summaryrefslogtreecommitdiffstats
path: root/src/rocksdb/util
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--src/rocksdb/util/aligned_buffer.h234
-rw-r--r--src/rocksdb/util/async_file_reader.cc73
-rw-r--r--src/rocksdb/util/async_file_reader.h144
-rw-r--r--src/rocksdb/util/autovector.h406
-rw-r--r--src/rocksdb/util/autovector_test.cc331
-rw-r--r--src/rocksdb/util/bloom_impl.h489
-rw-r--r--src/rocksdb/util/bloom_test.cc1175
-rw-r--r--src/rocksdb/util/build_version.cc.in81
-rw-r--r--src/rocksdb/util/cast_util.h42
-rw-r--r--src/rocksdb/util/channel.h69
-rw-r--r--src/rocksdb/util/cleanable.cc181
-rw-r--r--src/rocksdb/util/coding.cc90
-rw-r--r--src/rocksdb/util/coding.h389
-rw-r--r--src/rocksdb/util/coding_lean.h101
-rw-r--r--src/rocksdb/util/coding_test.cc217
-rw-r--r--src/rocksdb/util/compaction_job_stats_impl.cc100
-rw-r--r--src/rocksdb/util/comparator.cc391
-rw-r--r--src/rocksdb/util/compression.cc122
-rw-r--r--src/rocksdb/util/compression.h1786
-rw-r--r--src/rocksdb/util/compression_context_cache.cc106
-rw-r--r--src/rocksdb/util/compression_context_cache.h47
-rw-r--r--src/rocksdb/util/concurrent_task_limiter_impl.cc64
-rw-r--r--src/rocksdb/util/concurrent_task_limiter_impl.h67
-rw-r--r--src/rocksdb/util/core_local.h83
-rw-r--r--src/rocksdb/util/coro_utils.h112
-rw-r--r--src/rocksdb/util/crc32c.cc1351
-rw-r--r--src/rocksdb/util/crc32c.h56
-rw-r--r--src/rocksdb/util/crc32c_arm64.cc215
-rw-r--r--src/rocksdb/util/crc32c_arm64.h52
-rw-r--r--src/rocksdb/util/crc32c_ppc.c94
-rw-r--r--src/rocksdb/util/crc32c_ppc.h22
-rw-r--r--src/rocksdb/util/crc32c_ppc_asm.S756
-rw-r--r--src/rocksdb/util/crc32c_ppc_constants.h900
-rw-r--r--src/rocksdb/util/crc32c_test.cc213
-rw-r--r--src/rocksdb/util/defer.h82
-rw-r--r--src/rocksdb/util/defer_test.cc51
-rw-r--r--src/rocksdb/util/distributed_mutex.h48
-rw-r--r--src/rocksdb/util/duplicate_detector.h71
-rw-r--r--src/rocksdb/util/dynamic_bloom.cc70
-rw-r--r--src/rocksdb/util/dynamic_bloom.h214
-rw-r--r--src/rocksdb/util/dynamic_bloom_test.cc325
-rw-r--r--src/rocksdb/util/fastrange.h114
-rw-r--r--src/rocksdb/util/file_checksum_helper.cc172
-rw-r--r--src/rocksdb/util/file_checksum_helper.h100
-rw-r--r--src/rocksdb/util/file_reader_writer_test.cc1066
-rw-r--r--src/rocksdb/util/filelock_test.cc148
-rw-r--r--src/rocksdb/util/filter_bench.cc840
-rw-r--r--src/rocksdb/util/gflags_compat.h30
-rw-r--r--src/rocksdb/util/hash.cc201
-rw-r--r--src/rocksdb/util/hash.h137
-rw-r--r--src/rocksdb/util/hash128.h26
-rw-r--r--src/rocksdb/util/hash_containers.h51
-rw-r--r--src/rocksdb/util/hash_map.h67
-rw-r--r--src/rocksdb/util/hash_test.cc853
-rw-r--r--src/rocksdb/util/heap.h174
-rw-r--r--src/rocksdb/util/heap_test.cc131
-rw-r--r--src/rocksdb/util/kv_map.h33
-rw-r--r--src/rocksdb/util/log_write_bench.cc88
-rw-r--r--src/rocksdb/util/math.h294
-rw-r--r--src/rocksdb/util/math128.h316
-rw-r--r--src/rocksdb/util/murmurhash.cc196
-rw-r--r--src/rocksdb/util/murmurhash.h43
-rw-r--r--src/rocksdb/util/mutexlock.h180
-rw-r--r--src/rocksdb/util/ppc-opcode.h27
-rw-r--r--src/rocksdb/util/random.cc62
-rw-r--r--src/rocksdb/util/random.h190
-rw-r--r--src/rocksdb/util/random_test.cc107
-rw-r--r--src/rocksdb/util/rate_limiter.cc378
-rw-r--r--src/rocksdb/util/rate_limiter.h146
-rw-r--r--src/rocksdb/util/rate_limiter_test.cc476
-rw-r--r--src/rocksdb/util/repeatable_thread.h149
-rw-r--r--src/rocksdb/util/repeatable_thread_test.cc111
-rw-r--r--src/rocksdb/util/ribbon_alg.h1225
-rw-r--r--src/rocksdb/util/ribbon_config.cc506
-rw-r--r--src/rocksdb/util/ribbon_config.h182
-rw-r--r--src/rocksdb/util/ribbon_impl.h1137
-rw-r--r--src/rocksdb/util/ribbon_test.cc1308
-rw-r--r--src/rocksdb/util/set_comparator.h24
-rw-r--r--src/rocksdb/util/single_thread_executor.h56
-rw-r--r--src/rocksdb/util/slice.cc405
-rw-r--r--src/rocksdb/util/slice_test.cc191
-rw-r--r--src/rocksdb/util/slice_transform_test.cc154
-rw-r--r--src/rocksdb/util/status.cc154
-rw-r--r--src/rocksdb/util/stderr_logger.cc30
-rw-r--r--src/rocksdb/util/stderr_logger.h31
-rw-r--r--src/rocksdb/util/stop_watch.h118
-rw-r--r--src/rocksdb/util/string_util.cc504
-rw-r--r--src/rocksdb/util/string_util.h177
-rw-r--r--src/rocksdb/util/thread_guard.h41
-rw-r--r--src/rocksdb/util/thread_list_test.cc360
-rw-r--r--src/rocksdb/util/thread_local.cc521
-rw-r--r--src/rocksdb/util/thread_local.h100
-rw-r--r--src/rocksdb/util/thread_local_test.cc582
-rw-r--r--src/rocksdb/util/thread_operation.h112
-rw-r--r--src/rocksdb/util/threadpool_imp.cc551
-rw-r--r--src/rocksdb/util/threadpool_imp.h120
-rw-r--r--src/rocksdb/util/timer.h340
-rw-r--r--src/rocksdb/util/timer_queue.h231
-rw-r--r--src/rocksdb/util/timer_queue_test.cc73
-rw-r--r--src/rocksdb/util/timer_test.cc402
-rw-r--r--src/rocksdb/util/user_comparator_wrapper.h64
-rw-r--r--src/rocksdb/util/vector_iterator.h118
-rw-r--r--src/rocksdb/util/work_queue.h150
-rw-r--r--src/rocksdb/util/work_queue_test.cc272
-rw-r--r--src/rocksdb/util/xxhash.cc48
-rw-r--r--src/rocksdb/util/xxhash.h5346
-rw-r--r--src/rocksdb/util/xxph3.h1764
-rw-r--r--src/rocksdb/utilities/agg_merge/agg_merge.cc238
-rw-r--r--src/rocksdb/utilities/agg_merge/agg_merge.h49
-rw-r--r--src/rocksdb/utilities/agg_merge/agg_merge_test.cc135
-rw-r--r--src/rocksdb/utilities/agg_merge/test_agg_merge.cc104
-rw-r--r--src/rocksdb/utilities/agg_merge/test_agg_merge.h47
-rw-r--r--src/rocksdb/utilities/backup/backup_engine.cc3181
-rw-r--r--src/rocksdb/utilities/backup/backup_engine_impl.h36
-rw-r--r--src/rocksdb/utilities/backup/backup_engine_test.cc4219
-rw-r--r--src/rocksdb/utilities/blob_db/blob_compaction_filter.cc490
-rw-r--r--src/rocksdb/utilities/blob_db/blob_compaction_filter.h204
-rw-r--r--src/rocksdb/utilities/blob_db/blob_db.cc114
-rw-r--r--src/rocksdb/utilities/blob_db/blob_db.h266
-rw-r--r--src/rocksdb/utilities/blob_db/blob_db_gc_stats.h56
-rw-r--r--src/rocksdb/utilities/blob_db/blob_db_impl.cc2177
-rw-r--r--src/rocksdb/utilities/blob_db/blob_db_impl.h503
-rw-r--r--src/rocksdb/utilities/blob_db/blob_db_impl_filesnapshot.cc113
-rw-r--r--src/rocksdb/utilities/blob_db/blob_db_iterator.h150
-rw-r--r--src/rocksdb/utilities/blob_db/blob_db_listener.h71
-rw-r--r--src/rocksdb/utilities/blob_db/blob_db_test.cc2407
-rw-r--r--src/rocksdb/utilities/blob_db/blob_dump_tool.cc282
-rw-r--r--src/rocksdb/utilities/blob_db/blob_dump_tool.h58
-rw-r--r--src/rocksdb/utilities/blob_db/blob_file.cc318
-rw-r--r--src/rocksdb/utilities/blob_db/blob_file.h246
-rw-r--r--src/rocksdb/utilities/cache_dump_load.cc69
-rw-r--r--src/rocksdb/utilities/cache_dump_load_impl.cc393
-rw-r--r--src/rocksdb/utilities/cache_dump_load_impl.h359
-rw-r--r--src/rocksdb/utilities/cassandra/cassandra_compaction_filter.cc110
-rw-r--r--src/rocksdb/utilities/cassandra/cassandra_compaction_filter.h57
-rw-r--r--src/rocksdb/utilities/cassandra/cassandra_format_test.cc377
-rw-r--r--src/rocksdb/utilities/cassandra/cassandra_functional_test.cc446
-rw-r--r--src/rocksdb/utilities/cassandra/cassandra_options.h43
-rw-r--r--src/rocksdb/utilities/cassandra/cassandra_row_merge_test.cc98
-rw-r--r--src/rocksdb/utilities/cassandra/cassandra_serialize_test.cc164
-rw-r--r--src/rocksdb/utilities/cassandra/format.cc367
-rw-r--r--src/rocksdb/utilities/cassandra/format.h183
-rw-r--r--src/rocksdb/utilities/cassandra/merge_operator.cc82
-rw-r--r--src/rocksdb/utilities/cassandra/merge_operator.h44
-rw-r--r--src/rocksdb/utilities/cassandra/serialize.h81
-rw-r--r--src/rocksdb/utilities/cassandra/test_utils.cc69
-rw-r--r--src/rocksdb/utilities/cassandra/test_utils.h42
-rw-r--r--src/rocksdb/utilities/checkpoint/checkpoint_impl.cc469
-rw-r--r--src/rocksdb/utilities/checkpoint/checkpoint_impl.h66
-rw-r--r--src/rocksdb/utilities/checkpoint/checkpoint_test.cc974
-rw-r--r--src/rocksdb/utilities/compaction_filters.cc56
-rw-r--r--src/rocksdb/utilities/compaction_filters/layered_compaction_filter_base.h41
-rw-r--r--src/rocksdb/utilities/compaction_filters/remove_emptyvalue_compactionfilter.cc26
-rw-r--r--src/rocksdb/utilities/compaction_filters/remove_emptyvalue_compactionfilter.h28
-rw-r--r--src/rocksdb/utilities/convenience/info_log_finder.cc26
-rw-r--r--src/rocksdb/utilities/counted_fs.cc379
-rw-r--r--src/rocksdb/utilities/counted_fs.h158
-rw-r--r--src/rocksdb/utilities/debug.cc120
-rw-r--r--src/rocksdb/utilities/env_mirror.cc275
-rw-r--r--src/rocksdb/utilities/env_mirror_test.cc226
-rw-r--r--src/rocksdb/utilities/env_timed.cc187
-rw-r--r--src/rocksdb/utilities/env_timed.h97
-rw-r--r--src/rocksdb/utilities/env_timed_test.cc44
-rw-r--r--src/rocksdb/utilities/fault_injection_env.cc555
-rw-r--r--src/rocksdb/utilities/fault_injection_env.h258
-rw-r--r--src/rocksdb/utilities/fault_injection_fs.cc1032
-rw-r--r--src/rocksdb/utilities/fault_injection_fs.h584
-rw-r--r--src/rocksdb/utilities/fault_injection_secondary_cache.cc131
-rw-r--r--src/rocksdb/utilities/fault_injection_secondary_cache.h108
-rw-r--r--src/rocksdb/utilities/leveldb_options/leveldb_options.cc57
-rw-r--r--src/rocksdb/utilities/memory/memory_test.cc279
-rw-r--r--src/rocksdb/utilities/memory/memory_util.cc52
-rw-r--r--src/rocksdb/utilities/memory_allocators.h104
-rw-r--r--src/rocksdb/utilities/merge_operators.cc120
-rw-r--r--src/rocksdb/utilities/merge_operators.h36
-rw-r--r--src/rocksdb/utilities/merge_operators/bytesxor.cc57
-rw-r--r--src/rocksdb/utilities/merge_operators/bytesxor.h40
-rw-r--r--src/rocksdb/utilities/merge_operators/max.cc80
-rw-r--r--src/rocksdb/utilities/merge_operators/put.cc92
-rw-r--r--src/rocksdb/utilities/merge_operators/sortlist.cc95
-rw-r--r--src/rocksdb/utilities/merge_operators/sortlist.h42
-rw-r--r--src/rocksdb/utilities/merge_operators/string_append/stringappend.cc78
-rw-r--r--src/rocksdb/utilities/merge_operators/string_append/stringappend.h32
-rw-r--r--src/rocksdb/utilities/merge_operators/string_append/stringappend2.cc132
-rw-r--r--src/rocksdb/utilities/merge_operators/string_append/stringappend2.h52
-rw-r--r--src/rocksdb/utilities/merge_operators/string_append/stringappend_test.cc640
-rw-r--r--src/rocksdb/utilities/merge_operators/uint64add.cc75
-rw-r--r--src/rocksdb/utilities/object_registry.cc383
-rw-r--r--src/rocksdb/utilities/object_registry_test.cc872
-rw-r--r--src/rocksdb/utilities/option_change_migration/option_change_migration.cc186
-rw-r--r--src/rocksdb/utilities/option_change_migration/option_change_migration_test.cc550
-rw-r--r--src/rocksdb/utilities/options/options_util.cc159
-rw-r--r--src/rocksdb/utilities/options/options_util_test.cc779
-rw-r--r--src/rocksdb/utilities/persistent_cache/block_cache_tier.cc422
-rw-r--r--src/rocksdb/utilities/persistent_cache/block_cache_tier.h156
-rw-r--r--src/rocksdb/utilities/persistent_cache/block_cache_tier_file.cc610
-rw-r--r--src/rocksdb/utilities/persistent_cache/block_cache_tier_file.h293
-rw-r--r--src/rocksdb/utilities/persistent_cache/block_cache_tier_file_buffer.h127
-rw-r--r--src/rocksdb/utilities/persistent_cache/block_cache_tier_metadata.cc86
-rw-r--r--src/rocksdb/utilities/persistent_cache/block_cache_tier_metadata.h124
-rw-r--r--src/rocksdb/utilities/persistent_cache/hash_table.h239
-rw-r--r--src/rocksdb/utilities/persistent_cache/hash_table_bench.cc310
-rw-r--r--src/rocksdb/utilities/persistent_cache/hash_table_evictable.h168
-rw-r--r--src/rocksdb/utilities/persistent_cache/hash_table_test.cc163
-rw-r--r--src/rocksdb/utilities/persistent_cache/lrulist.h174
-rw-r--r--src/rocksdb/utilities/persistent_cache/persistent_cache_bench.cc359
-rw-r--r--src/rocksdb/utilities/persistent_cache/persistent_cache_test.cc462
-rw-r--r--src/rocksdb/utilities/persistent_cache/persistent_cache_test.h286
-rw-r--r--src/rocksdb/utilities/persistent_cache/persistent_cache_tier.cc167
-rw-r--r--src/rocksdb/utilities/persistent_cache/persistent_cache_tier.h342
-rw-r--r--src/rocksdb/utilities/persistent_cache/persistent_cache_util.h67
-rw-r--r--src/rocksdb/utilities/persistent_cache/volatile_tier_impl.cc140
-rw-r--r--src/rocksdb/utilities/persistent_cache/volatile_tier_impl.h141
-rw-r--r--src/rocksdb/utilities/simulator_cache/cache_simulator.cc288
-rw-r--r--src/rocksdb/utilities/simulator_cache/cache_simulator.h231
-rw-r--r--src/rocksdb/utilities/simulator_cache/cache_simulator_test.cc497
-rw-r--r--src/rocksdb/utilities/simulator_cache/sim_cache.cc364
-rw-r--r--src/rocksdb/utilities/simulator_cache/sim_cache_test.cc226
-rw-r--r--src/rocksdb/utilities/table_properties_collectors/compact_on_deletion_collector.cc227
-rw-r--r--src/rocksdb/utilities/table_properties_collectors/compact_on_deletion_collector.h70
-rw-r--r--src/rocksdb/utilities/table_properties_collectors/compact_on_deletion_collector_test.cc245
-rw-r--r--src/rocksdb/utilities/trace/file_trace_reader_writer.cc133
-rw-r--r--src/rocksdb/utilities/trace/file_trace_reader_writer.h48
-rw-r--r--src/rocksdb/utilities/trace/replayer_impl.cc316
-rw-r--r--src/rocksdb/utilities/trace/replayer_impl.h86
-rw-r--r--src/rocksdb/utilities/transactions/lock/lock_manager.cc29
-rw-r--r--src/rocksdb/utilities/transactions/lock/lock_manager.h82
-rw-r--r--src/rocksdb/utilities/transactions/lock/lock_tracker.h209
-rw-r--r--src/rocksdb/utilities/transactions/lock/point/point_lock_manager.cc721
-rw-r--r--src/rocksdb/utilities/transactions/lock/point/point_lock_manager.h224
-rw-r--r--src/rocksdb/utilities/transactions/lock/point/point_lock_manager_test.cc181
-rw-r--r--src/rocksdb/utilities/transactions/lock/point/point_lock_manager_test.h324
-rw-r--r--src/rocksdb/utilities/transactions/lock/point/point_lock_tracker.cc257
-rw-r--r--src/rocksdb/utilities/transactions/lock/point/point_lock_tracker.h99
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_lock_manager.h36
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_locking_test.cc459
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/lib/COPYING.AGPLv3661
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/lib/COPYING.APACHEv2174
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/lib/COPYING.GPLv2339
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/lib/README13
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/lib/db.h76
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/lib/ft/comparator.h138
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/lib/ft/ft-status.h102
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/concurrent_tree.cc139
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/concurrent_tree.h174
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/keyrange.cc222
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/keyrange.h141
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/lock_request.cc527
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/lock_request.h255
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/locktree.cc1023
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/locktree.h580
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/manager.cc527
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/range_buffer.cc265
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/range_buffer.h178
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/treenode.cc520
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/treenode.h302
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/txnid_set.cc120
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/txnid_set.h92
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/wfg.cc213
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/wfg.h124
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/lib/portability/memory.h215
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/lib/portability/toku_assert_subst.h39
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/lib/portability/toku_atomic.h130
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/lib/portability/toku_external_pthread.h83
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/lib/portability/toku_instrumentation.h286
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/lib/portability/toku_portability.h87
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/lib/portability/toku_pthread.h520
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/lib/portability/toku_race_tools.h179
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/lib/portability/toku_time.h193
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/lib/portability/txn_subst.h27
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/lib/standalone_port.cc132
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/lib/util/dbt.cc153
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/lib/util/dbt.h98
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/lib/util/growable_array.h144
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/lib/util/memarena.cc201
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/lib/util/memarena.h141
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/lib/util/omt.h794
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/lib/util/omt_impl.h1295
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/lib/util/partitioned_counter.h165
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/lib/util/status.h76
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/range_tree_lock_manager.cc503
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/range_tree_lock_manager.h137
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/range_tree_lock_tracker.cc156
-rw-r--r--src/rocksdb/utilities/transactions/lock/range/range_tree/range_tree_lock_tracker.h146
-rw-r--r--src/rocksdb/utilities/transactions/optimistic_transaction.cc196
-rw-r--r--src/rocksdb/utilities/transactions/optimistic_transaction.h101
-rw-r--r--src/rocksdb/utilities/transactions/optimistic_transaction_db_impl.cc111
-rw-r--r--src/rocksdb/utilities/transactions/optimistic_transaction_db_impl.h88
-rw-r--r--src/rocksdb/utilities/transactions/optimistic_transaction_test.cc1491
-rw-r--r--src/rocksdb/utilities/transactions/pessimistic_transaction.cc1175
-rw-r--r--src/rocksdb/utilities/transactions/pessimistic_transaction.h313
-rw-r--r--src/rocksdb/utilities/transactions/pessimistic_transaction_db.cc782
-rw-r--r--src/rocksdb/utilities/transactions/pessimistic_transaction_db.h318
-rw-r--r--src/rocksdb/utilities/transactions/snapshot_checker.cc53
-rw-r--r--src/rocksdb/utilities/transactions/timestamped_snapshot_test.cc466
-rw-r--r--src/rocksdb/utilities/transactions/transaction_base.cc731
-rw-r--r--src/rocksdb/utilities/transactions/transaction_base.h384
-rw-r--r--src/rocksdb/utilities/transactions/transaction_db_mutex_impl.cc135
-rw-r--r--src/rocksdb/utilities/transactions/transaction_db_mutex_impl.h26
-rw-r--r--src/rocksdb/utilities/transactions/transaction_test.cc6550
-rw-r--r--src/rocksdb/utilities/transactions/transaction_test.h578
-rw-r--r--src/rocksdb/utilities/transactions/transaction_util.cc206
-rw-r--r--src/rocksdb/utilities/transactions/transaction_util.h85
-rw-r--r--src/rocksdb/utilities/transactions/write_committed_transaction_ts_test.cc588
-rw-r--r--src/rocksdb/utilities/transactions/write_prepared_transaction_test.cc4078
-rw-r--r--src/rocksdb/utilities/transactions/write_prepared_txn.cc512
-rw-r--r--src/rocksdb/utilities/transactions/write_prepared_txn.h119
-rw-r--r--src/rocksdb/utilities/transactions/write_prepared_txn_db.cc1030
-rw-r--r--src/rocksdb/utilities/transactions/write_prepared_txn_db.h1125
-rw-r--r--src/rocksdb/utilities/transactions/write_unprepared_transaction_test.cc790
-rw-r--r--src/rocksdb/utilities/transactions/write_unprepared_txn.cc1053
-rw-r--r--src/rocksdb/utilities/transactions/write_unprepared_txn.h341
-rw-r--r--src/rocksdb/utilities/transactions/write_unprepared_txn_db.cc473
-rw-r--r--src/rocksdb/utilities/transactions/write_unprepared_txn_db.h108
-rw-r--r--src/rocksdb/utilities/ttl/db_ttl_impl.cc609
-rw-r--r--src/rocksdb/utilities/ttl/db_ttl_impl.h245
-rw-r--r--src/rocksdb/utilities/ttl/ttl_test.cc912
-rw-r--r--src/rocksdb/utilities/util_merge_operators_test.cc100
-rw-r--r--src/rocksdb/utilities/wal_filter.cc23
-rw-r--r--src/rocksdb/utilities/write_batch_with_index/write_batch_with_index.cc695
-rw-r--r--src/rocksdb/utilities/write_batch_with_index/write_batch_with_index_internal.cc735
-rw-r--r--src/rocksdb/utilities/write_batch_with_index/write_batch_with_index_internal.h344
-rw-r--r--src/rocksdb/utilities/write_batch_with_index/write_batch_with_index_test.cc2419
323 files changed, 118084 insertions, 0 deletions
diff --git a/src/rocksdb/util/aligned_buffer.h b/src/rocksdb/util/aligned_buffer.h
new file mode 100644
index 000000000..95ee5dfe8
--- /dev/null
+++ b/src/rocksdb/util/aligned_buffer.h
@@ -0,0 +1,234 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+#pragma once
+
+#include <algorithm>
+
+#include "port/port.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// This file contains utilities to handle the alignment of pages and buffers.
+
+// Truncate to a multiple of page_size, which is also a page boundary. This
+// helps to figuring out the right alignment.
+// Example:
+// TruncateToPageBoundary(4096, 5000) => 4096
+// TruncateToPageBoundary((4096, 10000) => 8192
+inline size_t TruncateToPageBoundary(size_t page_size, size_t s) {
+ s -= (s & (page_size - 1));
+ assert((s % page_size) == 0);
+ return s;
+}
+
+// Round up x to a multiple of y.
+// Example:
+// Roundup(13, 5) => 15
+// Roundup(201, 16) => 208
+inline size_t Roundup(size_t x, size_t y) { return ((x + y - 1) / y) * y; }
+
+// Round down x to a multiple of y.
+// Example:
+// Rounddown(13, 5) => 10
+// Rounddown(201, 16) => 192
+inline size_t Rounddown(size_t x, size_t y) { return (x / y) * y; }
+
+// AlignedBuffer manages a buffer by taking alignment into consideration, and
+// aligns the buffer start and end positions. It is mainly used for direct I/O,
+// though it can be used other purposes as well.
+// It also supports expanding the managed buffer, and copying whole or part of
+// the data from old buffer into the new expanded buffer. Such a copy especially
+// helps in cases avoiding an IO to re-fetch the data from disk.
+//
+// Example:
+// AlignedBuffer buf;
+// buf.Alignment(alignment);
+// buf.AllocateNewBuffer(user_requested_buf_size);
+// ...
+// buf.AllocateNewBuffer(2*user_requested_buf_size, /*copy_data*/ true,
+// copy_offset, copy_len);
+class AlignedBuffer {
+ size_t alignment_;
+ std::unique_ptr<char[]> buf_;
+ size_t capacity_;
+ size_t cursize_;
+ char* bufstart_;
+
+ public:
+ AlignedBuffer()
+ : alignment_(), capacity_(0), cursize_(0), bufstart_(nullptr) {}
+
+ AlignedBuffer(AlignedBuffer&& o) noexcept { *this = std::move(o); }
+
+ AlignedBuffer& operator=(AlignedBuffer&& o) noexcept {
+ alignment_ = std::move(o.alignment_);
+ buf_ = std::move(o.buf_);
+ capacity_ = std::move(o.capacity_);
+ cursize_ = std::move(o.cursize_);
+ bufstart_ = std::move(o.bufstart_);
+ return *this;
+ }
+
+ AlignedBuffer(const AlignedBuffer&) = delete;
+
+ AlignedBuffer& operator=(const AlignedBuffer&) = delete;
+
+ static bool isAligned(const void* ptr, size_t alignment) {
+ return reinterpret_cast<uintptr_t>(ptr) % alignment == 0;
+ }
+
+ static bool isAligned(size_t n, size_t alignment) {
+ return n % alignment == 0;
+ }
+
+ size_t Alignment() const { return alignment_; }
+
+ size_t Capacity() const { return capacity_; }
+
+ size_t CurrentSize() const { return cursize_; }
+
+ const char* BufferStart() const { return bufstart_; }
+
+ char* BufferStart() { return bufstart_; }
+
+ void Clear() { cursize_ = 0; }
+
+ char* Release() {
+ cursize_ = 0;
+ capacity_ = 0;
+ bufstart_ = nullptr;
+ return buf_.release();
+ }
+
+ void Alignment(size_t alignment) {
+ assert(alignment > 0);
+ assert((alignment & (alignment - 1)) == 0);
+ alignment_ = alignment;
+ }
+
+ // Allocates a new buffer and sets the start position to the first aligned
+ // byte.
+ //
+ // requested_capacity: requested new buffer capacity. This capacity will be
+ // rounded up based on alignment.
+ // copy_data: Copy data from old buffer to new buffer. If copy_offset and
+ // copy_len are not passed in and the new requested capacity is bigger
+ // than the existing buffer's capacity, the data in the exising buffer is
+ // fully copied over to the new buffer.
+ // copy_offset: Copy data from this offset in old buffer.
+ // copy_len: Number of bytes to copy.
+ //
+ // The function does nothing if the new requested_capacity is smaller than
+ // the current buffer capacity and copy_data is true i.e. the old buffer is
+ // retained as is.
+ void AllocateNewBuffer(size_t requested_capacity, bool copy_data = false,
+ uint64_t copy_offset = 0, size_t copy_len = 0) {
+ assert(alignment_ > 0);
+ assert((alignment_ & (alignment_ - 1)) == 0);
+
+ copy_len = copy_len > 0 ? copy_len : cursize_;
+ if (copy_data && requested_capacity < copy_len) {
+ // If we are downsizing to a capacity that is smaller than the current
+ // data in the buffer -- Ignore the request.
+ return;
+ }
+
+ size_t new_capacity = Roundup(requested_capacity, alignment_);
+ char* new_buf = new char[new_capacity + alignment_];
+ char* new_bufstart = reinterpret_cast<char*>(
+ (reinterpret_cast<uintptr_t>(new_buf) + (alignment_ - 1)) &
+ ~static_cast<uintptr_t>(alignment_ - 1));
+
+ if (copy_data) {
+ assert(bufstart_ + copy_offset + copy_len <= bufstart_ + cursize_);
+ memcpy(new_bufstart, bufstart_ + copy_offset, copy_len);
+ cursize_ = copy_len;
+ } else {
+ cursize_ = 0;
+ }
+
+ bufstart_ = new_bufstart;
+ capacity_ = new_capacity;
+ buf_.reset(new_buf);
+ }
+
+ // Append to the buffer.
+ //
+ // src : source to copy the data from.
+ // append_size : number of bytes to copy from src.
+ // Returns the number of bytes appended.
+ //
+ // If append_size is more than the remaining buffer size only the
+ // remaining-size worth of bytes are copied.
+ size_t Append(const char* src, size_t append_size) {
+ size_t buffer_remaining = capacity_ - cursize_;
+ size_t to_copy = std::min(append_size, buffer_remaining);
+
+ if (to_copy > 0) {
+ memcpy(bufstart_ + cursize_, src, to_copy);
+ cursize_ += to_copy;
+ }
+ return to_copy;
+ }
+
+ // Read from the buffer.
+ //
+ // dest : destination buffer to copy the data to.
+ // offset : the buffer offset to start reading from.
+ // read_size : the number of bytes to copy from the buffer to dest.
+ // Returns the number of bytes read/copied to dest.
+ size_t Read(char* dest, size_t offset, size_t read_size) const {
+ assert(offset < cursize_);
+
+ size_t to_read = 0;
+ if (offset < cursize_) {
+ to_read = std::min(cursize_ - offset, read_size);
+ }
+ if (to_read > 0) {
+ memcpy(dest, bufstart_ + offset, to_read);
+ }
+ return to_read;
+ }
+
+ // Pad to the end of alignment with "padding"
+ void PadToAlignmentWith(int padding) {
+ size_t total_size = Roundup(cursize_, alignment_);
+ size_t pad_size = total_size - cursize_;
+
+ if (pad_size > 0) {
+ assert((pad_size + cursize_) <= capacity_);
+ memset(bufstart_ + cursize_, padding, pad_size);
+ cursize_ += pad_size;
+ }
+ }
+
+ void PadWith(size_t pad_size, int padding) {
+ assert((pad_size + cursize_) <= capacity_);
+ memset(bufstart_ + cursize_, padding, pad_size);
+ cursize_ += pad_size;
+ }
+
+ // After a partial flush move the tail to the beginning of the buffer.
+ void RefitTail(size_t tail_offset, size_t tail_size) {
+ if (tail_size > 0) {
+ memmove(bufstart_, bufstart_ + tail_offset, tail_size);
+ }
+ cursize_ = tail_size;
+ }
+
+ // Returns a place to start appending.
+ // WARNING: Note that it is possible to write past the end of the buffer if
+ // the buffer is modified without using the write APIs or encapsulation
+ // offered by AlignedBuffer. It is up to the user to guard against such
+ // errors.
+ char* Destination() { return bufstart_ + cursize_; }
+
+ void Size(size_t cursize) { cursize_ = cursize; }
+};
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/async_file_reader.cc b/src/rocksdb/util/async_file_reader.cc
new file mode 100644
index 000000000..8401a6b44
--- /dev/null
+++ b/src/rocksdb/util/async_file_reader.cc
@@ -0,0 +1,73 @@
+// Copyright (c) Meta Platforms, Inc. and affiliates.
+//
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+#if USE_COROUTINES
+#include "util/async_file_reader.h"
+
+namespace ROCKSDB_NAMESPACE {
+bool AsyncFileReader::MultiReadAsyncImpl(ReadAwaiter* awaiter) {
+ if (tail_) {
+ tail_->next_ = awaiter;
+ }
+ tail_ = awaiter;
+ if (!head_) {
+ head_ = awaiter;
+ }
+ num_reqs_ += awaiter->num_reqs_;
+ awaiter->io_handle_.resize(awaiter->num_reqs_);
+ awaiter->del_fn_.resize(awaiter->num_reqs_);
+ for (size_t i = 0; i < awaiter->num_reqs_; ++i) {
+ awaiter->file_
+ ->ReadAsync(
+ awaiter->read_reqs_[i], awaiter->opts_,
+ [](const FSReadRequest& req, void* cb_arg) {
+ FSReadRequest* read_req = static_cast<FSReadRequest*>(cb_arg);
+ read_req->status = req.status;
+ read_req->result = req.result;
+ },
+ &awaiter->read_reqs_[i], &awaiter->io_handle_[i],
+ &awaiter->del_fn_[i], /*aligned_buf=*/nullptr)
+ .PermitUncheckedError();
+ }
+ return true;
+}
+
+void AsyncFileReader::Wait() {
+ if (!head_) {
+ return;
+ }
+ ReadAwaiter* waiter;
+ std::vector<void*> io_handles;
+ io_handles.reserve(num_reqs_);
+ waiter = head_;
+ do {
+ for (size_t i = 0; i < waiter->num_reqs_; ++i) {
+ if (waiter->io_handle_[i]) {
+ io_handles.push_back(waiter->io_handle_[i]);
+ }
+ }
+ } while (waiter != tail_ && (waiter = waiter->next_));
+ if (io_handles.size() > 0) {
+ StopWatch sw(SystemClock::Default().get(), stats_, POLL_WAIT_MICROS);
+ fs_->Poll(io_handles, io_handles.size()).PermitUncheckedError();
+ }
+ do {
+ waiter = head_;
+ head_ = waiter->next_;
+
+ for (size_t i = 0; i < waiter->num_reqs_; ++i) {
+ if (waiter->io_handle_[i] && waiter->del_fn_[i]) {
+ waiter->del_fn_[i](waiter->io_handle_[i]);
+ }
+ }
+ waiter->awaiting_coro_.resume();
+ } while (waiter != tail_);
+ head_ = tail_ = nullptr;
+ RecordInHistogram(stats_, MULTIGET_IO_BATCH_SIZE, num_reqs_);
+ num_reqs_ = 0;
+}
+} // namespace ROCKSDB_NAMESPACE
+#endif // USE_COROUTINES
diff --git a/src/rocksdb/util/async_file_reader.h b/src/rocksdb/util/async_file_reader.h
new file mode 100644
index 000000000..df69a840e
--- /dev/null
+++ b/src/rocksdb/util/async_file_reader.h
@@ -0,0 +1,144 @@
+// Copyright (c) Meta Platforms, Inc. and affiliates.
+//
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).#pragma once
+#pragma once
+
+#if USE_COROUTINES
+#include "file/random_access_file_reader.h"
+#include "folly/experimental/coro/ViaIfAsync.h"
+#include "port/port.h"
+#include "rocksdb/file_system.h"
+#include "rocksdb/statistics.h"
+#include "util/autovector.h"
+#include "util/stop_watch.h"
+
+namespace ROCKSDB_NAMESPACE {
+class SingleThreadExecutor;
+
+// AsyncFileReader implements the Awaitable concept, which allows calling
+// coroutines to co_await it. When the AsyncFileReader Awaitable is
+// resumed, it initiates the fie reads requested by the awaiting caller
+// by calling RandomAccessFileReader's ReadAsync. It then suspends the
+// awaiting coroutine. The suspended awaiter is later resumed by Wait().
+class AsyncFileReader {
+ class ReadAwaiter;
+ template <typename Awaiter>
+ class ReadOperation;
+
+ public:
+ AsyncFileReader(FileSystem* fs, Statistics* stats) : fs_(fs), stats_(stats) {}
+
+ ~AsyncFileReader() {}
+
+ ReadOperation<ReadAwaiter> MultiReadAsync(RandomAccessFileReader* file,
+ const IOOptions& opts,
+ FSReadRequest* read_reqs,
+ size_t num_reqs,
+ AlignedBuf* aligned_buf) noexcept {
+ return ReadOperation<ReadAwaiter>{*this, file, opts,
+ read_reqs, num_reqs, aligned_buf};
+ }
+
+ private:
+ friend SingleThreadExecutor;
+
+ // Implementation of the Awaitable concept
+ class ReadAwaiter {
+ public:
+ explicit ReadAwaiter(AsyncFileReader& reader, RandomAccessFileReader* file,
+ const IOOptions& opts, FSReadRequest* read_reqs,
+ size_t num_reqs, AlignedBuf* /*aligned_buf*/) noexcept
+ : reader_(reader),
+ file_(file),
+ opts_(opts),
+ read_reqs_(read_reqs),
+ num_reqs_(num_reqs),
+ next_(nullptr) {}
+
+ bool await_ready() noexcept { return false; }
+
+ // A return value of true means suspend the awaiter (calling coroutine). The
+ // awaiting_coro parameter is the handle of the awaiter. The handle can be
+ // resumed later, so we cache it here.
+ bool await_suspend(
+ folly::coro::impl::coroutine_handle<> awaiting_coro) noexcept {
+ awaiting_coro_ = awaiting_coro;
+ // MultiReadAsyncImpl always returns true, so caller will be suspended
+ return reader_.MultiReadAsyncImpl(this);
+ }
+
+ void await_resume() noexcept {}
+
+ private:
+ friend AsyncFileReader;
+
+ // The parameters passed to MultiReadAsync are cached here when the caller
+ // calls MultiReadAsync. Later, when the execution of this awaitable is
+ // started, these are used to do the actual IO
+ AsyncFileReader& reader_;
+ RandomAccessFileReader* file_;
+ const IOOptions& opts_;
+ FSReadRequest* read_reqs_;
+ size_t num_reqs_;
+ autovector<void*, 32> io_handle_;
+ autovector<IOHandleDeleter, 32> del_fn_;
+ folly::coro::impl::coroutine_handle<> awaiting_coro_;
+ // Use this to link to the next ReadAwaiter in the suspended coroutine
+ // list. The head and tail of the list are tracked by AsyncFileReader.
+ // We use this approach rather than an STL container in order to avoid
+ // extra memory allocations. The coroutine call already allocates a
+ // ReadAwaiter object.
+ ReadAwaiter* next_;
+ };
+
+ // An instance of ReadOperation is returned to the caller of MultiGetAsync.
+ // This represents an awaitable that can be started later.
+ template <typename Awaiter>
+ class ReadOperation {
+ public:
+ explicit ReadOperation(AsyncFileReader& reader,
+ RandomAccessFileReader* file, const IOOptions& opts,
+ FSReadRequest* read_reqs, size_t num_reqs,
+ AlignedBuf* aligned_buf) noexcept
+ : reader_(reader),
+ file_(file),
+ opts_(opts),
+ read_reqs_(read_reqs),
+ num_reqs_(num_reqs),
+ aligned_buf_(aligned_buf) {}
+
+ auto viaIfAsync(folly::Executor::KeepAlive<> executor) const {
+ return folly::coro::co_viaIfAsync(
+ std::move(executor),
+ Awaiter{reader_, file_, opts_, read_reqs_, num_reqs_, aligned_buf_});
+ }
+
+ private:
+ AsyncFileReader& reader_;
+ RandomAccessFileReader* file_;
+ const IOOptions& opts_;
+ FSReadRequest* read_reqs_;
+ size_t num_reqs_;
+ AlignedBuf* aligned_buf_;
+ };
+
+ // This function does the actual work when this awaitable starts execution
+ bool MultiReadAsyncImpl(ReadAwaiter* awaiter);
+
+ // Called by the SingleThreadExecutor to poll for async IO completion.
+ // This also resumes the awaiting coroutines.
+ void Wait();
+
+ // Head of the queue of awaiters waiting for async IO completion
+ ReadAwaiter* head_ = nullptr;
+ // Tail of the awaiter queue
+ ReadAwaiter* tail_ = nullptr;
+ // Total number of pending async IOs
+ size_t num_reqs_ = 0;
+ FileSystem* fs_;
+ Statistics* stats_;
+};
+} // namespace ROCKSDB_NAMESPACE
+#endif // USE_COROUTINES
diff --git a/src/rocksdb/util/autovector.h b/src/rocksdb/util/autovector.h
new file mode 100644
index 000000000..f758473b7
--- /dev/null
+++ b/src/rocksdb/util/autovector.h
@@ -0,0 +1,406 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+#pragma once
+
+#include <algorithm>
+#include <cassert>
+#include <initializer_list>
+#include <iterator>
+#include <stdexcept>
+#include <vector>
+
+#include "port/lang.h"
+#include "rocksdb/rocksdb_namespace.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+#ifdef ROCKSDB_LITE
+template <class T, size_t kSize = 8>
+class autovector : public std::vector<T> {
+ using std::vector<T>::vector;
+
+ public:
+ autovector() {
+ // Make sure the initial vector has space for kSize elements
+ std::vector<T>::reserve(kSize);
+ }
+};
+#else
+// A vector that leverages pre-allocated stack-based array to achieve better
+// performance for array with small amount of items.
+//
+// The interface resembles that of vector, but with less features since we aim
+// to solve the problem that we have in hand, rather than implementing a
+// full-fledged generic container.
+//
+// Currently we don't support:
+// * shrink_to_fit()
+// If used correctly, in most cases, people should not touch the
+// underlying vector at all.
+// * random insert()/erase(), please only use push_back()/pop_back().
+// * No move/swap operations. Each autovector instance has a
+// stack-allocated array and if we want support move/swap operations, we
+// need to copy the arrays other than just swapping the pointers. In this
+// case we'll just explicitly forbid these operations since they may
+// lead users to make false assumption by thinking they are inexpensive
+// operations.
+//
+// Naming style of public methods almost follows that of the STL's.
+template <class T, size_t kSize = 8>
+class autovector {
+ public:
+ // General STL-style container member types.
+ using value_type = T;
+ using difference_type = typename std::vector<T>::difference_type;
+ using size_type = typename std::vector<T>::size_type;
+ using reference = value_type&;
+ using const_reference = const value_type&;
+ using pointer = value_type*;
+ using const_pointer = const value_type*;
+
+ // This class is the base for regular/const iterator
+ template <class TAutoVector, class TValueType>
+ class iterator_impl {
+ public:
+ // -- iterator traits
+ using self_type = iterator_impl<TAutoVector, TValueType>;
+ using value_type = TValueType;
+ using reference = TValueType&;
+ using pointer = TValueType*;
+ using difference_type = typename TAutoVector::difference_type;
+ using iterator_category = std::random_access_iterator_tag;
+
+ iterator_impl(TAutoVector* vect, size_t index)
+ : vect_(vect), index_(index){};
+ iterator_impl(const iterator_impl&) = default;
+ ~iterator_impl() {}
+ iterator_impl& operator=(const iterator_impl&) = default;
+
+ // -- Advancement
+ // ++iterator
+ self_type& operator++() {
+ ++index_;
+ return *this;
+ }
+
+ // iterator++
+ self_type operator++(int) {
+ auto old = *this;
+ ++index_;
+ return old;
+ }
+
+ // --iterator
+ self_type& operator--() {
+ --index_;
+ return *this;
+ }
+
+ // iterator--
+ self_type operator--(int) {
+ auto old = *this;
+ --index_;
+ return old;
+ }
+
+ self_type operator-(difference_type len) const {
+ return self_type(vect_, index_ - len);
+ }
+
+ difference_type operator-(const self_type& other) const {
+ assert(vect_ == other.vect_);
+ return index_ - other.index_;
+ }
+
+ self_type operator+(difference_type len) const {
+ return self_type(vect_, index_ + len);
+ }
+
+ self_type& operator+=(difference_type len) {
+ index_ += len;
+ return *this;
+ }
+
+ self_type& operator-=(difference_type len) {
+ index_ -= len;
+ return *this;
+ }
+
+ // -- Reference
+ reference operator*() const {
+ assert(vect_->size() >= index_);
+ return (*vect_)[index_];
+ }
+
+ pointer operator->() const {
+ assert(vect_->size() >= index_);
+ return &(*vect_)[index_];
+ }
+
+ reference operator[](difference_type len) const { return *(*this + len); }
+
+ // -- Logical Operators
+ bool operator==(const self_type& other) const {
+ assert(vect_ == other.vect_);
+ return index_ == other.index_;
+ }
+
+ bool operator!=(const self_type& other) const { return !(*this == other); }
+
+ bool operator>(const self_type& other) const {
+ assert(vect_ == other.vect_);
+ return index_ > other.index_;
+ }
+
+ bool operator<(const self_type& other) const {
+ assert(vect_ == other.vect_);
+ return index_ < other.index_;
+ }
+
+ bool operator>=(const self_type& other) const {
+ assert(vect_ == other.vect_);
+ return index_ >= other.index_;
+ }
+
+ bool operator<=(const self_type& other) const {
+ assert(vect_ == other.vect_);
+ return index_ <= other.index_;
+ }
+
+ private:
+ TAutoVector* vect_ = nullptr;
+ size_t index_ = 0;
+ };
+
+ using iterator = iterator_impl<autovector, value_type>;
+ using const_iterator = iterator_impl<const autovector, const value_type>;
+ using reverse_iterator = std::reverse_iterator<iterator>;
+ using const_reverse_iterator = std::reverse_iterator<const_iterator>;
+
+ autovector() : values_(reinterpret_cast<pointer>(buf_)) {}
+
+ autovector(std::initializer_list<T> init_list)
+ : values_(reinterpret_cast<pointer>(buf_)) {
+ for (const T& item : init_list) {
+ push_back(item);
+ }
+ }
+
+ ~autovector() { clear(); }
+
+ // -- Immutable operations
+ // Indicate if all data resides in in-stack data structure.
+ bool only_in_stack() const {
+ // If no element was inserted at all, the vector's capacity will be `0`.
+ return vect_.capacity() == 0;
+ }
+
+ size_type size() const { return num_stack_items_ + vect_.size(); }
+
+ // resize does not guarantee anything about the contents of the newly
+ // available elements
+ void resize(size_type n) {
+ if (n > kSize) {
+ vect_.resize(n - kSize);
+ while (num_stack_items_ < kSize) {
+ new ((void*)(&values_[num_stack_items_++])) value_type();
+ }
+ num_stack_items_ = kSize;
+ } else {
+ vect_.clear();
+ while (num_stack_items_ < n) {
+ new ((void*)(&values_[num_stack_items_++])) value_type();
+ }
+ while (num_stack_items_ > n) {
+ values_[--num_stack_items_].~value_type();
+ }
+ }
+ }
+
+ bool empty() const { return size() == 0; }
+
+ size_type capacity() const { return kSize + vect_.capacity(); }
+
+ void reserve(size_t cap) {
+ if (cap > kSize) {
+ vect_.reserve(cap - kSize);
+ }
+
+ assert(cap <= capacity());
+ }
+
+ const_reference operator[](size_type n) const {
+ assert(n < size());
+ if (n < kSize) {
+ return values_[n];
+ }
+ return vect_[n - kSize];
+ }
+
+ reference operator[](size_type n) {
+ assert(n < size());
+ if (n < kSize) {
+ return values_[n];
+ }
+ return vect_[n - kSize];
+ }
+
+ const_reference at(size_type n) const {
+ assert(n < size());
+ return (*this)[n];
+ }
+
+ reference at(size_type n) {
+ assert(n < size());
+ return (*this)[n];
+ }
+
+ reference front() {
+ assert(!empty());
+ return *begin();
+ }
+
+ const_reference front() const {
+ assert(!empty());
+ return *begin();
+ }
+
+ reference back() {
+ assert(!empty());
+ return *(end() - 1);
+ }
+
+ const_reference back() const {
+ assert(!empty());
+ return *(end() - 1);
+ }
+
+ // -- Mutable Operations
+ void push_back(T&& item) {
+ if (num_stack_items_ < kSize) {
+ new ((void*)(&values_[num_stack_items_])) value_type();
+ values_[num_stack_items_++] = std::move(item);
+ } else {
+ vect_.push_back(item);
+ }
+ }
+
+ void push_back(const T& item) {
+ if (num_stack_items_ < kSize) {
+ new ((void*)(&values_[num_stack_items_])) value_type();
+ values_[num_stack_items_++] = item;
+ } else {
+ vect_.push_back(item);
+ }
+ }
+
+ template <class... Args>
+#if _LIBCPP_STD_VER > 14
+ reference emplace_back(Args&&... args) {
+ if (num_stack_items_ < kSize) {
+ return *(new ((void*)(&values_[num_stack_items_++]))
+ value_type(std::forward<Args>(args)...));
+ } else {
+ return vect_.emplace_back(std::forward<Args>(args)...);
+ }
+ }
+#else
+ void emplace_back(Args&&... args) {
+ if (num_stack_items_ < kSize) {
+ new ((void*)(&values_[num_stack_items_++]))
+ value_type(std::forward<Args>(args)...);
+ } else {
+ vect_.emplace_back(std::forward<Args>(args)...);
+ }
+ }
+#endif
+
+ void pop_back() {
+ assert(!empty());
+ if (!vect_.empty()) {
+ vect_.pop_back();
+ } else {
+ values_[--num_stack_items_].~value_type();
+ }
+ }
+
+ void clear() {
+ while (num_stack_items_ > 0) {
+ values_[--num_stack_items_].~value_type();
+ }
+ vect_.clear();
+ }
+
+ // -- Copy and Assignment
+ autovector& assign(const autovector& other);
+
+ autovector(const autovector& other) { assign(other); }
+
+ autovector& operator=(const autovector& other) { return assign(other); }
+
+ autovector(autovector&& other) noexcept { *this = std::move(other); }
+ autovector& operator=(autovector&& other);
+
+ // -- Iterator Operations
+ iterator begin() { return iterator(this, 0); }
+
+ const_iterator begin() const { return const_iterator(this, 0); }
+
+ iterator end() { return iterator(this, this->size()); }
+
+ const_iterator end() const { return const_iterator(this, this->size()); }
+
+ reverse_iterator rbegin() { return reverse_iterator(end()); }
+
+ const_reverse_iterator rbegin() const {
+ return const_reverse_iterator(end());
+ }
+
+ reverse_iterator rend() { return reverse_iterator(begin()); }
+
+ const_reverse_iterator rend() const {
+ return const_reverse_iterator(begin());
+ }
+
+ private:
+ size_type num_stack_items_ = 0; // current number of items
+ alignas(alignof(
+ value_type)) char buf_[kSize *
+ sizeof(value_type)]; // the first `kSize` items
+ pointer values_;
+ // used only if there are more than `kSize` items.
+ std::vector<T> vect_;
+};
+
+template <class T, size_t kSize>
+autovector<T, kSize>& autovector<T, kSize>::assign(
+ const autovector<T, kSize>& other) {
+ values_ = reinterpret_cast<pointer>(buf_);
+ // copy the internal vector
+ vect_.assign(other.vect_.begin(), other.vect_.end());
+
+ // copy array
+ num_stack_items_ = other.num_stack_items_;
+ std::copy(other.values_, other.values_ + num_stack_items_, values_);
+
+ return *this;
+}
+
+template <class T, size_t kSize>
+autovector<T, kSize>& autovector<T, kSize>::operator=(
+ autovector<T, kSize>&& other) {
+ values_ = reinterpret_cast<pointer>(buf_);
+ vect_ = std::move(other.vect_);
+ size_t n = other.num_stack_items_;
+ num_stack_items_ = n;
+ other.num_stack_items_ = 0;
+ for (size_t i = 0; i < n; ++i) {
+ values_[i] = std::move(other.values_[i]);
+ }
+ return *this;
+}
+
+#endif // ROCKSDB_LITE
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/autovector_test.cc b/src/rocksdb/util/autovector_test.cc
new file mode 100644
index 000000000..8c7c39ce6
--- /dev/null
+++ b/src/rocksdb/util/autovector_test.cc
@@ -0,0 +1,331 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "util/autovector.h"
+
+#include <atomic>
+#include <iostream>
+#include <string>
+#include <utility>
+
+#include "rocksdb/env.h"
+#include "test_util/testharness.h"
+#include "test_util/testutil.h"
+#include "util/string_util.h"
+
+using std::cout;
+using std::endl;
+
+namespace ROCKSDB_NAMESPACE {
+
+class AutoVectorTest : public testing::Test {};
+const unsigned long kSize = 8;
+
+namespace {
+template <class T>
+void AssertAutoVectorOnlyInStack(autovector<T, kSize>* vec, bool result) {
+#ifndef ROCKSDB_LITE
+ ASSERT_EQ(vec->only_in_stack(), result);
+#else
+ (void)vec;
+ (void)result;
+#endif // !ROCKSDB_LITE
+}
+} // namespace
+
+TEST_F(AutoVectorTest, PushBackAndPopBack) {
+ autovector<size_t, kSize> vec;
+ ASSERT_TRUE(vec.empty());
+ ASSERT_EQ(0ul, vec.size());
+
+ for (size_t i = 0; i < 1000 * kSize; ++i) {
+ vec.push_back(i);
+ ASSERT_TRUE(!vec.empty());
+ if (i < kSize) {
+ AssertAutoVectorOnlyInStack(&vec, true);
+ } else {
+ AssertAutoVectorOnlyInStack(&vec, false);
+ }
+ ASSERT_EQ(i + 1, vec.size());
+ ASSERT_EQ(i, vec[i]);
+ ASSERT_EQ(i, vec.at(i));
+ }
+
+ size_t size = vec.size();
+ while (size != 0) {
+ vec.pop_back();
+ // will always be in heap
+ AssertAutoVectorOnlyInStack(&vec, false);
+ ASSERT_EQ(--size, vec.size());
+ }
+
+ ASSERT_TRUE(vec.empty());
+}
+
+TEST_F(AutoVectorTest, EmplaceBack) {
+ using ValType = std::pair<size_t, std::string>;
+ autovector<ValType, kSize> vec;
+
+ for (size_t i = 0; i < 1000 * kSize; ++i) {
+ vec.emplace_back(i, std::to_string(i + 123));
+ ASSERT_TRUE(!vec.empty());
+ if (i < kSize) {
+ AssertAutoVectorOnlyInStack(&vec, true);
+ } else {
+ AssertAutoVectorOnlyInStack(&vec, false);
+ }
+
+ ASSERT_EQ(i + 1, vec.size());
+ ASSERT_EQ(i, vec[i].first);
+ ASSERT_EQ(std::to_string(i + 123), vec[i].second);
+ }
+
+ vec.clear();
+ ASSERT_TRUE(vec.empty());
+ AssertAutoVectorOnlyInStack(&vec, false);
+}
+
+TEST_F(AutoVectorTest, Resize) {
+ autovector<size_t, kSize> vec;
+
+ vec.resize(kSize);
+ AssertAutoVectorOnlyInStack(&vec, true);
+ for (size_t i = 0; i < kSize; ++i) {
+ vec[i] = i;
+ }
+
+ vec.resize(kSize * 2);
+ AssertAutoVectorOnlyInStack(&vec, false);
+ for (size_t i = 0; i < kSize; ++i) {
+ ASSERT_EQ(vec[i], i);
+ }
+ for (size_t i = 0; i < kSize; ++i) {
+ vec[i + kSize] = i;
+ }
+
+ vec.resize(1);
+ ASSERT_EQ(1U, vec.size());
+}
+
+namespace {
+void AssertEqual(const autovector<size_t, kSize>& a,
+ const autovector<size_t, kSize>& b) {
+ ASSERT_EQ(a.size(), b.size());
+ ASSERT_EQ(a.empty(), b.empty());
+#ifndef ROCKSDB_LITE
+ ASSERT_EQ(a.only_in_stack(), b.only_in_stack());
+#endif // !ROCKSDB_LITE
+ for (size_t i = 0; i < a.size(); ++i) {
+ ASSERT_EQ(a[i], b[i]);
+ }
+}
+} // namespace
+
+TEST_F(AutoVectorTest, CopyAndAssignment) {
+ // Test both heap-allocated and stack-allocated cases.
+ for (auto size : {kSize / 2, kSize * 1000}) {
+ autovector<size_t, kSize> vec;
+ for (size_t i = 0; i < size; ++i) {
+ vec.push_back(i);
+ }
+
+ {
+ autovector<size_t, kSize> other;
+ other = vec;
+ AssertEqual(other, vec);
+ }
+
+ {
+ autovector<size_t, kSize> other(vec);
+ AssertEqual(other, vec);
+ }
+ }
+}
+
+TEST_F(AutoVectorTest, Iterators) {
+ autovector<std::string, kSize> vec;
+ for (size_t i = 0; i < kSize * 1000; ++i) {
+ vec.push_back(std::to_string(i));
+ }
+
+ // basic operator test
+ ASSERT_EQ(vec.front(), *vec.begin());
+ ASSERT_EQ(vec.back(), *(vec.end() - 1));
+ ASSERT_TRUE(vec.begin() < vec.end());
+
+ // non-const iterator
+ size_t index = 0;
+ for (const auto& item : vec) {
+ ASSERT_EQ(vec[index++], item);
+ }
+
+ index = vec.size() - 1;
+ for (auto pos = vec.rbegin(); pos != vec.rend(); ++pos) {
+ ASSERT_EQ(vec[index--], *pos);
+ }
+
+ // const iterator
+ const auto& cvec = vec;
+ index = 0;
+ for (const auto& item : cvec) {
+ ASSERT_EQ(cvec[index++], item);
+ }
+
+ index = vec.size() - 1;
+ for (auto pos = cvec.rbegin(); pos != cvec.rend(); ++pos) {
+ ASSERT_EQ(cvec[index--], *pos);
+ }
+
+ // forward and backward
+ auto pos = vec.begin();
+ while (pos != vec.end()) {
+ auto old_val = *pos;
+ auto old = pos++;
+ // HACK: make sure -> works
+ ASSERT_TRUE(!old->empty());
+ ASSERT_EQ(old_val, *old);
+ ASSERT_TRUE(pos == vec.end() || old_val != *pos);
+ }
+
+ pos = vec.begin();
+ for (size_t i = 0; i < vec.size(); i += 2) {
+ // Cannot use ASSERT_EQ since that macro depends on iostream serialization
+ ASSERT_TRUE(pos + 2 - 2 == pos);
+ pos += 2;
+ ASSERT_TRUE(pos >= vec.begin());
+ ASSERT_TRUE(pos <= vec.end());
+
+ size_t diff = static_cast<size_t>(pos - vec.begin());
+ ASSERT_EQ(i + 2, diff);
+ }
+}
+
+namespace {
+std::vector<std::string> GetTestKeys(size_t size) {
+ std::vector<std::string> keys;
+ keys.resize(size);
+
+ int index = 0;
+ for (auto& key : keys) {
+ key = "item-" + std::to_string(index++);
+ }
+ return keys;
+}
+} // namespace
+
+template <class TVector>
+void BenchmarkVectorCreationAndInsertion(
+ std::string name, size_t ops, size_t item_size,
+ const std::vector<typename TVector::value_type>& items) {
+ auto env = Env::Default();
+
+ int index = 0;
+ auto start_time = env->NowNanos();
+ auto ops_remaining = ops;
+ while (ops_remaining--) {
+ TVector v;
+ for (size_t i = 0; i < item_size; ++i) {
+ v.push_back(items[index++]);
+ }
+ }
+ auto elapsed = env->NowNanos() - start_time;
+ cout << "created " << ops << " " << name << " instances:\n\t"
+ << "each was inserted with " << item_size << " elements\n\t"
+ << "total time elapsed: " << elapsed << " (ns)" << endl;
+}
+
+template <class TVector>
+size_t BenchmarkSequenceAccess(std::string name, size_t ops, size_t elem_size) {
+ TVector v;
+ for (const auto& item : GetTestKeys(elem_size)) {
+ v.push_back(item);
+ }
+ auto env = Env::Default();
+
+ auto ops_remaining = ops;
+ auto start_time = env->NowNanos();
+ size_t total = 0;
+ while (ops_remaining--) {
+ auto end = v.end();
+ for (auto pos = v.begin(); pos != end; ++pos) {
+ total += pos->size();
+ }
+ }
+ auto elapsed = env->NowNanos() - start_time;
+ cout << "performed " << ops << " sequence access against " << name << "\n\t"
+ << "size: " << elem_size << "\n\t"
+ << "total time elapsed: " << elapsed << " (ns)" << endl;
+ // HACK avoid compiler's optimization to ignore total
+ return total;
+}
+
+// This test case only reports the performance between std::vector<std::string>
+// and autovector<std::string>. We chose string for comparison because in most
+// of our use cases we used std::vector<std::string>.
+TEST_F(AutoVectorTest, PerfBench) {
+ // We run same operations for kOps times in order to get a more fair result.
+ size_t kOps = 100000;
+
+ // Creation and insertion test
+ // Test the case when there is:
+ // * no element inserted: internal array of std::vector may not really get
+ // initialize.
+ // * one element inserted: internal array of std::vector must have
+ // initialized.
+ // * kSize elements inserted. This shows the most time we'll spend if we
+ // keep everything in stack.
+ // * 2 * kSize elements inserted. The internal vector of
+ // autovector must have been initialized.
+ cout << "=====================================================" << endl;
+ cout << "Creation and Insertion Test (value type: std::string)" << endl;
+ cout << "=====================================================" << endl;
+
+ // pre-generated unique keys
+ auto string_keys = GetTestKeys(kOps * 2 * kSize);
+ for (auto insertions : {0ul, 1ul, kSize / 2, kSize, 2 * kSize}) {
+ BenchmarkVectorCreationAndInsertion<std::vector<std::string>>(
+ "std::vector<std::string>", kOps, insertions, string_keys);
+ BenchmarkVectorCreationAndInsertion<autovector<std::string, kSize>>(
+ "autovector<std::string>", kOps, insertions, string_keys);
+ cout << "-----------------------------------" << endl;
+ }
+
+ cout << "=====================================================" << endl;
+ cout << "Creation and Insertion Test (value type: uint64_t)" << endl;
+ cout << "=====================================================" << endl;
+
+ // pre-generated unique keys
+ std::vector<uint64_t> int_keys(kOps * 2 * kSize);
+ for (size_t i = 0; i < kOps * 2 * kSize; ++i) {
+ int_keys[i] = i;
+ }
+ for (auto insertions : {0ul, 1ul, kSize / 2, kSize, 2 * kSize}) {
+ BenchmarkVectorCreationAndInsertion<std::vector<uint64_t>>(
+ "std::vector<uint64_t>", kOps, insertions, int_keys);
+ BenchmarkVectorCreationAndInsertion<autovector<uint64_t, kSize>>(
+ "autovector<uint64_t>", kOps, insertions, int_keys);
+ cout << "-----------------------------------" << endl;
+ }
+
+ // Sequence Access Test
+ cout << "=====================================================" << endl;
+ cout << "Sequence Access Test" << endl;
+ cout << "=====================================================" << endl;
+ for (auto elem_size : {kSize / 2, kSize, 2 * kSize}) {
+ BenchmarkSequenceAccess<std::vector<std::string>>("std::vector", kOps,
+ elem_size);
+ BenchmarkSequenceAccess<autovector<std::string, kSize>>("autovector", kOps,
+ elem_size);
+ cout << "-----------------------------------" << endl;
+ }
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/rocksdb/util/bloom_impl.h b/src/rocksdb/util/bloom_impl.h
new file mode 100644
index 000000000..fadd012d3
--- /dev/null
+++ b/src/rocksdb/util/bloom_impl.h
@@ -0,0 +1,489 @@
+// Copyright (c) 2019-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Implementation details of various Bloom filter implementations used in
+// RocksDB. (DynamicBloom is in a separate file for now because it
+// supports concurrent write.)
+
+#pragma once
+#include <stddef.h>
+#include <stdint.h>
+
+#include <cmath>
+
+#include "port/port.h" // for PREFETCH
+#include "rocksdb/slice.h"
+#include "util/hash.h"
+
+#ifdef HAVE_AVX2
+#include <immintrin.h>
+#endif
+
+namespace ROCKSDB_NAMESPACE {
+
+class BloomMath {
+ public:
+ // False positive rate of a standard Bloom filter, for given ratio of
+ // filter memory bits to added keys, and number of probes per operation.
+ // (The false positive rate is effectively independent of scale, assuming
+ // the implementation scales OK.)
+ static double StandardFpRate(double bits_per_key, int num_probes) {
+ // Standard very-good-estimate formula. See
+ // https://en.wikipedia.org/wiki/Bloom_filter#Probability_of_false_positives
+ return std::pow(1.0 - std::exp(-num_probes / bits_per_key), num_probes);
+ }
+
+ // False positive rate of a "blocked"/"shareded"/"cache-local" Bloom filter,
+ // for given ratio of filter memory bits to added keys, number of probes per
+ // operation (all within the given block or cache line size), and block or
+ // cache line size.
+ static double CacheLocalFpRate(double bits_per_key, int num_probes,
+ int cache_line_bits) {
+ if (bits_per_key <= 0.0) {
+ // Fix a discontinuity
+ return 1.0;
+ }
+ double keys_per_cache_line = cache_line_bits / bits_per_key;
+ // A reasonable estimate is the average of the FP rates for one standard
+ // deviation above and below the mean bucket occupancy. See
+ // https://github.com/facebook/rocksdb/wiki/RocksDB-Bloom-Filter#the-math
+ double keys_stddev = std::sqrt(keys_per_cache_line);
+ double crowded_fp = StandardFpRate(
+ cache_line_bits / (keys_per_cache_line + keys_stddev), num_probes);
+ double uncrowded_fp = StandardFpRate(
+ cache_line_bits / (keys_per_cache_line - keys_stddev), num_probes);
+ return (crowded_fp + uncrowded_fp) / 2;
+ }
+
+ // False positive rate of querying a new item against `num_keys` items, all
+ // hashed to `fingerprint_bits` bits. (This assumes the fingerprint hashes
+ // themselves are stored losslessly. See Section 4 of
+ // http://www.ccs.neu.edu/home/pete/pub/bloom-filters-verification.pdf)
+ static double FingerprintFpRate(size_t num_keys, int fingerprint_bits) {
+ double inv_fingerprint_space = std::pow(0.5, fingerprint_bits);
+ // Base estimate assumes each key maps to a unique fingerprint.
+ // Could be > 1 in extreme cases.
+ double base_estimate = num_keys * inv_fingerprint_space;
+ // To account for potential overlap, we choose between two formulas
+ if (base_estimate > 0.0001) {
+ // A very good formula assuming we don't construct a floating point
+ // number extremely close to 1. Always produces a probability < 1.
+ return 1.0 - std::exp(-base_estimate);
+ } else {
+ // A very good formula when base_estimate is far below 1. (Subtract
+ // away the integral-approximated sum that some key has same hash as
+ // one coming before it in a list.)
+ return base_estimate - (base_estimate * base_estimate * 0.5);
+ }
+ }
+
+ // Returns the probably of either of two independent(-ish) events
+ // happening, given their probabilities. (This is useful for combining
+ // results from StandardFpRate or CacheLocalFpRate with FingerprintFpRate
+ // for a hash-efficient Bloom filter's FP rate. See Section 4 of
+ // http://www.ccs.neu.edu/home/pete/pub/bloom-filters-verification.pdf)
+ static double IndependentProbabilitySum(double rate1, double rate2) {
+ // Use formula that avoids floating point extremely close to 1 if
+ // rates are extremely small.
+ return rate1 + rate2 - (rate1 * rate2);
+ }
+};
+
+// A fast, flexible, and accurate cache-local Bloom implementation with
+// SIMD-optimized query performance (currently using AVX2 on Intel). Write
+// performance and non-SIMD read are very good, benefiting from FastRange32
+// used in place of % and single-cycle multiplication on recent processors.
+//
+// Most other SIMD Bloom implementations sacrifice flexibility and/or
+// accuracy by requiring num_probes to be a power of two and restricting
+// where each probe can occur in a cache line. This implementation sacrifices
+// SIMD-optimization for add (might still be possible, especially with AVX512)
+// in favor of allowing any num_probes, not crossing cache line boundary,
+// and accuracy close to theoretical best accuracy for a cache-local Bloom.
+// E.g. theoretical best for 10 bits/key, num_probes=6, and 512-bit bucket
+// (Intel cache line size) is 0.9535% FP rate. This implementation yields
+// about 0.957%. (Compare to LegacyLocalityBloomImpl<false> at 1.138%, or
+// about 0.951% for 1024-bit buckets, cache line size for some ARM CPUs.)
+//
+// This implementation can use a 32-bit hash (let h2 be h1 * 0x9e3779b9) or
+// a 64-bit hash (split into two uint32s). With many millions of keys, the
+// false positive rate associated with using a 32-bit hash can dominate the
+// false positive rate of the underlying filter. At 10 bits/key setting, the
+// inflection point is about 40 million keys, so 32-bit hash is a bad idea
+// with 10s of millions of keys or more.
+//
+// Despite accepting a 64-bit hash, this implementation uses 32-bit fastrange
+// to pick a cache line, which can be faster than 64-bit in some cases.
+// This only hurts accuracy as you get into 10s of GB for a single filter,
+// and accuracy abruptly breaks down at 256GB (2^32 cache lines). Switch to
+// 64-bit fastrange if you need filters so big. ;)
+//
+// Using only a 32-bit input hash within each cache line has negligible
+// impact for any reasonable cache line / bucket size, for arbitrary filter
+// size, and potentially saves intermediate data size in some cases vs.
+// tracking full 64 bits. (Even in an implementation using 64-bit arithmetic
+// to generate indices, I might do the same, as a single multiplication
+// suffices to generate a sufficiently mixed 64 bits from 32 bits.)
+//
+// This implementation is currently tied to Intel cache line size, 64 bytes ==
+// 512 bits. If there's sufficient demand for other cache line sizes, this is
+// a pretty good implementation to extend, but slight performance enhancements
+// are possible with an alternate implementation (probably not very compatible
+// with SIMD):
+// (1) Use rotation in addition to multiplication for remixing
+// (like murmur hash). (Using multiplication alone *slightly* hurts accuracy
+// because lower bits never depend on original upper bits.)
+// (2) Extract more than one bit index from each re-mix. (Only if rotation
+// or similar is part of remix, because otherwise you're making the
+// multiplication-only problem worse.)
+// (3) Re-mix full 64 bit hash, to get maximum number of bit indices per
+// re-mix.
+//
+class FastLocalBloomImpl {
+ public:
+ // NOTE: this has only been validated to enough accuracy for producing
+ // reasonable warnings / user feedback, not for making functional decisions.
+ static double EstimatedFpRate(size_t keys, size_t bytes, int num_probes,
+ int hash_bits) {
+ return BloomMath::IndependentProbabilitySum(
+ BloomMath::CacheLocalFpRate(8.0 * bytes / keys, num_probes,
+ /*cache line bits*/ 512),
+ BloomMath::FingerprintFpRate(keys, hash_bits));
+ }
+
+ static inline int ChooseNumProbes(int millibits_per_key) {
+ // Since this implementation can (with AVX2) make up to 8 probes
+ // for the same cost, we pick the most accurate num_probes, based
+ // on actual tests of the implementation. Note that for higher
+ // bits/key, the best choice for cache-local Bloom can be notably
+ // smaller than standard bloom, e.g. 9 instead of 11 @ 16 b/k.
+ if (millibits_per_key <= 2080) {
+ return 1;
+ } else if (millibits_per_key <= 3580) {
+ return 2;
+ } else if (millibits_per_key <= 5100) {
+ return 3;
+ } else if (millibits_per_key <= 6640) {
+ return 4;
+ } else if (millibits_per_key <= 8300) {
+ return 5;
+ } else if (millibits_per_key <= 10070) {
+ return 6;
+ } else if (millibits_per_key <= 11720) {
+ return 7;
+ } else if (millibits_per_key <= 14001) {
+ // Would be something like <= 13800 but sacrificing *slightly* for
+ // more settings using <= 8 probes.
+ return 8;
+ } else if (millibits_per_key <= 16050) {
+ return 9;
+ } else if (millibits_per_key <= 18300) {
+ return 10;
+ } else if (millibits_per_key <= 22001) {
+ return 11;
+ } else if (millibits_per_key <= 25501) {
+ return 12;
+ } else if (millibits_per_key > 50000) {
+ // Top out at 24 probes (three sets of 8)
+ return 24;
+ } else {
+ // Roughly optimal choices for remaining range
+ // e.g.
+ // 28000 -> 12, 28001 -> 13
+ // 50000 -> 23, 50001 -> 24
+ return (millibits_per_key - 1) / 2000 - 1;
+ }
+ }
+
+ static inline void AddHash(uint32_t h1, uint32_t h2, uint32_t len_bytes,
+ int num_probes, char *data) {
+ uint32_t bytes_to_cache_line = FastRange32(len_bytes >> 6, h1) << 6;
+ AddHashPrepared(h2, num_probes, data + bytes_to_cache_line);
+ }
+
+ static inline void AddHashPrepared(uint32_t h2, int num_probes,
+ char *data_at_cache_line) {
+ uint32_t h = h2;
+ for (int i = 0; i < num_probes; ++i, h *= uint32_t{0x9e3779b9}) {
+ // 9-bit address within 512 bit cache line
+ int bitpos = h >> (32 - 9);
+ data_at_cache_line[bitpos >> 3] |= (uint8_t{1} << (bitpos & 7));
+ }
+ }
+
+ static inline void PrepareHash(uint32_t h1, uint32_t len_bytes,
+ const char *data,
+ uint32_t /*out*/ *byte_offset) {
+ uint32_t bytes_to_cache_line = FastRange32(len_bytes >> 6, h1) << 6;
+ PREFETCH(data + bytes_to_cache_line, 0 /* rw */, 1 /* locality */);
+ PREFETCH(data + bytes_to_cache_line + 63, 0 /* rw */, 1 /* locality */);
+ *byte_offset = bytes_to_cache_line;
+ }
+
+ static inline bool HashMayMatch(uint32_t h1, uint32_t h2, uint32_t len_bytes,
+ int num_probes, const char *data) {
+ uint32_t bytes_to_cache_line = FastRange32(len_bytes >> 6, h1) << 6;
+ return HashMayMatchPrepared(h2, num_probes, data + bytes_to_cache_line);
+ }
+
+ static inline bool HashMayMatchPrepared(uint32_t h2, int num_probes,
+ const char *data_at_cache_line) {
+ uint32_t h = h2;
+#ifdef HAVE_AVX2
+ int rem_probes = num_probes;
+
+ // NOTE: For better performance for num_probes in {1, 2, 9, 10, 17, 18,
+ // etc.} one can insert specialized code for rem_probes <= 2, bypassing
+ // the SIMD code in those cases. There is a detectable but minor overhead
+ // applied to other values of num_probes (when not statically determined),
+ // but smoother performance curve vs. num_probes. But for now, when
+ // in doubt, don't add unnecessary code.
+
+ // Powers of 32-bit golden ratio, mod 2**32.
+ const __m256i multipliers =
+ _mm256_setr_epi32(0x00000001, 0x9e3779b9, 0xe35e67b1, 0x734297e9,
+ 0x35fbe861, 0xdeb7c719, 0x448b211, 0x3459b749);
+
+ for (;;) {
+ // Eight copies of hash
+ __m256i hash_vector = _mm256_set1_epi32(h);
+
+ // Same effect as repeated multiplication by 0x9e3779b9 thanks to
+ // associativity of multiplication.
+ hash_vector = _mm256_mullo_epi32(hash_vector, multipliers);
+
+ // Now the top 9 bits of each of the eight 32-bit values in
+ // hash_vector are bit addresses for probes within the cache line.
+ // While the platform-independent code uses byte addressing (6 bits
+ // to pick a byte + 3 bits to pick a bit within a byte), here we work
+ // with 32-bit words (4 bits to pick a word + 5 bits to pick a bit
+ // within a word) because that works well with AVX2 and is equivalent
+ // under little-endian.
+
+ // Shift each right by 28 bits to get 4-bit word addresses.
+ const __m256i word_addresses = _mm256_srli_epi32(hash_vector, 28);
+
+ // Gather 32-bit values spread over 512 bits by 4-bit address. In
+ // essence, we are dereferencing eight pointers within the cache
+ // line.
+ //
+ // Option 1: AVX2 gather (seems to be a little slow - understandable)
+ // const __m256i value_vector =
+ // _mm256_i32gather_epi32(static_cast<const int
+ // *>(data_at_cache_line),
+ // word_addresses,
+ // /*bytes / i32*/ 4);
+ // END Option 1
+ // Potentially unaligned as we're not *always* cache-aligned -> loadu
+ const __m256i *mm_data =
+ reinterpret_cast<const __m256i *>(data_at_cache_line);
+ __m256i lower = _mm256_loadu_si256(mm_data);
+ __m256i upper = _mm256_loadu_si256(mm_data + 1);
+ // Option 2: AVX512VL permute hack
+ // Only negligibly faster than Option 3, so not yet worth supporting
+ // const __m256i value_vector =
+ // _mm256_permutex2var_epi32(lower, word_addresses, upper);
+ // END Option 2
+ // Option 3: AVX2 permute+blend hack
+ // Use lowest three bits to order probing values, as if all from same
+ // 256 bit piece.
+ lower = _mm256_permutevar8x32_epi32(lower, word_addresses);
+ upper = _mm256_permutevar8x32_epi32(upper, word_addresses);
+ // Just top 1 bit of address, to select between lower and upper.
+ const __m256i upper_lower_selector = _mm256_srai_epi32(hash_vector, 31);
+ // Finally: the next 8 probed 32-bit values, in probing sequence order.
+ const __m256i value_vector =
+ _mm256_blendv_epi8(lower, upper, upper_lower_selector);
+ // END Option 3
+
+ // We might not need to probe all 8, so build a mask for selecting only
+ // what we need. (The k_selector(s) could be pre-computed but that
+ // doesn't seem to make a noticeable performance difference.)
+ const __m256i zero_to_seven = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
+ // Subtract rem_probes from each of those constants
+ __m256i k_selector =
+ _mm256_sub_epi32(zero_to_seven, _mm256_set1_epi32(rem_probes));
+ // Negative after subtract -> use/select
+ // Keep only high bit (logical shift right each by 31).
+ k_selector = _mm256_srli_epi32(k_selector, 31);
+
+ // Strip off the 4 bit word address (shift left)
+ __m256i bit_addresses = _mm256_slli_epi32(hash_vector, 4);
+ // And keep only 5-bit (32 - 27) bit-within-32-bit-word addresses.
+ bit_addresses = _mm256_srli_epi32(bit_addresses, 27);
+ // Build a bit mask
+ const __m256i bit_mask = _mm256_sllv_epi32(k_selector, bit_addresses);
+
+ // Like ((~value_vector) & bit_mask) == 0)
+ bool match = _mm256_testc_si256(value_vector, bit_mask) != 0;
+
+ // This check first so that it's easy for branch predictor to optimize
+ // num_probes <= 8 case, making it free of unpredictable branches.
+ if (rem_probes <= 8) {
+ return match;
+ } else if (!match) {
+ return false;
+ }
+ // otherwise
+ // Need another iteration. 0xab25f4c1 == golden ratio to the 8th power
+ h *= 0xab25f4c1;
+ rem_probes -= 8;
+ }
+#else
+ for (int i = 0; i < num_probes; ++i, h *= uint32_t{0x9e3779b9}) {
+ // 9-bit address within 512 bit cache line
+ int bitpos = h >> (32 - 9);
+ if ((data_at_cache_line[bitpos >> 3] & (char(1) << (bitpos & 7))) == 0) {
+ return false;
+ }
+ }
+ return true;
+#endif
+ }
+};
+
+// A legacy Bloom filter implementation with no locality of probes (slow).
+// It uses double hashing to generate a sequence of hash values.
+// Asymptotic analysis is in [Kirsch,Mitzenmacher 2006], but known to have
+// subtle accuracy flaws for practical sizes [Dillinger,Manolios 2004].
+//
+// DO NOT REUSE
+//
+class LegacyNoLocalityBloomImpl {
+ public:
+ static inline int ChooseNumProbes(int bits_per_key) {
+ // We intentionally round down to reduce probing cost a little bit
+ int num_probes = static_cast<int>(bits_per_key * 0.69); // 0.69 =~ ln(2)
+ if (num_probes < 1) num_probes = 1;
+ if (num_probes > 30) num_probes = 30;
+ return num_probes;
+ }
+
+ static inline void AddHash(uint32_t h, uint32_t total_bits, int num_probes,
+ char *data) {
+ const uint32_t delta = (h >> 17) | (h << 15); // Rotate right 17 bits
+ for (int i = 0; i < num_probes; i++) {
+ const uint32_t bitpos = h % total_bits;
+ data[bitpos / 8] |= (1 << (bitpos % 8));
+ h += delta;
+ }
+ }
+
+ static inline bool HashMayMatch(uint32_t h, uint32_t total_bits,
+ int num_probes, const char *data) {
+ const uint32_t delta = (h >> 17) | (h << 15); // Rotate right 17 bits
+ for (int i = 0; i < num_probes; i++) {
+ const uint32_t bitpos = h % total_bits;
+ if ((data[bitpos / 8] & (1 << (bitpos % 8))) == 0) {
+ return false;
+ }
+ h += delta;
+ }
+ return true;
+ }
+};
+
+// A legacy Bloom filter implementation with probes local to a single
+// cache line (fast). Because SST files might be transported between
+// platforms, the cache line size is a parameter rather than hard coded.
+// (But if specified as a constant parameter, an optimizing compiler
+// should take advantage of that.)
+//
+// When ExtraRotates is false, this implementation is notably deficient in
+// accuracy. Specifically, it uses double hashing with a 1/512 chance of the
+// increment being zero (when cache line size is 512 bits). Thus, there's a
+// 1/512 chance of probing only one index, which we'd expect to incur about
+// a 1/2 * 1/512 or absolute 0.1% FP rate penalty. More detail at
+// https://github.com/facebook/rocksdb/issues/4120
+//
+// DO NOT REUSE
+//
+template <bool ExtraRotates>
+class LegacyLocalityBloomImpl {
+ private:
+ static inline uint32_t GetLine(uint32_t h, uint32_t num_lines) {
+ uint32_t offset_h = ExtraRotates ? (h >> 11) | (h << 21) : h;
+ return offset_h % num_lines;
+ }
+
+ public:
+ // NOTE: this has only been validated to enough accuracy for producing
+ // reasonable warnings / user feedback, not for making functional decisions.
+ static double EstimatedFpRate(size_t keys, size_t bytes, int num_probes) {
+ double bits_per_key = 8.0 * bytes / keys;
+ double filter_rate = BloomMath::CacheLocalFpRate(bits_per_key, num_probes,
+ /*cache line bits*/ 512);
+ if (!ExtraRotates) {
+ // Good estimate of impact of flaw in index computation.
+ // Adds roughly 0.002 around 50 bits/key and 0.001 around 100 bits/key.
+ // The + 22 shifts it nicely to fit for lower bits/key.
+ filter_rate += 0.1 / (bits_per_key * 0.75 + 22);
+ } else {
+ // Not yet validated
+ assert(false);
+ }
+ // Always uses 32-bit hash
+ double fingerprint_rate = BloomMath::FingerprintFpRate(keys, 32);
+ return BloomMath::IndependentProbabilitySum(filter_rate, fingerprint_rate);
+ }
+
+ static inline void AddHash(uint32_t h, uint32_t num_lines, int num_probes,
+ char *data, int log2_cache_line_bytes) {
+ const int log2_cache_line_bits = log2_cache_line_bytes + 3;
+
+ char *data_at_offset =
+ data + (GetLine(h, num_lines) << log2_cache_line_bytes);
+ const uint32_t delta = (h >> 17) | (h << 15);
+ for (int i = 0; i < num_probes; ++i) {
+ // Mask to bit-within-cache-line address
+ const uint32_t bitpos = h & ((1 << log2_cache_line_bits) - 1);
+ data_at_offset[bitpos / 8] |= (1 << (bitpos % 8));
+ if (ExtraRotates) {
+ h = (h >> log2_cache_line_bits) | (h << (32 - log2_cache_line_bits));
+ }
+ h += delta;
+ }
+ }
+
+ static inline void PrepareHashMayMatch(uint32_t h, uint32_t num_lines,
+ const char *data,
+ uint32_t /*out*/ *byte_offset,
+ int log2_cache_line_bytes) {
+ uint32_t b = GetLine(h, num_lines) << log2_cache_line_bytes;
+ PREFETCH(data + b, 0 /* rw */, 1 /* locality */);
+ PREFETCH(data + b + ((1 << log2_cache_line_bytes) - 1), 0 /* rw */,
+ 1 /* locality */);
+ *byte_offset = b;
+ }
+
+ static inline bool HashMayMatch(uint32_t h, uint32_t num_lines,
+ int num_probes, const char *data,
+ int log2_cache_line_bytes) {
+ uint32_t b = GetLine(h, num_lines) << log2_cache_line_bytes;
+ return HashMayMatchPrepared(h, num_probes, data + b, log2_cache_line_bytes);
+ }
+
+ static inline bool HashMayMatchPrepared(uint32_t h, int num_probes,
+ const char *data_at_offset,
+ int log2_cache_line_bytes) {
+ const int log2_cache_line_bits = log2_cache_line_bytes + 3;
+
+ const uint32_t delta = (h >> 17) | (h << 15);
+ for (int i = 0; i < num_probes; ++i) {
+ // Mask to bit-within-cache-line address
+ const uint32_t bitpos = h & ((1 << log2_cache_line_bits) - 1);
+ if (((data_at_offset[bitpos / 8]) & (1 << (bitpos % 8))) == 0) {
+ return false;
+ }
+ if (ExtraRotates) {
+ h = (h >> log2_cache_line_bits) | (h << (32 - log2_cache_line_bits));
+ }
+ h += delta;
+ }
+ return true;
+ }
+};
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/bloom_test.cc b/src/rocksdb/util/bloom_test.cc
new file mode 100644
index 000000000..9d509ac3d
--- /dev/null
+++ b/src/rocksdb/util/bloom_test.cc
@@ -0,0 +1,1175 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2012 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#ifndef GFLAGS
+#include <cstdio>
+int main() {
+ fprintf(stderr, "Please install gflags to run this test... Skipping...\n");
+ return 0;
+}
+#else
+
+#include <array>
+#include <cmath>
+#include <vector>
+
+#include "cache/cache_entry_roles.h"
+#include "cache/cache_reservation_manager.h"
+#include "memory/arena.h"
+#include "port/jemalloc_helper.h"
+#include "rocksdb/filter_policy.h"
+#include "table/block_based/filter_policy_internal.h"
+#include "test_util/testharness.h"
+#include "test_util/testutil.h"
+#include "util/gflags_compat.h"
+#include "util/hash.h"
+
+using GFLAGS_NAMESPACE::ParseCommandLineFlags;
+
+// The test is not fully designed for bits_per_key other than 10, but with
+// this parameter you can easily explore the behavior of other bits_per_key.
+// See also filter_bench.
+DEFINE_int32(bits_per_key, 10, "");
+
+namespace ROCKSDB_NAMESPACE {
+
+namespace {
+const std::string kLegacyBloom = test::LegacyBloomFilterPolicy::kClassName();
+const std::string kFastLocalBloom =
+ test::FastLocalBloomFilterPolicy::kClassName();
+const std::string kStandard128Ribbon =
+ test::Standard128RibbonFilterPolicy::kClassName();
+} // namespace
+
+static const int kVerbose = 1;
+
+static Slice Key(int i, char* buffer) {
+ std::string s;
+ PutFixed32(&s, static_cast<uint32_t>(i));
+ memcpy(buffer, s.c_str(), sizeof(i));
+ return Slice(buffer, sizeof(i));
+}
+
+static int NextLength(int length) {
+ if (length < 10) {
+ length += 1;
+ } else if (length < 100) {
+ length += 10;
+ } else if (length < 1000) {
+ length += 100;
+ } else {
+ length += 1000;
+ }
+ return length;
+}
+
+class FullBloomTest : public testing::TestWithParam<std::string> {
+ protected:
+ BlockBasedTableOptions table_options_;
+
+ private:
+ std::shared_ptr<const FilterPolicy>& policy_;
+ std::unique_ptr<FilterBitsBuilder> bits_builder_;
+ std::unique_ptr<FilterBitsReader> bits_reader_;
+ std::unique_ptr<const char[]> buf_;
+ size_t filter_size_;
+
+ public:
+ FullBloomTest() : policy_(table_options_.filter_policy), filter_size_(0) {
+ ResetPolicy();
+ }
+
+ BuiltinFilterBitsBuilder* GetBuiltinFilterBitsBuilder() {
+ // Throws on bad cast
+ return dynamic_cast<BuiltinFilterBitsBuilder*>(bits_builder_.get());
+ }
+
+ const BloomLikeFilterPolicy* GetBloomLikeFilterPolicy() {
+ // Throws on bad cast
+ return &dynamic_cast<const BloomLikeFilterPolicy&>(*policy_);
+ }
+
+ void Reset() {
+ bits_builder_.reset(BloomFilterPolicy::GetBuilderFromContext(
+ FilterBuildingContext(table_options_)));
+ bits_reader_.reset(nullptr);
+ buf_.reset(nullptr);
+ filter_size_ = 0;
+ }
+
+ void ResetPolicy(double bits_per_key) {
+ policy_ = BloomLikeFilterPolicy::Create(GetParam(), bits_per_key);
+ Reset();
+ }
+
+ void ResetPolicy() { ResetPolicy(FLAGS_bits_per_key); }
+
+ void Add(const Slice& s) { bits_builder_->AddKey(s); }
+
+ void OpenRaw(const Slice& s) {
+ bits_reader_.reset(policy_->GetFilterBitsReader(s));
+ }
+
+ void Build() {
+ Slice filter = bits_builder_->Finish(&buf_);
+ bits_reader_.reset(policy_->GetFilterBitsReader(filter));
+ filter_size_ = filter.size();
+ }
+
+ size_t FilterSize() const { return filter_size_; }
+
+ Slice FilterData() { return Slice(buf_.get(), filter_size_); }
+
+ int GetNumProbesFromFilterData() {
+ assert(filter_size_ >= 5);
+ int8_t raw_num_probes = static_cast<int8_t>(buf_.get()[filter_size_ - 5]);
+ if (raw_num_probes == -1) { // New bloom filter marker
+ return static_cast<uint8_t>(buf_.get()[filter_size_ - 3]);
+ } else {
+ return raw_num_probes;
+ }
+ }
+
+ int GetRibbonSeedFromFilterData() {
+ assert(filter_size_ >= 5);
+ // Check for ribbon marker
+ assert(-2 == static_cast<int8_t>(buf_.get()[filter_size_ - 5]));
+ return static_cast<uint8_t>(buf_.get()[filter_size_ - 4]);
+ }
+
+ bool Matches(const Slice& s) {
+ if (bits_reader_ == nullptr) {
+ Build();
+ }
+ return bits_reader_->MayMatch(s);
+ }
+
+ // Provides a kind of fingerprint on the Bloom filter's
+ // behavior, for reasonbly high FP rates.
+ uint64_t PackedMatches() {
+ char buffer[sizeof(int)];
+ uint64_t result = 0;
+ for (int i = 0; i < 64; i++) {
+ if (Matches(Key(i + 12345, buffer))) {
+ result |= uint64_t{1} << i;
+ }
+ }
+ return result;
+ }
+
+ // Provides a kind of fingerprint on the Bloom filter's
+ // behavior, for lower FP rates.
+ std::string FirstFPs(int count) {
+ char buffer[sizeof(int)];
+ std::string rv;
+ int fp_count = 0;
+ for (int i = 0; i < 1000000; i++) {
+ // Pack four match booleans into each hexadecimal digit
+ if (Matches(Key(i + 1000000, buffer))) {
+ ++fp_count;
+ rv += std::to_string(i);
+ if (fp_count == count) {
+ break;
+ }
+ rv += ',';
+ }
+ }
+ return rv;
+ }
+
+ double FalsePositiveRate() {
+ char buffer[sizeof(int)];
+ int result = 0;
+ for (int i = 0; i < 10000; i++) {
+ if (Matches(Key(i + 1000000000, buffer))) {
+ result++;
+ }
+ }
+ return result / 10000.0;
+ }
+};
+
+TEST_P(FullBloomTest, FilterSize) {
+ // In addition to checking the consistency of space computation, we are
+ // checking that denoted and computed doubles are interpreted as expected
+ // as bits_per_key values.
+ bool some_computed_less_than_denoted = false;
+ // Note: to avoid unproductive configurations, bits_per_key < 0.5 is rounded
+ // down to 0 (no filter), and 0.5 <= bits_per_key < 1.0 is rounded up to 1
+ // bit per key (1000 millibits). Also, enforced maximum is 100 bits per key
+ // (100000 millibits).
+ for (auto bpk : std::vector<std::pair<double, int> >{{-HUGE_VAL, 0},
+ {-INFINITY, 0},
+ {0.0, 0},
+ {0.499, 0},
+ {0.5, 1000},
+ {1.234, 1234},
+ {3.456, 3456},
+ {9.5, 9500},
+ {10.0, 10000},
+ {10.499, 10499},
+ {21.345, 21345},
+ {99.999, 99999},
+ {1234.0, 100000},
+ {HUGE_VAL, 100000},
+ {INFINITY, 100000},
+ {NAN, 100000}}) {
+ ResetPolicy(bpk.first);
+ auto bfp = GetBloomLikeFilterPolicy();
+ EXPECT_EQ(bpk.second, bfp->GetMillibitsPerKey());
+ EXPECT_EQ((bpk.second + 500) / 1000, bfp->GetWholeBitsPerKey());
+
+ double computed = bpk.first;
+ // This transforms e.g. 9.5 -> 9.499999999999998, which we still
+ // round to 10 for whole bits per key.
+ computed += 0.5;
+ computed /= 1234567.0;
+ computed *= 1234567.0;
+ computed -= 0.5;
+ some_computed_less_than_denoted |= (computed < bpk.first);
+ ResetPolicy(computed);
+ bfp = GetBloomLikeFilterPolicy();
+ EXPECT_EQ(bpk.second, bfp->GetMillibitsPerKey());
+ EXPECT_EQ((bpk.second + 500) / 1000, bfp->GetWholeBitsPerKey());
+
+ auto bits_builder = GetBuiltinFilterBitsBuilder();
+ if (bpk.second == 0) {
+ ASSERT_EQ(bits_builder, nullptr);
+ continue;
+ }
+
+ size_t n = 1;
+ size_t space = 0;
+ for (; n < 1000000; n += 1 + n / 1000) {
+ // Ensure consistency between CalculateSpace and ApproximateNumEntries
+ space = bits_builder->CalculateSpace(n);
+ size_t n2 = bits_builder->ApproximateNumEntries(space);
+ EXPECT_GE(n2, n);
+ size_t space2 = bits_builder->CalculateSpace(n2);
+ if (n > 12000 && GetParam() == kStandard128Ribbon) {
+ // TODO(peterd): better approximation?
+ EXPECT_GE(space2, space);
+ EXPECT_LE(space2 * 0.998, space * 1.0);
+ } else {
+ EXPECT_EQ(space2, space);
+ }
+ }
+ // Until size_t overflow
+ for (; n < (n + n / 3); n += n / 3) {
+ // Ensure space computation is not overflowing; capped is OK
+ size_t space2 = bits_builder->CalculateSpace(n);
+ EXPECT_GE(space2, space);
+ space = space2;
+ }
+ }
+ // Check that the compiler hasn't optimized our computation into nothing
+ EXPECT_TRUE(some_computed_less_than_denoted);
+ ResetPolicy();
+}
+
+TEST_P(FullBloomTest, FullEmptyFilter) {
+ // Empty filter is not match, at this level
+ ASSERT_TRUE(!Matches("hello"));
+ ASSERT_TRUE(!Matches("world"));
+}
+
+TEST_P(FullBloomTest, FullSmall) {
+ Add("hello");
+ Add("world");
+ ASSERT_TRUE(Matches("hello"));
+ ASSERT_TRUE(Matches("world"));
+ ASSERT_TRUE(!Matches("x"));
+ ASSERT_TRUE(!Matches("foo"));
+}
+
+TEST_P(FullBloomTest, FullVaryingLengths) {
+ char buffer[sizeof(int)];
+
+ // Count number of filters that significantly exceed the false positive rate
+ int mediocre_filters = 0;
+ int good_filters = 0;
+
+ for (int length = 1; length <= 10000; length = NextLength(length)) {
+ Reset();
+ for (int i = 0; i < length; i++) {
+ Add(Key(i, buffer));
+ }
+ Build();
+
+ EXPECT_LE(FilterSize(), (size_t)((length * FLAGS_bits_per_key / 8) +
+ CACHE_LINE_SIZE * 2 + 5));
+
+ // All added keys must match
+ for (int i = 0; i < length; i++) {
+ ASSERT_TRUE(Matches(Key(i, buffer)))
+ << "Length " << length << "; key " << i;
+ }
+
+ // Check false positive rate
+ double rate = FalsePositiveRate();
+ if (kVerbose >= 1) {
+ fprintf(stderr, "False positives: %5.2f%% @ length = %6d ; bytes = %6d\n",
+ rate * 100.0, length, static_cast<int>(FilterSize()));
+ }
+ if (FLAGS_bits_per_key == 10) {
+ EXPECT_LE(rate, 0.02); // Must not be over 2%
+ if (rate > 0.0125) {
+ mediocre_filters++; // Allowed, but not too often
+ } else {
+ good_filters++;
+ }
+ }
+ }
+ if (kVerbose >= 1) {
+ fprintf(stderr, "Filters: %d good, %d mediocre\n", good_filters,
+ mediocre_filters);
+ }
+ EXPECT_LE(mediocre_filters, good_filters / 5);
+}
+
+TEST_P(FullBloomTest, OptimizeForMemory) {
+ char buffer[sizeof(int)];
+ for (bool offm : {true, false}) {
+ table_options_.optimize_filters_for_memory = offm;
+ ResetPolicy();
+ Random32 rnd(12345);
+ uint64_t total_size = 0;
+ uint64_t total_mem = 0;
+ int64_t total_keys = 0;
+ double total_fp_rate = 0;
+ constexpr int nfilters = 100;
+ for (int i = 0; i < nfilters; ++i) {
+ int nkeys = static_cast<int>(rnd.Uniformish(10000)) + 100;
+ Reset();
+ for (int j = 0; j < nkeys; ++j) {
+ Add(Key(j, buffer));
+ }
+ Build();
+ size_t size = FilterData().size();
+ total_size += size;
+ // optimize_filters_for_memory currently depends on malloc_usable_size
+ // but we run the rest of the test to ensure no bad behavior without it.
+#ifdef ROCKSDB_MALLOC_USABLE_SIZE
+ size = malloc_usable_size(const_cast<char*>(FilterData().data()));
+#endif // ROCKSDB_MALLOC_USABLE_SIZE
+ total_mem += size;
+ total_keys += nkeys;
+ total_fp_rate += FalsePositiveRate();
+ }
+ if (FLAGS_bits_per_key == 10) {
+ EXPECT_LE(total_fp_rate / double{nfilters}, 0.011);
+ EXPECT_GE(total_fp_rate / double{nfilters},
+ CACHE_LINE_SIZE >= 256 ? 0.007 : 0.008);
+ }
+
+ int64_t ex_min_total_size = int64_t{FLAGS_bits_per_key} * total_keys / 8;
+ if (GetParam() == kStandard128Ribbon) {
+ // ~ 30% savings vs. Bloom filter
+ ex_min_total_size = 7 * ex_min_total_size / 10;
+ }
+ EXPECT_GE(static_cast<int64_t>(total_size), ex_min_total_size);
+
+ int64_t blocked_bloom_overhead = nfilters * (CACHE_LINE_SIZE + 5);
+ if (GetParam() == kLegacyBloom) {
+ // this config can add extra cache line to make odd number
+ blocked_bloom_overhead += nfilters * CACHE_LINE_SIZE;
+ }
+
+ EXPECT_GE(total_mem, total_size);
+
+ // optimize_filters_for_memory not implemented with legacy Bloom
+ if (offm && GetParam() != kLegacyBloom) {
+ // This value can include a small extra penalty for kExtraPadding
+ fprintf(stderr, "Internal fragmentation (optimized): %g%%\n",
+ (total_mem - total_size) * 100.0 / total_size);
+ // Less than 1% internal fragmentation
+ EXPECT_LE(total_mem, total_size * 101 / 100);
+ // Up to 2% storage penalty
+ EXPECT_LE(static_cast<int64_t>(total_size),
+ ex_min_total_size * 102 / 100 + blocked_bloom_overhead);
+ } else {
+ fprintf(stderr, "Internal fragmentation (not optimized): %g%%\n",
+ (total_mem - total_size) * 100.0 / total_size);
+ // TODO: add control checks for more allocators?
+#ifdef ROCKSDB_JEMALLOC
+ fprintf(stderr, "Jemalloc detected? %d\n", HasJemalloc());
+ if (HasJemalloc()) {
+#ifdef ROCKSDB_MALLOC_USABLE_SIZE
+ // More than 5% internal fragmentation
+ EXPECT_GE(total_mem, total_size * 105 / 100);
+#endif // ROCKSDB_MALLOC_USABLE_SIZE
+ }
+#endif // ROCKSDB_JEMALLOC
+ // No storage penalty, just usual overhead
+ EXPECT_LE(static_cast<int64_t>(total_size),
+ ex_min_total_size + blocked_bloom_overhead);
+ }
+ }
+}
+
+class ChargeFilterConstructionTest : public testing::Test {};
+TEST_F(ChargeFilterConstructionTest, RibbonFilterFallBackOnLargeBanding) {
+ constexpr std::size_t kCacheCapacity =
+ 8 * CacheReservationManagerImpl<
+ CacheEntryRole::kFilterConstruction>::GetDummyEntrySize();
+ constexpr std::size_t num_entries_for_cache_full = kCacheCapacity / 8;
+
+ for (CacheEntryRoleOptions::Decision charge_filter_construction_mem :
+ {CacheEntryRoleOptions::Decision::kEnabled,
+ CacheEntryRoleOptions::Decision::kDisabled}) {
+ bool will_fall_back = charge_filter_construction_mem ==
+ CacheEntryRoleOptions::Decision::kEnabled;
+
+ BlockBasedTableOptions table_options;
+ table_options.cache_usage_options.options_overrides.insert(
+ {CacheEntryRole::kFilterConstruction,
+ {/*.charged = */ charge_filter_construction_mem}});
+ LRUCacheOptions lo;
+ lo.capacity = kCacheCapacity;
+ lo.num_shard_bits = 0; // 2^0 shard
+ lo.strict_capacity_limit = true;
+ std::shared_ptr<Cache> cache(NewLRUCache(lo));
+ table_options.block_cache = cache;
+ table_options.filter_policy =
+ BloomLikeFilterPolicy::Create(kStandard128Ribbon, FLAGS_bits_per_key);
+ FilterBuildingContext ctx(table_options);
+ std::unique_ptr<FilterBitsBuilder> filter_bits_builder(
+ table_options.filter_policy->GetBuilderWithContext(ctx));
+
+ char key_buffer[sizeof(int)];
+ for (std::size_t i = 0; i < num_entries_for_cache_full; ++i) {
+ filter_bits_builder->AddKey(Key(static_cast<int>(i), key_buffer));
+ }
+
+ std::unique_ptr<const char[]> buf;
+ Slice filter = filter_bits_builder->Finish(&buf);
+
+ // To verify Ribbon Filter fallbacks to Bloom Filter properly
+ // based on cache charging result
+ // See BloomFilterPolicy::GetBloomBitsReader re: metadata
+ // -1 = Marker for newer Bloom implementations
+ // -2 = Marker for Standard128 Ribbon
+ if (will_fall_back) {
+ EXPECT_EQ(filter.data()[filter.size() - 5], static_cast<char>(-1));
+ } else {
+ EXPECT_EQ(filter.data()[filter.size() - 5], static_cast<char>(-2));
+ }
+
+ if (charge_filter_construction_mem ==
+ CacheEntryRoleOptions::Decision::kEnabled) {
+ const size_t dummy_entry_num = static_cast<std::size_t>(std::ceil(
+ filter.size() * 1.0 /
+ CacheReservationManagerImpl<
+ CacheEntryRole::kFilterConstruction>::GetDummyEntrySize()));
+ EXPECT_GE(
+ cache->GetPinnedUsage(),
+ dummy_entry_num *
+ CacheReservationManagerImpl<
+ CacheEntryRole::kFilterConstruction>::GetDummyEntrySize());
+ EXPECT_LT(
+ cache->GetPinnedUsage(),
+ (dummy_entry_num + 1) *
+ CacheReservationManagerImpl<
+ CacheEntryRole::kFilterConstruction>::GetDummyEntrySize());
+ } else {
+ EXPECT_EQ(cache->GetPinnedUsage(), 0);
+ }
+ }
+}
+
+namespace {
+inline uint32_t SelectByCacheLineSize(uint32_t for64, uint32_t for128,
+ uint32_t for256) {
+ (void)for64;
+ (void)for128;
+ (void)for256;
+#if CACHE_LINE_SIZE == 64
+ return for64;
+#elif CACHE_LINE_SIZE == 128
+ return for128;
+#elif CACHE_LINE_SIZE == 256
+ return for256;
+#else
+#error "CACHE_LINE_SIZE unknown or unrecognized"
+#endif
+}
+} // namespace
+
+// Ensure the implementation doesn't accidentally change in an
+// incompatible way. This test doesn't check the reading side
+// (FirstFPs/PackedMatches) for LegacyBloom because it requires the
+// ability to read filters generated using other cache line sizes.
+// See RawSchema.
+TEST_P(FullBloomTest, Schema) {
+#define EXPECT_EQ_Bloom(a, b) \
+ { \
+ if (GetParam() != kStandard128Ribbon) { \
+ EXPECT_EQ(a, b); \
+ } \
+ }
+#define EXPECT_EQ_Ribbon(a, b) \
+ { \
+ if (GetParam() == kStandard128Ribbon) { \
+ EXPECT_EQ(a, b); \
+ } \
+ }
+#define EXPECT_EQ_FastBloom(a, b) \
+ { \
+ if (GetParam() == kFastLocalBloom) { \
+ EXPECT_EQ(a, b); \
+ } \
+ }
+#define EXPECT_EQ_LegacyBloom(a, b) \
+ { \
+ if (GetParam() == kLegacyBloom) { \
+ EXPECT_EQ(a, b); \
+ } \
+ }
+#define EXPECT_EQ_NotLegacy(a, b) \
+ { \
+ if (GetParam() != kLegacyBloom) { \
+ EXPECT_EQ(a, b); \
+ } \
+ }
+
+ char buffer[sizeof(int)];
+
+ // First do a small number of keys, where Ribbon config will fall back on
+ // fast Bloom filter and generate the same data
+ ResetPolicy(5); // num_probes = 3
+ for (int key = 0; key < 87; key++) {
+ Add(Key(key, buffer));
+ }
+ Build();
+ EXPECT_EQ(GetNumProbesFromFilterData(), 3);
+
+ EXPECT_EQ_NotLegacy(BloomHash(FilterData()), 4130687756U);
+
+ EXPECT_EQ_NotLegacy("31,38,40,43,61,83,86,112,125,131", FirstFPs(10));
+
+ // Now use enough keys so that changing bits / key by 1 is guaranteed to
+ // change number of allocated cache lines. So keys > max cache line bits.
+
+ // Note that the first attempted Ribbon seed is determined by the hash
+ // of the first key added (for pseudorandomness in practice, determinism in
+ // testing)
+
+ ResetPolicy(2); // num_probes = 1
+ for (int key = 0; key < 2087; key++) {
+ Add(Key(key, buffer));
+ }
+ Build();
+ EXPECT_EQ_Bloom(GetNumProbesFromFilterData(), 1);
+ EXPECT_EQ_Ribbon(GetRibbonSeedFromFilterData(), 61);
+
+ EXPECT_EQ_LegacyBloom(
+ BloomHash(FilterData()),
+ SelectByCacheLineSize(1567096579, 1964771444, 2659542661U));
+ EXPECT_EQ_FastBloom(BloomHash(FilterData()), 3817481309U);
+ EXPECT_EQ_Ribbon(BloomHash(FilterData()), 1705851228U);
+
+ EXPECT_EQ_FastBloom("11,13,17,25,29,30,35,37,45,53", FirstFPs(10));
+ EXPECT_EQ_Ribbon("3,8,10,17,19,20,23,28,31,32", FirstFPs(10));
+
+ ResetPolicy(3); // num_probes = 2
+ for (int key = 0; key < 2087; key++) {
+ Add(Key(key, buffer));
+ }
+ Build();
+ EXPECT_EQ_Bloom(GetNumProbesFromFilterData(), 2);
+ EXPECT_EQ_Ribbon(GetRibbonSeedFromFilterData(), 61);
+
+ EXPECT_EQ_LegacyBloom(
+ BloomHash(FilterData()),
+ SelectByCacheLineSize(2707206547U, 2571983456U, 218344685));
+ EXPECT_EQ_FastBloom(BloomHash(FilterData()), 2807269961U);
+ EXPECT_EQ_Ribbon(BloomHash(FilterData()), 1095342358U);
+
+ EXPECT_EQ_FastBloom("4,15,17,24,27,28,29,53,63,70", FirstFPs(10));
+ EXPECT_EQ_Ribbon("3,17,20,28,32,33,36,43,49,54", FirstFPs(10));
+
+ ResetPolicy(5); // num_probes = 3
+ for (int key = 0; key < 2087; key++) {
+ Add(Key(key, buffer));
+ }
+ Build();
+ EXPECT_EQ_Bloom(GetNumProbesFromFilterData(), 3);
+ EXPECT_EQ_Ribbon(GetRibbonSeedFromFilterData(), 61);
+
+ EXPECT_EQ_LegacyBloom(
+ BloomHash(FilterData()),
+ SelectByCacheLineSize(515748486, 94611728, 2436112214U));
+ EXPECT_EQ_FastBloom(BloomHash(FilterData()), 204628445U);
+ EXPECT_EQ_Ribbon(BloomHash(FilterData()), 3971337699U);
+
+ EXPECT_EQ_FastBloom("15,24,29,39,53,87,89,100,103,104", FirstFPs(10));
+ EXPECT_EQ_Ribbon("3,33,36,43,67,70,76,78,84,102", FirstFPs(10));
+
+ ResetPolicy(8); // num_probes = 5
+ for (int key = 0; key < 2087; key++) {
+ Add(Key(key, buffer));
+ }
+ Build();
+ EXPECT_EQ_Bloom(GetNumProbesFromFilterData(), 5);
+ EXPECT_EQ_Ribbon(GetRibbonSeedFromFilterData(), 61);
+
+ EXPECT_EQ_LegacyBloom(
+ BloomHash(FilterData()),
+ SelectByCacheLineSize(1302145999, 2811644657U, 756553699));
+ EXPECT_EQ_FastBloom(BloomHash(FilterData()), 355564975U);
+ EXPECT_EQ_Ribbon(BloomHash(FilterData()), 3651449053U);
+
+ EXPECT_EQ_FastBloom("16,60,66,126,220,238,244,256,265,287", FirstFPs(10));
+ EXPECT_EQ_Ribbon("33,187,203,296,300,322,411,419,547,582", FirstFPs(10));
+
+ ResetPolicy(9); // num_probes = 6
+ for (int key = 0; key < 2087; key++) {
+ Add(Key(key, buffer));
+ }
+ Build();
+ EXPECT_EQ_Bloom(GetNumProbesFromFilterData(), 6);
+ EXPECT_EQ_Ribbon(GetRibbonSeedFromFilterData(), 61);
+
+ EXPECT_EQ_LegacyBloom(
+ BloomHash(FilterData()),
+ SelectByCacheLineSize(2092755149, 661139132, 1182970461));
+ EXPECT_EQ_FastBloom(BloomHash(FilterData()), 2137566013U);
+ EXPECT_EQ_Ribbon(BloomHash(FilterData()), 1005676675U);
+
+ EXPECT_EQ_FastBloom("156,367,791,872,945,1015,1139,1159,1265", FirstFPs(9));
+ EXPECT_EQ_Ribbon("33,187,203,296,411,419,604,612,615,619", FirstFPs(10));
+
+ ResetPolicy(11); // num_probes = 7
+ for (int key = 0; key < 2087; key++) {
+ Add(Key(key, buffer));
+ }
+ Build();
+ EXPECT_EQ_Bloom(GetNumProbesFromFilterData(), 7);
+ EXPECT_EQ_Ribbon(GetRibbonSeedFromFilterData(), 61);
+
+ EXPECT_EQ_LegacyBloom(
+ BloomHash(FilterData()),
+ SelectByCacheLineSize(3755609649U, 1812694762, 1449142939));
+ EXPECT_EQ_FastBloom(BloomHash(FilterData()), 2561502687U);
+ EXPECT_EQ_Ribbon(BloomHash(FilterData()), 3129900846U);
+
+ EXPECT_EQ_FastBloom("34,74,130,236,643,882,962,1015,1035,1110", FirstFPs(10));
+ EXPECT_EQ_Ribbon("411,419,623,665,727,794,955,1052,1323,1330", FirstFPs(10));
+
+ // This used to be 9 probes, but 8 is a better choice for speed,
+ // especially with SIMD groups of 8 probes, with essentially no
+ // change in FP rate.
+ // FP rate @ 9 probes, old Bloom: 0.4321%
+ // FP rate @ 9 probes, new Bloom: 0.1846%
+ // FP rate @ 8 probes, new Bloom: 0.1843%
+ ResetPolicy(14); // num_probes = 8 (new), 9 (old)
+ for (int key = 0; key < 2087; key++) {
+ Add(Key(key, buffer));
+ }
+ Build();
+ EXPECT_EQ_LegacyBloom(GetNumProbesFromFilterData(), 9);
+ EXPECT_EQ_FastBloom(GetNumProbesFromFilterData(), 8);
+ EXPECT_EQ_Ribbon(GetRibbonSeedFromFilterData(), 61);
+
+ EXPECT_EQ_LegacyBloom(
+ BloomHash(FilterData()),
+ SelectByCacheLineSize(178861123, 379087593, 2574136516U));
+ EXPECT_EQ_FastBloom(BloomHash(FilterData()), 3709876890U);
+ EXPECT_EQ_Ribbon(BloomHash(FilterData()), 1855638875U);
+
+ EXPECT_EQ_FastBloom("130,240,522,565,989,2002,2526,3147,3543", FirstFPs(9));
+ EXPECT_EQ_Ribbon("665,727,1323,1755,3866,4232,4442,4492,4736", FirstFPs(9));
+
+ // This used to be 11 probes, but 9 is a better choice for speed
+ // AND accuracy.
+ // FP rate @ 11 probes, old Bloom: 0.3571%
+ // FP rate @ 11 probes, new Bloom: 0.0884%
+ // FP rate @ 9 probes, new Bloom: 0.0843%
+ ResetPolicy(16); // num_probes = 9 (new), 11 (old)
+ for (int key = 0; key < 2087; key++) {
+ Add(Key(key, buffer));
+ }
+ Build();
+ EXPECT_EQ_LegacyBloom(GetNumProbesFromFilterData(), 11);
+ EXPECT_EQ_FastBloom(GetNumProbesFromFilterData(), 9);
+ EXPECT_EQ_Ribbon(GetRibbonSeedFromFilterData(), 61);
+
+ EXPECT_EQ_LegacyBloom(
+ BloomHash(FilterData()),
+ SelectByCacheLineSize(1129406313, 3049154394U, 1727750964));
+ EXPECT_EQ_FastBloom(BloomHash(FilterData()), 1087138490U);
+ EXPECT_EQ_Ribbon(BloomHash(FilterData()), 459379967U);
+
+ EXPECT_EQ_FastBloom("3299,3611,3916,6620,7822,8079,8482,8942", FirstFPs(8));
+ EXPECT_EQ_Ribbon("727,1323,1755,4442,4736,5386,6974,7154,8222", FirstFPs(9));
+
+ ResetPolicy(10); // num_probes = 6, but different memory ratio vs. 9
+ for (int key = 0; key < 2087; key++) {
+ Add(Key(key, buffer));
+ }
+ Build();
+ EXPECT_EQ_Bloom(GetNumProbesFromFilterData(), 6);
+ EXPECT_EQ_Ribbon(GetRibbonSeedFromFilterData(), 61);
+
+ EXPECT_EQ_LegacyBloom(
+ BloomHash(FilterData()),
+ SelectByCacheLineSize(1478976371, 2910591341U, 1182970461));
+ EXPECT_EQ_FastBloom(BloomHash(FilterData()), 2498541272U);
+ EXPECT_EQ_Ribbon(BloomHash(FilterData()), 1273231667U);
+
+ EXPECT_EQ_FastBloom("16,126,133,422,466,472,813,1002,1035", FirstFPs(9));
+ EXPECT_EQ_Ribbon("296,411,419,612,619,623,630,665,686,727", FirstFPs(10));
+
+ ResetPolicy(10);
+ for (int key = /*CHANGED*/ 1; key < 2087; key++) {
+ Add(Key(key, buffer));
+ }
+ Build();
+ EXPECT_EQ_Bloom(GetNumProbesFromFilterData(), 6);
+ EXPECT_EQ_Ribbon(GetRibbonSeedFromFilterData(), /*CHANGED*/ 184);
+
+ EXPECT_EQ_LegacyBloom(
+ BloomHash(FilterData()),
+ SelectByCacheLineSize(4205696321U, 1132081253U, 2385981855U));
+ EXPECT_EQ_FastBloom(BloomHash(FilterData()), 2058382345U);
+ EXPECT_EQ_Ribbon(BloomHash(FilterData()), 3007790572U);
+
+ EXPECT_EQ_FastBloom("16,126,133,422,466,472,813,1002,1035", FirstFPs(9));
+ EXPECT_EQ_Ribbon("33,152,383,497,589,633,737,781,911,990", FirstFPs(10));
+
+ ResetPolicy(10);
+ for (int key = 1; key < /*CHANGED*/ 2088; key++) {
+ Add(Key(key, buffer));
+ }
+ Build();
+ EXPECT_EQ_Bloom(GetNumProbesFromFilterData(), 6);
+ EXPECT_EQ_Ribbon(GetRibbonSeedFromFilterData(), 184);
+
+ EXPECT_EQ_LegacyBloom(
+ BloomHash(FilterData()),
+ SelectByCacheLineSize(2885052954U, 769447944, 4175124908U));
+ EXPECT_EQ_FastBloom(BloomHash(FilterData()), 23699164U);
+ EXPECT_EQ_Ribbon(BloomHash(FilterData()), 1942323379U);
+
+ EXPECT_EQ_FastBloom("16,126,133,422,466,472,813,1002,1035", FirstFPs(9));
+ EXPECT_EQ_Ribbon("33,95,360,589,737,911,990,1048,1081,1414", FirstFPs(10));
+
+ // With new fractional bits_per_key, check that we are rounding to
+ // whole bits per key for old Bloom filters but fractional for
+ // new Bloom filter.
+ ResetPolicy(9.5);
+ for (int key = 1; key < 2088; key++) {
+ Add(Key(key, buffer));
+ }
+ Build();
+ EXPECT_EQ_Bloom(GetNumProbesFromFilterData(), 6);
+ EXPECT_EQ_Ribbon(GetRibbonSeedFromFilterData(), 184);
+
+ EXPECT_EQ_LegacyBloom(
+ BloomHash(FilterData()),
+ /*SAME*/ SelectByCacheLineSize(2885052954U, 769447944, 4175124908U));
+ EXPECT_EQ_FastBloom(BloomHash(FilterData()), 3166884174U);
+ EXPECT_EQ_Ribbon(BloomHash(FilterData()), 1148258663U);
+
+ EXPECT_EQ_FastBloom("126,156,367,444,458,791,813,976,1015", FirstFPs(9));
+ EXPECT_EQ_Ribbon("33,54,95,360,589,693,737,911,990,1048", FirstFPs(10));
+
+ ResetPolicy(10.499);
+ for (int key = 1; key < 2088; key++) {
+ Add(Key(key, buffer));
+ }
+ Build();
+ EXPECT_EQ_LegacyBloom(GetNumProbesFromFilterData(), 6);
+ EXPECT_EQ_FastBloom(GetNumProbesFromFilterData(), 7);
+ EXPECT_EQ_Ribbon(GetRibbonSeedFromFilterData(), 184);
+
+ EXPECT_EQ_LegacyBloom(
+ BloomHash(FilterData()),
+ /*SAME*/ SelectByCacheLineSize(2885052954U, 769447944, 4175124908U));
+ EXPECT_EQ_FastBloom(BloomHash(FilterData()), 4098502778U);
+ EXPECT_EQ_Ribbon(BloomHash(FilterData()), 792138188U);
+
+ EXPECT_EQ_FastBloom("16,236,240,472,1015,1045,1111,1409,1465", FirstFPs(9));
+ EXPECT_EQ_Ribbon("33,95,360,589,737,990,1048,1081,1414,1643", FirstFPs(10));
+
+ ResetPolicy();
+}
+
+// A helper class for testing custom or corrupt filter bits as read by
+// built-in FilterBitsReaders.
+struct RawFilterTester {
+ // Buffer, from which we always return a tail Slice, so the
+ // last five bytes are always the metadata bytes.
+ std::array<char, 3000> data_;
+ // Points five bytes from the end
+ char* metadata_ptr_;
+
+ RawFilterTester() : metadata_ptr_(&*(data_.end() - 5)) {}
+
+ Slice ResetNoFill(uint32_t len_without_metadata, uint32_t num_lines,
+ uint32_t num_probes) {
+ metadata_ptr_[0] = static_cast<char>(num_probes);
+ EncodeFixed32(metadata_ptr_ + 1, num_lines);
+ uint32_t len = len_without_metadata + /*metadata*/ 5;
+ assert(len <= data_.size());
+ return Slice(metadata_ptr_ - len_without_metadata, len);
+ }
+
+ Slice Reset(uint32_t len_without_metadata, uint32_t num_lines,
+ uint32_t num_probes, bool fill_ones) {
+ data_.fill(fill_ones ? 0xff : 0);
+ return ResetNoFill(len_without_metadata, num_lines, num_probes);
+ }
+
+ Slice ResetWeirdFill(uint32_t len_without_metadata, uint32_t num_lines,
+ uint32_t num_probes) {
+ for (uint32_t i = 0; i < data_.size(); ++i) {
+ data_[i] = static_cast<char>(0x7b7b >> (i % 7));
+ }
+ return ResetNoFill(len_without_metadata, num_lines, num_probes);
+ }
+};
+
+TEST_P(FullBloomTest, RawSchema) {
+ RawFilterTester cft;
+ // Legacy Bloom configurations
+ // Two probes, about 3/4 bits set: ~50% "FP" rate
+ // One 256-byte cache line.
+ OpenRaw(cft.ResetWeirdFill(256, 1, 2));
+ EXPECT_EQ(uint64_t{11384799501900898790U}, PackedMatches());
+
+ // Two 128-byte cache lines.
+ OpenRaw(cft.ResetWeirdFill(256, 2, 2));
+ EXPECT_EQ(uint64_t{10157853359773492589U}, PackedMatches());
+
+ // Four 64-byte cache lines.
+ OpenRaw(cft.ResetWeirdFill(256, 4, 2));
+ EXPECT_EQ(uint64_t{7123594913907464682U}, PackedMatches());
+
+ // Fast local Bloom configurations (marker 255 -> -1)
+ // Two probes, about 3/4 bits set: ~50% "FP" rate
+ // Four 64-byte cache lines.
+ OpenRaw(cft.ResetWeirdFill(256, 2U << 8, 255));
+ EXPECT_EQ(uint64_t{9957045189927952471U}, PackedMatches());
+
+ // Ribbon configurations (marker 254 -> -2)
+
+ // Even though the builder never builds configurations this
+ // small (preferring Bloom), we can test that the configuration
+ // can be read, for possible future-proofing.
+
+ // 256 slots, one result column = 32 bytes (2 blocks, seed 0)
+ // ~50% FP rate:
+ // 0b0101010111110101010000110000011011011111100100001110010011101010
+ OpenRaw(cft.ResetWeirdFill(32, 2U << 8, 254));
+ EXPECT_EQ(uint64_t{6193930559317665002U}, PackedMatches());
+
+ // 256 slots, three-to-four result columns = 112 bytes
+ // ~ 1 in 10 FP rate:
+ // 0b0000000000100000000000000000000001000001000000010000101000000000
+ OpenRaw(cft.ResetWeirdFill(112, 2U << 8, 254));
+ EXPECT_EQ(uint64_t{9007200345328128U}, PackedMatches());
+}
+
+TEST_P(FullBloomTest, CorruptFilters) {
+ RawFilterTester cft;
+
+ for (bool fill : {false, true}) {
+ // Legacy Bloom configurations
+ // Good filter bits - returns same as fill
+ OpenRaw(cft.Reset(CACHE_LINE_SIZE, 1, 6, fill));
+ ASSERT_EQ(fill, Matches("hello"));
+ ASSERT_EQ(fill, Matches("world"));
+
+ // Good filter bits - returns same as fill
+ OpenRaw(cft.Reset(CACHE_LINE_SIZE * 3, 3, 6, fill));
+ ASSERT_EQ(fill, Matches("hello"));
+ ASSERT_EQ(fill, Matches("world"));
+
+ // Good filter bits - returns same as fill
+ // 256 is unusual but legal cache line size
+ OpenRaw(cft.Reset(256 * 3, 3, 6, fill));
+ ASSERT_EQ(fill, Matches("hello"));
+ ASSERT_EQ(fill, Matches("world"));
+
+ // Good filter bits - returns same as fill
+ // 30 should be max num_probes
+ OpenRaw(cft.Reset(CACHE_LINE_SIZE, 1, 30, fill));
+ ASSERT_EQ(fill, Matches("hello"));
+ ASSERT_EQ(fill, Matches("world"));
+
+ // Good filter bits - returns same as fill
+ // 1 should be min num_probes
+ OpenRaw(cft.Reset(CACHE_LINE_SIZE, 1, 1, fill));
+ ASSERT_EQ(fill, Matches("hello"));
+ ASSERT_EQ(fill, Matches("world"));
+
+ // Type 1 trivial filter bits - returns true as if FP by zero probes
+ OpenRaw(cft.Reset(CACHE_LINE_SIZE, 1, 0, fill));
+ ASSERT_TRUE(Matches("hello"));
+ ASSERT_TRUE(Matches("world"));
+
+ // Type 2 trivial filter bits - returns false as if built from zero keys
+ OpenRaw(cft.Reset(0, 0, 6, fill));
+ ASSERT_FALSE(Matches("hello"));
+ ASSERT_FALSE(Matches("world"));
+
+ // Type 2 trivial filter bits - returns false as if built from zero keys
+ OpenRaw(cft.Reset(0, 37, 6, fill));
+ ASSERT_FALSE(Matches("hello"));
+ ASSERT_FALSE(Matches("world"));
+
+ // Type 2 trivial filter bits - returns false as 0 size trumps 0 probes
+ OpenRaw(cft.Reset(0, 0, 0, fill));
+ ASSERT_FALSE(Matches("hello"));
+ ASSERT_FALSE(Matches("world"));
+
+ // Bad filter bits - returns true for safety
+ // No solution to 0 * x == CACHE_LINE_SIZE
+ OpenRaw(cft.Reset(CACHE_LINE_SIZE, 0, 6, fill));
+ ASSERT_TRUE(Matches("hello"));
+ ASSERT_TRUE(Matches("world"));
+
+ // Bad filter bits - returns true for safety
+ // Can't have 3 * x == 4 for integer x
+ OpenRaw(cft.Reset(4, 3, 6, fill));
+ ASSERT_TRUE(Matches("hello"));
+ ASSERT_TRUE(Matches("world"));
+
+ // Bad filter bits - returns true for safety
+ // 97 bytes is not a power of two, so not a legal cache line size
+ OpenRaw(cft.Reset(97 * 3, 3, 6, fill));
+ ASSERT_TRUE(Matches("hello"));
+ ASSERT_TRUE(Matches("world"));
+
+ // Bad filter bits - returns true for safety
+ // 65 bytes is not a power of two, so not a legal cache line size
+ OpenRaw(cft.Reset(65 * 3, 3, 6, fill));
+ ASSERT_TRUE(Matches("hello"));
+ ASSERT_TRUE(Matches("world"));
+
+ // Bad filter bits - returns false as if built from zero keys
+ // < 5 bytes overall means missing even metadata
+ OpenRaw(cft.Reset(static_cast<uint32_t>(-1), 3, 6, fill));
+ ASSERT_FALSE(Matches("hello"));
+ ASSERT_FALSE(Matches("world"));
+
+ OpenRaw(cft.Reset(static_cast<uint32_t>(-5), 3, 6, fill));
+ ASSERT_FALSE(Matches("hello"));
+ ASSERT_FALSE(Matches("world"));
+
+ // Dubious filter bits - returns same as fill (for now)
+ // 31 is not a useful num_probes, nor generated by RocksDB unless directly
+ // using filter bits API without BloomFilterPolicy.
+ OpenRaw(cft.Reset(CACHE_LINE_SIZE, 1, 31, fill));
+ ASSERT_EQ(fill, Matches("hello"));
+ ASSERT_EQ(fill, Matches("world"));
+
+ // Dubious filter bits - returns same as fill (for now)
+ // Similar, with 127, largest positive char
+ OpenRaw(cft.Reset(CACHE_LINE_SIZE, 1, 127, fill));
+ ASSERT_EQ(fill, Matches("hello"));
+ ASSERT_EQ(fill, Matches("world"));
+
+ // Dubious filter bits - returns true (for now)
+ // num_probes set to 128 / -128, lowest negative char
+ // NB: Bug in implementation interprets this as negative and has same
+ // effect as zero probes, but effectively reserves negative char values
+ // for future use.
+ OpenRaw(cft.Reset(CACHE_LINE_SIZE, 1, 128, fill));
+ ASSERT_TRUE(Matches("hello"));
+ ASSERT_TRUE(Matches("world"));
+
+ // Dubious filter bits - returns true (for now)
+ // Similar, with 253 / -3
+ OpenRaw(cft.Reset(CACHE_LINE_SIZE, 1, 253, fill));
+ ASSERT_TRUE(Matches("hello"));
+ ASSERT_TRUE(Matches("world"));
+
+ // #########################################################
+ // Fast local Bloom configurations (marker 255 -> -1)
+ // Good config with six probes
+ OpenRaw(cft.Reset(CACHE_LINE_SIZE, 6U << 8, 255, fill));
+ ASSERT_EQ(fill, Matches("hello"));
+ ASSERT_EQ(fill, Matches("world"));
+
+ // Becomes bad/reserved config (always true) if any other byte set
+ OpenRaw(cft.Reset(CACHE_LINE_SIZE, (6U << 8) | 1U, 255, fill));
+ ASSERT_TRUE(Matches("hello"));
+ ASSERT_TRUE(Matches("world"));
+
+ OpenRaw(cft.Reset(CACHE_LINE_SIZE, (6U << 8) | (1U << 16), 255, fill));
+ ASSERT_TRUE(Matches("hello"));
+ ASSERT_TRUE(Matches("world"));
+
+ OpenRaw(cft.Reset(CACHE_LINE_SIZE, (6U << 8) | (1U << 24), 255, fill));
+ ASSERT_TRUE(Matches("hello"));
+ ASSERT_TRUE(Matches("world"));
+
+ // Good config, max 30 probes
+ OpenRaw(cft.Reset(CACHE_LINE_SIZE, 30U << 8, 255, fill));
+ ASSERT_EQ(fill, Matches("hello"));
+ ASSERT_EQ(fill, Matches("world"));
+
+ // Bad/reserved config (always true) if more than 30
+ OpenRaw(cft.Reset(CACHE_LINE_SIZE, 31U << 8, 255, fill));
+ ASSERT_TRUE(Matches("hello"));
+ ASSERT_TRUE(Matches("world"));
+
+ OpenRaw(cft.Reset(CACHE_LINE_SIZE, 33U << 8, 255, fill));
+ ASSERT_TRUE(Matches("hello"));
+ ASSERT_TRUE(Matches("world"));
+
+ OpenRaw(cft.Reset(CACHE_LINE_SIZE, 66U << 8, 255, fill));
+ ASSERT_TRUE(Matches("hello"));
+ ASSERT_TRUE(Matches("world"));
+
+ OpenRaw(cft.Reset(CACHE_LINE_SIZE, 130U << 8, 255, fill));
+ ASSERT_TRUE(Matches("hello"));
+ ASSERT_TRUE(Matches("world"));
+ }
+
+ // #########################################################
+ // Ribbon configurations (marker 254 -> -2)
+ // ("fill" doesn't work to detect good configurations, we just
+ // have to rely on TN probability)
+
+ // Good: 2 blocks * 16 bytes / segment * 4 columns = 128 bytes
+ // seed = 123
+ OpenRaw(cft.Reset(128, (2U << 8) + 123U, 254, false));
+ ASSERT_FALSE(Matches("hello"));
+ ASSERT_FALSE(Matches("world"));
+
+ // Good: 2 blocks * 16 bytes / segment * 8 columns = 256 bytes
+ OpenRaw(cft.Reset(256, (2U << 8) + 123U, 254, false));
+ ASSERT_FALSE(Matches("hello"));
+ ASSERT_FALSE(Matches("world"));
+
+ // Surprisingly OK: 5000 blocks (640,000 slots) in only 1024 bits
+ // -> average close to 0 columns
+ OpenRaw(cft.Reset(128, (5000U << 8) + 123U, 254, false));
+ // *Almost* all FPs
+ ASSERT_TRUE(Matches("hello"));
+ ASSERT_TRUE(Matches("world"));
+ // Need many queries to find a "true negative"
+ for (int i = 0; Matches(std::to_string(i)); ++i) {
+ ASSERT_LT(i, 1000);
+ }
+
+ // Bad: 1 block not allowed (for implementation detail reasons)
+ OpenRaw(cft.Reset(128, (1U << 8) + 123U, 254, false));
+ ASSERT_TRUE(Matches("hello"));
+ ASSERT_TRUE(Matches("world"));
+
+ // Bad: 0 blocks not allowed
+ OpenRaw(cft.Reset(128, (0U << 8) + 123U, 254, false));
+ ASSERT_TRUE(Matches("hello"));
+ ASSERT_TRUE(Matches("world"));
+}
+
+INSTANTIATE_TEST_CASE_P(Full, FullBloomTest,
+ testing::Values(kLegacyBloom, kFastLocalBloom,
+ kStandard128Ribbon));
+
+static double GetEffectiveBitsPerKey(FilterBitsBuilder* builder) {
+ union {
+ uint64_t key_value = 0;
+ char key_bytes[8];
+ };
+
+ const unsigned kNumKeys = 1000;
+
+ Slice key_slice{key_bytes, 8};
+ for (key_value = 0; key_value < kNumKeys; ++key_value) {
+ builder->AddKey(key_slice);
+ }
+
+ std::unique_ptr<const char[]> buf;
+ auto filter = builder->Finish(&buf);
+ return filter.size() * /*bits per byte*/ 8 / (1.0 * kNumKeys);
+}
+
+static void SetTestingLevel(int levelish, FilterBuildingContext* ctx) {
+ if (levelish == -1) {
+ // Flush is treated as level -1 for this option but actually level 0
+ ctx->level_at_creation = 0;
+ ctx->reason = TableFileCreationReason::kFlush;
+ } else {
+ ctx->level_at_creation = levelish;
+ ctx->reason = TableFileCreationReason::kCompaction;
+ }
+}
+
+TEST(RibbonTest, RibbonTestLevelThreshold) {
+ BlockBasedTableOptions opts;
+ FilterBuildingContext ctx(opts);
+ // A few settings
+ for (CompactionStyle cs : {kCompactionStyleLevel, kCompactionStyleUniversal,
+ kCompactionStyleFIFO, kCompactionStyleNone}) {
+ ctx.compaction_style = cs;
+ for (int bloom_before_level : {-1, 0, 1, 10}) {
+ std::vector<std::unique_ptr<const FilterPolicy> > policies;
+ policies.emplace_back(NewRibbonFilterPolicy(10, bloom_before_level));
+
+ if (bloom_before_level == 0) {
+ // Also test new API default
+ policies.emplace_back(NewRibbonFilterPolicy(10));
+ }
+
+ for (std::unique_ptr<const FilterPolicy>& policy : policies) {
+ // Claim to be generating filter for this level
+ SetTestingLevel(bloom_before_level, &ctx);
+
+ std::unique_ptr<FilterBitsBuilder> builder{
+ policy->GetBuilderWithContext(ctx)};
+
+ // Must be Ribbon (more space efficient than 10 bits per key)
+ ASSERT_LT(GetEffectiveBitsPerKey(builder.get()), 8);
+
+ if (bloom_before_level >= 0) {
+ // Claim to be generating filter for previous level
+ SetTestingLevel(bloom_before_level - 1, &ctx);
+
+ builder.reset(policy->GetBuilderWithContext(ctx));
+
+ if (cs == kCompactionStyleLevel || cs == kCompactionStyleUniversal) {
+ // Level is considered.
+ // Must be Bloom (~ 10 bits per key)
+ ASSERT_GT(GetEffectiveBitsPerKey(builder.get()), 9);
+ } else {
+ // Level is ignored under non-traditional compaction styles.
+ // Must be Ribbon (more space efficient than 10 bits per key)
+ ASSERT_LT(GetEffectiveBitsPerKey(builder.get()), 8);
+ }
+ }
+
+ // Like SST file writer
+ ctx.level_at_creation = -1;
+ ctx.reason = TableFileCreationReason::kMisc;
+
+ builder.reset(policy->GetBuilderWithContext(ctx));
+
+ // Must be Ribbon (more space efficient than 10 bits per key)
+ ASSERT_LT(GetEffectiveBitsPerKey(builder.get()), 8);
+ }
+ }
+ }
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ ParseCommandLineFlags(&argc, &argv, true);
+
+ return RUN_ALL_TESTS();
+}
+
+#endif // GFLAGS
diff --git a/src/rocksdb/util/build_version.cc.in b/src/rocksdb/util/build_version.cc.in
new file mode 100644
index 000000000..c1706dc1f
--- /dev/null
+++ b/src/rocksdb/util/build_version.cc.in
@@ -0,0 +1,81 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+#include <memory>
+
+#include "rocksdb/version.h"
+#include "rocksdb/utilities/object_registry.h"
+#include "util/string_util.h"
+
+// The build script may replace these values with real values based
+// on whether or not GIT is available and the platform settings
+static const std::string rocksdb_build_git_sha = "rocksdb_build_git_sha:@GIT_SHA@";
+static const std::string rocksdb_build_git_tag = "rocksdb_build_git_tag:@GIT_TAG@";
+#define HAS_GIT_CHANGES @GIT_MOD@
+#if HAS_GIT_CHANGES == 0
+// If HAS_GIT_CHANGES is 0, the GIT date is used.
+// Use the time the branch/tag was last modified
+static const std::string rocksdb_build_date = "rocksdb_build_date:@GIT_DATE@";
+#else
+// If HAS_GIT_CHANGES is > 0, the branch/tag has modifications.
+// Use the time the build was created.
+static const std::string rocksdb_build_date = "rocksdb_build_date:@BUILD_DATE@";
+#endif
+
+#ifndef ROCKSDB_LITE
+extern "C" {
+@ROCKSDB_PLUGIN_EXTERNS@
+} // extern "C"
+
+std::unordered_map<std::string, ROCKSDB_NAMESPACE::RegistrarFunc> ROCKSDB_NAMESPACE::ObjectRegistry::builtins_ = {
+ @ROCKSDB_PLUGIN_BUILTINS@
+};
+#endif //ROCKSDB_LITE
+
+namespace ROCKSDB_NAMESPACE {
+static void AddProperty(std::unordered_map<std::string, std::string> *props, const std::string& name) {
+ size_t colon = name.find(":");
+ if (colon != std::string::npos && colon > 0 && colon < name.length() - 1) {
+ // If we found a "@:", then this property was a build-time substitution that failed. Skip it
+ size_t at = name.find("@", colon);
+ if (at != colon + 1) {
+ // Everything before the colon is the name, after is the value
+ (*props)[name.substr(0, colon)] = name.substr(colon + 1);
+ }
+ }
+}
+
+static std::unordered_map<std::string, std::string>* LoadPropertiesSet() {
+ auto * properties = new std::unordered_map<std::string, std::string>();
+ AddProperty(properties, rocksdb_build_git_sha);
+ AddProperty(properties, rocksdb_build_git_tag);
+ AddProperty(properties, rocksdb_build_date);
+ return properties;
+}
+
+const std::unordered_map<std::string, std::string>& GetRocksBuildProperties() {
+ static std::unique_ptr<std::unordered_map<std::string, std::string>> props(LoadPropertiesSet());
+ return *props;
+}
+
+std::string GetRocksVersionAsString(bool with_patch) {
+ std::string version = std::to_string(ROCKSDB_MAJOR) + "." + std::to_string(ROCKSDB_MINOR);
+ if (with_patch) {
+ return version + "." + std::to_string(ROCKSDB_PATCH);
+ } else {
+ return version;
+ }
+}
+
+std::string GetRocksBuildInfoAsString(const std::string& program, bool verbose) {
+ std::string info = program + " (RocksDB) " + GetRocksVersionAsString(true);
+ if (verbose) {
+ for (const auto& it : GetRocksBuildProperties()) {
+ info.append("\n ");
+ info.append(it.first);
+ info.append(": ");
+ info.append(it.second);
+ }
+ }
+ return info;
+}
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/cast_util.h b/src/rocksdb/util/cast_util.h
new file mode 100644
index 000000000..c91b6ff1e
--- /dev/null
+++ b/src/rocksdb/util/cast_util.h
@@ -0,0 +1,42 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#include <type_traits>
+
+#include "rocksdb/rocksdb_namespace.h"
+
+namespace ROCKSDB_NAMESPACE {
+// The helper function to assert the move from dynamic_cast<> to
+// static_cast<> is correct. This function is to deal with legacy code.
+// It is not recommended to add new code to issue class casting. The preferred
+// solution is to implement the functionality without a need of casting.
+template <class DestClass, class SrcClass>
+inline DestClass* static_cast_with_check(SrcClass* x) {
+ DestClass* ret = static_cast<DestClass*>(x);
+#ifdef ROCKSDB_USE_RTTI
+ assert(ret == dynamic_cast<DestClass*>(x));
+#endif
+ return ret;
+}
+
+// A wrapper around static_cast for lossless conversion between integral
+// types, including enum types. For example, this can be used for converting
+// between signed/unsigned or enum type and underlying type without fear of
+// stripping away data, now or in the future.
+template <typename To, typename From>
+inline To lossless_cast(From x) {
+ using FromValue = typename std::remove_reference<From>::type;
+ static_assert(
+ std::is_integral<FromValue>::value || std::is_enum<FromValue>::value,
+ "Only works on integral types");
+ static_assert(std::is_integral<To>::value || std::is_enum<To>::value,
+ "Only works on integral types");
+ static_assert(sizeof(To) >= sizeof(FromValue), "Must be lossless");
+ return static_cast<To>(x);
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/channel.h b/src/rocksdb/util/channel.h
new file mode 100644
index 000000000..19b956297
--- /dev/null
+++ b/src/rocksdb/util/channel.h
@@ -0,0 +1,69 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#include <condition_variable>
+#include <mutex>
+#include <queue>
+#include <utility>
+
+#include "rocksdb/rocksdb_namespace.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+template <class T>
+class channel {
+ public:
+ explicit channel() : eof_(false) {}
+
+ channel(const channel&) = delete;
+ void operator=(const channel&) = delete;
+
+ void sendEof() {
+ std::lock_guard<std::mutex> lk(lock_);
+ eof_ = true;
+ cv_.notify_all();
+ }
+
+ bool eof() {
+ std::lock_guard<std::mutex> lk(lock_);
+ return buffer_.empty() && eof_;
+ }
+
+ size_t size() const {
+ std::lock_guard<std::mutex> lk(lock_);
+ return buffer_.size();
+ }
+
+ // writes elem to the queue
+ void write(T&& elem) {
+ std::unique_lock<std::mutex> lk(lock_);
+ buffer_.emplace(std::forward<T>(elem));
+ cv_.notify_one();
+ }
+
+ /// Moves a dequeued element onto elem, blocking until an element
+ /// is available.
+ // returns false if EOF
+ bool read(T& elem) {
+ std::unique_lock<std::mutex> lk(lock_);
+ cv_.wait(lk, [&] { return eof_ || !buffer_.empty(); });
+ if (eof_ && buffer_.empty()) {
+ return false;
+ }
+ elem = std::move(buffer_.front());
+ buffer_.pop();
+ cv_.notify_one();
+ return true;
+ }
+
+ private:
+ std::condition_variable cv_;
+ mutable std::mutex lock_;
+ std::queue<T> buffer_;
+ bool eof_;
+};
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/cleanable.cc b/src/rocksdb/util/cleanable.cc
new file mode 100644
index 000000000..89a7ab9be
--- /dev/null
+++ b/src/rocksdb/util/cleanable.cc
@@ -0,0 +1,181 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "rocksdb/cleanable.h"
+
+#include <atomic>
+#include <cassert>
+#include <utility>
+
+namespace ROCKSDB_NAMESPACE {
+
+Cleanable::Cleanable() {
+ cleanup_.function = nullptr;
+ cleanup_.next = nullptr;
+}
+
+Cleanable::~Cleanable() { DoCleanup(); }
+
+Cleanable::Cleanable(Cleanable&& other) noexcept { *this = std::move(other); }
+
+Cleanable& Cleanable::operator=(Cleanable&& other) noexcept {
+ assert(this != &other); // https://stackoverflow.com/a/9322542/454544
+ cleanup_ = other.cleanup_;
+ other.cleanup_.function = nullptr;
+ other.cleanup_.next = nullptr;
+ return *this;
+}
+
+// If the entire linked list was on heap we could have simply add attach one
+// link list to another. However the head is an embeded object to avoid the cost
+// of creating objects for most of the use cases when the Cleanable has only one
+// Cleanup to do. We could put evernything on heap if benchmarks show no
+// negative impact on performance.
+// Also we need to iterate on the linked list since there is no pointer to the
+// tail. We can add the tail pointer but maintainin it might negatively impact
+// the perforamnce for the common case of one cleanup where tail pointer is not
+// needed. Again benchmarks could clarify that.
+// Even without a tail pointer we could iterate on the list, find the tail, and
+// have only that node updated without the need to insert the Cleanups one by
+// one. This however would be redundant when the source Cleanable has one or a
+// few Cleanups which is the case most of the time.
+// TODO(myabandeh): if the list is too long we should maintain a tail pointer
+// and have the entire list (minus the head that has to be inserted separately)
+// merged with the target linked list at once.
+void Cleanable::DelegateCleanupsTo(Cleanable* other) {
+ assert(other != nullptr);
+ if (cleanup_.function == nullptr) {
+ return;
+ }
+ Cleanup* c = &cleanup_;
+ other->RegisterCleanup(c->function, c->arg1, c->arg2);
+ c = c->next;
+ while (c != nullptr) {
+ Cleanup* next = c->next;
+ other->RegisterCleanup(c);
+ c = next;
+ }
+ cleanup_.function = nullptr;
+ cleanup_.next = nullptr;
+}
+
+void Cleanable::RegisterCleanup(Cleanable::Cleanup* c) {
+ assert(c != nullptr);
+ if (cleanup_.function == nullptr) {
+ cleanup_.function = c->function;
+ cleanup_.arg1 = c->arg1;
+ cleanup_.arg2 = c->arg2;
+ delete c;
+ } else {
+ c->next = cleanup_.next;
+ cleanup_.next = c;
+ }
+}
+
+void Cleanable::RegisterCleanup(CleanupFunction func, void* arg1, void* arg2) {
+ assert(func != nullptr);
+ Cleanup* c;
+ if (cleanup_.function == nullptr) {
+ c = &cleanup_;
+ } else {
+ c = new Cleanup;
+ c->next = cleanup_.next;
+ cleanup_.next = c;
+ }
+ c->function = func;
+ c->arg1 = arg1;
+ c->arg2 = arg2;
+}
+
+struct SharedCleanablePtr::Impl : public Cleanable {
+ std::atomic<unsigned> ref_count{1}; // Start with 1 ref
+ void Ref() { ref_count.fetch_add(1, std::memory_order_relaxed); }
+ void Unref() {
+ if (ref_count.fetch_sub(1, std::memory_order_relaxed) == 1) {
+ // Last ref
+ delete this;
+ }
+ }
+ static void UnrefWrapper(void* arg1, void* /*arg2*/) {
+ static_cast<SharedCleanablePtr::Impl*>(arg1)->Unref();
+ }
+};
+
+void SharedCleanablePtr::Reset() {
+ if (ptr_) {
+ ptr_->Unref();
+ ptr_ = nullptr;
+ }
+}
+
+void SharedCleanablePtr::Allocate() {
+ Reset();
+ ptr_ = new Impl();
+}
+
+SharedCleanablePtr::SharedCleanablePtr(const SharedCleanablePtr& from) {
+ *this = from;
+}
+
+SharedCleanablePtr::SharedCleanablePtr(SharedCleanablePtr&& from) noexcept {
+ *this = std::move(from);
+}
+
+SharedCleanablePtr& SharedCleanablePtr::operator=(
+ const SharedCleanablePtr& from) {
+ if (this != &from) {
+ Reset();
+ ptr_ = from.ptr_;
+ if (ptr_) {
+ ptr_->Ref();
+ }
+ }
+ return *this;
+}
+
+SharedCleanablePtr& SharedCleanablePtr::operator=(
+ SharedCleanablePtr&& from) noexcept {
+ assert(this != &from); // https://stackoverflow.com/a/9322542/454544
+ Reset();
+ ptr_ = from.ptr_;
+ from.ptr_ = nullptr;
+ return *this;
+}
+
+SharedCleanablePtr::~SharedCleanablePtr() { Reset(); }
+
+Cleanable& SharedCleanablePtr::operator*() {
+ return *ptr_; // implicit upcast
+}
+
+Cleanable* SharedCleanablePtr::operator->() {
+ return ptr_; // implicit upcast
+}
+
+Cleanable* SharedCleanablePtr::get() {
+ return ptr_; // implicit upcast
+}
+
+void SharedCleanablePtr::RegisterCopyWith(Cleanable* target) {
+ if (ptr_) {
+ // "Virtual" copy of the pointer
+ ptr_->Ref();
+ target->RegisterCleanup(&Impl::UnrefWrapper, ptr_, nullptr);
+ }
+}
+
+void SharedCleanablePtr::MoveAsCleanupTo(Cleanable* target) {
+ if (ptr_) {
+ // "Virtual" move of the pointer
+ target->RegisterCleanup(&Impl::UnrefWrapper, ptr_, nullptr);
+ ptr_ = nullptr;
+ }
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/coding.cc b/src/rocksdb/util/coding.cc
new file mode 100644
index 000000000..3da8afaa2
--- /dev/null
+++ b/src/rocksdb/util/coding.cc
@@ -0,0 +1,90 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "util/coding.h"
+
+#include <algorithm>
+
+#include "rocksdb/slice.h"
+#include "rocksdb/slice_transform.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// conversion' conversion from 'type1' to 'type2', possible loss of data
+#if defined(_MSC_VER)
+#pragma warning(push)
+#pragma warning(disable : 4244)
+#endif
+char* EncodeVarint32(char* dst, uint32_t v) {
+ // Operate on characters as unsigneds
+ unsigned char* ptr = reinterpret_cast<unsigned char*>(dst);
+ static const int B = 128;
+ if (v < (1 << 7)) {
+ *(ptr++) = v;
+ } else if (v < (1 << 14)) {
+ *(ptr++) = v | B;
+ *(ptr++) = v >> 7;
+ } else if (v < (1 << 21)) {
+ *(ptr++) = v | B;
+ *(ptr++) = (v >> 7) | B;
+ *(ptr++) = v >> 14;
+ } else if (v < (1 << 28)) {
+ *(ptr++) = v | B;
+ *(ptr++) = (v >> 7) | B;
+ *(ptr++) = (v >> 14) | B;
+ *(ptr++) = v >> 21;
+ } else {
+ *(ptr++) = v | B;
+ *(ptr++) = (v >> 7) | B;
+ *(ptr++) = (v >> 14) | B;
+ *(ptr++) = (v >> 21) | B;
+ *(ptr++) = v >> 28;
+ }
+ return reinterpret_cast<char*>(ptr);
+}
+#if defined(_MSC_VER)
+#pragma warning(pop)
+#endif
+
+const char* GetVarint32PtrFallback(const char* p, const char* limit,
+ uint32_t* value) {
+ uint32_t result = 0;
+ for (uint32_t shift = 0; shift <= 28 && p < limit; shift += 7) {
+ uint32_t byte = *(reinterpret_cast<const unsigned char*>(p));
+ p++;
+ if (byte & 128) {
+ // More bytes are present
+ result |= ((byte & 127) << shift);
+ } else {
+ result |= (byte << shift);
+ *value = result;
+ return reinterpret_cast<const char*>(p);
+ }
+ }
+ return nullptr;
+}
+
+const char* GetVarint64Ptr(const char* p, const char* limit, uint64_t* value) {
+ uint64_t result = 0;
+ for (uint32_t shift = 0; shift <= 63 && p < limit; shift += 7) {
+ uint64_t byte = *(reinterpret_cast<const unsigned char*>(p));
+ p++;
+ if (byte & 128) {
+ // More bytes are present
+ result |= ((byte & 127) << shift);
+ } else {
+ result |= (byte << shift);
+ *value = result;
+ return reinterpret_cast<const char*>(p);
+ }
+ }
+ return nullptr;
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/coding.h b/src/rocksdb/util/coding.h
new file mode 100644
index 000000000..3168fd2fd
--- /dev/null
+++ b/src/rocksdb/util/coding.h
@@ -0,0 +1,389 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+//
+// Encoding independent of machine byte order:
+// * Fixed-length numbers are encoded with least-significant byte first
+// (little endian, native order on Intel and others)
+// * In addition we support variable length "varint" encoding
+// * Strings are encoded prefixed by their length in varint format
+//
+// Some related functions are provided in coding_lean.h
+
+#pragma once
+#include <algorithm>
+#include <string>
+
+#include "port/port.h"
+#include "rocksdb/slice.h"
+#include "util/coding_lean.h"
+
+// Some processors does not allow unaligned access to memory
+#if defined(__sparc)
+#define PLATFORM_UNALIGNED_ACCESS_NOT_ALLOWED
+#endif
+
+namespace ROCKSDB_NAMESPACE {
+
+// The maximum length of a varint in bytes for 64-bit.
+const uint32_t kMaxVarint64Length = 10;
+
+// Standard Put... routines append to a string
+extern void PutFixed16(std::string* dst, uint16_t value);
+extern void PutFixed32(std::string* dst, uint32_t value);
+extern void PutFixed64(std::string* dst, uint64_t value);
+extern void PutVarint32(std::string* dst, uint32_t value);
+extern void PutVarint32Varint32(std::string* dst, uint32_t value1,
+ uint32_t value2);
+extern void PutVarint32Varint32Varint32(std::string* dst, uint32_t value1,
+ uint32_t value2, uint32_t value3);
+extern void PutVarint64(std::string* dst, uint64_t value);
+extern void PutVarint64Varint64(std::string* dst, uint64_t value1,
+ uint64_t value2);
+extern void PutVarint32Varint64(std::string* dst, uint32_t value1,
+ uint64_t value2);
+extern void PutVarint32Varint32Varint64(std::string* dst, uint32_t value1,
+ uint32_t value2, uint64_t value3);
+extern void PutLengthPrefixedSlice(std::string* dst, const Slice& value);
+extern void PutLengthPrefixedSliceParts(std::string* dst,
+ const SliceParts& slice_parts);
+extern void PutLengthPrefixedSlicePartsWithPadding(
+ std::string* dst, const SliceParts& slice_parts, size_t pad_sz);
+
+// Standard Get... routines parse a value from the beginning of a Slice
+// and advance the slice past the parsed value.
+extern bool GetFixed64(Slice* input, uint64_t* value);
+extern bool GetFixed32(Slice* input, uint32_t* value);
+extern bool GetFixed16(Slice* input, uint16_t* value);
+extern bool GetVarint32(Slice* input, uint32_t* value);
+extern bool GetVarint64(Slice* input, uint64_t* value);
+extern bool GetVarsignedint64(Slice* input, int64_t* value);
+extern bool GetLengthPrefixedSlice(Slice* input, Slice* result);
+// This function assumes data is well-formed.
+extern Slice GetLengthPrefixedSlice(const char* data);
+
+extern Slice GetSliceUntil(Slice* slice, char delimiter);
+
+// Borrowed from
+// https://github.com/facebook/fbthrift/blob/449a5f77f9f9bae72c9eb5e78093247eef185c04/thrift/lib/cpp/util/VarintUtils-inl.h#L202-L208
+constexpr inline uint64_t i64ToZigzag(const int64_t l) {
+ return (static_cast<uint64_t>(l) << 1) ^ static_cast<uint64_t>(l >> 63);
+}
+inline int64_t zigzagToI64(uint64_t n) {
+ return (n >> 1) ^ -static_cast<int64_t>(n & 1);
+}
+
+// Pointer-based variants of GetVarint... These either store a value
+// in *v and return a pointer just past the parsed value, or return
+// nullptr on error. These routines only look at bytes in the range
+// [p..limit-1]
+extern const char* GetVarint32Ptr(const char* p, const char* limit,
+ uint32_t* v);
+extern const char* GetVarint64Ptr(const char* p, const char* limit,
+ uint64_t* v);
+inline const char* GetVarsignedint64Ptr(const char* p, const char* limit,
+ int64_t* value) {
+ uint64_t u = 0;
+ const char* ret = GetVarint64Ptr(p, limit, &u);
+ *value = zigzagToI64(u);
+ return ret;
+}
+
+// Returns the length of the varint32 or varint64 encoding of "v"
+extern int VarintLength(uint64_t v);
+
+// Lower-level versions of Put... that write directly into a character buffer
+// and return a pointer just past the last byte written.
+// REQUIRES: dst has enough space for the value being written
+extern char* EncodeVarint32(char* dst, uint32_t value);
+extern char* EncodeVarint64(char* dst, uint64_t value);
+
+// Internal routine for use by fallback path of GetVarint32Ptr
+extern const char* GetVarint32PtrFallback(const char* p, const char* limit,
+ uint32_t* value);
+inline const char* GetVarint32Ptr(const char* p, const char* limit,
+ uint32_t* value) {
+ if (p < limit) {
+ uint32_t result = *(reinterpret_cast<const unsigned char*>(p));
+ if ((result & 128) == 0) {
+ *value = result;
+ return p + 1;
+ }
+ }
+ return GetVarint32PtrFallback(p, limit, value);
+}
+
+// Pull the last 8 bits and cast it to a character
+inline void PutFixed16(std::string* dst, uint16_t value) {
+ if (port::kLittleEndian) {
+ dst->append(const_cast<const char*>(reinterpret_cast<char*>(&value)),
+ sizeof(value));
+ } else {
+ char buf[sizeof(value)];
+ EncodeFixed16(buf, value);
+ dst->append(buf, sizeof(buf));
+ }
+}
+
+inline void PutFixed32(std::string* dst, uint32_t value) {
+ if (port::kLittleEndian) {
+ dst->append(const_cast<const char*>(reinterpret_cast<char*>(&value)),
+ sizeof(value));
+ } else {
+ char buf[sizeof(value)];
+ EncodeFixed32(buf, value);
+ dst->append(buf, sizeof(buf));
+ }
+}
+
+inline void PutFixed64(std::string* dst, uint64_t value) {
+ if (port::kLittleEndian) {
+ dst->append(const_cast<const char*>(reinterpret_cast<char*>(&value)),
+ sizeof(value));
+ } else {
+ char buf[sizeof(value)];
+ EncodeFixed64(buf, value);
+ dst->append(buf, sizeof(buf));
+ }
+}
+
+inline void PutVarint32(std::string* dst, uint32_t v) {
+ char buf[5];
+ char* ptr = EncodeVarint32(buf, v);
+ dst->append(buf, static_cast<size_t>(ptr - buf));
+}
+
+inline void PutVarint32Varint32(std::string* dst, uint32_t v1, uint32_t v2) {
+ char buf[10];
+ char* ptr = EncodeVarint32(buf, v1);
+ ptr = EncodeVarint32(ptr, v2);
+ dst->append(buf, static_cast<size_t>(ptr - buf));
+}
+
+inline void PutVarint32Varint32Varint32(std::string* dst, uint32_t v1,
+ uint32_t v2, uint32_t v3) {
+ char buf[15];
+ char* ptr = EncodeVarint32(buf, v1);
+ ptr = EncodeVarint32(ptr, v2);
+ ptr = EncodeVarint32(ptr, v3);
+ dst->append(buf, static_cast<size_t>(ptr - buf));
+}
+
+inline char* EncodeVarint64(char* dst, uint64_t v) {
+ static const unsigned int B = 128;
+ unsigned char* ptr = reinterpret_cast<unsigned char*>(dst);
+ while (v >= B) {
+ *(ptr++) = (v & (B - 1)) | B;
+ v >>= 7;
+ }
+ *(ptr++) = static_cast<unsigned char>(v);
+ return reinterpret_cast<char*>(ptr);
+}
+
+inline void PutVarint64(std::string* dst, uint64_t v) {
+ char buf[kMaxVarint64Length];
+ char* ptr = EncodeVarint64(buf, v);
+ dst->append(buf, static_cast<size_t>(ptr - buf));
+}
+
+inline void PutVarsignedint64(std::string* dst, int64_t v) {
+ char buf[kMaxVarint64Length];
+ // Using Zigzag format to convert signed to unsigned
+ char* ptr = EncodeVarint64(buf, i64ToZigzag(v));
+ dst->append(buf, static_cast<size_t>(ptr - buf));
+}
+
+inline void PutVarint64Varint64(std::string* dst, uint64_t v1, uint64_t v2) {
+ char buf[20];
+ char* ptr = EncodeVarint64(buf, v1);
+ ptr = EncodeVarint64(ptr, v2);
+ dst->append(buf, static_cast<size_t>(ptr - buf));
+}
+
+inline void PutVarint32Varint64(std::string* dst, uint32_t v1, uint64_t v2) {
+ char buf[15];
+ char* ptr = EncodeVarint32(buf, v1);
+ ptr = EncodeVarint64(ptr, v2);
+ dst->append(buf, static_cast<size_t>(ptr - buf));
+}
+
+inline void PutVarint32Varint32Varint64(std::string* dst, uint32_t v1,
+ uint32_t v2, uint64_t v3) {
+ char buf[20];
+ char* ptr = EncodeVarint32(buf, v1);
+ ptr = EncodeVarint32(ptr, v2);
+ ptr = EncodeVarint64(ptr, v3);
+ dst->append(buf, static_cast<size_t>(ptr - buf));
+}
+
+inline void PutLengthPrefixedSlice(std::string* dst, const Slice& value) {
+ PutVarint32(dst, static_cast<uint32_t>(value.size()));
+ dst->append(value.data(), value.size());
+}
+
+inline void PutLengthPrefixedSliceParts(std::string* dst, size_t total_bytes,
+ const SliceParts& slice_parts) {
+ for (int i = 0; i < slice_parts.num_parts; ++i) {
+ total_bytes += slice_parts.parts[i].size();
+ }
+ PutVarint32(dst, static_cast<uint32_t>(total_bytes));
+ for (int i = 0; i < slice_parts.num_parts; ++i) {
+ dst->append(slice_parts.parts[i].data(), slice_parts.parts[i].size());
+ }
+}
+
+inline void PutLengthPrefixedSliceParts(std::string* dst,
+ const SliceParts& slice_parts) {
+ PutLengthPrefixedSliceParts(dst, /*total_bytes=*/0, slice_parts);
+}
+
+inline void PutLengthPrefixedSlicePartsWithPadding(
+ std::string* dst, const SliceParts& slice_parts, size_t pad_sz) {
+ PutLengthPrefixedSliceParts(dst, /*total_bytes=*/pad_sz, slice_parts);
+ dst->append(pad_sz, '\0');
+}
+
+inline int VarintLength(uint64_t v) {
+ int len = 1;
+ while (v >= 128) {
+ v >>= 7;
+ len++;
+ }
+ return len;
+}
+
+inline bool GetFixed64(Slice* input, uint64_t* value) {
+ if (input->size() < sizeof(uint64_t)) {
+ return false;
+ }
+ *value = DecodeFixed64(input->data());
+ input->remove_prefix(sizeof(uint64_t));
+ return true;
+}
+
+inline bool GetFixed32(Slice* input, uint32_t* value) {
+ if (input->size() < sizeof(uint32_t)) {
+ return false;
+ }
+ *value = DecodeFixed32(input->data());
+ input->remove_prefix(sizeof(uint32_t));
+ return true;
+}
+
+inline bool GetFixed16(Slice* input, uint16_t* value) {
+ if (input->size() < sizeof(uint16_t)) {
+ return false;
+ }
+ *value = DecodeFixed16(input->data());
+ input->remove_prefix(sizeof(uint16_t));
+ return true;
+}
+
+inline bool GetVarint32(Slice* input, uint32_t* value) {
+ const char* p = input->data();
+ const char* limit = p + input->size();
+ const char* q = GetVarint32Ptr(p, limit, value);
+ if (q == nullptr) {
+ return false;
+ } else {
+ *input = Slice(q, static_cast<size_t>(limit - q));
+ return true;
+ }
+}
+
+inline bool GetVarint64(Slice* input, uint64_t* value) {
+ const char* p = input->data();
+ const char* limit = p + input->size();
+ const char* q = GetVarint64Ptr(p, limit, value);
+ if (q == nullptr) {
+ return false;
+ } else {
+ *input = Slice(q, static_cast<size_t>(limit - q));
+ return true;
+ }
+}
+
+inline bool GetVarsignedint64(Slice* input, int64_t* value) {
+ const char* p = input->data();
+ const char* limit = p + input->size();
+ const char* q = GetVarsignedint64Ptr(p, limit, value);
+ if (q == nullptr) {
+ return false;
+ } else {
+ *input = Slice(q, static_cast<size_t>(limit - q));
+ return true;
+ }
+}
+
+inline bool GetLengthPrefixedSlice(Slice* input, Slice* result) {
+ uint32_t len = 0;
+ if (GetVarint32(input, &len) && input->size() >= len) {
+ *result = Slice(input->data(), len);
+ input->remove_prefix(len);
+ return true;
+ } else {
+ return false;
+ }
+}
+
+inline Slice GetLengthPrefixedSlice(const char* data) {
+ uint32_t len = 0;
+ // +5: we assume "data" is not corrupted
+ // unsigned char is 7 bits, uint32_t is 32 bits, need 5 unsigned char
+ auto p = GetVarint32Ptr(data, data + 5 /* limit */, &len);
+ return Slice(p, len);
+}
+
+inline Slice GetSliceUntil(Slice* slice, char delimiter) {
+ uint32_t len = 0;
+ for (len = 0; len < slice->size() && slice->data()[len] != delimiter; ++len) {
+ // nothing
+ }
+
+ Slice ret(slice->data(), len);
+ slice->remove_prefix(len + ((len < slice->size()) ? 1 : 0));
+ return ret;
+}
+
+template <class T>
+#ifdef ROCKSDB_UBSAN_RUN
+#if defined(__clang__)
+__attribute__((__no_sanitize__("alignment")))
+#elif defined(__GNUC__)
+__attribute__((__no_sanitize_undefined__))
+#endif
+#endif
+inline void
+PutUnaligned(T* memory, const T& value) {
+#if defined(PLATFORM_UNALIGNED_ACCESS_NOT_ALLOWED)
+ char* nonAlignedMemory = reinterpret_cast<char*>(memory);
+ memcpy(nonAlignedMemory, reinterpret_cast<const char*>(&value), sizeof(T));
+#else
+ *memory = value;
+#endif
+}
+
+template <class T>
+#ifdef ROCKSDB_UBSAN_RUN
+#if defined(__clang__)
+__attribute__((__no_sanitize__("alignment")))
+#elif defined(__GNUC__)
+__attribute__((__no_sanitize_undefined__))
+#endif
+#endif
+inline void
+GetUnaligned(const T* memory, T* value) {
+#if defined(PLATFORM_UNALIGNED_ACCESS_NOT_ALLOWED)
+ char* nonAlignedMemory = reinterpret_cast<char*>(value);
+ memcpy(nonAlignedMemory, reinterpret_cast<const char*>(memory), sizeof(T));
+#else
+ *value = *memory;
+#endif
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/coding_lean.h b/src/rocksdb/util/coding_lean.h
new file mode 100644
index 000000000..6966f7a66
--- /dev/null
+++ b/src/rocksdb/util/coding_lean.h
@@ -0,0 +1,101 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+// Encoding independent of machine byte order:
+// * Fixed-length numbers are encoded with least-significant byte first
+// (little endian, native order on Intel and others)
+//
+// More functions in coding.h
+
+#pragma once
+
+#include <cstdint>
+#include <cstring>
+
+#include "port/port.h" // for port::kLittleEndian
+
+namespace ROCKSDB_NAMESPACE {
+
+// Lower-level versions of Put... that write directly into a character buffer
+// REQUIRES: dst has enough space for the value being written
+// -- Implementation of the functions declared above
+inline void EncodeFixed16(char* buf, uint16_t value) {
+ if (port::kLittleEndian) {
+ memcpy(buf, &value, sizeof(value));
+ } else {
+ buf[0] = value & 0xff;
+ buf[1] = (value >> 8) & 0xff;
+ }
+}
+
+inline void EncodeFixed32(char* buf, uint32_t value) {
+ if (port::kLittleEndian) {
+ memcpy(buf, &value, sizeof(value));
+ } else {
+ buf[0] = value & 0xff;
+ buf[1] = (value >> 8) & 0xff;
+ buf[2] = (value >> 16) & 0xff;
+ buf[3] = (value >> 24) & 0xff;
+ }
+}
+
+inline void EncodeFixed64(char* buf, uint64_t value) {
+ if (port::kLittleEndian) {
+ memcpy(buf, &value, sizeof(value));
+ } else {
+ buf[0] = value & 0xff;
+ buf[1] = (value >> 8) & 0xff;
+ buf[2] = (value >> 16) & 0xff;
+ buf[3] = (value >> 24) & 0xff;
+ buf[4] = (value >> 32) & 0xff;
+ buf[5] = (value >> 40) & 0xff;
+ buf[6] = (value >> 48) & 0xff;
+ buf[7] = (value >> 56) & 0xff;
+ }
+}
+
+// Lower-level versions of Get... that read directly from a character buffer
+// without any bounds checking.
+
+inline uint16_t DecodeFixed16(const char* ptr) {
+ if (port::kLittleEndian) {
+ // Load the raw bytes
+ uint16_t result;
+ memcpy(&result, ptr, sizeof(result)); // gcc optimizes this to a plain load
+ return result;
+ } else {
+ return ((static_cast<uint16_t>(static_cast<unsigned char>(ptr[0]))) |
+ (static_cast<uint16_t>(static_cast<unsigned char>(ptr[1])) << 8));
+ }
+}
+
+inline uint32_t DecodeFixed32(const char* ptr) {
+ if (port::kLittleEndian) {
+ // Load the raw bytes
+ uint32_t result;
+ memcpy(&result, ptr, sizeof(result)); // gcc optimizes this to a plain load
+ return result;
+ } else {
+ return ((static_cast<uint32_t>(static_cast<unsigned char>(ptr[0]))) |
+ (static_cast<uint32_t>(static_cast<unsigned char>(ptr[1])) << 8) |
+ (static_cast<uint32_t>(static_cast<unsigned char>(ptr[2])) << 16) |
+ (static_cast<uint32_t>(static_cast<unsigned char>(ptr[3])) << 24));
+ }
+}
+
+inline uint64_t DecodeFixed64(const char* ptr) {
+ if (port::kLittleEndian) {
+ // Load the raw bytes
+ uint64_t result;
+ memcpy(&result, ptr, sizeof(result)); // gcc optimizes this to a plain load
+ return result;
+ } else {
+ uint64_t lo = DecodeFixed32(ptr);
+ uint64_t hi = DecodeFixed32(ptr + 4);
+ return (hi << 32) | lo;
+ }
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/coding_test.cc b/src/rocksdb/util/coding_test.cc
new file mode 100644
index 000000000..79dd7b82e
--- /dev/null
+++ b/src/rocksdb/util/coding_test.cc
@@ -0,0 +1,217 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "util/coding.h"
+
+#include "test_util/testharness.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class Coding {};
+TEST(Coding, Fixed16) {
+ std::string s;
+ for (uint16_t v = 0; v < 0xFFFF; v++) {
+ PutFixed16(&s, v);
+ }
+
+ const char* p = s.data();
+ for (uint16_t v = 0; v < 0xFFFF; v++) {
+ uint16_t actual = DecodeFixed16(p);
+ ASSERT_EQ(v, actual);
+ p += sizeof(uint16_t);
+ }
+}
+
+TEST(Coding, Fixed32) {
+ std::string s;
+ for (uint32_t v = 0; v < 100000; v++) {
+ PutFixed32(&s, v);
+ }
+
+ const char* p = s.data();
+ for (uint32_t v = 0; v < 100000; v++) {
+ uint32_t actual = DecodeFixed32(p);
+ ASSERT_EQ(v, actual);
+ p += sizeof(uint32_t);
+ }
+}
+
+TEST(Coding, Fixed64) {
+ std::string s;
+ for (int power = 0; power <= 63; power++) {
+ uint64_t v = static_cast<uint64_t>(1) << power;
+ PutFixed64(&s, v - 1);
+ PutFixed64(&s, v + 0);
+ PutFixed64(&s, v + 1);
+ }
+
+ const char* p = s.data();
+ for (int power = 0; power <= 63; power++) {
+ uint64_t v = static_cast<uint64_t>(1) << power;
+ uint64_t actual = 0;
+ actual = DecodeFixed64(p);
+ ASSERT_EQ(v - 1, actual);
+ p += sizeof(uint64_t);
+
+ actual = DecodeFixed64(p);
+ ASSERT_EQ(v + 0, actual);
+ p += sizeof(uint64_t);
+
+ actual = DecodeFixed64(p);
+ ASSERT_EQ(v + 1, actual);
+ p += sizeof(uint64_t);
+ }
+}
+
+// Test that encoding routines generate little-endian encodings
+TEST(Coding, EncodingOutput) {
+ std::string dst;
+ PutFixed32(&dst, 0x04030201);
+ ASSERT_EQ(4U, dst.size());
+ ASSERT_EQ(0x01, static_cast<int>(dst[0]));
+ ASSERT_EQ(0x02, static_cast<int>(dst[1]));
+ ASSERT_EQ(0x03, static_cast<int>(dst[2]));
+ ASSERT_EQ(0x04, static_cast<int>(dst[3]));
+
+ dst.clear();
+ PutFixed64(&dst, 0x0807060504030201ull);
+ ASSERT_EQ(8U, dst.size());
+ ASSERT_EQ(0x01, static_cast<int>(dst[0]));
+ ASSERT_EQ(0x02, static_cast<int>(dst[1]));
+ ASSERT_EQ(0x03, static_cast<int>(dst[2]));
+ ASSERT_EQ(0x04, static_cast<int>(dst[3]));
+ ASSERT_EQ(0x05, static_cast<int>(dst[4]));
+ ASSERT_EQ(0x06, static_cast<int>(dst[5]));
+ ASSERT_EQ(0x07, static_cast<int>(dst[6]));
+ ASSERT_EQ(0x08, static_cast<int>(dst[7]));
+}
+
+TEST(Coding, Varint32) {
+ std::string s;
+ for (uint32_t i = 0; i < (32 * 32); i++) {
+ uint32_t v = (i / 32) << (i % 32);
+ PutVarint32(&s, v);
+ }
+
+ const char* p = s.data();
+ const char* limit = p + s.size();
+ for (uint32_t i = 0; i < (32 * 32); i++) {
+ uint32_t expected = (i / 32) << (i % 32);
+ uint32_t actual = 0;
+ const char* start = p;
+ p = GetVarint32Ptr(p, limit, &actual);
+ ASSERT_TRUE(p != nullptr);
+ ASSERT_EQ(expected, actual);
+ ASSERT_EQ(VarintLength(actual), p - start);
+ }
+ ASSERT_EQ(p, s.data() + s.size());
+}
+
+TEST(Coding, Varint64) {
+ // Construct the list of values to check
+ std::vector<uint64_t> values;
+ // Some special values
+ values.push_back(0);
+ values.push_back(100);
+ values.push_back(~static_cast<uint64_t>(0));
+ values.push_back(~static_cast<uint64_t>(0) - 1);
+ for (uint32_t k = 0; k < 64; k++) {
+ // Test values near powers of two
+ const uint64_t power = 1ull << k;
+ values.push_back(power);
+ values.push_back(power - 1);
+ values.push_back(power + 1);
+ };
+
+ std::string s;
+ for (unsigned int i = 0; i < values.size(); i++) {
+ PutVarint64(&s, values[i]);
+ }
+
+ const char* p = s.data();
+ const char* limit = p + s.size();
+ for (unsigned int i = 0; i < values.size(); i++) {
+ ASSERT_TRUE(p < limit);
+ uint64_t actual = 0;
+ const char* start = p;
+ p = GetVarint64Ptr(p, limit, &actual);
+ ASSERT_TRUE(p != nullptr);
+ ASSERT_EQ(values[i], actual);
+ ASSERT_EQ(VarintLength(actual), p - start);
+ }
+ ASSERT_EQ(p, limit);
+}
+
+TEST(Coding, Varint32Overflow) {
+ uint32_t result;
+ std::string input("\x81\x82\x83\x84\x85\x11");
+ ASSERT_TRUE(GetVarint32Ptr(input.data(), input.data() + input.size(),
+ &result) == nullptr);
+}
+
+TEST(Coding, Varint32Truncation) {
+ uint32_t large_value = (1u << 31) + 100;
+ std::string s;
+ PutVarint32(&s, large_value);
+ uint32_t result;
+ for (unsigned int len = 0; len + 1 < s.size(); len++) {
+ ASSERT_TRUE(GetVarint32Ptr(s.data(), s.data() + len, &result) == nullptr);
+ }
+ ASSERT_TRUE(GetVarint32Ptr(s.data(), s.data() + s.size(), &result) !=
+ nullptr);
+ ASSERT_EQ(large_value, result);
+}
+
+TEST(Coding, Varint64Overflow) {
+ uint64_t result;
+ std::string input("\x81\x82\x83\x84\x85\x81\x82\x83\x84\x85\x11");
+ ASSERT_TRUE(GetVarint64Ptr(input.data(), input.data() + input.size(),
+ &result) == nullptr);
+}
+
+TEST(Coding, Varint64Truncation) {
+ uint64_t large_value = (1ull << 63) + 100ull;
+ std::string s;
+ PutVarint64(&s, large_value);
+ uint64_t result;
+ for (unsigned int len = 0; len + 1 < s.size(); len++) {
+ ASSERT_TRUE(GetVarint64Ptr(s.data(), s.data() + len, &result) == nullptr);
+ }
+ ASSERT_TRUE(GetVarint64Ptr(s.data(), s.data() + s.size(), &result) !=
+ nullptr);
+ ASSERT_EQ(large_value, result);
+}
+
+TEST(Coding, Strings) {
+ std::string s;
+ PutLengthPrefixedSlice(&s, Slice(""));
+ PutLengthPrefixedSlice(&s, Slice("foo"));
+ PutLengthPrefixedSlice(&s, Slice("bar"));
+ PutLengthPrefixedSlice(&s, Slice(std::string(200, 'x')));
+
+ Slice input(s);
+ Slice v;
+ ASSERT_TRUE(GetLengthPrefixedSlice(&input, &v));
+ ASSERT_EQ("", v.ToString());
+ ASSERT_TRUE(GetLengthPrefixedSlice(&input, &v));
+ ASSERT_EQ("foo", v.ToString());
+ ASSERT_TRUE(GetLengthPrefixedSlice(&input, &v));
+ ASSERT_EQ("bar", v.ToString());
+ ASSERT_TRUE(GetLengthPrefixedSlice(&input, &v));
+ ASSERT_EQ(std::string(200, 'x'), v.ToString());
+ ASSERT_EQ("", input.ToString());
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/rocksdb/util/compaction_job_stats_impl.cc b/src/rocksdb/util/compaction_job_stats_impl.cc
new file mode 100644
index 000000000..cfab2a4fe
--- /dev/null
+++ b/src/rocksdb/util/compaction_job_stats_impl.cc
@@ -0,0 +1,100 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "rocksdb/compaction_job_stats.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+#ifndef ROCKSDB_LITE
+
+void CompactionJobStats::Reset() {
+ elapsed_micros = 0;
+ cpu_micros = 0;
+
+ num_input_records = 0;
+ num_blobs_read = 0;
+ num_input_files = 0;
+ num_input_files_at_output_level = 0;
+
+ num_output_records = 0;
+ num_output_files = 0;
+ num_output_files_blob = 0;
+
+ is_full_compaction = false;
+ is_manual_compaction = false;
+
+ total_input_bytes = 0;
+ total_blob_bytes_read = 0;
+ total_output_bytes = 0;
+ total_output_bytes_blob = 0;
+
+ num_records_replaced = 0;
+
+ total_input_raw_key_bytes = 0;
+ total_input_raw_value_bytes = 0;
+
+ num_input_deletion_records = 0;
+ num_expired_deletion_records = 0;
+
+ num_corrupt_keys = 0;
+
+ file_write_nanos = 0;
+ file_range_sync_nanos = 0;
+ file_fsync_nanos = 0;
+ file_prepare_write_nanos = 0;
+
+ smallest_output_key_prefix.clear();
+ largest_output_key_prefix.clear();
+
+ num_single_del_fallthru = 0;
+ num_single_del_mismatch = 0;
+}
+
+void CompactionJobStats::Add(const CompactionJobStats& stats) {
+ elapsed_micros += stats.elapsed_micros;
+ cpu_micros += stats.cpu_micros;
+
+ num_input_records += stats.num_input_records;
+ num_blobs_read += stats.num_blobs_read;
+ num_input_files += stats.num_input_files;
+ num_input_files_at_output_level += stats.num_input_files_at_output_level;
+
+ num_output_records += stats.num_output_records;
+ num_output_files += stats.num_output_files;
+ num_output_files_blob += stats.num_output_files_blob;
+
+ total_input_bytes += stats.total_input_bytes;
+ total_blob_bytes_read += stats.total_blob_bytes_read;
+ total_output_bytes += stats.total_output_bytes;
+ total_output_bytes_blob += stats.total_output_bytes_blob;
+
+ num_records_replaced += stats.num_records_replaced;
+
+ total_input_raw_key_bytes += stats.total_input_raw_key_bytes;
+ total_input_raw_value_bytes += stats.total_input_raw_value_bytes;
+
+ num_input_deletion_records += stats.num_input_deletion_records;
+ num_expired_deletion_records += stats.num_expired_deletion_records;
+
+ num_corrupt_keys += stats.num_corrupt_keys;
+
+ file_write_nanos += stats.file_write_nanos;
+ file_range_sync_nanos += stats.file_range_sync_nanos;
+ file_fsync_nanos += stats.file_fsync_nanos;
+ file_prepare_write_nanos += stats.file_prepare_write_nanos;
+
+ num_single_del_fallthru += stats.num_single_del_fallthru;
+ num_single_del_mismatch += stats.num_single_del_mismatch;
+}
+
+#else
+
+void CompactionJobStats::Reset() {}
+
+void CompactionJobStats::Add(const CompactionJobStats& /*stats*/) {}
+
+#endif // !ROCKSDB_LITE
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/comparator.cc b/src/rocksdb/util/comparator.cc
new file mode 100644
index 000000000..f85ed69ee
--- /dev/null
+++ b/src/rocksdb/util/comparator.cc
@@ -0,0 +1,391 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "rocksdb/comparator.h"
+
+#include <stdint.h>
+
+#include <algorithm>
+#include <memory>
+#include <mutex>
+#include <sstream>
+
+#include "db/dbformat.h"
+#include "port/lang.h"
+#include "port/port.h"
+#include "rocksdb/convenience.h"
+#include "rocksdb/slice.h"
+#include "rocksdb/utilities/customizable_util.h"
+#include "rocksdb/utilities/object_registry.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+namespace {
+class BytewiseComparatorImpl : public Comparator {
+ public:
+ BytewiseComparatorImpl() {}
+ static const char* kClassName() { return "leveldb.BytewiseComparator"; }
+ const char* Name() const override { return kClassName(); }
+
+ int Compare(const Slice& a, const Slice& b) const override {
+ return a.compare(b);
+ }
+
+ bool Equal(const Slice& a, const Slice& b) const override { return a == b; }
+
+ void FindShortestSeparator(std::string* start,
+ const Slice& limit) const override {
+ // Find length of common prefix
+ size_t min_length = std::min(start->size(), limit.size());
+ size_t diff_index = 0;
+ while ((diff_index < min_length) &&
+ ((*start)[diff_index] == limit[diff_index])) {
+ diff_index++;
+ }
+
+ if (diff_index >= min_length) {
+ // Do not shorten if one string is a prefix of the other
+ } else {
+ uint8_t start_byte = static_cast<uint8_t>((*start)[diff_index]);
+ uint8_t limit_byte = static_cast<uint8_t>(limit[diff_index]);
+ if (start_byte >= limit_byte) {
+ // Cannot shorten since limit is smaller than start or start is
+ // already the shortest possible.
+ return;
+ }
+ assert(start_byte < limit_byte);
+
+ if (diff_index < limit.size() - 1 || start_byte + 1 < limit_byte) {
+ (*start)[diff_index]++;
+ start->resize(diff_index + 1);
+ } else {
+ // v
+ // A A 1 A A A
+ // A A 2
+ //
+ // Incrementing the current byte will make start bigger than limit, we
+ // will skip this byte, and find the first non 0xFF byte in start and
+ // increment it.
+ diff_index++;
+
+ while (diff_index < start->size()) {
+ // Keep moving until we find the first non 0xFF byte to
+ // increment it
+ if (static_cast<uint8_t>((*start)[diff_index]) <
+ static_cast<uint8_t>(0xff)) {
+ (*start)[diff_index]++;
+ start->resize(diff_index + 1);
+ break;
+ }
+ diff_index++;
+ }
+ }
+ assert(Compare(*start, limit) < 0);
+ }
+ }
+
+ void FindShortSuccessor(std::string* key) const override {
+ // Find first character that can be incremented
+ size_t n = key->size();
+ for (size_t i = 0; i < n; i++) {
+ const uint8_t byte = (*key)[i];
+ if (byte != static_cast<uint8_t>(0xff)) {
+ (*key)[i] = byte + 1;
+ key->resize(i + 1);
+ return;
+ }
+ }
+ // *key is a run of 0xffs. Leave it alone.
+ }
+
+ bool IsSameLengthImmediateSuccessor(const Slice& s,
+ const Slice& t) const override {
+ if (s.size() != t.size() || s.size() == 0) {
+ return false;
+ }
+ size_t diff_ind = s.difference_offset(t);
+ // same slice
+ if (diff_ind >= s.size()) return false;
+ uint8_t byte_s = static_cast<uint8_t>(s[diff_ind]);
+ uint8_t byte_t = static_cast<uint8_t>(t[diff_ind]);
+ // first different byte must be consecutive, and remaining bytes must be
+ // 0xff for s and 0x00 for t
+ if (byte_s != uint8_t{0xff} && byte_s + 1 == byte_t) {
+ for (size_t i = diff_ind + 1; i < s.size(); ++i) {
+ byte_s = static_cast<uint8_t>(s[i]);
+ byte_t = static_cast<uint8_t>(t[i]);
+ if (byte_s != uint8_t{0xff} || byte_t != uint8_t{0x00}) {
+ return false;
+ }
+ }
+ return true;
+ } else {
+ return false;
+ }
+ }
+
+ bool CanKeysWithDifferentByteContentsBeEqual() const override {
+ return false;
+ }
+
+ using Comparator::CompareWithoutTimestamp;
+ int CompareWithoutTimestamp(const Slice& a, bool /*a_has_ts*/, const Slice& b,
+ bool /*b_has_ts*/) const override {
+ return a.compare(b);
+ }
+
+ bool EqualWithoutTimestamp(const Slice& a, const Slice& b) const override {
+ return a == b;
+ }
+};
+
+class ReverseBytewiseComparatorImpl : public BytewiseComparatorImpl {
+ public:
+ ReverseBytewiseComparatorImpl() {}
+
+ static const char* kClassName() {
+ return "rocksdb.ReverseBytewiseComparator";
+ }
+ const char* Name() const override { return kClassName(); }
+
+ int Compare(const Slice& a, const Slice& b) const override {
+ return -a.compare(b);
+ }
+
+ void FindShortestSeparator(std::string* start,
+ const Slice& limit) const override {
+ // Find length of common prefix
+ size_t min_length = std::min(start->size(), limit.size());
+ size_t diff_index = 0;
+ while ((diff_index < min_length) &&
+ ((*start)[diff_index] == limit[diff_index])) {
+ diff_index++;
+ }
+
+ assert(diff_index <= min_length);
+ if (diff_index == min_length) {
+ // Do not shorten if one string is a prefix of the other
+ //
+ // We could handle cases like:
+ // V
+ // A A 2 X Y
+ // A A 2
+ // in a similar way as BytewiseComparator::FindShortestSeparator().
+ // We keep it simple by not implementing it. We can come back to it
+ // later when needed.
+ } else {
+ uint8_t start_byte = static_cast<uint8_t>((*start)[diff_index]);
+ uint8_t limit_byte = static_cast<uint8_t>(limit[diff_index]);
+ if (start_byte > limit_byte && diff_index < start->size() - 1) {
+ // Case like
+ // V
+ // A A 3 A A
+ // A A 1 B B
+ //
+ // or
+ // v
+ // A A 2 A A
+ // A A 1 B B
+ // In this case "AA2" will be good.
+#ifndef NDEBUG
+ std::string old_start = *start;
+#endif
+ start->resize(diff_index + 1);
+#ifndef NDEBUG
+ assert(old_start >= *start);
+#endif
+ assert(Slice(*start).compare(limit) > 0);
+ }
+ }
+ }
+
+ void FindShortSuccessor(std::string* /*key*/) const override {
+ // Don't do anything for simplicity.
+ }
+
+ bool IsSameLengthImmediateSuccessor(const Slice& s,
+ const Slice& t) const override {
+ // Always returning false to prevent surfacing design flaws in
+ // auto_prefix_mode
+ (void)s, (void)t;
+ return false;
+ // "Correct" implementation:
+ // return BytewiseComparatorImpl::IsSameLengthImmediateSuccessor(t, s);
+ }
+
+ bool CanKeysWithDifferentByteContentsBeEqual() const override {
+ return false;
+ }
+
+ using Comparator::CompareWithoutTimestamp;
+ int CompareWithoutTimestamp(const Slice& a, bool /*a_has_ts*/, const Slice& b,
+ bool /*b_has_ts*/) const override {
+ return -a.compare(b);
+ }
+};
+
+// EXPERIMENTAL
+// Comparator with 64-bit integer timestamp.
+// We did not performance test this yet.
+template <typename TComparator>
+class ComparatorWithU64TsImpl : public Comparator {
+ static_assert(std::is_base_of<Comparator, TComparator>::value,
+ "template type must be a inherited type of comparator");
+
+ public:
+ explicit ComparatorWithU64TsImpl() : Comparator(/*ts_sz=*/sizeof(uint64_t)) {
+ assert(cmp_without_ts_.timestamp_size() == 0);
+ }
+
+ static const char* kClassName() {
+ static std::string class_name = kClassNameInternal();
+ return class_name.c_str();
+ }
+
+ const char* Name() const override { return kClassName(); }
+
+ void FindShortSuccessor(std::string*) const override {}
+ void FindShortestSeparator(std::string*, const Slice&) const override {}
+ int Compare(const Slice& a, const Slice& b) const override {
+ int ret = CompareWithoutTimestamp(a, b);
+ size_t ts_sz = timestamp_size();
+ if (ret != 0) {
+ return ret;
+ }
+ // Compare timestamp.
+ // For the same user key with different timestamps, larger (newer) timestamp
+ // comes first.
+ return -CompareTimestamp(ExtractTimestampFromUserKey(a, ts_sz),
+ ExtractTimestampFromUserKey(b, ts_sz));
+ }
+ using Comparator::CompareWithoutTimestamp;
+ int CompareWithoutTimestamp(const Slice& a, bool a_has_ts, const Slice& b,
+ bool b_has_ts) const override {
+ const size_t ts_sz = timestamp_size();
+ assert(!a_has_ts || a.size() >= ts_sz);
+ assert(!b_has_ts || b.size() >= ts_sz);
+ Slice lhs = a_has_ts ? StripTimestampFromUserKey(a, ts_sz) : a;
+ Slice rhs = b_has_ts ? StripTimestampFromUserKey(b, ts_sz) : b;
+ return cmp_without_ts_.Compare(lhs, rhs);
+ }
+ int CompareTimestamp(const Slice& ts1, const Slice& ts2) const override {
+ assert(ts1.size() == sizeof(uint64_t));
+ assert(ts2.size() == sizeof(uint64_t));
+ uint64_t lhs = DecodeFixed64(ts1.data());
+ uint64_t rhs = DecodeFixed64(ts2.data());
+ if (lhs < rhs) {
+ return -1;
+ } else if (lhs > rhs) {
+ return 1;
+ } else {
+ return 0;
+ }
+ }
+
+ private:
+ static std::string kClassNameInternal() {
+ std::stringstream ss;
+ ss << TComparator::kClassName() << ".u64ts";
+ return ss.str();
+ }
+
+ TComparator cmp_without_ts_;
+};
+
+} // namespace
+
+const Comparator* BytewiseComparator() {
+ STATIC_AVOID_DESTRUCTION(BytewiseComparatorImpl, bytewise);
+ return &bytewise;
+}
+
+const Comparator* ReverseBytewiseComparator() {
+ STATIC_AVOID_DESTRUCTION(ReverseBytewiseComparatorImpl, rbytewise);
+ return &rbytewise;
+}
+
+const Comparator* BytewiseComparatorWithU64Ts() {
+ STATIC_AVOID_DESTRUCTION(ComparatorWithU64TsImpl<BytewiseComparatorImpl>,
+ comp_with_u64_ts);
+ return &comp_with_u64_ts;
+}
+
+#ifndef ROCKSDB_LITE
+static int RegisterBuiltinComparators(ObjectLibrary& library,
+ const std::string& /*arg*/) {
+ library.AddFactory<const Comparator>(
+ BytewiseComparatorImpl::kClassName(),
+ [](const std::string& /*uri*/,
+ std::unique_ptr<const Comparator>* /*guard */,
+ std::string* /* errmsg */) { return BytewiseComparator(); });
+ library.AddFactory<const Comparator>(
+ ReverseBytewiseComparatorImpl::kClassName(),
+ [](const std::string& /*uri*/,
+ std::unique_ptr<const Comparator>* /*guard */,
+ std::string* /* errmsg */) { return ReverseBytewiseComparator(); });
+ library.AddFactory<const Comparator>(
+ ComparatorWithU64TsImpl<BytewiseComparatorImpl>::kClassName(),
+ [](const std::string& /*uri*/,
+ std::unique_ptr<const Comparator>* /*guard */,
+ std::string* /* errmsg */) { return BytewiseComparatorWithU64Ts(); });
+ return 3;
+}
+#endif // ROCKSDB_LITE
+
+Status Comparator::CreateFromString(const ConfigOptions& config_options,
+ const std::string& value,
+ const Comparator** result) {
+#ifndef ROCKSDB_LITE
+ static std::once_flag once;
+ std::call_once(once, [&]() {
+ RegisterBuiltinComparators(*(ObjectLibrary::Default().get()), "");
+ });
+#endif // ROCKSDB_LITE
+ std::string id;
+ std::unordered_map<std::string, std::string> opt_map;
+ Status status = Customizable::GetOptionsMap(config_options, *result, value,
+ &id, &opt_map);
+ if (!status.ok()) { // GetOptionsMap failed
+ return status;
+ }
+ if (id == BytewiseComparatorImpl::kClassName()) {
+ *result = BytewiseComparator();
+ } else if (id == ReverseBytewiseComparatorImpl::kClassName()) {
+ *result = ReverseBytewiseComparator();
+ } else if (id ==
+ ComparatorWithU64TsImpl<BytewiseComparatorImpl>::kClassName()) {
+ *result = BytewiseComparatorWithU64Ts();
+ } else if (value.empty()) {
+ // No Id and no options. Clear the object
+ *result = nullptr;
+ return Status::OK();
+ } else if (id.empty()) { // We have no Id but have options. Not good
+ return Status::NotSupported("Cannot reset object ", id);
+ } else {
+#ifndef ROCKSDB_LITE
+ status = config_options.registry->NewStaticObject(id, result);
+#else
+ status = Status::NotSupported("Cannot load object in LITE mode ", id);
+#endif // ROCKSDB_LITE
+ if (!status.ok()) {
+ if (config_options.ignore_unsupported_options &&
+ status.IsNotSupported()) {
+ return Status::OK();
+ } else {
+ return status;
+ }
+ } else {
+ Comparator* comparator = const_cast<Comparator*>(*result);
+ status =
+ Customizable::ConfigureNewObject(config_options, comparator, opt_map);
+ }
+ }
+ return status;
+}
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/compression.cc b/src/rocksdb/util/compression.cc
new file mode 100644
index 000000000..8e2f01b12
--- /dev/null
+++ b/src/rocksdb/util/compression.cc
@@ -0,0 +1,122 @@
+// Copyright (c) 2022-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "util/compression.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+StreamingCompress* StreamingCompress::Create(CompressionType compression_type,
+ const CompressionOptions& opts,
+ uint32_t compress_format_version,
+ size_t max_output_len) {
+ switch (compression_type) {
+ case kZSTD: {
+ if (!ZSTD_Streaming_Supported()) {
+ return nullptr;
+ }
+ return new ZSTDStreamingCompress(opts, compress_format_version,
+ max_output_len);
+ }
+ default:
+ return nullptr;
+ }
+}
+
+StreamingUncompress* StreamingUncompress::Create(
+ CompressionType compression_type, uint32_t compress_format_version,
+ size_t max_output_len) {
+ switch (compression_type) {
+ case kZSTD: {
+ if (!ZSTD_Streaming_Supported()) {
+ return nullptr;
+ }
+ return new ZSTDStreamingUncompress(compress_format_version,
+ max_output_len);
+ }
+ default:
+ return nullptr;
+ }
+}
+
+int ZSTDStreamingCompress::Compress(const char* input, size_t input_size,
+ char* output, size_t* output_pos) {
+ assert(input != nullptr && output != nullptr && output_pos != nullptr);
+ *output_pos = 0;
+ // Don't need to compress an empty input
+ if (input_size == 0) {
+ return 0;
+ }
+#ifndef ZSTD_STREAMING
+ (void)input;
+ (void)input_size;
+ (void)output;
+ return -1;
+#else
+ if (input_buffer_.src == nullptr || input_buffer_.src != input) {
+ // New input
+ // Catch errors where the previous input was not fully decompressed.
+ assert(input_buffer_.pos == input_buffer_.size);
+ input_buffer_ = {input, input_size, /*pos=*/0};
+ } else if (input_buffer_.src == input) {
+ // Same input, not fully compressed.
+ }
+ ZSTD_outBuffer output_buffer = {output, max_output_len_, /*pos=*/0};
+ const size_t remaining =
+ ZSTD_compressStream2(cctx_, &output_buffer, &input_buffer_, ZSTD_e_end);
+ if (ZSTD_isError(remaining)) {
+ // Failure
+ Reset();
+ return -1;
+ }
+ // Success
+ *output_pos = output_buffer.pos;
+ return (int)remaining;
+#endif
+}
+
+void ZSTDStreamingCompress::Reset() {
+#ifdef ZSTD_STREAMING
+ ZSTD_CCtx_reset(cctx_, ZSTD_ResetDirective::ZSTD_reset_session_only);
+ input_buffer_ = {/*src=*/nullptr, /*size=*/0, /*pos=*/0};
+#endif
+}
+
+int ZSTDStreamingUncompress::Uncompress(const char* input, size_t input_size,
+ char* output, size_t* output_pos) {
+ assert(input != nullptr && output != nullptr && output_pos != nullptr);
+ *output_pos = 0;
+ // Don't need to uncompress an empty input
+ if (input_size == 0) {
+ return 0;
+ }
+#ifdef ZSTD_STREAMING
+ if (input_buffer_.src != input) {
+ // New input
+ input_buffer_ = {input, input_size, /*pos=*/0};
+ }
+ ZSTD_outBuffer output_buffer = {output, max_output_len_, /*pos=*/0};
+ size_t ret = ZSTD_decompressStream(dctx_, &output_buffer, &input_buffer_);
+ if (ZSTD_isError(ret)) {
+ Reset();
+ return -1;
+ }
+ *output_pos = output_buffer.pos;
+ return (int)(input_buffer_.size - input_buffer_.pos);
+#else
+ (void)input;
+ (void)input_size;
+ (void)output;
+ return -1;
+#endif
+}
+
+void ZSTDStreamingUncompress::Reset() {
+#ifdef ZSTD_STREAMING
+ ZSTD_DCtx_reset(dctx_, ZSTD_ResetDirective::ZSTD_reset_session_only);
+ input_buffer_ = {/*src=*/nullptr, /*size=*/0, /*pos=*/0};
+#endif
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/compression.h b/src/rocksdb/util/compression.h
new file mode 100644
index 000000000..0d4febcfb
--- /dev/null
+++ b/src/rocksdb/util/compression.h
@@ -0,0 +1,1786 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+//
+#pragma once
+
+#include <algorithm>
+#include <limits>
+#ifdef ROCKSDB_MALLOC_USABLE_SIZE
+#ifdef OS_FREEBSD
+#include <malloc_np.h>
+#else // OS_FREEBSD
+#include <malloc.h>
+#endif // OS_FREEBSD
+#endif // ROCKSDB_MALLOC_USABLE_SIZE
+#include <string>
+
+#include "memory/memory_allocator.h"
+#include "rocksdb/options.h"
+#include "rocksdb/table.h"
+#include "test_util/sync_point.h"
+#include "util/coding.h"
+#include "util/compression_context_cache.h"
+#include "util/string_util.h"
+
+#ifdef SNAPPY
+#include <snappy.h>
+#endif
+
+#ifdef ZLIB
+#include <zlib.h>
+#endif
+
+#ifdef BZIP2
+#include <bzlib.h>
+#endif
+
+#if defined(LZ4)
+#include <lz4.h>
+#include <lz4hc.h>
+#endif
+
+#if defined(ZSTD)
+#include <zstd.h>
+// v1.1.3+
+#if ZSTD_VERSION_NUMBER >= 10103
+#include <zdict.h>
+#endif // ZSTD_VERSION_NUMBER >= 10103
+// v1.4.0+
+#if ZSTD_VERSION_NUMBER >= 10400
+#define ZSTD_STREAMING
+#endif // ZSTD_VERSION_NUMBER >= 10400
+namespace ROCKSDB_NAMESPACE {
+// Need this for the context allocation override
+// On windows we need to do this explicitly
+#if (ZSTD_VERSION_NUMBER >= 500)
+#if defined(ROCKSDB_JEMALLOC) && defined(OS_WIN) && \
+ defined(ZSTD_STATIC_LINKING_ONLY)
+#define ROCKSDB_ZSTD_CUSTOM_MEM
+namespace port {
+ZSTD_customMem GetJeZstdAllocationOverrides();
+} // namespace port
+#endif // defined(ROCKSDB_JEMALLOC) && defined(OS_WIN) &&
+ // defined(ZSTD_STATIC_LINKING_ONLY)
+
+// We require `ZSTD_sizeof_DDict` and `ZSTD_createDDict_byReference` to use
+// `ZSTD_DDict`. The former was introduced in v1.0.0 and the latter was
+// introduced in v1.1.3. But an important bug fix for `ZSTD_sizeof_DDict` came
+// in v1.1.4, so that is the version we require. As of today's latest version
+// (v1.3.8), they are both still in the experimental API, which means they are
+// only exported when the compiler flag `ZSTD_STATIC_LINKING_ONLY` is set.
+#if defined(ZSTD_STATIC_LINKING_ONLY) && ZSTD_VERSION_NUMBER >= 10104
+#define ROCKSDB_ZSTD_DDICT
+#endif // defined(ZSTD_STATIC_LINKING_ONLY) && ZSTD_VERSION_NUMBER >= 10104
+
+// Cached data represents a portion that can be re-used
+// If, in the future we have more than one native context to
+// cache we can arrange this as a tuple
+class ZSTDUncompressCachedData {
+ public:
+ using ZSTDNativeContext = ZSTD_DCtx*;
+ ZSTDUncompressCachedData() {}
+ // Init from cache
+ ZSTDUncompressCachedData(const ZSTDUncompressCachedData& o) = delete;
+ ZSTDUncompressCachedData& operator=(const ZSTDUncompressCachedData&) = delete;
+ ZSTDUncompressCachedData(ZSTDUncompressCachedData&& o) noexcept
+ : ZSTDUncompressCachedData() {
+ *this = std::move(o);
+ }
+ ZSTDUncompressCachedData& operator=(ZSTDUncompressCachedData&& o) noexcept {
+ assert(zstd_ctx_ == nullptr);
+ std::swap(zstd_ctx_, o.zstd_ctx_);
+ std::swap(cache_idx_, o.cache_idx_);
+ return *this;
+ }
+ ZSTDNativeContext Get() const { return zstd_ctx_; }
+ int64_t GetCacheIndex() const { return cache_idx_; }
+ void CreateIfNeeded() {
+ if (zstd_ctx_ == nullptr) {
+#ifdef ROCKSDB_ZSTD_CUSTOM_MEM
+ zstd_ctx_ =
+ ZSTD_createDCtx_advanced(port::GetJeZstdAllocationOverrides());
+#else // ROCKSDB_ZSTD_CUSTOM_MEM
+ zstd_ctx_ = ZSTD_createDCtx();
+#endif // ROCKSDB_ZSTD_CUSTOM_MEM
+ cache_idx_ = -1;
+ }
+ }
+ void InitFromCache(const ZSTDUncompressCachedData& o, int64_t idx) {
+ zstd_ctx_ = o.zstd_ctx_;
+ cache_idx_ = idx;
+ }
+ ~ZSTDUncompressCachedData() {
+ if (zstd_ctx_ != nullptr && cache_idx_ == -1) {
+ ZSTD_freeDCtx(zstd_ctx_);
+ }
+ }
+
+ private:
+ ZSTDNativeContext zstd_ctx_ = nullptr;
+ int64_t cache_idx_ = -1; // -1 means this instance owns the context
+};
+#endif // (ZSTD_VERSION_NUMBER >= 500)
+} // namespace ROCKSDB_NAMESPACE
+#endif // ZSTD
+
+#if !(defined ZSTD) || !(ZSTD_VERSION_NUMBER >= 500)
+namespace ROCKSDB_NAMESPACE {
+class ZSTDUncompressCachedData {
+ void* padding; // unused
+ public:
+ using ZSTDNativeContext = void*;
+ ZSTDUncompressCachedData() {}
+ ZSTDUncompressCachedData(const ZSTDUncompressCachedData&) {}
+ ZSTDUncompressCachedData& operator=(const ZSTDUncompressCachedData&) = delete;
+ ZSTDUncompressCachedData(ZSTDUncompressCachedData&&) noexcept = default;
+ ZSTDUncompressCachedData& operator=(ZSTDUncompressCachedData&&) noexcept =
+ default;
+ ZSTDNativeContext Get() const { return nullptr; }
+ int64_t GetCacheIndex() const { return -1; }
+ void CreateIfNeeded() {}
+ void InitFromCache(const ZSTDUncompressCachedData&, int64_t) {}
+
+ private:
+ void ignore_padding__() { padding = nullptr; }
+};
+} // namespace ROCKSDB_NAMESPACE
+#endif
+
+#if defined(XPRESS)
+#include "port/xpress.h"
+#endif
+
+namespace ROCKSDB_NAMESPACE {
+
+// Holds dictionary and related data, like ZSTD's digested compression
+// dictionary.
+struct CompressionDict {
+#if ZSTD_VERSION_NUMBER >= 700
+ ZSTD_CDict* zstd_cdict_ = nullptr;
+#endif // ZSTD_VERSION_NUMBER >= 700
+ std::string dict_;
+
+ public:
+#if ZSTD_VERSION_NUMBER >= 700
+ CompressionDict(std::string dict, CompressionType type, int level) {
+#else // ZSTD_VERSION_NUMBER >= 700
+ CompressionDict(std::string dict, CompressionType /*type*/, int /*level*/) {
+#endif // ZSTD_VERSION_NUMBER >= 700
+ dict_ = std::move(dict);
+#if ZSTD_VERSION_NUMBER >= 700
+ zstd_cdict_ = nullptr;
+ if (!dict_.empty() && (type == kZSTD || type == kZSTDNotFinalCompression)) {
+ if (level == CompressionOptions::kDefaultCompressionLevel) {
+ // 3 is the value of ZSTD_CLEVEL_DEFAULT (not exposed publicly), see
+ // https://github.com/facebook/zstd/issues/1148
+ level = 3;
+ }
+ // Should be safe (but slower) if below call fails as we'll use the
+ // raw dictionary to compress.
+ zstd_cdict_ = ZSTD_createCDict(dict_.data(), dict_.size(), level);
+ assert(zstd_cdict_ != nullptr);
+ }
+#endif // ZSTD_VERSION_NUMBER >= 700
+ }
+
+ ~CompressionDict() {
+#if ZSTD_VERSION_NUMBER >= 700
+ size_t res = 0;
+ if (zstd_cdict_ != nullptr) {
+ res = ZSTD_freeCDict(zstd_cdict_);
+ }
+ assert(res == 0); // Last I checked they can't fail
+ (void)res; // prevent unused var warning
+#endif // ZSTD_VERSION_NUMBER >= 700
+ }
+
+#if ZSTD_VERSION_NUMBER >= 700
+ const ZSTD_CDict* GetDigestedZstdCDict() const { return zstd_cdict_; }
+#endif // ZSTD_VERSION_NUMBER >= 700
+
+ Slice GetRawDict() const { return dict_; }
+
+ static const CompressionDict& GetEmptyDict() {
+ static CompressionDict empty_dict{};
+ return empty_dict;
+ }
+
+ CompressionDict() = default;
+ // Disable copy/move
+ CompressionDict(const CompressionDict&) = delete;
+ CompressionDict& operator=(const CompressionDict&) = delete;
+ CompressionDict(CompressionDict&&) = delete;
+ CompressionDict& operator=(CompressionDict&&) = delete;
+};
+
+// Holds dictionary and related data, like ZSTD's digested uncompression
+// dictionary.
+struct UncompressionDict {
+ // Block containing the data for the compression dictionary in case the
+ // constructor that takes a string parameter is used.
+ std::string dict_;
+
+ // Block containing the data for the compression dictionary in case the
+ // constructor that takes a Slice parameter is used and the passed in
+ // CacheAllocationPtr is not nullptr.
+ CacheAllocationPtr allocation_;
+
+ // Slice pointing to the compression dictionary data. Can point to
+ // dict_, allocation_, or some other memory location, depending on how
+ // the object was constructed.
+ Slice slice_;
+
+#ifdef ROCKSDB_ZSTD_DDICT
+ // Processed version of the contents of slice_ for ZSTD compression.
+ ZSTD_DDict* zstd_ddict_ = nullptr;
+#endif // ROCKSDB_ZSTD_DDICT
+
+#ifdef ROCKSDB_ZSTD_DDICT
+ UncompressionDict(std::string dict, bool using_zstd)
+#else // ROCKSDB_ZSTD_DDICT
+ UncompressionDict(std::string dict, bool /* using_zstd */)
+#endif // ROCKSDB_ZSTD_DDICT
+ : dict_(std::move(dict)), slice_(dict_) {
+#ifdef ROCKSDB_ZSTD_DDICT
+ if (!slice_.empty() && using_zstd) {
+ zstd_ddict_ = ZSTD_createDDict_byReference(slice_.data(), slice_.size());
+ assert(zstd_ddict_ != nullptr);
+ }
+#endif // ROCKSDB_ZSTD_DDICT
+ }
+
+#ifdef ROCKSDB_ZSTD_DDICT
+ UncompressionDict(Slice slice, CacheAllocationPtr&& allocation,
+ bool using_zstd)
+#else // ROCKSDB_ZSTD_DDICT
+ UncompressionDict(Slice slice, CacheAllocationPtr&& allocation,
+ bool /* using_zstd */)
+#endif // ROCKSDB_ZSTD_DDICT
+ : allocation_(std::move(allocation)), slice_(std::move(slice)) {
+#ifdef ROCKSDB_ZSTD_DDICT
+ if (!slice_.empty() && using_zstd) {
+ zstd_ddict_ = ZSTD_createDDict_byReference(slice_.data(), slice_.size());
+ assert(zstd_ddict_ != nullptr);
+ }
+#endif // ROCKSDB_ZSTD_DDICT
+ }
+
+ UncompressionDict(UncompressionDict&& rhs)
+ : dict_(std::move(rhs.dict_)),
+ allocation_(std::move(rhs.allocation_)),
+ slice_(std::move(rhs.slice_))
+#ifdef ROCKSDB_ZSTD_DDICT
+ ,
+ zstd_ddict_(rhs.zstd_ddict_)
+#endif
+ {
+#ifdef ROCKSDB_ZSTD_DDICT
+ rhs.zstd_ddict_ = nullptr;
+#endif
+ }
+
+ ~UncompressionDict() {
+#ifdef ROCKSDB_ZSTD_DDICT
+ size_t res = 0;
+ if (zstd_ddict_ != nullptr) {
+ res = ZSTD_freeDDict(zstd_ddict_);
+ }
+ assert(res == 0); // Last I checked they can't fail
+ (void)res; // prevent unused var warning
+#endif // ROCKSDB_ZSTD_DDICT
+ }
+
+ UncompressionDict& operator=(UncompressionDict&& rhs) {
+ if (this == &rhs) {
+ return *this;
+ }
+
+ dict_ = std::move(rhs.dict_);
+ allocation_ = std::move(rhs.allocation_);
+ slice_ = std::move(rhs.slice_);
+
+#ifdef ROCKSDB_ZSTD_DDICT
+ zstd_ddict_ = rhs.zstd_ddict_;
+ rhs.zstd_ddict_ = nullptr;
+#endif
+
+ return *this;
+ }
+
+ // The object is self-contained if the string constructor is used, or the
+ // Slice constructor is invoked with a non-null allocation. Otherwise, it
+ // is the caller's responsibility to ensure that the underlying storage
+ // outlives this object.
+ bool own_bytes() const { return !dict_.empty() || allocation_; }
+
+ const Slice& GetRawDict() const { return slice_; }
+
+#ifdef ROCKSDB_ZSTD_DDICT
+ const ZSTD_DDict* GetDigestedZstdDDict() const { return zstd_ddict_; }
+#endif // ROCKSDB_ZSTD_DDICT
+
+ static const UncompressionDict& GetEmptyDict() {
+ static UncompressionDict empty_dict{};
+ return empty_dict;
+ }
+
+ size_t ApproximateMemoryUsage() const {
+ size_t usage = sizeof(struct UncompressionDict);
+ usage += dict_.size();
+ if (allocation_) {
+ auto allocator = allocation_.get_deleter().allocator;
+ if (allocator) {
+ usage += allocator->UsableSize(allocation_.get(), slice_.size());
+ } else {
+ usage += slice_.size();
+ }
+ }
+#ifdef ROCKSDB_ZSTD_DDICT
+ usage += ZSTD_sizeof_DDict(zstd_ddict_);
+#endif // ROCKSDB_ZSTD_DDICT
+ return usage;
+ }
+
+ UncompressionDict() = default;
+ // Disable copy
+ UncompressionDict(const CompressionDict&) = delete;
+ UncompressionDict& operator=(const CompressionDict&) = delete;
+};
+
+class CompressionContext {
+ private:
+#if defined(ZSTD) && (ZSTD_VERSION_NUMBER >= 500)
+ ZSTD_CCtx* zstd_ctx_ = nullptr;
+ void CreateNativeContext(CompressionType type) {
+ if (type == kZSTD || type == kZSTDNotFinalCompression) {
+#ifdef ROCKSDB_ZSTD_CUSTOM_MEM
+ zstd_ctx_ =
+ ZSTD_createCCtx_advanced(port::GetJeZstdAllocationOverrides());
+#else // ROCKSDB_ZSTD_CUSTOM_MEM
+ zstd_ctx_ = ZSTD_createCCtx();
+#endif // ROCKSDB_ZSTD_CUSTOM_MEM
+ }
+ }
+ void DestroyNativeContext() {
+ if (zstd_ctx_ != nullptr) {
+ ZSTD_freeCCtx(zstd_ctx_);
+ }
+ }
+
+ public:
+ // callable inside ZSTD_Compress
+ ZSTD_CCtx* ZSTDPreallocCtx() const {
+ assert(zstd_ctx_ != nullptr);
+ return zstd_ctx_;
+ }
+
+#else // ZSTD && (ZSTD_VERSION_NUMBER >= 500)
+ private:
+ void CreateNativeContext(CompressionType /* type */) {}
+ void DestroyNativeContext() {}
+#endif // ZSTD && (ZSTD_VERSION_NUMBER >= 500)
+ public:
+ explicit CompressionContext(CompressionType type) {
+ CreateNativeContext(type);
+ }
+ ~CompressionContext() { DestroyNativeContext(); }
+ CompressionContext(const CompressionContext&) = delete;
+ CompressionContext& operator=(const CompressionContext&) = delete;
+};
+
+class CompressionInfo {
+ const CompressionOptions& opts_;
+ const CompressionContext& context_;
+ const CompressionDict& dict_;
+ const CompressionType type_;
+ const uint64_t sample_for_compression_;
+
+ public:
+ CompressionInfo(const CompressionOptions& _opts,
+ const CompressionContext& _context,
+ const CompressionDict& _dict, CompressionType _type,
+ uint64_t _sample_for_compression)
+ : opts_(_opts),
+ context_(_context),
+ dict_(_dict),
+ type_(_type),
+ sample_for_compression_(_sample_for_compression) {}
+
+ const CompressionOptions& options() const { return opts_; }
+ const CompressionContext& context() const { return context_; }
+ const CompressionDict& dict() const { return dict_; }
+ CompressionType type() const { return type_; }
+ uint64_t SampleForCompression() const { return sample_for_compression_; }
+};
+
+class UncompressionContext {
+ private:
+ CompressionContextCache* ctx_cache_ = nullptr;
+ ZSTDUncompressCachedData uncomp_cached_data_;
+
+ public:
+ explicit UncompressionContext(CompressionType type) {
+ if (type == kZSTD || type == kZSTDNotFinalCompression) {
+ ctx_cache_ = CompressionContextCache::Instance();
+ uncomp_cached_data_ = ctx_cache_->GetCachedZSTDUncompressData();
+ }
+ }
+ ~UncompressionContext() {
+ if (uncomp_cached_data_.GetCacheIndex() != -1) {
+ assert(ctx_cache_ != nullptr);
+ ctx_cache_->ReturnCachedZSTDUncompressData(
+ uncomp_cached_data_.GetCacheIndex());
+ }
+ }
+ UncompressionContext(const UncompressionContext&) = delete;
+ UncompressionContext& operator=(const UncompressionContext&) = delete;
+
+ ZSTDUncompressCachedData::ZSTDNativeContext GetZSTDContext() const {
+ return uncomp_cached_data_.Get();
+ }
+};
+
+class UncompressionInfo {
+ const UncompressionContext& context_;
+ const UncompressionDict& dict_;
+ const CompressionType type_;
+
+ public:
+ UncompressionInfo(const UncompressionContext& _context,
+ const UncompressionDict& _dict, CompressionType _type)
+ : context_(_context), dict_(_dict), type_(_type) {}
+
+ const UncompressionContext& context() const { return context_; }
+ const UncompressionDict& dict() const { return dict_; }
+ CompressionType type() const { return type_; }
+};
+
+inline bool Snappy_Supported() {
+#ifdef SNAPPY
+ return true;
+#else
+ return false;
+#endif
+}
+
+inline bool Zlib_Supported() {
+#ifdef ZLIB
+ return true;
+#else
+ return false;
+#endif
+}
+
+inline bool BZip2_Supported() {
+#ifdef BZIP2
+ return true;
+#else
+ return false;
+#endif
+}
+
+inline bool LZ4_Supported() {
+#ifdef LZ4
+ return true;
+#else
+ return false;
+#endif
+}
+
+inline bool XPRESS_Supported() {
+#ifdef XPRESS
+ return true;
+#else
+ return false;
+#endif
+}
+
+inline bool ZSTD_Supported() {
+#ifdef ZSTD
+ // ZSTD format is finalized since version 0.8.0.
+ return (ZSTD_versionNumber() >= 800);
+#else
+ return false;
+#endif
+}
+
+inline bool ZSTDNotFinal_Supported() {
+#ifdef ZSTD
+ return true;
+#else
+ return false;
+#endif
+}
+
+inline bool ZSTD_Streaming_Supported() {
+#if defined(ZSTD) && defined(ZSTD_STREAMING)
+ return true;
+#else
+ return false;
+#endif
+}
+
+inline bool StreamingCompressionTypeSupported(
+ CompressionType compression_type) {
+ switch (compression_type) {
+ case kNoCompression:
+ return true;
+ case kZSTD:
+ return ZSTD_Streaming_Supported();
+ default:
+ return false;
+ }
+}
+
+inline bool CompressionTypeSupported(CompressionType compression_type) {
+ switch (compression_type) {
+ case kNoCompression:
+ return true;
+ case kSnappyCompression:
+ return Snappy_Supported();
+ case kZlibCompression:
+ return Zlib_Supported();
+ case kBZip2Compression:
+ return BZip2_Supported();
+ case kLZ4Compression:
+ return LZ4_Supported();
+ case kLZ4HCCompression:
+ return LZ4_Supported();
+ case kXpressCompression:
+ return XPRESS_Supported();
+ case kZSTDNotFinalCompression:
+ return ZSTDNotFinal_Supported();
+ case kZSTD:
+ return ZSTD_Supported();
+ default:
+ assert(false);
+ return false;
+ }
+}
+
+inline bool DictCompressionTypeSupported(CompressionType compression_type) {
+ switch (compression_type) {
+ case kNoCompression:
+ return false;
+ case kSnappyCompression:
+ return false;
+ case kZlibCompression:
+ return Zlib_Supported();
+ case kBZip2Compression:
+ return false;
+ case kLZ4Compression:
+ case kLZ4HCCompression:
+#if LZ4_VERSION_NUMBER >= 10400 // r124+
+ return LZ4_Supported();
+#else
+ return false;
+#endif
+ case kXpressCompression:
+ return false;
+ case kZSTDNotFinalCompression:
+#if ZSTD_VERSION_NUMBER >= 500 // v0.5.0+
+ return ZSTDNotFinal_Supported();
+#else
+ return false;
+#endif
+ case kZSTD:
+#if ZSTD_VERSION_NUMBER >= 500 // v0.5.0+
+ return ZSTD_Supported();
+#else
+ return false;
+#endif
+ default:
+ assert(false);
+ return false;
+ }
+}
+
+inline std::string CompressionTypeToString(CompressionType compression_type) {
+ switch (compression_type) {
+ case kNoCompression:
+ return "NoCompression";
+ case kSnappyCompression:
+ return "Snappy";
+ case kZlibCompression:
+ return "Zlib";
+ case kBZip2Compression:
+ return "BZip2";
+ case kLZ4Compression:
+ return "LZ4";
+ case kLZ4HCCompression:
+ return "LZ4HC";
+ case kXpressCompression:
+ return "Xpress";
+ case kZSTD:
+ return "ZSTD";
+ case kZSTDNotFinalCompression:
+ return "ZSTDNotFinal";
+ case kDisableCompressionOption:
+ return "DisableOption";
+ default:
+ assert(false);
+ return "";
+ }
+}
+
+inline std::string CompressionOptionsToString(
+ CompressionOptions& compression_options) {
+ std::string result;
+ result.reserve(512);
+ result.append("window_bits=")
+ .append(std::to_string(compression_options.window_bits))
+ .append("; ");
+ result.append("level=")
+ .append(std::to_string(compression_options.level))
+ .append("; ");
+ result.append("strategy=")
+ .append(std::to_string(compression_options.strategy))
+ .append("; ");
+ result.append("max_dict_bytes=")
+ .append(std::to_string(compression_options.max_dict_bytes))
+ .append("; ");
+ result.append("zstd_max_train_bytes=")
+ .append(std::to_string(compression_options.zstd_max_train_bytes))
+ .append("; ");
+ result.append("enabled=")
+ .append(std::to_string(compression_options.enabled))
+ .append("; ");
+ result.append("max_dict_buffer_bytes=")
+ .append(std::to_string(compression_options.max_dict_buffer_bytes))
+ .append("; ");
+ result.append("use_zstd_dict_trainer=")
+ .append(std::to_string(compression_options.use_zstd_dict_trainer))
+ .append("; ");
+ return result;
+}
+
+// compress_format_version can have two values:
+// 1 -- decompressed sizes for BZip2 and Zlib are not included in the compressed
+// block. Also, decompressed sizes for LZ4 are encoded in platform-dependent
+// way.
+// 2 -- Zlib, BZip2 and LZ4 encode decompressed size as Varint32 just before the
+// start of compressed block. Snappy format is the same as version 1.
+
+inline bool Snappy_Compress(const CompressionInfo& /*info*/, const char* input,
+ size_t length, ::std::string* output) {
+#ifdef SNAPPY
+ output->resize(snappy::MaxCompressedLength(length));
+ size_t outlen;
+ snappy::RawCompress(input, length, &(*output)[0], &outlen);
+ output->resize(outlen);
+ return true;
+#else
+ (void)input;
+ (void)length;
+ (void)output;
+ return false;
+#endif
+}
+
+inline CacheAllocationPtr Snappy_Uncompress(
+ const char* input, size_t length, size_t* uncompressed_size,
+ MemoryAllocator* allocator = nullptr) {
+#ifdef SNAPPY
+ size_t uncompressed_length = 0;
+ if (!snappy::GetUncompressedLength(input, length, &uncompressed_length)) {
+ return nullptr;
+ }
+
+ CacheAllocationPtr output = AllocateBlock(uncompressed_length, allocator);
+
+ if (!snappy::RawUncompress(input, length, output.get())) {
+ return nullptr;
+ }
+
+ *uncompressed_size = uncompressed_length;
+
+ return output;
+#else
+ (void)input;
+ (void)length;
+ (void)uncompressed_size;
+ (void)allocator;
+ return nullptr;
+#endif
+}
+
+namespace compression {
+// returns size
+inline size_t PutDecompressedSizeInfo(std::string* output, uint32_t length) {
+ PutVarint32(output, length);
+ return output->size();
+}
+
+inline bool GetDecompressedSizeInfo(const char** input_data,
+ size_t* input_length,
+ uint32_t* output_len) {
+ auto new_input_data =
+ GetVarint32Ptr(*input_data, *input_data + *input_length, output_len);
+ if (new_input_data == nullptr) {
+ return false;
+ }
+ *input_length -= (new_input_data - *input_data);
+ *input_data = new_input_data;
+ return true;
+}
+} // namespace compression
+
+// compress_format_version == 1 -- decompressed size is not included in the
+// block header
+// compress_format_version == 2 -- decompressed size is included in the block
+// header in varint32 format
+// @param compression_dict Data for presetting the compression library's
+// dictionary.
+inline bool Zlib_Compress(const CompressionInfo& info,
+ uint32_t compress_format_version, const char* input,
+ size_t length, ::std::string* output) {
+#ifdef ZLIB
+ if (length > std::numeric_limits<uint32_t>::max()) {
+ // Can't compress more than 4GB
+ return false;
+ }
+
+ size_t output_header_len = 0;
+ if (compress_format_version == 2) {
+ output_header_len = compression::PutDecompressedSizeInfo(
+ output, static_cast<uint32_t>(length));
+ }
+
+ // The memLevel parameter specifies how much memory should be allocated for
+ // the internal compression state.
+ // memLevel=1 uses minimum memory but is slow and reduces compression ratio.
+ // memLevel=9 uses maximum memory for optimal speed.
+ // The default value is 8. See zconf.h for more details.
+ static const int memLevel = 8;
+ int level;
+ if (info.options().level == CompressionOptions::kDefaultCompressionLevel) {
+ level = Z_DEFAULT_COMPRESSION;
+ } else {
+ level = info.options().level;
+ }
+ z_stream _stream;
+ memset(&_stream, 0, sizeof(z_stream));
+ int st = deflateInit2(&_stream, level, Z_DEFLATED, info.options().window_bits,
+ memLevel, info.options().strategy);
+ if (st != Z_OK) {
+ return false;
+ }
+
+ Slice compression_dict = info.dict().GetRawDict();
+ if (compression_dict.size()) {
+ // Initialize the compression library's dictionary
+ st = deflateSetDictionary(
+ &_stream, reinterpret_cast<const Bytef*>(compression_dict.data()),
+ static_cast<unsigned int>(compression_dict.size()));
+ if (st != Z_OK) {
+ deflateEnd(&_stream);
+ return false;
+ }
+ }
+
+ // Get an upper bound on the compressed size.
+ size_t upper_bound =
+ deflateBound(&_stream, static_cast<unsigned long>(length));
+ output->resize(output_header_len + upper_bound);
+
+ // Compress the input, and put compressed data in output.
+ _stream.next_in = (Bytef*)input;
+ _stream.avail_in = static_cast<unsigned int>(length);
+
+ // Initialize the output size.
+ _stream.avail_out = static_cast<unsigned int>(upper_bound);
+ _stream.next_out = reinterpret_cast<Bytef*>(&(*output)[output_header_len]);
+
+ bool compressed = false;
+ st = deflate(&_stream, Z_FINISH);
+ if (st == Z_STREAM_END) {
+ compressed = true;
+ output->resize(output->size() - _stream.avail_out);
+ }
+ // The only return value we really care about is Z_STREAM_END.
+ // Z_OK means insufficient output space. This means the compression is
+ // bigger than decompressed size. Just fail the compression in that case.
+
+ deflateEnd(&_stream);
+ return compressed;
+#else
+ (void)info;
+ (void)compress_format_version;
+ (void)input;
+ (void)length;
+ (void)output;
+ return false;
+#endif
+}
+
+// compress_format_version == 1 -- decompressed size is not included in the
+// block header
+// compress_format_version == 2 -- decompressed size is included in the block
+// header in varint32 format
+// @param compression_dict Data for presetting the compression library's
+// dictionary.
+inline CacheAllocationPtr Zlib_Uncompress(
+ const UncompressionInfo& info, const char* input_data, size_t input_length,
+ size_t* uncompressed_size, uint32_t compress_format_version,
+ MemoryAllocator* allocator = nullptr, int windowBits = -14) {
+#ifdef ZLIB
+ uint32_t output_len = 0;
+ if (compress_format_version == 2) {
+ if (!compression::GetDecompressedSizeInfo(&input_data, &input_length,
+ &output_len)) {
+ return nullptr;
+ }
+ } else {
+ // Assume the decompressed data size will 5x of compressed size, but round
+ // to the page size
+ size_t proposed_output_len = ((input_length * 5) & (~(4096 - 1))) + 4096;
+ output_len = static_cast<uint32_t>(
+ std::min(proposed_output_len,
+ static_cast<size_t>(std::numeric_limits<uint32_t>::max())));
+ }
+
+ z_stream _stream;
+ memset(&_stream, 0, sizeof(z_stream));
+
+ // For raw inflate, the windowBits should be -8..-15.
+ // If windowBits is bigger than zero, it will use either zlib
+ // header or gzip header. Adding 32 to it will do automatic detection.
+ int st =
+ inflateInit2(&_stream, windowBits > 0 ? windowBits + 32 : windowBits);
+ if (st != Z_OK) {
+ return nullptr;
+ }
+
+ const Slice& compression_dict = info.dict().GetRawDict();
+ if (compression_dict.size()) {
+ // Initialize the compression library's dictionary
+ st = inflateSetDictionary(
+ &_stream, reinterpret_cast<const Bytef*>(compression_dict.data()),
+ static_cast<unsigned int>(compression_dict.size()));
+ if (st != Z_OK) {
+ return nullptr;
+ }
+ }
+
+ _stream.next_in = (Bytef*)input_data;
+ _stream.avail_in = static_cast<unsigned int>(input_length);
+
+ auto output = AllocateBlock(output_len, allocator);
+
+ _stream.next_out = (Bytef*)output.get();
+ _stream.avail_out = static_cast<unsigned int>(output_len);
+
+ bool done = false;
+ while (!done) {
+ st = inflate(&_stream, Z_SYNC_FLUSH);
+ switch (st) {
+ case Z_STREAM_END:
+ done = true;
+ break;
+ case Z_OK: {
+ // No output space. Increase the output space by 20%.
+ // We should never run out of output space if
+ // compress_format_version == 2
+ assert(compress_format_version != 2);
+ size_t old_sz = output_len;
+ uint32_t output_len_delta = output_len / 5;
+ output_len += output_len_delta < 10 ? 10 : output_len_delta;
+ auto tmp = AllocateBlock(output_len, allocator);
+ memcpy(tmp.get(), output.get(), old_sz);
+ output = std::move(tmp);
+
+ // Set more output.
+ _stream.next_out = (Bytef*)(output.get() + old_sz);
+ _stream.avail_out = static_cast<unsigned int>(output_len - old_sz);
+ break;
+ }
+ case Z_BUF_ERROR:
+ default:
+ inflateEnd(&_stream);
+ return nullptr;
+ }
+ }
+
+ // If we encoded decompressed block size, we should have no bytes left
+ assert(compress_format_version != 2 || _stream.avail_out == 0);
+ assert(output_len >= _stream.avail_out);
+ *uncompressed_size = output_len - _stream.avail_out;
+ inflateEnd(&_stream);
+ return output;
+#else
+ (void)info;
+ (void)input_data;
+ (void)input_length;
+ (void)uncompressed_size;
+ (void)compress_format_version;
+ (void)allocator;
+ (void)windowBits;
+ return nullptr;
+#endif
+}
+
+// compress_format_version == 1 -- decompressed size is not included in the
+// block header
+// compress_format_version == 2 -- decompressed size is included in the block
+// header in varint32 format
+inline bool BZip2_Compress(const CompressionInfo& /*info*/,
+ uint32_t compress_format_version, const char* input,
+ size_t length, ::std::string* output) {
+#ifdef BZIP2
+ if (length > std::numeric_limits<uint32_t>::max()) {
+ // Can't compress more than 4GB
+ return false;
+ }
+ size_t output_header_len = 0;
+ if (compress_format_version == 2) {
+ output_header_len = compression::PutDecompressedSizeInfo(
+ output, static_cast<uint32_t>(length));
+ }
+ // Resize output to be the plain data length.
+ // This may not be big enough if the compression actually expands data.
+ output->resize(output_header_len + length);
+
+ bz_stream _stream;
+ memset(&_stream, 0, sizeof(bz_stream));
+
+ // Block size 1 is 100K.
+ // 0 is for silent.
+ // 30 is the default workFactor
+ int st = BZ2_bzCompressInit(&_stream, 1, 0, 30);
+ if (st != BZ_OK) {
+ return false;
+ }
+
+ // Compress the input, and put compressed data in output.
+ _stream.next_in = (char*)input;
+ _stream.avail_in = static_cast<unsigned int>(length);
+
+ // Initialize the output size.
+ _stream.avail_out = static_cast<unsigned int>(length);
+ _stream.next_out = reinterpret_cast<char*>(&(*output)[output_header_len]);
+
+ bool compressed = false;
+ st = BZ2_bzCompress(&_stream, BZ_FINISH);
+ if (st == BZ_STREAM_END) {
+ compressed = true;
+ output->resize(output->size() - _stream.avail_out);
+ }
+ // The only return value we really care about is BZ_STREAM_END.
+ // BZ_FINISH_OK means insufficient output space. This means the compression
+ // is bigger than decompressed size. Just fail the compression in that case.
+
+ BZ2_bzCompressEnd(&_stream);
+ return compressed;
+#else
+ (void)compress_format_version;
+ (void)input;
+ (void)length;
+ (void)output;
+ return false;
+#endif
+}
+
+// compress_format_version == 1 -- decompressed size is not included in the
+// block header
+// compress_format_version == 2 -- decompressed size is included in the block
+// header in varint32 format
+inline CacheAllocationPtr BZip2_Uncompress(
+ const char* input_data, size_t input_length, size_t* uncompressed_size,
+ uint32_t compress_format_version, MemoryAllocator* allocator = nullptr) {
+#ifdef BZIP2
+ uint32_t output_len = 0;
+ if (compress_format_version == 2) {
+ if (!compression::GetDecompressedSizeInfo(&input_data, &input_length,
+ &output_len)) {
+ return nullptr;
+ }
+ } else {
+ // Assume the decompressed data size will 5x of compressed size, but round
+ // to the next page size
+ size_t proposed_output_len = ((input_length * 5) & (~(4096 - 1))) + 4096;
+ output_len = static_cast<uint32_t>(
+ std::min(proposed_output_len,
+ static_cast<size_t>(std::numeric_limits<uint32_t>::max())));
+ }
+
+ bz_stream _stream;
+ memset(&_stream, 0, sizeof(bz_stream));
+
+ int st = BZ2_bzDecompressInit(&_stream, 0, 0);
+ if (st != BZ_OK) {
+ return nullptr;
+ }
+
+ _stream.next_in = (char*)input_data;
+ _stream.avail_in = static_cast<unsigned int>(input_length);
+
+ auto output = AllocateBlock(output_len, allocator);
+
+ _stream.next_out = (char*)output.get();
+ _stream.avail_out = static_cast<unsigned int>(output_len);
+
+ bool done = false;
+ while (!done) {
+ st = BZ2_bzDecompress(&_stream);
+ switch (st) {
+ case BZ_STREAM_END:
+ done = true;
+ break;
+ case BZ_OK: {
+ // No output space. Increase the output space by 20%.
+ // We should never run out of output space if
+ // compress_format_version == 2
+ assert(compress_format_version != 2);
+ uint32_t old_sz = output_len;
+ output_len = output_len * 1.2;
+ auto tmp = AllocateBlock(output_len, allocator);
+ memcpy(tmp.get(), output.get(), old_sz);
+ output = std::move(tmp);
+
+ // Set more output.
+ _stream.next_out = (char*)(output.get() + old_sz);
+ _stream.avail_out = static_cast<unsigned int>(output_len - old_sz);
+ break;
+ }
+ default:
+ BZ2_bzDecompressEnd(&_stream);
+ return nullptr;
+ }
+ }
+
+ // If we encoded decompressed block size, we should have no bytes left
+ assert(compress_format_version != 2 || _stream.avail_out == 0);
+ assert(output_len >= _stream.avail_out);
+ *uncompressed_size = output_len - _stream.avail_out;
+ BZ2_bzDecompressEnd(&_stream);
+ return output;
+#else
+ (void)input_data;
+ (void)input_length;
+ (void)uncompressed_size;
+ (void)compress_format_version;
+ (void)allocator;
+ return nullptr;
+#endif
+}
+
+// compress_format_version == 1 -- decompressed size is included in the
+// block header using memcpy, which makes database non-portable)
+// compress_format_version == 2 -- decompressed size is included in the block
+// header in varint32 format
+// @param compression_dict Data for presetting the compression library's
+// dictionary.
+inline bool LZ4_Compress(const CompressionInfo& info,
+ uint32_t compress_format_version, const char* input,
+ size_t length, ::std::string* output) {
+#ifdef LZ4
+ if (length > std::numeric_limits<uint32_t>::max()) {
+ // Can't compress more than 4GB
+ return false;
+ }
+
+ size_t output_header_len = 0;
+ if (compress_format_version == 2) {
+ // new encoding, using varint32 to store size information
+ output_header_len = compression::PutDecompressedSizeInfo(
+ output, static_cast<uint32_t>(length));
+ } else {
+ // legacy encoding, which is not really portable (depends on big/little
+ // endianness)
+ output_header_len = 8;
+ output->resize(output_header_len);
+ char* p = const_cast<char*>(output->c_str());
+ memcpy(p, &length, sizeof(length));
+ }
+ int compress_bound = LZ4_compressBound(static_cast<int>(length));
+ output->resize(static_cast<size_t>(output_header_len + compress_bound));
+
+ int outlen;
+#if LZ4_VERSION_NUMBER >= 10400 // r124+
+ LZ4_stream_t* stream = LZ4_createStream();
+ Slice compression_dict = info.dict().GetRawDict();
+ if (compression_dict.size()) {
+ LZ4_loadDict(stream, compression_dict.data(),
+ static_cast<int>(compression_dict.size()));
+ }
+#if LZ4_VERSION_NUMBER >= 10700 // r129+
+ outlen =
+ LZ4_compress_fast_continue(stream, input, &(*output)[output_header_len],
+ static_cast<int>(length), compress_bound, 1);
+#else // up to r128
+ outlen = LZ4_compress_limitedOutput_continue(
+ stream, input, &(*output)[output_header_len], static_cast<int>(length),
+ compress_bound);
+#endif
+ LZ4_freeStream(stream);
+#else // up to r123
+ outlen = LZ4_compress_limitedOutput(input, &(*output)[output_header_len],
+ static_cast<int>(length), compress_bound);
+#endif // LZ4_VERSION_NUMBER >= 10400
+
+ if (outlen == 0) {
+ return false;
+ }
+ output->resize(static_cast<size_t>(output_header_len + outlen));
+ return true;
+#else // LZ4
+ (void)info;
+ (void)compress_format_version;
+ (void)input;
+ (void)length;
+ (void)output;
+ return false;
+#endif
+}
+
+// compress_format_version == 1 -- decompressed size is included in the
+// block header using memcpy, which makes database non-portable)
+// compress_format_version == 2 -- decompressed size is included in the block
+// header in varint32 format
+// @param compression_dict Data for presetting the compression library's
+// dictionary.
+inline CacheAllocationPtr LZ4_Uncompress(const UncompressionInfo& info,
+ const char* input_data,
+ size_t input_length,
+ size_t* uncompressed_size,
+ uint32_t compress_format_version,
+ MemoryAllocator* allocator = nullptr) {
+#ifdef LZ4
+ uint32_t output_len = 0;
+ if (compress_format_version == 2) {
+ // new encoding, using varint32 to store size information
+ if (!compression::GetDecompressedSizeInfo(&input_data, &input_length,
+ &output_len)) {
+ return nullptr;
+ }
+ } else {
+ // legacy encoding, which is not really portable (depends on big/little
+ // endianness)
+ if (input_length < 8) {
+ return nullptr;
+ }
+ if (port::kLittleEndian) {
+ memcpy(&output_len, input_data, sizeof(output_len));
+ } else {
+ memcpy(&output_len, input_data + 4, sizeof(output_len));
+ }
+ input_length -= 8;
+ input_data += 8;
+ }
+
+ auto output = AllocateBlock(output_len, allocator);
+
+ int decompress_bytes = 0;
+
+#if LZ4_VERSION_NUMBER >= 10400 // r124+
+ LZ4_streamDecode_t* stream = LZ4_createStreamDecode();
+ const Slice& compression_dict = info.dict().GetRawDict();
+ if (compression_dict.size()) {
+ LZ4_setStreamDecode(stream, compression_dict.data(),
+ static_cast<int>(compression_dict.size()));
+ }
+ decompress_bytes = LZ4_decompress_safe_continue(
+ stream, input_data, output.get(), static_cast<int>(input_length),
+ static_cast<int>(output_len));
+ LZ4_freeStreamDecode(stream);
+#else // up to r123
+ decompress_bytes = LZ4_decompress_safe(input_data, output.get(),
+ static_cast<int>(input_length),
+ static_cast<int>(output_len));
+#endif // LZ4_VERSION_NUMBER >= 10400
+
+ if (decompress_bytes < 0) {
+ return nullptr;
+ }
+ assert(decompress_bytes == static_cast<int>(output_len));
+ *uncompressed_size = decompress_bytes;
+ return output;
+#else // LZ4
+ (void)info;
+ (void)input_data;
+ (void)input_length;
+ (void)uncompressed_size;
+ (void)compress_format_version;
+ (void)allocator;
+ return nullptr;
+#endif
+}
+
+// compress_format_version == 1 -- decompressed size is included in the
+// block header using memcpy, which makes database non-portable)
+// compress_format_version == 2 -- decompressed size is included in the block
+// header in varint32 format
+// @param compression_dict Data for presetting the compression library's
+// dictionary.
+inline bool LZ4HC_Compress(const CompressionInfo& info,
+ uint32_t compress_format_version, const char* input,
+ size_t length, ::std::string* output) {
+#ifdef LZ4
+ if (length > std::numeric_limits<uint32_t>::max()) {
+ // Can't compress more than 4GB
+ return false;
+ }
+
+ size_t output_header_len = 0;
+ if (compress_format_version == 2) {
+ // new encoding, using varint32 to store size information
+ output_header_len = compression::PutDecompressedSizeInfo(
+ output, static_cast<uint32_t>(length));
+ } else {
+ // legacy encoding, which is not really portable (depends on big/little
+ // endianness)
+ output_header_len = 8;
+ output->resize(output_header_len);
+ char* p = const_cast<char*>(output->c_str());
+ memcpy(p, &length, sizeof(length));
+ }
+ int compress_bound = LZ4_compressBound(static_cast<int>(length));
+ output->resize(static_cast<size_t>(output_header_len + compress_bound));
+
+ int outlen;
+ int level;
+ if (info.options().level == CompressionOptions::kDefaultCompressionLevel) {
+ level = 0; // lz4hc.h says any value < 1 will be sanitized to default
+ } else {
+ level = info.options().level;
+ }
+#if LZ4_VERSION_NUMBER >= 10400 // r124+
+ LZ4_streamHC_t* stream = LZ4_createStreamHC();
+ LZ4_resetStreamHC(stream, level);
+ Slice compression_dict = info.dict().GetRawDict();
+ const char* compression_dict_data =
+ compression_dict.size() > 0 ? compression_dict.data() : nullptr;
+ size_t compression_dict_size = compression_dict.size();
+ if (compression_dict_data != nullptr) {
+ LZ4_loadDictHC(stream, compression_dict_data,
+ static_cast<int>(compression_dict_size));
+ }
+
+#if LZ4_VERSION_NUMBER >= 10700 // r129+
+ outlen =
+ LZ4_compress_HC_continue(stream, input, &(*output)[output_header_len],
+ static_cast<int>(length), compress_bound);
+#else // r124-r128
+ outlen = LZ4_compressHC_limitedOutput_continue(
+ stream, input, &(*output)[output_header_len], static_cast<int>(length),
+ compress_bound);
+#endif // LZ4_VERSION_NUMBER >= 10700
+ LZ4_freeStreamHC(stream);
+
+#elif LZ4_VERSION_MAJOR // r113-r123
+ outlen = LZ4_compressHC2_limitedOutput(input, &(*output)[output_header_len],
+ static_cast<int>(length),
+ compress_bound, level);
+#else // up to r112
+ outlen =
+ LZ4_compressHC_limitedOutput(input, &(*output)[output_header_len],
+ static_cast<int>(length), compress_bound);
+#endif // LZ4_VERSION_NUMBER >= 10400
+
+ if (outlen == 0) {
+ return false;
+ }
+ output->resize(static_cast<size_t>(output_header_len + outlen));
+ return true;
+#else // LZ4
+ (void)info;
+ (void)compress_format_version;
+ (void)input;
+ (void)length;
+ (void)output;
+ return false;
+#endif
+}
+
+#ifdef XPRESS
+inline bool XPRESS_Compress(const char* input, size_t length,
+ std::string* output) {
+ return port::xpress::Compress(input, length, output);
+}
+#else
+inline bool XPRESS_Compress(const char* /*input*/, size_t /*length*/,
+ std::string* /*output*/) {
+ return false;
+}
+#endif
+
+#ifdef XPRESS
+inline char* XPRESS_Uncompress(const char* input_data, size_t input_length,
+ size_t* uncompressed_size) {
+ return port::xpress::Decompress(input_data, input_length, uncompressed_size);
+}
+#else
+inline char* XPRESS_Uncompress(const char* /*input_data*/,
+ size_t /*input_length*/,
+ size_t* /*uncompressed_size*/) {
+ return nullptr;
+}
+#endif
+
+inline bool ZSTD_Compress(const CompressionInfo& info, const char* input,
+ size_t length, ::std::string* output) {
+#ifdef ZSTD
+ if (length > std::numeric_limits<uint32_t>::max()) {
+ // Can't compress more than 4GB
+ return false;
+ }
+
+ size_t output_header_len = compression::PutDecompressedSizeInfo(
+ output, static_cast<uint32_t>(length));
+
+ size_t compressBound = ZSTD_compressBound(length);
+ output->resize(static_cast<size_t>(output_header_len + compressBound));
+ size_t outlen = 0;
+ int level;
+ if (info.options().level == CompressionOptions::kDefaultCompressionLevel) {
+ // 3 is the value of ZSTD_CLEVEL_DEFAULT (not exposed publicly), see
+ // https://github.com/facebook/zstd/issues/1148
+ level = 3;
+ } else {
+ level = info.options().level;
+ }
+#if ZSTD_VERSION_NUMBER >= 500 // v0.5.0+
+ ZSTD_CCtx* context = info.context().ZSTDPreallocCtx();
+ assert(context != nullptr);
+#if ZSTD_VERSION_NUMBER >= 700 // v0.7.0+
+ if (info.dict().GetDigestedZstdCDict() != nullptr) {
+ outlen = ZSTD_compress_usingCDict(context, &(*output)[output_header_len],
+ compressBound, input, length,
+ info.dict().GetDigestedZstdCDict());
+ }
+#endif // ZSTD_VERSION_NUMBER >= 700
+ if (outlen == 0) {
+ outlen = ZSTD_compress_usingDict(context, &(*output)[output_header_len],
+ compressBound, input, length,
+ info.dict().GetRawDict().data(),
+ info.dict().GetRawDict().size(), level);
+ }
+#else // up to v0.4.x
+ outlen = ZSTD_compress(&(*output)[output_header_len], compressBound, input,
+ length, level);
+#endif // ZSTD_VERSION_NUMBER >= 500
+ if (outlen == 0) {
+ return false;
+ }
+ output->resize(output_header_len + outlen);
+ return true;
+#else // ZSTD
+ (void)info;
+ (void)input;
+ (void)length;
+ (void)output;
+ return false;
+#endif
+}
+
+// @param compression_dict Data for presetting the compression library's
+// dictionary.
+inline CacheAllocationPtr ZSTD_Uncompress(
+ const UncompressionInfo& info, const char* input_data, size_t input_length,
+ size_t* uncompressed_size, MemoryAllocator* allocator = nullptr) {
+#ifdef ZSTD
+ uint32_t output_len = 0;
+ if (!compression::GetDecompressedSizeInfo(&input_data, &input_length,
+ &output_len)) {
+ return nullptr;
+ }
+
+ auto output = AllocateBlock(output_len, allocator);
+ size_t actual_output_length = 0;
+#if ZSTD_VERSION_NUMBER >= 500 // v0.5.0+
+ ZSTD_DCtx* context = info.context().GetZSTDContext();
+ assert(context != nullptr);
+#ifdef ROCKSDB_ZSTD_DDICT
+ if (info.dict().GetDigestedZstdDDict() != nullptr) {
+ actual_output_length = ZSTD_decompress_usingDDict(
+ context, output.get(), output_len, input_data, input_length,
+ info.dict().GetDigestedZstdDDict());
+ }
+#endif // ROCKSDB_ZSTD_DDICT
+ if (actual_output_length == 0) {
+ actual_output_length = ZSTD_decompress_usingDict(
+ context, output.get(), output_len, input_data, input_length,
+ info.dict().GetRawDict().data(), info.dict().GetRawDict().size());
+ }
+#else // up to v0.4.x
+ (void)info;
+ actual_output_length =
+ ZSTD_decompress(output.get(), output_len, input_data, input_length);
+#endif // ZSTD_VERSION_NUMBER >= 500
+ assert(actual_output_length == output_len);
+ *uncompressed_size = actual_output_length;
+ return output;
+#else // ZSTD
+ (void)info;
+ (void)input_data;
+ (void)input_length;
+ (void)uncompressed_size;
+ (void)allocator;
+ return nullptr;
+#endif
+}
+
+inline bool ZSTD_TrainDictionarySupported() {
+#ifdef ZSTD
+ // Dictionary trainer is available since v0.6.1 for static linking, but not
+ // available for dynamic linking until v1.1.3. For now we enable the feature
+ // in v1.1.3+ only.
+ return (ZSTD_versionNumber() >= 10103);
+#else
+ return false;
+#endif
+}
+
+inline std::string ZSTD_TrainDictionary(const std::string& samples,
+ const std::vector<size_t>& sample_lens,
+ size_t max_dict_bytes) {
+ // Dictionary trainer is available since v0.6.1 for static linking, but not
+ // available for dynamic linking until v1.1.3. For now we enable the feature
+ // in v1.1.3+ only.
+#if ZSTD_VERSION_NUMBER >= 10103 // v1.1.3+
+ assert(samples.empty() == sample_lens.empty());
+ if (samples.empty()) {
+ return "";
+ }
+ std::string dict_data(max_dict_bytes, '\0');
+ size_t dict_len = ZDICT_trainFromBuffer(
+ &dict_data[0], max_dict_bytes, &samples[0], &sample_lens[0],
+ static_cast<unsigned>(sample_lens.size()));
+ if (ZDICT_isError(dict_len)) {
+ return "";
+ }
+ assert(dict_len <= max_dict_bytes);
+ dict_data.resize(dict_len);
+ return dict_data;
+#else // up to v1.1.2
+ assert(false);
+ (void)samples;
+ (void)sample_lens;
+ (void)max_dict_bytes;
+ return "";
+#endif // ZSTD_VERSION_NUMBER >= 10103
+}
+
+inline std::string ZSTD_TrainDictionary(const std::string& samples,
+ size_t sample_len_shift,
+ size_t max_dict_bytes) {
+ // Dictionary trainer is available since v0.6.1, but ZSTD was marked stable
+ // only since v0.8.0. For now we enable the feature in stable versions only.
+#if ZSTD_VERSION_NUMBER >= 10103 // v1.1.3+
+ // skips potential partial sample at the end of "samples"
+ size_t num_samples = samples.size() >> sample_len_shift;
+ std::vector<size_t> sample_lens(num_samples, size_t(1) << sample_len_shift);
+ return ZSTD_TrainDictionary(samples, sample_lens, max_dict_bytes);
+#else // up to v1.1.2
+ assert(false);
+ (void)samples;
+ (void)sample_len_shift;
+ (void)max_dict_bytes;
+ return "";
+#endif // ZSTD_VERSION_NUMBER >= 10103
+}
+
+inline bool ZSTD_FinalizeDictionarySupported() {
+#ifdef ZSTD
+ // ZDICT_finalizeDictionary API is stable since v1.4.5
+ return (ZSTD_versionNumber() >= 10405);
+#else
+ return false;
+#endif
+}
+
+inline std::string ZSTD_FinalizeDictionary(
+ const std::string& samples, const std::vector<size_t>& sample_lens,
+ size_t max_dict_bytes, int level) {
+ // ZDICT_finalizeDictionary is stable since version v1.4.5
+#if ZSTD_VERSION_NUMBER >= 10405 // v1.4.5+
+ assert(samples.empty() == sample_lens.empty());
+ if (samples.empty()) {
+ return "";
+ }
+ if (level == CompressionOptions::kDefaultCompressionLevel) {
+ // 3 is the value of ZSTD_CLEVEL_DEFAULT (not exposed publicly), see
+ // https://github.com/facebook/zstd/issues/1148
+ level = 3;
+ }
+ std::string dict_data(max_dict_bytes, '\0');
+ size_t dict_len = ZDICT_finalizeDictionary(
+ dict_data.data(), max_dict_bytes, samples.data(),
+ std::min(static_cast<size_t>(samples.size()), max_dict_bytes),
+ samples.data(), sample_lens.data(),
+ static_cast<unsigned>(sample_lens.size()),
+ {level, 0 /* notificationLevel */, 0 /* dictID */});
+ if (ZDICT_isError(dict_len)) {
+ return "";
+ } else {
+ assert(dict_len <= max_dict_bytes);
+ dict_data.resize(dict_len);
+ return dict_data;
+ }
+#else // up to v1.4.4
+ (void)samples;
+ (void)sample_lens;
+ (void)max_dict_bytes;
+ (void)level;
+ return "";
+#endif // ZSTD_VERSION_NUMBER >= 10405
+}
+
+inline bool CompressData(const Slice& raw,
+ const CompressionInfo& compression_info,
+ uint32_t compress_format_version,
+ std::string* compressed_output) {
+ bool ret = false;
+
+ // Will return compressed block contents if (1) the compression method is
+ // supported in this platform and (2) the compression rate is "good enough".
+ switch (compression_info.type()) {
+ case kSnappyCompression:
+ ret = Snappy_Compress(compression_info, raw.data(), raw.size(),
+ compressed_output);
+ break;
+ case kZlibCompression:
+ ret = Zlib_Compress(compression_info, compress_format_version, raw.data(),
+ raw.size(), compressed_output);
+ break;
+ case kBZip2Compression:
+ ret = BZip2_Compress(compression_info, compress_format_version,
+ raw.data(), raw.size(), compressed_output);
+ break;
+ case kLZ4Compression:
+ ret = LZ4_Compress(compression_info, compress_format_version, raw.data(),
+ raw.size(), compressed_output);
+ break;
+ case kLZ4HCCompression:
+ ret = LZ4HC_Compress(compression_info, compress_format_version,
+ raw.data(), raw.size(), compressed_output);
+ break;
+ case kXpressCompression:
+ ret = XPRESS_Compress(raw.data(), raw.size(), compressed_output);
+ break;
+ case kZSTD:
+ case kZSTDNotFinalCompression:
+ ret = ZSTD_Compress(compression_info, raw.data(), raw.size(),
+ compressed_output);
+ break;
+ default:
+ // Do not recognize this compression type
+ break;
+ }
+
+ TEST_SYNC_POINT_CALLBACK("CompressData:TamperWithReturnValue",
+ static_cast<void*>(&ret));
+
+ return ret;
+}
+
+inline CacheAllocationPtr UncompressData(
+ const UncompressionInfo& uncompression_info, const char* data, size_t n,
+ size_t* uncompressed_size, uint32_t compress_format_version,
+ MemoryAllocator* allocator = nullptr) {
+ switch (uncompression_info.type()) {
+ case kSnappyCompression:
+ return Snappy_Uncompress(data, n, uncompressed_size, allocator);
+ case kZlibCompression:
+ return Zlib_Uncompress(uncompression_info, data, n, uncompressed_size,
+ compress_format_version, allocator);
+ case kBZip2Compression:
+ return BZip2_Uncompress(data, n, uncompressed_size,
+ compress_format_version, allocator);
+ case kLZ4Compression:
+ case kLZ4HCCompression:
+ return LZ4_Uncompress(uncompression_info, data, n, uncompressed_size,
+ compress_format_version, allocator);
+ case kXpressCompression:
+ // XPRESS allocates memory internally, thus no support for custom
+ // allocator.
+ return CacheAllocationPtr(XPRESS_Uncompress(data, n, uncompressed_size));
+ case kZSTD:
+ case kZSTDNotFinalCompression:
+ return ZSTD_Uncompress(uncompression_info, data, n, uncompressed_size,
+ allocator);
+ default:
+ return CacheAllocationPtr();
+ }
+}
+
+// Records the compression type for subsequent WAL records.
+class CompressionTypeRecord {
+ public:
+ explicit CompressionTypeRecord(CompressionType compression_type)
+ : compression_type_(compression_type) {}
+
+ CompressionType GetCompressionType() const { return compression_type_; }
+
+ inline void EncodeTo(std::string* dst) const {
+ assert(dst != nullptr);
+ PutFixed32(dst, compression_type_);
+ }
+
+ inline Status DecodeFrom(Slice* src) {
+ constexpr char class_name[] = "CompressionTypeRecord";
+
+ uint32_t val;
+ if (!GetFixed32(src, &val)) {
+ return Status::Corruption(class_name,
+ "Error decoding WAL compression type");
+ }
+ CompressionType compression_type = static_cast<CompressionType>(val);
+ if (!StreamingCompressionTypeSupported(compression_type)) {
+ return Status::Corruption(class_name,
+ "WAL compression type not supported");
+ }
+ compression_type_ = compression_type;
+ return Status::OK();
+ }
+
+ inline std::string DebugString() const {
+ return "compression_type: " + CompressionTypeToString(compression_type_);
+ }
+
+ private:
+ CompressionType compression_type_;
+};
+
+// Base class to implement compression for a stream of buffers.
+// Instantiate an implementation of the class using Create() with the
+// compression type and use Compress() repeatedly.
+// The output buffer needs to be at least max_output_len.
+// Call Reset() in between frame boundaries or in case of an error.
+// NOTE: This class is not thread safe.
+class StreamingCompress {
+ public:
+ StreamingCompress(CompressionType compression_type,
+ const CompressionOptions& opts,
+ uint32_t compress_format_version, size_t max_output_len)
+ : compression_type_(compression_type),
+ opts_(opts),
+ compress_format_version_(compress_format_version),
+ max_output_len_(max_output_len) {}
+ virtual ~StreamingCompress() = default;
+ // compress should be called repeatedly with the same input till the method
+ // returns 0
+ // Parameters:
+ // input - buffer to compress
+ // input_size - size of input buffer
+ // output - compressed buffer allocated by caller, should be at least
+ // max_output_len
+ // output_size - size of the output buffer
+ // Returns -1 for errors, the remaining size of the input buffer that needs to
+ // be compressed
+ virtual int Compress(const char* input, size_t input_size, char* output,
+ size_t* output_pos) = 0;
+ // static method to create object of a class inherited from StreamingCompress
+ // based on the actual compression type.
+ static StreamingCompress* Create(CompressionType compression_type,
+ const CompressionOptions& opts,
+ uint32_t compress_format_version,
+ size_t max_output_len);
+ virtual void Reset() = 0;
+
+ protected:
+ const CompressionType compression_type_;
+ const CompressionOptions opts_;
+ const uint32_t compress_format_version_;
+ const size_t max_output_len_;
+};
+
+// Base class to uncompress a stream of compressed buffers.
+// Instantiate an implementation of the class using Create() with the
+// compression type and use Uncompress() repeatedly.
+// The output buffer needs to be at least max_output_len.
+// Call Reset() in between frame boundaries or in case of an error.
+// NOTE: This class is not thread safe.
+class StreamingUncompress {
+ public:
+ StreamingUncompress(CompressionType compression_type,
+ uint32_t compress_format_version, size_t max_output_len)
+ : compression_type_(compression_type),
+ compress_format_version_(compress_format_version),
+ max_output_len_(max_output_len) {}
+ virtual ~StreamingUncompress() = default;
+ // uncompress should be called again with the same input if output_size is
+ // equal to max_output_len or with the next input fragment.
+ // Parameters:
+ // input - buffer to uncompress
+ // input_size - size of input buffer
+ // output - uncompressed buffer allocated by caller, should be at least
+ // max_output_len
+ // output_size - size of the output buffer
+ // Returns -1 for errors, remaining input to be processed otherwise.
+ virtual int Uncompress(const char* input, size_t input_size, char* output,
+ size_t* output_pos) = 0;
+ static StreamingUncompress* Create(CompressionType compression_type,
+ uint32_t compress_format_version,
+ size_t max_output_len);
+ virtual void Reset() = 0;
+
+ protected:
+ CompressionType compression_type_;
+ uint32_t compress_format_version_;
+ size_t max_output_len_;
+};
+
+class ZSTDStreamingCompress final : public StreamingCompress {
+ public:
+ explicit ZSTDStreamingCompress(const CompressionOptions& opts,
+ uint32_t compress_format_version,
+ size_t max_output_len)
+ : StreamingCompress(kZSTD, opts, compress_format_version,
+ max_output_len) {
+#ifdef ZSTD_STREAMING
+ cctx_ = ZSTD_createCCtx();
+ // Each compressed frame will have a checksum
+ ZSTD_CCtx_setParameter(cctx_, ZSTD_c_checksumFlag, 1);
+ assert(cctx_ != nullptr);
+ input_buffer_ = {/*src=*/nullptr, /*size=*/0, /*pos=*/0};
+#endif
+ }
+ ~ZSTDStreamingCompress() override {
+#ifdef ZSTD_STREAMING
+ ZSTD_freeCCtx(cctx_);
+#endif
+ }
+ int Compress(const char* input, size_t input_size, char* output,
+ size_t* output_pos) override;
+ void Reset() override;
+#ifdef ZSTD_STREAMING
+ ZSTD_CCtx* cctx_;
+ ZSTD_inBuffer input_buffer_;
+#endif
+};
+
+class ZSTDStreamingUncompress final : public StreamingUncompress {
+ public:
+ explicit ZSTDStreamingUncompress(uint32_t compress_format_version,
+ size_t max_output_len)
+ : StreamingUncompress(kZSTD, compress_format_version, max_output_len) {
+#ifdef ZSTD_STREAMING
+ dctx_ = ZSTD_createDCtx();
+ assert(dctx_ != nullptr);
+ input_buffer_ = {/*src=*/nullptr, /*size=*/0, /*pos=*/0};
+#endif
+ }
+ ~ZSTDStreamingUncompress() override {
+#ifdef ZSTD_STREAMING
+ ZSTD_freeDCtx(dctx_);
+#endif
+ }
+ int Uncompress(const char* input, size_t input_size, char* output,
+ size_t* output_size) override;
+ void Reset() override;
+
+ private:
+#ifdef ZSTD_STREAMING
+ ZSTD_DCtx* dctx_;
+ ZSTD_inBuffer input_buffer_;
+#endif
+};
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/compression_context_cache.cc b/src/rocksdb/util/compression_context_cache.cc
new file mode 100644
index 000000000..52c3fac72
--- /dev/null
+++ b/src/rocksdb/util/compression_context_cache.cc
@@ -0,0 +1,106 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+//
+
+#include "util/compression_context_cache.h"
+
+#include <atomic>
+
+#include "util/compression.h"
+#include "util/core_local.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace compression_cache {
+
+void* const SentinelValue = nullptr;
+// Cache ZSTD uncompression contexts for reads
+// if needed we can add ZSTD compression context caching
+// which is currently is not done since BlockBasedTableBuilder
+// simply creates one compression context per new SST file.
+struct ZSTDCachedData {
+ // We choose to cache the below structure instead of a ptr
+ // because we want to avoid a) native types leak b) make
+ // cache use transparent for the user
+ ZSTDUncompressCachedData uncomp_cached_data_;
+ std::atomic<void*> zstd_uncomp_sentinel_;
+
+ char
+ padding[(CACHE_LINE_SIZE -
+ (sizeof(ZSTDUncompressCachedData) + sizeof(std::atomic<void*>)) %
+ CACHE_LINE_SIZE)]; // unused padding field
+
+ ZSTDCachedData() : zstd_uncomp_sentinel_(&uncomp_cached_data_) {}
+ ZSTDCachedData(const ZSTDCachedData&) = delete;
+ ZSTDCachedData& operator=(const ZSTDCachedData&) = delete;
+
+ ZSTDUncompressCachedData GetUncompressData(int64_t idx) {
+ ZSTDUncompressCachedData result;
+ void* expected = &uncomp_cached_data_;
+ if (zstd_uncomp_sentinel_.compare_exchange_strong(expected,
+ SentinelValue)) {
+ uncomp_cached_data_.CreateIfNeeded();
+ result.InitFromCache(uncomp_cached_data_, idx);
+ } else {
+ // Creates one time use data
+ result.CreateIfNeeded();
+ }
+ return result;
+ }
+ // Return the entry back into circulation
+ // This is executed only when we successfully obtained
+ // in the first place
+ void ReturnUncompressData() {
+ if (zstd_uncomp_sentinel_.exchange(&uncomp_cached_data_) != SentinelValue) {
+ // Means we are returning while not having it acquired.
+ assert(false);
+ }
+ }
+};
+static_assert(sizeof(ZSTDCachedData) % CACHE_LINE_SIZE == 0,
+ "Expected CACHE_LINE_SIZE alignment");
+} // namespace compression_cache
+
+class CompressionContextCache::Rep {
+ public:
+ Rep() {}
+ ZSTDUncompressCachedData GetZSTDUncompressData() {
+ auto p = per_core_uncompr_.AccessElementAndIndex();
+ int64_t idx = static_cast<int64_t>(p.second);
+ return p.first->GetUncompressData(idx);
+ }
+ void ReturnZSTDUncompressData(int64_t idx) {
+ assert(idx >= 0);
+ auto* cn = per_core_uncompr_.AccessAtCore(static_cast<size_t>(idx));
+ cn->ReturnUncompressData();
+ }
+
+ private:
+ CoreLocalArray<compression_cache::ZSTDCachedData> per_core_uncompr_;
+};
+
+CompressionContextCache::CompressionContextCache() : rep_(new Rep()) {}
+
+CompressionContextCache* CompressionContextCache::Instance() {
+ static CompressionContextCache instance;
+ return &instance;
+}
+
+void CompressionContextCache::InitSingleton() { Instance(); }
+
+ZSTDUncompressCachedData
+CompressionContextCache::GetCachedZSTDUncompressData() {
+ return rep_->GetZSTDUncompressData();
+}
+
+void CompressionContextCache::ReturnCachedZSTDUncompressData(int64_t idx) {
+ rep_->ReturnZSTDUncompressData(idx);
+}
+
+CompressionContextCache::~CompressionContextCache() { delete rep_; }
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/compression_context_cache.h b/src/rocksdb/util/compression_context_cache.h
new file mode 100644
index 000000000..7b7b2d507
--- /dev/null
+++ b/src/rocksdb/util/compression_context_cache.h
@@ -0,0 +1,47 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+//
+
+// Compression context cache allows to cache compression/uncompression contexts
+// This helps with Random Read latencies and reduces CPU utilization
+// Caching is implemented using CoreLocal facility. Compression/Uncompression
+// instances are cached on a per core basis using CoreLocalArray. A borrowed
+// instance is atomically replaced with a sentinel value for the time of being
+// used. If it turns out that another thread is already makes use of the
+// instance we still create one on the heap which is later is destroyed.
+
+#pragma once
+
+#include <stdint.h>
+
+#include "rocksdb/rocksdb_namespace.h"
+
+namespace ROCKSDB_NAMESPACE {
+class ZSTDUncompressCachedData;
+
+class CompressionContextCache {
+ public:
+ // Singleton
+ static CompressionContextCache* Instance();
+ static void InitSingleton();
+ CompressionContextCache(const CompressionContextCache&) = delete;
+ CompressionContextCache& operator=(const CompressionContextCache&) = delete;
+
+ ZSTDUncompressCachedData GetCachedZSTDUncompressData();
+ void ReturnCachedZSTDUncompressData(int64_t idx);
+
+ private:
+ // Singleton
+ CompressionContextCache();
+ ~CompressionContextCache();
+
+ class Rep;
+ Rep* rep_;
+};
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/concurrent_task_limiter_impl.cc b/src/rocksdb/util/concurrent_task_limiter_impl.cc
new file mode 100644
index 000000000..a0fc7331f
--- /dev/null
+++ b/src/rocksdb/util/concurrent_task_limiter_impl.cc
@@ -0,0 +1,64 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "util/concurrent_task_limiter_impl.h"
+
+#include "rocksdb/concurrent_task_limiter.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+ConcurrentTaskLimiterImpl::ConcurrentTaskLimiterImpl(
+ const std::string& name, int32_t max_outstanding_task)
+ : name_(name),
+ max_outstanding_tasks_{max_outstanding_task},
+ outstanding_tasks_{0} {}
+
+ConcurrentTaskLimiterImpl::~ConcurrentTaskLimiterImpl() {
+ assert(outstanding_tasks_ == 0);
+}
+
+const std::string& ConcurrentTaskLimiterImpl::GetName() const { return name_; }
+
+void ConcurrentTaskLimiterImpl::SetMaxOutstandingTask(int32_t limit) {
+ max_outstanding_tasks_.store(limit, std::memory_order_relaxed);
+}
+
+void ConcurrentTaskLimiterImpl::ResetMaxOutstandingTask() {
+ max_outstanding_tasks_.store(-1, std::memory_order_relaxed);
+}
+
+int32_t ConcurrentTaskLimiterImpl::GetOutstandingTask() const {
+ return outstanding_tasks_.load(std::memory_order_relaxed);
+}
+
+std::unique_ptr<TaskLimiterToken> ConcurrentTaskLimiterImpl::GetToken(
+ bool force) {
+ int32_t limit = max_outstanding_tasks_.load(std::memory_order_relaxed);
+ int32_t tasks = outstanding_tasks_.load(std::memory_order_relaxed);
+ // force = true, bypass the throttle.
+ // limit < 0 means unlimited tasks.
+ while (force || limit < 0 || tasks < limit) {
+ if (outstanding_tasks_.compare_exchange_weak(tasks, tasks + 1)) {
+ return std::unique_ptr<TaskLimiterToken>(new TaskLimiterToken(this));
+ }
+ }
+ return nullptr;
+}
+
+ConcurrentTaskLimiter* NewConcurrentTaskLimiter(const std::string& name,
+ int32_t limit) {
+ return new ConcurrentTaskLimiterImpl(name, limit);
+}
+
+TaskLimiterToken::~TaskLimiterToken() {
+ --limiter_->outstanding_tasks_;
+ assert(limiter_->outstanding_tasks_ >= 0);
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/concurrent_task_limiter_impl.h b/src/rocksdb/util/concurrent_task_limiter_impl.h
new file mode 100644
index 000000000..4952ae23a
--- /dev/null
+++ b/src/rocksdb/util/concurrent_task_limiter_impl.h
@@ -0,0 +1,67 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#pragma once
+#include <atomic>
+#include <memory>
+
+#include "rocksdb/concurrent_task_limiter.h"
+#include "rocksdb/env.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class TaskLimiterToken;
+
+class ConcurrentTaskLimiterImpl : public ConcurrentTaskLimiter {
+ public:
+ explicit ConcurrentTaskLimiterImpl(const std::string& name,
+ int32_t max_outstanding_task);
+ // No copying allowed
+ ConcurrentTaskLimiterImpl(const ConcurrentTaskLimiterImpl&) = delete;
+ ConcurrentTaskLimiterImpl& operator=(const ConcurrentTaskLimiterImpl&) =
+ delete;
+
+ virtual ~ConcurrentTaskLimiterImpl();
+
+ virtual const std::string& GetName() const override;
+
+ virtual void SetMaxOutstandingTask(int32_t limit) override;
+
+ virtual void ResetMaxOutstandingTask() override;
+
+ virtual int32_t GetOutstandingTask() const override;
+
+ // Request token for adding a new task.
+ // If force == true, it requests a token bypassing throttle.
+ // Returns nullptr if it got throttled.
+ virtual std::unique_ptr<TaskLimiterToken> GetToken(bool force);
+
+ private:
+ friend class TaskLimiterToken;
+
+ std::string name_;
+ std::atomic<int32_t> max_outstanding_tasks_;
+ std::atomic<int32_t> outstanding_tasks_;
+};
+
+class TaskLimiterToken {
+ public:
+ explicit TaskLimiterToken(ConcurrentTaskLimiterImpl* limiter)
+ : limiter_(limiter) {}
+ ~TaskLimiterToken();
+
+ private:
+ ConcurrentTaskLimiterImpl* limiter_;
+
+ // no copying allowed
+ TaskLimiterToken(const TaskLimiterToken&) = delete;
+ void operator=(const TaskLimiterToken&) = delete;
+};
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/core_local.h b/src/rocksdb/util/core_local.h
new file mode 100644
index 000000000..b444a1152
--- /dev/null
+++ b/src/rocksdb/util/core_local.h
@@ -0,0 +1,83 @@
+// Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#include <cstddef>
+#include <thread>
+#include <utility>
+#include <vector>
+
+#include "port/likely.h"
+#include "port/port.h"
+#include "util/random.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// An array of core-local values. Ideally the value type, T, is cache aligned to
+// prevent false sharing.
+template <typename T>
+class CoreLocalArray {
+ public:
+ CoreLocalArray();
+
+ size_t Size() const;
+ // returns pointer to the element corresponding to the core that the thread
+ // currently runs on.
+ T* Access() const;
+ // same as above, but also returns the core index, which the client can cache
+ // to reduce how often core ID needs to be retrieved. Only do this if some
+ // inaccuracy is tolerable, as the thread may migrate to a different core.
+ std::pair<T*, size_t> AccessElementAndIndex() const;
+ // returns pointer to element for the specified core index. This can be used,
+ // e.g., for aggregation, or if the client caches core index.
+ T* AccessAtCore(size_t core_idx) const;
+
+ private:
+ std::unique_ptr<T[]> data_;
+ int size_shift_;
+};
+
+template <typename T>
+CoreLocalArray<T>::CoreLocalArray() {
+ int num_cpus = static_cast<int>(std::thread::hardware_concurrency());
+ // find a power of two >= num_cpus and >= 8
+ size_shift_ = 3;
+ while (1 << size_shift_ < num_cpus) {
+ ++size_shift_;
+ }
+ data_.reset(new T[static_cast<size_t>(1) << size_shift_]);
+}
+
+template <typename T>
+size_t CoreLocalArray<T>::Size() const {
+ return static_cast<size_t>(1) << size_shift_;
+}
+
+template <typename T>
+T* CoreLocalArray<T>::Access() const {
+ return AccessElementAndIndex().first;
+}
+
+template <typename T>
+std::pair<T*, size_t> CoreLocalArray<T>::AccessElementAndIndex() const {
+ int cpuid = port::PhysicalCoreID();
+ size_t core_idx;
+ if (UNLIKELY(cpuid < 0)) {
+ // cpu id unavailable, just pick randomly
+ core_idx = Random::GetTLSInstance()->Uniform(1 << size_shift_);
+ } else {
+ core_idx = static_cast<size_t>(cpuid & ((1 << size_shift_) - 1));
+ }
+ return {AccessAtCore(core_idx), core_idx};
+}
+
+template <typename T>
+T* CoreLocalArray<T>::AccessAtCore(size_t core_idx) const {
+ assert(core_idx < static_cast<size_t>(1) << size_shift_);
+ return &data_[core_idx];
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/coro_utils.h b/src/rocksdb/util/coro_utils.h
new file mode 100644
index 000000000..5b4211135
--- /dev/null
+++ b/src/rocksdb/util/coro_utils.h
@@ -0,0 +1,112 @@
+// Copyright (c) Meta Platforms, Inc. and affiliates.
+//
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#if defined(USE_COROUTINES)
+#include "folly/experimental/coro/Coroutine.h"
+#include "folly/experimental/coro/Task.h"
+#endif
+#include "rocksdb/rocksdb_namespace.h"
+
+// This file has two sctions. The first section applies to all instances of
+// header file inclusion and has an include guard. The second section is
+// meant for multiple inclusions in the same source file, and is idempotent.
+namespace ROCKSDB_NAMESPACE {
+
+#ifndef UTIL_CORO_UTILS_H_
+#define UTIL_CORO_UTILS_H_
+
+#if defined(USE_COROUTINES)
+
+// The follwoing macros expand to regular and coroutine function
+// declarations for a given function
+#define DECLARE_SYNC_AND_ASYNC(__ret_type__, __func_name__, ...) \
+ __ret_type__ __func_name__(__VA_ARGS__); \
+ folly::coro::Task<__ret_type__> __func_name__##Coroutine(__VA_ARGS__);
+
+#define DECLARE_SYNC_AND_ASYNC_OVERRIDE(__ret_type__, __func_name__, ...) \
+ __ret_type__ __func_name__(__VA_ARGS__) override; \
+ folly::coro::Task<__ret_type__> __func_name__##Coroutine(__VA_ARGS__) \
+ override;
+
+#define DECLARE_SYNC_AND_ASYNC_CONST(__ret_type__, __func_name__, ...) \
+ __ret_type__ __func_name__(__VA_ARGS__) const; \
+ folly::coro::Task<__ret_type__> __func_name__##Coroutine(__VA_ARGS__) const;
+
+constexpr bool using_coroutines() { return true; }
+#else // !USE_COROUTINES
+
+// The follwoing macros expand to a regular function declaration for a given
+// function
+#define DECLARE_SYNC_AND_ASYNC(__ret_type__, __func_name__, ...) \
+ __ret_type__ __func_name__(__VA_ARGS__);
+
+#define DECLARE_SYNC_AND_ASYNC_OVERRIDE(__ret_type__, __func_name__, ...) \
+ __ret_type__ __func_name__(__VA_ARGS__) override;
+
+#define DECLARE_SYNC_AND_ASYNC_CONST(__ret_type__, __func_name__, ...) \
+ __ret_type__ __func_name__(__VA_ARGS__) const;
+
+constexpr bool using_coroutines() { return false; }
+#endif // USE_COROUTINES
+#endif // UTIL_CORO_UTILS_H_
+
+// The following section of the file is meant to be included twice in a
+// source file - once defining WITH_COROUTINES and once defining
+// WITHOUT_COROUTINES
+#undef DEFINE_SYNC_AND_ASYNC
+#undef CO_AWAIT
+#undef CO_RETURN
+
+#if defined(WITH_COROUTINES) && defined(USE_COROUTINES)
+
+// This macro should be used in the beginning of the function
+// definition. The declaration should have been done using one of the
+// DECLARE_SYNC_AND_ASYNC* macros. It expands to the return type and
+// the function name with the Coroutine suffix. For example -
+// DEFINE_SYNC_AND_ASYNC(int, foo)(bool bar) {}
+// would expand to -
+// folly::coro::Task<int> fooCoroutine(bool bar) {}
+#define DEFINE_SYNC_AND_ASYNC(__ret_type__, __func_name__) \
+ folly::coro::Task<__ret_type__> __func_name__##Coroutine
+
+// This macro should be used to call a function that might be a
+// coroutine. It expands to the correct function name and prefixes
+// the co_await operator if necessary. For example -
+// s = CO_AWAIT(foo)(true);
+// if the code is compiled WITH_COROUTINES, would expand to
+// s = co_await fooCoroutine(true);
+// if compiled WITHOUT_COROUTINES, would expand to
+// s = foo(true);
+#define CO_AWAIT(__func_name__) co_await __func_name__##Coroutine
+
+#define CO_RETURN co_return
+
+#elif defined(WITHOUT_COROUTINES)
+
+// This macro should be used in the beginning of the function
+// definition. The declaration should have been done using one of the
+// DECLARE_SYNC_AND_ASYNC* macros. It expands to the return type and
+// the function name without the Coroutine suffix. For example -
+// DEFINE_SYNC_AND_ASYNC(int, foo)(bool bar) {}
+// would expand to -
+// int foo(bool bar) {}
+#define DEFINE_SYNC_AND_ASYNC(__ret_type__, __func_name__) \
+ __ret_type__ __func_name__
+
+// This macro should be used to call a function that might be a
+// coroutine. It expands to the correct function name and prefixes
+// the co_await operator if necessary. For example -
+// s = CO_AWAIT(foo)(true);
+// if the code is compiled WITH_COROUTINES, would expand to
+// s = co_await fooCoroutine(true);
+// if compiled WITHOUT_COROUTINES, would expand to
+// s = foo(true);
+#define CO_AWAIT(__func_name__) __func_name__
+
+#define CO_RETURN return
+
+#endif // DO_NOT_USE_COROUTINES
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/crc32c.cc b/src/rocksdb/util/crc32c.cc
new file mode 100644
index 000000000..d71c71c2e
--- /dev/null
+++ b/src/rocksdb/util/crc32c.cc
@@ -0,0 +1,1351 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+//
+// A portable implementation of crc32c, optimized to handle
+// four bytes at a time.
+#include "util/crc32c.h"
+
+#include <stdint.h>
+
+#include <array>
+#include <utility>
+#ifdef HAVE_SSE42
+#include <nmmintrin.h>
+#include <wmmintrin.h>
+#endif
+
+#include "port/lang.h"
+#include "util/coding.h"
+#include "util/crc32c_arm64.h"
+#include "util/math.h"
+
+#ifdef __powerpc64__
+#include "util/crc32c_ppc.h"
+#include "util/crc32c_ppc_constants.h"
+
+#if __linux__
+#ifdef ROCKSDB_AUXV_GETAUXVAL_PRESENT
+#include <sys/auxv.h>
+#endif
+
+#ifndef PPC_FEATURE2_VEC_CRYPTO
+#define PPC_FEATURE2_VEC_CRYPTO 0x02000000
+#endif
+
+#ifndef AT_HWCAP2
+#define AT_HWCAP2 26
+#endif
+
+#elif __FreeBSD__
+#include <machine/cpu.h>
+#include <sys/auxv.h>
+#include <sys/elf_common.h>
+#endif /* __linux__ */
+
+#endif
+
+#if defined(HAVE_ARM64_CRC)
+bool pmull_runtime_flag = false;
+#endif
+
+namespace ROCKSDB_NAMESPACE {
+namespace crc32c {
+
+#if defined(HAVE_POWER8) && defined(HAS_ALTIVEC)
+#ifdef __powerpc64__
+static int arch_ppc_crc32 = 0;
+#endif /* __powerpc64__ */
+#endif
+
+static const uint32_t table0_[256] = {
+ 0x00000000, 0xf26b8303, 0xe13b70f7, 0x1350f3f4, 0xc79a971f, 0x35f1141c,
+ 0x26a1e7e8, 0xd4ca64eb, 0x8ad958cf, 0x78b2dbcc, 0x6be22838, 0x9989ab3b,
+ 0x4d43cfd0, 0xbf284cd3, 0xac78bf27, 0x5e133c24, 0x105ec76f, 0xe235446c,
+ 0xf165b798, 0x030e349b, 0xd7c45070, 0x25afd373, 0x36ff2087, 0xc494a384,
+ 0x9a879fa0, 0x68ec1ca3, 0x7bbcef57, 0x89d76c54, 0x5d1d08bf, 0xaf768bbc,
+ 0xbc267848, 0x4e4dfb4b, 0x20bd8ede, 0xd2d60ddd, 0xc186fe29, 0x33ed7d2a,
+ 0xe72719c1, 0x154c9ac2, 0x061c6936, 0xf477ea35, 0xaa64d611, 0x580f5512,
+ 0x4b5fa6e6, 0xb93425e5, 0x6dfe410e, 0x9f95c20d, 0x8cc531f9, 0x7eaeb2fa,
+ 0x30e349b1, 0xc288cab2, 0xd1d83946, 0x23b3ba45, 0xf779deae, 0x05125dad,
+ 0x1642ae59, 0xe4292d5a, 0xba3a117e, 0x4851927d, 0x5b016189, 0xa96ae28a,
+ 0x7da08661, 0x8fcb0562, 0x9c9bf696, 0x6ef07595, 0x417b1dbc, 0xb3109ebf,
+ 0xa0406d4b, 0x522bee48, 0x86e18aa3, 0x748a09a0, 0x67dafa54, 0x95b17957,
+ 0xcba24573, 0x39c9c670, 0x2a993584, 0xd8f2b687, 0x0c38d26c, 0xfe53516f,
+ 0xed03a29b, 0x1f682198, 0x5125dad3, 0xa34e59d0, 0xb01eaa24, 0x42752927,
+ 0x96bf4dcc, 0x64d4cecf, 0x77843d3b, 0x85efbe38, 0xdbfc821c, 0x2997011f,
+ 0x3ac7f2eb, 0xc8ac71e8, 0x1c661503, 0xee0d9600, 0xfd5d65f4, 0x0f36e6f7,
+ 0x61c69362, 0x93ad1061, 0x80fde395, 0x72966096, 0xa65c047d, 0x5437877e,
+ 0x4767748a, 0xb50cf789, 0xeb1fcbad, 0x197448ae, 0x0a24bb5a, 0xf84f3859,
+ 0x2c855cb2, 0xdeeedfb1, 0xcdbe2c45, 0x3fd5af46, 0x7198540d, 0x83f3d70e,
+ 0x90a324fa, 0x62c8a7f9, 0xb602c312, 0x44694011, 0x5739b3e5, 0xa55230e6,
+ 0xfb410cc2, 0x092a8fc1, 0x1a7a7c35, 0xe811ff36, 0x3cdb9bdd, 0xceb018de,
+ 0xdde0eb2a, 0x2f8b6829, 0x82f63b78, 0x709db87b, 0x63cd4b8f, 0x91a6c88c,
+ 0x456cac67, 0xb7072f64, 0xa457dc90, 0x563c5f93, 0x082f63b7, 0xfa44e0b4,
+ 0xe9141340, 0x1b7f9043, 0xcfb5f4a8, 0x3dde77ab, 0x2e8e845f, 0xdce5075c,
+ 0x92a8fc17, 0x60c37f14, 0x73938ce0, 0x81f80fe3, 0x55326b08, 0xa759e80b,
+ 0xb4091bff, 0x466298fc, 0x1871a4d8, 0xea1a27db, 0xf94ad42f, 0x0b21572c,
+ 0xdfeb33c7, 0x2d80b0c4, 0x3ed04330, 0xccbbc033, 0xa24bb5a6, 0x502036a5,
+ 0x4370c551, 0xb11b4652, 0x65d122b9, 0x97baa1ba, 0x84ea524e, 0x7681d14d,
+ 0x2892ed69, 0xdaf96e6a, 0xc9a99d9e, 0x3bc21e9d, 0xef087a76, 0x1d63f975,
+ 0x0e330a81, 0xfc588982, 0xb21572c9, 0x407ef1ca, 0x532e023e, 0xa145813d,
+ 0x758fe5d6, 0x87e466d5, 0x94b49521, 0x66df1622, 0x38cc2a06, 0xcaa7a905,
+ 0xd9f75af1, 0x2b9cd9f2, 0xff56bd19, 0x0d3d3e1a, 0x1e6dcdee, 0xec064eed,
+ 0xc38d26c4, 0x31e6a5c7, 0x22b65633, 0xd0ddd530, 0x0417b1db, 0xf67c32d8,
+ 0xe52cc12c, 0x1747422f, 0x49547e0b, 0xbb3ffd08, 0xa86f0efc, 0x5a048dff,
+ 0x8ecee914, 0x7ca56a17, 0x6ff599e3, 0x9d9e1ae0, 0xd3d3e1ab, 0x21b862a8,
+ 0x32e8915c, 0xc083125f, 0x144976b4, 0xe622f5b7, 0xf5720643, 0x07198540,
+ 0x590ab964, 0xab613a67, 0xb831c993, 0x4a5a4a90, 0x9e902e7b, 0x6cfbad78,
+ 0x7fab5e8c, 0x8dc0dd8f, 0xe330a81a, 0x115b2b19, 0x020bd8ed, 0xf0605bee,
+ 0x24aa3f05, 0xd6c1bc06, 0xc5914ff2, 0x37faccf1, 0x69e9f0d5, 0x9b8273d6,
+ 0x88d28022, 0x7ab90321, 0xae7367ca, 0x5c18e4c9, 0x4f48173d, 0xbd23943e,
+ 0xf36e6f75, 0x0105ec76, 0x12551f82, 0xe03e9c81, 0x34f4f86a, 0xc69f7b69,
+ 0xd5cf889d, 0x27a40b9e, 0x79b737ba, 0x8bdcb4b9, 0x988c474d, 0x6ae7c44e,
+ 0xbe2da0a5, 0x4c4623a6, 0x5f16d052, 0xad7d5351};
+static const uint32_t table1_[256] = {
+ 0x00000000, 0x13a29877, 0x274530ee, 0x34e7a899, 0x4e8a61dc, 0x5d28f9ab,
+ 0x69cf5132, 0x7a6dc945, 0x9d14c3b8, 0x8eb65bcf, 0xba51f356, 0xa9f36b21,
+ 0xd39ea264, 0xc03c3a13, 0xf4db928a, 0xe7790afd, 0x3fc5f181, 0x2c6769f6,
+ 0x1880c16f, 0x0b225918, 0x714f905d, 0x62ed082a, 0x560aa0b3, 0x45a838c4,
+ 0xa2d13239, 0xb173aa4e, 0x859402d7, 0x96369aa0, 0xec5b53e5, 0xfff9cb92,
+ 0xcb1e630b, 0xd8bcfb7c, 0x7f8be302, 0x6c297b75, 0x58ced3ec, 0x4b6c4b9b,
+ 0x310182de, 0x22a31aa9, 0x1644b230, 0x05e62a47, 0xe29f20ba, 0xf13db8cd,
+ 0xc5da1054, 0xd6788823, 0xac154166, 0xbfb7d911, 0x8b507188, 0x98f2e9ff,
+ 0x404e1283, 0x53ec8af4, 0x670b226d, 0x74a9ba1a, 0x0ec4735f, 0x1d66eb28,
+ 0x298143b1, 0x3a23dbc6, 0xdd5ad13b, 0xcef8494c, 0xfa1fe1d5, 0xe9bd79a2,
+ 0x93d0b0e7, 0x80722890, 0xb4958009, 0xa737187e, 0xff17c604, 0xecb55e73,
+ 0xd852f6ea, 0xcbf06e9d, 0xb19da7d8, 0xa23f3faf, 0x96d89736, 0x857a0f41,
+ 0x620305bc, 0x71a19dcb, 0x45463552, 0x56e4ad25, 0x2c896460, 0x3f2bfc17,
+ 0x0bcc548e, 0x186eccf9, 0xc0d23785, 0xd370aff2, 0xe797076b, 0xf4359f1c,
+ 0x8e585659, 0x9dface2e, 0xa91d66b7, 0xbabffec0, 0x5dc6f43d, 0x4e646c4a,
+ 0x7a83c4d3, 0x69215ca4, 0x134c95e1, 0x00ee0d96, 0x3409a50f, 0x27ab3d78,
+ 0x809c2506, 0x933ebd71, 0xa7d915e8, 0xb47b8d9f, 0xce1644da, 0xddb4dcad,
+ 0xe9537434, 0xfaf1ec43, 0x1d88e6be, 0x0e2a7ec9, 0x3acdd650, 0x296f4e27,
+ 0x53028762, 0x40a01f15, 0x7447b78c, 0x67e52ffb, 0xbf59d487, 0xacfb4cf0,
+ 0x981ce469, 0x8bbe7c1e, 0xf1d3b55b, 0xe2712d2c, 0xd69685b5, 0xc5341dc2,
+ 0x224d173f, 0x31ef8f48, 0x050827d1, 0x16aabfa6, 0x6cc776e3, 0x7f65ee94,
+ 0x4b82460d, 0x5820de7a, 0xfbc3faf9, 0xe861628e, 0xdc86ca17, 0xcf245260,
+ 0xb5499b25, 0xa6eb0352, 0x920cabcb, 0x81ae33bc, 0x66d73941, 0x7575a136,
+ 0x419209af, 0x523091d8, 0x285d589d, 0x3bffc0ea, 0x0f186873, 0x1cbaf004,
+ 0xc4060b78, 0xd7a4930f, 0xe3433b96, 0xf0e1a3e1, 0x8a8c6aa4, 0x992ef2d3,
+ 0xadc95a4a, 0xbe6bc23d, 0x5912c8c0, 0x4ab050b7, 0x7e57f82e, 0x6df56059,
+ 0x1798a91c, 0x043a316b, 0x30dd99f2, 0x237f0185, 0x844819fb, 0x97ea818c,
+ 0xa30d2915, 0xb0afb162, 0xcac27827, 0xd960e050, 0xed8748c9, 0xfe25d0be,
+ 0x195cda43, 0x0afe4234, 0x3e19eaad, 0x2dbb72da, 0x57d6bb9f, 0x447423e8,
+ 0x70938b71, 0x63311306, 0xbb8de87a, 0xa82f700d, 0x9cc8d894, 0x8f6a40e3,
+ 0xf50789a6, 0xe6a511d1, 0xd242b948, 0xc1e0213f, 0x26992bc2, 0x353bb3b5,
+ 0x01dc1b2c, 0x127e835b, 0x68134a1e, 0x7bb1d269, 0x4f567af0, 0x5cf4e287,
+ 0x04d43cfd, 0x1776a48a, 0x23910c13, 0x30339464, 0x4a5e5d21, 0x59fcc556,
+ 0x6d1b6dcf, 0x7eb9f5b8, 0x99c0ff45, 0x8a626732, 0xbe85cfab, 0xad2757dc,
+ 0xd74a9e99, 0xc4e806ee, 0xf00fae77, 0xe3ad3600, 0x3b11cd7c, 0x28b3550b,
+ 0x1c54fd92, 0x0ff665e5, 0x759baca0, 0x663934d7, 0x52de9c4e, 0x417c0439,
+ 0xa6050ec4, 0xb5a796b3, 0x81403e2a, 0x92e2a65d, 0xe88f6f18, 0xfb2df76f,
+ 0xcfca5ff6, 0xdc68c781, 0x7b5fdfff, 0x68fd4788, 0x5c1aef11, 0x4fb87766,
+ 0x35d5be23, 0x26772654, 0x12908ecd, 0x013216ba, 0xe64b1c47, 0xf5e98430,
+ 0xc10e2ca9, 0xd2acb4de, 0xa8c17d9b, 0xbb63e5ec, 0x8f844d75, 0x9c26d502,
+ 0x449a2e7e, 0x5738b609, 0x63df1e90, 0x707d86e7, 0x0a104fa2, 0x19b2d7d5,
+ 0x2d557f4c, 0x3ef7e73b, 0xd98eedc6, 0xca2c75b1, 0xfecbdd28, 0xed69455f,
+ 0x97048c1a, 0x84a6146d, 0xb041bcf4, 0xa3e32483};
+static const uint32_t table2_[256] = {
+ 0x00000000, 0xa541927e, 0x4f6f520d, 0xea2ec073, 0x9edea41a, 0x3b9f3664,
+ 0xd1b1f617, 0x74f06469, 0x38513ec5, 0x9d10acbb, 0x773e6cc8, 0xd27ffeb6,
+ 0xa68f9adf, 0x03ce08a1, 0xe9e0c8d2, 0x4ca15aac, 0x70a27d8a, 0xd5e3eff4,
+ 0x3fcd2f87, 0x9a8cbdf9, 0xee7cd990, 0x4b3d4bee, 0xa1138b9d, 0x045219e3,
+ 0x48f3434f, 0xedb2d131, 0x079c1142, 0xa2dd833c, 0xd62de755, 0x736c752b,
+ 0x9942b558, 0x3c032726, 0xe144fb14, 0x4405696a, 0xae2ba919, 0x0b6a3b67,
+ 0x7f9a5f0e, 0xdadbcd70, 0x30f50d03, 0x95b49f7d, 0xd915c5d1, 0x7c5457af,
+ 0x967a97dc, 0x333b05a2, 0x47cb61cb, 0xe28af3b5, 0x08a433c6, 0xade5a1b8,
+ 0x91e6869e, 0x34a714e0, 0xde89d493, 0x7bc846ed, 0x0f382284, 0xaa79b0fa,
+ 0x40577089, 0xe516e2f7, 0xa9b7b85b, 0x0cf62a25, 0xe6d8ea56, 0x43997828,
+ 0x37691c41, 0x92288e3f, 0x78064e4c, 0xdd47dc32, 0xc76580d9, 0x622412a7,
+ 0x880ad2d4, 0x2d4b40aa, 0x59bb24c3, 0xfcfab6bd, 0x16d476ce, 0xb395e4b0,
+ 0xff34be1c, 0x5a752c62, 0xb05bec11, 0x151a7e6f, 0x61ea1a06, 0xc4ab8878,
+ 0x2e85480b, 0x8bc4da75, 0xb7c7fd53, 0x12866f2d, 0xf8a8af5e, 0x5de93d20,
+ 0x29195949, 0x8c58cb37, 0x66760b44, 0xc337993a, 0x8f96c396, 0x2ad751e8,
+ 0xc0f9919b, 0x65b803e5, 0x1148678c, 0xb409f5f2, 0x5e273581, 0xfb66a7ff,
+ 0x26217bcd, 0x8360e9b3, 0x694e29c0, 0xcc0fbbbe, 0xb8ffdfd7, 0x1dbe4da9,
+ 0xf7908dda, 0x52d11fa4, 0x1e704508, 0xbb31d776, 0x511f1705, 0xf45e857b,
+ 0x80aee112, 0x25ef736c, 0xcfc1b31f, 0x6a802161, 0x56830647, 0xf3c29439,
+ 0x19ec544a, 0xbcadc634, 0xc85da25d, 0x6d1c3023, 0x8732f050, 0x2273622e,
+ 0x6ed23882, 0xcb93aafc, 0x21bd6a8f, 0x84fcf8f1, 0xf00c9c98, 0x554d0ee6,
+ 0xbf63ce95, 0x1a225ceb, 0x8b277743, 0x2e66e53d, 0xc448254e, 0x6109b730,
+ 0x15f9d359, 0xb0b84127, 0x5a968154, 0xffd7132a, 0xb3764986, 0x1637dbf8,
+ 0xfc191b8b, 0x595889f5, 0x2da8ed9c, 0x88e97fe2, 0x62c7bf91, 0xc7862def,
+ 0xfb850ac9, 0x5ec498b7, 0xb4ea58c4, 0x11abcaba, 0x655baed3, 0xc01a3cad,
+ 0x2a34fcde, 0x8f756ea0, 0xc3d4340c, 0x6695a672, 0x8cbb6601, 0x29faf47f,
+ 0x5d0a9016, 0xf84b0268, 0x1265c21b, 0xb7245065, 0x6a638c57, 0xcf221e29,
+ 0x250cde5a, 0x804d4c24, 0xf4bd284d, 0x51fcba33, 0xbbd27a40, 0x1e93e83e,
+ 0x5232b292, 0xf77320ec, 0x1d5de09f, 0xb81c72e1, 0xccec1688, 0x69ad84f6,
+ 0x83834485, 0x26c2d6fb, 0x1ac1f1dd, 0xbf8063a3, 0x55aea3d0, 0xf0ef31ae,
+ 0x841f55c7, 0x215ec7b9, 0xcb7007ca, 0x6e3195b4, 0x2290cf18, 0x87d15d66,
+ 0x6dff9d15, 0xc8be0f6b, 0xbc4e6b02, 0x190ff97c, 0xf321390f, 0x5660ab71,
+ 0x4c42f79a, 0xe90365e4, 0x032da597, 0xa66c37e9, 0xd29c5380, 0x77ddc1fe,
+ 0x9df3018d, 0x38b293f3, 0x7413c95f, 0xd1525b21, 0x3b7c9b52, 0x9e3d092c,
+ 0xeacd6d45, 0x4f8cff3b, 0xa5a23f48, 0x00e3ad36, 0x3ce08a10, 0x99a1186e,
+ 0x738fd81d, 0xd6ce4a63, 0xa23e2e0a, 0x077fbc74, 0xed517c07, 0x4810ee79,
+ 0x04b1b4d5, 0xa1f026ab, 0x4bdee6d8, 0xee9f74a6, 0x9a6f10cf, 0x3f2e82b1,
+ 0xd50042c2, 0x7041d0bc, 0xad060c8e, 0x08479ef0, 0xe2695e83, 0x4728ccfd,
+ 0x33d8a894, 0x96993aea, 0x7cb7fa99, 0xd9f668e7, 0x9557324b, 0x3016a035,
+ 0xda386046, 0x7f79f238, 0x0b899651, 0xaec8042f, 0x44e6c45c, 0xe1a75622,
+ 0xdda47104, 0x78e5e37a, 0x92cb2309, 0x378ab177, 0x437ad51e, 0xe63b4760,
+ 0x0c158713, 0xa954156d, 0xe5f54fc1, 0x40b4ddbf, 0xaa9a1dcc, 0x0fdb8fb2,
+ 0x7b2bebdb, 0xde6a79a5, 0x3444b9d6, 0x91052ba8};
+static const uint32_t table3_[256] = {
+ 0x00000000, 0xdd45aab8, 0xbf672381, 0x62228939, 0x7b2231f3, 0xa6679b4b,
+ 0xc4451272, 0x1900b8ca, 0xf64463e6, 0x2b01c95e, 0x49234067, 0x9466eadf,
+ 0x8d665215, 0x5023f8ad, 0x32017194, 0xef44db2c, 0xe964b13d, 0x34211b85,
+ 0x560392bc, 0x8b463804, 0x924680ce, 0x4f032a76, 0x2d21a34f, 0xf06409f7,
+ 0x1f20d2db, 0xc2657863, 0xa047f15a, 0x7d025be2, 0x6402e328, 0xb9474990,
+ 0xdb65c0a9, 0x06206a11, 0xd725148b, 0x0a60be33, 0x6842370a, 0xb5079db2,
+ 0xac072578, 0x71428fc0, 0x136006f9, 0xce25ac41, 0x2161776d, 0xfc24ddd5,
+ 0x9e0654ec, 0x4343fe54, 0x5a43469e, 0x8706ec26, 0xe524651f, 0x3861cfa7,
+ 0x3e41a5b6, 0xe3040f0e, 0x81268637, 0x5c632c8f, 0x45639445, 0x98263efd,
+ 0xfa04b7c4, 0x27411d7c, 0xc805c650, 0x15406ce8, 0x7762e5d1, 0xaa274f69,
+ 0xb327f7a3, 0x6e625d1b, 0x0c40d422, 0xd1057e9a, 0xaba65fe7, 0x76e3f55f,
+ 0x14c17c66, 0xc984d6de, 0xd0846e14, 0x0dc1c4ac, 0x6fe34d95, 0xb2a6e72d,
+ 0x5de23c01, 0x80a796b9, 0xe2851f80, 0x3fc0b538, 0x26c00df2, 0xfb85a74a,
+ 0x99a72e73, 0x44e284cb, 0x42c2eeda, 0x9f874462, 0xfda5cd5b, 0x20e067e3,
+ 0x39e0df29, 0xe4a57591, 0x8687fca8, 0x5bc25610, 0xb4868d3c, 0x69c32784,
+ 0x0be1aebd, 0xd6a40405, 0xcfa4bccf, 0x12e11677, 0x70c39f4e, 0xad8635f6,
+ 0x7c834b6c, 0xa1c6e1d4, 0xc3e468ed, 0x1ea1c255, 0x07a17a9f, 0xdae4d027,
+ 0xb8c6591e, 0x6583f3a6, 0x8ac7288a, 0x57828232, 0x35a00b0b, 0xe8e5a1b3,
+ 0xf1e51979, 0x2ca0b3c1, 0x4e823af8, 0x93c79040, 0x95e7fa51, 0x48a250e9,
+ 0x2a80d9d0, 0xf7c57368, 0xeec5cba2, 0x3380611a, 0x51a2e823, 0x8ce7429b,
+ 0x63a399b7, 0xbee6330f, 0xdcc4ba36, 0x0181108e, 0x1881a844, 0xc5c402fc,
+ 0xa7e68bc5, 0x7aa3217d, 0x52a0c93f, 0x8fe56387, 0xedc7eabe, 0x30824006,
+ 0x2982f8cc, 0xf4c75274, 0x96e5db4d, 0x4ba071f5, 0xa4e4aad9, 0x79a10061,
+ 0x1b838958, 0xc6c623e0, 0xdfc69b2a, 0x02833192, 0x60a1b8ab, 0xbde41213,
+ 0xbbc47802, 0x6681d2ba, 0x04a35b83, 0xd9e6f13b, 0xc0e649f1, 0x1da3e349,
+ 0x7f816a70, 0xa2c4c0c8, 0x4d801be4, 0x90c5b15c, 0xf2e73865, 0x2fa292dd,
+ 0x36a22a17, 0xebe780af, 0x89c50996, 0x5480a32e, 0x8585ddb4, 0x58c0770c,
+ 0x3ae2fe35, 0xe7a7548d, 0xfea7ec47, 0x23e246ff, 0x41c0cfc6, 0x9c85657e,
+ 0x73c1be52, 0xae8414ea, 0xcca69dd3, 0x11e3376b, 0x08e38fa1, 0xd5a62519,
+ 0xb784ac20, 0x6ac10698, 0x6ce16c89, 0xb1a4c631, 0xd3864f08, 0x0ec3e5b0,
+ 0x17c35d7a, 0xca86f7c2, 0xa8a47efb, 0x75e1d443, 0x9aa50f6f, 0x47e0a5d7,
+ 0x25c22cee, 0xf8878656, 0xe1873e9c, 0x3cc29424, 0x5ee01d1d, 0x83a5b7a5,
+ 0xf90696d8, 0x24433c60, 0x4661b559, 0x9b241fe1, 0x8224a72b, 0x5f610d93,
+ 0x3d4384aa, 0xe0062e12, 0x0f42f53e, 0xd2075f86, 0xb025d6bf, 0x6d607c07,
+ 0x7460c4cd, 0xa9256e75, 0xcb07e74c, 0x16424df4, 0x106227e5, 0xcd278d5d,
+ 0xaf050464, 0x7240aedc, 0x6b401616, 0xb605bcae, 0xd4273597, 0x09629f2f,
+ 0xe6264403, 0x3b63eebb, 0x59416782, 0x8404cd3a, 0x9d0475f0, 0x4041df48,
+ 0x22635671, 0xff26fcc9, 0x2e238253, 0xf36628eb, 0x9144a1d2, 0x4c010b6a,
+ 0x5501b3a0, 0x88441918, 0xea669021, 0x37233a99, 0xd867e1b5, 0x05224b0d,
+ 0x6700c234, 0xba45688c, 0xa345d046, 0x7e007afe, 0x1c22f3c7, 0xc167597f,
+ 0xc747336e, 0x1a0299d6, 0x782010ef, 0xa565ba57, 0xbc65029d, 0x6120a825,
+ 0x0302211c, 0xde478ba4, 0x31035088, 0xec46fa30, 0x8e647309, 0x5321d9b1,
+ 0x4a21617b, 0x9764cbc3, 0xf54642fa, 0x2803e842};
+
+// Used to fetch a naturally-aligned 32-bit word in little endian byte-order
+static inline uint32_t LE_LOAD32(const uint8_t* p) {
+ return DecodeFixed32(reinterpret_cast<const char*>(p));
+}
+
+#if defined(HAVE_SSE42) && (defined(__LP64__) || defined(_WIN64))
+static inline uint64_t LE_LOAD64(const uint8_t* p) {
+ return DecodeFixed64(reinterpret_cast<const char*>(p));
+}
+#endif
+
+static inline void Slow_CRC32(uint64_t* l, uint8_t const** p) {
+ uint32_t c = static_cast<uint32_t>(*l ^ LE_LOAD32(*p));
+ *p += 4;
+ *l = table3_[c & 0xff] ^ table2_[(c >> 8) & 0xff] ^
+ table1_[(c >> 16) & 0xff] ^ table0_[c >> 24];
+ // DO it twice.
+ c = static_cast<uint32_t>(*l ^ LE_LOAD32(*p));
+ *p += 4;
+ *l = table3_[c & 0xff] ^ table2_[(c >> 8) & 0xff] ^
+ table1_[(c >> 16) & 0xff] ^ table0_[c >> 24];
+}
+
+#if (!(defined(HAVE_POWER8) && defined(HAS_ALTIVEC))) && \
+ (!defined(HAVE_ARM64_CRC)) || \
+ defined(NO_THREEWAY_CRC32C)
+static inline void Fast_CRC32(uint64_t* l, uint8_t const** p) {
+#ifndef HAVE_SSE42
+ Slow_CRC32(l, p);
+#elif defined(__LP64__) || defined(_WIN64)
+ *l = _mm_crc32_u64(*l, LE_LOAD64(*p));
+ *p += 8;
+#else
+ *l = _mm_crc32_u32(static_cast<unsigned int>(*l), LE_LOAD32(*p));
+ *p += 4;
+ *l = _mm_crc32_u32(static_cast<unsigned int>(*l), LE_LOAD32(*p));
+ *p += 4;
+#endif
+}
+#endif
+
+template <void (*CRC32)(uint64_t*, uint8_t const**)>
+uint32_t ExtendImpl(uint32_t crc, const char* buf, size_t size) {
+ const uint8_t* p = reinterpret_cast<const uint8_t*>(buf);
+ const uint8_t* e = p + size;
+ uint64_t l = crc ^ 0xffffffffu;
+
+// Align n to (1 << m) byte boundary
+#define ALIGN(n, m) ((n + ((1 << m) - 1)) & ~((1 << m) - 1))
+
+#define STEP1 \
+ do { \
+ int c = (l & 0xff) ^ *p++; \
+ l = table0_[c] ^ (l >> 8); \
+ } while (0)
+
+ // Point x at first 16-byte aligned byte in string. This might be
+ // just past the end of the string.
+ const uintptr_t pval = reinterpret_cast<uintptr_t>(p);
+ const uint8_t* x = reinterpret_cast<const uint8_t*>(ALIGN(pval, 4));
+ if (x <= e) {
+ // Process bytes until finished or p is 16-byte aligned
+ while (p != x) {
+ STEP1;
+ }
+ }
+ // Process bytes 16 at a time
+ while ((e - p) >= 16) {
+ CRC32(&l, &p);
+ CRC32(&l, &p);
+ }
+ // Process bytes 8 at a time
+ while ((e - p) >= 8) {
+ CRC32(&l, &p);
+ }
+ // Process the last few bytes
+ while (p != e) {
+ STEP1;
+ }
+#undef STEP1
+#undef ALIGN
+ return static_cast<uint32_t>(l ^ 0xffffffffu);
+}
+
+// Detect if ARM64 CRC or not.
+#ifndef HAVE_ARM64_CRC
+// Detect if SS42 or not.
+#ifndef HAVE_POWER8
+
+static bool isSSE42() {
+#ifndef HAVE_SSE42
+ return false;
+#elif defined(__GNUC__) && defined(__x86_64__) && !defined(IOS_CROSS_COMPILE)
+ uint32_t c_;
+ __asm__("cpuid" : "=c"(c_) : "a"(1) : "ebx", "edx");
+ return c_ & (1U << 20); // copied from CpuId.h in Folly. Test SSE42
+#elif defined(_WIN64)
+ int info[4];
+ __cpuidex(info, 0x00000001, 0);
+ return (info[2] & ((int)1 << 20)) != 0;
+#else
+ return false;
+#endif
+}
+
+static bool isPCLMULQDQ() {
+#ifndef HAVE_SSE42
+ // in build_detect_platform we set this macro when both SSE42 and PCLMULQDQ
+ // are supported by compiler
+ return false;
+#elif defined(__GNUC__) && defined(__x86_64__) && !defined(IOS_CROSS_COMPILE)
+ uint32_t c_;
+ __asm__("cpuid" : "=c"(c_) : "a"(1) : "ebx", "edx");
+ return c_ & (1U << 1); // PCLMULQDQ is in bit 1 (not bit 0)
+#elif defined(_WIN64)
+ int info[4];
+ __cpuidex(info, 0x00000001, 0);
+ return (info[2] & ((int)1 << 1)) != 0;
+#else
+ return false;
+#endif
+}
+
+#endif // HAVE_POWER8
+#endif // HAVE_ARM64_CRC
+
+using Function = uint32_t (*)(uint32_t, const char*, size_t);
+
+#if defined(HAVE_POWER8) && defined(HAS_ALTIVEC)
+uint32_t ExtendPPCImpl(uint32_t crc, const char* buf, size_t size) {
+ return crc32c_ppc(crc, (const unsigned char*)buf, size);
+}
+
+#if __linux__
+static int arch_ppc_probe(void) {
+ arch_ppc_crc32 = 0;
+
+#if defined(__powerpc64__) && defined(ROCKSDB_AUXV_GETAUXVAL_PRESENT)
+ if (getauxval(AT_HWCAP2) & PPC_FEATURE2_VEC_CRYPTO) arch_ppc_crc32 = 1;
+#endif /* __powerpc64__ */
+
+ return arch_ppc_crc32;
+}
+#elif __FreeBSD__
+static int arch_ppc_probe(void) {
+ unsigned long cpufeatures;
+ arch_ppc_crc32 = 0;
+
+#if defined(__powerpc64__)
+ elf_aux_info(AT_HWCAP2, &cpufeatures, sizeof(cpufeatures));
+ if (cpufeatures & PPC_FEATURE2_HAS_VEC_CRYPTO) arch_ppc_crc32 = 1;
+#endif /* __powerpc64__ */
+
+ return arch_ppc_crc32;
+}
+#endif // __linux__
+
+static bool isAltiVec() {
+ if (arch_ppc_probe()) {
+ return true;
+ } else {
+ return false;
+ }
+}
+#endif
+
+#if defined(HAVE_ARM64_CRC)
+uint32_t ExtendARMImpl(uint32_t crc, const char* buf, size_t size) {
+ return crc32c_arm64(crc, (const unsigned char*)buf, size);
+}
+#endif
+
+std::string IsFastCrc32Supported() {
+ bool has_fast_crc = false;
+ std::string fast_zero_msg;
+ std::string arch;
+#ifdef HAVE_POWER8
+#ifdef HAS_ALTIVEC
+ if (arch_ppc_probe()) {
+ has_fast_crc = true;
+ arch = "PPC";
+ }
+#else
+ has_fast_crc = false;
+ arch = "PPC";
+#endif
+#elif defined(HAVE_ARM64_CRC)
+ if (crc32c_runtime_check()) {
+ has_fast_crc = true;
+ arch = "Arm64";
+ pmull_runtime_flag = crc32c_pmull_runtime_check();
+ } else {
+ has_fast_crc = false;
+ arch = "Arm64";
+ }
+#else
+ has_fast_crc = isSSE42();
+ arch = "x86";
+#endif
+ if (has_fast_crc) {
+ fast_zero_msg.append("Supported on " + arch);
+ } else {
+ fast_zero_msg.append("Not supported on " + arch);
+ }
+ return fast_zero_msg;
+}
+
+/*
+ * Copyright 2016 Ferry Toth, Exalon Delft BV, The Netherlands
+ * This software is provided 'as-is', without any express or implied
+ * warranty. In no event will the author be held liable for any damages
+ * arising from the use of this software.
+ * Permission is granted to anyone to use this software for any purpose,
+ * including commercial applications, and to alter it and redistribute it
+ * freely, subject to the following restrictions:
+ * 1. The origin of this software must not be misrepresented; you must not
+ * claim that you wrote the original software. If you use this software
+ * in a product, an acknowledgment in the product documentation would be
+ * appreciated but is not required.
+ * 2. Altered source versions must be plainly marked as such, and must not be
+ * misrepresented as being the original software.
+ * 3. This notice may not be removed or altered from any source distribution.
+ * Ferry Toth
+ * ftoth@exalondelft.nl
+ *
+ * https://github.com/htot/crc32c
+ *
+ * Modified by Facebook
+ *
+ * Original intel whitepaper:
+ * "Fast CRC Computation for iSCSI Polynomial Using CRC32 Instruction"
+ * https://www.intel.com/content/dam/www/public/us/en/documents/white-papers/crc-iscsi-polynomial-crc32-instruction-paper.pdf
+ *
+ * This version is from the folly library, created by Dave Watson
+ * <davejwatson@fb.com>
+ *
+ */
+#if defined HAVE_SSE42 && defined HAVE_PCLMUL
+
+#define CRCtriplet(crc, buf, offset) \
+ crc##0 = _mm_crc32_u64(crc##0, *(buf##0 + offset)); \
+ crc##1 = _mm_crc32_u64(crc##1, *(buf##1 + offset)); \
+ crc##2 = _mm_crc32_u64(crc##2, *(buf##2 + offset));
+
+#define CRCduplet(crc, buf, offset) \
+ crc##0 = _mm_crc32_u64(crc##0, *(buf##0 + offset)); \
+ crc##1 = _mm_crc32_u64(crc##1, *(buf##1 + offset));
+
+#define CRCsinglet(crc, buf, offset) \
+ crc = _mm_crc32_u64(crc, *(uint64_t*)(buf + offset));
+
+// Numbers taken directly from intel whitepaper.
+// clang-format off
+const uint64_t clmul_constants[] = {
+ 0x14cd00bd6, 0x105ec76f0, 0x0ba4fc28e, 0x14cd00bd6,
+ 0x1d82c63da, 0x0f20c0dfe, 0x09e4addf8, 0x0ba4fc28e,
+ 0x039d3b296, 0x1384aa63a, 0x102f9b8a2, 0x1d82c63da,
+ 0x14237f5e6, 0x01c291d04, 0x00d3b6092, 0x09e4addf8,
+ 0x0c96cfdc0, 0x0740eef02, 0x18266e456, 0x039d3b296,
+ 0x0daece73e, 0x0083a6eec, 0x0ab7aff2a, 0x102f9b8a2,
+ 0x1248ea574, 0x1c1733996, 0x083348832, 0x14237f5e6,
+ 0x12c743124, 0x02ad91c30, 0x0b9e02b86, 0x00d3b6092,
+ 0x018b33a4e, 0x06992cea2, 0x1b331e26a, 0x0c96cfdc0,
+ 0x17d35ba46, 0x07e908048, 0x1bf2e8b8a, 0x18266e456,
+ 0x1a3e0968a, 0x11ed1f9d8, 0x0ce7f39f4, 0x0daece73e,
+ 0x061d82e56, 0x0f1d0f55e, 0x0d270f1a2, 0x0ab7aff2a,
+ 0x1c3f5f66c, 0x0a87ab8a8, 0x12ed0daac, 0x1248ea574,
+ 0x065863b64, 0x08462d800, 0x11eef4f8e, 0x083348832,
+ 0x1ee54f54c, 0x071d111a8, 0x0b3e32c28, 0x12c743124,
+ 0x0064f7f26, 0x0ffd852c6, 0x0dd7e3b0c, 0x0b9e02b86,
+ 0x0f285651c, 0x0dcb17aa4, 0x010746f3c, 0x018b33a4e,
+ 0x1c24afea4, 0x0f37c5aee, 0x0271d9844, 0x1b331e26a,
+ 0x08e766a0c, 0x06051d5a2, 0x093a5f730, 0x17d35ba46,
+ 0x06cb08e5c, 0x11d5ca20e, 0x06b749fb2, 0x1bf2e8b8a,
+ 0x1167f94f2, 0x021f3d99c, 0x0cec3662e, 0x1a3e0968a,
+ 0x19329634a, 0x08f158014, 0x0e6fc4e6a, 0x0ce7f39f4,
+ 0x08227bb8a, 0x1a5e82106, 0x0b0cd4768, 0x061d82e56,
+ 0x13c2b89c4, 0x188815ab2, 0x0d7a4825c, 0x0d270f1a2,
+ 0x10f5ff2ba, 0x105405f3e, 0x00167d312, 0x1c3f5f66c,
+ 0x0f6076544, 0x0e9adf796, 0x026f6a60a, 0x12ed0daac,
+ 0x1a2adb74e, 0x096638b34, 0x19d34af3a, 0x065863b64,
+ 0x049c3cc9c, 0x1e50585a0, 0x068bce87a, 0x11eef4f8e,
+ 0x1524fa6c6, 0x19f1c69dc, 0x16cba8aca, 0x1ee54f54c,
+ 0x042d98888, 0x12913343e, 0x1329d9f7e, 0x0b3e32c28,
+ 0x1b1c69528, 0x088f25a3a, 0x02178513a, 0x0064f7f26,
+ 0x0e0ac139e, 0x04e36f0b0, 0x0170076fa, 0x0dd7e3b0c,
+ 0x141a1a2e2, 0x0bd6f81f8, 0x16ad828b4, 0x0f285651c,
+ 0x041d17b64, 0x19425cbba, 0x1fae1cc66, 0x010746f3c,
+ 0x1a75b4b00, 0x18db37e8a, 0x0f872e54c, 0x1c24afea4,
+ 0x01e41e9fc, 0x04c144932, 0x086d8e4d2, 0x0271d9844,
+ 0x160f7af7a, 0x052148f02, 0x05bb8f1bc, 0x08e766a0c,
+ 0x0a90fd27a, 0x0a3c6f37a, 0x0b3af077a, 0x093a5f730,
+ 0x04984d782, 0x1d22c238e, 0x0ca6ef3ac, 0x06cb08e5c,
+ 0x0234e0b26, 0x063ded06a, 0x1d88abd4a, 0x06b749fb2,
+ 0x04597456a, 0x04d56973c, 0x0e9e28eb4, 0x1167f94f2,
+ 0x07b3ff57a, 0x19385bf2e, 0x0c9c8b782, 0x0cec3662e,
+ 0x13a9cba9e, 0x0e417f38a, 0x093e106a4, 0x19329634a,
+ 0x167001a9c, 0x14e727980, 0x1ddffc5d4, 0x0e6fc4e6a,
+ 0x00df04680, 0x0d104b8fc, 0x02342001e, 0x08227bb8a,
+ 0x00a2a8d7e, 0x05b397730, 0x168763fa6, 0x0b0cd4768,
+ 0x1ed5a407a, 0x0e78eb416, 0x0d2c3ed1a, 0x13c2b89c4,
+ 0x0995a5724, 0x1641378f0, 0x19b1afbc4, 0x0d7a4825c,
+ 0x109ffedc0, 0x08d96551c, 0x0f2271e60, 0x10f5ff2ba,
+ 0x00b0bf8ca, 0x00bf80dd2, 0x123888b7a, 0x00167d312,
+ 0x1e888f7dc, 0x18dcddd1c, 0x002ee03b2, 0x0f6076544,
+ 0x183e8d8fe, 0x06a45d2b2, 0x133d7a042, 0x026f6a60a,
+ 0x116b0f50c, 0x1dd3e10e8, 0x05fabe670, 0x1a2adb74e,
+ 0x130004488, 0x0de87806c, 0x000bcf5f6, 0x19d34af3a,
+ 0x18f0c7078, 0x014338754, 0x017f27698, 0x049c3cc9c,
+ 0x058ca5f00, 0x15e3e77ee, 0x1af900c24, 0x068bce87a,
+ 0x0b5cfca28, 0x0dd07448e, 0x0ded288f8, 0x1524fa6c6,
+ 0x059f229bc, 0x1d8048348, 0x06d390dec, 0x16cba8aca,
+ 0x037170390, 0x0a3e3e02c, 0x06353c1cc, 0x042d98888,
+ 0x0c4584f5c, 0x0d73c7bea, 0x1f16a3418, 0x1329d9f7e,
+ 0x0531377e2, 0x185137662, 0x1d8d9ca7c, 0x1b1c69528,
+ 0x0b25b29f2, 0x18a08b5bc, 0x19fb2a8b0, 0x02178513a,
+ 0x1a08fe6ac, 0x1da758ae0, 0x045cddf4e, 0x0e0ac139e,
+ 0x1a91647f2, 0x169cf9eb0, 0x1a0f717c4, 0x0170076fa,
+};
+
+// Compute the crc32c value for buffer smaller than 8
+#ifdef ROCKSDB_UBSAN_RUN
+#if defined(__clang__)
+__attribute__((__no_sanitize__("alignment")))
+#elif defined(__GNUC__)
+__attribute__((__no_sanitize_undefined__))
+#endif
+#endif
+inline void align_to_8(
+ size_t len,
+ uint64_t& crc0, // crc so far, updated on return
+ const unsigned char*& next) { // next data pointer, updated on return
+ uint32_t crc32bit = static_cast<uint32_t>(crc0);
+ if (len & 0x04) {
+ crc32bit = _mm_crc32_u32(crc32bit, *(uint32_t*)next);
+ next += sizeof(uint32_t);
+ }
+ if (len & 0x02) {
+ crc32bit = _mm_crc32_u16(crc32bit, *(uint16_t*)next);
+ next += sizeof(uint16_t);
+ }
+ if (len & 0x01) {
+ crc32bit = _mm_crc32_u8(crc32bit, *(next));
+ next++;
+ }
+ crc0 = crc32bit;
+}
+
+//
+// CombineCRC performs pclmulqdq multiplication of 2 partial CRC's and a well
+// chosen constant and xor's these with the remaining CRC.
+//
+inline uint64_t CombineCRC(
+ size_t block_size,
+ uint64_t crc0,
+ uint64_t crc1,
+ uint64_t crc2,
+ const uint64_t* next2) {
+ const auto multiplier =
+ *(reinterpret_cast<const __m128i*>(clmul_constants) + block_size - 1);
+ const auto crc0_xmm = _mm_set_epi64x(0, crc0);
+ const auto res0 = _mm_clmulepi64_si128(crc0_xmm, multiplier, 0x00);
+ const auto crc1_xmm = _mm_set_epi64x(0, crc1);
+ const auto res1 = _mm_clmulepi64_si128(crc1_xmm, multiplier, 0x10);
+ const auto res = _mm_xor_si128(res0, res1);
+ crc0 = _mm_cvtsi128_si64(res);
+ crc0 = crc0 ^ *((uint64_t*)next2 - 1);
+ crc2 = _mm_crc32_u64(crc2, crc0);
+ return crc2;
+}
+
+// Compute CRC-32C using the Intel hardware instruction.
+#ifdef ROCKSDB_UBSAN_RUN
+#if defined(__clang__)
+__attribute__((__no_sanitize__("alignment")))
+#elif defined(__GNUC__)
+__attribute__((__no_sanitize_undefined__))
+#endif
+#endif
+uint32_t crc32c_3way(uint32_t crc, const char* buf, size_t len) {
+ const unsigned char* next = (const unsigned char*)buf;
+ uint64_t count;
+ uint64_t crc0, crc1, crc2;
+ crc0 = crc ^ 0xffffffffu;
+
+
+ if (len >= 8) {
+ // if len > 216 then align and use triplets
+ if (len > 216) {
+ {
+ // Work on the bytes (< 8) before the first 8-byte alignment addr starts
+ uint64_t align_bytes = (8 - (uintptr_t)next) & 7;
+ len -= align_bytes;
+ align_to_8(align_bytes, crc0, next);
+ }
+
+ // Now work on the remaining blocks
+ count = len / 24; // number of triplets
+ len %= 24; // bytes remaining
+ uint64_t n = count >> 7; // #blocks = first block + full blocks
+ uint64_t block_size = count & 127;
+ if (block_size == 0) {
+ block_size = 128;
+ } else {
+ n++;
+ }
+ // points to the first byte of the next block
+ const uint64_t* next0 = (uint64_t*)next + block_size;
+ const uint64_t* next1 = next0 + block_size;
+ const uint64_t* next2 = next1 + block_size;
+
+ crc1 = crc2 = 0;
+ // Use Duff's device, a for() loop inside a switch()
+ // statement. This needs to execute at least once, round len
+ // down to nearest triplet multiple
+ switch (block_size) {
+ case 128:
+ do {
+ // jumps here for a full block of len 128
+ CRCtriplet(crc, next, -128);
+ FALLTHROUGH_INTENDED;
+ case 127:
+ // jumps here or below for the first block smaller
+ CRCtriplet(crc, next, -127);
+ FALLTHROUGH_INTENDED;
+ case 126:
+ CRCtriplet(crc, next, -126); // than 128
+ FALLTHROUGH_INTENDED;
+ case 125:
+ CRCtriplet(crc, next, -125);
+ FALLTHROUGH_INTENDED;
+ case 124:
+ CRCtriplet(crc, next, -124);
+ FALLTHROUGH_INTENDED;
+ case 123:
+ CRCtriplet(crc, next, -123);
+ FALLTHROUGH_INTENDED;
+ case 122:
+ CRCtriplet(crc, next, -122);
+ FALLTHROUGH_INTENDED;
+ case 121:
+ CRCtriplet(crc, next, -121);
+ FALLTHROUGH_INTENDED;
+ case 120:
+ CRCtriplet(crc, next, -120);
+ FALLTHROUGH_INTENDED;
+ case 119:
+ CRCtriplet(crc, next, -119);
+ FALLTHROUGH_INTENDED;
+ case 118:
+ CRCtriplet(crc, next, -118);
+ FALLTHROUGH_INTENDED;
+ case 117:
+ CRCtriplet(crc, next, -117);
+ FALLTHROUGH_INTENDED;
+ case 116:
+ CRCtriplet(crc, next, -116);
+ FALLTHROUGH_INTENDED;
+ case 115:
+ CRCtriplet(crc, next, -115);
+ FALLTHROUGH_INTENDED;
+ case 114:
+ CRCtriplet(crc, next, -114);
+ FALLTHROUGH_INTENDED;
+ case 113:
+ CRCtriplet(crc, next, -113);
+ FALLTHROUGH_INTENDED;
+ case 112:
+ CRCtriplet(crc, next, -112);
+ FALLTHROUGH_INTENDED;
+ case 111:
+ CRCtriplet(crc, next, -111);
+ FALLTHROUGH_INTENDED;
+ case 110:
+ CRCtriplet(crc, next, -110);
+ FALLTHROUGH_INTENDED;
+ case 109:
+ CRCtriplet(crc, next, -109);
+ FALLTHROUGH_INTENDED;
+ case 108:
+ CRCtriplet(crc, next, -108);
+ FALLTHROUGH_INTENDED;
+ case 107:
+ CRCtriplet(crc, next, -107);
+ FALLTHROUGH_INTENDED;
+ case 106:
+ CRCtriplet(crc, next, -106);
+ FALLTHROUGH_INTENDED;
+ case 105:
+ CRCtriplet(crc, next, -105);
+ FALLTHROUGH_INTENDED;
+ case 104:
+ CRCtriplet(crc, next, -104);
+ FALLTHROUGH_INTENDED;
+ case 103:
+ CRCtriplet(crc, next, -103);
+ FALLTHROUGH_INTENDED;
+ case 102:
+ CRCtriplet(crc, next, -102);
+ FALLTHROUGH_INTENDED;
+ case 101:
+ CRCtriplet(crc, next, -101);
+ FALLTHROUGH_INTENDED;
+ case 100:
+ CRCtriplet(crc, next, -100);
+ FALLTHROUGH_INTENDED;
+ case 99:
+ CRCtriplet(crc, next, -99);
+ FALLTHROUGH_INTENDED;
+ case 98:
+ CRCtriplet(crc, next, -98);
+ FALLTHROUGH_INTENDED;
+ case 97:
+ CRCtriplet(crc, next, -97);
+ FALLTHROUGH_INTENDED;
+ case 96:
+ CRCtriplet(crc, next, -96);
+ FALLTHROUGH_INTENDED;
+ case 95:
+ CRCtriplet(crc, next, -95);
+ FALLTHROUGH_INTENDED;
+ case 94:
+ CRCtriplet(crc, next, -94);
+ FALLTHROUGH_INTENDED;
+ case 93:
+ CRCtriplet(crc, next, -93);
+ FALLTHROUGH_INTENDED;
+ case 92:
+ CRCtriplet(crc, next, -92);
+ FALLTHROUGH_INTENDED;
+ case 91:
+ CRCtriplet(crc, next, -91);
+ FALLTHROUGH_INTENDED;
+ case 90:
+ CRCtriplet(crc, next, -90);
+ FALLTHROUGH_INTENDED;
+ case 89:
+ CRCtriplet(crc, next, -89);
+ FALLTHROUGH_INTENDED;
+ case 88:
+ CRCtriplet(crc, next, -88);
+ FALLTHROUGH_INTENDED;
+ case 87:
+ CRCtriplet(crc, next, -87);
+ FALLTHROUGH_INTENDED;
+ case 86:
+ CRCtriplet(crc, next, -86);
+ FALLTHROUGH_INTENDED;
+ case 85:
+ CRCtriplet(crc, next, -85);
+ FALLTHROUGH_INTENDED;
+ case 84:
+ CRCtriplet(crc, next, -84);
+ FALLTHROUGH_INTENDED;
+ case 83:
+ CRCtriplet(crc, next, -83);
+ FALLTHROUGH_INTENDED;
+ case 82:
+ CRCtriplet(crc, next, -82);
+ FALLTHROUGH_INTENDED;
+ case 81:
+ CRCtriplet(crc, next, -81);
+ FALLTHROUGH_INTENDED;
+ case 80:
+ CRCtriplet(crc, next, -80);
+ FALLTHROUGH_INTENDED;
+ case 79:
+ CRCtriplet(crc, next, -79);
+ FALLTHROUGH_INTENDED;
+ case 78:
+ CRCtriplet(crc, next, -78);
+ FALLTHROUGH_INTENDED;
+ case 77:
+ CRCtriplet(crc, next, -77);
+ FALLTHROUGH_INTENDED;
+ case 76:
+ CRCtriplet(crc, next, -76);
+ FALLTHROUGH_INTENDED;
+ case 75:
+ CRCtriplet(crc, next, -75);
+ FALLTHROUGH_INTENDED;
+ case 74:
+ CRCtriplet(crc, next, -74);
+ FALLTHROUGH_INTENDED;
+ case 73:
+ CRCtriplet(crc, next, -73);
+ FALLTHROUGH_INTENDED;
+ case 72:
+ CRCtriplet(crc, next, -72);
+ FALLTHROUGH_INTENDED;
+ case 71:
+ CRCtriplet(crc, next, -71);
+ FALLTHROUGH_INTENDED;
+ case 70:
+ CRCtriplet(crc, next, -70);
+ FALLTHROUGH_INTENDED;
+ case 69:
+ CRCtriplet(crc, next, -69);
+ FALLTHROUGH_INTENDED;
+ case 68:
+ CRCtriplet(crc, next, -68);
+ FALLTHROUGH_INTENDED;
+ case 67:
+ CRCtriplet(crc, next, -67);
+ FALLTHROUGH_INTENDED;
+ case 66:
+ CRCtriplet(crc, next, -66);
+ FALLTHROUGH_INTENDED;
+ case 65:
+ CRCtriplet(crc, next, -65);
+ FALLTHROUGH_INTENDED;
+ case 64:
+ CRCtriplet(crc, next, -64);
+ FALLTHROUGH_INTENDED;
+ case 63:
+ CRCtriplet(crc, next, -63);
+ FALLTHROUGH_INTENDED;
+ case 62:
+ CRCtriplet(crc, next, -62);
+ FALLTHROUGH_INTENDED;
+ case 61:
+ CRCtriplet(crc, next, -61);
+ FALLTHROUGH_INTENDED;
+ case 60:
+ CRCtriplet(crc, next, -60);
+ FALLTHROUGH_INTENDED;
+ case 59:
+ CRCtriplet(crc, next, -59);
+ FALLTHROUGH_INTENDED;
+ case 58:
+ CRCtriplet(crc, next, -58);
+ FALLTHROUGH_INTENDED;
+ case 57:
+ CRCtriplet(crc, next, -57);
+ FALLTHROUGH_INTENDED;
+ case 56:
+ CRCtriplet(crc, next, -56);
+ FALLTHROUGH_INTENDED;
+ case 55:
+ CRCtriplet(crc, next, -55);
+ FALLTHROUGH_INTENDED;
+ case 54:
+ CRCtriplet(crc, next, -54);
+ FALLTHROUGH_INTENDED;
+ case 53:
+ CRCtriplet(crc, next, -53);
+ FALLTHROUGH_INTENDED;
+ case 52:
+ CRCtriplet(crc, next, -52);
+ FALLTHROUGH_INTENDED;
+ case 51:
+ CRCtriplet(crc, next, -51);
+ FALLTHROUGH_INTENDED;
+ case 50:
+ CRCtriplet(crc, next, -50);
+ FALLTHROUGH_INTENDED;
+ case 49:
+ CRCtriplet(crc, next, -49);
+ FALLTHROUGH_INTENDED;
+ case 48:
+ CRCtriplet(crc, next, -48);
+ FALLTHROUGH_INTENDED;
+ case 47:
+ CRCtriplet(crc, next, -47);
+ FALLTHROUGH_INTENDED;
+ case 46:
+ CRCtriplet(crc, next, -46);
+ FALLTHROUGH_INTENDED;
+ case 45:
+ CRCtriplet(crc, next, -45);
+ FALLTHROUGH_INTENDED;
+ case 44:
+ CRCtriplet(crc, next, -44);
+ FALLTHROUGH_INTENDED;
+ case 43:
+ CRCtriplet(crc, next, -43);
+ FALLTHROUGH_INTENDED;
+ case 42:
+ CRCtriplet(crc, next, -42);
+ FALLTHROUGH_INTENDED;
+ case 41:
+ CRCtriplet(crc, next, -41);
+ FALLTHROUGH_INTENDED;
+ case 40:
+ CRCtriplet(crc, next, -40);
+ FALLTHROUGH_INTENDED;
+ case 39:
+ CRCtriplet(crc, next, -39);
+ FALLTHROUGH_INTENDED;
+ case 38:
+ CRCtriplet(crc, next, -38);
+ FALLTHROUGH_INTENDED;
+ case 37:
+ CRCtriplet(crc, next, -37);
+ FALLTHROUGH_INTENDED;
+ case 36:
+ CRCtriplet(crc, next, -36);
+ FALLTHROUGH_INTENDED;
+ case 35:
+ CRCtriplet(crc, next, -35);
+ FALLTHROUGH_INTENDED;
+ case 34:
+ CRCtriplet(crc, next, -34);
+ FALLTHROUGH_INTENDED;
+ case 33:
+ CRCtriplet(crc, next, -33);
+ FALLTHROUGH_INTENDED;
+ case 32:
+ CRCtriplet(crc, next, -32);
+ FALLTHROUGH_INTENDED;
+ case 31:
+ CRCtriplet(crc, next, -31);
+ FALLTHROUGH_INTENDED;
+ case 30:
+ CRCtriplet(crc, next, -30);
+ FALLTHROUGH_INTENDED;
+ case 29:
+ CRCtriplet(crc, next, -29);
+ FALLTHROUGH_INTENDED;
+ case 28:
+ CRCtriplet(crc, next, -28);
+ FALLTHROUGH_INTENDED;
+ case 27:
+ CRCtriplet(crc, next, -27);
+ FALLTHROUGH_INTENDED;
+ case 26:
+ CRCtriplet(crc, next, -26);
+ FALLTHROUGH_INTENDED;
+ case 25:
+ CRCtriplet(crc, next, -25);
+ FALLTHROUGH_INTENDED;
+ case 24:
+ CRCtriplet(crc, next, -24);
+ FALLTHROUGH_INTENDED;
+ case 23:
+ CRCtriplet(crc, next, -23);
+ FALLTHROUGH_INTENDED;
+ case 22:
+ CRCtriplet(crc, next, -22);
+ FALLTHROUGH_INTENDED;
+ case 21:
+ CRCtriplet(crc, next, -21);
+ FALLTHROUGH_INTENDED;
+ case 20:
+ CRCtriplet(crc, next, -20);
+ FALLTHROUGH_INTENDED;
+ case 19:
+ CRCtriplet(crc, next, -19);
+ FALLTHROUGH_INTENDED;
+ case 18:
+ CRCtriplet(crc, next, -18);
+ FALLTHROUGH_INTENDED;
+ case 17:
+ CRCtriplet(crc, next, -17);
+ FALLTHROUGH_INTENDED;
+ case 16:
+ CRCtriplet(crc, next, -16);
+ FALLTHROUGH_INTENDED;
+ case 15:
+ CRCtriplet(crc, next, -15);
+ FALLTHROUGH_INTENDED;
+ case 14:
+ CRCtriplet(crc, next, -14);
+ FALLTHROUGH_INTENDED;
+ case 13:
+ CRCtriplet(crc, next, -13);
+ FALLTHROUGH_INTENDED;
+ case 12:
+ CRCtriplet(crc, next, -12);
+ FALLTHROUGH_INTENDED;
+ case 11:
+ CRCtriplet(crc, next, -11);
+ FALLTHROUGH_INTENDED;
+ case 10:
+ CRCtriplet(crc, next, -10);
+ FALLTHROUGH_INTENDED;
+ case 9:
+ CRCtriplet(crc, next, -9);
+ FALLTHROUGH_INTENDED;
+ case 8:
+ CRCtriplet(crc, next, -8);
+ FALLTHROUGH_INTENDED;
+ case 7:
+ CRCtriplet(crc, next, -7);
+ FALLTHROUGH_INTENDED;
+ case 6:
+ CRCtriplet(crc, next, -6);
+ FALLTHROUGH_INTENDED;
+ case 5:
+ CRCtriplet(crc, next, -5);
+ FALLTHROUGH_INTENDED;
+ case 4:
+ CRCtriplet(crc, next, -4);
+ FALLTHROUGH_INTENDED;
+ case 3:
+ CRCtriplet(crc, next, -3);
+ FALLTHROUGH_INTENDED;
+ case 2:
+ CRCtriplet(crc, next, -2);
+ FALLTHROUGH_INTENDED;
+ case 1:
+ CRCduplet(crc, next, -1); // the final triplet is actually only 2
+ //{ CombineCRC(); }
+ crc0 = CombineCRC(block_size, crc0, crc1, crc2, next2);
+ if (--n > 0) {
+ crc1 = crc2 = 0;
+ block_size = 128;
+ // points to the first byte of the next block
+ next0 = next2 + 128;
+ next1 = next0 + 128; // from here on all blocks are 128 long
+ next2 = next1 + 128;
+ }
+ FALLTHROUGH_INTENDED;
+ case 0:;
+ } while (n > 0);
+ }
+ next = (const unsigned char*)next2;
+ }
+ uint64_t count2 = len >> 3; // 216 of less bytes is 27 or less singlets
+ len = len & 7;
+ next += (count2 * 8);
+ switch (count2) {
+ case 27:
+ CRCsinglet(crc0, next, -27 * 8);
+ FALLTHROUGH_INTENDED;
+ case 26:
+ CRCsinglet(crc0, next, -26 * 8);
+ FALLTHROUGH_INTENDED;
+ case 25:
+ CRCsinglet(crc0, next, -25 * 8);
+ FALLTHROUGH_INTENDED;
+ case 24:
+ CRCsinglet(crc0, next, -24 * 8);
+ FALLTHROUGH_INTENDED;
+ case 23:
+ CRCsinglet(crc0, next, -23 * 8);
+ FALLTHROUGH_INTENDED;
+ case 22:
+ CRCsinglet(crc0, next, -22 * 8);
+ FALLTHROUGH_INTENDED;
+ case 21:
+ CRCsinglet(crc0, next, -21 * 8);
+ FALLTHROUGH_INTENDED;
+ case 20:
+ CRCsinglet(crc0, next, -20 * 8);
+ FALLTHROUGH_INTENDED;
+ case 19:
+ CRCsinglet(crc0, next, -19 * 8);
+ FALLTHROUGH_INTENDED;
+ case 18:
+ CRCsinglet(crc0, next, -18 * 8);
+ FALLTHROUGH_INTENDED;
+ case 17:
+ CRCsinglet(crc0, next, -17 * 8);
+ FALLTHROUGH_INTENDED;
+ case 16:
+ CRCsinglet(crc0, next, -16 * 8);
+ FALLTHROUGH_INTENDED;
+ case 15:
+ CRCsinglet(crc0, next, -15 * 8);
+ FALLTHROUGH_INTENDED;
+ case 14:
+ CRCsinglet(crc0, next, -14 * 8);
+ FALLTHROUGH_INTENDED;
+ case 13:
+ CRCsinglet(crc0, next, -13 * 8);
+ FALLTHROUGH_INTENDED;
+ case 12:
+ CRCsinglet(crc0, next, -12 * 8);
+ FALLTHROUGH_INTENDED;
+ case 11:
+ CRCsinglet(crc0, next, -11 * 8);
+ FALLTHROUGH_INTENDED;
+ case 10:
+ CRCsinglet(crc0, next, -10 * 8);
+ FALLTHROUGH_INTENDED;
+ case 9:
+ CRCsinglet(crc0, next, -9 * 8);
+ FALLTHROUGH_INTENDED;
+ case 8:
+ CRCsinglet(crc0, next, -8 * 8);
+ FALLTHROUGH_INTENDED;
+ case 7:
+ CRCsinglet(crc0, next, -7 * 8);
+ FALLTHROUGH_INTENDED;
+ case 6:
+ CRCsinglet(crc0, next, -6 * 8);
+ FALLTHROUGH_INTENDED;
+ case 5:
+ CRCsinglet(crc0, next, -5 * 8);
+ FALLTHROUGH_INTENDED;
+ case 4:
+ CRCsinglet(crc0, next, -4 * 8);
+ FALLTHROUGH_INTENDED;
+ case 3:
+ CRCsinglet(crc0, next, -3 * 8);
+ FALLTHROUGH_INTENDED;
+ case 2:
+ CRCsinglet(crc0, next, -2 * 8);
+ FALLTHROUGH_INTENDED;
+ case 1:
+ CRCsinglet(crc0, next, -1 * 8);
+ FALLTHROUGH_INTENDED;
+ case 0:;
+ }
+ }
+ {
+ align_to_8(len, crc0, next);
+ return (uint32_t)crc0 ^ 0xffffffffu;
+ }
+}
+
+#endif //HAVE_SSE42 && HAVE_PCLMUL
+
+static inline Function Choose_Extend() {
+#ifdef HAVE_POWER8
+ return isAltiVec() ? ExtendPPCImpl : ExtendImpl<Slow_CRC32>;
+#elif defined(HAVE_ARM64_CRC)
+ if(crc32c_runtime_check()) {
+ pmull_runtime_flag = crc32c_pmull_runtime_check();
+ return ExtendARMImpl;
+ } else {
+ return ExtendImpl<Slow_CRC32>;
+ }
+#else
+ if (isSSE42()) {
+ if (isPCLMULQDQ()) {
+#if (defined HAVE_SSE42 && defined HAVE_PCLMUL) && !defined NO_THREEWAY_CRC32C
+ return crc32c_3way;
+#else
+ return ExtendImpl<Fast_CRC32>; // Fast_CRC32 will check HAVE_SSE42 itself
+#endif
+ }
+ else { // no runtime PCLMULQDQ support but has SSE42 support
+ return ExtendImpl<Fast_CRC32>;
+ }
+ } // end of isSSE42()
+ else {
+ return ExtendImpl<Slow_CRC32>;
+ }
+#endif
+}
+
+static Function ChosenExtend = Choose_Extend();
+uint32_t Extend(uint32_t crc, const char* buf, size_t size) {
+ return ChosenExtend(crc, buf, size);
+}
+
+// The code for crc32c combine, copied with permission from folly
+
+// Standard galois-field multiply. The only modification is that a,
+// b, m, and p are all bit-reflected.
+//
+// https://en.wikipedia.org/wiki/Finite_field_arithmetic
+static constexpr uint32_t gf_multiply_sw_1(
+ size_t i, uint32_t p, uint32_t a, uint32_t b, uint32_t m) {
+ // clang-format off
+ return i == 32 ? p : gf_multiply_sw_1(
+ /* i = */ i + 1,
+ /* p = */ p ^ ((0u-((b >> 31) & 1)) & a),
+ /* a = */ (a >> 1) ^ ((0u-(a & 1)) & m),
+ /* b = */ b << 1,
+ /* m = */ m);
+ // clang-format on
+}
+static constexpr uint32_t gf_multiply_sw(uint32_t a, uint32_t b, uint32_t m) {
+ return gf_multiply_sw_1(/* i = */ 0, /* p = */ 0, a, b, m);
+}
+
+static constexpr uint32_t gf_square_sw(uint32_t a, uint32_t m) {
+ return gf_multiply_sw(a, a, m);
+}
+
+template <size_t i, uint32_t m>
+struct gf_powers_memo {
+ static constexpr uint32_t value =
+ gf_square_sw(gf_powers_memo<i - 1, m>::value, m);
+};
+template <uint32_t m>
+struct gf_powers_memo<0, m> {
+ static constexpr uint32_t value = m;
+};
+
+template <typename T, T... Ints>
+struct integer_sequence {
+ using value_type = T;
+ static constexpr size_t size() { return sizeof...(Ints); }
+};
+
+template <typename T, std::size_t N, T... Is>
+struct make_integer_sequence : make_integer_sequence<T, N - 1, N - 1, Is...> {};
+
+template <typename T, T... Is>
+struct make_integer_sequence<T, 0, Is...> : integer_sequence<T, Is...> {};
+
+template <std::size_t N>
+using make_index_sequence = make_integer_sequence<std::size_t, N>;
+
+template <uint32_t m>
+struct gf_powers_make {
+ template <size_t... i>
+ using index_sequence = integer_sequence<size_t, i...>;
+ template <size_t... i>
+ constexpr std::array<uint32_t, sizeof...(i)> operator()(
+ index_sequence<i...>) const {
+ return std::array<uint32_t, sizeof...(i)>{{gf_powers_memo<i, m>::value...}};
+ }
+};
+
+static constexpr uint32_t crc32c_m = 0x82f63b78;
+
+static constexpr std::array<uint32_t, 62> const crc32c_powers =
+ gf_powers_make<crc32c_m>{}(make_index_sequence<62>{});
+
+// Expects a "pure" crc (see Crc32cCombine)
+static uint32_t Crc32AppendZeroes(
+ uint32_t crc, size_t len_over_4, uint32_t polynomial,
+ std::array<uint32_t, 62> const& powers_array) {
+ auto powers = powers_array.data();
+ // Append by multiplying by consecutive powers of two of the zeroes
+ // array
+ size_t len_bits = len_over_4;
+
+ while (len_bits) {
+ // Advance directly to next bit set.
+ auto r = CountTrailingZeroBits(len_bits);
+ len_bits >>= r;
+ powers += r;
+
+ crc = gf_multiply_sw(crc, *powers, polynomial);
+
+ len_bits >>= 1;
+ powers++;
+ }
+
+ return crc;
+}
+
+static inline uint32_t InvertedToPure(uint32_t crc) { return ~crc; }
+
+static inline uint32_t PureToInverted(uint32_t crc) { return ~crc; }
+
+static inline uint32_t PureExtend(uint32_t crc, const char* buf, size_t size) {
+ return InvertedToPure(Extend(PureToInverted(crc), buf, size));
+}
+
+// Background:
+// RocksDB uses two kinds of crc32c values: masked and unmasked. Neither is
+// a "pure" CRC because a pure CRC satisfies (^ for xor)
+// crc(a ^ b) = crc(a) ^ crc(b)
+// The unmasked is closest, and this function takes unmasked crc32c values.
+// The unmasked values are impure in two ways:
+// * The initial setting at the start of CRC computation is all 1 bits
+// (like -1) instead of zero.
+// * The result has all bits invered.
+// Note that together, these result in the empty string having a crc32c of
+// zero. See
+// https://en.wikipedia.org/wiki/Computation_of_cyclic_redundancy_checks#CRC_variants
+//
+// Simplified version of strategy, using xor through pure CRCs (+ for concat):
+//
+// pure_crc(str1 + str2) = pure_crc(str1 + zeros(len(str2))) ^
+// pure_crc(zeros(len(str1)) + str2)
+//
+// because the xor of these two zero-padded strings is str1 + str2. For pure
+// CRC, leading zeros don't affect the result, so we only need
+//
+// pure_crc(str1 + str2) = pure_crc(str1 + zeros(len(str2))) ^
+// pure_crc(str2)
+//
+// Considering we aren't working with pure CRCs, what is actually in the input?
+//
+// crc1 = PureToInverted(PureExtendCrc32c(-1, zeros, crc1len) ^
+// PureCrc32c(str1, crc1len))
+// crc2 = PureToInverted(PureExtendCrc32c(-1, zeros, crc2len) ^
+// PureCrc32c(str2, crc2len))
+//
+// The result we want to compute is
+// combined = PureToInverted(PureExtendCrc32c(PureExtendCrc32c(-1, zeros,
+// crc1len) ^
+// PureCrc32c(str1, crc1len),
+// zeros, crc2len) ^
+// PureCrc32c(str2, crc2len))
+//
+// Thus, in addition to extending crc1 over the length of str2 in (virtual)
+// zeros, we need to cancel out the -1 initializer that was used in computing
+// crc2. To cancel it out, we also need to extend it over crc2len in zeros.
+// To simplify, since the end of str1 and that -1 initializer for crc2 are at
+// the same logical position, we can combine them before we extend over the
+// zeros.
+uint32_t Crc32cCombine(uint32_t crc1, uint32_t crc2, size_t crc2len) {
+ uint32_t pure_crc1_with_init = InvertedToPure(crc1);
+ uint32_t pure_crc2_with_init = InvertedToPure(crc2);
+ uint32_t pure_crc2_init = static_cast<uint32_t>(-1);
+
+ // Append up to 32 bits of zeroes in the normal way
+ char zeros[4] = {0, 0, 0, 0};
+ auto len = crc2len & 3;
+ uint32_t tmp = pure_crc1_with_init ^ pure_crc2_init;
+ if (len) {
+ tmp = PureExtend(tmp, zeros, len);
+ }
+ return PureToInverted(
+ Crc32AppendZeroes(tmp, crc2len / 4, crc32c_m, crc32c_powers) ^
+ pure_crc2_with_init);
+}
+
+} // namespace crc32c
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/crc32c.h b/src/rocksdb/util/crc32c.h
new file mode 100644
index 000000000..a08ad60af
--- /dev/null
+++ b/src/rocksdb/util/crc32c.h
@@ -0,0 +1,56 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#pragma once
+#include <stddef.h>
+#include <stdint.h>
+
+#include <string>
+
+#include "rocksdb/rocksdb_namespace.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace crc32c {
+
+extern std::string IsFastCrc32Supported();
+
+// Return the crc32c of concat(A, data[0,n-1]) where init_crc is the
+// crc32c of some string A. Extend() is often used to maintain the
+// crc32c of a stream of data.
+extern uint32_t Extend(uint32_t init_crc, const char* data, size_t n);
+
+// Takes two unmasked crc32c values, and the length of the string from
+// which `crc2` was computed, and computes a crc32c value for the
+// concatenation of the original two input strings. Running time is
+// ~ log(crc2len).
+extern uint32_t Crc32cCombine(uint32_t crc1, uint32_t crc2, size_t crc2len);
+
+// Return the crc32c of data[0,n-1]
+inline uint32_t Value(const char* data, size_t n) { return Extend(0, data, n); }
+
+static const uint32_t kMaskDelta = 0xa282ead8ul;
+
+// Return a masked representation of crc.
+//
+// Motivation: it is problematic to compute the CRC of a string that
+// contains embedded CRCs. Therefore we recommend that CRCs stored
+// somewhere (e.g., in files) should be masked before being stored.
+inline uint32_t Mask(uint32_t crc) {
+ // Rotate right by 15 bits and add a constant.
+ return ((crc >> 15) | (crc << 17)) + kMaskDelta;
+}
+
+// Return the crc whose masked representation is masked_crc.
+inline uint32_t Unmask(uint32_t masked_crc) {
+ uint32_t rot = masked_crc - kMaskDelta;
+ return ((rot >> 17) | (rot << 15));
+}
+
+} // namespace crc32c
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/crc32c_arm64.cc b/src/rocksdb/util/crc32c_arm64.cc
new file mode 100644
index 000000000..4885f4fe1
--- /dev/null
+++ b/src/rocksdb/util/crc32c_arm64.cc
@@ -0,0 +1,215 @@
+// Copyright (c) 2018, Arm Limited and affiliates. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "util/crc32c_arm64.h"
+
+#if defined(HAVE_ARM64_CRC)
+
+#if defined(__linux__)
+#include <asm/hwcap.h>
+#endif
+#ifdef ROCKSDB_AUXV_GETAUXVAL_PRESENT
+#include <sys/auxv.h>
+#endif
+#ifndef HWCAP_CRC32
+#define HWCAP_CRC32 (1 << 7)
+#endif
+#ifndef HWCAP_PMULL
+#define HWCAP_PMULL (1 << 4)
+#endif
+#if defined(__APPLE__)
+#include <sys/sysctl.h>
+#endif
+#if defined(__OpenBSD__)
+#include <sys/types.h>
+#include <sys/sysctl.h>
+#include <machine/cpu.h>
+#include <machine/armreg.h>
+#endif
+
+#ifdef HAVE_ARM64_CRYPTO
+/* unfolding to compute 8 * 3 = 24 bytes parallelly */
+#define CRC32C24BYTES(ITR) \
+ crc1 = crc32c_u64(crc1, *(buf64 + BLK_LENGTH + (ITR))); \
+ crc2 = crc32c_u64(crc2, *(buf64 + BLK_LENGTH * 2 + (ITR))); \
+ crc0 = crc32c_u64(crc0, *(buf64 + (ITR)));
+
+/* unfolding to compute 24 * 7 = 168 bytes parallelly */
+#define CRC32C7X24BYTES(ITR) \
+ do { \
+ CRC32C24BYTES((ITR)*7 + 0) \
+ CRC32C24BYTES((ITR)*7 + 1) \
+ CRC32C24BYTES((ITR)*7 + 2) \
+ CRC32C24BYTES((ITR)*7 + 3) \
+ CRC32C24BYTES((ITR)*7 + 4) \
+ CRC32C24BYTES((ITR)*7 + 5) \
+ CRC32C24BYTES((ITR)*7 + 6) \
+ } while (0)
+#endif
+
+extern bool pmull_runtime_flag;
+
+uint32_t crc32c_runtime_check(void) {
+#if defined(ROCKSDB_AUXV_GETAUXVAL_PRESENT) || defined(__FreeBSD__)
+ uint64_t auxv = 0;
+#if defined(ROCKSDB_AUXV_GETAUXVAL_PRESENT)
+ auxv = getauxval(AT_HWCAP);
+#elif defined(__FreeBSD__)
+ elf_aux_info(AT_HWCAP, &auxv, sizeof(auxv));
+#endif
+ return (auxv & HWCAP_CRC32) != 0;
+#elif defined(__APPLE__)
+ int r;
+ size_t l = sizeof(r);
+ if (sysctlbyname("hw.optional.armv8_crc32", &r, &l, NULL, 0) == -1) return 0;
+ return r == 1;
+#elif defined(__OpenBSD__)
+ int r = 0;
+ const int isar0_mib[] = { CTL_MACHDEP, CPU_ID_AA64ISAR0 };
+ uint64_t isar0;
+ size_t len = sizeof(isar0);
+
+ if (sysctl(isar0_mib, 2, &isar0, &len, NULL, 0) != -1) {
+ if (ID_AA64ISAR0_CRC32(isar0) >= ID_AA64ISAR0_CRC32_BASE)
+ r = 1;
+ }
+ return r;
+#else
+ return 0;
+#endif
+}
+
+bool crc32c_pmull_runtime_check(void) {
+#if defined(ROCKSDB_AUXV_GETAUXVAL_PRESENT) || defined(__FreeBSD__)
+ uint64_t auxv = 0;
+#if defined(ROCKSDB_AUXV_GETAUXVAL_PRESENT)
+ auxv = getauxval(AT_HWCAP);
+#elif defined(__FreeBSD__)
+ elf_aux_info(AT_HWCAP, &auxv, sizeof(auxv));
+#endif
+ return (auxv & HWCAP_PMULL) != 0;
+#elif defined(__APPLE__)
+ return true;
+#elif defined(__OpenBSD__)
+ bool r = false;
+ const int isar0_mib[] = { CTL_MACHDEP, CPU_ID_AA64ISAR0 };
+ uint64_t isar0;
+ size_t len = sizeof(isar0);
+
+ if (sysctl(isar0_mib, 2, &isar0, &len, NULL, 0) != -1) {
+ if (ID_AA64ISAR0_AES(isar0) >= ID_AA64ISAR0_AES_PMULL)
+ r = true;
+ }
+ return r;
+#else
+ return false;
+#endif
+}
+
+#ifdef ROCKSDB_UBSAN_RUN
+#if defined(__clang__)
+__attribute__((__no_sanitize__("alignment")))
+#elif defined(__GNUC__)
+__attribute__((__no_sanitize_undefined__))
+#endif
+#endif
+uint32_t
+crc32c_arm64(uint32_t crc, unsigned char const *data, size_t len) {
+ const uint8_t *buf8;
+ const uint64_t *buf64 = (uint64_t *)data;
+ int length = (int)len;
+ crc ^= 0xffffffff;
+
+ /*
+ * Pmull runtime check here.
+ * Raspberry Pi supports crc32 but doesn't support pmull.
+ * Skip Crc32c Parallel computation if no crypto extension available.
+ */
+ if (pmull_runtime_flag) {
+/* Macro (HAVE_ARM64_CRYPTO) is used for compiling check */
+#ifdef HAVE_ARM64_CRYPTO
+/* Crc32c Parallel computation
+ * Algorithm comes from Intel whitepaper:
+ * crc-iscsi-polynomial-crc32-instruction-paper
+ *
+ * Input data is divided into three equal-sized blocks
+ * Three parallel blocks (crc0, crc1, crc2) for 1024 Bytes
+ * One Block: 42(BLK_LENGTH) * 8(step length: crc32c_u64) bytes
+ */
+#define BLK_LENGTH 42
+ while (length >= 1024) {
+ uint64_t t0, t1;
+ uint32_t crc0 = 0, crc1 = 0, crc2 = 0;
+
+ /* Parallel Param:
+ * k0 = CRC32(x ^ (42 * 8 * 8 * 2 - 1));
+ * k1 = CRC32(x ^ (42 * 8 * 8 - 1));
+ */
+ uint32_t k0 = 0xe417f38a, k1 = 0x8f158014;
+
+ /* Prefetch data for following block to avoid cache miss */
+ PREF1KL1((uint8_t *)buf64, 1024);
+
+ /* First 8 byte for better pipelining */
+ crc0 = crc32c_u64(crc, *buf64++);
+
+ /* 3 blocks crc32c parallel computation
+ * Macro unfolding to compute parallelly
+ * 168 * 6 = 1008 (bytes)
+ */
+ CRC32C7X24BYTES(0);
+ CRC32C7X24BYTES(1);
+ CRC32C7X24BYTES(2);
+ CRC32C7X24BYTES(3);
+ CRC32C7X24BYTES(4);
+ CRC32C7X24BYTES(5);
+ buf64 += (BLK_LENGTH * 3);
+
+ /* Last 8 bytes */
+ crc = crc32c_u64(crc2, *buf64++);
+
+ t0 = (uint64_t)vmull_p64(crc0, k0);
+ t1 = (uint64_t)vmull_p64(crc1, k1);
+
+ /* Merge (crc0, crc1, crc2) -> crc */
+ crc1 = crc32c_u64(0, t1);
+ crc ^= crc1;
+ crc0 = crc32c_u64(0, t0);
+ crc ^= crc0;
+
+ length -= 1024;
+ }
+
+ if (length == 0) return crc ^ (0xffffffffU);
+#endif
+ } // if Pmull runtime check here
+
+ buf8 = (const uint8_t *)buf64;
+ while (length >= 8) {
+ crc = crc32c_u64(crc, *(const uint64_t *)buf8);
+ buf8 += 8;
+ length -= 8;
+ }
+
+ /* The following is more efficient than the straight loop */
+ if (length >= 4) {
+ crc = crc32c_u32(crc, *(const uint32_t *)buf8);
+ buf8 += 4;
+ length -= 4;
+ }
+
+ if (length >= 2) {
+ crc = crc32c_u16(crc, *(const uint16_t *)buf8);
+ buf8 += 2;
+ length -= 2;
+ }
+
+ if (length >= 1) crc = crc32c_u8(crc, *buf8);
+
+ crc ^= 0xffffffff;
+ return crc;
+}
+
+#endif
diff --git a/src/rocksdb/util/crc32c_arm64.h b/src/rocksdb/util/crc32c_arm64.h
new file mode 100644
index 000000000..4b27fe871
--- /dev/null
+++ b/src/rocksdb/util/crc32c_arm64.h
@@ -0,0 +1,52 @@
+// Copyright (c) 2018, Arm Limited and affiliates. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef UTIL_CRC32C_ARM64_H
+#define UTIL_CRC32C_ARM64_H
+
+#include <cinttypes>
+#include <cstddef>
+
+#if defined(__aarch64__) || defined(__AARCH64__)
+
+#ifdef __ARM_FEATURE_CRC32
+#define HAVE_ARM64_CRC
+#include <arm_acle.h>
+#define crc32c_u8(crc, v) __crc32cb(crc, v)
+#define crc32c_u16(crc, v) __crc32ch(crc, v)
+#define crc32c_u32(crc, v) __crc32cw(crc, v)
+#define crc32c_u64(crc, v) __crc32cd(crc, v)
+// clang-format off
+#define PREF4X64L1(buffer, PREF_OFFSET, ITR) \
+ __asm__("PRFM PLDL1KEEP, [%x[v],%[c]]" ::[v] "r"(buffer), \
+ [c] "I"((PREF_OFFSET) + ((ITR) + 0) * 64)); \
+ __asm__("PRFM PLDL1KEEP, [%x[v],%[c]]" ::[v] "r"(buffer), \
+ [c] "I"((PREF_OFFSET) + ((ITR) + 1) * 64)); \
+ __asm__("PRFM PLDL1KEEP, [%x[v],%[c]]" ::[v] "r"(buffer), \
+ [c] "I"((PREF_OFFSET) + ((ITR) + 2) * 64)); \
+ __asm__("PRFM PLDL1KEEP, [%x[v],%[c]]" ::[v] "r"(buffer), \
+ [c] "I"((PREF_OFFSET) + ((ITR) + 3) * 64));
+// clang-format on
+
+#define PREF1KL1(buffer, PREF_OFFSET) \
+ PREF4X64L1(buffer, (PREF_OFFSET), 0) \
+ PREF4X64L1(buffer, (PREF_OFFSET), 4) \
+ PREF4X64L1(buffer, (PREF_OFFSET), 8) \
+ PREF4X64L1(buffer, (PREF_OFFSET), 12)
+
+extern uint32_t crc32c_arm64(uint32_t crc, unsigned char const *data,
+ size_t len);
+extern uint32_t crc32c_runtime_check(void);
+extern bool crc32c_pmull_runtime_check(void);
+
+#ifdef __ARM_FEATURE_CRYPTO
+#define HAVE_ARM64_CRYPTO
+#include <arm_neon.h>
+#endif // __ARM_FEATURE_CRYPTO
+#endif // __ARM_FEATURE_CRC32
+
+#endif // defined(__aarch64__) || defined(__AARCH64__)
+
+#endif
diff --git a/src/rocksdb/util/crc32c_ppc.c b/src/rocksdb/util/crc32c_ppc.c
new file mode 100644
index 000000000..b37dfb158
--- /dev/null
+++ b/src/rocksdb/util/crc32c_ppc.c
@@ -0,0 +1,94 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+// Copyright (c) 2017 International Business Machines Corp.
+// All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#define CRC_TABLE
+#include <stdint.h>
+#include <stdlib.h>
+#include <strings.h>
+#include "util/crc32c_ppc_constants.h"
+
+#define VMX_ALIGN 16
+#define VMX_ALIGN_MASK (VMX_ALIGN - 1)
+
+#ifdef REFLECT
+static unsigned int crc32_align(unsigned int crc, unsigned char const *p,
+ unsigned long len) {
+ while (len--) crc = crc_table[(crc ^ *p++) & 0xff] ^ (crc >> 8);
+ return crc;
+}
+#endif
+
+#ifdef HAVE_POWER8
+unsigned int __crc32_vpmsum(unsigned int crc, unsigned char const *p,
+ unsigned long len);
+
+static uint32_t crc32_vpmsum(uint32_t crc, unsigned char const *data,
+ size_t len) {
+ unsigned int prealign;
+ unsigned int tail;
+
+#ifdef CRC_XOR
+ crc ^= 0xffffffff;
+#endif
+
+ if (len < VMX_ALIGN + VMX_ALIGN_MASK) {
+ crc = crc32_align(crc, data, (unsigned long)len);
+ goto out;
+ }
+
+ if ((unsigned long)data & VMX_ALIGN_MASK) {
+ prealign = VMX_ALIGN - ((unsigned long)data & VMX_ALIGN_MASK);
+ crc = crc32_align(crc, data, prealign);
+ len -= prealign;
+ data += prealign;
+ }
+
+ crc = __crc32_vpmsum(crc, data, (unsigned long)len & ~VMX_ALIGN_MASK);
+
+ tail = len & VMX_ALIGN_MASK;
+ if (tail) {
+ data += len & ~VMX_ALIGN_MASK;
+ crc = crc32_align(crc, data, tail);
+ }
+
+out:
+#ifdef CRC_XOR
+ crc ^= 0xffffffff;
+#endif
+
+ return crc;
+}
+
+/* This wrapper function works around the fact that crc32_vpmsum
+ * does not gracefully handle the case where the data pointer is NULL. There
+ * may be room for performance improvement here.
+ */
+uint32_t crc32c_ppc(uint32_t crc, unsigned char const *data, size_t len) {
+ unsigned char *buf2;
+
+ if (!data) {
+ buf2 = (unsigned char *)malloc(len);
+ bzero(buf2, len);
+ crc = crc32_vpmsum(crc, buf2, len);
+ free(buf2);
+ } else {
+ crc = crc32_vpmsum(crc, data, (unsigned long)len);
+ }
+ return crc;
+}
+
+#else /* HAVE_POWER8 */
+
+/* This symbol has to exist on non-ppc architectures (and on legacy
+ * ppc systems using power7 or below) in order to compile properly
+ * there, even though it won't be called.
+ */
+uint32_t crc32c_ppc(uint32_t crc, unsigned char const *data, size_t len) {
+ return 0;
+}
+
+#endif /* HAVE_POWER8 */
diff --git a/src/rocksdb/util/crc32c_ppc.h b/src/rocksdb/util/crc32c_ppc.h
new file mode 100644
index 000000000..f0b0b66d5
--- /dev/null
+++ b/src/rocksdb/util/crc32c_ppc.h
@@ -0,0 +1,22 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+// Copyright (c) 2017 International Business Machines Corp.
+// All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#include <cstddef>
+#include <cstdint>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+extern uint32_t crc32c_ppc(uint32_t crc, unsigned char const *buffer,
+ size_t len);
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/src/rocksdb/util/crc32c_ppc_asm.S b/src/rocksdb/util/crc32c_ppc_asm.S
new file mode 100644
index 000000000..6959ba839
--- /dev/null
+++ b/src/rocksdb/util/crc32c_ppc_asm.S
@@ -0,0 +1,756 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+// Copyright (c) 2015 Anton Blanchard <anton@au.ibm.com>, IBM
+// Copyright (c) 2017 International Business Machines Corp.
+// All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#if defined (__clang__)
+#include "third-party/gcc/ppc-asm.h"
+#else
+#include <ppc-asm.h>
+#endif
+#include "ppc-opcode.h"
+
+#undef toc
+
+#ifndef r1
+#define r1 1
+#endif
+
+#ifndef r2
+#define r2 2
+#endif
+
+ .section .rodata
+.balign 16
+
+.byteswap_constant:
+ /* byte reverse permute constant */
+ .octa 0x0F0E0D0C0B0A09080706050403020100
+
+#define __ASSEMBLY__
+#include "crc32c_ppc_constants.h"
+
+ .text
+
+#if defined(__BIG_ENDIAN__) && defined(REFLECT)
+#define BYTESWAP_DATA
+#elif defined(__LITTLE_ENDIAN__) && !defined(REFLECT)
+#define BYTESWAP_DATA
+#else
+#undef BYTESWAP_DATA
+#endif
+
+#define off16 r25
+#define off32 r26
+#define off48 r27
+#define off64 r28
+#define off80 r29
+#define off96 r30
+#define off112 r31
+
+#define const1 v24
+#define const2 v25
+
+#define byteswap v26
+#define mask_32bit v27
+#define mask_64bit v28
+#define zeroes v29
+
+#ifdef BYTESWAP_DATA
+#define VPERM(A, B, C, D) vperm A, B, C, D
+#else
+#define VPERM(A, B, C, D)
+#endif
+
+/* unsigned int __crc32_vpmsum(unsigned int crc, void *p, unsigned long len) */
+FUNC_START(__crc32_vpmsum)
+ std r31,-8(r1)
+ std r30,-16(r1)
+ std r29,-24(r1)
+ std r28,-32(r1)
+ std r27,-40(r1)
+ std r26,-48(r1)
+ std r25,-56(r1)
+
+ li off16,16
+ li off32,32
+ li off48,48
+ li off64,64
+ li off80,80
+ li off96,96
+ li off112,112
+ li r0,0
+
+ /* Enough room for saving 10 non volatile VMX registers */
+ subi r6,r1,56+10*16
+ subi r7,r1,56+2*16
+
+ stvx v20,0,r6
+ stvx v21,off16,r6
+ stvx v22,off32,r6
+ stvx v23,off48,r6
+ stvx v24,off64,r6
+ stvx v25,off80,r6
+ stvx v26,off96,r6
+ stvx v27,off112,r6
+ stvx v28,0,r7
+ stvx v29,off16,r7
+
+ mr r10,r3
+
+ vxor zeroes,zeroes,zeroes
+ vspltisw v0,-1
+
+ vsldoi mask_32bit,zeroes,v0,4
+ vsldoi mask_64bit,zeroes,v0,8
+
+ /* Get the initial value into v8 */
+ vxor v8,v8,v8
+ MTVRD(v8, r3)
+#ifdef REFLECT
+ vsldoi v8,zeroes,v8,8 /* shift into bottom 32 bits */
+#else
+ vsldoi v8,v8,zeroes,4 /* shift into top 32 bits */
+#endif
+
+#ifdef BYTESWAP_DATA
+ addis r3,r2,.byteswap_constant@toc@ha
+ addi r3,r3,.byteswap_constant@toc@l
+
+ lvx byteswap,0,r3
+ addi r3,r3,16
+#endif
+
+ cmpdi r5,256
+ blt .Lshort
+
+ rldicr r6,r5,0,56
+
+ /* Checksum in blocks of MAX_SIZE */
+1: lis r7,MAX_SIZE@h
+ ori r7,r7,MAX_SIZE@l
+ mr r9,r7
+ cmpd r6,r7
+ bgt 2f
+ mr r7,r6
+2: subf r6,r7,r6
+
+ /* our main loop does 128 bytes at a time */
+ srdi r7,r7,7
+
+ /*
+ * Work out the offset into the constants table to start at. Each
+ * constant is 16 bytes, and it is used against 128 bytes of input
+ * data - 128 / 16 = 8
+ */
+ sldi r8,r7,4
+ srdi r9,r9,3
+ subf r8,r8,r9
+
+ /* We reduce our final 128 bytes in a separate step */
+ addi r7,r7,-1
+ mtctr r7
+
+ addis r3,r2,.constants@toc@ha
+ addi r3,r3,.constants@toc@l
+
+ /* Find the start of our constants */
+ add r3,r3,r8
+
+ /* zero v0-v7 which will contain our checksums */
+ vxor v0,v0,v0
+ vxor v1,v1,v1
+ vxor v2,v2,v2
+ vxor v3,v3,v3
+ vxor v4,v4,v4
+ vxor v5,v5,v5
+ vxor v6,v6,v6
+ vxor v7,v7,v7
+
+ lvx const1,0,r3
+
+ /*
+ * If we are looping back to consume more data we use the values
+ * already in v16-v23.
+ */
+ cmpdi r0,1
+ beq 2f
+
+ /* First warm up pass */
+ lvx v16,0,r4
+ lvx v17,off16,r4
+ VPERM(v16,v16,v16,byteswap)
+ VPERM(v17,v17,v17,byteswap)
+ lvx v18,off32,r4
+ lvx v19,off48,r4
+ VPERM(v18,v18,v18,byteswap)
+ VPERM(v19,v19,v19,byteswap)
+ lvx v20,off64,r4
+ lvx v21,off80,r4
+ VPERM(v20,v20,v20,byteswap)
+ VPERM(v21,v21,v21,byteswap)
+ lvx v22,off96,r4
+ lvx v23,off112,r4
+ VPERM(v22,v22,v22,byteswap)
+ VPERM(v23,v23,v23,byteswap)
+ addi r4,r4,8*16
+
+ /* xor in initial value */
+ vxor v16,v16,v8
+
+2: bdz .Lfirst_warm_up_done
+
+ addi r3,r3,16
+ lvx const2,0,r3
+
+ /* Second warm up pass */
+ VPMSUMD(v8,v16,const1)
+ lvx v16,0,r4
+ VPERM(v16,v16,v16,byteswap)
+ ori r2,r2,0
+
+ VPMSUMD(v9,v17,const1)
+ lvx v17,off16,r4
+ VPERM(v17,v17,v17,byteswap)
+ ori r2,r2,0
+
+ VPMSUMD(v10,v18,const1)
+ lvx v18,off32,r4
+ VPERM(v18,v18,v18,byteswap)
+ ori r2,r2,0
+
+ VPMSUMD(v11,v19,const1)
+ lvx v19,off48,r4
+ VPERM(v19,v19,v19,byteswap)
+ ori r2,r2,0
+
+ VPMSUMD(v12,v20,const1)
+ lvx v20,off64,r4
+ VPERM(v20,v20,v20,byteswap)
+ ori r2,r2,0
+
+ VPMSUMD(v13,v21,const1)
+ lvx v21,off80,r4
+ VPERM(v21,v21,v21,byteswap)
+ ori r2,r2,0
+
+ VPMSUMD(v14,v22,const1)
+ lvx v22,off96,r4
+ VPERM(v22,v22,v22,byteswap)
+ ori r2,r2,0
+
+ VPMSUMD(v15,v23,const1)
+ lvx v23,off112,r4
+ VPERM(v23,v23,v23,byteswap)
+
+ addi r4,r4,8*16
+
+ bdz .Lfirst_cool_down
+
+ /*
+ * main loop. We modulo schedule it such that it takes three iterations
+ * to complete - first iteration load, second iteration vpmsum, third
+ * iteration xor.
+ */
+ .balign 16
+4: lvx const1,0,r3
+ addi r3,r3,16
+ ori r2,r2,0
+
+ vxor v0,v0,v8
+ VPMSUMD(v8,v16,const2)
+ lvx v16,0,r4
+ VPERM(v16,v16,v16,byteswap)
+ ori r2,r2,0
+
+ vxor v1,v1,v9
+ VPMSUMD(v9,v17,const2)
+ lvx v17,off16,r4
+ VPERM(v17,v17,v17,byteswap)
+ ori r2,r2,0
+
+ vxor v2,v2,v10
+ VPMSUMD(v10,v18,const2)
+ lvx v18,off32,r4
+ VPERM(v18,v18,v18,byteswap)
+ ori r2,r2,0
+
+ vxor v3,v3,v11
+ VPMSUMD(v11,v19,const2)
+ lvx v19,off48,r4
+ VPERM(v19,v19,v19,byteswap)
+ lvx const2,0,r3
+ ori r2,r2,0
+
+ vxor v4,v4,v12
+ VPMSUMD(v12,v20,const1)
+ lvx v20,off64,r4
+ VPERM(v20,v20,v20,byteswap)
+ ori r2,r2,0
+
+ vxor v5,v5,v13
+ VPMSUMD(v13,v21,const1)
+ lvx v21,off80,r4
+ VPERM(v21,v21,v21,byteswap)
+ ori r2,r2,0
+
+ vxor v6,v6,v14
+ VPMSUMD(v14,v22,const1)
+ lvx v22,off96,r4
+ VPERM(v22,v22,v22,byteswap)
+ ori r2,r2,0
+
+ vxor v7,v7,v15
+ VPMSUMD(v15,v23,const1)
+ lvx v23,off112,r4
+ VPERM(v23,v23,v23,byteswap)
+
+ addi r4,r4,8*16
+
+ bdnz 4b
+
+.Lfirst_cool_down:
+ /* First cool down pass */
+ lvx const1,0,r3
+ addi r3,r3,16
+
+ vxor v0,v0,v8
+ VPMSUMD(v8,v16,const1)
+ ori r2,r2,0
+
+ vxor v1,v1,v9
+ VPMSUMD(v9,v17,const1)
+ ori r2,r2,0
+
+ vxor v2,v2,v10
+ VPMSUMD(v10,v18,const1)
+ ori r2,r2,0
+
+ vxor v3,v3,v11
+ VPMSUMD(v11,v19,const1)
+ ori r2,r2,0
+
+ vxor v4,v4,v12
+ VPMSUMD(v12,v20,const1)
+ ori r2,r2,0
+
+ vxor v5,v5,v13
+ VPMSUMD(v13,v21,const1)
+ ori r2,r2,0
+
+ vxor v6,v6,v14
+ VPMSUMD(v14,v22,const1)
+ ori r2,r2,0
+
+ vxor v7,v7,v15
+ VPMSUMD(v15,v23,const1)
+ ori r2,r2,0
+
+.Lsecond_cool_down:
+ /* Second cool down pass */
+ vxor v0,v0,v8
+ vxor v1,v1,v9
+ vxor v2,v2,v10
+ vxor v3,v3,v11
+ vxor v4,v4,v12
+ vxor v5,v5,v13
+ vxor v6,v6,v14
+ vxor v7,v7,v15
+
+#ifdef REFLECT
+ /*
+ * vpmsumd produces a 96 bit result in the least significant bits
+ * of the register. Since we are bit reflected we have to shift it
+ * left 32 bits so it occupies the least significant bits in the
+ * bit reflected domain.
+ */
+ vsldoi v0,v0,zeroes,4
+ vsldoi v1,v1,zeroes,4
+ vsldoi v2,v2,zeroes,4
+ vsldoi v3,v3,zeroes,4
+ vsldoi v4,v4,zeroes,4
+ vsldoi v5,v5,zeroes,4
+ vsldoi v6,v6,zeroes,4
+ vsldoi v7,v7,zeroes,4
+#endif
+
+ /* xor with last 1024 bits */
+ lvx v8,0,r4
+ lvx v9,off16,r4
+ VPERM(v8,v8,v8,byteswap)
+ VPERM(v9,v9,v9,byteswap)
+ lvx v10,off32,r4
+ lvx v11,off48,r4
+ VPERM(v10,v10,v10,byteswap)
+ VPERM(v11,v11,v11,byteswap)
+ lvx v12,off64,r4
+ lvx v13,off80,r4
+ VPERM(v12,v12,v12,byteswap)
+ VPERM(v13,v13,v13,byteswap)
+ lvx v14,off96,r4
+ lvx v15,off112,r4
+ VPERM(v14,v14,v14,byteswap)
+ VPERM(v15,v15,v15,byteswap)
+
+ addi r4,r4,8*16
+
+ vxor v16,v0,v8
+ vxor v17,v1,v9
+ vxor v18,v2,v10
+ vxor v19,v3,v11
+ vxor v20,v4,v12
+ vxor v21,v5,v13
+ vxor v22,v6,v14
+ vxor v23,v7,v15
+
+ li r0,1
+ cmpdi r6,0
+ addi r6,r6,128
+ bne 1b
+
+ /* Work out how many bytes we have left */
+ andi. r5,r5,127
+
+ /* Calculate where in the constant table we need to start */
+ subfic r6,r5,128
+ add r3,r3,r6
+
+ /* How many 16 byte chunks are in the tail */
+ srdi r7,r5,4
+ mtctr r7
+
+ /*
+ * Reduce the previously calculated 1024 bits to 64 bits, shifting
+ * 32 bits to include the trailing 32 bits of zeros
+ */
+ lvx v0,0,r3
+ lvx v1,off16,r3
+ lvx v2,off32,r3
+ lvx v3,off48,r3
+ lvx v4,off64,r3
+ lvx v5,off80,r3
+ lvx v6,off96,r3
+ lvx v7,off112,r3
+ addi r3,r3,8*16
+
+ VPMSUMW(v0,v16,v0)
+ VPMSUMW(v1,v17,v1)
+ VPMSUMW(v2,v18,v2)
+ VPMSUMW(v3,v19,v3)
+ VPMSUMW(v4,v20,v4)
+ VPMSUMW(v5,v21,v5)
+ VPMSUMW(v6,v22,v6)
+ VPMSUMW(v7,v23,v7)
+
+ /* Now reduce the tail (0 - 112 bytes) */
+ cmpdi r7,0
+ beq 1f
+
+ lvx v16,0,r4
+ lvx v17,0,r3
+ VPERM(v16,v16,v16,byteswap)
+ VPMSUMW(v16,v16,v17)
+ vxor v0,v0,v16
+ bdz 1f
+
+ lvx v16,off16,r4
+ lvx v17,off16,r3
+ VPERM(v16,v16,v16,byteswap)
+ VPMSUMW(v16,v16,v17)
+ vxor v0,v0,v16
+ bdz 1f
+
+ lvx v16,off32,r4
+ lvx v17,off32,r3
+ VPERM(v16,v16,v16,byteswap)
+ VPMSUMW(v16,v16,v17)
+ vxor v0,v0,v16
+ bdz 1f
+
+ lvx v16,off48,r4
+ lvx v17,off48,r3
+ VPERM(v16,v16,v16,byteswap)
+ VPMSUMW(v16,v16,v17)
+ vxor v0,v0,v16
+ bdz 1f
+
+ lvx v16,off64,r4
+ lvx v17,off64,r3
+ VPERM(v16,v16,v16,byteswap)
+ VPMSUMW(v16,v16,v17)
+ vxor v0,v0,v16
+ bdz 1f
+
+ lvx v16,off80,r4
+ lvx v17,off80,r3
+ VPERM(v16,v16,v16,byteswap)
+ VPMSUMW(v16,v16,v17)
+ vxor v0,v0,v16
+ bdz 1f
+
+ lvx v16,off96,r4
+ lvx v17,off96,r3
+ VPERM(v16,v16,v16,byteswap)
+ VPMSUMW(v16,v16,v17)
+ vxor v0,v0,v16
+
+ /* Now xor all the parallel chunks together */
+1: vxor v0,v0,v1
+ vxor v2,v2,v3
+ vxor v4,v4,v5
+ vxor v6,v6,v7
+
+ vxor v0,v0,v2
+ vxor v4,v4,v6
+
+ vxor v0,v0,v4
+
+.Lbarrett_reduction:
+ /* Barrett constants */
+ addis r3,r2,.barrett_constants@toc@ha
+ addi r3,r3,.barrett_constants@toc@l
+
+ lvx const1,0,r3
+ lvx const2,off16,r3
+
+ vsldoi v1,v0,v0,8
+ vxor v0,v0,v1 /* xor two 64 bit results together */
+
+#ifdef REFLECT
+ /* shift left one bit */
+ vspltisb v1,1
+ vsl v0,v0,v1
+#endif
+
+ vand v0,v0,mask_64bit
+
+#ifndef REFLECT
+ /*
+ * Now for the Barrett reduction algorithm. The idea is to calculate q,
+ * the multiple of our polynomial that we need to subtract. By
+ * doing the computation 2x bits higher (ie 64 bits) and shifting the
+ * result back down 2x bits, we round down to the nearest multiple.
+ */
+ VPMSUMD(v1,v0,const1) /* ma */
+ vsldoi v1,zeroes,v1,8 /* q = floor(ma/(2^64)) */
+ VPMSUMD(v1,v1,const2) /* qn */
+ vxor v0,v0,v1 /* a - qn, subtraction is xor in GF(2) */
+
+ /*
+ * Get the result into r3. We need to shift it left 8 bytes:
+ * V0 [ 0 1 2 X ]
+ * V0 [ 0 X 2 3 ]
+ */
+ vsldoi v0,v0,zeroes,8 /* shift result into top 64 bits */
+#else
+ /*
+ * The reflected version of Barrett reduction. Instead of bit
+ * reflecting our data (which is expensive to do), we bit reflect our
+ * constants and our algorithm, which means the intermediate data in
+ * our vector registers goes from 0-63 instead of 63-0. We can reflect
+ * the algorithm because we don't carry in mod 2 arithmetic.
+ */
+ vand v1,v0,mask_32bit /* bottom 32 bits of a */
+ VPMSUMD(v1,v1,const1) /* ma */
+ vand v1,v1,mask_32bit /* bottom 32bits of ma */
+ VPMSUMD(v1,v1,const2) /* qn */
+ vxor v0,v0,v1 /* a - qn, subtraction is xor in GF(2) */
+
+ /*
+ * Since we are bit reflected, the result (ie the low 32 bits) is in
+ * the high 32 bits. We just need to shift it left 4 bytes
+ * V0 [ 0 1 X 3 ]
+ * V0 [ 0 X 2 3 ]
+ */
+ vsldoi v0,v0,zeroes,4 /* shift result into top 64 bits of */
+#endif
+
+ /* Get it into r3 */
+ MFVRD(r3, v0)
+
+.Lout:
+ subi r6,r1,56+10*16
+ subi r7,r1,56+2*16
+
+ lvx v20,0,r6
+ lvx v21,off16,r6
+ lvx v22,off32,r6
+ lvx v23,off48,r6
+ lvx v24,off64,r6
+ lvx v25,off80,r6
+ lvx v26,off96,r6
+ lvx v27,off112,r6
+ lvx v28,0,r7
+ lvx v29,off16,r7
+
+ ld r31,-8(r1)
+ ld r30,-16(r1)
+ ld r29,-24(r1)
+ ld r28,-32(r1)
+ ld r27,-40(r1)
+ ld r26,-48(r1)
+ ld r25,-56(r1)
+
+ blr
+
+.Lfirst_warm_up_done:
+ lvx const1,0,r3
+ addi r3,r3,16
+
+ VPMSUMD(v8,v16,const1)
+ VPMSUMD(v9,v17,const1)
+ VPMSUMD(v10,v18,const1)
+ VPMSUMD(v11,v19,const1)
+ VPMSUMD(v12,v20,const1)
+ VPMSUMD(v13,v21,const1)
+ VPMSUMD(v14,v22,const1)
+ VPMSUMD(v15,v23,const1)
+
+ b .Lsecond_cool_down
+
+.Lshort:
+ cmpdi r5,0
+ beq .Lzero
+
+ addis r3,r2,.short_constants@toc@ha
+ addi r3,r3,.short_constants@toc@l
+
+ /* Calculate where in the constant table we need to start */
+ subfic r6,r5,256
+ add r3,r3,r6
+
+ /* How many 16 byte chunks? */
+ srdi r7,r5,4
+ mtctr r7
+
+ vxor v19,v19,v19
+ vxor v20,v20,v20
+
+ lvx v0,0,r4
+ lvx v16,0,r3
+ VPERM(v0,v0,v16,byteswap)
+ vxor v0,v0,v8 /* xor in initial value */
+ VPMSUMW(v0,v0,v16)
+ bdz .Lv0
+
+ lvx v1,off16,r4
+ lvx v17,off16,r3
+ VPERM(v1,v1,v17,byteswap)
+ VPMSUMW(v1,v1,v17)
+ bdz .Lv1
+
+ lvx v2,off32,r4
+ lvx v16,off32,r3
+ VPERM(v2,v2,v16,byteswap)
+ VPMSUMW(v2,v2,v16)
+ bdz .Lv2
+
+ lvx v3,off48,r4
+ lvx v17,off48,r3
+ VPERM(v3,v3,v17,byteswap)
+ VPMSUMW(v3,v3,v17)
+ bdz .Lv3
+
+ lvx v4,off64,r4
+ lvx v16,off64,r3
+ VPERM(v4,v4,v16,byteswap)
+ VPMSUMW(v4,v4,v16)
+ bdz .Lv4
+
+ lvx v5,off80,r4
+ lvx v17,off80,r3
+ VPERM(v5,v5,v17,byteswap)
+ VPMSUMW(v5,v5,v17)
+ bdz .Lv5
+
+ lvx v6,off96,r4
+ lvx v16,off96,r3
+ VPERM(v6,v6,v16,byteswap)
+ VPMSUMW(v6,v6,v16)
+ bdz .Lv6
+
+ lvx v7,off112,r4
+ lvx v17,off112,r3
+ VPERM(v7,v7,v17,byteswap)
+ VPMSUMW(v7,v7,v17)
+ bdz .Lv7
+
+ addi r3,r3,128
+ addi r4,r4,128
+
+ lvx v8,0,r4
+ lvx v16,0,r3
+ VPERM(v8,v8,v16,byteswap)
+ VPMSUMW(v8,v8,v16)
+ bdz .Lv8
+
+ lvx v9,off16,r4
+ lvx v17,off16,r3
+ VPERM(v9,v9,v17,byteswap)
+ VPMSUMW(v9,v9,v17)
+ bdz .Lv9
+
+ lvx v10,off32,r4
+ lvx v16,off32,r3
+ VPERM(v10,v10,v16,byteswap)
+ VPMSUMW(v10,v10,v16)
+ bdz .Lv10
+
+ lvx v11,off48,r4
+ lvx v17,off48,r3
+ VPERM(v11,v11,v17,byteswap)
+ VPMSUMW(v11,v11,v17)
+ bdz .Lv11
+
+ lvx v12,off64,r4
+ lvx v16,off64,r3
+ VPERM(v12,v12,v16,byteswap)
+ VPMSUMW(v12,v12,v16)
+ bdz .Lv12
+
+ lvx v13,off80,r4
+ lvx v17,off80,r3
+ VPERM(v13,v13,v17,byteswap)
+ VPMSUMW(v13,v13,v17)
+ bdz .Lv13
+
+ lvx v14,off96,r4
+ lvx v16,off96,r3
+ VPERM(v14,v14,v16,byteswap)
+ VPMSUMW(v14,v14,v16)
+ bdz .Lv14
+
+ lvx v15,off112,r4
+ lvx v17,off112,r3
+ VPERM(v15,v15,v17,byteswap)
+ VPMSUMW(v15,v15,v17)
+
+.Lv15: vxor v19,v19,v15
+.Lv14: vxor v20,v20,v14
+.Lv13: vxor v19,v19,v13
+.Lv12: vxor v20,v20,v12
+.Lv11: vxor v19,v19,v11
+.Lv10: vxor v20,v20,v10
+.Lv9: vxor v19,v19,v9
+.Lv8: vxor v20,v20,v8
+.Lv7: vxor v19,v19,v7
+.Lv6: vxor v20,v20,v6
+.Lv5: vxor v19,v19,v5
+.Lv4: vxor v20,v20,v4
+.Lv3: vxor v19,v19,v3
+.Lv2: vxor v20,v20,v2
+.Lv1: vxor v19,v19,v1
+.Lv0: vxor v20,v20,v0
+
+ vxor v0,v19,v20
+
+ b .Lbarrett_reduction
+
+.Lzero:
+ mr r3,r10
+ b .Lout
+
+FUNC_END(__crc32_vpmsum)
diff --git a/src/rocksdb/util/crc32c_ppc_constants.h b/src/rocksdb/util/crc32c_ppc_constants.h
new file mode 100644
index 000000000..f6494cd01
--- /dev/null
+++ b/src/rocksdb/util/crc32c_ppc_constants.h
@@ -0,0 +1,900 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+// Copyright (C) 2015, 2017 International Business Machines Corp.
+// All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#define CRC 0x1edc6f41
+#define REFLECT
+#define CRC_XOR
+
+#ifndef __ASSEMBLY__
+#ifdef CRC_TABLE
+static const unsigned int crc_table[] = {
+ 0x00000000, 0xf26b8303, 0xe13b70f7, 0x1350f3f4, 0xc79a971f, 0x35f1141c,
+ 0x26a1e7e8, 0xd4ca64eb, 0x8ad958cf, 0x78b2dbcc, 0x6be22838, 0x9989ab3b,
+ 0x4d43cfd0, 0xbf284cd3, 0xac78bf27, 0x5e133c24, 0x105ec76f, 0xe235446c,
+ 0xf165b798, 0x030e349b, 0xd7c45070, 0x25afd373, 0x36ff2087, 0xc494a384,
+ 0x9a879fa0, 0x68ec1ca3, 0x7bbcef57, 0x89d76c54, 0x5d1d08bf, 0xaf768bbc,
+ 0xbc267848, 0x4e4dfb4b, 0x20bd8ede, 0xd2d60ddd, 0xc186fe29, 0x33ed7d2a,
+ 0xe72719c1, 0x154c9ac2, 0x061c6936, 0xf477ea35, 0xaa64d611, 0x580f5512,
+ 0x4b5fa6e6, 0xb93425e5, 0x6dfe410e, 0x9f95c20d, 0x8cc531f9, 0x7eaeb2fa,
+ 0x30e349b1, 0xc288cab2, 0xd1d83946, 0x23b3ba45, 0xf779deae, 0x05125dad,
+ 0x1642ae59, 0xe4292d5a, 0xba3a117e, 0x4851927d, 0x5b016189, 0xa96ae28a,
+ 0x7da08661, 0x8fcb0562, 0x9c9bf696, 0x6ef07595, 0x417b1dbc, 0xb3109ebf,
+ 0xa0406d4b, 0x522bee48, 0x86e18aa3, 0x748a09a0, 0x67dafa54, 0x95b17957,
+ 0xcba24573, 0x39c9c670, 0x2a993584, 0xd8f2b687, 0x0c38d26c, 0xfe53516f,
+ 0xed03a29b, 0x1f682198, 0x5125dad3, 0xa34e59d0, 0xb01eaa24, 0x42752927,
+ 0x96bf4dcc, 0x64d4cecf, 0x77843d3b, 0x85efbe38, 0xdbfc821c, 0x2997011f,
+ 0x3ac7f2eb, 0xc8ac71e8, 0x1c661503, 0xee0d9600, 0xfd5d65f4, 0x0f36e6f7,
+ 0x61c69362, 0x93ad1061, 0x80fde395, 0x72966096, 0xa65c047d, 0x5437877e,
+ 0x4767748a, 0xb50cf789, 0xeb1fcbad, 0x197448ae, 0x0a24bb5a, 0xf84f3859,
+ 0x2c855cb2, 0xdeeedfb1, 0xcdbe2c45, 0x3fd5af46, 0x7198540d, 0x83f3d70e,
+ 0x90a324fa, 0x62c8a7f9, 0xb602c312, 0x44694011, 0x5739b3e5, 0xa55230e6,
+ 0xfb410cc2, 0x092a8fc1, 0x1a7a7c35, 0xe811ff36, 0x3cdb9bdd, 0xceb018de,
+ 0xdde0eb2a, 0x2f8b6829, 0x82f63b78, 0x709db87b, 0x63cd4b8f, 0x91a6c88c,
+ 0x456cac67, 0xb7072f64, 0xa457dc90, 0x563c5f93, 0x082f63b7, 0xfa44e0b4,
+ 0xe9141340, 0x1b7f9043, 0xcfb5f4a8, 0x3dde77ab, 0x2e8e845f, 0xdce5075c,
+ 0x92a8fc17, 0x60c37f14, 0x73938ce0, 0x81f80fe3, 0x55326b08, 0xa759e80b,
+ 0xb4091bff, 0x466298fc, 0x1871a4d8, 0xea1a27db, 0xf94ad42f, 0x0b21572c,
+ 0xdfeb33c7, 0x2d80b0c4, 0x3ed04330, 0xccbbc033, 0xa24bb5a6, 0x502036a5,
+ 0x4370c551, 0xb11b4652, 0x65d122b9, 0x97baa1ba, 0x84ea524e, 0x7681d14d,
+ 0x2892ed69, 0xdaf96e6a, 0xc9a99d9e, 0x3bc21e9d, 0xef087a76, 0x1d63f975,
+ 0x0e330a81, 0xfc588982, 0xb21572c9, 0x407ef1ca, 0x532e023e, 0xa145813d,
+ 0x758fe5d6, 0x87e466d5, 0x94b49521, 0x66df1622, 0x38cc2a06, 0xcaa7a905,
+ 0xd9f75af1, 0x2b9cd9f2, 0xff56bd19, 0x0d3d3e1a, 0x1e6dcdee, 0xec064eed,
+ 0xc38d26c4, 0x31e6a5c7, 0x22b65633, 0xd0ddd530, 0x0417b1db, 0xf67c32d8,
+ 0xe52cc12c, 0x1747422f, 0x49547e0b, 0xbb3ffd08, 0xa86f0efc, 0x5a048dff,
+ 0x8ecee914, 0x7ca56a17, 0x6ff599e3, 0x9d9e1ae0, 0xd3d3e1ab, 0x21b862a8,
+ 0x32e8915c, 0xc083125f, 0x144976b4, 0xe622f5b7, 0xf5720643, 0x07198540,
+ 0x590ab964, 0xab613a67, 0xb831c993, 0x4a5a4a90, 0x9e902e7b, 0x6cfbad78,
+ 0x7fab5e8c, 0x8dc0dd8f, 0xe330a81a, 0x115b2b19, 0x020bd8ed, 0xf0605bee,
+ 0x24aa3f05, 0xd6c1bc06, 0xc5914ff2, 0x37faccf1, 0x69e9f0d5, 0x9b8273d6,
+ 0x88d28022, 0x7ab90321, 0xae7367ca, 0x5c18e4c9, 0x4f48173d, 0xbd23943e,
+ 0xf36e6f75, 0x0105ec76, 0x12551f82, 0xe03e9c81, 0x34f4f86a, 0xc69f7b69,
+ 0xd5cf889d, 0x27a40b9e, 0x79b737ba, 0x8bdcb4b9, 0x988c474d, 0x6ae7c44e,
+ 0xbe2da0a5, 0x4c4623a6, 0x5f16d052, 0xad7d5351,
+};
+
+#endif
+
+#else
+#define MAX_SIZE 32768
+.constants :
+
+ /* Reduce 262144 kbits to 1024 bits */
+ /* x^261120 mod p(x)` << 1, x^261184 mod p(x)` << 1 */
+ .octa 0x00000000b6ca9e20000000009c37c408
+
+ /* x^260096 mod p(x)` << 1, x^260160 mod p(x)` << 1 */
+ .octa 0x00000000350249a800000001b51df26c
+
+ /* x^259072 mod p(x)` << 1, x^259136 mod p(x)` << 1 */
+ .octa 0x00000001862dac54000000000724b9d0
+
+ /* x^258048 mod p(x)` << 1, x^258112 mod p(x)` << 1 */
+ .octa 0x00000001d87fb48c00000001c00532fe
+
+ /* x^257024 mod p(x)` << 1, x^257088 mod p(x)` << 1 */
+ .octa 0x00000001f39b699e00000000f05a9362
+
+ /* x^256000 mod p(x)` << 1, x^256064 mod p(x)` << 1 */
+ .octa 0x0000000101da11b400000001e1007970
+
+ /* x^254976 mod p(x)` << 1, x^255040 mod p(x)` << 1 */
+ .octa 0x00000001cab571e000000000a57366ee
+
+ /* x^253952 mod p(x)` << 1, x^254016 mod p(x)` << 1 */
+ .octa 0x00000000c7020cfe0000000192011284
+
+ /* x^252928 mod p(x)` << 1, x^252992 mod p(x)` << 1 */
+ .octa 0x00000000cdaed1ae0000000162716d9a
+
+ /* x^251904 mod p(x)` << 1, x^251968 mod p(x)` << 1 */
+ .octa 0x00000001e804effc00000000cd97ecde
+
+ /* x^250880 mod p(x)` << 1, x^250944 mod p(x)` << 1 */
+ .octa 0x0000000077c3ea3a0000000058812bc0
+
+ /* x^249856 mod p(x)` << 1, x^249920 mod p(x)` << 1 */
+ .octa 0x0000000068df31b40000000088b8c12e
+
+ /* x^248832 mod p(x)` << 1, x^248896 mod p(x)` << 1 */
+ .octa 0x00000000b059b6c200000001230b234c
+
+ /* x^247808 mod p(x)` << 1, x^247872 mod p(x)` << 1 */
+ .octa 0x0000000145fb8ed800000001120b416e
+
+ /* x^246784 mod p(x)` << 1, x^246848 mod p(x)` << 1 */
+ .octa 0x00000000cbc0916800000001974aecb0
+
+ /* x^245760 mod p(x)` << 1, x^245824 mod p(x)` << 1 */
+ .octa 0x000000005ceeedc2000000008ee3f226
+
+ /* x^244736 mod p(x)` << 1, x^244800 mod p(x)` << 1 */
+ .octa 0x0000000047d74e8600000001089aba9a
+
+ /* x^243712 mod p(x)` << 1, x^243776 mod p(x)` << 1 */
+ .octa 0x00000001407e9e220000000065113872
+
+ /* x^242688 mod p(x)` << 1, x^242752 mod p(x)` << 1 */
+ .octa 0x00000001da967bda000000005c07ec10
+
+ /* x^241664 mod p(x)` << 1, x^241728 mod p(x)` << 1 */
+ .octa 0x000000006c8983680000000187590924
+
+ /* x^240640 mod p(x)` << 1, x^240704 mod p(x)` << 1 */
+ .octa 0x00000000f2d14c9800000000e35da7c6
+
+ /* x^239616 mod p(x)` << 1, x^239680 mod p(x)` << 1 */
+ .octa 0x00000001993c6ad4000000000415855a
+
+ /* x^238592 mod p(x)` << 1, x^238656 mod p(x)` << 1 */
+ .octa 0x000000014683d1ac0000000073617758
+
+ /* x^237568 mod p(x)` << 1, x^237632 mod p(x)` << 1 */
+ .octa 0x00000001a7c93e6c0000000176021d28
+
+ /* x^236544 mod p(x)` << 1, x^236608 mod p(x)` << 1 */
+ .octa 0x000000010211e90a00000001c358fd0a
+
+ /* x^235520 mod p(x)` << 1, x^235584 mod p(x)` << 1 */
+ .octa 0x000000001119403e00000001ff7a2c18
+
+ /* x^234496 mod p(x)` << 1, x^234560 mod p(x)` << 1 */
+ .octa 0x000000001c3261aa00000000f2d9f7e4
+
+ /* x^233472 mod p(x)` << 1, x^233536 mod p(x)` << 1 */
+ .octa 0x000000014e37a634000000016cf1f9c8
+
+ /* x^232448 mod p(x)` << 1, x^232512 mod p(x)` << 1 */
+ .octa 0x0000000073786c0c000000010af9279a
+
+ /* x^231424 mod p(x)` << 1, x^231488 mod p(x)` << 1 */
+ .octa 0x000000011dc037f80000000004f101e8
+
+ /* x^230400 mod p(x)` << 1, x^230464 mod p(x)` << 1 */
+ .octa 0x0000000031433dfc0000000070bcf184
+
+ /* x^229376 mod p(x)` << 1, x^229440 mod p(x)` << 1 */
+ .octa 0x000000009cde8348000000000a8de642
+
+ /* x^228352 mod p(x)` << 1, x^228416 mod p(x)` << 1 */
+ .octa 0x0000000038d3c2a60000000062ea130c
+
+ /* x^227328 mod p(x)` << 1, x^227392 mod p(x)` << 1 */
+ .octa 0x000000011b25f26000000001eb31cbb2
+
+ /* x^226304 mod p(x)` << 1, x^226368 mod p(x)` << 1 */
+ .octa 0x000000001629e6f00000000170783448
+
+ /* x^225280 mod p(x)` << 1, x^225344 mod p(x)` << 1 */
+ .octa 0x0000000160838b4c00000001a684b4c6
+
+ /* x^224256 mod p(x)` << 1, x^224320 mod p(x)` << 1 */
+ .octa 0x000000007a44011c00000000253ca5b4
+
+ /* x^223232 mod p(x)` << 1, x^223296 mod p(x)` << 1 */
+ .octa 0x00000000226f417a0000000057b4b1e2
+
+ /* x^222208 mod p(x)` << 1, x^222272 mod p(x)` << 1 */
+ .octa 0x0000000045eb2eb400000000b6bd084c
+
+ /* x^221184 mod p(x)` << 1, x^221248 mod p(x)` << 1 */
+ .octa 0x000000014459d70c0000000123c2d592
+
+ /* x^220160 mod p(x)` << 1, x^220224 mod p(x)` << 1 */
+ .octa 0x00000001d406ed8200000000159dafce
+
+ /* x^219136 mod p(x)` << 1, x^219200 mod p(x)` << 1 */
+ .octa 0x0000000160c8e1a80000000127e1a64e
+
+ /* x^218112 mod p(x)` << 1, x^218176 mod p(x)` << 1 */
+ .octa 0x0000000027ba80980000000056860754
+
+ /* x^217088 mod p(x)` << 1, x^217152 mod p(x)` << 1 */
+ .octa 0x000000006d92d01800000001e661aae8
+
+ /* x^216064 mod p(x)` << 1, x^216128 mod p(x)` << 1 */
+ .octa 0x000000012ed7e3f200000000f82c6166
+
+ /* x^215040 mod p(x)` << 1, x^215104 mod p(x)` << 1 */
+ .octa 0x000000002dc8778800000000c4f9c7ae
+
+ /* x^214016 mod p(x)` << 1, x^214080 mod p(x)` << 1 */
+ .octa 0x0000000018240bb80000000074203d20
+
+ /* x^212992 mod p(x)` << 1, x^213056 mod p(x)` << 1 */
+ .octa 0x000000001ad381580000000198173052
+
+ /* x^211968 mod p(x)` << 1, x^212032 mod p(x)` << 1 */
+ .octa 0x00000001396b78f200000001ce8aba54
+
+ /* x^210944 mod p(x)` << 1, x^211008 mod p(x)` << 1 */
+ .octa 0x000000011a68133400000001850d5d94
+
+ /* x^209920 mod p(x)` << 1, x^209984 mod p(x)` << 1 */
+ .octa 0x000000012104732e00000001d609239c
+
+ /* x^208896 mod p(x)` << 1, x^208960 mod p(x)` << 1 */
+ .octa 0x00000000a140d90c000000001595f048
+
+ /* x^207872 mod p(x)` << 1, x^207936 mod p(x)` << 1 */
+ .octa 0x00000001b7215eda0000000042ccee08
+
+ /* x^206848 mod p(x)` << 1, x^206912 mod p(x)` << 1 */
+ .octa 0x00000001aaf1df3c000000010a389d74
+
+ /* x^205824 mod p(x)` << 1, x^205888 mod p(x)` << 1 */
+ .octa 0x0000000029d15b8a000000012a840da6
+
+ /* x^204800 mod p(x)` << 1, x^204864 mod p(x)` << 1 */
+ .octa 0x00000000f1a96922000000001d181c0c
+
+ /* x^203776 mod p(x)` << 1, x^203840 mod p(x)` << 1 */
+ .octa 0x00000001ac80d03c0000000068b7d1f6
+
+ /* x^202752 mod p(x)` << 1, x^202816 mod p(x)` << 1 */
+ .octa 0x000000000f11d56a000000005b0f14fc
+
+ /* x^201728 mod p(x)` << 1, x^201792 mod p(x)` << 1 */
+ .octa 0x00000001f1c022a20000000179e9e730
+
+ /* x^200704 mod p(x)` << 1, x^200768 mod p(x)` << 1 */
+ .octa 0x0000000173d00ae200000001ce1368d6
+
+ /* x^199680 mod p(x)` << 1, x^199744 mod p(x)` << 1 */
+ .octa 0x00000001d4ffe4ac0000000112c3a84c
+
+ /* x^198656 mod p(x)` << 1, x^198720 mod p(x)` << 1 */
+ .octa 0x000000016edc5ae400000000de940fee
+
+ /* x^197632 mod p(x)` << 1, x^197696 mod p(x)` << 1 */
+ .octa 0x00000001f1a0214000000000fe896b7e
+
+ /* x^196608 mod p(x)` << 1, x^196672 mod p(x)` << 1 */
+ .octa 0x00000000ca0b28a000000001f797431c
+
+ /* x^195584 mod p(x)` << 1, x^195648 mod p(x)` << 1 */
+ .octa 0x00000001928e30a20000000053e989ba
+
+ /* x^194560 mod p(x)` << 1, x^194624 mod p(x)` << 1 */
+ .octa 0x0000000097b1b002000000003920cd16
+
+ /* x^193536 mod p(x)` << 1, x^193600 mod p(x)` << 1 */
+ .octa 0x00000000b15bf90600000001e6f579b8
+
+ /* x^192512 mod p(x)` << 1, x^192576 mod p(x)` << 1 */
+ .octa 0x00000000411c5d52000000007493cb0a
+
+ /* x^191488 mod p(x)` << 1, x^191552 mod p(x)` << 1 */
+ .octa 0x00000001c36f330000000001bdd376d8
+
+ /* x^190464 mod p(x)` << 1, x^190528 mod p(x)` << 1 */
+ .octa 0x00000001119227e0000000016badfee6
+
+ /* x^189440 mod p(x)` << 1, x^189504 mod p(x)` << 1 */
+ .octa 0x00000000114d47020000000071de5c58
+
+ /* x^188416 mod p(x)` << 1, x^188480 mod p(x)` << 1 */
+ .octa 0x00000000458b5b9800000000453f317c
+
+ /* x^187392 mod p(x)` << 1, x^187456 mod p(x)` << 1 */
+ .octa 0x000000012e31fb8e0000000121675cce
+
+ /* x^186368 mod p(x)` << 1, x^186432 mod p(x)` << 1 */
+ .octa 0x000000005cf619d800000001f409ee92
+
+ /* x^185344 mod p(x)` << 1, x^185408 mod p(x)` << 1 */
+ .octa 0x0000000063f4d8b200000000f36b9c88
+
+ /* x^184320 mod p(x)` << 1, x^184384 mod p(x)` << 1 */
+ .octa 0x000000004138dc8a0000000036b398f4
+
+ /* x^183296 mod p(x)` << 1, x^183360 mod p(x)` << 1 */
+ .octa 0x00000001d29ee8e000000001748f9adc
+
+ /* x^182272 mod p(x)` << 1, x^182336 mod p(x)` << 1 */
+ .octa 0x000000006a08ace800000001be94ec00
+
+ /* x^181248 mod p(x)` << 1, x^181312 mod p(x)` << 1 */
+ .octa 0x0000000127d4201000000000b74370d6
+
+ /* x^180224 mod p(x)` << 1, x^180288 mod p(x)` << 1 */
+ .octa 0x0000000019d76b6200000001174d0b98
+
+ /* x^179200 mod p(x)` << 1, x^179264 mod p(x)` << 1 */
+ .octa 0x00000001b1471f6e00000000befc06a4
+
+ /* x^178176 mod p(x)` << 1, x^178240 mod p(x)` << 1 */
+ .octa 0x00000001f64c19cc00000001ae125288
+
+ /* x^177152 mod p(x)` << 1, x^177216 mod p(x)` << 1 */
+ .octa 0x00000000003c0ea00000000095c19b34
+
+ /* x^176128 mod p(x)` << 1, x^176192 mod p(x)` << 1 */
+ .octa 0x000000014d73abf600000001a78496f2
+
+ /* x^175104 mod p(x)` << 1, x^175168 mod p(x)` << 1 */
+ .octa 0x00000001620eb84400000001ac5390a0
+
+ /* x^174080 mod p(x)` << 1, x^174144 mod p(x)` << 1 */
+ .octa 0x0000000147655048000000002a80ed6e
+
+ /* x^173056 mod p(x)` << 1, x^173120 mod p(x)` << 1 */
+ .octa 0x0000000067b5077e00000001fa9b0128
+
+ /* x^172032 mod p(x)` << 1, x^172096 mod p(x)` << 1 */
+ .octa 0x0000000010ffe20600000001ea94929e
+
+ /* x^171008 mod p(x)` << 1, x^171072 mod p(x)` << 1 */
+ .octa 0x000000000fee8f1e0000000125f4305c
+
+ /* x^169984 mod p(x)` << 1, x^170048 mod p(x)` << 1 */
+ .octa 0x00000001da26fbae00000001471e2002
+
+ /* x^168960 mod p(x)` << 1, x^169024 mod p(x)` << 1 */
+ .octa 0x00000001b3a8bd880000000132d2253a
+
+ /* x^167936 mod p(x)` << 1, x^168000 mod p(x)` << 1 */
+ .octa 0x00000000e8f3898e00000000f26b3592
+
+ /* x^166912 mod p(x)` << 1, x^166976 mod p(x)` << 1 */
+ .octa 0x00000000b0d0d28c00000000bc8b67b0
+
+ /* x^165888 mod p(x)` << 1, x^165952 mod p(x)` << 1 */
+ .octa 0x0000000030f2a798000000013a826ef2
+
+ /* x^164864 mod p(x)` << 1, x^164928 mod p(x)` << 1 */
+ .octa 0x000000000fba10020000000081482c84
+
+ /* x^163840 mod p(x)` << 1, x^163904 mod p(x)` << 1 */
+ .octa 0x00000000bdb9bd7200000000e77307c2
+
+ /* x^162816 mod p(x)` << 1, x^162880 mod p(x)` << 1 */
+ .octa 0x0000000075d3bf5a00000000d4a07ec8
+
+ /* x^161792 mod p(x)` << 1, x^161856 mod p(x)` << 1 */
+ .octa 0x00000000ef1f98a00000000017102100
+
+ /* x^160768 mod p(x)` << 1, x^160832 mod p(x)` << 1 */
+ .octa 0x00000000689c760200000000db406486
+
+ /* x^159744 mod p(x)` << 1, x^159808 mod p(x)` << 1 */
+ .octa 0x000000016d5fa5fe0000000192db7f88
+
+ /* x^158720 mod p(x)` << 1, x^158784 mod p(x)` << 1 */
+ .octa 0x00000001d0d2b9ca000000018bf67b1e
+
+ /* x^157696 mod p(x)` << 1, x^157760 mod p(x)` << 1 */
+ .octa 0x0000000041e7b470000000007c09163e
+
+ /* x^156672 mod p(x)` << 1, x^156736 mod p(x)` << 1 */
+ .octa 0x00000001cbb6495e000000000adac060
+
+ /* x^155648 mod p(x)` << 1, x^155712 mod p(x)` << 1 */
+ .octa 0x000000010052a0b000000000bd8316ae
+
+ /* x^154624 mod p(x)` << 1, x^154688 mod p(x)` << 1 */
+ .octa 0x00000001d8effb5c000000019f09ab54
+
+ /* x^153600 mod p(x)` << 1, x^153664 mod p(x)` << 1 */
+ .octa 0x00000001d969853c0000000125155542
+
+ /* x^152576 mod p(x)` << 1, x^152640 mod p(x)` << 1 */
+ .octa 0x00000000523ccce2000000018fdb5882
+
+ /* x^151552 mod p(x)` << 1, x^151616 mod p(x)` << 1 */
+ .octa 0x000000001e2436bc00000000e794b3f4
+
+ /* x^150528 mod p(x)` << 1, x^150592 mod p(x)` << 1 */
+ .octa 0x00000000ddd1c3a2000000016f9bb022
+
+ /* x^149504 mod p(x)` << 1, x^149568 mod p(x)` << 1 */
+ .octa 0x0000000019fcfe3800000000290c9978
+
+ /* x^148480 mod p(x)` << 1, x^148544 mod p(x)` << 1 */
+ .octa 0x00000001ce95db640000000083c0f350
+
+ /* x^147456 mod p(x)` << 1, x^147520 mod p(x)` << 1 */
+ .octa 0x00000000af5828060000000173ea6628
+
+ /* x^146432 mod p(x)` << 1, x^146496 mod p(x)` << 1 */
+ .octa 0x00000001006388f600000001c8b4e00a
+
+ /* x^145408 mod p(x)` << 1, x^145472 mod p(x)` << 1 */
+ .octa 0x0000000179eca00a00000000de95d6aa
+
+ /* x^144384 mod p(x)` << 1, x^144448 mod p(x)` << 1 */
+ .octa 0x0000000122410a6a000000010b7f7248
+
+ /* x^143360 mod p(x)` << 1, x^143424 mod p(x)` << 1 */
+ .octa 0x000000004288e87c00000001326e3a06
+
+ /* x^142336 mod p(x)` << 1, x^142400 mod p(x)` << 1 */
+ .octa 0x000000016c5490da00000000bb62c2e6
+
+ /* x^141312 mod p(x)` << 1, x^141376 mod p(x)` << 1 */
+ .octa 0x00000000d1c71f6e0000000156a4b2c2
+
+ /* x^140288 mod p(x)` << 1, x^140352 mod p(x)` << 1 */
+ .octa 0x00000001b4ce08a6000000011dfe763a
+
+ /* x^139264 mod p(x)` << 1, x^139328 mod p(x)` << 1 */
+ .octa 0x00000001466ba60c000000007bcca8e2
+
+ /* x^138240 mod p(x)` << 1, x^138304 mod p(x)` << 1 */
+ .octa 0x00000001f6c488a40000000186118faa
+
+ /* x^137216 mod p(x)` << 1, x^137280 mod p(x)` << 1 */
+ .octa 0x000000013bfb06820000000111a65a88
+
+ /* x^136192 mod p(x)` << 1, x^136256 mod p(x)` << 1 */
+ .octa 0x00000000690e9e54000000003565e1c4
+
+ /* x^135168 mod p(x)` << 1, x^135232 mod p(x)` << 1 */
+ .octa 0x00000000281346b6000000012ed02a82
+
+ /* x^134144 mod p(x)` << 1, x^134208 mod p(x)` << 1 */
+ .octa 0x000000015646402400000000c486ecfc
+
+ /* x^133120 mod p(x)` << 1, x^133184 mod p(x)` << 1 */
+ .octa 0x000000016063a8dc0000000001b951b2
+
+ /* x^132096 mod p(x)` << 1, x^132160 mod p(x)` << 1 */
+ .octa 0x0000000116a663620000000048143916
+
+ /* x^131072 mod p(x)` << 1, x^131136 mod p(x)` << 1 */
+ .octa 0x000000017e8aa4d200000001dc2ae124
+
+ /* x^130048 mod p(x)` << 1, x^130112 mod p(x)` << 1 */
+ .octa 0x00000001728eb10c00000001416c58d6
+
+ /* x^129024 mod p(x)` << 1, x^129088 mod p(x)` << 1 */
+ .octa 0x00000001b08fd7fa00000000a479744a
+
+ /* x^128000 mod p(x)` << 1, x^128064 mod p(x)` << 1 */
+ .octa 0x00000001092a16e80000000096ca3a26
+
+ /* x^126976 mod p(x)` << 1, x^127040 mod p(x)` << 1 */
+ .octa 0x00000000a505637c00000000ff223d4e
+
+ /* x^125952 mod p(x)` << 1, x^126016 mod p(x)` << 1 */
+ .octa 0x00000000d94869b2000000010e84da42
+
+ /* x^124928 mod p(x)` << 1, x^124992 mod p(x)` << 1 */
+ .octa 0x00000001c8b203ae00000001b61ba3d0
+
+ /* x^123904 mod p(x)` << 1, x^123968 mod p(x)` << 1 */
+ .octa 0x000000005704aea000000000680f2de8
+
+ /* x^122880 mod p(x)` << 1, x^122944 mod p(x)` << 1 */
+ .octa 0x000000012e295fa2000000008772a9a8
+
+ /* x^121856 mod p(x)` << 1, x^121920 mod p(x)` << 1 */
+ .octa 0x000000011d0908bc0000000155f295bc
+
+ /* x^120832 mod p(x)` << 1, x^120896 mod p(x)` << 1 */
+ .octa 0x0000000193ed97ea00000000595f9282
+
+ /* x^119808 mod p(x)` << 1, x^119872 mod p(x)` << 1 */
+ .octa 0x000000013a0f1c520000000164b1c25a
+
+ /* x^118784 mod p(x)` << 1, x^118848 mod p(x)` << 1 */
+ .octa 0x000000010c2c40c000000000fbd67c50
+
+ /* x^117760 mod p(x)` << 1, x^117824 mod p(x)` << 1 */
+ .octa 0x00000000ff6fac3e0000000096076268
+
+ /* x^116736 mod p(x)` << 1, x^116800 mod p(x)` << 1 */
+ .octa 0x000000017b3609c000000001d288e4cc
+
+ /* x^115712 mod p(x)` << 1, x^115776 mod p(x)` << 1 */
+ .octa 0x0000000088c8c92200000001eaac1bdc
+
+ /* x^114688 mod p(x)` << 1, x^114752 mod p(x)` << 1 */
+ .octa 0x00000001751baae600000001f1ea39e2
+
+ /* x^113664 mod p(x)` << 1, x^113728 mod p(x)` << 1 */
+ .octa 0x000000010795297200000001eb6506fc
+
+ /* x^112640 mod p(x)` << 1, x^112704 mod p(x)` << 1 */
+ .octa 0x0000000162b00abe000000010f806ffe
+
+ /* x^111616 mod p(x)` << 1, x^111680 mod p(x)` << 1 */
+ .octa 0x000000000d7b404c000000010408481e
+
+ /* x^110592 mod p(x)` << 1, x^110656 mod p(x)` << 1 */
+ .octa 0x00000000763b13d40000000188260534
+
+ /* x^109568 mod p(x)` << 1, x^109632 mod p(x)` << 1 */
+ .octa 0x00000000f6dc22d80000000058fc73e0
+
+ /* x^108544 mod p(x)` << 1, x^108608 mod p(x)` << 1 */
+ .octa 0x000000007daae06000000000391c59b8
+
+ /* x^107520 mod p(x)` << 1, x^107584 mod p(x)` << 1 */
+ .octa 0x000000013359ab7c000000018b638400
+
+ /* x^106496 mod p(x)` << 1, x^106560 mod p(x)` << 1 */
+ .octa 0x000000008add438a000000011738f5c4
+
+ /* x^105472 mod p(x)` << 1, x^105536 mod p(x)` << 1 */
+ .octa 0x00000001edbefdea000000008cf7c6da
+
+ /* x^104448 mod p(x)` << 1, x^104512 mod p(x)` << 1 */
+ .octa 0x000000004104e0f800000001ef97fb16
+
+ /* x^103424 mod p(x)` << 1, x^103488 mod p(x)` << 1 */
+ .octa 0x00000000b48a82220000000102130e20
+
+ /* x^102400 mod p(x)` << 1, x^102464 mod p(x)` << 1 */
+ .octa 0x00000001bcb4684400000000db968898
+
+ /* x^101376 mod p(x)` << 1, x^101440 mod p(x)` << 1 */
+ .octa 0x000000013293ce0a00000000b5047b5e
+
+ /* x^100352 mod p(x)` << 1, x^100416 mod p(x)` << 1 */
+ .octa 0x00000001710d0844000000010b90fdb2
+
+ /* x^99328 mod p(x)` << 1, x^99392 mod p(x)` << 1 */
+ .octa 0x0000000117907f6e000000004834a32e
+
+ /* x^98304 mod p(x)` << 1, x^98368 mod p(x)` << 1 */
+ .octa 0x0000000087ddf93e0000000059c8f2b0
+
+ /* x^97280 mod p(x)` << 1, x^97344 mod p(x)` << 1 */
+ .octa 0x000000005970e9b00000000122cec508
+
+ /* x^96256 mod p(x)` << 1, x^96320 mod p(x)` << 1 */
+ .octa 0x0000000185b2b7d0000000000a330cda
+
+ /* x^95232 mod p(x)` << 1, x^95296 mod p(x)` << 1 */
+ .octa 0x00000001dcee0efc000000014a47148c
+
+ /* x^94208 mod p(x)` << 1, x^94272 mod p(x)` << 1 */
+ .octa 0x0000000030da27220000000042c61cb8
+
+ /* x^93184 mod p(x)` << 1, x^93248 mod p(x)` << 1 */
+ .octa 0x000000012f925a180000000012fe6960
+
+ /* x^92160 mod p(x)` << 1, x^92224 mod p(x)` << 1 */
+ .octa 0x00000000dd2e357c00000000dbda2c20
+
+ /* x^91136 mod p(x)` << 1, x^91200 mod p(x)` << 1 */
+ .octa 0x00000000071c80de000000011122410c
+
+ /* x^90112 mod p(x)` << 1, x^90176 mod p(x)` << 1 */
+ .octa 0x000000011513140a00000000977b2070
+
+ /* x^89088 mod p(x)` << 1, x^89152 mod p(x)` << 1 */
+ .octa 0x00000001df876e8e000000014050438e
+
+ /* x^88064 mod p(x)` << 1, x^88128 mod p(x)` << 1 */
+ .octa 0x000000015f81d6ce0000000147c840e8
+
+ /* x^87040 mod p(x)` << 1, x^87104 mod p(x)` << 1 */
+ .octa 0x000000019dd94dbe00000001cc7c88ce
+
+ /* x^86016 mod p(x)` << 1, x^86080 mod p(x)` << 1 */
+ .octa 0x00000001373d206e00000001476b35a4
+
+ /* x^84992 mod p(x)` << 1, x^85056 mod p(x)` << 1 */
+ .octa 0x00000000668ccade000000013d52d508
+
+ /* x^83968 mod p(x)` << 1, x^84032 mod p(x)` << 1 */
+ .octa 0x00000001b192d268000000008e4be32e
+
+ /* x^82944 mod p(x)` << 1, x^83008 mod p(x)` << 1 */
+ .octa 0x00000000e30f3a7800000000024120fe
+
+ /* x^81920 mod p(x)` << 1, x^81984 mod p(x)` << 1 */
+ .octa 0x000000010ef1f7bc00000000ddecddb4
+
+ /* x^80896 mod p(x)` << 1, x^80960 mod p(x)` << 1 */
+ .octa 0x00000001f5ac738000000000d4d403bc
+
+ /* x^79872 mod p(x)` << 1, x^79936 mod p(x)` << 1 */
+ .octa 0x000000011822ea7000000001734b89aa
+
+ /* x^78848 mod p(x)` << 1, x^78912 mod p(x)` << 1 */
+ .octa 0x00000000c3a33848000000010e7a58d6
+
+ /* x^77824 mod p(x)` << 1, x^77888 mod p(x)` << 1 */
+ .octa 0x00000001bd151c2400000001f9f04e9c
+
+ /* x^76800 mod p(x)` << 1, x^76864 mod p(x)` << 1 */
+ .octa 0x0000000056002d7600000000b692225e
+
+ /* x^75776 mod p(x)` << 1, x^75840 mod p(x)` << 1 */
+ .octa 0x000000014657c4f4000000019b8d3f3e
+
+ /* x^74752 mod p(x)` << 1, x^74816 mod p(x)` << 1 */
+ .octa 0x0000000113742d7c00000001a874f11e
+
+ /* x^73728 mod p(x)` << 1, x^73792 mod p(x)` << 1 */
+ .octa 0x000000019c5920ba000000010d5a4254
+
+ /* x^72704 mod p(x)` << 1, x^72768 mod p(x)` << 1 */
+ .octa 0x000000005216d2d600000000bbb2f5d6
+
+ /* x^71680 mod p(x)` << 1, x^71744 mod p(x)` << 1 */
+ .octa 0x0000000136f5ad8a0000000179cc0e36
+
+ /* x^70656 mod p(x)` << 1, x^70720 mod p(x)` << 1 */
+ .octa 0x000000018b07beb600000001dca1da4a
+
+ /* x^69632 mod p(x)` << 1, x^69696 mod p(x)` << 1 */
+ .octa 0x00000000db1e93b000000000feb1a192
+
+ /* x^68608 mod p(x)` << 1, x^68672 mod p(x)` << 1 */
+ .octa 0x000000000b96fa3a00000000d1eeedd6
+
+ /* x^67584 mod p(x)` << 1, x^67648 mod p(x)` << 1 */
+ .octa 0x00000001d9968af0000000008fad9bb4
+
+ /* x^66560 mod p(x)` << 1, x^66624 mod p(x)` << 1 */
+ .octa 0x000000000e4a77a200000001884938e4
+
+ /* x^65536 mod p(x)` << 1, x^65600 mod p(x)` << 1 */
+ .octa 0x00000000508c2ac800000001bc2e9bc0
+
+ /* x^64512 mod p(x)` << 1, x^64576 mod p(x)` << 1 */
+ .octa 0x0000000021572a8000000001f9658a68
+
+ /* x^63488 mod p(x)` << 1, x^63552 mod p(x)` << 1 */
+ .octa 0x00000001b859daf2000000001b9224fc
+
+ /* x^62464 mod p(x)` << 1, x^62528 mod p(x)` << 1 */
+ .octa 0x000000016f7884740000000055b2fb84
+
+ /* x^61440 mod p(x)` << 1, x^61504 mod p(x)` << 1 */
+ .octa 0x00000001b438810e000000018b090348
+
+ /* x^60416 mod p(x)` << 1, x^60480 mod p(x)` << 1 */
+ .octa 0x0000000095ddc6f2000000011ccbd5ea
+
+ /* x^59392 mod p(x)` << 1, x^59456 mod p(x)` << 1 */
+ .octa 0x00000001d977c20c0000000007ae47f8
+
+ /* x^58368 mod p(x)` << 1, x^58432 mod p(x)` << 1 */
+ .octa 0x00000000ebedb99a0000000172acbec0
+
+ /* x^57344 mod p(x)` << 1, x^57408 mod p(x)` << 1 */
+ .octa 0x00000001df9e9e9200000001c6e3ff20
+
+ /* x^56320 mod p(x)` << 1, x^56384 mod p(x)` << 1 */
+ .octa 0x00000001a4a3f95200000000e1b38744
+
+ /* x^55296 mod p(x)` << 1, x^55360 mod p(x)` << 1 */
+ .octa 0x00000000e2f5122000000000791585b2
+
+ /* x^54272 mod p(x)` << 1, x^54336 mod p(x)` << 1 */
+ .octa 0x000000004aa01f3e00000000ac53b894
+
+ /* x^53248 mod p(x)` << 1, x^53312 mod p(x)` << 1 */
+ .octa 0x00000000b3e90a5800000001ed5f2cf4
+
+ /* x^52224 mod p(x)` << 1, x^52288 mod p(x)` << 1 */
+ .octa 0x000000000c9ca2aa00000001df48b2e0
+
+ /* x^51200 mod p(x)` << 1, x^51264 mod p(x)` << 1 */
+ .octa 0x000000015168231600000000049c1c62
+
+ /* x^50176 mod p(x)` << 1, x^50240 mod p(x)` << 1 */
+ .octa 0x0000000036fce78c000000017c460c12
+
+ /* x^49152 mod p(x)` << 1, x^49216 mod p(x)` << 1 */
+ .octa 0x000000009037dc10000000015be4da7e
+
+ /* x^48128 mod p(x)` << 1, x^48192 mod p(x)` << 1 */
+ .octa 0x00000000d3298582000000010f38f668
+
+ /* x^47104 mod p(x)` << 1, x^47168 mod p(x)` << 1 */
+ .octa 0x00000001b42e8ad60000000039f40a00
+
+ /* x^46080 mod p(x)` << 1, x^46144 mod p(x)` << 1 */
+ .octa 0x00000000142a983800000000bd4c10c4
+
+ /* x^45056 mod p(x)` << 1, x^45120 mod p(x)` << 1 */
+ .octa 0x0000000109c7f1900000000042db1d98
+
+ /* x^44032 mod p(x)` << 1, x^44096 mod p(x)` << 1 */
+ .octa 0x0000000056ff931000000001c905bae6
+
+ /* x^43008 mod p(x)` << 1, x^43072 mod p(x)` << 1 */
+ .octa 0x00000001594513aa00000000069d40ea
+
+ /* x^41984 mod p(x)` << 1, x^42048 mod p(x)` << 1 */
+ .octa 0x00000001e3b5b1e8000000008e4fbad0
+
+ /* x^40960 mod p(x)` << 1, x^41024 mod p(x)` << 1 */
+ .octa 0x000000011dd5fc080000000047bedd46
+
+ /* x^39936 mod p(x)` << 1, x^40000 mod p(x)` << 1 */
+ .octa 0x00000001675f0cc20000000026396bf8
+
+ /* x^38912 mod p(x)` << 1, x^38976 mod p(x)` << 1 */
+ .octa 0x00000000d1c8dd4400000000379beb92
+
+ /* x^37888 mod p(x)` << 1, x^37952 mod p(x)` << 1 */
+ .octa 0x0000000115ebd3d8000000000abae54a
+
+ /* x^36864 mod p(x)` << 1, x^36928 mod p(x)` << 1 */
+ .octa 0x00000001ecbd0dac0000000007e6a128
+
+ /* x^35840 mod p(x)` << 1, x^35904 mod p(x)` << 1 */
+ .octa 0x00000000cdf67af2000000000ade29d2
+
+ /* x^34816 mod p(x)` << 1, x^34880 mod p(x)` << 1 */
+ .octa 0x000000004c01ff4c00000000f974c45c
+
+ /* x^33792 mod p(x)` << 1, x^33856 mod p(x)` << 1 */
+ .octa 0x00000000f2d8657e00000000e77ac60a
+
+ /* x^32768 mod p(x)` << 1, x^32832 mod p(x)` << 1 */
+ .octa 0x000000006bae74c40000000145895816
+
+ /* x^31744 mod p(x)` << 1, x^31808 mod p(x)` << 1 */
+ .octa 0x0000000152af8aa00000000038e362be
+
+ /* x^30720 mod p(x)` << 1, x^30784 mod p(x)` << 1 */
+ .octa 0x0000000004663802000000007f991a64
+
+ /* x^29696 mod p(x)` << 1, x^29760 mod p(x)` << 1 */
+ .octa 0x00000001ab2f5afc00000000fa366d3a
+
+ /* x^28672 mod p(x)` << 1, x^28736 mod p(x)` << 1 */
+ .octa 0x0000000074a4ebd400000001a2bb34f0
+
+ /* x^27648 mod p(x)` << 1, x^27712 mod p(x)` << 1 */
+ .octa 0x00000001d7ab3a4c0000000028a9981e
+
+ /* x^26624 mod p(x)` << 1, x^26688 mod p(x)` << 1 */
+ .octa 0x00000001a8da60c600000001dbc672be
+
+ /* x^25600 mod p(x)` << 1, x^25664 mod p(x)` << 1 */
+ .octa 0x000000013cf6382000000000b04d77f6
+
+ /* x^24576 mod p(x)` << 1, x^24640 mod p(x)` << 1 */
+ .octa 0x00000000bec12e1e0000000124400d96
+
+ /* x^23552 mod p(x)` << 1, x^23616 mod p(x)` << 1 */
+ .octa 0x00000001c6368010000000014ca4b414
+
+ /* x^22528 mod p(x)` << 1, x^22592 mod p(x)` << 1 */
+ .octa 0x00000001e6e78758000000012fe2c938
+
+ /* x^21504 mod p(x)` << 1, x^21568 mod p(x)` << 1 */
+ .octa 0x000000008d7f2b3c00000001faed01e6
+
+ /* x^20480 mod p(x)` << 1, x^20544 mod p(x)` << 1 */
+ .octa 0x000000016b4a156e000000007e80ecfe
+
+ /* x^19456 mod p(x)` << 1, x^19520 mod p(x)` << 1 */
+ .octa 0x00000001c63cfeb60000000098daee94
+
+ /* x^18432 mod p(x)` << 1, x^18496 mod p(x)` << 1 */
+ .octa 0x000000015f902670000000010a04edea
+
+ /* x^17408 mod p(x)` << 1, x^17472 mod p(x)` << 1 */
+ .octa 0x00000001cd5de11e00000001c00b4524
+
+ /* x^16384 mod p(x)` << 1, x^16448 mod p(x)` << 1 */
+ .octa 0x000000001acaec540000000170296550
+
+ /* x^15360 mod p(x)` << 1, x^15424 mod p(x)` << 1 */
+ .octa 0x000000002bd0ca780000000181afaa48
+
+ /* x^14336 mod p(x)` << 1, x^14400 mod p(x)` << 1 */
+ .octa 0x0000000032d63d5c0000000185a31ffa
+
+ /* x^13312 mod p(x)` << 1, x^13376 mod p(x)` << 1 */
+ .octa 0x000000001c6d4e4c000000002469f608
+
+ /* x^12288 mod p(x)` << 1, x^12352 mod p(x)` << 1 */
+ .octa 0x0000000106a60b92000000006980102a
+
+ /* x^11264 mod p(x)` << 1, x^11328 mod p(x)` << 1 */
+ .octa 0x00000000d3855e120000000111ea9ca8
+
+ /* x^10240 mod p(x)` << 1, x^10304 mod p(x)` << 1 */
+ .octa 0x00000000e312563600000001bd1d29ce
+
+ /* x^9216 mod p(x)` << 1, x^9280 mod p(x)` << 1 */
+ .octa 0x000000009e8f7ea400000001b34b9580
+
+ /* x^8192 mod p(x)` << 1, x^8256 mod p(x)` << 1 */
+ .octa 0x00000001c82e562c000000003076054e
+
+ /* x^7168 mod p(x)` << 1, x^7232 mod p(x)` << 1 */
+ .octa 0x00000000ca9f09ce000000012a608ea4
+
+ /* x^6144 mod p(x)` << 1, x^6208 mod p(x)` << 1 */
+ .octa 0x00000000c63764e600000000784d05fe
+
+ /* x^5120 mod p(x)` << 1, x^5184 mod p(x)` << 1 */
+ .octa 0x0000000168d2e49e000000016ef0d82a
+
+ /* x^4096 mod p(x)` << 1, x^4160 mod p(x)` << 1 */
+ .octa 0x00000000e986c1480000000075bda454
+
+ /* x^3072 mod p(x)` << 1, x^3136 mod p(x)` << 1 */
+ .octa 0x00000000cfb65894000000003dc0a1c4
+
+ /* x^2048 mod p(x)` << 1, x^2112 mod p(x)` << 1 */
+ .octa 0x0000000111cadee400000000e9a5d8be
+
+ /* x^1024 mod p(x)` << 1, x^1088 mod p(x)` << 1 */
+ .octa 0x0000000171fb63ce00000001609bc4b4
+
+ .short_constants :
+
+ /* Reduce final 1024-2048 bits to 64 bits, shifting 32 bits to include
+ the trailing 32 bits of zeros */
+ /* x^1952 mod p(x)`, x^1984 mod p(x)`, x^2016 mod p(x)`, x^2048 mod
+ p(x)` */
+ .octa 0x7fec2963e5bf80485cf015c388e56f72
+
+ /* x^1824 mod p(x)`, x^1856 mod p(x)`, x^1888 mod p(x)`, x^1920 mod
+ p(x)` */
+ .octa 0x38e888d4844752a9963a18920246e2e6
+
+ /* x^1696 mod p(x)`, x^1728 mod p(x)`, x^1760 mod p(x)`, x^1792 mod
+ p(x)` */
+ .octa 0x42316c00730206ad419a441956993a31
+
+ /* x^1568 mod p(x)`, x^1600 mod p(x)`, x^1632 mod p(x)`, x^1664 mod
+ p(x)` */
+ .octa 0x543d5c543e65ddf9924752ba2b830011
+
+ /* x^1440 mod p(x)`, x^1472 mod p(x)`, x^1504 mod p(x)`, x^1536 mod
+ p(x)` */
+ .octa 0x78e87aaf56767c9255bd7f9518e4a304
+
+ /* x^1312 mod p(x)`, x^1344 mod p(x)`, x^1376 mod p(x)`, x^1408 mod
+ p(x)` */
+ .octa 0x8f68fcec1903da7f6d76739fe0553f1e
+
+ /* x^1184 mod p(x)`, x^1216 mod p(x)`, x^1248 mod p(x)`, x^1280 mod
+ p(x)` */
+ .octa 0x3f4840246791d588c133722b1fe0b5c3
+
+ /* x^1056 mod p(x)`, x^1088 mod p(x)`, x^1120 mod p(x)`, x^1152 mod
+ p(x)` */
+ .octa 0x34c96751b04de25a64b67ee0e55ef1f3
+
+ /* x^928 mod p(x)`, x^960 mod p(x)`, x^992 mod p(x)`, x^1024 mod p(x)`
+ */
+ .octa 0x156c8e180b4a395b069db049b8fdb1e7
+
+ /* x^800 mod p(x)`, x^832 mod p(x)`, x^864 mod p(x)`, x^896 mod p(x)` */
+ .octa 0xe0b99ccbe661f7bea11bfaf3c9e90b9e
+
+ /* x^672 mod p(x)`, x^704 mod p(x)`, x^736 mod p(x)`, x^768 mod p(x)` */
+ .octa 0x041d37768cd75659817cdc5119b29a35
+
+ /* x^544 mod p(x)`, x^576 mod p(x)`, x^608 mod p(x)`, x^640 mod p(x)` */
+ .octa 0x3a0777818cfaa9651ce9d94b36c41f1c
+
+ /* x^416 mod p(x)`, x^448 mod p(x)`, x^480 mod p(x)`, x^512 mod p(x)` */
+ .octa 0x0e148e8252377a554f256efcb82be955
+
+ /* x^288 mod p(x)`, x^320 mod p(x)`, x^352 mod p(x)`, x^384 mod p(x)` */
+ .octa 0x9c25531d19e65ddeec1631edb2dea967
+
+ /* x^160 mod p(x)`, x^192 mod p(x)`, x^224 mod p(x)`, x^256 mod p(x)` */
+ .octa 0x790606ff9957c0a65d27e147510ac59a
+
+ /* x^32 mod p(x)`, x^64 mod p(x)`, x^96 mod p(x)`, x^128 mod p(x)` */
+ .octa 0x82f63b786ea2d55ca66805eb18b8ea18
+
+ .barrett_constants :
+ /* 33 bit reflected Barrett constant m - (4^32)/n */
+ .octa 0x000000000000000000000000dea713f1 /* x^64 div p(x)` */
+ /* 33 bit reflected Barrett constant n */
+ .octa 0x00000000000000000000000105ec76f1
+#endif
diff --git a/src/rocksdb/util/crc32c_test.cc b/src/rocksdb/util/crc32c_test.cc
new file mode 100644
index 000000000..715d63e2d
--- /dev/null
+++ b/src/rocksdb/util/crc32c_test.cc
@@ -0,0 +1,213 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+#include "util/crc32c.h"
+
+#include "test_util/testharness.h"
+#include "util/coding.h"
+#include "util/random.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace crc32c {
+
+class CRC {};
+
+// Tests for 3-way crc32c algorithm. We need these tests because it uses
+// different lookup tables than the original Fast_CRC32
+const unsigned int BUFFER_SIZE = 512 * 1024 * sizeof(uint64_t);
+char buffer[BUFFER_SIZE];
+
+struct ExpectedResult {
+ size_t offset;
+ size_t length;
+ uint32_t crc32c;
+};
+
+ExpectedResult expectedResults[] = {
+ // Zero-byte input
+ {0, 0, ~0U},
+ // Small aligned inputs to test special cases in SIMD implementations
+ {8, 1, 1543413366},
+ {8, 2, 523493126},
+ {8, 3, 1560427360},
+ {8, 4, 3422504776},
+ {8, 5, 447841138},
+ {8, 6, 3910050499},
+ {8, 7, 3346241981},
+ // Small unaligned inputs
+ {9, 1, 3855826643},
+ {10, 2, 560880875},
+ {11, 3, 1479707779},
+ {12, 4, 2237687071},
+ {13, 5, 4063855784},
+ {14, 6, 2553454047},
+ {15, 7, 1349220140},
+ // Larger inputs to test leftover chunks at the end of aligned blocks
+ {8, 8, 627613930},
+ {8, 9, 2105929409},
+ {8, 10, 2447068514},
+ {8, 11, 863807079},
+ {8, 12, 292050879},
+ {8, 13, 1411837737},
+ {8, 14, 2614515001},
+ {8, 15, 3579076296},
+ {8, 16, 2897079161},
+ {8, 17, 675168386},
+ // // Much larger inputs
+ {0, BUFFER_SIZE, 2096790750},
+ {1, BUFFER_SIZE / 2, 3854797577},
+
+};
+
+TEST(CRC, StandardResults) {
+ // Original Fast_CRC32 tests.
+ // From rfc3720 section B.4.
+ char buf[32];
+
+ memset(buf, 0, sizeof(buf));
+ ASSERT_EQ(0x8a9136aaU, Value(buf, sizeof(buf)));
+
+ memset(buf, 0xff, sizeof(buf));
+ ASSERT_EQ(0x62a8ab43U, Value(buf, sizeof(buf)));
+
+ for (int i = 0; i < 32; i++) {
+ buf[i] = static_cast<char>(i);
+ }
+ ASSERT_EQ(0x46dd794eU, Value(buf, sizeof(buf)));
+
+ for (int i = 0; i < 32; i++) {
+ buf[i] = static_cast<char>(31 - i);
+ }
+ ASSERT_EQ(0x113fdb5cU, Value(buf, sizeof(buf)));
+
+ unsigned char data[48] = {
+ 0x01, 0xc0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00,
+ 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x18, 0x28, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ };
+ ASSERT_EQ(0xd9963a56, Value(reinterpret_cast<char*>(data), sizeof(data)));
+
+ // 3-Way Crc32c tests ported from folly.
+ // Test 1: single computation
+ for (auto expected : expectedResults) {
+ uint32_t result = Value(buffer + expected.offset, expected.length);
+ EXPECT_EQ(~expected.crc32c, result);
+ }
+
+ // Test 2: stitching two computations
+ for (auto expected : expectedResults) {
+ size_t partialLength = expected.length / 2;
+ uint32_t partialChecksum = Value(buffer + expected.offset, partialLength);
+ uint32_t result =
+ Extend(partialChecksum, buffer + expected.offset + partialLength,
+ expected.length - partialLength);
+ EXPECT_EQ(~expected.crc32c, result);
+ }
+}
+
+TEST(CRC, Values) { ASSERT_NE(Value("a", 1), Value("foo", 3)); }
+
+TEST(CRC, Extend) {
+ ASSERT_EQ(Value("hello world", 11), Extend(Value("hello ", 6), "world", 5));
+}
+
+TEST(CRC, Mask) {
+ uint32_t crc = Value("foo", 3);
+ ASSERT_NE(crc, Mask(crc));
+ ASSERT_NE(crc, Mask(Mask(crc)));
+ ASSERT_EQ(crc, Unmask(Mask(crc)));
+ ASSERT_EQ(crc, Unmask(Unmask(Mask(Mask(crc)))));
+}
+
+TEST(CRC, Crc32cCombineBasicTest) {
+ uint32_t crc1 = Value("hello ", 6);
+ uint32_t crc2 = Value("world", 5);
+ uint32_t crc3 = Value("hello world", 11);
+ uint32_t crc1_2_combine = Crc32cCombine(crc1, crc2, 5);
+ ASSERT_EQ(crc3, crc1_2_combine);
+}
+
+TEST(CRC, Crc32cCombineOrderMattersTest) {
+ uint32_t crc1 = Value("hello ", 6);
+ uint32_t crc2 = Value("world", 5);
+ uint32_t crc3 = Value("hello world", 11);
+ uint32_t crc2_1_combine = Crc32cCombine(crc2, crc1, 6);
+ ASSERT_NE(crc3, crc2_1_combine);
+}
+
+TEST(CRC, Crc32cCombineFullCoverTest) {
+ int scale = 4 * 1024;
+ Random rnd(test::RandomSeed());
+ int size_1 = 1024 * 1024;
+ std::string s1 = rnd.RandomBinaryString(size_1);
+ uint32_t crc1 = Value(s1.data(), size_1);
+ for (int i = 0; i < scale; i++) {
+ int size_2 = i;
+ std::string s2 = rnd.RandomBinaryString(size_2);
+ uint32_t crc2 = Value(s2.data(), s2.size());
+ uint32_t crc1_2 = Extend(crc1, s2.data(), s2.size());
+ uint32_t crc1_2_combine = Crc32cCombine(crc1, crc2, size_2);
+ ASSERT_EQ(crc1_2, crc1_2_combine);
+ }
+}
+
+TEST(CRC, Crc32cCombineBigSizeTest) {
+ Random rnd(test::RandomSeed());
+ int size_1 = 1024 * 1024;
+ std::string s1 = rnd.RandomBinaryString(size_1);
+ uint32_t crc1 = Value(s1.data(), size_1);
+ int size_2 = 16 * 1024 * 1024 - 1;
+ std::string s2 = rnd.RandomBinaryString(size_2);
+ uint32_t crc2 = Value(s2.data(), s2.size());
+ uint32_t crc1_2 = Extend(crc1, s2.data(), s2.size());
+ uint32_t crc1_2_combine = Crc32cCombine(crc1, crc2, size_2);
+ ASSERT_EQ(crc1_2, crc1_2_combine);
+}
+
+} // namespace crc32c
+} // namespace ROCKSDB_NAMESPACE
+
+// copied from folly
+const uint64_t FNV_64_HASH_START = 14695981039346656037ULL;
+inline uint64_t fnv64_buf(const void* buf, size_t n,
+ uint64_t hash = FNV_64_HASH_START) {
+ // forcing signed char, since other platforms can use unsigned
+ const signed char* char_buf = reinterpret_cast<const signed char*>(buf);
+
+ for (size_t i = 0; i < n; ++i) {
+ hash += (hash << 1) + (hash << 4) + (hash << 5) + (hash << 7) +
+ (hash << 8) + (hash << 40);
+ hash ^= char_buf[i];
+ }
+ return hash;
+}
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+
+ // Populate a buffer with a deterministic pattern
+ // on which to compute checksums
+
+ const uint8_t* src = (uint8_t*)ROCKSDB_NAMESPACE::crc32c::buffer;
+ uint64_t* dst = (uint64_t*)ROCKSDB_NAMESPACE::crc32c::buffer;
+ const uint64_t* end =
+ (const uint64_t*)(ROCKSDB_NAMESPACE::crc32c::buffer +
+ ROCKSDB_NAMESPACE::crc32c::BUFFER_SIZE);
+ *dst++ = 0;
+ while (dst < end) {
+ ROCKSDB_NAMESPACE::EncodeFixed64(
+ reinterpret_cast<char*>(dst),
+ fnv64_buf((const char*)src, sizeof(uint64_t)));
+ dst++;
+ src += sizeof(uint64_t);
+ }
+
+ return RUN_ALL_TESTS();
+}
diff --git a/src/rocksdb/util/defer.h b/src/rocksdb/util/defer.h
new file mode 100644
index 000000000..f71e67ba9
--- /dev/null
+++ b/src/rocksdb/util/defer.h
@@ -0,0 +1,82 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#include <functional>
+
+#include "rocksdb/rocksdb_namespace.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// Defers the execution of the provided function until the Defer
+// object goes out of scope.
+//
+// Usage example:
+//
+// Status DeferTest() {
+// Status s;
+// Defer defer([&s]() {
+// if (!s.ok()) {
+// // do cleanups ...
+// }
+// });
+// // do something ...
+// if (!s.ok()) return;
+// // do some other things ...
+// return s;
+// }
+//
+// The above code ensures that cleanups will always happen on returning.
+//
+// Without the help of Defer, you can
+// 1. every time when !s.ok(), do the cleanup;
+// 2. instead of returning when !s.ok(), continue the work only when s.ok(),
+// but sometimes, this might lead to nested blocks of "if (s.ok()) {...}".
+//
+// With the help of Defer, you can centralize the cleanup logic inside the
+// lambda passed to Defer, and you can return immediately on failure when
+// necessary.
+class Defer final {
+ public:
+ explicit Defer(std::function<void()>&& fn) : fn_(std::move(fn)) {}
+ ~Defer() { fn_(); }
+
+ // Disallow copy.
+ Defer(const Defer&) = delete;
+ Defer& operator=(const Defer&) = delete;
+
+ private:
+ std::function<void()> fn_;
+};
+
+// An RAII utility object that saves the current value of an object so that
+// it can be overwritten, and restores it to the saved value when the
+// SaveAndRestore object goes out of scope.
+template <typename T>
+class SaveAndRestore {
+ public:
+ // obj is non-null pointer to value to be saved and later restored.
+ explicit SaveAndRestore(T* obj) : obj_(obj), saved_(*obj) {}
+ // new_value is stored in *obj
+ SaveAndRestore(T* obj, const T& new_value)
+ : obj_(obj), saved_(std::move(*obj)) {
+ *obj = new_value;
+ }
+ SaveAndRestore(T* obj, T&& new_value) : obj_(obj), saved_(std::move(*obj)) {
+ *obj = std::move(new_value);
+ }
+ ~SaveAndRestore() { *obj_ = std::move(saved_); }
+
+ // No copies
+ SaveAndRestore(const SaveAndRestore&) = delete;
+ SaveAndRestore& operator=(const SaveAndRestore&) = delete;
+
+ private:
+ T* const obj_;
+ T saved_;
+};
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/defer_test.cc b/src/rocksdb/util/defer_test.cc
new file mode 100644
index 000000000..0e98f68b6
--- /dev/null
+++ b/src/rocksdb/util/defer_test.cc
@@ -0,0 +1,51 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "util/defer.h"
+
+#include "port/port.h"
+#include "port/stack_trace.h"
+#include "test_util/testharness.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class DeferTest {};
+
+TEST(DeferTest, BlockScope) {
+ int v = 1;
+ {
+ Defer defer([&v]() { v *= 2; });
+ }
+ ASSERT_EQ(2, v);
+}
+
+TEST(DeferTest, FunctionScope) {
+ int v = 1;
+ auto f = [&v]() {
+ Defer defer([&v]() { v *= 2; });
+ v = 2;
+ };
+ f();
+ ASSERT_EQ(4, v);
+}
+
+TEST(SaveAndRestoreTest, BlockScope) {
+ int v = 1;
+ {
+ SaveAndRestore<int> sr(&v);
+ ASSERT_EQ(v, 1);
+ v = 2;
+ ASSERT_EQ(v, 2);
+ }
+ ASSERT_EQ(v, 1);
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/rocksdb/util/distributed_mutex.h b/src/rocksdb/util/distributed_mutex.h
new file mode 100644
index 000000000..9675a1e2d
--- /dev/null
+++ b/src/rocksdb/util/distributed_mutex.h
@@ -0,0 +1,48 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#include "rocksdb/rocksdb_namespace.h"
+
+// This file declares a wrapper around the efficient folly DistributedMutex
+// that falls back on a standard mutex when not available. See
+// https://github.com/facebook/folly/blob/main/folly/synchronization/DistributedMutex.h
+// for benefits and limitations.
+
+// At the moment, only scoped locking is supported using DMutexLock
+// RAII wrapper, because lock/unlock APIs will vary.
+
+#ifdef USE_FOLLY
+
+#include <folly/synchronization/DistributedMutex.h>
+
+namespace ROCKSDB_NAMESPACE {
+
+class DMutex : public folly::DistributedMutex {
+ public:
+ static const char* kName() { return "folly::DistributedMutex"; }
+
+ explicit DMutex(bool IGNORED_adaptive = false) { (void)IGNORED_adaptive; }
+
+ // currently no-op
+ void AssertHeld() {}
+};
+using DMutexLock = std::lock_guard<folly::DistributedMutex>;
+
+} // namespace ROCKSDB_NAMESPACE
+
+#else
+
+#include "port/port.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+using DMutex = port::Mutex;
+using DMutexLock = std::lock_guard<DMutex>;
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif
diff --git a/src/rocksdb/util/duplicate_detector.h b/src/rocksdb/util/duplicate_detector.h
new file mode 100644
index 000000000..d778622db
--- /dev/null
+++ b/src/rocksdb/util/duplicate_detector.h
@@ -0,0 +1,71 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#include <cstdint>
+
+#include "db/db_impl/db_impl.h"
+#include "logging/logging.h"
+#include "util/set_comparator.h"
+
+namespace ROCKSDB_NAMESPACE {
+// During recovery if the memtable is flushed we cannot rely on its help on
+// duplicate key detection and as key insert will not be attempted. This class
+// will be used as a emulator of memtable to tell if insertion of a key/seq
+// would have resulted in duplication.
+class DuplicateDetector {
+ public:
+ explicit DuplicateDetector(DBImpl* db) : db_(db) {}
+ bool IsDuplicateKeySeq(uint32_t cf, const Slice& key, SequenceNumber seq) {
+ assert(seq >= batch_seq_);
+ if (batch_seq_ != seq) { // it is a new batch
+ keys_.clear();
+ }
+ batch_seq_ = seq;
+ CFKeys& cf_keys = keys_[cf];
+ if (cf_keys.size() == 0) { // just inserted
+ InitWithComp(cf);
+ }
+ auto it = cf_keys.insert(key);
+ if (it.second == false) { // second is false if a element already existed.
+ keys_.clear();
+ InitWithComp(cf);
+ keys_[cf].insert(key);
+ return true;
+ }
+ return false;
+ }
+
+ private:
+ SequenceNumber batch_seq_ = 0;
+ DBImpl* db_;
+ using CFKeys = std::set<Slice, SetComparator>;
+ std::map<uint32_t, CFKeys> keys_;
+ void InitWithComp(const uint32_t cf) {
+ auto h = db_->GetColumnFamilyHandle(cf);
+ if (!h) {
+ // TODO(myabandeh): This is not a concern in MyRocks as drop cf is not
+ // implemented yet. When it does, we should return proper error instead
+ // of throwing exception.
+ ROCKS_LOG_FATAL(
+ db_->immutable_db_options().info_log,
+ "Recovering an entry from the dropped column family %" PRIu32
+ ". WAL must must have been emptied before dropping the column "
+ "family",
+ cf);
+#ifndef ROCKSDB_LITE
+ throw std::runtime_error(
+ "Recovering an entry from a dropped column family. "
+ "WAL must must have been flushed before dropping the column "
+ "family");
+#endif
+ return;
+ }
+ auto cmp = h->GetComparator();
+ keys_[cf] = CFKeys(SetComparator(cmp));
+ }
+};
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/dynamic_bloom.cc b/src/rocksdb/util/dynamic_bloom.cc
new file mode 100644
index 000000000..0ff3b4a75
--- /dev/null
+++ b/src/rocksdb/util/dynamic_bloom.cc
@@ -0,0 +1,70 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "dynamic_bloom.h"
+
+#include <algorithm>
+
+#include "memory/allocator.h"
+#include "port/port.h"
+#include "rocksdb/slice.h"
+#include "util/hash.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+namespace {
+
+uint32_t roundUpToPow2(uint32_t x) {
+ uint32_t rv = 1;
+ while (rv < x) {
+ rv <<= 1;
+ }
+ return rv;
+}
+} // namespace
+
+DynamicBloom::DynamicBloom(Allocator* allocator, uint32_t total_bits,
+ uint32_t num_probes, size_t huge_page_tlb_size,
+ Logger* logger)
+ // Round down, except round up with 1
+ : kNumDoubleProbes((num_probes + (num_probes == 1)) / 2) {
+ assert(num_probes % 2 == 0); // limitation of current implementation
+ assert(num_probes <= 10); // limitation of current implementation
+ assert(kNumDoubleProbes > 0);
+
+ // Determine how much to round off + align by so that x ^ i (that's xor) is
+ // a valid u64 index if x is a valid u64 index and 0 <= i < kNumDoubleProbes.
+ uint32_t block_bytes = /*bytes/u64*/ 8 *
+ /*u64s*/ std::max(1U, roundUpToPow2(kNumDoubleProbes));
+ uint32_t block_bits = block_bytes * 8;
+ uint32_t blocks = (total_bits + block_bits - 1) / block_bits;
+ uint32_t sz = blocks * block_bytes;
+ kLen = sz / /*bytes/u64*/ 8;
+ assert(kLen > 0);
+#ifndef NDEBUG
+ for (uint32_t i = 0; i < kNumDoubleProbes; ++i) {
+ // Ensure probes starting at last word are in range
+ assert(((kLen - 1) ^ i) < kLen);
+ }
+#endif
+
+ // Padding to correct for allocation not originally aligned on block_bytes
+ // boundary
+ sz += block_bytes - 1;
+ assert(allocator);
+
+ char* raw = allocator->AllocateAligned(sz, huge_page_tlb_size, logger);
+ memset(raw, 0, sz);
+ auto block_offset = reinterpret_cast<uintptr_t>(raw) % block_bytes;
+ if (block_offset > 0) {
+ // Align on block_bytes boundary
+ raw += block_bytes - block_offset;
+ }
+ static_assert(sizeof(std::atomic<uint64_t>) == sizeof(uint64_t),
+ "Expecting zero-space-overhead atomic");
+ data_ = reinterpret_cast<std::atomic<uint64_t>*>(raw);
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/dynamic_bloom.h b/src/rocksdb/util/dynamic_bloom.h
new file mode 100644
index 000000000..40cd29404
--- /dev/null
+++ b/src/rocksdb/util/dynamic_bloom.h
@@ -0,0 +1,214 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#include <array>
+#include <atomic>
+#include <memory>
+#include <string>
+
+#include "port/port.h"
+#include "rocksdb/slice.h"
+#include "table/multiget_context.h"
+#include "util/hash.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class Slice;
+class Allocator;
+class Logger;
+
+// A Bloom filter intended only to be used in memory, never serialized in a way
+// that could lead to schema incompatibility. Supports opt-in lock-free
+// concurrent access.
+//
+// This implementation is also intended for applications generally preferring
+// speed vs. maximum accuracy: roughly 0.9x BF op latency for 1.1x FP rate.
+// For 1% FP rate, that means that the latency of a look-up triggered by an FP
+// should be less than roughly 100x the cost of a Bloom filter op.
+//
+// For simplicity and performance, the current implementation requires
+// num_probes to be a multiple of two and <= 10.
+//
+class DynamicBloom {
+ public:
+ // allocator: pass allocator to bloom filter, hence trace the usage of memory
+ // total_bits: fixed total bits for the bloom
+ // num_probes: number of hash probes for a single key
+ // hash_func: customized hash function
+ // huge_page_tlb_size: if >0, try to allocate bloom bytes from huge page TLB
+ // within this page size. Need to reserve huge pages for
+ // it to be allocated, like:
+ // sysctl -w vm.nr_hugepages=20
+ // See linux doc Documentation/vm/hugetlbpage.txt
+ explicit DynamicBloom(Allocator* allocator, uint32_t total_bits,
+ uint32_t num_probes = 6, size_t huge_page_tlb_size = 0,
+ Logger* logger = nullptr);
+
+ ~DynamicBloom() {}
+
+ // Assuming single threaded access to this function.
+ void Add(const Slice& key);
+
+ // Like Add, but may be called concurrent with other functions.
+ void AddConcurrently(const Slice& key);
+
+ // Assuming single threaded access to this function.
+ void AddHash(uint32_t hash);
+
+ // Like AddHash, but may be called concurrent with other functions.
+ void AddHashConcurrently(uint32_t hash);
+
+ // Multithreaded access to this function is OK
+ bool MayContain(const Slice& key) const;
+
+ void MayContain(int num_keys, Slice* keys, bool* may_match) const;
+
+ // Multithreaded access to this function is OK
+ bool MayContainHash(uint32_t hash) const;
+
+ void Prefetch(uint32_t h);
+
+ private:
+ // Length of the structure, in 64-bit words. For this structure, "word"
+ // will always refer to 64-bit words.
+ uint32_t kLen;
+ // We make the k probes in pairs, two for each 64-bit read/write. Thus,
+ // this stores k/2, the number of words to double-probe.
+ const uint32_t kNumDoubleProbes;
+
+ std::atomic<uint64_t>* data_;
+
+ // or_func(ptr, mask) should effect *ptr |= mask with the appropriate
+ // concurrency safety, working with bytes.
+ template <typename OrFunc>
+ void AddHash(uint32_t hash, const OrFunc& or_func);
+
+ bool DoubleProbe(uint32_t h32, size_t a) const;
+};
+
+inline void DynamicBloom::Add(const Slice& key) { AddHash(BloomHash(key)); }
+
+inline void DynamicBloom::AddConcurrently(const Slice& key) {
+ AddHashConcurrently(BloomHash(key));
+}
+
+inline void DynamicBloom::AddHash(uint32_t hash) {
+ AddHash(hash, [](std::atomic<uint64_t>* ptr, uint64_t mask) {
+ ptr->store(ptr->load(std::memory_order_relaxed) | mask,
+ std::memory_order_relaxed);
+ });
+}
+
+inline void DynamicBloom::AddHashConcurrently(uint32_t hash) {
+ AddHash(hash, [](std::atomic<uint64_t>* ptr, uint64_t mask) {
+ // Happens-before between AddHash and MaybeContains is handled by
+ // access to versions_->LastSequence(), so all we have to do here is
+ // avoid races (so we don't give the compiler a license to mess up
+ // our code) and not lose bits. std::memory_order_relaxed is enough
+ // for that.
+ if ((mask & ptr->load(std::memory_order_relaxed)) != mask) {
+ ptr->fetch_or(mask, std::memory_order_relaxed);
+ }
+ });
+}
+
+inline bool DynamicBloom::MayContain(const Slice& key) const {
+ return (MayContainHash(BloomHash(key)));
+}
+
+inline void DynamicBloom::MayContain(int num_keys, Slice* keys,
+ bool* may_match) const {
+ std::array<uint32_t, MultiGetContext::MAX_BATCH_SIZE> hashes;
+ std::array<size_t, MultiGetContext::MAX_BATCH_SIZE> byte_offsets;
+ for (int i = 0; i < num_keys; ++i) {
+ hashes[i] = BloomHash(keys[i]);
+ size_t a = FastRange32(kLen, hashes[i]);
+ PREFETCH(data_ + a, 0, 3);
+ byte_offsets[i] = a;
+ }
+
+ for (int i = 0; i < num_keys; i++) {
+ may_match[i] = DoubleProbe(hashes[i], byte_offsets[i]);
+ }
+}
+
+#if defined(_MSC_VER)
+#pragma warning(push)
+// local variable is initialized but not referenced
+#pragma warning(disable : 4189)
+#endif
+inline void DynamicBloom::Prefetch(uint32_t h32) {
+ size_t a = FastRange32(kLen, h32);
+ PREFETCH(data_ + a, 0, 3);
+}
+#if defined(_MSC_VER)
+#pragma warning(pop)
+#endif
+
+// Speed hacks in this implementation:
+// * Uses fastrange instead of %
+// * Minimum logic to determine first (and all) probed memory addresses.
+// (Uses constant bit-xor offsets from the starting probe address.)
+// * (Major) Two probes per 64-bit memory fetch/write.
+// Code simplification / optimization: only allow even number of probes.
+// * Very fast and effective (murmur-like) hash expansion/re-mixing. (At
+// least on recent CPUs, integer multiplication is very cheap. Each 64-bit
+// remix provides five pairs of bit addresses within a uint64_t.)
+// Code simplification / optimization: only allow up to 10 probes, from a
+// single 64-bit remix.
+//
+// The FP rate penalty for this implementation, vs. standard Bloom filter, is
+// roughly 1.12x on top of the 1.15x penalty for a 512-bit cache-local Bloom.
+// This implementation does not explicitly use the cache line size, but is
+// effectively cache-local (up to 16 probes) because of the bit-xor offsetting.
+//
+// NB: could easily be upgraded to support a 64-bit hash and
+// total_bits > 2^32 (512MB). (The latter is a bad idea without the former,
+// because of false positives.)
+
+inline bool DynamicBloom::MayContainHash(uint32_t h32) const {
+ size_t a = FastRange32(kLen, h32);
+ PREFETCH(data_ + a, 0, 3);
+ return DoubleProbe(h32, a);
+}
+
+inline bool DynamicBloom::DoubleProbe(uint32_t h32, size_t byte_offset) const {
+ // Expand/remix with 64-bit golden ratio
+ uint64_t h = 0x9e3779b97f4a7c13ULL * h32;
+ for (unsigned i = 0;; ++i) {
+ // Two bit probes per uint64_t probe
+ uint64_t mask =
+ ((uint64_t)1 << (h & 63)) | ((uint64_t)1 << ((h >> 6) & 63));
+ uint64_t val = data_[byte_offset ^ i].load(std::memory_order_relaxed);
+ if (i + 1 >= kNumDoubleProbes) {
+ return (val & mask) == mask;
+ } else if ((val & mask) != mask) {
+ return false;
+ }
+ h = (h >> 12) | (h << 52);
+ }
+}
+
+template <typename OrFunc>
+inline void DynamicBloom::AddHash(uint32_t h32, const OrFunc& or_func) {
+ size_t a = FastRange32(kLen, h32);
+ PREFETCH(data_ + a, 0, 3);
+ // Expand/remix with 64-bit golden ratio
+ uint64_t h = 0x9e3779b97f4a7c13ULL * h32;
+ for (unsigned i = 0;; ++i) {
+ // Two bit probes per uint64_t probe
+ uint64_t mask =
+ ((uint64_t)1 << (h & 63)) | ((uint64_t)1 << ((h >> 6) & 63));
+ or_func(&data_[a ^ i], mask);
+ if (i + 1 >= kNumDoubleProbes) {
+ return;
+ }
+ h = (h >> 12) | (h << 52);
+ }
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/dynamic_bloom_test.cc b/src/rocksdb/util/dynamic_bloom_test.cc
new file mode 100644
index 000000000..925c5479a
--- /dev/null
+++ b/src/rocksdb/util/dynamic_bloom_test.cc
@@ -0,0 +1,325 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef GFLAGS
+#include <cstdio>
+int main() {
+ fprintf(stderr, "Please install gflags to run this test... Skipping...\n");
+ return 0;
+}
+#else
+
+#include <algorithm>
+#include <atomic>
+#include <cinttypes>
+#include <functional>
+#include <memory>
+#include <thread>
+#include <vector>
+
+#include "dynamic_bloom.h"
+#include "memory/arena.h"
+#include "port/port.h"
+#include "rocksdb/system_clock.h"
+#include "test_util/testharness.h"
+#include "test_util/testutil.h"
+#include "util/gflags_compat.h"
+#include "util/stop_watch.h"
+
+using GFLAGS_NAMESPACE::ParseCommandLineFlags;
+
+DEFINE_int32(bits_per_key, 10, "");
+DEFINE_int32(num_probes, 6, "");
+DEFINE_bool(enable_perf, false, "");
+
+namespace ROCKSDB_NAMESPACE {
+
+struct KeyMaker {
+ uint64_t a;
+ uint64_t b;
+
+ // Sequential, within a hash function block
+ inline Slice Seq(uint64_t i) {
+ a = i;
+ return Slice(reinterpret_cast<char *>(&a), sizeof(a));
+ }
+ // Not quite sequential, varies across hash function blocks
+ inline Slice Nonseq(uint64_t i) {
+ a = i;
+ b = i * 123;
+ return Slice(reinterpret_cast<char *>(this), sizeof(*this));
+ }
+ inline Slice Key(uint64_t i, bool nonseq) {
+ return nonseq ? Nonseq(i) : Seq(i);
+ }
+};
+
+class DynamicBloomTest : public testing::Test {};
+
+TEST_F(DynamicBloomTest, EmptyFilter) {
+ Arena arena;
+ DynamicBloom bloom1(&arena, 100, 2);
+ ASSERT_TRUE(!bloom1.MayContain("hello"));
+ ASSERT_TRUE(!bloom1.MayContain("world"));
+
+ DynamicBloom bloom2(&arena, CACHE_LINE_SIZE * 8 * 2 - 1, 2);
+ ASSERT_TRUE(!bloom2.MayContain("hello"));
+ ASSERT_TRUE(!bloom2.MayContain("world"));
+}
+
+TEST_F(DynamicBloomTest, Small) {
+ Arena arena;
+ DynamicBloom bloom1(&arena, 100, 2);
+ bloom1.Add("hello");
+ bloom1.Add("world");
+ ASSERT_TRUE(bloom1.MayContain("hello"));
+ ASSERT_TRUE(bloom1.MayContain("world"));
+ ASSERT_TRUE(!bloom1.MayContain("x"));
+ ASSERT_TRUE(!bloom1.MayContain("foo"));
+
+ DynamicBloom bloom2(&arena, CACHE_LINE_SIZE * 8 * 2 - 1, 2);
+ bloom2.Add("hello");
+ bloom2.Add("world");
+ ASSERT_TRUE(bloom2.MayContain("hello"));
+ ASSERT_TRUE(bloom2.MayContain("world"));
+ ASSERT_TRUE(!bloom2.MayContain("x"));
+ ASSERT_TRUE(!bloom2.MayContain("foo"));
+}
+
+TEST_F(DynamicBloomTest, SmallConcurrentAdd) {
+ Arena arena;
+ DynamicBloom bloom1(&arena, 100, 2);
+ bloom1.AddConcurrently("hello");
+ bloom1.AddConcurrently("world");
+ ASSERT_TRUE(bloom1.MayContain("hello"));
+ ASSERT_TRUE(bloom1.MayContain("world"));
+ ASSERT_TRUE(!bloom1.MayContain("x"));
+ ASSERT_TRUE(!bloom1.MayContain("foo"));
+
+ DynamicBloom bloom2(&arena, CACHE_LINE_SIZE * 8 * 2 - 1, 2);
+ bloom2.AddConcurrently("hello");
+ bloom2.AddConcurrently("world");
+ ASSERT_TRUE(bloom2.MayContain("hello"));
+ ASSERT_TRUE(bloom2.MayContain("world"));
+ ASSERT_TRUE(!bloom2.MayContain("x"));
+ ASSERT_TRUE(!bloom2.MayContain("foo"));
+}
+
+static uint32_t NextNum(uint32_t num) {
+ if (num < 10) {
+ num += 1;
+ } else if (num < 100) {
+ num += 10;
+ } else if (num < 1000) {
+ num += 100;
+ } else {
+ num = num * 26 / 10;
+ }
+ return num;
+}
+
+TEST_F(DynamicBloomTest, VaryingLengths) {
+ KeyMaker km;
+
+ // Count number of filters that significantly exceed the false positive rate
+ int mediocre_filters = 0;
+ int good_filters = 0;
+ uint32_t num_probes = static_cast<uint32_t>(FLAGS_num_probes);
+
+ fprintf(stderr, "bits_per_key: %d num_probes: %d\n", FLAGS_bits_per_key,
+ num_probes);
+
+ // NB: FP rate impact of 32-bit hash is noticeable starting around 10M keys.
+ // But that effect is hidden if using sequential keys (unique hashes).
+ for (bool nonseq : {false, true}) {
+ const uint32_t max_num = FLAGS_enable_perf ? 40000000 : 400000;
+ for (uint32_t num = 1; num <= max_num; num = NextNum(num)) {
+ uint32_t bloom_bits = 0;
+ Arena arena;
+ bloom_bits = num * FLAGS_bits_per_key;
+ DynamicBloom bloom(&arena, bloom_bits, num_probes);
+ for (uint64_t i = 0; i < num; i++) {
+ bloom.Add(km.Key(i, nonseq));
+ ASSERT_TRUE(bloom.MayContain(km.Key(i, nonseq)));
+ }
+
+ // All added keys must match
+ for (uint64_t i = 0; i < num; i++) {
+ ASSERT_TRUE(bloom.MayContain(km.Key(i, nonseq)));
+ }
+
+ // Check false positive rate
+ int result = 0;
+ for (uint64_t i = 0; i < 30000; i++) {
+ if (bloom.MayContain(km.Key(i + 1000000000, nonseq))) {
+ result++;
+ }
+ }
+ double rate = result / 30000.0;
+
+ fprintf(stderr,
+ "False positives (%s keys): "
+ "%5.2f%% @ num = %6u, bloom_bits = %6u\n",
+ nonseq ? "nonseq" : "seq", rate * 100.0, num, bloom_bits);
+
+ if (rate > 0.0125)
+ mediocre_filters++; // Allowed, but not too often
+ else
+ good_filters++;
+ }
+ }
+
+ fprintf(stderr, "Filters: %d good, %d mediocre\n", good_filters,
+ mediocre_filters);
+ ASSERT_LE(mediocre_filters, good_filters / 25);
+}
+
+TEST_F(DynamicBloomTest, perf) {
+ KeyMaker km;
+ StopWatchNano timer(SystemClock::Default().get());
+ uint32_t num_probes = static_cast<uint32_t>(FLAGS_num_probes);
+
+ if (!FLAGS_enable_perf) {
+ return;
+ }
+
+ for (uint32_t m = 1; m <= 8; ++m) {
+ Arena arena;
+ const uint32_t num_keys = m * 8 * 1024 * 1024;
+ fprintf(stderr, "testing %" PRIu32 "M keys\n", m * 8);
+
+ DynamicBloom std_bloom(&arena, num_keys * 10, num_probes);
+
+ timer.Start();
+ for (uint64_t i = 1; i <= num_keys; ++i) {
+ std_bloom.Add(km.Seq(i));
+ }
+
+ uint64_t elapsed = timer.ElapsedNanos();
+ fprintf(stderr, "dynamic bloom, avg add latency %3g\n",
+ static_cast<double>(elapsed) / num_keys);
+
+ uint32_t count = 0;
+ timer.Start();
+ for (uint64_t i = 1; i <= num_keys; ++i) {
+ if (std_bloom.MayContain(km.Seq(i))) {
+ ++count;
+ }
+ }
+ ASSERT_EQ(count, num_keys);
+ elapsed = timer.ElapsedNanos();
+ assert(count > 0);
+ fprintf(stderr, "dynamic bloom, avg query latency %3g\n",
+ static_cast<double>(elapsed) / count);
+ }
+}
+
+TEST_F(DynamicBloomTest, concurrent_with_perf) {
+ uint32_t num_probes = static_cast<uint32_t>(FLAGS_num_probes);
+
+ uint32_t m_limit = FLAGS_enable_perf ? 8 : 1;
+
+ uint32_t num_threads = 4;
+ std::vector<port::Thread> threads;
+
+ // NB: Uses sequential keys for speed, but that hides the FP rate
+ // impact of 32-bit hash, which is noticeable starting around 10M keys
+ // when they vary across hashing blocks.
+ for (uint32_t m = 1; m <= m_limit; ++m) {
+ Arena arena;
+ const uint32_t num_keys = m * 8 * 1024 * 1024;
+ fprintf(stderr, "testing %" PRIu32 "M keys\n", m * 8);
+
+ DynamicBloom std_bloom(&arena, num_keys * 10, num_probes);
+
+ std::atomic<uint64_t> elapsed(0);
+
+ std::function<void(size_t)> adder([&](size_t t) {
+ KeyMaker km;
+ StopWatchNano timer(SystemClock::Default().get());
+ timer.Start();
+ for (uint64_t i = 1 + t; i <= num_keys; i += num_threads) {
+ std_bloom.AddConcurrently(km.Seq(i));
+ }
+ elapsed += timer.ElapsedNanos();
+ });
+ for (size_t t = 0; t < num_threads; ++t) {
+ threads.emplace_back(adder, t);
+ }
+ while (threads.size() > 0) {
+ threads.back().join();
+ threads.pop_back();
+ }
+
+ fprintf(stderr,
+ "dynamic bloom, avg parallel add latency %3g"
+ " nanos/key\n",
+ static_cast<double>(elapsed) / num_threads / num_keys);
+
+ elapsed = 0;
+ std::function<void(size_t)> hitter([&](size_t t) {
+ KeyMaker km;
+ StopWatchNano timer(SystemClock::Default().get());
+ timer.Start();
+ for (uint64_t i = 1 + t; i <= num_keys; i += num_threads) {
+ bool f = std_bloom.MayContain(km.Seq(i));
+ ASSERT_TRUE(f);
+ }
+ elapsed += timer.ElapsedNanos();
+ });
+ for (size_t t = 0; t < num_threads; ++t) {
+ threads.emplace_back(hitter, t);
+ }
+ while (threads.size() > 0) {
+ threads.back().join();
+ threads.pop_back();
+ }
+
+ fprintf(stderr,
+ "dynamic bloom, avg parallel hit latency %3g"
+ " nanos/key\n",
+ static_cast<double>(elapsed) / num_threads / num_keys);
+
+ elapsed = 0;
+ std::atomic<uint32_t> false_positives(0);
+ std::function<void(size_t)> misser([&](size_t t) {
+ KeyMaker km;
+ StopWatchNano timer(SystemClock::Default().get());
+ timer.Start();
+ for (uint64_t i = num_keys + 1 + t; i <= 2 * num_keys; i += num_threads) {
+ bool f = std_bloom.MayContain(km.Seq(i));
+ if (f) {
+ ++false_positives;
+ }
+ }
+ elapsed += timer.ElapsedNanos();
+ });
+ for (size_t t = 0; t < num_threads; ++t) {
+ threads.emplace_back(misser, t);
+ }
+ while (threads.size() > 0) {
+ threads.back().join();
+ threads.pop_back();
+ }
+
+ fprintf(stderr,
+ "dynamic bloom, avg parallel miss latency %3g"
+ " nanos/key, %f%% false positive rate\n",
+ static_cast<double>(elapsed) / num_threads / num_keys,
+ false_positives.load() * 100.0 / num_keys);
+ }
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char **argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ ParseCommandLineFlags(&argc, &argv, true);
+
+ return RUN_ALL_TESTS();
+}
+
+#endif // GFLAGS
diff --git a/src/rocksdb/util/fastrange.h b/src/rocksdb/util/fastrange.h
new file mode 100644
index 000000000..a70a980f6
--- /dev/null
+++ b/src/rocksdb/util/fastrange.h
@@ -0,0 +1,114 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+// fastrange/FastRange: A faster alternative to % for mapping a hash value
+// to an arbitrary range. See https://github.com/lemire/fastrange
+//
+// Generally recommended are FastRange32 for mapping results of 32-bit
+// hash functions and FastRange64 for mapping results of 64-bit hash
+// functions. FastRange is less forgiving than % if the input hashes are
+// not well distributed over the full range of the type (32 or 64 bits).
+//
+// Also included is a templated implementation FastRangeGeneric for use
+// in generic algorithms, but not otherwise recommended because of
+// potential ambiguity. Unlike with %, it is critical to use the right
+// FastRange variant for the output size of your hash function.
+
+#pragma once
+
+#include <cstddef>
+#include <cstdint>
+#include <type_traits>
+
+#include "rocksdb/rocksdb_namespace.h"
+
+#ifdef TEST_UINT128_COMPAT
+#undef HAVE_UINT128_EXTENSION
+#endif
+
+namespace ROCKSDB_NAMESPACE {
+
+namespace detail {
+
+// Using a class template to support partial specialization
+template <typename Hash, typename Range>
+struct FastRangeGenericImpl {
+ // only reach this on no supported specialization
+};
+
+template <typename Range>
+struct FastRangeGenericImpl<uint32_t, Range> {
+ static inline Range Fn(uint32_t hash, Range range) {
+ static_assert(std::is_unsigned<Range>::value, "must be unsigned");
+ static_assert(sizeof(Range) <= sizeof(uint32_t),
+ "cannot be larger than hash (32 bits)");
+
+ uint64_t product = uint64_t{range} * hash;
+ return static_cast<Range>(product >> 32);
+ }
+};
+
+template <typename Range>
+struct FastRangeGenericImpl<uint64_t, Range> {
+ static inline Range Fn(uint64_t hash, Range range) {
+ static_assert(std::is_unsigned<Range>::value, "must be unsigned");
+ static_assert(sizeof(Range) <= sizeof(uint64_t),
+ "cannot be larger than hash (64 bits)");
+
+#ifdef HAVE_UINT128_EXTENSION
+ // Can use compiler's 128-bit type. Trust it to do the right thing.
+ __uint128_t wide = __uint128_t{range} * hash;
+ return static_cast<Range>(wide >> 64);
+#else
+ // Fall back: full decomposition.
+ // NOTE: GCC seems to fully understand this code as 64-bit x 64-bit
+ // -> 128-bit multiplication and optimize it appropriately
+ uint64_t range64 = range; // ok to shift by 32, even if Range is 32-bit
+ uint64_t tmp = uint64_t{range64 & 0xffffFFFF} * uint64_t{hash & 0xffffFFFF};
+ tmp >>= 32;
+ tmp += uint64_t{range64 & 0xffffFFFF} * uint64_t{hash >> 32};
+ // Avoid overflow: first add lower 32 of tmp2, and later upper 32
+ uint64_t tmp2 = uint64_t{range64 >> 32} * uint64_t{hash & 0xffffFFFF};
+ tmp += static_cast<uint32_t>(tmp2);
+ tmp >>= 32;
+ tmp += (tmp2 >> 32);
+ tmp += uint64_t{range64 >> 32} * uint64_t{hash >> 32};
+ return static_cast<Range>(tmp);
+#endif
+ }
+};
+
+} // namespace detail
+
+// Now an omnibus templated function (yay parameter inference).
+//
+// NOTICE:
+// This templated version is not recommended for typical use because
+// of the potential to mix a 64-bit FastRange with a 32-bit bit hash,
+// most likely because you put your 32-bit hash in an "unsigned long"
+// which is 64 bits on some platforms. That doesn't really matter for
+// an operation like %, but 64-bit FastRange gives extremely bad results,
+// mostly zero, on 32-bit hash values. And because good hashing is not
+// generally required for correctness, this kind of mistake could go
+// unnoticed with just unit tests. Plus it could vary by platform.
+template <typename Hash, typename Range>
+inline Range FastRangeGeneric(Hash hash, Range range) {
+ return detail::FastRangeGenericImpl<Hash, Range>::Fn(hash, range);
+}
+
+// The most popular / convenient / recommended variants:
+
+// Map a quality 64-bit hash value down to an arbitrary size_t range.
+// (size_t is standard for mapping to things in memory.)
+inline size_t FastRange64(uint64_t hash, size_t range) {
+ return FastRangeGeneric(hash, range);
+}
+
+// Map a quality 32-bit hash value down to an arbitrary uint32_t range.
+inline uint32_t FastRange32(uint32_t hash, uint32_t range) {
+ return FastRangeGeneric(hash, range);
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/file_checksum_helper.cc b/src/rocksdb/util/file_checksum_helper.cc
new file mode 100644
index 000000000..a73920352
--- /dev/null
+++ b/src/rocksdb/util/file_checksum_helper.cc
@@ -0,0 +1,172 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "util/file_checksum_helper.h"
+
+#include <unordered_set>
+
+#include "db/log_reader.h"
+#include "db/version_edit.h"
+#include "db/version_edit_handler.h"
+#include "file/sequence_file_reader.h"
+#include "rocksdb/utilities/customizable_util.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+void FileChecksumListImpl::reset() { checksum_map_.clear(); }
+
+size_t FileChecksumListImpl::size() const { return checksum_map_.size(); }
+
+Status FileChecksumListImpl::GetAllFileChecksums(
+ std::vector<uint64_t>* file_numbers, std::vector<std::string>* checksums,
+ std::vector<std::string>* checksum_func_names) {
+ if (file_numbers == nullptr || checksums == nullptr ||
+ checksum_func_names == nullptr) {
+ return Status::InvalidArgument("Pointer has not been initiated");
+ }
+
+ for (auto i : checksum_map_) {
+ file_numbers->push_back(i.first);
+ checksums->push_back(i.second.first);
+ checksum_func_names->push_back(i.second.second);
+ }
+ return Status::OK();
+}
+
+Status FileChecksumListImpl::SearchOneFileChecksum(
+ uint64_t file_number, std::string* checksum,
+ std::string* checksum_func_name) {
+ if (checksum == nullptr || checksum_func_name == nullptr) {
+ return Status::InvalidArgument("Pointer has not been initiated");
+ }
+
+ auto it = checksum_map_.find(file_number);
+ if (it == checksum_map_.end()) {
+ return Status::NotFound();
+ } else {
+ *checksum = it->second.first;
+ *checksum_func_name = it->second.second;
+ }
+ return Status::OK();
+}
+
+Status FileChecksumListImpl::InsertOneFileChecksum(
+ uint64_t file_number, const std::string& checksum,
+ const std::string& checksum_func_name) {
+ auto it = checksum_map_.find(file_number);
+ if (it == checksum_map_.end()) {
+ checksum_map_.insert(std::make_pair(
+ file_number, std::make_pair(checksum, checksum_func_name)));
+ } else {
+ it->second.first = checksum;
+ it->second.second = checksum_func_name;
+ }
+ return Status::OK();
+}
+
+Status FileChecksumListImpl::RemoveOneFileChecksum(uint64_t file_number) {
+ auto it = checksum_map_.find(file_number);
+ if (it == checksum_map_.end()) {
+ return Status::NotFound();
+ } else {
+ checksum_map_.erase(it);
+ }
+ return Status::OK();
+}
+
+FileChecksumList* NewFileChecksumList() {
+ FileChecksumListImpl* checksum_list = new FileChecksumListImpl();
+ return checksum_list;
+}
+
+std::shared_ptr<FileChecksumGenFactory> GetFileChecksumGenCrc32cFactory() {
+ static std::shared_ptr<FileChecksumGenFactory> default_crc32c_gen_factory(
+ new FileChecksumGenCrc32cFactory());
+ return default_crc32c_gen_factory;
+}
+
+Status GetFileChecksumsFromManifest(Env* src_env, const std::string& abs_path,
+ uint64_t manifest_file_size,
+ FileChecksumList* checksum_list) {
+ if (checksum_list == nullptr) {
+ return Status::InvalidArgument("checksum_list is nullptr");
+ }
+ assert(checksum_list);
+ checksum_list->reset();
+ Status s;
+
+ std::unique_ptr<SequentialFileReader> file_reader;
+ {
+ std::unique_ptr<FSSequentialFile> file;
+ const std::shared_ptr<FileSystem>& fs = src_env->GetFileSystem();
+ s = fs->NewSequentialFile(abs_path,
+ fs->OptimizeForManifestRead(FileOptions()), &file,
+ nullptr /* dbg */);
+ if (!s.ok()) {
+ return s;
+ }
+ file_reader.reset(new SequentialFileReader(std::move(file), abs_path));
+ }
+
+ struct LogReporter : public log::Reader::Reporter {
+ Status* status_ptr;
+ virtual void Corruption(size_t /*bytes*/, const Status& st) override {
+ if (status_ptr->ok()) {
+ *status_ptr = st;
+ }
+ }
+ } reporter;
+ reporter.status_ptr = &s;
+ log::Reader reader(nullptr, std::move(file_reader), &reporter,
+ true /* checksum */, 0 /* log_number */);
+ FileChecksumRetriever retriever(manifest_file_size, *checksum_list);
+ retriever.Iterate(reader, &s);
+ assert(!retriever.status().ok() ||
+ manifest_file_size == std::numeric_limits<uint64_t>::max() ||
+ reader.LastRecordEnd() == manifest_file_size);
+
+ return retriever.status();
+}
+
+#ifndef ROCKSDB_LITE
+namespace {
+static int RegisterFileChecksumGenFactories(ObjectLibrary& library,
+ const std::string& /*arg*/) {
+ library.AddFactory<FileChecksumGenFactory>(
+ FileChecksumGenCrc32cFactory::kClassName(),
+ [](const std::string& /*uri*/,
+ std::unique_ptr<FileChecksumGenFactory>* guard,
+ std::string* /* errmsg */) {
+ guard->reset(new FileChecksumGenCrc32cFactory());
+ return guard->get();
+ });
+ return 1;
+}
+} // namespace
+#endif // !ROCKSDB_LITE
+
+Status FileChecksumGenFactory::CreateFromString(
+ const ConfigOptions& options, const std::string& value,
+ std::shared_ptr<FileChecksumGenFactory>* result) {
+#ifndef ROCKSDB_LITE
+ static std::once_flag once;
+ std::call_once(once, [&]() {
+ RegisterFileChecksumGenFactories(*(ObjectLibrary::Default().get()), "");
+ });
+#endif // ROCKSDB_LITE
+ if (value == FileChecksumGenCrc32cFactory::kClassName()) {
+ *result = GetFileChecksumGenCrc32cFactory();
+ return Status::OK();
+ } else {
+ Status s = LoadSharedObject<FileChecksumGenFactory>(options, value, nullptr,
+ result);
+ return s;
+ }
+}
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/file_checksum_helper.h b/src/rocksdb/util/file_checksum_helper.h
new file mode 100644
index 000000000..d622e9bba
--- /dev/null
+++ b/src/rocksdb/util/file_checksum_helper.h
@@ -0,0 +1,100 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+#include <cassert>
+#include <unordered_map>
+
+#include "port/port.h"
+#include "rocksdb/file_checksum.h"
+#include "rocksdb/status.h"
+#include "util/coding.h"
+#include "util/crc32c.h"
+#include "util/math.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// This is the class to generate the file checksum based on Crc32. It
+// will be used as the default checksum method for SST file checksum
+class FileChecksumGenCrc32c : public FileChecksumGenerator {
+ public:
+ FileChecksumGenCrc32c(const FileChecksumGenContext& /*context*/) {
+ checksum_ = 0;
+ }
+
+ void Update(const char* data, size_t n) override {
+ checksum_ = crc32c::Extend(checksum_, data, n);
+ }
+
+ void Finalize() override {
+ assert(checksum_str_.empty());
+ // Store as big endian raw bytes
+ PutFixed32(&checksum_str_, EndianSwapValue(checksum_));
+ }
+
+ std::string GetChecksum() const override {
+ assert(!checksum_str_.empty());
+ return checksum_str_;
+ }
+
+ const char* Name() const override { return "FileChecksumCrc32c"; }
+
+ private:
+ uint32_t checksum_;
+ std::string checksum_str_;
+};
+
+class FileChecksumGenCrc32cFactory : public FileChecksumGenFactory {
+ public:
+ std::unique_ptr<FileChecksumGenerator> CreateFileChecksumGenerator(
+ const FileChecksumGenContext& context) override {
+ if (context.requested_checksum_func_name.empty() ||
+ context.requested_checksum_func_name == "FileChecksumCrc32c") {
+ return std::unique_ptr<FileChecksumGenerator>(
+ new FileChecksumGenCrc32c(context));
+ } else {
+ return nullptr;
+ }
+ }
+
+ static const char* kClassName() { return "FileChecksumGenCrc32cFactory"; }
+ const char* Name() const override { return kClassName(); }
+};
+
+// The default implementaion of FileChecksumList
+class FileChecksumListImpl : public FileChecksumList {
+ public:
+ FileChecksumListImpl() {}
+ void reset() override;
+
+ size_t size() const override;
+
+ Status GetAllFileChecksums(
+ std::vector<uint64_t>* file_numbers, std::vector<std::string>* checksums,
+ std::vector<std::string>* checksum_func_names) override;
+
+ Status SearchOneFileChecksum(uint64_t file_number, std::string* checksum,
+ std::string* checksum_func_name) override;
+
+ Status InsertOneFileChecksum(uint64_t file_number,
+ const std::string& checksum,
+ const std::string& checksum_func_name) override;
+
+ Status RemoveOneFileChecksum(uint64_t file_number) override;
+
+ private:
+ // Key is the file number, the first portion of the value is checksum, the
+ // second portion of the value is checksum function name.
+ std::unordered_map<uint64_t, std::pair<std::string, std::string>>
+ checksum_map_;
+};
+
+// If manifest_file_size < std::numeric_limits<uint64_t>::max(), only use
+// that length prefix of the manifest file.
+Status GetFileChecksumsFromManifest(Env* src_env, const std::string& abs_path,
+ uint64_t manifest_file_size,
+ FileChecksumList* checksum_list);
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/file_reader_writer_test.cc b/src/rocksdb/util/file_reader_writer_test.cc
new file mode 100644
index 000000000..e778efc3c
--- /dev/null
+++ b/src/rocksdb/util/file_reader_writer_test.cc
@@ -0,0 +1,1066 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+#include <algorithm>
+#include <vector>
+
+#include "db/db_test_util.h"
+#include "env/mock_env.h"
+#include "file/line_file_reader.h"
+#include "file/random_access_file_reader.h"
+#include "file/read_write_util.h"
+#include "file/readahead_raf.h"
+#include "file/sequence_file_reader.h"
+#include "file/writable_file_writer.h"
+#include "rocksdb/file_system.h"
+#include "test_util/testharness.h"
+#include "test_util/testutil.h"
+#include "util/crc32c.h"
+#include "util/random.h"
+#include "utilities/fault_injection_fs.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class WritableFileWriterTest : public testing::Test {};
+
+constexpr uint32_t kMb = static_cast<uint32_t>(1) << 20;
+
+TEST_F(WritableFileWriterTest, RangeSync) {
+ class FakeWF : public FSWritableFile {
+ public:
+ explicit FakeWF() : size_(0), last_synced_(0) {}
+ ~FakeWF() override {}
+
+ using FSWritableFile::Append;
+ IOStatus Append(const Slice& data, const IOOptions& /*options*/,
+ IODebugContext* /*dbg*/) override {
+ size_ += data.size();
+ return IOStatus::OK();
+ }
+ IOStatus Truncate(uint64_t /*size*/, const IOOptions& /*options*/,
+ IODebugContext* /*dbg*/) override {
+ return IOStatus::OK();
+ }
+ IOStatus Close(const IOOptions& /*options*/,
+ IODebugContext* /*dbg*/) override {
+ EXPECT_GE(size_, last_synced_ + kMb);
+ EXPECT_LT(size_, last_synced_ + 2 * kMb);
+ // Make sure random writes generated enough writes.
+ EXPECT_GT(size_, 10 * kMb);
+ return IOStatus::OK();
+ }
+ IOStatus Flush(const IOOptions& /*options*/,
+ IODebugContext* /*dbg*/) override {
+ return IOStatus::OK();
+ }
+ IOStatus Sync(const IOOptions& /*options*/,
+ IODebugContext* /*dbg*/) override {
+ return IOStatus::OK();
+ }
+ IOStatus Fsync(const IOOptions& /*options*/,
+ IODebugContext* /*dbg*/) override {
+ return IOStatus::OK();
+ }
+ void SetIOPriority(Env::IOPriority /*pri*/) override {}
+ uint64_t GetFileSize(const IOOptions& /*options*/,
+ IODebugContext* /*dbg*/) override {
+ return size_;
+ }
+ void GetPreallocationStatus(size_t* /*block_size*/,
+ size_t* /*last_allocated_block*/) override {}
+ size_t GetUniqueId(char* /*id*/, size_t /*max_size*/) const override {
+ return 0;
+ }
+ IOStatus InvalidateCache(size_t /*offset*/, size_t /*length*/) override {
+ return IOStatus::OK();
+ }
+
+ protected:
+ IOStatus Allocate(uint64_t /*offset*/, uint64_t /*len*/,
+ const IOOptions& /*options*/,
+ IODebugContext* /*dbg*/) override {
+ return IOStatus::OK();
+ }
+ IOStatus RangeSync(uint64_t offset, uint64_t nbytes,
+ const IOOptions& /*options*/,
+ IODebugContext* /*dbg*/) override {
+ EXPECT_EQ(offset % 4096, 0u);
+ EXPECT_EQ(nbytes % 4096, 0u);
+
+ EXPECT_EQ(offset, last_synced_);
+ last_synced_ = offset + nbytes;
+ EXPECT_GE(size_, last_synced_ + kMb);
+ if (size_ > 2 * kMb) {
+ EXPECT_LT(size_, last_synced_ + 2 * kMb);
+ }
+ return IOStatus::OK();
+ }
+
+ uint64_t size_;
+ uint64_t last_synced_;
+ };
+
+ EnvOptions env_options;
+ env_options.bytes_per_sync = kMb;
+ std::unique_ptr<FakeWF> wf(new FakeWF);
+ std::unique_ptr<WritableFileWriter> writer(
+ new WritableFileWriter(std::move(wf), "" /* don't care */, env_options));
+ Random r(301);
+ Status s;
+ std::unique_ptr<char[]> large_buf(new char[10 * kMb]);
+ for (int i = 0; i < 1000; i++) {
+ int skew_limit = (i < 700) ? 10 : 15;
+ uint32_t num = r.Skewed(skew_limit) * 100 + r.Uniform(100);
+ s = writer->Append(Slice(large_buf.get(), num));
+ ASSERT_OK(s);
+
+ // Flush in a chance of 1/10.
+ if (r.Uniform(10) == 0) {
+ s = writer->Flush();
+ ASSERT_OK(s);
+ }
+ }
+ s = writer->Close();
+ ASSERT_OK(s);
+}
+
+TEST_F(WritableFileWriterTest, IncrementalBuffer) {
+ class FakeWF : public FSWritableFile {
+ public:
+ explicit FakeWF(std::string* _file_data, bool _use_direct_io,
+ bool _no_flush)
+ : file_data_(_file_data),
+ use_direct_io_(_use_direct_io),
+ no_flush_(_no_flush) {}
+ ~FakeWF() override {}
+
+ using FSWritableFile::Append;
+ IOStatus Append(const Slice& data, const IOOptions& /*options*/,
+ IODebugContext* /*dbg*/) override {
+ file_data_->append(data.data(), data.size());
+ size_ += data.size();
+ return IOStatus::OK();
+ }
+ using FSWritableFile::PositionedAppend;
+ IOStatus PositionedAppend(const Slice& data, uint64_t pos,
+ const IOOptions& /*options*/,
+ IODebugContext* /*dbg*/) override {
+ EXPECT_TRUE(pos % 512 == 0);
+ EXPECT_TRUE(data.size() % 512 == 0);
+ file_data_->resize(pos);
+ file_data_->append(data.data(), data.size());
+ size_ += data.size();
+ return IOStatus::OK();
+ }
+
+ IOStatus Truncate(uint64_t size, const IOOptions& /*options*/,
+ IODebugContext* /*dbg*/) override {
+ file_data_->resize(size);
+ return IOStatus::OK();
+ }
+ IOStatus Close(const IOOptions& /*options*/,
+ IODebugContext* /*dbg*/) override {
+ return IOStatus::OK();
+ }
+ IOStatus Flush(const IOOptions& /*options*/,
+ IODebugContext* /*dbg*/) override {
+ return IOStatus::OK();
+ }
+ IOStatus Sync(const IOOptions& /*options*/,
+ IODebugContext* /*dbg*/) override {
+ return IOStatus::OK();
+ }
+ IOStatus Fsync(const IOOptions& /*options*/,
+ IODebugContext* /*dbg*/) override {
+ return IOStatus::OK();
+ }
+ void SetIOPriority(Env::IOPriority /*pri*/) override {}
+ uint64_t GetFileSize(const IOOptions& /*options*/,
+ IODebugContext* /*dbg*/) override {
+ return size_;
+ }
+ void GetPreallocationStatus(size_t* /*block_size*/,
+ size_t* /*last_allocated_block*/) override {}
+ size_t GetUniqueId(char* /*id*/, size_t /*max_size*/) const override {
+ return 0;
+ }
+ IOStatus InvalidateCache(size_t /*offset*/, size_t /*length*/) override {
+ return IOStatus::OK();
+ }
+ bool use_direct_io() const override { return use_direct_io_; }
+
+ std::string* file_data_;
+ bool use_direct_io_;
+ bool no_flush_;
+ size_t size_ = 0;
+ };
+
+ Random r(301);
+ const int kNumAttempts = 50;
+ for (int attempt = 0; attempt < kNumAttempts; attempt++) {
+ bool no_flush = (attempt % 3 == 0);
+ EnvOptions env_options;
+ env_options.writable_file_max_buffer_size =
+ (attempt < kNumAttempts / 2) ? 512 * 1024 : 700 * 1024;
+ std::string actual;
+ std::unique_ptr<FakeWF> wf(new FakeWF(&actual,
+#ifndef ROCKSDB_LITE
+ attempt % 2 == 1,
+#else
+ false,
+#endif
+ no_flush));
+ std::unique_ptr<WritableFileWriter> writer(new WritableFileWriter(
+ std::move(wf), "" /* don't care */, env_options));
+
+ std::string target;
+ for (int i = 0; i < 20; i++) {
+ uint32_t num = r.Skewed(16) * 100 + r.Uniform(100);
+ std::string random_string = r.RandomString(num);
+ ASSERT_OK(writer->Append(Slice(random_string.c_str(), num)));
+ target.append(random_string.c_str(), num);
+
+ // In some attempts, flush in a chance of 1/10.
+ if (!no_flush && r.Uniform(10) == 0) {
+ ASSERT_OK(writer->Flush());
+ }
+ }
+ ASSERT_OK(writer->Flush());
+ ASSERT_OK(writer->Close());
+ ASSERT_EQ(target.size(), actual.size());
+ ASSERT_EQ(target, actual);
+ }
+}
+
+TEST_F(WritableFileWriterTest, BufferWithZeroCapacityDirectIO) {
+ EnvOptions env_opts;
+ env_opts.use_direct_writes = true;
+ env_opts.writable_file_max_buffer_size = 0;
+ {
+ std::unique_ptr<WritableFileWriter> writer;
+ const Status s =
+ WritableFileWriter::Create(FileSystem::Default(), /*fname=*/"dont_care",
+ FileOptions(env_opts), &writer,
+ /*dbg=*/nullptr);
+ ASSERT_TRUE(s.IsInvalidArgument());
+ }
+}
+
+class DBWritableFileWriterTest : public DBTestBase {
+ public:
+ DBWritableFileWriterTest()
+ : DBTestBase("db_secondary_cache_test", /*env_do_fsync=*/true) {
+ fault_fs_.reset(new FaultInjectionTestFS(env_->GetFileSystem()));
+ fault_env_.reset(new CompositeEnvWrapper(env_, fault_fs_));
+ }
+
+ std::shared_ptr<FaultInjectionTestFS> fault_fs_;
+ std::unique_ptr<Env> fault_env_;
+};
+
+TEST_F(DBWritableFileWriterTest, AppendWithChecksum) {
+ FileOptions file_options = FileOptions();
+ Options options = GetDefaultOptions();
+ options.create_if_missing = true;
+ DestroyAndReopen(options);
+ std::string fname = dbname_ + "/test_file";
+ std::unique_ptr<FSWritableFile> writable_file_ptr;
+ ASSERT_OK(fault_fs_->NewWritableFile(fname, file_options, &writable_file_ptr,
+ /*dbg*/ nullptr));
+ std::unique_ptr<TestFSWritableFile> file;
+ file.reset(new TestFSWritableFile(
+ fname, file_options, std::move(writable_file_ptr), fault_fs_.get()));
+ std::unique_ptr<WritableFileWriter> file_writer;
+ ImmutableOptions ioptions(options);
+ file_writer.reset(new WritableFileWriter(
+ std::move(file), fname, file_options, SystemClock::Default().get(),
+ nullptr, ioptions.stats, ioptions.listeners,
+ ioptions.file_checksum_gen_factory.get(), true, true));
+
+ Random rnd(301);
+ std::string data = rnd.RandomString(1000);
+ uint32_t data_crc32c = crc32c::Value(data.c_str(), data.size());
+ fault_fs_->SetChecksumHandoffFuncType(ChecksumType::kCRC32c);
+
+ ASSERT_OK(file_writer->Append(Slice(data.c_str()), data_crc32c));
+ ASSERT_OK(file_writer->Flush());
+ Random size_r(47);
+ for (int i = 0; i < 2000; i++) {
+ data = rnd.RandomString((static_cast<int>(size_r.Next()) % 10000));
+ data_crc32c = crc32c::Value(data.c_str(), data.size());
+ ASSERT_OK(file_writer->Append(Slice(data.c_str()), data_crc32c));
+
+ data = rnd.RandomString((static_cast<int>(size_r.Next()) % 97));
+ ASSERT_OK(file_writer->Append(Slice(data.c_str())));
+ ASSERT_OK(file_writer->Flush());
+ }
+ ASSERT_OK(file_writer->Close());
+ Destroy(options);
+}
+
+TEST_F(DBWritableFileWriterTest, AppendVerifyNoChecksum) {
+ FileOptions file_options = FileOptions();
+ Options options = GetDefaultOptions();
+ options.create_if_missing = true;
+ DestroyAndReopen(options);
+ std::string fname = dbname_ + "/test_file";
+ std::unique_ptr<FSWritableFile> writable_file_ptr;
+ ASSERT_OK(fault_fs_->NewWritableFile(fname, file_options, &writable_file_ptr,
+ /*dbg*/ nullptr));
+ std::unique_ptr<TestFSWritableFile> file;
+ file.reset(new TestFSWritableFile(
+ fname, file_options, std::move(writable_file_ptr), fault_fs_.get()));
+ std::unique_ptr<WritableFileWriter> file_writer;
+ ImmutableOptions ioptions(options);
+ // Enable checksum handoff for this file, but do not enable buffer checksum.
+ // So Append with checksum logic will not be triggered
+ file_writer.reset(new WritableFileWriter(
+ std::move(file), fname, file_options, SystemClock::Default().get(),
+ nullptr, ioptions.stats, ioptions.listeners,
+ ioptions.file_checksum_gen_factory.get(), true, false));
+
+ Random rnd(301);
+ std::string data = rnd.RandomString(1000);
+ uint32_t data_crc32c = crc32c::Value(data.c_str(), data.size());
+ fault_fs_->SetChecksumHandoffFuncType(ChecksumType::kCRC32c);
+
+ ASSERT_OK(file_writer->Append(Slice(data.c_str()), data_crc32c));
+ ASSERT_OK(file_writer->Flush());
+ Random size_r(47);
+ for (int i = 0; i < 1000; i++) {
+ data = rnd.RandomString((static_cast<int>(size_r.Next()) % 10000));
+ data_crc32c = crc32c::Value(data.c_str(), data.size());
+ ASSERT_OK(file_writer->Append(Slice(data.c_str()), data_crc32c));
+
+ data = rnd.RandomString((static_cast<int>(size_r.Next()) % 97));
+ ASSERT_OK(file_writer->Append(Slice(data.c_str())));
+ ASSERT_OK(file_writer->Flush());
+ }
+ ASSERT_OK(file_writer->Close());
+ Destroy(options);
+}
+
+TEST_F(DBWritableFileWriterTest, AppendWithChecksumRateLimiter) {
+ FileOptions file_options = FileOptions();
+ file_options.rate_limiter = nullptr;
+ Options options = GetDefaultOptions();
+ options.create_if_missing = true;
+ DestroyAndReopen(options);
+ std::string fname = dbname_ + "/test_file";
+ std::unique_ptr<FSWritableFile> writable_file_ptr;
+ ASSERT_OK(fault_fs_->NewWritableFile(fname, file_options, &writable_file_ptr,
+ /*dbg*/ nullptr));
+ std::unique_ptr<TestFSWritableFile> file;
+ file.reset(new TestFSWritableFile(
+ fname, file_options, std::move(writable_file_ptr), fault_fs_.get()));
+ std::unique_ptr<WritableFileWriter> file_writer;
+ ImmutableOptions ioptions(options);
+ // Enable checksum handoff for this file, but do not enable buffer checksum.
+ // So Append with checksum logic will not be triggered
+ file_writer.reset(new WritableFileWriter(
+ std::move(file), fname, file_options, SystemClock::Default().get(),
+ nullptr, ioptions.stats, ioptions.listeners,
+ ioptions.file_checksum_gen_factory.get(), true, true));
+ fault_fs_->SetChecksumHandoffFuncType(ChecksumType::kCRC32c);
+
+ Random rnd(301);
+ std::string data;
+ uint32_t data_crc32c;
+ uint64_t start = fault_env_->NowMicros();
+ Random size_r(47);
+ uint64_t bytes_written = 0;
+ for (int i = 0; i < 100; i++) {
+ data = rnd.RandomString((static_cast<int>(size_r.Next()) % 10000));
+ data_crc32c = crc32c::Value(data.c_str(), data.size());
+ ASSERT_OK(file_writer->Append(Slice(data.c_str()), data_crc32c));
+ bytes_written += static_cast<uint64_t>(data.size());
+
+ data = rnd.RandomString((static_cast<int>(size_r.Next()) % 97));
+ ASSERT_OK(file_writer->Append(Slice(data.c_str())));
+ ASSERT_OK(file_writer->Flush());
+ bytes_written += static_cast<uint64_t>(data.size());
+ }
+ uint64_t elapsed = fault_env_->NowMicros() - start;
+ double raw_rate = bytes_written * 1000000.0 / elapsed;
+ ASSERT_OK(file_writer->Close());
+
+ // Set the rate-limiter
+ FileOptions file_options1 = FileOptions();
+ file_options1.rate_limiter =
+ NewGenericRateLimiter(static_cast<int64_t>(0.5 * raw_rate));
+ fname = dbname_ + "/test_file_1";
+ std::unique_ptr<FSWritableFile> writable_file_ptr1;
+ ASSERT_OK(fault_fs_->NewWritableFile(fname, file_options1,
+ &writable_file_ptr1,
+ /*dbg*/ nullptr));
+ file.reset(new TestFSWritableFile(
+ fname, file_options1, std::move(writable_file_ptr1), fault_fs_.get()));
+ // Enable checksum handoff for this file, but do not enable buffer checksum.
+ // So Append with checksum logic will not be triggered
+ file_writer.reset(new WritableFileWriter(
+ std::move(file), fname, file_options1, SystemClock::Default().get(),
+ nullptr, ioptions.stats, ioptions.listeners,
+ ioptions.file_checksum_gen_factory.get(), true, true));
+
+ for (int i = 0; i < 1000; i++) {
+ data = rnd.RandomString((static_cast<int>(size_r.Next()) % 10000));
+ data_crc32c = crc32c::Value(data.c_str(), data.size());
+ ASSERT_OK(file_writer->Append(Slice(data.c_str()), data_crc32c));
+
+ data = rnd.RandomString((static_cast<int>(size_r.Next()) % 97));
+ ASSERT_OK(file_writer->Append(Slice(data.c_str())));
+ ASSERT_OK(file_writer->Flush());
+ }
+ ASSERT_OK(file_writer->Close());
+ if (file_options1.rate_limiter != nullptr) {
+ delete file_options1.rate_limiter;
+ }
+
+ Destroy(options);
+}
+
+#ifndef ROCKSDB_LITE
+TEST_F(WritableFileWriterTest, AppendStatusReturn) {
+ class FakeWF : public FSWritableFile {
+ public:
+ explicit FakeWF() : use_direct_io_(false), io_error_(false) {}
+
+ bool use_direct_io() const override { return use_direct_io_; }
+
+ using FSWritableFile::Append;
+ IOStatus Append(const Slice& /*data*/, const IOOptions& /*options*/,
+ IODebugContext* /*dbg*/) override {
+ if (io_error_) {
+ return IOStatus::IOError("Fake IO error");
+ }
+ return IOStatus::OK();
+ }
+ using FSWritableFile::PositionedAppend;
+ IOStatus PositionedAppend(const Slice& /*data*/, uint64_t,
+ const IOOptions& /*options*/,
+ IODebugContext* /*dbg*/) override {
+ if (io_error_) {
+ return IOStatus::IOError("Fake IO error");
+ }
+ return IOStatus::OK();
+ }
+ IOStatus Close(const IOOptions& /*options*/,
+ IODebugContext* /*dbg*/) override {
+ return IOStatus::OK();
+ }
+ IOStatus Flush(const IOOptions& /*options*/,
+ IODebugContext* /*dbg*/) override {
+ return IOStatus::OK();
+ }
+ IOStatus Sync(const IOOptions& /*options*/,
+ IODebugContext* /*dbg*/) override {
+ return IOStatus::OK();
+ }
+ void Setuse_direct_io(bool val) { use_direct_io_ = val; }
+ void SetIOError(bool val) { io_error_ = val; }
+
+ protected:
+ bool use_direct_io_;
+ bool io_error_;
+ };
+ std::unique_ptr<FakeWF> wf(new FakeWF());
+ wf->Setuse_direct_io(true);
+ std::unique_ptr<WritableFileWriter> writer(
+ new WritableFileWriter(std::move(wf), "" /* don't care */, EnvOptions()));
+
+ ASSERT_OK(writer->Append(std::string(2 * kMb, 'a')));
+
+ // Next call to WritableFile::Append() should fail
+ FakeWF* fwf = static_cast<FakeWF*>(writer->writable_file());
+ fwf->SetIOError(true);
+ ASSERT_NOK(writer->Append(std::string(2 * kMb, 'b')));
+}
+#endif
+
+class ReadaheadRandomAccessFileTest
+ : public testing::Test,
+ public testing::WithParamInterface<size_t> {
+ public:
+ static std::vector<size_t> GetReadaheadSizeList() {
+ return {1lu << 12, 1lu << 16};
+ }
+ void SetUp() override {
+ readahead_size_ = GetParam();
+ scratch_.reset(new char[2 * readahead_size_]);
+ ResetSourceStr();
+ }
+ ReadaheadRandomAccessFileTest() : control_contents_() {}
+ std::string Read(uint64_t offset, size_t n) {
+ Slice result;
+ Status s = test_read_holder_->Read(offset, n, IOOptions(), &result,
+ scratch_.get(), nullptr);
+ EXPECT_TRUE(s.ok() || s.IsInvalidArgument());
+ return std::string(result.data(), result.size());
+ }
+ void ResetSourceStr(const std::string& str = "") {
+ std::unique_ptr<FSWritableFile> sink(
+ new test::StringSink(&control_contents_));
+ std::unique_ptr<WritableFileWriter> write_holder(new WritableFileWriter(
+ std::move(sink), "" /* don't care */, FileOptions()));
+ Status s = write_holder->Append(Slice(str));
+ EXPECT_OK(s);
+ s = write_holder->Flush();
+ EXPECT_OK(s);
+ std::unique_ptr<FSRandomAccessFile> read_holder(
+ new test::StringSource(control_contents_));
+ test_read_holder_ =
+ NewReadaheadRandomAccessFile(std::move(read_holder), readahead_size_);
+ }
+ size_t GetReadaheadSize() const { return readahead_size_; }
+
+ private:
+ size_t readahead_size_;
+ Slice control_contents_;
+ std::unique_ptr<FSRandomAccessFile> test_read_holder_;
+ std::unique_ptr<char[]> scratch_;
+};
+
+TEST_P(ReadaheadRandomAccessFileTest, EmptySourceStr) {
+ ASSERT_EQ("", Read(0, 1));
+ ASSERT_EQ("", Read(0, 0));
+ ASSERT_EQ("", Read(13, 13));
+}
+
+TEST_P(ReadaheadRandomAccessFileTest, SourceStrLenLessThanReadaheadSize) {
+ std::string str = "abcdefghijklmnopqrs";
+ ResetSourceStr(str);
+ ASSERT_EQ(str.substr(3, 4), Read(3, 4));
+ ASSERT_EQ(str.substr(0, 3), Read(0, 3));
+ ASSERT_EQ(str, Read(0, str.size()));
+ ASSERT_EQ(str.substr(7, std::min(static_cast<int>(str.size()) - 7, 30)),
+ Read(7, 30));
+ ASSERT_EQ("", Read(100, 100));
+}
+
+TEST_P(ReadaheadRandomAccessFileTest, SourceStrLenGreaterThanReadaheadSize) {
+ Random rng(42);
+ for (int k = 0; k < 100; ++k) {
+ size_t strLen = k * GetReadaheadSize() +
+ rng.Uniform(static_cast<int>(GetReadaheadSize()));
+ std::string str = rng.HumanReadableString(static_cast<int>(strLen));
+ ResetSourceStr(str);
+ for (int test = 1; test <= 100; ++test) {
+ size_t offset = rng.Uniform(static_cast<int>(strLen));
+ size_t n = rng.Uniform(static_cast<int>(GetReadaheadSize()));
+ ASSERT_EQ(str.substr(offset, std::min(n, strLen - offset)),
+ Read(offset, n));
+ }
+ }
+}
+
+TEST_P(ReadaheadRandomAccessFileTest, ReadExceedsReadaheadSize) {
+ Random rng(7);
+ size_t strLen = 4 * GetReadaheadSize() +
+ rng.Uniform(static_cast<int>(GetReadaheadSize()));
+ std::string str = rng.HumanReadableString(static_cast<int>(strLen));
+ ResetSourceStr(str);
+ for (int test = 1; test <= 100; ++test) {
+ size_t offset = rng.Uniform(static_cast<int>(strLen));
+ size_t n =
+ GetReadaheadSize() + rng.Uniform(static_cast<int>(GetReadaheadSize()));
+ ASSERT_EQ(str.substr(offset, std::min(n, strLen - offset)),
+ Read(offset, n));
+ }
+}
+
+INSTANTIATE_TEST_CASE_P(
+ EmptySourceStr, ReadaheadRandomAccessFileTest,
+ ::testing::ValuesIn(ReadaheadRandomAccessFileTest::GetReadaheadSizeList()));
+INSTANTIATE_TEST_CASE_P(
+ SourceStrLenLessThanReadaheadSize, ReadaheadRandomAccessFileTest,
+ ::testing::ValuesIn(ReadaheadRandomAccessFileTest::GetReadaheadSizeList()));
+INSTANTIATE_TEST_CASE_P(
+ SourceStrLenGreaterThanReadaheadSize, ReadaheadRandomAccessFileTest,
+ ::testing::ValuesIn(ReadaheadRandomAccessFileTest::GetReadaheadSizeList()));
+INSTANTIATE_TEST_CASE_P(
+ ReadExceedsReadaheadSize, ReadaheadRandomAccessFileTest,
+ ::testing::ValuesIn(ReadaheadRandomAccessFileTest::GetReadaheadSizeList()));
+
+class ReadaheadSequentialFileTest : public testing::Test,
+ public testing::WithParamInterface<size_t> {
+ public:
+ static std::vector<size_t> GetReadaheadSizeList() {
+ return {1lu << 8, 1lu << 12, 1lu << 16, 1lu << 18};
+ }
+ void SetUp() override {
+ readahead_size_ = GetParam();
+ scratch_.reset(new char[2 * readahead_size_]);
+ ResetSourceStr();
+ }
+ ReadaheadSequentialFileTest() {}
+ std::string Read(size_t n) {
+ Slice result;
+ Status s = test_read_holder_->Read(
+ n, &result, scratch_.get(), Env::IO_TOTAL /* rate_limiter_priority*/);
+ EXPECT_TRUE(s.ok() || s.IsInvalidArgument());
+ return std::string(result.data(), result.size());
+ }
+ void Skip(size_t n) { test_read_holder_->Skip(n); }
+ void ResetSourceStr(const std::string& str = "") {
+ auto read_holder = std::unique_ptr<FSSequentialFile>(
+ new test::SeqStringSource(str, &seq_read_count_));
+ test_read_holder_.reset(new SequentialFileReader(std::move(read_holder),
+ "test", readahead_size_));
+ }
+ size_t GetReadaheadSize() const { return readahead_size_; }
+
+ private:
+ size_t readahead_size_;
+ std::unique_ptr<SequentialFileReader> test_read_holder_;
+ std::unique_ptr<char[]> scratch_;
+ std::atomic<int> seq_read_count_;
+};
+
+TEST_P(ReadaheadSequentialFileTest, EmptySourceStr) {
+ ASSERT_EQ("", Read(0));
+ ASSERT_EQ("", Read(1));
+ ASSERT_EQ("", Read(13));
+}
+
+TEST_P(ReadaheadSequentialFileTest, SourceStrLenLessThanReadaheadSize) {
+ std::string str = "abcdefghijklmnopqrs";
+ ResetSourceStr(str);
+ ASSERT_EQ(str.substr(0, 3), Read(3));
+ ASSERT_EQ(str.substr(3, 1), Read(1));
+ ASSERT_EQ(str.substr(4), Read(str.size()));
+ ASSERT_EQ("", Read(100));
+}
+
+TEST_P(ReadaheadSequentialFileTest, SourceStrLenGreaterThanReadaheadSize) {
+ Random rng(42);
+ for (int s = 0; s < 1; ++s) {
+ for (int k = 0; k < 100; ++k) {
+ size_t strLen = k * GetReadaheadSize() +
+ rng.Uniform(static_cast<int>(GetReadaheadSize()));
+ std::string str = rng.HumanReadableString(static_cast<int>(strLen));
+ ResetSourceStr(str);
+ size_t offset = 0;
+ for (int test = 1; test <= 100; ++test) {
+ size_t n = rng.Uniform(static_cast<int>(GetReadaheadSize()));
+ if (s && test % 2) {
+ Skip(n);
+ } else {
+ ASSERT_EQ(str.substr(offset, std::min(n, strLen - offset)), Read(n));
+ }
+ offset = std::min(offset + n, strLen);
+ }
+ }
+ }
+}
+
+TEST_P(ReadaheadSequentialFileTest, ReadExceedsReadaheadSize) {
+ Random rng(42);
+ for (int s = 0; s < 1; ++s) {
+ for (int k = 0; k < 100; ++k) {
+ size_t strLen = k * GetReadaheadSize() +
+ rng.Uniform(static_cast<int>(GetReadaheadSize()));
+ std::string str = rng.HumanReadableString(static_cast<int>(strLen));
+ ResetSourceStr(str);
+ size_t offset = 0;
+ for (int test = 1; test <= 100; ++test) {
+ size_t n = GetReadaheadSize() +
+ rng.Uniform(static_cast<int>(GetReadaheadSize()));
+ if (s && test % 2) {
+ Skip(n);
+ } else {
+ ASSERT_EQ(str.substr(offset, std::min(n, strLen - offset)), Read(n));
+ }
+ offset = std::min(offset + n, strLen);
+ }
+ }
+ }
+}
+
+INSTANTIATE_TEST_CASE_P(
+ EmptySourceStr, ReadaheadSequentialFileTest,
+ ::testing::ValuesIn(ReadaheadSequentialFileTest::GetReadaheadSizeList()));
+INSTANTIATE_TEST_CASE_P(
+ SourceStrLenLessThanReadaheadSize, ReadaheadSequentialFileTest,
+ ::testing::ValuesIn(ReadaheadSequentialFileTest::GetReadaheadSizeList()));
+INSTANTIATE_TEST_CASE_P(
+ SourceStrLenGreaterThanReadaheadSize, ReadaheadSequentialFileTest,
+ ::testing::ValuesIn(ReadaheadSequentialFileTest::GetReadaheadSizeList()));
+INSTANTIATE_TEST_CASE_P(
+ ReadExceedsReadaheadSize, ReadaheadSequentialFileTest,
+ ::testing::ValuesIn(ReadaheadSequentialFileTest::GetReadaheadSizeList()));
+
+namespace {
+std::string GenerateLine(int n) {
+ std::string rv;
+ // Multiples of 17 characters per line, for likely bad buffer alignment
+ for (int i = 0; i < n; ++i) {
+ rv.push_back(static_cast<char>('0' + (i % 10)));
+ rv.append("xxxxxxxxxxxxxxxx");
+ }
+ return rv;
+}
+} // namespace
+
+TEST(LineFileReaderTest, LineFileReaderTest) {
+ const int nlines = 1000;
+
+ std::unique_ptr<Env> mem_env(MockEnv::Create(Env::Default()));
+ std::shared_ptr<FileSystem> fs = mem_env->GetFileSystem();
+ // Create an input file
+ {
+ std::unique_ptr<FSWritableFile> file;
+ ASSERT_OK(
+ fs->NewWritableFile("testfile", FileOptions(), &file, /*dbg*/ nullptr));
+
+ for (int i = 0; i < nlines; ++i) {
+ std::string line = GenerateLine(i);
+ line.push_back('\n');
+ ASSERT_OK(file->Append(line, IOOptions(), /*dbg*/ nullptr));
+ }
+ }
+
+ // Verify with no I/O errors
+ {
+ std::unique_ptr<LineFileReader> reader;
+ ASSERT_OK(LineFileReader::Create(fs, "testfile", FileOptions(), &reader,
+ nullptr /* dbg */,
+ nullptr /* rate_limiter */));
+ std::string line;
+ int count = 0;
+ while (reader->ReadLine(&line, Env::IO_TOTAL /* rate_limiter_priority */)) {
+ ASSERT_EQ(line, GenerateLine(count));
+ ++count;
+ ASSERT_EQ(static_cast<int>(reader->GetLineNumber()), count);
+ }
+ ASSERT_OK(reader->GetStatus());
+ ASSERT_EQ(count, nlines);
+ ASSERT_EQ(static_cast<int>(reader->GetLineNumber()), count);
+ // And still
+ ASSERT_FALSE(
+ reader->ReadLine(&line, Env::IO_TOTAL /* rate_limiter_priority */));
+ ASSERT_OK(reader->GetStatus());
+ ASSERT_EQ(static_cast<int>(reader->GetLineNumber()), count);
+ }
+
+ // Verify with injected I/O error
+ {
+ std::unique_ptr<LineFileReader> reader;
+ ASSERT_OK(LineFileReader::Create(fs, "testfile", FileOptions(), &reader,
+ nullptr /* dbg */,
+ nullptr /* rate_limiter */));
+ std::string line;
+ int count = 0;
+ // Read part way through the file
+ while (count < nlines / 4) {
+ ASSERT_TRUE(
+ reader->ReadLine(&line, Env::IO_TOTAL /* rate_limiter_priority */));
+ ASSERT_EQ(line, GenerateLine(count));
+ ++count;
+ ASSERT_EQ(static_cast<int>(reader->GetLineNumber()), count);
+ }
+ ASSERT_OK(reader->GetStatus());
+
+ // Inject error
+ int callback_count = 0;
+ SyncPoint::GetInstance()->SetCallBack(
+ "MemFile::Read:IOStatus", [&](void* arg) {
+ IOStatus* status = static_cast<IOStatus*>(arg);
+ *status = IOStatus::Corruption("test");
+ ++callback_count;
+ });
+ SyncPoint::GetInstance()->EnableProcessing();
+
+ while (reader->ReadLine(&line, Env::IO_TOTAL /* rate_limiter_priority */)) {
+ ASSERT_EQ(line, GenerateLine(count));
+ ++count;
+ ASSERT_EQ(static_cast<int>(reader->GetLineNumber()), count);
+ }
+ ASSERT_TRUE(reader->GetStatus().IsCorruption());
+ ASSERT_LT(count, nlines / 2);
+ ASSERT_EQ(callback_count, 1);
+
+ // Still get error & no retry
+ ASSERT_FALSE(
+ reader->ReadLine(&line, Env::IO_TOTAL /* rate_limiter_priority */));
+ ASSERT_TRUE(reader->GetStatus().IsCorruption());
+ ASSERT_EQ(callback_count, 1);
+
+ SyncPoint::GetInstance()->DisableProcessing();
+ SyncPoint::GetInstance()->ClearAllCallBacks();
+ }
+}
+
+#ifndef ROCKSDB_LITE
+class IOErrorEventListener : public EventListener {
+ public:
+ IOErrorEventListener() { notify_error_.store(0); }
+
+ void OnIOError(const IOErrorInfo& io_error_info) override {
+ notify_error_++;
+ EXPECT_FALSE(io_error_info.file_path.empty());
+ EXPECT_FALSE(io_error_info.io_status.ok());
+ }
+
+ size_t NotifyErrorCount() { return notify_error_; }
+
+ bool ShouldBeNotifiedOnFileIO() override { return true; }
+
+ private:
+ std::atomic<size_t> notify_error_;
+};
+
+TEST_F(DBWritableFileWriterTest, IOErrorNotification) {
+ class FakeWF : public FSWritableFile {
+ public:
+ explicit FakeWF() : io_error_(false) {
+ file_append_errors_.store(0);
+ file_flush_errors_.store(0);
+ }
+
+ using FSWritableFile::Append;
+ IOStatus Append(const Slice& /*data*/, const IOOptions& /*options*/,
+ IODebugContext* /*dbg*/) override {
+ if (io_error_) {
+ file_append_errors_++;
+ return IOStatus::IOError("Fake IO error");
+ }
+ return IOStatus::OK();
+ }
+
+ using FSWritableFile::PositionedAppend;
+ IOStatus PositionedAppend(const Slice& /*data*/, uint64_t,
+ const IOOptions& /*options*/,
+ IODebugContext* /*dbg*/) override {
+ if (io_error_) {
+ return IOStatus::IOError("Fake IO error");
+ }
+ return IOStatus::OK();
+ }
+ IOStatus Close(const IOOptions& /*options*/,
+ IODebugContext* /*dbg*/) override {
+ return IOStatus::OK();
+ }
+ IOStatus Flush(const IOOptions& /*options*/,
+ IODebugContext* /*dbg*/) override {
+ if (io_error_) {
+ file_flush_errors_++;
+ return IOStatus::IOError("Fake IO error");
+ }
+ return IOStatus::OK();
+ }
+ IOStatus Sync(const IOOptions& /*options*/,
+ IODebugContext* /*dbg*/) override {
+ return IOStatus::OK();
+ }
+
+ void SetIOError(bool val) { io_error_ = val; }
+
+ void CheckCounters(int file_append_errors, int file_flush_errors) {
+ ASSERT_EQ(file_append_errors, file_append_errors_);
+ ASSERT_EQ(file_flush_errors_, file_flush_errors);
+ }
+
+ protected:
+ bool io_error_;
+ std::atomic<size_t> file_append_errors_;
+ std::atomic<size_t> file_flush_errors_;
+ };
+
+ FileOptions file_options = FileOptions();
+ Options options = GetDefaultOptions();
+ options.create_if_missing = true;
+ IOErrorEventListener* listener = new IOErrorEventListener();
+ options.listeners.emplace_back(listener);
+
+ DestroyAndReopen(options);
+ ImmutableOptions ioptions(options);
+
+ std::string fname = dbname_ + "/test_file";
+ std::unique_ptr<FakeWF> writable_file_ptr(new FakeWF);
+
+ std::unique_ptr<WritableFileWriter> file_writer;
+ writable_file_ptr->SetIOError(true);
+
+ file_writer.reset(new WritableFileWriter(
+ std::move(writable_file_ptr), fname, file_options,
+ SystemClock::Default().get(), nullptr, ioptions.stats, ioptions.listeners,
+ ioptions.file_checksum_gen_factory.get(), true, true));
+
+ FakeWF* fwf = static_cast<FakeWF*>(file_writer->writable_file());
+
+ fwf->SetIOError(true);
+ ASSERT_NOK(file_writer->Append(std::string(2 * kMb, 'a')));
+ fwf->CheckCounters(1, 0);
+ ASSERT_EQ(listener->NotifyErrorCount(), 1);
+
+ file_writer->reset_seen_error();
+ fwf->SetIOError(true);
+ ASSERT_NOK(file_writer->Flush());
+ fwf->CheckCounters(1, 1);
+ ASSERT_EQ(listener->NotifyErrorCount(), 2);
+
+ /* No error generation */
+ file_writer->reset_seen_error();
+ fwf->SetIOError(false);
+ ASSERT_OK(file_writer->Append(std::string(2 * kMb, 'b')));
+ ASSERT_EQ(listener->NotifyErrorCount(), 2);
+ fwf->CheckCounters(1, 1);
+}
+#endif // ROCKSDB_LITE
+
+class WritableFileWriterIOPriorityTest : public testing::Test {
+ protected:
+ // This test is to check whether the rate limiter priority can be passed
+ // correctly from WritableFileWriter functions to FSWritableFile functions.
+
+ void SetUp() override {
+ // When op_rate_limiter_priority parameter in WritableFileWriter functions
+ // is the default (Env::IO_TOTAL).
+ std::unique_ptr<FakeWF> wf{new FakeWF(Env::IO_HIGH)};
+ FileOptions file_options;
+ writer_.reset(new WritableFileWriter(std::move(wf), "" /* don't care */,
+ file_options));
+ }
+
+ class FakeWF : public FSWritableFile {
+ public:
+ explicit FakeWF(Env::IOPriority io_priority) { SetIOPriority(io_priority); }
+ ~FakeWF() override {}
+
+ IOStatus Append(const Slice& /*data*/, const IOOptions& options,
+ IODebugContext* /*dbg*/) override {
+ EXPECT_EQ(options.rate_limiter_priority, io_priority_);
+ return IOStatus::OK();
+ }
+ IOStatus Append(const Slice& data, const IOOptions& options,
+ const DataVerificationInfo& /* verification_info */,
+ IODebugContext* dbg) override {
+ return Append(data, options, dbg);
+ }
+ IOStatus PositionedAppend(const Slice& /*data*/, uint64_t /*offset*/,
+ const IOOptions& options,
+ IODebugContext* /*dbg*/) override {
+ EXPECT_EQ(options.rate_limiter_priority, io_priority_);
+ return IOStatus::OK();
+ }
+ IOStatus PositionedAppend(
+ const Slice& /* data */, uint64_t /* offset */,
+ const IOOptions& options,
+ const DataVerificationInfo& /* verification_info */,
+ IODebugContext* /*dbg*/) override {
+ EXPECT_EQ(options.rate_limiter_priority, io_priority_);
+ return IOStatus::OK();
+ }
+ IOStatus Truncate(uint64_t /*size*/, const IOOptions& options,
+ IODebugContext* /*dbg*/) override {
+ EXPECT_EQ(options.rate_limiter_priority, io_priority_);
+ return IOStatus::OK();
+ }
+ IOStatus Close(const IOOptions& options, IODebugContext* /*dbg*/) override {
+ EXPECT_EQ(options.rate_limiter_priority, io_priority_);
+ return IOStatus::OK();
+ }
+ IOStatus Flush(const IOOptions& options, IODebugContext* /*dbg*/) override {
+ EXPECT_EQ(options.rate_limiter_priority, io_priority_);
+ return IOStatus::OK();
+ }
+ IOStatus Sync(const IOOptions& options, IODebugContext* /*dbg*/) override {
+ EXPECT_EQ(options.rate_limiter_priority, io_priority_);
+ return IOStatus::OK();
+ }
+ IOStatus Fsync(const IOOptions& options, IODebugContext* /*dbg*/) override {
+ EXPECT_EQ(options.rate_limiter_priority, io_priority_);
+ return IOStatus::OK();
+ }
+ uint64_t GetFileSize(const IOOptions& options,
+ IODebugContext* /*dbg*/) override {
+ EXPECT_EQ(options.rate_limiter_priority, io_priority_);
+ return 0;
+ }
+ void GetPreallocationStatus(size_t* /*block_size*/,
+ size_t* /*last_allocated_block*/) override {}
+ size_t GetUniqueId(char* /*id*/, size_t /*max_size*/) const override {
+ return 0;
+ }
+ IOStatus InvalidateCache(size_t /*offset*/, size_t /*length*/) override {
+ return IOStatus::OK();
+ }
+
+ IOStatus Allocate(uint64_t /*offset*/, uint64_t /*len*/,
+ const IOOptions& options,
+ IODebugContext* /*dbg*/) override {
+ EXPECT_EQ(options.rate_limiter_priority, io_priority_);
+ return IOStatus::OK();
+ }
+ IOStatus RangeSync(uint64_t /*offset*/, uint64_t /*nbytes*/,
+ const IOOptions& options,
+ IODebugContext* /*dbg*/) override {
+ EXPECT_EQ(options.rate_limiter_priority, io_priority_);
+ return IOStatus::OK();
+ }
+
+ void PrepareWrite(size_t /*offset*/, size_t /*len*/,
+ const IOOptions& options,
+ IODebugContext* /*dbg*/) override {
+ EXPECT_EQ(options.rate_limiter_priority, io_priority_);
+ }
+
+ bool IsSyncThreadSafe() const override { return true; }
+ };
+
+ std::unique_ptr<WritableFileWriter> writer_;
+};
+
+TEST_F(WritableFileWriterIOPriorityTest, Append) {
+ ASSERT_OK(writer_->Append(Slice("abc")));
+}
+
+TEST_F(WritableFileWriterIOPriorityTest, Pad) { ASSERT_OK(writer_->Pad(500)); }
+
+TEST_F(WritableFileWriterIOPriorityTest, Flush) { ASSERT_OK(writer_->Flush()); }
+
+TEST_F(WritableFileWriterIOPriorityTest, Close) { ASSERT_OK(writer_->Close()); }
+
+TEST_F(WritableFileWriterIOPriorityTest, Sync) {
+ ASSERT_OK(writer_->Sync(false));
+ ASSERT_OK(writer_->Sync(true));
+}
+
+TEST_F(WritableFileWriterIOPriorityTest, SyncWithoutFlush) {
+ ASSERT_OK(writer_->SyncWithoutFlush(false));
+ ASSERT_OK(writer_->SyncWithoutFlush(true));
+}
+
+TEST_F(WritableFileWriterIOPriorityTest, BasicOp) {
+ EnvOptions env_options;
+ env_options.bytes_per_sync = kMb;
+ std::unique_ptr<FakeWF> wf(new FakeWF(Env::IO_HIGH));
+ std::unique_ptr<WritableFileWriter> writer(
+ new WritableFileWriter(std::move(wf), "" /* don't care */, env_options));
+ Random r(301);
+ Status s;
+ std::unique_ptr<char[]> large_buf(new char[10 * kMb]);
+ for (int i = 0; i < 1000; i++) {
+ int skew_limit = (i < 700) ? 10 : 15;
+ uint32_t num = r.Skewed(skew_limit) * 100 + r.Uniform(100);
+ s = writer->Append(Slice(large_buf.get(), num));
+ ASSERT_OK(s);
+
+ // Flush in a chance of 1/10.
+ if (r.Uniform(10) == 0) {
+ s = writer->Flush();
+ ASSERT_OK(s);
+ }
+ }
+ s = writer->Close();
+ ASSERT_OK(s);
+}
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/rocksdb/util/filelock_test.cc b/src/rocksdb/util/filelock_test.cc
new file mode 100644
index 000000000..69947a732
--- /dev/null
+++ b/src/rocksdb/util/filelock_test.cc
@@ -0,0 +1,148 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+#include <fcntl.h>
+
+#include "rocksdb/env.h"
+#include "rocksdb/status.h"
+#ifdef __FreeBSD__
+#include <sys/types.h>
+#include <sys/wait.h>
+#endif
+#include <vector>
+
+#include "test_util/testharness.h"
+#include "util/coding.h"
+#include "util/string_util.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class LockTest : public testing::Test {
+ public:
+ static LockTest* current_;
+ std::string file_;
+ ROCKSDB_NAMESPACE::Env* env_;
+
+ LockTest()
+ : file_(test::PerThreadDBPath("db_testlock_file")),
+ env_(ROCKSDB_NAMESPACE::Env::Default()) {
+ current_ = this;
+ }
+
+ ~LockTest() override {}
+
+ Status LockFile(FileLock** db_lock) { return env_->LockFile(file_, db_lock); }
+
+ Status UnlockFile(FileLock* db_lock) { return env_->UnlockFile(db_lock); }
+
+ bool AssertFileIsLocked() {
+ return CheckFileLock(/* lock_expected = */ true);
+ }
+
+ bool AssertFileIsNotLocked() {
+ return CheckFileLock(/* lock_expected = */ false);
+ }
+
+ bool CheckFileLock(bool lock_expected) {
+ // We need to fork to check the fcntl lock as we need
+ // to open and close the file from a different process
+ // to avoid either releasing the lock on close, or not
+ // contending for it when requesting a lock.
+
+#ifdef OS_WIN
+
+ // WaitForSingleObject and GetExitCodeProcess can do what waitpid does.
+ // TODO - implement on Windows
+ return true;
+
+#else
+
+ pid_t pid = fork();
+ if (0 == pid) {
+ // child process
+ int exit_val = EXIT_FAILURE;
+ int fd = open(file_.c_str(), O_RDWR | O_CREAT, 0644);
+ if (fd < 0) {
+ // could not open file, could not check if it was locked
+ fprintf(stderr, "Open on on file %s failed.\n", file_.c_str());
+ exit(exit_val);
+ }
+
+ struct flock f;
+ memset(&f, 0, sizeof(f));
+ f.l_type = (F_WRLCK);
+ f.l_whence = SEEK_SET;
+ f.l_start = 0;
+ f.l_len = 0; // Lock/unlock entire file
+ int value = fcntl(fd, F_SETLK, &f);
+ if (value == -1) {
+ if (lock_expected) {
+ exit_val = EXIT_SUCCESS;
+ }
+ } else {
+ if (!lock_expected) {
+ exit_val = EXIT_SUCCESS;
+ }
+ }
+ close(fd); // lock is released for child process
+ exit(exit_val);
+ } else if (pid > 0) {
+ // parent process
+ int status;
+ while (-1 == waitpid(pid, &status, 0))
+ ;
+ if (!WIFEXITED(status) || WEXITSTATUS(status) != 0) {
+ // child process exited with non success status
+ return false;
+ } else {
+ return true;
+ }
+ } else {
+ fprintf(stderr, "Fork failed\n");
+ return false;
+ }
+ return false;
+
+#endif
+ }
+};
+LockTest* LockTest::current_;
+
+TEST_F(LockTest, LockBySameThread) {
+ FileLock* lock1;
+ FileLock* lock2;
+
+ // acquire a lock on a file
+ ASSERT_OK(LockFile(&lock1));
+
+ // check the file is locked
+ ASSERT_TRUE(AssertFileIsLocked());
+
+ // re-acquire the lock on the same file. This should fail.
+ Status s = LockFile(&lock2);
+ ASSERT_TRUE(s.IsIOError());
+#ifndef OS_WIN
+ // Validate that error message contains current thread ID.
+ ASSERT_TRUE(s.ToString().find(std::to_string(
+ Env::Default()->GetThreadID())) != std::string::npos);
+#endif
+
+ // check the file is locked
+ ASSERT_TRUE(AssertFileIsLocked());
+
+ // release the lock
+ ASSERT_OK(UnlockFile(lock1));
+
+ // check the file is not locked
+ ASSERT_TRUE(AssertFileIsNotLocked());
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/rocksdb/util/filter_bench.cc b/src/rocksdb/util/filter_bench.cc
new file mode 100644
index 000000000..93186cd08
--- /dev/null
+++ b/src/rocksdb/util/filter_bench.cc
@@ -0,0 +1,840 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#if !defined(GFLAGS) || defined(ROCKSDB_LITE)
+#include <cstdio>
+int main() {
+ fprintf(stderr, "filter_bench requires gflags and !ROCKSDB_LITE\n");
+ return 1;
+}
+#else
+
+#include <cinttypes>
+#include <iostream>
+#include <sstream>
+#include <vector>
+
+#include "memory/arena.h"
+#include "port/port.h"
+#include "port/stack_trace.h"
+#include "rocksdb/cache.h"
+#include "rocksdb/env.h"
+#include "rocksdb/system_clock.h"
+#include "rocksdb/table.h"
+#include "table/block_based/filter_policy_internal.h"
+#include "table/block_based/full_filter_block.h"
+#include "table/block_based/mock_block_based_table.h"
+#include "table/plain/plain_table_bloom.h"
+#include "util/cast_util.h"
+#include "util/gflags_compat.h"
+#include "util/hash.h"
+#include "util/random.h"
+#include "util/stderr_logger.h"
+#include "util/stop_watch.h"
+#include "util/string_util.h"
+
+using GFLAGS_NAMESPACE::ParseCommandLineFlags;
+using GFLAGS_NAMESPACE::RegisterFlagValidator;
+using GFLAGS_NAMESPACE::SetUsageMessage;
+
+DEFINE_uint32(seed, 0, "Seed for random number generators");
+
+DEFINE_double(working_mem_size_mb, 200,
+ "MB of memory to get up to among all filters, unless "
+ "m_keys_total_max is specified.");
+
+DEFINE_uint32(average_keys_per_filter, 10000,
+ "Average number of keys per filter");
+
+DEFINE_double(vary_key_count_ratio, 0.4,
+ "Vary number of keys by up to +/- vary_key_count_ratio * "
+ "average_keys_per_filter.");
+
+DEFINE_uint32(key_size, 24, "Average number of bytes for each key");
+
+DEFINE_bool(vary_key_alignment, true,
+ "Whether to vary key alignment (default: at least 32-bit "
+ "alignment)");
+
+DEFINE_uint32(vary_key_size_log2_interval, 5,
+ "Use same key size 2^n times, then change. Key size varies from "
+ "-2 to +2 bytes vs. average, unless n>=30 to fix key size.");
+
+DEFINE_uint32(batch_size, 8, "Number of keys to group in each batch");
+
+DEFINE_double(bits_per_key, 10.0, "Bits per key setting for filters");
+
+DEFINE_double(m_queries, 200, "Millions of queries for each test mode");
+
+DEFINE_double(m_keys_total_max, 0,
+ "Maximum total keys added to filters, in millions. "
+ "0 (default) disables. Non-zero overrides working_mem_size_mb "
+ "option.");
+
+DEFINE_bool(use_full_block_reader, false,
+ "Use FullFilterBlockReader interface rather than FilterBitsReader");
+
+DEFINE_bool(use_plain_table_bloom, false,
+ "Use PlainTableBloom structure and interface rather than "
+ "FilterBitsReader/FullFilterBlockReader");
+
+DEFINE_bool(new_builder, false,
+ "Whether to create a new builder for each new filter");
+
+DEFINE_uint32(impl, 0,
+ "Select filter implementation. Without -use_plain_table_bloom:"
+ "0 = legacy full Bloom filter, "
+ "1 = format_version 5 Bloom filter, 2 = Ribbon128 filter. With "
+ "-use_plain_table_bloom: 0 = no locality, 1 = locality.");
+
+DEFINE_bool(net_includes_hashing, false,
+ "Whether query net ns/op times should include hashing. "
+ "(if not, dry run will include hashing) "
+ "(build times always include hashing)");
+
+DEFINE_bool(optimize_filters_for_memory, false,
+ "Setting for BlockBasedTableOptions::optimize_filters_for_memory");
+
+DEFINE_bool(detect_filter_construct_corruption, false,
+ "Setting for "
+ "BlockBasedTableOptions::detect_filter_construct_corruption");
+
+DEFINE_uint32(block_cache_capacity_MB, 8,
+ "Setting for "
+ "LRUCacheOptions::capacity");
+
+DEFINE_bool(charge_filter_construction, false,
+ "Setting for "
+ "CacheEntryRoleOptions::charged of"
+ "CacheEntryRole::kFilterConstruction");
+
+DEFINE_bool(strict_capacity_limit, false,
+ "Setting for "
+ "LRUCacheOptions::strict_capacity_limit");
+
+DEFINE_bool(quick, false, "Run more limited set of tests, fewer queries");
+
+DEFINE_bool(best_case, false, "Run limited tests only for best-case");
+
+DEFINE_bool(allow_bad_fp_rate, false, "Continue even if FP rate is bad");
+
+DEFINE_bool(legend, false,
+ "Print more information about interpreting results instead of "
+ "running tests");
+
+DEFINE_uint32(runs, 1, "Number of times to rebuild and run benchmark tests");
+
+void _always_assert_fail(int line, const char *file, const char *expr) {
+ fprintf(stderr, "%s: %d: Assertion %s failed\n", file, line, expr);
+ abort();
+}
+
+#define ALWAYS_ASSERT(cond) \
+ ((cond) ? (void)0 : ::_always_assert_fail(__LINE__, __FILE__, #cond))
+
+#ifndef NDEBUG
+// This could affect build times enough that we should not include it for
+// accurate speed tests
+#define PREDICT_FP_RATE
+#endif
+
+using ROCKSDB_NAMESPACE::Arena;
+using ROCKSDB_NAMESPACE::BlockContents;
+using ROCKSDB_NAMESPACE::BloomFilterPolicy;
+using ROCKSDB_NAMESPACE::BloomHash;
+using ROCKSDB_NAMESPACE::BloomLikeFilterPolicy;
+using ROCKSDB_NAMESPACE::BuiltinFilterBitsBuilder;
+using ROCKSDB_NAMESPACE::CachableEntry;
+using ROCKSDB_NAMESPACE::Cache;
+using ROCKSDB_NAMESPACE::CacheEntryRole;
+using ROCKSDB_NAMESPACE::CacheEntryRoleOptions;
+using ROCKSDB_NAMESPACE::EncodeFixed32;
+using ROCKSDB_NAMESPACE::Env;
+using ROCKSDB_NAMESPACE::FastRange32;
+using ROCKSDB_NAMESPACE::FilterBitsReader;
+using ROCKSDB_NAMESPACE::FilterBuildingContext;
+using ROCKSDB_NAMESPACE::FilterPolicy;
+using ROCKSDB_NAMESPACE::FullFilterBlockReader;
+using ROCKSDB_NAMESPACE::GetSliceHash;
+using ROCKSDB_NAMESPACE::GetSliceHash64;
+using ROCKSDB_NAMESPACE::Lower32of64;
+using ROCKSDB_NAMESPACE::LRUCacheOptions;
+using ROCKSDB_NAMESPACE::ParsedFullFilterBlock;
+using ROCKSDB_NAMESPACE::PlainTableBloomV1;
+using ROCKSDB_NAMESPACE::Random32;
+using ROCKSDB_NAMESPACE::Slice;
+using ROCKSDB_NAMESPACE::static_cast_with_check;
+using ROCKSDB_NAMESPACE::Status;
+using ROCKSDB_NAMESPACE::StderrLogger;
+using ROCKSDB_NAMESPACE::mock::MockBlockBasedTableTester;
+
+struct KeyMaker {
+ KeyMaker(size_t avg_size)
+ : smallest_size_(avg_size -
+ (FLAGS_vary_key_size_log2_interval >= 30 ? 2 : 0)),
+ buf_size_(avg_size + 11), // pad to vary key size and alignment
+ buf_(new char[buf_size_]) {
+ memset(buf_.get(), 0, buf_size_);
+ assert(smallest_size_ > 8);
+ }
+ size_t smallest_size_;
+ size_t buf_size_;
+ std::unique_ptr<char[]> buf_;
+
+ // Returns a unique(-ish) key based on the given parameter values. Each
+ // call returns a Slice from the same buffer so previously returned
+ // Slices should be considered invalidated.
+ Slice Get(uint32_t filter_num, uint32_t val_num) {
+ size_t start = FLAGS_vary_key_alignment ? val_num % 4 : 0;
+ size_t len = smallest_size_;
+ if (FLAGS_vary_key_size_log2_interval < 30) {
+ // To get range [avg_size - 2, avg_size + 2]
+ // use range [smallest_size, smallest_size + 4]
+ len += FastRange32(
+ (val_num >> FLAGS_vary_key_size_log2_interval) * 1234567891, 5);
+ }
+ char *data = buf_.get() + start;
+ // Populate key data such that all data makes it into a key of at
+ // least 8 bytes. We also don't want all the within-filter key
+ // variance confined to a contiguous 32 bits, because then a 32 bit
+ // hash function can "cheat" the false positive rate by
+ // approximating a perfect hash.
+ EncodeFixed32(data, val_num);
+ EncodeFixed32(data + 4, filter_num + val_num);
+ // ensure clearing leftovers from different alignment
+ EncodeFixed32(data + 8, 0);
+ return Slice(data, len);
+ }
+};
+
+void PrintWarnings() {
+#if defined(__GNUC__) && !defined(__OPTIMIZE__)
+ fprintf(stdout,
+ "WARNING: Optimization is disabled: benchmarks unnecessarily slow\n");
+#endif
+#ifndef NDEBUG
+ fprintf(stdout,
+ "WARNING: Assertions are enabled; benchmarks unnecessarily slow\n");
+#endif
+}
+
+void PrintError(const char *error) { fprintf(stderr, "ERROR: %s\n", error); }
+
+struct FilterInfo {
+ uint32_t filter_id_ = 0;
+ std::unique_ptr<const char[]> owner_;
+ Slice filter_;
+ Status filter_construction_status = Status::OK();
+ uint32_t keys_added_ = 0;
+ std::unique_ptr<FilterBitsReader> reader_;
+ std::unique_ptr<FullFilterBlockReader> full_block_reader_;
+ std::unique_ptr<PlainTableBloomV1> plain_table_bloom_;
+ uint64_t outside_queries_ = 0;
+ uint64_t false_positives_ = 0;
+};
+
+enum TestMode {
+ kSingleFilter,
+ kBatchPrepared,
+ kBatchUnprepared,
+ kFiftyOneFilter,
+ kEightyTwentyFilter,
+ kRandomFilter,
+};
+
+static const std::vector<TestMode> allTestModes = {
+ kSingleFilter, kBatchPrepared, kBatchUnprepared,
+ kFiftyOneFilter, kEightyTwentyFilter, kRandomFilter,
+};
+
+static const std::vector<TestMode> quickTestModes = {
+ kSingleFilter,
+ kRandomFilter,
+};
+
+static const std::vector<TestMode> bestCaseTestModes = {
+ kSingleFilter,
+};
+
+const char *TestModeToString(TestMode tm) {
+ switch (tm) {
+ case kSingleFilter:
+ return "Single filter";
+ case kBatchPrepared:
+ return "Batched, prepared";
+ case kBatchUnprepared:
+ return "Batched, unprepared";
+ case kFiftyOneFilter:
+ return "Skewed 50% in 1%";
+ case kEightyTwentyFilter:
+ return "Skewed 80% in 20%";
+ case kRandomFilter:
+ return "Random filter";
+ }
+ return "Bad TestMode";
+}
+
+// Do just enough to keep some data dependence for the
+// compiler / CPU
+static uint32_t DryRunNoHash(Slice &s) {
+ uint32_t sz = static_cast<uint32_t>(s.size());
+ if (sz >= 4) {
+ return sz + s.data()[3];
+ } else {
+ return sz;
+ }
+}
+
+static uint32_t DryRunHash32(Slice &s) {
+ // Same perf characteristics as GetSliceHash()
+ return BloomHash(s);
+}
+
+static uint32_t DryRunHash64(Slice &s) {
+ return Lower32of64(GetSliceHash64(s));
+}
+
+const std::shared_ptr<const FilterPolicy> &GetPolicy() {
+ static std::shared_ptr<const FilterPolicy> policy;
+ if (!policy) {
+ policy = BloomLikeFilterPolicy::Create(
+ BloomLikeFilterPolicy::GetAllFixedImpls().at(FLAGS_impl),
+ FLAGS_bits_per_key);
+ }
+ return policy;
+}
+
+struct FilterBench : public MockBlockBasedTableTester {
+ std::vector<KeyMaker> kms_;
+ std::vector<FilterInfo> infos_;
+ Random32 random_;
+ std::ostringstream fp_rate_report_;
+ Arena arena_;
+ double m_queries_;
+ StderrLogger stderr_logger_;
+
+ FilterBench()
+ : MockBlockBasedTableTester(GetPolicy()),
+ random_(FLAGS_seed),
+ m_queries_(0) {
+ for (uint32_t i = 0; i < FLAGS_batch_size; ++i) {
+ kms_.emplace_back(FLAGS_key_size < 8 ? 8 : FLAGS_key_size);
+ }
+ ioptions_.logger = &stderr_logger_;
+ table_options_.optimize_filters_for_memory =
+ FLAGS_optimize_filters_for_memory;
+ table_options_.detect_filter_construct_corruption =
+ FLAGS_detect_filter_construct_corruption;
+ table_options_.cache_usage_options.options_overrides.insert(
+ {CacheEntryRole::kFilterConstruction,
+ {/*.charged = */ FLAGS_charge_filter_construction
+ ? CacheEntryRoleOptions::Decision::kEnabled
+ : CacheEntryRoleOptions::Decision::kDisabled}});
+ if (FLAGS_charge_filter_construction) {
+ table_options_.no_block_cache = false;
+ LRUCacheOptions lo;
+ lo.capacity = FLAGS_block_cache_capacity_MB * 1024 * 1024;
+ lo.num_shard_bits = 0; // 2^0 shard
+ lo.strict_capacity_limit = FLAGS_strict_capacity_limit;
+ std::shared_ptr<Cache> cache(NewLRUCache(lo));
+ table_options_.block_cache = cache;
+ }
+ }
+
+ void Go();
+
+ double RandomQueryTest(uint32_t inside_threshold, bool dry_run,
+ TestMode mode);
+};
+
+void FilterBench::Go() {
+ if (FLAGS_use_plain_table_bloom && FLAGS_use_full_block_reader) {
+ throw std::runtime_error(
+ "Can't combine -use_plain_table_bloom and -use_full_block_reader");
+ }
+ if (FLAGS_use_plain_table_bloom) {
+ if (FLAGS_impl > 1) {
+ throw std::runtime_error(
+ "-impl must currently be >= 0 and <= 1 for Plain table");
+ }
+ } else {
+ if (FLAGS_impl > 2) {
+ throw std::runtime_error(
+ "-impl must currently be >= 0 and <= 2 for Block-based table");
+ }
+ }
+
+ if (FLAGS_vary_key_count_ratio < 0.0 || FLAGS_vary_key_count_ratio > 1.0) {
+ throw std::runtime_error("-vary_key_count_ratio must be >= 0.0 and <= 1.0");
+ }
+
+ // For example, average_keys_per_filter = 100, vary_key_count_ratio = 0.1.
+ // Varys up to +/- 10 keys. variance_range = 21 (generating value 0..20).
+ // variance_offset = 10, so value - offset average value is always 0.
+ const uint32_t variance_range =
+ 1 + 2 * static_cast<uint32_t>(FLAGS_vary_key_count_ratio *
+ FLAGS_average_keys_per_filter);
+ const uint32_t variance_offset = variance_range / 2;
+
+ const std::vector<TestMode> &testModes = FLAGS_best_case ? bestCaseTestModes
+ : FLAGS_quick ? quickTestModes
+ : allTestModes;
+
+ m_queries_ = FLAGS_m_queries;
+ double working_mem_size_mb = FLAGS_working_mem_size_mb;
+ if (FLAGS_quick) {
+ m_queries_ /= 7.0;
+ } else if (FLAGS_best_case) {
+ m_queries_ /= 3.0;
+ working_mem_size_mb /= 10.0;
+ }
+
+ std::cout << "Building..." << std::endl;
+
+ std::unique_ptr<BuiltinFilterBitsBuilder> builder;
+
+ size_t total_memory_used = 0;
+ size_t total_size = 0;
+ size_t total_keys_added = 0;
+#ifdef PREDICT_FP_RATE
+ double weighted_predicted_fp_rate = 0.0;
+#endif
+ size_t max_total_keys;
+ size_t max_mem;
+ if (FLAGS_m_keys_total_max > 0) {
+ max_total_keys = static_cast<size_t>(1000000 * FLAGS_m_keys_total_max);
+ max_mem = SIZE_MAX;
+ } else {
+ max_total_keys = SIZE_MAX;
+ max_mem = static_cast<size_t>(1024 * 1024 * working_mem_size_mb);
+ }
+
+ ROCKSDB_NAMESPACE::StopWatchNano timer(
+ ROCKSDB_NAMESPACE::SystemClock::Default().get(), true);
+
+ infos_.clear();
+ while ((working_mem_size_mb == 0 || total_size < max_mem) &&
+ total_keys_added < max_total_keys) {
+ uint32_t filter_id = random_.Next();
+ uint32_t keys_to_add = FLAGS_average_keys_per_filter +
+ FastRange32(random_.Next(), variance_range) -
+ variance_offset;
+ if (max_total_keys - total_keys_added < keys_to_add) {
+ keys_to_add = static_cast<uint32_t>(max_total_keys - total_keys_added);
+ }
+ infos_.emplace_back();
+ FilterInfo &info = infos_.back();
+ info.filter_id_ = filter_id;
+ info.keys_added_ = keys_to_add;
+ if (FLAGS_use_plain_table_bloom) {
+ info.plain_table_bloom_.reset(new PlainTableBloomV1());
+ info.plain_table_bloom_->SetTotalBits(
+ &arena_, static_cast<uint32_t>(keys_to_add * FLAGS_bits_per_key),
+ FLAGS_impl, 0 /*huge_page*/, nullptr /*logger*/);
+ for (uint32_t i = 0; i < keys_to_add; ++i) {
+ uint32_t hash = GetSliceHash(kms_[0].Get(filter_id, i));
+ info.plain_table_bloom_->AddHash(hash);
+ }
+ info.filter_ = info.plain_table_bloom_->GetRawData();
+ } else {
+ if (!builder) {
+ builder.reset(
+ static_cast_with_check<BuiltinFilterBitsBuilder>(GetBuilder()));
+ }
+ for (uint32_t i = 0; i < keys_to_add; ++i) {
+ builder->AddKey(kms_[0].Get(filter_id, i));
+ }
+ info.filter_ =
+ builder->Finish(&info.owner_, &info.filter_construction_status);
+ if (info.filter_construction_status.ok()) {
+ info.filter_construction_status =
+ builder->MaybePostVerify(info.filter_);
+ }
+ if (!info.filter_construction_status.ok()) {
+ PrintError(info.filter_construction_status.ToString().c_str());
+ }
+#ifdef PREDICT_FP_RATE
+ weighted_predicted_fp_rate +=
+ keys_to_add *
+ builder->EstimatedFpRate(keys_to_add, info.filter_.size());
+#endif
+ if (FLAGS_new_builder) {
+ builder.reset();
+ }
+ info.reader_.reset(
+ table_options_.filter_policy->GetFilterBitsReader(info.filter_));
+ CachableEntry<ParsedFullFilterBlock> block(
+ new ParsedFullFilterBlock(table_options_.filter_policy.get(),
+ BlockContents(info.filter_)),
+ nullptr /* cache */, nullptr /* cache_handle */,
+ true /* own_value */);
+ info.full_block_reader_.reset(
+ new FullFilterBlockReader(table_.get(), std::move(block)));
+ }
+ total_size += info.filter_.size();
+#ifdef ROCKSDB_MALLOC_USABLE_SIZE
+ total_memory_used +=
+ malloc_usable_size(const_cast<char *>(info.filter_.data()));
+#endif // ROCKSDB_MALLOC_USABLE_SIZE
+ total_keys_added += keys_to_add;
+ }
+
+ uint64_t elapsed_nanos = timer.ElapsedNanos();
+ double ns = double(elapsed_nanos) / total_keys_added;
+ std::cout << "Build avg ns/key: " << ns << std::endl;
+ std::cout << "Number of filters: " << infos_.size() << std::endl;
+ std::cout << "Total size (MB): " << total_size / 1024.0 / 1024.0 << std::endl;
+ if (total_memory_used > 0) {
+ std::cout << "Reported total allocated memory (MB): "
+ << total_memory_used / 1024.0 / 1024.0 << std::endl;
+ std::cout << "Reported internal fragmentation: "
+ << (total_memory_used - total_size) * 100.0 / total_size << "%"
+ << std::endl;
+ }
+
+ double bpk = total_size * 8.0 / total_keys_added;
+ std::cout << "Bits/key stored: " << bpk << std::endl;
+#ifdef PREDICT_FP_RATE
+ std::cout << "Predicted FP rate %: "
+ << 100.0 * (weighted_predicted_fp_rate / total_keys_added)
+ << std::endl;
+#endif
+ if (!FLAGS_quick && !FLAGS_best_case) {
+ double tolerable_rate = std::pow(2.0, -(bpk - 1.0) / (1.4 + bpk / 50.0));
+ std::cout << "Best possible FP rate %: " << 100.0 * std::pow(2.0, -bpk)
+ << std::endl;
+ std::cout << "Tolerable FP rate %: " << 100.0 * tolerable_rate << std::endl;
+
+ std::cout << "----------------------------" << std::endl;
+ std::cout << "Verifying..." << std::endl;
+
+ uint32_t outside_q_per_f =
+ static_cast<uint32_t>(m_queries_ * 1000000 / infos_.size());
+ uint64_t fps = 0;
+ for (uint32_t i = 0; i < infos_.size(); ++i) {
+ FilterInfo &info = infos_[i];
+ for (uint32_t j = 0; j < info.keys_added_; ++j) {
+ if (FLAGS_use_plain_table_bloom) {
+ uint32_t hash = GetSliceHash(kms_[0].Get(info.filter_id_, j));
+ ALWAYS_ASSERT(info.plain_table_bloom_->MayContainHash(hash));
+ } else {
+ ALWAYS_ASSERT(
+ info.reader_->MayMatch(kms_[0].Get(info.filter_id_, j)));
+ }
+ }
+ for (uint32_t j = 0; j < outside_q_per_f; ++j) {
+ if (FLAGS_use_plain_table_bloom) {
+ uint32_t hash =
+ GetSliceHash(kms_[0].Get(info.filter_id_, j | 0x80000000));
+ fps += info.plain_table_bloom_->MayContainHash(hash);
+ } else {
+ fps += info.reader_->MayMatch(
+ kms_[0].Get(info.filter_id_, j | 0x80000000));
+ }
+ }
+ }
+ std::cout << " No FNs :)" << std::endl;
+ double prelim_rate = double(fps) / outside_q_per_f / infos_.size();
+ std::cout << " Prelim FP rate %: " << (100.0 * prelim_rate) << std::endl;
+
+ if (!FLAGS_allow_bad_fp_rate) {
+ ALWAYS_ASSERT(prelim_rate < tolerable_rate);
+ }
+ }
+
+ std::cout << "----------------------------" << std::endl;
+ std::cout << "Mixed inside/outside queries..." << std::endl;
+ // 50% each inside and outside
+ uint32_t inside_threshold = UINT32_MAX / 2;
+ for (TestMode tm : testModes) {
+ random_.Seed(FLAGS_seed + 1);
+ double f = RandomQueryTest(inside_threshold, /*dry_run*/ false, tm);
+ random_.Seed(FLAGS_seed + 1);
+ double d = RandomQueryTest(inside_threshold, /*dry_run*/ true, tm);
+ std::cout << " " << TestModeToString(tm) << " net ns/op: " << (f - d)
+ << std::endl;
+ }
+
+ if (!FLAGS_quick) {
+ std::cout << "----------------------------" << std::endl;
+ std::cout << "Inside queries (mostly)..." << std::endl;
+ // Do about 95% inside queries rather than 100% so that branch predictor
+ // can't give itself an artifically crazy advantage.
+ inside_threshold = UINT32_MAX / 20 * 19;
+ for (TestMode tm : testModes) {
+ random_.Seed(FLAGS_seed + 1);
+ double f = RandomQueryTest(inside_threshold, /*dry_run*/ false, tm);
+ random_.Seed(FLAGS_seed + 1);
+ double d = RandomQueryTest(inside_threshold, /*dry_run*/ true, tm);
+ std::cout << " " << TestModeToString(tm) << " net ns/op: " << (f - d)
+ << std::endl;
+ }
+
+ std::cout << "----------------------------" << std::endl;
+ std::cout << "Outside queries (mostly)..." << std::endl;
+ // Do about 95% outside queries rather than 100% so that branch predictor
+ // can't give itself an artifically crazy advantage.
+ inside_threshold = UINT32_MAX / 20;
+ for (TestMode tm : testModes) {
+ random_.Seed(FLAGS_seed + 2);
+ double f = RandomQueryTest(inside_threshold, /*dry_run*/ false, tm);
+ random_.Seed(FLAGS_seed + 2);
+ double d = RandomQueryTest(inside_threshold, /*dry_run*/ true, tm);
+ std::cout << " " << TestModeToString(tm) << " net ns/op: " << (f - d)
+ << std::endl;
+ }
+ }
+ std::cout << fp_rate_report_.str();
+
+ std::cout << "----------------------------" << std::endl;
+ std::cout << "Done. (For more info, run with -legend or -help.)" << std::endl;
+}
+
+double FilterBench::RandomQueryTest(uint32_t inside_threshold, bool dry_run,
+ TestMode mode) {
+ for (auto &info : infos_) {
+ info.outside_queries_ = 0;
+ info.false_positives_ = 0;
+ }
+
+ auto dry_run_hash_fn = DryRunNoHash;
+ if (!FLAGS_net_includes_hashing) {
+ if (FLAGS_impl == 0 || FLAGS_use_plain_table_bloom) {
+ dry_run_hash_fn = DryRunHash32;
+ } else {
+ dry_run_hash_fn = DryRunHash64;
+ }
+ }
+
+ uint32_t num_infos = static_cast<uint32_t>(infos_.size());
+ uint32_t dry_run_hash = 0;
+ uint64_t max_queries = static_cast<uint64_t>(m_queries_ * 1000000 + 0.50);
+ // Some filters may be considered secondary in order to implement skewed
+ // queries. num_primary_filters is the number that are to be treated as
+ // equal, and any remainder will be treated as secondary.
+ uint32_t num_primary_filters = num_infos;
+ // The proportion (when divided by 2^32 - 1) of filter queries going to
+ // the primary filters (default = all). The remainder of queries are
+ // against secondary filters.
+ uint32_t primary_filter_threshold = 0xffffffff;
+ if (mode == kSingleFilter) {
+ // 100% of queries to 1 filter
+ num_primary_filters = 1;
+ } else if (mode == kFiftyOneFilter) {
+ if (num_infos < 50) {
+ return 0.0; // skip
+ }
+ // 50% of queries
+ primary_filter_threshold /= 2;
+ // to 1% of filters
+ num_primary_filters = (num_primary_filters + 99) / 100;
+ } else if (mode == kEightyTwentyFilter) {
+ if (num_infos < 5) {
+ return 0.0; // skip
+ }
+ // 80% of queries
+ primary_filter_threshold = primary_filter_threshold / 5 * 4;
+ // to 20% of filters
+ num_primary_filters = (num_primary_filters + 4) / 5;
+ } else if (mode == kRandomFilter) {
+ if (num_infos == 1) {
+ return 0.0; // skip
+ }
+ }
+ uint32_t batch_size = 1;
+ std::unique_ptr<Slice[]> batch_slices;
+ std::unique_ptr<Slice *[]> batch_slice_ptrs;
+ std::unique_ptr<bool[]> batch_results;
+ if (mode == kBatchPrepared || mode == kBatchUnprepared) {
+ batch_size = static_cast<uint32_t>(kms_.size());
+ }
+
+ batch_slices.reset(new Slice[batch_size]);
+ batch_slice_ptrs.reset(new Slice *[batch_size]);
+ batch_results.reset(new bool[batch_size]);
+ for (uint32_t i = 0; i < batch_size; ++i) {
+ batch_results[i] = false;
+ batch_slice_ptrs[i] = &batch_slices[i];
+ }
+
+ ROCKSDB_NAMESPACE::StopWatchNano timer(
+ ROCKSDB_NAMESPACE::SystemClock::Default().get(), true);
+
+ for (uint64_t q = 0; q < max_queries; q += batch_size) {
+ bool inside_this_time = random_.Next() <= inside_threshold;
+
+ uint32_t filter_index;
+ if (random_.Next() <= primary_filter_threshold) {
+ filter_index = random_.Uniformish(num_primary_filters);
+ } else {
+ // secondary
+ filter_index = num_primary_filters +
+ random_.Uniformish(num_infos - num_primary_filters);
+ }
+ FilterInfo &info = infos_[filter_index];
+ for (uint32_t i = 0; i < batch_size; ++i) {
+ if (inside_this_time) {
+ batch_slices[i] =
+ kms_[i].Get(info.filter_id_, random_.Uniformish(info.keys_added_));
+ } else {
+ batch_slices[i] =
+ kms_[i].Get(info.filter_id_, random_.Uniformish(info.keys_added_) |
+ uint32_t{0x80000000});
+ info.outside_queries_++;
+ }
+ }
+ // TODO: implement batched interface to full block reader
+ // TODO: implement batched interface to plain table bloom
+ if (mode == kBatchPrepared && !FLAGS_use_full_block_reader &&
+ !FLAGS_use_plain_table_bloom) {
+ for (uint32_t i = 0; i < batch_size; ++i) {
+ batch_results[i] = false;
+ }
+ if (dry_run) {
+ for (uint32_t i = 0; i < batch_size; ++i) {
+ batch_results[i] = true;
+ dry_run_hash += dry_run_hash_fn(batch_slices[i]);
+ }
+ } else {
+ info.reader_->MayMatch(batch_size, batch_slice_ptrs.get(),
+ batch_results.get());
+ }
+ for (uint32_t i = 0; i < batch_size; ++i) {
+ if (inside_this_time) {
+ ALWAYS_ASSERT(batch_results[i]);
+ } else {
+ info.false_positives_ += batch_results[i];
+ }
+ }
+ } else {
+ for (uint32_t i = 0; i < batch_size; ++i) {
+ bool may_match;
+ if (FLAGS_use_plain_table_bloom) {
+ if (dry_run) {
+ dry_run_hash += dry_run_hash_fn(batch_slices[i]);
+ may_match = true;
+ } else {
+ uint32_t hash = GetSliceHash(batch_slices[i]);
+ may_match = info.plain_table_bloom_->MayContainHash(hash);
+ }
+ } else if (FLAGS_use_full_block_reader) {
+ if (dry_run) {
+ dry_run_hash += dry_run_hash_fn(batch_slices[i]);
+ may_match = true;
+ } else {
+ may_match = info.full_block_reader_->KeyMayMatch(
+ batch_slices[i],
+ /*no_io=*/false, /*const_ikey_ptr=*/nullptr,
+ /*get_context=*/nullptr,
+ /*lookup_context=*/nullptr, Env::IO_TOTAL);
+ }
+ } else {
+ if (dry_run) {
+ dry_run_hash += dry_run_hash_fn(batch_slices[i]);
+ may_match = true;
+ } else {
+ may_match = info.reader_->MayMatch(batch_slices[i]);
+ }
+ }
+ if (inside_this_time) {
+ ALWAYS_ASSERT(may_match);
+ } else {
+ info.false_positives_ += may_match;
+ }
+ }
+ }
+ }
+
+ uint64_t elapsed_nanos = timer.ElapsedNanos();
+ double ns = double(elapsed_nanos) / max_queries;
+
+ if (!FLAGS_quick) {
+ if (dry_run) {
+ // Printing part of hash prevents dry run components from being optimized
+ // away by compiler
+ std::cout << " Dry run (" << std::hex << (dry_run_hash & 0xfffff)
+ << std::dec << ") ";
+ } else {
+ std::cout << " Gross filter ";
+ }
+ std::cout << "ns/op: " << ns << std::endl;
+ }
+
+ if (!dry_run) {
+ fp_rate_report_.str("");
+ uint64_t q = 0;
+ uint64_t fp = 0;
+ double worst_fp_rate = 0.0;
+ double best_fp_rate = 1.0;
+ for (auto &info : infos_) {
+ q += info.outside_queries_;
+ fp += info.false_positives_;
+ if (info.outside_queries_ > 0) {
+ double fp_rate = double(info.false_positives_) / info.outside_queries_;
+ worst_fp_rate = std::max(worst_fp_rate, fp_rate);
+ best_fp_rate = std::min(best_fp_rate, fp_rate);
+ }
+ }
+ fp_rate_report_ << " Average FP rate %: " << 100.0 * fp / q << std::endl;
+ if (!FLAGS_quick && !FLAGS_best_case) {
+ fp_rate_report_ << " Worst FP rate %: " << 100.0 * worst_fp_rate
+ << std::endl;
+ fp_rate_report_ << " Best FP rate %: " << 100.0 * best_fp_rate
+ << std::endl;
+ fp_rate_report_ << " Best possible bits/key: "
+ << -std::log(double(fp) / q) / std::log(2.0) << std::endl;
+ }
+ }
+ return ns;
+}
+
+int main(int argc, char **argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ SetUsageMessage(std::string("\nUSAGE:\n") + std::string(argv[0]) +
+ " [-quick] [OTHER OPTIONS]...");
+ ParseCommandLineFlags(&argc, &argv, true);
+
+ PrintWarnings();
+
+ if (FLAGS_legend) {
+ std::cout
+ << "Legend:" << std::endl
+ << " \"Inside\" - key that was added to filter" << std::endl
+ << " \"Outside\" - key that was not added to filter" << std::endl
+ << " \"FN\" - false negative query (must not happen)" << std::endl
+ << " \"FP\" - false positive query (OK at low rate)" << std::endl
+ << " \"Dry run\" - cost of testing and hashing overhead." << std::endl
+ << " \"Gross filter\" - cost of filter queries including testing "
+ << "\n and hashing overhead." << std::endl
+ << " \"net\" - best estimate of time in filter operation, without "
+ << "\n testing and hashing overhead (gross filter - dry run)"
+ << std::endl
+ << " \"ns/op\" - nanoseconds per operation (key query or add)"
+ << std::endl
+ << " \"Single filter\" - essentially minimum cost, assuming filter"
+ << "\n fits easily in L1 CPU cache." << std::endl
+ << " \"Batched, prepared\" - several queries at once against a"
+ << "\n randomly chosen filter, using multi-query interface."
+ << std::endl
+ << " \"Batched, unprepared\" - similar, but using serial calls"
+ << "\n to single query interface." << std::endl
+ << " \"Random filter\" - a filter is chosen at random as target"
+ << "\n of each query." << std::endl
+ << " \"Skewed X% in Y%\" - like \"Random filter\" except Y% of"
+ << "\n the filters are designated as \"hot\" and receive X%"
+ << "\n of queries." << std::endl;
+ } else {
+ FilterBench b;
+ for (uint32_t i = 0; i < FLAGS_runs; ++i) {
+ b.Go();
+ FLAGS_seed += 100;
+ b.random_.Seed(FLAGS_seed);
+ }
+ }
+
+ return 0;
+}
+
+#endif // !defined(GFLAGS) || defined(ROCKSDB_LITE)
diff --git a/src/rocksdb/util/gflags_compat.h b/src/rocksdb/util/gflags_compat.h
new file mode 100644
index 000000000..b6f88a5bc
--- /dev/null
+++ b/src/rocksdb/util/gflags_compat.h
@@ -0,0 +1,30 @@
+// Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+#include <gflags/gflags.h>
+
+#include <functional>
+
+#ifndef GFLAGS_NAMESPACE
+// in case it's not defined in old versions, that's probably because it was
+// still google by default.
+#define GFLAGS_NAMESPACE google
+#endif
+
+#ifndef DEFINE_uint32
+// DEFINE_uint32 does not appear in older versions of gflags. This should be
+// a sane definition for those versions.
+#include <cstdint>
+#define DEFINE_uint32(name, val, txt) \
+ namespace gflags_compat { \
+ DEFINE_int32(name, val, txt); \
+ } \
+ std::reference_wrapper<uint32_t> FLAGS_##name = \
+ std::ref(*reinterpret_cast<uint32_t *>(&gflags_compat::FLAGS_##name));
+
+#define DECLARE_uint32(name) \
+ extern std::reference_wrapper<uint32_t> FLAGS_##name;
+#endif // !DEFINE_uint32
diff --git a/src/rocksdb/util/hash.cc b/src/rocksdb/util/hash.cc
new file mode 100644
index 000000000..0f7f2edc1
--- /dev/null
+++ b/src/rocksdb/util/hash.cc
@@ -0,0 +1,201 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "util/hash.h"
+
+#include <string>
+
+#include "port/lang.h"
+#include "util/coding.h"
+#include "util/hash128.h"
+#include "util/math128.h"
+#include "util/xxhash.h"
+#include "util/xxph3.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+uint64_t (*kGetSliceNPHash64UnseededFnPtr)(const Slice&) = &GetSliceHash64;
+
+uint32_t Hash(const char* data, size_t n, uint32_t seed) {
+ // MurmurHash1 - fast but mediocre quality
+ // https://github.com/aappleby/smhasher/wiki/MurmurHash1
+ //
+ const uint32_t m = 0xc6a4a793;
+ const uint32_t r = 24;
+ const char* limit = data + n;
+ uint32_t h = static_cast<uint32_t>(seed ^ (n * m));
+
+ // Pick up four bytes at a time
+ while (data + 4 <= limit) {
+ uint32_t w = DecodeFixed32(data);
+ data += 4;
+ h += w;
+ h *= m;
+ h ^= (h >> 16);
+ }
+
+ // Pick up remaining bytes
+ switch (limit - data) {
+ // Note: The original hash implementation used data[i] << shift, which
+ // promotes the char to int and then performs the shift. If the char is
+ // negative, the shift is undefined behavior in C++. The hash algorithm is
+ // part of the format definition, so we cannot change it; to obtain the same
+ // behavior in a legal way we just cast to uint32_t, which will do
+ // sign-extension. To guarantee compatibility with architectures where chars
+ // are unsigned we first cast the char to int8_t.
+ case 3:
+ h += static_cast<uint32_t>(static_cast<int8_t>(data[2])) << 16;
+ FALLTHROUGH_INTENDED;
+ case 2:
+ h += static_cast<uint32_t>(static_cast<int8_t>(data[1])) << 8;
+ FALLTHROUGH_INTENDED;
+ case 1:
+ h += static_cast<uint32_t>(static_cast<int8_t>(data[0]));
+ h *= m;
+ h ^= (h >> r);
+ break;
+ }
+ return h;
+}
+
+// We are standardizing on a preview release of XXH3, because that's
+// the best available at time of standardizing.
+//
+// In testing (mostly Intel Skylake), this hash function is much more
+// thorough than Hash32 and is almost universally faster. Hash() only
+// seems faster when passing runtime-sized keys of the same small size
+// (less than about 24 bytes) thousands of times in a row; this seems
+// to allow the branch predictor to work some magic. XXH3's speed is
+// much less dependent on branch prediction.
+//
+// Hashing with a prefix extractor is potentially a common case of
+// hashing objects of small, predictable size. We could consider
+// bundling hash functions specialized for particular lengths with
+// the prefix extractors.
+uint64_t Hash64(const char* data, size_t n, uint64_t seed) {
+ return XXPH3_64bits_withSeed(data, n, seed);
+}
+
+uint64_t Hash64(const char* data, size_t n) {
+ // Same as seed = 0
+ return XXPH3_64bits(data, n);
+}
+
+uint64_t GetSlicePartsNPHash64(const SliceParts& data, uint64_t seed) {
+ // TODO(ajkr): use XXH3 streaming APIs to avoid the copy/allocation.
+ size_t concat_len = 0;
+ for (int i = 0; i < data.num_parts; ++i) {
+ concat_len += data.parts[i].size();
+ }
+ std::string concat_data;
+ concat_data.reserve(concat_len);
+ for (int i = 0; i < data.num_parts; ++i) {
+ concat_data.append(data.parts[i].data(), data.parts[i].size());
+ }
+ assert(concat_data.size() == concat_len);
+ return NPHash64(concat_data.data(), concat_len, seed);
+}
+
+Unsigned128 Hash128(const char* data, size_t n, uint64_t seed) {
+ auto h = XXH3_128bits_withSeed(data, n, seed);
+ return (Unsigned128{h.high64} << 64) | (h.low64);
+}
+
+Unsigned128 Hash128(const char* data, size_t n) {
+ // Same as seed = 0
+ auto h = XXH3_128bits(data, n);
+ return (Unsigned128{h.high64} << 64) | (h.low64);
+}
+
+void Hash2x64(const char* data, size_t n, uint64_t* high64, uint64_t* low64) {
+ // Same as seed = 0
+ auto h = XXH3_128bits(data, n);
+ *high64 = h.high64;
+ *low64 = h.low64;
+}
+
+void Hash2x64(const char* data, size_t n, uint64_t seed, uint64_t* high64,
+ uint64_t* low64) {
+ auto h = XXH3_128bits_withSeed(data, n, seed);
+ *high64 = h.high64;
+ *low64 = h.low64;
+}
+
+namespace {
+
+inline uint64_t XXH3_avalanche(uint64_t h64) {
+ h64 ^= h64 >> 37;
+ h64 *= 0x165667919E3779F9U;
+ h64 ^= h64 >> 32;
+ return h64;
+}
+
+inline uint64_t XXH3_unavalanche(uint64_t h64) {
+ h64 ^= h64 >> 32;
+ h64 *= 0x8da8ee41d6df849U; // inverse of 0x165667919E3779F9U
+ h64 ^= h64 >> 37;
+ return h64;
+}
+
+} // namespace
+
+void BijectiveHash2x64(uint64_t in_high64, uint64_t in_low64, uint64_t seed,
+ uint64_t* out_high64, uint64_t* out_low64) {
+ // Adapted from XXH3_len_9to16_128b
+ const uint64_t bitflipl = /*secret part*/ 0x59973f0033362349U - seed;
+ const uint64_t bitfliph = /*secret part*/ 0xc202797692d63d58U + seed;
+ Unsigned128 tmp128 =
+ Multiply64to128(in_low64 ^ in_high64 ^ bitflipl, 0x9E3779B185EBCA87U);
+ uint64_t lo = Lower64of128(tmp128);
+ uint64_t hi = Upper64of128(tmp128);
+ lo += 0x3c0000000000000U; // (len - 1) << 54
+ in_high64 ^= bitfliph;
+ hi += in_high64 + (Lower32of64(in_high64) * uint64_t{0x85EBCA76});
+ lo ^= EndianSwapValue(hi);
+ tmp128 = Multiply64to128(lo, 0xC2B2AE3D27D4EB4FU);
+ lo = Lower64of128(tmp128);
+ hi = Upper64of128(tmp128) + (hi * 0xC2B2AE3D27D4EB4FU);
+ *out_low64 = XXH3_avalanche(lo);
+ *out_high64 = XXH3_avalanche(hi);
+}
+
+void BijectiveUnhash2x64(uint64_t in_high64, uint64_t in_low64, uint64_t seed,
+ uint64_t* out_high64, uint64_t* out_low64) {
+ // Inverted above (also consulting XXH3_len_9to16_128b)
+ const uint64_t bitflipl = /*secret part*/ 0x59973f0033362349U - seed;
+ const uint64_t bitfliph = /*secret part*/ 0xc202797692d63d58U + seed;
+ uint64_t lo = XXH3_unavalanche(in_low64);
+ uint64_t hi = XXH3_unavalanche(in_high64);
+ lo *= 0xba79078168d4baf; // inverse of 0xC2B2AE3D27D4EB4FU
+ hi -= Upper64of128(Multiply64to128(lo, 0xC2B2AE3D27D4EB4FU));
+ hi *= 0xba79078168d4baf; // inverse of 0xC2B2AE3D27D4EB4FU
+ lo ^= EndianSwapValue(hi);
+ lo -= 0x3c0000000000000U;
+ lo *= 0x887493432badb37U; // inverse of 0x9E3779B185EBCA87U
+ hi -= Upper64of128(Multiply64to128(lo, 0x9E3779B185EBCA87U));
+ uint32_t tmp32 = Lower32of64(hi) * 0xb6c92f47; // inverse of 0x85EBCA77
+ hi -= tmp32;
+ hi = (hi & 0xFFFFFFFF00000000U) -
+ ((tmp32 * uint64_t{0x85EBCA76}) & 0xFFFFFFFF00000000U) + tmp32;
+ hi ^= bitfliph;
+ lo ^= hi ^ bitflipl;
+ *out_high64 = hi;
+ *out_low64 = lo;
+}
+
+void BijectiveHash2x64(uint64_t in_high64, uint64_t in_low64,
+ uint64_t* out_high64, uint64_t* out_low64) {
+ BijectiveHash2x64(in_high64, in_low64, /*seed*/ 0, out_high64, out_low64);
+}
+
+void BijectiveUnhash2x64(uint64_t in_high64, uint64_t in_low64,
+ uint64_t* out_high64, uint64_t* out_low64) {
+ BijectiveUnhash2x64(in_high64, in_low64, /*seed*/ 0, out_high64, out_low64);
+}
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/hash.h b/src/rocksdb/util/hash.h
new file mode 100644
index 000000000..eafa47f34
--- /dev/null
+++ b/src/rocksdb/util/hash.h
@@ -0,0 +1,137 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+//
+// Common hash functions with convenient interfaces. If hashing a
+// statically-sized input in a performance-critical context, consider
+// calling a specific hash implementation directly, such as
+// XXH3_64bits from xxhash.h.
+//
+// Since this is a very common header, implementation details are kept
+// out-of-line. Out-of-lining also aids in tracking the time spent in
+// hashing functions. Inlining is of limited benefit for runtime-sized
+// hash inputs.
+
+#pragma once
+
+#include <cstddef>
+#include <cstdint>
+
+#include "rocksdb/slice.h"
+#include "util/fastrange.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// Stable/persistent 64-bit hash. Higher quality and generally faster than
+// Hash(), especially for inputs > 24 bytes.
+// KNOWN FLAW: incrementing seed by 1 might not give sufficiently independent
+// results from previous seed. Recommend incrementing by a large odd number.
+extern uint64_t Hash64(const char* data, size_t n, uint64_t seed);
+
+// Specific optimization without seed (same as seed = 0)
+extern uint64_t Hash64(const char* data, size_t n);
+
+// Non-persistent hash. Must only used for in-memory data structures.
+// The hash results are thus subject to change between releases,
+// architectures, build configuration, etc. (Thus, it rarely makes sense
+// to specify a seed for this function, except for a "rolling" hash.)
+// KNOWN FLAW: incrementing seed by 1 might not give sufficiently independent
+// results from previous seed. Recommend incrementing by a large odd number.
+inline uint64_t NPHash64(const char* data, size_t n, uint64_t seed) {
+#ifdef ROCKSDB_MODIFY_NPHASH
+ // For testing "subject to change"
+ return Hash64(data, n, seed + 123456789);
+#else
+ // Currently same as Hash64
+ return Hash64(data, n, seed);
+#endif
+}
+
+// Specific optimization without seed (same as seed = 0)
+inline uint64_t NPHash64(const char* data, size_t n) {
+#ifdef ROCKSDB_MODIFY_NPHASH
+ // For testing "subject to change"
+ return Hash64(data, n, 123456789);
+#else
+ // Currently same as Hash64
+ return Hash64(data, n);
+#endif
+}
+
+// Convenient and equivalent version of Hash128 without depending on 128-bit
+// scalars
+void Hash2x64(const char* data, size_t n, uint64_t* high64, uint64_t* low64);
+void Hash2x64(const char* data, size_t n, uint64_t seed, uint64_t* high64,
+ uint64_t* low64);
+
+// Hash 128 bits to 128 bits, guaranteed not to lose data (equivalent to
+// Hash2x64 on 16 bytes little endian)
+void BijectiveHash2x64(uint64_t in_high64, uint64_t in_low64,
+ uint64_t* out_high64, uint64_t* out_low64);
+void BijectiveHash2x64(uint64_t in_high64, uint64_t in_low64, uint64_t seed,
+ uint64_t* out_high64, uint64_t* out_low64);
+
+// Inverse of above (mostly for testing)
+void BijectiveUnhash2x64(uint64_t in_high64, uint64_t in_low64,
+ uint64_t* out_high64, uint64_t* out_low64);
+void BijectiveUnhash2x64(uint64_t in_high64, uint64_t in_low64, uint64_t seed,
+ uint64_t* out_high64, uint64_t* out_low64);
+
+// Stable/persistent 32-bit hash. Moderate quality and high speed on
+// small inputs.
+// TODO: consider rename to Hash32
+// KNOWN FLAW: incrementing seed by 1 might not give sufficiently independent
+// results from previous seed. Recommend pseudorandom or hashed seeds.
+extern uint32_t Hash(const char* data, size_t n, uint32_t seed);
+
+// TODO: consider rename to LegacyBloomHash32
+inline uint32_t BloomHash(const Slice& key) {
+ return Hash(key.data(), key.size(), 0xbc9f1d34);
+}
+
+inline uint64_t GetSliceHash64(const Slice& key) {
+ return Hash64(key.data(), key.size());
+}
+// Provided for convenience for use with template argument deduction, where a
+// specific overload needs to be used.
+extern uint64_t (*kGetSliceNPHash64UnseededFnPtr)(const Slice&);
+
+inline uint64_t GetSliceNPHash64(const Slice& s) {
+ return NPHash64(s.data(), s.size());
+}
+
+inline uint64_t GetSliceNPHash64(const Slice& s, uint64_t seed) {
+ return NPHash64(s.data(), s.size(), seed);
+}
+
+// Similar to `GetSliceNPHash64()` with `seed`, but input comes from
+// concatenation of `Slice`s in `data`.
+extern uint64_t GetSlicePartsNPHash64(const SliceParts& data, uint64_t seed);
+
+inline size_t GetSliceRangedNPHash(const Slice& s, size_t range) {
+ return FastRange64(NPHash64(s.data(), s.size()), range);
+}
+
+// TODO: consider rename to GetSliceHash32
+inline uint32_t GetSliceHash(const Slice& s) {
+ return Hash(s.data(), s.size(), 397);
+}
+
+// Useful for splitting up a 64-bit hash
+inline uint32_t Upper32of64(uint64_t v) {
+ return static_cast<uint32_t>(v >> 32);
+}
+inline uint32_t Lower32of64(uint64_t v) { return static_cast<uint32_t>(v); }
+
+// std::hash compatible interface.
+// TODO: consider rename to SliceHasher32
+struct SliceHasher {
+ uint32_t operator()(const Slice& s) const { return GetSliceHash(s); }
+};
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/hash128.h b/src/rocksdb/util/hash128.h
new file mode 100644
index 000000000..305caa14a
--- /dev/null
+++ b/src/rocksdb/util/hash128.h
@@ -0,0 +1,26 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+// 128-bit hash gets it own header so that more popular hash.h doesn't
+// depend on math128.h
+
+#include "rocksdb/slice.h"
+#include "util/math128.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// Stable/persistent 128-bit hash for non-cryptographic applications.
+Unsigned128 Hash128(const char* data, size_t n, uint64_t seed);
+
+// Specific optimization without seed (same as seed = 0)
+Unsigned128 Hash128(const char* data, size_t n);
+
+inline Unsigned128 GetSliceHash128(const Slice& key) {
+ return Hash128(key.data(), key.size());
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/hash_containers.h b/src/rocksdb/util/hash_containers.h
new file mode 100644
index 000000000..52be3718c
--- /dev/null
+++ b/src/rocksdb/util/hash_containers.h
@@ -0,0 +1,51 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+// This header establishes compile-time pluggable implementations of hashed
+// container structures, so that deployments have the option of minimal
+// dependencies with ok performance (e.g. std::unordered_map) or more
+// dependencies with optimized performance (e.g. folly::F14FastMap).
+
+#pragma once
+
+#include "rocksdb/rocksdb_namespace.h"
+
+#ifdef USE_FOLLY
+
+#include <folly/container/F14Map.h>
+#include <folly/container/F14Set.h>
+
+namespace ROCKSDB_NAMESPACE {
+
+template <typename K, typename V>
+using UnorderedMap = folly::F14FastMap<K, V>;
+
+template <typename K, typename V, typename H>
+using UnorderedMapH = folly::F14FastMap<K, V, H>;
+
+template <typename K>
+using UnorderedSet = folly::F14FastSet<K>;
+
+} // namespace ROCKSDB_NAMESPACE
+
+#else
+
+#include <unordered_map>
+#include <unordered_set>
+
+namespace ROCKSDB_NAMESPACE {
+
+template <typename K, typename V>
+using UnorderedMap = std::unordered_map<K, V>;
+
+template <typename K, typename V, typename H>
+using UnorderedMapH = std::unordered_map<K, V, H>;
+
+template <typename K>
+using UnorderedSet = std::unordered_set<K>;
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif
diff --git a/src/rocksdb/util/hash_map.h b/src/rocksdb/util/hash_map.h
new file mode 100644
index 000000000..e3ad2584f
--- /dev/null
+++ b/src/rocksdb/util/hash_map.h
@@ -0,0 +1,67 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+
+#pragma once
+
+#include <algorithm>
+#include <array>
+#include <utility>
+
+#include "util/autovector.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// This is similar to std::unordered_map, except that it tries to avoid
+// allocating or deallocating memory as much as possible. With
+// std::unordered_map, an allocation/deallocation is made for every insertion
+// or deletion because of the requirement that iterators remain valid even
+// with insertions or deletions. This means that the hash chains will be
+// implemented as linked lists.
+//
+// This implementation uses autovector as hash chains insteads.
+//
+template <typename K, typename V, size_t size = 128>
+class HashMap {
+ std::array<autovector<std::pair<K, V>, 1>, size> table_;
+
+ public:
+ bool Contains(K key) {
+ auto& bucket = table_[key % size];
+ auto it = std::find_if(
+ bucket.begin(), bucket.end(),
+ [key](const std::pair<K, V>& p) { return p.first == key; });
+ return it != bucket.end();
+ }
+
+ void Insert(K key, const V& value) {
+ auto& bucket = table_[key % size];
+ bucket.push_back({key, value});
+ }
+
+ void Delete(K key) {
+ auto& bucket = table_[key % size];
+ auto it = std::find_if(
+ bucket.begin(), bucket.end(),
+ [key](const std::pair<K, V>& p) { return p.first == key; });
+ if (it != bucket.end()) {
+ auto last = bucket.end() - 1;
+ if (it != last) {
+ *it = *last;
+ }
+ bucket.pop_back();
+ }
+ }
+
+ V& Get(K key) {
+ auto& bucket = table_[key % size];
+ auto it = std::find_if(
+ bucket.begin(), bucket.end(),
+ [key](const std::pair<K, V>& p) { return p.first == key; });
+ return it->second;
+ }
+};
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/hash_test.cc b/src/rocksdb/util/hash_test.cc
new file mode 100644
index 000000000..72112b044
--- /dev/null
+++ b/src/rocksdb/util/hash_test.cc
@@ -0,0 +1,853 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2012 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "util/hash.h"
+
+#include <cstring>
+#include <type_traits>
+#include <vector>
+
+#include "test_util/testharness.h"
+#include "util/coding.h"
+#include "util/coding_lean.h"
+#include "util/hash128.h"
+#include "util/math.h"
+#include "util/math128.h"
+
+using ROCKSDB_NAMESPACE::BijectiveHash2x64;
+using ROCKSDB_NAMESPACE::BijectiveUnhash2x64;
+using ROCKSDB_NAMESPACE::DecodeFixed64;
+using ROCKSDB_NAMESPACE::EncodeFixed32;
+using ROCKSDB_NAMESPACE::EndianSwapValue;
+using ROCKSDB_NAMESPACE::GetSliceHash64;
+using ROCKSDB_NAMESPACE::Hash;
+using ROCKSDB_NAMESPACE::Hash128;
+using ROCKSDB_NAMESPACE::Hash2x64;
+using ROCKSDB_NAMESPACE::Hash64;
+using ROCKSDB_NAMESPACE::Lower32of64;
+using ROCKSDB_NAMESPACE::Lower64of128;
+using ROCKSDB_NAMESPACE::ReverseBits;
+using ROCKSDB_NAMESPACE::Slice;
+using ROCKSDB_NAMESPACE::Unsigned128;
+using ROCKSDB_NAMESPACE::Upper32of64;
+using ROCKSDB_NAMESPACE::Upper64of128;
+
+// The hash algorithm is part of the file format, for example for the Bloom
+// filters. Test that the hash values are stable for a set of random strings of
+// varying lengths.
+TEST(HashTest, Values) {
+ constexpr uint32_t kSeed = 0xbc9f1d34; // Same as BloomHash.
+
+ EXPECT_EQ(Hash("", 0, kSeed), 3164544308u);
+ EXPECT_EQ(Hash("\x08", 1, kSeed), 422599524u);
+ EXPECT_EQ(Hash("\x17", 1, kSeed), 3168152998u);
+ EXPECT_EQ(Hash("\x9a", 1, kSeed), 3195034349u);
+ EXPECT_EQ(Hash("\x1c", 1, kSeed), 2651681383u);
+ EXPECT_EQ(Hash("\x4d\x76", 2, kSeed), 2447836956u);
+ EXPECT_EQ(Hash("\x52\xd5", 2, kSeed), 3854228105u);
+ EXPECT_EQ(Hash("\x91\xf7", 2, kSeed), 31066776u);
+ EXPECT_EQ(Hash("\xd6\x27", 2, kSeed), 1806091603u);
+ EXPECT_EQ(Hash("\x30\x46\x0b", 3, kSeed), 3808221797u);
+ EXPECT_EQ(Hash("\x56\xdc\xd6", 3, kSeed), 2157698265u);
+ EXPECT_EQ(Hash("\xd4\x52\x33", 3, kSeed), 1721992661u);
+ EXPECT_EQ(Hash("\x6a\xb5\xf4", 3, kSeed), 2469105222u);
+ EXPECT_EQ(Hash("\x67\x53\x81\x1c", 4, kSeed), 118283265u);
+ EXPECT_EQ(Hash("\x69\xb8\xc0\x88", 4, kSeed), 3416318611u);
+ EXPECT_EQ(Hash("\x1e\x84\xaf\x2d", 4, kSeed), 3315003572u);
+ EXPECT_EQ(Hash("\x46\xdc\x54\xbe", 4, kSeed), 447346355u);
+ EXPECT_EQ(Hash("\xd0\x7a\x6e\xea\x56", 5, kSeed), 4255445370u);
+ EXPECT_EQ(Hash("\x86\x83\xd5\xa4\xd8", 5, kSeed), 2390603402u);
+ EXPECT_EQ(Hash("\xb7\x46\xbb\x77\xce", 5, kSeed), 2048907743u);
+ EXPECT_EQ(Hash("\x6c\xa8\xbc\xe5\x99", 5, kSeed), 2177978500u);
+ EXPECT_EQ(Hash("\x5c\x5e\xe1\xa0\x73\x81", 6, kSeed), 1036846008u);
+ EXPECT_EQ(Hash("\x08\x5d\x73\x1c\xe5\x2e", 6, kSeed), 229980482u);
+ EXPECT_EQ(Hash("\x42\xfb\xf2\x52\xb4\x10", 6, kSeed), 3655585422u);
+ EXPECT_EQ(Hash("\x73\xe1\xff\x56\x9c\xce", 6, kSeed), 3502708029u);
+ EXPECT_EQ(Hash("\x5c\xbe\x97\x75\x54\x9a\x52", 7, kSeed), 815120748u);
+ EXPECT_EQ(Hash("\x16\x82\x39\x49\x88\x2b\x36", 7, kSeed), 3056033698u);
+ EXPECT_EQ(Hash("\x59\x77\xf0\xa7\x24\xf4\x78", 7, kSeed), 587205227u);
+ EXPECT_EQ(Hash("\xd3\xa5\x7c\x0e\xc0\x02\x07", 7, kSeed), 2030937252u);
+ EXPECT_EQ(Hash("\x31\x1b\x98\x75\x96\x22\xd3\x9a", 8, kSeed), 469635402u);
+ EXPECT_EQ(Hash("\x38\xd6\xf7\x28\x20\xb4\x8a\xe9", 8, kSeed), 3530274698u);
+ EXPECT_EQ(Hash("\xbb\x18\x5d\xf4\x12\x03\xf7\x99", 8, kSeed), 1974545809u);
+ EXPECT_EQ(Hash("\x80\xd4\x3b\x3b\xae\x22\xa2\x78", 8, kSeed), 3563570120u);
+ EXPECT_EQ(Hash("\x1a\xb5\xd0\xfe\xab\xc3\x61\xb2\x99", 9, kSeed),
+ 2706087434u);
+ EXPECT_EQ(Hash("\x8e\x4a\xc3\x18\x20\x2f\x06\xe6\x3c", 9, kSeed),
+ 1534654151u);
+ EXPECT_EQ(Hash("\xb6\xc0\xdd\x05\x3f\xc4\x86\x4c\xef", 9, kSeed),
+ 2355554696u);
+ EXPECT_EQ(Hash("\x9a\x5f\x78\x0d\xaf\x50\xe1\x1f\x55", 9, kSeed),
+ 1400800912u);
+ EXPECT_EQ(Hash("\x22\x6f\x39\x1f\xf8\xdd\x4f\x52\x17\x94", 10, kSeed),
+ 3420325137u);
+ EXPECT_EQ(Hash("\x32\x89\x2a\x75\x48\x3a\x4a\x02\x69\xdd", 10, kSeed),
+ 3427803584u);
+ EXPECT_EQ(Hash("\x06\x92\x5c\xf4\x88\x0e\x7e\x68\x38\x3e", 10, kSeed),
+ 1152407945u);
+ EXPECT_EQ(Hash("\xbd\x2c\x63\x38\xbf\xe9\x78\xb7\xbf\x15", 10, kSeed),
+ 3382479516u);
+}
+
+// The hash algorithm is part of the file format, for example for the Bloom
+// filters.
+TEST(HashTest, Hash64Misc) {
+ constexpr uint32_t kSeed = 0; // Same as GetSliceHash64
+
+ for (char fill : {'\0', 'a', '1', '\xff'}) {
+ const size_t max_size = 1000;
+ const std::string str(max_size, fill);
+
+ for (size_t size = 0; size <= max_size; ++size) {
+ uint64_t here = Hash64(str.data(), size, kSeed);
+
+ // Must be same as unseeded Hash64 and GetSliceHash64
+ EXPECT_EQ(here, Hash64(str.data(), size));
+ EXPECT_EQ(here, GetSliceHash64(Slice(str.data(), size)));
+
+ // Upper and Lower must reconstruct hash
+ EXPECT_EQ(here, (uint64_t{Upper32of64(here)} << 32) | Lower32of64(here));
+ EXPECT_EQ(here, (uint64_t{Upper32of64(here)} << 32) + Lower32of64(here));
+ EXPECT_EQ(here, (uint64_t{Upper32of64(here)} << 32) ^ Lower32of64(here));
+
+ // Seed changes hash value (with high probability)
+ for (uint64_t var_seed = 1; var_seed != 0; var_seed <<= 1) {
+ EXPECT_NE(here, Hash64(str.data(), size, var_seed));
+ }
+
+ // Size changes hash value (with high probability)
+ size_t max_smaller_by = std::min(size_t{30}, size);
+ for (size_t smaller_by = 1; smaller_by <= max_smaller_by; ++smaller_by) {
+ EXPECT_NE(here, Hash64(str.data(), size - smaller_by, kSeed));
+ }
+ }
+ }
+}
+
+// Test that hash values are "non-trivial" for "trivial" inputs
+TEST(HashTest, Hash64Trivial) {
+ // Thorough test too slow for regression testing
+ constexpr bool thorough = false;
+
+ // For various seeds, make sure hash of empty string is not zero.
+ constexpr uint64_t max_seed = thorough ? 0x1000000 : 0x10000;
+ for (uint64_t seed = 0; seed < max_seed; ++seed) {
+ uint64_t here = Hash64("", 0, seed);
+ EXPECT_NE(Lower32of64(here), 0u);
+ EXPECT_NE(Upper32of64(here), 0u);
+ }
+
+ // For standard seed, make sure hash of small strings are not zero
+ constexpr uint32_t kSeed = 0; // Same as GetSliceHash64
+ char input[4];
+ constexpr int max_len = thorough ? 3 : 2;
+ for (int len = 1; len <= max_len; ++len) {
+ for (uint32_t i = 0; (i >> (len * 8)) == 0; ++i) {
+ EncodeFixed32(input, i);
+ uint64_t here = Hash64(input, len, kSeed);
+ EXPECT_NE(Lower32of64(here), 0u);
+ EXPECT_NE(Upper32of64(here), 0u);
+ }
+ }
+}
+
+// Test that the hash values are stable for a set of random strings of
+// varying small lengths.
+TEST(HashTest, Hash64SmallValueSchema) {
+ constexpr uint32_t kSeed = 0; // Same as GetSliceHash64
+
+ EXPECT_EQ(Hash64("", 0, kSeed), uint64_t{5999572062939766020u});
+ EXPECT_EQ(Hash64("\x08", 1, kSeed), uint64_t{583283813901344696u});
+ EXPECT_EQ(Hash64("\x17", 1, kSeed), uint64_t{16175549975585474943u});
+ EXPECT_EQ(Hash64("\x9a", 1, kSeed), uint64_t{16322991629225003903u});
+ EXPECT_EQ(Hash64("\x1c", 1, kSeed), uint64_t{13269285487706833447u});
+ EXPECT_EQ(Hash64("\x4d\x76", 2, kSeed), uint64_t{6859542833406258115u});
+ EXPECT_EQ(Hash64("\x52\xd5", 2, kSeed), uint64_t{4919611532550636959u});
+ EXPECT_EQ(Hash64("\x91\xf7", 2, kSeed), uint64_t{14199427467559720719u});
+ EXPECT_EQ(Hash64("\xd6\x27", 2, kSeed), uint64_t{12292689282614532691u});
+ EXPECT_EQ(Hash64("\x30\x46\x0b", 3, kSeed), uint64_t{11404699285340020889u});
+ EXPECT_EQ(Hash64("\x56\xdc\xd6", 3, kSeed), uint64_t{12404347133785524237u});
+ EXPECT_EQ(Hash64("\xd4\x52\x33", 3, kSeed), uint64_t{15853805298481534034u});
+ EXPECT_EQ(Hash64("\x6a\xb5\xf4", 3, kSeed), uint64_t{16863488758399383382u});
+ EXPECT_EQ(Hash64("\x67\x53\x81\x1c", 4, kSeed),
+ uint64_t{9010661983527562386u});
+ EXPECT_EQ(Hash64("\x69\xb8\xc0\x88", 4, kSeed),
+ uint64_t{6611781377647041447u});
+ EXPECT_EQ(Hash64("\x1e\x84\xaf\x2d", 4, kSeed),
+ uint64_t{15290969111616346501u});
+ EXPECT_EQ(Hash64("\x46\xdc\x54\xbe", 4, kSeed),
+ uint64_t{7063754590279313623u});
+ EXPECT_EQ(Hash64("\xd0\x7a\x6e\xea\x56", 5, kSeed),
+ uint64_t{6384167718754869899u});
+ EXPECT_EQ(Hash64("\x86\x83\xd5\xa4\xd8", 5, kSeed),
+ uint64_t{16874407254108011067u});
+ EXPECT_EQ(Hash64("\xb7\x46\xbb\x77\xce", 5, kSeed),
+ uint64_t{16809880630149135206u});
+ EXPECT_EQ(Hash64("\x6c\xa8\xbc\xe5\x99", 5, kSeed),
+ uint64_t{1249038833153141148u});
+ EXPECT_EQ(Hash64("\x5c\x5e\xe1\xa0\x73\x81", 6, kSeed),
+ uint64_t{17358142495308219330u});
+ EXPECT_EQ(Hash64("\x08\x5d\x73\x1c\xe5\x2e", 6, kSeed),
+ uint64_t{4237646583134806322u});
+ EXPECT_EQ(Hash64("\x42\xfb\xf2\x52\xb4\x10", 6, kSeed),
+ uint64_t{4373664924115234051u});
+ EXPECT_EQ(Hash64("\x73\xe1\xff\x56\x9c\xce", 6, kSeed),
+ uint64_t{12012981210634596029u});
+ EXPECT_EQ(Hash64("\x5c\xbe\x97\x75\x54\x9a\x52", 7, kSeed),
+ uint64_t{5716522398211028826u});
+ EXPECT_EQ(Hash64("\x16\x82\x39\x49\x88\x2b\x36", 7, kSeed),
+ uint64_t{15604531309862565013u});
+ EXPECT_EQ(Hash64("\x59\x77\xf0\xa7\x24\xf4\x78", 7, kSeed),
+ uint64_t{8601330687345614172u});
+ EXPECT_EQ(Hash64("\xd3\xa5\x7c\x0e\xc0\x02\x07", 7, kSeed),
+ uint64_t{8088079329364056942u});
+ EXPECT_EQ(Hash64("\x31\x1b\x98\x75\x96\x22\xd3\x9a", 8, kSeed),
+ uint64_t{9844314944338447628u});
+ EXPECT_EQ(Hash64("\x38\xd6\xf7\x28\x20\xb4\x8a\xe9", 8, kSeed),
+ uint64_t{10973293517982163143u});
+ EXPECT_EQ(Hash64("\xbb\x18\x5d\xf4\x12\x03\xf7\x99", 8, kSeed),
+ uint64_t{9986007080564743219u});
+ EXPECT_EQ(Hash64("\x80\xd4\x3b\x3b\xae\x22\xa2\x78", 8, kSeed),
+ uint64_t{1729303145008254458u});
+ EXPECT_EQ(Hash64("\x1a\xb5\xd0\xfe\xab\xc3\x61\xb2\x99", 9, kSeed),
+ uint64_t{13253403748084181481u});
+ EXPECT_EQ(Hash64("\x8e\x4a\xc3\x18\x20\x2f\x06\xe6\x3c", 9, kSeed),
+ uint64_t{7768754303876232188u});
+ EXPECT_EQ(Hash64("\xb6\xc0\xdd\x05\x3f\xc4\x86\x4c\xef", 9, kSeed),
+ uint64_t{12439346786701492u});
+ EXPECT_EQ(Hash64("\x9a\x5f\x78\x0d\xaf\x50\xe1\x1f\x55", 9, kSeed),
+ uint64_t{10841838338450144690u});
+ EXPECT_EQ(Hash64("\x22\x6f\x39\x1f\xf8\xdd\x4f\x52\x17\x94", 10, kSeed),
+ uint64_t{12883919702069153152u});
+ EXPECT_EQ(Hash64("\x32\x89\x2a\x75\x48\x3a\x4a\x02\x69\xdd", 10, kSeed),
+ uint64_t{12692903507676842188u});
+ EXPECT_EQ(Hash64("\x06\x92\x5c\xf4\x88\x0e\x7e\x68\x38\x3e", 10, kSeed),
+ uint64_t{6540985900674032620u});
+ EXPECT_EQ(Hash64("\xbd\x2c\x63\x38\xbf\xe9\x78\xb7\xbf\x15", 10, kSeed),
+ uint64_t{10551812464348219044u});
+}
+
+std::string Hash64TestDescriptor(const char *repeat, size_t limit) {
+ const char *mod61_encode =
+ "abcdefghijklmnopqrstuvwxyz123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
+
+ std::string input;
+ while (input.size() < limit) {
+ input.append(repeat);
+ }
+ std::string rv;
+ for (size_t i = 0; i < limit; ++i) {
+ uint64_t h = GetSliceHash64(Slice(input.data(), i));
+ rv.append(1, mod61_encode[static_cast<size_t>(h % 61)]);
+ }
+ return rv;
+}
+
+// XXPH3 changes its algorithm for various sizes up through 250 bytes, so
+// we need to check the stability of larger sizes also.
+TEST(HashTest, Hash64LargeValueSchema) {
+ // Each of these derives a "descriptor" from the hash values for all
+ // lengths up to 430.
+ // Note that "c" is common for the zero-length string.
+ EXPECT_EQ(
+ Hash64TestDescriptor("foo", 430),
+ "cRhyWsY67B6klRA1udmOuiYuX7IthyGBKqbeosz2hzVglWCmQx8nEdnpkvPfYX56Up2OWOTV"
+ "lTzfAoYwvtqKzjD8E9xttR2unelbXbIV67NUe6bOO23BxaSFRcA3njGu5cUWfgwOqNoTsszp"
+ "uPvKRP6qaUR5VdoBkJUCFIefd7edlNK5mv6JYWaGdwxehg65hTkTmjZoPKxTZo4PLyzbL9U4"
+ "xt12ITSfeP2MfBHuLI2z2pDlBb44UQKVMx27LEoAHsdLp3WfWfgH3sdRBRCHm33UxCM4QmE2"
+ "xJ7gqSvNwTeH7v9GlC8zWbGroyD3UVNeShMLx29O7tH1biemLULwAHyIw8zdtLMDpEJ8m2ic"
+ "l6Lb4fDuuFNAs1GCVUthjK8CV8SWI8Rsz5THSwn5CGhpqUwSZcFknjwWIl5rNCvDxXJqYr");
+ // Note that "1EeRk" is common for "Rocks"
+ EXPECT_EQ(
+ Hash64TestDescriptor("Rocks", 430),
+ "c1EeRkrzgOYWLA8PuhJrwTePJewoB44WdXYDfhbk3ZxTqqg25WlPExDl7IKIQLJvnA6gJxxn"
+ "9TCSLkFGfJeXehaSS1GBqWSzfhEH4VXiXIUCuxJXxtKXcSC6FrNIQGTZbYDiUOLD6Y5inzrF"
+ "9etwQhXUBanw55xAUdNMFQAm2GjJ6UDWp2mISLiMMkLjANWMKLaZMqaFLX37qB4MRO1ooVRv"
+ "zSvaNRSCLxlggQCasQq8icWjzf3HjBlZtU6pd4rkaUxSzHqmo9oM5MghbU5Rtxg8wEfO7lVN"
+ "5wdMONYecslQTwjZUpO1K3LDf3K3XK6sUXM6ShQQ3RHmMn2acB4YtTZ3QQcHYJSOHn2DuWpa"
+ "Q8RqzX5lab92YmOLaCdOHq1BPsM7SIBzMdLgePNsJ1vvMALxAaoDUHPxoFLO2wx18IXnyX");
+ EXPECT_EQ(
+ Hash64TestDescriptor("RocksDB", 430),
+ "c1EeRkukbkb28wLTahwD2sfUhZzaBEnF8SVrxnPVB6A7b8CaAl3UKsDZISF92GSq2wDCukOq"
+ "Jgrsp7A3KZhDiLW8dFXp8UPqPxMCRlMdZeVeJ2dJxrmA6cyt99zkQFj7ELbut6jAeVqARFnw"
+ "fnWVXOsaLrq7bDCbMcns2DKvTaaqTCLMYxI7nhtLpFN1jR755FRQFcOzrrDbh7QhypjdvlYw"
+ "cdAMSZgp9JMHxbM23wPSuH6BOFgxejz35PScZfhDPvTOxIy1jc3MZsWrMC3P324zNolO7JdW"
+ "CX2I5UDKjjaEJfxbgVgJIXxtQGlmj2xkO5sPpjULQV4X2HlY7FQleJ4QRaJIB4buhCA4vUTF"
+ "eMFlxCIYUpTCsal2qsmnGOWa8WCcefrohMjDj1fjzSvSaQwlpyR1GZHF2uPOoQagiCpHpm");
+}
+
+TEST(HashTest, Hash128Misc) {
+ constexpr uint32_t kSeed = 0; // Same as GetSliceHash128
+
+ for (char fill : {'\0', 'a', '1', '\xff', 'e'}) {
+ const size_t max_size = 1000;
+ std::string str(max_size, fill);
+
+ if (fill == 'e') {
+ // Use different characters to check endianness handling
+ for (size_t i = 0; i < str.size(); ++i) {
+ str[i] += static_cast<char>(i);
+ }
+ }
+
+ for (size_t size = 0; size <= max_size; ++size) {
+ Unsigned128 here = Hash128(str.data(), size, kSeed);
+
+ // Must be same as unseeded Hash128 and GetSliceHash128
+ EXPECT_EQ(here, Hash128(str.data(), size));
+ EXPECT_EQ(here, GetSliceHash128(Slice(str.data(), size)));
+ {
+ uint64_t hi, lo;
+ Hash2x64(str.data(), size, &hi, &lo);
+ EXPECT_EQ(Lower64of128(here), lo);
+ EXPECT_EQ(Upper64of128(here), hi);
+ }
+ if (size == 16) {
+ const uint64_t in_hi = DecodeFixed64(str.data() + 8);
+ const uint64_t in_lo = DecodeFixed64(str.data());
+ uint64_t hi, lo;
+ BijectiveHash2x64(in_hi, in_lo, &hi, &lo);
+ EXPECT_EQ(Lower64of128(here), lo);
+ EXPECT_EQ(Upper64of128(here), hi);
+ uint64_t un_hi, un_lo;
+ BijectiveUnhash2x64(hi, lo, &un_hi, &un_lo);
+ EXPECT_EQ(in_lo, un_lo);
+ EXPECT_EQ(in_hi, un_hi);
+ }
+
+ // Upper and Lower must reconstruct hash
+ EXPECT_EQ(here,
+ (Unsigned128{Upper64of128(here)} << 64) | Lower64of128(here));
+ EXPECT_EQ(here,
+ (Unsigned128{Upper64of128(here)} << 64) ^ Lower64of128(here));
+
+ // Seed changes hash value (with high probability)
+ for (uint64_t var_seed = 1; var_seed != 0; var_seed <<= 1) {
+ Unsigned128 seeded = Hash128(str.data(), size, var_seed);
+ EXPECT_NE(here, seeded);
+ // Must match seeded Hash2x64
+ {
+ uint64_t hi, lo;
+ Hash2x64(str.data(), size, var_seed, &hi, &lo);
+ EXPECT_EQ(Lower64of128(seeded), lo);
+ EXPECT_EQ(Upper64of128(seeded), hi);
+ }
+ if (size == 16) {
+ const uint64_t in_hi = DecodeFixed64(str.data() + 8);
+ const uint64_t in_lo = DecodeFixed64(str.data());
+ uint64_t hi, lo;
+ BijectiveHash2x64(in_hi, in_lo, var_seed, &hi, &lo);
+ EXPECT_EQ(Lower64of128(seeded), lo);
+ EXPECT_EQ(Upper64of128(seeded), hi);
+ uint64_t un_hi, un_lo;
+ BijectiveUnhash2x64(hi, lo, var_seed, &un_hi, &un_lo);
+ EXPECT_EQ(in_lo, un_lo);
+ EXPECT_EQ(in_hi, un_hi);
+ }
+ }
+
+ // Size changes hash value (with high probability)
+ size_t max_smaller_by = std::min(size_t{30}, size);
+ for (size_t smaller_by = 1; smaller_by <= max_smaller_by; ++smaller_by) {
+ EXPECT_NE(here, Hash128(str.data(), size - smaller_by, kSeed));
+ }
+ }
+ }
+}
+
+// Test that hash values are "non-trivial" for "trivial" inputs
+TEST(HashTest, Hash128Trivial) {
+ // Thorough test too slow for regression testing
+ constexpr bool thorough = false;
+
+ // For various seeds, make sure hash of empty string is not zero.
+ constexpr uint64_t max_seed = thorough ? 0x1000000 : 0x10000;
+ for (uint64_t seed = 0; seed < max_seed; ++seed) {
+ Unsigned128 here = Hash128("", 0, seed);
+ EXPECT_NE(Lower64of128(here), 0u);
+ EXPECT_NE(Upper64of128(here), 0u);
+ }
+
+ // For standard seed, make sure hash of small strings are not zero
+ constexpr uint32_t kSeed = 0; // Same as GetSliceHash128
+ char input[4];
+ constexpr int max_len = thorough ? 3 : 2;
+ for (int len = 1; len <= max_len; ++len) {
+ for (uint32_t i = 0; (i >> (len * 8)) == 0; ++i) {
+ EncodeFixed32(input, i);
+ Unsigned128 here = Hash128(input, len, kSeed);
+ EXPECT_NE(Lower64of128(here), 0u);
+ EXPECT_NE(Upper64of128(here), 0u);
+ }
+ }
+}
+
+std::string Hash128TestDescriptor(const char *repeat, size_t limit) {
+ const char *mod61_encode =
+ "abcdefghijklmnopqrstuvwxyz123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
+
+ std::string input;
+ while (input.size() < limit) {
+ input.append(repeat);
+ }
+ std::string rv;
+ for (size_t i = 0; i < limit; ++i) {
+ auto h = GetSliceHash128(Slice(input.data(), i));
+ uint64_t h2 = Upper64of128(h) + Lower64of128(h);
+ rv.append(1, mod61_encode[static_cast<size_t>(h2 % 61)]);
+ }
+ return rv;
+}
+
+// XXH3 changes its algorithm for various sizes up through 250 bytes, so
+// we need to check the stability of larger sizes also.
+TEST(HashTest, Hash128ValueSchema) {
+ // Each of these derives a "descriptor" from the hash values for all
+ // lengths up to 430.
+ // Note that "b" is common for the zero-length string.
+ EXPECT_EQ(
+ Hash128TestDescriptor("foo", 430),
+ "bUMA3As8n9I4vNGhThXlEevxZlyMcbb6TYAlIKJ2f5ponsv99q962rYclQ7u3gfnRdCDQ5JI"
+ "2LrGUaCycbXrvLFe4SjgRb9RQwCfrnmNQ7VSEwSKMnkGCK3bDbXSrnIh5qLXdtvIZklbJpGH"
+ "Dqr93BlqF9ubTnOSYkSdx89XvQqflMIW8bjfQp9BPjQejWOeEQspnN1D3sfgVdFhpaQdHYA5"
+ "pI2XcPlCMFPxvrFuRr7joaDvjNe9IUZaunLPMewuXmC3EL95h52Ju3D7y9RNKhgYxMTrA84B"
+ "yJrMvyjdm3vlBxet4EN7v2GEyjbGuaZW9UL6lrX6PghJDg7ACfLGdxNbH3qXM4zaiG2RKnL5"
+ "S3WXKR78RBB5fRFQ8KDIEQjHFvSNsc3GrAEi6W8P2lv8JMTzjBODO2uN4wadVQFT9wpGfV");
+ // Note that "35D2v" is common for "Rocks"
+ EXPECT_EQ(
+ Hash128TestDescriptor("Rocks", 430),
+ "b35D2vzvklFVDqJmyLRXyApwGGO3EAT3swhe8XJAN3mY2UVPglzdmydxcba6JI2tSvwO6zSu"
+ "ANpjSM7tc9G5iMhsa7R8GfyCXRO1TnLg7HvdWNdgGGBirxZR68BgT7TQsYJt6zyEyISeXI1n"
+ "MXA48Xo7dWfJeYN6Z4KWlqZY7TgFXGbks9AX4ehZNSGtIhdO5i58qlgVX1bEejeOVaCcjC79"
+ "67DrMfOKds7rUQzjBa77sMPcoPW1vu6ljGJPZH3XkRyDMZ1twxXKkNxN3tE8nR7JHwyqBAxE"
+ "fTcjbOWrLZ1irWxRSombD8sGDEmclgF11IxqEhe3Rt7gyofO3nExGckKkS9KfRqsCHbiUyva"
+ "JGkJwUHRXaZnh58b4i1Ei9aQKZjXlvIVDixoZrjcNaH5XJIJlRZce9Z9t82wYapTpckYSg");
+ EXPECT_EQ(
+ Hash128TestDescriptor("RocksDB", 430),
+ "b35D2vFUst3XDZCRlSrhmYYakmqImV97LbBsV6EZlOEQpUPH1d1sD3xMKAPlA5UErHehg5O7"
+ "n966fZqhAf3hRc24kGCLfNAWjyUa7vSNOx3IcPoTyVRFZeFlcCtfl7t1QJumHOCpS33EBmBF"
+ "hvK13QjBbDWYWeHQhJhgV9Mqbx17TIcvUkEnYZxb8IzWNmjVsJG44Z7v52DjGj1ZzS62S2Vv"
+ "qWcDO7apvH5VHg68E9Wl6nXP21vlmUqEH9GeWRehfWVvY7mUpsAg5drHHQyDSdiMceiUuUxJ"
+ "XJqHFcDdzbbPk7xDvbLgWCKvH8k3MpQNWOmbSSRDdAP6nGlDjoTToYkcqVREHJzztSWAAq5h"
+ "GHSUNJ6OxsMHhf8EhXfHtKyUzRmPtjYyeckQcGmrQfFFLidc6cjMDKCdBG6c6HVBrS7H2R");
+}
+
+TEST(FastRange32Test, Values) {
+ using ROCKSDB_NAMESPACE::FastRange32;
+ // Zero range
+ EXPECT_EQ(FastRange32(0, 0), 0U);
+ EXPECT_EQ(FastRange32(123, 0), 0U);
+ EXPECT_EQ(FastRange32(0xffffffff, 0), 0U);
+
+ // One range
+ EXPECT_EQ(FastRange32(0, 1), 0U);
+ EXPECT_EQ(FastRange32(123, 1), 0U);
+ EXPECT_EQ(FastRange32(0xffffffff, 1), 0U);
+
+ // Two range
+ EXPECT_EQ(FastRange32(0, 2), 0U);
+ EXPECT_EQ(FastRange32(123, 2), 0U);
+ EXPECT_EQ(FastRange32(0x7fffffff, 2), 0U);
+ EXPECT_EQ(FastRange32(0x80000000, 2), 1U);
+ EXPECT_EQ(FastRange32(0xffffffff, 2), 1U);
+
+ // Seven range
+ EXPECT_EQ(FastRange32(0, 7), 0U);
+ EXPECT_EQ(FastRange32(123, 7), 0U);
+ EXPECT_EQ(FastRange32(613566756, 7), 0U);
+ EXPECT_EQ(FastRange32(613566757, 7), 1U);
+ EXPECT_EQ(FastRange32(1227133513, 7), 1U);
+ EXPECT_EQ(FastRange32(1227133514, 7), 2U);
+ // etc.
+ EXPECT_EQ(FastRange32(0xffffffff, 7), 6U);
+
+ // Big
+ EXPECT_EQ(FastRange32(1, 0x80000000), 0U);
+ EXPECT_EQ(FastRange32(2, 0x80000000), 1U);
+ EXPECT_EQ(FastRange32(4, 0x7fffffff), 1U);
+ EXPECT_EQ(FastRange32(4, 0x80000000), 2U);
+ EXPECT_EQ(FastRange32(0xffffffff, 0x7fffffff), 0x7ffffffeU);
+ EXPECT_EQ(FastRange32(0xffffffff, 0x80000000), 0x7fffffffU);
+}
+
+TEST(FastRange64Test, Values) {
+ using ROCKSDB_NAMESPACE::FastRange64;
+ // Zero range
+ EXPECT_EQ(FastRange64(0, 0), 0U);
+ EXPECT_EQ(FastRange64(123, 0), 0U);
+ EXPECT_EQ(FastRange64(0xffffFFFF, 0), 0U);
+ EXPECT_EQ(FastRange64(0xffffFFFFffffFFFF, 0), 0U);
+
+ // One range
+ EXPECT_EQ(FastRange64(0, 1), 0U);
+ EXPECT_EQ(FastRange64(123, 1), 0U);
+ EXPECT_EQ(FastRange64(0xffffFFFF, 1), 0U);
+ EXPECT_EQ(FastRange64(0xffffFFFFffffFFFF, 1), 0U);
+
+ // Two range
+ EXPECT_EQ(FastRange64(0, 2), 0U);
+ EXPECT_EQ(FastRange64(123, 2), 0U);
+ EXPECT_EQ(FastRange64(0xffffFFFF, 2), 0U);
+ EXPECT_EQ(FastRange64(0x7fffFFFFffffFFFF, 2), 0U);
+ EXPECT_EQ(FastRange64(0x8000000000000000, 2), 1U);
+ EXPECT_EQ(FastRange64(0xffffFFFFffffFFFF, 2), 1U);
+
+ // Seven range
+ EXPECT_EQ(FastRange64(0, 7), 0U);
+ EXPECT_EQ(FastRange64(123, 7), 0U);
+ EXPECT_EQ(FastRange64(0xffffFFFF, 7), 0U);
+ EXPECT_EQ(FastRange64(2635249153387078802, 7), 0U);
+ EXPECT_EQ(FastRange64(2635249153387078803, 7), 1U);
+ EXPECT_EQ(FastRange64(5270498306774157604, 7), 1U);
+ EXPECT_EQ(FastRange64(5270498306774157605, 7), 2U);
+ EXPECT_EQ(FastRange64(0x7fffFFFFffffFFFF, 7), 3U);
+ EXPECT_EQ(FastRange64(0x8000000000000000, 7), 3U);
+ EXPECT_EQ(FastRange64(0xffffFFFFffffFFFF, 7), 6U);
+
+ // Big but 32-bit range
+ EXPECT_EQ(FastRange64(0x100000000, 0x80000000), 0U);
+ EXPECT_EQ(FastRange64(0x200000000, 0x80000000), 1U);
+ EXPECT_EQ(FastRange64(0x400000000, 0x7fffFFFF), 1U);
+ EXPECT_EQ(FastRange64(0x400000000, 0x80000000), 2U);
+ EXPECT_EQ(FastRange64(0xffffFFFFffffFFFF, 0x7fffFFFF), 0x7fffFFFEU);
+ EXPECT_EQ(FastRange64(0xffffFFFFffffFFFF, 0x80000000), 0x7fffFFFFU);
+
+ // Big, > 32-bit range
+#if SIZE_MAX == UINT64_MAX
+ EXPECT_EQ(FastRange64(0x7fffFFFFffffFFFF, 0x4200000002), 0x2100000000U);
+ EXPECT_EQ(FastRange64(0x8000000000000000, 0x4200000002), 0x2100000001U);
+
+ EXPECT_EQ(FastRange64(0x0000000000000000, 420000000002), 0U);
+ EXPECT_EQ(FastRange64(0x7fffFFFFffffFFFF, 420000000002), 210000000000U);
+ EXPECT_EQ(FastRange64(0x8000000000000000, 420000000002), 210000000001U);
+ EXPECT_EQ(FastRange64(0xffffFFFFffffFFFF, 420000000002), 420000000001U);
+
+ EXPECT_EQ(FastRange64(0xffffFFFFffffFFFF, 0xffffFFFFffffFFFF),
+ 0xffffFFFFffffFFFEU);
+#endif
+}
+
+TEST(FastRangeGenericTest, Values) {
+ using ROCKSDB_NAMESPACE::FastRangeGeneric;
+ // Generic (including big and small)
+ // Note that FastRangeGeneric is also tested indirectly above via
+ // FastRange32 and FastRange64.
+ EXPECT_EQ(
+ FastRangeGeneric(uint64_t{0x8000000000000000}, uint64_t{420000000002}),
+ uint64_t{210000000001});
+ EXPECT_EQ(FastRangeGeneric(uint64_t{0x8000000000000000}, uint16_t{12468}),
+ uint16_t{6234});
+ EXPECT_EQ(FastRangeGeneric(uint32_t{0x80000000}, uint16_t{12468}),
+ uint16_t{6234});
+ // Not recommended for typical use because for example this could fail on
+ // some platforms and pass on others:
+ // EXPECT_EQ(FastRangeGeneric(static_cast<unsigned long>(0x80000000),
+ // uint16_t{12468}),
+ // uint16_t{6234});
+}
+
+// for inspection of disassembly
+uint32_t FastRange32(uint32_t hash, uint32_t range) {
+ return ROCKSDB_NAMESPACE::FastRange32(hash, range);
+}
+
+// for inspection of disassembly
+size_t FastRange64(uint64_t hash, size_t range) {
+ return ROCKSDB_NAMESPACE::FastRange64(hash, range);
+}
+
+// Tests for math.h / math128.h (not worth a separate test binary)
+using ROCKSDB_NAMESPACE::BitParity;
+using ROCKSDB_NAMESPACE::BitsSetToOne;
+using ROCKSDB_NAMESPACE::ConstexprFloorLog2;
+using ROCKSDB_NAMESPACE::CountTrailingZeroBits;
+using ROCKSDB_NAMESPACE::DecodeFixed128;
+using ROCKSDB_NAMESPACE::DecodeFixedGeneric;
+using ROCKSDB_NAMESPACE::DownwardInvolution;
+using ROCKSDB_NAMESPACE::EncodeFixed128;
+using ROCKSDB_NAMESPACE::EncodeFixedGeneric;
+using ROCKSDB_NAMESPACE::FloorLog2;
+using ROCKSDB_NAMESPACE::Lower64of128;
+using ROCKSDB_NAMESPACE::Multiply64to128;
+using ROCKSDB_NAMESPACE::Unsigned128;
+using ROCKSDB_NAMESPACE::Upper64of128;
+
+int blah(int x) { return DownwardInvolution(x); }
+
+template <typename T>
+static void test_BitOps() {
+ // This complex code is to generalize to 128-bit values. Otherwise
+ // we could just use = static_cast<T>(0x5555555555555555ULL);
+ T everyOtherBit = 0;
+ for (unsigned i = 0; i < sizeof(T); ++i) {
+ everyOtherBit = (everyOtherBit << 8) | T{0x55};
+ }
+
+ // This one built using bit operations, as our 128-bit layer
+ // might not implement arithmetic such as subtraction.
+ T vm1 = 0; // "v minus one"
+
+ for (int i = 0; i < int{8 * sizeof(T)}; ++i) {
+ T v = T{1} << i;
+ // If we could directly use arithmetic:
+ // T vm1 = static_cast<T>(v - 1);
+
+ // FloorLog2
+ if (v > 0) {
+ EXPECT_EQ(FloorLog2(v), i);
+ EXPECT_EQ(ConstexprFloorLog2(v), i);
+ }
+ if (vm1 > 0) {
+ EXPECT_EQ(FloorLog2(vm1), i - 1);
+ EXPECT_EQ(ConstexprFloorLog2(vm1), i - 1);
+ EXPECT_EQ(FloorLog2(everyOtherBit & vm1), (i - 1) & ~1);
+ EXPECT_EQ(ConstexprFloorLog2(everyOtherBit & vm1), (i - 1) & ~1);
+ }
+
+ // CountTrailingZeroBits
+ if (v != 0) {
+ EXPECT_EQ(CountTrailingZeroBits(v), i);
+ }
+ if (vm1 != 0) {
+ EXPECT_EQ(CountTrailingZeroBits(vm1), 0);
+ }
+ if (i < int{8 * sizeof(T)} - 1) {
+ EXPECT_EQ(CountTrailingZeroBits(~vm1 & everyOtherBit), (i + 1) & ~1);
+ }
+
+ // BitsSetToOne
+ EXPECT_EQ(BitsSetToOne(v), 1);
+ EXPECT_EQ(BitsSetToOne(vm1), i);
+ EXPECT_EQ(BitsSetToOne(vm1 & everyOtherBit), (i + 1) / 2);
+
+ // BitParity
+ EXPECT_EQ(BitParity(v), 1);
+ EXPECT_EQ(BitParity(vm1), i & 1);
+ EXPECT_EQ(BitParity(vm1 & everyOtherBit), ((i + 1) / 2) & 1);
+
+ // EndianSwapValue
+ T ev = T{1} << (((sizeof(T) - 1 - (i / 8)) * 8) + i % 8);
+ EXPECT_EQ(EndianSwapValue(v), ev);
+
+ // ReverseBits
+ EXPECT_EQ(ReverseBits(v), static_cast<T>(T{1} << (8 * sizeof(T) - 1 - i)));
+#ifdef HAVE_UINT128_EXTENSION // Uses multiplication
+ if (std::is_unsigned<T>::value) { // Technical UB on signed type
+ T rv = T{1} << (8 * sizeof(T) - 1 - i);
+ EXPECT_EQ(ReverseBits(vm1), static_cast<T>(rv * ~T{1}));
+ }
+#endif
+
+ // DownwardInvolution
+ {
+ T misc = static_cast<T>(/*random*/ 0xc682cd153d0e3279U +
+ i * /*random*/ 0x9b3972f3bea0baa3U);
+ if constexpr (sizeof(T) > 8) {
+ misc = (misc << 64) | (/*random*/ 0x52af031a38ced62dU +
+ i * /*random*/ 0x936f803d9752ddc3U);
+ }
+ T misc_masked = misc & vm1;
+ EXPECT_LE(misc_masked, vm1);
+ T di_misc_masked = DownwardInvolution(misc_masked);
+ EXPECT_LE(di_misc_masked, vm1);
+ if (misc_masked > 0) {
+ // Highest-order 1 in same position
+ EXPECT_EQ(FloorLog2(misc_masked), FloorLog2(di_misc_masked));
+ }
+ // Validate involution property on short value
+ EXPECT_EQ(DownwardInvolution(di_misc_masked), misc_masked);
+
+ // Validate involution property on large value
+ T di_misc = DownwardInvolution(misc);
+ EXPECT_EQ(DownwardInvolution(di_misc), misc);
+ // Highest-order 1 in same position
+ if (misc > 0) {
+ EXPECT_EQ(FloorLog2(misc), FloorLog2(di_misc));
+ }
+
+ // Validate distributes over xor.
+ // static_casts to avoid numerical promotion effects.
+ EXPECT_EQ(DownwardInvolution(static_cast<T>(misc_masked ^ vm1)),
+ static_cast<T>(di_misc_masked ^ DownwardInvolution(vm1)));
+ T misc2 = static_cast<T>(misc >> 1);
+ EXPECT_EQ(DownwardInvolution(static_cast<T>(misc ^ misc2)),
+ static_cast<T>(di_misc ^ DownwardInvolution(misc2)));
+
+ // Choose some small number of bits to pull off to test combined
+ // uniqueness guarantee
+ int in_bits = i % 7;
+ unsigned in_mask = (unsigned{1} << in_bits) - 1U;
+ // IMPLICIT: int out_bits = 8 - in_bits;
+ std::vector<bool> seen(256, false);
+ for (int j = 0; j < 255; ++j) {
+ T t_in = misc ^ static_cast<T>(j);
+ unsigned in = static_cast<unsigned>(t_in);
+ unsigned out = static_cast<unsigned>(DownwardInvolution(t_in));
+ unsigned val = ((out << in_bits) | (in & in_mask)) & 255U;
+ EXPECT_FALSE(seen[val]);
+ seen[val] = true;
+ }
+
+ if (i + 8 < int{8 * sizeof(T)}) {
+ // Also test manipulating bits in the middle of input is
+ // bijective in bottom of output
+ seen = std::vector<bool>(256, false);
+ for (int j = 0; j < 255; ++j) {
+ T in = misc ^ (static_cast<T>(j) << i);
+ unsigned val = static_cast<unsigned>(DownwardInvolution(in)) & 255U;
+ EXPECT_FALSE(seen[val]);
+ seen[val] = true;
+ }
+ }
+ }
+
+ vm1 = (vm1 << 1) | 1;
+ }
+
+ EXPECT_EQ(ConstexprFloorLog2(T{1}), 0);
+ EXPECT_EQ(ConstexprFloorLog2(T{2}), 1);
+ EXPECT_EQ(ConstexprFloorLog2(T{3}), 1);
+ EXPECT_EQ(ConstexprFloorLog2(T{42}), 5);
+}
+
+TEST(MathTest, BitOps) {
+ test_BitOps<uint32_t>();
+ test_BitOps<uint64_t>();
+ test_BitOps<uint16_t>();
+ test_BitOps<uint8_t>();
+ test_BitOps<unsigned char>();
+ test_BitOps<unsigned short>();
+ test_BitOps<unsigned int>();
+ test_BitOps<unsigned long>();
+ test_BitOps<unsigned long long>();
+ test_BitOps<char>();
+ test_BitOps<size_t>();
+ test_BitOps<int32_t>();
+ test_BitOps<int64_t>();
+ test_BitOps<int16_t>();
+ test_BitOps<int8_t>();
+ test_BitOps<signed char>();
+ test_BitOps<short>();
+ test_BitOps<int>();
+ test_BitOps<long>();
+ test_BitOps<long long>();
+ test_BitOps<ptrdiff_t>();
+}
+
+TEST(MathTest, BitOps128) { test_BitOps<Unsigned128>(); }
+
+TEST(MathTest, Math128) {
+ const Unsigned128 sixteenHexOnes = 0x1111111111111111U;
+ const Unsigned128 thirtyHexOnes = (sixteenHexOnes << 56) | sixteenHexOnes;
+ const Unsigned128 sixteenHexTwos = 0x2222222222222222U;
+ const Unsigned128 thirtyHexTwos = (sixteenHexTwos << 56) | sixteenHexTwos;
+
+ // v will slide from all hex ones to all hex twos
+ Unsigned128 v = thirtyHexOnes;
+ for (int i = 0; i <= 30; ++i) {
+ // Test bitwise operations
+ EXPECT_EQ(BitsSetToOne(v), 30);
+ EXPECT_EQ(BitsSetToOne(~v), 128 - 30);
+ EXPECT_EQ(BitsSetToOne(v & thirtyHexOnes), 30 - i);
+ EXPECT_EQ(BitsSetToOne(v | thirtyHexOnes), 30 + i);
+ EXPECT_EQ(BitsSetToOne(v ^ thirtyHexOnes), 2 * i);
+ EXPECT_EQ(BitsSetToOne(v & thirtyHexTwos), i);
+ EXPECT_EQ(BitsSetToOne(v | thirtyHexTwos), 60 - i);
+ EXPECT_EQ(BitsSetToOne(v ^ thirtyHexTwos), 60 - 2 * i);
+
+ // Test comparisons
+ EXPECT_EQ(v == thirtyHexOnes, i == 0);
+ EXPECT_EQ(v == thirtyHexTwos, i == 30);
+ EXPECT_EQ(v > thirtyHexOnes, i > 0);
+ EXPECT_EQ(v > thirtyHexTwos, false);
+ EXPECT_EQ(v >= thirtyHexOnes, true);
+ EXPECT_EQ(v >= thirtyHexTwos, i == 30);
+ EXPECT_EQ(v < thirtyHexOnes, false);
+ EXPECT_EQ(v < thirtyHexTwos, i < 30);
+ EXPECT_EQ(v <= thirtyHexOnes, i == 0);
+ EXPECT_EQ(v <= thirtyHexTwos, true);
+
+ // Update v, clearing upper-most byte
+ v = ((v << 12) >> 8) | 0x2;
+ }
+
+ for (int i = 0; i < 128; ++i) {
+ // Test shifts
+ Unsigned128 sl = thirtyHexOnes << i;
+ Unsigned128 sr = thirtyHexOnes >> i;
+ EXPECT_EQ(BitsSetToOne(sl), std::min(30, 32 - i / 4));
+ EXPECT_EQ(BitsSetToOne(sr), std::max(0, 30 - (i + 3) / 4));
+ EXPECT_EQ(BitsSetToOne(sl & sr), i % 2 ? 0 : std::max(0, 30 - i / 2));
+ }
+
+ // Test 64x64->128 multiply
+ Unsigned128 product =
+ Multiply64to128(0x1111111111111111U, 0x2222222222222222U);
+ EXPECT_EQ(Lower64of128(product), 2295594818061633090U);
+ EXPECT_EQ(Upper64of128(product), 163971058432973792U);
+}
+
+TEST(MathTest, Coding128) {
+ const char *in = "_1234567890123456";
+ // Note: in + 1 is likely unaligned
+ Unsigned128 decoded = DecodeFixed128(in + 1);
+ EXPECT_EQ(Lower64of128(decoded), 0x3837363534333231U);
+ EXPECT_EQ(Upper64of128(decoded), 0x3635343332313039U);
+ char out[18];
+ out[0] = '_';
+ EncodeFixed128(out + 1, decoded);
+ out[17] = '\0';
+ EXPECT_EQ(std::string(in), std::string(out));
+}
+
+TEST(MathTest, CodingGeneric) {
+ const char *in = "_1234567890123456";
+ // Decode
+ // Note: in + 1 is likely unaligned
+ Unsigned128 decoded128 = DecodeFixedGeneric<Unsigned128>(in + 1);
+ EXPECT_EQ(Lower64of128(decoded128), 0x3837363534333231U);
+ EXPECT_EQ(Upper64of128(decoded128), 0x3635343332313039U);
+
+ uint64_t decoded64 = DecodeFixedGeneric<uint64_t>(in + 1);
+ EXPECT_EQ(decoded64, 0x3837363534333231U);
+
+ uint32_t decoded32 = DecodeFixedGeneric<uint32_t>(in + 1);
+ EXPECT_EQ(decoded32, 0x34333231U);
+
+ uint16_t decoded16 = DecodeFixedGeneric<uint16_t>(in + 1);
+ EXPECT_EQ(decoded16, 0x3231U);
+
+ // Encode
+ char out[18];
+ out[0] = '_';
+ memset(out + 1, '\0', 17);
+ EncodeFixedGeneric(out + 1, decoded128);
+ EXPECT_EQ(std::string(in), std::string(out));
+
+ memset(out + 1, '\0', 9);
+ EncodeFixedGeneric(out + 1, decoded64);
+ EXPECT_EQ(std::string("_12345678"), std::string(out));
+
+ memset(out + 1, '\0', 5);
+ EncodeFixedGeneric(out + 1, decoded32);
+ EXPECT_EQ(std::string("_1234"), std::string(out));
+
+ memset(out + 1, '\0', 3);
+ EncodeFixedGeneric(out + 1, decoded16);
+ EXPECT_EQ(std::string("_12"), std::string(out));
+}
+
+int main(int argc, char **argv) {
+ fprintf(stderr, "NPHash64 id: %x\n",
+ static_cast<int>(ROCKSDB_NAMESPACE::GetSliceNPHash64("RocksDB")));
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+
+ return RUN_ALL_TESTS();
+}
diff --git a/src/rocksdb/util/heap.h b/src/rocksdb/util/heap.h
new file mode 100644
index 000000000..f221fc732
--- /dev/null
+++ b/src/rocksdb/util/heap.h
@@ -0,0 +1,174 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#include <algorithm>
+#include <cstdint>
+#include <functional>
+
+#include "port/port.h"
+#include "util/autovector.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// Binary heap implementation optimized for use in multi-way merge sort.
+// Comparison to std::priority_queue:
+// - In libstdc++, std::priority_queue::pop() usually performs just over logN
+// comparisons but never fewer.
+// - std::priority_queue does not have a replace-top operation, requiring a
+// pop+push. If the replacement element is the new top, this requires
+// around 2logN comparisons.
+// - This heap's pop() uses a "schoolbook" downheap which requires up to ~2logN
+// comparisons.
+// - This heap provides a replace_top() operation which requires [1, 2logN]
+// comparisons. When the replacement element is also the new top, this
+// takes just 1 or 2 comparisons.
+//
+// The last property can yield an order-of-magnitude performance improvement
+// when merge-sorting real-world non-random data. If the merge operation is
+// likely to take chunks of elements from the same input stream, only 1
+// comparison per element is needed. In RocksDB-land, this happens when
+// compacting a database where keys are not randomly distributed across L0
+// files but nearby keys are likely to be in the same L0 file.
+//
+// The container uses the same counterintuitive ordering as
+// std::priority_queue: the comparison operator is expected to provide the
+// less-than relation, but top() will return the maximum.
+
+template <typename T, typename Compare = std::less<T>>
+class BinaryHeap {
+ public:
+ BinaryHeap() {}
+ explicit BinaryHeap(Compare cmp) : cmp_(std::move(cmp)) {}
+
+ void push(const T& value) {
+ data_.push_back(value);
+ upheap(data_.size() - 1);
+ }
+
+ void push(T&& value) {
+ data_.push_back(std::move(value));
+ upheap(data_.size() - 1);
+ }
+
+ const T& top() const {
+ assert(!empty());
+ return data_.front();
+ }
+
+ void replace_top(const T& value) {
+ assert(!empty());
+ data_.front() = value;
+ downheap(get_root());
+ }
+
+ void replace_top(T&& value) {
+ assert(!empty());
+ data_.front() = std::move(value);
+ downheap(get_root());
+ }
+
+ void pop() {
+ assert(!empty());
+ if (data_.size() > 1) {
+ // Avoid self-move-assign, because it could cause problems with
+ // classes which are not prepared for this and it trips up the
+ // STL debugger when activated.
+ data_.front() = std::move(data_.back());
+ }
+ data_.pop_back();
+ if (!empty()) {
+ downheap(get_root());
+ } else {
+ reset_root_cmp_cache();
+ }
+ }
+
+ void swap(BinaryHeap& other) {
+ std::swap(cmp_, other.cmp_);
+ data_.swap(other.data_);
+ std::swap(root_cmp_cache_, other.root_cmp_cache_);
+ }
+
+ void clear() {
+ data_.clear();
+ reset_root_cmp_cache();
+ }
+
+ bool empty() const { return data_.empty(); }
+
+ size_t size() const { return data_.size(); }
+
+ void reset_root_cmp_cache() {
+ root_cmp_cache_ = std::numeric_limits<size_t>::max();
+ }
+
+ private:
+ static inline size_t get_root() { return 0; }
+ static inline size_t get_parent(size_t index) { return (index - 1) / 2; }
+ static inline size_t get_left(size_t index) { return 2 * index + 1; }
+ static inline size_t get_right(size_t index) { return 2 * index + 2; }
+
+ void upheap(size_t index) {
+ T v = std::move(data_[index]);
+ while (index > get_root()) {
+ const size_t parent = get_parent(index);
+ if (!cmp_(data_[parent], v)) {
+ break;
+ }
+ data_[index] = std::move(data_[parent]);
+ index = parent;
+ }
+ data_[index] = std::move(v);
+ reset_root_cmp_cache();
+ }
+
+ void downheap(size_t index) {
+ T v = std::move(data_[index]);
+
+ size_t picked_child = std::numeric_limits<size_t>::max();
+ while (1) {
+ const size_t left_child = get_left(index);
+ if (get_left(index) >= data_.size()) {
+ break;
+ }
+ const size_t right_child = left_child + 1;
+ assert(right_child == get_right(index));
+ picked_child = left_child;
+ if (index == 0 && root_cmp_cache_ < data_.size()) {
+ picked_child = root_cmp_cache_;
+ } else if (right_child < data_.size() &&
+ cmp_(data_[left_child], data_[right_child])) {
+ picked_child = right_child;
+ }
+ if (!cmp_(v, data_[picked_child])) {
+ break;
+ }
+ data_[index] = std::move(data_[picked_child]);
+ index = picked_child;
+ }
+
+ if (index == 0) {
+ // We did not change anything in the tree except for the value
+ // of the root node, left and right child did not change, we can
+ // cache that `picked_child` is the smallest child
+ // so next time we compare againist it directly
+ root_cmp_cache_ = picked_child;
+ } else {
+ // the tree changed, reset cache
+ reset_root_cmp_cache();
+ }
+
+ data_[index] = std::move(v);
+ }
+
+ Compare cmp_;
+ autovector<T> data_;
+ // Used to reduce number of cmp_ calls in downheap()
+ size_t root_cmp_cache_ = std::numeric_limits<size_t>::max();
+};
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/heap_test.cc b/src/rocksdb/util/heap_test.cc
new file mode 100644
index 000000000..bbb93324f
--- /dev/null
+++ b/src/rocksdb/util/heap_test.cc
@@ -0,0 +1,131 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "util/heap.h"
+
+#include <gtest/gtest.h>
+
+#include <climits>
+#include <queue>
+#include <random>
+#include <utility>
+
+#include "port/stack_trace.h"
+
+#ifndef GFLAGS
+const int64_t FLAGS_iters = 100000;
+#else
+#include "util/gflags_compat.h"
+DEFINE_int64(iters, 100000, "number of pseudo-random operations in each test");
+#endif // GFLAGS
+
+/*
+ * Compares the custom heap implementation in util/heap.h against
+ * std::priority_queue on a pseudo-random sequence of operations.
+ */
+
+namespace ROCKSDB_NAMESPACE {
+
+using HeapTestValue = uint64_t;
+using Params = std::tuple<size_t, HeapTestValue, int64_t>;
+
+class HeapTest : public ::testing::TestWithParam<Params> {};
+
+TEST_P(HeapTest, Test) {
+ // This test performs the same pseudorandom sequence of operations on a
+ // BinaryHeap and an std::priority_queue, comparing output. The three
+ // possible operations are insert, replace top and pop.
+ //
+ // Insert is chosen slightly more often than the others so that the size of
+ // the heap slowly grows. Once the size heats the MAX_HEAP_SIZE limit, we
+ // disallow inserting until the heap becomes empty, testing the "draining"
+ // scenario.
+
+ const auto MAX_HEAP_SIZE = std::get<0>(GetParam());
+ const auto MAX_VALUE = std::get<1>(GetParam());
+ const auto RNG_SEED = std::get<2>(GetParam());
+
+ BinaryHeap<HeapTestValue> heap;
+ std::priority_queue<HeapTestValue> ref;
+
+ std::mt19937 rng(static_cast<unsigned int>(RNG_SEED));
+ std::uniform_int_distribution<HeapTestValue> value_dist(0, MAX_VALUE);
+ int ndrains = 0;
+ bool draining = false; // hit max size, draining until we empty the heap
+ size_t size = 0;
+ for (int64_t i = 0; i < FLAGS_iters; ++i) {
+ if (size == 0) {
+ draining = false;
+ }
+
+ if (!draining && (size == 0 || std::bernoulli_distribution(0.4)(rng))) {
+ // insert
+ HeapTestValue val = value_dist(rng);
+ heap.push(val);
+ ref.push(val);
+ ++size;
+ if (size == MAX_HEAP_SIZE) {
+ draining = true;
+ ++ndrains;
+ }
+ } else if (std::bernoulli_distribution(0.5)(rng)) {
+ // replace top
+ HeapTestValue val = value_dist(rng);
+ heap.replace_top(val);
+ ref.pop();
+ ref.push(val);
+ } else {
+ // pop
+ assert(size > 0);
+ heap.pop();
+ ref.pop();
+ --size;
+ }
+
+ // After every operation, check that the public methods give the same
+ // results
+ assert((size == 0) == ref.empty());
+ ASSERT_EQ(size == 0, heap.empty());
+ if (size > 0) {
+ ASSERT_EQ(ref.top(), heap.top());
+ }
+ }
+
+ // Probabilities should be set up to occasionally hit the max heap size and
+ // drain it
+ assert(ndrains > 0);
+
+ heap.clear();
+ ASSERT_TRUE(heap.empty());
+}
+
+// Basic test, MAX_VALUE = 3*MAX_HEAP_SIZE (occasional duplicates)
+INSTANTIATE_TEST_CASE_P(Basic, HeapTest,
+ ::testing::Values(Params(1000, 3000,
+ 0x1b575cf05b708945)));
+// Mid-size heap with small values (many duplicates)
+INSTANTIATE_TEST_CASE_P(SmallValues, HeapTest,
+ ::testing::Values(Params(100, 10, 0x5ae213f7bd5dccd0)));
+// Small heap, large value range (no duplicates)
+INSTANTIATE_TEST_CASE_P(SmallHeap, HeapTest,
+ ::testing::Values(Params(10, ULLONG_MAX,
+ 0x3e1fa8f4d01707cf)));
+// Two-element heap
+INSTANTIATE_TEST_CASE_P(TwoElementHeap, HeapTest,
+ ::testing::Values(Params(2, 5, 0x4b5e13ea988c6abc)));
+// One-element heap
+INSTANTIATE_TEST_CASE_P(OneElementHeap, HeapTest,
+ ::testing::Values(Params(1, 3, 0x176a1019ab0b612e)));
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+#ifdef GFLAGS
+ GFLAGS_NAMESPACE::ParseCommandLineFlags(&argc, &argv, true);
+#endif // GFLAGS
+ return RUN_ALL_TESTS();
+}
diff --git a/src/rocksdb/util/kv_map.h b/src/rocksdb/util/kv_map.h
new file mode 100644
index 000000000..62be6d18e
--- /dev/null
+++ b/src/rocksdb/util/kv_map.h
@@ -0,0 +1,33 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+#pragma once
+
+#include <map>
+#include <string>
+
+#include "rocksdb/comparator.h"
+#include "rocksdb/slice.h"
+#include "util/coding.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace stl_wrappers {
+
+struct LessOfComparator {
+ explicit LessOfComparator(const Comparator* c = BytewiseComparator())
+ : cmp(c) {}
+
+ bool operator()(const std::string& a, const std::string& b) const {
+ return cmp->Compare(Slice(a), Slice(b)) < 0;
+ }
+ bool operator()(const Slice& a, const Slice& b) const {
+ return cmp->Compare(a, b) < 0;
+ }
+
+ const Comparator* cmp;
+};
+
+using KVMap = std::map<std::string, std::string, LessOfComparator>;
+} // namespace stl_wrappers
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/log_write_bench.cc b/src/rocksdb/util/log_write_bench.cc
new file mode 100644
index 000000000..c1637db15
--- /dev/null
+++ b/src/rocksdb/util/log_write_bench.cc
@@ -0,0 +1,88 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef GFLAGS
+#include <cstdio>
+int main() {
+ fprintf(stderr, "Please install gflags to run rocksdb tools\n");
+ return 1;
+}
+#else
+
+#include "file/writable_file_writer.h"
+#include "monitoring/histogram.h"
+#include "rocksdb/env.h"
+#include "rocksdb/system_clock.h"
+#include "test_util/testharness.h"
+#include "test_util/testutil.h"
+#include "util/gflags_compat.h"
+
+using GFLAGS_NAMESPACE::ParseCommandLineFlags;
+using GFLAGS_NAMESPACE::SetUsageMessage;
+
+// A simple benchmark to simulate transactional logs
+
+DEFINE_int32(num_records, 6000, "Number of records.");
+DEFINE_int32(record_size, 249, "Size of each record.");
+DEFINE_int32(record_interval, 10000, "Interval between records (microSec)");
+DEFINE_int32(bytes_per_sync, 0, "bytes_per_sync parameter in EnvOptions");
+DEFINE_bool(enable_sync, false, "sync after each write.");
+
+namespace ROCKSDB_NAMESPACE {
+void RunBenchmark() {
+ std::string file_name = test::PerThreadDBPath("log_write_benchmark.log");
+ DBOptions options;
+ Env* env = Env::Default();
+ const auto& clock = env->GetSystemClock();
+ EnvOptions env_options = env->OptimizeForLogWrite(EnvOptions(), options);
+ env_options.bytes_per_sync = FLAGS_bytes_per_sync;
+ std::unique_ptr<WritableFile> file;
+ env->NewWritableFile(file_name, &file, env_options);
+ std::unique_ptr<WritableFileWriter> writer;
+ writer.reset(new WritableFileWriter(std::move(file), file_name, env_options,
+ clock, nullptr /* stats */,
+ options.listeners));
+
+ std::string record;
+ record.assign(FLAGS_record_size, 'X');
+
+ HistogramImpl hist;
+
+ uint64_t start_time = clock->NowMicros();
+ for (int i = 0; i < FLAGS_num_records; i++) {
+ uint64_t start_nanos = clock->NowNanos();
+ writer->Append(record);
+ writer->Flush();
+ if (FLAGS_enable_sync) {
+ writer->Sync(false);
+ }
+ hist.Add(clock->NowNanos() - start_nanos);
+
+ if (i % 1000 == 1) {
+ fprintf(stderr, "Wrote %d records...\n", i);
+ }
+
+ int time_to_sleep =
+ (i + 1) * FLAGS_record_interval - (clock->NowMicros() - start_time);
+ if (time_to_sleep > 0) {
+ clock->SleepForMicroseconds(time_to_sleep);
+ }
+ }
+
+ fprintf(stderr, "Distribution of latency of append+flush: \n%s",
+ hist.ToString().c_str());
+}
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ SetUsageMessage(std::string("\nUSAGE:\n") + std::string(argv[0]) +
+ " [OPTIONS]...");
+ ParseCommandLineFlags(&argc, &argv, true);
+
+ ROCKSDB_NAMESPACE::RunBenchmark();
+ return 0;
+}
+
+#endif // GFLAGS
diff --git a/src/rocksdb/util/math.h b/src/rocksdb/util/math.h
new file mode 100644
index 000000000..da31b43ec
--- /dev/null
+++ b/src/rocksdb/util/math.h
@@ -0,0 +1,294 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#include <assert.h>
+#ifdef _MSC_VER
+#include <intrin.h>
+#endif
+
+#include <cstdint>
+#include <type_traits>
+
+#include "rocksdb/rocksdb_namespace.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// Fast implementation of floor(log2(v)). Undefined for 0 or negative
+// numbers (in case of signed type).
+template <typename T>
+inline int FloorLog2(T v) {
+ static_assert(std::is_integral<T>::value, "non-integral type");
+ assert(v > 0);
+#ifdef _MSC_VER
+ static_assert(sizeof(T) <= sizeof(uint64_t), "type too big");
+ unsigned long idx = 0;
+ if (sizeof(T) <= sizeof(uint32_t)) {
+ _BitScanReverse(&idx, static_cast<uint32_t>(v));
+ } else {
+#if defined(_M_X64) || defined(_M_ARM64)
+ _BitScanReverse64(&idx, static_cast<uint64_t>(v));
+#else
+ const auto vh = static_cast<uint32_t>(static_cast<uint64_t>(v) >> 32);
+ if (vh != 0) {
+ _BitScanReverse(&idx, static_cast<uint32_t>(vh));
+ idx += 32;
+ } else {
+ _BitScanReverse(&idx, static_cast<uint32_t>(v));
+ }
+#endif
+ }
+ return idx;
+#else
+ static_assert(sizeof(T) <= sizeof(unsigned long long), "type too big");
+ if (sizeof(T) <= sizeof(unsigned int)) {
+ int lz = __builtin_clz(static_cast<unsigned int>(v));
+ return int{sizeof(unsigned int)} * 8 - 1 - lz;
+ } else if (sizeof(T) <= sizeof(unsigned long)) {
+ int lz = __builtin_clzl(static_cast<unsigned long>(v));
+ return int{sizeof(unsigned long)} * 8 - 1 - lz;
+ } else {
+ int lz = __builtin_clzll(static_cast<unsigned long long>(v));
+ return int{sizeof(unsigned long long)} * 8 - 1 - lz;
+ }
+#endif
+}
+
+// Constexpr version of FloorLog2
+template <typename T>
+constexpr int ConstexprFloorLog2(T v) {
+ int rv = 0;
+ while (v > T{1}) {
+ ++rv;
+ v >>= 1;
+ }
+ return rv;
+}
+
+// Number of low-order zero bits before the first 1 bit. Undefined for 0.
+template <typename T>
+inline int CountTrailingZeroBits(T v) {
+ static_assert(std::is_integral<T>::value, "non-integral type");
+ assert(v != 0);
+#ifdef _MSC_VER
+ static_assert(sizeof(T) <= sizeof(uint64_t), "type too big");
+ unsigned long tz = 0;
+ if (sizeof(T) <= sizeof(uint32_t)) {
+ _BitScanForward(&tz, static_cast<uint32_t>(v));
+ } else {
+#if defined(_M_X64) || defined(_M_ARM64)
+ _BitScanForward64(&tz, static_cast<uint64_t>(v));
+#else
+ _BitScanForward(&tz, static_cast<uint32_t>(v));
+ if (tz == 0) {
+ _BitScanForward(&tz,
+ static_cast<uint32_t>(static_cast<uint64_t>(v) >> 32));
+ tz += 32;
+ }
+#endif
+ }
+ return static_cast<int>(tz);
+#else
+ static_assert(sizeof(T) <= sizeof(unsigned long long), "type too big");
+ if (sizeof(T) <= sizeof(unsigned int)) {
+ return __builtin_ctz(static_cast<unsigned int>(v));
+ } else if (sizeof(T) <= sizeof(unsigned long)) {
+ return __builtin_ctzl(static_cast<unsigned long>(v));
+ } else {
+ return __builtin_ctzll(static_cast<unsigned long long>(v));
+ }
+#endif
+}
+
+// Not all MSVC compile settings will use `BitsSetToOneFallback()`. We include
+// the following code at coarse granularity for simpler macros. It's important
+// to exclude at least so our non-MSVC unit test coverage tool doesn't see it.
+#ifdef _MSC_VER
+
+namespace detail {
+
+template <typename T>
+int BitsSetToOneFallback(T v) {
+ const int kBits = static_cast<int>(sizeof(T)) * 8;
+ static_assert((kBits & (kBits - 1)) == 0, "must be power of two bits");
+ // we static_cast these bit patterns in order to truncate them to the correct
+ // size. Warning C4309 dislikes this technique, so disable it here.
+#pragma warning(disable : 4309)
+ v = static_cast<T>(v - ((v >> 1) & static_cast<T>(0x5555555555555555ull)));
+ v = static_cast<T>((v & static_cast<T>(0x3333333333333333ull)) +
+ ((v >> 2) & static_cast<T>(0x3333333333333333ull)));
+ v = static_cast<T>((v + (v >> 4)) & static_cast<T>(0x0F0F0F0F0F0F0F0Full));
+#pragma warning(default : 4309)
+ for (int shift_bits = 8; shift_bits < kBits; shift_bits <<= 1) {
+ v += static_cast<T>(v >> shift_bits);
+ }
+ // we want the bottom "slot" that's big enough to represent a value up to
+ // (and including) kBits.
+ return static_cast<int>(v & static_cast<T>(kBits | (kBits - 1)));
+}
+
+} // namespace detail
+
+#endif // _MSC_VER
+
+// Number of bits set to 1. Also known as "population count".
+template <typename T>
+inline int BitsSetToOne(T v) {
+ static_assert(std::is_integral<T>::value, "non-integral type");
+#ifdef _MSC_VER
+ static_assert(sizeof(T) <= sizeof(uint64_t), "type too big");
+ if (sizeof(T) < sizeof(uint32_t)) {
+ // This bit mask is to avoid a compiler warning on unused path
+ constexpr auto mm = 8 * sizeof(uint32_t) - 1;
+ // The bit mask is to neutralize sign extension on small signed types
+ constexpr uint32_t m = (uint32_t{1} << ((8 * sizeof(T)) & mm)) - 1;
+#if defined(HAVE_SSE42) && (defined(_M_X64) || defined(_M_IX86))
+ return static_cast<int>(__popcnt(static_cast<uint32_t>(v) & m));
+#else
+ return static_cast<int>(detail::BitsSetToOneFallback(v) & m);
+#endif
+ } else if (sizeof(T) == sizeof(uint32_t)) {
+#if defined(HAVE_SSE42) && (defined(_M_X64) || defined(_M_IX86))
+ return static_cast<int>(__popcnt(static_cast<uint32_t>(v)));
+#else
+ return detail::BitsSetToOneFallback(static_cast<uint32_t>(v));
+#endif
+ } else {
+#if defined(HAVE_SSE42) && defined(_M_X64)
+ return static_cast<int>(__popcnt64(static_cast<uint64_t>(v)));
+#elif defined(HAVE_SSE42) && defined(_M_IX86)
+ return static_cast<int>(
+ __popcnt(static_cast<uint32_t>(static_cast<uint64_t>(v) >> 32) +
+ __popcnt(static_cast<uint32_t>(v))));
+#else
+ return detail::BitsSetToOneFallback(static_cast<uint64_t>(v));
+#endif
+ }
+#else
+ static_assert(sizeof(T) <= sizeof(unsigned long long), "type too big");
+ if (sizeof(T) < sizeof(unsigned int)) {
+ // This bit mask is to avoid a compiler warning on unused path
+ constexpr auto mm = 8 * sizeof(unsigned int) - 1;
+ // This bit mask is to neutralize sign extension on small signed types
+ constexpr unsigned int m = (1U << ((8 * sizeof(T)) & mm)) - 1;
+ return __builtin_popcount(static_cast<unsigned int>(v) & m);
+ } else if (sizeof(T) == sizeof(unsigned int)) {
+ return __builtin_popcount(static_cast<unsigned int>(v));
+ } else if (sizeof(T) <= sizeof(unsigned long)) {
+ return __builtin_popcountl(static_cast<unsigned long>(v));
+ } else {
+ return __builtin_popcountll(static_cast<unsigned long long>(v));
+ }
+#endif
+}
+
+template <typename T>
+inline int BitParity(T v) {
+ static_assert(std::is_integral<T>::value, "non-integral type");
+#ifdef _MSC_VER
+ // bit parity == oddness of popcount
+ return BitsSetToOne(v) & 1;
+#else
+ static_assert(sizeof(T) <= sizeof(unsigned long long), "type too big");
+ if (sizeof(T) <= sizeof(unsigned int)) {
+ // On any sane systen, potential sign extension here won't change parity
+ return __builtin_parity(static_cast<unsigned int>(v));
+ } else if (sizeof(T) <= sizeof(unsigned long)) {
+ return __builtin_parityl(static_cast<unsigned long>(v));
+ } else {
+ return __builtin_parityll(static_cast<unsigned long long>(v));
+ }
+#endif
+}
+
+// Swaps between big and little endian. Can be used in combination with the
+// little-endian encoding/decoding functions in coding_lean.h and coding.h to
+// encode/decode big endian.
+template <typename T>
+inline T EndianSwapValue(T v) {
+ static_assert(std::is_integral<T>::value, "non-integral type");
+
+#ifdef _MSC_VER
+ if (sizeof(T) == 2) {
+ return static_cast<T>(_byteswap_ushort(static_cast<uint16_t>(v)));
+ } else if (sizeof(T) == 4) {
+ return static_cast<T>(_byteswap_ulong(static_cast<uint32_t>(v)));
+ } else if (sizeof(T) == 8) {
+ return static_cast<T>(_byteswap_uint64(static_cast<uint64_t>(v)));
+ }
+#else
+ if (sizeof(T) == 2) {
+ return static_cast<T>(__builtin_bswap16(static_cast<uint16_t>(v)));
+ } else if (sizeof(T) == 4) {
+ return static_cast<T>(__builtin_bswap32(static_cast<uint32_t>(v)));
+ } else if (sizeof(T) == 8) {
+ return static_cast<T>(__builtin_bswap64(static_cast<uint64_t>(v)));
+ }
+#endif
+ // Recognized by clang as bswap, but not by gcc :(
+ T ret_val = 0;
+ for (std::size_t i = 0; i < sizeof(T); ++i) {
+ ret_val |= ((v >> (8 * i)) & 0xff) << (8 * (sizeof(T) - 1 - i));
+ }
+ return ret_val;
+}
+
+// Reverses the order of bits in an integral value
+template <typename T>
+inline T ReverseBits(T v) {
+ T r = EndianSwapValue(v);
+ const T kHighestByte = T{1} << ((sizeof(T) - 1) * 8);
+ const T kEveryByte = kHighestByte | (kHighestByte / 255);
+
+ r = ((r & (kEveryByte * 0x0f)) << 4) | ((r >> 4) & (kEveryByte * 0x0f));
+ r = ((r & (kEveryByte * 0x33)) << 2) | ((r >> 2) & (kEveryByte * 0x33));
+ r = ((r & (kEveryByte * 0x55)) << 1) | ((r >> 1) & (kEveryByte * 0x55));
+
+ return r;
+}
+
+// Every output bit depends on many input bits in the same and higher
+// positions, but not lower positions. Specifically, this function
+// * Output highest bit set to 1 is same as input (same FloorLog2, or
+// equivalently, same number of leading zeros)
+// * Is its own inverse (an involution)
+// * Guarantees that b bottom bits of v and c bottom bits of
+// DownwardInvolution(v) uniquely identify b + c bottom bits of v
+// (which is all of v if v < 2**(b + c)).
+// ** A notable special case is that modifying c adjacent bits at
+// some chosen position in the input is bijective with the bottom c
+// output bits.
+// * Distributes over xor, as in DI(a ^ b) == DI(a) ^ DI(b)
+//
+// This transformation is equivalent to a matrix*vector multiplication in
+// GF(2) where the matrix is recursively defined by the pattern matrix
+// P = | 1 1 |
+// | 0 1 |
+// and replacing 1's with P and 0's with 2x2 zero matices to some depth,
+// e.g. depth of 6 for 64-bit T. An essential feature of this matrix
+// is that all square sub-matrices that include the top row are invertible.
+template <typename T>
+inline T DownwardInvolution(T v) {
+ static_assert(std::is_integral<T>::value, "non-integral type");
+ static_assert(sizeof(T) <= 8, "only supported up to 64 bits");
+
+ uint64_t r = static_cast<uint64_t>(v);
+ if constexpr (sizeof(T) > 4) {
+ r ^= r >> 32;
+ }
+ if constexpr (sizeof(T) > 2) {
+ r ^= (r & 0xffff0000ffff0000U) >> 16;
+ }
+ if constexpr (sizeof(T) > 1) {
+ r ^= (r & 0xff00ff00ff00ff00U) >> 8;
+ }
+ r ^= (r & 0xf0f0f0f0f0f0f0f0U) >> 4;
+ r ^= (r & 0xccccccccccccccccU) >> 2;
+ r ^= (r & 0xaaaaaaaaaaaaaaaaU) >> 1;
+ return static_cast<T>(r);
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/math128.h b/src/rocksdb/util/math128.h
new file mode 100644
index 000000000..ae490051a
--- /dev/null
+++ b/src/rocksdb/util/math128.h
@@ -0,0 +1,316 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#include "util/coding_lean.h"
+#include "util/math.h"
+
+#ifdef TEST_UINT128_COMPAT
+#undef HAVE_UINT128_EXTENSION
+#endif
+
+namespace ROCKSDB_NAMESPACE {
+
+// Unsigned128 is a 128 bit value supporting (at least) bitwise operators,
+// shifts, and comparisons. __uint128_t is not always available.
+
+#ifdef HAVE_UINT128_EXTENSION
+using Unsigned128 = __uint128_t;
+#else
+struct Unsigned128 {
+ uint64_t lo;
+ uint64_t hi;
+
+ inline Unsigned128() {
+ static_assert(sizeof(Unsigned128) == 2 * sizeof(uint64_t),
+ "unexpected overhead in representation");
+ lo = 0;
+ hi = 0;
+ }
+
+ inline Unsigned128(uint64_t lower) {
+ lo = lower;
+ hi = 0;
+ }
+
+ inline Unsigned128(uint64_t lower, uint64_t upper) {
+ lo = lower;
+ hi = upper;
+ }
+
+ explicit operator uint64_t() { return lo; }
+
+ explicit operator uint32_t() { return static_cast<uint32_t>(lo); }
+
+ explicit operator uint16_t() { return static_cast<uint16_t>(lo); }
+
+ explicit operator uint8_t() { return static_cast<uint8_t>(lo); }
+};
+
+inline Unsigned128 operator<<(const Unsigned128& lhs, unsigned shift) {
+ shift &= 127;
+ Unsigned128 rv;
+ if (shift >= 64) {
+ rv.lo = 0;
+ rv.hi = lhs.lo << (shift & 63);
+ } else {
+ uint64_t tmp = lhs.lo;
+ rv.lo = tmp << shift;
+ // Ensure shift==0 shifts away everything. (This avoids another
+ // conditional branch on shift == 0.)
+ tmp = tmp >> 1 >> (63 - shift);
+ rv.hi = tmp | (lhs.hi << shift);
+ }
+ return rv;
+}
+
+inline Unsigned128& operator<<=(Unsigned128& lhs, unsigned shift) {
+ lhs = lhs << shift;
+ return lhs;
+}
+
+inline Unsigned128 operator>>(const Unsigned128& lhs, unsigned shift) {
+ shift &= 127;
+ Unsigned128 rv;
+ if (shift >= 64) {
+ rv.hi = 0;
+ rv.lo = lhs.hi >> (shift & 63);
+ } else {
+ uint64_t tmp = lhs.hi;
+ rv.hi = tmp >> shift;
+ // Ensure shift==0 shifts away everything
+ tmp = tmp << 1 << (63 - shift);
+ rv.lo = tmp | (lhs.lo >> shift);
+ }
+ return rv;
+}
+
+inline Unsigned128& operator>>=(Unsigned128& lhs, unsigned shift) {
+ lhs = lhs >> shift;
+ return lhs;
+}
+
+inline Unsigned128 operator&(const Unsigned128& lhs, const Unsigned128& rhs) {
+ return Unsigned128(lhs.lo & rhs.lo, lhs.hi & rhs.hi);
+}
+
+inline Unsigned128& operator&=(Unsigned128& lhs, const Unsigned128& rhs) {
+ lhs = lhs & rhs;
+ return lhs;
+}
+
+inline Unsigned128 operator|(const Unsigned128& lhs, const Unsigned128& rhs) {
+ return Unsigned128(lhs.lo | rhs.lo, lhs.hi | rhs.hi);
+}
+
+inline Unsigned128& operator|=(Unsigned128& lhs, const Unsigned128& rhs) {
+ lhs = lhs | rhs;
+ return lhs;
+}
+
+inline Unsigned128 operator^(const Unsigned128& lhs, const Unsigned128& rhs) {
+ return Unsigned128(lhs.lo ^ rhs.lo, lhs.hi ^ rhs.hi);
+}
+
+inline Unsigned128& operator^=(Unsigned128& lhs, const Unsigned128& rhs) {
+ lhs = lhs ^ rhs;
+ return lhs;
+}
+
+inline Unsigned128 operator~(const Unsigned128& v) {
+ return Unsigned128(~v.lo, ~v.hi);
+}
+
+inline bool operator==(const Unsigned128& lhs, const Unsigned128& rhs) {
+ return lhs.lo == rhs.lo && lhs.hi == rhs.hi;
+}
+
+inline bool operator!=(const Unsigned128& lhs, const Unsigned128& rhs) {
+ return lhs.lo != rhs.lo || lhs.hi != rhs.hi;
+}
+
+inline bool operator>(const Unsigned128& lhs, const Unsigned128& rhs) {
+ return lhs.hi > rhs.hi || (lhs.hi == rhs.hi && lhs.lo > rhs.lo);
+}
+
+inline bool operator<(const Unsigned128& lhs, const Unsigned128& rhs) {
+ return lhs.hi < rhs.hi || (lhs.hi == rhs.hi && lhs.lo < rhs.lo);
+}
+
+inline bool operator>=(const Unsigned128& lhs, const Unsigned128& rhs) {
+ return lhs.hi > rhs.hi || (lhs.hi == rhs.hi && lhs.lo >= rhs.lo);
+}
+
+inline bool operator<=(const Unsigned128& lhs, const Unsigned128& rhs) {
+ return lhs.hi < rhs.hi || (lhs.hi == rhs.hi && lhs.lo <= rhs.lo);
+}
+#endif
+
+inline uint64_t Lower64of128(Unsigned128 v) {
+#ifdef HAVE_UINT128_EXTENSION
+ return static_cast<uint64_t>(v);
+#else
+ return v.lo;
+#endif
+}
+
+inline uint64_t Upper64of128(Unsigned128 v) {
+#ifdef HAVE_UINT128_EXTENSION
+ return static_cast<uint64_t>(v >> 64);
+#else
+ return v.hi;
+#endif
+}
+
+// This generally compiles down to a single fast instruction on 64-bit.
+// This doesn't really make sense as operator* because it's not a
+// general 128x128 multiply and provides more output than 64x64 multiply.
+inline Unsigned128 Multiply64to128(uint64_t a, uint64_t b) {
+#ifdef HAVE_UINT128_EXTENSION
+ return Unsigned128{a} * Unsigned128{b};
+#else
+ // Full decomposition
+ // NOTE: GCC seems to fully understand this code as 64-bit x 64-bit
+ // -> 128-bit multiplication and optimize it appropriately.
+ uint64_t tmp = uint64_t{b & 0xffffFFFF} * uint64_t{a & 0xffffFFFF};
+ uint64_t lower = tmp & 0xffffFFFF;
+ tmp >>= 32;
+ tmp += uint64_t{b & 0xffffFFFF} * uint64_t{a >> 32};
+ // Avoid overflow: first add lower 32 of tmp2, and later upper 32
+ uint64_t tmp2 = uint64_t{b >> 32} * uint64_t{a & 0xffffFFFF};
+ tmp += tmp2 & 0xffffFFFF;
+ lower |= tmp << 32;
+ tmp >>= 32;
+ tmp += tmp2 >> 32;
+ tmp += uint64_t{b >> 32} * uint64_t{a >> 32};
+ return Unsigned128(lower, tmp);
+#endif
+}
+
+template <>
+inline int FloorLog2(Unsigned128 v) {
+ if (Upper64of128(v) == 0) {
+ return FloorLog2(Lower64of128(v));
+ } else {
+ return FloorLog2(Upper64of128(v)) + 64;
+ }
+}
+
+template <>
+inline int CountTrailingZeroBits(Unsigned128 v) {
+ if (Lower64of128(v) != 0) {
+ return CountTrailingZeroBits(Lower64of128(v));
+ } else {
+ return CountTrailingZeroBits(Upper64of128(v)) + 64;
+ }
+}
+
+template <>
+inline int BitsSetToOne(Unsigned128 v) {
+ return BitsSetToOne(Lower64of128(v)) + BitsSetToOne(Upper64of128(v));
+}
+
+template <>
+inline int BitParity(Unsigned128 v) {
+ return BitParity(Lower64of128(v) ^ Upper64of128(v));
+}
+
+template <>
+inline Unsigned128 EndianSwapValue(Unsigned128 v) {
+ return (Unsigned128{EndianSwapValue(Lower64of128(v))} << 64) |
+ EndianSwapValue(Upper64of128(v));
+}
+
+template <>
+inline Unsigned128 ReverseBits(Unsigned128 v) {
+ return (Unsigned128{ReverseBits(Lower64of128(v))} << 64) |
+ ReverseBits(Upper64of128(v));
+}
+
+template <>
+inline Unsigned128 DownwardInvolution(Unsigned128 v) {
+ return (Unsigned128{DownwardInvolution(Upper64of128(v))} << 64) |
+ DownwardInvolution(Upper64of128(v) ^ Lower64of128(v));
+}
+
+template <typename T>
+struct IsUnsignedUpTo128
+ : std::integral_constant<bool, std::is_unsigned<T>::value ||
+ std::is_same<T, Unsigned128>::value> {};
+
+inline void EncodeFixed128(char* dst, Unsigned128 value) {
+ EncodeFixed64(dst, Lower64of128(value));
+ EncodeFixed64(dst + 8, Upper64of128(value));
+}
+
+inline Unsigned128 DecodeFixed128(const char* ptr) {
+ Unsigned128 rv = DecodeFixed64(ptr + 8);
+ return (rv << 64) | DecodeFixed64(ptr);
+}
+
+// A version of EncodeFixed* for generic algorithms. Likely to be used
+// with Unsigned128, so lives here for now.
+template <typename T>
+inline void EncodeFixedGeneric(char* /*dst*/, T /*value*/) {
+ // Unfortunately, GCC does not appear to optimize this simple code down
+ // to a trivial load on Intel:
+ //
+ // T ret_val = 0;
+ // for (size_t i = 0; i < sizeof(T); ++i) {
+ // ret_val |= (static_cast<T>(static_cast<unsigned char>(ptr[i])) << (8 *
+ // i));
+ // }
+ // return ret_val;
+ //
+ // But does unroll the loop, and does optimize manually unrolled version
+ // for specific sizes down to a trivial load. I have no idea why it doesn't
+ // do both on this code.
+
+ // So instead, we rely on specializations
+ static_assert(sizeof(T) == 0, "No specialization provided for this type");
+}
+
+template <>
+inline void EncodeFixedGeneric(char* dst, uint16_t value) {
+ return EncodeFixed16(dst, value);
+}
+template <>
+inline void EncodeFixedGeneric(char* dst, uint32_t value) {
+ return EncodeFixed32(dst, value);
+}
+template <>
+inline void EncodeFixedGeneric(char* dst, uint64_t value) {
+ return EncodeFixed64(dst, value);
+}
+template <>
+inline void EncodeFixedGeneric(char* dst, Unsigned128 value) {
+ return EncodeFixed128(dst, value);
+}
+
+// A version of EncodeFixed* for generic algorithms.
+template <typename T>
+inline T DecodeFixedGeneric(const char* /*dst*/) {
+ static_assert(sizeof(T) == 0, "No specialization provided for this type");
+}
+
+template <>
+inline uint16_t DecodeFixedGeneric(const char* dst) {
+ return DecodeFixed16(dst);
+}
+template <>
+inline uint32_t DecodeFixedGeneric(const char* dst) {
+ return DecodeFixed32(dst);
+}
+template <>
+inline uint64_t DecodeFixedGeneric(const char* dst) {
+ return DecodeFixed64(dst);
+}
+template <>
+inline Unsigned128 DecodeFixedGeneric(const char* dst) {
+ return DecodeFixed128(dst);
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/murmurhash.cc b/src/rocksdb/util/murmurhash.cc
new file mode 100644
index 000000000..a69f3918a
--- /dev/null
+++ b/src/rocksdb/util/murmurhash.cc
@@ -0,0 +1,196 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+/*
+ Murmurhash from http://sites.google.com/site/murmurhash/
+
+ All code is released to the public domain. For business purposes, Murmurhash
+ is under the MIT license.
+*/
+#include "murmurhash.h"
+
+#include "port/lang.h"
+
+#if defined(__x86_64__)
+
+// -------------------------------------------------------------------
+//
+// The same caveats as 32-bit MurmurHash2 apply here - beware of alignment
+// and endian-ness issues if used across multiple platforms.
+//
+// 64-bit hash for 64-bit platforms
+
+#ifdef ROCKSDB_UBSAN_RUN
+#if defined(__clang__)
+__attribute__((__no_sanitize__("alignment")))
+#elif defined(__GNUC__)
+__attribute__((__no_sanitize_undefined__))
+#endif
+#endif
+// clang-format off
+uint64_t MurmurHash64A ( const void * key, int len, unsigned int seed )
+{
+ const uint64_t m = 0xc6a4a7935bd1e995;
+ const int r = 47;
+
+ uint64_t h = seed ^ (len * m);
+
+ const uint64_t * data = (const uint64_t *)key;
+ const uint64_t * end = data + (len/8);
+
+ while(data != end)
+ {
+ uint64_t k = *data++;
+
+ k *= m;
+ k ^= k >> r;
+ k *= m;
+
+ h ^= k;
+ h *= m;
+ }
+
+ const unsigned char * data2 = (const unsigned char*)data;
+
+ switch(len & 7)
+ {
+ case 7: h ^= ((uint64_t)data2[6]) << 48; FALLTHROUGH_INTENDED;
+ case 6: h ^= ((uint64_t)data2[5]) << 40; FALLTHROUGH_INTENDED;
+ case 5: h ^= ((uint64_t)data2[4]) << 32; FALLTHROUGH_INTENDED;
+ case 4: h ^= ((uint64_t)data2[3]) << 24; FALLTHROUGH_INTENDED;
+ case 3: h ^= ((uint64_t)data2[2]) << 16; FALLTHROUGH_INTENDED;
+ case 2: h ^= ((uint64_t)data2[1]) << 8; FALLTHROUGH_INTENDED;
+ case 1: h ^= ((uint64_t)data2[0]);
+ h *= m;
+ };
+
+ h ^= h >> r;
+ h *= m;
+ h ^= h >> r;
+
+ return h;
+}
+// clang-format on
+
+#elif defined(__i386__)
+
+// -------------------------------------------------------------------
+//
+// Note - This code makes a few assumptions about how your machine behaves -
+//
+// 1. We can read a 4-byte value from any address without crashing
+// 2. sizeof(int) == 4
+//
+// And it has a few limitations -
+//
+// 1. It will not work incrementally.
+// 2. It will not produce the same results on little-endian and big-endian
+// machines.
+// clang-format off
+unsigned int MurmurHash2 ( const void * key, int len, unsigned int seed )
+{
+ // 'm' and 'r' are mixing constants generated offline.
+ // They're not really 'magic', they just happen to work well.
+
+ const unsigned int m = 0x5bd1e995;
+ const int r = 24;
+
+ // Initialize the hash to a 'random' value
+
+ unsigned int h = seed ^ len;
+
+ // Mix 4 bytes at a time into the hash
+
+ const unsigned char * data = (const unsigned char *)key;
+
+ while(len >= 4)
+ {
+ unsigned int k = *(unsigned int *)data;
+
+ k *= m;
+ k ^= k >> r;
+ k *= m;
+
+ h *= m;
+ h ^= k;
+
+ data += 4;
+ len -= 4;
+ }
+
+ // Handle the last few bytes of the input array
+
+ switch(len)
+ {
+ case 3: h ^= data[2] << 16; FALLTHROUGH_INTENDED;
+ case 2: h ^= data[1] << 8; FALLTHROUGH_INTENDED;
+ case 1: h ^= data[0];
+ h *= m;
+ };
+
+ // Do a few final mixes of the hash to ensure the last few
+ // bytes are well-incorporated.
+
+ h ^= h >> 13;
+ h *= m;
+ h ^= h >> 15;
+
+ return h;
+}
+// clang-format on
+
+#else
+
+// -------------------------------------------------------------------
+//
+// Same as MurmurHash2, but endian- and alignment-neutral.
+// Half the speed though, alas.
+// clang-format off
+unsigned int MurmurHashNeutral2 ( const void * key, int len, unsigned int seed )
+{
+ const unsigned int m = 0x5bd1e995;
+ const int r = 24;
+
+ unsigned int h = seed ^ len;
+
+ const unsigned char * data = (const unsigned char *)key;
+
+ while(len >= 4)
+ {
+ unsigned int k;
+
+ k = data[0];
+ k |= data[1] << 8;
+ k |= data[2] << 16;
+ k |= data[3] << 24;
+
+ k *= m;
+ k ^= k >> r;
+ k *= m;
+
+ h *= m;
+ h ^= k;
+
+ data += 4;
+ len -= 4;
+ }
+
+ switch(len)
+ {
+ case 3: h ^= data[2] << 16; FALLTHROUGH_INTENDED;
+ case 2: h ^= data[1] << 8; FALLTHROUGH_INTENDED;
+ case 1: h ^= data[0];
+ h *= m;
+ };
+
+ h ^= h >> 13;
+ h *= m;
+ h ^= h >> 15;
+
+ return h;
+}
+// clang-format on
+
+#endif
diff --git a/src/rocksdb/util/murmurhash.h b/src/rocksdb/util/murmurhash.h
new file mode 100644
index 000000000..7ef4cbbec
--- /dev/null
+++ b/src/rocksdb/util/murmurhash.h
@@ -0,0 +1,43 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+/*
+ Murmurhash from http://sites.google.com/site/murmurhash/
+
+ All code is released to the public domain. For business purposes, Murmurhash
+ is under the MIT license.
+*/
+#pragma once
+#include <stdint.h>
+
+#include "rocksdb/slice.h"
+
+#if defined(__x86_64__)
+#define MURMUR_HASH MurmurHash64A
+uint64_t MurmurHash64A(const void* key, int len, unsigned int seed);
+#define MurmurHash MurmurHash64A
+using murmur_t = uint64_t;
+
+#elif defined(__i386__)
+#define MURMUR_HASH MurmurHash2
+unsigned int MurmurHash2(const void* key, int len, unsigned int seed);
+#define MurmurHash MurmurHash2
+using murmur_t = unsigned int;
+
+#else
+#define MURMUR_HASH MurmurHashNeutral2
+unsigned int MurmurHashNeutral2(const void* key, int len, unsigned int seed);
+#define MurmurHash MurmurHashNeutral2
+using murmur_t = unsigned int;
+#endif
+
+// Allow slice to be hashable by murmur hash.
+namespace ROCKSDB_NAMESPACE {
+struct murmur_hash {
+ size_t operator()(const Slice& slice) const {
+ return MurmurHash(slice.data(), static_cast<int>(slice.size()), 0);
+ }
+};
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/mutexlock.h b/src/rocksdb/util/mutexlock.h
new file mode 100644
index 000000000..94066b29e
--- /dev/null
+++ b/src/rocksdb/util/mutexlock.h
@@ -0,0 +1,180 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#pragma once
+#include <assert.h>
+
+#include <atomic>
+#include <mutex>
+#include <thread>
+
+#include "port/port.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// Helper class that locks a mutex on construction and unlocks the mutex when
+// the destructor of the MutexLock object is invoked.
+//
+// Typical usage:
+//
+// void MyClass::MyMethod() {
+// MutexLock l(&mu_); // mu_ is an instance variable
+// ... some complex code, possibly with multiple return paths ...
+// }
+
+class MutexLock {
+ public:
+ explicit MutexLock(port::Mutex *mu) : mu_(mu) { this->mu_->Lock(); }
+ // No copying allowed
+ MutexLock(const MutexLock &) = delete;
+ void operator=(const MutexLock &) = delete;
+
+ ~MutexLock() { this->mu_->Unlock(); }
+
+ private:
+ port::Mutex *const mu_;
+};
+
+//
+// Acquire a ReadLock on the specified RWMutex.
+// The Lock will be automatically released when the
+// object goes out of scope.
+//
+class ReadLock {
+ public:
+ explicit ReadLock(port::RWMutex *mu) : mu_(mu) { this->mu_->ReadLock(); }
+ // No copying allowed
+ ReadLock(const ReadLock &) = delete;
+ void operator=(const ReadLock &) = delete;
+
+ ~ReadLock() { this->mu_->ReadUnlock(); }
+
+ private:
+ port::RWMutex *const mu_;
+};
+
+//
+// Automatically unlock a locked mutex when the object is destroyed
+//
+class ReadUnlock {
+ public:
+ explicit ReadUnlock(port::RWMutex *mu) : mu_(mu) { mu->AssertHeld(); }
+ // No copying allowed
+ ReadUnlock(const ReadUnlock &) = delete;
+ ReadUnlock &operator=(const ReadUnlock &) = delete;
+
+ ~ReadUnlock() { mu_->ReadUnlock(); }
+
+ private:
+ port::RWMutex *const mu_;
+};
+
+//
+// Acquire a WriteLock on the specified RWMutex.
+// The Lock will be automatically released then the
+// object goes out of scope.
+//
+class WriteLock {
+ public:
+ explicit WriteLock(port::RWMutex *mu) : mu_(mu) { this->mu_->WriteLock(); }
+ // No copying allowed
+ WriteLock(const WriteLock &) = delete;
+ void operator=(const WriteLock &) = delete;
+
+ ~WriteLock() { this->mu_->WriteUnlock(); }
+
+ private:
+ port::RWMutex *const mu_;
+};
+
+//
+// SpinMutex has very low overhead for low-contention cases. Method names
+// are chosen so you can use std::unique_lock or std::lock_guard with it.
+//
+class SpinMutex {
+ public:
+ SpinMutex() : locked_(false) {}
+
+ bool try_lock() {
+ auto currently_locked = locked_.load(std::memory_order_relaxed);
+ return !currently_locked &&
+ locked_.compare_exchange_weak(currently_locked, true,
+ std::memory_order_acquire,
+ std::memory_order_relaxed);
+ }
+
+ void lock() {
+ for (size_t tries = 0;; ++tries) {
+ if (try_lock()) {
+ // success
+ break;
+ }
+ port::AsmVolatilePause();
+ if (tries > 100) {
+ std::this_thread::yield();
+ }
+ }
+ }
+
+ void unlock() { locked_.store(false, std::memory_order_release); }
+
+ private:
+ std::atomic<bool> locked_;
+};
+
+// We want to prevent false sharing
+template <class T>
+struct ALIGN_AS(CACHE_LINE_SIZE) LockData {
+ T lock_;
+};
+
+//
+// Inspired by Guava: https://github.com/google/guava/wiki/StripedExplained
+// A striped Lock. This offers the underlying lock striping similar
+// to that of ConcurrentHashMap in a reusable form, and extends it for
+// semaphores and read-write locks. Conceptually, lock striping is the technique
+// of dividing a lock into many <i>stripes</i>, increasing the granularity of a
+// single lock and allowing independent operations to lock different stripes and
+// proceed concurrently, instead of creating contention for a single lock.
+//
+template <class T, class P>
+class Striped {
+ public:
+ Striped(size_t stripes, std::function<uint64_t(const P &)> hash)
+ : stripes_(stripes), hash_(hash) {
+ locks_ = reinterpret_cast<LockData<T> *>(
+ port::cacheline_aligned_alloc(sizeof(LockData<T>) * stripes));
+ for (size_t i = 0; i < stripes; i++) {
+ new (&locks_[i]) LockData<T>();
+ }
+ }
+
+ virtual ~Striped() {
+ if (locks_ != nullptr) {
+ assert(stripes_ > 0);
+ for (size_t i = 0; i < stripes_; i++) {
+ locks_[i].~LockData<T>();
+ }
+ port::cacheline_aligned_free(locks_);
+ }
+ }
+
+ T *get(const P &key) {
+ uint64_t h = hash_(key);
+ size_t index = h % stripes_;
+ return &reinterpret_cast<LockData<T> *>(&locks_[index])->lock_;
+ }
+
+ private:
+ size_t stripes_;
+ LockData<T> *locks_;
+ std::function<uint64_t(const P &)> hash_;
+};
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/ppc-opcode.h b/src/rocksdb/util/ppc-opcode.h
new file mode 100644
index 000000000..5cc5af0e3
--- /dev/null
+++ b/src/rocksdb/util/ppc-opcode.h
@@ -0,0 +1,27 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+// Copyright (c) 2017 International Business Machines Corp.
+// All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#define __PPC_RA(a) (((a)&0x1f) << 16)
+#define __PPC_RB(b) (((b)&0x1f) << 11)
+#define __PPC_XA(a) ((((a)&0x1f) << 16) | (((a)&0x20) >> 3))
+#define __PPC_XB(b) ((((b)&0x1f) << 11) | (((b)&0x20) >> 4))
+#define __PPC_XS(s) ((((s)&0x1f) << 21) | (((s)&0x20) >> 5))
+#define __PPC_XT(s) __PPC_XS(s)
+#define VSX_XX3(t, a, b) (__PPC_XT(t) | __PPC_XA(a) | __PPC_XB(b))
+#define VSX_XX1(s, a, b) (__PPC_XS(s) | __PPC_RA(a) | __PPC_RB(b))
+
+#define PPC_INST_VPMSUMW 0x10000488
+#define PPC_INST_VPMSUMD 0x100004c8
+#define PPC_INST_MFVSRD 0x7c000066
+#define PPC_INST_MTVSRD 0x7c000166
+
+#define VPMSUMW(t, a, b) .long PPC_INST_VPMSUMW | VSX_XX3((t), a, b)
+#define VPMSUMD(t, a, b) .long PPC_INST_VPMSUMD | VSX_XX3((t), a, b)
+#define MFVRD(a, t) .long PPC_INST_MFVSRD | VSX_XX1((t) + 32, a, 0)
+#define MTVRD(t, a) .long PPC_INST_MTVSRD | VSX_XX1((t) + 32, a, 0)
diff --git a/src/rocksdb/util/random.cc b/src/rocksdb/util/random.cc
new file mode 100644
index 000000000..c94c28dfb
--- /dev/null
+++ b/src/rocksdb/util/random.cc
@@ -0,0 +1,62 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+
+#include "util/random.h"
+
+#include <stdint.h>
+#include <string.h>
+
+#include <thread>
+#include <utility>
+
+#include "port/likely.h"
+#include "util/thread_local.h"
+
+#define STORAGE_DECL static thread_local
+
+namespace ROCKSDB_NAMESPACE {
+
+Random* Random::GetTLSInstance() {
+ STORAGE_DECL Random* tls_instance;
+ STORAGE_DECL std::aligned_storage<sizeof(Random)>::type tls_instance_bytes;
+
+ auto rv = tls_instance;
+ if (UNLIKELY(rv == nullptr)) {
+ size_t seed = std::hash<std::thread::id>()(std::this_thread::get_id());
+ rv = new (&tls_instance_bytes) Random((uint32_t)seed);
+ tls_instance = rv;
+ }
+ return rv;
+}
+
+std::string Random::HumanReadableString(int len) {
+ std::string ret;
+ ret.resize(len);
+ for (int i = 0; i < len; ++i) {
+ ret[i] = static_cast<char>('a' + Uniform(26));
+ }
+ return ret;
+}
+
+std::string Random::RandomString(int len) {
+ std::string ret;
+ ret.resize(len);
+ for (int i = 0; i < len; i++) {
+ ret[i] = static_cast<char>(' ' + Uniform(95)); // ' ' .. '~'
+ }
+ return ret;
+}
+
+std::string Random::RandomBinaryString(int len) {
+ std::string ret;
+ ret.resize(len);
+ for (int i = 0; i < len; i++) {
+ ret[i] = static_cast<char>(Uniform(CHAR_MAX));
+ }
+ return ret;
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/random.h b/src/rocksdb/util/random.h
new file mode 100644
index 000000000..8923bdc4f
--- /dev/null
+++ b/src/rocksdb/util/random.h
@@ -0,0 +1,190 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#pragma once
+#include <stdint.h>
+
+#include <algorithm>
+#include <random>
+
+#include "rocksdb/rocksdb_namespace.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// A very simple random number generator. Not especially good at
+// generating truly random bits, but good enough for our needs in this
+// package.
+class Random {
+ private:
+ enum : uint32_t {
+ M = 2147483647L // 2^31-1
+ };
+ enum : uint64_t {
+ A = 16807 // bits 14, 8, 7, 5, 2, 1, 0
+ };
+
+ uint32_t seed_;
+
+ static uint32_t GoodSeed(uint32_t s) { return (s & M) != 0 ? (s & M) : 1; }
+
+ public:
+ // This is the largest value that can be returned from Next()
+ enum : uint32_t { kMaxNext = M };
+
+ explicit Random(uint32_t s) : seed_(GoodSeed(s)) {}
+
+ void Reset(uint32_t s) { seed_ = GoodSeed(s); }
+
+ uint32_t Next() {
+ // We are computing
+ // seed_ = (seed_ * A) % M, where M = 2^31-1
+ //
+ // seed_ must not be zero or M, or else all subsequent computed values
+ // will be zero or M respectively. For all other values, seed_ will end
+ // up cycling through every number in [1,M-1]
+ uint64_t product = seed_ * A;
+
+ // Compute (product % M) using the fact that ((x << 31) % M) == x.
+ seed_ = static_cast<uint32_t>((product >> 31) + (product & M));
+ // The first reduction may overflow by 1 bit, so we may need to
+ // repeat. mod == M is not possible; using > allows the faster
+ // sign-bit-based test.
+ if (seed_ > M) {
+ seed_ -= M;
+ }
+ return seed_;
+ }
+
+ uint64_t Next64() { return (uint64_t{Next()} << 32) | Next(); }
+
+ // Returns a uniformly distributed value in the range [0..n-1]
+ // REQUIRES: n > 0
+ uint32_t Uniform(int n) { return Next() % n; }
+
+ // Randomly returns true ~"1/n" of the time, and false otherwise.
+ // REQUIRES: n > 0
+ bool OneIn(int n) { return Uniform(n) == 0; }
+
+ // "Optional" one-in-n, where 0 or negative always returns false
+ // (may or may not consume a random value)
+ bool OneInOpt(int n) { return n > 0 && OneIn(n); }
+
+ // Returns random bool that is true for the given percentage of
+ // calls on average. Zero or less is always false and 100 or more
+ // is always true (may or may not consume a random value)
+ bool PercentTrue(int percentage) {
+ return static_cast<int>(Uniform(100)) < percentage;
+ }
+
+ // Skewed: pick "base" uniformly from range [0,max_log] and then
+ // return "base" random bits. The effect is to pick a number in the
+ // range [0,2^max_log-1] with exponential bias towards smaller numbers.
+ uint32_t Skewed(int max_log) { return Uniform(1 << Uniform(max_log + 1)); }
+
+ // Returns a random string of length "len"
+ std::string RandomString(int len);
+
+ // Generates a random string of len bytes using human-readable characters
+ std::string HumanReadableString(int len);
+
+ // Generates a random binary data
+ std::string RandomBinaryString(int len);
+
+ // Returns a Random instance for use by the current thread without
+ // additional locking
+ static Random* GetTLSInstance();
+};
+
+// A good 32-bit random number generator based on std::mt19937.
+// This exists in part to avoid compiler variance in warning about coercing
+// uint_fast32_t from mt19937 to uint32_t.
+class Random32 {
+ private:
+ std::mt19937 generator_;
+
+ public:
+ explicit Random32(uint32_t s) : generator_(s) {}
+
+ // Generates the next random number
+ uint32_t Next() { return static_cast<uint32_t>(generator_()); }
+
+ // Returns a uniformly distributed value in the range [0..n-1]
+ // REQUIRES: n > 0
+ uint32_t Uniform(uint32_t n) {
+ return static_cast<uint32_t>(
+ std::uniform_int_distribution<std::mt19937::result_type>(
+ 0, n - 1)(generator_));
+ }
+
+ // Returns an *almost* uniformly distributed value in the range [0..n-1].
+ // Much faster than Uniform().
+ // REQUIRES: n > 0
+ uint32_t Uniformish(uint32_t n) {
+ // fastrange (without the header)
+ return static_cast<uint32_t>((uint64_t(generator_()) * uint64_t(n)) >> 32);
+ }
+
+ // Randomly returns true ~"1/n" of the time, and false otherwise.
+ // REQUIRES: n > 0
+ bool OneIn(uint32_t n) { return Uniform(n) == 0; }
+
+ // Skewed: pick "base" uniformly from range [0,max_log] and then
+ // return "base" random bits. The effect is to pick a number in the
+ // range [0,2^max_log-1] with exponential bias towards smaller numbers.
+ uint32_t Skewed(int max_log) {
+ return Uniform(uint32_t{1} << Uniform(max_log + 1));
+ }
+
+ // Reset the seed of the generator to the given value
+ void Seed(uint32_t new_seed) { generator_.seed(new_seed); }
+};
+
+// A good 64-bit random number generator based on std::mt19937_64
+class Random64 {
+ private:
+ std::mt19937_64 generator_;
+
+ public:
+ explicit Random64(uint64_t s) : generator_(s) {}
+
+ // Generates the next random number
+ uint64_t Next() { return generator_(); }
+
+ // Returns a uniformly distributed value in the range [0..n-1]
+ // REQUIRES: n > 0
+ uint64_t Uniform(uint64_t n) {
+ return std::uniform_int_distribution<uint64_t>(0, n - 1)(generator_);
+ }
+
+ // Randomly returns true ~"1/n" of the time, and false otherwise.
+ // REQUIRES: n > 0
+ bool OneIn(uint64_t n) { return Uniform(n) == 0; }
+
+ // Skewed: pick "base" uniformly from range [0,max_log] and then
+ // return "base" random bits. The effect is to pick a number in the
+ // range [0,2^max_log-1] with exponential bias towards smaller numbers.
+ uint64_t Skewed(int max_log) {
+ return Uniform(uint64_t(1) << Uniform(max_log + 1));
+ }
+};
+
+// A seeded replacement for removed std::random_shuffle
+template <class RandomIt>
+void RandomShuffle(RandomIt first, RandomIt last, uint32_t seed) {
+ std::mt19937 rng(seed);
+ std::shuffle(first, last, rng);
+}
+
+// A replacement for removed std::random_shuffle
+template <class RandomIt>
+void RandomShuffle(RandomIt first, RandomIt last) {
+ RandomShuffle(first, last, std::random_device{}());
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/random_test.cc b/src/rocksdb/util/random_test.cc
new file mode 100644
index 000000000..1aa62c5da
--- /dev/null
+++ b/src/rocksdb/util/random_test.cc
@@ -0,0 +1,107 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2012 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "util/random.h"
+
+#include <cstring>
+#include <vector>
+
+#include "test_util/testharness.h"
+
+using ROCKSDB_NAMESPACE::Random;
+
+TEST(RandomTest, Uniform) {
+ const int average = 20;
+ for (uint32_t seed : {0, 1, 2, 37, 4096}) {
+ Random r(seed);
+ for (int range : {1, 2, 8, 12, 100}) {
+ std::vector<int> counts(range, 0);
+
+ for (int i = 0; i < range * average; ++i) {
+ ++counts.at(r.Uniform(range));
+ }
+ int max_variance = static_cast<int>(std::sqrt(range) * 2 + 4);
+ for (int i = 0; i < range; ++i) {
+ EXPECT_GE(counts[i], std::max(1, average - max_variance));
+ EXPECT_LE(counts[i], average + max_variance + 1);
+ }
+ }
+ }
+}
+
+TEST(RandomTest, OneIn) {
+ Random r(42);
+ for (int range : {1, 2, 8, 12, 100, 1234}) {
+ const int average = 100;
+ int count = 0;
+ for (int i = 0; i < average * range; ++i) {
+ if (r.OneIn(range)) {
+ ++count;
+ }
+ }
+ if (range == 1) {
+ EXPECT_EQ(count, average);
+ } else {
+ int max_variance = static_cast<int>(std::sqrt(average) * 1.5);
+ EXPECT_GE(count, average - max_variance);
+ EXPECT_LE(count, average + max_variance);
+ }
+ }
+}
+
+TEST(RandomTest, OneInOpt) {
+ Random r(42);
+ for (int range : {-12, 0, 1, 2, 8, 12, 100, 1234}) {
+ const int average = 100;
+ int count = 0;
+ for (int i = 0; i < average * range; ++i) {
+ if (r.OneInOpt(range)) {
+ ++count;
+ }
+ }
+ if (range < 1) {
+ EXPECT_EQ(count, 0);
+ } else if (range == 1) {
+ EXPECT_EQ(count, average);
+ } else {
+ int max_variance = static_cast<int>(std::sqrt(average) * 1.5);
+ EXPECT_GE(count, average - max_variance);
+ EXPECT_LE(count, average + max_variance);
+ }
+ }
+}
+
+TEST(RandomTest, PercentTrue) {
+ Random r(42);
+ for (int pct : {-12, 0, 1, 2, 10, 50, 90, 98, 99, 100, 1234}) {
+ const int samples = 10000;
+
+ int count = 0;
+ for (int i = 0; i < samples; ++i) {
+ if (r.PercentTrue(pct)) {
+ ++count;
+ }
+ }
+ if (pct <= 0) {
+ EXPECT_EQ(count, 0);
+ } else if (pct >= 100) {
+ EXPECT_EQ(count, samples);
+ } else {
+ int est = (count * 100 + (samples / 2)) / samples;
+ EXPECT_EQ(est, pct);
+ }
+ }
+}
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+
+ return RUN_ALL_TESTS();
+}
diff --git a/src/rocksdb/util/rate_limiter.cc b/src/rocksdb/util/rate_limiter.cc
new file mode 100644
index 000000000..6bbcabfae
--- /dev/null
+++ b/src/rocksdb/util/rate_limiter.cc
@@ -0,0 +1,378 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "util/rate_limiter.h"
+
+#include <algorithm>
+
+#include "monitoring/statistics.h"
+#include "port/port.h"
+#include "rocksdb/system_clock.h"
+#include "test_util/sync_point.h"
+#include "util/aligned_buffer.h"
+
+namespace ROCKSDB_NAMESPACE {
+size_t RateLimiter::RequestToken(size_t bytes, size_t alignment,
+ Env::IOPriority io_priority, Statistics* stats,
+ RateLimiter::OpType op_type) {
+ if (io_priority < Env::IO_TOTAL && IsRateLimited(op_type)) {
+ bytes = std::min(bytes, static_cast<size_t>(GetSingleBurstBytes()));
+
+ if (alignment > 0) {
+ // Here we may actually require more than burst and block
+ // as we can not write/read less than one page at a time on direct I/O
+ // thus we do not want to be strictly constrained by burst
+ bytes = std::max(alignment, TruncateToPageBoundary(alignment, bytes));
+ }
+ Request(bytes, io_priority, stats, op_type);
+ }
+ return bytes;
+}
+
+// Pending request
+struct GenericRateLimiter::Req {
+ explicit Req(int64_t _bytes, port::Mutex* _mu)
+ : request_bytes(_bytes), bytes(_bytes), cv(_mu), granted(false) {}
+ int64_t request_bytes;
+ int64_t bytes;
+ port::CondVar cv;
+ bool granted;
+};
+
+GenericRateLimiter::GenericRateLimiter(
+ int64_t rate_bytes_per_sec, int64_t refill_period_us, int32_t fairness,
+ RateLimiter::Mode mode, const std::shared_ptr<SystemClock>& clock,
+ bool auto_tuned)
+ : RateLimiter(mode),
+ refill_period_us_(refill_period_us),
+ rate_bytes_per_sec_(auto_tuned ? rate_bytes_per_sec / 2
+ : rate_bytes_per_sec),
+ refill_bytes_per_period_(
+ CalculateRefillBytesPerPeriodLocked(rate_bytes_per_sec_)),
+ clock_(clock),
+ stop_(false),
+ exit_cv_(&request_mutex_),
+ requests_to_wait_(0),
+ available_bytes_(0),
+ next_refill_us_(NowMicrosMonotonicLocked()),
+ fairness_(fairness > 100 ? 100 : fairness),
+ rnd_((uint32_t)time(nullptr)),
+ wait_until_refill_pending_(false),
+ auto_tuned_(auto_tuned),
+ num_drains_(0),
+ max_bytes_per_sec_(rate_bytes_per_sec),
+ tuned_time_(NowMicrosMonotonicLocked()) {
+ for (int i = Env::IO_LOW; i < Env::IO_TOTAL; ++i) {
+ total_requests_[i] = 0;
+ total_bytes_through_[i] = 0;
+ }
+}
+
+GenericRateLimiter::~GenericRateLimiter() {
+ MutexLock g(&request_mutex_);
+ stop_ = true;
+ std::deque<Req*>::size_type queues_size_sum = 0;
+ for (int i = Env::IO_LOW; i < Env::IO_TOTAL; ++i) {
+ queues_size_sum += queue_[i].size();
+ }
+ requests_to_wait_ = static_cast<int32_t>(queues_size_sum);
+
+ for (int i = Env::IO_TOTAL - 1; i >= Env::IO_LOW; --i) {
+ std::deque<Req*> queue = queue_[i];
+ for (auto& r : queue) {
+ r->cv.Signal();
+ }
+ }
+
+ while (requests_to_wait_ > 0) {
+ exit_cv_.Wait();
+ }
+}
+
+// This API allows user to dynamically change rate limiter's bytes per second.
+void GenericRateLimiter::SetBytesPerSecond(int64_t bytes_per_second) {
+ MutexLock g(&request_mutex_);
+ SetBytesPerSecondLocked(bytes_per_second);
+}
+
+void GenericRateLimiter::SetBytesPerSecondLocked(int64_t bytes_per_second) {
+ assert(bytes_per_second > 0);
+ rate_bytes_per_sec_.store(bytes_per_second, std::memory_order_relaxed);
+ refill_bytes_per_period_.store(
+ CalculateRefillBytesPerPeriodLocked(bytes_per_second),
+ std::memory_order_relaxed);
+}
+
+void GenericRateLimiter::Request(int64_t bytes, const Env::IOPriority pri,
+ Statistics* stats) {
+ assert(bytes <= refill_bytes_per_period_.load(std::memory_order_relaxed));
+ bytes = std::max(static_cast<int64_t>(0), bytes);
+ TEST_SYNC_POINT("GenericRateLimiter::Request");
+ TEST_SYNC_POINT_CALLBACK("GenericRateLimiter::Request:1",
+ &rate_bytes_per_sec_);
+ MutexLock g(&request_mutex_);
+
+ if (auto_tuned_) {
+ static const int kRefillsPerTune = 100;
+ std::chrono::microseconds now(NowMicrosMonotonicLocked());
+ if (now - tuned_time_ >=
+ kRefillsPerTune * std::chrono::microseconds(refill_period_us_)) {
+ Status s = TuneLocked();
+ s.PermitUncheckedError(); //**TODO: What to do on error?
+ }
+ }
+
+ if (stop_) {
+ // It is now in the clean-up of ~GenericRateLimiter().
+ // Therefore any new incoming request will exit from here
+ // and not get satiesfied.
+ return;
+ }
+
+ ++total_requests_[pri];
+
+ if (available_bytes_ >= bytes) {
+ // Refill thread assigns quota and notifies requests waiting on
+ // the queue under mutex. So if we get here, that means nobody
+ // is waiting?
+ available_bytes_ -= bytes;
+ total_bytes_through_[pri] += bytes;
+ return;
+ }
+
+ // Request cannot be satisfied at this moment, enqueue
+ Req r(bytes, &request_mutex_);
+ queue_[pri].push_back(&r);
+ TEST_SYNC_POINT_CALLBACK("GenericRateLimiter::Request:PostEnqueueRequest",
+ &request_mutex_);
+ // A thread representing a queued request coordinates with other such threads.
+ // There are two main duties.
+ //
+ // (1) Waiting for the next refill time.
+ // (2) Refilling the bytes and granting requests.
+ do {
+ int64_t time_until_refill_us = next_refill_us_ - NowMicrosMonotonicLocked();
+ if (time_until_refill_us > 0) {
+ if (wait_until_refill_pending_) {
+ // Somebody is performing (1). Trust we'll be woken up when our request
+ // is granted or we are needed for future duties.
+ r.cv.Wait();
+ } else {
+ // Whichever thread reaches here first performs duty (1) as described
+ // above.
+ int64_t wait_until = clock_->NowMicros() + time_until_refill_us;
+ RecordTick(stats, NUMBER_RATE_LIMITER_DRAINS);
+ ++num_drains_;
+ wait_until_refill_pending_ = true;
+ r.cv.TimedWait(wait_until);
+ TEST_SYNC_POINT_CALLBACK("GenericRateLimiter::Request:PostTimedWait",
+ &time_until_refill_us);
+ wait_until_refill_pending_ = false;
+ }
+ } else {
+ // Whichever thread reaches here first performs duty (2) as described
+ // above.
+ RefillBytesAndGrantRequestsLocked();
+ if (r.granted) {
+ // If there is any remaining requests, make sure there exists at least
+ // one candidate is awake for future duties by signaling a front request
+ // of a queue.
+ for (int i = Env::IO_TOTAL - 1; i >= Env::IO_LOW; --i) {
+ std::deque<Req*> queue = queue_[i];
+ if (!queue.empty()) {
+ queue.front()->cv.Signal();
+ break;
+ }
+ }
+ }
+ }
+ // Invariant: non-granted request is always in one queue, and granted
+ // request is always in zero queues.
+#ifndef NDEBUG
+ int num_found = 0;
+ for (int i = Env::IO_LOW; i < Env::IO_TOTAL; ++i) {
+ if (std::find(queue_[i].begin(), queue_[i].end(), &r) !=
+ queue_[i].end()) {
+ ++num_found;
+ }
+ }
+ if (r.granted) {
+ assert(num_found == 0);
+ } else {
+ assert(num_found == 1);
+ }
+#endif // NDEBUG
+ } while (!stop_ && !r.granted);
+
+ if (stop_) {
+ // It is now in the clean-up of ~GenericRateLimiter().
+ // Therefore any woken-up request will have come out of the loop and then
+ // exit here. It might or might not have been satisfied.
+ --requests_to_wait_;
+ exit_cv_.Signal();
+ }
+}
+
+std::vector<Env::IOPriority>
+GenericRateLimiter::GeneratePriorityIterationOrderLocked() {
+ std::vector<Env::IOPriority> pri_iteration_order(Env::IO_TOTAL /* 4 */);
+ // We make Env::IO_USER a superior priority by always iterating its queue
+ // first
+ pri_iteration_order[0] = Env::IO_USER;
+
+ bool high_pri_iterated_after_mid_low_pri = rnd_.OneIn(fairness_);
+ TEST_SYNC_POINT_CALLBACK(
+ "GenericRateLimiter::GeneratePriorityIterationOrderLocked::"
+ "PostRandomOneInFairnessForHighPri",
+ &high_pri_iterated_after_mid_low_pri);
+ bool mid_pri_itereated_after_low_pri = rnd_.OneIn(fairness_);
+ TEST_SYNC_POINT_CALLBACK(
+ "GenericRateLimiter::GeneratePriorityIterationOrderLocked::"
+ "PostRandomOneInFairnessForMidPri",
+ &mid_pri_itereated_after_low_pri);
+
+ if (high_pri_iterated_after_mid_low_pri) {
+ pri_iteration_order[3] = Env::IO_HIGH;
+ pri_iteration_order[2] =
+ mid_pri_itereated_after_low_pri ? Env::IO_MID : Env::IO_LOW;
+ pri_iteration_order[1] =
+ (pri_iteration_order[2] == Env::IO_MID) ? Env::IO_LOW : Env::IO_MID;
+ } else {
+ pri_iteration_order[1] = Env::IO_HIGH;
+ pri_iteration_order[3] =
+ mid_pri_itereated_after_low_pri ? Env::IO_MID : Env::IO_LOW;
+ pri_iteration_order[2] =
+ (pri_iteration_order[3] == Env::IO_MID) ? Env::IO_LOW : Env::IO_MID;
+ }
+
+ TEST_SYNC_POINT_CALLBACK(
+ "GenericRateLimiter::GeneratePriorityIterationOrderLocked::"
+ "PreReturnPriIterationOrder",
+ &pri_iteration_order);
+ return pri_iteration_order;
+}
+
+void GenericRateLimiter::RefillBytesAndGrantRequestsLocked() {
+ TEST_SYNC_POINT_CALLBACK(
+ "GenericRateLimiter::RefillBytesAndGrantRequestsLocked", &request_mutex_);
+ next_refill_us_ = NowMicrosMonotonicLocked() + refill_period_us_;
+ // Carry over the left over quota from the last period
+ auto refill_bytes_per_period =
+ refill_bytes_per_period_.load(std::memory_order_relaxed);
+ if (available_bytes_ < refill_bytes_per_period) {
+ available_bytes_ += refill_bytes_per_period;
+ }
+
+ std::vector<Env::IOPriority> pri_iteration_order =
+ GeneratePriorityIterationOrderLocked();
+
+ for (int i = Env::IO_LOW; i < Env::IO_TOTAL; ++i) {
+ assert(!pri_iteration_order.empty());
+ Env::IOPriority current_pri = pri_iteration_order[i];
+ auto* queue = &queue_[current_pri];
+ while (!queue->empty()) {
+ auto* next_req = queue->front();
+ if (available_bytes_ < next_req->request_bytes) {
+ // Grant partial request_bytes to avoid starvation of requests
+ // that become asking for more bytes than available_bytes_
+ // due to dynamically reduced rate limiter's bytes_per_second that
+ // leads to reduced refill_bytes_per_period hence available_bytes_
+ next_req->request_bytes -= available_bytes_;
+ available_bytes_ = 0;
+ break;
+ }
+ available_bytes_ -= next_req->request_bytes;
+ next_req->request_bytes = 0;
+ total_bytes_through_[current_pri] += next_req->bytes;
+ queue->pop_front();
+
+ next_req->granted = true;
+ // Quota granted, signal the thread to exit
+ next_req->cv.Signal();
+ }
+ }
+}
+
+int64_t GenericRateLimiter::CalculateRefillBytesPerPeriodLocked(
+ int64_t rate_bytes_per_sec) {
+ if (std::numeric_limits<int64_t>::max() / rate_bytes_per_sec <
+ refill_period_us_) {
+ // Avoid unexpected result in the overflow case. The result now is still
+ // inaccurate but is a number that is large enough.
+ return std::numeric_limits<int64_t>::max() / 1000000;
+ } else {
+ return rate_bytes_per_sec * refill_period_us_ / 1000000;
+ }
+}
+
+Status GenericRateLimiter::TuneLocked() {
+ const int kLowWatermarkPct = 50;
+ const int kHighWatermarkPct = 90;
+ const int kAdjustFactorPct = 5;
+ // computed rate limit will be in
+ // `[max_bytes_per_sec_ / kAllowedRangeFactor, max_bytes_per_sec_]`.
+ const int kAllowedRangeFactor = 20;
+
+ std::chrono::microseconds prev_tuned_time = tuned_time_;
+ tuned_time_ = std::chrono::microseconds(NowMicrosMonotonicLocked());
+
+ int64_t elapsed_intervals = (tuned_time_ - prev_tuned_time +
+ std::chrono::microseconds(refill_period_us_) -
+ std::chrono::microseconds(1)) /
+ std::chrono::microseconds(refill_period_us_);
+ // We tune every kRefillsPerTune intervals, so the overflow and division-by-
+ // zero conditions should never happen.
+ assert(num_drains_ <= std::numeric_limits<int64_t>::max() / 100);
+ assert(elapsed_intervals > 0);
+ int64_t drained_pct = num_drains_ * 100 / elapsed_intervals;
+
+ int64_t prev_bytes_per_sec = GetBytesPerSecond();
+ int64_t new_bytes_per_sec;
+ if (drained_pct == 0) {
+ new_bytes_per_sec = max_bytes_per_sec_ / kAllowedRangeFactor;
+ } else if (drained_pct < kLowWatermarkPct) {
+ // sanitize to prevent overflow
+ int64_t sanitized_prev_bytes_per_sec =
+ std::min(prev_bytes_per_sec, std::numeric_limits<int64_t>::max() / 100);
+ new_bytes_per_sec =
+ std::max(max_bytes_per_sec_ / kAllowedRangeFactor,
+ sanitized_prev_bytes_per_sec * 100 / (100 + kAdjustFactorPct));
+ } else if (drained_pct > kHighWatermarkPct) {
+ // sanitize to prevent overflow
+ int64_t sanitized_prev_bytes_per_sec =
+ std::min(prev_bytes_per_sec, std::numeric_limits<int64_t>::max() /
+ (100 + kAdjustFactorPct));
+ new_bytes_per_sec =
+ std::min(max_bytes_per_sec_,
+ sanitized_prev_bytes_per_sec * (100 + kAdjustFactorPct) / 100);
+ } else {
+ new_bytes_per_sec = prev_bytes_per_sec;
+ }
+ if (new_bytes_per_sec != prev_bytes_per_sec) {
+ SetBytesPerSecondLocked(new_bytes_per_sec);
+ }
+ num_drains_ = 0;
+ return Status::OK();
+}
+
+RateLimiter* NewGenericRateLimiter(
+ int64_t rate_bytes_per_sec, int64_t refill_period_us /* = 100 * 1000 */,
+ int32_t fairness /* = 10 */,
+ RateLimiter::Mode mode /* = RateLimiter::Mode::kWritesOnly */,
+ bool auto_tuned /* = false */) {
+ assert(rate_bytes_per_sec > 0);
+ assert(refill_period_us > 0);
+ assert(fairness > 0);
+ std::unique_ptr<RateLimiter> limiter(
+ new GenericRateLimiter(rate_bytes_per_sec, refill_period_us, fairness,
+ mode, SystemClock::Default(), auto_tuned));
+ return limiter.release();
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/rate_limiter.h b/src/rocksdb/util/rate_limiter.h
new file mode 100644
index 000000000..4c078f5a0
--- /dev/null
+++ b/src/rocksdb/util/rate_limiter.h
@@ -0,0 +1,146 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#pragma once
+
+#include <algorithm>
+#include <atomic>
+#include <chrono>
+#include <deque>
+
+#include "port/port.h"
+#include "rocksdb/env.h"
+#include "rocksdb/rate_limiter.h"
+#include "rocksdb/status.h"
+#include "rocksdb/system_clock.h"
+#include "util/mutexlock.h"
+#include "util/random.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class GenericRateLimiter : public RateLimiter {
+ public:
+ GenericRateLimiter(int64_t refill_bytes, int64_t refill_period_us,
+ int32_t fairness, RateLimiter::Mode mode,
+ const std::shared_ptr<SystemClock>& clock,
+ bool auto_tuned);
+
+ virtual ~GenericRateLimiter();
+
+ // This API allows user to dynamically change rate limiter's bytes per second.
+ virtual void SetBytesPerSecond(int64_t bytes_per_second) override;
+
+ // Request for token to write bytes. If this request can not be satisfied,
+ // the call is blocked. Caller is responsible to make sure
+ // bytes <= GetSingleBurstBytes() and bytes >= 0. Negative bytes
+ // passed in will be rounded up to 0.
+ using RateLimiter::Request;
+ virtual void Request(const int64_t bytes, const Env::IOPriority pri,
+ Statistics* stats) override;
+
+ virtual int64_t GetSingleBurstBytes() const override {
+ return refill_bytes_per_period_.load(std::memory_order_relaxed);
+ }
+
+ virtual int64_t GetTotalBytesThrough(
+ const Env::IOPriority pri = Env::IO_TOTAL) const override {
+ MutexLock g(&request_mutex_);
+ if (pri == Env::IO_TOTAL) {
+ int64_t total_bytes_through_sum = 0;
+ for (int i = Env::IO_LOW; i < Env::IO_TOTAL; ++i) {
+ total_bytes_through_sum += total_bytes_through_[i];
+ }
+ return total_bytes_through_sum;
+ }
+ return total_bytes_through_[pri];
+ }
+
+ virtual int64_t GetTotalRequests(
+ const Env::IOPriority pri = Env::IO_TOTAL) const override {
+ MutexLock g(&request_mutex_);
+ if (pri == Env::IO_TOTAL) {
+ int64_t total_requests_sum = 0;
+ for (int i = Env::IO_LOW; i < Env::IO_TOTAL; ++i) {
+ total_requests_sum += total_requests_[i];
+ }
+ return total_requests_sum;
+ }
+ return total_requests_[pri];
+ }
+
+ virtual Status GetTotalPendingRequests(
+ int64_t* total_pending_requests,
+ const Env::IOPriority pri = Env::IO_TOTAL) const override {
+ assert(total_pending_requests != nullptr);
+ MutexLock g(&request_mutex_);
+ if (pri == Env::IO_TOTAL) {
+ int64_t total_pending_requests_sum = 0;
+ for (int i = Env::IO_LOW; i < Env::IO_TOTAL; ++i) {
+ total_pending_requests_sum += static_cast<int64_t>(queue_[i].size());
+ }
+ *total_pending_requests = total_pending_requests_sum;
+ } else {
+ *total_pending_requests = static_cast<int64_t>(queue_[pri].size());
+ }
+ return Status::OK();
+ }
+
+ virtual int64_t GetBytesPerSecond() const override {
+ return rate_bytes_per_sec_.load(std::memory_order_relaxed);
+ }
+
+ virtual void TEST_SetClock(std::shared_ptr<SystemClock> clock) {
+ MutexLock g(&request_mutex_);
+ clock_ = std::move(clock);
+ next_refill_us_ = NowMicrosMonotonicLocked();
+ }
+
+ private:
+ void RefillBytesAndGrantRequestsLocked();
+ std::vector<Env::IOPriority> GeneratePriorityIterationOrderLocked();
+ int64_t CalculateRefillBytesPerPeriodLocked(int64_t rate_bytes_per_sec);
+ Status TuneLocked();
+ void SetBytesPerSecondLocked(int64_t bytes_per_second);
+
+ uint64_t NowMicrosMonotonicLocked() {
+ return clock_->NowNanos() / std::milli::den;
+ }
+
+ // This mutex guard all internal states
+ mutable port::Mutex request_mutex_;
+
+ const int64_t refill_period_us_;
+
+ std::atomic<int64_t> rate_bytes_per_sec_;
+ std::atomic<int64_t> refill_bytes_per_period_;
+ std::shared_ptr<SystemClock> clock_;
+
+ bool stop_;
+ port::CondVar exit_cv_;
+ int32_t requests_to_wait_;
+
+ int64_t total_requests_[Env::IO_TOTAL];
+ int64_t total_bytes_through_[Env::IO_TOTAL];
+ int64_t available_bytes_;
+ int64_t next_refill_us_;
+
+ int32_t fairness_;
+ Random rnd_;
+
+ struct Req;
+ std::deque<Req*> queue_[Env::IO_TOTAL];
+ bool wait_until_refill_pending_;
+
+ bool auto_tuned_;
+ int64_t num_drains_;
+ const int64_t max_bytes_per_sec_;
+ std::chrono::microseconds tuned_time_;
+};
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/rate_limiter_test.cc b/src/rocksdb/util/rate_limiter_test.cc
new file mode 100644
index 000000000..cda134867
--- /dev/null
+++ b/src/rocksdb/util/rate_limiter_test.cc
@@ -0,0 +1,476 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "util/rate_limiter.h"
+
+#include <chrono>
+#include <cinttypes>
+#include <cstdint>
+#include <limits>
+
+#include "db/db_test_util.h"
+#include "port/port.h"
+#include "rocksdb/system_clock.h"
+#include "test_util/sync_point.h"
+#include "test_util/testharness.h"
+#include "util/random.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// TODO(yhchiang): the rate will not be accurate when we run test in parallel.
+class RateLimiterTest : public testing::Test {
+ protected:
+ ~RateLimiterTest() override {
+ SyncPoint::GetInstance()->DisableProcessing();
+ SyncPoint::GetInstance()->ClearAllCallBacks();
+ }
+};
+
+TEST_F(RateLimiterTest, OverflowRate) {
+ GenericRateLimiter limiter(std::numeric_limits<int64_t>::max(), 1000, 10,
+ RateLimiter::Mode::kWritesOnly,
+ SystemClock::Default(), false /* auto_tuned */);
+ ASSERT_GT(limiter.GetSingleBurstBytes(), 1000000000ll);
+}
+
+TEST_F(RateLimiterTest, StartStop) {
+ std::unique_ptr<RateLimiter> limiter(NewGenericRateLimiter(100, 100, 10));
+}
+
+TEST_F(RateLimiterTest, GetTotalBytesThrough) {
+ std::unique_ptr<RateLimiter> limiter(NewGenericRateLimiter(
+ 200 /* rate_bytes_per_sec */, 1000 * 1000 /* refill_period_us */,
+ 10 /* fairness */));
+ for (int i = Env::IO_LOW; i <= Env::IO_TOTAL; ++i) {
+ ASSERT_EQ(limiter->GetTotalBytesThrough(static_cast<Env::IOPriority>(i)),
+ 0);
+ }
+
+ std::int64_t request_byte = 200;
+ std::int64_t request_byte_sum = 0;
+ for (int i = Env::IO_LOW; i < Env::IO_TOTAL; ++i) {
+ limiter->Request(request_byte, static_cast<Env::IOPriority>(i),
+ nullptr /* stats */, RateLimiter::OpType::kWrite);
+ request_byte_sum += request_byte;
+ }
+
+ for (int i = Env::IO_LOW; i < Env::IO_TOTAL; ++i) {
+ EXPECT_EQ(limiter->GetTotalBytesThrough(static_cast<Env::IOPriority>(i)),
+ request_byte)
+ << "Failed to track total_bytes_through_ correctly when IOPriority = "
+ << static_cast<Env::IOPriority>(i);
+ }
+ EXPECT_EQ(limiter->GetTotalBytesThrough(Env::IO_TOTAL), request_byte_sum)
+ << "Failed to track total_bytes_through_ correctly when IOPriority = "
+ "Env::IO_TOTAL";
+}
+
+TEST_F(RateLimiterTest, GetTotalRequests) {
+ std::unique_ptr<RateLimiter> limiter(NewGenericRateLimiter(
+ 200 /* rate_bytes_per_sec */, 1000 * 1000 /* refill_period_us */,
+ 10 /* fairness */));
+ for (int i = Env::IO_LOW; i <= Env::IO_TOTAL; ++i) {
+ ASSERT_EQ(limiter->GetTotalRequests(static_cast<Env::IOPriority>(i)), 0);
+ }
+
+ std::int64_t total_requests_sum = 0;
+ for (int i = Env::IO_LOW; i < Env::IO_TOTAL; ++i) {
+ limiter->Request(200, static_cast<Env::IOPriority>(i), nullptr /* stats */,
+ RateLimiter::OpType::kWrite);
+ total_requests_sum += 1;
+ }
+
+ for (int i = Env::IO_LOW; i < Env::IO_TOTAL; ++i) {
+ EXPECT_EQ(limiter->GetTotalRequests(static_cast<Env::IOPriority>(i)), 1)
+ << "Failed to track total_requests_ correctly when IOPriority = "
+ << static_cast<Env::IOPriority>(i);
+ }
+ EXPECT_EQ(limiter->GetTotalRequests(Env::IO_TOTAL), total_requests_sum)
+ << "Failed to track total_requests_ correctly when IOPriority = "
+ "Env::IO_TOTAL";
+}
+
+TEST_F(RateLimiterTest, GetTotalPendingRequests) {
+ std::unique_ptr<RateLimiter> limiter(NewGenericRateLimiter(
+ 200 /* rate_bytes_per_sec */, 1000 * 1000 /* refill_period_us */,
+ 10 /* fairness */));
+ int64_t total_pending_requests = 0;
+ for (int i = Env::IO_LOW; i <= Env::IO_TOTAL; ++i) {
+ ASSERT_OK(limiter->GetTotalPendingRequests(
+ &total_pending_requests, static_cast<Env::IOPriority>(i)));
+ ASSERT_EQ(total_pending_requests, 0);
+ }
+ // This is a variable for making sure the following callback is called
+ // and the assertions in it are indeed excuted
+ bool nonzero_pending_requests_verified = false;
+ SyncPoint::GetInstance()->SetCallBack(
+ "GenericRateLimiter::Request:PostEnqueueRequest", [&](void* arg) {
+ port::Mutex* request_mutex = (port::Mutex*)arg;
+ // We temporarily unlock the mutex so that the following
+ // GetTotalPendingRequests() can acquire it
+ request_mutex->Unlock();
+ for (int i = Env::IO_LOW; i <= Env::IO_TOTAL; ++i) {
+ EXPECT_OK(limiter->GetTotalPendingRequests(
+ &total_pending_requests, static_cast<Env::IOPriority>(i)))
+ << "Failed to return total pending requests for priority level = "
+ << static_cast<Env::IOPriority>(i);
+ if (i == Env::IO_USER || i == Env::IO_TOTAL) {
+ EXPECT_EQ(total_pending_requests, 1)
+ << "Failed to correctly return total pending requests for "
+ "priority level = "
+ << static_cast<Env::IOPriority>(i);
+ } else {
+ EXPECT_EQ(total_pending_requests, 0)
+ << "Failed to correctly return total pending requests for "
+ "priority level = "
+ << static_cast<Env::IOPriority>(i);
+ }
+ }
+ // We lock the mutex again so that the request thread can resume running
+ // with the mutex locked
+ request_mutex->Lock();
+ nonzero_pending_requests_verified = true;
+ });
+
+ SyncPoint::GetInstance()->EnableProcessing();
+ limiter->Request(200, Env::IO_USER, nullptr /* stats */,
+ RateLimiter::OpType::kWrite);
+ ASSERT_EQ(nonzero_pending_requests_verified, true);
+ for (int i = Env::IO_LOW; i <= Env::IO_TOTAL; ++i) {
+ EXPECT_OK(limiter->GetTotalPendingRequests(&total_pending_requests,
+ static_cast<Env::IOPriority>(i)))
+ << "Failed to return total pending requests for priority level = "
+ << static_cast<Env::IOPriority>(i);
+ EXPECT_EQ(total_pending_requests, 0)
+ << "Failed to correctly return total pending requests for priority "
+ "level = "
+ << static_cast<Env::IOPriority>(i);
+ }
+ SyncPoint::GetInstance()->DisableProcessing();
+ SyncPoint::GetInstance()->ClearCallBack(
+ "GenericRateLimiter::Request:PostEnqueueRequest");
+}
+
+TEST_F(RateLimiterTest, Modes) {
+ for (auto mode : {RateLimiter::Mode::kWritesOnly,
+ RateLimiter::Mode::kReadsOnly, RateLimiter::Mode::kAllIo}) {
+ GenericRateLimiter limiter(2000 /* rate_bytes_per_sec */,
+ 1000 * 1000 /* refill_period_us */,
+ 10 /* fairness */, mode, SystemClock::Default(),
+ false /* auto_tuned */);
+ limiter.Request(1000 /* bytes */, Env::IO_HIGH, nullptr /* stats */,
+ RateLimiter::OpType::kRead);
+ if (mode == RateLimiter::Mode::kWritesOnly) {
+ ASSERT_EQ(0, limiter.GetTotalBytesThrough(Env::IO_HIGH));
+ } else {
+ ASSERT_EQ(1000, limiter.GetTotalBytesThrough(Env::IO_HIGH));
+ }
+
+ limiter.Request(1000 /* bytes */, Env::IO_HIGH, nullptr /* stats */,
+ RateLimiter::OpType::kWrite);
+ if (mode == RateLimiter::Mode::kAllIo) {
+ ASSERT_EQ(2000, limiter.GetTotalBytesThrough(Env::IO_HIGH));
+ } else {
+ ASSERT_EQ(1000, limiter.GetTotalBytesThrough(Env::IO_HIGH));
+ }
+ }
+}
+
+TEST_F(RateLimiterTest, GeneratePriorityIterationOrder) {
+ std::unique_ptr<RateLimiter> limiter(NewGenericRateLimiter(
+ 200 /* rate_bytes_per_sec */, 1000 * 1000 /* refill_period_us */,
+ 10 /* fairness */));
+
+ bool possible_random_one_in_fairness_results_for_high_mid_pri[4][2] = {
+ {false, false}, {false, true}, {true, false}, {true, true}};
+ std::vector<Env::IOPriority> possible_priority_iteration_orders[4] = {
+ {Env::IO_USER, Env::IO_HIGH, Env::IO_MID, Env::IO_LOW},
+ {Env::IO_USER, Env::IO_HIGH, Env::IO_LOW, Env::IO_MID},
+ {Env::IO_USER, Env::IO_MID, Env::IO_LOW, Env::IO_HIGH},
+ {Env::IO_USER, Env::IO_LOW, Env::IO_MID, Env::IO_HIGH}};
+
+ for (int i = 0; i < 4; ++i) {
+ // These are variables for making sure the following callbacks are called
+ // and the assertion in the last callback is indeed excuted
+ bool high_pri_iterated_after_mid_low_pri_set = false;
+ bool mid_pri_itereated_after_low_pri_set = false;
+ bool pri_iteration_order_verified = false;
+ SyncPoint::GetInstance()->SetCallBack(
+ "GenericRateLimiter::GeneratePriorityIterationOrderLocked::"
+ "PostRandomOneInFairnessForHighPri",
+ [&](void* arg) {
+ bool* high_pri_iterated_after_mid_low_pri = (bool*)arg;
+ *high_pri_iterated_after_mid_low_pri =
+ possible_random_one_in_fairness_results_for_high_mid_pri[i][0];
+ high_pri_iterated_after_mid_low_pri_set = true;
+ });
+
+ SyncPoint::GetInstance()->SetCallBack(
+ "GenericRateLimiter::GeneratePriorityIterationOrderLocked::"
+ "PostRandomOneInFairnessForMidPri",
+ [&](void* arg) {
+ bool* mid_pri_itereated_after_low_pri = (bool*)arg;
+ *mid_pri_itereated_after_low_pri =
+ possible_random_one_in_fairness_results_for_high_mid_pri[i][1];
+ mid_pri_itereated_after_low_pri_set = true;
+ });
+
+ SyncPoint::GetInstance()->SetCallBack(
+ "GenericRateLimiter::GeneratePriorityIterationOrderLocked::"
+ "PreReturnPriIterationOrder",
+ [&](void* arg) {
+ std::vector<Env::IOPriority>* pri_iteration_order =
+ (std::vector<Env::IOPriority>*)arg;
+ EXPECT_EQ(*pri_iteration_order, possible_priority_iteration_orders[i])
+ << "Failed to generate priority iteration order correctly when "
+ "high_pri_iterated_after_mid_low_pri = "
+ << possible_random_one_in_fairness_results_for_high_mid_pri[i][0]
+ << ", mid_pri_itereated_after_low_pri = "
+ << possible_random_one_in_fairness_results_for_high_mid_pri[i][1]
+ << std::endl;
+ pri_iteration_order_verified = true;
+ });
+
+ SyncPoint::GetInstance()->EnableProcessing();
+ limiter->Request(200 /* request max bytes to drain so that refill and order
+ generation will be triggered every time
+ GenericRateLimiter::Request() is called */
+ ,
+ Env::IO_USER, nullptr /* stats */,
+ RateLimiter::OpType::kWrite);
+ ASSERT_EQ(high_pri_iterated_after_mid_low_pri_set, true);
+ ASSERT_EQ(mid_pri_itereated_after_low_pri_set, true);
+ ASSERT_EQ(pri_iteration_order_verified, true);
+ SyncPoint::GetInstance()->DisableProcessing();
+ SyncPoint::GetInstance()->ClearCallBack(
+ "GenericRateLimiter::GeneratePriorityIterationOrderLocked::"
+ "PreReturnPriIterationOrder");
+ SyncPoint::GetInstance()->ClearCallBack(
+ "GenericRateLimiter::GeneratePriorityIterationOrderLocked::"
+ "PostRandomOneInFairnessForMidPri");
+ SyncPoint::GetInstance()->ClearCallBack(
+ "GenericRateLimiter::GeneratePriorityIterationOrderLocked::"
+ "PostRandomOneInFairnessForHighPri");
+ }
+}
+
+TEST_F(RateLimiterTest, Rate) {
+ auto* env = Env::Default();
+ struct Arg {
+ Arg(int32_t _target_rate, int _burst)
+ : limiter(NewGenericRateLimiter(_target_rate /* rate_bytes_per_sec */,
+ 100 * 1000 /* refill_period_us */,
+ 10 /* fairness */)),
+ request_size(_target_rate /
+ 10 /* refill period here is 1/10 second */),
+ burst(_burst) {}
+ std::unique_ptr<RateLimiter> limiter;
+ int32_t request_size;
+ int burst;
+ };
+
+ auto writer = [](void* p) {
+ const auto& thread_clock = SystemClock::Default();
+ auto* arg = static_cast<Arg*>(p);
+ // Test for 2 seconds
+ auto until = thread_clock->NowMicros() + 2 * 1000000;
+ Random r((uint32_t)(thread_clock->NowNanos() %
+ std::numeric_limits<uint32_t>::max()));
+ while (thread_clock->NowMicros() < until) {
+ for (int i = 0; i < static_cast<int>(r.Skewed(arg->burst * 2) + 1); ++i) {
+ arg->limiter->Request(r.Uniform(arg->request_size - 1) + 1,
+ Env::IO_USER, nullptr /* stats */,
+ RateLimiter::OpType::kWrite);
+ }
+
+ for (int i = 0; i < static_cast<int>(r.Skewed(arg->burst) + 1); ++i) {
+ arg->limiter->Request(r.Uniform(arg->request_size - 1) + 1,
+ Env::IO_HIGH, nullptr /* stats */,
+ RateLimiter::OpType::kWrite);
+ }
+
+ for (int i = 0; i < static_cast<int>(r.Skewed(arg->burst / 2 + 1) + 1);
+ ++i) {
+ arg->limiter->Request(r.Uniform(arg->request_size - 1) + 1, Env::IO_MID,
+ nullptr /* stats */, RateLimiter::OpType::kWrite);
+ }
+
+ arg->limiter->Request(r.Uniform(arg->request_size - 1) + 1, Env::IO_LOW,
+ nullptr /* stats */, RateLimiter::OpType::kWrite);
+ }
+ };
+
+ int samples = 0;
+ int samples_at_minimum = 0;
+
+ for (int i = 1; i <= 16; i *= 2) {
+ int32_t target = i * 1024 * 10;
+ Arg arg(target, i / 4 + 1);
+ int64_t old_total_bytes_through = 0;
+ for (int iter = 1; iter <= 2; ++iter) {
+ // second iteration changes the target dynamically
+ if (iter == 2) {
+ target *= 2;
+ arg.limiter->SetBytesPerSecond(target);
+ }
+ auto start = env->NowMicros();
+ for (int t = 0; t < i; ++t) {
+ env->StartThread(writer, &arg);
+ }
+ env->WaitForJoin();
+
+ auto elapsed = env->NowMicros() - start;
+ double rate =
+ (arg.limiter->GetTotalBytesThrough() - old_total_bytes_through) *
+ 1000000.0 / elapsed;
+ old_total_bytes_through = arg.limiter->GetTotalBytesThrough();
+ fprintf(stderr,
+ "request size [1 - %" PRIi32 "], limit %" PRIi32
+ " KB/sec, actual rate: %lf KB/sec, elapsed %.2lf seconds\n",
+ arg.request_size - 1, target / 1024, rate / 1024,
+ elapsed / 1000000.0);
+
+ ++samples;
+ if (rate / target >= 0.80) {
+ ++samples_at_minimum;
+ }
+ ASSERT_LE(rate / target, 1.25);
+ }
+ }
+
+ // This can fail due to slow execution speed, like when using valgrind or in
+ // heavily loaded CI environments
+ bool skip_minimum_rate_check =
+#if (defined(CIRCLECI) && defined(OS_MACOSX)) || defined(ROCKSDB_VALGRIND_RUN)
+ true;
+#else
+ getenv("SANDCASTLE");
+#endif
+ if (skip_minimum_rate_check) {
+ fprintf(stderr, "Skipped minimum rate check (%d / %d passed)\n",
+ samples_at_minimum, samples);
+ } else {
+ ASSERT_EQ(samples_at_minimum, samples);
+ }
+}
+
+TEST_F(RateLimiterTest, LimitChangeTest) {
+ // starvation test when limit changes to a smaller value
+ int64_t refill_period = 1000 * 1000;
+ auto* env = Env::Default();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+ struct Arg {
+ Arg(int32_t _request_size, Env::IOPriority _pri,
+ std::shared_ptr<RateLimiter> _limiter)
+ : request_size(_request_size), pri(_pri), limiter(_limiter) {}
+ int32_t request_size;
+ Env::IOPriority pri;
+ std::shared_ptr<RateLimiter> limiter;
+ };
+
+ auto writer = [](void* p) {
+ auto* arg = static_cast<Arg*>(p);
+ arg->limiter->Request(arg->request_size, arg->pri, nullptr /* stats */,
+ RateLimiter::OpType::kWrite);
+ };
+
+ for (uint32_t i = 1; i <= 16; i <<= 1) {
+ int32_t target = i * 1024 * 10;
+ // refill per second
+ for (int iter = 0; iter < 2; iter++) {
+ std::shared_ptr<RateLimiter> limiter =
+ std::make_shared<GenericRateLimiter>(
+ target, refill_period, 10, RateLimiter::Mode::kWritesOnly,
+ SystemClock::Default(), false /* auto_tuned */);
+ // After "GenericRateLimiter::Request:1" the mutex is held until the bytes
+ // are refilled. This test could be improved to change the limit when lock
+ // is released in `TimedWait()`.
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency(
+ {{"GenericRateLimiter::Request",
+ "RateLimiterTest::LimitChangeTest:changeLimitStart"},
+ {"RateLimiterTest::LimitChangeTest:changeLimitEnd",
+ "GenericRateLimiter::Request:1"}});
+ Arg arg(target, Env::IO_HIGH, limiter);
+ // The idea behind is to start a request first, then before it refills,
+ // update limit to a different value (2X/0.5X). No starvation should
+ // be guaranteed under any situation
+ // TODO(lightmark): more test cases are welcome.
+ env->StartThread(writer, &arg);
+ int32_t new_limit = (target << 1) >> (iter << 1);
+ TEST_SYNC_POINT("RateLimiterTest::LimitChangeTest:changeLimitStart");
+ arg.limiter->SetBytesPerSecond(new_limit);
+ TEST_SYNC_POINT("RateLimiterTest::LimitChangeTest:changeLimitEnd");
+ env->WaitForJoin();
+ fprintf(stderr,
+ "[COMPLETE] request size %" PRIi32 " KB, new limit %" PRIi32
+ "KB/sec, refill period %" PRIi64 " ms\n",
+ target / 1024, new_limit / 1024, refill_period / 1000);
+ }
+ }
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+}
+
+TEST_F(RateLimiterTest, AutoTuneIncreaseWhenFull) {
+ const std::chrono::seconds kTimePerRefill(1);
+ const int kRefillsPerTune = 100; // needs to match util/rate_limiter.cc
+
+ SpecialEnv special_env(Env::Default(), /*time_elapse_only_sleep*/ true);
+
+ auto stats = CreateDBStatistics();
+ std::unique_ptr<RateLimiter> rate_limiter(new GenericRateLimiter(
+ 1000 /* rate_bytes_per_sec */,
+ std::chrono::microseconds(kTimePerRefill).count(), 10 /* fairness */,
+ RateLimiter::Mode::kWritesOnly, special_env.GetSystemClock(),
+ true /* auto_tuned */));
+
+ // Rate limiter uses `CondVar::TimedWait()`, which does not have access to the
+ // `Env` to advance its time according to the fake wait duration. The
+ // workaround is to install a callback that advance the `Env`'s mock time.
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "GenericRateLimiter::Request:PostTimedWait", [&](void* arg) {
+ int64_t time_waited_us = *static_cast<int64_t*>(arg);
+ special_env.SleepForMicroseconds(static_cast<int>(time_waited_us));
+ });
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+
+ // verify rate limit increases after a sequence of periods where rate limiter
+ // is always drained
+ int64_t orig_bytes_per_sec = rate_limiter->GetSingleBurstBytes();
+ rate_limiter->Request(orig_bytes_per_sec, Env::IO_HIGH, stats.get(),
+ RateLimiter::OpType::kWrite);
+ while (std::chrono::microseconds(special_env.NowMicros()) <=
+ kRefillsPerTune * kTimePerRefill) {
+ rate_limiter->Request(orig_bytes_per_sec, Env::IO_HIGH, stats.get(),
+ RateLimiter::OpType::kWrite);
+ }
+ int64_t new_bytes_per_sec = rate_limiter->GetSingleBurstBytes();
+ ASSERT_GT(new_bytes_per_sec, orig_bytes_per_sec);
+
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->ClearCallBack(
+ "GenericRateLimiter::Request:PostTimedWait");
+
+ // decreases after a sequence of periods where rate limiter is not drained
+ orig_bytes_per_sec = new_bytes_per_sec;
+ special_env.SleepForMicroseconds(static_cast<int>(
+ kRefillsPerTune * std::chrono::microseconds(kTimePerRefill).count()));
+ // make a request so tuner can be triggered
+ rate_limiter->Request(1 /* bytes */, Env::IO_HIGH, stats.get(),
+ RateLimiter::OpType::kWrite);
+ new_bytes_per_sec = rate_limiter->GetSingleBurstBytes();
+ ASSERT_LT(new_bytes_per_sec, orig_bytes_per_sec);
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/rocksdb/util/repeatable_thread.h b/src/rocksdb/util/repeatable_thread.h
new file mode 100644
index 000000000..c75ad7c49
--- /dev/null
+++ b/src/rocksdb/util/repeatable_thread.h
@@ -0,0 +1,149 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#include <functional>
+#include <string>
+
+#include "monitoring/instrumented_mutex.h"
+#include "port/port.h"
+#include "rocksdb/system_clock.h"
+#include "util/mutexlock.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// Simple wrapper around port::Thread that supports calling a callback every
+// X seconds. If you pass in 0, then it will call your callback repeatedly
+// without delay.
+class RepeatableThread {
+ public:
+ RepeatableThread(std::function<void()> function,
+ const std::string& thread_name, SystemClock* clock,
+ uint64_t delay_us, uint64_t initial_delay_us = 0)
+ : function_(function),
+ thread_name_("rocksdb:" + thread_name),
+ clock_(clock),
+ delay_us_(delay_us),
+ initial_delay_us_(initial_delay_us),
+ mutex_(clock),
+ cond_var_(&mutex_),
+ running_(true),
+#ifndef NDEBUG
+ waiting_(false),
+ run_count_(0),
+#endif
+ thread_([this] { thread(); }) {
+ }
+
+ void cancel() {
+ {
+ InstrumentedMutexLock l(&mutex_);
+ if (!running_) {
+ return;
+ }
+ running_ = false;
+ cond_var_.SignalAll();
+ }
+ thread_.join();
+ }
+
+ bool IsRunning() { return running_; }
+
+ ~RepeatableThread() { cancel(); }
+
+#ifndef NDEBUG
+ // Wait until RepeatableThread starting waiting, call the optional callback,
+ // then wait for one run of RepeatableThread. Tests can use provide a
+ // custom clock object to mock time, and use the callback here to bump current
+ // time and trigger RepeatableThread. See repeatable_thread_test for example.
+ //
+ // Note: only support one caller of this method.
+ void TEST_WaitForRun(std::function<void()> callback = nullptr) {
+ InstrumentedMutexLock l(&mutex_);
+ while (!waiting_) {
+ cond_var_.Wait();
+ }
+ uint64_t prev_count = run_count_;
+ if (callback != nullptr) {
+ callback();
+ }
+ cond_var_.SignalAll();
+ while (!(run_count_ > prev_count)) {
+ cond_var_.Wait();
+ }
+ }
+#endif
+
+ private:
+ bool wait(uint64_t delay) {
+ InstrumentedMutexLock l(&mutex_);
+ if (running_ && delay > 0) {
+ uint64_t wait_until = clock_->NowMicros() + delay;
+#ifndef NDEBUG
+ waiting_ = true;
+ cond_var_.SignalAll();
+#endif
+ while (running_) {
+ cond_var_.TimedWait(wait_until);
+ if (clock_->NowMicros() >= wait_until) {
+ break;
+ }
+ }
+#ifndef NDEBUG
+ waiting_ = false;
+#endif
+ }
+ return running_;
+ }
+
+ void thread() {
+#if defined(_GNU_SOURCE) && defined(__GLIBC_PREREQ)
+#if __GLIBC_PREREQ(2, 12)
+ // Set thread name.
+ auto thread_handle = thread_.native_handle();
+ int ret __attribute__((__unused__)) =
+ pthread_setname_np(thread_handle, thread_name_.c_str());
+ assert(ret == 0);
+#endif
+#endif
+
+ assert(delay_us_ > 0);
+ if (!wait(initial_delay_us_)) {
+ return;
+ }
+ do {
+ function_();
+#ifndef NDEBUG
+ {
+ InstrumentedMutexLock l(&mutex_);
+ run_count_++;
+ cond_var_.SignalAll();
+ }
+#endif
+ } while (wait(delay_us_));
+ }
+
+ const std::function<void()> function_;
+ const std::string thread_name_;
+ SystemClock* clock_;
+ const uint64_t delay_us_;
+ const uint64_t initial_delay_us_;
+
+ // Mutex lock should be held when accessing running_, waiting_
+ // and run_count_.
+ InstrumentedMutex mutex_;
+ InstrumentedCondVar cond_var_;
+ bool running_;
+#ifndef NDEBUG
+ // RepeatableThread waiting for timeout.
+ bool waiting_;
+ // Times function_ had run.
+ uint64_t run_count_;
+#endif
+ port::Thread thread_;
+};
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/repeatable_thread_test.cc b/src/rocksdb/util/repeatable_thread_test.cc
new file mode 100644
index 000000000..0b3e95464
--- /dev/null
+++ b/src/rocksdb/util/repeatable_thread_test.cc
@@ -0,0 +1,111 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "util/repeatable_thread.h"
+
+#include <atomic>
+#include <memory>
+
+#include "db/db_test_util.h"
+#include "test_util/mock_time_env.h"
+#include "test_util/sync_point.h"
+#include "test_util/testharness.h"
+
+class RepeatableThreadTest : public testing::Test {
+ public:
+ RepeatableThreadTest()
+ : mock_clock_(std::make_shared<ROCKSDB_NAMESPACE::MockSystemClock>(
+ ROCKSDB_NAMESPACE::SystemClock::Default())) {}
+
+ protected:
+ std::shared_ptr<ROCKSDB_NAMESPACE::MockSystemClock> mock_clock_;
+};
+
+TEST_F(RepeatableThreadTest, TimedTest) {
+ constexpr uint64_t kSecond = 1000000; // 1s = 1000000us
+ constexpr int kIteration = 3;
+ const auto& clock = ROCKSDB_NAMESPACE::SystemClock::Default();
+ ROCKSDB_NAMESPACE::port::Mutex mutex;
+ ROCKSDB_NAMESPACE::port::CondVar test_cv(&mutex);
+ int count = 0;
+ uint64_t prev_time = clock->NowMicros();
+ ROCKSDB_NAMESPACE::RepeatableThread thread(
+ [&] {
+ ROCKSDB_NAMESPACE::MutexLock l(&mutex);
+ count++;
+ uint64_t now = clock->NowMicros();
+ assert(count == 1 || prev_time + 1 * kSecond <= now);
+ prev_time = now;
+ if (count >= kIteration) {
+ test_cv.SignalAll();
+ }
+ },
+ "rt_test", clock.get(), 1 * kSecond);
+ // Wait for execution finish.
+ {
+ ROCKSDB_NAMESPACE::MutexLock l(&mutex);
+ while (count < kIteration) {
+ test_cv.Wait();
+ }
+ }
+
+ // Test cancel
+ thread.cancel();
+}
+
+TEST_F(RepeatableThreadTest, MockEnvTest) {
+ constexpr uint64_t kSecond = 1000000; // 1s = 1000000us
+ constexpr int kIteration = 3;
+ mock_clock_->SetCurrentTime(0); // in seconds
+ std::atomic<int> count{0};
+
+#if defined(OS_MACOSX) && !defined(NDEBUG)
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->ClearAllCallBacks();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "InstrumentedCondVar::TimedWaitInternal", [&](void* arg) {
+ // Obtain the current (real) time in seconds and add 1000 extra seconds
+ // to ensure that RepeatableThread::wait invokes TimedWait with a time
+ // greater than (real) current time. This is to prevent the TimedWait
+ // function from returning immediately without sleeping and releasing
+ // the mutex on certain platforms, e.g. OS X. If TimedWait returns
+ // immediately, the mutex will not be released, and
+ // RepeatableThread::TEST_WaitForRun never has a chance to execute the
+ // callback which, in this case, updates the result returned by
+ // mock_clock->NowMicros. Consequently, RepeatableThread::wait cannot
+ // break out of the loop, causing test to hang. The extra 1000 seconds
+ // is a best-effort approach because there seems no reliable and
+ // deterministic way to provide the aforementioned guarantee. By the
+ // time RepeatableThread::wait is called, it is no guarantee that the
+ // delay + mock_clock->NowMicros will be greater than the current real
+ // time. However, 1000 seconds should be sufficient in most cases.
+ uint64_t time_us = *reinterpret_cast<uint64_t*>(arg);
+ if (time_us < mock_clock_->RealNowMicros()) {
+ *reinterpret_cast<uint64_t*>(arg) =
+ mock_clock_->RealNowMicros() + 1000;
+ }
+ });
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+#endif // OS_MACOSX && !NDEBUG
+
+ ROCKSDB_NAMESPACE::RepeatableThread thread(
+ [&] { count++; }, "rt_test", mock_clock_.get(), 1 * kSecond, 1 * kSecond);
+ for (int i = 1; i <= kIteration; i++) {
+ // Bump current time
+ thread.TEST_WaitForRun([&] { mock_clock_->SetCurrentTime(i); });
+ }
+ // Test function should be exectued exactly kIteraion times.
+ ASSERT_EQ(kIteration, count.load());
+
+ // Test cancel
+ thread.cancel();
+}
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+
+ return RUN_ALL_TESTS();
+}
diff --git a/src/rocksdb/util/ribbon_alg.h b/src/rocksdb/util/ribbon_alg.h
new file mode 100644
index 000000000..f9afefc23
--- /dev/null
+++ b/src/rocksdb/util/ribbon_alg.h
@@ -0,0 +1,1225 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#include <array>
+#include <memory>
+
+#include "rocksdb/rocksdb_namespace.h"
+#include "util/math128.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+namespace ribbon {
+
+// RIBBON PHSF & RIBBON Filter (Rapid Incremental Boolean Banding ON-the-fly)
+//
+// ribbon_alg.h: generic versions of core algorithms.
+//
+// Ribbon is a Perfect Hash Static Function construction useful as a compact
+// static Bloom filter alternative. It combines (a) a boolean (GF(2)) linear
+// system construction that approximates a Band Matrix with hashing,
+// (b) an incremental, on-the-fly Gaussian Elimination algorithm that is
+// remarkably efficient and adaptable at constructing an upper-triangular
+// band matrix from a set of band-approximating inputs from (a), and
+// (c) a storage layout that is fast and adaptable as a filter.
+//
+// Footnotes: (a) "Efficient Gauss Elimination for Near-Quadratic Matrices
+// with One Short Random Block per Row, with Applications" by Stefan
+// Walzer and Martin Dietzfelbinger ("DW paper")
+// (b) developed by Peter C. Dillinger, though not the first on-the-fly
+// GE algorithm. See "On the fly Gaussian Elimination for LT codes" by
+// Bioglio, Grangetto, Gaeta, and Sereno.
+// (c) see "interleaved" solution storage below.
+//
+// See ribbon_impl.h for high-level behavioral summary. This file focuses
+// on the core design details.
+//
+// ######################################################################
+// ################# PHSF -> static filter reduction ####################
+//
+// A Perfect Hash Static Function is a data structure representing a
+// map from anything hashable (a "key") to values of some fixed size.
+// Crucially, it is allowed to return garbage values for anything not in
+// the original set of map keys, and it is a "static" structure: entries
+// cannot be added or deleted after construction. PHSFs representing n
+// mappings to b-bit values (assume uniformly distributed) require at least
+// n * b bits to represent, or at least b bits per entry. We typically
+// describe the compactness of a PHSF by typical bits per entry as some
+// function of b. For example, the MWHC construction (k=3 "peeling")
+// requires about 1.0222*b and a variant called Xor+ requires about
+// 1.08*b + 0.5 bits per entry.
+//
+// With more hashing, a PHSF can over-approximate a set as a Bloom filter
+// does, with no FN queries and predictable false positive (FP) query
+// rate. Instead of the user providing a value to map each input key to,
+// a hash function provides the value. Keys in the original set will
+// return a positive membership query because the underlying PHSF returns
+// the same value as hashing the key. When a key is not in the original set,
+// the PHSF returns a "garbage" value, which is only equal to the key's
+// hash with (false positive) probability 1 in 2^b.
+//
+// For a matching false positive rate, standard Bloom filters require
+// 1.44*b bits per entry. Cache-local Bloom filters (like bloom_impl.h)
+// require a bit more, around 1.5*b bits per entry. Thus, a Bloom
+// alternative could save up to or nearly 1/3rd of memory and storage
+// that RocksDB uses for SST (static) Bloom filters. (Memtable Bloom filter
+// is dynamic.)
+//
+// Recommended reading:
+// "Xor Filters: Faster and Smaller Than Bloom and Cuckoo Filters"
+// by Graf and Lemire
+// First three sections of "Fast Scalable Construction of (Minimal
+// Perfect Hash) Functions" by Genuzio, Ottaviano, and Vigna
+//
+// ######################################################################
+// ################## PHSF vs. hash table vs. Bloom #####################
+//
+// You can think of traditional hash tables and related filter variants
+// such as Cuckoo filters as utilizing an "OR" construction: a hash
+// function associates a key with some slots and the data is returned if
+// the data is found in any one of those slots. The collision resolution
+// is visible in the final data structure and requires extra information.
+// For example, Cuckoo filter uses roughly 1.05b + 2 bits per entry, and
+// Golomb-Rice code (aka "GCS") as little as b + 1.5. When the data
+// structure associates each input key with data in one slot, the
+// structure implicitly constructs a (near-)minimal (near-)perfect hash
+// (MPH) of the keys, which requires at least 1.44 bits per key to
+// represent. This is why approaches with visible collision resolution
+// have a fixed + 1.5 or more in storage overhead per entry, often in
+// addition to an overhead multiplier on b.
+//
+// By contrast Bloom filters utilize an "AND" construction: a query only
+// returns true if all bit positions associated with a key are set to 1.
+// There is no collision resolution, so Bloom filters do not suffer a
+// fixed bits per entry overhead like the above structures.
+//
+// PHSFs typically use a bitwise XOR construction: the data you want is
+// not in a single slot, but in a linear combination of several slots.
+// For static data, this gives the best of "AND" and "OR" constructions:
+// avoids the +1.44 or more fixed overhead by not approximating a MPH and
+// can do much better than Bloom's 1.44 factor on b with collision
+// resolution, which here is done ahead of time and invisible at query
+// time.
+//
+// ######################################################################
+// ######################## PHSF construction ###########################
+//
+// For a typical PHSF, construction is solving a linear system of
+// equations, typically in GF(2), which is to say that values are boolean
+// and XOR serves both as addition and subtraction. We can use matrices to
+// represent the problem:
+//
+// C * S = R
+// (n x m) (m x b) (n x b)
+// where C = coefficients, S = solution, R = results
+// and solving for S given C and R.
+//
+// Note that C and R each have n rows, one for each input entry for the
+// PHSF. A row in C is given by a hash function on the PHSF input key,
+// and the corresponding row in R is the b-bit value to associate with
+// that input key. (In a filter, rows of R are given by another hash
+// function on the input key.)
+//
+// On solving, the matrix S (solution) is the final PHSF data, as it
+// maps any row from the original C to its corresponding desired result
+// in R. We just have to hash our query inputs and compute a linear
+// combination of rows in S.
+//
+// In theory, we could chose m = n and let a hash function associate
+// each input key with random rows in C. A solution exists with high
+// probability, and uses essentially minimum space, b bits per entry
+// (because we set m = n) but this has terrible scaling, something
+// like O(n^2) space and O(n^3) time during construction (Gaussian
+// elimination) and O(n) query time. But computational efficiency is
+// key, and the core of this is avoiding scanning all of S to answer
+// each query.
+//
+// The traditional approach (MWHC, aka Xor filter) starts with setting
+// only some small fixed number of columns (typically k=3) to 1 for each
+// row of C, with remaining entries implicitly 0. This is implemented as
+// three hash functions over [0,m), and S can be implemented as a vector
+// of b-bit values. Now, a query only involves looking up k rows
+// (values) in S and computing their bitwise XOR. Additionally, this
+// construction can use a linear time algorithm called "peeling" for
+// finding a solution in many cases of one existing, but peeling
+// generally requires a larger space overhead factor in the solution
+// (m/n) than is required with Gaussian elimination.
+//
+// Recommended reading:
+// "Peeling Close to the Orientability Threshold - Spatial Coupling in
+// Hashing-Based Data Structures" by Stefan Walzer
+//
+// ######################################################################
+// ##################### Ribbon PHSF construction #######################
+//
+// Ribbon constructs coefficient rows essentially the same as in the
+// Walzer/Dietzfelbinger paper cited above: for some chosen fixed width
+// r (kCoeffBits in code), each key is hashed to a starting column in
+// [0, m - r] (GetStart() in code) and an r-bit sequence of boolean
+// coefficients (GetCoeffRow() in code). If you sort the rows by start,
+// the C matrix would look something like this:
+//
+// [####00000000000000000000]
+// [####00000000000000000000]
+// [000####00000000000000000]
+// [0000####0000000000000000]
+// [0000000####0000000000000]
+// [000000000####00000000000]
+// [000000000####00000000000]
+// [0000000000000####0000000]
+// [0000000000000000####0000]
+// [00000000000000000####000]
+// [00000000000000000000####]
+//
+// where each # could be a 0 or 1, chosen uniformly by a hash function.
+// (Except we typically set the start column value to 1.) This scheme
+// uses hashing to approximate a band matrix, and it has a solution iff
+// it reduces to an upper-triangular boolean r-band matrix, like this:
+//
+// [1###00000000000000000000]
+// [01##00000000000000000000]
+// [000000000000000000000000]
+// [0001###00000000000000000]
+// [000000000000000000000000]
+// [000001##0000000000000000]
+// [000000000000000000000000]
+// [00000001###0000000000000]
+// [000000001###000000000000]
+// [0000000001##000000000000]
+// ...
+// [00000000000000000000001#]
+// [000000000000000000000001]
+//
+// where we have expanded to an m x m matrix by filling with rows of
+// all zeros as needed. As in Gaussian elimination, this form is ready for
+// generating a solution through back-substitution.
+//
+// The awesome thing about the Ribbon construction (from the DW paper) is
+// how row reductions keep each row representable as a start column and
+// r coefficients, because row reductions are only needed when two rows
+// have the same number of leading zero columns. Thus, the combination
+// of those rows, the bitwise XOR of the r-bit coefficient rows, cancels
+// out the leading 1s, so starts (at least) one column later and only
+// needs (at most) r - 1 coefficients.
+//
+// ######################################################################
+// ###################### Ribbon PHSF scalability #######################
+//
+// Although more practical detail is in ribbon_impl.h, it's worth
+// understanding some of the overall benefits and limitations of the
+// Ribbon PHSFs.
+//
+// High-end scalability is a primary issue for Ribbon PHSFs, because in
+// a single Ribbon linear system with fixed r and fixed m/n ratio, the
+// solution probability approaches zero as n approaches infinity.
+// For a given n, solution probability improves with larger r and larger
+// m/n.
+//
+// By contrast, peeling-based PHSFs have somewhat worse storage ratio
+// or solution probability for small n (less than ~1000). This is
+// especially true with spatial-coupling, where benefits are only
+// notable for n on the order of 100k or 1m or more.
+//
+// To make best use of current hardware, r=128 seems to be closest to
+// a "generally good" choice for Ribbon, at least in RocksDB where SST
+// Bloom filters typically hold around 10-100k keys, and almost always
+// less than 10m keys. r=128 ribbon has a high chance of encoding success
+// (with first hash seed) when storage overhead is around 5% (m/n ~ 1.05)
+// for roughly 10k - 10m keys in a single linear system. r=64 only scales
+// up to about 10k keys with the same storage overhead. Construction and
+// access times for r=128 are similar to r=64. r=128 tracks nearly
+// twice as much data during construction, but in most cases we expect
+// the scalability benefits of r=128 vs. r=64 to make it preferred.
+//
+// A natural approach to scaling Ribbon beyond ~10m keys is splitting
+// (or "sharding") the inputs into multiple linear systems with their
+// own hash seeds. This can also help to control peak memory consumption.
+// TODO: much more to come
+//
+// ######################################################################
+// #################### Ribbon on-the-fly banding #######################
+//
+// "Banding" is what we call the process of reducing the inputs to an
+// upper-triangular r-band matrix ready for finishing a solution with
+// back-substitution. Although the DW paper presents an algorithm for
+// this ("SGauss"), the awesome properties of their construction enable
+// an even simpler, faster, and more backtrackable algorithm. In simplest
+// terms, the SGauss algorithm requires sorting the inputs by start
+// columns, but it's possible to make Gaussian elimination resemble hash
+// table insertion!
+//
+// The enhanced algorithm is based on these observations:
+// - When processing a coefficient row with first 1 in column j,
+// - If it's the first at column j to be processed, it can be part of
+// the banding at row j. (And that decision never overwritten, with
+// no loss of generality!)
+// - Else, it can be combined with existing row j and re-processed,
+// which will look for a later "empty" row or reach "no solution".
+//
+// We call our banding algorithm "incremental" and "on-the-fly" because
+// (like hash table insertion) we are "finished" after each input
+// processed, with respect to all inputs processed so far. Although the
+// band matrix is an intermediate step to the solution structure, we have
+// eliminated intermediate steps and unnecessary data tracking for
+// banding.
+//
+// Building on "incremental" and "on-the-fly", the banding algorithm is
+// easily backtrackable because no (non-empty) rows are overwritten in
+// the banding. Thus, if we want to "try" adding an additional set of
+// inputs to the banding, we only have to record which rows were written
+// in order to efficiently backtrack to our state before considering
+// the additional set. (TODO: how this can mitigate scalability and
+// reach sub-1% overheads)
+//
+// Like in a linear-probed hash table, as the occupancy approaches and
+// surpasses 90-95%, collision resolution dominates the construction
+// time. (Ribbon doesn't usually pay at query time; see solution
+// storage below.) This means that we can speed up construction time
+// by using a higher m/n ratio, up to negative returns around 1.2.
+// At m/n ~= 1.2, which still saves memory substantially vs. Bloom
+// filter's 1.5, construction speed (including back-substitution) is not
+// far from sorting speed, but still a few times slower than cache-local
+// Bloom construction speed.
+//
+// Back-substitution from an upper-triangular boolean band matrix is
+// especially fast and easy. All the memory accesses are sequential or at
+// least local, no random. If the number of result bits (b) is a
+// compile-time constant, the back-substitution state can even be tracked
+// in CPU registers. Regardless of the solution representation, we prefer
+// column-major representation for tracking back-substitution state, as
+// r (the band width) will typically be much larger than b (result bits
+// or columns), so better to handle r-bit values b times (per solution
+// row) than b-bit values r times.
+//
+// ######################################################################
+// ##################### Ribbon solution storage ########################
+//
+// Row-major layout is typical for boolean (bit) matrices, including for
+// MWHC (Xor) filters where a query combines k b-bit values, and k is
+// typically smaller than b. Even for k=4 and b=2, at least k=4 random
+// look-ups are required regardless of layout.
+//
+// Ribbon PHSFs are quite different, however, because
+// (a) all of the solution rows relevant to a query are within a single
+// range of r rows, and
+// (b) the number of solution rows involved (r/2 on average, or r if
+// avoiding conditional accesses) is typically much greater than
+// b, the number of solution columns.
+//
+// Row-major for Ribbon PHSFs therefore tends to incur undue CPU overhead
+// by processing (up to) r entries of b bits each, where b is typically
+// less than 10 for filter applications.
+//
+// Column-major layout has poor locality because of accessing up to b
+// memory locations in different pages (and obviously cache lines). Note
+// that negative filter queries do not typically need to access all
+// solution columns, as they can return when a mismatch is found in any
+// result/solution column. This optimization doesn't always pay off on
+// recent hardware, where the penalty for unpredictable conditional
+// branching can exceed the penalty for unnecessary work, but the
+// optimization is essentially unavailable with row-major layout.
+//
+// The best compromise seems to be interleaving column-major on the small
+// scale with row-major on the large scale. For example, let a solution
+// "block" be r rows column-major encoded as b r-bit values in sequence.
+// Each query accesses (up to) 2 adjacent blocks, which will typically
+// span 1-3 cache lines in adjacent memory. We get very close to the same
+// locality as row-major, but with much faster reconstruction of each
+// result column, at least for filter applications where b is relatively
+// small and negative queries can return early.
+//
+// ######################################################################
+// ###################### Fractional result bits ########################
+//
+// Bloom filters have great flexibility that alternatives mostly do not
+// have. One of those flexibilities is in utilizing any ratio of data
+// structure bits per key. With a typical memory allocator like jemalloc,
+// this flexibility can save roughly 10% of the filters' footprint in
+// DRAM by rounding up and down filter sizes to minimize memory internal
+// fragmentation (see optimize_filters_for_memory RocksDB option).
+//
+// At first glance, PHSFs only offer a whole number of bits per "slot"
+// (m rather than number of keys n), but coefficient locality in the
+// Ribbon construction makes fractional bits/key quite possible and
+// attractive for filter applications. This works by a prefix of the
+// structure using b-1 solution columns and the rest using b solution
+// columns. See InterleavedSolutionStorage below for more detail.
+//
+// Because false positive rates are non-linear in bits/key, this approach
+// is not quite optimal in terms of information theory. In common cases,
+// we see additional space overhead up to about 1.5% vs. theoretical
+// optimal to achieve the same FP rate. We consider this a quite acceptable
+// overhead for very efficiently utilizing space that might otherwise be
+// wasted.
+//
+// This property of Ribbon even makes it "elastic." A Ribbon filter and
+// its small metadata for answering queries can be adapted into another
+// Ribbon filter filling any smaller multiple of r bits (plus small
+// metadata), with a correspondingly higher FP rate. None of the data
+// thrown away during construction needs to be recalled for this reduction.
+// Similarly a single Ribbon construction can be separated (by solution
+// column) into two or more structures (or "layers" or "levels") with
+// independent filtering ability (no FP correlation, just as solution or
+// result columns in a single structure) despite being constructed as part
+// of a single linear system. (TODO: implement)
+// See also "ElasticBF: Fine-grained and Elastic Bloom Filter Towards
+// Efficient Read for LSM-tree-based KV Stores."
+//
+
+// ######################################################################
+// ################### CODE: Ribbon core algorithms #####################
+// ######################################################################
+//
+// These algorithms are templatized for genericity but near-maximum
+// performance in a given application. The template parameters
+// adhere to informal class/struct type concepts outlined below. (This
+// code is written for C++11 so does not use formal C++ concepts.)
+
+// Rough architecture for these algorithms:
+//
+// +-----------+ +---+ +-----------------+
+// | AddInputs | --> | H | --> | BandingStorage |
+// +-----------+ | a | +-----------------+
+// | s | |
+// | h | Back substitution
+// | e | V
+// +-----------+ | r | +-----------------+
+// | Query Key | --> | | >+< | SolutionStorage |
+// +-----------+ +---+ | +-----------------+
+// V
+// Query result
+
+// Common to other concepts
+// concept RibbonTypes {
+// // An unsigned integer type for an r-bit subsequence of coefficients.
+// // r (or kCoeffBits) is taken to be sizeof(CoeffRow) * 8, as it would
+// // generally only hurt scalability to leave bits of CoeffRow unused.
+// typename CoeffRow;
+// // An unsigned integer type big enough to hold a result row (b bits,
+// // or number of solution/result columns).
+// // In many applications, especially filters, the number of result
+// // columns is decided at run time, so ResultRow simply needs to be
+// // big enough for the largest number of columns allowed.
+// typename ResultRow;
+// // An unsigned integer type sufficient for representing the number of
+// // rows in the solution structure, and at least the arithmetic
+// // promotion size (usually 32 bits). uint32_t recommended because a
+// // single Ribbon construction doesn't really scale to billions of
+// // entries.
+// typename Index;
+// };
+
+// ######################################################################
+// ######################## Hashers and Banding #########################
+
+// Hasher concepts abstract out hashing details.
+
+// concept PhsfQueryHasher extends RibbonTypes {
+// // Type for a lookup key, which is hashable.
+// typename Key;
+//
+// // Type for hashed summary of a Key. uint64_t is recommended.
+// typename Hash;
+//
+// // Compute a hash value summarizing a Key
+// Hash GetHash(const Key &) const;
+//
+// // Given a hash value and a number of columns that can start an
+// // r-sequence of coefficients (== m - r + 1), return the start
+// // column to associate with that hash value. (Starts can be chosen
+// // uniformly or "smash" extra entries into the beginning and end for
+// // better utilization at those extremes of the structure. Details in
+// // ribbon.impl.h)
+// Index GetStart(Hash, Index num_starts) const;
+//
+// // Given a hash value, return the r-bit sequence of coefficients to
+// // associate with it. It's generally OK if
+// // sizeof(CoeffRow) > sizeof(Hash)
+// // as long as the hash itself is not too prone to collisions for the
+// // applications and the CoeffRow is generated uniformly from
+// // available hash data, but relatively independent of the start.
+// //
+// // Must be non-zero, because that's required for a solution to exist
+// // when mapping to non-zero result row. (Note: BandingAdd could be
+// // modified to allow 0 coeff row if that only occurs with 0 result
+// // row, which really only makes sense for filter implementation,
+// // where both values are hash-derived. Or BandingAdd could reject 0
+// // coeff row, forcing next seed, but that has potential problems with
+// // generality/scalability.)
+// CoeffRow GetCoeffRow(Hash) const;
+// };
+
+// concept FilterQueryHasher extends PhsfQueryHasher {
+// // For building or querying a filter, this returns the expected
+// // result row associated with a hashed input. For general PHSF,
+// // this must return 0.
+// //
+// // Although not strictly required, there's a slightly better chance of
+// // solver success if result row is masked down here to only the bits
+// // actually needed.
+// ResultRow GetResultRowFromHash(Hash) const;
+// }
+
+// concept BandingHasher extends FilterQueryHasher {
+// // For a filter, this will generally be the same as Key.
+// // For a general PHSF, it must either
+// // (a) include a key and a result it maps to (e.g. in a std::pair), or
+// // (b) GetResultRowFromInput looks up the result somewhere rather than
+// // extracting it.
+// typename AddInput;
+//
+// // Instead of requiring a way to extract a Key from an
+// // AddInput, we require getting the hash of the Key part
+// // of an AddInput, which is trivial if AddInput == Key.
+// Hash GetHash(const AddInput &) const;
+//
+// // For building a non-filter PHSF, this extracts or looks up the result
+// // row to associate with an input. For filter PHSF, this must return 0.
+// ResultRow GetResultRowFromInput(const AddInput &) const;
+//
+// // Whether the solver can assume the lowest bit of GetCoeffRow is
+// // always 1. When true, it should improve solver efficiency slightly.
+// static bool kFirstCoeffAlwaysOne;
+// }
+
+// Abstract storage for the the result of "banding" the inputs (Gaussian
+// elimination to an upper-triangular boolean band matrix). Because the
+// banding is an incremental / on-the-fly algorithm, this also represents
+// all the intermediate state between input entries.
+//
+// concept BandingStorage extends RibbonTypes {
+// // Tells the banding algorithm to prefetch memory associated with
+// // the next input before processing the current input. Generally
+// // recommended iff the BandingStorage doesn't easily fit in CPU
+// // cache.
+// bool UsePrefetch() const;
+//
+// // Prefetches (e.g. __builtin_prefetch) memory associated with a
+// // slot index i.
+// void Prefetch(Index i) const;
+//
+// // Load or store CoeffRow and ResultRow for slot index i.
+// // (Gaussian row operations involve both sides of the equation.)
+// // Bool `for_back_subst` indicates that customizing values for
+// // unconstrained solution rows (cr == 0) is allowed.
+// void LoadRow(Index i, CoeffRow *cr, ResultRow *rr, bool for_back_subst)
+// const;
+// void StoreRow(Index i, CoeffRow cr, ResultRow rr);
+//
+// // Returns the number of columns that can start an r-sequence of
+// // coefficients, which is the number of slots minus r (kCoeffBits)
+// // plus one. (m - r + 1)
+// Index GetNumStarts() const;
+// };
+
+// Optional storage for backtracking data in banding a set of input
+// entries. It exposes an array structure which will generally be
+// used as a stack. It must be able to accommodate as many entries
+// as are passed in as inputs to `BandingAddRange`.
+//
+// concept BacktrackStorage extends RibbonTypes {
+// // If false, backtracking support will be disabled in the algorithm.
+// // This should preferably be an inline compile-time constant function.
+// bool UseBacktrack() const;
+//
+// // Records `to_save` as the `i`th backtrack entry
+// void BacktrackPut(Index i, Index to_save);
+//
+// // Recalls the `i`th backtrack entry
+// Index BacktrackGet(Index i) const;
+// }
+
+// Adds a single entry to BandingStorage (and optionally, BacktrackStorage),
+// returning true if successful or false if solution is impossible with
+// current hasher (and presumably its seed) and number of "slots" (solution
+// or banding rows). (A solution is impossible when there is a linear
+// dependence among the inputs that doesn't "cancel out".)
+//
+// Pre- and post-condition: the BandingStorage represents a band matrix
+// ready for back substitution (row echelon form except for zero rows),
+// augmented with result values such that back substitution would give a
+// solution satisfying all the cr@start -> rr entries added.
+template <bool kFirstCoeffAlwaysOne, typename BandingStorage,
+ typename BacktrackStorage>
+bool BandingAdd(BandingStorage *bs, typename BandingStorage::Index start,
+ typename BandingStorage::ResultRow rr,
+ typename BandingStorage::CoeffRow cr, BacktrackStorage *bts,
+ typename BandingStorage::Index *backtrack_pos) {
+ using CoeffRow = typename BandingStorage::CoeffRow;
+ using ResultRow = typename BandingStorage::ResultRow;
+ using Index = typename BandingStorage::Index;
+
+ Index i = start;
+
+ if (!kFirstCoeffAlwaysOne) {
+ // Requires/asserts that cr != 0
+ int tz = CountTrailingZeroBits(cr);
+ i += static_cast<Index>(tz);
+ cr >>= tz;
+ }
+
+ for (;;) {
+ assert((cr & 1) == 1);
+ CoeffRow cr_at_i;
+ ResultRow rr_at_i;
+ bs->LoadRow(i, &cr_at_i, &rr_at_i, /* for_back_subst */ false);
+ if (cr_at_i == 0) {
+ bs->StoreRow(i, cr, rr);
+ bts->BacktrackPut(*backtrack_pos, i);
+ ++*backtrack_pos;
+ return true;
+ }
+ assert((cr_at_i & 1) == 1);
+ // Gaussian row reduction
+ cr ^= cr_at_i;
+ rr ^= rr_at_i;
+ if (cr == 0) {
+ // Inconsistency or (less likely) redundancy
+ break;
+ }
+ // Find relative offset of next non-zero coefficient.
+ int tz = CountTrailingZeroBits(cr);
+ i += static_cast<Index>(tz);
+ cr >>= tz;
+ }
+
+ // Failed, unless result row == 0 because e.g. a duplicate input or a
+ // stock hash collision, with same result row. (For filter, stock hash
+ // collision implies same result row.) Or we could have a full equation
+ // equal to sum of other equations, which is very possible with
+ // small range of values for result row.
+ return rr == 0;
+}
+
+// Adds a range of entries to BandingStorage returning true if successful
+// or false if solution is impossible with current hasher (and presumably
+// its seed) and number of "slots" (solution or banding rows). (A solution
+// is impossible when there is a linear dependence among the inputs that
+// doesn't "cancel out".) Here "InputIterator" is an iterator over AddInputs.
+//
+// If UseBacktrack in the BacktrackStorage, this function call rolls back
+// to prior state on failure. If !UseBacktrack, some subset of the entries
+// will have been added to the BandingStorage, so best considered to be in
+// an indeterminate state.
+//
+template <typename BandingStorage, typename BacktrackStorage,
+ typename BandingHasher, typename InputIterator>
+bool BandingAddRange(BandingStorage *bs, BacktrackStorage *bts,
+ const BandingHasher &bh, InputIterator begin,
+ InputIterator end) {
+ using CoeffRow = typename BandingStorage::CoeffRow;
+ using Index = typename BandingStorage::Index;
+ using ResultRow = typename BandingStorage::ResultRow;
+ using Hash = typename BandingHasher::Hash;
+
+ static_assert(IsUnsignedUpTo128<CoeffRow>::value, "must be unsigned");
+ static_assert(IsUnsignedUpTo128<Index>::value, "must be unsigned");
+ static_assert(IsUnsignedUpTo128<ResultRow>::value, "must be unsigned");
+
+ constexpr bool kFCA1 = BandingHasher::kFirstCoeffAlwaysOne;
+
+ if (begin == end) {
+ // trivial
+ return true;
+ }
+
+ const Index num_starts = bs->GetNumStarts();
+
+ InputIterator cur = begin;
+ Index backtrack_pos = 0;
+ if (!bs->UsePrefetch()) {
+ // Simple version, no prefetch
+ for (;;) {
+ Hash h = bh.GetHash(*cur);
+ Index start = bh.GetStart(h, num_starts);
+ ResultRow rr =
+ bh.GetResultRowFromInput(*cur) | bh.GetResultRowFromHash(h);
+ CoeffRow cr = bh.GetCoeffRow(h);
+
+ if (!BandingAdd<kFCA1>(bs, start, rr, cr, bts, &backtrack_pos)) {
+ break;
+ }
+ if ((++cur) == end) {
+ return true;
+ }
+ }
+ } else {
+ // Pipelined w/prefetch
+ // Prime the pipeline
+ Hash h = bh.GetHash(*cur);
+ Index start = bh.GetStart(h, num_starts);
+ ResultRow rr = bh.GetResultRowFromInput(*cur);
+ bs->Prefetch(start);
+
+ // Pipeline
+ for (;;) {
+ rr |= bh.GetResultRowFromHash(h);
+ CoeffRow cr = bh.GetCoeffRow(h);
+ if ((++cur) == end) {
+ if (!BandingAdd<kFCA1>(bs, start, rr, cr, bts, &backtrack_pos)) {
+ break;
+ }
+ return true;
+ }
+ Hash next_h = bh.GetHash(*cur);
+ Index next_start = bh.GetStart(next_h, num_starts);
+ ResultRow next_rr = bh.GetResultRowFromInput(*cur);
+ bs->Prefetch(next_start);
+ if (!BandingAdd<kFCA1>(bs, start, rr, cr, bts, &backtrack_pos)) {
+ break;
+ }
+ h = next_h;
+ start = next_start;
+ rr = next_rr;
+ }
+ }
+ // failed; backtrack (if implemented)
+ if (bts->UseBacktrack()) {
+ while (backtrack_pos > 0) {
+ --backtrack_pos;
+ Index i = bts->BacktrackGet(backtrack_pos);
+ // Clearing the ResultRow is not strictly required, but is required
+ // for good FP rate on inputs that might have been backtracked out.
+ // (We don't want anything we've backtracked on to leak into final
+ // result, as that might not be "harmless".)
+ bs->StoreRow(i, 0, 0);
+ }
+ }
+ return false;
+}
+
+// Adds a range of entries to BandingStorage returning true if successful
+// or false if solution is impossible with current hasher (and presumably
+// its seed) and number of "slots" (solution or banding rows). (A solution
+// is impossible when there is a linear dependence among the inputs that
+// doesn't "cancel out".) Here "InputIterator" is an iterator over AddInputs.
+//
+// On failure, some subset of the entries will have been added to the
+// BandingStorage, so best considered to be in an indeterminate state.
+//
+template <typename BandingStorage, typename BandingHasher,
+ typename InputIterator>
+bool BandingAddRange(BandingStorage *bs, const BandingHasher &bh,
+ InputIterator begin, InputIterator end) {
+ using Index = typename BandingStorage::Index;
+ struct NoopBacktrackStorage {
+ bool UseBacktrack() { return false; }
+ void BacktrackPut(Index, Index) {}
+ Index BacktrackGet(Index) {
+ assert(false);
+ return 0;
+ }
+ } nbts;
+ return BandingAddRange(bs, &nbts, bh, begin, end);
+}
+
+// ######################################################################
+// ######################### Solution Storage ###########################
+
+// Back-substitution and query algorithms unfortunately depend on some
+// details of data layout in the final data structure ("solution"). Thus,
+// there is no common SolutionStorage covering all the reasonable
+// possibilities.
+
+// ###################### SimpleSolutionStorage #########################
+
+// SimpleSolutionStorage is for a row-major storage, typically with no
+// unused bits in each ResultRow. This is mostly for demonstration
+// purposes as the simplest solution storage scheme. It is relatively slow
+// for filter queries.
+
+// concept SimpleSolutionStorage extends RibbonTypes {
+// // This is called at the beginning of back-substitution for the
+// // solution storage to do any remaining configuration before data
+// // is stored to it. If configuration is previously finalized, this
+// // could be a simple assertion or even no-op. Ribbon algorithms
+// // only call this from back-substitution, and only once per call,
+// // before other functions here.
+// void PrepareForNumStarts(Index num_starts) const;
+// // Must return num_starts passed to PrepareForNumStarts, or the most
+// // recent call to PrepareForNumStarts if this storage object can be
+// // reused. Note that num_starts == num_slots - kCoeffBits + 1 because
+// // there must be a run of kCoeffBits slots starting from each start.
+// Index GetNumStarts() const;
+// // Load the solution row (type ResultRow) for a slot
+// ResultRow Load(Index slot_num) const;
+// // Store the solution row (type ResultRow) for a slot
+// void Store(Index slot_num, ResultRow data);
+// };
+
+// Back-substitution for generating a solution from BandingStorage to
+// SimpleSolutionStorage.
+template <typename SimpleSolutionStorage, typename BandingStorage>
+void SimpleBackSubst(SimpleSolutionStorage *sss, const BandingStorage &bs) {
+ using CoeffRow = typename BandingStorage::CoeffRow;
+ using Index = typename BandingStorage::Index;
+ using ResultRow = typename BandingStorage::ResultRow;
+
+ static_assert(sizeof(Index) == sizeof(typename SimpleSolutionStorage::Index),
+ "must be same");
+ static_assert(
+ sizeof(CoeffRow) == sizeof(typename SimpleSolutionStorage::CoeffRow),
+ "must be same");
+ static_assert(
+ sizeof(ResultRow) == sizeof(typename SimpleSolutionStorage::ResultRow),
+ "must be same");
+
+ constexpr auto kCoeffBits = static_cast<Index>(sizeof(CoeffRow) * 8U);
+ constexpr auto kResultBits = static_cast<Index>(sizeof(ResultRow) * 8U);
+
+ // A column-major buffer of the solution matrix, containing enough
+ // recently-computed solution data to compute the next solution row
+ // (based also on banding data).
+ std::array<CoeffRow, kResultBits> state;
+ state.fill(0);
+
+ const Index num_starts = bs.GetNumStarts();
+ sss->PrepareForNumStarts(num_starts);
+ const Index num_slots = num_starts + kCoeffBits - 1;
+
+ for (Index i = num_slots; i > 0;) {
+ --i;
+ CoeffRow cr;
+ ResultRow rr;
+ bs.LoadRow(i, &cr, &rr, /* for_back_subst */ true);
+ // solution row
+ ResultRow sr = 0;
+ for (Index j = 0; j < kResultBits; ++j) {
+ // Compute next solution bit at row i, column j (see derivation below)
+ CoeffRow tmp = state[j] << 1;
+ bool bit = (BitParity(tmp & cr) ^ ((rr >> j) & 1)) != 0;
+ tmp |= bit ? CoeffRow{1} : CoeffRow{0};
+
+ // Now tmp is solution at column j from row i for next kCoeffBits
+ // more rows. Thus, for valid solution, the dot product of the
+ // solution column with the coefficient row has to equal the result
+ // at that column,
+ // BitParity(tmp & cr) == ((rr >> j) & 1)
+
+ // Update state.
+ state[j] = tmp;
+ // add to solution row
+ sr |= (bit ? ResultRow{1} : ResultRow{0}) << j;
+ }
+ sss->Store(i, sr);
+ }
+}
+
+// Common functionality for querying a key (already hashed) in
+// SimpleSolutionStorage.
+template <typename SimpleSolutionStorage>
+typename SimpleSolutionStorage::ResultRow SimpleQueryHelper(
+ typename SimpleSolutionStorage::Index start_slot,
+ typename SimpleSolutionStorage::CoeffRow cr,
+ const SimpleSolutionStorage &sss) {
+ using CoeffRow = typename SimpleSolutionStorage::CoeffRow;
+ using ResultRow = typename SimpleSolutionStorage::ResultRow;
+
+ constexpr unsigned kCoeffBits = static_cast<unsigned>(sizeof(CoeffRow) * 8U);
+
+ ResultRow result = 0;
+ for (unsigned i = 0; i < kCoeffBits; ++i) {
+ // Bit masking whole value is generally faster here than 'if'
+ result ^= sss.Load(start_slot + i) &
+ (ResultRow{0} - (static_cast<ResultRow>(cr >> i) & ResultRow{1}));
+ }
+ return result;
+}
+
+// General PHSF query a key from SimpleSolutionStorage.
+template <typename SimpleSolutionStorage, typename PhsfQueryHasher>
+typename SimpleSolutionStorage::ResultRow SimplePhsfQuery(
+ const typename PhsfQueryHasher::Key &key, const PhsfQueryHasher &hasher,
+ const SimpleSolutionStorage &sss) {
+ const typename PhsfQueryHasher::Hash hash = hasher.GetHash(key);
+
+ static_assert(sizeof(typename SimpleSolutionStorage::Index) ==
+ sizeof(typename PhsfQueryHasher::Index),
+ "must be same");
+ static_assert(sizeof(typename SimpleSolutionStorage::CoeffRow) ==
+ sizeof(typename PhsfQueryHasher::CoeffRow),
+ "must be same");
+
+ return SimpleQueryHelper(hasher.GetStart(hash, sss.GetNumStarts()),
+ hasher.GetCoeffRow(hash), sss);
+}
+
+// Filter query a key from SimpleSolutionStorage.
+template <typename SimpleSolutionStorage, typename FilterQueryHasher>
+bool SimpleFilterQuery(const typename FilterQueryHasher::Key &key,
+ const FilterQueryHasher &hasher,
+ const SimpleSolutionStorage &sss) {
+ const typename FilterQueryHasher::Hash hash = hasher.GetHash(key);
+ const typename SimpleSolutionStorage::ResultRow expected =
+ hasher.GetResultRowFromHash(hash);
+
+ static_assert(sizeof(typename SimpleSolutionStorage::Index) ==
+ sizeof(typename FilterQueryHasher::Index),
+ "must be same");
+ static_assert(sizeof(typename SimpleSolutionStorage::CoeffRow) ==
+ sizeof(typename FilterQueryHasher::CoeffRow),
+ "must be same");
+ static_assert(sizeof(typename SimpleSolutionStorage::ResultRow) ==
+ sizeof(typename FilterQueryHasher::ResultRow),
+ "must be same");
+
+ return expected ==
+ SimpleQueryHelper(hasher.GetStart(hash, sss.GetNumStarts()),
+ hasher.GetCoeffRow(hash), sss);
+}
+
+// #################### InterleavedSolutionStorage ######################
+
+// InterleavedSolutionStorage is row-major at a high level, for good
+// locality, and column-major at a low level, for CPU efficiency
+// especially in filter queries or relatively small number of result bits
+// (== solution columns). The storage is a sequence of "blocks" where a
+// block has one CoeffRow-sized segment for each solution column. Each
+// query spans at most two blocks; the starting solution row is typically
+// in the row-logical middle of a block and spans to the middle of the
+// next block. (See diagram below.)
+//
+// InterleavedSolutionStorage supports choosing b (number of result or
+// solution columns) at run time, and even supports mixing b and b-1 solution
+// columns in a single linear system solution, for filters that can
+// effectively utilize any size space (multiple of CoeffRow) for minimizing
+// FP rate for any number of added keys. To simplify query implementation
+// (with lower-index columns first), the b-bit portion comes after the b-1
+// portion of the structure.
+//
+// Diagram (=== marks logical block boundary; b=4; ### is data used by a
+// query crossing the b-1 to b boundary, each Segment has type CoeffRow):
+// ...
+// +======================+
+// | S e g m e n t col=0 |
+// +----------------------+
+// | S e g m e n t col=1 |
+// +----------------------+
+// | S e g m e n t col=2 |
+// +======================+
+// | S e g m e n #########|
+// +----------------------+
+// | S e g m e n #########|
+// +----------------------+
+// | S e g m e n #########|
+// +======================+ Result/solution columns: above = 3, below = 4
+// |#############t col=0 |
+// +----------------------+
+// |#############t col=1 |
+// +----------------------+
+// |#############t col=2 |
+// +----------------------+
+// | S e g m e n t col=3 |
+// +======================+
+// | S e g m e n t col=0 |
+// +----------------------+
+// | S e g m e n t col=1 |
+// +----------------------+
+// | S e g m e n t col=2 |
+// +----------------------+
+// | S e g m e n t col=3 |
+// +======================+
+// ...
+//
+// InterleavedSolutionStorage will be adapted by the algorithms from
+// simple array-like segment storage. That array-like storage is templatized
+// in part so that an implementation may choose to handle byte ordering
+// at access time.
+//
+// concept InterleavedSolutionStorage extends RibbonTypes {
+// // This is called at the beginning of back-substitution for the
+// // solution storage to do any remaining configuration before data
+// // is stored to it. If configuration is previously finalized, this
+// // could be a simple assertion or even no-op. Ribbon algorithms
+// // only call this from back-substitution, and only once per call,
+// // before other functions here.
+// void PrepareForNumStarts(Index num_starts) const;
+// // Must return num_starts passed to PrepareForNumStarts, or the most
+// // recent call to PrepareForNumStarts if this storage object can be
+// // reused. Note that num_starts == num_slots - kCoeffBits + 1 because
+// // there must be a run of kCoeffBits slots starting from each start.
+// Index GetNumStarts() const;
+// // The larger number of solution columns used (called "b" above).
+// Index GetUpperNumColumns() const;
+// // If returns > 0, then block numbers below that use
+// // GetUpperNumColumns() - 1 columns per solution row, and the rest
+// // use GetUpperNumColumns(). A block represents kCoeffBits "slots",
+// // where all but the last kCoeffBits - 1 slots are also starts. And
+// // a block contains a segment for each solution column.
+// // An implementation may only support uniform columns per solution
+// // row and return constant 0 here.
+// Index GetUpperStartBlock() const;
+//
+// // ### "Array of segments" portion of API ###
+// // The number of values of type CoeffRow used in this solution
+// // representation. (This value can be inferred from the previous
+// // three functions, but is expected at least for sanity / assertion
+// // checking.)
+// Index GetNumSegments() const;
+// // Load an entry from the logical array of segments
+// CoeffRow LoadSegment(Index segment_num) const;
+// // Store an entry to the logical array of segments
+// void StoreSegment(Index segment_num, CoeffRow data);
+// };
+
+// A helper for InterleavedBackSubst.
+template <typename BandingStorage>
+inline void BackSubstBlock(typename BandingStorage::CoeffRow *state,
+ typename BandingStorage::Index num_columns,
+ const BandingStorage &bs,
+ typename BandingStorage::Index start_slot) {
+ using CoeffRow = typename BandingStorage::CoeffRow;
+ using Index = typename BandingStorage::Index;
+ using ResultRow = typename BandingStorage::ResultRow;
+
+ constexpr auto kCoeffBits = static_cast<Index>(sizeof(CoeffRow) * 8U);
+
+ for (Index i = start_slot + kCoeffBits; i > start_slot;) {
+ --i;
+ CoeffRow cr;
+ ResultRow rr;
+ bs.LoadRow(i, &cr, &rr, /* for_back_subst */ true);
+ for (Index j = 0; j < num_columns; ++j) {
+ // Compute next solution bit at row i, column j (see derivation below)
+ CoeffRow tmp = state[j] << 1;
+ int bit = BitParity(tmp & cr) ^ ((rr >> j) & 1);
+ tmp |= static_cast<CoeffRow>(bit);
+
+ // Now tmp is solution at column j from row i for next kCoeffBits
+ // more rows. Thus, for valid solution, the dot product of the
+ // solution column with the coefficient row has to equal the result
+ // at that column,
+ // BitParity(tmp & cr) == ((rr >> j) & 1)
+
+ // Update state.
+ state[j] = tmp;
+ }
+ }
+}
+
+// Back-substitution for generating a solution from BandingStorage to
+// InterleavedSolutionStorage.
+template <typename InterleavedSolutionStorage, typename BandingStorage>
+void InterleavedBackSubst(InterleavedSolutionStorage *iss,
+ const BandingStorage &bs) {
+ using CoeffRow = typename BandingStorage::CoeffRow;
+ using Index = typename BandingStorage::Index;
+
+ static_assert(
+ sizeof(Index) == sizeof(typename InterleavedSolutionStorage::Index),
+ "must be same");
+ static_assert(
+ sizeof(CoeffRow) == sizeof(typename InterleavedSolutionStorage::CoeffRow),
+ "must be same");
+
+ constexpr auto kCoeffBits = static_cast<Index>(sizeof(CoeffRow) * 8U);
+
+ const Index num_starts = bs.GetNumStarts();
+ // Although it might be nice to have a filter that returns "always false"
+ // when no key is added, we aren't specifically supporting that here
+ // because it would require another condition branch in the query.
+ assert(num_starts > 0);
+ iss->PrepareForNumStarts(num_starts);
+
+ const Index num_slots = num_starts + kCoeffBits - 1;
+ assert(num_slots % kCoeffBits == 0);
+ const Index num_blocks = num_slots / kCoeffBits;
+ const Index num_segments = iss->GetNumSegments();
+
+ // For now upper, then lower
+ Index num_columns = iss->GetUpperNumColumns();
+ const Index upper_start_block = iss->GetUpperStartBlock();
+
+ if (num_columns == 0) {
+ // Nothing to do, presumably because there's not enough space for even
+ // a single segment.
+ assert(num_segments == 0);
+ // When num_columns == 0, a Ribbon filter query will always return true,
+ // or a PHSF query always 0.
+ return;
+ }
+
+ // We should be utilizing all available segments
+ assert(num_segments == (upper_start_block * (num_columns - 1)) +
+ ((num_blocks - upper_start_block) * num_columns));
+
+ // TODO: consider fixed-column specializations with stack-allocated state
+
+ // A column-major buffer of the solution matrix, containing enough
+ // recently-computed solution data to compute the next solution row
+ // (based also on banding data).
+ std::unique_ptr<CoeffRow[]> state{new CoeffRow[num_columns]()};
+
+ Index block = num_blocks;
+ Index segment_num = num_segments;
+ while (block > upper_start_block) {
+ --block;
+ BackSubstBlock(state.get(), num_columns, bs, block * kCoeffBits);
+ segment_num -= num_columns;
+ for (Index i = 0; i < num_columns; ++i) {
+ iss->StoreSegment(segment_num + i, state[i]);
+ }
+ }
+ // Now (if applicable), region using lower number of columns
+ // (This should be optimized away if GetUpperStartBlock() returns
+ // constant 0.)
+ --num_columns;
+ while (block > 0) {
+ --block;
+ BackSubstBlock(state.get(), num_columns, bs, block * kCoeffBits);
+ segment_num -= num_columns;
+ for (Index i = 0; i < num_columns; ++i) {
+ iss->StoreSegment(segment_num + i, state[i]);
+ }
+ }
+ // Verify everything processed
+ assert(block == 0);
+ assert(segment_num == 0);
+}
+
+// Prefetch memory for a key in InterleavedSolutionStorage.
+template <typename InterleavedSolutionStorage, typename PhsfQueryHasher>
+inline void InterleavedPrepareQuery(
+ const typename PhsfQueryHasher::Key &key, const PhsfQueryHasher &hasher,
+ const InterleavedSolutionStorage &iss,
+ typename PhsfQueryHasher::Hash *saved_hash,
+ typename InterleavedSolutionStorage::Index *saved_segment_num,
+ typename InterleavedSolutionStorage::Index *saved_num_columns,
+ typename InterleavedSolutionStorage::Index *saved_start_bit) {
+ using Hash = typename PhsfQueryHasher::Hash;
+ using CoeffRow = typename InterleavedSolutionStorage::CoeffRow;
+ using Index = typename InterleavedSolutionStorage::Index;
+
+ static_assert(sizeof(Index) == sizeof(typename PhsfQueryHasher::Index),
+ "must be same");
+
+ const Hash hash = hasher.GetHash(key);
+ const Index start_slot = hasher.GetStart(hash, iss.GetNumStarts());
+
+ constexpr auto kCoeffBits = static_cast<Index>(sizeof(CoeffRow) * 8U);
+
+ const Index upper_start_block = iss.GetUpperStartBlock();
+ Index num_columns = iss.GetUpperNumColumns();
+ Index start_block_num = start_slot / kCoeffBits;
+ Index segment_num = start_block_num * num_columns -
+ std::min(start_block_num, upper_start_block);
+ // Change to lower num columns if applicable.
+ // (This should not compile to a conditional branch.)
+ num_columns -= (start_block_num < upper_start_block) ? 1 : 0;
+
+ Index start_bit = start_slot % kCoeffBits;
+
+ Index segment_count = num_columns + (start_bit == 0 ? 0 : num_columns);
+
+ iss.PrefetchSegmentRange(segment_num, segment_num + segment_count);
+
+ *saved_hash = hash;
+ *saved_segment_num = segment_num;
+ *saved_num_columns = num_columns;
+ *saved_start_bit = start_bit;
+}
+
+// General PHSF query from InterleavedSolutionStorage, using data for
+// the query key from InterleavedPrepareQuery
+template <typename InterleavedSolutionStorage, typename PhsfQueryHasher>
+inline typename InterleavedSolutionStorage::ResultRow InterleavedPhsfQuery(
+ typename PhsfQueryHasher::Hash hash,
+ typename InterleavedSolutionStorage::Index segment_num,
+ typename InterleavedSolutionStorage::Index num_columns,
+ typename InterleavedSolutionStorage::Index start_bit,
+ const PhsfQueryHasher &hasher, const InterleavedSolutionStorage &iss) {
+ using CoeffRow = typename InterleavedSolutionStorage::CoeffRow;
+ using Index = typename InterleavedSolutionStorage::Index;
+ using ResultRow = typename InterleavedSolutionStorage::ResultRow;
+
+ static_assert(sizeof(Index) == sizeof(typename PhsfQueryHasher::Index),
+ "must be same");
+ static_assert(sizeof(CoeffRow) == sizeof(typename PhsfQueryHasher::CoeffRow),
+ "must be same");
+
+ constexpr auto kCoeffBits = static_cast<Index>(sizeof(CoeffRow) * 8U);
+
+ const CoeffRow cr = hasher.GetCoeffRow(hash);
+
+ ResultRow sr = 0;
+ const CoeffRow cr_left = cr << static_cast<unsigned>(start_bit);
+ for (Index i = 0; i < num_columns; ++i) {
+ sr ^= BitParity(iss.LoadSegment(segment_num + i) & cr_left) << i;
+ }
+
+ if (start_bit > 0) {
+ segment_num += num_columns;
+ const CoeffRow cr_right =
+ cr >> static_cast<unsigned>(kCoeffBits - start_bit);
+ for (Index i = 0; i < num_columns; ++i) {
+ sr ^= BitParity(iss.LoadSegment(segment_num + i) & cr_right) << i;
+ }
+ }
+
+ return sr;
+}
+
+// Filter query a key from InterleavedFilterQuery.
+template <typename InterleavedSolutionStorage, typename FilterQueryHasher>
+inline bool InterleavedFilterQuery(
+ typename FilterQueryHasher::Hash hash,
+ typename InterleavedSolutionStorage::Index segment_num,
+ typename InterleavedSolutionStorage::Index num_columns,
+ typename InterleavedSolutionStorage::Index start_bit,
+ const FilterQueryHasher &hasher, const InterleavedSolutionStorage &iss) {
+ using CoeffRow = typename InterleavedSolutionStorage::CoeffRow;
+ using Index = typename InterleavedSolutionStorage::Index;
+ using ResultRow = typename InterleavedSolutionStorage::ResultRow;
+
+ static_assert(sizeof(Index) == sizeof(typename FilterQueryHasher::Index),
+ "must be same");
+ static_assert(
+ sizeof(CoeffRow) == sizeof(typename FilterQueryHasher::CoeffRow),
+ "must be same");
+ static_assert(
+ sizeof(ResultRow) == sizeof(typename FilterQueryHasher::ResultRow),
+ "must be same");
+
+ constexpr auto kCoeffBits = static_cast<Index>(sizeof(CoeffRow) * 8U);
+
+ const CoeffRow cr = hasher.GetCoeffRow(hash);
+ const ResultRow expected = hasher.GetResultRowFromHash(hash);
+
+ // TODO: consider optimizations such as
+ // * get rid of start_bit == 0 condition with careful fetching & shifting
+ if (start_bit == 0) {
+ for (Index i = 0; i < num_columns; ++i) {
+ if (BitParity(iss.LoadSegment(segment_num + i) & cr) !=
+ (static_cast<int>(expected >> i) & 1)) {
+ return false;
+ }
+ }
+ } else {
+ const CoeffRow cr_left = cr << static_cast<unsigned>(start_bit);
+ const CoeffRow cr_right =
+ cr >> static_cast<unsigned>(kCoeffBits - start_bit);
+
+ for (Index i = 0; i < num_columns; ++i) {
+ CoeffRow soln_data =
+ (iss.LoadSegment(segment_num + i) & cr_left) ^
+ (iss.LoadSegment(segment_num + num_columns + i) & cr_right);
+ if (BitParity(soln_data) != (static_cast<int>(expected >> i) & 1)) {
+ return false;
+ }
+ }
+ }
+ // otherwise, all match
+ return true;
+}
+
+// TODO: refactor Interleaved*Query so that queries can be "prepared" by
+// prefetching memory, to hide memory latency for multiple queries in a
+// single thread.
+
+} // namespace ribbon
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/ribbon_config.cc b/src/rocksdb/util/ribbon_config.cc
new file mode 100644
index 000000000..c1046f4aa
--- /dev/null
+++ b/src/rocksdb/util/ribbon_config.cc
@@ -0,0 +1,506 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "util/ribbon_config.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+namespace ribbon {
+
+namespace detail {
+
+// Each instantiation of this struct is sufficiently unique for configuration
+// purposes, and is only instantiated for settings where we support the
+// configuration API. An application might only reference one instantiation,
+// meaning the rest could be pruned at link time.
+template <ConstructionFailureChance kCfc, uint64_t kCoeffBits, bool kUseSmash>
+struct BandingConfigHelperData {
+ static constexpr size_t kKnownSize = 18U;
+
+ // Because of complexity in the data, for smaller numbers of slots
+ // (powers of two up to 2^17), we record known numbers that can be added
+ // with kCfc chance of construction failure and settings in template
+ // parameters. Zero means "unsupported (too small) number of slots".
+ // (GetNumToAdd below will use interpolation for numbers of slots
+ // between powers of two; double rather than integer values here make
+ // that more accurate.)
+ static const std::array<double, kKnownSize> kKnownToAddByPow2;
+
+ // For sufficiently large number of slots, doubling the number of
+ // slots will increase the expected overhead (slots over number added)
+ // by approximately this constant.
+ // (This is roughly constant regardless of ConstructionFailureChance and
+ // smash setting.)
+ // (Would be a constant if we had partial template specialization for
+ // static const members.)
+ static inline double GetFactorPerPow2() {
+ if (kCoeffBits == 128U) {
+ return 0.0038;
+ } else {
+ assert(kCoeffBits == 64U);
+ return 0.0083;
+ }
+ }
+
+ // Overhead factor for 2^(kKnownSize-1) slots
+ // (Would be a constant if we had partial template specialization for
+ // static const members.)
+ static inline double GetFinalKnownFactor() {
+ return 1.0 * (uint32_t{1} << (kKnownSize - 1)) /
+ kKnownToAddByPow2[kKnownSize - 1];
+ }
+
+ // GetFinalKnownFactor() - (kKnownSize-1) * GetFactorPerPow2()
+ // (Would be a constant if we had partial template specialization for
+ // static const members.)
+ static inline double GetBaseFactor() {
+ return GetFinalKnownFactor() - (kKnownSize - 1) * GetFactorPerPow2();
+ }
+
+ // Get overhead factor (slots over number to add) for sufficiently large
+ // number of slots (by log base 2)
+ static inline double GetFactorForLarge(double log2_num_slots) {
+ return GetBaseFactor() + log2_num_slots * GetFactorPerPow2();
+ }
+
+ // For a given power of two number of slots (specified by whole number
+ // log base 2), implements GetNumToAdd for such limited case, returning
+ // double for better interpolation in GetNumToAdd and GetNumSlots.
+ static inline double GetNumToAddForPow2(uint32_t log2_num_slots) {
+ assert(log2_num_slots <= 32); // help clang-analyze
+ if (log2_num_slots < kKnownSize) {
+ return kKnownToAddByPow2[log2_num_slots];
+ } else {
+ return 1.0 * (uint64_t{1} << log2_num_slots) /
+ GetFactorForLarge(1.0 * log2_num_slots);
+ }
+ }
+};
+
+// Based on data from FindOccupancy in ribbon_test
+template <>
+const std::array<double, 18>
+ BandingConfigHelperData<kOneIn2, 128U, false>::kKnownToAddByPow2{{
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0, // unsupported
+ 252.984,
+ 506.109,
+ 1013.71,
+ 2029.47,
+ 4060.43,
+ 8115.63,
+ 16202.2,
+ 32305.1,
+ 64383.5,
+ 128274,
+ }};
+
+template <>
+const std::array<double, 18>
+ BandingConfigHelperData<kOneIn2, 128U, /*smash*/ true>::kKnownToAddByPow2{{
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0, // unsupported
+ 126.274,
+ 254.279,
+ 510.27,
+ 1022.24,
+ 2046.02,
+ 4091.99,
+ 8154.98,
+ 16244.3,
+ 32349.7,
+ 64426.6,
+ 128307,
+ }};
+
+template <>
+const std::array<double, 18>
+ BandingConfigHelperData<kOneIn2, 64U, false>::kKnownToAddByPow2{{
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0, // unsupported
+ 124.94,
+ 249.968,
+ 501.234,
+ 1004.06,
+ 2006.15,
+ 3997.89,
+ 7946.99,
+ 15778.4,
+ 31306.9,
+ 62115.3,
+ 123284,
+ }};
+
+template <>
+const std::array<double, 18>
+ BandingConfigHelperData<kOneIn2, 64U, /*smash*/ true>::kKnownToAddByPow2{{
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0, // unsupported
+ 62.2683,
+ 126.259,
+ 254.268,
+ 509.975,
+ 1019.98,
+ 2026.16,
+ 4019.75,
+ 7969.8,
+ 15798.2,
+ 31330.3,
+ 62134.2,
+ 123255,
+ }};
+
+template <>
+const std::array<double, 18>
+ BandingConfigHelperData<kOneIn20, 128U, false>::kKnownToAddByPow2{{
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0, // unsupported
+ 248.851,
+ 499.532,
+ 1001.26,
+ 2003.97,
+ 4005.59,
+ 8000.39,
+ 15966.6,
+ 31828.1,
+ 63447.3,
+ 126506,
+ }};
+
+template <>
+const std::array<double, 18>
+ BandingConfigHelperData<kOneIn20, 128U, /*smash*/ true>::kKnownToAddByPow2{{
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0, // unsupported
+ 122.637,
+ 250.651,
+ 506.625,
+ 1018.54,
+ 2036.43,
+ 4041.6,
+ 8039.25,
+ 16005,
+ 31869.6,
+ 63492.8,
+ 126537,
+ }};
+
+template <>
+const std::array<double, 18>
+ BandingConfigHelperData<kOneIn20, 64U, false>::kKnownToAddByPow2{{
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0, // unsupported
+ 120.659,
+ 243.346,
+ 488.168,
+ 976.373,
+ 1948.86,
+ 3875.85,
+ 7704.97,
+ 15312.4,
+ 30395.1,
+ 60321.8,
+ 119813,
+ }};
+
+template <>
+const std::array<double, 18>
+ BandingConfigHelperData<kOneIn20, 64U, /*smash*/ true>::kKnownToAddByPow2{{
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0, // unsupported
+ 58.6016,
+ 122.619,
+ 250.641,
+ 503.595,
+ 994.165,
+ 1967.36,
+ 3898.17,
+ 7727.21,
+ 15331.5,
+ 30405.8,
+ 60376.2,
+ 119836,
+ }};
+
+template <>
+const std::array<double, 18>
+ BandingConfigHelperData<kOneIn1000, 128U, false>::kKnownToAddByPow2{{
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0, // unsupported
+ 242.61,
+ 491.887,
+ 983.603,
+ 1968.21,
+ 3926.98,
+ 7833.99,
+ 15629,
+ 31199.9,
+ 62307.8,
+ 123870,
+ }};
+
+template <>
+const std::array<double, 18> BandingConfigHelperData<
+ kOneIn1000, 128U, /*smash*/ true>::kKnownToAddByPow2{{
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0, // unsupported
+ 117.19,
+ 245.105,
+ 500.748,
+ 1010.67,
+ 1993.4,
+ 3950.01,
+ 7863.31,
+ 15652,
+ 31262.1,
+ 62462.8,
+ 124095,
+}};
+
+template <>
+const std::array<double, 18>
+ BandingConfigHelperData<kOneIn1000, 64U, false>::kKnownToAddByPow2{{
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0, // unsupported
+ 114,
+ 234.8,
+ 471.498,
+ 940.165,
+ 1874,
+ 3721.5,
+ 7387.5,
+ 14592,
+ 29160,
+ 57745,
+ 115082,
+ }};
+
+template <>
+const std::array<double, 18>
+ BandingConfigHelperData<kOneIn1000, 64U, /*smash*/ true>::kKnownToAddByPow2{
+ {
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0, // unsupported
+ 53.0434,
+ 117,
+ 245.312,
+ 483.571,
+ 950.251,
+ 1878,
+ 3736.34,
+ 7387.97,
+ 14618,
+ 29142.9,
+ 57838.8,
+ 114932,
+ }};
+
+// We hide these implementation details from the .h file with explicit
+// instantiations below these partial specializations.
+
+template <ConstructionFailureChance kCfc, uint64_t kCoeffBits, bool kUseSmash,
+ bool kHomogeneous>
+uint32_t BandingConfigHelper1MaybeSupported<
+ kCfc, kCoeffBits, kUseSmash, kHomogeneous,
+ true /* kIsSupported */>::GetNumToAdd(uint32_t num_slots) {
+ using Data = detail::BandingConfigHelperData<kCfc, kCoeffBits, kUseSmash>;
+ if (num_slots == 0) {
+ return 0;
+ }
+ uint32_t num_to_add;
+ double log2_num_slots = std::log(num_slots) * 1.4426950409;
+ uint32_t floor_log2 = static_cast<uint32_t>(log2_num_slots);
+ if (floor_log2 + 1 < Data::kKnownSize) {
+ double ceil_portion = 1.0 * num_slots / (uint32_t{1} << floor_log2) - 1.0;
+ // Must be a supported number of slots
+ assert(Data::kKnownToAddByPow2[floor_log2] > 0.0);
+ // Weighted average of two nearest known data points
+ num_to_add = static_cast<uint32_t>(
+ ceil_portion * Data::kKnownToAddByPow2[floor_log2 + 1] +
+ (1.0 - ceil_portion) * Data::kKnownToAddByPow2[floor_log2]);
+ } else {
+ // Use formula for large values
+ double factor = Data::GetFactorForLarge(log2_num_slots);
+ assert(factor >= 1.0);
+ num_to_add = static_cast<uint32_t>(num_slots / factor);
+ }
+ if (kHomogeneous) {
+ // Even when standard filter construction would succeed, we might
+ // have loaded things up too much for Homogeneous filter. (Complete
+ // explanation not known but observed empirically.) This seems to
+ // correct for that, mostly affecting small filter configurations.
+ if (num_to_add >= 8) {
+ num_to_add -= 8;
+ } else {
+ assert(false);
+ }
+ }
+ return num_to_add;
+}
+
+template <ConstructionFailureChance kCfc, uint64_t kCoeffBits, bool kUseSmash,
+ bool kHomogeneous>
+uint32_t BandingConfigHelper1MaybeSupported<
+ kCfc, kCoeffBits, kUseSmash, kHomogeneous,
+ true /* kIsSupported */>::GetNumSlots(uint32_t num_to_add) {
+ using Data = detail::BandingConfigHelperData<kCfc, kCoeffBits, kUseSmash>;
+
+ if (num_to_add == 0) {
+ return 0;
+ }
+ if (kHomogeneous) {
+ // Reverse of above in GetNumToAdd
+ num_to_add += 8;
+ }
+ double log2_num_to_add = std::log(num_to_add) * 1.4426950409;
+ uint32_t approx_log2_slots = static_cast<uint32_t>(log2_num_to_add + 0.5);
+ assert(approx_log2_slots <= 32); // help clang-analyze
+
+ double lower_num_to_add = Data::GetNumToAddForPow2(approx_log2_slots);
+ double upper_num_to_add;
+ if (approx_log2_slots == 0 || lower_num_to_add == /* unsupported */ 0) {
+ // Return minimum non-zero slots in standard implementation
+ return kUseSmash ? kCoeffBits : 2 * kCoeffBits;
+ } else if (num_to_add < lower_num_to_add) {
+ upper_num_to_add = lower_num_to_add;
+ --approx_log2_slots;
+ lower_num_to_add = Data::GetNumToAddForPow2(approx_log2_slots);
+ } else {
+ upper_num_to_add = Data::GetNumToAddForPow2(approx_log2_slots + 1);
+ }
+
+ assert(num_to_add >= lower_num_to_add);
+ assert(num_to_add < upper_num_to_add);
+
+ double upper_portion =
+ (num_to_add - lower_num_to_add) / (upper_num_to_add - lower_num_to_add);
+
+ double lower_num_slots = 1.0 * (uint64_t{1} << approx_log2_slots);
+
+ // Interpolation, round up
+ return static_cast<uint32_t>(upper_portion * lower_num_slots +
+ lower_num_slots + 0.999999999);
+}
+
+// These explicit instantiations enable us to hide most of the
+// implementation details from the .h file. (The .h file currently
+// needs to determine whether settings are "supported" or not.)
+
+template struct BandingConfigHelper1MaybeSupported<kOneIn2, 128U, /*sm*/ false,
+ /*hm*/ false, /*sup*/ true>;
+template struct BandingConfigHelper1MaybeSupported<kOneIn2, 128U, /*sm*/ true,
+ /*hm*/ false, /*sup*/ true>;
+template struct BandingConfigHelper1MaybeSupported<kOneIn2, 128U, /*sm*/ false,
+ /*hm*/ true, /*sup*/ true>;
+template struct BandingConfigHelper1MaybeSupported<kOneIn2, 128U, /*sm*/ true,
+ /*hm*/ true, /*sup*/ true>;
+template struct BandingConfigHelper1MaybeSupported<kOneIn2, 64U, /*sm*/ false,
+ /*hm*/ false, /*sup*/ true>;
+template struct BandingConfigHelper1MaybeSupported<kOneIn2, 64U, /*sm*/ true,
+ /*hm*/ false, /*sup*/ true>;
+template struct BandingConfigHelper1MaybeSupported<kOneIn2, 64U, /*sm*/ false,
+ /*hm*/ true, /*sup*/ true>;
+template struct BandingConfigHelper1MaybeSupported<kOneIn2, 64U, /*sm*/ true,
+ /*hm*/ true, /*sup*/ true>;
+
+template struct BandingConfigHelper1MaybeSupported<kOneIn20, 128U, /*sm*/ false,
+ /*hm*/ false, /*sup*/ true>;
+template struct BandingConfigHelper1MaybeSupported<kOneIn20, 128U, /*sm*/ true,
+ /*hm*/ false, /*sup*/ true>;
+template struct BandingConfigHelper1MaybeSupported<kOneIn20, 128U, /*sm*/ false,
+ /*hm*/ true, /*sup*/ true>;
+template struct BandingConfigHelper1MaybeSupported<kOneIn20, 128U, /*sm*/ true,
+ /*hm*/ true, /*sup*/ true>;
+template struct BandingConfigHelper1MaybeSupported<kOneIn20, 64U, /*sm*/ false,
+ /*hm*/ false, /*sup*/ true>;
+template struct BandingConfigHelper1MaybeSupported<kOneIn20, 64U, /*sm*/ true,
+ /*hm*/ false, /*sup*/ true>;
+template struct BandingConfigHelper1MaybeSupported<kOneIn20, 64U, /*sm*/ false,
+ /*hm*/ true, /*sup*/ true>;
+template struct BandingConfigHelper1MaybeSupported<kOneIn20, 64U, /*sm*/ true,
+ /*hm*/ true, /*sup*/ true>;
+
+template struct BandingConfigHelper1MaybeSupported<
+ kOneIn1000, 128U, /*sm*/ false, /*hm*/ false, /*sup*/ true>;
+template struct BandingConfigHelper1MaybeSupported<
+ kOneIn1000, 128U, /*sm*/ true, /*hm*/ false, /*sup*/ true>;
+template struct BandingConfigHelper1MaybeSupported<
+ kOneIn1000, 128U, /*sm*/ false, /*hm*/ true, /*sup*/ true>;
+template struct BandingConfigHelper1MaybeSupported<
+ kOneIn1000, 128U, /*sm*/ true, /*hm*/ true, /*sup*/ true>;
+template struct BandingConfigHelper1MaybeSupported<
+ kOneIn1000, 64U, /*sm*/ false, /*hm*/ false, /*sup*/ true>;
+template struct BandingConfigHelper1MaybeSupported<kOneIn1000, 64U, /*sm*/ true,
+ /*hm*/ false, /*sup*/ true>;
+template struct BandingConfigHelper1MaybeSupported<
+ kOneIn1000, 64U, /*sm*/ false, /*hm*/ true, /*sup*/ true>;
+template struct BandingConfigHelper1MaybeSupported<kOneIn1000, 64U, /*sm*/ true,
+ /*hm*/ true, /*sup*/ true>;
+
+} // namespace detail
+
+} // namespace ribbon
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/ribbon_config.h b/src/rocksdb/util/ribbon_config.h
new file mode 100644
index 000000000..0e3edf073
--- /dev/null
+++ b/src/rocksdb/util/ribbon_config.h
@@ -0,0 +1,182 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#include <array>
+#include <cassert>
+#include <cmath>
+#include <cstdint>
+
+#include "port/lang.h" // for FALLTHROUGH_INTENDED
+#include "rocksdb/rocksdb_namespace.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+namespace ribbon {
+
+// RIBBON PHSF & RIBBON Filter (Rapid Incremental Boolean Banding ON-the-fly)
+//
+// ribbon_config.h: APIs for relating numbers of slots with numbers of
+// additions for tolerable construction failure probabilities. This is
+// separate from ribbon_impl.h because it might not be needed for
+// some applications.
+//
+// This API assumes uint32_t for number of slots, as a single Ribbon
+// linear system should not normally overflow that without big penalties.
+//
+// Template parameter kCoeffBits uses uint64_t for convenience in case it
+// comes from size_t.
+//
+// Most of the complexity here is trying to optimize speed and
+// compiled code size, using templates to minimize table look-ups and
+// the compiled size of all linked look-up tables. Look-up tables are
+// required because we don't have good formulas, and the data comes
+// from running FindOccupancy in ribbon_test.
+
+// Represents a chosen chance of successful Ribbon construction for a single
+// seed. Allowing higher chance of failed construction can reduce space
+// overhead but takes extra time in construction.
+enum ConstructionFailureChance {
+ kOneIn2,
+ kOneIn20,
+ // When using kHomogeneous==true, construction failure chance should
+ // not generally exceed target FP rate, so it unlikely useful to
+ // allow a higher "failure" chance. In some cases, even more overhead
+ // is appropriate. (TODO)
+ kOneIn1000,
+};
+
+namespace detail {
+
+// It is useful to compile ribbon_test linking to BandingConfigHelper with
+// settings for which we do not have configuration data, as long as we don't
+// run the code. This template hack supports that.
+template <ConstructionFailureChance kCfc, uint64_t kCoeffBits, bool kUseSmash,
+ bool kHomogeneous, bool kIsSupported>
+struct BandingConfigHelper1MaybeSupported {
+ public:
+ static uint32_t GetNumToAdd(uint32_t num_slots) {
+ // Unsupported
+ assert(num_slots == 0);
+ (void)num_slots;
+ return 0;
+ }
+
+ static uint32_t GetNumSlots(uint32_t num_to_add) {
+ // Unsupported
+ assert(num_to_add == 0);
+ (void)num_to_add;
+ return 0;
+ }
+};
+
+// Base class for BandingConfigHelper1 and helper for BandingConfigHelper
+// with core implementations built on above data
+template <ConstructionFailureChance kCfc, uint64_t kCoeffBits, bool kUseSmash,
+ bool kHomogeneous>
+struct BandingConfigHelper1MaybeSupported<
+ kCfc, kCoeffBits, kUseSmash, kHomogeneous, true /* kIsSupported */> {
+ public:
+ // See BandingConfigHelper1. Implementation in ribbon_config.cc
+ static uint32_t GetNumToAdd(uint32_t num_slots);
+
+ // See BandingConfigHelper1. Implementation in ribbon_config.cc
+ static uint32_t GetNumSlots(uint32_t num_to_add);
+};
+
+} // namespace detail
+
+template <ConstructionFailureChance kCfc, uint64_t kCoeffBits, bool kUseSmash,
+ bool kHomogeneous>
+struct BandingConfigHelper1
+ : public detail::BandingConfigHelper1MaybeSupported<
+ kCfc, kCoeffBits, kUseSmash, kHomogeneous,
+ /* kIsSupported */ kCoeffBits == 64 || kCoeffBits == 128> {
+ public:
+ // Returns a number of entries that can be added to a given number of
+ // slots, with roughly kCfc chance of construction failure per seed,
+ // or better. Does NOT do rounding for InterleavedSoln; call
+ // RoundUpNumSlots for that.
+ //
+ // inherited:
+ // static uint32_t GetNumToAdd(uint32_t num_slots);
+
+ // Returns a number of slots for a given number of entries to add
+ // that should have roughly kCfc chance of construction failure per
+ // seed, or better. Does NOT do rounding for InterleavedSoln; call
+ // RoundUpNumSlots for that.
+ //
+ // num_to_add should not exceed roughly 2/3rds of the maximum value
+ // of the uint32_t type to avoid overflow.
+ //
+ // inherited:
+ // static uint32_t GetNumSlots(uint32_t num_to_add);
+};
+
+// Configured using TypesAndSettings as in ribbon_impl.h
+template <ConstructionFailureChance kCfc, class TypesAndSettings>
+struct BandingConfigHelper1TS
+ : public BandingConfigHelper1<
+ kCfc,
+ /* kCoeffBits */ sizeof(typename TypesAndSettings::CoeffRow) * 8U,
+ TypesAndSettings::kUseSmash, TypesAndSettings::kHomogeneous> {};
+
+// Like BandingConfigHelper1TS except failure chance can be a runtime rather
+// than compile time value.
+template <class TypesAndSettings>
+struct BandingConfigHelper {
+ public:
+ static constexpr ConstructionFailureChance kDefaultFailureChance =
+ TypesAndSettings::kHomogeneous ? kOneIn1000 : kOneIn20;
+
+ static uint32_t GetNumToAdd(
+ uint32_t num_slots,
+ ConstructionFailureChance max_failure = kDefaultFailureChance) {
+ switch (max_failure) {
+ default:
+ assert(false);
+ FALLTHROUGH_INTENDED;
+ case kOneIn20: {
+ using H1 = BandingConfigHelper1TS<kOneIn20, TypesAndSettings>;
+ return H1::GetNumToAdd(num_slots);
+ }
+ case kOneIn2: {
+ using H1 = BandingConfigHelper1TS<kOneIn2, TypesAndSettings>;
+ return H1::GetNumToAdd(num_slots);
+ }
+ case kOneIn1000: {
+ using H1 = BandingConfigHelper1TS<kOneIn1000, TypesAndSettings>;
+ return H1::GetNumToAdd(num_slots);
+ }
+ }
+ }
+
+ static uint32_t GetNumSlots(
+ uint32_t num_to_add,
+ ConstructionFailureChance max_failure = kDefaultFailureChance) {
+ switch (max_failure) {
+ default:
+ assert(false);
+ FALLTHROUGH_INTENDED;
+ case kOneIn20: {
+ using H1 = BandingConfigHelper1TS<kOneIn20, TypesAndSettings>;
+ return H1::GetNumSlots(num_to_add);
+ }
+ case kOneIn2: {
+ using H1 = BandingConfigHelper1TS<kOneIn2, TypesAndSettings>;
+ return H1::GetNumSlots(num_to_add);
+ }
+ case kOneIn1000: {
+ using H1 = BandingConfigHelper1TS<kOneIn1000, TypesAndSettings>;
+ return H1::GetNumSlots(num_to_add);
+ }
+ }
+ }
+};
+
+} // namespace ribbon
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/ribbon_impl.h b/src/rocksdb/util/ribbon_impl.h
new file mode 100644
index 000000000..0afecc67d
--- /dev/null
+++ b/src/rocksdb/util/ribbon_impl.h
@@ -0,0 +1,1137 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#include <cmath>
+
+#include "port/port.h" // for PREFETCH
+#include "util/fastrange.h"
+#include "util/ribbon_alg.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+namespace ribbon {
+
+// RIBBON PHSF & RIBBON Filter (Rapid Incremental Boolean Banding ON-the-fly)
+//
+// ribbon_impl.h: templated (parameterized) standard implementations
+//
+// Ribbon is a Perfect Hash Static Function construction useful as a compact
+// static Bloom filter alternative. See ribbon_alg.h for core algorithms
+// and core design details.
+//
+// TODO: more details on trade-offs and practical issues.
+//
+// APIs for configuring Ribbon are in ribbon_config.h
+
+// Ribbon implementations in this file take these parameters, which must be
+// provided in a class/struct type with members expressed in this concept:
+
+// concept TypesAndSettings {
+// // See RibbonTypes and *Hasher in ribbon_alg.h, except here we have
+// // the added constraint that Hash be equivalent to either uint32_t or
+// // uint64_t.
+// typename Hash;
+// typename CoeffRow;
+// typename ResultRow;
+// typename Index;
+// typename Key;
+// static constexpr bool kFirstCoeffAlwaysOne;
+//
+// // An unsigned integer type for identifying a hash seed, typically
+// // uint32_t or uint64_t. Importantly, this is the amount of data
+// // stored in memory for identifying a raw seed. See StandardHasher.
+// typename Seed;
+//
+// // When true, the PHSF implements a static filter, expecting just
+// // keys as inputs for construction. When false, implements a general
+// // PHSF and expects std::pair<Key, ResultRow> as inputs for
+// // construction.
+// static constexpr bool kIsFilter;
+//
+// // When true, enables a special "homogeneous" filter implementation that
+// // is slightly faster to construct, and never fails to construct though
+// // FP rate can quickly explode in cases where corresponding
+// // non-homogeneous filter would fail (or nearly fail?) to construct.
+// // For smaller filters, you can configure with ConstructionFailureChance
+// // smaller than desired FP rate to largely counteract this effect.
+// // TODO: configuring Homogeneous Ribbon for arbitrarily large filters
+// // based on data from OptimizeHomogAtScale
+// static constexpr bool kHomogeneous;
+//
+// // When true, adds a tiny bit more hashing logic on queries and
+// // construction to improve utilization at the beginning and end of
+// // the structure. Recommended when CoeffRow is only 64 bits (or
+// // less), so typical num_starts < 10k. Although this is compatible
+// // with kHomogeneous, the competing space vs. time priorities might
+// // not be useful.
+// static constexpr bool kUseSmash;
+//
+// // When true, allows number of "starts" to be zero, for best support
+// // of the "no keys to add" case by always returning false for filter
+// // queries. (This is distinct from the "keys added but no space for
+// // any data" case, in which a filter always returns true.) The cost
+// // supporting this is a conditional branch (probably predictable) in
+// // queries.
+// static constexpr bool kAllowZeroStarts;
+//
+// // A seedable stock hash function on Keys. All bits of Hash must
+// // be reasonably high quality. XXH functions recommended, but
+// // Murmur, City, Farm, etc. also work.
+// static Hash HashFn(const Key &, Seed raw_seed);
+// };
+
+// A bit of a hack to automatically construct the type for
+// AddInput based on a constexpr bool.
+template <typename Key, typename ResultRow, bool IsFilter>
+struct AddInputSelector {
+ // For general PHSF, not filter
+ using T = std::pair<Key, ResultRow>;
+};
+
+template <typename Key, typename ResultRow>
+struct AddInputSelector<Key, ResultRow, true /*IsFilter*/> {
+ // For Filter
+ using T = Key;
+};
+
+// To avoid writing 'typename' everywhere that we use types like 'Index'
+#define IMPORT_RIBBON_TYPES_AND_SETTINGS(TypesAndSettings) \
+ using CoeffRow = typename TypesAndSettings::CoeffRow; \
+ using ResultRow = typename TypesAndSettings::ResultRow; \
+ using Index = typename TypesAndSettings::Index; \
+ using Hash = typename TypesAndSettings::Hash; \
+ using Key = typename TypesAndSettings::Key; \
+ using Seed = typename TypesAndSettings::Seed; \
+ \
+ /* Some more additions */ \
+ using QueryInput = Key; \
+ using AddInput = typename ROCKSDB_NAMESPACE::ribbon::AddInputSelector< \
+ Key, ResultRow, TypesAndSettings::kIsFilter>::T; \
+ static constexpr auto kCoeffBits = \
+ static_cast<Index>(sizeof(CoeffRow) * 8U); \
+ \
+ /* Export to algorithm */ \
+ static constexpr bool kFirstCoeffAlwaysOne = \
+ TypesAndSettings::kFirstCoeffAlwaysOne; \
+ \
+ static_assert(sizeof(CoeffRow) + sizeof(ResultRow) + sizeof(Index) + \
+ sizeof(Hash) + sizeof(Key) + sizeof(Seed) + \
+ sizeof(QueryInput) + sizeof(AddInput) + kCoeffBits + \
+ kFirstCoeffAlwaysOne > \
+ 0, \
+ "avoid unused warnings, semicolon expected after macro call")
+
+#ifdef _MSC_VER
+#pragma warning(push)
+#pragma warning(disable : 4309) // cast truncating constant
+#pragma warning(disable : 4307) // arithmetic constant overflow
+#endif
+
+// StandardHasher: A standard implementation of concepts RibbonTypes,
+// PhsfQueryHasher, FilterQueryHasher, and BandingHasher from ribbon_alg.h.
+//
+// This implementation should be suitable for most all practical purposes
+// as it "behaves" across a wide range of settings, with little room left
+// for improvement. The key functionality in this hasher is generating
+// CoeffRows, starts, and (for filters) ResultRows, which could be ~150
+// bits of data or more, from a modest hash of 64 or even just 32 bits, with
+// enough uniformity and bitwise independence to be close to "the best you
+// can do" with available hash information in terms of FP rate and
+// compactness. (64 bits recommended and sufficient for PHSF practical
+// purposes.)
+//
+// Another feature of this hasher is a minimal "premixing" of seeds before
+// they are provided to TypesAndSettings::HashFn in case that function does
+// not provide sufficiently independent hashes when iterating merely
+// sequentially on seeds. (This for example works around a problem with the
+// preview version 0.7.2 of XXH3 used in RocksDB, a.k.a. XXPH3 or Hash64, and
+// MurmurHash1 used in RocksDB, a.k.a. Hash.) We say this pre-mixing step
+// translates "ordinal seeds," which we iterate sequentially to find a
+// solution, into "raw seeds," with many more bits changing for each
+// iteration. The translation is an easily reversible lightweight mixing,
+// not suitable for hashing on its own. An advantage of this approach is that
+// StandardHasher can store just the raw seed (e.g. 64 bits) for fast query
+// times, while from the application perspective, we can limit to a small
+// number of ordinal keys (e.g. 64 in 6 bits) for saving in metadata.
+//
+// The default constructor initializes the seed to ordinal seed zero, which
+// is equal to raw seed zero.
+//
+template <class TypesAndSettings>
+class StandardHasher {
+ public:
+ IMPORT_RIBBON_TYPES_AND_SETTINGS(TypesAndSettings);
+
+ inline Hash GetHash(const Key& key) const {
+ return TypesAndSettings::HashFn(key, raw_seed_);
+ };
+ // For when AddInput == pair<Key, ResultRow> (kIsFilter == false)
+ inline Hash GetHash(const std::pair<Key, ResultRow>& bi) const {
+ return GetHash(bi.first);
+ };
+ inline Index GetStart(Hash h, Index num_starts) const {
+ // This is "critical path" code because it's required before memory
+ // lookup.
+ //
+ // FastRange gives us a fast and effective mapping from h to the
+ // appropriate range. This depends most, sometimes exclusively, on
+ // upper bits of h.
+ //
+ if (TypesAndSettings::kUseSmash) {
+ // Extra logic to "smash" entries at beginning and end, for
+ // better utilization. For example, without smash and with
+ // kFirstCoeffAlwaysOne, there's about a 30% chance that the
+ // first slot in the banding will be unused, and worse without
+ // kFirstCoeffAlwaysOne. The ending slots are even less utilized
+ // without smash.
+ //
+ // But since this only affects roughly kCoeffBits of the slots,
+ // it's usually small enough to be ignorable (less computation in
+ // this function) when number of slots is roughly 10k or larger.
+ //
+ // The best values for these smash weights might depend on how
+ // densely you're packing entries, and also kCoeffBits, but this
+ // seems to work well for roughly 95% success probability.
+ //
+ constexpr Index kFrontSmash = kCoeffBits / 4;
+ constexpr Index kBackSmash = kCoeffBits / 4;
+ Index start = FastRangeGeneric(h, num_starts + kFrontSmash + kBackSmash);
+ start = std::max(start, kFrontSmash);
+ start -= kFrontSmash;
+ start = std::min(start, num_starts - 1);
+ return start;
+ } else {
+ // For query speed, we allow small number of initial and final
+ // entries to be under-utilized.
+ // NOTE: This call statically enforces that Hash is equivalent to
+ // either uint32_t or uint64_t.
+ return FastRangeGeneric(h, num_starts);
+ }
+ }
+ inline CoeffRow GetCoeffRow(Hash h) const {
+ // This is not so much "critical path" code because it can be done in
+ // parallel (instruction level) with memory lookup.
+ //
+ // When we might have many entries squeezed into a single start,
+ // we need reasonably good remixing for CoeffRow.
+ if (TypesAndSettings::kUseSmash) {
+ // Reasonably good, reasonably fast, reasonably general.
+ // Probably not 1:1 but probably close enough.
+ Unsigned128 a = Multiply64to128(h, kAltCoeffFactor1);
+ Unsigned128 b = Multiply64to128(h, kAltCoeffFactor2);
+ auto cr = static_cast<CoeffRow>(b ^ (a << 64) ^ (a >> 64));
+
+ // Now ensure the value is non-zero
+ if (kFirstCoeffAlwaysOne) {
+ cr |= 1;
+ } else {
+ // Still have to ensure some bit is non-zero
+ cr |= (cr == 0) ? 1 : 0;
+ }
+ return cr;
+ }
+ // If not kUseSmash, we ensure we're not squeezing many entries into a
+ // single start, in part by ensuring num_starts > num_slots / 2. Thus,
+ // here we do not need good remixing for CoeffRow, but just enough that
+ // (a) every bit is reasonably independent from Start.
+ // (b) every Hash-length bit subsequence of the CoeffRow has full or
+ // nearly full entropy from h.
+ // (c) if nontrivial bit subsequences within are correlated, it needs to
+ // be more complicated than exact copy or bitwise not (at least without
+ // kFirstCoeffAlwaysOne), or else there seems to be a kind of
+ // correlated clustering effect.
+ // (d) the CoeffRow is not zero, so that no one input on its own can
+ // doom construction success. (Preferably a mix of 1's and 0's if
+ // satisfying above.)
+
+ // First, establish sufficient bitwise independence from Start, with
+ // multiplication by a large random prime.
+ // Note that we cast to Hash because if we use product bits beyond
+ // original input size, that's going to correlate with Start (FastRange)
+ // even with a (likely) different multiplier here.
+ Hash a = h * kCoeffAndResultFactor;
+
+ static_assert(
+ sizeof(Hash) == sizeof(uint64_t) || sizeof(Hash) == sizeof(uint32_t),
+ "Supported sizes");
+ // If that's big enough, we're done. If not, we have to expand it,
+ // maybe up to 4x size.
+ uint64_t b;
+ if (sizeof(Hash) < sizeof(uint64_t)) {
+ // Almost-trivial hash expansion (OK - see above), favoring roughly
+ // equal number of 1's and 0's in result
+ b = (uint64_t{a} << 32) ^ (a ^ kCoeffXor32);
+ } else {
+ b = a;
+ }
+ static_assert(sizeof(CoeffRow) <= sizeof(Unsigned128), "Supported sizes");
+ Unsigned128 c;
+ if (sizeof(uint64_t) < sizeof(CoeffRow)) {
+ // Almost-trivial hash expansion (OK - see above), favoring roughly
+ // equal number of 1's and 0's in result
+ c = (Unsigned128{b} << 64) ^ (b ^ kCoeffXor64);
+ } else {
+ c = b;
+ }
+ auto cr = static_cast<CoeffRow>(c);
+
+ // Now ensure the value is non-zero
+ if (kFirstCoeffAlwaysOne) {
+ cr |= 1;
+ } else if (sizeof(CoeffRow) == sizeof(Hash)) {
+ // Still have to ensure some bit is non-zero
+ cr |= (cr == 0) ? 1 : 0;
+ } else {
+ // (We did trivial expansion with constant xor, which ensures some
+ // bits are non-zero.)
+ }
+ return cr;
+ }
+ inline ResultRow GetResultRowMask() const {
+ // TODO: will be used with InterleavedSolutionStorage?
+ // For now, all bits set (note: might be a small type so might need to
+ // narrow after promotion)
+ return static_cast<ResultRow>(~ResultRow{0});
+ }
+ inline ResultRow GetResultRowFromHash(Hash h) const {
+ if (TypesAndSettings::kIsFilter && !TypesAndSettings::kHomogeneous) {
+ // This is not so much "critical path" code because it can be done in
+ // parallel (instruction level) with memory lookup.
+ //
+ // ResultRow bits only needs to be independent from CoeffRow bits if
+ // many entries might have the same start location, where "many" is
+ // comparable to number of hash bits or kCoeffBits. If !kUseSmash
+ // and num_starts > kCoeffBits, it is safe and efficient to draw from
+ // the same bits computed for CoeffRow, which are reasonably
+ // independent from Start. (Inlining and common subexpression
+ // elimination with GetCoeffRow should make this
+ // a single shared multiplication in generated code when !kUseSmash.)
+ Hash a = h * kCoeffAndResultFactor;
+
+ // The bits here that are *most* independent of Start are the highest
+ // order bits (as in Knuth multiplicative hash). To make those the
+ // most preferred for use in the result row, we do a bswap here.
+ auto rr = static_cast<ResultRow>(EndianSwapValue(a));
+ return rr & GetResultRowMask();
+ } else {
+ // Must be zero
+ return 0;
+ }
+ }
+ // For when AddInput == Key (kIsFilter == true)
+ inline ResultRow GetResultRowFromInput(const Key&) const {
+ // Must be zero
+ return 0;
+ }
+ // For when AddInput == pair<Key, ResultRow> (kIsFilter == false)
+ inline ResultRow GetResultRowFromInput(
+ const std::pair<Key, ResultRow>& bi) const {
+ // Simple extraction
+ return bi.second;
+ }
+
+ // Seed tracking APIs - see class comment
+ void SetRawSeed(Seed seed) { raw_seed_ = seed; }
+ Seed GetRawSeed() { return raw_seed_; }
+ void SetOrdinalSeed(Seed count) {
+ // A simple, reversible mixing of any size (whole bytes) up to 64 bits.
+ // This allows casting the raw seed to any smaller size we use for
+ // ordinal seeds without risk of duplicate raw seeds for unique ordinal
+ // seeds.
+
+ // Seed type might be smaller than numerical promotion size, but Hash
+ // should be at least that size, so we use Hash as intermediate type.
+ static_assert(sizeof(Seed) <= sizeof(Hash),
+ "Hash must be at least size of Seed");
+
+ // Multiply by a large random prime (one-to-one for any prefix of bits)
+ Hash tmp = count * kToRawSeedFactor;
+ // Within-byte one-to-one mixing
+ static_assert((kSeedMixMask & (kSeedMixMask >> kSeedMixShift)) == 0,
+ "Illegal mask+shift");
+ tmp ^= (tmp & kSeedMixMask) >> kSeedMixShift;
+ raw_seed_ = static_cast<Seed>(tmp);
+ // dynamic verification
+ assert(GetOrdinalSeed() == count);
+ }
+ Seed GetOrdinalSeed() {
+ Hash tmp = raw_seed_;
+ // Within-byte one-to-one mixing (its own inverse)
+ tmp ^= (tmp & kSeedMixMask) >> kSeedMixShift;
+ // Multiply by 64-bit multiplicative inverse
+ static_assert(kToRawSeedFactor * kFromRawSeedFactor == Hash{1},
+ "Must be inverses");
+ return static_cast<Seed>(tmp * kFromRawSeedFactor);
+ }
+
+ protected:
+ // For expanding hash:
+ // large random prime
+ static constexpr Hash kCoeffAndResultFactor =
+ static_cast<Hash>(0xc28f82822b650bedULL);
+ static constexpr uint64_t kAltCoeffFactor1 = 0x876f170be4f1fcb9U;
+ static constexpr uint64_t kAltCoeffFactor2 = 0xf0433a4aecda4c5fU;
+ // random-ish data
+ static constexpr uint32_t kCoeffXor32 = 0xa6293635U;
+ static constexpr uint64_t kCoeffXor64 = 0xc367844a6e52731dU;
+
+ // For pre-mixing seeds
+ static constexpr Hash kSeedMixMask = static_cast<Hash>(0xf0f0f0f0f0f0f0f0ULL);
+ static constexpr unsigned kSeedMixShift = 4U;
+ static constexpr Hash kToRawSeedFactor =
+ static_cast<Hash>(0xc78219a23eeadd03ULL);
+ static constexpr Hash kFromRawSeedFactor =
+ static_cast<Hash>(0xfe1a137d14b475abULL);
+
+ // See class description
+ Seed raw_seed_ = 0;
+};
+
+// StandardRehasher (and StandardRehasherAdapter): A variant of
+// StandardHasher that uses the same type for keys as for hashes.
+// This is primarily intended for building a Ribbon filter
+// from existing hashes without going back to original inputs in
+// order to apply a different seed. This hasher seeds a 1-to-1 mixing
+// transformation to apply a seed to an existing hash. (Untested for
+// hash-sized keys that are not already uniformly distributed.) This
+// transformation builds on the seed pre-mixing done in StandardHasher.
+//
+// Testing suggests essentially no degradation of solution success rate
+// vs. going back to original inputs when changing hash seeds. For example:
+// Average re-seeds for solution with r=128, 1.02x overhead, and ~100k keys
+// is about 1.10 for both StandardHasher and StandardRehasher.
+//
+// StandardRehasher is not really recommended for general PHSFs (not
+// filters) because a collision in the original hash could prevent
+// construction despite re-seeding the Rehasher. (Such collisions
+// do not interfere with filter construction.)
+//
+// concept RehasherTypesAndSettings: like TypesAndSettings but
+// does not require Key or HashFn.
+template <class RehasherTypesAndSettings>
+class StandardRehasherAdapter : public RehasherTypesAndSettings {
+ public:
+ using Hash = typename RehasherTypesAndSettings::Hash;
+ using Key = Hash;
+ using Seed = typename RehasherTypesAndSettings::Seed;
+
+ static Hash HashFn(const Hash& input, Seed raw_seed) {
+ // Note: raw_seed is already lightly pre-mixed, and this multiplication
+ // by a large prime is sufficient mixing (low-to-high bits) on top of
+ // that for good FastRange results, which depends primarily on highest
+ // bits. (The hashed CoeffRow and ResultRow are less sensitive to
+ // mixing than Start.)
+ // Also note: did consider adding ^ (input >> some) before the
+ // multiplication, but doesn't appear to be necessary.
+ return (input ^ raw_seed) * kRehashFactor;
+ }
+
+ private:
+ static constexpr Hash kRehashFactor =
+ static_cast<Hash>(0x6193d459236a3a0dULL);
+};
+
+// See comment on StandardRehasherAdapter
+template <class RehasherTypesAndSettings>
+using StandardRehasher =
+ StandardHasher<StandardRehasherAdapter<RehasherTypesAndSettings>>;
+
+#ifdef _MSC_VER
+#pragma warning(pop)
+#endif
+
+// Especially with smaller hashes (e.g. 32 bit), there can be noticeable
+// false positives due to collisions in the Hash returned by GetHash.
+// This function returns the expected FP rate due to those collisions,
+// which can be added to the expected FP rate from the underlying data
+// structure. (Note: technically, a + b is only a good approximation of
+// 1-(1-a)(1-b) == a + b - a*b, if a and b are much closer to 0 than to 1.)
+// The number of entries added can be a double here in case it's an
+// average.
+template <class Hasher, typename Numerical>
+double ExpectedCollisionFpRate(const Hasher& hasher, Numerical added) {
+ // Standardize on the 'double' specialization
+ return ExpectedCollisionFpRate(hasher, 1.0 * added);
+}
+template <class Hasher>
+double ExpectedCollisionFpRate(const Hasher& /*hasher*/, double added) {
+ // Technically, there could be overlap among the added, but ignoring that
+ // is typically close enough.
+ return added / std::pow(256.0, sizeof(typename Hasher::Hash));
+}
+
+// StandardBanding: a canonical implementation of BandingStorage and
+// BacktrackStorage, with convenience API for banding (solving with on-the-fly
+// Gaussian elimination) with and without backtracking.
+template <class TypesAndSettings>
+class StandardBanding : public StandardHasher<TypesAndSettings> {
+ public:
+ IMPORT_RIBBON_TYPES_AND_SETTINGS(TypesAndSettings);
+
+ StandardBanding(Index num_slots = 0, Index backtrack_size = 0) {
+ Reset(num_slots, backtrack_size);
+ }
+
+ void Reset(Index num_slots, Index backtrack_size = 0) {
+ if (num_slots == 0) {
+ // Unusual (TypesAndSettings::kAllowZeroStarts) or "uninitialized"
+ num_starts_ = 0;
+ } else {
+ // Normal
+ assert(num_slots >= kCoeffBits);
+ if (num_slots > num_slots_allocated_) {
+ coeff_rows_.reset(new CoeffRow[num_slots]());
+ if (!TypesAndSettings::kHomogeneous) {
+ // Note: don't strictly have to zero-init result_rows,
+ // except possible information leakage, etc ;)
+ result_rows_.reset(new ResultRow[num_slots]());
+ }
+ num_slots_allocated_ = num_slots;
+ } else {
+ for (Index i = 0; i < num_slots; ++i) {
+ coeff_rows_[i] = 0;
+ if (!TypesAndSettings::kHomogeneous) {
+ // Note: don't strictly have to zero-init result_rows,
+ // except possible information leakage, etc ;)
+ result_rows_[i] = 0;
+ }
+ }
+ }
+ num_starts_ = num_slots - kCoeffBits + 1;
+ }
+ EnsureBacktrackSize(backtrack_size);
+ }
+
+ void EnsureBacktrackSize(Index backtrack_size) {
+ if (backtrack_size > backtrack_size_) {
+ backtrack_.reset(new Index[backtrack_size]);
+ backtrack_size_ = backtrack_size;
+ }
+ }
+
+ // ********************************************************************
+ // From concept BandingStorage
+
+ inline bool UsePrefetch() const {
+ // A rough guesstimate of when prefetching during construction pays off.
+ // TODO: verify/validate
+ return num_starts_ > 1500;
+ }
+ inline void Prefetch(Index i) const {
+ PREFETCH(&coeff_rows_[i], 1 /* rw */, 1 /* locality */);
+ if (!TypesAndSettings::kHomogeneous) {
+ PREFETCH(&result_rows_[i], 1 /* rw */, 1 /* locality */);
+ }
+ }
+ inline void LoadRow(Index i, CoeffRow* cr, ResultRow* rr,
+ bool for_back_subst) const {
+ *cr = coeff_rows_[i];
+ if (TypesAndSettings::kHomogeneous) {
+ if (for_back_subst && *cr == 0) {
+ // Cheap pseudorandom data to fill unconstrained solution rows
+ *rr = static_cast<ResultRow>(i * 0x9E3779B185EBCA87ULL);
+ } else {
+ *rr = 0;
+ }
+ } else {
+ *rr = result_rows_[i];
+ }
+ }
+ inline void StoreRow(Index i, CoeffRow cr, ResultRow rr) {
+ coeff_rows_[i] = cr;
+ if (TypesAndSettings::kHomogeneous) {
+ assert(rr == 0);
+ } else {
+ result_rows_[i] = rr;
+ }
+ }
+ inline Index GetNumStarts() const { return num_starts_; }
+
+ // from concept BacktrackStorage, for when backtracking is used
+ inline bool UseBacktrack() const { return true; }
+ inline void BacktrackPut(Index i, Index to_save) { backtrack_[i] = to_save; }
+ inline Index BacktrackGet(Index i) const { return backtrack_[i]; }
+
+ // ********************************************************************
+ // Some useful API, still somewhat low level. Here an input is
+ // a Key for filters, or std::pair<Key, ResultRow> for general PHSF.
+
+ // Adds a range of inputs to the banding, returning true if successful.
+ // False means none or some may have been successfully added, so it's
+ // best to Reset this banding before any further use.
+ //
+ // Adding can fail even before all the "slots" are completely "full".
+ //
+ template <typename InputIterator>
+ bool AddRange(InputIterator begin, InputIterator end) {
+ assert(num_starts_ > 0 || TypesAndSettings::kAllowZeroStarts);
+ if (TypesAndSettings::kAllowZeroStarts && num_starts_ == 0) {
+ // Unusual. Can't add any in this case.
+ return begin == end;
+ }
+ // Normal
+ return BandingAddRange(this, *this, begin, end);
+ }
+
+ // Adds a range of inputs to the banding, returning true if successful,
+ // or if unsuccessful, rolls back to state before this call and returns
+ // false. Caller guarantees that the number of inputs in this batch
+ // does not exceed `backtrack_size` provided to Reset.
+ //
+ // Adding can fail even before all the "slots" are completely "full".
+ //
+ template <typename InputIterator>
+ bool AddRangeOrRollBack(InputIterator begin, InputIterator end) {
+ assert(num_starts_ > 0 || TypesAndSettings::kAllowZeroStarts);
+ if (TypesAndSettings::kAllowZeroStarts && num_starts_ == 0) {
+ // Unusual. Can't add any in this case.
+ return begin == end;
+ }
+ // else Normal
+ return BandingAddRange(this, this, *this, begin, end);
+ }
+
+ // Adds a single input to the banding, returning true if successful.
+ // If unsuccessful, returns false and banding state is unchanged.
+ //
+ // Adding can fail even before all the "slots" are completely "full".
+ //
+ bool Add(const AddInput& input) {
+ // Pointer can act as iterator
+ return AddRange(&input, &input + 1);
+ }
+
+ // Return the number of "occupied" rows (with non-zero coefficients stored).
+ Index GetOccupiedCount() const {
+ Index count = 0;
+ if (num_starts_ > 0) {
+ const Index num_slots = num_starts_ + kCoeffBits - 1;
+ for (Index i = 0; i < num_slots; ++i) {
+ if (coeff_rows_[i] != 0) {
+ ++count;
+ }
+ }
+ }
+ return count;
+ }
+
+ // Returns whether a row is "occupied" in the banding (non-zero
+ // coefficients stored). (Only recommended for debug/test)
+ bool IsOccupied(Index i) { return coeff_rows_[i] != 0; }
+
+ // ********************************************************************
+ // High-level API
+
+ // Iteratively (a) resets the structure for `num_slots`, (b) attempts
+ // to add the range of inputs, and (c) if unsuccessful, chooses next
+ // hash seed, until either successful or unsuccessful with all the
+ // allowed seeds. Returns true if successful. In that case, use
+ // GetOrdinalSeed() or GetRawSeed() to get the successful seed.
+ //
+ // The allowed sequence of hash seeds is determined by
+ // `starting_ordinal_seed,` the first ordinal seed to be attempted
+ // (see StandardHasher), and `ordinal_seed_mask,` a bit mask (power of
+ // two minus one) for the range of ordinal seeds to consider. The
+ // max number of seeds considered will be ordinal_seed_mask + 1.
+ // For filters we suggest `starting_ordinal_seed` be chosen randomly
+ // or round-robin, to minimize false positive correlations between keys.
+ //
+ // If unsuccessful, how best to continue is going to be application
+ // specific. It should be possible to choose parameters such that
+ // failure is extremely unlikely, using max_seed around 32 to 64.
+ // (TODO: APIs to help choose parameters) One option for fallback in
+ // constructing a filter is to construct a Bloom filter instead.
+ // Increasing num_slots is an option, but should not be used often
+ // unless construction maximum latency is a concern (rather than
+ // average running time of construction). Instead, choose parameters
+ // appropriately and trust that seeds are independent. (Also,
+ // increasing num_slots without changing hash seed would have a
+ // significant correlation in success, rather than independence.)
+ template <typename InputIterator>
+ bool ResetAndFindSeedToSolve(Index num_slots, InputIterator begin,
+ InputIterator end,
+ Seed starting_ordinal_seed = 0U,
+ Seed ordinal_seed_mask = 63U) {
+ // power of 2 minus 1
+ assert((ordinal_seed_mask & (ordinal_seed_mask + 1)) == 0);
+ // starting seed is within mask
+ assert((starting_ordinal_seed & ordinal_seed_mask) ==
+ starting_ordinal_seed);
+ starting_ordinal_seed &= ordinal_seed_mask; // if not debug
+
+ Seed cur_ordinal_seed = starting_ordinal_seed;
+ do {
+ StandardHasher<TypesAndSettings>::SetOrdinalSeed(cur_ordinal_seed);
+ Reset(num_slots);
+ bool success = AddRange(begin, end);
+ if (success) {
+ return true;
+ }
+ cur_ordinal_seed = (cur_ordinal_seed + 1) & ordinal_seed_mask;
+ } while (cur_ordinal_seed != starting_ordinal_seed);
+ // Reached limit by circling around
+ return false;
+ }
+
+ static std::size_t EstimateMemoryUsage(uint32_t num_slots) {
+ std::size_t bytes_coeff_rows = num_slots * sizeof(CoeffRow);
+ std::size_t bytes_result_rows = num_slots * sizeof(ResultRow);
+ std::size_t bytes_backtrack = 0;
+ std::size_t bytes_banding =
+ bytes_coeff_rows + bytes_result_rows + bytes_backtrack;
+
+ return bytes_banding;
+ }
+
+ protected:
+ // TODO: explore combining in a struct
+ std::unique_ptr<CoeffRow[]> coeff_rows_;
+ std::unique_ptr<ResultRow[]> result_rows_;
+ // We generally store "starts" instead of slots for speed of GetStart(),
+ // as in StandardHasher.
+ Index num_starts_ = 0;
+ Index num_slots_allocated_ = 0;
+ std::unique_ptr<Index[]> backtrack_;
+ Index backtrack_size_ = 0;
+};
+
+// Implements concept SimpleSolutionStorage, mostly for demonstration
+// purposes. This is "in memory" only because it does not handle byte
+// ordering issues for serialization.
+template <class TypesAndSettings>
+class InMemSimpleSolution {
+ public:
+ IMPORT_RIBBON_TYPES_AND_SETTINGS(TypesAndSettings);
+
+ void PrepareForNumStarts(Index num_starts) {
+ if (TypesAndSettings::kAllowZeroStarts && num_starts == 0) {
+ // Unusual
+ num_starts_ = 0;
+ } else {
+ // Normal
+ const Index num_slots = num_starts + kCoeffBits - 1;
+ assert(num_slots >= kCoeffBits);
+ if (num_slots > num_slots_allocated_) {
+ // Do not need to init the memory
+ solution_rows_.reset(new ResultRow[num_slots]);
+ num_slots_allocated_ = num_slots;
+ }
+ num_starts_ = num_starts;
+ }
+ }
+
+ Index GetNumStarts() const { return num_starts_; }
+
+ ResultRow Load(Index slot_num) const { return solution_rows_[slot_num]; }
+
+ void Store(Index slot_num, ResultRow solution_row) {
+ solution_rows_[slot_num] = solution_row;
+ }
+
+ // ********************************************************************
+ // High-level API
+
+ template <typename BandingStorage>
+ void BackSubstFrom(const BandingStorage& bs) {
+ if (TypesAndSettings::kAllowZeroStarts && bs.GetNumStarts() == 0) {
+ // Unusual
+ PrepareForNumStarts(0);
+ } else {
+ // Normal
+ SimpleBackSubst(this, bs);
+ }
+ }
+
+ template <typename PhsfQueryHasher>
+ ResultRow PhsfQuery(const Key& input, const PhsfQueryHasher& hasher) const {
+ // assert(!TypesAndSettings::kIsFilter); Can be useful in testing
+ if (TypesAndSettings::kAllowZeroStarts && num_starts_ == 0) {
+ // Unusual
+ return 0;
+ } else {
+ // Normal
+ return SimplePhsfQuery(input, hasher, *this);
+ }
+ }
+
+ template <typename FilterQueryHasher>
+ bool FilterQuery(const Key& input, const FilterQueryHasher& hasher) const {
+ assert(TypesAndSettings::kIsFilter);
+ if (TypesAndSettings::kAllowZeroStarts && num_starts_ == 0) {
+ // Unusual. Zero starts presumes no keys added -> always false
+ return false;
+ } else {
+ // Normal, or upper_num_columns_ == 0 means "no space for data" and
+ // thus will always return true.
+ return SimpleFilterQuery(input, hasher, *this);
+ }
+ }
+
+ double ExpectedFpRate() const {
+ assert(TypesAndSettings::kIsFilter);
+ if (TypesAndSettings::kAllowZeroStarts && num_starts_ == 0) {
+ // Unusual, but we don't have FPs if we always return false.
+ return 0.0;
+ }
+ // else Normal
+
+ // Each result (solution) bit (column) cuts FP rate in half
+ return std::pow(0.5, 8U * sizeof(ResultRow));
+ }
+
+ // ********************************************************************
+ // Static high-level API
+
+ // Round up to a number of slots supported by this structure. Note that
+ // this needs to be must be taken into account for the banding if this
+ // solution layout/storage is to be used.
+ static Index RoundUpNumSlots(Index num_slots) {
+ // Must be at least kCoeffBits for at least one start
+ // Or if not smash, even more because hashing not equipped
+ // for stacking up so many entries on a single start location
+ auto min_slots = kCoeffBits * (TypesAndSettings::kUseSmash ? 1 : 2);
+ return std::max(num_slots, static_cast<Index>(min_slots));
+ }
+
+ protected:
+ // We generally store "starts" instead of slots for speed of GetStart(),
+ // as in StandardHasher.
+ Index num_starts_ = 0;
+ Index num_slots_allocated_ = 0;
+ std::unique_ptr<ResultRow[]> solution_rows_;
+};
+
+// Implements concept InterleavedSolutionStorage always using little-endian
+// byte order, so easy for serialization/deserialization. This implementation
+// fully supports fractional bits per key, where any number of segments
+// (number of bytes multiple of sizeof(CoeffRow)) can be used with any number
+// of slots that is a multiple of kCoeffBits.
+//
+// The structure is passed an externally allocated/de-allocated byte buffer
+// that is optionally pre-populated (from storage) for answering queries,
+// or can be populated by BackSubstFrom.
+//
+template <class TypesAndSettings>
+class SerializableInterleavedSolution {
+ public:
+ IMPORT_RIBBON_TYPES_AND_SETTINGS(TypesAndSettings);
+
+ // Does not take ownership of `data` but uses it (up to `data_len` bytes)
+ // throughout lifetime
+ SerializableInterleavedSolution(char* data, size_t data_len)
+ : data_(data), data_len_(data_len) {}
+
+ void PrepareForNumStarts(Index num_starts) {
+ assert(num_starts == 0 || (num_starts % kCoeffBits == 1));
+ num_starts_ = num_starts;
+
+ InternalConfigure();
+ }
+
+ Index GetNumStarts() const { return num_starts_; }
+
+ Index GetNumBlocks() const {
+ const Index num_slots = num_starts_ + kCoeffBits - 1;
+ return num_slots / kCoeffBits;
+ }
+
+ Index GetUpperNumColumns() const { return upper_num_columns_; }
+
+ Index GetUpperStartBlock() const { return upper_start_block_; }
+
+ Index GetNumSegments() const {
+ return static_cast<Index>(data_len_ / sizeof(CoeffRow));
+ }
+
+ CoeffRow LoadSegment(Index segment_num) const {
+ assert(data_ != nullptr); // suppress clang analyzer report
+ return DecodeFixedGeneric<CoeffRow>(data_ + segment_num * sizeof(CoeffRow));
+ }
+ void StoreSegment(Index segment_num, CoeffRow val) {
+ assert(data_ != nullptr); // suppress clang analyzer report
+ EncodeFixedGeneric(data_ + segment_num * sizeof(CoeffRow), val);
+ }
+ void PrefetchSegmentRange(Index begin_segment_num,
+ Index end_segment_num) const {
+ if (end_segment_num == begin_segment_num) {
+ // Nothing to do
+ return;
+ }
+ char* cur = data_ + begin_segment_num * sizeof(CoeffRow);
+ char* last = data_ + (end_segment_num - 1) * sizeof(CoeffRow);
+ while (cur < last) {
+ PREFETCH(cur, 0 /* rw */, 1 /* locality */);
+ cur += CACHE_LINE_SIZE;
+ }
+ PREFETCH(last, 0 /* rw */, 1 /* locality */);
+ }
+
+ // ********************************************************************
+ // High-level API
+
+ void ConfigureForNumBlocks(Index num_blocks) {
+ if (num_blocks == 0) {
+ PrepareForNumStarts(0);
+ } else {
+ PrepareForNumStarts(num_blocks * kCoeffBits - kCoeffBits + 1);
+ }
+ }
+
+ void ConfigureForNumSlots(Index num_slots) {
+ assert(num_slots % kCoeffBits == 0);
+ ConfigureForNumBlocks(num_slots / kCoeffBits);
+ }
+
+ template <typename BandingStorage>
+ void BackSubstFrom(const BandingStorage& bs) {
+ if (TypesAndSettings::kAllowZeroStarts && bs.GetNumStarts() == 0) {
+ // Unusual
+ PrepareForNumStarts(0);
+ } else {
+ // Normal
+ InterleavedBackSubst(this, bs);
+ }
+ }
+
+ template <typename PhsfQueryHasher>
+ ResultRow PhsfQuery(const Key& input, const PhsfQueryHasher& hasher) const {
+ // assert(!TypesAndSettings::kIsFilter); Can be useful in testing
+ if (TypesAndSettings::kAllowZeroStarts && num_starts_ == 0) {
+ // Unusual
+ return 0;
+ } else {
+ // Normal
+ // NOTE: not using a struct to encourage compiler optimization
+ Hash hash;
+ Index segment_num;
+ Index num_columns;
+ Index start_bit;
+ InterleavedPrepareQuery(input, hasher, *this, &hash, &segment_num,
+ &num_columns, &start_bit);
+ return InterleavedPhsfQuery(hash, segment_num, num_columns, start_bit,
+ hasher, *this);
+ }
+ }
+
+ template <typename FilterQueryHasher>
+ bool FilterQuery(const Key& input, const FilterQueryHasher& hasher) const {
+ assert(TypesAndSettings::kIsFilter);
+ if (TypesAndSettings::kAllowZeroStarts && num_starts_ == 0) {
+ // Unusual. Zero starts presumes no keys added -> always false
+ return false;
+ } else {
+ // Normal, or upper_num_columns_ == 0 means "no space for data" and
+ // thus will always return true.
+ // NOTE: not using a struct to encourage compiler optimization
+ Hash hash;
+ Index segment_num;
+ Index num_columns;
+ Index start_bit;
+ InterleavedPrepareQuery(input, hasher, *this, &hash, &segment_num,
+ &num_columns, &start_bit);
+ return InterleavedFilterQuery(hash, segment_num, num_columns, start_bit,
+ hasher, *this);
+ }
+ }
+
+ double ExpectedFpRate() const {
+ assert(TypesAndSettings::kIsFilter);
+ if (TypesAndSettings::kAllowZeroStarts && num_starts_ == 0) {
+ // Unusual. Zero starts presumes no keys added -> always false
+ return 0.0;
+ }
+ // else Normal
+
+ // Note: Ignoring smash setting; still close enough in that case
+ double lower_portion =
+ (upper_start_block_ * 1.0 * kCoeffBits) / num_starts_;
+
+ // Each result (solution) bit (column) cuts FP rate in half. Weight that
+ // for upper and lower number of bits (columns).
+ return lower_portion * std::pow(0.5, upper_num_columns_ - 1) +
+ (1.0 - lower_portion) * std::pow(0.5, upper_num_columns_);
+ }
+
+ // ********************************************************************
+ // Static high-level API
+
+ // Round up to a number of slots supported by this structure. Note that
+ // this needs to be must be taken into account for the banding if this
+ // solution layout/storage is to be used.
+ static Index RoundUpNumSlots(Index num_slots) {
+ // Must be multiple of kCoeffBits
+ Index corrected = (num_slots + kCoeffBits - 1) / kCoeffBits * kCoeffBits;
+
+ // Do not use num_starts==1 unless kUseSmash, because the hashing
+ // might not be equipped for stacking up so many entries on a
+ // single start location.
+ if (!TypesAndSettings::kUseSmash && corrected == kCoeffBits) {
+ corrected += kCoeffBits;
+ }
+ return corrected;
+ }
+
+ // Round down to a number of slots supported by this structure. Note that
+ // this needs to be must be taken into account for the banding if this
+ // solution layout/storage is to be used.
+ static Index RoundDownNumSlots(Index num_slots) {
+ // Must be multiple of kCoeffBits
+ Index corrected = num_slots / kCoeffBits * kCoeffBits;
+
+ // Do not use num_starts==1 unless kUseSmash, because the hashing
+ // might not be equipped for stacking up so many entries on a
+ // single start location.
+ if (!TypesAndSettings::kUseSmash && corrected == kCoeffBits) {
+ corrected = 0;
+ }
+ return corrected;
+ }
+
+ // Compute the number of bytes for a given number of slots and desired
+ // FP rate. Since desired FP rate might not be exactly achievable,
+ // rounding_bias32==0 means to always round toward lower FP rate
+ // than desired (more bytes); rounding_bias32==max uint32_t means always
+ // round toward higher FP rate than desired (fewer bytes); other values
+ // act as a proportional threshold or bias between the two.
+ static size_t GetBytesForFpRate(Index num_slots, double desired_fp_rate,
+ uint32_t rounding_bias32) {
+ return InternalGetBytesForFpRate(num_slots, desired_fp_rate,
+ 1.0 / desired_fp_rate, rounding_bias32);
+ }
+
+ // The same, but specifying desired accuracy as 1.0 / FP rate, or
+ // one_in_fp_rate. E.g. desired_one_in_fp_rate=100 means 1% FP rate.
+ static size_t GetBytesForOneInFpRate(Index num_slots,
+ double desired_one_in_fp_rate,
+ uint32_t rounding_bias32) {
+ return InternalGetBytesForFpRate(num_slots, 1.0 / desired_one_in_fp_rate,
+ desired_one_in_fp_rate, rounding_bias32);
+ }
+
+ protected:
+ static size_t InternalGetBytesForFpRate(Index num_slots,
+ double desired_fp_rate,
+ double desired_one_in_fp_rate,
+ uint32_t rounding_bias32) {
+ assert(TypesAndSettings::kIsFilter);
+ if (TypesAndSettings::kAllowZeroStarts) {
+ if (num_slots == 0) {
+ // Unusual. Zero starts presumes no keys added -> always false (no FPs)
+ return 0U;
+ }
+ } else {
+ assert(num_slots > 0);
+ }
+ // Must be rounded up already.
+ assert(RoundUpNumSlots(num_slots) == num_slots);
+
+ if (desired_one_in_fp_rate > 1.0 && desired_fp_rate < 1.0) {
+ // Typical: less than 100% FP rate
+ if (desired_one_in_fp_rate <= static_cast<ResultRow>(-1)) {
+ // Typical: Less than maximum result row entropy
+ ResultRow rounded = static_cast<ResultRow>(desired_one_in_fp_rate);
+ int lower_columns = FloorLog2(rounded);
+ double lower_columns_fp_rate = std::pow(2.0, -lower_columns);
+ double upper_columns_fp_rate = std::pow(2.0, -(lower_columns + 1));
+ // Floating point don't let me down!
+ assert(lower_columns_fp_rate >= desired_fp_rate);
+ assert(upper_columns_fp_rate <= desired_fp_rate);
+
+ double lower_portion = (desired_fp_rate - upper_columns_fp_rate) /
+ (lower_columns_fp_rate - upper_columns_fp_rate);
+ // Floating point don't let me down!
+ assert(lower_portion >= 0.0);
+ assert(lower_portion <= 1.0);
+
+ double rounding_bias = (rounding_bias32 + 0.5) / double{0x100000000};
+ assert(rounding_bias > 0.0);
+ assert(rounding_bias < 1.0);
+
+ // Note: Ignoring smash setting; still close enough in that case
+ Index num_starts = num_slots - kCoeffBits + 1;
+ // Lower upper_start_block means lower FP rate (higher accuracy)
+ Index upper_start_block = static_cast<Index>(
+ (lower_portion * num_starts + rounding_bias) / kCoeffBits);
+ Index num_blocks = num_slots / kCoeffBits;
+ assert(upper_start_block < num_blocks);
+
+ // Start by assuming all blocks use lower number of columns
+ Index num_segments = num_blocks * static_cast<Index>(lower_columns);
+ // Correct by 1 each for blocks using upper number of columns
+ num_segments += (num_blocks - upper_start_block);
+ // Total bytes
+ return num_segments * sizeof(CoeffRow);
+ } else {
+ // one_in_fp_rate too big, thus requested FP rate is smaller than
+ // supported. Use max number of columns for minimum supported FP rate.
+ return num_slots * sizeof(ResultRow);
+ }
+ } else {
+ // Effectively asking for 100% FP rate, or NaN etc.
+ if (TypesAndSettings::kAllowZeroStarts) {
+ // Zero segments
+ return 0U;
+ } else {
+ // One segment (minimum size, maximizing FP rate)
+ return sizeof(CoeffRow);
+ }
+ }
+ }
+
+ void InternalConfigure() {
+ const Index num_blocks = GetNumBlocks();
+ Index num_segments = GetNumSegments();
+
+ if (num_blocks == 0) {
+ // Exceptional
+ upper_num_columns_ = 0;
+ upper_start_block_ = 0;
+ } else {
+ // Normal
+ upper_num_columns_ =
+ (num_segments + /*round up*/ num_blocks - 1) / num_blocks;
+ upper_start_block_ = upper_num_columns_ * num_blocks - num_segments;
+ // Unless that's more columns than supported by ResultRow data type
+ if (upper_num_columns_ > 8U * sizeof(ResultRow)) {
+ // Use maximum columns (there will be space unused)
+ upper_num_columns_ = static_cast<Index>(8U * sizeof(ResultRow));
+ upper_start_block_ = 0;
+ num_segments = num_blocks * upper_num_columns_;
+ }
+ }
+ // Update data_len_ for correct rounding and/or unused space
+ // NOTE: unused space stays gone if we PrepareForNumStarts again.
+ // We are prioritizing minimizing the number of fields over making
+ // the "unusued space" feature work well.
+ data_len_ = num_segments * sizeof(CoeffRow);
+ }
+
+ char* const data_;
+ size_t data_len_;
+ Index num_starts_ = 0;
+ Index upper_num_columns_ = 0;
+ Index upper_start_block_ = 0;
+};
+
+} // namespace ribbon
+
+} // namespace ROCKSDB_NAMESPACE
+
+// For convenience working with templates
+#define IMPORT_RIBBON_IMPL_TYPES(TypesAndSettings) \
+ using Hasher = ROCKSDB_NAMESPACE::ribbon::StandardHasher<TypesAndSettings>; \
+ using Banding = \
+ ROCKSDB_NAMESPACE::ribbon::StandardBanding<TypesAndSettings>; \
+ using SimpleSoln = \
+ ROCKSDB_NAMESPACE::ribbon::InMemSimpleSolution<TypesAndSettings>; \
+ using InterleavedSoln = \
+ ROCKSDB_NAMESPACE::ribbon::SerializableInterleavedSolution< \
+ TypesAndSettings>; \
+ static_assert(sizeof(Hasher) + sizeof(Banding) + sizeof(SimpleSoln) + \
+ sizeof(InterleavedSoln) > \
+ 0, \
+ "avoid unused warnings, semicolon expected after macro call")
diff --git a/src/rocksdb/util/ribbon_test.cc b/src/rocksdb/util/ribbon_test.cc
new file mode 100644
index 000000000..6519df3d5
--- /dev/null
+++ b/src/rocksdb/util/ribbon_test.cc
@@ -0,0 +1,1308 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "rocksdb/system_clock.h"
+#include "test_util/testharness.h"
+#include "util/bloom_impl.h"
+#include "util/coding.h"
+#include "util/hash.h"
+#include "util/ribbon_config.h"
+#include "util/ribbon_impl.h"
+#include "util/stop_watch.h"
+#include "util/string_util.h"
+
+#ifndef GFLAGS
+uint32_t FLAGS_thoroughness = 5;
+uint32_t FLAGS_max_add = 0;
+uint32_t FLAGS_min_check = 4000;
+uint32_t FLAGS_max_check = 100000;
+bool FLAGS_verbose = false;
+
+bool FLAGS_find_occ = false;
+bool FLAGS_find_slot_occ = false;
+double FLAGS_find_next_factor = 1.618;
+uint32_t FLAGS_find_iters = 10000;
+uint32_t FLAGS_find_min_slots = 128;
+uint32_t FLAGS_find_max_slots = 1000000;
+
+bool FLAGS_optimize_homog = false;
+uint32_t FLAGS_optimize_homog_slots = 30000000;
+uint32_t FLAGS_optimize_homog_check = 200000;
+double FLAGS_optimize_homog_granularity = 0.002;
+#else
+#include "util/gflags_compat.h"
+using GFLAGS_NAMESPACE::ParseCommandLineFlags;
+// Using 500 is a good test when you have time to be thorough.
+// Default is for general RocksDB regression test runs.
+DEFINE_uint32(thoroughness, 5, "iterations per configuration");
+DEFINE_uint32(max_add, 0,
+ "Add up to this number of entries to a single filter in "
+ "CompactnessAndBacktrackAndFpRate; 0 == reasonable default");
+DEFINE_uint32(min_check, 4000,
+ "Minimum number of novel entries for testing FP rate");
+DEFINE_uint32(max_check, 10000,
+ "Maximum number of novel entries for testing FP rate");
+DEFINE_bool(verbose, false, "Print extra details");
+
+// Options for FindOccupancy, which is more of a tool than a test.
+DEFINE_bool(find_occ, false, "whether to run the FindOccupancy tool");
+DEFINE_bool(find_slot_occ, false,
+ "whether to show individual slot occupancies with "
+ "FindOccupancy tool");
+DEFINE_double(find_next_factor, 1.618,
+ "factor to next num_slots for FindOccupancy");
+DEFINE_uint32(find_iters, 10000, "number of samples for FindOccupancy");
+DEFINE_uint32(find_min_slots, 128, "number of slots for FindOccupancy");
+DEFINE_uint32(find_max_slots, 1000000, "number of slots for FindOccupancy");
+
+// Options for OptimizeHomogAtScale, which is more of a tool than a test.
+DEFINE_bool(optimize_homog, false,
+ "whether to run the OptimizeHomogAtScale tool");
+DEFINE_uint32(optimize_homog_slots, 30000000,
+ "number of slots for OptimizeHomogAtScale");
+DEFINE_uint32(optimize_homog_check, 200000,
+ "number of queries for checking FP rate in OptimizeHomogAtScale");
+DEFINE_double(
+ optimize_homog_granularity, 0.002,
+ "overhead change between FP rate checking in OptimizeHomogAtScale");
+
+#endif // GFLAGS
+
+template <typename TypesAndSettings>
+class RibbonTypeParamTest : public ::testing::Test {};
+
+class RibbonTest : public ::testing::Test {};
+
+namespace {
+
+// Different ways of generating keys for testing
+
+// Generate semi-sequential keys
+struct StandardKeyGen {
+ StandardKeyGen(const std::string& prefix, uint64_t id)
+ : id_(id), str_(prefix) {
+ ROCKSDB_NAMESPACE::PutFixed64(&str_, /*placeholder*/ 0);
+ }
+
+ // Prefix (only one required)
+ StandardKeyGen& operator++() {
+ ++id_;
+ return *this;
+ }
+
+ StandardKeyGen& operator+=(uint64_t i) {
+ id_ += i;
+ return *this;
+ }
+
+ const std::string& operator*() {
+ // Use multiplication to mix things up a little in the key
+ ROCKSDB_NAMESPACE::EncodeFixed64(&str_[str_.size() - 8],
+ id_ * uint64_t{0x1500000001});
+ return str_;
+ }
+
+ bool operator==(const StandardKeyGen& other) {
+ // Same prefix is assumed
+ return id_ == other.id_;
+ }
+ bool operator!=(const StandardKeyGen& other) {
+ // Same prefix is assumed
+ return id_ != other.id_;
+ }
+
+ uint64_t id_;
+ std::string str_;
+};
+
+// Generate small sequential keys, that can misbehave with sequential seeds
+// as in https://github.com/Cyan4973/xxHash/issues/469.
+// These keys are only heuristically unique, but that's OK with 64 bits,
+// for testing purposes.
+struct SmallKeyGen {
+ SmallKeyGen(const std::string& prefix, uint64_t id) : id_(id) {
+ // Hash the prefix for a heuristically unique offset
+ id_ += ROCKSDB_NAMESPACE::GetSliceHash64(prefix);
+ ROCKSDB_NAMESPACE::PutFixed64(&str_, id_);
+ }
+
+ // Prefix (only one required)
+ SmallKeyGen& operator++() {
+ ++id_;
+ return *this;
+ }
+
+ SmallKeyGen& operator+=(uint64_t i) {
+ id_ += i;
+ return *this;
+ }
+
+ const std::string& operator*() {
+ ROCKSDB_NAMESPACE::EncodeFixed64(&str_[str_.size() - 8], id_);
+ return str_;
+ }
+
+ bool operator==(const SmallKeyGen& other) { return id_ == other.id_; }
+ bool operator!=(const SmallKeyGen& other) { return id_ != other.id_; }
+
+ uint64_t id_;
+ std::string str_;
+};
+
+template <typename KeyGen>
+struct Hash32KeyGenWrapper : public KeyGen {
+ Hash32KeyGenWrapper(const std::string& prefix, uint64_t id)
+ : KeyGen(prefix, id) {}
+ uint32_t operator*() {
+ auto& key = *static_cast<KeyGen&>(*this);
+ // unseeded
+ return ROCKSDB_NAMESPACE::GetSliceHash(key);
+ }
+};
+
+template <typename KeyGen>
+struct Hash64KeyGenWrapper : public KeyGen {
+ Hash64KeyGenWrapper(const std::string& prefix, uint64_t id)
+ : KeyGen(prefix, id) {}
+ uint64_t operator*() {
+ auto& key = *static_cast<KeyGen&>(*this);
+ // unseeded
+ return ROCKSDB_NAMESPACE::GetSliceHash64(key);
+ }
+};
+
+using ROCKSDB_NAMESPACE::ribbon::ConstructionFailureChance;
+
+const std::vector<ConstructionFailureChance> kFailureOnly50Pct = {
+ ROCKSDB_NAMESPACE::ribbon::kOneIn2};
+
+const std::vector<ConstructionFailureChance> kFailureOnlyRare = {
+ ROCKSDB_NAMESPACE::ribbon::kOneIn1000};
+
+const std::vector<ConstructionFailureChance> kFailureAll = {
+ ROCKSDB_NAMESPACE::ribbon::kOneIn2, ROCKSDB_NAMESPACE::ribbon::kOneIn20,
+ ROCKSDB_NAMESPACE::ribbon::kOneIn1000};
+
+} // namespace
+
+using ROCKSDB_NAMESPACE::ribbon::ExpectedCollisionFpRate;
+using ROCKSDB_NAMESPACE::ribbon::StandardHasher;
+using ROCKSDB_NAMESPACE::ribbon::StandardRehasherAdapter;
+
+struct DefaultTypesAndSettings {
+ using CoeffRow = ROCKSDB_NAMESPACE::Unsigned128;
+ using ResultRow = uint8_t;
+ using Index = uint32_t;
+ using Hash = uint64_t;
+ using Seed = uint32_t;
+ using Key = ROCKSDB_NAMESPACE::Slice;
+ static constexpr bool kIsFilter = true;
+ static constexpr bool kHomogeneous = false;
+ static constexpr bool kFirstCoeffAlwaysOne = true;
+ static constexpr bool kUseSmash = false;
+ static constexpr bool kAllowZeroStarts = false;
+ static Hash HashFn(const Key& key, uint64_t raw_seed) {
+ // This version 0.7.2 preview of XXH3 (a.k.a. XXPH3) function does
+ // not pass SmallKeyGen tests below without some seed premixing from
+ // StandardHasher. See https://github.com/Cyan4973/xxHash/issues/469
+ return ROCKSDB_NAMESPACE::Hash64(key.data(), key.size(), raw_seed);
+ }
+ // For testing
+ using KeyGen = StandardKeyGen;
+ static const std::vector<ConstructionFailureChance>& FailureChanceToTest() {
+ return kFailureAll;
+ }
+};
+
+using TypesAndSettings_Coeff128 = DefaultTypesAndSettings;
+struct TypesAndSettings_Coeff128Smash : public DefaultTypesAndSettings {
+ static constexpr bool kUseSmash = true;
+};
+struct TypesAndSettings_Coeff64 : public DefaultTypesAndSettings {
+ using CoeffRow = uint64_t;
+};
+struct TypesAndSettings_Coeff64Smash : public TypesAndSettings_Coeff64 {
+ static constexpr bool kUseSmash = true;
+};
+struct TypesAndSettings_Coeff64Smash0 : public TypesAndSettings_Coeff64Smash {
+ static constexpr bool kFirstCoeffAlwaysOne = false;
+};
+
+// Homogeneous Ribbon configurations
+struct TypesAndSettings_Coeff128_Homog : public DefaultTypesAndSettings {
+ static constexpr bool kHomogeneous = true;
+ // Since our best construction success setting still has 1/1000 failure
+ // rate, the best FP rate we test is 1/256
+ using ResultRow = uint8_t;
+ // Homogeneous only makes sense with sufficient slots for equivalent of
+ // almost sure construction success
+ static const std::vector<ConstructionFailureChance>& FailureChanceToTest() {
+ return kFailureOnlyRare;
+ }
+};
+struct TypesAndSettings_Coeff128Smash_Homog
+ : public TypesAndSettings_Coeff128_Homog {
+ // Smash (extra time to save space) + Homog (extra space to save time)
+ // doesn't make much sense in practice, but we minimally test it
+ static constexpr bool kUseSmash = true;
+};
+struct TypesAndSettings_Coeff64_Homog : public TypesAndSettings_Coeff128_Homog {
+ using CoeffRow = uint64_t;
+};
+struct TypesAndSettings_Coeff64Smash_Homog
+ : public TypesAndSettings_Coeff64_Homog {
+ // Smash (extra time to save space) + Homog (extra space to save time)
+ // doesn't make much sense in practice, but we minimally test it
+ static constexpr bool kUseSmash = true;
+};
+
+// Less exhaustive mix of coverage, but still covering the most stressful case
+// (only 50% construction success)
+struct AbridgedTypesAndSettings : public DefaultTypesAndSettings {
+ static const std::vector<ConstructionFailureChance>& FailureChanceToTest() {
+ return kFailureOnly50Pct;
+ }
+};
+struct TypesAndSettings_Result16 : public AbridgedTypesAndSettings {
+ using ResultRow = uint16_t;
+};
+struct TypesAndSettings_Result32 : public AbridgedTypesAndSettings {
+ using ResultRow = uint32_t;
+};
+struct TypesAndSettings_IndexSizeT : public AbridgedTypesAndSettings {
+ using Index = size_t;
+};
+struct TypesAndSettings_Hash32 : public AbridgedTypesAndSettings {
+ using Hash = uint32_t;
+ static Hash HashFn(const Key& key, Hash raw_seed) {
+ // This MurmurHash1 function does not pass tests below without the
+ // seed premixing from StandardHasher. In fact, it needs more than
+ // just a multiplication mixer on the ordinal seed.
+ return ROCKSDB_NAMESPACE::Hash(key.data(), key.size(), raw_seed);
+ }
+};
+struct TypesAndSettings_Hash32_Result16 : public AbridgedTypesAndSettings {
+ using ResultRow = uint16_t;
+};
+struct TypesAndSettings_KeyString : public AbridgedTypesAndSettings {
+ using Key = std::string;
+};
+struct TypesAndSettings_Seed8 : public AbridgedTypesAndSettings {
+ // This is not a generally recommended configuration. With the configured
+ // hash function, it would fail with SmallKeyGen due to insufficient
+ // independence among the seeds.
+ using Seed = uint8_t;
+};
+struct TypesAndSettings_NoAlwaysOne : public AbridgedTypesAndSettings {
+ static constexpr bool kFirstCoeffAlwaysOne = false;
+};
+struct TypesAndSettings_AllowZeroStarts : public AbridgedTypesAndSettings {
+ static constexpr bool kAllowZeroStarts = true;
+};
+struct TypesAndSettings_Seed64 : public AbridgedTypesAndSettings {
+ using Seed = uint64_t;
+};
+struct TypesAndSettings_Rehasher
+ : public StandardRehasherAdapter<AbridgedTypesAndSettings> {
+ using KeyGen = Hash64KeyGenWrapper<StandardKeyGen>;
+};
+struct TypesAndSettings_Rehasher_Result16 : public TypesAndSettings_Rehasher {
+ using ResultRow = uint16_t;
+};
+struct TypesAndSettings_Rehasher_Result32 : public TypesAndSettings_Rehasher {
+ using ResultRow = uint32_t;
+};
+struct TypesAndSettings_Rehasher_Seed64
+ : public StandardRehasherAdapter<TypesAndSettings_Seed64> {
+ using KeyGen = Hash64KeyGenWrapper<StandardKeyGen>;
+ // Note: 64-bit seed with Rehasher gives slightly better average reseeds
+};
+struct TypesAndSettings_Rehasher32
+ : public StandardRehasherAdapter<TypesAndSettings_Hash32> {
+ using KeyGen = Hash32KeyGenWrapper<StandardKeyGen>;
+};
+struct TypesAndSettings_Rehasher32_Coeff64
+ : public TypesAndSettings_Rehasher32 {
+ using CoeffRow = uint64_t;
+};
+struct TypesAndSettings_SmallKeyGen : public AbridgedTypesAndSettings {
+ // SmallKeyGen stresses the independence of different hash seeds
+ using KeyGen = SmallKeyGen;
+};
+struct TypesAndSettings_Hash32_SmallKeyGen : public TypesAndSettings_Hash32 {
+ // SmallKeyGen stresses the independence of different hash seeds
+ using KeyGen = SmallKeyGen;
+};
+struct TypesAndSettings_Coeff32 : public DefaultTypesAndSettings {
+ using CoeffRow = uint32_t;
+};
+struct TypesAndSettings_Coeff32Smash : public TypesAndSettings_Coeff32 {
+ static constexpr bool kUseSmash = true;
+};
+struct TypesAndSettings_Coeff16 : public DefaultTypesAndSettings {
+ using CoeffRow = uint16_t;
+};
+struct TypesAndSettings_Coeff16Smash : public TypesAndSettings_Coeff16 {
+ static constexpr bool kUseSmash = true;
+};
+
+using TestTypesAndSettings = ::testing::Types<
+ TypesAndSettings_Coeff128, TypesAndSettings_Coeff128Smash,
+ TypesAndSettings_Coeff64, TypesAndSettings_Coeff64Smash,
+ TypesAndSettings_Coeff64Smash0, TypesAndSettings_Coeff128_Homog,
+ TypesAndSettings_Coeff128Smash_Homog, TypesAndSettings_Coeff64_Homog,
+ TypesAndSettings_Coeff64Smash_Homog, TypesAndSettings_Result16,
+ TypesAndSettings_Result32, TypesAndSettings_IndexSizeT,
+ TypesAndSettings_Hash32, TypesAndSettings_Hash32_Result16,
+ TypesAndSettings_KeyString, TypesAndSettings_Seed8,
+ TypesAndSettings_NoAlwaysOne, TypesAndSettings_AllowZeroStarts,
+ TypesAndSettings_Seed64, TypesAndSettings_Rehasher,
+ TypesAndSettings_Rehasher_Result16, TypesAndSettings_Rehasher_Result32,
+ TypesAndSettings_Rehasher_Seed64, TypesAndSettings_Rehasher32,
+ TypesAndSettings_Rehasher32_Coeff64, TypesAndSettings_SmallKeyGen,
+ TypesAndSettings_Hash32_SmallKeyGen, TypesAndSettings_Coeff32,
+ TypesAndSettings_Coeff32Smash, TypesAndSettings_Coeff16,
+ TypesAndSettings_Coeff16Smash>;
+TYPED_TEST_CASE(RibbonTypeParamTest, TestTypesAndSettings);
+
+namespace {
+
+// For testing Poisson-distributed (or similar) statistics, get value for
+// `stddevs_allowed` standard deviations above expected mean
+// `expected_count`.
+// (Poisson approximates Binomial only if probability of a trial being
+// in the count is low.)
+uint64_t PoissonUpperBound(double expected_count, double stddevs_allowed) {
+ return static_cast<uint64_t>(
+ expected_count + stddevs_allowed * std::sqrt(expected_count) + 1.0);
+}
+
+uint64_t PoissonLowerBound(double expected_count, double stddevs_allowed) {
+ return static_cast<uint64_t>(std::max(
+ 0.0, expected_count - stddevs_allowed * std::sqrt(expected_count)));
+}
+
+uint64_t FrequentPoissonUpperBound(double expected_count) {
+ // Allow up to 5.0 standard deviations for frequently checked statistics
+ return PoissonUpperBound(expected_count, 5.0);
+}
+
+uint64_t FrequentPoissonLowerBound(double expected_count) {
+ return PoissonLowerBound(expected_count, 5.0);
+}
+
+uint64_t InfrequentPoissonUpperBound(double expected_count) {
+ // Allow up to 3 standard deviations for infrequently checked statistics
+ return PoissonUpperBound(expected_count, 3.0);
+}
+
+uint64_t InfrequentPoissonLowerBound(double expected_count) {
+ return PoissonLowerBound(expected_count, 3.0);
+}
+
+} // namespace
+
+TYPED_TEST(RibbonTypeParamTest, CompactnessAndBacktrackAndFpRate) {
+ IMPORT_RIBBON_TYPES_AND_SETTINGS(TypeParam);
+ IMPORT_RIBBON_IMPL_TYPES(TypeParam);
+ using KeyGen = typename TypeParam::KeyGen;
+ using ConfigHelper =
+ ROCKSDB_NAMESPACE::ribbon::BandingConfigHelper<TypeParam>;
+
+ if (sizeof(CoeffRow) < 8) {
+ ROCKSDB_GTEST_BYPASS("Not fully supported");
+ return;
+ }
+
+ const auto log2_thoroughness =
+ static_cast<uint32_t>(ROCKSDB_NAMESPACE::FloorLog2(FLAGS_thoroughness));
+
+ // We are going to choose num_to_add using an exponential distribution,
+ // so that we have good representation of small-to-medium filters.
+ // Here we just pick some reasonable, practical upper bound based on
+ // kCoeffBits or option.
+ const double log_max_add = std::log(
+ FLAGS_max_add > 0 ? FLAGS_max_add
+ : static_cast<uint32_t>(kCoeffBits * kCoeffBits) *
+ std::max(FLAGS_thoroughness, uint32_t{32}));
+
+ // This needs to be enough below the minimum number of slots to get a
+ // reasonable number of samples with the minimum number of slots.
+ const double log_min_add = std::log(0.66 * SimpleSoln::RoundUpNumSlots(1));
+
+ ASSERT_GT(log_max_add, log_min_add);
+
+ const double diff_log_add = log_max_add - log_min_add;
+
+ for (ConstructionFailureChance cs : TypeParam::FailureChanceToTest()) {
+ double expected_reseeds;
+ switch (cs) {
+ default:
+ assert(false);
+ FALLTHROUGH_INTENDED;
+ case ROCKSDB_NAMESPACE::ribbon::kOneIn2:
+ fprintf(stderr, "== Failure: 50 percent\n");
+ expected_reseeds = 1.0;
+ break;
+ case ROCKSDB_NAMESPACE::ribbon::kOneIn20:
+ fprintf(stderr, "== Failure: 95 percent\n");
+ expected_reseeds = 0.053;
+ break;
+ case ROCKSDB_NAMESPACE::ribbon::kOneIn1000:
+ fprintf(stderr, "== Failure: 1/1000\n");
+ expected_reseeds = 0.001;
+ break;
+ }
+
+ uint64_t total_reseeds = 0;
+ uint64_t total_singles = 0;
+ uint64_t total_single_failures = 0;
+ uint64_t total_batch = 0;
+ uint64_t total_batch_successes = 0;
+ uint64_t total_fp_count = 0;
+ uint64_t total_added = 0;
+ uint64_t total_expand_trials = 0;
+ uint64_t total_expand_failures = 0;
+ double total_expand_overhead = 0.0;
+
+ uint64_t soln_query_nanos = 0;
+ uint64_t soln_query_count = 0;
+ uint64_t bloom_query_nanos = 0;
+ uint64_t isoln_query_nanos = 0;
+ uint64_t isoln_query_count = 0;
+
+ // Take different samples if you change thoroughness
+ ROCKSDB_NAMESPACE::Random32 rnd(FLAGS_thoroughness);
+
+ for (uint32_t i = 0; i < FLAGS_thoroughness; ++i) {
+ // We are going to choose num_to_add using an exponential distribution
+ // as noted above, but instead of randomly choosing them, we generate
+ // samples linearly using the golden ratio, which ensures a nice spread
+ // even for a small number of samples, and starting with the minimum
+ // number of slots to ensure it is tested.
+ double log_add =
+ std::fmod(0.6180339887498948482 * diff_log_add * i, diff_log_add) +
+ log_min_add;
+ uint32_t num_to_add = static_cast<uint32_t>(std::exp(log_add));
+
+ // Most of the time, test the Interleaved solution storage, but when
+ // we do we have to make num_slots a multiple of kCoeffBits. So
+ // sometimes we want to test without that limitation.
+ bool test_interleaved = (i % 7) != 6;
+
+ // Compute num_slots, and re-adjust num_to_add to get as close as possible
+ // to next num_slots, to stress that num_slots in terms of construction
+ // success. Ensure at least one iteration:
+ Index num_slots = Index{0} - 1;
+ --num_to_add;
+ for (;;) {
+ Index next_num_slots = SimpleSoln::RoundUpNumSlots(
+ ConfigHelper::GetNumSlots(num_to_add + 1, cs));
+ if (test_interleaved) {
+ next_num_slots = InterleavedSoln::RoundUpNumSlots(next_num_slots);
+ // assert idempotent
+ EXPECT_EQ(next_num_slots,
+ InterleavedSoln::RoundUpNumSlots(next_num_slots));
+ }
+ // assert idempotent with InterleavedSoln::RoundUpNumSlots
+ EXPECT_EQ(next_num_slots, SimpleSoln::RoundUpNumSlots(next_num_slots));
+
+ if (next_num_slots > num_slots) {
+ break;
+ }
+ num_slots = next_num_slots;
+ ++num_to_add;
+ }
+ assert(num_slots < Index{0} - 1);
+
+ total_added += num_to_add;
+
+ std::string prefix;
+ ROCKSDB_NAMESPACE::PutFixed32(&prefix, rnd.Next());
+
+ // Batch that must be added
+ std::string added_str = prefix + "added";
+ KeyGen keys_begin(added_str, 0);
+ KeyGen keys_end(added_str, num_to_add);
+
+ // A couple more that will probably be added
+ KeyGen one_more(prefix + "more", 1);
+ KeyGen two_more(prefix + "more", 2);
+
+ // Batch that may or may not be added
+ uint32_t batch_size =
+ static_cast<uint32_t>(2.0 * std::sqrt(num_slots - num_to_add));
+ if (batch_size < 10U) {
+ batch_size = 0;
+ }
+ std::string batch_str = prefix + "batch";
+ KeyGen batch_begin(batch_str, 0);
+ KeyGen batch_end(batch_str, batch_size);
+
+ // Batch never (successfully) added, but used for querying FP rate
+ std::string not_str = prefix + "not";
+ KeyGen other_keys_begin(not_str, 0);
+ KeyGen other_keys_end(not_str, FLAGS_max_check);
+
+ double overhead_ratio = 1.0 * num_slots / num_to_add;
+ if (FLAGS_verbose) {
+ fprintf(stderr, "Adding(%s) %u / %u Overhead: %g Batch size: %u\n",
+ test_interleaved ? "i" : "s", (unsigned)num_to_add,
+ (unsigned)num_slots, overhead_ratio, (unsigned)batch_size);
+ }
+
+ // Vary bytes for InterleavedSoln to use number of solution columns
+ // from 0 to max allowed by ResultRow type (and used by SimpleSoln).
+ // Specifically include 0 and max, and otherwise skew toward max.
+ uint32_t max_ibytes =
+ static_cast<uint32_t>(sizeof(ResultRow) * num_slots);
+ size_t ibytes;
+ if (i == 0) {
+ ibytes = 0;
+ } else if (i == 1) {
+ ibytes = max_ibytes;
+ } else {
+ // Skewed
+ ibytes =
+ std::max(rnd.Uniformish(max_ibytes), rnd.Uniformish(max_ibytes));
+ }
+ std::unique_ptr<char[]> idata(new char[ibytes]);
+ InterleavedSoln isoln(idata.get(), ibytes);
+
+ SimpleSoln soln;
+ Hasher hasher;
+ bool first_single;
+ bool second_single;
+ bool batch_success;
+ {
+ Banding banding;
+ // Traditional solve for a fixed set.
+ ASSERT_TRUE(
+ banding.ResetAndFindSeedToSolve(num_slots, keys_begin, keys_end));
+
+ Index occupied_count = banding.GetOccupiedCount();
+ Index more_added = 0;
+
+ if (TypeParam::kHomogeneous || overhead_ratio < 1.01 ||
+ batch_size == 0) {
+ // Homogeneous not compatible with backtracking because add
+ // doesn't fail. Small overhead ratio too packed to expect more
+ first_single = false;
+ second_single = false;
+ batch_success = false;
+ } else {
+ // Now to test backtracking, starting with guaranteed fail. By using
+ // the keys that will be used to test FP rate, we are then doing an
+ // extra check that after backtracking there are no remnants (e.g. in
+ // result side of banding) of these entries.
+ KeyGen other_keys_too_big_end = other_keys_begin;
+ other_keys_too_big_end += num_to_add;
+ banding.EnsureBacktrackSize(std::max(num_to_add, batch_size));
+ EXPECT_FALSE(banding.AddRangeOrRollBack(other_keys_begin,
+ other_keys_too_big_end));
+ EXPECT_EQ(occupied_count, banding.GetOccupiedCount());
+
+ // Check that we still have a good chance of adding a couple more
+ // individually
+ first_single = banding.Add(*one_more);
+ second_single = banding.Add(*two_more);
+ more_added += (first_single ? 1 : 0) + (second_single ? 1 : 0);
+ total_singles += 2U;
+ total_single_failures += 2U - more_added;
+
+ // Or as a batch
+ batch_success = banding.AddRangeOrRollBack(batch_begin, batch_end);
+ ++total_batch;
+ if (batch_success) {
+ more_added += batch_size;
+ ++total_batch_successes;
+ }
+ EXPECT_LE(banding.GetOccupiedCount(), occupied_count + more_added);
+ }
+
+ // Also verify that redundant adds are OK (no effect)
+ ASSERT_TRUE(
+ banding.AddRange(keys_begin, KeyGen(added_str, num_to_add / 8)));
+ EXPECT_LE(banding.GetOccupiedCount(), occupied_count + more_added);
+
+ // Now back-substitution
+ soln.BackSubstFrom(banding);
+ if (test_interleaved) {
+ isoln.BackSubstFrom(banding);
+ }
+
+ Seed reseeds = banding.GetOrdinalSeed();
+ total_reseeds += reseeds;
+
+ EXPECT_LE(reseeds, 8 + log2_thoroughness);
+ if (reseeds > log2_thoroughness + 1) {
+ fprintf(
+ stderr, "%s high reseeds at %u, %u/%u: %u\n",
+ reseeds > log2_thoroughness + 8 ? "ERROR Extremely" : "Somewhat",
+ static_cast<unsigned>(i), static_cast<unsigned>(num_to_add),
+ static_cast<unsigned>(num_slots), static_cast<unsigned>(reseeds));
+ }
+
+ if (reseeds > 0) {
+ // "Expand" test: given a failed construction, how likely is it to
+ // pass with same seed and more slots. At each step, we increase
+ // enough to ensure there is at least one shift within each coeff
+ // block.
+ ++total_expand_trials;
+ Index expand_count = 0;
+ Index ex_slots = num_slots;
+ banding.SetOrdinalSeed(0);
+ for (;; ++expand_count) {
+ ASSERT_LE(expand_count, log2_thoroughness);
+ ex_slots += ex_slots / kCoeffBits;
+ if (test_interleaved) {
+ ex_slots = InterleavedSoln::RoundUpNumSlots(ex_slots);
+ }
+ banding.Reset(ex_slots);
+ bool success = banding.AddRange(keys_begin, keys_end);
+ if (success) {
+ break;
+ }
+ }
+ total_expand_failures += expand_count;
+ total_expand_overhead += 1.0 * (ex_slots - num_slots) / num_slots;
+ }
+
+ hasher.SetOrdinalSeed(reseeds);
+ }
+ // soln and hasher now independent of Banding object
+
+ // Verify keys added
+ KeyGen cur = keys_begin;
+ while (cur != keys_end) {
+ ASSERT_TRUE(soln.FilterQuery(*cur, hasher));
+ ASSERT_TRUE(!test_interleaved || isoln.FilterQuery(*cur, hasher));
+ ++cur;
+ }
+ // We (maybe) snuck these in!
+ if (first_single) {
+ ASSERT_TRUE(soln.FilterQuery(*one_more, hasher));
+ ASSERT_TRUE(!test_interleaved || isoln.FilterQuery(*one_more, hasher));
+ }
+ if (second_single) {
+ ASSERT_TRUE(soln.FilterQuery(*two_more, hasher));
+ ASSERT_TRUE(!test_interleaved || isoln.FilterQuery(*two_more, hasher));
+ }
+ if (batch_success) {
+ cur = batch_begin;
+ while (cur != batch_end) {
+ ASSERT_TRUE(soln.FilterQuery(*cur, hasher));
+ ASSERT_TRUE(!test_interleaved || isoln.FilterQuery(*cur, hasher));
+ ++cur;
+ }
+ }
+
+ // Check FP rate (depends only on number of result bits == solution
+ // columns)
+ Index fp_count = 0;
+ cur = other_keys_begin;
+ {
+ ROCKSDB_NAMESPACE::StopWatchNano timer(
+ ROCKSDB_NAMESPACE::SystemClock::Default().get(), true);
+ while (cur != other_keys_end) {
+ bool fp = soln.FilterQuery(*cur, hasher);
+ fp_count += fp ? 1 : 0;
+ ++cur;
+ }
+ soln_query_nanos += timer.ElapsedNanos();
+ soln_query_count += FLAGS_max_check;
+ }
+ {
+ double expected_fp_count = soln.ExpectedFpRate() * FLAGS_max_check;
+ // For expected FP rate, also include false positives due to collisions
+ // in Hash value. (Negligible for 64-bit, can matter for 32-bit.)
+ double correction =
+ FLAGS_max_check * ExpectedCollisionFpRate(hasher, num_to_add);
+
+ // NOTE: rare violations expected with kHomogeneous
+ EXPECT_LE(fp_count,
+ FrequentPoissonUpperBound(expected_fp_count + correction));
+ EXPECT_GE(fp_count,
+ FrequentPoissonLowerBound(expected_fp_count + correction));
+ }
+ total_fp_count += fp_count;
+
+ // And also check FP rate for isoln
+ if (test_interleaved) {
+ Index ifp_count = 0;
+ cur = other_keys_begin;
+ ROCKSDB_NAMESPACE::StopWatchNano timer(
+ ROCKSDB_NAMESPACE::SystemClock::Default().get(), true);
+ while (cur != other_keys_end) {
+ ifp_count += isoln.FilterQuery(*cur, hasher) ? 1 : 0;
+ ++cur;
+ }
+ isoln_query_nanos += timer.ElapsedNanos();
+ isoln_query_count += FLAGS_max_check;
+ {
+ double expected_fp_count = isoln.ExpectedFpRate() * FLAGS_max_check;
+ // For expected FP rate, also include false positives due to
+ // collisions in Hash value. (Negligible for 64-bit, can matter for
+ // 32-bit.)
+ double correction =
+ FLAGS_max_check * ExpectedCollisionFpRate(hasher, num_to_add);
+
+ // NOTE: rare violations expected with kHomogeneous
+ EXPECT_LE(ifp_count,
+ FrequentPoissonUpperBound(expected_fp_count + correction));
+
+ // FIXME: why sometimes can we slightly "beat the odds"?
+ // (0.95 factor should not be needed)
+ EXPECT_GE(ifp_count, FrequentPoissonLowerBound(
+ 0.95 * expected_fp_count + correction));
+ }
+ // Since the bits used in isoln are a subset of the bits used in soln,
+ // it cannot have fewer FPs
+ EXPECT_GE(ifp_count, fp_count);
+ }
+
+ // And compare to Bloom time, for fun
+ if (ibytes >= /* minimum Bloom impl bytes*/ 64) {
+ Index bfp_count = 0;
+ cur = other_keys_begin;
+ ROCKSDB_NAMESPACE::StopWatchNano timer(
+ ROCKSDB_NAMESPACE::SystemClock::Default().get(), true);
+ while (cur != other_keys_end) {
+ uint64_t h = hasher.GetHash(*cur);
+ uint32_t h1 = ROCKSDB_NAMESPACE::Lower32of64(h);
+ uint32_t h2 = sizeof(Hash) >= 8 ? ROCKSDB_NAMESPACE::Upper32of64(h)
+ : h1 * 0x9e3779b9;
+ bfp_count +=
+ ROCKSDB_NAMESPACE::FastLocalBloomImpl::HashMayMatch(
+ h1, h2, static_cast<uint32_t>(ibytes), 6, idata.get())
+ ? 1
+ : 0;
+ ++cur;
+ }
+ bloom_query_nanos += timer.ElapsedNanos();
+ // ensure bfp_count is used
+ ASSERT_LT(bfp_count, FLAGS_max_check);
+ }
+ }
+
+ // "outside" == key not in original set so either negative or false positive
+ fprintf(stderr,
+ "Simple outside query, hot, incl hashing, ns/key: %g\n",
+ 1.0 * soln_query_nanos / soln_query_count);
+ fprintf(stderr,
+ "Interleaved outside query, hot, incl hashing, ns/key: %g\n",
+ 1.0 * isoln_query_nanos / isoln_query_count);
+ fprintf(stderr,
+ "Bloom outside query, hot, incl hashing, ns/key: %g\n",
+ 1.0 * bloom_query_nanos / soln_query_count);
+
+ if (TypeParam::kHomogeneous) {
+ EXPECT_EQ(total_reseeds, 0U);
+ } else {
+ double average_reseeds = 1.0 * total_reseeds / FLAGS_thoroughness;
+ fprintf(stderr, "Average re-seeds: %g\n", average_reseeds);
+ // Values above were chosen to target around 50% chance of encoding
+ // success rate (average of 1.0 re-seeds) or slightly better. But 1.15 is
+ // also close enough.
+ EXPECT_LE(total_reseeds,
+ InfrequentPoissonUpperBound(1.15 * expected_reseeds *
+ FLAGS_thoroughness));
+ // Would use 0.85 here instead of 0.75, but
+ // TypesAndSettings_Hash32_SmallKeyGen can "beat the odds" because of
+ // sequential keys with a small, cheap hash function. We accept that
+ // there are surely inputs that are somewhat bad for this setup, but
+ // these somewhat good inputs are probably more likely.
+ EXPECT_GE(total_reseeds,
+ InfrequentPoissonLowerBound(0.75 * expected_reseeds *
+ FLAGS_thoroughness));
+ }
+
+ if (total_expand_trials > 0) {
+ double average_expand_failures =
+ 1.0 * total_expand_failures / total_expand_trials;
+ fprintf(stderr, "Average expand failures, and overhead: %g, %g\n",
+ average_expand_failures,
+ total_expand_overhead / total_expand_trials);
+ // Seems to be a generous allowance
+ EXPECT_LE(total_expand_failures,
+ InfrequentPoissonUpperBound(1.0 * total_expand_trials));
+ } else {
+ fprintf(stderr, "Average expand failures: N/A\n");
+ }
+
+ if (total_singles > 0) {
+ double single_failure_rate = 1.0 * total_single_failures / total_singles;
+ fprintf(stderr, "Add'l single, failure rate: %g\n", single_failure_rate);
+ // A rough bound (one sided) based on nothing in particular
+ double expected_single_failures = 1.0 * total_singles /
+ (sizeof(CoeffRow) == 16 ? 128
+ : TypeParam::kUseSmash ? 64
+ : 32);
+ EXPECT_LE(total_single_failures,
+ InfrequentPoissonUpperBound(expected_single_failures));
+ }
+
+ if (total_batch > 0) {
+ // Counting successes here for Poisson to approximate the Binomial
+ // distribution.
+ // A rough bound (one sided) based on nothing in particular.
+ double expected_batch_successes = 1.0 * total_batch / 2;
+ uint64_t lower_bound =
+ InfrequentPoissonLowerBound(expected_batch_successes);
+ fprintf(stderr, "Add'l batch, success rate: %g (>= %g)\n",
+ 1.0 * total_batch_successes / total_batch,
+ 1.0 * lower_bound / total_batch);
+ EXPECT_GE(total_batch_successes, lower_bound);
+ }
+
+ {
+ uint64_t total_checked = uint64_t{FLAGS_max_check} * FLAGS_thoroughness;
+ double expected_total_fp_count =
+ total_checked * std::pow(0.5, 8U * sizeof(ResultRow));
+ // For expected FP rate, also include false positives due to collisions
+ // in Hash value. (Negligible for 64-bit, can matter for 32-bit.)
+ double average_added = 1.0 * total_added / FLAGS_thoroughness;
+ expected_total_fp_count +=
+ total_checked * ExpectedCollisionFpRate(Hasher(), average_added);
+
+ uint64_t upper_bound =
+ InfrequentPoissonUpperBound(expected_total_fp_count);
+ uint64_t lower_bound =
+ InfrequentPoissonLowerBound(expected_total_fp_count);
+ fprintf(stderr, "Average FP rate: %g (~= %g, <= %g, >= %g)\n",
+ 1.0 * total_fp_count / total_checked,
+ expected_total_fp_count / total_checked,
+ 1.0 * upper_bound / total_checked,
+ 1.0 * lower_bound / total_checked);
+ EXPECT_LE(total_fp_count, upper_bound);
+ EXPECT_GE(total_fp_count, lower_bound);
+ }
+ }
+}
+
+TYPED_TEST(RibbonTypeParamTest, Extremes) {
+ IMPORT_RIBBON_TYPES_AND_SETTINGS(TypeParam);
+ IMPORT_RIBBON_IMPL_TYPES(TypeParam);
+ using KeyGen = typename TypeParam::KeyGen;
+
+ size_t bytes = 128 * 1024;
+ std::unique_ptr<char[]> buf(new char[bytes]);
+ InterleavedSoln isoln(buf.get(), bytes);
+ SimpleSoln soln;
+ Hasher hasher;
+ Banding banding;
+
+ // ########################################
+ // Add zero keys to minimal number of slots
+ KeyGen begin_and_end("foo", 123);
+ ASSERT_TRUE(banding.ResetAndFindSeedToSolve(
+ /*slots*/ kCoeffBits, begin_and_end, begin_and_end, /*first seed*/ 0,
+ /* seed mask*/ 0));
+
+ soln.BackSubstFrom(banding);
+ isoln.BackSubstFrom(banding);
+
+ // Because there's plenty of memory, we expect the interleaved solution to
+ // use maximum supported columns (same as simple solution)
+ ASSERT_EQ(isoln.GetUpperNumColumns(), 8U * sizeof(ResultRow));
+ ASSERT_EQ(isoln.GetUpperStartBlock(), 0U);
+
+ // Somewhat oddly, we expect same FP rate as if we had essentially filled
+ // up the slots.
+ KeyGen other_keys_begin("not", 0);
+ KeyGen other_keys_end("not", FLAGS_max_check);
+
+ Index fp_count = 0;
+ KeyGen cur = other_keys_begin;
+ while (cur != other_keys_end) {
+ bool isoln_query_result = isoln.FilterQuery(*cur, hasher);
+ bool soln_query_result = soln.FilterQuery(*cur, hasher);
+ // Solutions are equivalent
+ ASSERT_EQ(isoln_query_result, soln_query_result);
+ if (!TypeParam::kHomogeneous) {
+ // And in fact we only expect an FP when ResultRow is 0
+ // (except Homogeneous)
+ ASSERT_EQ(soln_query_result, hasher.GetResultRowFromHash(
+ hasher.GetHash(*cur)) == ResultRow{0});
+ }
+ fp_count += soln_query_result ? 1 : 0;
+ ++cur;
+ }
+ {
+ ASSERT_EQ(isoln.ExpectedFpRate(), soln.ExpectedFpRate());
+ double expected_fp_count = isoln.ExpectedFpRate() * FLAGS_max_check;
+ EXPECT_LE(fp_count, InfrequentPoissonUpperBound(expected_fp_count));
+ if (TypeParam::kHomogeneous) {
+ // Pseudorandom garbage in Homogeneous filter can "beat the odds" if
+ // nothing added
+ } else {
+ EXPECT_GE(fp_count, InfrequentPoissonLowerBound(expected_fp_count));
+ }
+ }
+
+ // ######################################################
+ // Use zero bytes for interleaved solution (key(s) added)
+
+ // Add one key
+ KeyGen key_begin("added", 0);
+ KeyGen key_end("added", 1);
+ ASSERT_TRUE(banding.ResetAndFindSeedToSolve(
+ /*slots*/ kCoeffBits, key_begin, key_end, /*first seed*/ 0,
+ /* seed mask*/ 0));
+
+ InterleavedSoln isoln2(nullptr, /*bytes*/ 0);
+
+ isoln2.BackSubstFrom(banding);
+
+ ASSERT_EQ(isoln2.GetUpperNumColumns(), 0U);
+ ASSERT_EQ(isoln2.GetUpperStartBlock(), 0U);
+
+ // All queries return true
+ ASSERT_TRUE(isoln2.FilterQuery(*other_keys_begin, hasher));
+ ASSERT_EQ(isoln2.ExpectedFpRate(), 1.0);
+}
+
+TEST(RibbonTest, AllowZeroStarts) {
+ IMPORT_RIBBON_TYPES_AND_SETTINGS(TypesAndSettings_AllowZeroStarts);
+ IMPORT_RIBBON_IMPL_TYPES(TypesAndSettings_AllowZeroStarts);
+ using KeyGen = StandardKeyGen;
+
+ InterleavedSoln isoln(nullptr, /*bytes*/ 0);
+ SimpleSoln soln;
+ Hasher hasher;
+ Banding banding;
+
+ KeyGen begin("foo", 0);
+ KeyGen end("foo", 1);
+ // Can't add 1 entry
+ ASSERT_FALSE(banding.ResetAndFindSeedToSolve(/*slots*/ 0, begin, end));
+
+ KeyGen begin_and_end("foo", 123);
+ // Can add 0 entries
+ ASSERT_TRUE(banding.ResetAndFindSeedToSolve(/*slots*/ 0, begin_and_end,
+ begin_and_end));
+
+ Seed reseeds = banding.GetOrdinalSeed();
+ ASSERT_EQ(reseeds, 0U);
+ hasher.SetOrdinalSeed(reseeds);
+
+ // Can construct 0-slot solutions
+ isoln.BackSubstFrom(banding);
+ soln.BackSubstFrom(banding);
+
+ // Should always return false
+ ASSERT_FALSE(isoln.FilterQuery(*begin, hasher));
+ ASSERT_FALSE(soln.FilterQuery(*begin, hasher));
+
+ // And report that in FP rate
+ ASSERT_EQ(isoln.ExpectedFpRate(), 0.0);
+ ASSERT_EQ(soln.ExpectedFpRate(), 0.0);
+}
+
+TEST(RibbonTest, RawAndOrdinalSeeds) {
+ StandardHasher<TypesAndSettings_Seed64> hasher64;
+ StandardHasher<DefaultTypesAndSettings> hasher64_32;
+ StandardHasher<TypesAndSettings_Hash32> hasher32;
+ StandardHasher<TypesAndSettings_Seed8> hasher8;
+
+ for (uint32_t limit : {0xffU, 0xffffU}) {
+ std::vector<bool> seen(limit + 1);
+ for (uint32_t i = 0; i < limit; ++i) {
+ hasher64.SetOrdinalSeed(i);
+ auto raw64 = hasher64.GetRawSeed();
+ hasher32.SetOrdinalSeed(i);
+ auto raw32 = hasher32.GetRawSeed();
+ hasher8.SetOrdinalSeed(static_cast<uint8_t>(i));
+ auto raw8 = hasher8.GetRawSeed();
+ {
+ hasher64_32.SetOrdinalSeed(i);
+ auto raw64_32 = hasher64_32.GetRawSeed();
+ ASSERT_EQ(raw64_32, raw32); // Same size seed
+ }
+ if (i == 0) {
+ // Documented that ordinal seed 0 == raw seed 0
+ ASSERT_EQ(raw64, 0U);
+ ASSERT_EQ(raw32, 0U);
+ ASSERT_EQ(raw8, 0U);
+ } else {
+ // Extremely likely that upper bits are set
+ ASSERT_GT(raw64, raw32);
+ ASSERT_GT(raw32, raw8);
+ }
+ // Hashers agree on lower bits
+ ASSERT_EQ(static_cast<uint32_t>(raw64), raw32);
+ ASSERT_EQ(static_cast<uint8_t>(raw32), raw8);
+
+ // The translation is one-to-one for this size prefix
+ uint32_t v = static_cast<uint32_t>(raw32 & limit);
+ ASSERT_EQ(raw64 & limit, v);
+ ASSERT_FALSE(seen[v]);
+ seen[v] = true;
+ }
+ }
+}
+
+namespace {
+
+struct PhsfInputGen {
+ PhsfInputGen(const std::string& prefix, uint64_t id) : id_(id) {
+ val_.first = prefix;
+ ROCKSDB_NAMESPACE::PutFixed64(&val_.first, /*placeholder*/ 0);
+ }
+
+ // Prefix (only one required)
+ PhsfInputGen& operator++() {
+ ++id_;
+ return *this;
+ }
+
+ const std::pair<std::string, uint8_t>& operator*() {
+ // Use multiplication to mix things up a little in the key
+ ROCKSDB_NAMESPACE::EncodeFixed64(&val_.first[val_.first.size() - 8],
+ id_ * uint64_t{0x1500000001});
+ // Occasionally repeat values etc.
+ val_.second = static_cast<uint8_t>(id_ * 7 / 8);
+ return val_;
+ }
+
+ const std::pair<std::string, uint8_t>* operator->() { return &**this; }
+
+ bool operator==(const PhsfInputGen& other) {
+ // Same prefix is assumed
+ return id_ == other.id_;
+ }
+ bool operator!=(const PhsfInputGen& other) {
+ // Same prefix is assumed
+ return id_ != other.id_;
+ }
+
+ uint64_t id_;
+ std::pair<std::string, uint8_t> val_;
+};
+
+struct PhsfTypesAndSettings : public DefaultTypesAndSettings {
+ static constexpr bool kIsFilter = false;
+};
+} // namespace
+
+TEST(RibbonTest, PhsfBasic) {
+ IMPORT_RIBBON_TYPES_AND_SETTINGS(PhsfTypesAndSettings);
+ IMPORT_RIBBON_IMPL_TYPES(PhsfTypesAndSettings);
+
+ Index num_slots = 12800;
+ Index num_to_add = static_cast<Index>(num_slots / 1.02);
+
+ PhsfInputGen begin("in", 0);
+ PhsfInputGen end("in", num_to_add);
+
+ std::unique_ptr<char[]> idata(new char[/*bytes*/ num_slots]);
+ InterleavedSoln isoln(idata.get(), /*bytes*/ num_slots);
+ SimpleSoln soln;
+ Hasher hasher;
+
+ {
+ Banding banding;
+ ASSERT_TRUE(banding.ResetAndFindSeedToSolve(num_slots, begin, end));
+
+ soln.BackSubstFrom(banding);
+ isoln.BackSubstFrom(banding);
+
+ hasher.SetOrdinalSeed(banding.GetOrdinalSeed());
+ }
+
+ for (PhsfInputGen cur = begin; cur != end; ++cur) {
+ ASSERT_EQ(cur->second, soln.PhsfQuery(cur->first, hasher));
+ ASSERT_EQ(cur->second, isoln.PhsfQuery(cur->first, hasher));
+ }
+}
+
+// Not a real test, but a tool used to build APIs in ribbon_config.h
+TYPED_TEST(RibbonTypeParamTest, FindOccupancy) {
+ IMPORT_RIBBON_TYPES_AND_SETTINGS(TypeParam);
+ IMPORT_RIBBON_IMPL_TYPES(TypeParam);
+ using KeyGen = typename TypeParam::KeyGen;
+
+ if (!FLAGS_find_occ) {
+ ROCKSDB_GTEST_BYPASS("Tool disabled during unit test runs");
+ return;
+ }
+
+ KeyGen cur(std::to_string(testing::UnitTest::GetInstance()->random_seed()),
+ 0);
+
+ Banding banding;
+ Index num_slots = InterleavedSoln::RoundUpNumSlots(FLAGS_find_min_slots);
+ Index max_slots = InterleavedSoln::RoundUpNumSlots(FLAGS_find_max_slots);
+ while (num_slots <= max_slots) {
+ std::map<int32_t, uint32_t> rem_histogram;
+ std::map<Index, uint32_t> slot_histogram;
+ if (FLAGS_find_slot_occ) {
+ for (Index i = 0; i < kCoeffBits; ++i) {
+ slot_histogram[i] = 0;
+ slot_histogram[num_slots - 1 - i] = 0;
+ slot_histogram[num_slots / 2 - kCoeffBits / 2 + i] = 0;
+ }
+ }
+ uint64_t total_added = 0;
+ for (uint32_t i = 0; i < FLAGS_find_iters; ++i) {
+ banding.Reset(num_slots);
+ uint32_t j = 0;
+ KeyGen end = cur;
+ end += num_slots + num_slots / 10;
+ for (; cur != end; ++cur) {
+ if (banding.Add(*cur)) {
+ ++j;
+ } else {
+ break;
+ }
+ }
+ total_added += j;
+ for (auto& slot : slot_histogram) {
+ slot.second += banding.IsOccupied(slot.first);
+ }
+
+ int32_t bucket =
+ static_cast<int32_t>(num_slots) - static_cast<int32_t>(j);
+ rem_histogram[bucket]++;
+ if (FLAGS_verbose) {
+ fprintf(stderr, "num_slots: %u i: %u / %u avg_overhead: %g\r",
+ static_cast<unsigned>(num_slots), static_cast<unsigned>(i),
+ static_cast<unsigned>(FLAGS_find_iters),
+ 1.0 * (i + 1) * num_slots / total_added);
+ }
+ }
+ if (FLAGS_verbose) {
+ fprintf(stderr, "\n");
+ }
+
+ uint32_t cumulative = 0;
+
+ double p50_rem = 0;
+ double p95_rem = 0;
+ double p99_9_rem = 0;
+
+ for (auto& h : rem_histogram) {
+ double before = 1.0 * cumulative / FLAGS_find_iters;
+ double not_after = 1.0 * (cumulative + h.second) / FLAGS_find_iters;
+ if (FLAGS_verbose) {
+ fprintf(stderr, "overhead: %g before: %g not_after: %g\n",
+ 1.0 * num_slots / (num_slots - h.first), before, not_after);
+ }
+ cumulative += h.second;
+ if (before < 0.5 && 0.5 <= not_after) {
+ // fake it with linear interpolation
+ double portion = (0.5 - before) / (not_after - before);
+ p50_rem = h.first + portion;
+ } else if (before < 0.95 && 0.95 <= not_after) {
+ // fake it with linear interpolation
+ double portion = (0.95 - before) / (not_after - before);
+ p95_rem = h.first + portion;
+ } else if (before < 0.999 && 0.999 <= not_after) {
+ // fake it with linear interpolation
+ double portion = (0.999 - before) / (not_after - before);
+ p99_9_rem = h.first + portion;
+ }
+ }
+ for (auto& slot : slot_histogram) {
+ fprintf(stderr, "slot[%u] occupied: %g\n", (unsigned)slot.first,
+ 1.0 * slot.second / FLAGS_find_iters);
+ }
+
+ double mean_rem =
+ (1.0 * FLAGS_find_iters * num_slots - total_added) / FLAGS_find_iters;
+ fprintf(
+ stderr,
+ "num_slots: %u iters: %u mean_ovr: %g p50_ovr: %g p95_ovr: %g "
+ "p99.9_ovr: %g mean_rem: %g p50_rem: %g p95_rem: %g p99.9_rem: %g\n",
+ static_cast<unsigned>(num_slots),
+ static_cast<unsigned>(FLAGS_find_iters),
+ 1.0 * num_slots / (num_slots - mean_rem),
+ 1.0 * num_slots / (num_slots - p50_rem),
+ 1.0 * num_slots / (num_slots - p95_rem),
+ 1.0 * num_slots / (num_slots - p99_9_rem), mean_rem, p50_rem, p95_rem,
+ p99_9_rem);
+
+ num_slots = std::max(
+ num_slots + 1, static_cast<Index>(num_slots * FLAGS_find_next_factor));
+ num_slots = InterleavedSoln::RoundUpNumSlots(num_slots);
+ }
+}
+
+// Not a real test, but a tool to understand Homogeneous Ribbon
+// behavior (TODO: configuration APIs & tests)
+TYPED_TEST(RibbonTypeParamTest, OptimizeHomogAtScale) {
+ IMPORT_RIBBON_TYPES_AND_SETTINGS(TypeParam);
+ IMPORT_RIBBON_IMPL_TYPES(TypeParam);
+ using KeyGen = typename TypeParam::KeyGen;
+
+ if (!FLAGS_optimize_homog) {
+ ROCKSDB_GTEST_BYPASS("Tool disabled during unit test runs");
+ return;
+ }
+
+ if (!TypeParam::kHomogeneous) {
+ ROCKSDB_GTEST_BYPASS("Only for Homogeneous Ribbon");
+ return;
+ }
+
+ KeyGen cur(std::to_string(testing::UnitTest::GetInstance()->random_seed()),
+ 0);
+
+ Banding banding;
+ Index num_slots = SimpleSoln::RoundUpNumSlots(FLAGS_optimize_homog_slots);
+ banding.Reset(num_slots);
+
+ // This and "band_ovr" is the "allocated overhead", or slots over added.
+ // It does not take into account FP rates.
+ double target_overhead = 1.20;
+ uint32_t num_added = 0;
+
+ do {
+ do {
+ (void)banding.Add(*cur);
+ ++cur;
+ ++num_added;
+ } while (1.0 * num_slots / num_added > target_overhead);
+
+ SimpleSoln soln;
+ soln.BackSubstFrom(banding);
+
+ std::array<uint32_t, 8U * sizeof(ResultRow)> fp_counts_by_cols;
+ fp_counts_by_cols.fill(0U);
+ for (uint32_t i = 0; i < FLAGS_optimize_homog_check; ++i) {
+ ResultRow r = soln.PhsfQuery(*cur, banding);
+ ++cur;
+ for (size_t j = 0; j < fp_counts_by_cols.size(); ++j) {
+ if ((r & 1) == 1) {
+ break;
+ }
+ fp_counts_by_cols[j]++;
+ r /= 2;
+ }
+ }
+ fprintf(stderr, "band_ovr: %g ", 1.0 * num_slots / num_added);
+ for (unsigned j = 0; j < fp_counts_by_cols.size(); ++j) {
+ double inv_fp_rate =
+ 1.0 * FLAGS_optimize_homog_check / fp_counts_by_cols[j];
+ double equiv_cols = std::log(inv_fp_rate) * 1.4426950409;
+ // Overhead vs. information-theoretic minimum based on observed
+ // FP rate (subject to sampling error, especially for low FP rates)
+ double actual_overhead =
+ 1.0 * (j + 1) * num_slots / (equiv_cols * num_added);
+ fprintf(stderr, "ovr_%u: %g ", j + 1, actual_overhead);
+ }
+ fprintf(stderr, "\n");
+ target_overhead -= FLAGS_optimize_homog_granularity;
+ } while (target_overhead > 1.0);
+}
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+#ifdef GFLAGS
+ ParseCommandLineFlags(&argc, &argv, true);
+#endif // GFLAGS
+ return RUN_ALL_TESTS();
+}
diff --git a/src/rocksdb/util/set_comparator.h b/src/rocksdb/util/set_comparator.h
new file mode 100644
index 000000000..e0e64436a
--- /dev/null
+++ b/src/rocksdb/util/set_comparator.h
@@ -0,0 +1,24 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#include "rocksdb/comparator.h"
+
+namespace ROCKSDB_NAMESPACE {
+// A comparator to be used in std::set
+struct SetComparator {
+ explicit SetComparator() : user_comparator_(BytewiseComparator()) {}
+ explicit SetComparator(const Comparator* user_comparator)
+ : user_comparator_(user_comparator ? user_comparator
+ : BytewiseComparator()) {}
+ bool operator()(const Slice& lhs, const Slice& rhs) const {
+ return user_comparator_->Compare(lhs, rhs) < 0;
+ }
+
+ private:
+ const Comparator* user_comparator_;
+};
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/single_thread_executor.h b/src/rocksdb/util/single_thread_executor.h
new file mode 100644
index 000000000..c69f2a292
--- /dev/null
+++ b/src/rocksdb/util/single_thread_executor.h
@@ -0,0 +1,56 @@
+// Copyright (c) Meta Platforms, Inc. and affiliates.
+//
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+#pragma once
+
+#if USE_COROUTINES
+#include <atomic>
+
+#include "folly/CPortability.h"
+#include "folly/CppAttributes.h"
+#include "folly/Executor.h"
+#include "util/async_file_reader.h"
+
+namespace ROCKSDB_NAMESPACE {
+// Implements a simple executor that runs callback functions in the same
+// thread, unlike CPUThreadExecutor which may schedule the callback on
+// another thread. Runs in a tight loop calling the queued callbacks,
+// and polls for async IO completions when idle. The completions will
+// resume suspended coroutines and they get added to the queue, which
+// will get picked up by this loop.
+// Any possibility of deadlock is precluded because the file system
+// guarantees that async IO completion callbacks will not be scheduled
+// to run in this thread or this executor.
+class SingleThreadExecutor : public folly::Executor {
+ public:
+ explicit SingleThreadExecutor(AsyncFileReader& reader)
+ : reader_(reader), busy_(false) {}
+
+ void add(folly::Func callback) override {
+ auto& q = q_;
+ q.push(std::move(callback));
+ if (q.size() == 1 && !busy_) {
+ while (!q.empty()) {
+ q.front()();
+ q.pop();
+
+ if (q.empty()) {
+ // Prevent recursion, as the Wait may queue resumed coroutines
+ busy_ = true;
+ reader_.Wait();
+ busy_ = false;
+ }
+ }
+ }
+ }
+
+ private:
+ std::queue<folly::Func> q_;
+ AsyncFileReader& reader_;
+ bool busy_;
+};
+} // namespace ROCKSDB_NAMESPACE
+#endif // USE_COROUTINES
diff --git a/src/rocksdb/util/slice.cc b/src/rocksdb/util/slice.cc
new file mode 100644
index 000000000..1fa21afcb
--- /dev/null
+++ b/src/rocksdb/util/slice.cc
@@ -0,0 +1,405 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2012 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "rocksdb/slice.h"
+
+#include <stdio.h>
+
+#include <algorithm>
+
+#include "rocksdb/convenience.h"
+#include "rocksdb/slice_transform.h"
+#include "rocksdb/utilities/object_registry.h"
+#include "rocksdb/utilities/options_type.h"
+#include "util/string_util.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+namespace {
+
+class FixedPrefixTransform : public SliceTransform {
+ private:
+ size_t prefix_len_;
+ std::string id_;
+
+ public:
+ explicit FixedPrefixTransform(size_t prefix_len) : prefix_len_(prefix_len) {
+ id_ = std::string(kClassName()) + "." + std::to_string(prefix_len_);
+ }
+
+ static const char* kClassName() { return "rocksdb.FixedPrefix"; }
+ static const char* kNickName() { return "fixed"; }
+ const char* Name() const override { return kClassName(); }
+ const char* NickName() const override { return kNickName(); }
+
+ bool IsInstanceOf(const std::string& name) const override {
+ if (name == id_) {
+ return true;
+ } else if (StartsWith(name, kNickName())) {
+ std::string alt_id =
+ std::string(kNickName()) + ":" + std::to_string(prefix_len_);
+ if (name == alt_id) {
+ return true;
+ }
+ }
+ return SliceTransform::IsInstanceOf(name);
+ }
+
+ std::string GetId() const override { return id_; }
+
+ Slice Transform(const Slice& src) const override {
+ assert(InDomain(src));
+ return Slice(src.data(), prefix_len_);
+ }
+
+ bool InDomain(const Slice& src) const override {
+ return (src.size() >= prefix_len_);
+ }
+
+ bool InRange(const Slice& dst) const override {
+ return (dst.size() == prefix_len_);
+ }
+
+ bool FullLengthEnabled(size_t* len) const override {
+ *len = prefix_len_;
+ return true;
+ }
+
+ bool SameResultWhenAppended(const Slice& prefix) const override {
+ return InDomain(prefix);
+ }
+};
+
+class CappedPrefixTransform : public SliceTransform {
+ private:
+ size_t cap_len_;
+ std::string id_;
+
+ public:
+ explicit CappedPrefixTransform(size_t cap_len) : cap_len_(cap_len) {
+ id_ = std::string(kClassName()) + "." + std::to_string(cap_len_);
+ }
+
+ static const char* kClassName() { return "rocksdb.CappedPrefix"; }
+ static const char* kNickName() { return "capped"; }
+ const char* Name() const override { return kClassName(); }
+ const char* NickName() const override { return kNickName(); }
+ std::string GetId() const override { return id_; }
+
+ bool IsInstanceOf(const std::string& name) const override {
+ if (name == id_) {
+ return true;
+ } else if (StartsWith(name, kNickName())) {
+ std::string alt_id =
+ std::string(kNickName()) + ":" + std::to_string(cap_len_);
+ if (name == alt_id) {
+ return true;
+ }
+ }
+ return SliceTransform::IsInstanceOf(name);
+ }
+
+ Slice Transform(const Slice& src) const override {
+ assert(InDomain(src));
+ return Slice(src.data(), std::min(cap_len_, src.size()));
+ }
+
+ bool InDomain(const Slice& /*src*/) const override { return true; }
+
+ bool InRange(const Slice& dst) const override {
+ return (dst.size() <= cap_len_);
+ }
+
+ bool FullLengthEnabled(size_t* len) const override {
+ *len = cap_len_;
+ return true;
+ }
+
+ bool SameResultWhenAppended(const Slice& prefix) const override {
+ return prefix.size() >= cap_len_;
+ }
+};
+
+class NoopTransform : public SliceTransform {
+ public:
+ explicit NoopTransform() {}
+
+ static const char* kClassName() { return "rocksdb.Noop"; }
+ const char* Name() const override { return kClassName(); }
+
+ Slice Transform(const Slice& src) const override { return src; }
+
+ bool InDomain(const Slice& /*src*/) const override { return true; }
+
+ bool InRange(const Slice& /*dst*/) const override { return true; }
+
+ bool SameResultWhenAppended(const Slice& /*prefix*/) const override {
+ return false;
+ }
+};
+
+} // end namespace
+
+const SliceTransform* NewFixedPrefixTransform(size_t prefix_len) {
+ return new FixedPrefixTransform(prefix_len);
+}
+
+const SliceTransform* NewCappedPrefixTransform(size_t cap_len) {
+ return new CappedPrefixTransform(cap_len);
+}
+
+const SliceTransform* NewNoopTransform() { return new NoopTransform; }
+
+#ifndef ROCKSDB_LITE
+static int RegisterBuiltinSliceTransform(ObjectLibrary& library,
+ const std::string& /*arg*/) {
+ // For the builtin transforms, the format is typically
+ // [Name].[0-9]+ or [NickName]:[0-9]+
+ library.AddFactory<const SliceTransform>(
+ NoopTransform::kClassName(),
+ [](const std::string& /*uri*/,
+ std::unique_ptr<const SliceTransform>* guard,
+ std::string* /*errmsg*/) {
+ guard->reset(NewNoopTransform());
+ return guard->get();
+ });
+ library.AddFactory<const SliceTransform>(
+ ObjectLibrary::PatternEntry(FixedPrefixTransform::kNickName(), false)
+ .AddNumber(":"),
+ [](const std::string& uri, std::unique_ptr<const SliceTransform>* guard,
+ std::string* /*errmsg*/) {
+ auto colon = uri.find(":");
+ auto len = ParseSizeT(uri.substr(colon + 1));
+ guard->reset(NewFixedPrefixTransform(len));
+ return guard->get();
+ });
+ library.AddFactory<const SliceTransform>(
+ ObjectLibrary::PatternEntry(FixedPrefixTransform::kClassName(), false)
+ .AddNumber("."),
+ [](const std::string& uri, std::unique_ptr<const SliceTransform>* guard,
+ std::string* /*errmsg*/) {
+ auto len = ParseSizeT(
+ uri.substr(strlen(FixedPrefixTransform::kClassName()) + 1));
+ guard->reset(NewFixedPrefixTransform(len));
+ return guard->get();
+ });
+ library.AddFactory<const SliceTransform>(
+ ObjectLibrary::PatternEntry(CappedPrefixTransform::kNickName(), false)
+ .AddNumber(":"),
+ [](const std::string& uri, std::unique_ptr<const SliceTransform>* guard,
+ std::string* /*errmsg*/) {
+ auto colon = uri.find(":");
+ auto len = ParseSizeT(uri.substr(colon + 1));
+ guard->reset(NewCappedPrefixTransform(len));
+ return guard->get();
+ });
+ library.AddFactory<const SliceTransform>(
+ ObjectLibrary::PatternEntry(CappedPrefixTransform::kClassName(), false)
+ .AddNumber("."),
+ [](const std::string& uri, std::unique_ptr<const SliceTransform>* guard,
+ std::string* /*errmsg*/) {
+ auto len = ParseSizeT(
+ uri.substr(strlen(CappedPrefixTransform::kClassName()) + 1));
+ guard->reset(NewCappedPrefixTransform(len));
+ return guard->get();
+ });
+ size_t num_types;
+ return static_cast<int>(library.GetFactoryCount(&num_types));
+}
+#endif // ROCKSDB_LITE
+
+Status SliceTransform::CreateFromString(
+ const ConfigOptions& config_options, const std::string& value,
+ std::shared_ptr<const SliceTransform>* result) {
+#ifndef ROCKSDB_LITE
+ static std::once_flag once;
+ std::call_once(once, [&]() {
+ RegisterBuiltinSliceTransform(*(ObjectLibrary::Default().get()), "");
+ });
+#endif // ROCKSDB_LITE
+ std::string id;
+ std::unordered_map<std::string, std::string> opt_map;
+ Status status = Customizable::GetOptionsMap(config_options, result->get(),
+ value, &id, &opt_map);
+ if (!status.ok()) { // GetOptionsMap failed
+ return status;
+ } else if (id.empty() && opt_map.empty()) {
+ result->reset();
+ } else {
+#ifndef ROCKSDB_LITE
+ status = config_options.registry->NewSharedObject(id, result);
+#else
+ auto Matches = [](const std::string& input, size_t size,
+ const char* pattern, char sep) {
+ auto plen = strlen(pattern);
+ return (size > plen + 2 && input[plen] == sep &&
+ StartsWith(input, pattern));
+ };
+
+ auto size = id.size();
+ if (id == NoopTransform::kClassName()) {
+ result->reset(NewNoopTransform());
+ } else if (Matches(id, size, FixedPrefixTransform::kNickName(), ':')) {
+ auto fixed = strlen(FixedPrefixTransform::kNickName());
+ auto len = ParseSizeT(id.substr(fixed + 1));
+ result->reset(NewFixedPrefixTransform(len));
+ } else if (Matches(id, size, CappedPrefixTransform::kNickName(), ':')) {
+ auto capped = strlen(CappedPrefixTransform::kNickName());
+ auto len = ParseSizeT(id.substr(capped + 1));
+ result->reset(NewCappedPrefixTransform(len));
+ } else if (Matches(id, size, CappedPrefixTransform::kClassName(), '.')) {
+ auto capped = strlen(CappedPrefixTransform::kClassName());
+ auto len = ParseSizeT(id.substr(capped + 1));
+ result->reset(NewCappedPrefixTransform(len));
+ } else if (Matches(id, size, FixedPrefixTransform::kClassName(), '.')) {
+ auto fixed = strlen(FixedPrefixTransform::kClassName());
+ auto len = ParseSizeT(id.substr(fixed + 1));
+ result->reset(NewFixedPrefixTransform(len));
+ } else {
+ status = Status::NotSupported("Cannot load object in LITE mode ", id);
+ }
+#endif // ROCKSDB_LITE
+ if (config_options.ignore_unsupported_options && status.IsNotSupported()) {
+ return Status::OK();
+ } else if (status.ok()) {
+ SliceTransform* transform = const_cast<SliceTransform*>(result->get());
+ status =
+ Customizable::ConfigureNewObject(config_options, transform, opt_map);
+ }
+ }
+ return status;
+}
+
+std::string SliceTransform::AsString() const {
+#ifndef ROCKSDB_LITE
+ if (HasRegisteredOptions()) {
+ ConfigOptions opts;
+ opts.delimiter = ";";
+ return ToString(opts);
+ }
+#endif // ROCKSDB_LITE
+ return GetId();
+}
+
+// 2 small internal utility functions, for efficient hex conversions
+// and no need for snprintf, toupper etc...
+// Originally from wdt/util/EncryptionUtils.cpp - for
+// std::to_string(true)/DecodeHex:
+char toHex(unsigned char v) {
+ if (v <= 9) {
+ return '0' + v;
+ }
+ return 'A' + v - 10;
+}
+// most of the code is for validation/error check
+int fromHex(char c) {
+ // toupper:
+ if (c >= 'a' && c <= 'f') {
+ c -= ('a' - 'A'); // aka 0x20
+ }
+ // validation
+ if (c < '0' || (c > '9' && (c < 'A' || c > 'F'))) {
+ return -1; // invalid not 0-9A-F hex char
+ }
+ if (c <= '9') {
+ return c - '0';
+ }
+ return c - 'A' + 10;
+}
+
+Slice::Slice(const SliceParts& parts, std::string* buf) {
+ size_t length = 0;
+ for (int i = 0; i < parts.num_parts; ++i) {
+ length += parts.parts[i].size();
+ }
+ buf->reserve(length);
+
+ for (int i = 0; i < parts.num_parts; ++i) {
+ buf->append(parts.parts[i].data(), parts.parts[i].size());
+ }
+ data_ = buf->data();
+ size_ = buf->size();
+}
+
+// Return a string that contains the copy of the referenced data.
+std::string Slice::ToString(bool hex) const {
+ std::string result; // RVO/NRVO/move
+ if (hex) {
+ result.reserve(2 * size_);
+ for (size_t i = 0; i < size_; ++i) {
+ unsigned char c = data_[i];
+ result.push_back(toHex(c >> 4));
+ result.push_back(toHex(c & 0xf));
+ }
+ return result;
+ } else {
+ result.assign(data_, size_);
+ return result;
+ }
+}
+
+// Originally from rocksdb/utilities/ldb_cmd.h
+bool Slice::DecodeHex(std::string* result) const {
+ std::string::size_type len = size_;
+ if (len % 2) {
+ // Hex string must be even number of hex digits to get complete bytes back
+ return false;
+ }
+ if (!result) {
+ return false;
+ }
+ result->clear();
+ result->reserve(len / 2);
+
+ for (size_t i = 0; i < len;) {
+ int h1 = fromHex(data_[i++]);
+ if (h1 < 0) {
+ return false;
+ }
+ int h2 = fromHex(data_[i++]);
+ if (h2 < 0) {
+ return false;
+ }
+ result->push_back(static_cast<char>((h1 << 4) | h2));
+ }
+ return true;
+}
+
+PinnableSlice::PinnableSlice(PinnableSlice&& other) {
+ *this = std::move(other);
+}
+
+PinnableSlice& PinnableSlice::operator=(PinnableSlice&& other) {
+ if (this != &other) {
+ Cleanable::Reset();
+ Cleanable::operator=(std::move(other));
+ size_ = other.size_;
+ pinned_ = other.pinned_;
+ if (pinned_) {
+ data_ = other.data_;
+ // When it's pinned, buf should no longer be of use.
+ } else {
+ if (other.buf_ == &other.self_space_) {
+ self_space_ = std::move(other.self_space_);
+ buf_ = &self_space_;
+ data_ = buf_->data();
+ } else {
+ buf_ = other.buf_;
+ data_ = other.data_;
+ }
+ }
+ other.self_space_.clear();
+ other.buf_ = &other.self_space_;
+ other.pinned_ = false;
+ other.PinSelf();
+ }
+ return *this;
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/slice_test.cc b/src/rocksdb/util/slice_test.cc
new file mode 100644
index 000000000..e1c35d567
--- /dev/null
+++ b/src/rocksdb/util/slice_test.cc
@@ -0,0 +1,191 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "rocksdb/slice.h"
+
+#include <gtest/gtest.h>
+
+#include "port/port.h"
+#include "port/stack_trace.h"
+#include "rocksdb/data_structure.h"
+#include "rocksdb/types.h"
+#include "test_util/testharness.h"
+#include "test_util/testutil.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+TEST(SliceTest, StringView) {
+ std::string s = "foo";
+ std::string_view sv = s;
+ ASSERT_EQ(Slice(s), Slice(sv));
+ ASSERT_EQ(Slice(s), Slice(std::move(sv)));
+}
+
+// Use this to keep track of the cleanups that were actually performed
+void Multiplier(void* arg1, void* arg2) {
+ int* res = reinterpret_cast<int*>(arg1);
+ int* num = reinterpret_cast<int*>(arg2);
+ *res *= *num;
+}
+
+class PinnableSliceTest : public testing::Test {
+ public:
+ void AssertSameData(const std::string& expected, const PinnableSlice& slice) {
+ std::string got;
+ got.assign(slice.data(), slice.size());
+ ASSERT_EQ(expected, got);
+ }
+};
+
+// Test that the external buffer is moved instead of being copied.
+TEST_F(PinnableSliceTest, MoveExternalBuffer) {
+ Slice s("123");
+ std::string buf;
+ PinnableSlice v1(&buf);
+ v1.PinSelf(s);
+
+ PinnableSlice v2(std::move(v1));
+ ASSERT_EQ(buf.data(), v2.data());
+ ASSERT_EQ(&buf, v2.GetSelf());
+
+ PinnableSlice v3;
+ v3 = std::move(v2);
+ ASSERT_EQ(buf.data(), v3.data());
+ ASSERT_EQ(&buf, v3.GetSelf());
+}
+
+TEST_F(PinnableSliceTest, Move) {
+ int n2 = 2;
+ int res = 1;
+ const std::string const_str1 = "123";
+ const std::string const_str2 = "ABC";
+ Slice slice1(const_str1);
+ Slice slice2(const_str2);
+
+ {
+ // Test move constructor on a pinned slice.
+ res = 1;
+ PinnableSlice v1;
+ v1.PinSlice(slice1, Multiplier, &res, &n2);
+ PinnableSlice v2(std::move(v1));
+
+ // Since v1's Cleanable has been moved to v2,
+ // no cleanup should happen in Reset.
+ v1.Reset();
+ ASSERT_EQ(1, res);
+
+ AssertSameData(const_str1, v2);
+ }
+ // v2 is cleaned up.
+ ASSERT_EQ(2, res);
+
+ {
+ // Test move constructor on an unpinned slice.
+ PinnableSlice v1;
+ v1.PinSelf(slice1);
+ PinnableSlice v2(std::move(v1));
+
+ AssertSameData(const_str1, v2);
+ }
+
+ {
+ // Test move assignment from a pinned slice to
+ // another pinned slice.
+ res = 1;
+ PinnableSlice v1;
+ v1.PinSlice(slice1, Multiplier, &res, &n2);
+ PinnableSlice v2;
+ v2.PinSlice(slice2, Multiplier, &res, &n2);
+ v2 = std::move(v1);
+
+ // v2's Cleanable will be Reset before moving
+ // anything from v1.
+ ASSERT_EQ(2, res);
+ // Since v1's Cleanable has been moved to v2,
+ // no cleanup should happen in Reset.
+ v1.Reset();
+ ASSERT_EQ(2, res);
+
+ AssertSameData(const_str1, v2);
+ }
+ // The Cleanable moved from v1 to v2 will be Reset.
+ ASSERT_EQ(4, res);
+
+ {
+ // Test move assignment from a pinned slice to
+ // an unpinned slice.
+ res = 1;
+ PinnableSlice v1;
+ v1.PinSlice(slice1, Multiplier, &res, &n2);
+ PinnableSlice v2;
+ v2.PinSelf(slice2);
+ v2 = std::move(v1);
+
+ // Since v1's Cleanable has been moved to v2,
+ // no cleanup should happen in Reset.
+ v1.Reset();
+ ASSERT_EQ(1, res);
+
+ AssertSameData(const_str1, v2);
+ }
+ // The Cleanable moved from v1 to v2 will be Reset.
+ ASSERT_EQ(2, res);
+
+ {
+ // Test move assignment from an upinned slice to
+ // another unpinned slice.
+ PinnableSlice v1;
+ v1.PinSelf(slice1);
+ PinnableSlice v2;
+ v2.PinSelf(slice2);
+ v2 = std::move(v1);
+
+ AssertSameData(const_str1, v2);
+ }
+
+ {
+ // Test move assignment from an upinned slice to
+ // a pinned slice.
+ res = 1;
+ PinnableSlice v1;
+ v1.PinSelf(slice1);
+ PinnableSlice v2;
+ v2.PinSlice(slice2, Multiplier, &res, &n2);
+ v2 = std::move(v1);
+
+ // v2's Cleanable will be Reset before moving
+ // anything from v1.
+ ASSERT_EQ(2, res);
+
+ AssertSameData(const_str1, v2);
+ }
+ // No Cleanable is moved from v1 to v2, so no more cleanup.
+ ASSERT_EQ(2, res);
+}
+
+// ***************************************************************** //
+// Unit test for SmallEnumSet
+class SmallEnumSetTest : public testing::Test {
+ public:
+ SmallEnumSetTest() {}
+ ~SmallEnumSetTest() {}
+};
+
+TEST_F(SmallEnumSetTest, SmallSetTest) {
+ FileTypeSet fs;
+ ASSERT_TRUE(fs.Add(FileType::kIdentityFile));
+ ASSERT_FALSE(fs.Add(FileType::kIdentityFile));
+ ASSERT_TRUE(fs.Add(FileType::kInfoLogFile));
+ ASSERT_TRUE(fs.Contains(FileType::kIdentityFile));
+ ASSERT_FALSE(fs.Contains(FileType::kDBLockFile));
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/rocksdb/util/slice_transform_test.cc b/src/rocksdb/util/slice_transform_test.cc
new file mode 100644
index 000000000..64ac8bb1f
--- /dev/null
+++ b/src/rocksdb/util/slice_transform_test.cc
@@ -0,0 +1,154 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "rocksdb/slice_transform.h"
+
+#include "rocksdb/db.h"
+#include "rocksdb/env.h"
+#include "rocksdb/filter_policy.h"
+#include "rocksdb/statistics.h"
+#include "rocksdb/table.h"
+#include "test_util/testharness.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class SliceTransformTest : public testing::Test {};
+
+TEST_F(SliceTransformTest, CapPrefixTransform) {
+ std::string s;
+ s = "abcdefge";
+
+ std::unique_ptr<const SliceTransform> transform;
+
+ transform.reset(NewCappedPrefixTransform(6));
+ ASSERT_EQ(transform->Transform(s).ToString(), "abcdef");
+ ASSERT_TRUE(transform->SameResultWhenAppended("123456"));
+ ASSERT_TRUE(transform->SameResultWhenAppended("1234567"));
+ ASSERT_TRUE(!transform->SameResultWhenAppended("12345"));
+
+ transform.reset(NewCappedPrefixTransform(8));
+ ASSERT_EQ(transform->Transform(s).ToString(), "abcdefge");
+
+ transform.reset(NewCappedPrefixTransform(10));
+ ASSERT_EQ(transform->Transform(s).ToString(), "abcdefge");
+
+ transform.reset(NewCappedPrefixTransform(0));
+ ASSERT_EQ(transform->Transform(s).ToString(), "");
+
+ transform.reset(NewCappedPrefixTransform(0));
+ ASSERT_EQ(transform->Transform("").ToString(), "");
+}
+
+class SliceTransformDBTest : public testing::Test {
+ private:
+ std::string dbname_;
+ Env* env_;
+ DB* db_;
+
+ public:
+ SliceTransformDBTest() : env_(Env::Default()), db_(nullptr) {
+ dbname_ = test::PerThreadDBPath("slice_transform_db_test");
+ EXPECT_OK(DestroyDB(dbname_, last_options_));
+ }
+
+ ~SliceTransformDBTest() override {
+ delete db_;
+ EXPECT_OK(DestroyDB(dbname_, last_options_));
+ }
+
+ DB* db() { return db_; }
+
+ // Return the current option configuration.
+ Options* GetOptions() { return &last_options_; }
+
+ void DestroyAndReopen() {
+ // Destroy using last options
+ Destroy();
+ ASSERT_OK(TryReopen());
+ }
+
+ void Destroy() {
+ delete db_;
+ db_ = nullptr;
+ ASSERT_OK(DestroyDB(dbname_, last_options_));
+ }
+
+ Status TryReopen() {
+ delete db_;
+ db_ = nullptr;
+ last_options_.create_if_missing = true;
+
+ return DB::Open(last_options_, dbname_, &db_);
+ }
+
+ Options last_options_;
+};
+
+namespace {
+uint64_t TestGetTickerCount(const Options& options, Tickers ticker_type) {
+ return options.statistics->getTickerCount(ticker_type);
+}
+} // namespace
+
+TEST_F(SliceTransformDBTest, CapPrefix) {
+ last_options_.prefix_extractor.reset(NewCappedPrefixTransform(8));
+ last_options_.statistics = ROCKSDB_NAMESPACE::CreateDBStatistics();
+ BlockBasedTableOptions bbto;
+ bbto.filter_policy.reset(NewBloomFilterPolicy(10, false));
+ bbto.whole_key_filtering = false;
+ last_options_.table_factory.reset(NewBlockBasedTableFactory(bbto));
+ ASSERT_OK(TryReopen());
+
+ ReadOptions ro;
+ FlushOptions fo;
+ WriteOptions wo;
+
+ ASSERT_OK(db()->Put(wo, "barbarbar", "foo"));
+ ASSERT_OK(db()->Put(wo, "barbarbar2", "foo2"));
+ ASSERT_OK(db()->Put(wo, "foo", "bar"));
+ ASSERT_OK(db()->Put(wo, "foo3", "bar3"));
+ ASSERT_OK(db()->Flush(fo));
+
+ std::unique_ptr<Iterator> iter(db()->NewIterator(ro));
+
+ iter->Seek("foo");
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ(iter->value().ToString(), "bar");
+ ASSERT_EQ(TestGetTickerCount(last_options_, BLOOM_FILTER_PREFIX_USEFUL), 0U);
+
+ iter->Seek("foo2");
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(!iter->Valid());
+ ASSERT_EQ(TestGetTickerCount(last_options_, BLOOM_FILTER_PREFIX_USEFUL), 1U);
+
+ iter->Seek("barbarbar");
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ(iter->value().ToString(), "foo");
+ ASSERT_EQ(TestGetTickerCount(last_options_, BLOOM_FILTER_PREFIX_USEFUL), 1U);
+
+ iter->Seek("barfoofoo");
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(!iter->Valid());
+ ASSERT_EQ(TestGetTickerCount(last_options_, BLOOM_FILTER_PREFIX_USEFUL), 2U);
+
+ iter->Seek("foobarbar");
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(!iter->Valid());
+ ASSERT_EQ(TestGetTickerCount(last_options_, BLOOM_FILTER_PREFIX_USEFUL), 3U);
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/rocksdb/util/status.cc b/src/rocksdb/util/status.cc
new file mode 100644
index 000000000..72fdfdbcc
--- /dev/null
+++ b/src/rocksdb/util/status.cc
@@ -0,0 +1,154 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "rocksdb/status.h"
+
+#include <stdio.h>
+#ifdef OS_WIN
+#include <string.h>
+#endif
+#include <cstring>
+
+#include "port/port.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+std::unique_ptr<const char[]> Status::CopyState(const char* s) {
+ const size_t cch = std::strlen(s) + 1; // +1 for the null terminator
+ char* rv = new char[cch];
+ std::strncpy(rv, s, cch);
+ return std::unique_ptr<const char[]>(rv);
+}
+
+static const char* msgs[static_cast<int>(Status::kMaxSubCode)] = {
+ "", // kNone
+ "Timeout Acquiring Mutex", // kMutexTimeout
+ "Timeout waiting to lock key", // kLockTimeout
+ "Failed to acquire lock due to max_num_locks limit", // kLockLimit
+ "No space left on device", // kNoSpace
+ "Deadlock", // kDeadlock
+ "Stale file handle", // kStaleFile
+ "Memory limit reached", // kMemoryLimit
+ "Space limit reached", // kSpaceLimit
+ "No such file or directory", // kPathNotFound
+ // KMergeOperandsInsufficientCapacity
+ "Insufficient capacity for merge operands",
+ // kManualCompactionPaused
+ "Manual compaction paused",
+ " (overwritten)", // kOverwritten, subcode of OK
+ "Txn not prepared", // kTxnNotPrepared
+ "IO fenced off", // kIOFenced
+};
+
+Status::Status(Code _code, SubCode _subcode, const Slice& msg,
+ const Slice& msg2, Severity sev)
+ : code_(_code),
+ subcode_(_subcode),
+ sev_(sev),
+ retryable_(false),
+ data_loss_(false),
+ scope_(0) {
+ assert(subcode_ != kMaxSubCode);
+ const size_t len1 = msg.size();
+ const size_t len2 = msg2.size();
+ const size_t size = len1 + (len2 ? (2 + len2) : 0);
+ char* const result = new char[size + 1]; // +1 for null terminator
+ memcpy(result, msg.data(), len1);
+ if (len2) {
+ result[len1] = ':';
+ result[len1 + 1] = ' ';
+ memcpy(result + len1 + 2, msg2.data(), len2);
+ }
+ result[size] = '\0'; // null terminator for C style string
+ state_.reset(result);
+}
+
+std::string Status::ToString() const {
+#ifdef ROCKSDB_ASSERT_STATUS_CHECKED
+ checked_ = true;
+#endif // ROCKSDB_ASSERT_STATUS_CHECKED
+ const char* type = nullptr;
+ switch (code_) {
+ case kOk:
+ return "OK";
+ case kNotFound:
+ type = "NotFound: ";
+ break;
+ case kCorruption:
+ type = "Corruption: ";
+ break;
+ case kNotSupported:
+ type = "Not implemented: ";
+ break;
+ case kInvalidArgument:
+ type = "Invalid argument: ";
+ break;
+ case kIOError:
+ type = "IO error: ";
+ break;
+ case kMergeInProgress:
+ type = "Merge in progress: ";
+ break;
+ case kIncomplete:
+ type = "Result incomplete: ";
+ break;
+ case kShutdownInProgress:
+ type = "Shutdown in progress: ";
+ break;
+ case kTimedOut:
+ type = "Operation timed out: ";
+ break;
+ case kAborted:
+ type = "Operation aborted: ";
+ break;
+ case kBusy:
+ type = "Resource busy: ";
+ break;
+ case kExpired:
+ type = "Operation expired: ";
+ break;
+ case kTryAgain:
+ type = "Operation failed. Try again.: ";
+ break;
+ case kCompactionTooLarge:
+ type = "Compaction too large: ";
+ break;
+ case kColumnFamilyDropped:
+ type = "Column family dropped: ";
+ break;
+ case kMaxCode:
+ assert(false);
+ break;
+ }
+ char tmp[30];
+ if (type == nullptr) {
+ // This should not happen since `code_` should be a valid non-`kMaxCode`
+ // member of the `Code` enum. The above switch-statement should have had a
+ // case assigning `type` to a corresponding string.
+ assert(false);
+ snprintf(tmp, sizeof(tmp), "Unknown code(%d): ", static_cast<int>(code()));
+ type = tmp;
+ }
+ std::string result(type);
+ if (subcode_ != kNone) {
+ uint32_t index = static_cast<int32_t>(subcode_);
+ assert(sizeof(msgs) / sizeof(msgs[0]) > index);
+ result.append(msgs[index]);
+ }
+
+ if (state_ != nullptr) {
+ if (subcode_ != kNone) {
+ result.append(": ");
+ }
+ result.append(state_.get());
+ }
+ return result;
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/stderr_logger.cc b/src/rocksdb/util/stderr_logger.cc
new file mode 100644
index 000000000..6044b8b93
--- /dev/null
+++ b/src/rocksdb/util/stderr_logger.cc
@@ -0,0 +1,30 @@
+// Copyright (c) Meta Platforms, Inc. and affiliates.
+//
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "util/stderr_logger.h"
+
+#include "port/sys_time.h"
+
+namespace ROCKSDB_NAMESPACE {
+StderrLogger::~StderrLogger() {}
+
+void StderrLogger::Logv(const char* format, va_list ap) {
+ const uint64_t thread_id = Env::Default()->GetThreadID();
+
+ port::TimeVal now_tv;
+ port::GetTimeOfDay(&now_tv, nullptr);
+ const time_t seconds = now_tv.tv_sec;
+ struct tm t;
+ port::LocalTimeR(&seconds, &t);
+ fprintf(stderr, "%04d/%02d/%02d-%02d:%02d:%02d.%06d %llx ", t.tm_year + 1900,
+ t.tm_mon + 1, t.tm_mday, t.tm_hour, t.tm_min, t.tm_sec,
+ static_cast<int>(now_tv.tv_usec),
+ static_cast<long long unsigned int>(thread_id));
+
+ vfprintf(stderr, format, ap);
+ fprintf(stderr, "\n");
+}
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/stderr_logger.h b/src/rocksdb/util/stderr_logger.h
new file mode 100644
index 000000000..c3b01210c
--- /dev/null
+++ b/src/rocksdb/util/stderr_logger.h
@@ -0,0 +1,31 @@
+// Copyright (c) Meta Platforms, Inc. and affiliates.
+//
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#include <stdarg.h>
+#include <stdio.h>
+
+#include "rocksdb/env.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// Prints logs to stderr for faster debugging
+class StderrLogger : public Logger {
+ public:
+ explicit StderrLogger(const InfoLogLevel log_level = InfoLogLevel::INFO_LEVEL)
+ : Logger(log_level) {}
+
+ ~StderrLogger() override;
+
+ // Brings overloaded Logv()s into scope so they're not hidden when we override
+ // a subset of them.
+ using Logger::Logv;
+
+ virtual void Logv(const char* format, va_list ap) override;
+};
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/stop_watch.h b/src/rocksdb/util/stop_watch.h
new file mode 100644
index 000000000..e26380d97
--- /dev/null
+++ b/src/rocksdb/util/stop_watch.h
@@ -0,0 +1,118 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+#pragma once
+#include "monitoring/statistics.h"
+#include "rocksdb/system_clock.h"
+
+namespace ROCKSDB_NAMESPACE {
+// Auto-scoped.
+// Records the measure time into the corresponding histogram if statistics
+// is not nullptr. It is also saved into *elapsed if the pointer is not nullptr
+// and overwrite is true, it will be added to *elapsed if overwrite is false.
+class StopWatch {
+ public:
+ StopWatch(SystemClock* clock, Statistics* statistics,
+ const uint32_t hist_type, uint64_t* elapsed = nullptr,
+ bool overwrite = true, bool delay_enabled = false)
+ : clock_(clock),
+ statistics_(statistics),
+ hist_type_(hist_type),
+ elapsed_(elapsed),
+ overwrite_(overwrite),
+ stats_enabled_(statistics &&
+ statistics->get_stats_level() >=
+ StatsLevel::kExceptTimers &&
+ statistics->HistEnabledForType(hist_type)),
+ delay_enabled_(delay_enabled),
+ total_delay_(0),
+ delay_start_time_(0),
+ start_time_((stats_enabled_ || elapsed != nullptr) ? clock->NowMicros()
+ : 0) {}
+
+ ~StopWatch() {
+ if (elapsed_) {
+ if (overwrite_) {
+ *elapsed_ = clock_->NowMicros() - start_time_;
+ } else {
+ *elapsed_ += clock_->NowMicros() - start_time_;
+ }
+ }
+ if (elapsed_ && delay_enabled_) {
+ *elapsed_ -= total_delay_;
+ }
+ if (stats_enabled_) {
+ statistics_->reportTimeToHistogram(
+ hist_type_, (elapsed_ != nullptr)
+ ? *elapsed_
+ : (clock_->NowMicros() - start_time_));
+ }
+ }
+
+ void DelayStart() {
+ // if delay_start_time_ is not 0, it means we are already tracking delay,
+ // so delay_start_time_ should not be overwritten
+ if (elapsed_ && delay_enabled_ && delay_start_time_ == 0) {
+ delay_start_time_ = clock_->NowMicros();
+ }
+ }
+
+ void DelayStop() {
+ if (elapsed_ && delay_enabled_ && delay_start_time_ != 0) {
+ total_delay_ += clock_->NowMicros() - delay_start_time_;
+ }
+ // reset to 0 means currently no delay is being tracked, so two consecutive
+ // calls to DelayStop will not increase total_delay_
+ delay_start_time_ = 0;
+ }
+
+ uint64_t GetDelay() const { return delay_enabled_ ? total_delay_ : 0; }
+
+ uint64_t start_time() const { return start_time_; }
+
+ private:
+ SystemClock* clock_;
+ Statistics* statistics_;
+ const uint32_t hist_type_;
+ uint64_t* elapsed_;
+ bool overwrite_;
+ bool stats_enabled_;
+ bool delay_enabled_;
+ uint64_t total_delay_;
+ uint64_t delay_start_time_;
+ const uint64_t start_time_;
+};
+
+// a nano second precision stopwatch
+class StopWatchNano {
+ public:
+ explicit StopWatchNano(SystemClock* clock, bool auto_start = false)
+ : clock_(clock), start_(0) {
+ if (auto_start) {
+ Start();
+ }
+ }
+
+ void Start() { start_ = clock_->NowNanos(); }
+
+ uint64_t ElapsedNanos(bool reset = false) {
+ auto now = clock_->NowNanos();
+ auto elapsed = now - start_;
+ if (reset) {
+ start_ = now;
+ }
+ return elapsed;
+ }
+
+ uint64_t ElapsedNanosSafe(bool reset = false) {
+ return (clock_ != nullptr) ? ElapsedNanos(reset) : 0U;
+ }
+
+ private:
+ SystemClock* clock_;
+ uint64_t start_;
+};
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/string_util.cc b/src/rocksdb/util/string_util.cc
new file mode 100644
index 000000000..324482a4c
--- /dev/null
+++ b/src/rocksdb/util/string_util.cc
@@ -0,0 +1,504 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+#include "util/string_util.h"
+
+#include <errno.h>
+#include <stdio.h>
+#include <stdlib.h>
+
+#include <algorithm>
+#include <cinttypes>
+#include <cmath>
+#include <sstream>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "port/port.h"
+#include "port/sys_time.h"
+#include "rocksdb/slice.h"
+
+#ifndef __has_cpp_attribute
+#define ROCKSDB_HAS_CPP_ATTRIBUTE(x) 0
+#else
+#define ROCKSDB_HAS_CPP_ATTRIBUTE(x) __has_cpp_attribute(x)
+#endif
+
+#if ROCKSDB_HAS_CPP_ATTRIBUTE(maybe_unused) && __cplusplus >= 201703L
+#define ROCKSDB_MAYBE_UNUSED [[maybe_unused]]
+#elif ROCKSDB_HAS_CPP_ATTRIBUTE(gnu::unused) || __GNUC__
+#define ROCKSDB_MAYBE_UNUSED [[gnu::unused]]
+#else
+#define ROCKSDB_MAYBE_UNUSED
+#endif
+
+namespace ROCKSDB_NAMESPACE {
+
+const std::string kNullptrString = "nullptr";
+
+std::vector<std::string> StringSplit(const std::string& arg, char delim) {
+ std::vector<std::string> splits;
+ std::stringstream ss(arg);
+ std::string item;
+ while (std::getline(ss, item, delim)) {
+ splits.push_back(item);
+ }
+ return splits;
+}
+
+// for micros < 10ms, print "XX us".
+// for micros < 10sec, print "XX ms".
+// for micros >= 10 sec, print "XX sec".
+// for micros <= 1 hour, print Y:X M:S".
+// for micros > 1 hour, print Z:Y:X H:M:S".
+int AppendHumanMicros(uint64_t micros, char* output, int len,
+ bool fixed_format) {
+ if (micros < 10000 && !fixed_format) {
+ return snprintf(output, len, "%" PRIu64 " us", micros);
+ } else if (micros < 10000000 && !fixed_format) {
+ return snprintf(output, len, "%.3lf ms",
+ static_cast<double>(micros) / 1000);
+ } else if (micros < 1000000l * 60 && !fixed_format) {
+ return snprintf(output, len, "%.3lf sec",
+ static_cast<double>(micros) / 1000000);
+ } else if (micros < 1000000ll * 60 * 60 && !fixed_format) {
+ return snprintf(output, len, "%02" PRIu64 ":%05.3f M:S",
+ micros / 1000000 / 60,
+ static_cast<double>(micros % 60000000) / 1000000);
+ } else {
+ return snprintf(output, len, "%02" PRIu64 ":%02" PRIu64 ":%05.3f H:M:S",
+ micros / 1000000 / 3600, (micros / 1000000 / 60) % 60,
+ static_cast<double>(micros % 60000000) / 1000000);
+ }
+}
+
+// for sizes >=10TB, print "XXTB"
+// for sizes >=10GB, print "XXGB"
+// etc.
+// append file size summary to output and return the len
+int AppendHumanBytes(uint64_t bytes, char* output, int len) {
+ const uint64_t ull10 = 10;
+ if (bytes >= ull10 << 40) {
+ return snprintf(output, len, "%" PRIu64 "TB", bytes >> 40);
+ } else if (bytes >= ull10 << 30) {
+ return snprintf(output, len, "%" PRIu64 "GB", bytes >> 30);
+ } else if (bytes >= ull10 << 20) {
+ return snprintf(output, len, "%" PRIu64 "MB", bytes >> 20);
+ } else if (bytes >= ull10 << 10) {
+ return snprintf(output, len, "%" PRIu64 "KB", bytes >> 10);
+ } else {
+ return snprintf(output, len, "%" PRIu64 "B", bytes);
+ }
+}
+
+void AppendNumberTo(std::string* str, uint64_t num) {
+ char buf[30];
+ snprintf(buf, sizeof(buf), "%" PRIu64, num);
+ str->append(buf);
+}
+
+void AppendEscapedStringTo(std::string* str, const Slice& value) {
+ for (size_t i = 0; i < value.size(); i++) {
+ char c = value[i];
+ if (c >= ' ' && c <= '~') {
+ str->push_back(c);
+ } else {
+ char buf[10];
+ snprintf(buf, sizeof(buf), "\\x%02x",
+ static_cast<unsigned int>(c) & 0xff);
+ str->append(buf);
+ }
+ }
+}
+
+std::string NumberToHumanString(int64_t num) {
+ char buf[19];
+ int64_t absnum = num < 0 ? -num : num;
+ if (absnum < 10000) {
+ snprintf(buf, sizeof(buf), "%" PRIi64, num);
+ } else if (absnum < 10000000) {
+ snprintf(buf, sizeof(buf), "%" PRIi64 "K", num / 1000);
+ } else if (absnum < 10000000000LL) {
+ snprintf(buf, sizeof(buf), "%" PRIi64 "M", num / 1000000);
+ } else {
+ snprintf(buf, sizeof(buf), "%" PRIi64 "G", num / 1000000000);
+ }
+ return std::string(buf);
+}
+
+std::string BytesToHumanString(uint64_t bytes) {
+ const char* size_name[] = {"KB", "MB", "GB", "TB"};
+ double final_size = static_cast<double>(bytes);
+ size_t size_idx;
+
+ // always start with KB
+ final_size /= 1024;
+ size_idx = 0;
+
+ while (size_idx < 3 && final_size >= 1024) {
+ final_size /= 1024;
+ size_idx++;
+ }
+
+ char buf[20];
+ snprintf(buf, sizeof(buf), "%.2f %s", final_size, size_name[size_idx]);
+ return std::string(buf);
+}
+
+std::string TimeToHumanString(int unixtime) {
+ char time_buffer[80];
+ time_t rawtime = unixtime;
+ struct tm tInfo;
+ struct tm* timeinfo = port::LocalTimeR(&rawtime, &tInfo);
+ assert(timeinfo == &tInfo);
+ strftime(time_buffer, 80, "%c", timeinfo);
+ return std::string(time_buffer);
+}
+
+std::string EscapeString(const Slice& value) {
+ std::string r;
+ AppendEscapedStringTo(&r, value);
+ return r;
+}
+
+bool ConsumeDecimalNumber(Slice* in, uint64_t* val) {
+ uint64_t v = 0;
+ int digits = 0;
+ while (!in->empty()) {
+ char c = (*in)[0];
+ if (c >= '0' && c <= '9') {
+ ++digits;
+ const unsigned int delta = (c - '0');
+ static const uint64_t kMaxUint64 = ~static_cast<uint64_t>(0);
+ if (v > kMaxUint64 / 10 ||
+ (v == kMaxUint64 / 10 && delta > kMaxUint64 % 10)) {
+ // Overflow
+ return false;
+ }
+ v = (v * 10) + delta;
+ in->remove_prefix(1);
+ } else {
+ break;
+ }
+ }
+ *val = v;
+ return (digits > 0);
+}
+
+bool isSpecialChar(const char c) {
+ if (c == '\\' || c == '#' || c == ':' || c == '\r' || c == '\n') {
+ return true;
+ }
+ return false;
+}
+
+namespace {
+using CharMap = std::pair<char, char>;
+}
+
+char UnescapeChar(const char c) {
+ static const CharMap convert_map[] = {{'r', '\r'}, {'n', '\n'}};
+
+ auto iter = std::find_if(std::begin(convert_map), std::end(convert_map),
+ [c](const CharMap& p) { return p.first == c; });
+
+ if (iter == std::end(convert_map)) {
+ return c;
+ }
+ return iter->second;
+}
+
+char EscapeChar(const char c) {
+ static const CharMap convert_map[] = {{'\n', 'n'}, {'\r', 'r'}};
+
+ auto iter = std::find_if(std::begin(convert_map), std::end(convert_map),
+ [c](const CharMap& p) { return p.first == c; });
+
+ if (iter == std::end(convert_map)) {
+ return c;
+ }
+ return iter->second;
+}
+
+std::string EscapeOptionString(const std::string& raw_string) {
+ std::string output;
+ for (auto c : raw_string) {
+ if (isSpecialChar(c)) {
+ output += '\\';
+ output += EscapeChar(c);
+ } else {
+ output += c;
+ }
+ }
+
+ return output;
+}
+
+std::string UnescapeOptionString(const std::string& escaped_string) {
+ bool escaped = false;
+ std::string output;
+
+ for (auto c : escaped_string) {
+ if (escaped) {
+ output += UnescapeChar(c);
+ escaped = false;
+ } else {
+ if (c == '\\') {
+ escaped = true;
+ continue;
+ }
+ output += c;
+ }
+ }
+ return output;
+}
+
+std::string trim(const std::string& str) {
+ if (str.empty()) return std::string();
+ size_t start = 0;
+ size_t end = str.size() - 1;
+ while (isspace(str[start]) != 0 && start < end) {
+ ++start;
+ }
+ while (isspace(str[end]) != 0 && start < end) {
+ --end;
+ }
+ if (start <= end) {
+ return str.substr(start, end - start + 1);
+ }
+ return std::string();
+}
+
+bool EndsWith(const std::string& string, const std::string& pattern) {
+ size_t plen = pattern.size();
+ size_t slen = string.size();
+ if (plen <= slen) {
+ return string.compare(slen - plen, plen, pattern) == 0;
+ } else {
+ return false;
+ }
+}
+
+bool StartsWith(const std::string& string, const std::string& pattern) {
+ return string.compare(0, pattern.size(), pattern) == 0;
+}
+
+#ifndef ROCKSDB_LITE
+
+bool ParseBoolean(const std::string& type, const std::string& value) {
+ if (value == "true" || value == "1") {
+ return true;
+ } else if (value == "false" || value == "0") {
+ return false;
+ }
+ throw std::invalid_argument(type);
+}
+
+uint8_t ParseUint8(const std::string& value) {
+ uint64_t num = ParseUint64(value);
+ if ((num >> 8LL) == 0) {
+ return static_cast<uint8_t>(num);
+ } else {
+ throw std::out_of_range(value);
+ }
+}
+
+uint32_t ParseUint32(const std::string& value) {
+ uint64_t num = ParseUint64(value);
+ if ((num >> 32LL) == 0) {
+ return static_cast<uint32_t>(num);
+ } else {
+ throw std::out_of_range(value);
+ }
+}
+
+int32_t ParseInt32(const std::string& value) {
+ int64_t num = ParseInt64(value);
+ if (num <= std::numeric_limits<int32_t>::max() &&
+ num >= std::numeric_limits<int32_t>::min()) {
+ return static_cast<int32_t>(num);
+ } else {
+ throw std::out_of_range(value);
+ }
+}
+
+#endif
+
+uint64_t ParseUint64(const std::string& value) {
+ size_t endchar;
+#ifndef CYGWIN
+ uint64_t num = std::stoull(value.c_str(), &endchar);
+#else
+ char* endptr;
+ uint64_t num = std::strtoul(value.c_str(), &endptr, 0);
+ endchar = endptr - value.c_str();
+#endif
+
+ if (endchar < value.length()) {
+ char c = value[endchar];
+ if (c == 'k' || c == 'K')
+ num <<= 10LL;
+ else if (c == 'm' || c == 'M')
+ num <<= 20LL;
+ else if (c == 'g' || c == 'G')
+ num <<= 30LL;
+ else if (c == 't' || c == 'T')
+ num <<= 40LL;
+ }
+
+ return num;
+}
+
+int64_t ParseInt64(const std::string& value) {
+ size_t endchar;
+#ifndef CYGWIN
+ int64_t num = std::stoll(value.c_str(), &endchar);
+#else
+ char* endptr;
+ int64_t num = std::strtoll(value.c_str(), &endptr, 0);
+ endchar = endptr - value.c_str();
+#endif
+
+ if (endchar < value.length()) {
+ char c = value[endchar];
+ if (c == 'k' || c == 'K')
+ num <<= 10LL;
+ else if (c == 'm' || c == 'M')
+ num <<= 20LL;
+ else if (c == 'g' || c == 'G')
+ num <<= 30LL;
+ else if (c == 't' || c == 'T')
+ num <<= 40LL;
+ }
+
+ return num;
+}
+
+int ParseInt(const std::string& value) {
+ size_t endchar;
+#ifndef CYGWIN
+ int num = std::stoi(value.c_str(), &endchar);
+#else
+ char* endptr;
+ int num = std::strtoul(value.c_str(), &endptr, 0);
+ endchar = endptr - value.c_str();
+#endif
+
+ if (endchar < value.length()) {
+ char c = value[endchar];
+ if (c == 'k' || c == 'K')
+ num <<= 10;
+ else if (c == 'm' || c == 'M')
+ num <<= 20;
+ else if (c == 'g' || c == 'G')
+ num <<= 30;
+ }
+
+ return num;
+}
+
+double ParseDouble(const std::string& value) {
+#ifndef CYGWIN
+ return std::stod(value);
+#else
+ return std::strtod(value.c_str(), 0);
+#endif
+}
+
+size_t ParseSizeT(const std::string& value) {
+ return static_cast<size_t>(ParseUint64(value));
+}
+
+std::vector<int> ParseVectorInt(const std::string& value) {
+ std::vector<int> result;
+ size_t start = 0;
+ while (start < value.size()) {
+ size_t end = value.find(':', start);
+ if (end == std::string::npos) {
+ result.push_back(ParseInt(value.substr(start)));
+ break;
+ } else {
+ result.push_back(ParseInt(value.substr(start, end - start)));
+ start = end + 1;
+ }
+ }
+ return result;
+}
+
+bool SerializeIntVector(const std::vector<int>& vec, std::string* value) {
+ *value = "";
+ for (size_t i = 0; i < vec.size(); ++i) {
+ if (i > 0) {
+ *value += ":";
+ }
+ *value += std::to_string(vec[i]);
+ }
+ return true;
+}
+
+// Copied from folly/string.cpp:
+// https://github.com/facebook/folly/blob/0deef031cb8aab76dc7e736f8b7c22d701d5f36b/folly/String.cpp#L457
+// There are two variants of `strerror_r` function, one returns
+// `int`, and another returns `char*`. Selecting proper version using
+// preprocessor macros portably is extremely hard.
+//
+// For example, on Android function signature depends on `__USE_GNU` and
+// `__ANDROID_API__` macros (https://git.io/fjBBE).
+//
+// So we are using C++ overloading trick: we pass a pointer of
+// `strerror_r` to `invoke_strerror_r` function, and C++ compiler
+// selects proper function.
+
+#if !(defined(_WIN32) && (defined(__MINGW32__) || defined(_MSC_VER)))
+ROCKSDB_MAYBE_UNUSED
+static std::string invoke_strerror_r(int (*strerror_r)(int, char*, size_t),
+ int err, char* buf, size_t buflen) {
+ // Using XSI-compatible strerror_r
+ int r = strerror_r(err, buf, buflen);
+
+ // OSX/FreeBSD use EINVAL and Linux uses -1 so just check for non-zero
+ if (r != 0) {
+ snprintf(buf, buflen, "Unknown error %d (strerror_r failed with error %d)",
+ err, errno);
+ }
+ return buf;
+}
+
+ROCKSDB_MAYBE_UNUSED
+static std::string invoke_strerror_r(char* (*strerror_r)(int, char*, size_t),
+ int err, char* buf, size_t buflen) {
+ // Using GNU strerror_r
+ return strerror_r(err, buf, buflen);
+}
+#endif // !(defined(_WIN32) && (defined(__MINGW32__) || defined(_MSC_VER)))
+
+std::string errnoStr(int err) {
+ char buf[1024];
+ buf[0] = '\0';
+
+ std::string result;
+
+ // https://developer.apple.com/library/mac/documentation/Darwin/Reference/ManPages/man3/strerror_r.3.html
+ // http://www.kernel.org/doc/man-pages/online/pages/man3/strerror.3.html
+#if defined(_WIN32) && (defined(__MINGW32__) || defined(_MSC_VER))
+ // mingw64 has no strerror_r, but Windows has strerror_s, which C11 added
+ // as well. So maybe we should use this across all platforms (together
+ // with strerrorlen_s). Note strerror_r and _s have swapped args.
+ int r = strerror_s(buf, sizeof(buf), err);
+ if (r != 0) {
+ snprintf(buf, sizeof(buf),
+ "Unknown error %d (strerror_r failed with error %d)", err, errno);
+ }
+ result.assign(buf);
+#else
+ // Using any strerror_r
+ result.assign(invoke_strerror_r(strerror_r, err, buf, sizeof(buf)));
+#endif
+
+ return result;
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/string_util.h b/src/rocksdb/util/string_util.h
new file mode 100644
index 000000000..11178fd1d
--- /dev/null
+++ b/src/rocksdb/util/string_util.h
@@ -0,0 +1,177 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+
+#pragma once
+
+#include <cstdint>
+#include <sstream>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "rocksdb/rocksdb_namespace.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class Slice;
+
+extern std::vector<std::string> StringSplit(const std::string& arg, char delim);
+
+// Append a human-readable printout of "num" to *str
+extern void AppendNumberTo(std::string* str, uint64_t num);
+
+// Append a human-readable printout of "value" to *str.
+// Escapes any non-printable characters found in "value".
+extern void AppendEscapedStringTo(std::string* str, const Slice& value);
+
+// Put n digits from v in base kBase to (*buf)[0] to (*buf)[n-1] and
+// advance *buf to the position after what was written.
+template <size_t kBase>
+inline void PutBaseChars(char** buf, size_t n, uint64_t v, bool uppercase) {
+ const char* digitChars = uppercase ? "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
+ : "0123456789abcdefghijklmnopqrstuvwxyz";
+ for (size_t i = n; i > 0; --i) {
+ (*buf)[i - 1] = digitChars[static_cast<size_t>(v % kBase)];
+ v /= kBase;
+ }
+ *buf += n;
+}
+
+// Parse n digits from *buf in base kBase to *v and advance *buf to the
+// position after what was read. On success, true is returned. On failure,
+// false is returned, *buf is placed at the first bad character, and *v
+// contains the partial parsed data. Overflow is not checked but the
+// result is accurate mod 2^64. Requires the starting value of *v to be
+// zero or previously accumulated parsed digits, i.e.
+// ParseBaseChars(&b, n, &v);
+// is equivalent to n calls to
+// ParseBaseChars(&b, 1, &v);
+template <int kBase>
+inline bool ParseBaseChars(const char** buf, size_t n, uint64_t* v) {
+ while (n) {
+ char c = **buf;
+ *v *= static_cast<uint64_t>(kBase);
+ if (c >= '0' && (kBase >= 10 ? c <= '9' : c < '0' + kBase)) {
+ *v += static_cast<uint64_t>(c - '0');
+ } else if (kBase > 10 && c >= 'A' && c < 'A' + kBase - 10) {
+ *v += static_cast<uint64_t>(c - 'A' + 10);
+ } else if (kBase > 10 && c >= 'a' && c < 'a' + kBase - 10) {
+ *v += static_cast<uint64_t>(c - 'a' + 10);
+ } else {
+ return false;
+ }
+ --n;
+ ++*buf;
+ }
+ return true;
+}
+
+// Return a human-readable version of num.
+// for num >= 10.000, prints "xxK"
+// for num >= 10.000.000, prints "xxM"
+// for num >= 10.000.000.000, prints "xxG"
+extern std::string NumberToHumanString(int64_t num);
+
+// Return a human-readable version of bytes
+// ex: 1048576 -> 1.00 GB
+extern std::string BytesToHumanString(uint64_t bytes);
+
+// Return a human-readable version of unix time
+// ex: 1562116015 -> "Tue Jul 2 18:06:55 2019"
+extern std::string TimeToHumanString(int unixtime);
+
+// Append a human-readable time in micros.
+int AppendHumanMicros(uint64_t micros, char* output, int len,
+ bool fixed_format);
+
+// Append a human-readable size in bytes
+int AppendHumanBytes(uint64_t bytes, char* output, int len);
+
+// Return a human-readable version of "value".
+// Escapes any non-printable characters found in "value".
+extern std::string EscapeString(const Slice& value);
+
+// Parse a human-readable number from "*in" into *value. On success,
+// advances "*in" past the consumed number and sets "*val" to the
+// numeric value. Otherwise, returns false and leaves *in in an
+// unspecified state.
+extern bool ConsumeDecimalNumber(Slice* in, uint64_t* val);
+
+// Returns true if the input char "c" is considered as a special character
+// that will be escaped when EscapeOptionString() is called.
+//
+// @param c the input char
+// @return true if the input char "c" is considered as a special character.
+// @see EscapeOptionString
+bool isSpecialChar(const char c);
+
+// If the input char is an escaped char, it will return the its
+// associated raw-char. Otherwise, the function will simply return
+// the original input char.
+char UnescapeChar(const char c);
+
+// If the input char is a control char, it will return the its
+// associated escaped char. Otherwise, the function will simply return
+// the original input char.
+char EscapeChar(const char c);
+
+// Converts a raw string to an escaped string. Escaped-characters are
+// defined via the isSpecialChar() function. When a char in the input
+// string "raw_string" is classified as a special characters, then it
+// will be prefixed by '\' in the output.
+//
+// It's inverse function is UnescapeOptionString().
+// @param raw_string the input string
+// @return the '\' escaped string of the input "raw_string"
+// @see isSpecialChar, UnescapeOptionString
+std::string EscapeOptionString(const std::string& raw_string);
+
+// The inverse function of EscapeOptionString. It converts
+// an '\' escaped string back to a raw string.
+//
+// @param escaped_string the input '\' escaped string
+// @return the raw string of the input "escaped_string"
+std::string UnescapeOptionString(const std::string& escaped_string);
+
+std::string trim(const std::string& str);
+
+// Returns true if "string" ends with "pattern"
+bool EndsWith(const std::string& string, const std::string& pattern);
+
+// Returns true if "string" starts with "pattern"
+bool StartsWith(const std::string& string, const std::string& pattern);
+
+#ifndef ROCKSDB_LITE
+bool ParseBoolean(const std::string& type, const std::string& value);
+
+uint8_t ParseUint8(const std::string& value);
+
+uint32_t ParseUint32(const std::string& value);
+
+int32_t ParseInt32(const std::string& value);
+#endif
+
+uint64_t ParseUint64(const std::string& value);
+
+int ParseInt(const std::string& value);
+
+int64_t ParseInt64(const std::string& value);
+
+double ParseDouble(const std::string& value);
+
+size_t ParseSizeT(const std::string& value);
+
+std::vector<int> ParseVectorInt(const std::string& value);
+
+bool SerializeIntVector(const std::vector<int>& vec, std::string* value);
+
+extern const std::string kNullptrString;
+
+// errnoStr() function returns a string that describes the error code passed in
+// the argument err
+extern std::string errnoStr(int err);
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/thread_guard.h b/src/rocksdb/util/thread_guard.h
new file mode 100644
index 000000000..b2bb06a1b
--- /dev/null
+++ b/src/rocksdb/util/thread_guard.h
@@ -0,0 +1,41 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#include "port/port.h"
+#include "rocksdb/rocksdb_namespace.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// Resource management object for threads that joins the thread upon
+// destruction. Has unique ownership of the thread object, so copying it is not
+// allowed, while moving it transfers ownership.
+class ThreadGuard {
+ public:
+ ThreadGuard() = default;
+
+ explicit ThreadGuard(port::Thread&& thread) : thread_(std::move(thread)) {}
+
+ ThreadGuard(const ThreadGuard&) = delete;
+ ThreadGuard& operator=(const ThreadGuard&) = delete;
+
+ ThreadGuard(ThreadGuard&&) noexcept = default;
+ ThreadGuard& operator=(ThreadGuard&&) noexcept = default;
+
+ ~ThreadGuard() {
+ if (thread_.joinable()) {
+ thread_.join();
+ }
+ }
+
+ const port::Thread& GetThread() const { return thread_; }
+ port::Thread& GetThread() { return thread_; }
+
+ private:
+ port::Thread thread_;
+};
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/thread_list_test.cc b/src/rocksdb/util/thread_list_test.cc
new file mode 100644
index 000000000..af4e62355
--- /dev/null
+++ b/src/rocksdb/util/thread_list_test.cc
@@ -0,0 +1,360 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include <condition_variable>
+#include <mutex>
+
+#include "monitoring/thread_status_updater.h"
+#include "rocksdb/db.h"
+#include "test_util/testharness.h"
+
+#ifdef ROCKSDB_USING_THREAD_STATUS
+
+namespace ROCKSDB_NAMESPACE {
+
+class SimulatedBackgroundTask {
+ public:
+ SimulatedBackgroundTask(
+ const void* db_key, const std::string& db_name, const void* cf_key,
+ const std::string& cf_name,
+ const ThreadStatus::OperationType operation_type =
+ ThreadStatus::OP_UNKNOWN,
+ const ThreadStatus::StateType state_type = ThreadStatus::STATE_UNKNOWN)
+ : db_key_(db_key),
+ db_name_(db_name),
+ cf_key_(cf_key),
+ cf_name_(cf_name),
+ operation_type_(operation_type),
+ state_type_(state_type),
+ should_run_(true),
+ running_count_(0) {
+ Env::Default()->GetThreadStatusUpdater()->NewColumnFamilyInfo(
+ db_key_, db_name_, cf_key_, cf_name_);
+ }
+
+ ~SimulatedBackgroundTask() {
+ Env::Default()->GetThreadStatusUpdater()->EraseDatabaseInfo(db_key_);
+ }
+
+ void Run() {
+ std::unique_lock<std::mutex> l(mutex_);
+ running_count_++;
+ bg_cv_.notify_all();
+ Env::Default()->GetThreadStatusUpdater()->SetColumnFamilyInfoKey(cf_key_);
+ Env::Default()->GetThreadStatusUpdater()->SetThreadOperation(
+ operation_type_);
+ Env::Default()->GetThreadStatusUpdater()->SetThreadState(state_type_);
+ while (should_run_) {
+ bg_cv_.wait(l);
+ }
+ Env::Default()->GetThreadStatusUpdater()->ClearThreadState();
+ Env::Default()->GetThreadStatusUpdater()->ClearThreadOperation();
+ Env::Default()->GetThreadStatusUpdater()->SetColumnFamilyInfoKey(nullptr);
+ running_count_--;
+ bg_cv_.notify_all();
+ }
+
+ void FinishAllTasks() {
+ std::unique_lock<std::mutex> l(mutex_);
+ should_run_ = false;
+ bg_cv_.notify_all();
+ }
+
+ void WaitUntilScheduled(int job_count) {
+ std::unique_lock<std::mutex> l(mutex_);
+ while (running_count_ < job_count) {
+ bg_cv_.wait(l);
+ }
+ }
+
+ void WaitUntilDone() {
+ std::unique_lock<std::mutex> l(mutex_);
+ while (running_count_ > 0) {
+ bg_cv_.wait(l);
+ }
+ }
+
+ static void DoSimulatedTask(void* arg) {
+ reinterpret_cast<SimulatedBackgroundTask*>(arg)->Run();
+ }
+
+ private:
+ const void* db_key_;
+ const std::string db_name_;
+ const void* cf_key_;
+ const std::string cf_name_;
+ const ThreadStatus::OperationType operation_type_;
+ const ThreadStatus::StateType state_type_;
+ std::mutex mutex_;
+ std::condition_variable bg_cv_;
+ bool should_run_;
+ std::atomic<int> running_count_;
+};
+
+class ThreadListTest : public testing::Test {
+ public:
+ ThreadListTest() {}
+};
+
+TEST_F(ThreadListTest, GlobalTables) {
+ // verify the global tables for operations and states are properly indexed.
+ for (int type = 0; type != ThreadStatus::NUM_OP_TYPES; ++type) {
+ ASSERT_EQ(global_operation_table[type].type, type);
+ ASSERT_EQ(
+ global_operation_table[type].name,
+ ThreadStatus::GetOperationName(ThreadStatus::OperationType(type)));
+ }
+
+ for (int type = 0; type != ThreadStatus::NUM_STATE_TYPES; ++type) {
+ ASSERT_EQ(global_state_table[type].type, type);
+ ASSERT_EQ(global_state_table[type].name,
+ ThreadStatus::GetStateName(ThreadStatus::StateType(type)));
+ }
+
+ for (int stage = 0; stage != ThreadStatus::NUM_OP_STAGES; ++stage) {
+ ASSERT_EQ(global_op_stage_table[stage].stage, stage);
+ ASSERT_EQ(global_op_stage_table[stage].name,
+ ThreadStatus::GetOperationStageName(
+ ThreadStatus::OperationStage(stage)));
+ }
+}
+
+TEST_F(ThreadListTest, SimpleColumnFamilyInfoTest) {
+ Env* env = Env::Default();
+ const int kHighPriorityThreads = 3;
+ const int kLowPriorityThreads = 5;
+ const int kSimulatedHighPriThreads = kHighPriorityThreads - 1;
+ const int kSimulatedLowPriThreads = kLowPriorityThreads / 3;
+ const int kDelayMicros = 1000000;
+ env->SetBackgroundThreads(kHighPriorityThreads, Env::HIGH);
+ env->SetBackgroundThreads(kLowPriorityThreads, Env::LOW);
+ // Wait 1 second so that threads start
+ Env::Default()->SleepForMicroseconds(kDelayMicros);
+ SimulatedBackgroundTask running_task(reinterpret_cast<void*>(1234), "running",
+ reinterpret_cast<void*>(5678),
+ "pikachu");
+
+ for (int test = 0; test < kSimulatedHighPriThreads; ++test) {
+ env->Schedule(&SimulatedBackgroundTask::DoSimulatedTask, &running_task,
+ Env::Priority::HIGH);
+ }
+
+ for (int test = 0; test < kSimulatedLowPriThreads; ++test) {
+ env->Schedule(&SimulatedBackgroundTask::DoSimulatedTask, &running_task,
+ Env::Priority::LOW);
+ }
+ running_task.WaitUntilScheduled(kSimulatedHighPriThreads +
+ kSimulatedLowPriThreads);
+ // We can only reserve limited number of waiting threads
+ ASSERT_EQ(kHighPriorityThreads - kSimulatedHighPriThreads,
+ env->ReserveThreads(kHighPriorityThreads, Env::Priority::HIGH));
+ ASSERT_EQ(kLowPriorityThreads - kSimulatedLowPriThreads,
+ env->ReserveThreads(kLowPriorityThreads, Env::Priority::LOW));
+
+ // Reservation shall not affect the existing thread list
+ std::vector<ThreadStatus> thread_list;
+
+ // Verify the number of running threads in each pool.
+ ASSERT_OK(env->GetThreadList(&thread_list));
+ int running_count[ThreadStatus::NUM_THREAD_TYPES] = {0};
+ for (auto thread_status : thread_list) {
+ if (thread_status.cf_name == "pikachu" &&
+ thread_status.db_name == "running") {
+ running_count[thread_status.thread_type]++;
+ }
+ }
+ // Cannot reserve more threads
+ ASSERT_EQ(0, env->ReserveThreads(kHighPriorityThreads, Env::Priority::HIGH));
+ ASSERT_EQ(0, env->ReserveThreads(kLowPriorityThreads, Env::Priority::LOW));
+
+ ASSERT_EQ(running_count[ThreadStatus::HIGH_PRIORITY],
+ kSimulatedHighPriThreads);
+ ASSERT_EQ(running_count[ThreadStatus::LOW_PRIORITY], kSimulatedLowPriThreads);
+ ASSERT_EQ(running_count[ThreadStatus::USER], 0);
+
+ running_task.FinishAllTasks();
+ running_task.WaitUntilDone();
+
+ ASSERT_EQ(kHighPriorityThreads - kSimulatedHighPriThreads,
+ env->ReleaseThreads(kHighPriorityThreads, Env::Priority::HIGH));
+ ASSERT_EQ(kLowPriorityThreads - kSimulatedLowPriThreads,
+ env->ReleaseThreads(kLowPriorityThreads, Env::Priority::LOW));
+ // Verify none of the threads are running
+ ASSERT_OK(env->GetThreadList(&thread_list));
+
+ for (int i = 0; i < ThreadStatus::NUM_THREAD_TYPES; ++i) {
+ running_count[i] = 0;
+ }
+ for (auto thread_status : thread_list) {
+ if (thread_status.cf_name == "pikachu" &&
+ thread_status.db_name == "running") {
+ running_count[thread_status.thread_type]++;
+ }
+ }
+
+ ASSERT_EQ(running_count[ThreadStatus::HIGH_PRIORITY], 0);
+ ASSERT_EQ(running_count[ThreadStatus::LOW_PRIORITY], 0);
+ ASSERT_EQ(running_count[ThreadStatus::USER], 0);
+}
+
+namespace {
+void UpdateStatusCounts(const std::vector<ThreadStatus>& thread_list,
+ int operation_counts[], int state_counts[]) {
+ for (auto thread_status : thread_list) {
+ operation_counts[thread_status.operation_type]++;
+ state_counts[thread_status.state_type]++;
+ }
+}
+
+void VerifyAndResetCounts(const int correct_counts[], int collected_counts[],
+ int size) {
+ for (int i = 0; i < size; ++i) {
+ ASSERT_EQ(collected_counts[i], correct_counts[i]);
+ collected_counts[i] = 0;
+ }
+}
+
+void UpdateCount(int operation_counts[], int from_event, int to_event,
+ int amount) {
+ operation_counts[from_event] -= amount;
+ operation_counts[to_event] += amount;
+}
+} // namespace
+
+TEST_F(ThreadListTest, SimpleEventTest) {
+ Env* env = Env::Default();
+
+ // simulated tasks
+ const int kFlushWriteTasks = 3;
+ SimulatedBackgroundTask flush_write_task(
+ reinterpret_cast<void*>(1234), "running", reinterpret_cast<void*>(5678),
+ "pikachu", ThreadStatus::OP_FLUSH);
+
+ const int kCompactionWriteTasks = 4;
+ SimulatedBackgroundTask compaction_write_task(
+ reinterpret_cast<void*>(1234), "running", reinterpret_cast<void*>(5678),
+ "pikachu", ThreadStatus::OP_COMPACTION);
+
+ const int kCompactionReadTasks = 5;
+ SimulatedBackgroundTask compaction_read_task(
+ reinterpret_cast<void*>(1234), "running", reinterpret_cast<void*>(5678),
+ "pikachu", ThreadStatus::OP_COMPACTION);
+
+ const int kCompactionWaitTasks = 6;
+ SimulatedBackgroundTask compaction_wait_task(
+ reinterpret_cast<void*>(1234), "running", reinterpret_cast<void*>(5678),
+ "pikachu", ThreadStatus::OP_COMPACTION);
+
+ // setup right answers
+ int correct_operation_counts[ThreadStatus::NUM_OP_TYPES] = {0};
+ correct_operation_counts[ThreadStatus::OP_FLUSH] = kFlushWriteTasks;
+ correct_operation_counts[ThreadStatus::OP_COMPACTION] =
+ kCompactionWriteTasks + kCompactionReadTasks + kCompactionWaitTasks;
+
+ env->SetBackgroundThreads(correct_operation_counts[ThreadStatus::OP_FLUSH],
+ Env::HIGH);
+ env->SetBackgroundThreads(
+ correct_operation_counts[ThreadStatus::OP_COMPACTION], Env::LOW);
+
+ // schedule the simulated tasks
+ for (int t = 0; t < kFlushWriteTasks; ++t) {
+ env->Schedule(&SimulatedBackgroundTask::DoSimulatedTask, &flush_write_task,
+ Env::Priority::HIGH);
+ }
+ flush_write_task.WaitUntilScheduled(kFlushWriteTasks);
+
+ for (int t = 0; t < kCompactionWriteTasks; ++t) {
+ env->Schedule(&SimulatedBackgroundTask::DoSimulatedTask,
+ &compaction_write_task, Env::Priority::LOW);
+ }
+ compaction_write_task.WaitUntilScheduled(kCompactionWriteTasks);
+
+ for (int t = 0; t < kCompactionReadTasks; ++t) {
+ env->Schedule(&SimulatedBackgroundTask::DoSimulatedTask,
+ &compaction_read_task, Env::Priority::LOW);
+ }
+ compaction_read_task.WaitUntilScheduled(kCompactionReadTasks);
+
+ for (int t = 0; t < kCompactionWaitTasks; ++t) {
+ env->Schedule(&SimulatedBackgroundTask::DoSimulatedTask,
+ &compaction_wait_task, Env::Priority::LOW);
+ }
+ compaction_wait_task.WaitUntilScheduled(kCompactionWaitTasks);
+
+ // verify the thread-status
+ int operation_counts[ThreadStatus::NUM_OP_TYPES] = {0};
+ int state_counts[ThreadStatus::NUM_STATE_TYPES] = {0};
+
+ std::vector<ThreadStatus> thread_list;
+ ASSERT_OK(env->GetThreadList(&thread_list));
+ UpdateStatusCounts(thread_list, operation_counts, state_counts);
+ VerifyAndResetCounts(correct_operation_counts, operation_counts,
+ ThreadStatus::NUM_OP_TYPES);
+
+ // terminate compaction-wait tasks and see if the thread-status
+ // reflects this update
+ compaction_wait_task.FinishAllTasks();
+ compaction_wait_task.WaitUntilDone();
+ UpdateCount(correct_operation_counts, ThreadStatus::OP_COMPACTION,
+ ThreadStatus::OP_UNKNOWN, kCompactionWaitTasks);
+
+ ASSERT_OK(env->GetThreadList(&thread_list));
+ UpdateStatusCounts(thread_list, operation_counts, state_counts);
+ VerifyAndResetCounts(correct_operation_counts, operation_counts,
+ ThreadStatus::NUM_OP_TYPES);
+
+ // terminate flush-write tasks and see if the thread-status
+ // reflects this update
+ flush_write_task.FinishAllTasks();
+ flush_write_task.WaitUntilDone();
+ UpdateCount(correct_operation_counts, ThreadStatus::OP_FLUSH,
+ ThreadStatus::OP_UNKNOWN, kFlushWriteTasks);
+
+ ASSERT_OK(env->GetThreadList(&thread_list));
+ UpdateStatusCounts(thread_list, operation_counts, state_counts);
+ VerifyAndResetCounts(correct_operation_counts, operation_counts,
+ ThreadStatus::NUM_OP_TYPES);
+
+ // terminate compaction-write tasks and see if the thread-status
+ // reflects this update
+ compaction_write_task.FinishAllTasks();
+ compaction_write_task.WaitUntilDone();
+ UpdateCount(correct_operation_counts, ThreadStatus::OP_COMPACTION,
+ ThreadStatus::OP_UNKNOWN, kCompactionWriteTasks);
+
+ ASSERT_OK(env->GetThreadList(&thread_list));
+ UpdateStatusCounts(thread_list, operation_counts, state_counts);
+ VerifyAndResetCounts(correct_operation_counts, operation_counts,
+ ThreadStatus::NUM_OP_TYPES);
+
+ // terminate compaction-write tasks and see if the thread-status
+ // reflects this update
+ compaction_read_task.FinishAllTasks();
+ compaction_read_task.WaitUntilDone();
+ UpdateCount(correct_operation_counts, ThreadStatus::OP_COMPACTION,
+ ThreadStatus::OP_UNKNOWN, kCompactionReadTasks);
+
+ ASSERT_OK(env->GetThreadList(&thread_list));
+ UpdateStatusCounts(thread_list, operation_counts, state_counts);
+ VerifyAndResetCounts(correct_operation_counts, operation_counts,
+ ThreadStatus::NUM_OP_TYPES);
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
+
+#else
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return 0;
+}
+
+#endif // ROCKSDB_USING_THREAD_STATUS
diff --git a/src/rocksdb/util/thread_local.cc b/src/rocksdb/util/thread_local.cc
new file mode 100644
index 000000000..969639d9b
--- /dev/null
+++ b/src/rocksdb/util/thread_local.cc
@@ -0,0 +1,521 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "util/thread_local.h"
+
+#include <stdlib.h>
+
+#include "port/likely.h"
+#include "util/mutexlock.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+struct Entry {
+ Entry() : ptr(nullptr) {}
+ Entry(const Entry& e) : ptr(e.ptr.load(std::memory_order_relaxed)) {}
+ std::atomic<void*> ptr;
+};
+
+class StaticMeta;
+
+// This is the structure that is declared as "thread_local" storage.
+// The vector keep list of atomic pointer for all instances for "current"
+// thread. The vector is indexed by an Id that is unique in process and
+// associated with one ThreadLocalPtr instance. The Id is assigned by a
+// global StaticMeta singleton. So if we instantiated 3 ThreadLocalPtr
+// instances, each thread will have a ThreadData with a vector of size 3:
+// ---------------------------------------------------
+// | | instance 1 | instance 2 | instance 3 |
+// ---------------------------------------------------
+// | thread 1 | void* | void* | void* | <- ThreadData
+// ---------------------------------------------------
+// | thread 2 | void* | void* | void* | <- ThreadData
+// ---------------------------------------------------
+// | thread 3 | void* | void* | void* | <- ThreadData
+// ---------------------------------------------------
+struct ThreadData {
+ explicit ThreadData(ThreadLocalPtr::StaticMeta* _inst)
+ : entries(), next(nullptr), prev(nullptr), inst(_inst) {}
+ std::vector<Entry> entries;
+ ThreadData* next;
+ ThreadData* prev;
+ ThreadLocalPtr::StaticMeta* inst;
+};
+
+class ThreadLocalPtr::StaticMeta {
+ public:
+ StaticMeta();
+
+ // Return the next available Id
+ uint32_t GetId();
+ // Return the next available Id without claiming it
+ uint32_t PeekId() const;
+ // Return the given Id back to the free pool. This also triggers
+ // UnrefHandler for associated pointer value (if not NULL) for all threads.
+ void ReclaimId(uint32_t id);
+
+ // Return the pointer value for the given id for the current thread.
+ void* Get(uint32_t id) const;
+ // Reset the pointer value for the given id for the current thread.
+ void Reset(uint32_t id, void* ptr);
+ // Atomically swap the supplied ptr and return the previous value
+ void* Swap(uint32_t id, void* ptr);
+ // Atomically compare and swap the provided value only if it equals
+ // to expected value.
+ bool CompareAndSwap(uint32_t id, void* ptr, void*& expected);
+ // Reset all thread local data to replacement, and return non-nullptr
+ // data for all existing threads
+ void Scrape(uint32_t id, autovector<void*>* ptrs, void* const replacement);
+ // Update res by applying func on each thread-local value. Holds a lock that
+ // prevents unref handler from running during this call, but clients must
+ // still provide external synchronization since the owning thread can
+ // access the values without internal locking, e.g., via Get() and Reset().
+ void Fold(uint32_t id, FoldFunc func, void* res);
+
+ // Register the UnrefHandler for id
+ void SetHandler(uint32_t id, UnrefHandler handler);
+
+ // protect inst, next_instance_id_, free_instance_ids_, head_,
+ // ThreadData.entries
+ //
+ // Note that here we prefer function static variable instead of the usual
+ // global static variable. The reason is that c++ destruction order of
+ // static variables in the reverse order of their construction order.
+ // However, C++ does not guarantee any construction order when global
+ // static variables are defined in different files, while the function
+ // static variables are initialized when their function are first called.
+ // As a result, the construction order of the function static variables
+ // can be controlled by properly invoke their first function calls in
+ // the right order.
+ //
+ // For instance, the following function contains a function static
+ // variable. We place a dummy function call of this inside
+ // Env::Default() to ensure the construction order of the construction
+ // order.
+ static port::Mutex* Mutex();
+
+ // Returns the member mutex of the current StaticMeta. In general,
+ // Mutex() should be used instead of this one. However, in case where
+ // the static variable inside Instance() goes out of scope, MemberMutex()
+ // should be used. One example is OnThreadExit() function.
+ port::Mutex* MemberMutex() { return &mutex_; }
+
+ private:
+ // Get UnrefHandler for id with acquiring mutex
+ // REQUIRES: mutex locked
+ UnrefHandler GetHandler(uint32_t id);
+
+ // Triggered before a thread terminates
+ static void OnThreadExit(void* ptr);
+
+ // Add current thread's ThreadData to the global chain
+ // REQUIRES: mutex locked
+ void AddThreadData(ThreadData* d);
+
+ // Remove current thread's ThreadData from the global chain
+ // REQUIRES: mutex locked
+ void RemoveThreadData(ThreadData* d);
+
+ static ThreadData* GetThreadLocal();
+
+ uint32_t next_instance_id_;
+ // Used to recycle Ids in case ThreadLocalPtr is instantiated and destroyed
+ // frequently. This also prevents it from blowing up the vector space.
+ autovector<uint32_t> free_instance_ids_;
+ // Chain all thread local structure together. This is necessary since
+ // when one ThreadLocalPtr gets destroyed, we need to loop over each
+ // thread's version of pointer corresponding to that instance and
+ // call UnrefHandler for it.
+ ThreadData head_;
+
+ std::unordered_map<uint32_t, UnrefHandler> handler_map_;
+
+ // The private mutex. Developers should always use Mutex() instead of
+ // using this variable directly.
+ port::Mutex mutex_;
+ // Thread local storage
+ static thread_local ThreadData* tls_;
+
+ // Used to make thread exit trigger possible if !defined(OS_MACOSX).
+ // Otherwise, used to retrieve thread data.
+ pthread_key_t pthread_key_;
+};
+
+thread_local ThreadData* ThreadLocalPtr::StaticMeta::tls_ = nullptr;
+
+// Windows doesn't support a per-thread destructor with its
+// TLS primitives. So, we build it manually by inserting a
+// function to be called on each thread's exit.
+// See http://www.codeproject.com/Articles/8113/Thread-Local-Storage-The-C-Way
+// and http://www.nynaeve.net/?p=183
+//
+// really we do this to have clear conscience since using TLS with thread-pools
+// is iffy
+// although OK within a request. But otherwise, threads have no identity in its
+// modern use.
+
+// This runs on windows only called from the System Loader
+#ifdef OS_WIN
+
+// Windows cleanup routine is invoked from a System Loader with a different
+// signature so we can not directly hookup the original OnThreadExit which is
+// private member
+// so we make StaticMeta class share with the us the address of the function so
+// we can invoke it.
+namespace wintlscleanup {
+
+// This is set to OnThreadExit in StaticMeta singleton constructor
+UnrefHandler thread_local_inclass_routine = nullptr;
+pthread_key_t thread_local_key = pthread_key_t(-1);
+
+// Static callback function to call with each thread termination.
+void NTAPI WinOnThreadExit(PVOID module, DWORD reason, PVOID reserved) {
+ // We decided to punt on PROCESS_EXIT
+ if (DLL_THREAD_DETACH == reason) {
+ if (thread_local_key != pthread_key_t(-1) &&
+ thread_local_inclass_routine != nullptr) {
+ void* tls = TlsGetValue(thread_local_key);
+ if (tls != nullptr) {
+ thread_local_inclass_routine(tls);
+ }
+ }
+ }
+}
+
+} // namespace wintlscleanup
+
+// extern "C" suppresses C++ name mangling so we know the symbol name for the
+// linker /INCLUDE:symbol pragma above.
+extern "C" {
+
+#ifdef _MSC_VER
+// The linker must not discard thread_callback_on_exit. (We force a reference
+// to this variable with a linker /include:symbol pragma to ensure that.) If
+// this variable is discarded, the OnThreadExit function will never be called.
+#ifndef _X86_
+
+// .CRT section is merged with .rdata on x64 so it must be constant data.
+#pragma const_seg(".CRT$XLB")
+// When defining a const variable, it must have external linkage to be sure the
+// linker doesn't discard it.
+extern const PIMAGE_TLS_CALLBACK p_thread_callback_on_exit;
+const PIMAGE_TLS_CALLBACK p_thread_callback_on_exit =
+ wintlscleanup::WinOnThreadExit;
+// Reset the default section.
+#pragma const_seg()
+
+#pragma comment(linker, "/include:_tls_used")
+#pragma comment(linker, "/include:p_thread_callback_on_exit")
+
+#else // _X86_
+
+#pragma data_seg(".CRT$XLB")
+PIMAGE_TLS_CALLBACK p_thread_callback_on_exit = wintlscleanup::WinOnThreadExit;
+// Reset the default section.
+#pragma data_seg()
+
+#pragma comment(linker, "/INCLUDE:__tls_used")
+#pragma comment(linker, "/INCLUDE:_p_thread_callback_on_exit")
+
+#endif // _X86_
+
+#else
+// https://github.com/couchbase/gperftools/blob/master/src/windows/port.cc
+BOOL WINAPI DllMain(HINSTANCE h, DWORD dwReason, PVOID pv) {
+ if (dwReason == DLL_THREAD_DETACH)
+ wintlscleanup::WinOnThreadExit(h, dwReason, pv);
+ return TRUE;
+}
+#endif
+} // extern "C"
+
+#endif // OS_WIN
+
+void ThreadLocalPtr::InitSingletons() { ThreadLocalPtr::Instance(); }
+
+ThreadLocalPtr::StaticMeta* ThreadLocalPtr::Instance() {
+ // Here we prefer function static variable instead of global
+ // static variable as function static variable is initialized
+ // when the function is first call. As a result, we can properly
+ // control their construction order by properly preparing their
+ // first function call.
+ //
+ // Note that here we decide to make "inst" a static pointer w/o deleting
+ // it at the end instead of a static variable. This is to avoid the following
+ // destruction order disaster happens when a child thread using ThreadLocalPtr
+ // dies AFTER the main thread dies: When a child thread happens to use
+ // ThreadLocalPtr, it will try to delete its thread-local data on its
+ // OnThreadExit when the child thread dies. However, OnThreadExit depends
+ // on the following variable. As a result, if the main thread dies before any
+ // child thread happen to use ThreadLocalPtr dies, then the destruction of
+ // the following variable will go first, then OnThreadExit, therefore causing
+ // invalid access.
+ //
+ // The above problem can be solved by using thread_local to store tls_.
+ // thread_local supports dynamic construction and destruction of
+ // non-primitive typed variables. As a result, we can guarantee the
+ // destruction order even when the main thread dies before any child threads.
+ static ThreadLocalPtr::StaticMeta* inst = new ThreadLocalPtr::StaticMeta();
+ return inst;
+}
+
+port::Mutex* ThreadLocalPtr::StaticMeta::Mutex() { return &Instance()->mutex_; }
+
+void ThreadLocalPtr::StaticMeta::OnThreadExit(void* ptr) {
+ auto* tls = static_cast<ThreadData*>(ptr);
+ assert(tls != nullptr);
+
+ // Use the cached StaticMeta::Instance() instead of directly calling
+ // the variable inside StaticMeta::Instance() might already go out of
+ // scope here in case this OnThreadExit is called after the main thread
+ // dies.
+ auto* inst = tls->inst;
+ pthread_setspecific(inst->pthread_key_, nullptr);
+
+ MutexLock l(inst->MemberMutex());
+ inst->RemoveThreadData(tls);
+ // Unref stored pointers of current thread from all instances
+ uint32_t id = 0;
+ for (auto& e : tls->entries) {
+ void* raw = e.ptr.load();
+ if (raw != nullptr) {
+ auto unref = inst->GetHandler(id);
+ if (unref != nullptr) {
+ unref(raw);
+ }
+ }
+ ++id;
+ }
+ // Delete thread local structure no matter if it is Mac platform
+ delete tls;
+}
+
+ThreadLocalPtr::StaticMeta::StaticMeta()
+ : next_instance_id_(0), head_(this), pthread_key_(0) {
+ if (pthread_key_create(&pthread_key_, &OnThreadExit) != 0) {
+ abort();
+ }
+
+ // OnThreadExit is not getting called on the main thread.
+ // Call through the static destructor mechanism to avoid memory leak.
+ //
+ // Caveats: ~A() will be invoked _after_ ~StaticMeta for the global
+ // singleton (destructors are invoked in reverse order of constructor
+ // _completion_); the latter must not mutate internal members. This
+ // cleanup mechanism inherently relies on use-after-release of the
+ // StaticMeta, and is brittle with respect to compiler-specific handling
+ // of memory backing destructed statically-scoped objects. Perhaps
+ // registering with atexit(3) would be more robust.
+ //
+// This is not required on Windows.
+#if !defined(OS_WIN)
+ static struct A {
+ ~A() {
+ if (tls_) {
+ OnThreadExit(tls_);
+ }
+ }
+ } a;
+#endif // !defined(OS_WIN)
+
+ head_.next = &head_;
+ head_.prev = &head_;
+
+#ifdef OS_WIN
+ // Share with Windows its cleanup routine and the key
+ wintlscleanup::thread_local_inclass_routine = OnThreadExit;
+ wintlscleanup::thread_local_key = pthread_key_;
+#endif
+}
+
+void ThreadLocalPtr::StaticMeta::AddThreadData(ThreadData* d) {
+ Mutex()->AssertHeld();
+ d->next = &head_;
+ d->prev = head_.prev;
+ head_.prev->next = d;
+ head_.prev = d;
+}
+
+void ThreadLocalPtr::StaticMeta::RemoveThreadData(ThreadData* d) {
+ Mutex()->AssertHeld();
+ d->next->prev = d->prev;
+ d->prev->next = d->next;
+ d->next = d->prev = d;
+}
+
+ThreadData* ThreadLocalPtr::StaticMeta::GetThreadLocal() {
+ if (UNLIKELY(tls_ == nullptr)) {
+ auto* inst = Instance();
+ tls_ = new ThreadData(inst);
+ {
+ // Register it in the global chain, needs to be done before thread exit
+ // handler registration
+ MutexLock l(Mutex());
+ inst->AddThreadData(tls_);
+ }
+ // Even it is not OS_MACOSX, need to register value for pthread_key_ so that
+ // its exit handler will be triggered.
+ if (pthread_setspecific(inst->pthread_key_, tls_) != 0) {
+ {
+ MutexLock l(Mutex());
+ inst->RemoveThreadData(tls_);
+ }
+ delete tls_;
+ abort();
+ }
+ }
+ return tls_;
+}
+
+void* ThreadLocalPtr::StaticMeta::Get(uint32_t id) const {
+ auto* tls = GetThreadLocal();
+ if (UNLIKELY(id >= tls->entries.size())) {
+ return nullptr;
+ }
+ return tls->entries[id].ptr.load(std::memory_order_acquire);
+}
+
+void ThreadLocalPtr::StaticMeta::Reset(uint32_t id, void* ptr) {
+ auto* tls = GetThreadLocal();
+ if (UNLIKELY(id >= tls->entries.size())) {
+ // Need mutex to protect entries access within ReclaimId
+ MutexLock l(Mutex());
+ tls->entries.resize(id + 1);
+ }
+ tls->entries[id].ptr.store(ptr, std::memory_order_release);
+}
+
+void* ThreadLocalPtr::StaticMeta::Swap(uint32_t id, void* ptr) {
+ auto* tls = GetThreadLocal();
+ if (UNLIKELY(id >= tls->entries.size())) {
+ // Need mutex to protect entries access within ReclaimId
+ MutexLock l(Mutex());
+ tls->entries.resize(id + 1);
+ }
+ return tls->entries[id].ptr.exchange(ptr, std::memory_order_acquire);
+}
+
+bool ThreadLocalPtr::StaticMeta::CompareAndSwap(uint32_t id, void* ptr,
+ void*& expected) {
+ auto* tls = GetThreadLocal();
+ if (UNLIKELY(id >= tls->entries.size())) {
+ // Need mutex to protect entries access within ReclaimId
+ MutexLock l(Mutex());
+ tls->entries.resize(id + 1);
+ }
+ return tls->entries[id].ptr.compare_exchange_strong(
+ expected, ptr, std::memory_order_release, std::memory_order_relaxed);
+}
+
+void ThreadLocalPtr::StaticMeta::Scrape(uint32_t id, autovector<void*>* ptrs,
+ void* const replacement) {
+ MutexLock l(Mutex());
+ for (ThreadData* t = head_.next; t != &head_; t = t->next) {
+ if (id < t->entries.size()) {
+ void* ptr =
+ t->entries[id].ptr.exchange(replacement, std::memory_order_acquire);
+ if (ptr != nullptr) {
+ ptrs->push_back(ptr);
+ }
+ }
+ }
+}
+
+void ThreadLocalPtr::StaticMeta::Fold(uint32_t id, FoldFunc func, void* res) {
+ MutexLock l(Mutex());
+ for (ThreadData* t = head_.next; t != &head_; t = t->next) {
+ if (id < t->entries.size()) {
+ void* ptr = t->entries[id].ptr.load();
+ if (ptr != nullptr) {
+ func(ptr, res);
+ }
+ }
+ }
+}
+
+uint32_t ThreadLocalPtr::TEST_PeekId() { return Instance()->PeekId(); }
+
+void ThreadLocalPtr::StaticMeta::SetHandler(uint32_t id, UnrefHandler handler) {
+ MutexLock l(Mutex());
+ handler_map_[id] = handler;
+}
+
+UnrefHandler ThreadLocalPtr::StaticMeta::GetHandler(uint32_t id) {
+ Mutex()->AssertHeld();
+ auto iter = handler_map_.find(id);
+ if (iter == handler_map_.end()) {
+ return nullptr;
+ }
+ return iter->second;
+}
+
+uint32_t ThreadLocalPtr::StaticMeta::GetId() {
+ MutexLock l(Mutex());
+ if (free_instance_ids_.empty()) {
+ return next_instance_id_++;
+ }
+
+ uint32_t id = free_instance_ids_.back();
+ free_instance_ids_.pop_back();
+ return id;
+}
+
+uint32_t ThreadLocalPtr::StaticMeta::PeekId() const {
+ MutexLock l(Mutex());
+ if (!free_instance_ids_.empty()) {
+ return free_instance_ids_.back();
+ }
+ return next_instance_id_;
+}
+
+void ThreadLocalPtr::StaticMeta::ReclaimId(uint32_t id) {
+ // This id is not used, go through all thread local data and release
+ // corresponding value
+ MutexLock l(Mutex());
+ auto unref = GetHandler(id);
+ for (ThreadData* t = head_.next; t != &head_; t = t->next) {
+ if (id < t->entries.size()) {
+ void* ptr = t->entries[id].ptr.exchange(nullptr);
+ if (ptr != nullptr && unref != nullptr) {
+ unref(ptr);
+ }
+ }
+ }
+ handler_map_[id] = nullptr;
+ free_instance_ids_.push_back(id);
+}
+
+ThreadLocalPtr::ThreadLocalPtr(UnrefHandler handler)
+ : id_(Instance()->GetId()) {
+ if (handler != nullptr) {
+ Instance()->SetHandler(id_, handler);
+ }
+}
+
+ThreadLocalPtr::~ThreadLocalPtr() { Instance()->ReclaimId(id_); }
+
+void* ThreadLocalPtr::Get() const { return Instance()->Get(id_); }
+
+void ThreadLocalPtr::Reset(void* ptr) { Instance()->Reset(id_, ptr); }
+
+void* ThreadLocalPtr::Swap(void* ptr) { return Instance()->Swap(id_, ptr); }
+
+bool ThreadLocalPtr::CompareAndSwap(void* ptr, void*& expected) {
+ return Instance()->CompareAndSwap(id_, ptr, expected);
+}
+
+void ThreadLocalPtr::Scrape(autovector<void*>* ptrs, void* const replacement) {
+ Instance()->Scrape(id_, ptrs, replacement);
+}
+
+void ThreadLocalPtr::Fold(FoldFunc func, void* res) {
+ Instance()->Fold(id_, func, res);
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/thread_local.h b/src/rocksdb/util/thread_local.h
new file mode 100644
index 000000000..fde68f86f
--- /dev/null
+++ b/src/rocksdb/util/thread_local.h
@@ -0,0 +1,100 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#pragma once
+
+#include <atomic>
+#include <functional>
+#include <memory>
+#include <unordered_map>
+#include <vector>
+
+#include "port/port.h"
+#include "util/autovector.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// Cleanup function that will be called for a stored thread local
+// pointer (if not NULL) when one of the following happens:
+// (1) a thread terminates
+// (2) a ThreadLocalPtr is destroyed
+//
+// Warning: this function is called while holding a global mutex. The same mutex
+// is used (at least in some cases) by most methods of ThreadLocalPtr, and it's
+// shared across all instances of ThreadLocalPtr. Thereforere extra care
+// is needed to avoid deadlocks. In particular, the handler shouldn't lock any
+// mutexes and shouldn't call any methods of any ThreadLocalPtr instances,
+// unless you know what you're doing.
+using UnrefHandler = void (*)(void* ptr);
+
+// ThreadLocalPtr stores only values of pointer type. Different from
+// the usual thread-local-storage, ThreadLocalPtr has the ability to
+// distinguish data coming from different threads and different
+// ThreadLocalPtr instances. For example, if a regular thread_local
+// variable A is declared in DBImpl, two DBImpl objects would share
+// the same A. However, a ThreadLocalPtr that is defined under the
+// scope of DBImpl can avoid such confliction. As a result, its memory
+// usage would be O(# of threads * # of ThreadLocalPtr instances).
+class ThreadLocalPtr {
+ public:
+ explicit ThreadLocalPtr(UnrefHandler handler = nullptr);
+
+ ThreadLocalPtr(const ThreadLocalPtr&) = delete;
+ ThreadLocalPtr& operator=(const ThreadLocalPtr&) = delete;
+
+ ~ThreadLocalPtr();
+
+ // Return the current pointer stored in thread local
+ void* Get() const;
+
+ // Set a new pointer value to the thread local storage.
+ void Reset(void* ptr);
+
+ // Atomically swap the supplied ptr and return the previous value
+ void* Swap(void* ptr);
+
+ // Atomically compare the stored value with expected. Set the new
+ // pointer value to thread local only if the comparison is true.
+ // Otherwise, expected returns the stored value.
+ // Return true on success, false on failure
+ bool CompareAndSwap(void* ptr, void*& expected);
+
+ // Reset all thread local data to replacement, and return non-nullptr
+ // data for all existing threads
+ void Scrape(autovector<void*>* ptrs, void* const replacement);
+
+ using FoldFunc = std::function<void(void*, void*)>;
+ // Update res by applying func on each thread-local value. Holds a lock that
+ // prevents unref handler from running during this call, but clients must
+ // still provide external synchronization since the owning thread can
+ // access the values without internal locking, e.g., via Get() and Reset().
+ void Fold(FoldFunc func, void* res);
+
+ // Add here for testing
+ // Return the next available Id without claiming it
+ static uint32_t TEST_PeekId();
+
+ // Initialize the static singletons of the ThreadLocalPtr.
+ //
+ // If this function is not called, then the singletons will be
+ // automatically initialized when they are used.
+ //
+ // Calling this function twice or after the singletons have been
+ // initialized will be no-op.
+ static void InitSingletons();
+
+ class StaticMeta;
+
+ private:
+ static StaticMeta* Instance();
+
+ const uint32_t id_;
+};
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/thread_local_test.cc b/src/rocksdb/util/thread_local_test.cc
new file mode 100644
index 000000000..25ef5c0ee
--- /dev/null
+++ b/src/rocksdb/util/thread_local_test.cc
@@ -0,0 +1,582 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "util/thread_local.h"
+
+#include <atomic>
+#include <string>
+#include <thread>
+
+#include "port/port.h"
+#include "rocksdb/env.h"
+#include "test_util/sync_point.h"
+#include "test_util/testharness.h"
+#include "test_util/testutil.h"
+#include "util/autovector.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class ThreadLocalTest : public testing::Test {
+ public:
+ ThreadLocalTest() : env_(Env::Default()) {}
+
+ Env* env_;
+};
+
+namespace {
+
+struct Params {
+ Params(port::Mutex* m, port::CondVar* c, int* u, int n,
+ UnrefHandler handler = nullptr)
+ : mu(m),
+ cv(c),
+ unref(u),
+ total(n),
+ started(0),
+ completed(0),
+ doWrite(false),
+ tls1(handler),
+ tls2(nullptr) {}
+
+ port::Mutex* mu;
+ port::CondVar* cv;
+ int* unref;
+ int total;
+ int started;
+ int completed;
+ bool doWrite;
+ ThreadLocalPtr tls1;
+ ThreadLocalPtr* tls2;
+};
+
+class IDChecker : public ThreadLocalPtr {
+ public:
+ static uint32_t PeekId() { return TEST_PeekId(); }
+};
+
+} // anonymous namespace
+
+// Suppress false positive clang analyzer warnings.
+#ifndef __clang_analyzer__
+TEST_F(ThreadLocalTest, UniqueIdTest) {
+ port::Mutex mu;
+ port::CondVar cv(&mu);
+
+ uint32_t base_id = IDChecker::PeekId();
+ // New ThreadLocal instance bumps id by 1
+ {
+ // Id used 0
+ Params p1(&mu, &cv, nullptr, 1u);
+ ASSERT_EQ(IDChecker::PeekId(), base_id + 1u);
+ // Id used 1
+ Params p2(&mu, &cv, nullptr, 1u);
+ ASSERT_EQ(IDChecker::PeekId(), base_id + 2u);
+ // Id used 2
+ Params p3(&mu, &cv, nullptr, 1u);
+ ASSERT_EQ(IDChecker::PeekId(), base_id + 3u);
+ // Id used 3
+ Params p4(&mu, &cv, nullptr, 1u);
+ ASSERT_EQ(IDChecker::PeekId(), base_id + 4u);
+ }
+ // id 3, 2, 1, 0 are in the free queue in order
+ ASSERT_EQ(IDChecker::PeekId(), base_id + 0u);
+
+ // pick up 0
+ Params p1(&mu, &cv, nullptr, 1u);
+ ASSERT_EQ(IDChecker::PeekId(), base_id + 1u);
+ // pick up 1
+ Params* p2 = new Params(&mu, &cv, nullptr, 1u);
+ ASSERT_EQ(IDChecker::PeekId(), base_id + 2u);
+ // pick up 2
+ Params p3(&mu, &cv, nullptr, 1u);
+ ASSERT_EQ(IDChecker::PeekId(), base_id + 3u);
+ // return up 1
+ delete p2;
+ ASSERT_EQ(IDChecker::PeekId(), base_id + 1u);
+ // Now we have 3, 1 in queue
+ // pick up 1
+ Params p4(&mu, &cv, nullptr, 1u);
+ ASSERT_EQ(IDChecker::PeekId(), base_id + 3u);
+ // pick up 3
+ Params p5(&mu, &cv, nullptr, 1u);
+ // next new id
+ ASSERT_EQ(IDChecker::PeekId(), base_id + 4u);
+ // After exit, id sequence in queue:
+ // 3, 1, 2, 0
+}
+#endif // __clang_analyzer__
+
+TEST_F(ThreadLocalTest, SequentialReadWriteTest) {
+ // global id list carries over 3, 1, 2, 0
+ uint32_t base_id = IDChecker::PeekId();
+
+ port::Mutex mu;
+ port::CondVar cv(&mu);
+ Params p(&mu, &cv, nullptr, 1);
+ ThreadLocalPtr tls2;
+ p.tls2 = &tls2;
+
+ ASSERT_GT(IDChecker::PeekId(), base_id);
+ base_id = IDChecker::PeekId();
+
+ auto func = [](Params* ptr) {
+ Params& params = *ptr;
+ ASSERT_TRUE(params.tls1.Get() == nullptr);
+ params.tls1.Reset(reinterpret_cast<int*>(1));
+ ASSERT_TRUE(params.tls1.Get() == reinterpret_cast<int*>(1));
+ params.tls1.Reset(reinterpret_cast<int*>(2));
+ ASSERT_TRUE(params.tls1.Get() == reinterpret_cast<int*>(2));
+
+ ASSERT_TRUE(params.tls2->Get() == nullptr);
+ params.tls2->Reset(reinterpret_cast<int*>(1));
+ ASSERT_TRUE(params.tls2->Get() == reinterpret_cast<int*>(1));
+ params.tls2->Reset(reinterpret_cast<int*>(2));
+ ASSERT_TRUE(params.tls2->Get() == reinterpret_cast<int*>(2));
+
+ params.mu->Lock();
+ ++(params.completed);
+ params.cv->SignalAll();
+ params.mu->Unlock();
+ };
+
+ for (int iter = 0; iter < 1024; ++iter) {
+ ASSERT_EQ(IDChecker::PeekId(), base_id);
+ // Another new thread, read/write should not see value from previous thread
+ env_->StartThreadTyped(func, &p);
+
+ mu.Lock();
+ while (p.completed != iter + 1) {
+ cv.Wait();
+ }
+ mu.Unlock();
+ ASSERT_EQ(IDChecker::PeekId(), base_id);
+ }
+}
+
+TEST_F(ThreadLocalTest, ConcurrentReadWriteTest) {
+ // global id list carries over 3, 1, 2, 0
+ uint32_t base_id = IDChecker::PeekId();
+
+ ThreadLocalPtr tls2;
+ port::Mutex mu1;
+ port::CondVar cv1(&mu1);
+ Params p1(&mu1, &cv1, nullptr, 16);
+ p1.tls2 = &tls2;
+
+ port::Mutex mu2;
+ port::CondVar cv2(&mu2);
+ Params p2(&mu2, &cv2, nullptr, 16);
+ p2.doWrite = true;
+ p2.tls2 = &tls2;
+
+ auto func = [](void* ptr) {
+ auto& p = *static_cast<Params*>(ptr);
+
+ p.mu->Lock();
+ // Size_T switches size along with the ptr size
+ // we want to cast to.
+ size_t own = ++(p.started);
+ p.cv->SignalAll();
+ while (p.started != p.total) {
+ p.cv->Wait();
+ }
+ p.mu->Unlock();
+
+ // Let write threads write a different value from the read threads
+ if (p.doWrite) {
+ own += 8192;
+ }
+
+ ASSERT_TRUE(p.tls1.Get() == nullptr);
+ ASSERT_TRUE(p.tls2->Get() == nullptr);
+
+ auto* env = Env::Default();
+ auto start = env->NowMicros();
+
+ p.tls1.Reset(reinterpret_cast<size_t*>(own));
+ p.tls2->Reset(reinterpret_cast<size_t*>(own + 1));
+ // Loop for 1 second
+ while (env->NowMicros() - start < 1000 * 1000) {
+ for (int iter = 0; iter < 100000; ++iter) {
+ ASSERT_TRUE(p.tls1.Get() == reinterpret_cast<size_t*>(own));
+ ASSERT_TRUE(p.tls2->Get() == reinterpret_cast<size_t*>(own + 1));
+ if (p.doWrite) {
+ p.tls1.Reset(reinterpret_cast<size_t*>(own));
+ p.tls2->Reset(reinterpret_cast<size_t*>(own + 1));
+ }
+ }
+ }
+
+ p.mu->Lock();
+ ++(p.completed);
+ p.cv->SignalAll();
+ p.mu->Unlock();
+ };
+
+ // Initiate 2 instnaces: one keeps writing and one keeps reading.
+ // The read instance should not see data from the write instance.
+ // Each thread local copy of the value are also different from each
+ // other.
+ for (int th = 0; th < p1.total; ++th) {
+ env_->StartThreadTyped(func, &p1);
+ }
+ for (int th = 0; th < p2.total; ++th) {
+ env_->StartThreadTyped(func, &p2);
+ }
+
+ mu1.Lock();
+ while (p1.completed != p1.total) {
+ cv1.Wait();
+ }
+ mu1.Unlock();
+
+ mu2.Lock();
+ while (p2.completed != p2.total) {
+ cv2.Wait();
+ }
+ mu2.Unlock();
+
+ ASSERT_EQ(IDChecker::PeekId(), base_id + 3u);
+}
+
+TEST_F(ThreadLocalTest, Unref) {
+ auto unref = [](void* ptr) {
+ auto& p = *static_cast<Params*>(ptr);
+ p.mu->Lock();
+ ++(*p.unref);
+ p.mu->Unlock();
+ };
+
+ // Case 0: no unref triggered if ThreadLocalPtr is never accessed
+ auto func0 = [](Params* ptr) {
+ auto& p = *ptr;
+ p.mu->Lock();
+ ++(p.started);
+ p.cv->SignalAll();
+ while (p.started != p.total) {
+ p.cv->Wait();
+ }
+ p.mu->Unlock();
+ };
+
+ for (int th = 1; th <= 128; th += th) {
+ port::Mutex mu;
+ port::CondVar cv(&mu);
+ int unref_count = 0;
+ Params p(&mu, &cv, &unref_count, th, unref);
+
+ for (int i = 0; i < p.total; ++i) {
+ env_->StartThreadTyped(func0, &p);
+ }
+ env_->WaitForJoin();
+ ASSERT_EQ(unref_count, 0);
+ }
+
+ // Case 1: unref triggered by thread exit
+ auto func1 = [](Params* ptr) {
+ auto& p = *ptr;
+
+ p.mu->Lock();
+ ++(p.started);
+ p.cv->SignalAll();
+ while (p.started != p.total) {
+ p.cv->Wait();
+ }
+ p.mu->Unlock();
+
+ ASSERT_TRUE(p.tls1.Get() == nullptr);
+ ASSERT_TRUE(p.tls2->Get() == nullptr);
+
+ p.tls1.Reset(ptr);
+ p.tls2->Reset(ptr);
+
+ p.tls1.Reset(ptr);
+ p.tls2->Reset(ptr);
+ };
+
+ for (int th = 1; th <= 128; th += th) {
+ port::Mutex mu;
+ port::CondVar cv(&mu);
+ int unref_count = 0;
+ ThreadLocalPtr tls2(unref);
+ Params p(&mu, &cv, &unref_count, th, unref);
+ p.tls2 = &tls2;
+
+ for (int i = 0; i < p.total; ++i) {
+ env_->StartThreadTyped(func1, &p);
+ }
+
+ env_->WaitForJoin();
+
+ // N threads x 2 ThreadLocal instance cleanup on thread exit
+ ASSERT_EQ(unref_count, 2 * p.total);
+ }
+
+ // Case 2: unref triggered by ThreadLocal instance destruction
+ auto func2 = [](Params* ptr) {
+ auto& p = *ptr;
+
+ p.mu->Lock();
+ ++(p.started);
+ p.cv->SignalAll();
+ while (p.started != p.total) {
+ p.cv->Wait();
+ }
+ p.mu->Unlock();
+
+ ASSERT_TRUE(p.tls1.Get() == nullptr);
+ ASSERT_TRUE(p.tls2->Get() == nullptr);
+
+ p.tls1.Reset(ptr);
+ p.tls2->Reset(ptr);
+
+ p.tls1.Reset(ptr);
+ p.tls2->Reset(ptr);
+
+ p.mu->Lock();
+ ++(p.completed);
+ p.cv->SignalAll();
+
+ // Waiting for instruction to exit thread
+ while (p.completed != 0) {
+ p.cv->Wait();
+ }
+ p.mu->Unlock();
+ };
+
+ for (int th = 1; th <= 128; th += th) {
+ port::Mutex mu;
+ port::CondVar cv(&mu);
+ int unref_count = 0;
+ Params p(&mu, &cv, &unref_count, th, unref);
+ p.tls2 = new ThreadLocalPtr(unref);
+
+ for (int i = 0; i < p.total; ++i) {
+ env_->StartThreadTyped(func2, &p);
+ }
+
+ // Wait for all threads to finish using Params
+ mu.Lock();
+ while (p.completed != p.total) {
+ cv.Wait();
+ }
+ mu.Unlock();
+
+ // Now destroy one ThreadLocal instance
+ delete p.tls2;
+ p.tls2 = nullptr;
+ // instance destroy for N threads
+ ASSERT_EQ(unref_count, p.total);
+
+ // Signal to exit
+ mu.Lock();
+ p.completed = 0;
+ cv.SignalAll();
+ mu.Unlock();
+ env_->WaitForJoin();
+ // additional N threads exit unref for the left instance
+ ASSERT_EQ(unref_count, 2 * p.total);
+ }
+}
+
+TEST_F(ThreadLocalTest, Swap) {
+ ThreadLocalPtr tls;
+ tls.Reset(reinterpret_cast<void*>(1));
+ ASSERT_EQ(reinterpret_cast<int64_t>(tls.Swap(nullptr)), 1);
+ ASSERT_TRUE(tls.Swap(reinterpret_cast<void*>(2)) == nullptr);
+ ASSERT_EQ(reinterpret_cast<int64_t>(tls.Get()), 2);
+ ASSERT_EQ(reinterpret_cast<int64_t>(tls.Swap(reinterpret_cast<void*>(3))), 2);
+}
+
+TEST_F(ThreadLocalTest, Scrape) {
+ auto unref = [](void* ptr) {
+ auto& p = *static_cast<Params*>(ptr);
+ p.mu->Lock();
+ ++(*p.unref);
+ p.mu->Unlock();
+ };
+
+ auto func = [](void* ptr) {
+ auto& p = *static_cast<Params*>(ptr);
+
+ ASSERT_TRUE(p.tls1.Get() == nullptr);
+ ASSERT_TRUE(p.tls2->Get() == nullptr);
+
+ p.tls1.Reset(ptr);
+ p.tls2->Reset(ptr);
+
+ p.tls1.Reset(ptr);
+ p.tls2->Reset(ptr);
+
+ p.mu->Lock();
+ ++(p.completed);
+ p.cv->SignalAll();
+
+ // Waiting for instruction to exit thread
+ while (p.completed != 0) {
+ p.cv->Wait();
+ }
+ p.mu->Unlock();
+ };
+
+ for (int th = 1; th <= 128; th += th) {
+ port::Mutex mu;
+ port::CondVar cv(&mu);
+ int unref_count = 0;
+ Params p(&mu, &cv, &unref_count, th, unref);
+ p.tls2 = new ThreadLocalPtr(unref);
+
+ for (int i = 0; i < p.total; ++i) {
+ env_->StartThreadTyped(func, &p);
+ }
+
+ // Wait for all threads to finish using Params
+ mu.Lock();
+ while (p.completed != p.total) {
+ cv.Wait();
+ }
+ mu.Unlock();
+
+ ASSERT_EQ(unref_count, 0);
+
+ // Scrape all thread local data. No unref at thread
+ // exit or ThreadLocalPtr destruction
+ autovector<void*> ptrs;
+ p.tls1.Scrape(&ptrs, nullptr);
+ p.tls2->Scrape(&ptrs, nullptr);
+ delete p.tls2;
+ // Signal to exit
+ mu.Lock();
+ p.completed = 0;
+ cv.SignalAll();
+ mu.Unlock();
+ env_->WaitForJoin();
+
+ ASSERT_EQ(unref_count, 0);
+ }
+}
+
+TEST_F(ThreadLocalTest, Fold) {
+ auto unref = [](void* ptr) {
+ delete static_cast<std::atomic<int64_t>*>(ptr);
+ };
+ static const int kNumThreads = 16;
+ static const int kItersPerThread = 10;
+ port::Mutex mu;
+ port::CondVar cv(&mu);
+ Params params(&mu, &cv, nullptr, kNumThreads, unref);
+ auto func = [](void* ptr) {
+ auto& p = *static_cast<Params*>(ptr);
+ ASSERT_TRUE(p.tls1.Get() == nullptr);
+ p.tls1.Reset(new std::atomic<int64_t>(0));
+
+ for (int i = 0; i < kItersPerThread; ++i) {
+ static_cast<std::atomic<int64_t>*>(p.tls1.Get())->fetch_add(1);
+ }
+
+ p.mu->Lock();
+ ++(p.completed);
+ p.cv->SignalAll();
+
+ // Waiting for instruction to exit thread
+ while (p.completed != 0) {
+ p.cv->Wait();
+ }
+ p.mu->Unlock();
+ };
+
+ for (int th = 0; th < params.total; ++th) {
+ env_->StartThread(func, &params);
+ }
+
+ // Wait for all threads to finish using Params
+ mu.Lock();
+ while (params.completed != params.total) {
+ cv.Wait();
+ }
+ mu.Unlock();
+
+ // Verify Fold() behavior
+ int64_t sum = 0;
+ params.tls1.Fold(
+ [](void* ptr, void* res) {
+ auto sum_ptr = static_cast<int64_t*>(res);
+ *sum_ptr += static_cast<std::atomic<int64_t>*>(ptr)->load();
+ },
+ &sum);
+ ASSERT_EQ(sum, kNumThreads * kItersPerThread);
+
+ // Signal to exit
+ mu.Lock();
+ params.completed = 0;
+ cv.SignalAll();
+ mu.Unlock();
+ env_->WaitForJoin();
+}
+
+TEST_F(ThreadLocalTest, CompareAndSwap) {
+ ThreadLocalPtr tls;
+ ASSERT_TRUE(tls.Swap(reinterpret_cast<void*>(1)) == nullptr);
+ void* expected = reinterpret_cast<void*>(1);
+ // Swap in 2
+ ASSERT_TRUE(tls.CompareAndSwap(reinterpret_cast<void*>(2), expected));
+ expected = reinterpret_cast<void*>(100);
+ // Fail Swap, still 2
+ ASSERT_TRUE(!tls.CompareAndSwap(reinterpret_cast<void*>(2), expected));
+ ASSERT_EQ(expected, reinterpret_cast<void*>(2));
+ // Swap in 3
+ expected = reinterpret_cast<void*>(2);
+ ASSERT_TRUE(tls.CompareAndSwap(reinterpret_cast<void*>(3), expected));
+ ASSERT_EQ(tls.Get(), reinterpret_cast<void*>(3));
+}
+
+namespace {
+
+void* AccessThreadLocal(void* /*arg*/) {
+ TEST_SYNC_POINT("AccessThreadLocal:Start");
+ ThreadLocalPtr tlp;
+ tlp.Reset(new std::string("hello RocksDB"));
+ TEST_SYNC_POINT("AccessThreadLocal:End");
+ return nullptr;
+}
+
+} // namespace
+
+// The following test is disabled as it requires manual steps to run it
+// correctly.
+//
+// Currently we have no way to acess SyncPoint w/o ASAN error when the
+// child thread dies after the main thread dies. So if you manually enable
+// this test and only see an ASAN error on SyncPoint, it means you pass the
+// test.
+TEST_F(ThreadLocalTest, DISABLED_MainThreadDiesFirst) {
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency(
+ {{"AccessThreadLocal:Start", "MainThreadDiesFirst:End"},
+ {"PosixEnv::~PosixEnv():End", "AccessThreadLocal:End"}});
+
+ // Triggers the initialization of singletons.
+ Env::Default();
+
+#ifndef ROCKSDB_LITE
+ try {
+#endif // ROCKSDB_LITE
+ ROCKSDB_NAMESPACE::port::Thread th(&AccessThreadLocal, nullptr);
+ th.detach();
+ TEST_SYNC_POINT("MainThreadDiesFirst:End");
+#ifndef ROCKSDB_LITE
+ } catch (const std::system_error& ex) {
+ std::cerr << "Start thread: " << ex.code() << std::endl;
+ FAIL();
+ }
+#endif // ROCKSDB_LITE
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/rocksdb/util/thread_operation.h b/src/rocksdb/util/thread_operation.h
new file mode 100644
index 000000000..c24fccd5c
--- /dev/null
+++ b/src/rocksdb/util/thread_operation.h
@@ -0,0 +1,112 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// This file defines the structures for thread operation and state.
+// Thread operations are used to describe high level action of a
+// thread such as doing compaction or flush, while thread state
+// are used to describe lower-level action such as reading /
+// writing a file or waiting for a mutex. Operations and states
+// are designed to be independent. Typically, a thread usually involves
+// in one operation and one state at any specific point in time.
+
+#pragma once
+
+#include <string>
+
+#include "rocksdb/thread_status.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+#ifdef ROCKSDB_USING_THREAD_STATUS
+
+// The structure that describes a major thread operation.
+struct OperationInfo {
+ const ThreadStatus::OperationType type;
+ const std::string name;
+};
+
+// The global operation table.
+//
+// When updating a status of a thread, the pointer of the OperationInfo
+// of the current ThreadStatusData will be pointing to one of the
+// rows in this global table.
+//
+// Note that it's not designed to be constant as in the future we
+// might consider adding global count to the OperationInfo.
+static OperationInfo global_operation_table[] = {
+ {ThreadStatus::OP_UNKNOWN, ""},
+ {ThreadStatus::OP_COMPACTION, "Compaction"},
+ {ThreadStatus::OP_FLUSH, "Flush"}};
+
+struct OperationStageInfo {
+ const ThreadStatus::OperationStage stage;
+ const std::string name;
+};
+
+// A table maintains the mapping from stage type to stage string.
+// Note that the string must be changed accordingly when the
+// associated function name changed.
+static OperationStageInfo global_op_stage_table[] = {
+ {ThreadStatus::STAGE_UNKNOWN, ""},
+ {ThreadStatus::STAGE_FLUSH_RUN, "FlushJob::Run"},
+ {ThreadStatus::STAGE_FLUSH_WRITE_L0, "FlushJob::WriteLevel0Table"},
+ {ThreadStatus::STAGE_COMPACTION_PREPARE, "CompactionJob::Prepare"},
+ {ThreadStatus::STAGE_COMPACTION_RUN, "CompactionJob::Run"},
+ {ThreadStatus::STAGE_COMPACTION_PROCESS_KV,
+ "CompactionJob::ProcessKeyValueCompaction"},
+ {ThreadStatus::STAGE_COMPACTION_INSTALL, "CompactionJob::Install"},
+ {ThreadStatus::STAGE_COMPACTION_SYNC_FILE,
+ "CompactionJob::FinishCompactionOutputFile"},
+ {ThreadStatus::STAGE_PICK_MEMTABLES_TO_FLUSH,
+ "MemTableList::PickMemtablesToFlush"},
+ {ThreadStatus::STAGE_MEMTABLE_ROLLBACK,
+ "MemTableList::RollbackMemtableFlush"},
+ {ThreadStatus::STAGE_MEMTABLE_INSTALL_FLUSH_RESULTS,
+ "MemTableList::TryInstallMemtableFlushResults"},
+};
+
+// The structure that describes a state.
+struct StateInfo {
+ const ThreadStatus::StateType type;
+ const std::string name;
+};
+
+// The global state table.
+//
+// When updating a status of a thread, the pointer of the StateInfo
+// of the current ThreadStatusData will be pointing to one of the
+// rows in this global table.
+static StateInfo global_state_table[] = {
+ {ThreadStatus::STATE_UNKNOWN, ""},
+ {ThreadStatus::STATE_MUTEX_WAIT, "Mutex Wait"},
+};
+
+struct OperationProperty {
+ int code;
+ std::string name;
+};
+
+static OperationProperty compaction_operation_properties[] = {
+ {ThreadStatus::COMPACTION_JOB_ID, "JobID"},
+ {ThreadStatus::COMPACTION_INPUT_OUTPUT_LEVEL, "InputOutputLevel"},
+ {ThreadStatus::COMPACTION_PROP_FLAGS, "Manual/Deletion/Trivial"},
+ {ThreadStatus::COMPACTION_TOTAL_INPUT_BYTES, "TotalInputBytes"},
+ {ThreadStatus::COMPACTION_BYTES_READ, "BytesRead"},
+ {ThreadStatus::COMPACTION_BYTES_WRITTEN, "BytesWritten"},
+};
+
+static OperationProperty flush_operation_properties[] = {
+ {ThreadStatus::FLUSH_JOB_ID, "JobID"},
+ {ThreadStatus::FLUSH_BYTES_MEMTABLES, "BytesMemtables"},
+ {ThreadStatus::FLUSH_BYTES_WRITTEN, "BytesWritten"}};
+
+#else
+
+struct OperationInfo {};
+
+struct StateInfo {};
+
+#endif // ROCKSDB_USING_THREAD_STATUS
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/threadpool_imp.cc b/src/rocksdb/util/threadpool_imp.cc
new file mode 100644
index 000000000..09706cac5
--- /dev/null
+++ b/src/rocksdb/util/threadpool_imp.cc
@@ -0,0 +1,551 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "util/threadpool_imp.h"
+
+#ifndef OS_WIN
+#include <unistd.h>
+#endif
+
+#ifdef OS_LINUX
+#include <sys/resource.h>
+#include <sys/syscall.h>
+#endif
+
+#include <stdlib.h>
+
+#include <algorithm>
+#include <atomic>
+#include <condition_variable>
+#include <deque>
+#include <mutex>
+#include <sstream>
+#include <thread>
+#include <vector>
+
+#include "monitoring/thread_status_util.h"
+#include "port/port.h"
+#include "test_util/sync_point.h"
+#include "util/string_util.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+void ThreadPoolImpl::PthreadCall(const char* label, int result) {
+ if (result != 0) {
+ fprintf(stderr, "pthread %s: %s\n", label, errnoStr(result).c_str());
+ abort();
+ }
+}
+
+struct ThreadPoolImpl::Impl {
+ Impl();
+ ~Impl();
+
+ void JoinThreads(bool wait_for_jobs_to_complete);
+
+ void SetBackgroundThreadsInternal(int num, bool allow_reduce);
+ int GetBackgroundThreads();
+
+ unsigned int GetQueueLen() const {
+ return queue_len_.load(std::memory_order_relaxed);
+ }
+
+ void LowerIOPriority();
+
+ void LowerCPUPriority(CpuPriority pri);
+
+ void WakeUpAllThreads() { bgsignal_.notify_all(); }
+
+ void BGThread(size_t thread_id);
+
+ void StartBGThreads();
+
+ void Submit(std::function<void()>&& schedule,
+ std::function<void()>&& unschedule, void* tag);
+
+ int UnSchedule(void* arg);
+
+ void SetHostEnv(Env* env) { env_ = env; }
+
+ Env* GetHostEnv() const { return env_; }
+
+ bool HasExcessiveThread() const {
+ return static_cast<int>(bgthreads_.size()) > total_threads_limit_;
+ }
+
+ // Return true iff the current thread is the excessive thread to terminate.
+ // Always terminate the running thread that is added last, even if there are
+ // more than one thread to terminate.
+ bool IsLastExcessiveThread(size_t thread_id) const {
+ return HasExcessiveThread() && thread_id == bgthreads_.size() - 1;
+ }
+
+ bool IsExcessiveThread(size_t thread_id) const {
+ return static_cast<int>(thread_id) >= total_threads_limit_;
+ }
+
+ // Return the thread priority.
+ // This would allow its member-thread to know its priority.
+ Env::Priority GetThreadPriority() const { return priority_; }
+
+ // Set the thread priority.
+ void SetThreadPriority(Env::Priority priority) { priority_ = priority; }
+
+ int ReserveThreads(int threads_to_be_reserved) {
+ std::unique_lock<std::mutex> lock(mu_);
+ // We can reserve at most num_waiting_threads_ in total so the number of
+ // threads that can be reserved might be fewer than the desired one. In
+ // rare cases, num_waiting_threads_ could be less than reserved_threads
+ // due to SetBackgroundThreadInternal or last excessive threads. If that
+ // happens, we cannot reserve any other threads.
+ int reserved_threads_in_success =
+ std::min(std::max(num_waiting_threads_ - reserved_threads_, 0),
+ threads_to_be_reserved);
+ reserved_threads_ += reserved_threads_in_success;
+ return reserved_threads_in_success;
+ }
+
+ int ReleaseThreads(int threads_to_be_released) {
+ std::unique_lock<std::mutex> lock(mu_);
+ // We cannot release more than reserved_threads_
+ int released_threads_in_success =
+ std::min(reserved_threads_, threads_to_be_released);
+ reserved_threads_ -= released_threads_in_success;
+ WakeUpAllThreads();
+ return released_threads_in_success;
+ }
+
+ private:
+ static void BGThreadWrapper(void* arg);
+
+ bool low_io_priority_;
+ CpuPriority cpu_priority_;
+ Env::Priority priority_;
+ Env* env_;
+
+ int total_threads_limit_;
+ std::atomic_uint queue_len_; // Queue length. Used for stats reporting
+ // Number of reserved threads, managed by ReserveThreads(..) and
+ // ReleaseThreads(..), if num_waiting_threads_ is no larger than
+ // reserved_threads_, its thread will be blocked to ensure the reservation
+ // mechanism
+ int reserved_threads_;
+ // Number of waiting threads (Maximum number of threads that can be
+ // reserved), in rare cases, num_waiting_threads_ could be less than
+ // reserved_threads due to SetBackgroundThreadInternal or last
+ // excessive threads.
+ int num_waiting_threads_;
+ bool exit_all_threads_;
+ bool wait_for_jobs_to_complete_;
+
+ // Entry per Schedule()/Submit() call
+ struct BGItem {
+ void* tag = nullptr;
+ std::function<void()> function;
+ std::function<void()> unschedFunction;
+ };
+
+ using BGQueue = std::deque<BGItem>;
+ BGQueue queue_;
+
+ std::mutex mu_;
+ std::condition_variable bgsignal_;
+ std::vector<port::Thread> bgthreads_;
+};
+
+inline ThreadPoolImpl::Impl::Impl()
+ : low_io_priority_(false),
+ cpu_priority_(CpuPriority::kNormal),
+ priority_(Env::LOW),
+ env_(nullptr),
+ total_threads_limit_(0),
+ queue_len_(),
+ reserved_threads_(0),
+ num_waiting_threads_(0),
+ exit_all_threads_(false),
+ wait_for_jobs_to_complete_(false),
+ queue_(),
+ mu_(),
+ bgsignal_(),
+ bgthreads_() {}
+
+inline ThreadPoolImpl::Impl::~Impl() { assert(bgthreads_.size() == 0U); }
+
+void ThreadPoolImpl::Impl::JoinThreads(bool wait_for_jobs_to_complete) {
+ std::unique_lock<std::mutex> lock(mu_);
+ assert(!exit_all_threads_);
+
+ wait_for_jobs_to_complete_ = wait_for_jobs_to_complete;
+ exit_all_threads_ = true;
+ // prevent threads from being recreated right after they're joined, in case
+ // the user is concurrently submitting jobs.
+ total_threads_limit_ = 0;
+ reserved_threads_ = 0;
+ num_waiting_threads_ = 0;
+
+ lock.unlock();
+
+ bgsignal_.notify_all();
+
+ for (auto& th : bgthreads_) {
+ th.join();
+ }
+
+ bgthreads_.clear();
+
+ exit_all_threads_ = false;
+ wait_for_jobs_to_complete_ = false;
+}
+
+inline void ThreadPoolImpl::Impl::LowerIOPriority() {
+ std::lock_guard<std::mutex> lock(mu_);
+ low_io_priority_ = true;
+}
+
+inline void ThreadPoolImpl::Impl::LowerCPUPriority(CpuPriority pri) {
+ std::lock_guard<std::mutex> lock(mu_);
+ cpu_priority_ = pri;
+}
+
+void ThreadPoolImpl::Impl::BGThread(size_t thread_id) {
+ bool low_io_priority = false;
+ CpuPriority current_cpu_priority = CpuPriority::kNormal;
+
+ while (true) {
+ // Wait until there is an item that is ready to run
+ std::unique_lock<std::mutex> lock(mu_);
+ // Stop waiting if the thread needs to do work or needs to terminate.
+ // Increase num_waiting_threads_ once this task has started waiting
+ num_waiting_threads_++;
+
+ TEST_SYNC_POINT("ThreadPoolImpl::BGThread::WaitingThreadsInc");
+ TEST_IDX_SYNC_POINT("ThreadPoolImpl::BGThread::Start:th", thread_id);
+ // When not exist_all_threads and the current thread id is not the last
+ // excessive thread, it may be blocked due to 3 reasons: 1) queue is empty
+ // 2) it is the excessive thread (not the last one)
+ // 3) the number of waiting threads is not greater than reserved threads
+ // (i.e, no available threads due to full reservation")
+ while (!exit_all_threads_ && !IsLastExcessiveThread(thread_id) &&
+ (queue_.empty() || IsExcessiveThread(thread_id) ||
+ num_waiting_threads_ <= reserved_threads_)) {
+ bgsignal_.wait(lock);
+ }
+ // Decrease num_waiting_threads_ once the thread is not waiting
+ num_waiting_threads_--;
+
+ if (exit_all_threads_) { // mechanism to let BG threads exit safely
+
+ if (!wait_for_jobs_to_complete_ || queue_.empty()) {
+ break;
+ }
+ } else if (IsLastExcessiveThread(thread_id)) {
+ // Current thread is the last generated one and is excessive.
+ // We always terminate excessive thread in the reverse order of
+ // generation time. But not when `exit_all_threads_ == true`,
+ // otherwise `JoinThreads()` could try to `join()` a `detach()`ed
+ // thread.
+ auto& terminating_thread = bgthreads_.back();
+ terminating_thread.detach();
+ bgthreads_.pop_back();
+ if (HasExcessiveThread()) {
+ // There is still at least more excessive thread to terminate.
+ WakeUpAllThreads();
+ }
+ TEST_IDX_SYNC_POINT("ThreadPoolImpl::BGThread::Termination:th",
+ thread_id);
+ TEST_SYNC_POINT("ThreadPoolImpl::BGThread::Termination");
+ break;
+ }
+
+ auto func = std::move(queue_.front().function);
+ queue_.pop_front();
+
+ queue_len_.store(static_cast<unsigned int>(queue_.size()),
+ std::memory_order_relaxed);
+
+ bool decrease_io_priority = (low_io_priority != low_io_priority_);
+ CpuPriority cpu_priority = cpu_priority_;
+ lock.unlock();
+
+ if (cpu_priority < current_cpu_priority) {
+ TEST_SYNC_POINT_CALLBACK("ThreadPoolImpl::BGThread::BeforeSetCpuPriority",
+ &current_cpu_priority);
+ // 0 means current thread.
+ port::SetCpuPriority(0, cpu_priority);
+ current_cpu_priority = cpu_priority;
+ TEST_SYNC_POINT_CALLBACK("ThreadPoolImpl::BGThread::AfterSetCpuPriority",
+ &current_cpu_priority);
+ }
+
+#ifdef OS_LINUX
+ if (decrease_io_priority) {
+#define IOPRIO_CLASS_SHIFT (13)
+#define IOPRIO_PRIO_VALUE(class, data) (((class) << IOPRIO_CLASS_SHIFT) | data)
+ // Put schedule into IOPRIO_CLASS_IDLE class (lowest)
+ // These system calls only have an effect when used in conjunction
+ // with an I/O scheduler that supports I/O priorities. As at
+ // kernel 2.6.17 the only such scheduler is the Completely
+ // Fair Queuing (CFQ) I/O scheduler.
+ // To change scheduler:
+ // echo cfq > /sys/block/<device_name>/queue/schedule
+ // Tunables to consider:
+ // /sys/block/<device_name>/queue/slice_idle
+ // /sys/block/<device_name>/queue/slice_sync
+ syscall(SYS_ioprio_set, 1, // IOPRIO_WHO_PROCESS
+ 0, // current thread
+ IOPRIO_PRIO_VALUE(3, 0));
+ low_io_priority = true;
+ }
+#else
+ (void)decrease_io_priority; // avoid 'unused variable' error
+#endif
+
+ TEST_SYNC_POINT_CALLBACK("ThreadPoolImpl::Impl::BGThread:BeforeRun",
+ &priority_);
+
+ func();
+ }
+}
+
+// Helper struct for passing arguments when creating threads.
+struct BGThreadMetadata {
+ ThreadPoolImpl::Impl* thread_pool_;
+ size_t thread_id_; // Thread count in the thread.
+ BGThreadMetadata(ThreadPoolImpl::Impl* thread_pool, size_t thread_id)
+ : thread_pool_(thread_pool), thread_id_(thread_id) {}
+};
+
+void ThreadPoolImpl::Impl::BGThreadWrapper(void* arg) {
+ BGThreadMetadata* meta = reinterpret_cast<BGThreadMetadata*>(arg);
+ size_t thread_id = meta->thread_id_;
+ ThreadPoolImpl::Impl* tp = meta->thread_pool_;
+#ifdef ROCKSDB_USING_THREAD_STATUS
+ // initialize it because compiler isn't good enough to see we don't use it
+ // uninitialized
+ ThreadStatus::ThreadType thread_type = ThreadStatus::NUM_THREAD_TYPES;
+ switch (tp->GetThreadPriority()) {
+ case Env::Priority::HIGH:
+ thread_type = ThreadStatus::HIGH_PRIORITY;
+ break;
+ case Env::Priority::LOW:
+ thread_type = ThreadStatus::LOW_PRIORITY;
+ break;
+ case Env::Priority::BOTTOM:
+ thread_type = ThreadStatus::BOTTOM_PRIORITY;
+ break;
+ case Env::Priority::USER:
+ thread_type = ThreadStatus::USER;
+ break;
+ case Env::Priority::TOTAL:
+ assert(false);
+ return;
+ }
+ assert(thread_type != ThreadStatus::NUM_THREAD_TYPES);
+ ThreadStatusUtil::RegisterThread(tp->GetHostEnv(), thread_type);
+#endif
+ delete meta;
+ tp->BGThread(thread_id);
+#ifdef ROCKSDB_USING_THREAD_STATUS
+ ThreadStatusUtil::UnregisterThread();
+#endif
+ return;
+}
+
+void ThreadPoolImpl::Impl::SetBackgroundThreadsInternal(int num,
+ bool allow_reduce) {
+ std::lock_guard<std::mutex> lock(mu_);
+ if (exit_all_threads_) {
+ return;
+ }
+ if (num > total_threads_limit_ ||
+ (num < total_threads_limit_ && allow_reduce)) {
+ total_threads_limit_ = std::max(0, num);
+ WakeUpAllThreads();
+ StartBGThreads();
+ }
+}
+
+int ThreadPoolImpl::Impl::GetBackgroundThreads() {
+ std::unique_lock<std::mutex> lock(mu_);
+ return total_threads_limit_;
+}
+
+void ThreadPoolImpl::Impl::StartBGThreads() {
+ // Start background thread if necessary
+ while ((int)bgthreads_.size() < total_threads_limit_) {
+ port::Thread p_t(&BGThreadWrapper,
+ new BGThreadMetadata(this, bgthreads_.size()));
+
+// Set the thread name to aid debugging
+#if defined(_GNU_SOURCE) && defined(__GLIBC_PREREQ)
+#if __GLIBC_PREREQ(2, 12)
+ auto th_handle = p_t.native_handle();
+ std::string thread_priority = Env::PriorityToString(GetThreadPriority());
+ std::ostringstream thread_name_stream;
+ thread_name_stream << "rocksdb:";
+ for (char c : thread_priority) {
+ thread_name_stream << static_cast<char>(tolower(c));
+ }
+ pthread_setname_np(th_handle, thread_name_stream.str().c_str());
+#endif
+#endif
+ bgthreads_.push_back(std::move(p_t));
+ }
+}
+
+void ThreadPoolImpl::Impl::Submit(std::function<void()>&& schedule,
+ std::function<void()>&& unschedule,
+ void* tag) {
+ std::lock_guard<std::mutex> lock(mu_);
+
+ if (exit_all_threads_) {
+ return;
+ }
+
+ StartBGThreads();
+
+ // Add to priority queue
+ queue_.push_back(BGItem());
+ TEST_SYNC_POINT("ThreadPoolImpl::Submit::Enqueue");
+ auto& item = queue_.back();
+ item.tag = tag;
+ item.function = std::move(schedule);
+ item.unschedFunction = std::move(unschedule);
+
+ queue_len_.store(static_cast<unsigned int>(queue_.size()),
+ std::memory_order_relaxed);
+
+ if (!HasExcessiveThread()) {
+ // Wake up at least one waiting thread.
+ bgsignal_.notify_one();
+ } else {
+ // Need to wake up all threads to make sure the one woken
+ // up is not the one to terminate.
+ WakeUpAllThreads();
+ }
+}
+
+int ThreadPoolImpl::Impl::UnSchedule(void* arg) {
+ int count = 0;
+
+ std::vector<std::function<void()>> candidates;
+ {
+ std::lock_guard<std::mutex> lock(mu_);
+
+ // Remove from priority queue
+ BGQueue::iterator it = queue_.begin();
+ while (it != queue_.end()) {
+ if (arg == (*it).tag) {
+ if (it->unschedFunction) {
+ candidates.push_back(std::move(it->unschedFunction));
+ }
+ it = queue_.erase(it);
+ count++;
+ } else {
+ ++it;
+ }
+ }
+ queue_len_.store(static_cast<unsigned int>(queue_.size()),
+ std::memory_order_relaxed);
+ }
+
+ // Run unschedule functions outside the mutex
+ for (auto& f : candidates) {
+ f();
+ }
+
+ return count;
+}
+
+ThreadPoolImpl::ThreadPoolImpl() : impl_(new Impl()) {}
+
+ThreadPoolImpl::~ThreadPoolImpl() {}
+
+void ThreadPoolImpl::JoinAllThreads() { impl_->JoinThreads(false); }
+
+void ThreadPoolImpl::SetBackgroundThreads(int num) {
+ impl_->SetBackgroundThreadsInternal(num, true);
+}
+
+int ThreadPoolImpl::GetBackgroundThreads() {
+ return impl_->GetBackgroundThreads();
+}
+
+unsigned int ThreadPoolImpl::GetQueueLen() const {
+ return impl_->GetQueueLen();
+}
+
+void ThreadPoolImpl::WaitForJobsAndJoinAllThreads() {
+ impl_->JoinThreads(true);
+}
+
+void ThreadPoolImpl::LowerIOPriority() { impl_->LowerIOPriority(); }
+
+void ThreadPoolImpl::LowerCPUPriority(CpuPriority pri) {
+ impl_->LowerCPUPriority(pri);
+}
+
+void ThreadPoolImpl::IncBackgroundThreadsIfNeeded(int num) {
+ impl_->SetBackgroundThreadsInternal(num, false);
+}
+
+void ThreadPoolImpl::SubmitJob(const std::function<void()>& job) {
+ auto copy(job);
+ impl_->Submit(std::move(copy), std::function<void()>(), nullptr);
+}
+
+void ThreadPoolImpl::SubmitJob(std::function<void()>&& job) {
+ impl_->Submit(std::move(job), std::function<void()>(), nullptr);
+}
+
+void ThreadPoolImpl::Schedule(void (*function)(void* arg1), void* arg,
+ void* tag, void (*unschedFunction)(void* arg)) {
+ if (unschedFunction == nullptr) {
+ impl_->Submit(std::bind(function, arg), std::function<void()>(), tag);
+ } else {
+ impl_->Submit(std::bind(function, arg), std::bind(unschedFunction, arg),
+ tag);
+ }
+}
+
+int ThreadPoolImpl::UnSchedule(void* arg) { return impl_->UnSchedule(arg); }
+
+void ThreadPoolImpl::SetHostEnv(Env* env) { impl_->SetHostEnv(env); }
+
+Env* ThreadPoolImpl::GetHostEnv() const { return impl_->GetHostEnv(); }
+
+// Return the thread priority.
+// This would allow its member-thread to know its priority.
+Env::Priority ThreadPoolImpl::GetThreadPriority() const {
+ return impl_->GetThreadPriority();
+}
+
+// Set the thread priority.
+void ThreadPoolImpl::SetThreadPriority(Env::Priority priority) {
+ impl_->SetThreadPriority(priority);
+}
+
+// Reserve a specific number of threads, prevent them from running other
+// functions The number of reserved threads could be fewer than the desired one
+int ThreadPoolImpl::ReserveThreads(int threads_to_be_reserved) {
+ return impl_->ReserveThreads(threads_to_be_reserved);
+}
+
+// Release a specific number of threads
+int ThreadPoolImpl::ReleaseThreads(int threads_to_be_released) {
+ return impl_->ReleaseThreads(threads_to_be_released);
+}
+
+ThreadPool* NewThreadPool(int num_threads) {
+ ThreadPoolImpl* thread_pool = new ThreadPoolImpl();
+ thread_pool->SetBackgroundThreads(num_threads);
+ return thread_pool;
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/threadpool_imp.h b/src/rocksdb/util/threadpool_imp.h
new file mode 100644
index 000000000..a5109e38f
--- /dev/null
+++ b/src/rocksdb/util/threadpool_imp.h
@@ -0,0 +1,120 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+#pragma once
+
+#include <functional>
+#include <memory>
+
+#include "rocksdb/env.h"
+#include "rocksdb/threadpool.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class ThreadPoolImpl : public ThreadPool {
+ public:
+ ThreadPoolImpl();
+ ~ThreadPoolImpl();
+
+ ThreadPoolImpl(ThreadPoolImpl&&) = delete;
+ ThreadPoolImpl& operator=(ThreadPoolImpl&&) = delete;
+
+ // Implement ThreadPool interfaces
+
+ // Wait for all threads to finish.
+ // Discards all the jobs that did not
+ // start executing and waits for those running
+ // to complete
+ void JoinAllThreads() override;
+
+ // Set the number of background threads that will be executing the
+ // scheduled jobs.
+ void SetBackgroundThreads(int num) override;
+ int GetBackgroundThreads() override;
+
+ // Get the number of jobs scheduled in the ThreadPool queue.
+ unsigned int GetQueueLen() const override;
+
+ // Waits for all jobs to complete those
+ // that already started running and those that did not
+ // start yet
+ void WaitForJobsAndJoinAllThreads() override;
+
+ // Make threads to run at a lower kernel IO priority
+ // Currently only has effect on Linux
+ void LowerIOPriority();
+
+ // Make threads to run at a lower kernel CPU priority
+ // Currently only has effect on Linux
+ void LowerCPUPriority(CpuPriority pri);
+
+ // Ensure there is at aleast num threads in the pool
+ // but do not kill threads if there are more
+ void IncBackgroundThreadsIfNeeded(int num);
+
+ // Submit a fire and forget job
+ // These jobs can not be unscheduled
+
+ // This allows to submit the same job multiple times
+ void SubmitJob(const std::function<void()>&) override;
+ // This moves the function in for efficiency
+ void SubmitJob(std::function<void()>&&) override;
+
+ // Schedule a job with an unschedule tag and unschedule function
+ // Can be used to filter and unschedule jobs by a tag
+ // that are still in the queue and did not start running
+ void Schedule(void (*function)(void* arg1), void* arg, void* tag,
+ void (*unschedFunction)(void* arg));
+
+ // Filter jobs that are still in a queue and match
+ // the given tag. Remove them from a queue if any
+ // and for each such job execute an unschedule function
+ // if such was given at scheduling time.
+ int UnSchedule(void* tag);
+
+ void SetHostEnv(Env* env);
+
+ Env* GetHostEnv() const;
+
+ // Return the thread priority.
+ // This would allow its member-thread to know its priority.
+ Env::Priority GetThreadPriority() const;
+
+ // Set the thread priority.
+ void SetThreadPriority(Env::Priority priority);
+
+ // Reserve a specific number of threads, prevent them from running other
+ // functions The number of reserved threads could be fewer than the desired
+ // one
+ int ReserveThreads(int threads_to_be_reserved) override;
+
+ // Release a specific number of threads
+ int ReleaseThreads(int threads_to_be_released) override;
+
+ static void PthreadCall(const char* label, int result);
+
+ struct Impl;
+
+ private:
+ // Current public virtual interface does not provide usable
+ // functionality and thus can not be used internally to
+ // facade different implementations.
+ //
+ // We propose a pimpl idiom in order to easily replace the thread pool impl
+ // w/o touching the header file but providing a different .cc potentially
+ // CMake option driven.
+ //
+ // Another option is to introduce a Env::MakeThreadPool() virtual interface
+ // and override the environment. This would require refactoring ThreadPool
+ // usage.
+ //
+ // We can also combine these two approaches
+ std::unique_ptr<Impl> impl_;
+};
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/timer.h b/src/rocksdb/util/timer.h
new file mode 100644
index 000000000..db71cefaf
--- /dev/null
+++ b/src/rocksdb/util/timer.h
@@ -0,0 +1,340 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+
+#pragma once
+
+#include <functional>
+#include <memory>
+#include <queue>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "monitoring/instrumented_mutex.h"
+#include "rocksdb/system_clock.h"
+#include "test_util/sync_point.h"
+#include "util/mutexlock.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// A Timer class to handle repeated work.
+//
+// `Start()` and `Shutdown()` are currently not thread-safe. The client must
+// serialize calls to these two member functions.
+//
+// A single timer instance can handle multiple functions via a single thread.
+// It is better to leave long running work to a dedicated thread pool.
+//
+// Timer can be started by calling `Start()`, and ended by calling `Shutdown()`.
+// Work (in terms of a `void function`) can be scheduled by calling `Add` with
+// a unique function name and de-scheduled by calling `Cancel`.
+// Many functions can be added.
+//
+// Impl Details:
+// A heap is used to keep track of when the next timer goes off.
+// A map from a function name to the function keeps track of all the functions.
+class Timer {
+ public:
+ explicit Timer(SystemClock* clock)
+ : clock_(clock),
+ mutex_(clock),
+ cond_var_(&mutex_),
+ running_(false),
+ executing_task_(false) {}
+
+ ~Timer() { Shutdown(); }
+
+ // Add a new function to run.
+ // fn_name has to be identical, otherwise it will fail to add and return false
+ // start_after_us is the initial delay.
+ // repeat_every_us is the interval between ending time of the last call and
+ // starting time of the next call. For example, repeat_every_us = 2000 and
+ // the function takes 1000us to run. If it starts at time [now]us, then it
+ // finishes at [now]+1000us, 2nd run starting time will be at [now]+3000us.
+ // repeat_every_us == 0 means do not repeat.
+ bool Add(std::function<void()> fn, const std::string& fn_name,
+ uint64_t start_after_us, uint64_t repeat_every_us) {
+ auto fn_info = std::make_unique<FunctionInfo>(std::move(fn), fn_name, 0,
+ repeat_every_us);
+ InstrumentedMutexLock l(&mutex_);
+ // Assign time within mutex to make sure the next_run_time is larger than
+ // the current running one
+ fn_info->next_run_time_us = clock_->NowMicros() + start_after_us;
+ // the new task start time should never before the current task executing
+ // time, as the executing task can only be running if it's next_run_time_us
+ // is due (<= clock_->NowMicros()).
+ if (executing_task_ &&
+ fn_info->next_run_time_us < heap_.top()->next_run_time_us) {
+ return false;
+ }
+ auto it = map_.find(fn_name);
+ if (it == map_.end()) {
+ heap_.push(fn_info.get());
+ map_.try_emplace(fn_name, std::move(fn_info));
+ } else {
+ // timer doesn't support duplicated function name
+ return false;
+ }
+ cond_var_.SignalAll();
+ return true;
+ }
+
+ void Cancel(const std::string& fn_name) {
+ InstrumentedMutexLock l(&mutex_);
+
+ // Mark the function with fn_name as invalid so that it will not be
+ // requeued.
+ auto it = map_.find(fn_name);
+ if (it != map_.end() && it->second) {
+ it->second->Cancel();
+ }
+
+ // If the currently running function is fn_name, then we need to wait
+ // until it finishes before returning to caller.
+ while (!heap_.empty() && executing_task_) {
+ FunctionInfo* func_info = heap_.top();
+ assert(func_info);
+ if (func_info->name == fn_name) {
+ WaitForTaskCompleteIfNecessary();
+ } else {
+ break;
+ }
+ }
+ }
+
+ void CancelAll() {
+ InstrumentedMutexLock l(&mutex_);
+ CancelAllWithLock();
+ }
+
+ // Start the Timer
+ bool Start() {
+ InstrumentedMutexLock l(&mutex_);
+ if (running_) {
+ return false;
+ }
+
+ running_ = true;
+ thread_ = std::make_unique<port::Thread>(&Timer::Run, this);
+ return true;
+ }
+
+ // Shutdown the Timer
+ bool Shutdown() {
+ {
+ InstrumentedMutexLock l(&mutex_);
+ if (!running_) {
+ return false;
+ }
+ running_ = false;
+ CancelAllWithLock();
+ cond_var_.SignalAll();
+ }
+
+ if (thread_) {
+ thread_->join();
+ }
+ return true;
+ }
+
+ bool HasPendingTask() const {
+ InstrumentedMutexLock l(&mutex_);
+ for (const auto& fn_info : map_) {
+ if (fn_info.second->IsValid()) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+#ifndef NDEBUG
+ // Wait until Timer starting waiting, call the optional callback, then wait
+ // for Timer waiting again.
+ // Tests can provide a custom Clock object to mock time, and use the callback
+ // here to bump current time and trigger Timer. See timer_test for example.
+ //
+ // Note: only support one caller of this method.
+ void TEST_WaitForRun(const std::function<void()>& callback = nullptr) {
+ InstrumentedMutexLock l(&mutex_);
+ // It act as a spin lock
+ while (executing_task_ ||
+ (!heap_.empty() &&
+ heap_.top()->next_run_time_us <= clock_->NowMicros())) {
+ cond_var_.TimedWait(clock_->NowMicros() + 1000);
+ }
+ if (callback != nullptr) {
+ callback();
+ }
+ cond_var_.SignalAll();
+ do {
+ cond_var_.TimedWait(clock_->NowMicros() + 1000);
+ } while (executing_task_ ||
+ (!heap_.empty() &&
+ heap_.top()->next_run_time_us <= clock_->NowMicros()));
+ }
+
+ size_t TEST_GetPendingTaskNum() const {
+ InstrumentedMutexLock l(&mutex_);
+ size_t ret = 0;
+ for (const auto& fn_info : map_) {
+ if (fn_info.second->IsValid()) {
+ ret++;
+ }
+ }
+ return ret;
+ }
+
+ void TEST_OverrideTimer(SystemClock* clock) {
+ InstrumentedMutexLock l(&mutex_);
+ clock_ = clock;
+ }
+#endif // NDEBUG
+
+ private:
+ void Run() {
+ InstrumentedMutexLock l(&mutex_);
+
+ while (running_) {
+ if (heap_.empty()) {
+ // wait
+ TEST_SYNC_POINT("Timer::Run::Waiting");
+ cond_var_.Wait();
+ continue;
+ }
+
+ FunctionInfo* current_fn = heap_.top();
+ assert(current_fn);
+
+ if (!current_fn->IsValid()) {
+ heap_.pop();
+ map_.erase(current_fn->name);
+ continue;
+ }
+
+ if (current_fn->next_run_time_us <= clock_->NowMicros()) {
+ // make a copy of the function so it won't be changed after
+ // mutex_.unlock.
+ std::function<void()> fn = current_fn->fn;
+ executing_task_ = true;
+ mutex_.Unlock();
+ // Execute the work
+ fn();
+ mutex_.Lock();
+ executing_task_ = false;
+ cond_var_.SignalAll();
+
+ // Remove the work from the heap once it is done executing, make sure
+ // it's the same function after executing the work while mutex is
+ // released.
+ // Note that we are just removing the pointer from the heap. Its
+ // memory is still managed in the map (as it holds a unique ptr).
+ // So current_fn is still a valid ptr.
+ assert(heap_.top() == current_fn);
+ heap_.pop();
+
+ // current_fn may be cancelled already.
+ if (current_fn->IsValid() && current_fn->repeat_every_us > 0) {
+ assert(running_);
+ current_fn->next_run_time_us =
+ clock_->NowMicros() + current_fn->repeat_every_us;
+
+ // Schedule new work into the heap with new time.
+ heap_.push(current_fn);
+ } else {
+ // if current_fn is cancelled or no need to repeat, remove it from the
+ // map to avoid leak.
+ map_.erase(current_fn->name);
+ }
+ } else {
+ cond_var_.TimedWait(current_fn->next_run_time_us);
+ }
+ }
+ }
+
+ void CancelAllWithLock() {
+ mutex_.AssertHeld();
+ if (map_.empty() && heap_.empty()) {
+ return;
+ }
+
+ // With mutex_ held, set all tasks to invalid so that they will not be
+ // re-queued.
+ for (auto& elem : map_) {
+ auto& func_info = elem.second;
+ assert(func_info);
+ func_info->Cancel();
+ }
+
+ // WaitForTaskCompleteIfNecessary() may release mutex_
+ WaitForTaskCompleteIfNecessary();
+
+ while (!heap_.empty()) {
+ heap_.pop();
+ }
+ map_.clear();
+ }
+
+ // A wrapper around std::function to keep track when it should run next
+ // and at what frequency.
+ struct FunctionInfo {
+ // the actual work
+ std::function<void()> fn;
+ // name of the function
+ std::string name;
+ // when the function should run next
+ uint64_t next_run_time_us;
+ // repeat interval
+ uint64_t repeat_every_us;
+ // controls whether this function is valid.
+ // A function is valid upon construction and until someone explicitly
+ // calls `Cancel()`.
+ bool valid;
+
+ FunctionInfo(std::function<void()>&& _fn, std::string _name,
+ const uint64_t _next_run_time_us, uint64_t _repeat_every_us)
+ : fn(std::move(_fn)),
+ name(std::move(_name)),
+ next_run_time_us(_next_run_time_us),
+ repeat_every_us(_repeat_every_us),
+ valid(true) {}
+
+ void Cancel() { valid = false; }
+
+ bool IsValid() const { return valid; }
+ };
+
+ void WaitForTaskCompleteIfNecessary() {
+ mutex_.AssertHeld();
+ while (executing_task_) {
+ TEST_SYNC_POINT("Timer::WaitForTaskCompleteIfNecessary:TaskExecuting");
+ cond_var_.Wait();
+ }
+ }
+
+ struct RunTimeOrder {
+ bool operator()(const FunctionInfo* f1, const FunctionInfo* f2) {
+ return f1->next_run_time_us > f2->next_run_time_us;
+ }
+ };
+
+ SystemClock* clock_;
+ // This mutex controls both the heap_ and the map_. It needs to be held for
+ // making any changes in them.
+ mutable InstrumentedMutex mutex_;
+ InstrumentedCondVar cond_var_;
+ std::unique_ptr<port::Thread> thread_;
+ bool running_;
+ bool executing_task_;
+
+ std::priority_queue<FunctionInfo*, std::vector<FunctionInfo*>, RunTimeOrder>
+ heap_;
+
+ // In addition to providing a mapping from a function name to a function,
+ // it is also responsible for memory management.
+ std::unordered_map<std::string, std::unique_ptr<FunctionInfo>> map_;
+};
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/timer_queue.h b/src/rocksdb/util/timer_queue.h
new file mode 100644
index 000000000..36a1744ac
--- /dev/null
+++ b/src/rocksdb/util/timer_queue.h
@@ -0,0 +1,231 @@
+// Portions Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Borrowed from
+// http://www.crazygaze.com/blog/2016/03/24/portable-c-timer-queue/
+// Timer Queue
+//
+// License
+//
+// The source code in this article is licensed under the CC0 license, so feel
+// free to copy, modify, share, do whatever you want with it.
+// No attribution is required, but Ill be happy if you do.
+// CC0 license
+
+// The person who associated a work with this deed has dedicated the work to the
+// public domain by waiving all of his or her rights to the work worldwide
+// under copyright law, including all related and neighboring rights, to the
+// extent allowed by law. You can copy, modify, distribute and perform the
+// work, even for commercial purposes, all without asking permission.
+
+#pragma once
+
+#include <assert.h>
+
+#include <chrono>
+#include <condition_variable>
+#include <functional>
+#include <queue>
+#include <thread>
+#include <utility>
+#include <vector>
+
+#include "port/port.h"
+#include "test_util/sync_point.h"
+
+// Allows execution of handlers at a specified time in the future
+// Guarantees:
+// - All handlers are executed ONCE, even if cancelled (aborted parameter will
+// be set to true)
+// - If TimerQueue is destroyed, it will cancel all handlers.
+// - Handlers are ALWAYS executed in the Timer Queue worker thread.
+// - Handlers execution order is NOT guaranteed
+//
+////////////////////////////////////////////////////////////////////////////////
+// borrowed from
+// http://www.crazygaze.com/blog/2016/03/24/portable-c-timer-queue/
+class TimerQueue {
+ public:
+ TimerQueue() : m_th(&TimerQueue::run, this) {}
+
+ ~TimerQueue() { shutdown(); }
+
+ // This function is not thread-safe.
+ void shutdown() {
+ if (closed_) {
+ return;
+ }
+ cancelAll();
+ // Abusing the timer queue to trigger the shutdown.
+ add(0, [this](bool) {
+ m_finish = true;
+ return std::make_pair(false, 0);
+ });
+ m_th.join();
+ closed_ = true;
+ }
+
+ // Adds a new timer
+ // \return
+ // Returns the ID of the new timer. You can use this ID to cancel the
+ // timer
+ uint64_t add(int64_t milliseconds,
+ std::function<std::pair<bool, int64_t>(bool)> handler) {
+ WorkItem item;
+ Clock::time_point tp = Clock::now();
+ item.end = tp + std::chrono::milliseconds(milliseconds);
+ TEST_SYNC_POINT_CALLBACK("TimeQueue::Add:item.end", &item.end);
+ item.period = milliseconds;
+ item.handler = std::move(handler);
+
+ std::unique_lock<std::mutex> lk(m_mtx);
+ uint64_t id = ++m_idcounter;
+ item.id = id;
+ m_items.push(std::move(item));
+
+ // Something changed, so wake up timer thread
+ m_checkWork.notify_one();
+ return id;
+ }
+
+ // Cancels the specified timer
+ // \return
+ // 1 if the timer was cancelled.
+ // 0 if you were too late to cancel (or the timer ID was never valid to
+ // start with)
+ size_t cancel(uint64_t id) {
+ // Instead of removing the item from the container (thus breaking the
+ // heap integrity), we set the item as having no handler, and put
+ // that handler on a new item at the top for immediate execution
+ // The timer thread will then ignore the original item, since it has no
+ // handler.
+ std::unique_lock<std::mutex> lk(m_mtx);
+ for (auto&& item : m_items.getContainer()) {
+ if (item.id == id && item.handler) {
+ WorkItem newItem;
+ // Zero time, so it stays at the top for immediate execution
+ newItem.end = Clock::time_point();
+ newItem.id = 0; // Means it is a canceled item
+ // Move the handler from item to newitem (thus clearing item)
+ newItem.handler = std::move(item.handler);
+ m_items.push(std::move(newItem));
+
+ // Something changed, so wake up timer thread
+ m_checkWork.notify_one();
+ return 1;
+ }
+ }
+ return 0;
+ }
+
+ // Cancels all timers
+ // \return
+ // The number of timers cancelled
+ size_t cancelAll() {
+ // Setting all "end" to 0 (for immediate execution) is ok,
+ // since it maintains the heap integrity
+ std::unique_lock<std::mutex> lk(m_mtx);
+ m_cancel = true;
+ for (auto&& item : m_items.getContainer()) {
+ if (item.id && item.handler) {
+ item.end = Clock::time_point();
+ item.id = 0;
+ }
+ }
+ auto ret = m_items.size();
+
+ m_checkWork.notify_one();
+ return ret;
+ }
+
+ private:
+ using Clock = std::chrono::steady_clock;
+ TimerQueue(const TimerQueue&) = delete;
+ TimerQueue& operator=(const TimerQueue&) = delete;
+
+ void run() {
+ std::unique_lock<std::mutex> lk(m_mtx);
+ while (!m_finish) {
+ auto end = calcWaitTime_lock();
+ if (end.first) {
+ // Timers found, so wait until it expires (or something else
+ // changes)
+ m_checkWork.wait_until(lk, end.second);
+ } else {
+ // No timers exist, so wait forever until something changes
+ m_checkWork.wait(lk);
+ }
+
+ // Check and execute as much work as possible, such as, all expired
+ // timers
+ checkWork(&lk);
+ }
+
+ // If we are shutting down, we should not have any items left,
+ // since the shutdown cancels all items
+ assert(m_items.size() == 0);
+ }
+
+ std::pair<bool, Clock::time_point> calcWaitTime_lock() {
+ while (m_items.size()) {
+ if (m_items.top().handler) {
+ // Item present, so return the new wait time
+ return std::make_pair(true, m_items.top().end);
+ } else {
+ // Discard empty handlers (they were cancelled)
+ m_items.pop();
+ }
+ }
+
+ // No items found, so return no wait time (causes the thread to wait
+ // indefinitely)
+ return std::make_pair(false, Clock::time_point());
+ }
+
+ void checkWork(std::unique_lock<std::mutex>* lk) {
+ while (m_items.size() && m_items.top().end <= Clock::now()) {
+ WorkItem item(m_items.top());
+ m_items.pop();
+
+ if (item.handler) {
+ (*lk).unlock();
+ auto reschedule_pair = item.handler(item.id == 0);
+ (*lk).lock();
+ if (!m_cancel && reschedule_pair.first) {
+ int64_t new_period = (reschedule_pair.second == -1)
+ ? item.period
+ : reschedule_pair.second;
+
+ item.period = new_period;
+ item.end = Clock::now() + std::chrono::milliseconds(new_period);
+ m_items.push(std::move(item));
+ }
+ }
+ }
+ }
+
+ bool m_finish = false;
+ bool m_cancel = false;
+ uint64_t m_idcounter = 0;
+ std::condition_variable m_checkWork;
+
+ struct WorkItem {
+ Clock::time_point end;
+ int64_t period;
+ uint64_t id; // id==0 means it was cancelled
+ std::function<std::pair<bool, int64_t>(bool)> handler;
+ bool operator>(const WorkItem& other) const { return end > other.end; }
+ };
+
+ std::mutex m_mtx;
+ // Inheriting from priority_queue, so we can access the internal container
+ class Queue : public std::priority_queue<WorkItem, std::vector<WorkItem>,
+ std::greater<WorkItem>> {
+ public:
+ std::vector<WorkItem>& getContainer() { return this->c; }
+ } m_items;
+ ROCKSDB_NAMESPACE::port::Thread m_th;
+ bool closed_ = false;
+};
diff --git a/src/rocksdb/util/timer_queue_test.cc b/src/rocksdb/util/timer_queue_test.cc
new file mode 100644
index 000000000..b3c3768ec
--- /dev/null
+++ b/src/rocksdb/util/timer_queue_test.cc
@@ -0,0 +1,73 @@
+// Portions Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+// borrowed from
+// http://www.crazygaze.com/blog/2016/03/24/portable-c-timer-queue/
+// Timer Queue
+//
+// License
+//
+// The source code in this article is licensed under the CC0 license, so feel
+// free
+// to copy, modify, share, do whatever you want with it.
+// No attribution is required, but Ill be happy if you do.
+// CC0 license
+
+// The person who associated a work with this deed has dedicated the work to the
+// public domain by waiving all of his or her rights to the work worldwide
+// under copyright law, including all related and neighboring rights, to the
+// extent allowed by law. You can copy, modify, distribute and perform the
+// work, even for
+// commercial purposes, all without asking permission. See Other Information
+// below.
+//
+
+#include "util/timer_queue.h"
+
+#include <future>
+
+namespace Timing {
+
+using Clock = std::chrono::high_resolution_clock;
+double now() {
+ static auto start = Clock::now();
+ return std::chrono::duration<double, std::milli>(Clock::now() - start)
+ .count();
+}
+
+} // namespace Timing
+
+int main() {
+ TimerQueue q;
+
+ double tnow = Timing::now();
+
+ q.add(10000, [tnow](bool aborted) mutable {
+ printf("T 1: %d, Elapsed %4.2fms\n", aborted, Timing::now() - tnow);
+ return std::make_pair(false, 0);
+ });
+ q.add(10001, [tnow](bool aborted) mutable {
+ printf("T 2: %d, Elapsed %4.2fms\n", aborted, Timing::now() - tnow);
+ return std::make_pair(false, 0);
+ });
+
+ q.add(1000, [tnow](bool aborted) mutable {
+ printf("T 3: %d, Elapsed %4.2fms\n", aborted, Timing::now() - tnow);
+ return std::make_pair(!aborted, 1000);
+ });
+
+ auto id = q.add(2000, [tnow](bool aborted) mutable {
+ printf("T 4: %d, Elapsed %4.2fms\n", aborted, Timing::now() - tnow);
+ return std::make_pair(!aborted, 2000);
+ });
+
+ (void)id;
+ // auto ret = q.cancel(id);
+ // assert(ret == 1);
+ // q.cancelAll();
+
+ return 0;
+}
+//////////////////////////////////////////
diff --git a/src/rocksdb/util/timer_test.cc b/src/rocksdb/util/timer_test.cc
new file mode 100644
index 000000000..0ebfa9f3d
--- /dev/null
+++ b/src/rocksdb/util/timer_test.cc
@@ -0,0 +1,402 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "util/timer.h"
+
+#include "db/db_test_util.h"
+#include "rocksdb/file_system.h"
+#include "test_util/mock_time_env.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class TimerTest : public testing::Test {
+ public:
+ TimerTest()
+ : mock_clock_(std::make_shared<MockSystemClock>(SystemClock::Default())) {
+ }
+
+ protected:
+ std::shared_ptr<MockSystemClock> mock_clock_;
+
+ void SetUp() override { mock_clock_->InstallTimedWaitFixCallback(); }
+
+ const int kUsPerSec = 1000000;
+};
+
+TEST_F(TimerTest, SingleScheduleOnce) {
+ const int kInitDelayUs = 1 * kUsPerSec;
+ Timer timer(mock_clock_.get());
+
+ int count = 0;
+ timer.Add([&] { count++; }, "fn_sch_test", kInitDelayUs, 0);
+
+ ASSERT_TRUE(timer.Start());
+
+ ASSERT_EQ(0, count);
+ // Wait for execution to finish
+ timer.TEST_WaitForRun(
+ [&] { mock_clock_->SleepForMicroseconds(kInitDelayUs); });
+ ASSERT_EQ(1, count);
+
+ ASSERT_TRUE(timer.Shutdown());
+}
+
+TEST_F(TimerTest, MultipleScheduleOnce) {
+ const int kInitDelay1Us = 1 * kUsPerSec;
+ const int kInitDelay2Us = 3 * kUsPerSec;
+ Timer timer(mock_clock_.get());
+
+ int count1 = 0;
+ timer.Add([&] { count1++; }, "fn_sch_test1", kInitDelay1Us, 0);
+
+ int count2 = 0;
+ timer.Add([&] { count2++; }, "fn_sch_test2", kInitDelay2Us, 0);
+
+ ASSERT_TRUE(timer.Start());
+ ASSERT_EQ(0, count1);
+ ASSERT_EQ(0, count2);
+
+ timer.TEST_WaitForRun(
+ [&] { mock_clock_->SleepForMicroseconds(kInitDelay1Us); });
+
+ ASSERT_EQ(1, count1);
+ ASSERT_EQ(0, count2);
+
+ timer.TEST_WaitForRun([&] {
+ mock_clock_->SleepForMicroseconds(kInitDelay2Us - kInitDelay1Us);
+ });
+
+ ASSERT_EQ(1, count1);
+ ASSERT_EQ(1, count2);
+
+ ASSERT_TRUE(timer.Shutdown());
+}
+
+TEST_F(TimerTest, SingleScheduleRepeatedly) {
+ const int kIterations = 5;
+ const int kInitDelayUs = 1 * kUsPerSec;
+ const int kRepeatUs = 1 * kUsPerSec;
+
+ Timer timer(mock_clock_.get());
+ int count = 0;
+ timer.Add([&] { count++; }, "fn_sch_test", kInitDelayUs, kRepeatUs);
+
+ ASSERT_TRUE(timer.Start());
+ ASSERT_EQ(0, count);
+
+ timer.TEST_WaitForRun(
+ [&] { mock_clock_->SleepForMicroseconds(kInitDelayUs); });
+
+ ASSERT_EQ(1, count);
+
+ // Wait for execution to finish
+ for (int i = 1; i < kIterations; i++) {
+ timer.TEST_WaitForRun(
+ [&] { mock_clock_->SleepForMicroseconds(kRepeatUs); });
+ }
+ ASSERT_EQ(kIterations, count);
+
+ ASSERT_TRUE(timer.Shutdown());
+}
+
+TEST_F(TimerTest, MultipleScheduleRepeatedly) {
+ const int kIterations = 5;
+ const int kInitDelay1Us = 0 * kUsPerSec;
+ const int kInitDelay2Us = 1 * kUsPerSec;
+ const int kInitDelay3Us = 0 * kUsPerSec;
+ const int kRepeatUs = 2 * kUsPerSec;
+ const int kLargeRepeatUs = 100 * kUsPerSec;
+
+ Timer timer(mock_clock_.get());
+
+ int count1 = 0;
+ timer.Add([&] { count1++; }, "fn_sch_test1", kInitDelay1Us, kRepeatUs);
+
+ int count2 = 0;
+ timer.Add([&] { count2++; }, "fn_sch_test2", kInitDelay2Us, kRepeatUs);
+
+ // Add a function with relatively large repeat interval
+ int count3 = 0;
+ timer.Add([&] { count3++; }, "fn_sch_test3", kInitDelay3Us, kLargeRepeatUs);
+
+ ASSERT_TRUE(timer.Start());
+
+ ASSERT_EQ(0, count2);
+ // Wait for execution to finish
+ for (int i = 1; i < kIterations * (kRepeatUs / kUsPerSec); i++) {
+ timer.TEST_WaitForRun(
+ [&] { mock_clock_->SleepForMicroseconds(1 * kUsPerSec); });
+ ASSERT_EQ((i + 2) / (kRepeatUs / kUsPerSec), count1);
+ ASSERT_EQ((i + 1) / (kRepeatUs / kUsPerSec), count2);
+
+ // large interval function should only run once (the first one).
+ ASSERT_EQ(1, count3);
+ }
+
+ timer.Cancel("fn_sch_test1");
+
+ // Wait for execution to finish
+ timer.TEST_WaitForRun(
+ [&] { mock_clock_->SleepForMicroseconds(1 * kUsPerSec); });
+ ASSERT_EQ(kIterations, count1);
+ ASSERT_EQ(kIterations, count2);
+ ASSERT_EQ(1, count3);
+
+ timer.Cancel("fn_sch_test2");
+
+ ASSERT_EQ(kIterations, count1);
+ ASSERT_EQ(kIterations, count2);
+
+ // execute the long interval one
+ timer.TEST_WaitForRun([&] {
+ mock_clock_->SleepForMicroseconds(
+ kLargeRepeatUs - static_cast<int>(mock_clock_->NowMicros()));
+ });
+ ASSERT_EQ(2, count3);
+
+ ASSERT_TRUE(timer.Shutdown());
+}
+
+TEST_F(TimerTest, AddAfterStartTest) {
+ const int kIterations = 5;
+ const int kInitDelayUs = 1 * kUsPerSec;
+ const int kRepeatUs = 1 * kUsPerSec;
+
+ // wait timer to run and then add a new job
+ SyncPoint::GetInstance()->LoadDependency(
+ {{"Timer::Run::Waiting", "TimerTest:AddAfterStartTest:1"}});
+ SyncPoint::GetInstance()->EnableProcessing();
+
+ Timer timer(mock_clock_.get());
+
+ ASSERT_TRUE(timer.Start());
+
+ TEST_SYNC_POINT("TimerTest:AddAfterStartTest:1");
+ int count = 0;
+ timer.Add([&] { count++; }, "fn_sch_test", kInitDelayUs, kRepeatUs);
+ ASSERT_EQ(0, count);
+ // Wait for execution to finish
+ timer.TEST_WaitForRun(
+ [&] { mock_clock_->SleepForMicroseconds(kInitDelayUs); });
+ ASSERT_EQ(1, count);
+
+ for (int i = 1; i < kIterations; i++) {
+ timer.TEST_WaitForRun(
+ [&] { mock_clock_->SleepForMicroseconds(kRepeatUs); });
+ }
+ ASSERT_EQ(kIterations, count);
+
+ ASSERT_TRUE(timer.Shutdown());
+}
+
+TEST_F(TimerTest, CancelRunningTask) {
+ static constexpr char kTestFuncName[] = "test_func";
+ const int kRepeatUs = 1 * kUsPerSec;
+ Timer timer(mock_clock_.get());
+ ASSERT_TRUE(timer.Start());
+ int* value = new int;
+ *value = 0;
+ SyncPoint::GetInstance()->DisableProcessing();
+ SyncPoint::GetInstance()->LoadDependency({
+ {"TimerTest::CancelRunningTask:test_func:0",
+ "TimerTest::CancelRunningTask:BeforeCancel"},
+ {"Timer::WaitForTaskCompleteIfNecessary:TaskExecuting",
+ "TimerTest::CancelRunningTask:test_func:1"},
+ });
+ SyncPoint::GetInstance()->EnableProcessing();
+ timer.Add(
+ [&]() {
+ *value = 1;
+ TEST_SYNC_POINT("TimerTest::CancelRunningTask:test_func:0");
+ TEST_SYNC_POINT("TimerTest::CancelRunningTask:test_func:1");
+ },
+ kTestFuncName, 0, kRepeatUs);
+ port::Thread control_thr([&]() {
+ TEST_SYNC_POINT("TimerTest::CancelRunningTask:BeforeCancel");
+ timer.Cancel(kTestFuncName);
+ // Verify that *value has been set to 1.
+ ASSERT_EQ(1, *value);
+ delete value;
+ value = nullptr;
+ });
+ mock_clock_->SleepForMicroseconds(kRepeatUs);
+ control_thr.join();
+ ASSERT_TRUE(timer.Shutdown());
+}
+
+TEST_F(TimerTest, ShutdownRunningTask) {
+ const int kRepeatUs = 1 * kUsPerSec;
+ constexpr char kTestFunc1Name[] = "test_func1";
+ constexpr char kTestFunc2Name[] = "test_func2";
+ Timer timer(mock_clock_.get());
+
+ SyncPoint::GetInstance()->DisableProcessing();
+ SyncPoint::GetInstance()->LoadDependency({
+ {"TimerTest::ShutdownRunningTest:test_func:0",
+ "TimerTest::ShutdownRunningTest:BeforeShutdown"},
+ {"Timer::WaitForTaskCompleteIfNecessary:TaskExecuting",
+ "TimerTest::ShutdownRunningTest:test_func:1"},
+ });
+ SyncPoint::GetInstance()->EnableProcessing();
+
+ ASSERT_TRUE(timer.Start());
+
+ int* value = new int;
+ *value = 0;
+ timer.Add(
+ [&]() {
+ TEST_SYNC_POINT("TimerTest::ShutdownRunningTest:test_func:0");
+ *value = 1;
+ TEST_SYNC_POINT("TimerTest::ShutdownRunningTest:test_func:1");
+ },
+ kTestFunc1Name, 0, kRepeatUs);
+
+ timer.Add([&]() { ++(*value); }, kTestFunc2Name, 0, kRepeatUs);
+
+ port::Thread control_thr([&]() {
+ TEST_SYNC_POINT("TimerTest::ShutdownRunningTest:BeforeShutdown");
+ timer.Shutdown();
+ });
+ mock_clock_->SleepForMicroseconds(kRepeatUs);
+ control_thr.join();
+ delete value;
+}
+
+TEST_F(TimerTest, AddSameFuncName) {
+ const int kInitDelayUs = 1 * kUsPerSec;
+ const int kRepeat1Us = 5 * kUsPerSec;
+ const int kRepeat2Us = 4 * kUsPerSec;
+
+ Timer timer(mock_clock_.get());
+ ASSERT_TRUE(timer.Start());
+
+ int func_counter1 = 0;
+ ASSERT_TRUE(timer.Add([&] { func_counter1++; }, "duplicated_func",
+ kInitDelayUs, kRepeat1Us));
+
+ int func2_counter = 0;
+ ASSERT_TRUE(
+ timer.Add([&] { func2_counter++; }, "func2", kInitDelayUs, kRepeat2Us));
+
+ // New function with the same name should fail to add
+ int func_counter2 = 0;
+ ASSERT_FALSE(timer.Add([&] { func_counter2++; }, "duplicated_func",
+ kInitDelayUs, kRepeat1Us));
+
+ ASSERT_EQ(0, func_counter1);
+ ASSERT_EQ(0, func2_counter);
+
+ timer.TEST_WaitForRun(
+ [&] { mock_clock_->SleepForMicroseconds(kInitDelayUs); });
+
+ ASSERT_EQ(1, func_counter1);
+ ASSERT_EQ(1, func2_counter);
+
+ timer.TEST_WaitForRun([&] { mock_clock_->SleepForMicroseconds(kRepeat1Us); });
+
+ ASSERT_EQ(2, func_counter1);
+ ASSERT_EQ(2, func2_counter);
+ ASSERT_EQ(0, func_counter2);
+
+ ASSERT_TRUE(timer.Shutdown());
+}
+
+TEST_F(TimerTest, RepeatIntervalWithFuncRunningTime) {
+ const int kInitDelayUs = 1 * kUsPerSec;
+ const int kRepeatUs = 5 * kUsPerSec;
+ const int kFuncRunningTimeUs = 1 * kUsPerSec;
+
+ Timer timer(mock_clock_.get());
+ ASSERT_TRUE(timer.Start());
+
+ int func_counter = 0;
+ timer.Add(
+ [&] {
+ mock_clock_->SleepForMicroseconds(kFuncRunningTimeUs);
+ func_counter++;
+ },
+ "func", kInitDelayUs, kRepeatUs);
+
+ ASSERT_EQ(0, func_counter);
+ timer.TEST_WaitForRun(
+ [&] { mock_clock_->SleepForMicroseconds(kInitDelayUs); });
+ ASSERT_EQ(1, func_counter);
+ ASSERT_EQ(kInitDelayUs + kFuncRunningTimeUs, mock_clock_->NowMicros());
+
+ // After repeat interval time, the function is not executed, as running
+ // the function takes some time (`kFuncRunningTimeSec`). The repeat interval
+ // is the time between ending time of the last call and starting time of the
+ // next call.
+ uint64_t next_abs_interval_time_us = kInitDelayUs + kRepeatUs;
+ timer.TEST_WaitForRun([&] {
+ mock_clock_->SetCurrentTime(next_abs_interval_time_us / kUsPerSec);
+ });
+ ASSERT_EQ(1, func_counter);
+
+ // After the function running time, it's executed again
+ timer.TEST_WaitForRun(
+ [&] { mock_clock_->SleepForMicroseconds(kFuncRunningTimeUs); });
+ ASSERT_EQ(2, func_counter);
+
+ ASSERT_TRUE(timer.Shutdown());
+}
+
+TEST_F(TimerTest, DestroyRunningTimer) {
+ const int kInitDelayUs = 1 * kUsPerSec;
+ const int kRepeatUs = 1 * kUsPerSec;
+
+ auto timer_ptr = new Timer(mock_clock_.get());
+
+ int count = 0;
+ timer_ptr->Add([&] { count++; }, "fn_sch_test", kInitDelayUs, kRepeatUs);
+ ASSERT_TRUE(timer_ptr->Start());
+
+ timer_ptr->TEST_WaitForRun(
+ [&] { mock_clock_->SleepForMicroseconds(kInitDelayUs); });
+
+ // delete a running timer should not cause any exception
+ delete timer_ptr;
+}
+
+TEST_F(TimerTest, DestroyTimerWithRunningFunc) {
+ const int kRepeatUs = 1 * kUsPerSec;
+ auto timer_ptr = new Timer(mock_clock_.get());
+
+ SyncPoint::GetInstance()->DisableProcessing();
+ SyncPoint::GetInstance()->LoadDependency({
+ {"TimerTest::DestroyTimerWithRunningFunc:test_func:0",
+ "TimerTest::DestroyTimerWithRunningFunc:BeforeDelete"},
+ {"Timer::WaitForTaskCompleteIfNecessary:TaskExecuting",
+ "TimerTest::DestroyTimerWithRunningFunc:test_func:1"},
+ });
+ SyncPoint::GetInstance()->EnableProcessing();
+
+ ASSERT_TRUE(timer_ptr->Start());
+
+ int count = 0;
+ timer_ptr->Add(
+ [&]() {
+ TEST_SYNC_POINT("TimerTest::DestroyTimerWithRunningFunc:test_func:0");
+ count++;
+ TEST_SYNC_POINT("TimerTest::DestroyTimerWithRunningFunc:test_func:1");
+ },
+ "fn_running_test", 0, kRepeatUs);
+
+ port::Thread control_thr([&] {
+ TEST_SYNC_POINT("TimerTest::DestroyTimerWithRunningFunc:BeforeDelete");
+ delete timer_ptr;
+ });
+ mock_clock_->SleepForMicroseconds(kRepeatUs);
+ control_thr.join();
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+
+ return RUN_ALL_TESTS();
+}
diff --git a/src/rocksdb/util/user_comparator_wrapper.h b/src/rocksdb/util/user_comparator_wrapper.h
new file mode 100644
index 000000000..59ebada12
--- /dev/null
+++ b/src/rocksdb/util/user_comparator_wrapper.h
@@ -0,0 +1,64 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#pragma once
+
+#include "monitoring/perf_context_imp.h"
+#include "rocksdb/comparator.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// Wrapper of user comparator, with auto increment to
+// perf_context.user_key_comparison_count.
+class UserComparatorWrapper {
+ public:
+ // `UserComparatorWrapper`s constructed with the default constructor are not
+ // usable and will segfault on any attempt to use them for comparisons.
+ UserComparatorWrapper() : user_comparator_(nullptr) {}
+
+ explicit UserComparatorWrapper(const Comparator* const user_cmp)
+ : user_comparator_(user_cmp) {}
+
+ ~UserComparatorWrapper() = default;
+
+ const Comparator* user_comparator() const { return user_comparator_; }
+
+ int Compare(const Slice& a, const Slice& b) const {
+ PERF_COUNTER_ADD(user_key_comparison_count, 1);
+ return user_comparator_->Compare(a, b);
+ }
+
+ bool Equal(const Slice& a, const Slice& b) const {
+ PERF_COUNTER_ADD(user_key_comparison_count, 1);
+ return user_comparator_->Equal(a, b);
+ }
+
+ int CompareTimestamp(const Slice& ts1, const Slice& ts2) const {
+ return user_comparator_->CompareTimestamp(ts1, ts2);
+ }
+
+ int CompareWithoutTimestamp(const Slice& a, const Slice& b) const {
+ PERF_COUNTER_ADD(user_key_comparison_count, 1);
+ return user_comparator_->CompareWithoutTimestamp(a, b);
+ }
+
+ int CompareWithoutTimestamp(const Slice& a, bool a_has_ts, const Slice& b,
+ bool b_has_ts) const {
+ PERF_COUNTER_ADD(user_key_comparison_count, 1);
+ return user_comparator_->CompareWithoutTimestamp(a, a_has_ts, b, b_has_ts);
+ }
+
+ bool EqualWithoutTimestamp(const Slice& a, const Slice& b) const {
+ return user_comparator_->EqualWithoutTimestamp(a, b);
+ }
+
+ private:
+ const Comparator* user_comparator_;
+};
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/vector_iterator.h b/src/rocksdb/util/vector_iterator.h
new file mode 100644
index 000000000..c4cc01d56
--- /dev/null
+++ b/src/rocksdb/util/vector_iterator.h
@@ -0,0 +1,118 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+#pragma once
+
+#include <algorithm>
+#include <string>
+#include <vector>
+
+#include "db/dbformat.h"
+#include "rocksdb/comparator.h"
+#include "rocksdb/iterator.h"
+#include "rocksdb/slice.h"
+#include "table/internal_iterator.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// Iterator over a vector of keys/values
+class VectorIterator : public InternalIterator {
+ public:
+ VectorIterator(std::vector<std::string> keys, std::vector<std::string> values,
+ const CompareInterface* icmp = nullptr)
+ : keys_(std::move(keys)),
+ values_(std::move(values)),
+ current_(keys_.size()),
+ indexed_cmp_(icmp, &keys_) {
+ assert(keys_.size() == values_.size());
+
+ indices_.reserve(keys_.size());
+ for (size_t i = 0; i < keys_.size(); i++) {
+ indices_.push_back(i);
+ }
+ if (icmp != nullptr) {
+ std::sort(indices_.begin(), indices_.end(), indexed_cmp_);
+ }
+ }
+
+ virtual bool Valid() const override {
+ return !indices_.empty() && current_ < indices_.size();
+ }
+
+ virtual void SeekToFirst() override { current_ = 0; }
+ virtual void SeekToLast() override { current_ = indices_.size() - 1; }
+
+ virtual void Seek(const Slice& target) override {
+ if (indexed_cmp_.cmp != nullptr) {
+ current_ = std::lower_bound(indices_.begin(), indices_.end(), target,
+ indexed_cmp_) -
+ indices_.begin();
+ } else {
+ current_ =
+ std::lower_bound(keys_.begin(), keys_.end(), target.ToString()) -
+ keys_.begin();
+ }
+ }
+
+ virtual void SeekForPrev(const Slice& target) override {
+ if (indexed_cmp_.cmp != nullptr) {
+ current_ = std::upper_bound(indices_.begin(), indices_.end(), target,
+ indexed_cmp_) -
+ indices_.begin();
+ } else {
+ current_ =
+ std::upper_bound(keys_.begin(), keys_.end(), target.ToString()) -
+ keys_.begin();
+ }
+ if (!Valid()) {
+ SeekToLast();
+ } else {
+ Prev();
+ }
+ }
+
+ virtual void Next() override { current_++; }
+ virtual void Prev() override { current_--; }
+
+ virtual Slice key() const override {
+ return Slice(keys_[indices_[current_]]);
+ }
+ virtual Slice value() const override {
+ return Slice(values_[indices_[current_]]);
+ }
+
+ virtual Status status() const override { return Status::OK(); }
+
+ virtual bool IsKeyPinned() const override { return true; }
+ virtual bool IsValuePinned() const override { return true; }
+
+ protected:
+ std::vector<std::string> keys_;
+ std::vector<std::string> values_;
+ size_t current_;
+
+ private:
+ struct IndexedKeyComparator {
+ IndexedKeyComparator(const CompareInterface* c,
+ const std::vector<std::string>* ks)
+ : cmp(c), keys(ks) {}
+
+ bool operator()(size_t a, size_t b) const {
+ return cmp->Compare((*keys)[a], (*keys)[b]) < 0;
+ }
+
+ bool operator()(size_t a, const Slice& b) const {
+ return cmp->Compare((*keys)[a], b) < 0;
+ }
+
+ bool operator()(const Slice& a, size_t b) const {
+ return cmp->Compare(a, (*keys)[b]) < 0;
+ }
+
+ const CompareInterface* cmp;
+ const std::vector<std::string>* keys;
+ };
+
+ IndexedKeyComparator indexed_cmp_;
+ std::vector<size_t> indices_;
+};
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/work_queue.h b/src/rocksdb/util/work_queue.h
new file mode 100644
index 000000000..94ece85d9
--- /dev/null
+++ b/src/rocksdb/util/work_queue.h
@@ -0,0 +1,150 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+/*
+ * Copyright (c) 2016-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under both the BSD-style license (found in the
+ * LICENSE file in the root directory of this source tree) and the GPLv2 (found
+ * in the COPYING file in the root directory of this source tree).
+ */
+#pragma once
+
+#include <atomic>
+#include <cassert>
+#include <condition_variable>
+#include <cstddef>
+#include <functional>
+#include <mutex>
+#include <queue>
+
+#include "rocksdb/rocksdb_namespace.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+/// Unbounded thread-safe work queue.
+//
+// This file is an excerpt from Facebook's zstd repo at
+// https://github.com/facebook/zstd/. The relevant file is
+// contrib/pzstd/utils/WorkQueue.h.
+
+template <typename T>
+class WorkQueue {
+ // Protects all member variable access
+ std::mutex mutex_;
+ std::condition_variable readerCv_;
+ std::condition_variable writerCv_;
+ std::condition_variable finishCv_;
+
+ std::queue<T> queue_;
+ bool done_;
+ std::size_t maxSize_;
+
+ // Must have lock to call this function
+ bool full() const {
+ if (maxSize_ == 0) {
+ return false;
+ }
+ return queue_.size() >= maxSize_;
+ }
+
+ public:
+ /**
+ * Constructs an empty work queue with an optional max size.
+ * If `maxSize == 0` the queue size is unbounded.
+ *
+ * @param maxSize The maximum allowed size of the work queue.
+ */
+ WorkQueue(std::size_t maxSize = 0) : done_(false), maxSize_(maxSize) {}
+
+ /**
+ * Push an item onto the work queue. Notify a single thread that work is
+ * available. If `finish()` has been called, do nothing and return false.
+ * If `push()` returns false, then `item` has not been copied from.
+ *
+ * @param item Item to push onto the queue.
+ * @returns True upon success, false if `finish()` has been called. An
+ * item was pushed iff `push()` returns true.
+ */
+ template <typename U>
+ bool push(U&& item) {
+ {
+ std::unique_lock<std::mutex> lock(mutex_);
+ while (full() && !done_) {
+ writerCv_.wait(lock);
+ }
+ if (done_) {
+ return false;
+ }
+ queue_.push(std::forward<U>(item));
+ }
+ readerCv_.notify_one();
+ return true;
+ }
+
+ /**
+ * Attempts to pop an item off the work queue. It will block until data is
+ * available or `finish()` has been called.
+ *
+ * @param[out] item If `pop` returns `true`, it contains the popped item.
+ * If `pop` returns `false`, it is unmodified.
+ * @returns True upon success. False if the queue is empty and
+ * `finish()` has been called.
+ */
+ bool pop(T& item) {
+ {
+ std::unique_lock<std::mutex> lock(mutex_);
+ while (queue_.empty() && !done_) {
+ readerCv_.wait(lock);
+ }
+ if (queue_.empty()) {
+ assert(done_);
+ return false;
+ }
+ item = queue_.front();
+ queue_.pop();
+ }
+ writerCv_.notify_one();
+ return true;
+ }
+
+ /**
+ * Sets the maximum queue size. If `maxSize == 0` then it is unbounded.
+ *
+ * @param maxSize The new maximum queue size.
+ */
+ void setMaxSize(std::size_t maxSize) {
+ {
+ std::lock_guard<std::mutex> lock(mutex_);
+ maxSize_ = maxSize;
+ }
+ writerCv_.notify_all();
+ }
+
+ /**
+ * Promise that `push()` won't be called again, so once the queue is empty
+ * there will never any more work.
+ */
+ void finish() {
+ {
+ std::lock_guard<std::mutex> lock(mutex_);
+ assert(!done_);
+ done_ = true;
+ }
+ readerCv_.notify_all();
+ writerCv_.notify_all();
+ finishCv_.notify_all();
+ }
+
+ /// Blocks until `finish()` has been called (but the queue may not be empty).
+ void waitUntilFinished() {
+ std::unique_lock<std::mutex> lock(mutex_);
+ while (!done_) {
+ finishCv_.wait(lock);
+ }
+ }
+};
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/util/work_queue_test.cc b/src/rocksdb/util/work_queue_test.cc
new file mode 100644
index 000000000..c23a51279
--- /dev/null
+++ b/src/rocksdb/util/work_queue_test.cc
@@ -0,0 +1,272 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+/*
+ * Copyright (c) 2016-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under both the BSD-style license (found in the
+ * LICENSE file in the root directory of this source tree) and the GPLv2 (found
+ * in the COPYING file in the root directory of this source tree).
+ */
+#include "util/work_queue.h"
+
+#include <gtest/gtest.h>
+
+#include <iostream>
+#include <memory>
+#include <mutex>
+#include <thread>
+#include <vector>
+
+#include "port/stack_trace.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// Unit test for work_queue.h.
+//
+// This file is an excerpt from Facebook's zstd repo at
+// https://github.com/facebook/zstd/. The relevant file is
+// contrib/pzstd/utils/test/WorkQueueTest.cpp.
+
+struct Popper {
+ WorkQueue<int>* queue;
+ int* results;
+ std::mutex* mutex;
+
+ void operator()() {
+ int result;
+ while (queue->pop(result)) {
+ std::lock_guard<std::mutex> lock(*mutex);
+ results[result] = result;
+ }
+ }
+};
+
+TEST(WorkQueue, SingleThreaded) {
+ WorkQueue<int> queue;
+ int result;
+
+ queue.push(5);
+ EXPECT_TRUE(queue.pop(result));
+ EXPECT_EQ(5, result);
+
+ queue.push(1);
+ queue.push(2);
+ EXPECT_TRUE(queue.pop(result));
+ EXPECT_EQ(1, result);
+ EXPECT_TRUE(queue.pop(result));
+ EXPECT_EQ(2, result);
+
+ queue.push(1);
+ queue.push(2);
+ queue.finish();
+ EXPECT_TRUE(queue.pop(result));
+ EXPECT_EQ(1, result);
+ EXPECT_TRUE(queue.pop(result));
+ EXPECT_EQ(2, result);
+ EXPECT_FALSE(queue.pop(result));
+
+ queue.waitUntilFinished();
+}
+
+TEST(WorkQueue, SPSC) {
+ WorkQueue<int> queue;
+ const int max = 100;
+
+ for (int i = 0; i < 10; ++i) {
+ queue.push(i);
+ }
+
+ std::thread thread([&queue, max] {
+ int result;
+ for (int i = 0;; ++i) {
+ if (!queue.pop(result)) {
+ EXPECT_EQ(i, max);
+ break;
+ }
+ EXPECT_EQ(i, result);
+ }
+ });
+
+ std::this_thread::yield();
+ for (int i = 10; i < max; ++i) {
+ queue.push(i);
+ }
+ queue.finish();
+
+ thread.join();
+}
+
+TEST(WorkQueue, SPMC) {
+ WorkQueue<int> queue;
+ std::vector<int> results(50, -1);
+ std::mutex mutex;
+ std::vector<std::thread> threads;
+ for (int i = 0; i < 5; ++i) {
+ threads.emplace_back(Popper{&queue, results.data(), &mutex});
+ }
+
+ for (int i = 0; i < 50; ++i) {
+ queue.push(i);
+ }
+ queue.finish();
+
+ for (auto& thread : threads) {
+ thread.join();
+ }
+
+ for (int i = 0; i < 50; ++i) {
+ EXPECT_EQ(i, results[i]);
+ }
+}
+
+TEST(WorkQueue, MPMC) {
+ WorkQueue<int> queue;
+ std::vector<int> results(100, -1);
+ std::mutex mutex;
+ std::vector<std::thread> popperThreads;
+ for (int i = 0; i < 4; ++i) {
+ popperThreads.emplace_back(Popper{&queue, results.data(), &mutex});
+ }
+
+ std::vector<std::thread> pusherThreads;
+ for (int i = 0; i < 2; ++i) {
+ auto min = i * 50;
+ auto max = (i + 1) * 50;
+ pusherThreads.emplace_back([&queue, min, max] {
+ for (int j = min; j < max; ++j) {
+ queue.push(j);
+ }
+ });
+ }
+
+ for (auto& thread : pusherThreads) {
+ thread.join();
+ }
+ queue.finish();
+
+ for (auto& thread : popperThreads) {
+ thread.join();
+ }
+
+ for (int i = 0; i < 100; ++i) {
+ EXPECT_EQ(i, results[i]);
+ }
+}
+
+TEST(WorkQueue, BoundedSizeWorks) {
+ WorkQueue<int> queue(1);
+ int result;
+ queue.push(5);
+ queue.pop(result);
+ queue.push(5);
+ queue.pop(result);
+ queue.push(5);
+ queue.finish();
+ queue.pop(result);
+ EXPECT_EQ(5, result);
+}
+
+TEST(WorkQueue, BoundedSizePushAfterFinish) {
+ WorkQueue<int> queue(1);
+ int result;
+ queue.push(5);
+ std::thread pusher([&queue] { queue.push(6); });
+ // Dirtily try and make sure that pusher has run.
+ std::this_thread::sleep_for(std::chrono::seconds(1));
+ queue.finish();
+ EXPECT_TRUE(queue.pop(result));
+ EXPECT_EQ(5, result);
+ EXPECT_FALSE(queue.pop(result));
+
+ pusher.join();
+}
+
+TEST(WorkQueue, SetMaxSize) {
+ WorkQueue<int> queue(2);
+ int result;
+ queue.push(5);
+ queue.push(6);
+ queue.setMaxSize(1);
+ std::thread pusher([&queue] { queue.push(7); });
+ // Dirtily try and make sure that pusher has run.
+ std::this_thread::sleep_for(std::chrono::seconds(1));
+ queue.finish();
+ EXPECT_TRUE(queue.pop(result));
+ EXPECT_EQ(5, result);
+ EXPECT_TRUE(queue.pop(result));
+ EXPECT_EQ(6, result);
+ EXPECT_FALSE(queue.pop(result));
+
+ pusher.join();
+}
+
+TEST(WorkQueue, BoundedSizeMPMC) {
+ WorkQueue<int> queue(10);
+ std::vector<int> results(200, -1);
+ std::mutex mutex;
+ std::cerr << "Creating popperThreads" << std::endl;
+ std::vector<std::thread> popperThreads;
+ for (int i = 0; i < 4; ++i) {
+ popperThreads.emplace_back(Popper{&queue, results.data(), &mutex});
+ }
+
+ std::cerr << "Creating pusherThreads" << std::endl;
+ std::vector<std::thread> pusherThreads;
+ for (int i = 0; i < 2; ++i) {
+ auto min = i * 100;
+ auto max = (i + 1) * 100;
+ pusherThreads.emplace_back([&queue, min, max] {
+ for (int j = min; j < max; ++j) {
+ queue.push(j);
+ }
+ });
+ }
+
+ std::cerr << "Joining pusherThreads" << std::endl;
+ for (auto& thread : pusherThreads) {
+ thread.join();
+ }
+ std::cerr << "Finishing queue" << std::endl;
+ queue.finish();
+
+ std::cerr << "Joining popperThreads" << std::endl;
+ for (auto& thread : popperThreads) {
+ thread.join();
+ }
+
+ std::cerr << "Inspecting results" << std::endl;
+ for (int i = 0; i < 200; ++i) {
+ EXPECT_EQ(i, results[i]);
+ }
+}
+
+TEST(WorkQueue, FailedPush) {
+ WorkQueue<int> queue;
+ EXPECT_TRUE(queue.push(1));
+ queue.finish();
+ EXPECT_FALSE(queue.push(1));
+}
+
+TEST(WorkQueue, FailedPop) {
+ WorkQueue<int> queue;
+ int x = 5;
+ EXPECT_TRUE(queue.push(x));
+ queue.finish();
+ x = 0;
+ EXPECT_TRUE(queue.pop(x));
+ EXPECT_EQ(5, x);
+ EXPECT_FALSE(queue.pop(x));
+ EXPECT_EQ(5, x);
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/rocksdb/util/xxhash.cc b/src/rocksdb/util/xxhash.cc
new file mode 100644
index 000000000..88852c330
--- /dev/null
+++ b/src/rocksdb/util/xxhash.cc
@@ -0,0 +1,48 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+/*
+ * xxHash - Extremely Fast Hash algorithm
+ * Copyright (C) 2012-2020 Yann Collet
+ *
+ * BSD 2-Clause License (https://www.opensource.org/licenses/bsd-license.php)
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are
+ * met:
+ *
+ * * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following disclaimer
+ * in the documentation and/or other materials provided with the
+ * distribution.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ * You can contact the author at:
+ * - xxHash homepage: https://www.xxhash.com
+ * - xxHash source repository: https://github.com/Cyan4973/xxHash
+ */
+
+/*
+ * xxhash.c instantiates functions defined in xxhash.h
+ */
+// clang-format off
+#ifndef XXH_STATIC_LINKING_ONLY
+#define XXH_STATIC_LINKING_ONLY /* access advanced declarations */
+#endif // !defined(XXH_STATIC_LINKING_ONLY)
+#define XXH_IMPLEMENTATION /* access definitions */
+
+#include "xxhash.h"
diff --git a/src/rocksdb/util/xxhash.h b/src/rocksdb/util/xxhash.h
new file mode 100644
index 000000000..195f06b39
--- /dev/null
+++ b/src/rocksdb/util/xxhash.h
@@ -0,0 +1,5346 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+/* BEGIN RocksDB customizations */
+#ifndef XXH_STATIC_LINKING_ONLY
+// Using compiled xxhash.cc
+#define XXH_STATIC_LINKING_ONLY 1
+#endif // !defined(XXH_STATIC_LINKING_ONLY)
+#ifndef XXH_NAMESPACE
+#define XXH_NAMESPACE ROCKSDB_
+#endif // !defined(XXH_NAMESPACE)
+
+// for FALLTHROUGH_INTENDED, inserted as appropriate
+#include "port/lang.h"
+/* END RocksDB customizations */
+
+// clang-format off
+/*
+ * xxHash - Extremely Fast Hash algorithm
+ * Header File
+ * Copyright (C) 2012-2020 Yann Collet
+ *
+ * BSD 2-Clause License (https://www.opensource.org/licenses/bsd-license.php)
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are
+ * met:
+ *
+ * * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following disclaimer
+ * in the documentation and/or other materials provided with the
+ * distribution.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ * You can contact the author at:
+ * - xxHash homepage: https://www.xxhash.com
+ * - xxHash source repository: https://github.com/Cyan4973/xxHash
+ */
+/*!
+ * @mainpage xxHash
+ *
+ * @file xxhash.h
+ * xxHash prototypes and implementation
+ */
+/* TODO: update */
+/* Notice extracted from xxHash homepage:
+
+xxHash is an extremely fast hash algorithm, running at RAM speed limits.
+It also successfully passes all tests from the SMHasher suite.
+
+Comparison (single thread, Windows Seven 32 bits, using SMHasher on a Core 2 Duo @3GHz)
+
+Name Speed Q.Score Author
+xxHash 5.4 GB/s 10
+CrapWow 3.2 GB/s 2 Andrew
+MurmurHash 3a 2.7 GB/s 10 Austin Appleby
+SpookyHash 2.0 GB/s 10 Bob Jenkins
+SBox 1.4 GB/s 9 Bret Mulvey
+Lookup3 1.2 GB/s 9 Bob Jenkins
+SuperFastHash 1.2 GB/s 1 Paul Hsieh
+CityHash64 1.05 GB/s 10 Pike & Alakuijala
+FNV 0.55 GB/s 5 Fowler, Noll, Vo
+CRC32 0.43 GB/s 9
+MD5-32 0.33 GB/s 10 Ronald L. Rivest
+SHA1-32 0.28 GB/s 10
+
+Q.Score is a measure of quality of the hash function.
+It depends on successfully passing SMHasher test set.
+10 is a perfect score.
+
+Note: SMHasher's CRC32 implementation is not the fastest one.
+Other speed-oriented implementations can be faster,
+especially in combination with PCLMUL instruction:
+https://fastcompression.blogspot.com/2019/03/presenting-xxh3.html?showComment=1552696407071#c3490092340461170735
+
+A 64-bit version, named XXH64, is available since r35.
+It offers much better speed, but for 64-bit applications only.
+Name Speed on 64 bits Speed on 32 bits
+XXH64 13.8 GB/s 1.9 GB/s
+XXH32 6.8 GB/s 6.0 GB/s
+*/
+
+#if defined (__cplusplus)
+extern "C" {
+#endif
+
+/* ****************************
+ * INLINE mode
+ ******************************/
+/*!
+ * XXH_INLINE_ALL (and XXH_PRIVATE_API)
+ * Use these build macros to inline xxhash into the target unit.
+ * Inlining improves performance on small inputs, especially when the length is
+ * expressed as a compile-time constant:
+ *
+ * https://fastcompression.blogspot.com/2018/03/xxhash-for-small-keys-impressive-power.html
+ *
+ * It also keeps xxHash symbols private to the unit, so they are not exported.
+ *
+ * Usage:
+ * #define XXH_INLINE_ALL
+ * #include "xxhash.h"
+ *
+ * Do not compile and link xxhash.o as a separate object, as it is not useful.
+ */
+#if (defined(XXH_INLINE_ALL) || defined(XXH_PRIVATE_API)) \
+ && !defined(XXH_INLINE_ALL_31684351384)
+ /* this section should be traversed only once */
+# define XXH_INLINE_ALL_31684351384
+ /* give access to the advanced API, required to compile implementations */
+# undef XXH_STATIC_LINKING_ONLY /* avoid macro redef */
+# define XXH_STATIC_LINKING_ONLY
+ /* make all functions private */
+# undef XXH_PUBLIC_API
+# if defined(__GNUC__)
+# define XXH_PUBLIC_API static __inline __attribute__((unused))
+# elif defined (__cplusplus) || (defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) /* C99 */)
+# define XXH_PUBLIC_API static inline
+# elif defined(_MSC_VER)
+# define XXH_PUBLIC_API static __inline
+# else
+ /* note: this version may generate warnings for unused static functions */
+# define XXH_PUBLIC_API static
+# endif
+
+ /*
+ * This part deals with the special case where a unit wants to inline xxHash,
+ * but "xxhash.h" has previously been included without XXH_INLINE_ALL, such
+ * as part of some previously included *.h header file.
+ * Without further action, the new include would just be ignored,
+ * and functions would effectively _not_ be inlined (silent failure).
+ * The following macros solve this situation by prefixing all inlined names,
+ * avoiding naming collision with previous inclusions.
+ */
+# ifdef XXH_NAMESPACE
+# error "XXH_INLINE_ALL with XXH_NAMESPACE is not supported"
+ /*
+ * Note: Alternative: #undef all symbols (it's a pretty large list).
+ * Without #error: it compiles, but functions are actually not inlined.
+ */
+# endif
+# define XXH_NAMESPACE XXH_INLINE_
+ /*
+ * Some identifiers (enums, type names) are not symbols, but they must
+ * still be renamed to avoid redeclaration.
+ * Alternative solution: do not redeclare them.
+ * However, this requires some #ifdefs, and is a more dispersed action.
+ * Meanwhile, renaming can be achieved in a single block
+ */
+# define XXH_IPREF(Id) XXH_INLINE_ ## Id
+# define XXH_OK XXH_IPREF(XXH_OK)
+# define XXH_ERROR XXH_IPREF(XXH_ERROR)
+# define XXH_errorcode XXH_IPREF(XXH_errorcode)
+# define XXH32_canonical_t XXH_IPREF(XXH32_canonical_t)
+# define XXH64_canonical_t XXH_IPREF(XXH64_canonical_t)
+# define XXH128_canonical_t XXH_IPREF(XXH128_canonical_t)
+# define XXH32_state_s XXH_IPREF(XXH32_state_s)
+# define XXH32_state_t XXH_IPREF(XXH32_state_t)
+# define XXH64_state_s XXH_IPREF(XXH64_state_s)
+# define XXH64_state_t XXH_IPREF(XXH64_state_t)
+# define XXH3_state_s XXH_IPREF(XXH3_state_s)
+# define XXH3_state_t XXH_IPREF(XXH3_state_t)
+# define XXH128_hash_t XXH_IPREF(XXH128_hash_t)
+ /* Ensure the header is parsed again, even if it was previously included */
+# undef XXHASH_H_5627135585666179
+# undef XXHASH_H_STATIC_13879238742
+#endif /* XXH_INLINE_ALL || XXH_PRIVATE_API */
+
+
+
+/* ****************************************************************
+ * Stable API
+ *****************************************************************/
+#ifndef XXHASH_H_5627135585666179
+#define XXHASH_H_5627135585666179 1
+
+
+/*!
+ * @defgroup public Public API
+ * Contains details on the public xxHash functions.
+ * @{
+ */
+/* specific declaration modes for Windows */
+#if !defined(XXH_INLINE_ALL) && !defined(XXH_PRIVATE_API)
+# if defined(WIN32) && defined(_MSC_VER) && (defined(XXH_IMPORT) || defined(XXH_EXPORT))
+# ifdef XXH_EXPORT
+# define XXH_PUBLIC_API __declspec(dllexport)
+# elif XXH_IMPORT
+# define XXH_PUBLIC_API __declspec(dllimport)
+# endif
+# else
+# define XXH_PUBLIC_API /* do nothing */
+# endif
+#endif
+
+#ifdef XXH_DOXYGEN
+/*!
+ * @brief Emulate a namespace by transparently prefixing all symbols.
+ *
+ * If you want to include _and expose_ xxHash functions from within your own
+ * library, but also want to avoid symbol collisions with other libraries which
+ * may also include xxHash, you can use XXH_NAMESPACE to automatically prefix
+ * any public symbol from xxhash library with the value of XXH_NAMESPACE
+ * (therefore, avoid empty or numeric values).
+ *
+ * Note that no change is required within the calling program as long as it
+ * includes `xxhash.h`: Regular symbol names will be automatically translated
+ * by this header.
+ */
+# define XXH_NAMESPACE /* YOUR NAME HERE */
+# undef XXH_NAMESPACE
+#endif
+
+#ifdef XXH_NAMESPACE
+# define XXH_CAT(A,B) A##B
+# define XXH_NAME2(A,B) XXH_CAT(A,B)
+# define XXH_versionNumber XXH_NAME2(XXH_NAMESPACE, XXH_versionNumber)
+/* XXH32 */
+# define XXH32 XXH_NAME2(XXH_NAMESPACE, XXH32)
+# define XXH32_createState XXH_NAME2(XXH_NAMESPACE, XXH32_createState)
+# define XXH32_freeState XXH_NAME2(XXH_NAMESPACE, XXH32_freeState)
+# define XXH32_reset XXH_NAME2(XXH_NAMESPACE, XXH32_reset)
+# define XXH32_update XXH_NAME2(XXH_NAMESPACE, XXH32_update)
+# define XXH32_digest XXH_NAME2(XXH_NAMESPACE, XXH32_digest)
+# define XXH32_copyState XXH_NAME2(XXH_NAMESPACE, XXH32_copyState)
+# define XXH32_canonicalFromHash XXH_NAME2(XXH_NAMESPACE, XXH32_canonicalFromHash)
+# define XXH32_hashFromCanonical XXH_NAME2(XXH_NAMESPACE, XXH32_hashFromCanonical)
+/* XXH64 */
+# define XXH64 XXH_NAME2(XXH_NAMESPACE, XXH64)
+# define XXH64_createState XXH_NAME2(XXH_NAMESPACE, XXH64_createState)
+# define XXH64_freeState XXH_NAME2(XXH_NAMESPACE, XXH64_freeState)
+# define XXH64_reset XXH_NAME2(XXH_NAMESPACE, XXH64_reset)
+# define XXH64_update XXH_NAME2(XXH_NAMESPACE, XXH64_update)
+# define XXH64_digest XXH_NAME2(XXH_NAMESPACE, XXH64_digest)
+# define XXH64_copyState XXH_NAME2(XXH_NAMESPACE, XXH64_copyState)
+# define XXH64_canonicalFromHash XXH_NAME2(XXH_NAMESPACE, XXH64_canonicalFromHash)
+# define XXH64_hashFromCanonical XXH_NAME2(XXH_NAMESPACE, XXH64_hashFromCanonical)
+/* XXH3_64bits */
+# define XXH3_64bits XXH_NAME2(XXH_NAMESPACE, XXH3_64bits)
+# define XXH3_64bits_withSecret XXH_NAME2(XXH_NAMESPACE, XXH3_64bits_withSecret)
+# define XXH3_64bits_withSeed XXH_NAME2(XXH_NAMESPACE, XXH3_64bits_withSeed)
+# define XXH3_createState XXH_NAME2(XXH_NAMESPACE, XXH3_createState)
+# define XXH3_freeState XXH_NAME2(XXH_NAMESPACE, XXH3_freeState)
+# define XXH3_copyState XXH_NAME2(XXH_NAMESPACE, XXH3_copyState)
+# define XXH3_64bits_reset XXH_NAME2(XXH_NAMESPACE, XXH3_64bits_reset)
+# define XXH3_64bits_reset_withSeed XXH_NAME2(XXH_NAMESPACE, XXH3_64bits_reset_withSeed)
+# define XXH3_64bits_reset_withSecret XXH_NAME2(XXH_NAMESPACE, XXH3_64bits_reset_withSecret)
+# define XXH3_64bits_update XXH_NAME2(XXH_NAMESPACE, XXH3_64bits_update)
+# define XXH3_64bits_digest XXH_NAME2(XXH_NAMESPACE, XXH3_64bits_digest)
+# define XXH3_generateSecret XXH_NAME2(XXH_NAMESPACE, XXH3_generateSecret)
+/* XXH3_128bits */
+# define XXH128 XXH_NAME2(XXH_NAMESPACE, XXH128)
+# define XXH3_128bits XXH_NAME2(XXH_NAMESPACE, XXH3_128bits)
+# define XXH3_128bits_withSeed XXH_NAME2(XXH_NAMESPACE, XXH3_128bits_withSeed)
+# define XXH3_128bits_withSecret XXH_NAME2(XXH_NAMESPACE, XXH3_128bits_withSecret)
+# define XXH3_128bits_reset XXH_NAME2(XXH_NAMESPACE, XXH3_128bits_reset)
+# define XXH3_128bits_reset_withSeed XXH_NAME2(XXH_NAMESPACE, XXH3_128bits_reset_withSeed)
+# define XXH3_128bits_reset_withSecret XXH_NAME2(XXH_NAMESPACE, XXH3_128bits_reset_withSecret)
+# define XXH3_128bits_update XXH_NAME2(XXH_NAMESPACE, XXH3_128bits_update)
+# define XXH3_128bits_digest XXH_NAME2(XXH_NAMESPACE, XXH3_128bits_digest)
+# define XXH128_isEqual XXH_NAME2(XXH_NAMESPACE, XXH128_isEqual)
+# define XXH128_cmp XXH_NAME2(XXH_NAMESPACE, XXH128_cmp)
+# define XXH128_canonicalFromHash XXH_NAME2(XXH_NAMESPACE, XXH128_canonicalFromHash)
+# define XXH128_hashFromCanonical XXH_NAME2(XXH_NAMESPACE, XXH128_hashFromCanonical)
+#endif
+
+
+/* *************************************
+* Version
+***************************************/
+#define XXH_VERSION_MAJOR 0
+#define XXH_VERSION_MINOR 8
+#define XXH_VERSION_RELEASE 1
+#define XXH_VERSION_NUMBER (XXH_VERSION_MAJOR *100*100 + XXH_VERSION_MINOR *100 + XXH_VERSION_RELEASE)
+
+/*!
+ * @brief Obtains the xxHash version.
+ *
+ * This is only useful when xxHash is compiled as a shared library, as it is
+ * independent of the version defined in the header.
+ *
+ * @return `XXH_VERSION_NUMBER` as of when the libray was compiled.
+ */
+XXH_PUBLIC_API unsigned XXH_versionNumber (void);
+
+
+/* ****************************
+* Definitions
+******************************/
+#include <stddef.h> /* size_t */
+typedef enum { XXH_OK=0, XXH_ERROR } XXH_errorcode;
+
+
+/*-**********************************************************************
+* 32-bit hash
+************************************************************************/
+#if defined(XXH_DOXYGEN) /* Don't show <stdint.h> include */
+/*!
+ * @brief An unsigned 32-bit integer.
+ *
+ * Not necessarily defined to `uint32_t` but functionally equivalent.
+ */
+typedef uint32_t XXH32_hash_t;
+#elif !defined (__VMS) \
+ && (defined (__cplusplus) \
+ || (defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) /* C99 */) )
+# include <stdint.h>
+ typedef uint32_t XXH32_hash_t;
+#else
+# include <limits.h>
+# if UINT_MAX == 0xFFFFFFFFUL
+ typedef unsigned int XXH32_hash_t;
+# else
+# if ULONG_MAX == 0xFFFFFFFFUL
+ typedef unsigned long XXH32_hash_t;
+# else
+# error "unsupported platform: need a 32-bit type"
+# endif
+# endif
+#endif
+
+/*!
+ * @}
+ *
+ * @defgroup xxh32_family XXH32 family
+ * @ingroup public
+ * Contains functions used in the classic 32-bit xxHash algorithm.
+ *
+ * @note
+ * XXH32 is considered rather weak by today's standards.
+ * The @ref xxh3_family provides competitive speed for both 32-bit and 64-bit
+ * systems, and offers true 64/128 bit hash results. It provides a superior
+ * level of dispersion, and greatly reduces the risks of collisions.
+ *
+ * @see @ref xxh64_family, @ref xxh3_family : Other xxHash families
+ * @see @ref xxh32_impl for implementation details
+ * @{
+ */
+
+/*!
+ * @brief Calculates the 32-bit hash of @p input using xxHash32.
+ *
+ * Speed on Core 2 Duo @ 3 GHz (single thread, SMHasher benchmark): 5.4 GB/s
+ *
+ * @param input The block of data to be hashed, at least @p length bytes in size.
+ * @param length The length of @p input, in bytes.
+ * @param seed The 32-bit seed to alter the hash's output predictably.
+ *
+ * @pre
+ * The memory between @p input and @p input + @p length must be valid,
+ * readable, contiguous memory. However, if @p length is `0`, @p input may be
+ * `NULL`. In C++, this also must be *TriviallyCopyable*.
+ *
+ * @return The calculated 32-bit hash value.
+ *
+ * @see
+ * XXH64(), XXH3_64bits_withSeed(), XXH3_128bits_withSeed(), XXH128():
+ * Direct equivalents for the other variants of xxHash.
+ * @see
+ * XXH32_createState(), XXH32_update(), XXH32_digest(): Streaming version.
+ */
+XXH_PUBLIC_API XXH32_hash_t XXH32 (const void* input, size_t length, XXH32_hash_t seed);
+
+/*!
+ * Streaming functions generate the xxHash value from an incremental input.
+ * This method is slower than single-call functions, due to state management.
+ * For small inputs, prefer `XXH32()` and `XXH64()`, which are better optimized.
+ *
+ * An XXH state must first be allocated using `XXH*_createState()`.
+ *
+ * Start a new hash by initializing the state with a seed using `XXH*_reset()`.
+ *
+ * Then, feed the hash state by calling `XXH*_update()` as many times as necessary.
+ *
+ * The function returns an error code, with 0 meaning OK, and any other value
+ * meaning there is an error.
+ *
+ * Finally, a hash value can be produced anytime, by using `XXH*_digest()`.
+ * This function returns the nn-bits hash as an int or long long.
+ *
+ * It's still possible to continue inserting input into the hash state after a
+ * digest, and generate new hash values later on by invoking `XXH*_digest()`.
+ *
+ * When done, release the state using `XXH*_freeState()`.
+ *
+ * Example code for incrementally hashing a file:
+ * @code{.c}
+ * #include <stdio.h>
+ * #include <xxhash.h>
+ * #define BUFFER_SIZE 256
+ *
+ * // Note: XXH64 and XXH3 use the same interface.
+ * XXH32_hash_t
+ * hashFile(FILE* stream)
+ * {
+ * XXH32_state_t* state;
+ * unsigned char buf[BUFFER_SIZE];
+ * size_t amt;
+ * XXH32_hash_t hash;
+ *
+ * state = XXH32_createState(); // Create a state
+ * assert(state != NULL); // Error check here
+ * XXH32_reset(state, 0xbaad5eed); // Reset state with our seed
+ * while ((amt = fread(buf, 1, sizeof(buf), stream)) != 0) {
+ * XXH32_update(state, buf, amt); // Hash the file in chunks
+ * }
+ * hash = XXH32_digest(state); // Finalize the hash
+ * XXH32_freeState(state); // Clean up
+ * return hash;
+ * }
+ * @endcode
+ */
+
+/*!
+ * @typedef struct XXH32_state_s XXH32_state_t
+ * @brief The opaque state struct for the XXH32 streaming API.
+ *
+ * @see XXH32_state_s for details.
+ */
+typedef struct XXH32_state_s XXH32_state_t;
+
+/*!
+ * @brief Allocates an @ref XXH32_state_t.
+ *
+ * Must be freed with XXH32_freeState().
+ * @return An allocated XXH32_state_t on success, `NULL` on failure.
+ */
+XXH_PUBLIC_API XXH32_state_t* XXH32_createState(void);
+/*!
+ * @brief Frees an @ref XXH32_state_t.
+ *
+ * Must be allocated with XXH32_createState().
+ * @param statePtr A pointer to an @ref XXH32_state_t allocated with @ref XXH32_createState().
+ * @return XXH_OK.
+ */
+XXH_PUBLIC_API XXH_errorcode XXH32_freeState(XXH32_state_t* statePtr);
+/*!
+ * @brief Copies one @ref XXH32_state_t to another.
+ *
+ * @param dst_state The state to copy to.
+ * @param src_state The state to copy from.
+ * @pre
+ * @p dst_state and @p src_state must not be `NULL` and must not overlap.
+ */
+XXH_PUBLIC_API void XXH32_copyState(XXH32_state_t* dst_state, const XXH32_state_t* src_state);
+
+/*!
+ * @brief Resets an @ref XXH32_state_t to begin a new hash.
+ *
+ * This function resets and seeds a state. Call it before @ref XXH32_update().
+ *
+ * @param statePtr The state struct to reset.
+ * @param seed The 32-bit seed to alter the hash result predictably.
+ *
+ * @pre
+ * @p statePtr must not be `NULL`.
+ *
+ * @return @ref XXH_OK on success, @ref XXH_ERROR on failure.
+ */
+XXH_PUBLIC_API XXH_errorcode XXH32_reset (XXH32_state_t* statePtr, XXH32_hash_t seed);
+
+/*!
+ * @brief Consumes a block of @p input to an @ref XXH32_state_t.
+ *
+ * Call this to incrementally consume blocks of data.
+ *
+ * @param statePtr The state struct to update.
+ * @param input The block of data to be hashed, at least @p length bytes in size.
+ * @param length The length of @p input, in bytes.
+ *
+ * @pre
+ * @p statePtr must not be `NULL`.
+ * @pre
+ * The memory between @p input and @p input + @p length must be valid,
+ * readable, contiguous memory. However, if @p length is `0`, @p input may be
+ * `NULL`. In C++, this also must be *TriviallyCopyable*.
+ *
+ * @return @ref XXH_OK on success, @ref XXH_ERROR on failure.
+ */
+XXH_PUBLIC_API XXH_errorcode XXH32_update (XXH32_state_t* statePtr, const void* input, size_t length);
+
+/*!
+ * @brief Returns the calculated hash value from an @ref XXH32_state_t.
+ *
+ * @note
+ * Calling XXH32_digest() will not affect @p statePtr, so you can update,
+ * digest, and update again.
+ *
+ * @param statePtr The state struct to calculate the hash from.
+ *
+ * @pre
+ * @p statePtr must not be `NULL`.
+ *
+ * @return The calculated xxHash32 value from that state.
+ */
+XXH_PUBLIC_API XXH32_hash_t XXH32_digest (const XXH32_state_t* statePtr);
+
+/******* Canonical representation *******/
+
+/*
+ * The default return values from XXH functions are unsigned 32 and 64 bit
+ * integers.
+ * This the simplest and fastest format for further post-processing.
+ *
+ * However, this leaves open the question of what is the order on the byte level,
+ * since little and big endian conventions will store the same number differently.
+ *
+ * The canonical representation settles this issue by mandating big-endian
+ * convention, the same convention as human-readable numbers (large digits first).
+ *
+ * When writing hash values to storage, sending them over a network, or printing
+ * them, it's highly recommended to use the canonical representation to ensure
+ * portability across a wider range of systems, present and future.
+ *
+ * The following functions allow transformation of hash values to and from
+ * canonical format.
+ */
+
+/*!
+ * @brief Canonical (big endian) representation of @ref XXH32_hash_t.
+ */
+typedef struct {
+ unsigned char digest[4]; /*!< Hash bytes, big endian */
+} XXH32_canonical_t;
+
+/*!
+ * @brief Converts an @ref XXH32_hash_t to a big endian @ref XXH32_canonical_t.
+ *
+ * @param dst The @ref XXH32_canonical_t pointer to be stored to.
+ * @param hash The @ref XXH32_hash_t to be converted.
+ *
+ * @pre
+ * @p dst must not be `NULL`.
+ */
+XXH_PUBLIC_API void XXH32_canonicalFromHash(XXH32_canonical_t* dst, XXH32_hash_t hash);
+
+/*!
+ * @brief Converts an @ref XXH32_canonical_t to a native @ref XXH32_hash_t.
+ *
+ * @param src The @ref XXH32_canonical_t to convert.
+ *
+ * @pre
+ * @p src must not be `NULL`.
+ *
+ * @return The converted hash.
+ */
+XXH_PUBLIC_API XXH32_hash_t XXH32_hashFromCanonical(const XXH32_canonical_t* src);
+
+
+/*!
+ * @}
+ * @ingroup public
+ * @{
+ */
+
+#ifndef XXH_NO_LONG_LONG
+/*-**********************************************************************
+* 64-bit hash
+************************************************************************/
+#if defined(XXH_DOXYGEN) /* don't include <stdint.h> */
+/*!
+ * @brief An unsigned 64-bit integer.
+ *
+ * Not necessarily defined to `uint64_t` but functionally equivalent.
+ */
+typedef uint64_t XXH64_hash_t;
+#elif !defined (__VMS) \
+ && (defined (__cplusplus) \
+ || (defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) /* C99 */) )
+# include <stdint.h>
+ typedef uint64_t XXH64_hash_t;
+#else
+# include <limits.h>
+# if defined(__LP64__) && ULONG_MAX == 0xFFFFFFFFFFFFFFFFULL
+ /* LP64 ABI says uint64_t is unsigned long */
+ typedef unsigned long XXH64_hash_t;
+# else
+ /* the following type must have a width of 64-bit */
+ typedef unsigned long long XXH64_hash_t;
+# endif
+#endif
+
+/*!
+ * @}
+ *
+ * @defgroup xxh64_family XXH64 family
+ * @ingroup public
+ * @{
+ * Contains functions used in the classic 64-bit xxHash algorithm.
+ *
+ * @note
+ * XXH3 provides competitive speed for both 32-bit and 64-bit systems,
+ * and offers true 64/128 bit hash results. It provides a superior level of
+ * dispersion, and greatly reduces the risks of collisions.
+ */
+
+
+/*!
+ * @brief Calculates the 64-bit hash of @p input using xxHash64.
+ *
+ * This function usually runs faster on 64-bit systems, but slower on 32-bit
+ * systems (see benchmark).
+ *
+ * @param input The block of data to be hashed, at least @p length bytes in size.
+ * @param length The length of @p input, in bytes.
+ * @param seed The 64-bit seed to alter the hash's output predictably.
+ *
+ * @pre
+ * The memory between @p input and @p input + @p length must be valid,
+ * readable, contiguous memory. However, if @p length is `0`, @p input may be
+ * `NULL`. In C++, this also must be *TriviallyCopyable*.
+ *
+ * @return The calculated 64-bit hash.
+ *
+ * @see
+ * XXH32(), XXH3_64bits_withSeed(), XXH3_128bits_withSeed(), XXH128():
+ * Direct equivalents for the other variants of xxHash.
+ * @see
+ * XXH64_createState(), XXH64_update(), XXH64_digest(): Streaming version.
+ */
+XXH_PUBLIC_API XXH64_hash_t XXH64(const void* input, size_t length, XXH64_hash_t seed);
+
+/******* Streaming *******/
+/*!
+ * @brief The opaque state struct for the XXH64 streaming API.
+ *
+ * @see XXH64_state_s for details.
+ */
+typedef struct XXH64_state_s XXH64_state_t; /* incomplete type */
+XXH_PUBLIC_API XXH64_state_t* XXH64_createState(void);
+XXH_PUBLIC_API XXH_errorcode XXH64_freeState(XXH64_state_t* statePtr);
+XXH_PUBLIC_API void XXH64_copyState(XXH64_state_t* dst_state, const XXH64_state_t* src_state);
+
+XXH_PUBLIC_API XXH_errorcode XXH64_reset (XXH64_state_t* statePtr, XXH64_hash_t seed);
+XXH_PUBLIC_API XXH_errorcode XXH64_update (XXH64_state_t* statePtr, const void* input, size_t length);
+XXH_PUBLIC_API XXH64_hash_t XXH64_digest (const XXH64_state_t* statePtr);
+
+/******* Canonical representation *******/
+typedef struct { unsigned char digest[sizeof(XXH64_hash_t)]; } XXH64_canonical_t;
+XXH_PUBLIC_API void XXH64_canonicalFromHash(XXH64_canonical_t* dst, XXH64_hash_t hash);
+XXH_PUBLIC_API XXH64_hash_t XXH64_hashFromCanonical(const XXH64_canonical_t* src);
+
+/*!
+ * @}
+ * ************************************************************************
+ * @defgroup xxh3_family XXH3 family
+ * @ingroup public
+ * @{
+ *
+ * XXH3 is a more recent hash algorithm featuring:
+ * - Improved speed for both small and large inputs
+ * - True 64-bit and 128-bit outputs
+ * - SIMD acceleration
+ * - Improved 32-bit viability
+ *
+ * Speed analysis methodology is explained here:
+ *
+ * https://fastcompression.blogspot.com/2019/03/presenting-xxh3.html
+ *
+ * Compared to XXH64, expect XXH3 to run approximately
+ * ~2x faster on large inputs and >3x faster on small ones,
+ * exact differences vary depending on platform.
+ *
+ * XXH3's speed benefits greatly from SIMD and 64-bit arithmetic,
+ * but does not require it.
+ * Any 32-bit and 64-bit targets that can run XXH32 smoothly
+ * can run XXH3 at competitive speeds, even without vector support.
+ * Further details are explained in the implementation.
+ *
+ * Optimized implementations are provided for AVX512, AVX2, SSE2, NEON, POWER8,
+ * ZVector and scalar targets. This can be controlled via the XXH_VECTOR macro.
+ *
+ * XXH3 implementation is portable:
+ * it has a generic C90 formulation that can be compiled on any platform,
+ * all implementations generage exactly the same hash value on all platforms.
+ * Starting from v0.8.0, it's also labelled "stable", meaning that
+ * any future version will also generate the same hash value.
+ *
+ * XXH3 offers 2 variants, _64bits and _128bits.
+ *
+ * When only 64 bits are needed, prefer invoking the _64bits variant, as it
+ * reduces the amount of mixing, resulting in faster speed on small inputs.
+ * It's also generally simpler to manipulate a scalar return type than a struct.
+ *
+ * The API supports one-shot hashing, streaming mode, and custom secrets.
+ */
+
+/*-**********************************************************************
+* XXH3 64-bit variant
+************************************************************************/
+
+/* XXH3_64bits():
+ * default 64-bit variant, using default secret and default seed of 0.
+ * It's the fastest variant. */
+XXH_PUBLIC_API XXH64_hash_t XXH3_64bits(const void* data, size_t len);
+
+/*
+ * XXH3_64bits_withSeed():
+ * This variant generates a custom secret on the fly
+ * based on default secret altered using the `seed` value.
+ * While this operation is decently fast, note that it's not completely free.
+ * Note: seed==0 produces the same results as XXH3_64bits().
+ */
+XXH_PUBLIC_API XXH64_hash_t XXH3_64bits_withSeed(const void* data, size_t len, XXH64_hash_t seed);
+
+/*!
+ * The bare minimum size for a custom secret.
+ *
+ * @see
+ * XXH3_64bits_withSecret(), XXH3_64bits_reset_withSecret(),
+ * XXH3_128bits_withSecret(), XXH3_128bits_reset_withSecret().
+ */
+#define XXH3_SECRET_SIZE_MIN 136
+
+/*
+ * XXH3_64bits_withSecret():
+ * It's possible to provide any blob of bytes as a "secret" to generate the hash.
+ * This makes it more difficult for an external actor to prepare an intentional collision.
+ * The main condition is that secretSize *must* be large enough (>= XXH3_SECRET_SIZE_MIN).
+ * However, the quality of produced hash values depends on secret's entropy.
+ * Technically, the secret must look like a bunch of random bytes.
+ * Avoid "trivial" or structured data such as repeated sequences or a text document.
+ * Whenever unsure about the "randomness" of the blob of bytes,
+ * consider relabelling it as a "custom seed" instead,
+ * and employ "XXH3_generateSecret()" (see below)
+ * to generate a high entropy secret derived from the custom seed.
+ */
+XXH_PUBLIC_API XXH64_hash_t XXH3_64bits_withSecret(const void* data, size_t len, const void* secret, size_t secretSize);
+
+
+/******* Streaming *******/
+/*
+ * Streaming requires state maintenance.
+ * This operation costs memory and CPU.
+ * As a consequence, streaming is slower than one-shot hashing.
+ * For better performance, prefer one-shot functions whenever applicable.
+ */
+
+/*!
+ * @brief The state struct for the XXH3 streaming API.
+ *
+ * @see XXH3_state_s for details.
+ */
+typedef struct XXH3_state_s XXH3_state_t;
+XXH_PUBLIC_API XXH3_state_t* XXH3_createState(void);
+XXH_PUBLIC_API XXH_errorcode XXH3_freeState(XXH3_state_t* statePtr);
+XXH_PUBLIC_API void XXH3_copyState(XXH3_state_t* dst_state, const XXH3_state_t* src_state);
+
+/*
+ * XXH3_64bits_reset():
+ * Initialize with default parameters.
+ * digest will be equivalent to `XXH3_64bits()`.
+ */
+XXH_PUBLIC_API XXH_errorcode XXH3_64bits_reset(XXH3_state_t* statePtr);
+/*
+ * XXH3_64bits_reset_withSeed():
+ * Generate a custom secret from `seed`, and store it into `statePtr`.
+ * digest will be equivalent to `XXH3_64bits_withSeed()`.
+ */
+XXH_PUBLIC_API XXH_errorcode XXH3_64bits_reset_withSeed(XXH3_state_t* statePtr, XXH64_hash_t seed);
+/*
+ * XXH3_64bits_reset_withSecret():
+ * `secret` is referenced, it _must outlive_ the hash streaming session.
+ * Similar to one-shot API, `secretSize` must be >= `XXH3_SECRET_SIZE_MIN`,
+ * and the quality of produced hash values depends on secret's entropy
+ * (secret's content should look like a bunch of random bytes).
+ * When in doubt about the randomness of a candidate `secret`,
+ * consider employing `XXH3_generateSecret()` instead (see below).
+ */
+XXH_PUBLIC_API XXH_errorcode XXH3_64bits_reset_withSecret(XXH3_state_t* statePtr, const void* secret, size_t secretSize);
+
+XXH_PUBLIC_API XXH_errorcode XXH3_64bits_update (XXH3_state_t* statePtr, const void* input, size_t length);
+XXH_PUBLIC_API XXH64_hash_t XXH3_64bits_digest (const XXH3_state_t* statePtr);
+
+/* note : canonical representation of XXH3 is the same as XXH64
+ * since they both produce XXH64_hash_t values */
+
+
+/*-**********************************************************************
+* XXH3 128-bit variant
+************************************************************************/
+
+/*!
+ * @brief The return value from 128-bit hashes.
+ *
+ * Stored in little endian order, although the fields themselves are in native
+ * endianness.
+ */
+typedef struct {
+ XXH64_hash_t low64; /*!< `value & 0xFFFFFFFFFFFFFFFF` */
+ XXH64_hash_t high64; /*!< `value >> 64` */
+} XXH128_hash_t;
+
+XXH_PUBLIC_API XXH128_hash_t XXH3_128bits(const void* data, size_t len);
+XXH_PUBLIC_API XXH128_hash_t XXH3_128bits_withSeed(const void* data, size_t len, XXH64_hash_t seed);
+XXH_PUBLIC_API XXH128_hash_t XXH3_128bits_withSecret(const void* data, size_t len, const void* secret, size_t secretSize);
+
+/******* Streaming *******/
+/*
+ * Streaming requires state maintenance.
+ * This operation costs memory and CPU.
+ * As a consequence, streaming is slower than one-shot hashing.
+ * For better performance, prefer one-shot functions whenever applicable.
+ *
+ * XXH3_128bits uses the same XXH3_state_t as XXH3_64bits().
+ * Use already declared XXH3_createState() and XXH3_freeState().
+ *
+ * All reset and streaming functions have same meaning as their 64-bit counterpart.
+ */
+
+XXH_PUBLIC_API XXH_errorcode XXH3_128bits_reset(XXH3_state_t* statePtr);
+XXH_PUBLIC_API XXH_errorcode XXH3_128bits_reset_withSeed(XXH3_state_t* statePtr, XXH64_hash_t seed);
+XXH_PUBLIC_API XXH_errorcode XXH3_128bits_reset_withSecret(XXH3_state_t* statePtr, const void* secret, size_t secretSize);
+
+XXH_PUBLIC_API XXH_errorcode XXH3_128bits_update (XXH3_state_t* statePtr, const void* input, size_t length);
+XXH_PUBLIC_API XXH128_hash_t XXH3_128bits_digest (const XXH3_state_t* statePtr);
+
+/* Following helper functions make it possible to compare XXH128_hast_t values.
+ * Since XXH128_hash_t is a structure, this capability is not offered by the language.
+ * Note: For better performance, these functions can be inlined using XXH_INLINE_ALL */
+
+/*!
+ * XXH128_isEqual():
+ * Return: 1 if `h1` and `h2` are equal, 0 if they are not.
+ */
+XXH_PUBLIC_API int XXH128_isEqual(XXH128_hash_t h1, XXH128_hash_t h2);
+
+/*!
+ * XXH128_cmp():
+ *
+ * This comparator is compatible with stdlib's `qsort()`/`bsearch()`.
+ *
+ * return: >0 if *h128_1 > *h128_2
+ * =0 if *h128_1 == *h128_2
+ * <0 if *h128_1 < *h128_2
+ */
+XXH_PUBLIC_API int XXH128_cmp(const void* h128_1, const void* h128_2);
+
+
+/******* Canonical representation *******/
+typedef struct { unsigned char digest[sizeof(XXH128_hash_t)]; } XXH128_canonical_t;
+XXH_PUBLIC_API void XXH128_canonicalFromHash(XXH128_canonical_t* dst, XXH128_hash_t hash);
+XXH_PUBLIC_API XXH128_hash_t XXH128_hashFromCanonical(const XXH128_canonical_t* src);
+
+
+#endif /* XXH_NO_LONG_LONG */
+
+/*!
+ * @}
+ */
+#endif /* XXHASH_H_5627135585666179 */
+
+
+
+#if defined(XXH_STATIC_LINKING_ONLY) && !defined(XXHASH_H_STATIC_13879238742)
+#define XXHASH_H_STATIC_13879238742
+/* ****************************************************************************
+ * This section contains declarations which are not guaranteed to remain stable.
+ * They may change in future versions, becoming incompatible with a different
+ * version of the library.
+ * These declarations should only be used with static linking.
+ * Never use them in association with dynamic linking!
+ ***************************************************************************** */
+
+/*
+ * These definitions are only present to allow static allocation
+ * of XXH states, on stack or in a struct, for example.
+ * Never **ever** access their members directly.
+ */
+
+/*!
+ * @internal
+ * @brief Structure for XXH32 streaming API.
+ *
+ * @note This is only defined when @ref XXH_STATIC_LINKING_ONLY,
+ * @ref XXH_INLINE_ALL, or @ref XXH_IMPLEMENTATION is defined. Otherwise it is
+ * an opaque type. This allows fields to safely be changed.
+ *
+ * Typedef'd to @ref XXH32_state_t.
+ * Do not access the members of this struct directly.
+ * @see XXH64_state_s, XXH3_state_s
+ */
+struct XXH32_state_s {
+ XXH32_hash_t total_len_32; /*!< Total length hashed, modulo 2^32 */
+ XXH32_hash_t large_len; /*!< Whether the hash is >= 16 (handles @ref total_len_32 overflow) */
+ XXH32_hash_t v1; /*!< First accumulator lane */
+ XXH32_hash_t v2; /*!< Second accumulator lane */
+ XXH32_hash_t v3; /*!< Third accumulator lane */
+ XXH32_hash_t v4; /*!< Fourth accumulator lane */
+ XXH32_hash_t mem32[4]; /*!< Internal buffer for partial reads. Treated as unsigned char[16]. */
+ XXH32_hash_t memsize; /*!< Amount of data in @ref mem32 */
+ XXH32_hash_t reserved; /*!< Reserved field. Do not read or write to it, it may be removed. */
+}; /* typedef'd to XXH32_state_t */
+
+
+#ifndef XXH_NO_LONG_LONG /* defined when there is no 64-bit support */
+
+/*!
+ * @internal
+ * @brief Structure for XXH64 streaming API.
+ *
+ * @note This is only defined when @ref XXH_STATIC_LINKING_ONLY,
+ * @ref XXH_INLINE_ALL, or @ref XXH_IMPLEMENTATION is defined. Otherwise it is
+ * an opaque type. This allows fields to safely be changed.
+ *
+ * Typedef'd to @ref XXH64_state_t.
+ * Do not access the members of this struct directly.
+ * @see XXH32_state_s, XXH3_state_s
+ */
+struct XXH64_state_s {
+ XXH64_hash_t total_len; /*!< Total length hashed. This is always 64-bit. */
+ XXH64_hash_t v1; /*!< First accumulator lane */
+ XXH64_hash_t v2; /*!< Second accumulator lane */
+ XXH64_hash_t v3; /*!< Third accumulator lane */
+ XXH64_hash_t v4; /*!< Fourth accumulator lane */
+ XXH64_hash_t mem64[4]; /*!< Internal buffer for partial reads. Treated as unsigned char[32]. */
+ XXH32_hash_t memsize; /*!< Amount of data in @ref mem64 */
+ XXH32_hash_t reserved32; /*!< Reserved field, needed for padding anyways*/
+ XXH64_hash_t reserved64; /*!< Reserved field. Do not read or write to it, it may be removed. */
+}; /* typedef'd to XXH64_state_t */
+
+#if defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 201112L) /* C11+ */
+# include <stdalign.h>
+# define XXH_ALIGN(n) alignas(n)
+#elif defined(__GNUC__)
+# define XXH_ALIGN(n) __attribute__ ((aligned(n)))
+#elif defined(_MSC_VER)
+# define XXH_ALIGN(n) __declspec(align(n))
+#else
+# define XXH_ALIGN(n) /* disabled */
+#endif
+
+/* Old GCC versions only accept the attribute after the type in structures. */
+#if !(defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201112L)) /* C11+ */ \
+ && defined(__GNUC__)
+# define XXH_ALIGN_MEMBER(align, type) type XXH_ALIGN(align)
+#else
+# define XXH_ALIGN_MEMBER(align, type) XXH_ALIGN(align) type
+#endif
+
+/*!
+ * @brief The size of the internal XXH3 buffer.
+ *
+ * This is the optimal update size for incremental hashing.
+ *
+ * @see XXH3_64b_update(), XXH3_128b_update().
+ */
+#define XXH3_INTERNALBUFFER_SIZE 256
+
+/*!
+ * @brief Default size of the secret buffer (and @ref XXH3_kSecret).
+ *
+ * This is the size used in @ref XXH3_kSecret and the seeded functions.
+ *
+ * Not to be confused with @ref XXH3_SECRET_SIZE_MIN.
+ */
+#define XXH3_SECRET_DEFAULT_SIZE 192
+
+/*!
+ * @internal
+ * @brief Structure for XXH3 streaming API.
+ *
+ * @note This is only defined when @ref XXH_STATIC_LINKING_ONLY,
+ * @ref XXH_INLINE_ALL, or @ref XXH_IMPLEMENTATION is defined. Otherwise it is
+ * an opaque type. This allows fields to safely be changed.
+ *
+ * @note **This structure has a strict alignment requirement of 64 bytes.** Do
+ * not allocate this with `malloc()` or `new`, it will not be sufficiently
+ * aligned. Use @ref XXH3_createState() and @ref XXH3_freeState(), or stack
+ * allocation.
+ *
+ * Typedef'd to @ref XXH3_state_t.
+ * Do not access the members of this struct directly.
+ *
+ * @see XXH3_INITSTATE() for stack initialization.
+ * @see XXH3_createState(), XXH3_freeState().
+ * @see XXH32_state_s, XXH64_state_s
+ */
+struct XXH3_state_s {
+ XXH_ALIGN_MEMBER(64, XXH64_hash_t acc[8]);
+ /*!< The 8 accumulators. Similar to `vN` in @ref XXH32_state_s::v1 and @ref XXH64_state_s */
+ XXH_ALIGN_MEMBER(64, unsigned char customSecret[XXH3_SECRET_DEFAULT_SIZE]);
+ /*!< Used to store a custom secret generated from a seed. */
+ XXH_ALIGN_MEMBER(64, unsigned char buffer[XXH3_INTERNALBUFFER_SIZE]);
+ /*!< The internal buffer. @see XXH32_state_s::mem32 */
+ XXH32_hash_t bufferedSize;
+ /*!< The amount of memory in @ref buffer, @see XXH32_state_s::memsize */
+ XXH32_hash_t reserved32;
+ /*!< Reserved field. Needed for padding on 64-bit. */
+ size_t nbStripesSoFar;
+ /*!< Number or stripes processed. */
+ XXH64_hash_t totalLen;
+ /*!< Total length hashed. 64-bit even on 32-bit targets. */
+ size_t nbStripesPerBlock;
+ /*!< Number of stripes per block. */
+ size_t secretLimit;
+ /*!< Size of @ref customSecret or @ref extSecret */
+ XXH64_hash_t seed;
+ /*!< Seed for _withSeed variants. Must be zero otherwise, @see XXH3_INITSTATE() */
+ XXH64_hash_t reserved64;
+ /*!< Reserved field. */
+ const unsigned char* extSecret;
+ /*!< Reference to an external secret for the _withSecret variants, NULL
+ * for other variants. */
+ /* note: there may be some padding at the end due to alignment on 64 bytes */
+}; /* typedef'd to XXH3_state_t */
+
+#undef XXH_ALIGN_MEMBER
+
+/*!
+ * @brief Initializes a stack-allocated `XXH3_state_s`.
+ *
+ * When the @ref XXH3_state_t structure is merely emplaced on stack,
+ * it should be initialized with XXH3_INITSTATE() or a memset()
+ * in case its first reset uses XXH3_NNbits_reset_withSeed().
+ * This init can be omitted if the first reset uses default or _withSecret mode.
+ * This operation isn't necessary when the state is created with XXH3_createState().
+ * Note that this doesn't prepare the state for a streaming operation,
+ * it's still necessary to use XXH3_NNbits_reset*() afterwards.
+ */
+#define XXH3_INITSTATE(XXH3_state_ptr) { (XXH3_state_ptr)->seed = 0; }
+
+
+/* === Experimental API === */
+/* Symbols defined below must be considered tied to a specific library version. */
+
+/*
+ * XXH3_generateSecret():
+ *
+ * Derive a high-entropy secret from any user-defined content, named customSeed.
+ * The generated secret can be used in combination with `*_withSecret()` functions.
+ * The `_withSecret()` variants are useful to provide a higher level of protection than 64-bit seed,
+ * as it becomes much more difficult for an external actor to guess how to impact the calculation logic.
+ *
+ * The function accepts as input a custom seed of any length and any content,
+ * and derives from it a high-entropy secret of length XXH3_SECRET_DEFAULT_SIZE
+ * into an already allocated buffer secretBuffer.
+ * The generated secret is _always_ XXH_SECRET_DEFAULT_SIZE bytes long.
+ *
+ * The generated secret can then be used with any `*_withSecret()` variant.
+ * Functions `XXH3_128bits_withSecret()`, `XXH3_64bits_withSecret()`,
+ * `XXH3_128bits_reset_withSecret()` and `XXH3_64bits_reset_withSecret()`
+ * are part of this list. They all accept a `secret` parameter
+ * which must be very long for implementation reasons (>= XXH3_SECRET_SIZE_MIN)
+ * _and_ feature very high entropy (consist of random-looking bytes).
+ * These conditions can be a high bar to meet, so
+ * this function can be used to generate a secret of proper quality.
+ *
+ * customSeed can be anything. It can have any size, even small ones,
+ * and its content can be anything, even stupidly "low entropy" source such as a bunch of zeroes.
+ * The resulting `secret` will nonetheless provide all expected qualities.
+ *
+ * Supplying NULL as the customSeed copies the default secret into `secretBuffer`.
+ * When customSeedSize > 0, supplying NULL as customSeed is undefined behavior.
+ */
+XXH_PUBLIC_API void XXH3_generateSecret(void* secretBuffer, const void* customSeed, size_t customSeedSize);
+
+
+/* simple short-cut to pre-selected XXH3_128bits variant */
+XXH_PUBLIC_API XXH128_hash_t XXH128(const void* data, size_t len, XXH64_hash_t seed);
+
+
+#endif /* XXH_NO_LONG_LONG */
+#if defined(XXH_INLINE_ALL) || defined(XXH_PRIVATE_API)
+# define XXH_IMPLEMENTATION
+#endif
+
+#endif /* defined(XXH_STATIC_LINKING_ONLY) && !defined(XXHASH_H_STATIC_13879238742) */
+
+
+/* ======================================================================== */
+/* ======================================================================== */
+/* ======================================================================== */
+
+
+/*-**********************************************************************
+ * xxHash implementation
+ *-**********************************************************************
+ * xxHash's implementation used to be hosted inside xxhash.c.
+ *
+ * However, inlining requires implementation to be visible to the compiler,
+ * hence be included alongside the header.
+ * Previously, implementation was hosted inside xxhash.c,
+ * which was then #included when inlining was activated.
+ * This construction created issues with a few build and install systems,
+ * as it required xxhash.c to be stored in /include directory.
+ *
+ * xxHash implementation is now directly integrated within xxhash.h.
+ * As a consequence, xxhash.c is no longer needed in /include.
+ *
+ * xxhash.c is still available and is still useful.
+ * In a "normal" setup, when xxhash is not inlined,
+ * xxhash.h only exposes the prototypes and public symbols,
+ * while xxhash.c can be built into an object file xxhash.o
+ * which can then be linked into the final binary.
+ ************************************************************************/
+
+#if ( defined(XXH_INLINE_ALL) || defined(XXH_PRIVATE_API) \
+ || defined(XXH_IMPLEMENTATION) ) && !defined(XXH_IMPLEM_13a8737387)
+# define XXH_IMPLEM_13a8737387
+
+/* *************************************
+* Tuning parameters
+***************************************/
+
+/*!
+ * @defgroup tuning Tuning parameters
+ * @{
+ *
+ * Various macros to control xxHash's behavior.
+ */
+#ifdef XXH_DOXYGEN
+/*!
+ * @brief Define this to disable 64-bit code.
+ *
+ * Useful if only using the @ref xxh32_family and you have a strict C90 compiler.
+ */
+# define XXH_NO_LONG_LONG
+# undef XXH_NO_LONG_LONG /* don't actually */
+/*!
+ * @brief Controls how unaligned memory is accessed.
+ *
+ * By default, access to unaligned memory is controlled by `memcpy()`, which is
+ * safe and portable.
+ *
+ * Unfortunately, on some target/compiler combinations, the generated assembly
+ * is sub-optimal.
+ *
+ * The below switch allow selection of a different access method
+ * in the search for improved performance.
+ *
+ * @par Possible options:
+ *
+ * - `XXH_FORCE_MEMORY_ACCESS=0` (default): `memcpy`
+ * @par
+ * Use `memcpy()`. Safe and portable. Note that most modern compilers will
+ * eliminate the function call and treat it as an unaligned access.
+ *
+ * - `XXH_FORCE_MEMORY_ACCESS=1`: `__attribute__((packed))`
+ * @par
+ * Depends on compiler extensions and is therefore not portable.
+ * This method is safe _if_ your compiler supports it,
+ * and *generally* as fast or faster than `memcpy`.
+ *
+ * - `XXH_FORCE_MEMORY_ACCESS=2`: Direct cast
+ * @par
+ * Casts directly and dereferences. This method doesn't depend on the
+ * compiler, but it violates the C standard as it directly dereferences an
+ * unaligned pointer. It can generate buggy code on targets which do not
+ * support unaligned memory accesses, but in some circumstances, it's the
+ * only known way to get the most performance.
+ *
+ * - `XXH_FORCE_MEMORY_ACCESS=3`: Byteshift
+ * @par
+ * Also portable. This can generate the best code on old compilers which don't
+ * inline small `memcpy()` calls, and it might also be faster on big-endian
+ * systems which lack a native byteswap instruction. However, some compilers
+ * will emit literal byteshifts even if the target supports unaligned access.
+ * .
+ *
+ * @warning
+ * Methods 1 and 2 rely on implementation-defined behavior. Use these with
+ * care, as what works on one compiler/platform/optimization level may cause
+ * another to read garbage data or even crash.
+ *
+ * See https://stackoverflow.com/a/32095106/646947 for details.
+ *
+ * Prefer these methods in priority order (0 > 3 > 1 > 2)
+ */
+# define XXH_FORCE_MEMORY_ACCESS 0
+/*!
+ * @def XXH_ACCEPT_NULL_INPUT_POINTER
+ * @brief Whether to add explicit `NULL` checks.
+ *
+ * If the input pointer is `NULL` and the length is non-zero, xxHash's default
+ * behavior is to dereference it, triggering a segfault.
+ *
+ * When this macro is enabled, xxHash actively checks the input for a null pointer.
+ * If it is, the result for null input pointers is the same as a zero-length input.
+ */
+# define XXH_ACCEPT_NULL_INPUT_POINTER 0
+/*!
+ * @def XXH_FORCE_ALIGN_CHECK
+ * @brief If defined to non-zero, adds a special path for aligned inputs (XXH32()
+ * and XXH64() only).
+ *
+ * This is an important performance trick for architectures without decent
+ * unaligned memory access performance.
+ *
+ * It checks for input alignment, and when conditions are met, uses a "fast
+ * path" employing direct 32-bit/64-bit reads, resulting in _dramatically
+ * faster_ read speed.
+ *
+ * The check costs one initial branch per hash, which is generally negligible,
+ * but not zero.
+ *
+ * Moreover, it's not useful to generate an additional code path if memory
+ * access uses the same instruction for both aligned and unaligned
+ * addresses (e.g. x86 and aarch64).
+ *
+ * In these cases, the alignment check can be removed by setting this macro to 0.
+ * Then the code will always use unaligned memory access.
+ * Align check is automatically disabled on x86, x64 & arm64,
+ * which are platforms known to offer good unaligned memory accesses performance.
+ *
+ * This option does not affect XXH3 (only XXH32 and XXH64).
+ */
+# define XXH_FORCE_ALIGN_CHECK 0
+
+/*!
+ * @def XXH_NO_INLINE_HINTS
+ * @brief When non-zero, sets all functions to `static`.
+ *
+ * By default, xxHash tries to force the compiler to inline almost all internal
+ * functions.
+ *
+ * This can usually improve performance due to reduced jumping and improved
+ * constant folding, but significantly increases the size of the binary which
+ * might not be favorable.
+ *
+ * Additionally, sometimes the forced inlining can be detrimental to performance,
+ * depending on the architecture.
+ *
+ * XXH_NO_INLINE_HINTS marks all internal functions as static, giving the
+ * compiler full control on whether to inline or not.
+ *
+ * When not optimizing (-O0), optimizing for size (-Os, -Oz), or using
+ * -fno-inline with GCC or Clang, this will automatically be defined.
+ */
+# define XXH_NO_INLINE_HINTS 0
+
+/*!
+ * @def XXH_REROLL
+ * @brief Whether to reroll `XXH32_finalize` and `XXH64_finalize`.
+ *
+ * For performance, `XXH32_finalize` and `XXH64_finalize` use an unrolled loop
+ * in the form of a switch statement.
+ *
+ * This is not always desirable, as it generates larger code, and depending on
+ * the architecture, may even be slower
+ *
+ * This is automatically defined with `-Os`/`-Oz` on GCC and Clang.
+ */
+# define XXH_REROLL 0
+
+/*!
+ * @internal
+ * @brief Redefines old internal names.
+ *
+ * For compatibility with code that uses xxHash's internals before the names
+ * were changed to improve namespacing. There is no other reason to use this.
+ */
+# define XXH_OLD_NAMES
+# undef XXH_OLD_NAMES /* don't actually use, it is ugly. */
+#endif /* XXH_DOXYGEN */
+/*!
+ * @}
+ */
+
+#ifndef XXH_FORCE_MEMORY_ACCESS /* can be defined externally, on command line for example */
+ /* prefer __packed__ structures (method 1) for gcc on armv7 and armv8 */
+# if !defined(__clang__) && ( \
+ (defined(__INTEL_COMPILER) && !defined(_WIN32)) || \
+ (defined(__GNUC__) && (defined(__ARM_ARCH) && __ARM_ARCH >= 7)) )
+# define XXH_FORCE_MEMORY_ACCESS 1
+# endif
+#endif
+
+#ifndef XXH_ACCEPT_NULL_INPUT_POINTER /* can be defined externally */
+# define XXH_ACCEPT_NULL_INPUT_POINTER 0
+#endif
+
+#ifndef XXH_FORCE_ALIGN_CHECK /* can be defined externally */
+# if defined(__i386) || defined(__x86_64__) || defined(__aarch64__) \
+ || defined(_M_IX86) || defined(_M_X64) || defined(_M_ARM64) /* visual */
+# define XXH_FORCE_ALIGN_CHECK 0
+# else
+# define XXH_FORCE_ALIGN_CHECK 1
+# endif
+#endif
+
+#ifndef XXH_NO_INLINE_HINTS
+# if defined(__OPTIMIZE_SIZE__) /* -Os, -Oz */ \
+ || defined(__NO_INLINE__) /* -O0, -fno-inline */
+# define XXH_NO_INLINE_HINTS 1
+# else
+# define XXH_NO_INLINE_HINTS 0
+# endif
+#endif
+
+#ifndef XXH_REROLL
+# if defined(__OPTIMIZE_SIZE__)
+# define XXH_REROLL 1
+# else
+# define XXH_REROLL 0
+# endif
+#endif
+
+/*!
+ * @defgroup impl Implementation
+ * @{
+ */
+
+
+/* *************************************
+* Includes & Memory related functions
+***************************************/
+/*
+ * Modify the local functions below should you wish to use
+ * different memory routines for malloc() and free()
+ */
+#include <stdlib.h>
+
+/*!
+ * @internal
+ * @brief Modify this function to use a different routine than malloc().
+ */
+static void* XXH_malloc(size_t s) { return malloc(s); }
+
+/*!
+ * @internal
+ * @brief Modify this function to use a different routine than free().
+ */
+static void XXH_free(void* p) { free(p); }
+
+#include <string.h>
+
+/*!
+ * @internal
+ * @brief Modify this function to use a different routine than memcpy().
+ */
+static void* XXH_memcpy(void* dest, const void* src, size_t size)
+{
+ return memcpy(dest,src,size);
+}
+
+#include <limits.h> /* ULLONG_MAX */
+
+
+/* *************************************
+* Compiler Specific Options
+***************************************/
+#ifdef _MSC_VER /* Visual Studio warning fix */
+# pragma warning(disable : 4127) /* disable: C4127: conditional expression is constant */
+#endif
+
+#if XXH_NO_INLINE_HINTS /* disable inlining hints */
+# if defined(__GNUC__)
+# define XXH_FORCE_INLINE static __attribute__((unused))
+# else
+# define XXH_FORCE_INLINE static
+# endif
+# define XXH_NO_INLINE static
+/* enable inlining hints */
+#elif defined(_MSC_VER) /* Visual Studio */
+# define XXH_FORCE_INLINE static __forceinline
+# define XXH_NO_INLINE static __declspec(noinline)
+#elif defined(__GNUC__)
+# define XXH_FORCE_INLINE static __inline__ __attribute__((always_inline, unused))
+# define XXH_NO_INLINE static __attribute__((noinline))
+#elif defined (__cplusplus) \
+ || (defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L)) /* C99 */
+# define XXH_FORCE_INLINE static inline
+# define XXH_NO_INLINE static
+#else
+# define XXH_FORCE_INLINE static
+# define XXH_NO_INLINE static
+#endif
+
+
+
+/* *************************************
+* Debug
+***************************************/
+/*!
+ * @ingroup tuning
+ * @def XXH_DEBUGLEVEL
+ * @brief Sets the debugging level.
+ *
+ * XXH_DEBUGLEVEL is expected to be defined externally, typically via the
+ * compiler's command line options. The value must be a number.
+ */
+#ifndef XXH_DEBUGLEVEL
+# ifdef DEBUGLEVEL /* backwards compat */
+# define XXH_DEBUGLEVEL DEBUGLEVEL
+# else
+# define XXH_DEBUGLEVEL 0
+# endif
+#endif
+
+#if (XXH_DEBUGLEVEL>=1)
+# include <assert.h> /* note: can still be disabled with NDEBUG */
+# define XXH_ASSERT(c) assert(c)
+#else
+# define XXH_ASSERT(c) ((void)0)
+#endif
+
+/* note: use after variable declarations */
+#define XXH_STATIC_ASSERT(c) do { enum { XXH_sa = 1/(int)(!!(c)) }; } while (0)
+
+/*!
+ * @internal
+ * @def XXH_COMPILER_GUARD(var)
+ * @brief Used to prevent unwanted optimizations for @p var.
+ *
+ * It uses an empty GCC inline assembly statement with a register constraint
+ * which forces @p var into a general purpose register (eg eax, ebx, ecx
+ * on x86) and marks it as modified.
+ *
+ * This is used in a few places to avoid unwanted autovectorization (e.g.
+ * XXH32_round()). All vectorization we want is explicit via intrinsics,
+ * and _usually_ isn't wanted elsewhere.
+ *
+ * We also use it to prevent unwanted constant folding for AArch64 in
+ * XXH3_initCustomSecret_scalar().
+ */
+#ifdef __GNUC__
+# define XXH_COMPILER_GUARD(var) __asm__ __volatile__("" : "+r" (var))
+#else
+# define XXH_COMPILER_GUARD(var) ((void)0)
+#endif
+
+/* *************************************
+* Basic Types
+***************************************/
+#if !defined (__VMS) \
+ && (defined (__cplusplus) \
+ || (defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) /* C99 */) )
+# include <stdint.h>
+ typedef uint8_t xxh_u8;
+#else
+ typedef unsigned char xxh_u8;
+#endif
+typedef XXH32_hash_t xxh_u32;
+
+#ifdef XXH_OLD_NAMES
+# define BYTE xxh_u8
+# define U8 xxh_u8
+# define U32 xxh_u32
+#endif
+
+/* *** Memory access *** */
+
+/*!
+ * @internal
+ * @fn xxh_u32 XXH_read32(const void* ptr)
+ * @brief Reads an unaligned 32-bit integer from @p ptr in native endianness.
+ *
+ * Affected by @ref XXH_FORCE_MEMORY_ACCESS.
+ *
+ * @param ptr The pointer to read from.
+ * @return The 32-bit native endian integer from the bytes at @p ptr.
+ */
+
+/*!
+ * @internal
+ * @fn xxh_u32 XXH_readLE32(const void* ptr)
+ * @brief Reads an unaligned 32-bit little endian integer from @p ptr.
+ *
+ * Affected by @ref XXH_FORCE_MEMORY_ACCESS.
+ *
+ * @param ptr The pointer to read from.
+ * @return The 32-bit little endian integer from the bytes at @p ptr.
+ */
+
+/*!
+ * @internal
+ * @fn xxh_u32 XXH_readBE32(const void* ptr)
+ * @brief Reads an unaligned 32-bit big endian integer from @p ptr.
+ *
+ * Affected by @ref XXH_FORCE_MEMORY_ACCESS.
+ *
+ * @param ptr The pointer to read from.
+ * @return The 32-bit big endian integer from the bytes at @p ptr.
+ */
+
+/*!
+ * @internal
+ * @fn xxh_u32 XXH_readLE32_align(const void* ptr, XXH_alignment align)
+ * @brief Like @ref XXH_readLE32(), but has an option for aligned reads.
+ *
+ * Affected by @ref XXH_FORCE_MEMORY_ACCESS.
+ * Note that when @ref XXH_FORCE_ALIGN_CHECK == 0, the @p align parameter is
+ * always @ref XXH_alignment::XXH_unaligned.
+ *
+ * @param ptr The pointer to read from.
+ * @param align Whether @p ptr is aligned.
+ * @pre
+ * If @p align == @ref XXH_alignment::XXH_aligned, @p ptr must be 4 byte
+ * aligned.
+ * @return The 32-bit little endian integer from the bytes at @p ptr.
+ */
+
+#if (defined(XXH_FORCE_MEMORY_ACCESS) && (XXH_FORCE_MEMORY_ACCESS==3))
+/*
+ * Manual byteshift. Best for old compilers which don't inline memcpy.
+ * We actually directly use XXH_readLE32 and XXH_readBE32.
+ */
+#elif (defined(XXH_FORCE_MEMORY_ACCESS) && (XXH_FORCE_MEMORY_ACCESS==2))
+
+/*
+ * Force direct memory access. Only works on CPU which support unaligned memory
+ * access in hardware.
+ */
+static xxh_u32 XXH_read32(const void* memPtr) { return *(const xxh_u32*) memPtr; }
+
+#elif (defined(XXH_FORCE_MEMORY_ACCESS) && (XXH_FORCE_MEMORY_ACCESS==1))
+
+/*
+ * __pack instructions are safer but compiler specific, hence potentially
+ * problematic for some compilers.
+ *
+ * Currently only defined for GCC and ICC.
+ */
+#ifdef XXH_OLD_NAMES
+typedef union { xxh_u32 u32; } __attribute__((packed)) unalign;
+#endif
+static xxh_u32 XXH_read32(const void* ptr)
+{
+ typedef union { xxh_u32 u32; } __attribute__((packed)) xxh_unalign;
+ return ((const xxh_unalign*)ptr)->u32;
+}
+
+#else
+
+/*
+ * Portable and safe solution. Generally efficient.
+ * see: https://stackoverflow.com/a/32095106/646947
+ */
+static xxh_u32 XXH_read32(const void* memPtr)
+{
+ xxh_u32 val;
+ memcpy(&val, memPtr, sizeof(val));
+ return val;
+}
+
+#endif /* XXH_FORCE_DIRECT_MEMORY_ACCESS */
+
+
+/* *** Endianness *** */
+/*!
+ * @ingroup tuning
+ * @def XXH_CPU_LITTLE_ENDIAN
+ * @brief Whether the target is little endian.
+ *
+ * Defined to 1 if the target is little endian, or 0 if it is big endian.
+ * It can be defined externally, for example on the compiler command line.
+ *
+ * If it is not defined, a runtime check (which is usually constant folded)
+ * is used instead.
+ *
+ * @note
+ * This is not necessarily defined to an integer constant.
+ *
+ * @see XXH_isLittleEndian() for the runtime check.
+ */
+#ifndef XXH_CPU_LITTLE_ENDIAN
+/*
+ * Try to detect endianness automatically, to avoid the nonstandard behavior
+ * in `XXH_isLittleEndian()`
+ */
+# if defined(_WIN32) /* Windows is always little endian */ \
+ || defined(__LITTLE_ENDIAN__) \
+ || (defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)
+# define XXH_CPU_LITTLE_ENDIAN 1
+# elif defined(__BIG_ENDIAN__) \
+ || (defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__)
+# define XXH_CPU_LITTLE_ENDIAN 0
+# else
+/*!
+ * @internal
+ * @brief Runtime check for @ref XXH_CPU_LITTLE_ENDIAN.
+ *
+ * Most compilers will constant fold this.
+ */
+static int XXH_isLittleEndian(void)
+{
+ /*
+ * Portable and well-defined behavior.
+ * Don't use static: it is detrimental to performance.
+ */
+ const union { xxh_u32 u; xxh_u8 c[4]; } one = { 1 };
+ return one.c[0];
+}
+# define XXH_CPU_LITTLE_ENDIAN XXH_isLittleEndian()
+# endif
+#endif
+
+
+
+
+/* ****************************************
+* Compiler-specific Functions and Macros
+******************************************/
+#define XXH_GCC_VERSION (__GNUC__ * 100 + __GNUC_MINOR__)
+
+#ifdef __has_builtin
+# define XXH_HAS_BUILTIN(x) __has_builtin(x)
+#else
+# define XXH_HAS_BUILTIN(x) 0
+#endif
+
+/*!
+ * @internal
+ * @def XXH_rotl32(x,r)
+ * @brief 32-bit rotate left.
+ *
+ * @param x The 32-bit integer to be rotated.
+ * @param r The number of bits to rotate.
+ * @pre
+ * @p r > 0 && @p r < 32
+ * @note
+ * @p x and @p r may be evaluated multiple times.
+ * @return The rotated result.
+ */
+#if !defined(NO_CLANG_BUILTIN) && XXH_HAS_BUILTIN(__builtin_rotateleft32) \
+ && XXH_HAS_BUILTIN(__builtin_rotateleft64)
+# define XXH_rotl32 __builtin_rotateleft32
+# define XXH_rotl64 __builtin_rotateleft64
+/* Note: although _rotl exists for minGW (GCC under windows), performance seems poor */
+#elif defined(_MSC_VER)
+# define XXH_rotl32(x,r) _rotl(x,r)
+# define XXH_rotl64(x,r) _rotl64(x,r)
+#else
+# define XXH_rotl32(x,r) (((x) << (r)) | ((x) >> (32 - (r))))
+# define XXH_rotl64(x,r) (((x) << (r)) | ((x) >> (64 - (r))))
+#endif
+
+/*!
+ * @internal
+ * @fn xxh_u32 XXH_swap32(xxh_u32 x)
+ * @brief A 32-bit byteswap.
+ *
+ * @param x The 32-bit integer to byteswap.
+ * @return @p x, byteswapped.
+ */
+#if defined(_MSC_VER) /* Visual Studio */
+# define XXH_swap32 _byteswap_ulong
+#elif XXH_GCC_VERSION >= 403
+# define XXH_swap32 __builtin_bswap32
+#else
+static xxh_u32 XXH_swap32 (xxh_u32 x)
+{
+ return ((x << 24) & 0xff000000 ) |
+ ((x << 8) & 0x00ff0000 ) |
+ ((x >> 8) & 0x0000ff00 ) |
+ ((x >> 24) & 0x000000ff );
+}
+#endif
+
+
+/* ***************************
+* Memory reads
+*****************************/
+
+/*!
+ * @internal
+ * @brief Enum to indicate whether a pointer is aligned.
+ */
+typedef enum {
+ XXH_aligned, /*!< Aligned */
+ XXH_unaligned /*!< Possibly unaligned */
+} XXH_alignment;
+
+/*
+ * XXH_FORCE_MEMORY_ACCESS==3 is an endian-independent byteshift load.
+ *
+ * This is ideal for older compilers which don't inline memcpy.
+ */
+#if (defined(XXH_FORCE_MEMORY_ACCESS) && (XXH_FORCE_MEMORY_ACCESS==3))
+
+XXH_FORCE_INLINE xxh_u32 XXH_readLE32(const void* memPtr)
+{
+ const xxh_u8* bytePtr = (const xxh_u8 *)memPtr;
+ return bytePtr[0]
+ | ((xxh_u32)bytePtr[1] << 8)
+ | ((xxh_u32)bytePtr[2] << 16)
+ | ((xxh_u32)bytePtr[3] << 24);
+}
+
+XXH_FORCE_INLINE xxh_u32 XXH_readBE32(const void* memPtr)
+{
+ const xxh_u8* bytePtr = (const xxh_u8 *)memPtr;
+ return bytePtr[3]
+ | ((xxh_u32)bytePtr[2] << 8)
+ | ((xxh_u32)bytePtr[1] << 16)
+ | ((xxh_u32)bytePtr[0] << 24);
+}
+
+#else
+XXH_FORCE_INLINE xxh_u32 XXH_readLE32(const void* ptr)
+{
+ return XXH_CPU_LITTLE_ENDIAN ? XXH_read32(ptr) : XXH_swap32(XXH_read32(ptr));
+}
+
+static xxh_u32 XXH_readBE32(const void* ptr)
+{
+ return XXH_CPU_LITTLE_ENDIAN ? XXH_swap32(XXH_read32(ptr)) : XXH_read32(ptr);
+}
+#endif
+
+XXH_FORCE_INLINE xxh_u32
+XXH_readLE32_align(const void* ptr, XXH_alignment align)
+{
+ if (align==XXH_unaligned) {
+ return XXH_readLE32(ptr);
+ } else {
+ return XXH_CPU_LITTLE_ENDIAN ? *(const xxh_u32*)ptr : XXH_swap32(*(const xxh_u32*)ptr);
+ }
+}
+
+
+/* *************************************
+* Misc
+***************************************/
+/*! @ingroup public */
+XXH_PUBLIC_API unsigned XXH_versionNumber (void) { return XXH_VERSION_NUMBER; }
+
+
+/* *******************************************************************
+* 32-bit hash functions
+*********************************************************************/
+/*!
+ * @}
+ * @defgroup xxh32_impl XXH32 implementation
+ * @ingroup impl
+ * @{
+ */
+ /* #define instead of static const, to be used as initializers */
+#define XXH_PRIME32_1 0x9E3779B1U /*!< 0b10011110001101110111100110110001 */
+#define XXH_PRIME32_2 0x85EBCA77U /*!< 0b10000101111010111100101001110111 */
+#define XXH_PRIME32_3 0xC2B2AE3DU /*!< 0b11000010101100101010111000111101 */
+#define XXH_PRIME32_4 0x27D4EB2FU /*!< 0b00100111110101001110101100101111 */
+#define XXH_PRIME32_5 0x165667B1U /*!< 0b00010110010101100110011110110001 */
+
+#ifdef XXH_OLD_NAMES
+# define PRIME32_1 XXH_PRIME32_1
+# define PRIME32_2 XXH_PRIME32_2
+# define PRIME32_3 XXH_PRIME32_3
+# define PRIME32_4 XXH_PRIME32_4
+# define PRIME32_5 XXH_PRIME32_5
+#endif
+
+/*!
+ * @internal
+ * @brief Normal stripe processing routine.
+ *
+ * This shuffles the bits so that any bit from @p input impacts several bits in
+ * @p acc.
+ *
+ * @param acc The accumulator lane.
+ * @param input The stripe of input to mix.
+ * @return The mixed accumulator lane.
+ */
+static xxh_u32 XXH32_round(xxh_u32 acc, xxh_u32 input)
+{
+ acc += input * XXH_PRIME32_2;
+ acc = XXH_rotl32(acc, 13);
+ acc *= XXH_PRIME32_1;
+#if (defined(__SSE4_1__) || defined(__aarch64__)) && !defined(XXH_ENABLE_AUTOVECTORIZE)
+ /*
+ * UGLY HACK:
+ * A compiler fence is the only thing that prevents GCC and Clang from
+ * autovectorizing the XXH32 loop (pragmas and attributes don't work for some
+ * reason) without globally disabling SSE4.1.
+ *
+ * The reason we want to avoid vectorization is because despite working on
+ * 4 integers at a time, there are multiple factors slowing XXH32 down on
+ * SSE4:
+ * - There's a ridiculous amount of lag from pmulld (10 cycles of latency on
+ * newer chips!) making it slightly slower to multiply four integers at
+ * once compared to four integers independently. Even when pmulld was
+ * fastest, Sandy/Ivy Bridge, it is still not worth it to go into SSE
+ * just to multiply unless doing a long operation.
+ *
+ * - Four instructions are required to rotate,
+ * movqda tmp, v // not required with VEX encoding
+ * pslld tmp, 13 // tmp <<= 13
+ * psrld v, 19 // x >>= 19
+ * por v, tmp // x |= tmp
+ * compared to one for scalar:
+ * roll v, 13 // reliably fast across the board
+ * shldl v, v, 13 // Sandy Bridge and later prefer this for some reason
+ *
+ * - Instruction level parallelism is actually more beneficial here because
+ * the SIMD actually serializes this operation: While v1 is rotating, v2
+ * can load data, while v3 can multiply. SSE forces them to operate
+ * together.
+ *
+ * This is also enabled on AArch64, as Clang autovectorizes it incorrectly
+ * and it is pointless writing a NEON implementation that is basically the
+ * same speed as scalar for XXH32.
+ */
+ XXH_COMPILER_GUARD(acc);
+#endif
+ return acc;
+}
+
+/*!
+ * @internal
+ * @brief Mixes all bits to finalize the hash.
+ *
+ * The final mix ensures that all input bits have a chance to impact any bit in
+ * the output digest, resulting in an unbiased distribution.
+ *
+ * @param h32 The hash to avalanche.
+ * @return The avalanched hash.
+ */
+static xxh_u32 XXH32_avalanche(xxh_u32 h32)
+{
+ h32 ^= h32 >> 15;
+ h32 *= XXH_PRIME32_2;
+ h32 ^= h32 >> 13;
+ h32 *= XXH_PRIME32_3;
+ h32 ^= h32 >> 16;
+ return(h32);
+}
+
+#define XXH_get32bits(p) XXH_readLE32_align(p, align)
+
+/*!
+ * @internal
+ * @brief Processes the last 0-15 bytes of @p ptr.
+ *
+ * There may be up to 15 bytes remaining to consume from the input.
+ * This final stage will digest them to ensure that all input bytes are present
+ * in the final mix.
+ *
+ * @param h32 The hash to finalize.
+ * @param ptr The pointer to the remaining input.
+ * @param len The remaining length, modulo 16.
+ * @param align Whether @p ptr is aligned.
+ * @return The finalized hash.
+ */
+static xxh_u32
+XXH32_finalize(xxh_u32 h32, const xxh_u8* ptr, size_t len, XXH_alignment align)
+{
+#define XXH_PROCESS1 do { \
+ h32 += (*ptr++) * XXH_PRIME32_5; \
+ h32 = XXH_rotl32(h32, 11) * XXH_PRIME32_1; \
+} while (0)
+
+#define XXH_PROCESS4 do { \
+ h32 += XXH_get32bits(ptr) * XXH_PRIME32_3; \
+ ptr += 4; \
+ h32 = XXH_rotl32(h32, 17) * XXH_PRIME32_4; \
+} while (0)
+
+ /* Compact rerolled version */
+ if (XXH_REROLL) {
+ len &= 15;
+ while (len >= 4) {
+ XXH_PROCESS4;
+ len -= 4;
+ }
+ while (len > 0) {
+ XXH_PROCESS1;
+ --len;
+ }
+ return XXH32_avalanche(h32);
+ } else {
+ switch(len&15) /* or switch(bEnd - p) */ {
+ case 12: XXH_PROCESS4;
+ FALLTHROUGH_INTENDED;
+ case 8: XXH_PROCESS4;
+ FALLTHROUGH_INTENDED;
+ case 4: XXH_PROCESS4;
+ return XXH32_avalanche(h32);
+
+ case 13: XXH_PROCESS4;
+ FALLTHROUGH_INTENDED;
+ case 9: XXH_PROCESS4;
+ FALLTHROUGH_INTENDED;
+ case 5: XXH_PROCESS4;
+ XXH_PROCESS1;
+ return XXH32_avalanche(h32);
+
+ case 14: XXH_PROCESS4;
+ FALLTHROUGH_INTENDED;
+ case 10: XXH_PROCESS4;
+ FALLTHROUGH_INTENDED;
+ case 6: XXH_PROCESS4;
+ XXH_PROCESS1;
+ XXH_PROCESS1;
+ return XXH32_avalanche(h32);
+
+ case 15: XXH_PROCESS4;
+ FALLTHROUGH_INTENDED;
+ case 11: XXH_PROCESS4;
+ FALLTHROUGH_INTENDED;
+ case 7: XXH_PROCESS4;
+ FALLTHROUGH_INTENDED;
+ case 3: XXH_PROCESS1;
+ FALLTHROUGH_INTENDED;
+ case 2: XXH_PROCESS1;
+ FALLTHROUGH_INTENDED;
+ case 1: XXH_PROCESS1;
+ FALLTHROUGH_INTENDED;
+ case 0: return XXH32_avalanche(h32);
+ }
+ XXH_ASSERT(0);
+ return h32; /* reaching this point is deemed impossible */
+ }
+}
+
+#ifdef XXH_OLD_NAMES
+# define PROCESS1 XXH_PROCESS1
+# define PROCESS4 XXH_PROCESS4
+#else
+# undef XXH_PROCESS1
+# undef XXH_PROCESS4
+#endif
+
+/*!
+ * @internal
+ * @brief The implementation for @ref XXH32().
+ *
+ * @param input, len, seed Directly passed from @ref XXH32().
+ * @param align Whether @p input is aligned.
+ * @return The calculated hash.
+ */
+XXH_FORCE_INLINE xxh_u32
+XXH32_endian_align(const xxh_u8* input, size_t len, xxh_u32 seed, XXH_alignment align)
+{
+ const xxh_u8* bEnd = input ? input + len : NULL;
+ xxh_u32 h32;
+
+#if defined(XXH_ACCEPT_NULL_INPUT_POINTER) && (XXH_ACCEPT_NULL_INPUT_POINTER>=1)
+ if (input==NULL) {
+ len=0;
+ bEnd=input=(const xxh_u8*)(size_t)16;
+ }
+#endif
+
+ if (len>=16) {
+ const xxh_u8* const limit = bEnd - 15;
+ xxh_u32 v1 = seed + XXH_PRIME32_1 + XXH_PRIME32_2;
+ xxh_u32 v2 = seed + XXH_PRIME32_2;
+ xxh_u32 v3 = seed + 0;
+ xxh_u32 v4 = seed - XXH_PRIME32_1;
+
+ do {
+ v1 = XXH32_round(v1, XXH_get32bits(input)); input += 4;
+ v2 = XXH32_round(v2, XXH_get32bits(input)); input += 4;
+ v3 = XXH32_round(v3, XXH_get32bits(input)); input += 4;
+ v4 = XXH32_round(v4, XXH_get32bits(input)); input += 4;
+ } while (input < limit);
+
+ h32 = XXH_rotl32(v1, 1) + XXH_rotl32(v2, 7)
+ + XXH_rotl32(v3, 12) + XXH_rotl32(v4, 18);
+ } else {
+ h32 = seed + XXH_PRIME32_5;
+ }
+
+ h32 += (xxh_u32)len;
+
+ return XXH32_finalize(h32, input, len&15, align);
+}
+
+/*! @ingroup xxh32_family */
+XXH_PUBLIC_API XXH32_hash_t XXH32 (const void* input, size_t len, XXH32_hash_t seed)
+{
+#if 0
+ /* Simple version, good for code maintenance, but unfortunately slow for small inputs */
+ XXH32_state_t state;
+ XXH32_reset(&state, seed);
+ XXH32_update(&state, (const xxh_u8*)input, len);
+ return XXH32_digest(&state);
+#else
+ if (XXH_FORCE_ALIGN_CHECK) {
+ if ((((size_t)input) & 3) == 0) { /* Input is 4-bytes aligned, leverage the speed benefit */
+ return XXH32_endian_align((const xxh_u8*)input, len, seed, XXH_aligned);
+ } }
+
+ return XXH32_endian_align((const xxh_u8*)input, len, seed, XXH_unaligned);
+#endif
+}
+
+
+
+/******* Hash streaming *******/
+/*!
+ * @ingroup xxh32_family
+ */
+XXH_PUBLIC_API XXH32_state_t* XXH32_createState(void)
+{
+ return (XXH32_state_t*)XXH_malloc(sizeof(XXH32_state_t));
+}
+/*! @ingroup xxh32_family */
+XXH_PUBLIC_API XXH_errorcode XXH32_freeState(XXH32_state_t* statePtr)
+{
+ XXH_free(statePtr);
+ return XXH_OK;
+}
+
+/*! @ingroup xxh32_family */
+XXH_PUBLIC_API void XXH32_copyState(XXH32_state_t* dstState, const XXH32_state_t* srcState)
+{
+ memcpy(dstState, srcState, sizeof(*dstState));
+}
+
+/*! @ingroup xxh32_family */
+XXH_PUBLIC_API XXH_errorcode XXH32_reset(XXH32_state_t* statePtr, XXH32_hash_t seed)
+{
+ XXH32_state_t state; /* using a local state to memcpy() in order to avoid strict-aliasing warnings */
+ memset(&state, 0, sizeof(state));
+ state.v1 = seed + XXH_PRIME32_1 + XXH_PRIME32_2;
+ state.v2 = seed + XXH_PRIME32_2;
+ state.v3 = seed + 0;
+ state.v4 = seed - XXH_PRIME32_1;
+ /* do not write into reserved, planned to be removed in a future version */
+ memcpy(statePtr, &state, sizeof(state) - sizeof(state.reserved));
+ return XXH_OK;
+}
+
+
+/*! @ingroup xxh32_family */
+XXH_PUBLIC_API XXH_errorcode
+XXH32_update(XXH32_state_t* state, const void* input, size_t len)
+{
+ if (input==NULL)
+#if defined(XXH_ACCEPT_NULL_INPUT_POINTER) && (XXH_ACCEPT_NULL_INPUT_POINTER>=1)
+ return XXH_OK;
+#else
+ return XXH_ERROR;
+#endif
+
+ { const xxh_u8* p = (const xxh_u8*)input;
+ const xxh_u8* const bEnd = p + len;
+
+ state->total_len_32 += (XXH32_hash_t)len;
+ state->large_len |= (XXH32_hash_t)((len>=16) | (state->total_len_32>=16));
+
+ if (state->memsize + len < 16) { /* fill in tmp buffer */
+ XXH_memcpy((xxh_u8*)(state->mem32) + state->memsize, input, len);
+ state->memsize += (XXH32_hash_t)len;
+ return XXH_OK;
+ }
+
+ if (state->memsize) { /* some data left from previous update */
+ XXH_memcpy((xxh_u8*)(state->mem32) + state->memsize, input, 16-state->memsize);
+ { const xxh_u32* p32 = state->mem32;
+ state->v1 = XXH32_round(state->v1, XXH_readLE32(p32)); p32++;
+ state->v2 = XXH32_round(state->v2, XXH_readLE32(p32)); p32++;
+ state->v3 = XXH32_round(state->v3, XXH_readLE32(p32)); p32++;
+ state->v4 = XXH32_round(state->v4, XXH_readLE32(p32));
+ }
+ p += 16-state->memsize;
+ state->memsize = 0;
+ }
+
+ /* uintptr_t casts avoid UB or compiler warning on out-of-bounds
+ * pointer arithmetic */
+ if ((uintptr_t)p <= (uintptr_t)bEnd - 16) {
+ const uintptr_t limit = (uintptr_t)bEnd - 16;
+ xxh_u32 v1 = state->v1;
+ xxh_u32 v2 = state->v2;
+ xxh_u32 v3 = state->v3;
+ xxh_u32 v4 = state->v4;
+
+ do {
+ v1 = XXH32_round(v1, XXH_readLE32(p)); p+=4;
+ v2 = XXH32_round(v2, XXH_readLE32(p)); p+=4;
+ v3 = XXH32_round(v3, XXH_readLE32(p)); p+=4;
+ v4 = XXH32_round(v4, XXH_readLE32(p)); p+=4;
+ } while ((uintptr_t)p<=limit);
+
+ state->v1 = v1;
+ state->v2 = v2;
+ state->v3 = v3;
+ state->v4 = v4;
+ }
+
+ if (p < bEnd) {
+ XXH_memcpy(state->mem32, p, (size_t)(bEnd-p));
+ state->memsize = (unsigned)(bEnd-p);
+ }
+ }
+
+ return XXH_OK;
+}
+
+
+/*! @ingroup xxh32_family */
+XXH_PUBLIC_API XXH32_hash_t XXH32_digest(const XXH32_state_t* state)
+{
+ xxh_u32 h32;
+
+ if (state->large_len) {
+ h32 = XXH_rotl32(state->v1, 1)
+ + XXH_rotl32(state->v2, 7)
+ + XXH_rotl32(state->v3, 12)
+ + XXH_rotl32(state->v4, 18);
+ } else {
+ h32 = state->v3 /* == seed */ + XXH_PRIME32_5;
+ }
+
+ h32 += state->total_len_32;
+
+ return XXH32_finalize(h32, (const xxh_u8*)state->mem32, state->memsize, XXH_aligned);
+}
+
+
+/******* Canonical representation *******/
+
+/*!
+ * @ingroup xxh32_family
+ * The default return values from XXH functions are unsigned 32 and 64 bit
+ * integers.
+ *
+ * The canonical representation uses big endian convention, the same convention
+ * as human-readable numbers (large digits first).
+ *
+ * This way, hash values can be written into a file or buffer, remaining
+ * comparable across different systems.
+ *
+ * The following functions allow transformation of hash values to and from their
+ * canonical format.
+ */
+XXH_PUBLIC_API void XXH32_canonicalFromHash(XXH32_canonical_t* dst, XXH32_hash_t hash)
+{
+ XXH_STATIC_ASSERT(sizeof(XXH32_canonical_t) == sizeof(XXH32_hash_t));
+ if (XXH_CPU_LITTLE_ENDIAN) hash = XXH_swap32(hash);
+ memcpy(dst, &hash, sizeof(*dst));
+}
+/*! @ingroup xxh32_family */
+XXH_PUBLIC_API XXH32_hash_t XXH32_hashFromCanonical(const XXH32_canonical_t* src)
+{
+ return XXH_readBE32(src);
+}
+
+
+#ifndef XXH_NO_LONG_LONG
+
+/* *******************************************************************
+* 64-bit hash functions
+*********************************************************************/
+/*!
+ * @}
+ * @ingroup impl
+ * @{
+ */
+/******* Memory access *******/
+
+typedef XXH64_hash_t xxh_u64;
+
+#ifdef XXH_OLD_NAMES
+# define U64 xxh_u64
+#endif
+
+#if (defined(XXH_FORCE_MEMORY_ACCESS) && (XXH_FORCE_MEMORY_ACCESS==3))
+/*
+ * Manual byteshift. Best for old compilers which don't inline memcpy.
+ * We actually directly use XXH_readLE64 and XXH_readBE64.
+ */
+#elif (defined(XXH_FORCE_MEMORY_ACCESS) && (XXH_FORCE_MEMORY_ACCESS==2))
+
+/* Force direct memory access. Only works on CPU which support unaligned memory access in hardware */
+static xxh_u64 XXH_read64(const void* memPtr)
+{
+ return *(const xxh_u64*) memPtr;
+}
+
+#elif (defined(XXH_FORCE_MEMORY_ACCESS) && (XXH_FORCE_MEMORY_ACCESS==1))
+
+/*
+ * __pack instructions are safer, but compiler specific, hence potentially
+ * problematic for some compilers.
+ *
+ * Currently only defined for GCC and ICC.
+ */
+#ifdef XXH_OLD_NAMES
+typedef union { xxh_u32 u32; xxh_u64 u64; } __attribute__((packed)) unalign64;
+#endif
+static xxh_u64 XXH_read64(const void* ptr)
+{
+ typedef union { xxh_u32 u32; xxh_u64 u64; } __attribute__((packed)) xxh_unalign64;
+ return ((const xxh_unalign64*)ptr)->u64;
+}
+
+#else
+
+/*
+ * Portable and safe solution. Generally efficient.
+ * see: https://stackoverflow.com/a/32095106/646947
+ */
+static xxh_u64 XXH_read64(const void* memPtr)
+{
+ xxh_u64 val;
+ memcpy(&val, memPtr, sizeof(val));
+ return val;
+}
+
+#endif /* XXH_FORCE_DIRECT_MEMORY_ACCESS */
+
+#if defined(_MSC_VER) /* Visual Studio */
+# define XXH_swap64 _byteswap_uint64
+#elif XXH_GCC_VERSION >= 403
+# define XXH_swap64 __builtin_bswap64
+#else
+static xxh_u64 XXH_swap64(xxh_u64 x)
+{
+ return ((x << 56) & 0xff00000000000000ULL) |
+ ((x << 40) & 0x00ff000000000000ULL) |
+ ((x << 24) & 0x0000ff0000000000ULL) |
+ ((x << 8) & 0x000000ff00000000ULL) |
+ ((x >> 8) & 0x00000000ff000000ULL) |
+ ((x >> 24) & 0x0000000000ff0000ULL) |
+ ((x >> 40) & 0x000000000000ff00ULL) |
+ ((x >> 56) & 0x00000000000000ffULL);
+}
+#endif
+
+
+/* XXH_FORCE_MEMORY_ACCESS==3 is an endian-independent byteshift load. */
+#if (defined(XXH_FORCE_MEMORY_ACCESS) && (XXH_FORCE_MEMORY_ACCESS==3))
+
+XXH_FORCE_INLINE xxh_u64 XXH_readLE64(const void* memPtr)
+{
+ const xxh_u8* bytePtr = (const xxh_u8 *)memPtr;
+ return bytePtr[0]
+ | ((xxh_u64)bytePtr[1] << 8)
+ | ((xxh_u64)bytePtr[2] << 16)
+ | ((xxh_u64)bytePtr[3] << 24)
+ | ((xxh_u64)bytePtr[4] << 32)
+ | ((xxh_u64)bytePtr[5] << 40)
+ | ((xxh_u64)bytePtr[6] << 48)
+ | ((xxh_u64)bytePtr[7] << 56);
+}
+
+XXH_FORCE_INLINE xxh_u64 XXH_readBE64(const void* memPtr)
+{
+ const xxh_u8* bytePtr = (const xxh_u8 *)memPtr;
+ return bytePtr[7]
+ | ((xxh_u64)bytePtr[6] << 8)
+ | ((xxh_u64)bytePtr[5] << 16)
+ | ((xxh_u64)bytePtr[4] << 24)
+ | ((xxh_u64)bytePtr[3] << 32)
+ | ((xxh_u64)bytePtr[2] << 40)
+ | ((xxh_u64)bytePtr[1] << 48)
+ | ((xxh_u64)bytePtr[0] << 56);
+}
+
+#else
+XXH_FORCE_INLINE xxh_u64 XXH_readLE64(const void* ptr)
+{
+ return XXH_CPU_LITTLE_ENDIAN ? XXH_read64(ptr) : XXH_swap64(XXH_read64(ptr));
+}
+
+static xxh_u64 XXH_readBE64(const void* ptr)
+{
+ return XXH_CPU_LITTLE_ENDIAN ? XXH_swap64(XXH_read64(ptr)) : XXH_read64(ptr);
+}
+#endif
+
+XXH_FORCE_INLINE xxh_u64
+XXH_readLE64_align(const void* ptr, XXH_alignment align)
+{
+ if (align==XXH_unaligned)
+ return XXH_readLE64(ptr);
+ else
+ return XXH_CPU_LITTLE_ENDIAN ? *(const xxh_u64*)ptr : XXH_swap64(*(const xxh_u64*)ptr);
+}
+
+
+/******* xxh64 *******/
+/*!
+ * @}
+ * @defgroup xxh64_impl XXH64 implementation
+ * @ingroup impl
+ * @{
+ */
+/* #define rather that static const, to be used as initializers */
+#define XXH_PRIME64_1 0x9E3779B185EBCA87ULL /*!< 0b1001111000110111011110011011000110000101111010111100101010000111 */
+#define XXH_PRIME64_2 0xC2B2AE3D27D4EB4FULL /*!< 0b1100001010110010101011100011110100100111110101001110101101001111 */
+#define XXH_PRIME64_3 0x165667B19E3779F9ULL /*!< 0b0001011001010110011001111011000110011110001101110111100111111001 */
+#define XXH_PRIME64_4 0x85EBCA77C2B2AE63ULL /*!< 0b1000010111101011110010100111011111000010101100101010111001100011 */
+#define XXH_PRIME64_5 0x27D4EB2F165667C5ULL /*!< 0b0010011111010100111010110010111100010110010101100110011111000101 */
+
+#ifdef XXH_OLD_NAMES
+# define PRIME64_1 XXH_PRIME64_1
+# define PRIME64_2 XXH_PRIME64_2
+# define PRIME64_3 XXH_PRIME64_3
+# define PRIME64_4 XXH_PRIME64_4
+# define PRIME64_5 XXH_PRIME64_5
+#endif
+
+static xxh_u64 XXH64_round(xxh_u64 acc, xxh_u64 input)
+{
+ acc += input * XXH_PRIME64_2;
+ acc = XXH_rotl64(acc, 31);
+ acc *= XXH_PRIME64_1;
+ return acc;
+}
+
+static xxh_u64 XXH64_mergeRound(xxh_u64 acc, xxh_u64 val)
+{
+ val = XXH64_round(0, val);
+ acc ^= val;
+ acc = acc * XXH_PRIME64_1 + XXH_PRIME64_4;
+ return acc;
+}
+
+static xxh_u64 XXH64_avalanche(xxh_u64 h64)
+{
+ h64 ^= h64 >> 33;
+ h64 *= XXH_PRIME64_2;
+ h64 ^= h64 >> 29;
+ h64 *= XXH_PRIME64_3;
+ h64 ^= h64 >> 32;
+ return h64;
+}
+
+
+#define XXH_get64bits(p) XXH_readLE64_align(p, align)
+
+static xxh_u64
+XXH64_finalize(xxh_u64 h64, const xxh_u8* ptr, size_t len, XXH_alignment align)
+{
+ len &= 31;
+ while (len >= 8) {
+ xxh_u64 const k1 = XXH64_round(0, XXH_get64bits(ptr));
+ ptr += 8;
+ h64 ^= k1;
+ h64 = XXH_rotl64(h64,27) * XXH_PRIME64_1 + XXH_PRIME64_4;
+ len -= 8;
+ }
+ if (len >= 4) {
+ h64 ^= (xxh_u64)(XXH_get32bits(ptr)) * XXH_PRIME64_1;
+ ptr += 4;
+ h64 = XXH_rotl64(h64, 23) * XXH_PRIME64_2 + XXH_PRIME64_3;
+ len -= 4;
+ }
+ while (len > 0) {
+ h64 ^= (*ptr++) * XXH_PRIME64_5;
+ h64 = XXH_rotl64(h64, 11) * XXH_PRIME64_1;
+ --len;
+ }
+ return XXH64_avalanche(h64);
+}
+
+#ifdef XXH_OLD_NAMES
+# define PROCESS1_64 XXH_PROCESS1_64
+# define PROCESS4_64 XXH_PROCESS4_64
+# define PROCESS8_64 XXH_PROCESS8_64
+#else
+# undef XXH_PROCESS1_64
+# undef XXH_PROCESS4_64
+# undef XXH_PROCESS8_64
+#endif
+
+XXH_FORCE_INLINE xxh_u64
+XXH64_endian_align(const xxh_u8* input, size_t len, xxh_u64 seed, XXH_alignment align)
+{
+ const xxh_u8* bEnd = input ? input + len : NULL;
+ xxh_u64 h64;
+
+#if defined(XXH_ACCEPT_NULL_INPUT_POINTER) && (XXH_ACCEPT_NULL_INPUT_POINTER>=1)
+ if (input==NULL) {
+ len=0;
+ bEnd=input=(const xxh_u8*)(size_t)32;
+ }
+#endif
+
+ if (len>=32) {
+ const xxh_u8* const limit = bEnd - 32;
+ xxh_u64 v1 = seed + XXH_PRIME64_1 + XXH_PRIME64_2;
+ xxh_u64 v2 = seed + XXH_PRIME64_2;
+ xxh_u64 v3 = seed + 0;
+ xxh_u64 v4 = seed - XXH_PRIME64_1;
+
+ do {
+ v1 = XXH64_round(v1, XXH_get64bits(input)); input+=8;
+ v2 = XXH64_round(v2, XXH_get64bits(input)); input+=8;
+ v3 = XXH64_round(v3, XXH_get64bits(input)); input+=8;
+ v4 = XXH64_round(v4, XXH_get64bits(input)); input+=8;
+ } while (input<=limit);
+
+ h64 = XXH_rotl64(v1, 1) + XXH_rotl64(v2, 7) + XXH_rotl64(v3, 12) + XXH_rotl64(v4, 18);
+ h64 = XXH64_mergeRound(h64, v1);
+ h64 = XXH64_mergeRound(h64, v2);
+ h64 = XXH64_mergeRound(h64, v3);
+ h64 = XXH64_mergeRound(h64, v4);
+
+ } else {
+ h64 = seed + XXH_PRIME64_5;
+ }
+
+ h64 += (xxh_u64) len;
+
+ return XXH64_finalize(h64, input, len, align);
+}
+
+
+/*! @ingroup xxh64_family */
+XXH_PUBLIC_API XXH64_hash_t XXH64 (const void* input, size_t len, XXH64_hash_t seed)
+{
+#if 0
+ /* Simple version, good for code maintenance, but unfortunately slow for small inputs */
+ XXH64_state_t state;
+ XXH64_reset(&state, seed);
+ XXH64_update(&state, (const xxh_u8*)input, len);
+ return XXH64_digest(&state);
+#else
+ if (XXH_FORCE_ALIGN_CHECK) {
+ if ((((size_t)input) & 7)==0) { /* Input is aligned, let's leverage the speed advantage */
+ return XXH64_endian_align((const xxh_u8*)input, len, seed, XXH_aligned);
+ } }
+
+ return XXH64_endian_align((const xxh_u8*)input, len, seed, XXH_unaligned);
+
+#endif
+}
+
+/******* Hash Streaming *******/
+
+/*! @ingroup xxh64_family*/
+XXH_PUBLIC_API XXH64_state_t* XXH64_createState(void)
+{
+ return (XXH64_state_t*)XXH_malloc(sizeof(XXH64_state_t));
+}
+/*! @ingroup xxh64_family */
+XXH_PUBLIC_API XXH_errorcode XXH64_freeState(XXH64_state_t* statePtr)
+{
+ XXH_free(statePtr);
+ return XXH_OK;
+}
+
+/*! @ingroup xxh64_family */
+XXH_PUBLIC_API void XXH64_copyState(XXH64_state_t* dstState, const XXH64_state_t* srcState)
+{
+ memcpy(dstState, srcState, sizeof(*dstState));
+}
+
+/*! @ingroup xxh64_family */
+XXH_PUBLIC_API XXH_errorcode XXH64_reset(XXH64_state_t* statePtr, XXH64_hash_t seed)
+{
+ XXH64_state_t state; /* use a local state to memcpy() in order to avoid strict-aliasing warnings */
+ memset(&state, 0, sizeof(state));
+ state.v1 = seed + XXH_PRIME64_1 + XXH_PRIME64_2;
+ state.v2 = seed + XXH_PRIME64_2;
+ state.v3 = seed + 0;
+ state.v4 = seed - XXH_PRIME64_1;
+ /* do not write into reserved64, might be removed in a future version */
+ memcpy(statePtr, &state, sizeof(state) - sizeof(state.reserved64));
+ return XXH_OK;
+}
+
+/*! @ingroup xxh64_family */
+XXH_PUBLIC_API XXH_errorcode
+XXH64_update (XXH64_state_t* state, const void* input, size_t len)
+{
+ if (input==NULL)
+#if defined(XXH_ACCEPT_NULL_INPUT_POINTER) && (XXH_ACCEPT_NULL_INPUT_POINTER>=1)
+ return XXH_OK;
+#else
+ return XXH_ERROR;
+#endif
+
+ { const xxh_u8* p = (const xxh_u8*)input;
+ const xxh_u8* const bEnd = p + len;
+
+ state->total_len += len;
+
+ if (state->memsize + len < 32) { /* fill in tmp buffer */
+ XXH_memcpy(((xxh_u8*)state->mem64) + state->memsize, input, len);
+ state->memsize += (xxh_u32)len;
+ return XXH_OK;
+ }
+
+ if (state->memsize) { /* tmp buffer is full */
+ XXH_memcpy(((xxh_u8*)state->mem64) + state->memsize, input, 32-state->memsize);
+ state->v1 = XXH64_round(state->v1, XXH_readLE64(state->mem64+0));
+ state->v2 = XXH64_round(state->v2, XXH_readLE64(state->mem64+1));
+ state->v3 = XXH64_round(state->v3, XXH_readLE64(state->mem64+2));
+ state->v4 = XXH64_round(state->v4, XXH_readLE64(state->mem64+3));
+ p += 32 - state->memsize;
+ state->memsize = 0;
+ }
+
+ /* uintptr_t casts avoid UB or compiler warning on out-of-bounds
+ * pointer arithmetic */
+ if ((uintptr_t)p + 32 <= (uintptr_t)bEnd) {
+ const uintptr_t limit = (uintptr_t)bEnd - 32;
+ xxh_u64 v1 = state->v1;
+ xxh_u64 v2 = state->v2;
+ xxh_u64 v3 = state->v3;
+ xxh_u64 v4 = state->v4;
+
+ do {
+ v1 = XXH64_round(v1, XXH_readLE64(p)); p+=8;
+ v2 = XXH64_round(v2, XXH_readLE64(p)); p+=8;
+ v3 = XXH64_round(v3, XXH_readLE64(p)); p+=8;
+ v4 = XXH64_round(v4, XXH_readLE64(p)); p+=8;
+ } while ((uintptr_t)p<=limit);
+
+ state->v1 = v1;
+ state->v2 = v2;
+ state->v3 = v3;
+ state->v4 = v4;
+ }
+
+ if (p < bEnd) {
+ XXH_memcpy(state->mem64, p, (size_t)(bEnd-p));
+ state->memsize = (unsigned)(bEnd-p);
+ }
+ }
+
+ return XXH_OK;
+}
+
+
+/*! @ingroup xxh64_family */
+XXH_PUBLIC_API XXH64_hash_t XXH64_digest(const XXH64_state_t* state)
+{
+ xxh_u64 h64;
+
+ if (state->total_len >= 32) {
+ xxh_u64 const v1 = state->v1;
+ xxh_u64 const v2 = state->v2;
+ xxh_u64 const v3 = state->v3;
+ xxh_u64 const v4 = state->v4;
+
+ h64 = XXH_rotl64(v1, 1) + XXH_rotl64(v2, 7) + XXH_rotl64(v3, 12) + XXH_rotl64(v4, 18);
+ h64 = XXH64_mergeRound(h64, v1);
+ h64 = XXH64_mergeRound(h64, v2);
+ h64 = XXH64_mergeRound(h64, v3);
+ h64 = XXH64_mergeRound(h64, v4);
+ } else {
+ h64 = state->v3 /*seed*/ + XXH_PRIME64_5;
+ }
+
+ h64 += (xxh_u64) state->total_len;
+
+ return XXH64_finalize(h64, (const xxh_u8*)state->mem64, (size_t)state->total_len, XXH_aligned);
+}
+
+
+/******* Canonical representation *******/
+
+/*! @ingroup xxh64_family */
+XXH_PUBLIC_API void XXH64_canonicalFromHash(XXH64_canonical_t* dst, XXH64_hash_t hash)
+{
+ XXH_STATIC_ASSERT(sizeof(XXH64_canonical_t) == sizeof(XXH64_hash_t));
+ if (XXH_CPU_LITTLE_ENDIAN) hash = XXH_swap64(hash);
+ memcpy(dst, &hash, sizeof(*dst));
+}
+
+/*! @ingroup xxh64_family */
+XXH_PUBLIC_API XXH64_hash_t XXH64_hashFromCanonical(const XXH64_canonical_t* src)
+{
+ return XXH_readBE64(src);
+}
+
+#ifndef XXH_NO_XXH3
+
+/* *********************************************************************
+* XXH3
+* New generation hash designed for speed on small keys and vectorization
+************************************************************************ */
+/*!
+ * @}
+ * @defgroup xxh3_impl XXH3 implementation
+ * @ingroup impl
+ * @{
+ */
+
+/* === Compiler specifics === */
+
+#if ((defined(sun) || defined(__sun)) && __cplusplus) /* Solaris includes __STDC_VERSION__ with C++. Tested with GCC 5.5 */
+# define XXH_RESTRICT /* disable */
+#elif defined (__STDC_VERSION__) && __STDC_VERSION__ >= 199901L /* >= C99 */
+# define XXH_RESTRICT restrict
+#else
+/* Note: it might be useful to define __restrict or __restrict__ for some C++ compilers */
+# define XXH_RESTRICT /* disable */
+#endif
+
+#if (defined(__GNUC__) && (__GNUC__ >= 3)) \
+ || (defined(__INTEL_COMPILER) && (__INTEL_COMPILER >= 800)) \
+ || defined(__clang__)
+# define XXH_likely(x) __builtin_expect(x, 1)
+# define XXH_unlikely(x) __builtin_expect(x, 0)
+#else
+# define XXH_likely(x) (x)
+# define XXH_unlikely(x) (x)
+#endif
+
+#if defined(__GNUC__)
+# if defined(__AVX2__)
+# include <immintrin.h>
+# elif defined(__SSE2__)
+# include <emmintrin.h>
+# elif defined(__ARM_NEON__) || defined(__ARM_NEON)
+# define inline __inline__ /* circumvent a clang bug */
+# include <arm_neon.h>
+# undef inline
+# endif
+#elif defined(_MSC_VER)
+# include <intrin.h>
+#endif
+
+/*
+ * One goal of XXH3 is to make it fast on both 32-bit and 64-bit, while
+ * remaining a true 64-bit/128-bit hash function.
+ *
+ * This is done by prioritizing a subset of 64-bit operations that can be
+ * emulated without too many steps on the average 32-bit machine.
+ *
+ * For example, these two lines seem similar, and run equally fast on 64-bit:
+ *
+ * xxh_u64 x;
+ * x ^= (x >> 47); // good
+ * x ^= (x >> 13); // bad
+ *
+ * However, to a 32-bit machine, there is a major difference.
+ *
+ * x ^= (x >> 47) looks like this:
+ *
+ * x.lo ^= (x.hi >> (47 - 32));
+ *
+ * while x ^= (x >> 13) looks like this:
+ *
+ * // note: funnel shifts are not usually cheap.
+ * x.lo ^= (x.lo >> 13) | (x.hi << (32 - 13));
+ * x.hi ^= (x.hi >> 13);
+ *
+ * The first one is significantly faster than the second, simply because the
+ * shift is larger than 32. This means:
+ * - All the bits we need are in the upper 32 bits, so we can ignore the lower
+ * 32 bits in the shift.
+ * - The shift result will always fit in the lower 32 bits, and therefore,
+ * we can ignore the upper 32 bits in the xor.
+ *
+ * Thanks to this optimization, XXH3 only requires these features to be efficient:
+ *
+ * - Usable unaligned access
+ * - A 32-bit or 64-bit ALU
+ * - If 32-bit, a decent ADC instruction
+ * - A 32 or 64-bit multiply with a 64-bit result
+ * - For the 128-bit variant, a decent byteswap helps short inputs.
+ *
+ * The first two are already required by XXH32, and almost all 32-bit and 64-bit
+ * platforms which can run XXH32 can run XXH3 efficiently.
+ *
+ * Thumb-1, the classic 16-bit only subset of ARM's instruction set, is one
+ * notable exception.
+ *
+ * First of all, Thumb-1 lacks support for the UMULL instruction which
+ * performs the important long multiply. This means numerous __aeabi_lmul
+ * calls.
+ *
+ * Second of all, the 8 functional registers are just not enough.
+ * Setup for __aeabi_lmul, byteshift loads, pointers, and all arithmetic need
+ * Lo registers, and this shuffling results in thousands more MOVs than A32.
+ *
+ * A32 and T32 don't have this limitation. They can access all 14 registers,
+ * do a 32->64 multiply with UMULL, and the flexible operand allowing free
+ * shifts is helpful, too.
+ *
+ * Therefore, we do a quick sanity check.
+ *
+ * If compiling Thumb-1 for a target which supports ARM instructions, we will
+ * emit a warning, as it is not a "sane" platform to compile for.
+ *
+ * Usually, if this happens, it is because of an accident and you probably need
+ * to specify -march, as you likely meant to compile for a newer architecture.
+ *
+ * Credit: large sections of the vectorial and asm source code paths
+ * have been contributed by @easyaspi314
+ */
+#if defined(__thumb__) && !defined(__thumb2__) && defined(__ARM_ARCH_ISA_ARM)
+# warning "XXH3 is highly inefficient without ARM or Thumb-2."
+#endif
+
+/* ==========================================
+ * Vectorization detection
+ * ========================================== */
+
+#ifdef XXH_DOXYGEN
+/*!
+ * @ingroup tuning
+ * @brief Overrides the vectorization implementation chosen for XXH3.
+ *
+ * Can be defined to 0 to disable SIMD or any of the values mentioned in
+ * @ref XXH_VECTOR_TYPE.
+ *
+ * If this is not defined, it uses predefined macros to determine the best
+ * implementation.
+ */
+# define XXH_VECTOR XXH_SCALAR
+/*!
+ * @ingroup tuning
+ * @brief Possible values for @ref XXH_VECTOR.
+ *
+ * Note that these are actually implemented as macros.
+ *
+ * If this is not defined, it is detected automatically.
+ * @ref XXH_X86DISPATCH overrides this.
+ */
+enum XXH_VECTOR_TYPE /* fake enum */ {
+ XXH_SCALAR = 0, /*!< Portable scalar version */
+ XXH_SSE2 = 1, /*!<
+ * SSE2 for Pentium 4, Opteron, all x86_64.
+ *
+ * @note SSE2 is also guaranteed on Windows 10, macOS, and
+ * Android x86.
+ */
+ XXH_AVX2 = 2, /*!< AVX2 for Haswell and Bulldozer */
+ XXH_AVX512 = 3, /*!< AVX512 for Skylake and Icelake */
+ XXH_NEON = 4, /*!< NEON for most ARMv7-A and all AArch64 */
+ XXH_VSX = 5, /*!< VSX and ZVector for POWER8/z13 (64-bit) */
+};
+/*!
+ * @ingroup tuning
+ * @brief Selects the minimum alignment for XXH3's accumulators.
+ *
+ * When using SIMD, this should match the alignment reqired for said vector
+ * type, so, for example, 32 for AVX2.
+ *
+ * Default: Auto detected.
+ */
+# define XXH_ACC_ALIGN 8
+#endif
+
+/* Actual definition */
+#ifndef XXH_DOXYGEN
+# define XXH_SCALAR 0
+# define XXH_SSE2 1
+# define XXH_AVX2 2
+# define XXH_AVX512 3
+# define XXH_NEON 4
+# define XXH_VSX 5
+#endif
+
+#ifndef XXH_VECTOR /* can be defined on command line */
+# if defined(__AVX512F__)
+# define XXH_VECTOR XXH_AVX512
+# elif defined(__AVX2__)
+# define XXH_VECTOR XXH_AVX2
+# elif defined(__SSE2__) || defined(_M_AMD64) || defined(_M_X64) || (defined(_M_IX86_FP) && (_M_IX86_FP == 2))
+# define XXH_VECTOR XXH_SSE2
+# elif defined(__GNUC__) /* msvc support maybe later */ \
+ && (defined(__ARM_NEON__) || defined(__ARM_NEON)) \
+ && (defined(__LITTLE_ENDIAN__) /* We only support little endian NEON */ \
+ || (defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__))
+# define XXH_VECTOR XXH_NEON
+# elif (defined(__PPC64__) && defined(__POWER8_VECTOR__)) \
+ || (defined(__s390x__) && defined(__VEC__)) \
+ && defined(__GNUC__) /* TODO: IBM XL */
+# define XXH_VECTOR XXH_VSX
+# else
+# define XXH_VECTOR XXH_SCALAR
+# endif
+#endif
+
+/*
+ * Controls the alignment of the accumulator,
+ * for compatibility with aligned vector loads, which are usually faster.
+ */
+#ifndef XXH_ACC_ALIGN
+# if defined(XXH_X86DISPATCH)
+# define XXH_ACC_ALIGN 64 /* for compatibility with avx512 */
+# elif XXH_VECTOR == XXH_SCALAR /* scalar */
+# define XXH_ACC_ALIGN 8
+# elif XXH_VECTOR == XXH_SSE2 /* sse2 */
+# define XXH_ACC_ALIGN 16
+# elif XXH_VECTOR == XXH_AVX2 /* avx2 */
+# define XXH_ACC_ALIGN 32
+# elif XXH_VECTOR == XXH_NEON /* neon */
+# define XXH_ACC_ALIGN 16
+# elif XXH_VECTOR == XXH_VSX /* vsx */
+# define XXH_ACC_ALIGN 16
+# elif XXH_VECTOR == XXH_AVX512 /* avx512 */
+# define XXH_ACC_ALIGN 64
+# endif
+#endif
+
+#if defined(XXH_X86DISPATCH) || XXH_VECTOR == XXH_SSE2 \
+ || XXH_VECTOR == XXH_AVX2 || XXH_VECTOR == XXH_AVX512
+# define XXH_SEC_ALIGN XXH_ACC_ALIGN
+#else
+# define XXH_SEC_ALIGN 8
+#endif
+
+/*
+ * UGLY HACK:
+ * GCC usually generates the best code with -O3 for xxHash.
+ *
+ * However, when targeting AVX2, it is overzealous in its unrolling resulting
+ * in code roughly 3/4 the speed of Clang.
+ *
+ * There are other issues, such as GCC splitting _mm256_loadu_si256 into
+ * _mm_loadu_si128 + _mm256_inserti128_si256. This is an optimization which
+ * only applies to Sandy and Ivy Bridge... which don't even support AVX2.
+ *
+ * That is why when compiling the AVX2 version, it is recommended to use either
+ * -O2 -mavx2 -march=haswell
+ * or
+ * -O2 -mavx2 -mno-avx256-split-unaligned-load
+ * for decent performance, or to use Clang instead.
+ *
+ * Fortunately, we can control the first one with a pragma that forces GCC into
+ * -O2, but the other one we can't control without "failed to inline always
+ * inline function due to target mismatch" warnings.
+ */
+#if XXH_VECTOR == XXH_AVX2 /* AVX2 */ \
+ && defined(__GNUC__) && !defined(__clang__) /* GCC, not Clang */ \
+ && defined(__OPTIMIZE__) && !defined(__OPTIMIZE_SIZE__) /* respect -O0 and -Os */
+# pragma GCC push_options
+# pragma GCC optimize("-O2")
+#endif
+
+
+#if XXH_VECTOR == XXH_NEON
+/*
+ * NEON's setup for vmlal_u32 is a little more complicated than it is on
+ * SSE2, AVX2, and VSX.
+ *
+ * While PMULUDQ and VMULEUW both perform a mask, VMLAL.U32 performs an upcast.
+ *
+ * To do the same operation, the 128-bit 'Q' register needs to be split into
+ * two 64-bit 'D' registers, performing this operation::
+ *
+ * [ a | b ]
+ * | '---------. .--------' |
+ * | x |
+ * | .---------' '--------. |
+ * [ a & 0xFFFFFFFF | b & 0xFFFFFFFF ],[ a >> 32 | b >> 32 ]
+ *
+ * Due to significant changes in aarch64, the fastest method for aarch64 is
+ * completely different than the fastest method for ARMv7-A.
+ *
+ * ARMv7-A treats D registers as unions overlaying Q registers, so modifying
+ * D11 will modify the high half of Q5. This is similar to how modifying AH
+ * will only affect bits 8-15 of AX on x86.
+ *
+ * VZIP takes two registers, and puts even lanes in one register and odd lanes
+ * in the other.
+ *
+ * On ARMv7-A, this strangely modifies both parameters in place instead of
+ * taking the usual 3-operand form.
+ *
+ * Therefore, if we want to do this, we can simply use a D-form VZIP.32 on the
+ * lower and upper halves of the Q register to end up with the high and low
+ * halves where we want - all in one instruction.
+ *
+ * vzip.32 d10, d11 @ d10 = { d10[0], d11[0] }; d11 = { d10[1], d11[1] }
+ *
+ * Unfortunately we need inline assembly for this: Instructions modifying two
+ * registers at once is not possible in GCC or Clang's IR, and they have to
+ * create a copy.
+ *
+ * aarch64 requires a different approach.
+ *
+ * In order to make it easier to write a decent compiler for aarch64, many
+ * quirks were removed, such as conditional execution.
+ *
+ * NEON was also affected by this.
+ *
+ * aarch64 cannot access the high bits of a Q-form register, and writes to a
+ * D-form register zero the high bits, similar to how writes to W-form scalar
+ * registers (or DWORD registers on x86_64) work.
+ *
+ * The formerly free vget_high intrinsics now require a vext (with a few
+ * exceptions)
+ *
+ * Additionally, VZIP was replaced by ZIP1 and ZIP2, which are the equivalent
+ * of PUNPCKL* and PUNPCKH* in SSE, respectively, in order to only modify one
+ * operand.
+ *
+ * The equivalent of the VZIP.32 on the lower and upper halves would be this
+ * mess:
+ *
+ * ext v2.4s, v0.4s, v0.4s, #2 // v2 = { v0[2], v0[3], v0[0], v0[1] }
+ * zip1 v1.2s, v0.2s, v2.2s // v1 = { v0[0], v2[0] }
+ * zip2 v0.2s, v0.2s, v1.2s // v0 = { v0[1], v2[1] }
+ *
+ * Instead, we use a literal downcast, vmovn_u64 (XTN), and vshrn_n_u64 (SHRN):
+ *
+ * shrn v1.2s, v0.2d, #32 // v1 = (uint32x2_t)(v0 >> 32);
+ * xtn v0.2s, v0.2d // v0 = (uint32x2_t)(v0 & 0xFFFFFFFF);
+ *
+ * This is available on ARMv7-A, but is less efficient than a single VZIP.32.
+ */
+
+/*!
+ * Function-like macro:
+ * void XXH_SPLIT_IN_PLACE(uint64x2_t &in, uint32x2_t &outLo, uint32x2_t &outHi)
+ * {
+ * outLo = (uint32x2_t)(in & 0xFFFFFFFF);
+ * outHi = (uint32x2_t)(in >> 32);
+ * in = UNDEFINED;
+ * }
+ */
+# if !defined(XXH_NO_VZIP_HACK) /* define to disable */ \
+ && defined(__GNUC__) \
+ && !defined(__aarch64__) && !defined(__arm64__)
+# define XXH_SPLIT_IN_PLACE(in, outLo, outHi) \
+ do { \
+ /* Undocumented GCC/Clang operand modifier: %e0 = lower D half, %f0 = upper D half */ \
+ /* https://github.com/gcc-mirror/gcc/blob/38cf91e5/gcc/config/arm/arm.c#L22486 */ \
+ /* https://github.com/llvm-mirror/llvm/blob/2c4ca683/lib/Target/ARM/ARMAsmPrinter.cpp#L399 */ \
+ __asm__("vzip.32 %e0, %f0" : "+w" (in)); \
+ (outLo) = vget_low_u32 (vreinterpretq_u32_u64(in)); \
+ (outHi) = vget_high_u32(vreinterpretq_u32_u64(in)); \
+ } while (0)
+# else
+# define XXH_SPLIT_IN_PLACE(in, outLo, outHi) \
+ do { \
+ (outLo) = vmovn_u64 (in); \
+ (outHi) = vshrn_n_u64 ((in), 32); \
+ } while (0)
+# endif
+#endif /* XXH_VECTOR == XXH_NEON */
+
+/*
+ * VSX and Z Vector helpers.
+ *
+ * This is very messy, and any pull requests to clean this up are welcome.
+ *
+ * There are a lot of problems with supporting VSX and s390x, due to
+ * inconsistent intrinsics, spotty coverage, and multiple endiannesses.
+ */
+#if XXH_VECTOR == XXH_VSX
+# if defined(__s390x__)
+# include <s390intrin.h>
+# else
+/* gcc's altivec.h can have the unwanted consequence to unconditionally
+ * #define bool, vector, and pixel keywords,
+ * with bad consequences for programs already using these keywords for other purposes.
+ * The paragraph defining these macros is skipped when __APPLE_ALTIVEC__ is defined.
+ * __APPLE_ALTIVEC__ is _generally_ defined automatically by the compiler,
+ * but it seems that, in some cases, it isn't.
+ * Force the build macro to be defined, so that keywords are not altered.
+ */
+# if defined(__GNUC__) && !defined(__APPLE_ALTIVEC__)
+# define __APPLE_ALTIVEC__
+# endif
+# include <altivec.h>
+# endif
+
+typedef __vector unsigned long long xxh_u64x2;
+typedef __vector unsigned char xxh_u8x16;
+typedef __vector unsigned xxh_u32x4;
+
+# ifndef XXH_VSX_BE
+# if defined(__BIG_ENDIAN__) \
+ || (defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__)
+# define XXH_VSX_BE 1
+# elif defined(__VEC_ELEMENT_REG_ORDER__) && __VEC_ELEMENT_REG_ORDER__ == __ORDER_BIG_ENDIAN__
+# warning "-maltivec=be is not recommended. Please use native endianness."
+# define XXH_VSX_BE 1
+# else
+# define XXH_VSX_BE 0
+# endif
+# endif /* !defined(XXH_VSX_BE) */
+
+# if XXH_VSX_BE
+# if defined(__POWER9_VECTOR__) || (defined(__clang__) && defined(__s390x__))
+# define XXH_vec_revb vec_revb
+# else
+/*!
+ * A polyfill for POWER9's vec_revb().
+ */
+XXH_FORCE_INLINE xxh_u64x2 XXH_vec_revb(xxh_u64x2 val)
+{
+ xxh_u8x16 const vByteSwap = { 0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01, 0x00,
+ 0x0F, 0x0E, 0x0D, 0x0C, 0x0B, 0x0A, 0x09, 0x08 };
+ return vec_perm(val, val, vByteSwap);
+}
+# endif
+# endif /* XXH_VSX_BE */
+
+/*!
+ * Performs an unaligned vector load and byte swaps it on big endian.
+ */
+XXH_FORCE_INLINE xxh_u64x2 XXH_vec_loadu(const void *ptr)
+{
+ xxh_u64x2 ret;
+ memcpy(&ret, ptr, sizeof(xxh_u64x2));
+# if XXH_VSX_BE
+ ret = XXH_vec_revb(ret);
+# endif
+ return ret;
+}
+
+/*
+ * vec_mulo and vec_mule are very problematic intrinsics on PowerPC
+ *
+ * These intrinsics weren't added until GCC 8, despite existing for a while,
+ * and they are endian dependent. Also, their meaning swap depending on version.
+ * */
+# if defined(__s390x__)
+ /* s390x is always big endian, no issue on this platform */
+# define XXH_vec_mulo vec_mulo
+# define XXH_vec_mule vec_mule
+# elif defined(__clang__) && XXH_HAS_BUILTIN(__builtin_altivec_vmuleuw)
+/* Clang has a better way to control this, we can just use the builtin which doesn't swap. */
+# define XXH_vec_mulo __builtin_altivec_vmulouw
+# define XXH_vec_mule __builtin_altivec_vmuleuw
+# else
+/* gcc needs inline assembly */
+/* Adapted from https://github.com/google/highwayhash/blob/master/highwayhash/hh_vsx.h. */
+XXH_FORCE_INLINE xxh_u64x2 XXH_vec_mulo(xxh_u32x4 a, xxh_u32x4 b)
+{
+ xxh_u64x2 result;
+ __asm__("vmulouw %0, %1, %2" : "=v" (result) : "v" (a), "v" (b));
+ return result;
+}
+XXH_FORCE_INLINE xxh_u64x2 XXH_vec_mule(xxh_u32x4 a, xxh_u32x4 b)
+{
+ xxh_u64x2 result;
+ __asm__("vmuleuw %0, %1, %2" : "=v" (result) : "v" (a), "v" (b));
+ return result;
+}
+# endif /* XXH_vec_mulo, XXH_vec_mule */
+#endif /* XXH_VECTOR == XXH_VSX */
+
+
+/* prefetch
+ * can be disabled, by declaring XXH_NO_PREFETCH build macro */
+#if defined(XXH_NO_PREFETCH)
+# define XXH_PREFETCH(ptr) (void)(ptr) /* disabled */
+#else
+# if defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)) /* _mm_prefetch() not defined outside of x86/x64 */
+# include <mmintrin.h> /* https://msdn.microsoft.com/fr-fr/library/84szxsww(v=vs.90).aspx */
+# define XXH_PREFETCH(ptr) _mm_prefetch((const char*)(ptr), _MM_HINT_T0)
+# elif defined(__GNUC__) && ( (__GNUC__ >= 4) || ( (__GNUC__ == 3) && (__GNUC_MINOR__ >= 1) ) )
+# define XXH_PREFETCH(ptr) __builtin_prefetch((ptr), 0 /* rw==read */, 3 /* locality */)
+# else
+# define XXH_PREFETCH(ptr) (void)(ptr) /* disabled */
+# endif
+#endif /* XXH_NO_PREFETCH */
+
+
+/* ==========================================
+ * XXH3 default settings
+ * ========================================== */
+
+#define XXH_SECRET_DEFAULT_SIZE 192 /* minimum XXH3_SECRET_SIZE_MIN */
+
+#if (XXH_SECRET_DEFAULT_SIZE < XXH3_SECRET_SIZE_MIN)
+# error "default keyset is not large enough"
+#endif
+
+/*! Pseudorandom secret taken directly from FARSH. */
+XXH_ALIGN(64) static const xxh_u8 XXH3_kSecret[XXH_SECRET_DEFAULT_SIZE] = {
+ 0xb8, 0xfe, 0x6c, 0x39, 0x23, 0xa4, 0x4b, 0xbe, 0x7c, 0x01, 0x81, 0x2c, 0xf7, 0x21, 0xad, 0x1c,
+ 0xde, 0xd4, 0x6d, 0xe9, 0x83, 0x90, 0x97, 0xdb, 0x72, 0x40, 0xa4, 0xa4, 0xb7, 0xb3, 0x67, 0x1f,
+ 0xcb, 0x79, 0xe6, 0x4e, 0xcc, 0xc0, 0xe5, 0x78, 0x82, 0x5a, 0xd0, 0x7d, 0xcc, 0xff, 0x72, 0x21,
+ 0xb8, 0x08, 0x46, 0x74, 0xf7, 0x43, 0x24, 0x8e, 0xe0, 0x35, 0x90, 0xe6, 0x81, 0x3a, 0x26, 0x4c,
+ 0x3c, 0x28, 0x52, 0xbb, 0x91, 0xc3, 0x00, 0xcb, 0x88, 0xd0, 0x65, 0x8b, 0x1b, 0x53, 0x2e, 0xa3,
+ 0x71, 0x64, 0x48, 0x97, 0xa2, 0x0d, 0xf9, 0x4e, 0x38, 0x19, 0xef, 0x46, 0xa9, 0xde, 0xac, 0xd8,
+ 0xa8, 0xfa, 0x76, 0x3f, 0xe3, 0x9c, 0x34, 0x3f, 0xf9, 0xdc, 0xbb, 0xc7, 0xc7, 0x0b, 0x4f, 0x1d,
+ 0x8a, 0x51, 0xe0, 0x4b, 0xcd, 0xb4, 0x59, 0x31, 0xc8, 0x9f, 0x7e, 0xc9, 0xd9, 0x78, 0x73, 0x64,
+ 0xea, 0xc5, 0xac, 0x83, 0x34, 0xd3, 0xeb, 0xc3, 0xc5, 0x81, 0xa0, 0xff, 0xfa, 0x13, 0x63, 0xeb,
+ 0x17, 0x0d, 0xdd, 0x51, 0xb7, 0xf0, 0xda, 0x49, 0xd3, 0x16, 0x55, 0x26, 0x29, 0xd4, 0x68, 0x9e,
+ 0x2b, 0x16, 0xbe, 0x58, 0x7d, 0x47, 0xa1, 0xfc, 0x8f, 0xf8, 0xb8, 0xd1, 0x7a, 0xd0, 0x31, 0xce,
+ 0x45, 0xcb, 0x3a, 0x8f, 0x95, 0x16, 0x04, 0x28, 0xaf, 0xd7, 0xfb, 0xca, 0xbb, 0x4b, 0x40, 0x7e,
+};
+
+
+#ifdef XXH_OLD_NAMES
+# define kSecret XXH3_kSecret
+#endif
+
+#ifdef XXH_DOXYGEN
+/*!
+ * @brief Calculates a 32-bit to 64-bit long multiply.
+ *
+ * Implemented as a macro.
+ *
+ * Wraps `__emulu` on MSVC x86 because it tends to call `__allmul` when it doesn't
+ * need to (but it shouldn't need to anyways, it is about 7 instructions to do
+ * a 64x64 multiply...). Since we know that this will _always_ emit `MULL`, we
+ * use that instead of the normal method.
+ *
+ * If you are compiling for platforms like Thumb-1 and don't have a better option,
+ * you may also want to write your own long multiply routine here.
+ *
+ * @param x, y Numbers to be multiplied
+ * @return 64-bit product of the low 32 bits of @p x and @p y.
+ */
+XXH_FORCE_INLINE xxh_u64
+XXH_mult32to64(xxh_u64 x, xxh_u64 y)
+{
+ return (x & 0xFFFFFFFF) * (y & 0xFFFFFFFF);
+}
+#elif defined(_MSC_VER) && defined(_M_IX86)
+# include <intrin.h>
+# define XXH_mult32to64(x, y) __emulu((unsigned)(x), (unsigned)(y))
+#else
+/*
+ * Downcast + upcast is usually better than masking on older compilers like
+ * GCC 4.2 (especially 32-bit ones), all without affecting newer compilers.
+ *
+ * The other method, (x & 0xFFFFFFFF) * (y & 0xFFFFFFFF), will AND both operands
+ * and perform a full 64x64 multiply -- entirely redundant on 32-bit.
+ */
+# define XXH_mult32to64(x, y) ((xxh_u64)(xxh_u32)(x) * (xxh_u64)(xxh_u32)(y))
+#endif
+
+/*!
+ * @brief Calculates a 64->128-bit long multiply.
+ *
+ * Uses `__uint128_t` and `_umul128` if available, otherwise uses a scalar
+ * version.
+ *
+ * @param lhs, rhs The 64-bit integers to be multiplied
+ * @return The 128-bit result represented in an @ref XXH128_hash_t.
+ */
+static XXH128_hash_t
+XXH_mult64to128(xxh_u64 lhs, xxh_u64 rhs)
+{
+ /*
+ * GCC/Clang __uint128_t method.
+ *
+ * On most 64-bit targets, GCC and Clang define a __uint128_t type.
+ * This is usually the best way as it usually uses a native long 64-bit
+ * multiply, such as MULQ on x86_64 or MUL + UMULH on aarch64.
+ *
+ * Usually.
+ *
+ * Despite being a 32-bit platform, Clang (and emscripten) define this type
+ * despite not having the arithmetic for it. This results in a laggy
+ * compiler builtin call which calculates a full 128-bit multiply.
+ * In that case it is best to use the portable one.
+ * https://github.com/Cyan4973/xxHash/issues/211#issuecomment-515575677
+ */
+#if defined(__GNUC__) && !defined(__wasm__) \
+ && defined(__SIZEOF_INT128__) \
+ || (defined(_INTEGRAL_MAX_BITS) && _INTEGRAL_MAX_BITS >= 128)
+
+ __uint128_t const product = (__uint128_t)lhs * (__uint128_t)rhs;
+ XXH128_hash_t r128;
+ r128.low64 = (xxh_u64)(product);
+ r128.high64 = (xxh_u64)(product >> 64);
+ return r128;
+
+ /*
+ * MSVC for x64's _umul128 method.
+ *
+ * xxh_u64 _umul128(xxh_u64 Multiplier, xxh_u64 Multiplicand, xxh_u64 *HighProduct);
+ *
+ * This compiles to single operand MUL on x64.
+ */
+#elif defined(_M_X64) || defined(_M_IA64)
+
+#ifndef _MSC_VER
+# pragma intrinsic(_umul128)
+#endif
+ xxh_u64 product_high;
+ xxh_u64 const product_low = _umul128(lhs, rhs, &product_high);
+ XXH128_hash_t r128;
+ r128.low64 = product_low;
+ r128.high64 = product_high;
+ return r128;
+
+#else
+ /*
+ * Portable scalar method. Optimized for 32-bit and 64-bit ALUs.
+ *
+ * This is a fast and simple grade school multiply, which is shown below
+ * with base 10 arithmetic instead of base 0x100000000.
+ *
+ * 9 3 // D2 lhs = 93
+ * x 7 5 // D2 rhs = 75
+ * ----------
+ * 1 5 // D2 lo_lo = (93 % 10) * (75 % 10) = 15
+ * 4 5 | // D2 hi_lo = (93 / 10) * (75 % 10) = 45
+ * 2 1 | // D2 lo_hi = (93 % 10) * (75 / 10) = 21
+ * + 6 3 | | // D2 hi_hi = (93 / 10) * (75 / 10) = 63
+ * ---------
+ * 2 7 | // D2 cross = (15 / 10) + (45 % 10) + 21 = 27
+ * + 6 7 | | // D2 upper = (27 / 10) + (45 / 10) + 63 = 67
+ * ---------
+ * 6 9 7 5 // D4 res = (27 * 10) + (15 % 10) + (67 * 100) = 6975
+ *
+ * The reasons for adding the products like this are:
+ * 1. It avoids manual carry tracking. Just like how
+ * (9 * 9) + 9 + 9 = 99, the same applies with this for UINT64_MAX.
+ * This avoids a lot of complexity.
+ *
+ * 2. It hints for, and on Clang, compiles to, the powerful UMAAL
+ * instruction available in ARM's Digital Signal Processing extension
+ * in 32-bit ARMv6 and later, which is shown below:
+ *
+ * void UMAAL(xxh_u32 *RdLo, xxh_u32 *RdHi, xxh_u32 Rn, xxh_u32 Rm)
+ * {
+ * xxh_u64 product = (xxh_u64)*RdLo * (xxh_u64)*RdHi + Rn + Rm;
+ * *RdLo = (xxh_u32)(product & 0xFFFFFFFF);
+ * *RdHi = (xxh_u32)(product >> 32);
+ * }
+ *
+ * This instruction was designed for efficient long multiplication, and
+ * allows this to be calculated in only 4 instructions at speeds
+ * comparable to some 64-bit ALUs.
+ *
+ * 3. It isn't terrible on other platforms. Usually this will be a couple
+ * of 32-bit ADD/ADCs.
+ */
+
+ /* First calculate all of the cross products. */
+ xxh_u64 const lo_lo = XXH_mult32to64(lhs & 0xFFFFFFFF, rhs & 0xFFFFFFFF);
+ xxh_u64 const hi_lo = XXH_mult32to64(lhs >> 32, rhs & 0xFFFFFFFF);
+ xxh_u64 const lo_hi = XXH_mult32to64(lhs & 0xFFFFFFFF, rhs >> 32);
+ xxh_u64 const hi_hi = XXH_mult32to64(lhs >> 32, rhs >> 32);
+
+ /* Now add the products together. These will never overflow. */
+ xxh_u64 const cross = (lo_lo >> 32) + (hi_lo & 0xFFFFFFFF) + lo_hi;
+ xxh_u64 const upper = (hi_lo >> 32) + (cross >> 32) + hi_hi;
+ xxh_u64 const lower = (cross << 32) | (lo_lo & 0xFFFFFFFF);
+
+ XXH128_hash_t r128;
+ r128.low64 = lower;
+ r128.high64 = upper;
+ return r128;
+#endif
+}
+
+/*!
+ * @brief Calculates a 64-bit to 128-bit multiply, then XOR folds it.
+ *
+ * The reason for the separate function is to prevent passing too many structs
+ * around by value. This will hopefully inline the multiply, but we don't force it.
+ *
+ * @param lhs, rhs The 64-bit integers to multiply
+ * @return The low 64 bits of the product XOR'd by the high 64 bits.
+ * @see XXH_mult64to128()
+ */
+static xxh_u64
+XXH3_mul128_fold64(xxh_u64 lhs, xxh_u64 rhs)
+{
+ XXH128_hash_t product = XXH_mult64to128(lhs, rhs);
+ return product.low64 ^ product.high64;
+}
+
+/*! Seems to produce slightly better code on GCC for some reason. */
+XXH_FORCE_INLINE xxh_u64 XXH_xorshift64(xxh_u64 v64, int shift)
+{
+ XXH_ASSERT(0 <= shift && shift < 64);
+ return v64 ^ (v64 >> shift);
+}
+
+/*
+ * This is a fast avalanche stage,
+ * suitable when input bits are already partially mixed
+ */
+static XXH64_hash_t XXH3_avalanche(xxh_u64 h64)
+{
+ h64 = XXH_xorshift64(h64, 37);
+ h64 *= 0x165667919E3779F9ULL;
+ h64 = XXH_xorshift64(h64, 32);
+ return h64;
+}
+
+/*
+ * This is a stronger avalanche,
+ * inspired by Pelle Evensen's rrmxmx
+ * preferable when input has not been previously mixed
+ */
+static XXH64_hash_t XXH3_rrmxmx(xxh_u64 h64, xxh_u64 len)
+{
+ /* this mix is inspired by Pelle Evensen's rrmxmx */
+ h64 ^= XXH_rotl64(h64, 49) ^ XXH_rotl64(h64, 24);
+ h64 *= 0x9FB21C651E98DF25ULL;
+ h64 ^= (h64 >> 35) + len ;
+ h64 *= 0x9FB21C651E98DF25ULL;
+ return XXH_xorshift64(h64, 28);
+}
+
+
+/* ==========================================
+ * Short keys
+ * ==========================================
+ * One of the shortcomings of XXH32 and XXH64 was that their performance was
+ * sub-optimal on short lengths. It used an iterative algorithm which strongly
+ * favored lengths that were a multiple of 4 or 8.
+ *
+ * Instead of iterating over individual inputs, we use a set of single shot
+ * functions which piece together a range of lengths and operate in constant time.
+ *
+ * Additionally, the number of multiplies has been significantly reduced. This
+ * reduces latency, especially when emulating 64-bit multiplies on 32-bit.
+ *
+ * Depending on the platform, this may or may not be faster than XXH32, but it
+ * is almost guaranteed to be faster than XXH64.
+ */
+
+/*
+ * At very short lengths, there isn't enough input to fully hide secrets, or use
+ * the entire secret.
+ *
+ * There is also only a limited amount of mixing we can do before significantly
+ * impacting performance.
+ *
+ * Therefore, we use different sections of the secret and always mix two secret
+ * samples with an XOR. This should have no effect on performance on the
+ * seedless or withSeed variants because everything _should_ be constant folded
+ * by modern compilers.
+ *
+ * The XOR mixing hides individual parts of the secret and increases entropy.
+ *
+ * This adds an extra layer of strength for custom secrets.
+ */
+XXH_FORCE_INLINE XXH64_hash_t
+XXH3_len_1to3_64b(const xxh_u8* input, size_t len, const xxh_u8* secret, XXH64_hash_t seed)
+{
+ XXH_ASSERT(input != NULL);
+ XXH_ASSERT(1 <= len && len <= 3);
+ XXH_ASSERT(secret != NULL);
+ /*
+ * len = 1: combined = { input[0], 0x01, input[0], input[0] }
+ * len = 2: combined = { input[1], 0x02, input[0], input[1] }
+ * len = 3: combined = { input[2], 0x03, input[0], input[1] }
+ */
+ { xxh_u8 const c1 = input[0];
+ xxh_u8 const c2 = input[len >> 1];
+ xxh_u8 const c3 = input[len - 1];
+ xxh_u32 const combined = ((xxh_u32)c1 << 16) | ((xxh_u32)c2 << 24)
+ | ((xxh_u32)c3 << 0) | ((xxh_u32)len << 8);
+ xxh_u64 const bitflip = (XXH_readLE32(secret) ^ XXH_readLE32(secret+4)) + seed;
+ xxh_u64 const keyed = (xxh_u64)combined ^ bitflip;
+ return XXH64_avalanche(keyed);
+ }
+}
+
+XXH_FORCE_INLINE XXH64_hash_t
+XXH3_len_4to8_64b(const xxh_u8* input, size_t len, const xxh_u8* secret, XXH64_hash_t seed)
+{
+ XXH_ASSERT(input != NULL);
+ XXH_ASSERT(secret != NULL);
+ XXH_ASSERT(4 <= len && len <= 8);
+ seed ^= (xxh_u64)XXH_swap32((xxh_u32)seed) << 32;
+ { xxh_u32 const input1 = XXH_readLE32(input);
+ xxh_u32 const input2 = XXH_readLE32(input + len - 4);
+ xxh_u64 const bitflip = (XXH_readLE64(secret+8) ^ XXH_readLE64(secret+16)) - seed;
+ xxh_u64 const input64 = input2 + (((xxh_u64)input1) << 32);
+ xxh_u64 const keyed = input64 ^ bitflip;
+ return XXH3_rrmxmx(keyed, len);
+ }
+}
+
+XXH_FORCE_INLINE XXH64_hash_t
+XXH3_len_9to16_64b(const xxh_u8* input, size_t len, const xxh_u8* secret, XXH64_hash_t seed)
+{
+ XXH_ASSERT(input != NULL);
+ XXH_ASSERT(secret != NULL);
+ XXH_ASSERT(9 <= len && len <= 16);
+ { xxh_u64 const bitflip1 = (XXH_readLE64(secret+24) ^ XXH_readLE64(secret+32)) + seed;
+ xxh_u64 const bitflip2 = (XXH_readLE64(secret+40) ^ XXH_readLE64(secret+48)) - seed;
+ xxh_u64 const input_lo = XXH_readLE64(input) ^ bitflip1;
+ xxh_u64 const input_hi = XXH_readLE64(input + len - 8) ^ bitflip2;
+ xxh_u64 const acc = len
+ + XXH_swap64(input_lo) + input_hi
+ + XXH3_mul128_fold64(input_lo, input_hi);
+ return XXH3_avalanche(acc);
+ }
+}
+
+XXH_FORCE_INLINE XXH64_hash_t
+XXH3_len_0to16_64b(const xxh_u8* input, size_t len, const xxh_u8* secret, XXH64_hash_t seed)
+{
+ XXH_ASSERT(len <= 16);
+ { if (XXH_likely(len > 8)) return XXH3_len_9to16_64b(input, len, secret, seed);
+ if (XXH_likely(len >= 4)) return XXH3_len_4to8_64b(input, len, secret, seed);
+ if (len) return XXH3_len_1to3_64b(input, len, secret, seed);
+ return XXH64_avalanche(seed ^ (XXH_readLE64(secret+56) ^ XXH_readLE64(secret+64)));
+ }
+}
+
+/*
+ * DISCLAIMER: There are known *seed-dependent* multicollisions here due to
+ * multiplication by zero, affecting hashes of lengths 17 to 240.
+ *
+ * However, they are very unlikely.
+ *
+ * Keep this in mind when using the unseeded XXH3_64bits() variant: As with all
+ * unseeded non-cryptographic hashes, it does not attempt to defend itself
+ * against specially crafted inputs, only random inputs.
+ *
+ * Compared to classic UMAC where a 1 in 2^31 chance of 4 consecutive bytes
+ * cancelling out the secret is taken an arbitrary number of times (addressed
+ * in XXH3_accumulate_512), this collision is very unlikely with random inputs
+ * and/or proper seeding:
+ *
+ * This only has a 1 in 2^63 chance of 8 consecutive bytes cancelling out, in a
+ * function that is only called up to 16 times per hash with up to 240 bytes of
+ * input.
+ *
+ * This is not too bad for a non-cryptographic hash function, especially with
+ * only 64 bit outputs.
+ *
+ * The 128-bit variant (which trades some speed for strength) is NOT affected
+ * by this, although it is always a good idea to use a proper seed if you care
+ * about strength.
+ */
+XXH_FORCE_INLINE xxh_u64 XXH3_mix16B(const xxh_u8* XXH_RESTRICT input,
+ const xxh_u8* XXH_RESTRICT secret, xxh_u64 seed64)
+{
+#if defined(__GNUC__) && !defined(__clang__) /* GCC, not Clang */ \
+ && defined(__i386__) && defined(__SSE2__) /* x86 + SSE2 */ \
+ && !defined(XXH_ENABLE_AUTOVECTORIZE) /* Define to disable like XXH32 hack */
+ /*
+ * UGLY HACK:
+ * GCC for x86 tends to autovectorize the 128-bit multiply, resulting in
+ * slower code.
+ *
+ * By forcing seed64 into a register, we disrupt the cost model and
+ * cause it to scalarize. See `XXH32_round()`
+ *
+ * FIXME: Clang's output is still _much_ faster -- On an AMD Ryzen 3600,
+ * XXH3_64bits @ len=240 runs at 4.6 GB/s with Clang 9, but 3.3 GB/s on
+ * GCC 9.2, despite both emitting scalar code.
+ *
+ * GCC generates much better scalar code than Clang for the rest of XXH3,
+ * which is why finding a more optimal codepath is an interest.
+ */
+ XXH_COMPILER_GUARD(seed64);
+#endif
+ { xxh_u64 const input_lo = XXH_readLE64(input);
+ xxh_u64 const input_hi = XXH_readLE64(input+8);
+ return XXH3_mul128_fold64(
+ input_lo ^ (XXH_readLE64(secret) + seed64),
+ input_hi ^ (XXH_readLE64(secret+8) - seed64)
+ );
+ }
+}
+
+/* For mid range keys, XXH3 uses a Mum-hash variant. */
+XXH_FORCE_INLINE XXH64_hash_t
+XXH3_len_17to128_64b(const xxh_u8* XXH_RESTRICT input, size_t len,
+ const xxh_u8* XXH_RESTRICT secret, size_t secretSize,
+ XXH64_hash_t seed)
+{
+ XXH_ASSERT(secretSize >= XXH3_SECRET_SIZE_MIN); (void)secretSize;
+ XXH_ASSERT(16 < len && len <= 128);
+
+ { xxh_u64 acc = len * XXH_PRIME64_1;
+ if (len > 32) {
+ if (len > 64) {
+ if (len > 96) {
+ acc += XXH3_mix16B(input+48, secret+96, seed);
+ acc += XXH3_mix16B(input+len-64, secret+112, seed);
+ }
+ acc += XXH3_mix16B(input+32, secret+64, seed);
+ acc += XXH3_mix16B(input+len-48, secret+80, seed);
+ }
+ acc += XXH3_mix16B(input+16, secret+32, seed);
+ acc += XXH3_mix16B(input+len-32, secret+48, seed);
+ }
+ acc += XXH3_mix16B(input+0, secret+0, seed);
+ acc += XXH3_mix16B(input+len-16, secret+16, seed);
+
+ return XXH3_avalanche(acc);
+ }
+}
+
+#define XXH3_MIDSIZE_MAX 240
+
+XXH_NO_INLINE XXH64_hash_t
+XXH3_len_129to240_64b(const xxh_u8* XXH_RESTRICT input, size_t len,
+ const xxh_u8* XXH_RESTRICT secret, size_t secretSize,
+ XXH64_hash_t seed)
+{
+ XXH_ASSERT(secretSize >= XXH3_SECRET_SIZE_MIN); (void)secretSize;
+ XXH_ASSERT(128 < len && len <= XXH3_MIDSIZE_MAX);
+
+ #define XXH3_MIDSIZE_STARTOFFSET 3
+ #define XXH3_MIDSIZE_LASTOFFSET 17
+
+ { xxh_u64 acc = len * XXH_PRIME64_1;
+ int const nbRounds = (int)len / 16;
+ int i;
+ for (i=0; i<8; i++) {
+ acc += XXH3_mix16B(input+(16*i), secret+(16*i), seed);
+ }
+ acc = XXH3_avalanche(acc);
+ XXH_ASSERT(nbRounds >= 8);
+#if defined(__clang__) /* Clang */ \
+ && (defined(__ARM_NEON) || defined(__ARM_NEON__)) /* NEON */ \
+ && !defined(XXH_ENABLE_AUTOVECTORIZE) /* Define to disable */
+ /*
+ * UGLY HACK:
+ * Clang for ARMv7-A tries to vectorize this loop, similar to GCC x86.
+ * In everywhere else, it uses scalar code.
+ *
+ * For 64->128-bit multiplies, even if the NEON was 100% optimal, it
+ * would still be slower than UMAAL (see XXH_mult64to128).
+ *
+ * Unfortunately, Clang doesn't handle the long multiplies properly and
+ * converts them to the nonexistent "vmulq_u64" intrinsic, which is then
+ * scalarized into an ugly mess of VMOV.32 instructions.
+ *
+ * This mess is difficult to avoid without turning autovectorization
+ * off completely, but they are usually relatively minor and/or not
+ * worth it to fix.
+ *
+ * This loop is the easiest to fix, as unlike XXH32, this pragma
+ * _actually works_ because it is a loop vectorization instead of an
+ * SLP vectorization.
+ */
+ #pragma clang loop vectorize(disable)
+#endif
+ for (i=8 ; i < nbRounds; i++) {
+ acc += XXH3_mix16B(input+(16*i), secret+(16*(i-8)) + XXH3_MIDSIZE_STARTOFFSET, seed);
+ }
+ /* last bytes */
+ acc += XXH3_mix16B(input + len - 16, secret + XXH3_SECRET_SIZE_MIN - XXH3_MIDSIZE_LASTOFFSET, seed);
+ return XXH3_avalanche(acc);
+ }
+}
+
+
+/* ======= Long Keys ======= */
+
+#define XXH_STRIPE_LEN 64
+#define XXH_SECRET_CONSUME_RATE 8 /* nb of secret bytes consumed at each accumulation */
+#define XXH_ACC_NB (XXH_STRIPE_LEN / sizeof(xxh_u64))
+
+#ifdef XXH_OLD_NAMES
+# define STRIPE_LEN XXH_STRIPE_LEN
+# define ACC_NB XXH_ACC_NB
+#endif
+
+XXH_FORCE_INLINE void XXH_writeLE64(void* dst, xxh_u64 v64)
+{
+ if (!XXH_CPU_LITTLE_ENDIAN) v64 = XXH_swap64(v64);
+ memcpy(dst, &v64, sizeof(v64));
+}
+
+/* Several intrinsic functions below are supposed to accept __int64 as argument,
+ * as documented in https://software.intel.com/sites/landingpage/IntrinsicsGuide/ .
+ * However, several environments do not define __int64 type,
+ * requiring a workaround.
+ */
+#if !defined (__VMS) \
+ && (defined (__cplusplus) \
+ || (defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) /* C99 */) )
+ typedef int64_t xxh_i64;
+#else
+ /* the following type must have a width of 64-bit */
+ typedef long long xxh_i64;
+#endif
+
+/*
+ * XXH3_accumulate_512 is the tightest loop for long inputs, and it is the most optimized.
+ *
+ * It is a hardened version of UMAC, based off of FARSH's implementation.
+ *
+ * This was chosen because it adapts quite well to 32-bit, 64-bit, and SIMD
+ * implementations, and it is ridiculously fast.
+ *
+ * We harden it by mixing the original input to the accumulators as well as the product.
+ *
+ * This means that in the (relatively likely) case of a multiply by zero, the
+ * original input is preserved.
+ *
+ * On 128-bit inputs, we swap 64-bit pairs when we add the input to improve
+ * cross-pollination, as otherwise the upper and lower halves would be
+ * essentially independent.
+ *
+ * This doesn't matter on 64-bit hashes since they all get merged together in
+ * the end, so we skip the extra step.
+ *
+ * Both XXH3_64bits and XXH3_128bits use this subroutine.
+ */
+
+#if (XXH_VECTOR == XXH_AVX512) \
+ || (defined(XXH_DISPATCH_AVX512) && XXH_DISPATCH_AVX512 != 0)
+
+#ifndef XXH_TARGET_AVX512
+# define XXH_TARGET_AVX512 /* disable attribute target */
+#endif
+
+XXH_FORCE_INLINE XXH_TARGET_AVX512 void
+XXH3_accumulate_512_avx512(void* XXH_RESTRICT acc,
+ const void* XXH_RESTRICT input,
+ const void* XXH_RESTRICT secret)
+{
+ XXH_ALIGN(64) __m512i* const xacc = (__m512i *) acc;
+ XXH_ASSERT((((size_t)acc) & 63) == 0);
+ XXH_STATIC_ASSERT(XXH_STRIPE_LEN == sizeof(__m512i));
+
+ {
+ /* data_vec = input[0]; */
+ __m512i const data_vec = _mm512_loadu_si512 (input);
+ /* key_vec = secret[0]; */
+ __m512i const key_vec = _mm512_loadu_si512 (secret);
+ /* data_key = data_vec ^ key_vec; */
+ __m512i const data_key = _mm512_xor_si512 (data_vec, key_vec);
+ /* data_key_lo = data_key >> 32; */
+ __m512i const data_key_lo = _mm512_shuffle_epi32 (data_key, (_MM_PERM_ENUM)_MM_SHUFFLE(0, 3, 0, 1));
+ /* product = (data_key & 0xffffffff) * (data_key_lo & 0xffffffff); */
+ __m512i const product = _mm512_mul_epu32 (data_key, data_key_lo);
+ /* xacc[0] += swap(data_vec); */
+ __m512i const data_swap = _mm512_shuffle_epi32(data_vec, (_MM_PERM_ENUM)_MM_SHUFFLE(1, 0, 3, 2));
+ __m512i const sum = _mm512_add_epi64(*xacc, data_swap);
+ /* xacc[0] += product; */
+ *xacc = _mm512_add_epi64(product, sum);
+ }
+}
+
+/*
+ * XXH3_scrambleAcc: Scrambles the accumulators to improve mixing.
+ *
+ * Multiplication isn't perfect, as explained by Google in HighwayHash:
+ *
+ * // Multiplication mixes/scrambles bytes 0-7 of the 64-bit result to
+ * // varying degrees. In descending order of goodness, bytes
+ * // 3 4 2 5 1 6 0 7 have quality 228 224 164 160 100 96 36 32.
+ * // As expected, the upper and lower bytes are much worse.
+ *
+ * Source: https://github.com/google/highwayhash/blob/0aaf66b/highwayhash/hh_avx2.h#L291
+ *
+ * Since our algorithm uses a pseudorandom secret to add some variance into the
+ * mix, we don't need to (or want to) mix as often or as much as HighwayHash does.
+ *
+ * This isn't as tight as XXH3_accumulate, but still written in SIMD to avoid
+ * extraction.
+ *
+ * Both XXH3_64bits and XXH3_128bits use this subroutine.
+ */
+
+XXH_FORCE_INLINE XXH_TARGET_AVX512 void
+XXH3_scrambleAcc_avx512(void* XXH_RESTRICT acc, const void* XXH_RESTRICT secret)
+{
+ XXH_ASSERT((((size_t)acc) & 63) == 0);
+ XXH_STATIC_ASSERT(XXH_STRIPE_LEN == sizeof(__m512i));
+ { XXH_ALIGN(64) __m512i* const xacc = (__m512i*) acc;
+ const __m512i prime32 = _mm512_set1_epi32((int)XXH_PRIME32_1);
+
+ /* xacc[0] ^= (xacc[0] >> 47) */
+ __m512i const acc_vec = *xacc;
+ __m512i const shifted = _mm512_srli_epi64 (acc_vec, 47);
+ __m512i const data_vec = _mm512_xor_si512 (acc_vec, shifted);
+ /* xacc[0] ^= secret; */
+ __m512i const key_vec = _mm512_loadu_si512 (secret);
+ __m512i const data_key = _mm512_xor_si512 (data_vec, key_vec);
+
+ /* xacc[0] *= XXH_PRIME32_1; */
+ __m512i const data_key_hi = _mm512_shuffle_epi32 (data_key, (_MM_PERM_ENUM)_MM_SHUFFLE(0, 3, 0, 1));
+ __m512i const prod_lo = _mm512_mul_epu32 (data_key, prime32);
+ __m512i const prod_hi = _mm512_mul_epu32 (data_key_hi, prime32);
+ *xacc = _mm512_add_epi64(prod_lo, _mm512_slli_epi64(prod_hi, 32));
+ }
+}
+
+XXH_FORCE_INLINE XXH_TARGET_AVX512 void
+XXH3_initCustomSecret_avx512(void* XXH_RESTRICT customSecret, xxh_u64 seed64)
+{
+ XXH_STATIC_ASSERT((XXH_SECRET_DEFAULT_SIZE & 63) == 0);
+ XXH_STATIC_ASSERT(XXH_SEC_ALIGN == 64);
+ XXH_ASSERT(((size_t)customSecret & 63) == 0);
+ (void)(&XXH_writeLE64);
+ { int const nbRounds = XXH_SECRET_DEFAULT_SIZE / sizeof(__m512i);
+ __m512i const seed = _mm512_mask_set1_epi64(_mm512_set1_epi64((xxh_i64)seed64), 0xAA, (xxh_i64)(0U - seed64));
+
+ XXH_ALIGN(64) const __m512i* const src = (const __m512i*) XXH3_kSecret;
+ XXH_ALIGN(64) __m512i* const dest = ( __m512i*) customSecret;
+ int i;
+ for (i=0; i < nbRounds; ++i) {
+ /* GCC has a bug, _mm512_stream_load_si512 accepts 'void*', not 'void const*',
+ * this will warn "discards 'const' qualifier". */
+ union {
+ XXH_ALIGN(64) const __m512i* cp;
+ XXH_ALIGN(64) void* p;
+ } remote_const_void;
+ remote_const_void.cp = src + i;
+ dest[i] = _mm512_add_epi64(_mm512_stream_load_si512(remote_const_void.p), seed);
+ } }
+}
+
+#endif
+
+#if (XXH_VECTOR == XXH_AVX2) \
+ || (defined(XXH_DISPATCH_AVX2) && XXH_DISPATCH_AVX2 != 0)
+
+#ifndef XXH_TARGET_AVX2
+# define XXH_TARGET_AVX2 /* disable attribute target */
+#endif
+
+XXH_FORCE_INLINE XXH_TARGET_AVX2 void
+XXH3_accumulate_512_avx2( void* XXH_RESTRICT acc,
+ const void* XXH_RESTRICT input,
+ const void* XXH_RESTRICT secret)
+{
+ XXH_ASSERT((((size_t)acc) & 31) == 0);
+ { XXH_ALIGN(32) __m256i* const xacc = (__m256i *) acc;
+ /* Unaligned. This is mainly for pointer arithmetic, and because
+ * _mm256_loadu_si256 requires a const __m256i * pointer for some reason. */
+ const __m256i* const xinput = (const __m256i *) input;
+ /* Unaligned. This is mainly for pointer arithmetic, and because
+ * _mm256_loadu_si256 requires a const __m256i * pointer for some reason. */
+ const __m256i* const xsecret = (const __m256i *) secret;
+
+ size_t i;
+ for (i=0; i < XXH_STRIPE_LEN/sizeof(__m256i); i++) {
+ /* data_vec = xinput[i]; */
+ __m256i const data_vec = _mm256_loadu_si256 (xinput+i);
+ /* key_vec = xsecret[i]; */
+ __m256i const key_vec = _mm256_loadu_si256 (xsecret+i);
+ /* data_key = data_vec ^ key_vec; */
+ __m256i const data_key = _mm256_xor_si256 (data_vec, key_vec);
+ /* data_key_lo = data_key >> 32; */
+ __m256i const data_key_lo = _mm256_shuffle_epi32 (data_key, _MM_SHUFFLE(0, 3, 0, 1));
+ /* product = (data_key & 0xffffffff) * (data_key_lo & 0xffffffff); */
+ __m256i const product = _mm256_mul_epu32 (data_key, data_key_lo);
+ /* xacc[i] += swap(data_vec); */
+ __m256i const data_swap = _mm256_shuffle_epi32(data_vec, _MM_SHUFFLE(1, 0, 3, 2));
+ __m256i const sum = _mm256_add_epi64(xacc[i], data_swap);
+ /* xacc[i] += product; */
+ xacc[i] = _mm256_add_epi64(product, sum);
+ } }
+}
+
+XXH_FORCE_INLINE XXH_TARGET_AVX2 void
+XXH3_scrambleAcc_avx2(void* XXH_RESTRICT acc, const void* XXH_RESTRICT secret)
+{
+ XXH_ASSERT((((size_t)acc) & 31) == 0);
+ { XXH_ALIGN(32) __m256i* const xacc = (__m256i*) acc;
+ /* Unaligned. This is mainly for pointer arithmetic, and because
+ * _mm256_loadu_si256 requires a const __m256i * pointer for some reason. */
+ const __m256i* const xsecret = (const __m256i *) secret;
+ const __m256i prime32 = _mm256_set1_epi32((int)XXH_PRIME32_1);
+
+ size_t i;
+ for (i=0; i < XXH_STRIPE_LEN/sizeof(__m256i); i++) {
+ /* xacc[i] ^= (xacc[i] >> 47) */
+ __m256i const acc_vec = xacc[i];
+ __m256i const shifted = _mm256_srli_epi64 (acc_vec, 47);
+ __m256i const data_vec = _mm256_xor_si256 (acc_vec, shifted);
+ /* xacc[i] ^= xsecret; */
+ __m256i const key_vec = _mm256_loadu_si256 (xsecret+i);
+ __m256i const data_key = _mm256_xor_si256 (data_vec, key_vec);
+
+ /* xacc[i] *= XXH_PRIME32_1; */
+ __m256i const data_key_hi = _mm256_shuffle_epi32 (data_key, _MM_SHUFFLE(0, 3, 0, 1));
+ __m256i const prod_lo = _mm256_mul_epu32 (data_key, prime32);
+ __m256i const prod_hi = _mm256_mul_epu32 (data_key_hi, prime32);
+ xacc[i] = _mm256_add_epi64(prod_lo, _mm256_slli_epi64(prod_hi, 32));
+ }
+ }
+}
+
+XXH_FORCE_INLINE XXH_TARGET_AVX2 void XXH3_initCustomSecret_avx2(void* XXH_RESTRICT customSecret, xxh_u64 seed64)
+{
+ XXH_STATIC_ASSERT((XXH_SECRET_DEFAULT_SIZE & 31) == 0);
+ XXH_STATIC_ASSERT((XXH_SECRET_DEFAULT_SIZE / sizeof(__m256i)) == 6);
+ XXH_STATIC_ASSERT(XXH_SEC_ALIGN <= 64);
+ (void)(&XXH_writeLE64);
+ XXH_PREFETCH(customSecret);
+ { __m256i const seed = _mm256_set_epi64x((xxh_i64)(0U - seed64), (xxh_i64)seed64, (xxh_i64)(0U - seed64), (xxh_i64)seed64);
+
+ XXH_ALIGN(64) const __m256i* const src = (const __m256i*) XXH3_kSecret;
+ XXH_ALIGN(64) __m256i* dest = ( __m256i*) customSecret;
+
+# if defined(__GNUC__) || defined(__clang__)
+ /*
+ * On GCC & Clang, marking 'dest' as modified will cause the compiler:
+ * - do not extract the secret from sse registers in the internal loop
+ * - use less common registers, and avoid pushing these reg into stack
+ */
+ XXH_COMPILER_GUARD(dest);
+# endif
+
+ /* GCC -O2 need unroll loop manually */
+ dest[0] = _mm256_add_epi64(_mm256_stream_load_si256(src+0), seed);
+ dest[1] = _mm256_add_epi64(_mm256_stream_load_si256(src+1), seed);
+ dest[2] = _mm256_add_epi64(_mm256_stream_load_si256(src+2), seed);
+ dest[3] = _mm256_add_epi64(_mm256_stream_load_si256(src+3), seed);
+ dest[4] = _mm256_add_epi64(_mm256_stream_load_si256(src+4), seed);
+ dest[5] = _mm256_add_epi64(_mm256_stream_load_si256(src+5), seed);
+ }
+}
+
+#endif
+
+/* x86dispatch always generates SSE2 */
+#if (XXH_VECTOR == XXH_SSE2) || defined(XXH_X86DISPATCH)
+
+#ifndef XXH_TARGET_SSE2
+# define XXH_TARGET_SSE2 /* disable attribute target */
+#endif
+
+XXH_FORCE_INLINE XXH_TARGET_SSE2 void
+XXH3_accumulate_512_sse2( void* XXH_RESTRICT acc,
+ const void* XXH_RESTRICT input,
+ const void* XXH_RESTRICT secret)
+{
+ /* SSE2 is just a half-scale version of the AVX2 version. */
+ XXH_ASSERT((((size_t)acc) & 15) == 0);
+ { XXH_ALIGN(16) __m128i* const xacc = (__m128i *) acc;
+ /* Unaligned. This is mainly for pointer arithmetic, and because
+ * _mm_loadu_si128 requires a const __m128i * pointer for some reason. */
+ const __m128i* const xinput = (const __m128i *) input;
+ /* Unaligned. This is mainly for pointer arithmetic, and because
+ * _mm_loadu_si128 requires a const __m128i * pointer for some reason. */
+ const __m128i* const xsecret = (const __m128i *) secret;
+
+ size_t i;
+ for (i=0; i < XXH_STRIPE_LEN/sizeof(__m128i); i++) {
+ /* data_vec = xinput[i]; */
+ __m128i const data_vec = _mm_loadu_si128 (xinput+i);
+ /* key_vec = xsecret[i]; */
+ __m128i const key_vec = _mm_loadu_si128 (xsecret+i);
+ /* data_key = data_vec ^ key_vec; */
+ __m128i const data_key = _mm_xor_si128 (data_vec, key_vec);
+ /* data_key_lo = data_key >> 32; */
+ __m128i const data_key_lo = _mm_shuffle_epi32 (data_key, _MM_SHUFFLE(0, 3, 0, 1));
+ /* product = (data_key & 0xffffffff) * (data_key_lo & 0xffffffff); */
+ __m128i const product = _mm_mul_epu32 (data_key, data_key_lo);
+ /* xacc[i] += swap(data_vec); */
+ __m128i const data_swap = _mm_shuffle_epi32(data_vec, _MM_SHUFFLE(1,0,3,2));
+ __m128i const sum = _mm_add_epi64(xacc[i], data_swap);
+ /* xacc[i] += product; */
+ xacc[i] = _mm_add_epi64(product, sum);
+ } }
+}
+
+XXH_FORCE_INLINE XXH_TARGET_SSE2 void
+XXH3_scrambleAcc_sse2(void* XXH_RESTRICT acc, const void* XXH_RESTRICT secret)
+{
+ XXH_ASSERT((((size_t)acc) & 15) == 0);
+ { XXH_ALIGN(16) __m128i* const xacc = (__m128i*) acc;
+ /* Unaligned. This is mainly for pointer arithmetic, and because
+ * _mm_loadu_si128 requires a const __m128i * pointer for some reason. */
+ const __m128i* const xsecret = (const __m128i *) secret;
+ const __m128i prime32 = _mm_set1_epi32((int)XXH_PRIME32_1);
+
+ size_t i;
+ for (i=0; i < XXH_STRIPE_LEN/sizeof(__m128i); i++) {
+ /* xacc[i] ^= (xacc[i] >> 47) */
+ __m128i const acc_vec = xacc[i];
+ __m128i const shifted = _mm_srli_epi64 (acc_vec, 47);
+ __m128i const data_vec = _mm_xor_si128 (acc_vec, shifted);
+ /* xacc[i] ^= xsecret[i]; */
+ __m128i const key_vec = _mm_loadu_si128 (xsecret+i);
+ __m128i const data_key = _mm_xor_si128 (data_vec, key_vec);
+
+ /* xacc[i] *= XXH_PRIME32_1; */
+ __m128i const data_key_hi = _mm_shuffle_epi32 (data_key, _MM_SHUFFLE(0, 3, 0, 1));
+ __m128i const prod_lo = _mm_mul_epu32 (data_key, prime32);
+ __m128i const prod_hi = _mm_mul_epu32 (data_key_hi, prime32);
+ xacc[i] = _mm_add_epi64(prod_lo, _mm_slli_epi64(prod_hi, 32));
+ }
+ }
+}
+
+XXH_FORCE_INLINE XXH_TARGET_SSE2 void XXH3_initCustomSecret_sse2(void* XXH_RESTRICT customSecret, xxh_u64 seed64)
+{
+ XXH_STATIC_ASSERT((XXH_SECRET_DEFAULT_SIZE & 15) == 0);
+ (void)(&XXH_writeLE64);
+ { int const nbRounds = XXH_SECRET_DEFAULT_SIZE / sizeof(__m128i);
+
+# if defined(_MSC_VER) && defined(_M_IX86) && _MSC_VER < 1900
+ // MSVC 32bit mode does not support _mm_set_epi64x before 2015
+ XXH_ALIGN(16) const xxh_i64 seed64x2[2] = { (xxh_i64)seed64, (xxh_i64)(0U - seed64) };
+ __m128i const seed = _mm_load_si128((__m128i const*)seed64x2);
+# else
+ __m128i const seed = _mm_set_epi64x((xxh_i64)(0U - seed64), (xxh_i64)seed64);
+# endif
+ int i;
+
+ XXH_ALIGN(64) const float* const src = (float const*) XXH3_kSecret;
+ XXH_ALIGN(XXH_SEC_ALIGN) __m128i* dest = (__m128i*) customSecret;
+# if defined(__GNUC__) || defined(__clang__)
+ /*
+ * On GCC & Clang, marking 'dest' as modified will cause the compiler:
+ * - do not extract the secret from sse registers in the internal loop
+ * - use less common registers, and avoid pushing these reg into stack
+ */
+ XXH_COMPILER_GUARD(dest);
+# endif
+
+ for (i=0; i < nbRounds; ++i) {
+ dest[i] = _mm_add_epi64(_mm_castps_si128(_mm_load_ps(src+i*4)), seed);
+ } }
+}
+
+#endif
+
+#if (XXH_VECTOR == XXH_NEON)
+
+XXH_FORCE_INLINE void
+XXH3_accumulate_512_neon( void* XXH_RESTRICT acc,
+ const void* XXH_RESTRICT input,
+ const void* XXH_RESTRICT secret)
+{
+ XXH_ASSERT((((size_t)acc) & 15) == 0);
+ {
+ XXH_ALIGN(16) uint64x2_t* const xacc = (uint64x2_t *) acc;
+ /* We don't use a uint32x4_t pointer because it causes bus errors on ARMv7. */
+ uint8_t const* const xinput = (const uint8_t *) input;
+ uint8_t const* const xsecret = (const uint8_t *) secret;
+
+ size_t i;
+ for (i=0; i < XXH_STRIPE_LEN / sizeof(uint64x2_t); i++) {
+ /* data_vec = xinput[i]; */
+ uint8x16_t data_vec = vld1q_u8(xinput + (i * 16));
+ /* key_vec = xsecret[i]; */
+ uint8x16_t key_vec = vld1q_u8(xsecret + (i * 16));
+ uint64x2_t data_key;
+ uint32x2_t data_key_lo, data_key_hi;
+ /* xacc[i] += swap(data_vec); */
+ uint64x2_t const data64 = vreinterpretq_u64_u8(data_vec);
+ uint64x2_t const swapped = vextq_u64(data64, data64, 1);
+ xacc[i] = vaddq_u64 (xacc[i], swapped);
+ /* data_key = data_vec ^ key_vec; */
+ data_key = vreinterpretq_u64_u8(veorq_u8(data_vec, key_vec));
+ /* data_key_lo = (uint32x2_t) (data_key & 0xFFFFFFFF);
+ * data_key_hi = (uint32x2_t) (data_key >> 32);
+ * data_key = UNDEFINED; */
+ XXH_SPLIT_IN_PLACE(data_key, data_key_lo, data_key_hi);
+ /* xacc[i] += (uint64x2_t) data_key_lo * (uint64x2_t) data_key_hi; */
+ xacc[i] = vmlal_u32 (xacc[i], data_key_lo, data_key_hi);
+
+ }
+ }
+}
+
+XXH_FORCE_INLINE void
+XXH3_scrambleAcc_neon(void* XXH_RESTRICT acc, const void* XXH_RESTRICT secret)
+{
+ XXH_ASSERT((((size_t)acc) & 15) == 0);
+
+ { uint64x2_t* xacc = (uint64x2_t*) acc;
+ uint8_t const* xsecret = (uint8_t const*) secret;
+ uint32x2_t prime = vdup_n_u32 (XXH_PRIME32_1);
+
+ size_t i;
+ for (i=0; i < XXH_STRIPE_LEN/sizeof(uint64x2_t); i++) {
+ /* xacc[i] ^= (xacc[i] >> 47); */
+ uint64x2_t acc_vec = xacc[i];
+ uint64x2_t shifted = vshrq_n_u64 (acc_vec, 47);
+ uint64x2_t data_vec = veorq_u64 (acc_vec, shifted);
+
+ /* xacc[i] ^= xsecret[i]; */
+ uint8x16_t key_vec = vld1q_u8(xsecret + (i * 16));
+ uint64x2_t data_key = veorq_u64(data_vec, vreinterpretq_u64_u8(key_vec));
+
+ /* xacc[i] *= XXH_PRIME32_1 */
+ uint32x2_t data_key_lo, data_key_hi;
+ /* data_key_lo = (uint32x2_t) (xacc[i] & 0xFFFFFFFF);
+ * data_key_hi = (uint32x2_t) (xacc[i] >> 32);
+ * xacc[i] = UNDEFINED; */
+ XXH_SPLIT_IN_PLACE(data_key, data_key_lo, data_key_hi);
+ { /*
+ * prod_hi = (data_key >> 32) * XXH_PRIME32_1;
+ *
+ * Avoid vmul_u32 + vshll_n_u32 since Clang 6 and 7 will
+ * incorrectly "optimize" this:
+ * tmp = vmul_u32(vmovn_u64(a), vmovn_u64(b));
+ * shifted = vshll_n_u32(tmp, 32);
+ * to this:
+ * tmp = "vmulq_u64"(a, b); // no such thing!
+ * shifted = vshlq_n_u64(tmp, 32);
+ *
+ * However, unlike SSE, Clang lacks a 64-bit multiply routine
+ * for NEON, and it scalarizes two 64-bit multiplies instead.
+ *
+ * vmull_u32 has the same timing as vmul_u32, and it avoids
+ * this bug completely.
+ * See https://bugs.llvm.org/show_bug.cgi?id=39967
+ */
+ uint64x2_t prod_hi = vmull_u32 (data_key_hi, prime);
+ /* xacc[i] = prod_hi << 32; */
+ xacc[i] = vshlq_n_u64(prod_hi, 32);
+ /* xacc[i] += (prod_hi & 0xFFFFFFFF) * XXH_PRIME32_1; */
+ xacc[i] = vmlal_u32(xacc[i], data_key_lo, prime);
+ }
+ } }
+}
+
+#endif
+
+#if (XXH_VECTOR == XXH_VSX)
+
+XXH_FORCE_INLINE void
+XXH3_accumulate_512_vsx( void* XXH_RESTRICT acc,
+ const void* XXH_RESTRICT input,
+ const void* XXH_RESTRICT secret)
+{
+ xxh_u64x2* const xacc = (xxh_u64x2*) acc; /* presumed aligned */
+ xxh_u64x2 const* const xinput = (xxh_u64x2 const*) input; /* no alignment restriction */
+ xxh_u64x2 const* const xsecret = (xxh_u64x2 const*) secret; /* no alignment restriction */
+ xxh_u64x2 const v32 = { 32, 32 };
+ size_t i;
+ for (i = 0; i < XXH_STRIPE_LEN / sizeof(xxh_u64x2); i++) {
+ /* data_vec = xinput[i]; */
+ xxh_u64x2 const data_vec = XXH_vec_loadu(xinput + i);
+ /* key_vec = xsecret[i]; */
+ xxh_u64x2 const key_vec = XXH_vec_loadu(xsecret + i);
+ xxh_u64x2 const data_key = data_vec ^ key_vec;
+ /* shuffled = (data_key << 32) | (data_key >> 32); */
+ xxh_u32x4 const shuffled = (xxh_u32x4)vec_rl(data_key, v32);
+ /* product = ((xxh_u64x2)data_key & 0xFFFFFFFF) * ((xxh_u64x2)shuffled & 0xFFFFFFFF); */
+ xxh_u64x2 const product = XXH_vec_mulo((xxh_u32x4)data_key, shuffled);
+ xacc[i] += product;
+
+ /* swap high and low halves */
+#ifdef __s390x__
+ xacc[i] += vec_permi(data_vec, data_vec, 2);
+#else
+ xacc[i] += vec_xxpermdi(data_vec, data_vec, 2);
+#endif
+ }
+}
+
+XXH_FORCE_INLINE void
+XXH3_scrambleAcc_vsx(void* XXH_RESTRICT acc, const void* XXH_RESTRICT secret)
+{
+ XXH_ASSERT((((size_t)acc) & 15) == 0);
+
+ { xxh_u64x2* const xacc = (xxh_u64x2*) acc;
+ const xxh_u64x2* const xsecret = (const xxh_u64x2*) secret;
+ /* constants */
+ xxh_u64x2 const v32 = { 32, 32 };
+ xxh_u64x2 const v47 = { 47, 47 };
+ xxh_u32x4 const prime = { XXH_PRIME32_1, XXH_PRIME32_1, XXH_PRIME32_1, XXH_PRIME32_1 };
+ size_t i;
+ for (i = 0; i < XXH_STRIPE_LEN / sizeof(xxh_u64x2); i++) {
+ /* xacc[i] ^= (xacc[i] >> 47); */
+ xxh_u64x2 const acc_vec = xacc[i];
+ xxh_u64x2 const data_vec = acc_vec ^ (acc_vec >> v47);
+
+ /* xacc[i] ^= xsecret[i]; */
+ xxh_u64x2 const key_vec = XXH_vec_loadu(xsecret + i);
+ xxh_u64x2 const data_key = data_vec ^ key_vec;
+
+ /* xacc[i] *= XXH_PRIME32_1 */
+ /* prod_lo = ((xxh_u64x2)data_key & 0xFFFFFFFF) * ((xxh_u64x2)prime & 0xFFFFFFFF); */
+ xxh_u64x2 const prod_even = XXH_vec_mule((xxh_u32x4)data_key, prime);
+ /* prod_hi = ((xxh_u64x2)data_key >> 32) * ((xxh_u64x2)prime >> 32); */
+ xxh_u64x2 const prod_odd = XXH_vec_mulo((xxh_u32x4)data_key, prime);
+ xacc[i] = prod_odd + (prod_even << v32);
+ } }
+}
+
+#endif
+
+/* scalar variants - universal */
+
+XXH_FORCE_INLINE void
+XXH3_accumulate_512_scalar(void* XXH_RESTRICT acc,
+ const void* XXH_RESTRICT input,
+ const void* XXH_RESTRICT secret)
+{
+ XXH_ALIGN(XXH_ACC_ALIGN) xxh_u64* const xacc = (xxh_u64*) acc; /* presumed aligned */
+ const xxh_u8* const xinput = (const xxh_u8*) input; /* no alignment restriction */
+ const xxh_u8* const xsecret = (const xxh_u8*) secret; /* no alignment restriction */
+ size_t i;
+ XXH_ASSERT(((size_t)acc & (XXH_ACC_ALIGN-1)) == 0);
+ for (i=0; i < XXH_ACC_NB; i++) {
+ xxh_u64 const data_val = XXH_readLE64(xinput + 8*i);
+ xxh_u64 const data_key = data_val ^ XXH_readLE64(xsecret + i*8);
+ xacc[i ^ 1] += data_val; /* swap adjacent lanes */
+ xacc[i] += XXH_mult32to64(data_key & 0xFFFFFFFF, data_key >> 32);
+ }
+}
+
+XXH_FORCE_INLINE void
+XXH3_scrambleAcc_scalar(void* XXH_RESTRICT acc, const void* XXH_RESTRICT secret)
+{
+ XXH_ALIGN(XXH_ACC_ALIGN) xxh_u64* const xacc = (xxh_u64*) acc; /* presumed aligned */
+ const xxh_u8* const xsecret = (const xxh_u8*) secret; /* no alignment restriction */
+ size_t i;
+ XXH_ASSERT((((size_t)acc) & (XXH_ACC_ALIGN-1)) == 0);
+ for (i=0; i < XXH_ACC_NB; i++) {
+ xxh_u64 const key64 = XXH_readLE64(xsecret + 8*i);
+ xxh_u64 acc64 = xacc[i];
+ acc64 = XXH_xorshift64(acc64, 47);
+ acc64 ^= key64;
+ acc64 *= XXH_PRIME32_1;
+ xacc[i] = acc64;
+ }
+}
+
+XXH_FORCE_INLINE void
+XXH3_initCustomSecret_scalar(void* XXH_RESTRICT customSecret, xxh_u64 seed64)
+{
+ /*
+ * We need a separate pointer for the hack below,
+ * which requires a non-const pointer.
+ * Any decent compiler will optimize this out otherwise.
+ */
+ const xxh_u8* kSecretPtr = XXH3_kSecret;
+ XXH_STATIC_ASSERT((XXH_SECRET_DEFAULT_SIZE & 15) == 0);
+
+#if defined(__clang__) && defined(__aarch64__)
+ /*
+ * UGLY HACK:
+ * Clang generates a bunch of MOV/MOVK pairs for aarch64, and they are
+ * placed sequentially, in order, at the top of the unrolled loop.
+ *
+ * While MOVK is great for generating constants (2 cycles for a 64-bit
+ * constant compared to 4 cycles for LDR), long MOVK chains stall the
+ * integer pipelines:
+ * I L S
+ * MOVK
+ * MOVK
+ * MOVK
+ * MOVK
+ * ADD
+ * SUB STR
+ * STR
+ * By forcing loads from memory (as the asm line causes Clang to assume
+ * that XXH3_kSecretPtr has been changed), the pipelines are used more
+ * efficiently:
+ * I L S
+ * LDR
+ * ADD LDR
+ * SUB STR
+ * STR
+ * XXH3_64bits_withSeed, len == 256, Snapdragon 835
+ * without hack: 2654.4 MB/s
+ * with hack: 3202.9 MB/s
+ */
+ XXH_COMPILER_GUARD(kSecretPtr);
+#endif
+ /*
+ * Note: in debug mode, this overrides the asm optimization
+ * and Clang will emit MOVK chains again.
+ */
+ XXH_ASSERT(kSecretPtr == XXH3_kSecret);
+
+ { int const nbRounds = XXH_SECRET_DEFAULT_SIZE / 16;
+ int i;
+ for (i=0; i < nbRounds; i++) {
+ /*
+ * The asm hack causes Clang to assume that kSecretPtr aliases with
+ * customSecret, and on aarch64, this prevented LDP from merging two
+ * loads together for free. Putting the loads together before the stores
+ * properly generates LDP.
+ */
+ xxh_u64 lo = XXH_readLE64(kSecretPtr + 16*i) + seed64;
+ xxh_u64 hi = XXH_readLE64(kSecretPtr + 16*i + 8) - seed64;
+ XXH_writeLE64((xxh_u8*)customSecret + 16*i, lo);
+ XXH_writeLE64((xxh_u8*)customSecret + 16*i + 8, hi);
+ } }
+}
+
+
+typedef void (*XXH3_f_accumulate_512)(void* XXH_RESTRICT, const void*, const void*);
+typedef void (*XXH3_f_scrambleAcc)(void* XXH_RESTRICT, const void*);
+typedef void (*XXH3_f_initCustomSecret)(void* XXH_RESTRICT, xxh_u64);
+
+
+#if (XXH_VECTOR == XXH_AVX512)
+
+#define XXH3_accumulate_512 XXH3_accumulate_512_avx512
+#define XXH3_scrambleAcc XXH3_scrambleAcc_avx512
+#define XXH3_initCustomSecret XXH3_initCustomSecret_avx512
+
+#elif (XXH_VECTOR == XXH_AVX2)
+
+#define XXH3_accumulate_512 XXH3_accumulate_512_avx2
+#define XXH3_scrambleAcc XXH3_scrambleAcc_avx2
+#define XXH3_initCustomSecret XXH3_initCustomSecret_avx2
+
+#elif (XXH_VECTOR == XXH_SSE2)
+
+#define XXH3_accumulate_512 XXH3_accumulate_512_sse2
+#define XXH3_scrambleAcc XXH3_scrambleAcc_sse2
+#define XXH3_initCustomSecret XXH3_initCustomSecret_sse2
+
+#elif (XXH_VECTOR == XXH_NEON)
+
+#define XXH3_accumulate_512 XXH3_accumulate_512_neon
+#define XXH3_scrambleAcc XXH3_scrambleAcc_neon
+#define XXH3_initCustomSecret XXH3_initCustomSecret_scalar
+
+#elif (XXH_VECTOR == XXH_VSX)
+
+#define XXH3_accumulate_512 XXH3_accumulate_512_vsx
+#define XXH3_scrambleAcc XXH3_scrambleAcc_vsx
+#define XXH3_initCustomSecret XXH3_initCustomSecret_scalar
+
+#else /* scalar */
+
+#define XXH3_accumulate_512 XXH3_accumulate_512_scalar
+#define XXH3_scrambleAcc XXH3_scrambleAcc_scalar
+#define XXH3_initCustomSecret XXH3_initCustomSecret_scalar
+
+#endif
+
+
+
+#ifndef XXH_PREFETCH_DIST
+# ifdef __clang__
+# define XXH_PREFETCH_DIST 320
+# else
+# if (XXH_VECTOR == XXH_AVX512)
+# define XXH_PREFETCH_DIST 512
+# else
+# define XXH_PREFETCH_DIST 384
+# endif
+# endif /* __clang__ */
+#endif /* XXH_PREFETCH_DIST */
+
+/*
+ * XXH3_accumulate()
+ * Loops over XXH3_accumulate_512().
+ * Assumption: nbStripes will not overflow the secret size
+ */
+XXH_FORCE_INLINE void
+XXH3_accumulate( xxh_u64* XXH_RESTRICT acc,
+ const xxh_u8* XXH_RESTRICT input,
+ const xxh_u8* XXH_RESTRICT secret,
+ size_t nbStripes,
+ XXH3_f_accumulate_512 f_acc512)
+{
+ size_t n;
+ for (n = 0; n < nbStripes; n++ ) {
+ const xxh_u8* const in = input + n*XXH_STRIPE_LEN;
+ XXH_PREFETCH(in + XXH_PREFETCH_DIST);
+ f_acc512(acc,
+ in,
+ secret + n*XXH_SECRET_CONSUME_RATE);
+ }
+}
+
+XXH_FORCE_INLINE void
+XXH3_hashLong_internal_loop(xxh_u64* XXH_RESTRICT acc,
+ const xxh_u8* XXH_RESTRICT input, size_t len,
+ const xxh_u8* XXH_RESTRICT secret, size_t secretSize,
+ XXH3_f_accumulate_512 f_acc512,
+ XXH3_f_scrambleAcc f_scramble)
+{
+ size_t const nbStripesPerBlock = (secretSize - XXH_STRIPE_LEN) / XXH_SECRET_CONSUME_RATE;
+ size_t const block_len = XXH_STRIPE_LEN * nbStripesPerBlock;
+ size_t const nb_blocks = (len - 1) / block_len;
+
+ size_t n;
+
+ XXH_ASSERT(secretSize >= XXH3_SECRET_SIZE_MIN);
+
+ for (n = 0; n < nb_blocks; n++) {
+ XXH3_accumulate(acc, input + n*block_len, secret, nbStripesPerBlock, f_acc512);
+ f_scramble(acc, secret + secretSize - XXH_STRIPE_LEN);
+ }
+
+ /* last partial block */
+ XXH_ASSERT(len > XXH_STRIPE_LEN);
+ { size_t const nbStripes = ((len - 1) - (block_len * nb_blocks)) / XXH_STRIPE_LEN;
+ XXH_ASSERT(nbStripes <= (secretSize / XXH_SECRET_CONSUME_RATE));
+ XXH3_accumulate(acc, input + nb_blocks*block_len, secret, nbStripes, f_acc512);
+
+ /* last stripe */
+ { const xxh_u8* const p = input + len - XXH_STRIPE_LEN;
+#define XXH_SECRET_LASTACC_START 7 /* not aligned on 8, last secret is different from acc & scrambler */
+ f_acc512(acc, p, secret + secretSize - XXH_STRIPE_LEN - XXH_SECRET_LASTACC_START);
+ } }
+}
+
+XXH_FORCE_INLINE xxh_u64
+XXH3_mix2Accs(const xxh_u64* XXH_RESTRICT acc, const xxh_u8* XXH_RESTRICT secret)
+{
+ return XXH3_mul128_fold64(
+ acc[0] ^ XXH_readLE64(secret),
+ acc[1] ^ XXH_readLE64(secret+8) );
+}
+
+static XXH64_hash_t
+XXH3_mergeAccs(const xxh_u64* XXH_RESTRICT acc, const xxh_u8* XXH_RESTRICT secret, xxh_u64 start)
+{
+ xxh_u64 result64 = start;
+ size_t i = 0;
+
+ for (i = 0; i < 4; i++) {
+ result64 += XXH3_mix2Accs(acc+2*i, secret + 16*i);
+#if defined(__clang__) /* Clang */ \
+ && (defined(__arm__) || defined(__thumb__)) /* ARMv7 */ \
+ && (defined(__ARM_NEON) || defined(__ARM_NEON__)) /* NEON */ \
+ && !defined(XXH_ENABLE_AUTOVECTORIZE) /* Define to disable */
+ /*
+ * UGLY HACK:
+ * Prevent autovectorization on Clang ARMv7-a. Exact same problem as
+ * the one in XXH3_len_129to240_64b. Speeds up shorter keys > 240b.
+ * XXH3_64bits, len == 256, Snapdragon 835:
+ * without hack: 2063.7 MB/s
+ * with hack: 2560.7 MB/s
+ */
+ XXH_COMPILER_GUARD(result64);
+#endif
+ }
+
+ return XXH3_avalanche(result64);
+}
+
+#define XXH3_INIT_ACC { XXH_PRIME32_3, XXH_PRIME64_1, XXH_PRIME64_2, XXH_PRIME64_3, \
+ XXH_PRIME64_4, XXH_PRIME32_2, XXH_PRIME64_5, XXH_PRIME32_1 }
+
+XXH_FORCE_INLINE XXH64_hash_t
+XXH3_hashLong_64b_internal(const void* XXH_RESTRICT input, size_t len,
+ const void* XXH_RESTRICT secret, size_t secretSize,
+ XXH3_f_accumulate_512 f_acc512,
+ XXH3_f_scrambleAcc f_scramble)
+{
+ XXH_ALIGN(XXH_ACC_ALIGN) xxh_u64 acc[XXH_ACC_NB] = XXH3_INIT_ACC;
+
+ XXH3_hashLong_internal_loop(acc, (const xxh_u8*)input, len, (const xxh_u8*)secret, secretSize, f_acc512, f_scramble);
+
+ /* converge into final hash */
+ XXH_STATIC_ASSERT(sizeof(acc) == 64);
+ /* do not align on 8, so that the secret is different from the accumulator */
+#define XXH_SECRET_MERGEACCS_START 11
+ XXH_ASSERT(secretSize >= sizeof(acc) + XXH_SECRET_MERGEACCS_START);
+ return XXH3_mergeAccs(acc, (const xxh_u8*)secret + XXH_SECRET_MERGEACCS_START, (xxh_u64)len * XXH_PRIME64_1);
+}
+
+/*
+ * It's important for performance that XXH3_hashLong is not inlined.
+ */
+XXH_NO_INLINE XXH64_hash_t
+XXH3_hashLong_64b_withSecret(const void* XXH_RESTRICT input, size_t len,
+ XXH64_hash_t seed64, const xxh_u8* XXH_RESTRICT secret, size_t secretLen)
+{
+ (void)seed64;
+ return XXH3_hashLong_64b_internal(input, len, secret, secretLen, XXH3_accumulate_512, XXH3_scrambleAcc);
+}
+
+/*
+ * It's important for performance that XXH3_hashLong is not inlined.
+ * Since the function is not inlined, the compiler may not be able to understand that,
+ * in some scenarios, its `secret` argument is actually a compile time constant.
+ * This variant enforces that the compiler can detect that,
+ * and uses this opportunity to streamline the generated code for better performance.
+ */
+XXH_NO_INLINE XXH64_hash_t
+XXH3_hashLong_64b_default(const void* XXH_RESTRICT input, size_t len,
+ XXH64_hash_t seed64, const xxh_u8* XXH_RESTRICT secret, size_t secretLen)
+{
+ (void)seed64; (void)secret; (void)secretLen;
+ return XXH3_hashLong_64b_internal(input, len, XXH3_kSecret, sizeof(XXH3_kSecret), XXH3_accumulate_512, XXH3_scrambleAcc);
+}
+
+/*
+ * XXH3_hashLong_64b_withSeed():
+ * Generate a custom key based on alteration of default XXH3_kSecret with the seed,
+ * and then use this key for long mode hashing.
+ *
+ * This operation is decently fast but nonetheless costs a little bit of time.
+ * Try to avoid it whenever possible (typically when seed==0).
+ *
+ * It's important for performance that XXH3_hashLong is not inlined. Not sure
+ * why (uop cache maybe?), but the difference is large and easily measurable.
+ */
+XXH_FORCE_INLINE XXH64_hash_t
+XXH3_hashLong_64b_withSeed_internal(const void* input, size_t len,
+ XXH64_hash_t seed,
+ XXH3_f_accumulate_512 f_acc512,
+ XXH3_f_scrambleAcc f_scramble,
+ XXH3_f_initCustomSecret f_initSec)
+{
+ if (seed == 0)
+ return XXH3_hashLong_64b_internal(input, len,
+ XXH3_kSecret, sizeof(XXH3_kSecret),
+ f_acc512, f_scramble);
+ { XXH_ALIGN(XXH_SEC_ALIGN) xxh_u8 secret[XXH_SECRET_DEFAULT_SIZE];
+ f_initSec(secret, seed);
+ return XXH3_hashLong_64b_internal(input, len, secret, sizeof(secret),
+ f_acc512, f_scramble);
+ }
+}
+
+/*
+ * It's important for performance that XXH3_hashLong is not inlined.
+ */
+XXH_NO_INLINE XXH64_hash_t
+XXH3_hashLong_64b_withSeed(const void* input, size_t len,
+ XXH64_hash_t seed, const xxh_u8* secret, size_t secretLen)
+{
+ (void)secret; (void)secretLen;
+ return XXH3_hashLong_64b_withSeed_internal(input, len, seed,
+ XXH3_accumulate_512, XXH3_scrambleAcc, XXH3_initCustomSecret);
+}
+
+
+typedef XXH64_hash_t (*XXH3_hashLong64_f)(const void* XXH_RESTRICT, size_t,
+ XXH64_hash_t, const xxh_u8* XXH_RESTRICT, size_t);
+
+XXH_FORCE_INLINE XXH64_hash_t
+XXH3_64bits_internal(const void* XXH_RESTRICT input, size_t len,
+ XXH64_hash_t seed64, const void* XXH_RESTRICT secret, size_t secretLen,
+ XXH3_hashLong64_f f_hashLong)
+{
+ XXH_ASSERT(secretLen >= XXH3_SECRET_SIZE_MIN);
+ /*
+ * If an action is to be taken if `secretLen` condition is not respected,
+ * it should be done here.
+ * For now, it's a contract pre-condition.
+ * Adding a check and a branch here would cost performance at every hash.
+ * Also, note that function signature doesn't offer room to return an error.
+ */
+ if (len <= 16)
+ return XXH3_len_0to16_64b((const xxh_u8*)input, len, (const xxh_u8*)secret, seed64);
+ if (len <= 128)
+ return XXH3_len_17to128_64b((const xxh_u8*)input, len, (const xxh_u8*)secret, secretLen, seed64);
+ if (len <= XXH3_MIDSIZE_MAX)
+ return XXH3_len_129to240_64b((const xxh_u8*)input, len, (const xxh_u8*)secret, secretLen, seed64);
+ return f_hashLong(input, len, seed64, (const xxh_u8*)secret, secretLen);
+}
+
+
+/* === Public entry point === */
+
+/*! @ingroup xxh3_family */
+XXH_PUBLIC_API XXH64_hash_t XXH3_64bits(const void* input, size_t len)
+{
+ return XXH3_64bits_internal(input, len, 0, XXH3_kSecret, sizeof(XXH3_kSecret), XXH3_hashLong_64b_default);
+}
+
+/*! @ingroup xxh3_family */
+XXH_PUBLIC_API XXH64_hash_t
+XXH3_64bits_withSecret(const void* input, size_t len, const void* secret, size_t secretSize)
+{
+ return XXH3_64bits_internal(input, len, 0, secret, secretSize, XXH3_hashLong_64b_withSecret);
+}
+
+/*! @ingroup xxh3_family */
+XXH_PUBLIC_API XXH64_hash_t
+XXH3_64bits_withSeed(const void* input, size_t len, XXH64_hash_t seed)
+{
+ return XXH3_64bits_internal(input, len, seed, XXH3_kSecret, sizeof(XXH3_kSecret), XXH3_hashLong_64b_withSeed);
+}
+
+
+/* === XXH3 streaming === */
+
+/*
+ * Malloc's a pointer that is always aligned to align.
+ *
+ * This must be freed with `XXH_alignedFree()`.
+ *
+ * malloc typically guarantees 16 byte alignment on 64-bit systems and 8 byte
+ * alignment on 32-bit. This isn't enough for the 32 byte aligned loads in AVX2
+ * or on 32-bit, the 16 byte aligned loads in SSE2 and NEON.
+ *
+ * This underalignment previously caused a rather obvious crash which went
+ * completely unnoticed due to XXH3_createState() not actually being tested.
+ * Credit to RedSpah for noticing this bug.
+ *
+ * The alignment is done manually: Functions like posix_memalign or _mm_malloc
+ * are avoided: To maintain portability, we would have to write a fallback
+ * like this anyways, and besides, testing for the existence of library
+ * functions without relying on external build tools is impossible.
+ *
+ * The method is simple: Overallocate, manually align, and store the offset
+ * to the original behind the returned pointer.
+ *
+ * Align must be a power of 2 and 8 <= align <= 128.
+ */
+static void* XXH_alignedMalloc(size_t s, size_t align)
+{
+ XXH_ASSERT(align <= 128 && align >= 8); /* range check */
+ XXH_ASSERT((align & (align-1)) == 0); /* power of 2 */
+ XXH_ASSERT(s != 0 && s < (s + align)); /* empty/overflow */
+ { /* Overallocate to make room for manual realignment and an offset byte */
+ xxh_u8* base = (xxh_u8*)XXH_malloc(s + align);
+ if (base != NULL) {
+ /*
+ * Get the offset needed to align this pointer.
+ *
+ * Even if the returned pointer is aligned, there will always be
+ * at least one byte to store the offset to the original pointer.
+ */
+ size_t offset = align - ((size_t)base & (align - 1)); /* base % align */
+ /* Add the offset for the now-aligned pointer */
+ xxh_u8* ptr = base + offset;
+
+ XXH_ASSERT((size_t)ptr % align == 0);
+
+ /* Store the offset immediately before the returned pointer. */
+ ptr[-1] = (xxh_u8)offset;
+ return ptr;
+ }
+ return NULL;
+ }
+}
+/*
+ * Frees an aligned pointer allocated by XXH_alignedMalloc(). Don't pass
+ * normal malloc'd pointers, XXH_alignedMalloc has a specific data layout.
+ */
+static void XXH_alignedFree(void* p)
+{
+ if (p != NULL) {
+ xxh_u8* ptr = (xxh_u8*)p;
+ /* Get the offset byte we added in XXH_malloc. */
+ xxh_u8 offset = ptr[-1];
+ /* Free the original malloc'd pointer */
+ xxh_u8* base = ptr - offset;
+ XXH_free(base);
+ }
+}
+/*! @ingroup xxh3_family */
+XXH_PUBLIC_API XXH3_state_t* XXH3_createState(void)
+{
+ XXH3_state_t* const state = (XXH3_state_t*)XXH_alignedMalloc(sizeof(XXH3_state_t), 64);
+ if (state==NULL) return NULL;
+ XXH3_INITSTATE(state);
+ return state;
+}
+
+/*! @ingroup xxh3_family */
+XXH_PUBLIC_API XXH_errorcode XXH3_freeState(XXH3_state_t* statePtr)
+{
+ XXH_alignedFree(statePtr);
+ return XXH_OK;
+}
+
+/*! @ingroup xxh3_family */
+XXH_PUBLIC_API void
+XXH3_copyState(XXH3_state_t* dst_state, const XXH3_state_t* src_state)
+{
+ memcpy(dst_state, src_state, sizeof(*dst_state));
+}
+
+static void
+XXH3_reset_internal(XXH3_state_t* statePtr,
+ XXH64_hash_t seed,
+ const void* secret, size_t secretSize)
+{
+ size_t const initStart = offsetof(XXH3_state_t, bufferedSize);
+ size_t const initLength = offsetof(XXH3_state_t, nbStripesPerBlock) - initStart;
+ XXH_ASSERT(offsetof(XXH3_state_t, nbStripesPerBlock) > initStart);
+ XXH_ASSERT(statePtr != NULL);
+ /* set members from bufferedSize to nbStripesPerBlock (excluded) to 0 */
+ memset((char*)statePtr + initStart, 0, initLength);
+ statePtr->acc[0] = XXH_PRIME32_3;
+ statePtr->acc[1] = XXH_PRIME64_1;
+ statePtr->acc[2] = XXH_PRIME64_2;
+ statePtr->acc[3] = XXH_PRIME64_3;
+ statePtr->acc[4] = XXH_PRIME64_4;
+ statePtr->acc[5] = XXH_PRIME32_2;
+ statePtr->acc[6] = XXH_PRIME64_5;
+ statePtr->acc[7] = XXH_PRIME32_1;
+ statePtr->seed = seed;
+ statePtr->extSecret = (const unsigned char*)secret;
+ XXH_ASSERT(secretSize >= XXH3_SECRET_SIZE_MIN);
+ statePtr->secretLimit = secretSize - XXH_STRIPE_LEN;
+ statePtr->nbStripesPerBlock = statePtr->secretLimit / XXH_SECRET_CONSUME_RATE;
+}
+
+/*! @ingroup xxh3_family */
+XXH_PUBLIC_API XXH_errorcode
+XXH3_64bits_reset(XXH3_state_t* statePtr)
+{
+ if (statePtr == NULL) return XXH_ERROR;
+ XXH3_reset_internal(statePtr, 0, XXH3_kSecret, XXH_SECRET_DEFAULT_SIZE);
+ return XXH_OK;
+}
+
+/*! @ingroup xxh3_family */
+XXH_PUBLIC_API XXH_errorcode
+XXH3_64bits_reset_withSecret(XXH3_state_t* statePtr, const void* secret, size_t secretSize)
+{
+ if (statePtr == NULL) return XXH_ERROR;
+ XXH3_reset_internal(statePtr, 0, secret, secretSize);
+ if (secret == NULL) return XXH_ERROR;
+ if (secretSize < XXH3_SECRET_SIZE_MIN) return XXH_ERROR;
+ return XXH_OK;
+}
+
+/*! @ingroup xxh3_family */
+XXH_PUBLIC_API XXH_errorcode
+XXH3_64bits_reset_withSeed(XXH3_state_t* statePtr, XXH64_hash_t seed)
+{
+ if (statePtr == NULL) return XXH_ERROR;
+ if (seed==0) return XXH3_64bits_reset(statePtr);
+ if (seed != statePtr->seed) XXH3_initCustomSecret(statePtr->customSecret, seed);
+ XXH3_reset_internal(statePtr, seed, NULL, XXH_SECRET_DEFAULT_SIZE);
+ return XXH_OK;
+}
+
+/* Note : when XXH3_consumeStripes() is invoked,
+ * there must be a guarantee that at least one more byte must be consumed from input
+ * so that the function can blindly consume all stripes using the "normal" secret segment */
+XXH_FORCE_INLINE void
+XXH3_consumeStripes(xxh_u64* XXH_RESTRICT acc,
+ size_t* XXH_RESTRICT nbStripesSoFarPtr, size_t nbStripesPerBlock,
+ const xxh_u8* XXH_RESTRICT input, size_t nbStripes,
+ const xxh_u8* XXH_RESTRICT secret, size_t secretLimit,
+ XXH3_f_accumulate_512 f_acc512,
+ XXH3_f_scrambleAcc f_scramble)
+{
+ XXH_ASSERT(nbStripes <= nbStripesPerBlock); /* can handle max 1 scramble per invocation */
+ XXH_ASSERT(*nbStripesSoFarPtr < nbStripesPerBlock);
+ if (nbStripesPerBlock - *nbStripesSoFarPtr <= nbStripes) {
+ /* need a scrambling operation */
+ size_t const nbStripesToEndofBlock = nbStripesPerBlock - *nbStripesSoFarPtr;
+ size_t const nbStripesAfterBlock = nbStripes - nbStripesToEndofBlock;
+ XXH3_accumulate(acc, input, secret + nbStripesSoFarPtr[0] * XXH_SECRET_CONSUME_RATE, nbStripesToEndofBlock, f_acc512);
+ f_scramble(acc, secret + secretLimit);
+ XXH3_accumulate(acc, input + nbStripesToEndofBlock * XXH_STRIPE_LEN, secret, nbStripesAfterBlock, f_acc512);
+ *nbStripesSoFarPtr = nbStripesAfterBlock;
+ } else {
+ XXH3_accumulate(acc, input, secret + nbStripesSoFarPtr[0] * XXH_SECRET_CONSUME_RATE, nbStripes, f_acc512);
+ *nbStripesSoFarPtr += nbStripes;
+ }
+}
+
+/*
+ * Both XXH3_64bits_update and XXH3_128bits_update use this routine.
+ */
+XXH_FORCE_INLINE XXH_errorcode
+XXH3_update(XXH3_state_t* state,
+ const xxh_u8* input, size_t len,
+ XXH3_f_accumulate_512 f_acc512,
+ XXH3_f_scrambleAcc f_scramble)
+{
+ if (input==NULL)
+#if defined(XXH_ACCEPT_NULL_INPUT_POINTER) && (XXH_ACCEPT_NULL_INPUT_POINTER>=1)
+ return XXH_OK;
+#else
+ return XXH_ERROR;
+#endif
+
+ { const xxh_u8* const bEnd = input + len;
+ const unsigned char* const secret = (state->extSecret == NULL) ? state->customSecret : state->extSecret;
+
+ state->totalLen += len;
+ XXH_ASSERT(state->bufferedSize <= XXH3_INTERNALBUFFER_SIZE);
+
+ if (state->bufferedSize + len <= XXH3_INTERNALBUFFER_SIZE) { /* fill in tmp buffer */
+ XXH_memcpy(state->buffer + state->bufferedSize, input, len);
+ state->bufferedSize += (XXH32_hash_t)len;
+ return XXH_OK;
+ }
+ /* total input is now > XXH3_INTERNALBUFFER_SIZE */
+
+ #define XXH3_INTERNALBUFFER_STRIPES (XXH3_INTERNALBUFFER_SIZE / XXH_STRIPE_LEN)
+ XXH_STATIC_ASSERT(XXH3_INTERNALBUFFER_SIZE % XXH_STRIPE_LEN == 0); /* clean multiple */
+
+ /*
+ * Internal buffer is partially filled (always, except at beginning)
+ * Complete it, then consume it.
+ */
+ if (state->bufferedSize) {
+ size_t const loadSize = XXH3_INTERNALBUFFER_SIZE - state->bufferedSize;
+ XXH_memcpy(state->buffer + state->bufferedSize, input, loadSize);
+ input += loadSize;
+ XXH3_consumeStripes(state->acc,
+ &state->nbStripesSoFar, state->nbStripesPerBlock,
+ state->buffer, XXH3_INTERNALBUFFER_STRIPES,
+ secret, state->secretLimit,
+ f_acc512, f_scramble);
+ state->bufferedSize = 0;
+ }
+ XXH_ASSERT(input < bEnd);
+
+ /* Consume input by a multiple of internal buffer size */
+ if (input+XXH3_INTERNALBUFFER_SIZE < bEnd) {
+ const xxh_u8* const limit = bEnd - XXH3_INTERNALBUFFER_SIZE;
+ do {
+ XXH3_consumeStripes(state->acc,
+ &state->nbStripesSoFar, state->nbStripesPerBlock,
+ input, XXH3_INTERNALBUFFER_STRIPES,
+ secret, state->secretLimit,
+ f_acc512, f_scramble);
+ input += XXH3_INTERNALBUFFER_SIZE;
+ } while (input<limit);
+ /* for last partial stripe */
+ memcpy(state->buffer + sizeof(state->buffer) - XXH_STRIPE_LEN, input - XXH_STRIPE_LEN, XXH_STRIPE_LEN);
+ }
+ XXH_ASSERT(input < bEnd);
+
+ /* Some remaining input (always) : buffer it */
+ XXH_memcpy(state->buffer, input, (size_t)(bEnd-input));
+ state->bufferedSize = (XXH32_hash_t)(bEnd-input);
+ }
+
+ return XXH_OK;
+}
+
+/*! @ingroup xxh3_family */
+XXH_PUBLIC_API XXH_errorcode
+XXH3_64bits_update(XXH3_state_t* state, const void* input, size_t len)
+{
+ return XXH3_update(state, (const xxh_u8*)input, len,
+ XXH3_accumulate_512, XXH3_scrambleAcc);
+}
+
+
+XXH_FORCE_INLINE void
+XXH3_digest_long (XXH64_hash_t* acc,
+ const XXH3_state_t* state,
+ const unsigned char* secret)
+{
+ /*
+ * Digest on a local copy. This way, the state remains unaltered, and it can
+ * continue ingesting more input afterwards.
+ */
+ memcpy(acc, state->acc, sizeof(state->acc));
+ if (state->bufferedSize >= XXH_STRIPE_LEN) {
+ size_t const nbStripes = (state->bufferedSize - 1) / XXH_STRIPE_LEN;
+ size_t nbStripesSoFar = state->nbStripesSoFar;
+ XXH3_consumeStripes(acc,
+ &nbStripesSoFar, state->nbStripesPerBlock,
+ state->buffer, nbStripes,
+ secret, state->secretLimit,
+ XXH3_accumulate_512, XXH3_scrambleAcc);
+ /* last stripe */
+ XXH3_accumulate_512(acc,
+ state->buffer + state->bufferedSize - XXH_STRIPE_LEN,
+ secret + state->secretLimit - XXH_SECRET_LASTACC_START);
+ } else { /* bufferedSize < XXH_STRIPE_LEN */
+ xxh_u8 lastStripe[XXH_STRIPE_LEN];
+ size_t const catchupSize = XXH_STRIPE_LEN - state->bufferedSize;
+ XXH_ASSERT(state->bufferedSize > 0); /* there is always some input buffered */
+ memcpy(lastStripe, state->buffer + sizeof(state->buffer) - catchupSize, catchupSize);
+ memcpy(lastStripe + catchupSize, state->buffer, state->bufferedSize);
+ XXH3_accumulate_512(acc,
+ lastStripe,
+ secret + state->secretLimit - XXH_SECRET_LASTACC_START);
+ }
+}
+
+/*! @ingroup xxh3_family */
+XXH_PUBLIC_API XXH64_hash_t XXH3_64bits_digest (const XXH3_state_t* state)
+{
+ const unsigned char* const secret = (state->extSecret == NULL) ? state->customSecret : state->extSecret;
+ if (state->totalLen > XXH3_MIDSIZE_MAX) {
+ XXH_ALIGN(XXH_ACC_ALIGN) XXH64_hash_t acc[XXH_ACC_NB];
+ XXH3_digest_long(acc, state, secret);
+ return XXH3_mergeAccs(acc,
+ secret + XXH_SECRET_MERGEACCS_START,
+ (xxh_u64)state->totalLen * XXH_PRIME64_1);
+ }
+ /* totalLen <= XXH3_MIDSIZE_MAX: digesting a short input */
+ if (state->seed)
+ return XXH3_64bits_withSeed(state->buffer, (size_t)state->totalLen, state->seed);
+ return XXH3_64bits_withSecret(state->buffer, (size_t)(state->totalLen),
+ secret, state->secretLimit + XXH_STRIPE_LEN);
+}
+
+
+#define XXH_MIN(x, y) (((x) > (y)) ? (y) : (x))
+
+/*! @ingroup xxh3_family */
+XXH_PUBLIC_API void
+XXH3_generateSecret(void* secretBuffer, const void* customSeed, size_t customSeedSize)
+{
+ XXH_ASSERT(secretBuffer != NULL);
+ if (customSeedSize == 0) {
+ memcpy(secretBuffer, XXH3_kSecret, XXH_SECRET_DEFAULT_SIZE);
+ return;
+ }
+ XXH_ASSERT(customSeed != NULL);
+
+ { size_t const segmentSize = sizeof(XXH128_hash_t);
+ size_t const nbSegments = XXH_SECRET_DEFAULT_SIZE / segmentSize;
+ XXH128_canonical_t scrambler;
+ XXH64_hash_t seeds[12];
+ size_t segnb;
+ XXH_ASSERT(nbSegments == 12);
+ XXH_ASSERT(segmentSize * nbSegments == XXH_SECRET_DEFAULT_SIZE); /* exact multiple */
+ XXH128_canonicalFromHash(&scrambler, XXH128(customSeed, customSeedSize, 0));
+
+ /*
+ * Copy customSeed to seeds[], truncating or repeating as necessary.
+ */
+ { size_t toFill = XXH_MIN(customSeedSize, sizeof(seeds));
+ size_t filled = toFill;
+ memcpy(seeds, customSeed, toFill);
+ while (filled < sizeof(seeds)) {
+ toFill = XXH_MIN(filled, sizeof(seeds) - filled);
+ memcpy((char*)seeds + filled, seeds, toFill);
+ filled += toFill;
+ } }
+
+ /* generate secret */
+ memcpy(secretBuffer, &scrambler, sizeof(scrambler));
+ for (segnb=1; segnb < nbSegments; segnb++) {
+ size_t const segmentStart = segnb * segmentSize;
+ XXH128_canonical_t segment;
+ XXH128_canonicalFromHash(&segment,
+ XXH128(&scrambler, sizeof(scrambler), XXH_readLE64(seeds + segnb) + segnb) );
+ memcpy((char*)secretBuffer + segmentStart, &segment, sizeof(segment));
+ } }
+}
+
+
+/* ==========================================
+ * XXH3 128 bits (a.k.a XXH128)
+ * ==========================================
+ * XXH3's 128-bit variant has better mixing and strength than the 64-bit variant,
+ * even without counting the significantly larger output size.
+ *
+ * For example, extra steps are taken to avoid the seed-dependent collisions
+ * in 17-240 byte inputs (See XXH3_mix16B and XXH128_mix32B).
+ *
+ * This strength naturally comes at the cost of some speed, especially on short
+ * lengths. Note that longer hashes are about as fast as the 64-bit version
+ * due to it using only a slight modification of the 64-bit loop.
+ *
+ * XXH128 is also more oriented towards 64-bit machines. It is still extremely
+ * fast for a _128-bit_ hash on 32-bit (it usually clears XXH64).
+ */
+
+XXH_FORCE_INLINE XXH128_hash_t
+XXH3_len_1to3_128b(const xxh_u8* input, size_t len, const xxh_u8* secret, XXH64_hash_t seed)
+{
+ /* A doubled version of 1to3_64b with different constants. */
+ XXH_ASSERT(input != NULL);
+ XXH_ASSERT(1 <= len && len <= 3);
+ XXH_ASSERT(secret != NULL);
+ /*
+ * len = 1: combinedl = { input[0], 0x01, input[0], input[0] }
+ * len = 2: combinedl = { input[1], 0x02, input[0], input[1] }
+ * len = 3: combinedl = { input[2], 0x03, input[0], input[1] }
+ */
+ { xxh_u8 const c1 = input[0];
+ xxh_u8 const c2 = input[len >> 1];
+ xxh_u8 const c3 = input[len - 1];
+ xxh_u32 const combinedl = ((xxh_u32)c1 <<16) | ((xxh_u32)c2 << 24)
+ | ((xxh_u32)c3 << 0) | ((xxh_u32)len << 8);
+ xxh_u32 const combinedh = XXH_rotl32(XXH_swap32(combinedl), 13);
+ xxh_u64 const bitflipl = (XXH_readLE32(secret) ^ XXH_readLE32(secret+4)) + seed;
+ xxh_u64 const bitfliph = (XXH_readLE32(secret+8) ^ XXH_readLE32(secret+12)) - seed;
+ xxh_u64 const keyed_lo = (xxh_u64)combinedl ^ bitflipl;
+ xxh_u64 const keyed_hi = (xxh_u64)combinedh ^ bitfliph;
+ XXH128_hash_t h128;
+ h128.low64 = XXH64_avalanche(keyed_lo);
+ h128.high64 = XXH64_avalanche(keyed_hi);
+ return h128;
+ }
+}
+
+XXH_FORCE_INLINE XXH128_hash_t
+XXH3_len_4to8_128b(const xxh_u8* input, size_t len, const xxh_u8* secret, XXH64_hash_t seed)
+{
+ XXH_ASSERT(input != NULL);
+ XXH_ASSERT(secret != NULL);
+ XXH_ASSERT(4 <= len && len <= 8);
+ seed ^= (xxh_u64)XXH_swap32((xxh_u32)seed) << 32;
+ { xxh_u32 const input_lo = XXH_readLE32(input);
+ xxh_u32 const input_hi = XXH_readLE32(input + len - 4);
+ xxh_u64 const input_64 = input_lo + ((xxh_u64)input_hi << 32);
+ xxh_u64 const bitflip = (XXH_readLE64(secret+16) ^ XXH_readLE64(secret+24)) + seed;
+ xxh_u64 const keyed = input_64 ^ bitflip;
+
+ /* Shift len to the left to ensure it is even, this avoids even multiplies. */
+ XXH128_hash_t m128 = XXH_mult64to128(keyed, XXH_PRIME64_1 + (len << 2));
+
+ m128.high64 += (m128.low64 << 1);
+ m128.low64 ^= (m128.high64 >> 3);
+
+ m128.low64 = XXH_xorshift64(m128.low64, 35);
+ m128.low64 *= 0x9FB21C651E98DF25ULL;
+ m128.low64 = XXH_xorshift64(m128.low64, 28);
+ m128.high64 = XXH3_avalanche(m128.high64);
+ return m128;
+ }
+}
+
+XXH_FORCE_INLINE XXH128_hash_t
+XXH3_len_9to16_128b(const xxh_u8* input, size_t len, const xxh_u8* secret, XXH64_hash_t seed)
+{
+ XXH_ASSERT(input != NULL);
+ XXH_ASSERT(secret != NULL);
+ XXH_ASSERT(9 <= len && len <= 16);
+ { xxh_u64 const bitflipl = (XXH_readLE64(secret+32) ^ XXH_readLE64(secret+40)) - seed;
+ xxh_u64 const bitfliph = (XXH_readLE64(secret+48) ^ XXH_readLE64(secret+56)) + seed;
+ xxh_u64 const input_lo = XXH_readLE64(input);
+ xxh_u64 input_hi = XXH_readLE64(input + len - 8);
+ XXH128_hash_t m128 = XXH_mult64to128(input_lo ^ input_hi ^ bitflipl, XXH_PRIME64_1);
+ /*
+ * Put len in the middle of m128 to ensure that the length gets mixed to
+ * both the low and high bits in the 128x64 multiply below.
+ */
+ m128.low64 += (xxh_u64)(len - 1) << 54;
+ input_hi ^= bitfliph;
+ /*
+ * Add the high 32 bits of input_hi to the high 32 bits of m128, then
+ * add the long product of the low 32 bits of input_hi and XXH_PRIME32_2 to
+ * the high 64 bits of m128.
+ *
+ * The best approach to this operation is different on 32-bit and 64-bit.
+ */
+ if (sizeof(void *) < sizeof(xxh_u64)) { /* 32-bit */
+ /*
+ * 32-bit optimized version, which is more readable.
+ *
+ * On 32-bit, it removes an ADC and delays a dependency between the two
+ * halves of m128.high64, but it generates an extra mask on 64-bit.
+ */
+ m128.high64 += (input_hi & 0xFFFFFFFF00000000ULL) + XXH_mult32to64((xxh_u32)input_hi, XXH_PRIME32_2);
+ } else {
+ /*
+ * 64-bit optimized (albeit more confusing) version.
+ *
+ * Uses some properties of addition and multiplication to remove the mask:
+ *
+ * Let:
+ * a = input_hi.lo = (input_hi & 0x00000000FFFFFFFF)
+ * b = input_hi.hi = (input_hi & 0xFFFFFFFF00000000)
+ * c = XXH_PRIME32_2
+ *
+ * a + (b * c)
+ * Inverse Property: x + y - x == y
+ * a + (b * (1 + c - 1))
+ * Distributive Property: x * (y + z) == (x * y) + (x * z)
+ * a + (b * 1) + (b * (c - 1))
+ * Identity Property: x * 1 == x
+ * a + b + (b * (c - 1))
+ *
+ * Substitute a, b, and c:
+ * input_hi.hi + input_hi.lo + ((xxh_u64)input_hi.lo * (XXH_PRIME32_2 - 1))
+ *
+ * Since input_hi.hi + input_hi.lo == input_hi, we get this:
+ * input_hi + ((xxh_u64)input_hi.lo * (XXH_PRIME32_2 - 1))
+ */
+ m128.high64 += input_hi + XXH_mult32to64((xxh_u32)input_hi, XXH_PRIME32_2 - 1);
+ }
+ /* m128 ^= XXH_swap64(m128 >> 64); */
+ m128.low64 ^= XXH_swap64(m128.high64);
+
+ { /* 128x64 multiply: h128 = m128 * XXH_PRIME64_2; */
+ XXH128_hash_t h128 = XXH_mult64to128(m128.low64, XXH_PRIME64_2);
+ h128.high64 += m128.high64 * XXH_PRIME64_2;
+
+ h128.low64 = XXH3_avalanche(h128.low64);
+ h128.high64 = XXH3_avalanche(h128.high64);
+ return h128;
+ } }
+}
+
+/*
+ * Assumption: `secret` size is >= XXH3_SECRET_SIZE_MIN
+ */
+XXH_FORCE_INLINE XXH128_hash_t
+XXH3_len_0to16_128b(const xxh_u8* input, size_t len, const xxh_u8* secret, XXH64_hash_t seed)
+{
+ XXH_ASSERT(len <= 16);
+ { if (len > 8) return XXH3_len_9to16_128b(input, len, secret, seed);
+ if (len >= 4) return XXH3_len_4to8_128b(input, len, secret, seed);
+ if (len) return XXH3_len_1to3_128b(input, len, secret, seed);
+ { XXH128_hash_t h128;
+ xxh_u64 const bitflipl = XXH_readLE64(secret+64) ^ XXH_readLE64(secret+72);
+ xxh_u64 const bitfliph = XXH_readLE64(secret+80) ^ XXH_readLE64(secret+88);
+ h128.low64 = XXH64_avalanche(seed ^ bitflipl);
+ h128.high64 = XXH64_avalanche( seed ^ bitfliph);
+ return h128;
+ } }
+}
+
+/*
+ * A bit slower than XXH3_mix16B, but handles multiply by zero better.
+ */
+XXH_FORCE_INLINE XXH128_hash_t
+XXH128_mix32B(XXH128_hash_t acc, const xxh_u8* input_1, const xxh_u8* input_2,
+ const xxh_u8* secret, XXH64_hash_t seed)
+{
+ acc.low64 += XXH3_mix16B (input_1, secret+0, seed);
+ acc.low64 ^= XXH_readLE64(input_2) + XXH_readLE64(input_2 + 8);
+ acc.high64 += XXH3_mix16B (input_2, secret+16, seed);
+ acc.high64 ^= XXH_readLE64(input_1) + XXH_readLE64(input_1 + 8);
+ return acc;
+}
+
+
+XXH_FORCE_INLINE XXH128_hash_t
+XXH3_len_17to128_128b(const xxh_u8* XXH_RESTRICT input, size_t len,
+ const xxh_u8* XXH_RESTRICT secret, size_t secretSize,
+ XXH64_hash_t seed)
+{
+ XXH_ASSERT(secretSize >= XXH3_SECRET_SIZE_MIN); (void)secretSize;
+ XXH_ASSERT(16 < len && len <= 128);
+
+ { XXH128_hash_t acc;
+ acc.low64 = len * XXH_PRIME64_1;
+ acc.high64 = 0;
+ if (len > 32) {
+ if (len > 64) {
+ if (len > 96) {
+ acc = XXH128_mix32B(acc, input+48, input+len-64, secret+96, seed);
+ }
+ acc = XXH128_mix32B(acc, input+32, input+len-48, secret+64, seed);
+ }
+ acc = XXH128_mix32B(acc, input+16, input+len-32, secret+32, seed);
+ }
+ acc = XXH128_mix32B(acc, input, input+len-16, secret, seed);
+ { XXH128_hash_t h128;
+ h128.low64 = acc.low64 + acc.high64;
+ h128.high64 = (acc.low64 * XXH_PRIME64_1)
+ + (acc.high64 * XXH_PRIME64_4)
+ + ((len - seed) * XXH_PRIME64_2);
+ h128.low64 = XXH3_avalanche(h128.low64);
+ h128.high64 = (XXH64_hash_t)0 - XXH3_avalanche(h128.high64);
+ return h128;
+ }
+ }
+}
+
+XXH_NO_INLINE XXH128_hash_t
+XXH3_len_129to240_128b(const xxh_u8* XXH_RESTRICT input, size_t len,
+ const xxh_u8* XXH_RESTRICT secret, size_t secretSize,
+ XXH64_hash_t seed)
+{
+ XXH_ASSERT(secretSize >= XXH3_SECRET_SIZE_MIN); (void)secretSize;
+ XXH_ASSERT(128 < len && len <= XXH3_MIDSIZE_MAX);
+
+ { XXH128_hash_t acc;
+ int const nbRounds = (int)len / 32;
+ int i;
+ acc.low64 = len * XXH_PRIME64_1;
+ acc.high64 = 0;
+ for (i=0; i<4; i++) {
+ acc = XXH128_mix32B(acc,
+ input + (32 * i),
+ input + (32 * i) + 16,
+ secret + (32 * i),
+ seed);
+ }
+ acc.low64 = XXH3_avalanche(acc.low64);
+ acc.high64 = XXH3_avalanche(acc.high64);
+ XXH_ASSERT(nbRounds >= 4);
+ for (i=4 ; i < nbRounds; i++) {
+ acc = XXH128_mix32B(acc,
+ input + (32 * i),
+ input + (32 * i) + 16,
+ secret + XXH3_MIDSIZE_STARTOFFSET + (32 * (i - 4)),
+ seed);
+ }
+ /* last bytes */
+ acc = XXH128_mix32B(acc,
+ input + len - 16,
+ input + len - 32,
+ secret + XXH3_SECRET_SIZE_MIN - XXH3_MIDSIZE_LASTOFFSET - 16,
+ 0ULL - seed);
+
+ { XXH128_hash_t h128;
+ h128.low64 = acc.low64 + acc.high64;
+ h128.high64 = (acc.low64 * XXH_PRIME64_1)
+ + (acc.high64 * XXH_PRIME64_4)
+ + ((len - seed) * XXH_PRIME64_2);
+ h128.low64 = XXH3_avalanche(h128.low64);
+ h128.high64 = (XXH64_hash_t)0 - XXH3_avalanche(h128.high64);
+ return h128;
+ }
+ }
+}
+
+XXH_FORCE_INLINE XXH128_hash_t
+XXH3_hashLong_128b_internal(const void* XXH_RESTRICT input, size_t len,
+ const xxh_u8* XXH_RESTRICT secret, size_t secretSize,
+ XXH3_f_accumulate_512 f_acc512,
+ XXH3_f_scrambleAcc f_scramble)
+{
+ XXH_ALIGN(XXH_ACC_ALIGN) xxh_u64 acc[XXH_ACC_NB] = XXH3_INIT_ACC;
+
+ XXH3_hashLong_internal_loop(acc, (const xxh_u8*)input, len, secret, secretSize, f_acc512, f_scramble);
+
+ /* converge into final hash */
+ XXH_STATIC_ASSERT(sizeof(acc) == 64);
+ XXH_ASSERT(secretSize >= sizeof(acc) + XXH_SECRET_MERGEACCS_START);
+ { XXH128_hash_t h128;
+ h128.low64 = XXH3_mergeAccs(acc,
+ secret + XXH_SECRET_MERGEACCS_START,
+ (xxh_u64)len * XXH_PRIME64_1);
+ h128.high64 = XXH3_mergeAccs(acc,
+ secret + secretSize
+ - sizeof(acc) - XXH_SECRET_MERGEACCS_START,
+ ~((xxh_u64)len * XXH_PRIME64_2));
+ return h128;
+ }
+}
+
+/*
+ * It's important for performance that XXH3_hashLong is not inlined.
+ */
+XXH_NO_INLINE XXH128_hash_t
+XXH3_hashLong_128b_default(const void* XXH_RESTRICT input, size_t len,
+ XXH64_hash_t seed64,
+ const void* XXH_RESTRICT secret, size_t secretLen)
+{
+ (void)seed64; (void)secret; (void)secretLen;
+ return XXH3_hashLong_128b_internal(input, len, XXH3_kSecret, sizeof(XXH3_kSecret),
+ XXH3_accumulate_512, XXH3_scrambleAcc);
+}
+
+/*
+ * It's important for performance that XXH3_hashLong is not inlined.
+ */
+XXH_NO_INLINE XXH128_hash_t
+XXH3_hashLong_128b_withSecret(const void* XXH_RESTRICT input, size_t len,
+ XXH64_hash_t seed64,
+ const void* XXH_RESTRICT secret, size_t secretLen)
+{
+ (void)seed64;
+ return XXH3_hashLong_128b_internal(input, len, (const xxh_u8*)secret, secretLen,
+ XXH3_accumulate_512, XXH3_scrambleAcc);
+}
+
+XXH_FORCE_INLINE XXH128_hash_t
+XXH3_hashLong_128b_withSeed_internal(const void* XXH_RESTRICT input, size_t len,
+ XXH64_hash_t seed64,
+ XXH3_f_accumulate_512 f_acc512,
+ XXH3_f_scrambleAcc f_scramble,
+ XXH3_f_initCustomSecret f_initSec)
+{
+ if (seed64 == 0)
+ return XXH3_hashLong_128b_internal(input, len,
+ XXH3_kSecret, sizeof(XXH3_kSecret),
+ f_acc512, f_scramble);
+ { XXH_ALIGN(XXH_SEC_ALIGN) xxh_u8 secret[XXH_SECRET_DEFAULT_SIZE];
+ f_initSec(secret, seed64);
+ return XXH3_hashLong_128b_internal(input, len, (const xxh_u8*)secret, sizeof(secret),
+ f_acc512, f_scramble);
+ }
+}
+
+/*
+ * It's important for performance that XXH3_hashLong is not inlined.
+ */
+XXH_NO_INLINE XXH128_hash_t
+XXH3_hashLong_128b_withSeed(const void* input, size_t len,
+ XXH64_hash_t seed64, const void* XXH_RESTRICT secret, size_t secretLen)
+{
+ (void)secret; (void)secretLen;
+ return XXH3_hashLong_128b_withSeed_internal(input, len, seed64,
+ XXH3_accumulate_512, XXH3_scrambleAcc, XXH3_initCustomSecret);
+}
+
+typedef XXH128_hash_t (*XXH3_hashLong128_f)(const void* XXH_RESTRICT, size_t,
+ XXH64_hash_t, const void* XXH_RESTRICT, size_t);
+
+XXH_FORCE_INLINE XXH128_hash_t
+XXH3_128bits_internal(const void* input, size_t len,
+ XXH64_hash_t seed64, const void* XXH_RESTRICT secret, size_t secretLen,
+ XXH3_hashLong128_f f_hl128)
+{
+ XXH_ASSERT(secretLen >= XXH3_SECRET_SIZE_MIN);
+ /*
+ * If an action is to be taken if `secret` conditions are not respected,
+ * it should be done here.
+ * For now, it's a contract pre-condition.
+ * Adding a check and a branch here would cost performance at every hash.
+ */
+ if (len <= 16)
+ return XXH3_len_0to16_128b((const xxh_u8*)input, len, (const xxh_u8*)secret, seed64);
+ if (len <= 128)
+ return XXH3_len_17to128_128b((const xxh_u8*)input, len, (const xxh_u8*)secret, secretLen, seed64);
+ if (len <= XXH3_MIDSIZE_MAX)
+ return XXH3_len_129to240_128b((const xxh_u8*)input, len, (const xxh_u8*)secret, secretLen, seed64);
+ return f_hl128(input, len, seed64, secret, secretLen);
+}
+
+
+/* === Public XXH128 API === */
+
+/*! @ingroup xxh3_family */
+XXH_PUBLIC_API XXH128_hash_t XXH3_128bits(const void* input, size_t len)
+{
+ return XXH3_128bits_internal(input, len, 0,
+ XXH3_kSecret, sizeof(XXH3_kSecret),
+ XXH3_hashLong_128b_default);
+}
+
+/*! @ingroup xxh3_family */
+XXH_PUBLIC_API XXH128_hash_t
+XXH3_128bits_withSecret(const void* input, size_t len, const void* secret, size_t secretSize)
+{
+ return XXH3_128bits_internal(input, len, 0,
+ (const xxh_u8*)secret, secretSize,
+ XXH3_hashLong_128b_withSecret);
+}
+
+/*! @ingroup xxh3_family */
+XXH_PUBLIC_API XXH128_hash_t
+XXH3_128bits_withSeed(const void* input, size_t len, XXH64_hash_t seed)
+{
+ return XXH3_128bits_internal(input, len, seed,
+ XXH3_kSecret, sizeof(XXH3_kSecret),
+ XXH3_hashLong_128b_withSeed);
+}
+
+/*! @ingroup xxh3_family */
+XXH_PUBLIC_API XXH128_hash_t
+XXH128(const void* input, size_t len, XXH64_hash_t seed)
+{
+ return XXH3_128bits_withSeed(input, len, seed);
+}
+
+
+/* === XXH3 128-bit streaming === */
+
+/*
+ * All the functions are actually the same as for 64-bit streaming variant.
+ * The only difference is the finalization routine.
+ */
+
+/*! @ingroup xxh3_family */
+XXH_PUBLIC_API XXH_errorcode
+XXH3_128bits_reset(XXH3_state_t* statePtr)
+{
+ if (statePtr == NULL) return XXH_ERROR;
+ XXH3_reset_internal(statePtr, 0, XXH3_kSecret, XXH_SECRET_DEFAULT_SIZE);
+ return XXH_OK;
+}
+
+/*! @ingroup xxh3_family */
+XXH_PUBLIC_API XXH_errorcode
+XXH3_128bits_reset_withSecret(XXH3_state_t* statePtr, const void* secret, size_t secretSize)
+{
+ if (statePtr == NULL) return XXH_ERROR;
+ XXH3_reset_internal(statePtr, 0, secret, secretSize);
+ if (secret == NULL) return XXH_ERROR;
+ if (secretSize < XXH3_SECRET_SIZE_MIN) return XXH_ERROR;
+ return XXH_OK;
+}
+
+/*! @ingroup xxh3_family */
+XXH_PUBLIC_API XXH_errorcode
+XXH3_128bits_reset_withSeed(XXH3_state_t* statePtr, XXH64_hash_t seed)
+{
+ if (statePtr == NULL) return XXH_ERROR;
+ if (seed==0) return XXH3_128bits_reset(statePtr);
+ if (seed != statePtr->seed) XXH3_initCustomSecret(statePtr->customSecret, seed);
+ XXH3_reset_internal(statePtr, seed, NULL, XXH_SECRET_DEFAULT_SIZE);
+ return XXH_OK;
+}
+
+/*! @ingroup xxh3_family */
+XXH_PUBLIC_API XXH_errorcode
+XXH3_128bits_update(XXH3_state_t* state, const void* input, size_t len)
+{
+ return XXH3_update(state, (const xxh_u8*)input, len,
+ XXH3_accumulate_512, XXH3_scrambleAcc);
+}
+
+/*! @ingroup xxh3_family */
+XXH_PUBLIC_API XXH128_hash_t XXH3_128bits_digest (const XXH3_state_t* state)
+{
+ const unsigned char* const secret = (state->extSecret == NULL) ? state->customSecret : state->extSecret;
+ if (state->totalLen > XXH3_MIDSIZE_MAX) {
+ XXH_ALIGN(XXH_ACC_ALIGN) XXH64_hash_t acc[XXH_ACC_NB];
+ XXH3_digest_long(acc, state, secret);
+ XXH_ASSERT(state->secretLimit + XXH_STRIPE_LEN >= sizeof(acc) + XXH_SECRET_MERGEACCS_START);
+ { XXH128_hash_t h128;
+ h128.low64 = XXH3_mergeAccs(acc,
+ secret + XXH_SECRET_MERGEACCS_START,
+ (xxh_u64)state->totalLen * XXH_PRIME64_1);
+ h128.high64 = XXH3_mergeAccs(acc,
+ secret + state->secretLimit + XXH_STRIPE_LEN
+ - sizeof(acc) - XXH_SECRET_MERGEACCS_START,
+ ~((xxh_u64)state->totalLen * XXH_PRIME64_2));
+ return h128;
+ }
+ }
+ /* len <= XXH3_MIDSIZE_MAX : short code */
+ if (state->seed)
+ return XXH3_128bits_withSeed(state->buffer, (size_t)state->totalLen, state->seed);
+ return XXH3_128bits_withSecret(state->buffer, (size_t)(state->totalLen),
+ secret, state->secretLimit + XXH_STRIPE_LEN);
+}
+
+/* 128-bit utility functions */
+
+#include <string.h> /* memcmp, memcpy */
+
+/* return : 1 is equal, 0 if different */
+/*! @ingroup xxh3_family */
+XXH_PUBLIC_API int XXH128_isEqual(XXH128_hash_t h1, XXH128_hash_t h2)
+{
+ /* note : XXH128_hash_t is compact, it has no padding byte */
+ return !(memcmp(&h1, &h2, sizeof(h1)));
+}
+
+/* This prototype is compatible with stdlib's qsort().
+ * return : >0 if *h128_1 > *h128_2
+ * <0 if *h128_1 < *h128_2
+ * =0 if *h128_1 == *h128_2 */
+/*! @ingroup xxh3_family */
+XXH_PUBLIC_API int XXH128_cmp(const void* h128_1, const void* h128_2)
+{
+ XXH128_hash_t const h1 = *(const XXH128_hash_t*)h128_1;
+ XXH128_hash_t const h2 = *(const XXH128_hash_t*)h128_2;
+ int const hcmp = (h1.high64 > h2.high64) - (h2.high64 > h1.high64);
+ /* note : bets that, in most cases, hash values are different */
+ if (hcmp) return hcmp;
+ return (h1.low64 > h2.low64) - (h2.low64 > h1.low64);
+}
+
+
+/*====== Canonical representation ======*/
+/*! @ingroup xxh3_family */
+XXH_PUBLIC_API void
+XXH128_canonicalFromHash(XXH128_canonical_t* dst, XXH128_hash_t hash)
+{
+ XXH_STATIC_ASSERT(sizeof(XXH128_canonical_t) == sizeof(XXH128_hash_t));
+ if (XXH_CPU_LITTLE_ENDIAN) {
+ hash.high64 = XXH_swap64(hash.high64);
+ hash.low64 = XXH_swap64(hash.low64);
+ }
+ memcpy(dst, &hash.high64, sizeof(hash.high64));
+ memcpy((char*)dst + sizeof(hash.high64), &hash.low64, sizeof(hash.low64));
+}
+
+/*! @ingroup xxh3_family */
+XXH_PUBLIC_API XXH128_hash_t
+XXH128_hashFromCanonical(const XXH128_canonical_t* src)
+{
+ XXH128_hash_t h;
+ h.high64 = XXH_readBE64(src);
+ h.low64 = XXH_readBE64(src->digest + 8);
+ return h;
+}
+
+/* Pop our optimization override from above */
+#if XXH_VECTOR == XXH_AVX2 /* AVX2 */ \
+ && defined(__GNUC__) && !defined(__clang__) /* GCC, not Clang */ \
+ && defined(__OPTIMIZE__) && !defined(__OPTIMIZE_SIZE__) /* respect -O0 and -Os */
+# pragma GCC pop_options
+#endif
+
+#endif /* XXH_NO_LONG_LONG */
+
+#endif /* XXH_NO_XXH3 */
+
+/*!
+ * @}
+ */
+#endif /* XXH_IMPLEMENTATION */
+
+
+#if defined (__cplusplus)
+}
+#endif
diff --git a/src/rocksdb/util/xxph3.h b/src/rocksdb/util/xxph3.h
new file mode 100644
index 000000000..968000c3a
--- /dev/null
+++ b/src/rocksdb/util/xxph3.h
@@ -0,0 +1,1764 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+/*
+ xxHash - Extremely Fast Hash algorithm
+ Header File
+ Copyright (C) 2012-2016, Yann Collet.
+
+ BSD 2-Clause License (http://www.opensource.org/licenses/bsd-license.php)
+
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions are
+ met:
+
+ * Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above
+ copyright notice, this list of conditions and the following disclaimer
+ in the documentation and/or other materials provided with the
+ distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+ You can contact the author at :
+ - xxHash source repository : https://github.com/Cyan4973/xxHash
+*/
+
+// This is a fork of a preview version of xxHash, as RocksDB depends on
+// this preview version of XXH3. To allow this to coexist with the
+// standard xxHash, including in the "unity" build where all source files
+// and headers go into a single translation unit, here "XXH" has been
+// replaced with "XXPH" for XX Preview Hash.
+
+#ifndef XXPHASH_H_5627135585666179
+#define XXPHASH_H_5627135585666179 1
+
+/* BEGIN RocksDB customizations */
+#ifndef XXPH_STATIC_LINKING_ONLY
+// Access experimental APIs
+#define XXPH_STATIC_LINKING_ONLY 1
+#endif
+#define XXPH_NAMESPACE ROCKSDB_
+#define XXPH_INLINE_ALL
+#include <cstring>
+/* END RocksDB customizations */
+
+// clang-format off
+#if defined (__cplusplus)
+extern "C" {
+#endif
+
+
+/* ****************************
+* Definitions
+******************************/
+#include <stddef.h> /* size_t */
+typedef enum { XXPH_OK=0, XXPH_ERROR } XXPH_errorcode;
+
+
+/* ****************************
+ * API modifier
+ ******************************/
+/** XXPH_INLINE_ALL (and XXPH_PRIVATE_API)
+ * This build macro includes xxhash functions in `static` mode
+ * in order to inline them, and remove their symbol from the public list.
+ * Inlining offers great performance improvement on small keys,
+ * and dramatic ones when length is expressed as a compile-time constant.
+ * See https://fastcompression.blogspot.com/2018/03/xxhash-for-small-keys-impressive-power.html .
+ * Methodology :
+ * #define XXPH_INLINE_ALL
+ * #include "xxhash.h"
+ * `xxhash.c` is automatically included.
+ * It's not useful to compile and link it as a separate object.
+ */
+#if defined(XXPH_INLINE_ALL) || defined(XXPH_PRIVATE_API)
+# ifndef XXPH_STATIC_LINKING_ONLY
+# define XXPH_STATIC_LINKING_ONLY
+# endif
+# if defined(__GNUC__)
+# define XXPH_PUBLIC_API static __inline __attribute__((unused))
+# elif defined (__cplusplus) || (defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) /* C99 */)
+# define XXPH_PUBLIC_API static inline
+# elif defined(_MSC_VER)
+# define XXPH_PUBLIC_API static __inline
+# else
+ /* this version may generate warnings for unused static functions */
+# define XXPH_PUBLIC_API static
+# endif
+#else
+# if defined(WIN32) && defined(_MSC_VER) && (defined(XXPH_IMPORT) || defined(XXPH_EXPORT))
+# ifdef XXPH_EXPORT
+# define XXPH_PUBLIC_API __declspec(dllexport)
+# elif XXPH_IMPORT
+# define XXPH_PUBLIC_API __declspec(dllimport)
+# endif
+# else
+# define XXPH_PUBLIC_API /* do nothing */
+# endif
+#endif /* XXPH_INLINE_ALL || XXPH_PRIVATE_API */
+
+/*! XXPH_NAMESPACE, aka Namespace Emulation :
+ *
+ * If you want to include _and expose_ xxHash functions from within your own library,
+ * but also want to avoid symbol collisions with other libraries which may also include xxHash,
+ *
+ * you can use XXPH_NAMESPACE, to automatically prefix any public symbol from xxhash library
+ * with the value of XXPH_NAMESPACE (therefore, avoid NULL and numeric values).
+ *
+ * Note that no change is required within the calling program as long as it includes `xxhash.h` :
+ * regular symbol name will be automatically translated by this header.
+ */
+#ifdef XXPH_NAMESPACE
+# define XXPH_CAT(A,B) A##B
+# define XXPH_NAME2(A,B) XXPH_CAT(A,B)
+# define XXPH_versionNumber XXPH_NAME2(XXPH_NAMESPACE, XXPH_versionNumber)
+#endif
+
+
+/* *************************************
+* Version
+***************************************/
+#define XXPH_VERSION_MAJOR 0
+#define XXPH_VERSION_MINOR 7
+#define XXPH_VERSION_RELEASE 2
+#define XXPH_VERSION_NUMBER (XXPH_VERSION_MAJOR *100*100 + XXPH_VERSION_MINOR *100 + XXPH_VERSION_RELEASE)
+XXPH_PUBLIC_API unsigned XXPH_versionNumber (void);
+
+
+/*-**********************************************************************
+* 32-bit hash
+************************************************************************/
+#if !defined (__VMS) \
+ && (defined (__cplusplus) \
+ || (defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) /* C99 */) )
+# include <stdint.h>
+ typedef uint32_t XXPH32_hash_t;
+#else
+# include <limits.h>
+# if UINT_MAX == 0xFFFFFFFFUL
+ typedef unsigned int XXPH32_hash_t;
+# else
+# if ULONG_MAX == 0xFFFFFFFFUL
+ typedef unsigned long XXPH32_hash_t;
+# else
+# error "unsupported platform : need a 32-bit type"
+# endif
+# endif
+#endif
+
+#ifndef XXPH_NO_LONG_LONG
+/*-**********************************************************************
+* 64-bit hash
+************************************************************************/
+#if !defined (__VMS) \
+ && (defined (__cplusplus) \
+ || (defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) /* C99 */) )
+# include <stdint.h>
+ typedef uint64_t XXPH64_hash_t;
+#else
+ /* the following type must have a width of 64-bit */
+ typedef unsigned long long XXPH64_hash_t;
+#endif
+
+#endif /* XXPH_NO_LONG_LONG */
+
+
+
+#ifdef XXPH_STATIC_LINKING_ONLY
+
+/* ================================================================================================
+ This section contains declarations which are not guaranteed to remain stable.
+ They may change in future versions, becoming incompatible with a different version of the library.
+ These declarations should only be used with static linking.
+ Never use them in association with dynamic linking !
+=================================================================================================== */
+
+
+/*-**********************************************************************
+* XXPH3
+* New experimental hash
+************************************************************************/
+#ifndef XXPH_NO_LONG_LONG
+
+
+/* ============================================
+ * XXPH3 is a new hash algorithm,
+ * featuring improved speed performance for both small and large inputs.
+ * See full speed analysis at : http://fastcompression.blogspot.com/2019/03/presenting-xxh3.html
+ * In general, expect XXPH3 to run about ~2x faster on large inputs,
+ * and >3x faster on small ones, though exact differences depend on platform.
+ *
+ * The algorithm is portable, will generate the same hash on all platforms.
+ * It benefits greatly from vectorization units, but does not require it.
+ *
+ * XXPH3 offers 2 variants, _64bits and _128bits.
+ * When only 64 bits are needed, prefer calling the _64bits variant :
+ * it reduces the amount of mixing, resulting in faster speed on small inputs.
+ * It's also generally simpler to manipulate a scalar return type than a struct.
+ *
+ * The XXPH3 algorithm is still considered experimental.
+ * Produced results can still change between versions.
+ * Results produced by v0.7.x are not comparable with results from v0.7.y .
+ * It's nonetheless possible to use XXPH3 for ephemeral data (local sessions),
+ * but avoid storing values in long-term storage for later reads.
+ *
+ * The API supports one-shot hashing, streaming mode, and custom secrets.
+ *
+ * There are still a number of opened questions that community can influence during the experimental period.
+ * I'm trying to list a few of them below, though don't consider this list as complete.
+ *
+ * - 128-bits output type : currently defined as a structure of two 64-bits fields.
+ * That's because 128-bit values do not exist in C standard.
+ * Note that it means that, at byte level, result is not identical depending on endianess.
+ * However, at field level, they are identical on all platforms.
+ * The canonical representation solves the issue of identical byte-level representation across platforms,
+ * which is necessary for serialization.
+ * Q1 : Would there be a better representation for a 128-bit hash result ?
+ * Q2 : Are the names of the inner 64-bit fields important ? Should they be changed ?
+ *
+ * - Prototype XXPH128() : XXPH128() uses the same arguments as XXPH64(), for consistency.
+ * It means it maps to XXPH3_128bits_withSeed().
+ * This variant is slightly slower than XXPH3_128bits(),
+ * because the seed is now part of the algorithm, and can't be simplified.
+ * Is that a good idea ?
+ *
+ * - Seed type for XXPH128() : currently, it's a single 64-bit value, like the 64-bit variant.
+ * It could be argued that it's more logical to offer a 128-bit seed input parameter for a 128-bit hash.
+ * But 128-bit seed is more difficult to use, since it requires to pass a structure instead of a scalar value.
+ * Such a variant could either replace current one, or become an additional one.
+ * Farmhash, for example, offers both variants (the 128-bits seed variant is called `doubleSeed`).
+ * Follow up question : if both 64-bit and 128-bit seeds are allowed, which variant should be called XXPH128 ?
+ *
+ * - Result for len==0 : Currently, the result of hashing a zero-length input is always `0`.
+ * It seems okay as a return value when using "default" secret and seed.
+ * But is it still fine to return `0` when secret or seed are non-default ?
+ * Are there use cases which could depend on generating a different hash result for zero-length input when the secret is different ?
+ *
+ * - Consistency (1) : Streaming XXPH128 uses an XXPH3 state, which is the same state as XXPH3_64bits().
+ * It means a 128bit streaming loop must invoke the following symbols :
+ * XXPH3_createState(), XXPH3_128bits_reset(), XXPH3_128bits_update() (loop), XXPH3_128bits_digest(), XXPH3_freeState().
+ * Is that consistent enough ?
+ *
+ * - Consistency (2) : The canonical representation of `XXPH3_64bits` is provided by existing functions
+ * XXPH64_canonicalFromHash(), and reverse operation XXPH64_hashFromCanonical().
+ * As a mirror, canonical functions for XXPH128_hash_t results generated by `XXPH3_128bits`
+ * are XXPH128_canonicalFromHash() and XXPH128_hashFromCanonical().
+ * Which means, `XXPH3` doesn't appear in the names, because canonical functions operate on a type,
+ * independently of which algorithm was used to generate that type.
+ * Is that consistent enough ?
+ */
+
+#ifdef XXPH_NAMESPACE
+# define XXPH3_64bits XXPH_NAME2(XXPH_NAMESPACE, XXPH3_64bits)
+# define XXPH3_64bits_withSecret XXPH_NAME2(XXPH_NAMESPACE, XXPH3_64bits_withSecret)
+# define XXPH3_64bits_withSeed XXPH_NAME2(XXPH_NAMESPACE, XXPH3_64bits_withSeed)
+#endif
+
+/* XXPH3_64bits() :
+ * default 64-bit variant, using default secret and default seed of 0.
+ * It's the fastest variant. */
+XXPH_PUBLIC_API XXPH64_hash_t XXPH3_64bits(const void* data, size_t len);
+
+/* XXPH3_64bits_withSecret() :
+ * It's possible to provide any blob of bytes as a "secret" to generate the hash.
+ * This makes it more difficult for an external actor to prepare an intentional collision.
+ * The secret *must* be large enough (>= XXPH3_SECRET_SIZE_MIN).
+ * It should consist of random bytes.
+ * Avoid repeating same character, or sequences of bytes,
+ * and especially avoid swathes of \0.
+ * Failure to respect these conditions will result in a poor quality hash.
+ */
+#define XXPH3_SECRET_SIZE_MIN 136
+XXPH_PUBLIC_API XXPH64_hash_t XXPH3_64bits_withSecret(const void* data, size_t len, const void* secret, size_t secretSize);
+
+/* XXPH3_64bits_withSeed() :
+ * This variant generates on the fly a custom secret,
+ * based on the default secret, altered using the `seed` value.
+ * While this operation is decently fast, note that it's not completely free.
+ * note : seed==0 produces same results as XXPH3_64bits() */
+XXPH_PUBLIC_API XXPH64_hash_t XXPH3_64bits_withSeed(const void* data, size_t len, XXPH64_hash_t seed);
+
+#if defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 201112L) /* C11+ */
+# include <stdalign.h>
+# define XXPH_ALIGN(n) alignas(n)
+#elif defined(__GNUC__)
+# define XXPH_ALIGN(n) __attribute__ ((aligned(n)))
+#elif defined(_MSC_VER)
+# define XXPH_ALIGN(n) __declspec(align(n))
+#else
+# define XXPH_ALIGN(n) /* disabled */
+#endif
+
+#define XXPH3_SECRET_DEFAULT_SIZE 192 /* minimum XXPH3_SECRET_SIZE_MIN */
+
+#endif /* XXPH_NO_LONG_LONG */
+
+
+/*-**********************************************************************
+* XXPH_INLINE_ALL
+************************************************************************/
+#if defined(XXPH_INLINE_ALL) || defined(XXPH_PRIVATE_API)
+
+/* === RocksDB modification: was #include here but permanently inlining === */
+
+typedef struct {
+ XXPH64_hash_t low64;
+ XXPH64_hash_t high64;
+} XXPH128_hash_t;
+
+/* *************************************
+* Tuning parameters
+***************************************/
+/*!XXPH_FORCE_MEMORY_ACCESS :
+ * By default, access to unaligned memory is controlled by `memcpy()`, which is safe and portable.
+ * Unfortunately, on some target/compiler combinations, the generated assembly is sub-optimal.
+ * The below switch allow to select different access method for improved performance.
+ * Method 0 (default) : use `memcpy()`. Safe and portable.
+ * Method 1 : `__packed` statement. It depends on compiler extension (ie, not portable).
+ * This method is safe if your compiler supports it, and *generally* as fast or faster than `memcpy`.
+ * Method 2 : direct access. This method doesn't depend on compiler but violate C standard.
+ * It can generate buggy code on targets which do not support unaligned memory accesses.
+ * But in some circumstances, it's the only known way to get the most performance (ie GCC + ARMv6)
+ * See http://stackoverflow.com/a/32095106/646947 for details.
+ * Prefer these methods in priority order (0 > 1 > 2)
+ */
+#ifndef XXPH_FORCE_MEMORY_ACCESS /* can be defined externally, on command line for example */
+# if !defined(__clang__) && defined(__GNUC__) && defined(__ARM_FEATURE_UNALIGNED) && defined(__ARM_ARCH) && (__ARM_ARCH == 6)
+# define XXPH_FORCE_MEMORY_ACCESS 2
+# elif !defined(__clang__) && ((defined(__INTEL_COMPILER) && !defined(_WIN32)) || \
+ (defined(__GNUC__) && (defined(__ARM_ARCH) && __ARM_ARCH >= 7)))
+# define XXPH_FORCE_MEMORY_ACCESS 1
+# endif
+#endif
+
+/*!XXPH_ACCEPT_NULL_INPUT_POINTER :
+ * If input pointer is NULL, xxHash default behavior is to dereference it, triggering a segfault.
+ * When this macro is enabled, xxHash actively checks input for null pointer.
+ * It it is, result for null input pointers is the same as a null-length input.
+ */
+#ifndef XXPH_ACCEPT_NULL_INPUT_POINTER /* can be defined externally */
+# define XXPH_ACCEPT_NULL_INPUT_POINTER 0
+#endif
+
+/*!XXPH_FORCE_ALIGN_CHECK :
+ * This is a minor performance trick, only useful with lots of very small keys.
+ * It means : check for aligned/unaligned input.
+ * The check costs one initial branch per hash;
+ * set it to 0 when the input is guaranteed to be aligned,
+ * or when alignment doesn't matter for performance.
+ */
+#ifndef XXPH_FORCE_ALIGN_CHECK /* can be defined externally */
+# if defined(__i386) || defined(_M_IX86) || defined(__x86_64__) || defined(_M_X64)
+# define XXPH_FORCE_ALIGN_CHECK 0
+# else
+# define XXPH_FORCE_ALIGN_CHECK 1
+# endif
+#endif
+
+/*!XXPH_REROLL:
+ * Whether to reroll XXPH32_finalize, and XXPH64_finalize,
+ * instead of using an unrolled jump table/if statement loop.
+ *
+ * This is automatically defined on -Os/-Oz on GCC and Clang. */
+#ifndef XXPH_REROLL
+# if defined(__OPTIMIZE_SIZE__)
+# define XXPH_REROLL 1
+# else
+# define XXPH_REROLL 0
+# endif
+#endif
+
+#include <limits.h> /* ULLONG_MAX */
+
+#ifndef XXPH_STATIC_LINKING_ONLY
+#define XXPH_STATIC_LINKING_ONLY
+#endif
+
+/* BEGIN RocksDB customizations */
+#include "port/lang.h" /* for FALLTHROUGH_INTENDED, inserted as appropriate */
+/* END RocksDB customizations */
+
+/* *************************************
+* Compiler Specific Options
+***************************************/
+#ifdef _MSC_VER /* Visual Studio */
+# pragma warning(disable : 4127) /* disable: C4127: conditional expression is constant */
+# define XXPH_FORCE_INLINE static __forceinline
+# define XXPH_NO_INLINE static __declspec(noinline)
+#else
+# if defined (__cplusplus) || defined (__STDC_VERSION__) && __STDC_VERSION__ >= 199901L /* C99 */
+# ifdef __GNUC__
+# define XXPH_FORCE_INLINE static inline __attribute__((always_inline))
+# define XXPH_NO_INLINE static __attribute__((noinline))
+# else
+# define XXPH_FORCE_INLINE static inline
+# define XXPH_NO_INLINE static
+# endif
+# else
+# define XXPH_FORCE_INLINE static
+# define XXPH_NO_INLINE static
+# endif /* __STDC_VERSION__ */
+#endif
+
+
+
+/* *************************************
+* Debug
+***************************************/
+/* DEBUGLEVEL is expected to be defined externally,
+ * typically through compiler command line.
+ * Value must be a number. */
+#ifndef DEBUGLEVEL
+# define DEBUGLEVEL 0
+#endif
+
+#if (DEBUGLEVEL>=1)
+# include <assert.h> /* note : can still be disabled with NDEBUG */
+# define XXPH_ASSERT(c) assert(c)
+#else
+# define XXPH_ASSERT(c) ((void)0)
+#endif
+
+/* note : use after variable declarations */
+#define XXPH_STATIC_ASSERT(c) { enum { XXPH_sa = 1/(int)(!!(c)) }; }
+
+
+/* *************************************
+* Basic Types
+***************************************/
+#if !defined (__VMS) \
+ && (defined (__cplusplus) \
+ || (defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) /* C99 */) )
+# include <stdint.h>
+ typedef uint8_t xxh_u8;
+#else
+ typedef unsigned char xxh_u8;
+#endif
+typedef XXPH32_hash_t xxh_u32;
+
+
+/* === Memory access === */
+
+#if (defined(XXPH_FORCE_MEMORY_ACCESS) && (XXPH_FORCE_MEMORY_ACCESS==2))
+
+/* Force direct memory access. Only works on CPU which support unaligned memory access in hardware */
+static xxh_u32 XXPH_read32(const void* memPtr) { return *(const xxh_u32*) memPtr; }
+
+#elif (defined(XXPH_FORCE_MEMORY_ACCESS) && (XXPH_FORCE_MEMORY_ACCESS==1))
+
+/* __pack instructions are safer, but compiler specific, hence potentially problematic for some compilers */
+/* currently only defined for gcc and icc */
+typedef union { xxh_u32 u32; } __attribute__((packed)) unalign;
+static xxh_u32 XXPH_read32(const void* ptr) { return ((const unalign*)ptr)->u32; }
+
+#else
+
+/* portable and safe solution. Generally efficient.
+ * see : http://stackoverflow.com/a/32095106/646947
+ */
+static xxh_u32 XXPH_read32(const void* memPtr)
+{
+ xxh_u32 val;
+ memcpy(&val, memPtr, sizeof(val));
+ return val;
+}
+
+#endif /* XXPH_FORCE_DIRECT_MEMORY_ACCESS */
+
+
+/* === Endianess === */
+
+/* XXPH_CPU_LITTLE_ENDIAN can be defined externally, for example on the compiler command line */
+#ifndef XXPH_CPU_LITTLE_ENDIAN
+# if defined(_WIN32) /* Windows is always little endian */ \
+ || defined(__LITTLE_ENDIAN__) \
+ || (defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)
+# define XXPH_CPU_LITTLE_ENDIAN 1
+# elif defined(__BIG_ENDIAN__) \
+ || (defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__)
+# define XXPH_CPU_LITTLE_ENDIAN 0
+# else
+static int XXPH_isLittleEndian(void)
+{
+ const union { xxh_u32 u; xxh_u8 c[4]; } one = { 1 }; /* don't use static : performance detrimental */
+ return one.c[0];
+}
+# define XXPH_CPU_LITTLE_ENDIAN XXPH_isLittleEndian()
+# endif
+#endif
+
+
+
+
+/* ****************************************
+* Compiler-specific Functions and Macros
+******************************************/
+#define XXPH_GCC_VERSION (__GNUC__ * 100 + __GNUC_MINOR__)
+
+#ifndef __has_builtin
+# define __has_builtin(x) 0
+#endif
+
+#if !defined(NO_CLANG_BUILTIN) && __has_builtin(__builtin_rotateleft32) && __has_builtin(__builtin_rotateleft64)
+# define XXPH_rotl32 __builtin_rotateleft32
+# define XXPH_rotl64 __builtin_rotateleft64
+/* Note : although _rotl exists for minGW (GCC under windows), performance seems poor */
+#elif defined(_MSC_VER)
+# define XXPH_rotl32(x,r) _rotl(x,r)
+# define XXPH_rotl64(x,r) _rotl64(x,r)
+#else
+# define XXPH_rotl32(x,r) (((x) << (r)) | ((x) >> (32 - (r))))
+# define XXPH_rotl64(x,r) (((x) << (r)) | ((x) >> (64 - (r))))
+#endif
+
+#if defined(_MSC_VER) /* Visual Studio */
+# define XXPH_swap32 _byteswap_ulong
+#elif XXPH_GCC_VERSION >= 403
+# define XXPH_swap32 __builtin_bswap32
+#else
+static xxh_u32 XXPH_swap32 (xxh_u32 x)
+{
+ return ((x << 24) & 0xff000000 ) |
+ ((x << 8) & 0x00ff0000 ) |
+ ((x >> 8) & 0x0000ff00 ) |
+ ((x >> 24) & 0x000000ff );
+}
+#endif
+
+
+/* ***************************
+* Memory reads
+*****************************/
+typedef enum { XXPH_aligned, XXPH_unaligned } XXPH_alignment;
+
+XXPH_FORCE_INLINE xxh_u32 XXPH_readLE32(const void* ptr)
+{
+ return XXPH_CPU_LITTLE_ENDIAN ? XXPH_read32(ptr) : XXPH_swap32(XXPH_read32(ptr));
+}
+
+XXPH_FORCE_INLINE xxh_u32
+XXPH_readLE32_align(const void* ptr, XXPH_alignment align)
+{
+ if (align==XXPH_unaligned) {
+ return XXPH_readLE32(ptr);
+ } else {
+ return XXPH_CPU_LITTLE_ENDIAN ? *(const xxh_u32*)ptr : XXPH_swap32(*(const xxh_u32*)ptr);
+ }
+}
+
+
+/* *************************************
+* Misc
+***************************************/
+XXPH_PUBLIC_API unsigned XXPH_versionNumber (void) { return XXPH_VERSION_NUMBER; }
+
+
+static const xxh_u32 PRIME32_1 = 0x9E3779B1U; /* 0b10011110001101110111100110110001 */
+static const xxh_u32 PRIME32_2 = 0x85EBCA77U; /* 0b10000101111010111100101001110111 */
+static const xxh_u32 PRIME32_3 = 0xC2B2AE3DU; /* 0b11000010101100101010111000111101 */
+static const xxh_u32 PRIME32_4 = 0x27D4EB2FU; /* 0b00100111110101001110101100101111 */
+static const xxh_u32 PRIME32_5 = 0x165667B1U; /* 0b00010110010101100110011110110001 */
+
+#ifndef XXPH_NO_LONG_LONG
+
+/* *******************************************************************
+* 64-bit hash functions
+*********************************************************************/
+
+/*====== Memory access ======*/
+
+typedef XXPH64_hash_t xxh_u64;
+
+#if (defined(XXPH_FORCE_MEMORY_ACCESS) && (XXPH_FORCE_MEMORY_ACCESS==2))
+
+/* Force direct memory access. Only works on CPU which support unaligned memory access in hardware */
+static xxh_u64 XXPH_read64(const void* memPtr) { return *(const xxh_u64*) memPtr; }
+
+#elif (defined(XXPH_FORCE_MEMORY_ACCESS) && (XXPH_FORCE_MEMORY_ACCESS==1))
+
+/* __pack instructions are safer, but compiler specific, hence potentially problematic for some compilers */
+/* currently only defined for gcc and icc */
+typedef union { xxh_u32 u32; xxh_u64 u64; } __attribute__((packed)) unalign64;
+static xxh_u64 XXPH_read64(const void* ptr) { return ((const unalign64*)ptr)->u64; }
+
+#else
+
+/* portable and safe solution. Generally efficient.
+ * see : http://stackoverflow.com/a/32095106/646947
+ */
+
+static xxh_u64 XXPH_read64(const void* memPtr)
+{
+ xxh_u64 val;
+ memcpy(&val, memPtr, sizeof(val));
+ return val;
+}
+
+#endif /* XXPH_FORCE_DIRECT_MEMORY_ACCESS */
+
+#if defined(_MSC_VER) /* Visual Studio */
+# define XXPH_swap64 _byteswap_uint64
+#elif XXPH_GCC_VERSION >= 403
+# define XXPH_swap64 __builtin_bswap64
+#else
+static xxh_u64 XXPH_swap64 (xxh_u64 x)
+{
+ return ((x << 56) & 0xff00000000000000ULL) |
+ ((x << 40) & 0x00ff000000000000ULL) |
+ ((x << 24) & 0x0000ff0000000000ULL) |
+ ((x << 8) & 0x000000ff00000000ULL) |
+ ((x >> 8) & 0x00000000ff000000ULL) |
+ ((x >> 24) & 0x0000000000ff0000ULL) |
+ ((x >> 40) & 0x000000000000ff00ULL) |
+ ((x >> 56) & 0x00000000000000ffULL);
+}
+#endif
+
+XXPH_FORCE_INLINE xxh_u64 XXPH_readLE64(const void* ptr)
+{
+ return XXPH_CPU_LITTLE_ENDIAN ? XXPH_read64(ptr) : XXPH_swap64(XXPH_read64(ptr));
+}
+
+XXPH_FORCE_INLINE xxh_u64
+XXPH_readLE64_align(const void* ptr, XXPH_alignment align)
+{
+ if (align==XXPH_unaligned)
+ return XXPH_readLE64(ptr);
+ else
+ return XXPH_CPU_LITTLE_ENDIAN ? *(const xxh_u64*)ptr : XXPH_swap64(*(const xxh_u64*)ptr);
+}
+
+
+/*====== xxh64 ======*/
+
+static const xxh_u64 PRIME64_1 = 0x9E3779B185EBCA87ULL; /* 0b1001111000110111011110011011000110000101111010111100101010000111 */
+static const xxh_u64 PRIME64_2 = 0xC2B2AE3D27D4EB4FULL; /* 0b1100001010110010101011100011110100100111110101001110101101001111 */
+static const xxh_u64 PRIME64_3 = 0x165667B19E3779F9ULL; /* 0b0001011001010110011001111011000110011110001101110111100111111001 */
+static const xxh_u64 PRIME64_4 = 0x85EBCA77C2B2AE63ULL; /* 0b1000010111101011110010100111011111000010101100101010111001100011 */
+static const xxh_u64 PRIME64_5 = 0x27D4EB2F165667C5ULL; /* 0b0010011111010100111010110010111100010110010101100110011111000101 */
+
+
+/* *********************************************************************
+* XXPH3
+* New generation hash designed for speed on small keys and vectorization
+************************************************************************ */
+
+/*======== Was #include "xxh3.h", now inlined below ==========*/
+
+/*
+ xxHash - Extremely Fast Hash algorithm
+ Development source file for `xxh3`
+ Copyright (C) 2019-present, Yann Collet.
+
+ BSD 2-Clause License (http://www.opensource.org/licenses/bsd-license.php)
+
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions are
+ met:
+
+ * Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above
+ copyright notice, this list of conditions and the following disclaimer
+ in the documentation and/or other materials provided with the
+ distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+ You can contact the author at :
+ - xxHash source repository : https://github.com/Cyan4973/xxHash
+*/
+
+/* RocksDB Note: This file contains a preview release (xxhash repository
+ version 0.7.2) of XXPH3 that is unlikely to be compatible with the final
+ version of XXPH3. We have therefore renamed this XXPH3 ("preview"), for
+ clarity so that we can continue to use this version even after
+ integrating a newer incompatible version.
+*/
+
+/* === Dependencies === */
+
+#undef XXPH_INLINE_ALL /* in case it's already defined */
+#define XXPH_INLINE_ALL
+
+
+/* === Compiler specifics === */
+
+#if defined (__STDC_VERSION__) && __STDC_VERSION__ >= 199901L /* >= C99 */
+# define XXPH_RESTRICT restrict
+#else
+/* note : it might be useful to define __restrict or __restrict__ for some C++ compilers */
+# define XXPH_RESTRICT /* disable */
+#endif
+
+#if defined(__GNUC__)
+# if defined(__AVX2__)
+# include <immintrin.h>
+# elif defined(__SSE2__)
+# include <emmintrin.h>
+# elif defined(__ARM_NEON__) || defined(__ARM_NEON)
+# define inline __inline__ /* clang bug */
+# include <arm_neon.h>
+# undef inline
+# endif
+#elif defined(_MSC_VER)
+# include <intrin.h>
+#endif
+
+/*
+ * Sanity check.
+ *
+ * XXPH3 only requires these features to be efficient:
+ *
+ * - Usable unaligned access
+ * - A 32-bit or 64-bit ALU
+ * - If 32-bit, a decent ADC instruction
+ * - A 32 or 64-bit multiply with a 64-bit result
+ *
+ * Almost all 32-bit and 64-bit targets meet this, except for Thumb-1, the
+ * classic 16-bit only subset of ARM's instruction set.
+ *
+ * First of all, Thumb-1 lacks support for the UMULL instruction which
+ * performs the important long multiply. This means numerous __aeabi_lmul
+ * calls.
+ *
+ * Second of all, the 8 functional registers are just not enough.
+ * Setup for __aeabi_lmul, byteshift loads, pointers, and all arithmetic need
+ * Lo registers, and this shuffling results in thousands more MOVs than A32.
+ *
+ * A32 and T32 don't have this limitation. They can access all 14 registers,
+ * do a 32->64 multiply with UMULL, and the flexible operand is helpful too.
+ *
+ * If compiling Thumb-1 for a target which supports ARM instructions, we
+ * will give a warning.
+ *
+ * Usually, if this happens, it is because of an accident and you probably
+ * need to specify -march, as you probably meant to compileh for a newer
+ * architecture.
+ */
+#if defined(__thumb__) && !defined(__thumb2__) && defined(__ARM_ARCH_ISA_ARM)
+# warning "XXPH3 is highly inefficient without ARM or Thumb-2."
+#endif
+
+/* ==========================================
+ * Vectorization detection
+ * ========================================== */
+#define XXPH_SCALAR 0
+#define XXPH_SSE2 1
+#define XXPH_AVX2 2
+#define XXPH_NEON 3
+#define XXPH_VSX 4
+
+#ifndef XXPH_VECTOR /* can be defined on command line */
+# if defined(__AVX2__)
+# define XXPH_VECTOR XXPH_AVX2
+# elif defined(__SSE2__) || defined(_M_AMD64) || defined(_M_X64) || (defined(_M_IX86_FP) && (_M_IX86_FP == 2))
+# define XXPH_VECTOR XXPH_SSE2
+# elif defined(__GNUC__) /* msvc support maybe later */ \
+ && (defined(__ARM_NEON__) || defined(__ARM_NEON)) \
+ && (defined(__LITTLE_ENDIAN__) /* We only support little endian NEON */ \
+ || (defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__))
+# define XXPH_VECTOR XXPH_NEON
+# elif defined(__PPC64__) && defined(__POWER8_VECTOR__) && defined(__GNUC__)
+# define XXPH_VECTOR XXPH_VSX
+# else
+# define XXPH_VECTOR XXPH_SCALAR
+# endif
+#endif
+
+/* control alignment of accumulator,
+ * for compatibility with fast vector loads */
+#ifndef XXPH_ACC_ALIGN
+# if XXPH_VECTOR == 0 /* scalar */
+# define XXPH_ACC_ALIGN 8
+# elif XXPH_VECTOR == 1 /* sse2 */
+# define XXPH_ACC_ALIGN 16
+# elif XXPH_VECTOR == 2 /* avx2 */
+# define XXPH_ACC_ALIGN 32
+# elif XXPH_VECTOR == 3 /* neon */
+# define XXPH_ACC_ALIGN 16
+# elif XXPH_VECTOR == 4 /* vsx */
+# define XXPH_ACC_ALIGN 16
+# endif
+#endif
+
+/* xxh_u64 XXPH_mult32to64(xxh_u32 a, xxh_u64 b) { return (xxh_u64)a * (xxh_u64)b; } */
+#if defined(_MSC_VER) && defined(_M_IX86)
+# include <intrin.h>
+# define XXPH_mult32to64(x, y) __emulu(x, y)
+#else
+# define XXPH_mult32to64(x, y) ((xxh_u64)((x) & 0xFFFFFFFF) * (xxh_u64)((y) & 0xFFFFFFFF))
+#endif
+
+/* VSX stuff. It's a lot because VSX support is mediocre across compilers and
+ * there is a lot of mischief with endianness. */
+#if XXPH_VECTOR == XXPH_VSX
+# include <altivec.h>
+# undef vector
+typedef __vector unsigned long long U64x2;
+typedef __vector unsigned char U8x16;
+typedef __vector unsigned U32x4;
+
+#ifndef XXPH_VSX_BE
+# if defined(__BIG_ENDIAN__) \
+ || (defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__)
+# define XXPH_VSX_BE 1
+# elif defined(__VEC_ELEMENT_REG_ORDER__) && __VEC_ELEMENT_REG_ORDER__ == __ORDER_BIG_ENDIAN__
+# warning "-maltivec=be is not recommended. Please use native endianness."
+# define XXPH_VSX_BE 1
+# else
+# define XXPH_VSX_BE 0
+# endif
+#endif
+
+/* We need some helpers for big endian mode. */
+#if XXPH_VSX_BE
+/* A wrapper for POWER9's vec_revb. */
+# ifdef __POWER9_VECTOR__
+# define XXPH_vec_revb vec_revb
+# else
+XXPH_FORCE_INLINE U64x2 XXPH_vec_revb(U64x2 val)
+{
+ U8x16 const vByteSwap = { 0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01, 0x00,
+ 0x0F, 0x0E, 0x0D, 0x0C, 0x0B, 0x0A, 0x09, 0x08 };
+ return vec_perm(val, val, vByteSwap);
+}
+# endif
+
+/* Power8 Crypto gives us vpermxor which is very handy for
+ * PPC64EB.
+ *
+ * U8x16 vpermxor(U8x16 a, U8x16 b, U8x16 mask)
+ * {
+ * U8x16 ret;
+ * for (int i = 0; i < 16; i++) {
+ * ret[i] = a[mask[i] & 0xF] ^ b[mask[i] >> 4];
+ * }
+ * return ret;
+ * }
+ *
+ * Because both of the main loops load the key, swap, and xor it with input,
+ * we can combine the key swap into this instruction.
+ */
+# ifdef vec_permxor
+# define XXPH_vec_permxor vec_permxor
+# else
+# define XXPH_vec_permxor __builtin_crypto_vpermxor
+# endif
+#endif /* XXPH_VSX_BE */
+/*
+ * Because we reinterpret the multiply, there are endian memes: vec_mulo actually becomes
+ * vec_mule.
+ *
+ * Additionally, the intrinsic wasn't added until GCC 8, despite existing for a while.
+ * Clang has an easy way to control this, we can just use the builtin which doesn't swap.
+ * GCC needs inline assembly. */
+#if __has_builtin(__builtin_altivec_vmuleuw)
+# define XXPH_vec_mulo __builtin_altivec_vmulouw
+# define XXPH_vec_mule __builtin_altivec_vmuleuw
+#else
+/* Adapted from https://github.com/google/highwayhash/blob/master/highwayhash/hh_vsx.h. */
+XXPH_FORCE_INLINE U64x2 XXPH_vec_mulo(U32x4 a, U32x4 b) {
+ U64x2 result;
+ __asm__("vmulouw %0, %1, %2" : "=v" (result) : "v" (a), "v" (b));
+ return result;
+}
+XXPH_FORCE_INLINE U64x2 XXPH_vec_mule(U32x4 a, U32x4 b) {
+ U64x2 result;
+ __asm__("vmuleuw %0, %1, %2" : "=v" (result) : "v" (a), "v" (b));
+ return result;
+}
+#endif /* __has_builtin(__builtin_altivec_vmuleuw) */
+#endif /* XXPH_VECTOR == XXPH_VSX */
+
+/* prefetch
+ * can be disabled, by declaring XXPH_NO_PREFETCH build macro */
+#if defined(XXPH_NO_PREFETCH)
+# define XXPH_PREFETCH(ptr) (void)(ptr) /* disabled */
+#else
+#if defined(_MSC_VER) && \
+ (defined(_M_X64) || \
+ defined(_M_IX86)) /* _mm_prefetch() is not defined outside of x86/x64 */
+# include <mmintrin.h> /* https://msdn.microsoft.com/fr-fr/library/84szxsww(v=vs.90).aspx */
+# define XXPH_PREFETCH(ptr) _mm_prefetch((const char*)(ptr), _MM_HINT_T0)
+# elif defined(__GNUC__) && ( (__GNUC__ >= 4) || ( (__GNUC__ == 3) && (__GNUC_MINOR__ >= 1) ) )
+# define XXPH_PREFETCH(ptr) __builtin_prefetch((ptr), 0 /* rw==read */, 3 /* locality */)
+# else
+# define XXPH_PREFETCH(ptr) (void)(ptr) /* disabled */
+# endif
+#endif /* XXPH_NO_PREFETCH */
+
+
+/* ==========================================
+ * XXPH3 default settings
+ * ========================================== */
+
+#define XXPH_SECRET_DEFAULT_SIZE 192 /* minimum XXPH3_SECRET_SIZE_MIN */
+
+#if (XXPH_SECRET_DEFAULT_SIZE < XXPH3_SECRET_SIZE_MIN)
+# error "default keyset is not large enough"
+#endif
+
+XXPH_ALIGN(64) static const xxh_u8 kSecret[XXPH_SECRET_DEFAULT_SIZE] = {
+ 0xb8, 0xfe, 0x6c, 0x39, 0x23, 0xa4, 0x4b, 0xbe, 0x7c, 0x01, 0x81, 0x2c, 0xf7, 0x21, 0xad, 0x1c,
+ 0xde, 0xd4, 0x6d, 0xe9, 0x83, 0x90, 0x97, 0xdb, 0x72, 0x40, 0xa4, 0xa4, 0xb7, 0xb3, 0x67, 0x1f,
+ 0xcb, 0x79, 0xe6, 0x4e, 0xcc, 0xc0, 0xe5, 0x78, 0x82, 0x5a, 0xd0, 0x7d, 0xcc, 0xff, 0x72, 0x21,
+ 0xb8, 0x08, 0x46, 0x74, 0xf7, 0x43, 0x24, 0x8e, 0xe0, 0x35, 0x90, 0xe6, 0x81, 0x3a, 0x26, 0x4c,
+ 0x3c, 0x28, 0x52, 0xbb, 0x91, 0xc3, 0x00, 0xcb, 0x88, 0xd0, 0x65, 0x8b, 0x1b, 0x53, 0x2e, 0xa3,
+ 0x71, 0x64, 0x48, 0x97, 0xa2, 0x0d, 0xf9, 0x4e, 0x38, 0x19, 0xef, 0x46, 0xa9, 0xde, 0xac, 0xd8,
+ 0xa8, 0xfa, 0x76, 0x3f, 0xe3, 0x9c, 0x34, 0x3f, 0xf9, 0xdc, 0xbb, 0xc7, 0xc7, 0x0b, 0x4f, 0x1d,
+ 0x8a, 0x51, 0xe0, 0x4b, 0xcd, 0xb4, 0x59, 0x31, 0xc8, 0x9f, 0x7e, 0xc9, 0xd9, 0x78, 0x73, 0x64,
+
+ 0xea, 0xc5, 0xac, 0x83, 0x34, 0xd3, 0xeb, 0xc3, 0xc5, 0x81, 0xa0, 0xff, 0xfa, 0x13, 0x63, 0xeb,
+ 0x17, 0x0d, 0xdd, 0x51, 0xb7, 0xf0, 0xda, 0x49, 0xd3, 0x16, 0x55, 0x26, 0x29, 0xd4, 0x68, 0x9e,
+ 0x2b, 0x16, 0xbe, 0x58, 0x7d, 0x47, 0xa1, 0xfc, 0x8f, 0xf8, 0xb8, 0xd1, 0x7a, 0xd0, 0x31, 0xce,
+ 0x45, 0xcb, 0x3a, 0x8f, 0x95, 0x16, 0x04, 0x28, 0xaf, 0xd7, 0xfb, 0xca, 0xbb, 0x4b, 0x40, 0x7e,
+};
+
+/*
+ * GCC for x86 has a tendency to use SSE in this loop. While it
+ * successfully avoids swapping (as MUL overwrites EAX and EDX), it
+ * slows it down because instead of free register swap shifts, it
+ * must use pshufd and punpckl/hd.
+ *
+ * To prevent this, we use this attribute to shut off SSE.
+ */
+#if defined(__GNUC__) && !defined(__clang__) && defined(__i386__)
+__attribute__((__target__("no-sse")))
+#endif
+static XXPH128_hash_t
+XXPH_mult64to128(xxh_u64 lhs, xxh_u64 rhs)
+{
+ /*
+ * GCC/Clang __uint128_t method.
+ *
+ * On most 64-bit targets, GCC and Clang define a __uint128_t type.
+ * This is usually the best way as it usually uses a native long 64-bit
+ * multiply, such as MULQ on x86_64 or MUL + UMULH on aarch64.
+ *
+ * Usually.
+ *
+ * Despite being a 32-bit platform, Clang (and emscripten) define this
+ * type despite not having the arithmetic for it. This results in a
+ * laggy compiler builtin call which calculates a full 128-bit multiply.
+ * In that case it is best to use the portable one.
+ * https://github.com/Cyan4973/xxHash/issues/211#issuecomment-515575677
+ */
+#if defined(__GNUC__) && !defined(__wasm__) \
+ && defined(__SIZEOF_INT128__) \
+ || (defined(_INTEGRAL_MAX_BITS) && _INTEGRAL_MAX_BITS >= 128)
+
+ __uint128_t product = (__uint128_t)lhs * (__uint128_t)rhs;
+ XXPH128_hash_t const r128 = { (xxh_u64)(product), (xxh_u64)(product >> 64) };
+ return r128;
+
+ /*
+ * MSVC for x64's _umul128 method.
+ *
+ * xxh_u64 _umul128(xxh_u64 Multiplier, xxh_u64 Multiplicand, xxh_u64 *HighProduct);
+ *
+ * This compiles to single operand MUL on x64.
+ */
+#elif defined(_M_X64) || defined(_M_IA64)
+
+#ifndef _MSC_VER
+# pragma intrinsic(_umul128)
+#endif
+ xxh_u64 product_high;
+ xxh_u64 const product_low = _umul128(lhs, rhs, &product_high);
+ XXPH128_hash_t const r128 = { product_low, product_high };
+ return r128;
+
+#else
+ /*
+ * Portable scalar method. Optimized for 32-bit and 64-bit ALUs.
+ *
+ * This is a fast and simple grade school multiply, which is shown
+ * below with base 10 arithmetic instead of base 0x100000000.
+ *
+ * 9 3 // D2 lhs = 93
+ * x 7 5 // D2 rhs = 75
+ * ----------
+ * 1 5 // D2 lo_lo = (93 % 10) * (75 % 10)
+ * 4 5 | // D2 hi_lo = (93 / 10) * (75 % 10)
+ * 2 1 | // D2 lo_hi = (93 % 10) * (75 / 10)
+ * + 6 3 | | // D2 hi_hi = (93 / 10) * (75 / 10)
+ * ---------
+ * 2 7 | // D2 cross = (15 / 10) + (45 % 10) + 21
+ * + 6 7 | | // D2 upper = (27 / 10) + (45 / 10) + 63
+ * ---------
+ * 6 9 7 5
+ *
+ * The reasons for adding the products like this are:
+ * 1. It avoids manual carry tracking. Just like how
+ * (9 * 9) + 9 + 9 = 99, the same applies with this for
+ * UINT64_MAX. This avoids a lot of complexity.
+ *
+ * 2. It hints for, and on Clang, compiles to, the powerful UMAAL
+ * instruction available in ARMv6+ A32/T32, which is shown below:
+ *
+ * void UMAAL(xxh_u32 *RdLo, xxh_u32 *RdHi, xxh_u32 Rn, xxh_u32 Rm)
+ * {
+ * xxh_u64 product = (xxh_u64)*RdLo * (xxh_u64)*RdHi + Rn + Rm;
+ * *RdLo = (xxh_u32)(product & 0xFFFFFFFF);
+ * *RdHi = (xxh_u32)(product >> 32);
+ * }
+ *
+ * This instruction was designed for efficient long multiplication,
+ * and allows this to be calculated in only 4 instructions which
+ * is comparable to some 64-bit ALUs.
+ *
+ * 3. It isn't terrible on other platforms. Usually this will be
+ * a couple of 32-bit ADD/ADCs.
+ */
+
+ /* First calculate all of the cross products. */
+ xxh_u64 const lo_lo = XXPH_mult32to64(lhs & 0xFFFFFFFF, rhs & 0xFFFFFFFF);
+ xxh_u64 const hi_lo = XXPH_mult32to64(lhs >> 32, rhs & 0xFFFFFFFF);
+ xxh_u64 const lo_hi = XXPH_mult32to64(lhs & 0xFFFFFFFF, rhs >> 32);
+ xxh_u64 const hi_hi = XXPH_mult32to64(lhs >> 32, rhs >> 32);
+
+ /* Now add the products together. These will never overflow. */
+ xxh_u64 const cross = (lo_lo >> 32) + (hi_lo & 0xFFFFFFFF) + lo_hi;
+ xxh_u64 const upper = (hi_lo >> 32) + (cross >> 32) + hi_hi;
+ xxh_u64 const lower = (cross << 32) | (lo_lo & 0xFFFFFFFF);
+
+ XXPH128_hash_t r128 = { lower, upper };
+ return r128;
+#endif
+}
+
+/*
+ * We want to keep the attribute here because a target switch
+ * disables inlining.
+ *
+ * Does a 64-bit to 128-bit multiply, then XOR folds it.
+ * The reason for the separate function is to prevent passing
+ * too many structs around by value. This will hopefully inline
+ * the multiply, but we don't force it.
+ */
+#if defined(__GNUC__) && !defined(__clang__) && defined(__i386__)
+__attribute__((__target__("no-sse")))
+#endif
+static xxh_u64
+XXPH3_mul128_fold64(xxh_u64 lhs, xxh_u64 rhs)
+{
+ XXPH128_hash_t product = XXPH_mult64to128(lhs, rhs);
+ return product.low64 ^ product.high64;
+}
+
+
+static XXPH64_hash_t XXPH3_avalanche(xxh_u64 h64)
+{
+ h64 ^= h64 >> 37;
+ h64 *= PRIME64_3;
+ h64 ^= h64 >> 32;
+ return h64;
+}
+
+
+/* ==========================================
+ * Short keys
+ * ========================================== */
+
+XXPH_FORCE_INLINE XXPH64_hash_t
+XXPH3_len_1to3_64b(const xxh_u8* input, size_t len, const xxh_u8* secret, XXPH64_hash_t seed)
+{
+ XXPH_ASSERT(input != NULL);
+ XXPH_ASSERT(1 <= len && len <= 3);
+ XXPH_ASSERT(secret != NULL);
+ { xxh_u8 const c1 = input[0];
+ xxh_u8 const c2 = input[len >> 1];
+ xxh_u8 const c3 = input[len - 1];
+ xxh_u32 const combined = ((xxh_u32)c1) | (((xxh_u32)c2) << 8) | (((xxh_u32)c3) << 16) | (((xxh_u32)len) << 24);
+ xxh_u64 const keyed = (xxh_u64)combined ^ (XXPH_readLE32(secret) + seed);
+ xxh_u64 const mixed = keyed * PRIME64_1;
+ return XXPH3_avalanche(mixed);
+ }
+}
+
+XXPH_FORCE_INLINE XXPH64_hash_t
+XXPH3_len_4to8_64b(const xxh_u8* input, size_t len, const xxh_u8* secret, XXPH64_hash_t seed)
+{
+ XXPH_ASSERT(input != NULL);
+ XXPH_ASSERT(secret != NULL);
+ XXPH_ASSERT(4 <= len && len <= 8);
+ { xxh_u32 const input_lo = XXPH_readLE32(input);
+ xxh_u32 const input_hi = XXPH_readLE32(input + len - 4);
+ xxh_u64 const input_64 = input_lo | ((xxh_u64)input_hi << 32);
+ xxh_u64 const keyed = input_64 ^ (XXPH_readLE64(secret) + seed);
+ xxh_u64 const mix64 = len + ((keyed ^ (keyed >> 51)) * PRIME32_1);
+ return XXPH3_avalanche((mix64 ^ (mix64 >> 47)) * PRIME64_2);
+ }
+}
+
+XXPH_FORCE_INLINE XXPH64_hash_t
+XXPH3_len_9to16_64b(const xxh_u8* input, size_t len, const xxh_u8* secret, XXPH64_hash_t seed)
+{
+ XXPH_ASSERT(input != NULL);
+ XXPH_ASSERT(secret != NULL);
+ XXPH_ASSERT(9 <= len && len <= 16);
+ { xxh_u64 const input_lo = XXPH_readLE64(input) ^ (XXPH_readLE64(secret) + seed);
+ xxh_u64 const input_hi = XXPH_readLE64(input + len - 8) ^ (XXPH_readLE64(secret + 8) - seed);
+ xxh_u64 const acc = len + (input_lo + input_hi) + XXPH3_mul128_fold64(input_lo, input_hi);
+ return XXPH3_avalanche(acc);
+ }
+}
+
+XXPH_FORCE_INLINE XXPH64_hash_t
+XXPH3_len_0to16_64b(const xxh_u8* input, size_t len, const xxh_u8* secret, XXPH64_hash_t seed)
+{
+ XXPH_ASSERT(len <= 16);
+ { if (len > 8) return XXPH3_len_9to16_64b(input, len, secret, seed);
+ if (len >= 4) return XXPH3_len_4to8_64b(input, len, secret, seed);
+ if (len) return XXPH3_len_1to3_64b(input, len, secret, seed);
+ /*
+ * RocksDB modification from XXPH3 preview: zero result for empty
+ * string can be problematic for multiplication-based algorithms.
+ * Return a hash of the seed instead.
+ */
+ return XXPH3_mul128_fold64(seed + XXPH_readLE64(secret), PRIME64_2);
+ }
+}
+
+
+/* === Long Keys === */
+
+#define STRIPE_LEN 64
+#define XXPH_SECRET_CONSUME_RATE 8 /* nb of secret bytes consumed at each accumulation */
+#define ACC_NB (STRIPE_LEN / sizeof(xxh_u64))
+
+typedef enum { XXPH3_acc_64bits, XXPH3_acc_128bits } XXPH3_accWidth_e;
+
+XXPH_FORCE_INLINE void
+XXPH3_accumulate_512( void* XXPH_RESTRICT acc,
+ const void* XXPH_RESTRICT input,
+ const void* XXPH_RESTRICT secret,
+ XXPH3_accWidth_e accWidth)
+{
+#if (XXPH_VECTOR == XXPH_AVX2)
+
+ XXPH_ASSERT((((size_t)acc) & 31) == 0);
+ { XXPH_ALIGN(32) __m256i* const xacc = (__m256i *) acc;
+ const __m256i* const xinput = (const __m256i *) input; /* not really aligned, just for ptr arithmetic, and because _mm256_loadu_si256() requires this type */
+ const __m256i* const xsecret = (const __m256i *) secret; /* not really aligned, just for ptr arithmetic, and because _mm256_loadu_si256() requires this type */
+
+ size_t i;
+ for (i=0; i < STRIPE_LEN/sizeof(__m256i); i++) {
+ __m256i const data_vec = _mm256_loadu_si256 (xinput+i);
+ __m256i const key_vec = _mm256_loadu_si256 (xsecret+i);
+ __m256i const data_key = _mm256_xor_si256 (data_vec, key_vec); /* uint32 dk[8] = {d0+k0, d1+k1, d2+k2, d3+k3, ...} */
+ __m256i const product = _mm256_mul_epu32 (data_key, _mm256_shuffle_epi32 (data_key, 0x31)); /* uint64 mul[4] = {dk0*dk1, dk2*dk3, ...} */
+ if (accWidth == XXPH3_acc_128bits) {
+ __m256i const data_swap = _mm256_shuffle_epi32(data_vec, _MM_SHUFFLE(1,0,3,2));
+ __m256i const sum = _mm256_add_epi64(xacc[i], data_swap);
+ xacc[i] = _mm256_add_epi64(product, sum);
+ } else { /* XXPH3_acc_64bits */
+ __m256i const sum = _mm256_add_epi64(xacc[i], data_vec);
+ xacc[i] = _mm256_add_epi64(product, sum);
+ }
+ } }
+
+#elif (XXPH_VECTOR == XXPH_SSE2)
+
+ XXPH_ASSERT((((size_t)acc) & 15) == 0);
+ { XXPH_ALIGN(16) __m128i* const xacc = (__m128i *) acc;
+ const __m128i* const xinput = (const __m128i *) input; /* not really aligned, just for ptr arithmetic, and because _mm_loadu_si128() requires this type */
+ const __m128i* const xsecret = (const __m128i *) secret; /* not really aligned, just for ptr arithmetic, and because _mm_loadu_si128() requires this type */
+
+ size_t i;
+ for (i=0; i < STRIPE_LEN/sizeof(__m128i); i++) {
+ __m128i const data_vec = _mm_loadu_si128 (xinput+i);
+ __m128i const key_vec = _mm_loadu_si128 (xsecret+i);
+ __m128i const data_key = _mm_xor_si128 (data_vec, key_vec); /* uint32 dk[8] = {d0+k0, d1+k1, d2+k2, d3+k3, ...} */
+ __m128i const product = _mm_mul_epu32 (data_key, _mm_shuffle_epi32 (data_key, 0x31)); /* uint64 mul[4] = {dk0*dk1, dk2*dk3, ...} */
+ if (accWidth == XXPH3_acc_128bits) {
+ __m128i const data_swap = _mm_shuffle_epi32(data_vec, _MM_SHUFFLE(1,0,3,2));
+ __m128i const sum = _mm_add_epi64(xacc[i], data_swap);
+ xacc[i] = _mm_add_epi64(product, sum);
+ } else { /* XXPH3_acc_64bits */
+ __m128i const sum = _mm_add_epi64(xacc[i], data_vec);
+ xacc[i] = _mm_add_epi64(product, sum);
+ }
+ } }
+
+#elif (XXPH_VECTOR == XXPH_NEON)
+
+ XXPH_ASSERT((((size_t)acc) & 15) == 0);
+ {
+ XXPH_ALIGN(16) uint64x2_t* const xacc = (uint64x2_t *) acc;
+ /* We don't use a uint32x4_t pointer because it causes bus errors on ARMv7. */
+ uint8_t const* const xinput = (const uint8_t *) input;
+ uint8_t const* const xsecret = (const uint8_t *) secret;
+
+ size_t i;
+ for (i=0; i < STRIPE_LEN / sizeof(uint64x2_t); i++) {
+#if !defined(__aarch64__) && !defined(__arm64__) && defined(__GNUC__) /* ARM32-specific hack */
+ /* vzip on ARMv7 Clang generates a lot of vmovs (technically vorrs) without this.
+ * vzip on 32-bit ARM NEON will overwrite the original register, and I think that Clang
+ * assumes I don't want to destroy it and tries to make a copy. This slows down the code
+ * a lot.
+ * aarch64 not only uses an entirely different syntax, but it requires three
+ * instructions...
+ * ext v1.16B, v0.16B, #8 // select high bits because aarch64 can't address them directly
+ * zip1 v3.2s, v0.2s, v1.2s // first zip
+ * zip2 v2.2s, v0.2s, v1.2s // second zip
+ * ...to do what ARM does in one:
+ * vzip.32 d0, d1 // Interleave high and low bits and overwrite. */
+
+ /* data_vec = xsecret[i]; */
+ uint8x16_t const data_vec = vld1q_u8(xinput + (i * 16));
+ /* key_vec = xsecret[i]; */
+ uint8x16_t const key_vec = vld1q_u8(xsecret + (i * 16));
+ /* data_key = data_vec ^ key_vec; */
+ uint32x4_t data_key;
+
+ if (accWidth == XXPH3_acc_64bits) {
+ /* Add first to prevent register swaps */
+ /* xacc[i] += data_vec; */
+ xacc[i] = vaddq_u64 (xacc[i], vreinterpretq_u64_u8(data_vec));
+ } else { /* XXPH3_acc_128bits */
+ /* xacc[i] += swap(data_vec); */
+ /* can probably be optimized better */
+ uint64x2_t const data64 = vreinterpretq_u64_u8(data_vec);
+ uint64x2_t const swapped= vextq_u64(data64, data64, 1);
+ xacc[i] = vaddq_u64 (xacc[i], swapped);
+ }
+
+ data_key = vreinterpretq_u32_u8(veorq_u8(data_vec, key_vec));
+
+ /* Here's the magic. We use the quirkiness of vzip to shuffle data_key in place.
+ * shuffle: data_key[0, 1, 2, 3] = data_key[0, 2, 1, 3] */
+ __asm__("vzip.32 %e0, %f0" : "+w" (data_key));
+ /* xacc[i] += (uint64x2_t) data_key[0, 1] * (uint64x2_t) data_key[2, 3]; */
+ xacc[i] = vmlal_u32(xacc[i], vget_low_u32(data_key), vget_high_u32(data_key));
+
+#else
+ /* On aarch64, vshrn/vmovn seems to be equivalent to, if not faster than, the vzip method. */
+
+ /* data_vec = xsecret[i]; */
+ uint8x16_t const data_vec = vld1q_u8(xinput + (i * 16));
+ /* key_vec = xsecret[i]; */
+ uint8x16_t const key_vec = vld1q_u8(xsecret + (i * 16));
+ /* data_key = data_vec ^ key_vec; */
+ uint64x2_t const data_key = vreinterpretq_u64_u8(veorq_u8(data_vec, key_vec));
+ /* data_key_lo = (uint32x2_t) (data_key & 0xFFFFFFFF); */
+ uint32x2_t const data_key_lo = vmovn_u64 (data_key);
+ /* data_key_hi = (uint32x2_t) (data_key >> 32); */
+ uint32x2_t const data_key_hi = vshrn_n_u64 (data_key, 32);
+ if (accWidth == XXPH3_acc_64bits) {
+ /* xacc[i] += data_vec; */
+ xacc[i] = vaddq_u64 (xacc[i], vreinterpretq_u64_u8(data_vec));
+ } else { /* XXPH3_acc_128bits */
+ /* xacc[i] += swap(data_vec); */
+ uint64x2_t const data64 = vreinterpretq_u64_u8(data_vec);
+ uint64x2_t const swapped= vextq_u64(data64, data64, 1);
+ xacc[i] = vaddq_u64 (xacc[i], swapped);
+ }
+ /* xacc[i] += (uint64x2_t) data_key_lo * (uint64x2_t) data_key_hi; */
+ xacc[i] = vmlal_u32 (xacc[i], data_key_lo, data_key_hi);
+
+#endif
+ }
+ }
+
+#elif (XXPH_VECTOR == XXPH_VSX) && /* work around a compiler bug */ (__GNUC__ > 5)
+ U64x2* const xacc = (U64x2*) acc; /* presumed aligned */
+ U64x2 const* const xinput = (U64x2 const*) input; /* no alignment restriction */
+ U64x2 const* const xsecret = (U64x2 const*) secret; /* no alignment restriction */
+ U64x2 const v32 = { 32, 32 };
+#if XXPH_VSX_BE
+ U8x16 const vXorSwap = { 0x07, 0x16, 0x25, 0x34, 0x43, 0x52, 0x61, 0x70,
+ 0x8F, 0x9E, 0xAD, 0xBC, 0xCB, 0xDA, 0xE9, 0xF8 };
+#endif
+ size_t i;
+ for (i = 0; i < STRIPE_LEN / sizeof(U64x2); i++) {
+ /* data_vec = xinput[i]; */
+ /* key_vec = xsecret[i]; */
+#if XXPH_VSX_BE
+ /* byteswap */
+ U64x2 const data_vec = XXPH_vec_revb(vec_vsx_ld(0, xinput + i));
+ U64x2 const key_raw = vec_vsx_ld(0, xsecret + i);
+ /* See comment above. data_key = data_vec ^ swap(xsecret[i]); */
+ U64x2 const data_key = (U64x2)XXPH_vec_permxor((U8x16)data_vec, (U8x16)key_raw, vXorSwap);
+#else
+ U64x2 const data_vec = vec_vsx_ld(0, xinput + i);
+ U64x2 const key_vec = vec_vsx_ld(0, xsecret + i);
+ U64x2 const data_key = data_vec ^ key_vec;
+#endif
+ /* shuffled = (data_key << 32) | (data_key >> 32); */
+ U32x4 const shuffled = (U32x4)vec_rl(data_key, v32);
+ /* product = ((U64x2)data_key & 0xFFFFFFFF) * ((U64x2)shuffled & 0xFFFFFFFF); */
+ U64x2 const product = XXPH_vec_mulo((U32x4)data_key, shuffled);
+ xacc[i] += product;
+
+ if (accWidth == XXPH3_acc_64bits) {
+ xacc[i] += data_vec;
+ } else { /* XXPH3_acc_128bits */
+ /* swap high and low halves */
+ U64x2 const data_swapped = vec_xxpermdi(data_vec, data_vec, 2);
+ xacc[i] += data_swapped;
+ }
+ }
+
+#else /* scalar variant of Accumulator - universal */
+
+ XXPH_ALIGN(XXPH_ACC_ALIGN) xxh_u64* const xacc = (xxh_u64*) acc; /* presumed aligned on 32-bytes boundaries, little hint for the auto-vectorizer */
+ const xxh_u8* const xinput = (const xxh_u8*) input; /* no alignment restriction */
+ const xxh_u8* const xsecret = (const xxh_u8*) secret; /* no alignment restriction */
+ size_t i;
+ XXPH_ASSERT(((size_t)acc & (XXPH_ACC_ALIGN-1)) == 0);
+ for (i=0; i < ACC_NB; i++) {
+ xxh_u64 const data_val = XXPH_readLE64(xinput + 8*i);
+ xxh_u64 const data_key = data_val ^ XXPH_readLE64(xsecret + i*8);
+
+ if (accWidth == XXPH3_acc_64bits) {
+ xacc[i] += data_val;
+ } else {
+ xacc[i ^ 1] += data_val; /* swap adjacent lanes */
+ }
+ xacc[i] += XXPH_mult32to64(data_key & 0xFFFFFFFF, data_key >> 32);
+ }
+#endif
+}
+
+XXPH_FORCE_INLINE void
+XXPH3_scrambleAcc(void* XXPH_RESTRICT acc, const void* XXPH_RESTRICT secret)
+{
+#if (XXPH_VECTOR == XXPH_AVX2)
+
+ XXPH_ASSERT((((size_t)acc) & 31) == 0);
+ { XXPH_ALIGN(32) __m256i* const xacc = (__m256i*) acc;
+ const __m256i* const xsecret = (const __m256i *) secret; /* not really aligned, just for ptr arithmetic, and because _mm256_loadu_si256() requires this argument type */
+ const __m256i prime32 = _mm256_set1_epi32((int)PRIME32_1);
+
+ size_t i;
+ for (i=0; i < STRIPE_LEN/sizeof(__m256i); i++) {
+ /* xacc[i] ^= (xacc[i] >> 47) */
+ __m256i const acc_vec = xacc[i];
+ __m256i const shifted = _mm256_srli_epi64 (acc_vec, 47);
+ __m256i const data_vec = _mm256_xor_si256 (acc_vec, shifted);
+ /* xacc[i] ^= xsecret; */
+ __m256i const key_vec = _mm256_loadu_si256 (xsecret+i);
+ __m256i const data_key = _mm256_xor_si256 (data_vec, key_vec);
+
+ /* xacc[i] *= PRIME32_1; */
+ __m256i const data_key_hi = _mm256_shuffle_epi32 (data_key, 0x31);
+ __m256i const prod_lo = _mm256_mul_epu32 (data_key, prime32);
+ __m256i const prod_hi = _mm256_mul_epu32 (data_key_hi, prime32);
+ xacc[i] = _mm256_add_epi64(prod_lo, _mm256_slli_epi64(prod_hi, 32));
+ }
+ }
+
+#elif (XXPH_VECTOR == XXPH_SSE2)
+
+ XXPH_ASSERT((((size_t)acc) & 15) == 0);
+ { XXPH_ALIGN(16) __m128i* const xacc = (__m128i*) acc;
+ const __m128i* const xsecret = (const __m128i *) secret; /* not really aligned, just for ptr arithmetic, and because _mm_loadu_si128() requires this argument type */
+ const __m128i prime32 = _mm_set1_epi32((int)PRIME32_1);
+
+ size_t i;
+ for (i=0; i < STRIPE_LEN/sizeof(__m128i); i++) {
+ /* xacc[i] ^= (xacc[i] >> 47) */
+ __m128i const acc_vec = xacc[i];
+ __m128i const shifted = _mm_srli_epi64 (acc_vec, 47);
+ __m128i const data_vec = _mm_xor_si128 (acc_vec, shifted);
+ /* xacc[i] ^= xsecret; */
+ __m128i const key_vec = _mm_loadu_si128 (xsecret+i);
+ __m128i const data_key = _mm_xor_si128 (data_vec, key_vec);
+
+ /* xacc[i] *= PRIME32_1; */
+ __m128i const data_key_hi = _mm_shuffle_epi32 (data_key, 0x31);
+ __m128i const prod_lo = _mm_mul_epu32 (data_key, prime32);
+ __m128i const prod_hi = _mm_mul_epu32 (data_key_hi, prime32);
+ xacc[i] = _mm_add_epi64(prod_lo, _mm_slli_epi64(prod_hi, 32));
+ }
+ }
+
+#elif (XXPH_VECTOR == XXPH_NEON)
+
+ XXPH_ASSERT((((size_t)acc) & 15) == 0);
+
+ { uint64x2_t* const xacc = (uint64x2_t*) acc;
+ uint8_t const* const xsecret = (uint8_t const*) secret;
+ uint32x2_t const prime = vdup_n_u32 (PRIME32_1);
+
+ size_t i;
+ for (i=0; i < STRIPE_LEN/sizeof(uint64x2_t); i++) {
+ /* data_vec = xacc[i] ^ (xacc[i] >> 47); */
+ uint64x2_t const acc_vec = xacc[i];
+ uint64x2_t const shifted = vshrq_n_u64 (acc_vec, 47);
+ uint64x2_t const data_vec = veorq_u64 (acc_vec, shifted);
+
+ /* key_vec = xsecret[i]; */
+ uint32x4_t const key_vec = vreinterpretq_u32_u8(vld1q_u8(xsecret + (i * 16)));
+ /* data_key = data_vec ^ key_vec; */
+ uint32x4_t const data_key = veorq_u32 (vreinterpretq_u32_u64(data_vec), key_vec);
+ /* shuffled = { data_key[0, 2], data_key[1, 3] }; */
+ uint32x2x2_t const shuffled = vzip_u32 (vget_low_u32(data_key), vget_high_u32(data_key));
+
+ /* data_key *= PRIME32_1 */
+
+ /* prod_hi = (data_key >> 32) * PRIME32_1; */
+ uint64x2_t const prod_hi = vmull_u32 (shuffled.val[1], prime);
+ /* xacc[i] = prod_hi << 32; */
+ xacc[i] = vshlq_n_u64(prod_hi, 32);
+ /* xacc[i] += (prod_hi & 0xFFFFFFFF) * PRIME32_1; */
+ xacc[i] = vmlal_u32(xacc[i], shuffled.val[0], prime);
+ } }
+
+#elif (XXPH_VECTOR == XXPH_VSX) && /* work around a compiler bug */ (__GNUC__ > 5)
+
+ U64x2* const xacc = (U64x2*) acc;
+ const U64x2* const xsecret = (const U64x2*) secret;
+ /* constants */
+ U64x2 const v32 = { 32, 32 };
+ U64x2 const v47 = { 47, 47 };
+ U32x4 const prime = { PRIME32_1, PRIME32_1, PRIME32_1, PRIME32_1 };
+ size_t i;
+#if XXPH_VSX_BE
+ /* endian swap */
+ U8x16 const vXorSwap = { 0x07, 0x16, 0x25, 0x34, 0x43, 0x52, 0x61, 0x70,
+ 0x8F, 0x9E, 0xAD, 0xBC, 0xCB, 0xDA, 0xE9, 0xF8 };
+#endif
+ for (i = 0; i < STRIPE_LEN / sizeof(U64x2); i++) {
+ U64x2 const acc_vec = xacc[i];
+ U64x2 const data_vec = acc_vec ^ (acc_vec >> v47);
+ /* key_vec = xsecret[i]; */
+#if XXPH_VSX_BE
+ /* swap bytes words */
+ U64x2 const key_raw = vec_vsx_ld(0, xsecret + i);
+ U64x2 const data_key = (U64x2)XXPH_vec_permxor((U8x16)data_vec, (U8x16)key_raw, vXorSwap);
+#else
+ U64x2 const key_vec = vec_vsx_ld(0, xsecret + i);
+ U64x2 const data_key = data_vec ^ key_vec;
+#endif
+
+ /* data_key *= PRIME32_1 */
+
+ /* prod_lo = ((U64x2)data_key & 0xFFFFFFFF) * ((U64x2)prime & 0xFFFFFFFF); */
+ U64x2 const prod_even = XXPH_vec_mule((U32x4)data_key, prime);
+ /* prod_hi = ((U64x2)data_key >> 32) * ((U64x2)prime >> 32); */
+ U64x2 const prod_odd = XXPH_vec_mulo((U32x4)data_key, prime);
+ xacc[i] = prod_odd + (prod_even << v32);
+ }
+
+#else /* scalar variant of Scrambler - universal */
+
+ XXPH_ALIGN(XXPH_ACC_ALIGN) xxh_u64* const xacc = (xxh_u64*) acc; /* presumed aligned on 32-bytes boundaries, little hint for the auto-vectorizer */
+ const xxh_u8* const xsecret = (const xxh_u8*) secret; /* no alignment restriction */
+ size_t i;
+ XXPH_ASSERT((((size_t)acc) & (XXPH_ACC_ALIGN-1)) == 0);
+ for (i=0; i < ACC_NB; i++) {
+ xxh_u64 const key64 = XXPH_readLE64(xsecret + 8*i);
+ xxh_u64 acc64 = xacc[i];
+ acc64 ^= acc64 >> 47;
+ acc64 ^= key64;
+ acc64 *= PRIME32_1;
+ xacc[i] = acc64;
+ }
+
+#endif
+}
+
+#define XXPH_PREFETCH_DIST 384
+
+/* assumption : nbStripes will not overflow secret size */
+XXPH_FORCE_INLINE void
+XXPH3_accumulate( xxh_u64* XXPH_RESTRICT acc,
+ const xxh_u8* XXPH_RESTRICT input,
+ const xxh_u8* XXPH_RESTRICT secret,
+ size_t nbStripes,
+ XXPH3_accWidth_e accWidth)
+{
+ size_t n;
+ for (n = 0; n < nbStripes; n++ ) {
+ const xxh_u8* const in = input + n*STRIPE_LEN;
+ XXPH_PREFETCH(in + XXPH_PREFETCH_DIST);
+ XXPH3_accumulate_512(acc,
+ in,
+ secret + n*XXPH_SECRET_CONSUME_RATE,
+ accWidth);
+ }
+}
+
+/* note : clang auto-vectorizes well in SS2 mode _if_ this function is `static`,
+ * and doesn't auto-vectorize it at all if it is `FORCE_INLINE`.
+ * However, it auto-vectorizes better AVX2 if it is `FORCE_INLINE`
+ * Pretty much every other modes and compilers prefer `FORCE_INLINE`.
+ */
+
+#if defined(__clang__) && (XXPH_VECTOR==0) && !defined(__AVX2__) && !defined(__arm__) && !defined(__thumb__)
+static void
+#else
+XXPH_FORCE_INLINE void
+#endif
+XXPH3_hashLong_internal_loop( xxh_u64* XXPH_RESTRICT acc,
+ const xxh_u8* XXPH_RESTRICT input, size_t len,
+ const xxh_u8* XXPH_RESTRICT secret, size_t secretSize,
+ XXPH3_accWidth_e accWidth)
+{
+ size_t const nb_rounds = (secretSize - STRIPE_LEN) / XXPH_SECRET_CONSUME_RATE;
+ size_t const block_len = STRIPE_LEN * nb_rounds;
+ size_t const nb_blocks = len / block_len;
+
+ size_t n;
+
+ XXPH_ASSERT(secretSize >= XXPH3_SECRET_SIZE_MIN);
+
+ for (n = 0; n < nb_blocks; n++) {
+ XXPH3_accumulate(acc, input + n*block_len, secret, nb_rounds, accWidth);
+ XXPH3_scrambleAcc(acc, secret + secretSize - STRIPE_LEN);
+ }
+
+ /* last partial block */
+ XXPH_ASSERT(len > STRIPE_LEN);
+ { size_t const nbStripes = (len - (block_len * nb_blocks)) / STRIPE_LEN;
+ XXPH_ASSERT(nbStripes <= (secretSize / XXPH_SECRET_CONSUME_RATE));
+ XXPH3_accumulate(acc, input + nb_blocks*block_len, secret, nbStripes, accWidth);
+
+ /* last stripe */
+ if (len & (STRIPE_LEN - 1)) {
+ const xxh_u8* const p = input + len - STRIPE_LEN;
+#define XXPH_SECRET_LASTACC_START 7 /* do not align on 8, so that secret is different from scrambler */
+ XXPH3_accumulate_512(acc, p, secret + secretSize - STRIPE_LEN - XXPH_SECRET_LASTACC_START, accWidth);
+ } }
+}
+
+XXPH_FORCE_INLINE xxh_u64
+XXPH3_mix2Accs(const xxh_u64* XXPH_RESTRICT acc, const xxh_u8* XXPH_RESTRICT secret)
+{
+ return XXPH3_mul128_fold64(
+ acc[0] ^ XXPH_readLE64(secret),
+ acc[1] ^ XXPH_readLE64(secret+8) );
+}
+
+static XXPH64_hash_t
+XXPH3_mergeAccs(const xxh_u64* XXPH_RESTRICT acc, const xxh_u8* XXPH_RESTRICT secret, xxh_u64 start)
+{
+ xxh_u64 result64 = start;
+
+ result64 += XXPH3_mix2Accs(acc+0, secret + 0);
+ result64 += XXPH3_mix2Accs(acc+2, secret + 16);
+ result64 += XXPH3_mix2Accs(acc+4, secret + 32);
+ result64 += XXPH3_mix2Accs(acc+6, secret + 48);
+
+ return XXPH3_avalanche(result64);
+}
+
+#define XXPH3_INIT_ACC { PRIME32_3, PRIME64_1, PRIME64_2, PRIME64_3, \
+ PRIME64_4, PRIME32_2, PRIME64_5, PRIME32_1 };
+
+XXPH_FORCE_INLINE XXPH64_hash_t
+XXPH3_hashLong_internal(const xxh_u8* XXPH_RESTRICT input, size_t len,
+ const xxh_u8* XXPH_RESTRICT secret, size_t secretSize)
+{
+ XXPH_ALIGN(XXPH_ACC_ALIGN) xxh_u64 acc[ACC_NB] = XXPH3_INIT_ACC;
+
+ XXPH3_hashLong_internal_loop(acc, input, len, secret, secretSize, XXPH3_acc_64bits);
+
+ /* converge into final hash */
+ XXPH_STATIC_ASSERT(sizeof(acc) == 64);
+#define XXPH_SECRET_MERGEACCS_START 11 /* do not align on 8, so that secret is different from accumulator */
+ XXPH_ASSERT(secretSize >= sizeof(acc) + XXPH_SECRET_MERGEACCS_START);
+ return XXPH3_mergeAccs(acc, secret + XXPH_SECRET_MERGEACCS_START, (xxh_u64)len * PRIME64_1);
+}
+
+
+XXPH_NO_INLINE XXPH64_hash_t /* It's important for performance that XXPH3_hashLong is not inlined. Not sure why (uop cache maybe ?), but difference is large and easily measurable */
+XXPH3_hashLong_64b_defaultSecret(const xxh_u8* XXPH_RESTRICT input, size_t len)
+{
+ return XXPH3_hashLong_internal(input, len, kSecret, sizeof(kSecret));
+}
+
+XXPH_NO_INLINE XXPH64_hash_t /* It's important for performance that XXPH3_hashLong is not inlined. Not sure why (uop cache maybe ?), but difference is large and easily measurable */
+XXPH3_hashLong_64b_withSecret(const xxh_u8* XXPH_RESTRICT input, size_t len,
+ const xxh_u8* XXPH_RESTRICT secret, size_t secretSize)
+{
+ return XXPH3_hashLong_internal(input, len, secret, secretSize);
+}
+
+
+XXPH_FORCE_INLINE void XXPH_writeLE64(void* dst, xxh_u64 v64)
+{
+ if (!XXPH_CPU_LITTLE_ENDIAN) v64 = XXPH_swap64(v64);
+ memcpy(dst, &v64, sizeof(v64));
+}
+
+/* XXPH3_initCustomSecret() :
+ * destination `customSecret` is presumed allocated and same size as `kSecret`.
+ */
+XXPH_FORCE_INLINE void XXPH3_initCustomSecret(xxh_u8* customSecret, xxh_u64 seed64)
+{
+ int const nbRounds = XXPH_SECRET_DEFAULT_SIZE / 16;
+ int i;
+
+ XXPH_STATIC_ASSERT((XXPH_SECRET_DEFAULT_SIZE & 15) == 0);
+
+ for (i=0; i < nbRounds; i++) {
+ XXPH_writeLE64(customSecret + 16*i, XXPH_readLE64(kSecret + 16*i) + seed64);
+ XXPH_writeLE64(customSecret + 16*i + 8, XXPH_readLE64(kSecret + 16*i + 8) - seed64);
+ }
+}
+
+
+/* XXPH3_hashLong_64b_withSeed() :
+ * Generate a custom key,
+ * based on alteration of default kSecret with the seed,
+ * and then use this key for long mode hashing.
+ * This operation is decently fast but nonetheless costs a little bit of time.
+ * Try to avoid it whenever possible (typically when seed==0).
+ */
+XXPH_NO_INLINE XXPH64_hash_t /* It's important for performance that XXPH3_hashLong is not inlined. Not sure why (uop cache maybe ?), but difference is large and easily measurable */
+XXPH3_hashLong_64b_withSeed(const xxh_u8* input, size_t len, XXPH64_hash_t seed)
+{
+ XXPH_ALIGN(8) xxh_u8 secret[XXPH_SECRET_DEFAULT_SIZE];
+ if (seed==0) return XXPH3_hashLong_64b_defaultSecret(input, len);
+ XXPH3_initCustomSecret(secret, seed);
+ return XXPH3_hashLong_internal(input, len, secret, sizeof(secret));
+}
+
+
+XXPH_FORCE_INLINE xxh_u64 XXPH3_mix16B(const xxh_u8* XXPH_RESTRICT input,
+ const xxh_u8* XXPH_RESTRICT secret, xxh_u64 seed64)
+{
+ xxh_u64 const input_lo = XXPH_readLE64(input);
+ xxh_u64 const input_hi = XXPH_readLE64(input+8);
+ return XXPH3_mul128_fold64(
+ input_lo ^ (XXPH_readLE64(secret) + seed64),
+ input_hi ^ (XXPH_readLE64(secret+8) - seed64) );
+}
+
+
+XXPH_FORCE_INLINE XXPH64_hash_t
+XXPH3_len_17to128_64b(const xxh_u8* XXPH_RESTRICT input, size_t len,
+ const xxh_u8* XXPH_RESTRICT secret, size_t secretSize,
+ XXPH64_hash_t seed)
+{
+ XXPH_ASSERT(secretSize >= XXPH3_SECRET_SIZE_MIN); (void)secretSize;
+ XXPH_ASSERT(16 < len && len <= 128);
+
+ { xxh_u64 acc = len * PRIME64_1;
+ if (len > 32) {
+ if (len > 64) {
+ if (len > 96) {
+ acc += XXPH3_mix16B(input+48, secret+96, seed);
+ acc += XXPH3_mix16B(input+len-64, secret+112, seed);
+ }
+ acc += XXPH3_mix16B(input+32, secret+64, seed);
+ acc += XXPH3_mix16B(input+len-48, secret+80, seed);
+ }
+ acc += XXPH3_mix16B(input+16, secret+32, seed);
+ acc += XXPH3_mix16B(input+len-32, secret+48, seed);
+ }
+ acc += XXPH3_mix16B(input+0, secret+0, seed);
+ acc += XXPH3_mix16B(input+len-16, secret+16, seed);
+
+ return XXPH3_avalanche(acc);
+ }
+}
+
+#define XXPH3_MIDSIZE_MAX 240
+
+XXPH_NO_INLINE XXPH64_hash_t
+XXPH3_len_129to240_64b(const xxh_u8* XXPH_RESTRICT input, size_t len,
+ const xxh_u8* XXPH_RESTRICT secret, size_t secretSize,
+ XXPH64_hash_t seed)
+{
+ XXPH_ASSERT(secretSize >= XXPH3_SECRET_SIZE_MIN); (void)secretSize;
+ XXPH_ASSERT(128 < len && len <= XXPH3_MIDSIZE_MAX);
+
+ #define XXPH3_MIDSIZE_STARTOFFSET 3
+ #define XXPH3_MIDSIZE_LASTOFFSET 17
+
+ { xxh_u64 acc = len * PRIME64_1;
+ int const nbRounds = (int)len / 16;
+ int i;
+ for (i=0; i<8; i++) {
+ acc += XXPH3_mix16B(input+(16*i), secret+(16*i), seed);
+ }
+ acc = XXPH3_avalanche(acc);
+ XXPH_ASSERT(nbRounds >= 8);
+ for (i=8 ; i < nbRounds; i++) {
+ acc += XXPH3_mix16B(input+(16*i), secret+(16*(i-8)) + XXPH3_MIDSIZE_STARTOFFSET, seed);
+ }
+ /* last bytes */
+ acc += XXPH3_mix16B(input + len - 16, secret + XXPH3_SECRET_SIZE_MIN - XXPH3_MIDSIZE_LASTOFFSET, seed);
+ return XXPH3_avalanche(acc);
+ }
+}
+
+/* === Public entry point === */
+
+XXPH_PUBLIC_API XXPH64_hash_t XXPH3_64bits(const void* input, size_t len)
+{
+ if (len <= 16) return XXPH3_len_0to16_64b((const xxh_u8*)input, len, kSecret, 0);
+ if (len <= 128) return XXPH3_len_17to128_64b((const xxh_u8*)input, len, kSecret, sizeof(kSecret), 0);
+ if (len <= XXPH3_MIDSIZE_MAX) return XXPH3_len_129to240_64b((const xxh_u8*)input, len, kSecret, sizeof(kSecret), 0);
+ return XXPH3_hashLong_64b_defaultSecret((const xxh_u8*)input, len);
+}
+
+XXPH_PUBLIC_API XXPH64_hash_t
+XXPH3_64bits_withSecret(const void* input, size_t len, const void* secret, size_t secretSize)
+{
+ XXPH_ASSERT(secretSize >= XXPH3_SECRET_SIZE_MIN);
+ /* if an action must be taken should `secret` conditions not be respected,
+ * it should be done here.
+ * For now, it's a contract pre-condition.
+ * Adding a check and a branch here would cost performance at every hash */
+ if (len <= 16) return XXPH3_len_0to16_64b((const xxh_u8*)input, len, (const xxh_u8*)secret, 0);
+ if (len <= 128) return XXPH3_len_17to128_64b((const xxh_u8*)input, len, (const xxh_u8*)secret, secretSize, 0);
+ if (len <= XXPH3_MIDSIZE_MAX) return XXPH3_len_129to240_64b((const xxh_u8*)input, len, (const xxh_u8*)secret, secretSize, 0);
+ return XXPH3_hashLong_64b_withSecret((const xxh_u8*)input, len, (const xxh_u8*)secret, secretSize);
+}
+
+XXPH_PUBLIC_API XXPH64_hash_t
+XXPH3_64bits_withSeed(const void* input, size_t len, XXPH64_hash_t seed)
+{
+ if (len <= 16) return XXPH3_len_0to16_64b((const xxh_u8*)input, len, kSecret, seed);
+ if (len <= 128) return XXPH3_len_17to128_64b((const xxh_u8*)input, len, kSecret, sizeof(kSecret), seed);
+ if (len <= XXPH3_MIDSIZE_MAX) return XXPH3_len_129to240_64b((const xxh_u8*)input, len, kSecret, sizeof(kSecret), seed);
+ return XXPH3_hashLong_64b_withSeed((const xxh_u8*)input, len, seed);
+}
+
+/* === XXPH3 streaming === */
+
+/* RocksDB Note: unused & removed due to bug in preview version */
+
+/*======== END #include "xxh3.h", now inlined above ==========*/
+
+#endif /* XXPH_NO_LONG_LONG */
+
+/* === END RocksDB modification of permanently inlining === */
+
+#endif /* defined(XXPH_INLINE_ALL) || defined(XXPH_PRIVATE_API) */
+
+#endif /* XXPH_STATIC_LINKING_ONLY */
+
+#if defined (__cplusplus)
+}
+#endif
+
+#endif /* XXPHASH_H_5627135585666179 */
diff --git a/src/rocksdb/utilities/agg_merge/agg_merge.cc b/src/rocksdb/utilities/agg_merge/agg_merge.cc
new file mode 100644
index 000000000..a7eab1f12
--- /dev/null
+++ b/src/rocksdb/utilities/agg_merge/agg_merge.cc
@@ -0,0 +1,238 @@
+// Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "utilities/agg_merge/agg_merge.h"
+
+#include <assert.h>
+
+#include <deque>
+#include <memory>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+#include "port/lang.h"
+#include "port/likely.h"
+#include "rocksdb/merge_operator.h"
+#include "rocksdb/slice.h"
+#include "rocksdb/utilities/agg_merge.h"
+#include "rocksdb/utilities/options_type.h"
+#include "util/coding.h"
+#include "utilities/merge_operators.h"
+
+namespace ROCKSDB_NAMESPACE {
+static std::unordered_map<std::string, std::unique_ptr<Aggregator>> func_map;
+const std::string kUnnamedFuncName = "";
+const std::string kErrorFuncName = "kErrorFuncName";
+
+Status AddAggregator(const std::string& function_name,
+ std::unique_ptr<Aggregator>&& agg) {
+ if (function_name == kErrorFuncName) {
+ return Status::InvalidArgument(
+ "Cannot register function name kErrorFuncName");
+ }
+ func_map.emplace(function_name, std::move(agg));
+ return Status::OK();
+}
+
+AggMergeOperator::AggMergeOperator() {}
+
+std::string EncodeAggFuncAndPayloadNoCheck(const Slice& function_name,
+ const Slice& value) {
+ std::string result;
+ PutLengthPrefixedSlice(&result, function_name);
+ result += value.ToString();
+ return result;
+}
+
+Status EncodeAggFuncAndPayload(const Slice& function_name, const Slice& payload,
+ std::string& output) {
+ if (function_name == kErrorFuncName) {
+ return Status::InvalidArgument("Cannot use error function name");
+ }
+ if (function_name != kUnnamedFuncName &&
+ func_map.find(function_name.ToString()) == func_map.end()) {
+ return Status::InvalidArgument("Function name not registered");
+ }
+ output = EncodeAggFuncAndPayloadNoCheck(function_name, payload);
+ return Status::OK();
+}
+
+bool ExtractAggFuncAndValue(const Slice& op, Slice& func, Slice& value) {
+ value = op;
+ return GetLengthPrefixedSlice(&value, &func);
+}
+
+bool ExtractList(const Slice& encoded_list, std::vector<Slice>& decoded_list) {
+ decoded_list.clear();
+ Slice list_slice = encoded_list;
+ Slice item;
+ while (GetLengthPrefixedSlice(&list_slice, &item)) {
+ decoded_list.push_back(item);
+ }
+ return list_slice.empty();
+}
+
+class AggMergeOperator::Accumulator {
+ public:
+ bool Add(const Slice& op, bool is_partial_aggregation) {
+ if (ignore_operands_) {
+ return true;
+ }
+ Slice my_func;
+ Slice my_value;
+ bool ret = ExtractAggFuncAndValue(op, my_func, my_value);
+ if (!ret) {
+ ignore_operands_ = true;
+ return true;
+ }
+
+ // Determine whether we need to do partial merge.
+ if (is_partial_aggregation && !my_func.empty()) {
+ auto f = func_map.find(my_func.ToString());
+ if (f == func_map.end() || !f->second->DoPartialAggregate()) {
+ return false;
+ }
+ }
+
+ if (!func_valid_) {
+ if (my_func != kUnnamedFuncName) {
+ func_ = my_func;
+ func_valid_ = true;
+ }
+ } else if (func_ != my_func) {
+ // User switched aggregation function. Need to aggregate the older
+ // one first.
+
+ // Previous aggreagion can't be done in partial merge
+ if (is_partial_aggregation) {
+ func_valid_ = false;
+ ignore_operands_ = true;
+ return false;
+ }
+
+ // We could consider stashing an iterator into the hash of aggregators
+ // to avoid repeated lookups when the aggregator doesn't change.
+ auto f = func_map.find(func_.ToString());
+ if (f == func_map.end() || !f->second->Aggregate(values_, scratch_)) {
+ func_valid_ = false;
+ ignore_operands_ = true;
+ return true;
+ }
+ std::swap(scratch_, aggregated_);
+ values_.clear();
+ values_.push_back(aggregated_);
+ func_ = my_func;
+ }
+ values_.push_back(my_value);
+ return true;
+ }
+
+ // Return false if aggregation fails.
+ // One possible reason
+ bool GetResult(std::string& result) {
+ if (!func_valid_) {
+ return false;
+ }
+ auto f = func_map.find(func_.ToString());
+ if (f == func_map.end()) {
+ return false;
+ }
+ if (!f->second->Aggregate(values_, scratch_)) {
+ return false;
+ }
+ result = EncodeAggFuncAndPayloadNoCheck(func_, scratch_);
+ return true;
+ }
+
+ void Clear() {
+ func_.clear();
+ values_.clear();
+ aggregated_.clear();
+ scratch_.clear();
+ ignore_operands_ = false;
+ func_valid_ = false;
+ }
+
+ private:
+ Slice func_;
+ std::vector<Slice> values_;
+ std::string aggregated_;
+ std::string scratch_;
+ bool ignore_operands_ = false;
+ bool func_valid_ = false;
+};
+
+// Creating and using a new Accumulator might invoke multiple malloc and is
+// expensive if it needs to be done when processing each merge operation.
+// AggMergeOperator's merge operators can be invoked concurrently by multiple
+// threads so we cannot simply create one Aggregator and reuse.
+// We use thread local instances instead.
+AggMergeOperator::Accumulator& AggMergeOperator::GetTLSAccumulator() {
+ static thread_local Accumulator tls_acc;
+ tls_acc.Clear();
+ return tls_acc;
+}
+
+void AggMergeOperator::PackAllMergeOperands(const MergeOperationInput& merge_in,
+ MergeOperationOutput& merge_out) {
+ merge_out.new_value = "";
+ PutLengthPrefixedSlice(&merge_out.new_value, kErrorFuncName);
+ if (merge_in.existing_value != nullptr) {
+ PutLengthPrefixedSlice(&merge_out.new_value, *merge_in.existing_value);
+ }
+ for (const Slice& op : merge_in.operand_list) {
+ PutLengthPrefixedSlice(&merge_out.new_value, op);
+ }
+}
+
+bool AggMergeOperator::FullMergeV2(const MergeOperationInput& merge_in,
+ MergeOperationOutput* merge_out) const {
+ Accumulator& agg = GetTLSAccumulator();
+ if (merge_in.existing_value != nullptr) {
+ agg.Add(*merge_in.existing_value, /*is_partial_aggregation=*/false);
+ }
+ for (const Slice& e : merge_in.operand_list) {
+ agg.Add(e, /*is_partial_aggregation=*/false);
+ }
+
+ bool succ = agg.GetResult(merge_out->new_value);
+ if (!succ) {
+ // If aggregation can't happen, pack all merge operands. In contrast to
+ // merge operator, we don't want to fail the DB. If users insert wrong
+ // format or call unregistered an aggregation function, we still hope
+ // the DB can continue functioning with other keys.
+ PackAllMergeOperands(merge_in, *merge_out);
+ }
+ agg.Clear();
+ return true;
+}
+
+bool AggMergeOperator::PartialMergeMulti(const Slice& /*key*/,
+ const std::deque<Slice>& operand_list,
+ std::string* new_value,
+ Logger* /*logger*/) const {
+ Accumulator& agg = GetTLSAccumulator();
+ bool do_aggregation = true;
+ for (const Slice& item : operand_list) {
+ do_aggregation = agg.Add(item, /*is_partial_aggregation=*/true);
+ if (!do_aggregation) {
+ break;
+ }
+ }
+ if (do_aggregation) {
+ do_aggregation = agg.GetResult(*new_value);
+ }
+ agg.Clear();
+ return do_aggregation;
+}
+
+std::shared_ptr<MergeOperator> GetAggMergeOperator() {
+ STATIC_AVOID_DESTRUCTION(std::shared_ptr<MergeOperator>, instance)
+ (std::make_shared<AggMergeOperator>());
+ assert(instance);
+ return instance;
+}
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/agg_merge/agg_merge.h b/src/rocksdb/utilities/agg_merge/agg_merge.h
new file mode 100644
index 000000000..00e58de08
--- /dev/null
+++ b/src/rocksdb/utilities/agg_merge/agg_merge.h
@@ -0,0 +1,49 @@
+// Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+#include <algorithm>
+#include <cstddef>
+#include <memory>
+#include <unordered_map>
+
+#include "rocksdb/merge_operator.h"
+#include "rocksdb/slice.h"
+#include "rocksdb/utilities/agg_merge.h"
+#include "utilities/cassandra/cassandra_options.h"
+
+namespace ROCKSDB_NAMESPACE {
+class AggMergeOperator : public MergeOperator {
+ public:
+ explicit AggMergeOperator();
+
+ bool FullMergeV2(const MergeOperationInput& merge_in,
+ MergeOperationOutput* merge_out) const override;
+
+ bool PartialMergeMulti(const Slice& key,
+ const std::deque<Slice>& operand_list,
+ std::string* new_value, Logger* logger) const override;
+
+ const char* Name() const override { return kClassName(); }
+ static const char* kClassName() { return "AggMergeOperator.v1"; }
+
+ bool AllowSingleOperand() const override { return true; }
+
+ bool ShouldMerge(const std::vector<Slice>&) const override { return false; }
+
+ private:
+ class Accumulator;
+
+ // Pack all merge operands into one value. This is called when aggregation
+ // fails. The existing values are preserved and returned so that users can
+ // debug the problem.
+ static void PackAllMergeOperands(const MergeOperationInput& merge_in,
+ MergeOperationOutput& merge_out);
+ static Accumulator& GetTLSAccumulator();
+};
+
+extern std::string EncodeAggFuncAndPayloadNoCheck(const Slice& function_name,
+ const Slice& value);
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/agg_merge/agg_merge_test.cc b/src/rocksdb/utilities/agg_merge/agg_merge_test.cc
new file mode 100644
index 000000000..a65441cd0
--- /dev/null
+++ b/src/rocksdb/utilities/agg_merge/agg_merge_test.cc
@@ -0,0 +1,135 @@
+// Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "rocksdb/utilities/agg_merge.h"
+
+#include <gtest/gtest.h>
+
+#include <memory>
+
+#include "db/db_test_util.h"
+#include "rocksdb/options.h"
+#include "test_util/testharness.h"
+#include "utilities/agg_merge/agg_merge.h"
+#include "utilities/agg_merge/test_agg_merge.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class AggMergeTest : public DBTestBase {
+ public:
+ AggMergeTest() : DBTestBase("agg_merge_db_test", /*env_do_fsync=*/true) {}
+};
+
+TEST_F(AggMergeTest, TestUsingMergeOperator) {
+ ASSERT_OK(AddAggregator("sum", std::make_unique<SumAggregator>()));
+ ASSERT_OK(AddAggregator("last3", std::make_unique<Last3Aggregator>()));
+ ASSERT_OK(AddAggregator("mul", std::make_unique<MultipleAggregator>()));
+
+ Options options = CurrentOptions();
+ options.merge_operator = GetAggMergeOperator();
+ Reopen(options);
+ std::string v = EncodeHelper::EncodeFuncAndInt("sum", 10);
+ ASSERT_OK(Merge("foo", v));
+ v = EncodeHelper::EncodeFuncAndInt("sum", 20);
+ ASSERT_OK(Merge("foo", v));
+ v = EncodeHelper::EncodeFuncAndInt("sum", 15);
+ ASSERT_OK(Merge("foo", v));
+
+ v = EncodeHelper::EncodeFuncAndList("last3", {"a", "b"});
+ ASSERT_OK(Merge("bar", v));
+ v = EncodeHelper::EncodeFuncAndList("last3", {"c", "d", "e"});
+ ASSERT_OK(Merge("bar", v));
+ ASSERT_OK(Flush());
+ v = EncodeHelper::EncodeFuncAndList("last3", {"f"});
+ ASSERT_OK(Merge("bar", v));
+
+ // Test Put() without aggregation type.
+ v = EncodeHelper::EncodeFuncAndInt(kUnnamedFuncName, 30);
+ ASSERT_OK(Put("foo2", v));
+ v = EncodeHelper::EncodeFuncAndInt("sum", 10);
+ ASSERT_OK(Merge("foo2", v));
+ v = EncodeHelper::EncodeFuncAndInt("sum", 20);
+ ASSERT_OK(Merge("foo2", v));
+
+ EXPECT_EQ(EncodeHelper::EncodeFuncAndInt("sum", 45), Get("foo"));
+ EXPECT_EQ(EncodeHelper::EncodeFuncAndList("last3", {"f", "c", "d"}),
+ Get("bar"));
+ EXPECT_EQ(EncodeHelper::EncodeFuncAndInt("sum", 60), Get("foo2"));
+
+ // Test changing aggregation type
+ v = EncodeHelper::EncodeFuncAndInt("mul", 10);
+ ASSERT_OK(Put("bar2", v));
+ v = EncodeHelper::EncodeFuncAndInt("mul", 20);
+ ASSERT_OK(Merge("bar2", v));
+ v = EncodeHelper::EncodeFuncAndInt("sum", 30);
+ ASSERT_OK(Merge("bar2", v));
+ v = EncodeHelper::EncodeFuncAndInt("sum", 40);
+ ASSERT_OK(Merge("bar2", v));
+ EXPECT_EQ(EncodeHelper::EncodeFuncAndInt("sum", 10 * 20 + 30 + 40),
+ Get("bar2"));
+
+ // Changing aggregation type with partial merge
+ v = EncodeHelper::EncodeFuncAndInt("mul", 10);
+ ASSERT_OK(Merge("foo3", v));
+ ASSERT_OK(Flush());
+ v = EncodeHelper::EncodeFuncAndInt("mul", 10);
+ ASSERT_OK(Merge("foo3", v));
+ v = EncodeHelper::EncodeFuncAndInt("mul", 10);
+ ASSERT_OK(Merge("foo3", v));
+ v = EncodeHelper::EncodeFuncAndInt("sum", 10);
+ ASSERT_OK(Merge("foo3", v));
+ ASSERT_OK(Flush());
+ EXPECT_EQ(EncodeHelper::EncodeFuncAndInt("sum", 10 * 10 * 10 + 10),
+ Get("foo3"));
+
+ // Merge after full merge
+ v = EncodeHelper::EncodeFuncAndInt("sum", 1);
+ ASSERT_OK(Merge("foo4", v));
+ v = EncodeHelper::EncodeFuncAndInt("sum", 2);
+ ASSERT_OK(Merge("foo4", v));
+ ASSERT_OK(Flush());
+ v = EncodeHelper::EncodeFuncAndInt("sum", 3);
+ ASSERT_OK(Merge("foo4", v));
+ v = EncodeHelper::EncodeFuncAndInt("sum", 4);
+ ASSERT_OK(Merge("foo4", v));
+ ASSERT_OK(Flush());
+ ASSERT_OK(db_->CompactRange(CompactRangeOptions(), nullptr, nullptr));
+ v = EncodeHelper::EncodeFuncAndInt("sum", 5);
+ ASSERT_OK(Merge("foo4", v));
+ EXPECT_EQ(EncodeHelper::EncodeFuncAndInt("sum", 15), Get("foo4"));
+
+ // Test unregistered function name
+ v = EncodeAggFuncAndPayloadNoCheck("non_existing", "1");
+ ASSERT_OK(Merge("bar3", v));
+ std::string v1;
+ v1 = EncodeAggFuncAndPayloadNoCheck("non_existing", "invalid");
+ ;
+ ASSERT_OK(Merge("bar3", v1));
+ EXPECT_EQ(EncodeAggFuncAndPayloadNoCheck(kErrorFuncName,
+ EncodeHelper::EncodeList({v, v1})),
+ Get("bar3"));
+
+ // invalidate input
+ ASSERT_OK(EncodeAggFuncAndPayload("sum", "invalid", v));
+ ASSERT_OK(Merge("bar4", v));
+ v1 = EncodeHelper::EncodeFuncAndInt("sum", 20);
+ ASSERT_OK(Merge("bar4", v1));
+ std::string aggregated_value = Get("bar4");
+ Slice func, payload;
+ ASSERT_TRUE(ExtractAggFuncAndValue(aggregated_value, func, payload));
+ EXPECT_EQ(kErrorFuncName, func);
+ std::vector<Slice> decoded_list;
+ ASSERT_TRUE(ExtractList(payload, decoded_list));
+ ASSERT_EQ(2, decoded_list.size());
+ ASSERT_EQ(v, decoded_list[0]);
+ ASSERT_EQ(v1, decoded_list[1]);
+}
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/rocksdb/utilities/agg_merge/test_agg_merge.cc b/src/rocksdb/utilities/agg_merge/test_agg_merge.cc
new file mode 100644
index 000000000..06e5b5697
--- /dev/null
+++ b/src/rocksdb/utilities/agg_merge/test_agg_merge.cc
@@ -0,0 +1,104 @@
+// Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "test_agg_merge.h"
+
+#include <assert.h>
+
+#include <deque>
+#include <vector>
+
+#include "util/coding.h"
+#include "utilities/agg_merge/agg_merge.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+std::string EncodeHelper::EncodeFuncAndInt(const Slice& function_name,
+ int64_t value) {
+ std::string encoded_value;
+ PutVarsignedint64(&encoded_value, value);
+ std::string ret;
+ Status s = EncodeAggFuncAndPayload(function_name, encoded_value, ret);
+ assert(s.ok());
+ return ret;
+}
+
+std::string EncodeHelper::EncodeInt(int64_t value) {
+ std::string encoded_value;
+ PutVarsignedint64(&encoded_value, value);
+ return encoded_value;
+}
+
+std::string EncodeHelper::EncodeFuncAndList(const Slice& function_name,
+ const std::vector<Slice>& list) {
+ std::string ret;
+ Status s = EncodeAggFuncAndPayload(function_name, EncodeList(list), ret);
+ assert(s.ok());
+ return ret;
+}
+
+std::string EncodeHelper::EncodeList(const std::vector<Slice>& list) {
+ std::string result;
+ for (const Slice& entity : list) {
+ PutLengthPrefixedSlice(&result, entity);
+ }
+ return result;
+}
+
+bool SumAggregator::Aggregate(const std::vector<Slice>& item_list,
+ std::string& result) const {
+ int64_t sum = 0;
+ for (const Slice& item : item_list) {
+ int64_t ivalue;
+ Slice v = item;
+ if (!GetVarsignedint64(&v, &ivalue) || !v.empty()) {
+ return false;
+ }
+ sum += ivalue;
+ }
+ result = EncodeHelper::EncodeInt(sum);
+ return true;
+}
+
+bool MultipleAggregator::Aggregate(const std::vector<Slice>& item_list,
+ std::string& result) const {
+ int64_t mresult = 1;
+ for (const Slice& item : item_list) {
+ int64_t ivalue;
+ Slice v = item;
+ if (!GetVarsignedint64(&v, &ivalue) || !v.empty()) {
+ return false;
+ }
+ mresult *= ivalue;
+ }
+ result = EncodeHelper::EncodeInt(mresult);
+ return true;
+}
+
+bool Last3Aggregator::Aggregate(const std::vector<Slice>& item_list,
+ std::string& result) const {
+ std::vector<Slice> last3;
+ last3.reserve(3);
+ for (auto it = item_list.rbegin(); it != item_list.rend(); ++it) {
+ Slice input = *it;
+ Slice entity;
+ bool ret;
+ while ((ret = GetLengthPrefixedSlice(&input, &entity)) == true) {
+ last3.push_back(entity);
+ if (last3.size() >= 3) {
+ break;
+ }
+ }
+ if (last3.size() >= 3) {
+ break;
+ }
+ if (!ret) {
+ continue;
+ }
+ }
+ result = EncodeHelper::EncodeList(last3);
+ return true;
+}
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/agg_merge/test_agg_merge.h b/src/rocksdb/utilities/agg_merge/test_agg_merge.h
new file mode 100644
index 000000000..5bdf8b9cc
--- /dev/null
+++ b/src/rocksdb/utilities/agg_merge/test_agg_merge.h
@@ -0,0 +1,47 @@
+// Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+#include <algorithm>
+#include <cstddef>
+#include <memory>
+#include <unordered_map>
+
+#include "rocksdb/merge_operator.h"
+#include "rocksdb/slice.h"
+#include "rocksdb/utilities/agg_merge.h"
+#include "utilities/cassandra/cassandra_options.h"
+
+namespace ROCKSDB_NAMESPACE {
+class SumAggregator : public Aggregator {
+ public:
+ ~SumAggregator() override {}
+ bool Aggregate(const std::vector<Slice>&, std::string& result) const override;
+ bool DoPartialAggregate() const override { return true; }
+};
+
+class MultipleAggregator : public Aggregator {
+ public:
+ ~MultipleAggregator() override {}
+ bool Aggregate(const std::vector<Slice>&, std::string& result) const override;
+ bool DoPartialAggregate() const override { return true; }
+};
+
+class Last3Aggregator : public Aggregator {
+ public:
+ ~Last3Aggregator() override {}
+ bool Aggregate(const std::vector<Slice>&, std::string& result) const override;
+};
+
+class EncodeHelper {
+ public:
+ static std::string EncodeFuncAndInt(const Slice& function_name,
+ int64_t value);
+ static std::string EncodeInt(int64_t value);
+ static std::string EncodeList(const std::vector<Slice>& list);
+ static std::string EncodeFuncAndList(const Slice& function_name,
+ const std::vector<Slice>& list);
+};
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/backup/backup_engine.cc b/src/rocksdb/utilities/backup/backup_engine.cc
new file mode 100644
index 000000000..81b4a6629
--- /dev/null
+++ b/src/rocksdb/utilities/backup/backup_engine.cc
@@ -0,0 +1,3181 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#ifndef ROCKSDB_LITE
+
+#include <algorithm>
+#include <atomic>
+#include <cinttypes>
+#include <cstdlib>
+#include <functional>
+#include <future>
+#include <limits>
+#include <map>
+#include <mutex>
+#include <sstream>
+#include <string>
+#include <thread>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "env/composite_env_wrapper.h"
+#include "env/fs_readonly.h"
+#include "env/fs_remap.h"
+#include "file/filename.h"
+#include "file/line_file_reader.h"
+#include "file/sequence_file_reader.h"
+#include "file/writable_file_writer.h"
+#include "logging/logging.h"
+#include "monitoring/iostats_context_imp.h"
+#include "options/options_helper.h"
+#include "port/port.h"
+#include "rocksdb/advanced_options.h"
+#include "rocksdb/env.h"
+#include "rocksdb/rate_limiter.h"
+#include "rocksdb/statistics.h"
+#include "rocksdb/transaction_log.h"
+#include "table/sst_file_dumper.h"
+#include "test_util/sync_point.h"
+#include "util/cast_util.h"
+#include "util/channel.h"
+#include "util/coding.h"
+#include "util/crc32c.h"
+#include "util/math.h"
+#include "util/rate_limiter.h"
+#include "util/string_util.h"
+#include "utilities/backup/backup_engine_impl.h"
+#include "utilities/checkpoint/checkpoint_impl.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+namespace {
+using ShareFilesNaming = BackupEngineOptions::ShareFilesNaming;
+
+constexpr BackupID kLatestBackupIDMarker = static_cast<BackupID>(-2);
+
+inline uint32_t ChecksumHexToInt32(const std::string& checksum_hex) {
+ std::string checksum_str;
+ Slice(checksum_hex).DecodeHex(&checksum_str);
+ return EndianSwapValue(DecodeFixed32(checksum_str.c_str()));
+}
+inline std::string ChecksumStrToHex(const std::string& checksum_str) {
+ return Slice(checksum_str).ToString(true);
+}
+inline std::string ChecksumInt32ToHex(const uint32_t& checksum_value) {
+ std::string checksum_str;
+ PutFixed32(&checksum_str, EndianSwapValue(checksum_value));
+ return ChecksumStrToHex(checksum_str);
+}
+
+const std::string kPrivateDirName = "private";
+const std::string kMetaDirName = "meta";
+const std::string kSharedDirName = "shared";
+const std::string kSharedChecksumDirName = "shared_checksum";
+const std::string kPrivateDirSlash = kPrivateDirName + "/";
+const std::string kMetaDirSlash = kMetaDirName + "/";
+const std::string kSharedDirSlash = kSharedDirName + "/";
+const std::string kSharedChecksumDirSlash = kSharedChecksumDirName + "/";
+
+} // namespace
+
+void BackupStatistics::IncrementNumberSuccessBackup() {
+ number_success_backup++;
+}
+void BackupStatistics::IncrementNumberFailBackup() { number_fail_backup++; }
+
+uint32_t BackupStatistics::GetNumberSuccessBackup() const {
+ return number_success_backup;
+}
+uint32_t BackupStatistics::GetNumberFailBackup() const {
+ return number_fail_backup;
+}
+
+std::string BackupStatistics::ToString() const {
+ char result[50];
+ snprintf(result, sizeof(result), "# success backup: %u, # fail backup: %u",
+ GetNumberSuccessBackup(), GetNumberFailBackup());
+ return result;
+}
+
+void BackupEngineOptions::Dump(Logger* logger) const {
+ ROCKS_LOG_INFO(logger, " Options.backup_dir: %s",
+ backup_dir.c_str());
+ ROCKS_LOG_INFO(logger, " Options.backup_env: %p", backup_env);
+ ROCKS_LOG_INFO(logger, " Options.share_table_files: %d",
+ static_cast<int>(share_table_files));
+ ROCKS_LOG_INFO(logger, " Options.info_log: %p", info_log);
+ ROCKS_LOG_INFO(logger, " Options.sync: %d",
+ static_cast<int>(sync));
+ ROCKS_LOG_INFO(logger, " Options.destroy_old_data: %d",
+ static_cast<int>(destroy_old_data));
+ ROCKS_LOG_INFO(logger, " Options.backup_log_files: %d",
+ static_cast<int>(backup_log_files));
+ ROCKS_LOG_INFO(logger, " Options.backup_rate_limit: %" PRIu64,
+ backup_rate_limit);
+ ROCKS_LOG_INFO(logger, " Options.restore_rate_limit: %" PRIu64,
+ restore_rate_limit);
+ ROCKS_LOG_INFO(logger, "Options.max_background_operations: %d",
+ max_background_operations);
+}
+
+namespace {
+// -------- BackupEngineImpl class ---------
+class BackupEngineImpl {
+ public:
+ BackupEngineImpl(const BackupEngineOptions& options, Env* db_env,
+ bool read_only = false);
+ ~BackupEngineImpl();
+
+ IOStatus CreateNewBackupWithMetadata(const CreateBackupOptions& options,
+ DB* db, const std::string& app_metadata,
+ BackupID* new_backup_id_ptr);
+
+ IOStatus PurgeOldBackups(uint32_t num_backups_to_keep);
+
+ IOStatus DeleteBackup(BackupID backup_id);
+
+ void StopBackup() { stop_backup_.store(true, std::memory_order_release); }
+
+ IOStatus GarbageCollect();
+
+ // The returned BackupInfos are in chronological order, which means the
+ // latest backup comes last.
+ void GetBackupInfo(std::vector<BackupInfo>* backup_info,
+ bool include_file_details) const;
+
+ Status GetBackupInfo(BackupID backup_id, BackupInfo* backup_info,
+ bool include_file_details = false) const;
+
+ void GetCorruptedBackups(std::vector<BackupID>* corrupt_backup_ids) const;
+
+ IOStatus RestoreDBFromBackup(const RestoreOptions& options,
+ BackupID backup_id, const std::string& db_dir,
+ const std::string& wal_dir) const;
+
+ IOStatus RestoreDBFromLatestBackup(const RestoreOptions& options,
+ const std::string& db_dir,
+ const std::string& wal_dir) const {
+ // Note: don't read latest_valid_backup_id_ outside of lock
+ return RestoreDBFromBackup(options, kLatestBackupIDMarker, db_dir, wal_dir);
+ }
+
+ IOStatus VerifyBackup(BackupID backup_id,
+ bool verify_with_checksum = false) const;
+
+ IOStatus Initialize();
+
+ ShareFilesNaming GetNamingNoFlags() const {
+ return options_.share_files_with_checksum_naming &
+ BackupEngineOptions::kMaskNoNamingFlags;
+ }
+ ShareFilesNaming GetNamingFlags() const {
+ return options_.share_files_with_checksum_naming &
+ BackupEngineOptions::kMaskNamingFlags;
+ }
+
+ void TEST_SetDefaultRateLimitersClock(
+ const std::shared_ptr<SystemClock>& backup_rate_limiter_clock,
+ const std::shared_ptr<SystemClock>& restore_rate_limiter_clock) {
+ if (backup_rate_limiter_clock) {
+ static_cast<GenericRateLimiter*>(options_.backup_rate_limiter.get())
+ ->TEST_SetClock(backup_rate_limiter_clock);
+ }
+
+ if (restore_rate_limiter_clock) {
+ static_cast<GenericRateLimiter*>(options_.restore_rate_limiter.get())
+ ->TEST_SetClock(restore_rate_limiter_clock);
+ }
+ }
+
+ private:
+ void DeleteChildren(const std::string& dir,
+ uint32_t file_type_filter = 0) const;
+ IOStatus DeleteBackupNoGC(BackupID backup_id);
+
+ // Extends the "result" map with pathname->size mappings for the contents of
+ // "dir" in "env". Pathnames are prefixed with "dir".
+ IOStatus ReadChildFileCurrentSizes(
+ const std::string& dir, const std::shared_ptr<FileSystem>&,
+ std::unordered_map<std::string, uint64_t>* result) const;
+
+ struct FileInfo {
+ FileInfo(const std::string& fname, uint64_t sz, const std::string& checksum,
+ const std::string& id, const std::string& sid, Temperature _temp)
+ : refs(0),
+ filename(fname),
+ size(sz),
+ checksum_hex(checksum),
+ db_id(id),
+ db_session_id(sid),
+ temp(_temp) {}
+
+ FileInfo(const FileInfo&) = delete;
+ FileInfo& operator=(const FileInfo&) = delete;
+
+ int refs;
+ const std::string filename;
+ const uint64_t size;
+ // crc32c checksum as hex. empty == unknown / unavailable
+ std::string checksum_hex;
+ // DB identities
+ // db_id is obtained for potential usage in the future but not used
+ // currently
+ const std::string db_id;
+ // db_session_id appears in the backup SST filename if the table naming
+ // option is kUseDbSessionId
+ const std::string db_session_id;
+ Temperature temp;
+
+ std::string GetDbFileName() {
+ std::string rv;
+ // extract the filename part
+ size_t slash = filename.find_last_of('/');
+ // file will either be shared/<file>, shared_checksum/<file_crc32c_size>,
+ // shared_checksum/<file_session>, shared_checksum/<file_crc32c_session>,
+ // or private/<number>/<file>
+ assert(slash != std::string::npos);
+ rv = filename.substr(slash + 1);
+
+ // if the file was in shared_checksum, extract the real file name
+ // in this case the file is <number>_<checksum>_<size>.<type>,
+ // <number>_<session>.<type>, or <number>_<checksum>_<session>.<type>
+ if (filename.substr(0, slash) == kSharedChecksumDirName) {
+ rv = GetFileFromChecksumFile(rv);
+ }
+ return rv;
+ }
+ };
+
+ // TODO: deprecate this function once we migrate all BackupEngine's rate
+ // limiting to lower-level ones (i.e, ones in file access wrapper level like
+ // `WritableFileWriter`)
+ static void LoopRateLimitRequestHelper(const size_t total_bytes_to_request,
+ RateLimiter* rate_limiter,
+ const Env::IOPriority pri,
+ Statistics* stats,
+ const RateLimiter::OpType op_type);
+
+ static inline std::string WithoutTrailingSlash(const std::string& path) {
+ if (path.empty() || path.back() != '/') {
+ return path;
+ } else {
+ return path.substr(path.size() - 1);
+ }
+ }
+
+ static inline std::string WithTrailingSlash(const std::string& path) {
+ if (path.empty() || path.back() != '/') {
+ return path + '/';
+ } else {
+ return path;
+ }
+ }
+
+ // A filesystem wrapper that makes shared backup files appear to be in the
+ // private backup directory (dst_dir), so that the private backup dir can
+ // be opened as a read-only DB.
+ class RemapSharedFileSystem : public RemapFileSystem {
+ public:
+ RemapSharedFileSystem(const std::shared_ptr<FileSystem>& base,
+ const std::string& dst_dir,
+ const std::string& src_base_dir,
+ const std::vector<std::shared_ptr<FileInfo>>& files)
+ : RemapFileSystem(base),
+ dst_dir_(WithoutTrailingSlash(dst_dir)),
+ dst_dir_slash_(WithTrailingSlash(dst_dir)),
+ src_base_dir_(WithTrailingSlash(src_base_dir)) {
+ for (auto& info : files) {
+ if (!StartsWith(info->filename, kPrivateDirSlash)) {
+ assert(StartsWith(info->filename, kSharedDirSlash) ||
+ StartsWith(info->filename, kSharedChecksumDirSlash));
+ remaps_[info->GetDbFileName()] = info;
+ }
+ }
+ }
+
+ const char* Name() const override {
+ return "BackupEngineImpl::RemapSharedFileSystem";
+ }
+
+ // Sometimes a directory listing is required in opening a DB
+ IOStatus GetChildren(const std::string& dir, const IOOptions& options,
+ std::vector<std::string>* result,
+ IODebugContext* dbg) override {
+ IOStatus s = RemapFileSystem::GetChildren(dir, options, result, dbg);
+ if (s.ok() && (dir == dst_dir_ || dir == dst_dir_slash_)) {
+ // Assume remapped files exist
+ for (auto& r : remaps_) {
+ result->push_back(r.first);
+ }
+ }
+ return s;
+ }
+
+ // Sometimes a directory listing is required in opening a DB
+ IOStatus GetChildrenFileAttributes(const std::string& dir,
+ const IOOptions& options,
+ std::vector<FileAttributes>* result,
+ IODebugContext* dbg) override {
+ IOStatus s =
+ RemapFileSystem::GetChildrenFileAttributes(dir, options, result, dbg);
+ if (s.ok() && (dir == dst_dir_ || dir == dst_dir_slash_)) {
+ // Assume remapped files exist with recorded size
+ for (auto& r : remaps_) {
+ result->emplace_back(); // clean up with C++20
+ FileAttributes& attr = result->back();
+ attr.name = r.first;
+ attr.size_bytes = r.second->size;
+ }
+ }
+ return s;
+ }
+
+ protected:
+ // When a file in dst_dir is requested, see if we need to remap to shared
+ // file path.
+ std::pair<IOStatus, std::string> EncodePath(
+ const std::string& path) override {
+ if (path.empty() || path[0] != '/') {
+ return {IOStatus::InvalidArgument(path, "Not an absolute path"), ""};
+ }
+ std::pair<IOStatus, std::string> rv{IOStatus(), path};
+ if (StartsWith(path, dst_dir_slash_)) {
+ std::string relative = path.substr(dst_dir_slash_.size());
+ auto it = remaps_.find(relative);
+ if (it != remaps_.end()) {
+ rv.second = src_base_dir_ + it->second->filename;
+ }
+ }
+ return rv;
+ }
+
+ private:
+ // Absolute path to a directory that some extra files will be mapped into.
+ const std::string dst_dir_;
+ // Includes a trailing slash.
+ const std::string dst_dir_slash_;
+ // Absolute path to a directory containing some files to be mapped into
+ // dst_dir_. Includes a trailing slash.
+ const std::string src_base_dir_;
+ // If remaps_[x] exists, attempt to read dst_dir_ / x should instead read
+ // src_base_dir_ / remaps_[x]->filename. FileInfo is used to maximize
+ // sharing with other backup data in memory.
+ std::unordered_map<std::string, std::shared_ptr<FileInfo>> remaps_;
+ };
+
+ class BackupMeta {
+ public:
+ BackupMeta(
+ const std::string& meta_filename, const std::string& meta_tmp_filename,
+ std::unordered_map<std::string, std::shared_ptr<FileInfo>>* file_infos,
+ Env* env, const std::shared_ptr<FileSystem>& fs)
+ : timestamp_(0),
+ sequence_number_(0),
+ size_(0),
+ meta_filename_(meta_filename),
+ meta_tmp_filename_(meta_tmp_filename),
+ file_infos_(file_infos),
+ env_(env),
+ fs_(fs) {}
+
+ BackupMeta(const BackupMeta&) = delete;
+ BackupMeta& operator=(const BackupMeta&) = delete;
+
+ ~BackupMeta() {}
+
+ void RecordTimestamp() {
+ // Best effort
+ Status s = env_->GetCurrentTime(&timestamp_);
+ if (!s.ok()) {
+ timestamp_ = /* something clearly fabricated */ 1;
+ }
+ }
+ int64_t GetTimestamp() const { return timestamp_; }
+ uint64_t GetSize() const { return size_; }
+ uint32_t GetNumberFiles() const {
+ return static_cast<uint32_t>(files_.size());
+ }
+ void SetSequenceNumber(uint64_t sequence_number) {
+ sequence_number_ = sequence_number;
+ }
+ uint64_t GetSequenceNumber() const { return sequence_number_; }
+
+ const std::string& GetAppMetadata() const { return app_metadata_; }
+
+ void SetAppMetadata(const std::string& app_metadata) {
+ app_metadata_ = app_metadata;
+ }
+
+ IOStatus AddFile(std::shared_ptr<FileInfo> file_info);
+
+ IOStatus Delete(bool delete_meta = true);
+
+ bool Empty() const { return files_.empty(); }
+
+ std::shared_ptr<FileInfo> GetFile(const std::string& filename) const {
+ auto it = file_infos_->find(filename);
+ if (it == file_infos_->end()) {
+ return nullptr;
+ }
+ return it->second;
+ }
+
+ const std::vector<std::shared_ptr<FileInfo>>& GetFiles() const {
+ return files_;
+ }
+
+ // @param abs_path_to_size Pre-fetched file sizes (bytes).
+ IOStatus LoadFromFile(
+ const std::string& backup_dir,
+ const std::unordered_map<std::string, uint64_t>& abs_path_to_size,
+ RateLimiter* rate_limiter, Logger* info_log,
+ std::unordered_set<std::string>* reported_ignored_fields);
+ IOStatus StoreToFile(
+ bool sync, int schema_version,
+ const TEST_BackupMetaSchemaOptions* schema_test_options);
+
+ std::string GetInfoString() {
+ std::ostringstream ss;
+ ss << "Timestamp: " << timestamp_ << std::endl;
+ char human_size[16];
+ AppendHumanBytes(size_, human_size, sizeof(human_size));
+ ss << "Size: " << human_size << std::endl;
+ ss << "Files:" << std::endl;
+ for (const auto& file : files_) {
+ AppendHumanBytes(file->size, human_size, sizeof(human_size));
+ ss << file->filename << ", size " << human_size << ", refs "
+ << file->refs << std::endl;
+ }
+ return ss.str();
+ }
+
+ const std::shared_ptr<Env>& GetEnvForOpen() const {
+ if (!env_for_open_) {
+ // Lazy initialize
+ // Find directories
+ std::string dst_dir = meta_filename_;
+ auto i = dst_dir.rfind(kMetaDirSlash);
+ assert(i != std::string::npos);
+ std::string src_base_dir = dst_dir.substr(0, i);
+ dst_dir.replace(i, kMetaDirSlash.size(), kPrivateDirSlash);
+ // Make the RemapSharedFileSystem
+ std::shared_ptr<FileSystem> remap_fs =
+ std::make_shared<RemapSharedFileSystem>(fs_, dst_dir, src_base_dir,
+ files_);
+ // Make it read-only for safety
+ remap_fs = std::make_shared<ReadOnlyFileSystem>(remap_fs);
+ // Make an Env wrapper
+ env_for_open_ = std::make_shared<CompositeEnvWrapper>(env_, remap_fs);
+ }
+ return env_for_open_;
+ }
+
+ private:
+ int64_t timestamp_;
+ // sequence number is only approximate, should not be used
+ // by clients
+ uint64_t sequence_number_;
+ uint64_t size_;
+ std::string app_metadata_;
+ std::string const meta_filename_;
+ std::string const meta_tmp_filename_;
+ // files with relative paths (without "/" prefix!!)
+ std::vector<std::shared_ptr<FileInfo>> files_;
+ std::unordered_map<std::string, std::shared_ptr<FileInfo>>* file_infos_;
+ Env* env_;
+ mutable std::shared_ptr<Env> env_for_open_;
+ std::shared_ptr<FileSystem> fs_;
+ IOOptions iooptions_ = IOOptions();
+ }; // BackupMeta
+
+ void SetBackupInfoFromBackupMeta(BackupID id, const BackupMeta& meta,
+ BackupInfo* backup_info,
+ bool include_file_details) const;
+
+ inline std::string GetAbsolutePath(
+ const std::string& relative_path = "") const {
+ assert(relative_path.size() == 0 || relative_path[0] != '/');
+ return options_.backup_dir + "/" + relative_path;
+ }
+ inline std::string GetPrivateFileRel(BackupID backup_id, bool tmp = false,
+ const std::string& file = "") const {
+ assert(file.size() == 0 || file[0] != '/');
+ return kPrivateDirSlash + std::to_string(backup_id) + (tmp ? ".tmp" : "") +
+ "/" + file;
+ }
+ inline std::string GetSharedFileRel(const std::string& file = "",
+ bool tmp = false) const {
+ assert(file.size() == 0 || file[0] != '/');
+ return kSharedDirSlash + std::string(tmp ? "." : "") + file +
+ (tmp ? ".tmp" : "");
+ }
+ inline std::string GetSharedFileWithChecksumRel(const std::string& file = "",
+ bool tmp = false) const {
+ assert(file.size() == 0 || file[0] != '/');
+ return kSharedChecksumDirSlash + std::string(tmp ? "." : "") + file +
+ (tmp ? ".tmp" : "");
+ }
+ inline bool UseLegacyNaming(const std::string& sid) const {
+ return GetNamingNoFlags() ==
+ BackupEngineOptions::kLegacyCrc32cAndFileSize ||
+ sid.empty();
+ }
+ inline std::string GetSharedFileWithChecksum(
+ const std::string& file, const std::string& checksum_hex,
+ const uint64_t file_size, const std::string& db_session_id) const {
+ assert(file.size() == 0 || file[0] != '/');
+ std::string file_copy = file;
+ if (UseLegacyNaming(db_session_id)) {
+ assert(!checksum_hex.empty());
+ file_copy.insert(file_copy.find_last_of('.'),
+ "_" + std::to_string(ChecksumHexToInt32(checksum_hex)) +
+ "_" + std::to_string(file_size));
+ } else {
+ file_copy.insert(file_copy.find_last_of('.'), "_s" + db_session_id);
+ if (GetNamingFlags() & BackupEngineOptions::kFlagIncludeFileSize) {
+ file_copy.insert(file_copy.find_last_of('.'),
+ "_" + std::to_string(file_size));
+ }
+ }
+ return file_copy;
+ }
+ static inline std::string GetFileFromChecksumFile(const std::string& file) {
+ assert(file.size() == 0 || file[0] != '/');
+ std::string file_copy = file;
+ size_t first_underscore = file_copy.find_first_of('_');
+ return file_copy.erase(first_underscore,
+ file_copy.find_last_of('.') - first_underscore);
+ }
+ inline std::string GetBackupMetaFile(BackupID backup_id, bool tmp) const {
+ return GetAbsolutePath(kMetaDirName) + "/" + (tmp ? "." : "") +
+ std::to_string(backup_id) + (tmp ? ".tmp" : "");
+ }
+
+ // If size_limit == 0, there is no size limit, copy everything.
+ //
+ // Exactly one of src and contents must be non-empty.
+ //
+ // @param src If non-empty, the file is copied from this pathname.
+ // @param contents If non-empty, the file will be created with these contents.
+ // @param src_temperature Pass in expected temperature of src, return back
+ // temperature reported by FileSystem
+ IOStatus CopyOrCreateFile(const std::string& src, const std::string& dst,
+ const std::string& contents, uint64_t size_limit,
+ Env* src_env, Env* dst_env,
+ const EnvOptions& src_env_options, bool sync,
+ RateLimiter* rate_limiter,
+ std::function<void()> progress_callback,
+ Temperature* src_temperature,
+ Temperature dst_temperature,
+ uint64_t* bytes_toward_next_callback,
+ uint64_t* size, std::string* checksum_hex);
+
+ IOStatus ReadFileAndComputeChecksum(const std::string& src,
+ const std::shared_ptr<FileSystem>& src_fs,
+ const EnvOptions& src_env_options,
+ uint64_t size_limit,
+ std::string* checksum_hex,
+ const Temperature src_temperature) const;
+
+ // Obtain db_id and db_session_id from the table properties of file_path
+ Status GetFileDbIdentities(Env* src_env, const EnvOptions& src_env_options,
+ const std::string& file_path,
+ Temperature file_temp, RateLimiter* rate_limiter,
+ std::string* db_id, std::string* db_session_id);
+
+ struct CopyOrCreateResult {
+ ~CopyOrCreateResult() {
+ // The Status needs to be ignored here for two reasons.
+ // First, if the BackupEngineImpl shuts down with jobs outstanding, then
+ // it is possible that the Status in the future/promise is never read,
+ // resulting in an unchecked Status. Second, if there are items in the
+ // channel when the BackupEngineImpl is shutdown, these will also have
+ // Status that have not been checked. This
+ // TODO: Fix those issues so that the Status
+ io_status.PermitUncheckedError();
+ }
+ uint64_t size;
+ std::string checksum_hex;
+ std::string db_id;
+ std::string db_session_id;
+ IOStatus io_status;
+ Temperature expected_src_temperature = Temperature::kUnknown;
+ Temperature current_src_temperature = Temperature::kUnknown;
+ };
+
+ // Exactly one of src_path and contents must be non-empty. If src_path is
+ // non-empty, the file is copied from this pathname. Otherwise, if contents is
+ // non-empty, the file will be created at dst_path with these contents.
+ struct CopyOrCreateWorkItem {
+ std::string src_path;
+ std::string dst_path;
+ Temperature src_temperature;
+ Temperature dst_temperature;
+ std::string contents;
+ Env* src_env;
+ Env* dst_env;
+ EnvOptions src_env_options;
+ bool sync;
+ RateLimiter* rate_limiter;
+ uint64_t size_limit;
+ Statistics* stats;
+ std::promise<CopyOrCreateResult> result;
+ std::function<void()> progress_callback;
+ std::string src_checksum_func_name;
+ std::string src_checksum_hex;
+ std::string db_id;
+ std::string db_session_id;
+
+ CopyOrCreateWorkItem()
+ : src_path(""),
+ dst_path(""),
+ src_temperature(Temperature::kUnknown),
+ dst_temperature(Temperature::kUnknown),
+ contents(""),
+ src_env(nullptr),
+ dst_env(nullptr),
+ src_env_options(),
+ sync(false),
+ rate_limiter(nullptr),
+ size_limit(0),
+ stats(nullptr),
+ src_checksum_func_name(kUnknownFileChecksumFuncName),
+ src_checksum_hex(""),
+ db_id(""),
+ db_session_id("") {}
+
+ CopyOrCreateWorkItem(const CopyOrCreateWorkItem&) = delete;
+ CopyOrCreateWorkItem& operator=(const CopyOrCreateWorkItem&) = delete;
+
+ CopyOrCreateWorkItem(CopyOrCreateWorkItem&& o) noexcept {
+ *this = std::move(o);
+ }
+
+ CopyOrCreateWorkItem& operator=(CopyOrCreateWorkItem&& o) noexcept {
+ src_path = std::move(o.src_path);
+ dst_path = std::move(o.dst_path);
+ src_temperature = std::move(o.src_temperature);
+ dst_temperature = std::move(o.dst_temperature);
+ contents = std::move(o.contents);
+ src_env = o.src_env;
+ dst_env = o.dst_env;
+ src_env_options = std::move(o.src_env_options);
+ sync = o.sync;
+ rate_limiter = o.rate_limiter;
+ size_limit = o.size_limit;
+ stats = o.stats;
+ result = std::move(o.result);
+ progress_callback = std::move(o.progress_callback);
+ src_checksum_func_name = std::move(o.src_checksum_func_name);
+ src_checksum_hex = std::move(o.src_checksum_hex);
+ db_id = std::move(o.db_id);
+ db_session_id = std::move(o.db_session_id);
+ src_temperature = o.src_temperature;
+ return *this;
+ }
+
+ CopyOrCreateWorkItem(
+ std::string _src_path, std::string _dst_path,
+ const Temperature _src_temperature, const Temperature _dst_temperature,
+ std::string _contents, Env* _src_env, Env* _dst_env,
+ EnvOptions _src_env_options, bool _sync, RateLimiter* _rate_limiter,
+ uint64_t _size_limit, Statistics* _stats,
+ std::function<void()> _progress_callback = []() {},
+ const std::string& _src_checksum_func_name =
+ kUnknownFileChecksumFuncName,
+ const std::string& _src_checksum_hex = "",
+ const std::string& _db_id = "", const std::string& _db_session_id = "")
+ : src_path(std::move(_src_path)),
+ dst_path(std::move(_dst_path)),
+ src_temperature(_src_temperature),
+ dst_temperature(_dst_temperature),
+ contents(std::move(_contents)),
+ src_env(_src_env),
+ dst_env(_dst_env),
+ src_env_options(std::move(_src_env_options)),
+ sync(_sync),
+ rate_limiter(_rate_limiter),
+ size_limit(_size_limit),
+ stats(_stats),
+ progress_callback(_progress_callback),
+ src_checksum_func_name(_src_checksum_func_name),
+ src_checksum_hex(_src_checksum_hex),
+ db_id(_db_id),
+ db_session_id(_db_session_id) {}
+ };
+
+ struct BackupAfterCopyOrCreateWorkItem {
+ std::future<CopyOrCreateResult> result;
+ bool shared;
+ bool needed_to_copy;
+ Env* backup_env;
+ std::string dst_path_tmp;
+ std::string dst_path;
+ std::string dst_relative;
+ BackupAfterCopyOrCreateWorkItem()
+ : shared(false),
+ needed_to_copy(false),
+ backup_env(nullptr),
+ dst_path_tmp(""),
+ dst_path(""),
+ dst_relative("") {}
+
+ BackupAfterCopyOrCreateWorkItem(
+ BackupAfterCopyOrCreateWorkItem&& o) noexcept {
+ *this = std::move(o);
+ }
+
+ BackupAfterCopyOrCreateWorkItem& operator=(
+ BackupAfterCopyOrCreateWorkItem&& o) noexcept {
+ result = std::move(o.result);
+ shared = o.shared;
+ needed_to_copy = o.needed_to_copy;
+ backup_env = o.backup_env;
+ dst_path_tmp = std::move(o.dst_path_tmp);
+ dst_path = std::move(o.dst_path);
+ dst_relative = std::move(o.dst_relative);
+ return *this;
+ }
+
+ BackupAfterCopyOrCreateWorkItem(std::future<CopyOrCreateResult>&& _result,
+ bool _shared, bool _needed_to_copy,
+ Env* _backup_env, std::string _dst_path_tmp,
+ std::string _dst_path,
+ std::string _dst_relative)
+ : result(std::move(_result)),
+ shared(_shared),
+ needed_to_copy(_needed_to_copy),
+ backup_env(_backup_env),
+ dst_path_tmp(std::move(_dst_path_tmp)),
+ dst_path(std::move(_dst_path)),
+ dst_relative(std::move(_dst_relative)) {}
+ };
+
+ struct RestoreAfterCopyOrCreateWorkItem {
+ std::future<CopyOrCreateResult> result;
+ std::string from_file;
+ std::string to_file;
+ std::string checksum_hex;
+ RestoreAfterCopyOrCreateWorkItem() : checksum_hex("") {}
+ RestoreAfterCopyOrCreateWorkItem(std::future<CopyOrCreateResult>&& _result,
+ const std::string& _from_file,
+ const std::string& _to_file,
+ const std::string& _checksum_hex)
+ : result(std::move(_result)),
+ from_file(_from_file),
+ to_file(_to_file),
+ checksum_hex(_checksum_hex) {}
+ RestoreAfterCopyOrCreateWorkItem(
+ RestoreAfterCopyOrCreateWorkItem&& o) noexcept {
+ *this = std::move(o);
+ }
+
+ RestoreAfterCopyOrCreateWorkItem& operator=(
+ RestoreAfterCopyOrCreateWorkItem&& o) noexcept {
+ result = std::move(o.result);
+ checksum_hex = std::move(o.checksum_hex);
+ return *this;
+ }
+ };
+
+ bool initialized_;
+ std::mutex byte_report_mutex_;
+ mutable channel<CopyOrCreateWorkItem> files_to_copy_or_create_;
+ std::vector<port::Thread> threads_;
+ std::atomic<CpuPriority> threads_cpu_priority_;
+
+ // Certain operations like PurgeOldBackups and DeleteBackup will trigger
+ // automatic GarbageCollect (true) unless we've already done one in this
+ // session and have not failed to delete backup files since then (false).
+ bool might_need_garbage_collect_ = true;
+
+ // Adds a file to the backup work queue to be copied or created if it doesn't
+ // already exist.
+ //
+ // Exactly one of src_dir and contents must be non-empty.
+ //
+ // @param src_dir If non-empty, the file in this directory named fname will be
+ // copied.
+ // @param fname Name of destination file and, in case of copy, source file.
+ // @param contents If non-empty, the file will be created with these contents.
+ IOStatus AddBackupFileWorkItem(
+ std::unordered_set<std::string>& live_dst_paths,
+ std::vector<BackupAfterCopyOrCreateWorkItem>& backup_items_to_finish,
+ BackupID backup_id, bool shared, const std::string& src_dir,
+ const std::string& fname, // starts with "/"
+ const EnvOptions& src_env_options, RateLimiter* rate_limiter,
+ FileType file_type, uint64_t size_bytes, Statistics* stats,
+ uint64_t size_limit = 0, bool shared_checksum = false,
+ std::function<void()> progress_callback = []() {},
+ const std::string& contents = std::string(),
+ const std::string& src_checksum_func_name = kUnknownFileChecksumFuncName,
+ const std::string& src_checksum_str = kUnknownFileChecksum,
+ const Temperature src_temperature = Temperature::kUnknown);
+
+ // backup state data
+ BackupID latest_backup_id_;
+ BackupID latest_valid_backup_id_;
+ std::map<BackupID, std::unique_ptr<BackupMeta>> backups_;
+ std::map<BackupID, std::pair<IOStatus, std::unique_ptr<BackupMeta>>>
+ corrupt_backups_;
+ std::unordered_map<std::string, std::shared_ptr<FileInfo>>
+ backuped_file_infos_;
+ std::atomic<bool> stop_backup_;
+
+ // options data
+ BackupEngineOptions options_;
+ Env* db_env_;
+ Env* backup_env_;
+
+ // directories
+ std::unique_ptr<FSDirectory> backup_directory_;
+ std::unique_ptr<FSDirectory> shared_directory_;
+ std::unique_ptr<FSDirectory> meta_directory_;
+ std::unique_ptr<FSDirectory> private_directory_;
+
+ static const size_t kDefaultCopyFileBufferSize = 5 * 1024 * 1024LL; // 5MB
+ bool read_only_;
+ BackupStatistics backup_statistics_;
+ std::unordered_set<std::string> reported_ignored_fields_;
+ static const size_t kMaxAppMetaSize = 1024 * 1024; // 1MB
+ std::shared_ptr<FileSystem> db_fs_;
+ std::shared_ptr<FileSystem> backup_fs_;
+ IOOptions io_options_ = IOOptions();
+
+ public:
+ std::unique_ptr<TEST_BackupMetaSchemaOptions> schema_test_options_;
+};
+
+// -------- BackupEngineImplThreadSafe class ---------
+// This locking layer for thread safety in the public API is layered on
+// top to prevent accidental recursive locking with RWMutex, which is UB.
+// Note: BackupEngineReadOnlyBase inherited twice, but has no fields
+class BackupEngineImplThreadSafe : public BackupEngine,
+ public BackupEngineReadOnly {
+ public:
+ BackupEngineImplThreadSafe(const BackupEngineOptions& options, Env* db_env,
+ bool read_only = false)
+ : impl_(options, db_env, read_only) {}
+ ~BackupEngineImplThreadSafe() override {}
+
+ using BackupEngine::CreateNewBackupWithMetadata;
+ IOStatus CreateNewBackupWithMetadata(const CreateBackupOptions& options,
+ DB* db, const std::string& app_metadata,
+ BackupID* new_backup_id) override {
+ WriteLock lock(&mutex_);
+ return impl_.CreateNewBackupWithMetadata(options, db, app_metadata,
+ new_backup_id);
+ }
+
+ IOStatus PurgeOldBackups(uint32_t num_backups_to_keep) override {
+ WriteLock lock(&mutex_);
+ return impl_.PurgeOldBackups(num_backups_to_keep);
+ }
+
+ IOStatus DeleteBackup(BackupID backup_id) override {
+ WriteLock lock(&mutex_);
+ return impl_.DeleteBackup(backup_id);
+ }
+
+ void StopBackup() override {
+ // No locking needed
+ impl_.StopBackup();
+ }
+
+ IOStatus GarbageCollect() override {
+ WriteLock lock(&mutex_);
+ return impl_.GarbageCollect();
+ }
+
+ Status GetLatestBackupInfo(BackupInfo* backup_info,
+ bool include_file_details = false) const override {
+ ReadLock lock(&mutex_);
+ return impl_.GetBackupInfo(kLatestBackupIDMarker, backup_info,
+ include_file_details);
+ }
+
+ Status GetBackupInfo(BackupID backup_id, BackupInfo* backup_info,
+ bool include_file_details = false) const override {
+ ReadLock lock(&mutex_);
+ return impl_.GetBackupInfo(backup_id, backup_info, include_file_details);
+ }
+
+ void GetBackupInfo(std::vector<BackupInfo>* backup_info,
+ bool include_file_details) const override {
+ ReadLock lock(&mutex_);
+ impl_.GetBackupInfo(backup_info, include_file_details);
+ }
+
+ void GetCorruptedBackups(
+ std::vector<BackupID>* corrupt_backup_ids) const override {
+ ReadLock lock(&mutex_);
+ impl_.GetCorruptedBackups(corrupt_backup_ids);
+ }
+
+ using BackupEngine::RestoreDBFromBackup;
+ IOStatus RestoreDBFromBackup(const RestoreOptions& options,
+ BackupID backup_id, const std::string& db_dir,
+ const std::string& wal_dir) const override {
+ ReadLock lock(&mutex_);
+ return impl_.RestoreDBFromBackup(options, backup_id, db_dir, wal_dir);
+ }
+
+ using BackupEngine::RestoreDBFromLatestBackup;
+ IOStatus RestoreDBFromLatestBackup(
+ const RestoreOptions& options, const std::string& db_dir,
+ const std::string& wal_dir) const override {
+ // Defer to above function, which locks
+ return RestoreDBFromBackup(options, kLatestBackupIDMarker, db_dir, wal_dir);
+ }
+
+ IOStatus VerifyBackup(BackupID backup_id,
+ bool verify_with_checksum = false) const override {
+ ReadLock lock(&mutex_);
+ return impl_.VerifyBackup(backup_id, verify_with_checksum);
+ }
+
+ // Not public API but needed
+ IOStatus Initialize() {
+ // No locking needed
+ return impl_.Initialize();
+ }
+
+ // Not public API but used in testing
+ void TEST_SetBackupMetaSchemaOptions(
+ const TEST_BackupMetaSchemaOptions& options) {
+ impl_.schema_test_options_.reset(new TEST_BackupMetaSchemaOptions(options));
+ }
+
+ // Not public API but used in testing
+ void TEST_SetDefaultRateLimitersClock(
+ const std::shared_ptr<SystemClock>& backup_rate_limiter_clock = nullptr,
+ const std::shared_ptr<SystemClock>& restore_rate_limiter_clock =
+ nullptr) {
+ impl_.TEST_SetDefaultRateLimitersClock(backup_rate_limiter_clock,
+ restore_rate_limiter_clock);
+ }
+
+ private:
+ mutable port::RWMutex mutex_;
+ BackupEngineImpl impl_;
+};
+} // namespace
+
+IOStatus BackupEngine::Open(const BackupEngineOptions& options, Env* env,
+ BackupEngine** backup_engine_ptr) {
+ std::unique_ptr<BackupEngineImplThreadSafe> backup_engine(
+ new BackupEngineImplThreadSafe(options, env));
+ auto s = backup_engine->Initialize();
+ if (!s.ok()) {
+ *backup_engine_ptr = nullptr;
+ return s;
+ }
+ *backup_engine_ptr = backup_engine.release();
+ return IOStatus::OK();
+}
+
+namespace {
+BackupEngineImpl::BackupEngineImpl(const BackupEngineOptions& options,
+ Env* db_env, bool read_only)
+ : initialized_(false),
+ threads_cpu_priority_(),
+ latest_backup_id_(0),
+ latest_valid_backup_id_(0),
+ stop_backup_(false),
+ options_(options),
+ db_env_(db_env),
+ backup_env_(options.backup_env != nullptr ? options.backup_env : db_env_),
+ read_only_(read_only) {
+ if (options_.backup_rate_limiter == nullptr &&
+ options_.backup_rate_limit > 0) {
+ options_.backup_rate_limiter.reset(
+ NewGenericRateLimiter(options_.backup_rate_limit));
+ }
+ if (options_.restore_rate_limiter == nullptr &&
+ options_.restore_rate_limit > 0) {
+ options_.restore_rate_limiter.reset(
+ NewGenericRateLimiter(options_.restore_rate_limit));
+ }
+ db_fs_ = db_env_->GetFileSystem();
+ backup_fs_ = backup_env_->GetFileSystem();
+}
+
+BackupEngineImpl::~BackupEngineImpl() {
+ files_to_copy_or_create_.sendEof();
+ for (auto& t : threads_) {
+ t.join();
+ }
+ LogFlush(options_.info_log);
+ for (const auto& it : corrupt_backups_) {
+ it.second.first.PermitUncheckedError();
+ }
+}
+
+IOStatus BackupEngineImpl::Initialize() {
+ assert(!initialized_);
+ initialized_ = true;
+ if (read_only_) {
+ ROCKS_LOG_INFO(options_.info_log, "Starting read_only backup engine");
+ }
+ options_.Dump(options_.info_log);
+
+ auto meta_path = GetAbsolutePath(kMetaDirName);
+
+ if (!read_only_) {
+ // we might need to clean up from previous crash or I/O errors
+ might_need_garbage_collect_ = true;
+
+ if (options_.max_valid_backups_to_open !=
+ std::numeric_limits<int32_t>::max()) {
+ options_.max_valid_backups_to_open = std::numeric_limits<int32_t>::max();
+ ROCKS_LOG_WARN(
+ options_.info_log,
+ "`max_valid_backups_to_open` is not set to the default value. "
+ "Ignoring its value since BackupEngine is not read-only.");
+ }
+
+ // gather the list of directories that we need to create
+ std::vector<std::pair<std::string, std::unique_ptr<FSDirectory>*>>
+ directories;
+ directories.emplace_back(GetAbsolutePath(), &backup_directory_);
+ if (options_.share_table_files) {
+ if (options_.share_files_with_checksum) {
+ directories.emplace_back(
+ GetAbsolutePath(GetSharedFileWithChecksumRel()),
+ &shared_directory_);
+ } else {
+ directories.emplace_back(GetAbsolutePath(GetSharedFileRel()),
+ &shared_directory_);
+ }
+ }
+ directories.emplace_back(GetAbsolutePath(kPrivateDirName),
+ &private_directory_);
+ directories.emplace_back(meta_path, &meta_directory_);
+ // create all the dirs we need
+ for (const auto& d : directories) {
+ IOStatus io_s =
+ backup_fs_->CreateDirIfMissing(d.first, io_options_, nullptr);
+ if (io_s.ok()) {
+ io_s =
+ backup_fs_->NewDirectory(d.first, io_options_, d.second, nullptr);
+ }
+ if (!io_s.ok()) {
+ return io_s;
+ }
+ }
+ }
+
+ std::vector<std::string> backup_meta_files;
+ {
+ IOStatus io_s = backup_fs_->GetChildren(meta_path, io_options_,
+ &backup_meta_files, nullptr);
+ if (io_s.IsNotFound()) {
+ return IOStatus::NotFound(meta_path + " is missing");
+ } else if (!io_s.ok()) {
+ return io_s;
+ }
+ }
+ // create backups_ structure
+ for (auto& file : backup_meta_files) {
+ ROCKS_LOG_INFO(options_.info_log, "Detected backup %s", file.c_str());
+ BackupID backup_id = 0;
+ sscanf(file.c_str(), "%u", &backup_id);
+ if (backup_id == 0 || file != std::to_string(backup_id)) {
+ // Invalid file name, will be deleted with auto-GC when user
+ // initiates an append or write operation. (Behave as read-only until
+ // then.)
+ ROCKS_LOG_INFO(options_.info_log, "Skipping unrecognized meta file %s",
+ file.c_str());
+ continue;
+ }
+ assert(backups_.find(backup_id) == backups_.end());
+ // Insert all the (backup_id, BackupMeta) that will be loaded later
+ // The loading performed later will check whether there are corrupt backups
+ // and move the corrupt backups to corrupt_backups_
+ backups_.insert(std::make_pair(
+ backup_id, std::unique_ptr<BackupMeta>(new BackupMeta(
+ GetBackupMetaFile(backup_id, false /* tmp */),
+ GetBackupMetaFile(backup_id, true /* tmp */),
+ &backuped_file_infos_, backup_env_, backup_fs_))));
+ }
+
+ latest_backup_id_ = 0;
+ latest_valid_backup_id_ = 0;
+ if (options_.destroy_old_data) { // Destroy old data
+ assert(!read_only_);
+ ROCKS_LOG_INFO(
+ options_.info_log,
+ "Backup Engine started with destroy_old_data == true, deleting all "
+ "backups");
+ IOStatus io_s = PurgeOldBackups(0);
+ if (io_s.ok()) {
+ io_s = GarbageCollect();
+ }
+ if (!io_s.ok()) {
+ return io_s;
+ }
+ } else { // Load data from storage
+ // abs_path_to_size: maps absolute paths of files in backup directory to
+ // their corresponding sizes
+ std::unordered_map<std::string, uint64_t> abs_path_to_size;
+ // Insert files and their sizes in backup sub-directories (shared and
+ // shared_checksum) to abs_path_to_size
+ for (const auto& rel_dir :
+ {GetSharedFileRel(), GetSharedFileWithChecksumRel()}) {
+ const auto abs_dir = GetAbsolutePath(rel_dir);
+ IOStatus io_s =
+ ReadChildFileCurrentSizes(abs_dir, backup_fs_, &abs_path_to_size);
+ if (!io_s.ok()) {
+ // I/O error likely impacting all backups
+ return io_s;
+ }
+ }
+ // load the backups if any, until valid_backups_to_open of the latest
+ // non-corrupted backups have been successfully opened.
+ int valid_backups_to_open = options_.max_valid_backups_to_open;
+ for (auto backup_iter = backups_.rbegin(); backup_iter != backups_.rend();
+ ++backup_iter) {
+ assert(latest_backup_id_ == 0 || latest_backup_id_ > backup_iter->first);
+ if (latest_backup_id_ == 0) {
+ latest_backup_id_ = backup_iter->first;
+ }
+ if (valid_backups_to_open == 0) {
+ break;
+ }
+
+ // Insert files and their sizes in backup sub-directories
+ // (private/backup_id) to abs_path_to_size
+ IOStatus io_s = ReadChildFileCurrentSizes(
+ GetAbsolutePath(GetPrivateFileRel(backup_iter->first)), backup_fs_,
+ &abs_path_to_size);
+ if (io_s.ok()) {
+ io_s = backup_iter->second->LoadFromFile(
+ options_.backup_dir, abs_path_to_size,
+ options_.backup_rate_limiter.get(), options_.info_log,
+ &reported_ignored_fields_);
+ }
+ if (io_s.IsCorruption() || io_s.IsNotSupported()) {
+ ROCKS_LOG_INFO(options_.info_log, "Backup %u corrupted -- %s",
+ backup_iter->first, io_s.ToString().c_str());
+ corrupt_backups_.insert(std::make_pair(
+ backup_iter->first,
+ std::make_pair(io_s, std::move(backup_iter->second))));
+ } else if (!io_s.ok()) {
+ // Distinguish corruption errors from errors in the backup Env.
+ // Errors in the backup Env (i.e., this code path) will cause Open() to
+ // fail, whereas corruption errors would not cause Open() failures.
+ return io_s;
+ } else {
+ ROCKS_LOG_INFO(options_.info_log, "Loading backup %" PRIu32 " OK:\n%s",
+ backup_iter->first,
+ backup_iter->second->GetInfoString().c_str());
+ assert(latest_valid_backup_id_ == 0 ||
+ latest_valid_backup_id_ > backup_iter->first);
+ if (latest_valid_backup_id_ == 0) {
+ latest_valid_backup_id_ = backup_iter->first;
+ }
+ --valid_backups_to_open;
+ }
+ }
+
+ for (const auto& corrupt : corrupt_backups_) {
+ backups_.erase(backups_.find(corrupt.first));
+ }
+ // erase the backups before max_valid_backups_to_open
+ int num_unopened_backups;
+ if (options_.max_valid_backups_to_open == 0) {
+ num_unopened_backups = 0;
+ } else {
+ num_unopened_backups =
+ std::max(0, static_cast<int>(backups_.size()) -
+ options_.max_valid_backups_to_open);
+ }
+ for (int i = 0; i < num_unopened_backups; ++i) {
+ assert(backups_.begin()->second->Empty());
+ backups_.erase(backups_.begin());
+ }
+ }
+
+ ROCKS_LOG_INFO(options_.info_log, "Latest backup is %u", latest_backup_id_);
+ ROCKS_LOG_INFO(options_.info_log, "Latest valid backup is %u",
+ latest_valid_backup_id_);
+
+ // set up threads perform copies from files_to_copy_or_create_ in the
+ // background
+ threads_cpu_priority_ = CpuPriority::kNormal;
+ threads_.reserve(options_.max_background_operations);
+ for (int t = 0; t < options_.max_background_operations; t++) {
+ threads_.emplace_back([this]() {
+#if defined(_GNU_SOURCE) && defined(__GLIBC_PREREQ)
+#if __GLIBC_PREREQ(2, 12)
+ pthread_setname_np(pthread_self(), "backup_engine");
+#endif
+#endif
+ CpuPriority current_priority = CpuPriority::kNormal;
+ CopyOrCreateWorkItem work_item;
+ uint64_t bytes_toward_next_callback = 0;
+ while (files_to_copy_or_create_.read(work_item)) {
+ CpuPriority priority = threads_cpu_priority_;
+ if (current_priority != priority) {
+ TEST_SYNC_POINT_CALLBACK(
+ "BackupEngineImpl::Initialize:SetCpuPriority", &priority);
+ port::SetCpuPriority(0, priority);
+ current_priority = priority;
+ }
+ // `bytes_read` and `bytes_written` stats are enabled based on
+ // compile-time support and cannot be dynamically toggled. So we do not
+ // need to worry about `PerfLevel` here, unlike many other
+ // `IOStatsContext` / `PerfContext` stats.
+ uint64_t prev_bytes_read = IOSTATS(bytes_read);
+ uint64_t prev_bytes_written = IOSTATS(bytes_written);
+
+ CopyOrCreateResult result;
+ Temperature temp = work_item.src_temperature;
+ result.io_status = CopyOrCreateFile(
+ work_item.src_path, work_item.dst_path, work_item.contents,
+ work_item.size_limit, work_item.src_env, work_item.dst_env,
+ work_item.src_env_options, work_item.sync, work_item.rate_limiter,
+ work_item.progress_callback, &temp, work_item.dst_temperature,
+ &bytes_toward_next_callback, &result.size, &result.checksum_hex);
+
+ RecordTick(work_item.stats, BACKUP_READ_BYTES,
+ IOSTATS(bytes_read) - prev_bytes_read);
+ RecordTick(work_item.stats, BACKUP_WRITE_BYTES,
+ IOSTATS(bytes_written) - prev_bytes_written);
+
+ result.db_id = work_item.db_id;
+ result.db_session_id = work_item.db_session_id;
+ result.expected_src_temperature = work_item.src_temperature;
+ result.current_src_temperature = temp;
+ if (result.io_status.ok() && !work_item.src_checksum_hex.empty()) {
+ // unknown checksum function name implies no db table file checksum in
+ // db manifest; work_item.src_checksum_hex not empty means
+ // backup engine has calculated its crc32c checksum for the table
+ // file; therefore, we are able to compare the checksums.
+ if (work_item.src_checksum_func_name ==
+ kUnknownFileChecksumFuncName ||
+ work_item.src_checksum_func_name == kDbFileChecksumFuncName) {
+ if (work_item.src_checksum_hex != result.checksum_hex) {
+ std::string checksum_info(
+ "Expected checksum is " + work_item.src_checksum_hex +
+ " while computed checksum is " + result.checksum_hex);
+ result.io_status = IOStatus::Corruption(
+ "Checksum mismatch after copying to " + work_item.dst_path +
+ ": " + checksum_info);
+ }
+ } else {
+ // FIXME(peterd): dead code?
+ std::string checksum_function_info(
+ "Existing checksum function is " +
+ work_item.src_checksum_func_name +
+ " while provided checksum function is " +
+ kBackupFileChecksumFuncName);
+ ROCKS_LOG_INFO(
+ options_.info_log,
+ "Unable to verify checksum after copying to %s: %s\n",
+ work_item.dst_path.c_str(), checksum_function_info.c_str());
+ }
+ }
+ work_item.result.set_value(std::move(result));
+ }
+ });
+ }
+ ROCKS_LOG_INFO(options_.info_log, "Initialized BackupEngine");
+ return IOStatus::OK();
+}
+
+IOStatus BackupEngineImpl::CreateNewBackupWithMetadata(
+ const CreateBackupOptions& options, DB* db, const std::string& app_metadata,
+ BackupID* new_backup_id_ptr) {
+ assert(initialized_);
+ assert(!read_only_);
+ if (app_metadata.size() > kMaxAppMetaSize) {
+ return IOStatus::InvalidArgument("App metadata too large");
+ }
+
+ if (options.decrease_background_thread_cpu_priority) {
+ if (options.background_thread_cpu_priority < threads_cpu_priority_) {
+ threads_cpu_priority_.store(options.background_thread_cpu_priority);
+ }
+ }
+
+ BackupID new_backup_id = latest_backup_id_ + 1;
+
+ // `bytes_read` and `bytes_written` stats are enabled based on compile-time
+ // support and cannot be dynamically toggled. So we do not need to worry about
+ // `PerfLevel` here, unlike many other `IOStatsContext` / `PerfContext` stats.
+ uint64_t prev_bytes_read = IOSTATS(bytes_read);
+ uint64_t prev_bytes_written = IOSTATS(bytes_written);
+
+ assert(backups_.find(new_backup_id) == backups_.end());
+
+ auto private_dir = GetAbsolutePath(GetPrivateFileRel(new_backup_id));
+ IOStatus io_s = backup_fs_->FileExists(private_dir, io_options_, nullptr);
+ if (io_s.ok()) {
+ // maybe last backup failed and left partial state behind, clean it up.
+ // need to do this before updating backups_ such that a private dir
+ // named after new_backup_id will be cleaned up.
+ // (If an incomplete new backup is followed by an incomplete delete
+ // of the latest full backup, then there could be more than one next
+ // id with a private dir, the last thing to be deleted in delete
+ // backup, but all will be cleaned up with a GarbageCollect.)
+ io_s = GarbageCollect();
+ } else if (io_s.IsNotFound()) {
+ // normal case, the new backup's private dir doesn't exist yet
+ io_s = IOStatus::OK();
+ }
+
+ auto ret = backups_.insert(std::make_pair(
+ new_backup_id, std::unique_ptr<BackupMeta>(new BackupMeta(
+ GetBackupMetaFile(new_backup_id, false /* tmp */),
+ GetBackupMetaFile(new_backup_id, true /* tmp */),
+ &backuped_file_infos_, backup_env_, backup_fs_))));
+ assert(ret.second == true);
+ auto& new_backup = ret.first->second;
+ new_backup->RecordTimestamp();
+ new_backup->SetAppMetadata(app_metadata);
+
+ auto start_backup = backup_env_->NowMicros();
+
+ ROCKS_LOG_INFO(options_.info_log,
+ "Started the backup process -- creating backup %u",
+ new_backup_id);
+
+ if (options_.share_table_files && !options_.share_files_with_checksum) {
+ ROCKS_LOG_WARN(options_.info_log,
+ "BackupEngineOptions::share_files_with_checksum=false is "
+ "DEPRECATED and could lead to data loss.");
+ }
+
+ if (io_s.ok()) {
+ io_s = backup_fs_->CreateDir(private_dir, io_options_, nullptr);
+ }
+
+ // A set into which we will insert the dst_paths that are calculated for live
+ // files and live WAL files.
+ // This is used to check whether a live files shares a dst_path with another
+ // live file.
+ std::unordered_set<std::string> live_dst_paths;
+
+ std::vector<BackupAfterCopyOrCreateWorkItem> backup_items_to_finish;
+ // Add a CopyOrCreateWorkItem to the channel for each live file
+ Status disabled = db->DisableFileDeletions();
+ DBOptions db_options = db->GetDBOptions();
+ Statistics* stats = db_options.statistics.get();
+ if (io_s.ok()) {
+ CheckpointImpl checkpoint(db);
+ uint64_t sequence_number = 0;
+ FileChecksumGenFactory* db_checksum_factory =
+ db_options.file_checksum_gen_factory.get();
+ const std::string kFileChecksumGenFactoryName =
+ "FileChecksumGenCrc32cFactory";
+ bool compare_checksum =
+ db_checksum_factory != nullptr &&
+ db_checksum_factory->Name() == kFileChecksumGenFactoryName
+ ? true
+ : false;
+ EnvOptions src_raw_env_options(db_options);
+ RateLimiter* rate_limiter = options_.backup_rate_limiter.get();
+ io_s = status_to_io_status(checkpoint.CreateCustomCheckpoint(
+ [&](const std::string& /*src_dirname*/, const std::string& /*fname*/,
+ FileType) {
+ // custom checkpoint will switch to calling copy_file_cb after it sees
+ // NotSupported returned from link_file_cb.
+ return IOStatus::NotSupported();
+ } /* link_file_cb */,
+ [&](const std::string& src_dirname, const std::string& fname,
+ uint64_t size_limit_bytes, FileType type,
+ const std::string& checksum_func_name,
+ const std::string& checksum_val,
+ const Temperature src_temperature) {
+ if (type == kWalFile && !options_.backup_log_files) {
+ return IOStatus::OK();
+ }
+ Log(options_.info_log, "add file for backup %s", fname.c_str());
+ uint64_t size_bytes = 0;
+ IOStatus io_st;
+ if (type == kTableFile || type == kBlobFile) {
+ io_st = db_fs_->GetFileSize(src_dirname + "/" + fname, io_options_,
+ &size_bytes, nullptr);
+ if (!io_st.ok()) {
+ Log(options_.info_log, "GetFileSize is failed: %s",
+ io_st.ToString().c_str());
+ return io_st;
+ }
+ }
+ EnvOptions src_env_options;
+ switch (type) {
+ case kWalFile:
+ src_env_options =
+ db_env_->OptimizeForLogRead(src_raw_env_options);
+ break;
+ case kTableFile:
+ src_env_options = db_env_->OptimizeForCompactionTableRead(
+ src_raw_env_options, ImmutableDBOptions(db_options));
+ break;
+ case kDescriptorFile:
+ src_env_options =
+ db_env_->OptimizeForManifestRead(src_raw_env_options);
+ break;
+ case kBlobFile:
+ src_env_options = db_env_->OptimizeForBlobFileRead(
+ src_raw_env_options, ImmutableDBOptions(db_options));
+ break;
+ default:
+ // Other backed up files (like options file) are not read by live
+ // DB, so don't need to worry about avoiding mixing buffered and
+ // direct I/O. Just use plain defaults.
+ src_env_options = src_raw_env_options;
+ break;
+ }
+ io_st = AddBackupFileWorkItem(
+ live_dst_paths, backup_items_to_finish, new_backup_id,
+ options_.share_table_files &&
+ (type == kTableFile || type == kBlobFile),
+ src_dirname, fname, src_env_options, rate_limiter, type,
+ size_bytes, db_options.statistics.get(), size_limit_bytes,
+ options_.share_files_with_checksum &&
+ (type == kTableFile || type == kBlobFile),
+ options.progress_callback, "" /* contents */, checksum_func_name,
+ checksum_val, src_temperature);
+ return io_st;
+ } /* copy_file_cb */,
+ [&](const std::string& fname, const std::string& contents,
+ FileType type) {
+ Log(options_.info_log, "add file for backup %s", fname.c_str());
+ return AddBackupFileWorkItem(
+ live_dst_paths, backup_items_to_finish, new_backup_id,
+ false /* shared */, "" /* src_dir */, fname,
+ EnvOptions() /* src_env_options */, rate_limiter, type,
+ contents.size(), db_options.statistics.get(), 0 /* size_limit */,
+ false /* shared_checksum */, options.progress_callback, contents);
+ } /* create_file_cb */,
+ &sequence_number,
+ options.flush_before_backup ? 0 : std::numeric_limits<uint64_t>::max(),
+ compare_checksum));
+ if (io_s.ok()) {
+ new_backup->SetSequenceNumber(sequence_number);
+ }
+ }
+ ROCKS_LOG_INFO(options_.info_log, "add files for backup done, wait finish.");
+ IOStatus item_io_status;
+ for (auto& item : backup_items_to_finish) {
+ item.result.wait();
+ auto result = item.result.get();
+ item_io_status = result.io_status;
+ Temperature temp = result.expected_src_temperature;
+ if (result.current_src_temperature != Temperature::kUnknown &&
+ (temp == Temperature::kUnknown ||
+ options_.current_temperatures_override_manifest)) {
+ temp = result.current_src_temperature;
+ }
+ if (item_io_status.ok() && item.shared && item.needed_to_copy) {
+ item_io_status = item.backup_env->GetFileSystem()->RenameFile(
+ item.dst_path_tmp, item.dst_path, io_options_, nullptr);
+ }
+ if (item_io_status.ok()) {
+ item_io_status = new_backup.get()->AddFile(std::make_shared<FileInfo>(
+ item.dst_relative, result.size, result.checksum_hex, result.db_id,
+ result.db_session_id, temp));
+ }
+ if (!item_io_status.ok()) {
+ io_s = item_io_status;
+ }
+ }
+
+ // we copied all the files, enable file deletions
+ if (disabled.ok()) { // If we successfully disabled file deletions
+ db->EnableFileDeletions(false).PermitUncheckedError();
+ }
+ auto backup_time = backup_env_->NowMicros() - start_backup;
+
+ if (io_s.ok()) {
+ // persist the backup metadata on the disk
+ io_s = new_backup->StoreToFile(options_.sync, options_.schema_version,
+ schema_test_options_.get());
+ }
+ if (io_s.ok() && options_.sync) {
+ std::unique_ptr<FSDirectory> backup_private_directory;
+ backup_fs_
+ ->NewDirectory(GetAbsolutePath(GetPrivateFileRel(new_backup_id, false)),
+ io_options_, &backup_private_directory, nullptr)
+ .PermitUncheckedError();
+ if (backup_private_directory != nullptr) {
+ io_s = backup_private_directory->FsyncWithDirOptions(io_options_, nullptr,
+ DirFsyncOptions());
+ }
+ if (io_s.ok() && private_directory_ != nullptr) {
+ io_s = private_directory_->FsyncWithDirOptions(io_options_, nullptr,
+ DirFsyncOptions());
+ }
+ if (io_s.ok() && meta_directory_ != nullptr) {
+ io_s = meta_directory_->FsyncWithDirOptions(io_options_, nullptr,
+ DirFsyncOptions());
+ }
+ if (io_s.ok() && shared_directory_ != nullptr) {
+ io_s = shared_directory_->FsyncWithDirOptions(io_options_, nullptr,
+ DirFsyncOptions());
+ }
+ if (io_s.ok() && backup_directory_ != nullptr) {
+ io_s = backup_directory_->FsyncWithDirOptions(io_options_, nullptr,
+ DirFsyncOptions());
+ }
+ }
+
+ if (io_s.ok()) {
+ backup_statistics_.IncrementNumberSuccessBackup();
+ // here we know that we succeeded and installed the new backup
+ latest_backup_id_ = new_backup_id;
+ latest_valid_backup_id_ = new_backup_id;
+ if (new_backup_id_ptr) {
+ *new_backup_id_ptr = new_backup_id;
+ }
+ ROCKS_LOG_INFO(options_.info_log, "Backup DONE. All is good");
+
+ // backup_speed is in byte/second
+ double backup_speed = new_backup->GetSize() / (1.048576 * backup_time);
+ ROCKS_LOG_INFO(options_.info_log, "Backup number of files: %u",
+ new_backup->GetNumberFiles());
+ char human_size[16];
+ AppendHumanBytes(new_backup->GetSize(), human_size, sizeof(human_size));
+ ROCKS_LOG_INFO(options_.info_log, "Backup size: %s", human_size);
+ ROCKS_LOG_INFO(options_.info_log, "Backup time: %" PRIu64 " microseconds",
+ backup_time);
+ ROCKS_LOG_INFO(options_.info_log, "Backup speed: %.3f MB/s", backup_speed);
+ ROCKS_LOG_INFO(options_.info_log, "Backup Statistics %s",
+ backup_statistics_.ToString().c_str());
+ } else {
+ backup_statistics_.IncrementNumberFailBackup();
+ // clean all the files we might have created
+ ROCKS_LOG_INFO(options_.info_log, "Backup failed -- %s",
+ io_s.ToString().c_str());
+ ROCKS_LOG_INFO(options_.info_log, "Backup Statistics %s\n",
+ backup_statistics_.ToString().c_str());
+ // delete files that we might have already written
+ might_need_garbage_collect_ = true;
+ DeleteBackup(new_backup_id).PermitUncheckedError();
+ }
+
+ RecordTick(stats, BACKUP_READ_BYTES, IOSTATS(bytes_read) - prev_bytes_read);
+ RecordTick(stats, BACKUP_WRITE_BYTES,
+ IOSTATS(bytes_written) - prev_bytes_written);
+ return io_s;
+}
+
+IOStatus BackupEngineImpl::PurgeOldBackups(uint32_t num_backups_to_keep) {
+ assert(initialized_);
+ assert(!read_only_);
+
+ // Best effort deletion even with errors
+ IOStatus overall_status = IOStatus::OK();
+
+ ROCKS_LOG_INFO(options_.info_log, "Purging old backups, keeping %u",
+ num_backups_to_keep);
+ std::vector<BackupID> to_delete;
+ auto itr = backups_.begin();
+ while ((backups_.size() - to_delete.size()) > num_backups_to_keep) {
+ to_delete.push_back(itr->first);
+ itr++;
+ }
+ for (auto backup_id : to_delete) {
+ // Do not GC until end
+ IOStatus io_s = DeleteBackupNoGC(backup_id);
+ if (!io_s.ok()) {
+ overall_status = io_s;
+ }
+ }
+ // Clean up after any incomplete backup deletion, potentially from
+ // earlier session.
+ if (might_need_garbage_collect_) {
+ IOStatus io_s = GarbageCollect();
+ if (!io_s.ok() && overall_status.ok()) {
+ overall_status = io_s;
+ }
+ }
+ return overall_status;
+}
+
+IOStatus BackupEngineImpl::DeleteBackup(BackupID backup_id) {
+ IOStatus s1 = DeleteBackupNoGC(backup_id);
+ IOStatus s2 = IOStatus::OK();
+
+ // Clean up after any incomplete backup deletion, potentially from
+ // earlier session.
+ if (might_need_garbage_collect_) {
+ s2 = GarbageCollect();
+ }
+
+ if (!s1.ok()) {
+ // Any failure in the primary objective trumps any failure in the
+ // secondary objective.
+ s2.PermitUncheckedError();
+ return s1;
+ } else {
+ return s2;
+ }
+}
+
+// Does not auto-GarbageCollect nor lock
+IOStatus BackupEngineImpl::DeleteBackupNoGC(BackupID backup_id) {
+ assert(initialized_);
+ assert(!read_only_);
+
+ ROCKS_LOG_INFO(options_.info_log, "Deleting backup %u", backup_id);
+ auto backup = backups_.find(backup_id);
+ if (backup != backups_.end()) {
+ IOStatus io_s = backup->second->Delete();
+ if (!io_s.ok()) {
+ return io_s;
+ }
+ backups_.erase(backup);
+ } else {
+ auto corrupt = corrupt_backups_.find(backup_id);
+ if (corrupt == corrupt_backups_.end()) {
+ return IOStatus::NotFound("Backup not found");
+ }
+ IOStatus io_s = corrupt->second.second->Delete();
+ if (!io_s.ok()) {
+ return io_s;
+ }
+ corrupt->second.first.PermitUncheckedError();
+ corrupt_backups_.erase(corrupt);
+ }
+
+ // After removing meta file, best effort deletion even with errors.
+ // (Don't delete other files if we can't delete the meta file right
+ // now.)
+ std::vector<std::string> to_delete;
+ for (auto& itr : backuped_file_infos_) {
+ if (itr.second->refs == 0) {
+ IOStatus io_s = backup_fs_->DeleteFile(GetAbsolutePath(itr.first),
+ io_options_, nullptr);
+ ROCKS_LOG_INFO(options_.info_log, "Deleting %s -- %s", itr.first.c_str(),
+ io_s.ToString().c_str());
+ to_delete.push_back(itr.first);
+ if (!io_s.ok()) {
+ // Trying again later might work
+ might_need_garbage_collect_ = true;
+ }
+ }
+ }
+ for (auto& td : to_delete) {
+ backuped_file_infos_.erase(td);
+ }
+
+ // take care of private dirs -- GarbageCollect() will take care of them
+ // if they are not empty
+ std::string private_dir = GetPrivateFileRel(backup_id);
+ IOStatus io_s =
+ backup_fs_->DeleteDir(GetAbsolutePath(private_dir), io_options_, nullptr);
+ ROCKS_LOG_INFO(options_.info_log, "Deleting private dir %s -- %s",
+ private_dir.c_str(), io_s.ToString().c_str());
+ if (!io_s.ok()) {
+ // Full gc or trying again later might work
+ might_need_garbage_collect_ = true;
+ }
+ return IOStatus::OK();
+}
+
+void BackupEngineImpl::SetBackupInfoFromBackupMeta(
+ BackupID id, const BackupMeta& meta, BackupInfo* backup_info,
+ bool include_file_details) const {
+ *backup_info = BackupInfo(id, meta.GetTimestamp(), meta.GetSize(),
+ meta.GetNumberFiles(), meta.GetAppMetadata());
+ std::string dir =
+ options_.backup_dir + "/" + kPrivateDirSlash + std::to_string(id);
+ if (include_file_details) {
+ auto& file_details = backup_info->file_details;
+ file_details.reserve(meta.GetFiles().size());
+ for (auto& file_ptr : meta.GetFiles()) {
+ BackupFileInfo& finfo = *file_details.emplace(file_details.end());
+ finfo.relative_filename = file_ptr->filename;
+ finfo.size = file_ptr->size;
+ finfo.directory = dir;
+ uint64_t number;
+ FileType type;
+ bool ok = ParseFileName(file_ptr->filename, &number, &type);
+ if (ok) {
+ finfo.file_number = number;
+ finfo.file_type = type;
+ }
+ // TODO: temperature, file_checksum, file_checksum_func_name
+ }
+ backup_info->name_for_open = GetAbsolutePath(GetPrivateFileRel(id));
+ backup_info->name_for_open.pop_back(); // remove trailing '/'
+ backup_info->env_for_open = meta.GetEnvForOpen();
+ }
+}
+
+Status BackupEngineImpl::GetBackupInfo(BackupID backup_id,
+ BackupInfo* backup_info,
+ bool include_file_details) const {
+ assert(initialized_);
+ if (backup_id == kLatestBackupIDMarker) {
+ // Note: Read latest_valid_backup_id_ inside of lock
+ backup_id = latest_valid_backup_id_;
+ }
+ auto corrupt_itr = corrupt_backups_.find(backup_id);
+ if (corrupt_itr != corrupt_backups_.end()) {
+ return Status::Corruption(corrupt_itr->second.first.ToString());
+ }
+ auto backup_itr = backups_.find(backup_id);
+ if (backup_itr == backups_.end()) {
+ return Status::NotFound("Backup not found");
+ }
+ auto& backup = backup_itr->second;
+ if (backup->Empty()) {
+ return Status::NotFound("Backup not found");
+ }
+
+ SetBackupInfoFromBackupMeta(backup_id, *backup, backup_info,
+ include_file_details);
+ return Status::OK();
+}
+
+void BackupEngineImpl::GetBackupInfo(std::vector<BackupInfo>* backup_info,
+ bool include_file_details) const {
+ assert(initialized_);
+ backup_info->resize(backups_.size());
+ size_t i = 0;
+ for (auto& backup : backups_) {
+ const BackupMeta& meta = *backup.second;
+ if (!meta.Empty()) {
+ SetBackupInfoFromBackupMeta(backup.first, meta, &backup_info->at(i++),
+ include_file_details);
+ }
+ }
+}
+
+void BackupEngineImpl::GetCorruptedBackups(
+ std::vector<BackupID>* corrupt_backup_ids) const {
+ assert(initialized_);
+ corrupt_backup_ids->reserve(corrupt_backups_.size());
+ for (auto& backup : corrupt_backups_) {
+ corrupt_backup_ids->push_back(backup.first);
+ }
+}
+
+IOStatus BackupEngineImpl::RestoreDBFromBackup(
+ const RestoreOptions& options, BackupID backup_id,
+ const std::string& db_dir, const std::string& wal_dir) const {
+ assert(initialized_);
+ if (backup_id == kLatestBackupIDMarker) {
+ // Note: Read latest_valid_backup_id_ inside of lock
+ backup_id = latest_valid_backup_id_;
+ }
+ auto corrupt_itr = corrupt_backups_.find(backup_id);
+ if (corrupt_itr != corrupt_backups_.end()) {
+ return corrupt_itr->second.first;
+ }
+ auto backup_itr = backups_.find(backup_id);
+ if (backup_itr == backups_.end()) {
+ return IOStatus::NotFound("Backup not found");
+ }
+ auto& backup = backup_itr->second;
+ if (backup->Empty()) {
+ return IOStatus::NotFound("Backup not found");
+ }
+
+ ROCKS_LOG_INFO(options_.info_log, "Restoring backup id %u\n", backup_id);
+ ROCKS_LOG_INFO(options_.info_log, "keep_log_files: %d\n",
+ static_cast<int>(options.keep_log_files));
+
+ // just in case. Ignore errors
+ db_fs_->CreateDirIfMissing(db_dir, io_options_, nullptr)
+ .PermitUncheckedError();
+ db_fs_->CreateDirIfMissing(wal_dir, io_options_, nullptr)
+ .PermitUncheckedError();
+
+ if (options.keep_log_files) {
+ // delete files in db_dir, but keep all the log files
+ DeleteChildren(db_dir, 1 << kWalFile);
+ // move all the files from archive dir to wal_dir
+ std::string archive_dir = ArchivalDirectory(wal_dir);
+ std::vector<std::string> archive_files;
+ db_fs_->GetChildren(archive_dir, io_options_, &archive_files, nullptr)
+ .PermitUncheckedError(); // ignore errors
+ for (const auto& f : archive_files) {
+ uint64_t number;
+ FileType type;
+ bool ok = ParseFileName(f, &number, &type);
+ if (ok && type == kWalFile) {
+ ROCKS_LOG_INFO(options_.info_log,
+ "Moving log file from archive/ to wal_dir: %s",
+ f.c_str());
+ IOStatus io_s = db_fs_->RenameFile(
+ archive_dir + "/" + f, wal_dir + "/" + f, io_options_, nullptr);
+ if (!io_s.ok()) {
+ // if we can't move log file from archive_dir to wal_dir,
+ // we should fail, since it might mean data loss
+ return io_s;
+ }
+ }
+ }
+ } else {
+ DeleteChildren(wal_dir);
+ DeleteChildren(ArchivalDirectory(wal_dir));
+ DeleteChildren(db_dir);
+ }
+
+ IOStatus io_s;
+ std::vector<RestoreAfterCopyOrCreateWorkItem> restore_items_to_finish;
+ std::string temporary_current_file;
+ std::string final_current_file;
+ std::unique_ptr<FSDirectory> db_dir_for_fsync;
+ std::unique_ptr<FSDirectory> wal_dir_for_fsync;
+
+ for (const auto& file_info : backup->GetFiles()) {
+ const std::string& file = file_info->filename;
+ // 1. get DB filename
+ std::string dst = file_info->GetDbFileName();
+
+ // 2. find the filetype
+ uint64_t number;
+ FileType type;
+ bool ok = ParseFileName(dst, &number, &type);
+ if (!ok) {
+ return IOStatus::Corruption("Backup corrupted: Fail to parse filename " +
+ dst);
+ }
+ // 3. Construct the final path
+ // kWalFile lives in wal_dir and all the rest live in db_dir
+ if (type == kWalFile) {
+ dst = wal_dir + "/" + dst;
+ if (options_.sync && !wal_dir_for_fsync) {
+ io_s = db_fs_->NewDirectory(wal_dir, io_options_, &wal_dir_for_fsync,
+ nullptr);
+ if (!io_s.ok()) {
+ return io_s;
+ }
+ }
+ } else {
+ dst = db_dir + "/" + dst;
+ if (options_.sync && !db_dir_for_fsync) {
+ io_s = db_fs_->NewDirectory(db_dir, io_options_, &db_dir_for_fsync,
+ nullptr);
+ if (!io_s.ok()) {
+ return io_s;
+ }
+ }
+ }
+ // For atomicity, initially restore CURRENT file to a temporary name.
+ // This is useful even without options_.sync e.g. in case the restore
+ // process is interrupted.
+ if (type == kCurrentFile) {
+ final_current_file = dst;
+ dst = temporary_current_file = dst + ".tmp";
+ }
+
+ ROCKS_LOG_INFO(options_.info_log, "Restoring %s to %s\n", file.c_str(),
+ dst.c_str());
+ CopyOrCreateWorkItem copy_or_create_work_item(
+ GetAbsolutePath(file), dst, Temperature::kUnknown /* src_temp */,
+ file_info->temp, "" /* contents */, backup_env_, db_env_,
+ EnvOptions() /* src_env_options */, options_.sync,
+ options_.restore_rate_limiter.get(), file_info->size,
+ nullptr /* stats */);
+ RestoreAfterCopyOrCreateWorkItem after_copy_or_create_work_item(
+ copy_or_create_work_item.result.get_future(), file, dst,
+ file_info->checksum_hex);
+ files_to_copy_or_create_.write(std::move(copy_or_create_work_item));
+ restore_items_to_finish.push_back(
+ std::move(after_copy_or_create_work_item));
+ }
+ IOStatus item_io_status;
+ for (auto& item : restore_items_to_finish) {
+ item.result.wait();
+ auto result = item.result.get();
+ item_io_status = result.io_status;
+ // Note: It is possible that both of the following bad-status cases occur
+ // during copying. But, we only return one status.
+ if (!item_io_status.ok()) {
+ io_s = item_io_status;
+ break;
+ } else if (!item.checksum_hex.empty() &&
+ item.checksum_hex != result.checksum_hex) {
+ io_s = IOStatus::Corruption(
+ "While restoring " + item.from_file + " -> " + item.to_file +
+ ": expected checksum is " + item.checksum_hex +
+ " while computed checksum is " + result.checksum_hex);
+ break;
+ }
+ }
+
+ // When enabled, the first FsyncWithDirOptions is to ensure all files are
+ // fully persisted before renaming CURRENT.tmp
+ if (io_s.ok() && db_dir_for_fsync) {
+ ROCKS_LOG_INFO(options_.info_log, "Restore: fsync\n");
+ io_s = db_dir_for_fsync->FsyncWithDirOptions(io_options_, nullptr,
+ DirFsyncOptions());
+ }
+
+ if (io_s.ok() && wal_dir_for_fsync) {
+ io_s = wal_dir_for_fsync->FsyncWithDirOptions(io_options_, nullptr,
+ DirFsyncOptions());
+ }
+
+ if (io_s.ok() && !temporary_current_file.empty()) {
+ ROCKS_LOG_INFO(options_.info_log, "Restore: atomic rename CURRENT.tmp\n");
+ assert(!final_current_file.empty());
+ io_s = db_fs_->RenameFile(temporary_current_file, final_current_file,
+ io_options_, nullptr);
+ }
+
+ if (io_s.ok() && db_dir_for_fsync && !temporary_current_file.empty()) {
+ // Second FsyncWithDirOptions is to ensure the final atomic rename of DB
+ // restore is fully persisted even if power goes out right after restore
+ // operation returns success
+ assert(db_dir_for_fsync);
+ io_s = db_dir_for_fsync->FsyncWithDirOptions(
+ io_options_, nullptr, DirFsyncOptions(final_current_file));
+ }
+
+ ROCKS_LOG_INFO(options_.info_log, "Restoring done -- %s\n",
+ io_s.ToString().c_str());
+ return io_s;
+}
+
+IOStatus BackupEngineImpl::VerifyBackup(BackupID backup_id,
+ bool verify_with_checksum) const {
+ assert(initialized_);
+ // Check if backup_id is corrupted, or valid and registered
+ auto corrupt_itr = corrupt_backups_.find(backup_id);
+ if (corrupt_itr != corrupt_backups_.end()) {
+ return corrupt_itr->second.first;
+ }
+
+ auto backup_itr = backups_.find(backup_id);
+ if (backup_itr == backups_.end()) {
+ return IOStatus::NotFound();
+ }
+
+ auto& backup = backup_itr->second;
+ if (backup->Empty()) {
+ return IOStatus::NotFound();
+ }
+
+ ROCKS_LOG_INFO(options_.info_log, "Verifying backup id %u\n", backup_id);
+
+ // Find all existing backup files belong to backup_id
+ std::unordered_map<std::string, uint64_t> curr_abs_path_to_size;
+ for (const auto& rel_dir : {GetPrivateFileRel(backup_id), GetSharedFileRel(),
+ GetSharedFileWithChecksumRel()}) {
+ const auto abs_dir = GetAbsolutePath(rel_dir);
+ // Shared directories allowed to be missing in some cases. Expected but
+ // missing files will be reported a few lines down.
+ ReadChildFileCurrentSizes(abs_dir, backup_fs_, &curr_abs_path_to_size)
+ .PermitUncheckedError();
+ }
+
+ // For all files registered in backup
+ for (const auto& file_info : backup->GetFiles()) {
+ const auto abs_path = GetAbsolutePath(file_info->filename);
+ // check existence of the file
+ if (curr_abs_path_to_size.find(abs_path) == curr_abs_path_to_size.end()) {
+ return IOStatus::NotFound("File missing: " + abs_path);
+ }
+ // verify file size
+ if (file_info->size != curr_abs_path_to_size[abs_path]) {
+ std::string size_info("Expected file size is " +
+ std::to_string(file_info->size) +
+ " while found file size is " +
+ std::to_string(curr_abs_path_to_size[abs_path]));
+ return IOStatus::Corruption("File corrupted: File size mismatch for " +
+ abs_path + ": " + size_info);
+ }
+ if (verify_with_checksum && !file_info->checksum_hex.empty()) {
+ // verify file checksum
+ std::string checksum_hex;
+ ROCKS_LOG_INFO(options_.info_log, "Verifying %s checksum...\n",
+ abs_path.c_str());
+ IOStatus io_s = ReadFileAndComputeChecksum(
+ abs_path, backup_fs_, EnvOptions(), 0 /* size_limit */, &checksum_hex,
+ Temperature::kUnknown);
+ if (!io_s.ok()) {
+ return io_s;
+ } else if (file_info->checksum_hex != checksum_hex) {
+ std::string checksum_info(
+ "Expected checksum is " + file_info->checksum_hex +
+ " while computed checksum is " + checksum_hex);
+ return IOStatus::Corruption("File corrupted: Checksum mismatch for " +
+ abs_path + ": " + checksum_info);
+ }
+ }
+ }
+ return IOStatus::OK();
+}
+
+IOStatus BackupEngineImpl::CopyOrCreateFile(
+ const std::string& src, const std::string& dst, const std::string& contents,
+ uint64_t size_limit, Env* src_env, Env* dst_env,
+ const EnvOptions& src_env_options, bool sync, RateLimiter* rate_limiter,
+ std::function<void()> progress_callback, Temperature* src_temperature,
+ Temperature dst_temperature, uint64_t* bytes_toward_next_callback,
+ uint64_t* size, std::string* checksum_hex) {
+ assert(src.empty() != contents.empty());
+ IOStatus io_s;
+ std::unique_ptr<FSWritableFile> dst_file;
+ std::unique_ptr<FSSequentialFile> src_file;
+ FileOptions dst_file_options;
+ dst_file_options.use_mmap_writes = false;
+ dst_file_options.temperature = dst_temperature;
+ // TODO:(gzh) maybe use direct reads/writes here if possible
+ if (size != nullptr) {
+ *size = 0;
+ }
+ uint32_t checksum_value = 0;
+
+ // Check if size limit is set. if not, set it to very big number
+ if (size_limit == 0) {
+ size_limit = std::numeric_limits<uint64_t>::max();
+ }
+
+ io_s = dst_env->GetFileSystem()->NewWritableFile(dst, dst_file_options,
+ &dst_file, nullptr);
+ if (io_s.ok() && !src.empty()) {
+ auto src_file_options = FileOptions(src_env_options);
+ src_file_options.temperature = *src_temperature;
+ io_s = src_env->GetFileSystem()->NewSequentialFile(src, src_file_options,
+ &src_file, nullptr);
+ }
+ if (io_s.IsPathNotFound() && *src_temperature != Temperature::kUnknown) {
+ // Retry without temperature hint in case the FileSystem is strict with
+ // non-kUnknown temperature option
+ io_s = src_env->GetFileSystem()->NewSequentialFile(
+ src, FileOptions(src_env_options), &src_file, nullptr);
+ }
+ if (!io_s.ok()) {
+ return io_s;
+ }
+
+ size_t buf_size =
+ rate_limiter ? static_cast<size_t>(rate_limiter->GetSingleBurstBytes())
+ : kDefaultCopyFileBufferSize;
+
+ std::unique_ptr<WritableFileWriter> dest_writer(
+ new WritableFileWriter(std::move(dst_file), dst, dst_file_options));
+ std::unique_ptr<SequentialFileReader> src_reader;
+ std::unique_ptr<char[]> buf;
+ if (!src.empty()) {
+ // Return back current temperature in FileSystem
+ *src_temperature = src_file->GetTemperature();
+
+ src_reader.reset(new SequentialFileReader(
+ std::move(src_file), src, nullptr /* io_tracer */, {}, rate_limiter));
+ buf.reset(new char[buf_size]);
+ }
+
+ Slice data;
+ do {
+ if (stop_backup_.load(std::memory_order_acquire)) {
+ return status_to_io_status(Status::Incomplete("Backup stopped"));
+ }
+ if (!src.empty()) {
+ size_t buffer_to_read =
+ (buf_size < size_limit) ? buf_size : static_cast<size_t>(size_limit);
+ io_s = src_reader->Read(buffer_to_read, &data, buf.get(),
+ Env::IO_LOW /* rate_limiter_priority */);
+ *bytes_toward_next_callback += data.size();
+ } else {
+ data = contents;
+ }
+ size_limit -= data.size();
+ TEST_SYNC_POINT_CALLBACK(
+ "BackupEngineImpl::CopyOrCreateFile:CorruptionDuringBackup",
+ (src.length() > 4 && src.rfind(".sst") == src.length() - 4) ? &data
+ : nullptr);
+
+ if (!io_s.ok()) {
+ return io_s;
+ }
+
+ if (size != nullptr) {
+ *size += data.size();
+ }
+ if (checksum_hex != nullptr) {
+ checksum_value = crc32c::Extend(checksum_value, data.data(), data.size());
+ }
+ io_s = dest_writer->Append(data);
+
+ if (rate_limiter != nullptr) {
+ if (!src.empty()) {
+ rate_limiter->Request(data.size(), Env::IO_LOW, nullptr /* stats */,
+ RateLimiter::OpType::kWrite);
+ } else {
+ LoopRateLimitRequestHelper(data.size(), rate_limiter, Env::IO_LOW,
+ nullptr /* stats */,
+ RateLimiter::OpType::kWrite);
+ }
+ }
+ while (*bytes_toward_next_callback >=
+ options_.callback_trigger_interval_size) {
+ *bytes_toward_next_callback -= options_.callback_trigger_interval_size;
+ std::lock_guard<std::mutex> lock(byte_report_mutex_);
+ progress_callback();
+ }
+ } while (io_s.ok() && contents.empty() && data.size() > 0 && size_limit > 0);
+
+ // Convert uint32_t checksum to hex checksum
+ if (checksum_hex != nullptr) {
+ checksum_hex->assign(ChecksumInt32ToHex(checksum_value));
+ }
+
+ if (io_s.ok() && sync) {
+ io_s = dest_writer->Sync(false);
+ }
+ if (io_s.ok()) {
+ io_s = dest_writer->Close();
+ }
+ return io_s;
+}
+
+// fname will always start with "/"
+IOStatus BackupEngineImpl::AddBackupFileWorkItem(
+ std::unordered_set<std::string>& live_dst_paths,
+ std::vector<BackupAfterCopyOrCreateWorkItem>& backup_items_to_finish,
+ BackupID backup_id, bool shared, const std::string& src_dir,
+ const std::string& fname, const EnvOptions& src_env_options,
+ RateLimiter* rate_limiter, FileType file_type, uint64_t size_bytes,
+ Statistics* stats, uint64_t size_limit, bool shared_checksum,
+ std::function<void()> progress_callback, const std::string& contents,
+ const std::string& src_checksum_func_name,
+ const std::string& src_checksum_str, const Temperature src_temperature) {
+ assert(contents.empty() != src_dir.empty());
+
+ std::string src_path = src_dir + "/" + fname;
+ std::string dst_relative;
+ std::string dst_relative_tmp;
+ std::string db_id;
+ std::string db_session_id;
+ // crc32c checksum in hex. empty == unavailable / unknown
+ std::string checksum_hex;
+
+ // Whenever a default checksum function name is passed in, we will compares
+ // the corresponding checksum values after copying. Note that only table and
+ // blob files may have a known checksum function name passed in.
+ //
+ // If no default checksum function name is passed in and db session id is not
+ // available, we will calculate the checksum *before* copying in two cases
+ // (we always calcuate checksums when copying or creating for any file types):
+ // a) share_files_with_checksum is true and file type is table;
+ // b) share_table_files is true and the file exists already.
+ //
+ // Step 0: Check if default checksum function name is passed in
+ if (kDbFileChecksumFuncName == src_checksum_func_name) {
+ if (src_checksum_str == kUnknownFileChecksum) {
+ return status_to_io_status(
+ Status::Aborted("Unknown checksum value for " + fname));
+ }
+ checksum_hex = ChecksumStrToHex(src_checksum_str);
+ }
+
+ // Step 1: Prepare the relative path to destination
+ if (shared && shared_checksum) {
+ if (GetNamingNoFlags() != BackupEngineOptions::kLegacyCrc32cAndFileSize &&
+ file_type != kBlobFile) {
+ // Prepare db_session_id to add to the file name
+ // Ignore the returned status
+ // In the failed cases, db_id and db_session_id will be empty
+ GetFileDbIdentities(db_env_, src_env_options, src_path, src_temperature,
+ rate_limiter, &db_id, &db_session_id)
+ .PermitUncheckedError();
+ }
+ // Calculate checksum if checksum and db session id are not available.
+ // If db session id is available, we will not calculate the checksum
+ // since the session id should suffice to avoid file name collision in
+ // the shared_checksum directory.
+ if (checksum_hex.empty() && db_session_id.empty()) {
+ IOStatus io_s = ReadFileAndComputeChecksum(
+ src_path, db_fs_, src_env_options, size_limit, &checksum_hex,
+ src_temperature);
+ if (!io_s.ok()) {
+ return io_s;
+ }
+ }
+ if (size_bytes == std::numeric_limits<uint64_t>::max()) {
+ return IOStatus::NotFound("File missing: " + src_path);
+ }
+ // dst_relative depends on the following conditions:
+ // 1) the naming scheme is kUseDbSessionId,
+ // 2) db_session_id is not empty,
+ // 3) checksum is available in the DB manifest.
+ // If 1,2,3) are satisfied, then dst_relative will be of the form:
+ // shared_checksum/<file_number>_<checksum>_<db_session_id>.sst
+ // If 1,2) are satisfied, then dst_relative will be of the form:
+ // shared_checksum/<file_number>_<db_session_id>.sst
+ // Otherwise, dst_relative is of the form
+ // shared_checksum/<file_number>_<checksum>_<size>.sst
+ //
+ // For blob files, db_session_id is not supported with the blob file format.
+ // It uses original/legacy naming scheme.
+ // dst_relative will be of the form:
+ // shared_checksum/<file_number>_<checksum>_<size>.blob
+ dst_relative = GetSharedFileWithChecksum(fname, checksum_hex, size_bytes,
+ db_session_id);
+ dst_relative_tmp = GetSharedFileWithChecksumRel(dst_relative, true);
+ dst_relative = GetSharedFileWithChecksumRel(dst_relative, false);
+ } else if (shared) {
+ dst_relative_tmp = GetSharedFileRel(fname, true);
+ dst_relative = GetSharedFileRel(fname, false);
+ } else {
+ dst_relative = GetPrivateFileRel(backup_id, false, fname);
+ }
+
+ // We copy into `temp_dest_path` and, once finished, rename it to
+ // `final_dest_path`. This allows files to atomically appear at
+ // `final_dest_path`. We can copy directly to the final path when atomicity
+ // is unnecessary, like for files in private backup directories.
+ const std::string* copy_dest_path;
+ std::string temp_dest_path;
+ std::string final_dest_path = GetAbsolutePath(dst_relative);
+ if (!dst_relative_tmp.empty()) {
+ temp_dest_path = GetAbsolutePath(dst_relative_tmp);
+ copy_dest_path = &temp_dest_path;
+ } else {
+ copy_dest_path = &final_dest_path;
+ }
+
+ // Step 2: Determine whether to copy or not
+ // if it's shared, we also need to check if it exists -- if it does, no need
+ // to copy it again.
+ bool need_to_copy = true;
+ // true if final_dest_path is the same path as another live file
+ const bool same_path =
+ live_dst_paths.find(final_dest_path) != live_dst_paths.end();
+
+ bool file_exists = false;
+ if (shared && !same_path) {
+ // Should be in shared directory but not a live path, check existence in
+ // shared directory
+ IOStatus exist =
+ backup_fs_->FileExists(final_dest_path, io_options_, nullptr);
+ if (exist.ok()) {
+ file_exists = true;
+ } else if (exist.IsNotFound()) {
+ file_exists = false;
+ } else {
+ return exist;
+ }
+ }
+
+ if (!contents.empty()) {
+ need_to_copy = false;
+ } else if (shared && (same_path || file_exists)) {
+ need_to_copy = false;
+ auto find_result = backuped_file_infos_.find(dst_relative);
+ if (find_result == backuped_file_infos_.end() && !same_path) {
+ // file exists but not referenced
+ ROCKS_LOG_INFO(
+ options_.info_log,
+ "%s already present, but not referenced by any backup. We will "
+ "overwrite the file.",
+ fname.c_str());
+ need_to_copy = true;
+ // Defer any failure reporting to when we try to write the file
+ backup_fs_->DeleteFile(final_dest_path, io_options_, nullptr)
+ .PermitUncheckedError();
+ } else {
+ // file exists and referenced
+ if (checksum_hex.empty()) {
+ // same_path should not happen for a standard DB, so OK to
+ // read file contents to check for checksum mismatch between
+ // two files from same DB getting same name.
+ // For compatibility with future meta file that might not have
+ // crc32c checksum available, consider it might be empty, but
+ // we don't currently generate meta file without crc32c checksum.
+ // Therefore we have to read & compute it if we don't have it.
+ if (!same_path && !find_result->second->checksum_hex.empty()) {
+ assert(find_result != backuped_file_infos_.end());
+ // Note: to save I/O on incremental backups, we copy prior known
+ // checksum of the file instead of reading entire file contents
+ // to recompute it.
+ checksum_hex = find_result->second->checksum_hex;
+ // Regarding corruption detection, consider:
+ // (a) the DB file is corrupt (since previous backup) and the backup
+ // file is OK: we failed to detect, but the backup is safe. DB can
+ // be repaired/restored once its corruption is detected.
+ // (b) the backup file is corrupt (since previous backup) and the
+ // db file is OK: we failed to detect, but the backup is corrupt.
+ // CreateNewBackup should support fast incremental backups and
+ // there's no way to support that without reading all the files.
+ // We might add an option for extra checks on incremental backup,
+ // but until then, use VerifyBackups to check existing backup data.
+ // (c) file name collision with legitimately different content.
+ // This is almost inconceivable with a well-generated DB session
+ // ID, but even in that case, we double check the file sizes in
+ // BackupMeta::AddFile.
+ } else {
+ IOStatus io_s = ReadFileAndComputeChecksum(
+ src_path, db_fs_, src_env_options, size_limit, &checksum_hex,
+ src_temperature);
+ if (!io_s.ok()) {
+ return io_s;
+ }
+ }
+ }
+ if (!db_session_id.empty()) {
+ ROCKS_LOG_INFO(options_.info_log,
+ "%s already present, with checksum %s, size %" PRIu64
+ " and DB session identity %s",
+ fname.c_str(), checksum_hex.c_str(), size_bytes,
+ db_session_id.c_str());
+ } else {
+ ROCKS_LOG_INFO(options_.info_log,
+ "%s already present, with checksum %s and size %" PRIu64,
+ fname.c_str(), checksum_hex.c_str(), size_bytes);
+ }
+ }
+ }
+ live_dst_paths.insert(final_dest_path);
+
+ // Step 3: Add work item
+ if (!contents.empty() || need_to_copy) {
+ ROCKS_LOG_INFO(options_.info_log, "Copying %s to %s", fname.c_str(),
+ copy_dest_path->c_str());
+ CopyOrCreateWorkItem copy_or_create_work_item(
+ src_dir.empty() ? "" : src_path, *copy_dest_path, src_temperature,
+ Temperature::kUnknown /*dst_temp*/, contents, db_env_, backup_env_,
+ src_env_options, options_.sync, rate_limiter, size_limit, stats,
+ progress_callback, src_checksum_func_name, checksum_hex, db_id,
+ db_session_id);
+ BackupAfterCopyOrCreateWorkItem after_copy_or_create_work_item(
+ copy_or_create_work_item.result.get_future(), shared, need_to_copy,
+ backup_env_, temp_dest_path, final_dest_path, dst_relative);
+ files_to_copy_or_create_.write(std::move(copy_or_create_work_item));
+ backup_items_to_finish.push_back(std::move(after_copy_or_create_work_item));
+ } else {
+ std::promise<CopyOrCreateResult> promise_result;
+ BackupAfterCopyOrCreateWorkItem after_copy_or_create_work_item(
+ promise_result.get_future(), shared, need_to_copy, backup_env_,
+ temp_dest_path, final_dest_path, dst_relative);
+ backup_items_to_finish.push_back(std::move(after_copy_or_create_work_item));
+ CopyOrCreateResult result;
+ result.io_status = IOStatus::OK();
+ result.size = size_bytes;
+ result.checksum_hex = std::move(checksum_hex);
+ result.db_id = std::move(db_id);
+ result.db_session_id = std::move(db_session_id);
+ promise_result.set_value(std::move(result));
+ }
+ return IOStatus::OK();
+}
+
+IOStatus BackupEngineImpl::ReadFileAndComputeChecksum(
+ const std::string& src, const std::shared_ptr<FileSystem>& src_fs,
+ const EnvOptions& src_env_options, uint64_t size_limit,
+ std::string* checksum_hex, const Temperature src_temperature) const {
+ if (checksum_hex == nullptr) {
+ return status_to_io_status(Status::Aborted("Checksum pointer is null"));
+ }
+ uint32_t checksum_value = 0;
+ if (size_limit == 0) {
+ size_limit = std::numeric_limits<uint64_t>::max();
+ }
+
+ std::unique_ptr<SequentialFileReader> src_reader;
+ auto file_options = FileOptions(src_env_options);
+ file_options.temperature = src_temperature;
+ RateLimiter* rate_limiter = options_.backup_rate_limiter.get();
+ IOStatus io_s = SequentialFileReader::Create(
+ src_fs, src, file_options, &src_reader, nullptr /* dbg */, rate_limiter);
+ if (io_s.IsPathNotFound() && src_temperature != Temperature::kUnknown) {
+ // Retry without temperature hint in case the FileSystem is strict with
+ // non-kUnknown temperature option
+ file_options.temperature = Temperature::kUnknown;
+ io_s = SequentialFileReader::Create(src_fs, src, file_options, &src_reader,
+ nullptr /* dbg */, rate_limiter);
+ }
+ if (!io_s.ok()) {
+ return io_s;
+ }
+
+ size_t buf_size = kDefaultCopyFileBufferSize;
+ std::unique_ptr<char[]> buf(new char[buf_size]);
+ Slice data;
+
+ do {
+ if (stop_backup_.load(std::memory_order_acquire)) {
+ return status_to_io_status(Status::Incomplete("Backup stopped"));
+ }
+ size_t buffer_to_read =
+ (buf_size < size_limit) ? buf_size : static_cast<size_t>(size_limit);
+ io_s = src_reader->Read(buffer_to_read, &data, buf.get(),
+ Env::IO_LOW /* rate_limiter_priority */);
+ if (!io_s.ok()) {
+ return io_s;
+ }
+
+ size_limit -= data.size();
+ checksum_value = crc32c::Extend(checksum_value, data.data(), data.size());
+ } while (data.size() > 0 && size_limit > 0);
+
+ checksum_hex->assign(ChecksumInt32ToHex(checksum_value));
+
+ return io_s;
+}
+
+Status BackupEngineImpl::GetFileDbIdentities(
+ Env* src_env, const EnvOptions& src_env_options,
+ const std::string& file_path, Temperature file_temp,
+ RateLimiter* rate_limiter, std::string* db_id, std::string* db_session_id) {
+ assert(db_id != nullptr || db_session_id != nullptr);
+
+ Options options;
+ options.env = src_env;
+ SstFileDumper sst_reader(options, file_path, file_temp,
+ 2 * 1024 * 1024
+ /* readahead_size */,
+ false /* verify_checksum */, false /* output_hex */,
+ false /* decode_blob_index */, src_env_options,
+ true /* silent */);
+
+ const TableProperties* table_properties = nullptr;
+ std::shared_ptr<const TableProperties> tp;
+ Status s = sst_reader.getStatus();
+
+ if (s.ok()) {
+ // Try to get table properties from the table reader of sst_reader
+ if (!sst_reader.ReadTableProperties(&tp).ok()) {
+ // Try to use table properites from the initialization of sst_reader
+ table_properties = sst_reader.GetInitTableProperties();
+ } else {
+ table_properties = tp.get();
+ if (table_properties != nullptr && rate_limiter != nullptr) {
+ // sizeof(*table_properties) is a sufficent but far-from-exact
+ // approximation of read bytes due to metaindex block, std::string
+ // properties and varint compression
+ LoopRateLimitRequestHelper(sizeof(*table_properties), rate_limiter,
+ Env::IO_LOW, nullptr /* stats */,
+ RateLimiter::OpType::kRead);
+ }
+ }
+ } else {
+ ROCKS_LOG_INFO(options_.info_log, "Failed to read %s: %s",
+ file_path.c_str(), s.ToString().c_str());
+ return s;
+ }
+
+ if (table_properties != nullptr) {
+ if (db_id != nullptr) {
+ db_id->assign(table_properties->db_id);
+ }
+ if (db_session_id != nullptr) {
+ db_session_id->assign(table_properties->db_session_id);
+ if (db_session_id->empty()) {
+ s = Status::NotFound("DB session identity not found in " + file_path);
+ ROCKS_LOG_INFO(options_.info_log, "%s", s.ToString().c_str());
+ return s;
+ }
+ }
+ return Status::OK();
+ } else {
+ s = Status::Corruption("Table properties missing in " + file_path);
+ ROCKS_LOG_INFO(options_.info_log, "%s", s.ToString().c_str());
+ return s;
+ }
+}
+
+void BackupEngineImpl::LoopRateLimitRequestHelper(
+ const size_t total_bytes_to_request, RateLimiter* rate_limiter,
+ const Env::IOPriority pri, Statistics* stats,
+ const RateLimiter::OpType op_type) {
+ assert(rate_limiter != nullptr);
+ size_t remaining_bytes = total_bytes_to_request;
+ size_t request_bytes = 0;
+ while (remaining_bytes > 0) {
+ request_bytes =
+ std::min(static_cast<size_t>(rate_limiter->GetSingleBurstBytes()),
+ remaining_bytes);
+ rate_limiter->Request(request_bytes, pri, stats, op_type);
+ remaining_bytes -= request_bytes;
+ }
+}
+
+void BackupEngineImpl::DeleteChildren(const std::string& dir,
+ uint32_t file_type_filter) const {
+ std::vector<std::string> children;
+ db_fs_->GetChildren(dir, io_options_, &children, nullptr)
+ .PermitUncheckedError(); // ignore errors
+
+ for (const auto& f : children) {
+ uint64_t number;
+ FileType type;
+ bool ok = ParseFileName(f, &number, &type);
+ if (ok && (file_type_filter & (1 << type))) {
+ // don't delete this file
+ continue;
+ }
+ db_fs_->DeleteFile(dir + "/" + f, io_options_, nullptr)
+ .PermitUncheckedError(); // ignore errors
+ }
+}
+
+IOStatus BackupEngineImpl::ReadChildFileCurrentSizes(
+ const std::string& dir, const std::shared_ptr<FileSystem>& fs,
+ std::unordered_map<std::string, uint64_t>* result) const {
+ assert(result != nullptr);
+ std::vector<Env::FileAttributes> files_attrs;
+ IOStatus io_status = fs->FileExists(dir, io_options_, nullptr);
+ if (io_status.ok()) {
+ io_status =
+ fs->GetChildrenFileAttributes(dir, io_options_, &files_attrs, nullptr);
+ } else if (io_status.IsNotFound()) {
+ // Insert no entries can be considered success
+ io_status = IOStatus::OK();
+ }
+ const bool slash_needed = dir.empty() || dir.back() != '/';
+ for (const auto& file_attrs : files_attrs) {
+ result->emplace(dir + (slash_needed ? "/" : "") + file_attrs.name,
+ file_attrs.size_bytes);
+ }
+ return io_status;
+}
+
+IOStatus BackupEngineImpl::GarbageCollect() {
+ assert(!read_only_);
+
+ // We will make a best effort to remove all garbage even in the presence
+ // of inconsistencies or I/O failures that inhibit finding garbage.
+ IOStatus overall_status = IOStatus::OK();
+ // If all goes well, we don't need another auto-GC this session
+ might_need_garbage_collect_ = false;
+
+ ROCKS_LOG_INFO(options_.info_log, "Starting garbage collection");
+
+ // delete obsolete shared files
+ for (bool with_checksum : {false, true}) {
+ std::vector<std::string> shared_children;
+ {
+ std::string shared_path;
+ if (with_checksum) {
+ shared_path = GetAbsolutePath(GetSharedFileWithChecksumRel());
+ } else {
+ shared_path = GetAbsolutePath(GetSharedFileRel());
+ }
+ IOStatus io_s = backup_fs_->FileExists(shared_path, io_options_, nullptr);
+ if (io_s.ok()) {
+ io_s = backup_fs_->GetChildren(shared_path, io_options_,
+ &shared_children, nullptr);
+ } else if (io_s.IsNotFound()) {
+ io_s = IOStatus::OK();
+ }
+ if (!io_s.ok()) {
+ overall_status = io_s;
+ // Trying again later might work
+ might_need_garbage_collect_ = true;
+ }
+ }
+ for (auto& child : shared_children) {
+ std::string rel_fname;
+ if (with_checksum) {
+ rel_fname = GetSharedFileWithChecksumRel(child);
+ } else {
+ rel_fname = GetSharedFileRel(child);
+ }
+ auto child_itr = backuped_file_infos_.find(rel_fname);
+ // if it's not refcounted, delete it
+ if (child_itr == backuped_file_infos_.end() ||
+ child_itr->second->refs == 0) {
+ // this might be a directory, but DeleteFile will just fail in that
+ // case, so we're good
+ IOStatus io_s = backup_fs_->DeleteFile(GetAbsolutePath(rel_fname),
+ io_options_, nullptr);
+ ROCKS_LOG_INFO(options_.info_log, "Deleting %s -- %s",
+ rel_fname.c_str(), io_s.ToString().c_str());
+ backuped_file_infos_.erase(rel_fname);
+ if (!io_s.ok()) {
+ // Trying again later might work
+ might_need_garbage_collect_ = true;
+ }
+ }
+ }
+ }
+
+ // delete obsolete private files
+ std::vector<std::string> private_children;
+ {
+ IOStatus io_s =
+ backup_fs_->GetChildren(GetAbsolutePath(kPrivateDirName), io_options_,
+ &private_children, nullptr);
+ if (!io_s.ok()) {
+ overall_status = io_s;
+ // Trying again later might work
+ might_need_garbage_collect_ = true;
+ }
+ }
+ for (auto& child : private_children) {
+ BackupID backup_id = 0;
+ bool tmp_dir = child.find(".tmp") != std::string::npos;
+ sscanf(child.c_str(), "%u", &backup_id);
+ if (!tmp_dir && // if it's tmp_dir, delete it
+ (backup_id == 0 || backups_.find(backup_id) != backups_.end())) {
+ // it's either not a number or it's still alive. continue
+ continue;
+ }
+ // here we have to delete the dir and all its children
+ std::string full_private_path =
+ GetAbsolutePath(GetPrivateFileRel(backup_id));
+ std::vector<std::string> subchildren;
+ if (backup_fs_
+ ->GetChildren(full_private_path, io_options_, &subchildren, nullptr)
+ .ok()) {
+ for (auto& subchild : subchildren) {
+ IOStatus io_s = backup_fs_->DeleteFile(full_private_path + subchild,
+ io_options_, nullptr);
+ ROCKS_LOG_INFO(options_.info_log, "Deleting %s -- %s",
+ (full_private_path + subchild).c_str(),
+ io_s.ToString().c_str());
+ if (!io_s.ok()) {
+ // Trying again later might work
+ might_need_garbage_collect_ = true;
+ }
+ }
+ }
+ // finally delete the private dir
+ IOStatus io_s =
+ backup_fs_->DeleteDir(full_private_path, io_options_, nullptr);
+ ROCKS_LOG_INFO(options_.info_log, "Deleting dir %s -- %s",
+ full_private_path.c_str(), io_s.ToString().c_str());
+ if (!io_s.ok()) {
+ // Trying again later might work
+ might_need_garbage_collect_ = true;
+ }
+ }
+
+ assert(overall_status.ok() || might_need_garbage_collect_);
+ return overall_status;
+}
+
+// ------- BackupMeta class --------
+
+IOStatus BackupEngineImpl::BackupMeta::AddFile(
+ std::shared_ptr<FileInfo> file_info) {
+ auto itr = file_infos_->find(file_info->filename);
+ if (itr == file_infos_->end()) {
+ auto ret = file_infos_->insert({file_info->filename, file_info});
+ if (ret.second) {
+ itr = ret.first;
+ itr->second->refs = 1;
+ } else {
+ // if this happens, something is seriously wrong
+ return IOStatus::Corruption("In memory metadata insertion error");
+ }
+ } else {
+ // Compare sizes, because we scanned that off the filesystem on both
+ // ends. This is like a check in VerifyBackup.
+ if (itr->second->size != file_info->size) {
+ std::string msg = "Size mismatch for existing backup file: ";
+ msg.append(file_info->filename);
+ msg.append(" Size in backup is " + std::to_string(itr->second->size) +
+ " while size in DB is " + std::to_string(file_info->size));
+ msg.append(
+ " If this DB file checks as not corrupt, try deleting old"
+ " backups or backing up to a different backup directory.");
+ return IOStatus::Corruption(msg);
+ }
+ if (file_info->checksum_hex.empty()) {
+ // No checksum available to check
+ } else if (itr->second->checksum_hex.empty()) {
+ // Remember checksum if newly acquired
+ itr->second->checksum_hex = file_info->checksum_hex;
+ } else if (itr->second->checksum_hex != file_info->checksum_hex) {
+ // Note: to save I/O, these will be equal trivially on already backed
+ // up files that don't have the checksum in their name. And it should
+ // never fail for files that do have checksum in their name.
+
+ // Should never reach here, but produce an appropriate corruption
+ // message in case we do in a release build.
+ assert(false);
+ std::string msg = "Checksum mismatch for existing backup file: ";
+ msg.append(file_info->filename);
+ msg.append(" Expected checksum is " + itr->second->checksum_hex +
+ " while computed checksum is " + file_info->checksum_hex);
+ msg.append(
+ " If this DB file checks as not corrupt, try deleting old"
+ " backups or backing up to a different backup directory.");
+ return IOStatus::Corruption(msg);
+ }
+ ++itr->second->refs; // increase refcount if already present
+ }
+
+ size_ += file_info->size;
+ files_.push_back(itr->second);
+
+ return IOStatus::OK();
+}
+
+IOStatus BackupEngineImpl::BackupMeta::Delete(bool delete_meta) {
+ IOStatus io_s;
+ for (const auto& file : files_) {
+ --file->refs; // decrease refcount
+ }
+ files_.clear();
+ // delete meta file
+ if (delete_meta) {
+ io_s = fs_->FileExists(meta_filename_, iooptions_, nullptr);
+ if (io_s.ok()) {
+ io_s = fs_->DeleteFile(meta_filename_, iooptions_, nullptr);
+ } else if (io_s.IsNotFound()) {
+ io_s = IOStatus::OK(); // nothing to delete
+ }
+ }
+ timestamp_ = 0;
+ return io_s;
+}
+
+// Constants for backup meta file schema (see LoadFromFile)
+const std::string kSchemaVersionPrefix{"schema_version "};
+const std::string kFooterMarker{"// FOOTER"};
+
+const std::string kAppMetaDataFieldName{"metadata"};
+
+// WART: The checksums are crc32c but named "crc32"
+const std::string kFileCrc32cFieldName{"crc32"};
+const std::string kFileSizeFieldName{"size"};
+const std::string kTemperatureFieldName{"temp"};
+
+// Marks a (future) field that should cause failure if not recognized.
+// Other fields are assumed to be ignorable. For example, in the future
+// we might add
+// ni::file_name_escape uri_percent
+// to indicate all file names have had spaces and special characters
+// escaped using a URI percent encoding.
+const std::string kNonIgnorableFieldPrefix{"ni::"};
+
+// Each backup meta file is of the format (schema version 1):
+//----------------------------------------------------------
+// <timestamp>
+// <seq number>
+// metadata <metadata> (optional)
+// <number of files>
+// <file1> crc32 <crc32c_as_unsigned_decimal>
+// <file2> crc32 <crc32c_as_unsigned_decimal>
+// ...
+//----------------------------------------------------------
+//
+// For schema version 2.x (not in public APIs, but
+// forward-compatibility started):
+//----------------------------------------------------------
+// schema_version <ver>
+// <timestamp>
+// <seq number>
+// [<field name> <field data>]
+// ...
+// <number of files>
+// <file1>( <field name> <field data no spaces>)*
+// <file2>( <field name> <field data no spaces>)*
+// ...
+// [// FOOTER]
+// [<field name> <field data>]
+// ...
+//----------------------------------------------------------
+// where
+// <ver> ::= [0-9]+([.][0-9]+)
+// <field name> ::= [A-Za-z_][A-Za-z_0-9.]+
+// <field data> is anything but newline
+// <field data no spaces> is anything but space and newline
+// Although "// FOOTER" wouldn't strictly be required as a delimiter
+// given the number of files is included, it is there for parsing
+// sanity in case of corruption. It is only required if followed
+// by footer fields, such as a checksum of the meta file (so far).
+// Unrecognized fields are ignored, to support schema evolution on
+// non-critical features with forward compatibility. Update schema
+// major version for breaking changes. Schema minor versions are indicated
+// only for diagnostic/debugging purposes.
+//
+// Fields in schema version 2.0:
+// * Top-level meta fields:
+// * Only "metadata" as in schema version 1
+// * File meta fields:
+// * "crc32" - a crc32c checksum as in schema version 1
+// * "size" - the size of the file (new)
+// * Footer meta fields:
+// * None yet (future use for meta file checksum anticipated)
+//
+IOStatus BackupEngineImpl::BackupMeta::LoadFromFile(
+ const std::string& backup_dir,
+ const std::unordered_map<std::string, uint64_t>& abs_path_to_size,
+ RateLimiter* rate_limiter, Logger* info_log,
+ std::unordered_set<std::string>* reported_ignored_fields) {
+ assert(reported_ignored_fields);
+ assert(Empty());
+
+ std::unique_ptr<LineFileReader> backup_meta_reader;
+ {
+ IOStatus io_s = LineFileReader::Create(fs_, meta_filename_, FileOptions(),
+ &backup_meta_reader,
+ nullptr /* dbg */, rate_limiter);
+ if (!io_s.ok()) {
+ return io_s;
+ }
+ }
+
+ // If we don't read an explicit schema_version, that implies version 1,
+ // which is what we call the original backup meta schema.
+ int schema_major_version = 1;
+
+ // Failures handled at the end
+ std::string line;
+ if (backup_meta_reader->ReadLine(&line,
+ Env::IO_LOW /* rate_limiter_priority */)) {
+ if (StartsWith(line, kSchemaVersionPrefix)) {
+ std::string ver = line.substr(kSchemaVersionPrefix.size());
+ if (ver == "2" || StartsWith(ver, "2.")) {
+ schema_major_version = 2;
+ } else {
+ return IOStatus::NotSupported(
+ "Unsupported/unrecognized schema version: " + ver);
+ }
+ line.clear();
+ } else if (line.empty()) {
+ return IOStatus::Corruption("Unexpected empty line");
+ }
+ }
+ if (!line.empty()) {
+ timestamp_ = std::strtoull(line.c_str(), nullptr, /*base*/ 10);
+ } else if (backup_meta_reader->ReadLine(
+ &line, Env::IO_LOW /* rate_limiter_priority */)) {
+ timestamp_ = std::strtoull(line.c_str(), nullptr, /*base*/ 10);
+ }
+ if (backup_meta_reader->ReadLine(&line,
+ Env::IO_LOW /* rate_limiter_priority */)) {
+ sequence_number_ = std::strtoull(line.c_str(), nullptr, /*base*/ 10);
+ }
+ uint32_t num_files = UINT32_MAX;
+ while (backup_meta_reader->ReadLine(
+ &line, Env::IO_LOW /* rate_limiter_priority */)) {
+ if (line.empty()) {
+ return IOStatus::Corruption("Unexpected empty line");
+ }
+ // Number -> number of files -> exit loop reading optional meta fields
+ if (line[0] >= '0' && line[0] <= '9') {
+ num_files = static_cast<uint32_t>(strtoul(line.c_str(), nullptr, 10));
+ break;
+ }
+ // else, must be a meta field assignment
+ auto space_pos = line.find_first_of(' ');
+ if (space_pos == std::string::npos) {
+ return IOStatus::Corruption("Expected number of files or meta field");
+ }
+ std::string field_name = line.substr(0, space_pos);
+ std::string field_data = line.substr(space_pos + 1);
+ if (field_name == kAppMetaDataFieldName) {
+ // app metadata present
+ bool decode_success = Slice(field_data).DecodeHex(&app_metadata_);
+ if (!decode_success) {
+ return IOStatus::Corruption(
+ "Failed to decode stored hex encoded app metadata");
+ }
+ } else if (schema_major_version < 2) {
+ return IOStatus::Corruption("Expected number of files or \"" +
+ kAppMetaDataFieldName + "\" field");
+ } else if (StartsWith(field_name, kNonIgnorableFieldPrefix)) {
+ return IOStatus::NotSupported("Unrecognized non-ignorable meta field " +
+ field_name + " (from future version?)");
+ } else {
+ // Warn the first time we see any particular unrecognized meta field
+ if (reported_ignored_fields->insert("meta:" + field_name).second) {
+ ROCKS_LOG_WARN(info_log, "Ignoring unrecognized backup meta field %s",
+ field_name.c_str());
+ }
+ }
+ }
+ std::vector<std::shared_ptr<FileInfo>> files;
+ bool footer_present = false;
+ while (backup_meta_reader->ReadLine(
+ &line, Env::IO_LOW /* rate_limiter_priority */)) {
+ std::vector<std::string> components = StringSplit(line, ' ');
+
+ if (components.size() < 1) {
+ return IOStatus::Corruption("Empty line instead of file entry.");
+ }
+ if (schema_major_version >= 2 && components.size() == 2 &&
+ line == kFooterMarker) {
+ footer_present = true;
+ break;
+ }
+
+ const std::string& filename = components[0];
+
+ uint64_t actual_size;
+ const std::shared_ptr<FileInfo> file_info = GetFile(filename);
+ if (file_info) {
+ actual_size = file_info->size;
+ } else {
+ std::string abs_path = backup_dir + "/" + filename;
+ auto e = abs_path_to_size.find(abs_path);
+ if (e == abs_path_to_size.end()) {
+ return IOStatus::Corruption(
+ "Pathname in meta file not found on disk: " + abs_path);
+ }
+ actual_size = e->second;
+ }
+
+ if (schema_major_version >= 2) {
+ if (components.size() % 2 != 1) {
+ return IOStatus::Corruption(
+ "Bad number of line components for file entry.");
+ }
+ } else {
+ // Check restricted original schema
+ if (components.size() < 3) {
+ return IOStatus::Corruption("File checksum is missing for " + filename +
+ " in " + meta_filename_);
+ }
+ if (components[1] != kFileCrc32cFieldName) {
+ return IOStatus::Corruption("Unknown checksum type for " + filename +
+ " in " + meta_filename_);
+ }
+ if (components.size() > 3) {
+ return IOStatus::Corruption("Extra data for entry " + filename +
+ " in " + meta_filename_);
+ }
+ }
+
+ std::string checksum_hex;
+ Temperature temp = Temperature::kUnknown;
+ for (unsigned i = 1; i < components.size(); i += 2) {
+ const std::string& field_name = components[i];
+ const std::string& field_data = components[i + 1];
+
+ if (field_name == kFileCrc32cFieldName) {
+ uint32_t checksum_value =
+ static_cast<uint32_t>(strtoul(field_data.c_str(), nullptr, 10));
+ if (field_data != std::to_string(checksum_value)) {
+ return IOStatus::Corruption("Invalid checksum value for " + filename +
+ " in " + meta_filename_);
+ }
+ checksum_hex = ChecksumInt32ToHex(checksum_value);
+ } else if (field_name == kFileSizeFieldName) {
+ uint64_t ex_size =
+ std::strtoull(field_data.c_str(), nullptr, /*base*/ 10);
+ if (ex_size != actual_size) {
+ return IOStatus::Corruption(
+ "For file " + filename + " expected size " +
+ std::to_string(ex_size) + " but found size" +
+ std::to_string(actual_size));
+ }
+ } else if (field_name == kTemperatureFieldName) {
+ auto iter = temperature_string_map.find(field_data);
+ if (iter != temperature_string_map.end()) {
+ temp = iter->second;
+ } else {
+ // Could report corruption, but in case of new temperatures added
+ // in future, letting those map to kUnknown which should generally
+ // be safe.
+ temp = Temperature::kUnknown;
+ }
+ } else if (StartsWith(field_name, kNonIgnorableFieldPrefix)) {
+ return IOStatus::NotSupported("Unrecognized non-ignorable file field " +
+ field_name + " (from future version?)");
+ } else {
+ // Warn the first time we see any particular unrecognized file field
+ if (reported_ignored_fields->insert("file:" + field_name).second) {
+ ROCKS_LOG_WARN(info_log, "Ignoring unrecognized backup file field %s",
+ field_name.c_str());
+ }
+ }
+ }
+
+ files.emplace_back(new FileInfo(filename, actual_size, checksum_hex,
+ /*id*/ "", /*sid*/ "", temp));
+ }
+
+ if (footer_present) {
+ assert(schema_major_version >= 2);
+ while (backup_meta_reader->ReadLine(
+ &line, Env::IO_LOW /* rate_limiter_priority */)) {
+ if (line.empty()) {
+ return IOStatus::Corruption("Unexpected empty line");
+ }
+ auto space_pos = line.find_first_of(' ');
+ if (space_pos == std::string::npos) {
+ return IOStatus::Corruption("Expected footer field");
+ }
+ std::string field_name = line.substr(0, space_pos);
+ std::string field_data = line.substr(space_pos + 1);
+ if (StartsWith(field_name, kNonIgnorableFieldPrefix)) {
+ return IOStatus::NotSupported("Unrecognized non-ignorable field " +
+ field_name + " (from future version?)");
+ } else if (reported_ignored_fields->insert("footer:" + field_name)
+ .second) {
+ // Warn the first time we see any particular unrecognized footer field
+ ROCKS_LOG_WARN(info_log,
+ "Ignoring unrecognized backup meta footer field %s",
+ field_name.c_str());
+ }
+ }
+ }
+
+ {
+ IOStatus io_s = backup_meta_reader->GetStatus();
+ if (!io_s.ok()) {
+ return io_s;
+ }
+ }
+
+ if (num_files != files.size()) {
+ return IOStatus::Corruption(
+ "Inconsistent number of files or missing/incomplete header in " +
+ meta_filename_);
+ }
+
+ files_.reserve(files.size());
+ for (const auto& file_info : files) {
+ IOStatus io_s = AddFile(file_info);
+ if (!io_s.ok()) {
+ return io_s;
+ }
+ }
+
+ return IOStatus::OK();
+}
+
+const std::vector<std::string> minor_version_strings{
+ "", // invalid major version 0
+ "", // implicit major version 1
+ "2.0",
+};
+
+IOStatus BackupEngineImpl::BackupMeta::StoreToFile(
+ bool sync, int schema_version,
+ const TEST_BackupMetaSchemaOptions* schema_test_options) {
+ if (schema_version < 1) {
+ return IOStatus::InvalidArgument(
+ "BackupEngineOptions::schema_version must be >= 1");
+ }
+ if (schema_version > static_cast<int>(minor_version_strings.size() - 1)) {
+ return IOStatus::NotSupported(
+ "Only BackupEngineOptions::schema_version <= " +
+ std::to_string(minor_version_strings.size() - 1) + " is supported");
+ }
+ std::string ver = minor_version_strings[schema_version];
+
+ // Need schema_version >= 2 for TEST_BackupMetaSchemaOptions
+ assert(schema_version >= 2 || schema_test_options == nullptr);
+
+ IOStatus io_s;
+ std::unique_ptr<FSWritableFile> backup_meta_file;
+ FileOptions file_options;
+ file_options.use_mmap_writes = false;
+ file_options.use_direct_writes = false;
+ io_s = fs_->NewWritableFile(meta_tmp_filename_, file_options,
+ &backup_meta_file, nullptr);
+ if (!io_s.ok()) {
+ return io_s;
+ }
+
+ std::ostringstream buf;
+ if (schema_test_options) {
+ // override for testing
+ ver = schema_test_options->version;
+ }
+ if (!ver.empty()) {
+ assert(schema_version >= 2);
+ buf << kSchemaVersionPrefix << ver << "\n";
+ }
+ buf << static_cast<unsigned long long>(timestamp_) << "\n";
+ buf << sequence_number_ << "\n";
+
+ if (!app_metadata_.empty()) {
+ std::string hex_encoded_metadata =
+ Slice(app_metadata_).ToString(/* hex */ true);
+ buf << kAppMetaDataFieldName << " " << hex_encoded_metadata << "\n";
+ }
+ if (schema_test_options) {
+ for (auto& e : schema_test_options->meta_fields) {
+ buf << e.first << " " << e.second << "\n";
+ }
+ }
+ buf << files_.size() << "\n";
+
+ for (const auto& file : files_) {
+ buf << file->filename;
+ if (schema_test_options == nullptr ||
+ schema_test_options->crc32c_checksums) {
+ // use crc32c for now, switch to something else if needed
+ buf << " " << kFileCrc32cFieldName << " "
+ << ChecksumHexToInt32(file->checksum_hex);
+ }
+ if (schema_version >= 2 && file->temp != Temperature::kUnknown) {
+ buf << " " << kTemperatureFieldName << " "
+ << temperature_to_string[file->temp];
+ }
+ if (schema_test_options && schema_test_options->file_sizes) {
+ buf << " " << kFileSizeFieldName << " " << std::to_string(file->size);
+ }
+ if (schema_test_options) {
+ for (auto& e : schema_test_options->file_fields) {
+ buf << " " << e.first << " " << e.second;
+ }
+ }
+ buf << "\n";
+ }
+
+ if (schema_test_options && !schema_test_options->footer_fields.empty()) {
+ buf << kFooterMarker << "\n";
+ for (auto& e : schema_test_options->footer_fields) {
+ buf << e.first << " " << e.second << "\n";
+ }
+ }
+
+ io_s = backup_meta_file->Append(Slice(buf.str()), iooptions_, nullptr);
+ IOSTATS_ADD(bytes_written, buf.str().size());
+ if (io_s.ok() && sync) {
+ io_s = backup_meta_file->Sync(iooptions_, nullptr);
+ }
+ if (io_s.ok()) {
+ io_s = backup_meta_file->Close(iooptions_, nullptr);
+ }
+ if (io_s.ok()) {
+ io_s = fs_->RenameFile(meta_tmp_filename_, meta_filename_, iooptions_,
+ nullptr);
+ }
+ return io_s;
+}
+} // namespace
+
+IOStatus BackupEngineReadOnly::Open(const BackupEngineOptions& options,
+ Env* env,
+ BackupEngineReadOnly** backup_engine_ptr) {
+ if (options.destroy_old_data) {
+ return IOStatus::InvalidArgument(
+ "Can't destroy old data with ReadOnly BackupEngine");
+ }
+ std::unique_ptr<BackupEngineImplThreadSafe> backup_engine(
+ new BackupEngineImplThreadSafe(options, env, true /*read_only*/));
+ auto s = backup_engine->Initialize();
+ if (!s.ok()) {
+ *backup_engine_ptr = nullptr;
+ return s;
+ }
+ *backup_engine_ptr = backup_engine.release();
+ return IOStatus::OK();
+}
+
+void TEST_SetBackupMetaSchemaOptions(
+ BackupEngine* engine, const TEST_BackupMetaSchemaOptions& options) {
+ BackupEngineImplThreadSafe* impl =
+ static_cast_with_check<BackupEngineImplThreadSafe>(engine);
+ impl->TEST_SetBackupMetaSchemaOptions(options);
+}
+
+void TEST_SetDefaultRateLimitersClock(
+ BackupEngine* engine,
+ const std::shared_ptr<SystemClock>& backup_rate_limiter_clock,
+ const std::shared_ptr<SystemClock>& restore_rate_limiter_clock) {
+ BackupEngineImplThreadSafe* impl =
+ static_cast_with_check<BackupEngineImplThreadSafe>(engine);
+ impl->TEST_SetDefaultRateLimitersClock(backup_rate_limiter_clock,
+ restore_rate_limiter_clock);
+}
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/backup/backup_engine_impl.h b/src/rocksdb/utilities/backup/backup_engine_impl.h
new file mode 100644
index 000000000..398f47f27
--- /dev/null
+++ b/src/rocksdb/utilities/backup/backup_engine_impl.h
@@ -0,0 +1,36 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+#ifndef ROCKSDB_LITE
+
+#include "rocksdb/utilities/backup_engine.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+struct TEST_BackupMetaSchemaOptions {
+ std::string version = "2";
+ bool crc32c_checksums = false;
+ bool file_sizes = true;
+ std::map<std::string, std::string> meta_fields;
+ std::map<std::string, std::string> file_fields;
+ std::map<std::string, std::string> footer_fields;
+};
+
+// Modifies the BackupEngine(Impl) to write backup meta files using the
+// unpublished schema version 2, for the life of this object (not backup_dir).
+// TEST_BackupMetaSchemaOptions offers some customization for testing.
+void TEST_SetBackupMetaSchemaOptions(
+ BackupEngine* engine, const TEST_BackupMetaSchemaOptions& options);
+
+// Modifies the BackupEngine(Impl) to use specified clocks for backup and
+// restore rate limiters created by default if not specified by users for
+// test speedup.
+void TEST_SetDefaultRateLimitersClock(
+ BackupEngine* engine,
+ const std::shared_ptr<SystemClock>& backup_rate_limiter_clock = nullptr,
+ const std::shared_ptr<SystemClock>& restore_rate_limiter_clock = nullptr);
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/backup/backup_engine_test.cc b/src/rocksdb/utilities/backup/backup_engine_test.cc
new file mode 100644
index 000000000..d1f74f769
--- /dev/null
+++ b/src/rocksdb/utilities/backup/backup_engine_test.cc
@@ -0,0 +1,4219 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#if !defined(ROCKSDB_LITE) && !defined(OS_WIN)
+
+#include "rocksdb/utilities/backup_engine.h"
+
+#include <algorithm>
+#include <array>
+#include <atomic>
+#include <cstddef>
+#include <cstdint>
+#include <limits>
+#include <memory>
+#include <random>
+#include <string>
+#include <utility>
+
+#include "db/db_impl/db_impl.h"
+#include "db/db_test_util.h"
+#include "env/composite_env_wrapper.h"
+#include "env/env_chroot.h"
+#include "file/filename.h"
+#include "port/port.h"
+#include "port/stack_trace.h"
+#include "rocksdb/advanced_options.h"
+#include "rocksdb/env.h"
+#include "rocksdb/file_checksum.h"
+#include "rocksdb/rate_limiter.h"
+#include "rocksdb/statistics.h"
+#include "rocksdb/transaction_log.h"
+#include "rocksdb/types.h"
+#include "rocksdb/utilities/options_util.h"
+#include "rocksdb/utilities/stackable_db.h"
+#include "test_util/sync_point.h"
+#include "test_util/testharness.h"
+#include "test_util/testutil.h"
+#include "util/cast_util.h"
+#include "util/mutexlock.h"
+#include "util/random.h"
+#include "util/rate_limiter.h"
+#include "util/stderr_logger.h"
+#include "util/string_util.h"
+#include "utilities/backup/backup_engine_impl.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+namespace {
+using ShareFilesNaming = BackupEngineOptions::ShareFilesNaming;
+const auto kLegacyCrc32cAndFileSize =
+ BackupEngineOptions::kLegacyCrc32cAndFileSize;
+const auto kUseDbSessionId = BackupEngineOptions::kUseDbSessionId;
+const auto kFlagIncludeFileSize = BackupEngineOptions::kFlagIncludeFileSize;
+const auto kNamingDefault = kUseDbSessionId | kFlagIncludeFileSize;
+
+class DummyDB : public StackableDB {
+ public:
+ /* implicit */
+ DummyDB(const Options& options, const std::string& dbname)
+ : StackableDB(nullptr),
+ options_(options),
+ dbname_(dbname),
+ deletions_enabled_(true),
+ sequence_number_(0) {}
+
+ SequenceNumber GetLatestSequenceNumber() const override {
+ return ++sequence_number_;
+ }
+
+ const std::string& GetName() const override { return dbname_; }
+
+ Env* GetEnv() const override { return options_.env; }
+
+ using DB::GetOptions;
+ Options GetOptions(ColumnFamilyHandle* /*column_family*/) const override {
+ return options_;
+ }
+
+ DBOptions GetDBOptions() const override { return DBOptions(options_); }
+
+ Status EnableFileDeletions(bool /*force*/) override {
+ EXPECT_TRUE(!deletions_enabled_);
+ deletions_enabled_ = true;
+ return Status::OK();
+ }
+
+ Status DisableFileDeletions() override {
+ EXPECT_TRUE(deletions_enabled_);
+ deletions_enabled_ = false;
+ return Status::OK();
+ }
+
+ ColumnFamilyHandle* DefaultColumnFamily() const override { return nullptr; }
+
+ Status GetLiveFilesStorageInfo(
+ const LiveFilesStorageInfoOptions& opts,
+ std::vector<LiveFileStorageInfo>* files) override {
+ uint64_t number;
+ FileType type;
+ files->clear();
+ for (auto& f : live_files_) {
+ bool success = ParseFileName(f, &number, &type);
+ if (!success) {
+ return Status::InvalidArgument("Bad file name: " + f);
+ }
+ files->emplace_back();
+ LiveFileStorageInfo& info = files->back();
+ info.relative_filename = f;
+ info.directory = dbname_;
+ info.file_number = number;
+ info.file_type = type;
+ if (type == kDescriptorFile) {
+ info.size = 100; // See TestFs::GetChildrenFileAttributes below
+ info.trim_to_size = true;
+ } else if (type == kCurrentFile) {
+ info.size = 0;
+ info.trim_to_size = true;
+ } else {
+ info.size = 200; // See TestFs::GetChildrenFileAttributes below
+ }
+ if (opts.include_checksum_info) {
+ info.file_checksum = kUnknownFileChecksum;
+ info.file_checksum_func_name = kUnknownFileChecksumFuncName;
+ }
+ }
+ return Status::OK();
+ }
+
+ // To avoid FlushWAL called on stacked db which is nullptr
+ Status FlushWAL(bool /*sync*/) override { return Status::OK(); }
+
+ std::vector<std::string> live_files_;
+
+ private:
+ Options options_;
+ std::string dbname_;
+ bool deletions_enabled_;
+ mutable SequenceNumber sequence_number_;
+}; // DummyDB
+
+class TestFs : public FileSystemWrapper {
+ public:
+ explicit TestFs(const std::shared_ptr<FileSystem>& t)
+ : FileSystemWrapper(t) {}
+ const char* Name() const override { return "TestFs"; }
+
+ class DummySequentialFile : public FSSequentialFile {
+ public:
+ explicit DummySequentialFile(bool fail_reads)
+ : FSSequentialFile(), rnd_(5), fail_reads_(fail_reads) {}
+ IOStatus Read(size_t n, const IOOptions&, Slice* result, char* scratch,
+ IODebugContext*) override {
+ if (fail_reads_) {
+ return IOStatus::IOError();
+ }
+ size_t read_size = (n > size_left) ? size_left : n;
+ for (size_t i = 0; i < read_size; ++i) {
+ scratch[i] = rnd_.Next() & 255;
+ }
+ *result = Slice(scratch, read_size);
+ size_left -= read_size;
+ return IOStatus::OK();
+ }
+
+ IOStatus Skip(uint64_t n) override {
+ size_left = (n > size_left) ? size_left - n : 0;
+ return IOStatus::OK();
+ }
+
+ private:
+ size_t size_left = 200;
+ Random rnd_;
+ bool fail_reads_;
+ };
+
+ IOStatus NewSequentialFile(const std::string& f, const FileOptions& file_opts,
+ std::unique_ptr<FSSequentialFile>* r,
+ IODebugContext* dbg) override {
+ MutexLock l(&mutex_);
+ if (dummy_sequential_file_) {
+ r->reset(
+ new TestFs::DummySequentialFile(dummy_sequential_file_fail_reads_));
+ return IOStatus::OK();
+ } else {
+ IOStatus s = FileSystemWrapper::NewSequentialFile(f, file_opts, r, dbg);
+ if (s.ok()) {
+ if ((*r)->use_direct_io()) {
+ ++num_direct_seq_readers_;
+ }
+ ++num_seq_readers_;
+ }
+ return s;
+ }
+ }
+
+ IOStatus NewWritableFile(const std::string& f, const FileOptions& file_opts,
+ std::unique_ptr<FSWritableFile>* r,
+ IODebugContext* dbg) override {
+ MutexLock l(&mutex_);
+ written_files_.push_back(f);
+ if (limit_written_files_ == 0) {
+ return IOStatus::NotSupported("Limit on written files reached");
+ }
+ limit_written_files_--;
+ IOStatus s = FileSystemWrapper::NewWritableFile(f, file_opts, r, dbg);
+ if (s.ok()) {
+ if ((*r)->use_direct_io()) {
+ ++num_direct_writers_;
+ }
+ ++num_writers_;
+ }
+ return s;
+ }
+
+ IOStatus NewRandomAccessFile(const std::string& f,
+ const FileOptions& file_opts,
+ std::unique_ptr<FSRandomAccessFile>* r,
+ IODebugContext* dbg) override {
+ MutexLock l(&mutex_);
+ IOStatus s = FileSystemWrapper::NewRandomAccessFile(f, file_opts, r, dbg);
+ if (s.ok()) {
+ if ((*r)->use_direct_io()) {
+ ++num_direct_rand_readers_;
+ }
+ ++num_rand_readers_;
+ }
+ return s;
+ }
+
+ IOStatus DeleteFile(const std::string& f, const IOOptions& options,
+ IODebugContext* dbg) override {
+ MutexLock l(&mutex_);
+ if (fail_delete_files_) {
+ return IOStatus::IOError();
+ }
+ EXPECT_GT(limit_delete_files_, 0U);
+ limit_delete_files_--;
+ return FileSystemWrapper::DeleteFile(f, options, dbg);
+ }
+
+ IOStatus DeleteDir(const std::string& d, const IOOptions& options,
+ IODebugContext* dbg) override {
+ MutexLock l(&mutex_);
+ if (fail_delete_files_) {
+ return IOStatus::IOError();
+ }
+ return FileSystemWrapper::DeleteDir(d, options, dbg);
+ }
+
+ void AssertWrittenFiles(std::vector<std::string>& should_have_written) {
+ MutexLock l(&mutex_);
+ std::sort(should_have_written.begin(), should_have_written.end());
+ std::sort(written_files_.begin(), written_files_.end());
+
+ ASSERT_EQ(should_have_written, written_files_);
+ }
+
+ void ClearWrittenFiles() {
+ MutexLock l(&mutex_);
+ written_files_.clear();
+ }
+
+ void SetLimitWrittenFiles(uint64_t limit) {
+ MutexLock l(&mutex_);
+ limit_written_files_ = limit;
+ }
+
+ void SetLimitDeleteFiles(uint64_t limit) {
+ MutexLock l(&mutex_);
+ limit_delete_files_ = limit;
+ }
+
+ void SetDeleteFileFailure(bool fail) {
+ MutexLock l(&mutex_);
+ fail_delete_files_ = fail;
+ }
+
+ void SetDummySequentialFile(bool dummy_sequential_file) {
+ MutexLock l(&mutex_);
+ dummy_sequential_file_ = dummy_sequential_file;
+ }
+ void SetDummySequentialFileFailReads(bool dummy_sequential_file_fail_reads) {
+ MutexLock l(&mutex_);
+ dummy_sequential_file_fail_reads_ = dummy_sequential_file_fail_reads;
+ }
+
+ void SetGetChildrenFailure(bool fail) { get_children_failure_ = fail; }
+ IOStatus GetChildren(const std::string& dir, const IOOptions& io_opts,
+ std::vector<std::string>* r,
+ IODebugContext* dbg) override {
+ if (get_children_failure_) {
+ return IOStatus::IOError("SimulatedFailure");
+ }
+ return FileSystemWrapper::GetChildren(dir, io_opts, r, dbg);
+ }
+
+ // Some test cases do not actually create the test files (e.g., see
+ // DummyDB::live_files_) - for those cases, we mock those files' attributes
+ // so CreateNewBackup() can get their attributes.
+ void SetFilenamesForMockedAttrs(const std::vector<std::string>& filenames) {
+ filenames_for_mocked_attrs_ = filenames;
+ }
+ IOStatus GetChildrenFileAttributes(const std::string& dir,
+ const IOOptions& options,
+ std::vector<FileAttributes>* result,
+ IODebugContext* dbg) override {
+ if (filenames_for_mocked_attrs_.size() > 0) {
+ for (const auto& filename : filenames_for_mocked_attrs_) {
+ uint64_t size_bytes = 200; // Match TestFs
+ if (filename.find("MANIFEST") == 0) {
+ size_bytes = 100; // Match DummyDB::GetLiveFiles
+ }
+ result->push_back({dir + "/" + filename, size_bytes});
+ }
+ return IOStatus::OK();
+ }
+ return FileSystemWrapper::GetChildrenFileAttributes(dir, options, result,
+ dbg);
+ }
+
+ IOStatus GetFileSize(const std::string& f, const IOOptions& options,
+ uint64_t* s, IODebugContext* dbg) override {
+ if (filenames_for_mocked_attrs_.size() > 0) {
+ auto fname = f.substr(f.find_last_of('/') + 1);
+ auto filename_iter = std::find(filenames_for_mocked_attrs_.begin(),
+ filenames_for_mocked_attrs_.end(), fname);
+ if (filename_iter != filenames_for_mocked_attrs_.end()) {
+ *s = 200; // Match TestFs
+ if (fname.find("MANIFEST") == 0) {
+ *s = 100; // Match DummyDB::GetLiveFiles
+ }
+ return IOStatus::OK();
+ }
+ return IOStatus::NotFound(fname);
+ }
+ return FileSystemWrapper::GetFileSize(f, options, s, dbg);
+ }
+
+ void SetCreateDirIfMissingFailure(bool fail) {
+ create_dir_if_missing_failure_ = fail;
+ }
+ IOStatus CreateDirIfMissing(const std::string& d, const IOOptions& options,
+ IODebugContext* dbg) override {
+ if (create_dir_if_missing_failure_) {
+ return IOStatus::IOError("SimulatedFailure");
+ }
+ return FileSystemWrapper::CreateDirIfMissing(d, options, dbg);
+ }
+
+ void SetNewDirectoryFailure(bool fail) { new_directory_failure_ = fail; }
+ IOStatus NewDirectory(const std::string& name, const IOOptions& io_opts,
+ std::unique_ptr<FSDirectory>* result,
+ IODebugContext* dbg) override {
+ if (new_directory_failure_) {
+ return IOStatus::IOError("SimulatedFailure");
+ }
+ return FileSystemWrapper::NewDirectory(name, io_opts, result, dbg);
+ }
+
+ void ClearFileOpenCounters() {
+ MutexLock l(&mutex_);
+ num_rand_readers_ = 0;
+ num_direct_rand_readers_ = 0;
+ num_seq_readers_ = 0;
+ num_direct_seq_readers_ = 0;
+ num_writers_ = 0;
+ num_direct_writers_ = 0;
+ }
+
+ int num_rand_readers() { return num_rand_readers_; }
+ int num_direct_rand_readers() { return num_direct_rand_readers_; }
+ int num_seq_readers() { return num_seq_readers_; }
+ int num_direct_seq_readers() { return num_direct_seq_readers_; }
+ int num_writers() { return num_writers_; }
+ // FIXME(?): unused
+ int num_direct_writers() { return num_direct_writers_; }
+
+ private:
+ port::Mutex mutex_;
+ bool dummy_sequential_file_ = false;
+ bool dummy_sequential_file_fail_reads_ = false;
+ std::vector<std::string> written_files_;
+ std::vector<std::string> filenames_for_mocked_attrs_;
+ uint64_t limit_written_files_ = 1000000;
+ uint64_t limit_delete_files_ = 1000000;
+ bool fail_delete_files_ = false;
+
+ bool get_children_failure_ = false;
+ bool create_dir_if_missing_failure_ = false;
+ bool new_directory_failure_ = false;
+
+ // Keeps track of how many files of each type were successfully opened, and
+ // out of those, how many were opened with direct I/O.
+ std::atomic<int> num_rand_readers_{};
+ std::atomic<int> num_direct_rand_readers_{};
+ std::atomic<int> num_seq_readers_{};
+ std::atomic<int> num_direct_seq_readers_{};
+ std::atomic<int> num_writers_{};
+ std::atomic<int> num_direct_writers_{};
+}; // TestFs
+
+class FileManager : public EnvWrapper {
+ public:
+ explicit FileManager(Env* t) : EnvWrapper(t), rnd_(5) {}
+ const char* Name() const override { return "FileManager"; }
+
+ Status GetRandomFileInDir(const std::string& dir, std::string* fname,
+ uint64_t* fsize) {
+ std::vector<FileAttributes> children;
+ auto s = GetChildrenFileAttributes(dir, &children);
+ if (!s.ok()) {
+ return s;
+ } else if (children.size() <= 2) { // . and ..
+ return Status::NotFound("Empty directory: " + dir);
+ }
+ assert(fname != nullptr);
+ while (true) {
+ int i = rnd_.Next() % children.size();
+ fname->assign(dir + "/" + children[i].name);
+ *fsize = children[i].size_bytes;
+ return Status::OK();
+ }
+ // should never get here
+ assert(false);
+ return Status::NotFound("");
+ }
+
+ Status DeleteRandomFileInDir(const std::string& dir) {
+ std::vector<std::string> children;
+ Status s = GetChildren(dir, &children);
+ if (!s.ok()) {
+ return s;
+ }
+ while (true) {
+ int i = rnd_.Next() % children.size();
+ return DeleteFile(dir + "/" + children[i]);
+ }
+ // should never get here
+ assert(false);
+ return Status::NotFound("");
+ }
+
+ Status AppendToRandomFileInDir(const std::string& dir,
+ const std::string& data) {
+ std::vector<std::string> children;
+ Status s = GetChildren(dir, &children);
+ if (!s.ok()) {
+ return s;
+ }
+ while (true) {
+ int i = rnd_.Next() % children.size();
+ return WriteToFile(dir + "/" + children[i], data);
+ }
+ // should never get here
+ assert(false);
+ return Status::NotFound("");
+ }
+
+ Status CorruptFile(const std::string& fname, uint64_t bytes_to_corrupt) {
+ std::string file_contents;
+ Status s = ReadFileToString(this, fname, &file_contents);
+ if (!s.ok()) {
+ return s;
+ }
+ s = DeleteFile(fname);
+ if (!s.ok()) {
+ return s;
+ }
+
+ for (uint64_t i = 0; i < bytes_to_corrupt; ++i) {
+ std::string tmp = rnd_.RandomString(1);
+ file_contents[rnd_.Next() % file_contents.size()] = tmp[0];
+ }
+ return WriteToFile(fname, file_contents);
+ }
+
+ Status CorruptFileStart(const std::string& fname) {
+ std::string to_xor = "blah";
+ std::string file_contents;
+ Status s = ReadFileToString(this, fname, &file_contents);
+ if (!s.ok()) {
+ return s;
+ }
+ s = DeleteFile(fname);
+ if (!s.ok()) {
+ return s;
+ }
+ for (size_t i = 0; i < to_xor.size(); ++i) {
+ file_contents[i] ^= to_xor[i];
+ }
+ return WriteToFile(fname, file_contents);
+ }
+
+ Status CorruptChecksum(const std::string& fname, bool appear_valid) {
+ std::string metadata;
+ Status s = ReadFileToString(this, fname, &metadata);
+ if (!s.ok()) {
+ return s;
+ }
+ s = DeleteFile(fname);
+ if (!s.ok()) {
+ return s;
+ }
+
+ auto pos = metadata.find("private");
+ if (pos == std::string::npos) {
+ return Status::Corruption("private file is expected");
+ }
+ pos = metadata.find(" crc32 ", pos + 6);
+ if (pos == std::string::npos) {
+ return Status::Corruption("checksum not found");
+ }
+
+ if (metadata.size() < pos + 7) {
+ return Status::Corruption("bad CRC32 checksum value");
+ }
+
+ if (appear_valid) {
+ if (metadata[pos + 8] == '\n') {
+ // single digit value, safe to insert one more digit
+ metadata.insert(pos + 8, 1, '0');
+ } else {
+ metadata.erase(pos + 8, 1);
+ }
+ } else {
+ metadata[pos + 7] = 'a';
+ }
+
+ return WriteToFile(fname, metadata);
+ }
+
+ Status WriteToFile(const std::string& fname, const std::string& data) {
+ std::unique_ptr<WritableFile> file;
+ EnvOptions env_options;
+ env_options.use_mmap_writes = false;
+ Status s = EnvWrapper::NewWritableFile(fname, &file, env_options);
+ if (!s.ok()) {
+ return s;
+ }
+ return file->Append(Slice(data));
+ }
+
+ private:
+ Random rnd_;
+}; // FileManager
+
+// utility functions
+namespace {
+
+enum FillDBFlushAction {
+ kFlushMost,
+ kFlushAll,
+ kAutoFlushOnly,
+};
+
+// Many tests in this file expect FillDB to write at least one sst file,
+// so the default behavior (if not kAutoFlushOnly) of FillDB is to force
+// a flush. But to ensure coverage of the WAL file case, we also (by default)
+// do one Put after the Flush (kFlushMost).
+size_t FillDB(DB* db, int from, int to,
+ FillDBFlushAction flush_action = kFlushMost) {
+ size_t bytes_written = 0;
+ for (int i = from; i < to; ++i) {
+ std::string key = "testkey" + std::to_string(i);
+ std::string value = "testvalue" + std::to_string(i);
+ bytes_written += key.size() + value.size();
+
+ EXPECT_OK(db->Put(WriteOptions(), Slice(key), Slice(value)));
+
+ if (flush_action == kFlushMost && i == to - 2) {
+ EXPECT_OK(db->Flush(FlushOptions()));
+ }
+ }
+ if (flush_action == kFlushAll) {
+ EXPECT_OK(db->Flush(FlushOptions()));
+ }
+ return bytes_written;
+}
+
+void AssertExists(DB* db, int from, int to) {
+ for (int i = from; i < to; ++i) {
+ std::string key = "testkey" + std::to_string(i);
+ std::string value;
+ Status s = db->Get(ReadOptions(), Slice(key), &value);
+ ASSERT_EQ(value, "testvalue" + std::to_string(i));
+ }
+}
+
+void AssertEmpty(DB* db, int from, int to) {
+ for (int i = from; i < to; ++i) {
+ std::string key = "testkey" + std::to_string(i);
+ std::string value = "testvalue" + std::to_string(i);
+
+ Status s = db->Get(ReadOptions(), Slice(key), &value);
+ ASSERT_TRUE(s.IsNotFound());
+ }
+}
+} // namespace
+
+class BackupEngineTest : public testing::Test {
+ public:
+ enum ShareOption {
+ kNoShare,
+ kShareNoChecksum,
+ kShareWithChecksum,
+ };
+
+ const std::vector<ShareOption> kAllShareOptions = {kNoShare, kShareNoChecksum,
+ kShareWithChecksum};
+
+ BackupEngineTest() {
+ // set up files
+ std::string db_chroot = test::PerThreadDBPath("db_for_backup");
+ std::string backup_chroot = test::PerThreadDBPath("db_backups");
+ EXPECT_OK(Env::Default()->CreateDirIfMissing(db_chroot));
+ EXPECT_OK(Env::Default()->CreateDirIfMissing(backup_chroot));
+ dbname_ = "/tempdb";
+ backupdir_ = "/tempbk";
+ latest_backup_ = backupdir_ + "/LATEST_BACKUP";
+
+ // set up FileSystem & Envs
+ db_chroot_fs_ = NewChrootFileSystem(FileSystem::Default(), db_chroot);
+ backup_chroot_fs_ =
+ NewChrootFileSystem(FileSystem::Default(), backup_chroot);
+ test_db_fs_ = std::make_shared<TestFs>(db_chroot_fs_);
+ test_backup_fs_ = std::make_shared<TestFs>(backup_chroot_fs_);
+ SetEnvsFromFileSystems();
+
+ // set up db options
+ options_.create_if_missing = true;
+ options_.paranoid_checks = true;
+ options_.write_buffer_size = 1 << 17; // 128KB
+ options_.wal_dir = dbname_;
+ options_.enable_blob_files = true;
+
+ // The sync option is not easily testable in unit tests, but should be
+ // smoke tested across all the other backup tests. However, it is
+ // certainly not worth doubling the runtime of backup tests for it.
+ // Thus, we can enable sync for one of our alternate testing
+ // configurations.
+ constexpr bool kUseSync =
+#ifdef ROCKSDB_MODIFY_NPHASH
+ true;
+#else
+ false;
+#endif // ROCKSDB_MODIFY_NPHASH
+
+ // set up backup db options
+ engine_options_.reset(new BackupEngineOptions(
+ backupdir_, test_backup_env_.get(), /*share_table_files*/ true,
+ logger_.get(), kUseSync));
+
+ // most tests will use multi-threaded backups
+ engine_options_->max_background_operations = 7;
+
+ // delete old files in db
+ DestroyDBWithoutCheck(dbname_, options_);
+
+ // delete old LATEST_BACKUP file, which some tests create for compatibility
+ // testing.
+ backup_chroot_env_->DeleteFile(latest_backup_).PermitUncheckedError();
+ }
+
+ void SetEnvsFromFileSystems() {
+ db_chroot_env_.reset(
+ new CompositeEnvWrapper(Env::Default(), db_chroot_fs_));
+ backup_chroot_env_.reset(
+ new CompositeEnvWrapper(Env::Default(), backup_chroot_fs_));
+ test_db_env_.reset(new CompositeEnvWrapper(Env::Default(), test_db_fs_));
+ options_.env = test_db_env_.get();
+ test_backup_env_.reset(
+ new CompositeEnvWrapper(Env::Default(), test_backup_fs_));
+ if (engine_options_) {
+ engine_options_->backup_env = test_backup_env_.get();
+ }
+ file_manager_.reset(new FileManager(backup_chroot_env_.get()));
+ db_file_manager_.reset(new FileManager(db_chroot_env_.get()));
+
+ // Create logger
+ DBOptions logger_options;
+ logger_options.env = db_chroot_env_.get();
+ ASSERT_OK(CreateLoggerFromOptions(dbname_, logger_options, &logger_));
+ }
+
+ DB* OpenDB() {
+ DB* db;
+ EXPECT_OK(DB::Open(options_, dbname_, &db));
+ return db;
+ }
+
+ void CloseAndReopenDB(bool read_only = false) {
+ // Close DB
+ db_.reset();
+
+ // Open DB
+ test_db_fs_->SetLimitWrittenFiles(1000000);
+ DB* db;
+ if (read_only) {
+ ASSERT_OK(DB::OpenForReadOnly(options_, dbname_, &db));
+ } else {
+ ASSERT_OK(DB::Open(options_, dbname_, &db));
+ }
+ db_.reset(db);
+ }
+
+ void InitializeDBAndBackupEngine(bool dummy = false) {
+ // reset all the db env defaults
+ test_db_fs_->SetLimitWrittenFiles(1000000);
+ test_db_fs_->SetDummySequentialFile(dummy);
+
+ DB* db;
+ if (dummy) {
+ dummy_db_ = new DummyDB(options_, dbname_);
+ db = dummy_db_;
+ } else {
+ ASSERT_OK(DB::Open(options_, dbname_, &db));
+ }
+ db_.reset(db);
+ }
+
+ virtual void OpenDBAndBackupEngine(
+ bool destroy_old_data = false, bool dummy = false,
+ ShareOption shared_option = kShareNoChecksum) {
+ InitializeDBAndBackupEngine(dummy);
+ // reset backup env defaults
+ test_backup_fs_->SetLimitWrittenFiles(1000000);
+ engine_options_->destroy_old_data = destroy_old_data;
+ engine_options_->share_table_files = shared_option != kNoShare;
+ engine_options_->share_files_with_checksum =
+ shared_option == kShareWithChecksum;
+ OpenBackupEngine(destroy_old_data);
+ }
+
+ void CloseDBAndBackupEngine() {
+ db_.reset();
+ backup_engine_.reset();
+ }
+
+ void OpenBackupEngine(bool destroy_old_data = false) {
+ engine_options_->destroy_old_data = destroy_old_data;
+ engine_options_->info_log = logger_.get();
+ BackupEngine* backup_engine;
+ ASSERT_OK(BackupEngine::Open(test_db_env_.get(), *engine_options_,
+ &backup_engine));
+ backup_engine_.reset(backup_engine);
+ }
+
+ void CloseBackupEngine() { backup_engine_.reset(nullptr); }
+
+ // cross-cutting test of GetBackupInfo
+ void AssertBackupInfoConsistency() {
+ std::vector<BackupInfo> backup_info;
+ backup_engine_->GetBackupInfo(&backup_info, /*with file details*/ true);
+ std::map<std::string, uint64_t> file_sizes;
+
+ // Find the files that are supposed to be there
+ for (auto& backup : backup_info) {
+ uint64_t sum_for_backup = 0;
+ for (auto& file : backup.file_details) {
+ auto e = file_sizes.find(file.relative_filename);
+ if (e == file_sizes.end()) {
+ // fprintf(stderr, "Adding %s -> %u\n",
+ // file.relative_filename.c_str(), (unsigned)file.size);
+ file_sizes[file.relative_filename] = file.size;
+ } else {
+ ASSERT_EQ(file_sizes[file.relative_filename], file.size);
+ }
+ sum_for_backup += file.size;
+ }
+ ASSERT_EQ(backup.size, sum_for_backup);
+ }
+
+ std::vector<BackupID> corrupt_backup_ids;
+ backup_engine_->GetCorruptedBackups(&corrupt_backup_ids);
+ bool has_corrupt = corrupt_backup_ids.size() > 0;
+
+ // Compare with what's in backup dir
+ std::vector<std::string> child_dirs;
+ ASSERT_OK(
+ test_backup_env_->GetChildren(backupdir_ + "/private", &child_dirs));
+ for (auto& dir : child_dirs) {
+ dir = "private/" + dir;
+ }
+ child_dirs.push_back("shared"); // might not exist
+ child_dirs.push_back("shared_checksum"); // might not exist
+ for (auto& dir : child_dirs) {
+ std::vector<std::string> children;
+ test_backup_env_->GetChildren(backupdir_ + "/" + dir, &children)
+ .PermitUncheckedError();
+ // fprintf(stderr, "ls %s\n", (backupdir_ + "/" + dir).c_str());
+ for (auto& file : children) {
+ uint64_t size;
+ size = UINT64_MAX; // appease clang-analyze
+ std::string rel_file = dir + "/" + file;
+ // fprintf(stderr, "stat %s\n", (backupdir_ + "/" + rel_file).c_str());
+ ASSERT_OK(
+ test_backup_env_->GetFileSize(backupdir_ + "/" + rel_file, &size));
+ auto e = file_sizes.find(rel_file);
+ if (e == file_sizes.end()) {
+ // The only case in which we should find files not reported
+ ASSERT_TRUE(has_corrupt);
+ } else {
+ ASSERT_EQ(e->second, size);
+ file_sizes.erase(e);
+ }
+ }
+ }
+
+ // Everything should have been matched
+ ASSERT_EQ(file_sizes.size(), 0);
+ }
+
+ // restores backup backup_id and asserts the existence of
+ // [start_exist, end_exist> and not-existence of
+ // [end_exist, end>
+ //
+ // if backup_id == 0, it means restore from latest
+ // if end == 0, don't check AssertEmpty
+ void AssertBackupConsistency(BackupID backup_id, uint32_t start_exist,
+ uint32_t end_exist, uint32_t end = 0,
+ bool keep_log_files = false) {
+ RestoreOptions restore_options(keep_log_files);
+ bool opened_backup_engine = false;
+ if (backup_engine_.get() == nullptr) {
+ opened_backup_engine = true;
+ OpenBackupEngine();
+ }
+ AssertBackupInfoConsistency();
+
+ // Now perform restore
+ if (backup_id > 0) {
+ ASSERT_OK(backup_engine_->RestoreDBFromBackup(backup_id, dbname_, dbname_,
+ restore_options));
+ } else {
+ ASSERT_OK(backup_engine_->RestoreDBFromLatestBackup(dbname_, dbname_,
+ restore_options));
+ }
+ DB* db = OpenDB();
+ // Check DB contents
+ AssertExists(db, start_exist, end_exist);
+ if (end != 0) {
+ AssertEmpty(db, end_exist, end);
+ }
+ delete db;
+ if (opened_backup_engine) {
+ CloseBackupEngine();
+ }
+ }
+
+ void DeleteLogFiles() {
+ std::vector<std::string> delete_logs;
+ ASSERT_OK(db_chroot_env_->GetChildren(dbname_, &delete_logs));
+ for (auto f : delete_logs) {
+ uint64_t number;
+ FileType type;
+ bool ok = ParseFileName(f, &number, &type);
+ if (ok && type == kWalFile) {
+ ASSERT_OK(db_chroot_env_->DeleteFile(dbname_ + "/" + f));
+ }
+ }
+ }
+
+ Status GetDataFilesInDB(const FileType& file_type,
+ std::vector<FileAttributes>* files) {
+ std::vector<std::string> live;
+ uint64_t ignore_manifest_size;
+ Status s = db_->GetLiveFiles(live, &ignore_manifest_size, /*flush*/ false);
+ if (!s.ok()) {
+ return s;
+ }
+ std::vector<FileAttributes> children;
+ s = test_db_env_->GetChildrenFileAttributes(dbname_, &children);
+ for (const auto& child : children) {
+ FileType type;
+ uint64_t number = 0;
+ if (ParseFileName(child.name, &number, &type) && type == file_type &&
+ std::find(live.begin(), live.end(), "/" + child.name) != live.end()) {
+ files->push_back(child);
+ }
+ }
+ return s;
+ }
+
+ Status GetRandomDataFileInDB(const FileType& file_type,
+ std::string* fname_out,
+ uint64_t* fsize_out = nullptr) {
+ Random rnd(6); // NB: hardly "random"
+ std::vector<FileAttributes> files;
+ Status s = GetDataFilesInDB(file_type, &files);
+ if (!s.ok()) {
+ return s;
+ }
+ if (files.empty()) {
+ return Status::NotFound("");
+ }
+ size_t i = rnd.Uniform(static_cast<int>(files.size()));
+ *fname_out = dbname_ + "/" + files[i].name;
+ if (fsize_out) {
+ *fsize_out = files[i].size_bytes;
+ }
+ return Status::OK();
+ }
+
+ Status CorruptRandomDataFileInDB(const FileType& file_type) {
+ std::string fname;
+ uint64_t fsize = 0;
+ Status s = GetRandomDataFileInDB(file_type, &fname, &fsize);
+ if (!s.ok()) {
+ return s;
+ }
+
+ std::string file_contents;
+ s = ReadFileToString(test_db_env_.get(), fname, &file_contents);
+ if (!s.ok()) {
+ return s;
+ }
+ s = test_db_env_->DeleteFile(fname);
+ if (!s.ok()) {
+ return s;
+ }
+
+ file_contents[0] = (file_contents[0] + 257) % 256;
+ return WriteStringToFile(test_db_env_.get(), file_contents, fname);
+ }
+
+ void AssertDirectoryFilesMatchRegex(const std::string& dir,
+ const TestRegex& pattern,
+ const std::string& file_type,
+ int minimum_count) {
+ std::vector<FileAttributes> children;
+ ASSERT_OK(file_manager_->GetChildrenFileAttributes(dir, &children));
+ int found_count = 0;
+ for (const auto& child : children) {
+ if (EndsWith(child.name, file_type)) {
+ ASSERT_MATCHES_REGEX(child.name, pattern);
+ ++found_count;
+ }
+ }
+ ASSERT_GE(found_count, minimum_count);
+ }
+
+ void AssertDirectoryFilesSizeIndicators(const std::string& dir,
+ int minimum_count) {
+ std::vector<FileAttributes> children;
+ ASSERT_OK(file_manager_->GetChildrenFileAttributes(dir, &children));
+ int found_count = 0;
+ for (const auto& child : children) {
+ auto last_underscore = child.name.find_last_of('_');
+ auto last_dot = child.name.find_last_of('.');
+ ASSERT_NE(child.name, child.name.substr(0, last_underscore));
+ ASSERT_NE(child.name, child.name.substr(0, last_dot));
+ ASSERT_LT(last_underscore, last_dot);
+ std::string s = child.name.substr(last_underscore + 1,
+ last_dot - (last_underscore + 1));
+ ASSERT_EQ(s, std::to_string(child.size_bytes));
+ ++found_count;
+ }
+ ASSERT_GE(found_count, minimum_count);
+ }
+
+ // files
+ std::string dbname_;
+ std::string backupdir_;
+ std::string latest_backup_;
+
+ // logger_ must be above backup_engine_ such that the engine's destructor,
+ // which uses a raw pointer to the logger, executes first.
+ std::shared_ptr<Logger> logger_;
+
+ // FileSystems
+ std::shared_ptr<FileSystem> db_chroot_fs_;
+ std::shared_ptr<FileSystem> backup_chroot_fs_;
+ std::shared_ptr<TestFs> test_db_fs_;
+ std::shared_ptr<TestFs> test_backup_fs_;
+
+ // Env wrappers
+ std::unique_ptr<Env> db_chroot_env_;
+ std::unique_ptr<Env> backup_chroot_env_;
+ std::unique_ptr<Env> test_db_env_;
+ std::unique_ptr<Env> test_backup_env_;
+ std::unique_ptr<FileManager> file_manager_;
+ std::unique_ptr<FileManager> db_file_manager_;
+
+ // all the dbs!
+ DummyDB* dummy_db_; // owned as db_ when present
+ std::unique_ptr<DB> db_;
+ std::unique_ptr<BackupEngine> backup_engine_;
+
+ // options
+ Options options_;
+
+ protected:
+ void DestroyDBWithoutCheck(const std::string& dbname,
+ const Options& options) {
+ // DestroyDB may fail because the db might not be existed for some tests
+ DestroyDB(dbname, options).PermitUncheckedError();
+ }
+
+ std::unique_ptr<BackupEngineOptions> engine_options_;
+}; // BackupEngineTest
+
+void AppendPath(const std::string& path, std::vector<std::string>& v) {
+ for (auto& f : v) {
+ f = path + f;
+ }
+}
+
+class BackupEngineTestWithParam : public BackupEngineTest,
+ public testing::WithParamInterface<bool> {
+ public:
+ BackupEngineTestWithParam() {
+ engine_options_->share_files_with_checksum = GetParam();
+ }
+ void OpenDBAndBackupEngine(
+ bool destroy_old_data = false, bool dummy = false,
+ ShareOption shared_option = kShareNoChecksum) override {
+ BackupEngineTest::InitializeDBAndBackupEngine(dummy);
+ // reset backup env defaults
+ test_backup_fs_->SetLimitWrittenFiles(1000000);
+ engine_options_->destroy_old_data = destroy_old_data;
+ engine_options_->share_table_files = shared_option != kNoShare;
+ // NOTE: keep share_files_with_checksum setting from constructor
+ OpenBackupEngine(destroy_old_data);
+ }
+};
+
+TEST_F(BackupEngineTest, FileCollision) {
+ const int keys_iteration = 100;
+ for (const auto& sopt : kAllShareOptions) {
+ OpenDBAndBackupEngine(true /* destroy_old_data */, false /* dummy */, sopt);
+ FillDB(db_.get(), 0, keys_iteration);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get()));
+ FillDB(db_.get(), keys_iteration, keys_iteration * 2);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get()));
+ CloseDBAndBackupEngine();
+
+ // If the db directory has been cleaned up, it is sensitive to file
+ // collision.
+ DestroyDBWithoutCheck(dbname_, options_);
+
+ // open fresh DB, but old backups present
+ OpenDBAndBackupEngine(false /* destroy_old_data */, false /* dummy */,
+ sopt);
+ FillDB(db_.get(), 0, keys_iteration);
+ ASSERT_OK(db_->Flush(FlushOptions())); // like backup would do
+ FillDB(db_.get(), keys_iteration, keys_iteration * 2);
+ if (sopt != kShareNoChecksum) {
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get()));
+ } else {
+ // The new table files created in FillDB() will clash with the old
+ // backup and sharing tables with no checksum will have the file
+ // collision problem.
+ ASSERT_NOK(backup_engine_->CreateNewBackup(db_.get()));
+ ASSERT_OK(backup_engine_->PurgeOldBackups(0));
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get()));
+ }
+ CloseDBAndBackupEngine();
+
+ // delete old data
+ DestroyDBWithoutCheck(dbname_, options_);
+ }
+}
+
+// This test verifies that the verifyBackup method correctly identifies
+// invalid backups
+TEST_P(BackupEngineTestWithParam, VerifyBackup) {
+ const int keys_iteration = 5000;
+ OpenDBAndBackupEngine(true);
+ // create five backups
+ for (int i = 0; i < 5; ++i) {
+ FillDB(db_.get(), keys_iteration * i, keys_iteration * (i + 1));
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), true));
+ }
+ CloseDBAndBackupEngine();
+
+ OpenDBAndBackupEngine();
+ // ---------- case 1. - valid backup -----------
+ ASSERT_TRUE(backup_engine_->VerifyBackup(1).ok());
+
+ // ---------- case 2. - delete a file -----------i
+ ASSERT_OK(file_manager_->DeleteRandomFileInDir(backupdir_ + "/private/1"));
+ ASSERT_TRUE(backup_engine_->VerifyBackup(1).IsNotFound());
+
+ // ---------- case 3. - corrupt a file -----------
+ std::string append_data = "Corrupting a random file";
+ ASSERT_OK(file_manager_->AppendToRandomFileInDir(backupdir_ + "/private/2",
+ append_data));
+ ASSERT_TRUE(backup_engine_->VerifyBackup(2).IsCorruption());
+
+ // ---------- case 4. - invalid backup -----------
+ ASSERT_TRUE(backup_engine_->VerifyBackup(6).IsNotFound());
+ CloseDBAndBackupEngine();
+}
+
+#if !defined(ROCKSDB_VALGRIND_RUN) || defined(ROCKSDB_FULL_VALGRIND_RUN)
+// open DB, write, close DB, backup, restore, repeat
+TEST_P(BackupEngineTestWithParam, OfflineIntegrationTest) {
+ // has to be a big number, so that it triggers the memtable flush
+ const int keys_iteration = 5000;
+ const int max_key = keys_iteration * 4 + 10;
+ // first iter -- flush before backup
+ // second iter -- don't flush before backup
+ for (int iter = 0; iter < 2; ++iter) {
+ // delete old data
+ DestroyDBWithoutCheck(dbname_, options_);
+ bool destroy_data = true;
+
+ // every iteration --
+ // 1. insert new data in the DB
+ // 2. backup the DB
+ // 3. destroy the db
+ // 4. restore the db, check everything is still there
+ for (int i = 0; i < 5; ++i) {
+ // in last iteration, put smaller amount of data,
+ int fill_up_to = std::min(keys_iteration * (i + 1), max_key);
+ // ---- insert new data and back up ----
+ OpenDBAndBackupEngine(destroy_data);
+ destroy_data = false;
+ // kAutoFlushOnly to preserve legacy test behavior (consider updating)
+ FillDB(db_.get(), keys_iteration * i, fill_up_to, kAutoFlushOnly);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), iter == 0))
+ << "iter: " << iter << ", idx: " << i;
+ CloseDBAndBackupEngine();
+ DestroyDBWithoutCheck(dbname_, options_);
+
+ // ---- make sure it's empty ----
+ DB* db = OpenDB();
+ AssertEmpty(db, 0, fill_up_to);
+ delete db;
+
+ // ---- restore the DB ----
+ OpenBackupEngine();
+ if (i >= 3) { // test purge old backups
+ // when i == 4, purge to only 1 backup
+ // when i == 3, purge to 2 backups
+ ASSERT_OK(backup_engine_->PurgeOldBackups(5 - i));
+ }
+ // ---- make sure the data is there ---
+ AssertBackupConsistency(0, 0, fill_up_to, max_key);
+ CloseBackupEngine();
+ }
+ }
+}
+
+// open DB, write, backup, write, backup, close, restore
+TEST_P(BackupEngineTestWithParam, OnlineIntegrationTest) {
+ // has to be a big number, so that it triggers the memtable flush
+ const int keys_iteration = 5000;
+ const int max_key = keys_iteration * 4 + 10;
+ Random rnd(7);
+ // delete old data
+ DestroyDBWithoutCheck(dbname_, options_);
+
+ // TODO: Implement & test db_paths support in backup (not supported in
+ // restore)
+ // options_.db_paths.emplace_back(dbname_, 500 * 1024);
+ // options_.db_paths.emplace_back(dbname_ + "_2", 1024 * 1024 * 1024);
+
+ OpenDBAndBackupEngine(true);
+ // write some data, backup, repeat
+ for (int i = 0; i < 5; ++i) {
+ if (i == 4) {
+ // delete backup number 2, online delete!
+ ASSERT_OK(backup_engine_->DeleteBackup(2));
+ }
+ // in last iteration, put smaller amount of data,
+ // so that backups can share sst files
+ int fill_up_to = std::min(keys_iteration * (i + 1), max_key);
+ // kAutoFlushOnly to preserve legacy test behavior (consider updating)
+ FillDB(db_.get(), keys_iteration * i, fill_up_to, kAutoFlushOnly);
+ // we should get consistent results with flush_before_backup
+ // set to both true and false
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), !!(rnd.Next() % 2)));
+ }
+ // close and destroy
+ CloseDBAndBackupEngine();
+ DestroyDBWithoutCheck(dbname_, options_);
+
+ // ---- make sure it's empty ----
+ DB* db = OpenDB();
+ AssertEmpty(db, 0, max_key);
+ delete db;
+
+ // ---- restore every backup and verify all the data is there ----
+ OpenBackupEngine();
+ for (int i = 1; i <= 5; ++i) {
+ if (i == 2) {
+ // we deleted backup 2
+ Status s = backup_engine_->RestoreDBFromBackup(2, dbname_, dbname_);
+ ASSERT_TRUE(!s.ok());
+ } else {
+ int fill_up_to = std::min(keys_iteration * i, max_key);
+ AssertBackupConsistency(i, 0, fill_up_to, max_key);
+ }
+ }
+
+ // delete some backups -- this should leave only backups 3 and 5 alive
+ ASSERT_OK(backup_engine_->DeleteBackup(4));
+ ASSERT_OK(backup_engine_->PurgeOldBackups(2));
+
+ std::vector<BackupInfo> backup_info;
+ backup_engine_->GetBackupInfo(&backup_info);
+ ASSERT_EQ(2UL, backup_info.size());
+
+ // check backup 3
+ AssertBackupConsistency(3, 0, 3 * keys_iteration, max_key);
+ // check backup 5
+ AssertBackupConsistency(5, 0, max_key);
+
+ CloseBackupEngine();
+}
+#endif // !defined(ROCKSDB_VALGRIND_RUN) || defined(ROCKSDB_FULL_VALGRIND_RUN)
+
+INSTANTIATE_TEST_CASE_P(BackupEngineTestWithParam, BackupEngineTestWithParam,
+ ::testing::Bool());
+
+// this will make sure that backup does not copy the same file twice
+TEST_F(BackupEngineTest, NoDoubleCopy_And_AutoGC) {
+ OpenDBAndBackupEngine(true, true);
+
+ // should write 5 DB files + one meta file
+ test_backup_fs_->SetLimitWrittenFiles(7);
+ test_backup_fs_->ClearWrittenFiles();
+ test_db_fs_->SetLimitWrittenFiles(0);
+ dummy_db_->live_files_ = {"00010.sst", "00011.sst", "CURRENT", "MANIFEST-01",
+ "00011.log"};
+ test_db_fs_->SetFilenamesForMockedAttrs(dummy_db_->live_files_);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), false));
+ std::vector<std::string> should_have_written = {
+ "/shared/.00010.sst.tmp", "/shared/.00011.sst.tmp", "/private/1/CURRENT",
+ "/private/1/MANIFEST-01", "/private/1/00011.log", "/meta/.1.tmp"};
+ AppendPath(backupdir_, should_have_written);
+ test_backup_fs_->AssertWrittenFiles(should_have_written);
+
+ char db_number = '1';
+
+ for (std::string other_sst : {"00015.sst", "00017.sst", "00019.sst"}) {
+ // should write 4 new DB files + one meta file
+ // should not write/copy 00010.sst, since it's already there!
+ test_backup_fs_->SetLimitWrittenFiles(6);
+ test_backup_fs_->ClearWrittenFiles();
+
+ dummy_db_->live_files_ = {"00010.sst", other_sst, "CURRENT", "MANIFEST-01",
+ "00011.log"};
+ test_db_fs_->SetFilenamesForMockedAttrs(dummy_db_->live_files_);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), false));
+ // should not open 00010.sst - it's already there
+
+ ++db_number;
+ std::string private_dir = std::string("/private/") + db_number;
+ should_have_written = {
+ "/shared/." + other_sst + ".tmp", private_dir + "/CURRENT",
+ private_dir + "/MANIFEST-01", private_dir + "/00011.log",
+ std::string("/meta/.") + db_number + ".tmp"};
+ AppendPath(backupdir_, should_have_written);
+ test_backup_fs_->AssertWrittenFiles(should_have_written);
+ }
+
+ ASSERT_OK(backup_engine_->DeleteBackup(1));
+ ASSERT_OK(test_backup_env_->FileExists(backupdir_ + "/shared/00010.sst"));
+
+ // 00011.sst was only in backup 1, should be deleted
+ ASSERT_EQ(Status::NotFound(),
+ test_backup_env_->FileExists(backupdir_ + "/shared/00011.sst"));
+ ASSERT_OK(test_backup_env_->FileExists(backupdir_ + "/shared/00015.sst"));
+
+ // MANIFEST file size should be only 100
+ uint64_t size = 0;
+ ASSERT_OK(test_backup_env_->GetFileSize(backupdir_ + "/private/2/MANIFEST-01",
+ &size));
+ ASSERT_EQ(100UL, size);
+ ASSERT_OK(
+ test_backup_env_->GetFileSize(backupdir_ + "/shared/00015.sst", &size));
+ ASSERT_EQ(200UL, size);
+
+ CloseBackupEngine();
+
+ //
+ // Now simulate incomplete delete by removing just meta
+ //
+ ASSERT_OK(test_backup_env_->DeleteFile(backupdir_ + "/meta/2"));
+
+ OpenBackupEngine();
+
+ // 1 appears to be removed, so
+ // 2 non-corrupt and 0 corrupt seen
+ std::vector<BackupInfo> backup_info;
+ std::vector<BackupID> corrupt_backup_ids;
+ backup_engine_->GetBackupInfo(&backup_info);
+ backup_engine_->GetCorruptedBackups(&corrupt_backup_ids);
+ ASSERT_EQ(2UL, backup_info.size());
+ ASSERT_EQ(0UL, corrupt_backup_ids.size());
+
+ // Keep the two we see, but this should suffice to purge unreferenced
+ // shared files from incomplete delete.
+ ASSERT_OK(backup_engine_->PurgeOldBackups(2));
+
+ // Make sure dangling sst file has been removed (somewhere along this
+ // process). GarbageCollect should not be needed.
+ ASSERT_EQ(Status::NotFound(),
+ test_backup_env_->FileExists(backupdir_ + "/shared/00015.sst"));
+ ASSERT_OK(test_backup_env_->FileExists(backupdir_ + "/shared/00017.sst"));
+ ASSERT_OK(test_backup_env_->FileExists(backupdir_ + "/shared/00019.sst"));
+
+ // Now actually purge a good one
+ ASSERT_OK(backup_engine_->PurgeOldBackups(1));
+
+ ASSERT_EQ(Status::NotFound(),
+ test_backup_env_->FileExists(backupdir_ + "/shared/00017.sst"));
+ ASSERT_OK(test_backup_env_->FileExists(backupdir_ + "/shared/00019.sst"));
+
+ CloseDBAndBackupEngine();
+}
+
+// test various kind of corruptions that may happen:
+// 1. Not able to write a file for backup - that backup should fail,
+// everything else should work
+// 2. Corrupted backup meta file or missing backuped file - we should
+// not be able to open that backup, but all other backups should be
+// fine
+// 3. Corrupted checksum value - if the checksum is not a valid uint32_t,
+// db open should fail, otherwise, it aborts during the restore process.
+TEST_F(BackupEngineTest, CorruptionsTest) {
+ const int keys_iteration = 5000;
+ Random rnd(6);
+ Status s;
+
+ OpenDBAndBackupEngine(true);
+ // create five backups
+ for (int i = 0; i < 5; ++i) {
+ FillDB(db_.get(), keys_iteration * i, keys_iteration * (i + 1));
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), !!(rnd.Next() % 2)));
+ }
+
+ // ---------- case 1. - fail a write -----------
+ // try creating backup 6, but fail a write
+ FillDB(db_.get(), keys_iteration * 5, keys_iteration * 6);
+ test_backup_fs_->SetLimitWrittenFiles(2);
+ // should fail
+ s = backup_engine_->CreateNewBackup(db_.get(), !!(rnd.Next() % 2));
+ ASSERT_NOK(s);
+ test_backup_fs_->SetLimitWrittenFiles(1000000);
+ // latest backup should have all the keys
+ CloseDBAndBackupEngine();
+ AssertBackupConsistency(0, 0, keys_iteration * 5, keys_iteration * 6);
+
+ // --------- case 2. corrupted backup meta or missing backuped file ----
+ ASSERT_OK(file_manager_->CorruptFile(backupdir_ + "/meta/5", 3));
+ // since 5 meta is now corrupted, latest backup should be 4
+ AssertBackupConsistency(0, 0, keys_iteration * 4, keys_iteration * 5);
+ OpenBackupEngine();
+ s = backup_engine_->RestoreDBFromBackup(5, dbname_, dbname_);
+ ASSERT_NOK(s);
+ CloseBackupEngine();
+ ASSERT_OK(file_manager_->DeleteRandomFileInDir(backupdir_ + "/private/4"));
+ // 4 is corrupted, 3 is the latest backup now
+ AssertBackupConsistency(0, 0, keys_iteration * 3, keys_iteration * 5);
+ OpenBackupEngine();
+ s = backup_engine_->RestoreDBFromBackup(4, dbname_, dbname_);
+ CloseBackupEngine();
+ ASSERT_NOK(s);
+
+ // --------- case 3. corrupted checksum value ----
+ ASSERT_OK(file_manager_->CorruptChecksum(backupdir_ + "/meta/3", false));
+ // checksum of backup 3 is an invalid value, this can be detected at
+ // db open time, and it reverts to the previous backup automatically
+ AssertBackupConsistency(0, 0, keys_iteration * 2, keys_iteration * 5);
+ // checksum of the backup 2 appears to be valid, this can cause checksum
+ // mismatch and abort restore process
+ ASSERT_OK(file_manager_->CorruptChecksum(backupdir_ + "/meta/2", true));
+ ASSERT_OK(file_manager_->FileExists(backupdir_ + "/meta/2"));
+ OpenBackupEngine();
+ ASSERT_OK(file_manager_->FileExists(backupdir_ + "/meta/2"));
+ s = backup_engine_->RestoreDBFromBackup(2, dbname_, dbname_);
+ ASSERT_NOK(s);
+
+ // make sure that no corrupt backups have actually been deleted!
+ ASSERT_OK(file_manager_->FileExists(backupdir_ + "/meta/1"));
+ ASSERT_OK(file_manager_->FileExists(backupdir_ + "/meta/2"));
+ ASSERT_OK(file_manager_->FileExists(backupdir_ + "/meta/3"));
+ ASSERT_OK(file_manager_->FileExists(backupdir_ + "/meta/4"));
+ ASSERT_OK(file_manager_->FileExists(backupdir_ + "/meta/5"));
+ ASSERT_OK(file_manager_->FileExists(backupdir_ + "/private/1"));
+ ASSERT_OK(file_manager_->FileExists(backupdir_ + "/private/2"));
+ ASSERT_OK(file_manager_->FileExists(backupdir_ + "/private/3"));
+ ASSERT_OK(file_manager_->FileExists(backupdir_ + "/private/4"));
+ ASSERT_OK(file_manager_->FileExists(backupdir_ + "/private/5"));
+
+ // delete the corrupt backups and then make sure they're actually deleted
+ ASSERT_OK(backup_engine_->DeleteBackup(5));
+ ASSERT_OK(backup_engine_->DeleteBackup(4));
+ ASSERT_OK(backup_engine_->DeleteBackup(3));
+ ASSERT_OK(backup_engine_->DeleteBackup(2));
+ // Should not be needed anymore with auto-GC on DeleteBackup
+ //(void)backup_engine_->GarbageCollect();
+ ASSERT_EQ(Status::NotFound(),
+ file_manager_->FileExists(backupdir_ + "/meta/5"));
+ ASSERT_EQ(Status::NotFound(),
+ file_manager_->FileExists(backupdir_ + "/private/5"));
+ ASSERT_EQ(Status::NotFound(),
+ file_manager_->FileExists(backupdir_ + "/meta/4"));
+ ASSERT_EQ(Status::NotFound(),
+ file_manager_->FileExists(backupdir_ + "/private/4"));
+ ASSERT_EQ(Status::NotFound(),
+ file_manager_->FileExists(backupdir_ + "/meta/3"));
+ ASSERT_EQ(Status::NotFound(),
+ file_manager_->FileExists(backupdir_ + "/private/3"));
+ ASSERT_EQ(Status::NotFound(),
+ file_manager_->FileExists(backupdir_ + "/meta/2"));
+ ASSERT_EQ(Status::NotFound(),
+ file_manager_->FileExists(backupdir_ + "/private/2"));
+ CloseBackupEngine();
+ AssertBackupConsistency(0, 0, keys_iteration * 1, keys_iteration * 5);
+
+ // new backup should be 2!
+ OpenDBAndBackupEngine();
+ FillDB(db_.get(), keys_iteration * 1, keys_iteration * 2);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), !!(rnd.Next() % 2)));
+ CloseDBAndBackupEngine();
+ AssertBackupConsistency(2, 0, keys_iteration * 2, keys_iteration * 5);
+}
+
+// Corrupt a file but maintain its size
+TEST_F(BackupEngineTest, CorruptFileMaintainSize) {
+ const int keys_iteration = 5000;
+ OpenDBAndBackupEngine(true);
+ // create a backup
+ FillDB(db_.get(), 0, keys_iteration);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), true));
+ CloseDBAndBackupEngine();
+
+ OpenDBAndBackupEngine();
+ // verify with file size
+ ASSERT_OK(backup_engine_->VerifyBackup(1, false));
+ // verify with file checksum
+ ASSERT_OK(backup_engine_->VerifyBackup(1, true));
+
+ std::string file_to_corrupt;
+ uint64_t file_size = 0;
+ // under normal circumstance, there should be at least one nonempty file
+ while (file_size == 0) {
+ // get a random file in /private/1
+ assert(file_manager_
+ ->GetRandomFileInDir(backupdir_ + "/private/1", &file_to_corrupt,
+ &file_size)
+ .ok());
+ // corrupt the file by replacing its content by file_size random bytes
+ ASSERT_OK(file_manager_->CorruptFile(file_to_corrupt, file_size));
+ }
+ // file sizes match
+ ASSERT_OK(backup_engine_->VerifyBackup(1, false));
+ // file checksums mismatch
+ ASSERT_NOK(backup_engine_->VerifyBackup(1, true));
+ // sanity check, use default second argument
+ ASSERT_OK(backup_engine_->VerifyBackup(1));
+ CloseDBAndBackupEngine();
+
+ // an extra challenge
+ // set share_files_with_checksum to true and do two more backups
+ // corrupt all the table files in shared_checksum but maintain their sizes
+ OpenDBAndBackupEngine(true /* destroy_old_data */, false /* dummy */,
+ kShareWithChecksum);
+ // creat two backups
+ for (int i = 1; i < 3; ++i) {
+ FillDB(db_.get(), keys_iteration * i, keys_iteration * (i + 1));
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), true));
+ }
+ CloseDBAndBackupEngine();
+
+ OpenDBAndBackupEngine();
+ std::vector<FileAttributes> children;
+ const std::string dir = backupdir_ + "/shared_checksum";
+ ASSERT_OK(file_manager_->GetChildrenFileAttributes(dir, &children));
+ for (const auto& child : children) {
+ if (child.size_bytes == 0) {
+ continue;
+ }
+ // corrupt the file by replacing its content by file_size random bytes
+ ASSERT_OK(
+ file_manager_->CorruptFile(dir + "/" + child.name, child.size_bytes));
+ }
+ // file sizes match
+ ASSERT_OK(backup_engine_->VerifyBackup(1, false));
+ ASSERT_OK(backup_engine_->VerifyBackup(2, false));
+ // file checksums mismatch
+ ASSERT_NOK(backup_engine_->VerifyBackup(1, true));
+ ASSERT_NOK(backup_engine_->VerifyBackup(2, true));
+ CloseDBAndBackupEngine();
+}
+
+// Corrupt a blob file but maintain its size
+TEST_P(BackupEngineTestWithParam, CorruptBlobFileMaintainSize) {
+ const int keys_iteration = 5000;
+ OpenDBAndBackupEngine(true);
+ // create a backup
+ FillDB(db_.get(), 0, keys_iteration);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), true));
+ CloseDBAndBackupEngine();
+
+ OpenDBAndBackupEngine();
+ // verify with file size
+ ASSERT_OK(backup_engine_->VerifyBackup(1, false));
+ // verify with file checksum
+ ASSERT_OK(backup_engine_->VerifyBackup(1, true));
+
+ std::string file_to_corrupt;
+ std::vector<FileAttributes> children;
+
+ std::string dir = backupdir_;
+ if (engine_options_->share_files_with_checksum) {
+ dir += "/shared_checksum";
+ } else {
+ dir += "/shared";
+ }
+
+ ASSERT_OK(file_manager_->GetChildrenFileAttributes(dir, &children));
+
+ for (const auto& child : children) {
+ if (EndsWith(child.name, ".blob") && child.size_bytes != 0) {
+ // corrupt the blob files by replacing its content by file_size random
+ // bytes
+ ASSERT_OK(
+ file_manager_->CorruptFile(dir + "/" + child.name, child.size_bytes));
+ }
+ }
+
+ // file sizes match
+ ASSERT_OK(backup_engine_->VerifyBackup(1, false));
+ // file checksums mismatch
+ ASSERT_NOK(backup_engine_->VerifyBackup(1, true));
+ // sanity check, use default second argument
+ ASSERT_OK(backup_engine_->VerifyBackup(1));
+ CloseDBAndBackupEngine();
+}
+
+// Test if BackupEngine will fail to create new backup if some table has been
+// corrupted and the table file checksum is stored in the DB manifest
+TEST_F(BackupEngineTest, TableFileCorruptedBeforeBackup) {
+ const int keys_iteration = 50000;
+
+ OpenDBAndBackupEngine(true /* destroy_old_data */, false /* dummy */,
+ kNoShare);
+ FillDB(db_.get(), 0, keys_iteration);
+ CloseAndReopenDB(/*read_only*/ true);
+ // corrupt a random table file in the DB directory
+ ASSERT_OK(CorruptRandomDataFileInDB(kTableFile));
+ // file_checksum_gen_factory is null, and thus table checksum is not
+ // verified for creating a new backup; no correction is detected
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get()));
+ CloseDBAndBackupEngine();
+
+ // delete old files in db
+ DestroyDBWithoutCheck(dbname_, options_);
+
+ // Enable table file checksum in DB manifest
+ options_.file_checksum_gen_factory = GetFileChecksumGenCrc32cFactory();
+ OpenDBAndBackupEngine(true /* destroy_old_data */, false /* dummy */,
+ kNoShare);
+ FillDB(db_.get(), 0, keys_iteration);
+ CloseAndReopenDB(/*read_only*/ true);
+ // corrupt a random table file in the DB directory
+ ASSERT_OK(CorruptRandomDataFileInDB(kTableFile));
+ // table file checksum is enabled so we should be able to detect any
+ // corruption
+ ASSERT_NOK(backup_engine_->CreateNewBackup(db_.get()));
+ CloseDBAndBackupEngine();
+}
+
+// Test if BackupEngine will fail to create new backup if some blob files has
+// been corrupted and the blob file checksum is stored in the DB manifest
+TEST_F(BackupEngineTest, BlobFileCorruptedBeforeBackup) {
+ const int keys_iteration = 50000;
+
+ OpenDBAndBackupEngine(true /* destroy_old_data */, false /* dummy */,
+ kNoShare);
+ FillDB(db_.get(), 0, keys_iteration);
+ CloseAndReopenDB(/*read_only*/ true);
+ // corrupt a random blob file in the DB directory
+ ASSERT_OK(CorruptRandomDataFileInDB(kBlobFile));
+ // file_checksum_gen_factory is null, and thus blob checksum is not
+ // verified for creating a new backup; no correction is detected
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get()));
+ CloseDBAndBackupEngine();
+
+ // delete old files in db
+ DestroyDBWithoutCheck(dbname_, options_);
+
+ // Enable file checksum in DB manifest
+ options_.file_checksum_gen_factory = GetFileChecksumGenCrc32cFactory();
+ OpenDBAndBackupEngine(true /* destroy_old_data */, false /* dummy */,
+ kNoShare);
+ FillDB(db_.get(), 0, keys_iteration);
+ CloseAndReopenDB(/*read_only*/ true);
+ // corrupt a random blob file in the DB directory
+ ASSERT_OK(CorruptRandomDataFileInDB(kBlobFile));
+
+ // file checksum is enabled so we should be able to detect any
+ // corruption
+ ASSERT_NOK(backup_engine_->CreateNewBackup(db_.get()));
+ CloseDBAndBackupEngine();
+}
+
+#if !defined(ROCKSDB_VALGRIND_RUN) || defined(ROCKSDB_FULL_VALGRIND_RUN)
+// Test if BackupEngine will fail to create new backup if some table has been
+// corrupted and the table file checksum is stored in the DB manifest for the
+// case when backup table files will be stored in a shared directory
+TEST_P(BackupEngineTestWithParam, TableFileCorruptedBeforeBackup) {
+ const int keys_iteration = 50000;
+
+ OpenDBAndBackupEngine(true /* destroy_old_data */);
+ FillDB(db_.get(), 0, keys_iteration);
+ CloseAndReopenDB(/*read_only*/ true);
+ // corrupt a random table file in the DB directory
+ ASSERT_OK(CorruptRandomDataFileInDB(kTableFile));
+ // cannot detect corruption since DB manifest has no table checksums
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get()));
+ CloseDBAndBackupEngine();
+
+ // delete old files in db
+ DestroyDBWithoutCheck(dbname_, options_);
+
+ // Enable table checksums in DB manifest
+ options_.file_checksum_gen_factory = GetFileChecksumGenCrc32cFactory();
+ OpenDBAndBackupEngine(true /* destroy_old_data */);
+ FillDB(db_.get(), 0, keys_iteration);
+ CloseAndReopenDB(/*read_only*/ true);
+ // corrupt a random table file in the DB directory
+ ASSERT_OK(CorruptRandomDataFileInDB(kTableFile));
+ // corruption is detected
+ ASSERT_NOK(backup_engine_->CreateNewBackup(db_.get()));
+ CloseDBAndBackupEngine();
+}
+
+// Test if BackupEngine will fail to create new backup if some blob files have
+// been corrupted and the blob file checksum is stored in the DB manifest for
+// the case when backup blob files will be stored in a shared directory
+TEST_P(BackupEngineTestWithParam, BlobFileCorruptedBeforeBackup) {
+ const int keys_iteration = 50000;
+ OpenDBAndBackupEngine(true /* destroy_old_data */);
+ FillDB(db_.get(), 0, keys_iteration);
+ CloseAndReopenDB(/*read_only*/ true);
+ // corrupt a random blob file in the DB directory
+ ASSERT_OK(CorruptRandomDataFileInDB(kBlobFile));
+ // cannot detect corruption since DB manifest has no blob file checksums
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get()));
+ CloseDBAndBackupEngine();
+
+ // delete old files in db
+ DestroyDBWithoutCheck(dbname_, options_);
+
+ // Enable blob file checksums in DB manifest
+ options_.file_checksum_gen_factory = GetFileChecksumGenCrc32cFactory();
+ OpenDBAndBackupEngine(true /* destroy_old_data */);
+ FillDB(db_.get(), 0, keys_iteration);
+ CloseAndReopenDB(/*read_only*/ true);
+ // corrupt a random blob file in the DB directory
+ ASSERT_OK(CorruptRandomDataFileInDB(kBlobFile));
+ // corruption is detected
+ ASSERT_NOK(backup_engine_->CreateNewBackup(db_.get()));
+ CloseDBAndBackupEngine();
+}
+#endif // !defined(ROCKSDB_VALGRIND_RUN) || defined(ROCKSDB_FULL_VALGRIND_RUN)
+
+TEST_F(BackupEngineTest, TableFileWithoutDbChecksumCorruptedDuringBackup) {
+ const int keys_iteration = 50000;
+ engine_options_->share_files_with_checksum_naming = kLegacyCrc32cAndFileSize;
+ // When share_files_with_checksum is on, we calculate checksums of table
+ // files before and after copying. So we can test whether a corruption has
+ // happened during the file is copied to backup directory.
+ OpenDBAndBackupEngine(true /* destroy_old_data */, false /* dummy */,
+ kShareWithChecksum);
+
+ FillDB(db_.get(), 0, keys_iteration);
+ std::atomic<bool> corrupted{false};
+ // corrupt files when copying to the backup directory
+ SyncPoint::GetInstance()->SetCallBack(
+ "BackupEngineImpl::CopyOrCreateFile:CorruptionDuringBackup",
+ [&](void* data) {
+ if (data != nullptr) {
+ Slice* d = reinterpret_cast<Slice*>(data);
+ if (!d->empty()) {
+ d->remove_suffix(1);
+ corrupted = true;
+ }
+ }
+ });
+ SyncPoint::GetInstance()->EnableProcessing();
+ Status s = backup_engine_->CreateNewBackup(db_.get());
+ if (corrupted) {
+ ASSERT_NOK(s);
+ } else {
+ // should not in this path in normal cases
+ ASSERT_OK(s);
+ }
+
+ SyncPoint::GetInstance()->DisableProcessing();
+ SyncPoint::GetInstance()->ClearAllCallBacks();
+
+ CloseDBAndBackupEngine();
+ // delete old files in db
+ DestroyDBWithoutCheck(dbname_, options_);
+}
+
+TEST_F(BackupEngineTest, TableFileWithDbChecksumCorruptedDuringBackup) {
+ const int keys_iteration = 50000;
+ options_.file_checksum_gen_factory = GetFileChecksumGenCrc32cFactory();
+ for (auto& sopt : kAllShareOptions) {
+ // Since the default DB table file checksum is on, we obtain checksums of
+ // table files from the DB manifest before copying and verify it with the
+ // one calculated during copying.
+ // Therefore, we can test whether a corruption has happened during the file
+ // being copied to backup directory.
+ OpenDBAndBackupEngine(true /* destroy_old_data */, false /* dummy */, sopt);
+
+ FillDB(db_.get(), 0, keys_iteration);
+
+ // corrupt files when copying to the backup directory
+ SyncPoint::GetInstance()->SetCallBack(
+ "BackupEngineImpl::CopyOrCreateFile:CorruptionDuringBackup",
+ [&](void* data) {
+ if (data != nullptr) {
+ Slice* d = reinterpret_cast<Slice*>(data);
+ if (!d->empty()) {
+ d->remove_suffix(1);
+ }
+ }
+ });
+ SyncPoint::GetInstance()->EnableProcessing();
+ // The only case that we can't detect a corruption is when the file
+ // being backed up is empty. But as keys_iteration is large, such
+ // a case shouldn't have happened and we should be able to detect
+ // the corruption.
+ ASSERT_NOK(backup_engine_->CreateNewBackup(db_.get()));
+
+ SyncPoint::GetInstance()->DisableProcessing();
+ SyncPoint::GetInstance()->ClearAllCallBacks();
+
+ CloseDBAndBackupEngine();
+ // delete old files in db
+ DestroyDBWithoutCheck(dbname_, options_);
+ }
+}
+
+TEST_F(BackupEngineTest, InterruptCreationTest) {
+ // Interrupt backup creation by failing new writes and failing cleanup of the
+ // partial state. Then verify a subsequent backup can still succeed.
+ const int keys_iteration = 5000;
+ Random rnd(6);
+
+ OpenDBAndBackupEngine(true /* destroy_old_data */);
+ FillDB(db_.get(), 0, keys_iteration);
+ test_backup_fs_->SetLimitWrittenFiles(2);
+ test_backup_fs_->SetDeleteFileFailure(true);
+ // should fail creation
+ ASSERT_NOK(backup_engine_->CreateNewBackup(db_.get(), !!(rnd.Next() % 2)));
+ CloseDBAndBackupEngine();
+ // should also fail cleanup so the tmp directory stays behind
+ ASSERT_OK(backup_chroot_env_->FileExists(backupdir_ + "/private/1/"));
+
+ OpenDBAndBackupEngine(false /* destroy_old_data */);
+ test_backup_fs_->SetLimitWrittenFiles(1000000);
+ test_backup_fs_->SetDeleteFileFailure(false);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), !!(rnd.Next() % 2)));
+ // latest backup should have all the keys
+ CloseDBAndBackupEngine();
+ AssertBackupConsistency(0, 0, keys_iteration);
+}
+
+TEST_F(BackupEngineTest, FlushCompactDuringBackupCheckpoint) {
+ const int keys_iteration = 5000;
+ options_.file_checksum_gen_factory = GetFileChecksumGenCrc32cFactory();
+ for (const auto& sopt : kAllShareOptions) {
+ OpenDBAndBackupEngine(true /* destroy_old_data */, false /* dummy */, sopt);
+ FillDB(db_.get(), 0, keys_iteration);
+ // That FillDB leaves a mix of flushed and unflushed data
+ SyncPoint::GetInstance()->LoadDependency(
+ {{"CheckpointImpl::CreateCustomCheckpoint:AfterGetLive1",
+ "BackupEngineTest::FlushCompactDuringBackupCheckpoint:Before"},
+ {"BackupEngineTest::FlushCompactDuringBackupCheckpoint:After",
+ "CheckpointImpl::CreateCustomCheckpoint:AfterGetLive2"}});
+ SyncPoint::GetInstance()->EnableProcessing();
+ ROCKSDB_NAMESPACE::port::Thread flush_thread{[this]() {
+ TEST_SYNC_POINT(
+ "BackupEngineTest::FlushCompactDuringBackupCheckpoint:Before");
+ FillDB(db_.get(), keys_iteration, 2 * keys_iteration);
+ ASSERT_OK(db_->Flush(FlushOptions()));
+ DBImpl* dbi = static_cast<DBImpl*>(db_.get());
+ ASSERT_OK(dbi->TEST_WaitForFlushMemTable());
+ ASSERT_OK(db_->CompactRange(CompactRangeOptions(), nullptr, nullptr));
+ ASSERT_OK(dbi->TEST_WaitForCompact());
+ TEST_SYNC_POINT(
+ "BackupEngineTest::FlushCompactDuringBackupCheckpoint:After");
+ }};
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get()));
+ flush_thread.join();
+ CloseDBAndBackupEngine();
+ SyncPoint::GetInstance()->DisableProcessing();
+ SyncPoint::GetInstance()->ClearAllCallBacks();
+ /* FIXME(peterd): reinstate with option for checksum in file names
+ if (sopt == kShareWithChecksum) {
+ // Ensure we actually got DB manifest checksums by inspecting
+ // shared_checksum file names for hex checksum component
+ TestRegex expected("[^_]+_[0-9A-F]{8}_[^_]+.sst");
+ std::vector<FileAttributes> children;
+ const std::string dir = backupdir_ + "/shared_checksum";
+ ASSERT_OK(file_manager_->GetChildrenFileAttributes(dir, &children));
+ for (const auto& child : children) {
+ if (child.size_bytes == 0) {
+ continue;
+ }
+ EXPECT_MATCHES_REGEX(child.name, expected);
+ }
+ }
+ */
+ AssertBackupConsistency(0, 0, keys_iteration);
+ }
+}
+
+inline std::string OptionsPath(std::string ret, int backupID) {
+ ret += "/private/";
+ ret += std::to_string(backupID);
+ ret += "/";
+ return ret;
+}
+
+// Backup the LATEST options file to
+// "<backup_dir>/private/<backup_id>/OPTIONS<number>"
+
+TEST_F(BackupEngineTest, BackupOptions) {
+ OpenDBAndBackupEngine(true);
+ for (int i = 1; i < 5; i++) {
+ std::string name;
+ std::vector<std::string> filenames;
+ // Must reset() before reset(OpenDB()) again.
+ // Calling OpenDB() while *db_ is existing will cause LOCK issue
+ db_.reset();
+ db_.reset(OpenDB());
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), true));
+ ASSERT_OK(ROCKSDB_NAMESPACE::GetLatestOptionsFileName(db_->GetName(),
+ options_.env, &name));
+ ASSERT_OK(file_manager_->FileExists(OptionsPath(backupdir_, i) + name));
+ ASSERT_OK(backup_chroot_env_->GetChildren(OptionsPath(backupdir_, i),
+ &filenames));
+ for (auto fn : filenames) {
+ if (fn.compare(0, 7, "OPTIONS") == 0) {
+ ASSERT_EQ(name, fn);
+ }
+ }
+ }
+
+ CloseDBAndBackupEngine();
+}
+
+TEST_F(BackupEngineTest, SetOptionsBackupRaceCondition) {
+ OpenDBAndBackupEngine(true);
+ SyncPoint::GetInstance()->LoadDependency(
+ {{"CheckpointImpl::CreateCheckpoint:SavedLiveFiles1",
+ "BackupEngineTest::SetOptionsBackupRaceCondition:BeforeSetOptions"},
+ {"BackupEngineTest::SetOptionsBackupRaceCondition:AfterSetOptions",
+ "CheckpointImpl::CreateCheckpoint:SavedLiveFiles2"}});
+ SyncPoint::GetInstance()->EnableProcessing();
+ ROCKSDB_NAMESPACE::port::Thread setoptions_thread{[this]() {
+ TEST_SYNC_POINT(
+ "BackupEngineTest::SetOptionsBackupRaceCondition:BeforeSetOptions");
+ DBImpl* dbi = static_cast<DBImpl*>(db_.get());
+ // Change arbitrary option to trigger OPTIONS file deletion
+ ASSERT_OK(dbi->SetOptions(dbi->DefaultColumnFamily(),
+ {{"paranoid_file_checks", "false"}}));
+ ASSERT_OK(dbi->SetOptions(dbi->DefaultColumnFamily(),
+ {{"paranoid_file_checks", "true"}}));
+ ASSERT_OK(dbi->SetOptions(dbi->DefaultColumnFamily(),
+ {{"paranoid_file_checks", "false"}}));
+ TEST_SYNC_POINT(
+ "BackupEngineTest::SetOptionsBackupRaceCondition:AfterSetOptions");
+ }};
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get()));
+ setoptions_thread.join();
+ CloseDBAndBackupEngine();
+}
+
+// This test verifies we don't delete the latest backup when read-only option is
+// set
+TEST_F(BackupEngineTest, NoDeleteWithReadOnly) {
+ const int keys_iteration = 5000;
+ Random rnd(6);
+
+ OpenDBAndBackupEngine(true);
+ // create five backups
+ for (int i = 0; i < 5; ++i) {
+ FillDB(db_.get(), keys_iteration * i, keys_iteration * (i + 1));
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), !!(rnd.Next() % 2)));
+ }
+ CloseDBAndBackupEngine();
+ ASSERT_OK(file_manager_->WriteToFile(latest_backup_, "4"));
+
+ engine_options_->destroy_old_data = false;
+ BackupEngineReadOnly* read_only_backup_engine;
+ ASSERT_OK(BackupEngineReadOnly::Open(
+ backup_chroot_env_.get(), *engine_options_, &read_only_backup_engine));
+
+ // assert that data from backup 5 is still here (even though LATEST_BACKUP
+ // says 4 is latest)
+ ASSERT_OK(file_manager_->FileExists(backupdir_ + "/meta/5"));
+ ASSERT_OK(file_manager_->FileExists(backupdir_ + "/private/5"));
+
+ // Behavior change: We now ignore LATEST_BACKUP contents. This means that
+ // we should have 5 backups, even if LATEST_BACKUP says 4.
+ std::vector<BackupInfo> backup_info;
+ read_only_backup_engine->GetBackupInfo(&backup_info);
+ ASSERT_EQ(5UL, backup_info.size());
+ delete read_only_backup_engine;
+}
+
+TEST_F(BackupEngineTest, FailOverwritingBackups) {
+ options_.write_buffer_size = 1024 * 1024 * 1024; // 1GB
+ options_.disable_auto_compactions = true;
+
+ // create backups 1, 2, 3, 4, 5
+ OpenDBAndBackupEngine(true);
+ for (int i = 0; i < 5; ++i) {
+ CloseDBAndBackupEngine();
+ DeleteLogFiles();
+ OpenDBAndBackupEngine(false);
+ FillDB(db_.get(), 100 * i, 100 * (i + 1), kFlushAll);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get()));
+ }
+ CloseDBAndBackupEngine();
+
+ // restore 3
+ OpenBackupEngine();
+ ASSERT_OK(backup_engine_->RestoreDBFromBackup(3, dbname_, dbname_));
+ CloseBackupEngine();
+
+ OpenDBAndBackupEngine(false);
+ // More data, bigger SST
+ FillDB(db_.get(), 1000, 1300, kFlushAll);
+ Status s = backup_engine_->CreateNewBackup(db_.get());
+ // the new backup fails because new table files
+ // clash with old table files from backups 4 and 5
+ // (since write_buffer_size is huge, we can be sure that
+ // each backup will generate only one sst file and that
+ // a file generated here would have the same name as an
+ // sst file generated by backup 4, and will be bigger)
+ ASSERT_TRUE(s.IsCorruption());
+ ASSERT_OK(backup_engine_->DeleteBackup(4));
+ ASSERT_OK(backup_engine_->DeleteBackup(5));
+ // now, the backup can succeed
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get()));
+ CloseDBAndBackupEngine();
+}
+
+TEST_F(BackupEngineTest, NoShareTableFiles) {
+ const int keys_iteration = 5000;
+ OpenDBAndBackupEngine(true, false, kNoShare);
+ for (int i = 0; i < 5; ++i) {
+ FillDB(db_.get(), keys_iteration * i, keys_iteration * (i + 1));
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), !!(i % 2)));
+ }
+ CloseDBAndBackupEngine();
+
+ for (int i = 0; i < 5; ++i) {
+ AssertBackupConsistency(i + 1, 0, keys_iteration * (i + 1),
+ keys_iteration * 6);
+ }
+}
+
+// Verify that you can backup and restore with share_files_with_checksum on
+TEST_F(BackupEngineTest, ShareTableFilesWithChecksums) {
+ const int keys_iteration = 5000;
+ OpenDBAndBackupEngine(true, false, kShareWithChecksum);
+ for (int i = 0; i < 5; ++i) {
+ FillDB(db_.get(), keys_iteration * i, keys_iteration * (i + 1));
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), !!(i % 2)));
+ }
+ CloseDBAndBackupEngine();
+
+ for (int i = 0; i < 5; ++i) {
+ AssertBackupConsistency(i + 1, 0, keys_iteration * (i + 1),
+ keys_iteration * 6);
+ }
+}
+
+// Verify that you can backup and restore using share_files_with_checksum set to
+// false and then transition this option to true
+TEST_F(BackupEngineTest, ShareTableFilesWithChecksumsTransition) {
+ const int keys_iteration = 5000;
+ // set share_files_with_checksum to false
+ OpenDBAndBackupEngine(true, false, kShareNoChecksum);
+ for (int i = 0; i < 5; ++i) {
+ FillDB(db_.get(), keys_iteration * i, keys_iteration * (i + 1));
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), true));
+ }
+ CloseDBAndBackupEngine();
+
+ for (int i = 0; i < 5; ++i) {
+ AssertBackupConsistency(i + 1, 0, keys_iteration * (i + 1),
+ keys_iteration * 6);
+ }
+
+ // set share_files_with_checksum to true and do some more backups
+ OpenDBAndBackupEngine(false /* destroy_old_data */, false,
+ kShareWithChecksum);
+ for (int i = 5; i < 10; ++i) {
+ FillDB(db_.get(), keys_iteration * i, keys_iteration * (i + 1));
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), true));
+ }
+ CloseDBAndBackupEngine();
+
+ // Verify first (about to delete)
+ AssertBackupConsistency(1, 0, keys_iteration, keys_iteration * 11);
+
+ // For an extra challenge, make sure that GarbageCollect / DeleteBackup
+ // is OK even if we open without share_table_files
+ OpenDBAndBackupEngine(false /* destroy_old_data */, false, kNoShare);
+ ASSERT_OK(backup_engine_->DeleteBackup(1));
+ ASSERT_OK(backup_engine_->GarbageCollect());
+ CloseDBAndBackupEngine();
+
+ // Verify rest (not deleted)
+ for (int i = 1; i < 10; ++i) {
+ AssertBackupConsistency(i + 1, 0, keys_iteration * (i + 1),
+ keys_iteration * 11);
+ }
+}
+
+// Verify backup and restore with various naming options, check names
+TEST_F(BackupEngineTest, ShareTableFilesWithChecksumsNewNaming) {
+ ASSERT_TRUE(engine_options_->share_files_with_checksum_naming ==
+ kNamingDefault);
+
+ const int keys_iteration = 5000;
+
+ OpenDBAndBackupEngine(true, false, kShareWithChecksum);
+ FillDB(db_.get(), 0, keys_iteration);
+ CloseDBAndBackupEngine();
+
+ static const std::map<ShareFilesNaming, TestRegex> option_to_expected = {
+ {kLegacyCrc32cAndFileSize, "[0-9]+_[0-9]+_[0-9]+[.]sst"},
+ // kFlagIncludeFileSize redundant here
+ {kLegacyCrc32cAndFileSize | kFlagIncludeFileSize,
+ "[0-9]+_[0-9]+_[0-9]+[.]sst"},
+ {kUseDbSessionId, "[0-9]+_s[0-9A-Z]{20}[.]sst"},
+ {kUseDbSessionId | kFlagIncludeFileSize,
+ "[0-9]+_s[0-9A-Z]{20}_[0-9]+[.]sst"},
+ };
+
+ const TestRegex blobfile_pattern = "[0-9]+_[0-9]+_[0-9]+[.]blob";
+
+ for (const auto& pair : option_to_expected) {
+ CloseAndReopenDB();
+ engine_options_->share_files_with_checksum_naming = pair.first;
+ OpenBackupEngine(true /*destroy_old_data*/);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get()));
+ CloseDBAndBackupEngine();
+ AssertBackupConsistency(1, 0, keys_iteration, keys_iteration * 2);
+ AssertDirectoryFilesMatchRegex(backupdir_ + "/shared_checksum", pair.second,
+ ".sst", 1 /* minimum_count */);
+ if (std::string::npos != pair.second.GetPattern().find("_[0-9]+[.]sst")) {
+ AssertDirectoryFilesSizeIndicators(backupdir_ + "/shared_checksum",
+ 1 /* minimum_count */);
+ }
+
+ AssertDirectoryFilesMatchRegex(backupdir_ + "/shared_checksum",
+ blobfile_pattern, ".blob",
+ 1 /* minimum_count */);
+ }
+}
+
+// Mimic SST file generated by pre-6.12 releases and verify that
+// old names are always used regardless of naming option.
+TEST_F(BackupEngineTest, ShareTableFilesWithChecksumsOldFileNaming) {
+ const int keys_iteration = 5000;
+
+ // Pre-6.12 release did not include db id and db session id properties.
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "PropertyBlockBuilder::AddTableProperty:Start", [&](void* props_vs) {
+ auto props = static_cast<TableProperties*>(props_vs);
+ props->db_id = "";
+ props->db_session_id = "";
+ });
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+
+ // Corrupting the table properties corrupts the unique id.
+ // Ignore the unique id recorded in the manifest.
+ options_.verify_sst_unique_id_in_manifest = false;
+
+ OpenDBAndBackupEngine(true, false, kShareWithChecksum);
+ FillDB(db_.get(), 0, keys_iteration);
+ CloseDBAndBackupEngine();
+
+ // Old names should always be used on old files
+ const TestRegex sstfile_pattern("[0-9]+_[0-9]+_[0-9]+[.]sst");
+
+ const TestRegex blobfile_pattern = "[0-9]+_[0-9]+_[0-9]+[.]blob";
+
+ for (ShareFilesNaming option : {kNamingDefault, kUseDbSessionId}) {
+ CloseAndReopenDB();
+ engine_options_->share_files_with_checksum_naming = option;
+ OpenBackupEngine(true /*destroy_old_data*/);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get()));
+ CloseDBAndBackupEngine();
+ AssertBackupConsistency(1, 0, keys_iteration, keys_iteration * 2);
+ AssertDirectoryFilesMatchRegex(backupdir_ + "/shared_checksum",
+ sstfile_pattern, ".sst",
+ 1 /* minimum_count */);
+ AssertDirectoryFilesMatchRegex(backupdir_ + "/shared_checksum",
+ blobfile_pattern, ".blob",
+ 1 /* minimum_count */);
+ }
+
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->ClearAllCallBacks();
+}
+
+// Test how naming options interact with detecting DB corruption
+// between incremental backups
+TEST_F(BackupEngineTest, TableFileCorruptionBeforeIncremental) {
+ const auto share_no_checksum = static_cast<ShareFilesNaming>(0);
+
+ for (bool corrupt_before_first_backup : {false, true}) {
+ for (ShareFilesNaming option :
+ {share_no_checksum, kLegacyCrc32cAndFileSize, kNamingDefault}) {
+ auto share =
+ option == share_no_checksum ? kShareNoChecksum : kShareWithChecksum;
+ if (option != share_no_checksum) {
+ engine_options_->share_files_with_checksum_naming = option;
+ }
+ OpenDBAndBackupEngine(true, false, share);
+ DBImpl* dbi = static_cast<DBImpl*>(db_.get());
+ // A small SST file
+ ASSERT_OK(dbi->Put(WriteOptions(), "x", "y"));
+ ASSERT_OK(dbi->Flush(FlushOptions()));
+ // And a bigger one
+ ASSERT_OK(dbi->Put(WriteOptions(), "y", Random(42).RandomString(500)));
+ ASSERT_OK(dbi->Flush(FlushOptions()));
+ ASSERT_OK(dbi->TEST_WaitForFlushMemTable());
+ CloseAndReopenDB(/*read_only*/ true);
+
+ std::vector<FileAttributes> table_files;
+ ASSERT_OK(GetDataFilesInDB(kTableFile, &table_files));
+ ASSERT_EQ(table_files.size(), 2);
+ std::string tf0 = dbname_ + "/" + table_files[0].name;
+ std::string tf1 = dbname_ + "/" + table_files[1].name;
+
+ CloseDBAndBackupEngine();
+
+ if (corrupt_before_first_backup) {
+ // This corrupts a data block, which does not cause DB open
+ // failure, only failure on accessing the block.
+ ASSERT_OK(db_file_manager_->CorruptFileStart(tf0));
+ }
+
+ OpenDBAndBackupEngine(false, false, share);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get()));
+ CloseDBAndBackupEngine();
+
+ // if corrupt_before_first_backup, this undoes the initial corruption
+ ASSERT_OK(db_file_manager_->CorruptFileStart(tf0));
+
+ OpenDBAndBackupEngine(false, false, share);
+ Status s = backup_engine_->CreateNewBackup(db_.get());
+
+ // Even though none of the naming options catch the inconsistency
+ // between the first and second time backing up fname, in the case
+ // of kUseDbSessionId (kNamingDefault), this is an intentional
+ // trade-off to avoid full scan of files from the DB that are
+ // already backed up. If we did the scan, kUseDbSessionId could catch
+ // the corruption. kLegacyCrc32cAndFileSize does the scan (to
+ // compute checksum for name) without catching the corruption,
+ // because the corruption means the names don't merge.
+ EXPECT_OK(s);
+
+ // VerifyBackup doesn't check DB integrity or table file internal
+ // checksums
+ EXPECT_OK(backup_engine_->VerifyBackup(1, true));
+ EXPECT_OK(backup_engine_->VerifyBackup(2, true));
+
+ db_.reset();
+ ASSERT_OK(backup_engine_->RestoreDBFromBackup(2, dbname_, dbname_));
+ {
+ DB* db = OpenDB();
+ s = db->VerifyChecksum();
+ delete db;
+ }
+ if (option != kLegacyCrc32cAndFileSize && !corrupt_before_first_backup) {
+ // Second backup is OK because it used (uncorrupt) file from first
+ // backup instead of (corrupt) file from DB.
+ // This is arguably a good trade-off vs. treating the file as distinct
+ // from the old version, because a file should be more likely to be
+ // corrupt as it ages. Although the backed-up file might also corrupt
+ // with age, the alternative approach (checksum in file name computed
+ // from current DB file contents) wouldn't detect that case at backup
+ // time either. Although you would have both copies of the file with
+ // the alternative approach, that would only last until the older
+ // backup is deleted.
+ ASSERT_OK(s);
+ } else if (option == kLegacyCrc32cAndFileSize &&
+ corrupt_before_first_backup) {
+ // Second backup is OK because it saved the updated (uncorrupt)
+ // file from DB, instead of the sharing with first backup.
+ // Recall: if corrupt_before_first_backup, [second CorruptFileStart]
+ // undoes the initial corruption.
+ // This is arguably a bad trade-off vs. sharing the old version of the
+ // file because a file should be more likely to corrupt as it ages.
+ // (Not likely that the previously backed-up version was already
+ // corrupt and the new version is non-corrupt. This approach doesn't
+ // help if backed-up version is corrupted after taking the backup.)
+ ASSERT_OK(s);
+ } else {
+ // Something is legitimately corrupted, but we can't be sure what
+ // with information available (TODO? unless one passes block checksum
+ // test and other doesn't. Probably better to use end-to-end full file
+ // checksum anyway.)
+ ASSERT_TRUE(s.IsCorruption());
+ }
+
+ CloseDBAndBackupEngine();
+ DestroyDBWithoutCheck(dbname_, options_);
+ }
+ }
+}
+
+// Test how naming options interact with detecting file size corruption
+// between incremental backups
+TEST_F(BackupEngineTest, FileSizeForIncremental) {
+ const auto share_no_checksum = static_cast<ShareFilesNaming>(0);
+ // TODO: enable blob files once Integrated BlobDB supports DB session id.
+ options_.enable_blob_files = false;
+
+ for (ShareFilesNaming option : {share_no_checksum, kLegacyCrc32cAndFileSize,
+ kNamingDefault, kUseDbSessionId}) {
+ auto share =
+ option == share_no_checksum ? kShareNoChecksum : kShareWithChecksum;
+ if (option != share_no_checksum) {
+ engine_options_->share_files_with_checksum_naming = option;
+ }
+ OpenDBAndBackupEngine(true, false, share);
+
+ std::vector<FileAttributes> children;
+ const std::string shared_dir =
+ backupdir_ +
+ (option == share_no_checksum ? "/shared" : "/shared_checksum");
+
+ // A single small SST file
+ ASSERT_OK(db_->Put(WriteOptions(), "x", "y"));
+
+ // First, test that we always detect file size corruption on the shared
+ // backup side on incremental. (Since sizes aren't really part of backup
+ // meta file, this works by querying the filesystem for the sizes.)
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), true /*flush*/));
+ CloseDBAndBackupEngine();
+
+ // Corrupt backup SST file
+ ASSERT_OK(file_manager_->GetChildrenFileAttributes(shared_dir, &children));
+ ASSERT_EQ(children.size(), 1U); // one sst
+ for (const auto& child : children) {
+ if (child.name.size() > 4 && child.size_bytes > 0) {
+ ASSERT_OK(
+ file_manager_->WriteToFile(shared_dir + "/" + child.name, "asdf"));
+ break;
+ }
+ }
+
+ OpenDBAndBackupEngine(false, false, share);
+ Status s = backup_engine_->CreateNewBackup(db_.get());
+ EXPECT_TRUE(s.IsCorruption());
+
+ ASSERT_OK(backup_engine_->PurgeOldBackups(0));
+ CloseDBAndBackupEngine();
+
+ // Second, test that a hypothetical db session id collision would likely
+ // not suffice to corrupt a backup, because there's a good chance of
+ // file size difference (in this test, guaranteed) so either no name
+ // collision or detected collision.
+
+ // Create backup 1
+ OpenDBAndBackupEngine(false, false, share);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get()));
+
+ // Even though we have "the same" DB state as backup 1, we need
+ // to restore to recreate the same conditions as later restore.
+ db_.reset();
+ DestroyDBWithoutCheck(dbname_, options_);
+ ASSERT_OK(backup_engine_->RestoreDBFromBackup(1, dbname_, dbname_));
+ CloseDBAndBackupEngine();
+
+ // Forge session id
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "DBImpl::SetDbSessionId", [](void* sid_void_star) {
+ std::string* sid = static_cast<std::string*>(sid_void_star);
+ *sid = "01234567890123456789";
+ });
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+
+ // Create another SST file
+ OpenDBAndBackupEngine(false, false, share);
+ ASSERT_OK(db_->Put(WriteOptions(), "y", "x"));
+
+ // Create backup 2
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), true /*flush*/));
+
+ // Restore backup 1 (again)
+ db_.reset();
+ DestroyDBWithoutCheck(dbname_, options_);
+ ASSERT_OK(backup_engine_->RestoreDBFromBackup(1, dbname_, dbname_));
+ CloseDBAndBackupEngine();
+
+ // Create another SST file with same number and db session id, only bigger
+ OpenDBAndBackupEngine(false, false, share);
+ ASSERT_OK(db_->Put(WriteOptions(), "y", Random(42).RandomString(500)));
+
+ // Count backup SSTs files.
+ children.clear();
+ ASSERT_OK(file_manager_->GetChildrenFileAttributes(shared_dir, &children));
+ ASSERT_EQ(children.size(), 2U); // two sst files
+
+ // Try create backup 3
+ s = backup_engine_->CreateNewBackup(db_.get(), true /*flush*/);
+
+ // Re-count backup SSTs
+ children.clear();
+ ASSERT_OK(file_manager_->GetChildrenFileAttributes(shared_dir, &children));
+
+ if (option == kUseDbSessionId) {
+ // Acceptable to call it corruption if size is not in name and
+ // db session id collision is practically impossible.
+ EXPECT_TRUE(s.IsCorruption());
+ EXPECT_EQ(children.size(), 2U); // no SST file added
+ } else if (option == share_no_checksum) {
+ // Good to call it corruption if both backups cannot be
+ // accommodated.
+ EXPECT_TRUE(s.IsCorruption());
+ EXPECT_EQ(children.size(), 2U); // no SST file added
+ } else {
+ // Since opening a DB seems sufficient for detecting size corruption
+ // on the DB side, this should be a good thing, ...
+ EXPECT_OK(s);
+ // ... as long as we did actually treat it as a distinct SST file.
+ EXPECT_EQ(children.size(), 3U); // Another SST added
+ }
+ CloseDBAndBackupEngine();
+ DestroyDBWithoutCheck(dbname_, options_);
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->ClearAllCallBacks();
+ }
+}
+
+// Verify backup and restore with share_files_with_checksum off and then
+// transition this option to on and share_files_with_checksum_naming to be
+// based on kUseDbSessionId
+TEST_F(BackupEngineTest, ShareTableFilesWithChecksumsNewNamingTransition) {
+ const int keys_iteration = 5000;
+ // We may set share_files_with_checksum_naming to kLegacyCrc32cAndFileSize
+ // here but even if we don't, it should have no effect when
+ // share_files_with_checksum is false
+ ASSERT_TRUE(engine_options_->share_files_with_checksum_naming ==
+ kNamingDefault);
+ // set share_files_with_checksum to false
+ OpenDBAndBackupEngine(true, false, kShareNoChecksum);
+ int j = 3;
+ for (int i = 0; i < j; ++i) {
+ FillDB(db_.get(), keys_iteration * i, keys_iteration * (i + 1));
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), true));
+ }
+ CloseDBAndBackupEngine();
+
+ for (int i = 0; i < j; ++i) {
+ AssertBackupConsistency(i + 1, 0, keys_iteration * (i + 1),
+ keys_iteration * (j + 1));
+ }
+
+ // set share_files_with_checksum to true and do some more backups
+ // and use session id in the name of SST file backup
+ ASSERT_TRUE(engine_options_->share_files_with_checksum_naming ==
+ kNamingDefault);
+ OpenDBAndBackupEngine(false /* destroy_old_data */, false,
+ kShareWithChecksum);
+ FillDB(db_.get(), keys_iteration * j, keys_iteration * (j + 1));
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), true));
+ CloseDBAndBackupEngine();
+ // Use checksum in the name as well
+ ++j;
+ options_.file_checksum_gen_factory = GetFileChecksumGenCrc32cFactory();
+ OpenDBAndBackupEngine(false /* destroy_old_data */, false,
+ kShareWithChecksum);
+ FillDB(db_.get(), keys_iteration * j, keys_iteration * (j + 1));
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), true));
+ CloseDBAndBackupEngine();
+
+ // Verify first (about to delete)
+ AssertBackupConsistency(1, 0, keys_iteration, keys_iteration * (j + 1));
+
+ // For an extra challenge, make sure that GarbageCollect / DeleteBackup
+ // is OK even if we open without share_table_files but with
+ // share_files_with_checksum_naming based on kUseDbSessionId
+ ASSERT_TRUE(engine_options_->share_files_with_checksum_naming ==
+ kNamingDefault);
+ OpenDBAndBackupEngine(false /* destroy_old_data */, false, kNoShare);
+ ASSERT_OK(backup_engine_->DeleteBackup(1));
+ ASSERT_OK(backup_engine_->GarbageCollect());
+ CloseDBAndBackupEngine();
+
+ // Verify second (about to delete)
+ AssertBackupConsistency(2, 0, keys_iteration * 2, keys_iteration * (j + 1));
+
+ // Use checksum and file size for backup table file names and open without
+ // share_table_files
+ // Again, make sure that GarbageCollect / DeleteBackup is OK
+ engine_options_->share_files_with_checksum_naming = kLegacyCrc32cAndFileSize;
+ OpenDBAndBackupEngine(false /* destroy_old_data */, false, kNoShare);
+ ASSERT_OK(backup_engine_->DeleteBackup(2));
+ ASSERT_OK(backup_engine_->GarbageCollect());
+ CloseDBAndBackupEngine();
+
+ // Verify rest (not deleted)
+ for (int i = 2; i < j; ++i) {
+ AssertBackupConsistency(i + 1, 0, keys_iteration * (i + 1),
+ keys_iteration * (j + 1));
+ }
+}
+
+// Verify backup and restore with share_files_with_checksum on and transition
+// from kLegacyCrc32cAndFileSize to kUseDbSessionId
+TEST_F(BackupEngineTest, ShareTableFilesWithChecksumsNewNamingUpgrade) {
+ engine_options_->share_files_with_checksum_naming = kLegacyCrc32cAndFileSize;
+ const int keys_iteration = 5000;
+ // set share_files_with_checksum to true
+ OpenDBAndBackupEngine(true, false, kShareWithChecksum);
+ int j = 3;
+ for (int i = 0; i < j; ++i) {
+ FillDB(db_.get(), keys_iteration * i, keys_iteration * (i + 1));
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), true));
+ }
+ CloseDBAndBackupEngine();
+
+ for (int i = 0; i < j; ++i) {
+ AssertBackupConsistency(i + 1, 0, keys_iteration * (i + 1),
+ keys_iteration * (j + 1));
+ }
+
+ engine_options_->share_files_with_checksum_naming = kUseDbSessionId;
+ OpenDBAndBackupEngine(false /* destroy_old_data */, false,
+ kShareWithChecksum);
+ FillDB(db_.get(), keys_iteration * j, keys_iteration * (j + 1));
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), true));
+ CloseDBAndBackupEngine();
+
+ ++j;
+ options_.file_checksum_gen_factory = GetFileChecksumGenCrc32cFactory();
+ OpenDBAndBackupEngine(false /* destroy_old_data */, false,
+ kShareWithChecksum);
+ FillDB(db_.get(), keys_iteration * j, keys_iteration * (j + 1));
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), true));
+ CloseDBAndBackupEngine();
+
+ // Verify first (about to delete)
+ AssertBackupConsistency(1, 0, keys_iteration, keys_iteration * (j + 1));
+
+ // For an extra challenge, make sure that GarbageCollect / DeleteBackup
+ // is OK even if we open without share_table_files
+ OpenDBAndBackupEngine(false /* destroy_old_data */, false, kNoShare);
+ ASSERT_OK(backup_engine_->DeleteBackup(1));
+ ASSERT_OK(backup_engine_->GarbageCollect());
+ CloseDBAndBackupEngine();
+
+ // Verify second (about to delete)
+ AssertBackupConsistency(2, 0, keys_iteration * 2, keys_iteration * (j + 1));
+
+ // Use checksum and file size for backup table file names and open without
+ // share_table_files
+ // Again, make sure that GarbageCollect / DeleteBackup is OK
+ engine_options_->share_files_with_checksum_naming = kLegacyCrc32cAndFileSize;
+ OpenDBAndBackupEngine(false /* destroy_old_data */, false, kNoShare);
+ ASSERT_OK(backup_engine_->DeleteBackup(2));
+ ASSERT_OK(backup_engine_->GarbageCollect());
+ CloseDBAndBackupEngine();
+
+ // Verify rest (not deleted)
+ for (int i = 2; i < j; ++i) {
+ AssertBackupConsistency(i + 1, 0, keys_iteration * (i + 1),
+ keys_iteration * (j + 1));
+ }
+}
+
+// This test simulates cleaning up after aborted or incomplete creation
+// of a new backup.
+TEST_F(BackupEngineTest, DeleteTmpFiles) {
+ for (int cleanup_fn : {1, 2, 3, 4}) {
+ for (ShareOption shared_option : kAllShareOptions) {
+ OpenDBAndBackupEngine(false /* destroy_old_data */, false /* dummy */,
+ shared_option);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get()));
+ BackupID next_id = 1;
+ BackupID oldest_id = std::numeric_limits<BackupID>::max();
+ {
+ std::vector<BackupInfo> backup_info;
+ backup_engine_->GetBackupInfo(&backup_info);
+ for (const auto& bi : backup_info) {
+ next_id = std::max(next_id, bi.backup_id + 1);
+ oldest_id = std::min(oldest_id, bi.backup_id);
+ }
+ }
+ CloseDBAndBackupEngine();
+
+ // An aborted or incomplete new backup will always be in the next
+ // id (maybe more)
+ std::string next_private = "private/" + std::to_string(next_id);
+
+ // NOTE: both shared and shared_checksum should be cleaned up
+ // regardless of how the backup engine is opened.
+ std::vector<std::string> tmp_files_and_dirs;
+ for (const auto& dir_and_file : {
+ std::make_pair(std::string("shared"),
+ std::string(".00006.sst.tmp")),
+ std::make_pair(std::string("shared_checksum"),
+ std::string(".00007.sst.tmp")),
+ std::make_pair(next_private, std::string("00003.sst")),
+ }) {
+ std::string dir = backupdir_ + "/" + dir_and_file.first;
+ ASSERT_OK(file_manager_->CreateDirIfMissing(dir));
+ ASSERT_OK(file_manager_->FileExists(dir));
+
+ std::string file = dir + "/" + dir_and_file.second;
+ ASSERT_OK(file_manager_->WriteToFile(file, "tmp"));
+ ASSERT_OK(file_manager_->FileExists(file));
+
+ tmp_files_and_dirs.push_back(file);
+ }
+ if (cleanup_fn != /*CreateNewBackup*/ 4) {
+ // This exists after CreateNewBackup because it's deleted then
+ // re-created.
+ tmp_files_and_dirs.push_back(backupdir_ + "/" + next_private);
+ }
+
+ OpenDBAndBackupEngine(false /* destroy_old_data */, false /* dummy */,
+ shared_option);
+ // Need to call one of these explicitly to delete tmp files
+ switch (cleanup_fn) {
+ case 1:
+ ASSERT_OK(backup_engine_->GarbageCollect());
+ break;
+ case 2:
+ ASSERT_OK(backup_engine_->DeleteBackup(oldest_id));
+ break;
+ case 3:
+ ASSERT_OK(backup_engine_->PurgeOldBackups(1));
+ break;
+ case 4:
+ // Does a garbage collect if it sees that next private dir exists
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get()));
+ break;
+ default:
+ assert(false);
+ }
+ CloseDBAndBackupEngine();
+ for (std::string file_or_dir : tmp_files_and_dirs) {
+ if (file_manager_->FileExists(file_or_dir) != Status::NotFound()) {
+ FAIL() << file_or_dir << " was expected to be deleted." << cleanup_fn;
+ }
+ }
+ }
+ }
+}
+
+TEST_F(BackupEngineTest, KeepLogFiles) {
+ engine_options_->backup_log_files = false;
+ // basically infinite
+ options_.WAL_ttl_seconds = 24 * 60 * 60;
+ OpenDBAndBackupEngine(true);
+ FillDB(db_.get(), 0, 100, kFlushAll);
+ FillDB(db_.get(), 100, 200, kFlushAll);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), false));
+ FillDB(db_.get(), 200, 300, kFlushAll);
+ FillDB(db_.get(), 300, 400, kFlushAll);
+ FillDB(db_.get(), 400, 500, kFlushAll);
+ CloseDBAndBackupEngine();
+
+ // all data should be there if we call with keep_log_files = true
+ AssertBackupConsistency(0, 0, 500, 600, true);
+}
+
+#if !defined(ROCKSDB_VALGRIND_RUN) || defined(ROCKSDB_FULL_VALGRIND_RUN)
+class BackupEngineRateLimitingTestWithParam
+ : public BackupEngineTest,
+ public testing::WithParamInterface<
+ std::tuple<bool /* make throttle */,
+ int /* 0 = single threaded, 1 = multi threaded*/,
+ std::pair<uint64_t, uint64_t> /* limits */>> {
+ public:
+ BackupEngineRateLimitingTestWithParam() {}
+};
+
+uint64_t const MB = 1024 * 1024;
+
+INSTANTIATE_TEST_CASE_P(
+ RateLimiting, BackupEngineRateLimitingTestWithParam,
+ ::testing::Values(std::make_tuple(false, 0, std::make_pair(1 * MB, 5 * MB)),
+ std::make_tuple(false, 0, std::make_pair(2 * MB, 3 * MB)),
+ std::make_tuple(false, 1, std::make_pair(1 * MB, 5 * MB)),
+ std::make_tuple(false, 1, std::make_pair(2 * MB, 3 * MB)),
+ std::make_tuple(true, 0, std::make_pair(1 * MB, 5 * MB)),
+ std::make_tuple(true, 0, std::make_pair(2 * MB, 3 * MB)),
+ std::make_tuple(true, 1, std::make_pair(1 * MB, 5 * MB)),
+ std::make_tuple(true, 1,
+ std::make_pair(2 * MB, 3 * MB))));
+
+TEST_P(BackupEngineRateLimitingTestWithParam, RateLimiting) {
+ size_t const kMicrosPerSec = 1000 * 1000LL;
+ const bool custom_rate_limiter = std::get<0>(GetParam());
+ // iter 0 -- single threaded
+ // iter 1 -- multi threaded
+ const int iter = std::get<1>(GetParam());
+ const std::pair<uint64_t, uint64_t> limit = std::get<2>(GetParam());
+ std::unique_ptr<Env> special_env(
+ new SpecialEnv(db_chroot_env_.get(), /*time_elapse_only_sleep*/ true));
+ // destroy old data
+ Options options;
+ options.env = special_env.get();
+ DestroyDBWithoutCheck(dbname_, options);
+
+ if (custom_rate_limiter) {
+ std::shared_ptr<RateLimiter> backup_rate_limiter =
+ std::make_shared<GenericRateLimiter>(
+ limit.first, 100 * 1000 /* refill_period_us */, 10 /* fairness */,
+ RateLimiter::Mode::kWritesOnly /* mode */,
+ special_env->GetSystemClock(), false /* auto_tuned */);
+ std::shared_ptr<RateLimiter> restore_rate_limiter =
+ std::make_shared<GenericRateLimiter>(
+ limit.second, 100 * 1000 /* refill_period_us */, 10 /* fairness */,
+ RateLimiter::Mode::kWritesOnly /* mode */,
+ special_env->GetSystemClock(), false /* auto_tuned */);
+ engine_options_->backup_rate_limiter = backup_rate_limiter;
+ engine_options_->restore_rate_limiter = restore_rate_limiter;
+ } else {
+ engine_options_->backup_rate_limit = limit.first;
+ engine_options_->restore_rate_limit = limit.second;
+ }
+
+ engine_options_->max_background_operations = (iter == 0) ? 1 : 10;
+ options_.compression = kNoCompression;
+
+ // Rate limiter uses `CondVar::TimedWait()`, which does not have access to the
+ // `Env` to advance its time according to the fake wait duration. The
+ // workaround is to install a callback that advance the `Env`'s mock time.
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "GenericRateLimiter::Request:PostTimedWait", [&](void* arg) {
+ int64_t time_waited_us = *static_cast<int64_t*>(arg);
+ special_env->SleepForMicroseconds(static_cast<int>(time_waited_us));
+ });
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+
+ OpenDBAndBackupEngine(true);
+ TEST_SetDefaultRateLimitersClock(backup_engine_.get(),
+ special_env->GetSystemClock());
+
+ size_t bytes_written = FillDB(db_.get(), 0, 10000);
+
+ auto start_backup = special_env->NowMicros();
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), false));
+ auto backup_time = special_env->NowMicros() - start_backup;
+ CloseDBAndBackupEngine();
+ auto rate_limited_backup_time = (bytes_written * kMicrosPerSec) / limit.first;
+ ASSERT_GT(backup_time, 0.8 * rate_limited_backup_time);
+
+ OpenBackupEngine();
+ TEST_SetDefaultRateLimitersClock(
+ backup_engine_.get(),
+ special_env->GetSystemClock() /* backup_rate_limiter_clock */,
+ special_env->GetSystemClock() /* restore_rate_limiter_clock */);
+
+ auto start_restore = special_env->NowMicros();
+ ASSERT_OK(backup_engine_->RestoreDBFromLatestBackup(dbname_, dbname_));
+ auto restore_time = special_env->NowMicros() - start_restore;
+ CloseBackupEngine();
+ auto rate_limited_restore_time =
+ (bytes_written * kMicrosPerSec) / limit.second;
+ ASSERT_GT(restore_time, 0.8 * rate_limited_restore_time);
+
+ AssertBackupConsistency(0, 0, 10000, 10100);
+
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->ClearCallBack(
+ "GenericRateLimiter::Request:PostTimedWait");
+}
+
+TEST_P(BackupEngineRateLimitingTestWithParam, RateLimitingVerifyBackup) {
+ const std::size_t kMicrosPerSec = 1000 * 1000LL;
+ const bool custom_rate_limiter = std::get<0>(GetParam());
+ const std::uint64_t backup_rate_limiter_limit = std::get<2>(GetParam()).first;
+ const bool is_single_threaded = std::get<1>(GetParam()) == 0 ? true : false;
+ std::unique_ptr<Env> special_env(
+ new SpecialEnv(db_chroot_env_.get(), /*time_elapse_only_sleep*/ true));
+
+ if (custom_rate_limiter) {
+ std::shared_ptr<RateLimiter> backup_rate_limiter =
+ std::make_shared<GenericRateLimiter>(
+ backup_rate_limiter_limit, 100 * 1000 /* refill_period_us */,
+ 10 /* fairness */, RateLimiter::Mode::kAllIo /* mode */,
+ special_env->GetSystemClock(), false /* auto_tuned */);
+ engine_options_->backup_rate_limiter = backup_rate_limiter;
+ } else {
+ engine_options_->backup_rate_limit = backup_rate_limiter_limit;
+ }
+
+ engine_options_->max_background_operations = is_single_threaded ? 1 : 10;
+
+ Options options;
+ options.env = special_env.get();
+ DestroyDBWithoutCheck(dbname_, options);
+ // Rate limiter uses `CondVar::TimedWait()`, which does not have access to the
+ // `Env` to advance its time according to the fake wait duration. The
+ // workaround is to install a callback that advance the `Env`'s mock time.
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "GenericRateLimiter::Request:PostTimedWait", [&](void* arg) {
+ int64_t time_waited_us = *static_cast<int64_t*>(arg);
+ special_env->SleepForMicroseconds(static_cast<int>(time_waited_us));
+ });
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+ OpenDBAndBackupEngine(true /* destroy_old_data */);
+ TEST_SetDefaultRateLimitersClock(backup_engine_.get(),
+ special_env->GetSystemClock(), nullptr);
+ FillDB(db_.get(), 0, 10000);
+
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(),
+ false /* flush_before_backup */));
+
+ std::vector<BackupInfo> backup_infos;
+ BackupInfo backup_info;
+ backup_engine_->GetBackupInfo(&backup_infos);
+ ASSERT_EQ(1, backup_infos.size());
+ const int backup_id = 1;
+ ASSERT_EQ(backup_id, backup_infos[0].backup_id);
+ ASSERT_OK(backup_engine_->GetBackupInfo(backup_id, &backup_info,
+ true /* include_file_details */));
+
+ std::uint64_t bytes_read_during_verify_backup = 0;
+ for (BackupFileInfo backup_file_info : backup_info.file_details) {
+ bytes_read_during_verify_backup += backup_file_info.size;
+ }
+ auto start_verify_backup = special_env->NowMicros();
+ ASSERT_OK(
+ backup_engine_->VerifyBackup(backup_id, true /* verify_with_checksum */));
+ auto verify_backup_time = special_env->NowMicros() - start_verify_backup;
+ auto rate_limited_verify_backup_time =
+ (bytes_read_during_verify_backup * kMicrosPerSec) /
+ backup_rate_limiter_limit;
+ if (custom_rate_limiter) {
+ EXPECT_GE(verify_backup_time, 0.8 * rate_limited_verify_backup_time);
+ }
+
+ CloseDBAndBackupEngine();
+ AssertBackupConsistency(backup_id, 0, 10000, 10010);
+ DestroyDBWithoutCheck(dbname_, options);
+
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->ClearCallBack(
+ "GenericRateLimiter::Request:PostTimedWait");
+}
+
+TEST_P(BackupEngineRateLimitingTestWithParam, RateLimitingChargeReadInBackup) {
+ bool is_single_threaded = std::get<1>(GetParam()) == 0 ? true : false;
+ engine_options_->max_background_operations = is_single_threaded ? 1 : 10;
+
+ const std::uint64_t backup_rate_limiter_limit = std::get<2>(GetParam()).first;
+ std::shared_ptr<RateLimiter> backup_rate_limiter(NewGenericRateLimiter(
+ backup_rate_limiter_limit, 100 * 1000 /* refill_period_us */,
+ 10 /* fairness */, RateLimiter::Mode::kWritesOnly /* mode */));
+ engine_options_->backup_rate_limiter = backup_rate_limiter;
+
+ DestroyDBWithoutCheck(dbname_, Options());
+ OpenDBAndBackupEngine(true /* destroy_old_data */, false /* dummy */,
+ kShareWithChecksum /* shared_option */);
+ FillDB(db_.get(), 0, 10);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(),
+ false /* flush_before_backup */));
+ std::int64_t total_bytes_through_with_no_read_charged =
+ backup_rate_limiter->GetTotalBytesThrough();
+ CloseBackupEngine();
+
+ backup_rate_limiter.reset(NewGenericRateLimiter(
+ backup_rate_limiter_limit, 100 * 1000 /* refill_period_us */,
+ 10 /* fairness */, RateLimiter::Mode::kAllIo /* mode */));
+ engine_options_->backup_rate_limiter = backup_rate_limiter;
+
+ OpenBackupEngine(true);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(),
+ false /* flush_before_backup */));
+ std::int64_t total_bytes_through_with_read_charged =
+ backup_rate_limiter->GetTotalBytesThrough();
+ EXPECT_GT(total_bytes_through_with_read_charged,
+ total_bytes_through_with_no_read_charged);
+ CloseDBAndBackupEngine();
+ AssertBackupConsistency(1, 0, 10, 20);
+ DestroyDBWithoutCheck(dbname_, Options());
+}
+
+TEST_P(BackupEngineRateLimitingTestWithParam, RateLimitingChargeReadInRestore) {
+ bool is_single_threaded = std::get<1>(GetParam()) == 0 ? true : false;
+ engine_options_->max_background_operations = is_single_threaded ? 1 : 10;
+
+ const std::uint64_t restore_rate_limiter_limit =
+ std::get<2>(GetParam()).second;
+ std::shared_ptr<RateLimiter> restore_rate_limiter(NewGenericRateLimiter(
+ restore_rate_limiter_limit, 100 * 1000 /* refill_period_us */,
+ 10 /* fairness */, RateLimiter::Mode::kWritesOnly /* mode */));
+ engine_options_->restore_rate_limiter = restore_rate_limiter;
+
+ DestroyDBWithoutCheck(dbname_, Options());
+ OpenDBAndBackupEngine(true /* destroy_old_data */);
+ FillDB(db_.get(), 0, 10);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(),
+ false /* flush_before_backup */));
+ CloseDBAndBackupEngine();
+ DestroyDBWithoutCheck(dbname_, Options());
+
+ OpenBackupEngine(false /* destroy_old_data */);
+ ASSERT_OK(backup_engine_->RestoreDBFromLatestBackup(dbname_, dbname_));
+ std::int64_t total_bytes_through_with_no_read_charged =
+ restore_rate_limiter->GetTotalBytesThrough();
+ CloseBackupEngine();
+ DestroyDBWithoutCheck(dbname_, Options());
+
+ restore_rate_limiter.reset(NewGenericRateLimiter(
+ restore_rate_limiter_limit, 100 * 1000 /* refill_period_us */,
+ 10 /* fairness */, RateLimiter::Mode::kAllIo /* mode */));
+ engine_options_->restore_rate_limiter = restore_rate_limiter;
+
+ OpenBackupEngine(false /* destroy_old_data */);
+ ASSERT_OK(backup_engine_->RestoreDBFromLatestBackup(dbname_, dbname_));
+ std::int64_t total_bytes_through_with_read_charged =
+ restore_rate_limiter->GetTotalBytesThrough();
+ EXPECT_EQ(total_bytes_through_with_read_charged,
+ total_bytes_through_with_no_read_charged * 2);
+ CloseBackupEngine();
+ AssertBackupConsistency(1, 0, 10, 20);
+ DestroyDBWithoutCheck(dbname_, Options());
+}
+
+TEST_P(BackupEngineRateLimitingTestWithParam,
+ RateLimitingChargeReadInInitialize) {
+ bool is_single_threaded = std::get<1>(GetParam()) == 0 ? true : false;
+ engine_options_->max_background_operations = is_single_threaded ? 1 : 10;
+
+ const std::uint64_t backup_rate_limiter_limit = std::get<2>(GetParam()).first;
+ std::shared_ptr<RateLimiter> backup_rate_limiter(NewGenericRateLimiter(
+ backup_rate_limiter_limit, 100 * 1000 /* refill_period_us */,
+ 10 /* fairness */, RateLimiter::Mode::kAllIo /* mode */));
+ engine_options_->backup_rate_limiter = backup_rate_limiter;
+
+ DestroyDBWithoutCheck(dbname_, Options());
+ OpenDBAndBackupEngine(true /* destroy_old_data */);
+ FillDB(db_.get(), 0, 10);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(),
+ false /* flush_before_backup */));
+ CloseDBAndBackupEngine();
+ AssertBackupConsistency(1, 0, 10, 20);
+
+ std::int64_t total_bytes_through_before_initialize =
+ engine_options_->backup_rate_limiter->GetTotalBytesThrough();
+ OpenDBAndBackupEngine(false /* destroy_old_data */);
+ // We charge read in BackupEngineImpl::BackupMeta::LoadFromFile,
+ // which is called in BackupEngineImpl::Initialize() during
+ // OpenBackupEngine(false)
+ EXPECT_GT(engine_options_->backup_rate_limiter->GetTotalBytesThrough(),
+ total_bytes_through_before_initialize);
+ CloseDBAndBackupEngine();
+ DestroyDBWithoutCheck(dbname_, Options());
+}
+
+class BackupEngineRateLimitingTestWithParam2
+ : public BackupEngineTest,
+ public testing::WithParamInterface<
+ std::tuple<std::pair<uint64_t, uint64_t> /* limits */>> {
+ public:
+ BackupEngineRateLimitingTestWithParam2() {}
+};
+
+INSTANTIATE_TEST_CASE_P(
+ LowRefillBytesPerPeriod, BackupEngineRateLimitingTestWithParam2,
+ ::testing::Values(std::make_tuple(std::make_pair(1, 1))));
+// To verify we don't request over-sized bytes relative to
+// refill_bytes_per_period_ in each RateLimiter::Request() called in
+// BackupEngine through verifying we don't trigger assertion
+// failure on over-sized request in GenericRateLimiter in debug builds
+TEST_P(BackupEngineRateLimitingTestWithParam2,
+ RateLimitingWithLowRefillBytesPerPeriod) {
+ SpecialEnv special_env(Env::Default(), /*time_elapse_only_sleep*/ true);
+
+ engine_options_->max_background_operations = 1;
+ const uint64_t backup_rate_limiter_limit = std::get<0>(GetParam()).first;
+ std::shared_ptr<RateLimiter> backup_rate_limiter(
+ std::make_shared<GenericRateLimiter>(
+ backup_rate_limiter_limit, 1000 * 1000 /* refill_period_us */,
+ 10 /* fairness */, RateLimiter::Mode::kAllIo /* mode */,
+ special_env.GetSystemClock(), false /* auto_tuned */));
+
+ engine_options_->backup_rate_limiter = backup_rate_limiter;
+
+ const uint64_t restore_rate_limiter_limit = std::get<0>(GetParam()).second;
+ std::shared_ptr<RateLimiter> restore_rate_limiter(
+ std::make_shared<GenericRateLimiter>(
+ restore_rate_limiter_limit, 1000 * 1000 /* refill_period_us */,
+ 10 /* fairness */, RateLimiter::Mode::kAllIo /* mode */,
+ special_env.GetSystemClock(), false /* auto_tuned */));
+
+ engine_options_->restore_rate_limiter = restore_rate_limiter;
+
+ // Rate limiter uses `CondVar::TimedWait()`, which does not have access to the
+ // `Env` to advance its time according to the fake wait duration. The
+ // workaround is to install a callback that advance the `Env`'s mock time.
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "GenericRateLimiter::Request:PostTimedWait", [&](void* arg) {
+ int64_t time_waited_us = *static_cast<int64_t*>(arg);
+ special_env.SleepForMicroseconds(static_cast<int>(time_waited_us));
+ });
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+
+ DestroyDBWithoutCheck(dbname_, Options());
+ OpenDBAndBackupEngine(true /* destroy_old_data */, false /* dummy */,
+ kShareWithChecksum /* shared_option */);
+
+ FillDB(db_.get(), 0, 100);
+ int64_t total_bytes_through_before_backup =
+ engine_options_->backup_rate_limiter->GetTotalBytesThrough();
+ EXPECT_OK(backup_engine_->CreateNewBackup(db_.get(),
+ false /* flush_before_backup */));
+ int64_t total_bytes_through_after_backup =
+ engine_options_->backup_rate_limiter->GetTotalBytesThrough();
+ ASSERT_GT(total_bytes_through_after_backup,
+ total_bytes_through_before_backup);
+
+ std::vector<BackupInfo> backup_infos;
+ BackupInfo backup_info;
+ backup_engine_->GetBackupInfo(&backup_infos);
+ ASSERT_EQ(1, backup_infos.size());
+ const int backup_id = 1;
+ ASSERT_EQ(backup_id, backup_infos[0].backup_id);
+ ASSERT_OK(backup_engine_->GetBackupInfo(backup_id, &backup_info,
+ true /* include_file_details */));
+ int64_t total_bytes_through_before_verify_backup =
+ engine_options_->backup_rate_limiter->GetTotalBytesThrough();
+ EXPECT_OK(
+ backup_engine_->VerifyBackup(backup_id, true /* verify_with_checksum */));
+ int64_t total_bytes_through_after_verify_backup =
+ engine_options_->backup_rate_limiter->GetTotalBytesThrough();
+ ASSERT_GT(total_bytes_through_after_verify_backup,
+ total_bytes_through_before_verify_backup);
+
+ CloseDBAndBackupEngine();
+ AssertBackupConsistency(backup_id, 0, 100, 101);
+
+ int64_t total_bytes_through_before_initialize =
+ engine_options_->backup_rate_limiter->GetTotalBytesThrough();
+ OpenDBAndBackupEngine(false /* destroy_old_data */);
+ // We charge read in BackupEngineImpl::BackupMeta::LoadFromFile,
+ // which is called in BackupEngineImpl::Initialize() during
+ // OpenBackupEngine(false)
+ int64_t total_bytes_through_after_initialize =
+ engine_options_->backup_rate_limiter->GetTotalBytesThrough();
+ ASSERT_GT(total_bytes_through_after_initialize,
+ total_bytes_through_before_initialize);
+ CloseDBAndBackupEngine();
+
+ DestroyDBWithoutCheck(dbname_, Options());
+ OpenBackupEngine(false /* destroy_old_data */);
+ int64_t total_bytes_through_before_restore =
+ engine_options_->restore_rate_limiter->GetTotalBytesThrough();
+ EXPECT_OK(backup_engine_->RestoreDBFromLatestBackup(dbname_, dbname_));
+ int64_t total_bytes_through_after_restore =
+ engine_options_->restore_rate_limiter->GetTotalBytesThrough();
+ ASSERT_GT(total_bytes_through_after_restore,
+ total_bytes_through_before_restore);
+ CloseBackupEngine();
+
+ DestroyDBWithoutCheck(dbname_, Options());
+
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->ClearCallBack(
+ "GenericRateLimiter::Request:PostTimedWait");
+}
+
+#endif // !defined(ROCKSDB_VALGRIND_RUN) || defined(ROCKSDB_FULL_VALGRIND_RUN)
+
+TEST_F(BackupEngineTest, ReadOnlyBackupEngine) {
+ DestroyDBWithoutCheck(dbname_, options_);
+ OpenDBAndBackupEngine(true);
+ FillDB(db_.get(), 0, 100);
+ // Also test read-only DB with CreateNewBackup and flush=true (no flush)
+ CloseAndReopenDB(/*read_only*/ true);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), /*flush*/ true));
+ CloseAndReopenDB(/*read_only*/ false);
+ FillDB(db_.get(), 100, 200);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), /*flush*/ true));
+ CloseDBAndBackupEngine();
+ DestroyDBWithoutCheck(dbname_, options_);
+
+ engine_options_->destroy_old_data = false;
+ test_backup_fs_->ClearWrittenFiles();
+ test_backup_fs_->SetLimitDeleteFiles(0);
+ BackupEngineReadOnly* read_only_backup_engine;
+ ASSERT_OK(BackupEngineReadOnly::Open(db_chroot_env_.get(), *engine_options_,
+ &read_only_backup_engine));
+ std::vector<BackupInfo> backup_info;
+ read_only_backup_engine->GetBackupInfo(&backup_info);
+ ASSERT_EQ(backup_info.size(), 2U);
+
+ RestoreOptions restore_options(false);
+ ASSERT_OK(read_only_backup_engine->RestoreDBFromLatestBackup(
+ dbname_, dbname_, restore_options));
+ delete read_only_backup_engine;
+ std::vector<std::string> should_have_written;
+ test_backup_fs_->AssertWrittenFiles(should_have_written);
+
+ DB* db = OpenDB();
+ AssertExists(db, 0, 200);
+ delete db;
+}
+
+TEST_F(BackupEngineTest, OpenBackupAsReadOnlyDB) {
+ DestroyDBWithoutCheck(dbname_, options_);
+ options_.write_dbid_to_manifest = false;
+
+ OpenDBAndBackupEngine(true);
+ FillDB(db_.get(), 0, 100);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), /*flush*/ false));
+
+ options_.write_dbid_to_manifest = true; // exercises some read-only DB code
+ CloseAndReopenDB();
+
+ FillDB(db_.get(), 100, 200);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), /*flush*/ false));
+ db_.reset(); // CloseDB
+ DestroyDBWithoutCheck(dbname_, options_);
+ BackupInfo backup_info;
+ // First, check that we get empty fields without include_file_details
+ ASSERT_OK(backup_engine_->GetBackupInfo(/*id*/ 1U, &backup_info,
+ /*with file details*/ false));
+ ASSERT_EQ(backup_info.name_for_open, "");
+ ASSERT_FALSE(backup_info.env_for_open);
+
+ // Now for the real test
+ backup_info = BackupInfo();
+ ASSERT_OK(backup_engine_->GetBackupInfo(/*id*/ 1U, &backup_info,
+ /*with file details*/ true));
+
+ // Caution: DBOptions only holds a raw pointer to Env, so something else
+ // must keep it alive.
+ // Case 1: Keeping BackupEngine open suffices to keep Env alive
+ DB* db = nullptr;
+ Options opts = options_;
+ // Ensure some key defaults are set
+ opts.wal_dir = "";
+ opts.create_if_missing = false;
+ opts.info_log.reset();
+
+ opts.env = backup_info.env_for_open.get();
+ std::string name = backup_info.name_for_open;
+ backup_info = BackupInfo();
+ ASSERT_OK(DB::OpenForReadOnly(opts, name, &db));
+
+ AssertExists(db, 0, 100);
+ AssertEmpty(db, 100, 200);
+
+ delete db;
+ db = nullptr;
+
+ // Case 2: Keeping BackupInfo alive rather than BackupEngine also suffices
+ ASSERT_OK(backup_engine_->GetBackupInfo(/*id*/ 2U, &backup_info,
+ /*with file details*/ true));
+ CloseBackupEngine();
+ opts.create_if_missing = true; // check also OK (though pointless)
+ opts.env = backup_info.env_for_open.get();
+ name = backup_info.name_for_open;
+ // Note: keeping backup_info alive
+ ASSERT_OK(DB::OpenForReadOnly(opts, name, &db));
+
+ AssertExists(db, 0, 200);
+ delete db;
+ db = nullptr;
+
+ // Now try opening read-write and make sure it fails, for safety.
+ ASSERT_TRUE(DB::Open(opts, name, &db).IsIOError());
+}
+
+TEST_F(BackupEngineTest, ProgressCallbackDuringBackup) {
+ DestroyDBWithoutCheck(dbname_, options_);
+ // Too big for this small DB
+ engine_options_->callback_trigger_interval_size = 100000;
+ OpenDBAndBackupEngine(true);
+ FillDB(db_.get(), 0, 100);
+ bool is_callback_invoked = false;
+ ASSERT_OK(backup_engine_->CreateNewBackup(
+ db_.get(), true,
+ [&is_callback_invoked]() { is_callback_invoked = true; }));
+ ASSERT_FALSE(is_callback_invoked);
+ CloseBackupEngine();
+
+ // Easily small enough for this small DB
+ engine_options_->callback_trigger_interval_size = 1000;
+ OpenBackupEngine();
+ ASSERT_OK(backup_engine_->CreateNewBackup(
+ db_.get(), true,
+ [&is_callback_invoked]() { is_callback_invoked = true; }));
+ ASSERT_TRUE(is_callback_invoked);
+ CloseDBAndBackupEngine();
+ DestroyDBWithoutCheck(dbname_, options_);
+}
+
+TEST_F(BackupEngineTest, GarbageCollectionBeforeBackup) {
+ DestroyDBWithoutCheck(dbname_, options_);
+ OpenDBAndBackupEngine(true);
+
+ ASSERT_OK(backup_chroot_env_->CreateDirIfMissing(backupdir_ + "/shared"));
+ std::string file_five = backupdir_ + "/shared/000009.sst";
+ std::string file_five_contents = "I'm not really a sst file";
+ // this depends on the fact that 00009.sst is the first file created by the DB
+ ASSERT_OK(file_manager_->WriteToFile(file_five, file_five_contents));
+
+ FillDB(db_.get(), 0, 100);
+ // backup overwrites file 000009.sst
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), true));
+
+ std::string new_file_five_contents;
+ ASSERT_OK(ReadFileToString(backup_chroot_env_.get(), file_five,
+ &new_file_five_contents));
+ // file 000009.sst was overwritten
+ ASSERT_TRUE(new_file_five_contents != file_five_contents);
+
+ CloseDBAndBackupEngine();
+
+ AssertBackupConsistency(0, 0, 100);
+}
+
+// Test that we properly propagate Env failures
+TEST_F(BackupEngineTest, EnvFailures) {
+ BackupEngine* backup_engine;
+
+ // get children failure
+ {
+ test_backup_fs_->SetGetChildrenFailure(true);
+ ASSERT_NOK(BackupEngine::Open(test_db_env_.get(), *engine_options_,
+ &backup_engine));
+ test_backup_fs_->SetGetChildrenFailure(false);
+ }
+
+ // created dir failure
+ {
+ test_backup_fs_->SetCreateDirIfMissingFailure(true);
+ ASSERT_NOK(BackupEngine::Open(test_db_env_.get(), *engine_options_,
+ &backup_engine));
+ test_backup_fs_->SetCreateDirIfMissingFailure(false);
+ }
+
+ // new directory failure
+ {
+ test_backup_fs_->SetNewDirectoryFailure(true);
+ ASSERT_NOK(BackupEngine::Open(test_db_env_.get(), *engine_options_,
+ &backup_engine));
+ test_backup_fs_->SetNewDirectoryFailure(false);
+ }
+
+ // Read from meta-file failure
+ {
+ DestroyDBWithoutCheck(dbname_, options_);
+ OpenDBAndBackupEngine(true);
+ FillDB(db_.get(), 0, 100);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), true));
+ CloseDBAndBackupEngine();
+ test_backup_fs_->SetDummySequentialFile(true);
+ test_backup_fs_->SetDummySequentialFileFailReads(true);
+ engine_options_->destroy_old_data = false;
+ ASSERT_NOK(BackupEngine::Open(test_db_env_.get(), *engine_options_,
+ &backup_engine));
+ test_backup_fs_->SetDummySequentialFile(false);
+ test_backup_fs_->SetDummySequentialFileFailReads(false);
+ }
+
+ // no failure
+ {
+ ASSERT_OK(BackupEngine::Open(test_db_env_.get(), *engine_options_,
+ &backup_engine));
+ delete backup_engine;
+ }
+}
+
+// Verify manifest can roll while a backup is being created with the old
+// manifest.
+TEST_F(BackupEngineTest, ChangeManifestDuringBackupCreation) {
+ DestroyDBWithoutCheck(dbname_, options_);
+ options_.max_manifest_file_size = 0; // always rollover manifest for file add
+ OpenDBAndBackupEngine(true);
+ FillDB(db_.get(), 0, 100, kAutoFlushOnly);
+
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency({
+ {"CheckpointImpl::CreateCheckpoint:SavedLiveFiles1",
+ "VersionSet::LogAndApply:WriteManifest"},
+ {"VersionSet::LogAndApply:WriteManifestDone",
+ "CheckpointImpl::CreateCheckpoint:SavedLiveFiles2"},
+ });
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+
+ ROCKSDB_NAMESPACE::port::Thread flush_thread{
+ [this]() { ASSERT_OK(db_->Flush(FlushOptions())); }};
+
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), false));
+
+ flush_thread.join();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+
+ // The last manifest roll would've already been cleaned up by the full scan
+ // that happens when CreateNewBackup invokes EnableFileDeletions. We need to
+ // trigger another roll to verify non-full scan purges stale manifests.
+ DBImpl* db_impl = static_cast_with_check<DBImpl>(db_.get());
+ std::string prev_manifest_path =
+ DescriptorFileName(dbname_, db_impl->TEST_Current_Manifest_FileNo());
+ FillDB(db_.get(), 0, 100, kAutoFlushOnly);
+ ASSERT_OK(db_chroot_env_->FileExists(prev_manifest_path));
+ ASSERT_OK(db_->Flush(FlushOptions()));
+ // Even though manual flush completed above, the background thread may not
+ // have finished its cleanup work. `TEST_WaitForBackgroundWork()` will wait
+ // until all the background thread's work has completed, including cleanup.
+ ASSERT_OK(db_impl->TEST_WaitForBackgroundWork());
+ ASSERT_TRUE(db_chroot_env_->FileExists(prev_manifest_path).IsNotFound());
+
+ CloseDBAndBackupEngine();
+ DestroyDBWithoutCheck(dbname_, options_);
+ AssertBackupConsistency(0, 0, 100);
+}
+
+// see https://github.com/facebook/rocksdb/issues/921
+TEST_F(BackupEngineTest, Issue921Test) {
+ BackupEngine* backup_engine;
+ engine_options_->share_table_files = false;
+ ASSERT_OK(
+ backup_chroot_env_->CreateDirIfMissing(engine_options_->backup_dir));
+ engine_options_->backup_dir += "/new_dir";
+ ASSERT_OK(BackupEngine::Open(backup_chroot_env_.get(), *engine_options_,
+ &backup_engine));
+
+ delete backup_engine;
+}
+
+TEST_F(BackupEngineTest, BackupWithMetadata) {
+ const int keys_iteration = 5000;
+ OpenDBAndBackupEngine(true);
+ // create five backups
+ for (int i = 0; i < 5; ++i) {
+ const std::string metadata = std::to_string(i);
+ FillDB(db_.get(), keys_iteration * i, keys_iteration * (i + 1));
+ // Here also test CreateNewBackupWithMetadata with CreateBackupOptions
+ // and outputting saved BackupID.
+ CreateBackupOptions opts;
+ opts.flush_before_backup = true;
+ BackupID new_id = 0;
+ ASSERT_OK(backup_engine_->CreateNewBackupWithMetadata(opts, db_.get(),
+ metadata, &new_id));
+ ASSERT_EQ(new_id, static_cast<BackupID>(i + 1));
+ }
+ CloseDBAndBackupEngine();
+
+ OpenDBAndBackupEngine();
+ { // Verify in bulk BackupInfo
+ std::vector<BackupInfo> backup_infos;
+ backup_engine_->GetBackupInfo(&backup_infos);
+ ASSERT_EQ(5, backup_infos.size());
+ for (int i = 0; i < 5; i++) {
+ ASSERT_EQ(std::to_string(i), backup_infos[i].app_metadata);
+ }
+ }
+ // Also verify in individual BackupInfo
+ for (int i = 0; i < 5; i++) {
+ BackupInfo backup_info;
+ ASSERT_OK(backup_engine_->GetBackupInfo(static_cast<BackupID>(i + 1),
+ &backup_info));
+ ASSERT_EQ(std::to_string(i), backup_info.app_metadata);
+ }
+ CloseDBAndBackupEngine();
+ DestroyDBWithoutCheck(dbname_, options_);
+}
+
+TEST_F(BackupEngineTest, BinaryMetadata) {
+ OpenDBAndBackupEngine(true);
+ std::string binaryMetadata = "abc\ndef";
+ binaryMetadata.push_back('\0');
+ binaryMetadata.append("ghi");
+ ASSERT_OK(
+ backup_engine_->CreateNewBackupWithMetadata(db_.get(), binaryMetadata));
+ CloseDBAndBackupEngine();
+
+ OpenDBAndBackupEngine();
+ std::vector<BackupInfo> backup_infos;
+ backup_engine_->GetBackupInfo(&backup_infos);
+ ASSERT_EQ(1, backup_infos.size());
+ ASSERT_EQ(binaryMetadata, backup_infos[0].app_metadata);
+ CloseDBAndBackupEngine();
+ DestroyDBWithoutCheck(dbname_, options_);
+}
+
+TEST_F(BackupEngineTest, MetadataTooLarge) {
+ OpenDBAndBackupEngine(true);
+ std::string largeMetadata(1024 * 1024 + 1, 0);
+ ASSERT_NOK(
+ backup_engine_->CreateNewBackupWithMetadata(db_.get(), largeMetadata));
+ CloseDBAndBackupEngine();
+ DestroyDBWithoutCheck(dbname_, options_);
+}
+
+TEST_F(BackupEngineTest, MetaSchemaVersion2_SizeCorruption) {
+ engine_options_->schema_version = 1;
+ OpenDBAndBackupEngine(/*destroy_old_data*/ true);
+
+ // Backup 1: no future schema, no sizes, with checksums
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get()));
+
+ CloseDBAndBackupEngine();
+ engine_options_->schema_version = 2;
+ OpenDBAndBackupEngine(/*destroy_old_data*/ false);
+
+ // Backup 2: no checksums, no sizes
+ TEST_BackupMetaSchemaOptions test_opts;
+ test_opts.crc32c_checksums = false;
+ test_opts.file_sizes = false;
+ TEST_SetBackupMetaSchemaOptions(backup_engine_.get(), test_opts);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get()));
+
+ // Backup 3: no checksums, with sizes
+ test_opts.file_sizes = true;
+ TEST_SetBackupMetaSchemaOptions(backup_engine_.get(), test_opts);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get()));
+
+ // Backup 4: with checksums and sizes
+ test_opts.crc32c_checksums = true;
+ TEST_SetBackupMetaSchemaOptions(backup_engine_.get(), test_opts);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get()));
+
+ CloseDBAndBackupEngine();
+
+ // Corrupt all the CURRENT files with the wrong size
+ const std::string private_dir = backupdir_ + "/private";
+
+ for (int id = 1; id <= 3; ++id) {
+ ASSERT_OK(file_manager_->WriteToFile(
+ private_dir + "/" + std::to_string(id) + "/CURRENT", "x"));
+ }
+ // Except corrupt Backup 4 with same size CURRENT file
+ {
+ uint64_t size = 0;
+ ASSERT_OK(test_backup_env_->GetFileSize(private_dir + "/4/CURRENT", &size));
+ ASSERT_OK(file_manager_->WriteToFile(private_dir + "/4/CURRENT",
+ std::string(size, 'x')));
+ }
+
+ OpenBackupEngine();
+
+ // Only the one with sizes in metadata will be immediately detected
+ // as corrupt
+ std::vector<BackupID> corrupted;
+ backup_engine_->GetCorruptedBackups(&corrupted);
+ ASSERT_EQ(corrupted.size(), 1);
+ ASSERT_EQ(corrupted[0], 3);
+
+ // Size corruption detected on Restore with checksum
+ ASSERT_TRUE(backup_engine_->RestoreDBFromBackup(1 /*id*/, dbname_, dbname_)
+ .IsCorruption());
+
+ // Size corruption not detected without checksums nor sizes
+ ASSERT_OK(backup_engine_->RestoreDBFromBackup(2 /*id*/, dbname_, dbname_));
+
+ // Non-size corruption detected on Restore with checksum
+ ASSERT_TRUE(backup_engine_->RestoreDBFromBackup(4 /*id*/, dbname_, dbname_)
+ .IsCorruption());
+
+ CloseBackupEngine();
+}
+
+TEST_F(BackupEngineTest, MetaSchemaVersion2_NotSupported) {
+ engine_options_->schema_version = 2;
+ TEST_BackupMetaSchemaOptions test_opts;
+ std::string app_metadata = "abc\ndef";
+
+ OpenDBAndBackupEngine(true);
+ // Start with supported
+ TEST_SetBackupMetaSchemaOptions(backup_engine_.get(), test_opts);
+ ASSERT_OK(
+ backup_engine_->CreateNewBackupWithMetadata(db_.get(), app_metadata));
+
+ // Because we are injecting badness with a TEST API, the badness is only
+ // detected on attempt to restore.
+ // Not supported versions
+ test_opts.version = "3";
+ TEST_SetBackupMetaSchemaOptions(backup_engine_.get(), test_opts);
+ ASSERT_OK(
+ backup_engine_->CreateNewBackupWithMetadata(db_.get(), app_metadata));
+ test_opts.version = "23.45.67";
+ TEST_SetBackupMetaSchemaOptions(backup_engine_.get(), test_opts);
+ ASSERT_OK(
+ backup_engine_->CreateNewBackupWithMetadata(db_.get(), app_metadata));
+ test_opts.version = "2";
+
+ // Non-ignorable fields
+ test_opts.meta_fields["ni::blah"] = "123";
+ TEST_SetBackupMetaSchemaOptions(backup_engine_.get(), test_opts);
+ ASSERT_OK(
+ backup_engine_->CreateNewBackupWithMetadata(db_.get(), app_metadata));
+ test_opts.meta_fields.clear();
+
+ test_opts.file_fields["ni::123"] = "xyz";
+ TEST_SetBackupMetaSchemaOptions(backup_engine_.get(), test_opts);
+ ASSERT_OK(
+ backup_engine_->CreateNewBackupWithMetadata(db_.get(), app_metadata));
+ test_opts.file_fields.clear();
+
+ test_opts.footer_fields["ni::123"] = "xyz";
+ TEST_SetBackupMetaSchemaOptions(backup_engine_.get(), test_opts);
+ ASSERT_OK(
+ backup_engine_->CreateNewBackupWithMetadata(db_.get(), app_metadata));
+ test_opts.footer_fields.clear();
+ CloseDBAndBackupEngine();
+
+ OpenBackupEngine();
+ std::vector<BackupID> corrupted;
+ backup_engine_->GetCorruptedBackups(&corrupted);
+ ASSERT_EQ(corrupted.size(), 5);
+
+ ASSERT_OK(backup_engine_->RestoreDBFromLatestBackup(dbname_, dbname_));
+ CloseBackupEngine();
+}
+
+TEST_F(BackupEngineTest, MetaSchemaVersion2_Restore) {
+ engine_options_->schema_version = 2;
+ TEST_BackupMetaSchemaOptions test_opts;
+ const int keys_iteration = 5000;
+
+ OpenDBAndBackupEngine(true, false, kShareWithChecksum);
+ FillDB(db_.get(), 0, keys_iteration);
+ // Start with minimum metadata to ensure it works without it being filled
+ // based on shared files also in other backups with the metadata.
+ test_opts.crc32c_checksums = false;
+ test_opts.file_sizes = false;
+ TEST_SetBackupMetaSchemaOptions(backup_engine_.get(), test_opts);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), true));
+ CloseDBAndBackupEngine();
+
+ AssertBackupConsistency(1 /* id */, 0, keys_iteration, keys_iteration * 2);
+
+ OpenDBAndBackupEngine(false /* destroy_old_data */, false,
+ kShareWithChecksum);
+ test_opts.file_sizes = true;
+ TEST_SetBackupMetaSchemaOptions(backup_engine_.get(), test_opts);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), true));
+ CloseDBAndBackupEngine();
+
+ for (int id = 1; id <= 2; ++id) {
+ AssertBackupConsistency(id, 0, keys_iteration, keys_iteration * 2);
+ }
+
+ OpenDBAndBackupEngine(false /* destroy_old_data */, false,
+ kShareWithChecksum);
+ test_opts.crc32c_checksums = true;
+ TEST_SetBackupMetaSchemaOptions(backup_engine_.get(), test_opts);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), true));
+ CloseDBAndBackupEngine();
+
+ for (int id = 1; id <= 3; ++id) {
+ AssertBackupConsistency(id, 0, keys_iteration, keys_iteration * 2);
+ }
+
+ OpenDBAndBackupEngine(false /* destroy_old_data */, false,
+ kShareWithChecksum);
+ // No TEST_EnableWriteFutureSchemaVersion2
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), true));
+ CloseDBAndBackupEngine();
+
+ for (int id = 1; id <= 4; ++id) {
+ AssertBackupConsistency(id, 0, keys_iteration, keys_iteration * 2);
+ }
+
+ OpenDBAndBackupEngine(false /* destroy_old_data */, false,
+ kShareWithChecksum);
+ // Minor version updates should be forward-compatible
+ test_opts.version = "2.5.70";
+ test_opts.meta_fields["asdf.3456"] = "-42";
+ test_opts.meta_fields["__QRST"] = " 1 $ %%& ";
+ test_opts.file_fields["z94._"] = "^\\";
+ test_opts.file_fields["_7yyyyyyyyy"] = "111111111111";
+ test_opts.footer_fields["Qwzn.tz89"] = "ASDF!!@# ##=\t ";
+ test_opts.footer_fields["yes"] = "no!";
+ TEST_SetBackupMetaSchemaOptions(backup_engine_.get(), test_opts);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), true));
+ CloseDBAndBackupEngine();
+
+ for (int id = 1; id <= 5; ++id) {
+ AssertBackupConsistency(id, 0, keys_iteration, keys_iteration * 2);
+ }
+}
+
+TEST_F(BackupEngineTest, Concurrency) {
+ // Check that we can simultaneously:
+ // * Run several read operations in different threads on a single
+ // BackupEngine object, and
+ // * With another BackupEngine object on the same
+ // backup_dir, run the same read operations in another thread, and
+ // * With yet another BackupEngine object on the same
+ // backup_dir, create two new backups in parallel threads.
+ //
+ // Because of the challenges of integrating this into db_stress,
+ // this is a non-deterministic mini-stress test here instead.
+
+ // To check for a race condition in handling buffer size based on byte
+ // burst limit, we need a (generous) rate limiter
+ std::shared_ptr<RateLimiter> limiter{NewGenericRateLimiter(1000000000)};
+ engine_options_->backup_rate_limiter = limiter;
+ engine_options_->restore_rate_limiter = limiter;
+
+ OpenDBAndBackupEngine(true, false, kShareWithChecksum);
+
+ static constexpr int keys_iteration = 5000;
+ FillDB(db_.get(), 0, keys_iteration);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get()));
+
+ FillDB(db_.get(), keys_iteration, 2 * keys_iteration);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get()));
+
+ static constexpr int max_factor = 3;
+ FillDB(db_.get(), 2 * keys_iteration, max_factor * keys_iteration);
+ // will create another backup soon...
+
+ Options db_opts = options_;
+ db_opts.wal_dir = "";
+ db_opts.create_if_missing = false;
+ BackupEngineOptions be_opts = *engine_options_;
+ be_opts.destroy_old_data = false;
+
+ std::mt19937 rng{std::random_device()()};
+
+ std::array<std::thread, 4> read_threads;
+ std::array<std::thread, 4> restore_verify_threads;
+ for (uint32_t i = 0; i < read_threads.size(); ++i) {
+ uint32_t sleep_micros = rng() % 100000;
+ read_threads[i] = std::thread([this, i, sleep_micros, &db_opts, &be_opts,
+ &restore_verify_threads, &limiter] {
+ test_db_env_->SleepForMicroseconds(sleep_micros);
+
+ // Whether to also re-open the BackupEngine, potentially seeing
+ // additional backups
+ bool reopen = i == 3;
+ // Whether we are going to restore "latest"
+ bool latest = i > 1;
+
+ BackupEngine* my_be;
+ if (reopen) {
+ ASSERT_OK(BackupEngine::Open(test_db_env_.get(), be_opts, &my_be));
+ } else {
+ my_be = backup_engine_.get();
+ }
+
+ // Verify metadata (we don't receive updates from concurrently
+ // creating a new backup)
+ std::vector<BackupInfo> infos;
+ my_be->GetBackupInfo(&infos);
+ const uint32_t count = static_cast<uint32_t>(infos.size());
+ infos.clear();
+ if (reopen) {
+ ASSERT_GE(count, 2U);
+ ASSERT_LE(count, 4U);
+ fprintf(stderr, "Reopen saw %u backups\n", count);
+ } else {
+ ASSERT_EQ(count, 2U);
+ }
+ std::vector<BackupID> ids;
+ my_be->GetCorruptedBackups(&ids);
+ ASSERT_EQ(ids.size(), 0U);
+
+ // (Eventually, see below) Restore one of the backups, or "latest"
+ std::string restore_db_dir = dbname_ + "/restore" + std::to_string(i);
+ DestroyDir(test_db_env_.get(), restore_db_dir).PermitUncheckedError();
+ BackupID to_restore;
+ if (latest) {
+ to_restore = count;
+ } else {
+ to_restore = i + 1;
+ }
+
+ // Open restored DB to verify its contents, but test atomic restore
+ // by doing it async and ensuring we either get OK or InvalidArgument
+ restore_verify_threads[i] =
+ std::thread([this, &db_opts, restore_db_dir, to_restore] {
+ DB* restored;
+ Status s;
+ for (;;) {
+ s = DB::Open(db_opts, restore_db_dir, &restored);
+ if (s.IsInvalidArgument()) {
+ // Restore hasn't finished
+ test_db_env_->SleepForMicroseconds(1000);
+ continue;
+ } else {
+ // We should only get InvalidArgument if restore is
+ // incomplete, or OK if complete
+ ASSERT_OK(s);
+ break;
+ }
+ }
+ int factor = std::min(static_cast<int>(to_restore), max_factor);
+ AssertExists(restored, 0, factor * keys_iteration);
+ AssertEmpty(restored, factor * keys_iteration,
+ (factor + 1) * keys_iteration);
+ delete restored;
+ });
+
+ // (Ok now) Restore one of the backups, or "latest"
+ if (latest) {
+ ASSERT_OK(
+ my_be->RestoreDBFromLatestBackup(restore_db_dir, restore_db_dir));
+ } else {
+ ASSERT_OK(my_be->VerifyBackup(to_restore, true));
+ ASSERT_OK(my_be->RestoreDBFromBackup(to_restore, restore_db_dir,
+ restore_db_dir));
+ }
+
+ // Test for race condition in reconfiguring limiter
+ // FIXME: this could set to a different value in all threads, except
+ // GenericRateLimiter::SetBytesPerSecond has a write-write race
+ // reported by TSAN
+ if (i == 0) {
+ limiter->SetBytesPerSecond(2000000000);
+ }
+
+ // Re-verify metadata (we don't receive updates from concurrently
+ // creating a new backup)
+ my_be->GetBackupInfo(&infos);
+ ASSERT_EQ(infos.size(), count);
+ my_be->GetCorruptedBackups(&ids);
+ ASSERT_EQ(ids.size(), 0);
+ // fprintf(stderr, "Finished read thread\n");
+
+ if (reopen) {
+ delete my_be;
+ }
+ });
+ }
+
+ BackupEngine* alt_be;
+ ASSERT_OK(BackupEngine::Open(test_db_env_.get(), be_opts, &alt_be));
+
+ std::array<std::thread, 2> append_threads;
+ for (unsigned i = 0; i < append_threads.size(); ++i) {
+ uint32_t sleep_micros = rng() % 100000;
+ append_threads[i] = std::thread([this, sleep_micros, alt_be] {
+ test_db_env_->SleepForMicroseconds(sleep_micros);
+ // WART: CreateNewBackup doesn't tell you the BackupID it just created,
+ // which is ugly for multithreaded setting.
+ // TODO: add delete backup also when that is added
+ ASSERT_OK(alt_be->CreateNewBackup(db_.get()));
+ // fprintf(stderr, "Finished append thread\n");
+ });
+ }
+
+ for (auto& t : append_threads) {
+ t.join();
+ }
+ // Verify metadata
+ std::vector<BackupInfo> infos;
+ alt_be->GetBackupInfo(&infos);
+ ASSERT_EQ(infos.size(), 2 + append_threads.size());
+
+ for (auto& t : read_threads) {
+ t.join();
+ }
+
+ delete alt_be;
+
+ for (auto& t : restore_verify_threads) {
+ t.join();
+ }
+
+ CloseDBAndBackupEngine();
+}
+
+TEST_F(BackupEngineTest, LimitBackupsOpened) {
+ // Verify the specified max backups are opened, including skipping over
+ // corrupted backups.
+ //
+ // Setup:
+ // - backups 1, 2, and 4 are valid
+ // - backup 3 is corrupt
+ // - max_valid_backups_to_open == 2
+ //
+ // Expectation: the engine opens backups 4 and 2 since those are latest two
+ // non-corrupt backups.
+ const int kNumKeys = 5000;
+ OpenDBAndBackupEngine(true);
+ for (int i = 1; i <= 4; ++i) {
+ FillDB(db_.get(), kNumKeys * i, kNumKeys * (i + 1));
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), true));
+ if (i == 3) {
+ ASSERT_OK(file_manager_->CorruptFile(backupdir_ + "/meta/3", 3));
+ }
+ }
+ CloseDBAndBackupEngine();
+
+ engine_options_->max_valid_backups_to_open = 2;
+ engine_options_->destroy_old_data = false;
+ BackupEngineReadOnly* read_only_backup_engine;
+ ASSERT_OK(BackupEngineReadOnly::Open(
+ backup_chroot_env_.get(), *engine_options_, &read_only_backup_engine));
+
+ std::vector<BackupInfo> backup_infos;
+ read_only_backup_engine->GetBackupInfo(&backup_infos);
+ ASSERT_EQ(2, backup_infos.size());
+ ASSERT_EQ(2, backup_infos[0].backup_id);
+ ASSERT_EQ(4, backup_infos[1].backup_id);
+ delete read_only_backup_engine;
+}
+
+TEST_F(BackupEngineTest, IgnoreLimitBackupsOpenedWhenNotReadOnly) {
+ // Verify the specified max_valid_backups_to_open is ignored if the engine
+ // is not read-only.
+ //
+ // Setup:
+ // - backups 1, 2, and 4 are valid
+ // - backup 3 is corrupt
+ // - max_valid_backups_to_open == 2
+ //
+ // Expectation: the engine opens backups 4, 2, and 1 since those are latest
+ // non-corrupt backups, by ignoring max_valid_backups_to_open == 2.
+ const int kNumKeys = 5000;
+ OpenDBAndBackupEngine(true);
+ for (int i = 1; i <= 4; ++i) {
+ FillDB(db_.get(), kNumKeys * i, kNumKeys * (i + 1));
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), true));
+ if (i == 3) {
+ ASSERT_OK(file_manager_->CorruptFile(backupdir_ + "/meta/3", 3));
+ }
+ }
+ CloseDBAndBackupEngine();
+
+ engine_options_->max_valid_backups_to_open = 2;
+ OpenDBAndBackupEngine();
+ std::vector<BackupInfo> backup_infos;
+ backup_engine_->GetBackupInfo(&backup_infos);
+ ASSERT_EQ(3, backup_infos.size());
+ ASSERT_EQ(1, backup_infos[0].backup_id);
+ ASSERT_EQ(2, backup_infos[1].backup_id);
+ ASSERT_EQ(4, backup_infos[2].backup_id);
+ CloseDBAndBackupEngine();
+ DestroyDBWithoutCheck(dbname_, options_);
+}
+
+TEST_F(BackupEngineTest, CreateWhenLatestBackupCorrupted) {
+ // we should pick an ID greater than corrupted backups' IDs so creation can
+ // succeed even when latest backup is corrupted.
+ const int kNumKeys = 5000;
+ OpenDBAndBackupEngine(true /* destroy_old_data */);
+ BackupInfo backup_info;
+ ASSERT_TRUE(backup_engine_->GetLatestBackupInfo(&backup_info).IsNotFound());
+ FillDB(db_.get(), 0 /* from */, kNumKeys);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(),
+ true /* flush_before_backup */));
+ ASSERT_OK(file_manager_->CorruptFile(backupdir_ + "/meta/1",
+ 3 /* bytes_to_corrupt */));
+ CloseDBAndBackupEngine();
+
+ OpenDBAndBackupEngine();
+ ASSERT_TRUE(backup_engine_->GetLatestBackupInfo(&backup_info).IsNotFound());
+
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(),
+ true /* flush_before_backup */));
+
+ ASSERT_TRUE(backup_engine_->GetLatestBackupInfo(&backup_info).ok());
+ ASSERT_EQ(2, backup_info.backup_id);
+
+ std::vector<BackupInfo> backup_infos;
+ backup_engine_->GetBackupInfo(&backup_infos);
+ ASSERT_EQ(1, backup_infos.size());
+ ASSERT_EQ(2, backup_infos[0].backup_id);
+
+ // Verify individual GetBackupInfo by ID
+ ASSERT_TRUE(backup_engine_->GetBackupInfo(0U, &backup_info).IsNotFound());
+ ASSERT_TRUE(backup_engine_->GetBackupInfo(1U, &backup_info).IsCorruption());
+ ASSERT_TRUE(backup_engine_->GetBackupInfo(2U, &backup_info).ok());
+ ASSERT_TRUE(backup_engine_->GetBackupInfo(3U, &backup_info).IsNotFound());
+ ASSERT_TRUE(
+ backup_engine_->GetBackupInfo(999999U, &backup_info).IsNotFound());
+}
+
+TEST_F(BackupEngineTest, WriteOnlyEngineNoSharedFileDeletion) {
+ // Verifies a write-only BackupEngine does not delete files belonging to valid
+ // backups when GarbageCollect, PurgeOldBackups, or DeleteBackup are called.
+ const int kNumKeys = 5000;
+ for (int i = 0; i < 3; ++i) {
+ OpenDBAndBackupEngine(i == 0 /* destroy_old_data */);
+ FillDB(db_.get(), i * kNumKeys, (i + 1) * kNumKeys);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), true));
+ CloseDBAndBackupEngine();
+
+ engine_options_->max_valid_backups_to_open = 0;
+ OpenDBAndBackupEngine();
+ switch (i) {
+ case 0:
+ ASSERT_OK(backup_engine_->GarbageCollect());
+ break;
+ case 1:
+ ASSERT_OK(backup_engine_->PurgeOldBackups(1 /* num_backups_to_keep */));
+ break;
+ case 2:
+ ASSERT_OK(backup_engine_->DeleteBackup(2 /* backup_id */));
+ break;
+ default:
+ assert(false);
+ }
+ CloseDBAndBackupEngine();
+
+ engine_options_->max_valid_backups_to_open =
+ std::numeric_limits<int32_t>::max();
+ AssertBackupConsistency(i + 1, 0, (i + 1) * kNumKeys);
+ }
+}
+
+TEST_P(BackupEngineTestWithParam, BackupUsingDirectIO) {
+ // Tests direct I/O on the backup engine's reads and writes on the DB env and
+ // backup env
+ // We use ChrootEnv underneath so the below line checks for direct I/O support
+ // in the chroot directory, not the true filesystem root.
+ if (!test::IsDirectIOSupported(test_db_env_.get(), "/")) {
+ ROCKSDB_GTEST_SKIP("Test requires Direct I/O Support");
+ return;
+ }
+ const int kNumKeysPerBackup = 100;
+ const int kNumBackups = 3;
+ options_.use_direct_reads = true;
+ OpenDBAndBackupEngine(true /* destroy_old_data */);
+ for (int i = 0; i < kNumBackups; ++i) {
+ FillDB(db_.get(), i * kNumKeysPerBackup /* from */,
+ (i + 1) * kNumKeysPerBackup /* to */, kFlushAll);
+
+ // Clear the file open counters and then do a bunch of backup engine ops.
+ // For all ops, files should be opened in direct mode.
+ test_backup_fs_->ClearFileOpenCounters();
+ test_db_fs_->ClearFileOpenCounters();
+ CloseBackupEngine();
+ OpenBackupEngine();
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(),
+ false /* flush_before_backup */));
+ ASSERT_OK(backup_engine_->VerifyBackup(i + 1));
+ CloseBackupEngine();
+ OpenBackupEngine();
+ std::vector<BackupInfo> backup_infos;
+ backup_engine_->GetBackupInfo(&backup_infos);
+ ASSERT_EQ(static_cast<size_t>(i + 1), backup_infos.size());
+
+ // Verify backup engine always opened files with direct I/O
+ ASSERT_EQ(0, test_db_fs_->num_writers());
+ ASSERT_GE(test_db_fs_->num_direct_rand_readers(), 0);
+ ASSERT_GT(test_db_fs_->num_direct_seq_readers(), 0);
+ // Currently the DB doesn't support reading WALs or manifest with direct
+ // I/O, so subtract two.
+ ASSERT_EQ(test_db_fs_->num_seq_readers() - 2,
+ test_db_fs_->num_direct_seq_readers());
+ ASSERT_EQ(test_db_fs_->num_rand_readers(),
+ test_db_fs_->num_direct_rand_readers());
+ }
+ CloseDBAndBackupEngine();
+
+ for (int i = 0; i < kNumBackups; ++i) {
+ AssertBackupConsistency(i + 1 /* backup_id */,
+ i * kNumKeysPerBackup /* start_exist */,
+ (i + 1) * kNumKeysPerBackup /* end_exist */,
+ (i + 2) * kNumKeysPerBackup /* end */);
+ }
+}
+
+TEST_F(BackupEngineTest, BackgroundThreadCpuPriority) {
+ std::atomic<CpuPriority> priority(CpuPriority::kNormal);
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "BackupEngineImpl::Initialize:SetCpuPriority", [&](void* new_priority) {
+ priority.store(*reinterpret_cast<CpuPriority*>(new_priority));
+ });
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+
+ // 1 thread is easier to test, otherwise, we may not be sure which thread
+ // actually does the work during CreateNewBackup.
+ engine_options_->max_background_operations = 1;
+ OpenDBAndBackupEngine(true);
+
+ {
+ FillDB(db_.get(), 0, 100);
+
+ // by default, cpu priority is not changed.
+ CreateBackupOptions options;
+ ASSERT_OK(backup_engine_->CreateNewBackup(options, db_.get()));
+
+ ASSERT_EQ(priority, CpuPriority::kNormal);
+ }
+
+ {
+ FillDB(db_.get(), 101, 200);
+
+ // decrease cpu priority from normal to low.
+ CreateBackupOptions options;
+ options.decrease_background_thread_cpu_priority = true;
+ options.background_thread_cpu_priority = CpuPriority::kLow;
+ ASSERT_OK(backup_engine_->CreateNewBackup(options, db_.get()));
+
+ ASSERT_EQ(priority, CpuPriority::kLow);
+ }
+
+ {
+ FillDB(db_.get(), 201, 300);
+
+ // try to upgrade cpu priority back to normal,
+ // the priority should still low.
+ CreateBackupOptions options;
+ options.decrease_background_thread_cpu_priority = true;
+ options.background_thread_cpu_priority = CpuPriority::kNormal;
+ ASSERT_OK(backup_engine_->CreateNewBackup(options, db_.get()));
+
+ ASSERT_EQ(priority, CpuPriority::kLow);
+ }
+
+ {
+ FillDB(db_.get(), 301, 400);
+
+ // decrease cpu priority from low to idle.
+ CreateBackupOptions options;
+ options.decrease_background_thread_cpu_priority = true;
+ options.background_thread_cpu_priority = CpuPriority::kIdle;
+ ASSERT_OK(backup_engine_->CreateNewBackup(options, db_.get()));
+
+ ASSERT_EQ(priority, CpuPriority::kIdle);
+ }
+
+ {
+ FillDB(db_.get(), 301, 400);
+
+ // reset priority to later verify that it's not updated by SetCpuPriority.
+ priority = CpuPriority::kNormal;
+
+ // setting the same cpu priority won't call SetCpuPriority.
+ CreateBackupOptions options;
+ options.decrease_background_thread_cpu_priority = true;
+ options.background_thread_cpu_priority = CpuPriority::kIdle;
+
+ // Also check output backup_id with CreateNewBackup
+ BackupID new_id = 0;
+ ASSERT_OK(backup_engine_->CreateNewBackup(options, db_.get(), &new_id));
+ ASSERT_EQ(new_id, 5U);
+
+ ASSERT_EQ(priority, CpuPriority::kNormal);
+ }
+
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->ClearAllCallBacks();
+ CloseDBAndBackupEngine();
+ DestroyDBWithoutCheck(dbname_, options_);
+}
+
+// Populates `*total_size` with the size of all files under `backup_dir`.
+// We don't go through `BackupEngine` currently because it's hard to figure out
+// the metadata file size.
+Status GetSizeOfBackupFiles(FileSystem* backup_fs,
+ const std::string& backup_dir, size_t* total_size) {
+ *total_size = 0;
+ std::vector<std::string> dir_stack = {backup_dir};
+ Status s;
+ while (s.ok() && !dir_stack.empty()) {
+ std::string dir = std::move(dir_stack.back());
+ dir_stack.pop_back();
+ std::vector<std::string> children;
+ s = backup_fs->GetChildren(dir, IOOptions(), &children, nullptr /* dbg */);
+ for (size_t i = 0; s.ok() && i < children.size(); ++i) {
+ std::string path = dir + "/" + children[i];
+ bool is_dir;
+ s = backup_fs->IsDirectory(path, IOOptions(), &is_dir, nullptr /* dbg */);
+ uint64_t file_size = 0;
+ if (s.ok()) {
+ if (is_dir) {
+ dir_stack.emplace_back(std::move(path));
+ } else {
+ s = backup_fs->GetFileSize(path, IOOptions(), &file_size,
+ nullptr /* dbg */);
+ }
+ }
+ if (s.ok()) {
+ *total_size += file_size;
+ }
+ }
+ }
+ return s;
+}
+
+TEST_F(BackupEngineTest, IOStats) {
+ // Tests the `BACKUP_READ_BYTES` and `BACKUP_WRITE_BYTES` ticker stats have
+ // the expected values according to the files in the backups.
+
+ // These ticker stats are expected to be populated regardless of `PerfLevel`
+ // in user thread
+ SetPerfLevel(kDisable);
+
+ options_.statistics = CreateDBStatistics();
+ OpenDBAndBackupEngine(true /* destroy_old_data */, false /* dummy */,
+ kShareWithChecksum);
+
+ FillDB(db_.get(), 0 /* from */, 100 /* to */, kFlushMost);
+
+ ASSERT_EQ(0, options_.statistics->getTickerCount(BACKUP_READ_BYTES));
+ ASSERT_EQ(0, options_.statistics->getTickerCount(BACKUP_WRITE_BYTES));
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(),
+ false /* flush_before_backup */));
+
+ size_t orig_backup_files_size;
+ ASSERT_OK(GetSizeOfBackupFiles(test_backup_env_->GetFileSystem().get(),
+ backupdir_, &orig_backup_files_size));
+ size_t expected_bytes_written = orig_backup_files_size;
+ ASSERT_EQ(expected_bytes_written,
+ options_.statistics->getTickerCount(BACKUP_WRITE_BYTES));
+ // Bytes read is more difficult to pin down since there are reads for many
+ // purposes other than creating file, like `GetSortedWalFiles()` to find first
+ // sequence number, or `CreateNewBackup()` thread to find SST file session ID.
+ // So we loosely require there are at least as many reads as needed for
+ // copying, but not as many as twice that.
+ ASSERT_GE(options_.statistics->getTickerCount(BACKUP_READ_BYTES),
+ expected_bytes_written);
+ ASSERT_LT(expected_bytes_written,
+ 2 * options_.statistics->getTickerCount(BACKUP_READ_BYTES));
+
+ FillDB(db_.get(), 100 /* from */, 200 /* to */, kFlushMost);
+
+ ASSERT_OK(options_.statistics->Reset());
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(),
+ false /* flush_before_backup */));
+ size_t final_backup_files_size;
+ ASSERT_OK(GetSizeOfBackupFiles(test_backup_env_->GetFileSystem().get(),
+ backupdir_, &final_backup_files_size));
+ expected_bytes_written = final_backup_files_size - orig_backup_files_size;
+ ASSERT_EQ(expected_bytes_written,
+ options_.statistics->getTickerCount(BACKUP_WRITE_BYTES));
+ // See above for why these bounds were chosen.
+ ASSERT_GE(options_.statistics->getTickerCount(BACKUP_READ_BYTES),
+ expected_bytes_written);
+ ASSERT_LT(expected_bytes_written,
+ 2 * options_.statistics->getTickerCount(BACKUP_READ_BYTES));
+}
+
+TEST_F(BackupEngineTest, FileTemperatures) {
+ CloseDBAndBackupEngine();
+
+ // Required for recording+restoring temperatures
+ engine_options_->schema_version = 2;
+
+ // More file IO instrumentation
+ auto my_db_fs = std::make_shared<FileTemperatureTestFS>(db_chroot_fs_);
+ test_db_fs_ = std::make_shared<TestFs>(my_db_fs);
+ SetEnvsFromFileSystems();
+
+ // Use temperatures
+ options_.bottommost_temperature = Temperature::kWarm;
+ options_.level0_file_num_compaction_trigger = 2;
+ // set dynamic_level to true so the compaction would compact the data to the
+ // last level directly which will have the last_level_temperature
+ options_.level_compaction_dynamic_level_bytes = true;
+
+ OpenDBAndBackupEngine(true /* destroy_old_data */, false /* dummy */,
+ kShareWithChecksum);
+
+ // generate a bottommost file (combined from 2) and a non-bottommost file
+ DBImpl* dbi = static_cast_with_check<DBImpl>(db_.get());
+ ASSERT_OK(db_->Put(WriteOptions(), "a", "val"));
+ ASSERT_OK(db_->Put(WriteOptions(), "c", "val"));
+ ASSERT_OK(db_->Flush(FlushOptions()));
+ ASSERT_OK(db_->Put(WriteOptions(), "b", "val"));
+ ASSERT_OK(db_->Put(WriteOptions(), "d", "val"));
+ ASSERT_OK(db_->Flush(FlushOptions()));
+ ASSERT_OK(dbi->TEST_WaitForCompact());
+ ASSERT_OK(db_->Put(WriteOptions(), "e", "val"));
+ ASSERT_OK(db_->Flush(FlushOptions()));
+
+ // Get temperatures from manifest
+ std::map<uint64_t, Temperature> manifest_temps;
+ std::map<Temperature, int> manifest_temp_counts;
+ {
+ std::vector<LiveFileStorageInfo> infos;
+ ASSERT_OK(
+ db_->GetLiveFilesStorageInfo(LiveFilesStorageInfoOptions(), &infos));
+ for (auto info : infos) {
+ if (info.file_type == kTableFile) {
+ manifest_temps.emplace(info.file_number, info.temperature);
+ manifest_temp_counts[info.temperature]++;
+ }
+ }
+ }
+
+ // Verify expected manifest temperatures
+ ASSERT_EQ(manifest_temp_counts.size(), 2);
+ ASSERT_EQ(manifest_temp_counts[Temperature::kWarm], 1);
+ ASSERT_EQ(manifest_temp_counts[Temperature::kUnknown], 1);
+
+ // Verify manifest temperatures match FS temperatures
+ std::map<uint64_t, Temperature> current_temps;
+ my_db_fs->CopyCurrentSstFileTemperatures(&current_temps);
+ for (const auto& manifest_temp : manifest_temps) {
+ ASSERT_EQ(current_temps[manifest_temp.first], manifest_temp.second);
+ }
+
+ // Try a few different things
+ for (int i = 1; i <= 5; ++i) {
+ // Expected temperatures after restore are based on manifest temperatures
+ std::map<uint64_t, Temperature> expected_temps = manifest_temps;
+
+ if (i >= 2) {
+ // For iterations 2 & 3, override current temperature of one file
+ // and vary which temperature is authoritative (current or manifest).
+ // For iterations 4 & 5, override current temperature of both files
+ // but make sure an current temperate always takes precedence over
+ // unknown regardless of current_temperatures_override_manifest setting.
+ bool use_current = ((i % 2) == 1);
+ engine_options_->current_temperatures_override_manifest = use_current;
+ CloseBackupEngine();
+ OpenBackupEngine();
+ for (const auto& manifest_temp : manifest_temps) {
+ if (i <= 3) {
+ if (manifest_temp.second == Temperature::kWarm) {
+ my_db_fs->OverrideSstFileTemperature(manifest_temp.first,
+ Temperature::kCold);
+ if (use_current) {
+ expected_temps[manifest_temp.first] = Temperature::kCold;
+ }
+ }
+ } else {
+ assert(i <= 5);
+ if (manifest_temp.second == Temperature::kWarm) {
+ my_db_fs->OverrideSstFileTemperature(manifest_temp.first,
+ Temperature::kUnknown);
+ } else {
+ ASSERT_EQ(manifest_temp.second, Temperature::kUnknown);
+ my_db_fs->OverrideSstFileTemperature(manifest_temp.first,
+ Temperature::kHot);
+ // regardless of use_current
+ expected_temps[manifest_temp.first] = Temperature::kHot;
+ }
+ }
+ }
+ }
+
+ // Sample requested temperatures in opening files for backup
+ my_db_fs->PopRequestedSstFileTemperatures();
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get()));
+
+ // Verify requested temperatures against manifest temperatures (before
+ // retry with kUnknown if needed, and before backup finds out current
+ // temperatures in FileSystem)
+ std::vector<std::pair<uint64_t, Temperature>> requested_temps;
+ my_db_fs->PopRequestedSstFileTemperatures(&requested_temps);
+ std::set<uint64_t> distinct_requests;
+ for (const auto& requested_temp : requested_temps) {
+ // Matching manifest temperatures, except allow retry request with
+ // kUnknown
+ auto manifest_temp = manifest_temps.at(requested_temp.first);
+ if (manifest_temp == Temperature::kUnknown ||
+ requested_temp.second != Temperature::kUnknown) {
+ ASSERT_EQ(manifest_temp, requested_temp.second);
+ }
+ distinct_requests.insert(requested_temp.first);
+ }
+ // Two distinct requests
+ ASSERT_EQ(distinct_requests.size(), 2);
+
+ // Verify against backup info file details API
+ BackupInfo info;
+ ASSERT_OK(backup_engine_->GetLatestBackupInfo(
+ &info, /*include_file_details*/ true));
+ ASSERT_GT(info.file_details.size(), 2);
+ for (auto& e : info.file_details) {
+ ASSERT_EQ(expected_temps[e.file_number], e.temperature);
+ }
+
+ // Restore backup to another virtual (tiered) dir
+ const std::string restore_dir = "/restore" + std::to_string(i);
+ ASSERT_OK(backup_engine_->RestoreDBFromLatestBackup(
+ RestoreOptions(), restore_dir, restore_dir));
+
+ // Verify restored FS temperatures match expectation
+ // (FileTemperatureTestFS doesn't distinguish directories when reporting
+ // current temperatures, just whatever SST was written or overridden last
+ // with that file number.)
+ my_db_fs->CopyCurrentSstFileTemperatures(&current_temps);
+ for (const auto& expected_temp : expected_temps) {
+ ASSERT_EQ(current_temps[expected_temp.first], expected_temp.second);
+ }
+
+ // Delete backup to force next backup to copy files
+ ASSERT_OK(backup_engine_->PurgeOldBackups(0));
+ }
+}
+
+} // namespace
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
+
+#else
+#include <stdio.h>
+
+int main(int /*argc*/, char** /*argv*/) {
+ fprintf(stderr, "SKIPPED as BackupEngine is not supported in ROCKSDB_LITE\n");
+ return 0;
+}
+
+#endif // !defined(ROCKSDB_LITE) && !defined(OS_WIN)
diff --git a/src/rocksdb/utilities/blob_db/blob_compaction_filter.cc b/src/rocksdb/utilities/blob_db/blob_compaction_filter.cc
new file mode 100644
index 000000000..86907e979
--- /dev/null
+++ b/src/rocksdb/utilities/blob_db/blob_compaction_filter.cc
@@ -0,0 +1,490 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "utilities/blob_db/blob_compaction_filter.h"
+
+#include <cinttypes>
+
+#include "db/dbformat.h"
+#include "logging/logging.h"
+#include "rocksdb/system_clock.h"
+#include "test_util/sync_point.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace blob_db {
+
+BlobIndexCompactionFilterBase::~BlobIndexCompactionFilterBase() {
+ if (blob_file_) {
+ CloseAndRegisterNewBlobFile();
+ }
+ RecordTick(statistics_, BLOB_DB_BLOB_INDEX_EXPIRED_COUNT, expired_count_);
+ RecordTick(statistics_, BLOB_DB_BLOB_INDEX_EXPIRED_SIZE, expired_size_);
+ RecordTick(statistics_, BLOB_DB_BLOB_INDEX_EVICTED_COUNT, evicted_count_);
+ RecordTick(statistics_, BLOB_DB_BLOB_INDEX_EVICTED_SIZE, evicted_size_);
+}
+
+CompactionFilter::Decision BlobIndexCompactionFilterBase::FilterV2(
+ int level, const Slice& key, ValueType value_type, const Slice& value,
+ std::string* new_value, std::string* skip_until) const {
+ const CompactionFilter* ucf = user_comp_filter();
+ if (value_type != kBlobIndex) {
+ if (ucf == nullptr) {
+ return Decision::kKeep;
+ }
+ // Apply user compaction filter for inlined data.
+ CompactionFilter::Decision decision =
+ ucf->FilterV2(level, key, value_type, value, new_value, skip_until);
+ if (decision == Decision::kChangeValue) {
+ return HandleValueChange(key, new_value);
+ }
+ return decision;
+ }
+ BlobIndex blob_index;
+ Status s = blob_index.DecodeFrom(value);
+ if (!s.ok()) {
+ // Unable to decode blob index. Keeping the value.
+ return Decision::kKeep;
+ }
+ if (blob_index.HasTTL() && blob_index.expiration() <= current_time_) {
+ // Expired
+ expired_count_++;
+ expired_size_ += key.size() + value.size();
+ return Decision::kRemove;
+ }
+ if (!blob_index.IsInlined() &&
+ blob_index.file_number() < context_.next_file_number &&
+ context_.current_blob_files.count(blob_index.file_number()) == 0) {
+ // Corresponding blob file gone (most likely, evicted by FIFO eviction).
+ evicted_count_++;
+ evicted_size_ += key.size() + value.size();
+ return Decision::kRemove;
+ }
+ if (context_.fifo_eviction_seq > 0 && blob_index.HasTTL() &&
+ blob_index.expiration() < context_.evict_expiration_up_to) {
+ // Hack: Internal key is passed to BlobIndexCompactionFilter for it to
+ // get sequence number.
+ ParsedInternalKey ikey;
+ if (!ParseInternalKey(
+ key, &ikey,
+ context_.blob_db_impl->db_options_.allow_data_in_errors)
+ .ok()) {
+ assert(false);
+ return Decision::kKeep;
+ }
+ // Remove keys that could have been remove by last FIFO eviction.
+ // If get error while parsing key, ignore and continue.
+ if (ikey.sequence < context_.fifo_eviction_seq) {
+ evicted_count_++;
+ evicted_size_ += key.size() + value.size();
+ return Decision::kRemove;
+ }
+ }
+ // Apply user compaction filter for all non-TTL blob data.
+ if (ucf != nullptr && !blob_index.HasTTL()) {
+ // Hack: Internal key is passed to BlobIndexCompactionFilter for it to
+ // get sequence number.
+ ParsedInternalKey ikey;
+ if (!ParseInternalKey(
+ key, &ikey,
+ context_.blob_db_impl->db_options_.allow_data_in_errors)
+ .ok()) {
+ assert(false);
+ return Decision::kKeep;
+ }
+ // Read value from blob file.
+ PinnableSlice blob;
+ CompressionType compression_type = kNoCompression;
+ constexpr bool need_decompress = true;
+ if (!ReadBlobFromOldFile(ikey.user_key, blob_index, &blob, need_decompress,
+ &compression_type)) {
+ return Decision::kIOError;
+ }
+ CompactionFilter::Decision decision = ucf->FilterV2(
+ level, ikey.user_key, kValue, blob, new_value, skip_until);
+ if (decision == Decision::kChangeValue) {
+ return HandleValueChange(ikey.user_key, new_value);
+ }
+ return decision;
+ }
+ return Decision::kKeep;
+}
+
+CompactionFilter::Decision BlobIndexCompactionFilterBase::HandleValueChange(
+ const Slice& key, std::string* new_value) const {
+ BlobDBImpl* const blob_db_impl = context_.blob_db_impl;
+ assert(blob_db_impl);
+
+ if (new_value->size() < blob_db_impl->bdb_options_.min_blob_size) {
+ // Keep new_value inlined.
+ return Decision::kChangeValue;
+ }
+ if (!OpenNewBlobFileIfNeeded()) {
+ return Decision::kIOError;
+ }
+ Slice new_blob_value(*new_value);
+ std::string compression_output;
+ if (blob_db_impl->bdb_options_.compression != kNoCompression) {
+ new_blob_value =
+ blob_db_impl->GetCompressedSlice(new_blob_value, &compression_output);
+ }
+ uint64_t new_blob_file_number = 0;
+ uint64_t new_blob_offset = 0;
+ if (!WriteBlobToNewFile(key, new_blob_value, &new_blob_file_number,
+ &new_blob_offset)) {
+ return Decision::kIOError;
+ }
+ if (!CloseAndRegisterNewBlobFileIfNeeded()) {
+ return Decision::kIOError;
+ }
+ BlobIndex::EncodeBlob(new_value, new_blob_file_number, new_blob_offset,
+ new_blob_value.size(),
+ blob_db_impl->bdb_options_.compression);
+ return Decision::kChangeBlobIndex;
+}
+
+BlobIndexCompactionFilterGC::~BlobIndexCompactionFilterGC() {
+ assert(context().blob_db_impl);
+
+ ROCKS_LOG_INFO(context().blob_db_impl->db_options_.info_log,
+ "GC pass finished %s: encountered %" PRIu64 " blobs (%" PRIu64
+ " bytes), relocated %" PRIu64 " blobs (%" PRIu64
+ " bytes), created %" PRIu64 " new blob file(s)",
+ !gc_stats_.HasError() ? "successfully" : "with failure",
+ gc_stats_.AllBlobs(), gc_stats_.AllBytes(),
+ gc_stats_.RelocatedBlobs(), gc_stats_.RelocatedBytes(),
+ gc_stats_.NewFiles());
+
+ RecordTick(statistics(), BLOB_DB_GC_NUM_KEYS_RELOCATED,
+ gc_stats_.RelocatedBlobs());
+ RecordTick(statistics(), BLOB_DB_GC_BYTES_RELOCATED,
+ gc_stats_.RelocatedBytes());
+ RecordTick(statistics(), BLOB_DB_GC_NUM_NEW_FILES, gc_stats_.NewFiles());
+ RecordTick(statistics(), BLOB_DB_GC_FAILURES, gc_stats_.HasError());
+}
+
+bool BlobIndexCompactionFilterBase::IsBlobFileOpened() const {
+ if (blob_file_) {
+ assert(writer_);
+ return true;
+ }
+ return false;
+}
+
+bool BlobIndexCompactionFilterBase::OpenNewBlobFileIfNeeded() const {
+ if (IsBlobFileOpened()) {
+ return true;
+ }
+
+ BlobDBImpl* const blob_db_impl = context_.blob_db_impl;
+ assert(blob_db_impl);
+
+ const Status s = blob_db_impl->CreateBlobFileAndWriter(
+ /* has_ttl */ false, ExpirationRange(), "compaction/GC", &blob_file_,
+ &writer_);
+ if (!s.ok()) {
+ ROCKS_LOG_ERROR(
+ blob_db_impl->db_options_.info_log,
+ "Error opening new blob file during compaction/GC, status: %s",
+ s.ToString().c_str());
+ blob_file_.reset();
+ writer_.reset();
+ return false;
+ }
+
+ assert(blob_file_);
+ assert(writer_);
+
+ return true;
+}
+
+bool BlobIndexCompactionFilterBase::ReadBlobFromOldFile(
+ const Slice& key, const BlobIndex& blob_index, PinnableSlice* blob,
+ bool need_decompress, CompressionType* compression_type) const {
+ BlobDBImpl* const blob_db_impl = context_.blob_db_impl;
+ assert(blob_db_impl);
+
+ Status s = blob_db_impl->GetRawBlobFromFile(
+ key, blob_index.file_number(), blob_index.offset(), blob_index.size(),
+ blob, compression_type);
+
+ if (!s.ok()) {
+ ROCKS_LOG_ERROR(
+ blob_db_impl->db_options_.info_log,
+ "Error reading blob during compaction/GC, key: %s (%s), status: %s",
+ key.ToString(/* output_hex */ true).c_str(),
+ blob_index.DebugString(/* output_hex */ true).c_str(),
+ s.ToString().c_str());
+
+ return false;
+ }
+
+ if (need_decompress && *compression_type != kNoCompression) {
+ s = blob_db_impl->DecompressSlice(*blob, *compression_type, blob);
+ if (!s.ok()) {
+ ROCKS_LOG_ERROR(
+ blob_db_impl->db_options_.info_log,
+ "Uncompression error during blob read from file: %" PRIu64
+ " blob_offset: %" PRIu64 " blob_size: %" PRIu64
+ " key: %s status: '%s'",
+ blob_index.file_number(), blob_index.offset(), blob_index.size(),
+ key.ToString(/* output_hex */ true).c_str(), s.ToString().c_str());
+
+ return false;
+ }
+ }
+
+ return true;
+}
+
+bool BlobIndexCompactionFilterBase::WriteBlobToNewFile(
+ const Slice& key, const Slice& blob, uint64_t* new_blob_file_number,
+ uint64_t* new_blob_offset) const {
+ TEST_SYNC_POINT("BlobIndexCompactionFilterBase::WriteBlobToNewFile");
+ assert(new_blob_file_number);
+ assert(new_blob_offset);
+
+ assert(blob_file_);
+ *new_blob_file_number = blob_file_->BlobFileNumber();
+
+ assert(writer_);
+ uint64_t new_key_offset = 0;
+ const Status s = writer_->AddRecord(key, blob, kNoExpiration, &new_key_offset,
+ new_blob_offset);
+
+ if (!s.ok()) {
+ const BlobDBImpl* const blob_db_impl = context_.blob_db_impl;
+ assert(blob_db_impl);
+
+ ROCKS_LOG_ERROR(blob_db_impl->db_options_.info_log,
+ "Error writing blob to new file %s during compaction/GC, "
+ "key: %s, status: %s",
+ blob_file_->PathName().c_str(),
+ key.ToString(/* output_hex */ true).c_str(),
+ s.ToString().c_str());
+ return false;
+ }
+
+ const uint64_t new_size =
+ BlobLogRecord::kHeaderSize + key.size() + blob.size();
+ blob_file_->BlobRecordAdded(new_size);
+
+ BlobDBImpl* const blob_db_impl = context_.blob_db_impl;
+ assert(blob_db_impl);
+
+ blob_db_impl->total_blob_size_ += new_size;
+
+ return true;
+}
+
+bool BlobIndexCompactionFilterBase::CloseAndRegisterNewBlobFileIfNeeded()
+ const {
+ const BlobDBImpl* const blob_db_impl = context_.blob_db_impl;
+ assert(blob_db_impl);
+
+ assert(blob_file_);
+ if (blob_file_->GetFileSize() < blob_db_impl->bdb_options_.blob_file_size) {
+ return true;
+ }
+
+ return CloseAndRegisterNewBlobFile();
+}
+
+bool BlobIndexCompactionFilterBase::CloseAndRegisterNewBlobFile() const {
+ BlobDBImpl* const blob_db_impl = context_.blob_db_impl;
+ assert(blob_db_impl);
+ assert(blob_file_);
+
+ Status s;
+
+ {
+ WriteLock wl(&blob_db_impl->mutex_);
+
+ s = blob_db_impl->CloseBlobFile(blob_file_);
+
+ // Note: we delay registering the new blob file until it's closed to
+ // prevent FIFO eviction from processing it during compaction/GC.
+ blob_db_impl->RegisterBlobFile(blob_file_);
+ }
+
+ assert(blob_file_->Immutable());
+
+ if (!s.ok()) {
+ ROCKS_LOG_ERROR(
+ blob_db_impl->db_options_.info_log,
+ "Error closing new blob file %s during compaction/GC, status: %s",
+ blob_file_->PathName().c_str(), s.ToString().c_str());
+ }
+
+ blob_file_.reset();
+ return s.ok();
+}
+
+CompactionFilter::BlobDecision BlobIndexCompactionFilterGC::PrepareBlobOutput(
+ const Slice& key, const Slice& existing_value,
+ std::string* new_value) const {
+ assert(new_value);
+
+ const BlobDBImpl* const blob_db_impl = context().blob_db_impl;
+ (void)blob_db_impl;
+
+ assert(blob_db_impl);
+ assert(blob_db_impl->bdb_options_.enable_garbage_collection);
+
+ BlobIndex blob_index;
+ const Status s = blob_index.DecodeFrom(existing_value);
+ if (!s.ok()) {
+ gc_stats_.SetError();
+ return BlobDecision::kCorruption;
+ }
+
+ if (blob_index.IsInlined()) {
+ gc_stats_.AddBlob(blob_index.value().size());
+
+ return BlobDecision::kKeep;
+ }
+
+ gc_stats_.AddBlob(blob_index.size());
+
+ if (blob_index.HasTTL()) {
+ return BlobDecision::kKeep;
+ }
+
+ if (blob_index.file_number() >= context_gc_.cutoff_file_number) {
+ return BlobDecision::kKeep;
+ }
+
+ // Note: each compaction generates its own blob files, which, depending on the
+ // workload, might result in many small blob files. The total number of files
+ // is bounded though (determined by the number of compactions and the blob
+ // file size option).
+ if (!OpenNewBlobFileIfNeeded()) {
+ gc_stats_.SetError();
+ return BlobDecision::kIOError;
+ }
+
+ PinnableSlice blob;
+ CompressionType compression_type = kNoCompression;
+ std::string compression_output;
+ if (!ReadBlobFromOldFile(key, blob_index, &blob, false, &compression_type)) {
+ gc_stats_.SetError();
+ return BlobDecision::kIOError;
+ }
+
+ // If the compression_type is changed, re-compress it with the new compression
+ // type.
+ if (compression_type != blob_db_impl->bdb_options_.compression) {
+ if (compression_type != kNoCompression) {
+ const Status status =
+ blob_db_impl->DecompressSlice(blob, compression_type, &blob);
+ if (!status.ok()) {
+ gc_stats_.SetError();
+ return BlobDecision::kCorruption;
+ }
+ }
+ if (blob_db_impl->bdb_options_.compression != kNoCompression) {
+ blob_db_impl->GetCompressedSlice(blob, &compression_output);
+ blob = PinnableSlice(&compression_output);
+ blob.PinSelf();
+ }
+ }
+
+ uint64_t new_blob_file_number = 0;
+ uint64_t new_blob_offset = 0;
+ if (!WriteBlobToNewFile(key, blob, &new_blob_file_number, &new_blob_offset)) {
+ gc_stats_.SetError();
+ return BlobDecision::kIOError;
+ }
+
+ if (!CloseAndRegisterNewBlobFileIfNeeded()) {
+ gc_stats_.SetError();
+ return BlobDecision::kIOError;
+ }
+
+ BlobIndex::EncodeBlob(new_value, new_blob_file_number, new_blob_offset,
+ blob.size(), compression_type);
+
+ gc_stats_.AddRelocatedBlob(blob_index.size());
+
+ return BlobDecision::kChangeValue;
+}
+
+bool BlobIndexCompactionFilterGC::OpenNewBlobFileIfNeeded() const {
+ if (IsBlobFileOpened()) {
+ return true;
+ }
+ bool result = BlobIndexCompactionFilterBase::OpenNewBlobFileIfNeeded();
+ if (result) {
+ gc_stats_.AddNewFile();
+ }
+ return result;
+}
+
+std::unique_ptr<CompactionFilter>
+BlobIndexCompactionFilterFactoryBase::CreateUserCompactionFilterFromFactory(
+ const CompactionFilter::Context& context) const {
+ std::unique_ptr<CompactionFilter> user_comp_filter_from_factory;
+ if (user_comp_filter_factory_) {
+ user_comp_filter_from_factory =
+ user_comp_filter_factory_->CreateCompactionFilter(context);
+ }
+ return user_comp_filter_from_factory;
+}
+
+std::unique_ptr<CompactionFilter>
+BlobIndexCompactionFilterFactory::CreateCompactionFilter(
+ const CompactionFilter::Context& _context) {
+ assert(clock());
+
+ int64_t current_time = 0;
+ Status s = clock()->GetCurrentTime(&current_time);
+ if (!s.ok()) {
+ return nullptr;
+ }
+ assert(current_time >= 0);
+
+ assert(blob_db_impl());
+
+ BlobCompactionContext context;
+ blob_db_impl()->GetCompactionContext(&context);
+
+ std::unique_ptr<CompactionFilter> user_comp_filter_from_factory =
+ CreateUserCompactionFilterFromFactory(_context);
+
+ return std::unique_ptr<CompactionFilter>(new BlobIndexCompactionFilter(
+ std::move(context), user_comp_filter(),
+ std::move(user_comp_filter_from_factory), current_time, statistics()));
+}
+
+std::unique_ptr<CompactionFilter>
+BlobIndexCompactionFilterFactoryGC::CreateCompactionFilter(
+ const CompactionFilter::Context& _context) {
+ assert(clock());
+
+ int64_t current_time = 0;
+ Status s = clock()->GetCurrentTime(&current_time);
+ if (!s.ok()) {
+ return nullptr;
+ }
+ assert(current_time >= 0);
+
+ assert(blob_db_impl());
+
+ BlobCompactionContext context;
+ BlobCompactionContextGC context_gc;
+ blob_db_impl()->GetCompactionContext(&context, &context_gc);
+
+ std::unique_ptr<CompactionFilter> user_comp_filter_from_factory =
+ CreateUserCompactionFilterFromFactory(_context);
+
+ return std::unique_ptr<CompactionFilter>(new BlobIndexCompactionFilterGC(
+ std::move(context), std::move(context_gc), user_comp_filter(),
+ std::move(user_comp_filter_from_factory), current_time, statistics()));
+}
+
+} // namespace blob_db
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/blob_db/blob_compaction_filter.h b/src/rocksdb/utilities/blob_db/blob_compaction_filter.h
new file mode 100644
index 000000000..1493cfc1a
--- /dev/null
+++ b/src/rocksdb/utilities/blob_db/blob_compaction_filter.h
@@ -0,0 +1,204 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+#pragma once
+#ifndef ROCKSDB_LITE
+
+#include <unordered_set>
+
+#include "db/blob/blob_index.h"
+#include "monitoring/statistics.h"
+#include "rocksdb/compaction_filter.h"
+#include "utilities/blob_db/blob_db_gc_stats.h"
+#include "utilities/blob_db/blob_db_impl.h"
+#include "utilities/compaction_filters/layered_compaction_filter_base.h"
+
+namespace ROCKSDB_NAMESPACE {
+class SystemClock;
+namespace blob_db {
+
+struct BlobCompactionContext {
+ BlobDBImpl* blob_db_impl = nullptr;
+ uint64_t next_file_number = 0;
+ std::unordered_set<uint64_t> current_blob_files;
+ SequenceNumber fifo_eviction_seq = 0;
+ uint64_t evict_expiration_up_to = 0;
+};
+
+struct BlobCompactionContextGC {
+ uint64_t cutoff_file_number = 0;
+};
+
+// Compaction filter that deletes expired blob indexes from the base DB.
+// Comes into two varieties, one for the non-GC case and one for the GC case.
+class BlobIndexCompactionFilterBase : public LayeredCompactionFilterBase {
+ public:
+ BlobIndexCompactionFilterBase(
+ BlobCompactionContext&& _context,
+ const CompactionFilter* _user_comp_filter,
+ std::unique_ptr<const CompactionFilter> _user_comp_filter_from_factory,
+ uint64_t current_time, Statistics* stats)
+ : LayeredCompactionFilterBase(_user_comp_filter,
+ std::move(_user_comp_filter_from_factory)),
+ context_(std::move(_context)),
+ current_time_(current_time),
+ statistics_(stats) {}
+
+ ~BlobIndexCompactionFilterBase() override;
+
+ // Filter expired blob indexes regardless of snapshots.
+ bool IgnoreSnapshots() const override { return true; }
+
+ Decision FilterV2(int level, const Slice& key, ValueType value_type,
+ const Slice& value, std::string* new_value,
+ std::string* skip_until) const override;
+
+ bool IsStackedBlobDbInternalCompactionFilter() const override { return true; }
+
+ protected:
+ bool IsBlobFileOpened() const;
+ virtual bool OpenNewBlobFileIfNeeded() const;
+ bool ReadBlobFromOldFile(const Slice& key, const BlobIndex& blob_index,
+ PinnableSlice* blob, bool need_decompress,
+ CompressionType* compression_type) const;
+ bool WriteBlobToNewFile(const Slice& key, const Slice& blob,
+ uint64_t* new_blob_file_number,
+ uint64_t* new_blob_offset) const;
+ bool CloseAndRegisterNewBlobFileIfNeeded() const;
+ bool CloseAndRegisterNewBlobFile() const;
+
+ Statistics* statistics() const { return statistics_; }
+ const BlobCompactionContext& context() const { return context_; }
+
+ private:
+ Decision HandleValueChange(const Slice& key, std::string* new_value) const;
+
+ private:
+ BlobCompactionContext context_;
+ const uint64_t current_time_;
+ Statistics* statistics_;
+
+ mutable std::shared_ptr<BlobFile> blob_file_;
+ mutable std::shared_ptr<BlobLogWriter> writer_;
+
+ // It is safe to not using std::atomic since the compaction filter, created
+ // from a compaction filter factroy, will not be called from multiple threads.
+ mutable uint64_t expired_count_ = 0;
+ mutable uint64_t expired_size_ = 0;
+ mutable uint64_t evicted_count_ = 0;
+ mutable uint64_t evicted_size_ = 0;
+};
+
+class BlobIndexCompactionFilter : public BlobIndexCompactionFilterBase {
+ public:
+ BlobIndexCompactionFilter(
+ BlobCompactionContext&& _context,
+ const CompactionFilter* _user_comp_filter,
+ std::unique_ptr<const CompactionFilter> _user_comp_filter_from_factory,
+ uint64_t current_time, Statistics* stats)
+ : BlobIndexCompactionFilterBase(std::move(_context), _user_comp_filter,
+ std::move(_user_comp_filter_from_factory),
+ current_time, stats) {}
+
+ const char* Name() const override { return "BlobIndexCompactionFilter"; }
+};
+
+class BlobIndexCompactionFilterGC : public BlobIndexCompactionFilterBase {
+ public:
+ BlobIndexCompactionFilterGC(
+ BlobCompactionContext&& _context, BlobCompactionContextGC&& context_gc,
+ const CompactionFilter* _user_comp_filter,
+ std::unique_ptr<const CompactionFilter> _user_comp_filter_from_factory,
+ uint64_t current_time, Statistics* stats)
+ : BlobIndexCompactionFilterBase(std::move(_context), _user_comp_filter,
+ std::move(_user_comp_filter_from_factory),
+ current_time, stats),
+ context_gc_(std::move(context_gc)) {}
+
+ ~BlobIndexCompactionFilterGC() override;
+
+ const char* Name() const override { return "BlobIndexCompactionFilterGC"; }
+
+ BlobDecision PrepareBlobOutput(const Slice& key, const Slice& existing_value,
+ std::string* new_value) const override;
+
+ private:
+ bool OpenNewBlobFileIfNeeded() const override;
+
+ private:
+ BlobCompactionContextGC context_gc_;
+ mutable BlobDBGarbageCollectionStats gc_stats_;
+};
+
+// Compaction filter factory; similarly to the filters above, it comes
+// in two flavors, one that creates filters that support GC, and one
+// that creates non-GC filters.
+class BlobIndexCompactionFilterFactoryBase : public CompactionFilterFactory {
+ public:
+ BlobIndexCompactionFilterFactoryBase(BlobDBImpl* _blob_db_impl,
+ SystemClock* _clock,
+ const ColumnFamilyOptions& _cf_options,
+ Statistics* _statistics)
+ : blob_db_impl_(_blob_db_impl),
+ clock_(_clock),
+ statistics_(_statistics),
+ user_comp_filter_(_cf_options.compaction_filter),
+ user_comp_filter_factory_(_cf_options.compaction_filter_factory) {}
+
+ protected:
+ std::unique_ptr<CompactionFilter> CreateUserCompactionFilterFromFactory(
+ const CompactionFilter::Context& context) const;
+
+ BlobDBImpl* blob_db_impl() const { return blob_db_impl_; }
+ SystemClock* clock() const { return clock_; }
+ Statistics* statistics() const { return statistics_; }
+ const CompactionFilter* user_comp_filter() const { return user_comp_filter_; }
+
+ private:
+ BlobDBImpl* blob_db_impl_;
+ SystemClock* clock_;
+ Statistics* statistics_;
+ const CompactionFilter* user_comp_filter_;
+ std::shared_ptr<CompactionFilterFactory> user_comp_filter_factory_;
+};
+
+class BlobIndexCompactionFilterFactory
+ : public BlobIndexCompactionFilterFactoryBase {
+ public:
+ BlobIndexCompactionFilterFactory(BlobDBImpl* _blob_db_impl,
+ SystemClock* _clock,
+ const ColumnFamilyOptions& _cf_options,
+ Statistics* _statistics)
+ : BlobIndexCompactionFilterFactoryBase(_blob_db_impl, _clock, _cf_options,
+ _statistics) {}
+
+ const char* Name() const override {
+ return "BlobIndexCompactionFilterFactory";
+ }
+
+ std::unique_ptr<CompactionFilter> CreateCompactionFilter(
+ const CompactionFilter::Context& context) override;
+};
+
+class BlobIndexCompactionFilterFactoryGC
+ : public BlobIndexCompactionFilterFactoryBase {
+ public:
+ BlobIndexCompactionFilterFactoryGC(BlobDBImpl* _blob_db_impl,
+ SystemClock* _clock,
+ const ColumnFamilyOptions& _cf_options,
+ Statistics* _statistics)
+ : BlobIndexCompactionFilterFactoryBase(_blob_db_impl, _clock, _cf_options,
+ _statistics) {}
+
+ const char* Name() const override {
+ return "BlobIndexCompactionFilterFactoryGC";
+ }
+
+ std::unique_ptr<CompactionFilter> CreateCompactionFilter(
+ const CompactionFilter::Context& context) override;
+};
+
+} // namespace blob_db
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/blob_db/blob_db.cc b/src/rocksdb/utilities/blob_db/blob_db.cc
new file mode 100644
index 000000000..cbd02e68e
--- /dev/null
+++ b/src/rocksdb/utilities/blob_db/blob_db.cc
@@ -0,0 +1,114 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+#ifndef ROCKSDB_LITE
+
+#include "utilities/blob_db/blob_db.h"
+
+#include <cinttypes>
+
+#include "logging/logging.h"
+#include "utilities/blob_db/blob_db_impl.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace blob_db {
+
+Status BlobDB::Open(const Options& options, const BlobDBOptions& bdb_options,
+ const std::string& dbname, BlobDB** blob_db) {
+ *blob_db = nullptr;
+ DBOptions db_options(options);
+ ColumnFamilyOptions cf_options(options);
+ std::vector<ColumnFamilyDescriptor> column_families;
+ column_families.push_back(
+ ColumnFamilyDescriptor(kDefaultColumnFamilyName, cf_options));
+ std::vector<ColumnFamilyHandle*> handles;
+ Status s = BlobDB::Open(db_options, bdb_options, dbname, column_families,
+ &handles, blob_db);
+ if (s.ok()) {
+ assert(handles.size() == 1);
+ // i can delete the handle since DBImpl is always holding a reference to
+ // default column family
+ delete handles[0];
+ }
+ return s;
+}
+
+Status BlobDB::Open(const DBOptions& db_options,
+ const BlobDBOptions& bdb_options, const std::string& dbname,
+ const std::vector<ColumnFamilyDescriptor>& column_families,
+ std::vector<ColumnFamilyHandle*>* handles,
+ BlobDB** blob_db) {
+ assert(handles);
+
+ if (column_families.size() != 1 ||
+ column_families[0].name != kDefaultColumnFamilyName) {
+ return Status::NotSupported(
+ "Blob DB doesn't support non-default column family.");
+ }
+
+ BlobDBImpl* blob_db_impl = new BlobDBImpl(dbname, bdb_options, db_options,
+ column_families[0].options);
+ Status s = blob_db_impl->Open(handles);
+ if (s.ok()) {
+ *blob_db = static_cast<BlobDB*>(blob_db_impl);
+ } else {
+ if (!handles->empty()) {
+ for (ColumnFamilyHandle* cfh : *handles) {
+ blob_db_impl->DestroyColumnFamilyHandle(cfh);
+ }
+
+ handles->clear();
+ }
+
+ delete blob_db_impl;
+ *blob_db = nullptr;
+ }
+ return s;
+}
+
+BlobDB::BlobDB() : StackableDB(nullptr) {}
+
+void BlobDBOptions::Dump(Logger* log) const {
+ ROCKS_LOG_HEADER(
+ log, " BlobDBOptions.blob_dir: %s",
+ blob_dir.c_str());
+ ROCKS_LOG_HEADER(
+ log, " BlobDBOptions.path_relative: %d",
+ path_relative);
+ ROCKS_LOG_HEADER(
+ log, " BlobDBOptions.is_fifo: %d",
+ is_fifo);
+ ROCKS_LOG_HEADER(
+ log, " BlobDBOptions.max_db_size: %" PRIu64,
+ max_db_size);
+ ROCKS_LOG_HEADER(
+ log, " BlobDBOptions.ttl_range_secs: %" PRIu64,
+ ttl_range_secs);
+ ROCKS_LOG_HEADER(
+ log, " BlobDBOptions.min_blob_size: %" PRIu64,
+ min_blob_size);
+ ROCKS_LOG_HEADER(
+ log, " BlobDBOptions.bytes_per_sync: %" PRIu64,
+ bytes_per_sync);
+ ROCKS_LOG_HEADER(
+ log, " BlobDBOptions.blob_file_size: %" PRIu64,
+ blob_file_size);
+ ROCKS_LOG_HEADER(
+ log, " BlobDBOptions.compression: %d",
+ static_cast<int>(compression));
+ ROCKS_LOG_HEADER(
+ log, " BlobDBOptions.enable_garbage_collection: %d",
+ enable_garbage_collection);
+ ROCKS_LOG_HEADER(
+ log, " BlobDBOptions.garbage_collection_cutoff: %f",
+ garbage_collection_cutoff);
+ ROCKS_LOG_HEADER(
+ log, " BlobDBOptions.disable_background_tasks: %d",
+ disable_background_tasks);
+}
+
+} // namespace blob_db
+} // namespace ROCKSDB_NAMESPACE
+#endif
diff --git a/src/rocksdb/utilities/blob_db/blob_db.h b/src/rocksdb/utilities/blob_db/blob_db.h
new file mode 100644
index 000000000..e9d92486f
--- /dev/null
+++ b/src/rocksdb/utilities/blob_db/blob_db.h
@@ -0,0 +1,266 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#ifndef ROCKSDB_LITE
+
+#include <functional>
+#include <limits>
+#include <string>
+#include <vector>
+
+#include "rocksdb/db.h"
+#include "rocksdb/status.h"
+#include "rocksdb/utilities/stackable_db.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+namespace blob_db {
+
+// A wrapped database which puts values of KV pairs in a separate log
+// and store location to the log in the underlying DB.
+//
+// The factory needs to be moved to include/rocksdb/utilities to allow
+// users to use blob DB.
+
+constexpr uint64_t kNoExpiration = std::numeric_limits<uint64_t>::max();
+
+struct BlobDBOptions {
+ // Name of the directory under the base DB where blobs will be stored. Using
+ // a directory where the base DB stores its SST files is not supported.
+ // Default is "blob_dir"
+ std::string blob_dir = "blob_dir";
+
+ // whether the blob_dir path is relative or absolute.
+ bool path_relative = true;
+
+ // When max_db_size is reached, evict blob files to free up space
+ // instead of returnning NoSpace error on write. Blob files will be
+ // evicted from oldest to newest, based on file creation time.
+ bool is_fifo = false;
+
+ // Maximum size of the database (including SST files and blob files).
+ //
+ // Default: 0 (no limits)
+ uint64_t max_db_size = 0;
+
+ // a new bucket is opened, for ttl_range. So if ttl_range is 600seconds
+ // (10 minutes), and the first bucket starts at 1471542000
+ // then the blob buckets will be
+ // first bucket is 1471542000 - 1471542600
+ // second bucket is 1471542600 - 1471543200
+ // and so on
+ uint64_t ttl_range_secs = 3600;
+
+ // The smallest value to store in blob log. Values smaller than this threshold
+ // will be inlined in base DB together with the key.
+ uint64_t min_blob_size = 0;
+
+ // Allows OS to incrementally sync blob files to disk for every
+ // bytes_per_sync bytes written. Users shouldn't rely on it for
+ // persistency guarantee.
+ uint64_t bytes_per_sync = 512 * 1024;
+
+ // the target size of each blob file. File will become immutable
+ // after it exceeds that size
+ uint64_t blob_file_size = 256 * 1024 * 1024;
+
+ // what compression to use for Blob's
+ CompressionType compression = kNoCompression;
+
+ // If enabled, BlobDB cleans up stale blobs in non-TTL files during compaction
+ // by rewriting the remaining live blobs to new files.
+ bool enable_garbage_collection = false;
+
+ // The cutoff in terms of blob file age for garbage collection. Blobs in
+ // the oldest N non-TTL blob files will be rewritten when encountered during
+ // compaction, where N = garbage_collection_cutoff * number_of_non_TTL_files.
+ double garbage_collection_cutoff = 0.25;
+
+ // Disable all background job. Used for test only.
+ bool disable_background_tasks = false;
+
+ void Dump(Logger* log) const;
+};
+
+class BlobDB : public StackableDB {
+ public:
+ using ROCKSDB_NAMESPACE::StackableDB::Put;
+ virtual Status Put(const WriteOptions& options, const Slice& key,
+ const Slice& value) override = 0;
+ virtual Status Put(const WriteOptions& options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ const Slice& value) override {
+ if (column_family->GetID() != DefaultColumnFamily()->GetID()) {
+ return Status::NotSupported(
+ "Blob DB doesn't support non-default column family.");
+ }
+ return Put(options, key, value);
+ }
+
+ using ROCKSDB_NAMESPACE::StackableDB::Delete;
+ virtual Status Delete(const WriteOptions& options,
+ ColumnFamilyHandle* column_family,
+ const Slice& key) override {
+ if (column_family->GetID() != DefaultColumnFamily()->GetID()) {
+ return Status::NotSupported(
+ "Blob DB doesn't support non-default column family.");
+ }
+ assert(db_ != nullptr);
+ return db_->Delete(options, column_family, key);
+ }
+
+ virtual Status PutWithTTL(const WriteOptions& options, const Slice& key,
+ const Slice& value, uint64_t ttl) = 0;
+ virtual Status PutWithTTL(const WriteOptions& options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ const Slice& value, uint64_t ttl) {
+ if (column_family->GetID() != DefaultColumnFamily()->GetID()) {
+ return Status::NotSupported(
+ "Blob DB doesn't support non-default column family.");
+ }
+ return PutWithTTL(options, key, value, ttl);
+ }
+
+ // Put with expiration. Key with expiration time equal to
+ // std::numeric_limits<uint64_t>::max() means the key don't expire.
+ virtual Status PutUntil(const WriteOptions& options, const Slice& key,
+ const Slice& value, uint64_t expiration) = 0;
+ virtual Status PutUntil(const WriteOptions& options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ const Slice& value, uint64_t expiration) {
+ if (column_family->GetID() != DefaultColumnFamily()->GetID()) {
+ return Status::NotSupported(
+ "Blob DB doesn't support non-default column family.");
+ }
+ return PutUntil(options, key, value, expiration);
+ }
+
+ using ROCKSDB_NAMESPACE::StackableDB::Get;
+ virtual Status Get(const ReadOptions& options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ PinnableSlice* value) override = 0;
+
+ // Get value and expiration.
+ virtual Status Get(const ReadOptions& options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ PinnableSlice* value, uint64_t* expiration) = 0;
+ virtual Status Get(const ReadOptions& options, const Slice& key,
+ PinnableSlice* value, uint64_t* expiration) {
+ return Get(options, DefaultColumnFamily(), key, value, expiration);
+ }
+
+ using ROCKSDB_NAMESPACE::StackableDB::MultiGet;
+ virtual std::vector<Status> MultiGet(
+ const ReadOptions& options, const std::vector<Slice>& keys,
+ std::vector<std::string>* values) override = 0;
+ virtual std::vector<Status> MultiGet(
+ const ReadOptions& options,
+ const std::vector<ColumnFamilyHandle*>& column_families,
+ const std::vector<Slice>& keys,
+ std::vector<std::string>* values) override {
+ for (auto column_family : column_families) {
+ if (column_family->GetID() != DefaultColumnFamily()->GetID()) {
+ return std::vector<Status>(
+ column_families.size(),
+ Status::NotSupported(
+ "Blob DB doesn't support non-default column family."));
+ }
+ }
+ return MultiGet(options, keys, values);
+ }
+ virtual void MultiGet(const ReadOptions& /*options*/,
+ ColumnFamilyHandle* /*column_family*/,
+ const size_t num_keys, const Slice* /*keys*/,
+ PinnableSlice* /*values*/, Status* statuses,
+ const bool /*sorted_input*/ = false) override {
+ for (size_t i = 0; i < num_keys; ++i) {
+ statuses[i] =
+ Status::NotSupported("Blob DB doesn't support batched MultiGet");
+ }
+ }
+
+ using ROCKSDB_NAMESPACE::StackableDB::SingleDelete;
+ virtual Status SingleDelete(const WriteOptions& /*wopts*/,
+ ColumnFamilyHandle* /*column_family*/,
+ const Slice& /*key*/) override {
+ return Status::NotSupported("Not supported operation in blob db.");
+ }
+
+ using ROCKSDB_NAMESPACE::StackableDB::Merge;
+ virtual Status Merge(const WriteOptions& /*options*/,
+ ColumnFamilyHandle* /*column_family*/,
+ const Slice& /*key*/, const Slice& /*value*/) override {
+ return Status::NotSupported("Not supported operation in blob db.");
+ }
+
+ virtual Status Write(const WriteOptions& opts,
+ WriteBatch* updates) override = 0;
+
+ using ROCKSDB_NAMESPACE::StackableDB::NewIterator;
+ virtual Iterator* NewIterator(const ReadOptions& options) override = 0;
+ virtual Iterator* NewIterator(const ReadOptions& options,
+ ColumnFamilyHandle* column_family) override {
+ if (column_family->GetID() != DefaultColumnFamily()->GetID()) {
+ // Blob DB doesn't support non-default column family.
+ return nullptr;
+ }
+ return NewIterator(options);
+ }
+
+ Status CompactFiles(
+ const CompactionOptions& compact_options,
+ const std::vector<std::string>& input_file_names, const int output_level,
+ const int output_path_id = -1,
+ std::vector<std::string>* const output_file_names = nullptr,
+ CompactionJobInfo* compaction_job_info = nullptr) override = 0;
+ Status CompactFiles(
+ const CompactionOptions& compact_options,
+ ColumnFamilyHandle* column_family,
+ const std::vector<std::string>& input_file_names, const int output_level,
+ const int output_path_id = -1,
+ std::vector<std::string>* const output_file_names = nullptr,
+ CompactionJobInfo* compaction_job_info = nullptr) override {
+ if (column_family->GetID() != DefaultColumnFamily()->GetID()) {
+ return Status::NotSupported(
+ "Blob DB doesn't support non-default column family.");
+ }
+
+ return CompactFiles(compact_options, input_file_names, output_level,
+ output_path_id, output_file_names, compaction_job_info);
+ }
+
+ using ROCKSDB_NAMESPACE::StackableDB::Close;
+ virtual Status Close() override = 0;
+
+ // Opening blob db.
+ static Status Open(const Options& options, const BlobDBOptions& bdb_options,
+ const std::string& dbname, BlobDB** blob_db);
+
+ static Status Open(const DBOptions& db_options,
+ const BlobDBOptions& bdb_options,
+ const std::string& dbname,
+ const std::vector<ColumnFamilyDescriptor>& column_families,
+ std::vector<ColumnFamilyHandle*>* handles,
+ BlobDB** blob_db);
+
+ virtual BlobDBOptions GetBlobDBOptions() const = 0;
+
+ virtual Status SyncBlobFiles() = 0;
+
+ virtual ~BlobDB() {}
+
+ protected:
+ explicit BlobDB();
+};
+
+// Destroy the content of the database.
+Status DestroyBlobDB(const std::string& dbname, const Options& options,
+ const BlobDBOptions& bdb_options);
+
+} // namespace blob_db
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/blob_db/blob_db_gc_stats.h b/src/rocksdb/utilities/blob_db/blob_db_gc_stats.h
new file mode 100644
index 000000000..fea6b0032
--- /dev/null
+++ b/src/rocksdb/utilities/blob_db/blob_db_gc_stats.h
@@ -0,0 +1,56 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+#pragma once
+
+#include <cstdint>
+
+#include "rocksdb/rocksdb_namespace.h"
+
+#ifndef ROCKSDB_LITE
+
+namespace ROCKSDB_NAMESPACE {
+
+namespace blob_db {
+
+/**
+ * Statistics related to a single garbage collection pass (i.e. a single
+ * (sub)compaction).
+ */
+class BlobDBGarbageCollectionStats {
+ public:
+ uint64_t AllBlobs() const { return all_blobs_; }
+ uint64_t AllBytes() const { return all_bytes_; }
+ uint64_t RelocatedBlobs() const { return relocated_blobs_; }
+ uint64_t RelocatedBytes() const { return relocated_bytes_; }
+ uint64_t NewFiles() const { return new_files_; }
+ bool HasError() const { return error_; }
+
+ void AddBlob(uint64_t size) {
+ ++all_blobs_;
+ all_bytes_ += size;
+ }
+
+ void AddRelocatedBlob(uint64_t size) {
+ ++relocated_blobs_;
+ relocated_bytes_ += size;
+ }
+
+ void AddNewFile() { ++new_files_; }
+
+ void SetError() { error_ = true; }
+
+ private:
+ uint64_t all_blobs_ = 0;
+ uint64_t all_bytes_ = 0;
+ uint64_t relocated_blobs_ = 0;
+ uint64_t relocated_bytes_ = 0;
+ uint64_t new_files_ = 0;
+ bool error_ = false;
+};
+
+} // namespace blob_db
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/blob_db/blob_db_impl.cc b/src/rocksdb/utilities/blob_db/blob_db_impl.cc
new file mode 100644
index 000000000..87e294c5c
--- /dev/null
+++ b/src/rocksdb/utilities/blob_db/blob_db_impl.cc
@@ -0,0 +1,2177 @@
+
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+#ifndef ROCKSDB_LITE
+
+#include "utilities/blob_db/blob_db_impl.h"
+
+#include <algorithm>
+#include <cinttypes>
+#include <iomanip>
+#include <memory>
+#include <sstream>
+
+#include "db/blob/blob_index.h"
+#include "db/db_impl/db_impl.h"
+#include "db/write_batch_internal.h"
+#include "file/file_util.h"
+#include "file/filename.h"
+#include "file/random_access_file_reader.h"
+#include "file/sst_file_manager_impl.h"
+#include "file/writable_file_writer.h"
+#include "logging/logging.h"
+#include "monitoring/instrumented_mutex.h"
+#include "monitoring/statistics.h"
+#include "rocksdb/convenience.h"
+#include "rocksdb/env.h"
+#include "rocksdb/iterator.h"
+#include "rocksdb/utilities/stackable_db.h"
+#include "rocksdb/utilities/transaction.h"
+#include "table/block_based/block.h"
+#include "table/block_based/block_based_table_builder.h"
+#include "table/block_based/block_builder.h"
+#include "table/meta_blocks.h"
+#include "test_util/sync_point.h"
+#include "util/cast_util.h"
+#include "util/crc32c.h"
+#include "util/mutexlock.h"
+#include "util/random.h"
+#include "util/stop_watch.h"
+#include "util/timer_queue.h"
+#include "utilities/blob_db/blob_compaction_filter.h"
+#include "utilities/blob_db/blob_db_iterator.h"
+#include "utilities/blob_db/blob_db_listener.h"
+
+namespace {
+int kBlockBasedTableVersionFormat = 2;
+} // end namespace
+
+namespace ROCKSDB_NAMESPACE {
+namespace blob_db {
+
+bool BlobFileComparator::operator()(
+ const std::shared_ptr<BlobFile>& lhs,
+ const std::shared_ptr<BlobFile>& rhs) const {
+ return lhs->BlobFileNumber() > rhs->BlobFileNumber();
+}
+
+bool BlobFileComparatorTTL::operator()(
+ const std::shared_ptr<BlobFile>& lhs,
+ const std::shared_ptr<BlobFile>& rhs) const {
+ assert(lhs->HasTTL() && rhs->HasTTL());
+ if (lhs->expiration_range_.first < rhs->expiration_range_.first) {
+ return true;
+ }
+ if (lhs->expiration_range_.first > rhs->expiration_range_.first) {
+ return false;
+ }
+ return lhs->BlobFileNumber() < rhs->BlobFileNumber();
+}
+
+BlobDBImpl::BlobDBImpl(const std::string& dbname,
+ const BlobDBOptions& blob_db_options,
+ const DBOptions& db_options,
+ const ColumnFamilyOptions& cf_options)
+ : BlobDB(),
+ dbname_(dbname),
+ db_impl_(nullptr),
+ env_(db_options.env),
+ bdb_options_(blob_db_options),
+ db_options_(db_options),
+ cf_options_(cf_options),
+ file_options_(db_options),
+ statistics_(db_options_.statistics.get()),
+ next_file_number_(1),
+ flush_sequence_(0),
+ closed_(true),
+ open_file_count_(0),
+ total_blob_size_(0),
+ live_sst_size_(0),
+ fifo_eviction_seq_(0),
+ evict_expiration_up_to_(0),
+ debug_level_(0) {
+ clock_ = env_->GetSystemClock().get();
+ blob_dir_ = (bdb_options_.path_relative)
+ ? dbname + "/" + bdb_options_.blob_dir
+ : bdb_options_.blob_dir;
+ file_options_.bytes_per_sync = blob_db_options.bytes_per_sync;
+}
+
+BlobDBImpl::~BlobDBImpl() {
+ tqueue_.shutdown();
+ // CancelAllBackgroundWork(db_, true);
+ Status s __attribute__((__unused__)) = Close();
+ assert(s.ok());
+}
+
+Status BlobDBImpl::Close() {
+ if (closed_) {
+ return Status::OK();
+ }
+ closed_ = true;
+
+ // Close base DB before BlobDBImpl destructs to stop event listener and
+ // compaction filter call.
+ Status s = db_->Close();
+ // delete db_ anyway even if close failed.
+ delete db_;
+ // Reset pointers to avoid StackableDB delete the pointer again.
+ db_ = nullptr;
+ db_impl_ = nullptr;
+ if (!s.ok()) {
+ return s;
+ }
+
+ s = SyncBlobFiles();
+ return s;
+}
+
+BlobDBOptions BlobDBImpl::GetBlobDBOptions() const { return bdb_options_; }
+
+Status BlobDBImpl::Open(std::vector<ColumnFamilyHandle*>* handles) {
+ assert(handles != nullptr);
+ assert(db_ == nullptr);
+
+ if (blob_dir_.empty()) {
+ return Status::NotSupported("No blob directory in options");
+ }
+
+ if (bdb_options_.garbage_collection_cutoff < 0.0 ||
+ bdb_options_.garbage_collection_cutoff > 1.0) {
+ return Status::InvalidArgument(
+ "Garbage collection cutoff must be in the interval [0.0, 1.0]");
+ }
+
+ // Temporarily disable compactions in the base DB during open; save the user
+ // defined value beforehand so we can restore it once BlobDB is initialized.
+ // Note: this is only needed if garbage collection is enabled.
+ const bool disable_auto_compactions = cf_options_.disable_auto_compactions;
+
+ if (bdb_options_.enable_garbage_collection) {
+ cf_options_.disable_auto_compactions = true;
+ }
+
+ Status s;
+
+ // Create info log.
+ if (db_options_.info_log == nullptr) {
+ s = CreateLoggerFromOptions(dbname_, db_options_, &db_options_.info_log);
+ if (!s.ok()) {
+ return s;
+ }
+ }
+
+ ROCKS_LOG_INFO(db_options_.info_log, "Opening BlobDB...");
+
+ if ((cf_options_.compaction_filter != nullptr ||
+ cf_options_.compaction_filter_factory != nullptr)) {
+ ROCKS_LOG_INFO(db_options_.info_log,
+ "BlobDB only support compaction filter on non-TTL values.");
+ }
+
+ // Open blob directory.
+ s = env_->CreateDirIfMissing(blob_dir_);
+ if (!s.ok()) {
+ ROCKS_LOG_ERROR(db_options_.info_log,
+ "Failed to create blob_dir %s, status: %s",
+ blob_dir_.c_str(), s.ToString().c_str());
+ }
+ s = env_->GetFileSystem()->NewDirectory(blob_dir_, IOOptions(), &dir_ent_,
+ nullptr);
+ if (!s.ok()) {
+ ROCKS_LOG_ERROR(db_options_.info_log,
+ "Failed to open blob_dir %s, status: %s", blob_dir_.c_str(),
+ s.ToString().c_str());
+ return s;
+ }
+
+ // Open blob files.
+ s = OpenAllBlobFiles();
+ if (!s.ok()) {
+ return s;
+ }
+
+ // Update options
+ if (bdb_options_.enable_garbage_collection) {
+ db_options_.listeners.push_back(std::make_shared<BlobDBListenerGC>(this));
+ cf_options_.compaction_filter_factory =
+ std::make_shared<BlobIndexCompactionFilterFactoryGC>(
+ this, clock_, cf_options_, statistics_);
+ } else {
+ db_options_.listeners.push_back(std::make_shared<BlobDBListener>(this));
+ cf_options_.compaction_filter_factory =
+ std::make_shared<BlobIndexCompactionFilterFactory>(
+ this, clock_, cf_options_, statistics_);
+ }
+
+ // Reset user compaction filter after building into compaction factory.
+ cf_options_.compaction_filter = nullptr;
+
+ // Open base db.
+ ColumnFamilyDescriptor cf_descriptor(kDefaultColumnFamilyName, cf_options_);
+ s = DB::Open(db_options_, dbname_, {cf_descriptor}, handles, &db_);
+ if (!s.ok()) {
+ return s;
+ }
+ db_impl_ = static_cast_with_check<DBImpl>(db_->GetRootDB());
+
+ // Sanitize the blob_dir provided. Using a directory where the
+ // base DB stores its files for the default CF is not supported.
+ const ColumnFamilyData* const cfd =
+ static_cast<ColumnFamilyHandleImpl*>(DefaultColumnFamily())->cfd();
+ assert(cfd);
+
+ const ImmutableCFOptions* const ioptions = cfd->ioptions();
+ assert(ioptions);
+
+ assert(env_);
+
+ for (const auto& cf_path : ioptions->cf_paths) {
+ bool blob_dir_same_as_cf_dir = false;
+ s = env_->AreFilesSame(blob_dir_, cf_path.path, &blob_dir_same_as_cf_dir);
+ if (!s.ok()) {
+ ROCKS_LOG_ERROR(db_options_.info_log,
+ "Error while sanitizing blob_dir %s, status: %s",
+ blob_dir_.c_str(), s.ToString().c_str());
+ return s;
+ }
+
+ if (blob_dir_same_as_cf_dir) {
+ return Status::NotSupported(
+ "Using the base DB's storage directories for BlobDB files is not "
+ "supported.");
+ }
+ }
+
+ // Initialize SST file <-> oldest blob file mapping if garbage collection
+ // is enabled.
+ if (bdb_options_.enable_garbage_collection) {
+ std::vector<LiveFileMetaData> live_files;
+ db_->GetLiveFilesMetaData(&live_files);
+
+ InitializeBlobFileToSstMapping(live_files);
+
+ MarkUnreferencedBlobFilesObsoleteDuringOpen();
+
+ if (!disable_auto_compactions) {
+ s = db_->EnableAutoCompaction(*handles);
+ if (!s.ok()) {
+ ROCKS_LOG_ERROR(
+ db_options_.info_log,
+ "Failed to enable automatic compactions during open, status: %s",
+ s.ToString().c_str());
+ return s;
+ }
+ }
+ }
+
+ // Add trash files in blob dir to file delete scheduler.
+ SstFileManagerImpl* sfm = static_cast<SstFileManagerImpl*>(
+ db_impl_->immutable_db_options().sst_file_manager.get());
+ DeleteScheduler::CleanupDirectory(env_, sfm, blob_dir_);
+
+ UpdateLiveSSTSize();
+
+ // Start background jobs.
+ if (!bdb_options_.disable_background_tasks) {
+ StartBackgroundTasks();
+ }
+
+ ROCKS_LOG_INFO(db_options_.info_log, "BlobDB pointer %p", this);
+ bdb_options_.Dump(db_options_.info_log.get());
+ closed_ = false;
+ return s;
+}
+
+void BlobDBImpl::StartBackgroundTasks() {
+ // store a call to a member function and object
+ tqueue_.add(
+ kReclaimOpenFilesPeriodMillisecs,
+ std::bind(&BlobDBImpl::ReclaimOpenFiles, this, std::placeholders::_1));
+ tqueue_.add(
+ kDeleteObsoleteFilesPeriodMillisecs,
+ std::bind(&BlobDBImpl::DeleteObsoleteFiles, this, std::placeholders::_1));
+ tqueue_.add(kSanityCheckPeriodMillisecs,
+ std::bind(&BlobDBImpl::SanityCheck, this, std::placeholders::_1));
+ tqueue_.add(
+ kEvictExpiredFilesPeriodMillisecs,
+ std::bind(&BlobDBImpl::EvictExpiredFiles, this, std::placeholders::_1));
+}
+
+Status BlobDBImpl::GetAllBlobFiles(std::set<uint64_t>* file_numbers) {
+ assert(file_numbers != nullptr);
+ std::vector<std::string> all_files;
+ Status s = env_->GetChildren(blob_dir_, &all_files);
+ if (!s.ok()) {
+ ROCKS_LOG_ERROR(db_options_.info_log,
+ "Failed to get list of blob files, status: %s",
+ s.ToString().c_str());
+ return s;
+ }
+
+ for (const auto& file_name : all_files) {
+ uint64_t file_number;
+ FileType type;
+ bool success = ParseFileName(file_name, &file_number, &type);
+ if (success && type == kBlobFile) {
+ file_numbers->insert(file_number);
+ } else {
+ ROCKS_LOG_WARN(db_options_.info_log,
+ "Skipping file in blob directory: %s", file_name.c_str());
+ }
+ }
+
+ return s;
+}
+
+Status BlobDBImpl::OpenAllBlobFiles() {
+ std::set<uint64_t> file_numbers;
+ Status s = GetAllBlobFiles(&file_numbers);
+ if (!s.ok()) {
+ return s;
+ }
+
+ if (!file_numbers.empty()) {
+ next_file_number_.store(*file_numbers.rbegin() + 1);
+ }
+
+ std::ostringstream blob_file_oss;
+ std::ostringstream live_imm_oss;
+ std::ostringstream obsolete_file_oss;
+
+ for (auto& file_number : file_numbers) {
+ std::shared_ptr<BlobFile> blob_file = std::make_shared<BlobFile>(
+ this, blob_dir_, file_number, db_options_.info_log.get());
+ blob_file->MarkImmutable(/* sequence */ 0);
+
+ // Read file header and footer
+ Status read_metadata_status =
+ blob_file->ReadMetadata(env_->GetFileSystem(), file_options_);
+ if (read_metadata_status.IsCorruption()) {
+ // Remove incomplete file.
+ if (!obsolete_files_.empty()) {
+ obsolete_file_oss << ", ";
+ }
+ obsolete_file_oss << file_number;
+
+ ObsoleteBlobFile(blob_file, 0 /*obsolete_seq*/, false /*update_size*/);
+ continue;
+ } else if (!read_metadata_status.ok()) {
+ ROCKS_LOG_ERROR(db_options_.info_log,
+ "Unable to read metadata of blob file %" PRIu64
+ ", status: '%s'",
+ file_number, read_metadata_status.ToString().c_str());
+ return read_metadata_status;
+ }
+
+ total_blob_size_ += blob_file->GetFileSize();
+
+ if (!blob_files_.empty()) {
+ blob_file_oss << ", ";
+ }
+ blob_file_oss << file_number;
+
+ blob_files_[file_number] = blob_file;
+
+ if (!blob_file->HasTTL()) {
+ if (!live_imm_non_ttl_blob_files_.empty()) {
+ live_imm_oss << ", ";
+ }
+ live_imm_oss << file_number;
+
+ live_imm_non_ttl_blob_files_[file_number] = blob_file;
+ }
+ }
+
+ ROCKS_LOG_INFO(db_options_.info_log,
+ "Found %" ROCKSDB_PRIszt " blob files: %s", blob_files_.size(),
+ blob_file_oss.str().c_str());
+ ROCKS_LOG_INFO(
+ db_options_.info_log, "Found %" ROCKSDB_PRIszt " non-TTL blob files: %s",
+ live_imm_non_ttl_blob_files_.size(), live_imm_oss.str().c_str());
+ ROCKS_LOG_INFO(db_options_.info_log,
+ "Found %" ROCKSDB_PRIszt
+ " incomplete or corrupted blob files: %s",
+ obsolete_files_.size(), obsolete_file_oss.str().c_str());
+ return s;
+}
+
+template <typename Linker>
+void BlobDBImpl::LinkSstToBlobFileImpl(uint64_t sst_file_number,
+ uint64_t blob_file_number,
+ Linker linker) {
+ assert(bdb_options_.enable_garbage_collection);
+ assert(blob_file_number != kInvalidBlobFileNumber);
+
+ auto it = blob_files_.find(blob_file_number);
+ if (it == blob_files_.end()) {
+ ROCKS_LOG_WARN(db_options_.info_log,
+ "Blob file %" PRIu64
+ " not found while trying to link "
+ "SST file %" PRIu64,
+ blob_file_number, sst_file_number);
+ return;
+ }
+
+ BlobFile* const blob_file = it->second.get();
+ assert(blob_file);
+
+ linker(blob_file, sst_file_number);
+
+ ROCKS_LOG_INFO(db_options_.info_log,
+ "Blob file %" PRIu64 " linked to SST file %" PRIu64,
+ blob_file_number, sst_file_number);
+}
+
+void BlobDBImpl::LinkSstToBlobFile(uint64_t sst_file_number,
+ uint64_t blob_file_number) {
+ auto linker = [](BlobFile* blob_file, uint64_t sst_file) {
+ WriteLock file_lock(&blob_file->mutex_);
+ blob_file->LinkSstFile(sst_file);
+ };
+
+ LinkSstToBlobFileImpl(sst_file_number, blob_file_number, linker);
+}
+
+void BlobDBImpl::LinkSstToBlobFileNoLock(uint64_t sst_file_number,
+ uint64_t blob_file_number) {
+ auto linker = [](BlobFile* blob_file, uint64_t sst_file) {
+ blob_file->LinkSstFile(sst_file);
+ };
+
+ LinkSstToBlobFileImpl(sst_file_number, blob_file_number, linker);
+}
+
+void BlobDBImpl::UnlinkSstFromBlobFile(uint64_t sst_file_number,
+ uint64_t blob_file_number) {
+ assert(bdb_options_.enable_garbage_collection);
+ assert(blob_file_number != kInvalidBlobFileNumber);
+
+ auto it = blob_files_.find(blob_file_number);
+ if (it == blob_files_.end()) {
+ ROCKS_LOG_WARN(db_options_.info_log,
+ "Blob file %" PRIu64
+ " not found while trying to unlink "
+ "SST file %" PRIu64,
+ blob_file_number, sst_file_number);
+ return;
+ }
+
+ BlobFile* const blob_file = it->second.get();
+ assert(blob_file);
+
+ {
+ WriteLock file_lock(&blob_file->mutex_);
+ blob_file->UnlinkSstFile(sst_file_number);
+ }
+
+ ROCKS_LOG_INFO(db_options_.info_log,
+ "Blob file %" PRIu64 " unlinked from SST file %" PRIu64,
+ blob_file_number, sst_file_number);
+}
+
+void BlobDBImpl::InitializeBlobFileToSstMapping(
+ const std::vector<LiveFileMetaData>& live_files) {
+ assert(bdb_options_.enable_garbage_collection);
+
+ for (const auto& live_file : live_files) {
+ const uint64_t sst_file_number = live_file.file_number;
+ const uint64_t blob_file_number = live_file.oldest_blob_file_number;
+
+ if (blob_file_number == kInvalidBlobFileNumber) {
+ continue;
+ }
+
+ LinkSstToBlobFileNoLock(sst_file_number, blob_file_number);
+ }
+}
+
+void BlobDBImpl::ProcessFlushJobInfo(const FlushJobInfo& info) {
+ assert(bdb_options_.enable_garbage_collection);
+
+ WriteLock lock(&mutex_);
+
+ if (info.oldest_blob_file_number != kInvalidBlobFileNumber) {
+ LinkSstToBlobFile(info.file_number, info.oldest_blob_file_number);
+ }
+
+ assert(flush_sequence_ < info.largest_seqno);
+ flush_sequence_ = info.largest_seqno;
+
+ MarkUnreferencedBlobFilesObsolete();
+}
+
+void BlobDBImpl::ProcessCompactionJobInfo(const CompactionJobInfo& info) {
+ assert(bdb_options_.enable_garbage_collection);
+
+ if (!info.status.ok()) {
+ return;
+ }
+
+ // Note: the same SST file may appear in both the input and the output
+ // file list in case of a trivial move. We walk through the two lists
+ // below in a fashion that's similar to merge sort to detect this.
+
+ auto cmp = [](const CompactionFileInfo& lhs, const CompactionFileInfo& rhs) {
+ return lhs.file_number < rhs.file_number;
+ };
+
+ auto inputs = info.input_file_infos;
+ auto iit = inputs.begin();
+ const auto iit_end = inputs.end();
+
+ std::sort(iit, iit_end, cmp);
+
+ auto outputs = info.output_file_infos;
+ auto oit = outputs.begin();
+ const auto oit_end = outputs.end();
+
+ std::sort(oit, oit_end, cmp);
+
+ WriteLock lock(&mutex_);
+
+ while (iit != iit_end && oit != oit_end) {
+ const auto& input = *iit;
+ const auto& output = *oit;
+
+ if (input.file_number == output.file_number) {
+ ++iit;
+ ++oit;
+ } else if (input.file_number < output.file_number) {
+ if (input.oldest_blob_file_number != kInvalidBlobFileNumber) {
+ UnlinkSstFromBlobFile(input.file_number, input.oldest_blob_file_number);
+ }
+
+ ++iit;
+ } else {
+ assert(output.file_number < input.file_number);
+
+ if (output.oldest_blob_file_number != kInvalidBlobFileNumber) {
+ LinkSstToBlobFile(output.file_number, output.oldest_blob_file_number);
+ }
+
+ ++oit;
+ }
+ }
+
+ while (iit != iit_end) {
+ const auto& input = *iit;
+
+ if (input.oldest_blob_file_number != kInvalidBlobFileNumber) {
+ UnlinkSstFromBlobFile(input.file_number, input.oldest_blob_file_number);
+ }
+
+ ++iit;
+ }
+
+ while (oit != oit_end) {
+ const auto& output = *oit;
+
+ if (output.oldest_blob_file_number != kInvalidBlobFileNumber) {
+ LinkSstToBlobFile(output.file_number, output.oldest_blob_file_number);
+ }
+
+ ++oit;
+ }
+
+ MarkUnreferencedBlobFilesObsolete();
+}
+
+bool BlobDBImpl::MarkBlobFileObsoleteIfNeeded(
+ const std::shared_ptr<BlobFile>& blob_file, SequenceNumber obsolete_seq) {
+ assert(blob_file);
+ assert(!blob_file->HasTTL());
+ assert(blob_file->Immutable());
+ assert(bdb_options_.enable_garbage_collection);
+
+ // Note: FIFO eviction could have marked this file obsolete already.
+ if (blob_file->Obsolete()) {
+ return true;
+ }
+
+ // We cannot mark this file (or any higher-numbered files for that matter)
+ // obsolete if it is referenced by any memtables or SSTs. We keep track of
+ // the SSTs explicitly. To account for memtables, we keep track of the highest
+ // sequence number received in flush notifications, and we do not mark the
+ // blob file obsolete if there are still unflushed memtables from before
+ // the time the blob file was closed.
+ if (blob_file->GetImmutableSequence() > flush_sequence_ ||
+ !blob_file->GetLinkedSstFiles().empty()) {
+ return false;
+ }
+
+ ROCKS_LOG_INFO(db_options_.info_log,
+ "Blob file %" PRIu64 " is no longer needed, marking obsolete",
+ blob_file->BlobFileNumber());
+
+ ObsoleteBlobFile(blob_file, obsolete_seq, /* update_size */ true);
+ return true;
+}
+
+template <class Functor>
+void BlobDBImpl::MarkUnreferencedBlobFilesObsoleteImpl(Functor mark_if_needed) {
+ assert(bdb_options_.enable_garbage_collection);
+
+ // Iterate through all live immutable non-TTL blob files, and mark them
+ // obsolete assuming no SST files or memtables rely on the blobs in them.
+ // Note: we need to stop as soon as we find a blob file that has any
+ // linked SSTs (or one potentially referenced by memtables).
+
+ uint64_t obsoleted_files = 0;
+
+ auto it = live_imm_non_ttl_blob_files_.begin();
+ while (it != live_imm_non_ttl_blob_files_.end()) {
+ const auto& blob_file = it->second;
+ assert(blob_file);
+ assert(blob_file->BlobFileNumber() == it->first);
+ assert(!blob_file->HasTTL());
+ assert(blob_file->Immutable());
+
+ // Small optimization: Obsolete() does an atomic read, so we can do
+ // this check without taking a lock on the blob file's mutex.
+ if (blob_file->Obsolete()) {
+ it = live_imm_non_ttl_blob_files_.erase(it);
+ continue;
+ }
+
+ if (!mark_if_needed(blob_file)) {
+ break;
+ }
+
+ it = live_imm_non_ttl_blob_files_.erase(it);
+
+ ++obsoleted_files;
+ }
+
+ if (obsoleted_files > 0) {
+ ROCKS_LOG_INFO(db_options_.info_log,
+ "%" PRIu64 " blob file(s) marked obsolete by GC",
+ obsoleted_files);
+ RecordTick(statistics_, BLOB_DB_GC_NUM_FILES, obsoleted_files);
+ }
+}
+
+void BlobDBImpl::MarkUnreferencedBlobFilesObsolete() {
+ const SequenceNumber obsolete_seq = GetLatestSequenceNumber();
+
+ MarkUnreferencedBlobFilesObsoleteImpl(
+ [this, obsolete_seq](const std::shared_ptr<BlobFile>& blob_file) {
+ WriteLock file_lock(&blob_file->mutex_);
+ return MarkBlobFileObsoleteIfNeeded(blob_file, obsolete_seq);
+ });
+}
+
+void BlobDBImpl::MarkUnreferencedBlobFilesObsoleteDuringOpen() {
+ MarkUnreferencedBlobFilesObsoleteImpl(
+ [this](const std::shared_ptr<BlobFile>& blob_file) {
+ return MarkBlobFileObsoleteIfNeeded(blob_file, /* obsolete_seq */ 0);
+ });
+}
+
+void BlobDBImpl::CloseRandomAccessLocked(
+ const std::shared_ptr<BlobFile>& bfile) {
+ bfile->CloseRandomAccessLocked();
+ open_file_count_--;
+}
+
+Status BlobDBImpl::GetBlobFileReader(
+ const std::shared_ptr<BlobFile>& blob_file,
+ std::shared_ptr<RandomAccessFileReader>* reader) {
+ assert(reader != nullptr);
+ bool fresh_open = false;
+ Status s = blob_file->GetReader(env_, file_options_, reader, &fresh_open);
+ if (s.ok() && fresh_open) {
+ assert(*reader != nullptr);
+ open_file_count_++;
+ }
+ return s;
+}
+
+std::shared_ptr<BlobFile> BlobDBImpl::NewBlobFile(
+ bool has_ttl, const ExpirationRange& expiration_range,
+ const std::string& reason) {
+ assert(has_ttl == (expiration_range.first || expiration_range.second));
+
+ uint64_t file_num = next_file_number_++;
+
+ const uint32_t column_family_id =
+ static_cast<ColumnFamilyHandleImpl*>(DefaultColumnFamily())->GetID();
+ auto blob_file = std::make_shared<BlobFile>(
+ this, blob_dir_, file_num, db_options_.info_log.get(), column_family_id,
+ bdb_options_.compression, has_ttl, expiration_range);
+
+ ROCKS_LOG_DEBUG(db_options_.info_log, "New blob file created: %s reason='%s'",
+ blob_file->PathName().c_str(), reason.c_str());
+ LogFlush(db_options_.info_log);
+
+ return blob_file;
+}
+
+void BlobDBImpl::RegisterBlobFile(std::shared_ptr<BlobFile> blob_file) {
+ const uint64_t blob_file_number = blob_file->BlobFileNumber();
+
+ auto it = blob_files_.lower_bound(blob_file_number);
+ assert(it == blob_files_.end() || it->first != blob_file_number);
+
+ blob_files_.insert(it,
+ std::map<uint64_t, std::shared_ptr<BlobFile>>::value_type(
+ blob_file_number, std::move(blob_file)));
+}
+
+Status BlobDBImpl::CreateWriterLocked(const std::shared_ptr<BlobFile>& bfile) {
+ std::string fpath(bfile->PathName());
+ std::unique_ptr<FSWritableFile> wfile;
+ const auto& fs = env_->GetFileSystem();
+
+ Status s = fs->ReopenWritableFile(fpath, file_options_, &wfile, nullptr);
+ if (!s.ok()) {
+ ROCKS_LOG_ERROR(db_options_.info_log,
+ "Failed to open blob file for write: %s status: '%s'"
+ " exists: '%s'",
+ fpath.c_str(), s.ToString().c_str(),
+ fs->FileExists(fpath, file_options_.io_options, nullptr)
+ .ToString()
+ .c_str());
+ return s;
+ }
+
+ std::unique_ptr<WritableFileWriter> fwriter;
+ fwriter.reset(new WritableFileWriter(std::move(wfile), fpath, file_options_));
+
+ uint64_t boffset = bfile->GetFileSize();
+ if (debug_level_ >= 2 && boffset) {
+ ROCKS_LOG_DEBUG(db_options_.info_log,
+ "Open blob file: %s with offset: %" PRIu64, fpath.c_str(),
+ boffset);
+ }
+
+ BlobLogWriter::ElemType et = BlobLogWriter::kEtNone;
+ if (bfile->file_size_ == BlobLogHeader::kSize) {
+ et = BlobLogWriter::kEtFileHdr;
+ } else if (bfile->file_size_ > BlobLogHeader::kSize) {
+ et = BlobLogWriter::kEtRecord;
+ } else if (bfile->file_size_) {
+ ROCKS_LOG_WARN(db_options_.info_log,
+ "Open blob file: %s with wrong size: %" PRIu64,
+ fpath.c_str(), boffset);
+ return Status::Corruption("Invalid blob file size");
+ }
+
+ constexpr bool do_flush = true;
+
+ bfile->log_writer_ = std::make_shared<BlobLogWriter>(
+ std::move(fwriter), clock_, statistics_, bfile->file_number_,
+ db_options_.use_fsync, do_flush, boffset);
+ bfile->log_writer_->last_elem_type_ = et;
+
+ return s;
+}
+
+std::shared_ptr<BlobFile> BlobDBImpl::FindBlobFileLocked(
+ uint64_t expiration) const {
+ if (open_ttl_files_.empty()) {
+ return nullptr;
+ }
+
+ std::shared_ptr<BlobFile> tmp = std::make_shared<BlobFile>();
+ tmp->SetHasTTL(true);
+ tmp->expiration_range_ = std::make_pair(expiration, 0);
+ tmp->file_number_ = std::numeric_limits<uint64_t>::max();
+
+ auto citr = open_ttl_files_.equal_range(tmp);
+ if (citr.first == open_ttl_files_.end()) {
+ assert(citr.second == open_ttl_files_.end());
+
+ std::shared_ptr<BlobFile> check = *(open_ttl_files_.rbegin());
+ return (check->expiration_range_.second <= expiration) ? nullptr : check;
+ }
+
+ if (citr.first != citr.second) {
+ return *(citr.first);
+ }
+
+ auto finditr = citr.second;
+ if (finditr != open_ttl_files_.begin()) {
+ --finditr;
+ }
+
+ bool b2 = (*finditr)->expiration_range_.second <= expiration;
+ bool b1 = (*finditr)->expiration_range_.first > expiration;
+
+ return (b1 || b2) ? nullptr : (*finditr);
+}
+
+Status BlobDBImpl::CheckOrCreateWriterLocked(
+ const std::shared_ptr<BlobFile>& blob_file,
+ std::shared_ptr<BlobLogWriter>* writer) {
+ assert(writer != nullptr);
+ *writer = blob_file->GetWriter();
+ if (*writer != nullptr) {
+ return Status::OK();
+ }
+ Status s = CreateWriterLocked(blob_file);
+ if (s.ok()) {
+ *writer = blob_file->GetWriter();
+ }
+ return s;
+}
+
+Status BlobDBImpl::CreateBlobFileAndWriter(
+ bool has_ttl, const ExpirationRange& expiration_range,
+ const std::string& reason, std::shared_ptr<BlobFile>* blob_file,
+ std::shared_ptr<BlobLogWriter>* writer) {
+ TEST_SYNC_POINT("BlobDBImpl::CreateBlobFileAndWriter");
+ assert(has_ttl == (expiration_range.first || expiration_range.second));
+ assert(blob_file);
+ assert(writer);
+
+ *blob_file = NewBlobFile(has_ttl, expiration_range, reason);
+ assert(*blob_file);
+
+ // file not visible, hence no lock
+ Status s = CheckOrCreateWriterLocked(*blob_file, writer);
+ if (!s.ok()) {
+ ROCKS_LOG_ERROR(db_options_.info_log,
+ "Failed to get writer for blob file: %s, error: %s",
+ (*blob_file)->PathName().c_str(), s.ToString().c_str());
+ return s;
+ }
+
+ assert(*writer);
+
+ s = (*writer)->WriteHeader((*blob_file)->header_);
+ if (!s.ok()) {
+ ROCKS_LOG_ERROR(db_options_.info_log,
+ "Failed to write header to new blob file: %s"
+ " status: '%s'",
+ (*blob_file)->PathName().c_str(), s.ToString().c_str());
+ return s;
+ }
+
+ (*blob_file)->SetFileSize(BlobLogHeader::kSize);
+ total_blob_size_ += BlobLogHeader::kSize;
+
+ return s;
+}
+
+Status BlobDBImpl::SelectBlobFile(std::shared_ptr<BlobFile>* blob_file) {
+ assert(blob_file);
+
+ {
+ ReadLock rl(&mutex_);
+
+ if (open_non_ttl_file_) {
+ assert(!open_non_ttl_file_->Immutable());
+ *blob_file = open_non_ttl_file_;
+ return Status::OK();
+ }
+ }
+
+ // Check again
+ WriteLock wl(&mutex_);
+
+ if (open_non_ttl_file_) {
+ assert(!open_non_ttl_file_->Immutable());
+ *blob_file = open_non_ttl_file_;
+ return Status::OK();
+ }
+
+ std::shared_ptr<BlobLogWriter> writer;
+ const Status s = CreateBlobFileAndWriter(
+ /* has_ttl */ false, ExpirationRange(),
+ /* reason */ "SelectBlobFile", blob_file, &writer);
+ if (!s.ok()) {
+ return s;
+ }
+
+ RegisterBlobFile(*blob_file);
+ open_non_ttl_file_ = *blob_file;
+
+ return s;
+}
+
+Status BlobDBImpl::SelectBlobFileTTL(uint64_t expiration,
+ std::shared_ptr<BlobFile>* blob_file) {
+ assert(blob_file);
+ assert(expiration != kNoExpiration);
+
+ {
+ ReadLock rl(&mutex_);
+
+ *blob_file = FindBlobFileLocked(expiration);
+ if (*blob_file != nullptr) {
+ assert(!(*blob_file)->Immutable());
+ return Status::OK();
+ }
+ }
+
+ // Check again
+ WriteLock wl(&mutex_);
+
+ *blob_file = FindBlobFileLocked(expiration);
+ if (*blob_file != nullptr) {
+ assert(!(*blob_file)->Immutable());
+ return Status::OK();
+ }
+
+ const uint64_t exp_low =
+ (expiration / bdb_options_.ttl_range_secs) * bdb_options_.ttl_range_secs;
+ const uint64_t exp_high = exp_low + bdb_options_.ttl_range_secs;
+ const ExpirationRange expiration_range(exp_low, exp_high);
+
+ std::ostringstream oss;
+ oss << "SelectBlobFileTTL range: [" << exp_low << ',' << exp_high << ')';
+
+ std::shared_ptr<BlobLogWriter> writer;
+ const Status s =
+ CreateBlobFileAndWriter(/* has_ttl */ true, expiration_range,
+ /* reason */ oss.str(), blob_file, &writer);
+ if (!s.ok()) {
+ return s;
+ }
+
+ RegisterBlobFile(*blob_file);
+ open_ttl_files_.insert(*blob_file);
+
+ return s;
+}
+
+class BlobDBImpl::BlobInserter : public WriteBatch::Handler {
+ private:
+ const WriteOptions& options_;
+ BlobDBImpl* blob_db_impl_;
+ uint32_t default_cf_id_;
+ WriteBatch batch_;
+
+ public:
+ BlobInserter(const WriteOptions& options, BlobDBImpl* blob_db_impl,
+ uint32_t default_cf_id)
+ : options_(options),
+ blob_db_impl_(blob_db_impl),
+ default_cf_id_(default_cf_id) {}
+
+ WriteBatch* batch() { return &batch_; }
+
+ Status PutCF(uint32_t column_family_id, const Slice& key,
+ const Slice& value) override {
+ if (column_family_id != default_cf_id_) {
+ return Status::NotSupported(
+ "Blob DB doesn't support non-default column family.");
+ }
+ Status s = blob_db_impl_->PutBlobValue(options_, key, value, kNoExpiration,
+ &batch_);
+ return s;
+ }
+
+ Status DeleteCF(uint32_t column_family_id, const Slice& key) override {
+ if (column_family_id != default_cf_id_) {
+ return Status::NotSupported(
+ "Blob DB doesn't support non-default column family.");
+ }
+ Status s = WriteBatchInternal::Delete(&batch_, column_family_id, key);
+ return s;
+ }
+
+ virtual Status DeleteRange(uint32_t column_family_id, const Slice& begin_key,
+ const Slice& end_key) {
+ if (column_family_id != default_cf_id_) {
+ return Status::NotSupported(
+ "Blob DB doesn't support non-default column family.");
+ }
+ Status s = WriteBatchInternal::DeleteRange(&batch_, column_family_id,
+ begin_key, end_key);
+ return s;
+ }
+
+ Status SingleDeleteCF(uint32_t /*column_family_id*/,
+ const Slice& /*key*/) override {
+ return Status::NotSupported("Not supported operation in blob db.");
+ }
+
+ Status MergeCF(uint32_t /*column_family_id*/, const Slice& /*key*/,
+ const Slice& /*value*/) override {
+ return Status::NotSupported("Not supported operation in blob db.");
+ }
+
+ void LogData(const Slice& blob) override { batch_.PutLogData(blob); }
+};
+
+Status BlobDBImpl::Write(const WriteOptions& options, WriteBatch* updates) {
+ StopWatch write_sw(clock_, statistics_, BLOB_DB_WRITE_MICROS);
+ RecordTick(statistics_, BLOB_DB_NUM_WRITE);
+ uint32_t default_cf_id =
+ static_cast_with_check<ColumnFamilyHandleImpl>(DefaultColumnFamily())
+ ->GetID();
+ Status s;
+ BlobInserter blob_inserter(options, this, default_cf_id);
+ {
+ // Release write_mutex_ before DB write to avoid race condition with
+ // flush begin listener, which also require write_mutex_ to sync
+ // blob files.
+ MutexLock l(&write_mutex_);
+ s = updates->Iterate(&blob_inserter);
+ }
+ if (!s.ok()) {
+ return s;
+ }
+ return db_->Write(options, blob_inserter.batch());
+}
+
+Status BlobDBImpl::Put(const WriteOptions& options, const Slice& key,
+ const Slice& value) {
+ return PutUntil(options, key, value, kNoExpiration);
+}
+
+Status BlobDBImpl::PutWithTTL(const WriteOptions& options, const Slice& key,
+ const Slice& value, uint64_t ttl) {
+ uint64_t now = EpochNow();
+ uint64_t expiration = kNoExpiration - now > ttl ? now + ttl : kNoExpiration;
+ return PutUntil(options, key, value, expiration);
+}
+
+Status BlobDBImpl::PutUntil(const WriteOptions& options, const Slice& key,
+ const Slice& value, uint64_t expiration) {
+ StopWatch write_sw(clock_, statistics_, BLOB_DB_WRITE_MICROS);
+ RecordTick(statistics_, BLOB_DB_NUM_PUT);
+ Status s;
+ WriteBatch batch;
+ {
+ // Release write_mutex_ before DB write to avoid race condition with
+ // flush begin listener, which also require write_mutex_ to sync
+ // blob files.
+ MutexLock l(&write_mutex_);
+ s = PutBlobValue(options, key, value, expiration, &batch);
+ }
+ if (s.ok()) {
+ s = db_->Write(options, &batch);
+ }
+ return s;
+}
+
+Status BlobDBImpl::PutBlobValue(const WriteOptions& /*options*/,
+ const Slice& key, const Slice& value,
+ uint64_t expiration, WriteBatch* batch) {
+ write_mutex_.AssertHeld();
+ Status s;
+ std::string index_entry;
+ uint32_t column_family_id =
+ static_cast_with_check<ColumnFamilyHandleImpl>(DefaultColumnFamily())
+ ->GetID();
+ if (value.size() < bdb_options_.min_blob_size) {
+ if (expiration == kNoExpiration) {
+ // Put as normal value
+ s = batch->Put(key, value);
+ RecordTick(statistics_, BLOB_DB_WRITE_INLINED);
+ } else {
+ // Inlined with TTL
+ BlobIndex::EncodeInlinedTTL(&index_entry, expiration, value);
+ s = WriteBatchInternal::PutBlobIndex(batch, column_family_id, key,
+ index_entry);
+ RecordTick(statistics_, BLOB_DB_WRITE_INLINED_TTL);
+ }
+ } else {
+ std::string compression_output;
+ Slice value_compressed = GetCompressedSlice(value, &compression_output);
+
+ std::string headerbuf;
+ BlobLogWriter::ConstructBlobHeader(&headerbuf, key, value_compressed,
+ expiration);
+
+ // Check DB size limit before selecting blob file to
+ // Since CheckSizeAndEvictBlobFiles() can close blob files, it needs to be
+ // done before calling SelectBlobFile().
+ s = CheckSizeAndEvictBlobFiles(headerbuf.size() + key.size() +
+ value_compressed.size());
+ if (!s.ok()) {
+ return s;
+ }
+
+ std::shared_ptr<BlobFile> blob_file;
+ if (expiration != kNoExpiration) {
+ s = SelectBlobFileTTL(expiration, &blob_file);
+ } else {
+ s = SelectBlobFile(&blob_file);
+ }
+ if (s.ok()) {
+ assert(blob_file != nullptr);
+ assert(blob_file->GetCompressionType() == bdb_options_.compression);
+ s = AppendBlob(blob_file, headerbuf, key, value_compressed, expiration,
+ &index_entry);
+ }
+ if (s.ok()) {
+ if (expiration != kNoExpiration) {
+ WriteLock file_lock(&blob_file->mutex_);
+ blob_file->ExtendExpirationRange(expiration);
+ }
+ s = CloseBlobFileIfNeeded(blob_file);
+ }
+ if (s.ok()) {
+ s = WriteBatchInternal::PutBlobIndex(batch, column_family_id, key,
+ index_entry);
+ }
+ if (s.ok()) {
+ if (expiration == kNoExpiration) {
+ RecordTick(statistics_, BLOB_DB_WRITE_BLOB);
+ } else {
+ RecordTick(statistics_, BLOB_DB_WRITE_BLOB_TTL);
+ }
+ } else {
+ ROCKS_LOG_ERROR(
+ db_options_.info_log,
+ "Failed to append blob to FILE: %s: KEY: %s VALSZ: %" ROCKSDB_PRIszt
+ " status: '%s' blob_file: '%s'",
+ blob_file->PathName().c_str(), key.ToString().c_str(), value.size(),
+ s.ToString().c_str(), blob_file->DumpState().c_str());
+ }
+ }
+
+ RecordTick(statistics_, BLOB_DB_NUM_KEYS_WRITTEN);
+ RecordTick(statistics_, BLOB_DB_BYTES_WRITTEN, key.size() + value.size());
+ RecordInHistogram(statistics_, BLOB_DB_KEY_SIZE, key.size());
+ RecordInHistogram(statistics_, BLOB_DB_VALUE_SIZE, value.size());
+
+ return s;
+}
+
+Slice BlobDBImpl::GetCompressedSlice(const Slice& raw,
+ std::string* compression_output) const {
+ if (bdb_options_.compression == kNoCompression) {
+ return raw;
+ }
+ StopWatch compression_sw(clock_, statistics_, BLOB_DB_COMPRESSION_MICROS);
+ CompressionType type = bdb_options_.compression;
+ CompressionOptions opts;
+ CompressionContext context(type);
+ CompressionInfo info(opts, context, CompressionDict::GetEmptyDict(), type,
+ 0 /* sample_for_compression */);
+ CompressBlock(raw, info, &type, kBlockBasedTableVersionFormat, false,
+ compression_output, nullptr, nullptr);
+ return *compression_output;
+}
+
+Status BlobDBImpl::DecompressSlice(const Slice& compressed_value,
+ CompressionType compression_type,
+ PinnableSlice* value_output) const {
+ assert(compression_type != kNoCompression);
+
+ BlockContents contents;
+ auto cfh = static_cast<ColumnFamilyHandleImpl*>(DefaultColumnFamily());
+
+ {
+ StopWatch decompression_sw(clock_, statistics_,
+ BLOB_DB_DECOMPRESSION_MICROS);
+ UncompressionContext context(compression_type);
+ UncompressionInfo info(context, UncompressionDict::GetEmptyDict(),
+ compression_type);
+ Status s = UncompressBlockData(
+ info, compressed_value.data(), compressed_value.size(), &contents,
+ kBlockBasedTableVersionFormat, *(cfh->cfd()->ioptions()));
+ if (!s.ok()) {
+ return Status::Corruption("Unable to decompress blob.");
+ }
+ }
+
+ value_output->PinSelf(contents.data);
+
+ return Status::OK();
+}
+
+Status BlobDBImpl::CompactFiles(
+ const CompactionOptions& compact_options,
+ const std::vector<std::string>& input_file_names, const int output_level,
+ const int output_path_id, std::vector<std::string>* const output_file_names,
+ CompactionJobInfo* compaction_job_info) {
+ // Note: we need CompactionJobInfo to be able to track updates to the
+ // blob file <-> SST mappings, so we provide one if the user hasn't,
+ // assuming that GC is enabled.
+ CompactionJobInfo info{};
+ if (bdb_options_.enable_garbage_collection && !compaction_job_info) {
+ compaction_job_info = &info;
+ }
+
+ const Status s =
+ db_->CompactFiles(compact_options, input_file_names, output_level,
+ output_path_id, output_file_names, compaction_job_info);
+ if (!s.ok()) {
+ return s;
+ }
+
+ if (bdb_options_.enable_garbage_collection) {
+ assert(compaction_job_info);
+ ProcessCompactionJobInfo(*compaction_job_info);
+ }
+
+ return s;
+}
+
+void BlobDBImpl::GetCompactionContextCommon(BlobCompactionContext* context) {
+ assert(context);
+
+ context->blob_db_impl = this;
+ context->next_file_number = next_file_number_.load();
+ context->current_blob_files.clear();
+ for (auto& p : blob_files_) {
+ context->current_blob_files.insert(p.first);
+ }
+ context->fifo_eviction_seq = fifo_eviction_seq_;
+ context->evict_expiration_up_to = evict_expiration_up_to_;
+}
+
+void BlobDBImpl::GetCompactionContext(BlobCompactionContext* context) {
+ assert(context);
+
+ ReadLock l(&mutex_);
+ GetCompactionContextCommon(context);
+}
+
+void BlobDBImpl::GetCompactionContext(BlobCompactionContext* context,
+ BlobCompactionContextGC* context_gc) {
+ assert(context);
+ assert(context_gc);
+
+ ReadLock l(&mutex_);
+ GetCompactionContextCommon(context);
+
+ if (!live_imm_non_ttl_blob_files_.empty()) {
+ auto it = live_imm_non_ttl_blob_files_.begin();
+ std::advance(it, bdb_options_.garbage_collection_cutoff *
+ live_imm_non_ttl_blob_files_.size());
+ context_gc->cutoff_file_number = it != live_imm_non_ttl_blob_files_.end()
+ ? it->first
+ : std::numeric_limits<uint64_t>::max();
+ }
+}
+
+void BlobDBImpl::UpdateLiveSSTSize() {
+ uint64_t live_sst_size = 0;
+ bool ok = GetIntProperty(DB::Properties::kLiveSstFilesSize, &live_sst_size);
+ if (ok) {
+ live_sst_size_.store(live_sst_size);
+ ROCKS_LOG_INFO(db_options_.info_log,
+ "Updated total SST file size: %" PRIu64 " bytes.",
+ live_sst_size);
+ } else {
+ ROCKS_LOG_ERROR(
+ db_options_.info_log,
+ "Failed to update total SST file size after flush or compaction.");
+ }
+ {
+ // Trigger FIFO eviction if needed.
+ MutexLock l(&write_mutex_);
+ Status s = CheckSizeAndEvictBlobFiles(0, true /*force*/);
+ if (s.IsNoSpace()) {
+ ROCKS_LOG_WARN(db_options_.info_log,
+ "DB grow out-of-space after SST size updated. Current live"
+ " SST size: %" PRIu64
+ " , current blob files size: %" PRIu64 ".",
+ live_sst_size_.load(), total_blob_size_.load());
+ }
+ }
+}
+
+Status BlobDBImpl::CheckSizeAndEvictBlobFiles(uint64_t blob_size,
+ bool force_evict) {
+ write_mutex_.AssertHeld();
+
+ uint64_t live_sst_size = live_sst_size_.load();
+ if (bdb_options_.max_db_size == 0 ||
+ live_sst_size + total_blob_size_.load() + blob_size <=
+ bdb_options_.max_db_size) {
+ return Status::OK();
+ }
+
+ if (bdb_options_.is_fifo == false ||
+ (!force_evict && live_sst_size + blob_size > bdb_options_.max_db_size)) {
+ // FIFO eviction is disabled, or no space to insert new blob even we evict
+ // all blob files.
+ return Status::NoSpace(
+ "Write failed, as writing it would exceed max_db_size limit.");
+ }
+
+ std::vector<std::shared_ptr<BlobFile>> candidate_files;
+ CopyBlobFiles(&candidate_files);
+ std::sort(candidate_files.begin(), candidate_files.end(),
+ BlobFileComparator());
+ fifo_eviction_seq_ = GetLatestSequenceNumber();
+
+ WriteLock l(&mutex_);
+
+ while (!candidate_files.empty() &&
+ live_sst_size + total_blob_size_.load() + blob_size >
+ bdb_options_.max_db_size) {
+ std::shared_ptr<BlobFile> blob_file = candidate_files.back();
+ candidate_files.pop_back();
+ WriteLock file_lock(&blob_file->mutex_);
+ if (blob_file->Obsolete()) {
+ // File already obsoleted by someone else.
+ assert(blob_file->Immutable());
+ continue;
+ }
+ // FIFO eviction can evict open blob files.
+ if (!blob_file->Immutable()) {
+ Status s = CloseBlobFile(blob_file);
+ if (!s.ok()) {
+ return s;
+ }
+ }
+ assert(blob_file->Immutable());
+ auto expiration_range = blob_file->GetExpirationRange();
+ ROCKS_LOG_INFO(db_options_.info_log,
+ "Evict oldest blob file since DB out of space. Current "
+ "live SST file size: %" PRIu64 ", total blob size: %" PRIu64
+ ", max db size: %" PRIu64 ", evicted blob file #%" PRIu64
+ ".",
+ live_sst_size, total_blob_size_.load(),
+ bdb_options_.max_db_size, blob_file->BlobFileNumber());
+ ObsoleteBlobFile(blob_file, fifo_eviction_seq_, true /*update_size*/);
+ evict_expiration_up_to_ = expiration_range.first;
+ RecordTick(statistics_, BLOB_DB_FIFO_NUM_FILES_EVICTED);
+ RecordTick(statistics_, BLOB_DB_FIFO_NUM_KEYS_EVICTED,
+ blob_file->BlobCount());
+ RecordTick(statistics_, BLOB_DB_FIFO_BYTES_EVICTED,
+ blob_file->GetFileSize());
+ TEST_SYNC_POINT("BlobDBImpl::EvictOldestBlobFile:Evicted");
+ }
+ if (live_sst_size + total_blob_size_.load() + blob_size >
+ bdb_options_.max_db_size) {
+ return Status::NoSpace(
+ "Write failed, as writing it would exceed max_db_size limit.");
+ }
+ return Status::OK();
+}
+
+Status BlobDBImpl::AppendBlob(const std::shared_ptr<BlobFile>& bfile,
+ const std::string& headerbuf, const Slice& key,
+ const Slice& value, uint64_t expiration,
+ std::string* index_entry) {
+ Status s;
+ uint64_t blob_offset = 0;
+ uint64_t key_offset = 0;
+ {
+ WriteLock lockbfile_w(&bfile->mutex_);
+ std::shared_ptr<BlobLogWriter> writer;
+ s = CheckOrCreateWriterLocked(bfile, &writer);
+ if (!s.ok()) {
+ return s;
+ }
+
+ // write the blob to the blob log.
+ s = writer->EmitPhysicalRecord(headerbuf, key, value, &key_offset,
+ &blob_offset);
+ }
+
+ if (!s.ok()) {
+ ROCKS_LOG_ERROR(db_options_.info_log,
+ "Invalid status in AppendBlob: %s status: '%s'",
+ bfile->PathName().c_str(), s.ToString().c_str());
+ return s;
+ }
+
+ uint64_t size_put = headerbuf.size() + key.size() + value.size();
+ bfile->BlobRecordAdded(size_put);
+ total_blob_size_ += size_put;
+
+ if (expiration == kNoExpiration) {
+ BlobIndex::EncodeBlob(index_entry, bfile->BlobFileNumber(), blob_offset,
+ value.size(), bdb_options_.compression);
+ } else {
+ BlobIndex::EncodeBlobTTL(index_entry, expiration, bfile->BlobFileNumber(),
+ blob_offset, value.size(),
+ bdb_options_.compression);
+ }
+
+ return s;
+}
+
+std::vector<Status> BlobDBImpl::MultiGet(const ReadOptions& read_options,
+ const std::vector<Slice>& keys,
+ std::vector<std::string>* values) {
+ StopWatch multiget_sw(clock_, statistics_, BLOB_DB_MULTIGET_MICROS);
+ RecordTick(statistics_, BLOB_DB_NUM_MULTIGET);
+ // Get a snapshot to avoid blob file get deleted between we
+ // fetch and index entry and reading from the file.
+ ReadOptions ro(read_options);
+ bool snapshot_created = SetSnapshotIfNeeded(&ro);
+
+ std::vector<Status> statuses;
+ statuses.reserve(keys.size());
+ values->clear();
+ values->reserve(keys.size());
+ PinnableSlice value;
+ for (size_t i = 0; i < keys.size(); i++) {
+ statuses.push_back(Get(ro, DefaultColumnFamily(), keys[i], &value));
+ values->push_back(value.ToString());
+ value.Reset();
+ }
+ if (snapshot_created) {
+ db_->ReleaseSnapshot(ro.snapshot);
+ }
+ return statuses;
+}
+
+bool BlobDBImpl::SetSnapshotIfNeeded(ReadOptions* read_options) {
+ assert(read_options != nullptr);
+ if (read_options->snapshot != nullptr) {
+ return false;
+ }
+ read_options->snapshot = db_->GetSnapshot();
+ return true;
+}
+
+Status BlobDBImpl::GetBlobValue(const Slice& key, const Slice& index_entry,
+ PinnableSlice* value, uint64_t* expiration) {
+ assert(value);
+
+ BlobIndex blob_index;
+ Status s = blob_index.DecodeFrom(index_entry);
+ if (!s.ok()) {
+ return s;
+ }
+
+ if (blob_index.HasTTL() && blob_index.expiration() <= EpochNow()) {
+ return Status::NotFound("Key expired");
+ }
+
+ if (expiration != nullptr) {
+ if (blob_index.HasTTL()) {
+ *expiration = blob_index.expiration();
+ } else {
+ *expiration = kNoExpiration;
+ }
+ }
+
+ if (blob_index.IsInlined()) {
+ // TODO(yiwu): If index_entry is a PinnableSlice, we can also pin the same
+ // memory buffer to avoid extra copy.
+ value->PinSelf(blob_index.value());
+ return Status::OK();
+ }
+
+ CompressionType compression_type = kNoCompression;
+ s = GetRawBlobFromFile(key, blob_index.file_number(), blob_index.offset(),
+ blob_index.size(), value, &compression_type);
+ if (!s.ok()) {
+ return s;
+ }
+
+ if (compression_type != kNoCompression) {
+ s = DecompressSlice(*value, compression_type, value);
+ if (!s.ok()) {
+ if (debug_level_ >= 2) {
+ ROCKS_LOG_ERROR(
+ db_options_.info_log,
+ "Uncompression error during blob read from file: %" PRIu64
+ " blob_offset: %" PRIu64 " blob_size: %" PRIu64
+ " key: %s status: '%s'",
+ blob_index.file_number(), blob_index.offset(), blob_index.size(),
+ key.ToString(/* output_hex */ true).c_str(), s.ToString().c_str());
+ }
+ return s;
+ }
+ }
+
+ return Status::OK();
+}
+
+Status BlobDBImpl::GetRawBlobFromFile(const Slice& key, uint64_t file_number,
+ uint64_t offset, uint64_t size,
+ PinnableSlice* value,
+ CompressionType* compression_type) {
+ assert(value);
+ assert(compression_type);
+ assert(*compression_type == kNoCompression);
+
+ if (!size) {
+ value->PinSelf("");
+ return Status::OK();
+ }
+
+ // offset has to have certain min, as we will read CRC
+ // later from the Blob Header, which needs to be also a
+ // valid offset.
+ if (offset <
+ (BlobLogHeader::kSize + BlobLogRecord::kHeaderSize + key.size())) {
+ if (debug_level_ >= 2) {
+ ROCKS_LOG_ERROR(db_options_.info_log,
+ "Invalid blob index file_number: %" PRIu64
+ " blob_offset: %" PRIu64 " blob_size: %" PRIu64
+ " key: %s",
+ file_number, offset, size,
+ key.ToString(/* output_hex */ true).c_str());
+ }
+
+ return Status::NotFound("Invalid blob offset");
+ }
+
+ std::shared_ptr<BlobFile> blob_file;
+
+ {
+ ReadLock rl(&mutex_);
+ auto it = blob_files_.find(file_number);
+
+ // file was deleted
+ if (it == blob_files_.end()) {
+ return Status::NotFound("Blob Not Found as blob file missing");
+ }
+
+ blob_file = it->second;
+ }
+
+ *compression_type = blob_file->GetCompressionType();
+
+ // takes locks when called
+ std::shared_ptr<RandomAccessFileReader> reader;
+ Status s = GetBlobFileReader(blob_file, &reader);
+ if (!s.ok()) {
+ return s;
+ }
+
+ assert(offset >= key.size() + sizeof(uint32_t));
+ const uint64_t record_offset = offset - key.size() - sizeof(uint32_t);
+ const uint64_t record_size = sizeof(uint32_t) + key.size() + size;
+
+ // Allocate the buffer. This is safe in C++11
+ std::string buf;
+ AlignedBuf aligned_buf;
+
+ // A partial blob record contain checksum, key and value.
+ Slice blob_record;
+
+ {
+ StopWatch read_sw(clock_, statistics_, BLOB_DB_BLOB_FILE_READ_MICROS);
+ // TODO: rate limit old blob DB file reads.
+ if (reader->use_direct_io()) {
+ s = reader->Read(IOOptions(), record_offset,
+ static_cast<size_t>(record_size), &blob_record, nullptr,
+ &aligned_buf, Env::IO_TOTAL /* rate_limiter_priority */);
+ } else {
+ buf.reserve(static_cast<size_t>(record_size));
+ s = reader->Read(IOOptions(), record_offset,
+ static_cast<size_t>(record_size), &blob_record, &buf[0],
+ nullptr, Env::IO_TOTAL /* rate_limiter_priority */);
+ }
+ RecordTick(statistics_, BLOB_DB_BLOB_FILE_BYTES_READ, blob_record.size());
+ }
+
+ if (!s.ok()) {
+ ROCKS_LOG_DEBUG(
+ db_options_.info_log,
+ "Failed to read blob from blob file %" PRIu64 ", blob_offset: %" PRIu64
+ ", blob_size: %" PRIu64 ", key_size: %" ROCKSDB_PRIszt ", status: '%s'",
+ file_number, offset, size, key.size(), s.ToString().c_str());
+ return s;
+ }
+
+ if (blob_record.size() != record_size) {
+ ROCKS_LOG_DEBUG(
+ db_options_.info_log,
+ "Failed to read blob from blob file %" PRIu64 ", blob_offset: %" PRIu64
+ ", blob_size: %" PRIu64 ", key_size: %" ROCKSDB_PRIszt
+ ", read %" ROCKSDB_PRIszt " bytes, expected %" PRIu64 " bytes",
+ file_number, offset, size, key.size(), blob_record.size(), record_size);
+
+ return Status::Corruption("Failed to retrieve blob from blob index.");
+ }
+
+ Slice crc_slice(blob_record.data(), sizeof(uint32_t));
+ Slice blob_value(blob_record.data() + sizeof(uint32_t) + key.size(),
+ static_cast<size_t>(size));
+
+ uint32_t crc_exp = 0;
+ if (!GetFixed32(&crc_slice, &crc_exp)) {
+ ROCKS_LOG_DEBUG(
+ db_options_.info_log,
+ "Unable to decode CRC from blob file %" PRIu64 ", blob_offset: %" PRIu64
+ ", blob_size: %" PRIu64 ", key size: %" ROCKSDB_PRIszt ", status: '%s'",
+ file_number, offset, size, key.size(), s.ToString().c_str());
+ return Status::Corruption("Unable to decode checksum.");
+ }
+
+ uint32_t crc = crc32c::Value(blob_record.data() + sizeof(uint32_t),
+ blob_record.size() - sizeof(uint32_t));
+ crc = crc32c::Mask(crc); // Adjust for storage
+ if (crc != crc_exp) {
+ if (debug_level_ >= 2) {
+ ROCKS_LOG_ERROR(
+ db_options_.info_log,
+ "Blob crc mismatch file: %" PRIu64 " blob_offset: %" PRIu64
+ " blob_size: %" PRIu64 " key: %s status: '%s'",
+ file_number, offset, size,
+ key.ToString(/* output_hex */ true).c_str(), s.ToString().c_str());
+ }
+
+ return Status::Corruption("Corruption. Blob CRC mismatch");
+ }
+
+ value->PinSelf(blob_value);
+
+ return Status::OK();
+}
+
+Status BlobDBImpl::Get(const ReadOptions& read_options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ PinnableSlice* value) {
+ return Get(read_options, column_family, key, value,
+ static_cast<uint64_t*>(nullptr) /*expiration*/);
+}
+
+Status BlobDBImpl::Get(const ReadOptions& read_options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ PinnableSlice* value, uint64_t* expiration) {
+ StopWatch get_sw(clock_, statistics_, BLOB_DB_GET_MICROS);
+ RecordTick(statistics_, BLOB_DB_NUM_GET);
+ return GetImpl(read_options, column_family, key, value, expiration);
+}
+
+Status BlobDBImpl::GetImpl(const ReadOptions& read_options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ PinnableSlice* value, uint64_t* expiration) {
+ if (column_family->GetID() != DefaultColumnFamily()->GetID()) {
+ return Status::NotSupported(
+ "Blob DB doesn't support non-default column family.");
+ }
+ // Get a snapshot to avoid blob file get deleted between we
+ // fetch and index entry and reading from the file.
+ // TODO(yiwu): For Get() retry if file not found would be a simpler strategy.
+ ReadOptions ro(read_options);
+ bool snapshot_created = SetSnapshotIfNeeded(&ro);
+
+ PinnableSlice index_entry;
+ Status s;
+ bool is_blob_index = false;
+ DBImpl::GetImplOptions get_impl_options;
+ get_impl_options.column_family = column_family;
+ get_impl_options.value = &index_entry;
+ get_impl_options.is_blob_index = &is_blob_index;
+ s = db_impl_->GetImpl(ro, key, get_impl_options);
+ if (expiration != nullptr) {
+ *expiration = kNoExpiration;
+ }
+ RecordTick(statistics_, BLOB_DB_NUM_KEYS_READ);
+ if (s.ok()) {
+ if (is_blob_index) {
+ s = GetBlobValue(key, index_entry, value, expiration);
+ } else {
+ // The index entry is the value itself in this case.
+ value->PinSelf(index_entry);
+ }
+ RecordTick(statistics_, BLOB_DB_BYTES_READ, value->size());
+ }
+ if (snapshot_created) {
+ db_->ReleaseSnapshot(ro.snapshot);
+ }
+ return s;
+}
+
+std::pair<bool, int64_t> BlobDBImpl::SanityCheck(bool aborted) {
+ if (aborted) {
+ return std::make_pair(false, -1);
+ }
+
+ ReadLock rl(&mutex_);
+
+ ROCKS_LOG_INFO(db_options_.info_log, "Starting Sanity Check");
+ ROCKS_LOG_INFO(db_options_.info_log, "Number of files %" ROCKSDB_PRIszt,
+ blob_files_.size());
+ ROCKS_LOG_INFO(db_options_.info_log, "Number of open files %" ROCKSDB_PRIszt,
+ open_ttl_files_.size());
+
+ for (const auto& blob_file : open_ttl_files_) {
+ (void)blob_file;
+ assert(!blob_file->Immutable());
+ }
+
+ for (const auto& pair : live_imm_non_ttl_blob_files_) {
+ const auto& blob_file = pair.second;
+ (void)blob_file;
+ assert(!blob_file->HasTTL());
+ assert(blob_file->Immutable());
+ }
+
+ uint64_t now = EpochNow();
+
+ for (auto blob_file_pair : blob_files_) {
+ auto blob_file = blob_file_pair.second;
+ std::ostringstream buf;
+
+ buf << "Blob file " << blob_file->BlobFileNumber() << ", size "
+ << blob_file->GetFileSize() << ", blob count " << blob_file->BlobCount()
+ << ", immutable " << blob_file->Immutable();
+
+ if (blob_file->HasTTL()) {
+ ExpirationRange expiration_range;
+ {
+ ReadLock file_lock(&blob_file->mutex_);
+ expiration_range = blob_file->GetExpirationRange();
+ }
+ buf << ", expiration range (" << expiration_range.first << ", "
+ << expiration_range.second << ")";
+
+ if (!blob_file->Obsolete()) {
+ buf << ", expire in " << (expiration_range.second - now) << "seconds";
+ }
+ }
+ if (blob_file->Obsolete()) {
+ buf << ", obsolete at " << blob_file->GetObsoleteSequence();
+ }
+ buf << ".";
+ ROCKS_LOG_INFO(db_options_.info_log, "%s", buf.str().c_str());
+ }
+
+ // reschedule
+ return std::make_pair(true, -1);
+}
+
+Status BlobDBImpl::CloseBlobFile(std::shared_ptr<BlobFile> bfile) {
+ TEST_SYNC_POINT("BlobDBImpl::CloseBlobFile");
+ assert(bfile);
+ assert(!bfile->Immutable());
+ assert(!bfile->Obsolete());
+
+ if (bfile->HasTTL() || bfile == open_non_ttl_file_) {
+ write_mutex_.AssertHeld();
+ }
+
+ ROCKS_LOG_INFO(db_options_.info_log,
+ "Closing blob file %" PRIu64 ". Path: %s",
+ bfile->BlobFileNumber(), bfile->PathName().c_str());
+
+ const SequenceNumber sequence = GetLatestSequenceNumber();
+
+ const Status s = bfile->WriteFooterAndCloseLocked(sequence);
+
+ if (s.ok()) {
+ total_blob_size_ += BlobLogFooter::kSize;
+ } else {
+ bfile->MarkImmutable(sequence);
+
+ ROCKS_LOG_ERROR(db_options_.info_log,
+ "Failed to close blob file %" PRIu64 "with error: %s",
+ bfile->BlobFileNumber(), s.ToString().c_str());
+ }
+
+ if (bfile->HasTTL()) {
+ size_t erased __attribute__((__unused__));
+ erased = open_ttl_files_.erase(bfile);
+ } else {
+ if (bfile == open_non_ttl_file_) {
+ open_non_ttl_file_ = nullptr;
+ }
+
+ const uint64_t blob_file_number = bfile->BlobFileNumber();
+ auto it = live_imm_non_ttl_blob_files_.lower_bound(blob_file_number);
+ assert(it == live_imm_non_ttl_blob_files_.end() ||
+ it->first != blob_file_number);
+ live_imm_non_ttl_blob_files_.insert(
+ it, std::map<uint64_t, std::shared_ptr<BlobFile>>::value_type(
+ blob_file_number, bfile));
+ }
+
+ return s;
+}
+
+Status BlobDBImpl::CloseBlobFileIfNeeded(std::shared_ptr<BlobFile>& bfile) {
+ write_mutex_.AssertHeld();
+
+ // atomic read
+ if (bfile->GetFileSize() < bdb_options_.blob_file_size) {
+ return Status::OK();
+ }
+
+ WriteLock lock(&mutex_);
+ WriteLock file_lock(&bfile->mutex_);
+
+ assert(!bfile->Obsolete() || bfile->Immutable());
+ if (bfile->Immutable()) {
+ return Status::OK();
+ }
+
+ return CloseBlobFile(bfile);
+}
+
+void BlobDBImpl::ObsoleteBlobFile(std::shared_ptr<BlobFile> blob_file,
+ SequenceNumber obsolete_seq,
+ bool update_size) {
+ assert(blob_file->Immutable());
+ assert(!blob_file->Obsolete());
+
+ // Should hold write lock of mutex_ or during DB open.
+ blob_file->MarkObsolete(obsolete_seq);
+ obsolete_files_.push_back(blob_file);
+ assert(total_blob_size_.load() >= blob_file->GetFileSize());
+ if (update_size) {
+ total_blob_size_ -= blob_file->GetFileSize();
+ }
+}
+
+bool BlobDBImpl::VisibleToActiveSnapshot(
+ const std::shared_ptr<BlobFile>& bfile) {
+ assert(bfile->Obsolete());
+
+ // We check whether the oldest snapshot is no less than the last sequence
+ // by the time the blob file become obsolete. If so, the blob file is not
+ // visible to all existing snapshots.
+ //
+ // If we keep track of the earliest sequence of the keys in the blob file,
+ // we could instead check if there's a snapshot falls in range
+ // [earliest_sequence, obsolete_sequence). But doing so will make the
+ // implementation more complicated.
+ SequenceNumber obsolete_sequence = bfile->GetObsoleteSequence();
+ SequenceNumber oldest_snapshot = kMaxSequenceNumber;
+ {
+ // Need to lock DBImpl mutex before access snapshot list.
+ InstrumentedMutexLock l(db_impl_->mutex());
+ auto& snapshots = db_impl_->snapshots();
+ if (!snapshots.empty()) {
+ oldest_snapshot = snapshots.oldest()->GetSequenceNumber();
+ }
+ }
+ bool visible = oldest_snapshot < obsolete_sequence;
+ if (visible) {
+ ROCKS_LOG_INFO(db_options_.info_log,
+ "Obsolete blob file %" PRIu64 " (obsolete at %" PRIu64
+ ") visible to oldest snapshot %" PRIu64 ".",
+ bfile->BlobFileNumber(), obsolete_sequence, oldest_snapshot);
+ }
+ return visible;
+}
+
+std::pair<bool, int64_t> BlobDBImpl::EvictExpiredFiles(bool aborted) {
+ if (aborted) {
+ return std::make_pair(false, -1);
+ }
+
+ TEST_SYNC_POINT("BlobDBImpl::EvictExpiredFiles:0");
+ TEST_SYNC_POINT("BlobDBImpl::EvictExpiredFiles:1");
+
+ std::vector<std::shared_ptr<BlobFile>> process_files;
+ uint64_t now = EpochNow();
+ {
+ ReadLock rl(&mutex_);
+ for (auto p : blob_files_) {
+ auto& blob_file = p.second;
+ ReadLock file_lock(&blob_file->mutex_);
+ if (blob_file->HasTTL() && !blob_file->Obsolete() &&
+ blob_file->GetExpirationRange().second <= now) {
+ process_files.push_back(blob_file);
+ }
+ }
+ }
+
+ TEST_SYNC_POINT("BlobDBImpl::EvictExpiredFiles:2");
+ TEST_SYNC_POINT("BlobDBImpl::EvictExpiredFiles:3");
+ TEST_SYNC_POINT_CALLBACK("BlobDBImpl::EvictExpiredFiles:cb", nullptr);
+
+ SequenceNumber seq = GetLatestSequenceNumber();
+ {
+ MutexLock l(&write_mutex_);
+ WriteLock lock(&mutex_);
+ for (auto& blob_file : process_files) {
+ WriteLock file_lock(&blob_file->mutex_);
+
+ // Need to double check if the file is obsolete.
+ if (blob_file->Obsolete()) {
+ assert(blob_file->Immutable());
+ continue;
+ }
+
+ if (!blob_file->Immutable()) {
+ CloseBlobFile(blob_file);
+ }
+
+ assert(blob_file->Immutable());
+
+ ObsoleteBlobFile(blob_file, seq, true /*update_size*/);
+ }
+ }
+
+ return std::make_pair(true, -1);
+}
+
+Status BlobDBImpl::SyncBlobFiles() {
+ MutexLock l(&write_mutex_);
+
+ std::vector<std::shared_ptr<BlobFile>> process_files;
+ {
+ ReadLock rl(&mutex_);
+ for (auto fitr : open_ttl_files_) {
+ process_files.push_back(fitr);
+ }
+ if (open_non_ttl_file_ != nullptr) {
+ process_files.push_back(open_non_ttl_file_);
+ }
+ }
+
+ Status s;
+ for (auto& blob_file : process_files) {
+ s = blob_file->Fsync();
+ if (!s.ok()) {
+ ROCKS_LOG_ERROR(db_options_.info_log,
+ "Failed to sync blob file %" PRIu64 ", status: %s",
+ blob_file->BlobFileNumber(), s.ToString().c_str());
+ return s;
+ }
+ }
+
+ s = dir_ent_->FsyncWithDirOptions(IOOptions(), nullptr, DirFsyncOptions());
+ if (!s.ok()) {
+ ROCKS_LOG_ERROR(db_options_.info_log,
+ "Failed to sync blob directory, status: %s",
+ s.ToString().c_str());
+ }
+ return s;
+}
+
+std::pair<bool, int64_t> BlobDBImpl::ReclaimOpenFiles(bool aborted) {
+ if (aborted) return std::make_pair(false, -1);
+
+ if (open_file_count_.load() < kOpenFilesTrigger) {
+ return std::make_pair(true, -1);
+ }
+
+ // in the future, we should sort by last_access_
+ // instead of closing every file
+ ReadLock rl(&mutex_);
+ for (auto const& ent : blob_files_) {
+ auto bfile = ent.second;
+ if (bfile->last_access_.load() == -1) continue;
+
+ WriteLock lockbfile_w(&bfile->mutex_);
+ CloseRandomAccessLocked(bfile);
+ }
+
+ return std::make_pair(true, -1);
+}
+
+std::pair<bool, int64_t> BlobDBImpl::DeleteObsoleteFiles(bool aborted) {
+ if (aborted) {
+ return std::make_pair(false, -1);
+ }
+
+ MutexLock delete_file_lock(&delete_file_mutex_);
+ if (disable_file_deletions_ > 0) {
+ return std::make_pair(true, -1);
+ }
+
+ std::list<std::shared_ptr<BlobFile>> tobsolete;
+ {
+ WriteLock wl(&mutex_);
+ if (obsolete_files_.empty()) {
+ return std::make_pair(true, -1);
+ }
+ tobsolete.swap(obsolete_files_);
+ }
+
+ bool file_deleted = false;
+ for (auto iter = tobsolete.begin(); iter != tobsolete.end();) {
+ auto bfile = *iter;
+ {
+ ReadLock lockbfile_r(&bfile->mutex_);
+ if (VisibleToActiveSnapshot(bfile)) {
+ ROCKS_LOG_INFO(db_options_.info_log,
+ "Could not delete file due to snapshot failure %s",
+ bfile->PathName().c_str());
+ ++iter;
+ continue;
+ }
+ }
+ ROCKS_LOG_INFO(db_options_.info_log,
+ "Will delete file due to snapshot success %s",
+ bfile->PathName().c_str());
+
+ {
+ WriteLock wl(&mutex_);
+ blob_files_.erase(bfile->BlobFileNumber());
+ }
+
+ Status s = DeleteDBFile(&(db_impl_->immutable_db_options()),
+ bfile->PathName(), blob_dir_, true,
+ /*force_fg=*/false);
+ if (!s.ok()) {
+ ROCKS_LOG_ERROR(db_options_.info_log,
+ "File failed to be deleted as obsolete %s",
+ bfile->PathName().c_str());
+ ++iter;
+ continue;
+ }
+
+ file_deleted = true;
+ ROCKS_LOG_INFO(db_options_.info_log,
+ "File deleted as obsolete from blob dir %s",
+ bfile->PathName().c_str());
+
+ iter = tobsolete.erase(iter);
+ }
+
+ // directory change. Fsync
+ if (file_deleted) {
+ Status s = dir_ent_->FsyncWithDirOptions(
+ IOOptions(), nullptr,
+ DirFsyncOptions(DirFsyncOptions::FsyncReason::kFileDeleted));
+ if (!s.ok()) {
+ ROCKS_LOG_ERROR(db_options_.info_log, "Failed to sync dir %s: %s",
+ blob_dir_.c_str(), s.ToString().c_str());
+ }
+ }
+
+ // put files back into obsolete if for some reason, delete failed
+ if (!tobsolete.empty()) {
+ WriteLock wl(&mutex_);
+ for (auto bfile : tobsolete) {
+ blob_files_.insert(std::make_pair(bfile->BlobFileNumber(), bfile));
+ obsolete_files_.push_front(bfile);
+ }
+ }
+
+ return std::make_pair(!aborted, -1);
+}
+
+void BlobDBImpl::CopyBlobFiles(
+ std::vector<std::shared_ptr<BlobFile>>* bfiles_copy) {
+ ReadLock rl(&mutex_);
+ for (auto const& p : blob_files_) {
+ bfiles_copy->push_back(p.second);
+ }
+}
+
+Iterator* BlobDBImpl::NewIterator(const ReadOptions& read_options) {
+ auto* cfd =
+ static_cast_with_check<ColumnFamilyHandleImpl>(DefaultColumnFamily())
+ ->cfd();
+ // Get a snapshot to avoid blob file get deleted between we
+ // fetch and index entry and reading from the file.
+ ManagedSnapshot* own_snapshot = nullptr;
+ const Snapshot* snapshot = read_options.snapshot;
+ if (snapshot == nullptr) {
+ own_snapshot = new ManagedSnapshot(db_);
+ snapshot = own_snapshot->snapshot();
+ }
+ auto* iter = db_impl_->NewIteratorImpl(
+ read_options, cfd, snapshot->GetSequenceNumber(),
+ nullptr /*read_callback*/, true /*expose_blob_index*/);
+ return new BlobDBIterator(own_snapshot, iter, this, clock_, statistics_);
+}
+
+Status DestroyBlobDB(const std::string& dbname, const Options& options,
+ const BlobDBOptions& bdb_options) {
+ const ImmutableDBOptions soptions(SanitizeOptions(dbname, options));
+ Env* env = soptions.env;
+
+ Status status;
+ std::string blobdir;
+ blobdir = (bdb_options.path_relative) ? dbname + "/" + bdb_options.blob_dir
+ : bdb_options.blob_dir;
+
+ std::vector<std::string> filenames;
+ if (env->GetChildren(blobdir, &filenames).ok()) {
+ for (const auto& f : filenames) {
+ uint64_t number;
+ FileType type;
+ if (ParseFileName(f, &number, &type) && type == kBlobFile) {
+ Status del = DeleteDBFile(&soptions, blobdir + "/" + f, blobdir, true,
+ /*force_fg=*/false);
+ if (status.ok() && !del.ok()) {
+ status = del;
+ }
+ }
+ }
+ // TODO: What to do if we cannot delete the directory?
+ env->DeleteDir(blobdir).PermitUncheckedError();
+ }
+ Status destroy = DestroyDB(dbname, options);
+ if (status.ok() && !destroy.ok()) {
+ status = destroy;
+ }
+
+ return status;
+}
+
+#ifndef NDEBUG
+Status BlobDBImpl::TEST_GetBlobValue(const Slice& key, const Slice& index_entry,
+ PinnableSlice* value) {
+ return GetBlobValue(key, index_entry, value);
+}
+
+void BlobDBImpl::TEST_AddDummyBlobFile(uint64_t blob_file_number,
+ SequenceNumber immutable_sequence) {
+ auto blob_file = std::make_shared<BlobFile>(this, blob_dir_, blob_file_number,
+ db_options_.info_log.get());
+ blob_file->MarkImmutable(immutable_sequence);
+
+ blob_files_[blob_file_number] = blob_file;
+ live_imm_non_ttl_blob_files_[blob_file_number] = blob_file;
+}
+
+std::vector<std::shared_ptr<BlobFile>> BlobDBImpl::TEST_GetBlobFiles() const {
+ ReadLock l(&mutex_);
+ std::vector<std::shared_ptr<BlobFile>> blob_files;
+ for (auto& p : blob_files_) {
+ blob_files.emplace_back(p.second);
+ }
+ return blob_files;
+}
+
+std::vector<std::shared_ptr<BlobFile>> BlobDBImpl::TEST_GetLiveImmNonTTLFiles()
+ const {
+ ReadLock l(&mutex_);
+ std::vector<std::shared_ptr<BlobFile>> live_imm_non_ttl_files;
+ for (const auto& pair : live_imm_non_ttl_blob_files_) {
+ live_imm_non_ttl_files.emplace_back(pair.second);
+ }
+ return live_imm_non_ttl_files;
+}
+
+std::vector<std::shared_ptr<BlobFile>> BlobDBImpl::TEST_GetObsoleteFiles()
+ const {
+ ReadLock l(&mutex_);
+ std::vector<std::shared_ptr<BlobFile>> obsolete_files;
+ for (auto& bfile : obsolete_files_) {
+ obsolete_files.emplace_back(bfile);
+ }
+ return obsolete_files;
+}
+
+void BlobDBImpl::TEST_DeleteObsoleteFiles() {
+ DeleteObsoleteFiles(false /*abort*/);
+}
+
+Status BlobDBImpl::TEST_CloseBlobFile(std::shared_ptr<BlobFile>& bfile) {
+ MutexLock l(&write_mutex_);
+ WriteLock lock(&mutex_);
+ WriteLock file_lock(&bfile->mutex_);
+
+ return CloseBlobFile(bfile);
+}
+
+void BlobDBImpl::TEST_ObsoleteBlobFile(std::shared_ptr<BlobFile>& blob_file,
+ SequenceNumber obsolete_seq,
+ bool update_size) {
+ return ObsoleteBlobFile(blob_file, obsolete_seq, update_size);
+}
+
+void BlobDBImpl::TEST_EvictExpiredFiles() {
+ EvictExpiredFiles(false /*abort*/);
+}
+
+uint64_t BlobDBImpl::TEST_live_sst_size() { return live_sst_size_.load(); }
+
+void BlobDBImpl::TEST_InitializeBlobFileToSstMapping(
+ const std::vector<LiveFileMetaData>& live_files) {
+ InitializeBlobFileToSstMapping(live_files);
+}
+
+void BlobDBImpl::TEST_ProcessFlushJobInfo(const FlushJobInfo& info) {
+ ProcessFlushJobInfo(info);
+}
+
+void BlobDBImpl::TEST_ProcessCompactionJobInfo(const CompactionJobInfo& info) {
+ ProcessCompactionJobInfo(info);
+}
+
+#endif // !NDEBUG
+
+} // namespace blob_db
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/blob_db/blob_db_impl.h b/src/rocksdb/utilities/blob_db/blob_db_impl.h
new file mode 100644
index 000000000..0b4dbf5e5
--- /dev/null
+++ b/src/rocksdb/utilities/blob_db/blob_db_impl.h
@@ -0,0 +1,503 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#ifndef ROCKSDB_LITE
+
+#include <atomic>
+#include <condition_variable>
+#include <limits>
+#include <list>
+#include <memory>
+#include <set>
+#include <string>
+#include <thread>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "db/blob/blob_log_format.h"
+#include "db/blob/blob_log_writer.h"
+#include "db/db_iter.h"
+#include "rocksdb/compaction_filter.h"
+#include "rocksdb/db.h"
+#include "rocksdb/file_system.h"
+#include "rocksdb/listener.h"
+#include "rocksdb/options.h"
+#include "rocksdb/statistics.h"
+#include "rocksdb/wal_filter.h"
+#include "util/mutexlock.h"
+#include "util/timer_queue.h"
+#include "utilities/blob_db/blob_db.h"
+#include "utilities/blob_db/blob_file.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class DBImpl;
+class ColumnFamilyHandle;
+class ColumnFamilyData;
+class SystemClock;
+
+struct FlushJobInfo;
+
+namespace blob_db {
+
+struct BlobCompactionContext;
+struct BlobCompactionContextGC;
+class BlobDBImpl;
+class BlobFile;
+
+// Comparator to sort "TTL" aware Blob files based on the lower value of
+// TTL range.
+struct BlobFileComparatorTTL {
+ bool operator()(const std::shared_ptr<BlobFile>& lhs,
+ const std::shared_ptr<BlobFile>& rhs) const;
+};
+
+struct BlobFileComparator {
+ bool operator()(const std::shared_ptr<BlobFile>& lhs,
+ const std::shared_ptr<BlobFile>& rhs) const;
+};
+
+/**
+ * The implementation class for BlobDB. It manages the blob logs, which
+ * are sequentially written files. Blob logs can be of the TTL or non-TTL
+ * varieties; the former are cleaned up when they expire, while the latter
+ * are (optionally) garbage collected.
+ */
+class BlobDBImpl : public BlobDB {
+ friend class BlobFile;
+ friend class BlobDBIterator;
+ friend class BlobDBListener;
+ friend class BlobDBListenerGC;
+ friend class BlobIndexCompactionFilterBase;
+ friend class BlobIndexCompactionFilterGC;
+
+ public:
+ // deletions check period
+ static constexpr uint32_t kDeleteCheckPeriodMillisecs = 2 * 1000;
+
+ // sanity check task
+ static constexpr uint32_t kSanityCheckPeriodMillisecs = 20 * 60 * 1000;
+
+ // how many random access open files can we tolerate
+ static constexpr uint32_t kOpenFilesTrigger = 100;
+
+ // how often to schedule reclaim open files.
+ static constexpr uint32_t kReclaimOpenFilesPeriodMillisecs = 1 * 1000;
+
+ // how often to schedule delete obs files periods
+ static constexpr uint32_t kDeleteObsoleteFilesPeriodMillisecs = 10 * 1000;
+
+ // how often to schedule expired files eviction.
+ static constexpr uint32_t kEvictExpiredFilesPeriodMillisecs = 10 * 1000;
+
+ // when should oldest file be evicted:
+ // on reaching 90% of blob_dir_size
+ static constexpr double kEvictOldestFileAtSize = 0.9;
+
+ using BlobDB::Put;
+ Status Put(const WriteOptions& options, const Slice& key,
+ const Slice& value) override;
+
+ using BlobDB::Get;
+ Status Get(const ReadOptions& read_options, ColumnFamilyHandle* column_family,
+ const Slice& key, PinnableSlice* value) override;
+
+ Status Get(const ReadOptions& read_options, ColumnFamilyHandle* column_family,
+ const Slice& key, PinnableSlice* value,
+ uint64_t* expiration) override;
+
+ using BlobDB::NewIterator;
+ virtual Iterator* NewIterator(const ReadOptions& read_options) override;
+
+ using BlobDB::NewIterators;
+ virtual Status NewIterators(
+ const ReadOptions& /*read_options*/,
+ const std::vector<ColumnFamilyHandle*>& /*column_families*/,
+ std::vector<Iterator*>* /*iterators*/) override {
+ return Status::NotSupported("Not implemented");
+ }
+
+ using BlobDB::MultiGet;
+ virtual std::vector<Status> MultiGet(
+ const ReadOptions& read_options, const std::vector<Slice>& keys,
+ std::vector<std::string>* values) override;
+
+ using BlobDB::Write;
+ virtual Status Write(const WriteOptions& opts, WriteBatch* updates) override;
+
+ virtual Status Close() override;
+
+ using BlobDB::PutWithTTL;
+ Status PutWithTTL(const WriteOptions& options, const Slice& key,
+ const Slice& value, uint64_t ttl) override;
+
+ using BlobDB::PutUntil;
+ Status PutUntil(const WriteOptions& options, const Slice& key,
+ const Slice& value, uint64_t expiration) override;
+
+ using BlobDB::CompactFiles;
+ Status CompactFiles(
+ const CompactionOptions& compact_options,
+ const std::vector<std::string>& input_file_names, const int output_level,
+ const int output_path_id = -1,
+ std::vector<std::string>* const output_file_names = nullptr,
+ CompactionJobInfo* compaction_job_info = nullptr) override;
+
+ BlobDBOptions GetBlobDBOptions() const override;
+
+ BlobDBImpl(const std::string& dbname, const BlobDBOptions& bdb_options,
+ const DBOptions& db_options,
+ const ColumnFamilyOptions& cf_options);
+
+ virtual Status DisableFileDeletions() override;
+
+ virtual Status EnableFileDeletions(bool force) override;
+
+ virtual Status GetLiveFiles(std::vector<std::string>&,
+ uint64_t* manifest_file_size,
+ bool flush_memtable = true) override;
+ virtual void GetLiveFilesMetaData(std::vector<LiveFileMetaData>*) override;
+
+ ~BlobDBImpl();
+
+ Status Open(std::vector<ColumnFamilyHandle*>* handles);
+
+ Status SyncBlobFiles() override;
+
+ // Common part of the two GetCompactionContext methods below.
+ // REQUIRES: read lock on mutex_
+ void GetCompactionContextCommon(BlobCompactionContext* context);
+
+ void GetCompactionContext(BlobCompactionContext* context);
+ void GetCompactionContext(BlobCompactionContext* context,
+ BlobCompactionContextGC* context_gc);
+
+#ifndef NDEBUG
+ Status TEST_GetBlobValue(const Slice& key, const Slice& index_entry,
+ PinnableSlice* value);
+
+ void TEST_AddDummyBlobFile(uint64_t blob_file_number,
+ SequenceNumber immutable_sequence);
+
+ std::vector<std::shared_ptr<BlobFile>> TEST_GetBlobFiles() const;
+
+ std::vector<std::shared_ptr<BlobFile>> TEST_GetLiveImmNonTTLFiles() const;
+
+ std::vector<std::shared_ptr<BlobFile>> TEST_GetObsoleteFiles() const;
+
+ Status TEST_CloseBlobFile(std::shared_ptr<BlobFile>& bfile);
+
+ void TEST_ObsoleteBlobFile(std::shared_ptr<BlobFile>& blob_file,
+ SequenceNumber obsolete_seq = 0,
+ bool update_size = true);
+
+ void TEST_EvictExpiredFiles();
+
+ void TEST_DeleteObsoleteFiles();
+
+ uint64_t TEST_live_sst_size();
+
+ const std::string& TEST_blob_dir() const { return blob_dir_; }
+
+ void TEST_InitializeBlobFileToSstMapping(
+ const std::vector<LiveFileMetaData>& live_files);
+
+ void TEST_ProcessFlushJobInfo(const FlushJobInfo& info);
+
+ void TEST_ProcessCompactionJobInfo(const CompactionJobInfo& info);
+
+#endif // !NDEBUG
+
+ private:
+ class BlobInserter;
+
+ // Create a snapshot if there isn't one in read options.
+ // Return true if a snapshot is created.
+ bool SetSnapshotIfNeeded(ReadOptions* read_options);
+
+ Status GetImpl(const ReadOptions& read_options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ PinnableSlice* value, uint64_t* expiration = nullptr);
+
+ Status GetBlobValue(const Slice& key, const Slice& index_entry,
+ PinnableSlice* value, uint64_t* expiration = nullptr);
+
+ Status GetRawBlobFromFile(const Slice& key, uint64_t file_number,
+ uint64_t offset, uint64_t size,
+ PinnableSlice* value,
+ CompressionType* compression_type);
+
+ Slice GetCompressedSlice(const Slice& raw,
+ std::string* compression_output) const;
+
+ Status DecompressSlice(const Slice& compressed_value,
+ CompressionType compression_type,
+ PinnableSlice* value_output) const;
+
+ // Close a file by appending a footer, and removes file from open files list.
+ // REQUIRES: lock held on write_mutex_, write lock held on both the db mutex_
+ // and the blob file's mutex_. If called on a blob file which is visible only
+ // to a single thread (like in the case of new files written during
+ // compaction/GC), the locks on write_mutex_ and the blob file's mutex_ can be
+ // avoided.
+ Status CloseBlobFile(std::shared_ptr<BlobFile> bfile);
+
+ // Close a file if its size exceeds blob_file_size
+ // REQUIRES: lock held on write_mutex_.
+ Status CloseBlobFileIfNeeded(std::shared_ptr<BlobFile>& bfile);
+
+ // Mark file as obsolete and move the file to obsolete file list.
+ //
+ // REQUIRED: hold write lock of mutex_ or during DB open.
+ void ObsoleteBlobFile(std::shared_ptr<BlobFile> blob_file,
+ SequenceNumber obsolete_seq, bool update_size);
+
+ Status PutBlobValue(const WriteOptions& options, const Slice& key,
+ const Slice& value, uint64_t expiration,
+ WriteBatch* batch);
+
+ Status AppendBlob(const std::shared_ptr<BlobFile>& bfile,
+ const std::string& headerbuf, const Slice& key,
+ const Slice& value, uint64_t expiration,
+ std::string* index_entry);
+
+ // Create a new blob file and associated writer.
+ Status CreateBlobFileAndWriter(bool has_ttl,
+ const ExpirationRange& expiration_range,
+ const std::string& reason,
+ std::shared_ptr<BlobFile>* blob_file,
+ std::shared_ptr<BlobLogWriter>* writer);
+
+ // Get the open non-TTL blob log file, or create a new one if no such file
+ // exists.
+ Status SelectBlobFile(std::shared_ptr<BlobFile>* blob_file);
+
+ // Get the open TTL blob log file for a certain expiration, or create a new
+ // one if no such file exists.
+ Status SelectBlobFileTTL(uint64_t expiration,
+ std::shared_ptr<BlobFile>* blob_file);
+
+ std::shared_ptr<BlobFile> FindBlobFileLocked(uint64_t expiration) const;
+
+ // periodic sanity check. Bunch of checks
+ std::pair<bool, int64_t> SanityCheck(bool aborted);
+
+ // Delete files that have been marked obsolete (either because of TTL
+ // or GC). Check whether any snapshots exist which refer to the same.
+ std::pair<bool, int64_t> DeleteObsoleteFiles(bool aborted);
+
+ // periodically check if open blob files and their TTL's has expired
+ // if expired, close the sequential writer and make the file immutable
+ std::pair<bool, int64_t> EvictExpiredFiles(bool aborted);
+
+ // if the number of open files, approaches ULIMIT's this
+ // task will close random readers, which are kept around for
+ // efficiency
+ std::pair<bool, int64_t> ReclaimOpenFiles(bool aborted);
+
+ std::pair<bool, int64_t> RemoveTimerQ(TimerQueue* tq, bool aborted);
+
+ // Adds the background tasks to the timer queue
+ void StartBackgroundTasks();
+
+ // add a new Blob File
+ std::shared_ptr<BlobFile> NewBlobFile(bool has_ttl,
+ const ExpirationRange& expiration_range,
+ const std::string& reason);
+
+ // Register a new blob file.
+ // REQUIRES: write lock on mutex_.
+ void RegisterBlobFile(std::shared_ptr<BlobFile> blob_file);
+
+ // collect all the blob log files from the blob directory
+ Status GetAllBlobFiles(std::set<uint64_t>* file_numbers);
+
+ // Open all blob files found in blob_dir.
+ Status OpenAllBlobFiles();
+
+ // Link an SST to a blob file. Comes in locking and non-locking varieties
+ // (the latter is used during Open).
+ template <typename Linker>
+ void LinkSstToBlobFileImpl(uint64_t sst_file_number,
+ uint64_t blob_file_number, Linker linker);
+
+ void LinkSstToBlobFile(uint64_t sst_file_number, uint64_t blob_file_number);
+
+ void LinkSstToBlobFileNoLock(uint64_t sst_file_number,
+ uint64_t blob_file_number);
+
+ // Unlink an SST from a blob file.
+ void UnlinkSstFromBlobFile(uint64_t sst_file_number,
+ uint64_t blob_file_number);
+
+ // Initialize the mapping between blob files and SSTs during Open.
+ void InitializeBlobFileToSstMapping(
+ const std::vector<LiveFileMetaData>& live_files);
+
+ // Update the mapping between blob files and SSTs after a flush and mark
+ // any unneeded blob files obsolete.
+ void ProcessFlushJobInfo(const FlushJobInfo& info);
+
+ // Update the mapping between blob files and SSTs after a compaction and
+ // mark any unneeded blob files obsolete.
+ void ProcessCompactionJobInfo(const CompactionJobInfo& info);
+
+ // Mark an immutable non-TTL blob file obsolete assuming it has no more SSTs
+ // linked to it, and all memtables from before the blob file became immutable
+ // have been flushed. Note: should only be called if the condition holds for
+ // all lower-numbered non-TTL blob files as well.
+ bool MarkBlobFileObsoleteIfNeeded(const std::shared_ptr<BlobFile>& blob_file,
+ SequenceNumber obsolete_seq);
+
+ // Mark all immutable non-TTL blob files that aren't needed by any SSTs as
+ // obsolete. Comes in two varieties; the version used during Open need not
+ // worry about locking or snapshots.
+ template <class Functor>
+ void MarkUnreferencedBlobFilesObsoleteImpl(Functor mark_if_needed);
+
+ void MarkUnreferencedBlobFilesObsolete();
+ void MarkUnreferencedBlobFilesObsoleteDuringOpen();
+
+ void UpdateLiveSSTSize();
+
+ Status GetBlobFileReader(const std::shared_ptr<BlobFile>& blob_file,
+ std::shared_ptr<RandomAccessFileReader>* reader);
+
+ // hold write mutex on file and call.
+ // Close the above Random Access reader
+ void CloseRandomAccessLocked(const std::shared_ptr<BlobFile>& bfile);
+
+ // hold write mutex on file and call
+ // creates a sequential (append) writer for this blobfile
+ Status CreateWriterLocked(const std::shared_ptr<BlobFile>& bfile);
+
+ // returns a BlobLogWriter object for the file. If writer is not
+ // already present, creates one. Needs Write Mutex to be held
+ Status CheckOrCreateWriterLocked(const std::shared_ptr<BlobFile>& blob_file,
+ std::shared_ptr<BlobLogWriter>* writer);
+
+ // checks if there is no snapshot which is referencing the
+ // blobs
+ bool VisibleToActiveSnapshot(const std::shared_ptr<BlobFile>& file);
+ bool FileDeleteOk_SnapshotCheckLocked(const std::shared_ptr<BlobFile>& bfile);
+
+ void CopyBlobFiles(std::vector<std::shared_ptr<BlobFile>>* bfiles_copy);
+
+ uint64_t EpochNow() { return clock_->NowMicros() / 1000000; }
+
+ // Check if inserting a new blob will make DB grow out of space.
+ // If is_fifo = true, FIFO eviction will be triggered to make room for the
+ // new blob. If force_evict = true, FIFO eviction will evict blob files
+ // even eviction will not make enough room for the new blob.
+ Status CheckSizeAndEvictBlobFiles(uint64_t blob_size,
+ bool force_evict = false);
+
+ // name of the database directory
+ std::string dbname_;
+
+ // the base DB
+ DBImpl* db_impl_;
+ Env* env_;
+ SystemClock* clock_;
+ // the options that govern the behavior of Blob Storage
+ BlobDBOptions bdb_options_;
+ DBOptions db_options_;
+ ColumnFamilyOptions cf_options_;
+ FileOptions file_options_;
+
+ // Raw pointer of statistic. db_options_ has a std::shared_ptr to hold
+ // ownership.
+ Statistics* statistics_;
+
+ // by default this is "blob_dir" under dbname_
+ // but can be configured
+ std::string blob_dir_;
+
+ // pointer to directory
+ std::unique_ptr<FSDirectory> dir_ent_;
+
+ // Read Write Mutex, which protects all the data structures
+ // HEAVILY TRAFFICKED
+ mutable port::RWMutex mutex_;
+
+ // Writers has to hold write_mutex_ before writing.
+ mutable port::Mutex write_mutex_;
+
+ // counter for blob file number
+ std::atomic<uint64_t> next_file_number_;
+
+ // entire metadata of all the BLOB files memory
+ std::map<uint64_t, std::shared_ptr<BlobFile>> blob_files_;
+
+ // All live immutable non-TTL blob files.
+ std::map<uint64_t, std::shared_ptr<BlobFile>> live_imm_non_ttl_blob_files_;
+
+ // The largest sequence number that has been flushed.
+ SequenceNumber flush_sequence_;
+
+ // opened non-TTL blob file.
+ std::shared_ptr<BlobFile> open_non_ttl_file_;
+
+ // all the blob files which are currently being appended to based
+ // on variety of incoming TTL's
+ std::set<std::shared_ptr<BlobFile>, BlobFileComparatorTTL> open_ttl_files_;
+
+ // Flag to check whether Close() has been called on this DB
+ bool closed_;
+
+ // timer based queue to execute tasks
+ TimerQueue tqueue_;
+
+ // number of files opened for random access/GET
+ // counter is used to monitor and close excess RA files.
+ std::atomic<uint32_t> open_file_count_;
+
+ // Total size of all live blob files (i.e. exclude obsolete files).
+ std::atomic<uint64_t> total_blob_size_;
+
+ // total size of SST files.
+ std::atomic<uint64_t> live_sst_size_;
+
+ // Latest FIFO eviction timestamp
+ //
+ // REQUIRES: access with metex_ lock held.
+ uint64_t fifo_eviction_seq_;
+
+ // The expiration up to which latest FIFO eviction evicts.
+ //
+ // REQUIRES: access with metex_ lock held.
+ uint64_t evict_expiration_up_to_;
+
+ std::list<std::shared_ptr<BlobFile>> obsolete_files_;
+
+ // DeleteObsoleteFiles, DiableFileDeletions and EnableFileDeletions block
+ // on the mutex to avoid contention.
+ //
+ // While DeleteObsoleteFiles hold both mutex_ and delete_file_mutex_, note
+ // the difference. mutex_ only needs to be held when access the
+ // data-structure, and delete_file_mutex_ needs to be held the whole time
+ // during DeleteObsoleteFiles to avoid being run simultaneously with
+ // DisableFileDeletions.
+ //
+ // If both of mutex_ and delete_file_mutex_ needs to be held, it is adviced
+ // to hold delete_file_mutex_ first to avoid deadlock.
+ mutable port::Mutex delete_file_mutex_;
+
+ // Each call of DisableFileDeletions will increase disable_file_deletion_
+ // by 1. EnableFileDeletions will either decrease the count by 1 or reset
+ // it to zeor, depending on the force flag.
+ //
+ // REQUIRES: access with delete_file_mutex_ held.
+ int disable_file_deletions_ = 0;
+
+ uint32_t debug_level_;
+};
+
+} // namespace blob_db
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/blob_db/blob_db_impl_filesnapshot.cc b/src/rocksdb/utilities/blob_db/blob_db_impl_filesnapshot.cc
new file mode 100644
index 000000000..87e3f33cc
--- /dev/null
+++ b/src/rocksdb/utilities/blob_db/blob_db_impl_filesnapshot.cc
@@ -0,0 +1,113 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "file/filename.h"
+#include "logging/logging.h"
+#include "util/cast_util.h"
+#include "util/mutexlock.h"
+#include "utilities/blob_db/blob_db_impl.h"
+
+// BlobDBImpl methods to get snapshot of files, e.g. for replication.
+
+namespace ROCKSDB_NAMESPACE {
+namespace blob_db {
+
+Status BlobDBImpl::DisableFileDeletions() {
+ // Disable base DB file deletions.
+ Status s = db_impl_->DisableFileDeletions();
+ if (!s.ok()) {
+ return s;
+ }
+
+ int count = 0;
+ {
+ // Hold delete_file_mutex_ to make sure no DeleteObsoleteFiles job
+ // is running.
+ MutexLock l(&delete_file_mutex_);
+ count = ++disable_file_deletions_;
+ }
+
+ ROCKS_LOG_INFO(db_options_.info_log,
+ "Disabled blob file deletions. count: %d", count);
+ return Status::OK();
+}
+
+Status BlobDBImpl::EnableFileDeletions(bool force) {
+ // Enable base DB file deletions.
+ Status s = db_impl_->EnableFileDeletions(force);
+ if (!s.ok()) {
+ return s;
+ }
+
+ int count = 0;
+ {
+ MutexLock l(&delete_file_mutex_);
+ if (force) {
+ disable_file_deletions_ = 0;
+ } else if (disable_file_deletions_ > 0) {
+ count = --disable_file_deletions_;
+ }
+ assert(count >= 0);
+ }
+
+ ROCKS_LOG_INFO(db_options_.info_log, "Enabled blob file deletions. count: %d",
+ count);
+ // Consider trigger DeleteobsoleteFiles once after re-enabled, if we are to
+ // make DeleteobsoleteFiles re-run interval configuration.
+ return Status::OK();
+}
+
+Status BlobDBImpl::GetLiveFiles(std::vector<std::string>& ret,
+ uint64_t* manifest_file_size,
+ bool flush_memtable) {
+ if (!bdb_options_.path_relative) {
+ return Status::NotSupported(
+ "Not able to get relative blob file path from absolute blob_dir.");
+ }
+ // Hold a lock in the beginning to avoid updates to base DB during the call
+ ReadLock rl(&mutex_);
+ Status s = db_->GetLiveFiles(ret, manifest_file_size, flush_memtable);
+ if (!s.ok()) {
+ return s;
+ }
+ ret.reserve(ret.size() + blob_files_.size());
+ for (auto bfile_pair : blob_files_) {
+ auto blob_file = bfile_pair.second;
+ // Path should be relative to db_name, but begin with slash.
+ ret.emplace_back(
+ BlobFileName("", bdb_options_.blob_dir, blob_file->BlobFileNumber()));
+ }
+ return Status::OK();
+}
+
+void BlobDBImpl::GetLiveFilesMetaData(std::vector<LiveFileMetaData>* metadata) {
+ // Path should be relative to db_name.
+ assert(bdb_options_.path_relative);
+ // Hold a lock in the beginning to avoid updates to base DB during the call
+ ReadLock rl(&mutex_);
+ db_->GetLiveFilesMetaData(metadata);
+ for (auto bfile_pair : blob_files_) {
+ auto blob_file = bfile_pair.second;
+ LiveFileMetaData filemetadata;
+ filemetadata.size = blob_file->GetFileSize();
+ const uint64_t file_number = blob_file->BlobFileNumber();
+ // Path should be relative to db_name, but begin with slash.
+ filemetadata.name = BlobFileName("", bdb_options_.blob_dir, file_number);
+ filemetadata.file_number = file_number;
+ if (blob_file->HasTTL()) {
+ filemetadata.oldest_ancester_time = blob_file->GetExpirationRange().first;
+ }
+ auto cfh =
+ static_cast_with_check<ColumnFamilyHandleImpl>(DefaultColumnFamily());
+ filemetadata.column_family_name = cfh->GetName();
+ metadata->emplace_back(filemetadata);
+ }
+}
+
+} // namespace blob_db
+} // namespace ROCKSDB_NAMESPACE
+#endif // !ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/blob_db/blob_db_iterator.h b/src/rocksdb/utilities/blob_db/blob_db_iterator.h
new file mode 100644
index 000000000..fd2b2f8f5
--- /dev/null
+++ b/src/rocksdb/utilities/blob_db/blob_db_iterator.h
@@ -0,0 +1,150 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+#ifndef ROCKSDB_LITE
+
+#include "db/arena_wrapped_db_iter.h"
+#include "rocksdb/iterator.h"
+#include "util/stop_watch.h"
+#include "utilities/blob_db/blob_db_impl.h"
+
+namespace ROCKSDB_NAMESPACE {
+class Statistics;
+class SystemClock;
+
+namespace blob_db {
+
+using ROCKSDB_NAMESPACE::ManagedSnapshot;
+
+class BlobDBIterator : public Iterator {
+ public:
+ BlobDBIterator(ManagedSnapshot* snapshot, ArenaWrappedDBIter* iter,
+ BlobDBImpl* blob_db, SystemClock* clock,
+ Statistics* statistics)
+ : snapshot_(snapshot),
+ iter_(iter),
+ blob_db_(blob_db),
+ clock_(clock),
+ statistics_(statistics) {}
+
+ virtual ~BlobDBIterator() = default;
+
+ bool Valid() const override {
+ if (!iter_->Valid()) {
+ return false;
+ }
+ return status_.ok();
+ }
+
+ Status status() const override {
+ if (!iter_->status().ok()) {
+ return iter_->status();
+ }
+ return status_;
+ }
+
+ void SeekToFirst() override {
+ StopWatch seek_sw(clock_, statistics_, BLOB_DB_SEEK_MICROS);
+ RecordTick(statistics_, BLOB_DB_NUM_SEEK);
+ iter_->SeekToFirst();
+ while (UpdateBlobValue()) {
+ iter_->Next();
+ }
+ }
+
+ void SeekToLast() override {
+ StopWatch seek_sw(clock_, statistics_, BLOB_DB_SEEK_MICROS);
+ RecordTick(statistics_, BLOB_DB_NUM_SEEK);
+ iter_->SeekToLast();
+ while (UpdateBlobValue()) {
+ iter_->Prev();
+ }
+ }
+
+ void Seek(const Slice& target) override {
+ StopWatch seek_sw(clock_, statistics_, BLOB_DB_SEEK_MICROS);
+ RecordTick(statistics_, BLOB_DB_NUM_SEEK);
+ iter_->Seek(target);
+ while (UpdateBlobValue()) {
+ iter_->Next();
+ }
+ }
+
+ void SeekForPrev(const Slice& target) override {
+ StopWatch seek_sw(clock_, statistics_, BLOB_DB_SEEK_MICROS);
+ RecordTick(statistics_, BLOB_DB_NUM_SEEK);
+ iter_->SeekForPrev(target);
+ while (UpdateBlobValue()) {
+ iter_->Prev();
+ }
+ }
+
+ void Next() override {
+ assert(Valid());
+ StopWatch next_sw(clock_, statistics_, BLOB_DB_NEXT_MICROS);
+ RecordTick(statistics_, BLOB_DB_NUM_NEXT);
+ iter_->Next();
+ while (UpdateBlobValue()) {
+ iter_->Next();
+ }
+ }
+
+ void Prev() override {
+ assert(Valid());
+ StopWatch prev_sw(clock_, statistics_, BLOB_DB_PREV_MICROS);
+ RecordTick(statistics_, BLOB_DB_NUM_PREV);
+ iter_->Prev();
+ while (UpdateBlobValue()) {
+ iter_->Prev();
+ }
+ }
+
+ Slice key() const override {
+ assert(Valid());
+ return iter_->key();
+ }
+
+ Slice value() const override {
+ assert(Valid());
+ if (!iter_->IsBlob()) {
+ return iter_->value();
+ }
+ return value_;
+ }
+
+ // Iterator::Refresh() not supported.
+
+ private:
+ // Return true if caller should continue to next value.
+ bool UpdateBlobValue() {
+ value_.Reset();
+ status_ = Status::OK();
+ if (iter_->Valid() && iter_->status().ok() && iter_->IsBlob()) {
+ Status s = blob_db_->GetBlobValue(iter_->key(), iter_->value(), &value_);
+ if (s.IsNotFound()) {
+ return true;
+ } else {
+ if (!s.ok()) {
+ status_ = s;
+ }
+ return false;
+ }
+ } else {
+ return false;
+ }
+ }
+
+ std::unique_ptr<ManagedSnapshot> snapshot_;
+ std::unique_ptr<ArenaWrappedDBIter> iter_;
+ BlobDBImpl* blob_db_;
+ SystemClock* clock_;
+ Statistics* statistics_;
+ Status status_;
+ PinnableSlice value_;
+};
+} // namespace blob_db
+} // namespace ROCKSDB_NAMESPACE
+#endif // !ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/blob_db/blob_db_listener.h b/src/rocksdb/utilities/blob_db/blob_db_listener.h
new file mode 100644
index 000000000..d17d29853
--- /dev/null
+++ b/src/rocksdb/utilities/blob_db/blob_db_listener.h
@@ -0,0 +1,71 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#ifndef ROCKSDB_LITE
+
+#include <atomic>
+
+#include "rocksdb/listener.h"
+#include "util/mutexlock.h"
+#include "utilities/blob_db/blob_db_impl.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace blob_db {
+
+class BlobDBListener : public EventListener {
+ public:
+ explicit BlobDBListener(BlobDBImpl* blob_db_impl)
+ : blob_db_impl_(blob_db_impl) {}
+
+ void OnFlushBegin(DB* /*db*/, const FlushJobInfo& /*info*/) override {
+ assert(blob_db_impl_ != nullptr);
+ blob_db_impl_->SyncBlobFiles();
+ }
+
+ void OnFlushCompleted(DB* /*db*/, const FlushJobInfo& /*info*/) override {
+ assert(blob_db_impl_ != nullptr);
+ blob_db_impl_->UpdateLiveSSTSize();
+ }
+
+ void OnCompactionCompleted(DB* /*db*/,
+ const CompactionJobInfo& /*info*/) override {
+ assert(blob_db_impl_ != nullptr);
+ blob_db_impl_->UpdateLiveSSTSize();
+ }
+
+ const char* Name() const override { return kClassName(); }
+ static const char* kClassName() { return "BlobDBListener"; }
+
+ protected:
+ BlobDBImpl* blob_db_impl_;
+};
+
+class BlobDBListenerGC : public BlobDBListener {
+ public:
+ explicit BlobDBListenerGC(BlobDBImpl* blob_db_impl)
+ : BlobDBListener(blob_db_impl) {}
+
+ const char* Name() const override { return kClassName(); }
+ static const char* kClassName() { return "BlobDBListenerGC"; }
+ void OnFlushCompleted(DB* db, const FlushJobInfo& info) override {
+ BlobDBListener::OnFlushCompleted(db, info);
+
+ assert(blob_db_impl_);
+ blob_db_impl_->ProcessFlushJobInfo(info);
+ }
+
+ void OnCompactionCompleted(DB* db, const CompactionJobInfo& info) override {
+ BlobDBListener::OnCompactionCompleted(db, info);
+
+ assert(blob_db_impl_);
+ blob_db_impl_->ProcessCompactionJobInfo(info);
+ }
+};
+
+} // namespace blob_db
+} // namespace ROCKSDB_NAMESPACE
+#endif // !ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/blob_db/blob_db_test.cc b/src/rocksdb/utilities/blob_db/blob_db_test.cc
new file mode 100644
index 000000000..e392962b2
--- /dev/null
+++ b/src/rocksdb/utilities/blob_db/blob_db_test.cc
@@ -0,0 +1,2407 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "utilities/blob_db/blob_db.h"
+
+#include <algorithm>
+#include <chrono>
+#include <cstdlib>
+#include <iomanip>
+#include <map>
+#include <memory>
+#include <sstream>
+#include <string>
+#include <vector>
+
+#include "db/blob/blob_index.h"
+#include "db/db_test_util.h"
+#include "env/composite_env_wrapper.h"
+#include "file/file_util.h"
+#include "file/sst_file_manager_impl.h"
+#include "port/port.h"
+#include "rocksdb/utilities/debug.h"
+#include "test_util/mock_time_env.h"
+#include "test_util/sync_point.h"
+#include "test_util/testharness.h"
+#include "util/random.h"
+#include "util/string_util.h"
+#include "utilities/blob_db/blob_db_impl.h"
+#include "utilities/fault_injection_env.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace blob_db {
+
+class BlobDBTest : public testing::Test {
+ public:
+ const int kMaxBlobSize = 1 << 14;
+
+ struct BlobIndexVersion {
+ BlobIndexVersion() = default;
+ BlobIndexVersion(std::string _user_key, uint64_t _file_number,
+ uint64_t _expiration, SequenceNumber _sequence,
+ ValueType _type)
+ : user_key(std::move(_user_key)),
+ file_number(_file_number),
+ expiration(_expiration),
+ sequence(_sequence),
+ type(_type) {}
+
+ std::string user_key;
+ uint64_t file_number = kInvalidBlobFileNumber;
+ uint64_t expiration = kNoExpiration;
+ SequenceNumber sequence = 0;
+ ValueType type = kTypeValue;
+ };
+
+ BlobDBTest()
+ : dbname_(test::PerThreadDBPath("blob_db_test")), blob_db_(nullptr) {
+ mock_clock_ = std::make_shared<MockSystemClock>(SystemClock::Default());
+ mock_env_.reset(new CompositeEnvWrapper(Env::Default(), mock_clock_));
+ fault_injection_env_.reset(new FaultInjectionTestEnv(Env::Default()));
+
+ Status s = DestroyBlobDB(dbname_, Options(), BlobDBOptions());
+ assert(s.ok());
+ }
+
+ ~BlobDBTest() override {
+ SyncPoint::GetInstance()->ClearAllCallBacks();
+ Destroy();
+ }
+
+ Status TryOpen(BlobDBOptions bdb_options = BlobDBOptions(),
+ Options options = Options()) {
+ options.create_if_missing = true;
+ if (options.env == mock_env_.get()) {
+ // Need to disable stats dumping and persisting which also use
+ // RepeatableThread, which uses InstrumentedCondVar::TimedWaitInternal.
+ // With mocked time, this can hang on some platforms (MacOS)
+ // because (a) on some platforms, pthread_cond_timedwait does not appear
+ // to release the lock for other threads to operate if the deadline time
+ // is already passed, and (b) TimedWait calls are currently a bad
+ // abstraction because the deadline parameter is usually computed from
+ // Env time, but is interpreted in real clock time.
+ options.stats_dump_period_sec = 0;
+ options.stats_persist_period_sec = 0;
+ }
+ return BlobDB::Open(options, bdb_options, dbname_, &blob_db_);
+ }
+
+ void Open(BlobDBOptions bdb_options = BlobDBOptions(),
+ Options options = Options()) {
+ ASSERT_OK(TryOpen(bdb_options, options));
+ }
+
+ void Reopen(BlobDBOptions bdb_options = BlobDBOptions(),
+ Options options = Options()) {
+ assert(blob_db_ != nullptr);
+ delete blob_db_;
+ blob_db_ = nullptr;
+ Open(bdb_options, options);
+ }
+
+ void Close() {
+ assert(blob_db_ != nullptr);
+ delete blob_db_;
+ blob_db_ = nullptr;
+ }
+
+ void Destroy() {
+ if (blob_db_) {
+ Options options = blob_db_->GetOptions();
+ BlobDBOptions bdb_options = blob_db_->GetBlobDBOptions();
+ delete blob_db_;
+ blob_db_ = nullptr;
+ ASSERT_OK(DestroyBlobDB(dbname_, options, bdb_options));
+ }
+ }
+
+ BlobDBImpl *blob_db_impl() {
+ return reinterpret_cast<BlobDBImpl *>(blob_db_);
+ }
+
+ Status Put(const Slice &key, const Slice &value,
+ std::map<std::string, std::string> *data = nullptr) {
+ Status s = blob_db_->Put(WriteOptions(), key, value);
+ if (data != nullptr) {
+ (*data)[key.ToString()] = value.ToString();
+ }
+ return s;
+ }
+
+ void Delete(const std::string &key,
+ std::map<std::string, std::string> *data = nullptr) {
+ ASSERT_OK(blob_db_->Delete(WriteOptions(), key));
+ if (data != nullptr) {
+ data->erase(key);
+ }
+ }
+
+ Status PutWithTTL(const Slice &key, const Slice &value, uint64_t ttl,
+ std::map<std::string, std::string> *data = nullptr) {
+ Status s = blob_db_->PutWithTTL(WriteOptions(), key, value, ttl);
+ if (data != nullptr) {
+ (*data)[key.ToString()] = value.ToString();
+ }
+ return s;
+ }
+
+ Status PutUntil(const Slice &key, const Slice &value, uint64_t expiration) {
+ return blob_db_->PutUntil(WriteOptions(), key, value, expiration);
+ }
+
+ void PutRandomWithTTL(const std::string &key, uint64_t ttl, Random *rnd,
+ std::map<std::string, std::string> *data = nullptr) {
+ int len = rnd->Next() % kMaxBlobSize + 1;
+ std::string value = rnd->HumanReadableString(len);
+ ASSERT_OK(
+ blob_db_->PutWithTTL(WriteOptions(), Slice(key), Slice(value), ttl));
+ if (data != nullptr) {
+ (*data)[key] = value;
+ }
+ }
+
+ void PutRandomUntil(const std::string &key, uint64_t expiration, Random *rnd,
+ std::map<std::string, std::string> *data = nullptr) {
+ int len = rnd->Next() % kMaxBlobSize + 1;
+ std::string value = rnd->HumanReadableString(len);
+ ASSERT_OK(blob_db_->PutUntil(WriteOptions(), Slice(key), Slice(value),
+ expiration));
+ if (data != nullptr) {
+ (*data)[key] = value;
+ }
+ }
+
+ void PutRandom(const std::string &key, Random *rnd,
+ std::map<std::string, std::string> *data = nullptr) {
+ PutRandom(blob_db_, key, rnd, data);
+ }
+
+ void PutRandom(DB *db, const std::string &key, Random *rnd,
+ std::map<std::string, std::string> *data = nullptr) {
+ int len = rnd->Next() % kMaxBlobSize + 1;
+ std::string value = rnd->HumanReadableString(len);
+ ASSERT_OK(db->Put(WriteOptions(), Slice(key), Slice(value)));
+ if (data != nullptr) {
+ (*data)[key] = value;
+ }
+ }
+
+ void PutRandomToWriteBatch(
+ const std::string &key, Random *rnd, WriteBatch *batch,
+ std::map<std::string, std::string> *data = nullptr) {
+ int len = rnd->Next() % kMaxBlobSize + 1;
+ std::string value = rnd->HumanReadableString(len);
+ ASSERT_OK(batch->Put(key, value));
+ if (data != nullptr) {
+ (*data)[key] = value;
+ }
+ }
+
+ // Verify blob db contain expected data and nothing more.
+ void VerifyDB(const std::map<std::string, std::string> &data) {
+ VerifyDB(blob_db_, data);
+ }
+
+ void VerifyDB(DB *db, const std::map<std::string, std::string> &data) {
+ // Verify normal Get
+ auto *cfh = db->DefaultColumnFamily();
+ for (auto &p : data) {
+ PinnableSlice value_slice;
+ ASSERT_OK(db->Get(ReadOptions(), cfh, p.first, &value_slice));
+ ASSERT_EQ(p.second, value_slice.ToString());
+ std::string value;
+ ASSERT_OK(db->Get(ReadOptions(), cfh, p.first, &value));
+ ASSERT_EQ(p.second, value);
+ }
+
+ // Verify iterators
+ Iterator *iter = db->NewIterator(ReadOptions());
+ iter->SeekToFirst();
+ for (auto &p : data) {
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ(p.first, iter->key().ToString());
+ ASSERT_EQ(p.second, iter->value().ToString());
+ iter->Next();
+ }
+ ASSERT_FALSE(iter->Valid());
+ ASSERT_OK(iter->status());
+ delete iter;
+ }
+
+ void VerifyBaseDB(
+ const std::map<std::string, KeyVersion> &expected_versions) {
+ auto *bdb_impl = static_cast<BlobDBImpl *>(blob_db_);
+ DB *db = blob_db_->GetRootDB();
+ const size_t kMaxKeys = 10000;
+ std::vector<KeyVersion> versions;
+ ASSERT_OK(GetAllKeyVersions(db, "", "", kMaxKeys, &versions));
+ ASSERT_EQ(expected_versions.size(), versions.size());
+ size_t i = 0;
+ for (auto &key_version : expected_versions) {
+ const KeyVersion &expected_version = key_version.second;
+ ASSERT_EQ(expected_version.user_key, versions[i].user_key);
+ ASSERT_EQ(expected_version.sequence, versions[i].sequence);
+ ASSERT_EQ(expected_version.type, versions[i].type);
+ if (versions[i].type == kTypeValue) {
+ ASSERT_EQ(expected_version.value, versions[i].value);
+ } else {
+ ASSERT_EQ(kTypeBlobIndex, versions[i].type);
+ PinnableSlice value;
+ ASSERT_OK(bdb_impl->TEST_GetBlobValue(versions[i].user_key,
+ versions[i].value, &value));
+ ASSERT_EQ(expected_version.value, value.ToString());
+ }
+ i++;
+ }
+ }
+
+ void VerifyBaseDBBlobIndex(
+ const std::map<std::string, BlobIndexVersion> &expected_versions) {
+ const size_t kMaxKeys = 10000;
+ std::vector<KeyVersion> versions;
+ ASSERT_OK(
+ GetAllKeyVersions(blob_db_->GetRootDB(), "", "", kMaxKeys, &versions));
+ ASSERT_EQ(versions.size(), expected_versions.size());
+
+ size_t i = 0;
+ for (const auto &expected_pair : expected_versions) {
+ const BlobIndexVersion &expected_version = expected_pair.second;
+
+ ASSERT_EQ(versions[i].user_key, expected_version.user_key);
+ ASSERT_EQ(versions[i].sequence, expected_version.sequence);
+ ASSERT_EQ(versions[i].type, expected_version.type);
+ if (versions[i].type != kTypeBlobIndex) {
+ ASSERT_EQ(kInvalidBlobFileNumber, expected_version.file_number);
+ ASSERT_EQ(kNoExpiration, expected_version.expiration);
+
+ ++i;
+ continue;
+ }
+
+ BlobIndex blob_index;
+ ASSERT_OK(blob_index.DecodeFrom(versions[i].value));
+
+ const uint64_t file_number = !blob_index.IsInlined()
+ ? blob_index.file_number()
+ : kInvalidBlobFileNumber;
+ ASSERT_EQ(file_number, expected_version.file_number);
+
+ const uint64_t expiration =
+ blob_index.HasTTL() ? blob_index.expiration() : kNoExpiration;
+ ASSERT_EQ(expiration, expected_version.expiration);
+
+ ++i;
+ }
+ }
+
+ void InsertBlobs() {
+ WriteOptions wo;
+ std::string value;
+
+ Random rnd(301);
+ for (size_t i = 0; i < 100000; i++) {
+ uint64_t ttl = rnd.Next() % 86400;
+ PutRandomWithTTL("key" + std::to_string(i % 500), ttl, &rnd, nullptr);
+ }
+
+ for (size_t i = 0; i < 10; i++) {
+ Delete("key" + std::to_string(i % 500));
+ }
+ }
+
+ const std::string dbname_;
+ std::shared_ptr<MockSystemClock> mock_clock_;
+ std::unique_ptr<Env> mock_env_;
+ std::unique_ptr<FaultInjectionTestEnv> fault_injection_env_;
+ BlobDB *blob_db_;
+}; // class BlobDBTest
+
+TEST_F(BlobDBTest, Put) {
+ Random rnd(301);
+ BlobDBOptions bdb_options;
+ bdb_options.min_blob_size = 0;
+ bdb_options.disable_background_tasks = true;
+ Open(bdb_options);
+ std::map<std::string, std::string> data;
+ for (size_t i = 0; i < 100; i++) {
+ PutRandom("key" + std::to_string(i), &rnd, &data);
+ }
+ VerifyDB(data);
+}
+
+TEST_F(BlobDBTest, PutWithTTL) {
+ Random rnd(301);
+ Options options;
+ options.env = mock_env_.get();
+ BlobDBOptions bdb_options;
+ bdb_options.ttl_range_secs = 1000;
+ bdb_options.min_blob_size = 0;
+ bdb_options.blob_file_size = 256 * 1000 * 1000;
+ bdb_options.disable_background_tasks = true;
+ Open(bdb_options, options);
+ std::map<std::string, std::string> data;
+ mock_clock_->SetCurrentTime(50);
+ for (size_t i = 0; i < 100; i++) {
+ uint64_t ttl = rnd.Next() % 100;
+ PutRandomWithTTL("key" + std::to_string(i), ttl, &rnd,
+ (ttl <= 50 ? nullptr : &data));
+ }
+ mock_clock_->SetCurrentTime(100);
+ auto *bdb_impl = static_cast<BlobDBImpl *>(blob_db_);
+ auto blob_files = bdb_impl->TEST_GetBlobFiles();
+ ASSERT_EQ(1, blob_files.size());
+ ASSERT_TRUE(blob_files[0]->HasTTL());
+ ASSERT_OK(bdb_impl->TEST_CloseBlobFile(blob_files[0]));
+ VerifyDB(data);
+}
+
+TEST_F(BlobDBTest, PutUntil) {
+ Random rnd(301);
+ Options options;
+ options.env = mock_env_.get();
+ BlobDBOptions bdb_options;
+ bdb_options.ttl_range_secs = 1000;
+ bdb_options.min_blob_size = 0;
+ bdb_options.blob_file_size = 256 * 1000 * 1000;
+ bdb_options.disable_background_tasks = true;
+ Open(bdb_options, options);
+ std::map<std::string, std::string> data;
+ mock_clock_->SetCurrentTime(50);
+ for (size_t i = 0; i < 100; i++) {
+ uint64_t expiration = rnd.Next() % 100 + 50;
+ PutRandomUntil("key" + std::to_string(i), expiration, &rnd,
+ (expiration <= 100 ? nullptr : &data));
+ }
+ mock_clock_->SetCurrentTime(100);
+ auto *bdb_impl = static_cast<BlobDBImpl *>(blob_db_);
+ auto blob_files = bdb_impl->TEST_GetBlobFiles();
+ ASSERT_EQ(1, blob_files.size());
+ ASSERT_TRUE(blob_files[0]->HasTTL());
+ ASSERT_OK(bdb_impl->TEST_CloseBlobFile(blob_files[0]));
+ VerifyDB(data);
+}
+
+TEST_F(BlobDBTest, StackableDBGet) {
+ Random rnd(301);
+ BlobDBOptions bdb_options;
+ bdb_options.min_blob_size = 0;
+ bdb_options.disable_background_tasks = true;
+ Open(bdb_options);
+ std::map<std::string, std::string> data;
+ for (size_t i = 0; i < 100; i++) {
+ PutRandom("key" + std::to_string(i), &rnd, &data);
+ }
+ for (size_t i = 0; i < 100; i++) {
+ StackableDB *db = blob_db_;
+ ColumnFamilyHandle *column_family = db->DefaultColumnFamily();
+ std::string key = "key" + std::to_string(i);
+ PinnableSlice pinnable_value;
+ ASSERT_OK(db->Get(ReadOptions(), column_family, key, &pinnable_value));
+ std::string string_value;
+ ASSERT_OK(db->Get(ReadOptions(), column_family, key, &string_value));
+ ASSERT_EQ(string_value, pinnable_value.ToString());
+ ASSERT_EQ(string_value, data[key]);
+ }
+}
+
+TEST_F(BlobDBTest, GetExpiration) {
+ Options options;
+ options.env = mock_env_.get();
+ BlobDBOptions bdb_options;
+ bdb_options.disable_background_tasks = true;
+ mock_clock_->SetCurrentTime(100);
+ Open(bdb_options, options);
+ ASSERT_OK(Put("key1", "value1"));
+ ASSERT_OK(PutWithTTL("key2", "value2", 200));
+ PinnableSlice value;
+ uint64_t expiration;
+ ASSERT_OK(blob_db_->Get(ReadOptions(), "key1", &value, &expiration));
+ ASSERT_EQ("value1", value.ToString());
+ ASSERT_EQ(kNoExpiration, expiration);
+ ASSERT_OK(blob_db_->Get(ReadOptions(), "key2", &value, &expiration));
+ ASSERT_EQ("value2", value.ToString());
+ ASSERT_EQ(300 /* = 100 + 200 */, expiration);
+}
+
+TEST_F(BlobDBTest, GetIOError) {
+ Options options;
+ options.env = fault_injection_env_.get();
+ BlobDBOptions bdb_options;
+ bdb_options.min_blob_size = 0; // Make sure value write to blob file
+ bdb_options.disable_background_tasks = true;
+ Open(bdb_options, options);
+ ColumnFamilyHandle *column_family = blob_db_->DefaultColumnFamily();
+ PinnableSlice value;
+ ASSERT_OK(Put("foo", "bar"));
+ fault_injection_env_->SetFilesystemActive(false, Status::IOError());
+ Status s = blob_db_->Get(ReadOptions(), column_family, "foo", &value);
+ ASSERT_TRUE(s.IsIOError());
+ // Reactivate file system to allow test to close DB.
+ fault_injection_env_->SetFilesystemActive(true);
+}
+
+TEST_F(BlobDBTest, PutIOError) {
+ Options options;
+ options.env = fault_injection_env_.get();
+ BlobDBOptions bdb_options;
+ bdb_options.min_blob_size = 0; // Make sure value write to blob file
+ bdb_options.disable_background_tasks = true;
+ Open(bdb_options, options);
+ fault_injection_env_->SetFilesystemActive(false, Status::IOError());
+ ASSERT_TRUE(Put("foo", "v1").IsIOError());
+ fault_injection_env_->SetFilesystemActive(true, Status::IOError());
+ ASSERT_OK(Put("bar", "v1"));
+}
+
+TEST_F(BlobDBTest, WriteBatch) {
+ Random rnd(301);
+ BlobDBOptions bdb_options;
+ bdb_options.min_blob_size = 0;
+ bdb_options.disable_background_tasks = true;
+ Open(bdb_options);
+ std::map<std::string, std::string> data;
+ for (size_t i = 0; i < 100; i++) {
+ WriteBatch batch;
+ for (size_t j = 0; j < 10; j++) {
+ PutRandomToWriteBatch("key" + std::to_string(j * 100 + i), &rnd, &batch,
+ &data);
+ }
+
+ ASSERT_OK(blob_db_->Write(WriteOptions(), &batch));
+ }
+ VerifyDB(data);
+}
+
+TEST_F(BlobDBTest, Delete) {
+ Random rnd(301);
+ BlobDBOptions bdb_options;
+ bdb_options.min_blob_size = 0;
+ bdb_options.disable_background_tasks = true;
+ Open(bdb_options);
+ std::map<std::string, std::string> data;
+ for (size_t i = 0; i < 100; i++) {
+ PutRandom("key" + std::to_string(i), &rnd, &data);
+ }
+ for (size_t i = 0; i < 100; i += 5) {
+ Delete("key" + std::to_string(i), &data);
+ }
+ VerifyDB(data);
+}
+
+TEST_F(BlobDBTest, DeleteBatch) {
+ Random rnd(301);
+ BlobDBOptions bdb_options;
+ bdb_options.min_blob_size = 0;
+ bdb_options.disable_background_tasks = true;
+ Open(bdb_options);
+ for (size_t i = 0; i < 100; i++) {
+ PutRandom("key" + std::to_string(i), &rnd);
+ }
+ WriteBatch batch;
+ for (size_t i = 0; i < 100; i++) {
+ ASSERT_OK(batch.Delete("key" + std::to_string(i)));
+ }
+ ASSERT_OK(blob_db_->Write(WriteOptions(), &batch));
+ // DB should be empty.
+ VerifyDB({});
+}
+
+TEST_F(BlobDBTest, Override) {
+ Random rnd(301);
+ BlobDBOptions bdb_options;
+ bdb_options.min_blob_size = 0;
+ bdb_options.disable_background_tasks = true;
+ Open(bdb_options);
+ std::map<std::string, std::string> data;
+ for (int i = 0; i < 10000; i++) {
+ PutRandom("key" + std::to_string(i), &rnd, nullptr);
+ }
+ // override all the keys
+ for (int i = 0; i < 10000; i++) {
+ PutRandom("key" + std::to_string(i), &rnd, &data);
+ }
+ VerifyDB(data);
+}
+
+#ifdef SNAPPY
+TEST_F(BlobDBTest, Compression) {
+ Random rnd(301);
+ BlobDBOptions bdb_options;
+ bdb_options.min_blob_size = 0;
+ bdb_options.disable_background_tasks = true;
+ bdb_options.compression = CompressionType::kSnappyCompression;
+ Open(bdb_options);
+ std::map<std::string, std::string> data;
+ for (size_t i = 0; i < 100; i++) {
+ PutRandom("put-key" + std::to_string(i), &rnd, &data);
+ }
+ for (int i = 0; i < 100; i++) {
+ WriteBatch batch;
+ for (size_t j = 0; j < 10; j++) {
+ PutRandomToWriteBatch("write-batch-key" + std::to_string(j * 100 + i),
+ &rnd, &batch, &data);
+ }
+ ASSERT_OK(blob_db_->Write(WriteOptions(), &batch));
+ }
+ VerifyDB(data);
+}
+
+TEST_F(BlobDBTest, DecompressAfterReopen) {
+ Random rnd(301);
+ BlobDBOptions bdb_options;
+ bdb_options.min_blob_size = 0;
+ bdb_options.disable_background_tasks = true;
+ bdb_options.compression = CompressionType::kSnappyCompression;
+ Open(bdb_options);
+ std::map<std::string, std::string> data;
+ for (size_t i = 0; i < 100; i++) {
+ PutRandom("put-key" + std::to_string(i), &rnd, &data);
+ }
+ VerifyDB(data);
+ bdb_options.compression = CompressionType::kNoCompression;
+ Reopen(bdb_options);
+ VerifyDB(data);
+}
+
+TEST_F(BlobDBTest, EnableDisableCompressionGC) {
+ Random rnd(301);
+ BlobDBOptions bdb_options;
+ bdb_options.min_blob_size = 0;
+ bdb_options.garbage_collection_cutoff = 1.0;
+ bdb_options.disable_background_tasks = true;
+ bdb_options.compression = kSnappyCompression;
+ Open(bdb_options);
+ std::map<std::string, std::string> data;
+ size_t data_idx = 0;
+ for (; data_idx < 100; data_idx++) {
+ PutRandom("put-key" + std::to_string(data_idx), &rnd, &data);
+ }
+ VerifyDB(data);
+ auto blob_files = blob_db_impl()->TEST_GetBlobFiles();
+ ASSERT_EQ(1, blob_files.size());
+ ASSERT_EQ(kSnappyCompression, blob_files[0]->GetCompressionType());
+
+ // disable compression
+ bdb_options.compression = kNoCompression;
+ Reopen(bdb_options);
+
+ // Add more data with new compression type
+ for (; data_idx < 200; data_idx++) {
+ PutRandom("put-key" + std::to_string(data_idx), &rnd, &data);
+ }
+ VerifyDB(data);
+
+ blob_files = blob_db_impl()->TEST_GetBlobFiles();
+ ASSERT_EQ(2, blob_files.size());
+ ASSERT_EQ(kNoCompression, blob_files[1]->GetCompressionType());
+
+ // Enable GC. If we do it earlier the snapshot release triggered compaction
+ // may compact files and trigger GC before we can verify there are two files.
+ bdb_options.enable_garbage_collection = true;
+ Reopen(bdb_options);
+
+ // Trigger compaction
+ ASSERT_OK(blob_db_->CompactRange(CompactRangeOptions(), nullptr, nullptr));
+ blob_db_impl()->TEST_DeleteObsoleteFiles();
+ VerifyDB(data);
+
+ blob_files = blob_db_impl()->TEST_GetBlobFiles();
+ for (auto bfile : blob_files) {
+ ASSERT_EQ(kNoCompression, bfile->GetCompressionType());
+ }
+
+ // enabling the compression again
+ bdb_options.compression = kSnappyCompression;
+ Reopen(bdb_options);
+
+ // Add more data with new compression type
+ for (; data_idx < 300; data_idx++) {
+ PutRandom("put-key" + std::to_string(data_idx), &rnd, &data);
+ }
+ VerifyDB(data);
+
+ // Trigger compaction
+ ASSERT_OK(blob_db_->CompactRange(CompactRangeOptions(), nullptr, nullptr));
+ blob_db_impl()->TEST_DeleteObsoleteFiles();
+ VerifyDB(data);
+
+ blob_files = blob_db_impl()->TEST_GetBlobFiles();
+ for (auto bfile : blob_files) {
+ ASSERT_EQ(kSnappyCompression, bfile->GetCompressionType());
+ }
+}
+
+#ifdef LZ4
+// Test switch compression types and run GC, it needs both Snappy and LZ4
+// support.
+TEST_F(BlobDBTest, ChangeCompressionGC) {
+ Random rnd(301);
+ BlobDBOptions bdb_options;
+ bdb_options.min_blob_size = 0;
+ bdb_options.garbage_collection_cutoff = 1.0;
+ bdb_options.disable_background_tasks = true;
+ bdb_options.compression = kLZ4Compression;
+ Open(bdb_options);
+ std::map<std::string, std::string> data;
+ size_t data_idx = 0;
+ for (; data_idx < 100; data_idx++) {
+ PutRandom("put-key" + std::to_string(data_idx), &rnd, &data);
+ }
+ VerifyDB(data);
+ auto blob_files = blob_db_impl()->TEST_GetBlobFiles();
+ ASSERT_EQ(1, blob_files.size());
+ ASSERT_EQ(kLZ4Compression, blob_files[0]->GetCompressionType());
+
+ // Change compression type
+ bdb_options.compression = kSnappyCompression;
+ Reopen(bdb_options);
+
+ // Add more data with Snappy compression type
+ for (; data_idx < 200; data_idx++) {
+ PutRandom("put-key" + std::to_string(data_idx), &rnd, &data);
+ }
+ VerifyDB(data);
+
+ // Verify blob file compression type
+ blob_files = blob_db_impl()->TEST_GetBlobFiles();
+ ASSERT_EQ(2, blob_files.size());
+ ASSERT_EQ(kSnappyCompression, blob_files[1]->GetCompressionType());
+
+ // Enable GC. If we do it earlier the snapshot release triggered compaction
+ // may compact files and trigger GC before we can verify there are two files.
+ bdb_options.enable_garbage_collection = true;
+ Reopen(bdb_options);
+
+ ASSERT_OK(blob_db_->CompactRange(CompactRangeOptions(), nullptr, nullptr));
+ VerifyDB(data);
+
+ blob_db_impl()->TEST_DeleteObsoleteFiles();
+ blob_files = blob_db_impl()->TEST_GetBlobFiles();
+ for (auto bfile : blob_files) {
+ ASSERT_EQ(kSnappyCompression, bfile->GetCompressionType());
+ }
+
+ // Disable compression
+ bdb_options.compression = kNoCompression;
+ Reopen(bdb_options);
+ for (; data_idx < 300; data_idx++) {
+ PutRandom("put-key" + std::to_string(data_idx), &rnd, &data);
+ }
+ VerifyDB(data);
+
+ ASSERT_OK(blob_db_->CompactRange(CompactRangeOptions(), nullptr, nullptr));
+ VerifyDB(data);
+
+ blob_db_impl()->TEST_DeleteObsoleteFiles();
+ blob_files = blob_db_impl()->TEST_GetBlobFiles();
+ for (auto bfile : blob_files) {
+ ASSERT_EQ(kNoCompression, bfile->GetCompressionType());
+ }
+
+ // switching different compression types to generate mixed compression types
+ bdb_options.compression = kSnappyCompression;
+ Reopen(bdb_options);
+ for (; data_idx < 400; data_idx++) {
+ PutRandom("put-key" + std::to_string(data_idx), &rnd, &data);
+ }
+ VerifyDB(data);
+
+ bdb_options.compression = kLZ4Compression;
+ Reopen(bdb_options);
+ for (; data_idx < 500; data_idx++) {
+ PutRandom("put-key" + std::to_string(data_idx), &rnd, &data);
+ }
+ VerifyDB(data);
+
+ ASSERT_OK(blob_db_->CompactRange(CompactRangeOptions(), nullptr, nullptr));
+ VerifyDB(data);
+
+ blob_db_impl()->TEST_DeleteObsoleteFiles();
+ blob_files = blob_db_impl()->TEST_GetBlobFiles();
+ for (auto bfile : blob_files) {
+ ASSERT_EQ(kLZ4Compression, bfile->GetCompressionType());
+ }
+}
+#endif // LZ4
+#endif // SNAPPY
+
+TEST_F(BlobDBTest, MultipleWriters) {
+ Open(BlobDBOptions());
+
+ std::vector<port::Thread> workers;
+ std::vector<std::map<std::string, std::string>> data_set(10);
+ for (uint32_t i = 0; i < 10; i++)
+ workers.push_back(port::Thread(
+ [&](uint32_t id) {
+ Random rnd(301 + id);
+ for (int j = 0; j < 100; j++) {
+ std::string key =
+ "key" + std::to_string(id) + "_" + std::to_string(j);
+ if (id < 5) {
+ PutRandom(key, &rnd, &data_set[id]);
+ } else {
+ WriteBatch batch;
+ PutRandomToWriteBatch(key, &rnd, &batch, &data_set[id]);
+ ASSERT_OK(blob_db_->Write(WriteOptions(), &batch));
+ }
+ }
+ },
+ i));
+ std::map<std::string, std::string> data;
+ for (size_t i = 0; i < 10; i++) {
+ workers[i].join();
+ data.insert(data_set[i].begin(), data_set[i].end());
+ }
+ VerifyDB(data);
+}
+
+TEST_F(BlobDBTest, SstFileManager) {
+ // run the same test for Get(), MultiGet() and Iterator each.
+ std::shared_ptr<SstFileManager> sst_file_manager(
+ NewSstFileManager(mock_env_.get()));
+ sst_file_manager->SetDeleteRateBytesPerSecond(1);
+ SstFileManagerImpl *sfm =
+ static_cast<SstFileManagerImpl *>(sst_file_manager.get());
+
+ BlobDBOptions bdb_options;
+ bdb_options.min_blob_size = 0;
+ bdb_options.enable_garbage_collection = true;
+ bdb_options.garbage_collection_cutoff = 1.0;
+ Options db_options;
+
+ int files_scheduled_to_delete = 0;
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "SstFileManagerImpl::ScheduleFileDeletion", [&](void *arg) {
+ assert(arg);
+ const std::string *const file_path =
+ static_cast<const std::string *>(arg);
+ if (file_path->find(".blob") != std::string::npos) {
+ ++files_scheduled_to_delete;
+ }
+ });
+ SyncPoint::GetInstance()->EnableProcessing();
+ db_options.sst_file_manager = sst_file_manager;
+
+ Open(bdb_options, db_options);
+
+ // Create one obselete file and clean it.
+ ASSERT_OK(blob_db_->Put(WriteOptions(), "foo", "bar"));
+ auto blob_files = blob_db_impl()->TEST_GetBlobFiles();
+ ASSERT_EQ(1, blob_files.size());
+ std::shared_ptr<BlobFile> bfile = blob_files[0];
+ ASSERT_OK(blob_db_impl()->TEST_CloseBlobFile(bfile));
+ ASSERT_OK(blob_db_->CompactRange(CompactRangeOptions(), nullptr, nullptr));
+ blob_db_impl()->TEST_DeleteObsoleteFiles();
+
+ // Even if SSTFileManager is not set, DB is creating a dummy one.
+ ASSERT_EQ(1, files_scheduled_to_delete);
+ Destroy();
+ // Make sure that DestroyBlobDB() also goes through delete scheduler.
+ ASSERT_EQ(2, files_scheduled_to_delete);
+ SyncPoint::GetInstance()->DisableProcessing();
+ sfm->WaitForEmptyTrash();
+}
+
+TEST_F(BlobDBTest, SstFileManagerRestart) {
+ int files_scheduled_to_delete = 0;
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "SstFileManagerImpl::ScheduleFileDeletion", [&](void *arg) {
+ assert(arg);
+ const std::string *const file_path =
+ static_cast<const std::string *>(arg);
+ if (file_path->find(".blob") != std::string::npos) {
+ ++files_scheduled_to_delete;
+ }
+ });
+
+ // run the same test for Get(), MultiGet() and Iterator each.
+ std::shared_ptr<SstFileManager> sst_file_manager(
+ NewSstFileManager(mock_env_.get()));
+ sst_file_manager->SetDeleteRateBytesPerSecond(1);
+ SstFileManagerImpl *sfm =
+ static_cast<SstFileManagerImpl *>(sst_file_manager.get());
+
+ BlobDBOptions bdb_options;
+ bdb_options.min_blob_size = 0;
+ Options db_options;
+
+ SyncPoint::GetInstance()->EnableProcessing();
+ db_options.sst_file_manager = sst_file_manager;
+
+ Open(bdb_options, db_options);
+ std::string blob_dir = blob_db_impl()->TEST_blob_dir();
+ ASSERT_OK(blob_db_->Put(WriteOptions(), "foo", "bar"));
+ Close();
+
+ // Create 3 dummy trash files under the blob_dir
+ const auto &fs = db_options.env->GetFileSystem();
+ ASSERT_OK(CreateFile(fs, blob_dir + "/000666.blob.trash", "", false));
+ ASSERT_OK(CreateFile(fs, blob_dir + "/000888.blob.trash", "", true));
+ ASSERT_OK(CreateFile(fs, blob_dir + "/something_not_match.trash", "", false));
+
+ // Make sure that reopening the DB rescan the existing trash files
+ Open(bdb_options, db_options);
+ ASSERT_EQ(files_scheduled_to_delete, 2);
+
+ sfm->WaitForEmptyTrash();
+
+ // There should be exact one file under the blob dir now.
+ std::vector<std::string> all_files;
+ ASSERT_OK(db_options.env->GetChildren(blob_dir, &all_files));
+ int nfiles = 0;
+ for (const auto &f : all_files) {
+ assert(!f.empty());
+ if (f[0] == '.') {
+ continue;
+ }
+ nfiles++;
+ }
+ ASSERT_EQ(nfiles, 1);
+
+ SyncPoint::GetInstance()->DisableProcessing();
+}
+
+TEST_F(BlobDBTest, SnapshotAndGarbageCollection) {
+ BlobDBOptions bdb_options;
+ bdb_options.min_blob_size = 0;
+ bdb_options.enable_garbage_collection = true;
+ bdb_options.garbage_collection_cutoff = 1.0;
+ bdb_options.disable_background_tasks = true;
+
+ Options options;
+ options.disable_auto_compactions = true;
+
+ // i = when to take snapshot
+ for (int i = 0; i < 4; i++) {
+ Destroy();
+ Open(bdb_options, options);
+
+ const Snapshot *snapshot = nullptr;
+
+ // First file
+ ASSERT_OK(Put("key1", "value"));
+ if (i == 0) {
+ snapshot = blob_db_->GetSnapshot();
+ }
+
+ auto blob_files = blob_db_impl()->TEST_GetBlobFiles();
+ ASSERT_EQ(1, blob_files.size());
+ ASSERT_OK(blob_db_impl()->TEST_CloseBlobFile(blob_files[0]));
+
+ // Second file
+ ASSERT_OK(Put("key2", "value"));
+ if (i == 1) {
+ snapshot = blob_db_->GetSnapshot();
+ }
+
+ blob_files = blob_db_impl()->TEST_GetBlobFiles();
+ ASSERT_EQ(2, blob_files.size());
+ auto bfile = blob_files[1];
+ ASSERT_FALSE(bfile->Immutable());
+ ASSERT_OK(blob_db_impl()->TEST_CloseBlobFile(bfile));
+
+ // Third file
+ ASSERT_OK(Put("key3", "value"));
+ if (i == 2) {
+ snapshot = blob_db_->GetSnapshot();
+ }
+
+ ASSERT_OK(blob_db_->CompactRange(CompactRangeOptions(), nullptr, nullptr));
+ ASSERT_TRUE(bfile->Obsolete());
+ ASSERT_EQ(blob_db_->GetLatestSequenceNumber(),
+ bfile->GetObsoleteSequence());
+
+ Delete("key2");
+ if (i == 3) {
+ snapshot = blob_db_->GetSnapshot();
+ }
+
+ ASSERT_EQ(4, blob_db_impl()->TEST_GetBlobFiles().size());
+ blob_db_impl()->TEST_DeleteObsoleteFiles();
+
+ if (i >= 2) {
+ // The snapshot shouldn't see data in bfile
+ ASSERT_EQ(2, blob_db_impl()->TEST_GetBlobFiles().size());
+ blob_db_->ReleaseSnapshot(snapshot);
+ } else {
+ // The snapshot will see data in bfile, so the file shouldn't be deleted
+ ASSERT_EQ(4, blob_db_impl()->TEST_GetBlobFiles().size());
+ blob_db_->ReleaseSnapshot(snapshot);
+ blob_db_impl()->TEST_DeleteObsoleteFiles();
+ ASSERT_EQ(2, blob_db_impl()->TEST_GetBlobFiles().size());
+ }
+ }
+}
+
+TEST_F(BlobDBTest, ColumnFamilyNotSupported) {
+ Options options;
+ options.env = mock_env_.get();
+ mock_clock_->SetCurrentTime(0);
+ Open(BlobDBOptions(), options);
+ ColumnFamilyHandle *default_handle = blob_db_->DefaultColumnFamily();
+ ColumnFamilyHandle *handle = nullptr;
+ std::string value;
+ std::vector<std::string> values;
+ // The call simply pass through to base db. It should succeed.
+ ASSERT_OK(
+ blob_db_->CreateColumnFamily(ColumnFamilyOptions(), "foo", &handle));
+ ASSERT_TRUE(blob_db_->Put(WriteOptions(), handle, "k", "v").IsNotSupported());
+ ASSERT_TRUE(blob_db_->PutWithTTL(WriteOptions(), handle, "k", "v", 60)
+ .IsNotSupported());
+ ASSERT_TRUE(blob_db_->PutUntil(WriteOptions(), handle, "k", "v", 100)
+ .IsNotSupported());
+ WriteBatch batch;
+ ASSERT_OK(batch.Put("k1", "v1"));
+ ASSERT_OK(batch.Put(handle, "k2", "v2"));
+ ASSERT_TRUE(blob_db_->Write(WriteOptions(), &batch).IsNotSupported());
+ ASSERT_TRUE(blob_db_->Get(ReadOptions(), "k1", &value).IsNotFound());
+ ASSERT_TRUE(
+ blob_db_->Get(ReadOptions(), handle, "k", &value).IsNotSupported());
+ auto statuses = blob_db_->MultiGet(ReadOptions(), {default_handle, handle},
+ {"k1", "k2"}, &values);
+ ASSERT_EQ(2, statuses.size());
+ ASSERT_TRUE(statuses[0].IsNotSupported());
+ ASSERT_TRUE(statuses[1].IsNotSupported());
+ ASSERT_EQ(nullptr, blob_db_->NewIterator(ReadOptions(), handle));
+ delete handle;
+}
+
+TEST_F(BlobDBTest, GetLiveFilesMetaData) {
+ Random rnd(301);
+
+ BlobDBOptions bdb_options;
+ bdb_options.blob_dir = "blob_dir";
+ bdb_options.path_relative = true;
+ bdb_options.ttl_range_secs = 10;
+ bdb_options.min_blob_size = 0;
+ bdb_options.disable_background_tasks = true;
+
+ Options options;
+ options.env = mock_env_.get();
+
+ Open(bdb_options, options);
+
+ std::map<std::string, std::string> data;
+ for (size_t i = 0; i < 100; i++) {
+ PutRandom("key" + std::to_string(i), &rnd, &data);
+ }
+
+ constexpr uint64_t expiration = 1000ULL;
+ PutRandomUntil("key100", expiration, &rnd, &data);
+
+ std::vector<LiveFileMetaData> metadata;
+ blob_db_->GetLiveFilesMetaData(&metadata);
+
+ ASSERT_EQ(2U, metadata.size());
+ // Path should be relative to db_name, but begin with slash.
+ const std::string filename1("/blob_dir/000001.blob");
+ ASSERT_EQ(filename1, metadata[0].name);
+ ASSERT_EQ(1, metadata[0].file_number);
+ ASSERT_EQ(0, metadata[0].oldest_ancester_time);
+ ASSERT_EQ(kDefaultColumnFamilyName, metadata[0].column_family_name);
+
+ const std::string filename2("/blob_dir/000002.blob");
+ ASSERT_EQ(filename2, metadata[1].name);
+ ASSERT_EQ(2, metadata[1].file_number);
+ ASSERT_EQ(expiration, metadata[1].oldest_ancester_time);
+ ASSERT_EQ(kDefaultColumnFamilyName, metadata[1].column_family_name);
+
+ std::vector<std::string> livefile;
+ uint64_t mfs;
+ ASSERT_OK(blob_db_->GetLiveFiles(livefile, &mfs, false));
+ ASSERT_EQ(5U, livefile.size());
+ ASSERT_EQ(filename1, livefile[3]);
+ ASSERT_EQ(filename2, livefile[4]);
+ VerifyDB(data);
+}
+
+TEST_F(BlobDBTest, MigrateFromPlainRocksDB) {
+ constexpr size_t kNumKey = 20;
+ constexpr size_t kNumIteration = 10;
+ Random rnd(301);
+ std::map<std::string, std::string> data;
+ std::vector<bool> is_blob(kNumKey, false);
+
+ // Write to plain rocksdb.
+ Options options;
+ options.create_if_missing = true;
+ DB *db = nullptr;
+ ASSERT_OK(DB::Open(options, dbname_, &db));
+ for (size_t i = 0; i < kNumIteration; i++) {
+ auto key_index = rnd.Next() % kNumKey;
+ std::string key = "key" + std::to_string(key_index);
+ PutRandom(db, key, &rnd, &data);
+ }
+ VerifyDB(db, data);
+ delete db;
+ db = nullptr;
+
+ // Open as blob db. Verify it can read existing data.
+ Open();
+ VerifyDB(blob_db_, data);
+ for (size_t i = 0; i < kNumIteration; i++) {
+ auto key_index = rnd.Next() % kNumKey;
+ std::string key = "key" + std::to_string(key_index);
+ is_blob[key_index] = true;
+ PutRandom(blob_db_, key, &rnd, &data);
+ }
+ VerifyDB(blob_db_, data);
+ delete blob_db_;
+ blob_db_ = nullptr;
+
+ // Verify plain db return error for keys written by blob db.
+ ASSERT_OK(DB::Open(options, dbname_, &db));
+ std::string value;
+ for (size_t i = 0; i < kNumKey; i++) {
+ std::string key = "key" + std::to_string(i);
+ Status s = db->Get(ReadOptions(), key, &value);
+ if (data.count(key) == 0) {
+ ASSERT_TRUE(s.IsNotFound());
+ } else if (is_blob[i]) {
+ ASSERT_TRUE(s.IsCorruption());
+ } else {
+ ASSERT_OK(s);
+ ASSERT_EQ(data[key], value);
+ }
+ }
+ delete db;
+}
+
+// Test to verify that a NoSpace IOError Status is returned on reaching
+// max_db_size limit.
+TEST_F(BlobDBTest, OutOfSpace) {
+ // Use mock env to stop wall clock.
+ Options options;
+ options.env = mock_env_.get();
+ BlobDBOptions bdb_options;
+ bdb_options.max_db_size = 200;
+ bdb_options.is_fifo = false;
+ bdb_options.disable_background_tasks = true;
+ Open(bdb_options);
+
+ // Each stored blob has an overhead of about 42 bytes currently.
+ // So a small key + a 100 byte blob should take up ~150 bytes in the db.
+ std::string value(100, 'v');
+ ASSERT_OK(blob_db_->PutWithTTL(WriteOptions(), "key1", value, 60));
+
+ // Putting another blob should fail as ading it would exceed the max_db_size
+ // limit.
+ Status s = blob_db_->PutWithTTL(WriteOptions(), "key2", value, 60);
+ ASSERT_TRUE(s.IsIOError());
+ ASSERT_TRUE(s.IsNoSpace());
+}
+
+TEST_F(BlobDBTest, FIFOEviction) {
+ BlobDBOptions bdb_options;
+ bdb_options.max_db_size = 200;
+ bdb_options.blob_file_size = 100;
+ bdb_options.is_fifo = true;
+ bdb_options.disable_background_tasks = true;
+ Open(bdb_options);
+
+ std::atomic<int> evict_count{0};
+ SyncPoint::GetInstance()->SetCallBack(
+ "BlobDBImpl::EvictOldestBlobFile:Evicted",
+ [&](void *) { evict_count++; });
+ SyncPoint::GetInstance()->EnableProcessing();
+
+ // Each stored blob has an overhead of 32 bytes currently.
+ // So a 100 byte blob should take up 132 bytes.
+ std::string value(100, 'v');
+ ASSERT_OK(blob_db_->PutWithTTL(WriteOptions(), "key1", value, 10));
+ VerifyDB({{"key1", value}});
+
+ ASSERT_EQ(1, blob_db_impl()->TEST_GetBlobFiles().size());
+
+ // Adding another 100 bytes blob would take the total size to 264 bytes
+ // (2*132). max_db_size will be exceeded
+ // than max_db_size and trigger FIFO eviction.
+ ASSERT_OK(blob_db_->PutWithTTL(WriteOptions(), "key2", value, 60));
+ ASSERT_EQ(1, evict_count);
+ // key1 will exist until corresponding file be deleted.
+ VerifyDB({{"key1", value}, {"key2", value}});
+
+ // Adding another 100 bytes blob without TTL.
+ ASSERT_OK(blob_db_->Put(WriteOptions(), "key3", value));
+ ASSERT_EQ(2, evict_count);
+ // key1 and key2 will exist until corresponding file be deleted.
+ VerifyDB({{"key1", value}, {"key2", value}, {"key3", value}});
+
+ // The fourth blob file, without TTL.
+ ASSERT_OK(blob_db_->Put(WriteOptions(), "key4", value));
+ ASSERT_EQ(3, evict_count);
+ VerifyDB(
+ {{"key1", value}, {"key2", value}, {"key3", value}, {"key4", value}});
+
+ auto blob_files = blob_db_impl()->TEST_GetBlobFiles();
+ ASSERT_EQ(4, blob_files.size());
+ ASSERT_TRUE(blob_files[0]->Obsolete());
+ ASSERT_TRUE(blob_files[1]->Obsolete());
+ ASSERT_TRUE(blob_files[2]->Obsolete());
+ ASSERT_FALSE(blob_files[3]->Obsolete());
+ auto obsolete_files = blob_db_impl()->TEST_GetObsoleteFiles();
+ ASSERT_EQ(3, obsolete_files.size());
+ ASSERT_EQ(blob_files[0], obsolete_files[0]);
+ ASSERT_EQ(blob_files[1], obsolete_files[1]);
+ ASSERT_EQ(blob_files[2], obsolete_files[2]);
+
+ blob_db_impl()->TEST_DeleteObsoleteFiles();
+ obsolete_files = blob_db_impl()->TEST_GetObsoleteFiles();
+ ASSERT_TRUE(obsolete_files.empty());
+ VerifyDB({{"key4", value}});
+}
+
+TEST_F(BlobDBTest, FIFOEviction_NoOldestFileToEvict) {
+ Options options;
+ BlobDBOptions bdb_options;
+ bdb_options.max_db_size = 1000;
+ bdb_options.blob_file_size = 5000;
+ bdb_options.is_fifo = true;
+ bdb_options.disable_background_tasks = true;
+ Open(bdb_options);
+
+ std::atomic<int> evict_count{0};
+ SyncPoint::GetInstance()->SetCallBack(
+ "BlobDBImpl::EvictOldestBlobFile:Evicted",
+ [&](void *) { evict_count++; });
+ SyncPoint::GetInstance()->EnableProcessing();
+
+ std::string value(2000, 'v');
+ ASSERT_TRUE(Put("foo", std::string(2000, 'v')).IsNoSpace());
+ ASSERT_EQ(0, evict_count);
+}
+
+TEST_F(BlobDBTest, FIFOEviction_NoEnoughBlobFilesToEvict) {
+ BlobDBOptions bdb_options;
+ bdb_options.is_fifo = true;
+ bdb_options.min_blob_size = 100;
+ bdb_options.disable_background_tasks = true;
+ Options options;
+ // Use mock env to stop wall clock.
+ options.env = mock_env_.get();
+ options.disable_auto_compactions = true;
+ auto statistics = CreateDBStatistics();
+ options.statistics = statistics;
+ Open(bdb_options, options);
+
+ ASSERT_EQ(0, blob_db_impl()->TEST_live_sst_size());
+ std::string small_value(50, 'v');
+ std::map<std::string, std::string> data;
+ // Insert some data into LSM tree to make sure FIFO eviction take SST
+ // file size into account.
+ for (int i = 0; i < 1000; i++) {
+ ASSERT_OK(Put("key" + std::to_string(i), small_value, &data));
+ }
+ ASSERT_OK(blob_db_->Flush(FlushOptions()));
+ uint64_t live_sst_size = 0;
+ ASSERT_TRUE(blob_db_->GetIntProperty(DB::Properties::kTotalSstFilesSize,
+ &live_sst_size));
+ ASSERT_TRUE(live_sst_size > 0);
+ ASSERT_EQ(live_sst_size, blob_db_impl()->TEST_live_sst_size());
+
+ bdb_options.max_db_size = live_sst_size + 2000;
+ Reopen(bdb_options, options);
+ ASSERT_EQ(live_sst_size, blob_db_impl()->TEST_live_sst_size());
+
+ std::string value_1k(1000, 'v');
+ ASSERT_OK(PutWithTTL("large_key1", value_1k, 60, &data));
+ ASSERT_EQ(0, statistics->getTickerCount(BLOB_DB_FIFO_NUM_FILES_EVICTED));
+ VerifyDB(data);
+ // large_key2 evicts large_key1
+ ASSERT_OK(PutWithTTL("large_key2", value_1k, 60, &data));
+ ASSERT_EQ(1, statistics->getTickerCount(BLOB_DB_FIFO_NUM_FILES_EVICTED));
+ blob_db_impl()->TEST_DeleteObsoleteFiles();
+ data.erase("large_key1");
+ VerifyDB(data);
+ // large_key3 get no enough space even after evicting large_key2, so it
+ // instead return no space error.
+ std::string value_2k(2000, 'v');
+ ASSERT_TRUE(PutWithTTL("large_key3", value_2k, 60).IsNoSpace());
+ ASSERT_EQ(1, statistics->getTickerCount(BLOB_DB_FIFO_NUM_FILES_EVICTED));
+ // Verify large_key2 still exists.
+ VerifyDB(data);
+}
+
+// Test flush or compaction will trigger FIFO eviction since they update
+// total SST file size.
+TEST_F(BlobDBTest, FIFOEviction_TriggerOnSSTSizeChange) {
+ BlobDBOptions bdb_options;
+ bdb_options.max_db_size = 1000;
+ bdb_options.is_fifo = true;
+ bdb_options.min_blob_size = 100;
+ bdb_options.disable_background_tasks = true;
+ Options options;
+ // Use mock env to stop wall clock.
+ options.env = mock_env_.get();
+ auto statistics = CreateDBStatistics();
+ options.statistics = statistics;
+ options.compression = kNoCompression;
+ Open(bdb_options, options);
+
+ std::string value(800, 'v');
+ ASSERT_OK(PutWithTTL("large_key", value, 60));
+ ASSERT_EQ(1, blob_db_impl()->TEST_GetBlobFiles().size());
+ ASSERT_EQ(0, statistics->getTickerCount(BLOB_DB_FIFO_NUM_FILES_EVICTED));
+ VerifyDB({{"large_key", value}});
+
+ // Insert some small keys and flush to bring DB out of space.
+ std::map<std::string, std::string> data;
+ for (int i = 0; i < 10; i++) {
+ ASSERT_OK(Put("key" + std::to_string(i), "v", &data));
+ }
+ ASSERT_OK(blob_db_->Flush(FlushOptions()));
+
+ // Verify large_key is deleted by FIFO eviction.
+ blob_db_impl()->TEST_DeleteObsoleteFiles();
+ ASSERT_EQ(0, blob_db_impl()->TEST_GetBlobFiles().size());
+ ASSERT_EQ(1, statistics->getTickerCount(BLOB_DB_FIFO_NUM_FILES_EVICTED));
+ VerifyDB(data);
+}
+
+TEST_F(BlobDBTest, InlineSmallValues) {
+ constexpr uint64_t kMaxExpiration = 1000;
+ Random rnd(301);
+ BlobDBOptions bdb_options;
+ bdb_options.ttl_range_secs = kMaxExpiration;
+ bdb_options.min_blob_size = 100;
+ bdb_options.blob_file_size = 256 * 1000 * 1000;
+ bdb_options.disable_background_tasks = true;
+ Options options;
+ options.env = mock_env_.get();
+ mock_clock_->SetCurrentTime(0);
+ Open(bdb_options, options);
+ std::map<std::string, std::string> data;
+ std::map<std::string, KeyVersion> versions;
+ for (size_t i = 0; i < 1000; i++) {
+ bool is_small_value = rnd.Next() % 2;
+ bool has_ttl = rnd.Next() % 2;
+ uint64_t expiration = rnd.Next() % kMaxExpiration;
+ int len = is_small_value ? 50 : 200;
+ std::string key = "key" + std::to_string(i);
+ std::string value = rnd.HumanReadableString(len);
+ std::string blob_index;
+ data[key] = value;
+ SequenceNumber sequence = blob_db_->GetLatestSequenceNumber() + 1;
+ if (!has_ttl) {
+ ASSERT_OK(blob_db_->Put(WriteOptions(), key, value));
+ } else {
+ ASSERT_OK(blob_db_->PutUntil(WriteOptions(), key, value, expiration));
+ }
+ ASSERT_EQ(blob_db_->GetLatestSequenceNumber(), sequence);
+ versions[key] =
+ KeyVersion(key, value, sequence,
+ (is_small_value && !has_ttl) ? kTypeValue : kTypeBlobIndex);
+ }
+ VerifyDB(data);
+ VerifyBaseDB(versions);
+ auto *bdb_impl = static_cast<BlobDBImpl *>(blob_db_);
+ auto blob_files = bdb_impl->TEST_GetBlobFiles();
+ ASSERT_EQ(2, blob_files.size());
+ std::shared_ptr<BlobFile> non_ttl_file;
+ std::shared_ptr<BlobFile> ttl_file;
+ if (blob_files[0]->HasTTL()) {
+ ttl_file = blob_files[0];
+ non_ttl_file = blob_files[1];
+ } else {
+ non_ttl_file = blob_files[0];
+ ttl_file = blob_files[1];
+ }
+ ASSERT_FALSE(non_ttl_file->HasTTL());
+ ASSERT_TRUE(ttl_file->HasTTL());
+}
+
+TEST_F(BlobDBTest, UserCompactionFilter) {
+ class CustomerFilter : public CompactionFilter {
+ public:
+ bool Filter(int /*level*/, const Slice & /*key*/, const Slice &value,
+ std::string *new_value, bool *value_changed) const override {
+ *value_changed = false;
+ // changing value size to test value transitions between inlined data
+ // and stored-in-blob data
+ if (value.size() % 4 == 1) {
+ *new_value = value.ToString();
+ // double size by duplicating value
+ *new_value += *new_value;
+ *value_changed = true;
+ return false;
+ } else if (value.size() % 3 == 1) {
+ *new_value = value.ToString();
+ // trancate value size by half
+ *new_value = new_value->substr(0, new_value->size() / 2);
+ *value_changed = true;
+ return false;
+ } else if (value.size() % 2 == 1) {
+ return true;
+ }
+ return false;
+ }
+ bool IgnoreSnapshots() const override { return true; }
+ const char *Name() const override { return "CustomerFilter"; }
+ };
+ class CustomerFilterFactory : public CompactionFilterFactory {
+ const char *Name() const override { return "CustomerFilterFactory"; }
+ std::unique_ptr<CompactionFilter> CreateCompactionFilter(
+ const CompactionFilter::Context & /*context*/) override {
+ return std::unique_ptr<CompactionFilter>(new CustomerFilter());
+ }
+ };
+
+ constexpr size_t kNumPuts = 1 << 10;
+ // Generate both inlined and blob value
+ constexpr uint64_t kMinValueSize = 1 << 6;
+ constexpr uint64_t kMaxValueSize = 1 << 8;
+ constexpr uint64_t kMinBlobSize = 1 << 7;
+ static_assert(kMinValueSize < kMinBlobSize, "");
+ static_assert(kMaxValueSize > kMinBlobSize, "");
+
+ BlobDBOptions bdb_options;
+ bdb_options.min_blob_size = kMinBlobSize;
+ bdb_options.blob_file_size = kMaxValueSize * 10;
+ bdb_options.disable_background_tasks = true;
+ if (Snappy_Supported()) {
+ bdb_options.compression = CompressionType::kSnappyCompression;
+ }
+ // case_num == 0: Test user defined compaction filter
+ // case_num == 1: Test user defined compaction filter factory
+ for (int case_num = 0; case_num < 2; case_num++) {
+ Options options;
+ if (case_num == 0) {
+ options.compaction_filter = new CustomerFilter();
+ } else {
+ options.compaction_filter_factory.reset(new CustomerFilterFactory());
+ }
+ options.disable_auto_compactions = true;
+ options.env = mock_env_.get();
+ options.statistics = CreateDBStatistics();
+ Open(bdb_options, options);
+
+ std::map<std::string, std::string> data;
+ std::map<std::string, std::string> data_after_compact;
+ Random rnd(301);
+ uint64_t value_size = kMinValueSize;
+ int drop_record = 0;
+ for (size_t i = 0; i < kNumPuts; ++i) {
+ std::ostringstream oss;
+ oss << "key" << std::setw(4) << std::setfill('0') << i;
+
+ const std::string key(oss.str());
+ const std::string value = rnd.HumanReadableString((int)value_size);
+ const SequenceNumber sequence = blob_db_->GetLatestSequenceNumber() + 1;
+
+ ASSERT_OK(Put(key, value));
+ ASSERT_EQ(blob_db_->GetLatestSequenceNumber(), sequence);
+
+ data[key] = value;
+ if (value.length() % 4 == 1) {
+ data_after_compact[key] = value + value;
+ } else if (value.length() % 3 == 1) {
+ data_after_compact[key] = value.substr(0, value.size() / 2);
+ } else if (value.length() % 2 == 1) {
+ ++drop_record;
+ } else {
+ data_after_compact[key] = value;
+ }
+
+ if (++value_size > kMaxValueSize) {
+ value_size = kMinValueSize;
+ }
+ }
+ // Verify full data set
+ VerifyDB(data);
+ // Applying compaction filter for records
+ ASSERT_OK(blob_db_->CompactRange(CompactRangeOptions(), nullptr, nullptr));
+ // Verify data after compaction, only value with even length left.
+ VerifyDB(data_after_compact);
+ ASSERT_EQ(drop_record,
+ options.statistics->getTickerCount(COMPACTION_KEY_DROP_USER));
+ delete options.compaction_filter;
+ Destroy();
+ }
+}
+
+// Test user comapction filter when there is IO error on blob data.
+TEST_F(BlobDBTest, UserCompactionFilter_BlobIOError) {
+ class CustomerFilter : public CompactionFilter {
+ public:
+ bool Filter(int /*level*/, const Slice & /*key*/, const Slice &value,
+ std::string *new_value, bool *value_changed) const override {
+ *new_value = value.ToString() + "_new";
+ *value_changed = true;
+ return false;
+ }
+ bool IgnoreSnapshots() const override { return true; }
+ const char *Name() const override { return "CustomerFilter"; }
+ };
+
+ constexpr size_t kNumPuts = 100;
+ constexpr int kValueSize = 100;
+
+ BlobDBOptions bdb_options;
+ bdb_options.min_blob_size = 0;
+ bdb_options.blob_file_size = kValueSize * 10;
+ bdb_options.disable_background_tasks = true;
+ bdb_options.compression = CompressionType::kNoCompression;
+
+ std::vector<std::string> io_failure_cases = {
+ "BlobDBImpl::CreateBlobFileAndWriter",
+ "BlobIndexCompactionFilterBase::WriteBlobToNewFile",
+ "BlobDBImpl::CloseBlobFile"};
+
+ for (size_t case_num = 0; case_num < io_failure_cases.size(); case_num++) {
+ Options options;
+ options.compaction_filter = new CustomerFilter();
+ options.disable_auto_compactions = true;
+ options.env = fault_injection_env_.get();
+ options.statistics = CreateDBStatistics();
+ Open(bdb_options, options);
+
+ std::map<std::string, std::string> data;
+ Random rnd(301);
+ for (size_t i = 0; i < kNumPuts; ++i) {
+ std::ostringstream oss;
+ oss << "key" << std::setw(4) << std::setfill('0') << i;
+
+ const std::string key(oss.str());
+ const std::string value = rnd.HumanReadableString(kValueSize);
+ const SequenceNumber sequence = blob_db_->GetLatestSequenceNumber() + 1;
+
+ ASSERT_OK(Put(key, value));
+ ASSERT_EQ(blob_db_->GetLatestSequenceNumber(), sequence);
+ data[key] = value;
+ }
+
+ // Verify full data set
+ VerifyDB(data);
+
+ SyncPoint::GetInstance()->SetCallBack(
+ io_failure_cases[case_num], [&](void * /*arg*/) {
+ fault_injection_env_->SetFilesystemActive(false, Status::IOError());
+ });
+ SyncPoint::GetInstance()->EnableProcessing();
+ auto s = blob_db_->CompactRange(CompactRangeOptions(), nullptr, nullptr);
+ ASSERT_TRUE(s.IsIOError());
+
+ // Reactivate file system to allow test to verify and close DB.
+ fault_injection_env_->SetFilesystemActive(true);
+ SyncPoint::GetInstance()->DisableProcessing();
+ SyncPoint::GetInstance()->ClearAllCallBacks();
+
+ // Verify full data set after compaction failure
+ VerifyDB(data);
+
+ delete options.compaction_filter;
+ Destroy();
+ }
+}
+
+// Test comapction filter should remove any expired blob index.
+TEST_F(BlobDBTest, FilterExpiredBlobIndex) {
+ constexpr size_t kNumKeys = 100;
+ constexpr size_t kNumPuts = 1000;
+ constexpr uint64_t kMaxExpiration = 1000;
+ constexpr uint64_t kCompactTime = 500;
+ constexpr uint64_t kMinBlobSize = 100;
+ Random rnd(301);
+ mock_clock_->SetCurrentTime(0);
+ BlobDBOptions bdb_options;
+ bdb_options.min_blob_size = kMinBlobSize;
+ bdb_options.disable_background_tasks = true;
+ Options options;
+ options.env = mock_env_.get();
+ Open(bdb_options, options);
+
+ std::map<std::string, std::string> data;
+ std::map<std::string, std::string> data_after_compact;
+ for (size_t i = 0; i < kNumPuts; i++) {
+ bool is_small_value = rnd.Next() % 2;
+ bool has_ttl = rnd.Next() % 2;
+ uint64_t expiration = rnd.Next() % kMaxExpiration;
+ int len = is_small_value ? 10 : 200;
+ std::string key = "key" + std::to_string(rnd.Next() % kNumKeys);
+ std::string value = rnd.HumanReadableString(len);
+ if (!has_ttl) {
+ if (is_small_value) {
+ std::string blob_entry;
+ BlobIndex::EncodeInlinedTTL(&blob_entry, expiration, value);
+ // Fake blob index with TTL. See what it will do.
+ ASSERT_GT(kMinBlobSize, blob_entry.size());
+ value = blob_entry;
+ }
+ ASSERT_OK(Put(key, value));
+ data_after_compact[key] = value;
+ } else {
+ ASSERT_OK(PutUntil(key, value, expiration));
+ if (expiration <= kCompactTime) {
+ data_after_compact.erase(key);
+ } else {
+ data_after_compact[key] = value;
+ }
+ }
+ data[key] = value;
+ }
+ VerifyDB(data);
+
+ mock_clock_->SetCurrentTime(kCompactTime);
+ // Take a snapshot before compaction. Make sure expired blob indexes is
+ // filtered regardless of snapshot.
+ const Snapshot *snapshot = blob_db_->GetSnapshot();
+ // Issue manual compaction to trigger compaction filter.
+ ASSERT_OK(blob_db_->CompactRange(CompactRangeOptions(), nullptr, nullptr));
+ blob_db_->ReleaseSnapshot(snapshot);
+ // Verify expired blob index are filtered.
+ std::vector<KeyVersion> versions;
+ const size_t kMaxKeys = 10000;
+ ASSERT_OK(GetAllKeyVersions(blob_db_, "", "", kMaxKeys, &versions));
+ ASSERT_EQ(data_after_compact.size(), versions.size());
+ for (auto &version : versions) {
+ ASSERT_TRUE(data_after_compact.count(version.user_key) > 0);
+ }
+ VerifyDB(data_after_compact);
+}
+
+// Test compaction filter should remove any blob index where corresponding
+// blob file has been removed.
+TEST_F(BlobDBTest, FilterFileNotAvailable) {
+ BlobDBOptions bdb_options;
+ bdb_options.min_blob_size = 0;
+ bdb_options.disable_background_tasks = true;
+ Options options;
+ options.disable_auto_compactions = true;
+ Open(bdb_options, options);
+
+ ASSERT_OK(Put("foo", "v1"));
+ auto blob_files = blob_db_impl()->TEST_GetBlobFiles();
+ ASSERT_EQ(1, blob_files.size());
+ ASSERT_EQ(1, blob_files[0]->BlobFileNumber());
+ ASSERT_OK(blob_db_impl()->TEST_CloseBlobFile(blob_files[0]));
+
+ ASSERT_OK(Put("bar", "v2"));
+ blob_files = blob_db_impl()->TEST_GetBlobFiles();
+ ASSERT_EQ(2, blob_files.size());
+ ASSERT_EQ(2, blob_files[1]->BlobFileNumber());
+ ASSERT_OK(blob_db_impl()->TEST_CloseBlobFile(blob_files[1]));
+
+ const size_t kMaxKeys = 10000;
+
+ DB *base_db = blob_db_->GetRootDB();
+ std::vector<KeyVersion> versions;
+ ASSERT_OK(GetAllKeyVersions(base_db, "", "", kMaxKeys, &versions));
+ ASSERT_EQ(2, versions.size());
+ ASSERT_EQ("bar", versions[0].user_key);
+ ASSERT_EQ("foo", versions[1].user_key);
+ VerifyDB({{"bar", "v2"}, {"foo", "v1"}});
+
+ ASSERT_OK(blob_db_->CompactRange(CompactRangeOptions(), nullptr, nullptr));
+ ASSERT_OK(GetAllKeyVersions(base_db, "", "", kMaxKeys, &versions));
+ ASSERT_EQ(2, versions.size());
+ ASSERT_EQ("bar", versions[0].user_key);
+ ASSERT_EQ("foo", versions[1].user_key);
+ VerifyDB({{"bar", "v2"}, {"foo", "v1"}});
+
+ // Remove the first blob file and compact. foo should be remove from base db.
+ blob_db_impl()->TEST_ObsoleteBlobFile(blob_files[0]);
+ blob_db_impl()->TEST_DeleteObsoleteFiles();
+ ASSERT_OK(blob_db_->CompactRange(CompactRangeOptions(), nullptr, nullptr));
+ ASSERT_OK(GetAllKeyVersions(base_db, "", "", kMaxKeys, &versions));
+ ASSERT_EQ(1, versions.size());
+ ASSERT_EQ("bar", versions[0].user_key);
+ VerifyDB({{"bar", "v2"}});
+
+ // Remove the second blob file and compact. bar should be remove from base db.
+ blob_db_impl()->TEST_ObsoleteBlobFile(blob_files[1]);
+ blob_db_impl()->TEST_DeleteObsoleteFiles();
+ ASSERT_OK(blob_db_->CompactRange(CompactRangeOptions(), nullptr, nullptr));
+ ASSERT_OK(GetAllKeyVersions(base_db, "", "", kMaxKeys, &versions));
+ ASSERT_EQ(0, versions.size());
+ VerifyDB({});
+}
+
+// Test compaction filter should filter any inlined TTL keys that would have
+// been dropped by last FIFO eviction if they are store out-of-line.
+TEST_F(BlobDBTest, FilterForFIFOEviction) {
+ Random rnd(215);
+ BlobDBOptions bdb_options;
+ bdb_options.min_blob_size = 100;
+ bdb_options.ttl_range_secs = 60;
+ bdb_options.max_db_size = 0;
+ bdb_options.disable_background_tasks = true;
+ Options options;
+ // Use mock env to stop wall clock.
+ mock_clock_->SetCurrentTime(0);
+ options.env = mock_env_.get();
+ auto statistics = CreateDBStatistics();
+ options.statistics = statistics;
+ options.disable_auto_compactions = true;
+ Open(bdb_options, options);
+
+ std::map<std::string, std::string> data;
+ std::map<std::string, std::string> data_after_compact;
+ // Insert some small values that will be inlined.
+ for (int i = 0; i < 1000; i++) {
+ std::string key = "key" + std::to_string(i);
+ std::string value = rnd.HumanReadableString(50);
+ uint64_t ttl = rnd.Next() % 120 + 1;
+ ASSERT_OK(PutWithTTL(key, value, ttl, &data));
+ if (ttl >= 60) {
+ data_after_compact[key] = value;
+ }
+ }
+ uint64_t num_keys_to_evict = data.size() - data_after_compact.size();
+ ASSERT_OK(blob_db_->Flush(FlushOptions()));
+ uint64_t live_sst_size = blob_db_impl()->TEST_live_sst_size();
+ ASSERT_GT(live_sst_size, 0);
+ VerifyDB(data);
+
+ bdb_options.max_db_size = live_sst_size + 30000;
+ bdb_options.is_fifo = true;
+ Reopen(bdb_options, options);
+ VerifyDB(data);
+
+ // Put two large values, each on a different blob file.
+ std::string large_value(10000, 'v');
+ ASSERT_OK(PutWithTTL("large_key1", large_value, 90));
+ ASSERT_OK(PutWithTTL("large_key2", large_value, 150));
+ ASSERT_EQ(2, blob_db_impl()->TEST_GetBlobFiles().size());
+ ASSERT_EQ(0, statistics->getTickerCount(BLOB_DB_FIFO_NUM_FILES_EVICTED));
+ data["large_key1"] = large_value;
+ data["large_key2"] = large_value;
+ VerifyDB(data);
+
+ // Put a third large value which will bring the DB out of space.
+ // FIFO eviction will evict the file of large_key1.
+ ASSERT_OK(PutWithTTL("large_key3", large_value, 150));
+ ASSERT_EQ(1, statistics->getTickerCount(BLOB_DB_FIFO_NUM_FILES_EVICTED));
+ ASSERT_EQ(2, blob_db_impl()->TEST_GetBlobFiles().size());
+ blob_db_impl()->TEST_DeleteObsoleteFiles();
+ ASSERT_EQ(1, blob_db_impl()->TEST_GetBlobFiles().size());
+ data.erase("large_key1");
+ data["large_key3"] = large_value;
+ VerifyDB(data);
+
+ // Putting some more small values. These values shouldn't be evicted by
+ // compaction filter since they are inserted after FIFO eviction.
+ ASSERT_OK(PutWithTTL("foo", "v", 30, &data_after_compact));
+ ASSERT_OK(PutWithTTL("bar", "v", 30, &data_after_compact));
+
+ // FIFO eviction doesn't trigger again since there enough room for the flush.
+ ASSERT_OK(blob_db_->Flush(FlushOptions()));
+ ASSERT_EQ(1, statistics->getTickerCount(BLOB_DB_FIFO_NUM_FILES_EVICTED));
+
+ // Manual compact and check if compaction filter evict those keys with
+ // expiration < 60.
+ ASSERT_OK(blob_db_->CompactRange(CompactRangeOptions(), nullptr, nullptr));
+ // All keys with expiration < 60, plus large_key1 is filtered by
+ // compaction filter.
+ ASSERT_EQ(num_keys_to_evict + 1,
+ statistics->getTickerCount(BLOB_DB_BLOB_INDEX_EVICTED_COUNT));
+ ASSERT_EQ(1, statistics->getTickerCount(BLOB_DB_FIFO_NUM_FILES_EVICTED));
+ ASSERT_EQ(1, blob_db_impl()->TEST_GetBlobFiles().size());
+ data_after_compact["large_key2"] = large_value;
+ data_after_compact["large_key3"] = large_value;
+ VerifyDB(data_after_compact);
+}
+
+TEST_F(BlobDBTest, GarbageCollection) {
+ constexpr size_t kNumPuts = 1 << 10;
+
+ constexpr uint64_t kExpiration = 1000;
+ constexpr uint64_t kCompactTime = 500;
+
+ constexpr uint64_t kKeySize = 7; // "key" + 4 digits
+
+ constexpr uint64_t kSmallValueSize = 1 << 6;
+ constexpr uint64_t kLargeValueSize = 1 << 8;
+ constexpr uint64_t kMinBlobSize = 1 << 7;
+ static_assert(kSmallValueSize < kMinBlobSize, "");
+ static_assert(kLargeValueSize > kMinBlobSize, "");
+
+ constexpr size_t kBlobsPerFile = 8;
+ constexpr size_t kNumBlobFiles = kNumPuts / kBlobsPerFile;
+ constexpr uint64_t kBlobFileSize =
+ BlobLogHeader::kSize +
+ (BlobLogRecord::kHeaderSize + kKeySize + kLargeValueSize) * kBlobsPerFile;
+
+ BlobDBOptions bdb_options;
+ bdb_options.min_blob_size = kMinBlobSize;
+ bdb_options.blob_file_size = kBlobFileSize;
+ bdb_options.enable_garbage_collection = true;
+ bdb_options.garbage_collection_cutoff = 0.25;
+ bdb_options.disable_background_tasks = true;
+
+ Options options;
+ options.env = mock_env_.get();
+ options.statistics = CreateDBStatistics();
+
+ Open(bdb_options, options);
+
+ std::map<std::string, std::string> data;
+ std::map<std::string, KeyVersion> blob_value_versions;
+ std::map<std::string, BlobIndexVersion> blob_index_versions;
+
+ Random rnd(301);
+
+ // Add a bunch of large non-TTL values. These will be written to non-TTL
+ // blob files and will be subject to GC.
+ for (size_t i = 0; i < kNumPuts; ++i) {
+ std::ostringstream oss;
+ oss << "key" << std::setw(4) << std::setfill('0') << i;
+
+ const std::string key(oss.str());
+ const std::string value = rnd.HumanReadableString(kLargeValueSize);
+ const SequenceNumber sequence = blob_db_->GetLatestSequenceNumber() + 1;
+
+ ASSERT_OK(Put(key, value));
+ ASSERT_EQ(blob_db_->GetLatestSequenceNumber(), sequence);
+
+ data[key] = value;
+ blob_value_versions[key] = KeyVersion(key, value, sequence, kTypeBlobIndex);
+ blob_index_versions[key] =
+ BlobIndexVersion(key, /* file_number */ (i >> 3) + 1, kNoExpiration,
+ sequence, kTypeBlobIndex);
+ }
+
+ // Add some small and/or TTL values that will be ignored during GC.
+ // First, add a large TTL value will be written to its own TTL blob file.
+ {
+ const std::string key("key2000");
+ const std::string value = rnd.HumanReadableString(kLargeValueSize);
+ const SequenceNumber sequence = blob_db_->GetLatestSequenceNumber() + 1;
+
+ ASSERT_OK(PutUntil(key, value, kExpiration));
+ ASSERT_EQ(blob_db_->GetLatestSequenceNumber(), sequence);
+
+ data[key] = value;
+ blob_value_versions[key] = KeyVersion(key, value, sequence, kTypeBlobIndex);
+ blob_index_versions[key] =
+ BlobIndexVersion(key, /* file_number */ kNumBlobFiles + 1, kExpiration,
+ sequence, kTypeBlobIndex);
+ }
+
+ // Now add a small TTL value (which will be inlined).
+ {
+ const std::string key("key3000");
+ const std::string value = rnd.HumanReadableString(kSmallValueSize);
+ const SequenceNumber sequence = blob_db_->GetLatestSequenceNumber() + 1;
+
+ ASSERT_OK(PutUntil(key, value, kExpiration));
+ ASSERT_EQ(blob_db_->GetLatestSequenceNumber(), sequence);
+
+ data[key] = value;
+ blob_value_versions[key] = KeyVersion(key, value, sequence, kTypeBlobIndex);
+ blob_index_versions[key] = BlobIndexVersion(
+ key, kInvalidBlobFileNumber, kExpiration, sequence, kTypeBlobIndex);
+ }
+
+ // Finally, add a small non-TTL value (which will be stored as a regular
+ // value).
+ {
+ const std::string key("key4000");
+ const std::string value = rnd.HumanReadableString(kSmallValueSize);
+ const SequenceNumber sequence = blob_db_->GetLatestSequenceNumber() + 1;
+
+ ASSERT_OK(Put(key, value));
+ ASSERT_EQ(blob_db_->GetLatestSequenceNumber(), sequence);
+
+ data[key] = value;
+ blob_value_versions[key] = KeyVersion(key, value, sequence, kTypeValue);
+ blob_index_versions[key] = BlobIndexVersion(
+ key, kInvalidBlobFileNumber, kNoExpiration, sequence, kTypeValue);
+ }
+
+ VerifyDB(data);
+ VerifyBaseDB(blob_value_versions);
+ VerifyBaseDBBlobIndex(blob_index_versions);
+
+ // At this point, we should have 128 immutable non-TTL files with file numbers
+ // 1..128.
+ {
+ auto live_imm_files = blob_db_impl()->TEST_GetLiveImmNonTTLFiles();
+ ASSERT_EQ(live_imm_files.size(), kNumBlobFiles);
+ for (size_t i = 0; i < kNumBlobFiles; ++i) {
+ ASSERT_EQ(live_imm_files[i]->BlobFileNumber(), i + 1);
+ ASSERT_EQ(live_imm_files[i]->GetFileSize(),
+ kBlobFileSize + BlobLogFooter::kSize);
+ }
+ }
+
+ mock_clock_->SetCurrentTime(kCompactTime);
+
+ ASSERT_OK(blob_db_->CompactRange(CompactRangeOptions(), nullptr, nullptr));
+
+ // We expect the data to remain the same and the blobs from the oldest N files
+ // to be moved to new files. Sequence numbers get zeroed out during the
+ // compaction.
+ VerifyDB(data);
+
+ for (auto &pair : blob_value_versions) {
+ KeyVersion &version = pair.second;
+ version.sequence = 0;
+ }
+
+ VerifyBaseDB(blob_value_versions);
+
+ const uint64_t cutoff = static_cast<uint64_t>(
+ bdb_options.garbage_collection_cutoff * kNumBlobFiles);
+ for (auto &pair : blob_index_versions) {
+ BlobIndexVersion &version = pair.second;
+
+ version.sequence = 0;
+
+ if (version.file_number == kInvalidBlobFileNumber) {
+ continue;
+ }
+
+ if (version.file_number > cutoff) {
+ continue;
+ }
+
+ version.file_number += kNumBlobFiles + 1;
+ }
+
+ VerifyBaseDBBlobIndex(blob_index_versions);
+
+ const Statistics *const statistics = options.statistics.get();
+ assert(statistics);
+
+ ASSERT_EQ(statistics->getTickerCount(BLOB_DB_GC_NUM_FILES), cutoff);
+ ASSERT_EQ(statistics->getTickerCount(BLOB_DB_GC_NUM_NEW_FILES), cutoff);
+ ASSERT_EQ(statistics->getTickerCount(BLOB_DB_GC_FAILURES), 0);
+ ASSERT_EQ(statistics->getTickerCount(BLOB_DB_GC_NUM_KEYS_RELOCATED),
+ cutoff * kBlobsPerFile);
+ ASSERT_EQ(statistics->getTickerCount(BLOB_DB_GC_BYTES_RELOCATED),
+ cutoff * kBlobsPerFile * kLargeValueSize);
+
+ // At this point, we should have 128 immutable non-TTL files with file numbers
+ // 33..128 and 130..161. (129 was taken by the TTL blob file.)
+ {
+ auto live_imm_files = blob_db_impl()->TEST_GetLiveImmNonTTLFiles();
+ ASSERT_EQ(live_imm_files.size(), kNumBlobFiles);
+ for (size_t i = 0; i < kNumBlobFiles; ++i) {
+ uint64_t expected_file_number = i + cutoff + 1;
+ if (expected_file_number > kNumBlobFiles) {
+ ++expected_file_number;
+ }
+
+ ASSERT_EQ(live_imm_files[i]->BlobFileNumber(), expected_file_number);
+ ASSERT_EQ(live_imm_files[i]->GetFileSize(),
+ kBlobFileSize + BlobLogFooter::kSize);
+ }
+ }
+}
+
+TEST_F(BlobDBTest, GarbageCollectionFailure) {
+ BlobDBOptions bdb_options;
+ bdb_options.min_blob_size = 0;
+ bdb_options.enable_garbage_collection = true;
+ bdb_options.garbage_collection_cutoff = 1.0;
+ bdb_options.disable_background_tasks = true;
+
+ Options db_options;
+ db_options.statistics = CreateDBStatistics();
+
+ Open(bdb_options, db_options);
+
+ // Write a couple of valid blobs.
+ ASSERT_OK(Put("foo", "bar"));
+ ASSERT_OK(Put("dead", "beef"));
+
+ // Write a fake blob reference into the base DB that points to a non-existing
+ // blob file.
+ std::string blob_index;
+ BlobIndex::EncodeBlob(&blob_index, /* file_number */ 1000, /* offset */ 1234,
+ /* size */ 5678, kNoCompression);
+
+ WriteBatch batch;
+ ASSERT_OK(WriteBatchInternal::PutBlobIndex(
+ &batch, blob_db_->DefaultColumnFamily()->GetID(), "key", blob_index));
+ ASSERT_OK(blob_db_->GetRootDB()->Write(WriteOptions(), &batch));
+
+ auto blob_files = blob_db_impl()->TEST_GetBlobFiles();
+ ASSERT_EQ(blob_files.size(), 1);
+ auto blob_file = blob_files[0];
+ ASSERT_OK(blob_db_impl()->TEST_CloseBlobFile(blob_file));
+
+ ASSERT_TRUE(blob_db_->CompactRange(CompactRangeOptions(), nullptr, nullptr)
+ .IsIOError());
+
+ const Statistics *const statistics = db_options.statistics.get();
+ assert(statistics);
+
+ ASSERT_EQ(statistics->getTickerCount(BLOB_DB_GC_NUM_FILES), 0);
+ ASSERT_EQ(statistics->getTickerCount(BLOB_DB_GC_NUM_NEW_FILES), 1);
+ ASSERT_EQ(statistics->getTickerCount(BLOB_DB_GC_FAILURES), 1);
+ ASSERT_EQ(statistics->getTickerCount(BLOB_DB_GC_NUM_KEYS_RELOCATED), 2);
+ ASSERT_EQ(statistics->getTickerCount(BLOB_DB_GC_BYTES_RELOCATED), 7);
+}
+
+// File should be evicted after expiration.
+TEST_F(BlobDBTest, EvictExpiredFile) {
+ BlobDBOptions bdb_options;
+ bdb_options.ttl_range_secs = 100;
+ bdb_options.min_blob_size = 0;
+ bdb_options.disable_background_tasks = true;
+ Options options;
+ options.env = mock_env_.get();
+ Open(bdb_options, options);
+ mock_clock_->SetCurrentTime(50);
+ std::map<std::string, std::string> data;
+ ASSERT_OK(PutWithTTL("foo", "bar", 100, &data));
+ auto blob_files = blob_db_impl()->TEST_GetBlobFiles();
+ ASSERT_EQ(1, blob_files.size());
+ auto blob_file = blob_files[0];
+ ASSERT_FALSE(blob_file->Immutable());
+ ASSERT_FALSE(blob_file->Obsolete());
+ VerifyDB(data);
+ mock_clock_->SetCurrentTime(250);
+ // The key should expired now.
+ blob_db_impl()->TEST_EvictExpiredFiles();
+ ASSERT_EQ(1, blob_db_impl()->TEST_GetBlobFiles().size());
+ ASSERT_EQ(1, blob_db_impl()->TEST_GetObsoleteFiles().size());
+ ASSERT_TRUE(blob_file->Immutable());
+ ASSERT_TRUE(blob_file->Obsolete());
+ blob_db_impl()->TEST_DeleteObsoleteFiles();
+ ASSERT_EQ(0, blob_db_impl()->TEST_GetBlobFiles().size());
+ ASSERT_EQ(0, blob_db_impl()->TEST_GetObsoleteFiles().size());
+ // Make sure we don't return garbage value after blob file being evicted,
+ // but the blob index still exists in the LSM tree.
+ std::string val = "";
+ ASSERT_TRUE(blob_db_->Get(ReadOptions(), "foo", &val).IsNotFound());
+ ASSERT_EQ("", val);
+}
+
+TEST_F(BlobDBTest, DisableFileDeletions) {
+ BlobDBOptions bdb_options;
+ bdb_options.disable_background_tasks = true;
+ Open(bdb_options);
+ std::map<std::string, std::string> data;
+ for (bool force : {true, false}) {
+ ASSERT_OK(Put("foo", "v", &data));
+ auto blob_files = blob_db_impl()->TEST_GetBlobFiles();
+ ASSERT_EQ(1, blob_files.size());
+ auto blob_file = blob_files[0];
+ ASSERT_OK(blob_db_impl()->TEST_CloseBlobFile(blob_file));
+ blob_db_impl()->TEST_ObsoleteBlobFile(blob_file);
+ ASSERT_EQ(1, blob_db_impl()->TEST_GetBlobFiles().size());
+ ASSERT_EQ(1, blob_db_impl()->TEST_GetObsoleteFiles().size());
+ // Call DisableFileDeletions twice.
+ ASSERT_OK(blob_db_->DisableFileDeletions());
+ ASSERT_OK(blob_db_->DisableFileDeletions());
+ // File deletions should be disabled.
+ blob_db_impl()->TEST_DeleteObsoleteFiles();
+ ASSERT_EQ(1, blob_db_impl()->TEST_GetBlobFiles().size());
+ ASSERT_EQ(1, blob_db_impl()->TEST_GetObsoleteFiles().size());
+ VerifyDB(data);
+ // Enable file deletions once. If force=true, file deletion is enabled.
+ // Otherwise it needs to enable it for a second time.
+ ASSERT_OK(blob_db_->EnableFileDeletions(force));
+ blob_db_impl()->TEST_DeleteObsoleteFiles();
+ if (!force) {
+ ASSERT_EQ(1, blob_db_impl()->TEST_GetBlobFiles().size());
+ ASSERT_EQ(1, blob_db_impl()->TEST_GetObsoleteFiles().size());
+ VerifyDB(data);
+ // Call EnableFileDeletions a second time.
+ ASSERT_OK(blob_db_->EnableFileDeletions(false));
+ blob_db_impl()->TEST_DeleteObsoleteFiles();
+ }
+ // Regardless of value of `force`, file should be deleted by now.
+ ASSERT_EQ(0, blob_db_impl()->TEST_GetBlobFiles().size());
+ ASSERT_EQ(0, blob_db_impl()->TEST_GetObsoleteFiles().size());
+ VerifyDB({});
+ }
+}
+
+TEST_F(BlobDBTest, MaintainBlobFileToSstMapping) {
+ BlobDBOptions bdb_options;
+ bdb_options.enable_garbage_collection = true;
+ bdb_options.disable_background_tasks = true;
+ Open(bdb_options);
+
+ // Register some dummy blob files.
+ blob_db_impl()->TEST_AddDummyBlobFile(1, /* immutable_sequence */ 200);
+ blob_db_impl()->TEST_AddDummyBlobFile(2, /* immutable_sequence */ 300);
+ blob_db_impl()->TEST_AddDummyBlobFile(3, /* immutable_sequence */ 400);
+ blob_db_impl()->TEST_AddDummyBlobFile(4, /* immutable_sequence */ 500);
+ blob_db_impl()->TEST_AddDummyBlobFile(5, /* immutable_sequence */ 600);
+
+ // Initialize the blob <-> SST file mapping. First, add some SST files with
+ // blob file references, then some without.
+ std::vector<LiveFileMetaData> live_files;
+
+ for (uint64_t i = 1; i <= 10; ++i) {
+ LiveFileMetaData live_file;
+ live_file.file_number = i;
+ live_file.oldest_blob_file_number = ((i - 1) % 5) + 1;
+
+ live_files.emplace_back(live_file);
+ }
+
+ for (uint64_t i = 11; i <= 20; ++i) {
+ LiveFileMetaData live_file;
+ live_file.file_number = i;
+
+ live_files.emplace_back(live_file);
+ }
+
+ blob_db_impl()->TEST_InitializeBlobFileToSstMapping(live_files);
+
+ // Check that the blob <-> SST mappings have been correctly initialized.
+ auto blob_files = blob_db_impl()->TEST_GetBlobFiles();
+
+ ASSERT_EQ(blob_files.size(), 5);
+
+ {
+ auto live_imm_files = blob_db_impl()->TEST_GetLiveImmNonTTLFiles();
+ ASSERT_EQ(live_imm_files.size(), 5);
+ for (size_t i = 0; i < 5; ++i) {
+ ASSERT_EQ(live_imm_files[i]->BlobFileNumber(), i + 1);
+ }
+
+ ASSERT_TRUE(blob_db_impl()->TEST_GetObsoleteFiles().empty());
+ }
+
+ {
+ const std::vector<std::unordered_set<uint64_t>> expected_sst_files{
+ {1, 6}, {2, 7}, {3, 8}, {4, 9}, {5, 10}};
+ const std::vector<bool> expected_obsolete{false, false, false, false,
+ false};
+ for (size_t i = 0; i < 5; ++i) {
+ const auto &blob_file = blob_files[i];
+ ASSERT_EQ(blob_file->GetLinkedSstFiles(), expected_sst_files[i]);
+ ASSERT_EQ(blob_file->Obsolete(), expected_obsolete[i]);
+ }
+
+ auto live_imm_files = blob_db_impl()->TEST_GetLiveImmNonTTLFiles();
+ ASSERT_EQ(live_imm_files.size(), 5);
+ for (size_t i = 0; i < 5; ++i) {
+ ASSERT_EQ(live_imm_files[i]->BlobFileNumber(), i + 1);
+ }
+
+ ASSERT_TRUE(blob_db_impl()->TEST_GetObsoleteFiles().empty());
+ }
+
+ // Simulate a flush where the SST does not reference any blob files.
+ {
+ FlushJobInfo info{};
+ info.file_number = 21;
+ info.smallest_seqno = 1;
+ info.largest_seqno = 100;
+
+ blob_db_impl()->TEST_ProcessFlushJobInfo(info);
+
+ const std::vector<std::unordered_set<uint64_t>> expected_sst_files{
+ {1, 6}, {2, 7}, {3, 8}, {4, 9}, {5, 10}};
+ const std::vector<bool> expected_obsolete{false, false, false, false,
+ false};
+ for (size_t i = 0; i < 5; ++i) {
+ const auto &blob_file = blob_files[i];
+ ASSERT_EQ(blob_file->GetLinkedSstFiles(), expected_sst_files[i]);
+ ASSERT_EQ(blob_file->Obsolete(), expected_obsolete[i]);
+ }
+
+ auto live_imm_files = blob_db_impl()->TEST_GetLiveImmNonTTLFiles();
+ ASSERT_EQ(live_imm_files.size(), 5);
+ for (size_t i = 0; i < 5; ++i) {
+ ASSERT_EQ(live_imm_files[i]->BlobFileNumber(), i + 1);
+ }
+
+ ASSERT_TRUE(blob_db_impl()->TEST_GetObsoleteFiles().empty());
+ }
+
+ // Simulate a flush where the SST references a blob file.
+ {
+ FlushJobInfo info{};
+ info.file_number = 22;
+ info.oldest_blob_file_number = 5;
+ info.smallest_seqno = 101;
+ info.largest_seqno = 200;
+
+ blob_db_impl()->TEST_ProcessFlushJobInfo(info);
+
+ const std::vector<std::unordered_set<uint64_t>> expected_sst_files{
+ {1, 6}, {2, 7}, {3, 8}, {4, 9}, {5, 10, 22}};
+ const std::vector<bool> expected_obsolete{false, false, false, false,
+ false};
+ for (size_t i = 0; i < 5; ++i) {
+ const auto &blob_file = blob_files[i];
+ ASSERT_EQ(blob_file->GetLinkedSstFiles(), expected_sst_files[i]);
+ ASSERT_EQ(blob_file->Obsolete(), expected_obsolete[i]);
+ }
+
+ auto live_imm_files = blob_db_impl()->TEST_GetLiveImmNonTTLFiles();
+ ASSERT_EQ(live_imm_files.size(), 5);
+ for (size_t i = 0; i < 5; ++i) {
+ ASSERT_EQ(live_imm_files[i]->BlobFileNumber(), i + 1);
+ }
+
+ ASSERT_TRUE(blob_db_impl()->TEST_GetObsoleteFiles().empty());
+ }
+
+ // Simulate a compaction. Some inputs and outputs have blob file references,
+ // some don't. There is also a trivial move (which means the SST appears on
+ // both the input and the output list). Blob file 1 loses all its linked SSTs,
+ // and since it got marked immutable at sequence number 200 which has already
+ // been flushed, it can be marked obsolete.
+ {
+ CompactionJobInfo info{};
+ info.input_file_infos.emplace_back(CompactionFileInfo{1, 1, 1});
+ info.input_file_infos.emplace_back(CompactionFileInfo{1, 2, 2});
+ info.input_file_infos.emplace_back(CompactionFileInfo{1, 6, 1});
+ info.input_file_infos.emplace_back(
+ CompactionFileInfo{1, 11, kInvalidBlobFileNumber});
+ info.input_file_infos.emplace_back(CompactionFileInfo{1, 22, 5});
+ info.output_file_infos.emplace_back(CompactionFileInfo{2, 22, 5});
+ info.output_file_infos.emplace_back(CompactionFileInfo{2, 23, 3});
+ info.output_file_infos.emplace_back(
+ CompactionFileInfo{2, 24, kInvalidBlobFileNumber});
+
+ blob_db_impl()->TEST_ProcessCompactionJobInfo(info);
+
+ const std::vector<std::unordered_set<uint64_t>> expected_sst_files{
+ {}, {7}, {3, 8, 23}, {4, 9}, {5, 10, 22}};
+ const std::vector<bool> expected_obsolete{true, false, false, false, false};
+ for (size_t i = 0; i < 5; ++i) {
+ const auto &blob_file = blob_files[i];
+ ASSERT_EQ(blob_file->GetLinkedSstFiles(), expected_sst_files[i]);
+ ASSERT_EQ(blob_file->Obsolete(), expected_obsolete[i]);
+ }
+
+ auto live_imm_files = blob_db_impl()->TEST_GetLiveImmNonTTLFiles();
+ ASSERT_EQ(live_imm_files.size(), 4);
+ for (size_t i = 0; i < 4; ++i) {
+ ASSERT_EQ(live_imm_files[i]->BlobFileNumber(), i + 2);
+ }
+
+ auto obsolete_files = blob_db_impl()->TEST_GetObsoleteFiles();
+ ASSERT_EQ(obsolete_files.size(), 1);
+ ASSERT_EQ(obsolete_files[0]->BlobFileNumber(), 1);
+ }
+
+ // Simulate a failed compaction. No mappings should be updated.
+ {
+ CompactionJobInfo info{};
+ info.input_file_infos.emplace_back(CompactionFileInfo{1, 7, 2});
+ info.input_file_infos.emplace_back(CompactionFileInfo{2, 22, 5});
+ info.output_file_infos.emplace_back(CompactionFileInfo{2, 25, 3});
+ info.status = Status::Corruption();
+
+ blob_db_impl()->TEST_ProcessCompactionJobInfo(info);
+
+ const std::vector<std::unordered_set<uint64_t>> expected_sst_files{
+ {}, {7}, {3, 8, 23}, {4, 9}, {5, 10, 22}};
+ const std::vector<bool> expected_obsolete{true, false, false, false, false};
+ for (size_t i = 0; i < 5; ++i) {
+ const auto &blob_file = blob_files[i];
+ ASSERT_EQ(blob_file->GetLinkedSstFiles(), expected_sst_files[i]);
+ ASSERT_EQ(blob_file->Obsolete(), expected_obsolete[i]);
+ }
+
+ auto live_imm_files = blob_db_impl()->TEST_GetLiveImmNonTTLFiles();
+ ASSERT_EQ(live_imm_files.size(), 4);
+ for (size_t i = 0; i < 4; ++i) {
+ ASSERT_EQ(live_imm_files[i]->BlobFileNumber(), i + 2);
+ }
+
+ auto obsolete_files = blob_db_impl()->TEST_GetObsoleteFiles();
+ ASSERT_EQ(obsolete_files.size(), 1);
+ ASSERT_EQ(obsolete_files[0]->BlobFileNumber(), 1);
+ }
+
+ // Simulate another compaction. Blob file 2 loses all its linked SSTs
+ // but since it got marked immutable at sequence number 300 which hasn't
+ // been flushed yet, it cannot be marked obsolete at this point.
+ {
+ CompactionJobInfo info{};
+ info.input_file_infos.emplace_back(CompactionFileInfo{1, 7, 2});
+ info.input_file_infos.emplace_back(CompactionFileInfo{2, 22, 5});
+ info.output_file_infos.emplace_back(CompactionFileInfo{2, 25, 3});
+
+ blob_db_impl()->TEST_ProcessCompactionJobInfo(info);
+
+ const std::vector<std::unordered_set<uint64_t>> expected_sst_files{
+ {}, {}, {3, 8, 23, 25}, {4, 9}, {5, 10}};
+ const std::vector<bool> expected_obsolete{true, false, false, false, false};
+ for (size_t i = 0; i < 5; ++i) {
+ const auto &blob_file = blob_files[i];
+ ASSERT_EQ(blob_file->GetLinkedSstFiles(), expected_sst_files[i]);
+ ASSERT_EQ(blob_file->Obsolete(), expected_obsolete[i]);
+ }
+
+ auto live_imm_files = blob_db_impl()->TEST_GetLiveImmNonTTLFiles();
+ ASSERT_EQ(live_imm_files.size(), 4);
+ for (size_t i = 0; i < 4; ++i) {
+ ASSERT_EQ(live_imm_files[i]->BlobFileNumber(), i + 2);
+ }
+
+ auto obsolete_files = blob_db_impl()->TEST_GetObsoleteFiles();
+ ASSERT_EQ(obsolete_files.size(), 1);
+ ASSERT_EQ(obsolete_files[0]->BlobFileNumber(), 1);
+ }
+
+ // Simulate a flush with largest sequence number 300. This will make it
+ // possible to mark blob file 2 obsolete.
+ {
+ FlushJobInfo info{};
+ info.file_number = 26;
+ info.smallest_seqno = 201;
+ info.largest_seqno = 300;
+
+ blob_db_impl()->TEST_ProcessFlushJobInfo(info);
+
+ const std::vector<std::unordered_set<uint64_t>> expected_sst_files{
+ {}, {}, {3, 8, 23, 25}, {4, 9}, {5, 10}};
+ const std::vector<bool> expected_obsolete{true, true, false, false, false};
+ for (size_t i = 0; i < 5; ++i) {
+ const auto &blob_file = blob_files[i];
+ ASSERT_EQ(blob_file->GetLinkedSstFiles(), expected_sst_files[i]);
+ ASSERT_EQ(blob_file->Obsolete(), expected_obsolete[i]);
+ }
+
+ auto live_imm_files = blob_db_impl()->TEST_GetLiveImmNonTTLFiles();
+ ASSERT_EQ(live_imm_files.size(), 3);
+ for (size_t i = 0; i < 3; ++i) {
+ ASSERT_EQ(live_imm_files[i]->BlobFileNumber(), i + 3);
+ }
+
+ auto obsolete_files = blob_db_impl()->TEST_GetObsoleteFiles();
+ ASSERT_EQ(obsolete_files.size(), 2);
+ ASSERT_EQ(obsolete_files[0]->BlobFileNumber(), 1);
+ ASSERT_EQ(obsolete_files[1]->BlobFileNumber(), 2);
+ }
+}
+
+TEST_F(BlobDBTest, ShutdownWait) {
+ BlobDBOptions bdb_options;
+ bdb_options.ttl_range_secs = 100;
+ bdb_options.min_blob_size = 0;
+ bdb_options.disable_background_tasks = false;
+ Options options;
+ options.env = mock_env_.get();
+
+ SyncPoint::GetInstance()->LoadDependency({
+ {"BlobDBImpl::EvictExpiredFiles:0", "BlobDBTest.ShutdownWait:0"},
+ {"BlobDBTest.ShutdownWait:1", "BlobDBImpl::EvictExpiredFiles:1"},
+ {"BlobDBImpl::EvictExpiredFiles:2", "BlobDBTest.ShutdownWait:2"},
+ {"BlobDBTest.ShutdownWait:3", "BlobDBImpl::EvictExpiredFiles:3"},
+ });
+ // Force all tasks to be scheduled immediately.
+ SyncPoint::GetInstance()->SetCallBack(
+ "TimeQueue::Add:item.end", [&](void *arg) {
+ std::chrono::steady_clock::time_point *tp =
+ static_cast<std::chrono::steady_clock::time_point *>(arg);
+ *tp =
+ std::chrono::steady_clock::now() - std::chrono::milliseconds(10000);
+ });
+
+ SyncPoint::GetInstance()->SetCallBack(
+ "BlobDBImpl::EvictExpiredFiles:cb", [&](void * /*arg*/) {
+ // Sleep 3 ms to increase the chance of data race.
+ // We've synced up the code so that EvictExpiredFiles()
+ // is called concurrently with ~BlobDBImpl().
+ // ~BlobDBImpl() is supposed to wait for all background
+ // task to shutdown before doing anything else. In order
+ // to use the same test to reproduce a bug of the waiting
+ // logic, we wait a little bit here, so that TSAN can
+ // catch the data race.
+ // We should improve the test if we find a better way.
+ Env::Default()->SleepForMicroseconds(3000);
+ });
+
+ SyncPoint::GetInstance()->EnableProcessing();
+
+ Open(bdb_options, options);
+ mock_clock_->SetCurrentTime(50);
+ std::map<std::string, std::string> data;
+ ASSERT_OK(PutWithTTL("foo", "bar", 100, &data));
+ auto blob_files = blob_db_impl()->TEST_GetBlobFiles();
+ ASSERT_EQ(1, blob_files.size());
+ auto blob_file = blob_files[0];
+ ASSERT_FALSE(blob_file->Immutable());
+ ASSERT_FALSE(blob_file->Obsolete());
+ VerifyDB(data);
+
+ TEST_SYNC_POINT("BlobDBTest.ShutdownWait:0");
+ mock_clock_->SetCurrentTime(250);
+ // The key should expired now.
+ TEST_SYNC_POINT("BlobDBTest.ShutdownWait:1");
+
+ TEST_SYNC_POINT("BlobDBTest.ShutdownWait:2");
+ TEST_SYNC_POINT("BlobDBTest.ShutdownWait:3");
+ Close();
+
+ SyncPoint::GetInstance()->DisableProcessing();
+}
+
+TEST_F(BlobDBTest, SyncBlobFileBeforeClose) {
+ Options options;
+ options.statistics = CreateDBStatistics();
+
+ BlobDBOptions blob_options;
+ blob_options.min_blob_size = 0;
+ blob_options.bytes_per_sync = 1 << 20;
+ blob_options.disable_background_tasks = true;
+
+ Open(blob_options, options);
+
+ ASSERT_OK(Put("foo", "bar"));
+
+ auto blob_files = blob_db_impl()->TEST_GetBlobFiles();
+ ASSERT_EQ(blob_files.size(), 1);
+
+ ASSERT_OK(blob_db_impl()->TEST_CloseBlobFile(blob_files[0]));
+ ASSERT_EQ(options.statistics->getTickerCount(BLOB_DB_BLOB_FILE_SYNCED), 1);
+}
+
+TEST_F(BlobDBTest, SyncBlobFileBeforeCloseIOError) {
+ Options options;
+ options.env = fault_injection_env_.get();
+
+ BlobDBOptions blob_options;
+ blob_options.min_blob_size = 0;
+ blob_options.bytes_per_sync = 1 << 20;
+ blob_options.disable_background_tasks = true;
+
+ Open(blob_options, options);
+
+ ASSERT_OK(Put("foo", "bar"));
+
+ auto blob_files = blob_db_impl()->TEST_GetBlobFiles();
+ ASSERT_EQ(blob_files.size(), 1);
+
+ SyncPoint::GetInstance()->SetCallBack(
+ "BlobLogWriter::Sync", [this](void * /* arg */) {
+ fault_injection_env_->SetFilesystemActive(false, Status::IOError());
+ });
+ SyncPoint::GetInstance()->EnableProcessing();
+
+ const Status s = blob_db_impl()->TEST_CloseBlobFile(blob_files[0]);
+
+ fault_injection_env_->SetFilesystemActive(true);
+ SyncPoint::GetInstance()->DisableProcessing();
+ SyncPoint::GetInstance()->ClearAllCallBacks();
+
+ ASSERT_TRUE(s.IsIOError());
+}
+
+} // namespace blob_db
+} // namespace ROCKSDB_NAMESPACE
+
+// A black-box test for the ttl wrapper around rocksdb
+int main(int argc, char **argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
+
+#else
+#include <stdio.h>
+
+int main(int /*argc*/, char** /*argv*/) {
+ fprintf(stderr, "SKIPPED as BlobDB is not supported in ROCKSDB_LITE\n");
+ return 0;
+}
+
+#endif // !ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/blob_db/blob_dump_tool.cc b/src/rocksdb/utilities/blob_db/blob_dump_tool.cc
new file mode 100644
index 000000000..1e0632990
--- /dev/null
+++ b/src/rocksdb/utilities/blob_db/blob_dump_tool.cc
@@ -0,0 +1,282 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+#ifndef ROCKSDB_LITE
+
+#include "utilities/blob_db/blob_dump_tool.h"
+
+#include <stdio.h>
+
+#include <cinttypes>
+#include <iostream>
+#include <memory>
+#include <string>
+
+#include "file/random_access_file_reader.h"
+#include "file/readahead_raf.h"
+#include "port/port.h"
+#include "rocksdb/convenience.h"
+#include "rocksdb/file_system.h"
+#include "table/format.h"
+#include "util/coding.h"
+#include "util/string_util.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace blob_db {
+
+BlobDumpTool::BlobDumpTool()
+ : reader_(nullptr), buffer_(nullptr), buffer_size_(0) {}
+
+Status BlobDumpTool::Run(const std::string& filename, DisplayType show_key,
+ DisplayType show_blob,
+ DisplayType show_uncompressed_blob,
+ bool show_summary) {
+ constexpr size_t kReadaheadSize = 2 * 1024 * 1024;
+ Status s;
+ const auto fs = FileSystem::Default();
+ IOOptions io_opts;
+ s = fs->FileExists(filename, io_opts, nullptr);
+ if (!s.ok()) {
+ return s;
+ }
+ uint64_t file_size = 0;
+ s = fs->GetFileSize(filename, io_opts, &file_size, nullptr);
+ if (!s.ok()) {
+ return s;
+ }
+ std::unique_ptr<FSRandomAccessFile> file;
+ s = fs->NewRandomAccessFile(filename, FileOptions(), &file, nullptr);
+ if (!s.ok()) {
+ return s;
+ }
+ file = NewReadaheadRandomAccessFile(std::move(file), kReadaheadSize);
+ if (file_size == 0) {
+ return Status::Corruption("File is empty.");
+ }
+ reader_.reset(new RandomAccessFileReader(std::move(file), filename));
+ uint64_t offset = 0;
+ uint64_t footer_offset = 0;
+ CompressionType compression = kNoCompression;
+ s = DumpBlobLogHeader(&offset, &compression);
+ if (!s.ok()) {
+ return s;
+ }
+ s = DumpBlobLogFooter(file_size, &footer_offset);
+ if (!s.ok()) {
+ return s;
+ }
+ uint64_t total_records = 0;
+ uint64_t total_key_size = 0;
+ uint64_t total_blob_size = 0;
+ uint64_t total_uncompressed_blob_size = 0;
+ if (show_key != DisplayType::kNone || show_summary) {
+ while (offset < footer_offset) {
+ s = DumpRecord(show_key, show_blob, show_uncompressed_blob, show_summary,
+ compression, &offset, &total_records, &total_key_size,
+ &total_blob_size, &total_uncompressed_blob_size);
+ if (!s.ok()) {
+ break;
+ }
+ }
+ }
+ if (show_summary) {
+ fprintf(stdout, "Summary:\n");
+ fprintf(stdout, " total records: %" PRIu64 "\n", total_records);
+ fprintf(stdout, " total key size: %" PRIu64 "\n", total_key_size);
+ fprintf(stdout, " total blob size: %" PRIu64 "\n", total_blob_size);
+ if (compression != kNoCompression) {
+ fprintf(stdout, " total raw blob size: %" PRIu64 "\n",
+ total_uncompressed_blob_size);
+ }
+ }
+ return s;
+}
+
+Status BlobDumpTool::Read(uint64_t offset, size_t size, Slice* result) {
+ if (buffer_size_ < size) {
+ if (buffer_size_ == 0) {
+ buffer_size_ = 4096;
+ }
+ while (buffer_size_ < size) {
+ buffer_size_ *= 2;
+ }
+ buffer_.reset(new char[buffer_size_]);
+ }
+ Status s = reader_->Read(IOOptions(), offset, size, result, buffer_.get(),
+ nullptr, Env::IO_TOTAL /* rate_limiter_priority */);
+ if (!s.ok()) {
+ return s;
+ }
+ if (result->size() != size) {
+ return Status::Corruption("Reach the end of the file unexpectedly.");
+ }
+ return s;
+}
+
+Status BlobDumpTool::DumpBlobLogHeader(uint64_t* offset,
+ CompressionType* compression) {
+ Slice slice;
+ Status s = Read(0, BlobLogHeader::kSize, &slice);
+ if (!s.ok()) {
+ return s;
+ }
+ BlobLogHeader header;
+ s = header.DecodeFrom(slice);
+ if (!s.ok()) {
+ return s;
+ }
+ fprintf(stdout, "Blob log header:\n");
+ fprintf(stdout, " Version : %" PRIu32 "\n", header.version);
+ fprintf(stdout, " Column Family ID : %" PRIu32 "\n",
+ header.column_family_id);
+ std::string compression_str;
+ if (!GetStringFromCompressionType(&compression_str, header.compression)
+ .ok()) {
+ compression_str = "Unrecongnized compression type (" +
+ std::to_string((int)header.compression) + ")";
+ }
+ fprintf(stdout, " Compression : %s\n", compression_str.c_str());
+ fprintf(stdout, " Expiration range : %s\n",
+ GetString(header.expiration_range).c_str());
+ *offset = BlobLogHeader::kSize;
+ *compression = header.compression;
+ return s;
+}
+
+Status BlobDumpTool::DumpBlobLogFooter(uint64_t file_size,
+ uint64_t* footer_offset) {
+ auto no_footer = [&]() {
+ *footer_offset = file_size;
+ fprintf(stdout, "No blob log footer.\n");
+ return Status::OK();
+ };
+ if (file_size < BlobLogHeader::kSize + BlobLogFooter::kSize) {
+ return no_footer();
+ }
+ Slice slice;
+ *footer_offset = file_size - BlobLogFooter::kSize;
+ Status s = Read(*footer_offset, BlobLogFooter::kSize, &slice);
+ if (!s.ok()) {
+ return s;
+ }
+ BlobLogFooter footer;
+ s = footer.DecodeFrom(slice);
+ if (!s.ok()) {
+ return no_footer();
+ }
+ fprintf(stdout, "Blob log footer:\n");
+ fprintf(stdout, " Blob count : %" PRIu64 "\n", footer.blob_count);
+ fprintf(stdout, " Expiration Range : %s\n",
+ GetString(footer.expiration_range).c_str());
+ return s;
+}
+
+Status BlobDumpTool::DumpRecord(DisplayType show_key, DisplayType show_blob,
+ DisplayType show_uncompressed_blob,
+ bool show_summary, CompressionType compression,
+ uint64_t* offset, uint64_t* total_records,
+ uint64_t* total_key_size,
+ uint64_t* total_blob_size,
+ uint64_t* total_uncompressed_blob_size) {
+ if (show_key != DisplayType::kNone) {
+ fprintf(stdout, "Read record with offset 0x%" PRIx64 " (%" PRIu64 "):\n",
+ *offset, *offset);
+ }
+ Slice slice;
+ Status s = Read(*offset, BlobLogRecord::kHeaderSize, &slice);
+ if (!s.ok()) {
+ return s;
+ }
+ BlobLogRecord record;
+ s = record.DecodeHeaderFrom(slice);
+ if (!s.ok()) {
+ return s;
+ }
+ uint64_t key_size = record.key_size;
+ uint64_t value_size = record.value_size;
+ if (show_key != DisplayType::kNone) {
+ fprintf(stdout, " key size : %" PRIu64 "\n", key_size);
+ fprintf(stdout, " value size : %" PRIu64 "\n", value_size);
+ fprintf(stdout, " expiration : %" PRIu64 "\n", record.expiration);
+ }
+ *offset += BlobLogRecord::kHeaderSize;
+ s = Read(*offset, static_cast<size_t>(key_size + value_size), &slice);
+ if (!s.ok()) {
+ return s;
+ }
+ // Decompress value
+ std::string uncompressed_value;
+ if (compression != kNoCompression &&
+ (show_uncompressed_blob != DisplayType::kNone || show_summary)) {
+ BlockContents contents;
+ UncompressionContext context(compression);
+ UncompressionInfo info(context, UncompressionDict::GetEmptyDict(),
+ compression);
+ s = UncompressBlockData(
+ info, slice.data() + key_size, static_cast<size_t>(value_size),
+ &contents, 2 /*compress_format_version*/, ImmutableOptions(Options()));
+ if (!s.ok()) {
+ return s;
+ }
+ uncompressed_value = contents.data.ToString();
+ }
+ if (show_key != DisplayType::kNone) {
+ fprintf(stdout, " key : ");
+ DumpSlice(Slice(slice.data(), static_cast<size_t>(key_size)), show_key);
+ if (show_blob != DisplayType::kNone) {
+ fprintf(stdout, " blob : ");
+ DumpSlice(Slice(slice.data() + static_cast<size_t>(key_size),
+ static_cast<size_t>(value_size)),
+ show_blob);
+ }
+ if (show_uncompressed_blob != DisplayType::kNone) {
+ fprintf(stdout, " raw blob : ");
+ DumpSlice(Slice(uncompressed_value), show_uncompressed_blob);
+ }
+ }
+ *offset += key_size + value_size;
+ *total_records += 1;
+ *total_key_size += key_size;
+ *total_blob_size += value_size;
+ *total_uncompressed_blob_size += uncompressed_value.size();
+ return s;
+}
+
+void BlobDumpTool::DumpSlice(const Slice s, DisplayType type) {
+ if (type == DisplayType::kRaw) {
+ fprintf(stdout, "%s\n", s.ToString().c_str());
+ } else if (type == DisplayType::kHex) {
+ fprintf(stdout, "%s\n", s.ToString(true /*hex*/).c_str());
+ } else if (type == DisplayType::kDetail) {
+ char buf[100];
+ for (size_t i = 0; i < s.size(); i += 16) {
+ memset(buf, 0, sizeof(buf));
+ for (size_t j = 0; j < 16 && i + j < s.size(); j++) {
+ unsigned char c = s[i + j];
+ snprintf(buf + j * 3 + 15, 2, "%x", c >> 4);
+ snprintf(buf + j * 3 + 16, 2, "%x", c & 0xf);
+ snprintf(buf + j + 65, 2, "%c", (0x20 <= c && c <= 0x7e) ? c : '.');
+ }
+ for (size_t p = 0; p + 1 < sizeof(buf); p++) {
+ if (buf[p] == 0) {
+ buf[p] = ' ';
+ }
+ }
+ fprintf(stdout, "%s\n", i == 0 ? buf + 15 : buf);
+ }
+ }
+}
+
+template <class T>
+std::string BlobDumpTool::GetString(std::pair<T, T> p) {
+ if (p.first == 0 && p.second == 0) {
+ return "nil";
+ }
+ return "(" + std::to_string(p.first) + ", " + std::to_string(p.second) + ")";
+}
+
+} // namespace blob_db
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/blob_db/blob_dump_tool.h b/src/rocksdb/utilities/blob_db/blob_dump_tool.h
new file mode 100644
index 000000000..bece564e1
--- /dev/null
+++ b/src/rocksdb/utilities/blob_db/blob_dump_tool.h
@@ -0,0 +1,58 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+#pragma once
+#ifndef ROCKSDB_LITE
+
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "db/blob/blob_log_format.h"
+#include "file/random_access_file_reader.h"
+#include "rocksdb/slice.h"
+#include "rocksdb/status.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace blob_db {
+
+class BlobDumpTool {
+ public:
+ enum class DisplayType {
+ kNone,
+ kRaw,
+ kHex,
+ kDetail,
+ };
+
+ BlobDumpTool();
+
+ Status Run(const std::string& filename, DisplayType show_key,
+ DisplayType show_blob, DisplayType show_uncompressed_blob,
+ bool show_summary);
+
+ private:
+ std::unique_ptr<RandomAccessFileReader> reader_;
+ std::unique_ptr<char[]> buffer_;
+ size_t buffer_size_;
+
+ Status Read(uint64_t offset, size_t size, Slice* result);
+ Status DumpBlobLogHeader(uint64_t* offset, CompressionType* compression);
+ Status DumpBlobLogFooter(uint64_t file_size, uint64_t* footer_offset);
+ Status DumpRecord(DisplayType show_key, DisplayType show_blob,
+ DisplayType show_uncompressed_blob, bool show_summary,
+ CompressionType compression, uint64_t* offset,
+ uint64_t* total_records, uint64_t* total_key_size,
+ uint64_t* total_blob_size,
+ uint64_t* total_uncompressed_blob_size);
+ void DumpSlice(const Slice s, DisplayType type);
+
+ template <class T>
+ std::string GetString(std::pair<T, T> p);
+};
+
+} // namespace blob_db
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/blob_db/blob_file.cc b/src/rocksdb/utilities/blob_db/blob_file.cc
new file mode 100644
index 000000000..c68e557c6
--- /dev/null
+++ b/src/rocksdb/utilities/blob_db/blob_file.cc
@@ -0,0 +1,318 @@
+
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+#ifndef ROCKSDB_LITE
+#include "utilities/blob_db/blob_file.h"
+
+#include <stdio.h>
+
+#include <algorithm>
+#include <cinttypes>
+#include <memory>
+
+#include "db/column_family.h"
+#include "db/db_impl/db_impl.h"
+#include "db/dbformat.h"
+#include "file/filename.h"
+#include "file/readahead_raf.h"
+#include "logging/logging.h"
+#include "utilities/blob_db/blob_db_impl.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+namespace blob_db {
+
+BlobFile::BlobFile(const BlobDBImpl* p, const std::string& bdir, uint64_t fn,
+ Logger* info_log)
+ : parent_(p), path_to_dir_(bdir), file_number_(fn), info_log_(info_log) {}
+
+BlobFile::BlobFile(const BlobDBImpl* p, const std::string& bdir, uint64_t fn,
+ Logger* info_log, uint32_t column_family_id,
+ CompressionType compression, bool has_ttl,
+ const ExpirationRange& expiration_range)
+ : parent_(p),
+ path_to_dir_(bdir),
+ file_number_(fn),
+ info_log_(info_log),
+ column_family_id_(column_family_id),
+ compression_(compression),
+ has_ttl_(has_ttl),
+ expiration_range_(expiration_range),
+ header_(column_family_id, compression, has_ttl, expiration_range),
+ header_valid_(true) {}
+
+BlobFile::~BlobFile() {
+ if (obsolete_) {
+ std::string pn(PathName());
+ Status s = Env::Default()->DeleteFile(PathName());
+ if (!s.ok()) {
+ // ROCKS_LOG_INFO(db_options_.info_log,
+ // "File could not be deleted %s", pn.c_str());
+ }
+ }
+}
+
+uint32_t BlobFile::GetColumnFamilyId() const { return column_family_id_; }
+
+std::string BlobFile::PathName() const {
+ return BlobFileName(path_to_dir_, file_number_);
+}
+
+std::string BlobFile::DumpState() const {
+ char str[1000];
+ snprintf(
+ str, sizeof(str),
+ "path: %s fn: %" PRIu64 " blob_count: %" PRIu64 " file_size: %" PRIu64
+ " closed: %d obsolete: %d expiration_range: (%" PRIu64 ", %" PRIu64
+ "), writer: %d reader: %d",
+ path_to_dir_.c_str(), file_number_, blob_count_.load(), file_size_.load(),
+ closed_.load(), obsolete_.load(), expiration_range_.first,
+ expiration_range_.second, (!!log_writer_), (!!ra_file_reader_));
+ return str;
+}
+
+void BlobFile::MarkObsolete(SequenceNumber sequence) {
+ assert(Immutable());
+ obsolete_sequence_ = sequence;
+ obsolete_.store(true);
+}
+
+Status BlobFile::WriteFooterAndCloseLocked(SequenceNumber sequence) {
+ BlobLogFooter footer;
+ footer.blob_count = blob_count_;
+ if (HasTTL()) {
+ footer.expiration_range = expiration_range_;
+ }
+
+ // this will close the file and reset the Writable File Pointer.
+ Status s = log_writer_->AppendFooter(footer, /* checksum_method */ nullptr,
+ /* checksum_value */ nullptr);
+ if (s.ok()) {
+ closed_ = true;
+ immutable_sequence_ = sequence;
+ file_size_ += BlobLogFooter::kSize;
+ }
+ // delete the sequential writer
+ log_writer_.reset();
+ return s;
+}
+
+Status BlobFile::ReadFooter(BlobLogFooter* bf) {
+ if (file_size_ < (BlobLogHeader::kSize + BlobLogFooter::kSize)) {
+ return Status::IOError("File does not have footer", PathName());
+ }
+
+ uint64_t footer_offset = file_size_ - BlobLogFooter::kSize;
+ // assume that ra_file_reader_ is valid before we enter this
+ assert(ra_file_reader_);
+
+ Slice result;
+ std::string buf;
+ AlignedBuf aligned_buf;
+ Status s;
+ // TODO: rate limit reading footers from blob files.
+ if (ra_file_reader_->use_direct_io()) {
+ s = ra_file_reader_->Read(IOOptions(), footer_offset, BlobLogFooter::kSize,
+ &result, nullptr, &aligned_buf,
+ Env::IO_TOTAL /* rate_limiter_priority */);
+ } else {
+ buf.reserve(BlobLogFooter::kSize + 10);
+ s = ra_file_reader_->Read(IOOptions(), footer_offset, BlobLogFooter::kSize,
+ &result, &buf[0], nullptr,
+ Env::IO_TOTAL /* rate_limiter_priority */);
+ }
+ if (!s.ok()) return s;
+ if (result.size() != BlobLogFooter::kSize) {
+ // should not happen
+ return Status::IOError("EOF reached before footer");
+ }
+
+ s = bf->DecodeFrom(result);
+ return s;
+}
+
+Status BlobFile::SetFromFooterLocked(const BlobLogFooter& footer) {
+ blob_count_ = footer.blob_count;
+ expiration_range_ = footer.expiration_range;
+ closed_ = true;
+ return Status::OK();
+}
+
+Status BlobFile::Fsync() {
+ Status s;
+ if (log_writer_.get()) {
+ s = log_writer_->Sync();
+ }
+ return s;
+}
+
+void BlobFile::CloseRandomAccessLocked() {
+ ra_file_reader_.reset();
+ last_access_ = -1;
+}
+
+Status BlobFile::GetReader(Env* env, const FileOptions& file_options,
+ std::shared_ptr<RandomAccessFileReader>* reader,
+ bool* fresh_open) {
+ assert(reader != nullptr);
+ assert(fresh_open != nullptr);
+ *fresh_open = false;
+ int64_t current_time = 0;
+ if (env->GetCurrentTime(&current_time).ok()) {
+ last_access_.store(current_time);
+ }
+ Status s;
+
+ {
+ ReadLock lockbfile_r(&mutex_);
+ if (ra_file_reader_) {
+ *reader = ra_file_reader_;
+ return s;
+ }
+ }
+
+ WriteLock lockbfile_w(&mutex_);
+ // Double check.
+ if (ra_file_reader_) {
+ *reader = ra_file_reader_;
+ return s;
+ }
+
+ std::unique_ptr<FSRandomAccessFile> rfile;
+ s = env->GetFileSystem()->NewRandomAccessFile(PathName(), file_options,
+ &rfile, nullptr);
+ if (!s.ok()) {
+ ROCKS_LOG_ERROR(info_log_,
+ "Failed to open blob file for random-read: %s status: '%s'"
+ " exists: '%s'",
+ PathName().c_str(), s.ToString().c_str(),
+ env->FileExists(PathName()).ToString().c_str());
+ return s;
+ }
+
+ ra_file_reader_ =
+ std::make_shared<RandomAccessFileReader>(std::move(rfile), PathName());
+ *reader = ra_file_reader_;
+ *fresh_open = true;
+ return s;
+}
+
+Status BlobFile::ReadMetadata(const std::shared_ptr<FileSystem>& fs,
+ const FileOptions& file_options) {
+ assert(Immutable());
+ // Get file size.
+ uint64_t file_size = 0;
+ Status s =
+ fs->GetFileSize(PathName(), file_options.io_options, &file_size, nullptr);
+ if (s.ok()) {
+ file_size_ = file_size;
+ } else {
+ ROCKS_LOG_ERROR(info_log_,
+ "Failed to get size of blob file %" PRIu64 ", status: %s",
+ file_number_, s.ToString().c_str());
+ return s;
+ }
+ if (file_size < BlobLogHeader::kSize) {
+ ROCKS_LOG_ERROR(
+ info_log_, "Incomplete blob file blob file %" PRIu64 ", size: %" PRIu64,
+ file_number_, file_size);
+ return Status::Corruption("Incomplete blob file header.");
+ }
+
+ // Create file reader.
+ std::unique_ptr<RandomAccessFileReader> file_reader;
+ s = RandomAccessFileReader::Create(fs, PathName(), file_options, &file_reader,
+ nullptr);
+ if (!s.ok()) {
+ ROCKS_LOG_ERROR(info_log_,
+ "Failed to open blob file %" PRIu64 ", status: %s",
+ file_number_, s.ToString().c_str());
+ return s;
+ }
+
+ // Read file header.
+ std::string header_buf;
+ AlignedBuf aligned_buf;
+ Slice header_slice;
+ // TODO: rate limit reading headers from blob files.
+ if (file_reader->use_direct_io()) {
+ s = file_reader->Read(IOOptions(), 0, BlobLogHeader::kSize, &header_slice,
+ nullptr, &aligned_buf,
+ Env::IO_TOTAL /* rate_limiter_priority */);
+ } else {
+ header_buf.reserve(BlobLogHeader::kSize);
+ s = file_reader->Read(IOOptions(), 0, BlobLogHeader::kSize, &header_slice,
+ &header_buf[0], nullptr,
+ Env::IO_TOTAL /* rate_limiter_priority */);
+ }
+ if (!s.ok()) {
+ ROCKS_LOG_ERROR(
+ info_log_, "Failed to read header of blob file %" PRIu64 ", status: %s",
+ file_number_, s.ToString().c_str());
+ return s;
+ }
+ BlobLogHeader header;
+ s = header.DecodeFrom(header_slice);
+ if (!s.ok()) {
+ ROCKS_LOG_ERROR(info_log_,
+ "Failed to decode header of blob file %" PRIu64
+ ", status: %s",
+ file_number_, s.ToString().c_str());
+ return s;
+ }
+ column_family_id_ = header.column_family_id;
+ compression_ = header.compression;
+ has_ttl_ = header.has_ttl;
+ if (has_ttl_) {
+ expiration_range_ = header.expiration_range;
+ }
+ header_valid_ = true;
+
+ // Read file footer.
+ if (file_size_ < BlobLogHeader::kSize + BlobLogFooter::kSize) {
+ // OK not to have footer.
+ assert(!footer_valid_);
+ return Status::OK();
+ }
+ std::string footer_buf;
+ Slice footer_slice;
+ // TODO: rate limit reading footers from blob files.
+ if (file_reader->use_direct_io()) {
+ s = file_reader->Read(IOOptions(), file_size - BlobLogFooter::kSize,
+ BlobLogFooter::kSize, &footer_slice, nullptr,
+ &aligned_buf,
+ Env::IO_TOTAL /* rate_limiter_priority */);
+ } else {
+ footer_buf.reserve(BlobLogFooter::kSize);
+ s = file_reader->Read(IOOptions(), file_size - BlobLogFooter::kSize,
+ BlobLogFooter::kSize, &footer_slice, &footer_buf[0],
+ nullptr, Env::IO_TOTAL /* rate_limiter_priority */);
+ }
+ if (!s.ok()) {
+ ROCKS_LOG_ERROR(
+ info_log_, "Failed to read footer of blob file %" PRIu64 ", status: %s",
+ file_number_, s.ToString().c_str());
+ return s;
+ }
+ BlobLogFooter footer;
+ s = footer.DecodeFrom(footer_slice);
+ if (!s.ok()) {
+ // OK not to have footer.
+ assert(!footer_valid_);
+ return Status::OK();
+ }
+ blob_count_ = footer.blob_count;
+ if (has_ttl_) {
+ assert(header.expiration_range.first <= footer.expiration_range.first);
+ assert(header.expiration_range.second >= footer.expiration_range.second);
+ expiration_range_ = footer.expiration_range;
+ }
+ footer_valid_ = true;
+ return Status::OK();
+}
+
+} // namespace blob_db
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/blob_db/blob_file.h b/src/rocksdb/utilities/blob_db/blob_file.h
new file mode 100644
index 000000000..6f3f2bea7
--- /dev/null
+++ b/src/rocksdb/utilities/blob_db/blob_file.h
@@ -0,0 +1,246 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+#pragma once
+#ifndef ROCKSDB_LITE
+
+#include <atomic>
+#include <limits>
+#include <memory>
+#include <unordered_set>
+
+#include "db/blob/blob_log_format.h"
+#include "db/blob/blob_log_writer.h"
+#include "file/random_access_file_reader.h"
+#include "port/port.h"
+#include "rocksdb/env.h"
+#include "rocksdb/file_system.h"
+#include "rocksdb/options.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace blob_db {
+
+class BlobDBImpl;
+
+class BlobFile {
+ friend class BlobDBImpl;
+ friend struct BlobFileComparator;
+ friend struct BlobFileComparatorTTL;
+ friend class BlobIndexCompactionFilterBase;
+ friend class BlobIndexCompactionFilterGC;
+
+ private:
+ // access to parent
+ const BlobDBImpl* parent_{nullptr};
+
+ // path to blob directory
+ std::string path_to_dir_;
+
+ // the id of the file.
+ // the above 2 are created during file creation and never changed
+ // after that
+ uint64_t file_number_{0};
+
+ // The file numbers of the SST files whose oldest blob file reference
+ // points to this blob file.
+ std::unordered_set<uint64_t> linked_sst_files_;
+
+ // Info log.
+ Logger* info_log_{nullptr};
+
+ // Column family id.
+ uint32_t column_family_id_{std::numeric_limits<uint32_t>::max()};
+
+ // Compression type of blobs in the file
+ CompressionType compression_{kNoCompression};
+
+ // If true, the keys in this file all has TTL. Otherwise all keys don't
+ // have TTL.
+ bool has_ttl_{false};
+
+ // TTL range of blobs in the file.
+ ExpirationRange expiration_range_;
+
+ // number of blobs in the file
+ std::atomic<uint64_t> blob_count_{0};
+
+ // size of the file
+ std::atomic<uint64_t> file_size_{0};
+
+ BlobLogHeader header_;
+
+ // closed_ = true implies the file is no more mutable
+ // no more blobs will be appended and the footer has been written out
+ std::atomic<bool> closed_{false};
+
+ // The latest sequence number when the file was closed/made immutable.
+ SequenceNumber immutable_sequence_{0};
+
+ // Whether the file was marked obsolete (due to either TTL or GC).
+ // obsolete_ still needs to do iterator/snapshot checks
+ std::atomic<bool> obsolete_{false};
+
+ // The last sequence number by the time the file marked as obsolete.
+ // Data in this file is visible to a snapshot taken before the sequence.
+ SequenceNumber obsolete_sequence_{0};
+
+ // Sequential/Append writer for blobs
+ std::shared_ptr<BlobLogWriter> log_writer_;
+
+ // random access file reader for GET calls
+ std::shared_ptr<RandomAccessFileReader> ra_file_reader_;
+
+ // This Read-Write mutex is per file specific and protects
+ // all the datastructures
+ mutable port::RWMutex mutex_;
+
+ // time when the random access reader was last created.
+ std::atomic<std::int64_t> last_access_{-1};
+
+ bool header_valid_{false};
+
+ bool footer_valid_{false};
+
+ public:
+ BlobFile() = default;
+
+ BlobFile(const BlobDBImpl* parent, const std::string& bdir, uint64_t fnum,
+ Logger* info_log);
+
+ BlobFile(const BlobDBImpl* parent, const std::string& bdir, uint64_t fnum,
+ Logger* info_log, uint32_t column_family_id,
+ CompressionType compression, bool has_ttl,
+ const ExpirationRange& expiration_range);
+
+ ~BlobFile();
+
+ uint32_t GetColumnFamilyId() const;
+
+ // Returns log file's absolute pathname.
+ std::string PathName() const;
+
+ // Primary identifier for blob file.
+ // once the file is created, this never changes
+ uint64_t BlobFileNumber() const { return file_number_; }
+
+ // Get the set of SST files whose oldest blob file reference points to
+ // this file.
+ const std::unordered_set<uint64_t>& GetLinkedSstFiles() const {
+ return linked_sst_files_;
+ }
+
+ // Link an SST file whose oldest blob file reference points to this file.
+ void LinkSstFile(uint64_t sst_file_number) {
+ assert(linked_sst_files_.find(sst_file_number) == linked_sst_files_.end());
+ linked_sst_files_.insert(sst_file_number);
+ }
+
+ // Unlink an SST file whose oldest blob file reference points to this file.
+ void UnlinkSstFile(uint64_t sst_file_number) {
+ auto it = linked_sst_files_.find(sst_file_number);
+ assert(it != linked_sst_files_.end());
+ linked_sst_files_.erase(it);
+ }
+
+ // the following functions are atomic, and don't need
+ // read lock
+ uint64_t BlobCount() const {
+ return blob_count_.load(std::memory_order_acquire);
+ }
+
+ std::string DumpState() const;
+
+ // if the file is not taking any more appends.
+ bool Immutable() const { return closed_.load(); }
+
+ // Mark the file as immutable.
+ // REQUIRES: write lock held, or access from single thread (on DB open).
+ void MarkImmutable(SequenceNumber sequence) {
+ closed_ = true;
+ immutable_sequence_ = sequence;
+ }
+
+ SequenceNumber GetImmutableSequence() const {
+ assert(Immutable());
+ return immutable_sequence_;
+ }
+
+ // Whether the file was marked obsolete (due to either TTL or GC).
+ bool Obsolete() const {
+ assert(Immutable() || !obsolete_.load());
+ return obsolete_.load();
+ }
+
+ // Mark file as obsolete (due to either TTL or GC). The file is not visible to
+ // snapshots with sequence greater or equal to the given sequence.
+ void MarkObsolete(SequenceNumber sequence);
+
+ SequenceNumber GetObsoleteSequence() const {
+ assert(Obsolete());
+ return obsolete_sequence_;
+ }
+
+ Status Fsync();
+
+ uint64_t GetFileSize() const {
+ return file_size_.load(std::memory_order_acquire);
+ }
+
+ // All Get functions which are not atomic, will need ReadLock on the mutex
+
+ const ExpirationRange& GetExpirationRange() const {
+ return expiration_range_;
+ }
+
+ void ExtendExpirationRange(uint64_t expiration) {
+ expiration_range_.first = std::min(expiration_range_.first, expiration);
+ expiration_range_.second = std::max(expiration_range_.second, expiration);
+ }
+
+ bool HasTTL() const { return has_ttl_; }
+
+ void SetHasTTL(bool has_ttl) { has_ttl_ = has_ttl; }
+
+ CompressionType GetCompressionType() const { return compression_; }
+
+ std::shared_ptr<BlobLogWriter> GetWriter() const { return log_writer_; }
+
+ // Read blob file header and footer. Return corruption if file header is
+ // malform or incomplete. If footer is malform or incomplete, set
+ // footer_valid_ to false and return Status::OK.
+ Status ReadMetadata(const std::shared_ptr<FileSystem>& fs,
+ const FileOptions& file_options);
+
+ Status GetReader(Env* env, const FileOptions& file_options,
+ std::shared_ptr<RandomAccessFileReader>* reader,
+ bool* fresh_open);
+
+ private:
+ Status ReadFooter(BlobLogFooter* footer);
+
+ Status WriteFooterAndCloseLocked(SequenceNumber sequence);
+
+ void CloseRandomAccessLocked();
+
+ // this is used, when you are reading only the footer of a
+ // previously closed file
+ Status SetFromFooterLocked(const BlobLogFooter& footer);
+
+ void set_expiration_range(const ExpirationRange& expiration_range) {
+ expiration_range_ = expiration_range;
+ }
+
+ // The following functions are atomic, and don't need locks
+ void SetFileSize(uint64_t fs) { file_size_ = fs; }
+
+ void SetBlobCount(uint64_t bc) { blob_count_ = bc; }
+
+ void BlobRecordAdded(uint64_t record_size) {
+ ++blob_count_;
+ file_size_ += record_size;
+ }
+};
+} // namespace blob_db
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/cache_dump_load.cc b/src/rocksdb/utilities/cache_dump_load.cc
new file mode 100644
index 000000000..9a7c76798
--- /dev/null
+++ b/src/rocksdb/utilities/cache_dump_load.cc
@@ -0,0 +1,69 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "rocksdb/utilities/cache_dump_load.h"
+
+#include "file/writable_file_writer.h"
+#include "port/lang.h"
+#include "rocksdb/env.h"
+#include "rocksdb/file_system.h"
+#include "table/format.h"
+#include "util/crc32c.h"
+#include "utilities/cache_dump_load_impl.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+IOStatus NewToFileCacheDumpWriter(const std::shared_ptr<FileSystem>& fs,
+ const FileOptions& file_opts,
+ const std::string& file_name,
+ std::unique_ptr<CacheDumpWriter>* writer) {
+ std::unique_ptr<WritableFileWriter> file_writer;
+ IOStatus io_s = WritableFileWriter::Create(fs, file_name, file_opts,
+ &file_writer, nullptr);
+ if (!io_s.ok()) {
+ return io_s;
+ }
+ writer->reset(new ToFileCacheDumpWriter(std::move(file_writer)));
+ return io_s;
+}
+
+IOStatus NewFromFileCacheDumpReader(const std::shared_ptr<FileSystem>& fs,
+ const FileOptions& file_opts,
+ const std::string& file_name,
+ std::unique_ptr<CacheDumpReader>* reader) {
+ std::unique_ptr<RandomAccessFileReader> file_reader;
+ IOStatus io_s = RandomAccessFileReader::Create(fs, file_name, file_opts,
+ &file_reader, nullptr);
+ if (!io_s.ok()) {
+ return io_s;
+ }
+ reader->reset(new FromFileCacheDumpReader(std::move(file_reader)));
+ return io_s;
+}
+
+Status NewDefaultCacheDumper(const CacheDumpOptions& dump_options,
+ const std::shared_ptr<Cache>& cache,
+ std::unique_ptr<CacheDumpWriter>&& writer,
+ std::unique_ptr<CacheDumper>* cache_dumper) {
+ cache_dumper->reset(
+ new CacheDumperImpl(dump_options, cache, std::move(writer)));
+ return Status::OK();
+}
+
+Status NewDefaultCacheDumpedLoader(
+ const CacheDumpOptions& dump_options,
+ const BlockBasedTableOptions& toptions,
+ const std::shared_ptr<SecondaryCache>& secondary_cache,
+ std::unique_ptr<CacheDumpReader>&& reader,
+ std::unique_ptr<CacheDumpedLoader>* cache_dump_loader) {
+ cache_dump_loader->reset(new CacheDumpedLoaderImpl(
+ dump_options, toptions, secondary_cache, std::move(reader)));
+ return Status::OK();
+}
+
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/cache_dump_load_impl.cc b/src/rocksdb/utilities/cache_dump_load_impl.cc
new file mode 100644
index 000000000..2b9f2a29d
--- /dev/null
+++ b/src/rocksdb/utilities/cache_dump_load_impl.cc
@@ -0,0 +1,393 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "cache/cache_key.h"
+#include "table/block_based/block_based_table_reader.h"
+#ifndef ROCKSDB_LITE
+
+#include "cache/cache_entry_roles.h"
+#include "file/writable_file_writer.h"
+#include "port/lang.h"
+#include "rocksdb/env.h"
+#include "rocksdb/file_system.h"
+#include "rocksdb/utilities/ldb_cmd.h"
+#include "table/format.h"
+#include "util/crc32c.h"
+#include "utilities/cache_dump_load_impl.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// Set the dump filter with a list of DBs. Block cache may be shared by multipe
+// DBs and we may only want to dump out the blocks belonging to certain DB(s).
+// Therefore, a filter is need to decide if the key of the block satisfy the
+// requirement.
+Status CacheDumperImpl::SetDumpFilter(std::vector<DB*> db_list) {
+ Status s = Status::OK();
+ for (size_t i = 0; i < db_list.size(); i++) {
+ assert(i < db_list.size());
+ TablePropertiesCollection ptc;
+ assert(db_list[i] != nullptr);
+ s = db_list[i]->GetPropertiesOfAllTables(&ptc);
+ if (!s.ok()) {
+ return s;
+ }
+ for (auto id = ptc.begin(); id != ptc.end(); id++) {
+ OffsetableCacheKey base;
+ // We only want to save cache entries that are portable to another
+ // DB::Open, so only save entries with stable keys.
+ bool is_stable;
+ BlockBasedTable::SetupBaseCacheKey(id->second.get(),
+ /*cur_db_session_id*/ "",
+ /*cur_file_num*/ 0, &base, &is_stable);
+ if (is_stable) {
+ Slice prefix_slice = base.CommonPrefixSlice();
+ assert(prefix_slice.size() == OffsetableCacheKey::kCommonPrefixSize);
+ prefix_filter_.insert(prefix_slice.ToString());
+ }
+ }
+ }
+ return s;
+}
+
+// This is the main function to dump out the cache block entries to the writer.
+// The writer may create a file or write to other systems. Currently, we will
+// iterate the whole block cache, get the blocks, and write them to the writer
+IOStatus CacheDumperImpl::DumpCacheEntriesToWriter() {
+ // Prepare stage, check the parameters.
+ if (cache_ == nullptr) {
+ return IOStatus::InvalidArgument("Cache is null");
+ }
+ if (writer_ == nullptr) {
+ return IOStatus::InvalidArgument("CacheDumpWriter is null");
+ }
+ // Set the system clock
+ if (options_.clock == nullptr) {
+ return IOStatus::InvalidArgument("System clock is null");
+ }
+ clock_ = options_.clock;
+ // We copy the Cache Deleter Role Map as its member.
+ role_map_ = CopyCacheDeleterRoleMap();
+ // Set the sequence number
+ sequence_num_ = 0;
+
+ // Dump stage, first, we write the hader
+ IOStatus io_s = WriteHeader();
+ if (!io_s.ok()) {
+ return io_s;
+ }
+
+ // Then, we iterate the block cache and dump out the blocks that are not
+ // filtered out.
+ cache_->ApplyToAllEntries(DumpOneBlockCallBack(), {});
+
+ // Finally, write the footer
+ io_s = WriteFooter();
+ if (!io_s.ok()) {
+ return io_s;
+ }
+ io_s = writer_->Close();
+ return io_s;
+}
+
+// Check if we need to filter out the block based on its key
+bool CacheDumperImpl::ShouldFilterOut(const Slice& key) {
+ if (key.size() < OffsetableCacheKey::kCommonPrefixSize) {
+ return /*filter out*/ true;
+ }
+ Slice key_prefix(key.data(), OffsetableCacheKey::kCommonPrefixSize);
+ std::string prefix = key_prefix.ToString();
+ // Filter out if not found
+ return prefix_filter_.find(prefix) == prefix_filter_.end();
+}
+
+// This is the callback function which will be applied to
+// Cache::ApplyToAllEntries. In this callback function, we will get the block
+// type, decide if the block needs to be dumped based on the filter, and write
+// the block through the provided writer.
+std::function<void(const Slice&, void*, size_t, Cache::DeleterFn)>
+CacheDumperImpl::DumpOneBlockCallBack() {
+ return [&](const Slice& key, void* value, size_t /*charge*/,
+ Cache::DeleterFn deleter) {
+ // Step 1: get the type of the block from role_map_
+ auto e = role_map_.find(deleter);
+ CacheEntryRole role;
+ CacheDumpUnitType type = CacheDumpUnitType::kBlockTypeMax;
+ if (e == role_map_.end()) {
+ role = CacheEntryRole::kMisc;
+ } else {
+ role = e->second;
+ }
+ bool filter_out = false;
+
+ // Step 2: based on the key prefix, check if the block should be filter out.
+ if (ShouldFilterOut(key)) {
+ filter_out = true;
+ }
+
+ // Step 3: based on the block type, get the block raw pointer and length.
+ const char* block_start = nullptr;
+ size_t block_len = 0;
+ switch (role) {
+ case CacheEntryRole::kDataBlock:
+ type = CacheDumpUnitType::kData;
+ block_start = (static_cast<Block*>(value))->data();
+ block_len = (static_cast<Block*>(value))->size();
+ break;
+ case CacheEntryRole::kFilterBlock:
+ type = CacheDumpUnitType::kFilter;
+ block_start = (static_cast<ParsedFullFilterBlock*>(value))
+ ->GetBlockContentsData()
+ .data();
+ block_len = (static_cast<ParsedFullFilterBlock*>(value))
+ ->GetBlockContentsData()
+ .size();
+ break;
+ case CacheEntryRole::kFilterMetaBlock:
+ type = CacheDumpUnitType::kFilterMetaBlock;
+ block_start = (static_cast<Block*>(value))->data();
+ block_len = (static_cast<Block*>(value))->size();
+ break;
+ case CacheEntryRole::kIndexBlock:
+ type = CacheDumpUnitType::kIndex;
+ block_start = (static_cast<Block*>(value))->data();
+ block_len = (static_cast<Block*>(value))->size();
+ break;
+ case CacheEntryRole::kDeprecatedFilterBlock:
+ // Obsolete
+ filter_out = true;
+ break;
+ case CacheEntryRole::kMisc:
+ filter_out = true;
+ break;
+ case CacheEntryRole::kOtherBlock:
+ filter_out = true;
+ break;
+ case CacheEntryRole::kWriteBuffer:
+ filter_out = true;
+ break;
+ default:
+ filter_out = true;
+ }
+
+ // Step 4: if the block should not be filter out, write the block to the
+ // CacheDumpWriter
+ if (!filter_out && block_start != nullptr) {
+ WriteBlock(type, key, Slice(block_start, block_len))
+ .PermitUncheckedError();
+ }
+ };
+}
+
+// Write the block to the writer. It takes the timestamp of the
+// block being copied from block cache, block type, key, block pointer,
+// block size and block checksum as the input. When writing the dumper raw
+// block, we first create the dump unit and encoude it to a string. Then,
+// we calculate the checksum of the whole dump unit string and store it in
+// the dump unit metadata.
+// First, we write the metadata first, which is a fixed size string. Then, we
+// Append the dump unit string to the writer.
+IOStatus CacheDumperImpl::WriteBlock(CacheDumpUnitType type, const Slice& key,
+ const Slice& value) {
+ uint64_t timestamp = clock_->NowMicros();
+ uint32_t value_checksum = crc32c::Value(value.data(), value.size());
+
+ // First, serialize the block information in a string
+ DumpUnit dump_unit;
+ dump_unit.timestamp = timestamp;
+ dump_unit.key = key;
+ dump_unit.type = type;
+ dump_unit.value_len = value.size();
+ dump_unit.value = const_cast<char*>(value.data());
+ dump_unit.value_checksum = value_checksum;
+ std::string encoded_data;
+ CacheDumperHelper::EncodeDumpUnit(dump_unit, &encoded_data);
+
+ // Second, create the metadata, which contains a sequence number, the dump
+ // unit string checksum and the string size. The sequence number monotonically
+ // increases from 0.
+ DumpUnitMeta unit_meta;
+ unit_meta.sequence_num = sequence_num_;
+ sequence_num_++;
+ unit_meta.dump_unit_checksum =
+ crc32c::Value(encoded_data.data(), encoded_data.size());
+ unit_meta.dump_unit_size = encoded_data.size();
+ std::string encoded_meta;
+ CacheDumperHelper::EncodeDumpUnitMeta(unit_meta, &encoded_meta);
+
+ // We write the metadata first.
+ assert(writer_ != nullptr);
+ IOStatus io_s = writer_->WriteMetadata(encoded_meta);
+ if (!io_s.ok()) {
+ return io_s;
+ }
+ // followed by the dump unit.
+ return writer_->WritePacket(encoded_data);
+}
+
+// Before we write any block, we write the header first to store the cache dump
+// format version, rocksdb version, and brief intro.
+IOStatus CacheDumperImpl::WriteHeader() {
+ std::string header_key = "header";
+ std::ostringstream s;
+ s << kTraceMagic << "\t"
+ << "Cache dump format version: " << kCacheDumpMajorVersion << "."
+ << kCacheDumpMinorVersion << "\t"
+ << "RocksDB Version: " << kMajorVersion << "." << kMinorVersion << "\t"
+ << "Format: dump_unit_metadata <sequence_number, dump_unit_checksum, "
+ "dump_unit_size>, dump_unit <timestamp, key, block_type, "
+ "block_size, block_data, block_checksum> cache_value\n";
+ std::string header_value(s.str());
+ CacheDumpUnitType type = CacheDumpUnitType::kHeader;
+ return WriteBlock(type, header_key, header_value);
+}
+
+// Write the footer after all the blocks are stored to indicate the ending.
+IOStatus CacheDumperImpl::WriteFooter() {
+ std::string footer_key = "footer";
+ std::string footer_value("cache dump completed");
+ CacheDumpUnitType type = CacheDumpUnitType::kFooter;
+ return WriteBlock(type, footer_key, footer_value);
+}
+
+// This is the main function to restore the cache entries to secondary cache.
+// First, we check if all the arguments are valid. Then, we read the block
+// sequentially from the reader and insert them to the secondary cache.
+IOStatus CacheDumpedLoaderImpl::RestoreCacheEntriesToSecondaryCache() {
+ // TODO: remove this line when options are used in the loader
+ (void)options_;
+ // Step 1: we check if all the arguments are valid
+ if (secondary_cache_ == nullptr) {
+ return IOStatus::InvalidArgument("Secondary Cache is null");
+ }
+ if (reader_ == nullptr) {
+ return IOStatus::InvalidArgument("CacheDumpReader is null");
+ }
+ // we copy the Cache Deleter Role Map as its member.
+ role_map_ = CopyCacheDeleterRoleMap();
+
+ // Step 2: read the header
+ // TODO: we need to check the cache dump format version and RocksDB version
+ // after the header is read out.
+ IOStatus io_s;
+ DumpUnit dump_unit;
+ std::string data;
+ io_s = ReadHeader(&data, &dump_unit);
+ if (!io_s.ok()) {
+ return io_s;
+ }
+
+ // Step 3: read out the rest of the blocks from the reader. The loop will stop
+ // either I/O status is not ok or we reach to the the end.
+ while (io_s.ok()) {
+ dump_unit.reset();
+ data.clear();
+ // read the content and store in the dump_unit
+ io_s = ReadCacheBlock(&data, &dump_unit);
+ if (!io_s.ok()) {
+ break;
+ }
+ if (dump_unit.type == CacheDumpUnitType::kFooter) {
+ break;
+ }
+ // Create the uncompressed_block based on the information in the dump_unit
+ // (There is no block trailer here compatible with block-based SST file.)
+ Slice content =
+ Slice(static_cast<char*>(dump_unit.value), dump_unit.value_len);
+ Status s = secondary_cache_->InsertSaved(dump_unit.key, content);
+ if (!s.ok()) {
+ io_s = status_to_io_status(std::move(s));
+ }
+ }
+ if (dump_unit.type == CacheDumpUnitType::kFooter) {
+ return IOStatus::OK();
+ } else {
+ return io_s;
+ }
+}
+
+// Read and copy the dump unit metadata to std::string data, decode and create
+// the unit metadata based on the string
+IOStatus CacheDumpedLoaderImpl::ReadDumpUnitMeta(std::string* data,
+ DumpUnitMeta* unit_meta) {
+ assert(reader_ != nullptr);
+ assert(data != nullptr);
+ assert(unit_meta != nullptr);
+ IOStatus io_s = reader_->ReadMetadata(data);
+ if (!io_s.ok()) {
+ return io_s;
+ }
+ return status_to_io_status(
+ CacheDumperHelper::DecodeDumpUnitMeta(*data, unit_meta));
+}
+
+// Read and copy the dump unit to std::string data, decode and create the unit
+// based on the string
+IOStatus CacheDumpedLoaderImpl::ReadDumpUnit(size_t len, std::string* data,
+ DumpUnit* unit) {
+ assert(reader_ != nullptr);
+ assert(data != nullptr);
+ assert(unit != nullptr);
+ IOStatus io_s = reader_->ReadPacket(data);
+ if (!io_s.ok()) {
+ return io_s;
+ }
+ if (data->size() != len) {
+ return IOStatus::Corruption(
+ "The data being read out does not match the size stored in metadata!");
+ }
+ Slice block;
+ return status_to_io_status(CacheDumperHelper::DecodeDumpUnit(*data, unit));
+}
+
+// Read the header
+IOStatus CacheDumpedLoaderImpl::ReadHeader(std::string* data,
+ DumpUnit* dump_unit) {
+ DumpUnitMeta header_meta;
+ header_meta.reset();
+ std::string meta_string;
+ IOStatus io_s = ReadDumpUnitMeta(&meta_string, &header_meta);
+ if (!io_s.ok()) {
+ return io_s;
+ }
+
+ io_s = ReadDumpUnit(header_meta.dump_unit_size, data, dump_unit);
+ if (!io_s.ok()) {
+ return io_s;
+ }
+ uint32_t unit_checksum = crc32c::Value(data->data(), data->size());
+ if (unit_checksum != header_meta.dump_unit_checksum) {
+ return IOStatus::Corruption("Read header unit corrupted!");
+ }
+ return io_s;
+}
+
+// Read the blocks after header is read out
+IOStatus CacheDumpedLoaderImpl::ReadCacheBlock(std::string* data,
+ DumpUnit* dump_unit) {
+ // According to the write process, we read the dump_unit_metadata first
+ DumpUnitMeta unit_meta;
+ unit_meta.reset();
+ std::string unit_string;
+ IOStatus io_s = ReadDumpUnitMeta(&unit_string, &unit_meta);
+ if (!io_s.ok()) {
+ return io_s;
+ }
+
+ // Based on the information in the dump_unit_metadata, we read the dump_unit
+ // and verify if its content is correct.
+ io_s = ReadDumpUnit(unit_meta.dump_unit_size, data, dump_unit);
+ if (!io_s.ok()) {
+ return io_s;
+ }
+ uint32_t unit_checksum = crc32c::Value(data->data(), data->size());
+ if (unit_checksum != unit_meta.dump_unit_checksum) {
+ return IOStatus::Corruption(
+ "Checksum does not match! Read dumped unit corrupted!");
+ }
+ return io_s;
+}
+
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/cache_dump_load_impl.h b/src/rocksdb/utilities/cache_dump_load_impl.h
new file mode 100644
index 000000000..9ca1ff45a
--- /dev/null
+++ b/src/rocksdb/utilities/cache_dump_load_impl.h
@@ -0,0 +1,359 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+#ifndef ROCKSDB_LITE
+
+#include <unordered_map>
+
+#include "file/random_access_file_reader.h"
+#include "file/writable_file_writer.h"
+#include "rocksdb/utilities/cache_dump_load.h"
+#include "table/block_based/block.h"
+#include "table/block_based/block_like_traits.h"
+#include "table/block_based/block_type.h"
+#include "table/block_based/cachable_entry.h"
+#include "table/block_based/parsed_full_filter_block.h"
+#include "table/block_based/reader_common.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// the read buffer size of for the default CacheDumpReader
+const unsigned int kDumpReaderBufferSize = 1024; // 1KB
+static const unsigned int kSizePrefixLen = 4;
+
+enum CacheDumpUnitType : unsigned char {
+ kHeader = 1,
+ kFooter = 2,
+ kData = 3,
+ kFilter = 4,
+ kProperties = 5,
+ kCompressionDictionary = 6,
+ kRangeDeletion = 7,
+ kHashIndexPrefixes = 8,
+ kHashIndexMetadata = 9,
+ kMetaIndex = 10,
+ kIndex = 11,
+ kDeprecatedFilterBlock = 12, // OBSOLETE / DEPRECATED
+ kFilterMetaBlock = 13,
+ kBlockTypeMax,
+};
+
+// The metadata of a dump unit. After it is serilized, its size is fixed 16
+// bytes.
+struct DumpUnitMeta {
+ // sequence number is a monotonically increasing number to indicate the order
+ // of the blocks being written. Header is 0.
+ uint32_t sequence_num;
+ // The Crc32c checksum of its dump unit.
+ uint32_t dump_unit_checksum;
+ // The dump unit size after the dump unit is serilized to a string.
+ uint64_t dump_unit_size;
+
+ void reset() {
+ sequence_num = 0;
+ dump_unit_checksum = 0;
+ dump_unit_size = 0;
+ }
+};
+
+// The data structure to hold a block and its information.
+struct DumpUnit {
+ // The timestamp when the block is identified, copied, and dumped from block
+ // cache
+ uint64_t timestamp;
+ // The type of the block
+ CacheDumpUnitType type;
+ // The key of this block when the block is referenced by this Cache
+ Slice key;
+ // The block size
+ size_t value_len;
+ // The Crc32c checksum of the block
+ uint32_t value_checksum;
+ // Pointer to the block. Note that, in the dump process, it points to a memory
+ // buffer copied from cache block. The buffer is freed when we process the
+ // next block. In the load process, we use an std::string to store the
+ // serialized dump_unit read from the reader. So it points to the memory
+ // address of the begin of the block in this string.
+ void* value;
+
+ DumpUnit() { reset(); }
+
+ void reset() {
+ timestamp = 0;
+ type = CacheDumpUnitType::kBlockTypeMax;
+ key.clear();
+ value_len = 0;
+ value_checksum = 0;
+ value = nullptr;
+ }
+};
+
+// The default implementation of the Cache Dumper
+class CacheDumperImpl : public CacheDumper {
+ public:
+ CacheDumperImpl(const CacheDumpOptions& dump_options,
+ const std::shared_ptr<Cache>& cache,
+ std::unique_ptr<CacheDumpWriter>&& writer)
+ : options_(dump_options), cache_(cache), writer_(std::move(writer)) {}
+ ~CacheDumperImpl() { writer_.reset(); }
+ Status SetDumpFilter(std::vector<DB*> db_list) override;
+ IOStatus DumpCacheEntriesToWriter() override;
+
+ private:
+ IOStatus WriteBlock(CacheDumpUnitType type, const Slice& key,
+ const Slice& value);
+ IOStatus WriteHeader();
+ IOStatus WriteFooter();
+ bool ShouldFilterOut(const Slice& key);
+ std::function<void(const Slice&, void*, size_t, Cache::DeleterFn)>
+ DumpOneBlockCallBack();
+
+ CacheDumpOptions options_;
+ std::shared_ptr<Cache> cache_;
+ std::unique_ptr<CacheDumpWriter> writer_;
+ UnorderedMap<Cache::DeleterFn, CacheEntryRole> role_map_;
+ SystemClock* clock_;
+ uint32_t sequence_num_;
+ // The cache key prefix filter. Currently, we use db_session_id as the prefix,
+ // so using std::set to store the prefixes as filter is enough. Further
+ // improvement can be applied like BloomFilter or others to speedup the
+ // filtering.
+ std::set<std::string> prefix_filter_;
+};
+
+// The default implementation of CacheDumpedLoader
+class CacheDumpedLoaderImpl : public CacheDumpedLoader {
+ public:
+ CacheDumpedLoaderImpl(const CacheDumpOptions& dump_options,
+ const BlockBasedTableOptions& /*toptions*/,
+ const std::shared_ptr<SecondaryCache>& secondary_cache,
+ std::unique_ptr<CacheDumpReader>&& reader)
+ : options_(dump_options),
+ secondary_cache_(secondary_cache),
+ reader_(std::move(reader)) {}
+ ~CacheDumpedLoaderImpl() {}
+ IOStatus RestoreCacheEntriesToSecondaryCache() override;
+
+ private:
+ IOStatus ReadDumpUnitMeta(std::string* data, DumpUnitMeta* unit_meta);
+ IOStatus ReadDumpUnit(size_t len, std::string* data, DumpUnit* unit);
+ IOStatus ReadHeader(std::string* data, DumpUnit* dump_unit);
+ IOStatus ReadCacheBlock(std::string* data, DumpUnit* dump_unit);
+
+ CacheDumpOptions options_;
+ std::shared_ptr<SecondaryCache> secondary_cache_;
+ std::unique_ptr<CacheDumpReader> reader_;
+ UnorderedMap<Cache::DeleterFn, CacheEntryRole> role_map_;
+};
+
+// The default implementation of CacheDumpWriter. We write the blocks to a file
+// sequentially.
+class ToFileCacheDumpWriter : public CacheDumpWriter {
+ public:
+ explicit ToFileCacheDumpWriter(
+ std::unique_ptr<WritableFileWriter>&& file_writer)
+ : file_writer_(std::move(file_writer)) {}
+
+ ~ToFileCacheDumpWriter() { Close().PermitUncheckedError(); }
+
+ // Write the serialized metadata to the file
+ virtual IOStatus WriteMetadata(const Slice& metadata) override {
+ assert(file_writer_ != nullptr);
+ std::string prefix;
+ PutFixed32(&prefix, static_cast<uint32_t>(metadata.size()));
+ IOStatus io_s = file_writer_->Append(Slice(prefix));
+ if (!io_s.ok()) {
+ return io_s;
+ }
+ io_s = file_writer_->Append(metadata);
+ return io_s;
+ }
+
+ // Write the serialized data to the file
+ virtual IOStatus WritePacket(const Slice& data) override {
+ assert(file_writer_ != nullptr);
+ std::string prefix;
+ PutFixed32(&prefix, static_cast<uint32_t>(data.size()));
+ IOStatus io_s = file_writer_->Append(Slice(prefix));
+ if (!io_s.ok()) {
+ return io_s;
+ }
+ io_s = file_writer_->Append(data);
+ return io_s;
+ }
+
+ // Reset the writer
+ virtual IOStatus Close() override {
+ file_writer_.reset();
+ return IOStatus::OK();
+ }
+
+ private:
+ std::unique_ptr<WritableFileWriter> file_writer_;
+};
+
+// The default implementation of CacheDumpReader. It is implemented based on
+// RandomAccessFileReader. Note that, we keep an internal variable to remember
+// the current offset.
+class FromFileCacheDumpReader : public CacheDumpReader {
+ public:
+ explicit FromFileCacheDumpReader(
+ std::unique_ptr<RandomAccessFileReader>&& reader)
+ : file_reader_(std::move(reader)),
+ offset_(0),
+ buffer_(new char[kDumpReaderBufferSize]) {}
+
+ ~FromFileCacheDumpReader() { delete[] buffer_; }
+
+ virtual IOStatus ReadMetadata(std::string* metadata) override {
+ uint32_t metadata_len = 0;
+ IOStatus io_s = ReadSizePrefix(&metadata_len);
+ if (!io_s.ok()) {
+ return io_s;
+ }
+ return Read(metadata_len, metadata);
+ }
+
+ virtual IOStatus ReadPacket(std::string* data) override {
+ uint32_t data_len = 0;
+ IOStatus io_s = ReadSizePrefix(&data_len);
+ if (!io_s.ok()) {
+ return io_s;
+ }
+ return Read(data_len, data);
+ }
+
+ private:
+ IOStatus ReadSizePrefix(uint32_t* len) {
+ std::string prefix;
+ IOStatus io_s = Read(kSizePrefixLen, &prefix);
+ if (!io_s.ok()) {
+ return io_s;
+ }
+ Slice encoded_slice(prefix);
+ if (!GetFixed32(&encoded_slice, len)) {
+ return IOStatus::Corruption("Decode size prefix string failed");
+ }
+ return IOStatus::OK();
+ }
+
+ IOStatus Read(size_t len, std::string* data) {
+ assert(file_reader_ != nullptr);
+ IOStatus io_s;
+
+ unsigned int bytes_to_read = static_cast<unsigned int>(len);
+ unsigned int to_read = bytes_to_read > kDumpReaderBufferSize
+ ? kDumpReaderBufferSize
+ : bytes_to_read;
+
+ while (to_read > 0) {
+ io_s = file_reader_->Read(IOOptions(), offset_, to_read, &result_,
+ buffer_, nullptr,
+ Env::IO_TOTAL /* rate_limiter_priority */);
+ if (!io_s.ok()) {
+ return io_s;
+ }
+ if (result_.size() < to_read) {
+ return IOStatus::Corruption("Corrupted cache dump file.");
+ }
+ data->append(result_.data(), result_.size());
+
+ offset_ += to_read;
+ bytes_to_read -= to_read;
+ to_read = bytes_to_read > kDumpReaderBufferSize ? kDumpReaderBufferSize
+ : bytes_to_read;
+ }
+ return io_s;
+ }
+ std::unique_ptr<RandomAccessFileReader> file_reader_;
+ Slice result_;
+ size_t offset_;
+ char* buffer_;
+};
+
+// The cache dump and load helper class
+class CacheDumperHelper {
+ public:
+ // serialize the dump_unit_meta to a string, it is fixed 16 bytes size.
+ static void EncodeDumpUnitMeta(const DumpUnitMeta& meta, std::string* data) {
+ assert(data);
+ PutFixed32(data, static_cast<uint32_t>(meta.sequence_num));
+ PutFixed32(data, static_cast<uint32_t>(meta.dump_unit_checksum));
+ PutFixed64(data, meta.dump_unit_size);
+ }
+
+ // Serialize the dump_unit to a string.
+ static void EncodeDumpUnit(const DumpUnit& dump_unit, std::string* data) {
+ assert(data);
+ PutFixed64(data, dump_unit.timestamp);
+ data->push_back(dump_unit.type);
+ PutLengthPrefixedSlice(data, dump_unit.key);
+ PutFixed32(data, static_cast<uint32_t>(dump_unit.value_len));
+ PutFixed32(data, dump_unit.value_checksum);
+ PutLengthPrefixedSlice(data,
+ Slice((char*)dump_unit.value, dump_unit.value_len));
+ }
+
+ // Deserialize the dump_unit_meta from a string
+ static Status DecodeDumpUnitMeta(const std::string& encoded_data,
+ DumpUnitMeta* unit_meta) {
+ assert(unit_meta != nullptr);
+ Slice encoded_slice = Slice(encoded_data);
+ if (!GetFixed32(&encoded_slice, &(unit_meta->sequence_num))) {
+ return Status::Incomplete("Decode dumped unit meta sequence_num failed");
+ }
+ if (!GetFixed32(&encoded_slice, &(unit_meta->dump_unit_checksum))) {
+ return Status::Incomplete(
+ "Decode dumped unit meta dump_unit_checksum failed");
+ }
+ if (!GetFixed64(&encoded_slice, &(unit_meta->dump_unit_size))) {
+ return Status::Incomplete(
+ "Decode dumped unit meta dump_unit_size failed");
+ }
+ return Status::OK();
+ }
+
+ // Deserialize the dump_unit from a string.
+ static Status DecodeDumpUnit(const std::string& encoded_data,
+ DumpUnit* dump_unit) {
+ assert(dump_unit != nullptr);
+ Slice encoded_slice = Slice(encoded_data);
+
+ // Decode timestamp
+ if (!GetFixed64(&encoded_slice, &dump_unit->timestamp)) {
+ return Status::Incomplete("Decode dumped unit string failed");
+ }
+ // Decode the block type
+ dump_unit->type = static_cast<CacheDumpUnitType>(encoded_slice[0]);
+ encoded_slice.remove_prefix(1);
+ // Decode the key
+ if (!GetLengthPrefixedSlice(&encoded_slice, &(dump_unit->key))) {
+ return Status::Incomplete("Decode dumped unit string failed");
+ }
+ // Decode the value size
+ uint32_t value_len;
+ if (!GetFixed32(&encoded_slice, &value_len)) {
+ return Status::Incomplete("Decode dumped unit string failed");
+ }
+ dump_unit->value_len = static_cast<size_t>(value_len);
+ // Decode the value checksum
+ if (!GetFixed32(&encoded_slice, &(dump_unit->value_checksum))) {
+ return Status::Incomplete("Decode dumped unit string failed");
+ }
+ // Decode the block content and copy to the memory space whose pointer
+ // will be managed by the cache finally.
+ Slice block;
+ if (!GetLengthPrefixedSlice(&encoded_slice, &block)) {
+ return Status::Incomplete("Decode dumped unit string failed");
+ }
+ dump_unit->value = (void*)block.data();
+ assert(block.size() == dump_unit->value_len);
+ return Status::OK();
+ }
+};
+
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/cassandra/cassandra_compaction_filter.cc b/src/rocksdb/utilities/cassandra/cassandra_compaction_filter.cc
new file mode 100644
index 000000000..4e48d63aa
--- /dev/null
+++ b/src/rocksdb/utilities/cassandra/cassandra_compaction_filter.cc
@@ -0,0 +1,110 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "utilities/cassandra/cassandra_compaction_filter.h"
+
+#include <string>
+
+#include "rocksdb/slice.h"
+#include "rocksdb/utilities/object_registry.h"
+#include "rocksdb/utilities/options_type.h"
+#include "utilities/cassandra/format.h"
+#include "utilities/cassandra/merge_operator.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace cassandra {
+static std::unordered_map<std::string, OptionTypeInfo>
+ cassandra_filter_type_info = {
+#ifndef ROCKSDB_LITE
+ {"purge_ttl_on_expiration",
+ {offsetof(struct CassandraOptions, purge_ttl_on_expiration),
+ OptionType::kBoolean, OptionVerificationType::kNormal,
+ OptionTypeFlags::kNone}},
+ {"gc_grace_period_in_seconds",
+ {offsetof(struct CassandraOptions, gc_grace_period_in_seconds),
+ OptionType::kUInt32T, OptionVerificationType::kNormal,
+ OptionTypeFlags::kNone}},
+#endif // ROCKSDB_LITE
+};
+
+CassandraCompactionFilter::CassandraCompactionFilter(
+ bool purge_ttl_on_expiration, int32_t gc_grace_period_in_seconds)
+ : options_(gc_grace_period_in_seconds, 0, purge_ttl_on_expiration) {
+ RegisterOptions(&options_, &cassandra_filter_type_info);
+}
+
+CompactionFilter::Decision CassandraCompactionFilter::FilterV2(
+ int /*level*/, const Slice& /*key*/, ValueType value_type,
+ const Slice& existing_value, std::string* new_value,
+ std::string* /*skip_until*/) const {
+ bool value_changed = false;
+ RowValue row_value =
+ RowValue::Deserialize(existing_value.data(), existing_value.size());
+ RowValue compacted =
+ options_.purge_ttl_on_expiration
+ ? row_value.RemoveExpiredColumns(&value_changed)
+ : row_value.ConvertExpiredColumnsToTombstones(&value_changed);
+
+ if (value_type == ValueType::kValue) {
+ compacted = compacted.RemoveTombstones(options_.gc_grace_period_in_seconds);
+ }
+
+ if (compacted.Empty()) {
+ return Decision::kRemove;
+ }
+
+ if (value_changed) {
+ compacted.Serialize(new_value);
+ return Decision::kChangeValue;
+ }
+
+ return Decision::kKeep;
+}
+
+CassandraCompactionFilterFactory::CassandraCompactionFilterFactory(
+ bool purge_ttl_on_expiration, int32_t gc_grace_period_in_seconds)
+ : options_(gc_grace_period_in_seconds, 0, purge_ttl_on_expiration) {
+ RegisterOptions(&options_, &cassandra_filter_type_info);
+}
+
+std::unique_ptr<CompactionFilter>
+CassandraCompactionFilterFactory::CreateCompactionFilter(
+ const CompactionFilter::Context&) {
+ std::unique_ptr<CompactionFilter> result(new CassandraCompactionFilter(
+ options_.purge_ttl_on_expiration, options_.gc_grace_period_in_seconds));
+ return result;
+}
+
+#ifndef ROCKSDB_LITE
+int RegisterCassandraObjects(ObjectLibrary& library,
+ const std::string& /*arg*/) {
+ library.AddFactory<MergeOperator>(
+ CassandraValueMergeOperator::kClassName(),
+ [](const std::string& /*uri*/, std::unique_ptr<MergeOperator>* guard,
+ std::string* /* errmsg */) {
+ guard->reset(new CassandraValueMergeOperator(0));
+ return guard->get();
+ });
+ library.AddFactory<CompactionFilter>(
+ CassandraCompactionFilter::kClassName(),
+ [](const std::string& /*uri*/,
+ std::unique_ptr<CompactionFilter>* /*guard */,
+ std::string* /* errmsg */) {
+ return new CassandraCompactionFilter(false, 0);
+ });
+ library.AddFactory<CompactionFilterFactory>(
+ CassandraCompactionFilterFactory::kClassName(),
+ [](const std::string& /*uri*/,
+ std::unique_ptr<CompactionFilterFactory>* guard,
+ std::string* /* errmsg */) {
+ guard->reset(new CassandraCompactionFilterFactory(false, 0));
+ return guard->get();
+ });
+ size_t num_types;
+ return static_cast<int>(library.GetFactoryCount(&num_types));
+}
+#endif // ROCKSDB_LITE
+} // namespace cassandra
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/cassandra/cassandra_compaction_filter.h b/src/rocksdb/utilities/cassandra/cassandra_compaction_filter.h
new file mode 100644
index 000000000..0325a4c39
--- /dev/null
+++ b/src/rocksdb/utilities/cassandra/cassandra_compaction_filter.h
@@ -0,0 +1,57 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+#include <string>
+
+#include "rocksdb/compaction_filter.h"
+#include "rocksdb/slice.h"
+#include "utilities/cassandra/cassandra_options.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace cassandra {
+
+/**
+ * Compaction filter for removing expired Cassandra data with ttl.
+ * If option `purge_ttl_on_expiration` is set to true, expired data
+ * will be directly purged. Otherwise expired data will be converted
+ * tombstones first, then be eventally removed after gc grace period.
+ * `purge_ttl_on_expiration` should only be on in the case all the
+ * writes have same ttl setting, otherwise it could bring old data back.
+ *
+ * Compaction filter is also in charge of removing tombstone that has been
+ * promoted to kValue type after serials of merging in compaction.
+ */
+class CassandraCompactionFilter : public CompactionFilter {
+ public:
+ explicit CassandraCompactionFilter(bool purge_ttl_on_expiration,
+ int32_t gc_grace_period_in_seconds);
+ static const char* kClassName() { return "CassandraCompactionFilter"; }
+ const char* Name() const override { return kClassName(); }
+
+ virtual Decision FilterV2(int level, const Slice& key, ValueType value_type,
+ const Slice& existing_value, std::string* new_value,
+ std::string* skip_until) const override;
+
+ private:
+ CassandraOptions options_;
+};
+
+class CassandraCompactionFilterFactory : public CompactionFilterFactory {
+ public:
+ explicit CassandraCompactionFilterFactory(bool purge_ttl_on_expiration,
+ int32_t gc_grace_period_in_seconds);
+ ~CassandraCompactionFilterFactory() override {}
+
+ std::unique_ptr<CompactionFilter> CreateCompactionFilter(
+ const CompactionFilter::Context& context) override;
+ static const char* kClassName() { return "CassandraCompactionFilterFactory"; }
+ const char* Name() const override { return kClassName(); }
+
+ private:
+ CassandraOptions options_;
+};
+} // namespace cassandra
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/cassandra/cassandra_format_test.cc b/src/rocksdb/utilities/cassandra/cassandra_format_test.cc
new file mode 100644
index 000000000..4f12947ad
--- /dev/null
+++ b/src/rocksdb/utilities/cassandra/cassandra_format_test.cc
@@ -0,0 +1,377 @@
+// Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include <cstring>
+#include <memory>
+
+#include "test_util/testharness.h"
+#include "utilities/cassandra/format.h"
+#include "utilities/cassandra/serialize.h"
+#include "utilities/cassandra/test_utils.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace cassandra {
+
+TEST(ColumnTest, Column) {
+ char data[4] = {'d', 'a', 't', 'a'};
+ int8_t mask = 0;
+ int8_t index = 1;
+ int64_t timestamp = 1494022807044;
+ Column c = Column(mask, index, timestamp, sizeof(data), data);
+
+ EXPECT_EQ(c.Index(), index);
+ EXPECT_EQ(c.Timestamp(), timestamp);
+ EXPECT_EQ(c.Size(), 14 + sizeof(data));
+
+ // Verify the serialization.
+ std::string dest;
+ dest.reserve(c.Size() * 2);
+ c.Serialize(&dest);
+
+ EXPECT_EQ(dest.size(), c.Size());
+ std::size_t offset = 0;
+ EXPECT_EQ(Deserialize<int8_t>(dest.c_str(), offset), mask);
+ offset += sizeof(int8_t);
+ EXPECT_EQ(Deserialize<int8_t>(dest.c_str(), offset), index);
+ offset += sizeof(int8_t);
+ EXPECT_EQ(Deserialize<int64_t>(dest.c_str(), offset), timestamp);
+ offset += sizeof(int64_t);
+ EXPECT_EQ(Deserialize<int32_t>(dest.c_str(), offset), sizeof(data));
+ offset += sizeof(int32_t);
+ EXPECT_TRUE(std::memcmp(data, dest.c_str() + offset, sizeof(data)) == 0);
+
+ // Verify the deserialization.
+ std::string saved_dest = dest;
+ std::shared_ptr<Column> c1 = Column::Deserialize(saved_dest.c_str(), 0);
+ EXPECT_EQ(c1->Index(), index);
+ EXPECT_EQ(c1->Timestamp(), timestamp);
+ EXPECT_EQ(c1->Size(), 14 + sizeof(data));
+
+ c1->Serialize(&dest);
+ EXPECT_EQ(dest.size(), 2 * c.Size());
+ EXPECT_TRUE(std::memcmp(dest.c_str(), dest.c_str() + c.Size(), c.Size()) ==
+ 0);
+
+ // Verify the ColumnBase::Deserialization.
+ saved_dest = dest;
+ std::shared_ptr<ColumnBase> c2 =
+ ColumnBase::Deserialize(saved_dest.c_str(), c.Size());
+ c2->Serialize(&dest);
+ EXPECT_EQ(dest.size(), 3 * c.Size());
+ EXPECT_TRUE(std::memcmp(dest.c_str() + c.Size(), dest.c_str() + c.Size() * 2,
+ c.Size()) == 0);
+}
+
+TEST(ExpiringColumnTest, ExpiringColumn) {
+ char data[4] = {'d', 'a', 't', 'a'};
+ int8_t mask = ColumnTypeMask::EXPIRATION_MASK;
+ int8_t index = 3;
+ int64_t timestamp = 1494022807044;
+ int32_t ttl = 3600;
+ ExpiringColumn c =
+ ExpiringColumn(mask, index, timestamp, sizeof(data), data, ttl);
+
+ EXPECT_EQ(c.Index(), index);
+ EXPECT_EQ(c.Timestamp(), timestamp);
+ EXPECT_EQ(c.Size(), 18 + sizeof(data));
+
+ // Verify the serialization.
+ std::string dest;
+ dest.reserve(c.Size() * 2);
+ c.Serialize(&dest);
+
+ EXPECT_EQ(dest.size(), c.Size());
+ std::size_t offset = 0;
+ EXPECT_EQ(Deserialize<int8_t>(dest.c_str(), offset), mask);
+ offset += sizeof(int8_t);
+ EXPECT_EQ(Deserialize<int8_t>(dest.c_str(), offset), index);
+ offset += sizeof(int8_t);
+ EXPECT_EQ(Deserialize<int64_t>(dest.c_str(), offset), timestamp);
+ offset += sizeof(int64_t);
+ EXPECT_EQ(Deserialize<int32_t>(dest.c_str(), offset), sizeof(data));
+ offset += sizeof(int32_t);
+ EXPECT_TRUE(std::memcmp(data, dest.c_str() + offset, sizeof(data)) == 0);
+ offset += sizeof(data);
+ EXPECT_EQ(Deserialize<int32_t>(dest.c_str(), offset), ttl);
+
+ // Verify the deserialization.
+ std::string saved_dest = dest;
+ std::shared_ptr<ExpiringColumn> c1 =
+ ExpiringColumn::Deserialize(saved_dest.c_str(), 0);
+ EXPECT_EQ(c1->Index(), index);
+ EXPECT_EQ(c1->Timestamp(), timestamp);
+ EXPECT_EQ(c1->Size(), 18 + sizeof(data));
+
+ c1->Serialize(&dest);
+ EXPECT_EQ(dest.size(), 2 * c.Size());
+ EXPECT_TRUE(std::memcmp(dest.c_str(), dest.c_str() + c.Size(), c.Size()) ==
+ 0);
+
+ // Verify the ColumnBase::Deserialization.
+ saved_dest = dest;
+ std::shared_ptr<ColumnBase> c2 =
+ ColumnBase::Deserialize(saved_dest.c_str(), c.Size());
+ c2->Serialize(&dest);
+ EXPECT_EQ(dest.size(), 3 * c.Size());
+ EXPECT_TRUE(std::memcmp(dest.c_str() + c.Size(), dest.c_str() + c.Size() * 2,
+ c.Size()) == 0);
+}
+
+TEST(TombstoneTest, TombstoneCollectable) {
+ int32_t now = (int32_t)time(nullptr);
+ int32_t gc_grace_seconds = 16440;
+ int32_t time_delta_seconds = 10;
+ EXPECT_TRUE(
+ Tombstone(ColumnTypeMask::DELETION_MASK, 0,
+ now - gc_grace_seconds - time_delta_seconds,
+ ToMicroSeconds(now - gc_grace_seconds - time_delta_seconds))
+ .Collectable(gc_grace_seconds));
+ EXPECT_FALSE(
+ Tombstone(ColumnTypeMask::DELETION_MASK, 0,
+ now - gc_grace_seconds + time_delta_seconds,
+ ToMicroSeconds(now - gc_grace_seconds + time_delta_seconds))
+ .Collectable(gc_grace_seconds));
+}
+
+TEST(TombstoneTest, Tombstone) {
+ int8_t mask = ColumnTypeMask::DELETION_MASK;
+ int8_t index = 2;
+ int32_t local_deletion_time = 1494022807;
+ int64_t marked_for_delete_at = 1494022807044;
+ Tombstone c =
+ Tombstone(mask, index, local_deletion_time, marked_for_delete_at);
+
+ EXPECT_EQ(c.Index(), index);
+ EXPECT_EQ(c.Timestamp(), marked_for_delete_at);
+ EXPECT_EQ(c.Size(), 14);
+
+ // Verify the serialization.
+ std::string dest;
+ dest.reserve(c.Size() * 2);
+ c.Serialize(&dest);
+
+ EXPECT_EQ(dest.size(), c.Size());
+ std::size_t offset = 0;
+ EXPECT_EQ(Deserialize<int8_t>(dest.c_str(), offset), mask);
+ offset += sizeof(int8_t);
+ EXPECT_EQ(Deserialize<int8_t>(dest.c_str(), offset), index);
+ offset += sizeof(int8_t);
+ EXPECT_EQ(Deserialize<int32_t>(dest.c_str(), offset), local_deletion_time);
+ offset += sizeof(int32_t);
+ EXPECT_EQ(Deserialize<int64_t>(dest.c_str(), offset), marked_for_delete_at);
+
+ // Verify the deserialization.
+ std::shared_ptr<Tombstone> c1 = Tombstone::Deserialize(dest.c_str(), 0);
+ EXPECT_EQ(c1->Index(), index);
+ EXPECT_EQ(c1->Timestamp(), marked_for_delete_at);
+ EXPECT_EQ(c1->Size(), 14);
+
+ c1->Serialize(&dest);
+ EXPECT_EQ(dest.size(), 2 * c.Size());
+ EXPECT_TRUE(std::memcmp(dest.c_str(), dest.c_str() + c.Size(), c.Size()) ==
+ 0);
+
+ // Verify the ColumnBase::Deserialization.
+ std::shared_ptr<ColumnBase> c2 =
+ ColumnBase::Deserialize(dest.c_str(), c.Size());
+ c2->Serialize(&dest);
+ EXPECT_EQ(dest.size(), 3 * c.Size());
+ EXPECT_TRUE(std::memcmp(dest.c_str() + c.Size(), dest.c_str() + c.Size() * 2,
+ c.Size()) == 0);
+}
+
+class RowValueTest : public testing::Test {};
+
+TEST(RowValueTest, RowTombstone) {
+ int32_t local_deletion_time = 1494022807;
+ int64_t marked_for_delete_at = 1494022807044;
+ RowValue r = RowValue(local_deletion_time, marked_for_delete_at);
+
+ EXPECT_EQ(r.Size(), 12);
+ EXPECT_EQ(r.IsTombstone(), true);
+ EXPECT_EQ(r.LastModifiedTime(), marked_for_delete_at);
+
+ // Verify the serialization.
+ std::string dest;
+ dest.reserve(r.Size() * 2);
+ r.Serialize(&dest);
+
+ EXPECT_EQ(dest.size(), r.Size());
+ std::size_t offset = 0;
+ EXPECT_EQ(Deserialize<int32_t>(dest.c_str(), offset), local_deletion_time);
+ offset += sizeof(int32_t);
+ EXPECT_EQ(Deserialize<int64_t>(dest.c_str(), offset), marked_for_delete_at);
+
+ // Verify the deserialization.
+ RowValue r1 = RowValue::Deserialize(dest.c_str(), r.Size());
+ EXPECT_EQ(r1.Size(), 12);
+ EXPECT_EQ(r1.IsTombstone(), true);
+ EXPECT_EQ(r1.LastModifiedTime(), marked_for_delete_at);
+
+ r1.Serialize(&dest);
+ EXPECT_EQ(dest.size(), 2 * r.Size());
+ EXPECT_TRUE(std::memcmp(dest.c_str(), dest.c_str() + r.Size(), r.Size()) ==
+ 0);
+}
+
+TEST(RowValueTest, RowWithColumns) {
+ std::vector<std::shared_ptr<ColumnBase>> columns;
+ int64_t last_modified_time = 1494022807048;
+ std::size_t columns_data_size = 0;
+
+ char e_data[5] = {'e', 'd', 'a', 't', 'a'};
+ int8_t e_index = 0;
+ int64_t e_timestamp = 1494022807044;
+ int32_t e_ttl = 3600;
+ columns.push_back(std::shared_ptr<ExpiringColumn>(
+ new ExpiringColumn(ColumnTypeMask::EXPIRATION_MASK, e_index, e_timestamp,
+ sizeof(e_data), e_data, e_ttl)));
+ columns_data_size += columns[0]->Size();
+
+ char c_data[4] = {'d', 'a', 't', 'a'};
+ int8_t c_index = 1;
+ int64_t c_timestamp = 1494022807048;
+ columns.push_back(std::shared_ptr<Column>(
+ new Column(0, c_index, c_timestamp, sizeof(c_data), c_data)));
+ columns_data_size += columns[1]->Size();
+
+ int8_t t_index = 2;
+ int32_t t_local_deletion_time = 1494022801;
+ int64_t t_marked_for_delete_at = 1494022807043;
+ columns.push_back(std::shared_ptr<Tombstone>(
+ new Tombstone(ColumnTypeMask::DELETION_MASK, t_index,
+ t_local_deletion_time, t_marked_for_delete_at)));
+ columns_data_size += columns[2]->Size();
+
+ RowValue r = RowValue(std::move(columns), last_modified_time);
+
+ EXPECT_EQ(r.Size(), columns_data_size + 12);
+ EXPECT_EQ(r.IsTombstone(), false);
+ EXPECT_EQ(r.LastModifiedTime(), last_modified_time);
+
+ // Verify the serialization.
+ std::string dest;
+ dest.reserve(r.Size() * 2);
+ r.Serialize(&dest);
+
+ EXPECT_EQ(dest.size(), r.Size());
+ std::size_t offset = 0;
+ EXPECT_EQ(Deserialize<int32_t>(dest.c_str(), offset),
+ std::numeric_limits<int32_t>::max());
+ offset += sizeof(int32_t);
+ EXPECT_EQ(Deserialize<int64_t>(dest.c_str(), offset),
+ std::numeric_limits<int64_t>::min());
+ offset += sizeof(int64_t);
+
+ // Column0: ExpiringColumn
+ EXPECT_EQ(Deserialize<int8_t>(dest.c_str(), offset),
+ ColumnTypeMask::EXPIRATION_MASK);
+ offset += sizeof(int8_t);
+ EXPECT_EQ(Deserialize<int8_t>(dest.c_str(), offset), e_index);
+ offset += sizeof(int8_t);
+ EXPECT_EQ(Deserialize<int64_t>(dest.c_str(), offset), e_timestamp);
+ offset += sizeof(int64_t);
+ EXPECT_EQ(Deserialize<int32_t>(dest.c_str(), offset), sizeof(e_data));
+ offset += sizeof(int32_t);
+ EXPECT_TRUE(std::memcmp(e_data, dest.c_str() + offset, sizeof(e_data)) == 0);
+ offset += sizeof(e_data);
+ EXPECT_EQ(Deserialize<int32_t>(dest.c_str(), offset), e_ttl);
+ offset += sizeof(int32_t);
+
+ // Column1: Column
+ EXPECT_EQ(Deserialize<int8_t>(dest.c_str(), offset), 0);
+ offset += sizeof(int8_t);
+ EXPECT_EQ(Deserialize<int8_t>(dest.c_str(), offset), c_index);
+ offset += sizeof(int8_t);
+ EXPECT_EQ(Deserialize<int64_t>(dest.c_str(), offset), c_timestamp);
+ offset += sizeof(int64_t);
+ EXPECT_EQ(Deserialize<int32_t>(dest.c_str(), offset), sizeof(c_data));
+ offset += sizeof(int32_t);
+ EXPECT_TRUE(std::memcmp(c_data, dest.c_str() + offset, sizeof(c_data)) == 0);
+ offset += sizeof(c_data);
+
+ // Column2: Tombstone
+ EXPECT_EQ(Deserialize<int8_t>(dest.c_str(), offset),
+ ColumnTypeMask::DELETION_MASK);
+ offset += sizeof(int8_t);
+ EXPECT_EQ(Deserialize<int8_t>(dest.c_str(), offset), t_index);
+ offset += sizeof(int8_t);
+ EXPECT_EQ(Deserialize<int32_t>(dest.c_str(), offset), t_local_deletion_time);
+ offset += sizeof(int32_t);
+ EXPECT_EQ(Deserialize<int64_t>(dest.c_str(), offset), t_marked_for_delete_at);
+
+ // Verify the deserialization.
+ RowValue r1 = RowValue::Deserialize(dest.c_str(), r.Size());
+ EXPECT_EQ(r1.Size(), columns_data_size + 12);
+ EXPECT_EQ(r1.IsTombstone(), false);
+ EXPECT_EQ(r1.LastModifiedTime(), last_modified_time);
+
+ r1.Serialize(&dest);
+ EXPECT_EQ(dest.size(), 2 * r.Size());
+ EXPECT_TRUE(std::memcmp(dest.c_str(), dest.c_str() + r.Size(), r.Size()) ==
+ 0);
+}
+
+TEST(RowValueTest, PurgeTtlShouldRemvoeAllColumnsExpired) {
+ int64_t now = time(nullptr);
+
+ auto row_value = CreateTestRowValue(
+ {CreateTestColumnSpec(kColumn, 0, ToMicroSeconds(now)),
+ CreateTestColumnSpec(kExpiringColumn, 1,
+ ToMicroSeconds(now - kTtl - 10)), // expired
+ CreateTestColumnSpec(kExpiringColumn, 2,
+ ToMicroSeconds(now)), // not expired
+ CreateTestColumnSpec(kTombstone, 3, ToMicroSeconds(now))});
+
+ bool changed = false;
+ auto purged = row_value.RemoveExpiredColumns(&changed);
+ EXPECT_TRUE(changed);
+ EXPECT_EQ(purged.get_columns().size(), 3);
+ VerifyRowValueColumns(purged.get_columns(), 0, kColumn, 0,
+ ToMicroSeconds(now));
+ VerifyRowValueColumns(purged.get_columns(), 1, kExpiringColumn, 2,
+ ToMicroSeconds(now));
+ VerifyRowValueColumns(purged.get_columns(), 2, kTombstone, 3,
+ ToMicroSeconds(now));
+
+ purged.RemoveExpiredColumns(&changed);
+ EXPECT_FALSE(changed);
+}
+
+TEST(RowValueTest, ExpireTtlShouldConvertExpiredColumnsToTombstones) {
+ int64_t now = time(nullptr);
+
+ auto row_value = CreateTestRowValue(
+ {CreateTestColumnSpec(kColumn, 0, ToMicroSeconds(now)),
+ CreateTestColumnSpec(kExpiringColumn, 1,
+ ToMicroSeconds(now - kTtl - 10)), // expired
+ CreateTestColumnSpec(kExpiringColumn, 2,
+ ToMicroSeconds(now)), // not expired
+ CreateTestColumnSpec(kTombstone, 3, ToMicroSeconds(now))});
+
+ bool changed = false;
+ auto compacted = row_value.ConvertExpiredColumnsToTombstones(&changed);
+ EXPECT_TRUE(changed);
+ EXPECT_EQ(compacted.get_columns().size(), 4);
+ VerifyRowValueColumns(compacted.get_columns(), 0, kColumn, 0,
+ ToMicroSeconds(now));
+ VerifyRowValueColumns(compacted.get_columns(), 1, kTombstone, 1,
+ ToMicroSeconds(now - 10));
+ VerifyRowValueColumns(compacted.get_columns(), 2, kExpiringColumn, 2,
+ ToMicroSeconds(now));
+ VerifyRowValueColumns(compacted.get_columns(), 3, kTombstone, 3,
+ ToMicroSeconds(now));
+
+ compacted.ConvertExpiredColumnsToTombstones(&changed);
+ EXPECT_FALSE(changed);
+}
+} // namespace cassandra
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/rocksdb/utilities/cassandra/cassandra_functional_test.cc b/src/rocksdb/utilities/cassandra/cassandra_functional_test.cc
new file mode 100644
index 000000000..c5be836e8
--- /dev/null
+++ b/src/rocksdb/utilities/cassandra/cassandra_functional_test.cc
@@ -0,0 +1,446 @@
+// Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include <iostream>
+
+#include "db/db_impl/db_impl.h"
+#include "rocksdb/convenience.h"
+#include "rocksdb/db.h"
+#include "rocksdb/merge_operator.h"
+#include "rocksdb/utilities/object_registry.h"
+#include "test_util/testharness.h"
+#include "util/cast_util.h"
+#include "util/random.h"
+#include "utilities/cassandra/cassandra_compaction_filter.h"
+#include "utilities/cassandra/merge_operator.h"
+#include "utilities/cassandra/test_utils.h"
+#include "utilities/merge_operators.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace cassandra {
+
+// Path to the database on file system
+const std::string kDbName = test::PerThreadDBPath("cassandra_functional_test");
+
+class CassandraStore {
+ public:
+ explicit CassandraStore(std::shared_ptr<DB> db)
+ : db_(db), write_option_(), get_option_() {
+ assert(db);
+ }
+
+ bool Append(const std::string& key, const RowValue& val) {
+ std::string result;
+ val.Serialize(&result);
+ Slice valSlice(result.data(), result.size());
+ auto s = db_->Merge(write_option_, key, valSlice);
+
+ if (s.ok()) {
+ return true;
+ } else {
+ std::cerr << "ERROR " << s.ToString() << std::endl;
+ return false;
+ }
+ }
+
+ bool Put(const std::string& key, const RowValue& val) {
+ std::string result;
+ val.Serialize(&result);
+ Slice valSlice(result.data(), result.size());
+ auto s = db_->Put(write_option_, key, valSlice);
+ if (s.ok()) {
+ return true;
+ } else {
+ std::cerr << "ERROR " << s.ToString() << std::endl;
+ return false;
+ }
+ }
+
+ Status Flush() {
+ Status s = dbfull()->TEST_FlushMemTable();
+ if (s.ok()) {
+ s = dbfull()->TEST_WaitForCompact();
+ }
+ return s;
+ }
+
+ Status Compact() {
+ return dbfull()->TEST_CompactRange(0, nullptr, nullptr,
+ db_->DefaultColumnFamily());
+ }
+
+ std::tuple<bool, RowValue> Get(const std::string& key) {
+ std::string result;
+ auto s = db_->Get(get_option_, key, &result);
+
+ if (s.ok()) {
+ return std::make_tuple(
+ true, RowValue::Deserialize(result.data(), result.size()));
+ }
+
+ if (!s.IsNotFound()) {
+ std::cerr << "ERROR " << s.ToString() << std::endl;
+ }
+
+ return std::make_tuple(false, RowValue(0, 0));
+ }
+
+ private:
+ std::shared_ptr<DB> db_;
+ WriteOptions write_option_;
+ ReadOptions get_option_;
+
+ DBImpl* dbfull() { return static_cast_with_check<DBImpl>(db_.get()); }
+};
+
+class TestCompactionFilterFactory : public CompactionFilterFactory {
+ public:
+ explicit TestCompactionFilterFactory(bool purge_ttl_on_expiration,
+ int32_t gc_grace_period_in_seconds)
+ : purge_ttl_on_expiration_(purge_ttl_on_expiration),
+ gc_grace_period_in_seconds_(gc_grace_period_in_seconds) {}
+
+ std::unique_ptr<CompactionFilter> CreateCompactionFilter(
+ const CompactionFilter::Context& /*context*/) override {
+ return std::unique_ptr<CompactionFilter>(new CassandraCompactionFilter(
+ purge_ttl_on_expiration_, gc_grace_period_in_seconds_));
+ }
+
+ const char* Name() const override { return "TestCompactionFilterFactory"; }
+
+ private:
+ bool purge_ttl_on_expiration_;
+ int32_t gc_grace_period_in_seconds_;
+};
+
+// The class for unit-testing
+class CassandraFunctionalTest : public testing::Test {
+ public:
+ CassandraFunctionalTest() {
+ EXPECT_OK(
+ DestroyDB(kDbName, Options())); // Start each test with a fresh DB
+ }
+
+ std::shared_ptr<DB> OpenDb() {
+ DB* db;
+ Options options;
+ options.create_if_missing = true;
+ options.merge_operator.reset(
+ new CassandraValueMergeOperator(gc_grace_period_in_seconds_));
+ auto* cf_factory = new TestCompactionFilterFactory(
+ purge_ttl_on_expiration_, gc_grace_period_in_seconds_);
+ options.compaction_filter_factory.reset(cf_factory);
+ EXPECT_OK(DB::Open(options, kDbName, &db));
+ return std::shared_ptr<DB>(db);
+ }
+
+ bool purge_ttl_on_expiration_ = false;
+ int32_t gc_grace_period_in_seconds_ = 100;
+};
+
+// THE TEST CASES BEGIN HERE
+
+TEST_F(CassandraFunctionalTest, SimpleMergeTest) {
+ CassandraStore store(OpenDb());
+ int64_t now = time(nullptr);
+
+ store.Append(
+ "k1",
+ CreateTestRowValue({
+ CreateTestColumnSpec(kTombstone, 0, ToMicroSeconds(now + 5)),
+ CreateTestColumnSpec(kColumn, 1, ToMicroSeconds(now + 8)),
+ CreateTestColumnSpec(kExpiringColumn, 2, ToMicroSeconds(now + 5)),
+ }));
+ store.Append(
+ "k1",
+ CreateTestRowValue({
+ CreateTestColumnSpec(kColumn, 0, ToMicroSeconds(now + 2)),
+ CreateTestColumnSpec(kExpiringColumn, 1, ToMicroSeconds(now + 5)),
+ CreateTestColumnSpec(kTombstone, 2, ToMicroSeconds(now + 7)),
+ CreateTestColumnSpec(kExpiringColumn, 7, ToMicroSeconds(now + 17)),
+ }));
+ store.Append(
+ "k1",
+ CreateTestRowValue({
+ CreateTestColumnSpec(kExpiringColumn, 0, ToMicroSeconds(now + 6)),
+ CreateTestColumnSpec(kTombstone, 1, ToMicroSeconds(now + 5)),
+ CreateTestColumnSpec(kColumn, 2, ToMicroSeconds(now + 4)),
+ CreateTestColumnSpec(kTombstone, 11, ToMicroSeconds(now + 11)),
+ }));
+
+ auto ret = store.Get("k1");
+
+ ASSERT_TRUE(std::get<0>(ret));
+ RowValue& merged = std::get<1>(ret);
+ EXPECT_EQ(merged.get_columns().size(), 5);
+ VerifyRowValueColumns(merged.get_columns(), 0, kExpiringColumn, 0,
+ ToMicroSeconds(now + 6));
+ VerifyRowValueColumns(merged.get_columns(), 1, kColumn, 1,
+ ToMicroSeconds(now + 8));
+ VerifyRowValueColumns(merged.get_columns(), 2, kTombstone, 2,
+ ToMicroSeconds(now + 7));
+ VerifyRowValueColumns(merged.get_columns(), 3, kExpiringColumn, 7,
+ ToMicroSeconds(now + 17));
+ VerifyRowValueColumns(merged.get_columns(), 4, kTombstone, 11,
+ ToMicroSeconds(now + 11));
+}
+
+constexpr int64_t kTestTimeoutSecs = 600;
+
+TEST_F(CassandraFunctionalTest,
+ CompactionShouldConvertExpiredColumnsToTombstone) {
+ CassandraStore store(OpenDb());
+ int64_t now = time(nullptr);
+
+ store.Append(
+ "k1",
+ CreateTestRowValue(
+ {CreateTestColumnSpec(kExpiringColumn, 0,
+ ToMicroSeconds(now - kTtl - 20)), // expired
+ CreateTestColumnSpec(
+ kExpiringColumn, 1,
+ ToMicroSeconds(now - kTtl + kTestTimeoutSecs)), // not expired
+ CreateTestColumnSpec(kTombstone, 3, ToMicroSeconds(now))}));
+
+ ASSERT_OK(store.Flush());
+
+ store.Append(
+ "k1",
+ CreateTestRowValue(
+ {CreateTestColumnSpec(kExpiringColumn, 0,
+ ToMicroSeconds(now - kTtl - 10)), // expired
+ CreateTestColumnSpec(kColumn, 2, ToMicroSeconds(now))}));
+
+ ASSERT_OK(store.Flush());
+ ASSERT_OK(store.Compact());
+
+ auto ret = store.Get("k1");
+ ASSERT_TRUE(std::get<0>(ret));
+ RowValue& merged = std::get<1>(ret);
+ EXPECT_EQ(merged.get_columns().size(), 4);
+ VerifyRowValueColumns(merged.get_columns(), 0, kTombstone, 0,
+ ToMicroSeconds(now - 10));
+ VerifyRowValueColumns(merged.get_columns(), 1, kExpiringColumn, 1,
+ ToMicroSeconds(now - kTtl + kTestTimeoutSecs));
+ VerifyRowValueColumns(merged.get_columns(), 2, kColumn, 2,
+ ToMicroSeconds(now));
+ VerifyRowValueColumns(merged.get_columns(), 3, kTombstone, 3,
+ ToMicroSeconds(now));
+}
+
+TEST_F(CassandraFunctionalTest,
+ CompactionShouldPurgeExpiredColumnsIfPurgeTtlIsOn) {
+ purge_ttl_on_expiration_ = true;
+ CassandraStore store(OpenDb());
+ int64_t now = time(nullptr);
+
+ store.Append(
+ "k1",
+ CreateTestRowValue(
+ {CreateTestColumnSpec(kExpiringColumn, 0,
+ ToMicroSeconds(now - kTtl - 20)), // expired
+ CreateTestColumnSpec(kExpiringColumn, 1,
+ ToMicroSeconds(now)), // not expired
+ CreateTestColumnSpec(kTombstone, 3, ToMicroSeconds(now))}));
+
+ ASSERT_OK(store.Flush());
+
+ store.Append(
+ "k1",
+ CreateTestRowValue(
+ {CreateTestColumnSpec(kExpiringColumn, 0,
+ ToMicroSeconds(now - kTtl - 10)), // expired
+ CreateTestColumnSpec(kColumn, 2, ToMicroSeconds(now))}));
+
+ ASSERT_OK(store.Flush());
+ ASSERT_OK(store.Compact());
+
+ auto ret = store.Get("k1");
+ ASSERT_TRUE(std::get<0>(ret));
+ RowValue& merged = std::get<1>(ret);
+ EXPECT_EQ(merged.get_columns().size(), 3);
+ VerifyRowValueColumns(merged.get_columns(), 0, kExpiringColumn, 1,
+ ToMicroSeconds(now));
+ VerifyRowValueColumns(merged.get_columns(), 1, kColumn, 2,
+ ToMicroSeconds(now));
+ VerifyRowValueColumns(merged.get_columns(), 2, kTombstone, 3,
+ ToMicroSeconds(now));
+}
+
+TEST_F(CassandraFunctionalTest,
+ CompactionShouldRemoveRowWhenAllColumnsExpiredIfPurgeTtlIsOn) {
+ purge_ttl_on_expiration_ = true;
+ CassandraStore store(OpenDb());
+ int64_t now = time(nullptr);
+
+ store.Append("k1", CreateTestRowValue({
+ CreateTestColumnSpec(kExpiringColumn, 0,
+ ToMicroSeconds(now - kTtl - 20)),
+ CreateTestColumnSpec(kExpiringColumn, 1,
+ ToMicroSeconds(now - kTtl - 20)),
+ }));
+
+ ASSERT_OK(store.Flush());
+
+ store.Append("k1", CreateTestRowValue({
+ CreateTestColumnSpec(kExpiringColumn, 0,
+ ToMicroSeconds(now - kTtl - 10)),
+ }));
+
+ ASSERT_OK(store.Flush());
+ ASSERT_OK(store.Compact());
+ ASSERT_FALSE(std::get<0>(store.Get("k1")));
+}
+
+TEST_F(CassandraFunctionalTest,
+ CompactionShouldRemoveTombstoneExceedingGCGracePeriod) {
+ purge_ttl_on_expiration_ = true;
+ CassandraStore store(OpenDb());
+ int64_t now = time(nullptr);
+
+ store.Append("k1",
+ CreateTestRowValue(
+ {CreateTestColumnSpec(
+ kTombstone, 0,
+ ToMicroSeconds(now - gc_grace_period_in_seconds_ - 1)),
+ CreateTestColumnSpec(kColumn, 1, ToMicroSeconds(now))}));
+
+ store.Append("k2", CreateTestRowValue({CreateTestColumnSpec(
+ kColumn, 0, ToMicroSeconds(now))}));
+
+ ASSERT_OK(store.Flush());
+
+ store.Append("k1", CreateTestRowValue({
+ CreateTestColumnSpec(kColumn, 1, ToMicroSeconds(now)),
+ }));
+
+ ASSERT_OK(store.Flush());
+ ASSERT_OK(store.Compact());
+
+ auto ret = store.Get("k1");
+ ASSERT_TRUE(std::get<0>(ret));
+ RowValue& gced = std::get<1>(ret);
+ EXPECT_EQ(gced.get_columns().size(), 1);
+ VerifyRowValueColumns(gced.get_columns(), 0, kColumn, 1, ToMicroSeconds(now));
+}
+
+TEST_F(CassandraFunctionalTest, CompactionShouldRemoveTombstoneFromPut) {
+ purge_ttl_on_expiration_ = true;
+ CassandraStore store(OpenDb());
+ int64_t now = time(nullptr);
+
+ store.Put("k1",
+ CreateTestRowValue({
+ CreateTestColumnSpec(
+ kTombstone, 0,
+ ToMicroSeconds(now - gc_grace_period_in_seconds_ - 1)),
+ }));
+
+ ASSERT_OK(store.Flush());
+ ASSERT_OK(store.Compact());
+ ASSERT_FALSE(std::get<0>(store.Get("k1")));
+}
+
+#ifndef ROCKSDB_LITE
+TEST_F(CassandraFunctionalTest, LoadMergeOperator) {
+ ConfigOptions config_options;
+ std::shared_ptr<MergeOperator> mo;
+ config_options.ignore_unsupported_options = false;
+
+ ASSERT_NOK(MergeOperator::CreateFromString(
+ config_options, CassandraValueMergeOperator::kClassName(), &mo));
+
+ config_options.registry->AddLibrary("cassandra", RegisterCassandraObjects,
+ "cassandra");
+
+ ASSERT_OK(MergeOperator::CreateFromString(
+ config_options, CassandraValueMergeOperator::kClassName(), &mo));
+ ASSERT_NE(mo, nullptr);
+ ASSERT_STREQ(mo->Name(), CassandraValueMergeOperator::kClassName());
+ mo.reset();
+ ASSERT_OK(MergeOperator::CreateFromString(
+ config_options,
+ std::string("operands_limit=20;gc_grace_period_in_seconds=42;id=") +
+ CassandraValueMergeOperator::kClassName(),
+ &mo));
+ ASSERT_NE(mo, nullptr);
+ ASSERT_STREQ(mo->Name(), CassandraValueMergeOperator::kClassName());
+ const auto* opts = mo->GetOptions<CassandraOptions>();
+ ASSERT_NE(opts, nullptr);
+ ASSERT_EQ(opts->gc_grace_period_in_seconds, 42);
+ ASSERT_EQ(opts->operands_limit, 20);
+}
+
+TEST_F(CassandraFunctionalTest, LoadCompactionFilter) {
+ ConfigOptions config_options;
+ const CompactionFilter* filter = nullptr;
+ config_options.ignore_unsupported_options = false;
+
+ ASSERT_NOK(CompactionFilter::CreateFromString(
+ config_options, CassandraCompactionFilter::kClassName(), &filter));
+ config_options.registry->AddLibrary("cassandra", RegisterCassandraObjects,
+ "cassandra");
+
+ ASSERT_OK(CompactionFilter::CreateFromString(
+ config_options, CassandraCompactionFilter::kClassName(), &filter));
+ ASSERT_NE(filter, nullptr);
+ ASSERT_STREQ(filter->Name(), CassandraCompactionFilter::kClassName());
+ delete filter;
+ filter = nullptr;
+ ASSERT_OK(CompactionFilter::CreateFromString(
+ config_options,
+ std::string(
+ "purge_ttl_on_expiration=true;gc_grace_period_in_seconds=42;id=") +
+ CassandraCompactionFilter::kClassName(),
+ &filter));
+ ASSERT_NE(filter, nullptr);
+ ASSERT_STREQ(filter->Name(), CassandraCompactionFilter::kClassName());
+ const auto* opts = filter->GetOptions<CassandraOptions>();
+ ASSERT_NE(opts, nullptr);
+ ASSERT_EQ(opts->gc_grace_period_in_seconds, 42);
+ ASSERT_TRUE(opts->purge_ttl_on_expiration);
+ delete filter;
+}
+
+TEST_F(CassandraFunctionalTest, LoadCompactionFilterFactory) {
+ ConfigOptions config_options;
+ std::shared_ptr<CompactionFilterFactory> factory;
+
+ config_options.ignore_unsupported_options = false;
+ ASSERT_NOK(CompactionFilterFactory::CreateFromString(
+ config_options, CassandraCompactionFilterFactory::kClassName(),
+ &factory));
+ config_options.registry->AddLibrary("cassandra", RegisterCassandraObjects,
+ "cassandra");
+
+ ASSERT_OK(CompactionFilterFactory::CreateFromString(
+ config_options, CassandraCompactionFilterFactory::kClassName(),
+ &factory));
+ ASSERT_NE(factory, nullptr);
+ ASSERT_STREQ(factory->Name(), CassandraCompactionFilterFactory::kClassName());
+ factory.reset();
+ ASSERT_OK(CompactionFilterFactory::CreateFromString(
+ config_options,
+ std::string(
+ "purge_ttl_on_expiration=true;gc_grace_period_in_seconds=42;id=") +
+ CassandraCompactionFilterFactory::kClassName(),
+ &factory));
+ ASSERT_NE(factory, nullptr);
+ ASSERT_STREQ(factory->Name(), CassandraCompactionFilterFactory::kClassName());
+ const auto* opts = factory->GetOptions<CassandraOptions>();
+ ASSERT_NE(opts, nullptr);
+ ASSERT_EQ(opts->gc_grace_period_in_seconds, 42);
+ ASSERT_TRUE(opts->purge_ttl_on_expiration);
+}
+#endif // ROCKSDB_LITE
+
+} // namespace cassandra
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/rocksdb/utilities/cassandra/cassandra_options.h b/src/rocksdb/utilities/cassandra/cassandra_options.h
new file mode 100644
index 000000000..efa73a308
--- /dev/null
+++ b/src/rocksdb/utilities/cassandra/cassandra_options.h
@@ -0,0 +1,43 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#include <cstddef>
+#include <cstdint>
+#include <string>
+
+#include "rocksdb/rocksdb_namespace.h"
+
+namespace ROCKSDB_NAMESPACE {
+class ObjectLibrary;
+namespace cassandra {
+struct CassandraOptions {
+ static const char* kName() { return "CassandraOptions"; }
+ CassandraOptions(int32_t _gc_grace_period_in_seconds, size_t _operands_limit,
+ bool _purge_ttl_on_expiration = false)
+ : operands_limit(_operands_limit),
+ gc_grace_period_in_seconds(_gc_grace_period_in_seconds),
+ purge_ttl_on_expiration(_purge_ttl_on_expiration) {}
+ // Limit on the number of merge operands.
+ size_t operands_limit;
+
+ // How long (in seconds) tombstoned data remains before it is purged
+ int32_t gc_grace_period_in_seconds;
+
+ // If is set to true, expired data will be directly purged.
+ // Otherwise expired data will be converted tombstones first,
+ // then be eventually removed after gc grace period. This value should
+ // only true if all writes have same ttl setting, otherwise it could bring old
+ // data back.
+ bool purge_ttl_on_expiration;
+};
+#ifndef ROCKSDB_LITE
+extern "C" {
+int RegisterCassandraObjects(ObjectLibrary& library, const std::string& arg);
+} // extern "C"
+#endif // ROCKSDB_LITE
+} // namespace cassandra
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/cassandra/cassandra_row_merge_test.cc b/src/rocksdb/utilities/cassandra/cassandra_row_merge_test.cc
new file mode 100644
index 000000000..0b4a89287
--- /dev/null
+++ b/src/rocksdb/utilities/cassandra/cassandra_row_merge_test.cc
@@ -0,0 +1,98 @@
+// Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include <memory>
+
+#include "test_util/testharness.h"
+#include "utilities/cassandra/format.h"
+#include "utilities/cassandra/test_utils.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace cassandra {
+
+class RowValueMergeTest : public testing::Test {};
+
+TEST(RowValueMergeTest, Merge) {
+ std::vector<RowValue> row_values;
+ row_values.push_back(CreateTestRowValue({
+ CreateTestColumnSpec(kTombstone, 0, 5),
+ CreateTestColumnSpec(kColumn, 1, 8),
+ CreateTestColumnSpec(kExpiringColumn, 2, 5),
+ }));
+
+ row_values.push_back(CreateTestRowValue({
+ CreateTestColumnSpec(kColumn, 0, 2),
+ CreateTestColumnSpec(kExpiringColumn, 1, 5),
+ CreateTestColumnSpec(kTombstone, 2, 7),
+ CreateTestColumnSpec(kExpiringColumn, 7, 17),
+ }));
+
+ row_values.push_back(CreateTestRowValue({
+ CreateTestColumnSpec(kExpiringColumn, 0, 6),
+ CreateTestColumnSpec(kTombstone, 1, 5),
+ CreateTestColumnSpec(kColumn, 2, 4),
+ CreateTestColumnSpec(kTombstone, 11, 11),
+ }));
+
+ RowValue merged = RowValue::Merge(std::move(row_values));
+ EXPECT_FALSE(merged.IsTombstone());
+ EXPECT_EQ(merged.get_columns().size(), 5);
+ VerifyRowValueColumns(merged.get_columns(), 0, kExpiringColumn, 0, 6);
+ VerifyRowValueColumns(merged.get_columns(), 1, kColumn, 1, 8);
+ VerifyRowValueColumns(merged.get_columns(), 2, kTombstone, 2, 7);
+ VerifyRowValueColumns(merged.get_columns(), 3, kExpiringColumn, 7, 17);
+ VerifyRowValueColumns(merged.get_columns(), 4, kTombstone, 11, 11);
+}
+
+TEST(RowValueMergeTest, MergeWithRowTombstone) {
+ std::vector<RowValue> row_values;
+
+ // A row tombstone.
+ row_values.push_back(CreateRowTombstone(11));
+
+ // This row's timestamp is smaller than tombstone.
+ row_values.push_back(CreateTestRowValue({
+ CreateTestColumnSpec(kColumn, 0, 5),
+ CreateTestColumnSpec(kColumn, 1, 6),
+ }));
+
+ // Some of the column's row is smaller, some is larger.
+ row_values.push_back(CreateTestRowValue({
+ CreateTestColumnSpec(kColumn, 2, 10),
+ CreateTestColumnSpec(kColumn, 3, 12),
+ }));
+
+ // All of the column's rows are larger than tombstone.
+ row_values.push_back(CreateTestRowValue({
+ CreateTestColumnSpec(kColumn, 4, 13),
+ CreateTestColumnSpec(kColumn, 5, 14),
+ }));
+
+ RowValue merged = RowValue::Merge(std::move(row_values));
+ EXPECT_FALSE(merged.IsTombstone());
+ EXPECT_EQ(merged.get_columns().size(), 3);
+ VerifyRowValueColumns(merged.get_columns(), 0, kColumn, 3, 12);
+ VerifyRowValueColumns(merged.get_columns(), 1, kColumn, 4, 13);
+ VerifyRowValueColumns(merged.get_columns(), 2, kColumn, 5, 14);
+
+ // If the tombstone's timestamp is the latest, then it returns a
+ // row tombstone.
+ row_values.push_back(CreateRowTombstone(15));
+
+ row_values.push_back(CreateRowTombstone(17));
+
+ merged = RowValue::Merge(std::move(row_values));
+ EXPECT_TRUE(merged.IsTombstone());
+ EXPECT_EQ(merged.LastModifiedTime(), 17);
+}
+
+} // namespace cassandra
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/rocksdb/utilities/cassandra/cassandra_serialize_test.cc b/src/rocksdb/utilities/cassandra/cassandra_serialize_test.cc
new file mode 100644
index 000000000..c14d8fd80
--- /dev/null
+++ b/src/rocksdb/utilities/cassandra/cassandra_serialize_test.cc
@@ -0,0 +1,164 @@
+// Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "test_util/testharness.h"
+#include "utilities/cassandra/serialize.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace cassandra {
+
+TEST(SerializeTest, SerializeI64) {
+ std::string dest;
+ Serialize<int64_t>(0, &dest);
+ EXPECT_EQ(std::string({'\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00',
+ '\x00'}),
+ dest);
+
+ dest.clear();
+ Serialize<int64_t>(1, &dest);
+ EXPECT_EQ(std::string({'\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00',
+ '\x01'}),
+ dest);
+
+ dest.clear();
+ Serialize<int64_t>(-1, &dest);
+ EXPECT_EQ(std::string({'\xff', '\xff', '\xff', '\xff', '\xff', '\xff', '\xff',
+ '\xff'}),
+ dest);
+
+ dest.clear();
+ Serialize<int64_t>(9223372036854775807, &dest);
+ EXPECT_EQ(std::string({'\x7f', '\xff', '\xff', '\xff', '\xff', '\xff', '\xff',
+ '\xff'}),
+ dest);
+
+ dest.clear();
+ Serialize<int64_t>(-9223372036854775807, &dest);
+ EXPECT_EQ(std::string({'\x80', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00',
+ '\x01'}),
+ dest);
+}
+
+TEST(SerializeTest, DeserializeI64) {
+ std::string dest;
+ std::size_t offset = dest.size();
+ Serialize<int64_t>(0, &dest);
+ EXPECT_EQ(0, Deserialize<int64_t>(dest.c_str(), offset));
+
+ offset = dest.size();
+ Serialize<int64_t>(1, &dest);
+ EXPECT_EQ(1, Deserialize<int64_t>(dest.c_str(), offset));
+
+ offset = dest.size();
+ Serialize<int64_t>(-1, &dest);
+ EXPECT_EQ(-1, Deserialize<int64_t>(dest.c_str(), offset));
+
+ offset = dest.size();
+ Serialize<int64_t>(-9223372036854775807, &dest);
+ EXPECT_EQ(-9223372036854775807, Deserialize<int64_t>(dest.c_str(), offset));
+
+ offset = dest.size();
+ Serialize<int64_t>(9223372036854775807, &dest);
+ EXPECT_EQ(9223372036854775807, Deserialize<int64_t>(dest.c_str(), offset));
+}
+
+TEST(SerializeTest, SerializeI32) {
+ std::string dest;
+ Serialize<int32_t>(0, &dest);
+ EXPECT_EQ(std::string({'\x00', '\x00', '\x00', '\x00'}), dest);
+
+ dest.clear();
+ Serialize<int32_t>(1, &dest);
+ EXPECT_EQ(std::string({'\x00', '\x00', '\x00', '\x01'}), dest);
+
+ dest.clear();
+ Serialize<int32_t>(-1, &dest);
+ EXPECT_EQ(std::string({'\xff', '\xff', '\xff', '\xff'}), dest);
+
+ dest.clear();
+ Serialize<int32_t>(2147483647, &dest);
+ EXPECT_EQ(std::string({'\x7f', '\xff', '\xff', '\xff'}), dest);
+
+ dest.clear();
+ Serialize<int32_t>(-2147483648LL, &dest);
+ EXPECT_EQ(std::string({'\x80', '\x00', '\x00', '\x00'}), dest);
+}
+
+TEST(SerializeTest, DeserializeI32) {
+ std::string dest;
+ std::size_t offset = dest.size();
+ Serialize<int32_t>(0, &dest);
+ EXPECT_EQ(0, Deserialize<int32_t>(dest.c_str(), offset));
+
+ offset = dest.size();
+ Serialize<int32_t>(1, &dest);
+ EXPECT_EQ(1, Deserialize<int32_t>(dest.c_str(), offset));
+
+ offset = dest.size();
+ Serialize<int32_t>(-1, &dest);
+ EXPECT_EQ(-1, Deserialize<int32_t>(dest.c_str(), offset));
+
+ offset = dest.size();
+ Serialize<int32_t>(2147483647, &dest);
+ EXPECT_EQ(2147483647, Deserialize<int32_t>(dest.c_str(), offset));
+
+ offset = dest.size();
+ Serialize<int32_t>(-2147483648LL, &dest);
+ EXPECT_EQ(-2147483648LL, Deserialize<int32_t>(dest.c_str(), offset));
+}
+
+TEST(SerializeTest, SerializeI8) {
+ std::string dest;
+ Serialize<int8_t>(0, &dest);
+ EXPECT_EQ(std::string({'\x00'}), dest);
+
+ dest.clear();
+ Serialize<int8_t>(1, &dest);
+ EXPECT_EQ(std::string({'\x01'}), dest);
+
+ dest.clear();
+ Serialize<int8_t>(-1, &dest);
+ EXPECT_EQ(std::string({'\xff'}), dest);
+
+ dest.clear();
+ Serialize<int8_t>(127, &dest);
+ EXPECT_EQ(std::string({'\x7f'}), dest);
+
+ dest.clear();
+ Serialize<int8_t>(-128, &dest);
+ EXPECT_EQ(std::string({'\x80'}), dest);
+}
+
+TEST(SerializeTest, DeserializeI8) {
+ std::string dest;
+ std::size_t offset = dest.size();
+ Serialize<int8_t>(0, &dest);
+ EXPECT_EQ(0, Deserialize<int8_t>(dest.c_str(), offset));
+
+ offset = dest.size();
+ Serialize<int8_t>(1, &dest);
+ EXPECT_EQ(1, Deserialize<int8_t>(dest.c_str(), offset));
+
+ offset = dest.size();
+ Serialize<int8_t>(-1, &dest);
+ EXPECT_EQ(-1, Deserialize<int8_t>(dest.c_str(), offset));
+
+ offset = dest.size();
+ Serialize<int8_t>(127, &dest);
+ EXPECT_EQ(127, Deserialize<int8_t>(dest.c_str(), offset));
+
+ offset = dest.size();
+ Serialize<int8_t>(-128, &dest);
+ EXPECT_EQ(-128, Deserialize<int8_t>(dest.c_str(), offset));
+}
+
+} // namespace cassandra
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/rocksdb/utilities/cassandra/format.cc b/src/rocksdb/utilities/cassandra/format.cc
new file mode 100644
index 000000000..cc1dd2f28
--- /dev/null
+++ b/src/rocksdb/utilities/cassandra/format.cc
@@ -0,0 +1,367 @@
+// Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "format.h"
+
+#include <algorithm>
+#include <map>
+#include <memory>
+
+#include "utilities/cassandra/serialize.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace cassandra {
+namespace {
+const int32_t kDefaultLocalDeletionTime = std::numeric_limits<int32_t>::max();
+const int64_t kDefaultMarkedForDeleteAt = std::numeric_limits<int64_t>::min();
+} // namespace
+
+ColumnBase::ColumnBase(int8_t mask, int8_t index)
+ : mask_(mask), index_(index) {}
+
+std::size_t ColumnBase::Size() const { return sizeof(mask_) + sizeof(index_); }
+
+int8_t ColumnBase::Mask() const { return mask_; }
+
+int8_t ColumnBase::Index() const { return index_; }
+
+void ColumnBase::Serialize(std::string* dest) const {
+ ROCKSDB_NAMESPACE::cassandra::Serialize<int8_t>(mask_, dest);
+ ROCKSDB_NAMESPACE::cassandra::Serialize<int8_t>(index_, dest);
+}
+
+std::shared_ptr<ColumnBase> ColumnBase::Deserialize(const char* src,
+ std::size_t offset) {
+ int8_t mask = ROCKSDB_NAMESPACE::cassandra::Deserialize<int8_t>(src, offset);
+ if ((mask & ColumnTypeMask::DELETION_MASK) != 0) {
+ return Tombstone::Deserialize(src, offset);
+ } else if ((mask & ColumnTypeMask::EXPIRATION_MASK) != 0) {
+ return ExpiringColumn::Deserialize(src, offset);
+ } else {
+ return Column::Deserialize(src, offset);
+ }
+}
+
+Column::Column(int8_t mask, int8_t index, int64_t timestamp, int32_t value_size,
+ const char* value)
+ : ColumnBase(mask, index),
+ timestamp_(timestamp),
+ value_size_(value_size),
+ value_(value) {}
+
+int64_t Column::Timestamp() const { return timestamp_; }
+
+std::size_t Column::Size() const {
+ return ColumnBase::Size() + sizeof(timestamp_) + sizeof(value_size_) +
+ value_size_;
+}
+
+void Column::Serialize(std::string* dest) const {
+ ColumnBase::Serialize(dest);
+ ROCKSDB_NAMESPACE::cassandra::Serialize<int64_t>(timestamp_, dest);
+ ROCKSDB_NAMESPACE::cassandra::Serialize<int32_t>(value_size_, dest);
+ dest->append(value_, value_size_);
+}
+
+std::shared_ptr<Column> Column::Deserialize(const char* src,
+ std::size_t offset) {
+ int8_t mask = ROCKSDB_NAMESPACE::cassandra::Deserialize<int8_t>(src, offset);
+ offset += sizeof(mask);
+ int8_t index = ROCKSDB_NAMESPACE::cassandra::Deserialize<int8_t>(src, offset);
+ offset += sizeof(index);
+ int64_t timestamp =
+ ROCKSDB_NAMESPACE::cassandra::Deserialize<int64_t>(src, offset);
+ offset += sizeof(timestamp);
+ int32_t value_size =
+ ROCKSDB_NAMESPACE::cassandra::Deserialize<int32_t>(src, offset);
+ offset += sizeof(value_size);
+ return std::make_shared<Column>(mask, index, timestamp, value_size,
+ src + offset);
+}
+
+ExpiringColumn::ExpiringColumn(int8_t mask, int8_t index, int64_t timestamp,
+ int32_t value_size, const char* value,
+ int32_t ttl)
+ : Column(mask, index, timestamp, value_size, value), ttl_(ttl) {}
+
+std::size_t ExpiringColumn::Size() const {
+ return Column::Size() + sizeof(ttl_);
+}
+
+void ExpiringColumn::Serialize(std::string* dest) const {
+ Column::Serialize(dest);
+ ROCKSDB_NAMESPACE::cassandra::Serialize<int32_t>(ttl_, dest);
+}
+
+std::chrono::time_point<std::chrono::system_clock> ExpiringColumn::TimePoint()
+ const {
+ return std::chrono::time_point<std::chrono::system_clock>(
+ std::chrono::microseconds(Timestamp()));
+}
+
+std::chrono::seconds ExpiringColumn::Ttl() const {
+ return std::chrono::seconds(ttl_);
+}
+
+bool ExpiringColumn::Expired() const {
+ return TimePoint() + Ttl() < std::chrono::system_clock::now();
+}
+
+std::shared_ptr<Tombstone> ExpiringColumn::ToTombstone() const {
+ auto expired_at = (TimePoint() + Ttl()).time_since_epoch();
+ int32_t local_deletion_time = static_cast<int32_t>(
+ std::chrono::duration_cast<std::chrono::seconds>(expired_at).count());
+ int64_t marked_for_delete_at =
+ std::chrono::duration_cast<std::chrono::microseconds>(expired_at).count();
+ return std::make_shared<Tombstone>(
+ static_cast<int8_t>(ColumnTypeMask::DELETION_MASK), Index(),
+ local_deletion_time, marked_for_delete_at);
+}
+
+std::shared_ptr<ExpiringColumn> ExpiringColumn::Deserialize(
+ const char* src, std::size_t offset) {
+ int8_t mask = ROCKSDB_NAMESPACE::cassandra::Deserialize<int8_t>(src, offset);
+ offset += sizeof(mask);
+ int8_t index = ROCKSDB_NAMESPACE::cassandra::Deserialize<int8_t>(src, offset);
+ offset += sizeof(index);
+ int64_t timestamp =
+ ROCKSDB_NAMESPACE::cassandra::Deserialize<int64_t>(src, offset);
+ offset += sizeof(timestamp);
+ int32_t value_size =
+ ROCKSDB_NAMESPACE::cassandra::Deserialize<int32_t>(src, offset);
+ offset += sizeof(value_size);
+ const char* value = src + offset;
+ offset += value_size;
+ int32_t ttl = ROCKSDB_NAMESPACE::cassandra::Deserialize<int32_t>(src, offset);
+ return std::make_shared<ExpiringColumn>(mask, index, timestamp, value_size,
+ value, ttl);
+}
+
+Tombstone::Tombstone(int8_t mask, int8_t index, int32_t local_deletion_time,
+ int64_t marked_for_delete_at)
+ : ColumnBase(mask, index),
+ local_deletion_time_(local_deletion_time),
+ marked_for_delete_at_(marked_for_delete_at) {}
+
+int64_t Tombstone::Timestamp() const { return marked_for_delete_at_; }
+
+std::size_t Tombstone::Size() const {
+ return ColumnBase::Size() + sizeof(local_deletion_time_) +
+ sizeof(marked_for_delete_at_);
+}
+
+void Tombstone::Serialize(std::string* dest) const {
+ ColumnBase::Serialize(dest);
+ ROCKSDB_NAMESPACE::cassandra::Serialize<int32_t>(local_deletion_time_, dest);
+ ROCKSDB_NAMESPACE::cassandra::Serialize<int64_t>(marked_for_delete_at_, dest);
+}
+
+bool Tombstone::Collectable(int32_t gc_grace_period_in_seconds) const {
+ auto local_deleted_at = std::chrono::time_point<std::chrono::system_clock>(
+ std::chrono::seconds(local_deletion_time_));
+ auto gc_grace_period = std::chrono::seconds(gc_grace_period_in_seconds);
+ return local_deleted_at + gc_grace_period < std::chrono::system_clock::now();
+}
+
+std::shared_ptr<Tombstone> Tombstone::Deserialize(const char* src,
+ std::size_t offset) {
+ int8_t mask = ROCKSDB_NAMESPACE::cassandra::Deserialize<int8_t>(src, offset);
+ offset += sizeof(mask);
+ int8_t index = ROCKSDB_NAMESPACE::cassandra::Deserialize<int8_t>(src, offset);
+ offset += sizeof(index);
+ int32_t local_deletion_time =
+ ROCKSDB_NAMESPACE::cassandra::Deserialize<int32_t>(src, offset);
+ offset += sizeof(int32_t);
+ int64_t marked_for_delete_at =
+ ROCKSDB_NAMESPACE::cassandra::Deserialize<int64_t>(src, offset);
+ return std::make_shared<Tombstone>(mask, index, local_deletion_time,
+ marked_for_delete_at);
+}
+
+RowValue::RowValue(int32_t local_deletion_time, int64_t marked_for_delete_at)
+ : local_deletion_time_(local_deletion_time),
+ marked_for_delete_at_(marked_for_delete_at),
+ columns_(),
+ last_modified_time_(0) {}
+
+RowValue::RowValue(Columns columns, int64_t last_modified_time)
+ : local_deletion_time_(kDefaultLocalDeletionTime),
+ marked_for_delete_at_(kDefaultMarkedForDeleteAt),
+ columns_(std::move(columns)),
+ last_modified_time_(last_modified_time) {}
+
+std::size_t RowValue::Size() const {
+ std::size_t size =
+ sizeof(local_deletion_time_) + sizeof(marked_for_delete_at_);
+ for (const auto& column : columns_) {
+ size += column->Size();
+ }
+ return size;
+}
+
+int64_t RowValue::LastModifiedTime() const {
+ if (IsTombstone()) {
+ return marked_for_delete_at_;
+ } else {
+ return last_modified_time_;
+ }
+}
+
+bool RowValue::IsTombstone() const {
+ return marked_for_delete_at_ > kDefaultMarkedForDeleteAt;
+}
+
+void RowValue::Serialize(std::string* dest) const {
+ ROCKSDB_NAMESPACE::cassandra::Serialize<int32_t>(local_deletion_time_, dest);
+ ROCKSDB_NAMESPACE::cassandra::Serialize<int64_t>(marked_for_delete_at_, dest);
+ for (const auto& column : columns_) {
+ column->Serialize(dest);
+ }
+}
+
+RowValue RowValue::RemoveExpiredColumns(bool* changed) const {
+ *changed = false;
+ Columns new_columns;
+ for (auto& column : columns_) {
+ if (column->Mask() == ColumnTypeMask::EXPIRATION_MASK) {
+ std::shared_ptr<ExpiringColumn> expiring_column =
+ std::static_pointer_cast<ExpiringColumn>(column);
+
+ if (expiring_column->Expired()) {
+ *changed = true;
+ continue;
+ }
+ }
+
+ new_columns.push_back(column);
+ }
+ return RowValue(std::move(new_columns), last_modified_time_);
+}
+
+RowValue RowValue::ConvertExpiredColumnsToTombstones(bool* changed) const {
+ *changed = false;
+ Columns new_columns;
+ for (auto& column : columns_) {
+ if (column->Mask() == ColumnTypeMask::EXPIRATION_MASK) {
+ std::shared_ptr<ExpiringColumn> expiring_column =
+ std::static_pointer_cast<ExpiringColumn>(column);
+
+ if (expiring_column->Expired()) {
+ std::shared_ptr<Tombstone> tombstone = expiring_column->ToTombstone();
+ new_columns.push_back(tombstone);
+ *changed = true;
+ continue;
+ }
+ }
+ new_columns.push_back(column);
+ }
+ return RowValue(std::move(new_columns), last_modified_time_);
+}
+
+RowValue RowValue::RemoveTombstones(int32_t gc_grace_period) const {
+ Columns new_columns;
+ for (auto& column : columns_) {
+ if (column->Mask() == ColumnTypeMask::DELETION_MASK) {
+ std::shared_ptr<Tombstone> tombstone =
+ std::static_pointer_cast<Tombstone>(column);
+
+ if (tombstone->Collectable(gc_grace_period)) {
+ continue;
+ }
+ }
+
+ new_columns.push_back(column);
+ }
+ return RowValue(std::move(new_columns), last_modified_time_);
+}
+
+bool RowValue::Empty() const { return columns_.empty(); }
+
+RowValue RowValue::Deserialize(const char* src, std::size_t size) {
+ std::size_t offset = 0;
+ assert(size >= sizeof(local_deletion_time_) + sizeof(marked_for_delete_at_));
+ int32_t local_deletion_time =
+ ROCKSDB_NAMESPACE::cassandra::Deserialize<int32_t>(src, offset);
+ offset += sizeof(int32_t);
+ int64_t marked_for_delete_at =
+ ROCKSDB_NAMESPACE::cassandra::Deserialize<int64_t>(src, offset);
+ offset += sizeof(int64_t);
+ if (offset == size) {
+ return RowValue(local_deletion_time, marked_for_delete_at);
+ }
+
+ assert(local_deletion_time == kDefaultLocalDeletionTime);
+ assert(marked_for_delete_at == kDefaultMarkedForDeleteAt);
+ Columns columns;
+ int64_t last_modified_time = 0;
+ while (offset < size) {
+ auto c = ColumnBase::Deserialize(src, offset);
+ offset += c->Size();
+ assert(offset <= size);
+ last_modified_time = std::max(last_modified_time, c->Timestamp());
+ columns.push_back(std::move(c));
+ }
+
+ return RowValue(std::move(columns), last_modified_time);
+}
+
+// Merge multiple row values into one.
+// For each column in rows with same index, we pick the one with latest
+// timestamp. And we also take row tombstone into consideration, by iterating
+// each row from reverse timestamp order, and stop once we hit the first
+// row tombstone.
+RowValue RowValue::Merge(std::vector<RowValue>&& values) {
+ assert(values.size() > 0);
+ if (values.size() == 1) {
+ return std::move(values[0]);
+ }
+
+ // Merge columns by their last modified time, and skip once we hit
+ // a row tombstone.
+ std::sort(values.begin(), values.end(),
+ [](const RowValue& r1, const RowValue& r2) {
+ return r1.LastModifiedTime() > r2.LastModifiedTime();
+ });
+
+ std::map<int8_t, std::shared_ptr<ColumnBase>> merged_columns;
+ int64_t tombstone_timestamp = 0;
+
+ for (auto& value : values) {
+ if (value.IsTombstone()) {
+ if (merged_columns.size() == 0) {
+ return std::move(value);
+ }
+ tombstone_timestamp = value.LastModifiedTime();
+ break;
+ }
+ for (auto& column : value.columns_) {
+ int8_t index = column->Index();
+ if (merged_columns.find(index) == merged_columns.end()) {
+ merged_columns[index] = column;
+ } else {
+ if (column->Timestamp() > merged_columns[index]->Timestamp()) {
+ merged_columns[index] = column;
+ }
+ }
+ }
+ }
+
+ int64_t last_modified_time = 0;
+ Columns columns;
+ for (auto& pair : merged_columns) {
+ // For some row, its last_modified_time > row tombstone_timestamp, but
+ // it might have rows whose timestamp is ealier than tombstone, so we
+ // ned to filter these rows.
+ if (pair.second->Timestamp() <= tombstone_timestamp) {
+ continue;
+ }
+ last_modified_time = std::max(last_modified_time, pair.second->Timestamp());
+ columns.push_back(std::move(pair.second));
+ }
+ return RowValue(std::move(columns), last_modified_time);
+}
+
+} // namespace cassandra
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/cassandra/format.h b/src/rocksdb/utilities/cassandra/format.h
new file mode 100644
index 000000000..1b2714735
--- /dev/null
+++ b/src/rocksdb/utilities/cassandra/format.h
@@ -0,0 +1,183 @@
+// Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+/**
+ * The encoding of Cassandra Row Value.
+ *
+ * A Cassandra Row Value could either be a row tombstone,
+ * or contains multiple columns, it has following fields:
+ *
+ * struct row_value {
+ * int32_t local_deletion_time; // Time in second when the row is deleted,
+ * // only used for Cassandra tombstone gc.
+ * int64_t marked_for_delete_at; // Ms that marked this row is deleted.
+ * struct column_base columns[]; // For non tombstone row, all columns
+ * // are stored here.
+ * }
+ *
+ * If the local_deletion_time and marked_for_delete_at is set, then this is
+ * a tombstone, otherwise it contains multiple columns.
+ *
+ * There are three type of Columns: Normal Column, Expiring Column and Column
+ * Tombstone, which have following fields:
+ *
+ * // Identify the type of the column.
+ * enum mask {
+ * DELETION_MASK = 0x01,
+ * EXPIRATION_MASK = 0x02,
+ * };
+ *
+ * struct column {
+ * int8_t mask = 0;
+ * int8_t index;
+ * int64_t timestamp;
+ * int32_t value_length;
+ * char value[value_length];
+ * }
+ *
+ * struct expiring_column {
+ * int8_t mask = mask.EXPIRATION_MASK;
+ * int8_t index;
+ * int64_t timestamp;
+ * int32_t value_length;
+ * char value[value_length];
+ * int32_t ttl;
+ * }
+ *
+ * struct tombstone_column {
+ * int8_t mask = mask.DELETION_MASK;
+ * int8_t index;
+ * int32_t local_deletion_time; // Similar to row_value's field.
+ * int64_t marked_for_delete_at;
+ * }
+ */
+
+#pragma once
+#include <chrono>
+#include <memory>
+#include <vector>
+
+#include "rocksdb/merge_operator.h"
+#include "rocksdb/slice.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace cassandra {
+
+// Identify the type of the column.
+enum ColumnTypeMask {
+ DELETION_MASK = 0x01,
+ EXPIRATION_MASK = 0x02,
+};
+
+class ColumnBase {
+ public:
+ ColumnBase(int8_t mask, int8_t index);
+ virtual ~ColumnBase() = default;
+
+ virtual int64_t Timestamp() const = 0;
+ virtual int8_t Mask() const;
+ virtual int8_t Index() const;
+ virtual std::size_t Size() const;
+ virtual void Serialize(std::string* dest) const;
+ static std::shared_ptr<ColumnBase> Deserialize(const char* src,
+ std::size_t offset);
+
+ private:
+ int8_t mask_;
+ int8_t index_;
+};
+
+class Column : public ColumnBase {
+ public:
+ Column(int8_t mask, int8_t index, int64_t timestamp, int32_t value_size,
+ const char* value);
+
+ virtual int64_t Timestamp() const override;
+ virtual std::size_t Size() const override;
+ virtual void Serialize(std::string* dest) const override;
+ static std::shared_ptr<Column> Deserialize(const char* src,
+ std::size_t offset);
+
+ private:
+ int64_t timestamp_;
+ int32_t value_size_;
+ const char* value_;
+};
+
+class Tombstone : public ColumnBase {
+ public:
+ Tombstone(int8_t mask, int8_t index, int32_t local_deletion_time,
+ int64_t marked_for_delete_at);
+
+ virtual int64_t Timestamp() const override;
+ virtual std::size_t Size() const override;
+ virtual void Serialize(std::string* dest) const override;
+ bool Collectable(int32_t gc_grace_period) const;
+ static std::shared_ptr<Tombstone> Deserialize(const char* src,
+ std::size_t offset);
+
+ private:
+ int32_t local_deletion_time_;
+ int64_t marked_for_delete_at_;
+};
+
+class ExpiringColumn : public Column {
+ public:
+ ExpiringColumn(int8_t mask, int8_t index, int64_t timestamp,
+ int32_t value_size, const char* value, int32_t ttl);
+
+ virtual std::size_t Size() const override;
+ virtual void Serialize(std::string* dest) const override;
+ bool Expired() const;
+ std::shared_ptr<Tombstone> ToTombstone() const;
+
+ static std::shared_ptr<ExpiringColumn> Deserialize(const char* src,
+ std::size_t offset);
+
+ private:
+ int32_t ttl_;
+ std::chrono::time_point<std::chrono::system_clock> TimePoint() const;
+ std::chrono::seconds Ttl() const;
+};
+
+using Columns = std::vector<std::shared_ptr<ColumnBase>>;
+
+class RowValue {
+ public:
+ // Create a Row Tombstone.
+ RowValue(int32_t local_deletion_time, int64_t marked_for_delete_at);
+ // Create a Row containing columns.
+ RowValue(Columns columns, int64_t last_modified_time);
+ RowValue(const RowValue& /*that*/) = delete;
+ RowValue(RowValue&& /*that*/) noexcept = default;
+ RowValue& operator=(const RowValue& /*that*/) = delete;
+ RowValue& operator=(RowValue&& /*that*/) = default;
+
+ std::size_t Size() const;
+ bool IsTombstone() const;
+ // For Tombstone this returns the marked_for_delete_at_,
+ // otherwise it returns the max timestamp of containing columns.
+ int64_t LastModifiedTime() const;
+ void Serialize(std::string* dest) const;
+ RowValue RemoveExpiredColumns(bool* changed) const;
+ RowValue ConvertExpiredColumnsToTombstones(bool* changed) const;
+ RowValue RemoveTombstones(int32_t gc_grace_period) const;
+ bool Empty() const;
+
+ static RowValue Deserialize(const char* src, std::size_t size);
+ // Merge multiple rows according to their timestamp.
+ static RowValue Merge(std::vector<RowValue>&& values);
+
+ const Columns& get_columns() { return columns_; }
+
+ private:
+ int32_t local_deletion_time_;
+ int64_t marked_for_delete_at_;
+ Columns columns_;
+ int64_t last_modified_time_;
+};
+
+} // namespace cassandra
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/cassandra/merge_operator.cc b/src/rocksdb/utilities/cassandra/merge_operator.cc
new file mode 100644
index 000000000..bde5dcbad
--- /dev/null
+++ b/src/rocksdb/utilities/cassandra/merge_operator.cc
@@ -0,0 +1,82 @@
+// Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "merge_operator.h"
+
+#include <assert.h>
+
+#include <memory>
+
+#include "rocksdb/merge_operator.h"
+#include "rocksdb/slice.h"
+#include "rocksdb/utilities/options_type.h"
+#include "utilities/cassandra/format.h"
+#include "utilities/merge_operators.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace cassandra {
+static std::unordered_map<std::string, OptionTypeInfo>
+ merge_operator_options_info = {
+#ifndef ROCKSDB_LITE
+ {"gc_grace_period_in_seconds",
+ {offsetof(struct CassandraOptions, gc_grace_period_in_seconds),
+ OptionType::kUInt32T, OptionVerificationType::kNormal,
+ OptionTypeFlags::kNone}},
+ {"operands_limit",
+ {offsetof(struct CassandraOptions, operands_limit), OptionType::kSizeT,
+ OptionVerificationType::kNormal, OptionTypeFlags::kNone}},
+#endif // ROCKSDB_LITE
+};
+
+CassandraValueMergeOperator::CassandraValueMergeOperator(
+ int32_t gc_grace_period_in_seconds, size_t operands_limit)
+ : options_(gc_grace_period_in_seconds, operands_limit) {
+ RegisterOptions(&options_, &merge_operator_options_info);
+}
+
+// Implementation for the merge operation (merges two Cassandra values)
+bool CassandraValueMergeOperator::FullMergeV2(
+ const MergeOperationInput& merge_in,
+ MergeOperationOutput* merge_out) const {
+ // Clear the *new_value for writing.
+ merge_out->new_value.clear();
+ std::vector<RowValue> row_values;
+ if (merge_in.existing_value) {
+ row_values.push_back(RowValue::Deserialize(
+ merge_in.existing_value->data(), merge_in.existing_value->size()));
+ }
+
+ for (auto& operand : merge_in.operand_list) {
+ row_values.push_back(RowValue::Deserialize(operand.data(), operand.size()));
+ }
+
+ RowValue merged = RowValue::Merge(std::move(row_values));
+ merged = merged.RemoveTombstones(options_.gc_grace_period_in_seconds);
+ merge_out->new_value.reserve(merged.Size());
+ merged.Serialize(&(merge_out->new_value));
+
+ return true;
+}
+
+bool CassandraValueMergeOperator::PartialMergeMulti(
+ const Slice& /*key*/, const std::deque<Slice>& operand_list,
+ std::string* new_value, Logger* /*logger*/) const {
+ // Clear the *new_value for writing.
+ assert(new_value);
+ new_value->clear();
+
+ std::vector<RowValue> row_values;
+ for (auto& operand : operand_list) {
+ row_values.push_back(RowValue::Deserialize(operand.data(), operand.size()));
+ }
+ RowValue merged = RowValue::Merge(std::move(row_values));
+ new_value->reserve(merged.Size());
+ merged.Serialize(new_value);
+ return true;
+}
+
+} // namespace cassandra
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/cassandra/merge_operator.h b/src/rocksdb/utilities/cassandra/merge_operator.h
new file mode 100644
index 000000000..af8725db7
--- /dev/null
+++ b/src/rocksdb/utilities/cassandra/merge_operator.h
@@ -0,0 +1,44 @@
+// Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+#include "rocksdb/merge_operator.h"
+#include "rocksdb/slice.h"
+#include "utilities/cassandra/cassandra_options.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace cassandra {
+
+/**
+ * A MergeOperator for rocksdb that implements Cassandra row value merge.
+ */
+class CassandraValueMergeOperator : public MergeOperator {
+ public:
+ explicit CassandraValueMergeOperator(int32_t gc_grace_period_in_seconds,
+ size_t operands_limit = 0);
+
+ virtual bool FullMergeV2(const MergeOperationInput& merge_in,
+ MergeOperationOutput* merge_out) const override;
+
+ virtual bool PartialMergeMulti(const Slice& key,
+ const std::deque<Slice>& operand_list,
+ std::string* new_value,
+ Logger* logger) const override;
+
+ const char* Name() const override { return kClassName(); }
+ static const char* kClassName() { return "CassandraValueMergeOperator"; }
+
+ virtual bool AllowSingleOperand() const override { return true; }
+
+ virtual bool ShouldMerge(const std::vector<Slice>& operands) const override {
+ return options_.operands_limit > 0 &&
+ operands.size() >= options_.operands_limit;
+ }
+
+ private:
+ CassandraOptions options_;
+};
+} // namespace cassandra
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/cassandra/serialize.h b/src/rocksdb/utilities/cassandra/serialize.h
new file mode 100644
index 000000000..4bd552bfc
--- /dev/null
+++ b/src/rocksdb/utilities/cassandra/serialize.h
@@ -0,0 +1,81 @@
+// Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+/**
+ * Helper functions which serialize and deserialize integers
+ * into bytes in big endian.
+ */
+
+#pragma once
+
+#include <cstdint>
+#include <string>
+
+#include "rocksdb/rocksdb_namespace.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace cassandra {
+namespace {
+const int64_t kCharMask = 0xFFLL;
+const int32_t kBitsPerByte = 8;
+} // namespace
+
+template <typename T>
+void Serialize(T val, std::string* dest);
+
+template <typename T>
+T Deserialize(const char* src, std::size_t offset = 0);
+
+// Specializations
+template <>
+inline void Serialize<int8_t>(int8_t t, std::string* dest) {
+ dest->append(1, static_cast<char>(t & kCharMask));
+}
+
+template <>
+inline void Serialize<int32_t>(int32_t t, std::string* dest) {
+ for (unsigned long i = 0; i < sizeof(int32_t); i++) {
+ dest->append(
+ 1, static_cast<char>((t >> (sizeof(int32_t) - 1 - i) * kBitsPerByte) &
+ kCharMask));
+ }
+}
+
+template <>
+inline void Serialize<int64_t>(int64_t t, std::string* dest) {
+ for (unsigned long i = 0; i < sizeof(int64_t); i++) {
+ dest->append(
+ 1, static_cast<char>((t >> (sizeof(int64_t) - 1 - i) * kBitsPerByte) &
+ kCharMask));
+ }
+}
+
+template <>
+inline int8_t Deserialize<int8_t>(const char* src, std::size_t offset) {
+ return static_cast<int8_t>(src[offset]);
+}
+
+template <>
+inline int32_t Deserialize<int32_t>(const char* src, std::size_t offset) {
+ int32_t result = 0;
+ for (unsigned long i = 0; i < sizeof(int32_t); i++) {
+ result |= static_cast<int32_t>(static_cast<unsigned char>(src[offset + i]))
+ << ((sizeof(int32_t) - 1 - i) * kBitsPerByte);
+ }
+ return result;
+}
+
+template <>
+inline int64_t Deserialize<int64_t>(const char* src, std::size_t offset) {
+ int64_t result = 0;
+ for (unsigned long i = 0; i < sizeof(int64_t); i++) {
+ result |= static_cast<int64_t>(static_cast<unsigned char>(src[offset + i]))
+ << ((sizeof(int64_t) - 1 - i) * kBitsPerByte);
+ }
+ return result;
+}
+
+} // namespace cassandra
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/cassandra/test_utils.cc b/src/rocksdb/utilities/cassandra/test_utils.cc
new file mode 100644
index 000000000..ec6e5752d
--- /dev/null
+++ b/src/rocksdb/utilities/cassandra/test_utils.cc
@@ -0,0 +1,69 @@
+// Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "test_utils.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace cassandra {
+const char kData[] = {'d', 'a', 't', 'a'};
+const char kExpiringData[] = {'e', 'd', 'a', 't', 'a'};
+const int32_t kTtl = 86400;
+const int8_t kColumn = 0;
+const int8_t kTombstone = 1;
+const int8_t kExpiringColumn = 2;
+
+std::shared_ptr<ColumnBase> CreateTestColumn(int8_t mask, int8_t index,
+ int64_t timestamp) {
+ if ((mask & ColumnTypeMask::DELETION_MASK) != 0) {
+ return std::shared_ptr<Tombstone>(
+ new Tombstone(mask, index, ToSeconds(timestamp), timestamp));
+ } else if ((mask & ColumnTypeMask::EXPIRATION_MASK) != 0) {
+ return std::shared_ptr<ExpiringColumn>(new ExpiringColumn(
+ mask, index, timestamp, sizeof(kExpiringData), kExpiringData, kTtl));
+ } else {
+ return std::shared_ptr<Column>(
+ new Column(mask, index, timestamp, sizeof(kData), kData));
+ }
+}
+
+std::tuple<int8_t, int8_t, int64_t> CreateTestColumnSpec(int8_t mask,
+ int8_t index,
+ int64_t timestamp) {
+ return std::make_tuple(mask, index, timestamp);
+}
+
+RowValue CreateTestRowValue(
+ std::vector<std::tuple<int8_t, int8_t, int64_t>> column_specs) {
+ std::vector<std::shared_ptr<ColumnBase>> columns;
+ int64_t last_modified_time = 0;
+ for (auto spec : column_specs) {
+ auto c = CreateTestColumn(std::get<0>(spec), std::get<1>(spec),
+ std::get<2>(spec));
+ last_modified_time = std::max(last_modified_time, c->Timestamp());
+ columns.push_back(std::move(c));
+ }
+ return RowValue(std::move(columns), last_modified_time);
+}
+
+RowValue CreateRowTombstone(int64_t timestamp) {
+ return RowValue(ToSeconds(timestamp), timestamp);
+}
+
+void VerifyRowValueColumns(
+ const std::vector<std::shared_ptr<ColumnBase>> &columns,
+ std::size_t index_of_vector, int8_t expected_mask, int8_t expected_index,
+ int64_t expected_timestamp) {
+ EXPECT_EQ(expected_timestamp, columns[index_of_vector]->Timestamp());
+ EXPECT_EQ(expected_mask, columns[index_of_vector]->Mask());
+ EXPECT_EQ(expected_index, columns[index_of_vector]->Index());
+}
+
+int64_t ToMicroSeconds(int64_t seconds) { return seconds * (int64_t)1000000; }
+
+int32_t ToSeconds(int64_t microseconds) {
+ return (int32_t)(microseconds / (int64_t)1000000);
+}
+} // namespace cassandra
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/cassandra/test_utils.h b/src/rocksdb/utilities/cassandra/test_utils.h
new file mode 100644
index 000000000..be23f7076
--- /dev/null
+++ b/src/rocksdb/utilities/cassandra/test_utils.h
@@ -0,0 +1,42 @@
+// Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+#include <memory>
+
+#include "test_util/testharness.h"
+#include "utilities/cassandra/format.h"
+#include "utilities/cassandra/serialize.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace cassandra {
+extern const char kData[];
+extern const char kExpiringData[];
+extern const int32_t kTtl;
+extern const int8_t kColumn;
+extern const int8_t kTombstone;
+extern const int8_t kExpiringColumn;
+
+std::shared_ptr<ColumnBase> CreateTestColumn(int8_t mask, int8_t index,
+ int64_t timestamp);
+
+std::tuple<int8_t, int8_t, int64_t> CreateTestColumnSpec(int8_t mask,
+ int8_t index,
+ int64_t timestamp);
+
+RowValue CreateTestRowValue(
+ std::vector<std::tuple<int8_t, int8_t, int64_t>> column_specs);
+
+RowValue CreateRowTombstone(int64_t timestamp);
+
+void VerifyRowValueColumns(
+ const std::vector<std::shared_ptr<ColumnBase>> &columns,
+ std::size_t index_of_vector, int8_t expected_mask, int8_t expected_index,
+ int64_t expected_timestamp);
+
+int64_t ToMicroSeconds(int64_t seconds);
+int32_t ToSeconds(int64_t microseconds);
+} // namespace cassandra
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/checkpoint/checkpoint_impl.cc b/src/rocksdb/utilities/checkpoint/checkpoint_impl.cc
new file mode 100644
index 000000000..44ce70b1b
--- /dev/null
+++ b/src/rocksdb/utilities/checkpoint/checkpoint_impl.cc
@@ -0,0 +1,469 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2012 Facebook.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef ROCKSDB_LITE
+
+#include "utilities/checkpoint/checkpoint_impl.h"
+
+#include <algorithm>
+#include <cinttypes>
+#include <string>
+#include <tuple>
+#include <unordered_set>
+#include <vector>
+
+#include "db/wal_manager.h"
+#include "file/file_util.h"
+#include "file/filename.h"
+#include "logging/logging.h"
+#include "port/port.h"
+#include "rocksdb/db.h"
+#include "rocksdb/env.h"
+#include "rocksdb/metadata.h"
+#include "rocksdb/options.h"
+#include "rocksdb/transaction_log.h"
+#include "rocksdb/types.h"
+#include "rocksdb/utilities/checkpoint.h"
+#include "test_util/sync_point.h"
+#include "util/cast_util.h"
+#include "util/file_checksum_helper.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+Status Checkpoint::Create(DB* db, Checkpoint** checkpoint_ptr) {
+ *checkpoint_ptr = new CheckpointImpl(db);
+ return Status::OK();
+}
+
+Status Checkpoint::CreateCheckpoint(const std::string& /*checkpoint_dir*/,
+ uint64_t /*log_size_for_flush*/,
+ uint64_t* /*sequence_number_ptr*/) {
+ return Status::NotSupported("");
+}
+
+void CheckpointImpl::CleanStagingDirectory(const std::string& full_private_path,
+ Logger* info_log) {
+ std::vector<std::string> subchildren;
+ Status s = db_->GetEnv()->FileExists(full_private_path);
+ if (s.IsNotFound()) {
+ return;
+ }
+ ROCKS_LOG_INFO(info_log, "File exists %s -- %s", full_private_path.c_str(),
+ s.ToString().c_str());
+ s = db_->GetEnv()->GetChildren(full_private_path, &subchildren);
+ if (s.ok()) {
+ for (auto& subchild : subchildren) {
+ std::string subchild_path = full_private_path + "/" + subchild;
+ s = db_->GetEnv()->DeleteFile(subchild_path);
+ ROCKS_LOG_INFO(info_log, "Delete file %s -- %s", subchild_path.c_str(),
+ s.ToString().c_str());
+ }
+ }
+ // finally delete the private dir
+ s = db_->GetEnv()->DeleteDir(full_private_path);
+ ROCKS_LOG_INFO(info_log, "Delete dir %s -- %s", full_private_path.c_str(),
+ s.ToString().c_str());
+}
+
+Status Checkpoint::ExportColumnFamily(
+ ColumnFamilyHandle* /*handle*/, const std::string& /*export_dir*/,
+ ExportImportFilesMetaData** /*metadata*/) {
+ return Status::NotSupported("");
+}
+
+// Builds an openable snapshot of RocksDB
+Status CheckpointImpl::CreateCheckpoint(const std::string& checkpoint_dir,
+ uint64_t log_size_for_flush,
+ uint64_t* sequence_number_ptr) {
+ DBOptions db_options = db_->GetDBOptions();
+
+ Status s = db_->GetEnv()->FileExists(checkpoint_dir);
+ if (s.ok()) {
+ return Status::InvalidArgument("Directory exists");
+ } else if (!s.IsNotFound()) {
+ assert(s.IsIOError());
+ return s;
+ }
+
+ ROCKS_LOG_INFO(
+ db_options.info_log,
+ "Started the snapshot process -- creating snapshot in directory %s",
+ checkpoint_dir.c_str());
+
+ size_t final_nonslash_idx = checkpoint_dir.find_last_not_of('/');
+ if (final_nonslash_idx == std::string::npos) {
+ // npos means it's only slashes or empty. Non-empty means it's the root
+ // directory, but it shouldn't be because we verified above the directory
+ // doesn't exist.
+ assert(checkpoint_dir.empty());
+ return Status::InvalidArgument("invalid checkpoint directory name");
+ }
+
+ std::string full_private_path =
+ checkpoint_dir.substr(0, final_nonslash_idx + 1) + ".tmp";
+ ROCKS_LOG_INFO(db_options.info_log,
+ "Snapshot process -- using temporary directory %s",
+ full_private_path.c_str());
+ CleanStagingDirectory(full_private_path, db_options.info_log.get());
+ // create snapshot directory
+ s = db_->GetEnv()->CreateDir(full_private_path);
+ uint64_t sequence_number = 0;
+ if (s.ok()) {
+ // enable file deletions
+ s = db_->DisableFileDeletions();
+ const bool disabled_file_deletions = s.ok();
+
+ if (s.ok() || s.IsNotSupported()) {
+ s = CreateCustomCheckpoint(
+ [&](const std::string& src_dirname, const std::string& fname,
+ FileType) {
+ ROCKS_LOG_INFO(db_options.info_log, "Hard Linking %s",
+ fname.c_str());
+ return db_->GetFileSystem()->LinkFile(
+ src_dirname + "/" + fname, full_private_path + "/" + fname,
+ IOOptions(), nullptr);
+ } /* link_file_cb */,
+ [&](const std::string& src_dirname, const std::string& fname,
+ uint64_t size_limit_bytes, FileType,
+ const std::string& /* checksum_func_name */,
+ const std::string& /* checksum_val */,
+ const Temperature temperature) {
+ ROCKS_LOG_INFO(db_options.info_log, "Copying %s", fname.c_str());
+ return CopyFile(db_->GetFileSystem(), src_dirname + "/" + fname,
+ full_private_path + "/" + fname, size_limit_bytes,
+ db_options.use_fsync, nullptr, temperature);
+ } /* copy_file_cb */,
+ [&](const std::string& fname, const std::string& contents, FileType) {
+ ROCKS_LOG_INFO(db_options.info_log, "Creating %s", fname.c_str());
+ return CreateFile(db_->GetFileSystem(),
+ full_private_path + "/" + fname, contents,
+ db_options.use_fsync);
+ } /* create_file_cb */,
+ &sequence_number, log_size_for_flush);
+
+ // we copied all the files, enable file deletions
+ if (disabled_file_deletions) {
+ Status ss = db_->EnableFileDeletions(false);
+ assert(ss.ok());
+ ss.PermitUncheckedError();
+ }
+ }
+ }
+
+ if (s.ok()) {
+ // move tmp private backup to real snapshot directory
+ s = db_->GetEnv()->RenameFile(full_private_path, checkpoint_dir);
+ }
+ if (s.ok()) {
+ std::unique_ptr<FSDirectory> checkpoint_directory;
+ s = db_->GetFileSystem()->NewDirectory(checkpoint_dir, IOOptions(),
+ &checkpoint_directory, nullptr);
+ if (s.ok() && checkpoint_directory != nullptr) {
+ s = checkpoint_directory->FsyncWithDirOptions(
+ IOOptions(), nullptr,
+ DirFsyncOptions(DirFsyncOptions::FsyncReason::kDirRenamed));
+ }
+ }
+
+ if (s.ok()) {
+ if (sequence_number_ptr != nullptr) {
+ *sequence_number_ptr = sequence_number;
+ }
+ // here we know that we succeeded and installed the new snapshot
+ ROCKS_LOG_INFO(db_options.info_log, "Snapshot DONE. All is good");
+ ROCKS_LOG_INFO(db_options.info_log, "Snapshot sequence number: %" PRIu64,
+ sequence_number);
+ } else {
+ // clean all the files we might have created
+ ROCKS_LOG_INFO(db_options.info_log, "Snapshot failed -- %s",
+ s.ToString().c_str());
+ CleanStagingDirectory(full_private_path, db_options.info_log.get());
+ }
+ return s;
+}
+
+Status CheckpointImpl::CreateCustomCheckpoint(
+ std::function<Status(const std::string& src_dirname,
+ const std::string& src_fname, FileType type)>
+ link_file_cb,
+ std::function<
+ Status(const std::string& src_dirname, const std::string& src_fname,
+ uint64_t size_limit_bytes, FileType type,
+ const std::string& checksum_func_name,
+ const std::string& checksum_val, const Temperature temperature)>
+ copy_file_cb,
+ std::function<Status(const std::string& fname, const std::string& contents,
+ FileType type)>
+ create_file_cb,
+ uint64_t* sequence_number, uint64_t log_size_for_flush,
+ bool get_live_table_checksum) {
+ *sequence_number = db_->GetLatestSequenceNumber();
+
+ LiveFilesStorageInfoOptions opts;
+ opts.include_checksum_info = get_live_table_checksum;
+ opts.wal_size_for_flush = log_size_for_flush;
+
+ std::vector<LiveFileStorageInfo> infos;
+ {
+ Status s = db_->GetLiveFilesStorageInfo(opts, &infos);
+ if (!s.ok()) {
+ return s;
+ }
+ }
+
+ // Verify that everything except WAL files are in same directory
+ // (db_paths / cf_paths not supported)
+ std::unordered_set<std::string> dirs;
+ for (auto& info : infos) {
+ if (info.file_type != kWalFile) {
+ dirs.insert(info.directory);
+ }
+ }
+ if (dirs.size() > 1) {
+ return Status::NotSupported(
+ "db_paths / cf_paths not supported for Checkpoint nor BackupEngine");
+ }
+
+ bool same_fs = true;
+
+ for (auto& info : infos) {
+ Status s;
+ if (!info.replacement_contents.empty()) {
+ // Currently should only be used for CURRENT file.
+ assert(info.file_type == kCurrentFile);
+
+ if (info.size != info.replacement_contents.size()) {
+ s = Status::Corruption("Inconsistent size metadata for " +
+ info.relative_filename);
+ } else {
+ s = create_file_cb(info.relative_filename, info.replacement_contents,
+ info.file_type);
+ }
+ } else {
+ if (same_fs && !info.trim_to_size) {
+ s = link_file_cb(info.directory, info.relative_filename,
+ info.file_type);
+ if (s.IsNotSupported()) {
+ same_fs = false;
+ s = Status::OK();
+ }
+ s.MustCheck();
+ }
+ if (!same_fs || info.trim_to_size) {
+ assert(info.file_checksum_func_name.empty() ==
+ !opts.include_checksum_info);
+ // no assertion on file_checksum because empty is used for both "not
+ // set" and "unknown"
+ if (opts.include_checksum_info) {
+ s = copy_file_cb(info.directory, info.relative_filename, info.size,
+ info.file_type, info.file_checksum_func_name,
+ info.file_checksum, info.temperature);
+ } else {
+ s = copy_file_cb(info.directory, info.relative_filename, info.size,
+ info.file_type, kUnknownFileChecksumFuncName,
+ kUnknownFileChecksum, info.temperature);
+ }
+ }
+ }
+ if (!s.ok()) {
+ return s;
+ }
+ }
+
+ return Status::OK();
+}
+
+// Exports all live SST files of a specified Column Family onto export_dir,
+// returning SST files information in metadata.
+Status CheckpointImpl::ExportColumnFamily(
+ ColumnFamilyHandle* handle, const std::string& export_dir,
+ ExportImportFilesMetaData** metadata) {
+ auto cfh = static_cast_with_check<ColumnFamilyHandleImpl>(handle);
+ const auto cf_name = cfh->GetName();
+ const auto db_options = db_->GetDBOptions();
+
+ assert(metadata != nullptr);
+ assert(*metadata == nullptr);
+ auto s = db_->GetEnv()->FileExists(export_dir);
+ if (s.ok()) {
+ return Status::InvalidArgument("Specified export_dir exists");
+ } else if (!s.IsNotFound()) {
+ assert(s.IsIOError());
+ return s;
+ }
+
+ const auto final_nonslash_idx = export_dir.find_last_not_of('/');
+ if (final_nonslash_idx == std::string::npos) {
+ return Status::InvalidArgument("Specified export_dir invalid");
+ }
+ ROCKS_LOG_INFO(db_options.info_log,
+ "[%s] export column family onto export directory %s",
+ cf_name.c_str(), export_dir.c_str());
+
+ // Create a temporary export directory.
+ const auto tmp_export_dir =
+ export_dir.substr(0, final_nonslash_idx + 1) + ".tmp";
+ s = db_->GetEnv()->CreateDir(tmp_export_dir);
+
+ if (s.ok()) {
+ s = db_->Flush(ROCKSDB_NAMESPACE::FlushOptions(), handle);
+ }
+
+ ColumnFamilyMetaData db_metadata;
+ if (s.ok()) {
+ // Export live sst files with file deletions disabled.
+ s = db_->DisableFileDeletions();
+ if (s.ok()) {
+ db_->GetColumnFamilyMetaData(handle, &db_metadata);
+
+ s = ExportFilesInMetaData(
+ db_options, db_metadata,
+ [&](const std::string& src_dirname, const std::string& fname) {
+ ROCKS_LOG_INFO(db_options.info_log, "[%s] HardLinking %s",
+ cf_name.c_str(), fname.c_str());
+ return db_->GetEnv()->LinkFile(src_dirname + fname,
+ tmp_export_dir + fname);
+ } /*link_file_cb*/,
+ [&](const std::string& src_dirname, const std::string& fname) {
+ ROCKS_LOG_INFO(db_options.info_log, "[%s] Copying %s",
+ cf_name.c_str(), fname.c_str());
+ return CopyFile(db_->GetFileSystem(), src_dirname + fname,
+ tmp_export_dir + fname, 0, db_options.use_fsync,
+ nullptr, Temperature::kUnknown);
+ } /*copy_file_cb*/);
+
+ const auto enable_status = db_->EnableFileDeletions(false /*force*/);
+ if (s.ok()) {
+ s = enable_status;
+ }
+ }
+ }
+
+ auto moved_to_user_specified_dir = false;
+ if (s.ok()) {
+ // Move temporary export directory to the actual export directory.
+ s = db_->GetEnv()->RenameFile(tmp_export_dir, export_dir);
+ }
+
+ if (s.ok()) {
+ // Fsync export directory.
+ moved_to_user_specified_dir = true;
+ std::unique_ptr<FSDirectory> dir_ptr;
+ s = db_->GetFileSystem()->NewDirectory(export_dir, IOOptions(), &dir_ptr,
+ nullptr);
+ if (s.ok()) {
+ assert(dir_ptr != nullptr);
+ s = dir_ptr->FsyncWithDirOptions(
+ IOOptions(), nullptr,
+ DirFsyncOptions(DirFsyncOptions::FsyncReason::kDirRenamed));
+ }
+ }
+
+ if (s.ok()) {
+ // Export of files succeeded. Fill in the metadata information.
+ auto result_metadata = new ExportImportFilesMetaData();
+ result_metadata->db_comparator_name = handle->GetComparator()->Name();
+ for (const auto& level_metadata : db_metadata.levels) {
+ for (const auto& file_metadata : level_metadata.files) {
+ LiveFileMetaData live_file_metadata;
+ live_file_metadata.size = file_metadata.size;
+ live_file_metadata.name = std::move(file_metadata.name);
+ live_file_metadata.file_number = file_metadata.file_number;
+ live_file_metadata.db_path = export_dir;
+ live_file_metadata.smallest_seqno = file_metadata.smallest_seqno;
+ live_file_metadata.largest_seqno = file_metadata.largest_seqno;
+ live_file_metadata.smallestkey = std::move(file_metadata.smallestkey);
+ live_file_metadata.largestkey = std::move(file_metadata.largestkey);
+ live_file_metadata.oldest_blob_file_number =
+ file_metadata.oldest_blob_file_number;
+ live_file_metadata.level = level_metadata.level;
+ result_metadata->files.push_back(live_file_metadata);
+ }
+ *metadata = result_metadata;
+ }
+ ROCKS_LOG_INFO(db_options.info_log, "[%s] Export succeeded.",
+ cf_name.c_str());
+ } else {
+ // Failure: Clean up all the files/directories created.
+ ROCKS_LOG_INFO(db_options.info_log, "[%s] Export failed. %s",
+ cf_name.c_str(), s.ToString().c_str());
+ std::vector<std::string> subchildren;
+ const auto cleanup_dir =
+ moved_to_user_specified_dir ? export_dir : tmp_export_dir;
+ db_->GetEnv()->GetChildren(cleanup_dir, &subchildren);
+ for (const auto& subchild : subchildren) {
+ const auto subchild_path = cleanup_dir + "/" + subchild;
+ const auto status = db_->GetEnv()->DeleteFile(subchild_path);
+ if (!status.ok()) {
+ ROCKS_LOG_WARN(db_options.info_log, "Failed to cleanup file %s: %s",
+ subchild_path.c_str(), status.ToString().c_str());
+ }
+ }
+ const auto status = db_->GetEnv()->DeleteDir(cleanup_dir);
+ if (!status.ok()) {
+ ROCKS_LOG_WARN(db_options.info_log, "Failed to cleanup dir %s: %s",
+ cleanup_dir.c_str(), status.ToString().c_str());
+ }
+ }
+ return s;
+}
+
+Status CheckpointImpl::ExportFilesInMetaData(
+ const DBOptions& db_options, const ColumnFamilyMetaData& metadata,
+ std::function<Status(const std::string& src_dirname,
+ const std::string& src_fname)>
+ link_file_cb,
+ std::function<Status(const std::string& src_dirname,
+ const std::string& src_fname)>
+ copy_file_cb) {
+ Status s;
+ auto hardlink_file = true;
+
+ // Copy/hard link files in metadata.
+ size_t num_files = 0;
+ for (const auto& level_metadata : metadata.levels) {
+ for (const auto& file_metadata : level_metadata.files) {
+ uint64_t number;
+ FileType type;
+ const auto ok = ParseFileName(file_metadata.name, &number, &type);
+ if (!ok) {
+ s = Status::Corruption("Could not parse file name");
+ break;
+ }
+
+ // We should only get sst files here.
+ assert(type == kTableFile);
+ assert(file_metadata.size > 0 && file_metadata.name[0] == '/');
+ const auto src_fname = file_metadata.name;
+ ++num_files;
+
+ if (hardlink_file) {
+ s = link_file_cb(db_->GetName(), src_fname);
+ if (num_files == 1 && s.IsNotSupported()) {
+ // Fallback to copy if link failed due to cross-device directories.
+ hardlink_file = false;
+ s = Status::OK();
+ }
+ }
+ if (!hardlink_file) {
+ s = copy_file_cb(db_->GetName(), src_fname);
+ }
+ if (!s.ok()) {
+ break;
+ }
+ }
+ }
+ ROCKS_LOG_INFO(db_options.info_log, "Number of table files %" ROCKSDB_PRIszt,
+ num_files);
+
+ return s;
+}
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/checkpoint/checkpoint_impl.h b/src/rocksdb/utilities/checkpoint/checkpoint_impl.h
new file mode 100644
index 000000000..2947330cc
--- /dev/null
+++ b/src/rocksdb/utilities/checkpoint/checkpoint_impl.h
@@ -0,0 +1,66 @@
+// Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+#ifndef ROCKSDB_LITE
+
+#include <string>
+
+#include "file/filename.h"
+#include "rocksdb/db.h"
+#include "rocksdb/utilities/checkpoint.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class CheckpointImpl : public Checkpoint {
+ public:
+ explicit CheckpointImpl(DB* db) : db_(db) {}
+
+ Status CreateCheckpoint(const std::string& checkpoint_dir,
+ uint64_t log_size_for_flush,
+ uint64_t* sequence_number_ptr) override;
+
+ Status ExportColumnFamily(ColumnFamilyHandle* handle,
+ const std::string& export_dir,
+ ExportImportFilesMetaData** metadata) override;
+
+ // Checkpoint logic can be customized by providing callbacks for link, copy,
+ // or create.
+ Status CreateCustomCheckpoint(
+ std::function<Status(const std::string& src_dirname,
+ const std::string& fname, FileType type)>
+ link_file_cb,
+ std::function<Status(const std::string& src_dirname,
+ const std::string& fname, uint64_t size_limit_bytes,
+ FileType type, const std::string& checksum_func_name,
+ const std::string& checksum_val,
+ const Temperature src_temperature)>
+ copy_file_cb,
+ std::function<Status(const std::string& fname,
+ const std::string& contents, FileType type)>
+ create_file_cb,
+ uint64_t* sequence_number, uint64_t log_size_for_flush,
+ bool get_live_table_checksum = false);
+
+ private:
+ void CleanStagingDirectory(const std::string& path, Logger* info_log);
+
+ // Export logic customization by providing callbacks for link or copy.
+ Status ExportFilesInMetaData(
+ const DBOptions& db_options, const ColumnFamilyMetaData& metadata,
+ std::function<Status(const std::string& src_dirname,
+ const std::string& fname)>
+ link_file_cb,
+ std::function<Status(const std::string& src_dirname,
+ const std::string& fname)>
+ copy_file_cb);
+
+ private:
+ DB* db_;
+};
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/checkpoint/checkpoint_test.cc b/src/rocksdb/utilities/checkpoint/checkpoint_test.cc
new file mode 100644
index 000000000..3da753d5f
--- /dev/null
+++ b/src/rocksdb/utilities/checkpoint/checkpoint_test.cc
@@ -0,0 +1,974 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+// Syncpoint prevents us building and running tests in release
+#ifndef ROCKSDB_LITE
+#include "rocksdb/utilities/checkpoint.h"
+
+#ifndef OS_WIN
+#include <unistd.h>
+#endif
+#include <iostream>
+#include <thread>
+#include <utility>
+
+#include "db/db_impl/db_impl.h"
+#include "file/file_util.h"
+#include "port/port.h"
+#include "port/stack_trace.h"
+#include "rocksdb/db.h"
+#include "rocksdb/env.h"
+#include "rocksdb/utilities/transaction_db.h"
+#include "test_util/sync_point.h"
+#include "test_util/testharness.h"
+#include "test_util/testutil.h"
+#include "utilities/fault_injection_env.h"
+#include "utilities/fault_injection_fs.h"
+
+namespace ROCKSDB_NAMESPACE {
+class CheckpointTest : public testing::Test {
+ protected:
+ // Sequence of option configurations to try
+ enum OptionConfig {
+ kDefault = 0,
+ };
+ int option_config_;
+
+ public:
+ std::string dbname_;
+ std::string alternative_wal_dir_;
+ Env* env_;
+ DB* db_;
+ Options last_options_;
+ std::vector<ColumnFamilyHandle*> handles_;
+ std::string snapshot_name_;
+ std::string export_path_;
+ ColumnFamilyHandle* cfh_reverse_comp_;
+ ExportImportFilesMetaData* metadata_;
+
+ CheckpointTest() : env_(Env::Default()) {
+ env_->SetBackgroundThreads(1, Env::LOW);
+ env_->SetBackgroundThreads(1, Env::HIGH);
+ dbname_ = test::PerThreadDBPath(env_, "checkpoint_test");
+ alternative_wal_dir_ = dbname_ + "/wal";
+ auto options = CurrentOptions();
+ auto delete_options = options;
+ delete_options.wal_dir = alternative_wal_dir_;
+ EXPECT_OK(DestroyDB(dbname_, delete_options));
+ // Destroy it for not alternative WAL dir is used.
+ EXPECT_OK(DestroyDB(dbname_, options));
+ db_ = nullptr;
+ snapshot_name_ = test::PerThreadDBPath(env_, "snapshot");
+ std::string snapshot_tmp_name = snapshot_name_ + ".tmp";
+ EXPECT_OK(DestroyDB(snapshot_name_, options));
+ test::DeleteDir(env_, snapshot_name_);
+ EXPECT_OK(DestroyDB(snapshot_tmp_name, options));
+ test::DeleteDir(env_, snapshot_tmp_name);
+ Reopen(options);
+ export_path_ = test::PerThreadDBPath("/export");
+ DestroyDir(env_, export_path_).PermitUncheckedError();
+ cfh_reverse_comp_ = nullptr;
+ metadata_ = nullptr;
+ }
+
+ ~CheckpointTest() override {
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency({});
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->ClearAllCallBacks();
+ if (cfh_reverse_comp_) {
+ EXPECT_OK(db_->DestroyColumnFamilyHandle(cfh_reverse_comp_));
+ cfh_reverse_comp_ = nullptr;
+ }
+ if (metadata_) {
+ delete metadata_;
+ metadata_ = nullptr;
+ }
+ Close();
+ Options options;
+ options.db_paths.emplace_back(dbname_, 0);
+ options.db_paths.emplace_back(dbname_ + "_2", 0);
+ options.db_paths.emplace_back(dbname_ + "_3", 0);
+ options.db_paths.emplace_back(dbname_ + "_4", 0);
+ EXPECT_OK(DestroyDB(dbname_, options));
+ EXPECT_OK(DestroyDB(snapshot_name_, options));
+ DestroyDir(env_, export_path_).PermitUncheckedError();
+ }
+
+ // Return the current option configuration.
+ Options CurrentOptions() {
+ Options options;
+ options.env = env_;
+ options.create_if_missing = true;
+ return options;
+ }
+
+ void CreateColumnFamilies(const std::vector<std::string>& cfs,
+ const Options& options) {
+ ColumnFamilyOptions cf_opts(options);
+ size_t cfi = handles_.size();
+ handles_.resize(cfi + cfs.size());
+ for (auto cf : cfs) {
+ ASSERT_OK(db_->CreateColumnFamily(cf_opts, cf, &handles_[cfi++]));
+ }
+ }
+
+ void CreateAndReopenWithCF(const std::vector<std::string>& cfs,
+ const Options& options) {
+ CreateColumnFamilies(cfs, options);
+ std::vector<std::string> cfs_plus_default = cfs;
+ cfs_plus_default.insert(cfs_plus_default.begin(), kDefaultColumnFamilyName);
+ ReopenWithColumnFamilies(cfs_plus_default, options);
+ }
+
+ void ReopenWithColumnFamilies(const std::vector<std::string>& cfs,
+ const std::vector<Options>& options) {
+ ASSERT_OK(TryReopenWithColumnFamilies(cfs, options));
+ }
+
+ void ReopenWithColumnFamilies(const std::vector<std::string>& cfs,
+ const Options& options) {
+ ASSERT_OK(TryReopenWithColumnFamilies(cfs, options));
+ }
+
+ Status TryReopenWithColumnFamilies(const std::vector<std::string>& cfs,
+ const std::vector<Options>& options) {
+ Close();
+ EXPECT_EQ(cfs.size(), options.size());
+ std::vector<ColumnFamilyDescriptor> column_families;
+ for (size_t i = 0; i < cfs.size(); ++i) {
+ column_families.push_back(ColumnFamilyDescriptor(cfs[i], options[i]));
+ }
+ DBOptions db_opts = DBOptions(options[0]);
+ return DB::Open(db_opts, dbname_, column_families, &handles_, &db_);
+ }
+
+ Status TryReopenWithColumnFamilies(const std::vector<std::string>& cfs,
+ const Options& options) {
+ Close();
+ std::vector<Options> v_opts(cfs.size(), options);
+ return TryReopenWithColumnFamilies(cfs, v_opts);
+ }
+
+ void Reopen(const Options& options) { ASSERT_OK(TryReopen(options)); }
+
+ void CompactAll() {
+ for (auto h : handles_) {
+ ASSERT_OK(db_->CompactRange(CompactRangeOptions(), h, nullptr, nullptr));
+ }
+ }
+
+ void Close() {
+ for (auto h : handles_) {
+ delete h;
+ }
+ handles_.clear();
+ delete db_;
+ db_ = nullptr;
+ }
+
+ void DestroyAndReopen(const Options& options) {
+ // Destroy using last options
+ Destroy(last_options_);
+ ASSERT_OK(TryReopen(options));
+ }
+
+ void Destroy(const Options& options) {
+ Close();
+ ASSERT_OK(DestroyDB(dbname_, options));
+ }
+
+ Status ReadOnlyReopen(const Options& options) {
+ return DB::OpenForReadOnly(options, dbname_, &db_);
+ }
+
+ Status ReadOnlyReopenWithColumnFamilies(const std::vector<std::string>& cfs,
+ const Options& options) {
+ std::vector<ColumnFamilyDescriptor> column_families;
+ for (const auto& cf : cfs) {
+ column_families.emplace_back(cf, options);
+ }
+ return DB::OpenForReadOnly(options, dbname_, column_families, &handles_,
+ &db_);
+ }
+
+ Status TryReopen(const Options& options) {
+ Close();
+ last_options_ = options;
+ return DB::Open(options, dbname_, &db_);
+ }
+
+ Status Flush(int cf = 0) {
+ if (cf == 0) {
+ return db_->Flush(FlushOptions());
+ } else {
+ return db_->Flush(FlushOptions(), handles_[cf]);
+ }
+ }
+
+ Status Put(const Slice& k, const Slice& v, WriteOptions wo = WriteOptions()) {
+ return db_->Put(wo, k, v);
+ }
+
+ Status Put(int cf, const Slice& k, const Slice& v,
+ WriteOptions wo = WriteOptions()) {
+ return db_->Put(wo, handles_[cf], k, v);
+ }
+
+ Status Delete(const std::string& k) { return db_->Delete(WriteOptions(), k); }
+
+ Status Delete(int cf, const std::string& k) {
+ return db_->Delete(WriteOptions(), handles_[cf], k);
+ }
+
+ std::string Get(const std::string& k, const Snapshot* snapshot = nullptr) {
+ ReadOptions options;
+ options.verify_checksums = true;
+ options.snapshot = snapshot;
+ std::string result;
+ Status s = db_->Get(options, k, &result);
+ if (s.IsNotFound()) {
+ result = "NOT_FOUND";
+ } else if (!s.ok()) {
+ result = s.ToString();
+ }
+ return result;
+ }
+
+ std::string Get(int cf, const std::string& k,
+ const Snapshot* snapshot = nullptr) {
+ ReadOptions options;
+ options.verify_checksums = true;
+ options.snapshot = snapshot;
+ std::string result;
+ Status s = db_->Get(options, handles_[cf], k, &result);
+ if (s.IsNotFound()) {
+ result = "NOT_FOUND";
+ } else if (!s.ok()) {
+ result = s.ToString();
+ }
+ return result;
+ }
+};
+
+TEST_F(CheckpointTest, GetSnapshotLink) {
+ for (uint64_t log_size_for_flush : {0, 1000000}) {
+ Options options;
+ DB* snapshotDB;
+ ReadOptions roptions;
+ std::string result;
+ Checkpoint* checkpoint;
+
+ options = CurrentOptions();
+ delete db_;
+ db_ = nullptr;
+ ASSERT_OK(DestroyDB(dbname_, options));
+
+ // Create a database
+ options.create_if_missing = true;
+ ASSERT_OK(DB::Open(options, dbname_, &db_));
+ std::string key = std::string("foo");
+ ASSERT_OK(Put(key, "v1"));
+ // Take a snapshot
+ ASSERT_OK(Checkpoint::Create(db_, &checkpoint));
+ ASSERT_OK(checkpoint->CreateCheckpoint(snapshot_name_, log_size_for_flush));
+ ASSERT_OK(Put(key, "v2"));
+ ASSERT_EQ("v2", Get(key));
+ ASSERT_OK(Flush());
+ ASSERT_EQ("v2", Get(key));
+ // Open snapshot and verify contents while DB is running
+ options.create_if_missing = false;
+ ASSERT_OK(DB::Open(options, snapshot_name_, &snapshotDB));
+ ASSERT_OK(snapshotDB->Get(roptions, key, &result));
+ ASSERT_EQ("v1", result);
+ delete snapshotDB;
+ snapshotDB = nullptr;
+ delete db_;
+ db_ = nullptr;
+
+ // Destroy original DB
+ ASSERT_OK(DestroyDB(dbname_, options));
+
+ // Open snapshot and verify contents
+ options.create_if_missing = false;
+ dbname_ = snapshot_name_;
+ ASSERT_OK(DB::Open(options, dbname_, &db_));
+ ASSERT_EQ("v1", Get(key));
+ delete db_;
+ db_ = nullptr;
+ ASSERT_OK(DestroyDB(dbname_, options));
+ delete checkpoint;
+
+ // Restore DB name
+ dbname_ = test::PerThreadDBPath(env_, "db_test");
+ }
+}
+
+TEST_F(CheckpointTest, CheckpointWithBlob) {
+ // Create a database with a blob file
+ Options options = CurrentOptions();
+ options.create_if_missing = true;
+ options.enable_blob_files = true;
+ options.min_blob_size = 0;
+
+ Reopen(options);
+
+ constexpr char key[] = "key";
+ constexpr char blob[] = "blob";
+
+ ASSERT_OK(Put(key, blob));
+ ASSERT_OK(Flush());
+
+ // Create a checkpoint
+ Checkpoint* checkpoint = nullptr;
+ ASSERT_OK(Checkpoint::Create(db_, &checkpoint));
+
+ std::unique_ptr<Checkpoint> checkpoint_guard(checkpoint);
+
+ ASSERT_OK(checkpoint->CreateCheckpoint(snapshot_name_));
+
+ // Make sure it contains the blob file
+ std::vector<std::string> files;
+ ASSERT_OK(env_->GetChildren(snapshot_name_, &files));
+
+ bool blob_file_found = false;
+ for (const auto& file : files) {
+ uint64_t number = 0;
+ FileType type = kWalFile;
+
+ if (ParseFileName(file, &number, &type) && type == kBlobFile) {
+ blob_file_found = true;
+ break;
+ }
+ }
+
+ ASSERT_TRUE(blob_file_found);
+
+ // Make sure the checkpoint can be opened and the blob value read
+ options.create_if_missing = false;
+ DB* checkpoint_db = nullptr;
+ ASSERT_OK(DB::Open(options, snapshot_name_, &checkpoint_db));
+
+ std::unique_ptr<DB> checkpoint_db_guard(checkpoint_db);
+
+ PinnableSlice value;
+ ASSERT_OK(checkpoint_db->Get(
+ ReadOptions(), checkpoint_db->DefaultColumnFamily(), key, &value));
+
+ ASSERT_EQ(value, blob);
+}
+
+TEST_F(CheckpointTest, ExportColumnFamilyWithLinks) {
+ // Create a database
+ auto options = CurrentOptions();
+ options.create_if_missing = true;
+ CreateAndReopenWithCF({}, options);
+
+ // Helper to verify the number of files in metadata and export dir
+ auto verify_files_exported = [&](const ExportImportFilesMetaData& metadata,
+ int num_files_expected) {
+ ASSERT_EQ(metadata.files.size(), num_files_expected);
+ std::vector<std::string> subchildren;
+ ASSERT_OK(env_->GetChildren(export_path_, &subchildren));
+ ASSERT_EQ(subchildren.size(), num_files_expected);
+ };
+
+ // Test DefaultColumnFamily
+ {
+ const auto key = std::string("foo");
+ ASSERT_OK(Put(key, "v1"));
+
+ Checkpoint* checkpoint;
+ ASSERT_OK(Checkpoint::Create(db_, &checkpoint));
+
+ // Export the Tables and verify
+ ASSERT_OK(checkpoint->ExportColumnFamily(db_->DefaultColumnFamily(),
+ export_path_, &metadata_));
+ verify_files_exported(*metadata_, 1);
+ ASSERT_EQ(metadata_->db_comparator_name, options.comparator->Name());
+ ASSERT_OK(DestroyDir(env_, export_path_));
+ delete metadata_;
+ metadata_ = nullptr;
+
+ // Check again after compaction
+ CompactAll();
+ ASSERT_OK(Put(key, "v2"));
+ ASSERT_OK(checkpoint->ExportColumnFamily(db_->DefaultColumnFamily(),
+ export_path_, &metadata_));
+ verify_files_exported(*metadata_, 2);
+ ASSERT_EQ(metadata_->db_comparator_name, options.comparator->Name());
+ ASSERT_OK(DestroyDir(env_, export_path_));
+ delete metadata_;
+ metadata_ = nullptr;
+ delete checkpoint;
+ }
+
+ // Test non default column family with non default comparator
+ {
+ auto cf_options = CurrentOptions();
+ cf_options.comparator = ReverseBytewiseComparator();
+ ASSERT_OK(db_->CreateColumnFamily(cf_options, "yoyo", &cfh_reverse_comp_));
+
+ const auto key = std::string("foo");
+ ASSERT_OK(db_->Put(WriteOptions(), cfh_reverse_comp_, key, "v1"));
+
+ Checkpoint* checkpoint;
+ ASSERT_OK(Checkpoint::Create(db_, &checkpoint));
+
+ // Export the Tables and verify
+ ASSERT_OK(checkpoint->ExportColumnFamily(cfh_reverse_comp_, export_path_,
+ &metadata_));
+ verify_files_exported(*metadata_, 1);
+ ASSERT_EQ(metadata_->db_comparator_name,
+ ReverseBytewiseComparator()->Name());
+ delete checkpoint;
+ }
+}
+
+TEST_F(CheckpointTest, ExportColumnFamilyNegativeTest) {
+ // Create a database
+ auto options = CurrentOptions();
+ options.create_if_missing = true;
+ CreateAndReopenWithCF({}, options);
+
+ const auto key = std::string("foo");
+ ASSERT_OK(Put(key, "v1"));
+
+ Checkpoint* checkpoint;
+ ASSERT_OK(Checkpoint::Create(db_, &checkpoint));
+
+ // Export onto existing directory
+ ASSERT_OK(env_->CreateDirIfMissing(export_path_));
+ ASSERT_EQ(checkpoint->ExportColumnFamily(db_->DefaultColumnFamily(),
+ export_path_, &metadata_),
+ Status::InvalidArgument("Specified export_dir exists"));
+ ASSERT_OK(DestroyDir(env_, export_path_));
+
+ // Export with invalid directory specification
+ export_path_ = "";
+ ASSERT_EQ(checkpoint->ExportColumnFamily(db_->DefaultColumnFamily(),
+ export_path_, &metadata_),
+ Status::InvalidArgument("Specified export_dir invalid"));
+ delete checkpoint;
+}
+
+TEST_F(CheckpointTest, CheckpointCF) {
+ Options options = CurrentOptions();
+ CreateAndReopenWithCF({"one", "two", "three", "four", "five"}, options);
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency(
+ {{"CheckpointTest::CheckpointCF:2", "DBImpl::GetLiveFiles:2"},
+ {"DBImpl::GetLiveFiles:1", "CheckpointTest::CheckpointCF:1"}});
+
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+
+ ASSERT_OK(Put(0, "Default", "Default"));
+ ASSERT_OK(Put(1, "one", "one"));
+ ASSERT_OK(Put(2, "two", "two"));
+ ASSERT_OK(Put(3, "three", "three"));
+ ASSERT_OK(Put(4, "four", "four"));
+ ASSERT_OK(Put(5, "five", "five"));
+
+ DB* snapshotDB;
+ ReadOptions roptions;
+ std::string result;
+ std::vector<ColumnFamilyHandle*> cphandles;
+
+ // Take a snapshot
+ ROCKSDB_NAMESPACE::port::Thread t([&]() {
+ Checkpoint* checkpoint;
+ ASSERT_OK(Checkpoint::Create(db_, &checkpoint));
+ ASSERT_OK(checkpoint->CreateCheckpoint(snapshot_name_));
+ delete checkpoint;
+ });
+ TEST_SYNC_POINT("CheckpointTest::CheckpointCF:1");
+ ASSERT_OK(Put(0, "Default", "Default1"));
+ ASSERT_OK(Put(1, "one", "eleven"));
+ ASSERT_OK(Put(2, "two", "twelve"));
+ ASSERT_OK(Put(3, "three", "thirteen"));
+ ASSERT_OK(Put(4, "four", "fourteen"));
+ ASSERT_OK(Put(5, "five", "fifteen"));
+ TEST_SYNC_POINT("CheckpointTest::CheckpointCF:2");
+ t.join();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+ ASSERT_OK(Put(1, "one", "twentyone"));
+ ASSERT_OK(Put(2, "two", "twentytwo"));
+ ASSERT_OK(Put(3, "three", "twentythree"));
+ ASSERT_OK(Put(4, "four", "twentyfour"));
+ ASSERT_OK(Put(5, "five", "twentyfive"));
+ ASSERT_OK(Flush());
+
+ // Open snapshot and verify contents while DB is running
+ options.create_if_missing = false;
+ std::vector<std::string> cfs;
+ cfs = {kDefaultColumnFamilyName, "one", "two", "three", "four", "five"};
+ std::vector<ColumnFamilyDescriptor> column_families;
+ for (size_t i = 0; i < cfs.size(); ++i) {
+ column_families.push_back(ColumnFamilyDescriptor(cfs[i], options));
+ }
+ ASSERT_OK(DB::Open(options, snapshot_name_, column_families, &cphandles,
+ &snapshotDB));
+ ASSERT_OK(snapshotDB->Get(roptions, cphandles[0], "Default", &result));
+ ASSERT_EQ("Default1", result);
+ ASSERT_OK(snapshotDB->Get(roptions, cphandles[1], "one", &result));
+ ASSERT_EQ("eleven", result);
+ ASSERT_OK(snapshotDB->Get(roptions, cphandles[2], "two", &result));
+ for (auto h : cphandles) {
+ delete h;
+ }
+ cphandles.clear();
+ delete snapshotDB;
+ snapshotDB = nullptr;
+}
+
+TEST_F(CheckpointTest, CheckpointCFNoFlush) {
+ Options options = CurrentOptions();
+ CreateAndReopenWithCF({"one", "two", "three", "four", "five"}, options);
+
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+
+ ASSERT_OK(Put(0, "Default", "Default"));
+ ASSERT_OK(Put(1, "one", "one"));
+ ASSERT_OK(Flush());
+ ASSERT_OK(Put(2, "two", "two"));
+
+ DB* snapshotDB;
+ ReadOptions roptions;
+ std::string result;
+ std::vector<ColumnFamilyHandle*> cphandles;
+
+ // Take a snapshot
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "DBImpl::BackgroundCallFlush:start", [&](void* /*arg*/) {
+ // Flush should never trigger.
+ FAIL();
+ });
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+ Checkpoint* checkpoint;
+ ASSERT_OK(Checkpoint::Create(db_, &checkpoint));
+ ASSERT_OK(checkpoint->CreateCheckpoint(snapshot_name_, 1000000));
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+
+ delete checkpoint;
+ ASSERT_OK(Put(1, "one", "two"));
+ ASSERT_OK(Flush(1));
+ ASSERT_OK(Put(2, "two", "twentytwo"));
+ Close();
+ EXPECT_OK(DestroyDB(dbname_, options));
+
+ // Open snapshot and verify contents while DB is running
+ options.create_if_missing = false;
+ std::vector<std::string> cfs;
+ cfs = {kDefaultColumnFamilyName, "one", "two", "three", "four", "five"};
+ std::vector<ColumnFamilyDescriptor> column_families;
+ for (size_t i = 0; i < cfs.size(); ++i) {
+ column_families.push_back(ColumnFamilyDescriptor(cfs[i], options));
+ }
+ ASSERT_OK(DB::Open(options, snapshot_name_, column_families, &cphandles,
+ &snapshotDB));
+ ASSERT_OK(snapshotDB->Get(roptions, cphandles[0], "Default", &result));
+ ASSERT_EQ("Default", result);
+ ASSERT_OK(snapshotDB->Get(roptions, cphandles[1], "one", &result));
+ ASSERT_EQ("one", result);
+ ASSERT_OK(snapshotDB->Get(roptions, cphandles[2], "two", &result));
+ ASSERT_EQ("two", result);
+ for (auto h : cphandles) {
+ delete h;
+ }
+ cphandles.clear();
+ delete snapshotDB;
+ snapshotDB = nullptr;
+}
+
+TEST_F(CheckpointTest, CurrentFileModifiedWhileCheckpointing) {
+ Options options = CurrentOptions();
+ options.max_manifest_file_size = 0; // always rollover manifest for file add
+ Reopen(options);
+
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency(
+ {// Get past the flush in the checkpoint thread before adding any keys to
+ // the db so the checkpoint thread won't hit the WriteManifest
+ // syncpoints.
+ {"CheckpointImpl::CreateCheckpoint:FlushDone",
+ "CheckpointTest::CurrentFileModifiedWhileCheckpointing:PrePut"},
+ // Roll the manifest during checkpointing right after live files are
+ // snapshotted.
+ {"CheckpointImpl::CreateCheckpoint:SavedLiveFiles1",
+ "VersionSet::LogAndApply:WriteManifest"},
+ {"VersionSet::LogAndApply:WriteManifestDone",
+ "CheckpointImpl::CreateCheckpoint:SavedLiveFiles2"}});
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+
+ ROCKSDB_NAMESPACE::port::Thread t([&]() {
+ Checkpoint* checkpoint;
+ ASSERT_OK(Checkpoint::Create(db_, &checkpoint));
+ ASSERT_OK(checkpoint->CreateCheckpoint(snapshot_name_));
+ delete checkpoint;
+ });
+ TEST_SYNC_POINT(
+ "CheckpointTest::CurrentFileModifiedWhileCheckpointing:PrePut");
+ ASSERT_OK(Put("Default", "Default1"));
+ ASSERT_OK(Flush());
+ t.join();
+
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+
+ DB* snapshotDB;
+ // Successful Open() implies that CURRENT pointed to the manifest in the
+ // checkpoint.
+ ASSERT_OK(DB::Open(options, snapshot_name_, &snapshotDB));
+ delete snapshotDB;
+ snapshotDB = nullptr;
+}
+
+TEST_F(CheckpointTest, CurrentFileModifiedWhileCheckpointing2PC) {
+ Close();
+ const std::string dbname = test::PerThreadDBPath("transaction_testdb");
+ ASSERT_OK(DestroyDB(dbname, CurrentOptions()));
+ test::DeleteDir(env_, dbname);
+
+ Options options = CurrentOptions();
+ options.allow_2pc = true;
+ // allow_2pc is implicitly set with tx prepare
+ // options.allow_2pc = true;
+ TransactionDBOptions txn_db_options;
+ TransactionDB* txdb;
+ Status s = TransactionDB::Open(options, txn_db_options, dbname, &txdb);
+ ASSERT_OK(s);
+ ColumnFamilyHandle* cfa;
+ ColumnFamilyHandle* cfb;
+ ColumnFamilyOptions cf_options;
+ ASSERT_OK(txdb->CreateColumnFamily(cf_options, "CFA", &cfa));
+
+ WriteOptions write_options;
+ // Insert something into CFB so lots of log files will be kept
+ // before creating the checkpoint.
+ ASSERT_OK(txdb->CreateColumnFamily(cf_options, "CFB", &cfb));
+ ASSERT_OK(txdb->Put(write_options, cfb, "", ""));
+
+ ReadOptions read_options;
+ std::string value;
+ TransactionOptions txn_options;
+ Transaction* txn = txdb->BeginTransaction(write_options, txn_options);
+ s = txn->SetName("xid");
+ ASSERT_OK(s);
+ ASSERT_EQ(txdb->GetTransactionByName("xid"), txn);
+
+ s = txn->Put(Slice("foo"), Slice("bar"));
+ ASSERT_OK(s);
+ s = txn->Put(cfa, Slice("foocfa"), Slice("barcfa"));
+ ASSERT_OK(s);
+ // Writing prepare into middle of first WAL, then flush WALs many times
+ for (int i = 1; i <= 100000; i++) {
+ Transaction* tx = txdb->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(tx->SetName("x"));
+ ASSERT_OK(tx->Put(Slice(std::to_string(i)), Slice("val")));
+ ASSERT_OK(tx->Put(cfa, Slice("aaa"), Slice("111")));
+ ASSERT_OK(tx->Prepare());
+ ASSERT_OK(tx->Commit());
+ if (i % 10000 == 0) {
+ ASSERT_OK(txdb->Flush(FlushOptions()));
+ }
+ if (i == 88888) {
+ ASSERT_OK(txn->Prepare());
+ }
+ delete tx;
+ }
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency(
+ {{"CheckpointImpl::CreateCheckpoint:SavedLiveFiles1",
+ "CheckpointTest::CurrentFileModifiedWhileCheckpointing2PC:PreCommit"},
+ {"CheckpointTest::CurrentFileModifiedWhileCheckpointing2PC:PostCommit",
+ "CheckpointImpl::CreateCheckpoint:SavedLiveFiles2"}});
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+ ROCKSDB_NAMESPACE::port::Thread t([&]() {
+ Checkpoint* checkpoint;
+ ASSERT_OK(Checkpoint::Create(txdb, &checkpoint));
+ ASSERT_OK(checkpoint->CreateCheckpoint(snapshot_name_));
+ delete checkpoint;
+ });
+ TEST_SYNC_POINT(
+ "CheckpointTest::CurrentFileModifiedWhileCheckpointing2PC:PreCommit");
+ ASSERT_OK(txn->Commit());
+ delete txn;
+ TEST_SYNC_POINT(
+ "CheckpointTest::CurrentFileModifiedWhileCheckpointing2PC:PostCommit");
+ t.join();
+
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+
+ // No more than two logs files should exist.
+ std::vector<std::string> files;
+ ASSERT_OK(env_->GetChildren(snapshot_name_, &files));
+ int num_log_files = 0;
+ for (auto& file : files) {
+ uint64_t num;
+ FileType type;
+ WalFileType log_type;
+ if (ParseFileName(file, &num, &type, &log_type) && type == kWalFile) {
+ num_log_files++;
+ }
+ }
+ // One flush after preapare + one outstanding file before checkpoint + one log
+ // file generated after checkpoint.
+ ASSERT_LE(num_log_files, 3);
+
+ TransactionDB* snapshotDB;
+ std::vector<ColumnFamilyDescriptor> column_families;
+ column_families.push_back(
+ ColumnFamilyDescriptor(kDefaultColumnFamilyName, ColumnFamilyOptions()));
+ column_families.push_back(
+ ColumnFamilyDescriptor("CFA", ColumnFamilyOptions()));
+ column_families.push_back(
+ ColumnFamilyDescriptor("CFB", ColumnFamilyOptions()));
+ std::vector<ROCKSDB_NAMESPACE::ColumnFamilyHandle*> cf_handles;
+ ASSERT_OK(TransactionDB::Open(options, txn_db_options, snapshot_name_,
+ column_families, &cf_handles, &snapshotDB));
+ ASSERT_OK(snapshotDB->Get(read_options, "foo", &value));
+ ASSERT_EQ(value, "bar");
+ ASSERT_OK(snapshotDB->Get(read_options, cf_handles[1], "foocfa", &value));
+ ASSERT_EQ(value, "barcfa");
+
+ delete cfa;
+ delete cfb;
+ delete cf_handles[0];
+ delete cf_handles[1];
+ delete cf_handles[2];
+ delete snapshotDB;
+ snapshotDB = nullptr;
+ delete txdb;
+}
+
+TEST_F(CheckpointTest, CheckpointInvalidDirectoryName) {
+ for (std::string checkpoint_dir : {"", "/", "////"}) {
+ Checkpoint* checkpoint;
+ ASSERT_OK(Checkpoint::Create(db_, &checkpoint));
+ ASSERT_TRUE(
+ checkpoint->CreateCheckpoint(checkpoint_dir).IsInvalidArgument());
+ delete checkpoint;
+ }
+}
+
+TEST_F(CheckpointTest, CheckpointWithParallelWrites) {
+ // When run with TSAN, this exposes the data race fixed in
+ // https://github.com/facebook/rocksdb/pull/3603
+ ASSERT_OK(Put("key1", "val1"));
+ port::Thread thread([this]() { ASSERT_OK(Put("key2", "val2")); });
+ Checkpoint* checkpoint;
+ ASSERT_OK(Checkpoint::Create(db_, &checkpoint));
+ ASSERT_OK(checkpoint->CreateCheckpoint(snapshot_name_));
+ delete checkpoint;
+ thread.join();
+}
+
+TEST_F(CheckpointTest, CheckpointWithUnsyncedDataDropped) {
+ Options options = CurrentOptions();
+ std::unique_ptr<FaultInjectionTestEnv> env(new FaultInjectionTestEnv(env_));
+ options.env = env.get();
+ Reopen(options);
+ ASSERT_OK(Put("key1", "val1"));
+ Checkpoint* checkpoint;
+ ASSERT_OK(Checkpoint::Create(db_, &checkpoint));
+ ASSERT_OK(checkpoint->CreateCheckpoint(snapshot_name_));
+ delete checkpoint;
+ ASSERT_OK(env->DropUnsyncedFileData());
+
+ // make sure it's openable even though whatever data that wasn't synced got
+ // dropped.
+ options.env = env_;
+ DB* snapshot_db;
+ ASSERT_OK(DB::Open(options, snapshot_name_, &snapshot_db));
+ ReadOptions read_opts;
+ std::string get_result;
+ ASSERT_OK(snapshot_db->Get(read_opts, "key1", &get_result));
+ ASSERT_EQ("val1", get_result);
+ delete snapshot_db;
+ delete db_;
+ db_ = nullptr;
+}
+
+TEST_F(CheckpointTest, CheckpointOptionsFileFailedToPersist) {
+ // Regression test for a bug where checkpoint failed on a DB where persisting
+ // OPTIONS file failed and the DB was opened with
+ // `fail_if_options_file_error == false`.
+ Options options = CurrentOptions();
+ options.fail_if_options_file_error = false;
+ auto fault_fs = std::make_shared<FaultInjectionTestFS>(FileSystem::Default());
+
+ // Setup `FaultInjectionTestFS` and `SyncPoint` callbacks to fail one
+ // operation when inside the OPTIONS file persisting code.
+ std::unique_ptr<Env> fault_fs_env(NewCompositeEnv(fault_fs));
+ fault_fs->SetRandomMetadataWriteError(1 /* one_in */);
+ SyncPoint::GetInstance()->SetCallBack(
+ "PersistRocksDBOptions:start", [fault_fs](void* /* arg */) {
+ fault_fs->EnableMetadataWriteErrorInjection();
+ });
+ SyncPoint::GetInstance()->SetCallBack(
+ "FaultInjectionTestFS::InjectMetadataWriteError:Injected",
+ [fault_fs](void* /* arg */) {
+ fault_fs->DisableMetadataWriteErrorInjection();
+ });
+ options.env = fault_fs_env.get();
+ SyncPoint::GetInstance()->EnableProcessing();
+
+ Reopen(options);
+ ASSERT_OK(Put("key1", "val1"));
+ Checkpoint* checkpoint;
+ ASSERT_OK(Checkpoint::Create(db_, &checkpoint));
+ ASSERT_OK(checkpoint->CreateCheckpoint(snapshot_name_));
+ delete checkpoint;
+
+ // Make sure it's usable.
+ options.env = env_;
+ DB* snapshot_db;
+ ASSERT_OK(DB::Open(options, snapshot_name_, &snapshot_db));
+ ReadOptions read_opts;
+ std::string get_result;
+ ASSERT_OK(snapshot_db->Get(read_opts, "key1", &get_result));
+ ASSERT_EQ("val1", get_result);
+ delete snapshot_db;
+ delete db_;
+ db_ = nullptr;
+}
+
+TEST_F(CheckpointTest, CheckpointReadOnlyDB) {
+ ASSERT_OK(Put("foo", "foo_value"));
+ ASSERT_OK(Flush());
+ Close();
+ Options options = CurrentOptions();
+ ASSERT_OK(ReadOnlyReopen(options));
+ Checkpoint* checkpoint = nullptr;
+ ASSERT_OK(Checkpoint::Create(db_, &checkpoint));
+ ASSERT_OK(checkpoint->CreateCheckpoint(snapshot_name_));
+ delete checkpoint;
+ checkpoint = nullptr;
+ Close();
+ DB* snapshot_db = nullptr;
+ ASSERT_OK(DB::Open(options, snapshot_name_, &snapshot_db));
+ ReadOptions read_opts;
+ std::string get_result;
+ ASSERT_OK(snapshot_db->Get(read_opts, "foo", &get_result));
+ ASSERT_EQ("foo_value", get_result);
+ delete snapshot_db;
+}
+
+TEST_F(CheckpointTest, CheckpointReadOnlyDBWithMultipleColumnFamilies) {
+ Options options = CurrentOptions();
+ CreateAndReopenWithCF({"pikachu", "eevee"}, options);
+ for (int i = 0; i != 3; ++i) {
+ ASSERT_OK(Put(i, "foo", "foo_value"));
+ ASSERT_OK(Flush(i));
+ }
+ Close();
+ Status s = ReadOnlyReopenWithColumnFamilies(
+ {kDefaultColumnFamilyName, "pikachu", "eevee"}, options);
+ ASSERT_OK(s);
+ Checkpoint* checkpoint = nullptr;
+ ASSERT_OK(Checkpoint::Create(db_, &checkpoint));
+ ASSERT_OK(checkpoint->CreateCheckpoint(snapshot_name_));
+ delete checkpoint;
+ checkpoint = nullptr;
+ Close();
+
+ std::vector<ColumnFamilyDescriptor> column_families{
+ {kDefaultColumnFamilyName, options},
+ {"pikachu", options},
+ {"eevee", options}};
+ DB* snapshot_db = nullptr;
+ std::vector<ColumnFamilyHandle*> snapshot_handles;
+ s = DB::Open(options, snapshot_name_, column_families, &snapshot_handles,
+ &snapshot_db);
+ ASSERT_OK(s);
+ ReadOptions read_opts;
+ for (int i = 0; i != 3; ++i) {
+ std::string get_result;
+ s = snapshot_db->Get(read_opts, snapshot_handles[i], "foo", &get_result);
+ ASSERT_OK(s);
+ ASSERT_EQ("foo_value", get_result);
+ }
+
+ for (auto snapshot_h : snapshot_handles) {
+ delete snapshot_h;
+ }
+ snapshot_handles.clear();
+ delete snapshot_db;
+}
+
+TEST_F(CheckpointTest, CheckpointWithDbPath) {
+ Options options = CurrentOptions();
+ options.db_paths.emplace_back(dbname_ + "_2", 0);
+ Reopen(options);
+ ASSERT_OK(Put("key1", "val1"));
+ Flush();
+ Checkpoint* checkpoint;
+ ASSERT_OK(Checkpoint::Create(db_, &checkpoint));
+ // Currently not supported
+ ASSERT_TRUE(checkpoint->CreateCheckpoint(snapshot_name_).IsNotSupported());
+ delete checkpoint;
+}
+
+TEST_F(CheckpointTest, PutRaceWithCheckpointTrackedWalSync) {
+ // Repro for a race condition where a user write comes in after the checkpoint
+ // syncs WAL for `track_and_verify_wals_in_manifest` but before the
+ // corresponding MANIFEST update. With the bug, that scenario resulted in an
+ // unopenable DB with error "Corruption: Size mismatch: WAL ...".
+ Options options = CurrentOptions();
+ std::unique_ptr<FaultInjectionTestEnv> fault_env(
+ new FaultInjectionTestEnv(env_));
+ options.env = fault_env.get();
+ options.track_and_verify_wals_in_manifest = true;
+ Reopen(options);
+
+ ASSERT_OK(Put("key1", "val1"));
+
+ SyncPoint::GetInstance()->SetCallBack(
+ "DBImpl::SyncWAL:BeforeMarkLogsSynced:1",
+ [this](void* /* arg */) { ASSERT_OK(Put("key2", "val2")); });
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+
+ std::unique_ptr<Checkpoint> checkpoint;
+ {
+ Checkpoint* checkpoint_ptr;
+ ASSERT_OK(Checkpoint::Create(db_, &checkpoint_ptr));
+ checkpoint.reset(checkpoint_ptr);
+ }
+
+ ASSERT_OK(checkpoint->CreateCheckpoint(snapshot_name_));
+
+ // Ensure callback ran.
+ ASSERT_EQ("val2", Get("key2"));
+
+ Close();
+
+ // Simulate full loss of unsynced data. This drops "key2" -> "val2" from the
+ // DB WAL.
+ fault_env->DropUnsyncedFileData();
+
+ // Before the bug fix, reopening the DB would fail because the MANIFEST's
+ // AddWal entry indicated the WAL should be synced through "key2" -> "val2".
+ Reopen(options);
+
+ // Need to close before `fault_env` goes out of scope.
+ Close();
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
+
+#else
+#include <stdio.h>
+
+int main(int /*argc*/, char** /*argv*/) {
+ fprintf(stderr, "SKIPPED as Checkpoint is not supported in ROCKSDB_LITE\n");
+ return 0;
+}
+
+#endif // !ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/compaction_filters.cc b/src/rocksdb/utilities/compaction_filters.cc
new file mode 100644
index 000000000..8763901c3
--- /dev/null
+++ b/src/rocksdb/utilities/compaction_filters.cc
@@ -0,0 +1,56 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include <memory>
+
+#include "rocksdb/compaction_filter.h"
+#include "rocksdb/options.h"
+#include "rocksdb/utilities/customizable_util.h"
+#include "rocksdb/utilities/options_type.h"
+#include "utilities/compaction_filters/layered_compaction_filter_base.h"
+#include "utilities/compaction_filters/remove_emptyvalue_compactionfilter.h"
+
+namespace ROCKSDB_NAMESPACE {
+#ifndef ROCKSDB_LITE
+static int RegisterBuiltinCompactionFilters(ObjectLibrary& library,
+ const std::string& /*arg*/) {
+ library.AddFactory<CompactionFilter>(
+ RemoveEmptyValueCompactionFilter::kClassName(),
+ [](const std::string& /*uri*/,
+ std::unique_ptr<CompactionFilter>* /*guard*/,
+ std::string* /*errmsg*/) {
+ return new RemoveEmptyValueCompactionFilter();
+ });
+ return 1;
+}
+#endif // ROCKSDB_LITE
+Status CompactionFilter::CreateFromString(const ConfigOptions& config_options,
+ const std::string& value,
+ const CompactionFilter** result) {
+#ifndef ROCKSDB_LITE
+ static std::once_flag once;
+ std::call_once(once, [&]() {
+ RegisterBuiltinCompactionFilters(*(ObjectLibrary::Default().get()), "");
+ });
+#endif // ROCKSDB_LITE
+ CompactionFilter* filter = const_cast<CompactionFilter*>(*result);
+ Status status = LoadStaticObject<CompactionFilter>(config_options, value,
+ nullptr, &filter);
+ if (status.ok()) {
+ *result = const_cast<CompactionFilter*>(filter);
+ }
+ return status;
+}
+
+Status CompactionFilterFactory::CreateFromString(
+ const ConfigOptions& config_options, const std::string& value,
+ std::shared_ptr<CompactionFilterFactory>* result) {
+ // Currently there are no builtin CompactionFilterFactories.
+ // If any are introduced, they need to be registered here.
+ Status status = LoadSharedObject<CompactionFilterFactory>(
+ config_options, value, nullptr, result);
+ return status;
+}
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/compaction_filters/layered_compaction_filter_base.h b/src/rocksdb/utilities/compaction_filters/layered_compaction_filter_base.h
new file mode 100644
index 000000000..803fa94ae
--- /dev/null
+++ b/src/rocksdb/utilities/compaction_filters/layered_compaction_filter_base.h
@@ -0,0 +1,41 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#pragma once
+#include <memory>
+
+#include "rocksdb/compaction_filter.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// Abstract base class for building layered compaction filter on top of
+// user compaction filter.
+// See BlobIndexCompactionFilter or TtlCompactionFilter for a basic usage.
+class LayeredCompactionFilterBase : public CompactionFilter {
+ public:
+ LayeredCompactionFilterBase(
+ const CompactionFilter* _user_comp_filter,
+ std::unique_ptr<const CompactionFilter> _user_comp_filter_from_factory)
+ : user_comp_filter_(_user_comp_filter),
+ user_comp_filter_from_factory_(
+ std::move(_user_comp_filter_from_factory)) {
+ if (!user_comp_filter_) {
+ user_comp_filter_ = user_comp_filter_from_factory_.get();
+ }
+ }
+
+ // Return a pointer to user compaction filter
+ const CompactionFilter* user_comp_filter() const { return user_comp_filter_; }
+
+ const Customizable* Inner() const override { return user_comp_filter_; }
+
+ protected:
+ const CompactionFilter* user_comp_filter_;
+
+ private:
+ std::unique_ptr<const CompactionFilter> user_comp_filter_from_factory_;
+};
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/compaction_filters/remove_emptyvalue_compactionfilter.cc b/src/rocksdb/utilities/compaction_filters/remove_emptyvalue_compactionfilter.cc
new file mode 100644
index 000000000..b788dbf9b
--- /dev/null
+++ b/src/rocksdb/utilities/compaction_filters/remove_emptyvalue_compactionfilter.cc
@@ -0,0 +1,26 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "utilities/compaction_filters/remove_emptyvalue_compactionfilter.h"
+
+#include <string>
+
+#include "rocksdb/slice.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+bool RemoveEmptyValueCompactionFilter::Filter(int /*level*/,
+ const Slice& /*key*/,
+ const Slice& existing_value,
+ std::string* /*new_value*/,
+ bool* /*value_changed*/) const {
+ // remove kv pairs that have empty values
+ return existing_value.empty();
+}
+
+} // namespace ROCKSDB_NAMESPACE
+#endif // !ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/compaction_filters/remove_emptyvalue_compactionfilter.h b/src/rocksdb/utilities/compaction_filters/remove_emptyvalue_compactionfilter.h
new file mode 100644
index 000000000..864ad15ff
--- /dev/null
+++ b/src/rocksdb/utilities/compaction_filters/remove_emptyvalue_compactionfilter.h
@@ -0,0 +1,28 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#pragma once
+
+#include <string>
+
+#include "rocksdb/compaction_filter.h"
+#include "rocksdb/slice.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class RemoveEmptyValueCompactionFilter : public CompactionFilter {
+ public:
+ static const char* kClassName() { return "RemoveEmptyValueCompactionFilter"; }
+
+ const char* Name() const override { return kClassName(); }
+
+ bool Filter(int level, const Slice& key, const Slice& existing_value,
+ std::string* new_value, bool* value_changed) const override;
+};
+
+} // namespace ROCKSDB_NAMESPACE
+#endif // !ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/convenience/info_log_finder.cc b/src/rocksdb/utilities/convenience/info_log_finder.cc
new file mode 100644
index 000000000..fe62fd561
--- /dev/null
+++ b/src/rocksdb/utilities/convenience/info_log_finder.cc
@@ -0,0 +1,26 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2012 Facebook.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "rocksdb/utilities/info_log_finder.h"
+
+#include "file/filename.h"
+#include "rocksdb/env.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+Status GetInfoLogList(DB* db, std::vector<std::string>* info_log_list) {
+ if (!db) {
+ return Status::InvalidArgument("DB pointer is not valid");
+ }
+ std::string parent_path;
+ const Options& options = db->GetOptions();
+ return GetInfoLogFiles(options.env->GetFileSystem(), options.db_log_dir,
+ db->GetName(), &parent_path, info_log_list);
+}
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/counted_fs.cc b/src/rocksdb/utilities/counted_fs.cc
new file mode 100644
index 000000000..e43f3a191
--- /dev/null
+++ b/src/rocksdb/utilities/counted_fs.cc
@@ -0,0 +1,379 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "utilities/counted_fs.h"
+
+#include <sstream>
+
+#include "rocksdb/file_system.h"
+#include "rocksdb/utilities/options_type.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace {
+class CountedSequentialFile : public FSSequentialFileOwnerWrapper {
+ private:
+ CountedFileSystem* fs_;
+
+ public:
+ CountedSequentialFile(std::unique_ptr<FSSequentialFile>&& f,
+ CountedFileSystem* fs)
+ : FSSequentialFileOwnerWrapper(std::move(f)), fs_(fs) {}
+
+ ~CountedSequentialFile() override { fs_->counters()->closes++; }
+
+ IOStatus Read(size_t n, const IOOptions& options, Slice* result,
+ char* scratch, IODebugContext* dbg) override {
+ IOStatus rv = target()->Read(n, options, result, scratch, dbg);
+ fs_->counters()->reads.RecordOp(rv, result->size());
+ return rv;
+ }
+
+ IOStatus PositionedRead(uint64_t offset, size_t n, const IOOptions& options,
+ Slice* result, char* scratch,
+ IODebugContext* dbg) override {
+ IOStatus rv =
+ target()->PositionedRead(offset, n, options, result, scratch, dbg);
+ fs_->counters()->reads.RecordOp(rv, result->size());
+ return rv;
+ }
+};
+
+class CountedRandomAccessFile : public FSRandomAccessFileOwnerWrapper {
+ private:
+ CountedFileSystem* fs_;
+
+ public:
+ CountedRandomAccessFile(std::unique_ptr<FSRandomAccessFile>&& f,
+ CountedFileSystem* fs)
+ : FSRandomAccessFileOwnerWrapper(std::move(f)), fs_(fs) {}
+
+ ~CountedRandomAccessFile() override { fs_->counters()->closes++; }
+
+ IOStatus Read(uint64_t offset, size_t n, const IOOptions& options,
+ Slice* result, char* scratch,
+ IODebugContext* dbg) const override {
+ IOStatus rv = target()->Read(offset, n, options, result, scratch, dbg);
+ fs_->counters()->reads.RecordOp(rv, result->size());
+ return rv;
+ }
+
+ IOStatus MultiRead(FSReadRequest* reqs, size_t num_reqs,
+ const IOOptions& options, IODebugContext* dbg) override {
+ IOStatus rv = target()->MultiRead(reqs, num_reqs, options, dbg);
+ for (size_t r = 0; r < num_reqs; r++) {
+ fs_->counters()->reads.RecordOp(reqs[r].status, reqs[r].result.size());
+ }
+ return rv;
+ }
+};
+
+class CountedWritableFile : public FSWritableFileOwnerWrapper {
+ private:
+ CountedFileSystem* fs_;
+
+ public:
+ CountedWritableFile(std::unique_ptr<FSWritableFile>&& f,
+ CountedFileSystem* fs)
+ : FSWritableFileOwnerWrapper(std::move(f)), fs_(fs) {}
+
+ IOStatus Append(const Slice& data, const IOOptions& options,
+ IODebugContext* dbg) override {
+ IOStatus rv = target()->Append(data, options, dbg);
+ fs_->counters()->writes.RecordOp(rv, data.size());
+ return rv;
+ }
+
+ IOStatus Append(const Slice& data, const IOOptions& options,
+ const DataVerificationInfo& info,
+ IODebugContext* dbg) override {
+ IOStatus rv = target()->Append(data, options, info, dbg);
+ fs_->counters()->writes.RecordOp(rv, data.size());
+ return rv;
+ }
+
+ IOStatus PositionedAppend(const Slice& data, uint64_t offset,
+ const IOOptions& options,
+ IODebugContext* dbg) override {
+ IOStatus rv = target()->PositionedAppend(data, offset, options, dbg);
+ fs_->counters()->writes.RecordOp(rv, data.size());
+ return rv;
+ }
+
+ IOStatus PositionedAppend(const Slice& data, uint64_t offset,
+ const IOOptions& options,
+ const DataVerificationInfo& info,
+ IODebugContext* dbg) override {
+ IOStatus rv = target()->PositionedAppend(data, offset, options, info, dbg);
+ fs_->counters()->writes.RecordOp(rv, data.size());
+ return rv;
+ }
+
+ IOStatus Close(const IOOptions& options, IODebugContext* dbg) override {
+ IOStatus rv = target()->Close(options, dbg);
+ if (rv.ok()) {
+ fs_->counters()->closes++;
+ }
+ return rv;
+ }
+
+ IOStatus Flush(const IOOptions& options, IODebugContext* dbg) override {
+ IOStatus rv = target()->Flush(options, dbg);
+ if (rv.ok()) {
+ fs_->counters()->flushes++;
+ }
+ return rv;
+ }
+
+ IOStatus Sync(const IOOptions& options, IODebugContext* dbg) override {
+ IOStatus rv = target()->Sync(options, dbg);
+ if (rv.ok()) {
+ fs_->counters()->syncs++;
+ }
+ return rv;
+ }
+
+ IOStatus Fsync(const IOOptions& options, IODebugContext* dbg) override {
+ IOStatus rv = target()->Fsync(options, dbg);
+ if (rv.ok()) {
+ fs_->counters()->fsyncs++;
+ }
+ return rv;
+ }
+
+ IOStatus RangeSync(uint64_t offset, uint64_t nbytes, const IOOptions& options,
+ IODebugContext* dbg) override {
+ IOStatus rv = target()->RangeSync(offset, nbytes, options, dbg);
+ if (rv.ok()) {
+ fs_->counters()->syncs++;
+ }
+ return rv;
+ }
+};
+
+class CountedRandomRWFile : public FSRandomRWFileOwnerWrapper {
+ private:
+ mutable CountedFileSystem* fs_;
+
+ public:
+ CountedRandomRWFile(std::unique_ptr<FSRandomRWFile>&& f,
+ CountedFileSystem* fs)
+ : FSRandomRWFileOwnerWrapper(std::move(f)), fs_(fs) {}
+ IOStatus Write(uint64_t offset, const Slice& data, const IOOptions& options,
+ IODebugContext* dbg) override {
+ IOStatus rv = target()->Write(offset, data, options, dbg);
+ fs_->counters()->writes.RecordOp(rv, data.size());
+ return rv;
+ }
+
+ IOStatus Read(uint64_t offset, size_t n, const IOOptions& options,
+ Slice* result, char* scratch,
+ IODebugContext* dbg) const override {
+ IOStatus rv = target()->Read(offset, n, options, result, scratch, dbg);
+ fs_->counters()->reads.RecordOp(rv, result->size());
+ return rv;
+ }
+
+ IOStatus Flush(const IOOptions& options, IODebugContext* dbg) override {
+ IOStatus rv = target()->Flush(options, dbg);
+ if (rv.ok()) {
+ fs_->counters()->flushes++;
+ }
+ return rv;
+ }
+
+ IOStatus Sync(const IOOptions& options, IODebugContext* dbg) override {
+ IOStatus rv = target()->Sync(options, dbg);
+ if (rv.ok()) {
+ fs_->counters()->syncs++;
+ }
+ return rv;
+ }
+
+ IOStatus Fsync(const IOOptions& options, IODebugContext* dbg) override {
+ IOStatus rv = target()->Fsync(options, dbg);
+ if (rv.ok()) {
+ fs_->counters()->fsyncs++;
+ }
+ return rv;
+ }
+
+ IOStatus Close(const IOOptions& options, IODebugContext* dbg) override {
+ IOStatus rv = target()->Close(options, dbg);
+ if (rv.ok()) {
+ fs_->counters()->closes++;
+ }
+ return rv;
+ }
+};
+
+class CountedDirectory : public FSDirectoryWrapper {
+ private:
+ mutable CountedFileSystem* fs_;
+ bool closed_ = false;
+
+ public:
+ CountedDirectory(std::unique_ptr<FSDirectory>&& f, CountedFileSystem* fs)
+ : FSDirectoryWrapper(std::move(f)), fs_(fs) {}
+
+ IOStatus Fsync(const IOOptions& options, IODebugContext* dbg) override {
+ IOStatus rv = FSDirectoryWrapper::Fsync(options, dbg);
+ if (rv.ok()) {
+ fs_->counters()->dsyncs++;
+ }
+ return rv;
+ }
+
+ IOStatus Close(const IOOptions& options, IODebugContext* dbg) override {
+ IOStatus rv = FSDirectoryWrapper::Close(options, dbg);
+ if (rv.ok()) {
+ fs_->counters()->closes++;
+ fs_->counters()->dir_closes++;
+ closed_ = true;
+ }
+ return rv;
+ }
+
+ IOStatus FsyncWithDirOptions(const IOOptions& options, IODebugContext* dbg,
+ const DirFsyncOptions& dir_options) override {
+ IOStatus rv =
+ FSDirectoryWrapper::FsyncWithDirOptions(options, dbg, dir_options);
+ if (rv.ok()) {
+ fs_->counters()->dsyncs++;
+ }
+ return rv;
+ }
+
+ ~CountedDirectory() {
+ if (!closed_) {
+ // TODO: fix DB+CF code to use explicit Close, not rely on destructor
+ fs_->counters()->closes++;
+ fs_->counters()->dir_closes++;
+ }
+ }
+};
+} // anonymous namespace
+
+std::string FileOpCounters::PrintCounters() const {
+ std::stringstream ss;
+ ss << "Num files opened: " << opens.load(std::memory_order_relaxed)
+ << std::endl;
+ ss << "Num files deleted: " << deletes.load(std::memory_order_relaxed)
+ << std::endl;
+ ss << "Num files renamed: " << renames.load(std::memory_order_relaxed)
+ << std::endl;
+ ss << "Num Flush(): " << flushes.load(std::memory_order_relaxed) << std::endl;
+ ss << "Num Sync(): " << syncs.load(std::memory_order_relaxed) << std::endl;
+ ss << "Num Fsync(): " << fsyncs.load(std::memory_order_relaxed) << std::endl;
+ ss << "Num Dir Fsync(): " << dsyncs.load(std::memory_order_relaxed)
+ << std::endl;
+ ss << "Num Close(): " << closes.load(std::memory_order_relaxed) << std::endl;
+ ss << "Num Dir Open(): " << dir_opens.load(std::memory_order_relaxed)
+ << std::endl;
+ ss << "Num Dir Close(): " << dir_closes.load(std::memory_order_relaxed)
+ << std::endl;
+ ss << "Num Read(): " << reads.ops.load(std::memory_order_relaxed)
+ << std::endl;
+ ss << "Num Append(): " << writes.ops.load(std::memory_order_relaxed)
+ << std::endl;
+ ss << "Num bytes read: " << reads.bytes.load(std::memory_order_relaxed)
+ << std::endl;
+ ss << "Num bytes written: " << writes.bytes.load(std::memory_order_relaxed)
+ << std::endl;
+ return ss.str();
+}
+
+CountedFileSystem::CountedFileSystem(const std::shared_ptr<FileSystem>& base)
+ : FileSystemWrapper(base) {}
+
+IOStatus CountedFileSystem::NewSequentialFile(
+ const std::string& f, const FileOptions& options,
+ std::unique_ptr<FSSequentialFile>* r, IODebugContext* dbg) {
+ std::unique_ptr<FSSequentialFile> base;
+ IOStatus s = target()->NewSequentialFile(f, options, &base, dbg);
+ if (s.ok()) {
+ counters_.opens++;
+ r->reset(new CountedSequentialFile(std::move(base), this));
+ }
+ return s;
+}
+
+IOStatus CountedFileSystem::NewRandomAccessFile(
+ const std::string& f, const FileOptions& options,
+ std::unique_ptr<FSRandomAccessFile>* r, IODebugContext* dbg) {
+ std::unique_ptr<FSRandomAccessFile> base;
+ IOStatus s = target()->NewRandomAccessFile(f, options, &base, dbg);
+ if (s.ok()) {
+ counters_.opens++;
+ r->reset(new CountedRandomAccessFile(std::move(base), this));
+ }
+ return s;
+}
+
+IOStatus CountedFileSystem::NewWritableFile(const std::string& f,
+ const FileOptions& options,
+ std::unique_ptr<FSWritableFile>* r,
+ IODebugContext* dbg) {
+ std::unique_ptr<FSWritableFile> base;
+ IOStatus s = target()->NewWritableFile(f, options, &base, dbg);
+ if (s.ok()) {
+ counters_.opens++;
+ r->reset(new CountedWritableFile(std::move(base), this));
+ }
+ return s;
+}
+
+IOStatus CountedFileSystem::ReopenWritableFile(
+ const std::string& fname, const FileOptions& options,
+ std::unique_ptr<FSWritableFile>* result, IODebugContext* dbg) {
+ std::unique_ptr<FSWritableFile> base;
+ IOStatus s = target()->ReopenWritableFile(fname, options, &base, dbg);
+ if (s.ok()) {
+ counters_.opens++;
+ result->reset(new CountedWritableFile(std::move(base), this));
+ }
+ return s;
+}
+
+IOStatus CountedFileSystem::ReuseWritableFile(
+ const std::string& fname, const std::string& old_fname,
+ const FileOptions& options, std::unique_ptr<FSWritableFile>* result,
+ IODebugContext* dbg) {
+ std::unique_ptr<FSWritableFile> base;
+ IOStatus s =
+ target()->ReuseWritableFile(fname, old_fname, options, &base, dbg);
+ if (s.ok()) {
+ counters_.opens++;
+ result->reset(new CountedWritableFile(std::move(base), this));
+ }
+ return s;
+}
+
+IOStatus CountedFileSystem::NewRandomRWFile(
+ const std::string& name, const FileOptions& options,
+ std::unique_ptr<FSRandomRWFile>* result, IODebugContext* dbg) {
+ std::unique_ptr<FSRandomRWFile> base;
+ IOStatus s = target()->NewRandomRWFile(name, options, &base, dbg);
+ if (s.ok()) {
+ counters_.opens++;
+ result->reset(new CountedRandomRWFile(std::move(base), this));
+ }
+ return s;
+}
+
+IOStatus CountedFileSystem::NewDirectory(const std::string& name,
+ const IOOptions& options,
+ std::unique_ptr<FSDirectory>* result,
+ IODebugContext* dbg) {
+ std::unique_ptr<FSDirectory> base;
+ IOStatus s = target()->NewDirectory(name, options, &base, dbg);
+ if (s.ok()) {
+ counters_.opens++;
+ counters_.dir_opens++;
+ result->reset(new CountedDirectory(std::move(base), this));
+ }
+ return s;
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/counted_fs.h b/src/rocksdb/utilities/counted_fs.h
new file mode 100644
index 000000000..cb8a8968f
--- /dev/null
+++ b/src/rocksdb/utilities/counted_fs.h
@@ -0,0 +1,158 @@
+// Copyright (c) 2016-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#include <atomic>
+#include <memory>
+
+#include "rocksdb/file_system.h"
+#include "rocksdb/io_status.h"
+#include "rocksdb/rocksdb_namespace.h"
+
+namespace ROCKSDB_NAMESPACE {
+class Logger;
+
+struct OpCounter {
+ std::atomic<int> ops;
+ std::atomic<uint64_t> bytes;
+
+ OpCounter() : ops(0), bytes(0) {}
+
+ void Reset() {
+ ops = 0;
+ bytes = 0;
+ }
+ void RecordOp(const IOStatus& io_s, size_t added_bytes) {
+ if (!io_s.IsNotSupported()) {
+ ops.fetch_add(1, std::memory_order_relaxed);
+ }
+ if (io_s.ok()) {
+ bytes.fetch_add(added_bytes, std::memory_order_relaxed);
+ }
+ }
+};
+
+struct FileOpCounters {
+ static const char* kName() { return "FileOpCounters"; }
+
+ std::atomic<int> opens;
+ std::atomic<int> closes;
+ std::atomic<int> deletes;
+ std::atomic<int> renames;
+ std::atomic<int> flushes;
+ std::atomic<int> syncs;
+ std::atomic<int> dsyncs;
+ std::atomic<int> fsyncs;
+ std::atomic<int> dir_opens;
+ std::atomic<int> dir_closes;
+ OpCounter reads;
+ OpCounter writes;
+
+ FileOpCounters()
+ : opens(0),
+ closes(0),
+ deletes(0),
+ renames(0),
+ flushes(0),
+ syncs(0),
+ dsyncs(0),
+ fsyncs(0),
+ dir_opens(0),
+ dir_closes(0) {}
+
+ void Reset() {
+ opens = 0;
+ closes = 0;
+ deletes = 0;
+ renames = 0;
+ flushes = 0;
+ syncs = 0;
+ dsyncs = 0;
+ fsyncs = 0;
+ dir_opens = 0;
+ dir_closes = 0;
+ reads.Reset();
+ writes.Reset();
+ }
+ std::string PrintCounters() const;
+};
+
+// A FileSystem class that counts operations (reads, writes, opens, closes, etc)
+class CountedFileSystem : public FileSystemWrapper {
+ public:
+ private:
+ FileOpCounters counters_;
+
+ public:
+ explicit CountedFileSystem(const std::shared_ptr<FileSystem>& base);
+ static const char* kClassName() { return "CountedFileSystem"; }
+ const char* Name() const override { return kClassName(); }
+
+ IOStatus NewSequentialFile(const std::string& f, const FileOptions& options,
+ std::unique_ptr<FSSequentialFile>* r,
+ IODebugContext* dbg) override;
+
+ IOStatus NewRandomAccessFile(const std::string& f,
+ const FileOptions& file_opts,
+ std::unique_ptr<FSRandomAccessFile>* r,
+ IODebugContext* dbg) override;
+
+ IOStatus NewWritableFile(const std::string& f, const FileOptions& options,
+ std::unique_ptr<FSWritableFile>* r,
+ IODebugContext* dbg) override;
+ IOStatus ReopenWritableFile(const std::string& fname,
+ const FileOptions& options,
+ std::unique_ptr<FSWritableFile>* result,
+ IODebugContext* dbg) override;
+
+ IOStatus ReuseWritableFile(const std::string& fname,
+ const std::string& old_fname,
+ const FileOptions& file_opts,
+ std::unique_ptr<FSWritableFile>* result,
+ IODebugContext* dbg) override;
+ IOStatus NewRandomRWFile(const std::string& name, const FileOptions& options,
+ std::unique_ptr<FSRandomRWFile>* result,
+ IODebugContext* dbg) override;
+
+ IOStatus NewDirectory(const std::string& name, const IOOptions& io_opts,
+ std::unique_ptr<FSDirectory>* result,
+ IODebugContext* dbg) override;
+
+ IOStatus DeleteFile(const std::string& fname, const IOOptions& options,
+ IODebugContext* dbg) override {
+ IOStatus s = target()->DeleteFile(fname, options, dbg);
+ if (s.ok()) {
+ counters_.deletes++;
+ }
+ return s;
+ }
+
+ IOStatus RenameFile(const std::string& s, const std::string& t,
+ const IOOptions& options, IODebugContext* dbg) override {
+ IOStatus st = target()->RenameFile(s, t, options, dbg);
+ if (st.ok()) {
+ counters_.renames++;
+ }
+ return st;
+ }
+
+ const FileOpCounters* counters() const { return &counters_; }
+
+ FileOpCounters* counters() { return &counters_; }
+
+ const void* GetOptionsPtr(const std::string& name) const override {
+ if (name == FileOpCounters::kName()) {
+ return counters();
+ } else {
+ return FileSystemWrapper::GetOptionsPtr(name);
+ }
+ }
+
+ // Prints the counters to a string
+ std::string PrintCounters() const { return counters_.PrintCounters(); }
+ void ResetCounters() { counters_.Reset(); }
+};
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/debug.cc b/src/rocksdb/utilities/debug.cc
new file mode 100644
index 000000000..f2c3bb513
--- /dev/null
+++ b/src/rocksdb/utilities/debug.cc
@@ -0,0 +1,120 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "rocksdb/utilities/debug.h"
+
+#include "db/db_impl/db_impl.h"
+#include "rocksdb/utilities/options_type.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+static std::unordered_map<std::string, ValueType> value_type_string_map = {
+ {"TypeDeletion", ValueType::kTypeDeletion},
+ {"TypeValue", ValueType::kTypeValue},
+ {"TypeMerge", ValueType::kTypeMerge},
+ {"TypeLogData", ValueType::kTypeLogData},
+ {"TypeColumnFamilyDeletion", ValueType::kTypeColumnFamilyDeletion},
+ {"TypeColumnFamilyValue", ValueType::kTypeColumnFamilyValue},
+ {"TypeColumnFamilyMerge", ValueType::kTypeColumnFamilyMerge},
+ {"TypeSingleDeletion", ValueType::kTypeSingleDeletion},
+ {"TypeColumnFamilySingleDeletion",
+ ValueType::kTypeColumnFamilySingleDeletion},
+ {"TypeBeginPrepareXID", ValueType::kTypeBeginPrepareXID},
+ {"TypeEndPrepareXID", ValueType::kTypeEndPrepareXID},
+ {"TypeCommitXID", ValueType::kTypeCommitXID},
+ {"TypeRollbackXID", ValueType::kTypeRollbackXID},
+ {"TypeNoop", ValueType::kTypeNoop},
+ {"TypeColumnFamilyRangeDeletion",
+ ValueType::kTypeColumnFamilyRangeDeletion},
+ {"TypeRangeDeletion", ValueType::kTypeRangeDeletion},
+ {"TypeColumnFamilyBlobIndex", ValueType::kTypeColumnFamilyBlobIndex},
+ {"TypeBlobIndex", ValueType::kTypeBlobIndex},
+ {"TypeBeginPersistedPrepareXID", ValueType::kTypeBeginPersistedPrepareXID},
+ {"TypeBeginUnprepareXID", ValueType::kTypeBeginUnprepareXID},
+ {"TypeDeletionWithTimestamp", ValueType::kTypeDeletionWithTimestamp},
+ {"TypeCommitXIDAndTimestamp", ValueType::kTypeCommitXIDAndTimestamp},
+ {"TypeWideColumnEntity", ValueType::kTypeWideColumnEntity},
+ {"TypeColumnFamilyWideColumnEntity",
+ ValueType::kTypeColumnFamilyWideColumnEntity}};
+
+std::string KeyVersion::GetTypeName() const {
+ std::string type_name;
+ if (SerializeEnum<ValueType>(value_type_string_map,
+ static_cast<ValueType>(type), &type_name)) {
+ return type_name;
+ } else {
+ return "Invalid";
+ }
+}
+
+Status GetAllKeyVersions(DB* db, Slice begin_key, Slice end_key,
+ size_t max_num_ikeys,
+ std::vector<KeyVersion>* key_versions) {
+ if (nullptr == db) {
+ return Status::InvalidArgument("db cannot be null.");
+ }
+ return GetAllKeyVersions(db, db->DefaultColumnFamily(), begin_key, end_key,
+ max_num_ikeys, key_versions);
+}
+
+Status GetAllKeyVersions(DB* db, ColumnFamilyHandle* cfh, Slice begin_key,
+ Slice end_key, size_t max_num_ikeys,
+ std::vector<KeyVersion>* key_versions) {
+ if (nullptr == db) {
+ return Status::InvalidArgument("db cannot be null.");
+ }
+ if (nullptr == cfh) {
+ return Status::InvalidArgument("Column family handle cannot be null.");
+ }
+ if (nullptr == key_versions) {
+ return Status::InvalidArgument("key_versions cannot be null.");
+ }
+ key_versions->clear();
+
+ DBImpl* idb = static_cast<DBImpl*>(db->GetRootDB());
+ auto icmp = InternalKeyComparator(idb->GetOptions(cfh).comparator);
+ ReadOptions read_options;
+ Arena arena;
+ ScopedArenaIterator iter(
+ idb->NewInternalIterator(read_options, &arena, kMaxSequenceNumber, cfh));
+
+ if (!begin_key.empty()) {
+ InternalKey ikey;
+ ikey.SetMinPossibleForUserKey(begin_key);
+ iter->Seek(ikey.Encode());
+ } else {
+ iter->SeekToFirst();
+ }
+
+ size_t num_keys = 0;
+ for (; iter->Valid(); iter->Next()) {
+ ParsedInternalKey ikey;
+ Status pik_status =
+ ParseInternalKey(iter->key(), &ikey, true /* log_err_key */); // TODO
+ if (!pik_status.ok()) {
+ return pik_status;
+ }
+
+ if (!end_key.empty() &&
+ icmp.user_comparator()->Compare(ikey.user_key, end_key) > 0) {
+ break;
+ }
+
+ key_versions->emplace_back(ikey.user_key.ToString() /* _user_key */,
+ iter->value().ToString() /* _value */,
+ ikey.sequence /* _sequence */,
+ static_cast<int>(ikey.type) /* _type */);
+ if (++num_keys >= max_num_ikeys) {
+ break;
+ }
+ }
+ return Status::OK();
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/env_mirror.cc b/src/rocksdb/utilities/env_mirror.cc
new file mode 100644
index 000000000..3ea323b42
--- /dev/null
+++ b/src/rocksdb/utilities/env_mirror.cc
@@ -0,0 +1,275 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+// Copyright (c) 2015, Red Hat, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#ifndef ROCKSDB_LITE
+
+#include "rocksdb/utilities/env_mirror.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// An implementation of Env that mirrors all work over two backend
+// Env's. This is useful for debugging purposes.
+class SequentialFileMirror : public SequentialFile {
+ public:
+ std::unique_ptr<SequentialFile> a_, b_;
+ std::string fname;
+ explicit SequentialFileMirror(std::string f) : fname(f) {}
+
+ Status Read(size_t n, Slice* result, char* scratch) override {
+ Slice aslice;
+ Status as = a_->Read(n, &aslice, scratch);
+ if (as == Status::OK()) {
+ char* bscratch = new char[n];
+ Slice bslice;
+#ifndef NDEBUG
+ size_t off = 0;
+#endif
+ size_t left = aslice.size();
+ while (left) {
+ Status bs = b_->Read(left, &bslice, bscratch);
+#ifndef NDEBUG
+ assert(as == bs);
+ assert(memcmp(bscratch, scratch + off, bslice.size()) == 0);
+ off += bslice.size();
+#endif
+ left -= bslice.size();
+ }
+ delete[] bscratch;
+ *result = aslice;
+ } else {
+ Status bs = b_->Read(n, result, scratch);
+ assert(as == bs);
+ }
+ return as;
+ }
+
+ Status Skip(uint64_t n) override {
+ Status as = a_->Skip(n);
+ Status bs = b_->Skip(n);
+ assert(as == bs);
+ return as;
+ }
+ Status InvalidateCache(size_t offset, size_t length) override {
+ Status as = a_->InvalidateCache(offset, length);
+ Status bs = b_->InvalidateCache(offset, length);
+ assert(as == bs);
+ return as;
+ };
+};
+
+class RandomAccessFileMirror : public RandomAccessFile {
+ public:
+ std::unique_ptr<RandomAccessFile> a_, b_;
+ std::string fname;
+ explicit RandomAccessFileMirror(std::string f) : fname(f) {}
+
+ Status Read(uint64_t offset, size_t n, Slice* result,
+ char* scratch) const override {
+ Status as = a_->Read(offset, n, result, scratch);
+ if (as == Status::OK()) {
+ char* bscratch = new char[n];
+ Slice bslice;
+ size_t off = 0;
+ size_t left = result->size();
+ while (left) {
+ Status bs = b_->Read(offset + off, left, &bslice, bscratch);
+ assert(as == bs);
+ assert(memcmp(bscratch, scratch + off, bslice.size()) == 0);
+ off += bslice.size();
+ left -= bslice.size();
+ }
+ delete[] bscratch;
+ } else {
+ Status bs = b_->Read(offset, n, result, scratch);
+ assert(as == bs);
+ }
+ return as;
+ }
+
+ size_t GetUniqueId(char* id, size_t max_size) const override {
+ // NOTE: not verified
+ return a_->GetUniqueId(id, max_size);
+ }
+};
+
+class WritableFileMirror : public WritableFile {
+ public:
+ std::unique_ptr<WritableFile> a_, b_;
+ std::string fname;
+ explicit WritableFileMirror(std::string f, const EnvOptions& options)
+ : WritableFile(options), fname(f) {}
+
+ Status Append(const Slice& data) override {
+ Status as = a_->Append(data);
+ Status bs = b_->Append(data);
+ assert(as == bs);
+ return as;
+ }
+ Status Append(const Slice& data,
+ const DataVerificationInfo& /* verification_info */) override {
+ return Append(data);
+ }
+ Status PositionedAppend(const Slice& data, uint64_t offset) override {
+ Status as = a_->PositionedAppend(data, offset);
+ Status bs = b_->PositionedAppend(data, offset);
+ assert(as == bs);
+ return as;
+ }
+ Status PositionedAppend(
+ const Slice& data, uint64_t offset,
+ const DataVerificationInfo& /* verification_info */) override {
+ return PositionedAppend(data, offset);
+ }
+ Status Truncate(uint64_t size) override {
+ Status as = a_->Truncate(size);
+ Status bs = b_->Truncate(size);
+ assert(as == bs);
+ return as;
+ }
+ Status Close() override {
+ Status as = a_->Close();
+ Status bs = b_->Close();
+ assert(as == bs);
+ return as;
+ }
+ Status Flush() override {
+ Status as = a_->Flush();
+ Status bs = b_->Flush();
+ assert(as == bs);
+ return as;
+ }
+ Status Sync() override {
+ Status as = a_->Sync();
+ Status bs = b_->Sync();
+ assert(as == bs);
+ return as;
+ }
+ Status Fsync() override {
+ Status as = a_->Fsync();
+ Status bs = b_->Fsync();
+ assert(as == bs);
+ return as;
+ }
+ bool IsSyncThreadSafe() const override {
+ bool as = a_->IsSyncThreadSafe();
+ assert(as == b_->IsSyncThreadSafe());
+ return as;
+ }
+ void SetIOPriority(Env::IOPriority pri) override {
+ a_->SetIOPriority(pri);
+ b_->SetIOPriority(pri);
+ }
+ Env::IOPriority GetIOPriority() override {
+ // NOTE: we don't verify this one
+ return a_->GetIOPriority();
+ }
+ uint64_t GetFileSize() override {
+ uint64_t as = a_->GetFileSize();
+ assert(as == b_->GetFileSize());
+ return as;
+ }
+ void GetPreallocationStatus(size_t* block_size,
+ size_t* last_allocated_block) override {
+ // NOTE: we don't verify this one
+ return a_->GetPreallocationStatus(block_size, last_allocated_block);
+ }
+ size_t GetUniqueId(char* id, size_t max_size) const override {
+ // NOTE: we don't verify this one
+ return a_->GetUniqueId(id, max_size);
+ }
+ Status InvalidateCache(size_t offset, size_t length) override {
+ Status as = a_->InvalidateCache(offset, length);
+ Status bs = b_->InvalidateCache(offset, length);
+ assert(as == bs);
+ return as;
+ }
+
+ protected:
+ Status Allocate(uint64_t offset, uint64_t length) override {
+ Status as = a_->Allocate(offset, length);
+ Status bs = b_->Allocate(offset, length);
+ assert(as == bs);
+ return as;
+ }
+ Status RangeSync(uint64_t offset, uint64_t nbytes) override {
+ Status as = a_->RangeSync(offset, nbytes);
+ Status bs = b_->RangeSync(offset, nbytes);
+ assert(as == bs);
+ return as;
+ }
+};
+
+Status EnvMirror::NewSequentialFile(const std::string& f,
+ std::unique_ptr<SequentialFile>* r,
+ const EnvOptions& options) {
+ if (f.find("/proc/") == 0) {
+ return a_->NewSequentialFile(f, r, options);
+ }
+ SequentialFileMirror* mf = new SequentialFileMirror(f);
+ Status as = a_->NewSequentialFile(f, &mf->a_, options);
+ Status bs = b_->NewSequentialFile(f, &mf->b_, options);
+ assert(as == bs);
+ if (as.ok())
+ r->reset(mf);
+ else
+ delete mf;
+ return as;
+}
+
+Status EnvMirror::NewRandomAccessFile(const std::string& f,
+ std::unique_ptr<RandomAccessFile>* r,
+ const EnvOptions& options) {
+ if (f.find("/proc/") == 0) {
+ return a_->NewRandomAccessFile(f, r, options);
+ }
+ RandomAccessFileMirror* mf = new RandomAccessFileMirror(f);
+ Status as = a_->NewRandomAccessFile(f, &mf->a_, options);
+ Status bs = b_->NewRandomAccessFile(f, &mf->b_, options);
+ assert(as == bs);
+ if (as.ok())
+ r->reset(mf);
+ else
+ delete mf;
+ return as;
+}
+
+Status EnvMirror::NewWritableFile(const std::string& f,
+ std::unique_ptr<WritableFile>* r,
+ const EnvOptions& options) {
+ if (f.find("/proc/") == 0) return a_->NewWritableFile(f, r, options);
+ WritableFileMirror* mf = new WritableFileMirror(f, options);
+ Status as = a_->NewWritableFile(f, &mf->a_, options);
+ Status bs = b_->NewWritableFile(f, &mf->b_, options);
+ assert(as == bs);
+ if (as.ok())
+ r->reset(mf);
+ else
+ delete mf;
+ return as;
+}
+
+Status EnvMirror::ReuseWritableFile(const std::string& fname,
+ const std::string& old_fname,
+ std::unique_ptr<WritableFile>* r,
+ const EnvOptions& options) {
+ if (fname.find("/proc/") == 0)
+ return a_->ReuseWritableFile(fname, old_fname, r, options);
+ WritableFileMirror* mf = new WritableFileMirror(fname, options);
+ Status as = a_->ReuseWritableFile(fname, old_fname, &mf->a_, options);
+ Status bs = b_->ReuseWritableFile(fname, old_fname, &mf->b_, options);
+ assert(as == bs);
+ if (as.ok())
+ r->reset(mf);
+ else
+ delete mf;
+ return as;
+}
+
+} // namespace ROCKSDB_NAMESPACE
+#endif
diff --git a/src/rocksdb/utilities/env_mirror_test.cc b/src/rocksdb/utilities/env_mirror_test.cc
new file mode 100644
index 000000000..c372de1da
--- /dev/null
+++ b/src/rocksdb/utilities/env_mirror_test.cc
@@ -0,0 +1,226 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+// Copyright (c) 2015, Red Hat, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "rocksdb/utilities/env_mirror.h"
+
+#include "env/mock_env.h"
+#include "test_util/testharness.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class EnvMirrorTest : public testing::Test {
+ public:
+ Env* default_;
+ MockEnv *a_, *b_;
+ EnvMirror* env_;
+ const EnvOptions soptions_;
+
+ EnvMirrorTest()
+ : default_(Env::Default()),
+ a_(new MockEnv(default_)),
+ b_(new MockEnv(default_)),
+ env_(new EnvMirror(a_, b_)) {}
+ ~EnvMirrorTest() {
+ delete env_;
+ delete a_;
+ delete b_;
+ }
+};
+
+TEST_F(EnvMirrorTest, Basics) {
+ uint64_t file_size;
+ std::unique_ptr<WritableFile> writable_file;
+ std::vector<std::string> children;
+
+ ASSERT_OK(env_->CreateDir("/dir"));
+
+ // Check that the directory is empty.
+ ASSERT_EQ(Status::NotFound(), env_->FileExists("/dir/non_existent"));
+ ASSERT_TRUE(!env_->GetFileSize("/dir/non_existent", &file_size).ok());
+ ASSERT_OK(env_->GetChildren("/dir", &children));
+ ASSERT_EQ(0U, children.size());
+
+ // Create a file.
+ ASSERT_OK(env_->NewWritableFile("/dir/f", &writable_file, soptions_));
+ writable_file.reset();
+
+ // Check that the file exists.
+ ASSERT_OK(env_->FileExists("/dir/f"));
+ ASSERT_OK(a_->FileExists("/dir/f"));
+ ASSERT_OK(b_->FileExists("/dir/f"));
+ ASSERT_OK(env_->GetFileSize("/dir/f", &file_size));
+ ASSERT_EQ(0U, file_size);
+ ASSERT_OK(env_->GetChildren("/dir", &children));
+ ASSERT_EQ(1U, children.size());
+ ASSERT_EQ("f", children[0]);
+ ASSERT_OK(a_->GetChildren("/dir", &children));
+ ASSERT_EQ(1U, children.size());
+ ASSERT_EQ("f", children[0]);
+ ASSERT_OK(b_->GetChildren("/dir", &children));
+ ASSERT_EQ(1U, children.size());
+ ASSERT_EQ("f", children[0]);
+
+ // Write to the file.
+ ASSERT_OK(env_->NewWritableFile("/dir/f", &writable_file, soptions_));
+ ASSERT_OK(writable_file->Append("abc"));
+ writable_file.reset();
+
+ // Check for expected size.
+ ASSERT_OK(env_->GetFileSize("/dir/f", &file_size));
+ ASSERT_EQ(3U, file_size);
+ ASSERT_OK(a_->GetFileSize("/dir/f", &file_size));
+ ASSERT_EQ(3U, file_size);
+ ASSERT_OK(b_->GetFileSize("/dir/f", &file_size));
+ ASSERT_EQ(3U, file_size);
+
+ // Check that renaming works.
+ ASSERT_TRUE(!env_->RenameFile("/dir/non_existent", "/dir/g").ok());
+ ASSERT_OK(env_->RenameFile("/dir/f", "/dir/g"));
+ ASSERT_EQ(Status::NotFound(), env_->FileExists("/dir/f"));
+ ASSERT_OK(env_->FileExists("/dir/g"));
+ ASSERT_OK(env_->GetFileSize("/dir/g", &file_size));
+ ASSERT_EQ(3U, file_size);
+ ASSERT_OK(a_->FileExists("/dir/g"));
+ ASSERT_OK(a_->GetFileSize("/dir/g", &file_size));
+ ASSERT_EQ(3U, file_size);
+ ASSERT_OK(b_->FileExists("/dir/g"));
+ ASSERT_OK(b_->GetFileSize("/dir/g", &file_size));
+ ASSERT_EQ(3U, file_size);
+
+ // Check that opening non-existent file fails.
+ std::unique_ptr<SequentialFile> seq_file;
+ std::unique_ptr<RandomAccessFile> rand_file;
+ ASSERT_TRUE(
+ !env_->NewSequentialFile("/dir/non_existent", &seq_file, soptions_).ok());
+ ASSERT_TRUE(!seq_file);
+ ASSERT_TRUE(
+ !env_->NewRandomAccessFile("/dir/non_existent", &rand_file, soptions_)
+ .ok());
+ ASSERT_TRUE(!rand_file);
+
+ // Check that deleting works.
+ ASSERT_TRUE(!env_->DeleteFile("/dir/non_existent").ok());
+ ASSERT_OK(env_->DeleteFile("/dir/g"));
+ ASSERT_EQ(Status::NotFound(), env_->FileExists("/dir/g"));
+ ASSERT_OK(env_->GetChildren("/dir", &children));
+ ASSERT_EQ(0U, children.size());
+ ASSERT_OK(env_->DeleteDir("/dir"));
+}
+
+TEST_F(EnvMirrorTest, ReadWrite) {
+ std::unique_ptr<WritableFile> writable_file;
+ std::unique_ptr<SequentialFile> seq_file;
+ std::unique_ptr<RandomAccessFile> rand_file;
+ Slice result;
+ char scratch[100];
+
+ ASSERT_OK(env_->CreateDir("/dir"));
+
+ ASSERT_OK(env_->NewWritableFile("/dir/f", &writable_file, soptions_));
+ ASSERT_OK(writable_file->Append("hello "));
+ ASSERT_OK(writable_file->Append("world"));
+ writable_file.reset();
+
+ // Read sequentially.
+ ASSERT_OK(env_->NewSequentialFile("/dir/f", &seq_file, soptions_));
+ ASSERT_OK(seq_file->Read(5, &result, scratch)); // Read "hello".
+ ASSERT_EQ(0, result.compare("hello"));
+ ASSERT_OK(seq_file->Skip(1));
+ ASSERT_OK(seq_file->Read(1000, &result, scratch)); // Read "world".
+ ASSERT_EQ(0, result.compare("world"));
+ ASSERT_OK(seq_file->Read(1000, &result, scratch)); // Try reading past EOF.
+ ASSERT_EQ(0U, result.size());
+ ASSERT_OK(seq_file->Skip(100)); // Try to skip past end of file.
+ ASSERT_OK(seq_file->Read(1000, &result, scratch));
+ ASSERT_EQ(0U, result.size());
+
+ // Random reads.
+ ASSERT_OK(env_->NewRandomAccessFile("/dir/f", &rand_file, soptions_));
+ ASSERT_OK(rand_file->Read(6, 5, &result, scratch)); // Read "world".
+ ASSERT_EQ(0, result.compare("world"));
+ ASSERT_OK(rand_file->Read(0, 5, &result, scratch)); // Read "hello".
+ ASSERT_EQ(0, result.compare("hello"));
+ ASSERT_OK(rand_file->Read(10, 100, &result, scratch)); // Read "d".
+ ASSERT_EQ(0, result.compare("d"));
+
+ // Too high offset.
+ ASSERT_TRUE(!rand_file->Read(1000, 5, &result, scratch).ok());
+}
+
+TEST_F(EnvMirrorTest, Locks) {
+ FileLock* lock;
+
+ // These are no-ops, but we test they return success.
+ ASSERT_OK(env_->LockFile("some file", &lock));
+ ASSERT_OK(env_->UnlockFile(lock));
+}
+
+TEST_F(EnvMirrorTest, Misc) {
+ std::string test_dir;
+ ASSERT_OK(env_->GetTestDirectory(&test_dir));
+ ASSERT_TRUE(!test_dir.empty());
+
+ std::unique_ptr<WritableFile> writable_file;
+ ASSERT_OK(env_->NewWritableFile("/a/b", &writable_file, soptions_));
+
+ // These are no-ops, but we test they return success.
+ ASSERT_OK(writable_file->Sync());
+ ASSERT_OK(writable_file->Flush());
+ ASSERT_OK(writable_file->Close());
+ writable_file.reset();
+}
+
+TEST_F(EnvMirrorTest, LargeWrite) {
+ const size_t kWriteSize = 300 * 1024;
+ char* scratch = new char[kWriteSize * 2];
+
+ std::string write_data;
+ for (size_t i = 0; i < kWriteSize; ++i) {
+ write_data.append(1, static_cast<char>(i));
+ }
+
+ std::unique_ptr<WritableFile> writable_file;
+ ASSERT_OK(env_->NewWritableFile("/dir/f", &writable_file, soptions_));
+ ASSERT_OK(writable_file->Append("foo"));
+ ASSERT_OK(writable_file->Append(write_data));
+ writable_file.reset();
+
+ std::unique_ptr<SequentialFile> seq_file;
+ Slice result;
+ ASSERT_OK(env_->NewSequentialFile("/dir/f", &seq_file, soptions_));
+ ASSERT_OK(seq_file->Read(3, &result, scratch)); // Read "foo".
+ ASSERT_EQ(0, result.compare("foo"));
+
+ size_t read = 0;
+ std::string read_data;
+ while (read < kWriteSize) {
+ ASSERT_OK(seq_file->Read(kWriteSize - read, &result, scratch));
+ read_data.append(result.data(), result.size());
+ read += result.size();
+ }
+ ASSERT_TRUE(write_data == read_data);
+ delete[] scratch;
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
+
+#else
+#include <stdio.h>
+
+int main(int argc, char** argv) {
+ fprintf(stderr, "SKIPPED as EnvMirror is not supported in ROCKSDB_LITE\n");
+ return 0;
+}
+
+#endif // !ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/env_timed.cc b/src/rocksdb/utilities/env_timed.cc
new file mode 100644
index 000000000..1eb723146
--- /dev/null
+++ b/src/rocksdb/utilities/env_timed.cc
@@ -0,0 +1,187 @@
+// Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+#include "utilities/env_timed.h"
+
+#include "env/composite_env_wrapper.h"
+#include "monitoring/perf_context_imp.h"
+#include "rocksdb/env.h"
+#include "rocksdb/file_system.h"
+#include "rocksdb/status.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+#ifndef ROCKSDB_LITE
+TimedFileSystem::TimedFileSystem(const std::shared_ptr<FileSystem>& base)
+ : FileSystemWrapper(base) {}
+IOStatus TimedFileSystem::NewSequentialFile(
+ const std::string& fname, const FileOptions& options,
+ std::unique_ptr<FSSequentialFile>* result, IODebugContext* dbg) {
+ PERF_TIMER_GUARD(env_new_sequential_file_nanos);
+ return FileSystemWrapper::NewSequentialFile(fname, options, result, dbg);
+}
+
+IOStatus TimedFileSystem::NewRandomAccessFile(
+ const std::string& fname, const FileOptions& options,
+ std::unique_ptr<FSRandomAccessFile>* result, IODebugContext* dbg) {
+ PERF_TIMER_GUARD(env_new_random_access_file_nanos);
+ return FileSystemWrapper::NewRandomAccessFile(fname, options, result, dbg);
+}
+
+IOStatus TimedFileSystem::NewWritableFile(
+ const std::string& fname, const FileOptions& options,
+ std::unique_ptr<FSWritableFile>* result, IODebugContext* dbg) {
+ PERF_TIMER_GUARD(env_new_writable_file_nanos);
+ return FileSystemWrapper::NewWritableFile(fname, options, result, dbg);
+}
+
+IOStatus TimedFileSystem::ReuseWritableFile(
+ const std::string& fname, const std::string& old_fname,
+ const FileOptions& options, std::unique_ptr<FSWritableFile>* result,
+ IODebugContext* dbg) {
+ PERF_TIMER_GUARD(env_reuse_writable_file_nanos);
+ return FileSystemWrapper::ReuseWritableFile(fname, old_fname, options, result,
+ dbg);
+}
+
+IOStatus TimedFileSystem::NewRandomRWFile(
+ const std::string& fname, const FileOptions& options,
+ std::unique_ptr<FSRandomRWFile>* result, IODebugContext* dbg) {
+ PERF_TIMER_GUARD(env_new_random_rw_file_nanos);
+ return FileSystemWrapper::NewRandomRWFile(fname, options, result, dbg);
+}
+
+IOStatus TimedFileSystem::NewDirectory(const std::string& name,
+ const IOOptions& options,
+ std::unique_ptr<FSDirectory>* result,
+ IODebugContext* dbg) {
+ PERF_TIMER_GUARD(env_new_directory_nanos);
+ return FileSystemWrapper::NewDirectory(name, options, result, dbg);
+}
+
+IOStatus TimedFileSystem::FileExists(const std::string& fname,
+ const IOOptions& options,
+ IODebugContext* dbg) {
+ PERF_TIMER_GUARD(env_file_exists_nanos);
+ return FileSystemWrapper::FileExists(fname, options, dbg);
+}
+
+IOStatus TimedFileSystem::GetChildren(const std::string& dir,
+ const IOOptions& options,
+ std::vector<std::string>* result,
+ IODebugContext* dbg) {
+ PERF_TIMER_GUARD(env_get_children_nanos);
+ return FileSystemWrapper::GetChildren(dir, options, result, dbg);
+}
+
+IOStatus TimedFileSystem::GetChildrenFileAttributes(
+ const std::string& dir, const IOOptions& options,
+ std::vector<FileAttributes>* result, IODebugContext* dbg) {
+ PERF_TIMER_GUARD(env_get_children_file_attributes_nanos);
+ return FileSystemWrapper::GetChildrenFileAttributes(dir, options, result,
+ dbg);
+}
+
+IOStatus TimedFileSystem::DeleteFile(const std::string& fname,
+ const IOOptions& options,
+ IODebugContext* dbg) {
+ PERF_TIMER_GUARD(env_delete_file_nanos);
+ return FileSystemWrapper::DeleteFile(fname, options, dbg);
+}
+
+IOStatus TimedFileSystem::CreateDir(const std::string& dirname,
+ const IOOptions& options,
+ IODebugContext* dbg) {
+ PERF_TIMER_GUARD(env_create_dir_nanos);
+ return FileSystemWrapper::CreateDir(dirname, options, dbg);
+}
+
+IOStatus TimedFileSystem::CreateDirIfMissing(const std::string& dirname,
+ const IOOptions& options,
+ IODebugContext* dbg) {
+ PERF_TIMER_GUARD(env_create_dir_if_missing_nanos);
+ return FileSystemWrapper::CreateDirIfMissing(dirname, options, dbg);
+}
+
+IOStatus TimedFileSystem::DeleteDir(const std::string& dirname,
+ const IOOptions& options,
+ IODebugContext* dbg) {
+ PERF_TIMER_GUARD(env_delete_dir_nanos);
+ return FileSystemWrapper::DeleteDir(dirname, options, dbg);
+}
+
+IOStatus TimedFileSystem::GetFileSize(const std::string& fname,
+ const IOOptions& options,
+ uint64_t* file_size,
+ IODebugContext* dbg) {
+ PERF_TIMER_GUARD(env_get_file_size_nanos);
+ return FileSystemWrapper::GetFileSize(fname, options, file_size, dbg);
+}
+
+IOStatus TimedFileSystem::GetFileModificationTime(const std::string& fname,
+ const IOOptions& options,
+ uint64_t* file_mtime,
+ IODebugContext* dbg) {
+ PERF_TIMER_GUARD(env_get_file_modification_time_nanos);
+ return FileSystemWrapper::GetFileModificationTime(fname, options, file_mtime,
+ dbg);
+}
+
+IOStatus TimedFileSystem::RenameFile(const std::string& src,
+ const std::string& dst,
+ const IOOptions& options,
+ IODebugContext* dbg) {
+ PERF_TIMER_GUARD(env_rename_file_nanos);
+ return FileSystemWrapper::RenameFile(src, dst, options, dbg);
+}
+
+IOStatus TimedFileSystem::LinkFile(const std::string& src,
+ const std::string& dst,
+ const IOOptions& options,
+ IODebugContext* dbg) {
+ PERF_TIMER_GUARD(env_link_file_nanos);
+ return FileSystemWrapper::LinkFile(src, dst, options, dbg);
+}
+
+IOStatus TimedFileSystem::LockFile(const std::string& fname,
+ const IOOptions& options, FileLock** lock,
+ IODebugContext* dbg) {
+ PERF_TIMER_GUARD(env_lock_file_nanos);
+ return FileSystemWrapper::LockFile(fname, options, lock, dbg);
+}
+
+IOStatus TimedFileSystem::UnlockFile(FileLock* lock, const IOOptions& options,
+ IODebugContext* dbg) {
+ PERF_TIMER_GUARD(env_unlock_file_nanos);
+ return FileSystemWrapper::UnlockFile(lock, options, dbg);
+}
+
+IOStatus TimedFileSystem::NewLogger(const std::string& fname,
+ const IOOptions& options,
+ std::shared_ptr<Logger>* result,
+ IODebugContext* dbg) {
+ PERF_TIMER_GUARD(env_new_logger_nanos);
+ return FileSystemWrapper::NewLogger(fname, options, result, dbg);
+}
+
+std::shared_ptr<FileSystem> NewTimedFileSystem(
+ const std::shared_ptr<FileSystem>& base) {
+ return std::make_shared<TimedFileSystem>(base);
+}
+
+// An environment that measures function call times for filesystem
+// operations, reporting results to variables in PerfContext.
+Env* NewTimedEnv(Env* base_env) {
+ std::shared_ptr<FileSystem> timed_fs =
+ NewTimedFileSystem(base_env->GetFileSystem());
+ return new CompositeEnvWrapper(base_env, timed_fs);
+}
+
+#else // ROCKSDB_LITE
+
+Env* NewTimedEnv(Env* /*base_env*/) { return nullptr; }
+
+#endif // !ROCKSDB_LITE
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/env_timed.h b/src/rocksdb/utilities/env_timed.h
new file mode 100644
index 000000000..2d34fd590
--- /dev/null
+++ b/src/rocksdb/utilities/env_timed.h
@@ -0,0 +1,97 @@
+// Copyright (c) 2019-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+
+#pragma once
+#include "rocksdb/file_system.h"
+namespace ROCKSDB_NAMESPACE {
+#ifndef ROCKSDB_LITE
+class TimedFileSystem : public FileSystemWrapper {
+ public:
+ explicit TimedFileSystem(const std::shared_ptr<FileSystem>& base);
+
+ static const char* kClassName() { return "TimedFS"; }
+ const char* Name() const override { return kClassName(); }
+
+ IOStatus NewSequentialFile(const std::string& fname,
+ const FileOptions& options,
+ std::unique_ptr<FSSequentialFile>* result,
+ IODebugContext* dbg) override;
+
+ IOStatus NewRandomAccessFile(const std::string& fname,
+ const FileOptions& options,
+ std::unique_ptr<FSRandomAccessFile>* result,
+ IODebugContext* dbg) override;
+
+ IOStatus NewWritableFile(const std::string& fname, const FileOptions& options,
+ std::unique_ptr<FSWritableFile>* result,
+ IODebugContext* dbg) override;
+
+ IOStatus ReuseWritableFile(const std::string& fname,
+ const std::string& old_fname,
+ const FileOptions& options,
+ std::unique_ptr<FSWritableFile>* result,
+ IODebugContext* dbg) override;
+
+ IOStatus NewRandomRWFile(const std::string& fname, const FileOptions& options,
+ std::unique_ptr<FSRandomRWFile>* result,
+ IODebugContext* dbg) override;
+
+ IOStatus NewDirectory(const std::string& name, const IOOptions& options,
+ std::unique_ptr<FSDirectory>* result,
+ IODebugContext* dbg) override;
+
+ IOStatus FileExists(const std::string& fname, const IOOptions& options,
+ IODebugContext* dbg) override;
+
+ IOStatus GetChildren(const std::string& dir, const IOOptions& options,
+ std::vector<std::string>* result,
+ IODebugContext* dbg) override;
+
+ IOStatus GetChildrenFileAttributes(const std::string& dir,
+ const IOOptions& options,
+ std::vector<FileAttributes>* result,
+ IODebugContext* dbg) override;
+
+ IOStatus DeleteFile(const std::string& fname, const IOOptions& options,
+ IODebugContext* dbg) override;
+
+ IOStatus CreateDir(const std::string& dirname, const IOOptions& options,
+ IODebugContext* dbg) override;
+
+ IOStatus CreateDirIfMissing(const std::string& dirname,
+ const IOOptions& options,
+ IODebugContext* dbg) override;
+
+ IOStatus DeleteDir(const std::string& dirname, const IOOptions& options,
+ IODebugContext* dbg) override;
+
+ IOStatus GetFileSize(const std::string& fname, const IOOptions& options,
+ uint64_t* file_size, IODebugContext* dbg) override;
+
+ IOStatus GetFileModificationTime(const std::string& fname,
+ const IOOptions& options,
+ uint64_t* file_mtime,
+ IODebugContext* dbg) override;
+
+ IOStatus RenameFile(const std::string& src, const std::string& dst,
+ const IOOptions& options, IODebugContext* dbg) override;
+
+ IOStatus LinkFile(const std::string& src, const std::string& dst,
+ const IOOptions& options, IODebugContext* dbg) override;
+
+ IOStatus LockFile(const std::string& fname, const IOOptions& options,
+ FileLock** lock, IODebugContext* dbg) override;
+
+ IOStatus UnlockFile(FileLock* lock, const IOOptions& options,
+ IODebugContext* dbg) override;
+
+ IOStatus NewLogger(const std::string& fname, const IOOptions& options,
+ std::shared_ptr<Logger>* result,
+ IODebugContext* dbg) override;
+};
+
+#endif // ROCKSDB_LITE
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/env_timed_test.cc b/src/rocksdb/utilities/env_timed_test.cc
new file mode 100644
index 000000000..6e392579d
--- /dev/null
+++ b/src/rocksdb/utilities/env_timed_test.cc
@@ -0,0 +1,44 @@
+// Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "rocksdb/env.h"
+#include "rocksdb/perf_context.h"
+#include "test_util/testharness.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class TimedEnvTest : public testing::Test {};
+
+TEST_F(TimedEnvTest, BasicTest) {
+ SetPerfLevel(PerfLevel::kEnableTime);
+ ASSERT_EQ(0, get_perf_context()->env_new_writable_file_nanos);
+
+ std::unique_ptr<Env> mem_env(NewMemEnv(Env::Default()));
+ std::unique_ptr<Env> timed_env(NewTimedEnv(mem_env.get()));
+ std::unique_ptr<WritableFile> writable_file;
+ ASSERT_OK(timed_env->NewWritableFile("f", &writable_file, EnvOptions()));
+
+ ASSERT_GT(get_perf_context()->env_new_writable_file_nanos, 0);
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
+
+#else // ROCKSDB_LITE
+#include <stdio.h>
+
+int main(int /*argc*/, char** /*argv*/) {
+ fprintf(stderr, "SKIPPED as TimedEnv is not supported in ROCKSDB_LITE\n");
+ return 0;
+}
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/fault_injection_env.cc b/src/rocksdb/utilities/fault_injection_env.cc
new file mode 100644
index 000000000..b0495a8c1
--- /dev/null
+++ b/src/rocksdb/utilities/fault_injection_env.cc
@@ -0,0 +1,555 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright 2014 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+// This test uses a custom Env to keep track of the state of a filesystem as of
+// the last "sync". It then checks for data loss errors by purposely dropping
+// file data (or entire files) not protected by a "sync".
+
+#include "utilities/fault_injection_env.h"
+
+#include <functional>
+#include <utility>
+
+#include "util/random.h"
+namespace ROCKSDB_NAMESPACE {
+
+// Assume a filename, and not a directory name like "/foo/bar/"
+std::string GetDirName(const std::string filename) {
+ size_t found = filename.find_last_of("/\\");
+ if (found == std::string::npos) {
+ return "";
+ } else {
+ return filename.substr(0, found);
+ }
+}
+
+// A basic file truncation function suitable for this test.
+Status Truncate(Env* env, const std::string& filename, uint64_t length) {
+ std::unique_ptr<SequentialFile> orig_file;
+ const EnvOptions options;
+ Status s = env->NewSequentialFile(filename, &orig_file, options);
+ if (!s.ok()) {
+ fprintf(stderr, "Cannot open file %s for truncation: %s\n",
+ filename.c_str(), s.ToString().c_str());
+ return s;
+ }
+
+ std::unique_ptr<char[]> scratch(new char[length]);
+ ROCKSDB_NAMESPACE::Slice result;
+ s = orig_file->Read(length, &result, scratch.get());
+#ifdef OS_WIN
+ orig_file.reset();
+#endif
+ if (s.ok()) {
+ std::string tmp_name = GetDirName(filename) + "/truncate.tmp";
+ std::unique_ptr<WritableFile> tmp_file;
+ s = env->NewWritableFile(tmp_name, &tmp_file, options);
+ if (s.ok()) {
+ s = tmp_file->Append(result);
+ if (s.ok()) {
+ s = env->RenameFile(tmp_name, filename);
+ } else {
+ fprintf(stderr, "Cannot rename file %s to %s: %s\n", tmp_name.c_str(),
+ filename.c_str(), s.ToString().c_str());
+ env->DeleteFile(tmp_name);
+ }
+ }
+ }
+ if (!s.ok()) {
+ fprintf(stderr, "Cannot truncate file %s: %s\n", filename.c_str(),
+ s.ToString().c_str());
+ }
+
+ return s;
+}
+
+// Trim the tailing "/" in the end of `str`
+std::string TrimDirname(const std::string& str) {
+ size_t found = str.find_last_not_of("/");
+ if (found == std::string::npos) {
+ return str;
+ }
+ return str.substr(0, found + 1);
+}
+
+// Return pair <parent directory name, file name> of a full path.
+std::pair<std::string, std::string> GetDirAndName(const std::string& name) {
+ std::string dirname = GetDirName(name);
+ std::string fname = name.substr(dirname.size() + 1);
+ return std::make_pair(dirname, fname);
+}
+
+Status FileState::DropUnsyncedData(Env* env) const {
+ ssize_t sync_pos = pos_at_last_sync_ == -1 ? 0 : pos_at_last_sync_;
+ return Truncate(env, filename_, sync_pos);
+}
+
+Status FileState::DropRandomUnsyncedData(Env* env, Random* rand) const {
+ ssize_t sync_pos = pos_at_last_sync_ == -1 ? 0 : pos_at_last_sync_;
+ assert(pos_ >= sync_pos);
+ int range = static_cast<int>(pos_ - sync_pos);
+ uint64_t truncated_size =
+ static_cast<uint64_t>(sync_pos) + rand->Uniform(range);
+ return Truncate(env, filename_, truncated_size);
+}
+
+Status TestDirectory::Fsync() {
+ if (!env_->IsFilesystemActive()) {
+ return env_->GetError();
+ }
+ env_->SyncDir(dirname_);
+ return dir_->Fsync();
+}
+
+Status TestDirectory::Close() {
+ if (!env_->IsFilesystemActive()) {
+ return env_->GetError();
+ }
+ return dir_->Close();
+}
+
+TestRandomAccessFile::TestRandomAccessFile(
+ std::unique_ptr<RandomAccessFile>&& target, FaultInjectionTestEnv* env)
+ : target_(std::move(target)), env_(env) {
+ assert(target_);
+ assert(env_);
+}
+
+Status TestRandomAccessFile::Read(uint64_t offset, size_t n, Slice* result,
+ char* scratch) const {
+ assert(env_);
+ if (!env_->IsFilesystemActive()) {
+ return env_->GetError();
+ }
+
+ assert(target_);
+ return target_->Read(offset, n, result, scratch);
+}
+
+Status TestRandomAccessFile::Prefetch(uint64_t offset, size_t n) {
+ assert(env_);
+ if (!env_->IsFilesystemActive()) {
+ return env_->GetError();
+ }
+
+ assert(target_);
+ return target_->Prefetch(offset, n);
+}
+
+Status TestRandomAccessFile::MultiRead(ReadRequest* reqs, size_t num_reqs) {
+ assert(env_);
+ if (!env_->IsFilesystemActive()) {
+ const Status s = env_->GetError();
+
+ assert(reqs);
+ for (size_t i = 0; i < num_reqs; ++i) {
+ reqs[i].status = s;
+ }
+
+ return s;
+ }
+
+ assert(target_);
+ return target_->MultiRead(reqs, num_reqs);
+}
+
+TestWritableFile::TestWritableFile(const std::string& fname,
+ std::unique_ptr<WritableFile>&& f,
+ FaultInjectionTestEnv* env)
+ : state_(fname),
+ target_(std::move(f)),
+ writable_file_opened_(true),
+ env_(env) {
+ assert(target_ != nullptr);
+ state_.pos_ = 0;
+}
+
+TestWritableFile::~TestWritableFile() {
+ if (writable_file_opened_) {
+ Close().PermitUncheckedError();
+ }
+}
+
+Status TestWritableFile::Append(const Slice& data) {
+ if (!env_->IsFilesystemActive()) {
+ return env_->GetError();
+ }
+ Status s = target_->Append(data);
+ if (s.ok()) {
+ state_.pos_ += data.size();
+ env_->WritableFileAppended(state_);
+ }
+ return s;
+}
+
+Status TestWritableFile::Close() {
+ writable_file_opened_ = false;
+ Status s = target_->Close();
+ if (s.ok()) {
+ env_->WritableFileClosed(state_);
+ }
+ return s;
+}
+
+Status TestWritableFile::Flush() {
+ Status s = target_->Flush();
+ if (s.ok() && env_->IsFilesystemActive()) {
+ state_.pos_at_last_flush_ = state_.pos_;
+ }
+ return s;
+}
+
+Status TestWritableFile::Sync() {
+ if (!env_->IsFilesystemActive()) {
+ return Status::IOError("FaultInjectionTestEnv: not active");
+ }
+ // No need to actual sync.
+ state_.pos_at_last_sync_ = state_.pos_;
+ env_->WritableFileSynced(state_);
+ return Status::OK();
+}
+
+TestRandomRWFile::TestRandomRWFile(const std::string& /*fname*/,
+ std::unique_ptr<RandomRWFile>&& f,
+ FaultInjectionTestEnv* env)
+ : target_(std::move(f)), file_opened_(true), env_(env) {
+ assert(target_ != nullptr);
+}
+
+TestRandomRWFile::~TestRandomRWFile() {
+ if (file_opened_) {
+ Close().PermitUncheckedError();
+ }
+}
+
+Status TestRandomRWFile::Write(uint64_t offset, const Slice& data) {
+ if (!env_->IsFilesystemActive()) {
+ return env_->GetError();
+ }
+ return target_->Write(offset, data);
+}
+
+Status TestRandomRWFile::Read(uint64_t offset, size_t n, Slice* result,
+ char* scratch) const {
+ if (!env_->IsFilesystemActive()) {
+ return env_->GetError();
+ }
+ return target_->Read(offset, n, result, scratch);
+}
+
+Status TestRandomRWFile::Close() {
+ file_opened_ = false;
+ return target_->Close();
+}
+
+Status TestRandomRWFile::Flush() {
+ if (!env_->IsFilesystemActive()) {
+ return env_->GetError();
+ }
+ return target_->Flush();
+}
+
+Status TestRandomRWFile::Sync() {
+ if (!env_->IsFilesystemActive()) {
+ return env_->GetError();
+ }
+ return target_->Sync();
+}
+
+Status FaultInjectionTestEnv::NewDirectory(const std::string& name,
+ std::unique_ptr<Directory>* result) {
+ std::unique_ptr<Directory> r;
+ Status s = target()->NewDirectory(name, &r);
+ assert(s.ok());
+ if (!s.ok()) {
+ return s;
+ }
+ result->reset(new TestDirectory(this, TrimDirname(name), r.release()));
+ return Status::OK();
+}
+
+Status FaultInjectionTestEnv::NewWritableFile(
+ const std::string& fname, std::unique_ptr<WritableFile>* result,
+ const EnvOptions& soptions) {
+ if (!IsFilesystemActive()) {
+ return GetError();
+ }
+ // Not allow overwriting files
+ Status s = target()->FileExists(fname);
+ if (s.ok()) {
+ return Status::Corruption("File already exists.");
+ } else if (!s.IsNotFound()) {
+ assert(s.IsIOError());
+ return s;
+ }
+ s = target()->NewWritableFile(fname, result, soptions);
+ if (s.ok()) {
+ result->reset(new TestWritableFile(fname, std::move(*result), this));
+ // WritableFileWriter* file is opened
+ // again then it will be truncated - so forget our saved state.
+ UntrackFile(fname);
+ MutexLock l(&mutex_);
+ open_managed_files_.insert(fname);
+ auto dir_and_name = GetDirAndName(fname);
+ auto& list = dir_to_new_files_since_last_sync_[dir_and_name.first];
+ list.insert(dir_and_name.second);
+ }
+ return s;
+}
+
+Status FaultInjectionTestEnv::ReopenWritableFile(
+ const std::string& fname, std::unique_ptr<WritableFile>* result,
+ const EnvOptions& soptions) {
+ if (!IsFilesystemActive()) {
+ return GetError();
+ }
+
+ bool exists;
+ Status s, exists_s = target()->FileExists(fname);
+ if (exists_s.IsNotFound()) {
+ exists = false;
+ } else if (exists_s.ok()) {
+ exists = true;
+ } else {
+ s = exists_s;
+ exists = false;
+ }
+
+ if (s.ok()) {
+ s = target()->ReopenWritableFile(fname, result, soptions);
+ }
+
+ // Only track files we created. Files created outside of this
+ // `FaultInjectionTestEnv` are not eligible for tracking/data dropping
+ // (for example, they may contain data a previous db_stress run expects to
+ // be recovered). This could be extended to track/drop data appended once
+ // the file is under `FaultInjectionTestEnv`'s control.
+ if (s.ok()) {
+ bool should_track;
+ {
+ MutexLock l(&mutex_);
+ if (db_file_state_.find(fname) != db_file_state_.end()) {
+ // It was written by this `Env` earlier.
+ assert(exists);
+ should_track = true;
+ } else if (!exists) {
+ // It was created by this `Env` just now.
+ should_track = true;
+ open_managed_files_.insert(fname);
+ auto dir_and_name = GetDirAndName(fname);
+ auto& list = dir_to_new_files_since_last_sync_[dir_and_name.first];
+ list.insert(dir_and_name.second);
+ } else {
+ should_track = false;
+ }
+ }
+ if (should_track) {
+ result->reset(new TestWritableFile(fname, std::move(*result), this));
+ }
+ }
+ return s;
+}
+
+Status FaultInjectionTestEnv::NewRandomRWFile(
+ const std::string& fname, std::unique_ptr<RandomRWFile>* result,
+ const EnvOptions& soptions) {
+ if (!IsFilesystemActive()) {
+ return GetError();
+ }
+ Status s = target()->NewRandomRWFile(fname, result, soptions);
+ if (s.ok()) {
+ result->reset(new TestRandomRWFile(fname, std::move(*result), this));
+ // WritableFileWriter* file is opened
+ // again then it will be truncated - so forget our saved state.
+ UntrackFile(fname);
+ MutexLock l(&mutex_);
+ open_managed_files_.insert(fname);
+ auto dir_and_name = GetDirAndName(fname);
+ auto& list = dir_to_new_files_since_last_sync_[dir_and_name.first];
+ list.insert(dir_and_name.second);
+ }
+ return s;
+}
+
+Status FaultInjectionTestEnv::NewRandomAccessFile(
+ const std::string& fname, std::unique_ptr<RandomAccessFile>* result,
+ const EnvOptions& soptions) {
+ if (!IsFilesystemActive()) {
+ return GetError();
+ }
+
+ assert(target());
+ const Status s = target()->NewRandomAccessFile(fname, result, soptions);
+ if (!s.ok()) {
+ return s;
+ }
+
+ assert(result);
+ result->reset(new TestRandomAccessFile(std::move(*result), this));
+
+ return Status::OK();
+}
+
+Status FaultInjectionTestEnv::DeleteFile(const std::string& f) {
+ if (!IsFilesystemActive()) {
+ return GetError();
+ }
+ Status s = EnvWrapper::DeleteFile(f);
+ if (s.ok()) {
+ UntrackFile(f);
+ }
+ return s;
+}
+
+Status FaultInjectionTestEnv::RenameFile(const std::string& s,
+ const std::string& t) {
+ if (!IsFilesystemActive()) {
+ return GetError();
+ }
+ Status ret = EnvWrapper::RenameFile(s, t);
+
+ if (ret.ok()) {
+ MutexLock l(&mutex_);
+ if (db_file_state_.find(s) != db_file_state_.end()) {
+ db_file_state_[t] = db_file_state_[s];
+ db_file_state_.erase(s);
+ }
+
+ auto sdn = GetDirAndName(s);
+ auto tdn = GetDirAndName(t);
+ if (dir_to_new_files_since_last_sync_[sdn.first].erase(sdn.second) != 0) {
+ auto& tlist = dir_to_new_files_since_last_sync_[tdn.first];
+ assert(tlist.find(tdn.second) == tlist.end());
+ tlist.insert(tdn.second);
+ }
+ }
+
+ return ret;
+}
+
+Status FaultInjectionTestEnv::LinkFile(const std::string& s,
+ const std::string& t) {
+ if (!IsFilesystemActive()) {
+ return GetError();
+ }
+ Status ret = EnvWrapper::LinkFile(s, t);
+
+ if (ret.ok()) {
+ MutexLock l(&mutex_);
+ if (db_file_state_.find(s) != db_file_state_.end()) {
+ db_file_state_[t] = db_file_state_[s];
+ }
+
+ auto sdn = GetDirAndName(s);
+ auto tdn = GetDirAndName(t);
+ if (dir_to_new_files_since_last_sync_[sdn.first].find(sdn.second) !=
+ dir_to_new_files_since_last_sync_[sdn.first].end()) {
+ auto& tlist = dir_to_new_files_since_last_sync_[tdn.first];
+ assert(tlist.find(tdn.second) == tlist.end());
+ tlist.insert(tdn.second);
+ }
+ }
+
+ return ret;
+}
+
+void FaultInjectionTestEnv::WritableFileClosed(const FileState& state) {
+ MutexLock l(&mutex_);
+ if (open_managed_files_.find(state.filename_) != open_managed_files_.end()) {
+ db_file_state_[state.filename_] = state;
+ open_managed_files_.erase(state.filename_);
+ }
+}
+
+void FaultInjectionTestEnv::WritableFileSynced(const FileState& state) {
+ MutexLock l(&mutex_);
+ if (open_managed_files_.find(state.filename_) != open_managed_files_.end()) {
+ if (db_file_state_.find(state.filename_) == db_file_state_.end()) {
+ db_file_state_.insert(std::make_pair(state.filename_, state));
+ } else {
+ db_file_state_[state.filename_] = state;
+ }
+ }
+}
+
+void FaultInjectionTestEnv::WritableFileAppended(const FileState& state) {
+ MutexLock l(&mutex_);
+ if (open_managed_files_.find(state.filename_) != open_managed_files_.end()) {
+ if (db_file_state_.find(state.filename_) == db_file_state_.end()) {
+ db_file_state_.insert(std::make_pair(state.filename_, state));
+ } else {
+ db_file_state_[state.filename_] = state;
+ }
+ }
+}
+
+// For every file that is not fully synced, make a call to `func` with
+// FileState of the file as the parameter.
+Status FaultInjectionTestEnv::DropFileData(
+ std::function<Status(Env*, FileState)> func) {
+ Status s;
+ MutexLock l(&mutex_);
+ for (std::map<std::string, FileState>::const_iterator it =
+ db_file_state_.begin();
+ s.ok() && it != db_file_state_.end(); ++it) {
+ const FileState& state = it->second;
+ if (!state.IsFullySynced()) {
+ s = func(target(), state);
+ }
+ }
+ return s;
+}
+
+Status FaultInjectionTestEnv::DropUnsyncedFileData() {
+ return DropFileData([&](Env* env, const FileState& state) {
+ return state.DropUnsyncedData(env);
+ });
+}
+
+Status FaultInjectionTestEnv::DropRandomUnsyncedFileData(Random* rnd) {
+ return DropFileData([&](Env* env, const FileState& state) {
+ return state.DropRandomUnsyncedData(env, rnd);
+ });
+}
+
+Status FaultInjectionTestEnv::DeleteFilesCreatedAfterLastDirSync() {
+ // Because DeleteFile access this container make a copy to avoid deadlock
+ std::map<std::string, std::set<std::string>> map_copy;
+ {
+ MutexLock l(&mutex_);
+ map_copy.insert(dir_to_new_files_since_last_sync_.begin(),
+ dir_to_new_files_since_last_sync_.end());
+ }
+
+ for (auto& pair : map_copy) {
+ for (std::string name : pair.second) {
+ Status s = DeleteFile(pair.first + "/" + name);
+ if (!s.ok()) {
+ return s;
+ }
+ }
+ }
+ return Status::OK();
+}
+void FaultInjectionTestEnv::ResetState() {
+ MutexLock l(&mutex_);
+ db_file_state_.clear();
+ dir_to_new_files_since_last_sync_.clear();
+ SetFilesystemActiveNoLock(true);
+}
+
+void FaultInjectionTestEnv::UntrackFile(const std::string& f) {
+ MutexLock l(&mutex_);
+ auto dir_and_name = GetDirAndName(f);
+ dir_to_new_files_since_last_sync_[dir_and_name.first].erase(
+ dir_and_name.second);
+ db_file_state_.erase(f);
+ open_managed_files_.erase(f);
+}
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/fault_injection_env.h b/src/rocksdb/utilities/fault_injection_env.h
new file mode 100644
index 000000000..549bfe716
--- /dev/null
+++ b/src/rocksdb/utilities/fault_injection_env.h
@@ -0,0 +1,258 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright 2014 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+// This test uses a custom Env to keep track of the state of a filesystem as of
+// the last "sync". It then checks for data loss errors by purposely dropping
+// file data (or entire files) not protected by a "sync".
+
+#pragma once
+
+#include <map>
+#include <set>
+#include <string>
+
+#include "file/filename.h"
+#include "rocksdb/env.h"
+#include "util/mutexlock.h"
+
+namespace ROCKSDB_NAMESPACE {
+class Random;
+class TestWritableFile;
+class FaultInjectionTestEnv;
+
+struct FileState {
+ std::string filename_;
+ ssize_t pos_;
+ ssize_t pos_at_last_sync_;
+ ssize_t pos_at_last_flush_;
+
+ explicit FileState(const std::string& filename)
+ : filename_(filename),
+ pos_(-1),
+ pos_at_last_sync_(-1),
+ pos_at_last_flush_(-1) {}
+
+ FileState() : pos_(-1), pos_at_last_sync_(-1), pos_at_last_flush_(-1) {}
+
+ bool IsFullySynced() const { return pos_ <= 0 || pos_ == pos_at_last_sync_; }
+
+ Status DropUnsyncedData(Env* env) const;
+
+ Status DropRandomUnsyncedData(Env* env, Random* rand) const;
+};
+
+class TestRandomAccessFile : public RandomAccessFile {
+ public:
+ TestRandomAccessFile(std::unique_ptr<RandomAccessFile>&& target,
+ FaultInjectionTestEnv* env);
+
+ Status Read(uint64_t offset, size_t n, Slice* result,
+ char* scratch) const override;
+
+ Status Prefetch(uint64_t offset, size_t n) override;
+
+ Status MultiRead(ReadRequest* reqs, size_t num_reqs) override;
+
+ private:
+ std::unique_ptr<RandomAccessFile> target_;
+ FaultInjectionTestEnv* env_;
+};
+
+// A wrapper around WritableFileWriter* file
+// is written to or sync'ed.
+class TestWritableFile : public WritableFile {
+ public:
+ explicit TestWritableFile(const std::string& fname,
+ std::unique_ptr<WritableFile>&& f,
+ FaultInjectionTestEnv* env);
+ virtual ~TestWritableFile();
+ virtual Status Append(const Slice& data) override;
+ virtual Status Append(
+ const Slice& data,
+ const DataVerificationInfo& /*verification_info*/) override {
+ return Append(data);
+ }
+ virtual Status Truncate(uint64_t size) override {
+ return target_->Truncate(size);
+ }
+ virtual Status Close() override;
+ virtual Status Flush() override;
+ virtual Status Sync() override;
+ virtual bool IsSyncThreadSafe() const override { return true; }
+ virtual Status PositionedAppend(const Slice& data, uint64_t offset) override {
+ return target_->PositionedAppend(data, offset);
+ }
+ virtual Status PositionedAppend(
+ const Slice& data, uint64_t offset,
+ const DataVerificationInfo& /*verification_info*/) override {
+ return PositionedAppend(data, offset);
+ }
+ virtual bool use_direct_io() const override {
+ return target_->use_direct_io();
+ };
+
+ private:
+ FileState state_;
+ std::unique_ptr<WritableFile> target_;
+ bool writable_file_opened_;
+ FaultInjectionTestEnv* env_;
+};
+
+// A wrapper around WritableFileWriter* file
+// is written to or sync'ed.
+class TestRandomRWFile : public RandomRWFile {
+ public:
+ explicit TestRandomRWFile(const std::string& fname,
+ std::unique_ptr<RandomRWFile>&& f,
+ FaultInjectionTestEnv* env);
+ virtual ~TestRandomRWFile();
+ Status Write(uint64_t offset, const Slice& data) override;
+ Status Read(uint64_t offset, size_t n, Slice* result,
+ char* scratch) const override;
+ Status Close() override;
+ Status Flush() override;
+ Status Sync() override;
+ size_t GetRequiredBufferAlignment() const override {
+ return target_->GetRequiredBufferAlignment();
+ }
+ bool use_direct_io() const override { return target_->use_direct_io(); };
+
+ private:
+ std::unique_ptr<RandomRWFile> target_;
+ bool file_opened_;
+ FaultInjectionTestEnv* env_;
+};
+
+class TestDirectory : public Directory {
+ public:
+ explicit TestDirectory(FaultInjectionTestEnv* env, std::string dirname,
+ Directory* dir)
+ : env_(env), dirname_(dirname), dir_(dir) {}
+ ~TestDirectory() {}
+
+ virtual Status Fsync() override;
+ virtual Status Close() override;
+
+ private:
+ FaultInjectionTestEnv* env_;
+ std::string dirname_;
+ std::unique_ptr<Directory> dir_;
+};
+
+class FaultInjectionTestEnv : public EnvWrapper {
+ public:
+ explicit FaultInjectionTestEnv(Env* base)
+ : EnvWrapper(base), filesystem_active_(true) {}
+ virtual ~FaultInjectionTestEnv() { error_.PermitUncheckedError(); }
+
+ static const char* kClassName() { return "FaultInjectionTestEnv"; }
+ const char* Name() const override { return kClassName(); }
+
+ Status NewDirectory(const std::string& name,
+ std::unique_ptr<Directory>* result) override;
+
+ Status NewWritableFile(const std::string& fname,
+ std::unique_ptr<WritableFile>* result,
+ const EnvOptions& soptions) override;
+
+ Status ReopenWritableFile(const std::string& fname,
+ std::unique_ptr<WritableFile>* result,
+ const EnvOptions& soptions) override;
+
+ Status NewRandomRWFile(const std::string& fname,
+ std::unique_ptr<RandomRWFile>* result,
+ const EnvOptions& soptions) override;
+
+ Status NewRandomAccessFile(const std::string& fname,
+ std::unique_ptr<RandomAccessFile>* result,
+ const EnvOptions& soptions) override;
+
+ virtual Status DeleteFile(const std::string& f) override;
+
+ virtual Status RenameFile(const std::string& s,
+ const std::string& t) override;
+
+ virtual Status LinkFile(const std::string& s, const std::string& t) override;
+
+// Undef to eliminate clash on Windows
+#undef GetFreeSpace
+ virtual Status GetFreeSpace(const std::string& path,
+ uint64_t* disk_free) override {
+ if (!IsFilesystemActive() &&
+ error_.subcode() == IOStatus::SubCode::kNoSpace) {
+ *disk_free = 0;
+ return Status::OK();
+ } else {
+ return target()->GetFreeSpace(path, disk_free);
+ }
+ }
+
+ void WritableFileClosed(const FileState& state);
+
+ void WritableFileSynced(const FileState& state);
+
+ void WritableFileAppended(const FileState& state);
+
+ // For every file that is not fully synced, make a call to `func` with
+ // FileState of the file as the parameter.
+ Status DropFileData(std::function<Status(Env*, FileState)> func);
+
+ Status DropUnsyncedFileData();
+
+ Status DropRandomUnsyncedFileData(Random* rnd);
+
+ Status DeleteFilesCreatedAfterLastDirSync();
+
+ void ResetState();
+
+ void UntrackFile(const std::string& f);
+
+ void SyncDir(const std::string& dirname) {
+ MutexLock l(&mutex_);
+ dir_to_new_files_since_last_sync_.erase(dirname);
+ }
+
+ // Setting the filesystem to inactive is the test equivalent to simulating a
+ // system reset. Setting to inactive will freeze our saved filesystem state so
+ // that it will stop being recorded. It can then be reset back to the state at
+ // the time of the reset.
+ bool IsFilesystemActive() {
+ MutexLock l(&mutex_);
+ return filesystem_active_;
+ }
+ void SetFilesystemActiveNoLock(
+ bool active, Status error = Status::Corruption("Not active")) {
+ error.PermitUncheckedError();
+ filesystem_active_ = active;
+ if (!active) {
+ error_ = error;
+ }
+ error.PermitUncheckedError();
+ }
+ void SetFilesystemActive(bool active,
+ Status error = Status::Corruption("Not active")) {
+ error.PermitUncheckedError();
+ MutexLock l(&mutex_);
+ SetFilesystemActiveNoLock(active, error);
+ error.PermitUncheckedError();
+ }
+ void AssertNoOpenFile() { assert(open_managed_files_.empty()); }
+ Status GetError() { return error_; }
+
+ private:
+ port::Mutex mutex_;
+ std::map<std::string, FileState> db_file_state_;
+ std::set<std::string> open_managed_files_;
+ std::unordered_map<std::string, std::set<std::string>>
+ dir_to_new_files_since_last_sync_;
+ bool filesystem_active_; // Record flushes, syncs, writes
+ Status error_;
+};
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/fault_injection_fs.cc b/src/rocksdb/utilities/fault_injection_fs.cc
new file mode 100644
index 000000000..549051856
--- /dev/null
+++ b/src/rocksdb/utilities/fault_injection_fs.cc
@@ -0,0 +1,1032 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright 2014 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+// This test uses a custom FileSystem to keep track of the state of a file
+// system the last "Sync". The data being written is cached in a "buffer".
+// Only when "Sync" is called, the data will be persistent. It can simulate
+// file data loss (or entire files) not protected by a "Sync". For any of the
+// FileSystem related operations, by specify the "IOStatus Error", a specific
+// error can be returned when file system is not activated.
+
+#include "utilities/fault_injection_fs.h"
+
+#include <algorithm>
+#include <functional>
+#include <utility>
+
+#include "env/composite_env_wrapper.h"
+#include "port/lang.h"
+#include "port/stack_trace.h"
+#include "test_util/sync_point.h"
+#include "util/coding.h"
+#include "util/crc32c.h"
+#include "util/random.h"
+#include "util/string_util.h"
+#include "util/xxhash.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+const std::string kNewFileNoOverwrite = "";
+
+// Assume a filename, and not a directory name like "/foo/bar/"
+std::string TestFSGetDirName(const std::string filename) {
+ size_t found = filename.find_last_of("/\\");
+ if (found == std::string::npos) {
+ return "";
+ } else {
+ return filename.substr(0, found);
+ }
+}
+
+// Trim the tailing "/" in the end of `str`
+std::string TestFSTrimDirname(const std::string& str) {
+ size_t found = str.find_last_not_of("/");
+ if (found == std::string::npos) {
+ return str;
+ }
+ return str.substr(0, found + 1);
+}
+
+// Return pair <parent directory name, file name> of a full path.
+std::pair<std::string, std::string> TestFSGetDirAndName(
+ const std::string& name) {
+ std::string dirname = TestFSGetDirName(name);
+ std::string fname = name.substr(dirname.size() + 1);
+ return std::make_pair(dirname, fname);
+}
+
+// Calculate the checksum of the data with corresponding checksum
+// type. If name does not match, no checksum is returned.
+void CalculateTypedChecksum(const ChecksumType& checksum_type, const char* data,
+ size_t size, std::string* checksum) {
+ if (checksum_type == ChecksumType::kCRC32c) {
+ uint32_t v_crc32c = crc32c::Extend(0, data, size);
+ PutFixed32(checksum, v_crc32c);
+ return;
+ } else if (checksum_type == ChecksumType::kxxHash) {
+ uint32_t v = XXH32(data, size, 0);
+ PutFixed32(checksum, v);
+ }
+ return;
+}
+
+IOStatus FSFileState::DropUnsyncedData() {
+ buffer_.resize(0);
+ return IOStatus::OK();
+}
+
+IOStatus FSFileState::DropRandomUnsyncedData(Random* rand) {
+ int range = static_cast<int>(buffer_.size());
+ size_t truncated_size = static_cast<size_t>(rand->Uniform(range));
+ buffer_.resize(truncated_size);
+ return IOStatus::OK();
+}
+
+IOStatus TestFSDirectory::Fsync(const IOOptions& options, IODebugContext* dbg) {
+ if (!fs_->IsFilesystemActive()) {
+ return fs_->GetError();
+ }
+ {
+ IOStatus in_s = fs_->InjectMetadataWriteError();
+ if (!in_s.ok()) {
+ return in_s;
+ }
+ }
+ fs_->SyncDir(dirname_);
+ IOStatus s = dir_->Fsync(options, dbg);
+ {
+ IOStatus in_s = fs_->InjectMetadataWriteError();
+ if (!in_s.ok()) {
+ return in_s;
+ }
+ }
+ return s;
+}
+
+IOStatus TestFSDirectory::Close(const IOOptions& options, IODebugContext* dbg) {
+ if (!fs_->IsFilesystemActive()) {
+ return fs_->GetError();
+ }
+ IOStatus s = dir_->Close(options, dbg);
+ return s;
+}
+
+IOStatus TestFSDirectory::FsyncWithDirOptions(
+ const IOOptions& options, IODebugContext* dbg,
+ const DirFsyncOptions& dir_fsync_options) {
+ if (!fs_->IsFilesystemActive()) {
+ return fs_->GetError();
+ }
+ {
+ IOStatus in_s = fs_->InjectMetadataWriteError();
+ if (!in_s.ok()) {
+ return in_s;
+ }
+ }
+ fs_->SyncDir(dirname_);
+ IOStatus s = dir_->FsyncWithDirOptions(options, dbg, dir_fsync_options);
+ {
+ IOStatus in_s = fs_->InjectMetadataWriteError();
+ if (!in_s.ok()) {
+ return in_s;
+ }
+ }
+ return s;
+}
+
+TestFSWritableFile::TestFSWritableFile(const std::string& fname,
+ const FileOptions& file_opts,
+ std::unique_ptr<FSWritableFile>&& f,
+ FaultInjectionTestFS* fs)
+ : state_(fname),
+ file_opts_(file_opts),
+ target_(std::move(f)),
+ writable_file_opened_(true),
+ fs_(fs) {
+ assert(target_ != nullptr);
+ state_.pos_ = 0;
+}
+
+TestFSWritableFile::~TestFSWritableFile() {
+ if (writable_file_opened_) {
+ Close(IOOptions(), nullptr).PermitUncheckedError();
+ }
+}
+
+IOStatus TestFSWritableFile::Append(const Slice& data, const IOOptions& options,
+ IODebugContext* dbg) {
+ MutexLock l(&mutex_);
+ if (!fs_->IsFilesystemActive()) {
+ return fs_->GetError();
+ }
+ if (target_->use_direct_io()) {
+ target_->Append(data, options, dbg).PermitUncheckedError();
+ } else {
+ state_.buffer_.append(data.data(), data.size());
+ state_.pos_ += data.size();
+ fs_->WritableFileAppended(state_);
+ }
+ IOStatus io_s = fs_->InjectWriteError(state_.filename_);
+ return io_s;
+}
+
+// By setting the IngestDataCorruptionBeforeWrite(), the data corruption is
+// simulated.
+IOStatus TestFSWritableFile::Append(
+ const Slice& data, const IOOptions& options,
+ const DataVerificationInfo& verification_info, IODebugContext* dbg) {
+ MutexLock l(&mutex_);
+ if (!fs_->IsFilesystemActive()) {
+ return fs_->GetError();
+ }
+ if (fs_->ShouldDataCorruptionBeforeWrite()) {
+ return IOStatus::Corruption("Data is corrupted!");
+ }
+
+ // Calculate the checksum
+ std::string checksum;
+ CalculateTypedChecksum(fs_->GetChecksumHandoffFuncType(), data.data(),
+ data.size(), &checksum);
+ if (fs_->GetChecksumHandoffFuncType() != ChecksumType::kNoChecksum &&
+ checksum != verification_info.checksum.ToString()) {
+ std::string msg = "Data is corrupted! Origin data checksum: " +
+ verification_info.checksum.ToString() +
+ "current data checksum: " + checksum;
+ return IOStatus::Corruption(msg);
+ }
+ if (target_->use_direct_io()) {
+ target_->Append(data, options, dbg).PermitUncheckedError();
+ } else {
+ state_.buffer_.append(data.data(), data.size());
+ state_.pos_ += data.size();
+ fs_->WritableFileAppended(state_);
+ }
+ IOStatus io_s = fs_->InjectWriteError(state_.filename_);
+ return io_s;
+}
+
+IOStatus TestFSWritableFile::PositionedAppend(
+ const Slice& data, uint64_t offset, const IOOptions& options,
+ const DataVerificationInfo& verification_info, IODebugContext* dbg) {
+ MutexLock l(&mutex_);
+ if (!fs_->IsFilesystemActive()) {
+ return fs_->GetError();
+ }
+ if (fs_->ShouldDataCorruptionBeforeWrite()) {
+ return IOStatus::Corruption("Data is corrupted!");
+ }
+
+ // Calculate the checksum
+ std::string checksum;
+ CalculateTypedChecksum(fs_->GetChecksumHandoffFuncType(), data.data(),
+ data.size(), &checksum);
+ if (fs_->GetChecksumHandoffFuncType() != ChecksumType::kNoChecksum &&
+ checksum != verification_info.checksum.ToString()) {
+ std::string msg = "Data is corrupted! Origin data checksum: " +
+ verification_info.checksum.ToString() +
+ "current data checksum: " + checksum;
+ return IOStatus::Corruption(msg);
+ }
+ target_->PositionedAppend(data, offset, options, dbg);
+ IOStatus io_s = fs_->InjectWriteError(state_.filename_);
+ return io_s;
+}
+
+IOStatus TestFSWritableFile::Close(const IOOptions& options,
+ IODebugContext* dbg) {
+ MutexLock l(&mutex_);
+ if (!fs_->IsFilesystemActive()) {
+ return fs_->GetError();
+ }
+ {
+ IOStatus in_s = fs_->InjectMetadataWriteError();
+ if (!in_s.ok()) {
+ return in_s;
+ }
+ }
+ writable_file_opened_ = false;
+ IOStatus io_s;
+ if (!target_->use_direct_io()) {
+ io_s = target_->Append(state_.buffer_, options, dbg);
+ }
+ if (io_s.ok()) {
+ state_.buffer_.resize(0);
+ // Ignore sync errors
+ target_->Sync(options, dbg).PermitUncheckedError();
+ io_s = target_->Close(options, dbg);
+ }
+ if (io_s.ok()) {
+ fs_->WritableFileClosed(state_);
+ IOStatus in_s = fs_->InjectMetadataWriteError();
+ if (!in_s.ok()) {
+ return in_s;
+ }
+ }
+ return io_s;
+}
+
+IOStatus TestFSWritableFile::Flush(const IOOptions&, IODebugContext*) {
+ MutexLock l(&mutex_);
+ if (!fs_->IsFilesystemActive()) {
+ return fs_->GetError();
+ }
+ if (fs_->IsFilesystemActive()) {
+ state_.pos_at_last_flush_ = state_.pos_;
+ }
+ return IOStatus::OK();
+}
+
+IOStatus TestFSWritableFile::Sync(const IOOptions& options,
+ IODebugContext* dbg) {
+ MutexLock l(&mutex_);
+ if (!fs_->IsFilesystemActive()) {
+ return fs_->GetError();
+ }
+ if (target_->use_direct_io()) {
+ // For Direct IO mode, we don't buffer anything in TestFSWritableFile.
+ // So just return
+ return IOStatus::OK();
+ }
+ IOStatus io_s = target_->Append(state_.buffer_, options, dbg);
+ state_.buffer_.resize(0);
+ // Ignore sync errors
+ target_->Sync(options, dbg).PermitUncheckedError();
+ state_.pos_at_last_sync_ = state_.pos_;
+ fs_->WritableFileSynced(state_);
+ return io_s;
+}
+
+IOStatus TestFSWritableFile::RangeSync(uint64_t offset, uint64_t nbytes,
+ const IOOptions& options,
+ IODebugContext* dbg) {
+ MutexLock l(&mutex_);
+ if (!fs_->IsFilesystemActive()) {
+ return fs_->GetError();
+ }
+ // Assumes caller passes consecutive byte ranges.
+ uint64_t sync_limit = offset + nbytes;
+ uint64_t buf_begin =
+ state_.pos_at_last_sync_ < 0 ? 0 : state_.pos_at_last_sync_;
+
+ IOStatus io_s;
+ if (sync_limit < buf_begin) {
+ return io_s;
+ }
+ uint64_t num_to_sync = std::min(static_cast<uint64_t>(state_.buffer_.size()),
+ sync_limit - buf_begin);
+ Slice buf_to_sync(state_.buffer_.data(), num_to_sync);
+ io_s = target_->Append(buf_to_sync, options, dbg);
+ state_.buffer_ = state_.buffer_.substr(num_to_sync);
+ // Ignore sync errors
+ target_->RangeSync(offset, nbytes, options, dbg).PermitUncheckedError();
+ state_.pos_at_last_sync_ = offset + num_to_sync;
+ fs_->WritableFileSynced(state_);
+ return io_s;
+}
+
+TestFSRandomRWFile::TestFSRandomRWFile(const std::string& /*fname*/,
+ std::unique_ptr<FSRandomRWFile>&& f,
+ FaultInjectionTestFS* fs)
+ : target_(std::move(f)), file_opened_(true), fs_(fs) {
+ assert(target_ != nullptr);
+}
+
+TestFSRandomRWFile::~TestFSRandomRWFile() {
+ if (file_opened_) {
+ Close(IOOptions(), nullptr).PermitUncheckedError();
+ }
+}
+
+IOStatus TestFSRandomRWFile::Write(uint64_t offset, const Slice& data,
+ const IOOptions& options,
+ IODebugContext* dbg) {
+ if (!fs_->IsFilesystemActive()) {
+ return fs_->GetError();
+ }
+ return target_->Write(offset, data, options, dbg);
+}
+
+IOStatus TestFSRandomRWFile::Read(uint64_t offset, size_t n,
+ const IOOptions& options, Slice* result,
+ char* scratch, IODebugContext* dbg) const {
+ if (!fs_->IsFilesystemActive()) {
+ return fs_->GetError();
+ }
+ return target_->Read(offset, n, options, result, scratch, dbg);
+}
+
+IOStatus TestFSRandomRWFile::Close(const IOOptions& options,
+ IODebugContext* dbg) {
+ if (!fs_->IsFilesystemActive()) {
+ return fs_->GetError();
+ }
+ file_opened_ = false;
+ return target_->Close(options, dbg);
+}
+
+IOStatus TestFSRandomRWFile::Flush(const IOOptions& options,
+ IODebugContext* dbg) {
+ if (!fs_->IsFilesystemActive()) {
+ return fs_->GetError();
+ }
+ return target_->Flush(options, dbg);
+}
+
+IOStatus TestFSRandomRWFile::Sync(const IOOptions& options,
+ IODebugContext* dbg) {
+ if (!fs_->IsFilesystemActive()) {
+ return fs_->GetError();
+ }
+ return target_->Sync(options, dbg);
+}
+
+TestFSRandomAccessFile::TestFSRandomAccessFile(
+ const std::string& /*fname*/, std::unique_ptr<FSRandomAccessFile>&& f,
+ FaultInjectionTestFS* fs)
+ : target_(std::move(f)), fs_(fs) {
+ assert(target_ != nullptr);
+}
+
+IOStatus TestFSRandomAccessFile::Read(uint64_t offset, size_t n,
+ const IOOptions& options, Slice* result,
+ char* scratch,
+ IODebugContext* dbg) const {
+ if (!fs_->IsFilesystemActive()) {
+ return fs_->GetError();
+ }
+ IOStatus s = target_->Read(offset, n, options, result, scratch, dbg);
+ if (s.ok()) {
+ s = fs_->InjectThreadSpecificReadError(
+ FaultInjectionTestFS::ErrorOperation::kRead, result, use_direct_io(),
+ scratch, /*need_count_increase=*/true, /*fault_injected=*/nullptr);
+ }
+ if (s.ok() && fs_->ShouldInjectRandomReadError()) {
+ return IOStatus::IOError("Injected read error");
+ }
+ return s;
+}
+
+IOStatus TestFSRandomAccessFile::MultiRead(FSReadRequest* reqs, size_t num_reqs,
+ const IOOptions& options,
+ IODebugContext* dbg) {
+ if (!fs_->IsFilesystemActive()) {
+ return fs_->GetError();
+ }
+ IOStatus s = target_->MultiRead(reqs, num_reqs, options, dbg);
+ bool injected_error = false;
+ for (size_t i = 0; i < num_reqs; i++) {
+ if (!reqs[i].status.ok()) {
+ // Already seeing an error.
+ break;
+ }
+ bool this_injected_error;
+ reqs[i].status = fs_->InjectThreadSpecificReadError(
+ FaultInjectionTestFS::ErrorOperation::kMultiReadSingleReq,
+ &(reqs[i].result), use_direct_io(), reqs[i].scratch,
+ /*need_count_increase=*/true,
+ /*fault_injected=*/&this_injected_error);
+ injected_error |= this_injected_error;
+ }
+ if (s.ok()) {
+ s = fs_->InjectThreadSpecificReadError(
+ FaultInjectionTestFS::ErrorOperation::kMultiRead, nullptr,
+ use_direct_io(), nullptr, /*need_count_increase=*/!injected_error,
+ /*fault_injected=*/nullptr);
+ }
+ if (s.ok() && fs_->ShouldInjectRandomReadError()) {
+ return IOStatus::IOError("Injected read error");
+ }
+ return s;
+}
+
+size_t TestFSRandomAccessFile::GetUniqueId(char* id, size_t max_size) const {
+ if (fs_->ShouldFailGetUniqueId()) {
+ return 0;
+ } else {
+ return target_->GetUniqueId(id, max_size);
+ }
+}
+IOStatus TestFSSequentialFile::Read(size_t n, const IOOptions& options,
+ Slice* result, char* scratch,
+ IODebugContext* dbg) {
+ IOStatus s = target()->Read(n, options, result, scratch, dbg);
+ if (s.ok() && fs_->ShouldInjectRandomReadError()) {
+ return IOStatus::IOError("Injected seq read error");
+ }
+ return s;
+}
+
+IOStatus TestFSSequentialFile::PositionedRead(uint64_t offset, size_t n,
+ const IOOptions& options,
+ Slice* result, char* scratch,
+ IODebugContext* dbg) {
+ IOStatus s =
+ target()->PositionedRead(offset, n, options, result, scratch, dbg);
+ if (s.ok() && fs_->ShouldInjectRandomReadError()) {
+ return IOStatus::IOError("Injected seq positioned read error");
+ }
+ return s;
+}
+
+IOStatus FaultInjectionTestFS::NewDirectory(
+ const std::string& name, const IOOptions& options,
+ std::unique_ptr<FSDirectory>* result, IODebugContext* dbg) {
+ std::unique_ptr<FSDirectory> r;
+ IOStatus io_s = target()->NewDirectory(name, options, &r, dbg);
+ if (!io_s.ok()) {
+ return io_s;
+ }
+ result->reset(
+ new TestFSDirectory(this, TestFSTrimDirname(name), r.release()));
+ return IOStatus::OK();
+}
+
+IOStatus FaultInjectionTestFS::NewWritableFile(
+ const std::string& fname, const FileOptions& file_opts,
+ std::unique_ptr<FSWritableFile>* result, IODebugContext* dbg) {
+ if (!IsFilesystemActive()) {
+ return GetError();
+ }
+ {
+ IOStatus in_s = InjectMetadataWriteError();
+ if (!in_s.ok()) {
+ return in_s;
+ }
+ }
+
+ if (ShouldUseDiretWritable(fname)) {
+ return target()->NewWritableFile(fname, file_opts, result, dbg);
+ }
+
+ IOStatus io_s = target()->NewWritableFile(fname, file_opts, result, dbg);
+ if (io_s.ok()) {
+ result->reset(
+ new TestFSWritableFile(fname, file_opts, std::move(*result), this));
+ // WritableFileWriter* file is opened
+ // again then it will be truncated - so forget our saved state.
+ UntrackFile(fname);
+ {
+ MutexLock l(&mutex_);
+ open_managed_files_.insert(fname);
+ auto dir_and_name = TestFSGetDirAndName(fname);
+ auto& list = dir_to_new_files_since_last_sync_[dir_and_name.first];
+ // The new file could overwrite an old one. Here we simplify
+ // the implementation by assuming no file of this name after
+ // dropping unsynced files.
+ list[dir_and_name.second] = kNewFileNoOverwrite;
+ }
+ {
+ IOStatus in_s = InjectMetadataWriteError();
+ if (!in_s.ok()) {
+ return in_s;
+ }
+ }
+ }
+ return io_s;
+}
+
+IOStatus FaultInjectionTestFS::ReopenWritableFile(
+ const std::string& fname, const FileOptions& file_opts,
+ std::unique_ptr<FSWritableFile>* result, IODebugContext* dbg) {
+ if (!IsFilesystemActive()) {
+ return GetError();
+ }
+ if (ShouldUseDiretWritable(fname)) {
+ return target()->ReopenWritableFile(fname, file_opts, result, dbg);
+ }
+ {
+ IOStatus in_s = InjectMetadataWriteError();
+ if (!in_s.ok()) {
+ return in_s;
+ }
+ }
+
+ bool exists;
+ IOStatus io_s,
+ exists_s = target()->FileExists(fname, IOOptions(), nullptr /* dbg */);
+ if (exists_s.IsNotFound()) {
+ exists = false;
+ } else if (exists_s.ok()) {
+ exists = true;
+ } else {
+ io_s = exists_s;
+ exists = false;
+ }
+
+ if (io_s.ok()) {
+ io_s = target()->ReopenWritableFile(fname, file_opts, result, dbg);
+ }
+
+ // Only track files we created. Files created outside of this
+ // `FaultInjectionTestFS` are not eligible for tracking/data dropping
+ // (for example, they may contain data a previous db_stress run expects to
+ // be recovered). This could be extended to track/drop data appended once
+ // the file is under `FaultInjectionTestFS`'s control.
+ if (io_s.ok()) {
+ bool should_track;
+ {
+ MutexLock l(&mutex_);
+ if (db_file_state_.find(fname) != db_file_state_.end()) {
+ // It was written by this `FileSystem` earlier.
+ assert(exists);
+ should_track = true;
+ } else if (!exists) {
+ // It was created by this `FileSystem` just now.
+ should_track = true;
+ open_managed_files_.insert(fname);
+ auto dir_and_name = TestFSGetDirAndName(fname);
+ auto& list = dir_to_new_files_since_last_sync_[dir_and_name.first];
+ list[dir_and_name.second] = kNewFileNoOverwrite;
+ } else {
+ should_track = false;
+ }
+ }
+ if (should_track) {
+ result->reset(
+ new TestFSWritableFile(fname, file_opts, std::move(*result), this));
+ }
+ {
+ IOStatus in_s = InjectMetadataWriteError();
+ if (!in_s.ok()) {
+ return in_s;
+ }
+ }
+ }
+ return io_s;
+}
+
+IOStatus FaultInjectionTestFS::NewRandomRWFile(
+ const std::string& fname, const FileOptions& file_opts,
+ std::unique_ptr<FSRandomRWFile>* result, IODebugContext* dbg) {
+ if (!IsFilesystemActive()) {
+ return GetError();
+ }
+ if (ShouldUseDiretWritable(fname)) {
+ return target()->NewRandomRWFile(fname, file_opts, result, dbg);
+ }
+ {
+ IOStatus in_s = InjectMetadataWriteError();
+ if (!in_s.ok()) {
+ return in_s;
+ }
+ }
+ IOStatus io_s = target()->NewRandomRWFile(fname, file_opts, result, dbg);
+ if (io_s.ok()) {
+ result->reset(new TestFSRandomRWFile(fname, std::move(*result), this));
+ // WritableFileWriter* file is opened
+ // again then it will be truncated - so forget our saved state.
+ UntrackFile(fname);
+ {
+ MutexLock l(&mutex_);
+ open_managed_files_.insert(fname);
+ auto dir_and_name = TestFSGetDirAndName(fname);
+ auto& list = dir_to_new_files_since_last_sync_[dir_and_name.first];
+ // It could be overwriting an old file, but we simplify the
+ // implementation by ignoring it.
+ list[dir_and_name.second] = kNewFileNoOverwrite;
+ }
+ {
+ IOStatus in_s = InjectMetadataWriteError();
+ if (!in_s.ok()) {
+ return in_s;
+ }
+ }
+ }
+ return io_s;
+}
+
+IOStatus FaultInjectionTestFS::NewRandomAccessFile(
+ const std::string& fname, const FileOptions& file_opts,
+ std::unique_ptr<FSRandomAccessFile>* result, IODebugContext* dbg) {
+ if (!IsFilesystemActive()) {
+ return GetError();
+ }
+ if (ShouldInjectRandomReadError()) {
+ return IOStatus::IOError("Injected error when open random access file");
+ }
+ IOStatus io_s = InjectThreadSpecificReadError(ErrorOperation::kOpen, nullptr,
+ false, nullptr,
+ /*need_count_increase=*/true,
+ /*fault_injected=*/nullptr);
+ if (io_s.ok()) {
+ io_s = target()->NewRandomAccessFile(fname, file_opts, result, dbg);
+ }
+ if (io_s.ok()) {
+ result->reset(new TestFSRandomAccessFile(fname, std::move(*result), this));
+ }
+ return io_s;
+}
+
+IOStatus FaultInjectionTestFS::NewSequentialFile(
+ const std::string& fname, const FileOptions& file_opts,
+ std::unique_ptr<FSSequentialFile>* result, IODebugContext* dbg) {
+ if (!IsFilesystemActive()) {
+ return GetError();
+ }
+
+ if (ShouldInjectRandomReadError()) {
+ return IOStatus::IOError("Injected read error when creating seq file");
+ }
+ IOStatus io_s = target()->NewSequentialFile(fname, file_opts, result, dbg);
+ if (io_s.ok()) {
+ result->reset(new TestFSSequentialFile(std::move(*result), this));
+ }
+ return io_s;
+}
+
+IOStatus FaultInjectionTestFS::DeleteFile(const std::string& f,
+ const IOOptions& options,
+ IODebugContext* dbg) {
+ if (!IsFilesystemActive()) {
+ return GetError();
+ }
+ {
+ IOStatus in_s = InjectMetadataWriteError();
+ if (!in_s.ok()) {
+ return in_s;
+ }
+ }
+ IOStatus io_s = FileSystemWrapper::DeleteFile(f, options, dbg);
+ if (io_s.ok()) {
+ UntrackFile(f);
+ {
+ IOStatus in_s = InjectMetadataWriteError();
+ if (!in_s.ok()) {
+ return in_s;
+ }
+ }
+ }
+ return io_s;
+}
+
+IOStatus FaultInjectionTestFS::RenameFile(const std::string& s,
+ const std::string& t,
+ const IOOptions& options,
+ IODebugContext* dbg) {
+ if (!IsFilesystemActive()) {
+ return GetError();
+ }
+ {
+ IOStatus in_s = InjectMetadataWriteError();
+ if (!in_s.ok()) {
+ return in_s;
+ }
+ }
+
+ // We preserve contents of overwritten files up to a size threshold.
+ // We could keep previous file in another name, but we need to worry about
+ // garbage collect the those files. We do it if it is needed later.
+ // We ignore I/O errors here for simplicity.
+ std::string previous_contents = kNewFileNoOverwrite;
+ if (target()->FileExists(t, IOOptions(), nullptr).ok()) {
+ uint64_t file_size;
+ if (target()->GetFileSize(t, IOOptions(), &file_size, nullptr).ok() &&
+ file_size < 1024) {
+ ReadFileToString(target(), t, &previous_contents).PermitUncheckedError();
+ }
+ }
+ IOStatus io_s = FileSystemWrapper::RenameFile(s, t, options, dbg);
+
+ if (io_s.ok()) {
+ {
+ MutexLock l(&mutex_);
+ if (db_file_state_.find(s) != db_file_state_.end()) {
+ db_file_state_[t] = db_file_state_[s];
+ db_file_state_.erase(s);
+ }
+
+ auto sdn = TestFSGetDirAndName(s);
+ auto tdn = TestFSGetDirAndName(t);
+ if (dir_to_new_files_since_last_sync_[sdn.first].erase(sdn.second) != 0) {
+ auto& tlist = dir_to_new_files_since_last_sync_[tdn.first];
+ assert(tlist.find(tdn.second) == tlist.end());
+ tlist[tdn.second] = previous_contents;
+ }
+ }
+ IOStatus in_s = InjectMetadataWriteError();
+ if (!in_s.ok()) {
+ return in_s;
+ }
+ }
+
+ return io_s;
+}
+
+IOStatus FaultInjectionTestFS::LinkFile(const std::string& s,
+ const std::string& t,
+ const IOOptions& options,
+ IODebugContext* dbg) {
+ if (!IsFilesystemActive()) {
+ return GetError();
+ }
+ {
+ IOStatus in_s = InjectMetadataWriteError();
+ if (!in_s.ok()) {
+ return in_s;
+ }
+ }
+
+ // Using the value in `dir_to_new_files_since_last_sync_` for the source file
+ // may be a more reasonable choice.
+ std::string previous_contents = kNewFileNoOverwrite;
+
+ IOStatus io_s = FileSystemWrapper::LinkFile(s, t, options, dbg);
+
+ if (io_s.ok()) {
+ {
+ MutexLock l(&mutex_);
+ if (db_file_state_.find(s) != db_file_state_.end()) {
+ db_file_state_[t] = db_file_state_[s];
+ }
+
+ auto sdn = TestFSGetDirAndName(s);
+ auto tdn = TestFSGetDirAndName(t);
+ if (dir_to_new_files_since_last_sync_[sdn.first].find(sdn.second) !=
+ dir_to_new_files_since_last_sync_[sdn.first].end()) {
+ auto& tlist = dir_to_new_files_since_last_sync_[tdn.first];
+ assert(tlist.find(tdn.second) == tlist.end());
+ tlist[tdn.second] = previous_contents;
+ }
+ }
+ IOStatus in_s = InjectMetadataWriteError();
+ if (!in_s.ok()) {
+ return in_s;
+ }
+ }
+
+ return io_s;
+}
+
+void FaultInjectionTestFS::WritableFileClosed(const FSFileState& state) {
+ MutexLock l(&mutex_);
+ if (open_managed_files_.find(state.filename_) != open_managed_files_.end()) {
+ db_file_state_[state.filename_] = state;
+ open_managed_files_.erase(state.filename_);
+ }
+}
+
+void FaultInjectionTestFS::WritableFileSynced(const FSFileState& state) {
+ MutexLock l(&mutex_);
+ if (open_managed_files_.find(state.filename_) != open_managed_files_.end()) {
+ if (db_file_state_.find(state.filename_) == db_file_state_.end()) {
+ db_file_state_.insert(std::make_pair(state.filename_, state));
+ } else {
+ db_file_state_[state.filename_] = state;
+ }
+ }
+}
+
+void FaultInjectionTestFS::WritableFileAppended(const FSFileState& state) {
+ MutexLock l(&mutex_);
+ if (open_managed_files_.find(state.filename_) != open_managed_files_.end()) {
+ if (db_file_state_.find(state.filename_) == db_file_state_.end()) {
+ db_file_state_.insert(std::make_pair(state.filename_, state));
+ } else {
+ db_file_state_[state.filename_] = state;
+ }
+ }
+}
+
+IOStatus FaultInjectionTestFS::DropUnsyncedFileData() {
+ IOStatus io_s;
+ MutexLock l(&mutex_);
+ for (std::map<std::string, FSFileState>::iterator it = db_file_state_.begin();
+ io_s.ok() && it != db_file_state_.end(); ++it) {
+ FSFileState& fs_state = it->second;
+ if (!fs_state.IsFullySynced()) {
+ io_s = fs_state.DropUnsyncedData();
+ }
+ }
+ return io_s;
+}
+
+IOStatus FaultInjectionTestFS::DropRandomUnsyncedFileData(Random* rnd) {
+ IOStatus io_s;
+ MutexLock l(&mutex_);
+ for (std::map<std::string, FSFileState>::iterator it = db_file_state_.begin();
+ io_s.ok() && it != db_file_state_.end(); ++it) {
+ FSFileState& fs_state = it->second;
+ if (!fs_state.IsFullySynced()) {
+ io_s = fs_state.DropRandomUnsyncedData(rnd);
+ }
+ }
+ return io_s;
+}
+
+IOStatus FaultInjectionTestFS::DeleteFilesCreatedAfterLastDirSync(
+ const IOOptions& options, IODebugContext* dbg) {
+ // Because DeleteFile access this container make a copy to avoid deadlock
+ std::map<std::string, std::map<std::string, std::string>> map_copy;
+ {
+ MutexLock l(&mutex_);
+ map_copy.insert(dir_to_new_files_since_last_sync_.begin(),
+ dir_to_new_files_since_last_sync_.end());
+ }
+
+ for (auto& pair : map_copy) {
+ for (auto& file_pair : pair.second) {
+ if (file_pair.second == kNewFileNoOverwrite) {
+ IOStatus io_s =
+ DeleteFile(pair.first + "/" + file_pair.first, options, dbg);
+ if (!io_s.ok()) {
+ return io_s;
+ }
+ } else {
+ IOStatus io_s =
+ WriteStringToFile(target(), file_pair.second,
+ pair.first + "/" + file_pair.first, true);
+ if (!io_s.ok()) {
+ return io_s;
+ }
+ }
+ }
+ }
+ return IOStatus::OK();
+}
+
+void FaultInjectionTestFS::ResetState() {
+ MutexLock l(&mutex_);
+ db_file_state_.clear();
+ dir_to_new_files_since_last_sync_.clear();
+ SetFilesystemActiveNoLock(true);
+}
+
+void FaultInjectionTestFS::UntrackFile(const std::string& f) {
+ MutexLock l(&mutex_);
+ auto dir_and_name = TestFSGetDirAndName(f);
+ dir_to_new_files_since_last_sync_[dir_and_name.first].erase(
+ dir_and_name.second);
+ db_file_state_.erase(f);
+ open_managed_files_.erase(f);
+}
+
+IOStatus FaultInjectionTestFS::InjectThreadSpecificReadError(
+ ErrorOperation op, Slice* result, bool direct_io, char* scratch,
+ bool need_count_increase, bool* fault_injected) {
+ bool dummy_bool;
+ bool& ret_fault_injected = fault_injected ? *fault_injected : dummy_bool;
+ ret_fault_injected = false;
+ ErrorContext* ctx = static_cast<ErrorContext*>(thread_local_error_->Get());
+ if (ctx == nullptr || !ctx->enable_error_injection || !ctx->one_in) {
+ return IOStatus::OK();
+ }
+
+ if (ctx->rand.OneIn(ctx->one_in)) {
+ if (ctx->count == 0) {
+ ctx->message = "";
+ }
+ if (need_count_increase) {
+ ctx->count++;
+ }
+ if (ctx->callstack) {
+ free(ctx->callstack);
+ }
+ ctx->callstack = port::SaveStack(&ctx->frames);
+
+ if (op != ErrorOperation::kMultiReadSingleReq) {
+ // Likely non-per read status code for MultiRead
+ ctx->message += "error; ";
+ ret_fault_injected = true;
+ return IOStatus::IOError();
+ } else if (Random::GetTLSInstance()->OneIn(8)) {
+ assert(result);
+ // For a small chance, set the failure to status but turn the
+ // result to be empty, which is supposed to be caught for a check.
+ *result = Slice();
+ ctx->message += "inject empty result; ";
+ ret_fault_injected = true;
+ } else if (!direct_io && Random::GetTLSInstance()->OneIn(7) &&
+ scratch != nullptr && result->data() == scratch) {
+ assert(result);
+ // With direct I/O, many extra bytes might be read so corrupting
+ // one byte might not cause checksum mismatch. Skip checksum
+ // corruption injection.
+ // We only corrupt data if the result is filled to `scratch`. For other
+ // cases, the data might not be able to be modified (e.g mmaped files)
+ // or has unintended side effects.
+ // For a small chance, set the failure to status but corrupt the
+ // result in a way that checksum checking is supposed to fail.
+ // Corrupt the last byte, which is supposed to be a checksum byte
+ // It would work for CRC. Not 100% sure for xxhash and will adjust
+ // if it is not the case.
+ const_cast<char*>(result->data())[result->size() - 1]++;
+ ctx->message += "corrupt last byte; ";
+ ret_fault_injected = true;
+ } else {
+ ctx->message += "error result multiget single; ";
+ ret_fault_injected = true;
+ return IOStatus::IOError();
+ }
+ }
+ return IOStatus::OK();
+}
+
+bool FaultInjectionTestFS::TryParseFileName(const std::string& file_name,
+ uint64_t* number, FileType* type) {
+ std::size_t found = file_name.find_last_of("/");
+ std::string file = file_name.substr(found);
+ return ParseFileName(file, number, type);
+}
+
+IOStatus FaultInjectionTestFS::InjectWriteError(const std::string& file_name) {
+ MutexLock l(&mutex_);
+ if (!enable_write_error_injection_ || !write_error_one_in_) {
+ return IOStatus::OK();
+ }
+ bool allowed_type = false;
+
+ if (inject_for_all_file_types_) {
+ allowed_type = true;
+ } else {
+ uint64_t number;
+ FileType cur_type = kTempFile;
+ if (TryParseFileName(file_name, &number, &cur_type)) {
+ for (const auto& type : write_error_allowed_types_) {
+ if (cur_type == type) {
+ allowed_type = true;
+ }
+ }
+ }
+ }
+
+ if (allowed_type) {
+ if (write_error_rand_.OneIn(write_error_one_in_)) {
+ return GetError();
+ }
+ }
+ return IOStatus::OK();
+}
+
+IOStatus FaultInjectionTestFS::InjectMetadataWriteError() {
+ {
+ MutexLock l(&mutex_);
+ if (!enable_metadata_write_error_injection_ ||
+ !metadata_write_error_one_in_ ||
+ !write_error_rand_.OneIn(metadata_write_error_one_in_)) {
+ return IOStatus::OK();
+ }
+ }
+ TEST_SYNC_POINT("FaultInjectionTestFS::InjectMetadataWriteError:Injected");
+ return IOStatus::IOError();
+}
+
+void FaultInjectionTestFS::PrintFaultBacktrace() {
+#if defined(OS_LINUX)
+ ErrorContext* ctx = static_cast<ErrorContext*>(thread_local_error_->Get());
+ if (ctx == nullptr) {
+ return;
+ }
+ fprintf(stderr, "Injected error type = %d\n", ctx->type);
+ fprintf(stderr, "Message: %s\n", ctx->message.c_str());
+ port::PrintAndFreeStack(ctx->callstack, ctx->frames);
+ ctx->callstack = nullptr;
+#endif
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/fault_injection_fs.h b/src/rocksdb/utilities/fault_injection_fs.h
new file mode 100644
index 000000000..53c9ccb6f
--- /dev/null
+++ b/src/rocksdb/utilities/fault_injection_fs.h
@@ -0,0 +1,584 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright 2014 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+// This test uses a custom FileSystem to keep track of the state of a file
+// system the last "Sync". The data being written is cached in a "buffer".
+// Only when "Sync" is called, the data will be persistent. It can similate
+// file data loss (or entire files) not protected by a "Sync". For any of the
+// FileSystem related operations, by specify the "IOStatus Error", a specific
+// error can be returned when file system is not activated.
+
+#pragma once
+
+#include <algorithm>
+#include <map>
+#include <set>
+#include <string>
+
+#include "file/filename.h"
+#include "rocksdb/file_system.h"
+#include "util/mutexlock.h"
+#include "util/random.h"
+#include "util/thread_local.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class TestFSWritableFile;
+class FaultInjectionTestFS;
+
+struct FSFileState {
+ std::string filename_;
+ ssize_t pos_;
+ ssize_t pos_at_last_sync_;
+ ssize_t pos_at_last_flush_;
+ std::string buffer_;
+
+ explicit FSFileState(const std::string& filename)
+ : filename_(filename),
+ pos_(-1),
+ pos_at_last_sync_(-1),
+ pos_at_last_flush_(-1) {}
+
+ FSFileState() : pos_(-1), pos_at_last_sync_(-1), pos_at_last_flush_(-1) {}
+
+ bool IsFullySynced() const { return pos_ <= 0 || pos_ == pos_at_last_sync_; }
+
+ IOStatus DropUnsyncedData();
+
+ IOStatus DropRandomUnsyncedData(Random* rand);
+};
+
+// A wrapper around WritableFileWriter* file
+// is written to or sync'ed.
+class TestFSWritableFile : public FSWritableFile {
+ public:
+ explicit TestFSWritableFile(const std::string& fname,
+ const FileOptions& file_opts,
+ std::unique_ptr<FSWritableFile>&& f,
+ FaultInjectionTestFS* fs);
+ virtual ~TestFSWritableFile();
+ virtual IOStatus Append(const Slice& data, const IOOptions&,
+ IODebugContext*) override;
+ virtual IOStatus Append(const Slice& data, const IOOptions& options,
+ const DataVerificationInfo& verification_info,
+ IODebugContext* dbg) override;
+ virtual IOStatus Truncate(uint64_t size, const IOOptions& options,
+ IODebugContext* dbg) override {
+ return target_->Truncate(size, options, dbg);
+ }
+ virtual IOStatus Close(const IOOptions& options,
+ IODebugContext* dbg) override;
+ virtual IOStatus Flush(const IOOptions&, IODebugContext*) override;
+ virtual IOStatus Sync(const IOOptions& options, IODebugContext* dbg) override;
+ virtual IOStatus RangeSync(uint64_t /*offset*/, uint64_t /*nbytes*/,
+ const IOOptions& options,
+ IODebugContext* dbg) override;
+ virtual bool IsSyncThreadSafe() const override { return true; }
+ virtual IOStatus PositionedAppend(const Slice& data, uint64_t offset,
+ const IOOptions& options,
+ IODebugContext* dbg) override {
+ return target_->PositionedAppend(data, offset, options, dbg);
+ }
+ IOStatus PositionedAppend(const Slice& data, uint64_t offset,
+ const IOOptions& options,
+ const DataVerificationInfo& verification_info,
+ IODebugContext* dbg) override;
+ virtual size_t GetRequiredBufferAlignment() const override {
+ return target_->GetRequiredBufferAlignment();
+ }
+ virtual bool use_direct_io() const override {
+ return target_->use_direct_io();
+ };
+
+ private:
+ FSFileState state_; // Need protection by mutex_
+ FileOptions file_opts_;
+ std::unique_ptr<FSWritableFile> target_;
+ bool writable_file_opened_;
+ FaultInjectionTestFS* fs_;
+ port::Mutex mutex_;
+};
+
+// A wrapper around WritableFileWriter* file
+// is written to or sync'ed.
+class TestFSRandomRWFile : public FSRandomRWFile {
+ public:
+ explicit TestFSRandomRWFile(const std::string& fname,
+ std::unique_ptr<FSRandomRWFile>&& f,
+ FaultInjectionTestFS* fs);
+ virtual ~TestFSRandomRWFile();
+ IOStatus Write(uint64_t offset, const Slice& data, const IOOptions& options,
+ IODebugContext* dbg) override;
+ IOStatus Read(uint64_t offset, size_t n, const IOOptions& options,
+ Slice* result, char* scratch,
+ IODebugContext* dbg) const override;
+ IOStatus Close(const IOOptions& options, IODebugContext* dbg) override;
+ IOStatus Flush(const IOOptions& options, IODebugContext* dbg) override;
+ IOStatus Sync(const IOOptions& options, IODebugContext* dbg) override;
+ size_t GetRequiredBufferAlignment() const override {
+ return target_->GetRequiredBufferAlignment();
+ }
+ bool use_direct_io() const override { return target_->use_direct_io(); };
+
+ private:
+ std::unique_ptr<FSRandomRWFile> target_;
+ bool file_opened_;
+ FaultInjectionTestFS* fs_;
+};
+
+class TestFSRandomAccessFile : public FSRandomAccessFile {
+ public:
+ explicit TestFSRandomAccessFile(const std::string& fname,
+ std::unique_ptr<FSRandomAccessFile>&& f,
+ FaultInjectionTestFS* fs);
+ ~TestFSRandomAccessFile() override {}
+ IOStatus Read(uint64_t offset, size_t n, const IOOptions& options,
+ Slice* result, char* scratch,
+ IODebugContext* dbg) const override;
+ IOStatus MultiRead(FSReadRequest* reqs, size_t num_reqs,
+ const IOOptions& options, IODebugContext* dbg) override;
+ size_t GetRequiredBufferAlignment() const override {
+ return target_->GetRequiredBufferAlignment();
+ }
+ bool use_direct_io() const override { return target_->use_direct_io(); }
+
+ size_t GetUniqueId(char* id, size_t max_size) const override;
+
+ private:
+ std::unique_ptr<FSRandomAccessFile> target_;
+ FaultInjectionTestFS* fs_;
+};
+
+class TestFSSequentialFile : public FSSequentialFileOwnerWrapper {
+ public:
+ explicit TestFSSequentialFile(std::unique_ptr<FSSequentialFile>&& f,
+ FaultInjectionTestFS* fs)
+ : FSSequentialFileOwnerWrapper(std::move(f)), fs_(fs) {}
+ IOStatus Read(size_t n, const IOOptions& options, Slice* result,
+ char* scratch, IODebugContext* dbg) override;
+ IOStatus PositionedRead(uint64_t offset, size_t n, const IOOptions& options,
+ Slice* result, char* scratch,
+ IODebugContext* dbg) override;
+
+ private:
+ FaultInjectionTestFS* fs_;
+};
+
+class TestFSDirectory : public FSDirectory {
+ public:
+ explicit TestFSDirectory(FaultInjectionTestFS* fs, std::string dirname,
+ FSDirectory* dir)
+ : fs_(fs), dirname_(dirname), dir_(dir) {}
+ ~TestFSDirectory() {}
+
+ virtual IOStatus Fsync(const IOOptions& options,
+ IODebugContext* dbg) override;
+
+ virtual IOStatus Close(const IOOptions& options,
+ IODebugContext* dbg) override;
+
+ virtual IOStatus FsyncWithDirOptions(
+ const IOOptions& options, IODebugContext* dbg,
+ const DirFsyncOptions& dir_fsync_options) override;
+
+ private:
+ FaultInjectionTestFS* fs_;
+ std::string dirname_;
+ std::unique_ptr<FSDirectory> dir_;
+};
+
+class FaultInjectionTestFS : public FileSystemWrapper {
+ public:
+ explicit FaultInjectionTestFS(const std::shared_ptr<FileSystem>& base)
+ : FileSystemWrapper(base),
+ filesystem_active_(true),
+ filesystem_writable_(false),
+ thread_local_error_(new ThreadLocalPtr(DeleteThreadLocalErrorContext)),
+ enable_write_error_injection_(false),
+ enable_metadata_write_error_injection_(false),
+ write_error_rand_(0),
+ write_error_one_in_(0),
+ metadata_write_error_one_in_(0),
+ read_error_one_in_(0),
+ ingest_data_corruption_before_write_(false),
+ fail_get_file_unique_id_(false) {}
+ virtual ~FaultInjectionTestFS() { error_.PermitUncheckedError(); }
+
+ static const char* kClassName() { return "FaultInjectionTestFS"; }
+ const char* Name() const override { return kClassName(); }
+
+ IOStatus NewDirectory(const std::string& name, const IOOptions& options,
+ std::unique_ptr<FSDirectory>* result,
+ IODebugContext* dbg) override;
+
+ IOStatus NewWritableFile(const std::string& fname,
+ const FileOptions& file_opts,
+ std::unique_ptr<FSWritableFile>* result,
+ IODebugContext* dbg) override;
+
+ IOStatus ReopenWritableFile(const std::string& fname,
+ const FileOptions& file_opts,
+ std::unique_ptr<FSWritableFile>* result,
+ IODebugContext* dbg) override;
+
+ IOStatus NewRandomRWFile(const std::string& fname,
+ const FileOptions& file_opts,
+ std::unique_ptr<FSRandomRWFile>* result,
+ IODebugContext* dbg) override;
+
+ IOStatus NewRandomAccessFile(const std::string& fname,
+ const FileOptions& file_opts,
+ std::unique_ptr<FSRandomAccessFile>* result,
+ IODebugContext* dbg) override;
+ IOStatus NewSequentialFile(const std::string& f, const FileOptions& file_opts,
+ std::unique_ptr<FSSequentialFile>* r,
+ IODebugContext* dbg) override;
+
+ virtual IOStatus DeleteFile(const std::string& f, const IOOptions& options,
+ IODebugContext* dbg) override;
+
+ virtual IOStatus RenameFile(const std::string& s, const std::string& t,
+ const IOOptions& options,
+ IODebugContext* dbg) override;
+
+ virtual IOStatus LinkFile(const std::string& src, const std::string& target,
+ const IOOptions& options,
+ IODebugContext* dbg) override;
+
+// Undef to eliminate clash on Windows
+#undef GetFreeSpace
+ virtual IOStatus GetFreeSpace(const std::string& path,
+ const IOOptions& options, uint64_t* disk_free,
+ IODebugContext* dbg) override {
+ IOStatus io_s;
+ if (!IsFilesystemActive() &&
+ error_.subcode() == IOStatus::SubCode::kNoSpace) {
+ *disk_free = 0;
+ } else {
+ io_s = target()->GetFreeSpace(path, options, disk_free, dbg);
+ }
+ return io_s;
+ }
+
+ void WritableFileClosed(const FSFileState& state);
+
+ void WritableFileSynced(const FSFileState& state);
+
+ void WritableFileAppended(const FSFileState& state);
+
+ IOStatus DropUnsyncedFileData();
+
+ IOStatus DropRandomUnsyncedFileData(Random* rnd);
+
+ IOStatus DeleteFilesCreatedAfterLastDirSync(const IOOptions& options,
+ IODebugContext* dbg);
+
+ void ResetState();
+
+ void UntrackFile(const std::string& f);
+
+ void SyncDir(const std::string& dirname) {
+ MutexLock l(&mutex_);
+ dir_to_new_files_since_last_sync_.erase(dirname);
+ }
+
+ // Setting the filesystem to inactive is the test equivalent to simulating a
+ // system reset. Setting to inactive will freeze our saved filesystem state so
+ // that it will stop being recorded. It can then be reset back to the state at
+ // the time of the reset.
+ bool IsFilesystemActive() {
+ MutexLock l(&mutex_);
+ return filesystem_active_;
+ }
+
+ // Setting filesystem_writable_ makes NewWritableFile. ReopenWritableFile,
+ // and NewRandomRWFile bypass FaultInjectionTestFS and go directly to the
+ // target FS
+ bool IsFilesystemDirectWritable() {
+ MutexLock l(&mutex_);
+ return filesystem_writable_;
+ }
+ bool ShouldUseDiretWritable(const std::string& file_name) {
+ MutexLock l(&mutex_);
+ if (filesystem_writable_) {
+ return true;
+ }
+ FileType file_type = kTempFile;
+ uint64_t file_number = 0;
+ if (!TryParseFileName(file_name, &file_number, &file_type)) {
+ return false;
+ }
+ return skip_direct_writable_types_.find(file_type) !=
+ skip_direct_writable_types_.end();
+ }
+ void SetFilesystemActiveNoLock(
+ bool active, IOStatus error = IOStatus::Corruption("Not active")) {
+ error.PermitUncheckedError();
+ filesystem_active_ = active;
+ if (!active) {
+ error_ = error;
+ }
+ }
+ void SetFilesystemActive(
+ bool active, IOStatus error = IOStatus::Corruption("Not active")) {
+ MutexLock l(&mutex_);
+ error.PermitUncheckedError();
+ SetFilesystemActiveNoLock(active, error);
+ }
+ void SetFilesystemDirectWritable(bool writable) {
+ MutexLock l(&mutex_);
+ filesystem_writable_ = writable;
+ }
+ void AssertNoOpenFile() { assert(open_managed_files_.empty()); }
+
+ IOStatus GetError() { return error_; }
+
+ void SetFileSystemIOError(IOStatus io_error) {
+ MutexLock l(&mutex_);
+ io_error.PermitUncheckedError();
+ error_ = io_error;
+ }
+
+ // To simulate the data corruption before data is written in FS
+ void IngestDataCorruptionBeforeWrite() {
+ MutexLock l(&mutex_);
+ ingest_data_corruption_before_write_ = true;
+ }
+
+ void NoDataCorruptionBeforeWrite() {
+ MutexLock l(&mutex_);
+ ingest_data_corruption_before_write_ = false;
+ }
+
+ bool ShouldDataCorruptionBeforeWrite() {
+ MutexLock l(&mutex_);
+ return ingest_data_corruption_before_write_;
+ }
+
+ void SetChecksumHandoffFuncType(const ChecksumType& func_type) {
+ MutexLock l(&mutex_);
+ checksum_handoff_func_tpye_ = func_type;
+ }
+
+ const ChecksumType& GetChecksumHandoffFuncType() {
+ MutexLock l(&mutex_);
+ return checksum_handoff_func_tpye_;
+ }
+
+ void SetFailGetUniqueId(bool flag) {
+ MutexLock l(&mutex_);
+ fail_get_file_unique_id_ = flag;
+ }
+
+ bool ShouldFailGetUniqueId() {
+ MutexLock l(&mutex_);
+ return fail_get_file_unique_id_;
+ }
+
+ // Specify what the operation, so we can inject the right type of error
+ enum ErrorOperation : char {
+ kRead = 0,
+ kMultiReadSingleReq = 1,
+ kMultiRead = 2,
+ kOpen,
+ };
+
+ // Set thread-local parameters for error injection. The first argument,
+ // seed is the seed for the random number generator, and one_in determines
+ // the probability of injecting error (i.e an error is injected with
+ // 1/one_in probability)
+ void SetThreadLocalReadErrorContext(uint32_t seed, int one_in) {
+ struct ErrorContext* ctx =
+ static_cast<struct ErrorContext*>(thread_local_error_->Get());
+ if (ctx == nullptr) {
+ ctx = new ErrorContext(seed);
+ thread_local_error_->Reset(ctx);
+ }
+ ctx->one_in = one_in;
+ ctx->count = 0;
+ }
+
+ static void DeleteThreadLocalErrorContext(void* p) {
+ ErrorContext* ctx = static_cast<ErrorContext*>(p);
+ delete ctx;
+ }
+
+ // This is to set the parameters for the write error injection.
+ // seed is the seed for the random number generator, and one_in determines
+ // the probability of injecting error (i.e an error is injected with
+ // 1/one_in probability). For write error, we can specify the error we
+ // want to inject. Types decides the file types we want to inject the
+ // error (e.g., Wal files, SST files), which is empty by default.
+ void SetRandomWriteError(uint32_t seed, int one_in, IOStatus error,
+ bool inject_for_all_file_types,
+ const std::vector<FileType>& types) {
+ MutexLock l(&mutex_);
+ Random tmp_rand(seed);
+ error.PermitUncheckedError();
+ error_ = error;
+ write_error_rand_ = tmp_rand;
+ write_error_one_in_ = one_in;
+ inject_for_all_file_types_ = inject_for_all_file_types;
+ write_error_allowed_types_ = types;
+ }
+
+ void SetSkipDirectWritableTypes(const std::set<FileType>& types) {
+ MutexLock l(&mutex_);
+ skip_direct_writable_types_ = types;
+ }
+
+ void SetRandomMetadataWriteError(int one_in) {
+ MutexLock l(&mutex_);
+ metadata_write_error_one_in_ = one_in;
+ }
+ // If the value is not 0, it is enabled. Otherwise, it is disabled.
+ void SetRandomReadError(int one_in) { read_error_one_in_ = one_in; }
+
+ bool ShouldInjectRandomReadError() {
+ return read_error_one_in() &&
+ Random::GetTLSInstance()->OneIn(read_error_one_in());
+ }
+
+ // Inject an write error with randomlized parameter and the predefined
+ // error type. Only the allowed file types will inject the write error
+ IOStatus InjectWriteError(const std::string& file_name);
+
+ // Ingest error to metadata operations.
+ IOStatus InjectMetadataWriteError();
+
+ // Inject an error. For a READ operation, a status of IOError(), a
+ // corruption in the contents of scratch, or truncation of slice
+ // are the types of error with equal probability. For OPEN,
+ // its always an IOError.
+ // fault_injected returns whether a fault is injected. It is needed
+ // because some fault is inected with IOStatus to be OK.
+ IOStatus InjectThreadSpecificReadError(ErrorOperation op, Slice* slice,
+ bool direct_io, char* scratch,
+ bool need_count_increase,
+ bool* fault_injected);
+
+ // Get the count of how many times we injected since the previous call
+ int GetAndResetErrorCount() {
+ ErrorContext* ctx = static_cast<ErrorContext*>(thread_local_error_->Get());
+ int count = 0;
+ if (ctx != nullptr) {
+ count = ctx->count;
+ ctx->count = 0;
+ }
+ return count;
+ }
+
+ void EnableErrorInjection() {
+ ErrorContext* ctx = static_cast<ErrorContext*>(thread_local_error_->Get());
+ if (ctx) {
+ ctx->enable_error_injection = true;
+ }
+ }
+
+ void EnableWriteErrorInjection() {
+ MutexLock l(&mutex_);
+ enable_write_error_injection_ = true;
+ }
+ void EnableMetadataWriteErrorInjection() {
+ MutexLock l(&mutex_);
+ enable_metadata_write_error_injection_ = true;
+ }
+
+ void DisableWriteErrorInjection() {
+ MutexLock l(&mutex_);
+ enable_write_error_injection_ = false;
+ }
+
+ void DisableErrorInjection() {
+ ErrorContext* ctx = static_cast<ErrorContext*>(thread_local_error_->Get());
+ if (ctx) {
+ ctx->enable_error_injection = false;
+ }
+ }
+
+ void DisableMetadataWriteErrorInjection() {
+ MutexLock l(&mutex_);
+ enable_metadata_write_error_injection_ = false;
+ }
+
+ int read_error_one_in() const { return read_error_one_in_.load(); }
+
+ int write_error_one_in() const { return write_error_one_in_; }
+
+ // We capture a backtrace every time a fault is injected, for debugging
+ // purposes. This call prints the backtrace to stderr and frees the
+ // saved callstack
+ void PrintFaultBacktrace();
+
+ private:
+ port::Mutex mutex_;
+ std::map<std::string, FSFileState> db_file_state_;
+ std::set<std::string> open_managed_files_;
+ // directory -> (file name -> file contents to recover)
+ // When data is recovered from unsyned parent directory, the files with
+ // empty file contents to recover is deleted. Those with non-empty ones
+ // will be recovered to content accordingly.
+ std::unordered_map<std::string, std::map<std::string, std::string>>
+ dir_to_new_files_since_last_sync_;
+ bool filesystem_active_; // Record flushes, syncs, writes
+ bool filesystem_writable_; // Bypass FaultInjectionTestFS and go directly
+ // to underlying FS for writable files
+ IOStatus error_;
+
+ enum ErrorType : int {
+ kErrorTypeStatus = 0,
+ kErrorTypeCorruption,
+ kErrorTypeTruncated,
+ kErrorTypeMax
+ };
+
+ struct ErrorContext {
+ Random rand;
+ int one_in;
+ int count;
+ bool enable_error_injection;
+ void* callstack;
+ std::string message;
+ int frames;
+ ErrorType type;
+
+ explicit ErrorContext(uint32_t seed)
+ : rand(seed),
+ enable_error_injection(false),
+ callstack(nullptr),
+ frames(0) {}
+ ~ErrorContext() {
+ if (callstack) {
+ free(callstack);
+ }
+ }
+ };
+
+ std::unique_ptr<ThreadLocalPtr> thread_local_error_;
+ bool enable_write_error_injection_;
+ bool enable_metadata_write_error_injection_;
+ Random write_error_rand_;
+ int write_error_one_in_;
+ int metadata_write_error_one_in_;
+ std::atomic<int> read_error_one_in_;
+ bool inject_for_all_file_types_;
+ std::vector<FileType> write_error_allowed_types_;
+ // File types where direct writable is skipped.
+ std::set<FileType> skip_direct_writable_types_;
+ bool ingest_data_corruption_before_write_;
+ ChecksumType checksum_handoff_func_tpye_;
+ bool fail_get_file_unique_id_;
+
+ // Extract number of type from file name. Return false if failing to fine
+ // them.
+ bool TryParseFileName(const std::string& file_name, uint64_t* number,
+ FileType* type);
+};
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/fault_injection_secondary_cache.cc b/src/rocksdb/utilities/fault_injection_secondary_cache.cc
new file mode 100644
index 000000000..2758c2a19
--- /dev/null
+++ b/src/rocksdb/utilities/fault_injection_secondary_cache.cc
@@ -0,0 +1,131 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+// This class implements a custom SecondaryCache that randomly injects an
+// error status into Inserts/Lookups based on a specified probability.
+
+#include "utilities/fault_injection_secondary_cache.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+void FaultInjectionSecondaryCache::ResultHandle::UpdateHandleValue(
+ FaultInjectionSecondaryCache::ResultHandle* handle) {
+ ErrorContext* ctx = handle->cache_->GetErrorContext();
+ if (!ctx->rand.OneIn(handle->cache_->prob_)) {
+ handle->value_ = handle->base_->Value();
+ handle->size_ = handle->base_->Size();
+ }
+ handle->base_.reset();
+}
+
+bool FaultInjectionSecondaryCache::ResultHandle::IsReady() {
+ bool ready = true;
+ if (base_) {
+ ready = base_->IsReady();
+ if (ready) {
+ UpdateHandleValue(this);
+ }
+ }
+ return ready;
+}
+
+void FaultInjectionSecondaryCache::ResultHandle::Wait() {
+ base_->Wait();
+ UpdateHandleValue(this);
+}
+
+void* FaultInjectionSecondaryCache::ResultHandle::Value() { return value_; }
+
+size_t FaultInjectionSecondaryCache::ResultHandle::Size() { return size_; }
+
+void FaultInjectionSecondaryCache::ResultHandle::WaitAll(
+ FaultInjectionSecondaryCache* cache,
+ std::vector<SecondaryCacheResultHandle*> handles) {
+ std::vector<SecondaryCacheResultHandle*> base_handles;
+ for (SecondaryCacheResultHandle* hdl : handles) {
+ FaultInjectionSecondaryCache::ResultHandle* handle =
+ static_cast<FaultInjectionSecondaryCache::ResultHandle*>(hdl);
+ if (!handle->base_) {
+ continue;
+ }
+ base_handles.emplace_back(handle->base_.get());
+ }
+
+ cache->base_->WaitAll(base_handles);
+ for (SecondaryCacheResultHandle* hdl : handles) {
+ FaultInjectionSecondaryCache::ResultHandle* handle =
+ static_cast<FaultInjectionSecondaryCache::ResultHandle*>(hdl);
+ if (handle->base_) {
+ UpdateHandleValue(handle);
+ }
+ }
+}
+
+FaultInjectionSecondaryCache::ErrorContext*
+FaultInjectionSecondaryCache::GetErrorContext() {
+ ErrorContext* ctx = static_cast<ErrorContext*>(thread_local_error_->Get());
+ if (!ctx) {
+ ctx = new ErrorContext(seed_);
+ thread_local_error_->Reset(ctx);
+ }
+
+ return ctx;
+}
+
+Status FaultInjectionSecondaryCache::Insert(
+ const Slice& key, void* value, const Cache::CacheItemHelper* helper) {
+ ErrorContext* ctx = GetErrorContext();
+ if (ctx->rand.OneIn(prob_)) {
+ return Status::IOError();
+ }
+
+ return base_->Insert(key, value, helper);
+}
+
+std::unique_ptr<SecondaryCacheResultHandle>
+FaultInjectionSecondaryCache::Lookup(const Slice& key,
+ const Cache::CreateCallback& create_cb,
+ bool wait, bool advise_erase,
+ bool& is_in_sec_cache) {
+ ErrorContext* ctx = GetErrorContext();
+ if (base_is_compressed_sec_cache_) {
+ if (ctx->rand.OneIn(prob_)) {
+ return nullptr;
+ } else {
+ return base_->Lookup(key, create_cb, wait, advise_erase, is_in_sec_cache);
+ }
+ } else {
+ std::unique_ptr<SecondaryCacheResultHandle> hdl =
+ base_->Lookup(key, create_cb, wait, advise_erase, is_in_sec_cache);
+ if (wait && ctx->rand.OneIn(prob_)) {
+ hdl.reset();
+ }
+ return std::unique_ptr<FaultInjectionSecondaryCache::ResultHandle>(
+ new FaultInjectionSecondaryCache::ResultHandle(this, std::move(hdl)));
+ }
+}
+
+void FaultInjectionSecondaryCache::Erase(const Slice& key) {
+ base_->Erase(key);
+}
+
+void FaultInjectionSecondaryCache::WaitAll(
+ std::vector<SecondaryCacheResultHandle*> handles) {
+ if (base_is_compressed_sec_cache_) {
+ ErrorContext* ctx = GetErrorContext();
+ std::vector<SecondaryCacheResultHandle*> base_handles;
+ for (SecondaryCacheResultHandle* hdl : handles) {
+ if (ctx->rand.OneIn(prob_)) {
+ continue;
+ }
+ base_handles.push_back(hdl);
+ }
+ base_->WaitAll(base_handles);
+ } else {
+ FaultInjectionSecondaryCache::ResultHandle::WaitAll(this, handles);
+ }
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/fault_injection_secondary_cache.h b/src/rocksdb/utilities/fault_injection_secondary_cache.h
new file mode 100644
index 000000000..5321df626
--- /dev/null
+++ b/src/rocksdb/utilities/fault_injection_secondary_cache.h
@@ -0,0 +1,108 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "rocksdb/secondary_cache.h"
+#include "util/random.h"
+#include "util/thread_local.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// This class implements a custom SecondaryCache that randomly injects an
+// error status into Inserts/Lookups based on a specified probability.
+// Its used by db_stress to verify correctness in the presence of
+// secondary cache errors.
+//
+class FaultInjectionSecondaryCache : public SecondaryCache {
+ public:
+ explicit FaultInjectionSecondaryCache(
+ const std::shared_ptr<SecondaryCache>& base, uint32_t seed, int prob)
+ : base_(base),
+ seed_(seed),
+ prob_(prob),
+ thread_local_error_(new ThreadLocalPtr(DeleteThreadLocalErrorContext)) {
+ if (std::strcmp(base_->Name(), "CompressedSecondaryCache") == 0) {
+ base_is_compressed_sec_cache_ = true;
+ }
+ }
+
+ virtual ~FaultInjectionSecondaryCache() override {}
+
+ const char* Name() const override { return "FaultInjectionSecondaryCache"; }
+
+ Status Insert(const Slice& key, void* value,
+ const Cache::CacheItemHelper* helper) override;
+
+ std::unique_ptr<SecondaryCacheResultHandle> Lookup(
+ const Slice& key, const Cache::CreateCallback& create_cb, bool wait,
+ bool advise_erase, bool& is_in_sec_cache) override;
+
+ bool SupportForceErase() const override { return base_->SupportForceErase(); }
+
+ void Erase(const Slice& key) override;
+
+ void WaitAll(std::vector<SecondaryCacheResultHandle*> handles) override;
+
+ Status SetCapacity(size_t capacity) override {
+ return base_->SetCapacity(capacity);
+ }
+
+ Status GetCapacity(size_t& capacity) override {
+ return base_->GetCapacity(capacity);
+ }
+
+ std::string GetPrintableOptions() const override {
+ return base_->GetPrintableOptions();
+ }
+
+ private:
+ class ResultHandle : public SecondaryCacheResultHandle {
+ public:
+ ResultHandle(FaultInjectionSecondaryCache* cache,
+ std::unique_ptr<SecondaryCacheResultHandle>&& base)
+ : cache_(cache), base_(std::move(base)), value_(nullptr), size_(0) {}
+
+ ~ResultHandle() override {}
+
+ bool IsReady() override;
+
+ void Wait() override;
+
+ void* Value() override;
+
+ size_t Size() override;
+
+ static void WaitAll(FaultInjectionSecondaryCache* cache,
+ std::vector<SecondaryCacheResultHandle*> handles);
+
+ private:
+ static void UpdateHandleValue(ResultHandle* handle);
+
+ FaultInjectionSecondaryCache* cache_;
+ std::unique_ptr<SecondaryCacheResultHandle> base_;
+ void* value_;
+ size_t size_;
+ };
+
+ static void DeleteThreadLocalErrorContext(void* p) {
+ ErrorContext* ctx = static_cast<ErrorContext*>(p);
+ delete ctx;
+ }
+
+ const std::shared_ptr<SecondaryCache> base_;
+ uint32_t seed_;
+ int prob_;
+ bool base_is_compressed_sec_cache_{false};
+
+ struct ErrorContext {
+ Random rand;
+
+ explicit ErrorContext(uint32_t seed) : rand(seed) {}
+ };
+ std::unique_ptr<ThreadLocalPtr> thread_local_error_;
+
+ ErrorContext* GetErrorContext();
+};
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/leveldb_options/leveldb_options.cc b/src/rocksdb/utilities/leveldb_options/leveldb_options.cc
new file mode 100644
index 000000000..125c3d956
--- /dev/null
+++ b/src/rocksdb/utilities/leveldb_options/leveldb_options.cc
@@ -0,0 +1,57 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "rocksdb/utilities/leveldb_options.h"
+
+#include "rocksdb/cache.h"
+#include "rocksdb/comparator.h"
+#include "rocksdb/env.h"
+#include "rocksdb/filter_policy.h"
+#include "rocksdb/options.h"
+#include "rocksdb/table.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+LevelDBOptions::LevelDBOptions()
+ : comparator(BytewiseComparator()),
+ create_if_missing(false),
+ error_if_exists(false),
+ paranoid_checks(false),
+ env(Env::Default()),
+ info_log(nullptr),
+ write_buffer_size(4 << 20),
+ max_open_files(1000),
+ block_cache(nullptr),
+ block_size(4096),
+ block_restart_interval(16),
+ compression(kSnappyCompression),
+ filter_policy(nullptr) {}
+
+Options ConvertOptions(const LevelDBOptions& leveldb_options) {
+ Options options = Options();
+ options.create_if_missing = leveldb_options.create_if_missing;
+ options.error_if_exists = leveldb_options.error_if_exists;
+ options.paranoid_checks = leveldb_options.paranoid_checks;
+ options.env = leveldb_options.env;
+ options.info_log.reset(leveldb_options.info_log);
+ options.write_buffer_size = leveldb_options.write_buffer_size;
+ options.max_open_files = leveldb_options.max_open_files;
+ options.compression = leveldb_options.compression;
+
+ BlockBasedTableOptions table_options;
+ table_options.block_cache.reset(leveldb_options.block_cache);
+ table_options.block_size = leveldb_options.block_size;
+ table_options.block_restart_interval = leveldb_options.block_restart_interval;
+ table_options.filter_policy.reset(leveldb_options.filter_policy);
+ options.table_factory.reset(NewBlockBasedTableFactory(table_options));
+
+ return options;
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/memory/memory_test.cc b/src/rocksdb/utilities/memory/memory_test.cc
new file mode 100644
index 000000000..0b043af0e
--- /dev/null
+++ b/src/rocksdb/utilities/memory/memory_test.cc
@@ -0,0 +1,279 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "db/db_impl/db_impl.h"
+#include "rocksdb/cache.h"
+#include "rocksdb/table.h"
+#include "rocksdb/utilities/memory_util.h"
+#include "rocksdb/utilities/stackable_db.h"
+#include "table/block_based/block_based_table_factory.h"
+#include "test_util/testharness.h"
+#include "test_util/testutil.h"
+#include "util/random.h"
+#include "util/string_util.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class MemoryTest : public testing::Test {
+ public:
+ MemoryTest() : kDbDir(test::PerThreadDBPath("memory_test")), rnd_(301) {
+ assert(Env::Default()->CreateDirIfMissing(kDbDir).ok());
+ }
+
+ std::string GetDBName(int id) { return kDbDir + "db_" + std::to_string(id); }
+
+ void UpdateUsagesHistory(const std::vector<DB*>& dbs) {
+ std::map<MemoryUtil::UsageType, uint64_t> usage_by_type;
+ ASSERT_OK(GetApproximateMemoryUsageByType(dbs, &usage_by_type));
+ for (int i = 0; i < MemoryUtil::kNumUsageTypes; ++i) {
+ usage_history_[i].push_back(
+ usage_by_type[static_cast<MemoryUtil::UsageType>(i)]);
+ }
+ }
+
+ void GetCachePointersFromTableFactory(
+ const TableFactory* factory,
+ std::unordered_set<const Cache*>* cache_set) {
+ const auto bbto = factory->GetOptions<BlockBasedTableOptions>();
+ if (bbto != nullptr) {
+ cache_set->insert(bbto->block_cache.get());
+ cache_set->insert(bbto->block_cache_compressed.get());
+ }
+ }
+
+ void GetCachePointers(const std::vector<DB*>& dbs,
+ std::unordered_set<const Cache*>* cache_set) {
+ cache_set->clear();
+
+ for (auto* db : dbs) {
+ assert(db);
+
+ // Cache from DBImpl
+ StackableDB* sdb = dynamic_cast<StackableDB*>(db);
+ DBImpl* db_impl = dynamic_cast<DBImpl*>(sdb ? sdb->GetBaseDB() : db);
+ if (db_impl != nullptr) {
+ cache_set->insert(db_impl->TEST_table_cache());
+ }
+
+ // Cache from DBOptions
+ cache_set->insert(db->GetDBOptions().row_cache.get());
+
+ // Cache from table factories
+ std::unordered_map<std::string, const ImmutableCFOptions*> iopts_map;
+ if (db_impl != nullptr) {
+ ASSERT_OK(db_impl->TEST_GetAllImmutableCFOptions(&iopts_map));
+ }
+ for (auto pair : iopts_map) {
+ GetCachePointersFromTableFactory(pair.second->table_factory.get(),
+ cache_set);
+ }
+ }
+ }
+
+ Status GetApproximateMemoryUsageByType(
+ const std::vector<DB*>& dbs,
+ std::map<MemoryUtil::UsageType, uint64_t>* usage_by_type) {
+ std::unordered_set<const Cache*> cache_set;
+ GetCachePointers(dbs, &cache_set);
+
+ return MemoryUtil::GetApproximateMemoryUsageByType(dbs, cache_set,
+ usage_by_type);
+ }
+
+ const std::string kDbDir;
+ Random rnd_;
+ std::vector<uint64_t> usage_history_[MemoryUtil::kNumUsageTypes];
+};
+
+TEST_F(MemoryTest, SharedBlockCacheTotal) {
+ std::vector<DB*> dbs;
+ std::vector<uint64_t> usage_by_type;
+ const int kNumDBs = 10;
+ const int kKeySize = 100;
+ const int kValueSize = 500;
+ Options opt;
+ opt.create_if_missing = true;
+ opt.write_buffer_size = kKeySize + kValueSize;
+ opt.max_write_buffer_number = 10;
+ opt.min_write_buffer_number_to_merge = 10;
+ opt.disable_auto_compactions = true;
+ BlockBasedTableOptions bbt_opts;
+ bbt_opts.block_cache = NewLRUCache(4096 * 1000 * 10);
+ for (int i = 0; i < kNumDBs; ++i) {
+ ASSERT_OK(DestroyDB(GetDBName(i), opt));
+ DB* db = nullptr;
+ ASSERT_OK(DB::Open(opt, GetDBName(i), &db));
+ dbs.push_back(db);
+ }
+
+ std::vector<std::string> keys_by_db[kNumDBs];
+
+ // Fill one memtable per Put to make memtable use more memory.
+ for (int p = 0; p < opt.min_write_buffer_number_to_merge / 2; ++p) {
+ for (int i = 0; i < kNumDBs; ++i) {
+ for (int j = 0; j < 100; ++j) {
+ keys_by_db[i].emplace_back(rnd_.RandomString(kKeySize));
+ ASSERT_OK(dbs[i]->Put(WriteOptions(), keys_by_db[i].back(),
+ rnd_.RandomString(kValueSize)));
+ }
+ ASSERT_OK(dbs[i]->Flush(FlushOptions()));
+ }
+ }
+ for (int i = 0; i < kNumDBs; ++i) {
+ for (auto& key : keys_by_db[i]) {
+ std::string value;
+ ASSERT_OK(dbs[i]->Get(ReadOptions(), key, &value));
+ }
+ UpdateUsagesHistory(dbs);
+ }
+ for (size_t i = 1; i < usage_history_[MemoryUtil::kMemTableTotal].size();
+ ++i) {
+ // Expect EQ as we didn't flush more memtables.
+ ASSERT_EQ(usage_history_[MemoryUtil::kTableReadersTotal][i],
+ usage_history_[MemoryUtil::kTableReadersTotal][i - 1]);
+ }
+ for (int i = 0; i < kNumDBs; ++i) {
+ delete dbs[i];
+ }
+}
+
+TEST_F(MemoryTest, MemTableAndTableReadersTotal) {
+ std::vector<DB*> dbs;
+ std::vector<uint64_t> usage_by_type;
+ std::vector<std::vector<ColumnFamilyHandle*>> vec_handles;
+ const int kNumDBs = 10;
+ // These key/value sizes ensure each KV has its own memtable. Note that the
+ // minimum write_buffer_size allowed is 64 KB.
+ const int kKeySize = 100;
+ const int kValueSize = 1 << 16;
+ Options opt;
+ opt.create_if_missing = true;
+ opt.create_missing_column_families = true;
+ opt.write_buffer_size = kKeySize + kValueSize;
+ opt.max_write_buffer_number = 10;
+ opt.min_write_buffer_number_to_merge = 10;
+ opt.disable_auto_compactions = true;
+
+ std::vector<ColumnFamilyDescriptor> cf_descs = {
+ {kDefaultColumnFamilyName, ColumnFamilyOptions(opt)},
+ {"one", ColumnFamilyOptions(opt)},
+ {"two", ColumnFamilyOptions(opt)},
+ };
+
+ for (int i = 0; i < kNumDBs; ++i) {
+ ASSERT_OK(DestroyDB(GetDBName(i), opt));
+ std::vector<ColumnFamilyHandle*> handles;
+ dbs.emplace_back();
+ vec_handles.emplace_back();
+ ASSERT_OK(DB::Open(DBOptions(opt), GetDBName(i), cf_descs,
+ &vec_handles.back(), &dbs.back()));
+ }
+
+ // Fill one memtable per Put to make memtable use more memory.
+ for (int p = 0; p < opt.min_write_buffer_number_to_merge / 2; ++p) {
+ for (int i = 0; i < kNumDBs; ++i) {
+ for (auto* handle : vec_handles[i]) {
+ ASSERT_OK(dbs[i]->Put(WriteOptions(), handle,
+ rnd_.RandomString(kKeySize),
+ rnd_.RandomString(kValueSize)));
+ UpdateUsagesHistory(dbs);
+ }
+ }
+ }
+ // Expect the usage history is monotonically increasing
+ for (size_t i = 1; i < usage_history_[MemoryUtil::kMemTableTotal].size();
+ ++i) {
+ ASSERT_GT(usage_history_[MemoryUtil::kMemTableTotal][i],
+ usage_history_[MemoryUtil::kMemTableTotal][i - 1]);
+ ASSERT_GT(usage_history_[MemoryUtil::kMemTableUnFlushed][i],
+ usage_history_[MemoryUtil::kMemTableUnFlushed][i - 1]);
+ ASSERT_EQ(usage_history_[MemoryUtil::kTableReadersTotal][i],
+ usage_history_[MemoryUtil::kTableReadersTotal][i - 1]);
+ }
+
+ size_t usage_check_point = usage_history_[MemoryUtil::kMemTableTotal].size();
+ std::vector<Iterator*> iters;
+
+ // Create an iterator and flush all memtables for each db
+ for (int i = 0; i < kNumDBs; ++i) {
+ iters.push_back(dbs[i]->NewIterator(ReadOptions()));
+ ASSERT_OK(dbs[i]->Flush(FlushOptions()));
+
+ for (int j = 0; j < 100; ++j) {
+ std::string value;
+ ASSERT_NOK(
+ dbs[i]->Get(ReadOptions(), rnd_.RandomString(kKeySize), &value));
+ }
+
+ UpdateUsagesHistory(dbs);
+ }
+ for (size_t i = usage_check_point;
+ i < usage_history_[MemoryUtil::kMemTableTotal].size(); ++i) {
+ // Since memtables are pinned by iterators, we don't expect the
+ // memory usage of all the memtables decreases as they are pinned
+ // by iterators.
+ ASSERT_GE(usage_history_[MemoryUtil::kMemTableTotal][i],
+ usage_history_[MemoryUtil::kMemTableTotal][i - 1]);
+ // Expect the usage history from the "usage_decay_point" is
+ // monotonically decreasing.
+ ASSERT_LT(usage_history_[MemoryUtil::kMemTableUnFlushed][i],
+ usage_history_[MemoryUtil::kMemTableUnFlushed][i - 1]);
+ // Expect the usage history of the table readers increases
+ // as we flush tables.
+ ASSERT_GT(usage_history_[MemoryUtil::kTableReadersTotal][i],
+ usage_history_[MemoryUtil::kTableReadersTotal][i - 1]);
+ ASSERT_GT(usage_history_[MemoryUtil::kCacheTotal][i],
+ usage_history_[MemoryUtil::kCacheTotal][i - 1]);
+ }
+ usage_check_point = usage_history_[MemoryUtil::kMemTableTotal].size();
+ for (int i = 0; i < kNumDBs; ++i) {
+ // iterator is not used.
+ ASSERT_OK(iters[i]->status());
+ delete iters[i];
+ UpdateUsagesHistory(dbs);
+ }
+ for (size_t i = usage_check_point;
+ i < usage_history_[MemoryUtil::kMemTableTotal].size(); ++i) {
+ // Expect the usage of all memtables decreasing as we delete iterators.
+ ASSERT_LT(usage_history_[MemoryUtil::kMemTableTotal][i],
+ usage_history_[MemoryUtil::kMemTableTotal][i - 1]);
+ // Since the memory usage of un-flushed memtables is only affected
+ // by Put and flush, we expect EQ here as we only delete iterators.
+ ASSERT_EQ(usage_history_[MemoryUtil::kMemTableUnFlushed][i],
+ usage_history_[MemoryUtil::kMemTableUnFlushed][i - 1]);
+ // Expect EQ as we didn't flush more memtables.
+ ASSERT_EQ(usage_history_[MemoryUtil::kTableReadersTotal][i],
+ usage_history_[MemoryUtil::kTableReadersTotal][i - 1]);
+ }
+
+ for (int i = 0; i < kNumDBs; ++i) {
+ for (auto* handle : vec_handles[i]) {
+ delete handle;
+ }
+ delete dbs[i];
+ }
+}
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+#if !(defined NDEBUG) || !defined(OS_WIN)
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+#else
+ return 0;
+#endif
+}
+
+#else
+#include <cstdio>
+
+int main(int /*argc*/, char** /*argv*/) {
+ printf("Skipped in RocksDBLite as utilities are not supported.\n");
+ return 0;
+}
+#endif // !ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/memory/memory_util.cc b/src/rocksdb/utilities/memory/memory_util.cc
new file mode 100644
index 000000000..13c81aec4
--- /dev/null
+++ b/src/rocksdb/utilities/memory/memory_util.cc
@@ -0,0 +1,52 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "rocksdb/utilities/memory_util.h"
+
+#include "db/db_impl/db_impl.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+Status MemoryUtil::GetApproximateMemoryUsageByType(
+ const std::vector<DB*>& dbs,
+ const std::unordered_set<const Cache*> cache_set,
+ std::map<MemoryUtil::UsageType, uint64_t>* usage_by_type) {
+ usage_by_type->clear();
+
+ // MemTable
+ for (auto* db : dbs) {
+ uint64_t usage = 0;
+ if (db->GetAggregatedIntProperty(DB::Properties::kSizeAllMemTables,
+ &usage)) {
+ (*usage_by_type)[MemoryUtil::kMemTableTotal] += usage;
+ }
+ if (db->GetAggregatedIntProperty(DB::Properties::kCurSizeAllMemTables,
+ &usage)) {
+ (*usage_by_type)[MemoryUtil::kMemTableUnFlushed] += usage;
+ }
+ }
+
+ // Table Readers
+ for (auto* db : dbs) {
+ uint64_t usage = 0;
+ if (db->GetAggregatedIntProperty(DB::Properties::kEstimateTableReadersMem,
+ &usage)) {
+ (*usage_by_type)[MemoryUtil::kTableReadersTotal] += usage;
+ }
+ }
+
+ // Cache
+ for (const auto* cache : cache_set) {
+ if (cache != nullptr) {
+ (*usage_by_type)[MemoryUtil::kCacheTotal] += cache->GetUsage();
+ }
+ }
+
+ return Status::OK();
+}
+} // namespace ROCKSDB_NAMESPACE
+#endif // !ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/memory_allocators.h b/src/rocksdb/utilities/memory_allocators.h
new file mode 100644
index 000000000..c9e77a5b7
--- /dev/null
+++ b/src/rocksdb/utilities/memory_allocators.h
@@ -0,0 +1,104 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#include <atomic>
+
+#include "rocksdb/memory_allocator.h"
+
+namespace ROCKSDB_NAMESPACE {
+// A memory allocator using new/delete
+class DefaultMemoryAllocator : public MemoryAllocator {
+ public:
+ static const char* kClassName() { return "DefaultMemoryAllocator"; }
+ const char* Name() const override { return kClassName(); }
+ void* Allocate(size_t size) override {
+ return static_cast<void*>(new char[size]);
+ }
+
+ void Deallocate(void* p) override { delete[] static_cast<char*>(p); }
+};
+
+// Base class for a MemoryAllocator. This implementation does nothing
+// and implements the methods in failuse mode (assert if the methods are
+// invoked). Implementations can extend this class and override these methods
+// when they are enabled via compiler switches (e.g., the
+// JeMallocMemoryAllocator can define these methods if ROCKSDB_JEMALLOC is
+// defined at compile time. If compiled in "disabled" mode, this class provides
+// default/failure implementations. If compiled in "enabled" mode, the derived
+// class needs to provide the appopriate "enabled" methods for the "real"
+// implementation. Failure of the "real" implementation to implement ovreride
+// any of these methods will result in an assert failure.
+class BaseMemoryAllocator : public MemoryAllocator {
+ public:
+ void* Allocate(size_t /*size*/) override {
+ assert(false);
+ return nullptr;
+ }
+
+ void Deallocate(void* /*p*/) override { assert(false); }
+};
+
+// A Wrapped MemoryAllocator. Delegates the memory allcator functions to the
+// wrapped one.
+class MemoryAllocatorWrapper : public MemoryAllocator {
+ public:
+ // Initialize an MemoryAllocatorWrapper that delegates all calls to *t
+ explicit MemoryAllocatorWrapper(const std::shared_ptr<MemoryAllocator>& t);
+ ~MemoryAllocatorWrapper() override {}
+
+ // Return the target to which to forward all calls
+ MemoryAllocator* target() const { return target_.get(); }
+ // Allocate a block of at least size. Has to be thread-safe.
+ void* Allocate(size_t size) override { return target_->Allocate(size); }
+
+ // Deallocate previously allocated block. Has to be thread-safe.
+ void Deallocate(void* p) override { return target_->Deallocate(p); }
+
+ // Returns the memory size of the block allocated at p. The default
+ // implementation that just returns the original allocation_size is fine.
+ size_t UsableSize(void* p, size_t allocation_size) const override {
+ return target_->UsableSize(p, allocation_size);
+ }
+
+ const Customizable* Inner() const override { return target_.get(); }
+
+ protected:
+ std::shared_ptr<MemoryAllocator> target_;
+};
+
+// A memory allocator that counts the number of allocations and deallocations
+// This class is useful if the number of memory allocations/dellocations is
+// important.
+class CountedMemoryAllocator : public MemoryAllocatorWrapper {
+ public:
+ CountedMemoryAllocator()
+ : MemoryAllocatorWrapper(std::make_shared<DefaultMemoryAllocator>()),
+ allocations_(0),
+ deallocations_(0) {}
+
+ explicit CountedMemoryAllocator(const std::shared_ptr<MemoryAllocator>& t)
+ : MemoryAllocatorWrapper(t), allocations_(0), deallocations_(0) {}
+ static const char* kClassName() { return "CountedMemoryAllocator"; }
+ const char* Name() const override { return kClassName(); }
+ std::string GetId() const override { return std::string(Name()); }
+ void* Allocate(size_t size) override {
+ allocations_++;
+ return MemoryAllocatorWrapper::Allocate(size);
+ }
+
+ void Deallocate(void* p) override {
+ deallocations_++;
+ MemoryAllocatorWrapper::Deallocate(p);
+ }
+ uint64_t GetNumAllocations() const { return allocations_; }
+ uint64_t GetNumDeallocations() const { return deallocations_; }
+
+ private:
+ std::atomic<uint64_t> allocations_;
+ std::atomic<uint64_t> deallocations_;
+};
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/merge_operators.cc b/src/rocksdb/utilities/merge_operators.cc
new file mode 100644
index 000000000..c97e9ce25
--- /dev/null
+++ b/src/rocksdb/utilities/merge_operators.cc
@@ -0,0 +1,120 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "utilities/merge_operators.h"
+
+#include <memory>
+
+#include "rocksdb/merge_operator.h"
+#include "rocksdb/options.h"
+#include "rocksdb/utilities/customizable_util.h"
+#include "rocksdb/utilities/object_registry.h"
+#include "utilities/merge_operators/bytesxor.h"
+#include "utilities/merge_operators/sortlist.h"
+#include "utilities/merge_operators/string_append/stringappend.h"
+#include "utilities/merge_operators/string_append/stringappend2.h"
+
+namespace ROCKSDB_NAMESPACE {
+static bool LoadMergeOperator(const std::string& id,
+ std::shared_ptr<MergeOperator>* result) {
+ bool success = true;
+ // TODO: Hook the "name" up to the actual Name() of the MergeOperators?
+ // Requires these classes be moved into a header file...
+ if (id == "put" || id == "PutOperator") {
+ *result = MergeOperators::CreatePutOperator();
+ } else if (id == "put_v1") {
+ *result = MergeOperators::CreateDeprecatedPutOperator();
+ } else if (id == "uint64add" || id == "UInt64AddOperator") {
+ *result = MergeOperators::CreateUInt64AddOperator();
+ } else if (id == "max" || id == "MaxOperator") {
+ *result = MergeOperators::CreateMaxOperator();
+#ifdef ROCKSDB_LITE
+ // The remainder of the classes are handled by the ObjectRegistry in
+ // non-LITE mode
+ } else if (id == StringAppendOperator::kNickName() ||
+ id == StringAppendOperator::kClassName()) {
+ *result = MergeOperators::CreateStringAppendOperator();
+ } else if (id == StringAppendTESTOperator::kNickName() ||
+ id == StringAppendTESTOperator::kClassName()) {
+ *result = MergeOperators::CreateStringAppendTESTOperator();
+ } else if (id == BytesXOROperator::kNickName() ||
+ id == BytesXOROperator::kClassName()) {
+ *result = MergeOperators::CreateBytesXOROperator();
+ } else if (id == SortList::kNickName() || id == SortList::kClassName()) {
+ *result = MergeOperators::CreateSortOperator();
+#endif // ROCKSDB_LITE
+ } else {
+ success = false;
+ }
+ return success;
+}
+
+#ifndef ROCKSDB_LITE
+static int RegisterBuiltinMergeOperators(ObjectLibrary& library,
+ const std::string& /*arg*/) {
+ size_t num_types;
+ library.AddFactory<MergeOperator>(
+ ObjectLibrary::PatternEntry(StringAppendOperator::kClassName())
+ .AnotherName(StringAppendOperator::kNickName()),
+ [](const std::string& /*uri*/, std::unique_ptr<MergeOperator>* guard,
+ std::string* /*errmsg*/) {
+ guard->reset(new StringAppendOperator(","));
+ return guard->get();
+ });
+ library.AddFactory<MergeOperator>(
+ ObjectLibrary::PatternEntry(StringAppendTESTOperator::kClassName())
+ .AnotherName(StringAppendTESTOperator::kNickName()),
+ [](const std::string& /*uri*/, std::unique_ptr<MergeOperator>* guard,
+ std::string* /*errmsg*/) {
+ guard->reset(new StringAppendTESTOperator(","));
+ return guard->get();
+ });
+ library.AddFactory<MergeOperator>(
+ ObjectLibrary::PatternEntry(SortList::kClassName())
+ .AnotherName(SortList::kNickName()),
+ [](const std::string& /*uri*/, std::unique_ptr<MergeOperator>* guard,
+ std::string* /*errmsg*/) {
+ guard->reset(new SortList());
+ return guard->get();
+ });
+ library.AddFactory<MergeOperator>(
+ ObjectLibrary::PatternEntry(BytesXOROperator::kClassName())
+ .AnotherName(BytesXOROperator::kNickName()),
+ [](const std::string& /*uri*/, std::unique_ptr<MergeOperator>* guard,
+ std::string* /*errmsg*/) {
+ guard->reset(new BytesXOROperator());
+ return guard->get();
+ });
+
+ return static_cast<int>(library.GetFactoryCount(&num_types));
+}
+#endif // ROCKSDB_LITE
+
+Status MergeOperator::CreateFromString(const ConfigOptions& config_options,
+ const std::string& value,
+ std::shared_ptr<MergeOperator>* result) {
+#ifndef ROCKSDB_LITE
+ static std::once_flag once;
+ std::call_once(once, [&]() {
+ RegisterBuiltinMergeOperators(*(ObjectLibrary::Default().get()), "");
+ });
+#endif // ROCKSDB_LITE
+ return LoadSharedObject<MergeOperator>(config_options, value,
+ LoadMergeOperator, result);
+}
+
+std::shared_ptr<MergeOperator> MergeOperators::CreateFromStringId(
+ const std::string& id) {
+ std::shared_ptr<MergeOperator> result;
+ Status s = MergeOperator::CreateFromString(ConfigOptions(), id, &result);
+ if (s.ok()) {
+ return result;
+ } else {
+ // Empty or unknown, just return nullptr
+ return nullptr;
+ }
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/merge_operators.h b/src/rocksdb/utilities/merge_operators.h
new file mode 100644
index 000000000..9b90107e3
--- /dev/null
+++ b/src/rocksdb/utilities/merge_operators.h
@@ -0,0 +1,36 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+#pragma once
+#include <stdio.h>
+
+#include <memory>
+#include <string>
+
+#include "rocksdb/merge_operator.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class MergeOperators {
+ public:
+ static std::shared_ptr<MergeOperator> CreatePutOperator();
+ static std::shared_ptr<MergeOperator> CreateDeprecatedPutOperator();
+ static std::shared_ptr<MergeOperator> CreateUInt64AddOperator();
+ static std::shared_ptr<MergeOperator> CreateStringAppendOperator();
+ static std::shared_ptr<MergeOperator> CreateStringAppendOperator(
+ char delim_char);
+ static std::shared_ptr<MergeOperator> CreateStringAppendOperator(
+ const std::string& delim);
+ static std::shared_ptr<MergeOperator> CreateStringAppendTESTOperator();
+ static std::shared_ptr<MergeOperator> CreateMaxOperator();
+ static std::shared_ptr<MergeOperator> CreateBytesXOROperator();
+ static std::shared_ptr<MergeOperator> CreateSortOperator();
+
+ // Will return a different merge operator depending on the string.
+ static std::shared_ptr<MergeOperator> CreateFromStringId(
+ const std::string& name);
+};
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/merge_operators/bytesxor.cc b/src/rocksdb/utilities/merge_operators/bytesxor.cc
new file mode 100644
index 000000000..fa09c18ea
--- /dev/null
+++ b/src/rocksdb/utilities/merge_operators/bytesxor.cc
@@ -0,0 +1,57 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "utilities/merge_operators/bytesxor.h"
+
+#include <algorithm>
+#include <string>
+
+namespace ROCKSDB_NAMESPACE {
+
+std::shared_ptr<MergeOperator> MergeOperators::CreateBytesXOROperator() {
+ return std::make_shared<BytesXOROperator>();
+}
+
+bool BytesXOROperator::Merge(const Slice& /*key*/, const Slice* existing_value,
+ const Slice& value, std::string* new_value,
+ Logger* /*logger*/) const {
+ XOR(existing_value, value, new_value);
+ return true;
+}
+
+void BytesXOROperator::XOR(const Slice* existing_value, const Slice& value,
+ std::string* new_value) const {
+ if (!existing_value) {
+ new_value->clear();
+ new_value->assign(value.data(), value.size());
+ return;
+ }
+
+ size_t min_size = std::min(existing_value->size(), value.size());
+ size_t max_size = std::max(existing_value->size(), value.size());
+
+ new_value->clear();
+ new_value->reserve(max_size);
+
+ const char* existing_value_data = existing_value->data();
+ const char* value_data = value.data();
+
+ for (size_t i = 0; i < min_size; i++) {
+ new_value->push_back(existing_value_data[i] ^ value_data[i]);
+ }
+
+ if (existing_value->size() == max_size) {
+ for (size_t i = min_size; i < max_size; i++) {
+ new_value->push_back(existing_value_data[i]);
+ }
+ } else {
+ assert(value.size() == max_size);
+ for (size_t i = min_size; i < max_size; i++) {
+ new_value->push_back(value_data[i]);
+ }
+ }
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/merge_operators/bytesxor.h b/src/rocksdb/utilities/merge_operators/bytesxor.h
new file mode 100644
index 000000000..3c7baacce
--- /dev/null
+++ b/src/rocksdb/utilities/merge_operators/bytesxor.h
@@ -0,0 +1,40 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#include <algorithm>
+#include <memory>
+#include <string>
+
+#include "rocksdb/env.h"
+#include "rocksdb/merge_operator.h"
+#include "rocksdb/slice.h"
+#include "util/coding.h"
+#include "utilities/merge_operators.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// A 'model' merge operator that XORs two (same sized) array of bytes.
+// Implemented as an AssociativeMergeOperator for simplicity and example.
+class BytesXOROperator : public AssociativeMergeOperator {
+ public:
+ // XORs the two array of bytes one byte at a time and stores the result
+ // in new_value. len is the number of xored bytes, and the length of new_value
+ virtual bool Merge(const Slice& key, const Slice* existing_value,
+ const Slice& value, std::string* new_value,
+ Logger* logger) const override;
+
+ static const char* kClassName() { return "BytesXOR"; }
+ static const char* kNickName() { return "bytesxor"; }
+
+ const char* NickName() const override { return kNickName(); }
+ const char* Name() const override { return kClassName(); }
+
+ void XOR(const Slice* existing_value, const Slice& value,
+ std::string* new_value) const;
+};
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/merge_operators/max.cc b/src/rocksdb/utilities/merge_operators/max.cc
new file mode 100644
index 000000000..de4abfa6f
--- /dev/null
+++ b/src/rocksdb/utilities/merge_operators/max.cc
@@ -0,0 +1,80 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include <memory>
+
+#include "rocksdb/merge_operator.h"
+#include "rocksdb/slice.h"
+#include "utilities/merge_operators.h"
+
+using ROCKSDB_NAMESPACE::Logger;
+using ROCKSDB_NAMESPACE::MergeOperator;
+using ROCKSDB_NAMESPACE::Slice;
+
+namespace { // anonymous namespace
+
+// Merge operator that picks the maximum operand, Comparison is based on
+// Slice::compare
+class MaxOperator : public MergeOperator {
+ public:
+ bool FullMergeV2(const MergeOperationInput& merge_in,
+ MergeOperationOutput* merge_out) const override {
+ Slice& max = merge_out->existing_operand;
+ if (merge_in.existing_value) {
+ max = Slice(merge_in.existing_value->data(),
+ merge_in.existing_value->size());
+ } else if (max.data() == nullptr) {
+ max = Slice();
+ }
+
+ for (const auto& op : merge_in.operand_list) {
+ if (max.compare(op) < 0) {
+ max = op;
+ }
+ }
+
+ return true;
+ }
+
+ bool PartialMerge(const Slice& /*key*/, const Slice& left_operand,
+ const Slice& right_operand, std::string* new_value,
+ Logger* /*logger*/) const override {
+ if (left_operand.compare(right_operand) >= 0) {
+ new_value->assign(left_operand.data(), left_operand.size());
+ } else {
+ new_value->assign(right_operand.data(), right_operand.size());
+ }
+ return true;
+ }
+
+ bool PartialMergeMulti(const Slice& /*key*/,
+ const std::deque<Slice>& operand_list,
+ std::string* new_value,
+ Logger* /*logger*/) const override {
+ Slice max;
+ for (const auto& operand : operand_list) {
+ if (max.compare(operand) < 0) {
+ max = operand;
+ }
+ }
+
+ new_value->assign(max.data(), max.size());
+ return true;
+ }
+
+ static const char* kClassName() { return "MaxOperator"; }
+ static const char* kNickName() { return "max"; }
+ const char* Name() const override { return kClassName(); }
+ const char* NickName() const override { return kNickName(); }
+};
+
+} // end of anonymous namespace
+
+namespace ROCKSDB_NAMESPACE {
+
+std::shared_ptr<MergeOperator> MergeOperators::CreateMaxOperator() {
+ return std::make_shared<MaxOperator>();
+}
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/merge_operators/put.cc b/src/rocksdb/utilities/merge_operators/put.cc
new file mode 100644
index 000000000..ccf9ff21f
--- /dev/null
+++ b/src/rocksdb/utilities/merge_operators/put.cc
@@ -0,0 +1,92 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include <memory>
+
+#include "rocksdb/merge_operator.h"
+#include "rocksdb/slice.h"
+#include "utilities/merge_operators.h"
+
+namespace { // anonymous namespace
+
+using ROCKSDB_NAMESPACE::Logger;
+using ROCKSDB_NAMESPACE::MergeOperator;
+using ROCKSDB_NAMESPACE::Slice;
+
+// A merge operator that mimics Put semantics
+// Since this merge-operator will not be used in production,
+// it is implemented as a non-associative merge operator to illustrate the
+// new interface and for testing purposes. (That is, we inherit from
+// the MergeOperator class rather than the AssociativeMergeOperator
+// which would be simpler in this case).
+//
+// From the client-perspective, semantics are the same.
+class PutOperator : public MergeOperator {
+ public:
+ bool FullMerge(const Slice& /*key*/, const Slice* /*existing_value*/,
+ const std::deque<std::string>& operand_sequence,
+ std::string* new_value, Logger* /*logger*/) const override {
+ // Put basically only looks at the current/latest value
+ assert(!operand_sequence.empty());
+ assert(new_value != nullptr);
+ new_value->assign(operand_sequence.back());
+ return true;
+ }
+
+ bool PartialMerge(const Slice& /*key*/, const Slice& /*left_operand*/,
+ const Slice& right_operand, std::string* new_value,
+ Logger* /*logger*/) const override {
+ new_value->assign(right_operand.data(), right_operand.size());
+ return true;
+ }
+
+ using MergeOperator::PartialMergeMulti;
+ bool PartialMergeMulti(const Slice& /*key*/,
+ const std::deque<Slice>& operand_list,
+ std::string* new_value,
+ Logger* /*logger*/) const override {
+ new_value->assign(operand_list.back().data(), operand_list.back().size());
+ return true;
+ }
+
+ static const char* kClassName() { return "PutOperator"; }
+ static const char* kNickName() { return "put_v1"; }
+ const char* Name() const override { return kClassName(); }
+ const char* NickName() const override { return kNickName(); }
+};
+
+class PutOperatorV2 : public PutOperator {
+ bool FullMerge(const Slice& /*key*/, const Slice* /*existing_value*/,
+ const std::deque<std::string>& /*operand_sequence*/,
+ std::string* /*new_value*/,
+ Logger* /*logger*/) const override {
+ assert(false);
+ return false;
+ }
+
+ bool FullMergeV2(const MergeOperationInput& merge_in,
+ MergeOperationOutput* merge_out) const override {
+ // Put basically only looks at the current/latest value
+ assert(!merge_in.operand_list.empty());
+ merge_out->existing_operand = merge_in.operand_list.back();
+ return true;
+ }
+
+ static const char* kNickName() { return "put"; }
+ const char* NickName() const override { return kNickName(); }
+};
+
+} // end of anonymous namespace
+
+namespace ROCKSDB_NAMESPACE {
+
+std::shared_ptr<MergeOperator> MergeOperators::CreateDeprecatedPutOperator() {
+ return std::make_shared<PutOperator>();
+}
+
+std::shared_ptr<MergeOperator> MergeOperators::CreatePutOperator() {
+ return std::make_shared<PutOperatorV2>();
+}
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/merge_operators/sortlist.cc b/src/rocksdb/utilities/merge_operators/sortlist.cc
new file mode 100644
index 000000000..67bfc7e5e
--- /dev/null
+++ b/src/rocksdb/utilities/merge_operators/sortlist.cc
@@ -0,0 +1,95 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+#include "utilities/merge_operators/sortlist.h"
+
+#include "rocksdb/merge_operator.h"
+#include "rocksdb/slice.h"
+#include "utilities/merge_operators.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+bool SortList::FullMergeV2(const MergeOperationInput& merge_in,
+ MergeOperationOutput* merge_out) const {
+ std::vector<int> left;
+ for (Slice slice : merge_in.operand_list) {
+ std::vector<int> right;
+ MakeVector(right, slice);
+ left = Merge(left, right);
+ }
+ for (int i = 0; i < static_cast<int>(left.size()) - 1; i++) {
+ merge_out->new_value.append(std::to_string(left[i])).append(",");
+ }
+ merge_out->new_value.append(std::to_string(left.back()));
+ return true;
+}
+
+bool SortList::PartialMerge(const Slice& /*key*/, const Slice& left_operand,
+ const Slice& right_operand, std::string* new_value,
+ Logger* /*logger*/) const {
+ std::vector<int> left;
+ std::vector<int> right;
+ MakeVector(left, left_operand);
+ MakeVector(right, right_operand);
+ left = Merge(left, right);
+ for (int i = 0; i < static_cast<int>(left.size()) - 1; i++) {
+ new_value->append(std::to_string(left[i])).append(",");
+ }
+ new_value->append(std::to_string(left.back()));
+ return true;
+}
+
+bool SortList::PartialMergeMulti(const Slice& /*key*/,
+ const std::deque<Slice>& operand_list,
+ std::string* new_value,
+ Logger* /*logger*/) const {
+ (void)operand_list;
+ (void)new_value;
+ return true;
+}
+
+void SortList::MakeVector(std::vector<int>& operand, Slice slice) const {
+ do {
+ const char* begin = slice.data_;
+ while (*slice.data_ != ',' && *slice.data_) slice.data_++;
+ operand.push_back(std::stoi(std::string(begin, slice.data_)));
+ } while (0 != *slice.data_++);
+}
+
+std::vector<int> SortList::Merge(std::vector<int>& left,
+ std::vector<int>& right) const {
+ // Fill the resultant vector with sorted results from both vectors
+ std::vector<int> result;
+ unsigned left_it = 0, right_it = 0;
+
+ while (left_it < left.size() && right_it < right.size()) {
+ // If the left value is smaller than the right it goes next
+ // into the resultant vector
+ if (left[left_it] < right[right_it]) {
+ result.push_back(left[left_it]);
+ left_it++;
+ } else {
+ result.push_back(right[right_it]);
+ right_it++;
+ }
+ }
+
+ // Push the remaining data from both vectors onto the resultant
+ while (left_it < left.size()) {
+ result.push_back(left[left_it]);
+ left_it++;
+ }
+
+ while (right_it < right.size()) {
+ result.push_back(right[right_it]);
+ right_it++;
+ }
+
+ return result;
+}
+
+std::shared_ptr<MergeOperator> MergeOperators::CreateSortOperator() {
+ return std::make_shared<SortList>();
+}
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/merge_operators/sortlist.h b/src/rocksdb/utilities/merge_operators/sortlist.h
new file mode 100644
index 000000000..eaa4e76fb
--- /dev/null
+++ b/src/rocksdb/utilities/merge_operators/sortlist.h
@@ -0,0 +1,42 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+// A MergeOperator for RocksDB that implements Merge Sort.
+// It is built using the MergeOperator interface. The operator works by taking
+// an input which contains one or more merge operands where each operand is a
+// list of sorted ints and merges them to form a large sorted list.
+#pragma once
+
+#include "rocksdb/merge_operator.h"
+#include "rocksdb/slice.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class SortList : public MergeOperator {
+ public:
+ bool FullMergeV2(const MergeOperationInput& merge_in,
+ MergeOperationOutput* merge_out) const override;
+
+ bool PartialMerge(const Slice& /*key*/, const Slice& left_operand,
+ const Slice& right_operand, std::string* new_value,
+ Logger* /*logger*/) const override;
+
+ bool PartialMergeMulti(const Slice& key,
+ const std::deque<Slice>& operand_list,
+ std::string* new_value, Logger* logger) const override;
+
+ static const char* kClassName() { return "MergeSortOperator"; }
+ static const char* kNickName() { return "sortlist"; }
+
+ const char* Name() const override { return kClassName(); }
+ const char* NickName() const override { return kNickName(); }
+
+ void MakeVector(std::vector<int>& operand, Slice slice) const;
+
+ private:
+ std::vector<int> Merge(std::vector<int>& left, std::vector<int>& right) const;
+};
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/merge_operators/string_append/stringappend.cc b/src/rocksdb/utilities/merge_operators/string_append/stringappend.cc
new file mode 100644
index 000000000..5092cabcb
--- /dev/null
+++ b/src/rocksdb/utilities/merge_operators/string_append/stringappend.cc
@@ -0,0 +1,78 @@
+/**
+ * A MergeOperator for rocksdb that implements string append.
+ * @author Deon Nicholas (dnicholas@fb.com)
+ * Copyright 2013 Facebook
+ */
+
+#include "stringappend.h"
+
+#include <assert.h>
+
+#include <memory>
+
+#include "rocksdb/merge_operator.h"
+#include "rocksdb/slice.h"
+#include "rocksdb/utilities/options_type.h"
+#include "utilities/merge_operators.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace {
+static std::unordered_map<std::string, OptionTypeInfo>
+ stringappend_merge_type_info = {
+#ifndef ROCKSDB_LITE
+ {"delimiter",
+ {0, OptionType::kString, OptionVerificationType::kNormal,
+ OptionTypeFlags::kNone}},
+#endif // ROCKSDB_LITE
+};
+} // namespace
+// Constructor: also specify the delimiter character.
+StringAppendOperator::StringAppendOperator(char delim_char)
+ : delim_(1, delim_char) {
+ RegisterOptions("Delimiter", &delim_, &stringappend_merge_type_info);
+}
+
+StringAppendOperator::StringAppendOperator(const std::string& delim)
+ : delim_(delim) {
+ RegisterOptions("Delimiter", &delim_, &stringappend_merge_type_info);
+}
+
+// Implementation for the merge operation (concatenates two strings)
+bool StringAppendOperator::Merge(const Slice& /*key*/,
+ const Slice* existing_value,
+ const Slice& value, std::string* new_value,
+ Logger* /*logger*/) const {
+ // Clear the *new_value for writing.
+ assert(new_value);
+ new_value->clear();
+
+ if (!existing_value) {
+ // No existing_value. Set *new_value = value
+ new_value->assign(value.data(), value.size());
+ } else {
+ // Generic append (existing_value != null).
+ // Reserve *new_value to correct size, and apply concatenation.
+ new_value->reserve(existing_value->size() + delim_.size() + value.size());
+ new_value->assign(existing_value->data(), existing_value->size());
+ new_value->append(delim_);
+ new_value->append(value.data(), value.size());
+ }
+
+ return true;
+}
+
+std::shared_ptr<MergeOperator> MergeOperators::CreateStringAppendOperator() {
+ return std::make_shared<StringAppendOperator>(',');
+}
+
+std::shared_ptr<MergeOperator> MergeOperators::CreateStringAppendOperator(
+ char delim_char) {
+ return std::make_shared<StringAppendOperator>(delim_char);
+}
+
+std::shared_ptr<MergeOperator> MergeOperators::CreateStringAppendOperator(
+ const std::string& delim) {
+ return std::make_shared<StringAppendOperator>(delim);
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/merge_operators/string_append/stringappend.h b/src/rocksdb/utilities/merge_operators/string_append/stringappend.h
new file mode 100644
index 000000000..153532382
--- /dev/null
+++ b/src/rocksdb/utilities/merge_operators/string_append/stringappend.h
@@ -0,0 +1,32 @@
+/**
+ * A MergeOperator for rocksdb that implements string append.
+ * @author Deon Nicholas (dnicholas@fb.com)
+ * Copyright 2013 Facebook
+ */
+
+#pragma once
+#include "rocksdb/merge_operator.h"
+#include "rocksdb/slice.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class StringAppendOperator : public AssociativeMergeOperator {
+ public:
+ // Constructor: specify delimiter
+ explicit StringAppendOperator(char delim_char);
+ explicit StringAppendOperator(const std::string& delim);
+
+ virtual bool Merge(const Slice& key, const Slice* existing_value,
+ const Slice& value, std::string* new_value,
+ Logger* logger) const override;
+
+ static const char* kClassName() { return "StringAppendOperator"; }
+ static const char* kNickName() { return "stringappend"; }
+ virtual const char* Name() const override { return kClassName(); }
+ virtual const char* NickName() const override { return kNickName(); }
+
+ private:
+ std::string delim_; // The delimiter is inserted between elements
+};
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/merge_operators/string_append/stringappend2.cc b/src/rocksdb/utilities/merge_operators/string_append/stringappend2.cc
new file mode 100644
index 000000000..36cb9ee34
--- /dev/null
+++ b/src/rocksdb/utilities/merge_operators/string_append/stringappend2.cc
@@ -0,0 +1,132 @@
+/**
+ * @author Deon Nicholas (dnicholas@fb.com)
+ * Copyright 2013 Facebook
+ */
+
+#include "stringappend2.h"
+
+#include <assert.h>
+
+#include <memory>
+#include <string>
+
+#include "rocksdb/merge_operator.h"
+#include "rocksdb/slice.h"
+#include "rocksdb/utilities/options_type.h"
+#include "utilities/merge_operators.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace {
+static std::unordered_map<std::string, OptionTypeInfo>
+ stringappend2_merge_type_info = {
+#ifndef ROCKSDB_LITE
+ {"delimiter",
+ {0, OptionType::kString, OptionVerificationType::kNormal,
+ OptionTypeFlags::kNone}},
+#endif // ROCKSDB_LITE
+};
+} // namespace
+
+// Constructor: also specify the delimiter character.
+StringAppendTESTOperator::StringAppendTESTOperator(char delim_char)
+ : delim_(1, delim_char) {
+ RegisterOptions("Delimiter", &delim_, &stringappend2_merge_type_info);
+}
+
+StringAppendTESTOperator::StringAppendTESTOperator(const std::string& delim)
+ : delim_(delim) {
+ RegisterOptions("Delimiter", &delim_, &stringappend2_merge_type_info);
+}
+
+// Implementation for the merge operation (concatenates two strings)
+bool StringAppendTESTOperator::FullMergeV2(
+ const MergeOperationInput& merge_in,
+ MergeOperationOutput* merge_out) const {
+ // Clear the *new_value for writing.
+ merge_out->new_value.clear();
+
+ if (merge_in.existing_value == nullptr && merge_in.operand_list.size() == 1) {
+ // Only one operand
+ merge_out->existing_operand = merge_in.operand_list.back();
+ return true;
+ }
+
+ // Compute the space needed for the final result.
+ size_t numBytes = 0;
+
+ for (auto it = merge_in.operand_list.begin();
+ it != merge_in.operand_list.end(); ++it) {
+ numBytes += it->size() + delim_.size();
+ }
+
+ // Only print the delimiter after the first entry has been printed
+ bool printDelim = false;
+
+ // Prepend the *existing_value if one exists.
+ if (merge_in.existing_value) {
+ merge_out->new_value.reserve(numBytes + merge_in.existing_value->size());
+ merge_out->new_value.append(merge_in.existing_value->data(),
+ merge_in.existing_value->size());
+ printDelim = true;
+ } else if (numBytes) {
+ // Without the existing (initial) value, the delimiter before the first of
+ // subsequent operands becomes redundant.
+ merge_out->new_value.reserve(numBytes - delim_.size());
+ }
+
+ // Concatenate the sequence of strings (and add a delimiter between each)
+ for (auto it = merge_in.operand_list.begin();
+ it != merge_in.operand_list.end(); ++it) {
+ if (printDelim) {
+ merge_out->new_value.append(delim_);
+ }
+ merge_out->new_value.append(it->data(), it->size());
+ printDelim = true;
+ }
+
+ return true;
+}
+
+bool StringAppendTESTOperator::PartialMergeMulti(
+ const Slice& /*key*/, const std::deque<Slice>& /*operand_list*/,
+ std::string* /*new_value*/, Logger* /*logger*/) const {
+ return false;
+}
+
+// A version of PartialMerge that actually performs "partial merging".
+// Use this to simulate the exact behaviour of the StringAppendOperator.
+bool StringAppendTESTOperator::_AssocPartialMergeMulti(
+ const Slice& /*key*/, const std::deque<Slice>& operand_list,
+ std::string* new_value, Logger* /*logger*/) const {
+ // Clear the *new_value for writing
+ assert(new_value);
+ new_value->clear();
+ assert(operand_list.size() >= 2);
+
+ // Generic append
+ // Determine and reserve correct size for *new_value.
+ size_t size = 0;
+ for (const auto& operand : operand_list) {
+ size += operand.size();
+ }
+ size += (operand_list.size() - 1) * delim_.length(); // Delimiters
+ new_value->reserve(size);
+
+ // Apply concatenation
+ new_value->assign(operand_list.front().data(), operand_list.front().size());
+
+ for (std::deque<Slice>::const_iterator it = operand_list.begin() + 1;
+ it != operand_list.end(); ++it) {
+ new_value->append(delim_);
+ new_value->append(it->data(), it->size());
+ }
+
+ return true;
+}
+
+std::shared_ptr<MergeOperator>
+MergeOperators::CreateStringAppendTESTOperator() {
+ return std::make_shared<StringAppendTESTOperator>(',');
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/merge_operators/string_append/stringappend2.h b/src/rocksdb/utilities/merge_operators/string_append/stringappend2.h
new file mode 100644
index 000000000..75389e4ae
--- /dev/null
+++ b/src/rocksdb/utilities/merge_operators/string_append/stringappend2.h
@@ -0,0 +1,52 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+/**
+ * A TEST MergeOperator for rocksdb that implements string append.
+ * It is built using the MergeOperator interface rather than the simpler
+ * AssociativeMergeOperator interface. This is useful for testing/benchmarking.
+ * While the two operators are semantically the same, all production code
+ * should use the StringAppendOperator defined in stringappend.{h,cc}. The
+ * operator defined in the present file is primarily for testing.
+ *
+ * @author Deon Nicholas (dnicholas@fb.com)
+ * Copyright 2013 Facebook
+ */
+
+#pragma once
+#include <deque>
+#include <string>
+
+#include "rocksdb/merge_operator.h"
+#include "rocksdb/slice.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class StringAppendTESTOperator : public MergeOperator {
+ public:
+ // Constructor with delimiter
+ explicit StringAppendTESTOperator(char delim_char);
+ explicit StringAppendTESTOperator(const std::string& delim);
+
+ virtual bool FullMergeV2(const MergeOperationInput& merge_in,
+ MergeOperationOutput* merge_out) const override;
+
+ virtual bool PartialMergeMulti(const Slice& key,
+ const std::deque<Slice>& operand_list,
+ std::string* new_value,
+ Logger* logger) const override;
+
+ static const char* kClassName() { return "StringAppendTESTOperator"; }
+ static const char* kNickName() { return "stringappendtest"; }
+ const char* Name() const override { return kClassName(); }
+ const char* NickName() const override { return kNickName(); }
+
+ private:
+ // A version of PartialMerge that actually performs "partial merging".
+ // Use this to simulate the exact behaviour of the StringAppendOperator.
+ bool _AssocPartialMergeMulti(const Slice& key,
+ const std::deque<Slice>& operand_list,
+ std::string* new_value, Logger* logger) const;
+
+ std::string delim_; // The delimiter is inserted between elements
+};
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/merge_operators/string_append/stringappend_test.cc b/src/rocksdb/utilities/merge_operators/string_append/stringappend_test.cc
new file mode 100644
index 000000000..22b6144af
--- /dev/null
+++ b/src/rocksdb/utilities/merge_operators/string_append/stringappend_test.cc
@@ -0,0 +1,640 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+
+/**
+ * An persistent map : key -> (list of strings), using rocksdb merge.
+ * This file is a test-harness / use-case for the StringAppendOperator.
+ *
+ * @author Deon Nicholas (dnicholas@fb.com)
+ * Copyright 2013 Facebook, Inc.
+ */
+
+#include "utilities/merge_operators/string_append/stringappend.h"
+
+#include <iostream>
+#include <map>
+#include <tuple>
+
+#include "port/stack_trace.h"
+#include "rocksdb/db.h"
+#include "rocksdb/merge_operator.h"
+#include "rocksdb/utilities/db_ttl.h"
+#include "test_util/testharness.h"
+#include "util/random.h"
+#include "utilities/merge_operators.h"
+#include "utilities/merge_operators/string_append/stringappend2.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// Path to the database on file system
+const std::string kDbName = test::PerThreadDBPath("stringappend_test");
+
+namespace {
+// OpenDb opens a (possibly new) rocksdb database with a StringAppendOperator
+std::shared_ptr<DB> OpenNormalDb(const std::string& delim) {
+ DB* db;
+ Options options;
+ options.create_if_missing = true;
+ MergeOperator* mergeOperator;
+ if (delim.size() == 1) {
+ mergeOperator = new StringAppendOperator(delim[0]);
+ } else {
+ mergeOperator = new StringAppendOperator(delim);
+ }
+ options.merge_operator.reset(mergeOperator);
+ EXPECT_OK(DB::Open(options, kDbName, &db));
+ return std::shared_ptr<DB>(db);
+}
+
+#ifndef ROCKSDB_LITE // TtlDb is not supported in Lite
+// Open a TtlDB with a non-associative StringAppendTESTOperator
+std::shared_ptr<DB> OpenTtlDb(const std::string& delim) {
+ DBWithTTL* db;
+ Options options;
+ options.create_if_missing = true;
+ MergeOperator* mergeOperator;
+ if (delim.size() == 1) {
+ mergeOperator = new StringAppendTESTOperator(delim[0]);
+ } else {
+ mergeOperator = new StringAppendTESTOperator(delim);
+ }
+ options.merge_operator.reset(mergeOperator);
+ EXPECT_OK(DBWithTTL::Open(options, kDbName, &db, 123456));
+ return std::shared_ptr<DB>(db);
+}
+#endif // !ROCKSDB_LITE
+} // namespace
+
+/// StringLists represents a set of string-lists, each with a key-index.
+/// Supports Append(list, string) and Get(list)
+class StringLists {
+ public:
+ // Constructor: specifies the rocksdb db
+ /* implicit */
+ StringLists(std::shared_ptr<DB> db)
+ : db_(db), merge_option_(), get_option_() {
+ assert(db);
+ }
+
+ // Append string val onto the list defined by key; return true on success
+ bool Append(const std::string& key, const std::string& val) {
+ Slice valSlice(val.data(), val.size());
+ auto s = db_->Merge(merge_option_, key, valSlice);
+
+ if (s.ok()) {
+ return true;
+ } else {
+ std::cerr << "ERROR " << s.ToString() << std::endl;
+ return false;
+ }
+ }
+
+ // Returns the list of strings associated with key (or "" if does not exist)
+ bool Get(const std::string& key, std::string* const result) {
+ assert(result != nullptr); // we should have a place to store the result
+ auto s = db_->Get(get_option_, key, result);
+
+ if (s.ok()) {
+ return true;
+ }
+
+ // Either key does not exist, or there is some error.
+ *result = ""; // Always return empty string (just for convention)
+
+ // NotFound is okay; just return empty (similar to std::map)
+ // But network or db errors, etc, should fail the test (or at least yell)
+ if (!s.IsNotFound()) {
+ std::cerr << "ERROR " << s.ToString() << std::endl;
+ }
+
+ // Always return false if s.ok() was not true
+ return false;
+ }
+
+ private:
+ std::shared_ptr<DB> db_;
+ WriteOptions merge_option_;
+ ReadOptions get_option_;
+};
+
+// The class for unit-testing
+class StringAppendOperatorTest : public testing::Test,
+ public ::testing::WithParamInterface<bool> {
+ public:
+ StringAppendOperatorTest() {
+ EXPECT_OK(
+ DestroyDB(kDbName, Options())); // Start each test with a fresh DB
+ }
+
+ void SetUp() override {
+#ifndef ROCKSDB_LITE // TtlDb is not supported in Lite
+ bool if_use_ttl = GetParam();
+ if (if_use_ttl) {
+ fprintf(stderr, "Running tests with ttl db and generic operator.\n");
+ StringAppendOperatorTest::SetOpenDbFunction(&OpenTtlDb);
+ return;
+ }
+#endif // !ROCKSDB_LITE
+ fprintf(stderr, "Running tests with regular db and operator.\n");
+ StringAppendOperatorTest::SetOpenDbFunction(&OpenNormalDb);
+ }
+
+ using OpenFuncPtr = std::shared_ptr<DB> (*)(const std::string&);
+
+ // Allows user to open databases with different configurations.
+ // e.g.: Can open a DB or a TtlDB, etc.
+ static void SetOpenDbFunction(OpenFuncPtr func) { OpenDb = func; }
+
+ protected:
+ static OpenFuncPtr OpenDb;
+};
+StringAppendOperatorTest::OpenFuncPtr StringAppendOperatorTest::OpenDb =
+ nullptr;
+
+// THE TEST CASES BEGIN HERE
+
+TEST_P(StringAppendOperatorTest, IteratorTest) {
+ auto db_ = OpenDb(",");
+ StringLists slists(db_);
+
+ slists.Append("k1", "v1");
+ slists.Append("k1", "v2");
+ slists.Append("k1", "v3");
+
+ slists.Append("k2", "a1");
+ slists.Append("k2", "a2");
+ slists.Append("k2", "a3");
+
+ std::string res;
+ std::unique_ptr<ROCKSDB_NAMESPACE::Iterator> it(
+ db_->NewIterator(ReadOptions()));
+ std::string k1("k1");
+ std::string k2("k2");
+ bool first = true;
+ for (it->Seek(k1); it->Valid(); it->Next()) {
+ res = it->value().ToString();
+ if (first) {
+ ASSERT_EQ(res, "v1,v2,v3");
+ first = false;
+ } else {
+ ASSERT_EQ(res, "a1,a2,a3");
+ }
+ }
+ slists.Append("k2", "a4");
+ slists.Append("k1", "v4");
+
+ // Snapshot should still be the same. Should ignore a4 and v4.
+ first = true;
+ for (it->Seek(k1); it->Valid(); it->Next()) {
+ res = it->value().ToString();
+ if (first) {
+ ASSERT_EQ(res, "v1,v2,v3");
+ first = false;
+ } else {
+ ASSERT_EQ(res, "a1,a2,a3");
+ }
+ }
+
+ // Should release the snapshot and be aware of the new stuff now
+ it.reset(db_->NewIterator(ReadOptions()));
+ first = true;
+ for (it->Seek(k1); it->Valid(); it->Next()) {
+ res = it->value().ToString();
+ if (first) {
+ ASSERT_EQ(res, "v1,v2,v3,v4");
+ first = false;
+ } else {
+ ASSERT_EQ(res, "a1,a2,a3,a4");
+ }
+ }
+
+ // start from k2 this time.
+ for (it->Seek(k2); it->Valid(); it->Next()) {
+ res = it->value().ToString();
+ if (first) {
+ ASSERT_EQ(res, "v1,v2,v3,v4");
+ first = false;
+ } else {
+ ASSERT_EQ(res, "a1,a2,a3,a4");
+ }
+ }
+
+ slists.Append("k3", "g1");
+
+ it.reset(db_->NewIterator(ReadOptions()));
+ first = true;
+ std::string k3("k3");
+ for (it->Seek(k2); it->Valid(); it->Next()) {
+ res = it->value().ToString();
+ if (first) {
+ ASSERT_EQ(res, "a1,a2,a3,a4");
+ first = false;
+ } else {
+ ASSERT_EQ(res, "g1");
+ }
+ }
+ for (it->Seek(k3); it->Valid(); it->Next()) {
+ res = it->value().ToString();
+ if (first) {
+ // should not be hit
+ ASSERT_EQ(res, "a1,a2,a3,a4");
+ first = false;
+ } else {
+ ASSERT_EQ(res, "g1");
+ }
+ }
+}
+
+TEST_P(StringAppendOperatorTest, SimpleTest) {
+ auto db = OpenDb(",");
+ StringLists slists(db);
+
+ slists.Append("k1", "v1");
+ slists.Append("k1", "v2");
+ slists.Append("k1", "v3");
+
+ std::string res;
+ ASSERT_TRUE(slists.Get("k1", &res));
+ ASSERT_EQ(res, "v1,v2,v3");
+}
+
+TEST_P(StringAppendOperatorTest, SimpleDelimiterTest) {
+ auto db = OpenDb("|");
+ StringLists slists(db);
+
+ slists.Append("k1", "v1");
+ slists.Append("k1", "v2");
+ slists.Append("k1", "v3");
+
+ std::string res;
+ ASSERT_TRUE(slists.Get("k1", &res));
+ ASSERT_EQ(res, "v1|v2|v3");
+}
+
+TEST_P(StringAppendOperatorTest, EmptyDelimiterTest) {
+ auto db = OpenDb("");
+ StringLists slists(db);
+
+ slists.Append("k1", "v1");
+ slists.Append("k1", "v2");
+ slists.Append("k1", "v3");
+
+ std::string res;
+ ASSERT_TRUE(slists.Get("k1", &res));
+ ASSERT_EQ(res, "v1v2v3");
+}
+
+TEST_P(StringAppendOperatorTest, MultiCharDelimiterTest) {
+ auto db = OpenDb("<>");
+ StringLists slists(db);
+
+ slists.Append("k1", "v1");
+ slists.Append("k1", "v2");
+ slists.Append("k1", "v3");
+
+ std::string res;
+ ASSERT_TRUE(slists.Get("k1", &res));
+ ASSERT_EQ(res, "v1<>v2<>v3");
+}
+
+TEST_P(StringAppendOperatorTest, DelimiterIsDefensivelyCopiedTest) {
+ std::string delimiter = "<>";
+ auto db = OpenDb(delimiter);
+ StringLists slists(db);
+
+ slists.Append("k1", "v1");
+ slists.Append("k1", "v2");
+ delimiter.clear();
+ slists.Append("k1", "v3");
+
+ std::string res;
+ ASSERT_TRUE(slists.Get("k1", &res));
+ ASSERT_EQ(res, "v1<>v2<>v3");
+}
+
+TEST_P(StringAppendOperatorTest, OneValueNoDelimiterTest) {
+ auto db = OpenDb("!");
+ StringLists slists(db);
+
+ slists.Append("random_key", "single_val");
+
+ std::string res;
+ ASSERT_TRUE(slists.Get("random_key", &res));
+ ASSERT_EQ(res, "single_val");
+}
+
+TEST_P(StringAppendOperatorTest, VariousKeys) {
+ auto db = OpenDb("\n");
+ StringLists slists(db);
+
+ slists.Append("c", "asdasd");
+ slists.Append("a", "x");
+ slists.Append("b", "y");
+ slists.Append("a", "t");
+ slists.Append("a", "r");
+ slists.Append("b", "2");
+ slists.Append("c", "asdasd");
+
+ std::string a, b, c;
+ bool sa, sb, sc;
+ sa = slists.Get("a", &a);
+ sb = slists.Get("b", &b);
+ sc = slists.Get("c", &c);
+
+ ASSERT_TRUE(sa && sb && sc); // All three keys should have been found
+
+ ASSERT_EQ(a, "x\nt\nr");
+ ASSERT_EQ(b, "y\n2");
+ ASSERT_EQ(c, "asdasd\nasdasd");
+}
+
+// Generate semi random keys/words from a small distribution.
+TEST_P(StringAppendOperatorTest, RandomMixGetAppend) {
+ auto db = OpenDb(" ");
+ StringLists slists(db);
+
+ // Generate a list of random keys and values
+ const int kWordCount = 15;
+ std::string words[] = {"sdasd", "triejf", "fnjsdfn", "dfjisdfsf",
+ "342839", "dsuha", "mabuais", "sadajsid",
+ "jf9834hf", "2d9j89", "dj9823jd", "a",
+ "dk02ed2dh", "$(jd4h984$(*", "mabz"};
+ const int kKeyCount = 6;
+ std::string keys[] = {"dhaiusdhu", "denidw", "daisda",
+ "keykey", "muki", "shzassdianmd"};
+
+ // Will store a local copy of all data in order to verify correctness
+ std::map<std::string, std::string> parallel_copy;
+
+ // Generate a bunch of random queries (Append and Get)!
+ enum query_t { APPEND_OP, GET_OP, NUM_OPS };
+ Random randomGen(1337); // deterministic seed; always get same results!
+
+ const int kNumQueries = 30;
+ for (int q = 0; q < kNumQueries; ++q) {
+ // Generate a random query (Append or Get) and random parameters
+ query_t query = (query_t)randomGen.Uniform((int)NUM_OPS);
+ std::string key = keys[randomGen.Uniform((int)kKeyCount)];
+ std::string word = words[randomGen.Uniform((int)kWordCount)];
+
+ // Apply the query and any checks.
+ if (query == APPEND_OP) {
+ // Apply the rocksdb test-harness Append defined above
+ slists.Append(key, word); // apply the rocksdb append
+
+ // Apply the similar "Append" to the parallel copy
+ if (parallel_copy[key].size() > 0) {
+ parallel_copy[key] += " " + word;
+ } else {
+ parallel_copy[key] = word;
+ }
+
+ } else if (query == GET_OP) {
+ // Assumes that a non-existent key just returns <empty>
+ std::string res;
+ slists.Get(key, &res);
+ ASSERT_EQ(res, parallel_copy[key]);
+ }
+ }
+}
+
+TEST_P(StringAppendOperatorTest, BIGRandomMixGetAppend) {
+ auto db = OpenDb(" ");
+ StringLists slists(db);
+
+ // Generate a list of random keys and values
+ const int kWordCount = 15;
+ std::string words[] = {"sdasd", "triejf", "fnjsdfn", "dfjisdfsf",
+ "342839", "dsuha", "mabuais", "sadajsid",
+ "jf9834hf", "2d9j89", "dj9823jd", "a",
+ "dk02ed2dh", "$(jd4h984$(*", "mabz"};
+ const int kKeyCount = 6;
+ std::string keys[] = {"dhaiusdhu", "denidw", "daisda",
+ "keykey", "muki", "shzassdianmd"};
+
+ // Will store a local copy of all data in order to verify correctness
+ std::map<std::string, std::string> parallel_copy;
+
+ // Generate a bunch of random queries (Append and Get)!
+ enum query_t { APPEND_OP, GET_OP, NUM_OPS };
+ Random randomGen(9138204); // deterministic seed
+
+ const int kNumQueries = 1000;
+ for (int q = 0; q < kNumQueries; ++q) {
+ // Generate a random query (Append or Get) and random parameters
+ query_t query = (query_t)randomGen.Uniform((int)NUM_OPS);
+ std::string key = keys[randomGen.Uniform((int)kKeyCount)];
+ std::string word = words[randomGen.Uniform((int)kWordCount)];
+
+ // Apply the query and any checks.
+ if (query == APPEND_OP) {
+ // Apply the rocksdb test-harness Append defined above
+ slists.Append(key, word); // apply the rocksdb append
+
+ // Apply the similar "Append" to the parallel copy
+ if (parallel_copy[key].size() > 0) {
+ parallel_copy[key] += " " + word;
+ } else {
+ parallel_copy[key] = word;
+ }
+
+ } else if (query == GET_OP) {
+ // Assumes that a non-existent key just returns <empty>
+ std::string res;
+ slists.Get(key, &res);
+ ASSERT_EQ(res, parallel_copy[key]);
+ }
+ }
+}
+
+TEST_P(StringAppendOperatorTest, PersistentVariousKeys) {
+ // Perform the following operations in limited scope
+ {
+ auto db = OpenDb("\n");
+ StringLists slists(db);
+
+ slists.Append("c", "asdasd");
+ slists.Append("a", "x");
+ slists.Append("b", "y");
+ slists.Append("a", "t");
+ slists.Append("a", "r");
+ slists.Append("b", "2");
+ slists.Append("c", "asdasd");
+
+ std::string a, b, c;
+ ASSERT_TRUE(slists.Get("a", &a));
+ ASSERT_TRUE(slists.Get("b", &b));
+ ASSERT_TRUE(slists.Get("c", &c));
+
+ ASSERT_EQ(a, "x\nt\nr");
+ ASSERT_EQ(b, "y\n2");
+ ASSERT_EQ(c, "asdasd\nasdasd");
+ }
+
+ // Reopen the database (the previous changes should persist / be remembered)
+ {
+ auto db = OpenDb("\n");
+ StringLists slists(db);
+
+ slists.Append("c", "bbnagnagsx");
+ slists.Append("a", "sa");
+ slists.Append("b", "df");
+ slists.Append("a", "gh");
+ slists.Append("a", "jk");
+ slists.Append("b", "l;");
+ slists.Append("c", "rogosh");
+
+ // The previous changes should be on disk (L0)
+ // The most recent changes should be in memory (MemTable)
+ // Hence, this will test both Get() paths.
+ std::string a, b, c;
+ ASSERT_TRUE(slists.Get("a", &a));
+ ASSERT_TRUE(slists.Get("b", &b));
+ ASSERT_TRUE(slists.Get("c", &c));
+
+ ASSERT_EQ(a, "x\nt\nr\nsa\ngh\njk");
+ ASSERT_EQ(b, "y\n2\ndf\nl;");
+ ASSERT_EQ(c, "asdasd\nasdasd\nbbnagnagsx\nrogosh");
+ }
+
+ // Reopen the database (the previous changes should persist / be remembered)
+ {
+ auto db = OpenDb("\n");
+ StringLists slists(db);
+
+ // All changes should be on disk. This will test VersionSet Get()
+ std::string a, b, c;
+ ASSERT_TRUE(slists.Get("a", &a));
+ ASSERT_TRUE(slists.Get("b", &b));
+ ASSERT_TRUE(slists.Get("c", &c));
+
+ ASSERT_EQ(a, "x\nt\nr\nsa\ngh\njk");
+ ASSERT_EQ(b, "y\n2\ndf\nl;");
+ ASSERT_EQ(c, "asdasd\nasdasd\nbbnagnagsx\nrogosh");
+ }
+}
+
+TEST_P(StringAppendOperatorTest, PersistentFlushAndCompaction) {
+ // Perform the following operations in limited scope
+ {
+ auto db = OpenDb("\n");
+ StringLists slists(db);
+ std::string a, b, c;
+
+ // Append, Flush, Get
+ slists.Append("c", "asdasd");
+ ASSERT_OK(db->Flush(ROCKSDB_NAMESPACE::FlushOptions()));
+ ASSERT_TRUE(slists.Get("c", &c));
+ ASSERT_EQ(c, "asdasd");
+
+ // Append, Flush, Append, Get
+ slists.Append("a", "x");
+ slists.Append("b", "y");
+ ASSERT_OK(db->Flush(ROCKSDB_NAMESPACE::FlushOptions()));
+ slists.Append("a", "t");
+ slists.Append("a", "r");
+ slists.Append("b", "2");
+
+ ASSERT_TRUE(slists.Get("a", &a));
+ ASSERT_EQ(a, "x\nt\nr");
+
+ ASSERT_TRUE(slists.Get("b", &b));
+ ASSERT_EQ(b, "y\n2");
+
+ // Append, Get
+ ASSERT_TRUE(slists.Append("c", "asdasd"));
+ ASSERT_TRUE(slists.Append("b", "monkey"));
+
+ ASSERT_TRUE(slists.Get("a", &a));
+ ASSERT_TRUE(slists.Get("b", &b));
+ ASSERT_TRUE(slists.Get("c", &c));
+
+ ASSERT_EQ(a, "x\nt\nr");
+ ASSERT_EQ(b, "y\n2\nmonkey");
+ ASSERT_EQ(c, "asdasd\nasdasd");
+ }
+
+ // Reopen the database (the previous changes should persist / be remembered)
+ {
+ auto db = OpenDb("\n");
+ StringLists slists(db);
+ std::string a, b, c;
+
+ // Get (Quick check for persistence of previous database)
+ ASSERT_TRUE(slists.Get("a", &a));
+ ASSERT_EQ(a, "x\nt\nr");
+
+ // Append, Compact, Get
+ slists.Append("c", "bbnagnagsx");
+ slists.Append("a", "sa");
+ slists.Append("b", "df");
+ ASSERT_OK(db->CompactRange(CompactRangeOptions(), nullptr, nullptr));
+ ASSERT_TRUE(slists.Get("a", &a));
+ ASSERT_TRUE(slists.Get("b", &b));
+ ASSERT_TRUE(slists.Get("c", &c));
+ ASSERT_EQ(a, "x\nt\nr\nsa");
+ ASSERT_EQ(b, "y\n2\nmonkey\ndf");
+ ASSERT_EQ(c, "asdasd\nasdasd\nbbnagnagsx");
+
+ // Append, Get
+ slists.Append("a", "gh");
+ slists.Append("a", "jk");
+ slists.Append("b", "l;");
+ slists.Append("c", "rogosh");
+ ASSERT_TRUE(slists.Get("a", &a));
+ ASSERT_TRUE(slists.Get("b", &b));
+ ASSERT_TRUE(slists.Get("c", &c));
+ ASSERT_EQ(a, "x\nt\nr\nsa\ngh\njk");
+ ASSERT_EQ(b, "y\n2\nmonkey\ndf\nl;");
+ ASSERT_EQ(c, "asdasd\nasdasd\nbbnagnagsx\nrogosh");
+
+ // Compact, Get
+ ASSERT_OK(db->CompactRange(CompactRangeOptions(), nullptr, nullptr));
+ ASSERT_EQ(a, "x\nt\nr\nsa\ngh\njk");
+ ASSERT_EQ(b, "y\n2\nmonkey\ndf\nl;");
+ ASSERT_EQ(c, "asdasd\nasdasd\nbbnagnagsx\nrogosh");
+
+ // Append, Flush, Compact, Get
+ slists.Append("b", "afcg");
+ ASSERT_OK(db->Flush(ROCKSDB_NAMESPACE::FlushOptions()));
+ ASSERT_OK(db->CompactRange(CompactRangeOptions(), nullptr, nullptr));
+ ASSERT_TRUE(slists.Get("b", &b));
+ ASSERT_EQ(b, "y\n2\nmonkey\ndf\nl;\nafcg");
+ }
+}
+
+TEST_P(StringAppendOperatorTest, SimpleTestNullDelimiter) {
+ auto db = OpenDb(std::string(1, '\0'));
+ StringLists slists(db);
+
+ slists.Append("k1", "v1");
+ slists.Append("k1", "v2");
+ slists.Append("k1", "v3");
+
+ std::string res;
+ ASSERT_TRUE(slists.Get("k1", &res));
+
+ // Construct the desired string. Default constructor doesn't like '\0' chars.
+ std::string checker("v1,v2,v3"); // Verify that the string is right size.
+ checker[2] = '\0'; // Use null delimiter instead of comma.
+ checker[5] = '\0';
+ ASSERT_EQ(checker.size(), 8); // Verify it is still the correct size
+
+ // Check that the rocksdb result string matches the desired string
+ ASSERT_EQ(res.size(), checker.size());
+ ASSERT_EQ(res, checker);
+}
+
+INSTANTIATE_TEST_CASE_P(StringAppendOperatorTest, StringAppendOperatorTest,
+ testing::Bool());
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/rocksdb/utilities/merge_operators/uint64add.cc b/src/rocksdb/utilities/merge_operators/uint64add.cc
new file mode 100644
index 000000000..5be2f5641
--- /dev/null
+++ b/src/rocksdb/utilities/merge_operators/uint64add.cc
@@ -0,0 +1,75 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include <memory>
+
+#include "logging/logging.h"
+#include "rocksdb/env.h"
+#include "rocksdb/merge_operator.h"
+#include "rocksdb/slice.h"
+#include "util/coding.h"
+#include "utilities/merge_operators.h"
+
+namespace { // anonymous namespace
+
+using ROCKSDB_NAMESPACE::AssociativeMergeOperator;
+using ROCKSDB_NAMESPACE::InfoLogLevel;
+using ROCKSDB_NAMESPACE::Logger;
+using ROCKSDB_NAMESPACE::Slice;
+
+// A 'model' merge operator with uint64 addition semantics
+// Implemented as an AssociativeMergeOperator for simplicity and example.
+class UInt64AddOperator : public AssociativeMergeOperator {
+ public:
+ bool Merge(const Slice& /*key*/, const Slice* existing_value,
+ const Slice& value, std::string* new_value,
+ Logger* logger) const override {
+ uint64_t orig_value = 0;
+ if (existing_value) {
+ orig_value = DecodeInteger(*existing_value, logger);
+ }
+ uint64_t operand = DecodeInteger(value, logger);
+
+ assert(new_value);
+ new_value->clear();
+ ROCKSDB_NAMESPACE::PutFixed64(new_value, orig_value + operand);
+
+ return true; // Return true always since corruption will be treated as 0
+ }
+
+ static const char* kClassName() { return "UInt64AddOperator"; }
+ static const char* kNickName() { return "uint64add"; }
+ const char* Name() const override { return kClassName(); }
+ const char* NickName() const override { return kNickName(); }
+
+ private:
+ // Takes the string and decodes it into a uint64_t
+ // On error, prints a message and returns 0
+ uint64_t DecodeInteger(const Slice& value, Logger* logger) const {
+ uint64_t result = 0;
+
+ if (value.size() == sizeof(uint64_t)) {
+ result = ROCKSDB_NAMESPACE::DecodeFixed64(value.data());
+ } else if (logger != nullptr) {
+ // If value is corrupted, treat it as 0
+ ROCKS_LOG_ERROR(logger,
+ "uint64 value corruption, size: %" ROCKSDB_PRIszt
+ " > %" ROCKSDB_PRIszt,
+ value.size(), sizeof(uint64_t));
+ }
+
+ return result;
+ }
+};
+
+} // anonymous namespace
+
+namespace ROCKSDB_NAMESPACE {
+
+std::shared_ptr<MergeOperator> MergeOperators::CreateUInt64AddOperator() {
+ return std::make_shared<UInt64AddOperator>();
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/object_registry.cc b/src/rocksdb/utilities/object_registry.cc
new file mode 100644
index 000000000..18834783d
--- /dev/null
+++ b/src/rocksdb/utilities/object_registry.cc
@@ -0,0 +1,383 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "rocksdb/utilities/object_registry.h"
+
+#include <ctype.h>
+
+#include "logging/logging.h"
+#include "port/lang.h"
+#include "rocksdb/customizable.h"
+#include "rocksdb/env.h"
+#include "util/string_util.h"
+
+namespace ROCKSDB_NAMESPACE {
+#ifndef ROCKSDB_LITE
+namespace {
+bool MatchesInteger(const std::string &target, size_t start, size_t pos) {
+ // If it is numeric, everything up to the match must be a number
+ int digits = 0;
+ if (target[start] == '-') {
+ start++; // Allow negative numbers
+ }
+ while (start < pos) {
+ if (!isdigit(target[start++])) {
+ return false;
+ } else {
+ digits++;
+ }
+ }
+ return (digits > 0);
+}
+
+bool MatchesDecimal(const std::string &target, size_t start, size_t pos) {
+ int digits = 0;
+ if (target[start] == '-') {
+ start++; // Allow negative numbers
+ }
+ for (bool point = false; start < pos; start++) {
+ if (target[start] == '.') {
+ if (point) {
+ return false;
+ } else {
+ point = true;
+ }
+ } else if (!isdigit(target[start])) {
+ return false;
+ } else {
+ digits++;
+ }
+ }
+ return (digits > 0);
+}
+} // namespace
+
+size_t ObjectLibrary::PatternEntry::MatchSeparatorAt(
+ size_t start, Quantifier mode, const std::string &target, size_t tlen,
+ const std::string &separator) const {
+ size_t slen = separator.size();
+ // See if there is enough space. If so, find the separator
+ if (tlen < start + slen) {
+ return std::string::npos; // not enough space left
+ } else if (mode == kMatchExact) {
+ // Exact mode means the next thing we are looking for is the separator
+ if (target.compare(start, slen, separator) != 0) {
+ return std::string::npos;
+ } else {
+ return start + slen; // Found the separator, return where we found it
+ }
+ } else {
+ auto pos = start + 1;
+ if (!separator.empty()) {
+ pos = target.find(separator, pos);
+ }
+ if (pos == std::string::npos) {
+ return pos;
+ } else if (mode == kMatchInteger) {
+ if (!MatchesInteger(target, start, pos)) {
+ return std::string::npos;
+ }
+ } else if (mode == kMatchDecimal) {
+ if (!MatchesDecimal(target, start, pos)) {
+ return std::string::npos;
+ }
+ }
+ return pos + slen;
+ }
+}
+
+bool ObjectLibrary::PatternEntry::MatchesTarget(const std::string &name,
+ size_t nlen,
+ const std::string &target,
+ size_t tlen) const {
+ if (separators_.empty()) {
+ assert(optional_); // If there are no separators, it must be only a name
+ return nlen == tlen && name == target;
+ } else if (nlen == tlen) { // The lengths are the same
+ return optional_ && name == target;
+ } else if (tlen < nlen + slength_) {
+ // The target is not long enough
+ return false;
+ } else if (target.compare(0, nlen, name) != 0) {
+ return false; // Target does not start with name
+ } else {
+ // Loop through all of the separators one at a time matching them.
+ // Note that we first match the separator and then its quantifiers.
+ // Since we expect the separator first, we start with an exact match
+ // Subsequent matches will use the quantifier of the previous separator
+ size_t start = nlen;
+ auto mode = kMatchExact;
+ for (size_t idx = 0; idx < separators_.size(); ++idx) {
+ const auto &separator = separators_[idx];
+ start = MatchSeparatorAt(start, mode, target, tlen, separator.first);
+ if (start == std::string::npos) {
+ return false;
+ } else {
+ mode = separator.second;
+ }
+ }
+ // We have matched all of the separators. Now check that what is left
+ // unmatched in the target is acceptable.
+ if (mode == kMatchExact) {
+ return (start == tlen);
+ } else if (start > tlen || (start == tlen && mode != kMatchZeroOrMore)) {
+ return false;
+ } else if (mode == kMatchInteger) {
+ return MatchesInteger(target, start, tlen);
+ } else if (mode == kMatchDecimal) {
+ return MatchesDecimal(target, start, tlen);
+ }
+ }
+ return true;
+}
+
+bool ObjectLibrary::PatternEntry::Matches(const std::string &target) const {
+ auto tlen = target.size();
+ if (MatchesTarget(name_, nlength_, target, tlen)) {
+ return true;
+ } else if (!names_.empty()) {
+ for (const auto &alt : names_) {
+ if (MatchesTarget(alt, alt.size(), target, tlen)) {
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+size_t ObjectLibrary::GetFactoryCount(size_t *types) const {
+ std::unique_lock<std::mutex> lock(mu_);
+ *types = factories_.size();
+ size_t factories = 0;
+ for (const auto &e : factories_) {
+ factories += e.second.size();
+ }
+ return factories;
+}
+
+size_t ObjectLibrary::GetFactoryCount(const std::string &type) const {
+ std::unique_lock<std::mutex> lock(mu_);
+ auto iter = factories_.find(type);
+ if (iter != factories_.end()) {
+ return iter->second.size();
+ } else {
+ return 0;
+ }
+}
+
+void ObjectLibrary::GetFactoryNames(const std::string &type,
+ std::vector<std::string> *names) const {
+ assert(names);
+ std::unique_lock<std::mutex> lock(mu_);
+ auto iter = factories_.find(type);
+ if (iter != factories_.end()) {
+ for (const auto &f : iter->second) {
+ names->push_back(f->Name());
+ }
+ }
+}
+
+void ObjectLibrary::GetFactoryTypes(
+ std::unordered_set<std::string> *types) const {
+ assert(types);
+ std::unique_lock<std::mutex> lock(mu_);
+ for (const auto &iter : factories_) {
+ types->insert(iter.first);
+ }
+}
+
+void ObjectLibrary::Dump(Logger *logger) const {
+ std::unique_lock<std::mutex> lock(mu_);
+ if (logger != nullptr && !factories_.empty()) {
+ ROCKS_LOG_HEADER(logger, " Registered Library: %s\n", id_.c_str());
+ for (const auto &iter : factories_) {
+ ROCKS_LOG_HEADER(logger, " Registered factories for type[%s] ",
+ iter.first.c_str());
+ bool printed_one = false;
+ for (const auto &e : iter.second) {
+ ROCKS_LOG_HEADER(logger, "%c %s", (printed_one) ? ',' : ':', e->Name());
+ printed_one = true;
+ }
+ }
+ }
+}
+
+// Returns the Default singleton instance of the ObjectLibrary
+// This instance will contain most of the "standard" registered objects
+std::shared_ptr<ObjectLibrary> &ObjectLibrary::Default() {
+ // Use avoid destruction here so the default ObjectLibrary will not be
+ // statically destroyed and long-lived.
+ STATIC_AVOID_DESTRUCTION(std::shared_ptr<ObjectLibrary>, instance)
+ (std::make_shared<ObjectLibrary>("default"));
+ return instance;
+}
+
+ObjectRegistry::ObjectRegistry(const std::shared_ptr<ObjectLibrary> &library) {
+ libraries_.push_back(library);
+ for (const auto &b : builtins_) {
+ RegisterPlugin(b.first, b.second);
+ }
+}
+
+std::shared_ptr<ObjectRegistry> ObjectRegistry::Default() {
+ // Use avoid destruction here so the default ObjectRegistry will not be
+ // statically destroyed and long-lived.
+ STATIC_AVOID_DESTRUCTION(std::shared_ptr<ObjectRegistry>, instance)
+ (std::make_shared<ObjectRegistry>(ObjectLibrary::Default()));
+ return instance;
+}
+
+std::shared_ptr<ObjectRegistry> ObjectRegistry::NewInstance() {
+ return std::make_shared<ObjectRegistry>(Default());
+}
+
+std::shared_ptr<ObjectRegistry> ObjectRegistry::NewInstance(
+ const std::shared_ptr<ObjectRegistry> &parent) {
+ return std::make_shared<ObjectRegistry>(parent);
+}
+
+Status ObjectRegistry::SetManagedObject(
+ const std::string &type, const std::string &id,
+ const std::shared_ptr<Customizable> &object) {
+ std::string object_key = ToManagedObjectKey(type, id);
+ std::shared_ptr<Customizable> curr;
+ if (parent_ != nullptr) {
+ curr = parent_->GetManagedObject(type, id);
+ }
+ if (curr == nullptr) {
+ // We did not find the object in any parent. Update in the current
+ std::unique_lock<std::mutex> lock(objects_mutex_);
+ auto iter = managed_objects_.find(object_key);
+ if (iter != managed_objects_.end()) { // The object exists
+ curr = iter->second.lock();
+ if (curr != nullptr && curr != object) {
+ return Status::InvalidArgument("Object already exists: ", object_key);
+ } else {
+ iter->second = object;
+ }
+ } else {
+ // The object does not exist. Add it
+ managed_objects_[object_key] = object;
+ }
+ } else if (curr != object) {
+ return Status::InvalidArgument("Object already exists: ", object_key);
+ }
+ return Status::OK();
+}
+
+std::shared_ptr<Customizable> ObjectRegistry::GetManagedObject(
+ const std::string &type, const std::string &id) const {
+ {
+ std::unique_lock<std::mutex> lock(objects_mutex_);
+ auto iter = managed_objects_.find(ToManagedObjectKey(type, id));
+ if (iter != managed_objects_.end()) {
+ return iter->second.lock();
+ }
+ }
+ if (parent_ != nullptr) {
+ return parent_->GetManagedObject(type, id);
+ } else {
+ return nullptr;
+ }
+}
+
+Status ObjectRegistry::ListManagedObjects(
+ const std::string &type, const std::string &name,
+ std::vector<std::shared_ptr<Customizable>> *results) const {
+ {
+ std::string key = ToManagedObjectKey(type, name);
+ std::unique_lock<std::mutex> lock(objects_mutex_);
+ for (auto iter = managed_objects_.lower_bound(key);
+ iter != managed_objects_.end() && StartsWith(iter->first, key);
+ ++iter) {
+ auto shared = iter->second.lock();
+ if (shared != nullptr) {
+ if (name.empty() || shared->IsInstanceOf(name)) {
+ results->emplace_back(shared);
+ }
+ }
+ }
+ }
+ if (parent_ != nullptr) {
+ return parent_->ListManagedObjects(type, name, results);
+ } else {
+ return Status::OK();
+ }
+}
+
+// Returns the number of registered types for this registry.
+// If specified (not-null), types is updated to include the names of the
+// registered types.
+size_t ObjectRegistry::GetFactoryCount(const std::string &type) const {
+ size_t count = 0;
+ if (parent_ != nullptr) {
+ count = parent_->GetFactoryCount(type);
+ }
+ std::unique_lock<std::mutex> lock(library_mutex_);
+ for (const auto &library : libraries_) {
+ count += library->GetFactoryCount(type);
+ }
+ return count;
+}
+
+void ObjectRegistry::GetFactoryNames(const std::string &type,
+ std::vector<std::string> *names) const {
+ assert(names);
+ names->clear();
+ if (parent_ != nullptr) {
+ parent_->GetFactoryNames(type, names);
+ }
+ std::unique_lock<std::mutex> lock(library_mutex_);
+ for (const auto &library : libraries_) {
+ library->GetFactoryNames(type, names);
+ }
+}
+
+void ObjectRegistry::GetFactoryTypes(
+ std::unordered_set<std::string> *types) const {
+ assert(types);
+ if (parent_ != nullptr) {
+ parent_->GetFactoryTypes(types);
+ }
+ std::unique_lock<std::mutex> lock(library_mutex_);
+ for (const auto &library : libraries_) {
+ library->GetFactoryTypes(types);
+ }
+}
+
+void ObjectRegistry::Dump(Logger *logger) const {
+ if (logger != nullptr) {
+ std::unique_lock<std::mutex> lock(library_mutex_);
+ if (!plugins_.empty()) {
+ ROCKS_LOG_HEADER(logger, " Registered Plugins:");
+ bool printed_one = false;
+ for (const auto &plugin : plugins_) {
+ ROCKS_LOG_HEADER(logger, "%s%s", (printed_one) ? ", " : " ",
+ plugin.c_str());
+ printed_one = true;
+ }
+ ROCKS_LOG_HEADER(logger, "\n");
+ }
+ for (auto iter = libraries_.crbegin(); iter != libraries_.crend(); ++iter) {
+ iter->get()->Dump(logger);
+ }
+ }
+ if (parent_ != nullptr) {
+ parent_->Dump(logger);
+ }
+}
+
+int ObjectRegistry::RegisterPlugin(const std::string &name,
+ const RegistrarFunc &func) {
+ if (!name.empty() && func != nullptr) {
+ plugins_.push_back(name);
+ return AddLibrary(name)->Register(func, name);
+ } else {
+ return -1;
+ }
+}
+
+#endif // ROCKSDB_LITE
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/object_registry_test.cc b/src/rocksdb/utilities/object_registry_test.cc
new file mode 100644
index 000000000..90cd155ee
--- /dev/null
+++ b/src/rocksdb/utilities/object_registry_test.cc
@@ -0,0 +1,872 @@
+// Copyright (c) 2016-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "rocksdb/utilities/object_registry.h"
+
+#include "rocksdb/convenience.h"
+#include "rocksdb/customizable.h"
+#include "test_util/testharness.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class ObjRegistryTest : public testing::Test {
+ public:
+ static int num_a, num_b;
+};
+
+int ObjRegistryTest::num_a = 0;
+int ObjRegistryTest::num_b = 0;
+static FactoryFunc<Env> test_reg_a = ObjectLibrary::Default()->AddFactory<Env>(
+ ObjectLibrary::PatternEntry("a", false).AddSeparator("://"),
+ [](const std::string& /*uri*/, std::unique_ptr<Env>* /*env_guard*/,
+ std::string* /* errmsg */) {
+ ++ObjRegistryTest::num_a;
+ return Env::Default();
+ });
+
+class WrappedEnv : public EnvWrapper {
+ private:
+ std::string id_;
+
+ public:
+ WrappedEnv(Env* t, const std::string& id) : EnvWrapper(t), id_(id) {}
+ const char* Name() const override { return id_.c_str(); }
+ std::string GetId() const override { return id_; }
+};
+static FactoryFunc<Env> test_reg_b = ObjectLibrary::Default()->AddFactory<Env>(
+ ObjectLibrary::PatternEntry("b", false).AddSeparator("://"),
+ [](const std::string& uri, std::unique_ptr<Env>* env_guard,
+ std::string* /* errmsg */) {
+ ++ObjRegistryTest::num_b;
+ // Env::Default() is a singleton so we can't grant ownership directly to
+ // the caller - we must wrap it first.
+ env_guard->reset(new WrappedEnv(Env::Default(), uri));
+ return env_guard->get();
+ });
+
+TEST_F(ObjRegistryTest, Basics) {
+ std::string msg;
+ std::unique_ptr<Env> guard;
+ Env* a_env = nullptr;
+
+ auto registry = ObjectRegistry::NewInstance();
+ ASSERT_NOK(registry->NewStaticObject<Env>("c://test", &a_env));
+ ASSERT_NOK(registry->NewUniqueObject<Env>("c://test", &guard));
+ ASSERT_EQ(a_env, nullptr);
+ ASSERT_EQ(guard, nullptr);
+ ASSERT_EQ(0, num_a);
+ ASSERT_EQ(0, num_b);
+
+ ASSERT_OK(registry->NewStaticObject<Env>("a://test", &a_env));
+ ASSERT_NE(a_env, nullptr);
+ ASSERT_EQ(1, num_a);
+ ASSERT_EQ(0, num_b);
+
+ ASSERT_OK(registry->NewUniqueObject<Env>("b://test", &guard));
+ ASSERT_NE(guard, nullptr);
+ ASSERT_EQ(1, num_a);
+ ASSERT_EQ(1, num_b);
+
+ Env* b_env = nullptr;
+ ASSERT_NOK(registry->NewStaticObject<Env>("b://test", &b_env));
+ ASSERT_EQ(b_env, nullptr);
+ ASSERT_EQ(1, num_a);
+ ASSERT_EQ(2, num_b); // Created but rejected as not static
+
+ b_env = a_env;
+ ASSERT_NOK(registry->NewStaticObject<Env>("b://test", &b_env));
+ ASSERT_EQ(b_env, a_env);
+ ASSERT_EQ(1, num_a);
+ ASSERT_EQ(3, num_b);
+
+ b_env = guard.get();
+ ASSERT_NOK(registry->NewUniqueObject<Env>("a://test", &guard));
+ ASSERT_EQ(guard.get(), b_env); // Unchanged
+ ASSERT_EQ(2, num_a); // Created one but rejected it as not unique
+ ASSERT_EQ(3, num_b);
+}
+
+TEST_F(ObjRegistryTest, LocalRegistry) {
+ Env* env = nullptr;
+ auto registry = ObjectRegistry::NewInstance();
+ std::shared_ptr<ObjectLibrary> library =
+ std::make_shared<ObjectLibrary>("local");
+ registry->AddLibrary(library);
+ library->AddFactory<Env>(
+ "test-local",
+ [](const std::string& /*uri*/, std::unique_ptr<Env>* /*guard */,
+ std::string* /* errmsg */) { return Env::Default(); });
+
+ ObjectLibrary::Default()->AddFactory<Env>(
+ "test-global",
+ [](const std::string& /*uri*/, std::unique_ptr<Env>* /*guard */,
+ std::string* /* errmsg */) { return Env::Default(); });
+
+ ASSERT_NOK(
+ ObjectRegistry::NewInstance()->NewStaticObject<Env>("test-local", &env));
+ ASSERT_EQ(env, nullptr);
+ ASSERT_OK(
+ ObjectRegistry::NewInstance()->NewStaticObject<Env>("test-global", &env));
+ ASSERT_NE(env, nullptr);
+ ASSERT_OK(registry->NewStaticObject<Env>("test-local", &env));
+ ASSERT_NE(env, nullptr);
+ ASSERT_OK(registry->NewStaticObject<Env>("test-global", &env));
+ ASSERT_NE(env, nullptr);
+}
+
+static int RegisterTestUnguarded(ObjectLibrary& library,
+ const std::string& /*arg*/) {
+ library.AddFactory<Env>(
+ "unguarded",
+ [](const std::string& /*uri*/, std::unique_ptr<Env>* /*guard */,
+ std::string* /* errmsg */) { return Env::Default(); });
+ library.AddFactory<Env>(
+ "guarded", [](const std::string& uri, std::unique_ptr<Env>* guard,
+ std::string* /* errmsg */) {
+ guard->reset(new WrappedEnv(Env::Default(), uri));
+ return guard->get();
+ });
+ return 2;
+}
+
+TEST_F(ObjRegistryTest, CheckShared) {
+ std::shared_ptr<Env> shared;
+ std::shared_ptr<ObjectRegistry> registry = ObjectRegistry::NewInstance();
+ registry->AddLibrary("shared", RegisterTestUnguarded, "");
+
+ ASSERT_OK(registry->NewSharedObject<Env>("guarded", &shared));
+ ASSERT_NE(shared, nullptr);
+ shared.reset();
+ ASSERT_NOK(registry->NewSharedObject<Env>("unguarded", &shared));
+ ASSERT_EQ(shared, nullptr);
+}
+
+TEST_F(ObjRegistryTest, CheckStatic) {
+ Env* env = nullptr;
+ std::shared_ptr<ObjectRegistry> registry = ObjectRegistry::NewInstance();
+ registry->AddLibrary("static", RegisterTestUnguarded, "");
+
+ ASSERT_NOK(registry->NewStaticObject<Env>("guarded", &env));
+ ASSERT_EQ(env, nullptr);
+ env = nullptr;
+ ASSERT_OK(registry->NewStaticObject<Env>("unguarded", &env));
+ ASSERT_NE(env, nullptr);
+}
+
+TEST_F(ObjRegistryTest, CheckUnique) {
+ std::unique_ptr<Env> unique;
+ std::shared_ptr<ObjectRegistry> registry = ObjectRegistry::NewInstance();
+ registry->AddLibrary("unique", RegisterTestUnguarded, "");
+
+ ASSERT_OK(registry->NewUniqueObject<Env>("guarded", &unique));
+ ASSERT_NE(unique, nullptr);
+ unique.reset();
+ ASSERT_NOK(registry->NewUniqueObject<Env>("unguarded", &unique));
+ ASSERT_EQ(unique, nullptr);
+}
+
+TEST_F(ObjRegistryTest, FailingFactory) {
+ std::shared_ptr<ObjectRegistry> registry = ObjectRegistry::NewInstance();
+ std::shared_ptr<ObjectLibrary> library =
+ std::make_shared<ObjectLibrary>("failing");
+ registry->AddLibrary(library);
+ library->AddFactory<Env>(
+ "failing", [](const std::string& /*uri*/,
+ std::unique_ptr<Env>* /*guard */, std::string* errmsg) {
+ *errmsg = "Bad Factory";
+ return nullptr;
+ });
+ std::unique_ptr<Env> unique;
+ std::shared_ptr<Env> shared;
+ Env* pointer = nullptr;
+ Status s;
+ s = registry->NewUniqueObject<Env>("failing", &unique);
+ ASSERT_TRUE(s.IsInvalidArgument());
+ s = registry->NewSharedObject<Env>("failing", &shared);
+ ASSERT_TRUE(s.IsInvalidArgument());
+ s = registry->NewStaticObject<Env>("failing", &pointer);
+ ASSERT_TRUE(s.IsInvalidArgument());
+
+ s = registry->NewUniqueObject<Env>("missing", &unique);
+ ASSERT_TRUE(s.IsNotSupported());
+ s = registry->NewSharedObject<Env>("missing", &shared);
+ ASSERT_TRUE(s.IsNotSupported());
+ s = registry->NewStaticObject<Env>("missing", &pointer);
+ ASSERT_TRUE(s.IsNotSupported());
+}
+
+TEST_F(ObjRegistryTest, TestRegistryParents) {
+ auto grand = ObjectRegistry::Default();
+ auto parent = ObjectRegistry::NewInstance(); // parent with a grandparent
+ auto uncle = ObjectRegistry::NewInstance(grand);
+ auto child = ObjectRegistry::NewInstance(parent);
+ auto cousin = ObjectRegistry::NewInstance(uncle);
+
+ auto library = parent->AddLibrary("parent");
+ library->AddFactory<Env>(
+ "parent", [](const std::string& uri, std::unique_ptr<Env>* guard,
+ std::string* /* errmsg */) {
+ guard->reset(new WrappedEnv(Env::Default(), uri));
+ return guard->get();
+ });
+ library = cousin->AddLibrary("cousin");
+ library->AddFactory<Env>(
+ "cousin", [](const std::string& uri, std::unique_ptr<Env>* guard,
+ std::string* /* errmsg */) {
+ guard->reset(new WrappedEnv(Env::Default(), uri));
+ return guard->get();
+ });
+
+ Env* env = nullptr;
+ std::unique_ptr<Env> guard;
+ std::string msg;
+
+ // a:://* is registered in Default, so they should all work
+ ASSERT_OK(parent->NewStaticObject<Env>("a://test", &env));
+ ASSERT_OK(child->NewStaticObject<Env>("a://test", &env));
+ ASSERT_OK(uncle->NewStaticObject<Env>("a://test", &env));
+ ASSERT_OK(cousin->NewStaticObject<Env>("a://test", &env));
+
+ // The parent env is only registered for parent, not uncle,
+ // So parent and child should return success and uncle and cousin should fail
+ ASSERT_OK(parent->NewUniqueObject<Env>("parent", &guard));
+ ASSERT_OK(child->NewUniqueObject<Env>("parent", &guard));
+ ASSERT_NOK(uncle->NewUniqueObject<Env>("parent", &guard));
+ ASSERT_NOK(cousin->NewUniqueObject<Env>("parent", &guard));
+
+ // The cousin is only registered in the cousin, so all of the others should
+ // fail
+ ASSERT_OK(cousin->NewUniqueObject<Env>("cousin", &guard));
+ ASSERT_NOK(parent->NewUniqueObject<Env>("cousin", &guard));
+ ASSERT_NOK(child->NewUniqueObject<Env>("cousin", &guard));
+ ASSERT_NOK(uncle->NewUniqueObject<Env>("cousin", &guard));
+}
+
+class MyCustomizable : public Customizable {
+ public:
+ static const char* Type() { return "MyCustomizable"; }
+ MyCustomizable(const char* prefix, const std::string& id) : id_(id) {
+ name_ = id_.substr(0, strlen(prefix) - 1);
+ }
+ const char* Name() const override { return name_.c_str(); }
+ std::string GetId() const override { return id_; }
+
+ private:
+ std::string id_;
+ std::string name_;
+};
+
+TEST_F(ObjRegistryTest, TestFactoryCount) {
+ std::string msg;
+ auto grand = ObjectRegistry::Default();
+ auto local = ObjectRegistry::NewInstance();
+ std::unordered_set<std::string> grand_types, local_types;
+ std::vector<std::string> grand_names, local_names;
+
+ // Check how many types we have on startup.
+ // Grand should equal local
+ grand->GetFactoryTypes(&grand_types);
+ local->GetFactoryTypes(&local_types);
+ ASSERT_EQ(grand_types, local_types);
+ size_t grand_count = grand->GetFactoryCount(Env::Type());
+ size_t local_count = local->GetFactoryCount(Env::Type());
+
+ ASSERT_EQ(grand_count, local_count);
+ grand->GetFactoryNames(Env::Type(), &grand_names);
+ local->GetFactoryNames(Env::Type(), &local_names);
+ ASSERT_EQ(grand_names.size(), grand_count);
+ ASSERT_EQ(local_names.size(), local_count);
+ ASSERT_EQ(grand_names, local_names);
+
+ // Add an Env to the local registry.
+ // This will add one factory.
+ auto library = local->AddLibrary("local");
+ library->AddFactory<Env>(
+ "A", [](const std::string& /*uri*/, std::unique_ptr<Env>* /*guard */,
+ std::string* /* errmsg */) { return nullptr; });
+ ASSERT_EQ(local_count + 1, local->GetFactoryCount(Env::Type()));
+ ASSERT_EQ(grand_count, grand->GetFactoryCount(Env::Type()));
+ local->GetFactoryTypes(&local_types);
+ local->GetFactoryNames(Env::Type(), &local_names);
+ ASSERT_EQ(grand_names.size() + 1, local_names.size());
+ ASSERT_EQ(local_names.size(), local->GetFactoryCount(Env::Type()));
+
+ if (grand_count == 0) {
+ // There were no Env when we started. Should have one more type
+ // than previously
+ ASSERT_NE(grand_types, local_types);
+ ASSERT_EQ(grand_types.size() + 1, local_types.size());
+ } else {
+ // There was an Env type when we started. The types should match
+ ASSERT_EQ(grand_types, local_types);
+ }
+
+ // Add a MyCustomizable to the registry. This should be a new type
+ library->AddFactory<MyCustomizable>(
+ "MY", [](const std::string& /*uri*/,
+ std::unique_ptr<MyCustomizable>* /*guard */,
+ std::string* /* errmsg */) { return nullptr; });
+ ASSERT_EQ(local_count + 1, local->GetFactoryCount(Env::Type()));
+ ASSERT_EQ(grand_count, grand->GetFactoryCount(Env::Type()));
+ ASSERT_EQ(0U, grand->GetFactoryCount(MyCustomizable::Type()));
+ ASSERT_EQ(1U, local->GetFactoryCount(MyCustomizable::Type()));
+
+ local->GetFactoryNames(MyCustomizable::Type(), &local_names);
+ ASSERT_EQ(1U, local_names.size());
+ ASSERT_EQ(local_names[0], "MY");
+
+ local->GetFactoryTypes(&local_types);
+ ASSERT_EQ(grand_count == 0 ? 2 : grand_types.size() + 1, local_types.size());
+
+ // Add the same name again. We should now have 2 factories.
+ library->AddFactory<MyCustomizable>(
+ "MY", [](const std::string& /*uri*/,
+ std::unique_ptr<MyCustomizable>* /*guard */,
+ std::string* /* errmsg */) { return nullptr; });
+ local->GetFactoryNames(MyCustomizable::Type(), &local_names);
+ ASSERT_EQ(2U, local_names.size());
+}
+
+TEST_F(ObjRegistryTest, TestManagedObjects) {
+ auto registry = ObjectRegistry::NewInstance();
+ auto m_a1 = std::make_shared<MyCustomizable>("", "A");
+ auto m_a2 = std::make_shared<MyCustomizable>("", "A");
+
+ ASSERT_EQ(registry->GetManagedObject<MyCustomizable>("A"), nullptr);
+ ASSERT_OK(registry->SetManagedObject<MyCustomizable>(m_a1));
+ ASSERT_EQ(registry->GetManagedObject<MyCustomizable>("A"), m_a1);
+
+ ASSERT_NOK(registry->SetManagedObject<MyCustomizable>(m_a2));
+ ASSERT_OK(registry->SetManagedObject<MyCustomizable>(m_a1));
+ m_a1.reset();
+ ASSERT_EQ(registry->GetManagedObject<MyCustomizable>("A"), nullptr);
+ ASSERT_OK(registry->SetManagedObject<MyCustomizable>(m_a2));
+ ASSERT_EQ(registry->GetManagedObject<MyCustomizable>("A"), m_a2);
+}
+
+TEST_F(ObjRegistryTest, TestTwoManagedObjects) {
+ auto registry = ObjectRegistry::NewInstance();
+ auto m_a = std::make_shared<MyCustomizable>("", "A");
+ auto m_b = std::make_shared<MyCustomizable>("", "B");
+ std::vector<std::shared_ptr<MyCustomizable>> objects;
+
+ ASSERT_EQ(registry->GetManagedObject<MyCustomizable>("A"), nullptr);
+ ASSERT_EQ(registry->GetManagedObject<MyCustomizable>("B"), nullptr);
+ ASSERT_OK(registry->ListManagedObjects(&objects));
+ ASSERT_EQ(objects.size(), 0U);
+ ASSERT_OK(registry->SetManagedObject(m_a));
+ ASSERT_EQ(registry->GetManagedObject<MyCustomizable>("B"), nullptr);
+ ASSERT_EQ(registry->GetManagedObject<MyCustomizable>("A"), m_a);
+ ASSERT_OK(registry->ListManagedObjects(&objects));
+ ASSERT_EQ(objects.size(), 1U);
+ ASSERT_EQ(objects.front(), m_a);
+
+ ASSERT_OK(registry->SetManagedObject(m_b));
+ ASSERT_EQ(registry->GetManagedObject<MyCustomizable>("A"), m_a);
+ ASSERT_EQ(registry->GetManagedObject<MyCustomizable>("B"), m_b);
+ ASSERT_OK(registry->ListManagedObjects(&objects));
+ ASSERT_EQ(objects.size(), 2U);
+ ASSERT_OK(registry->ListManagedObjects("A", &objects));
+ ASSERT_EQ(objects.size(), 1U);
+ ASSERT_EQ(objects.front(), m_a);
+ ASSERT_OK(registry->ListManagedObjects("B", &objects));
+ ASSERT_EQ(objects.size(), 1U);
+ ASSERT_EQ(objects.front(), m_b);
+ ASSERT_OK(registry->ListManagedObjects("C", &objects));
+ ASSERT_EQ(objects.size(), 0U);
+
+ m_a.reset();
+ objects.clear();
+
+ ASSERT_EQ(registry->GetManagedObject<MyCustomizable>("B"), m_b);
+ ASSERT_EQ(registry->GetManagedObject<MyCustomizable>("A"), nullptr);
+ ASSERT_OK(registry->ListManagedObjects(&objects));
+ ASSERT_EQ(objects.size(), 1U);
+ ASSERT_EQ(objects.front(), m_b);
+
+ m_b.reset();
+ objects.clear();
+ ASSERT_EQ(registry->GetManagedObject<MyCustomizable>("A"), nullptr);
+ ASSERT_EQ(registry->GetManagedObject<MyCustomizable>("B"), nullptr);
+}
+
+TEST_F(ObjRegistryTest, TestAlternateNames) {
+ auto registry = ObjectRegistry::NewInstance();
+ auto m_a = std::make_shared<MyCustomizable>("", "A");
+ auto m_b = std::make_shared<MyCustomizable>("", "B");
+ std::vector<std::shared_ptr<MyCustomizable>> objects;
+ // Test no objects exist
+ ASSERT_EQ(registry->GetManagedObject<MyCustomizable>("A"), nullptr);
+ ASSERT_EQ(registry->GetManagedObject<MyCustomizable>("B"), nullptr);
+ ASSERT_EQ(registry->GetManagedObject<MyCustomizable>("TheOne"), nullptr);
+ ASSERT_OK(registry->ListManagedObjects(&objects));
+ ASSERT_EQ(objects.size(), 0U);
+
+ // Mark "TheOne" to be A
+ ASSERT_OK(registry->SetManagedObject("TheOne", m_a));
+ ASSERT_EQ(registry->GetManagedObject<MyCustomizable>("B"), nullptr);
+ ASSERT_EQ(registry->GetManagedObject<MyCustomizable>("A"), nullptr);
+ ASSERT_EQ(registry->GetManagedObject<MyCustomizable>("TheOne"), m_a);
+ ASSERT_OK(registry->ListManagedObjects(&objects));
+ ASSERT_EQ(objects.size(), 1U);
+ ASSERT_EQ(objects.front(), m_a);
+
+ // Try to mark "TheOne" again.
+ ASSERT_NOK(registry->SetManagedObject("TheOne", m_b));
+ ASSERT_OK(registry->SetManagedObject("TheOne", m_a));
+
+ // Add "A" as a managed object. Registered 2x
+ ASSERT_OK(registry->SetManagedObject(m_a));
+ ASSERT_EQ(registry->GetManagedObject<MyCustomizable>("B"), nullptr);
+ ASSERT_EQ(registry->GetManagedObject<MyCustomizable>("A"), m_a);
+ ASSERT_EQ(registry->GetManagedObject<MyCustomizable>("TheOne"), m_a);
+ ASSERT_OK(registry->ListManagedObjects(&objects));
+ ASSERT_EQ(objects.size(), 2U);
+
+ // Delete "A".
+ m_a.reset();
+ objects.clear();
+
+ ASSERT_EQ(registry->GetManagedObject<MyCustomizable>("TheOne"), nullptr);
+ ASSERT_OK(registry->SetManagedObject("TheOne", m_b));
+ ASSERT_EQ(registry->GetManagedObject<MyCustomizable>("TheOne"), m_b);
+ ASSERT_OK(registry->ListManagedObjects(&objects));
+ ASSERT_EQ(objects.size(), 1U);
+ ASSERT_EQ(objects.front(), m_b);
+
+ m_b.reset();
+ objects.clear();
+ ASSERT_EQ(registry->GetManagedObject<MyCustomizable>("A"), nullptr);
+ ASSERT_EQ(registry->GetManagedObject<MyCustomizable>("A"), nullptr);
+ ASSERT_EQ(registry->GetManagedObject<MyCustomizable>("TheOne"), nullptr);
+ ASSERT_OK(registry->ListManagedObjects(&objects));
+ ASSERT_EQ(objects.size(), 0U);
+}
+
+TEST_F(ObjRegistryTest, TestTwoManagedClasses) {
+ class MyCustomizable2 : public MyCustomizable {
+ public:
+ static const char* Type() { return "MyCustomizable2"; }
+ MyCustomizable2(const char* prefix, const std::string& id)
+ : MyCustomizable(prefix, id) {}
+ };
+
+ auto registry = ObjectRegistry::NewInstance();
+ auto m_a1 = std::make_shared<MyCustomizable>("", "A");
+ auto m_a2 = std::make_shared<MyCustomizable2>("", "A");
+ std::vector<std::shared_ptr<MyCustomizable>> obj1s;
+ std::vector<std::shared_ptr<MyCustomizable2>> obj2s;
+
+ ASSERT_EQ(registry->GetManagedObject<MyCustomizable>("A"), nullptr);
+ ASSERT_EQ(registry->GetManagedObject<MyCustomizable2>("A"), nullptr);
+
+ ASSERT_OK(registry->SetManagedObject(m_a1));
+ ASSERT_EQ(registry->GetManagedObject<MyCustomizable>("A"), m_a1);
+ ASSERT_EQ(registry->GetManagedObject<MyCustomizable2>("A"), nullptr);
+
+ ASSERT_OK(registry->SetManagedObject(m_a2));
+ ASSERT_EQ(registry->GetManagedObject<MyCustomizable2>("A"), m_a2);
+ ASSERT_OK(registry->ListManagedObjects(&obj1s));
+ ASSERT_OK(registry->ListManagedObjects(&obj2s));
+ ASSERT_EQ(obj1s.size(), 1U);
+ ASSERT_EQ(obj2s.size(), 1U);
+ ASSERT_EQ(obj1s.front(), m_a1);
+ ASSERT_EQ(obj2s.front(), m_a2);
+ m_a1.reset();
+ obj1s.clear();
+ obj2s.clear();
+ ASSERT_EQ(registry->GetManagedObject<MyCustomizable>("A"), nullptr);
+ ASSERT_EQ(registry->GetManagedObject<MyCustomizable2>("A"), m_a2);
+
+ m_a2.reset();
+ ASSERT_EQ(registry->GetManagedObject<MyCustomizable>("A"), nullptr);
+ ASSERT_EQ(registry->GetManagedObject<MyCustomizable2>("A"), nullptr);
+}
+
+TEST_F(ObjRegistryTest, TestManagedObjectsWithParent) {
+ auto base = ObjectRegistry::NewInstance();
+ auto registry = ObjectRegistry::NewInstance(base);
+
+ auto m_a = std::make_shared<MyCustomizable>("", "A");
+ auto m_b = std::make_shared<MyCustomizable>("", "A");
+
+ ASSERT_EQ(registry->GetManagedObject<MyCustomizable>("A"), nullptr);
+ ASSERT_OK(base->SetManagedObject(m_a));
+ ASSERT_EQ(registry->GetManagedObject<MyCustomizable>("A"), m_a);
+
+ ASSERT_NOK(registry->SetManagedObject(m_b));
+ ASSERT_OK(registry->SetManagedObject(m_a));
+
+ m_a.reset();
+ ASSERT_EQ(registry->GetManagedObject<MyCustomizable>("A"), nullptr);
+ ASSERT_OK(registry->SetManagedObject(m_b));
+ ASSERT_EQ(registry->GetManagedObject<MyCustomizable>("A"), m_b);
+}
+
+TEST_F(ObjRegistryTest, TestGetOrCreateManagedObject) {
+ auto registry = ObjectRegistry::NewInstance();
+ registry->AddLibrary("test")->AddFactory<MyCustomizable>(
+ ObjectLibrary::PatternEntry::AsIndividualId("MC"),
+ [](const std::string& uri, std::unique_ptr<MyCustomizable>* guard,
+ std::string* /* errmsg */) {
+ guard->reset(new MyCustomizable("MC", uri));
+ return guard->get();
+ });
+ std::shared_ptr<MyCustomizable> m_a, m_b, obj;
+ std::vector<std::shared_ptr<MyCustomizable>> objs;
+
+ std::unordered_map<std::string, std::string> opt_map;
+
+ ASSERT_EQ(registry->GetManagedObject<MyCustomizable>("MC@A#1"), nullptr);
+ ASSERT_EQ(registry->GetManagedObject<MyCustomizable>("MC@B#1"), nullptr);
+ ASSERT_OK(registry->GetOrCreateManagedObject("MC@A#1", &m_a));
+ ASSERT_OK(registry->GetOrCreateManagedObject("MC@B#1", &m_b));
+ ASSERT_EQ(registry->GetManagedObject<MyCustomizable>("MC@A#1"), m_a);
+ ASSERT_OK(registry->GetOrCreateManagedObject("MC@A#1", &obj));
+ ASSERT_EQ(obj, m_a);
+ ASSERT_OK(registry->GetOrCreateManagedObject("MC@B#1", &obj));
+ ASSERT_EQ(obj, m_b);
+ ASSERT_OK(registry->ListManagedObjects(&objs));
+ ASSERT_EQ(objs.size(), 2U);
+
+ objs.clear();
+ m_a.reset();
+ obj.reset();
+ ASSERT_OK(registry->GetOrCreateManagedObject("MC@A#1", &m_a));
+ ASSERT_EQ(1, m_a.use_count());
+ ASSERT_OK(registry->GetOrCreateManagedObject("MC@B#1", &obj));
+ ASSERT_EQ(2, obj.use_count());
+}
+
+TEST_F(ObjRegistryTest, RegisterPlugin) {
+ std::shared_ptr<ObjectRegistry> registry = ObjectRegistry::NewInstance();
+ std::unique_ptr<Env> guard;
+ Env* env = nullptr;
+
+ ASSERT_NOK(registry->NewObject<Env>("unguarded", &env, &guard));
+ ASSERT_EQ(registry->RegisterPlugin("Missing", nullptr), -1);
+ ASSERT_EQ(registry->RegisterPlugin("", RegisterTestUnguarded), -1);
+ ASSERT_GT(registry->RegisterPlugin("Valid", RegisterTestUnguarded), 0);
+ ASSERT_OK(registry->NewObject<Env>("unguarded", &env, &guard));
+ ASSERT_NE(env, nullptr);
+}
+class PatternEntryTest : public testing::Test {};
+
+TEST_F(PatternEntryTest, TestSimpleEntry) {
+ ObjectLibrary::PatternEntry entry("ABC", true);
+
+ ASSERT_TRUE(entry.Matches("ABC"));
+ ASSERT_FALSE(entry.Matches("AABC"));
+ ASSERT_FALSE(entry.Matches("ABCA"));
+ ASSERT_FALSE(entry.Matches("AABCA"));
+ ASSERT_FALSE(entry.Matches("AB"));
+ ASSERT_FALSE(entry.Matches("BC"));
+ ASSERT_FALSE(entry.Matches("ABD"));
+ ASSERT_FALSE(entry.Matches("BCA"));
+}
+
+TEST_F(PatternEntryTest, TestPatternEntry) {
+ // Matches A:+
+ ObjectLibrary::PatternEntry entry("A", false);
+ entry.AddSeparator(":");
+ ASSERT_FALSE(entry.Matches("A"));
+ ASSERT_FALSE(entry.Matches("AA"));
+ ASSERT_FALSE(entry.Matches("AB"));
+ ASSERT_FALSE(entry.Matches("B"));
+ ASSERT_FALSE(entry.Matches("A:"));
+ ASSERT_FALSE(entry.Matches("AA:"));
+ ASSERT_FALSE(entry.Matches("AA:B"));
+ ASSERT_FALSE(entry.Matches("AA:BB"));
+ ASSERT_TRUE(entry.Matches("A:B"));
+ ASSERT_TRUE(entry.Matches("A:BB"));
+
+ entry.SetOptional(true); // Now matches "A" or "A:+"
+ ASSERT_TRUE(entry.Matches("A"));
+ ASSERT_FALSE(entry.Matches("AA"));
+ ASSERT_FALSE(entry.Matches("AB"));
+ ASSERT_FALSE(entry.Matches("B"));
+ ASSERT_FALSE(entry.Matches("A:"));
+ ASSERT_FALSE(entry.Matches("AA:"));
+ ASSERT_FALSE(entry.Matches("AA:B"));
+ ASSERT_FALSE(entry.Matches("AA:BB"));
+ ASSERT_TRUE(entry.Matches("A:B"));
+ ASSERT_TRUE(entry.Matches("A:BB"));
+}
+
+TEST_F(PatternEntryTest, MatchZeroOrMore) {
+ // Matches A:*
+ ObjectLibrary::PatternEntry entry("A", false);
+ entry.AddSeparator(":", false);
+ ASSERT_FALSE(entry.Matches("A"));
+ ASSERT_FALSE(entry.Matches("AA"));
+ ASSERT_FALSE(entry.Matches("AB"));
+ ASSERT_FALSE(entry.Matches("B"));
+ ASSERT_TRUE(entry.Matches("A:"));
+ ASSERT_FALSE(entry.Matches("B:"));
+ ASSERT_FALSE(entry.Matches("B:A"));
+ ASSERT_FALSE(entry.Matches("AA:"));
+ ASSERT_FALSE(entry.Matches("AA:B"));
+ ASSERT_FALSE(entry.Matches("AA:BB"));
+ ASSERT_TRUE(entry.Matches("A:B"));
+ ASSERT_TRUE(entry.Matches("A:BB"));
+
+ entry.SetOptional(true); // Now matches "A" or "A:*"
+ ASSERT_TRUE(entry.Matches("A"));
+ ASSERT_FALSE(entry.Matches("AA"));
+ ASSERT_FALSE(entry.Matches("AB"));
+ ASSERT_FALSE(entry.Matches("B"));
+ ASSERT_TRUE(entry.Matches("A:"));
+ ASSERT_FALSE(entry.Matches("B:"));
+ ASSERT_FALSE(entry.Matches("B:A"));
+ ASSERT_FALSE(entry.Matches("AA:"));
+ ASSERT_FALSE(entry.Matches("AA:B"));
+ ASSERT_FALSE(entry.Matches("AA:BB"));
+ ASSERT_TRUE(entry.Matches("A:B"));
+ ASSERT_TRUE(entry.Matches("A:BB"));
+}
+
+TEST_F(PatternEntryTest, TestSuffixEntry) {
+ ObjectLibrary::PatternEntry entry("AA", true);
+ entry.AddSuffix("BB");
+
+ ASSERT_TRUE(entry.Matches("AA"));
+ ASSERT_TRUE(entry.Matches("AABB"));
+
+ ASSERT_FALSE(entry.Matches("A"));
+ ASSERT_FALSE(entry.Matches("AB"));
+ ASSERT_FALSE(entry.Matches("B"));
+ ASSERT_FALSE(entry.Matches("BB"));
+ ASSERT_FALSE(entry.Matches("ABA"));
+ ASSERT_FALSE(entry.Matches("BBAA"));
+ ASSERT_FALSE(entry.Matches("AABBA"));
+ ASSERT_FALSE(entry.Matches("AABBB"));
+}
+
+TEST_F(PatternEntryTest, TestNumericEntry) {
+ ObjectLibrary::PatternEntry entry("A", false);
+ entry.AddNumber(":");
+ ASSERT_FALSE(entry.Matches("A"));
+ ASSERT_FALSE(entry.Matches("AA"));
+ ASSERT_FALSE(entry.Matches("A:"));
+ ASSERT_FALSE(entry.Matches("AA:"));
+ ASSERT_TRUE(entry.Matches("A:1"));
+ ASSERT_TRUE(entry.Matches("A:11"));
+ ASSERT_FALSE(entry.Matches("AA:1"));
+ ASSERT_FALSE(entry.Matches("AA:11"));
+ ASSERT_FALSE(entry.Matches("A:B"));
+ ASSERT_FALSE(entry.Matches("A:1B"));
+ ASSERT_FALSE(entry.Matches("A:B1"));
+
+ entry.AddSeparator(":", false);
+ ASSERT_FALSE(entry.Matches("A"));
+ ASSERT_FALSE(entry.Matches("AA"));
+ ASSERT_FALSE(entry.Matches("A:"));
+ ASSERT_FALSE(entry.Matches("AA:"));
+ ASSERT_TRUE(entry.Matches("A:1:"));
+ ASSERT_TRUE(entry.Matches("A:11:"));
+ ASSERT_FALSE(entry.Matches("A:1"));
+ ASSERT_FALSE(entry.Matches("A:B1:"));
+ ASSERT_FALSE(entry.Matches("A:1B:"));
+ ASSERT_FALSE(entry.Matches("A::"));
+}
+
+TEST_F(PatternEntryTest, TestDoubleEntry) {
+ ObjectLibrary::PatternEntry entry("A", false);
+ entry.AddNumber(":", false);
+ ASSERT_FALSE(entry.Matches("A"));
+ ASSERT_FALSE(entry.Matches("AA"));
+ ASSERT_FALSE(entry.Matches("A:"));
+ ASSERT_FALSE(entry.Matches("AA:"));
+ ASSERT_FALSE(entry.Matches("AA:1"));
+ ASSERT_FALSE(entry.Matches("AA:11"));
+ ASSERT_FALSE(entry.Matches("A:B"));
+ ASSERT_FALSE(entry.Matches("A:1B"));
+ ASSERT_FALSE(entry.Matches("A:B1"));
+ ASSERT_TRUE(entry.Matches("A:1"));
+ ASSERT_TRUE(entry.Matches("A:11"));
+ ASSERT_TRUE(entry.Matches("A:1.1"));
+ ASSERT_TRUE(entry.Matches("A:11.11"));
+ ASSERT_TRUE(entry.Matches("A:1."));
+ ASSERT_TRUE(entry.Matches("A:.1"));
+ ASSERT_TRUE(entry.Matches("A:0.1"));
+ ASSERT_TRUE(entry.Matches("A:1.0"));
+ ASSERT_TRUE(entry.Matches("A:1.0"));
+
+ ASSERT_FALSE(entry.Matches("A:1.0."));
+ ASSERT_FALSE(entry.Matches("A:1.0.2"));
+ ASSERT_FALSE(entry.Matches("A:.1.0"));
+ ASSERT_FALSE(entry.Matches("A:..10"));
+ ASSERT_FALSE(entry.Matches("A:10.."));
+ ASSERT_FALSE(entry.Matches("A:."));
+
+ entry.AddSeparator(":", false);
+ ASSERT_FALSE(entry.Matches("A:1"));
+ ASSERT_FALSE(entry.Matches("A:1.0"));
+
+ ASSERT_TRUE(entry.Matches("A:11:"));
+ ASSERT_TRUE(entry.Matches("A:1.1:"));
+ ASSERT_TRUE(entry.Matches("A:11.11:"));
+ ASSERT_TRUE(entry.Matches("A:1.:"));
+ ASSERT_TRUE(entry.Matches("A:.1:"));
+ ASSERT_TRUE(entry.Matches("A:0.1:"));
+ ASSERT_TRUE(entry.Matches("A:1.0:"));
+ ASSERT_TRUE(entry.Matches("A:1.0:"));
+
+ ASSERT_FALSE(entry.Matches("A:1.0.:"));
+ ASSERT_FALSE(entry.Matches("A:1.0.2:"));
+ ASSERT_FALSE(entry.Matches("A:.1.0:"));
+ ASSERT_FALSE(entry.Matches("A:..10:"));
+ ASSERT_FALSE(entry.Matches("A:10..:"));
+ ASSERT_FALSE(entry.Matches("A:.:"));
+ ASSERT_FALSE(entry.Matches("A::"));
+}
+
+TEST_F(PatternEntryTest, TestIndividualIdEntry) {
+ auto entry = ObjectLibrary::PatternEntry::AsIndividualId("AA");
+ ASSERT_TRUE(entry.Matches("AA"));
+ ASSERT_TRUE(entry.Matches("AA@123#456"));
+ ASSERT_TRUE(entry.Matches("AA@deadbeef#id"));
+
+ ASSERT_FALSE(entry.Matches("A"));
+ ASSERT_FALSE(entry.Matches("AAA"));
+ ASSERT_FALSE(entry.Matches("AA@123"));
+ ASSERT_FALSE(entry.Matches("AA@123#"));
+ ASSERT_FALSE(entry.Matches("AA@#123"));
+}
+
+TEST_F(PatternEntryTest, TestTwoNameEntry) {
+ ObjectLibrary::PatternEntry entry("A");
+ entry.AnotherName("B");
+ ASSERT_TRUE(entry.Matches("A"));
+ ASSERT_TRUE(entry.Matches("B"));
+ ASSERT_FALSE(entry.Matches("AA"));
+ ASSERT_FALSE(entry.Matches("BB"));
+ ASSERT_FALSE(entry.Matches("AA"));
+ ASSERT_FALSE(entry.Matches("BA"));
+ ASSERT_FALSE(entry.Matches("AB"));
+}
+
+TEST_F(PatternEntryTest, TestTwoPatternEntry) {
+ ObjectLibrary::PatternEntry entry("AA", false);
+ entry.AddSeparator(":");
+ entry.AddSeparator(":");
+ ASSERT_FALSE(entry.Matches("AA"));
+ ASSERT_FALSE(entry.Matches("AA:"));
+ ASSERT_FALSE(entry.Matches("AA::"));
+ ASSERT_FALSE(entry.Matches("AA::12"));
+ ASSERT_TRUE(entry.Matches("AA:1:2"));
+ ASSERT_TRUE(entry.Matches("AA:1:2:"));
+
+ ObjectLibrary::PatternEntry entry2("AA", false);
+ entry2.AddSeparator("::");
+ entry2.AddSeparator("##");
+ ASSERT_FALSE(entry2.Matches("AA"));
+ ASSERT_FALSE(entry2.Matches("AA:"));
+ ASSERT_FALSE(entry2.Matches("AA::"));
+ ASSERT_FALSE(entry2.Matches("AA::#"));
+ ASSERT_FALSE(entry2.Matches("AA::##"));
+ ASSERT_FALSE(entry2.Matches("AA##1::2"));
+ ASSERT_FALSE(entry2.Matches("AA::123##"));
+ ASSERT_TRUE(entry2.Matches("AA::1##2"));
+ ASSERT_TRUE(entry2.Matches("AA::12##34:"));
+ ASSERT_TRUE(entry2.Matches("AA::12::34##56"));
+ ASSERT_TRUE(entry2.Matches("AA::12##34::56"));
+}
+
+TEST_F(PatternEntryTest, TestTwoNumbersEntry) {
+ ObjectLibrary::PatternEntry entry("AA", false);
+ entry.AddNumber(":");
+ entry.AddNumber(":");
+ ASSERT_FALSE(entry.Matches("AA"));
+ ASSERT_FALSE(entry.Matches("AA:"));
+ ASSERT_FALSE(entry.Matches("AA::"));
+ ASSERT_FALSE(entry.Matches("AA::12"));
+ ASSERT_FALSE(entry.Matches("AA:1:2:"));
+ ASSERT_TRUE(entry.Matches("AA:1:2"));
+ ASSERT_TRUE(entry.Matches("AA:12:23456"));
+
+ ObjectLibrary::PatternEntry entry2("AA", false);
+ entry2.AddNumber(":");
+ entry2.AddNumber("#");
+ ASSERT_FALSE(entry2.Matches("AA"));
+ ASSERT_FALSE(entry2.Matches("AA:"));
+ ASSERT_FALSE(entry2.Matches("AA:#"));
+ ASSERT_FALSE(entry2.Matches("AA#:"));
+ ASSERT_FALSE(entry2.Matches("AA:123#"));
+ ASSERT_FALSE(entry2.Matches("AA:123#B"));
+ ASSERT_FALSE(entry2.Matches("AA:B#123"));
+ ASSERT_TRUE(entry2.Matches("AA:1#2"));
+ ASSERT_FALSE(entry2.Matches("AA:123#23:"));
+ ASSERT_FALSE(entry2.Matches("AA::12#234"));
+}
+
+TEST_F(PatternEntryTest, TestPatternAndSuffix) {
+ ObjectLibrary::PatternEntry entry("AA", false);
+ entry.AddSeparator("::");
+ entry.AddSuffix("##");
+ ASSERT_FALSE(entry.Matches("AA"));
+ ASSERT_FALSE(entry.Matches("AA::"));
+ ASSERT_FALSE(entry.Matches("AA::##"));
+ ASSERT_FALSE(entry.Matches("AB::1##"));
+ ASSERT_FALSE(entry.Matches("AB::1##2"));
+ ASSERT_FALSE(entry.Matches("AA##1::"));
+ ASSERT_TRUE(entry.Matches("AA::1##"));
+ ASSERT_FALSE(entry.Matches("AA::1###"));
+
+ ObjectLibrary::PatternEntry entry2("AA", false);
+ entry2.AddSuffix("::");
+ entry2.AddSeparator("##");
+ ASSERT_FALSE(entry2.Matches("AA"));
+ ASSERT_FALSE(entry2.Matches("AA::"));
+ ASSERT_FALSE(entry2.Matches("AA::##"));
+ ASSERT_FALSE(entry2.Matches("AB::1##"));
+ ASSERT_FALSE(entry2.Matches("AB::1##2"));
+ ASSERT_TRUE(entry2.Matches("AA::##12"));
+}
+
+TEST_F(PatternEntryTest, TestTwoNamesAndPattern) {
+ ObjectLibrary::PatternEntry entry("AA", true);
+ entry.AddSeparator("::");
+ entry.AnotherName("BBB");
+ ASSERT_TRUE(entry.Matches("AA"));
+ ASSERT_TRUE(entry.Matches("AA::1"));
+ ASSERT_TRUE(entry.Matches("BBB"));
+ ASSERT_TRUE(entry.Matches("BBB::2"));
+
+ ASSERT_FALSE(entry.Matches("AA::"));
+ ASSERT_FALSE(entry.Matches("AAA::"));
+ ASSERT_FALSE(entry.Matches("BBB::"));
+
+ entry.SetOptional(false);
+ ASSERT_FALSE(entry.Matches("AA"));
+ ASSERT_FALSE(entry.Matches("BBB"));
+
+ ASSERT_FALSE(entry.Matches("AA::"));
+ ASSERT_FALSE(entry.Matches("AAA::"));
+ ASSERT_FALSE(entry.Matches("BBB::"));
+
+ ASSERT_TRUE(entry.Matches("AA::1"));
+ ASSERT_TRUE(entry.Matches("BBB::2"));
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
+
+#else // ROCKSDB_LITE
+#include <stdio.h>
+
+int main(int /*argc*/, char** /*argv*/) {
+ fprintf(stderr, "SKIPPED as ObjRegistry is not supported in ROCKSDB_LITE\n");
+ return 0;
+}
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/option_change_migration/option_change_migration.cc b/src/rocksdb/utilities/option_change_migration/option_change_migration.cc
new file mode 100644
index 000000000..e93d2152d
--- /dev/null
+++ b/src/rocksdb/utilities/option_change_migration/option_change_migration.cc
@@ -0,0 +1,186 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "rocksdb/utilities/option_change_migration.h"
+
+#ifndef ROCKSDB_LITE
+#include "rocksdb/db.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace {
+// Return a version of Options `opts` that allow us to open/write into a DB
+// without triggering an automatic compaction or stalling. This is guaranteed
+// by disabling automatic compactions and using huge values for stalling
+// triggers.
+Options GetNoCompactionOptions(const Options& opts) {
+ Options ret_opts = opts;
+ ret_opts.disable_auto_compactions = true;
+ ret_opts.level0_slowdown_writes_trigger = 999999;
+ ret_opts.level0_stop_writes_trigger = 999999;
+ ret_opts.soft_pending_compaction_bytes_limit = 0;
+ ret_opts.hard_pending_compaction_bytes_limit = 0;
+ return ret_opts;
+}
+
+Status OpenDb(const Options& options, const std::string& dbname,
+ std::unique_ptr<DB>* db) {
+ db->reset();
+ DB* tmpdb;
+ Status s = DB::Open(options, dbname, &tmpdb);
+ if (s.ok()) {
+ db->reset(tmpdb);
+ }
+ return s;
+}
+
+// l0_file_size specifies size of file on L0. Files will be range partitioned
+// after a full compaction so they are likely qualified to put on L0. If
+// left as 0, the files are compacted in a single file and put to L0. Otherwise,
+// will try to compact the files as size l0_file_size.
+Status CompactToLevel(const Options& options, const std::string& dbname,
+ int dest_level, uint64_t l0_file_size, bool need_reopen) {
+ std::unique_ptr<DB> db;
+ Options no_compact_opts = GetNoCompactionOptions(options);
+ if (dest_level == 0) {
+ if (l0_file_size == 0) {
+ // Single file.
+ l0_file_size = 999999999999999;
+ }
+ // L0 has strict sequenceID requirements to files to it. It's safer
+ // to only put one compacted file to there.
+ // This is only used for converting to universal compaction with
+ // only one level. In this case, compacting to one file is also
+ // optimal.
+ no_compact_opts.target_file_size_base = l0_file_size;
+ no_compact_opts.max_compaction_bytes = l0_file_size;
+ }
+ Status s = OpenDb(no_compact_opts, dbname, &db);
+ if (!s.ok()) {
+ return s;
+ }
+ CompactRangeOptions cro;
+ cro.change_level = true;
+ cro.target_level = dest_level;
+ if (dest_level == 0) {
+ // cannot use kForceOptimized because the compaction is expected to
+ // generate one output file
+ cro.bottommost_level_compaction = BottommostLevelCompaction::kForce;
+ }
+ s = db->CompactRange(cro, nullptr, nullptr);
+
+ if (s.ok() && need_reopen) {
+ // Need to restart DB to rewrite the manifest file.
+ // In order to open a DB with specific num_levels, the manifest file should
+ // contain no record that mentiones any level beyond num_levels. Issuing a
+ // full compaction will move all the data to a level not exceeding
+ // num_levels, but the manifest may still contain previous record mentioning
+ // a higher level. Reopening the DB will force the manifest to be rewritten
+ // so that those records will be cleared.
+ db.reset();
+ s = OpenDb(no_compact_opts, dbname, &db);
+ }
+ return s;
+}
+
+Status MigrateToUniversal(std::string dbname, const Options& old_opts,
+ const Options& new_opts) {
+ if (old_opts.num_levels <= new_opts.num_levels ||
+ old_opts.compaction_style == CompactionStyle::kCompactionStyleFIFO) {
+ return Status::OK();
+ } else {
+ bool need_compact = false;
+ {
+ std::unique_ptr<DB> db;
+ Options opts = GetNoCompactionOptions(old_opts);
+ Status s = OpenDb(opts, dbname, &db);
+ if (!s.ok()) {
+ return s;
+ }
+ ColumnFamilyMetaData metadata;
+ db->GetColumnFamilyMetaData(&metadata);
+ if (!metadata.levels.empty() &&
+ metadata.levels.back().level >= new_opts.num_levels) {
+ need_compact = true;
+ }
+ }
+ if (need_compact) {
+ return CompactToLevel(old_opts, dbname, new_opts.num_levels - 1,
+ /*l0_file_size=*/0, true);
+ }
+ return Status::OK();
+ }
+}
+
+Status MigrateToLevelBase(std::string dbname, const Options& old_opts,
+ const Options& new_opts) {
+ if (!new_opts.level_compaction_dynamic_level_bytes) {
+ if (old_opts.num_levels == 1) {
+ return Status::OK();
+ }
+ // Compact everything to level 1 to guarantee it can be safely opened.
+ Options opts = old_opts;
+ opts.target_file_size_base = new_opts.target_file_size_base;
+ // Although sometimes we can open the DB with the new option without error,
+ // We still want to compact the files to avoid the LSM tree to stuck
+ // in bad shape. For example, if the user changed the level size
+ // multiplier from 4 to 8, with the same data, we will have fewer
+ // levels. Unless we issue a full comaction, the LSM tree may stuck
+ // with more levels than needed and it won't recover automatically.
+ return CompactToLevel(opts, dbname, 1, /*l0_file_size=*/0, true);
+ } else {
+ // Compact everything to the last level to guarantee it can be safely
+ // opened.
+ if (old_opts.num_levels == 1) {
+ return Status::OK();
+ } else if (new_opts.num_levels > old_opts.num_levels) {
+ // Dynamic level mode requires data to be put in the last level first.
+ return CompactToLevel(new_opts, dbname, new_opts.num_levels - 1,
+ /*l0_file_size=*/0, false);
+ } else {
+ Options opts = old_opts;
+ opts.target_file_size_base = new_opts.target_file_size_base;
+ return CompactToLevel(opts, dbname, new_opts.num_levels - 1,
+ /*l0_file_size=*/0, true);
+ }
+ }
+}
+} // namespace
+
+Status OptionChangeMigration(std::string dbname, const Options& old_opts,
+ const Options& new_opts) {
+ if (old_opts.compaction_style == CompactionStyle::kCompactionStyleFIFO) {
+ // LSM generated by FIFO compaction can be opened by any compaction.
+ return Status::OK();
+ } else if (new_opts.compaction_style ==
+ CompactionStyle::kCompactionStyleUniversal) {
+ return MigrateToUniversal(dbname, old_opts, new_opts);
+ } else if (new_opts.compaction_style ==
+ CompactionStyle::kCompactionStyleLevel) {
+ return MigrateToLevelBase(dbname, old_opts, new_opts);
+ } else if (new_opts.compaction_style ==
+ CompactionStyle::kCompactionStyleFIFO) {
+ uint64_t l0_file_size = 0;
+ if (new_opts.compaction_options_fifo.max_table_files_size > 0) {
+ // Create at least 8 files when max_table_files_size hits, so that the DB
+ // doesn't just disappear. This in fact violates the FIFO condition, but
+ // otherwise, the migrated DB is unlikley to be usable.
+ l0_file_size = new_opts.compaction_options_fifo.max_table_files_size / 8;
+ }
+ return CompactToLevel(old_opts, dbname, 0, l0_file_size, true);
+ } else {
+ return Status::NotSupported(
+ "Do not how to migrate to this compaction style");
+ }
+}
+} // namespace ROCKSDB_NAMESPACE
+#else
+namespace ROCKSDB_NAMESPACE {
+Status OptionChangeMigration(std::string /*dbname*/,
+ const Options& /*old_opts*/,
+ const Options& /*new_opts*/) {
+ return Status::NotSupported();
+}
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/option_change_migration/option_change_migration_test.cc b/src/rocksdb/utilities/option_change_migration/option_change_migration_test.cc
new file mode 100644
index 000000000..71af45db1
--- /dev/null
+++ b/src/rocksdb/utilities/option_change_migration/option_change_migration_test.cc
@@ -0,0 +1,550 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "rocksdb/utilities/option_change_migration.h"
+
+#include <set>
+
+#include "db/db_test_util.h"
+#include "port/stack_trace.h"
+#include "util/random.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class DBOptionChangeMigrationTests
+ : public DBTestBase,
+ public testing::WithParamInterface<
+ std::tuple<int, int, bool, int, int, bool, uint64_t>> {
+ public:
+ DBOptionChangeMigrationTests()
+ : DBTestBase("db_option_change_migration_test", /*env_do_fsync=*/true) {
+ level1_ = std::get<0>(GetParam());
+ compaction_style1_ = std::get<1>(GetParam());
+ is_dynamic1_ = std::get<2>(GetParam());
+
+ level2_ = std::get<3>(GetParam());
+ compaction_style2_ = std::get<4>(GetParam());
+ is_dynamic2_ = std::get<5>(GetParam());
+ fifo_max_table_files_size_ = std::get<6>(GetParam());
+ }
+
+ // Required if inheriting from testing::WithParamInterface<>
+ static void SetUpTestCase() {}
+ static void TearDownTestCase() {}
+
+ int level1_;
+ int compaction_style1_;
+ bool is_dynamic1_;
+
+ int level2_;
+ int compaction_style2_;
+ bool is_dynamic2_;
+
+ uint64_t fifo_max_table_files_size_;
+};
+
+#ifndef ROCKSDB_LITE
+TEST_P(DBOptionChangeMigrationTests, Migrate1) {
+ Options old_options = CurrentOptions();
+ old_options.compaction_style =
+ static_cast<CompactionStyle>(compaction_style1_);
+ if (old_options.compaction_style == CompactionStyle::kCompactionStyleLevel) {
+ old_options.level_compaction_dynamic_level_bytes = is_dynamic1_;
+ }
+ if (old_options.compaction_style == CompactionStyle::kCompactionStyleFIFO) {
+ old_options.max_open_files = -1;
+ }
+ old_options.level0_file_num_compaction_trigger = 3;
+ old_options.write_buffer_size = 64 * 1024;
+ old_options.target_file_size_base = 128 * 1024;
+ // Make level target of L1, L2 to be 200KB and 600KB
+ old_options.num_levels = level1_;
+ old_options.max_bytes_for_level_multiplier = 3;
+ old_options.max_bytes_for_level_base = 200 * 1024;
+
+ Reopen(old_options);
+
+ Random rnd(301);
+ int key_idx = 0;
+
+ // Generate at least 2MB of data
+ for (int num = 0; num < 20; num++) {
+ GenerateNewFile(&rnd, &key_idx);
+ }
+ ASSERT_OK(dbfull()->TEST_WaitForFlushMemTable());
+ ASSERT_OK(dbfull()->TEST_WaitForCompact());
+
+ // Will make sure exactly those keys are in the DB after migration.
+ std::set<std::string> keys;
+ {
+ std::unique_ptr<Iterator> it(db_->NewIterator(ReadOptions()));
+ it->SeekToFirst();
+ for (; it->Valid(); it->Next()) {
+ keys.insert(it->key().ToString());
+ }
+ }
+ Close();
+
+ Options new_options = old_options;
+ new_options.compaction_style =
+ static_cast<CompactionStyle>(compaction_style2_);
+ if (new_options.compaction_style == CompactionStyle::kCompactionStyleLevel) {
+ new_options.level_compaction_dynamic_level_bytes = is_dynamic2_;
+ }
+ if (new_options.compaction_style == CompactionStyle::kCompactionStyleFIFO) {
+ new_options.max_open_files = -1;
+ }
+ if (fifo_max_table_files_size_ != 0) {
+ new_options.compaction_options_fifo.max_table_files_size =
+ fifo_max_table_files_size_;
+ }
+ new_options.target_file_size_base = 256 * 1024;
+ new_options.num_levels = level2_;
+ new_options.max_bytes_for_level_base = 150 * 1024;
+ new_options.max_bytes_for_level_multiplier = 4;
+ ASSERT_OK(OptionChangeMigration(dbname_, old_options, new_options));
+ Reopen(new_options);
+
+ // Wait for compaction to finish and make sure it can reopen
+ ASSERT_OK(dbfull()->TEST_WaitForFlushMemTable());
+ ASSERT_OK(dbfull()->TEST_WaitForCompact());
+ Reopen(new_options);
+
+ {
+ std::unique_ptr<Iterator> it(db_->NewIterator(ReadOptions()));
+ it->SeekToFirst();
+ for (std::string key : keys) {
+ ASSERT_TRUE(it->Valid());
+ ASSERT_EQ(key, it->key().ToString());
+ it->Next();
+ }
+ ASSERT_TRUE(!it->Valid());
+ }
+}
+
+TEST_P(DBOptionChangeMigrationTests, Migrate2) {
+ Options old_options = CurrentOptions();
+ old_options.compaction_style =
+ static_cast<CompactionStyle>(compaction_style2_);
+ if (old_options.compaction_style == CompactionStyle::kCompactionStyleLevel) {
+ old_options.level_compaction_dynamic_level_bytes = is_dynamic2_;
+ }
+ if (old_options.compaction_style == CompactionStyle::kCompactionStyleFIFO) {
+ old_options.max_open_files = -1;
+ }
+ old_options.level0_file_num_compaction_trigger = 3;
+ old_options.write_buffer_size = 64 * 1024;
+ old_options.target_file_size_base = 128 * 1024;
+ // Make level target of L1, L2 to be 200KB and 600KB
+ old_options.num_levels = level2_;
+ old_options.max_bytes_for_level_multiplier = 3;
+ old_options.max_bytes_for_level_base = 200 * 1024;
+
+ Reopen(old_options);
+
+ Random rnd(301);
+ int key_idx = 0;
+
+ // Generate at least 2MB of data
+ for (int num = 0; num < 20; num++) {
+ GenerateNewFile(&rnd, &key_idx);
+ }
+ ASSERT_OK(dbfull()->TEST_WaitForFlushMemTable());
+ ASSERT_OK(dbfull()->TEST_WaitForCompact());
+
+ // Will make sure exactly those keys are in the DB after migration.
+ std::set<std::string> keys;
+ {
+ std::unique_ptr<Iterator> it(db_->NewIterator(ReadOptions()));
+ it->SeekToFirst();
+ for (; it->Valid(); it->Next()) {
+ keys.insert(it->key().ToString());
+ }
+ }
+
+ Close();
+
+ Options new_options = old_options;
+ new_options.compaction_style =
+ static_cast<CompactionStyle>(compaction_style1_);
+ if (new_options.compaction_style == CompactionStyle::kCompactionStyleLevel) {
+ new_options.level_compaction_dynamic_level_bytes = is_dynamic1_;
+ }
+ if (new_options.compaction_style == CompactionStyle::kCompactionStyleFIFO) {
+ new_options.max_open_files = -1;
+ }
+ if (fifo_max_table_files_size_ != 0) {
+ new_options.compaction_options_fifo.max_table_files_size =
+ fifo_max_table_files_size_;
+ }
+ new_options.target_file_size_base = 256 * 1024;
+ new_options.num_levels = level1_;
+ new_options.max_bytes_for_level_base = 150 * 1024;
+ new_options.max_bytes_for_level_multiplier = 4;
+ ASSERT_OK(OptionChangeMigration(dbname_, old_options, new_options));
+ Reopen(new_options);
+ // Wait for compaction to finish and make sure it can reopen
+ ASSERT_OK(dbfull()->TEST_WaitForFlushMemTable());
+ ASSERT_OK(dbfull()->TEST_WaitForCompact());
+ Reopen(new_options);
+
+ {
+ std::unique_ptr<Iterator> it(db_->NewIterator(ReadOptions()));
+ it->SeekToFirst();
+ for (std::string key : keys) {
+ ASSERT_TRUE(it->Valid());
+ ASSERT_EQ(key, it->key().ToString());
+ it->Next();
+ }
+ ASSERT_TRUE(!it->Valid());
+ }
+}
+
+TEST_P(DBOptionChangeMigrationTests, Migrate3) {
+ Options old_options = CurrentOptions();
+ old_options.compaction_style =
+ static_cast<CompactionStyle>(compaction_style1_);
+ if (old_options.compaction_style == CompactionStyle::kCompactionStyleLevel) {
+ old_options.level_compaction_dynamic_level_bytes = is_dynamic1_;
+ }
+ if (old_options.compaction_style == CompactionStyle::kCompactionStyleFIFO) {
+ old_options.max_open_files = -1;
+ }
+ old_options.level0_file_num_compaction_trigger = 3;
+ old_options.write_buffer_size = 64 * 1024;
+ old_options.target_file_size_base = 128 * 1024;
+ // Make level target of L1, L2 to be 200KB and 600KB
+ old_options.num_levels = level1_;
+ old_options.max_bytes_for_level_multiplier = 3;
+ old_options.max_bytes_for_level_base = 200 * 1024;
+
+ Reopen(old_options);
+ Random rnd(301);
+ for (int num = 0; num < 20; num++) {
+ for (int i = 0; i < 50; i++) {
+ ASSERT_OK(Put(Key(num * 100 + i), rnd.RandomString(900)));
+ }
+ Flush();
+ ASSERT_OK(dbfull()->TEST_WaitForCompact());
+ if (num == 9) {
+ // Issue a full compaction to generate some zero-out files
+ CompactRangeOptions cro;
+ cro.bottommost_level_compaction = BottommostLevelCompaction::kForce;
+ ASSERT_OK(dbfull()->CompactRange(cro, nullptr, nullptr));
+ }
+ }
+ ASSERT_OK(dbfull()->TEST_WaitForFlushMemTable());
+ ASSERT_OK(dbfull()->TEST_WaitForCompact());
+
+ // Will make sure exactly those keys are in the DB after migration.
+ std::set<std::string> keys;
+ {
+ std::unique_ptr<Iterator> it(db_->NewIterator(ReadOptions()));
+ it->SeekToFirst();
+ for (; it->Valid(); it->Next()) {
+ keys.insert(it->key().ToString());
+ }
+ }
+ Close();
+
+ Options new_options = old_options;
+ new_options.compaction_style =
+ static_cast<CompactionStyle>(compaction_style2_);
+ if (new_options.compaction_style == CompactionStyle::kCompactionStyleLevel) {
+ new_options.level_compaction_dynamic_level_bytes = is_dynamic2_;
+ }
+ if (new_options.compaction_style == CompactionStyle::kCompactionStyleFIFO) {
+ new_options.max_open_files = -1;
+ }
+ if (fifo_max_table_files_size_ != 0) {
+ new_options.compaction_options_fifo.max_table_files_size =
+ fifo_max_table_files_size_;
+ }
+ new_options.target_file_size_base = 256 * 1024;
+ new_options.num_levels = level2_;
+ new_options.max_bytes_for_level_base = 150 * 1024;
+ new_options.max_bytes_for_level_multiplier = 4;
+ ASSERT_OK(OptionChangeMigration(dbname_, old_options, new_options));
+ Reopen(new_options);
+
+ // Wait for compaction to finish and make sure it can reopen
+ ASSERT_OK(dbfull()->TEST_WaitForFlushMemTable());
+ ASSERT_OK(dbfull()->TEST_WaitForCompact());
+ Reopen(new_options);
+
+ {
+ std::unique_ptr<Iterator> it(db_->NewIterator(ReadOptions()));
+ it->SeekToFirst();
+ for (std::string key : keys) {
+ ASSERT_TRUE(it->Valid());
+ ASSERT_EQ(key, it->key().ToString());
+ it->Next();
+ }
+ ASSERT_TRUE(!it->Valid());
+ }
+}
+
+TEST_P(DBOptionChangeMigrationTests, Migrate4) {
+ Options old_options = CurrentOptions();
+ old_options.compaction_style =
+ static_cast<CompactionStyle>(compaction_style2_);
+ if (old_options.compaction_style == CompactionStyle::kCompactionStyleLevel) {
+ old_options.level_compaction_dynamic_level_bytes = is_dynamic2_;
+ }
+ if (old_options.compaction_style == CompactionStyle::kCompactionStyleFIFO) {
+ old_options.max_open_files = -1;
+ }
+ old_options.level0_file_num_compaction_trigger = 3;
+ old_options.write_buffer_size = 64 * 1024;
+ old_options.target_file_size_base = 128 * 1024;
+ // Make level target of L1, L2 to be 200KB and 600KB
+ old_options.num_levels = level2_;
+ old_options.max_bytes_for_level_multiplier = 3;
+ old_options.max_bytes_for_level_base = 200 * 1024;
+
+ Reopen(old_options);
+ Random rnd(301);
+ for (int num = 0; num < 20; num++) {
+ for (int i = 0; i < 50; i++) {
+ ASSERT_OK(Put(Key(num * 100 + i), rnd.RandomString(900)));
+ }
+ Flush();
+ ASSERT_OK(dbfull()->TEST_WaitForCompact());
+ if (num == 9) {
+ // Issue a full compaction to generate some zero-out files
+ CompactRangeOptions cro;
+ cro.bottommost_level_compaction = BottommostLevelCompaction::kForce;
+ ASSERT_OK(dbfull()->CompactRange(cro, nullptr, nullptr));
+ }
+ }
+ ASSERT_OK(dbfull()->TEST_WaitForFlushMemTable());
+ ASSERT_OK(dbfull()->TEST_WaitForCompact());
+
+ // Will make sure exactly those keys are in the DB after migration.
+ std::set<std::string> keys;
+ {
+ std::unique_ptr<Iterator> it(db_->NewIterator(ReadOptions()));
+ it->SeekToFirst();
+ for (; it->Valid(); it->Next()) {
+ keys.insert(it->key().ToString());
+ }
+ }
+
+ Close();
+
+ Options new_options = old_options;
+ new_options.compaction_style =
+ static_cast<CompactionStyle>(compaction_style1_);
+ if (new_options.compaction_style == CompactionStyle::kCompactionStyleLevel) {
+ new_options.level_compaction_dynamic_level_bytes = is_dynamic1_;
+ }
+ if (new_options.compaction_style == CompactionStyle::kCompactionStyleFIFO) {
+ new_options.max_open_files = -1;
+ }
+ if (fifo_max_table_files_size_ != 0) {
+ new_options.compaction_options_fifo.max_table_files_size =
+ fifo_max_table_files_size_;
+ }
+ new_options.target_file_size_base = 256 * 1024;
+ new_options.num_levels = level1_;
+ new_options.max_bytes_for_level_base = 150 * 1024;
+ new_options.max_bytes_for_level_multiplier = 4;
+ ASSERT_OK(OptionChangeMigration(dbname_, old_options, new_options));
+ Reopen(new_options);
+ // Wait for compaction to finish and make sure it can reopen
+ ASSERT_OK(dbfull()->TEST_WaitForFlushMemTable());
+ ASSERT_OK(dbfull()->TEST_WaitForCompact());
+ Reopen(new_options);
+
+ {
+ std::unique_ptr<Iterator> it(db_->NewIterator(ReadOptions()));
+ it->SeekToFirst();
+ for (std::string key : keys) {
+ ASSERT_TRUE(it->Valid());
+ ASSERT_EQ(key, it->key().ToString());
+ it->Next();
+ }
+ ASSERT_TRUE(!it->Valid());
+ }
+}
+
+INSTANTIATE_TEST_CASE_P(
+ DBOptionChangeMigrationTests, DBOptionChangeMigrationTests,
+ ::testing::Values(
+ std::make_tuple(3 /* old num_levels */, 0 /* old compaction style */,
+ false /* is dynamic leveling in old option */,
+ 4 /* old num_levels */, 0 /* new compaction style */,
+ false /* is dynamic leveling in new option */,
+ 0 /*fifo max_table_files_size*/),
+ std::make_tuple(3 /* old num_levels */, 0 /* old compaction style */,
+ true /* is dynamic leveling in old option */,
+ 4 /* old num_levels */, 0 /* new compaction style */,
+ true /* is dynamic leveling in new option */,
+ 0 /*fifo max_table_files_size*/),
+ std::make_tuple(3 /* old num_levels */, 0 /* old compaction style */,
+ true /* is dynamic leveling in old option */,
+ 4 /* old num_levels */, 0 /* new compaction style */,
+ false, 0 /*fifo max_table_files_size*/),
+ std::make_tuple(3 /* old num_levels */, 0 /* old compaction style */,
+ false /* is dynamic leveling in old option */,
+ 4 /* old num_levels */, 0 /* new compaction style */,
+ true /* is dynamic leveling in new option */,
+ 0 /*fifo max_table_files_size*/),
+ std::make_tuple(3 /* old num_levels */, 1 /* old compaction style */,
+ false /* is dynamic leveling in old option */,
+ 4 /* old num_levels */, 1 /* new compaction style */,
+ false /* is dynamic leveling in new option */,
+ 0 /*fifo max_table_files_size*/),
+ std::make_tuple(1 /* old num_levels */, 1 /* old compaction style */,
+ false /* is dynamic leveling in old option */,
+ 4 /* old num_levels */, 1 /* new compaction style */,
+ false /* is dynamic leveling in new option */,
+ 0 /*fifo max_table_files_size*/),
+ std::make_tuple(3 /* old num_levels */, 0 /* old compaction style */,
+ false /* is dynamic leveling in old option */,
+ 4 /* old num_levels */, 1 /* new compaction style */,
+ false /* is dynamic leveling in new option */,
+ 0 /*fifo max_table_files_size*/),
+ std::make_tuple(3 /* old num_levels */, 0 /* old compaction style */,
+ false /* is dynamic leveling in old option */,
+ 1 /* old num_levels */, 1 /* new compaction style */,
+ false /* is dynamic leveling in new option */,
+ 0 /*fifo max_table_files_size*/),
+ std::make_tuple(3 /* old num_levels */, 0 /* old compaction style */,
+ true /* is dynamic leveling in old option */,
+ 4 /* old num_levels */, 1 /* new compaction style */,
+ false /* is dynamic leveling in new option */,
+ 0 /*fifo max_table_files_size*/),
+ std::make_tuple(3 /* old num_levels */, 0 /* old compaction style */,
+ true /* is dynamic leveling in old option */,
+ 1 /* old num_levels */, 1 /* new compaction style */,
+ false /* is dynamic leveling in new option */,
+ 0 /*fifo max_table_files_size*/),
+ std::make_tuple(1 /* old num_levels */, 1 /* old compaction style */,
+ false /* is dynamic leveling in old option */,
+ 4 /* old num_levels */, 0 /* new compaction style */,
+ false /* is dynamic leveling in new option */,
+ 0 /*fifo max_table_files_size*/),
+ std::make_tuple(4 /* old num_levels */, 0 /* old compaction style */,
+ false /* is dynamic leveling in old option */,
+ 1 /* old num_levels */, 2 /* new compaction style */,
+ false /* is dynamic leveling in new option */,
+ 0 /*fifo max_table_files_size*/),
+ std::make_tuple(3 /* old num_levels */, 0 /* old compaction style */,
+ true /* is dynamic leveling in old option */,
+ 2 /* old num_levels */, 2 /* new compaction style */,
+ false /* is dynamic leveling in new option */,
+ 0 /*fifo max_table_files_size*/),
+ std::make_tuple(3 /* old num_levels */, 1 /* old compaction style */,
+ false /* is dynamic leveling in old option */,
+ 3 /* old num_levels */, 2 /* new compaction style */,
+ false /* is dynamic leveling in new option */,
+ 0 /*fifo max_table_files_size*/),
+ std::make_tuple(1 /* old num_levels */, 1 /* old compaction style */,
+ false /* is dynamic leveling in old option */,
+ 4 /* old num_levels */, 2 /* new compaction style */,
+ false /* is dynamic leveling in new option */, 0),
+ std::make_tuple(4 /* old num_levels */, 0 /* old compaction style */,
+ false /* is dynamic leveling in old option */,
+ 1 /* old num_levels */, 2 /* new compaction style */,
+ false /* is dynamic leveling in new option */,
+ 5 * 1024 * 1024 /*fifo max_table_files_size*/),
+ std::make_tuple(3 /* old num_levels */, 0 /* old compaction style */,
+ true /* is dynamic leveling in old option */,
+ 2 /* old num_levels */, 2 /* new compaction style */,
+ false /* is dynamic leveling in new option */,
+ 5 * 1024 * 1024 /*fifo max_table_files_size*/),
+ std::make_tuple(3 /* old num_levels */, 1 /* old compaction style */,
+ false /* is dynamic leveling in old option */,
+ 3 /* old num_levels */, 2 /* new compaction style */,
+ false /* is dynamic leveling in new option */,
+ 5 * 1024 * 1024 /*fifo max_table_files_size*/),
+ std::make_tuple(1 /* old num_levels */, 1 /* old compaction style */,
+ false /* is dynamic leveling in old option */,
+ 4 /* old num_levels */, 2 /* new compaction style */,
+ false /* is dynamic leveling in new option */,
+ 5 * 1024 * 1024 /*fifo max_table_files_size*/)));
+
+class DBOptionChangeMigrationTest : public DBTestBase {
+ public:
+ DBOptionChangeMigrationTest()
+ : DBTestBase("db_option_change_migration_test2", /*env_do_fsync=*/true) {}
+};
+
+TEST_F(DBOptionChangeMigrationTest, CompactedSrcToUniversal) {
+ Options old_options = CurrentOptions();
+ old_options.compaction_style = CompactionStyle::kCompactionStyleLevel;
+ old_options.max_compaction_bytes = 200 * 1024;
+ old_options.level_compaction_dynamic_level_bytes = false;
+ old_options.level0_file_num_compaction_trigger = 3;
+ old_options.write_buffer_size = 64 * 1024;
+ old_options.target_file_size_base = 128 * 1024;
+ // Make level target of L1, L2 to be 200KB and 600KB
+ old_options.num_levels = 4;
+ old_options.max_bytes_for_level_multiplier = 3;
+ old_options.max_bytes_for_level_base = 200 * 1024;
+
+ Reopen(old_options);
+ Random rnd(301);
+ for (int num = 0; num < 20; num++) {
+ for (int i = 0; i < 50; i++) {
+ ASSERT_OK(Put(Key(num * 100 + i), rnd.RandomString(900)));
+ }
+ }
+ Flush();
+ CompactRangeOptions cro;
+ cro.bottommost_level_compaction = BottommostLevelCompaction::kForce;
+ ASSERT_OK(dbfull()->CompactRange(cro, nullptr, nullptr));
+
+ // Will make sure exactly those keys are in the DB after migration.
+ std::set<std::string> keys;
+ {
+ std::unique_ptr<Iterator> it(db_->NewIterator(ReadOptions()));
+ it->SeekToFirst();
+ for (; it->Valid(); it->Next()) {
+ keys.insert(it->key().ToString());
+ }
+ }
+
+ Close();
+
+ Options new_options = old_options;
+ new_options.compaction_style = CompactionStyle::kCompactionStyleUniversal;
+ new_options.target_file_size_base = 256 * 1024;
+ new_options.num_levels = 1;
+ new_options.max_bytes_for_level_base = 150 * 1024;
+ new_options.max_bytes_for_level_multiplier = 4;
+ ASSERT_OK(OptionChangeMigration(dbname_, old_options, new_options));
+ Reopen(new_options);
+ // Wait for compaction to finish and make sure it can reopen
+ ASSERT_OK(dbfull()->TEST_WaitForFlushMemTable());
+ ASSERT_OK(dbfull()->TEST_WaitForCompact());
+ Reopen(new_options);
+
+ {
+ std::unique_ptr<Iterator> it(db_->NewIterator(ReadOptions()));
+ it->SeekToFirst();
+ for (std::string key : keys) {
+ ASSERT_TRUE(it->Valid());
+ ASSERT_EQ(key, it->key().ToString());
+ it->Next();
+ }
+ ASSERT_TRUE(!it->Valid());
+ ASSERT_OK(it->status());
+ }
+}
+
+#endif // ROCKSDB_LITE
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/rocksdb/utilities/options/options_util.cc b/src/rocksdb/utilities/options/options_util.cc
new file mode 100644
index 000000000..00c4b981a
--- /dev/null
+++ b/src/rocksdb/utilities/options/options_util.cc
@@ -0,0 +1,159 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "rocksdb/utilities/options_util.h"
+
+#include "file/filename.h"
+#include "options/options_parser.h"
+#include "rocksdb/convenience.h"
+#include "rocksdb/options.h"
+#include "table/block_based/block_based_table_factory.h"
+
+namespace ROCKSDB_NAMESPACE {
+Status LoadOptionsFromFile(const std::string& file_name, Env* env,
+ DBOptions* db_options,
+ std::vector<ColumnFamilyDescriptor>* cf_descs,
+ bool ignore_unknown_options,
+ std::shared_ptr<Cache>* cache) {
+ ConfigOptions config_options;
+ config_options.ignore_unknown_options = ignore_unknown_options;
+ config_options.input_strings_escaped = true;
+ config_options.env = env;
+
+ return LoadOptionsFromFile(config_options, file_name, db_options, cf_descs,
+ cache);
+}
+
+Status LoadOptionsFromFile(const ConfigOptions& config_options,
+ const std::string& file_name, DBOptions* db_options,
+ std::vector<ColumnFamilyDescriptor>* cf_descs,
+ std::shared_ptr<Cache>* cache) {
+ RocksDBOptionsParser parser;
+ const auto& fs = config_options.env->GetFileSystem();
+ Status s = parser.Parse(config_options, file_name, fs.get());
+ if (!s.ok()) {
+ return s;
+ }
+ *db_options = *parser.db_opt();
+ const std::vector<std::string>& cf_names = *parser.cf_names();
+ const std::vector<ColumnFamilyOptions>& cf_opts = *parser.cf_opts();
+ cf_descs->clear();
+ for (size_t i = 0; i < cf_opts.size(); ++i) {
+ cf_descs->push_back({cf_names[i], cf_opts[i]});
+ if (cache != nullptr) {
+ TableFactory* tf = cf_opts[i].table_factory.get();
+ if (tf != nullptr) {
+ auto* opts = tf->GetOptions<BlockBasedTableOptions>();
+ if (opts != nullptr) {
+ opts->block_cache = *cache;
+ }
+ }
+ }
+ }
+ return Status::OK();
+}
+
+Status GetLatestOptionsFileName(const std::string& dbpath, Env* env,
+ std::string* options_file_name) {
+ Status s;
+ std::string latest_file_name;
+ uint64_t latest_time_stamp = 0;
+ std::vector<std::string> file_names;
+ s = env->GetChildren(dbpath, &file_names);
+ if (s.IsNotFound()) {
+ return Status::NotFound(Status::kPathNotFound,
+ "No options files found in the DB directory.",
+ dbpath);
+ } else if (!s.ok()) {
+ return s;
+ }
+ for (auto& file_name : file_names) {
+ uint64_t time_stamp;
+ FileType type;
+ if (ParseFileName(file_name, &time_stamp, &type) && type == kOptionsFile) {
+ if (time_stamp > latest_time_stamp) {
+ latest_time_stamp = time_stamp;
+ latest_file_name = file_name;
+ }
+ }
+ }
+ if (latest_file_name.size() == 0) {
+ return Status::NotFound(Status::kPathNotFound,
+ "No options files found in the DB directory.",
+ dbpath);
+ }
+ *options_file_name = latest_file_name;
+ return Status::OK();
+}
+
+Status LoadLatestOptions(const std::string& dbpath, Env* env,
+ DBOptions* db_options,
+ std::vector<ColumnFamilyDescriptor>* cf_descs,
+ bool ignore_unknown_options,
+ std::shared_ptr<Cache>* cache) {
+ ConfigOptions config_options;
+ config_options.ignore_unknown_options = ignore_unknown_options;
+ config_options.input_strings_escaped = true;
+ config_options.env = env;
+
+ return LoadLatestOptions(config_options, dbpath, db_options, cf_descs, cache);
+}
+
+Status LoadLatestOptions(const ConfigOptions& config_options,
+ const std::string& dbpath, DBOptions* db_options,
+ std::vector<ColumnFamilyDescriptor>* cf_descs,
+ std::shared_ptr<Cache>* cache) {
+ std::string options_file_name;
+ Status s =
+ GetLatestOptionsFileName(dbpath, config_options.env, &options_file_name);
+ if (!s.ok()) {
+ return s;
+ }
+ return LoadOptionsFromFile(config_options, dbpath + "/" + options_file_name,
+ db_options, cf_descs, cache);
+}
+
+Status CheckOptionsCompatibility(
+ const std::string& dbpath, Env* env, const DBOptions& db_options,
+ const std::vector<ColumnFamilyDescriptor>& cf_descs,
+ bool ignore_unknown_options) {
+ ConfigOptions config_options(db_options);
+ config_options.sanity_level = ConfigOptions::kSanityLevelLooselyCompatible;
+ config_options.ignore_unknown_options = ignore_unknown_options;
+ config_options.input_strings_escaped = true;
+ config_options.env = env;
+ return CheckOptionsCompatibility(config_options, dbpath, db_options,
+ cf_descs);
+}
+
+Status CheckOptionsCompatibility(
+ const ConfigOptions& config_options, const std::string& dbpath,
+ const DBOptions& db_options,
+ const std::vector<ColumnFamilyDescriptor>& cf_descs) {
+ std::string options_file_name;
+ Status s =
+ GetLatestOptionsFileName(dbpath, config_options.env, &options_file_name);
+ if (!s.ok()) {
+ return s;
+ }
+
+ std::vector<std::string> cf_names;
+ std::vector<ColumnFamilyOptions> cf_opts;
+ for (const auto& cf_desc : cf_descs) {
+ cf_names.push_back(cf_desc.name);
+ cf_opts.push_back(cf_desc.options);
+ }
+
+ const auto& fs = config_options.env->GetFileSystem();
+
+ return RocksDBOptionsParser::VerifyRocksDBOptionsFromFile(
+ config_options, db_options, cf_names, cf_opts,
+ dbpath + "/" + options_file_name, fs.get());
+}
+
+} // namespace ROCKSDB_NAMESPACE
+#endif // !ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/options/options_util_test.cc b/src/rocksdb/utilities/options/options_util_test.cc
new file mode 100644
index 000000000..1c3b41ff2
--- /dev/null
+++ b/src/rocksdb/utilities/options/options_util_test.cc
@@ -0,0 +1,779 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "rocksdb/utilities/options_util.h"
+
+#include <cctype>
+#include <cinttypes>
+#include <unordered_map>
+
+#include "env/mock_env.h"
+#include "file/filename.h"
+#include "options/options_parser.h"
+#include "rocksdb/convenience.h"
+#include "rocksdb/db.h"
+#include "rocksdb/table.h"
+#include "test_util/testharness.h"
+#include "test_util/testutil.h"
+#include "util/random.h"
+
+#ifndef GFLAGS
+bool FLAGS_enable_print = false;
+#else
+#include "util/gflags_compat.h"
+using GFLAGS_NAMESPACE::ParseCommandLineFlags;
+DEFINE_bool(enable_print, false, "Print options generated to console.");
+#endif // GFLAGS
+
+namespace ROCKSDB_NAMESPACE {
+class OptionsUtilTest : public testing::Test {
+ public:
+ OptionsUtilTest() : rnd_(0xFB) {
+ env_.reset(NewMemEnv(Env::Default()));
+ dbname_ = test::PerThreadDBPath("options_util_test");
+ }
+
+ protected:
+ std::unique_ptr<Env> env_;
+ std::string dbname_;
+ Random rnd_;
+};
+
+TEST_F(OptionsUtilTest, SaveAndLoad) {
+ const size_t kCFCount = 5;
+
+ DBOptions db_opt;
+ std::vector<std::string> cf_names;
+ std::vector<ColumnFamilyOptions> cf_opts;
+ test::RandomInitDBOptions(&db_opt, &rnd_);
+ for (size_t i = 0; i < kCFCount; ++i) {
+ cf_names.push_back(i == 0 ? kDefaultColumnFamilyName
+ : test::RandomName(&rnd_, 10));
+ cf_opts.emplace_back();
+ test::RandomInitCFOptions(&cf_opts.back(), db_opt, &rnd_);
+ }
+
+ const std::string kFileName = "OPTIONS-123456";
+ ASSERT_OK(PersistRocksDBOptions(db_opt, cf_names, cf_opts, kFileName,
+ env_->GetFileSystem().get()));
+
+ DBOptions loaded_db_opt;
+ std::vector<ColumnFamilyDescriptor> loaded_cf_descs;
+ ASSERT_OK(LoadOptionsFromFile(kFileName, env_.get(), &loaded_db_opt,
+ &loaded_cf_descs));
+ ConfigOptions exact;
+ exact.sanity_level = ConfigOptions::kSanityLevelExactMatch;
+ ASSERT_OK(
+ RocksDBOptionsParser::VerifyDBOptions(exact, db_opt, loaded_db_opt));
+ test::RandomInitDBOptions(&db_opt, &rnd_);
+ ASSERT_NOK(
+ RocksDBOptionsParser::VerifyDBOptions(exact, db_opt, loaded_db_opt));
+
+ for (size_t i = 0; i < kCFCount; ++i) {
+ ASSERT_EQ(cf_names[i], loaded_cf_descs[i].name);
+ ASSERT_OK(RocksDBOptionsParser::VerifyCFOptions(
+ exact, cf_opts[i], loaded_cf_descs[i].options));
+ ASSERT_OK(RocksDBOptionsParser::VerifyTableFactory(
+ exact, cf_opts[i].table_factory.get(),
+ loaded_cf_descs[i].options.table_factory.get()));
+ test::RandomInitCFOptions(&cf_opts[i], db_opt, &rnd_);
+ ASSERT_NOK(RocksDBOptionsParser::VerifyCFOptions(
+ exact, cf_opts[i], loaded_cf_descs[i].options));
+ }
+
+ ASSERT_OK(DestroyDB(dbname_, Options(db_opt, cf_opts[0])));
+ for (size_t i = 0; i < kCFCount; ++i) {
+ if (cf_opts[i].compaction_filter) {
+ delete cf_opts[i].compaction_filter;
+ }
+ }
+}
+
+TEST_F(OptionsUtilTest, SaveAndLoadWithCacheCheck) {
+ // creating db
+ DBOptions db_opt;
+ db_opt.create_if_missing = true;
+ // initialize BlockBasedTableOptions
+ std::shared_ptr<Cache> cache = NewLRUCache(1 * 1024);
+ BlockBasedTableOptions bbt_opts;
+ bbt_opts.block_size = 32 * 1024;
+ // saving cf options
+ std::vector<ColumnFamilyOptions> cf_opts;
+ ColumnFamilyOptions default_column_family_opt = ColumnFamilyOptions();
+ default_column_family_opt.table_factory.reset(
+ NewBlockBasedTableFactory(bbt_opts));
+ cf_opts.push_back(default_column_family_opt);
+
+ ColumnFamilyOptions cf_opt_sample = ColumnFamilyOptions();
+ cf_opt_sample.table_factory.reset(NewBlockBasedTableFactory(bbt_opts));
+ cf_opts.push_back(cf_opt_sample);
+
+ ColumnFamilyOptions cf_opt_plain_table_opt = ColumnFamilyOptions();
+ cf_opt_plain_table_opt.table_factory.reset(NewPlainTableFactory());
+ cf_opts.push_back(cf_opt_plain_table_opt);
+
+ std::vector<std::string> cf_names;
+ cf_names.push_back(kDefaultColumnFamilyName);
+ cf_names.push_back("cf_sample");
+ cf_names.push_back("cf_plain_table_sample");
+ // Saving DB in file
+ const std::string kFileName = "OPTIONS-LOAD_CACHE_123456";
+ ASSERT_OK(PersistRocksDBOptions(db_opt, cf_names, cf_opts, kFileName,
+ env_->GetFileSystem().get()));
+ DBOptions loaded_db_opt;
+ std::vector<ColumnFamilyDescriptor> loaded_cf_descs;
+
+ ConfigOptions config_options;
+ config_options.ignore_unknown_options = false;
+ config_options.input_strings_escaped = true;
+ config_options.env = env_.get();
+ ASSERT_OK(LoadOptionsFromFile(config_options, kFileName, &loaded_db_opt,
+ &loaded_cf_descs, &cache));
+ for (size_t i = 0; i < loaded_cf_descs.size(); i++) {
+ auto* loaded_bbt_opt =
+ loaded_cf_descs[i]
+ .options.table_factory->GetOptions<BlockBasedTableOptions>();
+ // Expect the same cache will be loaded
+ if (loaded_bbt_opt != nullptr) {
+ ASSERT_EQ(loaded_bbt_opt->block_cache.get(), cache.get());
+ }
+ }
+
+ // Test the old interface
+ ASSERT_OK(LoadOptionsFromFile(kFileName, env_.get(), &loaded_db_opt,
+ &loaded_cf_descs, false, &cache));
+ for (size_t i = 0; i < loaded_cf_descs.size(); i++) {
+ auto* loaded_bbt_opt =
+ loaded_cf_descs[i]
+ .options.table_factory->GetOptions<BlockBasedTableOptions>();
+ // Expect the same cache will be loaded
+ if (loaded_bbt_opt != nullptr) {
+ ASSERT_EQ(loaded_bbt_opt->block_cache.get(), cache.get());
+ }
+ }
+ ASSERT_OK(DestroyDB(dbname_, Options(loaded_db_opt, cf_opts[0])));
+}
+
+namespace {
+class DummyTableFactory : public TableFactory {
+ public:
+ DummyTableFactory() {}
+ ~DummyTableFactory() override {}
+
+ const char* Name() const override { return "DummyTableFactory"; }
+
+ using TableFactory::NewTableReader;
+ Status NewTableReader(
+ const ReadOptions& /*ro*/,
+ const TableReaderOptions& /*table_reader_options*/,
+ std::unique_ptr<RandomAccessFileReader>&& /*file*/,
+ uint64_t /*file_size*/, std::unique_ptr<TableReader>* /*table_reader*/,
+ bool /*prefetch_index_and_filter_in_cache*/) const override {
+ return Status::NotSupported();
+ }
+
+ TableBuilder* NewTableBuilder(
+ const TableBuilderOptions& /*table_builder_options*/,
+ WritableFileWriter* /*file*/) const override {
+ return nullptr;
+ }
+
+ Status ValidateOptions(
+ const DBOptions& /*db_opts*/,
+ const ColumnFamilyOptions& /*cf_opts*/) const override {
+ return Status::NotSupported();
+ }
+
+ std::string GetPrintableOptions() const override { return ""; }
+};
+
+class DummyMergeOperator : public MergeOperator {
+ public:
+ DummyMergeOperator() {}
+ ~DummyMergeOperator() override {}
+
+ bool FullMergeV2(const MergeOperationInput& /*merge_in*/,
+ MergeOperationOutput* /*merge_out*/) const override {
+ return false;
+ }
+
+ bool PartialMergeMulti(const Slice& /*key*/,
+ const std::deque<Slice>& /*operand_list*/,
+ std::string* /*new_value*/,
+ Logger* /*logger*/) const override {
+ return false;
+ }
+
+ const char* Name() const override { return "DummyMergeOperator"; }
+};
+
+class DummySliceTransform : public SliceTransform {
+ public:
+ DummySliceTransform() {}
+ ~DummySliceTransform() override {}
+
+ // Return the name of this transformation.
+ const char* Name() const override { return "DummySliceTransform"; }
+
+ // transform a src in domain to a dst in the range
+ Slice Transform(const Slice& src) const override { return src; }
+
+ // determine whether this is a valid src upon the function applies
+ bool InDomain(const Slice& /*src*/) const override { return false; }
+
+ // determine whether dst=Transform(src) for some src
+ bool InRange(const Slice& /*dst*/) const override { return false; }
+};
+
+} // namespace
+
+TEST_F(OptionsUtilTest, SanityCheck) {
+ DBOptions db_opt;
+ std::vector<ColumnFamilyDescriptor> cf_descs;
+ const size_t kCFCount = 5;
+ for (size_t i = 0; i < kCFCount; ++i) {
+ cf_descs.emplace_back();
+ cf_descs.back().name =
+ (i == 0) ? kDefaultColumnFamilyName : test::RandomName(&rnd_, 10);
+
+ cf_descs.back().options.table_factory.reset(NewBlockBasedTableFactory());
+ // Assign non-null values to prefix_extractors except the first cf.
+ cf_descs.back().options.prefix_extractor.reset(
+ i != 0 ? test::RandomSliceTransform(&rnd_) : nullptr);
+ cf_descs.back().options.merge_operator.reset(
+ test::RandomMergeOperator(&rnd_));
+ }
+
+ db_opt.create_missing_column_families = true;
+ db_opt.create_if_missing = true;
+
+ ASSERT_OK(DestroyDB(dbname_, Options(db_opt, cf_descs[0].options)));
+ DB* db;
+ std::vector<ColumnFamilyHandle*> handles;
+ // open and persist the options
+ ASSERT_OK(DB::Open(db_opt, dbname_, cf_descs, &handles, &db));
+
+ // close the db
+ for (auto* handle : handles) {
+ delete handle;
+ }
+ delete db;
+
+ ConfigOptions config_options;
+ config_options.ignore_unknown_options = false;
+ config_options.input_strings_escaped = true;
+ config_options.sanity_level = ConfigOptions::kSanityLevelLooselyCompatible;
+ // perform sanity check
+ ASSERT_OK(
+ CheckOptionsCompatibility(config_options, dbname_, db_opt, cf_descs));
+
+ ASSERT_GE(kCFCount, 5);
+ // merge operator
+ {
+ std::shared_ptr<MergeOperator> merge_op =
+ cf_descs[0].options.merge_operator;
+
+ ASSERT_NE(merge_op.get(), nullptr);
+ cf_descs[0].options.merge_operator.reset();
+ ASSERT_NOK(
+ CheckOptionsCompatibility(config_options, dbname_, db_opt, cf_descs));
+
+ cf_descs[0].options.merge_operator.reset(new DummyMergeOperator());
+ ASSERT_NOK(
+ CheckOptionsCompatibility(config_options, dbname_, db_opt, cf_descs));
+
+ cf_descs[0].options.merge_operator = merge_op;
+ ASSERT_OK(
+ CheckOptionsCompatibility(config_options, dbname_, db_opt, cf_descs));
+ }
+
+ // prefix extractor
+ {
+ std::shared_ptr<const SliceTransform> prefix_extractor =
+ cf_descs[1].options.prefix_extractor;
+
+ // It's okay to set prefix_extractor to nullptr.
+ ASSERT_NE(prefix_extractor, nullptr);
+ cf_descs[1].options.prefix_extractor.reset();
+ ASSERT_OK(
+ CheckOptionsCompatibility(config_options, dbname_, db_opt, cf_descs));
+
+ cf_descs[1].options.prefix_extractor.reset(new DummySliceTransform());
+ ASSERT_OK(
+ CheckOptionsCompatibility(config_options, dbname_, db_opt, cf_descs));
+
+ cf_descs[1].options.prefix_extractor = prefix_extractor;
+ ASSERT_OK(
+ CheckOptionsCompatibility(config_options, dbname_, db_opt, cf_descs));
+ }
+
+ // prefix extractor nullptr case
+ {
+ std::shared_ptr<const SliceTransform> prefix_extractor =
+ cf_descs[0].options.prefix_extractor;
+
+ // It's okay to set prefix_extractor to nullptr.
+ ASSERT_EQ(prefix_extractor, nullptr);
+ cf_descs[0].options.prefix_extractor.reset();
+ ASSERT_OK(
+ CheckOptionsCompatibility(config_options, dbname_, db_opt, cf_descs));
+
+ // It's okay to change prefix_extractor from nullptr to non-nullptr
+ cf_descs[0].options.prefix_extractor.reset(new DummySliceTransform());
+ ASSERT_OK(
+ CheckOptionsCompatibility(config_options, dbname_, db_opt, cf_descs));
+
+ cf_descs[0].options.prefix_extractor = prefix_extractor;
+ ASSERT_OK(
+ CheckOptionsCompatibility(config_options, dbname_, db_opt, cf_descs));
+ }
+
+ // comparator
+ {
+ test::SimpleSuffixReverseComparator comparator;
+
+ auto* prev_comparator = cf_descs[2].options.comparator;
+ cf_descs[2].options.comparator = &comparator;
+ ASSERT_NOK(
+ CheckOptionsCompatibility(config_options, dbname_, db_opt, cf_descs));
+
+ cf_descs[2].options.comparator = prev_comparator;
+ ASSERT_OK(
+ CheckOptionsCompatibility(config_options, dbname_, db_opt, cf_descs));
+ }
+
+ // table factory
+ {
+ std::shared_ptr<TableFactory> table_factory =
+ cf_descs[3].options.table_factory;
+
+ ASSERT_NE(table_factory, nullptr);
+ cf_descs[3].options.table_factory.reset(new DummyTableFactory());
+ ASSERT_NOK(
+ CheckOptionsCompatibility(config_options, dbname_, db_opt, cf_descs));
+
+ cf_descs[3].options.table_factory = table_factory;
+ ASSERT_OK(
+ CheckOptionsCompatibility(config_options, dbname_, db_opt, cf_descs));
+ }
+ ASSERT_OK(DestroyDB(dbname_, Options(db_opt, cf_descs[0].options)));
+}
+
+TEST_F(OptionsUtilTest, LatestOptionsNotFound) {
+ std::unique_ptr<Env> env(NewMemEnv(Env::Default()));
+ Status s;
+ Options options;
+ ConfigOptions config_opts;
+ std::vector<ColumnFamilyDescriptor> cf_descs;
+
+ options.env = env.get();
+ options.create_if_missing = true;
+ config_opts.env = options.env;
+ config_opts.ignore_unknown_options = false;
+
+ std::vector<std::string> children;
+
+ std::string options_file_name;
+ ASSERT_OK(DestroyDB(dbname_, options));
+ // First, test where the db directory does not exist
+ ASSERT_NOK(options.env->GetChildren(dbname_, &children));
+
+ s = GetLatestOptionsFileName(dbname_, options.env, &options_file_name);
+ ASSERT_TRUE(s.IsNotFound());
+ ASSERT_TRUE(s.IsPathNotFound());
+
+ s = LoadLatestOptions(dbname_, options.env, &options, &cf_descs);
+ ASSERT_TRUE(s.IsNotFound());
+ ASSERT_TRUE(s.IsPathNotFound());
+
+ s = LoadLatestOptions(config_opts, dbname_, &options, &cf_descs);
+ ASSERT_TRUE(s.IsPathNotFound());
+
+ s = GetLatestOptionsFileName(dbname_, options.env, &options_file_name);
+ ASSERT_TRUE(s.IsNotFound());
+ ASSERT_TRUE(s.IsPathNotFound());
+
+ // Second, test where the db directory exists but is empty
+ ASSERT_OK(options.env->CreateDir(dbname_));
+
+ s = GetLatestOptionsFileName(dbname_, options.env, &options_file_name);
+ ASSERT_TRUE(s.IsNotFound());
+ ASSERT_TRUE(s.IsPathNotFound());
+
+ s = LoadLatestOptions(dbname_, options.env, &options, &cf_descs);
+ ASSERT_TRUE(s.IsNotFound());
+ ASSERT_TRUE(s.IsPathNotFound());
+
+ // Finally, test where a file exists but is not an "Options" file
+ std::unique_ptr<WritableFile> file;
+ ASSERT_OK(
+ options.env->NewWritableFile(dbname_ + "/temp.txt", &file, EnvOptions()));
+ ASSERT_OK(file->Close());
+ s = GetLatestOptionsFileName(dbname_, options.env, &options_file_name);
+ ASSERT_TRUE(s.IsNotFound());
+ ASSERT_TRUE(s.IsPathNotFound());
+
+ s = LoadLatestOptions(config_opts, dbname_, &options, &cf_descs);
+ ASSERT_TRUE(s.IsNotFound());
+ ASSERT_TRUE(s.IsPathNotFound());
+ ASSERT_OK(options.env->DeleteFile(dbname_ + "/temp.txt"));
+ ASSERT_OK(options.env->DeleteDir(dbname_));
+}
+
+TEST_F(OptionsUtilTest, LoadLatestOptions) {
+ Options options;
+ options.OptimizeForSmallDb();
+ ColumnFamilyDescriptor cf_desc;
+ ConfigOptions config_opts;
+ DBOptions db_opts;
+ std::vector<ColumnFamilyDescriptor> cf_descs;
+ std::vector<ColumnFamilyHandle*> handles;
+ DB* db;
+ options.create_if_missing = true;
+
+ ASSERT_OK(DestroyDB(dbname_, options));
+
+ cf_descs.emplace_back();
+ cf_descs.back().name = kDefaultColumnFamilyName;
+ cf_descs.back().options.table_factory.reset(NewBlockBasedTableFactory());
+ cf_descs.emplace_back();
+ cf_descs.back().name = "Plain";
+ cf_descs.back().options.table_factory.reset(NewPlainTableFactory());
+ db_opts.create_missing_column_families = true;
+ db_opts.create_if_missing = true;
+
+ // open and persist the options
+ ASSERT_OK(DB::Open(db_opts, dbname_, cf_descs, &handles, &db));
+
+ std::string options_file_name;
+ std::string new_options_file;
+
+ ASSERT_OK(GetLatestOptionsFileName(dbname_, options.env, &options_file_name));
+ ASSERT_OK(LoadLatestOptions(config_opts, dbname_, &db_opts, &cf_descs));
+ ASSERT_EQ(cf_descs.size(), 2U);
+ ASSERT_OK(RocksDBOptionsParser::VerifyDBOptions(config_opts,
+ db->GetDBOptions(), db_opts));
+ ASSERT_OK(handles[0]->GetDescriptor(&cf_desc));
+ ASSERT_OK(RocksDBOptionsParser::VerifyCFOptions(config_opts, cf_desc.options,
+ cf_descs[0].options));
+ ASSERT_OK(handles[1]->GetDescriptor(&cf_desc));
+ ASSERT_OK(RocksDBOptionsParser::VerifyCFOptions(config_opts, cf_desc.options,
+ cf_descs[1].options));
+
+ // Now change some of the DBOptions
+ ASSERT_OK(db->SetDBOptions(
+ {{"delayed_write_rate", "1234"}, {"bytes_per_sync", "32768"}}));
+ ASSERT_OK(GetLatestOptionsFileName(dbname_, options.env, &new_options_file));
+ ASSERT_NE(options_file_name, new_options_file);
+ ASSERT_OK(LoadLatestOptions(config_opts, dbname_, &db_opts, &cf_descs));
+ ASSERT_OK(RocksDBOptionsParser::VerifyDBOptions(config_opts,
+ db->GetDBOptions(), db_opts));
+ options_file_name = new_options_file;
+
+ // Now change some of the ColumnFamilyOptions
+ ASSERT_OK(db->SetOptions(handles[1], {{"write_buffer_size", "32768"}}));
+ ASSERT_OK(GetLatestOptionsFileName(dbname_, options.env, &new_options_file));
+ ASSERT_NE(options_file_name, new_options_file);
+ ASSERT_OK(LoadLatestOptions(config_opts, dbname_, &db_opts, &cf_descs));
+ ASSERT_OK(RocksDBOptionsParser::VerifyDBOptions(config_opts,
+ db->GetDBOptions(), db_opts));
+ ASSERT_OK(handles[0]->GetDescriptor(&cf_desc));
+ ASSERT_OK(RocksDBOptionsParser::VerifyCFOptions(config_opts, cf_desc.options,
+ cf_descs[0].options));
+ ASSERT_OK(handles[1]->GetDescriptor(&cf_desc));
+ ASSERT_OK(RocksDBOptionsParser::VerifyCFOptions(config_opts, cf_desc.options,
+ cf_descs[1].options));
+
+ // close the db
+ for (auto* handle : handles) {
+ delete handle;
+ }
+ delete db;
+ ASSERT_OK(DestroyDB(dbname_, options, cf_descs));
+}
+
+static void WriteOptionsFile(Env* env, const std::string& path,
+ const std::string& options_file, int major,
+ int minor, const std::string& db_opts,
+ const std::string& cf_opts,
+ const std::string& bbt_opts = "") {
+ std::string options_file_header =
+ "\n"
+ "[Version]\n"
+ " rocksdb_version=" +
+ std::to_string(major) + "." + std::to_string(minor) +
+ ".0\n"
+ " options_file_version=1\n";
+
+ std::unique_ptr<WritableFile> wf;
+ ASSERT_OK(env->NewWritableFile(path + "/" + options_file, &wf, EnvOptions()));
+ ASSERT_OK(
+ wf->Append(options_file_header + "[ DBOptions ]\n" + db_opts + "\n"));
+ ASSERT_OK(wf->Append(
+ "[CFOptions \"default\"] # column family must be specified\n" +
+ cf_opts + "\n"));
+ ASSERT_OK(wf->Append("[TableOptions/BlockBasedTable \"default\"]\n" +
+ bbt_opts + "\n"));
+ ASSERT_OK(wf->Close());
+
+ std::string latest_options_file;
+ ASSERT_OK(GetLatestOptionsFileName(path, env, &latest_options_file));
+ ASSERT_EQ(latest_options_file, options_file);
+}
+
+TEST_F(OptionsUtilTest, BadLatestOptions) {
+ Status s;
+ ConfigOptions config_opts;
+ DBOptions db_opts;
+ std::vector<ColumnFamilyDescriptor> cf_descs;
+ Options options;
+ options.env = env_.get();
+ config_opts.env = env_.get();
+ config_opts.ignore_unknown_options = false;
+ config_opts.delimiter = "\n";
+
+ ConfigOptions ignore_opts = config_opts;
+ ignore_opts.ignore_unknown_options = true;
+
+ std::string options_file_name;
+
+ // Test where the db directory exists but is empty
+ ASSERT_OK(options.env->CreateDir(dbname_));
+ ASSERT_NOK(
+ GetLatestOptionsFileName(dbname_, options.env, &options_file_name));
+ ASSERT_NOK(LoadLatestOptions(config_opts, dbname_, &db_opts, &cf_descs));
+
+ // Write an options file for a previous major release with an unknown DB
+ // Option
+ WriteOptionsFile(options.env, dbname_, "OPTIONS-0001", ROCKSDB_MAJOR - 1,
+ ROCKSDB_MINOR, "unknown_db_opt=true", "");
+ s = LoadLatestOptions(config_opts, dbname_, &db_opts, &cf_descs);
+ ASSERT_NOK(s);
+ ASSERT_TRUE(s.IsInvalidArgument());
+ // Even though ignore_unknown_options=true, we still return an error...
+ s = LoadLatestOptions(ignore_opts, dbname_, &db_opts, &cf_descs);
+ ASSERT_NOK(s);
+ ASSERT_TRUE(s.IsInvalidArgument());
+ // Write an options file for a previous minor release with an unknown CF
+ // Option
+ WriteOptionsFile(options.env, dbname_, "OPTIONS-0002", ROCKSDB_MAJOR,
+ ROCKSDB_MINOR - 1, "", "unknown_cf_opt=true");
+ s = LoadLatestOptions(config_opts, dbname_, &db_opts, &cf_descs);
+ ASSERT_NOK(s);
+ ASSERT_TRUE(s.IsInvalidArgument());
+ // Even though ignore_unknown_options=true, we still return an error...
+ s = LoadLatestOptions(ignore_opts, dbname_, &db_opts, &cf_descs);
+ ASSERT_NOK(s);
+ ASSERT_TRUE(s.IsInvalidArgument());
+ // Write an options file for a previous minor release with an unknown BBT
+ // Option
+ WriteOptionsFile(options.env, dbname_, "OPTIONS-0003", ROCKSDB_MAJOR,
+ ROCKSDB_MINOR - 1, "", "", "unknown_bbt_opt=true");
+ s = LoadLatestOptions(config_opts, dbname_, &db_opts, &cf_descs);
+ ASSERT_NOK(s);
+ ASSERT_TRUE(s.IsInvalidArgument());
+ // Even though ignore_unknown_options=true, we still return an error...
+ s = LoadLatestOptions(ignore_opts, dbname_, &db_opts, &cf_descs);
+ ASSERT_NOK(s);
+ ASSERT_TRUE(s.IsInvalidArgument());
+
+ // Write an options file for the current release with an unknown DB Option
+ WriteOptionsFile(options.env, dbname_, "OPTIONS-0004", ROCKSDB_MAJOR,
+ ROCKSDB_MINOR, "unknown_db_opt=true", "");
+ s = LoadLatestOptions(config_opts, dbname_, &db_opts, &cf_descs);
+ ASSERT_NOK(s);
+ ASSERT_TRUE(s.IsInvalidArgument());
+ // Even though ignore_unknown_options=true, we still return an error...
+ s = LoadLatestOptions(ignore_opts, dbname_, &db_opts, &cf_descs);
+ ASSERT_NOK(s);
+ ASSERT_TRUE(s.IsInvalidArgument());
+
+ // Write an options file for the current release with an unknown CF Option
+ WriteOptionsFile(options.env, dbname_, "OPTIONS-0005", ROCKSDB_MAJOR,
+ ROCKSDB_MINOR, "", "unknown_cf_opt=true");
+ s = LoadLatestOptions(config_opts, dbname_, &db_opts, &cf_descs);
+ ASSERT_NOK(s);
+ ASSERT_TRUE(s.IsInvalidArgument());
+ // Even though ignore_unknown_options=true, we still return an error...
+ s = LoadLatestOptions(ignore_opts, dbname_, &db_opts, &cf_descs);
+ ASSERT_NOK(s);
+ ASSERT_TRUE(s.IsInvalidArgument());
+
+ // Write an options file for the current release with an invalid DB Option
+ WriteOptionsFile(options.env, dbname_, "OPTIONS-0006", ROCKSDB_MAJOR,
+ ROCKSDB_MINOR, "create_if_missing=hello", "");
+ s = LoadLatestOptions(config_opts, dbname_, &db_opts, &cf_descs);
+ ASSERT_NOK(s);
+ ASSERT_TRUE(s.IsInvalidArgument());
+ // Even though ignore_unknown_options=true, we still return an error...
+ s = LoadLatestOptions(ignore_opts, dbname_, &db_opts, &cf_descs);
+ ASSERT_NOK(s);
+ ASSERT_TRUE(s.IsInvalidArgument());
+
+ // Write an options file for the next release with an invalid DB Option
+ WriteOptionsFile(options.env, dbname_, "OPTIONS-0007", ROCKSDB_MAJOR,
+ ROCKSDB_MINOR + 1, "create_if_missing=hello", "");
+ ASSERT_NOK(LoadLatestOptions(config_opts, dbname_, &db_opts, &cf_descs));
+ ASSERT_OK(LoadLatestOptions(ignore_opts, dbname_, &db_opts, &cf_descs));
+
+ // Write an options file for the next release with an unknown DB Option
+ WriteOptionsFile(options.env, dbname_, "OPTIONS-0008", ROCKSDB_MAJOR,
+ ROCKSDB_MINOR + 1, "unknown_db_opt=true", "");
+ ASSERT_NOK(LoadLatestOptions(config_opts, dbname_, &db_opts, &cf_descs));
+ // Ignore the errors for future releases when ignore_unknown_options=true
+ ASSERT_OK(LoadLatestOptions(ignore_opts, dbname_, &db_opts, &cf_descs));
+
+ // Write an options file for the next major release with an unknown CF Option
+ WriteOptionsFile(options.env, dbname_, "OPTIONS-0009", ROCKSDB_MAJOR + 1,
+ ROCKSDB_MINOR, "", "unknown_cf_opt=true");
+ ASSERT_NOK(LoadLatestOptions(config_opts, dbname_, &db_opts, &cf_descs));
+ // Ignore the errors for future releases when ignore_unknown_options=true
+ ASSERT_OK(LoadLatestOptions(ignore_opts, dbname_, &db_opts, &cf_descs));
+}
+
+TEST_F(OptionsUtilTest, RenameDatabaseDirectory) {
+ DB* db;
+ Options options;
+ DBOptions db_opts;
+ std::vector<ColumnFamilyDescriptor> cf_descs;
+ std::vector<ColumnFamilyHandle*> handles;
+
+ options.create_if_missing = true;
+
+ ASSERT_OK(DB::Open(options, dbname_, &db));
+ ASSERT_OK(db->Put(WriteOptions(), "foo", "value0"));
+ delete db;
+
+ auto new_dbname = dbname_ + "_2";
+
+ ASSERT_OK(options.env->RenameFile(dbname_, new_dbname));
+ ASSERT_OK(LoadLatestOptions(new_dbname, options.env, &db_opts, &cf_descs));
+ ASSERT_EQ(cf_descs.size(), 1U);
+
+ db_opts.create_if_missing = false;
+ ASSERT_OK(DB::Open(db_opts, new_dbname, cf_descs, &handles, &db));
+ std::string value;
+ ASSERT_OK(db->Get(ReadOptions(), "foo", &value));
+ ASSERT_EQ("value0", value);
+ // close the db
+ for (auto* handle : handles) {
+ delete handle;
+ }
+ delete db;
+ Options new_options(db_opts, cf_descs[0].options);
+ ASSERT_OK(DestroyDB(new_dbname, new_options, cf_descs));
+ ASSERT_OK(DestroyDB(dbname_, options));
+}
+
+TEST_F(OptionsUtilTest, WalDirSettings) {
+ DB* db;
+ Options options;
+ DBOptions db_opts;
+ std::vector<ColumnFamilyDescriptor> cf_descs;
+ std::vector<ColumnFamilyHandle*> handles;
+
+ options.create_if_missing = true;
+
+ // Open a DB with no wal dir set. The wal_dir should stay empty
+ ASSERT_OK(DB::Open(options, dbname_, &db));
+ delete db;
+ ASSERT_OK(LoadLatestOptions(dbname_, options.env, &db_opts, &cf_descs));
+ ASSERT_EQ(db_opts.wal_dir, "");
+
+ // Open a DB with wal_dir == dbname. The wal_dir should be set to empty
+ options.wal_dir = dbname_;
+ ASSERT_OK(DB::Open(options, dbname_, &db));
+ delete db;
+ ASSERT_OK(LoadLatestOptions(dbname_, options.env, &db_opts, &cf_descs));
+ ASSERT_EQ(db_opts.wal_dir, "");
+
+ // Open a DB with no wal_dir but a db_path==dbname_. The wal_dir should be
+ // empty
+ options.wal_dir = "";
+ options.db_paths.emplace_back(dbname_, std::numeric_limits<uint64_t>::max());
+ ASSERT_OK(DB::Open(options, dbname_, &db));
+ delete db;
+ ASSERT_OK(LoadLatestOptions(dbname_, options.env, &db_opts, &cf_descs));
+ ASSERT_EQ(db_opts.wal_dir, "");
+
+ // Open a DB with no wal_dir==dbname_ and db_path==dbname_. The wal_dir
+ // should be empty
+ options.wal_dir = dbname_ + "/";
+ options.db_paths.emplace_back(dbname_, std::numeric_limits<uint64_t>::max());
+ ASSERT_OK(DB::Open(options, dbname_, &db));
+ delete db;
+ ASSERT_OK(LoadLatestOptions(dbname_, options.env, &db_opts, &cf_descs));
+ ASSERT_EQ(db_opts.wal_dir, "");
+ ASSERT_OK(DestroyDB(dbname_, options));
+
+ // Open a DB with no wal_dir but db_path != db_name. The wal_dir == dbname_
+ options.wal_dir = "";
+ options.db_paths.clear();
+ options.db_paths.emplace_back(dbname_ + "_0",
+ std::numeric_limits<uint64_t>::max());
+ ASSERT_OK(DB::Open(options, dbname_, &db));
+ delete db;
+ ASSERT_OK(LoadLatestOptions(dbname_, options.env, &db_opts, &cf_descs));
+ ASSERT_EQ(db_opts.wal_dir, dbname_);
+ ASSERT_OK(DestroyDB(dbname_, options));
+
+ // Open a DB with wal_dir != db_name. The wal_dir remains unchanged
+ options.wal_dir = dbname_ + "/wal";
+ options.db_paths.clear();
+ ASSERT_OK(DB::Open(options, dbname_, &db));
+ delete db;
+ ASSERT_OK(LoadLatestOptions(dbname_, options.env, &db_opts, &cf_descs));
+ ASSERT_EQ(db_opts.wal_dir, dbname_ + "/wal");
+ ASSERT_OK(DestroyDB(dbname_, options));
+}
+
+TEST_F(OptionsUtilTest, WalDirInOptins) {
+ DB* db;
+ Options options;
+ DBOptions db_opts;
+ std::vector<ColumnFamilyDescriptor> cf_descs;
+ std::vector<ColumnFamilyHandle*> handles;
+
+ // Store an options file with wal_dir=dbname_ and make sure it still loads
+ // when the input wal_dir is empty
+ options.create_if_missing = true;
+ options.wal_dir = "";
+ ASSERT_OK(DB::Open(options, dbname_, &db));
+ delete db;
+ options.wal_dir = dbname_;
+ std::string options_file;
+ ASSERT_OK(GetLatestOptionsFileName(dbname_, options.env, &options_file));
+ ASSERT_OK(PersistRocksDBOptions(options, {"default"}, {options},
+ dbname_ + "/" + options_file,
+ options.env->GetFileSystem().get()));
+ ASSERT_OK(LoadLatestOptions(dbname_, options.env, &db_opts, &cf_descs));
+ ASSERT_EQ(db_opts.wal_dir, dbname_);
+ options.wal_dir = "";
+ ASSERT_OK(DB::Open(options, dbname_, &db));
+ delete db;
+ ASSERT_OK(LoadLatestOptions(dbname_, options.env, &db_opts, &cf_descs));
+ ASSERT_EQ(db_opts.wal_dir, "");
+}
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+#ifdef GFLAGS
+ ParseCommandLineFlags(&argc, &argv, true);
+#endif // GFLAGS
+ return RUN_ALL_TESTS();
+}
+
+#else
+#include <cstdio>
+
+int main(int /*argc*/, char** /*argv*/) {
+ printf("Skipped in RocksDBLite as utilities are not supported.\n");
+ return 0;
+}
+#endif // !ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/persistent_cache/block_cache_tier.cc b/src/rocksdb/utilities/persistent_cache/block_cache_tier.cc
new file mode 100644
index 000000000..8ad9bb1b1
--- /dev/null
+++ b/src/rocksdb/utilities/persistent_cache/block_cache_tier.cc
@@ -0,0 +1,422 @@
+// Copyright (c) 2013, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+#ifndef ROCKSDB_LITE
+
+#include "utilities/persistent_cache/block_cache_tier.h"
+
+#include <utility>
+#include <vector>
+
+#include "logging/logging.h"
+#include "port/port.h"
+#include "test_util/sync_point.h"
+#include "util/stop_watch.h"
+#include "utilities/persistent_cache/block_cache_tier_file.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+//
+// BlockCacheImpl
+//
+Status BlockCacheTier::Open() {
+ Status status;
+
+ WriteLock _(&lock_);
+
+ assert(!size_);
+
+ // Check the validity of the options
+ status = opt_.ValidateSettings();
+ assert(status.ok());
+ if (!status.ok()) {
+ Error(opt_.log, "Invalid block cache options");
+ return status;
+ }
+
+ // Create base directory or cleanup existing directory
+ status = opt_.env->CreateDirIfMissing(opt_.path);
+ if (!status.ok()) {
+ Error(opt_.log, "Error creating directory %s. %s", opt_.path.c_str(),
+ status.ToString().c_str());
+ return status;
+ }
+
+ // Create base/<cache dir> directory
+ status = opt_.env->CreateDir(GetCachePath());
+ if (!status.ok()) {
+ // directory already exists, clean it up
+ status = CleanupCacheFolder(GetCachePath());
+ assert(status.ok());
+ if (!status.ok()) {
+ Error(opt_.log, "Error creating directory %s. %s", opt_.path.c_str(),
+ status.ToString().c_str());
+ return status;
+ }
+ }
+
+ // create a new file
+ assert(!cache_file_);
+ status = NewCacheFile();
+ if (!status.ok()) {
+ Error(opt_.log, "Error creating new file %s. %s", opt_.path.c_str(),
+ status.ToString().c_str());
+ return status;
+ }
+
+ assert(cache_file_);
+
+ if (opt_.pipeline_writes) {
+ assert(!insert_th_.joinable());
+ insert_th_ = port::Thread(&BlockCacheTier::InsertMain, this);
+ }
+
+ return Status::OK();
+}
+
+bool IsCacheFile(const std::string& file) {
+ // check if the file has .rc suffix
+ // Unfortunately regex support across compilers is not even, so we use simple
+ // string parsing
+ size_t pos = file.find(".");
+ if (pos == std::string::npos) {
+ return false;
+ }
+
+ std::string suffix = file.substr(pos);
+ return suffix == ".rc";
+}
+
+Status BlockCacheTier::CleanupCacheFolder(const std::string& folder) {
+ std::vector<std::string> files;
+ Status status = opt_.env->GetChildren(folder, &files);
+ if (!status.ok()) {
+ Error(opt_.log, "Error getting files for %s. %s", folder.c_str(),
+ status.ToString().c_str());
+ return status;
+ }
+
+ // cleanup files with the patter :digi:.rc
+ for (auto file : files) {
+ if (IsCacheFile(file)) {
+ // cache file
+ Info(opt_.log, "Removing file %s.", file.c_str());
+ status = opt_.env->DeleteFile(folder + "/" + file);
+ if (!status.ok()) {
+ Error(opt_.log, "Error deleting file %s. %s", file.c_str(),
+ status.ToString().c_str());
+ return status;
+ }
+ } else {
+ ROCKS_LOG_DEBUG(opt_.log, "Skipping file %s", file.c_str());
+ }
+ }
+ return Status::OK();
+}
+
+Status BlockCacheTier::Close() {
+ // stop the insert thread
+ if (opt_.pipeline_writes && insert_th_.joinable()) {
+ InsertOp op(/*quit=*/true);
+ insert_ops_.Push(std::move(op));
+ insert_th_.join();
+ }
+
+ // stop the writer before
+ writer_.Stop();
+
+ // clear all metadata
+ WriteLock _(&lock_);
+ metadata_.Clear();
+ return Status::OK();
+}
+
+template <class T>
+void Add(std::map<std::string, double>* stats, const std::string& key,
+ const T& t) {
+ stats->insert({key, static_cast<double>(t)});
+}
+
+PersistentCache::StatsType BlockCacheTier::Stats() {
+ std::map<std::string, double> stats;
+ Add(&stats, "persistentcache.blockcachetier.bytes_piplined",
+ stats_.bytes_pipelined_.Average());
+ Add(&stats, "persistentcache.blockcachetier.bytes_written",
+ stats_.bytes_written_.Average());
+ Add(&stats, "persistentcache.blockcachetier.bytes_read",
+ stats_.bytes_read_.Average());
+ Add(&stats, "persistentcache.blockcachetier.insert_dropped",
+ stats_.insert_dropped_);
+ Add(&stats, "persistentcache.blockcachetier.cache_hits", stats_.cache_hits_);
+ Add(&stats, "persistentcache.blockcachetier.cache_misses",
+ stats_.cache_misses_);
+ Add(&stats, "persistentcache.blockcachetier.cache_errors",
+ stats_.cache_errors_);
+ Add(&stats, "persistentcache.blockcachetier.cache_hits_pct",
+ stats_.CacheHitPct());
+ Add(&stats, "persistentcache.blockcachetier.cache_misses_pct",
+ stats_.CacheMissPct());
+ Add(&stats, "persistentcache.blockcachetier.read_hit_latency",
+ stats_.read_hit_latency_.Average());
+ Add(&stats, "persistentcache.blockcachetier.read_miss_latency",
+ stats_.read_miss_latency_.Average());
+ Add(&stats, "persistentcache.blockcachetier.write_latency",
+ stats_.write_latency_.Average());
+
+ auto out = PersistentCacheTier::Stats();
+ out.push_back(stats);
+ return out;
+}
+
+Status BlockCacheTier::Insert(const Slice& key, const char* data,
+ const size_t size) {
+ // update stats
+ stats_.bytes_pipelined_.Add(size);
+
+ if (opt_.pipeline_writes) {
+ // off load the write to the write thread
+ insert_ops_.Push(
+ InsertOp(key.ToString(), std::move(std::string(data, size))));
+ return Status::OK();
+ }
+
+ assert(!opt_.pipeline_writes);
+ return InsertImpl(key, Slice(data, size));
+}
+
+void BlockCacheTier::InsertMain() {
+ while (true) {
+ InsertOp op(insert_ops_.Pop());
+
+ if (op.signal_) {
+ // that is a secret signal to exit
+ break;
+ }
+
+ size_t retry = 0;
+ Status s;
+ while ((s = InsertImpl(Slice(op.key_), Slice(op.data_))).IsTryAgain()) {
+ if (retry > kMaxRetry) {
+ break;
+ }
+
+ // this can happen when the buffers are full, we wait till some buffers
+ // are free. Why don't we wait inside the code. This is because we want
+ // to support both pipelined and non-pipelined mode
+ buffer_allocator_.WaitUntilUsable();
+ retry++;
+ }
+
+ if (!s.ok()) {
+ stats_.insert_dropped_++;
+ }
+ }
+}
+
+Status BlockCacheTier::InsertImpl(const Slice& key, const Slice& data) {
+ // pre-condition
+ assert(key.size());
+ assert(data.size());
+ assert(cache_file_);
+
+ StopWatchNano timer(opt_.clock, /*auto_start=*/true);
+
+ WriteLock _(&lock_);
+
+ LBA lba;
+ if (metadata_.Lookup(key, &lba)) {
+ // the key already exists, this is duplicate insert
+ return Status::OK();
+ }
+
+ while (!cache_file_->Append(key, data, &lba)) {
+ if (!cache_file_->Eof()) {
+ ROCKS_LOG_DEBUG(opt_.log, "Error inserting to cache file %d",
+ cache_file_->cacheid());
+ stats_.write_latency_.Add(timer.ElapsedNanos() / 1000);
+ return Status::TryAgain();
+ }
+
+ assert(cache_file_->Eof());
+ Status status = NewCacheFile();
+ if (!status.ok()) {
+ return status;
+ }
+ }
+
+ // Insert into lookup index
+ BlockInfo* info = metadata_.Insert(key, lba);
+ assert(info);
+ if (!info) {
+ return Status::IOError("Unexpected error inserting to index");
+ }
+
+ // insert to cache file reverse mapping
+ cache_file_->Add(info);
+
+ // update stats
+ stats_.bytes_written_.Add(data.size());
+ stats_.write_latency_.Add(timer.ElapsedNanos() / 1000);
+ return Status::OK();
+}
+
+Status BlockCacheTier::Lookup(const Slice& key, std::unique_ptr<char[]>* val,
+ size_t* size) {
+ StopWatchNano timer(opt_.clock, /*auto_start=*/true);
+
+ LBA lba;
+ bool status;
+ status = metadata_.Lookup(key, &lba);
+ if (!status) {
+ stats_.cache_misses_++;
+ stats_.read_miss_latency_.Add(timer.ElapsedNanos() / 1000);
+ return Status::NotFound("blockcache: key not found");
+ }
+
+ BlockCacheFile* const file = metadata_.Lookup(lba.cache_id_);
+ if (!file) {
+ // this can happen because the block index and cache file index are
+ // different, and the cache file might be removed between the two lookups
+ stats_.cache_misses_++;
+ stats_.read_miss_latency_.Add(timer.ElapsedNanos() / 1000);
+ return Status::NotFound("blockcache: cache file not found");
+ }
+
+ assert(file->refs_);
+
+ std::unique_ptr<char[]> scratch(new char[lba.size_]);
+ Slice blk_key;
+ Slice blk_val;
+
+ status = file->Read(lba, &blk_key, &blk_val, scratch.get());
+ --file->refs_;
+ if (!status) {
+ stats_.cache_misses_++;
+ stats_.cache_errors_++;
+ stats_.read_miss_latency_.Add(timer.ElapsedNanos() / 1000);
+ return Status::NotFound("blockcache: error reading data");
+ }
+
+ assert(blk_key == key);
+
+ val->reset(new char[blk_val.size()]);
+ memcpy(val->get(), blk_val.data(), blk_val.size());
+ *size = blk_val.size();
+
+ stats_.bytes_read_.Add(*size);
+ stats_.cache_hits_++;
+ stats_.read_hit_latency_.Add(timer.ElapsedNanos() / 1000);
+
+ return Status::OK();
+}
+
+bool BlockCacheTier::Erase(const Slice& key) {
+ WriteLock _(&lock_);
+ BlockInfo* info = metadata_.Remove(key);
+ assert(info);
+ delete info;
+ return true;
+}
+
+Status BlockCacheTier::NewCacheFile() {
+ lock_.AssertHeld();
+
+ TEST_SYNC_POINT_CALLBACK("BlockCacheTier::NewCacheFile:DeleteDir",
+ (void*)(GetCachePath().c_str()));
+
+ std::unique_ptr<WriteableCacheFile> f(new WriteableCacheFile(
+ opt_.env, &buffer_allocator_, &writer_, GetCachePath(), writer_cache_id_,
+ opt_.cache_file_size, opt_.log));
+
+ bool status = f->Create(opt_.enable_direct_writes, opt_.enable_direct_reads);
+ if (!status) {
+ return Status::IOError("Error creating file");
+ }
+
+ Info(opt_.log, "Created cache file %d", writer_cache_id_);
+
+ writer_cache_id_++;
+ cache_file_ = f.release();
+
+ // insert to cache files tree
+ status = metadata_.Insert(cache_file_);
+ assert(status);
+ if (!status) {
+ Error(opt_.log, "Error inserting to metadata");
+ return Status::IOError("Error inserting to metadata");
+ }
+
+ return Status::OK();
+}
+
+bool BlockCacheTier::Reserve(const size_t size) {
+ WriteLock _(&lock_);
+ assert(size_ <= opt_.cache_size);
+
+ if (size + size_ <= opt_.cache_size) {
+ // there is enough space to write
+ size_ += size;
+ return true;
+ }
+
+ assert(size + size_ >= opt_.cache_size);
+ // there is not enough space to fit the requested data
+ // we can clear some space by evicting cold data
+
+ const double retain_fac = (100 - kEvictPct) / static_cast<double>(100);
+ while (size + size_ > opt_.cache_size * retain_fac) {
+ std::unique_ptr<BlockCacheFile> f(metadata_.Evict());
+ if (!f) {
+ // nothing is evictable
+ return false;
+ }
+ assert(!f->refs_);
+ uint64_t file_size;
+ if (!f->Delete(&file_size).ok()) {
+ // unable to delete file
+ return false;
+ }
+
+ assert(file_size <= size_);
+ size_ -= file_size;
+ }
+
+ size_ += size;
+ assert(size_ <= opt_.cache_size * 0.9);
+ return true;
+}
+
+Status NewPersistentCache(Env* const env, const std::string& path,
+ const uint64_t size,
+ const std::shared_ptr<Logger>& log,
+ const bool optimized_for_nvm,
+ std::shared_ptr<PersistentCache>* cache) {
+ if (!cache) {
+ return Status::IOError("invalid argument cache");
+ }
+
+ auto opt = PersistentCacheConfig(env, path, size, log);
+ if (optimized_for_nvm) {
+ // the default settings are optimized for SSD
+ // NVM devices are better accessed with 4K direct IO and written with
+ // parallelism
+ opt.enable_direct_writes = true;
+ opt.writer_qdepth = 4;
+ opt.writer_dispatch_size = 4 * 1024;
+ }
+
+ auto pcache = std::make_shared<BlockCacheTier>(opt);
+ Status s = pcache->Open();
+
+ if (!s.ok()) {
+ return s;
+ }
+
+ *cache = pcache;
+ return s;
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // ifndef ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/persistent_cache/block_cache_tier.h b/src/rocksdb/utilities/persistent_cache/block_cache_tier.h
new file mode 100644
index 000000000..1aac287cc
--- /dev/null
+++ b/src/rocksdb/utilities/persistent_cache/block_cache_tier.h
@@ -0,0 +1,156 @@
+// Copyright (c) 2013, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+#pragma once
+
+#ifndef ROCKSDB_LITE
+
+#ifndef OS_WIN
+#include <unistd.h>
+#endif // ! OS_WIN
+
+#include <atomic>
+#include <list>
+#include <memory>
+#include <set>
+#include <sstream>
+#include <stdexcept>
+#include <string>
+#include <thread>
+
+#include "memory/arena.h"
+#include "memtable/skiplist.h"
+#include "monitoring/histogram.h"
+#include "port/port.h"
+#include "rocksdb/cache.h"
+#include "rocksdb/comparator.h"
+#include "rocksdb/persistent_cache.h"
+#include "rocksdb/system_clock.h"
+#include "util/coding.h"
+#include "util/crc32c.h"
+#include "util/mutexlock.h"
+#include "utilities/persistent_cache/block_cache_tier_file.h"
+#include "utilities/persistent_cache/block_cache_tier_metadata.h"
+#include "utilities/persistent_cache/persistent_cache_util.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+//
+// Block cache tier implementation
+//
+class BlockCacheTier : public PersistentCacheTier {
+ public:
+ explicit BlockCacheTier(const PersistentCacheConfig& opt)
+ : opt_(opt),
+ insert_ops_(static_cast<size_t>(opt_.max_write_pipeline_backlog_size)),
+ buffer_allocator_(opt.write_buffer_size, opt.write_buffer_count()),
+ writer_(this, opt_.writer_qdepth,
+ static_cast<size_t>(opt_.writer_dispatch_size)) {
+ Info(opt_.log, "Initializing allocator. size=%d B count=%" ROCKSDB_PRIszt,
+ opt_.write_buffer_size, opt_.write_buffer_count());
+ }
+
+ virtual ~BlockCacheTier() {
+ // Close is re-entrant so we can call close even if it is already closed
+ Close().PermitUncheckedError();
+ assert(!insert_th_.joinable());
+ }
+
+ Status Insert(const Slice& key, const char* data, const size_t size) override;
+ Status Lookup(const Slice& key, std::unique_ptr<char[]>* data,
+ size_t* size) override;
+ Status Open() override;
+ Status Close() override;
+ bool Erase(const Slice& key) override;
+ bool Reserve(const size_t size) override;
+
+ bool IsCompressed() override { return opt_.is_compressed; }
+
+ std::string GetPrintableOptions() const override { return opt_.ToString(); }
+
+ PersistentCache::StatsType Stats() override;
+
+ void TEST_Flush() override {
+ while (insert_ops_.Size()) {
+ /* sleep override */
+ SystemClock::Default()->SleepForMicroseconds(1000000);
+ }
+ }
+
+ private:
+ // Percentage of cache to be evicted when the cache is full
+ static const size_t kEvictPct = 10;
+ // Max attempts to insert key, value to cache in pipelined mode
+ static const size_t kMaxRetry = 3;
+
+ // Pipelined operation
+ struct InsertOp {
+ explicit InsertOp(const bool signal) : signal_(signal) {}
+ explicit InsertOp(std::string&& key, const std::string& data)
+ : key_(std::move(key)), data_(data) {}
+ ~InsertOp() {}
+
+ InsertOp() = delete;
+ InsertOp(InsertOp&& /*rhs*/) = default;
+ InsertOp& operator=(InsertOp&& rhs) = default;
+
+ // used for estimating size by bounded queue
+ size_t Size() { return data_.size() + key_.size(); }
+
+ std::string key_;
+ std::string data_;
+ bool signal_ = false; // signal to request processing thread to exit
+ };
+
+ // entry point for insert thread
+ void InsertMain();
+ // insert implementation
+ Status InsertImpl(const Slice& key, const Slice& data);
+ // Create a new cache file
+ Status NewCacheFile();
+ // Get cache directory path
+ std::string GetCachePath() const { return opt_.path + "/cache"; }
+ // Cleanup folder
+ Status CleanupCacheFolder(const std::string& folder);
+
+ // Statistics
+ struct Statistics {
+ HistogramImpl bytes_pipelined_;
+ HistogramImpl bytes_written_;
+ HistogramImpl bytes_read_;
+ HistogramImpl read_hit_latency_;
+ HistogramImpl read_miss_latency_;
+ HistogramImpl write_latency_;
+ std::atomic<uint64_t> cache_hits_{0};
+ std::atomic<uint64_t> cache_misses_{0};
+ std::atomic<uint64_t> cache_errors_{0};
+ std::atomic<uint64_t> insert_dropped_{0};
+
+ double CacheHitPct() const {
+ const auto lookups = cache_hits_ + cache_misses_;
+ return lookups ? 100 * cache_hits_ / static_cast<double>(lookups) : 0.0;
+ }
+
+ double CacheMissPct() const {
+ const auto lookups = cache_hits_ + cache_misses_;
+ return lookups ? 100 * cache_misses_ / static_cast<double>(lookups) : 0.0;
+ }
+ };
+
+ port::RWMutex lock_; // Synchronization
+ const PersistentCacheConfig opt_; // BlockCache options
+ BoundedQueue<InsertOp> insert_ops_; // Ops waiting for insert
+ ROCKSDB_NAMESPACE::port::Thread insert_th_; // Insert thread
+ uint32_t writer_cache_id_ = 0; // Current cache file identifier
+ WriteableCacheFile* cache_file_ = nullptr; // Current cache file reference
+ CacheWriteBufferAllocator buffer_allocator_; // Buffer provider
+ ThreadedWriter writer_; // Writer threads
+ BlockCacheTierMetadata metadata_; // Cache meta data manager
+ std::atomic<uint64_t> size_{0}; // Size of the cache
+ Statistics stats_; // Statistics
+};
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif
diff --git a/src/rocksdb/utilities/persistent_cache/block_cache_tier_file.cc b/src/rocksdb/utilities/persistent_cache/block_cache_tier_file.cc
new file mode 100644
index 000000000..f4f8517ab
--- /dev/null
+++ b/src/rocksdb/utilities/persistent_cache/block_cache_tier_file.cc
@@ -0,0 +1,610 @@
+// Copyright (c) 2013, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+#ifndef ROCKSDB_LITE
+
+#include "utilities/persistent_cache/block_cache_tier_file.h"
+
+#ifndef OS_WIN
+#include <unistd.h>
+#endif
+#include <functional>
+#include <memory>
+#include <vector>
+
+#include "env/composite_env_wrapper.h"
+#include "logging/logging.h"
+#include "port/port.h"
+#include "rocksdb/system_clock.h"
+#include "util/crc32c.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+//
+// File creation factories
+//
+Status NewWritableCacheFile(Env* const env, const std::string& filepath,
+ std::unique_ptr<WritableFile>* file,
+ const bool use_direct_writes = false) {
+ EnvOptions opt;
+ opt.use_direct_writes = use_direct_writes;
+ Status s = env->NewWritableFile(filepath, file, opt);
+ return s;
+}
+
+Status NewRandomAccessCacheFile(const std::shared_ptr<FileSystem>& fs,
+ const std::string& filepath,
+ std::unique_ptr<FSRandomAccessFile>* file,
+ const bool use_direct_reads = true) {
+ assert(fs.get());
+
+ FileOptions opt;
+ opt.use_direct_reads = use_direct_reads;
+ return fs->NewRandomAccessFile(filepath, opt, file, nullptr);
+}
+
+//
+// BlockCacheFile
+//
+Status BlockCacheFile::Delete(uint64_t* size) {
+ assert(env_);
+
+ Status status = env_->GetFileSize(Path(), size);
+ if (!status.ok()) {
+ return status;
+ }
+ return env_->DeleteFile(Path());
+}
+
+//
+// CacheRecord
+//
+// Cache record represents the record on disk
+//
+// +--------+---------+----------+------------+---------------+-------------+
+// | magic | crc | key size | value size | key data | value data |
+// +--------+---------+----------+------------+---------------+-------------+
+// <-- 4 --><-- 4 --><-- 4 --><-- 4 --><-- key size --><-- v-size -->
+//
+struct CacheRecordHeader {
+ CacheRecordHeader() : magic_(0), crc_(0), key_size_(0), val_size_(0) {}
+ CacheRecordHeader(const uint32_t magic, const uint32_t key_size,
+ const uint32_t val_size)
+ : magic_(magic), crc_(0), key_size_(key_size), val_size_(val_size) {}
+
+ uint32_t magic_;
+ uint32_t crc_;
+ uint32_t key_size_;
+ uint32_t val_size_;
+};
+
+struct CacheRecord {
+ CacheRecord() {}
+ CacheRecord(const Slice& key, const Slice& val)
+ : hdr_(MAGIC, static_cast<uint32_t>(key.size()),
+ static_cast<uint32_t>(val.size())),
+ key_(key),
+ val_(val) {
+ hdr_.crc_ = ComputeCRC();
+ }
+
+ uint32_t ComputeCRC() const;
+ bool Serialize(std::vector<CacheWriteBuffer*>* bufs, size_t* woff);
+ bool Deserialize(const Slice& buf);
+
+ static uint32_t CalcSize(const Slice& key, const Slice& val) {
+ return static_cast<uint32_t>(sizeof(CacheRecordHeader) + key.size() +
+ val.size());
+ }
+
+ static const uint32_t MAGIC = 0xfefa;
+
+ bool Append(std::vector<CacheWriteBuffer*>* bufs, size_t* woff,
+ const char* data, const size_t size);
+
+ CacheRecordHeader hdr_;
+ Slice key_;
+ Slice val_;
+};
+
+static_assert(sizeof(CacheRecordHeader) == 16, "DataHeader is not aligned");
+
+uint32_t CacheRecord::ComputeCRC() const {
+ uint32_t crc = 0;
+ CacheRecordHeader tmp = hdr_;
+ tmp.crc_ = 0;
+ crc = crc32c::Extend(crc, reinterpret_cast<const char*>(&tmp), sizeof(tmp));
+ crc = crc32c::Extend(crc, reinterpret_cast<const char*>(key_.data()),
+ key_.size());
+ crc = crc32c::Extend(crc, reinterpret_cast<const char*>(val_.data()),
+ val_.size());
+ return crc;
+}
+
+bool CacheRecord::Serialize(std::vector<CacheWriteBuffer*>* bufs,
+ size_t* woff) {
+ assert(bufs->size());
+ return Append(bufs, woff, reinterpret_cast<const char*>(&hdr_),
+ sizeof(hdr_)) &&
+ Append(bufs, woff, reinterpret_cast<const char*>(key_.data()),
+ key_.size()) &&
+ Append(bufs, woff, reinterpret_cast<const char*>(val_.data()),
+ val_.size());
+}
+
+bool CacheRecord::Append(std::vector<CacheWriteBuffer*>* bufs, size_t* woff,
+ const char* data, const size_t data_size) {
+ assert(*woff < bufs->size());
+
+ const char* p = data;
+ size_t size = data_size;
+
+ while (size && *woff < bufs->size()) {
+ CacheWriteBuffer* buf = (*bufs)[*woff];
+ const size_t free = buf->Free();
+ if (size <= free) {
+ buf->Append(p, size);
+ size = 0;
+ } else {
+ buf->Append(p, free);
+ p += free;
+ size -= free;
+ assert(!buf->Free());
+ assert(buf->Used() == buf->Capacity());
+ }
+
+ if (!buf->Free()) {
+ *woff += 1;
+ }
+ }
+
+ assert(!size);
+
+ return !size;
+}
+
+bool CacheRecord::Deserialize(const Slice& data) {
+ assert(data.size() >= sizeof(CacheRecordHeader));
+ if (data.size() < sizeof(CacheRecordHeader)) {
+ return false;
+ }
+
+ memcpy(&hdr_, data.data(), sizeof(hdr_));
+
+ assert(hdr_.key_size_ + hdr_.val_size_ + sizeof(hdr_) == data.size());
+ if (hdr_.key_size_ + hdr_.val_size_ + sizeof(hdr_) != data.size()) {
+ return false;
+ }
+
+ key_ = Slice(data.data_ + sizeof(hdr_), hdr_.key_size_);
+ val_ = Slice(key_.data_ + hdr_.key_size_, hdr_.val_size_);
+
+ if (!(hdr_.magic_ == MAGIC && ComputeCRC() == hdr_.crc_)) {
+ fprintf(stderr, "** magic %d ** \n", hdr_.magic_);
+ fprintf(stderr, "** key_size %d ** \n", hdr_.key_size_);
+ fprintf(stderr, "** val_size %d ** \n", hdr_.val_size_);
+ fprintf(stderr, "** key %s ** \n", key_.ToString().c_str());
+ fprintf(stderr, "** val %s ** \n", val_.ToString().c_str());
+ for (size_t i = 0; i < hdr_.val_size_; ++i) {
+ fprintf(stderr, "%d.", (uint8_t)val_.data()[i]);
+ }
+ fprintf(stderr, "\n** cksum %d != %d **", hdr_.crc_, ComputeCRC());
+ }
+
+ assert(hdr_.magic_ == MAGIC && ComputeCRC() == hdr_.crc_);
+ return hdr_.magic_ == MAGIC && ComputeCRC() == hdr_.crc_;
+}
+
+//
+// RandomAccessFile
+//
+
+bool RandomAccessCacheFile::Open(const bool enable_direct_reads) {
+ WriteLock _(&rwlock_);
+ return OpenImpl(enable_direct_reads);
+}
+
+bool RandomAccessCacheFile::OpenImpl(const bool enable_direct_reads) {
+ rwlock_.AssertHeld();
+
+ ROCKS_LOG_DEBUG(log_, "Opening cache file %s", Path().c_str());
+ assert(env_);
+
+ std::unique_ptr<FSRandomAccessFile> file;
+ Status status = NewRandomAccessCacheFile(env_->GetFileSystem(), Path(), &file,
+ enable_direct_reads);
+ if (!status.ok()) {
+ Error(log_, "Error opening random access file %s. %s", Path().c_str(),
+ status.ToString().c_str());
+ return false;
+ }
+ freader_.reset(new RandomAccessFileReader(std::move(file), Path(),
+ env_->GetSystemClock().get()));
+
+ return true;
+}
+
+bool RandomAccessCacheFile::Read(const LBA& lba, Slice* key, Slice* val,
+ char* scratch) {
+ ReadLock _(&rwlock_);
+
+ assert(lba.cache_id_ == cache_id_);
+
+ if (!freader_) {
+ return false;
+ }
+
+ Slice result;
+ Status s = freader_->Read(IOOptions(), lba.off_, lba.size_, &result, scratch,
+ nullptr, Env::IO_TOTAL /* rate_limiter_priority */);
+ if (!s.ok()) {
+ Error(log_, "Error reading from file %s. %s", Path().c_str(),
+ s.ToString().c_str());
+ return false;
+ }
+
+ assert(result.data() == scratch);
+
+ return ParseRec(lba, key, val, scratch);
+}
+
+bool RandomAccessCacheFile::ParseRec(const LBA& lba, Slice* key, Slice* val,
+ char* scratch) {
+ Slice data(scratch, lba.size_);
+
+ CacheRecord rec;
+ if (!rec.Deserialize(data)) {
+ assert(!"Error deserializing data");
+ Error(log_, "Error de-serializing record from file %s off %d",
+ Path().c_str(), lba.off_);
+ return false;
+ }
+
+ *key = Slice(rec.key_);
+ *val = Slice(rec.val_);
+
+ return true;
+}
+
+//
+// WriteableCacheFile
+//
+
+WriteableCacheFile::~WriteableCacheFile() {
+ WriteLock _(&rwlock_);
+ if (!eof_) {
+ // This file never flushed. We give priority to shutdown since this is a
+ // cache
+ // TODO(krad): Figure a way to flush the pending data
+ if (file_) {
+ assert(refs_ == 1);
+ --refs_;
+ }
+ }
+ assert(!refs_);
+ ClearBuffers();
+}
+
+bool WriteableCacheFile::Create(const bool /*enable_direct_writes*/,
+ const bool enable_direct_reads) {
+ WriteLock _(&rwlock_);
+
+ enable_direct_reads_ = enable_direct_reads;
+
+ ROCKS_LOG_DEBUG(log_, "Creating new cache %s (max size is %d B)",
+ Path().c_str(), max_size_);
+
+ assert(env_);
+
+ Status s = env_->FileExists(Path());
+ if (s.ok()) {
+ ROCKS_LOG_WARN(log_, "File %s already exists. %s", Path().c_str(),
+ s.ToString().c_str());
+ }
+
+ s = NewWritableCacheFile(env_, Path(), &file_);
+ if (!s.ok()) {
+ ROCKS_LOG_WARN(log_, "Unable to create file %s. %s", Path().c_str(),
+ s.ToString().c_str());
+ return false;
+ }
+
+ assert(!refs_);
+ ++refs_;
+
+ return true;
+}
+
+bool WriteableCacheFile::Append(const Slice& key, const Slice& val, LBA* lba) {
+ WriteLock _(&rwlock_);
+
+ if (eof_) {
+ // We can't append since the file is full
+ return false;
+ }
+
+ // estimate the space required to store the (key, val)
+ uint32_t rec_size = CacheRecord::CalcSize(key, val);
+
+ if (!ExpandBuffer(rec_size)) {
+ // unable to expand the buffer
+ ROCKS_LOG_DEBUG(log_, "Error expanding buffers. size=%d", rec_size);
+ return false;
+ }
+
+ lba->cache_id_ = cache_id_;
+ lba->off_ = disk_woff_;
+ lba->size_ = rec_size;
+
+ CacheRecord rec(key, val);
+ if (!rec.Serialize(&bufs_, &buf_woff_)) {
+ // unexpected error: unable to serialize the data
+ assert(!"Error serializing record");
+ return false;
+ }
+
+ disk_woff_ += rec_size;
+ eof_ = disk_woff_ >= max_size_;
+
+ // dispatch buffer for flush
+ DispatchBuffer();
+
+ return true;
+}
+
+bool WriteableCacheFile::ExpandBuffer(const size_t size) {
+ rwlock_.AssertHeld();
+ assert(!eof_);
+
+ // determine if there is enough space
+ size_t free = 0; // compute the free space left in buffer
+ for (size_t i = buf_woff_; i < bufs_.size(); ++i) {
+ free += bufs_[i]->Free();
+ if (size <= free) {
+ // we have enough space in the buffer
+ return true;
+ }
+ }
+
+ // expand the buffer until there is enough space to write `size` bytes
+ assert(free < size);
+ assert(alloc_);
+
+ while (free < size) {
+ CacheWriteBuffer* const buf = alloc_->Allocate();
+ if (!buf) {
+ ROCKS_LOG_DEBUG(log_, "Unable to allocate buffers");
+ return false;
+ }
+
+ size_ += static_cast<uint32_t>(buf->Free());
+ free += buf->Free();
+ bufs_.push_back(buf);
+ }
+
+ assert(free >= size);
+ return true;
+}
+
+void WriteableCacheFile::DispatchBuffer() {
+ rwlock_.AssertHeld();
+
+ assert(bufs_.size());
+ assert(buf_doff_ <= buf_woff_);
+ assert(buf_woff_ <= bufs_.size());
+
+ if (pending_ios_) {
+ return;
+ }
+
+ if (!eof_ && buf_doff_ == buf_woff_) {
+ // dispatch buffer is pointing to write buffer and we haven't hit eof
+ return;
+ }
+
+ assert(eof_ || buf_doff_ < buf_woff_);
+ assert(buf_doff_ < bufs_.size());
+ assert(file_);
+ assert(alloc_);
+
+ auto* buf = bufs_[buf_doff_];
+ const uint64_t file_off = buf_doff_ * alloc_->BufferSize();
+
+ assert(!buf->Free() ||
+ (eof_ && buf_doff_ == buf_woff_ && buf_woff_ < bufs_.size()));
+ // we have reached end of file, and there is space in the last buffer
+ // pad it with zero for direct IO
+ buf->FillTrailingZeros();
+
+ assert(buf->Used() % kFileAlignmentSize == 0);
+
+ writer_->Write(file_.get(), buf, file_off,
+ std::bind(&WriteableCacheFile::BufferWriteDone, this));
+ pending_ios_++;
+ buf_doff_++;
+}
+
+void WriteableCacheFile::BufferWriteDone() {
+ WriteLock _(&rwlock_);
+
+ assert(bufs_.size());
+
+ pending_ios_--;
+
+ if (buf_doff_ < bufs_.size()) {
+ DispatchBuffer();
+ }
+
+ if (eof_ && buf_doff_ >= bufs_.size() && !pending_ios_) {
+ // end-of-file reached, move to read mode
+ CloseAndOpenForReading();
+ }
+}
+
+void WriteableCacheFile::CloseAndOpenForReading() {
+ // Our env abstraction do not allow reading from a file opened for appending
+ // We need close the file and re-open it for reading
+ Close();
+ RandomAccessCacheFile::OpenImpl(enable_direct_reads_);
+}
+
+bool WriteableCacheFile::ReadBuffer(const LBA& lba, Slice* key, Slice* block,
+ char* scratch) {
+ rwlock_.AssertHeld();
+
+ if (!ReadBuffer(lba, scratch)) {
+ Error(log_, "Error reading from buffer. cache=%d off=%d", cache_id_,
+ lba.off_);
+ return false;
+ }
+
+ return ParseRec(lba, key, block, scratch);
+}
+
+bool WriteableCacheFile::ReadBuffer(const LBA& lba, char* data) {
+ rwlock_.AssertHeld();
+
+ assert(lba.off_ < disk_woff_);
+ assert(alloc_);
+
+ // we read from the buffers like reading from a flat file. The list of buffers
+ // are treated as contiguous stream of data
+
+ char* tmp = data;
+ size_t pending_nbytes = lba.size_;
+ // start buffer
+ size_t start_idx = lba.off_ / alloc_->BufferSize();
+ // offset into the start buffer
+ size_t start_off = lba.off_ % alloc_->BufferSize();
+
+ assert(start_idx <= buf_woff_);
+
+ for (size_t i = start_idx; pending_nbytes && i < bufs_.size(); ++i) {
+ assert(i <= buf_woff_);
+ auto* buf = bufs_[i];
+ assert(i == buf_woff_ || !buf->Free());
+ // bytes to write to the buffer
+ size_t nbytes = pending_nbytes > (buf->Used() - start_off)
+ ? (buf->Used() - start_off)
+ : pending_nbytes;
+ memcpy(tmp, buf->Data() + start_off, nbytes);
+
+ // left over to be written
+ pending_nbytes -= nbytes;
+ start_off = 0;
+ tmp += nbytes;
+ }
+
+ assert(!pending_nbytes);
+ if (pending_nbytes) {
+ return false;
+ }
+
+ assert(tmp == data + lba.size_);
+ return true;
+}
+
+void WriteableCacheFile::Close() {
+ rwlock_.AssertHeld();
+
+ assert(size_ >= max_size_);
+ assert(disk_woff_ >= max_size_);
+ assert(buf_doff_ == bufs_.size());
+ assert(bufs_.size() - buf_woff_ <= 1);
+ assert(!pending_ios_);
+
+ Info(log_, "Closing file %s. size=%d written=%d", Path().c_str(), size_,
+ disk_woff_);
+
+ ClearBuffers();
+ file_.reset();
+
+ assert(refs_);
+ --refs_;
+}
+
+void WriteableCacheFile::ClearBuffers() {
+ assert(alloc_);
+
+ for (size_t i = 0; i < bufs_.size(); ++i) {
+ alloc_->Deallocate(bufs_[i]);
+ }
+
+ bufs_.clear();
+}
+
+//
+// ThreadedFileWriter implementation
+//
+ThreadedWriter::ThreadedWriter(PersistentCacheTier* const cache,
+ const size_t qdepth, const size_t io_size)
+ : Writer(cache), io_size_(io_size) {
+ for (size_t i = 0; i < qdepth; ++i) {
+ port::Thread th(&ThreadedWriter::ThreadMain, this);
+ threads_.push_back(std::move(th));
+ }
+}
+
+void ThreadedWriter::Stop() {
+ // notify all threads to exit
+ for (size_t i = 0; i < threads_.size(); ++i) {
+ q_.Push(IO(/*signal=*/true));
+ }
+
+ // wait for all threads to exit
+ for (auto& th : threads_) {
+ th.join();
+ assert(!th.joinable());
+ }
+ threads_.clear();
+}
+
+void ThreadedWriter::Write(WritableFile* const file, CacheWriteBuffer* buf,
+ const uint64_t file_off,
+ const std::function<void()> callback) {
+ q_.Push(IO(file, buf, file_off, callback));
+}
+
+void ThreadedWriter::ThreadMain() {
+ while (true) {
+ // Fetch the IO to process
+ IO io(q_.Pop());
+ if (io.signal_) {
+ // that's secret signal to exit
+ break;
+ }
+
+ // Reserve space for writing the buffer
+ while (!cache_->Reserve(io.buf_->Used())) {
+ // We can fail to reserve space if every file in the system
+ // is being currently accessed
+ /* sleep override */
+ SystemClock::Default()->SleepForMicroseconds(1000000);
+ }
+
+ DispatchIO(io);
+
+ io.callback_();
+ }
+}
+
+void ThreadedWriter::DispatchIO(const IO& io) {
+ size_t written = 0;
+ while (written < io.buf_->Used()) {
+ Slice data(io.buf_->Data() + written, io_size_);
+ Status s = io.file_->Append(data);
+ assert(s.ok());
+ if (!s.ok()) {
+ // That is definite IO error to device. There is not much we can
+ // do but ignore the failure. This can lead to corruption of data on
+ // disk, but the cache will skip while reading
+ fprintf(stderr, "Error writing data to file. %s\n", s.ToString().c_str());
+ }
+ written += io_size_;
+ }
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif
diff --git a/src/rocksdb/utilities/persistent_cache/block_cache_tier_file.h b/src/rocksdb/utilities/persistent_cache/block_cache_tier_file.h
new file mode 100644
index 000000000..1d265ab74
--- /dev/null
+++ b/src/rocksdb/utilities/persistent_cache/block_cache_tier_file.h
@@ -0,0 +1,293 @@
+// Copyright (c) 2013, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+#pragma once
+
+#ifndef ROCKSDB_LITE
+
+#include <list>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "file/random_access_file_reader.h"
+#include "port/port.h"
+#include "rocksdb/comparator.h"
+#include "rocksdb/env.h"
+#include "util/crc32c.h"
+#include "util/mutexlock.h"
+#include "utilities/persistent_cache/block_cache_tier_file_buffer.h"
+#include "utilities/persistent_cache/lrulist.h"
+#include "utilities/persistent_cache/persistent_cache_tier.h"
+#include "utilities/persistent_cache/persistent_cache_util.h"
+
+// The io code path of persistent cache uses pipelined architecture
+//
+// client -> In Queue <-- BlockCacheTier --> Out Queue <-- Writer <--> Kernel
+//
+// This would enable the system to scale for GB/s of throughput which is
+// expected with modern devies like NVM.
+//
+// The file level operations are encapsulated in the following abstractions
+//
+// BlockCacheFile
+// ^
+// |
+// |
+// RandomAccessCacheFile (For reading)
+// ^
+// |
+// |
+// WriteableCacheFile (For writing)
+//
+// Write IO code path :
+//
+namespace ROCKSDB_NAMESPACE {
+
+class WriteableCacheFile;
+struct BlockInfo;
+
+// Represents a logical record on device
+//
+// (L)ogical (B)lock (Address = { cache-file-id, offset, size }
+struct LogicalBlockAddress {
+ LogicalBlockAddress() {}
+ explicit LogicalBlockAddress(const uint32_t cache_id, const uint32_t off,
+ const uint16_t size)
+ : cache_id_(cache_id), off_(off), size_(size) {}
+
+ uint32_t cache_id_ = 0;
+ uint32_t off_ = 0;
+ uint32_t size_ = 0;
+};
+
+using LBA = LogicalBlockAddress;
+
+// class Writer
+//
+// Writer is the abstraction used for writing data to file. The component can be
+// multithreaded. It is the last step of write pipeline
+class Writer {
+ public:
+ explicit Writer(PersistentCacheTier* const cache) : cache_(cache) {}
+ virtual ~Writer() {}
+
+ // write buffer to file at the given offset
+ virtual void Write(WritableFile* const file, CacheWriteBuffer* buf,
+ const uint64_t file_off,
+ const std::function<void()> callback) = 0;
+ // stop the writer
+ virtual void Stop() = 0;
+
+ PersistentCacheTier* const cache_;
+};
+
+// class BlockCacheFile
+//
+// Generic interface to support building file specialized for read/writing
+class BlockCacheFile : public LRUElement<BlockCacheFile> {
+ public:
+ explicit BlockCacheFile(const uint32_t cache_id)
+ : LRUElement<BlockCacheFile>(), cache_id_(cache_id) {}
+
+ explicit BlockCacheFile(Env* const env, const std::string& dir,
+ const uint32_t cache_id)
+ : LRUElement<BlockCacheFile>(),
+ env_(env),
+ dir_(dir),
+ cache_id_(cache_id) {}
+
+ virtual ~BlockCacheFile() {}
+
+ // append key/value to file and return LBA locator to user
+ virtual bool Append(const Slice& /*key*/, const Slice& /*val*/,
+ LBA* const /*lba*/) {
+ assert(!"not implemented");
+ return false;
+ }
+
+ // read from the record locator (LBA) and return key, value and status
+ virtual bool Read(const LBA& /*lba*/, Slice* /*key*/, Slice* /*block*/,
+ char* /*scratch*/) {
+ assert(!"not implemented");
+ return false;
+ }
+
+ // get file path
+ std::string Path() const {
+ return dir_ + "/" + std::to_string(cache_id_) + ".rc";
+ }
+ // get cache ID
+ uint32_t cacheid() const { return cache_id_; }
+ // Add block information to file data
+ // Block information is the list of index reference for this file
+ virtual void Add(BlockInfo* binfo) {
+ WriteLock _(&rwlock_);
+ block_infos_.push_back(binfo);
+ }
+ // get block information
+ std::list<BlockInfo*>& block_infos() { return block_infos_; }
+ // delete file and return the size of the file
+ virtual Status Delete(uint64_t* size);
+
+ protected:
+ port::RWMutex rwlock_; // synchronization mutex
+ Env* const env_ = nullptr; // Env for OS
+ const std::string dir_; // Directory name
+ const uint32_t cache_id_; // Cache id for the file
+ std::list<BlockInfo*> block_infos_; // List of index entries mapping to the
+ // file content
+};
+
+// class RandomAccessFile
+//
+// Thread safe implementation for reading random data from file
+class RandomAccessCacheFile : public BlockCacheFile {
+ public:
+ explicit RandomAccessCacheFile(Env* const env, const std::string& dir,
+ const uint32_t cache_id,
+ const std::shared_ptr<Logger>& log)
+ : BlockCacheFile(env, dir, cache_id), log_(log) {}
+
+ virtual ~RandomAccessCacheFile() {}
+
+ // open file for reading
+ bool Open(const bool enable_direct_reads);
+ // read data from the disk
+ bool Read(const LBA& lba, Slice* key, Slice* block, char* scratch) override;
+
+ private:
+ std::unique_ptr<RandomAccessFileReader> freader_;
+
+ protected:
+ bool OpenImpl(const bool enable_direct_reads);
+ bool ParseRec(const LBA& lba, Slice* key, Slice* val, char* scratch);
+
+ std::shared_ptr<Logger> log_; // log file
+};
+
+// class WriteableCacheFile
+//
+// All writes to the files are cached in buffers. The buffers are flushed to
+// disk as they get filled up. When file size reaches a certain size, a new file
+// will be created provided there is free space
+class WriteableCacheFile : public RandomAccessCacheFile {
+ public:
+ explicit WriteableCacheFile(Env* const env, CacheWriteBufferAllocator* alloc,
+ Writer* writer, const std::string& dir,
+ const uint32_t cache_id, const uint32_t max_size,
+ const std::shared_ptr<Logger>& log)
+ : RandomAccessCacheFile(env, dir, cache_id, log),
+ alloc_(alloc),
+ writer_(writer),
+ max_size_(max_size) {}
+
+ virtual ~WriteableCacheFile();
+
+ // create file on disk
+ bool Create(const bool enable_direct_writes, const bool enable_direct_reads);
+
+ // read data from logical file
+ bool Read(const LBA& lba, Slice* key, Slice* block, char* scratch) override {
+ ReadLock _(&rwlock_);
+ const bool closed = eof_ && bufs_.empty();
+ if (closed) {
+ // the file is closed, read from disk
+ return RandomAccessCacheFile::Read(lba, key, block, scratch);
+ }
+ // file is still being written, read from buffers
+ return ReadBuffer(lba, key, block, scratch);
+ }
+
+ // append data to end of file
+ bool Append(const Slice&, const Slice&, LBA* const) override;
+ // End-of-file
+ bool Eof() const { return eof_; }
+
+ private:
+ friend class ThreadedWriter;
+
+ static const size_t kFileAlignmentSize = 4 * 1024; // align file size
+
+ bool ReadBuffer(const LBA& lba, Slice* key, Slice* block, char* scratch);
+ bool ReadBuffer(const LBA& lba, char* data);
+ bool ExpandBuffer(const size_t size);
+ void DispatchBuffer();
+ void BufferWriteDone();
+ void CloseAndOpenForReading();
+ void ClearBuffers();
+ void Close();
+
+ // File layout in memory
+ //
+ // +------+------+------+------+------+------+
+ // | b0 | b1 | b2 | b3 | b4 | b5 |
+ // +------+------+------+------+------+------+
+ // ^ ^
+ // | |
+ // buf_doff_ buf_woff_
+ // (next buffer to (next buffer to fill)
+ // flush to disk)
+ //
+ // The buffers are flushed to disk serially for a given file
+
+ CacheWriteBufferAllocator* const alloc_ = nullptr; // Buffer provider
+ Writer* const writer_ = nullptr; // File writer thread
+ std::unique_ptr<WritableFile> file_; // RocksDB Env file abstraction
+ std::vector<CacheWriteBuffer*> bufs_; // Written buffers
+ uint32_t size_ = 0; // Size of the file
+ const uint32_t max_size_; // Max size of the file
+ bool eof_ = false; // End of file
+ uint32_t disk_woff_ = 0; // Offset to write on disk
+ size_t buf_woff_ = 0; // off into bufs_ to write
+ size_t buf_doff_ = 0; // off into bufs_ to dispatch
+ size_t pending_ios_ = 0; // Number of ios to disk in-progress
+ bool enable_direct_reads_ = false; // Should we enable direct reads
+ // when reading from disk
+};
+
+//
+// Abstraction to do writing to device. It is part of pipelined architecture.
+//
+class ThreadedWriter : public Writer {
+ public:
+ // Representation of IO to device
+ struct IO {
+ explicit IO(const bool signal) : signal_(signal) {}
+ explicit IO(WritableFile* const file, CacheWriteBuffer* const buf,
+ const uint64_t file_off, const std::function<void()> callback)
+ : file_(file), buf_(buf), file_off_(file_off), callback_(callback) {}
+
+ IO(const IO&) = default;
+ IO& operator=(const IO&) = default;
+ size_t Size() const { return sizeof(IO); }
+
+ WritableFile* file_ = nullptr; // File to write to
+ CacheWriteBuffer* buf_ = nullptr; // buffer to write
+ uint64_t file_off_ = 0; // file offset
+ bool signal_ = false; // signal to exit thread loop
+ std::function<void()> callback_; // Callback on completion
+ };
+
+ explicit ThreadedWriter(PersistentCacheTier* const cache, const size_t qdepth,
+ const size_t io_size);
+ virtual ~ThreadedWriter() { assert(threads_.empty()); }
+
+ void Stop() override;
+ void Write(WritableFile* const file, CacheWriteBuffer* buf,
+ const uint64_t file_off,
+ const std::function<void()> callback) override;
+
+ private:
+ void ThreadMain();
+ void DispatchIO(const IO& io);
+
+ const size_t io_size_ = 0;
+ BoundedQueue<IO> q_;
+ std::vector<port::Thread> threads_;
+};
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif
diff --git a/src/rocksdb/utilities/persistent_cache/block_cache_tier_file_buffer.h b/src/rocksdb/utilities/persistent_cache/block_cache_tier_file_buffer.h
new file mode 100644
index 000000000..d4f02455a
--- /dev/null
+++ b/src/rocksdb/utilities/persistent_cache/block_cache_tier_file_buffer.h
@@ -0,0 +1,127 @@
+// Copyright (c) 2013, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+#pragma once
+
+#include <list>
+#include <memory>
+#include <string>
+
+#include "memory/arena.h"
+#include "rocksdb/comparator.h"
+#include "util/mutexlock.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+//
+// CacheWriteBuffer
+//
+// Buffer abstraction that can be manipulated via append
+// (not thread safe)
+class CacheWriteBuffer {
+ public:
+ explicit CacheWriteBuffer(const size_t size) : size_(size), pos_(0) {
+ buf_.reset(new char[size_]);
+ assert(!pos_);
+ assert(size_);
+ }
+
+ virtual ~CacheWriteBuffer() {}
+
+ void Append(const char* buf, const size_t size) {
+ assert(pos_ + size <= size_);
+ memcpy(buf_.get() + pos_, buf, size);
+ pos_ += size;
+ assert(pos_ <= size_);
+ }
+
+ void FillTrailingZeros() {
+ assert(pos_ <= size_);
+ memset(buf_.get() + pos_, '0', size_ - pos_);
+ pos_ = size_;
+ }
+
+ void Reset() { pos_ = 0; }
+ size_t Free() const { return size_ - pos_; }
+ size_t Capacity() const { return size_; }
+ size_t Used() const { return pos_; }
+ char* Data() const { return buf_.get(); }
+
+ private:
+ std::unique_ptr<char[]> buf_;
+ const size_t size_;
+ size_t pos_;
+};
+
+//
+// CacheWriteBufferAllocator
+//
+// Buffer pool abstraction(not thread safe)
+//
+class CacheWriteBufferAllocator {
+ public:
+ explicit CacheWriteBufferAllocator(const size_t buffer_size,
+ const size_t buffer_count)
+ : cond_empty_(&lock_), buffer_size_(buffer_size) {
+ MutexLock _(&lock_);
+ buffer_size_ = buffer_size;
+ for (uint32_t i = 0; i < buffer_count; i++) {
+ auto* buf = new CacheWriteBuffer(buffer_size_);
+ assert(buf);
+ if (buf) {
+ bufs_.push_back(buf);
+ cond_empty_.Signal();
+ }
+ }
+ }
+
+ virtual ~CacheWriteBufferAllocator() {
+ MutexLock _(&lock_);
+ assert(bufs_.size() * buffer_size_ == Capacity());
+ for (auto* buf : bufs_) {
+ delete buf;
+ }
+ bufs_.clear();
+ }
+
+ CacheWriteBuffer* Allocate() {
+ MutexLock _(&lock_);
+ if (bufs_.empty()) {
+ return nullptr;
+ }
+
+ assert(!bufs_.empty());
+ CacheWriteBuffer* const buf = bufs_.front();
+ bufs_.pop_front();
+ return buf;
+ }
+
+ void Deallocate(CacheWriteBuffer* const buf) {
+ assert(buf);
+ MutexLock _(&lock_);
+ buf->Reset();
+ bufs_.push_back(buf);
+ cond_empty_.Signal();
+ }
+
+ void WaitUntilUsable() {
+ // We are asked to wait till we have buffers available
+ MutexLock _(&lock_);
+ while (bufs_.empty()) {
+ cond_empty_.Wait();
+ }
+ }
+
+ size_t Capacity() const { return bufs_.size() * buffer_size_; }
+ size_t Free() const { return bufs_.size() * buffer_size_; }
+ size_t BufferSize() const { return buffer_size_; }
+
+ private:
+ port::Mutex lock_; // Sync lock
+ port::CondVar cond_empty_; // Condition var for empty buffers
+ size_t buffer_size_; // Size of each buffer
+ std::list<CacheWriteBuffer*> bufs_; // Buffer stash
+};
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/persistent_cache/block_cache_tier_metadata.cc b/src/rocksdb/utilities/persistent_cache/block_cache_tier_metadata.cc
new file mode 100644
index 000000000..d73b5d0b4
--- /dev/null
+++ b/src/rocksdb/utilities/persistent_cache/block_cache_tier_metadata.cc
@@ -0,0 +1,86 @@
+// Copyright (c) 2013, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+#ifndef ROCKSDB_LITE
+
+#include "utilities/persistent_cache/block_cache_tier_metadata.h"
+
+#include <functional>
+
+namespace ROCKSDB_NAMESPACE {
+
+bool BlockCacheTierMetadata::Insert(BlockCacheFile* file) {
+ return cache_file_index_.Insert(file);
+}
+
+BlockCacheFile* BlockCacheTierMetadata::Lookup(const uint32_t cache_id) {
+ BlockCacheFile* ret = nullptr;
+ BlockCacheFile lookup_key(cache_id);
+ bool ok = cache_file_index_.Find(&lookup_key, &ret);
+ if (ok) {
+ assert(ret->refs_);
+ return ret;
+ }
+ return nullptr;
+}
+
+BlockCacheFile* BlockCacheTierMetadata::Evict() {
+ using std::placeholders::_1;
+ auto fn = std::bind(&BlockCacheTierMetadata::RemoveAllKeys, this, _1);
+ return cache_file_index_.Evict(fn);
+}
+
+void BlockCacheTierMetadata::Clear() {
+ cache_file_index_.Clear([](BlockCacheFile* arg) { delete arg; });
+ block_index_.Clear([](BlockInfo* arg) { delete arg; });
+}
+
+BlockInfo* BlockCacheTierMetadata::Insert(const Slice& key, const LBA& lba) {
+ std::unique_ptr<BlockInfo> binfo(new BlockInfo(key, lba));
+ if (!block_index_.Insert(binfo.get())) {
+ return nullptr;
+ }
+ return binfo.release();
+}
+
+bool BlockCacheTierMetadata::Lookup(const Slice& key, LBA* lba) {
+ BlockInfo lookup_key(key);
+ BlockInfo* block;
+ port::RWMutex* rlock = nullptr;
+ if (!block_index_.Find(&lookup_key, &block, &rlock)) {
+ return false;
+ }
+
+ ReadUnlock _(rlock);
+ assert(block->key_ == key.ToString());
+ if (lba) {
+ *lba = block->lba_;
+ }
+ return true;
+}
+
+BlockInfo* BlockCacheTierMetadata::Remove(const Slice& key) {
+ BlockInfo lookup_key(key);
+ BlockInfo* binfo = nullptr;
+ bool ok __attribute__((__unused__));
+ ok = block_index_.Erase(&lookup_key, &binfo);
+ assert(ok);
+ return binfo;
+}
+
+void BlockCacheTierMetadata::RemoveAllKeys(BlockCacheFile* f) {
+ for (BlockInfo* binfo : f->block_infos()) {
+ BlockInfo* tmp = nullptr;
+ bool status = block_index_.Erase(binfo, &tmp);
+ (void)status;
+ assert(status);
+ assert(tmp == binfo);
+ delete binfo;
+ }
+ f->block_infos().clear();
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif
diff --git a/src/rocksdb/utilities/persistent_cache/block_cache_tier_metadata.h b/src/rocksdb/utilities/persistent_cache/block_cache_tier_metadata.h
new file mode 100644
index 000000000..2fcd50105
--- /dev/null
+++ b/src/rocksdb/utilities/persistent_cache/block_cache_tier_metadata.h
@@ -0,0 +1,124 @@
+// Copyright (c) 2013, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+#pragma once
+
+#ifndef ROCKSDB_LITE
+
+#include <functional>
+#include <string>
+#include <unordered_map>
+
+#include "rocksdb/slice.h"
+#include "utilities/persistent_cache/block_cache_tier_file.h"
+#include "utilities/persistent_cache/hash_table.h"
+#include "utilities/persistent_cache/hash_table_evictable.h"
+#include "utilities/persistent_cache/lrulist.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+//
+// Block Cache Tier Metadata
+//
+// The BlockCacheTierMetadata holds all the metadata associated with block
+// cache. It
+// fundamentally contains 2 indexes and an LRU.
+//
+// Block Cache Index
+//
+// This is a forward index that maps a given key to a LBA (Logical Block
+// Address). LBA is a disk pointer that points to a record on the cache.
+//
+// LBA = { cache-id, offset, size }
+//
+// Cache File Index
+//
+// This is a forward index that maps a given cache-id to a cache file object.
+// Typically you would lookup using LBA and use the object to read or write
+struct BlockInfo {
+ explicit BlockInfo(const Slice& key, const LBA& lba = LBA())
+ : key_(key.ToString()), lba_(lba) {}
+
+ std::string key_;
+ LBA lba_;
+};
+
+class BlockCacheTierMetadata {
+ public:
+ explicit BlockCacheTierMetadata(const uint32_t blocks_capacity = 1024 * 1024,
+ const uint32_t cachefile_capacity = 10 * 1024)
+ : cache_file_index_(cachefile_capacity), block_index_(blocks_capacity) {}
+
+ virtual ~BlockCacheTierMetadata() {}
+
+ // Insert a given cache file
+ bool Insert(BlockCacheFile* file);
+
+ // Lookup cache file based on cache_id
+ BlockCacheFile* Lookup(const uint32_t cache_id);
+
+ // Insert block information to block index
+ BlockInfo* Insert(const Slice& key, const LBA& lba);
+ // bool Insert(BlockInfo* binfo);
+
+ // Lookup block information from block index
+ bool Lookup(const Slice& key, LBA* lba);
+
+ // Remove a given from the block index
+ BlockInfo* Remove(const Slice& key);
+
+ // Find and evict a cache file using LRU policy
+ BlockCacheFile* Evict();
+
+ // Clear the metadata contents
+ virtual void Clear();
+
+ protected:
+ // Remove all block information from a given file
+ virtual void RemoveAllKeys(BlockCacheFile* file);
+
+ private:
+ // Cache file index definition
+ //
+ // cache-id => BlockCacheFile
+ struct BlockCacheFileHash {
+ uint64_t operator()(const BlockCacheFile* rec) {
+ return std::hash<uint32_t>()(rec->cacheid());
+ }
+ };
+
+ struct BlockCacheFileEqual {
+ uint64_t operator()(const BlockCacheFile* lhs, const BlockCacheFile* rhs) {
+ return lhs->cacheid() == rhs->cacheid();
+ }
+ };
+
+ using CacheFileIndexType =
+ EvictableHashTable<BlockCacheFile, BlockCacheFileHash,
+ BlockCacheFileEqual>;
+
+ // Block Lookup Index
+ //
+ // key => LBA
+ struct Hash {
+ size_t operator()(BlockInfo* node) const {
+ return std::hash<std::string>()(node->key_);
+ }
+ };
+
+ struct Equal {
+ size_t operator()(BlockInfo* lhs, BlockInfo* rhs) const {
+ return lhs->key_ == rhs->key_;
+ }
+ };
+
+ using BlockIndexType = HashTable<BlockInfo*, Hash, Equal>;
+
+ CacheFileIndexType cache_file_index_;
+ BlockIndexType block_index_;
+};
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif
diff --git a/src/rocksdb/utilities/persistent_cache/hash_table.h b/src/rocksdb/utilities/persistent_cache/hash_table.h
new file mode 100644
index 000000000..b00b294ce
--- /dev/null
+++ b/src/rocksdb/utilities/persistent_cache/hash_table.h
@@ -0,0 +1,239 @@
+// Copyright (c) 2013, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+#pragma once
+
+#ifndef ROCKSDB_LITE
+
+#include <assert.h>
+
+#include <list>
+#include <vector>
+
+#ifdef OS_LINUX
+#include <sys/mman.h>
+#endif
+
+#include "rocksdb/env.h"
+#include "util/mutexlock.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// HashTable<T, Hash, Equal>
+//
+// Traditional implementation of hash table with synchronization built on top
+// don't perform very well in multi-core scenarios. This is an implementation
+// designed for multi-core scenarios with high lock contention.
+//
+// |<-------- alpha ------------->|
+// Buckets Collision list
+// ---- +----+ +---+---+--- ...... ---+---+---+
+// / | |--->| | | | | |
+// / +----+ +---+---+--- ...... ---+---+---+
+// / | |
+// Locks/ +----+
+// +--+/ . .
+// | | . .
+// +--+ . .
+// | | . .
+// +--+ . .
+// | | . .
+// +--+ . .
+// \ +----+
+// \ | |
+// \ +----+
+// \ | |
+// \---- +----+
+//
+// The lock contention is spread over an array of locks. This helps improve
+// concurrent access. The spine is designed for a certain capacity and load
+// factor. When the capacity planning is done correctly we can expect
+// O(load_factor = 1) insert, access and remove time.
+//
+// Micro benchmark on debug build gives about .5 Million/sec rate of insert,
+// erase and lookup in parallel (total of about 1.5 Million ops/sec). If the
+// blocks were of 4K, the hash table can support a virtual throughput of
+// 6 GB/s.
+//
+// T Object type (contains both key and value)
+// Hash Function that returns an hash from type T
+// Equal Returns if two objects are equal
+// (We need explicit equal for pointer type)
+//
+template <class T, class Hash, class Equal>
+class HashTable {
+ public:
+ explicit HashTable(const size_t capacity = 1024 * 1024,
+ const float load_factor = 2.0, const uint32_t nlocks = 256)
+ : nbuckets_(
+ static_cast<uint32_t>(load_factor ? capacity / load_factor : 0)),
+ nlocks_(nlocks) {
+ // pre-conditions
+ assert(capacity);
+ assert(load_factor);
+ assert(nbuckets_);
+ assert(nlocks_);
+
+ buckets_.reset(new Bucket[nbuckets_]);
+#ifdef OS_LINUX
+ mlock(buckets_.get(), nbuckets_ * sizeof(Bucket));
+#endif
+
+ // initialize locks
+ locks_.reset(new port::RWMutex[nlocks_]);
+#ifdef OS_LINUX
+ mlock(locks_.get(), nlocks_ * sizeof(port::RWMutex));
+#endif
+
+ // post-conditions
+ assert(buckets_);
+ assert(locks_);
+ }
+
+ virtual ~HashTable() { AssertEmptyBuckets(); }
+
+ //
+ // Insert given record to hash table
+ //
+ bool Insert(const T& t) {
+ const uint64_t h = Hash()(t);
+ const uint32_t bucket_idx = h % nbuckets_;
+ const uint32_t lock_idx = bucket_idx % nlocks_;
+
+ WriteLock _(&locks_[lock_idx]);
+ auto& bucket = buckets_[bucket_idx];
+ return Insert(&bucket, t);
+ }
+
+ // Lookup hash table
+ //
+ // Please note that read lock should be held by the caller. This is because
+ // the caller owns the data, and should hold the read lock as long as he
+ // operates on the data.
+ bool Find(const T& t, T* ret, port::RWMutex** ret_lock) {
+ const uint64_t h = Hash()(t);
+ const uint32_t bucket_idx = h % nbuckets_;
+ const uint32_t lock_idx = bucket_idx % nlocks_;
+
+ port::RWMutex& lock = locks_[lock_idx];
+ lock.ReadLock();
+
+ auto& bucket = buckets_[bucket_idx];
+ if (Find(&bucket, t, ret)) {
+ *ret_lock = &lock;
+ return true;
+ }
+
+ lock.ReadUnlock();
+ return false;
+ }
+
+ //
+ // Erase a given key from the hash table
+ //
+ bool Erase(const T& t, T* ret) {
+ const uint64_t h = Hash()(t);
+ const uint32_t bucket_idx = h % nbuckets_;
+ const uint32_t lock_idx = bucket_idx % nlocks_;
+
+ WriteLock _(&locks_[lock_idx]);
+
+ auto& bucket = buckets_[bucket_idx];
+ return Erase(&bucket, t, ret);
+ }
+
+ // Fetch the mutex associated with a key
+ // This call is used to hold the lock for a given data for extended period of
+ // time.
+ port::RWMutex* GetMutex(const T& t) {
+ const uint64_t h = Hash()(t);
+ const uint32_t bucket_idx = h % nbuckets_;
+ const uint32_t lock_idx = bucket_idx % nlocks_;
+
+ return &locks_[lock_idx];
+ }
+
+ void Clear(void (*fn)(T)) {
+ for (uint32_t i = 0; i < nbuckets_; ++i) {
+ const uint32_t lock_idx = i % nlocks_;
+ WriteLock _(&locks_[lock_idx]);
+ for (auto& t : buckets_[i].list_) {
+ (*fn)(t);
+ }
+ buckets_[i].list_.clear();
+ }
+ }
+
+ protected:
+ // Models bucket of keys that hash to the same bucket number
+ struct Bucket {
+ std::list<T> list_;
+ };
+
+ // Substitute for std::find with custom comparator operator
+ typename std::list<T>::iterator Find(std::list<T>* list, const T& t) {
+ for (auto it = list->begin(); it != list->end(); ++it) {
+ if (Equal()(*it, t)) {
+ return it;
+ }
+ }
+ return list->end();
+ }
+
+ bool Insert(Bucket* bucket, const T& t) {
+ // Check if the key already exists
+ auto it = Find(&bucket->list_, t);
+ if (it != bucket->list_.end()) {
+ return false;
+ }
+
+ // insert to bucket
+ bucket->list_.push_back(t);
+ return true;
+ }
+
+ bool Find(Bucket* bucket, const T& t, T* ret) {
+ auto it = Find(&bucket->list_, t);
+ if (it != bucket->list_.end()) {
+ if (ret) {
+ *ret = *it;
+ }
+ return true;
+ }
+ return false;
+ }
+
+ bool Erase(Bucket* bucket, const T& t, T* ret) {
+ auto it = Find(&bucket->list_, t);
+ if (it != bucket->list_.end()) {
+ if (ret) {
+ *ret = *it;
+ }
+
+ bucket->list_.erase(it);
+ return true;
+ }
+ return false;
+ }
+
+ // assert that all buckets are empty
+ void AssertEmptyBuckets() {
+#ifndef NDEBUG
+ for (size_t i = 0; i < nbuckets_; ++i) {
+ WriteLock _(&locks_[i % nlocks_]);
+ assert(buckets_[i].list_.empty());
+ }
+#endif
+ }
+
+ const uint32_t nbuckets_; // No. of buckets in the spine
+ std::unique_ptr<Bucket[]> buckets_; // Spine of the hash buckets
+ const uint32_t nlocks_; // No. of locks
+ std::unique_ptr<port::RWMutex[]> locks_; // Granular locks
+};
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif
diff --git a/src/rocksdb/utilities/persistent_cache/hash_table_bench.cc b/src/rocksdb/utilities/persistent_cache/hash_table_bench.cc
new file mode 100644
index 000000000..74d7e2edf
--- /dev/null
+++ b/src/rocksdb/utilities/persistent_cache/hash_table_bench.cc
@@ -0,0 +1,310 @@
+// Copyright (c) 2013, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+
+#if !defined(OS_WIN) && !defined(ROCKSDB_LITE)
+
+#ifndef GFLAGS
+#include <cstdio>
+int main() { fprintf(stderr, "Please install gflags to run tools\n"); }
+#else
+
+#include <sys/time.h>
+#include <unistd.h>
+
+#include <atomic>
+#include <functional>
+#include <string>
+#include <unordered_map>
+
+#include "port/port_posix.h"
+#include "port/sys_time.h"
+#include "rocksdb/env.h"
+#include "util/gflags_compat.h"
+#include "util/mutexlock.h"
+#include "util/random.h"
+#include "utilities/persistent_cache/hash_table.h"
+
+using std::string;
+
+DEFINE_int32(nsec, 10, "nsec");
+DEFINE_int32(nthread_write, 1, "insert %");
+DEFINE_int32(nthread_read, 1, "lookup %");
+DEFINE_int32(nthread_erase, 1, "erase %");
+
+namespace ROCKSDB_NAMESPACE {
+
+//
+// HashTableImpl interface
+//
+// Abstraction of a hash table implementation
+template <class Key, class Value>
+class HashTableImpl {
+ public:
+ virtual ~HashTableImpl() {}
+
+ virtual bool Insert(const Key& key, const Value& val) = 0;
+ virtual bool Erase(const Key& key) = 0;
+ virtual bool Lookup(const Key& key, Value* val) = 0;
+};
+
+// HashTableBenchmark
+//
+// Abstraction to test a given hash table implementation. The test mostly
+// focus on insert, lookup and erase. The test can operate in test mode and
+// benchmark mode.
+class HashTableBenchmark {
+ public:
+ explicit HashTableBenchmark(HashTableImpl<size_t, std::string>* impl,
+ const size_t sec = 10,
+ const size_t nthread_write = 1,
+ const size_t nthread_read = 1,
+ const size_t nthread_erase = 1)
+ : impl_(impl),
+ sec_(sec),
+ ninserts_(0),
+ nreads_(0),
+ nerases_(0),
+ nerases_failed_(0),
+ quit_(false) {
+ Prepop();
+
+ StartThreads(nthread_write, WriteMain);
+ StartThreads(nthread_read, ReadMain);
+ StartThreads(nthread_erase, EraseMain);
+
+ uint64_t start = NowInMillSec();
+ while (!quit_) {
+ quit_ = NowInMillSec() - start > sec_ * 1000;
+ /* sleep override */ sleep(1);
+ }
+
+ Env* env = Env::Default();
+ env->WaitForJoin();
+
+ if (sec_) {
+ printf("Result \n");
+ printf("====== \n");
+ printf("insert/sec = %f \n", ninserts_ / static_cast<double>(sec_));
+ printf("read/sec = %f \n", nreads_ / static_cast<double>(sec_));
+ printf("erases/sec = %f \n", nerases_ / static_cast<double>(sec_));
+ const uint64_t ops = ninserts_ + nreads_ + nerases_;
+ printf("ops/sec = %f \n", ops / static_cast<double>(sec_));
+ printf("erase fail = %d (%f%%)\n", static_cast<int>(nerases_failed_),
+ static_cast<float>(nerases_failed_ / nerases_ * 100));
+ printf("====== \n");
+ }
+ }
+
+ void RunWrite() {
+ while (!quit_) {
+ size_t k = insert_key_++;
+ std::string tmp(1000, k % 255);
+ bool status = impl_->Insert(k, tmp);
+ assert(status);
+ ninserts_++;
+ }
+ }
+
+ void RunRead() {
+ Random64 rgen(time(nullptr));
+ while (!quit_) {
+ std::string s;
+ size_t k = rgen.Next() % max_prepop_key;
+ bool status = impl_->Lookup(k, &s);
+ assert(status);
+ assert(s == std::string(1000, k % 255));
+ nreads_++;
+ }
+ }
+
+ void RunErase() {
+ while (!quit_) {
+ size_t k = erase_key_++;
+ bool status = impl_->Erase(k);
+ nerases_failed_ += !status;
+ nerases_++;
+ }
+ }
+
+ private:
+ // Start threads for a given function
+ void StartThreads(const size_t n, void (*fn)(void*)) {
+ Env* env = Env::Default();
+ for (size_t i = 0; i < n; ++i) {
+ env->StartThread(fn, this);
+ }
+ }
+
+ // Prepop the hash table with 1M keys
+ void Prepop() {
+ for (size_t i = 0; i < max_prepop_key; ++i) {
+ bool status = impl_->Insert(i, std::string(1000, i % 255));
+ assert(status);
+ }
+
+ erase_key_ = insert_key_ = max_prepop_key;
+
+ for (size_t i = 0; i < 10 * max_prepop_key; ++i) {
+ bool status = impl_->Insert(insert_key_++, std::string(1000, 'x'));
+ assert(status);
+ }
+ }
+
+ static uint64_t NowInMillSec() {
+ port::TimeVal tv;
+ port::GetTimeOfDay(&tv, /*tz=*/nullptr);
+ return tv.tv_sec * 1000 + tv.tv_usec / 1000;
+ }
+
+ //
+ // Wrapper functions for thread entry
+ //
+ static void WriteMain(void* args) {
+ reinterpret_cast<HashTableBenchmark*>(args)->RunWrite();
+ }
+
+ static void ReadMain(void* args) {
+ reinterpret_cast<HashTableBenchmark*>(args)->RunRead();
+ }
+
+ static void EraseMain(void* args) {
+ reinterpret_cast<HashTableBenchmark*>(args)->RunErase();
+ }
+
+ HashTableImpl<size_t, std::string>* impl_; // Implementation to test
+ const size_t sec_; // Test time
+ const size_t max_prepop_key = 1ULL * 1024 * 1024; // Max prepop key
+ std::atomic<size_t> insert_key_; // Last inserted key
+ std::atomic<size_t> erase_key_; // Erase key
+ std::atomic<size_t> ninserts_; // Number of inserts
+ std::atomic<size_t> nreads_; // Number of reads
+ std::atomic<size_t> nerases_; // Number of erases
+ std::atomic<size_t> nerases_failed_; // Number of erases failed
+ bool quit_; // Should the threads quit ?
+};
+
+//
+// SimpleImpl
+// Lock safe unordered_map implementation
+class SimpleImpl : public HashTableImpl<size_t, string> {
+ public:
+ bool Insert(const size_t& key, const string& val) override {
+ WriteLock _(&rwlock_);
+ map_.insert(make_pair(key, val));
+ return true;
+ }
+
+ bool Erase(const size_t& key) override {
+ WriteLock _(&rwlock_);
+ auto it = map_.find(key);
+ if (it == map_.end()) {
+ return false;
+ }
+ map_.erase(it);
+ return true;
+ }
+
+ bool Lookup(const size_t& key, string* val) override {
+ ReadLock _(&rwlock_);
+ auto it = map_.find(key);
+ if (it != map_.end()) {
+ *val = it->second;
+ }
+ return it != map_.end();
+ }
+
+ private:
+ port::RWMutex rwlock_;
+ std::unordered_map<size_t, string> map_;
+};
+
+//
+// GranularLockImpl
+// Thread safe custom RocksDB implementation of hash table with granular
+// locking
+class GranularLockImpl : public HashTableImpl<size_t, string> {
+ public:
+ bool Insert(const size_t& key, const string& val) override {
+ Node n(key, val);
+ return impl_.Insert(n);
+ }
+
+ bool Erase(const size_t& key) override {
+ Node n(key, string());
+ return impl_.Erase(n, nullptr);
+ }
+
+ bool Lookup(const size_t& key, string* val) override {
+ Node n(key, string());
+ port::RWMutex* rlock;
+ bool status = impl_.Find(n, &n, &rlock);
+ if (status) {
+ ReadUnlock _(rlock);
+ *val = n.val_;
+ }
+ return status;
+ }
+
+ private:
+ struct Node {
+ explicit Node(const size_t key, const string& val) : key_(key), val_(val) {}
+
+ size_t key_ = 0;
+ string val_;
+ };
+
+ struct Hash {
+ uint64_t operator()(const Node& node) {
+ return std::hash<uint64_t>()(node.key_);
+ }
+ };
+
+ struct Equal {
+ bool operator()(const Node& lhs, const Node& rhs) {
+ return lhs.key_ == rhs.key_;
+ }
+ };
+
+ HashTable<Node, Hash, Equal> impl_;
+};
+
+} // namespace ROCKSDB_NAMESPACE
+
+//
+// main
+//
+int main(int argc, char** argv) {
+ GFLAGS_NAMESPACE::SetUsageMessage(std::string("\nUSAGE:\n") +
+ std::string(argv[0]) + " [OPTIONS]...");
+ GFLAGS_NAMESPACE::ParseCommandLineFlags(&argc, &argv, false);
+
+ //
+ // Micro benchmark unordered_map
+ //
+ printf("Micro benchmarking std::unordered_map \n");
+ {
+ ROCKSDB_NAMESPACE::SimpleImpl impl;
+ ROCKSDB_NAMESPACE::HashTableBenchmark _(
+ &impl, FLAGS_nsec, FLAGS_nthread_write, FLAGS_nthread_read,
+ FLAGS_nthread_erase);
+ }
+ //
+ // Micro benchmark scalable hash table
+ //
+ printf("Micro benchmarking scalable hash map \n");
+ {
+ ROCKSDB_NAMESPACE::GranularLockImpl impl;
+ ROCKSDB_NAMESPACE::HashTableBenchmark _(
+ &impl, FLAGS_nsec, FLAGS_nthread_write, FLAGS_nthread_read,
+ FLAGS_nthread_erase);
+ }
+
+ return 0;
+}
+#endif // #ifndef GFLAGS
+#else
+int main(int /*argc*/, char** /*argv*/) { return 0; }
+#endif
diff --git a/src/rocksdb/utilities/persistent_cache/hash_table_evictable.h b/src/rocksdb/utilities/persistent_cache/hash_table_evictable.h
new file mode 100644
index 000000000..e10939b2f
--- /dev/null
+++ b/src/rocksdb/utilities/persistent_cache/hash_table_evictable.h
@@ -0,0 +1,168 @@
+// Copyright (c) 2013, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+#pragma once
+
+#ifndef ROCKSDB_LITE
+
+#include <functional>
+
+#include "util/random.h"
+#include "utilities/persistent_cache/hash_table.h"
+#include "utilities/persistent_cache/lrulist.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// Evictable Hash Table
+//
+// Hash table index where least accessed (or one of the least accessed) elements
+// can be evicted.
+//
+// Please note EvictableHashTable can only be created for pointer type objects
+template <class T, class Hash, class Equal>
+class EvictableHashTable : private HashTable<T*, Hash, Equal> {
+ public:
+ using hash_table = HashTable<T*, Hash, Equal>;
+
+ explicit EvictableHashTable(const size_t capacity = 1024 * 1024,
+ const float load_factor = 2.0,
+ const uint32_t nlocks = 256)
+ : HashTable<T*, Hash, Equal>(capacity, load_factor, nlocks),
+ lru_lists_(new LRUList<T>[hash_table::nlocks_]) {
+ assert(lru_lists_);
+ }
+
+ virtual ~EvictableHashTable() { AssertEmptyLRU(); }
+
+ //
+ // Insert given record to hash table (and LRU list)
+ //
+ bool Insert(T* t) {
+ const uint64_t h = Hash()(t);
+ typename hash_table::Bucket& bucket = GetBucket(h);
+ LRUListType& lru = GetLRUList(h);
+ port::RWMutex& lock = GetMutex(h);
+
+ WriteLock _(&lock);
+ if (hash_table::Insert(&bucket, t)) {
+ lru.Push(t);
+ return true;
+ }
+ return false;
+ }
+
+ //
+ // Lookup hash table
+ //
+ // Please note that read lock should be held by the caller. This is because
+ // the caller owns the data, and should hold the read lock as long as he
+ // operates on the data.
+ bool Find(T* t, T** ret) {
+ const uint64_t h = Hash()(t);
+ typename hash_table::Bucket& bucket = GetBucket(h);
+ LRUListType& lru = GetLRUList(h);
+ port::RWMutex& lock = GetMutex(h);
+
+ ReadLock _(&lock);
+ if (hash_table::Find(&bucket, t, ret)) {
+ ++(*ret)->refs_;
+ lru.Touch(*ret);
+ return true;
+ }
+ return false;
+ }
+
+ //
+ // Evict one of the least recently used object
+ //
+ T* Evict(const std::function<void(T*)>& fn = nullptr) {
+ uint32_t random = Random::GetTLSInstance()->Next();
+ const size_t start_idx = random % hash_table::nlocks_;
+ T* t = nullptr;
+
+ // iterate from start_idx .. 0 .. start_idx
+ for (size_t i = 0; !t && i < hash_table::nlocks_; ++i) {
+ const size_t idx = (start_idx + i) % hash_table::nlocks_;
+
+ WriteLock _(&hash_table::locks_[idx]);
+ LRUListType& lru = lru_lists_[idx];
+ if (!lru.IsEmpty() && (t = lru.Pop()) != nullptr) {
+ assert(!t->refs_);
+ // We got an item to evict, erase from the bucket
+ const uint64_t h = Hash()(t);
+ typename hash_table::Bucket& bucket = GetBucket(h);
+ T* tmp = nullptr;
+ bool status = hash_table::Erase(&bucket, t, &tmp);
+ assert(t == tmp);
+ (void)status;
+ assert(status);
+ if (fn) {
+ fn(t);
+ }
+ break;
+ }
+ assert(!t);
+ }
+ return t;
+ }
+
+ void Clear(void (*fn)(T*)) {
+ for (uint32_t i = 0; i < hash_table::nbuckets_; ++i) {
+ const uint32_t lock_idx = i % hash_table::nlocks_;
+ WriteLock _(&hash_table::locks_[lock_idx]);
+ auto& lru_list = lru_lists_[lock_idx];
+ auto& bucket = hash_table::buckets_[i];
+ for (auto* t : bucket.list_) {
+ lru_list.Unlink(t);
+ (*fn)(t);
+ }
+ bucket.list_.clear();
+ }
+ // make sure that all LRU lists are emptied
+ AssertEmptyLRU();
+ }
+
+ void AssertEmptyLRU() {
+#ifndef NDEBUG
+ for (uint32_t i = 0; i < hash_table::nlocks_; ++i) {
+ WriteLock _(&hash_table::locks_[i]);
+ auto& lru_list = lru_lists_[i];
+ assert(lru_list.IsEmpty());
+ }
+#endif
+ }
+
+ //
+ // Fetch the mutex associated with a key
+ // This call is used to hold the lock for a given data for extended period of
+ // time.
+ port::RWMutex* GetMutex(T* t) { return hash_table::GetMutex(t); }
+
+ private:
+ using LRUListType = LRUList<T>;
+
+ typename hash_table::Bucket& GetBucket(const uint64_t h) {
+ const uint32_t bucket_idx = h % hash_table::nbuckets_;
+ return hash_table::buckets_[bucket_idx];
+ }
+
+ LRUListType& GetLRUList(const uint64_t h) {
+ const uint32_t bucket_idx = h % hash_table::nbuckets_;
+ const uint32_t lock_idx = bucket_idx % hash_table::nlocks_;
+ return lru_lists_[lock_idx];
+ }
+
+ port::RWMutex& GetMutex(const uint64_t h) {
+ const uint32_t bucket_idx = h % hash_table::nbuckets_;
+ const uint32_t lock_idx = bucket_idx % hash_table::nlocks_;
+ return hash_table::locks_[lock_idx];
+ }
+
+ std::unique_ptr<LRUListType[]> lru_lists_;
+};
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif
diff --git a/src/rocksdb/utilities/persistent_cache/hash_table_test.cc b/src/rocksdb/utilities/persistent_cache/hash_table_test.cc
new file mode 100644
index 000000000..2f6387f5f
--- /dev/null
+++ b/src/rocksdb/utilities/persistent_cache/hash_table_test.cc
@@ -0,0 +1,163 @@
+// Copyright (c) 2013, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+#include "utilities/persistent_cache/hash_table.h"
+
+#include <stdlib.h>
+
+#include <iostream>
+#include <set>
+#include <string>
+
+#include "db/db_test_util.h"
+#include "memory/arena.h"
+#include "test_util/testharness.h"
+#include "util/random.h"
+#include "utilities/persistent_cache/hash_table_evictable.h"
+
+#ifndef ROCKSDB_LITE
+
+namespace ROCKSDB_NAMESPACE {
+
+struct HashTableTest : public testing::Test {
+ ~HashTableTest() override { map_.Clear(&HashTableTest::ClearNode); }
+
+ struct Node {
+ Node() {}
+ explicit Node(const uint64_t key, const std::string& val = std::string())
+ : key_(key), val_(val) {}
+
+ uint64_t key_ = 0;
+ std::string val_;
+ };
+
+ struct Equal {
+ bool operator()(const Node& lhs, const Node& rhs) {
+ return lhs.key_ == rhs.key_;
+ }
+ };
+
+ struct Hash {
+ uint64_t operator()(const Node& node) {
+ return std::hash<uint64_t>()(node.key_);
+ }
+ };
+
+ static void ClearNode(Node /*node*/) {}
+
+ HashTable<Node, Hash, Equal> map_;
+};
+
+struct EvictableHashTableTest : public testing::Test {
+ ~EvictableHashTableTest() override {
+ map_.Clear(&EvictableHashTableTest::ClearNode);
+ }
+
+ struct Node : LRUElement<Node> {
+ Node() {}
+ explicit Node(const uint64_t key, const std::string& val = std::string())
+ : key_(key), val_(val) {}
+
+ uint64_t key_ = 0;
+ std::string val_;
+ std::atomic<uint32_t> refs_{0};
+ };
+
+ struct Equal {
+ bool operator()(const Node* lhs, const Node* rhs) {
+ return lhs->key_ == rhs->key_;
+ }
+ };
+
+ struct Hash {
+ uint64_t operator()(const Node* node) {
+ return std::hash<uint64_t>()(node->key_);
+ }
+ };
+
+ static void ClearNode(Node* /*node*/) {}
+
+ EvictableHashTable<Node, Hash, Equal> map_;
+};
+
+TEST_F(HashTableTest, TestInsert) {
+ const uint64_t max_keys = 1024 * 1024;
+
+ // insert
+ for (uint64_t k = 0; k < max_keys; ++k) {
+ map_.Insert(Node(k, std::string(1000, k % 255)));
+ }
+
+ // verify
+ for (uint64_t k = 0; k < max_keys; ++k) {
+ Node val;
+ port::RWMutex* rlock = nullptr;
+ assert(map_.Find(Node(k), &val, &rlock));
+ rlock->ReadUnlock();
+ assert(val.val_ == std::string(1000, k % 255));
+ }
+}
+
+TEST_F(HashTableTest, TestErase) {
+ const uint64_t max_keys = 1024 * 1024;
+ // insert
+ for (uint64_t k = 0; k < max_keys; ++k) {
+ map_.Insert(Node(k, std::string(1000, k % 255)));
+ }
+
+ auto rand = Random64(time(nullptr));
+ // erase a few keys randomly
+ std::set<uint64_t> erased;
+ for (int i = 0; i < 1024; ++i) {
+ uint64_t k = rand.Next() % max_keys;
+ if (erased.find(k) != erased.end()) {
+ continue;
+ }
+ assert(map_.Erase(Node(k), /*ret=*/nullptr));
+ erased.insert(k);
+ }
+
+ // verify
+ for (uint64_t k = 0; k < max_keys; ++k) {
+ Node val;
+ port::RWMutex* rlock = nullptr;
+ bool status = map_.Find(Node(k), &val, &rlock);
+ if (erased.find(k) == erased.end()) {
+ assert(status);
+ rlock->ReadUnlock();
+ assert(val.val_ == std::string(1000, k % 255));
+ } else {
+ assert(!status);
+ }
+ }
+}
+
+TEST_F(EvictableHashTableTest, TestEvict) {
+ const uint64_t max_keys = 1024 * 1024;
+
+ // insert
+ for (uint64_t k = 0; k < max_keys; ++k) {
+ map_.Insert(new Node(k, std::string(1000, k % 255)));
+ }
+
+ // verify
+ for (uint64_t k = 0; k < max_keys; ++k) {
+ Node* val = map_.Evict();
+ // unfortunately we can't predict eviction value since it is from any one of
+ // the lock stripe
+ assert(val);
+ assert(val->val_ == std::string(1000, val->key_ % 255));
+ delete val;
+ }
+}
+
+} // namespace ROCKSDB_NAMESPACE
+#endif
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/rocksdb/utilities/persistent_cache/lrulist.h b/src/rocksdb/utilities/persistent_cache/lrulist.h
new file mode 100644
index 000000000..a608890fc
--- /dev/null
+++ b/src/rocksdb/utilities/persistent_cache/lrulist.h
@@ -0,0 +1,174 @@
+// Copyright (c) 2013, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+#pragma once
+
+#ifndef ROCKSDB_LITE
+
+#include <atomic>
+
+#include "util/mutexlock.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// LRU element definition
+//
+// Any object that needs to be part of the LRU algorithm should extend this
+// class
+template <class T>
+struct LRUElement {
+ explicit LRUElement() : next_(nullptr), prev_(nullptr), refs_(0) {}
+
+ virtual ~LRUElement() { assert(!refs_); }
+
+ T* next_;
+ T* prev_;
+ std::atomic<size_t> refs_;
+};
+
+// LRU implementation
+//
+// In place LRU implementation. There is no copy or allocation involved when
+// inserting or removing an element. This makes the data structure slim
+template <class T>
+class LRUList {
+ public:
+ virtual ~LRUList() {
+ MutexLock _(&lock_);
+ assert(!head_);
+ assert(!tail_);
+ }
+
+ // Push element into the LRU at the cold end
+ inline void Push(T* const t) {
+ assert(t);
+ assert(!t->next_);
+ assert(!t->prev_);
+
+ MutexLock _(&lock_);
+
+ assert((!head_ && !tail_) || (head_ && tail_));
+ assert(!head_ || !head_->prev_);
+ assert(!tail_ || !tail_->next_);
+
+ t->next_ = head_;
+ if (head_) {
+ head_->prev_ = t;
+ }
+
+ head_ = t;
+ if (!tail_) {
+ tail_ = t;
+ }
+ }
+
+ // Unlink the element from the LRU
+ inline void Unlink(T* const t) {
+ MutexLock _(&lock_);
+ UnlinkImpl(t);
+ }
+
+ // Evict an element from the LRU
+ inline T* Pop() {
+ MutexLock _(&lock_);
+
+ assert(tail_ && head_);
+ assert(!tail_->next_);
+ assert(!head_->prev_);
+
+ T* t = head_;
+ while (t && t->refs_) {
+ t = t->next_;
+ }
+
+ if (!t) {
+ // nothing can be evicted
+ return nullptr;
+ }
+
+ assert(!t->refs_);
+
+ // unlike the element
+ UnlinkImpl(t);
+ return t;
+ }
+
+ // Move the element from the front of the list to the back of the list
+ inline void Touch(T* const t) {
+ MutexLock _(&lock_);
+ UnlinkImpl(t);
+ PushBackImpl(t);
+ }
+
+ // Check if the LRU is empty
+ inline bool IsEmpty() const {
+ MutexLock _(&lock_);
+ return !head_ && !tail_;
+ }
+
+ private:
+ // Unlink an element from the LRU
+ void UnlinkImpl(T* const t) {
+ assert(t);
+
+ lock_.AssertHeld();
+
+ assert(head_ && tail_);
+ assert(t->prev_ || head_ == t);
+ assert(t->next_ || tail_ == t);
+
+ if (t->prev_) {
+ t->prev_->next_ = t->next_;
+ }
+ if (t->next_) {
+ t->next_->prev_ = t->prev_;
+ }
+
+ if (tail_ == t) {
+ tail_ = tail_->prev_;
+ }
+ if (head_ == t) {
+ head_ = head_->next_;
+ }
+
+ t->next_ = t->prev_ = nullptr;
+ }
+
+ // Insert an element at the hot end
+ inline void PushBack(T* const t) {
+ MutexLock _(&lock_);
+ PushBackImpl(t);
+ }
+
+ inline void PushBackImpl(T* const t) {
+ assert(t);
+ assert(!t->next_);
+ assert(!t->prev_);
+
+ lock_.AssertHeld();
+
+ assert((!head_ && !tail_) || (head_ && tail_));
+ assert(!head_ || !head_->prev_);
+ assert(!tail_ || !tail_->next_);
+
+ t->prev_ = tail_;
+ if (tail_) {
+ tail_->next_ = t;
+ }
+
+ tail_ = t;
+ if (!head_) {
+ head_ = tail_;
+ }
+ }
+
+ mutable port::Mutex lock_; // synchronization primitive
+ T* head_ = nullptr; // front (cold)
+ T* tail_ = nullptr; // back (hot)
+};
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif
diff --git a/src/rocksdb/utilities/persistent_cache/persistent_cache_bench.cc b/src/rocksdb/utilities/persistent_cache/persistent_cache_bench.cc
new file mode 100644
index 000000000..9d6e15d6b
--- /dev/null
+++ b/src/rocksdb/utilities/persistent_cache/persistent_cache_bench.cc
@@ -0,0 +1,359 @@
+// Copyright (c) 2013, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+#ifndef ROCKSDB_LITE
+
+#ifndef GFLAGS
+#include <cstdio>
+int main() { fprintf(stderr, "Please install gflags to run tools\n"); }
+#else
+#include <atomic>
+#include <functional>
+#include <memory>
+#include <sstream>
+#include <unordered_map>
+
+#include "monitoring/histogram.h"
+#include "port/port.h"
+#include "rocksdb/env.h"
+#include "rocksdb/system_clock.h"
+#include "table/block_based/block_builder.h"
+#include "util/gflags_compat.h"
+#include "util/mutexlock.h"
+#include "util/stop_watch.h"
+#include "utilities/persistent_cache/block_cache_tier.h"
+#include "utilities/persistent_cache/persistent_cache_tier.h"
+#include "utilities/persistent_cache/volatile_tier_impl.h"
+
+DEFINE_int32(nsec, 10, "nsec");
+DEFINE_int32(nthread_write, 1, "Insert threads");
+DEFINE_int32(nthread_read, 1, "Lookup threads");
+DEFINE_string(path, "/tmp/microbench/blkcache", "Path for cachefile");
+DEFINE_string(log_path, "/tmp/log", "Path for the log file");
+DEFINE_uint64(cache_size, std::numeric_limits<uint64_t>::max(), "Cache size");
+DEFINE_int32(iosize, 4 * 1024, "Read IO size");
+DEFINE_int32(writer_iosize, 4 * 1024, "File writer IO size");
+DEFINE_int32(writer_qdepth, 1, "File writer qdepth");
+DEFINE_bool(enable_pipelined_writes, false, "Enable async writes");
+DEFINE_string(cache_type, "block_cache",
+ "Cache type. (block_cache, volatile, tiered)");
+DEFINE_bool(benchmark, false, "Benchmark mode");
+DEFINE_int32(volatile_cache_pct, 10, "Percentage of cache in memory tier.");
+
+namespace ROCKSDB_NAMESPACE {
+
+std::unique_ptr<PersistentCacheTier> NewVolatileCache() {
+ assert(FLAGS_cache_size != std::numeric_limits<uint64_t>::max());
+ std::unique_ptr<PersistentCacheTier> pcache(
+ new VolatileCacheTier(FLAGS_cache_size));
+ return pcache;
+}
+
+std::unique_ptr<PersistentCacheTier> NewBlockCache() {
+ std::shared_ptr<Logger> log;
+ if (!Env::Default()->NewLogger(FLAGS_log_path, &log).ok()) {
+ fprintf(stderr, "Error creating log %s \n", FLAGS_log_path.c_str());
+ return nullptr;
+ }
+
+ PersistentCacheConfig opt(Env::Default(), FLAGS_path, FLAGS_cache_size, log);
+ opt.writer_dispatch_size = FLAGS_writer_iosize;
+ opt.writer_qdepth = FLAGS_writer_qdepth;
+ opt.pipeline_writes = FLAGS_enable_pipelined_writes;
+ opt.max_write_pipeline_backlog_size = std::numeric_limits<uint64_t>::max();
+ std::unique_ptr<PersistentCacheTier> cache(new BlockCacheTier(opt));
+ Status status = cache->Open();
+ return cache;
+}
+
+// create a new cache tier
+// construct a tiered RAM+Block cache
+std::unique_ptr<PersistentTieredCache> NewTieredCache(
+ const size_t mem_size, const PersistentCacheConfig& opt) {
+ std::unique_ptr<PersistentTieredCache> tcache(new PersistentTieredCache());
+ // create primary tier
+ assert(mem_size);
+ auto pcache =
+ std::shared_ptr<PersistentCacheTier>(new VolatileCacheTier(mem_size));
+ tcache->AddTier(pcache);
+ // create secondary tier
+ auto scache = std::shared_ptr<PersistentCacheTier>(new BlockCacheTier(opt));
+ tcache->AddTier(scache);
+
+ Status s = tcache->Open();
+ assert(s.ok());
+ return tcache;
+}
+
+std::unique_ptr<PersistentTieredCache> NewTieredCache() {
+ std::shared_ptr<Logger> log;
+ if (!Env::Default()->NewLogger(FLAGS_log_path, &log).ok()) {
+ fprintf(stderr, "Error creating log %s \n", FLAGS_log_path.c_str());
+ abort();
+ }
+
+ auto pct = FLAGS_volatile_cache_pct / static_cast<double>(100);
+ PersistentCacheConfig opt(Env::Default(), FLAGS_path,
+ (1 - pct) * FLAGS_cache_size, log);
+ opt.writer_dispatch_size = FLAGS_writer_iosize;
+ opt.writer_qdepth = FLAGS_writer_qdepth;
+ opt.pipeline_writes = FLAGS_enable_pipelined_writes;
+ opt.max_write_pipeline_backlog_size = std::numeric_limits<uint64_t>::max();
+ return NewTieredCache(FLAGS_cache_size * pct, opt);
+}
+
+//
+// Benchmark driver
+//
+class CacheTierBenchmark {
+ public:
+ explicit CacheTierBenchmark(std::shared_ptr<PersistentCacheTier>&& cache)
+ : cache_(cache) {
+ if (FLAGS_nthread_read) {
+ fprintf(stdout, "Pre-populating\n");
+ Prepop();
+ fprintf(stdout, "Pre-population completed\n");
+ }
+
+ stats_.Clear();
+
+ // Start IO threads
+ std::list<port::Thread> threads;
+ Spawn(FLAGS_nthread_write, &threads,
+ std::bind(&CacheTierBenchmark::Write, this));
+ Spawn(FLAGS_nthread_read, &threads,
+ std::bind(&CacheTierBenchmark::Read, this));
+
+ // Wait till FLAGS_nsec and then signal to quit
+ StopWatchNano t(SystemClock::Default().get(), /*auto_start=*/true);
+ size_t sec = t.ElapsedNanos() / 1000000000ULL;
+ while (!quit_) {
+ sec = t.ElapsedNanos() / 1000000000ULL;
+ quit_ = sec > size_t(FLAGS_nsec);
+ /* sleep override */ sleep(1);
+ }
+
+ // Wait for threads to exit
+ Join(&threads);
+ // Print stats
+ PrintStats(sec);
+ // Close the cache
+ cache_->TEST_Flush();
+ cache_->Close();
+ }
+
+ private:
+ void PrintStats(const size_t sec) {
+ std::ostringstream msg;
+ msg << "Test stats" << std::endl
+ << "* Elapsed: " << sec << " s" << std::endl
+ << "* Write Latency:" << std::endl
+ << stats_.write_latency_.ToString() << std::endl
+ << "* Read Latency:" << std::endl
+ << stats_.read_latency_.ToString() << std::endl
+ << "* Bytes written:" << std::endl
+ << stats_.bytes_written_.ToString() << std::endl
+ << "* Bytes read:" << std::endl
+ << stats_.bytes_read_.ToString() << std::endl
+ << "Cache stats:" << std::endl
+ << cache_->PrintStats() << std::endl;
+ fprintf(stderr, "%s\n", msg.str().c_str());
+ }
+
+ //
+ // Insert implementation and corresponding helper functions
+ //
+ void Prepop() {
+ for (uint64_t i = 0; i < 1024 * 1024; ++i) {
+ InsertKey(i);
+ insert_key_limit_++;
+ read_key_limit_++;
+ }
+
+ // Wait until data is flushed
+ cache_->TEST_Flush();
+ // warmup the cache
+ for (uint64_t i = 0; i < 1024 * 1024; ReadKey(i++)) {
+ }
+ }
+
+ void Write() {
+ while (!quit_) {
+ InsertKey(insert_key_limit_++);
+ }
+ }
+
+ void InsertKey(const uint64_t key) {
+ // construct key
+ uint64_t k[3];
+ Slice block_key = FillKey(k, key);
+
+ // construct value
+ auto block = NewBlock(key);
+
+ // insert
+ StopWatchNano timer(SystemClock::Default().get(), /*auto_start=*/true);
+ while (true) {
+ Status status = cache_->Insert(block_key, block.get(), FLAGS_iosize);
+ if (status.ok()) {
+ break;
+ }
+
+ // transient error is possible if we run without pipelining
+ assert(!FLAGS_enable_pipelined_writes);
+ }
+
+ // adjust stats
+ const size_t elapsed_micro = timer.ElapsedNanos() / 1000;
+ stats_.write_latency_.Add(elapsed_micro);
+ stats_.bytes_written_.Add(FLAGS_iosize);
+ }
+
+ //
+ // Read implementation
+ //
+ void Read() {
+ while (!quit_) {
+ ReadKey(random() % read_key_limit_);
+ }
+ }
+
+ void ReadKey(const uint64_t val) {
+ // construct key
+ uint64_t k[3];
+ Slice key = FillKey(k, val);
+
+ // Lookup in cache
+ StopWatchNano timer(SystemClock::Default().get(), /*auto_start=*/true);
+ std::unique_ptr<char[]> block;
+ size_t size;
+ Status status = cache_->Lookup(key, &block, &size);
+ if (!status.ok()) {
+ fprintf(stderr, "%s\n", status.ToString().c_str());
+ }
+ assert(status.ok());
+ assert(size == (size_t)FLAGS_iosize);
+
+ // adjust stats
+ const size_t elapsed_micro = timer.ElapsedNanos() / 1000;
+ stats_.read_latency_.Add(elapsed_micro);
+ stats_.bytes_read_.Add(FLAGS_iosize);
+
+ // verify content
+ if (!FLAGS_benchmark) {
+ auto expected_block = NewBlock(val);
+ assert(memcmp(block.get(), expected_block.get(), FLAGS_iosize) == 0);
+ }
+ }
+
+ // create data for a key by filling with a certain pattern
+ std::unique_ptr<char[]> NewBlock(const uint64_t val) {
+ std::unique_ptr<char[]> data(new char[FLAGS_iosize]);
+ memset(data.get(), val % 255, FLAGS_iosize);
+ return data;
+ }
+
+ // spawn threads
+ void Spawn(const size_t n, std::list<port::Thread>* threads,
+ const std::function<void()>& fn) {
+ for (size_t i = 0; i < n; ++i) {
+ threads->emplace_back(fn);
+ }
+ }
+
+ // join threads
+ void Join(std::list<port::Thread>* threads) {
+ for (auto& th : *threads) {
+ th.join();
+ }
+ }
+
+ // construct key
+ Slice FillKey(uint64_t (&k)[3], const uint64_t val) {
+ k[0] = k[1] = 0;
+ k[2] = val;
+ void* p = static_cast<void*>(&k);
+ return Slice(static_cast<char*>(p), sizeof(k));
+ }
+
+ // benchmark stats
+ struct Stats {
+ void Clear() {
+ bytes_written_.Clear();
+ bytes_read_.Clear();
+ read_latency_.Clear();
+ write_latency_.Clear();
+ }
+
+ HistogramImpl bytes_written_;
+ HistogramImpl bytes_read_;
+ HistogramImpl read_latency_;
+ HistogramImpl write_latency_;
+ };
+
+ std::shared_ptr<PersistentCacheTier> cache_; // cache implementation
+ std::atomic<uint64_t> insert_key_limit_{0}; // data inserted upto
+ std::atomic<uint64_t> read_key_limit_{0}; // data can be read safely upto
+ bool quit_ = false; // Quit thread ?
+ mutable Stats stats_; // Stats
+};
+
+} // namespace ROCKSDB_NAMESPACE
+
+//
+// main
+//
+int main(int argc, char** argv) {
+ GFLAGS_NAMESPACE::SetUsageMessage(std::string("\nUSAGE:\n") +
+ std::string(argv[0]) + " [OPTIONS]...");
+ GFLAGS_NAMESPACE::ParseCommandLineFlags(&argc, &argv, false);
+
+ std::ostringstream msg;
+ msg << "Config" << std::endl
+ << "======" << std::endl
+ << "* nsec=" << FLAGS_nsec << std::endl
+ << "* nthread_write=" << FLAGS_nthread_write << std::endl
+ << "* path=" << FLAGS_path << std::endl
+ << "* cache_size=" << FLAGS_cache_size << std::endl
+ << "* iosize=" << FLAGS_iosize << std::endl
+ << "* writer_iosize=" << FLAGS_writer_iosize << std::endl
+ << "* writer_qdepth=" << FLAGS_writer_qdepth << std::endl
+ << "* enable_pipelined_writes=" << FLAGS_enable_pipelined_writes
+ << std::endl
+ << "* cache_type=" << FLAGS_cache_type << std::endl
+ << "* benchmark=" << FLAGS_benchmark << std::endl
+ << "* volatile_cache_pct=" << FLAGS_volatile_cache_pct << std::endl;
+
+ fprintf(stderr, "%s\n", msg.str().c_str());
+
+ std::shared_ptr<ROCKSDB_NAMESPACE::PersistentCacheTier> cache;
+ if (FLAGS_cache_type == "block_cache") {
+ fprintf(stderr, "Using block cache implementation\n");
+ cache = ROCKSDB_NAMESPACE::NewBlockCache();
+ } else if (FLAGS_cache_type == "volatile") {
+ fprintf(stderr, "Using volatile cache implementation\n");
+ cache = ROCKSDB_NAMESPACE::NewVolatileCache();
+ } else if (FLAGS_cache_type == "tiered") {
+ fprintf(stderr, "Using tiered cache implementation\n");
+ cache = ROCKSDB_NAMESPACE::NewTieredCache();
+ } else {
+ fprintf(stderr, "Unknown option for cache\n");
+ }
+
+ assert(cache);
+ if (!cache) {
+ fprintf(stderr, "Error creating cache\n");
+ abort();
+ }
+
+ std::unique_ptr<ROCKSDB_NAMESPACE::CacheTierBenchmark> benchmark(
+ new ROCKSDB_NAMESPACE::CacheTierBenchmark(std::move(cache)));
+
+ return 0;
+}
+#endif // #ifndef GFLAGS
+#else
+int main(int, char**) { return 0; }
+#endif
diff --git a/src/rocksdb/utilities/persistent_cache/persistent_cache_test.cc b/src/rocksdb/utilities/persistent_cache/persistent_cache_test.cc
new file mode 100644
index 000000000..d1b18b68a
--- /dev/null
+++ b/src/rocksdb/utilities/persistent_cache/persistent_cache_test.cc
@@ -0,0 +1,462 @@
+// Copyright (c) 2013, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+#if !defined ROCKSDB_LITE
+
+#include "utilities/persistent_cache/persistent_cache_test.h"
+
+#include <functional>
+#include <memory>
+#include <thread>
+
+#include "file/file_util.h"
+#include "utilities/persistent_cache/block_cache_tier.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+static const double kStressFactor = .125;
+
+#ifdef OS_LINUX
+static void OnOpenForRead(void* arg) {
+ int* val = static_cast<int*>(arg);
+ *val &= ~O_DIRECT;
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "NewRandomAccessFile:O_DIRECT",
+ std::bind(OnOpenForRead, std::placeholders::_1));
+}
+
+static void OnOpenForWrite(void* arg) {
+ int* val = static_cast<int*>(arg);
+ *val &= ~O_DIRECT;
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "NewWritableFile:O_DIRECT",
+ std::bind(OnOpenForWrite, std::placeholders::_1));
+}
+#endif
+
+static void OnDeleteDir(void* arg) {
+ char* dir = static_cast<char*>(arg);
+ ASSERT_OK(DestroyDir(Env::Default(), std::string(dir)));
+}
+
+//
+// Simple logger that prints message on stdout
+//
+class ConsoleLogger : public Logger {
+ public:
+ using Logger::Logv;
+ ConsoleLogger() : Logger(InfoLogLevel::ERROR_LEVEL) {}
+
+ void Logv(const char* format, va_list ap) override {
+ MutexLock _(&lock_);
+ vprintf(format, ap);
+ printf("\n");
+ }
+
+ port::Mutex lock_;
+};
+
+// construct a tiered RAM+Block cache
+std::unique_ptr<PersistentTieredCache> NewTieredCache(
+ const size_t mem_size, const PersistentCacheConfig& opt) {
+ std::unique_ptr<PersistentTieredCache> tcache(new PersistentTieredCache());
+ // create primary tier
+ assert(mem_size);
+ auto pcache = std::shared_ptr<PersistentCacheTier>(new VolatileCacheTier(
+ /*is_compressed*/ true, mem_size));
+ tcache->AddTier(pcache);
+ // create secondary tier
+ auto scache = std::shared_ptr<PersistentCacheTier>(new BlockCacheTier(opt));
+ tcache->AddTier(scache);
+
+ Status s = tcache->Open();
+ assert(s.ok());
+ return tcache;
+}
+
+// create block cache
+std::unique_ptr<PersistentCacheTier> NewBlockCache(
+ Env* env, const std::string& path,
+ const uint64_t max_size = std::numeric_limits<uint64_t>::max(),
+ const bool enable_direct_writes = false) {
+ const uint32_t max_file_size =
+ static_cast<uint32_t>(12 * 1024 * 1024 * kStressFactor);
+ auto log = std::make_shared<ConsoleLogger>();
+ PersistentCacheConfig opt(env, path, max_size, log);
+ opt.cache_file_size = max_file_size;
+ opt.max_write_pipeline_backlog_size = std::numeric_limits<uint64_t>::max();
+ opt.enable_direct_writes = enable_direct_writes;
+ std::unique_ptr<PersistentCacheTier> scache(new BlockCacheTier(opt));
+ Status s = scache->Open();
+ assert(s.ok());
+ return scache;
+}
+
+// create a new cache tier
+std::unique_ptr<PersistentTieredCache> NewTieredCache(
+ Env* env, const std::string& path, const uint64_t max_volatile_cache_size,
+ const uint64_t max_block_cache_size =
+ std::numeric_limits<uint64_t>::max()) {
+ const uint32_t max_file_size =
+ static_cast<uint32_t>(12 * 1024 * 1024 * kStressFactor);
+ auto log = std::make_shared<ConsoleLogger>();
+ auto opt = PersistentCacheConfig(env, path, max_block_cache_size, log);
+ opt.cache_file_size = max_file_size;
+ opt.max_write_pipeline_backlog_size = std::numeric_limits<uint64_t>::max();
+ // create tier out of the two caches
+ auto cache = NewTieredCache(max_volatile_cache_size, opt);
+ return cache;
+}
+
+PersistentCacheTierTest::PersistentCacheTierTest()
+ : path_(test::PerThreadDBPath("cache_test")) {
+#ifdef OS_LINUX
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "NewRandomAccessFile:O_DIRECT", OnOpenForRead);
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "NewWritableFile:O_DIRECT", OnOpenForWrite);
+#endif
+}
+
+// Block cache tests
+TEST_F(PersistentCacheTierTest, DISABLED_BlockCacheInsertWithFileCreateError) {
+ cache_ = NewBlockCache(Env::Default(), path_,
+ /*size=*/std::numeric_limits<uint64_t>::max(),
+ /*direct_writes=*/false);
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "BlockCacheTier::NewCacheFile:DeleteDir", OnDeleteDir);
+
+ RunNegativeInsertTest(/*nthreads=*/1,
+ /*max_keys*/
+ static_cast<size_t>(10 * 1024 * kStressFactor));
+
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->ClearAllCallBacks();
+}
+
+// Travis is unable to handle the normal version of the tests running out of
+// fds, out of space and timeouts. This is an easier version of the test
+// specifically written for Travis
+TEST_F(PersistentCacheTierTest, DISABLED_BasicTest) {
+ cache_ = std::make_shared<VolatileCacheTier>();
+ RunInsertTest(/*nthreads=*/1, /*max_keys=*/1024);
+
+ cache_ = NewBlockCache(Env::Default(), path_,
+ /*size=*/std::numeric_limits<uint64_t>::max(),
+ /*direct_writes=*/true);
+ RunInsertTest(/*nthreads=*/1, /*max_keys=*/1024);
+
+ cache_ = NewTieredCache(Env::Default(), path_,
+ /*memory_size=*/static_cast<size_t>(1 * 1024 * 1024));
+ RunInsertTest(/*nthreads=*/1, /*max_keys=*/1024);
+}
+
+// Volatile cache tests
+// DISABLED for now (somewhat expensive)
+TEST_F(PersistentCacheTierTest, DISABLED_VolatileCacheInsert) {
+ for (auto nthreads : {1, 5}) {
+ for (auto max_keys :
+ {10 * 1024 * kStressFactor, 1 * 1024 * 1024 * kStressFactor}) {
+ cache_ = std::make_shared<VolatileCacheTier>();
+ RunInsertTest(nthreads, static_cast<size_t>(max_keys));
+ }
+ }
+}
+
+// DISABLED for now (somewhat expensive)
+TEST_F(PersistentCacheTierTest, DISABLED_VolatileCacheInsertWithEviction) {
+ for (auto nthreads : {1, 5}) {
+ for (auto max_keys : {1 * 1024 * 1024 * kStressFactor}) {
+ cache_ = std::make_shared<VolatileCacheTier>(
+ /*compressed=*/true,
+ /*size=*/static_cast<size_t>(1 * 1024 * 1024 * kStressFactor));
+ RunInsertTestWithEviction(nthreads, static_cast<size_t>(max_keys));
+ }
+ }
+}
+
+// Block cache tests
+// DISABLED for now (expensive)
+TEST_F(PersistentCacheTierTest, DISABLED_BlockCacheInsert) {
+ for (auto direct_writes : {true, false}) {
+ for (auto nthreads : {1, 5}) {
+ for (auto max_keys :
+ {10 * 1024 * kStressFactor, 1 * 1024 * 1024 * kStressFactor}) {
+ cache_ = NewBlockCache(Env::Default(), path_,
+ /*size=*/std::numeric_limits<uint64_t>::max(),
+ direct_writes);
+ RunInsertTest(nthreads, static_cast<size_t>(max_keys));
+ }
+ }
+ }
+}
+
+// DISABLED for now (somewhat expensive)
+TEST_F(PersistentCacheTierTest, DISABLED_BlockCacheInsertWithEviction) {
+ for (auto nthreads : {1, 5}) {
+ for (auto max_keys : {1 * 1024 * 1024 * kStressFactor}) {
+ cache_ = NewBlockCache(
+ Env::Default(), path_,
+ /*max_size=*/static_cast<size_t>(200 * 1024 * 1024 * kStressFactor));
+ RunInsertTestWithEviction(nthreads, static_cast<size_t>(max_keys));
+ }
+ }
+}
+
+// Tiered cache tests
+// DISABLED for now (expensive)
+TEST_F(PersistentCacheTierTest, DISABLED_TieredCacheInsert) {
+ for (auto nthreads : {1, 5}) {
+ for (auto max_keys :
+ {10 * 1024 * kStressFactor, 1 * 1024 * 1024 * kStressFactor}) {
+ cache_ = NewTieredCache(
+ Env::Default(), path_,
+ /*memory_size=*/static_cast<size_t>(1 * 1024 * 1024 * kStressFactor));
+ RunInsertTest(nthreads, static_cast<size_t>(max_keys));
+ }
+ }
+}
+
+// the tests causes a lot of file deletions which Travis limited testing
+// environment cannot handle
+// DISABLED for now (somewhat expensive)
+TEST_F(PersistentCacheTierTest, DISABLED_TieredCacheInsertWithEviction) {
+ for (auto nthreads : {1, 5}) {
+ for (auto max_keys : {1 * 1024 * 1024 * kStressFactor}) {
+ cache_ = NewTieredCache(
+ Env::Default(), path_,
+ /*memory_size=*/static_cast<size_t>(1 * 1024 * 1024 * kStressFactor),
+ /*block_cache_size*/
+ static_cast<size_t>(200 * 1024 * 1024 * kStressFactor));
+ RunInsertTestWithEviction(nthreads, static_cast<size_t>(max_keys));
+ }
+ }
+}
+
+std::shared_ptr<PersistentCacheTier> MakeVolatileCache(
+ Env* /*env*/, const std::string& /*dbname*/) {
+ return std::make_shared<VolatileCacheTier>();
+}
+
+std::shared_ptr<PersistentCacheTier> MakeBlockCache(Env* env,
+ const std::string& dbname) {
+ return NewBlockCache(env, dbname);
+}
+
+std::shared_ptr<PersistentCacheTier> MakeTieredCache(
+ Env* env, const std::string& dbname) {
+ const auto memory_size = 1 * 1024 * 1024 * kStressFactor;
+ return NewTieredCache(env, dbname, static_cast<size_t>(memory_size));
+}
+
+#ifdef OS_LINUX
+static void UniqueIdCallback(void* arg) {
+ int* result = reinterpret_cast<int*>(arg);
+ if (*result == -1) {
+ *result = 0;
+ }
+
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->ClearTrace();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "GetUniqueIdFromFile:FS_IOC_GETVERSION", UniqueIdCallback);
+}
+#endif
+
+TEST_F(PersistentCacheTierTest, FactoryTest) {
+ for (auto nvm_opt : {true, false}) {
+ ASSERT_FALSE(cache_);
+ auto log = std::make_shared<ConsoleLogger>();
+ std::shared_ptr<PersistentCache> cache;
+ ASSERT_OK(NewPersistentCache(Env::Default(), path_,
+ /*size=*/1 * 1024 * 1024 * 1024, log, nvm_opt,
+ &cache));
+ ASSERT_TRUE(cache);
+ ASSERT_EQ(cache->Stats().size(), 1);
+ ASSERT_TRUE(cache->Stats()[0].size());
+ cache.reset();
+ }
+}
+
+PersistentCacheDBTest::PersistentCacheDBTest()
+ : DBTestBase("cache_test", /*env_do_fsync=*/true) {
+#ifdef OS_LINUX
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "GetUniqueIdFromFile:FS_IOC_GETVERSION", UniqueIdCallback);
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "NewRandomAccessFile:O_DIRECT", OnOpenForRead);
+#endif
+}
+
+// test template
+void PersistentCacheDBTest::RunTest(
+ const std::function<std::shared_ptr<PersistentCacheTier>(bool)>& new_pcache,
+ const size_t max_keys = 100 * 1024, const size_t max_usecase = 5) {
+ // number of insertion interations
+ int num_iter = static_cast<int>(max_keys * kStressFactor);
+
+ for (size_t iter = 0; iter < max_usecase; iter++) {
+ Options options;
+ options.write_buffer_size =
+ static_cast<size_t>(64 * 1024 * kStressFactor); // small write buffer
+ options.statistics = ROCKSDB_NAMESPACE::CreateDBStatistics();
+ options = CurrentOptions(options);
+
+ // setup page cache
+ std::shared_ptr<PersistentCacheTier> pcache;
+ BlockBasedTableOptions table_options;
+ table_options.cache_index_and_filter_blocks = true;
+
+ const size_t size_max = std::numeric_limits<size_t>::max();
+
+ switch (iter) {
+ case 0:
+ // page cache, block cache, no-compressed cache
+ pcache = new_pcache(/*is_compressed=*/true);
+ table_options.persistent_cache = pcache;
+ table_options.block_cache = NewLRUCache(size_max);
+ table_options.block_cache_compressed = nullptr;
+ options.table_factory.reset(NewBlockBasedTableFactory(table_options));
+ break;
+ case 1:
+ // page cache, block cache, compressed cache
+ pcache = new_pcache(/*is_compressed=*/true);
+ table_options.persistent_cache = pcache;
+ table_options.block_cache = NewLRUCache(size_max);
+ table_options.block_cache_compressed = NewLRUCache(size_max);
+ options.table_factory.reset(NewBlockBasedTableFactory(table_options));
+ break;
+ case 2:
+ // page cache, block cache, compressed cache + KNoCompression
+ // both block cache and compressed cache, but DB is not compressed
+ // also, make block cache sizes bigger, to trigger block cache hits
+ pcache = new_pcache(/*is_compressed=*/true);
+ table_options.persistent_cache = pcache;
+ table_options.block_cache = NewLRUCache(size_max);
+ table_options.block_cache_compressed = NewLRUCache(size_max);
+ options.table_factory.reset(NewBlockBasedTableFactory(table_options));
+ options.compression = kNoCompression;
+ break;
+ case 3:
+ // page cache, no block cache, no compressed cache
+ pcache = new_pcache(/*is_compressed=*/false);
+ table_options.persistent_cache = pcache;
+ table_options.block_cache = nullptr;
+ table_options.block_cache_compressed = nullptr;
+ options.table_factory.reset(NewBlockBasedTableFactory(table_options));
+ break;
+ case 4:
+ // page cache, no block cache, no compressed cache
+ // Page cache caches compressed blocks
+ pcache = new_pcache(/*is_compressed=*/true);
+ table_options.persistent_cache = pcache;
+ table_options.block_cache = nullptr;
+ table_options.block_cache_compressed = nullptr;
+ options.table_factory.reset(NewBlockBasedTableFactory(table_options));
+ break;
+ default:
+ FAIL();
+ }
+
+ std::vector<std::string> values;
+ // insert data
+ Insert(options, table_options, num_iter, &values);
+ // flush all data in cache to device
+ pcache->TEST_Flush();
+ // verify data
+ Verify(num_iter, values);
+
+ auto block_miss = TestGetTickerCount(options, BLOCK_CACHE_MISS);
+ auto compressed_block_hit =
+ TestGetTickerCount(options, BLOCK_CACHE_COMPRESSED_HIT);
+ auto compressed_block_miss =
+ TestGetTickerCount(options, BLOCK_CACHE_COMPRESSED_MISS);
+ auto page_hit = TestGetTickerCount(options, PERSISTENT_CACHE_HIT);
+ auto page_miss = TestGetTickerCount(options, PERSISTENT_CACHE_MISS);
+
+ // check that we triggered the appropriate code paths in the cache
+ switch (iter) {
+ case 0:
+ // page cache, block cache, no-compressed cache
+ ASSERT_GT(page_miss, 0);
+ ASSERT_GT(page_hit, 0);
+ ASSERT_GT(block_miss, 0);
+ ASSERT_EQ(compressed_block_miss, 0);
+ ASSERT_EQ(compressed_block_hit, 0);
+ break;
+ case 1:
+ // page cache, block cache, compressed cache
+ ASSERT_GT(page_miss, 0);
+ ASSERT_GT(block_miss, 0);
+ ASSERT_GT(compressed_block_miss, 0);
+ break;
+ case 2:
+ // page cache, block cache, compressed cache + KNoCompression
+ ASSERT_GT(page_miss, 0);
+ ASSERT_GT(page_hit, 0);
+ ASSERT_GT(block_miss, 0);
+ ASSERT_GT(compressed_block_miss, 0);
+ // remember kNoCompression
+ ASSERT_EQ(compressed_block_hit, 0);
+ break;
+ case 3:
+ case 4:
+ // page cache, no block cache, no compressed cache
+ ASSERT_GT(page_miss, 0);
+ ASSERT_GT(page_hit, 0);
+ ASSERT_EQ(compressed_block_hit, 0);
+ ASSERT_EQ(compressed_block_miss, 0);
+ break;
+ default:
+ FAIL();
+ }
+
+ options.create_if_missing = true;
+ DestroyAndReopen(options);
+
+ ASSERT_OK(pcache->Close());
+ }
+}
+
+// Travis is unable to handle the normal version of the tests running out of
+// fds, out of space and timeouts. This is an easier version of the test
+// specifically written for Travis.
+// Now used generally because main tests are too expensive as unit tests.
+TEST_F(PersistentCacheDBTest, BasicTest) {
+ RunTest(std::bind(&MakeBlockCache, env_, dbname_), /*max_keys=*/1024,
+ /*max_usecase=*/1);
+}
+
+// test table with block page cache
+// DISABLED for now (very expensive, especially memory)
+TEST_F(PersistentCacheDBTest, DISABLED_BlockCacheTest) {
+ RunTest(std::bind(&MakeBlockCache, env_, dbname_));
+}
+
+// test table with volatile page cache
+// DISABLED for now (very expensive, especially memory)
+TEST_F(PersistentCacheDBTest, DISABLED_VolatileCacheTest) {
+ RunTest(std::bind(&MakeVolatileCache, env_, dbname_));
+}
+
+// test table with tiered page cache
+// DISABLED for now (very expensive, especially memory)
+TEST_F(PersistentCacheDBTest, DISABLED_TieredCacheTest) {
+ RunTest(std::bind(&MakeTieredCache, env_, dbname_));
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
+#else // !defined ROCKSDB_LITE
+int main() { return 0; }
+#endif // !defined ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/persistent_cache/persistent_cache_test.h b/src/rocksdb/utilities/persistent_cache/persistent_cache_test.h
new file mode 100644
index 000000000..f13155ed6
--- /dev/null
+++ b/src/rocksdb/utilities/persistent_cache/persistent_cache_test.h
@@ -0,0 +1,286 @@
+// Copyright (c) 2013, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+#pragma once
+
+#ifndef ROCKSDB_LITE
+
+#include <functional>
+#include <limits>
+#include <list>
+#include <memory>
+#include <string>
+#include <thread>
+#include <vector>
+
+#include "db/db_test_util.h"
+#include "memory/arena.h"
+#include "port/port.h"
+#include "rocksdb/cache.h"
+#include "table/block_based/block_builder.h"
+#include "test_util/testharness.h"
+#include "util/random.h"
+#include "utilities/persistent_cache/volatile_tier_impl.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+//
+// Unit tests for testing PersistentCacheTier
+//
+class PersistentCacheTierTest : public testing::Test {
+ public:
+ PersistentCacheTierTest();
+ virtual ~PersistentCacheTierTest() {
+ if (cache_) {
+ Status s = cache_->Close();
+ assert(s.ok());
+ }
+ }
+
+ protected:
+ // Flush cache
+ void Flush() {
+ if (cache_) {
+ cache_->TEST_Flush();
+ }
+ }
+
+ // create threaded workload
+ template <class T>
+ std::list<port::Thread> SpawnThreads(const size_t n, const T& fn) {
+ std::list<port::Thread> threads;
+ for (size_t i = 0; i < n; i++) {
+ port::Thread th(fn);
+ threads.push_back(std::move(th));
+ }
+ return threads;
+ }
+
+ // Wait for threads to join
+ void Join(std::list<port::Thread>&& threads) {
+ for (auto& th : threads) {
+ th.join();
+ }
+ threads.clear();
+ }
+
+ // Run insert workload in threads
+ void Insert(const size_t nthreads, const size_t max_keys) {
+ key_ = 0;
+ max_keys_ = max_keys;
+ // spawn threads
+ auto fn = std::bind(&PersistentCacheTierTest::InsertImpl, this);
+ auto threads = SpawnThreads(nthreads, fn);
+ // join with threads
+ Join(std::move(threads));
+ // Flush cache
+ Flush();
+ }
+
+ // Run verification on the cache
+ void Verify(const size_t nthreads = 1, const bool eviction_enabled = false) {
+ stats_verify_hits_ = 0;
+ stats_verify_missed_ = 0;
+ key_ = 0;
+ // spawn threads
+ auto fn =
+ std::bind(&PersistentCacheTierTest::VerifyImpl, this, eviction_enabled);
+ auto threads = SpawnThreads(nthreads, fn);
+ // join with threads
+ Join(std::move(threads));
+ }
+
+ // pad 0 to numbers
+ std::string PaddedNumber(const size_t data, const size_t pad_size) {
+ assert(pad_size);
+ char* ret = new char[pad_size];
+ int pos = static_cast<int>(pad_size) - 1;
+ size_t count = 0;
+ size_t t = data;
+ // copy numbers
+ while (t) {
+ count++;
+ ret[pos--] = '0' + t % 10;
+ t = t / 10;
+ }
+ // copy 0s
+ while (pos >= 0) {
+ ret[pos--] = '0';
+ }
+ // post condition
+ assert(count <= pad_size);
+ assert(pos == -1);
+ std::string result(ret, pad_size);
+ delete[] ret;
+ return result;
+ }
+
+ // Insert workload implementation
+ void InsertImpl() {
+ const std::string prefix = "key_prefix_";
+
+ while (true) {
+ size_t i = key_++;
+ if (i >= max_keys_) {
+ break;
+ }
+
+ char data[4 * 1024];
+ memset(data, '0' + (i % 10), sizeof(data));
+ auto k = prefix + PaddedNumber(i, /*count=*/8);
+ Slice key(k);
+ while (true) {
+ Status status = cache_->Insert(key, data, sizeof(data));
+ if (status.ok()) {
+ break;
+ }
+ ASSERT_TRUE(status.IsTryAgain());
+ Env::Default()->SleepForMicroseconds(1 * 1000 * 1000);
+ }
+ }
+ }
+
+ // Verification implementation
+ void VerifyImpl(const bool eviction_enabled = false) {
+ const std::string prefix = "key_prefix_";
+ while (true) {
+ size_t i = key_++;
+ if (i >= max_keys_) {
+ break;
+ }
+
+ char edata[4 * 1024];
+ memset(edata, '0' + (i % 10), sizeof(edata));
+ auto k = prefix + PaddedNumber(i, /*count=*/8);
+ Slice key(k);
+ std::unique_ptr<char[]> block;
+ size_t block_size;
+
+ if (eviction_enabled) {
+ if (!cache_->Lookup(key, &block, &block_size).ok()) {
+ // assume that the key is evicted
+ stats_verify_missed_++;
+ continue;
+ }
+ }
+
+ ASSERT_OK(cache_->Lookup(key, &block, &block_size));
+ ASSERT_EQ(block_size, sizeof(edata));
+ ASSERT_EQ(memcmp(edata, block.get(), sizeof(edata)), 0);
+ stats_verify_hits_++;
+ }
+ }
+
+ // template for insert test
+ void RunInsertTest(const size_t nthreads, const size_t max_keys) {
+ Insert(nthreads, max_keys);
+ Verify(nthreads);
+ ASSERT_EQ(stats_verify_hits_, max_keys);
+ ASSERT_EQ(stats_verify_missed_, 0);
+
+ ASSERT_OK(cache_->Close());
+ cache_.reset();
+ }
+
+ // template for negative insert test
+ void RunNegativeInsertTest(const size_t nthreads, const size_t max_keys) {
+ Insert(nthreads, max_keys);
+ Verify(nthreads, /*eviction_enabled=*/true);
+ ASSERT_LT(stats_verify_hits_, max_keys);
+ ASSERT_GT(stats_verify_missed_, 0);
+
+ ASSERT_OK(cache_->Close());
+ cache_.reset();
+ }
+
+ // template for insert with eviction test
+ void RunInsertTestWithEviction(const size_t nthreads, const size_t max_keys) {
+ Insert(nthreads, max_keys);
+ Verify(nthreads, /*eviction_enabled=*/true);
+ ASSERT_EQ(stats_verify_hits_ + stats_verify_missed_, max_keys);
+ ASSERT_GT(stats_verify_hits_, 0);
+ ASSERT_GT(stats_verify_missed_, 0);
+
+ ASSERT_OK(cache_->Close());
+ cache_.reset();
+ }
+
+ const std::string path_;
+ std::shared_ptr<Logger> log_;
+ std::shared_ptr<PersistentCacheTier> cache_;
+ std::atomic<size_t> key_{0};
+ size_t max_keys_ = 0;
+ std::atomic<size_t> stats_verify_hits_{0};
+ std::atomic<size_t> stats_verify_missed_{0};
+};
+
+//
+// RocksDB tests
+//
+class PersistentCacheDBTest : public DBTestBase {
+ public:
+ PersistentCacheDBTest();
+
+ static uint64_t TestGetTickerCount(const Options& options,
+ Tickers ticker_type) {
+ return static_cast<uint32_t>(
+ options.statistics->getTickerCount(ticker_type));
+ }
+
+ // insert data to table
+ void Insert(const Options& options,
+ const BlockBasedTableOptions& /*table_options*/,
+ const int num_iter, std::vector<std::string>* values) {
+ CreateAndReopenWithCF({"pikachu"}, options);
+ // default column family doesn't have block cache
+ Options no_block_cache_opts;
+ no_block_cache_opts.statistics = options.statistics;
+ no_block_cache_opts = CurrentOptions(no_block_cache_opts);
+ BlockBasedTableOptions table_options_no_bc;
+ table_options_no_bc.no_block_cache = true;
+ no_block_cache_opts.table_factory.reset(
+ NewBlockBasedTableFactory(table_options_no_bc));
+ ReopenWithColumnFamilies(
+ {"default", "pikachu"},
+ std::vector<Options>({no_block_cache_opts, options}));
+
+ Random rnd(301);
+
+ // Write 8MB (80 values, each 100K)
+ ASSERT_EQ(NumTableFilesAtLevel(0, 1), 0);
+ std::string str;
+ for (int i = 0; i < num_iter; i++) {
+ if (i % 4 == 0) { // high compression ratio
+ str = rnd.RandomString(1000);
+ }
+ values->push_back(str);
+ ASSERT_OK(Put(1, Key(i), (*values)[i]));
+ }
+
+ // flush all data from memtable so that reads are from block cache
+ ASSERT_OK(Flush(1));
+ }
+
+ // verify data
+ void Verify(const int num_iter, const std::vector<std::string>& values) {
+ for (int j = 0; j < 2; ++j) {
+ for (int i = 0; i < num_iter; i++) {
+ ASSERT_EQ(Get(1, Key(i)), values[i]);
+ }
+ }
+ }
+
+ // test template
+ void RunTest(const std::function<std::shared_ptr<PersistentCacheTier>(bool)>&
+ new_pcache,
+ const size_t max_keys, const size_t max_usecase);
+};
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif
diff --git a/src/rocksdb/utilities/persistent_cache/persistent_cache_tier.cc b/src/rocksdb/utilities/persistent_cache/persistent_cache_tier.cc
new file mode 100644
index 000000000..54cbce8f7
--- /dev/null
+++ b/src/rocksdb/utilities/persistent_cache/persistent_cache_tier.cc
@@ -0,0 +1,167 @@
+// Copyright (c) 2013, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+#ifndef ROCKSDB_LITE
+
+#include "utilities/persistent_cache/persistent_cache_tier.h"
+
+#include <cinttypes>
+#include <sstream>
+#include <string>
+
+namespace ROCKSDB_NAMESPACE {
+
+std::string PersistentCacheConfig::ToString() const {
+ std::string ret;
+ ret.reserve(20000);
+ const int kBufferSize = 200;
+ char buffer[kBufferSize];
+
+ snprintf(buffer, kBufferSize, " path: %s\n", path.c_str());
+ ret.append(buffer);
+ snprintf(buffer, kBufferSize, " enable_direct_reads: %d\n",
+ enable_direct_reads);
+ ret.append(buffer);
+ snprintf(buffer, kBufferSize, " enable_direct_writes: %d\n",
+ enable_direct_writes);
+ ret.append(buffer);
+ snprintf(buffer, kBufferSize, " cache_size: %" PRIu64 "\n", cache_size);
+ ret.append(buffer);
+ snprintf(buffer, kBufferSize, " cache_file_size: %" PRIu32 "\n",
+ cache_file_size);
+ ret.append(buffer);
+ snprintf(buffer, kBufferSize, " writer_qdepth: %" PRIu32 "\n",
+ writer_qdepth);
+ ret.append(buffer);
+ snprintf(buffer, kBufferSize, " pipeline_writes: %d\n", pipeline_writes);
+ ret.append(buffer);
+ snprintf(buffer, kBufferSize,
+ " max_write_pipeline_backlog_size: %" PRIu64 "\n",
+ max_write_pipeline_backlog_size);
+ ret.append(buffer);
+ snprintf(buffer, kBufferSize, " write_buffer_size: %" PRIu32 "\n",
+ write_buffer_size);
+ ret.append(buffer);
+ snprintf(buffer, kBufferSize, " writer_dispatch_size: %" PRIu64 "\n",
+ writer_dispatch_size);
+ ret.append(buffer);
+ snprintf(buffer, kBufferSize, " is_compressed: %d\n", is_compressed);
+ ret.append(buffer);
+
+ return ret;
+}
+
+//
+// PersistentCacheTier implementation
+//
+Status PersistentCacheTier::Open() {
+ if (next_tier_) {
+ return next_tier_->Open();
+ }
+ return Status::OK();
+}
+
+Status PersistentCacheTier::Close() {
+ if (next_tier_) {
+ return next_tier_->Close();
+ }
+ return Status::OK();
+}
+
+bool PersistentCacheTier::Reserve(const size_t /*size*/) {
+ // default implementation is a pass through
+ return true;
+}
+
+bool PersistentCacheTier::Erase(const Slice& /*key*/) {
+ // default implementation is a pass through since not all cache tiers might
+ // support erase
+ return true;
+}
+
+std::string PersistentCacheTier::PrintStats() {
+ std::ostringstream os;
+ for (auto tier_stats : Stats()) {
+ os << "---- next tier -----" << std::endl;
+ for (auto stat : tier_stats) {
+ os << stat.first << ": " << stat.second << std::endl;
+ }
+ }
+ return os.str();
+}
+
+PersistentCache::StatsType PersistentCacheTier::Stats() {
+ if (next_tier_) {
+ return next_tier_->Stats();
+ }
+ return PersistentCache::StatsType{};
+}
+
+uint64_t PersistentCacheTier::NewId() {
+ return last_id_.fetch_add(1, std::memory_order_relaxed);
+}
+
+//
+// PersistentTieredCache implementation
+//
+PersistentTieredCache::~PersistentTieredCache() { assert(tiers_.empty()); }
+
+Status PersistentTieredCache::Open() {
+ assert(!tiers_.empty());
+ return tiers_.front()->Open();
+}
+
+Status PersistentTieredCache::Close() {
+ assert(!tiers_.empty());
+ Status status = tiers_.front()->Close();
+ if (status.ok()) {
+ tiers_.clear();
+ }
+ return status;
+}
+
+bool PersistentTieredCache::Erase(const Slice& key) {
+ assert(!tiers_.empty());
+ return tiers_.front()->Erase(key);
+}
+
+PersistentCache::StatsType PersistentTieredCache::Stats() {
+ assert(!tiers_.empty());
+ return tiers_.front()->Stats();
+}
+
+std::string PersistentTieredCache::PrintStats() {
+ assert(!tiers_.empty());
+ return tiers_.front()->PrintStats();
+}
+
+Status PersistentTieredCache::Insert(const Slice& page_key, const char* data,
+ const size_t size) {
+ assert(!tiers_.empty());
+ return tiers_.front()->Insert(page_key, data, size);
+}
+
+Status PersistentTieredCache::Lookup(const Slice& page_key,
+ std::unique_ptr<char[]>* data,
+ size_t* size) {
+ assert(!tiers_.empty());
+ return tiers_.front()->Lookup(page_key, data, size);
+}
+
+void PersistentTieredCache::AddTier(const Tier& tier) {
+ if (!tiers_.empty()) {
+ tiers_.back()->set_next_tier(tier);
+ }
+ tiers_.push_back(tier);
+}
+
+bool PersistentTieredCache::IsCompressed() {
+ assert(tiers_.size());
+ return tiers_.front()->IsCompressed();
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif
diff --git a/src/rocksdb/utilities/persistent_cache/persistent_cache_tier.h b/src/rocksdb/utilities/persistent_cache/persistent_cache_tier.h
new file mode 100644
index 000000000..65aadcd3f
--- /dev/null
+++ b/src/rocksdb/utilities/persistent_cache/persistent_cache_tier.h
@@ -0,0 +1,342 @@
+// Copyright (c) 2013, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+#pragma once
+
+#ifndef ROCKSDB_LITE
+
+#include <limits>
+#include <list>
+#include <map>
+#include <string>
+#include <vector>
+
+#include "monitoring/histogram.h"
+#include "rocksdb/env.h"
+#include "rocksdb/persistent_cache.h"
+#include "rocksdb/status.h"
+#include "rocksdb/system_clock.h"
+
+// Persistent Cache
+//
+// Persistent cache is tiered key-value cache that can use persistent medium. It
+// is a generic design and can leverage any storage medium -- disk/SSD/NVM/RAM.
+// The code has been kept generic but significant benchmark/design/development
+// time has been spent to make sure the cache performs appropriately for
+// respective storage medium.
+// The file defines
+// PersistentCacheTier : Implementation that handles individual cache tier
+// PersistentTieresCache : Implementation that handles all tiers as a logical
+// unit
+//
+// PersistentTieredCache architecture:
+// +--------------------------+ PersistentCacheTier that handles multiple tiers
+// | +----------------+ |
+// | | RAM | PersistentCacheTier that handles RAM (VolatileCacheImpl)
+// | +----------------+ |
+// | | next |
+// | v |
+// | +----------------+ |
+// | | NVM | PersistentCacheTier implementation that handles NVM
+// | +----------------+ (BlockCacheImpl)
+// | | next |
+// | V |
+// | +----------------+ |
+// | | LE-SSD | PersistentCacheTier implementation that handles LE-SSD
+// | +----------------+ (BlockCacheImpl)
+// | | |
+// | V |
+// | null |
+// +--------------------------+
+// |
+// V
+// null
+namespace ROCKSDB_NAMESPACE {
+
+// Persistent Cache Config
+//
+// This struct captures all the options that are used to configure persistent
+// cache. Some of the terminologies used in naming the options are
+//
+// dispatch size :
+// This is the size in which IO is dispatched to the device
+//
+// write buffer size :
+// This is the size of an individual write buffer size. Write buffers are
+// grouped to form buffered file.
+//
+// cache size :
+// This is the logical maximum for the cache size
+//
+// qdepth :
+// This is the max number of IOs that can issues to the device in parallel
+//
+// pepeling :
+// The writer code path follows pipelined architecture, which means the
+// operations are handed off from one stage to another
+//
+// pipelining backlog size :
+// With the pipelined architecture, there can always be backlogging of ops in
+// pipeline queues. This is the maximum backlog size after which ops are dropped
+// from queue
+struct PersistentCacheConfig {
+ explicit PersistentCacheConfig(
+ Env* const _env, const std::string& _path, const uint64_t _cache_size,
+ const std::shared_ptr<Logger>& _log,
+ const uint32_t _write_buffer_size = 1 * 1024 * 1024 /*1MB*/) {
+ env = _env;
+ clock = (env != nullptr) ? env->GetSystemClock().get()
+ : SystemClock::Default().get();
+ path = _path;
+ log = _log;
+ cache_size = _cache_size;
+ writer_dispatch_size = write_buffer_size = _write_buffer_size;
+ }
+
+ //
+ // Validate the settings. Our intentions are to catch erroneous settings ahead
+ // of time instead going violating invariants or causing dead locks.
+ //
+ Status ValidateSettings() const {
+ // (1) check pre-conditions for variables
+ if (!env || path.empty()) {
+ return Status::InvalidArgument("empty or null args");
+ }
+
+ // (2) assert size related invariants
+ // - cache size cannot be less than cache file size
+ // - individual write buffer size cannot be greater than cache file size
+ // - total write buffer size cannot be less than 2X cache file size
+ if (cache_size < cache_file_size || write_buffer_size >= cache_file_size ||
+ write_buffer_size * write_buffer_count() < 2 * cache_file_size) {
+ return Status::InvalidArgument("invalid cache size");
+ }
+
+ // (2) check writer settings
+ // - Queue depth cannot be 0
+ // - writer_dispatch_size cannot be greater than writer_buffer_size
+ // - dispatch size and buffer size need to be aligned
+ if (!writer_qdepth || writer_dispatch_size > write_buffer_size ||
+ write_buffer_size % writer_dispatch_size) {
+ return Status::InvalidArgument("invalid writer settings");
+ }
+
+ return Status::OK();
+ }
+
+ //
+ // Env abstraction to use for system level operations
+ //
+ Env* env;
+ SystemClock* clock;
+ //
+ // Path for the block cache where blocks are persisted
+ //
+ std::string path;
+
+ //
+ // Log handle for logging messages
+ //
+ std::shared_ptr<Logger> log;
+
+ //
+ // Enable direct IO for reading
+ //
+ bool enable_direct_reads = true;
+
+ //
+ // Enable direct IO for writing
+ //
+ bool enable_direct_writes = false;
+
+ //
+ // Logical cache size
+ //
+ uint64_t cache_size = std::numeric_limits<uint64_t>::max();
+
+ // cache-file-size
+ //
+ // Cache consists of multiples of small files. This parameter defines the
+ // size of an individual cache file
+ //
+ // default: 1M
+ uint32_t cache_file_size = 100ULL * 1024 * 1024;
+
+ // writer-qdepth
+ //
+ // The writers can issues IO to the devices in parallel. This parameter
+ // controls the max number if IOs that can issues in parallel to the block
+ // device
+ //
+ // default :1
+ uint32_t writer_qdepth = 1;
+
+ // pipeline-writes
+ //
+ // The write optionally follow pipelined architecture. This helps
+ // avoid regression in the eviction code path of the primary tier. This
+ // parameter defines if pipelining is enabled or disabled
+ //
+ // default: true
+ bool pipeline_writes = true;
+
+ // max-write-pipeline-backlog-size
+ //
+ // Max pipeline buffer size. This is the maximum backlog we can accumulate
+ // while waiting for writes. After the limit, new ops will be dropped.
+ //
+ // Default: 1GiB
+ uint64_t max_write_pipeline_backlog_size = 1ULL * 1024 * 1024 * 1024;
+
+ // write-buffer-size
+ //
+ // This is the size in which buffer slabs are allocated.
+ //
+ // Default: 1M
+ uint32_t write_buffer_size = 1ULL * 1024 * 1024;
+
+ // write-buffer-count
+ //
+ // This is the total number of buffer slabs. This is calculated as a factor of
+ // file size in order to avoid dead lock.
+ size_t write_buffer_count() const {
+ assert(write_buffer_size);
+ return static_cast<size_t>((writer_qdepth + 1.2) * cache_file_size /
+ write_buffer_size);
+ }
+
+ // writer-dispatch-size
+ //
+ // The writer thread will dispatch the IO at the specified IO size
+ //
+ // default: 1M
+ uint64_t writer_dispatch_size = 1ULL * 1024 * 1024;
+
+ // is_compressed
+ //
+ // This option determines if the cache will run in compressed mode or
+ // uncompressed mode
+ bool is_compressed = true;
+
+ PersistentCacheConfig MakePersistentCacheConfig(
+ const std::string& path, const uint64_t size,
+ const std::shared_ptr<Logger>& log);
+
+ std::string ToString() const;
+};
+
+// Persistent Cache Tier
+//
+// This a logical abstraction that defines a tier of the persistent cache. Tiers
+// can be stacked over one another. PersistentCahe provides the basic definition
+// for accessing/storing in the cache. PersistentCacheTier extends the interface
+// to enable management and stacking of tiers.
+class PersistentCacheTier : public PersistentCache {
+ public:
+ using Tier = std::shared_ptr<PersistentCacheTier>;
+
+ virtual ~PersistentCacheTier() {}
+
+ // Open the persistent cache tier
+ virtual Status Open();
+
+ // Close the persistent cache tier
+ virtual Status Close();
+
+ // Reserve space up to 'size' bytes
+ virtual bool Reserve(const size_t size);
+
+ // Erase a key from the cache
+ virtual bool Erase(const Slice& key);
+
+ // Print stats to string recursively
+ virtual std::string PrintStats();
+
+ virtual PersistentCache::StatsType Stats() override;
+
+ // Insert to page cache
+ virtual Status Insert(const Slice& page_key, const char* data,
+ const size_t size) override = 0;
+
+ // Lookup page cache by page identifier
+ virtual Status Lookup(const Slice& page_key, std::unique_ptr<char[]>* data,
+ size_t* size) override = 0;
+
+ // Does it store compressed data ?
+ virtual bool IsCompressed() override = 0;
+
+ virtual std::string GetPrintableOptions() const override = 0;
+
+ virtual uint64_t NewId() override;
+
+ // Return a reference to next tier
+ virtual Tier& next_tier() { return next_tier_; }
+
+ // Set the value for next tier
+ virtual void set_next_tier(const Tier& tier) {
+ assert(!next_tier_);
+ next_tier_ = tier;
+ }
+
+ virtual void TEST_Flush() {
+ if (next_tier_) {
+ next_tier_->TEST_Flush();
+ }
+ }
+
+ private:
+ Tier next_tier_; // next tier
+ std::atomic<uint64_t> last_id_{1};
+};
+
+// PersistentTieredCache
+//
+// Abstraction that helps you construct a tiers of persistent caches as a
+// unified cache. The tier(s) of cache will act a single tier for management
+// ease and support PersistentCache methods for accessing data.
+class PersistentTieredCache : public PersistentCacheTier {
+ public:
+ virtual ~PersistentTieredCache();
+
+ Status Open() override;
+ Status Close() override;
+ bool Erase(const Slice& key) override;
+ std::string PrintStats() override;
+ PersistentCache::StatsType Stats() override;
+ Status Insert(const Slice& page_key, const char* data,
+ const size_t size) override;
+ Status Lookup(const Slice& page_key, std::unique_ptr<char[]>* data,
+ size_t* size) override;
+ bool IsCompressed() override;
+
+ std::string GetPrintableOptions() const override {
+ return "PersistentTieredCache";
+ }
+
+ void AddTier(const Tier& tier);
+
+ Tier& next_tier() override {
+ auto it = tiers_.end();
+ return (*it)->next_tier();
+ }
+
+ void set_next_tier(const Tier& tier) override {
+ auto it = tiers_.end();
+ (*it)->set_next_tier(tier);
+ }
+
+ void TEST_Flush() override {
+ assert(!tiers_.empty());
+ tiers_.front()->TEST_Flush();
+ PersistentCacheTier::TEST_Flush();
+ }
+
+ protected:
+ std::list<Tier> tiers_; // list of tiers top-down
+};
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif
diff --git a/src/rocksdb/utilities/persistent_cache/persistent_cache_util.h b/src/rocksdb/utilities/persistent_cache/persistent_cache_util.h
new file mode 100644
index 000000000..2a769652d
--- /dev/null
+++ b/src/rocksdb/utilities/persistent_cache/persistent_cache_util.h
@@ -0,0 +1,67 @@
+// Copyright (c) 2013, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+#pragma once
+
+#include <limits>
+#include <list>
+
+#include "util/mutexlock.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+//
+// Simple synchronized queue implementation with the option of
+// bounding the queue
+//
+// On overflow, the elements will be discarded
+//
+template <class T>
+class BoundedQueue {
+ public:
+ explicit BoundedQueue(
+ const size_t max_size = std::numeric_limits<size_t>::max())
+ : cond_empty_(&lock_), max_size_(max_size) {}
+
+ virtual ~BoundedQueue() {}
+
+ void Push(T&& t) {
+ MutexLock _(&lock_);
+ if (max_size_ != std::numeric_limits<size_t>::max() &&
+ size_ + t.Size() >= max_size_) {
+ // overflow
+ return;
+ }
+
+ size_ += t.Size();
+ q_.push_back(std::move(t));
+ cond_empty_.SignalAll();
+ }
+
+ T Pop() {
+ MutexLock _(&lock_);
+ while (q_.empty()) {
+ cond_empty_.Wait();
+ }
+
+ T t = std::move(q_.front());
+ size_ -= t.Size();
+ q_.pop_front();
+ return t;
+ }
+
+ size_t Size() const {
+ MutexLock _(&lock_);
+ return size_;
+ }
+
+ private:
+ mutable port::Mutex lock_;
+ port::CondVar cond_empty_;
+ std::list<T> q_;
+ size_t size_ = 0;
+ const size_t max_size_;
+};
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/persistent_cache/volatile_tier_impl.cc b/src/rocksdb/utilities/persistent_cache/volatile_tier_impl.cc
new file mode 100644
index 000000000..45d2830aa
--- /dev/null
+++ b/src/rocksdb/utilities/persistent_cache/volatile_tier_impl.cc
@@ -0,0 +1,140 @@
+// Copyright (c) 2013, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+#ifndef ROCKSDB_LITE
+
+#include "utilities/persistent_cache/volatile_tier_impl.h"
+
+#include <string>
+
+namespace ROCKSDB_NAMESPACE {
+
+void VolatileCacheTier::DeleteCacheData(VolatileCacheTier::CacheData* data) {
+ assert(data);
+ delete data;
+}
+
+VolatileCacheTier::~VolatileCacheTier() { index_.Clear(&DeleteCacheData); }
+
+PersistentCache::StatsType VolatileCacheTier::Stats() {
+ std::map<std::string, double> stat;
+ stat.insert({"persistent_cache.volatile_cache.hits",
+ static_cast<double>(stats_.cache_hits_)});
+ stat.insert({"persistent_cache.volatile_cache.misses",
+ static_cast<double>(stats_.cache_misses_)});
+ stat.insert({"persistent_cache.volatile_cache.inserts",
+ static_cast<double>(stats_.cache_inserts_)});
+ stat.insert({"persistent_cache.volatile_cache.evicts",
+ static_cast<double>(stats_.cache_evicts_)});
+ stat.insert({"persistent_cache.volatile_cache.hit_pct",
+ static_cast<double>(stats_.CacheHitPct())});
+ stat.insert({"persistent_cache.volatile_cache.miss_pct",
+ static_cast<double>(stats_.CacheMissPct())});
+
+ auto out = PersistentCacheTier::Stats();
+ out.push_back(stat);
+ return out;
+}
+
+Status VolatileCacheTier::Insert(const Slice& page_key, const char* data,
+ const size_t size) {
+ // precondition
+ assert(data);
+ assert(size);
+
+ // increment the size
+ size_ += size;
+
+ // check if we have overshot the limit, if so evict some space
+ while (size_ > max_size_) {
+ if (!Evict()) {
+ // unable to evict data, we give up so we don't spike read
+ // latency
+ assert(size_ >= size);
+ size_ -= size;
+ return Status::TryAgain("Unable to evict any data");
+ }
+ }
+
+ assert(size_ >= size);
+
+ // insert order: LRU, followed by index
+ std::string key(page_key.data(), page_key.size());
+ std::string value(data, size);
+ std::unique_ptr<CacheData> cache_data(
+ new CacheData(std::move(key), std::move(value)));
+ bool ok = index_.Insert(cache_data.get());
+ if (!ok) {
+ // decrement the size that we incremented ahead of time
+ assert(size_ >= size);
+ size_ -= size;
+ // failed to insert to cache, block already in cache
+ return Status::TryAgain("key already exists in volatile cache");
+ }
+
+ cache_data.release();
+ stats_.cache_inserts_++;
+ return Status::OK();
+}
+
+Status VolatileCacheTier::Lookup(const Slice& page_key,
+ std::unique_ptr<char[]>* result,
+ size_t* size) {
+ CacheData key(std::move(page_key.ToString()));
+ CacheData* kv;
+ bool ok = index_.Find(&key, &kv);
+ if (ok) {
+ // set return data
+ result->reset(new char[kv->value.size()]);
+ memcpy(result->get(), kv->value.c_str(), kv->value.size());
+ *size = kv->value.size();
+ // drop the reference on cache data
+ kv->refs_--;
+ // update stats
+ stats_.cache_hits_++;
+ return Status::OK();
+ }
+
+ stats_.cache_misses_++;
+
+ if (next_tier()) {
+ return next_tier()->Lookup(page_key, result, size);
+ }
+
+ return Status::NotFound("key not found in volatile cache");
+}
+
+bool VolatileCacheTier::Erase(const Slice& /*key*/) {
+ assert(!"not supported");
+ return true;
+}
+
+bool VolatileCacheTier::Evict() {
+ CacheData* edata = index_.Evict();
+ if (!edata) {
+ // not able to evict any object
+ return false;
+ }
+
+ stats_.cache_evicts_++;
+
+ // push the evicted object to the next level
+ if (next_tier()) {
+ // TODO: Should the insert error be ignored?
+ Status s = next_tier()->Insert(Slice(edata->key), edata->value.c_str(),
+ edata->value.size());
+ s.PermitUncheckedError();
+ }
+
+ // adjust size and destroy data
+ size_ -= edata->value.size();
+ delete edata;
+
+ return true;
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif
diff --git a/src/rocksdb/utilities/persistent_cache/volatile_tier_impl.h b/src/rocksdb/utilities/persistent_cache/volatile_tier_impl.h
new file mode 100644
index 000000000..09265e457
--- /dev/null
+++ b/src/rocksdb/utilities/persistent_cache/volatile_tier_impl.h
@@ -0,0 +1,141 @@
+// Copyright (c) 2013, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+#pragma once
+
+#ifndef ROCKSDB_LITE
+
+#include <atomic>
+#include <limits>
+#include <sstream>
+#include <string>
+#include <vector>
+
+#include "rocksdb/cache.h"
+#include "utilities/persistent_cache/hash_table.h"
+#include "utilities/persistent_cache/hash_table_evictable.h"
+#include "utilities/persistent_cache/persistent_cache_tier.h"
+
+// VolatileCacheTier
+//
+// This file provides persistent cache tier implementation for caching
+// key/values in RAM.
+//
+// key/values
+// |
+// V
+// +-------------------+
+// | VolatileCacheTier | Store in an evictable hash table
+// +-------------------+
+// |
+// V
+// on eviction
+// pushed to next tier
+//
+// The implementation is designed to be concurrent. The evictable hash table
+// implementation is not concurrent at this point though.
+//
+// The eviction algorithm is LRU
+namespace ROCKSDB_NAMESPACE {
+
+class VolatileCacheTier : public PersistentCacheTier {
+ public:
+ explicit VolatileCacheTier(
+ const bool is_compressed = true,
+ const size_t max_size = std::numeric_limits<size_t>::max())
+ : is_compressed_(is_compressed), max_size_(max_size) {}
+
+ virtual ~VolatileCacheTier();
+
+ // insert to cache
+ Status Insert(const Slice& page_key, const char* data,
+ const size_t size) override;
+ // lookup key in cache
+ Status Lookup(const Slice& page_key, std::unique_ptr<char[]>* data,
+ size_t* size) override;
+
+ // is compressed cache ?
+ bool IsCompressed() override { return is_compressed_; }
+
+ // erase key from cache
+ bool Erase(const Slice& key) override;
+
+ std::string GetPrintableOptions() const override {
+ return "VolatileCacheTier";
+ }
+
+ // Expose stats as map
+ PersistentCache::StatsType Stats() override;
+
+ private:
+ //
+ // Cache data abstraction
+ //
+ struct CacheData : LRUElement<CacheData> {
+ explicit CacheData(CacheData&& rhs) noexcept
+ : key(std::move(rhs.key)), value(std::move(rhs.value)) {}
+
+ explicit CacheData(const std::string& _key, const std::string& _value = "")
+ : key(_key), value(_value) {}
+
+ virtual ~CacheData() {}
+
+ const std::string key;
+ const std::string value;
+ };
+
+ static void DeleteCacheData(CacheData* data);
+
+ //
+ // Index and LRU definition
+ //
+ struct CacheDataHash {
+ uint64_t operator()(const CacheData* obj) const {
+ assert(obj);
+ return std::hash<std::string>()(obj->key);
+ }
+ };
+
+ struct CacheDataEqual {
+ bool operator()(const CacheData* lhs, const CacheData* rhs) const {
+ assert(lhs);
+ assert(rhs);
+ return lhs->key == rhs->key;
+ }
+ };
+
+ struct Statistics {
+ std::atomic<uint64_t> cache_misses_{0};
+ std::atomic<uint64_t> cache_hits_{0};
+ std::atomic<uint64_t> cache_inserts_{0};
+ std::atomic<uint64_t> cache_evicts_{0};
+
+ double CacheHitPct() const {
+ auto lookups = cache_hits_ + cache_misses_;
+ return lookups ? 100 * cache_hits_ / static_cast<double>(lookups) : 0.0;
+ }
+
+ double CacheMissPct() const {
+ auto lookups = cache_hits_ + cache_misses_;
+ return lookups ? 100 * cache_misses_ / static_cast<double>(lookups) : 0.0;
+ }
+ };
+
+ using IndexType =
+ EvictableHashTable<CacheData, CacheDataHash, CacheDataEqual>;
+
+ // Evict LRU tail
+ bool Evict();
+
+ const bool is_compressed_ = true; // does it store compressed data
+ IndexType index_; // in-memory cache
+ std::atomic<uint64_t> max_size_{0}; // Maximum size of the cache
+ std::atomic<uint64_t> size_{0}; // Size of the cache
+ Statistics stats_;
+};
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif
diff --git a/src/rocksdb/utilities/simulator_cache/cache_simulator.cc b/src/rocksdb/utilities/simulator_cache/cache_simulator.cc
new file mode 100644
index 000000000..dc419e51a
--- /dev/null
+++ b/src/rocksdb/utilities/simulator_cache/cache_simulator.cc
@@ -0,0 +1,288 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "utilities/simulator_cache/cache_simulator.h"
+
+#include <algorithm>
+
+#include "db/dbformat.h"
+#include "rocksdb/trace_record.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+namespace {
+const std::string kGhostCachePrefix = "ghost_";
+} // namespace
+
+GhostCache::GhostCache(std::shared_ptr<Cache> sim_cache)
+ : sim_cache_(sim_cache) {}
+
+bool GhostCache::Admit(const Slice& lookup_key) {
+ auto handle = sim_cache_->Lookup(lookup_key);
+ if (handle != nullptr) {
+ sim_cache_->Release(handle);
+ return true;
+ }
+ // TODO: Should we check for errors here?
+ auto s = sim_cache_->Insert(lookup_key, /*value=*/nullptr, lookup_key.size(),
+ /*deleter=*/nullptr);
+ s.PermitUncheckedError();
+ return false;
+}
+
+CacheSimulator::CacheSimulator(std::unique_ptr<GhostCache>&& ghost_cache,
+ std::shared_ptr<Cache> sim_cache)
+ : ghost_cache_(std::move(ghost_cache)), sim_cache_(sim_cache) {}
+
+void CacheSimulator::Access(const BlockCacheTraceRecord& access) {
+ bool admit = true;
+ const bool is_user_access =
+ BlockCacheTraceHelper::IsUserAccess(access.caller);
+ bool is_cache_miss = true;
+ if (ghost_cache_ && !access.no_insert) {
+ admit = ghost_cache_->Admit(access.block_key);
+ }
+ auto handle = sim_cache_->Lookup(access.block_key);
+ if (handle != nullptr) {
+ sim_cache_->Release(handle);
+ is_cache_miss = false;
+ } else {
+ if (!access.no_insert && admit && access.block_size > 0) {
+ // Ignore errors on insert
+ auto s = sim_cache_->Insert(access.block_key, /*value=*/nullptr,
+ access.block_size,
+ /*deleter=*/nullptr);
+ s.PermitUncheckedError();
+ }
+ }
+ miss_ratio_stats_.UpdateMetrics(access.access_timestamp, is_user_access,
+ is_cache_miss);
+}
+
+void MissRatioStats::UpdateMetrics(uint64_t timestamp_in_ms,
+ bool is_user_access, bool is_cache_miss) {
+ uint64_t timestamp_in_seconds = timestamp_in_ms / kMicrosInSecond;
+ num_accesses_timeline_[timestamp_in_seconds] += 1;
+ num_accesses_ += 1;
+ if (num_misses_timeline_.find(timestamp_in_seconds) ==
+ num_misses_timeline_.end()) {
+ num_misses_timeline_[timestamp_in_seconds] = 0;
+ }
+ if (is_cache_miss) {
+ num_misses_ += 1;
+ num_misses_timeline_[timestamp_in_seconds] += 1;
+ }
+ if (is_user_access) {
+ user_accesses_ += 1;
+ if (is_cache_miss) {
+ user_misses_ += 1;
+ }
+ }
+}
+
+Cache::Priority PrioritizedCacheSimulator::ComputeBlockPriority(
+ const BlockCacheTraceRecord& access) const {
+ if (access.block_type == TraceType::kBlockTraceFilterBlock ||
+ access.block_type == TraceType::kBlockTraceIndexBlock ||
+ access.block_type == TraceType::kBlockTraceUncompressionDictBlock) {
+ return Cache::Priority::HIGH;
+ }
+ return Cache::Priority::LOW;
+}
+
+void PrioritizedCacheSimulator::AccessKVPair(
+ const Slice& key, uint64_t value_size, Cache::Priority priority,
+ const BlockCacheTraceRecord& access, bool no_insert, bool is_user_access,
+ bool* is_cache_miss, bool* admitted, bool update_metrics) {
+ assert(is_cache_miss);
+ assert(admitted);
+ *is_cache_miss = true;
+ *admitted = true;
+ if (ghost_cache_ && !no_insert) {
+ *admitted = ghost_cache_->Admit(key);
+ }
+ auto handle = sim_cache_->Lookup(key);
+ if (handle != nullptr) {
+ sim_cache_->Release(handle);
+ *is_cache_miss = false;
+ } else if (!no_insert && *admitted && value_size > 0) {
+ // TODO: Should we check for an error here?
+ auto s = sim_cache_->Insert(key, /*value=*/nullptr, value_size,
+ /*deleter=*/nullptr,
+ /*handle=*/nullptr, priority);
+ s.PermitUncheckedError();
+ }
+ if (update_metrics) {
+ miss_ratio_stats_.UpdateMetrics(access.access_timestamp, is_user_access,
+ *is_cache_miss);
+ }
+}
+
+void PrioritizedCacheSimulator::Access(const BlockCacheTraceRecord& access) {
+ bool is_cache_miss = true;
+ bool admitted = true;
+ AccessKVPair(access.block_key, access.block_size,
+ ComputeBlockPriority(access), access, access.no_insert,
+ BlockCacheTraceHelper::IsUserAccess(access.caller),
+ &is_cache_miss, &admitted, /*update_metrics=*/true);
+}
+
+void HybridRowBlockCacheSimulator::Access(const BlockCacheTraceRecord& access) {
+ // TODO (haoyu): We only support Get for now. We need to extend the tracing
+ // for MultiGet, i.e., non-data block accesses must log all keys in a
+ // MultiGet.
+ bool is_cache_miss = true;
+ bool admitted = false;
+ if (access.caller == TableReaderCaller::kUserGet &&
+ access.get_id != BlockCacheTraceHelper::kReservedGetId) {
+ // This is a Get request.
+ const std::string& row_key = BlockCacheTraceHelper::ComputeRowKey(access);
+ GetRequestStatus& status = getid_status_map_[access.get_id];
+ if (status.is_complete) {
+ // This Get request completes.
+ // Skip future accesses to its index/filter/data
+ // blocks. These block lookups are unnecessary if we observe a hit for the
+ // referenced key-value pair already. Thus, we treat these lookups as
+ // hits. This is also to ensure the total number of accesses are the same
+ // when comparing to other policies.
+ miss_ratio_stats_.UpdateMetrics(access.access_timestamp,
+ /*is_user_access=*/true,
+ /*is_cache_miss=*/false);
+ return;
+ }
+ if (status.row_key_status.find(row_key) == status.row_key_status.end()) {
+ // This is the first time that this key is accessed. Look up the key-value
+ // pair first. Do not update the miss/accesses metrics here since it will
+ // be updated later.
+ AccessKVPair(row_key, access.referenced_data_size, Cache::Priority::HIGH,
+ access,
+ /*no_insert=*/false,
+ /*is_user_access=*/true, &is_cache_miss, &admitted,
+ /*update_metrics=*/false);
+ InsertResult result = InsertResult::NO_INSERT;
+ if (admitted && access.referenced_data_size > 0) {
+ result = InsertResult::INSERTED;
+ } else if (admitted) {
+ result = InsertResult::ADMITTED;
+ }
+ status.row_key_status[row_key] = result;
+ }
+ if (!is_cache_miss) {
+ // A cache hit.
+ status.is_complete = true;
+ miss_ratio_stats_.UpdateMetrics(access.access_timestamp,
+ /*is_user_access=*/true,
+ /*is_cache_miss=*/false);
+ return;
+ }
+ // The row key-value pair observes a cache miss. We need to access its
+ // index/filter/data blocks.
+ InsertResult inserted = status.row_key_status[row_key];
+ AccessKVPair(
+ access.block_key, access.block_size, ComputeBlockPriority(access),
+ access,
+ /*no_insert=*/!insert_blocks_upon_row_kvpair_miss_ || access.no_insert,
+ /*is_user_access=*/true, &is_cache_miss, &admitted,
+ /*update_metrics=*/true);
+ if (access.referenced_data_size > 0 && inserted == InsertResult::ADMITTED) {
+ // TODO: Should we check for an error here?
+ auto s = sim_cache_->Insert(row_key, /*value=*/nullptr,
+ access.referenced_data_size,
+ /*deleter=*/nullptr,
+ /*handle=*/nullptr, Cache::Priority::HIGH);
+ s.PermitUncheckedError();
+ status.row_key_status[row_key] = InsertResult::INSERTED;
+ }
+ return;
+ }
+ AccessKVPair(access.block_key, access.block_size,
+ ComputeBlockPriority(access), access, access.no_insert,
+ BlockCacheTraceHelper::IsUserAccess(access.caller),
+ &is_cache_miss, &admitted, /*update_metrics=*/true);
+}
+
+BlockCacheTraceSimulator::BlockCacheTraceSimulator(
+ uint64_t warmup_seconds, uint32_t downsample_ratio,
+ const std::vector<CacheConfiguration>& cache_configurations)
+ : warmup_seconds_(warmup_seconds),
+ downsample_ratio_(downsample_ratio),
+ cache_configurations_(cache_configurations) {}
+
+Status BlockCacheTraceSimulator::InitializeCaches() {
+ for (auto const& config : cache_configurations_) {
+ for (auto cache_capacity : config.cache_capacities) {
+ // Scale down the cache capacity since the trace contains accesses on
+ // 1/'downsample_ratio' blocks.
+ uint64_t simulate_cache_capacity = cache_capacity / downsample_ratio_;
+ std::shared_ptr<CacheSimulator> sim_cache;
+ std::unique_ptr<GhostCache> ghost_cache;
+ std::string cache_name = config.cache_name;
+ if (cache_name.find(kGhostCachePrefix) != std::string::npos) {
+ ghost_cache.reset(new GhostCache(
+ NewLRUCache(config.ghost_cache_capacity, /*num_shard_bits=*/1,
+ /*strict_capacity_limit=*/false,
+ /*high_pri_pool_ratio=*/0)));
+ cache_name = cache_name.substr(kGhostCachePrefix.size());
+ }
+ if (cache_name == "lru") {
+ sim_cache = std::make_shared<CacheSimulator>(
+ std::move(ghost_cache),
+ NewLRUCache(simulate_cache_capacity, config.num_shard_bits,
+ /*strict_capacity_limit=*/false,
+ /*high_pri_pool_ratio=*/0));
+ } else if (cache_name == "lru_priority") {
+ sim_cache = std::make_shared<PrioritizedCacheSimulator>(
+ std::move(ghost_cache),
+ NewLRUCache(simulate_cache_capacity, config.num_shard_bits,
+ /*strict_capacity_limit=*/false,
+ /*high_pri_pool_ratio=*/0.5));
+ } else if (cache_name == "lru_hybrid") {
+ sim_cache = std::make_shared<HybridRowBlockCacheSimulator>(
+ std::move(ghost_cache),
+ NewLRUCache(simulate_cache_capacity, config.num_shard_bits,
+ /*strict_capacity_limit=*/false,
+ /*high_pri_pool_ratio=*/0.5),
+ /*insert_blocks_upon_row_kvpair_miss=*/true);
+ } else if (cache_name == "lru_hybrid_no_insert_on_row_miss") {
+ sim_cache = std::make_shared<HybridRowBlockCacheSimulator>(
+ std::move(ghost_cache),
+ NewLRUCache(simulate_cache_capacity, config.num_shard_bits,
+ /*strict_capacity_limit=*/false,
+ /*high_pri_pool_ratio=*/0.5),
+ /*insert_blocks_upon_row_kvpair_miss=*/false);
+ } else {
+ // Not supported.
+ return Status::InvalidArgument("Unknown cache name " +
+ config.cache_name);
+ }
+ sim_caches_[config].push_back(sim_cache);
+ }
+ }
+ return Status::OK();
+}
+
+void BlockCacheTraceSimulator::Access(const BlockCacheTraceRecord& access) {
+ if (trace_start_time_ == 0) {
+ trace_start_time_ = access.access_timestamp;
+ }
+ // access.access_timestamp is in microseconds.
+ if (!warmup_complete_ &&
+ trace_start_time_ + warmup_seconds_ * kMicrosInSecond <=
+ access.access_timestamp) {
+ for (auto& config_caches : sim_caches_) {
+ for (auto& sim_cache : config_caches.second) {
+ sim_cache->reset_counter();
+ }
+ }
+ warmup_complete_ = true;
+ }
+ for (auto& config_caches : sim_caches_) {
+ for (auto& sim_cache : config_caches.second) {
+ sim_cache->Access(access);
+ }
+ }
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/simulator_cache/cache_simulator.h b/src/rocksdb/utilities/simulator_cache/cache_simulator.h
new file mode 100644
index 000000000..6d4979013
--- /dev/null
+++ b/src/rocksdb/utilities/simulator_cache/cache_simulator.h
@@ -0,0 +1,231 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#include <unordered_map>
+
+#include "cache/lru_cache.h"
+#include "trace_replay/block_cache_tracer.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// A cache configuration provided by user.
+struct CacheConfiguration {
+ std::string cache_name; // LRU.
+ uint32_t num_shard_bits;
+ uint64_t ghost_cache_capacity; // ghost cache capacity in bytes.
+ std::vector<uint64_t>
+ cache_capacities; // simulate cache capacities in bytes.
+
+ bool operator==(const CacheConfiguration& o) const {
+ return cache_name == o.cache_name && num_shard_bits == o.num_shard_bits &&
+ ghost_cache_capacity == o.ghost_cache_capacity;
+ }
+ bool operator<(const CacheConfiguration& o) const {
+ return cache_name < o.cache_name ||
+ (cache_name == o.cache_name && num_shard_bits < o.num_shard_bits) ||
+ (cache_name == o.cache_name && num_shard_bits == o.num_shard_bits &&
+ ghost_cache_capacity < o.ghost_cache_capacity);
+ }
+};
+
+class MissRatioStats {
+ public:
+ void reset_counter() {
+ num_misses_ = 0;
+ num_accesses_ = 0;
+ user_accesses_ = 0;
+ user_misses_ = 0;
+ }
+ double miss_ratio() const {
+ if (num_accesses_ == 0) {
+ return -1;
+ }
+ return static_cast<double>(num_misses_ * 100.0 / num_accesses_);
+ }
+ uint64_t total_accesses() const { return num_accesses_; }
+ uint64_t total_misses() const { return num_misses_; }
+
+ const std::map<uint64_t, uint64_t>& num_accesses_timeline() const {
+ return num_accesses_timeline_;
+ }
+
+ const std::map<uint64_t, uint64_t>& num_misses_timeline() const {
+ return num_misses_timeline_;
+ }
+
+ double user_miss_ratio() const {
+ if (user_accesses_ == 0) {
+ return -1;
+ }
+ return static_cast<double>(user_misses_ * 100.0 / user_accesses_);
+ }
+ uint64_t user_accesses() const { return user_accesses_; }
+ uint64_t user_misses() const { return user_misses_; }
+
+ void UpdateMetrics(uint64_t timestamp_in_ms, bool is_user_access,
+ bool is_cache_miss);
+
+ private:
+ uint64_t num_accesses_ = 0;
+ uint64_t num_misses_ = 0;
+ uint64_t user_accesses_ = 0;
+ uint64_t user_misses_ = 0;
+
+ std::map<uint64_t, uint64_t> num_accesses_timeline_;
+ std::map<uint64_t, uint64_t> num_misses_timeline_;
+};
+
+// A ghost cache admits an entry on its second access.
+class GhostCache {
+ public:
+ explicit GhostCache(std::shared_ptr<Cache> sim_cache);
+ ~GhostCache() = default;
+ // No copy and move.
+ GhostCache(const GhostCache&) = delete;
+ GhostCache& operator=(const GhostCache&) = delete;
+ GhostCache(GhostCache&&) = delete;
+ GhostCache& operator=(GhostCache&&) = delete;
+
+ // Returns true if the lookup_key is in the ghost cache.
+ // Returns false otherwise.
+ bool Admit(const Slice& lookup_key);
+
+ private:
+ std::shared_ptr<Cache> sim_cache_;
+};
+
+// A cache simulator that runs against a block cache trace.
+class CacheSimulator {
+ public:
+ CacheSimulator(std::unique_ptr<GhostCache>&& ghost_cache,
+ std::shared_ptr<Cache> sim_cache);
+ virtual ~CacheSimulator() = default;
+ // No copy and move.
+ CacheSimulator(const CacheSimulator&) = delete;
+ CacheSimulator& operator=(const CacheSimulator&) = delete;
+ CacheSimulator(CacheSimulator&&) = delete;
+ CacheSimulator& operator=(CacheSimulator&&) = delete;
+
+ virtual void Access(const BlockCacheTraceRecord& access);
+
+ void reset_counter() { miss_ratio_stats_.reset_counter(); }
+
+ const MissRatioStats& miss_ratio_stats() const { return miss_ratio_stats_; }
+
+ protected:
+ MissRatioStats miss_ratio_stats_;
+ std::unique_ptr<GhostCache> ghost_cache_;
+ std::shared_ptr<Cache> sim_cache_;
+};
+
+// A prioritized cache simulator that runs against a block cache trace.
+// It inserts missing index/filter/uncompression-dictionary blocks with high
+// priority in the cache.
+class PrioritizedCacheSimulator : public CacheSimulator {
+ public:
+ PrioritizedCacheSimulator(std::unique_ptr<GhostCache>&& ghost_cache,
+ std::shared_ptr<Cache> sim_cache)
+ : CacheSimulator(std::move(ghost_cache), sim_cache) {}
+ void Access(const BlockCacheTraceRecord& access) override;
+
+ protected:
+ // Access the key-value pair and returns true upon a cache miss.
+ void AccessKVPair(const Slice& key, uint64_t value_size,
+ Cache::Priority priority,
+ const BlockCacheTraceRecord& access, bool no_insert,
+ bool is_user_access, bool* is_cache_miss, bool* admitted,
+ bool update_metrics);
+
+ Cache::Priority ComputeBlockPriority(
+ const BlockCacheTraceRecord& access) const;
+};
+
+// A hybrid row and block cache simulator. It looks up/inserts key-value pairs
+// referenced by Get/MultiGet requests, and not their accessed index/filter/data
+// blocks.
+//
+// Upon a Get/MultiGet request, it looks up the referenced key first.
+// If it observes a cache hit, future block accesses on this key-value pair is
+// skipped since the request is served already. Otherwise, it continues to look
+// up/insert its index/filter/data blocks. It also inserts the referenced
+// key-value pair in the cache for future lookups.
+class HybridRowBlockCacheSimulator : public PrioritizedCacheSimulator {
+ public:
+ HybridRowBlockCacheSimulator(std::unique_ptr<GhostCache>&& ghost_cache,
+ std::shared_ptr<Cache> sim_cache,
+ bool insert_blocks_upon_row_kvpair_miss)
+ : PrioritizedCacheSimulator(std::move(ghost_cache), sim_cache),
+ insert_blocks_upon_row_kvpair_miss_(
+ insert_blocks_upon_row_kvpair_miss) {}
+ void Access(const BlockCacheTraceRecord& access) override;
+
+ private:
+ enum InsertResult : char {
+ INSERTED,
+ ADMITTED,
+ NO_INSERT,
+ };
+
+ // We set is_complete to true when the referenced row-key of a get request
+ // hits the cache. If is_complete is true, we treat future accesses of this
+ // get request as hits.
+ //
+ // For each row key, it stores an enum. It is INSERTED when the
+ // kv-pair has been inserted into the cache, ADMITTED if it should be inserted
+ // but haven't been, NO_INSERT if it should not be inserted.
+ //
+ // A kv-pair is in ADMITTED state when we encounter this kv-pair but do not
+ // know its size. This may happen if the first access on the referenced key is
+ // an index/filter block.
+ struct GetRequestStatus {
+ bool is_complete = false;
+ std::map<std::string, InsertResult> row_key_status;
+ };
+
+ // A map stores get_id to a map of row keys.
+ std::map<uint64_t, GetRequestStatus> getid_status_map_;
+ bool insert_blocks_upon_row_kvpair_miss_;
+};
+
+// A block cache simulator that reports miss ratio curves given a set of cache
+// configurations.
+class BlockCacheTraceSimulator {
+ public:
+ // warmup_seconds: The number of seconds to warmup simulated caches. The
+ // hit/miss counters are reset after the warmup completes.
+ BlockCacheTraceSimulator(
+ uint64_t warmup_seconds, uint32_t downsample_ratio,
+ const std::vector<CacheConfiguration>& cache_configurations);
+ ~BlockCacheTraceSimulator() = default;
+ // No copy and move.
+ BlockCacheTraceSimulator(const BlockCacheTraceSimulator&) = delete;
+ BlockCacheTraceSimulator& operator=(const BlockCacheTraceSimulator&) = delete;
+ BlockCacheTraceSimulator(BlockCacheTraceSimulator&&) = delete;
+ BlockCacheTraceSimulator& operator=(BlockCacheTraceSimulator&&) = delete;
+
+ Status InitializeCaches();
+
+ void Access(const BlockCacheTraceRecord& access);
+
+ const std::map<CacheConfiguration,
+ std::vector<std::shared_ptr<CacheSimulator>>>&
+ sim_caches() const {
+ return sim_caches_;
+ }
+
+ private:
+ const uint64_t warmup_seconds_;
+ const uint32_t downsample_ratio_;
+ const std::vector<CacheConfiguration> cache_configurations_;
+
+ bool warmup_complete_ = false;
+ std::map<CacheConfiguration, std::vector<std::shared_ptr<CacheSimulator>>>
+ sim_caches_;
+ uint64_t trace_start_time_ = 0;
+};
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/simulator_cache/cache_simulator_test.cc b/src/rocksdb/utilities/simulator_cache/cache_simulator_test.cc
new file mode 100644
index 000000000..2bc057c92
--- /dev/null
+++ b/src/rocksdb/utilities/simulator_cache/cache_simulator_test.cc
@@ -0,0 +1,497 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "utilities/simulator_cache/cache_simulator.h"
+
+#include <cstdlib>
+
+#include "rocksdb/env.h"
+#include "rocksdb/trace_record.h"
+#include "test_util/testharness.h"
+#include "test_util/testutil.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace {
+const std::string kBlockKeyPrefix = "test-block-";
+const std::string kRefKeyPrefix = "test-get-";
+const std::string kRefKeySequenceNumber = std::string(8, 'c');
+const uint64_t kGetId = 1;
+const uint64_t kGetBlockId = 100;
+const uint64_t kCompactionBlockId = 1000;
+const uint64_t kCacheSize = 1024 * 1024 * 1024;
+const uint64_t kGhostCacheSize = 1024 * 1024;
+} // namespace
+
+class CacheSimulatorTest : public testing::Test {
+ public:
+ const size_t kNumBlocks = 5;
+ const size_t kValueSize = 1000;
+
+ CacheSimulatorTest() { env_ = ROCKSDB_NAMESPACE::Env::Default(); }
+
+ BlockCacheTraceRecord GenerateGetRecord(uint64_t getid) {
+ BlockCacheTraceRecord record;
+ record.block_type = TraceType::kBlockTraceDataBlock;
+ record.block_size = 4096;
+ record.block_key = kBlockKeyPrefix + std::to_string(kGetBlockId);
+ record.access_timestamp = env_->NowMicros();
+ record.cf_id = 0;
+ record.cf_name = "test";
+ record.caller = TableReaderCaller::kUserGet;
+ record.level = 6;
+ record.sst_fd_number = 0;
+ record.get_id = getid;
+ record.is_cache_hit = false;
+ record.no_insert = false;
+ record.referenced_key =
+ kRefKeyPrefix + std::to_string(kGetId) + kRefKeySequenceNumber;
+ record.referenced_key_exist_in_block = true;
+ record.referenced_data_size = 100;
+ record.num_keys_in_block = 300;
+ return record;
+ }
+
+ BlockCacheTraceRecord GenerateCompactionRecord() {
+ BlockCacheTraceRecord record;
+ record.block_type = TraceType::kBlockTraceDataBlock;
+ record.block_size = 4096;
+ record.block_key = kBlockKeyPrefix + std::to_string(kCompactionBlockId);
+ record.access_timestamp = env_->NowMicros();
+ record.cf_id = 0;
+ record.cf_name = "test";
+ record.caller = TableReaderCaller::kCompaction;
+ record.level = 6;
+ record.sst_fd_number = kCompactionBlockId;
+ record.is_cache_hit = false;
+ record.no_insert = true;
+ return record;
+ }
+
+ void AssertCache(std::shared_ptr<Cache> sim_cache,
+ const MissRatioStats& miss_ratio_stats,
+ uint64_t expected_usage, uint64_t expected_num_accesses,
+ uint64_t expected_num_misses,
+ std::vector<std::string> blocks,
+ std::vector<std::string> keys) {
+ EXPECT_EQ(expected_usage, sim_cache->GetUsage());
+ EXPECT_EQ(expected_num_accesses, miss_ratio_stats.total_accesses());
+ EXPECT_EQ(expected_num_misses, miss_ratio_stats.total_misses());
+ for (auto const& block : blocks) {
+ auto handle = sim_cache->Lookup(block);
+ EXPECT_NE(nullptr, handle);
+ sim_cache->Release(handle);
+ }
+ for (auto const& key : keys) {
+ std::string row_key = kRefKeyPrefix + key + kRefKeySequenceNumber;
+ auto handle =
+ sim_cache->Lookup("0_" + ExtractUserKey(row_key).ToString());
+ EXPECT_NE(nullptr, handle);
+ sim_cache->Release(handle);
+ }
+ }
+
+ Env* env_;
+};
+
+TEST_F(CacheSimulatorTest, GhostCache) {
+ const std::string key1 = "test1";
+ const std::string key2 = "test2";
+ std::unique_ptr<GhostCache> ghost_cache(new GhostCache(
+ NewLRUCache(/*capacity=*/kGhostCacheSize, /*num_shard_bits=*/1,
+ /*strict_capacity_limit=*/false,
+ /*high_pri_pool_ratio=*/0)));
+ EXPECT_FALSE(ghost_cache->Admit(key1));
+ EXPECT_TRUE(ghost_cache->Admit(key1));
+ EXPECT_TRUE(ghost_cache->Admit(key1));
+ EXPECT_FALSE(ghost_cache->Admit(key2));
+ EXPECT_TRUE(ghost_cache->Admit(key2));
+}
+
+TEST_F(CacheSimulatorTest, CacheSimulator) {
+ const BlockCacheTraceRecord& access = GenerateGetRecord(kGetId);
+ const BlockCacheTraceRecord& compaction_access = GenerateCompactionRecord();
+ std::shared_ptr<Cache> sim_cache =
+ NewLRUCache(/*capacity=*/kCacheSize, /*num_shard_bits=*/1,
+ /*strict_capacity_limit=*/false,
+ /*high_pri_pool_ratio=*/0);
+ std::unique_ptr<CacheSimulator> cache_simulator(
+ new CacheSimulator(nullptr, sim_cache));
+ cache_simulator->Access(access);
+ cache_simulator->Access(access);
+ ASSERT_EQ(2, cache_simulator->miss_ratio_stats().total_accesses());
+ ASSERT_EQ(50, cache_simulator->miss_ratio_stats().miss_ratio());
+ ASSERT_EQ(2, cache_simulator->miss_ratio_stats().user_accesses());
+ ASSERT_EQ(50, cache_simulator->miss_ratio_stats().user_miss_ratio());
+
+ cache_simulator->Access(compaction_access);
+ cache_simulator->Access(compaction_access);
+ ASSERT_EQ(4, cache_simulator->miss_ratio_stats().total_accesses());
+ ASSERT_EQ(75, cache_simulator->miss_ratio_stats().miss_ratio());
+ ASSERT_EQ(2, cache_simulator->miss_ratio_stats().user_accesses());
+ ASSERT_EQ(50, cache_simulator->miss_ratio_stats().user_miss_ratio());
+
+ cache_simulator->reset_counter();
+ ASSERT_EQ(0, cache_simulator->miss_ratio_stats().total_accesses());
+ ASSERT_EQ(-1, cache_simulator->miss_ratio_stats().miss_ratio());
+ auto handle = sim_cache->Lookup(access.block_key);
+ ASSERT_NE(nullptr, handle);
+ sim_cache->Release(handle);
+ handle = sim_cache->Lookup(compaction_access.block_key);
+ ASSERT_EQ(nullptr, handle);
+}
+
+TEST_F(CacheSimulatorTest, GhostCacheSimulator) {
+ const BlockCacheTraceRecord& access = GenerateGetRecord(kGetId);
+ std::unique_ptr<GhostCache> ghost_cache(new GhostCache(
+ NewLRUCache(/*capacity=*/kGhostCacheSize, /*num_shard_bits=*/1,
+ /*strict_capacity_limit=*/false,
+ /*high_pri_pool_ratio=*/0)));
+ std::unique_ptr<CacheSimulator> cache_simulator(new CacheSimulator(
+ std::move(ghost_cache),
+ NewLRUCache(/*capacity=*/kCacheSize, /*num_shard_bits=*/1,
+ /*strict_capacity_limit=*/false,
+ /*high_pri_pool_ratio=*/0)));
+ cache_simulator->Access(access);
+ cache_simulator->Access(access);
+ ASSERT_EQ(2, cache_simulator->miss_ratio_stats().total_accesses());
+ // Both of them will be miss since we have a ghost cache.
+ ASSERT_EQ(100, cache_simulator->miss_ratio_stats().miss_ratio());
+}
+
+TEST_F(CacheSimulatorTest, PrioritizedCacheSimulator) {
+ const BlockCacheTraceRecord& access = GenerateGetRecord(kGetId);
+ std::shared_ptr<Cache> sim_cache =
+ NewLRUCache(/*capacity=*/kCacheSize, /*num_shard_bits=*/1,
+ /*strict_capacity_limit=*/false,
+ /*high_pri_pool_ratio=*/0);
+ std::unique_ptr<PrioritizedCacheSimulator> cache_simulator(
+ new PrioritizedCacheSimulator(nullptr, sim_cache));
+ cache_simulator->Access(access);
+ cache_simulator->Access(access);
+ ASSERT_EQ(2, cache_simulator->miss_ratio_stats().total_accesses());
+ ASSERT_EQ(50, cache_simulator->miss_ratio_stats().miss_ratio());
+
+ auto handle = sim_cache->Lookup(access.block_key);
+ ASSERT_NE(nullptr, handle);
+ sim_cache->Release(handle);
+}
+
+TEST_F(CacheSimulatorTest, GhostPrioritizedCacheSimulator) {
+ const BlockCacheTraceRecord& access = GenerateGetRecord(kGetId);
+ std::unique_ptr<GhostCache> ghost_cache(new GhostCache(
+ NewLRUCache(/*capacity=*/kGhostCacheSize, /*num_shard_bits=*/1,
+ /*strict_capacity_limit=*/false,
+ /*high_pri_pool_ratio=*/0)));
+ std::unique_ptr<PrioritizedCacheSimulator> cache_simulator(
+ new PrioritizedCacheSimulator(
+ std::move(ghost_cache),
+ NewLRUCache(/*capacity=*/kCacheSize, /*num_shard_bits=*/1,
+ /*strict_capacity_limit=*/false,
+ /*high_pri_pool_ratio=*/0)));
+ cache_simulator->Access(access);
+ cache_simulator->Access(access);
+ ASSERT_EQ(2, cache_simulator->miss_ratio_stats().total_accesses());
+ // Both of them will be miss since we have a ghost cache.
+ ASSERT_EQ(100, cache_simulator->miss_ratio_stats().miss_ratio());
+}
+
+TEST_F(CacheSimulatorTest, HybridRowBlockCacheSimulator) {
+ uint64_t block_id = 100;
+ BlockCacheTraceRecord first_get = GenerateGetRecord(kGetId);
+ first_get.get_from_user_specified_snapshot = true;
+ BlockCacheTraceRecord second_get = GenerateGetRecord(kGetId + 1);
+ second_get.referenced_data_size = 0;
+ second_get.referenced_key_exist_in_block = false;
+ second_get.get_from_user_specified_snapshot = true;
+ BlockCacheTraceRecord third_get = GenerateGetRecord(kGetId + 2);
+ third_get.referenced_data_size = 0;
+ third_get.referenced_key_exist_in_block = false;
+ third_get.referenced_key = kRefKeyPrefix + "third_get";
+ // We didn't find the referenced key in the third get.
+ third_get.referenced_key_exist_in_block = false;
+ third_get.referenced_data_size = 0;
+ std::shared_ptr<Cache> sim_cache =
+ NewLRUCache(/*capacity=*/kCacheSize, /*num_shard_bits=*/1,
+ /*strict_capacity_limit=*/false,
+ /*high_pri_pool_ratio=*/0);
+ std::unique_ptr<HybridRowBlockCacheSimulator> cache_simulator(
+ new HybridRowBlockCacheSimulator(
+ nullptr, sim_cache, /*insert_blocks_row_kvpair_misses=*/true));
+ // The first get request accesses 10 blocks. We should only report 10 accesses
+ // and 100% miss.
+ for (uint32_t i = 0; i < 10; i++) {
+ first_get.block_key = kBlockKeyPrefix + std::to_string(block_id);
+ cache_simulator->Access(first_get);
+ block_id++;
+ }
+
+ ASSERT_EQ(10, cache_simulator->miss_ratio_stats().total_accesses());
+ ASSERT_EQ(100, cache_simulator->miss_ratio_stats().miss_ratio());
+ ASSERT_EQ(10, cache_simulator->miss_ratio_stats().user_accesses());
+ ASSERT_EQ(100, cache_simulator->miss_ratio_stats().user_miss_ratio());
+ auto handle =
+ sim_cache->Lookup(std::to_string(first_get.sst_fd_number) + "_" +
+ ExtractUserKey(first_get.referenced_key).ToString());
+ ASSERT_NE(nullptr, handle);
+ sim_cache->Release(handle);
+ for (uint32_t i = 100; i < block_id; i++) {
+ handle = sim_cache->Lookup(kBlockKeyPrefix + std::to_string(i));
+ ASSERT_NE(nullptr, handle);
+ sim_cache->Release(handle);
+ }
+
+ // The second get request accesses the same key. We should report 15
+ // access and 66% miss, 10 misses with 15 accesses.
+ // We do not consider these 5 block lookups as misses since the row hits the
+ // cache.
+ for (uint32_t i = 0; i < 5; i++) {
+ second_get.block_key = kBlockKeyPrefix + std::to_string(block_id);
+ cache_simulator->Access(second_get);
+ block_id++;
+ }
+ ASSERT_EQ(15, cache_simulator->miss_ratio_stats().total_accesses());
+ ASSERT_EQ(66, static_cast<uint64_t>(
+ cache_simulator->miss_ratio_stats().miss_ratio()));
+ ASSERT_EQ(15, cache_simulator->miss_ratio_stats().user_accesses());
+ ASSERT_EQ(66, static_cast<uint64_t>(
+ cache_simulator->miss_ratio_stats().user_miss_ratio()));
+ handle =
+ sim_cache->Lookup(std::to_string(second_get.sst_fd_number) + "_" +
+ ExtractUserKey(second_get.referenced_key).ToString());
+ ASSERT_NE(nullptr, handle);
+ sim_cache->Release(handle);
+ for (uint32_t i = 100; i < block_id; i++) {
+ handle = sim_cache->Lookup(kBlockKeyPrefix + std::to_string(i));
+ if (i < 110) {
+ ASSERT_NE(nullptr, handle) << i;
+ sim_cache->Release(handle);
+ } else {
+ ASSERT_EQ(nullptr, handle) << i;
+ }
+ }
+
+ // The third get on a different key and does not have a size.
+ // This key should not be inserted into the cache.
+ for (uint32_t i = 0; i < 5; i++) {
+ third_get.block_key = kBlockKeyPrefix + std::to_string(block_id);
+ cache_simulator->Access(third_get);
+ block_id++;
+ }
+ ASSERT_EQ(20, cache_simulator->miss_ratio_stats().total_accesses());
+ ASSERT_EQ(75, static_cast<uint64_t>(
+ cache_simulator->miss_ratio_stats().miss_ratio()));
+ ASSERT_EQ(20, cache_simulator->miss_ratio_stats().user_accesses());
+ ASSERT_EQ(75, static_cast<uint64_t>(
+ cache_simulator->miss_ratio_stats().user_miss_ratio()));
+ // Assert that the third key is not inserted into the cache.
+ handle = sim_cache->Lookup(std::to_string(third_get.sst_fd_number) + "_" +
+ third_get.referenced_key);
+ ASSERT_EQ(nullptr, handle);
+ for (uint32_t i = 100; i < block_id; i++) {
+ if (i < 110 || i >= 115) {
+ handle = sim_cache->Lookup(kBlockKeyPrefix + std::to_string(i));
+ ASSERT_NE(nullptr, handle) << i;
+ sim_cache->Release(handle);
+ } else {
+ handle = sim_cache->Lookup(kBlockKeyPrefix + std::to_string(i));
+ ASSERT_EQ(nullptr, handle) << i;
+ }
+ }
+}
+
+TEST_F(CacheSimulatorTest, HybridRowBlockCacheSimulatorGetTest) {
+ BlockCacheTraceRecord get = GenerateGetRecord(kGetId);
+ get.block_size = 1;
+ get.referenced_data_size = 0;
+ get.access_timestamp = 0;
+ get.block_key = "1";
+ get.get_id = 1;
+ get.get_from_user_specified_snapshot = false;
+ get.referenced_key =
+ kRefKeyPrefix + std::to_string(1) + kRefKeySequenceNumber;
+ get.no_insert = false;
+ get.sst_fd_number = 0;
+ get.get_from_user_specified_snapshot = false;
+
+ LRUCacheOptions co;
+ co.capacity = 16;
+ co.num_shard_bits = 1;
+ co.strict_capacity_limit = false;
+ co.high_pri_pool_ratio = 0;
+ co.metadata_charge_policy = kDontChargeCacheMetadata;
+ std::shared_ptr<Cache> sim_cache = NewLRUCache(co);
+ std::unique_ptr<HybridRowBlockCacheSimulator> cache_simulator(
+ new HybridRowBlockCacheSimulator(
+ nullptr, sim_cache, /*insert_blocks_row_kvpair_misses=*/true));
+ // Expect a miss and does not insert the row key-value pair since it does not
+ // have size.
+ cache_simulator->Access(get);
+ AssertCache(sim_cache, cache_simulator->miss_ratio_stats(), 1, 1, 1, {"1"},
+ {});
+ get.access_timestamp += 1;
+ get.referenced_data_size = 1;
+ get.block_key = "2";
+ cache_simulator->Access(get);
+ AssertCache(sim_cache, cache_simulator->miss_ratio_stats(), 3, 2, 2,
+ {"1", "2"}, {"1"});
+ get.access_timestamp += 1;
+ get.block_key = "3";
+ // K1 should not inserted again.
+ cache_simulator->Access(get);
+ AssertCache(sim_cache, cache_simulator->miss_ratio_stats(), 4, 3, 3,
+ {"1", "2", "3"}, {"1"});
+
+ // A second get request referencing the same key.
+ get.access_timestamp += 1;
+ get.get_id = 2;
+ get.block_key = "4";
+ get.referenced_data_size = 0;
+ cache_simulator->Access(get);
+ AssertCache(sim_cache, cache_simulator->miss_ratio_stats(), 4, 4, 3,
+ {"1", "2", "3"}, {"1"});
+
+ // A third get request searches three files, three different keys.
+ // And the second key observes a hit.
+ get.access_timestamp += 1;
+ get.referenced_data_size = 1;
+ get.get_id = 3;
+ get.block_key = "3";
+ get.referenced_key = kRefKeyPrefix + "2" + kRefKeySequenceNumber;
+ // K2 should observe a miss. Block 3 observes a hit.
+ cache_simulator->Access(get);
+ AssertCache(sim_cache, cache_simulator->miss_ratio_stats(), 5, 5, 3,
+ {"1", "2", "3"}, {"1", "2"});
+
+ get.access_timestamp += 1;
+ get.referenced_data_size = 1;
+ get.get_id = 3;
+ get.block_key = "4";
+ get.referenced_data_size = 1;
+ get.referenced_key = kRefKeyPrefix + "1" + kRefKeySequenceNumber;
+ // K1 should observe a hit.
+ cache_simulator->Access(get);
+ AssertCache(sim_cache, cache_simulator->miss_ratio_stats(), 5, 6, 3,
+ {"1", "2", "3"}, {"1", "2"});
+
+ get.access_timestamp += 1;
+ get.referenced_data_size = 1;
+ get.get_id = 3;
+ get.block_key = "4";
+ get.referenced_data_size = 1;
+ get.referenced_key = kRefKeyPrefix + "3" + kRefKeySequenceNumber;
+ // K3 should observe a miss.
+ // However, as the get already complete, we should not access k3 any more.
+ cache_simulator->Access(get);
+ AssertCache(sim_cache, cache_simulator->miss_ratio_stats(), 5, 7, 3,
+ {"1", "2", "3"}, {"1", "2"});
+
+ // A fourth get request searches one file and two blocks. One row key.
+ get.access_timestamp += 1;
+ get.get_id = 4;
+ get.block_key = "5";
+ get.referenced_key = kRefKeyPrefix + "4" + kRefKeySequenceNumber;
+ get.referenced_data_size = 1;
+ cache_simulator->Access(get);
+ AssertCache(sim_cache, cache_simulator->miss_ratio_stats(), 7, 8, 4,
+ {"1", "2", "3", "5"}, {"1", "2", "4"});
+ for (auto const& key : {"1", "2", "4"}) {
+ auto handle = sim_cache->Lookup("0_" + kRefKeyPrefix + key);
+ ASSERT_NE(nullptr, handle);
+ sim_cache->Release(handle);
+ }
+
+ // A bunch of insertions which evict cached row keys.
+ for (uint32_t i = 6; i < 100; i++) {
+ get.access_timestamp += 1;
+ get.get_id = 0;
+ get.block_key = std::to_string(i);
+ cache_simulator->Access(get);
+ }
+
+ get.get_id = 4;
+ // A different block.
+ get.block_key = "100";
+ // Same row key and should not be inserted again.
+ get.referenced_key = kRefKeyPrefix + "4" + kRefKeySequenceNumber;
+ get.referenced_data_size = 1;
+ cache_simulator->Access(get);
+ AssertCache(sim_cache, cache_simulator->miss_ratio_stats(), 16, 103, 99, {},
+ {});
+ for (auto const& key : {"1", "2", "4"}) {
+ auto handle = sim_cache->Lookup("0_" + kRefKeyPrefix + key);
+ ASSERT_EQ(nullptr, handle);
+ }
+}
+
+TEST_F(CacheSimulatorTest, HybridRowBlockNoInsertCacheSimulator) {
+ uint64_t block_id = 100;
+ BlockCacheTraceRecord first_get = GenerateGetRecord(kGetId);
+ std::shared_ptr<Cache> sim_cache =
+ NewLRUCache(/*capacity=*/kCacheSize, /*num_shard_bits=*/1,
+ /*strict_capacity_limit=*/false,
+ /*high_pri_pool_ratio=*/0);
+ std::unique_ptr<HybridRowBlockCacheSimulator> cache_simulator(
+ new HybridRowBlockCacheSimulator(
+ nullptr, sim_cache, /*insert_blocks_row_kvpair_misses=*/false));
+ for (uint32_t i = 0; i < 9; i++) {
+ first_get.block_key = kBlockKeyPrefix + std::to_string(block_id);
+ cache_simulator->Access(first_get);
+ block_id++;
+ }
+ auto handle =
+ sim_cache->Lookup(std::to_string(first_get.sst_fd_number) + "_" +
+ ExtractUserKey(first_get.referenced_key).ToString());
+ ASSERT_NE(nullptr, handle);
+ sim_cache->Release(handle);
+ // All blocks are missing from the cache since insert_blocks_row_kvpair_misses
+ // is set to false.
+ for (uint32_t i = 100; i < block_id; i++) {
+ handle = sim_cache->Lookup(kBlockKeyPrefix + std::to_string(i));
+ ASSERT_EQ(nullptr, handle);
+ }
+}
+
+TEST_F(CacheSimulatorTest, GhostHybridRowBlockCacheSimulator) {
+ std::unique_ptr<GhostCache> ghost_cache(new GhostCache(
+ NewLRUCache(/*capacity=*/kGhostCacheSize, /*num_shard_bits=*/1,
+ /*strict_capacity_limit=*/false,
+ /*high_pri_pool_ratio=*/0)));
+ const BlockCacheTraceRecord& first_get = GenerateGetRecord(kGetId);
+ const BlockCacheTraceRecord& second_get = GenerateGetRecord(kGetId + 1);
+ const BlockCacheTraceRecord& third_get = GenerateGetRecord(kGetId + 2);
+ std::unique_ptr<HybridRowBlockCacheSimulator> cache_simulator(
+ new HybridRowBlockCacheSimulator(
+ std::move(ghost_cache),
+ NewLRUCache(/*capacity=*/kCacheSize, /*num_shard_bits=*/1,
+ /*strict_capacity_limit=*/false,
+ /*high_pri_pool_ratio=*/0),
+ /*insert_blocks_row_kvpair_misses=*/false));
+ // Two get requests access the same key.
+ cache_simulator->Access(first_get);
+ cache_simulator->Access(second_get);
+ ASSERT_EQ(2, cache_simulator->miss_ratio_stats().total_accesses());
+ ASSERT_EQ(100, cache_simulator->miss_ratio_stats().miss_ratio());
+ ASSERT_EQ(2, cache_simulator->miss_ratio_stats().user_accesses());
+ ASSERT_EQ(100, cache_simulator->miss_ratio_stats().user_miss_ratio());
+ // We insert the key-value pair upon the second get request. A third get
+ // request should observe a hit.
+ for (uint32_t i = 0; i < 10; i++) {
+ cache_simulator->Access(third_get);
+ }
+ ASSERT_EQ(12, cache_simulator->miss_ratio_stats().total_accesses());
+ ASSERT_EQ(16, static_cast<uint64_t>(
+ cache_simulator->miss_ratio_stats().miss_ratio()));
+ ASSERT_EQ(12, cache_simulator->miss_ratio_stats().user_accesses());
+ ASSERT_EQ(16, static_cast<uint64_t>(
+ cache_simulator->miss_ratio_stats().user_miss_ratio()));
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/rocksdb/utilities/simulator_cache/sim_cache.cc b/src/rocksdb/utilities/simulator_cache/sim_cache.cc
new file mode 100644
index 000000000..a883b52e7
--- /dev/null
+++ b/src/rocksdb/utilities/simulator_cache/sim_cache.cc
@@ -0,0 +1,364 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "rocksdb/utilities/sim_cache.h"
+
+#include <atomic>
+#include <iomanip>
+
+#include "file/writable_file_writer.h"
+#include "monitoring/statistics.h"
+#include "port/port.h"
+#include "rocksdb/env.h"
+#include "rocksdb/file_system.h"
+#include "util/mutexlock.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+namespace {
+
+class CacheActivityLogger {
+ public:
+ CacheActivityLogger()
+ : activity_logging_enabled_(false), max_logging_size_(0) {}
+
+ ~CacheActivityLogger() {
+ MutexLock l(&mutex_);
+
+ StopLoggingInternal();
+ bg_status_.PermitUncheckedError();
+ }
+
+ Status StartLogging(const std::string& activity_log_file, Env* env,
+ uint64_t max_logging_size = 0) {
+ assert(activity_log_file != "");
+ assert(env != nullptr);
+
+ Status status;
+ FileOptions file_opts;
+
+ MutexLock l(&mutex_);
+
+ // Stop existing logging if any
+ StopLoggingInternal();
+
+ // Open log file
+ status = WritableFileWriter::Create(env->GetFileSystem(), activity_log_file,
+ file_opts, &file_writer_, nullptr);
+ if (!status.ok()) {
+ return status;
+ }
+
+ max_logging_size_ = max_logging_size;
+ activity_logging_enabled_.store(true);
+
+ return status;
+ }
+
+ void StopLogging() {
+ MutexLock l(&mutex_);
+
+ StopLoggingInternal();
+ }
+
+ void ReportLookup(const Slice& key) {
+ if (activity_logging_enabled_.load() == false) {
+ return;
+ }
+
+ std::ostringstream oss;
+ // line format: "LOOKUP - <KEY>"
+ oss << "LOOKUP - " << key.ToString(true) << std::endl;
+
+ MutexLock l(&mutex_);
+ Status s = file_writer_->Append(oss.str());
+ if (!s.ok() && bg_status_.ok()) {
+ bg_status_ = s;
+ }
+ if (MaxLoggingSizeReached() || !bg_status_.ok()) {
+ // Stop logging if we have reached the max file size or
+ // encountered an error
+ StopLoggingInternal();
+ }
+ }
+
+ void ReportAdd(const Slice& key, size_t size) {
+ if (activity_logging_enabled_.load() == false) {
+ return;
+ }
+
+ std::ostringstream oss;
+ // line format: "ADD - <KEY> - <KEY-SIZE>"
+ oss << "ADD - " << key.ToString(true) << " - " << size << std::endl;
+ MutexLock l(&mutex_);
+ Status s = file_writer_->Append(oss.str());
+ if (!s.ok() && bg_status_.ok()) {
+ bg_status_ = s;
+ }
+
+ if (MaxLoggingSizeReached() || !bg_status_.ok()) {
+ // Stop logging if we have reached the max file size or
+ // encountered an error
+ StopLoggingInternal();
+ }
+ }
+
+ Status& bg_status() {
+ MutexLock l(&mutex_);
+ return bg_status_;
+ }
+
+ private:
+ bool MaxLoggingSizeReached() {
+ mutex_.AssertHeld();
+
+ return (max_logging_size_ > 0 &&
+ file_writer_->GetFileSize() >= max_logging_size_);
+ }
+
+ void StopLoggingInternal() {
+ mutex_.AssertHeld();
+
+ if (!activity_logging_enabled_) {
+ return;
+ }
+
+ activity_logging_enabled_.store(false);
+ Status s = file_writer_->Close();
+ if (!s.ok() && bg_status_.ok()) {
+ bg_status_ = s;
+ }
+ }
+
+ // Mutex to sync writes to file_writer, and all following
+ // class data members
+ port::Mutex mutex_;
+ // Indicates if logging is currently enabled
+ // atomic to allow reads without mutex
+ std::atomic<bool> activity_logging_enabled_;
+ // When reached, we will stop logging and close the file
+ // Value of 0 means unlimited
+ uint64_t max_logging_size_;
+ std::unique_ptr<WritableFileWriter> file_writer_;
+ Status bg_status_;
+};
+
+// SimCacheImpl definition
+class SimCacheImpl : public SimCache {
+ public:
+ // capacity for real cache (ShardedLRUCache)
+ // test_capacity for key only cache
+ SimCacheImpl(std::shared_ptr<Cache> sim_cache, std::shared_ptr<Cache> cache)
+ : cache_(cache),
+ key_only_cache_(sim_cache),
+ miss_times_(0),
+ hit_times_(0),
+ stats_(nullptr) {}
+
+ ~SimCacheImpl() override {}
+ void SetCapacity(size_t capacity) override { cache_->SetCapacity(capacity); }
+
+ void SetStrictCapacityLimit(bool strict_capacity_limit) override {
+ cache_->SetStrictCapacityLimit(strict_capacity_limit);
+ }
+
+ using Cache::Insert;
+ Status Insert(const Slice& key, void* value, size_t charge,
+ void (*deleter)(const Slice& key, void* value), Handle** handle,
+ Priority priority) override {
+ // The handle and value passed in are for real cache, so we pass nullptr
+ // to key_only_cache_ for both instead. Also, the deleter function pointer
+ // will be called by user to perform some external operation which should
+ // be applied only once. Thus key_only_cache accepts an empty function.
+ // *Lambda function without capture can be assgined to a function pointer
+ Handle* h = key_only_cache_->Lookup(key);
+ if (h == nullptr) {
+ // TODO: Check for error here?
+ auto s = key_only_cache_->Insert(
+ key, nullptr, charge, [](const Slice& /*k*/, void* /*v*/) {}, nullptr,
+ priority);
+ s.PermitUncheckedError();
+ } else {
+ key_only_cache_->Release(h);
+ }
+
+ cache_activity_logger_.ReportAdd(key, charge);
+ if (!cache_) {
+ return Status::OK();
+ }
+ return cache_->Insert(key, value, charge, deleter, handle, priority);
+ }
+
+ using Cache::Lookup;
+ Handle* Lookup(const Slice& key, Statistics* stats) override {
+ Handle* h = key_only_cache_->Lookup(key);
+ if (h != nullptr) {
+ key_only_cache_->Release(h);
+ inc_hit_counter();
+ RecordTick(stats, SIM_BLOCK_CACHE_HIT);
+ } else {
+ inc_miss_counter();
+ RecordTick(stats, SIM_BLOCK_CACHE_MISS);
+ }
+
+ cache_activity_logger_.ReportLookup(key);
+ if (!cache_) {
+ return nullptr;
+ }
+ return cache_->Lookup(key, stats);
+ }
+
+ bool Ref(Handle* handle) override { return cache_->Ref(handle); }
+
+ using Cache::Release;
+ bool Release(Handle* handle, bool erase_if_last_ref = false) override {
+ return cache_->Release(handle, erase_if_last_ref);
+ }
+
+ void Erase(const Slice& key) override {
+ cache_->Erase(key);
+ key_only_cache_->Erase(key);
+ }
+
+ void* Value(Handle* handle) override { return cache_->Value(handle); }
+
+ uint64_t NewId() override { return cache_->NewId(); }
+
+ size_t GetCapacity() const override { return cache_->GetCapacity(); }
+
+ bool HasStrictCapacityLimit() const override {
+ return cache_->HasStrictCapacityLimit();
+ }
+
+ size_t GetUsage() const override { return cache_->GetUsage(); }
+
+ size_t GetUsage(Handle* handle) const override {
+ return cache_->GetUsage(handle);
+ }
+
+ size_t GetCharge(Handle* handle) const override {
+ return cache_->GetCharge(handle);
+ }
+
+ DeleterFn GetDeleter(Handle* handle) const override {
+ return cache_->GetDeleter(handle);
+ }
+
+ size_t GetPinnedUsage() const override { return cache_->GetPinnedUsage(); }
+
+ void DisownData() override {
+ cache_->DisownData();
+ key_only_cache_->DisownData();
+ }
+
+ void ApplyToAllCacheEntries(void (*callback)(void*, size_t),
+ bool thread_safe) override {
+ // only apply to _cache since key_only_cache doesn't hold value
+ cache_->ApplyToAllCacheEntries(callback, thread_safe);
+ }
+
+ void ApplyToAllEntries(
+ const std::function<void(const Slice& key, void* value, size_t charge,
+ DeleterFn deleter)>& callback,
+ const ApplyToAllEntriesOptions& opts) override {
+ cache_->ApplyToAllEntries(callback, opts);
+ }
+
+ void EraseUnRefEntries() override {
+ cache_->EraseUnRefEntries();
+ key_only_cache_->EraseUnRefEntries();
+ }
+
+ size_t GetSimCapacity() const override {
+ return key_only_cache_->GetCapacity();
+ }
+ size_t GetSimUsage() const override { return key_only_cache_->GetUsage(); }
+ void SetSimCapacity(size_t capacity) override {
+ key_only_cache_->SetCapacity(capacity);
+ }
+
+ uint64_t get_miss_counter() const override {
+ return miss_times_.load(std::memory_order_relaxed);
+ }
+
+ uint64_t get_hit_counter() const override {
+ return hit_times_.load(std::memory_order_relaxed);
+ }
+
+ void reset_counter() override {
+ miss_times_.store(0, std::memory_order_relaxed);
+ hit_times_.store(0, std::memory_order_relaxed);
+ SetTickerCount(stats_, SIM_BLOCK_CACHE_HIT, 0);
+ SetTickerCount(stats_, SIM_BLOCK_CACHE_MISS, 0);
+ }
+
+ std::string ToString() const override {
+ std::ostringstream oss;
+ oss << "SimCache MISSes: " << get_miss_counter() << std::endl;
+ oss << "SimCache HITs: " << get_hit_counter() << std::endl;
+ auto lookups = get_miss_counter() + get_hit_counter();
+ oss << "SimCache HITRATE: " << std::fixed << std::setprecision(2)
+ << (lookups == 0 ? 0 : get_hit_counter() * 100.0f / lookups)
+ << std::endl;
+ return oss.str();
+ }
+
+ std::string GetPrintableOptions() const override {
+ std::ostringstream oss;
+ oss << " cache_options:" << std::endl;
+ oss << cache_->GetPrintableOptions();
+ oss << " sim_cache_options:" << std::endl;
+ oss << key_only_cache_->GetPrintableOptions();
+ return oss.str();
+ }
+
+ Status StartActivityLogging(const std::string& activity_log_file, Env* env,
+ uint64_t max_logging_size = 0) override {
+ return cache_activity_logger_.StartLogging(activity_log_file, env,
+ max_logging_size);
+ }
+
+ void StopActivityLogging() override { cache_activity_logger_.StopLogging(); }
+
+ Status GetActivityLoggingStatus() override {
+ return cache_activity_logger_.bg_status();
+ }
+
+ private:
+ std::shared_ptr<Cache> cache_;
+ std::shared_ptr<Cache> key_only_cache_;
+ std::atomic<uint64_t> miss_times_;
+ std::atomic<uint64_t> hit_times_;
+ Statistics* stats_;
+ CacheActivityLogger cache_activity_logger_;
+
+ void inc_miss_counter() {
+ miss_times_.fetch_add(1, std::memory_order_relaxed);
+ }
+ void inc_hit_counter() { hit_times_.fetch_add(1, std::memory_order_relaxed); }
+};
+
+} // end anonymous namespace
+
+// For instrumentation purpose, use NewSimCache instead
+std::shared_ptr<SimCache> NewSimCache(std::shared_ptr<Cache> cache,
+ size_t sim_capacity, int num_shard_bits) {
+ LRUCacheOptions co;
+ co.capacity = sim_capacity;
+ co.num_shard_bits = num_shard_bits;
+ co.metadata_charge_policy = kDontChargeCacheMetadata;
+ return NewSimCache(NewLRUCache(co), cache, num_shard_bits);
+}
+
+std::shared_ptr<SimCache> NewSimCache(std::shared_ptr<Cache> sim_cache,
+ std::shared_ptr<Cache> cache,
+ int num_shard_bits) {
+ if (num_shard_bits >= 20) {
+ return nullptr; // the cache cannot be sharded into too many fine pieces
+ }
+ return std::make_shared<SimCacheImpl>(sim_cache, cache);
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/simulator_cache/sim_cache_test.cc b/src/rocksdb/utilities/simulator_cache/sim_cache_test.cc
new file mode 100644
index 000000000..2e37cd347
--- /dev/null
+++ b/src/rocksdb/utilities/simulator_cache/sim_cache_test.cc
@@ -0,0 +1,226 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "rocksdb/utilities/sim_cache.h"
+
+#include <cstdlib>
+
+#include "db/db_test_util.h"
+#include "port/stack_trace.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class SimCacheTest : public DBTestBase {
+ private:
+ size_t miss_count_ = 0;
+ size_t hit_count_ = 0;
+ size_t insert_count_ = 0;
+ size_t failure_count_ = 0;
+
+ public:
+ const size_t kNumBlocks = 5;
+ const size_t kValueSize = 1000;
+
+ SimCacheTest() : DBTestBase("sim_cache_test", /*env_do_fsync=*/true) {}
+
+ BlockBasedTableOptions GetTableOptions() {
+ BlockBasedTableOptions table_options;
+ // Set a small enough block size so that each key-value get its own block.
+ table_options.block_size = 1;
+ return table_options;
+ }
+
+ Options GetOptions(const BlockBasedTableOptions& table_options) {
+ Options options = CurrentOptions();
+ options.create_if_missing = true;
+ // options.compression = kNoCompression;
+ options.statistics = ROCKSDB_NAMESPACE::CreateDBStatistics();
+ options.table_factory.reset(NewBlockBasedTableFactory(table_options));
+ return options;
+ }
+
+ void InitTable(const Options& /*options*/) {
+ std::string value(kValueSize, 'a');
+ for (size_t i = 0; i < kNumBlocks * 2; i++) {
+ ASSERT_OK(Put(std::to_string(i), value.c_str()));
+ }
+ }
+
+ void RecordCacheCounters(const Options& options) {
+ miss_count_ = TestGetTickerCount(options, BLOCK_CACHE_MISS);
+ hit_count_ = TestGetTickerCount(options, BLOCK_CACHE_HIT);
+ insert_count_ = TestGetTickerCount(options, BLOCK_CACHE_ADD);
+ failure_count_ = TestGetTickerCount(options, BLOCK_CACHE_ADD_FAILURES);
+ }
+
+ void CheckCacheCounters(const Options& options, size_t expected_misses,
+ size_t expected_hits, size_t expected_inserts,
+ size_t expected_failures) {
+ size_t new_miss_count = TestGetTickerCount(options, BLOCK_CACHE_MISS);
+ size_t new_hit_count = TestGetTickerCount(options, BLOCK_CACHE_HIT);
+ size_t new_insert_count = TestGetTickerCount(options, BLOCK_CACHE_ADD);
+ size_t new_failure_count =
+ TestGetTickerCount(options, BLOCK_CACHE_ADD_FAILURES);
+ ASSERT_EQ(miss_count_ + expected_misses, new_miss_count);
+ ASSERT_EQ(hit_count_ + expected_hits, new_hit_count);
+ ASSERT_EQ(insert_count_ + expected_inserts, new_insert_count);
+ ASSERT_EQ(failure_count_ + expected_failures, new_failure_count);
+ miss_count_ = new_miss_count;
+ hit_count_ = new_hit_count;
+ insert_count_ = new_insert_count;
+ failure_count_ = new_failure_count;
+ }
+};
+
+TEST_F(SimCacheTest, SimCache) {
+ ReadOptions read_options;
+ auto table_options = GetTableOptions();
+ auto options = GetOptions(table_options);
+ InitTable(options);
+ LRUCacheOptions co;
+ co.capacity = 0;
+ co.num_shard_bits = 0;
+ co.strict_capacity_limit = false;
+ co.metadata_charge_policy = kDontChargeCacheMetadata;
+ std::shared_ptr<SimCache> simCache = NewSimCache(NewLRUCache(co), 20000, 0);
+ table_options.block_cache = simCache;
+ options.table_factory.reset(NewBlockBasedTableFactory(table_options));
+ Reopen(options);
+ RecordCacheCounters(options);
+ // due to cache entry stats collector
+ uint64_t base_misses = simCache->get_miss_counter();
+
+ std::vector<std::unique_ptr<Iterator>> iterators(kNumBlocks);
+ Iterator* iter = nullptr;
+
+ // Load blocks into cache.
+ for (size_t i = 0; i < kNumBlocks; i++) {
+ iter = db_->NewIterator(read_options);
+ iter->Seek(std::to_string(i));
+ ASSERT_OK(iter->status());
+ CheckCacheCounters(options, 1, 0, 1, 0);
+ iterators[i].reset(iter);
+ }
+ ASSERT_EQ(kNumBlocks, simCache->get_hit_counter() +
+ simCache->get_miss_counter() - base_misses);
+ ASSERT_EQ(0, simCache->get_hit_counter());
+ size_t usage = simCache->GetUsage();
+ ASSERT_LT(0, usage);
+ ASSERT_EQ(usage, simCache->GetSimUsage());
+ simCache->SetCapacity(usage);
+ ASSERT_EQ(usage, simCache->GetPinnedUsage());
+
+ // Test with strict capacity limit.
+ simCache->SetStrictCapacityLimit(true);
+ iter = db_->NewIterator(read_options);
+ iter->Seek(std::to_string(kNumBlocks * 2 - 1));
+ ASSERT_TRUE(iter->status().IsMemoryLimit());
+ CheckCacheCounters(options, 1, 0, 0, 1);
+ delete iter;
+ iter = nullptr;
+
+ // Release iterators and access cache again.
+ for (size_t i = 0; i < kNumBlocks; i++) {
+ iterators[i].reset();
+ CheckCacheCounters(options, 0, 0, 0, 0);
+ }
+ // Add kNumBlocks again
+ for (size_t i = 0; i < kNumBlocks; i++) {
+ std::unique_ptr<Iterator> it(db_->NewIterator(read_options));
+ it->Seek(std::to_string(i));
+ ASSERT_OK(it->status());
+ CheckCacheCounters(options, 0, 1, 0, 0);
+ }
+ ASSERT_EQ(5, simCache->get_hit_counter());
+ for (size_t i = kNumBlocks; i < kNumBlocks * 2; i++) {
+ std::unique_ptr<Iterator> it(db_->NewIterator(read_options));
+ it->Seek(std::to_string(i));
+ ASSERT_OK(it->status());
+ CheckCacheCounters(options, 1, 0, 1, 0);
+ }
+ ASSERT_EQ(0, simCache->GetPinnedUsage());
+ ASSERT_EQ(3 * kNumBlocks + 1, simCache->get_hit_counter() +
+ simCache->get_miss_counter() - base_misses);
+ ASSERT_EQ(6, simCache->get_hit_counter());
+}
+
+TEST_F(SimCacheTest, SimCacheLogging) {
+ auto table_options = GetTableOptions();
+ auto options = GetOptions(table_options);
+ options.disable_auto_compactions = true;
+ LRUCacheOptions co;
+ co.capacity = 1024 * 1024;
+ co.metadata_charge_policy = kDontChargeCacheMetadata;
+ std::shared_ptr<SimCache> sim_cache = NewSimCache(NewLRUCache(co), 20000, 0);
+ table_options.block_cache = sim_cache;
+ options.table_factory.reset(NewBlockBasedTableFactory(table_options));
+ Reopen(options);
+
+ int num_block_entries = 20;
+ for (int i = 0; i < num_block_entries; i++) {
+ ASSERT_OK(Put(Key(i), "val"));
+ ASSERT_OK(Flush());
+ }
+
+ std::string log_file = test::PerThreadDBPath(env_, "cache_log.txt");
+ ASSERT_OK(sim_cache->StartActivityLogging(log_file, env_));
+ for (int i = 0; i < num_block_entries; i++) {
+ ASSERT_EQ(Get(Key(i)), "val");
+ }
+ for (int i = 0; i < num_block_entries; i++) {
+ ASSERT_EQ(Get(Key(i)), "val");
+ }
+ sim_cache->StopActivityLogging();
+ ASSERT_OK(sim_cache->GetActivityLoggingStatus());
+
+ std::string file_contents = "";
+ ASSERT_OK(ReadFileToString(env_, log_file, &file_contents));
+ std::istringstream contents(file_contents);
+
+ int lookup_num = 0;
+ int add_num = 0;
+
+ std::string line;
+ // count number of lookups and additions
+ while (std::getline(contents, line)) {
+ // check if the line starts with LOOKUP or ADD
+ if (line.rfind("LOOKUP -", 0) == 0) {
+ ++lookup_num;
+ }
+ if (line.rfind("ADD -", 0) == 0) {
+ ++add_num;
+ }
+ }
+
+ // We asked for every block twice
+ ASSERT_EQ(lookup_num, num_block_entries * 2);
+
+ // We added every block only once, since the cache can hold all blocks
+ ASSERT_EQ(add_num, num_block_entries);
+
+ // Log things again but stop logging automatically after reaching 512 bytes
+ int max_size = 512;
+ ASSERT_OK(sim_cache->StartActivityLogging(log_file, env_, max_size));
+ for (int it = 0; it < 10; it++) {
+ for (int i = 0; i < num_block_entries; i++) {
+ ASSERT_EQ(Get(Key(i)), "val");
+ }
+ }
+ ASSERT_OK(sim_cache->GetActivityLoggingStatus());
+
+ uint64_t fsize = 0;
+ ASSERT_OK(env_->GetFileSize(log_file, &fsize));
+ // error margin of 100 bytes
+ ASSERT_LT(fsize, max_size + 100);
+ ASSERT_GT(fsize, max_size - 100);
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/rocksdb/utilities/table_properties_collectors/compact_on_deletion_collector.cc b/src/rocksdb/utilities/table_properties_collectors/compact_on_deletion_collector.cc
new file mode 100644
index 000000000..16f33934d
--- /dev/null
+++ b/src/rocksdb/utilities/table_properties_collectors/compact_on_deletion_collector.cc
@@ -0,0 +1,227 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "utilities/table_properties_collectors/compact_on_deletion_collector.h"
+
+#include <memory>
+#include <sstream>
+
+#include "rocksdb/utilities/customizable_util.h"
+#include "rocksdb/utilities/object_registry.h"
+#include "rocksdb/utilities/options_type.h"
+#include "rocksdb/utilities/table_properties_collectors.h"
+#include "util/string_util.h"
+
+namespace ROCKSDB_NAMESPACE {
+#ifndef ROCKSDB_LITE
+
+CompactOnDeletionCollector::CompactOnDeletionCollector(
+ size_t sliding_window_size, size_t deletion_trigger, double deletion_ratio)
+ : bucket_size_((sliding_window_size + kNumBuckets - 1) / kNumBuckets),
+ current_bucket_(0),
+ num_keys_in_current_bucket_(0),
+ num_deletions_in_observation_window_(0),
+ deletion_trigger_(deletion_trigger),
+ deletion_ratio_(deletion_ratio),
+ deletion_ratio_enabled_(deletion_ratio > 0 && deletion_ratio <= 1),
+ need_compaction_(false),
+ finished_(false) {
+ memset(num_deletions_in_buckets_, 0, sizeof(size_t) * kNumBuckets);
+}
+
+// AddUserKey() will be called when a new key/value pair is inserted into the
+// table.
+// @params key the user key that is inserted into the table.
+// @params value the value that is inserted into the table.
+// @params file_size file size up to now
+Status CompactOnDeletionCollector::AddUserKey(const Slice& /*key*/,
+ const Slice& /*value*/,
+ EntryType type,
+ SequenceNumber /*seq*/,
+ uint64_t /*file_size*/) {
+ assert(!finished_);
+ if (!bucket_size_ && !deletion_ratio_enabled_) {
+ // This collector is effectively disabled
+ return Status::OK();
+ }
+
+ if (need_compaction_) {
+ // If the output file already needs to be compacted, skip the check.
+ return Status::OK();
+ }
+
+ if (deletion_ratio_enabled_) {
+ total_entries_++;
+ if (type == kEntryDelete) {
+ deletion_entries_++;
+ }
+ }
+
+ if (bucket_size_) {
+ if (num_keys_in_current_bucket_ == bucket_size_) {
+ // When the current bucket is full, advance the cursor of the
+ // ring buffer to the next bucket.
+ current_bucket_ = (current_bucket_ + 1) % kNumBuckets;
+
+ // Update the current count of observed deletion keys by excluding
+ // the number of deletion keys in the oldest bucket in the
+ // observation window.
+ assert(num_deletions_in_observation_window_ >=
+ num_deletions_in_buckets_[current_bucket_]);
+ num_deletions_in_observation_window_ -=
+ num_deletions_in_buckets_[current_bucket_];
+ num_deletions_in_buckets_[current_bucket_] = 0;
+ num_keys_in_current_bucket_ = 0;
+ }
+
+ num_keys_in_current_bucket_++;
+ if (type == kEntryDelete) {
+ num_deletions_in_observation_window_++;
+ num_deletions_in_buckets_[current_bucket_]++;
+ if (num_deletions_in_observation_window_ >= deletion_trigger_) {
+ need_compaction_ = true;
+ }
+ }
+ }
+
+ return Status::OK();
+}
+
+Status CompactOnDeletionCollector::Finish(
+ UserCollectedProperties* /*properties*/) {
+ if (!need_compaction_ && deletion_ratio_enabled_ && total_entries_ > 0) {
+ double ratio = static_cast<double>(deletion_entries_) / total_entries_;
+ need_compaction_ = ratio >= deletion_ratio_;
+ }
+ finished_ = true;
+ return Status::OK();
+}
+static std::unordered_map<std::string, OptionTypeInfo>
+ on_deletion_collector_type_info = {
+#ifndef ROCKSDB_LITE
+ {"window_size",
+ {0, OptionType::kUnknown, OptionVerificationType::kNormal,
+ OptionTypeFlags::kCompareNever | OptionTypeFlags::kMutable,
+ [](const ConfigOptions&, const std::string&, const std::string& value,
+ void* addr) {
+ auto* factory =
+ static_cast<CompactOnDeletionCollectorFactory*>(addr);
+ factory->SetWindowSize(ParseSizeT(value));
+ return Status::OK();
+ },
+ [](const ConfigOptions&, const std::string&, const void* addr,
+ std::string* value) {
+ const auto* factory =
+ static_cast<const CompactOnDeletionCollectorFactory*>(addr);
+ *value = std::to_string(factory->GetWindowSize());
+ return Status::OK();
+ },
+ nullptr}},
+ {"deletion_trigger",
+ {0, OptionType::kUnknown, OptionVerificationType::kNormal,
+ OptionTypeFlags::kCompareNever | OptionTypeFlags::kMutable,
+ [](const ConfigOptions&, const std::string&, const std::string& value,
+ void* addr) {
+ auto* factory =
+ static_cast<CompactOnDeletionCollectorFactory*>(addr);
+ factory->SetDeletionTrigger(ParseSizeT(value));
+ return Status::OK();
+ },
+ [](const ConfigOptions&, const std::string&, const void* addr,
+ std::string* value) {
+ const auto* factory =
+ static_cast<const CompactOnDeletionCollectorFactory*>(addr);
+ *value = std::to_string(factory->GetDeletionTrigger());
+ return Status::OK();
+ },
+ nullptr}},
+ {"deletion_ratio",
+ {0, OptionType::kUnknown, OptionVerificationType::kNormal,
+ OptionTypeFlags::kCompareNever | OptionTypeFlags::kMutable,
+ [](const ConfigOptions&, const std::string&, const std::string& value,
+ void* addr) {
+ auto* factory =
+ static_cast<CompactOnDeletionCollectorFactory*>(addr);
+ factory->SetDeletionRatio(ParseDouble(value));
+ return Status::OK();
+ },
+ [](const ConfigOptions&, const std::string&, const void* addr,
+ std::string* value) {
+ const auto* factory =
+ static_cast<const CompactOnDeletionCollectorFactory*>(addr);
+ *value = std::to_string(factory->GetDeletionRatio());
+ return Status::OK();
+ },
+ nullptr}},
+
+#endif // ROCKSDB_LITE
+};
+
+CompactOnDeletionCollectorFactory::CompactOnDeletionCollectorFactory(
+ size_t sliding_window_size, size_t deletion_trigger, double deletion_ratio)
+ : sliding_window_size_(sliding_window_size),
+ deletion_trigger_(deletion_trigger),
+ deletion_ratio_(deletion_ratio) {
+ RegisterOptions("", this, &on_deletion_collector_type_info);
+}
+
+TablePropertiesCollector*
+CompactOnDeletionCollectorFactory::CreateTablePropertiesCollector(
+ TablePropertiesCollectorFactory::Context /*context*/) {
+ return new CompactOnDeletionCollector(sliding_window_size_.load(),
+ deletion_trigger_.load(),
+ deletion_ratio_.load());
+}
+
+std::string CompactOnDeletionCollectorFactory::ToString() const {
+ std::ostringstream cfg;
+ cfg << Name() << " (Sliding window size = " << sliding_window_size_.load()
+ << " Deletion trigger = " << deletion_trigger_.load()
+ << " Deletion ratio = " << deletion_ratio_.load() << ')';
+ return cfg.str();
+}
+
+std::shared_ptr<CompactOnDeletionCollectorFactory>
+NewCompactOnDeletionCollectorFactory(size_t sliding_window_size,
+ size_t deletion_trigger,
+ double deletion_ratio) {
+ return std::shared_ptr<CompactOnDeletionCollectorFactory>(
+ new CompactOnDeletionCollectorFactory(sliding_window_size,
+ deletion_trigger, deletion_ratio));
+}
+namespace {
+static int RegisterTablePropertiesCollectorFactories(
+ ObjectLibrary& library, const std::string& /*arg*/) {
+ library.AddFactory<TablePropertiesCollectorFactory>(
+ CompactOnDeletionCollectorFactory::kClassName(),
+ [](const std::string& /*uri*/,
+ std::unique_ptr<TablePropertiesCollectorFactory>* guard,
+ std::string* /* errmsg */) {
+ // By default, create a CompactionOnDeletionCollector that is disabled.
+ // Users will need to provide configuration parameters or call the
+ // corresponding Setter to enable the factory.
+ guard->reset(new CompactOnDeletionCollectorFactory(0, 0, 0));
+ return guard->get();
+ });
+ return 1;
+}
+} // namespace
+#endif // !ROCKSDB_LITE
+
+Status TablePropertiesCollectorFactory::CreateFromString(
+ const ConfigOptions& options, const std::string& value,
+ std::shared_ptr<TablePropertiesCollectorFactory>* result) {
+#ifndef ROCKSDB_LITE
+ static std::once_flag once;
+ std::call_once(once, [&]() {
+ RegisterTablePropertiesCollectorFactories(*(ObjectLibrary::Default().get()),
+ "");
+ });
+#endif // ROCKSDB_LITE
+ return LoadSharedObject<TablePropertiesCollectorFactory>(options, value,
+ nullptr, result);
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/table_properties_collectors/compact_on_deletion_collector.h b/src/rocksdb/utilities/table_properties_collectors/compact_on_deletion_collector.h
new file mode 100644
index 000000000..2f7dc4f1b
--- /dev/null
+++ b/src/rocksdb/utilities/table_properties_collectors/compact_on_deletion_collector.h
@@ -0,0 +1,70 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#ifndef ROCKSDB_LITE
+#include "rocksdb/utilities/table_properties_collectors.h"
+namespace ROCKSDB_NAMESPACE {
+
+class CompactOnDeletionCollector : public TablePropertiesCollector {
+ public:
+ CompactOnDeletionCollector(size_t sliding_window_size,
+ size_t deletion_trigger, double deletion_raatio);
+
+ // AddUserKey() will be called when a new key/value pair is inserted into the
+ // table.
+ // @params key the user key that is inserted into the table.
+ // @params value the value that is inserted into the table.
+ // @params file_size file size up to now
+ virtual Status AddUserKey(const Slice& key, const Slice& value,
+ EntryType type, SequenceNumber seq,
+ uint64_t file_size) override;
+
+ // Finish() will be called when a table has already been built and is ready
+ // for writing the properties block.
+ // @params properties User will add their collected statistics to
+ // `properties`.
+ virtual Status Finish(UserCollectedProperties* /*properties*/) override;
+
+ // Return the human-readable properties, where the key is property name and
+ // the value is the human-readable form of value.
+ virtual UserCollectedProperties GetReadableProperties() const override {
+ return UserCollectedProperties();
+ }
+
+ // The name of the properties collector can be used for debugging purpose.
+ virtual const char* Name() const override {
+ return "CompactOnDeletionCollector";
+ }
+
+ // EXPERIMENTAL Return whether the output file should be further compacted
+ virtual bool NeedCompact() const override { return need_compaction_; }
+
+ static const int kNumBuckets = 128;
+
+ private:
+ void Reset();
+
+ // A ring buffer that used to count the number of deletion entries for every
+ // "bucket_size_" keys.
+ size_t num_deletions_in_buckets_[kNumBuckets];
+ // the number of keys in a bucket
+ size_t bucket_size_;
+
+ size_t current_bucket_;
+ size_t num_keys_in_current_bucket_;
+ size_t num_deletions_in_observation_window_;
+ size_t deletion_trigger_;
+ const double deletion_ratio_;
+ const bool deletion_ratio_enabled_;
+ size_t total_entries_ = 0;
+ size_t deletion_entries_ = 0;
+ // true if the current SST file needs to be compacted.
+ bool need_compaction_;
+ bool finished_;
+};
+} // namespace ROCKSDB_NAMESPACE
+#endif // !ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/table_properties_collectors/compact_on_deletion_collector_test.cc b/src/rocksdb/utilities/table_properties_collectors/compact_on_deletion_collector_test.cc
new file mode 100644
index 000000000..88aeb8d5c
--- /dev/null
+++ b/src/rocksdb/utilities/table_properties_collectors/compact_on_deletion_collector_test.cc
@@ -0,0 +1,245 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include <stdio.h>
+
+#ifndef ROCKSDB_LITE
+#include <algorithm>
+#include <cmath>
+#include <vector>
+
+#include "port/stack_trace.h"
+#include "rocksdb/table.h"
+#include "rocksdb/table_properties.h"
+#include "rocksdb/utilities/table_properties_collectors.h"
+#include "test_util/testharness.h"
+#include "util/random.h"
+#include "utilities/table_properties_collectors/compact_on_deletion_collector.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+TEST(CompactOnDeletionCollector, DeletionRatio) {
+ TablePropertiesCollectorFactory::Context context;
+ context.column_family_id =
+ TablePropertiesCollectorFactory::Context::kUnknownColumnFamily;
+ const size_t kTotalEntries = 100;
+
+ {
+ // Disable deletion ratio.
+ for (double deletion_ratio : {-1.5, -1.0, 0.0, 1.5, 2.0}) {
+ auto factory = NewCompactOnDeletionCollectorFactory(0, 0, deletion_ratio);
+ std::unique_ptr<TablePropertiesCollector> collector(
+ factory->CreateTablePropertiesCollector(context));
+ for (size_t i = 0; i < kTotalEntries; i++) {
+ // All entries are deletion entries.
+ ASSERT_OK(
+ collector->AddUserKey("hello", "rocksdb", kEntryDelete, 0, 0));
+ ASSERT_FALSE(collector->NeedCompact());
+ }
+ ASSERT_OK(collector->Finish(nullptr));
+ ASSERT_FALSE(collector->NeedCompact());
+ }
+ }
+
+ {
+ for (double deletion_ratio : {0.3, 0.5, 0.8, 1.0}) {
+ auto factory = NewCompactOnDeletionCollectorFactory(0, 0, deletion_ratio);
+ const size_t deletion_entries_trigger =
+ static_cast<size_t>(deletion_ratio * kTotalEntries);
+ for (int delta : {-1, 0, 1}) {
+ // Actual deletion entry ratio <, =, > deletion_ratio
+ size_t actual_deletion_entries = deletion_entries_trigger + delta;
+ std::unique_ptr<TablePropertiesCollector> collector(
+ factory->CreateTablePropertiesCollector(context));
+ for (size_t i = 0; i < kTotalEntries; i++) {
+ if (i < actual_deletion_entries) {
+ ASSERT_OK(
+ collector->AddUserKey("hello", "rocksdb", kEntryDelete, 0, 0));
+ } else {
+ ASSERT_OK(
+ collector->AddUserKey("hello", "rocksdb", kEntryPut, 0, 0));
+ }
+ ASSERT_FALSE(collector->NeedCompact());
+ }
+ ASSERT_OK(collector->Finish(nullptr));
+ if (delta >= 0) {
+ // >= deletion_ratio
+ ASSERT_TRUE(collector->NeedCompact());
+ } else {
+ ASSERT_FALSE(collector->NeedCompact());
+ }
+ }
+ }
+ }
+}
+
+TEST(CompactOnDeletionCollector, SlidingWindow) {
+ const int kWindowSizes[] = {1000, 10000, 10000, 127, 128, 129,
+ 255, 256, 257, 2, 10000};
+ const int kDeletionTriggers[] = {500, 9500, 4323, 47, 61, 128,
+ 250, 250, 250, 2, 2};
+ TablePropertiesCollectorFactory::Context context;
+ context.column_family_id =
+ TablePropertiesCollectorFactory::Context::kUnknownColumnFamily;
+
+ std::vector<int> window_sizes;
+ std::vector<int> deletion_triggers;
+ // deterministic tests
+ for (int test = 0; test < 9; ++test) {
+ window_sizes.emplace_back(kWindowSizes[test]);
+ deletion_triggers.emplace_back(kDeletionTriggers[test]);
+ }
+
+ // randomize tests
+ Random rnd(301);
+ const int kMaxTestSize = 100000l;
+ for (int random_test = 0; random_test < 10; random_test++) {
+ int window_size = rnd.Uniform(kMaxTestSize) + 1;
+ int deletion_trigger = rnd.Uniform(window_size);
+ window_sizes.emplace_back(window_size);
+ deletion_triggers.emplace_back(deletion_trigger);
+ }
+
+ assert(window_sizes.size() == deletion_triggers.size());
+
+ for (size_t test = 0; test < window_sizes.size(); ++test) {
+ const int kBucketSize = 128;
+ const int kWindowSize = window_sizes[test];
+ const int kPaddedWindowSize =
+ kBucketSize * ((window_sizes[test] + kBucketSize - 1) / kBucketSize);
+ const int kNumDeletionTrigger = deletion_triggers[test];
+ const int kBias = (kNumDeletionTrigger + kBucketSize - 1) / kBucketSize;
+ // Simple test
+ {
+ auto factory = NewCompactOnDeletionCollectorFactory(kWindowSize,
+ kNumDeletionTrigger);
+ const int kSample = 10;
+ for (int delete_rate = 0; delete_rate <= kSample; ++delete_rate) {
+ std::unique_ptr<TablePropertiesCollector> collector(
+ factory->CreateTablePropertiesCollector(context));
+ int deletions = 0;
+ for (int i = 0; i < kPaddedWindowSize; ++i) {
+ if (i % kSample < delete_rate) {
+ ASSERT_OK(
+ collector->AddUserKey("hello", "rocksdb", kEntryDelete, 0, 0));
+ deletions++;
+ } else {
+ ASSERT_OK(
+ collector->AddUserKey("hello", "rocksdb", kEntryPut, 0, 0));
+ }
+ }
+ if (collector->NeedCompact() != (deletions >= kNumDeletionTrigger) &&
+ std::abs(deletions - kNumDeletionTrigger) > kBias) {
+ fprintf(stderr,
+ "[Error] collector->NeedCompact() != (%d >= %d)"
+ " with kWindowSize = %d and kNumDeletionTrigger = %d\n",
+ deletions, kNumDeletionTrigger, kWindowSize,
+ kNumDeletionTrigger);
+ ASSERT_TRUE(false);
+ }
+ ASSERT_OK(collector->Finish(nullptr));
+ }
+ }
+
+ // Only one section of a file satisfies the compaction trigger
+ {
+ auto factory = NewCompactOnDeletionCollectorFactory(kWindowSize,
+ kNumDeletionTrigger);
+ const int kSample = 10;
+ for (int delete_rate = 0; delete_rate <= kSample; ++delete_rate) {
+ std::unique_ptr<TablePropertiesCollector> collector(
+ factory->CreateTablePropertiesCollector(context));
+ int deletions = 0;
+ for (int section = 0; section < 5; ++section) {
+ int initial_entries = rnd.Uniform(kWindowSize) + kWindowSize;
+ for (int i = 0; i < initial_entries; ++i) {
+ ASSERT_OK(
+ collector->AddUserKey("hello", "rocksdb", kEntryPut, 0, 0));
+ }
+ }
+ for (int i = 0; i < kPaddedWindowSize; ++i) {
+ if (i % kSample < delete_rate) {
+ ASSERT_OK(
+ collector->AddUserKey("hello", "rocksdb", kEntryDelete, 0, 0));
+ deletions++;
+ } else {
+ ASSERT_OK(
+ collector->AddUserKey("hello", "rocksdb", kEntryPut, 0, 0));
+ }
+ }
+ for (int section = 0; section < 5; ++section) {
+ int ending_entries = rnd.Uniform(kWindowSize) + kWindowSize;
+ for (int i = 0; i < ending_entries; ++i) {
+ ASSERT_OK(
+ collector->AddUserKey("hello", "rocksdb", kEntryPut, 0, 0));
+ }
+ }
+ if (collector->NeedCompact() != (deletions >= kNumDeletionTrigger) &&
+ std::abs(deletions - kNumDeletionTrigger) > kBias) {
+ fprintf(stderr,
+ "[Error] collector->NeedCompact() %d != (%d >= %d)"
+ " with kWindowSize = %d, kNumDeletionTrigger = %d\n",
+ collector->NeedCompact(), deletions, kNumDeletionTrigger,
+ kWindowSize, kNumDeletionTrigger);
+ ASSERT_TRUE(false);
+ }
+ ASSERT_OK(collector->Finish(nullptr));
+ }
+ }
+
+ // TEST 3: Issues a lots of deletes, but their density is not
+ // high enough to trigger compaction.
+ {
+ std::unique_ptr<TablePropertiesCollector> collector;
+ auto factory = NewCompactOnDeletionCollectorFactory(kWindowSize,
+ kNumDeletionTrigger);
+ collector.reset(factory->CreateTablePropertiesCollector(context));
+ assert(collector->NeedCompact() == false);
+ // Insert "kNumDeletionTrigger * 0.95" deletions for every
+ // "kWindowSize" and verify compaction is not needed.
+ const int kDeletionsPerSection = kNumDeletionTrigger * 95 / 100;
+ if (kDeletionsPerSection >= 0) {
+ for (int section = 0; section < 200; ++section) {
+ for (int i = 0; i < kPaddedWindowSize; ++i) {
+ if (i < kDeletionsPerSection) {
+ ASSERT_OK(collector->AddUserKey("hello", "rocksdb", kEntryDelete,
+ 0, 0));
+ } else {
+ ASSERT_OK(
+ collector->AddUserKey("hello", "rocksdb", kEntryPut, 0, 0));
+ }
+ }
+ }
+ if (collector->NeedCompact() &&
+ std::abs(kDeletionsPerSection - kNumDeletionTrigger) > kBias) {
+ fprintf(stderr,
+ "[Error] collector->NeedCompact() != false"
+ " with kWindowSize = %d and kNumDeletionTrigger = %d\n",
+ kWindowSize, kNumDeletionTrigger);
+ ASSERT_TRUE(false);
+ }
+ ASSERT_OK(collector->Finish(nullptr));
+ }
+ }
+ }
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
+#else
+int main(int /*argc*/, char** /*argv*/) {
+ fprintf(stderr, "SKIPPED as RocksDBLite does not include utilities.\n");
+ return 0;
+}
+#endif // !ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/trace/file_trace_reader_writer.cc b/src/rocksdb/utilities/trace/file_trace_reader_writer.cc
new file mode 100644
index 000000000..5886d3539
--- /dev/null
+++ b/src/rocksdb/utilities/trace/file_trace_reader_writer.cc
@@ -0,0 +1,133 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "utilities/trace/file_trace_reader_writer.h"
+
+#include "env/composite_env_wrapper.h"
+#include "file/random_access_file_reader.h"
+#include "file/writable_file_writer.h"
+#include "trace_replay/trace_replay.h"
+#include "util/coding.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+const unsigned int FileTraceReader::kBufferSize = 1024; // 1KB
+
+FileTraceReader::FileTraceReader(
+ std::unique_ptr<RandomAccessFileReader>&& reader)
+ : file_reader_(std::move(reader)),
+ offset_(0),
+ buffer_(new char[kBufferSize]) {}
+
+FileTraceReader::~FileTraceReader() {
+ Close().PermitUncheckedError();
+ delete[] buffer_;
+}
+
+Status FileTraceReader::Close() {
+ file_reader_.reset();
+ return Status::OK();
+}
+
+Status FileTraceReader::Reset() {
+ if (file_reader_ == nullptr) {
+ return Status::IOError("TraceReader is closed.");
+ }
+ offset_ = 0;
+ return Status::OK();
+}
+
+Status FileTraceReader::Read(std::string* data) {
+ assert(file_reader_ != nullptr);
+ Status s = file_reader_->Read(IOOptions(), offset_, kTraceMetadataSize,
+ &result_, buffer_, nullptr,
+ Env::IO_TOTAL /* rate_limiter_priority */);
+ if (!s.ok()) {
+ return s;
+ }
+ if (result_.size() == 0) {
+ // No more data to read
+ // Todo: Come up with a better way to indicate end of data. May be this
+ // could be avoided once footer is introduced.
+ return Status::Incomplete();
+ }
+ if (result_.size() < kTraceMetadataSize) {
+ return Status::Corruption("Corrupted trace file.");
+ }
+ *data = result_.ToString();
+ offset_ += kTraceMetadataSize;
+
+ uint32_t payload_len =
+ DecodeFixed32(&buffer_[kTraceTimestampSize + kTraceTypeSize]);
+
+ // Read Payload
+ unsigned int bytes_to_read = payload_len;
+ unsigned int to_read =
+ bytes_to_read > kBufferSize ? kBufferSize : bytes_to_read;
+ while (to_read > 0) {
+ s = file_reader_->Read(IOOptions(), offset_, to_read, &result_, buffer_,
+ nullptr, Env::IO_TOTAL /* rate_limiter_priority */);
+ if (!s.ok()) {
+ return s;
+ }
+ if (result_.size() < to_read) {
+ return Status::Corruption("Corrupted trace file.");
+ }
+ data->append(result_.data(), result_.size());
+
+ offset_ += to_read;
+ bytes_to_read -= to_read;
+ to_read = bytes_to_read > kBufferSize ? kBufferSize : bytes_to_read;
+ }
+
+ return s;
+}
+
+FileTraceWriter::FileTraceWriter(
+ std::unique_ptr<WritableFileWriter>&& file_writer)
+ : file_writer_(std::move(file_writer)) {}
+
+FileTraceWriter::~FileTraceWriter() { Close().PermitUncheckedError(); }
+
+Status FileTraceWriter::Close() {
+ file_writer_.reset();
+ return Status::OK();
+}
+
+Status FileTraceWriter::Write(const Slice& data) {
+ return file_writer_->Append(data);
+}
+
+uint64_t FileTraceWriter::GetFileSize() { return file_writer_->GetFileSize(); }
+
+Status NewFileTraceReader(Env* env, const EnvOptions& env_options,
+ const std::string& trace_filename,
+ std::unique_ptr<TraceReader>* trace_reader) {
+ std::unique_ptr<RandomAccessFileReader> file_reader;
+ Status s = RandomAccessFileReader::Create(
+ env->GetFileSystem(), trace_filename, FileOptions(env_options),
+ &file_reader, nullptr);
+ if (!s.ok()) {
+ return s;
+ }
+ trace_reader->reset(new FileTraceReader(std::move(file_reader)));
+ return s;
+}
+
+Status NewFileTraceWriter(Env* env, const EnvOptions& env_options,
+ const std::string& trace_filename,
+ std::unique_ptr<TraceWriter>* trace_writer) {
+ std::unique_ptr<WritableFileWriter> file_writer;
+ Status s = WritableFileWriter::Create(env->GetFileSystem(), trace_filename,
+ FileOptions(env_options), &file_writer,
+ nullptr);
+ if (!s.ok()) {
+ return s;
+ }
+ trace_writer->reset(new FileTraceWriter(std::move(file_writer)));
+ return s;
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/trace/file_trace_reader_writer.h b/src/rocksdb/utilities/trace/file_trace_reader_writer.h
new file mode 100644
index 000000000..65d483108
--- /dev/null
+++ b/src/rocksdb/utilities/trace/file_trace_reader_writer.h
@@ -0,0 +1,48 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#include "rocksdb/trace_reader_writer.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class RandomAccessFileReader;
+class WritableFileWriter;
+
+// FileTraceReader allows reading RocksDB traces from a file.
+class FileTraceReader : public TraceReader {
+ public:
+ explicit FileTraceReader(std::unique_ptr<RandomAccessFileReader>&& reader);
+ ~FileTraceReader();
+
+ virtual Status Read(std::string* data) override;
+ virtual Status Close() override;
+ virtual Status Reset() override;
+
+ private:
+ std::unique_ptr<RandomAccessFileReader> file_reader_;
+ Slice result_;
+ size_t offset_;
+ char* const buffer_;
+
+ static const unsigned int kBufferSize;
+};
+
+// FileTraceWriter allows writing RocksDB traces to a file.
+class FileTraceWriter : public TraceWriter {
+ public:
+ explicit FileTraceWriter(std::unique_ptr<WritableFileWriter>&& file_writer);
+ ~FileTraceWriter();
+
+ virtual Status Write(const Slice& data) override;
+ virtual Status Close() override;
+ virtual uint64_t GetFileSize() override;
+
+ private:
+ std::unique_ptr<WritableFileWriter> file_writer_;
+};
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/trace/replayer_impl.cc b/src/rocksdb/utilities/trace/replayer_impl.cc
new file mode 100644
index 000000000..31023f1a2
--- /dev/null
+++ b/src/rocksdb/utilities/trace/replayer_impl.cc
@@ -0,0 +1,316 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "utilities/trace/replayer_impl.h"
+
+#include <cmath>
+#include <thread>
+
+#include "rocksdb/options.h"
+#include "rocksdb/slice.h"
+#include "rocksdb/system_clock.h"
+#include "util/threadpool_imp.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+ReplayerImpl::ReplayerImpl(DB* db,
+ const std::vector<ColumnFamilyHandle*>& handles,
+ std::unique_ptr<TraceReader>&& reader)
+ : Replayer(),
+ trace_reader_(std::move(reader)),
+ prepared_(false),
+ trace_end_(false),
+ header_ts_(0),
+ exec_handler_(TraceRecord::NewExecutionHandler(db, handles)),
+ env_(db->GetEnv()),
+ trace_file_version_(-1) {}
+
+ReplayerImpl::~ReplayerImpl() {
+ exec_handler_.reset();
+ trace_reader_.reset();
+}
+
+Status ReplayerImpl::Prepare() {
+ Trace header;
+ int db_version;
+ Status s = ReadHeader(&header);
+ if (!s.ok()) {
+ return s;
+ }
+ s = TracerHelper::ParseTraceHeader(header, &trace_file_version_, &db_version);
+ if (!s.ok()) {
+ return s;
+ }
+ header_ts_ = header.ts;
+ prepared_ = true;
+ trace_end_ = false;
+ return Status::OK();
+}
+
+Status ReplayerImpl::Next(std::unique_ptr<TraceRecord>* record) {
+ if (!prepared_) {
+ return Status::Incomplete("Not prepared!");
+ }
+ if (trace_end_) {
+ return Status::Incomplete("Trace end.");
+ }
+
+ Trace trace;
+ Status s = ReadTrace(&trace); // ReadTrace is atomic
+ // Reached the trace end.
+ if (s.ok() && trace.type == kTraceEnd) {
+ trace_end_ = true;
+ return Status::Incomplete("Trace end.");
+ }
+ if (!s.ok() || record == nullptr) {
+ return s;
+ }
+
+ return TracerHelper::DecodeTraceRecord(&trace, trace_file_version_, record);
+}
+
+Status ReplayerImpl::Execute(const std::unique_ptr<TraceRecord>& record,
+ std::unique_ptr<TraceRecordResult>* result) {
+ return record->Accept(exec_handler_.get(), result);
+}
+
+Status ReplayerImpl::Replay(
+ const ReplayOptions& options,
+ const std::function<void(Status, std::unique_ptr<TraceRecordResult>&&)>&
+ result_callback) {
+ if (options.fast_forward <= 0.0) {
+ return Status::InvalidArgument("Wrong fast forward speed!");
+ }
+
+ if (!prepared_) {
+ return Status::Incomplete("Not prepared!");
+ }
+ if (trace_end_) {
+ return Status::Incomplete("Trace end.");
+ }
+
+ Status s = Status::OK();
+
+ if (options.num_threads <= 1) {
+ // num_threads == 0 or num_threads == 1 uses single thread.
+ std::chrono::system_clock::time_point replay_epoch =
+ std::chrono::system_clock::now();
+
+ while (s.ok()) {
+ Trace trace;
+ s = ReadTrace(&trace);
+ // If already at trace end, ReadTrace should return Status::Incomplete().
+ if (!s.ok()) {
+ break;
+ }
+
+ // No need to sleep before breaking the loop if at the trace end.
+ if (trace.type == kTraceEnd) {
+ trace_end_ = true;
+ s = Status::Incomplete("Trace end.");
+ break;
+ }
+
+ // In single-threaded replay, decode first then sleep.
+ std::unique_ptr<TraceRecord> record;
+ s = TracerHelper::DecodeTraceRecord(&trace, trace_file_version_, &record);
+ if (!s.ok() && !s.IsNotSupported()) {
+ break;
+ }
+
+ std::chrono::system_clock::time_point sleep_to =
+ replay_epoch +
+ std::chrono::microseconds(static_cast<uint64_t>(std::llround(
+ 1.0 * (trace.ts - header_ts_) / options.fast_forward)));
+ if (sleep_to > std::chrono::system_clock::now()) {
+ std::this_thread::sleep_until(sleep_to);
+ }
+
+ // Skip unsupported traces, stop for other errors.
+ if (s.IsNotSupported()) {
+ if (result_callback != nullptr) {
+ result_callback(s, nullptr);
+ }
+ s = Status::OK();
+ continue;
+ }
+
+ if (result_callback == nullptr) {
+ s = Execute(record, nullptr);
+ } else {
+ std::unique_ptr<TraceRecordResult> res;
+ s = Execute(record, &res);
+ result_callback(s, std::move(res));
+ }
+ }
+ } else {
+ // Multi-threaded replay.
+ ThreadPoolImpl thread_pool;
+ thread_pool.SetHostEnv(env_);
+ thread_pool.SetBackgroundThreads(static_cast<int>(options.num_threads));
+
+ std::mutex mtx;
+ // Background decoding and execution status.
+ Status bg_s = Status::OK();
+ uint64_t last_err_ts = static_cast<uint64_t>(-1);
+ // Callback function used in background work to update bg_s for the ealiest
+ // TraceRecord which has execution error. This is different from the
+ // timestamp of the first execution error (either start or end timestamp).
+ //
+ // Suppose TraceRecord R1, R2, with timestamps T1 < T2. Their execution
+ // timestamps are T1_start, T1_end, T2_start, T2_end.
+ // Single-thread: there must be T1_start < T1_end < T2_start < T2_end.
+ // Multi-thread: T1_start < T2_start may not be enforced. Orders of them are
+ // totally unknown.
+ // In order to report the same `first` error in both single-thread and
+ // multi-thread replay, we can only rely on the TraceRecords' timestamps,
+ // rather than their executin timestamps. Although in single-thread replay,
+ // the first error is also the last error, while in multi-thread replay, the
+ // first error may not be the first error in execution, and it may not be
+ // the last error in exeution as well.
+ auto error_cb = [&mtx, &bg_s, &last_err_ts](Status err, uint64_t err_ts) {
+ std::lock_guard<std::mutex> gd(mtx);
+ // Only record the first error.
+ if (!err.ok() && !err.IsNotSupported() && err_ts < last_err_ts) {
+ bg_s = err;
+ last_err_ts = err_ts;
+ }
+ };
+
+ std::chrono::system_clock::time_point replay_epoch =
+ std::chrono::system_clock::now();
+
+ while (bg_s.ok() && s.ok()) {
+ Trace trace;
+ s = ReadTrace(&trace);
+ // If already at trace end, ReadTrace should return Status::Incomplete().
+ if (!s.ok()) {
+ break;
+ }
+
+ TraceType trace_type = trace.type;
+
+ // No need to sleep before breaking the loop if at the trace end.
+ if (trace_type == kTraceEnd) {
+ trace_end_ = true;
+ s = Status::Incomplete("Trace end.");
+ break;
+ }
+
+ // In multi-threaded replay, sleep first then start decoding and
+ // execution in a thread.
+ std::chrono::system_clock::time_point sleep_to =
+ replay_epoch +
+ std::chrono::microseconds(static_cast<uint64_t>(std::llround(
+ 1.0 * (trace.ts - header_ts_) / options.fast_forward)));
+ if (sleep_to > std::chrono::system_clock::now()) {
+ std::this_thread::sleep_until(sleep_to);
+ }
+
+ if (trace_type == kTraceWrite || trace_type == kTraceGet ||
+ trace_type == kTraceIteratorSeek ||
+ trace_type == kTraceIteratorSeekForPrev ||
+ trace_type == kTraceMultiGet) {
+ std::unique_ptr<ReplayerWorkerArg> ra(new ReplayerWorkerArg);
+ ra->trace_entry = std::move(trace);
+ ra->handler = exec_handler_.get();
+ ra->trace_file_version = trace_file_version_;
+ ra->error_cb = error_cb;
+ ra->result_cb = result_callback;
+ thread_pool.Schedule(&ReplayerImpl::BackgroundWork, ra.release(),
+ nullptr, nullptr);
+ } else {
+ // Skip unsupported traces.
+ if (result_callback != nullptr) {
+ result_callback(Status::NotSupported("Unsupported trace type."),
+ nullptr);
+ }
+ }
+ }
+
+ thread_pool.WaitForJobsAndJoinAllThreads();
+ if (!bg_s.ok()) {
+ s = bg_s;
+ }
+ }
+
+ if (s.IsIncomplete()) {
+ // Reaching eof returns Incomplete status at the moment.
+ // Could happen when killing a process without calling EndTrace() API.
+ // TODO: Add better error handling.
+ trace_end_ = true;
+ return Status::OK();
+ }
+ return s;
+}
+
+uint64_t ReplayerImpl::GetHeaderTimestamp() const { return header_ts_; }
+
+Status ReplayerImpl::ReadHeader(Trace* header) {
+ assert(header != nullptr);
+ Status s = trace_reader_->Reset();
+ if (!s.ok()) {
+ return s;
+ }
+ std::string encoded_trace;
+ // Read the trace head
+ s = trace_reader_->Read(&encoded_trace);
+ if (!s.ok()) {
+ return s;
+ }
+
+ return TracerHelper::DecodeHeader(encoded_trace, header);
+}
+
+Status ReplayerImpl::ReadTrace(Trace* trace) {
+ assert(trace != nullptr);
+ std::string encoded_trace;
+ // We don't know if TraceReader is implemented thread-safe, so we protect the
+ // reading trace part with a mutex. The decoding part does not need to be
+ // protected since it's local.
+ {
+ std::lock_guard<std::mutex> guard(mutex_);
+ Status s = trace_reader_->Read(&encoded_trace);
+ if (!s.ok()) {
+ return s;
+ }
+ }
+ return TracerHelper::DecodeTrace(encoded_trace, trace);
+}
+
+void ReplayerImpl::BackgroundWork(void* arg) {
+ std::unique_ptr<ReplayerWorkerArg> ra(
+ reinterpret_cast<ReplayerWorkerArg*>(arg));
+ assert(ra != nullptr);
+
+ std::unique_ptr<TraceRecord> record;
+ Status s = TracerHelper::DecodeTraceRecord(&(ra->trace_entry),
+ ra->trace_file_version, &record);
+ if (!s.ok()) {
+ // Stop the replay
+ if (ra->error_cb != nullptr) {
+ ra->error_cb(s, ra->trace_entry.ts);
+ }
+ // Report the result
+ if (ra->result_cb != nullptr) {
+ ra->result_cb(s, nullptr);
+ }
+ return;
+ }
+
+ if (ra->result_cb == nullptr) {
+ s = record->Accept(ra->handler, nullptr);
+ } else {
+ std::unique_ptr<TraceRecordResult> res;
+ s = record->Accept(ra->handler, &res);
+ ra->result_cb(s, std::move(res));
+ }
+ record.reset();
+}
+
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/trace/replayer_impl.h b/src/rocksdb/utilities/trace/replayer_impl.h
new file mode 100644
index 000000000..367b0b51e
--- /dev/null
+++ b/src/rocksdb/utilities/trace/replayer_impl.h
@@ -0,0 +1,86 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+#ifndef ROCKSDB_LITE
+
+#include <atomic>
+#include <functional>
+#include <memory>
+#include <mutex>
+#include <unordered_map>
+
+#include "rocksdb/db.h"
+#include "rocksdb/env.h"
+#include "rocksdb/status.h"
+#include "rocksdb/trace_reader_writer.h"
+#include "rocksdb/trace_record.h"
+#include "rocksdb/trace_record_result.h"
+#include "rocksdb/utilities/replayer.h"
+#include "trace_replay/trace_replay.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class ReplayerImpl : public Replayer {
+ public:
+ ReplayerImpl(DB* db, const std::vector<ColumnFamilyHandle*>& handles,
+ std::unique_ptr<TraceReader>&& reader);
+ ~ReplayerImpl() override;
+
+ using Replayer::Prepare;
+ Status Prepare() override;
+
+ using Replayer::Next;
+ Status Next(std::unique_ptr<TraceRecord>* record) override;
+
+ using Replayer::Execute;
+ Status Execute(const std::unique_ptr<TraceRecord>& record,
+ std::unique_ptr<TraceRecordResult>* result) override;
+
+ using Replayer::Replay;
+ Status Replay(
+ const ReplayOptions& options,
+ const std::function<void(Status, std::unique_ptr<TraceRecordResult>&&)>&
+ result_callback) override;
+
+ using Replayer::GetHeaderTimestamp;
+ uint64_t GetHeaderTimestamp() const override;
+
+ private:
+ Status ReadHeader(Trace* header);
+ Status ReadTrace(Trace* trace);
+
+ // Generic function to execute a Trace in a thread pool.
+ static void BackgroundWork(void* arg);
+
+ std::unique_ptr<TraceReader> trace_reader_;
+ std::mutex mutex_;
+ std::atomic<bool> prepared_;
+ std::atomic<bool> trace_end_;
+ uint64_t header_ts_;
+ std::unique_ptr<TraceRecord::Handler> exec_handler_;
+ Env* env_;
+ // When reading the trace header, the trace file version can be parsed.
+ // Replayer will use different decode method to get the trace content based
+ // on different trace file version.
+ int trace_file_version_;
+};
+
+// Arguments passed to BackgroundWork() for replaying in a thread pool.
+struct ReplayerWorkerArg {
+ Trace trace_entry;
+ int trace_file_version;
+ // Handler to execute TraceRecord.
+ TraceRecord::Handler* handler;
+ // Callback function to report the error status and the timestamp of the
+ // TraceRecord (not the start/end timestamp of executing the TraceRecord).
+ std::function<void(Status, uint64_t)> error_cb;
+ // Callback function to report the trace execution status and operation
+ // execution status/result(s).
+ std::function<void(Status, std::unique_ptr<TraceRecordResult>&&)> result_cb;
+};
+
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/lock/lock_manager.cc b/src/rocksdb/utilities/transactions/lock/lock_manager.cc
new file mode 100644
index 000000000..df16b32ad
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/lock_manager.cc
@@ -0,0 +1,29 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "utilities/transactions/lock/lock_manager.h"
+
+#include "utilities/transactions/lock/point/point_lock_manager.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+std::shared_ptr<LockManager> NewLockManager(PessimisticTransactionDB* db,
+ const TransactionDBOptions& opt) {
+ assert(db);
+ if (opt.lock_mgr_handle) {
+ // A custom lock manager was provided in options
+ auto mgr = opt.lock_mgr_handle->getLockManager();
+ return std::shared_ptr<LockManager>(opt.lock_mgr_handle, mgr);
+ } else {
+ // Use a point lock manager by default
+ return std::shared_ptr<LockManager>(new PointLockManager(db, opt));
+ }
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/lock/lock_manager.h b/src/rocksdb/utilities/transactions/lock/lock_manager.h
new file mode 100644
index 000000000..a5ce1948c
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/lock_manager.h
@@ -0,0 +1,82 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+#ifndef ROCKSDB_LITE
+
+#include "rocksdb/types.h"
+#include "rocksdb/utilities/transaction.h"
+#include "rocksdb/utilities/transaction_db.h"
+#include "utilities/transactions/lock/lock_tracker.h"
+#include "utilities/transactions/pessimistic_transaction.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class PessimisticTransactionDB;
+
+class LockManager {
+ public:
+ virtual ~LockManager() {}
+
+ // Whether supports locking a specific key.
+ virtual bool IsPointLockSupported() const = 0;
+
+ // Whether supports locking a range of keys.
+ virtual bool IsRangeLockSupported() const = 0;
+
+ // Locks acquired through this LockManager should be tracked by
+ // the LockTrackers created through the returned factory.
+ virtual const LockTrackerFactory& GetLockTrackerFactory() const = 0;
+
+ // Enable locking for the specified column family.
+ // Caller should guarantee that this column family is not already enabled.
+ virtual void AddColumnFamily(const ColumnFamilyHandle* cf) = 0;
+
+ // Disable locking for the specified column family.
+ // Caller should guarantee that this column family is no longer used.
+ virtual void RemoveColumnFamily(const ColumnFamilyHandle* cf) = 0;
+
+ // Attempt to lock a key or a key range. If OK status is returned, the caller
+ // is responsible for calling UnLock() on this key.
+ virtual Status TryLock(PessimisticTransaction* txn,
+ ColumnFamilyId column_family_id,
+ const std::string& key, Env* env, bool exclusive) = 0;
+ // The range [start, end] are inclusive at both sides.
+ virtual Status TryLock(PessimisticTransaction* txn,
+ ColumnFamilyId column_family_id, const Endpoint& start,
+ const Endpoint& end, Env* env, bool exclusive) = 0;
+
+ // Unlock a key or a range locked by TryLock(). txn must be the same
+ // Transaction that locked this key.
+ virtual void UnLock(PessimisticTransaction* txn, const LockTracker& tracker,
+ Env* env) = 0;
+ virtual void UnLock(PessimisticTransaction* txn,
+ ColumnFamilyId column_family_id, const std::string& key,
+ Env* env) = 0;
+ virtual void UnLock(PessimisticTransaction* txn,
+ ColumnFamilyId column_family_id, const Endpoint& start,
+ const Endpoint& end, Env* env) = 0;
+
+ using PointLockStatus = std::unordered_multimap<ColumnFamilyId, KeyLockInfo>;
+ virtual PointLockStatus GetPointLockStatus() = 0;
+
+ using RangeLockStatus =
+ std::unordered_multimap<ColumnFamilyId, RangeLockInfo>;
+ virtual RangeLockStatus GetRangeLockStatus() = 0;
+
+ virtual std::vector<DeadlockPath> GetDeadlockInfoBuffer() = 0;
+
+ virtual void Resize(uint32_t new_size) = 0;
+};
+
+// LockManager should always be constructed through this factory method,
+// instead of constructing through concrete implementations' constructor.
+// Caller owns the returned pointer.
+std::shared_ptr<LockManager> NewLockManager(PessimisticTransactionDB* db,
+ const TransactionDBOptions& opt);
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/lock/lock_tracker.h b/src/rocksdb/utilities/transactions/lock/lock_tracker.h
new file mode 100644
index 000000000..5fa228a82
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/lock_tracker.h
@@ -0,0 +1,209 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+#ifndef ROCKSDB_LITE
+
+#include <memory>
+
+#include "rocksdb/rocksdb_namespace.h"
+#include "rocksdb/status.h"
+#include "rocksdb/types.h"
+#include "rocksdb/utilities/transaction_db.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// Request for locking a single key.
+struct PointLockRequest {
+ // The id of the key's column family.
+ ColumnFamilyId column_family_id = 0;
+ // The key to lock.
+ std::string key;
+ // The sequence number from which there is no concurrent update to key.
+ SequenceNumber seq = 0;
+ // Whether the lock is acquired only for read.
+ bool read_only = false;
+ // Whether the lock is in exclusive mode.
+ bool exclusive = true;
+};
+
+// Request for locking a range of keys.
+struct RangeLockRequest {
+ // The id of the key's column family.
+ ColumnFamilyId column_family_id;
+
+ // The range to be locked
+ Endpoint start_endp;
+ Endpoint end_endp;
+};
+
+struct PointLockStatus {
+ // Whether the key is locked.
+ bool locked = false;
+ // Whether the key is locked in exclusive mode.
+ bool exclusive = true;
+ // The sequence number in the tracked PointLockRequest.
+ SequenceNumber seq = 0;
+};
+
+// Return status when calling LockTracker::Untrack.
+enum class UntrackStatus {
+ // The lock is not tracked at all, so no lock to untrack.
+ NOT_TRACKED,
+ // The lock is untracked but not removed from the tracker.
+ UNTRACKED,
+ // The lock is removed from the tracker.
+ REMOVED,
+};
+
+// Tracks the lock requests.
+// In PessimisticTransaction, it tracks the locks acquired through LockMgr;
+// In OptimisticTransaction, since there is no LockMgr, it tracks the lock
+// intention. Not thread-safe.
+class LockTracker {
+ public:
+ virtual ~LockTracker() {}
+
+ // Whether supports locking a specific key.
+ virtual bool IsPointLockSupported() const = 0;
+
+ // Whether supports locking a range of keys.
+ virtual bool IsRangeLockSupported() const = 0;
+
+ // Tracks the acquirement of a lock on key.
+ //
+ // If this method is not supported, leave it as a no-op.
+ virtual void Track(const PointLockRequest& /*lock_request*/) = 0;
+
+ // Untracks the lock on a key.
+ // seq and exclusive in lock_request are not used.
+ //
+ // If this method is not supported, leave it as a no-op and
+ // returns NOT_TRACKED.
+ virtual UntrackStatus Untrack(const PointLockRequest& /*lock_request*/) = 0;
+
+ // Counterpart of Track(const PointLockRequest&) for RangeLockRequest.
+ virtual void Track(const RangeLockRequest& /*lock_request*/) = 0;
+
+ // Counterpart of Untrack(const PointLockRequest&) for RangeLockRequest.
+ virtual UntrackStatus Untrack(const RangeLockRequest& /*lock_request*/) = 0;
+
+ // Merges lock requests tracked in the specified tracker into the current
+ // tracker.
+ //
+ // E.g. for point lock, if a key in tracker is not yet tracked,
+ // track this new key; otherwise, merge the tracked information of the key
+ // such as lock's exclusiveness, read/write statistics.
+ //
+ // If this method is not supported, leave it as a no-op.
+ //
+ // REQUIRED: the specified tracker must be of the same concrete class type as
+ // the current tracker.
+ virtual void Merge(const LockTracker& /*tracker*/) = 0;
+
+ // This is a reverse operation of Merge.
+ //
+ // E.g. for point lock, if a key exists in both current and the sepcified
+ // tracker, then subtract the information (such as read/write statistics) of
+ // the key in the specified tracker from the current tracker.
+ //
+ // If this method is not supported, leave it as a no-op.
+ //
+ // REQUIRED:
+ // The specified tracker must be of the same concrete class type as
+ // the current tracker.
+ // The tracked locks in the specified tracker must be a subset of those
+ // tracked by the current tracker.
+ virtual void Subtract(const LockTracker& /*tracker*/) = 0;
+
+ // Clears all tracked locks.
+ virtual void Clear() = 0;
+
+ // Gets the new locks (excluding the locks that have been tracked before the
+ // save point) tracked since the specified save point, the result is stored
+ // in an internally constructed LockTracker and returned.
+ //
+ // save_point_tracker is the tracker used by a SavePoint to track locks
+ // tracked after creating the SavePoint.
+ //
+ // The implementation should document whether point lock, or range lock, or
+ // both are considered in this method.
+ // If this method is not supported, returns nullptr.
+ //
+ // REQUIRED:
+ // The save_point_tracker must be of the same concrete class type as the
+ // current tracker.
+ // The tracked locks in the specified tracker must be a subset of those
+ // tracked by the current tracker.
+ virtual LockTracker* GetTrackedLocksSinceSavePoint(
+ const LockTracker& /*save_point_tracker*/) const = 0;
+
+ // Gets lock related information of the key.
+ //
+ // If point lock is not supported, always returns LockStatus with
+ // locked=false.
+ virtual PointLockStatus GetPointLockStatus(
+ ColumnFamilyId /*column_family_id*/,
+ const std::string& /*key*/) const = 0;
+
+ // Gets number of tracked point locks.
+ //
+ // If point lock is not supported, always returns 0.
+ virtual uint64_t GetNumPointLocks() const = 0;
+
+ class ColumnFamilyIterator {
+ public:
+ virtual ~ColumnFamilyIterator() {}
+
+ // Whether there are remaining column families.
+ virtual bool HasNext() const = 0;
+
+ // Gets next column family id.
+ //
+ // If HasNext is false, calling this method has undefined behavior.
+ virtual ColumnFamilyId Next() = 0;
+ };
+
+ // Gets an iterator for column families.
+ //
+ // Returned iterator must not be nullptr.
+ // If there is no column family to iterate,
+ // returns an empty non-null iterator.
+ // Caller owns the returned pointer.
+ virtual ColumnFamilyIterator* GetColumnFamilyIterator() const = 0;
+
+ class KeyIterator {
+ public:
+ virtual ~KeyIterator() {}
+
+ // Whether there are remaining keys.
+ virtual bool HasNext() const = 0;
+
+ // Gets the next key.
+ //
+ // If HasNext is false, calling this method has undefined behavior.
+ virtual const std::string& Next() = 0;
+ };
+
+ // Gets an iterator for keys with tracked point locks in the column family.
+ //
+ // The column family must exist.
+ // Returned iterator must not be nullptr.
+ // Caller owns the returned pointer.
+ virtual KeyIterator* GetKeyIterator(
+ ColumnFamilyId /*column_family_id*/) const = 0;
+};
+
+// LockTracker should always be constructed through this factory.
+// Each LockManager owns a LockTrackerFactory.
+class LockTrackerFactory {
+ public:
+ // Caller owns the returned pointer.
+ virtual LockTracker* Create() const = 0;
+ virtual ~LockTrackerFactory() {}
+};
+
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/lock/point/point_lock_manager.cc b/src/rocksdb/utilities/transactions/lock/point/point_lock_manager.cc
new file mode 100644
index 000000000..b362a164d
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/point/point_lock_manager.cc
@@ -0,0 +1,721 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "utilities/transactions/lock/point/point_lock_manager.h"
+
+#include <algorithm>
+#include <cinttypes>
+#include <mutex>
+
+#include "monitoring/perf_context_imp.h"
+#include "rocksdb/slice.h"
+#include "rocksdb/utilities/transaction_db_mutex.h"
+#include "test_util/sync_point.h"
+#include "util/cast_util.h"
+#include "util/hash.h"
+#include "util/thread_local.h"
+#include "utilities/transactions/pessimistic_transaction_db.h"
+#include "utilities/transactions/transaction_db_mutex_impl.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+struct LockInfo {
+ bool exclusive;
+ autovector<TransactionID> txn_ids;
+
+ // Transaction locks are not valid after this time in us
+ uint64_t expiration_time;
+
+ LockInfo(TransactionID id, uint64_t time, bool ex)
+ : exclusive(ex), expiration_time(time) {
+ txn_ids.push_back(id);
+ }
+ LockInfo(const LockInfo& lock_info)
+ : exclusive(lock_info.exclusive),
+ txn_ids(lock_info.txn_ids),
+ expiration_time(lock_info.expiration_time) {}
+ void operator=(const LockInfo& lock_info) {
+ exclusive = lock_info.exclusive;
+ txn_ids = lock_info.txn_ids;
+ expiration_time = lock_info.expiration_time;
+ }
+ DECLARE_DEFAULT_MOVES(LockInfo);
+};
+
+struct LockMapStripe {
+ explicit LockMapStripe(std::shared_ptr<TransactionDBMutexFactory> factory) {
+ stripe_mutex = factory->AllocateMutex();
+ stripe_cv = factory->AllocateCondVar();
+ assert(stripe_mutex);
+ assert(stripe_cv);
+ }
+
+ // Mutex must be held before modifying keys map
+ std::shared_ptr<TransactionDBMutex> stripe_mutex;
+
+ // Condition Variable per stripe for waiting on a lock
+ std::shared_ptr<TransactionDBCondVar> stripe_cv;
+
+ // Locked keys mapped to the info about the transactions that locked them.
+ // TODO(agiardullo): Explore performance of other data structures.
+ UnorderedMap<std::string, LockInfo> keys;
+};
+
+// Map of #num_stripes LockMapStripes
+struct LockMap {
+ explicit LockMap(size_t num_stripes,
+ std::shared_ptr<TransactionDBMutexFactory> factory)
+ : num_stripes_(num_stripes) {
+ lock_map_stripes_.reserve(num_stripes);
+ for (size_t i = 0; i < num_stripes; i++) {
+ LockMapStripe* stripe = new LockMapStripe(factory);
+ lock_map_stripes_.push_back(stripe);
+ }
+ }
+
+ ~LockMap() {
+ for (auto stripe : lock_map_stripes_) {
+ delete stripe;
+ }
+ }
+
+ // Number of sepearate LockMapStripes to create, each with their own Mutex
+ const size_t num_stripes_;
+
+ // Count of keys that are currently locked in this column family.
+ // (Only maintained if PointLockManager::max_num_locks_ is positive.)
+ std::atomic<int64_t> lock_cnt{0};
+
+ std::vector<LockMapStripe*> lock_map_stripes_;
+
+ size_t GetStripe(const std::string& key) const;
+};
+
+namespace {
+void UnrefLockMapsCache(void* ptr) {
+ // Called when a thread exits or a ThreadLocalPtr gets destroyed.
+ auto lock_maps_cache =
+ static_cast<UnorderedMap<uint32_t, std::shared_ptr<LockMap>>*>(ptr);
+ delete lock_maps_cache;
+}
+} // anonymous namespace
+
+PointLockManager::PointLockManager(PessimisticTransactionDB* txn_db,
+ const TransactionDBOptions& opt)
+ : txn_db_impl_(txn_db),
+ default_num_stripes_(opt.num_stripes),
+ max_num_locks_(opt.max_num_locks),
+ lock_maps_cache_(new ThreadLocalPtr(&UnrefLockMapsCache)),
+ dlock_buffer_(opt.max_num_deadlocks),
+ mutex_factory_(opt.custom_mutex_factory
+ ? opt.custom_mutex_factory
+ : std::make_shared<TransactionDBMutexFactoryImpl>()) {}
+
+size_t LockMap::GetStripe(const std::string& key) const {
+ assert(num_stripes_ > 0);
+ return FastRange64(GetSliceNPHash64(key), num_stripes_);
+}
+
+void PointLockManager::AddColumnFamily(const ColumnFamilyHandle* cf) {
+ InstrumentedMutexLock l(&lock_map_mutex_);
+
+ if (lock_maps_.find(cf->GetID()) == lock_maps_.end()) {
+ lock_maps_.emplace(cf->GetID(), std::make_shared<LockMap>(
+ default_num_stripes_, mutex_factory_));
+ } else {
+ // column_family already exists in lock map
+ assert(false);
+ }
+}
+
+void PointLockManager::RemoveColumnFamily(const ColumnFamilyHandle* cf) {
+ // Remove lock_map for this column family. Since the lock map is stored
+ // as a shared ptr, concurrent transactions can still keep using it
+ // until they release their references to it.
+ {
+ InstrumentedMutexLock l(&lock_map_mutex_);
+
+ auto lock_maps_iter = lock_maps_.find(cf->GetID());
+ if (lock_maps_iter == lock_maps_.end()) {
+ return;
+ }
+
+ lock_maps_.erase(lock_maps_iter);
+ } // lock_map_mutex_
+
+ // Clear all thread-local caches
+ autovector<void*> local_caches;
+ lock_maps_cache_->Scrape(&local_caches, nullptr);
+ for (auto cache : local_caches) {
+ delete static_cast<LockMaps*>(cache);
+ }
+}
+
+// Look up the LockMap std::shared_ptr for a given column_family_id.
+// Note: The LockMap is only valid as long as the caller is still holding on
+// to the returned std::shared_ptr.
+std::shared_ptr<LockMap> PointLockManager::GetLockMap(
+ ColumnFamilyId column_family_id) {
+ // First check thread-local cache
+ if (lock_maps_cache_->Get() == nullptr) {
+ lock_maps_cache_->Reset(new LockMaps());
+ }
+
+ auto lock_maps_cache = static_cast<LockMaps*>(lock_maps_cache_->Get());
+
+ auto lock_map_iter = lock_maps_cache->find(column_family_id);
+ if (lock_map_iter != lock_maps_cache->end()) {
+ // Found lock map for this column family.
+ return lock_map_iter->second;
+ }
+
+ // Not found in local cache, grab mutex and check shared LockMaps
+ InstrumentedMutexLock l(&lock_map_mutex_);
+
+ lock_map_iter = lock_maps_.find(column_family_id);
+ if (lock_map_iter == lock_maps_.end()) {
+ return std::shared_ptr<LockMap>(nullptr);
+ } else {
+ // Found lock map. Store in thread-local cache and return.
+ std::shared_ptr<LockMap>& lock_map = lock_map_iter->second;
+ lock_maps_cache->insert({column_family_id, lock_map});
+
+ return lock_map;
+ }
+}
+
+// Returns true if this lock has expired and can be acquired by another
+// transaction.
+// If false, sets *expire_time to the expiration time of the lock according
+// to Env->GetMicros() or 0 if no expiration.
+bool PointLockManager::IsLockExpired(TransactionID txn_id,
+ const LockInfo& lock_info, Env* env,
+ uint64_t* expire_time) {
+ if (lock_info.expiration_time == 0) {
+ *expire_time = 0;
+ return false;
+ }
+
+ auto now = env->NowMicros();
+ bool expired = lock_info.expiration_time <= now;
+ if (!expired) {
+ // return how many microseconds until lock will be expired
+ *expire_time = lock_info.expiration_time;
+ } else {
+ for (auto id : lock_info.txn_ids) {
+ if (txn_id == id) {
+ continue;
+ }
+
+ bool success = txn_db_impl_->TryStealingExpiredTransactionLocks(id);
+ if (!success) {
+ expired = false;
+ *expire_time = 0;
+ break;
+ }
+ }
+ }
+
+ return expired;
+}
+
+Status PointLockManager::TryLock(PessimisticTransaction* txn,
+ ColumnFamilyId column_family_id,
+ const std::string& key, Env* env,
+ bool exclusive) {
+ // Lookup lock map for this column family id
+ std::shared_ptr<LockMap> lock_map_ptr = GetLockMap(column_family_id);
+ LockMap* lock_map = lock_map_ptr.get();
+ if (lock_map == nullptr) {
+ char msg[255];
+ snprintf(msg, sizeof(msg), "Column family id not found: %" PRIu32,
+ column_family_id);
+
+ return Status::InvalidArgument(msg);
+ }
+
+ // Need to lock the mutex for the stripe that this key hashes to
+ size_t stripe_num = lock_map->GetStripe(key);
+ assert(lock_map->lock_map_stripes_.size() > stripe_num);
+ LockMapStripe* stripe = lock_map->lock_map_stripes_.at(stripe_num);
+
+ LockInfo lock_info(txn->GetID(), txn->GetExpirationTime(), exclusive);
+ int64_t timeout = txn->GetLockTimeout();
+
+ return AcquireWithTimeout(txn, lock_map, stripe, column_family_id, key, env,
+ timeout, lock_info);
+}
+
+// Helper function for TryLock().
+Status PointLockManager::AcquireWithTimeout(
+ PessimisticTransaction* txn, LockMap* lock_map, LockMapStripe* stripe,
+ ColumnFamilyId column_family_id, const std::string& key, Env* env,
+ int64_t timeout, const LockInfo& lock_info) {
+ Status result;
+ uint64_t end_time = 0;
+
+ if (timeout > 0) {
+ uint64_t start_time = env->NowMicros();
+ end_time = start_time + timeout;
+ }
+
+ if (timeout < 0) {
+ // If timeout is negative, we wait indefinitely to acquire the lock
+ result = stripe->stripe_mutex->Lock();
+ } else {
+ result = stripe->stripe_mutex->TryLockFor(timeout);
+ }
+
+ if (!result.ok()) {
+ // failed to acquire mutex
+ return result;
+ }
+
+ // Acquire lock if we are able to
+ uint64_t expire_time_hint = 0;
+ autovector<TransactionID> wait_ids;
+ result = AcquireLocked(lock_map, stripe, key, env, lock_info,
+ &expire_time_hint, &wait_ids);
+
+ if (!result.ok() && timeout != 0) {
+ PERF_TIMER_GUARD(key_lock_wait_time);
+ PERF_COUNTER_ADD(key_lock_wait_count, 1);
+ // If we weren't able to acquire the lock, we will keep retrying as long
+ // as the timeout allows.
+ bool timed_out = false;
+ do {
+ // Decide how long to wait
+ int64_t cv_end_time = -1;
+ if (expire_time_hint > 0 && end_time > 0) {
+ cv_end_time = std::min(expire_time_hint, end_time);
+ } else if (expire_time_hint > 0) {
+ cv_end_time = expire_time_hint;
+ } else if (end_time > 0) {
+ cv_end_time = end_time;
+ }
+
+ assert(result.IsBusy() || wait_ids.size() != 0);
+
+ // We are dependent on a transaction to finish, so perform deadlock
+ // detection.
+ if (wait_ids.size() != 0) {
+ if (txn->IsDeadlockDetect()) {
+ if (IncrementWaiters(txn, wait_ids, key, column_family_id,
+ lock_info.exclusive, env)) {
+ result = Status::Busy(Status::SubCode::kDeadlock);
+ stripe->stripe_mutex->UnLock();
+ return result;
+ }
+ }
+ txn->SetWaitingTxn(wait_ids, column_family_id, &key);
+ }
+
+ TEST_SYNC_POINT("PointLockManager::AcquireWithTimeout:WaitingTxn");
+ if (cv_end_time < 0) {
+ // Wait indefinitely
+ result = stripe->stripe_cv->Wait(stripe->stripe_mutex);
+ } else {
+ uint64_t now = env->NowMicros();
+ if (static_cast<uint64_t>(cv_end_time) > now) {
+ result = stripe->stripe_cv->WaitFor(stripe->stripe_mutex,
+ cv_end_time - now);
+ }
+ }
+
+ if (wait_ids.size() != 0) {
+ txn->ClearWaitingTxn();
+ if (txn->IsDeadlockDetect()) {
+ DecrementWaiters(txn, wait_ids);
+ }
+ }
+
+ if (result.IsTimedOut()) {
+ timed_out = true;
+ // Even though we timed out, we will still make one more attempt to
+ // acquire lock below (it is possible the lock expired and we
+ // were never signaled).
+ }
+
+ if (result.ok() || result.IsTimedOut()) {
+ result = AcquireLocked(lock_map, stripe, key, env, lock_info,
+ &expire_time_hint, &wait_ids);
+ }
+ } while (!result.ok() && !timed_out);
+ }
+
+ stripe->stripe_mutex->UnLock();
+
+ return result;
+}
+
+void PointLockManager::DecrementWaiters(
+ const PessimisticTransaction* txn,
+ const autovector<TransactionID>& wait_ids) {
+ std::lock_guard<std::mutex> lock(wait_txn_map_mutex_);
+ DecrementWaitersImpl(txn, wait_ids);
+}
+
+void PointLockManager::DecrementWaitersImpl(
+ const PessimisticTransaction* txn,
+ const autovector<TransactionID>& wait_ids) {
+ auto id = txn->GetID();
+ assert(wait_txn_map_.Contains(id));
+ wait_txn_map_.Delete(id);
+
+ for (auto wait_id : wait_ids) {
+ rev_wait_txn_map_.Get(wait_id)--;
+ if (rev_wait_txn_map_.Get(wait_id) == 0) {
+ rev_wait_txn_map_.Delete(wait_id);
+ }
+ }
+}
+
+bool PointLockManager::IncrementWaiters(
+ const PessimisticTransaction* txn,
+ const autovector<TransactionID>& wait_ids, const std::string& key,
+ const uint32_t& cf_id, const bool& exclusive, Env* const env) {
+ auto id = txn->GetID();
+ std::vector<int> queue_parents(
+ static_cast<size_t>(txn->GetDeadlockDetectDepth()));
+ std::vector<TransactionID> queue_values(
+ static_cast<size_t>(txn->GetDeadlockDetectDepth()));
+ std::lock_guard<std::mutex> lock(wait_txn_map_mutex_);
+ assert(!wait_txn_map_.Contains(id));
+
+ wait_txn_map_.Insert(id, {wait_ids, cf_id, exclusive, key});
+
+ for (auto wait_id : wait_ids) {
+ if (rev_wait_txn_map_.Contains(wait_id)) {
+ rev_wait_txn_map_.Get(wait_id)++;
+ } else {
+ rev_wait_txn_map_.Insert(wait_id, 1);
+ }
+ }
+
+ // No deadlock if nobody is waiting on self.
+ if (!rev_wait_txn_map_.Contains(id)) {
+ return false;
+ }
+
+ const auto* next_ids = &wait_ids;
+ int parent = -1;
+ int64_t deadlock_time = 0;
+ for (int tail = 0, head = 0; head < txn->GetDeadlockDetectDepth(); head++) {
+ int i = 0;
+ if (next_ids) {
+ for (; i < static_cast<int>(next_ids->size()) &&
+ tail + i < txn->GetDeadlockDetectDepth();
+ i++) {
+ queue_values[tail + i] = (*next_ids)[i];
+ queue_parents[tail + i] = parent;
+ }
+ tail += i;
+ }
+
+ // No more items in the list, meaning no deadlock.
+ if (tail == head) {
+ return false;
+ }
+
+ auto next = queue_values[head];
+ if (next == id) {
+ std::vector<DeadlockInfo> path;
+ while (head != -1) {
+ assert(wait_txn_map_.Contains(queue_values[head]));
+
+ auto extracted_info = wait_txn_map_.Get(queue_values[head]);
+ path.push_back({queue_values[head], extracted_info.m_cf_id,
+ extracted_info.m_exclusive,
+ extracted_info.m_waiting_key});
+ head = queue_parents[head];
+ }
+ if (!env->GetCurrentTime(&deadlock_time).ok()) {
+ /*
+ TODO(AR) this preserves the current behaviour whilst checking the
+ status of env->GetCurrentTime to ensure that ASSERT_STATUS_CHECKED
+ passes. Should we instead raise an error if !ok() ?
+ */
+ deadlock_time = 0;
+ }
+ std::reverse(path.begin(), path.end());
+ dlock_buffer_.AddNewPath(DeadlockPath(path, deadlock_time));
+ deadlock_time = 0;
+ DecrementWaitersImpl(txn, wait_ids);
+ return true;
+ } else if (!wait_txn_map_.Contains(next)) {
+ next_ids = nullptr;
+ continue;
+ } else {
+ parent = head;
+ next_ids = &(wait_txn_map_.Get(next).m_neighbors);
+ }
+ }
+
+ // Wait cycle too big, just assume deadlock.
+ if (!env->GetCurrentTime(&deadlock_time).ok()) {
+ /*
+ TODO(AR) this preserves the current behaviour whilst checking the status
+ of env->GetCurrentTime to ensure that ASSERT_STATUS_CHECKED passes.
+ Should we instead raise an error if !ok() ?
+ */
+ deadlock_time = 0;
+ }
+ dlock_buffer_.AddNewPath(DeadlockPath(deadlock_time, true));
+ DecrementWaitersImpl(txn, wait_ids);
+ return true;
+}
+
+// Try to lock this key after we have acquired the mutex.
+// Sets *expire_time to the expiration time in microseconds
+// or 0 if no expiration.
+// REQUIRED: Stripe mutex must be held.
+Status PointLockManager::AcquireLocked(LockMap* lock_map, LockMapStripe* stripe,
+ const std::string& key, Env* env,
+ const LockInfo& txn_lock_info,
+ uint64_t* expire_time,
+ autovector<TransactionID>* txn_ids) {
+ assert(txn_lock_info.txn_ids.size() == 1);
+
+ Status result;
+ // Check if this key is already locked
+ auto stripe_iter = stripe->keys.find(key);
+ if (stripe_iter != stripe->keys.end()) {
+ // Lock already held
+ LockInfo& lock_info = stripe_iter->second;
+ assert(lock_info.txn_ids.size() == 1 || !lock_info.exclusive);
+
+ if (lock_info.exclusive || txn_lock_info.exclusive) {
+ if (lock_info.txn_ids.size() == 1 &&
+ lock_info.txn_ids[0] == txn_lock_info.txn_ids[0]) {
+ // The list contains one txn and we're it, so just take it.
+ lock_info.exclusive = txn_lock_info.exclusive;
+ lock_info.expiration_time = txn_lock_info.expiration_time;
+ } else {
+ // Check if it's expired. Skips over txn_lock_info.txn_ids[0] in case
+ // it's there for a shared lock with multiple holders which was not
+ // caught in the first case.
+ if (IsLockExpired(txn_lock_info.txn_ids[0], lock_info, env,
+ expire_time)) {
+ // lock is expired, can steal it
+ lock_info.txn_ids = txn_lock_info.txn_ids;
+ lock_info.exclusive = txn_lock_info.exclusive;
+ lock_info.expiration_time = txn_lock_info.expiration_time;
+ // lock_cnt does not change
+ } else {
+ result = Status::TimedOut(Status::SubCode::kLockTimeout);
+ *txn_ids = lock_info.txn_ids;
+ }
+ }
+ } else {
+ // We are requesting shared access to a shared lock, so just grant it.
+ lock_info.txn_ids.push_back(txn_lock_info.txn_ids[0]);
+ // Using std::max means that expiration time never goes down even when
+ // a transaction is removed from the list. The correct solution would be
+ // to track expiry for every transaction, but this would also work for
+ // now.
+ lock_info.expiration_time =
+ std::max(lock_info.expiration_time, txn_lock_info.expiration_time);
+ }
+ } else { // Lock not held.
+ // Check lock limit
+ if (max_num_locks_ > 0 &&
+ lock_map->lock_cnt.load(std::memory_order_acquire) >= max_num_locks_) {
+ result = Status::Busy(Status::SubCode::kLockLimit);
+ } else {
+ // acquire lock
+ stripe->keys.emplace(key, txn_lock_info);
+
+ // Maintain lock count if there is a limit on the number of locks
+ if (max_num_locks_) {
+ lock_map->lock_cnt++;
+ }
+ }
+ }
+
+ return result;
+}
+
+void PointLockManager::UnLockKey(PessimisticTransaction* txn,
+ const std::string& key, LockMapStripe* stripe,
+ LockMap* lock_map, Env* env) {
+#ifdef NDEBUG
+ (void)env;
+#endif
+ TransactionID txn_id = txn->GetID();
+
+ auto stripe_iter = stripe->keys.find(key);
+ if (stripe_iter != stripe->keys.end()) {
+ auto& txns = stripe_iter->second.txn_ids;
+ auto txn_it = std::find(txns.begin(), txns.end(), txn_id);
+ // Found the key we locked. unlock it.
+ if (txn_it != txns.end()) {
+ if (txns.size() == 1) {
+ stripe->keys.erase(stripe_iter);
+ } else {
+ auto last_it = txns.end() - 1;
+ if (txn_it != last_it) {
+ *txn_it = *last_it;
+ }
+ txns.pop_back();
+ }
+
+ if (max_num_locks_ > 0) {
+ // Maintain lock count if there is a limit on the number of locks.
+ assert(lock_map->lock_cnt.load(std::memory_order_relaxed) > 0);
+ lock_map->lock_cnt--;
+ }
+ }
+ } else {
+ // This key is either not locked or locked by someone else. This should
+ // only happen if the unlocking transaction has expired.
+ assert(txn->GetExpirationTime() > 0 &&
+ txn->GetExpirationTime() < env->NowMicros());
+ }
+}
+
+void PointLockManager::UnLock(PessimisticTransaction* txn,
+ ColumnFamilyId column_family_id,
+ const std::string& key, Env* env) {
+ std::shared_ptr<LockMap> lock_map_ptr = GetLockMap(column_family_id);
+ LockMap* lock_map = lock_map_ptr.get();
+ if (lock_map == nullptr) {
+ // Column Family must have been dropped.
+ return;
+ }
+
+ // Lock the mutex for the stripe that this key hashes to
+ size_t stripe_num = lock_map->GetStripe(key);
+ assert(lock_map->lock_map_stripes_.size() > stripe_num);
+ LockMapStripe* stripe = lock_map->lock_map_stripes_.at(stripe_num);
+
+ stripe->stripe_mutex->Lock().PermitUncheckedError();
+ UnLockKey(txn, key, stripe, lock_map, env);
+ stripe->stripe_mutex->UnLock();
+
+ // Signal waiting threads to retry locking
+ stripe->stripe_cv->NotifyAll();
+}
+
+void PointLockManager::UnLock(PessimisticTransaction* txn,
+ const LockTracker& tracker, Env* env) {
+ std::unique_ptr<LockTracker::ColumnFamilyIterator> cf_it(
+ tracker.GetColumnFamilyIterator());
+ assert(cf_it != nullptr);
+ while (cf_it->HasNext()) {
+ ColumnFamilyId cf = cf_it->Next();
+ std::shared_ptr<LockMap> lock_map_ptr = GetLockMap(cf);
+ LockMap* lock_map = lock_map_ptr.get();
+ if (!lock_map) {
+ // Column Family must have been dropped.
+ return;
+ }
+
+ // Bucket keys by lock_map_ stripe
+ UnorderedMap<size_t, std::vector<const std::string*>> keys_by_stripe(
+ lock_map->num_stripes_);
+ std::unique_ptr<LockTracker::KeyIterator> key_it(
+ tracker.GetKeyIterator(cf));
+ assert(key_it != nullptr);
+ while (key_it->HasNext()) {
+ const std::string& key = key_it->Next();
+ size_t stripe_num = lock_map->GetStripe(key);
+ keys_by_stripe[stripe_num].push_back(&key);
+ }
+
+ // For each stripe, grab the stripe mutex and unlock all keys in this stripe
+ for (auto& stripe_iter : keys_by_stripe) {
+ size_t stripe_num = stripe_iter.first;
+ auto& stripe_keys = stripe_iter.second;
+
+ assert(lock_map->lock_map_stripes_.size() > stripe_num);
+ LockMapStripe* stripe = lock_map->lock_map_stripes_.at(stripe_num);
+
+ stripe->stripe_mutex->Lock().PermitUncheckedError();
+
+ for (const std::string* key : stripe_keys) {
+ UnLockKey(txn, *key, stripe, lock_map, env);
+ }
+
+ stripe->stripe_mutex->UnLock();
+
+ // Signal waiting threads to retry locking
+ stripe->stripe_cv->NotifyAll();
+ }
+ }
+}
+
+PointLockManager::PointLockStatus PointLockManager::GetPointLockStatus() {
+ PointLockStatus data;
+ // Lock order here is important. The correct order is lock_map_mutex_, then
+ // for every column family ID in ascending order lock every stripe in
+ // ascending order.
+ InstrumentedMutexLock l(&lock_map_mutex_);
+
+ std::vector<uint32_t> cf_ids;
+ for (const auto& map : lock_maps_) {
+ cf_ids.push_back(map.first);
+ }
+ std::sort(cf_ids.begin(), cf_ids.end());
+
+ for (auto i : cf_ids) {
+ const auto& stripes = lock_maps_[i]->lock_map_stripes_;
+ // Iterate and lock all stripes in ascending order.
+ for (const auto& j : stripes) {
+ j->stripe_mutex->Lock().PermitUncheckedError();
+ for (const auto& it : j->keys) {
+ struct KeyLockInfo info;
+ info.exclusive = it.second.exclusive;
+ info.key = it.first;
+ for (const auto& id : it.second.txn_ids) {
+ info.ids.push_back(id);
+ }
+ data.insert({i, info});
+ }
+ }
+ }
+
+ // Unlock everything. Unlocking order is not important.
+ for (auto i : cf_ids) {
+ const auto& stripes = lock_maps_[i]->lock_map_stripes_;
+ for (const auto& j : stripes) {
+ j->stripe_mutex->UnLock();
+ }
+ }
+
+ return data;
+}
+
+std::vector<DeadlockPath> PointLockManager::GetDeadlockInfoBuffer() {
+ return dlock_buffer_.PrepareBuffer();
+}
+
+void PointLockManager::Resize(uint32_t target_size) {
+ dlock_buffer_.Resize(target_size);
+}
+
+PointLockManager::RangeLockStatus PointLockManager::GetRangeLockStatus() {
+ return {};
+}
+
+Status PointLockManager::TryLock(PessimisticTransaction* /* txn */,
+ ColumnFamilyId /* cf_id */,
+ const Endpoint& /* start */,
+ const Endpoint& /* end */, Env* /* env */,
+ bool /* exclusive */) {
+ return Status::NotSupported(
+ "PointLockManager does not support range locking");
+}
+
+void PointLockManager::UnLock(PessimisticTransaction* /* txn */,
+ ColumnFamilyId /* cf_id */,
+ const Endpoint& /* start */,
+ const Endpoint& /* end */, Env* /* env */) {
+ // no-op
+}
+
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/lock/point/point_lock_manager.h b/src/rocksdb/utilities/transactions/lock/point/point_lock_manager.h
new file mode 100644
index 000000000..eeb34f3be
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/point/point_lock_manager.h
@@ -0,0 +1,224 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+#ifndef ROCKSDB_LITE
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "monitoring/instrumented_mutex.h"
+#include "rocksdb/utilities/transaction.h"
+#include "util/autovector.h"
+#include "util/hash_containers.h"
+#include "util/hash_map.h"
+#include "util/thread_local.h"
+#include "utilities/transactions/lock/lock_manager.h"
+#include "utilities/transactions/lock/point/point_lock_tracker.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class ColumnFamilyHandle;
+struct LockInfo;
+struct LockMap;
+struct LockMapStripe;
+
+template <class Path>
+class DeadlockInfoBufferTempl {
+ private:
+ std::vector<Path> paths_buffer_;
+ uint32_t buffer_idx_;
+ std::mutex paths_buffer_mutex_;
+
+ std::vector<Path> Normalize() {
+ auto working = paths_buffer_;
+
+ if (working.empty()) {
+ return working;
+ }
+
+ // Next write occurs at a nonexistent path's slot
+ if (paths_buffer_[buffer_idx_].empty()) {
+ working.resize(buffer_idx_);
+ } else {
+ std::rotate(working.begin(), working.begin() + buffer_idx_,
+ working.end());
+ }
+
+ return working;
+ }
+
+ public:
+ explicit DeadlockInfoBufferTempl(uint32_t n_latest_dlocks)
+ : paths_buffer_(n_latest_dlocks), buffer_idx_(0) {}
+
+ void AddNewPath(Path path) {
+ std::lock_guard<std::mutex> lock(paths_buffer_mutex_);
+
+ if (paths_buffer_.empty()) {
+ return;
+ }
+
+ paths_buffer_[buffer_idx_] = std::move(path);
+ buffer_idx_ = (buffer_idx_ + 1) % paths_buffer_.size();
+ }
+
+ void Resize(uint32_t target_size) {
+ std::lock_guard<std::mutex> lock(paths_buffer_mutex_);
+
+ paths_buffer_ = Normalize();
+
+ // Drop the deadlocks that will no longer be needed ater the normalize
+ if (target_size < paths_buffer_.size()) {
+ paths_buffer_.erase(
+ paths_buffer_.begin(),
+ paths_buffer_.begin() + (paths_buffer_.size() - target_size));
+ buffer_idx_ = 0;
+ }
+ // Resize the buffer to the target size and restore the buffer's idx
+ else {
+ auto prev_size = paths_buffer_.size();
+ paths_buffer_.resize(target_size);
+ buffer_idx_ = (uint32_t)prev_size;
+ }
+ }
+
+ std::vector<Path> PrepareBuffer() {
+ std::lock_guard<std::mutex> lock(paths_buffer_mutex_);
+
+ // Reversing the normalized vector returns the latest deadlocks first
+ auto working = Normalize();
+ std::reverse(working.begin(), working.end());
+
+ return working;
+ }
+};
+
+using DeadlockInfoBuffer = DeadlockInfoBufferTempl<DeadlockPath>;
+
+struct TrackedTrxInfo {
+ autovector<TransactionID> m_neighbors;
+ uint32_t m_cf_id;
+ bool m_exclusive;
+ std::string m_waiting_key;
+};
+
+class PointLockManager : public LockManager {
+ public:
+ PointLockManager(PessimisticTransactionDB* db,
+ const TransactionDBOptions& opt);
+ // No copying allowed
+ PointLockManager(const PointLockManager&) = delete;
+ PointLockManager& operator=(const PointLockManager&) = delete;
+
+ ~PointLockManager() override {}
+
+ bool IsPointLockSupported() const override { return true; }
+
+ bool IsRangeLockSupported() const override { return false; }
+
+ const LockTrackerFactory& GetLockTrackerFactory() const override {
+ return PointLockTrackerFactory::Get();
+ }
+
+ // Creates a new LockMap for this column family. Caller should guarantee
+ // that this column family does not already exist.
+ void AddColumnFamily(const ColumnFamilyHandle* cf) override;
+ // Deletes the LockMap for this column family. Caller should guarantee that
+ // this column family is no longer in use.
+ void RemoveColumnFamily(const ColumnFamilyHandle* cf) override;
+
+ Status TryLock(PessimisticTransaction* txn, ColumnFamilyId column_family_id,
+ const std::string& key, Env* env, bool exclusive) override;
+ Status TryLock(PessimisticTransaction* txn, ColumnFamilyId column_family_id,
+ const Endpoint& start, const Endpoint& end, Env* env,
+ bool exclusive) override;
+
+ void UnLock(PessimisticTransaction* txn, const LockTracker& tracker,
+ Env* env) override;
+ void UnLock(PessimisticTransaction* txn, ColumnFamilyId column_family_id,
+ const std::string& key, Env* env) override;
+ void UnLock(PessimisticTransaction* txn, ColumnFamilyId column_family_id,
+ const Endpoint& start, const Endpoint& end, Env* env) override;
+
+ PointLockStatus GetPointLockStatus() override;
+
+ RangeLockStatus GetRangeLockStatus() override;
+
+ std::vector<DeadlockPath> GetDeadlockInfoBuffer() override;
+
+ void Resize(uint32_t new_size) override;
+
+ private:
+ PessimisticTransactionDB* txn_db_impl_;
+
+ // Default number of lock map stripes per column family
+ const size_t default_num_stripes_;
+
+ // Limit on number of keys locked per column family
+ const int64_t max_num_locks_;
+
+ // The following lock order must be satisfied in order to avoid deadlocking
+ // ourselves.
+ // - lock_map_mutex_
+ // - stripe mutexes in ascending cf id, ascending stripe order
+ // - wait_txn_map_mutex_
+ //
+ // Must be held when accessing/modifying lock_maps_.
+ InstrumentedMutex lock_map_mutex_;
+
+ // Map of ColumnFamilyId to locked key info
+ using LockMaps = UnorderedMap<uint32_t, std::shared_ptr<LockMap>>;
+ LockMaps lock_maps_;
+
+ // Thread-local cache of entries in lock_maps_. This is an optimization
+ // to avoid acquiring a mutex in order to look up a LockMap
+ std::unique_ptr<ThreadLocalPtr> lock_maps_cache_;
+
+ // Must be held when modifying wait_txn_map_ and rev_wait_txn_map_.
+ std::mutex wait_txn_map_mutex_;
+
+ // Maps from waitee -> number of waiters.
+ HashMap<TransactionID, int> rev_wait_txn_map_;
+ // Maps from waiter -> waitee.
+ HashMap<TransactionID, TrackedTrxInfo> wait_txn_map_;
+ DeadlockInfoBuffer dlock_buffer_;
+
+ // Used to allocate mutexes/condvars to use when locking keys
+ std::shared_ptr<TransactionDBMutexFactory> mutex_factory_;
+
+ bool IsLockExpired(TransactionID txn_id, const LockInfo& lock_info, Env* env,
+ uint64_t* wait_time);
+
+ std::shared_ptr<LockMap> GetLockMap(uint32_t column_family_id);
+
+ Status AcquireWithTimeout(PessimisticTransaction* txn, LockMap* lock_map,
+ LockMapStripe* stripe, uint32_t column_family_id,
+ const std::string& key, Env* env, int64_t timeout,
+ const LockInfo& lock_info);
+
+ Status AcquireLocked(LockMap* lock_map, LockMapStripe* stripe,
+ const std::string& key, Env* env,
+ const LockInfo& lock_info, uint64_t* wait_time,
+ autovector<TransactionID>* txn_ids);
+
+ void UnLockKey(PessimisticTransaction* txn, const std::string& key,
+ LockMapStripe* stripe, LockMap* lock_map, Env* env);
+
+ bool IncrementWaiters(const PessimisticTransaction* txn,
+ const autovector<TransactionID>& wait_ids,
+ const std::string& key, const uint32_t& cf_id,
+ const bool& exclusive, Env* const env);
+ void DecrementWaiters(const PessimisticTransaction* txn,
+ const autovector<TransactionID>& wait_ids);
+ void DecrementWaitersImpl(const PessimisticTransaction* txn,
+ const autovector<TransactionID>& wait_ids);
+};
+
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/lock/point/point_lock_manager_test.cc b/src/rocksdb/utilities/transactions/lock/point/point_lock_manager_test.cc
new file mode 100644
index 000000000..525fdea71
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/point/point_lock_manager_test.cc
@@ -0,0 +1,181 @@
+// Copyright (c) 2020-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "utilities/transactions/lock/point/point_lock_manager_test.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// This test is not applicable for Range Lock manager as Range Lock Manager
+// operates on Column Families, not their ids.
+TEST_F(PointLockManagerTest, LockNonExistingColumnFamily) {
+ MockColumnFamilyHandle cf(1024);
+ locker_->RemoveColumnFamily(&cf);
+ auto txn = NewTxn();
+ auto s = locker_->TryLock(txn, 1024, "k", env_, true);
+ ASSERT_TRUE(s.IsInvalidArgument());
+ ASSERT_STREQ(s.getState(), "Column family id not found: 1024");
+ delete txn;
+}
+
+TEST_F(PointLockManagerTest, LockStatus) {
+ MockColumnFamilyHandle cf1(1024), cf2(2048);
+ locker_->AddColumnFamily(&cf1);
+ locker_->AddColumnFamily(&cf2);
+
+ auto txn1 = NewTxn();
+ ASSERT_OK(locker_->TryLock(txn1, 1024, "k1", env_, true));
+ ASSERT_OK(locker_->TryLock(txn1, 2048, "k1", env_, true));
+
+ auto txn2 = NewTxn();
+ ASSERT_OK(locker_->TryLock(txn2, 1024, "k2", env_, false));
+ ASSERT_OK(locker_->TryLock(txn2, 2048, "k2", env_, false));
+
+ auto s = locker_->GetPointLockStatus();
+ ASSERT_EQ(s.size(), 4u);
+ for (uint32_t cf_id : {1024, 2048}) {
+ ASSERT_EQ(s.count(cf_id), 2u);
+ auto range = s.equal_range(cf_id);
+ for (auto it = range.first; it != range.second; it++) {
+ ASSERT_TRUE(it->second.key == "k1" || it->second.key == "k2");
+ if (it->second.key == "k1") {
+ ASSERT_EQ(it->second.exclusive, true);
+ ASSERT_EQ(it->second.ids.size(), 1u);
+ ASSERT_EQ(it->second.ids[0], txn1->GetID());
+ } else if (it->second.key == "k2") {
+ ASSERT_EQ(it->second.exclusive, false);
+ ASSERT_EQ(it->second.ids.size(), 1u);
+ ASSERT_EQ(it->second.ids[0], txn2->GetID());
+ }
+ }
+ }
+
+ // Cleanup
+ locker_->UnLock(txn1, 1024, "k1", env_);
+ locker_->UnLock(txn1, 2048, "k1", env_);
+ locker_->UnLock(txn2, 1024, "k2", env_);
+ locker_->UnLock(txn2, 2048, "k2", env_);
+
+ delete txn1;
+ delete txn2;
+}
+
+TEST_F(PointLockManagerTest, UnlockExclusive) {
+ MockColumnFamilyHandle cf(1);
+ locker_->AddColumnFamily(&cf);
+
+ auto txn1 = NewTxn();
+ ASSERT_OK(locker_->TryLock(txn1, 1, "k", env_, true));
+ locker_->UnLock(txn1, 1, "k", env_);
+
+ auto txn2 = NewTxn();
+ ASSERT_OK(locker_->TryLock(txn2, 1, "k", env_, true));
+
+ // Cleanup
+ locker_->UnLock(txn2, 1, "k", env_);
+
+ delete txn1;
+ delete txn2;
+}
+
+TEST_F(PointLockManagerTest, UnlockShared) {
+ MockColumnFamilyHandle cf(1);
+ locker_->AddColumnFamily(&cf);
+
+ auto txn1 = NewTxn();
+ ASSERT_OK(locker_->TryLock(txn1, 1, "k", env_, false));
+ locker_->UnLock(txn1, 1, "k", env_);
+
+ auto txn2 = NewTxn();
+ ASSERT_OK(locker_->TryLock(txn2, 1, "k", env_, true));
+
+ // Cleanup
+ locker_->UnLock(txn2, 1, "k", env_);
+
+ delete txn1;
+ delete txn2;
+}
+
+// This test doesn't work with Range Lock Manager, because Range Lock Manager
+// doesn't support deadlock_detect_depth.
+
+TEST_F(PointLockManagerTest, DeadlockDepthExceeded) {
+ // Tests that when detecting deadlock, if the detection depth is exceeded,
+ // it's also viewed as deadlock.
+ MockColumnFamilyHandle cf(1);
+ locker_->AddColumnFamily(&cf);
+ TransactionOptions txn_opt;
+ txn_opt.deadlock_detect = true;
+ txn_opt.deadlock_detect_depth = 1;
+ txn_opt.lock_timeout = 1000000;
+ auto txn1 = NewTxn(txn_opt);
+ auto txn2 = NewTxn(txn_opt);
+ auto txn3 = NewTxn(txn_opt);
+ auto txn4 = NewTxn(txn_opt);
+ // "a ->(k) b" means transaction a is waiting for transaction b to release
+ // the held lock on key k.
+ // txn4 ->(k3) -> txn3 ->(k2) txn2 ->(k1) txn1
+ // txn3's deadlock detection will exceed the detection depth 1,
+ // which will be viewed as a deadlock.
+ // NOTE:
+ // txn4 ->(k3) -> txn3 must be set up before
+ // txn3 ->(k2) -> txn2, because to trigger deadlock detection for txn3,
+ // it must have another txn waiting on it, which is txn4 in this case.
+ ASSERT_OK(locker_->TryLock(txn1, 1, "k1", env_, true));
+
+ port::Thread t1 = BlockUntilWaitingTxn(wait_sync_point_name_, [&]() {
+ ASSERT_OK(locker_->TryLock(txn2, 1, "k2", env_, true));
+ // block because txn1 is holding a lock on k1.
+ locker_->TryLock(txn2, 1, "k1", env_, true);
+ });
+
+ ASSERT_OK(locker_->TryLock(txn3, 1, "k3", env_, true));
+
+ port::Thread t2 = BlockUntilWaitingTxn(wait_sync_point_name_, [&]() {
+ // block because txn3 is holding a lock on k1.
+ locker_->TryLock(txn4, 1, "k3", env_, true);
+ });
+
+ auto s = locker_->TryLock(txn3, 1, "k2", env_, true);
+ ASSERT_TRUE(s.IsBusy());
+ ASSERT_EQ(s.subcode(), Status::SubCode::kDeadlock);
+
+ std::vector<DeadlockPath> deadlock_paths = locker_->GetDeadlockInfoBuffer();
+ ASSERT_EQ(deadlock_paths.size(), 1u);
+ ASSERT_TRUE(deadlock_paths[0].limit_exceeded);
+
+ locker_->UnLock(txn1, 1, "k1", env_);
+ locker_->UnLock(txn3, 1, "k3", env_);
+ t1.join();
+ t2.join();
+
+ delete txn4;
+ delete txn3;
+ delete txn2;
+ delete txn1;
+}
+
+INSTANTIATE_TEST_CASE_P(PointLockManager, AnyLockManagerTest,
+ ::testing::Values(nullptr));
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
+
+#else
+#include <stdio.h>
+
+int main(int /*argc*/, char** /*argv*/) {
+ fprintf(stderr,
+ "SKIPPED because Transactions are not supported in ROCKSDB_LITE\n");
+ return 0;
+}
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/lock/point/point_lock_manager_test.h b/src/rocksdb/utilities/transactions/lock/point/point_lock_manager_test.h
new file mode 100644
index 000000000..ca9f46bf9
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/point/point_lock_manager_test.h
@@ -0,0 +1,324 @@
+// Copyright (c) Meta Platforms, Inc. and affiliates.
+//
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "file/file_util.h"
+#include "port/port.h"
+#include "port/stack_trace.h"
+#include "rocksdb/utilities/transaction_db.h"
+#include "test_util/testharness.h"
+#include "test_util/testutil.h"
+#include "utilities/transactions/lock/point/point_lock_manager.h"
+#include "utilities/transactions/pessimistic_transaction_db.h"
+#include "utilities/transactions/transaction_db_mutex_impl.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class MockColumnFamilyHandle : public ColumnFamilyHandle {
+ public:
+ explicit MockColumnFamilyHandle(ColumnFamilyId cf_id) : cf_id_(cf_id) {}
+
+ ~MockColumnFamilyHandle() override {}
+
+ const std::string& GetName() const override { return name_; }
+
+ ColumnFamilyId GetID() const override { return cf_id_; }
+
+ Status GetDescriptor(ColumnFamilyDescriptor*) override {
+ return Status::OK();
+ }
+
+ const Comparator* GetComparator() const override {
+ return BytewiseComparator();
+ }
+
+ private:
+ ColumnFamilyId cf_id_;
+ std::string name_ = "MockCF";
+};
+
+class PointLockManagerTest : public testing::Test {
+ public:
+ void SetUp() override {
+ env_ = Env::Default();
+ db_dir_ = test::PerThreadDBPath("point_lock_manager_test");
+ ASSERT_OK(env_->CreateDir(db_dir_));
+
+ Options opt;
+ opt.create_if_missing = true;
+ TransactionDBOptions txn_opt;
+ txn_opt.transaction_lock_timeout = 0;
+
+ ASSERT_OK(TransactionDB::Open(opt, txn_opt, db_dir_, &db_));
+
+ // CAUTION: This test creates a separate lock manager object (right, NOT
+ // the one that the TransactionDB is using!), and runs tests on it.
+ locker_.reset(new PointLockManager(
+ static_cast<PessimisticTransactionDB*>(db_), txn_opt));
+
+ wait_sync_point_name_ = "PointLockManager::AcquireWithTimeout:WaitingTxn";
+ }
+
+ void TearDown() override {
+ delete db_;
+ EXPECT_OK(DestroyDir(env_, db_dir_));
+ }
+
+ PessimisticTransaction* NewTxn(
+ TransactionOptions txn_opt = TransactionOptions()) {
+ Transaction* txn = db_->BeginTransaction(WriteOptions(), txn_opt);
+ return reinterpret_cast<PessimisticTransaction*>(txn);
+ }
+
+ protected:
+ Env* env_;
+ std::shared_ptr<LockManager> locker_;
+ const char* wait_sync_point_name_;
+ friend void PointLockManagerTestExternalSetup(PointLockManagerTest*);
+
+ private:
+ std::string db_dir_;
+ TransactionDB* db_;
+};
+
+using init_func_t = void (*)(PointLockManagerTest*);
+
+class AnyLockManagerTest : public PointLockManagerTest,
+ public testing::WithParamInterface<init_func_t> {
+ public:
+ void SetUp() override {
+ // If a custom setup function was provided, use it. Otherwise, use what we
+ // have inherited.
+ auto init_func = GetParam();
+ if (init_func)
+ (*init_func)(this);
+ else
+ PointLockManagerTest::SetUp();
+ }
+};
+
+TEST_P(AnyLockManagerTest, ReentrantExclusiveLock) {
+ // Tests that a txn can acquire exclusive lock on the same key repeatedly.
+ MockColumnFamilyHandle cf(1);
+ locker_->AddColumnFamily(&cf);
+ auto txn = NewTxn();
+ ASSERT_OK(locker_->TryLock(txn, 1, "k", env_, true));
+ ASSERT_OK(locker_->TryLock(txn, 1, "k", env_, true));
+
+ // Cleanup
+ locker_->UnLock(txn, 1, "k", env_);
+
+ delete txn;
+}
+
+TEST_P(AnyLockManagerTest, ReentrantSharedLock) {
+ // Tests that a txn can acquire shared lock on the same key repeatedly.
+ MockColumnFamilyHandle cf(1);
+ locker_->AddColumnFamily(&cf);
+ auto txn = NewTxn();
+ ASSERT_OK(locker_->TryLock(txn, 1, "k", env_, false));
+ ASSERT_OK(locker_->TryLock(txn, 1, "k", env_, false));
+
+ // Cleanup
+ locker_->UnLock(txn, 1, "k", env_);
+
+ delete txn;
+}
+
+TEST_P(AnyLockManagerTest, LockUpgrade) {
+ // Tests that a txn can upgrade from a shared lock to an exclusive lock.
+ MockColumnFamilyHandle cf(1);
+ locker_->AddColumnFamily(&cf);
+ auto txn = NewTxn();
+ ASSERT_OK(locker_->TryLock(txn, 1, "k", env_, false));
+ ASSERT_OK(locker_->TryLock(txn, 1, "k", env_, true));
+
+ // Cleanup
+ locker_->UnLock(txn, 1, "k", env_);
+ delete txn;
+}
+
+TEST_P(AnyLockManagerTest, LockDowngrade) {
+ // Tests that a txn can acquire a shared lock after acquiring an exclusive
+ // lock on the same key.
+ MockColumnFamilyHandle cf(1);
+ locker_->AddColumnFamily(&cf);
+ auto txn = NewTxn();
+ ASSERT_OK(locker_->TryLock(txn, 1, "k", env_, true));
+ ASSERT_OK(locker_->TryLock(txn, 1, "k", env_, false));
+
+ // Cleanup
+ locker_->UnLock(txn, 1, "k", env_);
+ delete txn;
+}
+
+TEST_P(AnyLockManagerTest, LockConflict) {
+ // Tests that lock conflicts lead to lock timeout.
+ MockColumnFamilyHandle cf(1);
+ locker_->AddColumnFamily(&cf);
+ auto txn1 = NewTxn();
+ auto txn2 = NewTxn();
+
+ {
+ // exclusive-exclusive conflict.
+ ASSERT_OK(locker_->TryLock(txn1, 1, "k1", env_, true));
+ auto s = locker_->TryLock(txn2, 1, "k1", env_, true);
+ ASSERT_TRUE(s.IsTimedOut());
+ }
+
+ {
+ // exclusive-shared conflict.
+ ASSERT_OK(locker_->TryLock(txn1, 1, "k2", env_, true));
+ auto s = locker_->TryLock(txn2, 1, "k2", env_, false);
+ ASSERT_TRUE(s.IsTimedOut());
+ }
+
+ {
+ // shared-exclusive conflict.
+ ASSERT_OK(locker_->TryLock(txn1, 1, "k2", env_, false));
+ auto s = locker_->TryLock(txn2, 1, "k2", env_, true);
+ ASSERT_TRUE(s.IsTimedOut());
+ }
+
+ // Cleanup
+ locker_->UnLock(txn1, 1, "k1", env_);
+ locker_->UnLock(txn1, 1, "k2", env_);
+
+ delete txn1;
+ delete txn2;
+}
+
+port::Thread BlockUntilWaitingTxn(const char* sync_point_name,
+ std::function<void()> f) {
+ std::atomic<bool> reached(false);
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ sync_point_name, [&](void* /*arg*/) { reached.store(true); });
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+
+ port::Thread t(f);
+
+ while (!reached.load()) {
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
+ }
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->ClearAllCallBacks();
+
+ return t;
+}
+
+TEST_P(AnyLockManagerTest, SharedLocks) {
+ // Tests that shared locks can be concurrently held by multiple transactions.
+ MockColumnFamilyHandle cf(1);
+ locker_->AddColumnFamily(&cf);
+ auto txn1 = NewTxn();
+ auto txn2 = NewTxn();
+ ASSERT_OK(locker_->TryLock(txn1, 1, "k", env_, false));
+ ASSERT_OK(locker_->TryLock(txn2, 1, "k", env_, false));
+
+ // Cleanup
+ locker_->UnLock(txn1, 1, "k", env_);
+ locker_->UnLock(txn2, 1, "k", env_);
+
+ delete txn1;
+ delete txn2;
+}
+
+TEST_P(AnyLockManagerTest, Deadlock) {
+ // Tests that deadlock can be detected.
+ // Deadlock scenario:
+ // txn1 exclusively locks k1, and wants to lock k2;
+ // txn2 exclusively locks k2, and wants to lock k1.
+ MockColumnFamilyHandle cf(1);
+ locker_->AddColumnFamily(&cf);
+ TransactionOptions txn_opt;
+ txn_opt.deadlock_detect = true;
+ txn_opt.lock_timeout = 1000000;
+ auto txn1 = NewTxn(txn_opt);
+ auto txn2 = NewTxn(txn_opt);
+
+ ASSERT_OK(locker_->TryLock(txn1, 1, "k1", env_, true));
+ ASSERT_OK(locker_->TryLock(txn2, 1, "k2", env_, true));
+
+ // txn1 tries to lock k2, will block forever.
+ port::Thread t = BlockUntilWaitingTxn(wait_sync_point_name_, [&]() {
+ // block because txn2 is holding a lock on k2.
+ locker_->TryLock(txn1, 1, "k2", env_, true);
+ });
+
+ auto s = locker_->TryLock(txn2, 1, "k1", env_, true);
+ ASSERT_TRUE(s.IsBusy());
+ ASSERT_EQ(s.subcode(), Status::SubCode::kDeadlock);
+
+ std::vector<DeadlockPath> deadlock_paths = locker_->GetDeadlockInfoBuffer();
+ ASSERT_EQ(deadlock_paths.size(), 1u);
+ ASSERT_FALSE(deadlock_paths[0].limit_exceeded);
+
+ std::vector<DeadlockInfo> deadlocks = deadlock_paths[0].path;
+ ASSERT_EQ(deadlocks.size(), 2u);
+
+ ASSERT_EQ(deadlocks[0].m_txn_id, txn1->GetID());
+ ASSERT_EQ(deadlocks[0].m_cf_id, 1u);
+ ASSERT_TRUE(deadlocks[0].m_exclusive);
+ ASSERT_EQ(deadlocks[0].m_waiting_key, "k2");
+
+ ASSERT_EQ(deadlocks[1].m_txn_id, txn2->GetID());
+ ASSERT_EQ(deadlocks[1].m_cf_id, 1u);
+ ASSERT_TRUE(deadlocks[1].m_exclusive);
+ ASSERT_EQ(deadlocks[1].m_waiting_key, "k1");
+
+ locker_->UnLock(txn2, 1, "k2", env_);
+ t.join();
+
+ // Cleanup
+ locker_->UnLock(txn1, 1, "k1", env_);
+ locker_->UnLock(txn1, 1, "k2", env_);
+ delete txn2;
+ delete txn1;
+}
+
+TEST_P(AnyLockManagerTest, GetWaitingTxns_MultipleTxns) {
+ MockColumnFamilyHandle cf(1);
+ locker_->AddColumnFamily(&cf);
+
+ auto txn1 = NewTxn();
+ ASSERT_OK(locker_->TryLock(txn1, 1, "k", env_, false));
+
+ auto txn2 = NewTxn();
+ ASSERT_OK(locker_->TryLock(txn2, 1, "k", env_, false));
+
+ auto txn3 = NewTxn();
+ txn3->SetLockTimeout(10000);
+ port::Thread t1 = BlockUntilWaitingTxn(wait_sync_point_name_, [&]() {
+ ASSERT_OK(locker_->TryLock(txn3, 1, "k", env_, true));
+ locker_->UnLock(txn3, 1, "k", env_);
+ });
+
+ // Ok, now txn3 is waiting for lock on "k", which is owned by two
+ // transactions. Check that GetWaitingTxns reports this correctly
+ uint32_t wait_cf_id;
+ std::string wait_key;
+ auto waiters = txn3->GetWaitingTxns(&wait_cf_id, &wait_key);
+
+ ASSERT_EQ(wait_cf_id, 1u);
+ ASSERT_EQ(wait_key, "k");
+ ASSERT_EQ(waiters.size(), 2);
+ bool waits_correct =
+ (waiters[0] == txn1->GetID() && waiters[1] == txn2->GetID()) ||
+ (waiters[1] == txn1->GetID() && waiters[0] == txn2->GetID());
+ ASSERT_EQ(waits_correct, true);
+
+ // Release locks so txn3 can proceed with execution
+ locker_->UnLock(txn1, 1, "k", env_);
+ locker_->UnLock(txn2, 1, "k", env_);
+
+ // Wait until txn3 finishes
+ t1.join();
+
+ delete txn1;
+ delete txn2;
+ delete txn3;
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/transactions/lock/point/point_lock_tracker.cc b/src/rocksdb/utilities/transactions/lock/point/point_lock_tracker.cc
new file mode 100644
index 000000000..6204a8f02
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/point/point_lock_tracker.cc
@@ -0,0 +1,257 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "utilities/transactions/lock/point/point_lock_tracker.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+namespace {
+
+class TrackedKeysColumnFamilyIterator
+ : public LockTracker::ColumnFamilyIterator {
+ public:
+ explicit TrackedKeysColumnFamilyIterator(const TrackedKeys& keys)
+ : tracked_keys_(keys), it_(keys.begin()) {}
+
+ bool HasNext() const override { return it_ != tracked_keys_.end(); }
+
+ ColumnFamilyId Next() override { return (it_++)->first; }
+
+ private:
+ const TrackedKeys& tracked_keys_;
+ TrackedKeys::const_iterator it_;
+};
+
+class TrackedKeysIterator : public LockTracker::KeyIterator {
+ public:
+ TrackedKeysIterator(const TrackedKeys& keys, ColumnFamilyId id)
+ : key_infos_(keys.at(id)), it_(key_infos_.begin()) {}
+
+ bool HasNext() const override { return it_ != key_infos_.end(); }
+
+ const std::string& Next() override { return (it_++)->first; }
+
+ private:
+ const TrackedKeyInfos& key_infos_;
+ TrackedKeyInfos::const_iterator it_;
+};
+
+} // namespace
+
+void PointLockTracker::Track(const PointLockRequest& r) {
+ auto& keys = tracked_keys_[r.column_family_id];
+ auto result = keys.try_emplace(r.key, r.seq);
+ auto it = result.first;
+ if (!result.second && r.seq < it->second.seq) {
+ // Now tracking this key with an earlier sequence number
+ it->second.seq = r.seq;
+ }
+ // else we do not update the seq. The smaller the tracked seq, the stronger it
+ // the guarantee since it implies from the seq onward there has not been a
+ // concurrent update to the key. So we update the seq if it implies stronger
+ // guarantees, i.e., if it is smaller than the existing tracked seq.
+
+ if (r.read_only) {
+ it->second.num_reads++;
+ } else {
+ it->second.num_writes++;
+ }
+
+ it->second.exclusive = it->second.exclusive || r.exclusive;
+}
+
+UntrackStatus PointLockTracker::Untrack(const PointLockRequest& r) {
+ auto cf_keys = tracked_keys_.find(r.column_family_id);
+ if (cf_keys == tracked_keys_.end()) {
+ return UntrackStatus::NOT_TRACKED;
+ }
+
+ auto& keys = cf_keys->second;
+ auto it = keys.find(r.key);
+ if (it == keys.end()) {
+ return UntrackStatus::NOT_TRACKED;
+ }
+
+ bool untracked = false;
+ auto& info = it->second;
+ if (r.read_only) {
+ if (info.num_reads > 0) {
+ info.num_reads--;
+ untracked = true;
+ }
+ } else {
+ if (info.num_writes > 0) {
+ info.num_writes--;
+ untracked = true;
+ }
+ }
+
+ bool removed = false;
+ if (info.num_reads == 0 && info.num_writes == 0) {
+ keys.erase(it);
+ if (keys.empty()) {
+ tracked_keys_.erase(cf_keys);
+ }
+ removed = true;
+ }
+
+ if (removed) {
+ return UntrackStatus::REMOVED;
+ }
+ if (untracked) {
+ return UntrackStatus::UNTRACKED;
+ }
+ return UntrackStatus::NOT_TRACKED;
+}
+
+void PointLockTracker::Merge(const LockTracker& tracker) {
+ const PointLockTracker& t = static_cast<const PointLockTracker&>(tracker);
+ for (const auto& cf_keys : t.tracked_keys_) {
+ ColumnFamilyId cf = cf_keys.first;
+ const auto& keys = cf_keys.second;
+
+ auto current_cf_keys = tracked_keys_.find(cf);
+ if (current_cf_keys == tracked_keys_.end()) {
+ tracked_keys_.emplace(cf_keys);
+ } else {
+ auto& current_keys = current_cf_keys->second;
+ for (const auto& key_info : keys) {
+ const std::string& key = key_info.first;
+ const TrackedKeyInfo& info = key_info.second;
+ // If key was not previously tracked, just copy the whole struct over.
+ // Otherwise, some merging needs to occur.
+ auto current_info = current_keys.find(key);
+ if (current_info == current_keys.end()) {
+ current_keys.emplace(key_info);
+ } else {
+ current_info->second.Merge(info);
+ }
+ }
+ }
+ }
+}
+
+void PointLockTracker::Subtract(const LockTracker& tracker) {
+ const PointLockTracker& t = static_cast<const PointLockTracker&>(tracker);
+ for (const auto& cf_keys : t.tracked_keys_) {
+ ColumnFamilyId cf = cf_keys.first;
+ const auto& keys = cf_keys.second;
+
+ auto& current_keys = tracked_keys_.at(cf);
+ for (const auto& key_info : keys) {
+ const std::string& key = key_info.first;
+ const TrackedKeyInfo& info = key_info.second;
+ uint32_t num_reads = info.num_reads;
+ uint32_t num_writes = info.num_writes;
+
+ auto current_key_info = current_keys.find(key);
+ assert(current_key_info != current_keys.end());
+
+ // Decrement the total reads/writes of this key by the number of
+ // reads/writes done since the last SavePoint.
+ if (num_reads > 0) {
+ assert(current_key_info->second.num_reads >= num_reads);
+ current_key_info->second.num_reads -= num_reads;
+ }
+ if (num_writes > 0) {
+ assert(current_key_info->second.num_writes >= num_writes);
+ current_key_info->second.num_writes -= num_writes;
+ }
+ if (current_key_info->second.num_reads == 0 &&
+ current_key_info->second.num_writes == 0) {
+ current_keys.erase(current_key_info);
+ }
+ }
+ }
+}
+
+LockTracker* PointLockTracker::GetTrackedLocksSinceSavePoint(
+ const LockTracker& save_point_tracker) const {
+ // Examine the number of reads/writes performed on all keys written
+ // since the last SavePoint and compare to the total number of reads/writes
+ // for each key.
+ LockTracker* t = new PointLockTracker();
+ const PointLockTracker& save_point_t =
+ static_cast<const PointLockTracker&>(save_point_tracker);
+ for (const auto& cf_keys : save_point_t.tracked_keys_) {
+ ColumnFamilyId cf = cf_keys.first;
+ const auto& keys = cf_keys.second;
+
+ auto& current_keys = tracked_keys_.at(cf);
+ for (const auto& key_info : keys) {
+ const std::string& key = key_info.first;
+ const TrackedKeyInfo& info = key_info.second;
+ uint32_t num_reads = info.num_reads;
+ uint32_t num_writes = info.num_writes;
+
+ auto current_key_info = current_keys.find(key);
+ assert(current_key_info != current_keys.end());
+ assert(current_key_info->second.num_reads >= num_reads);
+ assert(current_key_info->second.num_writes >= num_writes);
+
+ if (current_key_info->second.num_reads == num_reads &&
+ current_key_info->second.num_writes == num_writes) {
+ // All the reads/writes to this key were done in the last savepoint.
+ PointLockRequest r;
+ r.column_family_id = cf;
+ r.key = key;
+ r.seq = info.seq;
+ r.read_only = (num_writes == 0);
+ r.exclusive = info.exclusive;
+ t->Track(r);
+ }
+ }
+ }
+ return t;
+}
+
+PointLockStatus PointLockTracker::GetPointLockStatus(
+ ColumnFamilyId column_family_id, const std::string& key) const {
+ assert(IsPointLockSupported());
+ PointLockStatus status;
+ auto it = tracked_keys_.find(column_family_id);
+ if (it == tracked_keys_.end()) {
+ return status;
+ }
+
+ const auto& keys = it->second;
+ auto key_it = keys.find(key);
+ if (key_it == keys.end()) {
+ return status;
+ }
+
+ const TrackedKeyInfo& key_info = key_it->second;
+ status.locked = true;
+ status.exclusive = key_info.exclusive;
+ status.seq = key_info.seq;
+ return status;
+}
+
+uint64_t PointLockTracker::GetNumPointLocks() const {
+ uint64_t num_keys = 0;
+ for (const auto& cf_keys : tracked_keys_) {
+ num_keys += cf_keys.second.size();
+ }
+ return num_keys;
+}
+
+LockTracker::ColumnFamilyIterator* PointLockTracker::GetColumnFamilyIterator()
+ const {
+ return new TrackedKeysColumnFamilyIterator(tracked_keys_);
+}
+
+LockTracker::KeyIterator* PointLockTracker::GetKeyIterator(
+ ColumnFamilyId column_family_id) const {
+ assert(tracked_keys_.find(column_family_id) != tracked_keys_.end());
+ return new TrackedKeysIterator(tracked_keys_, column_family_id);
+}
+
+void PointLockTracker::Clear() { tracked_keys_.clear(); }
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/lock/point/point_lock_tracker.h b/src/rocksdb/utilities/transactions/lock/point/point_lock_tracker.h
new file mode 100644
index 000000000..daf6f9aa2
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/point/point_lock_tracker.h
@@ -0,0 +1,99 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+#ifndef ROCKSDB_LITE
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+
+#include "utilities/transactions/lock/lock_tracker.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+struct TrackedKeyInfo {
+ // Earliest sequence number that is relevant to this transaction for this key
+ SequenceNumber seq;
+
+ uint32_t num_writes;
+ uint32_t num_reads;
+
+ bool exclusive;
+
+ explicit TrackedKeyInfo(SequenceNumber seq_no)
+ : seq(seq_no), num_writes(0), num_reads(0), exclusive(false) {}
+
+ void Merge(const TrackedKeyInfo& info) {
+ assert(seq <= info.seq);
+ num_reads += info.num_reads;
+ num_writes += info.num_writes;
+ exclusive = exclusive || info.exclusive;
+ }
+};
+
+using TrackedKeyInfos = std::unordered_map<std::string, TrackedKeyInfo>;
+
+using TrackedKeys = std::unordered_map<ColumnFamilyId, TrackedKeyInfos>;
+
+// Tracks point locks on single keys.
+class PointLockTracker : public LockTracker {
+ public:
+ PointLockTracker() = default;
+
+ PointLockTracker(const PointLockTracker&) = delete;
+ PointLockTracker& operator=(const PointLockTracker&) = delete;
+
+ bool IsPointLockSupported() const override { return true; }
+
+ bool IsRangeLockSupported() const override { return false; }
+
+ void Track(const PointLockRequest& lock_request) override;
+
+ UntrackStatus Untrack(const PointLockRequest& lock_request) override;
+
+ void Track(const RangeLockRequest& /*lock_request*/) override {}
+
+ UntrackStatus Untrack(const RangeLockRequest& /*lock_request*/) override {
+ return UntrackStatus::NOT_TRACKED;
+ }
+
+ void Merge(const LockTracker& tracker) override;
+
+ void Subtract(const LockTracker& tracker) override;
+
+ void Clear() override;
+
+ virtual LockTracker* GetTrackedLocksSinceSavePoint(
+ const LockTracker& save_point_tracker) const override;
+
+ PointLockStatus GetPointLockStatus(ColumnFamilyId column_family_id,
+ const std::string& key) const override;
+
+ uint64_t GetNumPointLocks() const override;
+
+ ColumnFamilyIterator* GetColumnFamilyIterator() const override;
+
+ KeyIterator* GetKeyIterator(ColumnFamilyId column_family_id) const override;
+
+ private:
+ TrackedKeys tracked_keys_;
+};
+
+class PointLockTrackerFactory : public LockTrackerFactory {
+ public:
+ static const PointLockTrackerFactory& Get() {
+ static const PointLockTrackerFactory instance;
+ return instance;
+ }
+
+ LockTracker* Create() const override { return new PointLockTracker(); }
+
+ private:
+ PointLockTrackerFactory() {}
+};
+
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_lock_manager.h b/src/rocksdb/utilities/transactions/lock/range/range_lock_manager.h
new file mode 100644
index 000000000..01899542e
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_lock_manager.h
@@ -0,0 +1,36 @@
+// Copyright (c) Meta Platforms, Inc. and affiliates.
+//
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+//
+// Generic definitions for a Range-based Lock Manager
+//
+#pragma once
+#ifndef ROCKSDB_LITE
+
+#include "utilities/transactions/lock/lock_manager.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+/*
+ A base class for all Range-based lock managers
+
+ See also class RangeLockManagerHandle in
+ include/rocksdb/utilities/transaction_db.h
+*/
+class RangeLockManagerBase : public LockManager {
+ public:
+ // Geting a point lock is reduced to getting a range lock on a single-point
+ // range
+ using LockManager::TryLock;
+ Status TryLock(PessimisticTransaction* txn, ColumnFamilyId column_family_id,
+ const std::string& key, Env* env, bool exclusive) override {
+ Endpoint endp(key.data(), key.size(), false);
+ return TryLock(txn, column_family_id, endp, endp, env, exclusive);
+ }
+};
+
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_locking_test.cc b/src/rocksdb/utilities/transactions/lock/range/range_locking_test.cc
new file mode 100644
index 000000000..bce66c1f3
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_locking_test.cc
@@ -0,0 +1,459 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+#ifndef OS_WIN
+
+#include <algorithm>
+#include <functional>
+#include <string>
+#include <thread>
+
+#include "db/db_impl/db_impl.h"
+#include "port/port.h"
+#include "rocksdb/db.h"
+#include "rocksdb/options.h"
+#include "rocksdb/perf_context.h"
+#include "rocksdb/utilities/transaction.h"
+#include "rocksdb/utilities/transaction_db.h"
+#include "utilities/transactions/lock/point/point_lock_manager_test.h"
+#include "utilities/transactions/pessimistic_transaction_db.h"
+#include "utilities/transactions/transaction_test.h"
+
+using std::string;
+
+namespace ROCKSDB_NAMESPACE {
+
+class RangeLockingTest : public ::testing::Test {
+ public:
+ TransactionDB* db;
+ std::string dbname;
+ Options options;
+
+ std::shared_ptr<RangeLockManagerHandle> range_lock_mgr;
+ TransactionDBOptions txn_db_options;
+
+ RangeLockingTest() : db(nullptr) {
+ options.create_if_missing = true;
+ dbname = test::PerThreadDBPath("range_locking_testdb");
+
+ EXPECT_OK(DestroyDB(dbname, options));
+
+ range_lock_mgr.reset(NewRangeLockManager(nullptr));
+ txn_db_options.lock_mgr_handle = range_lock_mgr;
+
+ auto s = TransactionDB::Open(options, txn_db_options, dbname, &db);
+ assert(s.ok());
+ }
+
+ ~RangeLockingTest() {
+ delete db;
+ db = nullptr;
+ // This is to skip the assert statement in FaultInjectionTestEnv. There
+ // seems to be a bug in btrfs that the makes readdir return recently
+ // unlink-ed files. By using the default fs we simply ignore errors resulted
+ // from attempting to delete such files in DestroyDB.
+ EXPECT_OK(DestroyDB(dbname, options));
+ }
+
+ PessimisticTransaction* NewTxn(
+ TransactionOptions txn_opt = TransactionOptions()) {
+ Transaction* txn = db->BeginTransaction(WriteOptions(), txn_opt);
+ return reinterpret_cast<PessimisticTransaction*>(txn);
+ }
+};
+
+// TODO: set a smaller lock wait timeout so that the test runs faster.
+TEST_F(RangeLockingTest, BasicRangeLocking) {
+ WriteOptions write_options;
+ TransactionOptions txn_options;
+ std::string value;
+ ReadOptions read_options;
+ auto cf = db->DefaultColumnFamily();
+
+ Transaction* txn0 = db->BeginTransaction(write_options, txn_options);
+ Transaction* txn1 = db->BeginTransaction(write_options, txn_options);
+
+ // Get a range lock
+ ASSERT_OK(txn0->GetRangeLock(cf, Endpoint("a"), Endpoint("c")));
+
+ // Check that range Lock inhibits an overlapping range lock
+ {
+ auto s = txn1->GetRangeLock(cf, Endpoint("b"), Endpoint("z"));
+ ASSERT_TRUE(s.IsTimedOut());
+ }
+
+ // Check that range Lock inhibits an overlapping point lock
+ {
+ auto s = txn1->GetForUpdate(read_options, cf, Slice("b"), &value);
+ ASSERT_TRUE(s.IsTimedOut());
+ }
+
+ // Get a point lock, check that it inhibits range locks
+ ASSERT_OK(txn0->Put(cf, Slice("n"), Slice("value")));
+ {
+ auto s = txn1->GetRangeLock(cf, Endpoint("m"), Endpoint("p"));
+ ASSERT_TRUE(s.IsTimedOut());
+ }
+
+ ASSERT_OK(txn0->Commit());
+ txn1->Rollback();
+
+ delete txn0;
+ delete txn1;
+}
+
+TEST_F(RangeLockingTest, MyRocksLikeUpdate) {
+ WriteOptions write_options;
+ TransactionOptions txn_options;
+ Transaction* txn0 = db->BeginTransaction(write_options, txn_options);
+ auto cf = db->DefaultColumnFamily();
+ Status s;
+
+ // Get a range lock for the range we are about to update
+ ASSERT_OK(txn0->GetRangeLock(cf, Endpoint("a"), Endpoint("c")));
+
+ bool try_range_lock_called = false;
+
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "RangeTreeLockManager::TryRangeLock:enter",
+ [&](void* /*arg*/) { try_range_lock_called = true; });
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+
+ // For performance reasons, the following must NOT call lock_mgr->TryLock():
+ // We verify that by checking the value of try_range_lock_called.
+ ASSERT_OK(txn0->Put(cf, Slice("b"), Slice("value"),
+ /*assume_tracked=*/true));
+
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->ClearAllCallBacks();
+ ASSERT_FALSE(try_range_lock_called);
+
+ txn0->Rollback();
+
+ delete txn0;
+}
+
+TEST_F(RangeLockingTest, UpgradeLockAndGetConflict) {
+ WriteOptions write_options;
+ TransactionOptions txn_options;
+ auto cf = db->DefaultColumnFamily();
+ Status s;
+ std::string value;
+ txn_options.lock_timeout = 10;
+
+ Transaction* txn0 = db->BeginTransaction(write_options, txn_options);
+ Transaction* txn1 = db->BeginTransaction(write_options, txn_options);
+
+ // Get the shared lock in txn0
+ s = txn0->GetForUpdate(ReadOptions(), cf, Slice("a"), &value,
+ false /*exclusive*/);
+ ASSERT_TRUE(s.IsNotFound());
+
+ // Get the shared lock on the same key in txn1
+ s = txn1->GetForUpdate(ReadOptions(), cf, Slice("a"), &value,
+ false /*exclusive*/);
+ ASSERT_TRUE(s.IsNotFound());
+
+ // Now, try getting an exclusive lock that overlaps with the above
+ s = txn0->GetRangeLock(cf, Endpoint("a"), Endpoint("b"));
+ ASSERT_TRUE(s.IsTimedOut());
+
+ txn0->Rollback();
+ txn1->Rollback();
+
+ delete txn0;
+ delete txn1;
+}
+
+TEST_F(RangeLockingTest, SnapshotValidation) {
+ Status s;
+ Slice key_slice = Slice("k");
+ ColumnFamilyHandle* cfh = db->DefaultColumnFamily();
+
+ auto txn0 = NewTxn();
+ txn0->Put(key_slice, Slice("initial"));
+ txn0->Commit();
+
+ // txn1
+ auto txn1 = NewTxn();
+ txn1->SetSnapshot();
+ std::string val1;
+ ASSERT_OK(txn1->Get(ReadOptions(), cfh, key_slice, &val1));
+ ASSERT_EQ(val1, "initial");
+ val1 = val1 + std::string("-txn1");
+
+ ASSERT_OK(txn1->Put(cfh, key_slice, Slice(val1)));
+
+ // txn2
+ auto txn2 = NewTxn();
+ txn2->SetSnapshot();
+ std::string val2;
+ // This will see the original value as nothing is committed
+ // This is also Get, so it is doesn't acquire any locks.
+ ASSERT_OK(txn2->Get(ReadOptions(), cfh, key_slice, &val2));
+ ASSERT_EQ(val2, "initial");
+
+ // txn1
+ ASSERT_OK(txn1->Commit());
+
+ // txn2
+ val2 = val2 + std::string("-txn2");
+ // Now, this call should do Snapshot Validation and fail:
+ s = txn2->Put(cfh, key_slice, Slice(val2));
+ ASSERT_TRUE(s.IsBusy());
+
+ ASSERT_OK(txn2->Commit());
+
+ delete txn0;
+ delete txn1;
+ delete txn2;
+}
+
+TEST_F(RangeLockingTest, MultipleTrxLockStatusData) {
+ WriteOptions write_options;
+ TransactionOptions txn_options;
+ auto cf = db->DefaultColumnFamily();
+
+ Transaction* txn0 = db->BeginTransaction(write_options, txn_options);
+ Transaction* txn1 = db->BeginTransaction(write_options, txn_options);
+
+ // Get a range lock
+ ASSERT_OK(txn0->GetRangeLock(cf, Endpoint("z"), Endpoint("z")));
+ ASSERT_OK(txn1->GetRangeLock(cf, Endpoint("b"), Endpoint("e")));
+
+ auto s = range_lock_mgr->GetRangeLockStatusData();
+ ASSERT_EQ(s.size(), 2);
+ for (auto it = s.begin(); it != s.end(); ++it) {
+ ASSERT_EQ(it->first, cf->GetID());
+ auto val = it->second;
+ ASSERT_FALSE(val.start.inf_suffix);
+ ASSERT_FALSE(val.end.inf_suffix);
+ ASSERT_TRUE(val.exclusive);
+ ASSERT_EQ(val.ids.size(), 1);
+ if (val.ids[0] == txn0->GetID()) {
+ ASSERT_EQ(val.start.slice, "z");
+ ASSERT_EQ(val.end.slice, "z");
+ } else if (val.ids[0] == txn1->GetID()) {
+ ASSERT_EQ(val.start.slice, "b");
+ ASSERT_EQ(val.end.slice, "e");
+ } else {
+ FAIL(); // Unknown transaction ID.
+ }
+ }
+
+ delete txn0;
+ delete txn1;
+}
+
+#if defined(__has_feature)
+#if __has_feature(thread_sanitizer)
+#define SKIP_LOCK_ESCALATION_TEST 1
+#endif
+#else
+#define SKIP_LOCK_ESCALATION_TEST 1
+#endif
+
+#ifndef SKIP_LOCK_ESCALATION_TEST
+TEST_F(RangeLockingTest, BasicLockEscalation) {
+ auto cf = db->DefaultColumnFamily();
+
+ auto counters = range_lock_mgr->GetStatus();
+
+ // Initially not using any lock memory
+ ASSERT_EQ(counters.current_lock_memory, 0);
+ ASSERT_EQ(counters.escalation_count, 0);
+
+ ASSERT_EQ(0, range_lock_mgr->SetMaxLockMemory(2000));
+
+ // Insert until we see lock escalations
+ auto txn = NewTxn();
+
+ // Get the locks until we hit an escalation
+ for (int i = 0; i < 2020; i++) {
+ std::ostringstream buf;
+ buf << std::setw(8) << std::setfill('0') << i;
+ std::string buf_str = buf.str();
+ ASSERT_OK(txn->GetRangeLock(cf, Endpoint(buf_str), Endpoint(buf_str)));
+ }
+ counters = range_lock_mgr->GetStatus();
+ ASSERT_GT(counters.escalation_count, 0);
+ ASSERT_LE(counters.current_lock_memory, 2000);
+
+ delete txn;
+}
+
+// An escalation barrier function. Allow escalation iff the first two bytes are
+// identical.
+static bool escalation_barrier(const Endpoint& a, const Endpoint& b) {
+ assert(a.slice.size() > 2);
+ assert(b.slice.size() > 2);
+ if (memcmp(a.slice.data(), b.slice.data(), 2)) {
+ return true; // This is a barrier
+ } else {
+ return false; // No barrier
+ }
+}
+
+TEST_F(RangeLockingTest, LockEscalationBarrier) {
+ auto cf = db->DefaultColumnFamily();
+
+ auto counters = range_lock_mgr->GetStatus();
+
+ // Initially not using any lock memory
+ ASSERT_EQ(counters.escalation_count, 0);
+
+ range_lock_mgr->SetMaxLockMemory(8000);
+ range_lock_mgr->SetEscalationBarrierFunc(escalation_barrier);
+
+ // Insert enough locks to cause lock escalations to happen
+ auto txn = NewTxn();
+ const int N = 2000;
+ for (int i = 0; i < N; i++) {
+ std::ostringstream buf;
+ buf << std::setw(4) << std::setfill('0') << i;
+ std::string buf_str = buf.str();
+ ASSERT_OK(txn->GetRangeLock(cf, Endpoint(buf_str), Endpoint(buf_str)));
+ }
+ counters = range_lock_mgr->GetStatus();
+ ASSERT_GT(counters.escalation_count, 0);
+
+ // Check that lock escalation was not performed across escalation barriers:
+ // Use another txn to acquire locks near the barriers.
+ auto txn2 = NewTxn();
+ range_lock_mgr->SetMaxLockMemory(500000);
+ for (int i = 100; i < N; i += 100) {
+ std::ostringstream buf;
+ buf << std::setw(4) << std::setfill('0') << i - 1 << "-a";
+ std::string buf_str = buf.str();
+ // Check that we CAN get a lock near the escalation barrier
+ ASSERT_OK(txn2->GetRangeLock(cf, Endpoint(buf_str), Endpoint(buf_str)));
+ }
+
+ txn->Rollback();
+ txn2->Rollback();
+ delete txn;
+ delete txn2;
+}
+
+#endif
+
+TEST_F(RangeLockingTest, LockWaitCount) {
+ TransactionOptions txn_options;
+ auto cf = db->DefaultColumnFamily();
+ txn_options.lock_timeout = 50;
+ Transaction* txn0 = db->BeginTransaction(WriteOptions(), txn_options);
+ Transaction* txn1 = db->BeginTransaction(WriteOptions(), txn_options);
+
+ // Get a range lock
+ ASSERT_OK(txn0->GetRangeLock(cf, Endpoint("a"), Endpoint("c")));
+
+ uint64_t lock_waits1 = range_lock_mgr->GetStatus().lock_wait_count;
+ // Attempt to get a conflicting lock
+ auto s = txn1->GetRangeLock(cf, Endpoint("b"), Endpoint("z"));
+ ASSERT_TRUE(s.IsTimedOut());
+
+ // Check that the counter was incremented
+ uint64_t lock_waits2 = range_lock_mgr->GetStatus().lock_wait_count;
+ ASSERT_EQ(lock_waits1 + 1, lock_waits2);
+
+ txn0->Rollback();
+ txn1->Rollback();
+
+ delete txn0;
+ delete txn1;
+}
+
+TEST_F(RangeLockingTest, LockWaiteeAccess) {
+ TransactionOptions txn_options;
+ auto cf = db->DefaultColumnFamily();
+ txn_options.lock_timeout = 60;
+ Transaction* txn0 = db->BeginTransaction(WriteOptions(), txn_options);
+ Transaction* txn1 = db->BeginTransaction(WriteOptions(), txn_options);
+
+ // Get a range lock
+ ASSERT_OK(txn0->GetRangeLock(cf, Endpoint("a"), Endpoint("c")));
+
+ std::atomic<bool> reached(false);
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "RangeTreeLockManager::TryRangeLock:EnterWaitingTxn", [&](void* /*arg*/) {
+ reached.store(true);
+ std::this_thread::sleep_for(std::chrono::milliseconds(2000));
+ });
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+
+ port::Thread t([&]() {
+ // Attempt to get a conflicting lock
+ auto s = txn1->GetRangeLock(cf, Endpoint("b"), Endpoint("z"));
+ ASSERT_TRUE(s.ok());
+ txn1->Rollback();
+ });
+
+ while (!reached.load()) {
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
+ }
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->ClearAllCallBacks();
+
+ // Release locks and free the transaction
+ txn0->Rollback();
+ delete txn0;
+
+ t.join();
+
+ delete txn1;
+}
+
+void PointLockManagerTestExternalSetup(PointLockManagerTest* self) {
+ self->env_ = Env::Default();
+ self->db_dir_ = test::PerThreadDBPath("point_lock_manager_test");
+ ASSERT_OK(self->env_->CreateDir(self->db_dir_));
+
+ Options opt;
+ opt.create_if_missing = true;
+ TransactionDBOptions txn_opt;
+ txn_opt.transaction_lock_timeout = 0;
+
+ auto mutex_factory = std::make_shared<TransactionDBMutexFactoryImpl>();
+ self->locker_.reset(NewRangeLockManager(mutex_factory)->getLockManager());
+ std::shared_ptr<RangeLockManagerHandle> range_lock_mgr =
+ std::dynamic_pointer_cast<RangeLockManagerHandle>(self->locker_);
+ txn_opt.lock_mgr_handle = range_lock_mgr;
+
+ ASSERT_OK(TransactionDB::Open(opt, txn_opt, self->db_dir_, &self->db_));
+ self->wait_sync_point_name_ = "RangeTreeLockManager::TryRangeLock:WaitingTxn";
+}
+
+INSTANTIATE_TEST_CASE_P(RangeLockManager, AnyLockManagerTest,
+ ::testing::Values(PointLockManagerTestExternalSetup));
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
+
+#else // OS_WIN
+
+#include <stdio.h>
+int main(int /*argc*/, char** /*argv*/) {
+ fprintf(stderr, "skipped as Range Locking is not supported on Windows\n");
+ return 0;
+}
+
+#endif // OS_WIN
+
+#else
+#include <stdio.h>
+
+int main(int /*argc*/, char** /*argv*/) {
+ fprintf(stderr,
+ "skipped as transactions are not supported in rocksdb_lite\n");
+ return 0;
+}
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/COPYING.AGPLv3 b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/COPYING.AGPLv3
new file mode 100644
index 000000000..dba13ed2d
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/COPYING.AGPLv3
@@ -0,0 +1,661 @@
+ GNU AFFERO GENERAL PUBLIC LICENSE
+ Version 3, 19 November 2007
+
+ Copyright (C) 2007 Free Software Foundation, Inc. <http://fsf.org/>
+ Everyone is permitted to copy and distribute verbatim copies
+ of this license document, but changing it is not allowed.
+
+ Preamble
+
+ The GNU Affero General Public License is a free, copyleft license for
+software and other kinds of works, specifically designed to ensure
+cooperation with the community in the case of network server software.
+
+ The licenses for most software and other practical works are designed
+to take away your freedom to share and change the works. By contrast,
+our General Public Licenses are intended to guarantee your freedom to
+share and change all versions of a program--to make sure it remains free
+software for all its users.
+
+ When we speak of free software, we are referring to freedom, not
+price. Our General Public Licenses are designed to make sure that you
+have the freedom to distribute copies of free software (and charge for
+them if you wish), that you receive source code or can get it if you
+want it, that you can change the software or use pieces of it in new
+free programs, and that you know you can do these things.
+
+ Developers that use our General Public Licenses protect your rights
+with two steps: (1) assert copyright on the software, and (2) offer
+you this License which gives you legal permission to copy, distribute
+and/or modify the software.
+
+ A secondary benefit of defending all users' freedom is that
+improvements made in alternate versions of the program, if they
+receive widespread use, become available for other developers to
+incorporate. Many developers of free software are heartened and
+encouraged by the resulting cooperation. However, in the case of
+software used on network servers, this result may fail to come about.
+The GNU General Public License permits making a modified version and
+letting the public access it on a server without ever releasing its
+source code to the public.
+
+ The GNU Affero General Public License is designed specifically to
+ensure that, in such cases, the modified source code becomes available
+to the community. It requires the operator of a network server to
+provide the source code of the modified version running there to the
+users of that server. Therefore, public use of a modified version, on
+a publicly accessible server, gives the public access to the source
+code of the modified version.
+
+ An older license, called the Affero General Public License and
+published by Affero, was designed to accomplish similar goals. This is
+a different license, not a version of the Affero GPL, but Affero has
+released a new version of the Affero GPL which permits relicensing under
+this license.
+
+ The precise terms and conditions for copying, distribution and
+modification follow.
+
+ TERMS AND CONDITIONS
+
+ 0. Definitions.
+
+ "This License" refers to version 3 of the GNU Affero General Public License.
+
+ "Copyright" also means copyright-like laws that apply to other kinds of
+works, such as semiconductor masks.
+
+ "The Program" refers to any copyrightable work licensed under this
+License. Each licensee is addressed as "you". "Licensees" and
+"recipients" may be individuals or organizations.
+
+ To "modify" a work means to copy from or adapt all or part of the work
+in a fashion requiring copyright permission, other than the making of an
+exact copy. The resulting work is called a "modified version" of the
+earlier work or a work "based on" the earlier work.
+
+ A "covered work" means either the unmodified Program or a work based
+on the Program.
+
+ To "propagate" a work means to do anything with it that, without
+permission, would make you directly or secondarily liable for
+infringement under applicable copyright law, except executing it on a
+computer or modifying a private copy. Propagation includes copying,
+distribution (with or without modification), making available to the
+public, and in some countries other activities as well.
+
+ To "convey" a work means any kind of propagation that enables other
+parties to make or receive copies. Mere interaction with a user through
+a computer network, with no transfer of a copy, is not conveying.
+
+ An interactive user interface displays "Appropriate Legal Notices"
+to the extent that it includes a convenient and prominently visible
+feature that (1) displays an appropriate copyright notice, and (2)
+tells the user that there is no warranty for the work (except to the
+extent that warranties are provided), that licensees may convey the
+work under this License, and how to view a copy of this License. If
+the interface presents a list of user commands or options, such as a
+menu, a prominent item in the list meets this criterion.
+
+ 1. Source Code.
+
+ The "source code" for a work means the preferred form of the work
+for making modifications to it. "Object code" means any non-source
+form of a work.
+
+ A "Standard Interface" means an interface that either is an official
+standard defined by a recognized standards body, or, in the case of
+interfaces specified for a particular programming language, one that
+is widely used among developers working in that language.
+
+ The "System Libraries" of an executable work include anything, other
+than the work as a whole, that (a) is included in the normal form of
+packaging a Major Component, but which is not part of that Major
+Component, and (b) serves only to enable use of the work with that
+Major Component, or to implement a Standard Interface for which an
+implementation is available to the public in source code form. A
+"Major Component", in this context, means a major essential component
+(kernel, window system, and so on) of the specific operating system
+(if any) on which the executable work runs, or a compiler used to
+produce the work, or an object code interpreter used to run it.
+
+ The "Corresponding Source" for a work in object code form means all
+the source code needed to generate, install, and (for an executable
+work) run the object code and to modify the work, including scripts to
+control those activities. However, it does not include the work's
+System Libraries, or general-purpose tools or generally available free
+programs which are used unmodified in performing those activities but
+which are not part of the work. For example, Corresponding Source
+includes interface definition files associated with source files for
+the work, and the source code for shared libraries and dynamically
+linked subprograms that the work is specifically designed to require,
+such as by intimate data communication or control flow between those
+subprograms and other parts of the work.
+
+ The Corresponding Source need not include anything that users
+can regenerate automatically from other parts of the Corresponding
+Source.
+
+ The Corresponding Source for a work in source code form is that
+same work.
+
+ 2. Basic Permissions.
+
+ All rights granted under this License are granted for the term of
+copyright on the Program, and are irrevocable provided the stated
+conditions are met. This License explicitly affirms your unlimited
+permission to run the unmodified Program. The output from running a
+covered work is covered by this License only if the output, given its
+content, constitutes a covered work. This License acknowledges your
+rights of fair use or other equivalent, as provided by copyright law.
+
+ You may make, run and propagate covered works that you do not
+convey, without conditions so long as your license otherwise remains
+in force. You may convey covered works to others for the sole purpose
+of having them make modifications exclusively for you, or provide you
+with facilities for running those works, provided that you comply with
+the terms of this License in conveying all material for which you do
+not control copyright. Those thus making or running the covered works
+for you must do so exclusively on your behalf, under your direction
+and control, on terms that prohibit them from making any copies of
+your copyrighted material outside their relationship with you.
+
+ Conveying under any other circumstances is permitted solely under
+the conditions stated below. Sublicensing is not allowed; section 10
+makes it unnecessary.
+
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
+
+ No covered work shall be deemed part of an effective technological
+measure under any applicable law fulfilling obligations under article
+11 of the WIPO copyright treaty adopted on 20 December 1996, or
+similar laws prohibiting or restricting circumvention of such
+measures.
+
+ When you convey a covered work, you waive any legal power to forbid
+circumvention of technological measures to the extent such circumvention
+is effected by exercising rights under this License with respect to
+the covered work, and you disclaim any intention to limit operation or
+modification of the work as a means of enforcing, against the work's
+users, your or third parties' legal rights to forbid circumvention of
+technological measures.
+
+ 4. Conveying Verbatim Copies.
+
+ You may convey verbatim copies of the Program's source code as you
+receive it, in any medium, provided that you conspicuously and
+appropriately publish on each copy an appropriate copyright notice;
+keep intact all notices stating that this License and any
+non-permissive terms added in accord with section 7 apply to the code;
+keep intact all notices of the absence of any warranty; and give all
+recipients a copy of this License along with the Program.
+
+ You may charge any price or no price for each copy that you convey,
+and you may offer support or warranty protection for a fee.
+
+ 5. Conveying Modified Source Versions.
+
+ You may convey a work based on the Program, or the modifications to
+produce it from the Program, in the form of source code under the
+terms of section 4, provided that you also meet all of these conditions:
+
+ a) The work must carry prominent notices stating that you modified
+ it, and giving a relevant date.
+
+ b) The work must carry prominent notices stating that it is
+ released under this License and any conditions added under section
+ 7. This requirement modifies the requirement in section 4 to
+ "keep intact all notices".
+
+ c) You must license the entire work, as a whole, under this
+ License to anyone who comes into possession of a copy. This
+ License will therefore apply, along with any applicable section 7
+ additional terms, to the whole of the work, and all its parts,
+ regardless of how they are packaged. This License gives no
+ permission to license the work in any other way, but it does not
+ invalidate such permission if you have separately received it.
+
+ d) If the work has interactive user interfaces, each must display
+ Appropriate Legal Notices; however, if the Program has interactive
+ interfaces that do not display Appropriate Legal Notices, your
+ work need not make them do so.
+
+ A compilation of a covered work with other separate and independent
+works, which are not by their nature extensions of the covered work,
+and which are not combined with it such as to form a larger program,
+in or on a volume of a storage or distribution medium, is called an
+"aggregate" if the compilation and its resulting copyright are not
+used to limit the access or legal rights of the compilation's users
+beyond what the individual works permit. Inclusion of a covered work
+in an aggregate does not cause this License to apply to the other
+parts of the aggregate.
+
+ 6. Conveying Non-Source Forms.
+
+ You may convey a covered work in object code form under the terms
+of sections 4 and 5, provided that you also convey the
+machine-readable Corresponding Source under the terms of this License,
+in one of these ways:
+
+ a) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by the
+ Corresponding Source fixed on a durable physical medium
+ customarily used for software interchange.
+
+ b) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by a
+ written offer, valid for at least three years and valid for as
+ long as you offer spare parts or customer support for that product
+ model, to give anyone who possesses the object code either (1) a
+ copy of the Corresponding Source for all the software in the
+ product that is covered by this License, on a durable physical
+ medium customarily used for software interchange, for a price no
+ more than your reasonable cost of physically performing this
+ conveying of source, or (2) access to copy the
+ Corresponding Source from a network server at no charge.
+
+ c) Convey individual copies of the object code with a copy of the
+ written offer to provide the Corresponding Source. This
+ alternative is allowed only occasionally and noncommercially, and
+ only if you received the object code with such an offer, in accord
+ with subsection 6b.
+
+ d) Convey the object code by offering access from a designated
+ place (gratis or for a charge), and offer equivalent access to the
+ Corresponding Source in the same way through the same place at no
+ further charge. You need not require recipients to copy the
+ Corresponding Source along with the object code. If the place to
+ copy the object code is a network server, the Corresponding Source
+ may be on a different server (operated by you or a third party)
+ that supports equivalent copying facilities, provided you maintain
+ clear directions next to the object code saying where to find the
+ Corresponding Source. Regardless of what server hosts the
+ Corresponding Source, you remain obligated to ensure that it is
+ available for as long as needed to satisfy these requirements.
+
+ e) Convey the object code using peer-to-peer transmission, provided
+ you inform other peers where the object code and Corresponding
+ Source of the work are being offered to the general public at no
+ charge under subsection 6d.
+
+ A separable portion of the object code, whose source code is excluded
+from the Corresponding Source as a System Library, need not be
+included in conveying the object code work.
+
+ A "User Product" is either (1) a "consumer product", which means any
+tangible personal property which is normally used for personal, family,
+or household purposes, or (2) anything designed or sold for incorporation
+into a dwelling. In determining whether a product is a consumer product,
+doubtful cases shall be resolved in favor of coverage. For a particular
+product received by a particular user, "normally used" refers to a
+typical or common use of that class of product, regardless of the status
+of the particular user or of the way in which the particular user
+actually uses, or expects or is expected to use, the product. A product
+is a consumer product regardless of whether the product has substantial
+commercial, industrial or non-consumer uses, unless such uses represent
+the only significant mode of use of the product.
+
+ "Installation Information" for a User Product means any methods,
+procedures, authorization keys, or other information required to install
+and execute modified versions of a covered work in that User Product from
+a modified version of its Corresponding Source. The information must
+suffice to ensure that the continued functioning of the modified object
+code is in no case prevented or interfered with solely because
+modification has been made.
+
+ If you convey an object code work under this section in, or with, or
+specifically for use in, a User Product, and the conveying occurs as
+part of a transaction in which the right of possession and use of the
+User Product is transferred to the recipient in perpetuity or for a
+fixed term (regardless of how the transaction is characterized), the
+Corresponding Source conveyed under this section must be accompanied
+by the Installation Information. But this requirement does not apply
+if neither you nor any third party retains the ability to install
+modified object code on the User Product (for example, the work has
+been installed in ROM).
+
+ The requirement to provide Installation Information does not include a
+requirement to continue to provide support service, warranty, or updates
+for a work that has been modified or installed by the recipient, or for
+the User Product in which it has been modified or installed. Access to a
+network may be denied when the modification itself materially and
+adversely affects the operation of the network or violates the rules and
+protocols for communication across the network.
+
+ Corresponding Source conveyed, and Installation Information provided,
+in accord with this section must be in a format that is publicly
+documented (and with an implementation available to the public in
+source code form), and must require no special password or key for
+unpacking, reading or copying.
+
+ 7. Additional Terms.
+
+ "Additional permissions" are terms that supplement the terms of this
+License by making exceptions from one or more of its conditions.
+Additional permissions that are applicable to the entire Program shall
+be treated as though they were included in this License, to the extent
+that they are valid under applicable law. If additional permissions
+apply only to part of the Program, that part may be used separately
+under those permissions, but the entire Program remains governed by
+this License without regard to the additional permissions.
+
+ When you convey a copy of a covered work, you may at your option
+remove any additional permissions from that copy, or from any part of
+it. (Additional permissions may be written to require their own
+removal in certain cases when you modify the work.) You may place
+additional permissions on material, added by you to a covered work,
+for which you have or can give appropriate copyright permission.
+
+ Notwithstanding any other provision of this License, for material you
+add to a covered work, you may (if authorized by the copyright holders of
+that material) supplement the terms of this License with terms:
+
+ a) Disclaiming warranty or limiting liability differently from the
+ terms of sections 15 and 16 of this License; or
+
+ b) Requiring preservation of specified reasonable legal notices or
+ author attributions in that material or in the Appropriate Legal
+ Notices displayed by works containing it; or
+
+ c) Prohibiting misrepresentation of the origin of that material, or
+ requiring that modified versions of such material be marked in
+ reasonable ways as different from the original version; or
+
+ d) Limiting the use for publicity purposes of names of licensors or
+ authors of the material; or
+
+ e) Declining to grant rights under trademark law for use of some
+ trade names, trademarks, or service marks; or
+
+ f) Requiring indemnification of licensors and authors of that
+ material by anyone who conveys the material (or modified versions of
+ it) with contractual assumptions of liability to the recipient, for
+ any liability that these contractual assumptions directly impose on
+ those licensors and authors.
+
+ All other non-permissive additional terms are considered "further
+restrictions" within the meaning of section 10. If the Program as you
+received it, or any part of it, contains a notice stating that it is
+governed by this License along with a term that is a further
+restriction, you may remove that term. If a license document contains
+a further restriction but permits relicensing or conveying under this
+License, you may add to a covered work material governed by the terms
+of that license document, provided that the further restriction does
+not survive such relicensing or conveying.
+
+ If you add terms to a covered work in accord with this section, you
+must place, in the relevant source files, a statement of the
+additional terms that apply to those files, or a notice indicating
+where to find the applicable terms.
+
+ Additional terms, permissive or non-permissive, may be stated in the
+form of a separately written license, or stated as exceptions;
+the above requirements apply either way.
+
+ 8. Termination.
+
+ You may not propagate or modify a covered work except as expressly
+provided under this License. Any attempt otherwise to propagate or
+modify it is void, and will automatically terminate your rights under
+this License (including any patent licenses granted under the third
+paragraph of section 11).
+
+ However, if you cease all violation of this License, then your
+license from a particular copyright holder is reinstated (a)
+provisionally, unless and until the copyright holder explicitly and
+finally terminates your license, and (b) permanently, if the copyright
+holder fails to notify you of the violation by some reasonable means
+prior to 60 days after the cessation.
+
+ Moreover, your license from a particular copyright holder is
+reinstated permanently if the copyright holder notifies you of the
+violation by some reasonable means, this is the first time you have
+received notice of violation of this License (for any work) from that
+copyright holder, and you cure the violation prior to 30 days after
+your receipt of the notice.
+
+ Termination of your rights under this section does not terminate the
+licenses of parties who have received copies or rights from you under
+this License. If your rights have been terminated and not permanently
+reinstated, you do not qualify to receive new licenses for the same
+material under section 10.
+
+ 9. Acceptance Not Required for Having Copies.
+
+ You are not required to accept this License in order to receive or
+run a copy of the Program. Ancillary propagation of a covered work
+occurring solely as a consequence of using peer-to-peer transmission
+to receive a copy likewise does not require acceptance. However,
+nothing other than this License grants you permission to propagate or
+modify any covered work. These actions infringe copyright if you do
+not accept this License. Therefore, by modifying or propagating a
+covered work, you indicate your acceptance of this License to do so.
+
+ 10. Automatic Licensing of Downstream Recipients.
+
+ Each time you convey a covered work, the recipient automatically
+receives a license from the original licensors, to run, modify and
+propagate that work, subject to this License. You are not responsible
+for enforcing compliance by third parties with this License.
+
+ An "entity transaction" is a transaction transferring control of an
+organization, or substantially all assets of one, or subdividing an
+organization, or merging organizations. If propagation of a covered
+work results from an entity transaction, each party to that
+transaction who receives a copy of the work also receives whatever
+licenses to the work the party's predecessor in interest had or could
+give under the previous paragraph, plus a right to possession of the
+Corresponding Source of the work from the predecessor in interest, if
+the predecessor has it or can get it with reasonable efforts.
+
+ You may not impose any further restrictions on the exercise of the
+rights granted or affirmed under this License. For example, you may
+not impose a license fee, royalty, or other charge for exercise of
+rights granted under this License, and you may not initiate litigation
+(including a cross-claim or counterclaim in a lawsuit) alleging that
+any patent claim is infringed by making, using, selling, offering for
+sale, or importing the Program or any portion of it.
+
+ 11. Patents.
+
+ A "contributor" is a copyright holder who authorizes use under this
+License of the Program or a work on which the Program is based. The
+work thus licensed is called the contributor's "contributor version".
+
+ A contributor's "essential patent claims" are all patent claims
+owned or controlled by the contributor, whether already acquired or
+hereafter acquired, that would be infringed by some manner, permitted
+by this License, of making, using, or selling its contributor version,
+but do not include claims that would be infringed only as a
+consequence of further modification of the contributor version. For
+purposes of this definition, "control" includes the right to grant
+patent sublicenses in a manner consistent with the requirements of
+this License.
+
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
+patent license under the contributor's essential patent claims, to
+make, use, sell, offer for sale, import and otherwise run, modify and
+propagate the contents of its contributor version.
+
+ In the following three paragraphs, a "patent license" is any express
+agreement or commitment, however denominated, not to enforce a patent
+(such as an express permission to practice a patent or covenant not to
+sue for patent infringement). To "grant" such a patent license to a
+party means to make such an agreement or commitment not to enforce a
+patent against the party.
+
+ If you convey a covered work, knowingly relying on a patent license,
+and the Corresponding Source of the work is not available for anyone
+to copy, free of charge and under the terms of this License, through a
+publicly available network server or other readily accessible means,
+then you must either (1) cause the Corresponding Source to be so
+available, or (2) arrange to deprive yourself of the benefit of the
+patent license for this particular work, or (3) arrange, in a manner
+consistent with the requirements of this License, to extend the patent
+license to downstream recipients. "Knowingly relying" means you have
+actual knowledge that, but for the patent license, your conveying the
+covered work in a country, or your recipient's use of the covered work
+in a country, would infringe one or more identifiable patents in that
+country that you have reason to believe are valid.
+
+ If, pursuant to or in connection with a single transaction or
+arrangement, you convey, or propagate by procuring conveyance of, a
+covered work, and grant a patent license to some of the parties
+receiving the covered work authorizing them to use, propagate, modify
+or convey a specific copy of the covered work, then the patent license
+you grant is automatically extended to all recipients of the covered
+work and works based on it.
+
+ A patent license is "discriminatory" if it does not include within
+the scope of its coverage, prohibits the exercise of, or is
+conditioned on the non-exercise of one or more of the rights that are
+specifically granted under this License. You may not convey a covered
+work if you are a party to an arrangement with a third party that is
+in the business of distributing software, under which you make payment
+to the third party based on the extent of your activity of conveying
+the work, and under which the third party grants, to any of the
+parties who would receive the covered work from you, a discriminatory
+patent license (a) in connection with copies of the covered work
+conveyed by you (or copies made from those copies), or (b) primarily
+for and in connection with specific products or compilations that
+contain the covered work, unless you entered into that arrangement,
+or that patent license was granted, prior to 28 March 2007.
+
+ Nothing in this License shall be construed as excluding or limiting
+any implied license or other defenses to infringement that may
+otherwise be available to you under applicable patent law.
+
+ 12. No Surrender of Others' Freedom.
+
+ If conditions are imposed on you (whether by court order, agreement or
+otherwise) that contradict the conditions of this License, they do not
+excuse you from the conditions of this License. If you cannot convey a
+covered work so as to satisfy simultaneously your obligations under this
+License and any other pertinent obligations, then as a consequence you may
+not convey it at all. For example, if you agree to terms that obligate you
+to collect a royalty for further conveying from those to whom you convey
+the Program, the only way you could satisfy both those terms and this
+License would be to refrain entirely from conveying the Program.
+
+ 13. Remote Network Interaction; Use with the GNU General Public License.
+
+ Notwithstanding any other provision of this License, if you modify the
+Program, your modified version must prominently offer all users
+interacting with it remotely through a computer network (if your version
+supports such interaction) an opportunity to receive the Corresponding
+Source of your version by providing access to the Corresponding Source
+from a network server at no charge, through some standard or customary
+means of facilitating copying of software. This Corresponding Source
+shall include the Corresponding Source for any work covered by version 3
+of the GNU General Public License that is incorporated pursuant to the
+following paragraph.
+
+ Notwithstanding any other provision of this License, you have
+permission to link or combine any covered work with a work licensed
+under version 3 of the GNU General Public License into a single
+combined work, and to convey the resulting work. The terms of this
+License will continue to apply to the part which is the covered work,
+but the work with which it is combined will remain governed by version
+3 of the GNU General Public License.
+
+ 14. Revised Versions of this License.
+
+ The Free Software Foundation may publish revised and/or new versions of
+the GNU Affero General Public License from time to time. Such new versions
+will be similar in spirit to the present version, but may differ in detail to
+address new problems or concerns.
+
+ Each version is given a distinguishing version number. If the
+Program specifies that a certain numbered version of the GNU Affero General
+Public License "or any later version" applies to it, you have the
+option of following the terms and conditions either of that numbered
+version or of any later version published by the Free Software
+Foundation. If the Program does not specify a version number of the
+GNU Affero General Public License, you may choose any version ever published
+by the Free Software Foundation.
+
+ If the Program specifies that a proxy can decide which future
+versions of the GNU Affero General Public License can be used, that proxy's
+public statement of acceptance of a version permanently authorizes you
+to choose that version for the Program.
+
+ Later license versions may give you additional or different
+permissions. However, no additional obligations are imposed on any
+author or copyright holder as a result of your choosing to follow a
+later version.
+
+ 15. Disclaimer of Warranty.
+
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
+APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
+HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
+OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
+THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
+IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
+ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
+
+ 16. Limitation of Liability.
+
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
+WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
+THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
+GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
+USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
+DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
+PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
+EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
+SUCH DAMAGES.
+
+ 17. Interpretation of Sections 15 and 16.
+
+ If the disclaimer of warranty and limitation of liability provided
+above cannot be given local legal effect according to their terms,
+reviewing courts shall apply local law that most closely approximates
+an absolute waiver of all civil liability in connection with the
+Program, unless a warranty or assumption of liability accompanies a
+copy of the Program in return for a fee.
+
+ END OF TERMS AND CONDITIONS
+
+ How to Apply These Terms to Your New Programs
+
+ If you develop a new program, and you want it to be of the greatest
+possible use to the public, the best way to achieve this is to make it
+free software which everyone can redistribute and change under these terms.
+
+ To do so, attach the following notices to the program. It is safest
+to attach them to the start of each source file to most effectively
+state the exclusion of warranty; and each file should have at least
+the "copyright" line and a pointer to where the full notice is found.
+
+ <one line to give the program's name and a brief idea of what it does.>
+ Copyright (C) <year> <name of author>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+Also add information on how to contact you by electronic and paper mail.
+
+ If your software can interact with users remotely through a computer
+network, you should also make sure that it provides a way for users to
+get its source. For example, if your program is a web application, its
+interface could display a "Source" link that leads users to an archive
+of the code. There are many ways you could offer source, and different
+solutions will be better for different programs; see section 13 for the
+specific requirements.
+
+ You should also get your employer (if you work as a programmer) or school,
+if any, to sign a "copyright disclaimer" for the program, if necessary.
+For more information on this, and how to apply and follow the GNU AGPL, see
+<http://www.gnu.org/licenses/>.
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/COPYING.APACHEv2 b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/COPYING.APACHEv2
new file mode 100644
index 000000000..ecbfc770f
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/COPYING.APACHEv2
@@ -0,0 +1,174 @@
+Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/COPYING.GPLv2 b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/COPYING.GPLv2
new file mode 100644
index 000000000..d511905c1
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/COPYING.GPLv2
@@ -0,0 +1,339 @@
+ GNU GENERAL PUBLIC LICENSE
+ Version 2, June 1991
+
+ Copyright (C) 1989, 1991 Free Software Foundation, Inc.,
+ 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ Everyone is permitted to copy and distribute verbatim copies
+ of this license document, but changing it is not allowed.
+
+ Preamble
+
+ The licenses for most software are designed to take away your
+freedom to share and change it. By contrast, the GNU General Public
+License is intended to guarantee your freedom to share and change free
+software--to make sure the software is free for all its users. This
+General Public License applies to most of the Free Software
+Foundation's software and to any other program whose authors commit to
+using it. (Some other Free Software Foundation software is covered by
+the GNU Lesser General Public License instead.) You can apply it to
+your programs, too.
+
+ When we speak of free software, we are referring to freedom, not
+price. Our General Public Licenses are designed to make sure that you
+have the freedom to distribute copies of free software (and charge for
+this service if you wish), that you receive source code or can get it
+if you want it, that you can change the software or use pieces of it
+in new free programs; and that you know you can do these things.
+
+ To protect your rights, we need to make restrictions that forbid
+anyone to deny you these rights or to ask you to surrender the rights.
+These restrictions translate to certain responsibilities for you if you
+distribute copies of the software, or if you modify it.
+
+ For example, if you distribute copies of such a program, whether
+gratis or for a fee, you must give the recipients all the rights that
+you have. You must make sure that they, too, receive or can get the
+source code. And you must show them these terms so they know their
+rights.
+
+ We protect your rights with two steps: (1) copyright the software, and
+(2) offer you this license which gives you legal permission to copy,
+distribute and/or modify the software.
+
+ Also, for each author's protection and ours, we want to make certain
+that everyone understands that there is no warranty for this free
+software. If the software is modified by someone else and passed on, we
+want its recipients to know that what they have is not the original, so
+that any problems introduced by others will not reflect on the original
+authors' reputations.
+
+ Finally, any free program is threatened constantly by software
+patents. We wish to avoid the danger that redistributors of a free
+program will individually obtain patent licenses, in effect making the
+program proprietary. To prevent this, we have made it clear that any
+patent must be licensed for everyone's free use or not licensed at all.
+
+ The precise terms and conditions for copying, distribution and
+modification follow.
+
+ GNU GENERAL PUBLIC LICENSE
+ TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
+
+ 0. This License applies to any program or other work which contains
+a notice placed by the copyright holder saying it may be distributed
+under the terms of this General Public License. The "Program", below,
+refers to any such program or work, and a "work based on the Program"
+means either the Program or any derivative work under copyright law:
+that is to say, a work containing the Program or a portion of it,
+either verbatim or with modifications and/or translated into another
+language. (Hereinafter, translation is included without limitation in
+the term "modification".) Each licensee is addressed as "you".
+
+Activities other than copying, distribution and modification are not
+covered by this License; they are outside its scope. The act of
+running the Program is not restricted, and the output from the Program
+is covered only if its contents constitute a work based on the
+Program (independent of having been made by running the Program).
+Whether that is true depends on what the Program does.
+
+ 1. You may copy and distribute verbatim copies of the Program's
+source code as you receive it, in any medium, provided that you
+conspicuously and appropriately publish on each copy an appropriate
+copyright notice and disclaimer of warranty; keep intact all the
+notices that refer to this License and to the absence of any warranty;
+and give any other recipients of the Program a copy of this License
+along with the Program.
+
+You may charge a fee for the physical act of transferring a copy, and
+you may at your option offer warranty protection in exchange for a fee.
+
+ 2. You may modify your copy or copies of the Program or any portion
+of it, thus forming a work based on the Program, and copy and
+distribute such modifications or work under the terms of Section 1
+above, provided that you also meet all of these conditions:
+
+ a) You must cause the modified files to carry prominent notices
+ stating that you changed the files and the date of any change.
+
+ b) You must cause any work that you distribute or publish, that in
+ whole or in part contains or is derived from the Program or any
+ part thereof, to be licensed as a whole at no charge to all third
+ parties under the terms of this License.
+
+ c) If the modified program normally reads commands interactively
+ when run, you must cause it, when started running for such
+ interactive use in the most ordinary way, to print or display an
+ announcement including an appropriate copyright notice and a
+ notice that there is no warranty (or else, saying that you provide
+ a warranty) and that users may redistribute the program under
+ these conditions, and telling the user how to view a copy of this
+ License. (Exception: if the Program itself is interactive but
+ does not normally print such an announcement, your work based on
+ the Program is not required to print an announcement.)
+
+These requirements apply to the modified work as a whole. If
+identifiable sections of that work are not derived from the Program,
+and can be reasonably considered independent and separate works in
+themselves, then this License, and its terms, do not apply to those
+sections when you distribute them as separate works. But when you
+distribute the same sections as part of a whole which is a work based
+on the Program, the distribution of the whole must be on the terms of
+this License, whose permissions for other licensees extend to the
+entire whole, and thus to each and every part regardless of who wrote it.
+
+Thus, it is not the intent of this section to claim rights or contest
+your rights to work written entirely by you; rather, the intent is to
+exercise the right to control the distribution of derivative or
+collective works based on the Program.
+
+In addition, mere aggregation of another work not based on the Program
+with the Program (or with a work based on the Program) on a volume of
+a storage or distribution medium does not bring the other work under
+the scope of this License.
+
+ 3. You may copy and distribute the Program (or a work based on it,
+under Section 2) in object code or executable form under the terms of
+Sections 1 and 2 above provided that you also do one of the following:
+
+ a) Accompany it with the complete corresponding machine-readable
+ source code, which must be distributed under the terms of Sections
+ 1 and 2 above on a medium customarily used for software interchange; or,
+
+ b) Accompany it with a written offer, valid for at least three
+ years, to give any third party, for a charge no more than your
+ cost of physically performing source distribution, a complete
+ machine-readable copy of the corresponding source code, to be
+ distributed under the terms of Sections 1 and 2 above on a medium
+ customarily used for software interchange; or,
+
+ c) Accompany it with the information you received as to the offer
+ to distribute corresponding source code. (This alternative is
+ allowed only for noncommercial distribution and only if you
+ received the program in object code or executable form with such
+ an offer, in accord with Subsection b above.)
+
+The source code for a work means the preferred form of the work for
+making modifications to it. For an executable work, complete source
+code means all the source code for all modules it contains, plus any
+associated interface definition files, plus the scripts used to
+control compilation and installation of the executable. However, as a
+special exception, the source code distributed need not include
+anything that is normally distributed (in either source or binary
+form) with the major components (compiler, kernel, and so on) of the
+operating system on which the executable runs, unless that component
+itself accompanies the executable.
+
+If distribution of executable or object code is made by offering
+access to copy from a designated place, then offering equivalent
+access to copy the source code from the same place counts as
+distribution of the source code, even though third parties are not
+compelled to copy the source along with the object code.
+
+ 4. You may not copy, modify, sublicense, or distribute the Program
+except as expressly provided under this License. Any attempt
+otherwise to copy, modify, sublicense or distribute the Program is
+void, and will automatically terminate your rights under this License.
+However, parties who have received copies, or rights, from you under
+this License will not have their licenses terminated so long as such
+parties remain in full compliance.
+
+ 5. You are not required to accept this License, since you have not
+signed it. However, nothing else grants you permission to modify or
+distribute the Program or its derivative works. These actions are
+prohibited by law if you do not accept this License. Therefore, by
+modifying or distributing the Program (or any work based on the
+Program), you indicate your acceptance of this License to do so, and
+all its terms and conditions for copying, distributing or modifying
+the Program or works based on it.
+
+ 6. Each time you redistribute the Program (or any work based on the
+Program), the recipient automatically receives a license from the
+original licensor to copy, distribute or modify the Program subject to
+these terms and conditions. You may not impose any further
+restrictions on the recipients' exercise of the rights granted herein.
+You are not responsible for enforcing compliance by third parties to
+this License.
+
+ 7. If, as a consequence of a court judgment or allegation of patent
+infringement or for any other reason (not limited to patent issues),
+conditions are imposed on you (whether by court order, agreement or
+otherwise) that contradict the conditions of this License, they do not
+excuse you from the conditions of this License. If you cannot
+distribute so as to satisfy simultaneously your obligations under this
+License and any other pertinent obligations, then as a consequence you
+may not distribute the Program at all. For example, if a patent
+license would not permit royalty-free redistribution of the Program by
+all those who receive copies directly or indirectly through you, then
+the only way you could satisfy both it and this License would be to
+refrain entirely from distribution of the Program.
+
+If any portion of this section is held invalid or unenforceable under
+any particular circumstance, the balance of the section is intended to
+apply and the section as a whole is intended to apply in other
+circumstances.
+
+It is not the purpose of this section to induce you to infringe any
+patents or other property right claims or to contest validity of any
+such claims; this section has the sole purpose of protecting the
+integrity of the free software distribution system, which is
+implemented by public license practices. Many people have made
+generous contributions to the wide range of software distributed
+through that system in reliance on consistent application of that
+system; it is up to the author/donor to decide if he or she is willing
+to distribute software through any other system and a licensee cannot
+impose that choice.
+
+This section is intended to make thoroughly clear what is believed to
+be a consequence of the rest of this License.
+
+ 8. If the distribution and/or use of the Program is restricted in
+certain countries either by patents or by copyrighted interfaces, the
+original copyright holder who places the Program under this License
+may add an explicit geographical distribution limitation excluding
+those countries, so that distribution is permitted only in or among
+countries not thus excluded. In such case, this License incorporates
+the limitation as if written in the body of this License.
+
+ 9. The Free Software Foundation may publish revised and/or new versions
+of the General Public License from time to time. Such new versions will
+be similar in spirit to the present version, but may differ in detail to
+address new problems or concerns.
+
+Each version is given a distinguishing version number. If the Program
+specifies a version number of this License which applies to it and "any
+later version", you have the option of following the terms and conditions
+either of that version or of any later version published by the Free
+Software Foundation. If the Program does not specify a version number of
+this License, you may choose any version ever published by the Free Software
+Foundation.
+
+ 10. If you wish to incorporate parts of the Program into other free
+programs whose distribution conditions are different, write to the author
+to ask for permission. For software which is copyrighted by the Free
+Software Foundation, write to the Free Software Foundation; we sometimes
+make exceptions for this. Our decision will be guided by the two goals
+of preserving the free status of all derivatives of our free software and
+of promoting the sharing and reuse of software generally.
+
+ NO WARRANTY
+
+ 11. BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY
+FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN
+OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES
+PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED
+OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
+MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS
+TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE
+PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING,
+REPAIR OR CORRECTION.
+
+ 12. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
+WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR
+REDISTRIBUTE THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES,
+INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING
+OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED
+TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY
+YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER
+PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE
+POSSIBILITY OF SUCH DAMAGES.
+
+ END OF TERMS AND CONDITIONS
+
+ How to Apply These Terms to Your New Programs
+
+ If you develop a new program, and you want it to be of the greatest
+possible use to the public, the best way to achieve this is to make it
+free software which everyone can redistribute and change under these terms.
+
+ To do so, attach the following notices to the program. It is safest
+to attach them to the start of each source file to most effectively
+convey the exclusion of warranty; and each file should have at least
+the "copyright" line and a pointer to where the full notice is found.
+
+ <one line to give the program's name and a brief idea of what it does.>
+ Copyright (C) <year> <name of author>
+
+ This program is free software; you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation; either version 2 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License along
+ with this program; if not, write to the Free Software Foundation, Inc.,
+ 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
+Also add information on how to contact you by electronic and paper mail.
+
+If the program is interactive, make it output a short notice like this
+when it starts in an interactive mode:
+
+ Gnomovision version 69, Copyright (C) year name of author
+ Gnomovision comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
+ This is free software, and you are welcome to redistribute it
+ under certain conditions; type `show c' for details.
+
+The hypothetical commands `show w' and `show c' should show the appropriate
+parts of the General Public License. Of course, the commands you use may
+be called something other than `show w' and `show c'; they could even be
+mouse-clicks or menu items--whatever suits your program.
+
+You should also get your employer (if you work as a programmer) or your
+school, if any, to sign a "copyright disclaimer" for the program, if
+necessary. Here is a sample; alter the names:
+
+ Yoyodyne, Inc., hereby disclaims all copyright interest in the program
+ `Gnomovision' (which makes passes at compilers) written by James Hacker.
+
+ <signature of Ty Coon>, 1 April 1989
+ Ty Coon, President of Vice
+
+This General Public License does not permit incorporating your program into
+proprietary programs. If your program is a subroutine library, you may
+consider it more useful to permit linking proprietary applications with the
+library. If this is what you want to do, use the GNU Lesser General
+Public License instead of this License.
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/README b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/README
new file mode 100644
index 000000000..2ea86bf46
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/README
@@ -0,0 +1,13 @@
+The files in this directory originally come from
+https://github.com/percona/PerconaFT/.
+
+This directory only includes the "locktree" part of PerconaFT, and its
+dependencies.
+
+The following modifications were made:
+- Make locktree usable outside of PerconaFT library
+- Add shared read-only lock support
+
+The files named *_subst.* are substitutes of the PerconaFT's files, they
+contain replacements of PerconaFT's functionality.
+
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/db.h b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/db.h
new file mode 100644
index 000000000..5aa826c8e
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/db.h
@@ -0,0 +1,76 @@
+#ifndef _DB_H
+#define _DB_H
+
+#include <stdint.h>
+#include <sys/types.h>
+
+typedef struct __toku_dbt DBT;
+
+// port: this is currently not used
+struct simple_dbt {
+ uint32_t len;
+ void *data;
+};
+
+// engine status info
+// engine status is passed to handlerton as an array of
+// TOKU_ENGINE_STATUS_ROW_S[]
+typedef enum {
+ STATUS_FS_STATE = 0, // interpret as file system state (redzone) enum
+ STATUS_UINT64, // interpret as uint64_t
+ STATUS_CHARSTR, // interpret as char *
+ STATUS_UNIXTIME, // interpret as time_t
+ STATUS_TOKUTIME, // interpret as tokutime_t
+ STATUS_PARCOUNT, // interpret as PARTITIONED_COUNTER
+ STATUS_DOUBLE // interpret as double
+} toku_engine_status_display_type;
+
+typedef enum {
+ TOKU_ENGINE_STATUS = (1ULL << 0), // Include when asking for engine status
+ TOKU_GLOBAL_STATUS =
+ (1ULL << 1), // Include when asking for information_schema.global_status
+} toku_engine_status_include_type;
+
+typedef struct __toku_engine_status_row {
+ const char *keyname; // info schema key, should not change across revisions
+ // without good reason
+ const char
+ *columnname; // column for mysql, e.g. information_schema.global_status.
+ // TOKUDB_ will automatically be prefixed.
+ const char *legend; // the text that will appear at user interface
+ toku_engine_status_display_type type; // how to interpret the value
+ toku_engine_status_include_type
+ include; // which kinds of callers should get read this row?
+ union {
+ double dnum;
+ uint64_t num;
+ const char *str;
+ char datebuf[26];
+ struct partitioned_counter *parcount;
+ } value;
+} * TOKU_ENGINE_STATUS_ROW, TOKU_ENGINE_STATUS_ROW_S;
+
+#define DB_BUFFER_SMALL -30999
+#define DB_LOCK_DEADLOCK -30995
+#define DB_LOCK_NOTGRANTED -30994
+#define DB_NOTFOUND -30989
+#define DB_KEYEXIST -30996
+#define DB_DBT_MALLOC 8
+#define DB_DBT_REALLOC 64
+#define DB_DBT_USERMEM 256
+
+/* PerconaFT specific error codes */
+#define TOKUDB_OUT_OF_LOCKS -100000
+
+typedef void (*lock_wait_callback)(void *arg, uint64_t requesting_txnid,
+ uint64_t blocking_txnid);
+
+struct __toku_dbt {
+ void *data;
+ size_t size;
+ size_t ulen;
+ // One of DB_DBT_XXX flags
+ uint32_t flags;
+};
+
+#endif
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/ft/comparator.h b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/ft/comparator.h
new file mode 100644
index 000000000..718efc623
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/ft/comparator.h
@@ -0,0 +1,138 @@
+/* -*- mode: C; c-basic-offset: 4; indent-tabs-mode: nil -*- */
+// vim: ft=cpp:expandtab:ts=8:sw=4:softtabstop=4:
+/*======
+This file is part of PerconaFT.
+
+
+Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved.
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License, version 2,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License, version 3,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+======= */
+
+#ident \
+ "Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved."
+
+#pragma once
+
+#include <string.h>
+
+#include "../db.h"
+#include "../portability/memory.h"
+#include "../util/dbt.h"
+
+typedef int (*ft_compare_func)(void *arg, const DBT *a, const DBT *b);
+
+int toku_keycompare(const void *key1, size_t key1len, const void *key2,
+ size_t key2len);
+
+int toku_builtin_compare_fun(const DBT *, const DBT *)
+ __attribute__((__visibility__("default")));
+
+namespace toku {
+
+// a comparator object encapsulates the data necessary for
+// comparing two keys in a fractal tree. it further understands
+// that points may be positive or negative infinity.
+
+class comparator {
+ void init(ft_compare_func cmp, void *cmp_arg, uint8_t memcmp_magic) {
+ _cmp = cmp;
+ _cmp_arg = cmp_arg;
+ _memcmp_magic = memcmp_magic;
+ }
+
+ public:
+ // This magic value is reserved to mean that the magic has not been set.
+ static const uint8_t MEMCMP_MAGIC_NONE = 0;
+
+ void create(ft_compare_func cmp, void *cmp_arg,
+ uint8_t memcmp_magic = MEMCMP_MAGIC_NONE) {
+ init(cmp, cmp_arg, memcmp_magic);
+ }
+
+ // inherit the attributes of another comparator, but keep our own
+ // copy of fake_db that is owned separately from the one given.
+ void inherit(const comparator &cmp) {
+ invariant_notnull(cmp._cmp);
+ init(cmp._cmp, cmp._cmp_arg, cmp._memcmp_magic);
+ }
+
+ // like inherit, but doesn't require that the this comparator
+ // was already created
+ void create_from(const comparator &cmp) { inherit(cmp); }
+
+ void destroy() {}
+
+ ft_compare_func get_compare_func() const { return _cmp; }
+
+ uint8_t get_memcmp_magic() const { return _memcmp_magic; }
+
+ bool valid() const { return _cmp != nullptr; }
+
+ inline bool dbt_has_memcmp_magic(const DBT *dbt) const {
+ return *reinterpret_cast<const char *>(dbt->data) == _memcmp_magic;
+ }
+
+ int operator()(const DBT *a, const DBT *b) const {
+ if (__builtin_expect(toku_dbt_is_infinite(a) || toku_dbt_is_infinite(b),
+ 0)) {
+ return toku_dbt_infinite_compare(a, b);
+ } else if (_memcmp_magic != MEMCMP_MAGIC_NONE
+ // If `a' has the memcmp magic..
+ && dbt_has_memcmp_magic(a)
+ // ..then we expect `b' to also have the memcmp magic
+ && __builtin_expect(dbt_has_memcmp_magic(b), 1)) {
+ assert(0); // psergey: this branch should not be taken.
+ return toku_builtin_compare_fun(a, b);
+ } else {
+ // yikes, const sadness here
+ return _cmp(_cmp_arg, a, b);
+ }
+ }
+
+ private:
+ ft_compare_func _cmp;
+ void *_cmp_arg;
+
+ uint8_t _memcmp_magic;
+};
+
+} /* namespace toku */
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/ft/ft-status.h b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/ft/ft-status.h
new file mode 100644
index 000000000..1b4511172
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/ft/ft-status.h
@@ -0,0 +1,102 @@
+/* -*- mode: C++; c-basic-offset: 4; indent-tabs-mode: nil -*- */
+// vim: ft=cpp:expandtab:ts=8:sw=4:softtabstop=4:
+#ident "$Id$"
+/*======
+This file is part of PerconaFT.
+
+
+Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved.
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License, version 2,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License, version 3,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+======= */
+
+#ident \
+ "Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved."
+
+#pragma once
+
+#include "../db.h"
+#include "../portability/toku_race_tools.h"
+#include "../util/status.h"
+
+//
+// Lock Tree Manager statistics
+//
+class LTM_STATUS_S {
+ public:
+ enum {
+ LTM_SIZE_CURRENT = 0,
+ LTM_SIZE_LIMIT,
+ LTM_ESCALATION_COUNT,
+ LTM_ESCALATION_TIME,
+ LTM_ESCALATION_LATEST_RESULT,
+ LTM_NUM_LOCKTREES,
+ LTM_LOCK_REQUESTS_PENDING,
+ LTM_STO_NUM_ELIGIBLE,
+ LTM_STO_END_EARLY_COUNT,
+ LTM_STO_END_EARLY_TIME,
+ LTM_WAIT_COUNT,
+ LTM_WAIT_TIME,
+ LTM_LONG_WAIT_COUNT,
+ LTM_LONG_WAIT_TIME,
+ LTM_TIMEOUT_COUNT,
+ LTM_WAIT_ESCALATION_COUNT,
+ LTM_WAIT_ESCALATION_TIME,
+ LTM_LONG_WAIT_ESCALATION_COUNT,
+ LTM_LONG_WAIT_ESCALATION_TIME,
+ LTM_STATUS_NUM_ROWS // must be last
+ };
+
+ void init(void);
+ void destroy(void);
+
+ TOKU_ENGINE_STATUS_ROW_S status[LTM_STATUS_NUM_ROWS];
+
+ private:
+ bool m_initialized = false;
+};
+typedef LTM_STATUS_S* LTM_STATUS;
+extern LTM_STATUS_S ltm_status;
+
+#define LTM_STATUS_VAL(x) ltm_status.status[LTM_STATUS_S::x].value.num
+
+void toku_status_init(void); // just call ltm_status.init();
+void toku_status_destroy(void); // just call ltm_status.destroy();
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/concurrent_tree.cc b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/concurrent_tree.cc
new file mode 100644
index 000000000..5110cd482
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/concurrent_tree.cc
@@ -0,0 +1,139 @@
+/* -*- mode: C++; c-basic-offset: 4; indent-tabs-mode: nil -*- */
+// vim: ft=cpp:expandtab:ts=8:sw=4:softtabstop=4:
+#ifndef ROCKSDB_LITE
+#ifndef OS_WIN
+#ident "$Id$"
+/*======
+This file is part of PerconaFT.
+
+
+Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved.
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License, version 2,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License, version 3,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+======= */
+
+#ident \
+ "Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved."
+
+#include "concurrent_tree.h"
+
+// PORT #include <toku_assert.h>
+namespace toku {
+
+void concurrent_tree::create(const comparator *cmp) {
+ // start with an empty root node. we do this instead of
+ // setting m_root to null so there's always a root to lock
+ m_root.create_root(cmp);
+}
+
+void concurrent_tree::destroy(void) { m_root.destroy_root(); }
+
+bool concurrent_tree::is_empty(void) { return m_root.is_empty(); }
+
+uint64_t concurrent_tree::get_insertion_memory_overhead(void) {
+ return sizeof(treenode);
+}
+
+void concurrent_tree::locked_keyrange::prepare(concurrent_tree *tree) {
+ // the first step in acquiring a locked keyrange is locking the root
+ treenode *const root = &tree->m_root;
+ m_tree = tree;
+ m_subtree = root;
+ m_range = keyrange::get_infinite_range();
+ root->mutex_lock();
+}
+
+void concurrent_tree::locked_keyrange::acquire(const keyrange &range) {
+ treenode *const root = &m_tree->m_root;
+
+ treenode *subtree;
+ if (root->is_empty() || root->range_overlaps(range)) {
+ subtree = root;
+ } else {
+ // we do not have a precomputed comparison hint, so pass null
+ const keyrange::comparison *cmp_hint = nullptr;
+ subtree = root->find_node_with_overlapping_child(range, cmp_hint);
+ }
+
+ // subtree is locked. it will be unlocked when this is release()'d
+ invariant_notnull(subtree);
+ m_range = range;
+ m_subtree = subtree;
+}
+
+bool concurrent_tree::locked_keyrange::add_shared_owner(const keyrange &range,
+ TXNID new_owner) {
+ return m_subtree->insert(range, new_owner, /*is_shared*/ true);
+}
+
+void concurrent_tree::locked_keyrange::release(void) {
+ m_subtree->mutex_unlock();
+}
+
+void concurrent_tree::locked_keyrange::insert(const keyrange &range,
+ TXNID txnid, bool is_shared) {
+ // empty means no children, and only the root should ever be empty
+ if (m_subtree->is_empty()) {
+ m_subtree->set_range_and_txnid(range, txnid, is_shared);
+ } else {
+ m_subtree->insert(range, txnid, is_shared);
+ }
+}
+
+void concurrent_tree::locked_keyrange::remove(const keyrange &range,
+ TXNID txnid) {
+ invariant(!m_subtree->is_empty());
+ treenode *new_subtree = m_subtree->remove(range, txnid);
+ // if removing range changed the root of the subtree,
+ // then the subtree must be the root of the entire tree.
+ if (new_subtree == nullptr) {
+ invariant(m_subtree->is_root());
+ invariant(m_subtree->is_empty());
+ }
+}
+
+void concurrent_tree::locked_keyrange::remove_all(void) {
+ m_subtree->recursive_remove();
+}
+
+} /* namespace toku */
+#endif // OS_WIN
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/concurrent_tree.h b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/concurrent_tree.h
new file mode 100644
index 000000000..e1bfb86c5
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/concurrent_tree.h
@@ -0,0 +1,174 @@
+/* -*- mode: C++; c-basic-offset: 4; indent-tabs-mode: nil -*- */
+// vim: ft=cpp:expandtab:ts=8:sw=2:softtabstop=2:
+#ident "$Id$"
+/*======
+This file is part of PerconaFT.
+
+
+Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved.
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License, version 2,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License, version 3,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+======= */
+
+#ident \
+ "Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved."
+
+#pragma once
+
+#include "../ft/comparator.h"
+#include "keyrange.h"
+#include "treenode.h"
+
+namespace toku {
+
+// A concurrent_tree stores non-overlapping ranges.
+// Access to disjoint parts of the tree usually occurs concurrently.
+
+class concurrent_tree {
+ public:
+ // A locked_keyrange gives you exclusive access to read and write
+ // operations that occur on any keys in that range. You only have
+ // the right to operate on keys in that range or keys that were read
+ // from the keyrange using iterate()
+ //
+ // Access model:
+ // - user prepares a locked keyrange. all threads serialize behind prepare().
+ // - user breaks the serialzation point by acquiring a range, or releasing.
+ // - one thread operates on a certain locked_keyrange object at a time.
+ // - when the thread is finished, it releases
+
+ class locked_keyrange {
+ public:
+ // effect: prepare to acquire a locked keyrange over the given
+ // concurrent_tree, preventing other threads from preparing
+ // until this thread either does acquire() or release().
+ // note: operations performed on a prepared keyrange are equivalent
+ // to ones performed on an acquired keyrange over -inf, +inf.
+ // rationale: this provides the user with a serialization point for
+ // descending
+ // or modifying the the tree. it also proives a convenient way of
+ // doing serializable operations on the tree.
+ // There are two valid sequences of calls:
+ // - prepare, acquire, [operations], release
+ // - prepare, [operations],release
+ void prepare(concurrent_tree *tree);
+
+ // requires: the locked keyrange was prepare()'d
+ // effect: acquire a locked keyrange over the given concurrent_tree.
+ // the locked keyrange represents the range of keys overlapped
+ // by the given range
+ void acquire(const keyrange &range);
+
+ // effect: releases a locked keyrange and the mutex it holds
+ void release(void);
+
+ // effect: iterate over each range this locked_keyrange represents,
+ // calling function->fn() on each node's keyrange and txnid
+ // until there are no more or the function returns false
+ template <class F>
+ void iterate(F *function) const {
+ // if the subtree is non-empty, traverse it by calling the given
+ // function on each range, txnid pair found that overlaps.
+ if (!m_subtree->is_empty()) {
+ m_subtree->traverse_overlaps(m_range, function);
+ }
+ }
+
+ // Adds another owner to the lock on the specified keyrange.
+ // requires: the keyrange contains one treenode whose bounds are
+ // exactly equal to the specifed range (no sub/supersets)
+ bool add_shared_owner(const keyrange &range, TXNID new_owner);
+
+ // inserts the given range into the tree, with an associated txnid.
+ // requires: range does not overlap with anything in this locked_keyrange
+ // rationale: caller is responsible for only inserting unique ranges
+ void insert(const keyrange &range, TXNID txnid, bool is_shared);
+
+ // effect: removes the given range from the tree.
+ // - txnid=TXNID_ANY means remove the range no matter what its
+ // owners are
+ // - Other value means remove the specified txnid from
+ // ownership (if the range has other owners, it will remain
+ // in the tree)
+ // requires: range exists exactly in this locked_keyrange
+ // rationale: caller is responsible for only removing existing ranges
+ void remove(const keyrange &range, TXNID txnid);
+
+ // effect: removes all of the keys represented by this locked keyrange
+ // rationale: we'd like a fast way to empty out a tree
+ void remove_all(void);
+
+ private:
+ // the concurrent tree this locked keyrange is for
+ concurrent_tree *m_tree;
+
+ // the range of keys this locked keyrange represents
+ keyrange m_range;
+
+ // the subtree under which all overlapping ranges exist
+ treenode *m_subtree;
+
+ friend class concurrent_tree_unit_test;
+ };
+
+ // effect: initialize the tree to an empty state
+ void create(const comparator *cmp);
+
+ // effect: destroy the tree.
+ // requires: tree is empty
+ void destroy(void);
+
+ // returns: true iff the tree is empty
+ bool is_empty(void);
+
+ // returns: the memory overhead of a single insertion into the tree
+ static uint64_t get_insertion_memory_overhead(void);
+
+ private:
+ // the root needs to always exist so there's a lock to grab
+ // even if the tree is empty. that's why we store a treenode
+ // here and not a pointer to one.
+ treenode m_root;
+
+ friend class concurrent_tree_unit_test;
+};
+
+} /* namespace toku */
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/keyrange.cc b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/keyrange.cc
new file mode 100644
index 000000000..e50ace5a9
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/keyrange.cc
@@ -0,0 +1,222 @@
+/* -*- mode: C++; c-basic-offset: 4; indent-tabs-mode: nil -*- */
+// vim: ft=cpp:expandtab:ts=8:sw=4:softtabstop=4:
+#ifndef ROCKSDB_LITE
+#ifndef OS_WIN
+#ident "$Id$"
+/*======
+This file is part of PerconaFT.
+
+
+Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved.
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License, version 2,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License, version 3,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+======= */
+
+#ident \
+ "Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved."
+
+#include "keyrange.h"
+
+#include "../util/dbt.h"
+
+namespace toku {
+
+// create a keyrange by borrowing the left and right dbt
+// pointers. no memory is copied. no checks for infinity needed.
+void keyrange::create(const DBT *left, const DBT *right) {
+ init_empty();
+ m_left_key = left;
+ m_right_key = right;
+}
+
+// destroy the key copies. if they were never set, then destroy does nothing.
+void keyrange::destroy(void) {
+ toku_destroy_dbt(&m_left_key_copy);
+ toku_destroy_dbt(&m_right_key_copy);
+}
+
+// create a keyrange by copying the keys from the given range.
+void keyrange::create_copy(const keyrange &range) {
+ // start with an initialized, empty range
+ init_empty();
+
+ // optimize the case where the left and right keys are the same.
+ // we'd like to only have one copy of the data.
+ if (toku_dbt_equals(range.get_left_key(), range.get_right_key())) {
+ set_both_keys(range.get_left_key());
+ } else {
+ // replace our empty left and right keys with
+ // copies of the range's left and right keys
+ replace_left_key(range.get_left_key());
+ replace_right_key(range.get_right_key());
+ }
+}
+
+// extend this keyrange by choosing the leftmost and rightmost
+// endpoints between this range and the given. replaced keys
+// in this range are freed and inherited keys are copied.
+void keyrange::extend(const comparator &cmp, const keyrange &range) {
+ const DBT *range_left = range.get_left_key();
+ const DBT *range_right = range.get_right_key();
+ if (cmp(range_left, get_left_key()) < 0) {
+ replace_left_key(range_left);
+ }
+ if (cmp(range_right, get_right_key()) > 0) {
+ replace_right_key(range_right);
+ }
+}
+
+// how much memory does this keyrange take?
+// - the size of the left and right keys
+// --- ignore the fact that we may have optimized the point case.
+// it complicates things for little gain.
+// - the size of the keyrange class itself
+uint64_t keyrange::get_memory_size(void) const {
+ const DBT *left_key = get_left_key();
+ const DBT *right_key = get_right_key();
+ return left_key->size + right_key->size + sizeof(keyrange);
+}
+
+// compare ranges.
+keyrange::comparison keyrange::compare(const comparator &cmp,
+ const keyrange &range) const {
+ if (cmp(get_right_key(), range.get_left_key()) < 0) {
+ return comparison::LESS_THAN;
+ } else if (cmp(get_left_key(), range.get_right_key()) > 0) {
+ return comparison::GREATER_THAN;
+ } else if (cmp(get_left_key(), range.get_left_key()) == 0 &&
+ cmp(get_right_key(), range.get_right_key()) == 0) {
+ return comparison::EQUALS;
+ } else {
+ return comparison::OVERLAPS;
+ }
+}
+
+bool keyrange::overlaps(const comparator &cmp, const keyrange &range) const {
+ // equality is a stronger form of overlapping.
+ // so two ranges "overlap" if they're either equal or just overlapping.
+ comparison c = compare(cmp, range);
+ return c == comparison::EQUALS || c == comparison::OVERLAPS;
+}
+
+keyrange keyrange::get_infinite_range(void) {
+ keyrange range;
+ range.create(toku_dbt_negative_infinity(), toku_dbt_positive_infinity());
+ return range;
+}
+
+void keyrange::init_empty(void) {
+ m_left_key = nullptr;
+ m_right_key = nullptr;
+ toku_init_dbt(&m_left_key_copy);
+ toku_init_dbt(&m_right_key_copy);
+ m_point_range = false;
+}
+
+const DBT *keyrange::get_left_key(void) const {
+ if (m_left_key) {
+ return m_left_key;
+ } else {
+ return &m_left_key_copy;
+ }
+}
+
+const DBT *keyrange::get_right_key(void) const {
+ if (m_right_key) {
+ return m_right_key;
+ } else {
+ return &m_right_key_copy;
+ }
+}
+
+// copy the given once and set both the left and right pointers.
+// optimization for point ranges, so the left and right ranges
+// are not copied twice.
+void keyrange::set_both_keys(const DBT *key) {
+ if (toku_dbt_is_infinite(key)) {
+ m_left_key = key;
+ m_right_key = key;
+ } else {
+ toku_clone_dbt(&m_left_key_copy, *key);
+ toku_copyref_dbt(&m_right_key_copy, m_left_key_copy);
+ }
+ m_point_range = true;
+}
+
+// destroy the current left key. set and possibly copy the new one
+void keyrange::replace_left_key(const DBT *key) {
+ // a little magic:
+ //
+ // if this is a point range, then the left and right keys share
+ // one copy of the data, and it lives in the left key copy. so
+ // if we're replacing the left key, move the real data to the
+ // right key copy instead of destroying it. now, the memory is
+ // owned by the right key and the left key may be replaced.
+ if (m_point_range) {
+ m_right_key_copy = m_left_key_copy;
+ } else {
+ toku_destroy_dbt(&m_left_key_copy);
+ }
+
+ if (toku_dbt_is_infinite(key)) {
+ m_left_key = key;
+ } else {
+ toku_clone_dbt(&m_left_key_copy, *key);
+ m_left_key = nullptr;
+ }
+ m_point_range = false;
+}
+
+// destroy the current right key. set and possibly copy the new one
+void keyrange::replace_right_key(const DBT *key) {
+ toku_destroy_dbt(&m_right_key_copy);
+ if (toku_dbt_is_infinite(key)) {
+ m_right_key = key;
+ } else {
+ toku_clone_dbt(&m_right_key_copy, *key);
+ m_right_key = nullptr;
+ }
+ m_point_range = false;
+}
+
+} /* namespace toku */
+#endif // OS_WIN
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/keyrange.h b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/keyrange.h
new file mode 100644
index 000000000..f9aeea0c4
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/keyrange.h
@@ -0,0 +1,141 @@
+/* -*- mode: C++; c-basic-offset: 4; indent-tabs-mode: nil -*- */
+// vim: ft=cpp:expandtab:ts=8:sw=4:softtabstop=4:
+#ident "$Id$"
+/*======
+This file is part of PerconaFT.
+
+
+Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved.
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License, version 2,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License, version 3,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+======= */
+
+#ident \
+ "Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved."
+
+#pragma once
+
+#include "../ft/comparator.h"
+
+namespace toku {
+
+// A keyrange has a left and right key as endpoints.
+//
+// When a keyrange is created it owns no memory, but when it copies
+// or extends another keyrange, it copies memory as necessary. This
+// means it is cheap in the common case.
+
+class keyrange {
+ public:
+ // effect: constructor that borrows left and right key pointers.
+ // no memory is allocated or copied.
+ void create(const DBT *left_key, const DBT *right_key);
+
+ // effect: constructor that allocates and copies another keyrange's points.
+ void create_copy(const keyrange &range);
+
+ // effect: destroys the keyrange, freeing any allocated memory
+ void destroy(void);
+
+ // effect: extends the keyrange by choosing the leftmost and rightmost
+ // endpoints from this range and the given range.
+ // replaced keys in this range are freed, new keys are copied.
+ void extend(const comparator &cmp, const keyrange &range);
+
+ // returns: the amount of memory this keyrange takes. does not account
+ // for point optimizations or malloc overhead.
+ uint64_t get_memory_size(void) const;
+
+ // returns: pointer to the left key of this range
+ const DBT *get_left_key(void) const;
+
+ // returns: pointer to the right key of this range
+ const DBT *get_right_key(void) const;
+
+ // two ranges are either equal, lt, gt, or overlapping
+ enum comparison { EQUALS, LESS_THAN, GREATER_THAN, OVERLAPS };
+
+ // effect: compares this range to the given range
+ // returns: LESS_THAN if given range is strictly to the left
+ // GREATER_THAN if given range is strictly to the right
+ // EQUALS if given range has the same left and right endpoints
+ // OVERLAPS if at least one of the given range's endpoints falls
+ // between this range's endpoints
+ comparison compare(const comparator &cmp, const keyrange &range) const;
+
+ // returns: true if the range and the given range are equal or overlapping
+ bool overlaps(const comparator &cmp, const keyrange &range) const;
+
+ // returns: a keyrange representing -inf, +inf
+ static keyrange get_infinite_range(void);
+
+ private:
+ // some keys should be copied, some keys should not be.
+ //
+ // to support both, we use two DBTs for copies and two pointers
+ // for temporaries. the access rule is:
+ // - if a pointer is non-null, then it reprsents the key.
+ // - otherwise the pointer is null, and the key is in the copy.
+ DBT m_left_key_copy;
+ DBT m_right_key_copy;
+ const DBT *m_left_key;
+ const DBT *m_right_key;
+
+ // if this range is a point range, then m_left_key == m_right_key
+ // and the actual data is stored exactly once in m_left_key_copy.
+ bool m_point_range;
+
+ // effect: initializes a keyrange to be empty
+ void init_empty(void);
+
+ // effect: copies the given key once into the left key copy
+ // and sets the right key copy to share the left.
+ // rationale: optimization for point ranges to only do one malloc
+ void set_both_keys(const DBT *key);
+
+ // effect: destroys the current left key. sets and copies the new one.
+ void replace_left_key(const DBT *key);
+
+ // effect: destroys the current right key. sets and copies the new one.
+ void replace_right_key(const DBT *key);
+};
+
+} /* namespace toku */
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/lock_request.cc b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/lock_request.cc
new file mode 100644
index 000000000..3d217be70
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/lock_request.cc
@@ -0,0 +1,527 @@
+/* -*- mode: C++; c-basic-offset: 4; indent-tabs-mode: nil -*- */
+// vim: ft=cpp:expandtab:ts=8:sw=2:softtabstop=2:
+#ifndef ROCKSDB_LITE
+#ifndef OS_WIN
+#ident "$Id$"
+/*======
+This file is part of PerconaFT.
+
+
+Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved.
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License, version 2,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License, version 3,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+======= */
+
+#ident \
+ "Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved."
+
+#include "lock_request.h"
+
+#include "../portability/toku_race_tools.h"
+#include "../portability/txn_subst.h"
+#include "../util/dbt.h"
+#include "locktree.h"
+
+namespace toku {
+
+// initialize a lock request's internals
+void lock_request::create(toku_external_mutex_factory_t mutex_factory) {
+ m_txnid = TXNID_NONE;
+ m_conflicting_txnid = TXNID_NONE;
+ m_start_time = 0;
+ m_left_key = nullptr;
+ m_right_key = nullptr;
+ toku_init_dbt(&m_left_key_copy);
+ toku_init_dbt(&m_right_key_copy);
+
+ m_type = type::UNKNOWN;
+ m_lt = nullptr;
+
+ m_complete_r = 0;
+ m_state = state::UNINITIALIZED;
+ m_info = nullptr;
+
+ // psergey-todo: this condition is for interruptible wait
+ // note: moved to here from lock_request::create:
+ toku_external_cond_init(mutex_factory, &m_wait_cond);
+
+ m_start_test_callback = nullptr;
+ m_start_before_pending_test_callback = nullptr;
+ m_retry_test_callback = nullptr;
+}
+
+// destroy a lock request.
+void lock_request::destroy(void) {
+ invariant(m_state != state::PENDING);
+ invariant(m_state != state::DESTROYED);
+ m_state = state::DESTROYED;
+ toku_destroy_dbt(&m_left_key_copy);
+ toku_destroy_dbt(&m_right_key_copy);
+ toku_external_cond_destroy(&m_wait_cond);
+}
+
+// set the lock request parameters. this API allows a lock request to be reused.
+void lock_request::set(locktree *lt, TXNID txnid, const DBT *left_key,
+ const DBT *right_key, lock_request::type lock_type,
+ bool big_txn, void *extra) {
+ invariant(m_state != state::PENDING);
+ m_lt = lt;
+
+ m_txnid = txnid;
+ m_left_key = left_key;
+ m_right_key = right_key;
+ toku_destroy_dbt(&m_left_key_copy);
+ toku_destroy_dbt(&m_right_key_copy);
+ m_type = lock_type;
+ m_state = state::INITIALIZED;
+ m_info = lt ? lt->get_lock_request_info() : nullptr;
+ m_big_txn = big_txn;
+ m_extra = extra;
+}
+
+// get rid of any stored left and right key copies and
+// replace them with copies of the given left and right key
+void lock_request::copy_keys() {
+ if (!toku_dbt_is_infinite(m_left_key)) {
+ toku_clone_dbt(&m_left_key_copy, *m_left_key);
+ m_left_key = &m_left_key_copy;
+ }
+ if (!toku_dbt_is_infinite(m_right_key)) {
+ toku_clone_dbt(&m_right_key_copy, *m_right_key);
+ m_right_key = &m_right_key_copy;
+ }
+}
+
+// what are the conflicts for this pending lock request?
+void lock_request::get_conflicts(txnid_set *conflicts) {
+ invariant(m_state == state::PENDING);
+ const bool is_write_request = m_type == type::WRITE;
+ m_lt->get_conflicts(is_write_request, m_txnid, m_left_key, m_right_key,
+ conflicts);
+}
+
+// build a wait-for-graph for this lock request and the given conflict set
+// for each transaction B that blocks A's lock request
+// if B is blocked then
+// add (A,T) to the WFG and if B is new, fill in the WFG from B
+void lock_request::build_wait_graph(wfg *wait_graph,
+ const txnid_set &conflicts) {
+ uint32_t num_conflicts = conflicts.size();
+ for (uint32_t i = 0; i < num_conflicts; i++) {
+ TXNID conflicting_txnid = conflicts.get(i);
+ lock_request *conflicting_request = find_lock_request(conflicting_txnid);
+ invariant(conflicting_txnid != m_txnid);
+ invariant(conflicting_request != this);
+ if (conflicting_request) {
+ bool already_exists = wait_graph->node_exists(conflicting_txnid);
+ wait_graph->add_edge(m_txnid, conflicting_txnid);
+ if (!already_exists) {
+ // recursively build the wait for graph rooted at the conflicting
+ // request, given its set of lock conflicts.
+ txnid_set other_conflicts;
+ other_conflicts.create();
+ conflicting_request->get_conflicts(&other_conflicts);
+ conflicting_request->build_wait_graph(wait_graph, other_conflicts);
+ other_conflicts.destroy();
+ }
+ }
+ }
+}
+
+// returns: true if the current set of lock requests contains
+// a deadlock, false otherwise.
+bool lock_request::deadlock_exists(const txnid_set &conflicts) {
+ wfg wait_graph;
+ wait_graph.create();
+
+ build_wait_graph(&wait_graph, conflicts);
+
+ std::function<void(TXNID)> reporter;
+ if (m_deadlock_cb) {
+ reporter = [this](TXNID a) {
+ lock_request *req = find_lock_request(a);
+ if (req) {
+ m_deadlock_cb(req->m_txnid, (req->m_type == lock_request::WRITE),
+ req->m_left_key, req->m_right_key);
+ }
+ };
+ }
+
+ bool deadlock = wait_graph.cycle_exists_from_txnid(m_txnid, reporter);
+ wait_graph.destroy();
+ return deadlock;
+}
+
+// try to acquire a lock described by this lock request.
+int lock_request::start(void) {
+ int r;
+
+ txnid_set conflicts;
+ conflicts.create();
+ if (m_type == type::WRITE) {
+ r = m_lt->acquire_write_lock(m_txnid, m_left_key, m_right_key, &conflicts,
+ m_big_txn);
+ } else {
+ invariant(m_type == type::READ);
+ r = m_lt->acquire_read_lock(m_txnid, m_left_key, m_right_key, &conflicts,
+ m_big_txn);
+ }
+
+ // if the lock is not granted, save it to the set of lock requests
+ // and check for a deadlock. if there is one, complete it as failed
+ if (r == DB_LOCK_NOTGRANTED) {
+ copy_keys();
+ m_state = state::PENDING;
+ m_start_time = toku_current_time_microsec() / 1000;
+ m_conflicting_txnid = conflicts.get(0);
+ if (m_start_before_pending_test_callback)
+ m_start_before_pending_test_callback();
+ toku_external_mutex_lock(&m_info->mutex);
+ insert_into_lock_requests();
+ if (deadlock_exists(conflicts)) {
+ remove_from_lock_requests();
+ r = DB_LOCK_DEADLOCK;
+ }
+ toku_external_mutex_unlock(&m_info->mutex);
+ if (m_start_test_callback) m_start_test_callback(); // test callback
+ }
+
+ if (r != DB_LOCK_NOTGRANTED) {
+ complete(r);
+ }
+
+ conflicts.destroy();
+ return r;
+}
+
+// sleep on the lock request until it becomes resolved or the wait time has
+// elapsed.
+int lock_request::wait(uint64_t wait_time_ms) {
+ return wait(wait_time_ms, 0, nullptr);
+}
+
+int lock_request::wait(uint64_t wait_time_ms, uint64_t killed_time_ms,
+ int (*killed_callback)(void),
+ void (*lock_wait_callback)(void *, lock_wait_infos *),
+ void *callback_arg) {
+ uint64_t t_now = toku_current_time_microsec();
+ uint64_t t_start = t_now;
+ uint64_t t_end = t_start + wait_time_ms * 1000;
+
+ toku_external_mutex_lock(&m_info->mutex);
+
+ // check again, this time locking out other retry calls
+ if (m_state == state::PENDING) {
+ lock_wait_infos conflicts_collector;
+ retry(&conflicts_collector);
+ if (m_state == state::PENDING) {
+ report_waits(&conflicts_collector, lock_wait_callback, callback_arg);
+ }
+ }
+
+ while (m_state == state::PENDING) {
+ // check if this thread is killed
+ if (killed_callback && killed_callback()) {
+ remove_from_lock_requests();
+ complete(DB_LOCK_NOTGRANTED);
+ continue;
+ }
+
+ // compute the time until we should wait
+ uint64_t t_wait;
+ if (killed_time_ms == 0) {
+ t_wait = t_end;
+ } else {
+ t_wait = t_now + killed_time_ms * 1000;
+ if (t_wait > t_end) t_wait = t_end;
+ }
+
+ int r = toku_external_cond_timedwait(&m_wait_cond, &m_info->mutex,
+ (int64_t)(t_wait - t_now));
+ invariant(r == 0 || r == ETIMEDOUT);
+
+ t_now = toku_current_time_microsec();
+ if (m_state == state::PENDING && (t_now >= t_end)) {
+ m_info->counters.timeout_count += 1;
+
+ // if we're still pending and we timed out, then remove our
+ // request from the set of lock requests and fail.
+ remove_from_lock_requests();
+
+ // complete sets m_state to COMPLETE, breaking us out of the loop
+ complete(DB_LOCK_NOTGRANTED);
+ }
+ }
+
+ uint64_t t_real_end = toku_current_time_microsec();
+ uint64_t duration = t_real_end - t_start;
+ m_info->counters.wait_count += 1;
+ m_info->counters.wait_time += duration;
+ if (duration >= 1000000) {
+ m_info->counters.long_wait_count += 1;
+ m_info->counters.long_wait_time += duration;
+ }
+ toku_external_mutex_unlock(&m_info->mutex);
+
+ invariant(m_state == state::COMPLETE);
+ return m_complete_r;
+}
+
+// complete this lock request with the given return value
+void lock_request::complete(int complete_r) {
+ m_complete_r = complete_r;
+ m_state = state::COMPLETE;
+}
+
+const DBT *lock_request::get_left_key(void) const { return m_left_key; }
+
+const DBT *lock_request::get_right_key(void) const { return m_right_key; }
+
+TXNID lock_request::get_txnid(void) const { return m_txnid; }
+
+uint64_t lock_request::get_start_time(void) const { return m_start_time; }
+
+TXNID lock_request::get_conflicting_txnid(void) const {
+ return m_conflicting_txnid;
+}
+
+int lock_request::retry(lock_wait_infos *conflicts_collector) {
+ invariant(m_state == state::PENDING);
+ int r;
+ txnid_set conflicts;
+ conflicts.create();
+
+ if (m_type == type::WRITE) {
+ r = m_lt->acquire_write_lock(m_txnid, m_left_key, m_right_key, &conflicts,
+ m_big_txn);
+ } else {
+ r = m_lt->acquire_read_lock(m_txnid, m_left_key, m_right_key, &conflicts,
+ m_big_txn);
+ }
+
+ // if the acquisition succeeded then remove ourselves from the
+ // set of lock requests, complete, and signal the waiting thread.
+ if (r == 0) {
+ remove_from_lock_requests();
+ complete(r);
+ if (m_retry_test_callback) m_retry_test_callback(); // test callback
+ toku_external_cond_broadcast(&m_wait_cond);
+ } else {
+ m_conflicting_txnid = conflicts.get(0);
+ add_conflicts_to_waits(&conflicts, conflicts_collector);
+ }
+ conflicts.destroy();
+
+ return r;
+}
+
+void lock_request::retry_all_lock_requests(
+ locktree *lt, void (*lock_wait_callback)(void *, lock_wait_infos *),
+ void *callback_arg, void (*after_retry_all_test_callback)(void)) {
+ lt_lock_request_info *info = lt->get_lock_request_info();
+
+ // if there are no pending lock requests than there is nothing to do
+ // the unlocked data race on pending_is_empty is OK since lock requests
+ // are retried after added to the pending set.
+ if (info->pending_is_empty) return;
+
+ // get my retry generation (post increment of retry_want)
+ unsigned long long my_retry_want = (info->retry_want += 1);
+
+ toku_mutex_lock(&info->retry_mutex);
+
+ // here is the group retry algorithm.
+ // get the latest retry_want count and use it as the generation number of
+ // this retry operation. if this retry generation is > the last retry
+ // generation, then do the lock retries. otherwise, no lock retries
+ // are needed.
+ if ((my_retry_want - 1) == info->retry_done) {
+ for (;;) {
+ if (!info->running_retry) {
+ info->running_retry = true;
+ info->retry_done = info->retry_want;
+ toku_mutex_unlock(&info->retry_mutex);
+ retry_all_lock_requests_info(info, lock_wait_callback, callback_arg);
+ if (after_retry_all_test_callback) after_retry_all_test_callback();
+ toku_mutex_lock(&info->retry_mutex);
+ info->running_retry = false;
+ toku_cond_broadcast(&info->retry_cv);
+ break;
+ } else {
+ toku_cond_wait(&info->retry_cv, &info->retry_mutex);
+ }
+ }
+ }
+ toku_mutex_unlock(&info->retry_mutex);
+}
+
+void lock_request::retry_all_lock_requests_info(
+ lt_lock_request_info *info,
+ void (*lock_wait_callback)(void *, lock_wait_infos *), void *callback_arg) {
+ toku_external_mutex_lock(&info->mutex);
+ // retry all of the pending lock requests.
+ lock_wait_infos conflicts_collector;
+ for (uint32_t i = 0; i < info->pending_lock_requests.size();) {
+ lock_request *request;
+ int r = info->pending_lock_requests.fetch(i, &request);
+ invariant_zero(r);
+
+ // retry the lock request. if it didn't succeed,
+ // move on to the next lock request. otherwise
+ // the request is gone from the list so we may
+ // read the i'th entry for the next one.
+ r = request->retry(&conflicts_collector);
+ if (r != 0) {
+ i++;
+ }
+ }
+
+ // call report_waits while holding the pending queue lock since
+ // the waiter object is still valid while it's in the queue
+ report_waits(&conflicts_collector, lock_wait_callback, callback_arg);
+
+ // future threads should only retry lock requests if some still exist
+ info->should_retry_lock_requests = info->pending_lock_requests.size() > 0;
+ toku_external_mutex_unlock(&info->mutex);
+}
+
+void lock_request::add_conflicts_to_waits(txnid_set *conflicts,
+ lock_wait_infos *wait_conflicts) {
+ wait_conflicts->push_back({m_lt, get_txnid(), m_extra, {}});
+ uint32_t num_conflicts = conflicts->size();
+ for (uint32_t i = 0; i < num_conflicts; i++) {
+ wait_conflicts->back().waitees.push_back(conflicts->get(i));
+ }
+}
+
+void lock_request::report_waits(lock_wait_infos *wait_conflicts,
+ void (*lock_wait_callback)(void *,
+ lock_wait_infos *),
+ void *callback_arg) {
+ if (lock_wait_callback) (*lock_wait_callback)(callback_arg, wait_conflicts);
+}
+
+void *lock_request::get_extra(void) const { return m_extra; }
+
+void lock_request::kill_waiter(void) {
+ remove_from_lock_requests();
+ complete(DB_LOCK_NOTGRANTED);
+ toku_external_cond_broadcast(&m_wait_cond);
+}
+
+void lock_request::kill_waiter(locktree *lt, void *extra) {
+ lt_lock_request_info *info = lt->get_lock_request_info();
+ toku_external_mutex_lock(&info->mutex);
+ for (uint32_t i = 0; i < info->pending_lock_requests.size(); i++) {
+ lock_request *request;
+ int r = info->pending_lock_requests.fetch(i, &request);
+ if (r == 0 && request->get_extra() == extra) {
+ request->kill_waiter();
+ break;
+ }
+ }
+ toku_external_mutex_unlock(&info->mutex);
+}
+
+// find another lock request by txnid. must hold the mutex.
+lock_request *lock_request::find_lock_request(const TXNID &txnid) {
+ lock_request *request;
+ int r = m_info->pending_lock_requests.find_zero<TXNID, find_by_txnid>(
+ txnid, &request, nullptr);
+ if (r != 0) {
+ request = nullptr;
+ }
+ return request;
+}
+
+// insert this lock request into the locktree's set. must hold the mutex.
+void lock_request::insert_into_lock_requests(void) {
+ uint32_t idx;
+ lock_request *request;
+ int r = m_info->pending_lock_requests.find_zero<TXNID, find_by_txnid>(
+ m_txnid, &request, &idx);
+ invariant(r == DB_NOTFOUND);
+ r = m_info->pending_lock_requests.insert_at(this, idx);
+ invariant_zero(r);
+ m_info->pending_is_empty = false;
+}
+
+// remove this lock request from the locktree's set. must hold the mutex.
+void lock_request::remove_from_lock_requests(void) {
+ uint32_t idx;
+ lock_request *request;
+ int r = m_info->pending_lock_requests.find_zero<TXNID, find_by_txnid>(
+ m_txnid, &request, &idx);
+ invariant_zero(r);
+ invariant(request == this);
+ r = m_info->pending_lock_requests.delete_at(idx);
+ invariant_zero(r);
+ if (m_info->pending_lock_requests.size() == 0)
+ m_info->pending_is_empty = true;
+}
+
+int lock_request::find_by_txnid(lock_request *const &request,
+ const TXNID &txnid) {
+ TXNID request_txnid = request->m_txnid;
+ if (request_txnid < txnid) {
+ return -1;
+ } else if (request_txnid == txnid) {
+ return 0;
+ } else {
+ return 1;
+ }
+}
+
+void lock_request::set_start_test_callback(void (*f)(void)) {
+ m_start_test_callback = f;
+}
+
+void lock_request::set_start_before_pending_test_callback(void (*f)(void)) {
+ m_start_before_pending_test_callback = f;
+}
+
+void lock_request::set_retry_test_callback(void (*f)(void)) {
+ m_retry_test_callback = f;
+}
+
+} /* namespace toku */
+#endif // OS_WIN
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/lock_request.h b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/lock_request.h
new file mode 100644
index 000000000..d30e1e2ca
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/lock_request.h
@@ -0,0 +1,255 @@
+/* -*- mode: C++; c-basic-offset: 4; indent-tabs-mode: nil -*- */
+// vim: ft=cpp:expandtab:ts=8:sw=2:softtabstop=2:
+#ident "$Id$"
+/*======
+This file is part of PerconaFT.
+
+
+Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved.
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License, version 2,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License, version 3,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+======= */
+
+#ident \
+ "Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved."
+
+#pragma once
+
+#include "../db.h"
+#include "../ft/comparator.h"
+#include "../portability/toku_pthread.h"
+#include "locktree.h"
+#include "txnid_set.h"
+#include "wfg.h"
+
+namespace toku {
+
+// Information about a lock wait
+struct lock_wait_info {
+ locktree *ltree; // the tree where wait happens
+ TXNID waiter; // the waiting transaction
+ void *m_extra; // lock_request's m_extra
+
+ // The transactions that are waited for.
+ std::vector<TXNID> waitees;
+};
+
+typedef std::vector<lock_wait_info> lock_wait_infos;
+
+// A lock request contains the db, the key range, the lock type, and
+// the transaction id that describes a potential row range lock.
+//
+// the typical use case is:
+// - initialize a lock request
+// - start to try to acquire the lock
+// - do something else
+// - wait for the lock request to be resolved on a timed condition
+// - destroy the lock request
+// a lock request is resolved when its state is no longer pending, or
+// when it becomes granted, or timedout, or deadlocked. when resolved, the
+// state of the lock request is changed and any waiting threads are awakened.
+
+class lock_request {
+ public:
+ enum type { UNKNOWN, READ, WRITE };
+
+ // effect: Initializes a lock request.
+ void create(toku_external_mutex_factory_t mutex_factory);
+
+ // effect: Destroys a lock request.
+ void destroy(void);
+
+ // effect: Resets the lock request parameters, allowing it to be reused.
+ // requires: Lock request was already created at some point
+ void set(locktree *lt, TXNID txnid, const DBT *left_key, const DBT *right_key,
+ type lock_type, bool big_txn, void *extra = nullptr);
+
+ // effect: Tries to acquire a lock described by this lock request.
+ // returns: The return code of locktree::acquire_[write,read]_lock()
+ // or DB_LOCK_DEADLOCK if this request would end up deadlocked.
+ int start(void);
+
+ // effect: Sleeps until either the request is granted or the wait time
+ // expires. returns: The return code of locktree::acquire_[write,read]_lock()
+ // or simply DB_LOCK_NOTGRANTED if the wait time expired.
+ int wait(uint64_t wait_time_ms);
+ int wait(uint64_t wait_time_ms, uint64_t killed_time_ms,
+ int (*killed_callback)(void),
+ void (*lock_wait_callback)(void *, lock_wait_infos *) = nullptr,
+ void *callback_arg = nullptr);
+
+ // return: left end-point of the lock range
+ const DBT *get_left_key(void) const;
+
+ // return: right end-point of the lock range
+ const DBT *get_right_key(void) const;
+
+ // return: the txnid waiting for a lock
+ TXNID get_txnid(void) const;
+
+ // return: when this lock request started, as milliseconds from epoch
+ uint64_t get_start_time(void) const;
+
+ // return: which txnid is blocking this request (there may be more, though)
+ TXNID get_conflicting_txnid(void) const;
+
+ // effect: Retries all of the lock requests for the given locktree.
+ // Any lock requests successfully restarted is completed and woken
+ // up.
+ // The rest remain pending.
+ static void retry_all_lock_requests(
+ locktree *lt,
+ void (*lock_wait_callback)(void *, lock_wait_infos *) = nullptr,
+ void *callback_arg = nullptr,
+ void (*after_retry_test_callback)(void) = nullptr);
+ static void retry_all_lock_requests_info(
+ lt_lock_request_info *info,
+ void (*lock_wait_callback)(void *, lock_wait_infos *),
+ void *callback_arg);
+
+ void set_start_test_callback(void (*f)(void));
+ void set_start_before_pending_test_callback(void (*f)(void));
+ void set_retry_test_callback(void (*f)(void));
+
+ void *get_extra(void) const;
+
+ void kill_waiter(void);
+ static void kill_waiter(locktree *lt, void *extra);
+
+ private:
+ enum state {
+ UNINITIALIZED,
+ INITIALIZED,
+ PENDING,
+ COMPLETE,
+ DESTROYED,
+ };
+
+ // The keys for a lock request are stored "unowned" in m_left_key
+ // and m_right_key. When the request is about to go to sleep, it
+ // copies these keys and stores them in m_left_key_copy etc and
+ // sets the temporary pointers to null.
+ TXNID m_txnid;
+ TXNID m_conflicting_txnid;
+ uint64_t m_start_time;
+ const DBT *m_left_key;
+ const DBT *m_right_key;
+ DBT m_left_key_copy;
+ DBT m_right_key_copy;
+
+ // The lock request type and associated locktree
+ type m_type;
+ locktree *m_lt;
+
+ // If the lock request is in the completed state, then its
+ // final return value is stored in m_complete_r
+ int m_complete_r;
+ state m_state;
+
+ toku_external_cond_t m_wait_cond;
+
+ bool m_big_txn;
+
+ // the lock request info state stored in the
+ // locktree that this lock request is for.
+ struct lt_lock_request_info *m_info;
+
+ void *m_extra;
+
+ // effect: tries again to acquire the lock described by this lock request
+ // returns: 0 if retrying the request succeeded and is now complete
+ int retry(lock_wait_infos *collector);
+
+ void complete(int complete_r);
+
+ // effect: Finds another lock request by txnid.
+ // requires: The lock request info mutex is held
+ lock_request *find_lock_request(const TXNID &txnid);
+
+ // effect: Insert this lock request into the locktree's set.
+ // requires: the locktree's mutex is held
+ void insert_into_lock_requests(void);
+
+ // effect: Removes this lock request from the locktree's set.
+ // requires: The lock request info mutex is held
+ void remove_from_lock_requests(void);
+
+ // effect: Asks this request's locktree which txnids are preventing
+ // us from getting the lock described by this request.
+ // returns: conflicts is populated with the txnid's that this request
+ // is blocked on
+ void get_conflicts(txnid_set *conflicts);
+
+ // effect: Builds a wait-for-graph for this lock request and the given
+ // conflict set
+ void build_wait_graph(wfg *wait_graph, const txnid_set &conflicts);
+
+ // returns: True if this lock request is in deadlock with the given conflicts
+ // set
+ bool deadlock_exists(const txnid_set &conflicts);
+
+ void copy_keys(void);
+
+ static int find_by_txnid(lock_request *const &request, const TXNID &txnid);
+
+ // Report list of conflicts to lock wait callback.
+ static void report_waits(lock_wait_infos *wait_conflicts,
+ void (*lock_wait_callback)(void *,
+ lock_wait_infos *),
+ void *callback_arg);
+ void add_conflicts_to_waits(txnid_set *conflicts,
+ lock_wait_infos *wait_conflicts);
+
+ void (*m_start_test_callback)(void);
+ void (*m_start_before_pending_test_callback)(void);
+ void (*m_retry_test_callback)(void);
+
+ public:
+ std::function<void(TXNID, bool, const DBT *, const DBT *)> m_deadlock_cb;
+
+ friend class lock_request_unit_test;
+};
+// PORT: lock_request is not a POD anymore due to use of toku_external_cond_t
+// This is ok as the PODness is not really required: lock_request objects are
+// not moved in memory or anything.
+// ENSURE_POD(lock_request);
+
+} /* namespace toku */
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/locktree.cc b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/locktree.cc
new file mode 100644
index 000000000..3d6a590c7
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/locktree.cc
@@ -0,0 +1,1023 @@
+/* -*- mode: C++; c-basic-offset: 4; indent-tabs-mode: nil -*- */
+// vim: ft=cpp:expandtab:ts=8:sw=2:softtabstop=2:
+#ifndef ROCKSDB_LITE
+#ifndef OS_WIN
+#ident "$Id$"
+/*======
+This file is part of PerconaFT.
+
+
+Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved.
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License, version 2,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License, version 3,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+======= */
+
+#ident \
+ "Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved."
+
+#include "locktree.h"
+
+#include <memory.h>
+
+#include "../portability/toku_pthread.h"
+#include "../portability/toku_time.h"
+#include "../util/growable_array.h"
+#include "range_buffer.h"
+
+// including the concurrent_tree here expands the templates
+// and "defines" the implementation, so we do it here in
+// the locktree source file instead of the header.
+#include "concurrent_tree.h"
+
+namespace toku {
+// A locktree represents the set of row locks owned by all transactions
+// over an open dictionary. Read and write ranges are represented as
+// a left and right key which are compared with the given descriptor
+// and comparison fn.
+//
+// Each locktree has a reference count which it manages
+// but does nothing based on the value of the reference count - it is
+// up to the user of the locktree to destroy it when it sees fit.
+
+void locktree::create(locktree_manager *mgr, DICTIONARY_ID dict_id,
+ const comparator &cmp,
+ toku_external_mutex_factory_t mutex_factory) {
+ m_mgr = mgr;
+ m_dict_id = dict_id;
+
+ m_cmp.create_from(cmp);
+ m_reference_count = 1;
+ m_userdata = nullptr;
+
+ XCALLOC(m_rangetree);
+ m_rangetree->create(&m_cmp);
+
+ m_sto_txnid = TXNID_NONE;
+ m_sto_buffer.create();
+ m_sto_score = STO_SCORE_THRESHOLD;
+ m_sto_end_early_count = 0;
+ m_sto_end_early_time = 0;
+
+ m_escalation_barrier = [](const DBT *, const DBT *, void *) -> bool {
+ return false;
+ };
+
+ m_lock_request_info.init(mutex_factory);
+}
+
+void locktree::set_escalation_barrier_func(
+ lt_escalation_barrier_check_func func, void *extra) {
+ m_escalation_barrier = func;
+ m_escalation_barrier_arg = extra;
+}
+
+void lt_lock_request_info::init(toku_external_mutex_factory_t mutex_factory) {
+ pending_lock_requests.create();
+ pending_is_empty = true;
+ toku_external_mutex_init(mutex_factory, &mutex);
+ retry_want = retry_done = 0;
+ ZERO_STRUCT(counters);
+ ZERO_STRUCT(retry_mutex);
+ toku_mutex_init(locktree_request_info_retry_mutex_key, &retry_mutex, nullptr);
+ toku_cond_init(locktree_request_info_retry_cv_key, &retry_cv, nullptr);
+ running_retry = false;
+
+ TOKU_VALGRIND_HG_DISABLE_CHECKING(&pending_is_empty,
+ sizeof(pending_is_empty));
+ TOKU_DRD_IGNORE_VAR(pending_is_empty);
+}
+
+void locktree::destroy(void) {
+ invariant(m_reference_count == 0);
+ invariant(m_lock_request_info.pending_lock_requests.size() == 0);
+ m_cmp.destroy();
+ m_rangetree->destroy();
+ toku_free(m_rangetree);
+ m_sto_buffer.destroy();
+ m_lock_request_info.destroy();
+}
+
+void lt_lock_request_info::destroy(void) {
+ pending_lock_requests.destroy();
+ toku_external_mutex_destroy(&mutex);
+ toku_mutex_destroy(&retry_mutex);
+ toku_cond_destroy(&retry_cv);
+}
+
+void locktree::add_reference(void) {
+ (void)toku_sync_add_and_fetch(&m_reference_count, 1);
+}
+
+uint32_t locktree::release_reference(void) {
+ return toku_sync_sub_and_fetch(&m_reference_count, 1);
+}
+
+uint32_t locktree::get_reference_count(void) { return m_reference_count; }
+
+// a container for a range/txnid pair
+struct row_lock {
+ keyrange range;
+ TXNID txnid;
+ bool is_shared;
+ TxnidVector *owners;
+};
+
+// iterate over a locked keyrange and copy out all of the data,
+// storing each row lock into the given growable array. the
+// caller does not own the range inside the returned row locks,
+// so remove from the tree with care using them as keys.
+static void iterate_and_get_overlapping_row_locks(
+ const concurrent_tree::locked_keyrange *lkr,
+ GrowableArray<row_lock> *row_locks) {
+ struct copy_fn_obj {
+ GrowableArray<row_lock> *row_locks;
+ bool fn(const keyrange &range, TXNID txnid, bool is_shared,
+ TxnidVector *owners) {
+ row_lock lock = {.range = range,
+ .txnid = txnid,
+ .is_shared = is_shared,
+ .owners = owners};
+ row_locks->push(lock);
+ return true;
+ }
+ } copy_fn;
+ copy_fn.row_locks = row_locks;
+ lkr->iterate(&copy_fn);
+}
+
+// given a txnid and a set of overlapping row locks, determine
+// which txnids are conflicting, and store them in the conflicts
+// set, if given.
+static bool determine_conflicting_txnids(
+ const GrowableArray<row_lock> &row_locks, const TXNID &txnid,
+ txnid_set *conflicts) {
+ bool conflicts_exist = false;
+ const size_t num_overlaps = row_locks.get_size();
+ for (size_t i = 0; i < num_overlaps; i++) {
+ const row_lock lock = row_locks.fetch_unchecked(i);
+ const TXNID other_txnid = lock.txnid;
+ if (other_txnid != txnid) {
+ if (conflicts) {
+ if (other_txnid == TXNID_SHARED) {
+ // Add all shared lock owners, except this transaction.
+ for (TXNID shared_id : *lock.owners) {
+ if (shared_id != txnid) conflicts->add(shared_id);
+ }
+ } else {
+ conflicts->add(other_txnid);
+ }
+ }
+ conflicts_exist = true;
+ }
+ }
+ return conflicts_exist;
+}
+
+// how much memory does a row lock take up in a concurrent tree?
+static uint64_t row_lock_size_in_tree(const row_lock &lock) {
+ const uint64_t overhead = concurrent_tree::get_insertion_memory_overhead();
+ return lock.range.get_memory_size() + overhead;
+}
+
+// remove and destroy the given row lock from the locked keyrange,
+// then notify the memory tracker of the newly freed lock.
+static void remove_row_lock_from_tree(concurrent_tree::locked_keyrange *lkr,
+ const row_lock &lock, TXNID txnid,
+ locktree_manager *mgr) {
+ const uint64_t mem_released = row_lock_size_in_tree(lock);
+ lkr->remove(lock.range, txnid);
+ if (mgr != nullptr) {
+ mgr->note_mem_released(mem_released);
+ }
+}
+
+// insert a row lock into the locked keyrange, then notify
+// the memory tracker of this newly acquired lock.
+static void insert_row_lock_into_tree(concurrent_tree::locked_keyrange *lkr,
+ const row_lock &lock,
+ locktree_manager *mgr) {
+ uint64_t mem_used = row_lock_size_in_tree(lock);
+ lkr->insert(lock.range, lock.txnid, lock.is_shared);
+ if (mgr != nullptr) {
+ mgr->note_mem_used(mem_used);
+ }
+}
+
+void locktree::sto_begin(TXNID txnid) {
+ invariant(m_sto_txnid == TXNID_NONE);
+ invariant(m_sto_buffer.is_empty());
+ m_sto_txnid = txnid;
+}
+
+void locktree::sto_append(const DBT *left_key, const DBT *right_key,
+ bool is_write_request) {
+ uint64_t buffer_mem, delta;
+
+ // psergey: the below two lines do not make any sense
+ // (and it's the same in upstream TokuDB)
+ keyrange range;
+ range.create(left_key, right_key);
+
+ buffer_mem = m_sto_buffer.total_memory_size();
+ m_sto_buffer.append(left_key, right_key, is_write_request);
+ delta = m_sto_buffer.total_memory_size() - buffer_mem;
+ if (m_mgr != nullptr) {
+ m_mgr->note_mem_used(delta);
+ }
+}
+
+void locktree::sto_end(void) {
+ uint64_t mem_size = m_sto_buffer.total_memory_size();
+ if (m_mgr != nullptr) {
+ m_mgr->note_mem_released(mem_size);
+ }
+ m_sto_buffer.destroy();
+ m_sto_buffer.create();
+ m_sto_txnid = TXNID_NONE;
+}
+
+void locktree::sto_end_early_no_accounting(void *prepared_lkr) {
+ sto_migrate_buffer_ranges_to_tree(prepared_lkr);
+ sto_end();
+ toku_unsafe_set(m_sto_score, 0);
+}
+
+void locktree::sto_end_early(void *prepared_lkr) {
+ m_sto_end_early_count++;
+
+ tokutime_t t0 = toku_time_now();
+ sto_end_early_no_accounting(prepared_lkr);
+ tokutime_t t1 = toku_time_now();
+
+ m_sto_end_early_time += (t1 - t0);
+}
+
+void locktree::sto_migrate_buffer_ranges_to_tree(void *prepared_lkr) {
+ // There should be something to migrate, and nothing in the rangetree.
+ invariant(!m_sto_buffer.is_empty());
+ invariant(m_rangetree->is_empty());
+
+ concurrent_tree sto_rangetree;
+ concurrent_tree::locked_keyrange sto_lkr;
+ sto_rangetree.create(&m_cmp);
+
+ // insert all of the ranges from the single txnid buffer into a new rangtree
+ range_buffer::iterator iter(&m_sto_buffer);
+ range_buffer::iterator::record rec;
+ while (iter.current(&rec)) {
+ sto_lkr.prepare(&sto_rangetree);
+ int r = acquire_lock_consolidated(&sto_lkr, m_sto_txnid, rec.get_left_key(),
+ rec.get_right_key(),
+ rec.get_exclusive_flag(), nullptr);
+ invariant_zero(r);
+ sto_lkr.release();
+ iter.next();
+ }
+
+ // Iterate the newly created rangetree and insert each range into the
+ // locktree's rangetree, on behalf of the old single txnid.
+ struct migrate_fn_obj {
+ concurrent_tree::locked_keyrange *dst_lkr;
+ bool fn(const keyrange &range, TXNID txnid, bool is_shared,
+ TxnidVector *owners) {
+ // There can't be multiple owners in STO mode
+ invariant_zero(owners);
+ dst_lkr->insert(range, txnid, is_shared);
+ return true;
+ }
+ } migrate_fn;
+ migrate_fn.dst_lkr =
+ static_cast<concurrent_tree::locked_keyrange *>(prepared_lkr);
+ sto_lkr.prepare(&sto_rangetree);
+ sto_lkr.iterate(&migrate_fn);
+ sto_lkr.remove_all();
+ sto_lkr.release();
+ sto_rangetree.destroy();
+ invariant(!m_rangetree->is_empty());
+}
+
+bool locktree::sto_try_acquire(void *prepared_lkr, TXNID txnid,
+ const DBT *left_key, const DBT *right_key,
+ bool is_write_request) {
+ if (m_rangetree->is_empty() && m_sto_buffer.is_empty() &&
+ toku_unsafe_fetch(m_sto_score) >= STO_SCORE_THRESHOLD) {
+ // We can do the optimization because the rangetree is empty, and
+ // we know its worth trying because the sto score is big enough.
+ sto_begin(txnid);
+ } else if (m_sto_txnid != TXNID_NONE) {
+ // We are currently doing the optimization. Check if we need to cancel
+ // it because a new txnid appeared, or if the current single txnid has
+ // taken too many locks already.
+ if (m_sto_txnid != txnid ||
+ m_sto_buffer.get_num_ranges() > STO_BUFFER_MAX_SIZE) {
+ sto_end_early(prepared_lkr);
+ }
+ }
+
+ // At this point the sto txnid is properly set. If it is valid, then
+ // this txnid can append its lock to the sto buffer successfully.
+ if (m_sto_txnid != TXNID_NONE) {
+ invariant(m_sto_txnid == txnid);
+ sto_append(left_key, right_key, is_write_request);
+ return true;
+ } else {
+ invariant(m_sto_buffer.is_empty());
+ return false;
+ }
+}
+
+/*
+ Do the same as iterate_and_get_overlapping_row_locks does, but also check for
+ this:
+ The set of overlapping rows locks consists of just one read-only shared
+ lock with the same endpoints as specified (in that case, we can just add
+ ourselves into that list)
+
+ @return true - One compatible shared lock
+ false - Otherwise
+*/
+static bool iterate_and_get_overlapping_row_locks2(
+ const concurrent_tree::locked_keyrange *lkr, const DBT *left_key,
+ const DBT *right_key, comparator *cmp, TXNID,
+ GrowableArray<row_lock> *row_locks) {
+ struct copy_fn_obj {
+ GrowableArray<row_lock> *row_locks;
+ bool first_call = true;
+ bool matching_lock_found = false;
+ const DBT *left_key, *right_key;
+ comparator *cmp;
+
+ bool fn(const keyrange &range, TXNID txnid, bool is_shared,
+ TxnidVector *owners) {
+ if (first_call) {
+ first_call = false;
+ if (is_shared && !(*cmp)(left_key, range.get_left_key()) &&
+ !(*cmp)(right_key, range.get_right_key())) {
+ matching_lock_found = true;
+ }
+ } else {
+ // if we see multiple matching locks, it doesn't matter whether
+ // the first one was matching.
+ matching_lock_found = false;
+ }
+ row_lock lock = {.range = range,
+ .txnid = txnid,
+ .is_shared = is_shared,
+ .owners = owners};
+ row_locks->push(lock);
+ return true;
+ }
+ } copy_fn;
+ copy_fn.row_locks = row_locks;
+ copy_fn.left_key = left_key;
+ copy_fn.right_key = right_key;
+ copy_fn.cmp = cmp;
+ lkr->iterate(&copy_fn);
+ return copy_fn.matching_lock_found;
+}
+
+// try to acquire a lock and consolidate it with existing locks if possible
+// param: lkr, a prepared locked keyrange
+// return: 0 on success, DB_LOCK_NOTGRANTED if conflicting locks exist.
+int locktree::acquire_lock_consolidated(void *prepared_lkr, TXNID txnid,
+ const DBT *left_key,
+ const DBT *right_key,
+ bool is_write_request,
+ txnid_set *conflicts) {
+ int r = 0;
+ concurrent_tree::locked_keyrange *lkr;
+
+ keyrange requested_range;
+ requested_range.create(left_key, right_key);
+ lkr = static_cast<concurrent_tree::locked_keyrange *>(prepared_lkr);
+ lkr->acquire(requested_range);
+
+ // copy out the set of overlapping row locks.
+ GrowableArray<row_lock> overlapping_row_locks;
+ overlapping_row_locks.init();
+ bool matching_shared_lock_found = false;
+
+ if (is_write_request)
+ iterate_and_get_overlapping_row_locks(lkr, &overlapping_row_locks);
+ else {
+ matching_shared_lock_found = iterate_and_get_overlapping_row_locks2(
+ lkr, left_key, right_key, &m_cmp, txnid, &overlapping_row_locks);
+ // psergey-todo: what to do now? So, we have figured we have just one
+ // shareable lock. Need to add us into it as an owner but the lock
+ // pointer cannot be kept?
+ // A: use find_node_with_overlapping_child(key_range, nullptr);
+ // then, add ourselves to the owner list.
+ // Dont' foreget to release the subtree after that.
+ }
+
+ if (matching_shared_lock_found) {
+ // there is just one non-confliting matching shared lock.
+ // we are hilding a lock on it (see acquire() call above).
+ // we need to modify it to indicate there is another locker...
+ if (lkr->add_shared_owner(requested_range, txnid)) {
+ // Pretend shared lock uses as much memory.
+ row_lock new_lock = {.range = requested_range,
+ .txnid = txnid,
+ .is_shared = false,
+ .owners = nullptr};
+ uint64_t mem_used = row_lock_size_in_tree(new_lock);
+ if (m_mgr) {
+ m_mgr->note_mem_used(mem_used);
+ }
+ }
+ requested_range.destroy();
+ overlapping_row_locks.deinit();
+ return 0;
+ }
+
+ size_t num_overlapping_row_locks = overlapping_row_locks.get_size();
+
+ // if any overlapping row locks conflict with this request, bail out.
+
+ bool conflicts_exist =
+ determine_conflicting_txnids(overlapping_row_locks, txnid, conflicts);
+ if (!conflicts_exist) {
+ // there are no conflicts, so all of the overlaps are for the requesting
+ // txnid. so, we must consolidate all existing overlapping ranges and the
+ // requested range into one dominating range. then we insert the dominating
+ // range.
+ bool all_shared = !is_write_request;
+ for (size_t i = 0; i < num_overlapping_row_locks; i++) {
+ row_lock overlapping_lock = overlapping_row_locks.fetch_unchecked(i);
+ invariant(overlapping_lock.txnid == txnid);
+ requested_range.extend(m_cmp, overlapping_lock.range);
+ remove_row_lock_from_tree(lkr, overlapping_lock, TXNID_ANY, m_mgr);
+ all_shared = all_shared && overlapping_lock.is_shared;
+ }
+
+ row_lock new_lock = {.range = requested_range,
+ .txnid = txnid,
+ .is_shared = all_shared,
+ .owners = nullptr};
+ insert_row_lock_into_tree(lkr, new_lock, m_mgr);
+ } else {
+ r = DB_LOCK_NOTGRANTED;
+ }
+
+ requested_range.destroy();
+ overlapping_row_locks.deinit();
+ return r;
+}
+
+// acquire a lock in the given key range, inclusive. if successful,
+// return 0. otherwise, populate the conflicts txnid_set with the set of
+// transactions that conflict with this request.
+int locktree::acquire_lock(bool is_write_request, TXNID txnid,
+ const DBT *left_key, const DBT *right_key,
+ txnid_set *conflicts) {
+ int r = 0;
+
+ // we are only supporting write locks for simplicity
+ // invariant(is_write_request);
+
+ // acquire and prepare a locked keyrange over the requested range.
+ // prepare is a serialzation point, so we take the opportunity to
+ // try the single txnid optimization first.
+ concurrent_tree::locked_keyrange lkr;
+ lkr.prepare(m_rangetree);
+
+ bool acquired =
+ sto_try_acquire(&lkr, txnid, left_key, right_key, is_write_request);
+ if (!acquired) {
+ r = acquire_lock_consolidated(&lkr, txnid, left_key, right_key,
+ is_write_request, conflicts);
+ }
+
+ lkr.release();
+ return r;
+}
+
+int locktree::try_acquire_lock(bool is_write_request, TXNID txnid,
+ const DBT *left_key, const DBT *right_key,
+ txnid_set *conflicts, bool big_txn) {
+ // All ranges in the locktree must have left endpoints <= right endpoints.
+ // Range comparisons rely on this fact, so we make a paranoid invariant here.
+ paranoid_invariant(m_cmp(left_key, right_key) <= 0);
+ int r = m_mgr == nullptr ? 0 : m_mgr->check_current_lock_constraints(big_txn);
+ if (r == 0) {
+ r = acquire_lock(is_write_request, txnid, left_key, right_key, conflicts);
+ }
+ return r;
+}
+
+// the locktree silently upgrades read locks to write locks for simplicity
+int locktree::acquire_read_lock(TXNID txnid, const DBT *left_key,
+ const DBT *right_key, txnid_set *conflicts,
+ bool big_txn) {
+ return try_acquire_lock(false, txnid, left_key, right_key, conflicts,
+ big_txn);
+}
+
+int locktree::acquire_write_lock(TXNID txnid, const DBT *left_key,
+ const DBT *right_key, txnid_set *conflicts,
+ bool big_txn) {
+ return try_acquire_lock(true, txnid, left_key, right_key, conflicts, big_txn);
+}
+
+// typedef void (*dump_callback)(void *cdata, const DBT *left, const DBT *right,
+// TXNID txnid);
+void locktree::dump_locks(void *cdata, dump_callback cb) {
+ concurrent_tree::locked_keyrange lkr;
+ keyrange range;
+ range.create(toku_dbt_negative_infinity(), toku_dbt_positive_infinity());
+
+ lkr.prepare(m_rangetree);
+ lkr.acquire(range);
+
+ TXNID sto_txn;
+ if ((sto_txn = toku_unsafe_fetch(m_sto_txnid)) != TXNID_NONE) {
+ // insert all of the ranges from the single txnid buffer into a new rangtree
+ range_buffer::iterator iter(&m_sto_buffer);
+ range_buffer::iterator::record rec;
+ while (iter.current(&rec)) {
+ (*cb)(cdata, rec.get_left_key(), rec.get_right_key(), sto_txn,
+ !rec.get_exclusive_flag(), nullptr);
+ iter.next();
+ }
+ } else {
+ GrowableArray<row_lock> all_locks;
+ all_locks.init();
+ iterate_and_get_overlapping_row_locks(&lkr, &all_locks);
+
+ const size_t n_locks = all_locks.get_size();
+ for (size_t i = 0; i < n_locks; i++) {
+ const row_lock lock = all_locks.fetch_unchecked(i);
+ (*cb)(cdata, lock.range.get_left_key(), lock.range.get_right_key(),
+ lock.txnid, lock.is_shared, lock.owners);
+ }
+ all_locks.deinit();
+ }
+ lkr.release();
+ range.destroy();
+}
+
+void locktree::get_conflicts(bool is_write_request, TXNID txnid,
+ const DBT *left_key, const DBT *right_key,
+ txnid_set *conflicts) {
+ // because we only support write locks, ignore this bit for now.
+ (void)is_write_request;
+
+ // preparing and acquire a locked keyrange over the range
+ keyrange range;
+ range.create(left_key, right_key);
+ concurrent_tree::locked_keyrange lkr;
+ lkr.prepare(m_rangetree);
+ lkr.acquire(range);
+
+ // copy out the set of overlapping row locks and determine the conflicts
+ GrowableArray<row_lock> overlapping_row_locks;
+ overlapping_row_locks.init();
+ iterate_and_get_overlapping_row_locks(&lkr, &overlapping_row_locks);
+
+ // we don't care if conflicts exist. we just want the conflicts set populated.
+ (void)determine_conflicting_txnids(overlapping_row_locks, txnid, conflicts);
+
+ lkr.release();
+ overlapping_row_locks.deinit();
+ range.destroy();
+}
+
+// Effect:
+// For each range in the lock tree that overlaps the given range and has
+// the given txnid, remove it.
+// Rationale:
+// In the common case, there is only the range [left_key, right_key] and
+// it is associated with txnid, so this is a single tree delete.
+//
+// However, consolidation and escalation change the objects in the tree
+// without telling the txn anything. In this case, the txn may own a
+// large range lock that represents its ownership of many smaller range
+// locks. For example, the txn may think it owns point locks on keys 1,
+// 2, and 3, but due to escalation, only the object [1,3] exists in the
+// tree.
+//
+// The first call for a small lock will remove the large range lock, and
+// the rest of the calls should do nothing. After the first release,
+// another thread can acquire one of the locks that the txn thinks it
+// still owns. That's ok, because the txn doesn't want it anymore (it
+// unlocks everything at once), but it may find a lock that it does not
+// own.
+//
+// In our example, the txn unlocks key 1, which actually removes the
+// whole lock [1,3]. Now, someone else can lock 2 before our txn gets
+// around to unlocking 2, so we should not remove that lock.
+void locktree::remove_overlapping_locks_for_txnid(TXNID txnid,
+ const DBT *left_key,
+ const DBT *right_key) {
+ keyrange release_range;
+ release_range.create(left_key, right_key);
+
+ // acquire and prepare a locked keyrange over the release range
+ concurrent_tree::locked_keyrange lkr;
+ lkr.prepare(m_rangetree);
+ lkr.acquire(release_range);
+
+ // copy out the set of overlapping row locks.
+ GrowableArray<row_lock> overlapping_row_locks;
+ overlapping_row_locks.init();
+ iterate_and_get_overlapping_row_locks(&lkr, &overlapping_row_locks);
+ size_t num_overlapping_row_locks = overlapping_row_locks.get_size();
+
+ for (size_t i = 0; i < num_overlapping_row_locks; i++) {
+ row_lock lock = overlapping_row_locks.fetch_unchecked(i);
+ // If this isn't our lock, that's ok, just don't remove it.
+ // See rationale above.
+ // psergey-todo: for shared locks, just remove ourselves from the
+ // owners.
+ if (lock.txnid == txnid || (lock.owners && lock.owners->contains(txnid))) {
+ remove_row_lock_from_tree(&lkr, lock, txnid, m_mgr);
+ }
+ }
+
+ lkr.release();
+ overlapping_row_locks.deinit();
+ release_range.destroy();
+}
+
+bool locktree::sto_txnid_is_valid_unsafe(void) const {
+ return toku_unsafe_fetch(m_sto_txnid) != TXNID_NONE;
+}
+
+int locktree::sto_get_score_unsafe(void) const {
+ return toku_unsafe_fetch(m_sto_score);
+}
+
+bool locktree::sto_try_release(TXNID txnid) {
+ bool released = false;
+ if (toku_unsafe_fetch(m_sto_txnid) != TXNID_NONE) {
+ // check the bit again with a prepared locked keyrange,
+ // which protects the optimization bits and rangetree data
+ concurrent_tree::locked_keyrange lkr;
+ lkr.prepare(m_rangetree);
+ if (m_sto_txnid != TXNID_NONE) {
+ // this txnid better be the single txnid on this locktree,
+ // or else we are in big trouble (meaning the logic is broken)
+ invariant(m_sto_txnid == txnid);
+ invariant(m_rangetree->is_empty());
+ sto_end();
+ released = true;
+ }
+ lkr.release();
+ }
+ return released;
+}
+
+// release all of the locks for a txnid whose endpoints are pairs
+// in the given range buffer.
+void locktree::release_locks(TXNID txnid, const range_buffer *ranges,
+ bool all_trx_locks_hint) {
+ // try the single txn optimization. if it worked, then all of the
+ // locks are already released, otherwise we need to do it here.
+ bool released;
+ if (all_trx_locks_hint) {
+ // This will release all of the locks the transaction is holding
+ released = sto_try_release(txnid);
+ } else {
+ /*
+ psergey: we are asked to release *Some* of the locks the transaction
+ is holding.
+ We could try doing that without leaving the STO mode, but right now,
+ the easiest way is to exit the STO mode and let the non-STO code path
+ handle it.
+ */
+ if (toku_unsafe_fetch(m_sto_txnid) != TXNID_NONE) {
+ // check the bit again with a prepared locked keyrange,
+ // which protects the optimization bits and rangetree data
+ concurrent_tree::locked_keyrange lkr;
+ lkr.prepare(m_rangetree);
+ if (m_sto_txnid != TXNID_NONE) {
+ sto_end_early(&lkr);
+ }
+ lkr.release();
+ }
+ released = false;
+ }
+ if (!released) {
+ range_buffer::iterator iter(ranges);
+ range_buffer::iterator::record rec;
+ while (iter.current(&rec)) {
+ const DBT *left_key = rec.get_left_key();
+ const DBT *right_key = rec.get_right_key();
+ // All ranges in the locktree must have left endpoints <= right endpoints.
+ // Range comparisons rely on this fact, so we make a paranoid invariant
+ // here.
+ paranoid_invariant(m_cmp(left_key, right_key) <= 0);
+ remove_overlapping_locks_for_txnid(txnid, left_key, right_key);
+ iter.next();
+ }
+ // Increase the sto score slightly. Eventually it will hit
+ // the threshold and we'll try the optimization again. This
+ // is how a previously multithreaded system transitions into
+ // a single threaded system that benefits from the optimization.
+ if (toku_unsafe_fetch(m_sto_score) < STO_SCORE_THRESHOLD) {
+ toku_sync_fetch_and_add(&m_sto_score, 1);
+ }
+ }
+}
+
+// iterate over a locked keyrange and extract copies of the first N
+// row locks, storing each one into the given array of size N,
+// then removing each extracted lock from the locked keyrange.
+static int extract_first_n_row_locks(concurrent_tree::locked_keyrange *lkr,
+ locktree_manager *mgr, row_lock *row_locks,
+ int num_to_extract) {
+ struct extract_fn_obj {
+ int num_extracted;
+ int num_to_extract;
+ row_lock *row_locks;
+ bool fn(const keyrange &range, TXNID txnid, bool is_shared,
+ TxnidVector *owners) {
+ if (num_extracted < num_to_extract) {
+ row_lock lock;
+ lock.range.create_copy(range);
+ lock.txnid = txnid;
+ lock.is_shared = is_shared;
+ // deep-copy the set of owners:
+ if (owners)
+ lock.owners = new TxnidVector(*owners);
+ else
+ lock.owners = nullptr;
+ row_locks[num_extracted++] = lock;
+ return true;
+ } else {
+ return false;
+ }
+ }
+ } extract_fn;
+
+ extract_fn.row_locks = row_locks;
+ extract_fn.num_to_extract = num_to_extract;
+ extract_fn.num_extracted = 0;
+ lkr->iterate(&extract_fn);
+
+ // now that the ranges have been copied out, complete
+ // the extraction by removing the ranges from the tree.
+ // use remove_row_lock_from_tree() so we properly track the
+ // amount of memory and number of locks freed.
+ int num_extracted = extract_fn.num_extracted;
+ invariant(num_extracted <= num_to_extract);
+ for (int i = 0; i < num_extracted; i++) {
+ remove_row_lock_from_tree(lkr, row_locks[i], TXNID_ANY, mgr);
+ }
+
+ return num_extracted;
+}
+
+// Store each newly escalated lock in a range buffer for appropriate txnid.
+// We'll rebuild the locktree by iterating over these ranges, and then we
+// can pass back each txnid/buffer pair individually through a callback
+// to notify higher layers that locks have changed.
+struct txnid_range_buffer {
+ TXNID txnid;
+ range_buffer buffer;
+
+ static int find_by_txnid(struct txnid_range_buffer *const &other_buffer,
+ const TXNID &txnid) {
+ if (txnid < other_buffer->txnid) {
+ return -1;
+ } else if (other_buffer->txnid == txnid) {
+ return 0;
+ } else {
+ return 1;
+ }
+ }
+};
+
+// escalate the locks in the locktree by merging adjacent
+// locks that have the same txnid into one larger lock.
+//
+// if there's only one txnid in the locktree then this
+// approach works well. if there are many txnids and each
+// has locks in a random/alternating order, then this does
+// not work so well.
+void locktree::escalate(lt_escalate_cb after_escalate_callback,
+ void *after_escalate_callback_extra) {
+ omt<struct txnid_range_buffer *, struct txnid_range_buffer *> range_buffers;
+ range_buffers.create();
+
+ // prepare and acquire a locked keyrange on the entire locktree
+ concurrent_tree::locked_keyrange lkr;
+ keyrange infinite_range = keyrange::get_infinite_range();
+ lkr.prepare(m_rangetree);
+ lkr.acquire(infinite_range);
+
+ // if we're in the single txnid optimization, simply call it off.
+ // if you have to run escalation, you probably don't care about
+ // the optimization anyway, and this makes things easier.
+ if (m_sto_txnid != TXNID_NONE) {
+ // We are already accounting for this escalation time and
+ // count, so don't do it for sto_end_early too.
+ sto_end_early_no_accounting(&lkr);
+ }
+
+ // extract and remove batches of row locks from the locktree
+ int num_extracted;
+ const int num_row_locks_per_batch = 128;
+ row_lock *XCALLOC_N(num_row_locks_per_batch, extracted_buf);
+
+ // we always remove the "first" n because we are removing n
+ // each time we do an extraction. so this loops until its empty.
+ while ((num_extracted = extract_first_n_row_locks(
+ &lkr, m_mgr, extracted_buf, num_row_locks_per_batch)) > 0) {
+ int current_index = 0;
+ while (current_index < num_extracted) {
+ // every batch of extracted locks is in range-sorted order. search
+ // through them and merge adjacent locks with the same txnid into
+ // one dominating lock and save it to a set of escalated locks.
+ //
+ // first, find the index of the next row lock that
+ // - belongs to a different txnid, or
+ // - belongs to several txnids, or
+ // - is a shared lock (we could potentially merge those but
+ // currently we don't), or
+ // - is across a lock escalation barrier.
+ int next_txnid_index = current_index + 1;
+
+ while (next_txnid_index < num_extracted &&
+ (extracted_buf[current_index].txnid ==
+ extracted_buf[next_txnid_index].txnid) &&
+ !extracted_buf[next_txnid_index].is_shared &&
+ !extracted_buf[next_txnid_index].owners &&
+ !m_escalation_barrier(
+ extracted_buf[current_index].range.get_right_key(),
+ extracted_buf[next_txnid_index].range.get_left_key(),
+ m_escalation_barrier_arg)) {
+ next_txnid_index++;
+ }
+
+ // Create an escalated range for the current txnid that dominates
+ // each range between the current indext and the next txnid's index.
+ // const TXNID current_txnid = extracted_buf[current_index].txnid;
+ const DBT *escalated_left_key =
+ extracted_buf[current_index].range.get_left_key();
+ const DBT *escalated_right_key =
+ extracted_buf[next_txnid_index - 1].range.get_right_key();
+
+ // Try to find a range buffer for the current txnid. Create one if it
+ // doesn't exist. Then, append the new escalated range to the buffer. (If
+ // a lock is shared by multiple txnids, append it each of txnid's lists)
+ TxnidVector *owners_ptr;
+ TxnidVector singleton_owner;
+ if (extracted_buf[current_index].owners)
+ owners_ptr = extracted_buf[current_index].owners;
+ else {
+ singleton_owner.insert(extracted_buf[current_index].txnid);
+ owners_ptr = &singleton_owner;
+ }
+
+ for (auto cur_txnid : *owners_ptr) {
+ uint32_t idx;
+ struct txnid_range_buffer *existing_range_buffer;
+ int r =
+ range_buffers.find_zero<TXNID, txnid_range_buffer::find_by_txnid>(
+ cur_txnid, &existing_range_buffer, &idx);
+ if (r == DB_NOTFOUND) {
+ struct txnid_range_buffer *XMALLOC(new_range_buffer);
+ new_range_buffer->txnid = cur_txnid;
+ new_range_buffer->buffer.create();
+ new_range_buffer->buffer.append(
+ escalated_left_key, escalated_right_key,
+ !extracted_buf[current_index].is_shared);
+ range_buffers.insert_at(new_range_buffer, idx);
+ } else {
+ invariant_zero(r);
+ invariant(existing_range_buffer->txnid == cur_txnid);
+ existing_range_buffer->buffer.append(
+ escalated_left_key, escalated_right_key,
+ !extracted_buf[current_index].is_shared);
+ }
+ }
+
+ current_index = next_txnid_index;
+ }
+
+ // destroy the ranges copied during the extraction
+ for (int i = 0; i < num_extracted; i++) {
+ delete extracted_buf[i].owners;
+ extracted_buf[i].range.destroy();
+ }
+ }
+ toku_free(extracted_buf);
+
+ // Rebuild the locktree from each range in each range buffer,
+ // then notify higher layers that the txnid's locks have changed.
+ //
+ // (shared locks: if a lock was initially shared between transactions TRX1,
+ // TRX2, etc, we will now try to acquire it acting on behalf on TRX1, on
+ // TRX2, etc. This will succeed and an identical shared lock will be
+ // constructed)
+
+ invariant(m_rangetree->is_empty());
+ const uint32_t num_range_buffers = range_buffers.size();
+ for (uint32_t i = 0; i < num_range_buffers; i++) {
+ struct txnid_range_buffer *current_range_buffer;
+ int r = range_buffers.fetch(i, &current_range_buffer);
+ invariant_zero(r);
+ if (r == EINVAL) // Shouldn't happen, avoid compiler warning
+ continue;
+
+ const TXNID current_txnid = current_range_buffer->txnid;
+ range_buffer::iterator iter(&current_range_buffer->buffer);
+ range_buffer::iterator::record rec;
+ while (iter.current(&rec)) {
+ keyrange range;
+ range.create(rec.get_left_key(), rec.get_right_key());
+ row_lock lock = {.range = range,
+ .txnid = current_txnid,
+ .is_shared = !rec.get_exclusive_flag(),
+ .owners = nullptr};
+ insert_row_lock_into_tree(&lkr, lock, m_mgr);
+ iter.next();
+ }
+
+ // Notify higher layers that locks have changed for the current txnid
+ if (after_escalate_callback) {
+ after_escalate_callback(current_txnid, this, current_range_buffer->buffer,
+ after_escalate_callback_extra);
+ }
+ current_range_buffer->buffer.destroy();
+ }
+
+ while (range_buffers.size() > 0) {
+ struct txnid_range_buffer *buffer;
+ int r = range_buffers.fetch(0, &buffer);
+ invariant_zero(r);
+ r = range_buffers.delete_at(0);
+ invariant_zero(r);
+ toku_free(buffer);
+ }
+ range_buffers.destroy();
+
+ lkr.release();
+}
+
+void *locktree::get_userdata(void) const { return m_userdata; }
+
+void locktree::set_userdata(void *userdata) { m_userdata = userdata; }
+
+struct lt_lock_request_info *locktree::get_lock_request_info(void) {
+ return &m_lock_request_info;
+}
+
+void locktree::set_comparator(const comparator &cmp) { m_cmp.inherit(cmp); }
+
+locktree_manager *locktree::get_manager(void) const { return m_mgr; }
+
+int locktree::compare(const locktree *lt) const {
+ if (m_dict_id.dictid < lt->m_dict_id.dictid) {
+ return -1;
+ } else if (m_dict_id.dictid == lt->m_dict_id.dictid) {
+ return 0;
+ } else {
+ return 1;
+ }
+}
+
+DICTIONARY_ID locktree::get_dict_id() const { return m_dict_id; }
+
+} /* namespace toku */
+#endif // OS_WIN
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/locktree.h b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/locktree.h
new file mode 100644
index 000000000..f0f4b042d
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/locktree.h
@@ -0,0 +1,580 @@
+/* -*- mode: C++; c-basic-offset: 4; indent-tabs-mode: nil -*- */
+// vim: ft=cpp:expandtab:ts=8:sw=4:softtabstop=4:
+#ident "$Id$"
+/*======
+This file is part of PerconaFT.
+
+
+Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved.
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License, version 2,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License, version 3,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+======= */
+
+#ident \
+ "Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved."
+
+#pragma once
+
+#include <atomic>
+
+#include "../db.h"
+#include "../ft/comparator.h"
+#include "../portability/toku_external_pthread.h"
+#include "../portability/toku_pthread.h"
+#include "../portability/toku_time.h"
+// PORT #include <ft/ft-ops.h> // just for DICTIONARY_ID..
+// PORT: ft-status for LTM_STATUS:
+#include "../ft/ft-status.h"
+
+struct DICTIONARY_ID {
+ uint64_t dictid;
+};
+
+#include "../util/omt.h"
+#include "range_buffer.h"
+#include "txnid_set.h"
+#include "wfg.h"
+
+namespace toku {
+
+class locktree;
+class locktree_manager;
+class lock_request;
+class concurrent_tree;
+
+typedef int (*lt_create_cb)(locktree *lt, void *extra);
+typedef void (*lt_destroy_cb)(locktree *lt);
+typedef void (*lt_escalate_cb)(TXNID txnid, const locktree *lt,
+ const range_buffer &buffer, void *extra);
+
+typedef bool (*lt_escalation_barrier_check_func)(const DBT *a, const DBT *b,
+ void *extra);
+
+struct lt_counters {
+ uint64_t wait_count, wait_time;
+ uint64_t long_wait_count, long_wait_time;
+ uint64_t timeout_count;
+
+ void add(const lt_counters &rhs) {
+ wait_count += rhs.wait_count;
+ wait_time += rhs.wait_time;
+ long_wait_count += rhs.long_wait_count;
+ long_wait_time += rhs.long_wait_time;
+ timeout_count += rhs.timeout_count;
+ }
+};
+
+// Lock request state for some locktree
+struct lt_lock_request_info {
+ omt<lock_request *> pending_lock_requests;
+ std::atomic_bool pending_is_empty;
+ toku_external_mutex_t mutex;
+ bool should_retry_lock_requests;
+ lt_counters counters;
+ std::atomic_ullong retry_want;
+ unsigned long long retry_done;
+ toku_mutex_t retry_mutex;
+ toku_cond_t retry_cv;
+ bool running_retry;
+
+ void init(toku_external_mutex_factory_t mutex_factory);
+ void destroy(void);
+};
+
+// The locktree manager manages a set of locktrees, one for each open
+// dictionary. Locktrees are retrieved from the manager. When they are no
+// longer needed, they are be released by the user.
+class locktree_manager {
+ public:
+ // param: create_cb, called just after a locktree is first created.
+ // destroy_cb, called just before a locktree is destroyed.
+ // escalate_cb, called after a locktree is escalated (with extra
+ // param)
+ void create(lt_create_cb create_cb, lt_destroy_cb destroy_cb,
+ lt_escalate_cb escalate_cb, void *extra,
+ toku_external_mutex_factory_t mutex_factory_arg);
+
+ void destroy(void);
+
+ size_t get_max_lock_memory(void);
+
+ int set_max_lock_memory(size_t max_lock_memory);
+
+ // effect: Get a locktree from the manager. If a locktree exists with the
+ // given
+ // dict_id, it is referenced and then returned. If one did not exist,
+ // it is created. It will use the comparator for comparing keys. The
+ // on_create callback (passed to locktree_manager::create()) will be
+ // called with the given extra parameter.
+ locktree *get_lt(DICTIONARY_ID dict_id, const comparator &cmp,
+ void *on_create_extra);
+
+ void reference_lt(locktree *lt);
+
+ // effect: Releases one reference on a locktree. If the reference count
+ // transitions
+ // to zero, the on_destroy callback is called before it gets
+ // destroyed.
+ void release_lt(locktree *lt);
+
+ void get_status(LTM_STATUS status);
+
+ // effect: calls the iterate function on each pending lock request
+ // note: holds the manager's mutex
+ typedef int (*lock_request_iterate_callback)(DICTIONARY_ID dict_id,
+ TXNID txnid, const DBT *left_key,
+ const DBT *right_key,
+ TXNID blocking_txnid,
+ uint64_t start_time,
+ void *extra);
+ int iterate_pending_lock_requests(lock_request_iterate_callback cb,
+ void *extra);
+
+ // effect: Determines if too many locks or too much memory is being used,
+ // Runs escalation on the manager if so.
+ // param: big_txn, if the current transaction is 'big' (has spilled rollback
+ // logs) returns: 0 if there enough resources to create a new lock, or
+ // TOKUDB_OUT_OF_LOCKS
+ // if there are not enough resources and lock escalation failed to
+ // free up enough resources for a new lock.
+ int check_current_lock_constraints(bool big_txn);
+
+ bool over_big_threshold(void);
+
+ void note_mem_used(uint64_t mem_used);
+
+ void note_mem_released(uint64_t mem_freed);
+
+ bool out_of_locks(void) const;
+
+ // Escalate all locktrees
+ void escalate_all_locktrees(void);
+
+ // Escalate a set of locktrees
+ void escalate_locktrees(locktree **locktrees, int num_locktrees);
+
+ // effect: calls the private function run_escalation(), only ok to
+ // do for tests.
+ // rationale: to get better stress test coverage, we want a way to
+ // deterministicly trigger lock escalation.
+ void run_escalation_for_test(void);
+ void run_escalation(void);
+
+ // Add time t to the escalator's wait time statistics
+ void add_escalator_wait_time(uint64_t t);
+
+ void kill_waiter(void *extra);
+
+ private:
+ static const uint64_t DEFAULT_MAX_LOCK_MEMORY = 64L * 1024 * 1024;
+
+ // tracks the current number of locks and lock memory
+ uint64_t m_max_lock_memory;
+ uint64_t m_current_lock_memory;
+
+ struct lt_counters m_lt_counters;
+
+ // the create and destroy callbacks for the locktrees
+ lt_create_cb m_lt_create_callback;
+ lt_destroy_cb m_lt_destroy_callback;
+ lt_escalate_cb m_lt_escalate_callback;
+ void *m_lt_escalate_callback_extra;
+
+ omt<locktree *> m_locktree_map;
+
+ toku_external_mutex_factory_t mutex_factory;
+
+ // the manager's mutex protects the locktree map
+ toku_mutex_t m_mutex;
+
+ void mutex_lock(void);
+
+ void mutex_unlock(void);
+
+ // Manage the set of open locktrees
+ locktree *locktree_map_find(const DICTIONARY_ID &dict_id);
+ void locktree_map_put(locktree *lt);
+ void locktree_map_remove(locktree *lt);
+
+ static int find_by_dict_id(locktree *const &lt, const DICTIONARY_ID &dict_id);
+
+ void escalator_init(void);
+ void escalator_destroy(void);
+
+ // statistics about lock escalation.
+ toku_mutex_t m_escalation_mutex;
+ uint64_t m_escalation_count;
+ tokutime_t m_escalation_time;
+ uint64_t m_escalation_latest_result;
+ uint64_t m_wait_escalation_count;
+ uint64_t m_wait_escalation_time;
+ uint64_t m_long_wait_escalation_count;
+ uint64_t m_long_wait_escalation_time;
+
+ // the escalator coordinates escalation on a set of locktrees for a bunch of
+ // threads
+ class locktree_escalator {
+ public:
+ void create(void);
+ void destroy(void);
+ void run(locktree_manager *mgr, void (*escalate_locktrees_fun)(void *extra),
+ void *extra);
+
+ private:
+ toku_mutex_t m_escalator_mutex;
+ toku_cond_t m_escalator_done;
+ bool m_escalator_running;
+ };
+
+ locktree_escalator m_escalator;
+
+ friend class manager_unit_test;
+};
+
+// A locktree represents the set of row locks owned by all transactions
+// over an open dictionary. Read and write ranges are represented as
+// a left and right key which are compared with the given comparator
+//
+// Locktrees are not created and destroyed by the user. Instead, they are
+// referenced and released using the locktree manager.
+//
+// A sample workflow looks like this:
+// - Create a manager.
+// - Get a locktree by dictionaroy id from the manager.
+// - Perform read/write lock acquision on the locktree, add references to
+// the locktree using the manager, release locks, release references, etc.
+// - ...
+// - Release the final reference to the locktree. It will be destroyed.
+// - Destroy the manager.
+class locktree {
+ public:
+ // effect: Creates a locktree
+ void create(locktree_manager *mgr, DICTIONARY_ID dict_id,
+ const comparator &cmp,
+ toku_external_mutex_factory_t mutex_factory);
+
+ void destroy(void);
+
+ // For thread-safe, external reference counting
+ void add_reference(void);
+
+ // requires: the reference count is > 0
+ // returns: the reference count, after decrementing it by one
+ uint32_t release_reference(void);
+
+ // returns: the current reference count
+ uint32_t get_reference_count(void);
+
+ // effect: Attempts to grant a read lock for the range of keys between
+ // [left_key, right_key]. returns: If the lock cannot be granted, return
+ // DB_LOCK_NOTGRANTED, and populate the
+ // given conflicts set with the txnids that hold conflicting locks in
+ // the range. If the locktree cannot create more locks, return
+ // TOKUDB_OUT_OF_LOCKS.
+ // note: Read locks cannot be shared between txnids, as one would expect.
+ // This is for simplicity since read locks are rare in MySQL.
+ int acquire_read_lock(TXNID txnid, const DBT *left_key, const DBT *right_key,
+ txnid_set *conflicts, bool big_txn);
+
+ // effect: Attempts to grant a write lock for the range of keys between
+ // [left_key, right_key]. returns: If the lock cannot be granted, return
+ // DB_LOCK_NOTGRANTED, and populate the
+ // given conflicts set with the txnids that hold conflicting locks in
+ // the range. If the locktree cannot create more locks, return
+ // TOKUDB_OUT_OF_LOCKS.
+ int acquire_write_lock(TXNID txnid, const DBT *left_key, const DBT *right_key,
+ txnid_set *conflicts, bool big_txn);
+
+ // effect: populate the conflicts set with the txnids that would preventing
+ // the given txnid from getting a lock on [left_key, right_key]
+ void get_conflicts(bool is_write_request, TXNID txnid, const DBT *left_key,
+ const DBT *right_key, txnid_set *conflicts);
+
+ // effect: Release all of the lock ranges represented by the range buffer for
+ // a txnid.
+ void release_locks(TXNID txnid, const range_buffer *ranges,
+ bool all_trx_locks_hint = false);
+
+ // effect: Runs escalation on this locktree
+ void escalate(lt_escalate_cb after_escalate_callback, void *extra);
+
+ // returns: The userdata associated with this locktree, or null if it has not
+ // been set.
+ void *get_userdata(void) const;
+
+ void set_userdata(void *userdata);
+
+ locktree_manager *get_manager(void) const;
+
+ void set_comparator(const comparator &cmp);
+
+ // Set the user-provided Lock Escalation Barrier check function and its
+ // argument
+ //
+ // Lock Escalation Barrier limits the scope of Lock Escalation.
+ // For two keys A and B (such that A < B),
+ // escalation_barrier_check_func(A, B)==true means that there's a lock
+ // escalation barrier between A and B, and lock escalation is not allowed to
+ // bridge the gap between A and B.
+ //
+ // This method sets the user-provided barrier check function and its
+ // parameter.
+ void set_escalation_barrier_func(lt_escalation_barrier_check_func func,
+ void *extra);
+
+ int compare(const locktree *lt) const;
+
+ DICTIONARY_ID get_dict_id() const;
+
+ // Private info struct for storing pending lock request state.
+ // Only to be used by lock requests. We store it here as
+ // something less opaque than usual to strike a tradeoff between
+ // abstraction and code complexity. It is still fairly abstract
+ // since the lock_request object is opaque
+ struct lt_lock_request_info *get_lock_request_info(void);
+
+ typedef void (*dump_callback)(void *cdata, const DBT *left, const DBT *right,
+ TXNID txnid, bool is_shared,
+ TxnidVector *owners);
+ void dump_locks(void *cdata, dump_callback cb);
+
+ private:
+ locktree_manager *m_mgr;
+ DICTIONARY_ID m_dict_id;
+ uint32_t m_reference_count;
+
+ // Since the memory referenced by this comparator is not owned by the
+ // locktree, the user must guarantee it will outlive the locktree.
+ //
+ // The ydb API accomplishes this by opening an ft_handle in the on_create
+ // callback, which will keep the underlying FT (and its descriptor) in memory
+ // for as long as the handle is open. The ft_handle is stored opaquely in the
+ // userdata pointer below. see locktree_manager::get_lt w/ on_create_extra
+ comparator m_cmp;
+
+ lt_escalation_barrier_check_func m_escalation_barrier;
+ void *m_escalation_barrier_arg;
+
+ concurrent_tree *m_rangetree;
+
+ void *m_userdata;
+ struct lt_lock_request_info m_lock_request_info;
+
+ // psergey-todo:
+ // Each transaction also keeps a list of ranges it has locked.
+ // So, when a transaction is running in STO mode, two identical
+ // lists are kept: the STO lock list and transaction's owned locks
+ // list. Why can't we do with just one list?
+
+ // The following fields and members prefixed with "sto_" are for
+ // the single txnid optimization, intended to speed up the case
+ // when only one transaction is using the locktree. If we know
+ // the locktree has only one transaction, then acquiring locks
+ // takes O(1) work and releasing all locks takes O(1) work.
+ //
+ // How do we know that the locktree only has a single txnid?
+ // What do we do if it does?
+ //
+ // When a txn with txnid T requests a lock:
+ // - If the tree is empty, the optimization is possible. Set the single
+ // txnid to T, and insert the lock range into the buffer.
+ // - If the tree is not empty, check if the single txnid is T. If so,
+ // append the lock range to the buffer. Otherwise, migrate all of
+ // the locks in the buffer into the rangetree on behalf of txnid T,
+ // and invalid the single txnid.
+ //
+ // When a txn with txnid T releases its locks:
+ // - If the single txnid is valid, it must be for T. Destroy the buffer.
+ // - If it's not valid, release locks the normal way in the rangetree.
+ //
+ // To carry out the optimization we need to record a single txnid
+ // and a range buffer for each locktree, each protected by the root
+ // lock of the locktree's rangetree. The root lock for a rangetree
+ // is grabbed by preparing a locked keyrange on the rangetree.
+ TXNID m_sto_txnid;
+ range_buffer m_sto_buffer;
+
+ // The single txnid optimization speeds up the case when only one
+ // transaction is using the locktree. But it has the potential to
+ // hurt the case when more than one txnid exists.
+ //
+ // There are two things we need to do to make the optimization only
+ // optimize the case we care about, and not hurt the general case.
+ //
+ // Bound the worst-case latency for lock migration when the
+ // optimization stops working:
+ // - Idea: Stop the optimization and migrate immediate if we notice
+ // the single txnid has takes many locks in the range buffer.
+ // - Implementation: Enforce a max size on the single txnid range buffer.
+ // - Analysis: Choosing the perfect max value, M, is difficult to do
+ // without some feedback from the field. Intuition tells us that M should
+ // not be so small that the optimization is worthless, and it should not
+ // be so big that it's unreasonable to have to wait behind a thread doing
+ // the work of converting M buffer locks into rangetree locks.
+ //
+ // Prevent concurrent-transaction workloads from trying the optimization
+ // in vain:
+ // - Idea: Don't even bother trying the optimization if we think the
+ // system is in a concurrent-transaction state.
+ // - Implementation: Do something even simpler than detecting whether the
+ // system is in a concurent-transaction state. Just keep a "score" value
+ // and some threshold. If at any time the locktree is eligible for the
+ // optimization, only do it if the score is at this threshold. When you
+ // actually do the optimization but someone has to migrate locks in the buffer
+ // (expensive), then reset the score back to zero. Each time a txn
+ // releases locks, the score is incremented by 1.
+ // - Analysis: If you let the threshold be "C", then at most 1 / C txns will
+ // do the optimization in a concurrent-transaction system. Similarly, it
+ // takes at most C txns to start using the single txnid optimzation, which
+ // is good when the system transitions from multithreaded to single threaded.
+ //
+ // STO_BUFFER_MAX_SIZE:
+ //
+ // We choose the max value to be 1 million since most transactions are smaller
+ // than 1 million and we can create a rangetree of 1 million elements in
+ // less than a second. So we can be pretty confident that this threshold
+ // enables the optimization almost always, and prevents super pathological
+ // latency issues for the first lock taken by a second thread.
+ //
+ // STO_SCORE_THRESHOLD:
+ //
+ // A simple first guess at a good value for the score threshold is 100.
+ // By our analysis, we'd end up doing the optimization in vain for
+ // around 1% of all transactions, which seems reasonable. Further,
+ // if the system goes single threaded, it ought to be pretty quick
+ // for 100 transactions to go by, so we won't have to wait long before
+ // we start doing the single txind optimzation again.
+ static const int STO_BUFFER_MAX_SIZE = 50 * 1024;
+ static const int STO_SCORE_THRESHOLD = 100;
+ int m_sto_score;
+
+ // statistics about time spent ending the STO early
+ uint64_t m_sto_end_early_count;
+ tokutime_t m_sto_end_early_time;
+
+ // effect: begins the single txnid optimizaiton, setting m_sto_txnid
+ // to the given txnid.
+ // requires: m_sto_txnid is invalid
+ void sto_begin(TXNID txnid);
+
+ // effect: append a range to the sto buffer
+ // requires: m_sto_txnid is valid
+ void sto_append(const DBT *left_key, const DBT *right_key,
+ bool is_write_request);
+
+ // effect: ends the single txnid optimization, releaseing any memory
+ // stored in the sto buffer, notifying the tracker, and
+ // invalidating m_sto_txnid.
+ // requires: m_sto_txnid is valid
+ void sto_end(void);
+
+ // params: prepared_lkr is a void * to a prepared locked keyrange. see below.
+ // effect: ends the single txnid optimization early, migrating buffer locks
+ // into the rangetree, calling sto_end(), and then setting the
+ // sto_score back to zero.
+ // requires: m_sto_txnid is valid
+ void sto_end_early(void *prepared_lkr);
+ void sto_end_early_no_accounting(void *prepared_lkr);
+
+ // params: prepared_lkr is a void * to a prepared locked keyrange. we can't
+ // use
+ // the real type because the compiler won't allow us to forward
+ // declare concurrent_tree::locked_keyrange without including
+ // concurrent_tree.h, which we cannot do here because it is a template
+ // implementation.
+ // requires: the prepared locked keyrange is for the locktree's rangetree
+ // requires: m_sto_txnid is valid
+ // effect: migrates each lock in the single txnid buffer into the locktree's
+ // rangetree, notifying the memory tracker as necessary.
+ void sto_migrate_buffer_ranges_to_tree(void *prepared_lkr);
+
+ // effect: If m_sto_txnid is valid, then release the txnid's locks
+ // by ending the optimization.
+ // requires: If m_sto_txnid is valid, it is equal to the given txnid
+ // returns: True if locks were released for this txnid
+ bool sto_try_release(TXNID txnid);
+
+ // params: prepared_lkr is a void * to a prepared locked keyrange. see above.
+ // requires: the prepared locked keyrange is for the locktree's rangetree
+ // effect: If m_sto_txnid is valid and equal to the given txnid, then
+ // append a range onto the buffer. Otherwise, if m_sto_txnid is valid
+ // but not equal to this txnid, then migrate the buffer's locks
+ // into the rangetree and end the optimization, setting the score
+ // back to zero.
+ // returns: true if the lock was acquired for this txnid
+ bool sto_try_acquire(void *prepared_lkr, TXNID txnid, const DBT *left_key,
+ const DBT *right_key, bool is_write_request);
+
+ // Effect:
+ // Provides a hook for a helgrind suppression.
+ // Returns:
+ // true if m_sto_txnid is not TXNID_NONE
+ bool sto_txnid_is_valid_unsafe(void) const;
+
+ // Effect:
+ // Provides a hook for a helgrind suppression.
+ // Returns:
+ // m_sto_score
+ int sto_get_score_unsafe(void) const;
+
+ void remove_overlapping_locks_for_txnid(TXNID txnid, const DBT *left_key,
+ const DBT *right_key);
+
+ int acquire_lock_consolidated(void *prepared_lkr, TXNID txnid,
+ const DBT *left_key, const DBT *right_key,
+ bool is_write_request, txnid_set *conflicts);
+
+ int acquire_lock(bool is_write_request, TXNID txnid, const DBT *left_key,
+ const DBT *right_key, txnid_set *conflicts);
+
+ int try_acquire_lock(bool is_write_request, TXNID txnid, const DBT *left_key,
+ const DBT *right_key, txnid_set *conflicts,
+ bool big_txn);
+
+ friend class locktree_unit_test;
+ friend class manager_unit_test;
+ friend class lock_request_unit_test;
+
+ // engine status reaches into the locktree to read some stats
+ friend void locktree_manager::get_status(LTM_STATUS status);
+};
+
+} /* namespace toku */
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/manager.cc b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/manager.cc
new file mode 100644
index 000000000..4186182be
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/manager.cc
@@ -0,0 +1,527 @@
+/* -*- mode: C++; c-basic-offset: 4; indent-tabs-mode: nil -*- */
+// vim: ft=cpp:expandtab:ts=8:sw=4:softtabstop=4:
+#ifndef ROCKSDB_LITE
+#ifndef OS_WIN
+#ident "$Id$"
+/*======
+This file is part of PerconaFT.
+
+
+Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved.
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License, version 2,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License, version 3,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+======= */
+
+#ident \
+ "Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved."
+
+#include <stdlib.h>
+#include <string.h>
+
+#include "../portability/toku_pthread.h"
+#include "../util/status.h"
+#include "lock_request.h"
+#include "locktree.h"
+
+namespace toku {
+
+void locktree_manager::create(lt_create_cb create_cb, lt_destroy_cb destroy_cb,
+ lt_escalate_cb escalate_cb, void *escalate_extra,
+ toku_external_mutex_factory_t mutex_factory_arg) {
+ mutex_factory = mutex_factory_arg;
+ m_max_lock_memory = DEFAULT_MAX_LOCK_MEMORY;
+ m_current_lock_memory = 0;
+
+ m_locktree_map.create();
+ m_lt_create_callback = create_cb;
+ m_lt_destroy_callback = destroy_cb;
+ m_lt_escalate_callback = escalate_cb;
+ m_lt_escalate_callback_extra = escalate_extra;
+ ZERO_STRUCT(m_mutex);
+ toku_mutex_init(manager_mutex_key, &m_mutex, nullptr);
+
+ ZERO_STRUCT(m_lt_counters);
+
+ escalator_init();
+}
+
+void locktree_manager::destroy(void) {
+ escalator_destroy();
+ invariant(m_current_lock_memory == 0);
+ invariant(m_locktree_map.size() == 0);
+ m_locktree_map.destroy();
+ toku_mutex_destroy(&m_mutex);
+}
+
+void locktree_manager::mutex_lock(void) { toku_mutex_lock(&m_mutex); }
+
+void locktree_manager::mutex_unlock(void) { toku_mutex_unlock(&m_mutex); }
+
+size_t locktree_manager::get_max_lock_memory(void) { return m_max_lock_memory; }
+
+int locktree_manager::set_max_lock_memory(size_t max_lock_memory) {
+ int r = 0;
+ mutex_lock();
+ if (max_lock_memory < m_current_lock_memory) {
+ r = EDOM;
+ } else {
+ m_max_lock_memory = max_lock_memory;
+ }
+ mutex_unlock();
+ return r;
+}
+
+int locktree_manager::find_by_dict_id(locktree *const &lt,
+ const DICTIONARY_ID &dict_id) {
+ if (lt->get_dict_id().dictid < dict_id.dictid) {
+ return -1;
+ } else if (lt->get_dict_id().dictid == dict_id.dictid) {
+ return 0;
+ } else {
+ return 1;
+ }
+}
+
+locktree *locktree_manager::locktree_map_find(const DICTIONARY_ID &dict_id) {
+ locktree *lt;
+ int r = m_locktree_map.find_zero<DICTIONARY_ID, find_by_dict_id>(dict_id, &lt,
+ nullptr);
+ return r == 0 ? lt : nullptr;
+}
+
+void locktree_manager::locktree_map_put(locktree *lt) {
+ int r = m_locktree_map.insert<DICTIONARY_ID, find_by_dict_id>(
+ lt, lt->get_dict_id(), nullptr);
+ invariant_zero(r);
+}
+
+void locktree_manager::locktree_map_remove(locktree *lt) {
+ uint32_t idx;
+ locktree *found_lt;
+ int r = m_locktree_map.find_zero<DICTIONARY_ID, find_by_dict_id>(
+ lt->get_dict_id(), &found_lt, &idx);
+ invariant_zero(r);
+ invariant(found_lt == lt);
+ r = m_locktree_map.delete_at(idx);
+ invariant_zero(r);
+}
+
+locktree *locktree_manager::get_lt(DICTIONARY_ID dict_id, const comparator &cmp,
+ void *on_create_extra) {
+ // hold the mutex around searching and maybe
+ // inserting into the locktree map
+ mutex_lock();
+
+ locktree *lt = locktree_map_find(dict_id);
+ if (lt == nullptr) {
+ XCALLOC(lt);
+ lt->create(this, dict_id, cmp, mutex_factory);
+
+ // new locktree created - call the on_create callback
+ // and put it in the locktree map
+ if (m_lt_create_callback) {
+ int r = m_lt_create_callback(lt, on_create_extra);
+ if (r != 0) {
+ lt->release_reference();
+ lt->destroy();
+ toku_free(lt);
+ lt = nullptr;
+ }
+ }
+ if (lt) {
+ locktree_map_put(lt);
+ }
+ } else {
+ reference_lt(lt);
+ }
+
+ mutex_unlock();
+
+ return lt;
+}
+
+void locktree_manager::reference_lt(locktree *lt) {
+ // increment using a sync fetch and add.
+ // the caller guarantees that the lt won't be
+ // destroyed while we increment the count here.
+ //
+ // the caller can do this by already having an lt
+ // reference or by holding the manager mutex.
+ //
+ // if the manager's mutex is held, it is ok for the
+ // reference count to transition from 0 to 1 (no race),
+ // since we're serialized with other opens and closes.
+ lt->add_reference();
+}
+
+void locktree_manager::release_lt(locktree *lt) {
+ bool do_destroy = false;
+ DICTIONARY_ID dict_id = lt->get_dict_id();
+
+ // Release a reference on the locktree. If the count transitions to zero,
+ // then we *may* need to do the cleanup.
+ //
+ // Grab the manager's mutex and look for a locktree with this locktree's
+ // dictionary id. Since dictionary id's never get reused, any locktree
+ // found must be the one we just released a reference on.
+ //
+ // At least two things could have happened since we got the mutex:
+ // - Another thread gets a locktree with the same dict_id, increments
+ // the reference count. In this case, we shouldn't destroy it.
+ // - Another thread gets a locktree with the same dict_id and then
+ // releases it quickly, transitioning the reference count from zero to
+ // one and back to zero. In this case, only one of us should destroy it.
+ // It doesn't matter which. We originally missed this case, see #5776.
+ //
+ // After 5776, the high level rule for release is described below.
+ //
+ // If a thread releases a locktree and notices the reference count transition
+ // to zero, then that thread must immediately:
+ // - assume the locktree object is invalid
+ // - grab the manager's mutex
+ // - search the locktree map for a locktree with the same dict_id and remove
+ // it, if it exists. the destroy may be deferred.
+ // - release the manager's mutex
+ //
+ // This way, if many threads transition the same locktree's reference count
+ // from 1 to zero and wait behind the manager's mutex, only one of them will
+ // do the actual destroy and the others will happily do nothing.
+ uint32_t refs = lt->release_reference();
+ if (refs == 0) {
+ mutex_lock();
+ // lt may not have already been destroyed, so look it up.
+ locktree *find_lt = locktree_map_find(dict_id);
+ if (find_lt != nullptr) {
+ // A locktree is still in the map with that dict_id, so it must be
+ // equal to lt. This is true because dictionary ids are never reused.
+ // If the reference count is zero, it's our responsibility to remove
+ // it and do the destroy. Otherwise, someone still wants it.
+ // If the locktree is still valid then check if it should be deleted.
+ if (find_lt == lt) {
+ if (lt->get_reference_count() == 0) {
+ locktree_map_remove(lt);
+ do_destroy = true;
+ }
+ m_lt_counters.add(lt->get_lock_request_info()->counters);
+ }
+ }
+ mutex_unlock();
+ }
+
+ // if necessary, do the destroy without holding the mutex
+ if (do_destroy) {
+ if (m_lt_destroy_callback) {
+ m_lt_destroy_callback(lt);
+ }
+ lt->destroy();
+ toku_free(lt);
+ }
+}
+
+void locktree_manager::run_escalation(void) {
+ struct escalation_fn {
+ static void run(void *extra) {
+ locktree_manager *mgr = (locktree_manager *)extra;
+ mgr->escalate_all_locktrees();
+ };
+ };
+ m_escalator.run(this, escalation_fn::run, this);
+}
+
+// test-only version of lock escalation
+void locktree_manager::run_escalation_for_test(void) { run_escalation(); }
+
+void locktree_manager::escalate_all_locktrees(void) {
+ uint64_t t0 = toku_current_time_microsec();
+
+ // get all locktrees
+ mutex_lock();
+ int num_locktrees = m_locktree_map.size();
+ locktree **locktrees = new locktree *[num_locktrees];
+ for (int i = 0; i < num_locktrees; i++) {
+ int r = m_locktree_map.fetch(i, &locktrees[i]);
+ invariant_zero(r);
+ reference_lt(locktrees[i]);
+ }
+ mutex_unlock();
+
+ // escalate them
+ escalate_locktrees(locktrees, num_locktrees);
+
+ delete[] locktrees;
+
+ uint64_t t1 = toku_current_time_microsec();
+ add_escalator_wait_time(t1 - t0);
+}
+
+void locktree_manager::note_mem_used(uint64_t mem_used) {
+ (void)toku_sync_fetch_and_add(&m_current_lock_memory, mem_used);
+}
+
+void locktree_manager::note_mem_released(uint64_t mem_released) {
+ uint64_t old_mem_used =
+ toku_sync_fetch_and_sub(&m_current_lock_memory, mem_released);
+ invariant(old_mem_used >= mem_released);
+}
+
+bool locktree_manager::out_of_locks(void) const {
+ return m_current_lock_memory >= m_max_lock_memory;
+}
+
+bool locktree_manager::over_big_threshold(void) {
+ return m_current_lock_memory >= m_max_lock_memory / 2;
+}
+
+int locktree_manager::iterate_pending_lock_requests(
+ lock_request_iterate_callback callback, void *extra) {
+ mutex_lock();
+ int r = 0;
+ uint32_t num_locktrees = m_locktree_map.size();
+ for (uint32_t i = 0; i < num_locktrees && r == 0; i++) {
+ locktree *lt;
+ r = m_locktree_map.fetch(i, &lt);
+ invariant_zero(r);
+ if (r == EINVAL) // Shouldn't happen, avoid compiler warning
+ continue;
+
+ struct lt_lock_request_info *info = lt->get_lock_request_info();
+ toku_external_mutex_lock(&info->mutex);
+
+ uint32_t num_requests = info->pending_lock_requests.size();
+ for (uint32_t k = 0; k < num_requests && r == 0; k++) {
+ lock_request *req;
+ r = info->pending_lock_requests.fetch(k, &req);
+ invariant_zero(r);
+ if (r == EINVAL) /* Shouldn't happen, avoid compiler warning */
+ continue;
+ r = callback(lt->get_dict_id(), req->get_txnid(), req->get_left_key(),
+ req->get_right_key(), req->get_conflicting_txnid(),
+ req->get_start_time(), extra);
+ }
+
+ toku_external_mutex_unlock(&info->mutex);
+ }
+ mutex_unlock();
+ return r;
+}
+
+int locktree_manager::check_current_lock_constraints(bool big_txn) {
+ int r = 0;
+ if (big_txn && over_big_threshold()) {
+ run_escalation();
+ if (over_big_threshold()) {
+ r = TOKUDB_OUT_OF_LOCKS;
+ }
+ }
+ if (r == 0 && out_of_locks()) {
+ run_escalation();
+ if (out_of_locks()) {
+ // return an error if we're still out of locks after escalation.
+ r = TOKUDB_OUT_OF_LOCKS;
+ }
+ }
+ return r;
+}
+
+void locktree_manager::escalator_init(void) {
+ ZERO_STRUCT(m_escalation_mutex);
+ toku_mutex_init(manager_escalation_mutex_key, &m_escalation_mutex, nullptr);
+ m_escalation_count = 0;
+ m_escalation_time = 0;
+ m_wait_escalation_count = 0;
+ m_wait_escalation_time = 0;
+ m_long_wait_escalation_count = 0;
+ m_long_wait_escalation_time = 0;
+ m_escalation_latest_result = 0;
+ m_escalator.create();
+}
+
+void locktree_manager::escalator_destroy(void) {
+ m_escalator.destroy();
+ toku_mutex_destroy(&m_escalation_mutex);
+}
+
+void locktree_manager::add_escalator_wait_time(uint64_t t) {
+ toku_mutex_lock(&m_escalation_mutex);
+ m_wait_escalation_count += 1;
+ m_wait_escalation_time += t;
+ if (t >= 1000000) {
+ m_long_wait_escalation_count += 1;
+ m_long_wait_escalation_time += t;
+ }
+ toku_mutex_unlock(&m_escalation_mutex);
+}
+
+void locktree_manager::escalate_locktrees(locktree **locktrees,
+ int num_locktrees) {
+ // there are too many row locks in the system and we need to tidy up.
+ //
+ // a simple implementation of escalation does not attempt
+ // to reduce the memory foot print of each txn's range buffer.
+ // doing so would require some layering hackery (or a callback)
+ // and more complicated locking. for now, just escalate each
+ // locktree individually, in-place.
+ tokutime_t t0 = toku_time_now();
+ for (int i = 0; i < num_locktrees; i++) {
+ locktrees[i]->escalate(m_lt_escalate_callback,
+ m_lt_escalate_callback_extra);
+ release_lt(locktrees[i]);
+ }
+ tokutime_t t1 = toku_time_now();
+
+ toku_mutex_lock(&m_escalation_mutex);
+ m_escalation_count++;
+ m_escalation_time += (t1 - t0);
+ m_escalation_latest_result = m_current_lock_memory;
+ toku_mutex_unlock(&m_escalation_mutex);
+}
+
+struct escalate_args {
+ locktree_manager *mgr;
+ locktree **locktrees;
+ int num_locktrees;
+};
+
+void locktree_manager::locktree_escalator::create(void) {
+ ZERO_STRUCT(m_escalator_mutex);
+ toku_mutex_init(manager_escalator_mutex_key, &m_escalator_mutex, nullptr);
+ toku_cond_init(manager_m_escalator_done_key, &m_escalator_done, nullptr);
+ m_escalator_running = false;
+}
+
+void locktree_manager::locktree_escalator::destroy(void) {
+ toku_cond_destroy(&m_escalator_done);
+ toku_mutex_destroy(&m_escalator_mutex);
+}
+
+void locktree_manager::locktree_escalator::run(
+ locktree_manager *mgr, void (*escalate_locktrees_fun)(void *extra),
+ void *extra) {
+ uint64_t t0 = toku_current_time_microsec();
+ toku_mutex_lock(&m_escalator_mutex);
+ if (!m_escalator_running) {
+ // run escalation on this thread
+ m_escalator_running = true;
+ toku_mutex_unlock(&m_escalator_mutex);
+ escalate_locktrees_fun(extra);
+ toku_mutex_lock(&m_escalator_mutex);
+ m_escalator_running = false;
+ toku_cond_broadcast(&m_escalator_done);
+ } else {
+ toku_cond_wait(&m_escalator_done, &m_escalator_mutex);
+ }
+ toku_mutex_unlock(&m_escalator_mutex);
+ uint64_t t1 = toku_current_time_microsec();
+ mgr->add_escalator_wait_time(t1 - t0);
+}
+
+void locktree_manager::get_status(LTM_STATUS statp) {
+ ltm_status.init();
+ LTM_STATUS_VAL(LTM_SIZE_CURRENT) = m_current_lock_memory;
+ LTM_STATUS_VAL(LTM_SIZE_LIMIT) = m_max_lock_memory;
+ LTM_STATUS_VAL(LTM_ESCALATION_COUNT) = m_escalation_count;
+ LTM_STATUS_VAL(LTM_ESCALATION_TIME) = m_escalation_time;
+ LTM_STATUS_VAL(LTM_ESCALATION_LATEST_RESULT) = m_escalation_latest_result;
+ LTM_STATUS_VAL(LTM_WAIT_ESCALATION_COUNT) = m_wait_escalation_count;
+ LTM_STATUS_VAL(LTM_WAIT_ESCALATION_TIME) = m_wait_escalation_time;
+ LTM_STATUS_VAL(LTM_LONG_WAIT_ESCALATION_COUNT) = m_long_wait_escalation_count;
+ LTM_STATUS_VAL(LTM_LONG_WAIT_ESCALATION_TIME) = m_long_wait_escalation_time;
+
+ uint64_t lock_requests_pending = 0;
+ uint64_t sto_num_eligible = 0;
+ uint64_t sto_end_early_count = 0;
+ tokutime_t sto_end_early_time = 0;
+ uint32_t num_locktrees = 0;
+ struct lt_counters lt_counters;
+ ZERO_STRUCT(lt_counters); // PORT: instead of ={}.
+
+ if (toku_mutex_trylock(&m_mutex) == 0) {
+ lt_counters = m_lt_counters;
+ num_locktrees = m_locktree_map.size();
+ for (uint32_t i = 0; i < num_locktrees; i++) {
+ locktree *lt;
+ int r = m_locktree_map.fetch(i, &lt);
+ invariant_zero(r);
+ if (r == EINVAL) // Shouldn't happen, avoid compiler warning
+ continue;
+ if (toku_external_mutex_trylock(&lt->m_lock_request_info.mutex) == 0) {
+ lock_requests_pending +=
+ lt->m_lock_request_info.pending_lock_requests.size();
+ lt_counters.add(lt->get_lock_request_info()->counters);
+ toku_external_mutex_unlock(&lt->m_lock_request_info.mutex);
+ }
+ sto_num_eligible += lt->sto_txnid_is_valid_unsafe() ? 1 : 0;
+ sto_end_early_count += lt->m_sto_end_early_count;
+ sto_end_early_time += lt->m_sto_end_early_time;
+ }
+ mutex_unlock();
+ }
+
+ LTM_STATUS_VAL(LTM_NUM_LOCKTREES) = num_locktrees;
+ LTM_STATUS_VAL(LTM_LOCK_REQUESTS_PENDING) = lock_requests_pending;
+ LTM_STATUS_VAL(LTM_STO_NUM_ELIGIBLE) = sto_num_eligible;
+ LTM_STATUS_VAL(LTM_STO_END_EARLY_COUNT) = sto_end_early_count;
+ LTM_STATUS_VAL(LTM_STO_END_EARLY_TIME) = sto_end_early_time;
+ LTM_STATUS_VAL(LTM_WAIT_COUNT) = lt_counters.wait_count;
+ LTM_STATUS_VAL(LTM_WAIT_TIME) = lt_counters.wait_time;
+ LTM_STATUS_VAL(LTM_LONG_WAIT_COUNT) = lt_counters.long_wait_count;
+ LTM_STATUS_VAL(LTM_LONG_WAIT_TIME) = lt_counters.long_wait_time;
+ LTM_STATUS_VAL(LTM_TIMEOUT_COUNT) = lt_counters.timeout_count;
+ *statp = ltm_status;
+}
+
+void locktree_manager::kill_waiter(void *extra) {
+ mutex_lock();
+ int r = 0;
+ uint32_t num_locktrees = m_locktree_map.size();
+ for (uint32_t i = 0; i < num_locktrees; i++) {
+ locktree *lt;
+ r = m_locktree_map.fetch(i, &lt);
+ invariant_zero(r);
+ if (r) continue; // Get rid of "may be used uninitialized" warning
+ lock_request::kill_waiter(lt, extra);
+ }
+ mutex_unlock();
+}
+
+} /* namespace toku */
+#endif // OS_WIN
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/range_buffer.cc b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/range_buffer.cc
new file mode 100644
index 000000000..1e1d23ef8
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/range_buffer.cc
@@ -0,0 +1,265 @@
+/* -*- mode: C++; c-basic-offset: 4; indent-tabs-mode: nil -*- */
+// vim: ft=cpp:expandtab:ts=8:sw=4:softtabstop=4:
+#ifndef ROCKSDB_LITE
+#ifndef OS_WIN
+#ident "$Id$"
+/*======
+This file is part of PerconaFT.
+
+
+Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved.
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License, version 2,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License, version 3,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+======= */
+
+#ident \
+ "Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved."
+
+#include "range_buffer.h"
+
+#include <string.h>
+
+#include "../portability/memory.h"
+#include "../util/dbt.h"
+
+namespace toku {
+
+bool range_buffer::record_header::left_is_infinite(void) const {
+ return left_neg_inf || left_pos_inf;
+}
+
+bool range_buffer::record_header::right_is_infinite(void) const {
+ return right_neg_inf || right_pos_inf;
+}
+
+void range_buffer::record_header::init(const DBT *left_key,
+ const DBT *right_key,
+ bool is_exclusive) {
+ is_exclusive_lock = is_exclusive;
+ left_neg_inf = left_key == toku_dbt_negative_infinity();
+ left_pos_inf = left_key == toku_dbt_positive_infinity();
+ left_key_size = toku_dbt_is_infinite(left_key) ? 0 : left_key->size;
+ if (right_key) {
+ right_neg_inf = right_key == toku_dbt_negative_infinity();
+ right_pos_inf = right_key == toku_dbt_positive_infinity();
+ right_key_size = toku_dbt_is_infinite(right_key) ? 0 : right_key->size;
+ } else {
+ right_neg_inf = left_neg_inf;
+ right_pos_inf = left_pos_inf;
+ right_key_size = 0;
+ }
+}
+
+const DBT *range_buffer::iterator::record::get_left_key(void) const {
+ if (_header.left_neg_inf) {
+ return toku_dbt_negative_infinity();
+ } else if (_header.left_pos_inf) {
+ return toku_dbt_positive_infinity();
+ } else {
+ return &_left_key;
+ }
+}
+
+const DBT *range_buffer::iterator::record::get_right_key(void) const {
+ if (_header.right_neg_inf) {
+ return toku_dbt_negative_infinity();
+ } else if (_header.right_pos_inf) {
+ return toku_dbt_positive_infinity();
+ } else {
+ return &_right_key;
+ }
+}
+
+size_t range_buffer::iterator::record::size(void) const {
+ return sizeof(record_header) + _header.left_key_size + _header.right_key_size;
+}
+
+void range_buffer::iterator::record::deserialize(const char *buf) {
+ size_t current = 0;
+
+ // deserialize the header
+ memcpy(&_header, buf, sizeof(record_header));
+ current += sizeof(record_header);
+
+ // deserialize the left key if necessary
+ if (!_header.left_is_infinite()) {
+ // point the left DBT's buffer into ours
+ toku_fill_dbt(&_left_key, buf + current, _header.left_key_size);
+ current += _header.left_key_size;
+ }
+
+ // deserialize the right key if necessary
+ if (!_header.right_is_infinite()) {
+ if (_header.right_key_size == 0) {
+ toku_copyref_dbt(&_right_key, _left_key);
+ } else {
+ toku_fill_dbt(&_right_key, buf + current, _header.right_key_size);
+ }
+ }
+}
+
+toku::range_buffer::iterator::iterator()
+ : _ma_chunk_iterator(nullptr),
+ _current_chunk_base(nullptr),
+ _current_chunk_offset(0),
+ _current_chunk_max(0),
+ _current_rec_size(0) {}
+
+toku::range_buffer::iterator::iterator(const range_buffer *buffer)
+ : _ma_chunk_iterator(&buffer->_arena),
+ _current_chunk_base(nullptr),
+ _current_chunk_offset(0),
+ _current_chunk_max(0),
+ _current_rec_size(0) {
+ reset_current_chunk();
+}
+
+void range_buffer::iterator::reset_current_chunk() {
+ _current_chunk_base = _ma_chunk_iterator.current(&_current_chunk_max);
+ _current_chunk_offset = 0;
+}
+
+bool range_buffer::iterator::current(record *rec) {
+ if (_current_chunk_offset < _current_chunk_max) {
+ const char *buf = reinterpret_cast<const char *>(_current_chunk_base);
+ rec->deserialize(buf + _current_chunk_offset);
+ _current_rec_size = rec->size();
+ return true;
+ } else {
+ return false;
+ }
+}
+
+// move the iterator to the next record in the buffer
+void range_buffer::iterator::next(void) {
+ invariant(_current_chunk_offset < _current_chunk_max);
+ invariant(_current_rec_size > 0);
+
+ // the next record is _current_rec_size bytes forward
+ _current_chunk_offset += _current_rec_size;
+ // now, we don't know how big the current is, set it to 0.
+ _current_rec_size = 0;
+
+ if (_current_chunk_offset >= _current_chunk_max) {
+ // current chunk is exhausted, try moving to the next one
+ if (_ma_chunk_iterator.more()) {
+ _ma_chunk_iterator.next();
+ reset_current_chunk();
+ }
+ }
+}
+
+void range_buffer::create(void) {
+ // allocate buffer space lazily instead of on creation. this way,
+ // no malloc/free is done if the transaction ends up taking no locks.
+ _arena.create(0);
+ _num_ranges = 0;
+}
+
+void range_buffer::append(const DBT *left_key, const DBT *right_key,
+ bool is_write_request) {
+ // if the keys are equal, then only one copy is stored.
+ if (toku_dbt_equals(left_key, right_key)) {
+ invariant(left_key->size <= MAX_KEY_SIZE);
+ append_point(left_key, is_write_request);
+ } else {
+ invariant(left_key->size <= MAX_KEY_SIZE);
+ invariant(right_key->size <= MAX_KEY_SIZE);
+ append_range(left_key, right_key, is_write_request);
+ }
+ _num_ranges++;
+}
+
+bool range_buffer::is_empty(void) const { return total_memory_size() == 0; }
+
+uint64_t range_buffer::total_memory_size(void) const {
+ return _arena.total_size_in_use();
+}
+
+int range_buffer::get_num_ranges(void) const { return _num_ranges; }
+
+void range_buffer::destroy(void) { _arena.destroy(); }
+
+void range_buffer::append_range(const DBT *left_key, const DBT *right_key,
+ bool is_exclusive) {
+ size_t record_length =
+ sizeof(record_header) + left_key->size + right_key->size;
+ char *buf = reinterpret_cast<char *>(_arena.malloc_from_arena(record_length));
+
+ record_header h;
+ h.init(left_key, right_key, is_exclusive);
+
+ // serialize the header
+ memcpy(buf, &h, sizeof(record_header));
+ buf += sizeof(record_header);
+
+ // serialize the left key if necessary
+ if (!h.left_is_infinite()) {
+ memcpy(buf, left_key->data, left_key->size);
+ buf += left_key->size;
+ }
+
+ // serialize the right key if necessary
+ if (!h.right_is_infinite()) {
+ memcpy(buf, right_key->data, right_key->size);
+ }
+}
+
+void range_buffer::append_point(const DBT *key, bool is_exclusive) {
+ size_t record_length = sizeof(record_header) + key->size;
+ char *buf = reinterpret_cast<char *>(_arena.malloc_from_arena(record_length));
+
+ record_header h;
+ h.init(key, nullptr, is_exclusive);
+
+ // serialize the header
+ memcpy(buf, &h, sizeof(record_header));
+ buf += sizeof(record_header);
+
+ // serialize the key if necessary
+ if (!h.left_is_infinite()) {
+ memcpy(buf, key->data, key->size);
+ }
+}
+
+} /* namespace toku */
+#endif // OS_WIN
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/range_buffer.h b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/range_buffer.h
new file mode 100644
index 000000000..76e28d747
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/range_buffer.h
@@ -0,0 +1,178 @@
+/* -*- mode: C++; c-basic-offset: 4; indent-tabs-mode: nil -*- */
+// vim: ft=cpp:expandtab:ts=8:sw=4:softtabstop=4:
+#ident "$Id$"
+/*======
+This file is part of PerconaFT.
+
+
+Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved.
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License, version 2,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License, version 3,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+======= */
+
+#ident \
+ "Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved."
+
+#pragma once
+
+#include <inttypes.h>
+#include <stdint.h>
+
+#include "../util/dbt.h"
+#include "../util/memarena.h"
+
+namespace toku {
+
+// a key range buffer represents a set of key ranges that can
+// be stored, iterated over, and then destroyed all at once.
+class range_buffer {
+ private:
+ // the key range buffer is a bunch of records in a row.
+ // each record has the following header, followed by the
+ // left key and right key data payload, if applicable.
+ // we limit keys to be 2^16, since we store lengths as 2 bytes.
+ static const size_t MAX_KEY_SIZE = 1 << 16;
+
+ struct record_header {
+ bool left_neg_inf;
+ bool left_pos_inf;
+ bool right_pos_inf;
+ bool right_neg_inf;
+ uint16_t left_key_size;
+ uint16_t right_key_size;
+ bool is_exclusive_lock;
+
+ bool left_is_infinite(void) const;
+
+ bool right_is_infinite(void) const;
+
+ void init(const DBT *left_key, const DBT *right_key, bool is_exclusive);
+ };
+ // PORT static_assert(sizeof(record_header) == 8, "record header format is
+ // off");
+
+ public:
+ // the iterator abstracts reading over a buffer of variable length
+ // records one by one until there are no more left.
+ class iterator {
+ public:
+ iterator();
+ iterator(const range_buffer *buffer);
+
+ // a record represents the user-view of a serialized key range.
+ // it handles positive and negative infinity and the optimized
+ // point range case, where left and right points share memory.
+ class record {
+ public:
+ // get a read-only pointer to the left key of this record's range
+ const DBT *get_left_key(void) const;
+
+ // get a read-only pointer to the right key of this record's range
+ const DBT *get_right_key(void) const;
+
+ // how big is this record? this tells us where the next record is
+ size_t size(void) const;
+
+ bool get_exclusive_flag() const { return _header.is_exclusive_lock; }
+
+ // populate a record header and point our DBT's
+ // buffers into ours if they are not infinite.
+ void deserialize(const char *buf);
+
+ private:
+ record_header _header;
+ DBT _left_key;
+ DBT _right_key;
+ };
+
+ // populate the given record object with the current
+ // the memory referred to by record is valid for only
+ // as long as the record exists.
+ bool current(record *rec);
+
+ // move the iterator to the next record in the buffer
+ void next(void);
+
+ private:
+ void reset_current_chunk();
+
+ // the key range buffer we are iterating over, the current
+ // offset in that buffer, and the size of the current record.
+ memarena::chunk_iterator _ma_chunk_iterator;
+ const void *_current_chunk_base;
+ size_t _current_chunk_offset;
+ size_t _current_chunk_max;
+ size_t _current_rec_size;
+ };
+
+ // allocate buffer space lazily instead of on creation. this way,
+ // no malloc/free is done if the transaction ends up taking no locks.
+ void create(void);
+
+ // append a left/right key range to the buffer.
+ // if the keys are equal, then only one copy is stored.
+ void append(const DBT *left_key, const DBT *right_key,
+ bool is_write_request = false);
+
+ // is this range buffer empty?
+ bool is_empty(void) const;
+
+ // how much memory is being used by this range buffer?
+ uint64_t total_memory_size(void) const;
+
+ // how many ranges are stored in this range buffer?
+ int get_num_ranges(void) const;
+
+ void destroy(void);
+
+ private:
+ memarena _arena;
+ int _num_ranges;
+
+ void append_range(const DBT *left_key, const DBT *right_key,
+ bool is_write_request);
+
+ // append a point to the buffer. this is the space/time saving
+ // optimization for key ranges where left == right.
+ void append_point(const DBT *key, bool is_write_request);
+};
+
+} /* namespace toku */
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/treenode.cc b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/treenode.cc
new file mode 100644
index 000000000..8997f634b
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/treenode.cc
@@ -0,0 +1,520 @@
+/* -*- mode: C++; c-basic-offset: 4; indent-tabs-mode: nil -*- */
+// vim: ft=cpp:expandtab:ts=8:sw=4:softtabstop=4:
+#ifndef ROCKSDB_LITE
+#ifndef OS_WIN
+#ident "$Id$"
+/*======
+This file is part of PerconaFT.
+
+
+Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved.
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License, version 2,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License, version 3,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+======= */
+
+#ident \
+ "Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved."
+
+#include "treenode.h"
+
+#include "../portability/toku_race_tools.h"
+
+namespace toku {
+
+// TODO: source location info might have to be pulled up one caller
+// to be useful
+void treenode::mutex_lock(void) { toku_mutex_lock(&m_mutex); }
+
+void treenode::mutex_unlock(void) { toku_mutex_unlock(&m_mutex); }
+
+void treenode::init(const comparator *cmp) {
+ m_txnid = TXNID_NONE;
+ m_is_root = false;
+ m_is_empty = true;
+ m_cmp = cmp;
+
+ m_is_shared = false;
+ m_owners = nullptr;
+
+ // use an adaptive mutex at each node since we expect the time the
+ // lock is held to be relatively short compared to a context switch.
+ // indeed, this improves performance at high thread counts considerably.
+ memset(&m_mutex, 0, sizeof(toku_mutex_t));
+ toku_pthread_mutexattr_t attr;
+ toku_mutexattr_init(&attr);
+ toku_mutexattr_settype(&attr, TOKU_MUTEX_ADAPTIVE);
+ toku_mutex_init(treenode_mutex_key, &m_mutex, &attr);
+ toku_mutexattr_destroy(&attr);
+ m_left_child.set(nullptr);
+ m_right_child.set(nullptr);
+}
+
+void treenode::create_root(const comparator *cmp) {
+ init(cmp);
+ m_is_root = true;
+}
+
+void treenode::destroy_root(void) {
+ invariant(is_root());
+ invariant(is_empty());
+ toku_mutex_destroy(&m_mutex);
+ m_cmp = nullptr;
+}
+
+void treenode::set_range_and_txnid(const keyrange &range, TXNID txnid,
+ bool is_shared) {
+ // allocates a new copy of the range for this node
+ m_range.create_copy(range);
+ m_txnid = txnid;
+ m_is_shared = is_shared;
+ m_is_empty = false;
+}
+
+bool treenode::is_root(void) { return m_is_root; }
+
+bool treenode::is_empty(void) { return m_is_empty; }
+
+bool treenode::range_overlaps(const keyrange &range) {
+ return m_range.overlaps(*m_cmp, range);
+}
+
+treenode *treenode::alloc(const comparator *cmp, const keyrange &range,
+ TXNID txnid, bool is_shared) {
+ treenode *XCALLOC(node);
+ node->init(cmp);
+ node->set_range_and_txnid(range, txnid, is_shared);
+ return node;
+}
+
+void treenode::swap_in_place(treenode *node1, treenode *node2) {
+ keyrange tmp_range = node1->m_range;
+ TXNID tmp_txnid = node1->m_txnid;
+ node1->m_range = node2->m_range;
+ node1->m_txnid = node2->m_txnid;
+ node2->m_range = tmp_range;
+ node2->m_txnid = tmp_txnid;
+
+ bool tmp_is_shared = node1->m_is_shared;
+ node1->m_is_shared = node2->m_is_shared;
+ node2->m_is_shared = tmp_is_shared;
+
+ auto tmp_m_owners = node1->m_owners;
+ node1->m_owners = node2->m_owners;
+ node2->m_owners = tmp_m_owners;
+}
+
+bool treenode::add_shared_owner(TXNID txnid) {
+ assert(m_is_shared);
+ if (txnid == m_txnid)
+ return false; // acquiring a lock on the same range by the same trx
+
+ if (m_txnid != TXNID_SHARED) {
+ m_owners = new TxnidVector;
+ m_owners->insert(m_txnid);
+ m_txnid = TXNID_SHARED;
+ }
+ m_owners->insert(txnid);
+ return true;
+}
+
+void treenode::free(treenode *node) {
+ // destroy the range, freeing any copied keys
+ node->m_range.destroy();
+
+ if (node->m_owners) {
+ delete node->m_owners;
+ node->m_owners = nullptr; // need this?
+ }
+
+ // the root is simply marked as empty.
+ if (node->is_root()) {
+ // PORT toku_mutex_assert_locked(&node->m_mutex);
+ node->m_is_empty = true;
+ } else {
+ // PORT toku_mutex_assert_unlocked(&node->m_mutex);
+ toku_mutex_destroy(&node->m_mutex);
+ toku_free(node);
+ }
+}
+
+uint32_t treenode::get_depth_estimate(void) const {
+ const uint32_t left_est = m_left_child.depth_est;
+ const uint32_t right_est = m_right_child.depth_est;
+ return (left_est > right_est ? left_est : right_est) + 1;
+}
+
+treenode *treenode::find_node_with_overlapping_child(
+ const keyrange &range, const keyrange::comparison *cmp_hint) {
+ // determine which child to look at based on a comparison. if we were
+ // given a comparison hint, use that. otherwise, compare them now.
+ keyrange::comparison c =
+ cmp_hint ? *cmp_hint : range.compare(*m_cmp, m_range);
+
+ treenode *child;
+ if (c == keyrange::comparison::LESS_THAN) {
+ child = lock_and_rebalance_left();
+ } else {
+ // The caller (locked_keyrange::acquire) handles the case where
+ // the root of the locked_keyrange is the node that overlaps.
+ // range is guaranteed not to overlap this node.
+ invariant(c == keyrange::comparison::GREATER_THAN);
+ child = lock_and_rebalance_right();
+ }
+
+ // if the search would lead us to an empty subtree (child == nullptr),
+ // or the child overlaps, then we know this node is the parent we want.
+ // otherwise we need to recur into that child.
+ if (child == nullptr) {
+ return this;
+ } else {
+ c = range.compare(*m_cmp, child->m_range);
+ if (c == keyrange::comparison::EQUALS ||
+ c == keyrange::comparison::OVERLAPS) {
+ child->mutex_unlock();
+ return this;
+ } else {
+ // unlock this node before recurring into the locked child,
+ // passing in a comparison hint since we just comapred range
+ // to the child's range.
+ mutex_unlock();
+ return child->find_node_with_overlapping_child(range, &c);
+ }
+ }
+}
+
+bool treenode::insert(const keyrange &range, TXNID txnid, bool is_shared) {
+ int rc = true;
+ // choose a child to check. if that child is null, then insert the new node
+ // there. otherwise recur down that child's subtree
+ keyrange::comparison c = range.compare(*m_cmp, m_range);
+ if (c == keyrange::comparison::LESS_THAN) {
+ treenode *left_child = lock_and_rebalance_left();
+ if (left_child == nullptr) {
+ left_child = treenode::alloc(m_cmp, range, txnid, is_shared);
+ m_left_child.set(left_child);
+ } else {
+ left_child->insert(range, txnid, is_shared);
+ left_child->mutex_unlock();
+ }
+ } else if (c == keyrange::comparison::GREATER_THAN) {
+ // invariant(c == keyrange::comparison::GREATER_THAN);
+ treenode *right_child = lock_and_rebalance_right();
+ if (right_child == nullptr) {
+ right_child = treenode::alloc(m_cmp, range, txnid, is_shared);
+ m_right_child.set(right_child);
+ } else {
+ right_child->insert(range, txnid, is_shared);
+ right_child->mutex_unlock();
+ }
+ } else if (c == keyrange::comparison::EQUALS) {
+ invariant(is_shared);
+ invariant(m_is_shared);
+ rc = add_shared_owner(txnid);
+ } else {
+ invariant(0);
+ }
+ return rc;
+}
+
+treenode *treenode::find_child_at_extreme(int direction, treenode **parent) {
+ treenode *child =
+ direction > 0 ? m_right_child.get_locked() : m_left_child.get_locked();
+
+ if (child) {
+ *parent = this;
+ treenode *child_extreme = child->find_child_at_extreme(direction, parent);
+ child->mutex_unlock();
+ return child_extreme;
+ } else {
+ return this;
+ }
+}
+
+treenode *treenode::find_leftmost_child(treenode **parent) {
+ return find_child_at_extreme(-1, parent);
+}
+
+treenode *treenode::find_rightmost_child(treenode **parent) {
+ return find_child_at_extreme(1, parent);
+}
+
+treenode *treenode::remove_root_of_subtree() {
+ // if this node has no children, just free it and return null
+ if (m_left_child.ptr == nullptr && m_right_child.ptr == nullptr) {
+ // treenode::free requires that non-root nodes are unlocked
+ if (!is_root()) {
+ mutex_unlock();
+ }
+ treenode::free(this);
+ return nullptr;
+ }
+
+ // we have a child, so get either the in-order successor or
+ // predecessor of this node to be our replacement.
+ // replacement_parent is updated by the find functions as
+ // they recur down the tree, so initialize it to this.
+ treenode *child, *replacement;
+ treenode *replacement_parent = this;
+ if (m_left_child.ptr != nullptr) {
+ child = m_left_child.get_locked();
+ replacement = child->find_rightmost_child(&replacement_parent);
+ invariant(replacement == child || replacement_parent != this);
+
+ // detach the replacement from its parent
+ if (replacement_parent == this) {
+ m_left_child = replacement->m_left_child;
+ } else {
+ replacement_parent->m_right_child = replacement->m_left_child;
+ }
+ } else {
+ child = m_right_child.get_locked();
+ replacement = child->find_leftmost_child(&replacement_parent);
+ invariant(replacement == child || replacement_parent != this);
+
+ // detach the replacement from its parent
+ if (replacement_parent == this) {
+ m_right_child = replacement->m_right_child;
+ } else {
+ replacement_parent->m_left_child = replacement->m_right_child;
+ }
+ }
+ child->mutex_unlock();
+
+ // swap in place with the detached replacement, then destroy it
+ treenode::swap_in_place(replacement, this);
+ treenode::free(replacement);
+
+ return this;
+}
+
+void treenode::recursive_remove(void) {
+ treenode *left = m_left_child.ptr;
+ if (left) {
+ left->recursive_remove();
+ }
+ m_left_child.set(nullptr);
+
+ treenode *right = m_right_child.ptr;
+ if (right) {
+ right->recursive_remove();
+ }
+ m_right_child.set(nullptr);
+
+ // we do not take locks on the way down, so we know non-root nodes
+ // are unlocked here and the caller is required to pass a locked
+ // root, so this free is correct.
+ treenode::free(this);
+}
+
+void treenode::remove_shared_owner(TXNID txnid) {
+ assert(m_owners->size() > 1);
+ m_owners->erase(txnid);
+ assert(m_owners->size() > 0);
+ /* if there is just one owner left, move it to m_txnid */
+ if (m_owners->size() == 1) {
+ m_txnid = *m_owners->begin();
+ delete m_owners;
+ m_owners = nullptr;
+ }
+}
+
+treenode *treenode::remove(const keyrange &range, TXNID txnid) {
+ treenode *child;
+ // if the range is equal to this node's range, then just remove
+ // the root of this subtree. otherwise search down the tree
+ // in either the left or right children.
+ keyrange::comparison c = range.compare(*m_cmp, m_range);
+ switch (c) {
+ case keyrange::comparison::EQUALS: {
+ // if we are the only owners, remove. Otherwise, just remove
+ // us from the owners list.
+ if (txnid != TXNID_ANY && has_multiple_owners()) {
+ remove_shared_owner(txnid);
+ return this;
+ } else {
+ return remove_root_of_subtree();
+ }
+ }
+ case keyrange::comparison::LESS_THAN:
+ child = m_left_child.get_locked();
+ invariant_notnull(child);
+ child = child->remove(range, txnid);
+
+ // unlock the child if there still is one.
+ // regardless, set the right child pointer
+ if (child) {
+ child->mutex_unlock();
+ }
+ m_left_child.set(child);
+ break;
+ case keyrange::comparison::GREATER_THAN:
+ child = m_right_child.get_locked();
+ invariant_notnull(child);
+ child = child->remove(range, txnid);
+
+ // unlock the child if there still is one.
+ // regardless, set the right child pointer
+ if (child) {
+ child->mutex_unlock();
+ }
+ m_right_child.set(child);
+ break;
+ case keyrange::comparison::OVERLAPS:
+ // shouldn't be overlapping, since the tree is
+ // non-overlapping and this range must exist
+ abort();
+ }
+
+ return this;
+}
+
+bool treenode::left_imbalanced(int threshold) const {
+ uint32_t left_depth = m_left_child.depth_est;
+ uint32_t right_depth = m_right_child.depth_est;
+ return m_left_child.ptr != nullptr && left_depth > threshold + right_depth;
+}
+
+bool treenode::right_imbalanced(int threshold) const {
+ uint32_t left_depth = m_left_child.depth_est;
+ uint32_t right_depth = m_right_child.depth_est;
+ return m_right_child.ptr != nullptr && right_depth > threshold + left_depth;
+}
+
+// effect: rebalances the subtree rooted at this node
+// using AVL style O(1) rotations. unlocks this
+// node if it is not the new root of the subtree.
+// requires: node is locked by this thread, children are not
+// returns: locked root node of the rebalanced tree
+treenode *treenode::maybe_rebalance(void) {
+ // if we end up not rotating at all, the new root is this
+ treenode *new_root = this;
+ treenode *child = nullptr;
+
+ if (left_imbalanced(IMBALANCE_THRESHOLD)) {
+ child = m_left_child.get_locked();
+ if (child->right_imbalanced(0)) {
+ treenode *grandchild = child->m_right_child.get_locked();
+
+ child->m_right_child = grandchild->m_left_child;
+ grandchild->m_left_child.set(child);
+
+ m_left_child = grandchild->m_right_child;
+ grandchild->m_right_child.set(this);
+
+ new_root = grandchild;
+ } else {
+ m_left_child = child->m_right_child;
+ child->m_right_child.set(this);
+ new_root = child;
+ }
+ } else if (right_imbalanced(IMBALANCE_THRESHOLD)) {
+ child = m_right_child.get_locked();
+ if (child->left_imbalanced(0)) {
+ treenode *grandchild = child->m_left_child.get_locked();
+
+ child->m_left_child = grandchild->m_right_child;
+ grandchild->m_right_child.set(child);
+
+ m_right_child = grandchild->m_left_child;
+ grandchild->m_left_child.set(this);
+
+ new_root = grandchild;
+ } else {
+ m_right_child = child->m_left_child;
+ child->m_left_child.set(this);
+ new_root = child;
+ }
+ }
+
+ // up to three nodes may be locked.
+ // - this
+ // - child
+ // - grandchild (but if it is locked, its the new root)
+ //
+ // one of them is the new root. we unlock everything except the new root.
+ if (child && child != new_root) {
+ TOKU_VALGRIND_RESET_MUTEX_ORDERING_INFO(&child->m_mutex);
+ child->mutex_unlock();
+ }
+ if (this != new_root) {
+ TOKU_VALGRIND_RESET_MUTEX_ORDERING_INFO(&m_mutex);
+ mutex_unlock();
+ }
+ TOKU_VALGRIND_RESET_MUTEX_ORDERING_INFO(&new_root->m_mutex);
+ return new_root;
+}
+
+treenode *treenode::lock_and_rebalance_left(void) {
+ treenode *child = m_left_child.get_locked();
+ if (child) {
+ treenode *new_root = child->maybe_rebalance();
+ m_left_child.set(new_root);
+ child = new_root;
+ }
+ return child;
+}
+
+treenode *treenode::lock_and_rebalance_right(void) {
+ treenode *child = m_right_child.get_locked();
+ if (child) {
+ treenode *new_root = child->maybe_rebalance();
+ m_right_child.set(new_root);
+ child = new_root;
+ }
+ return child;
+}
+
+void treenode::child_ptr::set(treenode *node) {
+ ptr = node;
+ depth_est = ptr ? ptr->get_depth_estimate() : 0;
+}
+
+treenode *treenode::child_ptr::get_locked(void) {
+ if (ptr) {
+ ptr->mutex_lock();
+ depth_est = ptr->get_depth_estimate();
+ }
+ return ptr;
+}
+
+} /* namespace toku */
+#endif // OS_WIN
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/treenode.h b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/treenode.h
new file mode 100644
index 000000000..ec25a8c58
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/treenode.h
@@ -0,0 +1,302 @@
+/* -*- mode: C++; c-basic-offset: 4; indent-tabs-mode: nil -*- */
+// vim: ft=cpp:expandtab:ts=8:sw=2:softtabstop=2:
+#ident "$Id$"
+/*======
+This file is part of PerconaFT.
+
+
+Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved.
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License, version 2,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License, version 3,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+======= */
+
+#ident \
+ "Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved."
+
+#pragma once
+
+#include <string.h>
+
+#include "../ft/comparator.h"
+#include "../portability/memory.h"
+#include "../portability/toku_pthread.h"
+// PORT: we need LTM_STATUS
+#include "../ft/ft-status.h"
+#include "../portability/txn_subst.h"
+#include "keyrange.h"
+
+namespace toku {
+
+// a node in a tree with its own mutex
+// - range is the "key" of this node
+// - txnid is the single txnid associated with this node
+// - left and right children may be null
+//
+// to build a tree on top of this abstraction, the user:
+// - provides memory for a root node, initializes it via create_root()
+// - performs tree operations on the root node. memory management
+// below the root node is handled by the abstraction, not the user.
+// this pattern:
+// - guaruntees a root node always exists.
+// - does not allow for rebalances on the root node
+
+class treenode {
+ public:
+ // every treenode function has some common requirements:
+ // - node is locked and children are never locked
+ // - node may be unlocked if no other thread has visibility
+
+ // effect: create the root node
+ void create_root(const comparator *cmp);
+
+ // effect: destroys the root node
+ void destroy_root(void);
+
+ // effect: sets the txnid and copies the given range for this node
+ void set_range_and_txnid(const keyrange &range, TXNID txnid, bool is_shared);
+
+ // returns: true iff this node is marked as empty
+ bool is_empty(void);
+
+ // returns: true if this is the root node, denoted by a null parent
+ bool is_root(void);
+
+ // returns: true if the given range overlaps with this node's range
+ bool range_overlaps(const keyrange &range);
+
+ // effect: locks the node
+ void mutex_lock(void);
+
+ // effect: unlocks the node
+ void mutex_unlock(void);
+
+ // return: node whose child overlaps, or a child that is empty
+ // and would contain range if it existed
+ // given: if cmp_hint is non-null, then it is a precomputed
+ // comparison of this node's range to the given range.
+ treenode *find_node_with_overlapping_child(
+ const keyrange &range, const keyrange::comparison *cmp_hint);
+
+ // effect: performs an in-order traversal of the ranges that overlap the
+ // given range, calling function->fn() on each node that does
+ // requires: function signature is: bool fn(const keyrange &range, TXNID
+ // txnid) requires: fn returns true to keep iterating, false to stop iterating
+ // requires: fn does not attempt to use any ranges read out by value
+ // after removing a node with an overlapping range from the tree.
+ template <class F>
+ void traverse_overlaps(const keyrange &range, F *function) {
+ keyrange::comparison c = range.compare(*m_cmp, m_range);
+ if (c == keyrange::comparison::EQUALS) {
+ // Doesn't matter if fn wants to keep going, there
+ // is nothing left, so return.
+ function->fn(m_range, m_txnid, m_is_shared, m_owners);
+ return;
+ }
+
+ treenode *left = m_left_child.get_locked();
+ if (left) {
+ if (c != keyrange::comparison::GREATER_THAN) {
+ // Target range is less than this node, or it overlaps this
+ // node. There may be something on the left.
+ left->traverse_overlaps(range, function);
+ }
+ left->mutex_unlock();
+ }
+
+ if (c == keyrange::comparison::OVERLAPS) {
+ bool keep_going = function->fn(m_range, m_txnid, m_is_shared, m_owners);
+ if (!keep_going) {
+ return;
+ }
+ }
+
+ treenode *right = m_right_child.get_locked();
+ if (right) {
+ if (c != keyrange::comparison::LESS_THAN) {
+ // Target range is greater than this node, or it overlaps this
+ // node. There may be something on the right.
+ right->traverse_overlaps(range, function);
+ }
+ right->mutex_unlock();
+ }
+ }
+
+ // effect: inserts the given range and txnid into a subtree, recursively
+ // requires: range does not overlap with any node below the subtree
+ bool insert(const keyrange &range, TXNID txnid, bool is_shared);
+
+ // effect: removes the given range from the subtree
+ // requires: range exists in the subtree
+ // returns: the root of the resulting subtree
+ treenode *remove(const keyrange &range, TXNID txnid);
+
+ // effect: removes this node and all of its children, recursively
+ // requires: every node at and below this node is unlocked
+ void recursive_remove(void);
+
+ private:
+ // the child_ptr is a light abstraction for the locking of
+ // a child and the maintenence of its depth estimate.
+
+ struct child_ptr {
+ // set the child pointer
+ void set(treenode *node);
+
+ // get and lock this child if it exists
+ treenode *get_locked(void);
+
+ treenode *ptr;
+ uint32_t depth_est;
+ };
+
+ // the balance factor at which a node is considered imbalanced
+ static const int32_t IMBALANCE_THRESHOLD = 2;
+
+ // node-level mutex
+ toku_mutex_t m_mutex;
+
+ // the range and txnid for this node. the range contains a copy
+ // of the keys originally inserted into the tree. nodes may
+ // swap ranges. but at the end of the day, when a node is
+ // destroyed, it frees the memory associated with whatever range
+ // it has at the time of destruction.
+ keyrange m_range;
+
+ void remove_shared_owner(TXNID txnid);
+
+ bool has_multiple_owners() { return (m_txnid == TXNID_SHARED); }
+
+ private:
+ // Owner transaction id.
+ // A value of TXNID_SHARED means this node has multiple owners
+ TXNID m_txnid;
+
+ // If true, this lock is a non-exclusive lock, and it can have either
+ // one or several owners.
+ bool m_is_shared;
+
+ // List of the owners, or nullptr if there's just one owner.
+ TxnidVector *m_owners;
+
+ // two child pointers
+ child_ptr m_left_child;
+ child_ptr m_right_child;
+
+ // comparator for ranges
+ // psergey-todo: Is there any sense to store the comparator in each tree
+ // node?
+ const comparator *m_cmp;
+
+ // marked for the root node. the root node is never free()'d
+ // when removed, but instead marked as empty.
+ bool m_is_root;
+
+ // marked for an empty node. only valid for the root.
+ bool m_is_empty;
+
+ // effect: initializes an empty node with the given comparator
+ void init(const comparator *cmp);
+
+ // requires: this is a shared node (m_is_shared==true)
+ // effect: another transaction is added as an owner.
+ // returns: true <=> added another owner
+ // false <=> this transaction is already an owner
+ bool add_shared_owner(TXNID txnid);
+
+ // requires: *parent is initialized to something meaningful.
+ // requires: subtree is non-empty
+ // returns: the leftmost child of the given subtree
+ // returns: a pointer to the parent of said child in *parent, only
+ // if this function recurred, otherwise it is untouched.
+ treenode *find_leftmost_child(treenode **parent);
+
+ // requires: *parent is initialized to something meaningful.
+ // requires: subtree is non-empty
+ // returns: the rightmost child of the given subtree
+ // returns: a pointer to the parent of said child in *parent, only
+ // if this function recurred, otherwise it is untouched.
+ treenode *find_rightmost_child(treenode **parent);
+
+ // effect: remove the root of this subtree, destroying the old root
+ // returns: the new root of the subtree
+ treenode *remove_root_of_subtree(void);
+
+ // requires: subtree is non-empty, direction is not 0
+ // returns: the child of the subtree at either the left or rightmost extreme
+ treenode *find_child_at_extreme(int direction, treenode **parent);
+
+ // effect: retrieves and possibly rebalances the left child
+ // returns: a locked left child, if it exists
+ treenode *lock_and_rebalance_left(void);
+
+ // effect: retrieves and possibly rebalances the right child
+ // returns: a locked right child, if it exists
+ treenode *lock_and_rebalance_right(void);
+
+ // returns: the estimated depth of this subtree
+ uint32_t get_depth_estimate(void) const;
+
+ // returns: true iff left subtree depth is sufficiently less than the right
+ bool left_imbalanced(int threshold) const;
+
+ // returns: true iff right subtree depth is sufficiently greater than the left
+ bool right_imbalanced(int threshold) const;
+
+ // effect: performs an O(1) rebalance, which will "heal" an imbalance by at
+ // most 1. effect: if the new root is not this node, then this node is
+ // unlocked. returns: locked node representing the new root of the rebalanced
+ // subtree
+ treenode *maybe_rebalance(void);
+
+ // returns: allocated treenode populated with a copy of the range and txnid
+ static treenode *alloc(const comparator *cmp, const keyrange &range,
+ TXNID txnid, bool is_shared);
+
+ // requires: node is a locked root node, or an unlocked non-root node
+ static void free(treenode *node);
+
+ // effect: swaps the range/txnid pairs for node1 and node2.
+ static void swap_in_place(treenode *node1, treenode *node2);
+
+ friend class concurrent_tree_unit_test;
+};
+
+} /* namespace toku */
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/txnid_set.cc b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/txnid_set.cc
new file mode 100644
index 000000000..4caf1e26f
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/txnid_set.cc
@@ -0,0 +1,120 @@
+/* -*- mode: C++; c-basic-offset: 4; indent-tabs-mode: nil -*- */
+// vim: ft=cpp:expandtab:ts=8:sw=4:softtabstop=4:
+#ifndef ROCKSDB_LITE
+#ifndef OS_WIN
+#ident "$Id$"
+/*======
+This file is part of PerconaFT.
+
+
+Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved.
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License, version 2,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License, version 3,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+======= */
+
+#ident \
+ "Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved."
+
+#include "txnid_set.h"
+
+#include "../db.h"
+
+namespace toku {
+
+int find_by_txnid(const TXNID &txnid_a, const TXNID &txnid_b);
+int find_by_txnid(const TXNID &txnid_a, const TXNID &txnid_b) {
+ if (txnid_a < txnid_b) {
+ return -1;
+ } else if (txnid_a == txnid_b) {
+ return 0;
+ } else {
+ return 1;
+ }
+}
+
+void txnid_set::create(void) {
+ // lazily allocate the underlying omt, since it is common
+ // to create a txnid set and never put anything in it.
+ m_txnids.create_no_array();
+}
+
+void txnid_set::destroy(void) { m_txnids.destroy(); }
+
+// Return true if the given transaction id is a member of the set.
+// Otherwise, return false.
+bool txnid_set::contains(TXNID txnid) const {
+ TXNID find_txnid;
+ int r = m_txnids.find_zero<TXNID, find_by_txnid>(txnid, &find_txnid, nullptr);
+ return r == 0 ? true : false;
+}
+
+// Add a given txnid to the set
+void txnid_set::add(TXNID txnid) {
+ int r = m_txnids.insert<TXNID, find_by_txnid>(txnid, txnid, nullptr);
+ invariant(r == 0 || r == DB_KEYEXIST);
+}
+
+// Delete a given txnid from the set.
+void txnid_set::remove(TXNID txnid) {
+ uint32_t idx;
+ int r = m_txnids.find_zero<TXNID, find_by_txnid>(txnid, nullptr, &idx);
+ if (r == 0) {
+ r = m_txnids.delete_at(idx);
+ invariant_zero(r);
+ }
+}
+
+// Return the size of the set
+uint32_t txnid_set::size(void) const { return m_txnids.size(); }
+
+// Get the ith id in the set, assuming that the set is sorted.
+TXNID txnid_set::get(uint32_t i) const {
+ TXNID txnid;
+ int r = m_txnids.fetch(i, &txnid);
+ if (r == EINVAL) /* Shouldn't happen, avoid compiler warning */
+ return TXNID_NONE;
+ invariant_zero(r);
+ return txnid;
+}
+
+} /* namespace toku */
+#endif // OS_WIN
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/txnid_set.h b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/txnid_set.h
new file mode 100644
index 000000000..d79c24fb0
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/txnid_set.h
@@ -0,0 +1,92 @@
+/* -*- mode: C++; c-basic-offset: 4; indent-tabs-mode: nil -*- */
+// vim: ft=cpp:expandtab:ts=8:sw=4:softtabstop=4:
+#ident "$Id$"
+/*======
+This file is part of PerconaFT.
+
+
+Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved.
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License, version 2,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License, version 3,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+======= */
+
+#ident \
+ "Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved."
+
+#pragma once
+
+#include "../portability/txn_subst.h"
+#include "../util/omt.h"
+
+namespace toku {
+
+class txnid_set {
+ public:
+ // effect: Creates an empty set. Does not malloc space for
+ // any entries yet. That is done lazily on add().
+ void create(void);
+
+ // effect: Destroy the set's internals.
+ void destroy(void);
+
+ // returns: True if the given txnid is a member of the set.
+ bool contains(TXNID id) const;
+
+ // effect: Adds a given txnid to the set if it did not exist
+ void add(TXNID txnid);
+
+ // effect: Deletes a txnid from the set if it exists.
+ void remove(TXNID txnid);
+
+ // returns: Size of the set
+ uint32_t size(void) const;
+
+ // returns: The "i'th" id in the set, as if it were sorted.
+ TXNID get(uint32_t i) const;
+
+ private:
+ toku::omt<TXNID> m_txnids;
+
+ friend class txnid_set_unit_test;
+};
+ENSURE_POD(txnid_set);
+
+} /* namespace toku */
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/wfg.cc b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/wfg.cc
new file mode 100644
index 000000000..24536c88e
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/wfg.cc
@@ -0,0 +1,213 @@
+/* -*- mode: C++; c-basic-offset: 4; indent-tabs-mode: nil -*- */
+// vim: ft=cpp:expandtab:ts=8:sw=4:softtabstop=4:
+#ifndef ROCKSDB_LITE
+#ifndef OS_WIN
+#ident "$Id$"
+/*======
+This file is part of PerconaFT.
+
+
+Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved.
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License, version 2,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License, version 3,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+======= */
+
+#ident \
+ "Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved."
+
+#include "../db.h"
+#include "../portability/memory.h"
+// PORT #include <toku_assert.h>
+#include <memory.h>
+#include <string.h>
+
+#include "txnid_set.h"
+#include "wfg.h"
+
+namespace toku {
+
+// Create a lock request graph
+void wfg::create(void) { m_nodes.create(); }
+
+// Destroy the internals of the lock request graph
+void wfg::destroy(void) {
+ uint32_t n_nodes = m_nodes.size();
+ for (uint32_t i = 0; i < n_nodes; i++) {
+ node *n;
+ int r = m_nodes.fetch(i, &n);
+ invariant_zero(r);
+ invariant_notnull(n);
+ if (r) continue; // Get rid of "may be used uninitialized" warning
+ node::free(n);
+ }
+ m_nodes.destroy();
+}
+
+// Add an edge (a_id, b_id) to the graph
+void wfg::add_edge(TXNID a_txnid, TXNID b_txnid) {
+ node *a_node = find_create_node(a_txnid);
+ node *b_node = find_create_node(b_txnid);
+ a_node->edges.add(b_node->txnid);
+}
+
+// Return true if a node with the given transaction id exists in the graph.
+// Return false otherwise.
+bool wfg::node_exists(TXNID txnid) {
+ node *n = find_node(txnid);
+ return n != NULL;
+}
+
+bool wfg::cycle_exists_from_node(node *target, node *head,
+ std::function<void(TXNID)> reporter) {
+ bool cycle_found = false;
+ head->visited = true;
+ uint32_t n_edges = head->edges.size();
+ for (uint32_t i = 0; i < n_edges && !cycle_found; i++) {
+ TXNID edge_id = head->edges.get(i);
+ if (target->txnid == edge_id) {
+ cycle_found = true;
+ if (reporter) reporter(edge_id);
+ } else {
+ node *new_head = find_node(edge_id);
+ if (new_head && !new_head->visited) {
+ cycle_found = cycle_exists_from_node(target, new_head, reporter);
+ if (cycle_found && reporter) reporter(edge_id);
+ }
+ }
+ }
+ head->visited = false;
+ return cycle_found;
+}
+
+// Return true if there exists a cycle from a given transaction id in the graph.
+// Return false otherwise.
+bool wfg::cycle_exists_from_txnid(TXNID txnid,
+ std::function<void(TXNID)> reporter) {
+ node *a_node = find_node(txnid);
+ bool cycles_found = false;
+ if (a_node) {
+ cycles_found = cycle_exists_from_node(a_node, a_node, reporter);
+ }
+ return cycles_found;
+}
+
+// Apply a given function f to all of the nodes in the graph. The apply
+// function returns when the function f is called for all of the nodes in the
+// graph, or the function f returns non-zero.
+void wfg::apply_nodes(int (*fn)(TXNID id, void *extra), void *extra) {
+ int r = 0;
+ uint32_t n_nodes = m_nodes.size();
+ for (uint32_t i = 0; i < n_nodes && r == 0; i++) {
+ node *n;
+ r = m_nodes.fetch(i, &n);
+ invariant_zero(r);
+ if (r) continue; // Get rid of "may be used uninitialized" warning
+ r = fn(n->txnid, extra);
+ }
+}
+
+// Apply a given function f to all of the edges whose origin is a given node id.
+// The apply function returns when the function f is called for all edges in the
+// graph rooted at node id, or the function f returns non-zero.
+void wfg::apply_edges(TXNID txnid,
+ int (*fn)(TXNID txnid, TXNID edge_txnid, void *extra),
+ void *extra) {
+ node *n = find_node(txnid);
+ if (n) {
+ int r = 0;
+ uint32_t n_edges = n->edges.size();
+ for (uint32_t i = 0; i < n_edges && r == 0; i++) {
+ r = fn(txnid, n->edges.get(i), extra);
+ }
+ }
+}
+
+// find node by id
+wfg::node *wfg::find_node(TXNID txnid) {
+ node *n = nullptr;
+ int r = m_nodes.find_zero<TXNID, find_by_txnid>(txnid, &n, nullptr);
+ invariant(r == 0 || r == DB_NOTFOUND);
+ return n;
+}
+
+// this is the omt comparison function
+// nodes are compared by their txnid.
+int wfg::find_by_txnid(node *const &node_a, const TXNID &txnid_b) {
+ TXNID txnid_a = node_a->txnid;
+ if (txnid_a < txnid_b) {
+ return -1;
+ } else if (txnid_a == txnid_b) {
+ return 0;
+ } else {
+ return 1;
+ }
+}
+
+// insert a new node
+wfg::node *wfg::find_create_node(TXNID txnid) {
+ node *n;
+ uint32_t idx;
+ int r = m_nodes.find_zero<TXNID, find_by_txnid>(txnid, &n, &idx);
+ if (r == DB_NOTFOUND) {
+ n = node::alloc(txnid);
+ r = m_nodes.insert_at(n, idx);
+ invariant_zero(r);
+ }
+ invariant_notnull(n);
+ return n;
+}
+
+wfg::node *wfg::node::alloc(TXNID txnid) {
+ node *XCALLOC(n);
+ n->txnid = txnid;
+ n->visited = false;
+ n->edges.create();
+ return n;
+}
+
+void wfg::node::free(wfg::node *n) {
+ n->edges.destroy();
+ toku_free(n);
+}
+
+} /* namespace toku */
+#endif // OS_WIN
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/wfg.h b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/wfg.h
new file mode 100644
index 000000000..804202170
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/locktree/wfg.h
@@ -0,0 +1,124 @@
+/* -*- mode: C++; c-basic-offset: 4; indent-tabs-mode: nil -*- */
+// vim: ft=cpp:expandtab:ts=8:sw=4:softtabstop=4:
+#ident "$Id$"
+/*======
+This file is part of PerconaFT.
+
+
+Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved.
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License, version 2,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License, version 3,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+======= */
+
+#ident \
+ "Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved."
+
+#pragma once
+
+#include <functional>
+
+#include "../util/omt.h"
+#include "txnid_set.h"
+
+namespace toku {
+
+// A wfg is a 'wait-for' graph. A directed edge in represents one
+// txn waiting for another to finish before it can acquire a lock.
+
+class wfg {
+ public:
+ // Create a lock request graph
+ void create(void);
+
+ // Destroy the internals of the lock request graph
+ void destroy(void);
+
+ // Add an edge (a_id, b_id) to the graph
+ void add_edge(TXNID a_txnid, TXNID b_txnid);
+
+ // Return true if a node with the given transaction id exists in the graph.
+ // Return false otherwise.
+ bool node_exists(TXNID txnid);
+
+ // Return true if there exists a cycle from a given transaction id in the
+ // graph. Return false otherwise.
+ bool cycle_exists_from_txnid(TXNID txnid,
+ std::function<void(TXNID)> reporter);
+
+ // Apply a given function f to all of the nodes in the graph. The apply
+ // function returns when the function f is called for all of the nodes in the
+ // graph, or the function f returns non-zero.
+ void apply_nodes(int (*fn)(TXNID txnid, void *extra), void *extra);
+
+ // Apply a given function f to all of the edges whose origin is a given node
+ // id. The apply function returns when the function f is called for all edges
+ // in the graph rooted at node id, or the function f returns non-zero.
+ void apply_edges(TXNID txnid,
+ int (*fn)(TXNID txnid, TXNID edge_txnid, void *extra),
+ void *extra);
+
+ private:
+ struct node {
+ // txnid for this node and the associated set of edges
+ TXNID txnid;
+ txnid_set edges;
+ bool visited;
+
+ static node *alloc(TXNID txnid);
+
+ static void free(node *n);
+ };
+ ENSURE_POD(node);
+
+ toku::omt<node *> m_nodes;
+
+ node *find_node(TXNID txnid);
+
+ node *find_create_node(TXNID txnid);
+
+ bool cycle_exists_from_node(node *target, node *head,
+ std::function<void(TXNID)> reporter);
+
+ static int find_by_txnid(node *const &node_a, const TXNID &txnid_b);
+};
+ENSURE_POD(wfg);
+
+} /* namespace toku */
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/portability/memory.h b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/portability/memory.h
new file mode 100644
index 000000000..0a621f8e0
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/portability/memory.h
@@ -0,0 +1,215 @@
+/* -*- mode: C++; c-basic-offset: 4; indent-tabs-mode: nil -*- */
+// vim: ft=cpp:expandtab:ts=8:sw=4:softtabstop=4:
+#ident "$Id$"
+/*======
+This file is part of PerconaFT.
+
+
+Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved.
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License, version 2,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License, version 3,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+======= */
+
+#ident \
+ "Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved."
+
+#pragma once
+
+#include <stdlib.h>
+
+#include "toku_portability.h"
+
+/* Percona memory allocation functions and macros.
+ * These are functions for malloc and free */
+
+int toku_memory_startup(void) __attribute__((constructor));
+void toku_memory_shutdown(void) __attribute__((destructor));
+
+/* Generally: errno is set to 0 or a value to indicate problems. */
+
+// Everything should call toku_malloc() instead of malloc(), and toku_calloc()
+// instead of calloc() That way the tests can can, e.g., replace the malloc
+// function using toku_set_func_malloc().
+void *toku_calloc(size_t nmemb, size_t size)
+ __attribute__((__visibility__("default")));
+void *toku_xcalloc(size_t nmemb, size_t size)
+ __attribute__((__visibility__("default")));
+void *toku_malloc(size_t size) __attribute__((__visibility__("default")));
+void *toku_malloc_aligned(size_t alignment, size_t size)
+ __attribute__((__visibility__("default")));
+
+// xmalloc aborts instead of return NULL if we run out of memory
+void *toku_xmalloc(size_t size) __attribute__((__visibility__("default")));
+void *toku_xrealloc(void *, size_t size)
+ __attribute__((__visibility__("default")));
+void *toku_xmalloc_aligned(size_t alignment, size_t size)
+ __attribute__((__visibility__("default")));
+// Effect: Perform a os_malloc_aligned(size) with the additional property that
+// the returned pointer is a multiple of ALIGNMENT.
+// Fail with a resource_assert if the allocation fails (don't return an error
+// code). If the alloc_aligned function has been set then call it instead.
+// Requires: alignment is a power of two.
+
+void toku_free(void *) __attribute__((__visibility__("default")));
+
+size_t toku_malloc_usable_size(void *p)
+ __attribute__((__visibility__("default")));
+
+/* MALLOC is a macro that helps avoid a common error:
+ * Suppose I write
+ * struct foo *x = malloc(sizeof(struct foo));
+ * That works fine. But if I change it to this, I've probably made an mistake:
+ * struct foo *x = malloc(sizeof(struct bar));
+ * It can get worse, since one might have something like
+ * struct foo *x = malloc(sizeof(struct foo *))
+ * which looks reasonable, but it allocoates enough to hold a pointer instead of
+ * the amount needed for the struct. So instead, write struct foo *MALLOC(x);
+ * and you cannot go wrong.
+ */
+#define MALLOC(v) CAST_FROM_VOIDP(v, toku_malloc(sizeof(*v)))
+/* MALLOC_N is like calloc(Except no 0ing of data): It makes an array. Write
+ * int *MALLOC_N(5,x);
+ * to make an array of 5 integers.
+ */
+#define MALLOC_N(n, v) CAST_FROM_VOIDP(v, toku_malloc((n) * sizeof(*v)))
+#define MALLOC_N_ALIGNED(align, n, v) \
+ CAST_FROM_VOIDP(v, toku_malloc_aligned((align), (n) * sizeof(*v)))
+
+// CALLOC_N is like calloc with auto-figuring out size of members
+#define CALLOC_N(n, v) CAST_FROM_VOIDP(v, toku_calloc((n), sizeof(*v)))
+
+#define CALLOC(v) CALLOC_N(1, v)
+
+// XMALLOC macros are like MALLOC except they abort if the operation fails
+#define XMALLOC(v) CAST_FROM_VOIDP(v, toku_xmalloc(sizeof(*v)))
+#define XMALLOC_N(n, v) CAST_FROM_VOIDP(v, toku_xmalloc((n) * sizeof(*v)))
+#define XCALLOC_N(n, v) CAST_FROM_VOIDP(v, toku_xcalloc((n), (sizeof(*v))))
+#define XCALLOC(v) XCALLOC_N(1, v)
+#define XREALLOC(v, s) CAST_FROM_VOIDP(v, toku_xrealloc(v, s))
+#define XREALLOC_N(n, v) CAST_FROM_VOIDP(v, toku_xrealloc(v, (n) * sizeof(*v)))
+
+#define XMALLOC_N_ALIGNED(align, n, v) \
+ CAST_FROM_VOIDP(v, toku_xmalloc_aligned((align), (n) * sizeof(*v)))
+
+#define XMEMDUP(dst, src) CAST_FROM_VOIDP(dst, toku_xmemdup(src, sizeof(*src)))
+#define XMEMDUP_N(dst, src, len) CAST_FROM_VOIDP(dst, toku_xmemdup(src, len))
+
+// ZERO_ARRAY writes zeroes to a stack-allocated array
+#define ZERO_ARRAY(o) \
+ do { \
+ memset((o), 0, sizeof(o)); \
+ } while (0)
+// ZERO_STRUCT writes zeroes to a stack-allocated struct
+#define ZERO_STRUCT(o) \
+ do { \
+ memset(&(o), 0, sizeof(o)); \
+ } while (0)
+
+/* Copy memory. Analogous to strdup() */
+void *toku_memdup(const void *v, size_t len);
+/* Toku-version of strdup. Use this so that it calls toku_malloc() */
+char *toku_strdup(const char *s) __attribute__((__visibility__("default")));
+/* Toku-version of strndup. Use this so that it calls toku_malloc() */
+char *toku_strndup(const char *s, size_t n)
+ __attribute__((__visibility__("default")));
+/* Copy memory. Analogous to strdup() Crashes instead of returning NULL */
+void *toku_xmemdup(const void *v, size_t len)
+ __attribute__((__visibility__("default")));
+/* Toku-version of strdup. Use this so that it calls toku_xmalloc() Crashes
+ * instead of returning NULL */
+char *toku_xstrdup(const char *s) __attribute__((__visibility__("default")));
+
+void toku_malloc_cleanup(
+ void); /* Before exiting, call this function to free up any internal data
+ structures from toku_malloc. Otherwise valgrind will complain of
+ memory leaks. */
+
+/* Check to see if everything malloc'd was free. Might be a no-op depending on
+ * how memory.c is configured. */
+void toku_memory_check_all_free(void);
+/* Check to see if memory is "sane". Might be a no-op. Probably better to
+ * simply use valgrind. */
+void toku_do_memory_check(void);
+
+typedef void *(*malloc_fun_t)(size_t);
+typedef void (*free_fun_t)(void *);
+typedef void *(*realloc_fun_t)(void *, size_t);
+typedef void *(*malloc_aligned_fun_t)(size_t /*alignment*/, size_t /*size*/);
+typedef void *(*realloc_aligned_fun_t)(size_t /*alignment*/, void * /*pointer*/,
+ size_t /*size*/);
+
+void toku_set_func_malloc(malloc_fun_t f);
+void toku_set_func_xmalloc_only(malloc_fun_t f);
+void toku_set_func_malloc_only(malloc_fun_t f);
+void toku_set_func_realloc(realloc_fun_t f);
+void toku_set_func_xrealloc_only(realloc_fun_t f);
+void toku_set_func_realloc_only(realloc_fun_t f);
+void toku_set_func_free(free_fun_t f);
+
+typedef struct memory_status {
+ uint64_t malloc_count; // number of malloc operations
+ uint64_t free_count; // number of free operations
+ uint64_t realloc_count; // number of realloc operations
+ uint64_t malloc_fail; // number of malloc operations that failed
+ uint64_t realloc_fail; // number of realloc operations that failed
+ uint64_t requested; // number of bytes requested
+ uint64_t used; // number of bytes used (requested + overhead), obtained from
+ // malloc_usable_size()
+ uint64_t freed; // number of bytes freed;
+ uint64_t max_requested_size; // largest attempted allocation size
+ uint64_t last_failed_size; // size of the last failed allocation attempt
+ volatile uint64_t
+ max_in_use; // maximum memory footprint (used - freed), approximate (not
+ // worth threadsafety overhead for exact)
+ const char *mallocator_version;
+ uint64_t mmap_threshold;
+} LOCAL_MEMORY_STATUS_S, *LOCAL_MEMORY_STATUS;
+
+void toku_memory_get_status(LOCAL_MEMORY_STATUS s);
+
+// Effect: Like toku_memory_footprint, except instead of passing p,
+// we pass toku_malloc_usable_size(p).
+size_t toku_memory_footprint_given_usable_size(size_t touched, size_t usable);
+
+// Effect: Return an estimate how how much space an object is using, possibly by
+// using toku_malloc_usable_size(p).
+// If p is NULL then returns 0.
+size_t toku_memory_footprint(void *p, size_t touched);
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/portability/toku_assert_subst.h b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/portability/toku_assert_subst.h
new file mode 100644
index 000000000..af47800fb
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/portability/toku_assert_subst.h
@@ -0,0 +1,39 @@
+//
+// A replacement for toku_assert.h
+//
+#pragma once
+
+#include <assert.h>
+#include <errno.h>
+
+#ifdef NDEBUG
+
+#define assert_zero(a) ((void)(a))
+#define invariant(a) ((void)(a))
+#define invariant_notnull(a) ((void)(a))
+#define invariant_zero(a) ((void)(a))
+
+#else
+
+#define assert_zero(a) assert((a) == 0)
+#define invariant(a) assert(a)
+#define invariant_notnull(a) assert(a)
+#define invariant_zero(a) assert_zero(a)
+
+#endif
+
+#define lazy_assert_zero(a) assert_zero(a)
+
+#define paranoid_invariant_zero(a) assert_zero(a)
+#define paranoid_invariant_notnull(a) assert(a)
+#define paranoid_invariant(a) assert(a)
+
+#define ENSURE_POD(type) \
+ static_assert( \
+ std::is_standard_layout<type>::value && std::is_trivial<type>::value, \
+ #type "isn't POD")
+
+inline int get_error_errno(void) {
+ invariant(errno);
+ return errno;
+}
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/portability/toku_atomic.h b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/portability/toku_atomic.h
new file mode 100644
index 000000000..aaa2298fa
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/portability/toku_atomic.h
@@ -0,0 +1,130 @@
+/* -*- mode: C++; c-basic-offset: 4; indent-tabs-mode: nil -*- */
+// vim: ft=cpp:expandtab:ts=8:sw=4:softtabstop=4:
+#ident "$Id$"
+/*======
+This file is part of PerconaFT.
+
+
+Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved.
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License, version 2,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License, version 3,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+======= */
+
+#ident \
+ "Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved."
+
+#pragma once
+
+// PORT2: #include <portability/toku_config.h>
+#include <stdbool.h>
+#include <stddef.h>
+#include <stdint.h>
+
+#include "toku_assert_subst.h"
+
+__attribute__((const, always_inline)) static inline intptr_t which_cache_line(
+ intptr_t addr) {
+ static const size_t assumed_cache_line_size = 64;
+ return addr / assumed_cache_line_size;
+}
+template <typename T>
+__attribute__((const, always_inline)) static inline bool crosses_boundary(
+ T *addr, size_t width) {
+ const intptr_t int_addr = reinterpret_cast<intptr_t>(addr);
+ const intptr_t last_byte = int_addr + width - 1;
+ return which_cache_line(int_addr) != which_cache_line(last_byte);
+}
+
+template <typename T, typename U>
+__attribute__((always_inline)) static inline T toku_sync_fetch_and_add(T *addr,
+ U diff) {
+ paranoid_invariant(!crosses_boundary(addr, sizeof *addr));
+ return __sync_fetch_and_add(addr, diff);
+}
+template <typename T, typename U>
+__attribute__((always_inline)) static inline T toku_sync_add_and_fetch(T *addr,
+ U diff) {
+ paranoid_invariant(!crosses_boundary(addr, sizeof *addr));
+ return __sync_add_and_fetch(addr, diff);
+}
+template <typename T, typename U>
+__attribute__((always_inline)) static inline T toku_sync_fetch_and_sub(T *addr,
+ U diff) {
+ paranoid_invariant(!crosses_boundary(addr, sizeof *addr));
+ return __sync_fetch_and_sub(addr, diff);
+}
+template <typename T, typename U>
+__attribute__((always_inline)) static inline T toku_sync_sub_and_fetch(T *addr,
+ U diff) {
+ paranoid_invariant(!crosses_boundary(addr, sizeof *addr));
+ return __sync_sub_and_fetch(addr, diff);
+}
+template <typename T, typename U, typename V>
+__attribute__((always_inline)) static inline T toku_sync_val_compare_and_swap(
+ T *addr, U oldval, V newval) {
+ paranoid_invariant(!crosses_boundary(addr, sizeof *addr));
+ return __sync_val_compare_and_swap(addr, oldval, newval);
+}
+template <typename T, typename U, typename V>
+__attribute__((always_inline)) static inline bool
+toku_sync_bool_compare_and_swap(T *addr, U oldval, V newval) {
+ paranoid_invariant(!crosses_boundary(addr, sizeof *addr));
+ return __sync_bool_compare_and_swap(addr, oldval, newval);
+}
+
+// in case you include this but not toku_portability.h
+#pragma GCC poison __sync_fetch_and_add
+#pragma GCC poison __sync_fetch_and_sub
+#pragma GCC poison __sync_fetch_and_or
+#pragma GCC poison __sync_fetch_and_and
+#pragma GCC poison __sync_fetch_and_xor
+#pragma GCC poison __sync_fetch_and_nand
+#pragma GCC poison __sync_add_and_fetch
+#pragma GCC poison __sync_sub_and_fetch
+#pragma GCC poison __sync_or_and_fetch
+#pragma GCC poison __sync_and_and_fetch
+#pragma GCC poison __sync_xor_and_fetch
+#pragma GCC poison __sync_nand_and_fetch
+#pragma GCC poison __sync_bool_compare_and_swap
+#pragma GCC poison __sync_val_compare_and_swap
+#pragma GCC poison __sync_synchronize
+#pragma GCC poison __sync_lock_test_and_set
+#pragma GCC poison __sync_release
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/portability/toku_external_pthread.h b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/portability/toku_external_pthread.h
new file mode 100644
index 000000000..eb8291c1d
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/portability/toku_external_pthread.h
@@ -0,0 +1,83 @@
+/*
+ A wrapper around ROCKSDB_NAMESPACE::TransactionDBMutexFactory-provided
+ condition and mutex that provides toku_pthread_*-like interface. The functions
+ are named
+
+ toku_external_{mutex|cond}_XXX
+
+ Lock Tree uses this mutex and condition for interruptible (long) lock waits.
+
+ (It also still uses toku_pthread_XXX calls for mutexes/conditions for
+ shorter waits on internal objects)
+*/
+
+#pragma once
+
+#include <pthread.h>
+#include <stdint.h>
+#include <time.h>
+
+#include "rocksdb/utilities/transaction_db.h"
+#include "rocksdb/utilities/transaction_db_mutex.h"
+#include "toku_portability.h"
+
+using ROCKSDB_NAMESPACE::TransactionDBCondVar;
+using ROCKSDB_NAMESPACE::TransactionDBMutex;
+
+typedef std::shared_ptr<ROCKSDB_NAMESPACE::TransactionDBMutexFactory>
+ toku_external_mutex_factory_t;
+
+typedef std::shared_ptr<TransactionDBMutex> toku_external_mutex_t;
+typedef std::shared_ptr<TransactionDBCondVar> toku_external_cond_t;
+
+static inline void toku_external_cond_init(
+ toku_external_mutex_factory_t mutex_factory, toku_external_cond_t *cond) {
+ *cond = mutex_factory->AllocateCondVar();
+}
+
+inline void toku_external_cond_destroy(toku_external_cond_t *cond) {
+ cond->reset(); // this will destroy the managed object
+}
+
+inline void toku_external_cond_signal(toku_external_cond_t *cond) {
+ (*cond)->Notify();
+}
+
+inline void toku_external_cond_broadcast(toku_external_cond_t *cond) {
+ (*cond)->NotifyAll();
+}
+
+inline int toku_external_cond_timedwait(toku_external_cond_t *cond,
+ toku_external_mutex_t *mutex,
+ int64_t timeout_microsec) {
+ auto res = (*cond)->WaitFor(*mutex, timeout_microsec);
+ if (res.ok())
+ return 0;
+ else
+ return ETIMEDOUT;
+}
+
+inline void toku_external_mutex_init(toku_external_mutex_factory_t factory,
+ toku_external_mutex_t *mutex) {
+ // Use placement new: the memory has been allocated but constructor wasn't
+ // called
+ new (mutex) toku_external_mutex_t;
+ *mutex = factory->AllocateMutex();
+}
+
+inline void toku_external_mutex_lock(toku_external_mutex_t *mutex) {
+ (*mutex)->Lock();
+}
+
+inline int toku_external_mutex_trylock(toku_external_mutex_t *mutex) {
+ (*mutex)->Lock();
+ return 0;
+}
+
+inline void toku_external_mutex_unlock(toku_external_mutex_t *mutex) {
+ (*mutex)->UnLock();
+}
+
+inline void toku_external_mutex_destroy(toku_external_mutex_t *mutex) {
+ mutex->reset(); // this will destroy the managed object
+}
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/portability/toku_instrumentation.h b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/portability/toku_instrumentation.h
new file mode 100644
index 000000000..c967e7177
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/portability/toku_instrumentation.h
@@ -0,0 +1,286 @@
+/*======
+This file is part of PerconaFT.
+
+Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved.
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License, version 2,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License, version 3,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+======= */
+
+#pragma once
+
+#include <stdio.h> // FILE
+
+// Performance instrumentation object identifier type
+typedef unsigned int pfs_key_t;
+
+enum class toku_instr_object_type { mutex, rwlock, cond, thread, file };
+
+struct PSI_file;
+
+struct TOKU_FILE {
+ /** The real file. */
+ FILE *file;
+ struct PSI_file *key;
+ TOKU_FILE() : file(nullptr), key(nullptr) {}
+};
+
+struct PSI_mutex;
+struct PSI_cond;
+struct PSI_rwlock;
+
+struct toku_mutex_t;
+struct toku_cond_t;
+struct toku_pthread_rwlock_t;
+
+class toku_instr_key;
+
+class toku_instr_probe_empty {
+ public:
+ explicit toku_instr_probe_empty(UU(const toku_instr_key &key)) {}
+
+ void start_with_source_location(UU(const char *src_file), UU(int src_line)) {}
+
+ void stop() {}
+};
+
+#define TOKU_PROBE_START(p) p->start_with_source_location(__FILE__, __LINE__)
+#define TOKU_PROBE_STOP(p) p->stop
+
+extern toku_instr_key toku_uninstrumented;
+
+#ifndef MYSQL_TOKUDB_ENGINE
+
+#include <pthread.h>
+
+class toku_instr_key {
+ public:
+ toku_instr_key(UU(toku_instr_object_type type), UU(const char *group),
+ UU(const char *name)) {}
+
+ explicit toku_instr_key(UU(pfs_key_t key_id)) {}
+ // No-instrumentation constructor:
+ toku_instr_key() {}
+ ~toku_instr_key() {}
+};
+
+typedef toku_instr_probe_empty toku_instr_probe;
+
+enum class toku_instr_file_op {
+ file_stream_open,
+ file_create,
+ file_open,
+ file_delete,
+ file_rename,
+ file_read,
+ file_write,
+ file_sync,
+ file_stream_close,
+ file_close,
+ file_stat
+};
+
+struct PSI_file {};
+struct PSI_mutex {};
+
+struct toku_io_instrumentation {};
+
+inline int toku_pthread_create(UU(const toku_instr_key &key), pthread_t *thread,
+ const pthread_attr_t *attr,
+ void *(*start_routine)(void *), void *arg) {
+ return pthread_create(thread, attr, start_routine, arg);
+}
+
+inline void toku_instr_register_current_thread() {}
+
+inline void toku_instr_delete_current_thread() {}
+
+// Instrument file creation, opening, closing, and renaming
+inline void toku_instr_file_open_begin(UU(toku_io_instrumentation &io_instr),
+ UU(const toku_instr_key &key),
+ UU(toku_instr_file_op op),
+ UU(const char *name),
+ UU(const char *src_file),
+ UU(int src_line)) {}
+
+inline void toku_instr_file_stream_open_end(
+ UU(toku_io_instrumentation &io_instr), UU(TOKU_FILE &file)) {}
+
+inline void toku_instr_file_open_end(UU(toku_io_instrumentation &io_instr),
+ UU(int fd)) {}
+
+inline void toku_instr_file_name_close_begin(
+ UU(toku_io_instrumentation &io_instr), UU(const toku_instr_key &key),
+ UU(toku_instr_file_op op), UU(const char *name), UU(const char *src_file),
+ UU(int src_line)) {}
+
+inline void toku_instr_file_stream_close_begin(
+ UU(toku_io_instrumentation &io_instr), UU(toku_instr_file_op op),
+ UU(TOKU_FILE &file), UU(const char *src_file), UU(int src_line)) {}
+
+inline void toku_instr_file_fd_close_begin(
+ UU(toku_io_instrumentation &io_instr), UU(toku_instr_file_op op),
+ UU(int fd), UU(const char *src_file), UU(int src_line)) {}
+
+inline void toku_instr_file_close_end(UU(toku_io_instrumentation &io_instr),
+ UU(int result)) {}
+
+inline void toku_instr_file_io_begin(UU(toku_io_instrumentation &io_instr),
+ UU(toku_instr_file_op op), UU(int fd),
+ UU(unsigned int count),
+ UU(const char *src_file),
+ UU(int src_line)) {}
+
+inline void toku_instr_file_name_io_begin(
+ UU(toku_io_instrumentation &io_instr), UU(const toku_instr_key &key),
+ UU(toku_instr_file_op op), UU(const char *name), UU(unsigned int count),
+ UU(const char *src_file), UU(int src_line)) {}
+
+inline void toku_instr_file_stream_io_begin(
+ UU(toku_io_instrumentation &io_instr), UU(toku_instr_file_op op),
+ UU(TOKU_FILE &file), UU(unsigned int count), UU(const char *src_file),
+ UU(int src_line)) {}
+
+inline void toku_instr_file_io_end(UU(toku_io_instrumentation &io_instr),
+ UU(unsigned int count)) {}
+
+struct toku_mutex_t;
+
+struct toku_mutex_instrumentation {};
+
+inline PSI_mutex *toku_instr_mutex_init(UU(const toku_instr_key &key),
+ UU(toku_mutex_t &mutex)) {
+ return nullptr;
+}
+
+inline void toku_instr_mutex_destroy(UU(PSI_mutex *&mutex_instr)) {}
+
+inline void toku_instr_mutex_lock_start(
+ UU(toku_mutex_instrumentation &mutex_instr), UU(toku_mutex_t &mutex),
+ UU(const char *src_file), UU(int src_line)) {}
+
+inline void toku_instr_mutex_trylock_start(
+ UU(toku_mutex_instrumentation &mutex_instr), UU(toku_mutex_t &mutex),
+ UU(const char *src_file), UU(int src_line)) {}
+
+inline void toku_instr_mutex_lock_end(
+ UU(toku_mutex_instrumentation &mutex_instr),
+ UU(int pthread_mutex_lock_result)) {}
+
+inline void toku_instr_mutex_unlock(UU(PSI_mutex *mutex_instr)) {}
+
+struct toku_cond_instrumentation {};
+
+enum class toku_instr_cond_op {
+ cond_wait,
+ cond_timedwait,
+};
+
+inline PSI_cond *toku_instr_cond_init(UU(const toku_instr_key &key),
+ UU(toku_cond_t &cond)) {
+ return nullptr;
+}
+
+inline void toku_instr_cond_destroy(UU(PSI_cond *&cond_instr)) {}
+
+inline void toku_instr_cond_wait_start(
+ UU(toku_cond_instrumentation &cond_instr), UU(toku_instr_cond_op op),
+ UU(toku_cond_t &cond), UU(toku_mutex_t &mutex), UU(const char *src_file),
+ UU(int src_line)) {}
+
+inline void toku_instr_cond_wait_end(UU(toku_cond_instrumentation &cond_instr),
+ UU(int pthread_cond_wait_result)) {}
+
+inline void toku_instr_cond_signal(UU(toku_cond_t &cond)) {}
+
+inline void toku_instr_cond_broadcast(UU(toku_cond_t &cond)) {}
+
+#if 0
+// rw locks are not used
+// rwlock instrumentation
+struct toku_rwlock_instrumentation {};
+
+inline PSI_rwlock *toku_instr_rwlock_init(UU(const toku_instr_key &key),
+ UU(toku_pthread_rwlock_t &rwlock)) {
+ return nullptr;
+}
+
+inline void toku_instr_rwlock_destroy(UU(PSI_rwlock *&rwlock_instr)) {}
+
+inline void toku_instr_rwlock_rdlock_wait_start(
+ UU(toku_rwlock_instrumentation &rwlock_instr),
+ UU(toku_pthread_rwlock_t &rwlock),
+ UU(const char *src_file),
+ UU(int src_line)) {}
+
+inline void toku_instr_rwlock_wrlock_wait_start(
+ UU(toku_rwlock_instrumentation &rwlock_instr),
+ UU(toku_pthread_rwlock_t &rwlock),
+ UU(const char *src_file),
+ UU(int src_line)) {}
+
+inline void toku_instr_rwlock_rdlock_wait_end(
+ UU(toku_rwlock_instrumentation &rwlock_instr),
+ UU(int pthread_rwlock_wait_result)) {}
+
+inline void toku_instr_rwlock_wrlock_wait_end(
+ UU(toku_rwlock_instrumentation &rwlock_instr),
+ UU(int pthread_rwlock_wait_result)) {}
+
+inline void toku_instr_rwlock_unlock(UU(toku_pthread_rwlock_t &rwlock)) {}
+#endif
+
+#else // MYSQL_TOKUDB_ENGINE
+// There can be not only mysql but also mongodb or any other PFS stuff
+#include <toku_instr_mysql.h>
+#endif // MYSQL_TOKUDB_ENGINE
+
+// Mutexes
+extern toku_instr_key manager_escalation_mutex_key;
+extern toku_instr_key manager_escalator_mutex_key;
+extern toku_instr_key manager_mutex_key;
+extern toku_instr_key treenode_mutex_key;
+extern toku_instr_key locktree_request_info_mutex_key;
+extern toku_instr_key locktree_request_info_retry_mutex_key;
+
+// condition vars
+extern toku_instr_key lock_request_m_wait_cond_key;
+extern toku_instr_key locktree_request_info_retry_cv_key;
+extern toku_instr_key manager_m_escalator_done_key; // unused
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/portability/toku_portability.h b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/portability/toku_portability.h
new file mode 100644
index 000000000..9a95b38bd
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/portability/toku_portability.h
@@ -0,0 +1,87 @@
+/* -*- mode: C++; c-basic-offset: 4; indent-tabs-mode: nil -*- */
+// vim: ft=cpp:expandtab:ts=8:sw=4:softtabstop=4:
+#ident "$Id$"
+/*======
+This file is part of PerconaFT.
+
+
+Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved.
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License, version 2,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License, version 3,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+======= */
+
+#ident \
+ "Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved."
+
+#pragma once
+
+#if defined(__clang__)
+#define constexpr_static_assert(a, b)
+#else
+#define constexpr_static_assert(a, b) static_assert(a, b)
+#endif
+
+// include here, before they get deprecated
+#include <inttypes.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <sys/stat.h>
+#include <sys/time.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "toku_atomic.h"
+
+#if defined(__cplusplus)
+#include <type_traits>
+#endif
+
+#if defined(__cplusplus)
+// decltype() here gives a reference-to-pointer instead of just a pointer,
+// just use __typeof__
+#define CAST_FROM_VOIDP(name, value) name = static_cast<__typeof__(name)>(value)
+#else
+#define CAST_FROM_VOIDP(name, value) name = cast_to_typeof(name)(value)
+#endif
+
+#define UU(x) x __attribute__((__unused__))
+
+#include "toku_instrumentation.h"
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/portability/toku_pthread.h b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/portability/toku_pthread.h
new file mode 100644
index 000000000..571b950e1
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/portability/toku_pthread.h
@@ -0,0 +1,520 @@
+/* -*- mode: C++; c-basic-offset: 4; indent-tabs-mode: nil -*- */
+// vim: ft=cpp:expandtab:ts=8:sw=4:softtabstop=4:
+#ident "$Id$"
+/*======
+This file is part of PerconaFT.
+
+
+Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved.
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License, version 2,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License, version 3,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+======= */
+
+#ident \
+ "Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved."
+
+#pragma once
+
+#include <pthread.h>
+#include <stdint.h>
+#include <time.h>
+
+#include "toku_portability.h"
+// PORT2: #include "toku_assert.h"
+
+// TODO: some things moved toku_instrumentation.h, not necessarily the best
+// place
+typedef pthread_attr_t toku_pthread_attr_t;
+typedef pthread_t toku_pthread_t;
+typedef pthread_mutex_t toku_pthread_mutex_t;
+typedef pthread_condattr_t toku_pthread_condattr_t;
+typedef pthread_cond_t toku_pthread_cond_t;
+typedef pthread_rwlockattr_t toku_pthread_rwlockattr_t;
+typedef pthread_key_t toku_pthread_key_t;
+typedef struct timespec toku_timespec_t;
+
+// TODO: break this include loop
+#include <pthread.h>
+typedef pthread_mutexattr_t toku_pthread_mutexattr_t;
+
+struct toku_mutex_t {
+ pthread_mutex_t pmutex;
+ struct PSI_mutex *psi_mutex; /* The performance schema instrumentation hook */
+#if defined(TOKU_PTHREAD_DEBUG)
+ pthread_t owner; // = pthread_self(); // for debugging
+ bool locked;
+ bool valid;
+ pfs_key_t instr_key_id;
+#endif // defined(TOKU_PTHREAD_DEBUG)
+};
+
+struct toku_cond_t {
+ pthread_cond_t pcond;
+ struct PSI_cond *psi_cond;
+#if defined(TOKU_PTHREAD_DEBUG)
+ pfs_key_t instr_key_id;
+#endif // defined(TOKU_PTHREAD_DEBUG)
+};
+
+#if defined(TOKU_PTHREAD_DEBUG)
+#define TOKU_COND_INITIALIZER \
+ { .pcond = PTHREAD_COND_INITIALIZER, .psi_cond = nullptr, .instr_key_id = 0 }
+#else
+#define TOKU_COND_INITIALIZER \
+ { .pcond = PTHREAD_COND_INITIALIZER, .psi_cond = nullptr }
+#endif // defined(TOKU_PTHREAD_DEBUG)
+
+struct toku_pthread_rwlock_t {
+ pthread_rwlock_t rwlock;
+ struct PSI_rwlock *psi_rwlock;
+#if defined(TOKU_PTHREAD_DEBUG)
+ pfs_key_t instr_key_id;
+#endif // defined(TOKU_PTHREAD_DEBUG)
+};
+
+typedef struct toku_mutex_aligned {
+ toku_mutex_t aligned_mutex __attribute__((__aligned__(64)));
+} toku_mutex_aligned_t;
+
+// Initializing with {} will fill in a struct with all zeros.
+// But you may also need a pragma to suppress the warnings, as follows
+//
+// #pragma GCC diagnostic push
+// #pragma GCC diagnostic ignored "-Wmissing-field-initializers"
+// toku_mutex_t foo = ZERO_MUTEX_INITIALIZER;
+// #pragma GCC diagnostic pop
+//
+// In general it will be a lot of busy work to make this codebase compile
+// cleanly with -Wmissing-field-initializers
+
+#define ZERO_MUTEX_INITIALIZER \
+ {}
+
+#if defined(TOKU_PTHREAD_DEBUG)
+#define TOKU_MUTEX_INITIALIZER \
+ { \
+ .pmutex = PTHREAD_MUTEX_INITIALIZER, .psi_mutex = nullptr, .owner = 0, \
+ .locked = false, .valid = true, .instr_key_id = 0 \
+ }
+#else
+#define TOKU_MUTEX_INITIALIZER \
+ { .pmutex = PTHREAD_MUTEX_INITIALIZER, .psi_mutex = nullptr }
+#endif // defined(TOKU_PTHREAD_DEBUG)
+
+// Darwin doesn't provide adaptive mutexes
+#if defined(__APPLE__)
+#define TOKU_MUTEX_ADAPTIVE PTHREAD_MUTEX_DEFAULT
+#if defined(TOKU_PTHREAD_DEBUG)
+#define TOKU_ADAPTIVE_MUTEX_INITIALIZER \
+ { \
+ .pmutex = PTHREAD_MUTEX_INITIALIZER, .psi_mutex = nullptr, .owner = 0, \
+ .locked = false, .valid = true, .instr_key_id = 0 \
+ }
+#else
+#define TOKU_ADAPTIVE_MUTEX_INITIALIZER \
+ { .pmutex = PTHREAD_MUTEX_INITIALIZER, .psi_mutex = nullptr }
+#endif // defined(TOKU_PTHREAD_DEBUG)
+#else // __FreeBSD__, __linux__, at least
+#if defined(__GLIBC__)
+#define TOKU_MUTEX_ADAPTIVE PTHREAD_MUTEX_ADAPTIVE_NP
+#else
+// not all libc (e.g. musl) implement NP (Non-POSIX) attributes
+#define TOKU_MUTEX_ADAPTIVE PTHREAD_MUTEX_DEFAULT
+#endif
+#if defined(TOKU_PTHREAD_DEBUG)
+#define TOKU_ADAPTIVE_MUTEX_INITIALIZER \
+ { \
+ .pmutex = PTHREAD_ADAPTIVE_MUTEX_INITIALIZER_NP, .psi_mutex = nullptr, \
+ .owner = 0, .locked = false, .valid = true, .instr_key_id = 0 \
+ }
+#else
+#define TOKU_ADAPTIVE_MUTEX_INITIALIZER \
+ { .pmutex = PTHREAD_ADAPTIVE_MUTEX_INITIALIZER_NP, .psi_mutex = nullptr }
+#endif // defined(TOKU_PTHREAD_DEBUG)
+#endif // defined(__APPLE__)
+
+// Different OSes implement mutexes as different amounts of nested structs.
+// C++ will fill out all missing values with zeroes if you provide at least one
+// zero, but it needs the right amount of nesting.
+#if defined(__FreeBSD__)
+#define ZERO_COND_INITIALIZER \
+ { 0 }
+#elif defined(__APPLE__)
+#define ZERO_COND_INITIALIZER \
+ { \
+ { 0 } \
+ }
+#else // __linux__, at least
+#define ZERO_COND_INITIALIZER \
+ {}
+#endif
+
+static inline void toku_mutexattr_init(toku_pthread_mutexattr_t *attr) {
+ int r = pthread_mutexattr_init(attr);
+ assert_zero(r);
+}
+
+static inline void toku_mutexattr_settype(toku_pthread_mutexattr_t *attr,
+ int type) {
+ int r = pthread_mutexattr_settype(attr, type);
+ assert_zero(r);
+}
+
+static inline void toku_mutexattr_destroy(toku_pthread_mutexattr_t *attr) {
+ int r = pthread_mutexattr_destroy(attr);
+ assert_zero(r);
+}
+
+#if defined(TOKU_PTHREAD_DEBUG)
+static inline void toku_mutex_assert_locked(const toku_mutex_t *mutex) {
+ invariant(mutex->locked);
+ invariant(mutex->owner == pthread_self());
+}
+#else
+static inline void toku_mutex_assert_locked(const toku_mutex_t *mutex
+ __attribute__((unused))) {}
+#endif // defined(TOKU_PTHREAD_DEBUG)
+
+// asserting that a mutex is unlocked only makes sense
+// if the calling thread can guaruntee that no other threads
+// are trying to lock this mutex at the time of the assertion
+//
+// a good example of this is a tree with mutexes on each node.
+// when a node is locked the caller knows that no other threads
+// can be trying to lock its childrens' mutexes. the children
+// are in one of two fixed states: locked or unlocked.
+#if defined(TOKU_PTHREAD_DEBUG)
+static inline void toku_mutex_assert_unlocked(toku_mutex_t *mutex) {
+ invariant(mutex->owner == 0);
+ invariant(!mutex->locked);
+}
+#else
+static inline void toku_mutex_assert_unlocked(toku_mutex_t *mutex
+ __attribute__((unused))) {}
+#endif // defined(TOKU_PTHREAD_DEBUG)
+
+#define toku_mutex_lock(M) \
+ toku_mutex_lock_with_source_location(M, __FILE__, __LINE__)
+
+static inline void toku_cond_init(toku_cond_t *cond,
+ const toku_pthread_condattr_t *attr) {
+ int r = pthread_cond_init(&cond->pcond, attr);
+ assert_zero(r);
+}
+
+#define toku_mutex_trylock(M) \
+ toku_mutex_trylock_with_source_location(M, __FILE__, __LINE__)
+
+inline void toku_mutex_unlock(toku_mutex_t *mutex) {
+#if defined(TOKU_PTHREAD_DEBUG)
+ invariant(mutex->owner == pthread_self());
+ invariant(mutex->valid);
+ invariant(mutex->locked);
+ mutex->locked = false;
+ mutex->owner = 0;
+#endif // defined(TOKU_PTHREAD_DEBUG)
+ toku_instr_mutex_unlock(mutex->psi_mutex);
+ int r = pthread_mutex_unlock(&mutex->pmutex);
+ assert_zero(r);
+}
+
+inline void toku_mutex_lock_with_source_location(toku_mutex_t *mutex,
+ const char *src_file,
+ int src_line) {
+ toku_mutex_instrumentation mutex_instr;
+ toku_instr_mutex_lock_start(mutex_instr, *mutex, src_file, src_line);
+
+ const int r = pthread_mutex_lock(&mutex->pmutex);
+ toku_instr_mutex_lock_end(mutex_instr, r);
+
+ assert_zero(r);
+#if defined(TOKU_PTHREAD_DEBUG)
+ invariant(mutex->valid);
+ invariant(!mutex->locked);
+ invariant(mutex->owner == 0);
+ mutex->locked = true;
+ mutex->owner = pthread_self();
+#endif // defined(TOKU_PTHREAD_DEBUG)
+}
+
+inline int toku_mutex_trylock_with_source_location(toku_mutex_t *mutex,
+ const char *src_file,
+ int src_line) {
+ toku_mutex_instrumentation mutex_instr;
+ toku_instr_mutex_trylock_start(mutex_instr, *mutex, src_file, src_line);
+
+ const int r = pthread_mutex_lock(&mutex->pmutex);
+ toku_instr_mutex_lock_end(mutex_instr, r);
+
+#if defined(TOKU_PTHREAD_DEBUG)
+ if (r == 0) {
+ invariant(mutex->valid);
+ invariant(!mutex->locked);
+ invariant(mutex->owner == 0);
+ mutex->locked = true;
+ mutex->owner = pthread_self();
+ }
+#endif // defined(TOKU_PTHREAD_DEBUG)
+ return r;
+}
+
+#define toku_cond_wait(C, M) \
+ toku_cond_wait_with_source_location(C, M, __FILE__, __LINE__)
+
+#define toku_cond_timedwait(C, M, W) \
+ toku_cond_timedwait_with_source_location(C, M, W, __FILE__, __LINE__)
+
+inline void toku_cond_init(const toku_instr_key &key, toku_cond_t *cond,
+ const pthread_condattr_t *attr) {
+ toku_instr_cond_init(key, *cond);
+ int r = pthread_cond_init(&cond->pcond, attr);
+ assert_zero(r);
+}
+
+inline void toku_cond_destroy(toku_cond_t *cond) {
+ toku_instr_cond_destroy(cond->psi_cond);
+ int r = pthread_cond_destroy(&cond->pcond);
+ assert_zero(r);
+}
+
+inline void toku_cond_wait_with_source_location(toku_cond_t *cond,
+ toku_mutex_t *mutex,
+ const char *src_file,
+ int src_line) {
+#if defined(TOKU_PTHREAD_DEBUG)
+ invariant(mutex->locked);
+ mutex->locked = false;
+ mutex->owner = 0;
+#endif // defined(TOKU_PTHREAD_DEBUG)
+
+ /* Instrumentation start */
+ toku_cond_instrumentation cond_instr;
+ toku_instr_cond_wait_start(cond_instr, toku_instr_cond_op::cond_wait, *cond,
+ *mutex, src_file, src_line);
+
+ /* Instrumented code */
+ const int r = pthread_cond_wait(&cond->pcond, &mutex->pmutex);
+
+ /* Instrumentation end */
+ toku_instr_cond_wait_end(cond_instr, r);
+
+ assert_zero(r);
+#if defined(TOKU_PTHREAD_DEBUG)
+ invariant(!mutex->locked);
+ mutex->locked = true;
+ mutex->owner = pthread_self();
+#endif // defined(TOKU_PTHREAD_DEBUG)
+}
+
+inline int toku_cond_timedwait_with_source_location(toku_cond_t *cond,
+ toku_mutex_t *mutex,
+ toku_timespec_t *wakeup_at,
+ const char *src_file,
+ int src_line) {
+#if defined(TOKU_PTHREAD_DEBUG)
+ invariant(mutex->locked);
+ mutex->locked = false;
+ mutex->owner = 0;
+#endif // defined(TOKU_PTHREAD_DEBUG)
+
+ /* Instrumentation start */
+ toku_cond_instrumentation cond_instr;
+ toku_instr_cond_wait_start(cond_instr, toku_instr_cond_op::cond_timedwait,
+ *cond, *mutex, src_file, src_line);
+
+ /* Instrumented code */
+ const int r = pthread_cond_timedwait(&cond->pcond, &mutex->pmutex, wakeup_at);
+
+ /* Instrumentation end */
+ toku_instr_cond_wait_end(cond_instr, r);
+
+#if defined(TOKU_PTHREAD_DEBUG)
+ invariant(!mutex->locked);
+ mutex->locked = true;
+ mutex->owner = pthread_self();
+#endif // defined(TOKU_PTHREAD_DEBUG)
+ return r;
+}
+
+inline void toku_cond_signal(toku_cond_t *cond) {
+ toku_instr_cond_signal(*cond);
+ const int r = pthread_cond_signal(&cond->pcond);
+ assert_zero(r);
+}
+
+inline void toku_cond_broadcast(toku_cond_t *cond) {
+ toku_instr_cond_broadcast(*cond);
+ const int r = pthread_cond_broadcast(&cond->pcond);
+ assert_zero(r);
+}
+
+inline void toku_mutex_init(const toku_instr_key &key, toku_mutex_t *mutex,
+ const toku_pthread_mutexattr_t *attr) {
+#if defined(TOKU_PTHREAD_DEBUG)
+ mutex->valid = true;
+#endif // defined(TOKU_PTHREAD_DEBUG)
+ toku_instr_mutex_init(key, *mutex);
+ const int r = pthread_mutex_init(&mutex->pmutex, attr);
+ assert_zero(r);
+#if defined(TOKU_PTHREAD_DEBUG)
+ mutex->locked = false;
+ invariant(mutex->valid);
+ mutex->valid = true;
+ mutex->owner = 0;
+#endif // defined(TOKU_PTHREAD_DEBUG)
+}
+
+inline void toku_mutex_destroy(toku_mutex_t *mutex) {
+#if defined(TOKU_PTHREAD_DEBUG)
+ invariant(mutex->valid);
+ mutex->valid = false;
+ invariant(!mutex->locked);
+#endif // defined(TOKU_PTHREAD_DEBUG)
+ toku_instr_mutex_destroy(mutex->psi_mutex);
+ int r = pthread_mutex_destroy(&mutex->pmutex);
+ assert_zero(r);
+}
+
+#define toku_pthread_rwlock_rdlock(RW) \
+ toku_pthread_rwlock_rdlock_with_source_location(RW, __FILE__, __LINE__)
+
+#define toku_pthread_rwlock_wrlock(RW) \
+ toku_pthread_rwlock_wrlock_with_source_location(RW, __FILE__, __LINE__)
+
+#if 0
+inline void toku_pthread_rwlock_init(
+ const toku_instr_key &key,
+ toku_pthread_rwlock_t *__restrict rwlock,
+ const toku_pthread_rwlockattr_t *__restrict attr) {
+ toku_instr_rwlock_init(key, *rwlock);
+ int r = pthread_rwlock_init(&rwlock->rwlock, attr);
+ assert_zero(r);
+}
+
+inline void toku_pthread_rwlock_destroy(toku_pthread_rwlock_t *rwlock) {
+ toku_instr_rwlock_destroy(rwlock->psi_rwlock);
+ int r = pthread_rwlock_destroy(&rwlock->rwlock);
+ assert_zero(r);
+}
+
+inline void toku_pthread_rwlock_rdlock_with_source_location(
+ toku_pthread_rwlock_t *rwlock,
+ const char *src_file,
+ uint src_line) {
+
+ /* Instrumentation start */
+ toku_rwlock_instrumentation rwlock_instr;
+ toku_instr_rwlock_rdlock_wait_start(
+ rwlock_instr, *rwlock, src_file, src_line);
+ /* Instrumented code */
+ const int r = pthread_rwlock_rdlock(&rwlock->rwlock);
+
+ /* Instrumentation end */
+ toku_instr_rwlock_rdlock_wait_end(rwlock_instr, r);
+
+ assert_zero(r);
+}
+
+inline void toku_pthread_rwlock_wrlock_with_source_location(
+ toku_pthread_rwlock_t *rwlock,
+ const char *src_file,
+ uint src_line) {
+
+ /* Instrumentation start */
+ toku_rwlock_instrumentation rwlock_instr;
+ toku_instr_rwlock_wrlock_wait_start(
+ rwlock_instr, *rwlock, src_file, src_line);
+ /* Instrumented code */
+ const int r = pthread_rwlock_wrlock(&rwlock->rwlock);
+
+ /* Instrumentation end */
+ toku_instr_rwlock_wrlock_wait_end(rwlock_instr, r);
+
+ assert_zero(r);
+}
+
+inline void toku_pthread_rwlock_rdunlock(toku_pthread_rwlock_t *rwlock) {
+ toku_instr_rwlock_unlock(*rwlock);
+ const int r = pthread_rwlock_unlock(&rwlock->rwlock);
+ assert_zero(r);
+}
+
+inline void toku_pthread_rwlock_wrunlock(toku_pthread_rwlock_t *rwlock) {
+ toku_instr_rwlock_unlock(*rwlock);
+ const int r = pthread_rwlock_unlock(&rwlock->rwlock);
+ assert_zero(r);
+}
+#endif
+
+static inline int toku_pthread_join(toku_pthread_t thread, void **value_ptr) {
+ return pthread_join(thread, value_ptr);
+}
+
+static inline int toku_pthread_detach(toku_pthread_t thread) {
+ return pthread_detach(thread);
+}
+
+static inline int toku_pthread_key_create(toku_pthread_key_t *key,
+ void (*destroyf)(void *)) {
+ return pthread_key_create(key, destroyf);
+}
+
+static inline int toku_pthread_key_delete(toku_pthread_key_t key) {
+ return pthread_key_delete(key);
+}
+
+static inline void *toku_pthread_getspecific(toku_pthread_key_t key) {
+ return pthread_getspecific(key);
+}
+
+static inline int toku_pthread_setspecific(toku_pthread_key_t key, void *data) {
+ return pthread_setspecific(key, data);
+}
+
+int toku_pthread_yield(void) __attribute__((__visibility__("default")));
+
+static inline toku_pthread_t toku_pthread_self(void) { return pthread_self(); }
+
+static inline void *toku_pthread_done(void *exit_value) {
+ toku_instr_delete_current_thread();
+ pthread_exit(exit_value);
+ return nullptr; // Avoid compiler warning
+}
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/portability/toku_race_tools.h b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/portability/toku_race_tools.h
new file mode 100644
index 000000000..3cb5b5790
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/portability/toku_race_tools.h
@@ -0,0 +1,179 @@
+/* -*- mode: C++; c-basic-offset: 4; indent-tabs-mode: nil -*- */
+// vim: ft=cpp:expandtab:ts=8:sw=4:softtabstop=4:
+#ident "$Id$"
+/*======
+This file is part of PerconaFT.
+
+
+Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved.
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License, version 2,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License, version 3,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+======= */
+
+#ident \
+ "Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved."
+
+#pragma once
+
+// PORT2: #include <portability/toku_config.h>
+
+#ifdef HAVE_valgrind
+#undef USE_VALGRIND
+#define USE_VALGRIND 1
+#endif
+
+#if defined(__linux__) && USE_VALGRIND
+
+#include <valgrind/drd.h>
+#include <valgrind/helgrind.h>
+
+#define TOKU_ANNOTATE_NEW_MEMORY(p, size) ANNOTATE_NEW_MEMORY(p, size)
+#define TOKU_VALGRIND_HG_ENABLE_CHECKING(p, size) \
+ VALGRIND_HG_ENABLE_CHECKING(p, size)
+#define TOKU_VALGRIND_HG_DISABLE_CHECKING(p, size) \
+ VALGRIND_HG_DISABLE_CHECKING(p, size)
+#define TOKU_DRD_IGNORE_VAR(v) DRD_IGNORE_VAR(v)
+#define TOKU_DRD_STOP_IGNORING_VAR(v) DRD_STOP_IGNORING_VAR(v)
+#define TOKU_ANNOTATE_IGNORE_READS_BEGIN() ANNOTATE_IGNORE_READS_BEGIN()
+#define TOKU_ANNOTATE_IGNORE_READS_END() ANNOTATE_IGNORE_READS_END()
+#define TOKU_ANNOTATE_IGNORE_WRITES_BEGIN() ANNOTATE_IGNORE_WRITES_BEGIN()
+#define TOKU_ANNOTATE_IGNORE_WRITES_END() ANNOTATE_IGNORE_WRITES_END()
+
+/*
+ * How to make helgrind happy about tree rotations and new mutex orderings:
+ *
+ * // Tell helgrind that we unlocked it so that the next call doesn't get a
+ * "destroyed a locked mutex" error.
+ * // Tell helgrind that we destroyed the mutex.
+ * VALGRIND_HG_MUTEX_UNLOCK_PRE(&locka);
+ * VALGRIND_HG_MUTEX_DESTROY_PRE(&locka);
+ *
+ * // And recreate it. It would be better to simply be able to say that the
+ * order on these two can now be reversed, because this code forgets all the
+ * ordering information for this mutex.
+ * // Then tell helgrind that we have locked it again.
+ * VALGRIND_HG_MUTEX_INIT_POST(&locka, 0);
+ * VALGRIND_HG_MUTEX_LOCK_POST(&locka);
+ *
+ * When the ordering of two locks changes, we don't need tell Helgrind about do
+ * both locks. Just one is good enough.
+ */
+
+#define TOKU_VALGRIND_RESET_MUTEX_ORDERING_INFO(mutex) \
+ VALGRIND_HG_MUTEX_UNLOCK_PRE(mutex); \
+ VALGRIND_HG_MUTEX_DESTROY_PRE(mutex); \
+ VALGRIND_HG_MUTEX_INIT_POST(mutex, 0); \
+ VALGRIND_HG_MUTEX_LOCK_POST(mutex);
+
+#else // !defined(__linux__) || !USE_VALGRIND
+
+#define NVALGRIND 1
+#define TOKU_ANNOTATE_NEW_MEMORY(p, size) ((void)0)
+#define TOKU_VALGRIND_HG_ENABLE_CHECKING(p, size) ((void)0)
+#define TOKU_VALGRIND_HG_DISABLE_CHECKING(p, size) ((void)0)
+#define TOKU_DRD_IGNORE_VAR(v)
+#define TOKU_DRD_STOP_IGNORING_VAR(v)
+#define TOKU_ANNOTATE_IGNORE_READS_BEGIN() ((void)0)
+#define TOKU_ANNOTATE_IGNORE_READS_END() ((void)0)
+#define TOKU_ANNOTATE_IGNORE_WRITES_BEGIN() ((void)0)
+#define TOKU_ANNOTATE_IGNORE_WRITES_END() ((void)0)
+#define TOKU_VALGRIND_RESET_MUTEX_ORDERING_INFO(mutex)
+#undef RUNNING_ON_VALGRIND
+#define RUNNING_ON_VALGRIND (0U)
+#endif
+
+// Valgrind 3.10.1 (and previous versions).
+// Problems with VALGRIND_HG_DISABLE_CHECKING and VALGRIND_HG_ENABLE_CHECKING.
+// Helgrind's implementation of disable and enable checking causes false races
+// to be reported. In addition, the race report does not include ANY
+// information about the code that uses the helgrind disable and enable
+// functions. Therefore, it is very difficult to figure out the cause of the
+// race. DRD does implement the disable and enable functions.
+
+// Problems with ANNOTATE_IGNORE_READS.
+// Helgrind does not implement ignore reads.
+// Annotate ignore reads is the way to inform DRD to ignore racy reads.
+
+// FT code uses unsafe reads in several places. These unsafe reads have been
+// noted as valid since they use the toku_unsafe_fetch function. Unfortunately,
+// this causes helgrind to report erroneous data races which makes use of
+// helgrind problematic.
+
+// Unsafely fetch and return a `T' from src, telling drd to ignore
+// racey access to src for the next sizeof(*src) bytes
+template <typename T>
+T toku_unsafe_fetch(T *src) {
+ if (0)
+ TOKU_VALGRIND_HG_DISABLE_CHECKING(src,
+ sizeof *src); // disabled, see comment
+ TOKU_ANNOTATE_IGNORE_READS_BEGIN();
+ T r = *src;
+ TOKU_ANNOTATE_IGNORE_READS_END();
+ if (0)
+ TOKU_VALGRIND_HG_ENABLE_CHECKING(src,
+ sizeof *src); // disabled, see comment
+ return r;
+}
+
+template <typename T>
+T toku_unsafe_fetch(T &src) {
+ return toku_unsafe_fetch(&src);
+}
+
+// Unsafely set a `T' value into *dest from src, telling drd to ignore
+// racey access to dest for the next sizeof(*dest) bytes
+template <typename T>
+void toku_unsafe_set(T *dest, const T src) {
+ if (0)
+ TOKU_VALGRIND_HG_DISABLE_CHECKING(dest,
+ sizeof *dest); // disabled, see comment
+ TOKU_ANNOTATE_IGNORE_WRITES_BEGIN();
+ *dest = src;
+ TOKU_ANNOTATE_IGNORE_WRITES_END();
+ if (0)
+ TOKU_VALGRIND_HG_ENABLE_CHECKING(dest,
+ sizeof *dest); // disabled, see comment
+}
+
+template <typename T>
+void toku_unsafe_set(T &dest, const T src) {
+ toku_unsafe_set(&dest, src);
+}
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/portability/toku_time.h b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/portability/toku_time.h
new file mode 100644
index 000000000..46111e7f0
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/portability/toku_time.h
@@ -0,0 +1,193 @@
+/* -*- mode: C++; c-basic-offset: 4; indent-tabs-mode: nil -*- */
+// vim: ft=cpp:expandtab:ts=8:sw=4:softtabstop=4:
+#ident "$Id$"
+/*======
+This file is part of PerconaFT.
+
+
+Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved.
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License, version 2,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License, version 3,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+======= */
+
+#ident \
+ "Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved."
+
+#pragma once
+
+// PORT2: #include "toku_config.h"
+
+#include <stdint.h>
+#include <sys/time.h>
+#include <time.h>
+#if defined(__powerpc__)
+#include <sys/platform/ppc.h>
+#endif
+
+#if 0
+static inline float toku_tdiff (struct timeval *a, struct timeval *b) {
+ return (float)((a->tv_sec - b->tv_sec) + 1e-6 * (a->tv_usec - b->tv_usec));
+}
+// PORT2: temporary:
+#define HAVE_CLOCK_REALTIME
+#if !defined(HAVE_CLOCK_REALTIME)
+// OS X does not have clock_gettime, we fake clockid_t for the interface, and we'll implement it with clock_get_time.
+typedef int clockid_t;
+// just something bogus, it doesn't matter, we just want to make sure we're
+// only supporting this mode because we're not sure we can support other modes
+// without a real clock_gettime()
+#define CLOCK_REALTIME 0x01867234
+#endif
+int toku_clock_gettime(clockid_t clk_id, struct timespec *ts) __attribute__((__visibility__("default")));
+#endif
+
+// *************** Performance timers ************************
+// What do you really want from a performance timer:
+// (1) Can determine actual time of day from the performance time.
+// (2) Time goes forward, never backward.
+// (3) Same time on different processors (or even different machines).
+// (4) Time goes forward at a constant rate (doesn't get faster and slower)
+// (5) Portable.
+// (6) Getting the time is cheap.
+// Unfortuately it seems tough to get Properties 1-5. So we go for Property 6,,
+// but we abstract it. We offer a type tokutime_t which can hold the time. This
+// type can be subtracted to get a time difference. We can get the present time
+// cheaply. We can convert this type to seconds (but that can be expensive). The
+// implementation is to use RDTSC (hence we lose property 3: not portable).
+// Recent machines have constant_tsc in which case we get property (4).
+// Recent OSs on recent machines (that have RDTSCP) fix the per-processor clock
+// skew, so we get property (3). We get property 2 with RDTSC (as long as
+// there's not any skew). We don't even try to get propety 1, since we don't
+// need it. The decision here is that these times are really accurate only on
+// modern machines with modern OSs.
+typedef uint64_t tokutime_t; // Time type used in by tokutek timers.
+
+#if 0
+// The value of tokutime_t is not specified here.
+// It might be microseconds since 1/1/1970 (if gettimeofday() is
+// used), or clock cycles since boot (if rdtsc is used). Or something
+// else.
+// Two tokutime_t values can be subtracted to get a time difference.
+// Use tokutime_to_seconds to that convert difference to seconds.
+// We want get_tokutime() to be fast, but don't care so much about tokutime_to_seconds();
+//
+// For accurate time calculations do the subtraction in the right order:
+// Right: tokutime_to_seconds(t1-t2);
+// Wrong tokutime_to_seconds(t1)-toku_time_to_seconds(t2);
+// Doing it the wrong way is likely to result in loss of precision.
+// A double can hold numbers up to about 53 bits. RDTSC which uses about 33 bits every second, so that leaves
+// 2^20 seconds from booting (about 2 weeks) before the RDTSC value cannot be represented accurately as a double.
+//
+double tokutime_to_seconds(tokutime_t) __attribute__((__visibility__("default"))); // Convert tokutime to seconds.
+
+#endif
+
+// Get the value of tokutime for right now. We want this to be fast, so we
+// expose the implementation as RDTSC.
+static inline tokutime_t toku_time_now(void) {
+#if defined(__x86_64__) || defined(__i386__)
+ uint32_t lo, hi;
+ __asm__ __volatile__("rdtsc" : "=a"(lo), "=d"(hi));
+ return (uint64_t)hi << 32 | lo;
+#elif defined(__aarch64__)
+ uint64_t result;
+ __asm __volatile__("mrs %[rt], cntvct_el0" : [rt] "=r"(result));
+ return result;
+#elif defined(__powerpc__)
+ return __ppc_get_timebase();
+#elif defined(__s390x__)
+ uint64_t result;
+ asm volatile("stckf %0" : "=Q"(result) : : "cc");
+ return result;
+#elif defined(__riscv) && __riscv_xlen == 32
+ uint32_t cycles_lo, cycles_hi0, cycles_hi1;
+ // Implemented in assembly because Clang insisted on branching.
+ asm volatile(
+ "rdcycleh %0\n"
+ "rdcycle %1\n"
+ "rdcycleh %2\n"
+ "sub %0, %0, %2\n"
+ "seqz %0, %0\n"
+ "sub %0, zero, %0\n"
+ "and %1, %1, %0\n"
+ : "=r"(cycles_hi0), "=r"(cycles_lo), "=r"(cycles_hi1));
+ return (static_cast<uint64_t>(cycles_hi1) << 32) | cycles_lo;
+#elif defined(__riscv) && __riscv_xlen == 64
+ uint64_t cycles;
+ asm volatile("rdcycle %0" : "=r"(cycles));
+ return cycles;
+#else
+#error No timer implementation for this platform
+#endif
+}
+
+static inline uint64_t toku_current_time_microsec(void) {
+ struct timeval t;
+ gettimeofday(&t, NULL);
+ return t.tv_sec * (1UL * 1000 * 1000) + t.tv_usec;
+}
+
+#if 0
+// sleep microseconds
+static inline void toku_sleep_microsec(uint64_t ms) {
+ struct timeval t;
+
+ t.tv_sec = ms / 1000000;
+ t.tv_usec = ms % 1000000;
+
+ select(0, NULL, NULL, NULL, &t);
+}
+#endif
+
+/*
+ PORT: Usage of this file:
+
+ uint64_t toku_current_time_microsec() // uses gettimeoday
+ is used to track how much time various operations took (for example, lock
+ escalation). (TODO: it is not clear why these operations are tracked with
+ microsecond precision while others use nanoseconds)
+
+ tokutime_t toku_time_now() // uses rdtsc
+ seems to be used for a very similar purpose. This has greater precision
+
+ RocksDB environment provides Env::Default()->NowMicros() and NowNanos() which
+ should be adequate substitutes.
+*/
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/portability/txn_subst.h b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/portability/txn_subst.h
new file mode 100644
index 000000000..803914862
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/portability/txn_subst.h
@@ -0,0 +1,27 @@
+//
+// A substitute for ft/txn/txn.h
+//
+#pragma once
+
+#include <set>
+
+#include "../util/omt.h"
+
+typedef uint64_t TXNID;
+#define TXNID_NONE ((TXNID)0)
+
+// A set of transactions
+// (TODO: consider using class toku::txnid_set. The reason for using STL
+// container was that its API is easier)
+class TxnidVector : public std::set<TXNID> {
+ public:
+ bool contains(TXNID txnid) { return find(txnid) != end(); }
+};
+
+// A value for lock structures with a meaning "the lock is owned by multiple
+// transactions (and one has to check the TxnidVector to get their ids)
+#define TXNID_SHARED (TXNID(-1))
+
+// Auxiliary value meaning "any transaction id will do". No real transaction
+// may have this is as id.
+#define TXNID_ANY (TXNID(-2))
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/standalone_port.cc b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/standalone_port.cc
new file mode 100644
index 000000000..50dc879ce
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/standalone_port.cc
@@ -0,0 +1,132 @@
+#ifndef ROCKSDB_LITE
+#ifndef OS_WIN
+/*
+ This is a dump ground to make Lock Tree work without the rest of TokuDB.
+*/
+#include <string.h>
+
+#include "db.h"
+#include "ft/ft-status.h"
+#include "portability/memory.h"
+#include "util/dbt.h"
+
+// portability/os_malloc.cc
+
+void toku_free(void *p) { free(p); }
+
+void *toku_xmalloc(size_t size) { return malloc(size); }
+
+void *toku_xrealloc(void *v, size_t size) { return realloc(v, size); }
+
+void *toku_xmemdup(const void *v, size_t len) {
+ void *p = toku_xmalloc(len);
+ memcpy(p, v, len);
+ return p;
+}
+
+// TODO: what are the X-functions? Xcalloc, Xrealloc?
+void *toku_xcalloc(size_t nmemb, size_t size) { return calloc(nmemb, size); }
+
+// ft-ft-opts.cc:
+
+// locktree
+toku_instr_key lock_request_m_wait_cond_key;
+toku_instr_key manager_m_escalator_done_key;
+toku_instr_key locktree_request_info_mutex_key;
+toku_instr_key locktree_request_info_retry_mutex_key;
+toku_instr_key locktree_request_info_retry_cv_key;
+
+toku_instr_key treenode_mutex_key;
+toku_instr_key manager_mutex_key;
+toku_instr_key manager_escalation_mutex_key;
+toku_instr_key manager_escalator_mutex_key;
+
+// portability/memory.cc
+size_t toku_memory_footprint(void *, size_t touched) { return touched; }
+
+// ft/ft-status.c
+// PORT2: note: the @c parameter to TOKUFT_STATUS_INIT must not start with
+// "TOKU"
+LTM_STATUS_S ltm_status;
+void LTM_STATUS_S::init() {
+ if (m_initialized) return;
+#define LTM_STATUS_INIT(k, c, t, l) \
+ TOKUFT_STATUS_INIT((*this), k, c, t, "locktree: " l, \
+ TOKU_ENGINE_STATUS | TOKU_GLOBAL_STATUS)
+ LTM_STATUS_INIT(LTM_SIZE_CURRENT, LOCKTREE_MEMORY_SIZE, STATUS_UINT64,
+ "memory size");
+ LTM_STATUS_INIT(LTM_SIZE_LIMIT, LOCKTREE_MEMORY_SIZE_LIMIT, STATUS_UINT64,
+ "memory size limit");
+ LTM_STATUS_INIT(LTM_ESCALATION_COUNT, LOCKTREE_ESCALATION_NUM, STATUS_UINT64,
+ "number of times lock escalation ran");
+ LTM_STATUS_INIT(LTM_ESCALATION_TIME, LOCKTREE_ESCALATION_SECONDS,
+ STATUS_TOKUTIME, "time spent running escalation (seconds)");
+ LTM_STATUS_INIT(LTM_ESCALATION_LATEST_RESULT,
+ LOCKTREE_LATEST_POST_ESCALATION_MEMORY_SIZE, STATUS_UINT64,
+ "latest post-escalation memory size");
+ LTM_STATUS_INIT(LTM_NUM_LOCKTREES, LOCKTREE_OPEN_CURRENT, STATUS_UINT64,
+ "number of locktrees open now");
+ LTM_STATUS_INIT(LTM_LOCK_REQUESTS_PENDING, LOCKTREE_PENDING_LOCK_REQUESTS,
+ STATUS_UINT64, "number of pending lock requests");
+ LTM_STATUS_INIT(LTM_STO_NUM_ELIGIBLE, LOCKTREE_STO_ELIGIBLE_NUM,
+ STATUS_UINT64, "number of locktrees eligible for the STO");
+ LTM_STATUS_INIT(LTM_STO_END_EARLY_COUNT, LOCKTREE_STO_ENDED_NUM,
+ STATUS_UINT64,
+ "number of times a locktree ended the STO early");
+ LTM_STATUS_INIT(LTM_STO_END_EARLY_TIME, LOCKTREE_STO_ENDED_SECONDS,
+ STATUS_TOKUTIME, "time spent ending the STO early (seconds)");
+ LTM_STATUS_INIT(LTM_WAIT_COUNT, LOCKTREE_WAIT_COUNT, STATUS_UINT64,
+ "number of wait locks");
+ LTM_STATUS_INIT(LTM_WAIT_TIME, LOCKTREE_WAIT_TIME, STATUS_UINT64,
+ "time waiting for locks");
+ LTM_STATUS_INIT(LTM_LONG_WAIT_COUNT, LOCKTREE_LONG_WAIT_COUNT, STATUS_UINT64,
+ "number of long wait locks");
+ LTM_STATUS_INIT(LTM_LONG_WAIT_TIME, LOCKTREE_LONG_WAIT_TIME, STATUS_UINT64,
+ "long time waiting for locks");
+ LTM_STATUS_INIT(LTM_TIMEOUT_COUNT, LOCKTREE_TIMEOUT_COUNT, STATUS_UINT64,
+ "number of lock timeouts");
+ LTM_STATUS_INIT(LTM_WAIT_ESCALATION_COUNT, LOCKTREE_WAIT_ESCALATION_COUNT,
+ STATUS_UINT64, "number of waits on lock escalation");
+ LTM_STATUS_INIT(LTM_WAIT_ESCALATION_TIME, LOCKTREE_WAIT_ESCALATION_TIME,
+ STATUS_UINT64, "time waiting on lock escalation");
+ LTM_STATUS_INIT(LTM_LONG_WAIT_ESCALATION_COUNT,
+ LOCKTREE_LONG_WAIT_ESCALATION_COUNT, STATUS_UINT64,
+ "number of long waits on lock escalation");
+ LTM_STATUS_INIT(LTM_LONG_WAIT_ESCALATION_TIME,
+ LOCKTREE_LONG_WAIT_ESCALATION_TIME, STATUS_UINT64,
+ "long time waiting on lock escalation");
+
+ m_initialized = true;
+#undef LTM_STATUS_INIT
+}
+void LTM_STATUS_S::destroy() {
+ if (!m_initialized) return;
+ for (int i = 0; i < LTM_STATUS_NUM_ROWS; ++i) {
+ if (status[i].type == STATUS_PARCOUNT) {
+ // PORT: TODO?? destroy_partitioned_counter(status[i].value.parcount);
+ }
+ }
+}
+
+int toku_keycompare(const void *key1, size_t key1len, const void *key2,
+ size_t key2len) {
+ size_t comparelen = key1len < key2len ? key1len : key2len;
+ int c = memcmp(key1, key2, comparelen);
+ if (__builtin_expect(c != 0, 1)) {
+ return c;
+ } else {
+ if (key1len < key2len) {
+ return -1;
+ } else if (key1len > key2len) {
+ return 1;
+ } else {
+ return 0;
+ }
+ }
+}
+
+int toku_builtin_compare_fun(const DBT *a, const DBT *b) {
+ return toku_keycompare(a->data, a->size, b->data, b->size);
+}
+#endif // OS_WIN
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/util/dbt.cc b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/util/dbt.cc
new file mode 100644
index 000000000..63cc3a267
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/util/dbt.cc
@@ -0,0 +1,153 @@
+/* -*- mode: C++; c-basic-offset: 4; indent-tabs-mode: nil -*- */
+// vim: ft=cpp:expandtab:ts=8:sw=4:softtabstop=4:
+#ifndef ROCKSDB_LITE
+#ifndef OS_WIN
+#ident "$Id$"
+/*======
+This file is part of PerconaFT.
+
+
+Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved.
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License, version 2,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License, version 3,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+======= */
+
+#ident \
+ "Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved."
+
+#include "dbt.h"
+
+#include <string.h>
+
+#include "../db.h"
+#include "../portability/memory.h"
+
+DBT *toku_init_dbt(DBT *dbt) {
+ memset(dbt, 0, sizeof(*dbt));
+ return dbt;
+}
+
+DBT toku_empty_dbt(void) {
+ static const DBT empty_dbt = {.data = 0, .size = 0, .ulen = 0, .flags = 0};
+ return empty_dbt;
+}
+
+DBT *toku_init_dbt_flags(DBT *dbt, uint32_t flags) {
+ toku_init_dbt(dbt);
+ dbt->flags = flags;
+ return dbt;
+}
+
+void toku_destroy_dbt(DBT *dbt) {
+ switch (dbt->flags) {
+ case DB_DBT_MALLOC:
+ case DB_DBT_REALLOC:
+ toku_free(dbt->data);
+ toku_init_dbt(dbt);
+ break;
+ }
+}
+
+DBT *toku_fill_dbt(DBT *dbt, const void *k, size_t len) {
+ toku_init_dbt(dbt);
+ dbt->size = len;
+ dbt->data = (char *)k;
+ return dbt;
+}
+
+DBT *toku_memdup_dbt(DBT *dbt, const void *k, size_t len) {
+ toku_init_dbt_flags(dbt, DB_DBT_MALLOC);
+ dbt->size = len;
+ dbt->data = toku_xmemdup(k, len);
+ return dbt;
+}
+
+DBT *toku_copyref_dbt(DBT *dst, const DBT src) {
+ dst->flags = 0;
+ dst->ulen = 0;
+ dst->size = src.size;
+ dst->data = src.data;
+ return dst;
+}
+
+DBT *toku_clone_dbt(DBT *dst, const DBT &src) {
+ return toku_memdup_dbt(dst, src.data, src.size);
+}
+
+void toku_sdbt_cleanup(struct simple_dbt *sdbt) {
+ if (sdbt->data) toku_free(sdbt->data);
+ memset(sdbt, 0, sizeof(*sdbt));
+}
+
+const DBT *toku_dbt_positive_infinity(void) {
+ static DBT positive_infinity_dbt = {
+ .data = 0, .size = 0, .ulen = 0, .flags = 0}; // port
+ return &positive_infinity_dbt;
+}
+
+const DBT *toku_dbt_negative_infinity(void) {
+ static DBT negative_infinity_dbt = {
+ .data = 0, .size = 0, .ulen = 0, .flags = 0}; // port
+ return &negative_infinity_dbt;
+}
+
+bool toku_dbt_is_infinite(const DBT *dbt) {
+ return dbt == toku_dbt_positive_infinity() ||
+ dbt == toku_dbt_negative_infinity();
+}
+
+bool toku_dbt_is_empty(const DBT *dbt) {
+ // can't have a null data field with a non-zero size
+ paranoid_invariant(dbt->data != nullptr || dbt->size == 0);
+ return dbt->data == nullptr;
+}
+
+int toku_dbt_infinite_compare(const DBT *a, const DBT *b) {
+ if (a == b) {
+ return 0;
+ } else if (a == toku_dbt_positive_infinity()) {
+ return 1;
+ } else if (b == toku_dbt_positive_infinity()) {
+ return -1;
+ } else if (a == toku_dbt_negative_infinity()) {
+ return -1;
+ } else {
+ invariant(b == toku_dbt_negative_infinity());
+ return 1;
+ }
+}
+
+bool toku_dbt_equals(const DBT *a, const DBT *b) {
+ if (!toku_dbt_is_infinite(a) && !toku_dbt_is_infinite(b)) {
+ return a->data == b->data && a->size == b->size;
+ } else {
+ // a or b is infinite, so they're equal if they are the same infinite
+ return a == b ? true : false;
+ }
+}
+#endif // OS_WIN
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/util/dbt.h b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/util/dbt.h
new file mode 100644
index 000000000..d86c440f8
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/util/dbt.h
@@ -0,0 +1,98 @@
+/* -*- mode: C++; c-basic-offset: 4; indent-tabs-mode: nil -*- */
+// vim: ft=cpp:expandtab:ts=8:sw=4:softtabstop=4:
+#ident "$Id$"
+/*======
+This file is part of PerconaFT.
+
+
+Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved.
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License, version 2,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License, version 3,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+======= */
+
+#ident \
+ "Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved."
+
+#pragma once
+
+#include "../db.h"
+
+// TODO: John
+// Document this API a little better so that DBT
+// memory management can be morm widely understood.
+
+DBT *toku_init_dbt(DBT *);
+
+// returns: an initialized but empty dbt (for which toku_dbt_is_empty() is true)
+DBT toku_empty_dbt(void);
+
+DBT *toku_init_dbt_flags(DBT *, uint32_t flags);
+
+void toku_destroy_dbt(DBT *);
+
+DBT *toku_fill_dbt(DBT *dbt, const void *k, size_t len);
+
+DBT *toku_memdup_dbt(DBT *dbt, const void *k, size_t len);
+
+DBT *toku_copyref_dbt(DBT *dst, const DBT src);
+
+DBT *toku_clone_dbt(DBT *dst, const DBT &src);
+
+void toku_sdbt_cleanup(struct simple_dbt *sdbt);
+
+// returns: special DBT pointer representing positive infinity
+const DBT *toku_dbt_positive_infinity(void);
+
+// returns: special DBT pointer representing negative infinity
+const DBT *toku_dbt_negative_infinity(void);
+
+// returns: true if the given dbt is either positive or negative infinity
+bool toku_dbt_is_infinite(const DBT *dbt);
+
+// returns: true if the given dbt has no data (ie: dbt->data == nullptr)
+bool toku_dbt_is_empty(const DBT *dbt);
+
+// effect: compares two potentially infinity-valued dbts
+// requires: at least one is infinite (assert otherwise)
+int toku_dbt_infinite_compare(const DBT *a, const DBT *b);
+
+// returns: true if the given dbts have the same data pointer and size
+bool toku_dbt_equals(const DBT *a, const DBT *b);
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/util/growable_array.h b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/util/growable_array.h
new file mode 100644
index 000000000..158750fdb
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/util/growable_array.h
@@ -0,0 +1,144 @@
+/* -*- mode: C++; c-basic-offset: 4; indent-tabs-mode: nil -*- */
+// vim: ft=cpp:expandtab:ts=8:sw=4:softtabstop=4:
+#ident "$Id$"
+/*======
+This file is part of PerconaFT.
+
+
+Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved.
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License, version 2,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License, version 3,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+======= */
+
+#ident \
+ "Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved."
+
+#pragma once
+
+#include <memory.h>
+
+//******************************************************************************
+//
+// Overview: A growable array is a little bit like std::vector except that
+// it doesn't have constructors (hence can be used in static constructs, since
+// the google style guide says no constructors), and it's a little simpler.
+// Operations:
+// init and deinit (we don't have constructors and destructors).
+// fetch_unchecked to get values out.
+// store_unchecked to put values in.
+// push to add an element at the end
+// get_size to find out the size
+// get_memory_size to find out how much memory the data stucture is using.
+//
+//******************************************************************************
+
+namespace toku {
+
+template <typename T>
+class GrowableArray {
+ public:
+ void init(void)
+ // Effect: Initialize the array to contain no elements.
+ {
+ m_array = NULL;
+ m_size = 0;
+ m_size_limit = 0;
+ }
+
+ void deinit(void)
+ // Effect: Deinitialize the array (freeing any memory it uses, for example).
+ {
+ toku_free(m_array);
+ m_array = NULL;
+ m_size = 0;
+ m_size_limit = 0;
+ }
+
+ T fetch_unchecked(size_t i) const
+ // Effect: Fetch the ith element. If i is out of range, the system asserts.
+ {
+ return m_array[i];
+ }
+
+ void store_unchecked(size_t i, T v)
+ // Effect: Store v in the ith element. If i is out of range, the system
+ // asserts.
+ {
+ paranoid_invariant(i < m_size);
+ m_array[i] = v;
+ }
+
+ void push(T v)
+ // Effect: Add v to the end of the array (increasing the size). The amortized
+ // cost of this operation is constant. Implementation hint: Double the size
+ // of the array when it gets too big so that the amortized cost stays
+ // constant.
+ {
+ if (m_size >= m_size_limit) {
+ if (m_array == NULL) {
+ m_size_limit = 1;
+ } else {
+ m_size_limit *= 2;
+ }
+ XREALLOC_N(m_size_limit, m_array);
+ }
+ m_array[m_size++] = v;
+ }
+
+ size_t get_size(void) const
+ // Effect: Return the number of elements in the array.
+ {
+ return m_size;
+ }
+ size_t memory_size(void) const
+ // Effect: Return the size (in bytes) that the array occupies in memory. This
+ // is really only an estimate.
+ {
+ return sizeof(*this) + sizeof(T) * m_size_limit;
+ }
+
+ private:
+ T *m_array;
+ size_t m_size;
+ size_t m_size_limit; // How much space is allocated in array.
+};
+
+} // namespace toku
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/util/memarena.cc b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/util/memarena.cc
new file mode 100644
index 000000000..0e7a9880b
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/util/memarena.cc
@@ -0,0 +1,201 @@
+/* -*- mode: C++; c-basic-offset: 4; indent-tabs-mode: nil -*- */
+// vim: ft=cpp:expandtab:ts=8:sw=4:softtabstop=4:
+#ifndef ROCKSDB_LITE
+#ifndef OS_WIN
+#ident "$Id$"
+/*======
+This file is part of PerconaFT.
+
+
+Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved.
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License, version 2,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License, version 3,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+======= */
+
+#ident \
+ "Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved."
+
+#include "memarena.h"
+
+#include <string.h>
+
+#include <algorithm>
+
+#include "../portability/memory.h"
+
+void memarena::create(size_t initial_size) {
+ _current_chunk = arena_chunk();
+ _other_chunks = nullptr;
+ _size_of_other_chunks = 0;
+ _footprint_of_other_chunks = 0;
+ _n_other_chunks = 0;
+
+ _current_chunk.size = initial_size;
+ if (_current_chunk.size > 0) {
+ XMALLOC_N(_current_chunk.size, _current_chunk.buf);
+ }
+}
+
+void memarena::destroy(void) {
+ if (_current_chunk.buf) {
+ toku_free(_current_chunk.buf);
+ }
+ for (int i = 0; i < _n_other_chunks; i++) {
+ toku_free(_other_chunks[i].buf);
+ }
+ if (_other_chunks) {
+ toku_free(_other_chunks);
+ }
+ _current_chunk = arena_chunk();
+ _other_chunks = nullptr;
+ _n_other_chunks = 0;
+}
+
+static size_t round_to_page(size_t size) {
+ const size_t page_size = 4096;
+ const size_t r = page_size + ((size - 1) & ~(page_size - 1));
+ assert((r & (page_size - 1)) == 0); // make sure it's aligned
+ assert(r >= size); // make sure it's not too small
+ assert(r <
+ size + page_size); // make sure we didn't grow by more than a page.
+ return r;
+}
+
+static const size_t MEMARENA_MAX_CHUNK_SIZE = 64 * 1024 * 1024;
+
+void *memarena::malloc_from_arena(size_t size) {
+ if (_current_chunk.buf == nullptr ||
+ _current_chunk.size < _current_chunk.used + size) {
+ // The existing block isn't big enough.
+ // Add the block to the vector of blocks.
+ if (_current_chunk.buf) {
+ invariant(_current_chunk.size > 0);
+ int old_n = _n_other_chunks;
+ XREALLOC_N(old_n + 1, _other_chunks);
+ _other_chunks[old_n] = _current_chunk;
+ _n_other_chunks = old_n + 1;
+ _size_of_other_chunks += _current_chunk.size;
+ _footprint_of_other_chunks +=
+ toku_memory_footprint(_current_chunk.buf, _current_chunk.used);
+ }
+
+ // Make a new one. Grow the buffer size exponentially until we hit
+ // the max chunk size, but make it at least `size' bytes so the
+ // current allocation always fit.
+ size_t new_size =
+ std::min(MEMARENA_MAX_CHUNK_SIZE, 2 * _current_chunk.size);
+ if (new_size < size) {
+ new_size = size;
+ }
+ new_size = round_to_page(
+ new_size); // at least size, but round to the next page size
+ XMALLOC_N(new_size, _current_chunk.buf);
+ _current_chunk.used = 0;
+ _current_chunk.size = new_size;
+ }
+ invariant(_current_chunk.buf != nullptr);
+
+ // allocate in the existing block.
+ char *p = _current_chunk.buf + _current_chunk.used;
+ _current_chunk.used += size;
+ return p;
+}
+
+void memarena::move_memory(memarena *dest) {
+ // Move memory to dest
+ XREALLOC_N(dest->_n_other_chunks + _n_other_chunks + 1, dest->_other_chunks);
+ dest->_size_of_other_chunks += _size_of_other_chunks + _current_chunk.size;
+ dest->_footprint_of_other_chunks +=
+ _footprint_of_other_chunks +
+ toku_memory_footprint(_current_chunk.buf, _current_chunk.used);
+ for (int i = 0; i < _n_other_chunks; i++) {
+ dest->_other_chunks[dest->_n_other_chunks++] = _other_chunks[i];
+ }
+ dest->_other_chunks[dest->_n_other_chunks++] = _current_chunk;
+
+ // Clear out this memarena's memory
+ toku_free(_other_chunks);
+ _current_chunk = arena_chunk();
+ _other_chunks = nullptr;
+ _size_of_other_chunks = 0;
+ _footprint_of_other_chunks = 0;
+ _n_other_chunks = 0;
+}
+
+size_t memarena::total_memory_size(void) const {
+ return sizeof(*this) + total_size_in_use() +
+ _n_other_chunks * sizeof(*_other_chunks);
+}
+
+size_t memarena::total_size_in_use(void) const {
+ return _size_of_other_chunks + _current_chunk.used;
+}
+
+size_t memarena::total_footprint(void) const {
+ return sizeof(*this) + _footprint_of_other_chunks +
+ toku_memory_footprint(_current_chunk.buf, _current_chunk.used) +
+ _n_other_chunks * sizeof(*_other_chunks);
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+const void *memarena::chunk_iterator::current(size_t *used) const {
+ if (_chunk_idx < 0) {
+ *used = _ma->_current_chunk.used;
+ return _ma->_current_chunk.buf;
+ } else if (_chunk_idx < _ma->_n_other_chunks) {
+ *used = _ma->_other_chunks[_chunk_idx].used;
+ return _ma->_other_chunks[_chunk_idx].buf;
+ }
+ *used = 0;
+ return nullptr;
+}
+
+void memarena::chunk_iterator::next() { _chunk_idx++; }
+
+bool memarena::chunk_iterator::more() const {
+ if (_chunk_idx < 0) {
+ return _ma->_current_chunk.buf != nullptr;
+ }
+ return _chunk_idx < _ma->_n_other_chunks;
+}
+#endif // OS_WIN
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/util/memarena.h b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/util/memarena.h
new file mode 100644
index 000000000..ddcc1144f
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/util/memarena.h
@@ -0,0 +1,141 @@
+/* -*- mode: C++; c-basic-offset: 4; indent-tabs-mode: nil -*- */
+// vim: ft=cpp:expandtab:ts=8:sw=4:softtabstop=4:
+#ident "$Id$"
+/*======
+This file is part of PerconaFT.
+
+
+Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved.
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License, version 2,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License, version 3,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+======= */
+
+#ident \
+ "Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved."
+
+#pragma once
+
+#include <stdlib.h>
+
+/*
+ * A memarena is used to efficiently store a collection of objects that never
+ * move The pattern is allocate more and more stuff and free all of the items at
+ * once. The underlying memory will store 1 or more objects per chunk. Each
+ * chunk is contiguously laid out in memory but chunks are not necessarily
+ * contiguous with each other.
+ */
+class memarena {
+ public:
+ memarena()
+ : _current_chunk(arena_chunk()),
+ _other_chunks(nullptr),
+ _n_other_chunks(0),
+ _size_of_other_chunks(0),
+ _footprint_of_other_chunks(0) {}
+
+ // Effect: Create a memarena with the specified initial size
+ void create(size_t initial_size);
+
+ void destroy(void);
+
+ // Effect: Allocate some memory. The returned value remains valid until the
+ // memarena is cleared or closed.
+ // In case of ENOMEM, aborts.
+ void *malloc_from_arena(size_t size);
+
+ // Effect: Move all the memory from this memarena into DEST.
+ // When SOURCE is closed the memory won't be freed.
+ // When DEST is closed, the memory will be freed, unless DEST moves
+ // its memory to another memarena...
+ void move_memory(memarena *dest);
+
+ // Effect: Calculate the amount of memory used by a memory arena.
+ size_t total_memory_size(void) const;
+
+ // Effect: Calculate the used space of the memory arena (ie: excludes unused
+ // space)
+ size_t total_size_in_use(void) const;
+
+ // Effect: Calculate the amount of memory used, according to
+ // toku_memory_footprint(),
+ // which is a more expensive but more accurate count of memory used.
+ size_t total_footprint(void) const;
+
+ // iterator over the underlying chunks that store objects in the memarena.
+ // a chunk is represented by a pointer to const memory and a usable byte
+ // count.
+ class chunk_iterator {
+ public:
+ chunk_iterator(const memarena *ma) : _ma(ma), _chunk_idx(-1) {}
+
+ // returns: base pointer to the current chunk
+ // *used set to the number of usable bytes
+ // if more() is false, returns nullptr and *used = 0
+ const void *current(size_t *used) const;
+
+ // requires: more() is true
+ void next();
+
+ bool more() const;
+
+ private:
+ // -1 represents the 'initial' chunk in a memarena, ie: ma->_current_chunk
+ // >= 0 represents the i'th chunk in the ma->_other_chunks array
+ const memarena *_ma;
+ int _chunk_idx;
+ };
+
+ private:
+ struct arena_chunk {
+ arena_chunk() : buf(nullptr), used(0), size(0) {}
+ char *buf;
+ size_t used;
+ size_t size;
+ };
+
+ struct arena_chunk _current_chunk;
+ struct arena_chunk *_other_chunks;
+ int _n_other_chunks;
+ size_t _size_of_other_chunks; // the buf_size of all the other chunks.
+ size_t _footprint_of_other_chunks; // the footprint of all the other chunks.
+
+ friend class memarena_unit_test;
+};
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/util/omt.h b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/util/omt.h
new file mode 100644
index 000000000..f208002d3
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/util/omt.h
@@ -0,0 +1,794 @@
+/* -*- mode: C++; c-basic-offset: 4; indent-tabs-mode: nil -*- */
+// vim: ft=cpp:expandtab:ts=8:sw=4:softtabstop=4:
+#ident "$Id$"
+/*======
+This file is part of PerconaFT.
+
+
+Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved.
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License, version 2,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License, version 3,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+======= */
+
+#ident \
+ "Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved."
+
+#pragma once
+
+#include <memory.h>
+#include <stdint.h>
+
+#include "../portability/toku_portability.h"
+#include "../portability/toku_race_tools.h"
+#include "growable_array.h"
+
+namespace toku {
+
+/**
+ * Order Maintenance Tree (OMT)
+ *
+ * Maintains a collection of totally ordered values, where each value has an
+ * integer weight. The OMT is a mutable datatype.
+ *
+ * The Abstraction:
+ *
+ * An OMT is a vector of values, $V$, where $|V|$ is the length of the vector.
+ * The vector is numbered from $0$ to $|V|-1$.
+ * Each value has a weight. The weight of the $i$th element is denoted
+ * $w(V_i)$.
+ *
+ * We can create a new OMT, which is the empty vector.
+ *
+ * We can insert a new element $x$ into slot $i$, changing $V$ into $V'$ where
+ * $|V'|=1+|V|$ and
+ *
+ * V'_j = V_j if $j<i$
+ * x if $j=i$
+ * V_{j-1} if $j>i$.
+ *
+ * We can specify $i$ using a kind of function instead of as an integer.
+ * Let $b$ be a function mapping from values to nonzero integers, such that
+ * the signum of $b$ is monotically increasing.
+ * We can specify $i$ as the minimum integer such that $b(V_i)>0$.
+ *
+ * We look up a value using its index, or using a Heaviside function.
+ * For lookups, we allow $b$ to be zero for some values, and again the signum of
+ * $b$ must be monotonically increasing. When lookup up values, we can look up
+ * $V_i$ where $i$ is the minimum integer such that $b(V_i)=0$. (With a
+ * special return code if no such value exists.) (Rationale: Ordinarily we want
+ * $i$ to be unique. But for various reasons we want to allow multiple zeros,
+ * and we want the smallest $i$ in that case.) $V_i$ where $i$ is the minimum
+ * integer such that $b(V_i)>0$. (Or an indication that no such value exists.)
+ * $V_i$ where $i$ is the maximum integer such that $b(V_i)<0$. (Or an
+ * indication that no such value exists.)
+ *
+ * When looking up a value using a Heaviside function, we get the value and its
+ * index.
+ *
+ * We can also split an OMT into two OMTs, splitting the weight of the values
+ * evenly. Find a value $j$ such that the values to the left of $j$ have about
+ * the same total weight as the values to the right of $j$. The resulting two
+ * OMTs contain the values to the left of $j$ and the values to the right of $j$
+ * respectively. All of the values from the original OMT go into one of the new
+ * OMTs. If the weights of the values don't split exactly evenly, then the
+ * implementation has the freedom to choose whether the new left OMT or the new
+ * right OMT is larger.
+ *
+ * Performance:
+ * Insertion and deletion should run with $O(\log |V|)$ time and $O(\log |V|)$
+ * calls to the Heaviside function. The memory required is O(|V|).
+ *
+ * Usage:
+ * The omt is templated by two parameters:
+ * - omtdata_t is what will be stored within the omt. These could be pointers
+ * or real data types (ints, structs).
+ * - omtdataout_t is what will be returned by find and related functions. By
+ * default, it is the same as omtdata_t, but you can set it to (omtdata_t *). To
+ * create an omt which will store "TXNID"s, for example, it is a good idea to
+ * typedef the template: typedef omt<TXNID> txnid_omt_t; If you are storing
+ * structs, you may want to be able to get a pointer to the data actually stored
+ * in the omt (see find_zero). To do this, use the second template parameter:
+ * typedef omt<struct foo, struct foo *> foo_omt_t;
+ */
+
+namespace omt_internal {
+
+template <bool subtree_supports_marks>
+class subtree_templated {
+ private:
+ uint32_t m_index;
+
+ public:
+ static const uint32_t NODE_NULL = UINT32_MAX;
+ inline void set_to_null(void) { m_index = NODE_NULL; }
+
+ inline bool is_null(void) const { return NODE_NULL == this->get_index(); }
+
+ inline uint32_t get_index(void) const { return m_index; }
+
+ inline void set_index(uint32_t index) {
+ paranoid_invariant(index != NODE_NULL);
+ m_index = index;
+ }
+} __attribute__((__packed__, aligned(4)));
+
+template <>
+class subtree_templated<true> {
+ private:
+ uint32_t m_bitfield;
+ static const uint32_t MASK_INDEX = ~(((uint32_t)1) << 31);
+ static const uint32_t MASK_BIT = ((uint32_t)1) << 31;
+
+ inline void set_index_internal(uint32_t new_index) {
+ m_bitfield = (m_bitfield & MASK_BIT) | new_index;
+ }
+
+ public:
+ static const uint32_t NODE_NULL = INT32_MAX;
+ inline void set_to_null(void) { this->set_index_internal(NODE_NULL); }
+
+ inline bool is_null(void) const { return NODE_NULL == this->get_index(); }
+
+ inline uint32_t get_index(void) const {
+ TOKU_DRD_IGNORE_VAR(m_bitfield);
+ const uint32_t bits = m_bitfield;
+ TOKU_DRD_STOP_IGNORING_VAR(m_bitfield);
+ return bits & MASK_INDEX;
+ }
+
+ inline void set_index(uint32_t index) {
+ paranoid_invariant(index < NODE_NULL);
+ this->set_index_internal(index);
+ }
+
+ inline bool get_bit(void) const {
+ TOKU_DRD_IGNORE_VAR(m_bitfield);
+ const uint32_t bits = m_bitfield;
+ TOKU_DRD_STOP_IGNORING_VAR(m_bitfield);
+ return (bits & MASK_BIT) != 0;
+ }
+
+ inline void enable_bit(void) {
+ // These bits may be set by a thread with a write lock on some
+ // leaf, and the index can be read by another thread with a (read
+ // or write) lock on another thread. Also, the has_marks_below
+ // bit can be set by two threads simultaneously. Neither of these
+ // are real races, so if we are using DRD we should tell it to
+ // ignore these bits just while we set this bit. If there were a
+ // race in setting the index, that would be a real race.
+ TOKU_DRD_IGNORE_VAR(m_bitfield);
+ m_bitfield |= MASK_BIT;
+ TOKU_DRD_STOP_IGNORING_VAR(m_bitfield);
+ }
+
+ inline void disable_bit(void) { m_bitfield &= MASK_INDEX; }
+} __attribute__((__packed__));
+
+template <typename omtdata_t, bool subtree_supports_marks>
+class omt_node_templated {
+ public:
+ omtdata_t value;
+ uint32_t weight;
+ subtree_templated<subtree_supports_marks> left;
+ subtree_templated<subtree_supports_marks> right;
+
+ // this needs to be in both implementations because we don't have
+ // a "static if" the caller can use
+ inline void clear_stolen_bits(void) {}
+}; // note: originally this class had __attribute__((__packed__, aligned(4)))
+
+template <typename omtdata_t>
+class omt_node_templated<omtdata_t, true> {
+ public:
+ omtdata_t value;
+ uint32_t weight;
+ subtree_templated<true> left;
+ subtree_templated<true> right;
+ inline bool get_marked(void) const { return left.get_bit(); }
+ inline void set_marked_bit(void) { return left.enable_bit(); }
+ inline void unset_marked_bit(void) { return left.disable_bit(); }
+
+ inline bool get_marks_below(void) const { return right.get_bit(); }
+ inline void set_marks_below_bit(void) {
+ // This function can be called by multiple threads.
+ // Checking first reduces cache invalidation.
+ if (!this->get_marks_below()) {
+ right.enable_bit();
+ }
+ }
+ inline void unset_marks_below_bit(void) { right.disable_bit(); }
+
+ inline void clear_stolen_bits(void) {
+ this->unset_marked_bit();
+ this->unset_marks_below_bit();
+ }
+}; // note: originally this class had __attribute__((__packed__, aligned(4)))
+
+} // namespace omt_internal
+
+template <typename omtdata_t, typename omtdataout_t = omtdata_t,
+ bool supports_marks = false>
+class omt {
+ public:
+ /**
+ * Effect: Create an empty OMT.
+ * Performance: constant time.
+ */
+ void create(void);
+
+ /**
+ * Effect: Create an empty OMT with no internal allocated space.
+ * Performance: constant time.
+ * Rationale: In some cases we need a valid omt but don't want to malloc.
+ */
+ void create_no_array(void);
+
+ /**
+ * Effect: Create a OMT containing values. The number of values is in
+ * numvalues. Stores the new OMT in *omtp. Requires: this has not been created
+ * yet Requires: values != NULL Requires: values is sorted Performance:
+ * time=O(numvalues) Rationale: Normally to insert N values takes O(N lg N)
+ * amortized time. If the N values are known in advance, are sorted, and the
+ * structure is empty, we can batch insert them much faster.
+ */
+ __attribute__((nonnull)) void create_from_sorted_array(
+ const omtdata_t *const values, const uint32_t numvalues);
+
+ /**
+ * Effect: Create an OMT containing values. The number of values is in
+ * numvalues. On success the OMT takes ownership of *values array, and sets
+ * values=NULL. Requires: this has not been created yet Requires: values !=
+ * NULL Requires: *values is sorted Requires: *values was allocated with
+ * toku_malloc Requires: Capacity of the *values array is <= new_capacity
+ * Requires: On success, *values may not be accessed again by the caller.
+ * Performance: time=O(1)
+ * Rational: create_from_sorted_array takes O(numvalues) time.
+ * By taking ownership of the array, we save a malloc and
+ * memcpy, and possibly a free (if the caller is done with the array).
+ */
+ void create_steal_sorted_array(omtdata_t **const values,
+ const uint32_t numvalues,
+ const uint32_t new_capacity);
+
+ /**
+ * Effect: Create a new OMT, storing it in *newomt.
+ * The values to the right of index (starting at index) are moved to *newomt.
+ * Requires: newomt != NULL
+ * Returns
+ * 0 success,
+ * EINVAL if index > toku_omt_size(omt)
+ * On nonzero return, omt and *newomt are unmodified.
+ * Performance: time=O(n)
+ * Rationale: We don't need a split-evenly operation. We need to split items
+ * so that their total sizes are even, and other similar splitting criteria.
+ * It's easy to split evenly by calling size(), and dividing by two.
+ */
+ __attribute__((nonnull)) int split_at(omt *const newomt, const uint32_t idx);
+
+ /**
+ * Effect: Appends leftomt and rightomt to produce a new omt.
+ * Creates this as the new omt.
+ * leftomt and rightomt are destroyed.
+ * Performance: time=O(n) is acceptable, but one can imagine implementations
+ * that are O(\log n) worst-case.
+ */
+ __attribute__((nonnull)) void merge(omt *const leftomt, omt *const rightomt);
+
+ /**
+ * Effect: Creates a copy of an omt.
+ * Creates this as the clone.
+ * Each element is copied directly. If they are pointers, the underlying
+ * data is not duplicated. Performance: O(n) or the running time of
+ * fill_array_with_subtree_values()
+ */
+ void clone(const omt &src);
+
+ /**
+ * Effect: Set the tree to be empty.
+ * Note: Will not reallocate or resize any memory.
+ * Performance: time=O(1)
+ */
+ void clear(void);
+
+ /**
+ * Effect: Destroy an OMT, freeing all its memory.
+ * If the values being stored are pointers, their underlying data is not
+ * freed. See free_items() Those values may be freed before or after calling
+ * toku_omt_destroy. Rationale: Returns no values since free() cannot fail.
+ * Rationale: Does not free the underlying pointers to reduce complexity.
+ * Performance: time=O(1)
+ */
+ void destroy(void);
+
+ /**
+ * Effect: return |this|.
+ * Performance: time=O(1)
+ */
+ uint32_t size(void) const;
+
+ /**
+ * Effect: Insert value into the OMT.
+ * If there is some i such that $h(V_i, v)=0$ then returns DB_KEYEXIST.
+ * Otherwise, let i be the minimum value such that $h(V_i, v)>0$.
+ * If no such i exists, then let i be |V|
+ * Then this has the same effect as
+ * insert_at(tree, value, i);
+ * If idx!=NULL then i is stored in *idx
+ * Requires: The signum of h must be monotonically increasing.
+ * Returns:
+ * 0 success
+ * DB_KEYEXIST the key is present (h was equal to zero for some value)
+ * On nonzero return, omt is unchanged.
+ * Performance: time=O(\log N) amortized.
+ * Rationale: Some future implementation may be O(\log N) worst-case time, but
+ * O(\log N) amortized is good enough for now.
+ */
+ template <typename omtcmp_t, int (*h)(const omtdata_t &, const omtcmp_t &)>
+ int insert(const omtdata_t &value, const omtcmp_t &v, uint32_t *const idx);
+
+ /**
+ * Effect: Increases indexes of all items at slot >= idx by 1.
+ * Insert value into the position at idx.
+ * Returns:
+ * 0 success
+ * EINVAL if idx > this->size()
+ * On error, omt is unchanged.
+ * Performance: time=O(\log N) amortized time.
+ * Rationale: Some future implementation may be O(\log N) worst-case time, but
+ * O(\log N) amortized is good enough for now.
+ */
+ int insert_at(const omtdata_t &value, const uint32_t idx);
+
+ /**
+ * Effect: Replaces the item at idx with value.
+ * Returns:
+ * 0 success
+ * EINVAL if idx>=this->size()
+ * On error, omt is unchanged.
+ * Performance: time=O(\log N)
+ * Rationale: The FT needs to be able to replace a value with another copy of
+ * the same value (allocated in a different location)
+ *
+ */
+ int set_at(const omtdata_t &value, const uint32_t idx);
+
+ /**
+ * Effect: Delete the item in slot idx.
+ * Decreases indexes of all items at slot > idx by 1.
+ * Returns
+ * 0 success
+ * EINVAL if idx>=this->size()
+ * On error, omt is unchanged.
+ * Rationale: To delete an item, first find its index using find or find_zero,
+ * then delete it. Performance: time=O(\log N) amortized.
+ */
+ int delete_at(const uint32_t idx);
+
+ /**
+ * Effect: Iterate over the values of the omt, from left to right, calling f
+ * on each value. The first argument passed to f is a ref-to-const of the
+ * value stored in the omt. The second argument passed to f is the index of
+ * the value. The third argument passed to f is iterate_extra. The indices run
+ * from 0 (inclusive) to this->size() (exclusive). Requires: f != NULL
+ * Returns:
+ * If f ever returns nonzero, then the iteration stops, and the value
+ * returned by f is returned by iterate. If f always returns zero, then
+ * iterate returns 0. Requires: Don't modify the omt while running. (E.g., f
+ * may not insert or delete values from the omt.) Performance: time=O(i+\log
+ * N) where i is the number of times f is called, and N is the number of
+ * elements in the omt. Rationale: Although the functional iterator requires
+ * defining another function (as opposed to C++ style iterator), it is much
+ * easier to read. Rationale: We may at some point use functors, but for now
+ * this is a smaller change from the old OMT.
+ */
+ template <typename iterate_extra_t,
+ int (*f)(const omtdata_t &, const uint32_t, iterate_extra_t *const)>
+ int iterate(iterate_extra_t *const iterate_extra) const;
+
+ /**
+ * Effect: Iterate over the values of the omt, from left to right, calling f
+ * on each value. The first argument passed to f is a ref-to-const of the
+ * value stored in the omt. The second argument passed to f is the index of
+ * the value. The third argument passed to f is iterate_extra. The indices run
+ * from 0 (inclusive) to this->size() (exclusive). We will iterate only over
+ * [left,right)
+ *
+ * Requires: left <= right
+ * Requires: f != NULL
+ * Returns:
+ * EINVAL if right > this->size()
+ * If f ever returns nonzero, then the iteration stops, and the value
+ * returned by f is returned by iterate_on_range. If f always returns zero,
+ * then iterate_on_range returns 0. Requires: Don't modify the omt while
+ * running. (E.g., f may not insert or delete values from the omt.)
+ * Performance: time=O(i+\log N) where i is the number of times f is called,
+ * and N is the number of elements in the omt. Rational: Although the
+ * functional iterator requires defining another function (as opposed to C++
+ * style iterator), it is much easier to read.
+ */
+ template <typename iterate_extra_t,
+ int (*f)(const omtdata_t &, const uint32_t, iterate_extra_t *const)>
+ int iterate_on_range(const uint32_t left, const uint32_t right,
+ iterate_extra_t *const iterate_extra) const;
+
+ /**
+ * Effect: Iterate over the values of the omt, and mark the nodes that are
+ * visited. Other than the marks, this behaves the same as iterate_on_range.
+ * Requires: supports_marks == true
+ * Performance: time=O(i+\log N) where i is the number of times f is called,
+ * and N is the number of elements in the omt. Notes: This function MAY be
+ * called concurrently by multiple threads, but not concurrently with any
+ * other non-const function.
+ */
+ template <typename iterate_extra_t,
+ int (*f)(const omtdata_t &, const uint32_t, iterate_extra_t *const)>
+ int iterate_and_mark_range(const uint32_t left, const uint32_t right,
+ iterate_extra_t *const iterate_extra);
+
+ /**
+ * Effect: Iterate over the values of the omt, from left to right, calling f
+ * on each value whose node has been marked. Other than the marks, this
+ * behaves the same as iterate. Requires: supports_marks == true Performance:
+ * time=O(i+\log N) where i is the number of times f is called, and N is the
+ * number of elements in the omt.
+ */
+ template <typename iterate_extra_t,
+ int (*f)(const omtdata_t &, const uint32_t, iterate_extra_t *const)>
+ int iterate_over_marked(iterate_extra_t *const iterate_extra) const;
+
+ /**
+ * Effect: Delete all elements from the omt, whose nodes have been marked.
+ * Requires: supports_marks == true
+ * Performance: time=O(N + i\log N) where i is the number of marked elements,
+ * {c,sh}ould be faster
+ */
+ void delete_all_marked(void);
+
+ /**
+ * Effect: Verify that the internal state of the marks in the tree are
+ * self-consistent. Crashes the system if the marks are in a bad state.
+ * Requires: supports_marks == true
+ * Performance: time=O(N)
+ * Notes:
+ * Even though this is a const function, it requires exclusive access.
+ * Rationale:
+ * The current implementation of the marks relies on a sort of
+ * "cache" bit representing the state of bits below it in the tree.
+ * This allows glass-box testing that these bits are correct.
+ */
+ void verify_marks_consistent(void) const;
+
+ /**
+ * Effect: None
+ * Returns whether there are any marks in the tree.
+ */
+ bool has_marks(void) const;
+
+ /**
+ * Effect: Iterate over the values of the omt, from left to right, calling f
+ * on each value. The first argument passed to f is a pointer to the value
+ * stored in the omt. The second argument passed to f is the index of the
+ * value. The third argument passed to f is iterate_extra. The indices run
+ * from 0 (inclusive) to this->size() (exclusive). Requires: same as for
+ * iterate() Returns: same as for iterate() Performance: same as for iterate()
+ * Rationale: In general, most iterators should use iterate() since they
+ * should not modify the data stored in the omt. This function is for
+ * iterators which need to modify values (for example, free_items). Rationale:
+ * We assume if you are transforming the data in place, you want to do it to
+ * everything at once, so there is not yet an iterate_on_range_ptr (but there
+ * could be).
+ */
+ template <typename iterate_extra_t,
+ int (*f)(omtdata_t *, const uint32_t, iterate_extra_t *const)>
+ void iterate_ptr(iterate_extra_t *const iterate_extra);
+
+ /**
+ * Effect: Set *value=V_idx
+ * Returns
+ * 0 success
+ * EINVAL if index>=toku_omt_size(omt)
+ * On nonzero return, *value is unchanged
+ * Performance: time=O(\log N)
+ */
+ int fetch(const uint32_t idx, omtdataout_t *const value) const;
+
+ /**
+ * Effect: Find the smallest i such that h(V_i, extra)>=0
+ * If there is such an i and h(V_i,extra)==0 then set *idxp=i, set *value =
+ * V_i, and return 0. If there is such an i and h(V_i,extra)>0 then set
+ * *idxp=i and return DB_NOTFOUND. If there is no such i then set
+ * *idx=this->size() and return DB_NOTFOUND. Note: value is of type
+ * omtdataout_t, which may be of type (omtdata_t) or (omtdata_t *) but is
+ * fixed by the instantiation. If it is the value type, then the value is
+ * copied out (even if the value type is a pointer to something else) If it is
+ * the pointer type, then *value is set to a pointer to the data within the
+ * omt. This is determined by the type of the omt as initially declared. If
+ * the omt is declared as omt<foo_t>, then foo_t's will be stored and foo_t's
+ * will be returned by find and related functions. If the omt is declared as
+ * omt<foo_t, foo_t *>, then foo_t's will be stored, and pointers to the
+ * stored items will be returned by find and related functions. Rationale:
+ * Structs too small for malloc should be stored directly in the omt.
+ * These structs may need to be edited as they exist inside the omt, so we
+ * need a way to get a pointer within the omt. Using separate functions for
+ * returning pointers and values increases code duplication and reduces
+ * type-checking. That also reduces the ability of the creator of a data
+ * structure to give advice to its future users. Slight overloading in this
+ * case seemed to provide a better API and better type checking.
+ */
+ template <typename omtcmp_t, int (*h)(const omtdata_t &, const omtcmp_t &)>
+ int find_zero(const omtcmp_t &extra, omtdataout_t *const value,
+ uint32_t *const idxp) const;
+
+ /**
+ * Effect:
+ * If direction >0 then find the smallest i such that h(V_i,extra)>0.
+ * If direction <0 then find the largest i such that h(V_i,extra)<0.
+ * (Direction may not be equal to zero.)
+ * If value!=NULL then store V_i in *value
+ * If idxp!=NULL then store i in *idxp.
+ * Requires: The signum of h is monotically increasing.
+ * Returns
+ * 0 success
+ * DB_NOTFOUND no such value is found.
+ * On nonzero return, *value and *idxp are unchanged
+ * Performance: time=O(\log N)
+ * Rationale:
+ * Here's how to use the find function to find various things
+ * Cases for find:
+ * find first value: ( h(v)=+1, direction=+1 )
+ * find last value ( h(v)=-1, direction=-1 )
+ * find first X ( h(v)=(v< x) ? -1 : 1 direction=+1 )
+ * find last X ( h(v)=(v<=x) ? -1 : 1 direction=-1 )
+ * find X or successor to X ( same as find first X. )
+ *
+ * Rationale: To help understand heaviside functions and behavor of find:
+ * There are 7 kinds of heaviside functions.
+ * The signus of the h must be monotonically increasing.
+ * Given a function of the following form, A is the element
+ * returned for direction>0, B is the element returned
+ * for direction<0, C is the element returned for
+ * direction==0 (see find_zero) (with a return of 0), and D is the element
+ * returned for direction==0 (see find_zero) with a return of DB_NOTFOUND.
+ * If any of A, B, or C are not found, then asking for the
+ * associated direction will return DB_NOTFOUND.
+ * See find_zero for more information.
+ *
+ * Let the following represent the signus of the heaviside function.
+ *
+ * -...-
+ * A
+ * D
+ *
+ * +...+
+ * B
+ * D
+ *
+ * 0...0
+ * C
+ *
+ * -...-0...0
+ * AC
+ *
+ * 0...0+...+
+ * C B
+ *
+ * -...-+...+
+ * AB
+ * D
+ *
+ * -...-0...0+...+
+ * AC B
+ */
+ template <typename omtcmp_t, int (*h)(const omtdata_t &, const omtcmp_t &)>
+ int find(const omtcmp_t &extra, int direction, omtdataout_t *const value,
+ uint32_t *const idxp) const;
+
+ /**
+ * Effect: Return the size (in bytes) of the omt, as it resides in main
+ * memory. If the data stored are pointers, don't include the size of what
+ * they all point to.
+ */
+ size_t memory_size(void);
+
+ private:
+ typedef uint32_t node_idx;
+ typedef omt_internal::subtree_templated<supports_marks> subtree;
+ typedef omt_internal::omt_node_templated<omtdata_t, supports_marks> omt_node;
+ ENSURE_POD(subtree);
+
+ struct omt_array {
+ uint32_t start_idx;
+ uint32_t num_values;
+ omtdata_t *values;
+ };
+
+ struct omt_tree {
+ subtree root;
+ uint32_t free_idx;
+ omt_node *nodes;
+ };
+
+ bool is_array;
+ uint32_t capacity;
+ union {
+ struct omt_array a;
+ struct omt_tree t;
+ } d;
+
+ __attribute__((nonnull)) void unmark(const subtree &subtree,
+ const uint32_t index,
+ GrowableArray<node_idx> *const indexes);
+
+ void create_internal_no_array(const uint32_t new_capacity);
+
+ void create_internal(const uint32_t new_capacity);
+
+ uint32_t nweight(const subtree &subtree) const;
+
+ node_idx node_malloc(void);
+
+ void node_free(const node_idx idx);
+
+ void maybe_resize_array(const uint32_t n);
+
+ __attribute__((nonnull)) void fill_array_with_subtree_values(
+ omtdata_t *const array, const subtree &subtree) const;
+
+ void convert_to_array(void);
+
+ __attribute__((nonnull)) void rebuild_from_sorted_array(
+ subtree *const subtree, const omtdata_t *const values,
+ const uint32_t numvalues);
+
+ void convert_to_tree(void);
+
+ void maybe_resize_or_convert(const uint32_t n);
+
+ bool will_need_rebalance(const subtree &subtree, const int leftmod,
+ const int rightmod) const;
+
+ __attribute__((nonnull)) void insert_internal(
+ subtree *const subtreep, const omtdata_t &value, const uint32_t idx,
+ subtree **const rebalance_subtree);
+
+ void set_at_internal_array(const omtdata_t &value, const uint32_t idx);
+
+ void set_at_internal(const subtree &subtree, const omtdata_t &value,
+ const uint32_t idx);
+
+ void delete_internal(subtree *const subtreep, const uint32_t idx,
+ omt_node *const copyn,
+ subtree **const rebalance_subtree);
+
+ template <typename iterate_extra_t,
+ int (*f)(const omtdata_t &, const uint32_t, iterate_extra_t *const)>
+ int iterate_internal_array(const uint32_t left, const uint32_t right,
+ iterate_extra_t *const iterate_extra) const;
+
+ template <typename iterate_extra_t,
+ int (*f)(omtdata_t *, const uint32_t, iterate_extra_t *const)>
+ void iterate_ptr_internal(const uint32_t left, const uint32_t right,
+ const subtree &subtree, const uint32_t idx,
+ iterate_extra_t *const iterate_extra);
+
+ template <typename iterate_extra_t,
+ int (*f)(omtdata_t *, const uint32_t, iterate_extra_t *const)>
+ void iterate_ptr_internal_array(const uint32_t left, const uint32_t right,
+ iterate_extra_t *const iterate_extra);
+
+ template <typename iterate_extra_t,
+ int (*f)(const omtdata_t &, const uint32_t, iterate_extra_t *const)>
+ int iterate_internal(const uint32_t left, const uint32_t right,
+ const subtree &subtree, const uint32_t idx,
+ iterate_extra_t *const iterate_extra) const;
+
+ template <typename iterate_extra_t,
+ int (*f)(const omtdata_t &, const uint32_t, iterate_extra_t *const)>
+ int iterate_and_mark_range_internal(const uint32_t left, const uint32_t right,
+ const subtree &subtree,
+ const uint32_t idx,
+ iterate_extra_t *const iterate_extra);
+
+ template <typename iterate_extra_t,
+ int (*f)(const omtdata_t &, const uint32_t, iterate_extra_t *const)>
+ int iterate_over_marked_internal(const subtree &subtree, const uint32_t idx,
+ iterate_extra_t *const iterate_extra) const;
+
+ uint32_t verify_marks_consistent_internal(const subtree &subtree,
+ const bool allow_marks) const;
+
+ void fetch_internal_array(const uint32_t i, omtdataout_t *const value) const;
+
+ void fetch_internal(const subtree &subtree, const uint32_t i,
+ omtdataout_t *const value) const;
+
+ __attribute__((nonnull)) void fill_array_with_subtree_idxs(
+ node_idx *const array, const subtree &subtree) const;
+
+ __attribute__((nonnull)) void rebuild_subtree_from_idxs(
+ subtree *const subtree, const node_idx *const idxs,
+ const uint32_t numvalues);
+
+ __attribute__((nonnull)) void rebalance(subtree *const subtree);
+
+ __attribute__((nonnull)) static void copyout(omtdata_t *const out,
+ const omt_node *const n);
+
+ __attribute__((nonnull)) static void copyout(omtdata_t **const out,
+ omt_node *const n);
+
+ __attribute__((nonnull)) static void copyout(
+ omtdata_t *const out, const omtdata_t *const stored_value_ptr);
+
+ __attribute__((nonnull)) static void copyout(
+ omtdata_t **const out, omtdata_t *const stored_value_ptr);
+
+ template <typename omtcmp_t, int (*h)(const omtdata_t &, const omtcmp_t &)>
+ int find_internal_zero_array(const omtcmp_t &extra, omtdataout_t *const value,
+ uint32_t *const idxp) const;
+
+ template <typename omtcmp_t, int (*h)(const omtdata_t &, const omtcmp_t &)>
+ int find_internal_zero(const subtree &subtree, const omtcmp_t &extra,
+ omtdataout_t *const value, uint32_t *const idxp) const;
+
+ template <typename omtcmp_t, int (*h)(const omtdata_t &, const omtcmp_t &)>
+ int find_internal_plus_array(const omtcmp_t &extra, omtdataout_t *const value,
+ uint32_t *const idxp) const;
+
+ template <typename omtcmp_t, int (*h)(const omtdata_t &, const omtcmp_t &)>
+ int find_internal_plus(const subtree &subtree, const omtcmp_t &extra,
+ omtdataout_t *const value, uint32_t *const idxp) const;
+
+ template <typename omtcmp_t, int (*h)(const omtdata_t &, const omtcmp_t &)>
+ int find_internal_minus_array(const omtcmp_t &extra,
+ omtdataout_t *const value,
+ uint32_t *const idxp) const;
+
+ template <typename omtcmp_t, int (*h)(const omtdata_t &, const omtcmp_t &)>
+ int find_internal_minus(const subtree &subtree, const omtcmp_t &extra,
+ omtdataout_t *const value,
+ uint32_t *const idxp) const;
+};
+
+} // namespace toku
+
+// include the implementation here
+#include "omt_impl.h"
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/util/omt_impl.h b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/util/omt_impl.h
new file mode 100644
index 000000000..e77986716
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/util/omt_impl.h
@@ -0,0 +1,1295 @@
+/* -*- mode: C++; c-basic-offset: 4; indent-tabs-mode: nil -*- */
+// vim: ft=cpp:expandtab:ts=8:sw=4:softtabstop=4:
+#ident "$Id$"
+/*======
+This file is part of PerconaFT.
+
+
+Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved.
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License, version 2,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License, version 3,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+======= */
+
+#ident \
+ "Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved."
+
+#include <string.h>
+
+#include "../db.h"
+#include "../portability/memory.h"
+
+namespace toku {
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+void omt<omtdata_t, omtdataout_t, supports_marks>::create(void) {
+ this->create_internal(2);
+ if (supports_marks) {
+ this->convert_to_tree();
+ }
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+void omt<omtdata_t, omtdataout_t, supports_marks>::create_no_array(void) {
+ if (!supports_marks) {
+ this->create_internal_no_array(0);
+ } else {
+ this->is_array = false;
+ this->capacity = 0;
+ this->d.t.nodes = nullptr;
+ this->d.t.root.set_to_null();
+ this->d.t.free_idx = 0;
+ }
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+void omt<omtdata_t, omtdataout_t, supports_marks>::create_from_sorted_array(
+ const omtdata_t *const values, const uint32_t numvalues) {
+ this->create_internal(numvalues);
+ memcpy(this->d.a.values, values, numvalues * (sizeof values[0]));
+ this->d.a.num_values = numvalues;
+ if (supports_marks) {
+ this->convert_to_tree();
+ }
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+void omt<omtdata_t, omtdataout_t, supports_marks>::create_steal_sorted_array(
+ omtdata_t **const values, const uint32_t numvalues,
+ const uint32_t new_capacity) {
+ paranoid_invariant_notnull(values);
+ this->create_internal_no_array(new_capacity);
+ this->d.a.num_values = numvalues;
+ this->d.a.values = *values;
+ *values = nullptr;
+ if (supports_marks) {
+ this->convert_to_tree();
+ }
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+int omt<omtdata_t, omtdataout_t, supports_marks>::split_at(omt *const newomt,
+ const uint32_t idx) {
+ barf_if_marked(*this);
+ paranoid_invariant_notnull(newomt);
+ if (idx > this->size()) {
+ return EINVAL;
+ }
+ this->convert_to_array();
+ const uint32_t newsize = this->size() - idx;
+ newomt->create_from_sorted_array(&this->d.a.values[this->d.a.start_idx + idx],
+ newsize);
+ this->d.a.num_values = idx;
+ this->maybe_resize_array(idx);
+ if (supports_marks) {
+ this->convert_to_tree();
+ }
+ return 0;
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+void omt<omtdata_t, omtdataout_t, supports_marks>::merge(omt *const leftomt,
+ omt *const rightomt) {
+ barf_if_marked(*this);
+ paranoid_invariant_notnull(leftomt);
+ paranoid_invariant_notnull(rightomt);
+ const uint32_t leftsize = leftomt->size();
+ const uint32_t rightsize = rightomt->size();
+ const uint32_t newsize = leftsize + rightsize;
+
+ if (leftomt->is_array) {
+ if (leftomt->capacity -
+ (leftomt->d.a.start_idx + leftomt->d.a.num_values) >=
+ rightsize) {
+ this->create_steal_sorted_array(
+ &leftomt->d.a.values, leftomt->d.a.num_values, leftomt->capacity);
+ this->d.a.start_idx = leftomt->d.a.start_idx;
+ } else {
+ this->create_internal(newsize);
+ memcpy(&this->d.a.values[0], &leftomt->d.a.values[leftomt->d.a.start_idx],
+ leftomt->d.a.num_values * (sizeof this->d.a.values[0]));
+ }
+ } else {
+ this->create_internal(newsize);
+ leftomt->fill_array_with_subtree_values(&this->d.a.values[0],
+ leftomt->d.t.root);
+ }
+ leftomt->destroy();
+ this->d.a.num_values = leftsize;
+
+ if (rightomt->is_array) {
+ memcpy(&this->d.a.values[this->d.a.start_idx + this->d.a.num_values],
+ &rightomt->d.a.values[rightomt->d.a.start_idx],
+ rightomt->d.a.num_values * (sizeof this->d.a.values[0]));
+ } else {
+ rightomt->fill_array_with_subtree_values(
+ &this->d.a.values[this->d.a.start_idx + this->d.a.num_values],
+ rightomt->d.t.root);
+ }
+ rightomt->destroy();
+ this->d.a.num_values += rightsize;
+ paranoid_invariant(this->size() == newsize);
+ if (supports_marks) {
+ this->convert_to_tree();
+ }
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+void omt<omtdata_t, omtdataout_t, supports_marks>::clone(const omt &src) {
+ barf_if_marked(*this);
+ this->create_internal(src.size());
+ if (src.is_array) {
+ memcpy(&this->d.a.values[0], &src.d.a.values[src.d.a.start_idx],
+ src.d.a.num_values * (sizeof this->d.a.values[0]));
+ } else {
+ src.fill_array_with_subtree_values(&this->d.a.values[0], src.d.t.root);
+ }
+ this->d.a.num_values = src.size();
+ if (supports_marks) {
+ this->convert_to_tree();
+ }
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+void omt<omtdata_t, omtdataout_t, supports_marks>::clear(void) {
+ if (this->is_array) {
+ this->d.a.start_idx = 0;
+ this->d.a.num_values = 0;
+ } else {
+ this->d.t.root.set_to_null();
+ this->d.t.free_idx = 0;
+ }
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+void omt<omtdata_t, omtdataout_t, supports_marks>::destroy(void) {
+ this->clear();
+ this->capacity = 0;
+ if (this->is_array) {
+ if (this->d.a.values != nullptr) {
+ toku_free(this->d.a.values);
+ }
+ this->d.a.values = nullptr;
+ } else {
+ if (this->d.t.nodes != nullptr) {
+ toku_free(this->d.t.nodes);
+ }
+ this->d.t.nodes = nullptr;
+ }
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+uint32_t omt<omtdata_t, omtdataout_t, supports_marks>::size(void) const {
+ if (this->is_array) {
+ return this->d.a.num_values;
+ } else {
+ return this->nweight(this->d.t.root);
+ }
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+template <typename omtcmp_t, int (*h)(const omtdata_t &, const omtcmp_t &)>
+int omt<omtdata_t, omtdataout_t, supports_marks>::insert(const omtdata_t &value,
+ const omtcmp_t &v,
+ uint32_t *const idx) {
+ int r;
+ uint32_t insert_idx;
+
+ r = this->find_zero<omtcmp_t, h>(v, nullptr, &insert_idx);
+ if (r == 0) {
+ if (idx) *idx = insert_idx;
+ return DB_KEYEXIST;
+ }
+ if (r != DB_NOTFOUND) return r;
+
+ if ((r = this->insert_at(value, insert_idx))) return r;
+ if (idx) *idx = insert_idx;
+
+ return 0;
+}
+
+// The following 3 functions implement a static if for us.
+template <typename omtdata_t, typename omtdataout_t>
+static void barf_if_marked(const omt<omtdata_t, omtdataout_t, false> &UU(omt)) {
+}
+
+template <typename omtdata_t, typename omtdataout_t>
+static void barf_if_marked(const omt<omtdata_t, omtdataout_t, true> &omt) {
+ invariant(!omt.has_marks());
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+bool omt<omtdata_t, omtdataout_t, supports_marks>::has_marks(void) const {
+ static_assert(supports_marks, "Does not support marks");
+ if (this->d.t.root.is_null()) {
+ return false;
+ }
+ const omt_node &node = this->d.t.nodes[this->d.t.root.get_index()];
+ return node.get_marks_below() || node.get_marked();
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+int omt<omtdata_t, omtdataout_t, supports_marks>::insert_at(
+ const omtdata_t &value, const uint32_t idx) {
+ barf_if_marked(*this);
+ if (idx > this->size()) {
+ return EINVAL;
+ }
+
+ this->maybe_resize_or_convert(this->size() + 1);
+ if (this->is_array && idx != this->d.a.num_values &&
+ (idx != 0 || this->d.a.start_idx == 0)) {
+ this->convert_to_tree();
+ }
+ if (this->is_array) {
+ if (idx == this->d.a.num_values) {
+ this->d.a.values[this->d.a.start_idx + this->d.a.num_values] = value;
+ } else {
+ this->d.a.values[--this->d.a.start_idx] = value;
+ }
+ this->d.a.num_values++;
+ } else {
+ subtree *rebalance_subtree = nullptr;
+ this->insert_internal(&this->d.t.root, value, idx, &rebalance_subtree);
+ if (rebalance_subtree != nullptr) {
+ this->rebalance(rebalance_subtree);
+ }
+ }
+ return 0;
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+int omt<omtdata_t, omtdataout_t, supports_marks>::set_at(const omtdata_t &value,
+ const uint32_t idx) {
+ barf_if_marked(*this);
+ if (idx >= this->size()) {
+ return EINVAL;
+ }
+
+ if (this->is_array) {
+ this->set_at_internal_array(value, idx);
+ } else {
+ this->set_at_internal(this->d.t.root, value, idx);
+ }
+ return 0;
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+int omt<omtdata_t, omtdataout_t, supports_marks>::delete_at(
+ const uint32_t idx) {
+ barf_if_marked(*this);
+ if (idx >= this->size()) {
+ return EINVAL;
+ }
+
+ this->maybe_resize_or_convert(this->size() - 1);
+ if (this->is_array && idx != 0 && idx != this->d.a.num_values - 1) {
+ this->convert_to_tree();
+ }
+ if (this->is_array) {
+ // Testing for 0 does not rule out it being the last entry.
+ // Test explicitly for num_values-1
+ if (idx != this->d.a.num_values - 1) {
+ this->d.a.start_idx++;
+ }
+ this->d.a.num_values--;
+ } else {
+ subtree *rebalance_subtree = nullptr;
+ this->delete_internal(&this->d.t.root, idx, nullptr, &rebalance_subtree);
+ if (rebalance_subtree != nullptr) {
+ this->rebalance(rebalance_subtree);
+ }
+ }
+ return 0;
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+template <typename iterate_extra_t,
+ int (*f)(const omtdata_t &, const uint32_t, iterate_extra_t *const)>
+int omt<omtdata_t, omtdataout_t, supports_marks>::iterate(
+ iterate_extra_t *const iterate_extra) const {
+ return this->iterate_on_range<iterate_extra_t, f>(0, this->size(),
+ iterate_extra);
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+template <typename iterate_extra_t,
+ int (*f)(const omtdata_t &, const uint32_t, iterate_extra_t *const)>
+int omt<omtdata_t, omtdataout_t, supports_marks>::iterate_on_range(
+ const uint32_t left, const uint32_t right,
+ iterate_extra_t *const iterate_extra) const {
+ if (right > this->size()) {
+ return EINVAL;
+ }
+ if (left == right) {
+ return 0;
+ }
+ if (this->is_array) {
+ return this->iterate_internal_array<iterate_extra_t, f>(left, right,
+ iterate_extra);
+ }
+ return this->iterate_internal<iterate_extra_t, f>(left, right, this->d.t.root,
+ 0, iterate_extra);
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+template <typename iterate_extra_t,
+ int (*f)(const omtdata_t &, const uint32_t, iterate_extra_t *const)>
+int omt<omtdata_t, omtdataout_t, supports_marks>::iterate_and_mark_range(
+ const uint32_t left, const uint32_t right,
+ iterate_extra_t *const iterate_extra) {
+ static_assert(supports_marks, "does not support marks");
+ if (right > this->size()) {
+ return EINVAL;
+ }
+ if (left == right) {
+ return 0;
+ }
+ paranoid_invariant(!this->is_array);
+ return this->iterate_and_mark_range_internal<iterate_extra_t, f>(
+ left, right, this->d.t.root, 0, iterate_extra);
+}
+
+// TODO: We can optimize this if we steal 3 bits. 1 bit: this node is
+// marked. 1 bit: left subtree has marks. 1 bit: right subtree has marks.
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+template <typename iterate_extra_t,
+ int (*f)(const omtdata_t &, const uint32_t, iterate_extra_t *const)>
+int omt<omtdata_t, omtdataout_t, supports_marks>::iterate_over_marked(
+ iterate_extra_t *const iterate_extra) const {
+ static_assert(supports_marks, "does not support marks");
+ paranoid_invariant(!this->is_array);
+ return this->iterate_over_marked_internal<iterate_extra_t, f>(
+ this->d.t.root, 0, iterate_extra);
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+void omt<omtdata_t, omtdataout_t, supports_marks>::unmark(
+ const subtree &st, const uint32_t index,
+ GrowableArray<node_idx> *const indexes) {
+ if (st.is_null()) {
+ return;
+ }
+ omt_node &n = this->d.t.nodes[st.get_index()];
+ const uint32_t index_root = index + this->nweight(n.left);
+
+ const bool below = n.get_marks_below();
+ if (below) {
+ this->unmark(n.left, index, indexes);
+ }
+ if (n.get_marked()) {
+ indexes->push(index_root);
+ }
+ n.clear_stolen_bits();
+ if (below) {
+ this->unmark(n.right, index_root + 1, indexes);
+ }
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+void omt<omtdata_t, omtdataout_t, supports_marks>::delete_all_marked(void) {
+ static_assert(supports_marks, "does not support marks");
+ if (!this->has_marks()) {
+ return;
+ }
+ paranoid_invariant(!this->is_array);
+ GrowableArray<node_idx> marked_indexes;
+ marked_indexes.init();
+
+ // Remove all marks.
+ // We need to delete all the stolen bits before calling delete_at to
+ // prevent barfing.
+ this->unmark(this->d.t.root, 0, &marked_indexes);
+
+ for (uint32_t i = 0; i < marked_indexes.get_size(); i++) {
+ // Delete from left to right, shift by number already deleted.
+ // Alternative is delete from right to left.
+ int r = this->delete_at(marked_indexes.fetch_unchecked(i) - i);
+ lazy_assert_zero(r);
+ }
+ marked_indexes.deinit();
+ barf_if_marked(*this);
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+uint32_t
+omt<omtdata_t, omtdataout_t, supports_marks>::verify_marks_consistent_internal(
+ const subtree &st, const bool UU(allow_marks)) const {
+ if (st.is_null()) {
+ return 0;
+ }
+ const omt_node &node = this->d.t.nodes[st.get_index()];
+ uint32_t num_marks =
+ verify_marks_consistent_internal(node.left, node.get_marks_below());
+ num_marks +=
+ verify_marks_consistent_internal(node.right, node.get_marks_below());
+ if (node.get_marks_below()) {
+ paranoid_invariant(allow_marks);
+ paranoid_invariant(num_marks > 0);
+ } else {
+ // redundant with invariant below, but nice to have explicitly
+ paranoid_invariant(num_marks == 0);
+ }
+ if (node.get_marked()) {
+ paranoid_invariant(allow_marks);
+ ++num_marks;
+ }
+ return num_marks;
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+void omt<omtdata_t, omtdataout_t, supports_marks>::verify_marks_consistent(
+ void) const {
+ static_assert(supports_marks, "does not support marks");
+ paranoid_invariant(!this->is_array);
+ this->verify_marks_consistent_internal(this->d.t.root, true);
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+template <typename iterate_extra_t,
+ int (*f)(omtdata_t *, const uint32_t, iterate_extra_t *const)>
+void omt<omtdata_t, omtdataout_t, supports_marks>::iterate_ptr(
+ iterate_extra_t *const iterate_extra) {
+ if (this->is_array) {
+ this->iterate_ptr_internal_array<iterate_extra_t, f>(0, this->size(),
+ iterate_extra);
+ } else {
+ this->iterate_ptr_internal<iterate_extra_t, f>(
+ 0, this->size(), this->d.t.root, 0, iterate_extra);
+ }
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+int omt<omtdata_t, omtdataout_t, supports_marks>::fetch(
+ const uint32_t idx, omtdataout_t *const value) const {
+ if (idx >= this->size()) {
+ return EINVAL;
+ }
+ if (this->is_array) {
+ this->fetch_internal_array(idx, value);
+ } else {
+ this->fetch_internal(this->d.t.root, idx, value);
+ }
+ return 0;
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+template <typename omtcmp_t, int (*h)(const omtdata_t &, const omtcmp_t &)>
+int omt<omtdata_t, omtdataout_t, supports_marks>::find_zero(
+ const omtcmp_t &extra, omtdataout_t *const value,
+ uint32_t *const idxp) const {
+ uint32_t tmp_index;
+ uint32_t *const child_idxp = (idxp != nullptr) ? idxp : &tmp_index;
+ int r;
+ if (this->is_array) {
+ r = this->find_internal_zero_array<omtcmp_t, h>(extra, value, child_idxp);
+ } else {
+ r = this->find_internal_zero<omtcmp_t, h>(this->d.t.root, extra, value,
+ child_idxp);
+ }
+ return r;
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+template <typename omtcmp_t, int (*h)(const omtdata_t &, const omtcmp_t &)>
+int omt<omtdata_t, omtdataout_t, supports_marks>::find(
+ const omtcmp_t &extra, int direction, omtdataout_t *const value,
+ uint32_t *const idxp) const {
+ uint32_t tmp_index;
+ uint32_t *const child_idxp = (idxp != nullptr) ? idxp : &tmp_index;
+ paranoid_invariant(direction != 0);
+ if (direction < 0) {
+ if (this->is_array) {
+ return this->find_internal_minus_array<omtcmp_t, h>(extra, value,
+ child_idxp);
+ } else {
+ return this->find_internal_minus<omtcmp_t, h>(this->d.t.root, extra,
+ value, child_idxp);
+ }
+ } else {
+ if (this->is_array) {
+ return this->find_internal_plus_array<omtcmp_t, h>(extra, value,
+ child_idxp);
+ } else {
+ return this->find_internal_plus<omtcmp_t, h>(this->d.t.root, extra, value,
+ child_idxp);
+ }
+ }
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+size_t omt<omtdata_t, omtdataout_t, supports_marks>::memory_size(void) {
+ if (this->is_array) {
+ return (sizeof *this) + this->capacity * (sizeof this->d.a.values[0]);
+ }
+ return (sizeof *this) + this->capacity * (sizeof this->d.t.nodes[0]);
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+void omt<omtdata_t, omtdataout_t, supports_marks>::create_internal_no_array(
+ const uint32_t new_capacity) {
+ this->is_array = true;
+ this->d.a.start_idx = 0;
+ this->d.a.num_values = 0;
+ this->d.a.values = nullptr;
+ this->capacity = new_capacity;
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+void omt<omtdata_t, omtdataout_t, supports_marks>::create_internal(
+ const uint32_t new_capacity) {
+ this->create_internal_no_array(new_capacity);
+ XMALLOC_N(this->capacity, this->d.a.values);
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+uint32_t omt<omtdata_t, omtdataout_t, supports_marks>::nweight(
+ const subtree &st) const {
+ if (st.is_null()) {
+ return 0;
+ } else {
+ return this->d.t.nodes[st.get_index()].weight;
+ }
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+typename omt<omtdata_t, omtdataout_t, supports_marks>::node_idx
+omt<omtdata_t, omtdataout_t, supports_marks>::node_malloc(void) {
+ paranoid_invariant(this->d.t.free_idx < this->capacity);
+ omt_node &n = this->d.t.nodes[this->d.t.free_idx];
+ n.clear_stolen_bits();
+ return this->d.t.free_idx++;
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+void omt<omtdata_t, omtdataout_t, supports_marks>::node_free(
+ const node_idx UU(idx)) {
+ paranoid_invariant(idx < this->capacity);
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+void omt<omtdata_t, omtdataout_t, supports_marks>::maybe_resize_array(
+ const uint32_t n) {
+ const uint32_t new_size = n <= 2 ? 4 : 2 * n;
+ const uint32_t room = this->capacity - this->d.a.start_idx;
+
+ if (room < n || this->capacity / 2 >= new_size) {
+ omtdata_t *XMALLOC_N(new_size, tmp_values);
+ if (this->d.a.num_values) {
+ memcpy(tmp_values, &this->d.a.values[this->d.a.start_idx],
+ this->d.a.num_values * (sizeof tmp_values[0]));
+ }
+ this->d.a.start_idx = 0;
+ this->capacity = new_size;
+ toku_free(this->d.a.values);
+ this->d.a.values = tmp_values;
+ }
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+void omt<omtdata_t, omtdataout_t,
+ supports_marks>::fill_array_with_subtree_values(omtdata_t *const array,
+ const subtree &st)
+ const {
+ if (st.is_null()) return;
+ const omt_node &tree = this->d.t.nodes[st.get_index()];
+ this->fill_array_with_subtree_values(&array[0], tree.left);
+ array[this->nweight(tree.left)] = tree.value;
+ this->fill_array_with_subtree_values(&array[this->nweight(tree.left) + 1],
+ tree.right);
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+void omt<omtdata_t, omtdataout_t, supports_marks>::convert_to_array(void) {
+ if (!this->is_array) {
+ const uint32_t num_values = this->size();
+ uint32_t new_size = 2 * num_values;
+ new_size = new_size < 4 ? 4 : new_size;
+
+ omtdata_t *XMALLOC_N(new_size, tmp_values);
+ this->fill_array_with_subtree_values(tmp_values, this->d.t.root);
+ toku_free(this->d.t.nodes);
+ this->is_array = true;
+ this->capacity = new_size;
+ this->d.a.num_values = num_values;
+ this->d.a.values = tmp_values;
+ this->d.a.start_idx = 0;
+ }
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+void omt<omtdata_t, omtdataout_t, supports_marks>::rebuild_from_sorted_array(
+ subtree *const st, const omtdata_t *const values,
+ const uint32_t numvalues) {
+ if (numvalues == 0) {
+ st->set_to_null();
+ } else {
+ const uint32_t halfway = numvalues / 2;
+ const node_idx newidx = this->node_malloc();
+ omt_node *const newnode = &this->d.t.nodes[newidx];
+ newnode->weight = numvalues;
+ newnode->value = values[halfway];
+ st->set_index(newidx);
+ // update everything before the recursive calls so the second call
+ // can be a tail call.
+ this->rebuild_from_sorted_array(&newnode->left, &values[0], halfway);
+ this->rebuild_from_sorted_array(&newnode->right, &values[halfway + 1],
+ numvalues - (halfway + 1));
+ }
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+void omt<omtdata_t, omtdataout_t, supports_marks>::convert_to_tree(void) {
+ if (this->is_array) {
+ const uint32_t num_nodes = this->size();
+ uint32_t new_size = num_nodes * 2;
+ new_size = new_size < 4 ? 4 : new_size;
+
+ omt_node *XMALLOC_N(new_size, new_nodes);
+ omtdata_t *const values = this->d.a.values;
+ omtdata_t *const tmp_values = &values[this->d.a.start_idx];
+ this->is_array = false;
+ this->d.t.nodes = new_nodes;
+ this->capacity = new_size;
+ this->d.t.free_idx = 0;
+ this->d.t.root.set_to_null();
+ this->rebuild_from_sorted_array(&this->d.t.root, tmp_values, num_nodes);
+ toku_free(values);
+ }
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+void omt<omtdata_t, omtdataout_t, supports_marks>::maybe_resize_or_convert(
+ const uint32_t n) {
+ if (this->is_array) {
+ this->maybe_resize_array(n);
+ } else {
+ const uint32_t new_size = n <= 2 ? 4 : 2 * n;
+ const uint32_t num_nodes = this->nweight(this->d.t.root);
+ if ((this->capacity / 2 >= new_size) ||
+ (this->d.t.free_idx >= this->capacity && num_nodes < n) ||
+ (this->capacity < n)) {
+ this->convert_to_array();
+ // if we had a free list, the "supports_marks" version could
+ // just resize, as it is now, we have to convert to and back
+ // from an array.
+ if (supports_marks) {
+ this->convert_to_tree();
+ }
+ }
+ }
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+bool omt<omtdata_t, omtdataout_t, supports_marks>::will_need_rebalance(
+ const subtree &st, const int leftmod, const int rightmod) const {
+ if (st.is_null()) {
+ return false;
+ }
+ const omt_node &n = this->d.t.nodes[st.get_index()];
+ // one of the 1's is for the root.
+ // the other is to take ceil(n/2)
+ const uint32_t weight_left = this->nweight(n.left) + leftmod;
+ const uint32_t weight_right = this->nweight(n.right) + rightmod;
+ return ((1 + weight_left < (1 + 1 + weight_right) / 2) ||
+ (1 + weight_right < (1 + 1 + weight_left) / 2));
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+void omt<omtdata_t, omtdataout_t, supports_marks>::insert_internal(
+ subtree *const subtreep, const omtdata_t &value, const uint32_t idx,
+ subtree **const rebalance_subtree) {
+ if (subtreep->is_null()) {
+ paranoid_invariant_zero(idx);
+ const node_idx newidx = this->node_malloc();
+ omt_node *const newnode = &this->d.t.nodes[newidx];
+ newnode->weight = 1;
+ newnode->left.set_to_null();
+ newnode->right.set_to_null();
+ newnode->value = value;
+ subtreep->set_index(newidx);
+ } else {
+ omt_node &n = this->d.t.nodes[subtreep->get_index()];
+ n.weight++;
+ if (idx <= this->nweight(n.left)) {
+ if (*rebalance_subtree == nullptr &&
+ this->will_need_rebalance(*subtreep, 1, 0)) {
+ *rebalance_subtree = subtreep;
+ }
+ this->insert_internal(&n.left, value, idx, rebalance_subtree);
+ } else {
+ if (*rebalance_subtree == nullptr &&
+ this->will_need_rebalance(*subtreep, 0, 1)) {
+ *rebalance_subtree = subtreep;
+ }
+ const uint32_t sub_index = idx - this->nweight(n.left) - 1;
+ this->insert_internal(&n.right, value, sub_index, rebalance_subtree);
+ }
+ }
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+void omt<omtdata_t, omtdataout_t, supports_marks>::set_at_internal_array(
+ const omtdata_t &value, const uint32_t idx) {
+ this->d.a.values[this->d.a.start_idx + idx] = value;
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+void omt<omtdata_t, omtdataout_t, supports_marks>::set_at_internal(
+ const subtree &st, const omtdata_t &value, const uint32_t idx) {
+ paranoid_invariant(!st.is_null());
+ omt_node &n = this->d.t.nodes[st.get_index()];
+ const uint32_t leftweight = this->nweight(n.left);
+ if (idx < leftweight) {
+ this->set_at_internal(n.left, value, idx);
+ } else if (idx == leftweight) {
+ n.value = value;
+ } else {
+ this->set_at_internal(n.right, value, idx - leftweight - 1);
+ }
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+void omt<omtdata_t, omtdataout_t, supports_marks>::delete_internal(
+ subtree *const subtreep, const uint32_t idx, omt_node *const copyn,
+ subtree **const rebalance_subtree) {
+ paranoid_invariant_notnull(subtreep);
+ paranoid_invariant_notnull(rebalance_subtree);
+ paranoid_invariant(!subtreep->is_null());
+ omt_node &n = this->d.t.nodes[subtreep->get_index()];
+ const uint32_t leftweight = this->nweight(n.left);
+ if (idx < leftweight) {
+ n.weight--;
+ if (*rebalance_subtree == nullptr &&
+ this->will_need_rebalance(*subtreep, -1, 0)) {
+ *rebalance_subtree = subtreep;
+ }
+ this->delete_internal(&n.left, idx, copyn, rebalance_subtree);
+ } else if (idx == leftweight) {
+ if (n.left.is_null()) {
+ const uint32_t oldidx = subtreep->get_index();
+ *subtreep = n.right;
+ if (copyn != nullptr) {
+ copyn->value = n.value;
+ }
+ this->node_free(oldidx);
+ } else if (n.right.is_null()) {
+ const uint32_t oldidx = subtreep->get_index();
+ *subtreep = n.left;
+ if (copyn != nullptr) {
+ copyn->value = n.value;
+ }
+ this->node_free(oldidx);
+ } else {
+ if (*rebalance_subtree == nullptr &&
+ this->will_need_rebalance(*subtreep, 0, -1)) {
+ *rebalance_subtree = subtreep;
+ }
+ // don't need to copy up value, it's only used by this
+ // next call, and when that gets to the bottom there
+ // won't be any more recursion
+ n.weight--;
+ this->delete_internal(&n.right, 0, &n, rebalance_subtree);
+ }
+ } else {
+ n.weight--;
+ if (*rebalance_subtree == nullptr &&
+ this->will_need_rebalance(*subtreep, 0, -1)) {
+ *rebalance_subtree = subtreep;
+ }
+ this->delete_internal(&n.right, idx - leftweight - 1, copyn,
+ rebalance_subtree);
+ }
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+template <typename iterate_extra_t,
+ int (*f)(const omtdata_t &, const uint32_t, iterate_extra_t *const)>
+int omt<omtdata_t, omtdataout_t, supports_marks>::iterate_internal_array(
+ const uint32_t left, const uint32_t right,
+ iterate_extra_t *const iterate_extra) const {
+ int r;
+ for (uint32_t i = left; i < right; ++i) {
+ r = f(this->d.a.values[this->d.a.start_idx + i], i, iterate_extra);
+ if (r != 0) {
+ return r;
+ }
+ }
+ return 0;
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+template <typename iterate_extra_t,
+ int (*f)(omtdata_t *, const uint32_t, iterate_extra_t *const)>
+void omt<omtdata_t, omtdataout_t, supports_marks>::iterate_ptr_internal(
+ const uint32_t left, const uint32_t right, const subtree &st,
+ const uint32_t idx, iterate_extra_t *const iterate_extra) {
+ if (!st.is_null()) {
+ omt_node &n = this->d.t.nodes[st.get_index()];
+ const uint32_t idx_root = idx + this->nweight(n.left);
+ if (left < idx_root) {
+ this->iterate_ptr_internal<iterate_extra_t, f>(left, right, n.left, idx,
+ iterate_extra);
+ }
+ if (left <= idx_root && idx_root < right) {
+ int r = f(&n.value, idx_root, iterate_extra);
+ lazy_assert_zero(r);
+ }
+ if (idx_root + 1 < right) {
+ this->iterate_ptr_internal<iterate_extra_t, f>(
+ left, right, n.right, idx_root + 1, iterate_extra);
+ }
+ }
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+template <typename iterate_extra_t,
+ int (*f)(omtdata_t *, const uint32_t, iterate_extra_t *const)>
+void omt<omtdata_t, omtdataout_t, supports_marks>::iterate_ptr_internal_array(
+ const uint32_t left, const uint32_t right,
+ iterate_extra_t *const iterate_extra) {
+ for (uint32_t i = left; i < right; ++i) {
+ int r = f(&this->d.a.values[this->d.a.start_idx + i], i, iterate_extra);
+ lazy_assert_zero(r);
+ }
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+template <typename iterate_extra_t,
+ int (*f)(const omtdata_t &, const uint32_t, iterate_extra_t *const)>
+int omt<omtdata_t, omtdataout_t, supports_marks>::iterate_internal(
+ const uint32_t left, const uint32_t right, const subtree &st,
+ const uint32_t idx, iterate_extra_t *const iterate_extra) const {
+ if (st.is_null()) {
+ return 0;
+ }
+ int r;
+ const omt_node &n = this->d.t.nodes[st.get_index()];
+ const uint32_t idx_root = idx + this->nweight(n.left);
+ if (left < idx_root) {
+ r = this->iterate_internal<iterate_extra_t, f>(left, right, n.left, idx,
+ iterate_extra);
+ if (r != 0) {
+ return r;
+ }
+ }
+ if (left <= idx_root && idx_root < right) {
+ r = f(n.value, idx_root, iterate_extra);
+ if (r != 0) {
+ return r;
+ }
+ }
+ if (idx_root + 1 < right) {
+ return this->iterate_internal<iterate_extra_t, f>(
+ left, right, n.right, idx_root + 1, iterate_extra);
+ }
+ return 0;
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+template <typename iterate_extra_t,
+ int (*f)(const omtdata_t &, const uint32_t, iterate_extra_t *const)>
+int omt<omtdata_t, omtdataout_t, supports_marks>::
+ iterate_and_mark_range_internal(const uint32_t left, const uint32_t right,
+ const subtree &st, const uint32_t idx,
+ iterate_extra_t *const iterate_extra) {
+ paranoid_invariant(!st.is_null());
+ int r;
+ omt_node &n = this->d.t.nodes[st.get_index()];
+ const uint32_t idx_root = idx + this->nweight(n.left);
+ if (left < idx_root && !n.left.is_null()) {
+ n.set_marks_below_bit();
+ r = this->iterate_and_mark_range_internal<iterate_extra_t, f>(
+ left, right, n.left, idx, iterate_extra);
+ if (r != 0) {
+ return r;
+ }
+ }
+ if (left <= idx_root && idx_root < right) {
+ n.set_marked_bit();
+ r = f(n.value, idx_root, iterate_extra);
+ if (r != 0) {
+ return r;
+ }
+ }
+ if (idx_root + 1 < right && !n.right.is_null()) {
+ n.set_marks_below_bit();
+ return this->iterate_and_mark_range_internal<iterate_extra_t, f>(
+ left, right, n.right, idx_root + 1, iterate_extra);
+ }
+ return 0;
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+template <typename iterate_extra_t,
+ int (*f)(const omtdata_t &, const uint32_t, iterate_extra_t *const)>
+int omt<omtdata_t, omtdataout_t, supports_marks>::iterate_over_marked_internal(
+ const subtree &st, const uint32_t idx,
+ iterate_extra_t *const iterate_extra) const {
+ if (st.is_null()) {
+ return 0;
+ }
+ int r;
+ const omt_node &n = this->d.t.nodes[st.get_index()];
+ const uint32_t idx_root = idx + this->nweight(n.left);
+ if (n.get_marks_below()) {
+ r = this->iterate_over_marked_internal<iterate_extra_t, f>(n.left, idx,
+ iterate_extra);
+ if (r != 0) {
+ return r;
+ }
+ }
+ if (n.get_marked()) {
+ r = f(n.value, idx_root, iterate_extra);
+ if (r != 0) {
+ return r;
+ }
+ }
+ if (n.get_marks_below()) {
+ return this->iterate_over_marked_internal<iterate_extra_t, f>(
+ n.right, idx_root + 1, iterate_extra);
+ }
+ return 0;
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+void omt<omtdata_t, omtdataout_t, supports_marks>::fetch_internal_array(
+ const uint32_t i, omtdataout_t *const value) const {
+ if (value != nullptr) {
+ copyout(value, &this->d.a.values[this->d.a.start_idx + i]);
+ }
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+void omt<omtdata_t, omtdataout_t, supports_marks>::fetch_internal(
+ const subtree &st, const uint32_t i, omtdataout_t *const value) const {
+ omt_node &n = this->d.t.nodes[st.get_index()];
+ const uint32_t leftweight = this->nweight(n.left);
+ if (i < leftweight) {
+ this->fetch_internal(n.left, i, value);
+ } else if (i == leftweight) {
+ if (value != nullptr) {
+ copyout(value, &n);
+ }
+ } else {
+ this->fetch_internal(n.right, i - leftweight - 1, value);
+ }
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+void omt<omtdata_t, omtdataout_t, supports_marks>::fill_array_with_subtree_idxs(
+ node_idx *const array, const subtree &st) const {
+ if (!st.is_null()) {
+ const omt_node &tree = this->d.t.nodes[st.get_index()];
+ this->fill_array_with_subtree_idxs(&array[0], tree.left);
+ array[this->nweight(tree.left)] = st.get_index();
+ this->fill_array_with_subtree_idxs(&array[this->nweight(tree.left) + 1],
+ tree.right);
+ }
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+void omt<omtdata_t, omtdataout_t, supports_marks>::rebuild_subtree_from_idxs(
+ subtree *const st, const node_idx *const idxs, const uint32_t numvalues) {
+ if (numvalues == 0) {
+ st->set_to_null();
+ } else {
+ uint32_t halfway = numvalues / 2;
+ st->set_index(idxs[halfway]);
+ // node_idx newidx = idxs[halfway];
+ omt_node &newnode = this->d.t.nodes[st->get_index()];
+ newnode.weight = numvalues;
+ // value is already in there.
+ this->rebuild_subtree_from_idxs(&newnode.left, &idxs[0], halfway);
+ this->rebuild_subtree_from_idxs(&newnode.right, &idxs[halfway + 1],
+ numvalues - (halfway + 1));
+ // n_idx = newidx;
+ }
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+void omt<omtdata_t, omtdataout_t, supports_marks>::rebalance(
+ subtree *const st) {
+ node_idx idx = st->get_index();
+ if (idx == this->d.t.root.get_index()) {
+ // Try to convert to an array.
+ // If this fails, (malloc) nothing will have changed.
+ // In the failure case we continue on to the standard rebalance
+ // algorithm.
+ this->convert_to_array();
+ if (supports_marks) {
+ this->convert_to_tree();
+ }
+ } else {
+ const omt_node &n = this->d.t.nodes[idx];
+ node_idx *tmp_array;
+ size_t mem_needed = n.weight * (sizeof tmp_array[0]);
+ size_t mem_free =
+ (this->capacity - this->d.t.free_idx) * (sizeof this->d.t.nodes[0]);
+ bool malloced;
+ if (mem_needed <= mem_free) {
+ // There is sufficient free space at the end of the nodes array
+ // to hold enough node indexes to rebalance.
+ malloced = false;
+ tmp_array =
+ reinterpret_cast<node_idx *>(&this->d.t.nodes[this->d.t.free_idx]);
+ } else {
+ malloced = true;
+ XMALLOC_N(n.weight, tmp_array);
+ }
+ this->fill_array_with_subtree_idxs(tmp_array, *st);
+ this->rebuild_subtree_from_idxs(st, tmp_array, n.weight);
+ if (malloced) toku_free(tmp_array);
+ }
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+void omt<omtdata_t, omtdataout_t, supports_marks>::copyout(
+ omtdata_t *const out, const omt_node *const n) {
+ *out = n->value;
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+void omt<omtdata_t, omtdataout_t, supports_marks>::copyout(
+ omtdata_t **const out, omt_node *const n) {
+ *out = &n->value;
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+void omt<omtdata_t, omtdataout_t, supports_marks>::copyout(
+ omtdata_t *const out, const omtdata_t *const stored_value_ptr) {
+ *out = *stored_value_ptr;
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+void omt<omtdata_t, omtdataout_t, supports_marks>::copyout(
+ omtdata_t **const out, omtdata_t *const stored_value_ptr) {
+ *out = stored_value_ptr;
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+template <typename omtcmp_t, int (*h)(const omtdata_t &, const omtcmp_t &)>
+int omt<omtdata_t, omtdataout_t, supports_marks>::find_internal_zero_array(
+ const omtcmp_t &extra, omtdataout_t *const value,
+ uint32_t *const idxp) const {
+ paranoid_invariant_notnull(idxp);
+ uint32_t min = this->d.a.start_idx;
+ uint32_t limit = this->d.a.start_idx + this->d.a.num_values;
+ uint32_t best_pos = subtree::NODE_NULL;
+ uint32_t best_zero = subtree::NODE_NULL;
+
+ while (min != limit) {
+ uint32_t mid = (min + limit) / 2;
+ int hv = h(this->d.a.values[mid], extra);
+ if (hv < 0) {
+ min = mid + 1;
+ } else if (hv > 0) {
+ best_pos = mid;
+ limit = mid;
+ } else {
+ best_zero = mid;
+ limit = mid;
+ }
+ }
+ if (best_zero != subtree::NODE_NULL) {
+ // Found a zero
+ if (value != nullptr) {
+ copyout(value, &this->d.a.values[best_zero]);
+ }
+ *idxp = best_zero - this->d.a.start_idx;
+ return 0;
+ }
+ if (best_pos != subtree::NODE_NULL)
+ *idxp = best_pos - this->d.a.start_idx;
+ else
+ *idxp = this->d.a.num_values;
+ return DB_NOTFOUND;
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+template <typename omtcmp_t, int (*h)(const omtdata_t &, const omtcmp_t &)>
+int omt<omtdata_t, omtdataout_t, supports_marks>::find_internal_zero(
+ const subtree &st, const omtcmp_t &extra, omtdataout_t *const value,
+ uint32_t *const idxp) const {
+ paranoid_invariant_notnull(idxp);
+ if (st.is_null()) {
+ *idxp = 0;
+ return DB_NOTFOUND;
+ }
+ omt_node &n = this->d.t.nodes[st.get_index()];
+ int hv = h(n.value, extra);
+ if (hv < 0) {
+ int r = this->find_internal_zero<omtcmp_t, h>(n.right, extra, value, idxp);
+ *idxp += this->nweight(n.left) + 1;
+ return r;
+ } else if (hv > 0) {
+ return this->find_internal_zero<omtcmp_t, h>(n.left, extra, value, idxp);
+ } else {
+ int r = this->find_internal_zero<omtcmp_t, h>(n.left, extra, value, idxp);
+ if (r == DB_NOTFOUND) {
+ *idxp = this->nweight(n.left);
+ if (value != nullptr) {
+ copyout(value, &n);
+ }
+ r = 0;
+ }
+ return r;
+ }
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+template <typename omtcmp_t, int (*h)(const omtdata_t &, const omtcmp_t &)>
+int omt<omtdata_t, omtdataout_t, supports_marks>::find_internal_plus_array(
+ const omtcmp_t &extra, omtdataout_t *const value,
+ uint32_t *const idxp) const {
+ paranoid_invariant_notnull(idxp);
+ uint32_t min = this->d.a.start_idx;
+ uint32_t limit = this->d.a.start_idx + this->d.a.num_values;
+ uint32_t best = subtree::NODE_NULL;
+
+ while (min != limit) {
+ const uint32_t mid = (min + limit) / 2;
+ const int hv = h(this->d.a.values[mid], extra);
+ if (hv > 0) {
+ best = mid;
+ limit = mid;
+ } else {
+ min = mid + 1;
+ }
+ }
+ if (best == subtree::NODE_NULL) {
+ return DB_NOTFOUND;
+ }
+ if (value != nullptr) {
+ copyout(value, &this->d.a.values[best]);
+ }
+ *idxp = best - this->d.a.start_idx;
+ return 0;
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+template <typename omtcmp_t, int (*h)(const omtdata_t &, const omtcmp_t &)>
+int omt<omtdata_t, omtdataout_t, supports_marks>::find_internal_plus(
+ const subtree &st, const omtcmp_t &extra, omtdataout_t *const value,
+ uint32_t *const idxp) const {
+ paranoid_invariant_notnull(idxp);
+ if (st.is_null()) {
+ return DB_NOTFOUND;
+ }
+ omt_node *const n = &this->d.t.nodes[st.get_index()];
+ int hv = h(n->value, extra);
+ int r;
+ if (hv > 0) {
+ r = this->find_internal_plus<omtcmp_t, h>(n->left, extra, value, idxp);
+ if (r == DB_NOTFOUND) {
+ *idxp = this->nweight(n->left);
+ if (value != nullptr) {
+ copyout(value, n);
+ }
+ r = 0;
+ }
+ } else {
+ r = this->find_internal_plus<omtcmp_t, h>(n->right, extra, value, idxp);
+ if (r == 0) {
+ *idxp += this->nweight(n->left) + 1;
+ }
+ }
+ return r;
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+template <typename omtcmp_t, int (*h)(const omtdata_t &, const omtcmp_t &)>
+int omt<omtdata_t, omtdataout_t, supports_marks>::find_internal_minus_array(
+ const omtcmp_t &extra, omtdataout_t *const value,
+ uint32_t *const idxp) const {
+ paranoid_invariant_notnull(idxp);
+ uint32_t min = this->d.a.start_idx;
+ uint32_t limit = this->d.a.start_idx + this->d.a.num_values;
+ uint32_t best = subtree::NODE_NULL;
+
+ while (min != limit) {
+ const uint32_t mid = (min + limit) / 2;
+ const int hv = h(this->d.a.values[mid], extra);
+ if (hv < 0) {
+ best = mid;
+ min = mid + 1;
+ } else {
+ limit = mid;
+ }
+ }
+ if (best == subtree::NODE_NULL) {
+ return DB_NOTFOUND;
+ }
+ if (value != nullptr) {
+ copyout(value, &this->d.a.values[best]);
+ }
+ *idxp = best - this->d.a.start_idx;
+ return 0;
+}
+
+template <typename omtdata_t, typename omtdataout_t, bool supports_marks>
+template <typename omtcmp_t, int (*h)(const omtdata_t &, const omtcmp_t &)>
+int omt<omtdata_t, omtdataout_t, supports_marks>::find_internal_minus(
+ const subtree &st, const omtcmp_t &extra, omtdataout_t *const value,
+ uint32_t *const idxp) const {
+ paranoid_invariant_notnull(idxp);
+ if (st.is_null()) {
+ return DB_NOTFOUND;
+ }
+ omt_node *const n = &this->d.t.nodes[st.get_index()];
+ int hv = h(n->value, extra);
+ if (hv < 0) {
+ int r =
+ this->find_internal_minus<omtcmp_t, h>(n->right, extra, value, idxp);
+ if (r == 0) {
+ *idxp += this->nweight(n->left) + 1;
+ } else if (r == DB_NOTFOUND) {
+ *idxp = this->nweight(n->left);
+ if (value != nullptr) {
+ copyout(value, n);
+ }
+ r = 0;
+ }
+ return r;
+ } else {
+ return this->find_internal_minus<omtcmp_t, h>(n->left, extra, value, idxp);
+ }
+}
+} // namespace toku
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/util/partitioned_counter.h b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/util/partitioned_counter.h
new file mode 100644
index 000000000..f20eeedf2
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/util/partitioned_counter.h
@@ -0,0 +1,165 @@
+/* -*- mode: C++; c-basic-offset: 4; indent-tabs-mode: nil -*- */
+// vim: ft=cpp:expandtab:ts=8:sw=4:softtabstop=4:
+#ident "$Id$"
+/*======
+This file is part of PerconaFT.
+
+
+Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved.
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License, version 2,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License, version 3,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+======= */
+
+#ident \
+ "Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved."
+
+#pragma once
+
+// Overview: A partitioned_counter provides a counter that can be incremented
+// and the running sum can be read at any time.
+// We assume that increments are frequent, whereas reading is infrequent.
+// Implementation hint: Use thread-local storage so each thread increments its
+// own data. The increment does not require a lock or atomic operation.
+// Reading the data can be performed by iterating over the thread-local
+// versions, summing them up. The data structure also includes a sum for all
+// the threads that have died. Use a pthread_key to create the thread-local
+// versions. When a thread finishes, the system calls pthread_key destructor
+// which can add that thread's copy into the sum_of_dead counter.
+// Rationale: For statistics such as are found in engine status, we need a
+// counter that requires no cache misses to increment. We've seen significant
+// performance speedups by removing certain counters. Rather than removing
+// those statistics, we would like to just make the counter fast. We generally
+// increment the counters frequently, and want to fetch the values
+// infrequently. The counters are monotonic. The counters can be split into
+// many counters, which can be summed up at the end. We don't care if we get
+// slightly out-of-date counter sums when we read the counter. We don't care
+// if there is a race on reading the a counter
+// variable and incrementing.
+// See tests/test_partitioned_counter.c for some performance measurements.
+// Operations:
+// create_partitioned_counter Create a counter initialized to zero.
+// destroy_partitioned_counter Destroy it.
+// increment_partitioned_counter Increment it. This is the frequent
+// operation. read_partitioned_counter Get the current value. This is
+// infrequent.
+// See partitioned_counter.cc for the abstraction function and representation
+// invariant.
+//
+// The google style guide says to avoid using constructors, and it appears that
+// constructors may have broken all the tests, because they called
+// pthread_key_create before the key was actually created. So the google style
+// guide may have some wisdom there...
+//
+// This version does not use constructors, essentially reverrting to the google
+// C++ style guide.
+//
+
+// The old C interface. This required a bunch of explicit
+// ___attribute__((__destructor__)) functions to remember to destroy counters at
+// the end.
+#if defined(__cplusplus)
+extern "C" {
+#endif
+
+typedef struct partitioned_counter *PARTITIONED_COUNTER;
+PARTITIONED_COUNTER create_partitioned_counter(void);
+// Effect: Create a counter, initialized to zero.
+
+void destroy_partitioned_counter(PARTITIONED_COUNTER);
+// Effect: Destroy the counter. No operations on that counter are permitted
+// after this.
+
+void increment_partitioned_counter(PARTITIONED_COUNTER, uint64_t amount);
+// Effect: Increment the counter by amount.
+// Requires: No overflows. This is a 64-bit unsigned counter.
+
+uint64_t read_partitioned_counter(PARTITIONED_COUNTER)
+ __attribute__((__visibility__("default")));
+// Effect: Return the current value of the counter.
+
+void partitioned_counters_init(void);
+// Effect: Initialize any partitioned counters data structures that must be set
+// up before any partitioned counters run.
+
+void partitioned_counters_destroy(void);
+// Effect: Destroy any partitioned counters data structures.
+
+#if defined(__cplusplus)
+};
+#endif
+
+#if 0
+#include <pthread.h>
+
+#include "fttypes.h"
+
+// Used inside the PARTITIONED_COUNTER.
+struct linked_list_head {
+ struct linked_list_element *first;
+};
+
+
+class PARTITIONED_COUNTER {
+public:
+ PARTITIONED_COUNTER(void);
+ // Effect: Construct a counter, initialized to zero.
+
+ ~PARTITIONED_COUNTER(void);
+ // Effect: Destruct the counter.
+
+ void increment(uint64_t amount);
+ // Effect: Increment the counter by amount. This is a 64-bit unsigned counter, and if you overflow it, you will get overflowed results (that is mod 2^64).
+ // Requires: Don't use this from a static constructor or destructor.
+
+ uint64_t read(void);
+ // Effect: Read the sum.
+ // Requires: Don't use this from a static constructor or destructor.
+
+private:
+ uint64_t _sum_of_dead; // The sum of all thread-local counts from threads that have terminated.
+ pthread_key_t _key; // The pthread_key which gives us the hook to construct and destruct thread-local storage.
+ struct linked_list_head _ll_counter_head; // A linked list of all the thread-local information for this counter.
+
+ // This function is used to destroy the thread-local part of the state when a thread terminates.
+ // But it's not the destructor for the local part of the counter, it's a destructor on a "dummy" key just so that we get a notification when a thread ends.
+ friend void destroy_thread_local_part_of_partitioned_counters (void *);
+};
+#endif
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/util/status.h b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/util/status.h
new file mode 100644
index 000000000..3fd0095d0
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/lib/util/status.h
@@ -0,0 +1,76 @@
+/* -*- mode: C++; c-basic-offset: 4; indent-tabs-mode: nil -*- */
+// vim: ft=cpp:expandtab:ts=8:sw=4:softtabstop=4:
+#ident "$Id$"
+/*======
+This file is part of PerconaFT.
+
+
+Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved.
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License, version 2,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ PerconaFT is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License, version 3,
+ as published by the Free Software Foundation.
+
+ PerconaFT is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with PerconaFT. If not, see <http://www.gnu.org/licenses/>.
+
+----------------------------------------
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+======= */
+
+#ident \
+ "Copyright (c) 2006, 2015, Percona and/or its affiliates. All rights reserved."
+
+#pragma once
+
+#include "partitioned_counter.h"
+// PORT2: #include <util/constexpr.h>
+
+#define TOKUFT_STATUS_INIT(array, k, c, t, l, inc) \
+ do { \
+ array.status[k].keyname = #k; \
+ array.status[k].columnname = #c; \
+ array.status[k].type = t; \
+ array.status[k].legend = l; \
+ constexpr_static_assert( \
+ strcmp(#c, "NULL") && strcmp(#c, "0"), \
+ "Use nullptr for no column name instead of NULL, 0, etc..."); \
+ constexpr_static_assert( \
+ (inc) == TOKU_ENGINE_STATUS || strcmp(#c, "nullptr"), \
+ "Missing column name."); \
+ array.status[k].include = \
+ static_cast<toku_engine_status_include_type>(inc); \
+ if (t == STATUS_PARCOUNT) { \
+ array.status[k].value.parcount = create_partitioned_counter(); \
+ } \
+ } while (0)
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/range_tree_lock_manager.cc b/src/rocksdb/utilities/transactions/lock/range/range_tree/range_tree_lock_manager.cc
new file mode 100644
index 000000000..531165dea
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/range_tree_lock_manager.cc
@@ -0,0 +1,503 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+#ifndef OS_WIN
+
+#include "utilities/transactions/lock/range/range_tree/range_tree_lock_manager.h"
+
+#include <algorithm>
+#include <cinttypes>
+#include <mutex>
+
+#include "monitoring/perf_context_imp.h"
+#include "rocksdb/slice.h"
+#include "rocksdb/utilities/transaction_db_mutex.h"
+#include "test_util/sync_point.h"
+#include "util/cast_util.h"
+#include "util/hash.h"
+#include "util/thread_local.h"
+#include "utilities/transactions/lock/range/range_tree/range_tree_lock_tracker.h"
+#include "utilities/transactions/pessimistic_transaction_db.h"
+#include "utilities/transactions/transaction_db_mutex_impl.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+RangeLockManagerHandle* NewRangeLockManager(
+ std::shared_ptr<TransactionDBMutexFactory> mutex_factory) {
+ std::shared_ptr<TransactionDBMutexFactory> use_factory;
+
+ if (mutex_factory) {
+ use_factory = mutex_factory;
+ } else {
+ use_factory.reset(new TransactionDBMutexFactoryImpl());
+ }
+ return new RangeTreeLockManager(use_factory);
+}
+
+static const char SUFFIX_INFIMUM = 0x0;
+static const char SUFFIX_SUPREMUM = 0x1;
+
+// Convert Endpoint into an internal format used for storing it in locktree
+// (DBT structure is used for passing endpoints to locktree and getting back)
+void serialize_endpoint(const Endpoint& endp, std::string* buf) {
+ buf->push_back(endp.inf_suffix ? SUFFIX_SUPREMUM : SUFFIX_INFIMUM);
+ buf->append(endp.slice.data(), endp.slice.size());
+}
+
+// Decode the endpoint from the format it is stored in the locktree (DBT) to
+// the one used outside: either Endpoint or EndpointWithString
+template <typename EndpointStruct>
+void deserialize_endpoint(const DBT* dbt, EndpointStruct* endp) {
+ assert(dbt->size >= 1);
+ const char* dbt_data = (const char*)dbt->data;
+ char suffix = dbt_data[0];
+ assert(suffix == SUFFIX_INFIMUM || suffix == SUFFIX_SUPREMUM);
+ endp->inf_suffix = (suffix == SUFFIX_SUPREMUM);
+ endp->slice = decltype(EndpointStruct::slice)(dbt_data + 1, dbt->size - 1);
+}
+
+// Get a range lock on [start_key; end_key] range
+Status RangeTreeLockManager::TryLock(PessimisticTransaction* txn,
+ uint32_t column_family_id,
+ const Endpoint& start_endp,
+ const Endpoint& end_endp, Env*,
+ bool exclusive) {
+ toku::lock_request request;
+ request.create(mutex_factory_);
+ DBT start_key_dbt, end_key_dbt;
+
+ TEST_SYNC_POINT("RangeTreeLockManager::TryRangeLock:enter");
+ std::string start_key;
+ std::string end_key;
+ serialize_endpoint(start_endp, &start_key);
+ serialize_endpoint(end_endp, &end_key);
+
+ toku_fill_dbt(&start_key_dbt, start_key.data(), start_key.size());
+ toku_fill_dbt(&end_key_dbt, end_key.data(), end_key.size());
+
+ auto lt = GetLockTreeForCF(column_family_id);
+
+ // Put the key waited on into request's m_extra. See
+ // wait_callback_for_locktree for details.
+ std::string wait_key(start_endp.slice.data(), start_endp.slice.size());
+
+ request.set(lt.get(), (TXNID)txn, &start_key_dbt, &end_key_dbt,
+ exclusive ? toku::lock_request::WRITE : toku::lock_request::READ,
+ false /* not a big txn */, &wait_key);
+
+ // This is for "periodically wake up and check if the wait is killed" feature
+ // which we are not using.
+ uint64_t killed_time_msec = 0;
+ uint64_t wait_time_msec = txn->GetLockTimeout();
+
+ if (wait_time_msec == static_cast<uint64_t>(-1)) {
+ // The transaction has no wait timeout. lock_request::wait doesn't support
+ // this, it needs a number of milliseconds to wait. Pass it one year to
+ // be safe.
+ wait_time_msec = uint64_t(1000) * 60 * 60 * 24 * 365;
+ } else {
+ // convert microseconds to milliseconds
+ wait_time_msec = (wait_time_msec + 500) / 1000;
+ }
+
+ std::vector<RangeDeadlockInfo> di_path;
+ request.m_deadlock_cb = [&](TXNID txnid, bool is_exclusive,
+ const DBT* start_dbt, const DBT* end_dbt) {
+ EndpointWithString start;
+ EndpointWithString end;
+ deserialize_endpoint(start_dbt, &start);
+ deserialize_endpoint(end_dbt, &end);
+
+ di_path.push_back({txnid, column_family_id, is_exclusive, std::move(start),
+ std::move(end)});
+ };
+
+ request.start();
+
+ const int r = request.wait(wait_time_msec, killed_time_msec,
+ nullptr, // killed_callback
+ wait_callback_for_locktree, nullptr);
+
+ // Inform the txn that we are no longer waiting:
+ txn->ClearWaitingTxn();
+
+ request.destroy();
+ switch (r) {
+ case 0:
+ break; // fall through
+ case DB_LOCK_NOTGRANTED:
+ return Status::TimedOut(Status::SubCode::kLockTimeout);
+ case TOKUDB_OUT_OF_LOCKS:
+ return Status::Busy(Status::SubCode::kLockLimit);
+ case DB_LOCK_DEADLOCK: {
+ std::reverse(di_path.begin(), di_path.end());
+ dlock_buffer_.AddNewPath(
+ RangeDeadlockPath(di_path, request.get_start_time()));
+ return Status::Busy(Status::SubCode::kDeadlock);
+ }
+ default:
+ assert(0);
+ return Status::Busy(Status::SubCode::kLockLimit);
+ }
+
+ return Status::OK();
+}
+
+// Wait callback that locktree library will call to inform us about
+// the lock waits that are in progress.
+void wait_callback_for_locktree(void*, toku::lock_wait_infos* infos) {
+ TEST_SYNC_POINT("RangeTreeLockManager::TryRangeLock:EnterWaitingTxn");
+ for (auto wait_info : *infos) {
+ // As long as we hold the lock on the locktree's pending request queue
+ // this should be safe.
+ auto txn = (PessimisticTransaction*)wait_info.waiter;
+ auto cf_id = (ColumnFamilyId)wait_info.ltree->get_dict_id().dictid;
+
+ autovector<TransactionID> waitee_ids;
+ for (auto waitee : wait_info.waitees) {
+ waitee_ids.push_back(waitee);
+ }
+ txn->SetWaitingTxn(waitee_ids, cf_id, (std::string*)wait_info.m_extra);
+ }
+
+ // Here we can assume that the locktree code will now wait for some lock
+ TEST_SYNC_POINT("RangeTreeLockManager::TryRangeLock:WaitingTxn");
+}
+
+void RangeTreeLockManager::UnLock(PessimisticTransaction* txn,
+ ColumnFamilyId column_family_id,
+ const std::string& key, Env*) {
+ auto locktree = GetLockTreeForCF(column_family_id);
+ std::string endp_image;
+ serialize_endpoint({key.data(), key.size(), false}, &endp_image);
+
+ DBT key_dbt;
+ toku_fill_dbt(&key_dbt, endp_image.data(), endp_image.size());
+
+ toku::range_buffer range_buf;
+ range_buf.create();
+ range_buf.append(&key_dbt, &key_dbt);
+
+ locktree->release_locks((TXNID)txn, &range_buf);
+ range_buf.destroy();
+
+ toku::lock_request::retry_all_lock_requests(
+ locktree.get(), wait_callback_for_locktree, nullptr);
+}
+
+void RangeTreeLockManager::UnLock(PessimisticTransaction* txn,
+ const LockTracker& tracker, Env*) {
+ const RangeTreeLockTracker* range_tracker =
+ static_cast<const RangeTreeLockTracker*>(&tracker);
+
+ RangeTreeLockTracker* range_trx_tracker =
+ static_cast<RangeTreeLockTracker*>(&txn->GetTrackedLocks());
+ bool all_keys = (range_trx_tracker == range_tracker);
+
+ // tracked_locks_->range_list may hold nullptr if the transaction has never
+ // acquired any locks.
+ ((RangeTreeLockTracker*)range_tracker)->ReleaseLocks(this, txn, all_keys);
+}
+
+int RangeTreeLockManager::CompareDbtEndpoints(void* arg, const DBT* a_key,
+ const DBT* b_key) {
+ const char* a = (const char*)a_key->data;
+ const char* b = (const char*)b_key->data;
+
+ size_t a_len = a_key->size;
+ size_t b_len = b_key->size;
+
+ size_t min_len = std::min(a_len, b_len);
+
+ // Compare the values. The first byte encodes the endpoint type, its value
+ // is either SUFFIX_INFIMUM or SUFFIX_SUPREMUM.
+ Comparator* cmp = (Comparator*)arg;
+ int res = cmp->Compare(Slice(a + 1, min_len - 1), Slice(b + 1, min_len - 1));
+ if (!res) {
+ if (b_len > min_len) {
+ // a is shorter;
+ if (a[0] == SUFFIX_INFIMUM) {
+ return -1; //"a is smaller"
+ } else {
+ // a is considered padded with 0xFF:FF:FF:FF...
+ return 1; // "a" is bigger
+ }
+ } else if (a_len > min_len) {
+ // the opposite of the above: b is shorter.
+ if (b[0] == SUFFIX_INFIMUM) {
+ return 1; //"b is smaller"
+ } else {
+ // b is considered padded with 0xFF:FF:FF:FF...
+ return -1; // "b" is bigger
+ }
+ } else {
+ // the lengths are equal (and the key values, too)
+ if (a[0] < b[0]) {
+ return -1;
+ } else if (a[0] > b[0]) {
+ return 1;
+ } else {
+ return 0;
+ }
+ }
+ } else {
+ return res;
+ }
+}
+
+namespace {
+void UnrefLockTreeMapsCache(void* ptr) {
+ // Called when a thread exits or a ThreadLocalPtr gets destroyed.
+ auto lock_tree_map_cache = static_cast<
+ std::unordered_map<ColumnFamilyId, std::shared_ptr<toku::locktree>>*>(
+ ptr);
+ delete lock_tree_map_cache;
+}
+} // anonymous namespace
+
+RangeTreeLockManager::RangeTreeLockManager(
+ std::shared_ptr<TransactionDBMutexFactory> mutex_factory)
+ : mutex_factory_(mutex_factory),
+ ltree_lookup_cache_(new ThreadLocalPtr(&UnrefLockTreeMapsCache)),
+ dlock_buffer_(10) {
+ ltm_.create(on_create, on_destroy, on_escalate, nullptr, mutex_factory_);
+}
+
+int RangeTreeLockManager::on_create(toku::locktree* lt, void* arg) {
+ // arg is a pointer to RangeTreeLockManager
+ lt->set_escalation_barrier_func(&OnEscalationBarrierCheck, arg);
+ return 0;
+}
+
+bool RangeTreeLockManager::OnEscalationBarrierCheck(const DBT* a, const DBT* b,
+ void* extra) {
+ Endpoint a_endp, b_endp;
+ deserialize_endpoint(a, &a_endp);
+ deserialize_endpoint(b, &b_endp);
+ auto self = static_cast<RangeTreeLockManager*>(extra);
+ return self->barrier_func_(a_endp, b_endp);
+}
+
+void RangeTreeLockManager::SetRangeDeadlockInfoBufferSize(
+ uint32_t target_size) {
+ dlock_buffer_.Resize(target_size);
+}
+
+void RangeTreeLockManager::Resize(uint32_t target_size) {
+ SetRangeDeadlockInfoBufferSize(target_size);
+}
+
+std::vector<RangeDeadlockPath>
+RangeTreeLockManager::GetRangeDeadlockInfoBuffer() {
+ return dlock_buffer_.PrepareBuffer();
+}
+
+std::vector<DeadlockPath> RangeTreeLockManager::GetDeadlockInfoBuffer() {
+ std::vector<DeadlockPath> res;
+ std::vector<RangeDeadlockPath> data = GetRangeDeadlockInfoBuffer();
+ // report left endpoints
+ for (auto it = data.begin(); it != data.end(); ++it) {
+ std::vector<DeadlockInfo> path;
+
+ for (auto it2 = it->path.begin(); it2 != it->path.end(); ++it2) {
+ path.push_back(
+ {it2->m_txn_id, it2->m_cf_id, it2->m_exclusive, it2->m_start.slice});
+ }
+ res.push_back(DeadlockPath(path, it->deadlock_time));
+ }
+ return res;
+}
+
+// @brief Lock Escalation Callback function
+//
+// @param txnid Transaction whose locks got escalated
+// @param lt Lock Tree where escalation is happening
+// @param buffer Escalation result: list of locks that this transaction now
+// owns in this lock tree.
+// @param void* Callback context
+void RangeTreeLockManager::on_escalate(TXNID txnid, const toku::locktree* lt,
+ const toku::range_buffer& buffer,
+ void*) {
+ auto txn = (PessimisticTransaction*)txnid;
+ ((RangeTreeLockTracker*)&txn->GetTrackedLocks())->ReplaceLocks(lt, buffer);
+}
+
+RangeTreeLockManager::~RangeTreeLockManager() {
+ autovector<void*> local_caches;
+ ltree_lookup_cache_->Scrape(&local_caches, nullptr);
+ for (auto cache : local_caches) {
+ delete static_cast<LockTreeMap*>(cache);
+ }
+ ltree_map_.clear(); // this will call release_lt() for all locktrees
+ ltm_.destroy();
+}
+
+RangeLockManagerHandle::Counters RangeTreeLockManager::GetStatus() {
+ LTM_STATUS_S ltm_status_test;
+ ltm_.get_status(&ltm_status_test);
+ Counters res;
+
+ // Searching status variable by its string name is how Toku's unit tests
+ // do it (why didn't they make LTM_ESCALATION_COUNT constant visible?)
+ // lookup keyname in status
+ for (int i = 0; i < LTM_STATUS_S::LTM_STATUS_NUM_ROWS; i++) {
+ TOKU_ENGINE_STATUS_ROW status = &ltm_status_test.status[i];
+ if (strcmp(status->keyname, "LTM_ESCALATION_COUNT") == 0) {
+ res.escalation_count = status->value.num;
+ continue;
+ }
+ if (strcmp(status->keyname, "LTM_WAIT_COUNT") == 0) {
+ res.lock_wait_count = status->value.num;
+ continue;
+ }
+ if (strcmp(status->keyname, "LTM_SIZE_CURRENT") == 0) {
+ res.current_lock_memory = status->value.num;
+ }
+ }
+ return res;
+}
+
+std::shared_ptr<toku::locktree> RangeTreeLockManager::MakeLockTreePtr(
+ toku::locktree* lt) {
+ toku::locktree_manager* ltm = &ltm_;
+ return std::shared_ptr<toku::locktree>(
+ lt, [ltm](toku::locktree* p) { ltm->release_lt(p); });
+}
+
+void RangeTreeLockManager::AddColumnFamily(const ColumnFamilyHandle* cfh) {
+ uint32_t column_family_id = cfh->GetID();
+
+ InstrumentedMutexLock l(&ltree_map_mutex_);
+ if (ltree_map_.find(column_family_id) == ltree_map_.end()) {
+ DICTIONARY_ID dict_id = {.dictid = column_family_id};
+ toku::comparator cmp;
+ cmp.create(CompareDbtEndpoints, (void*)cfh->GetComparator());
+ toku::locktree* ltree =
+ ltm_.get_lt(dict_id, cmp,
+ /* on_create_extra*/ static_cast<void*>(this));
+ // This is ok to because get_lt has copied the comparator:
+ cmp.destroy();
+
+ ltree_map_.insert({column_family_id, MakeLockTreePtr(ltree)});
+ }
+}
+
+void RangeTreeLockManager::RemoveColumnFamily(const ColumnFamilyHandle* cfh) {
+ uint32_t column_family_id = cfh->GetID();
+ // Remove lock_map for this column family. Since the lock map is stored
+ // as a shared ptr, concurrent transactions can still keep using it
+ // until they release their references to it.
+
+ // TODO what if one drops a column family while transaction(s) still have
+ // locks in it?
+ // locktree uses column family'c Comparator* as the criteria to do tree
+ // ordering. If the comparator is gone, we won't even be able to remove the
+ // elements from the locktree.
+ // A possible solution might be to remove everything right now:
+ // - wait until everyone traversing the locktree are gone
+ // - remove everything from the locktree.
+ // - some transactions may have acquired locks in their LockTracker objects.
+ // Arrange something so we don't blow up when they try to release them.
+ // - ...
+ // This use case (drop column family while somebody is using it) doesn't seem
+ // the priority, though.
+
+ {
+ InstrumentedMutexLock l(&ltree_map_mutex_);
+
+ auto lock_maps_iter = ltree_map_.find(column_family_id);
+ assert(lock_maps_iter != ltree_map_.end());
+ ltree_map_.erase(lock_maps_iter);
+ } // lock_map_mutex_
+
+ autovector<void*> local_caches;
+ ltree_lookup_cache_->Scrape(&local_caches, nullptr);
+ for (auto cache : local_caches) {
+ delete static_cast<LockTreeMap*>(cache);
+ }
+}
+
+std::shared_ptr<toku::locktree> RangeTreeLockManager::GetLockTreeForCF(
+ ColumnFamilyId column_family_id) {
+ // First check thread-local cache
+ if (ltree_lookup_cache_->Get() == nullptr) {
+ ltree_lookup_cache_->Reset(new LockTreeMap());
+ }
+
+ auto ltree_map_cache = static_cast<LockTreeMap*>(ltree_lookup_cache_->Get());
+
+ auto it = ltree_map_cache->find(column_family_id);
+ if (it != ltree_map_cache->end()) {
+ // Found lock map for this column family.
+ return it->second;
+ }
+
+ // Not found in local cache, grab mutex and check shared LockMaps
+ InstrumentedMutexLock l(&ltree_map_mutex_);
+
+ it = ltree_map_.find(column_family_id);
+ if (it == ltree_map_.end()) {
+ return nullptr;
+ } else {
+ // Found lock map. Store in thread-local cache and return.
+ ltree_map_cache->insert({column_family_id, it->second});
+ return it->second;
+ }
+}
+
+struct LOCK_PRINT_CONTEXT {
+ RangeLockManagerHandle::RangeLockStatus* data; // Save locks here
+ uint32_t cfh_id; // Column Family whose tree we are traversing
+};
+
+// Report left endpoints of the acquired locks
+LockManager::PointLockStatus RangeTreeLockManager::GetPointLockStatus() {
+ PointLockStatus res;
+ LockManager::RangeLockStatus data = GetRangeLockStatus();
+ // report left endpoints
+ for (auto it = data.begin(); it != data.end(); ++it) {
+ auto& val = it->second;
+ res.insert({it->first, {val.start.slice, val.ids, val.exclusive}});
+ }
+ return res;
+}
+
+static void push_into_lock_status_data(void* param, const DBT* left,
+ const DBT* right, TXNID txnid_arg,
+ bool is_shared, TxnidVector* owners) {
+ struct LOCK_PRINT_CONTEXT* ctx = (LOCK_PRINT_CONTEXT*)param;
+ struct RangeLockInfo info;
+
+ info.exclusive = !is_shared;
+
+ deserialize_endpoint(left, &info.start);
+ deserialize_endpoint(right, &info.end);
+
+ if (txnid_arg != TXNID_SHARED) {
+ info.ids.push_back(txnid_arg);
+ } else {
+ for (auto it : *owners) {
+ info.ids.push_back(it);
+ }
+ }
+ ctx->data->insert({ctx->cfh_id, info});
+}
+
+LockManager::RangeLockStatus RangeTreeLockManager::GetRangeLockStatus() {
+ LockManager::RangeLockStatus data;
+ {
+ InstrumentedMutexLock l(&ltree_map_mutex_);
+ for (auto it : ltree_map_) {
+ LOCK_PRINT_CONTEXT ctx = {&data, it.first};
+ it.second->dump_locks((void*)&ctx, push_into_lock_status_data);
+ }
+ }
+ return data;
+}
+
+} // namespace ROCKSDB_NAMESPACE
+#endif // OS_WIN
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/range_tree_lock_manager.h b/src/rocksdb/utilities/transactions/lock/range/range_tree/range_tree_lock_manager.h
new file mode 100644
index 000000000..e4236d600
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/range_tree_lock_manager.h
@@ -0,0 +1,137 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+#ifndef ROCKSDB_LITE
+#ifndef OS_WIN
+
+// For DeadlockInfoBuffer:
+#include "util/thread_local.h"
+#include "utilities/transactions/lock/point/point_lock_manager.h"
+#include "utilities/transactions/lock/range/range_lock_manager.h"
+
+// Lock Tree library:
+#include "utilities/transactions/lock/range/range_tree/lib/locktree/lock_request.h"
+#include "utilities/transactions/lock/range/range_tree/lib/locktree/locktree.h"
+#include "utilities/transactions/lock/range/range_tree/range_tree_lock_tracker.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+typedef DeadlockInfoBufferTempl<RangeDeadlockPath> RangeDeadlockInfoBuffer;
+
+// A Range Lock Manager that uses PerconaFT's locktree library
+class RangeTreeLockManager : public RangeLockManagerBase,
+ public RangeLockManagerHandle {
+ public:
+ LockManager* getLockManager() override { return this; }
+
+ void AddColumnFamily(const ColumnFamilyHandle* cfh) override;
+ void RemoveColumnFamily(const ColumnFamilyHandle* cfh) override;
+
+ void Resize(uint32_t) override;
+ std::vector<DeadlockPath> GetDeadlockInfoBuffer() override;
+
+ std::vector<RangeDeadlockPath> GetRangeDeadlockInfoBuffer() override;
+ void SetRangeDeadlockInfoBufferSize(uint32_t target_size) override;
+
+ // Get a lock on a range
+ // @note only exclusive locks are currently supported (requesting a
+ // non-exclusive lock will get an exclusive one)
+ using LockManager::TryLock;
+ Status TryLock(PessimisticTransaction* txn, ColumnFamilyId column_family_id,
+ const Endpoint& start_endp, const Endpoint& end_endp, Env* env,
+ bool exclusive) override;
+
+ void UnLock(PessimisticTransaction* txn, const LockTracker& tracker,
+ Env* env) override;
+ void UnLock(PessimisticTransaction* txn, ColumnFamilyId column_family_id,
+ const std::string& key, Env* env) override;
+ void UnLock(PessimisticTransaction*, ColumnFamilyId, const Endpoint&,
+ const Endpoint&, Env*) override {
+ // TODO: range unlock does nothing...
+ }
+
+ explicit RangeTreeLockManager(
+ std::shared_ptr<TransactionDBMutexFactory> mutex_factory);
+
+ ~RangeTreeLockManager() override;
+
+ int SetMaxLockMemory(size_t max_lock_memory) override {
+ return ltm_.set_max_lock_memory(max_lock_memory);
+ }
+
+ size_t GetMaxLockMemory() override { return ltm_.get_max_lock_memory(); }
+
+ Counters GetStatus() override;
+
+ bool IsPointLockSupported() const override {
+ // One could have acquired a point lock (it is reduced to range lock)
+ return true;
+ }
+
+ PointLockStatus GetPointLockStatus() override;
+
+ // This is from LockManager
+ LockManager::RangeLockStatus GetRangeLockStatus() override;
+
+ // This has the same meaning as GetRangeLockStatus but is from
+ // RangeLockManagerHandle
+ RangeLockManagerHandle::RangeLockStatus GetRangeLockStatusData() override {
+ return GetRangeLockStatus();
+ }
+
+ bool IsRangeLockSupported() const override { return true; }
+
+ const LockTrackerFactory& GetLockTrackerFactory() const override {
+ return RangeTreeLockTrackerFactory::Get();
+ }
+
+ // Get the locktree which stores locks for the Column Family with given cf_id
+ std::shared_ptr<toku::locktree> GetLockTreeForCF(ColumnFamilyId cf_id);
+
+ void SetEscalationBarrierFunc(EscalationBarrierFunc func) override {
+ barrier_func_ = func;
+ }
+
+ private:
+ toku::locktree_manager ltm_;
+
+ EscalationBarrierFunc barrier_func_ =
+ [](const Endpoint&, const Endpoint&) -> bool { return false; };
+
+ std::shared_ptr<TransactionDBMutexFactory> mutex_factory_;
+
+ // Map from cf_id to locktree*. Can only be accessed while holding the
+ // ltree_map_mutex_. Must use a custom deleter that calls ltm_.release_lt
+ using LockTreeMap =
+ std::unordered_map<ColumnFamilyId, std::shared_ptr<toku::locktree>>;
+ LockTreeMap ltree_map_;
+
+ InstrumentedMutex ltree_map_mutex_;
+
+ // Per-thread cache of ltree_map_.
+ // (uses the same approach as TransactionLockMgr::lock_maps_cache_)
+ std::unique_ptr<ThreadLocalPtr> ltree_lookup_cache_;
+
+ RangeDeadlockInfoBuffer dlock_buffer_;
+
+ std::shared_ptr<toku::locktree> MakeLockTreePtr(toku::locktree* lt);
+ static int CompareDbtEndpoints(void* arg, const DBT* a_key, const DBT* b_key);
+
+ // Callbacks
+ static int on_create(toku::locktree*, void*);
+ static void on_destroy(toku::locktree*) {}
+ static void on_escalate(TXNID txnid, const toku::locktree* lt,
+ const toku::range_buffer& buffer, void* extra);
+
+ static bool OnEscalationBarrierCheck(const DBT* a, const DBT* b, void* extra);
+};
+
+void serialize_endpoint(const Endpoint& endp, std::string* buf);
+void wait_callback_for_locktree(void* cdata, toku::lock_wait_infos* infos);
+
+} // namespace ROCKSDB_NAMESPACE
+#endif // OS_WIN
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/range_tree_lock_tracker.cc b/src/rocksdb/utilities/transactions/lock/range/range_tree/range_tree_lock_tracker.cc
new file mode 100644
index 000000000..be1e1478b
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/range_tree_lock_tracker.cc
@@ -0,0 +1,156 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+#ifndef OS_WIN
+
+#include "utilities/transactions/lock/range/range_tree/range_tree_lock_tracker.h"
+
+#include "utilities/transactions/lock/range/range_tree/range_tree_lock_manager.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+RangeLockList *RangeTreeLockTracker::getOrCreateList() {
+ if (range_list_) return range_list_.get();
+
+ // Doesn't exist, create
+ range_list_.reset(new RangeLockList());
+ return range_list_.get();
+}
+
+void RangeTreeLockTracker::Track(const PointLockRequest &lock_req) {
+ DBT key_dbt;
+ std::string key;
+ serialize_endpoint(Endpoint(lock_req.key, false), &key);
+ toku_fill_dbt(&key_dbt, key.data(), key.size());
+ RangeLockList *rl = getOrCreateList();
+ rl->Append(lock_req.column_family_id, &key_dbt, &key_dbt);
+}
+
+void RangeTreeLockTracker::Track(const RangeLockRequest &lock_req) {
+ DBT start_dbt, end_dbt;
+ std::string start_key, end_key;
+
+ serialize_endpoint(lock_req.start_endp, &start_key);
+ serialize_endpoint(lock_req.end_endp, &end_key);
+
+ toku_fill_dbt(&start_dbt, start_key.data(), start_key.size());
+ toku_fill_dbt(&end_dbt, end_key.data(), end_key.size());
+
+ RangeLockList *rl = getOrCreateList();
+ rl->Append(lock_req.column_family_id, &start_dbt, &end_dbt);
+}
+
+PointLockStatus RangeTreeLockTracker::GetPointLockStatus(
+ ColumnFamilyId /*cf_id*/, const std::string & /*key*/) const {
+ // This function is not expected to be called as RangeTreeLockTracker::
+ // IsPointLockSupported() returns false. Return the status which indicates
+ // the point is not locked.
+ PointLockStatus p;
+ p.locked = false;
+ p.exclusive = true;
+ p.seq = 0;
+ return p;
+}
+
+void RangeTreeLockTracker::Clear() { range_list_.reset(); }
+
+void RangeLockList::Append(ColumnFamilyId cf_id, const DBT *left_key,
+ const DBT *right_key) {
+ MutexLock l(&mutex_);
+ // Only the transaction owner thread calls this function.
+ // The same thread does the lock release, so we can be certain nobody is
+ // releasing the locks concurrently.
+ assert(!releasing_locks_.load());
+ auto it = buffers_.find(cf_id);
+ if (it == buffers_.end()) {
+ // create a new one
+ it = buffers_.emplace(cf_id, std::make_shared<toku::range_buffer>()).first;
+ it->second->create();
+ }
+ it->second->append(left_key, right_key);
+}
+
+void RangeLockList::ReleaseLocks(RangeTreeLockManager *mgr,
+ PessimisticTransaction *txn,
+ bool all_trx_locks) {
+ {
+ MutexLock l(&mutex_);
+ // The lt->release_locks() call below will walk range_list->buffer_. We
+ // need to prevent lock escalation callback from replacing
+ // range_list->buffer_ while we are doing that.
+ //
+ // Additional complication here is internal mutex(es) in the locktree
+ // (let's call them latches):
+ // - Lock escalation first obtains latches on the lock tree
+ // - Then, it calls RangeTreeLockManager::on_escalate to replace
+ // transaction's range_list->buffer_. = Access to that buffer must be
+ // synchronized, so it will want to acquire the range_list->mutex_.
+ //
+ // While in this function we would want to do the reverse:
+ // - Acquire range_list->mutex_ to prevent access to the range_list.
+ // - Then, lt->release_locks() call will walk through the range_list
+ // - and acquire latches on parts of the lock tree to remove locks from
+ // it.
+ //
+ // How do we avoid the deadlock? The idea is that here we set
+ // releasing_locks_=true, and release the mutex.
+ // All other users of the range_list must:
+ // - Acquire the mutex, then check that releasing_locks_=false.
+ // (the code in this function doesnt do that as there's only one thread
+ // that releases transaction's locks)
+ releasing_locks_.store(true);
+ }
+
+ for (auto it : buffers_) {
+ // Don't try to call release_locks() if the buffer is empty! if we are
+ // not holding any locks, the lock tree might be in the STO-mode with
+ // another transaction, and our attempt to release an empty set of locks
+ // will cause an assertion failure.
+ if (it.second->get_num_ranges()) {
+ auto lt_ptr = mgr->GetLockTreeForCF(it.first);
+ toku::locktree *lt = lt_ptr.get();
+
+ lt->release_locks((TXNID)txn, it.second.get(), all_trx_locks);
+
+ it.second->destroy();
+ it.second->create();
+
+ toku::lock_request::retry_all_lock_requests(lt,
+ wait_callback_for_locktree);
+ }
+ }
+
+ Clear();
+ releasing_locks_.store(false);
+}
+
+void RangeLockList::ReplaceLocks(const toku::locktree *lt,
+ const toku::range_buffer &buffer) {
+ MutexLock l(&mutex_);
+ if (releasing_locks_.load()) {
+ // Do nothing. The transaction is releasing its locks, so it will not care
+ // about having a correct list of ranges. (In TokuDB,
+ // toku_db_txn_escalate_callback() makes use of this property, too)
+ return;
+ }
+
+ ColumnFamilyId cf_id = (ColumnFamilyId)lt->get_dict_id().dictid;
+
+ auto it = buffers_.find(cf_id);
+ it->second->destroy();
+ it->second->create();
+
+ toku::range_buffer::iterator iter(&buffer);
+ toku::range_buffer::iterator::record rec;
+ while (iter.current(&rec)) {
+ it->second->append(rec.get_left_key(), rec.get_right_key());
+ iter.next();
+ }
+}
+
+} // namespace ROCKSDB_NAMESPACE
+#endif // OS_WIN
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/lock/range/range_tree/range_tree_lock_tracker.h b/src/rocksdb/utilities/transactions/lock/range/range_tree/range_tree_lock_tracker.h
new file mode 100644
index 000000000..4ef48d252
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/lock/range/range_tree/range_tree_lock_tracker.h
@@ -0,0 +1,146 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+
+#include "util/mutexlock.h"
+#include "utilities/transactions/lock/lock_tracker.h"
+#include "utilities/transactions/pessimistic_transaction.h"
+
+// Range Locking:
+#include "lib/locktree/lock_request.h"
+#include "lib/locktree/locktree.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class RangeTreeLockManager;
+
+// Storage for locks that are currently held by a transaction.
+//
+// Locks are kept in toku::range_buffer because toku::locktree::release_locks()
+// accepts that as an argument.
+//
+// Note: the list of locks may differ slighly from the contents of the lock
+// tree, due to concurrency between lock acquisition, lock release, and lock
+// escalation. See MDEV-18227 and RangeTreeLockManager::UnLock for details.
+// This property is currently harmless.
+//
+// Append() and ReleaseLocks() are not thread-safe, as they are expected to be
+// called only by the owner transaction. ReplaceLocks() is safe to call from
+// other threads.
+class RangeLockList {
+ public:
+ ~RangeLockList() { Clear(); }
+
+ RangeLockList() : releasing_locks_(false) {}
+
+ void Append(ColumnFamilyId cf_id, const DBT* left_key, const DBT* right_key);
+ void ReleaseLocks(RangeTreeLockManager* mgr, PessimisticTransaction* txn,
+ bool all_trx_locks);
+ void ReplaceLocks(const toku::locktree* lt, const toku::range_buffer& buffer);
+
+ private:
+ void Clear() {
+ for (auto it : buffers_) {
+ it.second->destroy();
+ }
+ buffers_.clear();
+ }
+
+ std::unordered_map<ColumnFamilyId, std::shared_ptr<toku::range_buffer>>
+ buffers_;
+ port::Mutex mutex_;
+ std::atomic<bool> releasing_locks_;
+};
+
+// A LockTracker-based object that is used together with RangeTreeLockManager.
+class RangeTreeLockTracker : public LockTracker {
+ public:
+ RangeTreeLockTracker() : range_list_(nullptr) {}
+
+ RangeTreeLockTracker(const RangeTreeLockTracker&) = delete;
+ RangeTreeLockTracker& operator=(const RangeTreeLockTracker&) = delete;
+
+ void Track(const PointLockRequest&) override;
+ void Track(const RangeLockRequest&) override;
+
+ bool IsPointLockSupported() const override {
+ // This indicates that we don't implement GetPointLockStatus()
+ return false;
+ }
+ bool IsRangeLockSupported() const override { return true; }
+
+ // a Not-supported dummy implementation.
+ UntrackStatus Untrack(const RangeLockRequest& /*lock_request*/) override {
+ return UntrackStatus::NOT_TRACKED;
+ }
+
+ UntrackStatus Untrack(const PointLockRequest& /*lock_request*/) override {
+ return UntrackStatus::NOT_TRACKED;
+ }
+
+ // "If this method is not supported, leave it as a no-op."
+ void Merge(const LockTracker&) override {}
+
+ // "If this method is not supported, leave it as a no-op."
+ void Subtract(const LockTracker&) override {}
+
+ void Clear() override;
+
+ // "If this method is not supported, returns nullptr."
+ virtual LockTracker* GetTrackedLocksSinceSavePoint(
+ const LockTracker&) const override {
+ return nullptr;
+ }
+
+ PointLockStatus GetPointLockStatus(ColumnFamilyId column_family_id,
+ const std::string& key) const override;
+
+ // The return value is only used for tests
+ uint64_t GetNumPointLocks() const override { return 0; }
+
+ ColumnFamilyIterator* GetColumnFamilyIterator() const override {
+ return nullptr;
+ }
+
+ KeyIterator* GetKeyIterator(
+ ColumnFamilyId /*column_family_id*/) const override {
+ return nullptr;
+ }
+
+ void ReleaseLocks(RangeTreeLockManager* mgr, PessimisticTransaction* txn,
+ bool all_trx_locks) {
+ if (range_list_) range_list_->ReleaseLocks(mgr, txn, all_trx_locks);
+ }
+
+ void ReplaceLocks(const toku::locktree* lt,
+ const toku::range_buffer& buffer) {
+ // range_list_ cannot be NULL here
+ range_list_->ReplaceLocks(lt, buffer);
+ }
+
+ private:
+ RangeLockList* getOrCreateList();
+ std::unique_ptr<RangeLockList> range_list_;
+};
+
+class RangeTreeLockTrackerFactory : public LockTrackerFactory {
+ public:
+ static const RangeTreeLockTrackerFactory& Get() {
+ static const RangeTreeLockTrackerFactory instance;
+ return instance;
+ }
+
+ LockTracker* Create() const override { return new RangeTreeLockTracker(); }
+
+ private:
+ RangeTreeLockTrackerFactory() {}
+};
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/transactions/optimistic_transaction.cc b/src/rocksdb/utilities/transactions/optimistic_transaction.cc
new file mode 100644
index 000000000..0ee0f28b6
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/optimistic_transaction.cc
@@ -0,0 +1,196 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "utilities/transactions/optimistic_transaction.h"
+
+#include <string>
+
+#include "db/column_family.h"
+#include "db/db_impl/db_impl.h"
+#include "rocksdb/comparator.h"
+#include "rocksdb/db.h"
+#include "rocksdb/status.h"
+#include "rocksdb/utilities/optimistic_transaction_db.h"
+#include "util/cast_util.h"
+#include "util/string_util.h"
+#include "utilities/transactions/lock/point/point_lock_tracker.h"
+#include "utilities/transactions/optimistic_transaction.h"
+#include "utilities/transactions/optimistic_transaction_db_impl.h"
+#include "utilities/transactions/transaction_util.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+struct WriteOptions;
+
+OptimisticTransaction::OptimisticTransaction(
+ OptimisticTransactionDB* txn_db, const WriteOptions& write_options,
+ const OptimisticTransactionOptions& txn_options)
+ : TransactionBaseImpl(txn_db->GetBaseDB(), write_options,
+ PointLockTrackerFactory::Get()),
+ txn_db_(txn_db) {
+ Initialize(txn_options);
+}
+
+void OptimisticTransaction::Initialize(
+ const OptimisticTransactionOptions& txn_options) {
+ if (txn_options.set_snapshot) {
+ SetSnapshot();
+ }
+}
+
+void OptimisticTransaction::Reinitialize(
+ OptimisticTransactionDB* txn_db, const WriteOptions& write_options,
+ const OptimisticTransactionOptions& txn_options) {
+ TransactionBaseImpl::Reinitialize(txn_db->GetBaseDB(), write_options);
+ Initialize(txn_options);
+}
+
+OptimisticTransaction::~OptimisticTransaction() {}
+
+void OptimisticTransaction::Clear() { TransactionBaseImpl::Clear(); }
+
+Status OptimisticTransaction::Prepare() {
+ return Status::InvalidArgument(
+ "Two phase commit not supported for optimistic transactions.");
+}
+
+Status OptimisticTransaction::Commit() {
+ auto txn_db_impl = static_cast_with_check<OptimisticTransactionDBImpl,
+ OptimisticTransactionDB>(txn_db_);
+ assert(txn_db_impl);
+ switch (txn_db_impl->GetValidatePolicy()) {
+ case OccValidationPolicy::kValidateParallel:
+ return CommitWithParallelValidate();
+ case OccValidationPolicy::kValidateSerial:
+ return CommitWithSerialValidate();
+ default:
+ assert(0);
+ }
+ // unreachable, just void compiler complain
+ return Status::OK();
+}
+
+Status OptimisticTransaction::CommitWithSerialValidate() {
+ // Set up callback which will call CheckTransactionForConflicts() to
+ // check whether this transaction is safe to be committed.
+ OptimisticTransactionCallback callback(this);
+
+ DBImpl* db_impl = static_cast_with_check<DBImpl>(db_->GetRootDB());
+
+ Status s = db_impl->WriteWithCallback(
+ write_options_, GetWriteBatch()->GetWriteBatch(), &callback);
+
+ if (s.ok()) {
+ Clear();
+ }
+
+ return s;
+}
+
+Status OptimisticTransaction::CommitWithParallelValidate() {
+ auto txn_db_impl = static_cast_with_check<OptimisticTransactionDBImpl,
+ OptimisticTransactionDB>(txn_db_);
+ assert(txn_db_impl);
+ DBImpl* db_impl = static_cast_with_check<DBImpl>(db_->GetRootDB());
+ assert(db_impl);
+ const size_t space = txn_db_impl->GetLockBucketsSize();
+ std::set<size_t> lk_idxes;
+ std::vector<std::unique_lock<std::mutex>> lks;
+ std::unique_ptr<LockTracker::ColumnFamilyIterator> cf_it(
+ tracked_locks_->GetColumnFamilyIterator());
+ assert(cf_it != nullptr);
+ while (cf_it->HasNext()) {
+ ColumnFamilyId cf = cf_it->Next();
+ std::unique_ptr<LockTracker::KeyIterator> key_it(
+ tracked_locks_->GetKeyIterator(cf));
+ assert(key_it != nullptr);
+ while (key_it->HasNext()) {
+ const std::string& key = key_it->Next();
+ lk_idxes.insert(FastRange64(GetSliceNPHash64(key), space));
+ }
+ }
+ // NOTE: in a single txn, all bucket-locks are taken in ascending order.
+ // In this way, txns from different threads all obey this rule so that
+ // deadlock can be avoided.
+ for (auto v : lk_idxes) {
+ lks.emplace_back(txn_db_impl->LockBucket(v));
+ }
+
+ Status s = TransactionUtil::CheckKeysForConflicts(db_impl, *tracked_locks_,
+ true /* cache_only */);
+ if (!s.ok()) {
+ return s;
+ }
+
+ s = db_impl->Write(write_options_, GetWriteBatch()->GetWriteBatch());
+ if (s.ok()) {
+ Clear();
+ }
+
+ return s;
+}
+
+Status OptimisticTransaction::Rollback() {
+ Clear();
+ return Status::OK();
+}
+
+// Record this key so that we can check it for conflicts at commit time.
+//
+// 'exclusive' is unused for OptimisticTransaction.
+Status OptimisticTransaction::TryLock(ColumnFamilyHandle* column_family,
+ const Slice& key, bool read_only,
+ bool exclusive, const bool do_validate,
+ const bool assume_tracked) {
+ assert(!assume_tracked); // not supported
+ (void)assume_tracked;
+ if (!do_validate) {
+ return Status::OK();
+ }
+ uint32_t cfh_id = GetColumnFamilyID(column_family);
+
+ SetSnapshotIfNeeded();
+
+ SequenceNumber seq;
+ if (snapshot_) {
+ seq = snapshot_->GetSequenceNumber();
+ } else {
+ seq = db_->GetLatestSequenceNumber();
+ }
+
+ std::string key_str = key.ToString();
+
+ TrackKey(cfh_id, key_str, seq, read_only, exclusive);
+
+ // Always return OK. Confilct checking will happen at commit time.
+ return Status::OK();
+}
+
+// Returns OK if it is safe to commit this transaction. Returns Status::Busy
+// if there are read or write conflicts that would prevent us from committing OR
+// if we can not determine whether there would be any such conflicts.
+//
+// Should only be called on writer thread in order to avoid any race conditions
+// in detecting write conflicts.
+Status OptimisticTransaction::CheckTransactionForConflicts(DB* db) {
+ auto db_impl = static_cast_with_check<DBImpl>(db);
+
+ // Since we are on the write thread and do not want to block other writers,
+ // we will do a cache-only conflict check. This can result in TryAgain
+ // getting returned if there is not sufficient memtable history to check
+ // for conflicts.
+ return TransactionUtil::CheckKeysForConflicts(db_impl, *tracked_locks_,
+ true /* cache_only */);
+}
+
+Status OptimisticTransaction::SetName(const TransactionName& /* unused */) {
+ return Status::InvalidArgument("Optimistic transactions cannot be named.");
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/optimistic_transaction.h b/src/rocksdb/utilities/transactions/optimistic_transaction.h
new file mode 100644
index 000000000..de23233d5
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/optimistic_transaction.h
@@ -0,0 +1,101 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#ifndef ROCKSDB_LITE
+
+#include <stack>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "db/write_callback.h"
+#include "rocksdb/db.h"
+#include "rocksdb/slice.h"
+#include "rocksdb/snapshot.h"
+#include "rocksdb/status.h"
+#include "rocksdb/types.h"
+#include "rocksdb/utilities/optimistic_transaction_db.h"
+#include "rocksdb/utilities/transaction.h"
+#include "rocksdb/utilities/write_batch_with_index.h"
+#include "utilities/transactions/transaction_base.h"
+#include "utilities/transactions/transaction_util.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class OptimisticTransaction : public TransactionBaseImpl {
+ public:
+ OptimisticTransaction(OptimisticTransactionDB* db,
+ const WriteOptions& write_options,
+ const OptimisticTransactionOptions& txn_options);
+ // No copying allowed
+ OptimisticTransaction(const OptimisticTransaction&) = delete;
+ void operator=(const OptimisticTransaction&) = delete;
+
+ virtual ~OptimisticTransaction();
+
+ void Reinitialize(OptimisticTransactionDB* txn_db,
+ const WriteOptions& write_options,
+ const OptimisticTransactionOptions& txn_options);
+
+ Status Prepare() override;
+
+ Status Commit() override;
+
+ Status Rollback() override;
+
+ Status SetName(const TransactionName& name) override;
+
+ protected:
+ Status TryLock(ColumnFamilyHandle* column_family, const Slice& key,
+ bool read_only, bool exclusive, const bool do_validate = true,
+ const bool assume_tracked = false) override;
+
+ private:
+ ROCKSDB_FIELD_UNUSED OptimisticTransactionDB* const txn_db_;
+
+ friend class OptimisticTransactionCallback;
+
+ void Initialize(const OptimisticTransactionOptions& txn_options);
+
+ // Returns OK if it is safe to commit this transaction. Returns Status::Busy
+ // if there are read or write conflicts that would prevent us from committing
+ // OR if we can not determine whether there would be any such conflicts.
+ //
+ // Should only be called on writer thread.
+ Status CheckTransactionForConflicts(DB* db);
+
+ void Clear() override;
+
+ void UnlockGetForUpdate(ColumnFamilyHandle* /* unused */,
+ const Slice& /* unused */) override {
+ // Nothing to unlock.
+ }
+
+ Status CommitWithSerialValidate();
+
+ Status CommitWithParallelValidate();
+};
+
+// Used at commit time to trigger transaction validation
+class OptimisticTransactionCallback : public WriteCallback {
+ public:
+ explicit OptimisticTransactionCallback(OptimisticTransaction* txn)
+ : txn_(txn) {}
+
+ Status Callback(DB* db) override {
+ return txn_->CheckTransactionForConflicts(db);
+ }
+
+ bool AllowWriteBatching() override { return false; }
+
+ private:
+ OptimisticTransaction* txn_;
+};
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/optimistic_transaction_db_impl.cc b/src/rocksdb/utilities/transactions/optimistic_transaction_db_impl.cc
new file mode 100644
index 000000000..bffb3d5ed
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/optimistic_transaction_db_impl.cc
@@ -0,0 +1,111 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "utilities/transactions/optimistic_transaction_db_impl.h"
+
+#include <string>
+#include <vector>
+
+#include "db/db_impl/db_impl.h"
+#include "rocksdb/db.h"
+#include "rocksdb/options.h"
+#include "rocksdb/utilities/optimistic_transaction_db.h"
+#include "utilities/transactions/optimistic_transaction.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+Transaction* OptimisticTransactionDBImpl::BeginTransaction(
+ const WriteOptions& write_options,
+ const OptimisticTransactionOptions& txn_options, Transaction* old_txn) {
+ if (old_txn != nullptr) {
+ ReinitializeTransaction(old_txn, write_options, txn_options);
+ return old_txn;
+ } else {
+ return new OptimisticTransaction(this, write_options, txn_options);
+ }
+}
+
+std::unique_lock<std::mutex> OptimisticTransactionDBImpl::LockBucket(
+ size_t idx) {
+ assert(idx < bucketed_locks_.size());
+ return std::unique_lock<std::mutex>(*bucketed_locks_[idx]);
+}
+
+Status OptimisticTransactionDB::Open(const Options& options,
+ const std::string& dbname,
+ OptimisticTransactionDB** dbptr) {
+ DBOptions db_options(options);
+ ColumnFamilyOptions cf_options(options);
+ std::vector<ColumnFamilyDescriptor> column_families;
+ column_families.push_back(
+ ColumnFamilyDescriptor(kDefaultColumnFamilyName, cf_options));
+ std::vector<ColumnFamilyHandle*> handles;
+ Status s = Open(db_options, dbname, column_families, &handles, dbptr);
+ if (s.ok()) {
+ assert(handles.size() == 1);
+ // i can delete the handle since DBImpl is always holding a reference to
+ // default column family
+ delete handles[0];
+ }
+
+ return s;
+}
+
+Status OptimisticTransactionDB::Open(
+ const DBOptions& db_options, const std::string& dbname,
+ const std::vector<ColumnFamilyDescriptor>& column_families,
+ std::vector<ColumnFamilyHandle*>* handles,
+ OptimisticTransactionDB** dbptr) {
+ return OptimisticTransactionDB::Open(db_options,
+ OptimisticTransactionDBOptions(), dbname,
+ column_families, handles, dbptr);
+}
+
+Status OptimisticTransactionDB::Open(
+ const DBOptions& db_options,
+ const OptimisticTransactionDBOptions& occ_options,
+ const std::string& dbname,
+ const std::vector<ColumnFamilyDescriptor>& column_families,
+ std::vector<ColumnFamilyHandle*>* handles,
+ OptimisticTransactionDB** dbptr) {
+ Status s;
+ DB* db;
+
+ std::vector<ColumnFamilyDescriptor> column_families_copy = column_families;
+
+ // Enable MemTable History if not already enabled
+ for (auto& column_family : column_families_copy) {
+ ColumnFamilyOptions* options = &column_family.options;
+
+ if (options->max_write_buffer_size_to_maintain == 0 &&
+ options->max_write_buffer_number_to_maintain == 0) {
+ // Setting to -1 will set the History size to
+ // max_write_buffer_number * write_buffer_size.
+ options->max_write_buffer_size_to_maintain = -1;
+ }
+ }
+
+ s = DB::Open(db_options, dbname, column_families_copy, handles, &db);
+
+ if (s.ok()) {
+ *dbptr = new OptimisticTransactionDBImpl(db, occ_options);
+ }
+
+ return s;
+}
+
+void OptimisticTransactionDBImpl::ReinitializeTransaction(
+ Transaction* txn, const WriteOptions& write_options,
+ const OptimisticTransactionOptions& txn_options) {
+ assert(dynamic_cast<OptimisticTransaction*>(txn) != nullptr);
+ auto txn_impl = reinterpret_cast<OptimisticTransaction*>(txn);
+
+ txn_impl->Reinitialize(this, write_options, txn_options);
+}
+
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/optimistic_transaction_db_impl.h b/src/rocksdb/utilities/transactions/optimistic_transaction_db_impl.h
new file mode 100644
index 000000000..88e86ea4a
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/optimistic_transaction_db_impl.h
@@ -0,0 +1,88 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+#ifndef ROCKSDB_LITE
+
+#include <algorithm>
+#include <mutex>
+#include <vector>
+
+#include "rocksdb/db.h"
+#include "rocksdb/options.h"
+#include "rocksdb/utilities/optimistic_transaction_db.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class OptimisticTransactionDBImpl : public OptimisticTransactionDB {
+ public:
+ explicit OptimisticTransactionDBImpl(
+ DB* db, const OptimisticTransactionDBOptions& occ_options,
+ bool take_ownership = true)
+ : OptimisticTransactionDB(db),
+ db_owner_(take_ownership),
+ validate_policy_(occ_options.validate_policy) {
+ if (validate_policy_ == OccValidationPolicy::kValidateParallel) {
+ uint32_t bucket_size = std::max(16u, occ_options.occ_lock_buckets);
+ bucketed_locks_.reserve(bucket_size);
+ for (size_t i = 0; i < bucket_size; ++i) {
+ bucketed_locks_.emplace_back(
+ std::unique_ptr<std::mutex>(new std::mutex));
+ }
+ }
+ }
+
+ ~OptimisticTransactionDBImpl() {
+ // Prevent this stackable from destroying
+ // base db
+ if (!db_owner_) {
+ db_ = nullptr;
+ }
+ }
+
+ Transaction* BeginTransaction(const WriteOptions& write_options,
+ const OptimisticTransactionOptions& txn_options,
+ Transaction* old_txn) override;
+
+ // Transactional `DeleteRange()` is not yet supported.
+ using StackableDB::DeleteRange;
+ virtual Status DeleteRange(const WriteOptions&, ColumnFamilyHandle*,
+ const Slice&, const Slice&) override {
+ return Status::NotSupported();
+ }
+
+ // Range deletions also must not be snuck into `WriteBatch`es as they are
+ // incompatible with `OptimisticTransactionDB`.
+ virtual Status Write(const WriteOptions& write_opts,
+ WriteBatch* batch) override {
+ if (batch->HasDeleteRange()) {
+ return Status::NotSupported();
+ }
+ return OptimisticTransactionDB::Write(write_opts, batch);
+ }
+
+ size_t GetLockBucketsSize() const { return bucketed_locks_.size(); }
+
+ OccValidationPolicy GetValidatePolicy() const { return validate_policy_; }
+
+ std::unique_lock<std::mutex> LockBucket(size_t idx);
+
+ private:
+ // NOTE: used in validation phase. Each key is hashed into some
+ // bucket. We then take the lock in the hash value order to avoid deadlock.
+ std::vector<std::unique_ptr<std::mutex>> bucketed_locks_;
+
+ bool db_owner_;
+
+ const OccValidationPolicy validate_policy_;
+
+ void ReinitializeTransaction(Transaction* txn,
+ const WriteOptions& write_options,
+ const OptimisticTransactionOptions& txn_options =
+ OptimisticTransactionOptions());
+};
+
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/optimistic_transaction_test.cc b/src/rocksdb/utilities/transactions/optimistic_transaction_test.cc
new file mode 100644
index 000000000..aa8192c32
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/optimistic_transaction_test.cc
@@ -0,0 +1,1491 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include <functional>
+#include <string>
+#include <thread>
+
+#include "db/db_impl/db_impl.h"
+#include "db/db_test_util.h"
+#include "port/port.h"
+#include "rocksdb/db.h"
+#include "rocksdb/perf_context.h"
+#include "rocksdb/utilities/optimistic_transaction_db.h"
+#include "rocksdb/utilities/transaction.h"
+#include "test_util/sync_point.h"
+#include "test_util/testharness.h"
+#include "test_util/transaction_test_util.h"
+#include "util/crc32c.h"
+#include "util/random.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class OptimisticTransactionTest
+ : public testing::Test,
+ public testing::WithParamInterface<OccValidationPolicy> {
+ public:
+ OptimisticTransactionDB* txn_db;
+ std::string dbname;
+ Options options;
+
+ OptimisticTransactionTest() {
+ options.create_if_missing = true;
+ options.max_write_buffer_number = 2;
+ options.max_write_buffer_size_to_maintain = 2 * Arena::kInlineSize;
+ options.merge_operator.reset(new TestPutOperator());
+ dbname = test::PerThreadDBPath("optimistic_transaction_testdb");
+
+ EXPECT_OK(DestroyDB(dbname, options));
+ Open();
+ }
+ ~OptimisticTransactionTest() override {
+ delete txn_db;
+ EXPECT_OK(DestroyDB(dbname, options));
+ }
+
+ void Reopen() {
+ delete txn_db;
+ txn_db = nullptr;
+ Open();
+ }
+
+ private:
+ void Open() {
+ ColumnFamilyOptions cf_options(options);
+ OptimisticTransactionDBOptions occ_opts;
+ occ_opts.validate_policy = GetParam();
+ std::vector<ColumnFamilyDescriptor> column_families;
+ std::vector<ColumnFamilyHandle*> handles;
+ column_families.push_back(
+ ColumnFamilyDescriptor(kDefaultColumnFamilyName, cf_options));
+ Status s =
+ OptimisticTransactionDB::Open(DBOptions(options), occ_opts, dbname,
+ column_families, &handles, &txn_db);
+
+ ASSERT_OK(s);
+ ASSERT_NE(txn_db, nullptr);
+ ASSERT_EQ(handles.size(), 1);
+ delete handles[0];
+ }
+};
+
+TEST_P(OptimisticTransactionTest, SuccessTest) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ std::string value;
+
+ ASSERT_OK(txn_db->Put(write_options, Slice("foo"), Slice("bar")));
+ ASSERT_OK(txn_db->Put(write_options, Slice("foo2"), Slice("bar")));
+
+ Transaction* txn = txn_db->BeginTransaction(write_options);
+ ASSERT_NE(txn, nullptr);
+
+ ASSERT_OK(txn->GetForUpdate(read_options, "foo", &value));
+ ASSERT_EQ(value, "bar");
+
+ ASSERT_OK(txn->Put(Slice("foo"), Slice("bar2")));
+
+ ASSERT_OK(txn->GetForUpdate(read_options, "foo", &value));
+ ASSERT_EQ(value, "bar2");
+
+ ASSERT_OK(txn->Commit());
+
+ ASSERT_OK(txn_db->Get(read_options, "foo", &value));
+ ASSERT_EQ(value, "bar2");
+
+ delete txn;
+}
+
+TEST_P(OptimisticTransactionTest, WriteConflictTest) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ std::string value;
+
+ ASSERT_OK(txn_db->Put(write_options, "foo", "bar"));
+ ASSERT_OK(txn_db->Put(write_options, "foo2", "bar"));
+
+ Transaction* txn = txn_db->BeginTransaction(write_options);
+ ASSERT_NE(txn, nullptr);
+
+ ASSERT_OK(txn->Put("foo", "bar2"));
+
+ // This Put outside of a transaction will conflict with the previous write
+ ASSERT_OK(txn_db->Put(write_options, "foo", "barz"));
+
+ ASSERT_OK(txn_db->Get(read_options, "foo", &value));
+ ASSERT_EQ(value, "barz");
+ ASSERT_EQ(1, txn->GetNumKeys());
+
+ Status s = txn->Commit();
+ ASSERT_TRUE(s.IsBusy()); // Txn should not commit
+
+ // Verify that transaction did not write anything
+ ASSERT_OK(txn_db->Get(read_options, "foo", &value));
+ ASSERT_EQ(value, "barz");
+ ASSERT_OK(txn_db->Get(read_options, "foo2", &value));
+ ASSERT_EQ(value, "bar");
+
+ delete txn;
+}
+
+TEST_P(OptimisticTransactionTest, WriteConflictTest2) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ OptimisticTransactionOptions txn_options;
+ std::string value;
+
+ ASSERT_OK(txn_db->Put(write_options, "foo", "bar"));
+ ASSERT_OK(txn_db->Put(write_options, "foo2", "bar"));
+
+ txn_options.set_snapshot = true;
+ Transaction* txn = txn_db->BeginTransaction(write_options, txn_options);
+ ASSERT_NE(txn, nullptr);
+
+ // This Put outside of a transaction will conflict with a later write
+ ASSERT_OK(txn_db->Put(write_options, "foo", "barz"));
+
+ ASSERT_OK(txn->Put(
+ "foo", "bar2")); // Conflicts with write done after snapshot taken
+
+ ASSERT_OK(txn_db->Get(read_options, "foo", &value));
+ ASSERT_EQ(value, "barz");
+
+ Status s = txn->Commit();
+ ASSERT_TRUE(s.IsBusy()); // Txn should not commit
+
+ // Verify that transaction did not write anything
+ ASSERT_OK(txn_db->Get(read_options, "foo", &value));
+ ASSERT_EQ(value, "barz");
+ ASSERT_OK(txn_db->Get(read_options, "foo2", &value));
+ ASSERT_EQ(value, "bar");
+
+ delete txn;
+}
+
+TEST_P(OptimisticTransactionTest, WriteConflictTest3) {
+ ASSERT_OK(txn_db->Put(WriteOptions(), "foo", "bar"));
+
+ Transaction* txn = txn_db->BeginTransaction(WriteOptions());
+ ASSERT_NE(txn, nullptr);
+
+ std::string value;
+ ASSERT_OK(txn->GetForUpdate(ReadOptions(), "foo", &value));
+ ASSERT_EQ(value, "bar");
+ ASSERT_OK(txn->Merge("foo", "bar3"));
+
+ // Merge outside of a transaction should conflict with the previous merge
+ ASSERT_OK(txn_db->Merge(WriteOptions(), "foo", "bar2"));
+ ASSERT_OK(txn_db->Get(ReadOptions(), "foo", &value));
+ ASSERT_EQ(value, "bar2");
+
+ ASSERT_EQ(1, txn->GetNumKeys());
+
+ Status s = txn->Commit();
+ EXPECT_TRUE(s.IsBusy()); // Txn should not commit
+
+ // Verify that transaction did not write anything
+ ASSERT_OK(txn_db->Get(ReadOptions(), "foo", &value));
+ ASSERT_EQ(value, "bar2");
+
+ delete txn;
+}
+
+TEST_P(OptimisticTransactionTest, WriteConflict4) {
+ ASSERT_OK(txn_db->Put(WriteOptions(), "foo", "bar"));
+
+ Transaction* txn = txn_db->BeginTransaction(WriteOptions());
+ ASSERT_NE(txn, nullptr);
+
+ std::string value;
+ ASSERT_OK(txn->GetForUpdate(ReadOptions(), "foo", &value));
+ ASSERT_EQ(value, "bar");
+ ASSERT_OK(txn->Merge("foo", "bar3"));
+
+ // Range delete outside of a transaction should conflict with the previous
+ // merge inside txn
+ auto* dbimpl = static_cast_with_check<DBImpl>(txn_db->GetRootDB());
+ ColumnFamilyHandle* default_cf = dbimpl->DefaultColumnFamily();
+ ASSERT_OK(dbimpl->DeleteRange(WriteOptions(), default_cf, "foo", "foo1"));
+ Status s = txn_db->Get(ReadOptions(), "foo", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ ASSERT_EQ(1, txn->GetNumKeys());
+
+ s = txn->Commit();
+ EXPECT_TRUE(s.IsBusy()); // Txn should not commit
+
+ // Verify that transaction did not write anything
+ s = txn_db->Get(ReadOptions(), "foo", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ delete txn;
+}
+
+TEST_P(OptimisticTransactionTest, ReadConflictTest) {
+ WriteOptions write_options;
+ ReadOptions read_options, snapshot_read_options;
+ OptimisticTransactionOptions txn_options;
+ std::string value;
+
+ ASSERT_OK(txn_db->Put(write_options, "foo", "bar"));
+ ASSERT_OK(txn_db->Put(write_options, "foo2", "bar"));
+
+ txn_options.set_snapshot = true;
+ Transaction* txn = txn_db->BeginTransaction(write_options, txn_options);
+ ASSERT_NE(txn, nullptr);
+
+ txn->SetSnapshot();
+ snapshot_read_options.snapshot = txn->GetSnapshot();
+
+ ASSERT_OK(txn->GetForUpdate(snapshot_read_options, "foo", &value));
+ ASSERT_EQ(value, "bar");
+
+ // This Put outside of a transaction will conflict with the previous read
+ ASSERT_OK(txn_db->Put(write_options, "foo", "barz"));
+
+ ASSERT_OK(txn_db->Get(read_options, "foo", &value));
+ ASSERT_EQ(value, "barz");
+
+ Status s = txn->Commit();
+ ASSERT_TRUE(s.IsBusy()); // Txn should not commit
+
+ // Verify that transaction did not write anything
+ ASSERT_OK(txn->GetForUpdate(read_options, "foo", &value));
+ ASSERT_EQ(value, "barz");
+ ASSERT_OK(txn->GetForUpdate(read_options, "foo2", &value));
+ ASSERT_EQ(value, "bar");
+
+ delete txn;
+}
+
+TEST_P(OptimisticTransactionTest, TxnOnlyTest) {
+ // Test to make sure transactions work when there are no other writes in an
+ // empty db.
+
+ WriteOptions write_options;
+ ReadOptions read_options;
+ std::string value;
+
+ Transaction* txn = txn_db->BeginTransaction(write_options);
+ ASSERT_NE(txn, nullptr);
+
+ ASSERT_OK(txn->Put("x", "y"));
+
+ ASSERT_OK(txn->Commit());
+
+ delete txn;
+}
+
+TEST_P(OptimisticTransactionTest, FlushTest) {
+ WriteOptions write_options;
+ ReadOptions read_options, snapshot_read_options;
+ std::string value;
+
+ ASSERT_OK(txn_db->Put(write_options, Slice("foo"), Slice("bar")));
+ ASSERT_OK(txn_db->Put(write_options, Slice("foo2"), Slice("bar")));
+
+ Transaction* txn = txn_db->BeginTransaction(write_options);
+ ASSERT_NE(txn, nullptr);
+
+ snapshot_read_options.snapshot = txn->GetSnapshot();
+
+ ASSERT_OK(txn->GetForUpdate(snapshot_read_options, "foo", &value));
+ ASSERT_EQ(value, "bar");
+
+ ASSERT_OK(txn->Put(Slice("foo"), Slice("bar2")));
+
+ ASSERT_OK(txn->GetForUpdate(snapshot_read_options, "foo", &value));
+ ASSERT_EQ(value, "bar2");
+
+ // Put a random key so we have a memtable to flush
+ ASSERT_OK(txn_db->Put(write_options, "dummy", "dummy"));
+
+ // force a memtable flush
+ FlushOptions flush_ops;
+ ASSERT_OK(txn_db->Flush(flush_ops));
+
+ // txn should commit since the flushed table is still in MemtableList History
+ ASSERT_OK(txn->Commit());
+
+ ASSERT_OK(txn_db->Get(read_options, "foo", &value));
+ ASSERT_EQ(value, "bar2");
+
+ delete txn;
+}
+
+TEST_P(OptimisticTransactionTest, FlushTest2) {
+ WriteOptions write_options;
+ ReadOptions read_options, snapshot_read_options;
+ std::string value;
+
+ ASSERT_OK(txn_db->Put(write_options, Slice("foo"), Slice("bar")));
+ ASSERT_OK(txn_db->Put(write_options, Slice("foo2"), Slice("bar")));
+
+ Transaction* txn = txn_db->BeginTransaction(write_options);
+ ASSERT_NE(txn, nullptr);
+
+ snapshot_read_options.snapshot = txn->GetSnapshot();
+
+ ASSERT_OK(txn->GetForUpdate(snapshot_read_options, "foo", &value));
+ ASSERT_EQ(value, "bar");
+
+ ASSERT_OK(txn->Put(Slice("foo"), Slice("bar2")));
+
+ ASSERT_OK(txn->GetForUpdate(snapshot_read_options, "foo", &value));
+ ASSERT_EQ(value, "bar2");
+
+ // Put a random key so we have a MemTable to flush
+ ASSERT_OK(txn_db->Put(write_options, "dummy", "dummy"));
+
+ // force a memtable flush
+ FlushOptions flush_ops;
+ ASSERT_OK(txn_db->Flush(flush_ops));
+
+ // Put a random key so we have a MemTable to flush
+ ASSERT_OK(txn_db->Put(write_options, "dummy", "dummy2"));
+
+ // force a memtable flush
+ ASSERT_OK(txn_db->Flush(flush_ops));
+
+ ASSERT_OK(txn_db->Put(write_options, "dummy", "dummy3"));
+
+ // force a memtable flush
+ // Since our test db has max_write_buffer_number=2, this flush will cause
+ // the first memtable to get purged from the MemtableList history.
+ ASSERT_OK(txn_db->Flush(flush_ops));
+
+ Status s = txn->Commit();
+ // txn should not commit since MemTableList History is not large enough
+ ASSERT_TRUE(s.IsTryAgain());
+
+ ASSERT_OK(txn_db->Get(read_options, "foo", &value));
+ ASSERT_EQ(value, "bar");
+
+ delete txn;
+}
+
+// Trigger the condition where some old memtables are skipped when doing
+// TransactionUtil::CheckKey(), and make sure the result is still correct.
+TEST_P(OptimisticTransactionTest, CheckKeySkipOldMemtable) {
+ const int kAttemptHistoryMemtable = 0;
+ const int kAttemptImmMemTable = 1;
+ for (int attempt = kAttemptHistoryMemtable; attempt <= kAttemptImmMemTable;
+ attempt++) {
+ Reopen();
+
+ WriteOptions write_options;
+ ReadOptions read_options;
+ ReadOptions snapshot_read_options;
+ ReadOptions snapshot_read_options2;
+ std::string value;
+
+ ASSERT_OK(txn_db->Put(write_options, Slice("foo"), Slice("bar")));
+ ASSERT_OK(txn_db->Put(write_options, Slice("foo2"), Slice("bar")));
+
+ Transaction* txn = txn_db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn != nullptr);
+
+ Transaction* txn2 = txn_db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn2 != nullptr);
+
+ snapshot_read_options.snapshot = txn->GetSnapshot();
+ ASSERT_OK(txn->GetForUpdate(snapshot_read_options, "foo", &value));
+ ASSERT_EQ(value, "bar");
+ ASSERT_OK(txn->Put(Slice("foo"), Slice("bar2")));
+
+ snapshot_read_options2.snapshot = txn2->GetSnapshot();
+ ASSERT_OK(txn2->GetForUpdate(snapshot_read_options2, "foo2", &value));
+ ASSERT_EQ(value, "bar");
+ ASSERT_OK(txn2->Put(Slice("foo2"), Slice("bar2")));
+
+ // txn updates "foo" and txn2 updates "foo2", and now a write is
+ // issued for "foo", which conflicts with txn but not txn2
+ ASSERT_OK(txn_db->Put(write_options, "foo", "bar"));
+
+ if (attempt == kAttemptImmMemTable) {
+ // For the second attempt, hold flush from beginning. The memtable
+ // will be switched to immutable after calling TEST_SwitchMemtable()
+ // while CheckKey() is called.
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency(
+ {{"OptimisticTransactionTest.CheckKeySkipOldMemtable",
+ "FlushJob::Start"}});
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+ }
+
+ // force a memtable flush. The memtable should still be kept
+ FlushOptions flush_ops;
+ if (attempt == kAttemptHistoryMemtable) {
+ ASSERT_OK(txn_db->Flush(flush_ops));
+ } else {
+ ASSERT_EQ(attempt, kAttemptImmMemTable);
+ DBImpl* db_impl = static_cast<DBImpl*>(txn_db->GetRootDB());
+ ASSERT_OK(db_impl->TEST_SwitchMemtable());
+ }
+ uint64_t num_imm_mems;
+ ASSERT_TRUE(txn_db->GetIntProperty(DB::Properties::kNumImmutableMemTable,
+ &num_imm_mems));
+ if (attempt == kAttemptHistoryMemtable) {
+ ASSERT_EQ(0, num_imm_mems);
+ } else {
+ ASSERT_EQ(attempt, kAttemptImmMemTable);
+ ASSERT_EQ(1, num_imm_mems);
+ }
+
+ // Put something in active memtable
+ ASSERT_OK(txn_db->Put(write_options, Slice("foo3"), Slice("bar")));
+
+ // Create txn3 after flushing, when this transaction is commited,
+ // only need to check the active memtable
+ Transaction* txn3 = txn_db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn3 != nullptr);
+
+ // Commit both of txn and txn2. txn will conflict but txn2 will
+ // pass. In both ways, both memtables are queried.
+ SetPerfLevel(PerfLevel::kEnableCount);
+
+ get_perf_context()->Reset();
+ Status s = txn->Commit();
+ // We should have checked two memtables
+ ASSERT_EQ(2, get_perf_context()->get_from_memtable_count);
+ // txn should fail because of conflict, even if the memtable
+ // has flushed, because it is still preserved in history.
+ ASSERT_TRUE(s.IsBusy());
+
+ get_perf_context()->Reset();
+ s = txn2->Commit();
+ // We should have checked two memtables
+ ASSERT_EQ(2, get_perf_context()->get_from_memtable_count);
+ ASSERT_TRUE(s.ok());
+
+ ASSERT_OK(txn3->Put(Slice("foo2"), Slice("bar2")));
+ get_perf_context()->Reset();
+ s = txn3->Commit();
+ // txn3 is created after the active memtable is created, so that is the only
+ // memtable to check.
+ ASSERT_EQ(1, get_perf_context()->get_from_memtable_count);
+ ASSERT_TRUE(s.ok());
+
+ TEST_SYNC_POINT("OptimisticTransactionTest.CheckKeySkipOldMemtable");
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+
+ SetPerfLevel(PerfLevel::kDisable);
+
+ delete txn;
+ delete txn2;
+ delete txn3;
+ }
+}
+
+TEST_P(OptimisticTransactionTest, NoSnapshotTest) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ std::string value;
+
+ ASSERT_OK(txn_db->Put(write_options, "AAA", "bar"));
+
+ Transaction* txn = txn_db->BeginTransaction(write_options);
+ ASSERT_NE(txn, nullptr);
+
+ // Modify key after transaction start
+ ASSERT_OK(txn_db->Put(write_options, "AAA", "bar1"));
+
+ // Read and write without a snapshot
+ ASSERT_OK(txn->GetForUpdate(read_options, "AAA", &value));
+ ASSERT_EQ(value, "bar1");
+ ASSERT_OK(txn->Put("AAA", "bar2"));
+
+ // Should commit since read/write was done after data changed
+ ASSERT_OK(txn->Commit());
+
+ ASSERT_OK(txn->GetForUpdate(read_options, "AAA", &value));
+ ASSERT_EQ(value, "bar2");
+
+ delete txn;
+}
+
+TEST_P(OptimisticTransactionTest, MultipleSnapshotTest) {
+ WriteOptions write_options;
+ ReadOptions read_options, snapshot_read_options;
+ std::string value;
+
+ ASSERT_OK(txn_db->Put(write_options, "AAA", "bar"));
+ ASSERT_OK(txn_db->Put(write_options, "BBB", "bar"));
+ ASSERT_OK(txn_db->Put(write_options, "CCC", "bar"));
+
+ Transaction* txn = txn_db->BeginTransaction(write_options);
+ ASSERT_NE(txn, nullptr);
+
+ ASSERT_OK(txn_db->Put(write_options, "AAA", "bar1"));
+
+ // Read and write without a snapshot
+ ASSERT_OK(txn->GetForUpdate(read_options, "AAA", &value));
+ ASSERT_EQ(value, "bar1");
+ ASSERT_OK(txn->Put("AAA", "bar2"));
+
+ // Modify BBB before snapshot is taken
+ ASSERT_OK(txn_db->Put(write_options, "BBB", "bar1"));
+
+ txn->SetSnapshot();
+ snapshot_read_options.snapshot = txn->GetSnapshot();
+
+ // Read and write with snapshot
+ ASSERT_OK(txn->GetForUpdate(snapshot_read_options, "BBB", &value));
+ ASSERT_EQ(value, "bar1");
+ ASSERT_OK(txn->Put("BBB", "bar2"));
+
+ ASSERT_OK(txn_db->Put(write_options, "CCC", "bar1"));
+
+ // Set a new snapshot
+ txn->SetSnapshot();
+ snapshot_read_options.snapshot = txn->GetSnapshot();
+
+ // Read and write with snapshot
+ ASSERT_OK(txn->GetForUpdate(snapshot_read_options, "CCC", &value));
+ ASSERT_EQ(value, "bar1");
+ ASSERT_OK(txn->Put("CCC", "bar2"));
+
+ ASSERT_OK(txn->GetForUpdate(read_options, "AAA", &value));
+ ASSERT_EQ(value, "bar2");
+ ASSERT_OK(txn->GetForUpdate(read_options, "BBB", &value));
+ ASSERT_EQ(value, "bar2");
+ ASSERT_OK(txn->GetForUpdate(read_options, "CCC", &value));
+ ASSERT_EQ(value, "bar2");
+
+ ASSERT_OK(txn_db->Get(read_options, "AAA", &value));
+ ASSERT_EQ(value, "bar1");
+ ASSERT_OK(txn_db->Get(read_options, "BBB", &value));
+ ASSERT_EQ(value, "bar1");
+ ASSERT_OK(txn_db->Get(read_options, "CCC", &value));
+ ASSERT_EQ(value, "bar1");
+
+ ASSERT_OK(txn->Commit());
+
+ ASSERT_OK(txn_db->Get(read_options, "AAA", &value));
+ ASSERT_EQ(value, "bar2");
+ ASSERT_OK(txn_db->Get(read_options, "BBB", &value));
+ ASSERT_EQ(value, "bar2");
+ ASSERT_OK(txn_db->Get(read_options, "CCC", &value));
+ ASSERT_EQ(value, "bar2");
+
+ // verify that we track multiple writes to the same key at different snapshots
+ delete txn;
+ txn = txn_db->BeginTransaction(write_options);
+
+ // Potentially conflicting writes
+ ASSERT_OK(txn_db->Put(write_options, "ZZZ", "zzz"));
+ ASSERT_OK(txn_db->Put(write_options, "XXX", "xxx"));
+
+ txn->SetSnapshot();
+
+ OptimisticTransactionOptions txn_options;
+ txn_options.set_snapshot = true;
+ Transaction* txn2 = txn_db->BeginTransaction(write_options, txn_options);
+ txn2->SetSnapshot();
+
+ // This should not conflict in txn since the snapshot is later than the
+ // previous write (spoiler alert: it will later conflict with txn2).
+ ASSERT_OK(txn->Put("ZZZ", "zzzz"));
+ ASSERT_OK(txn->Commit());
+
+ delete txn;
+
+ // This will conflict since the snapshot is earlier than another write to ZZZ
+ ASSERT_OK(txn2->Put("ZZZ", "xxxxx"));
+
+ Status s = txn2->Commit();
+ ASSERT_TRUE(s.IsBusy());
+
+ delete txn2;
+}
+
+TEST_P(OptimisticTransactionTest, ColumnFamiliesTest) {
+ WriteOptions write_options;
+ ReadOptions read_options, snapshot_read_options;
+ OptimisticTransactionOptions txn_options;
+ std::string value;
+
+ ColumnFamilyHandle *cfa, *cfb;
+ ColumnFamilyOptions cf_options;
+
+ // Create 2 new column families
+ ASSERT_OK(txn_db->CreateColumnFamily(cf_options, "CFA", &cfa));
+ ASSERT_OK(txn_db->CreateColumnFamily(cf_options, "CFB", &cfb));
+
+ delete cfa;
+ delete cfb;
+ delete txn_db;
+ txn_db = nullptr;
+
+ // open DB with three column families
+ std::vector<ColumnFamilyDescriptor> column_families;
+ // have to open default column family
+ column_families.push_back(
+ ColumnFamilyDescriptor(kDefaultColumnFamilyName, ColumnFamilyOptions()));
+ // open the new column families
+ column_families.push_back(
+ ColumnFamilyDescriptor("CFA", ColumnFamilyOptions()));
+ column_families.push_back(
+ ColumnFamilyDescriptor("CFB", ColumnFamilyOptions()));
+ std::vector<ColumnFamilyHandle*> handles;
+ ASSERT_OK(OptimisticTransactionDB::Open(options, dbname, column_families,
+ &handles, &txn_db));
+ assert(txn_db != nullptr);
+ ASSERT_NE(txn_db, nullptr);
+
+ Transaction* txn = txn_db->BeginTransaction(write_options);
+ ASSERT_NE(txn, nullptr);
+
+ txn->SetSnapshot();
+ snapshot_read_options.snapshot = txn->GetSnapshot();
+
+ txn_options.set_snapshot = true;
+ Transaction* txn2 = txn_db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn2);
+
+ // Write some data to the db
+ WriteBatch batch;
+ ASSERT_OK(batch.Put("foo", "foo"));
+ ASSERT_OK(batch.Put(handles[1], "AAA", "bar"));
+ ASSERT_OK(batch.Put(handles[1], "AAAZZZ", "bar"));
+ ASSERT_OK(txn_db->Write(write_options, &batch));
+ ASSERT_OK(txn_db->Delete(write_options, handles[1], "AAAZZZ"));
+
+ // These keys do no conflict with existing writes since they're in
+ // different column families
+ ASSERT_OK(txn->Delete("AAA"));
+ Status s =
+ txn->GetForUpdate(snapshot_read_options, handles[1], "foo", &value);
+ ASSERT_TRUE(s.IsNotFound());
+ Slice key_slice("AAAZZZ");
+ Slice value_slices[2] = {Slice("bar"), Slice("bar")};
+ ASSERT_OK(txn->Put(handles[2], SliceParts(&key_slice, 1),
+ SliceParts(value_slices, 2)));
+
+ ASSERT_EQ(3, txn->GetNumKeys());
+
+ // Txn should commit
+ ASSERT_OK(txn->Commit());
+ s = txn_db->Get(read_options, "AAA", &value);
+ ASSERT_TRUE(s.IsNotFound());
+ s = txn_db->Get(read_options, handles[2], "AAAZZZ", &value);
+ ASSERT_EQ(value, "barbar");
+
+ Slice key_slices[3] = {Slice("AAA"), Slice("ZZ"), Slice("Z")};
+ Slice value_slice("barbarbar");
+ // This write will cause a conflict with the earlier batch write
+ ASSERT_OK(txn2->Put(handles[1], SliceParts(key_slices, 3),
+ SliceParts(&value_slice, 1)));
+
+ ASSERT_OK(txn2->Delete(handles[2], "XXX"));
+ ASSERT_OK(txn2->Delete(handles[1], "XXX"));
+ s = txn2->GetForUpdate(snapshot_read_options, handles[1], "AAA", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ // Verify txn did not commit
+ s = txn2->Commit();
+ ASSERT_TRUE(s.IsBusy());
+ s = txn_db->Get(read_options, handles[1], "AAAZZZ", &value);
+ ASSERT_TRUE(s.IsNotFound());
+ ASSERT_EQ(value, "barbar");
+
+ delete txn;
+ delete txn2;
+
+ txn = txn_db->BeginTransaction(write_options, txn_options);
+ snapshot_read_options.snapshot = txn->GetSnapshot();
+
+ txn2 = txn_db->BeginTransaction(write_options, txn_options);
+ ASSERT_NE(txn, nullptr);
+
+ std::vector<ColumnFamilyHandle*> multiget_cfh = {handles[1], handles[2],
+ handles[0], handles[2]};
+ std::vector<Slice> multiget_keys = {"AAA", "AAAZZZ", "foo", "foo"};
+ std::vector<std::string> values(4);
+
+ std::vector<Status> results = txn->MultiGetForUpdate(
+ snapshot_read_options, multiget_cfh, multiget_keys, &values);
+ ASSERT_OK(results[0]);
+ ASSERT_OK(results[1]);
+ ASSERT_OK(results[2]);
+ ASSERT_TRUE(results[3].IsNotFound());
+ ASSERT_EQ(values[0], "bar");
+ ASSERT_EQ(values[1], "barbar");
+ ASSERT_EQ(values[2], "foo");
+
+ ASSERT_OK(txn->Delete(handles[2], "ZZZ"));
+ ASSERT_OK(txn->Put(handles[2], "ZZZ", "YYY"));
+ ASSERT_OK(txn->Put(handles[2], "ZZZ", "YYYY"));
+ ASSERT_OK(txn->Delete(handles[2], "ZZZ"));
+ ASSERT_OK(txn->Put(handles[2], "AAAZZZ", "barbarbar"));
+
+ ASSERT_EQ(5, txn->GetNumKeys());
+
+ // Txn should commit
+ ASSERT_OK(txn->Commit());
+ s = txn_db->Get(read_options, handles[2], "ZZZ", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ // Put a key which will conflict with the next txn using the previous snapshot
+ ASSERT_OK(txn_db->Put(write_options, handles[2], "foo", "000"));
+
+ results = txn2->MultiGetForUpdate(snapshot_read_options, multiget_cfh,
+ multiget_keys, &values);
+ ASSERT_OK(results[0]);
+ ASSERT_OK(results[1]);
+ ASSERT_OK(results[2]);
+ ASSERT_TRUE(results[3].IsNotFound());
+ ASSERT_EQ(values[0], "bar");
+ ASSERT_EQ(values[1], "barbar");
+ ASSERT_EQ(values[2], "foo");
+
+ // Verify Txn Did not Commit
+ s = txn2->Commit();
+ ASSERT_TRUE(s.IsBusy());
+
+ s = txn_db->DropColumnFamily(handles[1]);
+ ASSERT_OK(s);
+ s = txn_db->DropColumnFamily(handles[2]);
+ ASSERT_OK(s);
+
+ delete txn;
+ delete txn2;
+
+ for (auto handle : handles) {
+ delete handle;
+ }
+}
+
+TEST_P(OptimisticTransactionTest, EmptyTest) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ std::string value;
+
+ ASSERT_OK(txn_db->Put(write_options, "aaa", "aaa"));
+
+ Transaction* txn = txn_db->BeginTransaction(write_options);
+ ASSERT_OK(txn->Commit());
+ delete txn;
+
+ txn = txn_db->BeginTransaction(write_options);
+ ASSERT_OK(txn->Rollback());
+ delete txn;
+
+ txn = txn_db->BeginTransaction(write_options);
+ ASSERT_OK(txn->GetForUpdate(read_options, "aaa", &value));
+ ASSERT_EQ(value, "aaa");
+
+ ASSERT_OK(txn->Commit());
+ delete txn;
+
+ txn = txn_db->BeginTransaction(write_options);
+ txn->SetSnapshot();
+ ASSERT_OK(txn->GetForUpdate(read_options, "aaa", &value));
+ ASSERT_EQ(value, "aaa");
+
+ ASSERT_OK(txn_db->Put(write_options, "aaa", "xxx"));
+ Status s = txn->Commit();
+ ASSERT_TRUE(s.IsBusy());
+ delete txn;
+}
+
+TEST_P(OptimisticTransactionTest, PredicateManyPreceders) {
+ WriteOptions write_options;
+ ReadOptions read_options1, read_options2;
+ OptimisticTransactionOptions txn_options;
+ std::string value;
+
+ txn_options.set_snapshot = true;
+ Transaction* txn1 = txn_db->BeginTransaction(write_options, txn_options);
+ read_options1.snapshot = txn1->GetSnapshot();
+
+ Transaction* txn2 = txn_db->BeginTransaction(write_options);
+ txn2->SetSnapshot();
+ read_options2.snapshot = txn2->GetSnapshot();
+
+ std::vector<Slice> multiget_keys = {"1", "2", "3"};
+ std::vector<std::string> multiget_values;
+
+ std::vector<Status> results =
+ txn1->MultiGetForUpdate(read_options1, multiget_keys, &multiget_values);
+ ASSERT_TRUE(results[0].IsNotFound());
+ ASSERT_TRUE(results[1].IsNotFound());
+ ASSERT_TRUE(results[2].IsNotFound());
+
+ ASSERT_OK(txn2->Put("2", "x"));
+
+ ASSERT_OK(txn2->Commit());
+
+ multiget_values.clear();
+ results =
+ txn1->MultiGetForUpdate(read_options1, multiget_keys, &multiget_values);
+ ASSERT_TRUE(results[0].IsNotFound());
+ ASSERT_TRUE(results[1].IsNotFound());
+ ASSERT_TRUE(results[2].IsNotFound());
+
+ // should not commit since txn2 wrote a key txn has read
+ Status s = txn1->Commit();
+ ASSERT_TRUE(s.IsBusy());
+
+ delete txn1;
+ delete txn2;
+
+ txn1 = txn_db->BeginTransaction(write_options, txn_options);
+ read_options1.snapshot = txn1->GetSnapshot();
+
+ txn2 = txn_db->BeginTransaction(write_options, txn_options);
+ read_options2.snapshot = txn2->GetSnapshot();
+
+ ASSERT_OK(txn1->Put("4", "x"));
+
+ ASSERT_OK(txn2->Delete("4"));
+
+ // txn1 can commit since txn2's delete hasn't happened yet (it's just batched)
+ ASSERT_OK(txn1->Commit());
+
+ s = txn2->GetForUpdate(read_options2, "4", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ // txn2 cannot commit since txn1 changed "4"
+ s = txn2->Commit();
+ ASSERT_TRUE(s.IsBusy());
+
+ delete txn1;
+ delete txn2;
+}
+
+TEST_P(OptimisticTransactionTest, LostUpdate) {
+ WriteOptions write_options;
+ ReadOptions read_options, read_options1, read_options2;
+ OptimisticTransactionOptions txn_options;
+ std::string value;
+
+ // Test 2 transactions writing to the same key in multiple orders and
+ // with/without snapshots
+
+ Transaction* txn1 = txn_db->BeginTransaction(write_options);
+ Transaction* txn2 = txn_db->BeginTransaction(write_options);
+
+ ASSERT_OK(txn1->Put("1", "1"));
+ ASSERT_OK(txn2->Put("1", "2"));
+
+ ASSERT_OK(txn1->Commit());
+
+ Status s = txn2->Commit();
+ ASSERT_TRUE(s.IsBusy());
+
+ delete txn1;
+ delete txn2;
+
+ txn_options.set_snapshot = true;
+ txn1 = txn_db->BeginTransaction(write_options, txn_options);
+ read_options1.snapshot = txn1->GetSnapshot();
+
+ txn2 = txn_db->BeginTransaction(write_options, txn_options);
+ read_options2.snapshot = txn2->GetSnapshot();
+
+ ASSERT_OK(txn1->Put("1", "3"));
+ ASSERT_OK(txn2->Put("1", "4"));
+
+ ASSERT_OK(txn1->Commit());
+
+ s = txn2->Commit();
+ ASSERT_TRUE(s.IsBusy());
+
+ delete txn1;
+ delete txn2;
+
+ txn1 = txn_db->BeginTransaction(write_options, txn_options);
+ read_options1.snapshot = txn1->GetSnapshot();
+
+ txn2 = txn_db->BeginTransaction(write_options, txn_options);
+ read_options2.snapshot = txn2->GetSnapshot();
+
+ ASSERT_OK(txn1->Put("1", "5"));
+ ASSERT_OK(txn1->Commit());
+
+ ASSERT_OK(txn2->Put("1", "6"));
+ s = txn2->Commit();
+ ASSERT_TRUE(s.IsBusy());
+
+ delete txn1;
+ delete txn2;
+
+ txn1 = txn_db->BeginTransaction(write_options, txn_options);
+ read_options1.snapshot = txn1->GetSnapshot();
+
+ txn2 = txn_db->BeginTransaction(write_options, txn_options);
+ read_options2.snapshot = txn2->GetSnapshot();
+
+ ASSERT_OK(txn1->Put("1", "5"));
+ ASSERT_OK(txn1->Commit());
+
+ txn2->SetSnapshot();
+ ASSERT_OK(txn2->Put("1", "6"));
+ ASSERT_OK(txn2->Commit());
+
+ delete txn1;
+ delete txn2;
+
+ txn1 = txn_db->BeginTransaction(write_options);
+ txn2 = txn_db->BeginTransaction(write_options);
+
+ ASSERT_OK(txn1->Put("1", "7"));
+ ASSERT_OK(txn1->Commit());
+
+ ASSERT_OK(txn2->Put("1", "8"));
+ ASSERT_OK(txn2->Commit());
+
+ delete txn1;
+ delete txn2;
+
+ ASSERT_OK(txn_db->Get(read_options, "1", &value));
+ ASSERT_EQ(value, "8");
+}
+
+TEST_P(OptimisticTransactionTest, UntrackedWrites) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ std::string value;
+ Status s;
+
+ // Verify transaction rollback works for untracked keys.
+ Transaction* txn = txn_db->BeginTransaction(write_options);
+ ASSERT_OK(txn->PutUntracked("untracked", "0"));
+ ASSERT_OK(txn->Rollback());
+ s = txn_db->Get(read_options, "untracked", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ delete txn;
+ txn = txn_db->BeginTransaction(write_options);
+
+ ASSERT_OK(txn->Put("tracked", "1"));
+ ASSERT_OK(txn->PutUntracked("untracked", "1"));
+ ASSERT_OK(txn->MergeUntracked("untracked", "2"));
+ ASSERT_OK(txn->DeleteUntracked("untracked"));
+
+ // Write to the untracked key outside of the transaction and verify
+ // it doesn't prevent the transaction from committing.
+ ASSERT_OK(txn_db->Put(write_options, "untracked", "x"));
+
+ ASSERT_OK(txn->Commit());
+
+ s = txn_db->Get(read_options, "untracked", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ delete txn;
+ txn = txn_db->BeginTransaction(write_options);
+
+ ASSERT_OK(txn->Put("tracked", "10"));
+ ASSERT_OK(txn->PutUntracked("untracked", "A"));
+
+ // Write to tracked key outside of the transaction and verify that the
+ // untracked keys are not written when the commit fails.
+ ASSERT_OK(txn_db->Delete(write_options, "tracked"));
+
+ s = txn->Commit();
+ ASSERT_TRUE(s.IsBusy());
+
+ s = txn_db->Get(read_options, "untracked", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ delete txn;
+}
+
+TEST_P(OptimisticTransactionTest, IteratorTest) {
+ WriteOptions write_options;
+ ReadOptions read_options, snapshot_read_options;
+ OptimisticTransactionOptions txn_options;
+ std::string value;
+
+ // Write some keys to the db
+ ASSERT_OK(txn_db->Put(write_options, "A", "a"));
+ ASSERT_OK(txn_db->Put(write_options, "G", "g"));
+ ASSERT_OK(txn_db->Put(write_options, "F", "f"));
+ ASSERT_OK(txn_db->Put(write_options, "C", "c"));
+ ASSERT_OK(txn_db->Put(write_options, "D", "d"));
+
+ Transaction* txn = txn_db->BeginTransaction(write_options);
+ ASSERT_NE(txn, nullptr);
+
+ // Write some keys in a txn
+ ASSERT_OK(txn->Put("B", "b"));
+ ASSERT_OK(txn->Put("H", "h"));
+ ASSERT_OK(txn->Delete("D"));
+ ASSERT_OK(txn->Put("E", "e"));
+
+ txn->SetSnapshot();
+ const Snapshot* snapshot = txn->GetSnapshot();
+
+ // Write some keys to the db after the snapshot
+ ASSERT_OK(txn_db->Put(write_options, "BB", "xx"));
+ ASSERT_OK(txn_db->Put(write_options, "C", "xx"));
+
+ read_options.snapshot = snapshot;
+ Iterator* iter = txn->GetIterator(read_options);
+ ASSERT_OK(iter->status());
+ iter->SeekToFirst();
+
+ // Read all keys via iter and lock them all
+ std::string results[] = {"a", "b", "c", "e", "f", "g", "h"};
+ for (int i = 0; i < 7; i++) {
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ(results[i], iter->value().ToString());
+
+ ASSERT_OK(txn->GetForUpdate(read_options, iter->key(), nullptr));
+
+ iter->Next();
+ }
+ ASSERT_FALSE(iter->Valid());
+
+ iter->Seek("G");
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("g", iter->value().ToString());
+
+ iter->Prev();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("f", iter->value().ToString());
+
+ iter->Seek("D");
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("e", iter->value().ToString());
+
+ iter->Seek("C");
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("c", iter->value().ToString());
+
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("e", iter->value().ToString());
+
+ iter->Seek("");
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("a", iter->value().ToString());
+
+ iter->Seek("X");
+ ASSERT_OK(iter->status());
+ ASSERT_FALSE(iter->Valid());
+
+ iter->SeekToLast();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("h", iter->value().ToString());
+
+ // key "C" was modified in the db after txn's snapshot. txn will not commit.
+ Status s = txn->Commit();
+ ASSERT_TRUE(s.IsBusy());
+
+ delete iter;
+ delete txn;
+}
+
+TEST_P(OptimisticTransactionTest, DeleteRangeSupportTest) {
+ // `OptimisticTransactionDB` does not allow range deletion in any API.
+ ASSERT_TRUE(
+ txn_db
+ ->DeleteRange(WriteOptions(), txn_db->DefaultColumnFamily(), "a", "b")
+ .IsNotSupported());
+ WriteBatch wb;
+ ASSERT_OK(wb.DeleteRange("a", "b"));
+ ASSERT_NOK(txn_db->Write(WriteOptions(), &wb));
+}
+
+TEST_P(OptimisticTransactionTest, SavepointTest) {
+ WriteOptions write_options;
+ ReadOptions read_options, snapshot_read_options;
+ OptimisticTransactionOptions txn_options;
+ std::string value;
+
+ Transaction* txn = txn_db->BeginTransaction(write_options);
+ ASSERT_NE(txn, nullptr);
+
+ Status s = txn->RollbackToSavePoint();
+ ASSERT_TRUE(s.IsNotFound());
+
+ txn->SetSavePoint(); // 1
+
+ ASSERT_OK(txn->RollbackToSavePoint()); // Rollback to beginning of txn
+ s = txn->RollbackToSavePoint();
+ ASSERT_TRUE(s.IsNotFound());
+
+ ASSERT_OK(txn->Put("B", "b"));
+
+ ASSERT_OK(txn->Commit());
+
+ ASSERT_OK(txn_db->Get(read_options, "B", &value));
+ ASSERT_EQ("b", value);
+
+ delete txn;
+ txn = txn_db->BeginTransaction(write_options);
+ ASSERT_NE(txn, nullptr);
+
+ ASSERT_OK(txn->Put("A", "a"));
+ ASSERT_OK(txn->Put("B", "bb"));
+ ASSERT_OK(txn->Put("C", "c"));
+
+ txn->SetSavePoint(); // 2
+
+ ASSERT_OK(txn->Delete("B"));
+ ASSERT_OK(txn->Put("C", "cc"));
+ ASSERT_OK(txn->Put("D", "d"));
+
+ ASSERT_OK(txn->RollbackToSavePoint()); // Rollback to 2
+
+ ASSERT_OK(txn->Get(read_options, "A", &value));
+ ASSERT_EQ("a", value);
+ ASSERT_OK(txn->Get(read_options, "B", &value));
+ ASSERT_EQ("bb", value);
+ ASSERT_OK(txn->Get(read_options, "C", &value));
+ ASSERT_EQ("c", value);
+ s = txn->Get(read_options, "D", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ ASSERT_OK(txn->Put("A", "a"));
+ ASSERT_OK(txn->Put("E", "e"));
+
+ // Rollback to beginning of txn
+ s = txn->RollbackToSavePoint();
+ ASSERT_TRUE(s.IsNotFound());
+ ASSERT_OK(txn->Rollback());
+
+ s = txn->Get(read_options, "A", &value);
+ ASSERT_TRUE(s.IsNotFound());
+ ASSERT_OK(txn->Get(read_options, "B", &value));
+ ASSERT_EQ("b", value);
+ s = txn->Get(read_options, "D", &value);
+ ASSERT_TRUE(s.IsNotFound());
+ s = txn->Get(read_options, "D", &value);
+ ASSERT_TRUE(s.IsNotFound());
+ s = txn->Get(read_options, "E", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ ASSERT_OK(txn->Put("A", "aa"));
+ ASSERT_OK(txn->Put("F", "f"));
+
+ txn->SetSavePoint(); // 3
+ txn->SetSavePoint(); // 4
+
+ ASSERT_OK(txn->Put("G", "g"));
+ ASSERT_OK(txn->Delete("F"));
+ ASSERT_OK(txn->Delete("B"));
+
+ ASSERT_OK(txn->Get(read_options, "A", &value));
+ ASSERT_EQ("aa", value);
+
+ s = txn->Get(read_options, "F", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn->Get(read_options, "B", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ ASSERT_OK(txn->RollbackToSavePoint()); // Rollback to 3
+
+ ASSERT_OK(txn->Get(read_options, "F", &value));
+ ASSERT_EQ("f", value);
+
+ s = txn->Get(read_options, "G", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ ASSERT_OK(txn->Commit());
+
+ ASSERT_OK(txn_db->Get(read_options, "F", &value));
+ ASSERT_EQ("f", value);
+
+ s = txn_db->Get(read_options, "G", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ ASSERT_OK(txn_db->Get(read_options, "A", &value));
+ ASSERT_EQ("aa", value);
+
+ ASSERT_OK(txn_db->Get(read_options, "B", &value));
+ ASSERT_EQ("b", value);
+
+ s = txn_db->Get(read_options, "C", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn_db->Get(read_options, "D", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn_db->Get(read_options, "E", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ delete txn;
+}
+
+TEST_P(OptimisticTransactionTest, UndoGetForUpdateTest) {
+ WriteOptions write_options;
+ ReadOptions read_options, snapshot_read_options;
+ OptimisticTransactionOptions txn_options;
+ std::string value;
+
+ ASSERT_OK(txn_db->Put(write_options, "A", ""));
+
+ Transaction* txn1 = txn_db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn1);
+
+ ASSERT_OK(txn1->GetForUpdate(read_options, "A", &value));
+
+ txn1->UndoGetForUpdate("A");
+
+ Transaction* txn2 = txn_db->BeginTransaction(write_options);
+ txn2->Put("A", "x");
+ ASSERT_OK(txn2->Commit());
+ delete txn2;
+
+ // Verify that txn1 can commit since A isn't conflict checked
+ ASSERT_OK(txn1->Commit());
+ delete txn1;
+
+ txn1 = txn_db->BeginTransaction(write_options);
+ ASSERT_OK(txn1->Put("A", "a"));
+
+ ASSERT_OK(txn1->GetForUpdate(read_options, "A", &value));
+
+ txn1->UndoGetForUpdate("A");
+
+ txn2 = txn_db->BeginTransaction(write_options);
+ ASSERT_OK(txn2->Put("A", "x"));
+ ASSERT_OK(txn2->Commit());
+ delete txn2;
+
+ // Verify that txn1 cannot commit since A will still be conflict checked
+ Status s = txn1->Commit();
+ ASSERT_TRUE(s.IsBusy());
+ delete txn1;
+
+ txn1 = txn_db->BeginTransaction(write_options);
+
+ ASSERT_OK(txn1->GetForUpdate(read_options, "A", &value));
+ ASSERT_OK(txn1->GetForUpdate(read_options, "A", &value));
+
+ txn1->UndoGetForUpdate("A");
+
+ txn2 = txn_db->BeginTransaction(write_options);
+ ASSERT_OK(txn2->Put("A", "x"));
+ ASSERT_OK(txn2->Commit());
+ delete txn2;
+
+ // Verify that txn1 cannot commit since A will still be conflict checked
+ s = txn1->Commit();
+ ASSERT_TRUE(s.IsBusy());
+ delete txn1;
+
+ txn1 = txn_db->BeginTransaction(write_options);
+
+ ASSERT_OK(txn1->GetForUpdate(read_options, "A", &value));
+ ASSERT_OK(txn1->GetForUpdate(read_options, "A", &value));
+
+ txn1->UndoGetForUpdate("A");
+ txn1->UndoGetForUpdate("A");
+
+ txn2 = txn_db->BeginTransaction(write_options);
+ ASSERT_OK(txn2->Put("A", "x"));
+ ASSERT_OK(txn2->Commit());
+ delete txn2;
+
+ // Verify that txn1 can commit since A isn't conflict checked
+ ASSERT_OK(txn1->Commit());
+ delete txn1;
+
+ txn1 = txn_db->BeginTransaction(write_options);
+
+ ASSERT_OK(txn1->GetForUpdate(read_options, "A", &value));
+
+ txn1->SetSavePoint();
+ txn1->UndoGetForUpdate("A");
+
+ txn2 = txn_db->BeginTransaction(write_options);
+ ASSERT_OK(txn2->Put("A", "x"));
+ ASSERT_OK(txn2->Commit());
+ delete txn2;
+
+ // Verify that txn1 cannot commit since A will still be conflict checked
+ s = txn1->Commit();
+ ASSERT_TRUE(s.IsBusy());
+ delete txn1;
+
+ txn1 = txn_db->BeginTransaction(write_options);
+
+ ASSERT_OK(txn1->GetForUpdate(read_options, "A", &value));
+
+ txn1->SetSavePoint();
+ ASSERT_OK(txn1->GetForUpdate(read_options, "A", &value));
+ txn1->UndoGetForUpdate("A");
+
+ txn2 = txn_db->BeginTransaction(write_options);
+ ASSERT_OK(txn2->Put("A", "x"));
+ ASSERT_OK(txn2->Commit());
+ delete txn2;
+
+ // Verify that txn1 cannot commit since A will still be conflict checked
+ s = txn1->Commit();
+ ASSERT_TRUE(s.IsBusy());
+ delete txn1;
+
+ txn1 = txn_db->BeginTransaction(write_options);
+
+ ASSERT_OK(txn1->GetForUpdate(read_options, "A", &value));
+
+ txn1->SetSavePoint();
+ ASSERT_OK(txn1->GetForUpdate(read_options, "A", &value));
+ txn1->UndoGetForUpdate("A");
+
+ ASSERT_OK(txn1->RollbackToSavePoint());
+ txn1->UndoGetForUpdate("A");
+
+ txn2 = txn_db->BeginTransaction(write_options);
+ ASSERT_OK(txn2->Put("A", "x"));
+ ASSERT_OK(txn2->Commit());
+ delete txn2;
+
+ // Verify that txn1 can commit since A isn't conflict checked
+ ASSERT_OK(txn1->Commit());
+ delete txn1;
+}
+
+namespace {
+Status OptimisticTransactionStressTestInserter(OptimisticTransactionDB* db,
+ const size_t num_transactions,
+ const size_t num_sets,
+ const size_t num_keys_per_set) {
+ size_t seed = std::hash<std::thread::id>()(std::this_thread::get_id());
+ Random64 _rand(seed);
+ WriteOptions write_options;
+ ReadOptions read_options;
+ OptimisticTransactionOptions txn_options;
+ txn_options.set_snapshot = true;
+
+ RandomTransactionInserter inserter(&_rand, write_options, read_options,
+ num_keys_per_set,
+ static_cast<uint16_t>(num_sets));
+
+ for (size_t t = 0; t < num_transactions; t++) {
+ bool success = inserter.OptimisticTransactionDBInsert(db, txn_options);
+ if (!success) {
+ // unexpected failure
+ return inserter.GetLastStatus();
+ }
+ }
+
+ inserter.GetLastStatus().PermitUncheckedError();
+
+ // Make sure at least some of the transactions succeeded. It's ok if
+ // some failed due to write-conflicts.
+ if (inserter.GetFailureCount() > num_transactions / 2) {
+ return Status::TryAgain("Too many transactions failed! " +
+ std::to_string(inserter.GetFailureCount()) + " / " +
+ std::to_string(num_transactions));
+ }
+
+ return Status::OK();
+}
+} // namespace
+
+TEST_P(OptimisticTransactionTest, OptimisticTransactionStressTest) {
+ const size_t num_threads = 4;
+ const size_t num_transactions_per_thread = 10000;
+ const size_t num_sets = 3;
+ const size_t num_keys_per_set = 100;
+ // Setting the key-space to be 100 keys should cause enough write-conflicts
+ // to make this test interesting.
+
+ std::vector<port::Thread> threads;
+
+ std::function<void()> call_inserter = [&] {
+ ASSERT_OK(OptimisticTransactionStressTestInserter(
+ txn_db, num_transactions_per_thread, num_sets, num_keys_per_set));
+ };
+
+ // Create N threads that use RandomTransactionInserter to write
+ // many transactions.
+ for (uint32_t i = 0; i < num_threads; i++) {
+ threads.emplace_back(call_inserter);
+ }
+
+ // Wait for all threads to run
+ for (auto& t : threads) {
+ t.join();
+ }
+
+ // Verify that data is consistent
+ Status s = RandomTransactionInserter::Verify(txn_db, num_sets);
+ ASSERT_OK(s);
+}
+
+TEST_P(OptimisticTransactionTest, SequenceNumberAfterRecoverTest) {
+ WriteOptions write_options;
+ OptimisticTransactionOptions transaction_options;
+
+ Transaction* transaction(
+ txn_db->BeginTransaction(write_options, transaction_options));
+ Status s = transaction->Put("foo", "val");
+ ASSERT_OK(s);
+ s = transaction->Put("foo2", "val");
+ ASSERT_OK(s);
+ s = transaction->Put("foo3", "val");
+ ASSERT_OK(s);
+ s = transaction->Commit();
+ ASSERT_OK(s);
+ delete transaction;
+
+ Reopen();
+ transaction = txn_db->BeginTransaction(write_options, transaction_options);
+ s = transaction->Put("bar", "val");
+ ASSERT_OK(s);
+ s = transaction->Put("bar2", "val");
+ ASSERT_OK(s);
+ s = transaction->Commit();
+ ASSERT_OK(s);
+
+ delete transaction;
+}
+
+TEST_P(OptimisticTransactionTest, TimestampedSnapshotMissingCommitTs) {
+ std::unique_ptr<Transaction> txn(txn_db->BeginTransaction(WriteOptions()));
+ ASSERT_OK(txn->Put("a", "v"));
+ Status s = txn->CommitAndTryCreateSnapshot();
+ ASSERT_TRUE(s.IsInvalidArgument());
+}
+
+TEST_P(OptimisticTransactionTest, TimestampedSnapshotSetCommitTs) {
+ std::unique_ptr<Transaction> txn(txn_db->BeginTransaction(WriteOptions()));
+ ASSERT_OK(txn->Put("a", "v"));
+ std::shared_ptr<const Snapshot> snapshot;
+ Status s = txn->CommitAndTryCreateSnapshot(nullptr, /*ts=*/100, &snapshot);
+ ASSERT_TRUE(s.IsNotSupported());
+}
+
+INSTANTIATE_TEST_CASE_P(
+ InstanceOccGroup, OptimisticTransactionTest,
+ testing::Values(OccValidationPolicy::kValidateSerial,
+ OccValidationPolicy::kValidateParallel));
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
+
+#else
+#include <stdio.h>
+
+int main(int /*argc*/, char** /*argv*/) {
+ fprintf(
+ stderr,
+ "SKIPPED as optimistic_transaction is not supported in ROCKSDB_LITE\n");
+ return 0;
+}
+
+#endif // !ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/pessimistic_transaction.cc b/src/rocksdb/utilities/transactions/pessimistic_transaction.cc
new file mode 100644
index 000000000..cb8fd3bb6
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/pessimistic_transaction.cc
@@ -0,0 +1,1175 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "utilities/transactions/pessimistic_transaction.h"
+
+#include <map>
+#include <set>
+#include <string>
+#include <vector>
+
+#include "db/column_family.h"
+#include "db/db_impl/db_impl.h"
+#include "logging/logging.h"
+#include "rocksdb/comparator.h"
+#include "rocksdb/db.h"
+#include "rocksdb/snapshot.h"
+#include "rocksdb/status.h"
+#include "rocksdb/utilities/transaction_db.h"
+#include "test_util/sync_point.h"
+#include "util/cast_util.h"
+#include "util/string_util.h"
+#include "utilities/transactions/pessimistic_transaction_db.h"
+#include "utilities/transactions/transaction_util.h"
+#include "utilities/write_batch_with_index/write_batch_with_index_internal.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+struct WriteOptions;
+
+std::atomic<TransactionID> PessimisticTransaction::txn_id_counter_(1);
+
+TransactionID PessimisticTransaction::GenTxnID() {
+ return txn_id_counter_.fetch_add(1);
+}
+
+PessimisticTransaction::PessimisticTransaction(
+ TransactionDB* txn_db, const WriteOptions& write_options,
+ const TransactionOptions& txn_options, const bool init)
+ : TransactionBaseImpl(
+ txn_db->GetRootDB(), write_options,
+ static_cast_with_check<PessimisticTransactionDB>(txn_db)
+ ->GetLockTrackerFactory()),
+ txn_db_impl_(nullptr),
+ expiration_time_(0),
+ txn_id_(0),
+ waiting_cf_id_(0),
+ waiting_key_(nullptr),
+ lock_timeout_(0),
+ deadlock_detect_(false),
+ deadlock_detect_depth_(0),
+ skip_concurrency_control_(false) {
+ txn_db_impl_ = static_cast_with_check<PessimisticTransactionDB>(txn_db);
+ db_impl_ = static_cast_with_check<DBImpl>(db_);
+ if (init) {
+ Initialize(txn_options);
+ }
+}
+
+void PessimisticTransaction::Initialize(const TransactionOptions& txn_options) {
+ // Range lock manager uses address of transaction object as TXNID
+ const TransactionDBOptions& db_options = txn_db_impl_->GetTxnDBOptions();
+ if (db_options.lock_mgr_handle &&
+ db_options.lock_mgr_handle->getLockManager()->IsRangeLockSupported()) {
+ txn_id_ = reinterpret_cast<TransactionID>(this);
+ } else {
+ txn_id_ = GenTxnID();
+ }
+
+ txn_state_ = STARTED;
+
+ deadlock_detect_ = txn_options.deadlock_detect;
+ deadlock_detect_depth_ = txn_options.deadlock_detect_depth;
+ write_batch_.SetMaxBytes(txn_options.max_write_batch_size);
+ skip_concurrency_control_ = txn_options.skip_concurrency_control;
+
+ lock_timeout_ = txn_options.lock_timeout * 1000;
+ if (lock_timeout_ < 0) {
+ // Lock timeout not set, use default
+ lock_timeout_ =
+ txn_db_impl_->GetTxnDBOptions().transaction_lock_timeout * 1000;
+ }
+
+ if (txn_options.expiration >= 0) {
+ expiration_time_ = start_time_ + txn_options.expiration * 1000;
+ } else {
+ expiration_time_ = 0;
+ }
+
+ if (txn_options.set_snapshot) {
+ SetSnapshot();
+ }
+
+ if (expiration_time_ > 0) {
+ txn_db_impl_->InsertExpirableTransaction(txn_id_, this);
+ }
+ use_only_the_last_commit_time_batch_for_recovery_ =
+ txn_options.use_only_the_last_commit_time_batch_for_recovery;
+ skip_prepare_ = txn_options.skip_prepare;
+
+ read_timestamp_ = kMaxTxnTimestamp;
+ commit_timestamp_ = kMaxTxnTimestamp;
+}
+
+PessimisticTransaction::~PessimisticTransaction() {
+ txn_db_impl_->UnLock(this, *tracked_locks_);
+ if (expiration_time_ > 0) {
+ txn_db_impl_->RemoveExpirableTransaction(txn_id_);
+ }
+ if (!name_.empty() && txn_state_ != COMMITTED) {
+ txn_db_impl_->UnregisterTransaction(this);
+ }
+}
+
+void PessimisticTransaction::Clear() {
+ txn_db_impl_->UnLock(this, *tracked_locks_);
+ TransactionBaseImpl::Clear();
+}
+
+void PessimisticTransaction::Reinitialize(
+ TransactionDB* txn_db, const WriteOptions& write_options,
+ const TransactionOptions& txn_options) {
+ if (!name_.empty() && txn_state_ != COMMITTED) {
+ txn_db_impl_->UnregisterTransaction(this);
+ }
+ TransactionBaseImpl::Reinitialize(txn_db->GetRootDB(), write_options);
+ Initialize(txn_options);
+}
+
+bool PessimisticTransaction::IsExpired() const {
+ if (expiration_time_ > 0) {
+ if (dbimpl_->GetSystemClock()->NowMicros() >= expiration_time_) {
+ // Transaction is expired.
+ return true;
+ }
+ }
+
+ return false;
+}
+
+WriteCommittedTxn::WriteCommittedTxn(TransactionDB* txn_db,
+ const WriteOptions& write_options,
+ const TransactionOptions& txn_options)
+ : PessimisticTransaction(txn_db, write_options, txn_options) {}
+
+Status WriteCommittedTxn::GetForUpdate(const ReadOptions& read_options,
+ ColumnFamilyHandle* column_family,
+ const Slice& key, std::string* value,
+ bool exclusive, const bool do_validate) {
+ return GetForUpdateImpl(read_options, column_family, key, value, exclusive,
+ do_validate);
+}
+
+Status WriteCommittedTxn::GetForUpdate(const ReadOptions& read_options,
+ ColumnFamilyHandle* column_family,
+ const Slice& key,
+ PinnableSlice* pinnable_val,
+ bool exclusive, const bool do_validate) {
+ return GetForUpdateImpl(read_options, column_family, key, pinnable_val,
+ exclusive, do_validate);
+}
+
+template <typename TValue>
+inline Status WriteCommittedTxn::GetForUpdateImpl(
+ const ReadOptions& read_options, ColumnFamilyHandle* column_family,
+ const Slice& key, TValue* value, bool exclusive, const bool do_validate) {
+ column_family =
+ column_family ? column_family : db_impl_->DefaultColumnFamily();
+ assert(column_family);
+ if (!read_options.timestamp) {
+ const Comparator* const ucmp = column_family->GetComparator();
+ assert(ucmp);
+ size_t ts_sz = ucmp->timestamp_size();
+ if (0 == ts_sz) {
+ return TransactionBaseImpl::GetForUpdate(read_options, column_family, key,
+ value, exclusive, do_validate);
+ }
+ } else {
+ Status s = db_impl_->FailIfTsMismatchCf(
+ column_family, *(read_options.timestamp), /*ts_for_read=*/true);
+ if (!s.ok()) {
+ return s;
+ }
+ }
+
+ if (!do_validate) {
+ return Status::InvalidArgument(
+ "If do_validate is false then GetForUpdate with read_timestamp is not "
+ "defined.");
+ } else if (kMaxTxnTimestamp == read_timestamp_) {
+ return Status::InvalidArgument("read_timestamp must be set for validation");
+ }
+
+ if (!read_options.timestamp) {
+ ReadOptions read_opts_copy = read_options;
+ char ts_buf[sizeof(kMaxTxnTimestamp)];
+ EncodeFixed64(ts_buf, read_timestamp_);
+ Slice ts(ts_buf, sizeof(ts_buf));
+ read_opts_copy.timestamp = &ts;
+ return TransactionBaseImpl::GetForUpdate(read_opts_copy, column_family, key,
+ value, exclusive, do_validate);
+ }
+ assert(read_options.timestamp);
+ const char* const ts_buf = read_options.timestamp->data();
+ assert(read_options.timestamp->size() == sizeof(kMaxTxnTimestamp));
+ TxnTimestamp ts = DecodeFixed64(ts_buf);
+ if (ts != read_timestamp_) {
+ return Status::InvalidArgument("Must read from the same read_timestamp");
+ }
+ return TransactionBaseImpl::GetForUpdate(read_options, column_family, key,
+ value, exclusive, do_validate);
+}
+
+Status WriteCommittedTxn::Put(ColumnFamilyHandle* column_family,
+ const Slice& key, const Slice& value,
+ const bool assume_tracked) {
+ const bool do_validate = !assume_tracked;
+ return Operate(column_family, key, do_validate, assume_tracked,
+ [column_family, &key, &value, this]() {
+ Status s =
+ GetBatchForWrite()->Put(column_family, key, value);
+ if (s.ok()) {
+ ++num_puts_;
+ }
+ return s;
+ });
+}
+
+Status WriteCommittedTxn::Put(ColumnFamilyHandle* column_family,
+ const SliceParts& key, const SliceParts& value,
+ const bool assume_tracked) {
+ const bool do_validate = !assume_tracked;
+ return Operate(column_family, key, do_validate, assume_tracked,
+ [column_family, &key, &value, this]() {
+ Status s =
+ GetBatchForWrite()->Put(column_family, key, value);
+ if (s.ok()) {
+ ++num_puts_;
+ }
+ return s;
+ });
+}
+
+Status WriteCommittedTxn::PutUntracked(ColumnFamilyHandle* column_family,
+ const Slice& key, const Slice& value) {
+ return Operate(
+ column_family, key, /*do_validate=*/false,
+ /*assume_tracked=*/false, [column_family, &key, &value, this]() {
+ Status s = GetBatchForWrite()->Put(column_family, key, value);
+ if (s.ok()) {
+ ++num_puts_;
+ }
+ return s;
+ });
+}
+
+Status WriteCommittedTxn::PutUntracked(ColumnFamilyHandle* column_family,
+ const SliceParts& key,
+ const SliceParts& value) {
+ return Operate(
+ column_family, key, /*do_validate=*/false,
+ /*assume_tracked=*/false, [column_family, &key, &value, this]() {
+ Status s = GetBatchForWrite()->Put(column_family, key, value);
+ if (s.ok()) {
+ ++num_puts_;
+ }
+ return s;
+ });
+}
+
+Status WriteCommittedTxn::Delete(ColumnFamilyHandle* column_family,
+ const Slice& key, const bool assume_tracked) {
+ const bool do_validate = !assume_tracked;
+ return Operate(column_family, key, do_validate, assume_tracked,
+ [column_family, &key, this]() {
+ Status s = GetBatchForWrite()->Delete(column_family, key);
+ if (s.ok()) {
+ ++num_deletes_;
+ }
+ return s;
+ });
+}
+
+Status WriteCommittedTxn::Delete(ColumnFamilyHandle* column_family,
+ const SliceParts& key,
+ const bool assume_tracked) {
+ const bool do_validate = !assume_tracked;
+ return Operate(column_family, key, do_validate, assume_tracked,
+ [column_family, &key, this]() {
+ Status s = GetBatchForWrite()->Delete(column_family, key);
+ if (s.ok()) {
+ ++num_deletes_;
+ }
+ return s;
+ });
+}
+
+Status WriteCommittedTxn::DeleteUntracked(ColumnFamilyHandle* column_family,
+ const Slice& key) {
+ return Operate(column_family, key, /*do_validate=*/false,
+ /*assume_tracked=*/false, [column_family, &key, this]() {
+ Status s = GetBatchForWrite()->Delete(column_family, key);
+ if (s.ok()) {
+ ++num_deletes_;
+ }
+ return s;
+ });
+}
+
+Status WriteCommittedTxn::DeleteUntracked(ColumnFamilyHandle* column_family,
+ const SliceParts& key) {
+ return Operate(column_family, key, /*do_validate=*/false,
+ /*assume_tracked=*/false, [column_family, &key, this]() {
+ Status s = GetBatchForWrite()->Delete(column_family, key);
+ if (s.ok()) {
+ ++num_deletes_;
+ }
+ return s;
+ });
+}
+
+Status WriteCommittedTxn::SingleDelete(ColumnFamilyHandle* column_family,
+ const Slice& key,
+ const bool assume_tracked) {
+ const bool do_validate = !assume_tracked;
+ return Operate(column_family, key, do_validate, assume_tracked,
+ [column_family, &key, this]() {
+ Status s =
+ GetBatchForWrite()->SingleDelete(column_family, key);
+ if (s.ok()) {
+ ++num_deletes_;
+ }
+ return s;
+ });
+}
+
+Status WriteCommittedTxn::SingleDelete(ColumnFamilyHandle* column_family,
+ const SliceParts& key,
+ const bool assume_tracked) {
+ const bool do_validate = !assume_tracked;
+ return Operate(column_family, key, do_validate, assume_tracked,
+ [column_family, &key, this]() {
+ Status s =
+ GetBatchForWrite()->SingleDelete(column_family, key);
+ if (s.ok()) {
+ ++num_deletes_;
+ }
+ return s;
+ });
+}
+
+Status WriteCommittedTxn::SingleDeleteUntracked(
+ ColumnFamilyHandle* column_family, const Slice& key) {
+ return Operate(column_family, key, /*do_validate=*/false,
+ /*assume_tracked=*/false, [column_family, &key, this]() {
+ Status s =
+ GetBatchForWrite()->SingleDelete(column_family, key);
+ if (s.ok()) {
+ ++num_deletes_;
+ }
+ return s;
+ });
+}
+
+Status WriteCommittedTxn::Merge(ColumnFamilyHandle* column_family,
+ const Slice& key, const Slice& value,
+ const bool assume_tracked) {
+ const bool do_validate = !assume_tracked;
+ return Operate(column_family, key, do_validate, assume_tracked,
+ [column_family, &key, &value, this]() {
+ Status s =
+ GetBatchForWrite()->Merge(column_family, key, value);
+ if (s.ok()) {
+ ++num_merges_;
+ }
+ return s;
+ });
+}
+
+template <typename TKey, typename TOperation>
+Status WriteCommittedTxn::Operate(ColumnFamilyHandle* column_family,
+ const TKey& key, const bool do_validate,
+ const bool assume_tracked,
+ TOperation&& operation) {
+ Status s;
+ if constexpr (std::is_same_v<Slice, TKey>) {
+ s = TryLock(column_family, key, /*read_only=*/false, /*exclusive=*/true,
+ do_validate, assume_tracked);
+ } else if constexpr (std::is_same_v<SliceParts, TKey>) {
+ std::string key_buf;
+ Slice contiguous_key(key, &key_buf);
+ s = TryLock(column_family, contiguous_key, /*read_only=*/false,
+ /*exclusive=*/true, do_validate, assume_tracked);
+ }
+ if (!s.ok()) {
+ return s;
+ }
+ column_family =
+ column_family ? column_family : db_impl_->DefaultColumnFamily();
+ assert(column_family);
+ const Comparator* const ucmp = column_family->GetComparator();
+ assert(ucmp);
+ size_t ts_sz = ucmp->timestamp_size();
+ if (ts_sz > 0) {
+ assert(ts_sz == sizeof(TxnTimestamp));
+ if (!IndexingEnabled()) {
+ cfs_with_ts_tracked_when_indexing_disabled_.insert(
+ column_family->GetID());
+ }
+ }
+ return operation();
+}
+
+Status WriteCommittedTxn::SetReadTimestampForValidation(TxnTimestamp ts) {
+ if (read_timestamp_ < kMaxTxnTimestamp && ts < read_timestamp_) {
+ return Status::InvalidArgument(
+ "Cannot decrease read timestamp for validation");
+ }
+ read_timestamp_ = ts;
+ return Status::OK();
+}
+
+Status WriteCommittedTxn::SetCommitTimestamp(TxnTimestamp ts) {
+ if (read_timestamp_ < kMaxTxnTimestamp && ts <= read_timestamp_) {
+ return Status::InvalidArgument(
+ "Cannot commit at timestamp smaller than or equal to read timestamp");
+ }
+ commit_timestamp_ = ts;
+ return Status::OK();
+}
+
+Status PessimisticTransaction::CommitBatch(WriteBatch* batch) {
+ if (batch && WriteBatchInternal::HasKeyWithTimestamp(*batch)) {
+ // CommitBatch() needs to lock the keys in the batch.
+ // However, the application also needs to specify the timestamp for the
+ // keys in batch before calling this API.
+ // This means timestamp order may violate the order of locking, thus
+ // violate the sequence number order for the same user key.
+ // Therefore, we disallow this operation for now.
+ return Status::NotSupported(
+ "Batch to commit includes timestamp assigned before locking");
+ }
+
+ std::unique_ptr<LockTracker> keys_to_unlock(lock_tracker_factory_.Create());
+ Status s = LockBatch(batch, keys_to_unlock.get());
+
+ if (!s.ok()) {
+ return s;
+ }
+
+ bool can_commit = false;
+
+ if (IsExpired()) {
+ s = Status::Expired();
+ } else if (expiration_time_ > 0) {
+ TransactionState expected = STARTED;
+ can_commit = std::atomic_compare_exchange_strong(&txn_state_, &expected,
+ AWAITING_COMMIT);
+ } else if (txn_state_ == STARTED) {
+ // lock stealing is not a concern
+ can_commit = true;
+ }
+
+ if (can_commit) {
+ txn_state_.store(AWAITING_COMMIT);
+ s = CommitBatchInternal(batch);
+ if (s.ok()) {
+ txn_state_.store(COMMITTED);
+ }
+ } else if (txn_state_ == LOCKS_STOLEN) {
+ s = Status::Expired();
+ } else {
+ s = Status::InvalidArgument("Transaction is not in state for commit.");
+ }
+
+ txn_db_impl_->UnLock(this, *keys_to_unlock);
+
+ return s;
+}
+
+Status PessimisticTransaction::Prepare() {
+ if (name_.empty()) {
+ return Status::InvalidArgument(
+ "Cannot prepare a transaction that has not been named.");
+ }
+
+ if (IsExpired()) {
+ return Status::Expired();
+ }
+
+ Status s;
+ bool can_prepare = false;
+
+ if (expiration_time_ > 0) {
+ // must concern ourselves with expiraton and/or lock stealing
+ // need to compare/exchange bc locks could be stolen under us here
+ TransactionState expected = STARTED;
+ can_prepare = std::atomic_compare_exchange_strong(&txn_state_, &expected,
+ AWAITING_PREPARE);
+ } else if (txn_state_ == STARTED) {
+ // expiration and lock stealing is not possible
+ txn_state_.store(AWAITING_PREPARE);
+ can_prepare = true;
+ }
+
+ if (can_prepare) {
+ // transaction can't expire after preparation
+ expiration_time_ = 0;
+ assert(log_number_ == 0 ||
+ txn_db_impl_->GetTxnDBOptions().write_policy == WRITE_UNPREPARED);
+
+ s = PrepareInternal();
+ if (s.ok()) {
+ txn_state_.store(PREPARED);
+ }
+ } else if (txn_state_ == LOCKS_STOLEN) {
+ s = Status::Expired();
+ } else if (txn_state_ == PREPARED) {
+ s = Status::InvalidArgument("Transaction has already been prepared.");
+ } else if (txn_state_ == COMMITTED) {
+ s = Status::InvalidArgument("Transaction has already been committed.");
+ } else if (txn_state_ == ROLLEDBACK) {
+ s = Status::InvalidArgument("Transaction has already been rolledback.");
+ } else {
+ s = Status::InvalidArgument("Transaction is not in state for commit.");
+ }
+
+ return s;
+}
+
+Status WriteCommittedTxn::PrepareInternal() {
+ WriteOptions write_options = write_options_;
+ write_options.disableWAL = false;
+ auto s = WriteBatchInternal::MarkEndPrepare(GetWriteBatch()->GetWriteBatch(),
+ name_);
+ assert(s.ok());
+ class MarkLogCallback : public PreReleaseCallback {
+ public:
+ MarkLogCallback(DBImpl* db, bool two_write_queues)
+ : db_(db), two_write_queues_(two_write_queues) {
+ (void)two_write_queues_; // to silence unused private field warning
+ }
+ virtual Status Callback(SequenceNumber, bool is_mem_disabled,
+ uint64_t log_number, size_t /*index*/,
+ size_t /*total*/) override {
+#ifdef NDEBUG
+ (void)is_mem_disabled;
+#endif
+ assert(log_number != 0);
+ assert(!two_write_queues_ || is_mem_disabled); // implies the 2nd queue
+ db_->logs_with_prep_tracker()->MarkLogAsContainingPrepSection(log_number);
+ return Status::OK();
+ }
+
+ private:
+ DBImpl* db_;
+ bool two_write_queues_;
+ } mark_log_callback(db_impl_,
+ db_impl_->immutable_db_options().two_write_queues);
+
+ WriteCallback* const kNoWriteCallback = nullptr;
+ const uint64_t kRefNoLog = 0;
+ const bool kDisableMemtable = true;
+ SequenceNumber* const KIgnoreSeqUsed = nullptr;
+ const size_t kNoBatchCount = 0;
+ s = db_impl_->WriteImpl(write_options, GetWriteBatch()->GetWriteBatch(),
+ kNoWriteCallback, &log_number_, kRefNoLog,
+ kDisableMemtable, KIgnoreSeqUsed, kNoBatchCount,
+ &mark_log_callback);
+ return s;
+}
+
+Status PessimisticTransaction::Commit() {
+ bool commit_without_prepare = false;
+ bool commit_prepared = false;
+
+ if (IsExpired()) {
+ return Status::Expired();
+ }
+
+ if (expiration_time_ > 0) {
+ // we must atomicaly compare and exchange the state here because at
+ // this state in the transaction it is possible for another thread
+ // to change our state out from under us in the even that we expire and have
+ // our locks stolen. In this case the only valid state is STARTED because
+ // a state of PREPARED would have a cleared expiration_time_.
+ TransactionState expected = STARTED;
+ commit_without_prepare = std::atomic_compare_exchange_strong(
+ &txn_state_, &expected, AWAITING_COMMIT);
+ TEST_SYNC_POINT("TransactionTest::ExpirableTransactionDataRace:1");
+ } else if (txn_state_ == PREPARED) {
+ // expiration and lock stealing is not a concern
+ commit_prepared = true;
+ } else if (txn_state_ == STARTED) {
+ // expiration and lock stealing is not a concern
+ if (skip_prepare_) {
+ commit_without_prepare = true;
+ } else {
+ return Status::TxnNotPrepared();
+ }
+ }
+
+ Status s;
+ if (commit_without_prepare) {
+ assert(!commit_prepared);
+ if (WriteBatchInternal::Count(GetCommitTimeWriteBatch()) > 0) {
+ s = Status::InvalidArgument(
+ "Commit-time batch contains values that will not be committed.");
+ } else {
+ txn_state_.store(AWAITING_COMMIT);
+ if (log_number_ > 0) {
+ dbimpl_->logs_with_prep_tracker()->MarkLogAsHavingPrepSectionFlushed(
+ log_number_);
+ }
+ s = CommitWithoutPrepareInternal();
+ if (!name_.empty()) {
+ txn_db_impl_->UnregisterTransaction(this);
+ }
+ Clear();
+ if (s.ok()) {
+ txn_state_.store(COMMITTED);
+ }
+ }
+ } else if (commit_prepared) {
+ txn_state_.store(AWAITING_COMMIT);
+
+ s = CommitInternal();
+
+ if (!s.ok()) {
+ ROCKS_LOG_WARN(db_impl_->immutable_db_options().info_log,
+ "Commit write failed");
+ return s;
+ }
+
+ // FindObsoleteFiles must now look to the memtables
+ // to determine what prep logs must be kept around,
+ // not the prep section heap.
+ assert(log_number_ > 0);
+ dbimpl_->logs_with_prep_tracker()->MarkLogAsHavingPrepSectionFlushed(
+ log_number_);
+ txn_db_impl_->UnregisterTransaction(this);
+
+ Clear();
+ txn_state_.store(COMMITTED);
+ } else if (txn_state_ == LOCKS_STOLEN) {
+ s = Status::Expired();
+ } else if (txn_state_ == COMMITTED) {
+ s = Status::InvalidArgument("Transaction has already been committed.");
+ } else if (txn_state_ == ROLLEDBACK) {
+ s = Status::InvalidArgument("Transaction has already been rolledback.");
+ } else {
+ s = Status::InvalidArgument("Transaction is not in state for commit.");
+ }
+
+ return s;
+}
+
+Status WriteCommittedTxn::CommitWithoutPrepareInternal() {
+ WriteBatchWithIndex* wbwi = GetWriteBatch();
+ assert(wbwi);
+ WriteBatch* wb = wbwi->GetWriteBatch();
+ assert(wb);
+
+ const bool needs_ts = WriteBatchInternal::HasKeyWithTimestamp(*wb);
+ if (needs_ts && commit_timestamp_ == kMaxTxnTimestamp) {
+ return Status::InvalidArgument("Must assign a commit timestamp");
+ }
+
+ if (needs_ts) {
+ assert(commit_timestamp_ != kMaxTxnTimestamp);
+ char commit_ts_buf[sizeof(kMaxTxnTimestamp)];
+ EncodeFixed64(commit_ts_buf, commit_timestamp_);
+ Slice commit_ts(commit_ts_buf, sizeof(commit_ts_buf));
+
+ Status s =
+ wb->UpdateTimestamps(commit_ts, [wbwi, this](uint32_t cf) -> size_t {
+ auto cf_iter = cfs_with_ts_tracked_when_indexing_disabled_.find(cf);
+ if (cf_iter != cfs_with_ts_tracked_when_indexing_disabled_.end()) {
+ return sizeof(kMaxTxnTimestamp);
+ }
+ const Comparator* ucmp =
+ WriteBatchWithIndexInternal::GetUserComparator(*wbwi, cf);
+ return ucmp ? ucmp->timestamp_size()
+ : std::numeric_limits<uint64_t>::max();
+ });
+ if (!s.ok()) {
+ return s;
+ }
+ }
+
+ uint64_t seq_used = kMaxSequenceNumber;
+ SnapshotCreationCallback snapshot_creation_cb(db_impl_, commit_timestamp_,
+ snapshot_notifier_, snapshot_);
+ PostMemTableCallback* post_mem_cb = nullptr;
+ if (snapshot_needed_) {
+ if (commit_timestamp_ == kMaxTxnTimestamp) {
+ return Status::InvalidArgument("Must set transaction commit timestamp");
+ } else {
+ post_mem_cb = &snapshot_creation_cb;
+ }
+ }
+ auto s = db_impl_->WriteImpl(write_options_, wb,
+ /*callback*/ nullptr, /*log_used*/ nullptr,
+ /*log_ref*/ 0, /*disable_memtable*/ false,
+ &seq_used, /*batch_cnt=*/0,
+ /*pre_release_callback=*/nullptr, post_mem_cb);
+ assert(!s.ok() || seq_used != kMaxSequenceNumber);
+ if (s.ok()) {
+ SetId(seq_used);
+ }
+ return s;
+}
+
+Status WriteCommittedTxn::CommitBatchInternal(WriteBatch* batch, size_t) {
+ uint64_t seq_used = kMaxSequenceNumber;
+ auto s = db_impl_->WriteImpl(write_options_, batch, /*callback*/ nullptr,
+ /*log_used*/ nullptr, /*log_ref*/ 0,
+ /*disable_memtable*/ false, &seq_used);
+ assert(!s.ok() || seq_used != kMaxSequenceNumber);
+ if (s.ok()) {
+ SetId(seq_used);
+ }
+ return s;
+}
+
+Status WriteCommittedTxn::CommitInternal() {
+ WriteBatchWithIndex* wbwi = GetWriteBatch();
+ assert(wbwi);
+ WriteBatch* wb = wbwi->GetWriteBatch();
+ assert(wb);
+
+ const bool needs_ts = WriteBatchInternal::HasKeyWithTimestamp(*wb);
+ if (needs_ts && commit_timestamp_ == kMaxTxnTimestamp) {
+ return Status::InvalidArgument("Must assign a commit timestamp");
+ }
+ // We take the commit-time batch and append the Commit marker.
+ // The Memtable will ignore the Commit marker in non-recovery mode
+ WriteBatch* working_batch = GetCommitTimeWriteBatch();
+
+ Status s;
+ if (!needs_ts) {
+ s = WriteBatchInternal::MarkCommit(working_batch, name_);
+ } else {
+ assert(commit_timestamp_ != kMaxTxnTimestamp);
+ char commit_ts_buf[sizeof(kMaxTxnTimestamp)];
+ EncodeFixed64(commit_ts_buf, commit_timestamp_);
+ Slice commit_ts(commit_ts_buf, sizeof(commit_ts_buf));
+ s = WriteBatchInternal::MarkCommitWithTimestamp(working_batch, name_,
+ commit_ts);
+ if (s.ok()) {
+ s = wb->UpdateTimestamps(commit_ts, [wbwi, this](uint32_t cf) -> size_t {
+ if (cfs_with_ts_tracked_when_indexing_disabled_.find(cf) !=
+ cfs_with_ts_tracked_when_indexing_disabled_.end()) {
+ return sizeof(kMaxTxnTimestamp);
+ }
+ const Comparator* ucmp =
+ WriteBatchWithIndexInternal::GetUserComparator(*wbwi, cf);
+ return ucmp ? ucmp->timestamp_size()
+ : std::numeric_limits<uint64_t>::max();
+ });
+ }
+ }
+
+ if (!s.ok()) {
+ return s;
+ }
+
+ // any operations appended to this working_batch will be ignored from WAL
+ working_batch->MarkWalTerminationPoint();
+
+ // insert prepared batch into Memtable only skipping WAL.
+ // Memtable will ignore BeginPrepare/EndPrepare markers
+ // in non recovery mode and simply insert the values
+ s = WriteBatchInternal::Append(working_batch, wb);
+ assert(s.ok());
+
+ uint64_t seq_used = kMaxSequenceNumber;
+ SnapshotCreationCallback snapshot_creation_cb(db_impl_, commit_timestamp_,
+ snapshot_notifier_, snapshot_);
+ PostMemTableCallback* post_mem_cb = nullptr;
+ if (snapshot_needed_) {
+ if (commit_timestamp_ == kMaxTxnTimestamp) {
+ s = Status::InvalidArgument("Must set transaction commit timestamp");
+ return s;
+ } else {
+ post_mem_cb = &snapshot_creation_cb;
+ }
+ }
+ s = db_impl_->WriteImpl(write_options_, working_batch, /*callback*/ nullptr,
+ /*log_used*/ nullptr, /*log_ref*/ log_number_,
+ /*disable_memtable*/ false, &seq_used,
+ /*batch_cnt=*/0, /*pre_release_callback=*/nullptr,
+ post_mem_cb);
+ assert(!s.ok() || seq_used != kMaxSequenceNumber);
+ if (s.ok()) {
+ SetId(seq_used);
+ }
+ return s;
+}
+
+Status PessimisticTransaction::Rollback() {
+ Status s;
+ if (txn_state_ == PREPARED) {
+ txn_state_.store(AWAITING_ROLLBACK);
+
+ s = RollbackInternal();
+
+ if (s.ok()) {
+ // we do not need to keep our prepared section around
+ assert(log_number_ > 0);
+ dbimpl_->logs_with_prep_tracker()->MarkLogAsHavingPrepSectionFlushed(
+ log_number_);
+ Clear();
+ txn_state_.store(ROLLEDBACK);
+ }
+ } else if (txn_state_ == STARTED) {
+ if (log_number_ > 0) {
+ assert(txn_db_impl_->GetTxnDBOptions().write_policy == WRITE_UNPREPARED);
+ assert(GetId() > 0);
+ s = RollbackInternal();
+
+ if (s.ok()) {
+ dbimpl_->logs_with_prep_tracker()->MarkLogAsHavingPrepSectionFlushed(
+ log_number_);
+ }
+ }
+ // prepare couldn't have taken place
+ Clear();
+ } else if (txn_state_ == COMMITTED) {
+ s = Status::InvalidArgument("This transaction has already been committed.");
+ } else {
+ s = Status::InvalidArgument(
+ "Two phase transaction is not in state for rollback.");
+ }
+
+ return s;
+}
+
+Status WriteCommittedTxn::RollbackInternal() {
+ WriteBatch rollback_marker;
+ auto s = WriteBatchInternal::MarkRollback(&rollback_marker, name_);
+ assert(s.ok());
+ s = db_impl_->WriteImpl(write_options_, &rollback_marker);
+ return s;
+}
+
+Status PessimisticTransaction::RollbackToSavePoint() {
+ if (txn_state_ != STARTED) {
+ return Status::InvalidArgument("Transaction is beyond state for rollback.");
+ }
+
+ if (save_points_ != nullptr && !save_points_->empty()) {
+ // Unlock any keys locked since last transaction
+ auto& save_point_tracker = *save_points_->top().new_locks_;
+ std::unique_ptr<LockTracker> t(
+ tracked_locks_->GetTrackedLocksSinceSavePoint(save_point_tracker));
+ if (t) {
+ txn_db_impl_->UnLock(this, *t);
+ }
+ }
+
+ return TransactionBaseImpl::RollbackToSavePoint();
+}
+
+// Lock all keys in this batch.
+// On success, caller should unlock keys_to_unlock
+Status PessimisticTransaction::LockBatch(WriteBatch* batch,
+ LockTracker* keys_to_unlock) {
+ if (!batch) {
+ return Status::InvalidArgument("batch is nullptr");
+ }
+
+ class Handler : public WriteBatch::Handler {
+ public:
+ // Sorted map of column_family_id to sorted set of keys.
+ // Since LockBatch() always locks keys in sorted order, it cannot deadlock
+ // with itself. We're not using a comparator here since it doesn't matter
+ // what the sorting is as long as it's consistent.
+ std::map<uint32_t, std::set<std::string>> keys_;
+
+ Handler() {}
+
+ void RecordKey(uint32_t column_family_id, const Slice& key) {
+ std::string key_str = key.ToString();
+
+ auto& cfh_keys = keys_[column_family_id];
+ auto iter = cfh_keys.find(key_str);
+ if (iter == cfh_keys.end()) {
+ // key not yet seen, store it.
+ cfh_keys.insert({std::move(key_str)});
+ }
+ }
+
+ Status PutCF(uint32_t column_family_id, const Slice& key,
+ const Slice& /* unused */) override {
+ RecordKey(column_family_id, key);
+ return Status::OK();
+ }
+ Status MergeCF(uint32_t column_family_id, const Slice& key,
+ const Slice& /* unused */) override {
+ RecordKey(column_family_id, key);
+ return Status::OK();
+ }
+ Status DeleteCF(uint32_t column_family_id, const Slice& key) override {
+ RecordKey(column_family_id, key);
+ return Status::OK();
+ }
+ };
+
+ // Iterating on this handler will add all keys in this batch into keys
+ Handler handler;
+ Status s = batch->Iterate(&handler);
+ if (!s.ok()) {
+ return s;
+ }
+
+ // Attempt to lock all keys
+ for (const auto& cf_iter : handler.keys_) {
+ uint32_t cfh_id = cf_iter.first;
+ auto& cfh_keys = cf_iter.second;
+
+ for (const auto& key_iter : cfh_keys) {
+ const std::string& key = key_iter;
+
+ s = txn_db_impl_->TryLock(this, cfh_id, key, true /* exclusive */);
+ if (!s.ok()) {
+ break;
+ }
+ PointLockRequest r;
+ r.column_family_id = cfh_id;
+ r.key = key;
+ r.seq = kMaxSequenceNumber;
+ r.read_only = false;
+ r.exclusive = true;
+ keys_to_unlock->Track(r);
+ }
+
+ if (!s.ok()) {
+ break;
+ }
+ }
+
+ if (!s.ok()) {
+ txn_db_impl_->UnLock(this, *keys_to_unlock);
+ }
+
+ return s;
+}
+
+// Attempt to lock this key.
+// Returns OK if the key has been successfully locked. Non-ok, otherwise.
+// If check_shapshot is true and this transaction has a snapshot set,
+// this key will only be locked if there have been no writes to this key since
+// the snapshot time.
+Status PessimisticTransaction::TryLock(ColumnFamilyHandle* column_family,
+ const Slice& key, bool read_only,
+ bool exclusive, const bool do_validate,
+ const bool assume_tracked) {
+ assert(!assume_tracked || !do_validate);
+ Status s;
+ if (UNLIKELY(skip_concurrency_control_)) {
+ return s;
+ }
+ uint32_t cfh_id = GetColumnFamilyID(column_family);
+ std::string key_str = key.ToString();
+
+ PointLockStatus status;
+ bool lock_upgrade;
+ bool previously_locked;
+ if (tracked_locks_->IsPointLockSupported()) {
+ status = tracked_locks_->GetPointLockStatus(cfh_id, key_str);
+ previously_locked = status.locked;
+ lock_upgrade = previously_locked && exclusive && !status.exclusive;
+ } else {
+ // If the record is tracked, we can assume it was locked, too.
+ previously_locked = assume_tracked;
+ status.locked = false;
+ lock_upgrade = false;
+ }
+
+ // Lock this key if this transactions hasn't already locked it or we require
+ // an upgrade.
+ if (!previously_locked || lock_upgrade) {
+ s = txn_db_impl_->TryLock(this, cfh_id, key_str, exclusive);
+ }
+
+ const ColumnFamilyHandle* const cfh =
+ column_family ? column_family : db_impl_->DefaultColumnFamily();
+ assert(cfh);
+ const Comparator* const ucmp = cfh->GetComparator();
+ assert(ucmp);
+ size_t ts_sz = ucmp->timestamp_size();
+
+ SetSnapshotIfNeeded();
+
+ // Even though we do not care about doing conflict checking for this write,
+ // we still need to take a lock to make sure we do not cause a conflict with
+ // some other write. However, we do not need to check if there have been
+ // any writes since this transaction's snapshot.
+ // TODO(agiardullo): could optimize by supporting shared txn locks in the
+ // future.
+ SequenceNumber tracked_at_seq =
+ status.locked ? status.seq : kMaxSequenceNumber;
+ if (!do_validate || (snapshot_ == nullptr &&
+ (0 == ts_sz || kMaxTxnTimestamp == read_timestamp_))) {
+ if (assume_tracked && !previously_locked &&
+ tracked_locks_->IsPointLockSupported()) {
+ s = Status::InvalidArgument(
+ "assume_tracked is set but it is not tracked yet");
+ }
+ // Need to remember the earliest sequence number that we know that this
+ // key has not been modified after. This is useful if this same
+ // transaction later tries to lock this key again.
+ if (tracked_at_seq == kMaxSequenceNumber) {
+ // Since we haven't checked a snapshot, we only know this key has not
+ // been modified since after we locked it.
+ // Note: when last_seq_same_as_publish_seq_==false this is less than the
+ // latest allocated seq but it is ok since i) this is just a heuristic
+ // used only as a hint to avoid actual check for conflicts, ii) this would
+ // cause a false positive only if the snapthot is taken right after the
+ // lock, which would be an unusual sequence.
+ tracked_at_seq = db_->GetLatestSequenceNumber();
+ }
+ } else if (s.ok()) {
+ // If a snapshot is set, we need to make sure the key hasn't been modified
+ // since the snapshot. This must be done after we locked the key.
+ // If we already have validated an earilier snapshot it must has been
+ // reflected in tracked_at_seq and ValidateSnapshot will return OK.
+ s = ValidateSnapshot(column_family, key, &tracked_at_seq);
+
+ if (!s.ok()) {
+ // Failed to validate key
+ // Unlock key we just locked
+ if (lock_upgrade) {
+ s = txn_db_impl_->TryLock(this, cfh_id, key_str, false /* exclusive */);
+ assert(s.ok());
+ } else if (!previously_locked) {
+ txn_db_impl_->UnLock(this, cfh_id, key.ToString());
+ }
+ }
+ }
+
+ if (s.ok()) {
+ // We must track all the locked keys so that we can unlock them later. If
+ // the key is already locked, this func will update some stats on the
+ // tracked key. It could also update the tracked_at_seq if it is lower
+ // than the existing tracked key seq. These stats are necessary for
+ // RollbackToSavePoint to determine whether a key can be safely removed
+ // from tracked_keys_. Removal can only be done if a key was only locked
+ // during the current savepoint.
+ //
+ // Recall that if assume_tracked is true, we assume that TrackKey has been
+ // called previously since the last savepoint, with the same exclusive
+ // setting, and at a lower sequence number, so skipping here should be
+ // safe.
+ if (!assume_tracked) {
+ TrackKey(cfh_id, key_str, tracked_at_seq, read_only, exclusive);
+ } else {
+#ifndef NDEBUG
+ if (tracked_locks_->IsPointLockSupported()) {
+ PointLockStatus lock_status =
+ tracked_locks_->GetPointLockStatus(cfh_id, key_str);
+ assert(lock_status.locked);
+ assert(lock_status.seq <= tracked_at_seq);
+ assert(lock_status.exclusive == exclusive);
+ }
+#endif
+ }
+ }
+
+ return s;
+}
+
+Status PessimisticTransaction::GetRangeLock(ColumnFamilyHandle* column_family,
+ const Endpoint& start_endp,
+ const Endpoint& end_endp) {
+ ColumnFamilyHandle* cfh =
+ column_family ? column_family : db_impl_->DefaultColumnFamily();
+ uint32_t cfh_id = GetColumnFamilyID(cfh);
+
+ Status s = txn_db_impl_->TryRangeLock(this, cfh_id, start_endp, end_endp);
+
+ if (s.ok()) {
+ RangeLockRequest req{cfh_id, start_endp, end_endp};
+ tracked_locks_->Track(req);
+ }
+ return s;
+}
+
+// Return OK() if this key has not been modified more recently than the
+// transaction snapshot_.
+// tracked_at_seq is the global seq at which we either locked the key or already
+// have done ValidateSnapshot.
+Status PessimisticTransaction::ValidateSnapshot(
+ ColumnFamilyHandle* column_family, const Slice& key,
+ SequenceNumber* tracked_at_seq) {
+ assert(snapshot_ || read_timestamp_ < kMaxTxnTimestamp);
+
+ SequenceNumber snap_seq = 0;
+ if (snapshot_) {
+ snap_seq = snapshot_->GetSequenceNumber();
+ if (*tracked_at_seq <= snap_seq) {
+ // If the key has been previous validated (or locked) at a sequence number
+ // earlier than the current snapshot's sequence number, we already know it
+ // has not been modified aftter snap_seq either.
+ return Status::OK();
+ }
+ } else {
+ snap_seq = db_impl_->GetLatestSequenceNumber();
+ }
+
+ // Otherwise we have either
+ // 1: tracked_at_seq == kMaxSequenceNumber, i.e., first time tracking the key
+ // 2: snap_seq < tracked_at_seq: last time we lock the key was via
+ // do_validate=false which means we had skipped ValidateSnapshot. In both
+ // cases we should do ValidateSnapshot now.
+
+ *tracked_at_seq = snap_seq;
+
+ ColumnFamilyHandle* cfh =
+ column_family ? column_family : db_impl_->DefaultColumnFamily();
+
+ assert(cfh);
+ const Comparator* const ucmp = cfh->GetComparator();
+ assert(ucmp);
+ size_t ts_sz = ucmp->timestamp_size();
+ std::string ts_buf;
+ if (ts_sz > 0 && read_timestamp_ < kMaxTxnTimestamp) {
+ assert(ts_sz == sizeof(read_timestamp_));
+ PutFixed64(&ts_buf, read_timestamp_);
+ }
+
+ return TransactionUtil::CheckKeyForConflicts(
+ db_impl_, cfh, key.ToString(), snap_seq, ts_sz == 0 ? nullptr : &ts_buf,
+ false /* cache_only */);
+}
+
+bool PessimisticTransaction::TryStealingLocks() {
+ assert(IsExpired());
+ TransactionState expected = STARTED;
+ return std::atomic_compare_exchange_strong(&txn_state_, &expected,
+ LOCKS_STOLEN);
+}
+
+void PessimisticTransaction::UnlockGetForUpdate(
+ ColumnFamilyHandle* column_family, const Slice& key) {
+ txn_db_impl_->UnLock(this, GetColumnFamilyID(column_family), key.ToString());
+}
+
+Status PessimisticTransaction::SetName(const TransactionName& name) {
+ Status s;
+ if (txn_state_ == STARTED) {
+ if (name_.length()) {
+ s = Status::InvalidArgument("Transaction has already been named.");
+ } else if (txn_db_impl_->GetTransactionByName(name) != nullptr) {
+ s = Status::InvalidArgument("Transaction name must be unique.");
+ } else if (name.length() < 1 || name.length() > 512) {
+ s = Status::InvalidArgument(
+ "Transaction name length must be between 1 and 512 chars.");
+ } else {
+ name_ = name;
+ txn_db_impl_->RegisterTransaction(this);
+ }
+ } else {
+ s = Status::InvalidArgument("Transaction is beyond state for naming.");
+ }
+ return s;
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/pessimistic_transaction.h b/src/rocksdb/utilities/transactions/pessimistic_transaction.h
new file mode 100644
index 000000000..d43d1d3ac
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/pessimistic_transaction.h
@@ -0,0 +1,313 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#ifndef ROCKSDB_LITE
+
+#include <algorithm>
+#include <atomic>
+#include <mutex>
+#include <stack>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "db/write_callback.h"
+#include "rocksdb/db.h"
+#include "rocksdb/slice.h"
+#include "rocksdb/snapshot.h"
+#include "rocksdb/status.h"
+#include "rocksdb/types.h"
+#include "rocksdb/utilities/transaction.h"
+#include "rocksdb/utilities/transaction_db.h"
+#include "rocksdb/utilities/write_batch_with_index.h"
+#include "util/autovector.h"
+#include "utilities/transactions/transaction_base.h"
+#include "utilities/transactions/transaction_util.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class PessimisticTransactionDB;
+
+// A transaction under pessimistic concurrency control. This class implements
+// the locking API and interfaces with the lock manager as well as the
+// pessimistic transactional db.
+class PessimisticTransaction : public TransactionBaseImpl {
+ public:
+ PessimisticTransaction(TransactionDB* db, const WriteOptions& write_options,
+ const TransactionOptions& txn_options,
+ const bool init = true);
+ // No copying allowed
+ PessimisticTransaction(const PessimisticTransaction&) = delete;
+ void operator=(const PessimisticTransaction&) = delete;
+
+ ~PessimisticTransaction() override;
+
+ void Reinitialize(TransactionDB* txn_db, const WriteOptions& write_options,
+ const TransactionOptions& txn_options);
+
+ Status Prepare() override;
+
+ Status Commit() override;
+
+ // It is basically Commit without going through Prepare phase. The write batch
+ // is also directly provided instead of expecting txn to gradually batch the
+ // transactions writes to an internal write batch.
+ Status CommitBatch(WriteBatch* batch);
+
+ Status Rollback() override;
+
+ Status RollbackToSavePoint() override;
+
+ Status SetName(const TransactionName& name) override;
+
+ // Generate a new unique transaction identifier
+ static TransactionID GenTxnID();
+
+ TransactionID GetID() const override { return txn_id_; }
+
+ std::vector<TransactionID> GetWaitingTxns(uint32_t* column_family_id,
+ std::string* key) const override {
+ std::lock_guard<std::mutex> lock(wait_mutex_);
+ std::vector<TransactionID> ids(waiting_txn_ids_.size());
+ if (key) *key = waiting_key_ ? *waiting_key_ : "";
+ if (column_family_id) *column_family_id = waiting_cf_id_;
+ std::copy(waiting_txn_ids_.begin(), waiting_txn_ids_.end(), ids.begin());
+ return ids;
+ }
+
+ void SetWaitingTxn(autovector<TransactionID> ids, uint32_t column_family_id,
+ const std::string* key) {
+ std::lock_guard<std::mutex> lock(wait_mutex_);
+ waiting_txn_ids_ = ids;
+ waiting_cf_id_ = column_family_id;
+ waiting_key_ = key;
+ }
+
+ void ClearWaitingTxn() {
+ std::lock_guard<std::mutex> lock(wait_mutex_);
+ waiting_txn_ids_.clear();
+ waiting_cf_id_ = 0;
+ waiting_key_ = nullptr;
+ }
+
+ // Returns the time (in microseconds according to Env->GetMicros())
+ // that this transaction will be expired. Returns 0 if this transaction does
+ // not expire.
+ uint64_t GetExpirationTime() const { return expiration_time_; }
+
+ // returns true if this transaction has an expiration_time and has expired.
+ bool IsExpired() const;
+
+ // Returns the number of microseconds a transaction can wait on acquiring a
+ // lock or -1 if there is no timeout.
+ int64_t GetLockTimeout() const { return lock_timeout_; }
+ void SetLockTimeout(int64_t timeout) override {
+ lock_timeout_ = timeout * 1000;
+ }
+
+ // Returns true if locks were stolen successfully, false otherwise.
+ bool TryStealingLocks();
+
+ bool IsDeadlockDetect() const override { return deadlock_detect_; }
+
+ int64_t GetDeadlockDetectDepth() const { return deadlock_detect_depth_; }
+
+ virtual Status GetRangeLock(ColumnFamilyHandle* column_family,
+ const Endpoint& start_key,
+ const Endpoint& end_key) override;
+
+ protected:
+ // Refer to
+ // TransactionOptions::use_only_the_last_commit_time_batch_for_recovery
+ bool use_only_the_last_commit_time_batch_for_recovery_ = false;
+ // Refer to
+ // TransactionOptions::skip_prepare
+ bool skip_prepare_ = false;
+
+ virtual Status PrepareInternal() = 0;
+
+ virtual Status CommitWithoutPrepareInternal() = 0;
+
+ // batch_cnt if non-zero is the number of sub-batches. A sub-batch is a batch
+ // with no duplicate keys. If zero, then the number of sub-batches is unknown.
+ virtual Status CommitBatchInternal(WriteBatch* batch,
+ size_t batch_cnt = 0) = 0;
+
+ virtual Status CommitInternal() = 0;
+
+ virtual Status RollbackInternal() = 0;
+
+ virtual void Initialize(const TransactionOptions& txn_options);
+
+ Status LockBatch(WriteBatch* batch, LockTracker* keys_to_unlock);
+
+ Status TryLock(ColumnFamilyHandle* column_family, const Slice& key,
+ bool read_only, bool exclusive, const bool do_validate = true,
+ const bool assume_tracked = false) override;
+
+ void Clear() override;
+
+ PessimisticTransactionDB* txn_db_impl_;
+ DBImpl* db_impl_;
+
+ // If non-zero, this transaction should not be committed after this time (in
+ // microseconds according to Env->NowMicros())
+ uint64_t expiration_time_;
+
+ // Timestamp used by the transaction to perform all GetForUpdate.
+ // Use this timestamp for conflict checking.
+ // read_timestamp_ == kMaxTxnTimestamp means this transaction has not
+ // performed any GetForUpdate. It is possible that the transaction has
+ // performed blind writes or Get, though.
+ TxnTimestamp read_timestamp_{kMaxTxnTimestamp};
+ TxnTimestamp commit_timestamp_{kMaxTxnTimestamp};
+
+ private:
+ friend class TransactionTest_ValidateSnapshotTest_Test;
+ // Used to create unique ids for transactions.
+ static std::atomic<TransactionID> txn_id_counter_;
+
+ // Unique ID for this transaction
+ TransactionID txn_id_;
+
+ // IDs for the transactions that are blocking the current transaction.
+ //
+ // empty if current transaction is not waiting.
+ autovector<TransactionID> waiting_txn_ids_;
+
+ // The following two represents the (cf, key) that a transaction is waiting
+ // on.
+ //
+ // If waiting_key_ is not null, then the pointer should always point to
+ // a valid string object. The reason is that it is only non-null when the
+ // transaction is blocked in the PointLockManager::AcquireWithTimeout
+ // function. At that point, the key string object is one of the function
+ // parameters.
+ uint32_t waiting_cf_id_;
+ const std::string* waiting_key_;
+
+ // Mutex protecting waiting_txn_ids_, waiting_cf_id_ and waiting_key_.
+ mutable std::mutex wait_mutex_;
+
+ // Timeout in microseconds when locking a key or -1 if there is no timeout.
+ int64_t lock_timeout_;
+
+ // Whether to perform deadlock detection or not.
+ bool deadlock_detect_;
+
+ // Whether to perform deadlock detection or not.
+ int64_t deadlock_detect_depth_;
+
+ // Refer to TransactionOptions::skip_concurrency_control
+ bool skip_concurrency_control_;
+
+ virtual Status ValidateSnapshot(ColumnFamilyHandle* column_family,
+ const Slice& key,
+ SequenceNumber* tracked_at_seq);
+
+ void UnlockGetForUpdate(ColumnFamilyHandle* column_family,
+ const Slice& key) override;
+};
+
+class WriteCommittedTxn : public PessimisticTransaction {
+ public:
+ WriteCommittedTxn(TransactionDB* db, const WriteOptions& write_options,
+ const TransactionOptions& txn_options);
+ // No copying allowed
+ WriteCommittedTxn(const WriteCommittedTxn&) = delete;
+ void operator=(const WriteCommittedTxn&) = delete;
+
+ ~WriteCommittedTxn() override {}
+
+ using TransactionBaseImpl::GetForUpdate;
+ Status GetForUpdate(const ReadOptions& read_options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ std::string* value, bool exclusive,
+ const bool do_validate) override;
+ Status GetForUpdate(const ReadOptions& read_options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ PinnableSlice* pinnable_val, bool exclusive,
+ const bool do_validate) override;
+
+ using TransactionBaseImpl::Put;
+ // `key` does NOT include timestamp even when it's enabled.
+ Status Put(ColumnFamilyHandle* column_family, const Slice& key,
+ const Slice& value, const bool assume_tracked = false) override;
+ Status Put(ColumnFamilyHandle* column_family, const SliceParts& key,
+ const SliceParts& value,
+ const bool assume_tracked = false) override;
+
+ using TransactionBaseImpl::PutUntracked;
+ Status PutUntracked(ColumnFamilyHandle* column_family, const Slice& key,
+ const Slice& value) override;
+ Status PutUntracked(ColumnFamilyHandle* column_family, const SliceParts& key,
+ const SliceParts& value) override;
+
+ using TransactionBaseImpl::Delete;
+ // `key` does NOT include timestamp even when it's enabled.
+ Status Delete(ColumnFamilyHandle* column_family, const Slice& key,
+ const bool assume_tracked = false) override;
+ Status Delete(ColumnFamilyHandle* column_family, const SliceParts& key,
+ const bool assume_tracked = false) override;
+
+ using TransactionBaseImpl::DeleteUntracked;
+ Status DeleteUntracked(ColumnFamilyHandle* column_family,
+ const Slice& key) override;
+ Status DeleteUntracked(ColumnFamilyHandle* column_family,
+ const SliceParts& key) override;
+
+ using TransactionBaseImpl::SingleDelete;
+ // `key` does NOT include timestamp even when it's enabled.
+ Status SingleDelete(ColumnFamilyHandle* column_family, const Slice& key,
+ const bool assume_tracked = false) override;
+ Status SingleDelete(ColumnFamilyHandle* column_family, const SliceParts& key,
+ const bool assume_tracked = false) override;
+
+ using TransactionBaseImpl::SingleDeleteUntracked;
+ Status SingleDeleteUntracked(ColumnFamilyHandle* column_family,
+ const Slice& key) override;
+
+ using TransactionBaseImpl::Merge;
+ Status Merge(ColumnFamilyHandle* column_family, const Slice& key,
+ const Slice& value, const bool assume_tracked = false) override;
+
+ Status SetReadTimestampForValidation(TxnTimestamp ts) override;
+ Status SetCommitTimestamp(TxnTimestamp ts) override;
+ TxnTimestamp GetCommitTimestamp() const override { return commit_timestamp_; }
+
+ private:
+ template <typename TValue>
+ Status GetForUpdateImpl(const ReadOptions& read_options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ TValue* value, bool exclusive,
+ const bool do_validate);
+
+ template <typename TKey, typename TOperation>
+ Status Operate(ColumnFamilyHandle* column_family, const TKey& key,
+ const bool do_validate, const bool assume_tracked,
+ TOperation&& operation);
+
+ Status PrepareInternal() override;
+
+ Status CommitWithoutPrepareInternal() override;
+
+ Status CommitBatchInternal(WriteBatch* batch, size_t batch_cnt) override;
+
+ Status CommitInternal() override;
+
+ Status RollbackInternal() override;
+
+ // Column families that enable timestamps and whose data are written when
+ // indexing_enabled_ is false. If a key is written when indexing_enabled_ is
+ // true, then the corresponding column family is not added to cfs_with_ts
+ // even if it enables timestamp.
+ std::unordered_set<uint32_t> cfs_with_ts_tracked_when_indexing_disabled_;
+};
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/pessimistic_transaction_db.cc b/src/rocksdb/utilities/transactions/pessimistic_transaction_db.cc
new file mode 100644
index 000000000..950ef8042
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/pessimistic_transaction_db.cc
@@ -0,0 +1,782 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "utilities/transactions/pessimistic_transaction_db.h"
+
+#include <cinttypes>
+#include <sstream>
+#include <string>
+#include <unordered_set>
+#include <vector>
+
+#include "db/db_impl/db_impl.h"
+#include "logging/logging.h"
+#include "rocksdb/db.h"
+#include "rocksdb/options.h"
+#include "rocksdb/utilities/transaction_db.h"
+#include "test_util/sync_point.h"
+#include "util/cast_util.h"
+#include "util/mutexlock.h"
+#include "utilities/transactions/pessimistic_transaction.h"
+#include "utilities/transactions/transaction_db_mutex_impl.h"
+#include "utilities/transactions/write_prepared_txn_db.h"
+#include "utilities/transactions/write_unprepared_txn_db.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+PessimisticTransactionDB::PessimisticTransactionDB(
+ DB* db, const TransactionDBOptions& txn_db_options)
+ : TransactionDB(db),
+ db_impl_(static_cast_with_check<DBImpl>(db)),
+ txn_db_options_(txn_db_options),
+ lock_manager_(NewLockManager(this, txn_db_options)) {
+ assert(db_impl_ != nullptr);
+ info_log_ = db_impl_->GetDBOptions().info_log;
+}
+
+// Support initiliazing PessimisticTransactionDB from a stackable db
+//
+// PessimisticTransactionDB
+// ^ ^
+// | |
+// | +
+// | StackableDB
+// | ^
+// | |
+// + +
+// DBImpl
+// ^
+// |(inherit)
+// +
+// DB
+//
+PessimisticTransactionDB::PessimisticTransactionDB(
+ StackableDB* db, const TransactionDBOptions& txn_db_options)
+ : TransactionDB(db),
+ db_impl_(static_cast_with_check<DBImpl>(db->GetRootDB())),
+ txn_db_options_(txn_db_options),
+ lock_manager_(NewLockManager(this, txn_db_options)) {
+ assert(db_impl_ != nullptr);
+}
+
+PessimisticTransactionDB::~PessimisticTransactionDB() {
+ while (!transactions_.empty()) {
+ delete transactions_.begin()->second;
+ // TODO(myabandeh): this seems to be an unsafe approach as it is not quite
+ // clear whether delete would also remove the entry from transactions_.
+ }
+}
+
+Status PessimisticTransactionDB::VerifyCFOptions(
+ const ColumnFamilyOptions& cf_options) {
+ const Comparator* const ucmp = cf_options.comparator;
+ assert(ucmp);
+ size_t ts_sz = ucmp->timestamp_size();
+ if (0 == ts_sz) {
+ return Status::OK();
+ }
+ if (ts_sz != sizeof(TxnTimestamp)) {
+ std::ostringstream oss;
+ oss << "Timestamp of transaction must have " << sizeof(TxnTimestamp)
+ << " bytes. CF comparator " << std::string(ucmp->Name())
+ << " timestamp size is " << ts_sz << " bytes";
+ return Status::InvalidArgument(oss.str());
+ }
+ if (txn_db_options_.write_policy != WRITE_COMMITTED) {
+ return Status::NotSupported("Only WriteCommittedTxn supports timestamp");
+ }
+ return Status::OK();
+}
+
+Status PessimisticTransactionDB::Initialize(
+ const std::vector<size_t>& compaction_enabled_cf_indices,
+ const std::vector<ColumnFamilyHandle*>& handles) {
+ for (auto cf_ptr : handles) {
+ AddColumnFamily(cf_ptr);
+ }
+ // Verify cf options
+ for (auto handle : handles) {
+ ColumnFamilyDescriptor cfd;
+ Status s = handle->GetDescriptor(&cfd);
+ if (!s.ok()) {
+ return s;
+ }
+ s = VerifyCFOptions(cfd.options);
+ if (!s.ok()) {
+ return s;
+ }
+ }
+
+ // Re-enable compaction for the column families that initially had
+ // compaction enabled.
+ std::vector<ColumnFamilyHandle*> compaction_enabled_cf_handles;
+ compaction_enabled_cf_handles.reserve(compaction_enabled_cf_indices.size());
+ for (auto index : compaction_enabled_cf_indices) {
+ compaction_enabled_cf_handles.push_back(handles[index]);
+ }
+
+ Status s = EnableAutoCompaction(compaction_enabled_cf_handles);
+
+ // create 'real' transactions from recovered shell transactions
+ auto dbimpl = static_cast_with_check<DBImpl>(GetRootDB());
+ assert(dbimpl != nullptr);
+ auto rtrxs = dbimpl->recovered_transactions();
+
+ for (auto it = rtrxs.begin(); it != rtrxs.end(); ++it) {
+ auto recovered_trx = it->second;
+ assert(recovered_trx);
+ assert(recovered_trx->batches_.size() == 1);
+ const auto& seq = recovered_trx->batches_.begin()->first;
+ const auto& batch_info = recovered_trx->batches_.begin()->second;
+ assert(batch_info.log_number_);
+ assert(recovered_trx->name_.length());
+
+ WriteOptions w_options;
+ w_options.sync = true;
+ TransactionOptions t_options;
+ // This would help avoiding deadlock for keys that although exist in the WAL
+ // did not go through concurrency control. This includes the merge that
+ // MyRocks uses for auto-inc columns. It is safe to do so, since (i) if
+ // there is a conflict between the keys of two transactions that must be
+ // avoided, it is already avoided by the application, MyRocks, before the
+ // restart (ii) application, MyRocks, guarntees to rollback/commit the
+ // recovered transactions before new transactions start.
+ t_options.skip_concurrency_control = true;
+
+ Transaction* real_trx = BeginTransaction(w_options, t_options, nullptr);
+ assert(real_trx);
+ real_trx->SetLogNumber(batch_info.log_number_);
+ assert(seq != kMaxSequenceNumber);
+ if (GetTxnDBOptions().write_policy != WRITE_COMMITTED) {
+ real_trx->SetId(seq);
+ }
+
+ s = real_trx->SetName(recovered_trx->name_);
+ if (!s.ok()) {
+ break;
+ }
+
+ s = real_trx->RebuildFromWriteBatch(batch_info.batch_);
+ // WriteCommitted set this to to disable this check that is specific to
+ // WritePrepared txns
+ assert(batch_info.batch_cnt_ == 0 ||
+ real_trx->GetWriteBatch()->SubBatchCnt() == batch_info.batch_cnt_);
+ real_trx->SetState(Transaction::PREPARED);
+ if (!s.ok()) {
+ break;
+ }
+ }
+ if (s.ok()) {
+ dbimpl->DeleteAllRecoveredTransactions();
+ }
+ return s;
+}
+
+Transaction* WriteCommittedTxnDB::BeginTransaction(
+ const WriteOptions& write_options, const TransactionOptions& txn_options,
+ Transaction* old_txn) {
+ if (old_txn != nullptr) {
+ ReinitializeTransaction(old_txn, write_options, txn_options);
+ return old_txn;
+ } else {
+ return new WriteCommittedTxn(this, write_options, txn_options);
+ }
+}
+
+TransactionDBOptions PessimisticTransactionDB::ValidateTxnDBOptions(
+ const TransactionDBOptions& txn_db_options) {
+ TransactionDBOptions validated = txn_db_options;
+
+ if (txn_db_options.num_stripes == 0) {
+ validated.num_stripes = 1;
+ }
+
+ return validated;
+}
+
+Status TransactionDB::Open(const Options& options,
+ const TransactionDBOptions& txn_db_options,
+ const std::string& dbname, TransactionDB** dbptr) {
+ DBOptions db_options(options);
+ ColumnFamilyOptions cf_options(options);
+ std::vector<ColumnFamilyDescriptor> column_families;
+ column_families.push_back(
+ ColumnFamilyDescriptor(kDefaultColumnFamilyName, cf_options));
+ std::vector<ColumnFamilyHandle*> handles;
+ Status s = TransactionDB::Open(db_options, txn_db_options, dbname,
+ column_families, &handles, dbptr);
+ if (s.ok()) {
+ assert(handles.size() == 1);
+ // i can delete the handle since DBImpl is always holding a reference to
+ // default column family
+ delete handles[0];
+ }
+
+ return s;
+}
+
+Status TransactionDB::Open(
+ const DBOptions& db_options, const TransactionDBOptions& txn_db_options,
+ const std::string& dbname,
+ const std::vector<ColumnFamilyDescriptor>& column_families,
+ std::vector<ColumnFamilyHandle*>* handles, TransactionDB** dbptr) {
+ Status s;
+ DB* db = nullptr;
+ if (txn_db_options.write_policy == WRITE_COMMITTED &&
+ db_options.unordered_write) {
+ return Status::NotSupported(
+ "WRITE_COMMITTED is incompatible with unordered_writes");
+ }
+ if (txn_db_options.write_policy == WRITE_UNPREPARED &&
+ db_options.unordered_write) {
+ // TODO(lth): support it
+ return Status::NotSupported(
+ "WRITE_UNPREPARED is currently incompatible with unordered_writes");
+ }
+ if (txn_db_options.write_policy == WRITE_PREPARED &&
+ db_options.unordered_write && !db_options.two_write_queues) {
+ return Status::NotSupported(
+ "WRITE_PREPARED is incompatible with unordered_writes if "
+ "two_write_queues is not enabled.");
+ }
+
+ std::vector<ColumnFamilyDescriptor> column_families_copy = column_families;
+ std::vector<size_t> compaction_enabled_cf_indices;
+ DBOptions db_options_2pc = db_options;
+ PrepareWrap(&db_options_2pc, &column_families_copy,
+ &compaction_enabled_cf_indices);
+ const bool use_seq_per_batch =
+ txn_db_options.write_policy == WRITE_PREPARED ||
+ txn_db_options.write_policy == WRITE_UNPREPARED;
+ const bool use_batch_per_txn =
+ txn_db_options.write_policy == WRITE_COMMITTED ||
+ txn_db_options.write_policy == WRITE_PREPARED;
+ s = DBImpl::Open(db_options_2pc, dbname, column_families_copy, handles, &db,
+ use_seq_per_batch, use_batch_per_txn);
+ if (s.ok()) {
+ ROCKS_LOG_WARN(db->GetDBOptions().info_log,
+ "Transaction write_policy is %" PRId32,
+ static_cast<int>(txn_db_options.write_policy));
+ // if WrapDB return non-ok, db will be deleted in WrapDB() via
+ // ~StackableDB().
+ s = WrapDB(db, txn_db_options, compaction_enabled_cf_indices, *handles,
+ dbptr);
+ }
+ return s;
+}
+
+void TransactionDB::PrepareWrap(
+ DBOptions* db_options, std::vector<ColumnFamilyDescriptor>* column_families,
+ std::vector<size_t>* compaction_enabled_cf_indices) {
+ compaction_enabled_cf_indices->clear();
+
+ // Enable MemTable History if not already enabled
+ for (size_t i = 0; i < column_families->size(); i++) {
+ ColumnFamilyOptions* cf_options = &(*column_families)[i].options;
+
+ if (cf_options->max_write_buffer_size_to_maintain == 0 &&
+ cf_options->max_write_buffer_number_to_maintain == 0) {
+ // Setting to -1 will set the History size to
+ // max_write_buffer_number * write_buffer_size.
+ cf_options->max_write_buffer_size_to_maintain = -1;
+ }
+ if (!cf_options->disable_auto_compactions) {
+ // Disable compactions momentarily to prevent race with DB::Open
+ cf_options->disable_auto_compactions = true;
+ compaction_enabled_cf_indices->push_back(i);
+ }
+ }
+ db_options->allow_2pc = true;
+}
+
+namespace {
+template <typename DBType>
+Status WrapAnotherDBInternal(
+ DBType* db, const TransactionDBOptions& txn_db_options,
+ const std::vector<size_t>& compaction_enabled_cf_indices,
+ const std::vector<ColumnFamilyHandle*>& handles, TransactionDB** dbptr) {
+ assert(db != nullptr);
+ assert(dbptr != nullptr);
+ *dbptr = nullptr;
+ std::unique_ptr<PessimisticTransactionDB> txn_db;
+ // txn_db owns object pointed to by the raw db pointer.
+ switch (txn_db_options.write_policy) {
+ case WRITE_UNPREPARED:
+ txn_db.reset(new WriteUnpreparedTxnDB(
+ db, PessimisticTransactionDB::ValidateTxnDBOptions(txn_db_options)));
+ break;
+ case WRITE_PREPARED:
+ txn_db.reset(new WritePreparedTxnDB(
+ db, PessimisticTransactionDB::ValidateTxnDBOptions(txn_db_options)));
+ break;
+ case WRITE_COMMITTED:
+ default:
+ txn_db.reset(new WriteCommittedTxnDB(
+ db, PessimisticTransactionDB::ValidateTxnDBOptions(txn_db_options)));
+ }
+ txn_db->UpdateCFComparatorMap(handles);
+ Status s = txn_db->Initialize(compaction_enabled_cf_indices, handles);
+ // In case of a failure at this point, db is deleted via the txn_db destructor
+ // and set to nullptr.
+ if (s.ok()) {
+ *dbptr = txn_db.release();
+ } else {
+ for (auto* h : handles) {
+ delete h;
+ }
+ // txn_db still owns db, and ~StackableDB() will be called when txn_db goes
+ // out of scope, deleting the input db pointer.
+ ROCKS_LOG_FATAL(db->GetDBOptions().info_log,
+ "Failed to initialize txn_db: %s", s.ToString().c_str());
+ }
+ return s;
+}
+} // namespace
+
+Status TransactionDB::WrapDB(
+ // make sure this db is already opened with memtable history enabled,
+ // auto compaction distabled and 2 phase commit enabled
+ DB* db, const TransactionDBOptions& txn_db_options,
+ const std::vector<size_t>& compaction_enabled_cf_indices,
+ const std::vector<ColumnFamilyHandle*>& handles, TransactionDB** dbptr) {
+ return WrapAnotherDBInternal(db, txn_db_options,
+ compaction_enabled_cf_indices, handles, dbptr);
+}
+
+Status TransactionDB::WrapStackableDB(
+ // make sure this stackable_db is already opened with memtable history
+ // enabled, auto compaction distabled and 2 phase commit enabled
+ StackableDB* db, const TransactionDBOptions& txn_db_options,
+ const std::vector<size_t>& compaction_enabled_cf_indices,
+ const std::vector<ColumnFamilyHandle*>& handles, TransactionDB** dbptr) {
+ return WrapAnotherDBInternal(db, txn_db_options,
+ compaction_enabled_cf_indices, handles, dbptr);
+}
+
+// Let LockManager know that this column family exists so it can
+// allocate a LockMap for it.
+void PessimisticTransactionDB::AddColumnFamily(
+ const ColumnFamilyHandle* handle) {
+ lock_manager_->AddColumnFamily(handle);
+}
+
+Status PessimisticTransactionDB::CreateColumnFamily(
+ const ColumnFamilyOptions& options, const std::string& column_family_name,
+ ColumnFamilyHandle** handle) {
+ InstrumentedMutexLock l(&column_family_mutex_);
+ Status s = VerifyCFOptions(options);
+ if (!s.ok()) {
+ return s;
+ }
+
+ s = db_->CreateColumnFamily(options, column_family_name, handle);
+ if (s.ok()) {
+ lock_manager_->AddColumnFamily(*handle);
+ UpdateCFComparatorMap(*handle);
+ }
+
+ return s;
+}
+
+Status PessimisticTransactionDB::CreateColumnFamilies(
+ const ColumnFamilyOptions& options,
+ const std::vector<std::string>& column_family_names,
+ std::vector<ColumnFamilyHandle*>* handles) {
+ InstrumentedMutexLock l(&column_family_mutex_);
+
+ Status s = VerifyCFOptions(options);
+ if (!s.ok()) {
+ return s;
+ }
+
+ s = db_->CreateColumnFamilies(options, column_family_names, handles);
+ if (s.ok()) {
+ for (auto* handle : *handles) {
+ lock_manager_->AddColumnFamily(handle);
+ UpdateCFComparatorMap(handle);
+ }
+ }
+
+ return s;
+}
+
+Status PessimisticTransactionDB::CreateColumnFamilies(
+ const std::vector<ColumnFamilyDescriptor>& column_families,
+ std::vector<ColumnFamilyHandle*>* handles) {
+ InstrumentedMutexLock l(&column_family_mutex_);
+
+ for (auto& cf_desc : column_families) {
+ Status s = VerifyCFOptions(cf_desc.options);
+ if (!s.ok()) {
+ return s;
+ }
+ }
+
+ Status s = db_->CreateColumnFamilies(column_families, handles);
+ if (s.ok()) {
+ for (auto* handle : *handles) {
+ lock_manager_->AddColumnFamily(handle);
+ UpdateCFComparatorMap(handle);
+ }
+ }
+
+ return s;
+}
+
+// Let LockManager know that it can deallocate the LockMap for this
+// column family.
+Status PessimisticTransactionDB::DropColumnFamily(
+ ColumnFamilyHandle* column_family) {
+ InstrumentedMutexLock l(&column_family_mutex_);
+
+ Status s = db_->DropColumnFamily(column_family);
+ if (s.ok()) {
+ lock_manager_->RemoveColumnFamily(column_family);
+ }
+
+ return s;
+}
+
+Status PessimisticTransactionDB::DropColumnFamilies(
+ const std::vector<ColumnFamilyHandle*>& column_families) {
+ InstrumentedMutexLock l(&column_family_mutex_);
+
+ Status s = db_->DropColumnFamilies(column_families);
+ if (s.ok()) {
+ for (auto* handle : column_families) {
+ lock_manager_->RemoveColumnFamily(handle);
+ }
+ }
+
+ return s;
+}
+
+Status PessimisticTransactionDB::TryLock(PessimisticTransaction* txn,
+ uint32_t cfh_id,
+ const std::string& key,
+ bool exclusive) {
+ return lock_manager_->TryLock(txn, cfh_id, key, GetEnv(), exclusive);
+}
+
+Status PessimisticTransactionDB::TryRangeLock(PessimisticTransaction* txn,
+ uint32_t cfh_id,
+ const Endpoint& start_endp,
+ const Endpoint& end_endp) {
+ return lock_manager_->TryLock(txn, cfh_id, start_endp, end_endp, GetEnv(),
+ /*exclusive=*/true);
+}
+
+void PessimisticTransactionDB::UnLock(PessimisticTransaction* txn,
+ const LockTracker& keys) {
+ lock_manager_->UnLock(txn, keys, GetEnv());
+}
+
+void PessimisticTransactionDB::UnLock(PessimisticTransaction* txn,
+ uint32_t cfh_id, const std::string& key) {
+ lock_manager_->UnLock(txn, cfh_id, key, GetEnv());
+}
+
+// Used when wrapping DB write operations in a transaction
+Transaction* PessimisticTransactionDB::BeginInternalTransaction(
+ const WriteOptions& options) {
+ TransactionOptions txn_options;
+ Transaction* txn = BeginTransaction(options, txn_options, nullptr);
+
+ // Use default timeout for non-transactional writes
+ txn->SetLockTimeout(txn_db_options_.default_lock_timeout);
+ return txn;
+}
+
+// All user Put, Merge, Delete, and Write requests must be intercepted to make
+// sure that they lock all keys that they are writing to avoid causing conflicts
+// with any concurrent transactions. The easiest way to do this is to wrap all
+// write operations in a transaction.
+//
+// Put(), Merge(), and Delete() only lock a single key per call. Write() will
+// sort its keys before locking them. This guarantees that TransactionDB write
+// methods cannot deadlock with each other (but still could deadlock with a
+// Transaction).
+Status PessimisticTransactionDB::Put(const WriteOptions& options,
+ ColumnFamilyHandle* column_family,
+ const Slice& key, const Slice& val) {
+ Status s = FailIfCfEnablesTs(this, column_family);
+ if (!s.ok()) {
+ return s;
+ }
+
+ Transaction* txn = BeginInternalTransaction(options);
+ txn->DisableIndexing();
+
+ // Since the client didn't create a transaction, they don't care about
+ // conflict checking for this write. So we just need to do PutUntracked().
+ s = txn->PutUntracked(column_family, key, val);
+
+ if (s.ok()) {
+ s = txn->Commit();
+ }
+
+ delete txn;
+
+ return s;
+}
+
+Status PessimisticTransactionDB::Delete(const WriteOptions& wopts,
+ ColumnFamilyHandle* column_family,
+ const Slice& key) {
+ Status s = FailIfCfEnablesTs(this, column_family);
+ if (!s.ok()) {
+ return s;
+ }
+
+ Transaction* txn = BeginInternalTransaction(wopts);
+ txn->DisableIndexing();
+
+ // Since the client didn't create a transaction, they don't care about
+ // conflict checking for this write. So we just need to do
+ // DeleteUntracked().
+ s = txn->DeleteUntracked(column_family, key);
+
+ if (s.ok()) {
+ s = txn->Commit();
+ }
+
+ delete txn;
+
+ return s;
+}
+
+Status PessimisticTransactionDB::SingleDelete(const WriteOptions& wopts,
+ ColumnFamilyHandle* column_family,
+ const Slice& key) {
+ Status s = FailIfCfEnablesTs(this, column_family);
+ if (!s.ok()) {
+ return s;
+ }
+
+ Transaction* txn = BeginInternalTransaction(wopts);
+ txn->DisableIndexing();
+
+ // Since the client didn't create a transaction, they don't care about
+ // conflict checking for this write. So we just need to do
+ // SingleDeleteUntracked().
+ s = txn->SingleDeleteUntracked(column_family, key);
+
+ if (s.ok()) {
+ s = txn->Commit();
+ }
+
+ delete txn;
+
+ return s;
+}
+
+Status PessimisticTransactionDB::Merge(const WriteOptions& options,
+ ColumnFamilyHandle* column_family,
+ const Slice& key, const Slice& value) {
+ Status s = FailIfCfEnablesTs(this, column_family);
+ if (!s.ok()) {
+ return s;
+ }
+
+ Transaction* txn = BeginInternalTransaction(options);
+ txn->DisableIndexing();
+
+ // Since the client didn't create a transaction, they don't care about
+ // conflict checking for this write. So we just need to do
+ // MergeUntracked().
+ s = txn->MergeUntracked(column_family, key, value);
+
+ if (s.ok()) {
+ s = txn->Commit();
+ }
+
+ delete txn;
+
+ return s;
+}
+
+Status PessimisticTransactionDB::Write(const WriteOptions& opts,
+ WriteBatch* updates) {
+ return WriteWithConcurrencyControl(opts, updates);
+}
+
+Status WriteCommittedTxnDB::Write(const WriteOptions& opts,
+ WriteBatch* updates) {
+ Status s = FailIfBatchHasTs(updates);
+ if (!s.ok()) {
+ return s;
+ }
+ if (txn_db_options_.skip_concurrency_control) {
+ return db_impl_->Write(opts, updates);
+ } else {
+ return WriteWithConcurrencyControl(opts, updates);
+ }
+}
+
+Status WriteCommittedTxnDB::Write(
+ const WriteOptions& opts,
+ const TransactionDBWriteOptimizations& optimizations, WriteBatch* updates) {
+ Status s = FailIfBatchHasTs(updates);
+ if (!s.ok()) {
+ return s;
+ }
+ if (optimizations.skip_concurrency_control) {
+ return db_impl_->Write(opts, updates);
+ } else {
+ return WriteWithConcurrencyControl(opts, updates);
+ }
+}
+
+void PessimisticTransactionDB::InsertExpirableTransaction(
+ TransactionID tx_id, PessimisticTransaction* tx) {
+ assert(tx->GetExpirationTime() > 0);
+ std::lock_guard<std::mutex> lock(map_mutex_);
+ expirable_transactions_map_.insert({tx_id, tx});
+}
+
+void PessimisticTransactionDB::RemoveExpirableTransaction(TransactionID tx_id) {
+ std::lock_guard<std::mutex> lock(map_mutex_);
+ expirable_transactions_map_.erase(tx_id);
+}
+
+bool PessimisticTransactionDB::TryStealingExpiredTransactionLocks(
+ TransactionID tx_id) {
+ std::lock_guard<std::mutex> lock(map_mutex_);
+
+ auto tx_it = expirable_transactions_map_.find(tx_id);
+ if (tx_it == expirable_transactions_map_.end()) {
+ return true;
+ }
+ PessimisticTransaction& tx = *(tx_it->second);
+ return tx.TryStealingLocks();
+}
+
+void PessimisticTransactionDB::ReinitializeTransaction(
+ Transaction* txn, const WriteOptions& write_options,
+ const TransactionOptions& txn_options) {
+ auto txn_impl = static_cast_with_check<PessimisticTransaction>(txn);
+
+ txn_impl->Reinitialize(this, write_options, txn_options);
+}
+
+Transaction* PessimisticTransactionDB::GetTransactionByName(
+ const TransactionName& name) {
+ std::lock_guard<std::mutex> lock(name_map_mutex_);
+ auto it = transactions_.find(name);
+ if (it == transactions_.end()) {
+ return nullptr;
+ } else {
+ return it->second;
+ }
+}
+
+void PessimisticTransactionDB::GetAllPreparedTransactions(
+ std::vector<Transaction*>* transv) {
+ assert(transv);
+ transv->clear();
+ std::lock_guard<std::mutex> lock(name_map_mutex_);
+ for (auto it = transactions_.begin(); it != transactions_.end(); ++it) {
+ if (it->second->GetState() == Transaction::PREPARED) {
+ transv->push_back(it->second);
+ }
+ }
+}
+
+LockManager::PointLockStatus PessimisticTransactionDB::GetLockStatusData() {
+ return lock_manager_->GetPointLockStatus();
+}
+
+std::vector<DeadlockPath> PessimisticTransactionDB::GetDeadlockInfoBuffer() {
+ return lock_manager_->GetDeadlockInfoBuffer();
+}
+
+void PessimisticTransactionDB::SetDeadlockInfoBufferSize(uint32_t target_size) {
+ lock_manager_->Resize(target_size);
+}
+
+void PessimisticTransactionDB::RegisterTransaction(Transaction* txn) {
+ assert(txn);
+ assert(txn->GetName().length() > 0);
+ assert(GetTransactionByName(txn->GetName()) == nullptr);
+ assert(txn->GetState() == Transaction::STARTED);
+ std::lock_guard<std::mutex> lock(name_map_mutex_);
+ transactions_[txn->GetName()] = txn;
+}
+
+void PessimisticTransactionDB::UnregisterTransaction(Transaction* txn) {
+ assert(txn);
+ std::lock_guard<std::mutex> lock(name_map_mutex_);
+ auto it = transactions_.find(txn->GetName());
+ assert(it != transactions_.end());
+ transactions_.erase(it);
+}
+
+std::pair<Status, std::shared_ptr<const Snapshot>>
+PessimisticTransactionDB::CreateTimestampedSnapshot(TxnTimestamp ts) {
+ if (kMaxTxnTimestamp == ts) {
+ return std::make_pair(Status::InvalidArgument("invalid ts"), nullptr);
+ }
+ assert(db_impl_);
+ return db_impl_->CreateTimestampedSnapshot(kMaxSequenceNumber, ts);
+}
+
+std::shared_ptr<const Snapshot>
+PessimisticTransactionDB::GetTimestampedSnapshot(TxnTimestamp ts) const {
+ assert(db_impl_);
+ return db_impl_->GetTimestampedSnapshot(ts);
+}
+
+void PessimisticTransactionDB::ReleaseTimestampedSnapshotsOlderThan(
+ TxnTimestamp ts) {
+ assert(db_impl_);
+ db_impl_->ReleaseTimestampedSnapshotsOlderThan(ts);
+}
+
+Status PessimisticTransactionDB::GetTimestampedSnapshots(
+ TxnTimestamp ts_lb, TxnTimestamp ts_ub,
+ std::vector<std::shared_ptr<const Snapshot>>& timestamped_snapshots) const {
+ assert(db_impl_);
+ return db_impl_->GetTimestampedSnapshots(ts_lb, ts_ub, timestamped_snapshots);
+}
+
+Status SnapshotCreationCallback::operator()(SequenceNumber seq,
+ bool disable_memtable) {
+ assert(db_impl_);
+ assert(commit_ts_ != kMaxTxnTimestamp);
+
+ const bool two_write_queues =
+ db_impl_->immutable_db_options().two_write_queues;
+ assert(!two_write_queues || !disable_memtable);
+#ifdef NDEBUG
+ (void)two_write_queues;
+ (void)disable_memtable;
+#endif
+
+ const bool seq_per_batch = db_impl_->seq_per_batch();
+ if (!seq_per_batch) {
+ assert(db_impl_->GetLastPublishedSequence() <= seq);
+ } else {
+ assert(db_impl_->GetLastPublishedSequence() < seq);
+ }
+
+ // Create a snapshot which can also be used for write conflict checking.
+ auto ret = db_impl_->CreateTimestampedSnapshot(seq, commit_ts_);
+ snapshot_creation_status_ = ret.first;
+ snapshot_ = ret.second;
+ if (snapshot_creation_status_.ok()) {
+ assert(snapshot_);
+ } else {
+ assert(!snapshot_);
+ }
+ if (snapshot_ && snapshot_notifier_) {
+ snapshot_notifier_->SnapshotCreated(snapshot_.get());
+ }
+ return Status::OK();
+}
+
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/pessimistic_transaction_db.h b/src/rocksdb/utilities/transactions/pessimistic_transaction_db.h
new file mode 100644
index 000000000..25cd11054
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/pessimistic_transaction_db.h
@@ -0,0 +1,318 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+#ifndef ROCKSDB_LITE
+
+#include <mutex>
+#include <queue>
+#include <set>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "db/db_iter.h"
+#include "db/read_callback.h"
+#include "db/snapshot_checker.h"
+#include "rocksdb/db.h"
+#include "rocksdb/options.h"
+#include "rocksdb/utilities/transaction_db.h"
+#include "util/cast_util.h"
+#include "utilities/transactions/lock/lock_manager.h"
+#include "utilities/transactions/lock/range/range_lock_manager.h"
+#include "utilities/transactions/pessimistic_transaction.h"
+#include "utilities/transactions/write_prepared_txn.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class PessimisticTransactionDB : public TransactionDB {
+ public:
+ explicit PessimisticTransactionDB(DB* db,
+ const TransactionDBOptions& txn_db_options);
+
+ explicit PessimisticTransactionDB(StackableDB* db,
+ const TransactionDBOptions& txn_db_options);
+
+ virtual ~PessimisticTransactionDB();
+
+ virtual const Snapshot* GetSnapshot() override { return db_->GetSnapshot(); }
+
+ virtual Status Initialize(
+ const std::vector<size_t>& compaction_enabled_cf_indices,
+ const std::vector<ColumnFamilyHandle*>& handles);
+
+ Transaction* BeginTransaction(const WriteOptions& write_options,
+ const TransactionOptions& txn_options,
+ Transaction* old_txn) override = 0;
+
+ using StackableDB::Put;
+ virtual Status Put(const WriteOptions& options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ const Slice& val) override;
+
+ using StackableDB::Delete;
+ virtual Status Delete(const WriteOptions& wopts,
+ ColumnFamilyHandle* column_family,
+ const Slice& key) override;
+
+ using StackableDB::SingleDelete;
+ virtual Status SingleDelete(const WriteOptions& wopts,
+ ColumnFamilyHandle* column_family,
+ const Slice& key) override;
+
+ using StackableDB::Merge;
+ virtual Status Merge(const WriteOptions& options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ const Slice& value) override;
+
+ using TransactionDB::Write;
+ virtual Status Write(const WriteOptions& opts, WriteBatch* updates) override;
+ inline Status WriteWithConcurrencyControl(const WriteOptions& opts,
+ WriteBatch* updates) {
+ Status s;
+ if (opts.protection_bytes_per_key > 0) {
+ s = WriteBatchInternal::UpdateProtectionInfo(
+ updates, opts.protection_bytes_per_key);
+ }
+ if (s.ok()) {
+ // Need to lock all keys in this batch to prevent write conflicts with
+ // concurrent transactions.
+ Transaction* txn = BeginInternalTransaction(opts);
+ txn->DisableIndexing();
+
+ auto txn_impl = static_cast_with_check<PessimisticTransaction>(txn);
+
+ // Since commitBatch sorts the keys before locking, concurrent Write()
+ // operations will not cause a deadlock.
+ // In order to avoid a deadlock with a concurrent Transaction,
+ // Transactions should use a lock timeout.
+ s = txn_impl->CommitBatch(updates);
+
+ delete txn;
+ }
+
+ return s;
+ }
+
+ using StackableDB::CreateColumnFamily;
+ virtual Status CreateColumnFamily(const ColumnFamilyOptions& options,
+ const std::string& column_family_name,
+ ColumnFamilyHandle** handle) override;
+
+ Status CreateColumnFamilies(
+ const ColumnFamilyOptions& options,
+ const std::vector<std::string>& column_family_names,
+ std::vector<ColumnFamilyHandle*>* handles) override;
+
+ Status CreateColumnFamilies(
+ const std::vector<ColumnFamilyDescriptor>& column_families,
+ std::vector<ColumnFamilyHandle*>* handles) override;
+
+ using StackableDB::DropColumnFamily;
+ virtual Status DropColumnFamily(ColumnFamilyHandle* column_family) override;
+
+ Status DropColumnFamilies(
+ const std::vector<ColumnFamilyHandle*>& column_families) override;
+
+ Status TryLock(PessimisticTransaction* txn, uint32_t cfh_id,
+ const std::string& key, bool exclusive);
+ Status TryRangeLock(PessimisticTransaction* txn, uint32_t cfh_id,
+ const Endpoint& start_endp, const Endpoint& end_endp);
+
+ void UnLock(PessimisticTransaction* txn, const LockTracker& keys);
+ void UnLock(PessimisticTransaction* txn, uint32_t cfh_id,
+ const std::string& key);
+
+ void AddColumnFamily(const ColumnFamilyHandle* handle);
+
+ static TransactionDBOptions ValidateTxnDBOptions(
+ const TransactionDBOptions& txn_db_options);
+
+ const TransactionDBOptions& GetTxnDBOptions() const {
+ return txn_db_options_;
+ }
+
+ void InsertExpirableTransaction(TransactionID tx_id,
+ PessimisticTransaction* tx);
+ void RemoveExpirableTransaction(TransactionID tx_id);
+
+ // If transaction is no longer available, locks can be stolen
+ // If transaction is available, try stealing locks directly from transaction
+ // It is the caller's responsibility to ensure that the referred transaction
+ // is expirable (GetExpirationTime() > 0) and that it is expired.
+ bool TryStealingExpiredTransactionLocks(TransactionID tx_id);
+
+ Transaction* GetTransactionByName(const TransactionName& name) override;
+
+ void RegisterTransaction(Transaction* txn);
+ void UnregisterTransaction(Transaction* txn);
+
+ // not thread safe. current use case is during recovery (single thread)
+ void GetAllPreparedTransactions(std::vector<Transaction*>* trans) override;
+
+ LockManager::PointLockStatus GetLockStatusData() override;
+
+ std::vector<DeadlockPath> GetDeadlockInfoBuffer() override;
+ void SetDeadlockInfoBufferSize(uint32_t target_size) override;
+
+ // The default implementation does nothing. The actual implementation is moved
+ // to the child classes that actually need this information. This was due to
+ // an odd performance drop we observed when the added std::atomic member to
+ // the base class even when the subclass do not read it in the fast path.
+ virtual void UpdateCFComparatorMap(const std::vector<ColumnFamilyHandle*>&) {}
+ virtual void UpdateCFComparatorMap(ColumnFamilyHandle*) {}
+
+ // Use the returned factory to create LockTrackers in transactions.
+ const LockTrackerFactory& GetLockTrackerFactory() const {
+ return lock_manager_->GetLockTrackerFactory();
+ }
+
+ std::pair<Status, std::shared_ptr<const Snapshot>> CreateTimestampedSnapshot(
+ TxnTimestamp ts) override;
+
+ std::shared_ptr<const Snapshot> GetTimestampedSnapshot(
+ TxnTimestamp ts) const override;
+
+ void ReleaseTimestampedSnapshotsOlderThan(TxnTimestamp ts) override;
+
+ Status GetTimestampedSnapshots(TxnTimestamp ts_lb, TxnTimestamp ts_ub,
+ std::vector<std::shared_ptr<const Snapshot>>&
+ timestamped_snapshots) const override;
+
+ protected:
+ DBImpl* db_impl_;
+ std::shared_ptr<Logger> info_log_;
+ const TransactionDBOptions txn_db_options_;
+
+ static Status FailIfBatchHasTs(const WriteBatch* wb);
+
+ static Status FailIfCfEnablesTs(const DB* db,
+ const ColumnFamilyHandle* column_family);
+
+ void ReinitializeTransaction(
+ Transaction* txn, const WriteOptions& write_options,
+ const TransactionOptions& txn_options = TransactionOptions());
+
+ virtual Status VerifyCFOptions(const ColumnFamilyOptions& cf_options);
+
+ private:
+ friend class WritePreparedTxnDB;
+ friend class WritePreparedTxnDBMock;
+ friend class WriteUnpreparedTxn;
+ friend class TransactionTest_DoubleCrashInRecovery_Test;
+ friend class TransactionTest_DoubleEmptyWrite_Test;
+ friend class TransactionTest_DuplicateKeys_Test;
+ friend class TransactionTest_PersistentTwoPhaseTransactionTest_Test;
+ friend class TransactionTest_TwoPhaseDoubleRecoveryTest_Test;
+ friend class TransactionTest_TwoPhaseOutOfOrderDelete_Test;
+ friend class TransactionStressTest_TwoPhaseLongPrepareTest_Test;
+ friend class WriteUnpreparedTransactionTest_RecoveryTest_Test;
+ friend class WriteUnpreparedTransactionTest_MarkLogWithPrepSection_Test;
+
+ Transaction* BeginInternalTransaction(const WriteOptions& options);
+
+ std::shared_ptr<LockManager> lock_manager_;
+
+ // Must be held when adding/dropping column families.
+ InstrumentedMutex column_family_mutex_;
+
+ // Used to ensure that no locks are stolen from an expirable transaction
+ // that has started a commit. Only transactions with an expiration time
+ // should be in this map.
+ std::mutex map_mutex_;
+ std::unordered_map<TransactionID, PessimisticTransaction*>
+ expirable_transactions_map_;
+
+ // map from name to two phase transaction instance
+ std::mutex name_map_mutex_;
+ std::unordered_map<TransactionName, Transaction*> transactions_;
+
+ // Signal that we are testing a crash scenario. Some asserts could be relaxed
+ // in such cases.
+ virtual void TEST_Crash() {}
+};
+
+// A PessimisticTransactionDB that writes the data to the DB after the commit.
+// In this way the DB only contains the committed data.
+class WriteCommittedTxnDB : public PessimisticTransactionDB {
+ public:
+ explicit WriteCommittedTxnDB(DB* db,
+ const TransactionDBOptions& txn_db_options)
+ : PessimisticTransactionDB(db, txn_db_options) {}
+
+ explicit WriteCommittedTxnDB(StackableDB* db,
+ const TransactionDBOptions& txn_db_options)
+ : PessimisticTransactionDB(db, txn_db_options) {}
+
+ virtual ~WriteCommittedTxnDB() {}
+
+ Transaction* BeginTransaction(const WriteOptions& write_options,
+ const TransactionOptions& txn_options,
+ Transaction* old_txn) override;
+
+ // Optimized version of ::Write that makes use of skip_concurrency_control
+ // hint
+ using TransactionDB::Write;
+ virtual Status Write(const WriteOptions& opts,
+ const TransactionDBWriteOptimizations& optimizations,
+ WriteBatch* updates) override;
+ virtual Status Write(const WriteOptions& opts, WriteBatch* updates) override;
+};
+
+inline Status PessimisticTransactionDB::FailIfBatchHasTs(
+ const WriteBatch* batch) {
+ if (batch != nullptr && WriteBatchInternal::HasKeyWithTimestamp(*batch)) {
+ return Status::NotSupported(
+ "Writes with timestamp must go through transaction API instead of "
+ "TransactionDB.");
+ }
+ return Status::OK();
+}
+
+inline Status PessimisticTransactionDB::FailIfCfEnablesTs(
+ const DB* db, const ColumnFamilyHandle* column_family) {
+ assert(db);
+ column_family = column_family ? column_family : db->DefaultColumnFamily();
+ assert(column_family);
+ const Comparator* const ucmp = column_family->GetComparator();
+ assert(ucmp);
+ if (ucmp->timestamp_size() > 0) {
+ return Status::NotSupported(
+ "Write operation with user timestamp must go through the transaction "
+ "API instead of TransactionDB.");
+ }
+ return Status::OK();
+}
+
+class SnapshotCreationCallback : public PostMemTableCallback {
+ public:
+ explicit SnapshotCreationCallback(
+ DBImpl* dbi, TxnTimestamp commit_ts,
+ const std::shared_ptr<TransactionNotifier>& notifier,
+ std::shared_ptr<const Snapshot>& snapshot)
+ : db_impl_(dbi),
+ commit_ts_(commit_ts),
+ snapshot_notifier_(notifier),
+ snapshot_(snapshot) {
+ assert(db_impl_);
+ }
+
+ ~SnapshotCreationCallback() override {
+ snapshot_creation_status_.PermitUncheckedError();
+ }
+
+ Status operator()(SequenceNumber seq, bool disable_memtable) override;
+
+ private:
+ DBImpl* const db_impl_;
+ const TxnTimestamp commit_ts_;
+ std::shared_ptr<TransactionNotifier> snapshot_notifier_;
+ std::shared_ptr<const Snapshot>& snapshot_;
+
+ Status snapshot_creation_status_;
+};
+
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/snapshot_checker.cc b/src/rocksdb/utilities/transactions/snapshot_checker.cc
new file mode 100644
index 000000000..76d16681a
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/snapshot_checker.cc
@@ -0,0 +1,53 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "db/snapshot_checker.h"
+
+#ifdef ROCKSDB_LITE
+#include <assert.h>
+#endif // ROCKSDB_LITE
+
+#include "port/lang.h"
+#include "utilities/transactions/write_prepared_txn_db.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+#ifdef ROCKSDB_LITE
+WritePreparedSnapshotChecker::WritePreparedSnapshotChecker(
+ WritePreparedTxnDB* /*txn_db*/) {}
+
+SnapshotCheckerResult WritePreparedSnapshotChecker::CheckInSnapshot(
+ SequenceNumber /*sequence*/, SequenceNumber /*snapshot_sequence*/) const {
+ // Should never be called in LITE mode.
+ assert(false);
+ return SnapshotCheckerResult::kInSnapshot;
+}
+
+#else
+
+WritePreparedSnapshotChecker::WritePreparedSnapshotChecker(
+ WritePreparedTxnDB* txn_db)
+ : txn_db_(txn_db){};
+
+SnapshotCheckerResult WritePreparedSnapshotChecker::CheckInSnapshot(
+ SequenceNumber sequence, SequenceNumber snapshot_sequence) const {
+ bool snapshot_released = false;
+ // TODO(myabandeh): set min_uncommitted
+ bool in_snapshot = txn_db_->IsInSnapshot(
+ sequence, snapshot_sequence, kMinUnCommittedSeq, &snapshot_released);
+ if (snapshot_released) {
+ return SnapshotCheckerResult::kSnapshotReleased;
+ }
+ return in_snapshot ? SnapshotCheckerResult::kInSnapshot
+ : SnapshotCheckerResult::kNotInSnapshot;
+}
+
+#endif // ROCKSDB_LITE
+
+DisableGCSnapshotChecker* DisableGCSnapshotChecker::Instance() {
+ STATIC_AVOID_DESTRUCTION(DisableGCSnapshotChecker, instance);
+ return &instance;
+}
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/transactions/timestamped_snapshot_test.cc b/src/rocksdb/utilities/transactions/timestamped_snapshot_test.cc
new file mode 100644
index 000000000..e9b474415
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/timestamped_snapshot_test.cc
@@ -0,0 +1,466 @@
+// Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifdef ROCKSDB_LITE
+#include <cstdio>
+
+int main(int /*argc*/, char** /*argv*/) {
+ fprintf(stderr, "SKIPPED as Transactions are not supported in LITE mode\n");
+ return 0;
+}
+#else // ROCKSDB_LITE
+#include <cassert>
+
+#include "util/cast_util.h"
+#include "utilities/transactions/transaction_test.h"
+
+namespace ROCKSDB_NAMESPACE {
+INSTANTIATE_TEST_CASE_P(
+ Unsupported, TimestampedSnapshotWithTsSanityCheck,
+ ::testing::Values(
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite),
+ std::make_tuple(false, false, WRITE_UNPREPARED, kOrderedWrite)));
+
+INSTANTIATE_TEST_CASE_P(WriteCommitted, TransactionTest,
+ ::testing::Combine(::testing::Bool(), ::testing::Bool(),
+ ::testing::Values(WRITE_COMMITTED),
+ ::testing::Values(kOrderedWrite)));
+
+namespace {
+// Not thread-safe. Caller needs to provide external synchronization.
+class TsCheckingTxnNotifier : public TransactionNotifier {
+ public:
+ explicit TsCheckingTxnNotifier() = default;
+
+ ~TsCheckingTxnNotifier() override {}
+
+ void SnapshotCreated(const Snapshot* new_snapshot) override {
+ assert(new_snapshot);
+ if (prev_snapshot_seq_ != kMaxSequenceNumber) {
+ assert(prev_snapshot_seq_ <= new_snapshot->GetSequenceNumber());
+ }
+ prev_snapshot_seq_ = new_snapshot->GetSequenceNumber();
+ if (prev_snapshot_ts_ != kMaxTxnTimestamp) {
+ assert(prev_snapshot_ts_ <= new_snapshot->GetTimestamp());
+ }
+ prev_snapshot_ts_ = new_snapshot->GetTimestamp();
+ }
+
+ TxnTimestamp prev_snapshot_ts() const { return prev_snapshot_ts_; }
+
+ private:
+ SequenceNumber prev_snapshot_seq_ = kMaxSequenceNumber;
+ TxnTimestamp prev_snapshot_ts_ = kMaxTxnTimestamp;
+};
+} // anonymous namespace
+
+TEST_P(TimestampedSnapshotWithTsSanityCheck, WithoutCommitTs) {
+ std::unique_ptr<Transaction> txn(
+ db->BeginTransaction(WriteOptions(), TransactionOptions()));
+ assert(txn);
+ ASSERT_OK(txn->SetName("txn0"));
+ ASSERT_OK(txn->Put("a", "v"));
+ ASSERT_OK(txn->Prepare());
+ Status s = txn->CommitAndTryCreateSnapshot();
+ ASSERT_TRUE(s.IsInvalidArgument());
+ ASSERT_OK(txn->Rollback());
+
+ txn.reset(db->BeginTransaction(WriteOptions(), TransactionOptions()));
+ assert(txn);
+ ASSERT_OK(txn->SetName("txn0"));
+ ASSERT_OK(txn->Put("a", "v"));
+ s = txn->CommitAndTryCreateSnapshot();
+ ASSERT_TRUE(s.IsInvalidArgument());
+}
+
+TEST_P(TimestampedSnapshotWithTsSanityCheck, SetCommitTs) {
+ std::unique_ptr<Transaction> txn(
+ db->BeginTransaction(WriteOptions(), TransactionOptions()));
+ assert(txn);
+ ASSERT_OK(txn->SetName("txn0"));
+ ASSERT_OK(txn->Put("a", "v"));
+ ASSERT_OK(txn->Prepare());
+ std::shared_ptr<const Snapshot> snapshot;
+ Status s = txn->CommitAndTryCreateSnapshot(nullptr, 10, &snapshot);
+ ASSERT_TRUE(s.IsNotSupported());
+ ASSERT_OK(txn->Rollback());
+
+ txn.reset(db->BeginTransaction(WriteOptions(), TransactionOptions()));
+ assert(txn);
+ ASSERT_OK(txn->SetName("txn0"));
+ ASSERT_OK(txn->Put("a", "v"));
+ s = txn->CommitAndTryCreateSnapshot(nullptr, 10, &snapshot);
+ ASSERT_TRUE(s.IsNotSupported());
+}
+
+TEST_P(TransactionTest, WithoutCommitTs) {
+ std::unique_ptr<Transaction> txn(
+ db->BeginTransaction(WriteOptions(), TransactionOptions()));
+ assert(txn);
+ ASSERT_OK(txn->SetName("txn0"));
+ ASSERT_OK(txn->Put("a", "v"));
+ ASSERT_OK(txn->Prepare());
+ Status s = txn->CommitAndTryCreateSnapshot();
+ ASSERT_TRUE(s.IsInvalidArgument());
+ ASSERT_OK(txn->Rollback());
+
+ txn.reset(db->BeginTransaction(WriteOptions(), TransactionOptions()));
+ assert(txn);
+ ASSERT_OK(txn->SetName("txn0"));
+ ASSERT_OK(txn->Put("a", "v"));
+ s = txn->CommitAndTryCreateSnapshot();
+ ASSERT_TRUE(s.IsInvalidArgument());
+}
+
+TEST_P(TransactionTest, ReuseExistingTxn) {
+ Transaction* txn = db->BeginTransaction(WriteOptions(), TransactionOptions());
+ assert(txn);
+ ASSERT_OK(txn->SetName("txn0"));
+ ASSERT_OK(txn->Put("a", "v1"));
+ ASSERT_OK(txn->Prepare());
+
+ auto notifier = std::make_shared<TsCheckingTxnNotifier>();
+ std::shared_ptr<const Snapshot> snapshot1;
+ Status s =
+ txn->CommitAndTryCreateSnapshot(notifier, /*commit_ts=*/100, &snapshot1);
+ ASSERT_OK(s);
+ ASSERT_EQ(100, snapshot1->GetTimestamp());
+
+ Transaction* txn1 =
+ db->BeginTransaction(WriteOptions(), TransactionOptions(), txn);
+ assert(txn1 == txn);
+ ASSERT_OK(txn1->SetName("txn1"));
+ ASSERT_OK(txn->Put("a", "v2"));
+ ASSERT_OK(txn->Prepare());
+ std::shared_ptr<const Snapshot> snapshot2;
+ s = txn->CommitAndTryCreateSnapshot(notifier, /*commit_ts=*/110, &snapshot2);
+ ASSERT_OK(s);
+ ASSERT_EQ(110, snapshot2->GetTimestamp());
+ delete txn;
+
+ {
+ std::string value;
+ ReadOptions read_opts;
+ read_opts.snapshot = snapshot1.get();
+ ASSERT_OK(db->Get(read_opts, "a", &value));
+ ASSERT_EQ("v1", value);
+
+ read_opts.snapshot = snapshot2.get();
+ ASSERT_OK(db->Get(read_opts, "a", &value));
+ ASSERT_EQ("v2", value);
+ }
+}
+
+TEST_P(TransactionTest, CreateSnapshotWhenCommit) {
+ std::unique_ptr<Transaction> txn(
+ db->BeginTransaction(WriteOptions(), TransactionOptions()));
+ assert(txn);
+
+ constexpr int batch_size = 10;
+ for (int i = 0; i < batch_size; ++i) {
+ ASSERT_OK(db->Put(WriteOptions(), "k" + std::to_string(i), "v0"));
+ }
+ const SequenceNumber seq0 = db->GetLatestSequenceNumber();
+ ASSERT_EQ(static_cast<SequenceNumber>(batch_size), seq0);
+
+ txn->SetSnapshot();
+ {
+ const Snapshot* const snapshot = txn->GetSnapshot();
+ assert(snapshot);
+ ASSERT_EQ(seq0, snapshot->GetSequenceNumber());
+ }
+
+ for (int i = 0; i < batch_size; ++i) {
+ ASSERT_OK(txn->Put("k" + std::to_string(i), "v1"));
+ }
+ ASSERT_OK(txn->SetName("txn0"));
+ ASSERT_OK(txn->Prepare());
+
+ std::shared_ptr<const Snapshot> snapshot;
+ constexpr TxnTimestamp timestamp = 1;
+ auto notifier = std::make_shared<TsCheckingTxnNotifier>();
+ Status s = txn->CommitAndTryCreateSnapshot(notifier, timestamp, &snapshot);
+ ASSERT_OK(s);
+ ASSERT_LT(notifier->prev_snapshot_ts(), kMaxTxnTimestamp);
+ assert(snapshot);
+ ASSERT_EQ(timestamp, snapshot->GetTimestamp());
+ ASSERT_EQ(seq0 + batch_size, snapshot->GetSequenceNumber());
+ const Snapshot* const raw_snapshot_ptr = txn->GetSnapshot();
+ ASSERT_EQ(raw_snapshot_ptr, snapshot.get());
+ ASSERT_EQ(snapshot, txn->GetTimestampedSnapshot());
+
+ {
+ std::shared_ptr<const Snapshot> snapshot1 =
+ db->GetLatestTimestampedSnapshot();
+ ASSERT_EQ(snapshot, snapshot1);
+ }
+ {
+ std::shared_ptr<const Snapshot> snapshot1 =
+ db->GetTimestampedSnapshot(timestamp);
+ ASSERT_EQ(snapshot, snapshot1);
+ }
+ {
+ std::vector<std::shared_ptr<const Snapshot> > snapshots;
+ s = db->GetAllTimestampedSnapshots(snapshots);
+ ASSERT_OK(s);
+ ASSERT_EQ(std::vector<std::shared_ptr<const Snapshot> >{snapshot},
+ snapshots);
+ }
+}
+
+TEST_P(TransactionTest, CreateSnapshot) {
+ // First create a non-timestamped snapshot
+ ManagedSnapshot snapshot_guard(db);
+ for (int i = 0; i < 10; ++i) {
+ ASSERT_OK(db->Put(WriteOptions(), "k" + std::to_string(i),
+ "v0_" + std::to_string(i)));
+ }
+ {
+ auto ret = db->CreateTimestampedSnapshot(kMaxTxnTimestamp);
+ ASSERT_TRUE(ret.first.IsInvalidArgument());
+ auto snapshot = ret.second;
+ ASSERT_EQ(nullptr, snapshot.get());
+ }
+ constexpr TxnTimestamp timestamp = 100;
+ Status s;
+ std::shared_ptr<const Snapshot> ts_snap0;
+ std::tie(s, ts_snap0) = db->CreateTimestampedSnapshot(timestamp);
+ ASSERT_OK(s);
+ assert(ts_snap0);
+ ASSERT_EQ(timestamp, ts_snap0->GetTimestamp());
+ for (int i = 0; i < 10; ++i) {
+ ASSERT_OK(db->Delete(WriteOptions(), "k" + std::to_string(i)));
+ }
+ {
+ ReadOptions read_opts;
+ read_opts.snapshot = ts_snap0.get();
+ for (int i = 0; i < 10; ++i) {
+ std::string value;
+ s = db->Get(read_opts, "k" + std::to_string(i), &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("v0_" + std::to_string(i), value);
+ }
+ }
+ {
+ std::shared_ptr<const Snapshot> snapshot =
+ db->GetLatestTimestampedSnapshot();
+ ASSERT_EQ(ts_snap0, snapshot);
+ }
+ {
+ std::shared_ptr<const Snapshot> snapshot =
+ db->GetTimestampedSnapshot(timestamp);
+ ASSERT_OK(s);
+ ASSERT_EQ(ts_snap0, snapshot);
+ }
+ {
+ std::vector<std::shared_ptr<const Snapshot> > snapshots;
+ s = db->GetAllTimestampedSnapshots(snapshots);
+ ASSERT_OK(s);
+ ASSERT_EQ(std::vector<std::shared_ptr<const Snapshot> >{ts_snap0},
+ snapshots);
+ }
+}
+
+TEST_P(TransactionTest, SequenceAndTsOrder) {
+ Status s;
+ std::shared_ptr<const Snapshot> snapshot;
+ std::tie(s, snapshot) = db->CreateTimestampedSnapshot(100);
+ ASSERT_OK(s);
+ assert(snapshot);
+ {
+ // Cannot request smaller timestamp for the new timestamped snapshot.
+ std::shared_ptr<const Snapshot> tmp_snapshot;
+ std::tie(s, tmp_snapshot) = db->CreateTimestampedSnapshot(50);
+ ASSERT_TRUE(s.IsInvalidArgument());
+ ASSERT_EQ(nullptr, tmp_snapshot.get());
+ }
+
+ // If requesting a new timestamped snapshot with the same timestamp and
+ // sequence number, we avoid creating new snapshot object but reuse
+ // exisisting one.
+ std::shared_ptr<const Snapshot> snapshot1;
+ std::tie(s, snapshot1) = db->CreateTimestampedSnapshot(100);
+ ASSERT_OK(s);
+ ASSERT_EQ(snapshot.get(), snapshot1.get());
+
+ // If there is no write, but we request a larger timestamp, we still create
+ // a new snapshot object.
+ std::shared_ptr<const Snapshot> snapshot2;
+ std::tie(s, snapshot2) = db->CreateTimestampedSnapshot(200);
+ ASSERT_OK(s);
+ assert(snapshot2);
+ ASSERT_NE(snapshot.get(), snapshot2.get());
+ ASSERT_EQ(snapshot2->GetSequenceNumber(), snapshot->GetSequenceNumber());
+ ASSERT_EQ(200, snapshot2->GetTimestamp());
+
+ // Increase sequence number.
+ ASSERT_OK(db->Put(WriteOptions(), "foo", "v0"));
+ {
+ // We are requesting the same timestamp for a larger sequence number, thus
+ // we cannot create timestamped snapshot.
+ std::shared_ptr<const Snapshot> tmp_snapshot;
+ std::tie(s, tmp_snapshot) = db->CreateTimestampedSnapshot(200);
+ ASSERT_TRUE(s.IsInvalidArgument());
+ ASSERT_EQ(nullptr, tmp_snapshot.get());
+ }
+ {
+ std::unique_ptr<Transaction> txn1(
+ db->BeginTransaction(WriteOptions(), TransactionOptions()));
+ ASSERT_OK(txn1->Put("bar", "v0"));
+ std::shared_ptr<const Snapshot> ss;
+ ASSERT_OK(txn1->CommitAndTryCreateSnapshot(nullptr, 200, &ss));
+ // Cannot create snapshot because requested timestamp is the same as the
+ // latest timestamped snapshot while sequence number is strictly higher.
+ ASSERT_EQ(nullptr, ss);
+ }
+ {
+ std::unique_ptr<Transaction> txn2(
+ db->BeginTransaction(WriteOptions(), TransactionOptions()));
+ ASSERT_OK(txn2->Put("bar", "v0"));
+ std::shared_ptr<const Snapshot> ss;
+ // Application should never do this. This is just to demonstrate error
+ // handling.
+ ASSERT_OK(txn2->CommitAndTryCreateSnapshot(nullptr, 100, &ss));
+ // Cannot create snapshot because requested timestamp is smaller than
+ // latest timestamped snapshot.
+ ASSERT_EQ(nullptr, ss);
+ }
+}
+
+TEST_P(TransactionTest, CloseDbWithSnapshots) {
+ std::unique_ptr<Transaction> txn(
+ db->BeginTransaction(WriteOptions(), TransactionOptions()));
+ ASSERT_OK(txn->SetName("txn0"));
+ ASSERT_OK(txn->Put("foo", "v"));
+ ASSERT_OK(txn->Prepare());
+ std::shared_ptr<const Snapshot> snapshot;
+ constexpr TxnTimestamp timestamp = 121;
+ auto notifier = std::make_shared<TsCheckingTxnNotifier>();
+ ASSERT_OK(txn->CommitAndTryCreateSnapshot(notifier, timestamp, &snapshot));
+ assert(snapshot);
+ ASSERT_LT(notifier->prev_snapshot_ts(), kMaxTxnTimestamp);
+ ASSERT_EQ(timestamp, snapshot->GetTimestamp());
+ ASSERT_TRUE(db->Close().IsAborted());
+}
+
+TEST_P(TransactionTest, MultipleTimestampedSnapshots) {
+ auto* dbimpl = static_cast_with_check<DBImpl>(db->GetRootDB());
+ assert(dbimpl);
+ const bool seq_per_batch = dbimpl->seq_per_batch();
+ // TODO: remove the following assert(!seq_per_batch) once timestamped snapshot
+ // is supported in write-prepared/write-unprepared transactions.
+ assert(!seq_per_batch);
+ constexpr size_t txn_size = 10;
+ constexpr TxnTimestamp ts_delta = 10;
+ constexpr size_t num_txns = 100;
+ std::vector<std::shared_ptr<const Snapshot> > snapshots(num_txns);
+ constexpr TxnTimestamp start_ts = 10000;
+ auto notifier = std::make_shared<TsCheckingTxnNotifier>();
+ for (size_t i = 0; i < num_txns; ++i) {
+ std::unique_ptr<Transaction> txn(
+ db->BeginTransaction(WriteOptions(), TransactionOptions()));
+ ASSERT_OK(txn->SetName("txn" + std::to_string(i)));
+ for (size_t j = 0; j < txn_size; ++j) {
+ ASSERT_OK(txn->Put("k" + std::to_string(j),
+ "v" + std::to_string(j) + "_" + std::to_string(i)));
+ }
+ if (0 == (i % 2)) {
+ ASSERT_OK(txn->Prepare());
+ }
+ ASSERT_OK(txn->CommitAndTryCreateSnapshot(notifier, start_ts + i * ts_delta,
+ &snapshots[i]));
+ assert(snapshots[i]);
+ ASSERT_LT(notifier->prev_snapshot_ts(), kMaxTxnTimestamp);
+ ASSERT_EQ(start_ts + i * ts_delta, snapshots[i]->GetTimestamp());
+ }
+
+ {
+ auto snapshot = db->GetTimestampedSnapshot(start_ts + 1);
+ ASSERT_EQ(nullptr, snapshot);
+ }
+
+ constexpr TxnTimestamp max_ts = start_ts + num_txns * ts_delta;
+ for (size_t i = 0; i < num_txns; ++i) {
+ auto snapshot = db->GetTimestampedSnapshot(start_ts + i * ts_delta);
+ ASSERT_EQ(snapshots[i], snapshot);
+
+ std::vector<std::shared_ptr<const Snapshot> > tmp_snapshots;
+ Status s = db->GetTimestampedSnapshots(max_ts, start_ts + i * ts_delta,
+ tmp_snapshots);
+ ASSERT_TRUE(s.IsInvalidArgument());
+ ASSERT_TRUE(tmp_snapshots.empty());
+
+ for (size_t j = i; j < num_txns; ++j) {
+ std::vector<std::shared_ptr<const Snapshot> > expected_snapshots(
+ snapshots.begin() + i, snapshots.begin() + j);
+ tmp_snapshots.clear();
+ s = db->GetTimestampedSnapshots(start_ts + i * ts_delta,
+ start_ts + j * ts_delta, tmp_snapshots);
+ if (i < j) {
+ ASSERT_OK(s);
+ } else {
+ ASSERT_TRUE(s.IsInvalidArgument());
+ }
+ ASSERT_EQ(expected_snapshots, tmp_snapshots);
+ }
+ }
+
+ {
+ std::vector<std::shared_ptr<const Snapshot> > tmp_snapshots;
+ const Status s = db->GetAllTimestampedSnapshots(tmp_snapshots);
+ ASSERT_OK(s);
+ ASSERT_EQ(snapshots, tmp_snapshots);
+
+ const std::shared_ptr<const Snapshot> latest_snapshot =
+ db->GetLatestTimestampedSnapshot();
+ ASSERT_EQ(snapshots.back(), latest_snapshot);
+ }
+
+ for (size_t i = 0; i <= num_txns; ++i) {
+ std::vector<std::shared_ptr<const Snapshot> > snapshots1(
+ snapshots.begin() + i, snapshots.end());
+ if (i > 0) {
+ auto snapshot1 =
+ db->GetTimestampedSnapshot(start_ts + (i - 1) * ts_delta);
+ assert(snapshot1);
+ ASSERT_EQ(start_ts + (i - 1) * ts_delta, snapshot1->GetTimestamp());
+ }
+
+ db->ReleaseTimestampedSnapshotsOlderThan(start_ts + i * ts_delta);
+
+ if (i > 0) {
+ auto snapshot1 =
+ db->GetTimestampedSnapshot(start_ts + (i - 1) * ts_delta);
+ ASSERT_EQ(nullptr, snapshot1);
+ }
+
+ std::vector<std::shared_ptr<const Snapshot> > tmp_snapshots;
+ const Status s = db->GetAllTimestampedSnapshots(tmp_snapshots);
+ ASSERT_OK(s);
+ ASSERT_EQ(snapshots1, tmp_snapshots);
+ }
+
+ // Even after released by db, the applications still hold reference to shared
+ // snapshots.
+ for (size_t i = 0; i < num_txns; ++i) {
+ assert(snapshots[i]);
+ ASSERT_EQ(start_ts + i * ts_delta, snapshots[i]->GetTimestamp());
+ }
+
+ snapshots.clear();
+ ASSERT_OK(db->Close());
+ delete db;
+ db = nullptr;
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
+#endif // !ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/transaction_base.cc b/src/rocksdb/utilities/transactions/transaction_base.cc
new file mode 100644
index 000000000..83fd94ac8
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/transaction_base.cc
@@ -0,0 +1,731 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "utilities/transactions/transaction_base.h"
+
+#include <cinttypes>
+
+#include "db/column_family.h"
+#include "db/db_impl/db_impl.h"
+#include "logging/logging.h"
+#include "rocksdb/comparator.h"
+#include "rocksdb/db.h"
+#include "rocksdb/status.h"
+#include "util/cast_util.h"
+#include "util/string_util.h"
+#include "utilities/transactions/lock/lock_tracker.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+Status Transaction::CommitAndTryCreateSnapshot(
+ std::shared_ptr<TransactionNotifier> notifier, TxnTimestamp ts,
+ std::shared_ptr<const Snapshot>* snapshot) {
+ if (snapshot) {
+ snapshot->reset();
+ }
+ TxnTimestamp commit_ts = GetCommitTimestamp();
+ if (commit_ts == kMaxTxnTimestamp) {
+ if (ts == kMaxTxnTimestamp) {
+ return Status::InvalidArgument("Commit timestamp unset");
+ } else {
+ const Status s = SetCommitTimestamp(ts);
+ if (!s.ok()) {
+ return s;
+ }
+ }
+ } else if (ts != kMaxTxnTimestamp) {
+ if (ts != commit_ts) {
+ // For now we treat this as error.
+ return Status::InvalidArgument("Different commit ts specified");
+ }
+ }
+ SetSnapshotOnNextOperation(notifier);
+ Status s = Commit();
+ if (!s.ok()) {
+ return s;
+ }
+ assert(s.ok());
+ // If we reach here, we must return ok status for this function.
+ std::shared_ptr<const Snapshot> new_snapshot = GetTimestampedSnapshot();
+
+ if (snapshot) {
+ *snapshot = new_snapshot;
+ }
+ return Status::OK();
+}
+
+TransactionBaseImpl::TransactionBaseImpl(
+ DB* db, const WriteOptions& write_options,
+ const LockTrackerFactory& lock_tracker_factory)
+ : db_(db),
+ dbimpl_(static_cast_with_check<DBImpl>(db)),
+ write_options_(write_options),
+ cmp_(GetColumnFamilyUserComparator(db->DefaultColumnFamily())),
+ lock_tracker_factory_(lock_tracker_factory),
+ start_time_(dbimpl_->GetSystemClock()->NowMicros()),
+ write_batch_(cmp_, 0, true, 0, write_options.protection_bytes_per_key),
+ tracked_locks_(lock_tracker_factory_.Create()),
+ commit_time_batch_(0 /* reserved_bytes */, 0 /* max_bytes */,
+ write_options.protection_bytes_per_key,
+ 0 /* default_cf_ts_sz */),
+ indexing_enabled_(true) {
+ assert(dynamic_cast<DBImpl*>(db_) != nullptr);
+ log_number_ = 0;
+ if (dbimpl_->allow_2pc()) {
+ InitWriteBatch();
+ }
+}
+
+TransactionBaseImpl::~TransactionBaseImpl() {
+ // Release snapshot if snapshot is set
+ SetSnapshotInternal(nullptr);
+}
+
+void TransactionBaseImpl::Clear() {
+ save_points_.reset(nullptr);
+ write_batch_.Clear();
+ commit_time_batch_.Clear();
+ tracked_locks_->Clear();
+ num_puts_ = 0;
+ num_deletes_ = 0;
+ num_merges_ = 0;
+
+ if (dbimpl_->allow_2pc()) {
+ InitWriteBatch();
+ }
+}
+
+void TransactionBaseImpl::Reinitialize(DB* db,
+ const WriteOptions& write_options) {
+ Clear();
+ ClearSnapshot();
+ id_ = 0;
+ db_ = db;
+ name_.clear();
+ log_number_ = 0;
+ write_options_ = write_options;
+ start_time_ = dbimpl_->GetSystemClock()->NowMicros();
+ indexing_enabled_ = true;
+ cmp_ = GetColumnFamilyUserComparator(db_->DefaultColumnFamily());
+ WriteBatchInternal::UpdateProtectionInfo(
+ write_batch_.GetWriteBatch(), write_options_.protection_bytes_per_key)
+ .PermitUncheckedError();
+ WriteBatchInternal::UpdateProtectionInfo(
+ &commit_time_batch_, write_options_.protection_bytes_per_key)
+ .PermitUncheckedError();
+}
+
+void TransactionBaseImpl::SetSnapshot() {
+ const Snapshot* snapshot = dbimpl_->GetSnapshotForWriteConflictBoundary();
+ SetSnapshotInternal(snapshot);
+}
+
+void TransactionBaseImpl::SetSnapshotInternal(const Snapshot* snapshot) {
+ // Set a custom deleter for the snapshot_ SharedPtr as the snapshot needs to
+ // be released, not deleted when it is no longer referenced.
+ snapshot_.reset(snapshot, std::bind(&TransactionBaseImpl::ReleaseSnapshot,
+ this, std::placeholders::_1, db_));
+ snapshot_needed_ = false;
+ snapshot_notifier_ = nullptr;
+}
+
+void TransactionBaseImpl::SetSnapshotOnNextOperation(
+ std::shared_ptr<TransactionNotifier> notifier) {
+ snapshot_needed_ = true;
+ snapshot_notifier_ = notifier;
+}
+
+void TransactionBaseImpl::SetSnapshotIfNeeded() {
+ if (snapshot_needed_) {
+ std::shared_ptr<TransactionNotifier> notifier = snapshot_notifier_;
+ SetSnapshot();
+ if (notifier != nullptr) {
+ notifier->SnapshotCreated(GetSnapshot());
+ }
+ }
+}
+
+Status TransactionBaseImpl::TryLock(ColumnFamilyHandle* column_family,
+ const SliceParts& key, bool read_only,
+ bool exclusive, const bool do_validate,
+ const bool assume_tracked) {
+ size_t key_size = 0;
+ for (int i = 0; i < key.num_parts; ++i) {
+ key_size += key.parts[i].size();
+ }
+
+ std::string str;
+ str.reserve(key_size);
+
+ for (int i = 0; i < key.num_parts; ++i) {
+ str.append(key.parts[i].data(), key.parts[i].size());
+ }
+
+ return TryLock(column_family, str, read_only, exclusive, do_validate,
+ assume_tracked);
+}
+
+void TransactionBaseImpl::SetSavePoint() {
+ if (save_points_ == nullptr) {
+ save_points_.reset(
+ new std::stack<TransactionBaseImpl::SavePoint,
+ autovector<TransactionBaseImpl::SavePoint>>());
+ }
+ save_points_->emplace(snapshot_, snapshot_needed_, snapshot_notifier_,
+ num_puts_, num_deletes_, num_merges_,
+ lock_tracker_factory_);
+ write_batch_.SetSavePoint();
+}
+
+Status TransactionBaseImpl::RollbackToSavePoint() {
+ if (save_points_ != nullptr && save_points_->size() > 0) {
+ // Restore saved SavePoint
+ TransactionBaseImpl::SavePoint& save_point = save_points_->top();
+ snapshot_ = save_point.snapshot_;
+ snapshot_needed_ = save_point.snapshot_needed_;
+ snapshot_notifier_ = save_point.snapshot_notifier_;
+ num_puts_ = save_point.num_puts_;
+ num_deletes_ = save_point.num_deletes_;
+ num_merges_ = save_point.num_merges_;
+
+ // Rollback batch
+ Status s = write_batch_.RollbackToSavePoint();
+ assert(s.ok());
+
+ // Rollback any keys that were tracked since the last savepoint
+ tracked_locks_->Subtract(*save_point.new_locks_);
+
+ save_points_->pop();
+
+ return s;
+ } else {
+ assert(write_batch_.RollbackToSavePoint().IsNotFound());
+ return Status::NotFound();
+ }
+}
+
+Status TransactionBaseImpl::PopSavePoint() {
+ if (save_points_ == nullptr || save_points_->empty()) {
+ // No SavePoint yet.
+ assert(write_batch_.PopSavePoint().IsNotFound());
+ return Status::NotFound();
+ }
+
+ assert(!save_points_->empty());
+ // If there is another savepoint A below the current savepoint B, then A needs
+ // to inherit tracked_keys in B so that if we rollback to savepoint A, we
+ // remember to unlock keys in B. If there is no other savepoint below, then we
+ // can safely discard savepoint info.
+ if (save_points_->size() == 1) {
+ save_points_->pop();
+ } else {
+ TransactionBaseImpl::SavePoint top(lock_tracker_factory_);
+ std::swap(top, save_points_->top());
+ save_points_->pop();
+
+ save_points_->top().new_locks_->Merge(*top.new_locks_);
+ }
+
+ return write_batch_.PopSavePoint();
+}
+
+Status TransactionBaseImpl::Get(const ReadOptions& read_options,
+ ColumnFamilyHandle* column_family,
+ const Slice& key, std::string* value) {
+ assert(value != nullptr);
+ PinnableSlice pinnable_val(value);
+ assert(!pinnable_val.IsPinned());
+ auto s = Get(read_options, column_family, key, &pinnable_val);
+ if (s.ok() && pinnable_val.IsPinned()) {
+ value->assign(pinnable_val.data(), pinnable_val.size());
+ } // else value is already assigned
+ return s;
+}
+
+Status TransactionBaseImpl::Get(const ReadOptions& read_options,
+ ColumnFamilyHandle* column_family,
+ const Slice& key, PinnableSlice* pinnable_val) {
+ return write_batch_.GetFromBatchAndDB(db_, read_options, column_family, key,
+ pinnable_val);
+}
+
+Status TransactionBaseImpl::GetForUpdate(const ReadOptions& read_options,
+ ColumnFamilyHandle* column_family,
+ const Slice& key, std::string* value,
+ bool exclusive,
+ const bool do_validate) {
+ if (!do_validate && read_options.snapshot != nullptr) {
+ return Status::InvalidArgument(
+ "If do_validate is false then GetForUpdate with snapshot is not "
+ "defined.");
+ }
+ Status s =
+ TryLock(column_family, key, true /* read_only */, exclusive, do_validate);
+
+ if (s.ok() && value != nullptr) {
+ assert(value != nullptr);
+ PinnableSlice pinnable_val(value);
+ assert(!pinnable_val.IsPinned());
+ s = Get(read_options, column_family, key, &pinnable_val);
+ if (s.ok() && pinnable_val.IsPinned()) {
+ value->assign(pinnable_val.data(), pinnable_val.size());
+ } // else value is already assigned
+ }
+ return s;
+}
+
+Status TransactionBaseImpl::GetForUpdate(const ReadOptions& read_options,
+ ColumnFamilyHandle* column_family,
+ const Slice& key,
+ PinnableSlice* pinnable_val,
+ bool exclusive,
+ const bool do_validate) {
+ if (!do_validate && read_options.snapshot != nullptr) {
+ return Status::InvalidArgument(
+ "If do_validate is false then GetForUpdate with snapshot is not "
+ "defined.");
+ }
+ Status s =
+ TryLock(column_family, key, true /* read_only */, exclusive, do_validate);
+
+ if (s.ok() && pinnable_val != nullptr) {
+ s = Get(read_options, column_family, key, pinnable_val);
+ }
+ return s;
+}
+
+std::vector<Status> TransactionBaseImpl::MultiGet(
+ const ReadOptions& read_options,
+ const std::vector<ColumnFamilyHandle*>& column_family,
+ const std::vector<Slice>& keys, std::vector<std::string>* values) {
+ size_t num_keys = keys.size();
+ values->resize(num_keys);
+
+ std::vector<Status> stat_list(num_keys);
+ for (size_t i = 0; i < num_keys; ++i) {
+ stat_list[i] = Get(read_options, column_family[i], keys[i], &(*values)[i]);
+ }
+
+ return stat_list;
+}
+
+void TransactionBaseImpl::MultiGet(const ReadOptions& read_options,
+ ColumnFamilyHandle* column_family,
+ const size_t num_keys, const Slice* keys,
+ PinnableSlice* values, Status* statuses,
+ const bool sorted_input) {
+ write_batch_.MultiGetFromBatchAndDB(db_, read_options, column_family,
+ num_keys, keys, values, statuses,
+ sorted_input);
+}
+
+std::vector<Status> TransactionBaseImpl::MultiGetForUpdate(
+ const ReadOptions& read_options,
+ const std::vector<ColumnFamilyHandle*>& column_family,
+ const std::vector<Slice>& keys, std::vector<std::string>* values) {
+ // Regardless of whether the MultiGet succeeded, track these keys.
+ size_t num_keys = keys.size();
+ values->resize(num_keys);
+
+ // Lock all keys
+ for (size_t i = 0; i < num_keys; ++i) {
+ Status s = TryLock(column_family[i], keys[i], true /* read_only */,
+ true /* exclusive */);
+ if (!s.ok()) {
+ // Fail entire multiget if we cannot lock all keys
+ return std::vector<Status>(num_keys, s);
+ }
+ }
+
+ // TODO(agiardullo): optimize multiget?
+ std::vector<Status> stat_list(num_keys);
+ for (size_t i = 0; i < num_keys; ++i) {
+ stat_list[i] = Get(read_options, column_family[i], keys[i], &(*values)[i]);
+ }
+
+ return stat_list;
+}
+
+Iterator* TransactionBaseImpl::GetIterator(const ReadOptions& read_options) {
+ Iterator* db_iter = db_->NewIterator(read_options);
+ assert(db_iter);
+
+ return write_batch_.NewIteratorWithBase(db_->DefaultColumnFamily(), db_iter,
+ &read_options);
+}
+
+Iterator* TransactionBaseImpl::GetIterator(const ReadOptions& read_options,
+ ColumnFamilyHandle* column_family) {
+ Iterator* db_iter = db_->NewIterator(read_options, column_family);
+ assert(db_iter);
+
+ return write_batch_.NewIteratorWithBase(column_family, db_iter,
+ &read_options);
+}
+
+Status TransactionBaseImpl::Put(ColumnFamilyHandle* column_family,
+ const Slice& key, const Slice& value,
+ const bool assume_tracked) {
+ const bool do_validate = !assume_tracked;
+ Status s = TryLock(column_family, key, false /* read_only */,
+ true /* exclusive */, do_validate, assume_tracked);
+
+ if (s.ok()) {
+ s = GetBatchForWrite()->Put(column_family, key, value);
+ if (s.ok()) {
+ num_puts_++;
+ }
+ }
+
+ return s;
+}
+
+Status TransactionBaseImpl::Put(ColumnFamilyHandle* column_family,
+ const SliceParts& key, const SliceParts& value,
+ const bool assume_tracked) {
+ const bool do_validate = !assume_tracked;
+ Status s = TryLock(column_family, key, false /* read_only */,
+ true /* exclusive */, do_validate, assume_tracked);
+
+ if (s.ok()) {
+ s = GetBatchForWrite()->Put(column_family, key, value);
+ if (s.ok()) {
+ num_puts_++;
+ }
+ }
+
+ return s;
+}
+
+Status TransactionBaseImpl::Merge(ColumnFamilyHandle* column_family,
+ const Slice& key, const Slice& value,
+ const bool assume_tracked) {
+ const bool do_validate = !assume_tracked;
+ Status s = TryLock(column_family, key, false /* read_only */,
+ true /* exclusive */, do_validate, assume_tracked);
+
+ if (s.ok()) {
+ s = GetBatchForWrite()->Merge(column_family, key, value);
+ if (s.ok()) {
+ num_merges_++;
+ }
+ }
+
+ return s;
+}
+
+Status TransactionBaseImpl::Delete(ColumnFamilyHandle* column_family,
+ const Slice& key,
+ const bool assume_tracked) {
+ const bool do_validate = !assume_tracked;
+ Status s = TryLock(column_family, key, false /* read_only */,
+ true /* exclusive */, do_validate, assume_tracked);
+
+ if (s.ok()) {
+ s = GetBatchForWrite()->Delete(column_family, key);
+ if (s.ok()) {
+ num_deletes_++;
+ }
+ }
+
+ return s;
+}
+
+Status TransactionBaseImpl::Delete(ColumnFamilyHandle* column_family,
+ const SliceParts& key,
+ const bool assume_tracked) {
+ const bool do_validate = !assume_tracked;
+ Status s = TryLock(column_family, key, false /* read_only */,
+ true /* exclusive */, do_validate, assume_tracked);
+
+ if (s.ok()) {
+ s = GetBatchForWrite()->Delete(column_family, key);
+ if (s.ok()) {
+ num_deletes_++;
+ }
+ }
+
+ return s;
+}
+
+Status TransactionBaseImpl::SingleDelete(ColumnFamilyHandle* column_family,
+ const Slice& key,
+ const bool assume_tracked) {
+ const bool do_validate = !assume_tracked;
+ Status s = TryLock(column_family, key, false /* read_only */,
+ true /* exclusive */, do_validate, assume_tracked);
+
+ if (s.ok()) {
+ s = GetBatchForWrite()->SingleDelete(column_family, key);
+ if (s.ok()) {
+ num_deletes_++;
+ }
+ }
+
+ return s;
+}
+
+Status TransactionBaseImpl::SingleDelete(ColumnFamilyHandle* column_family,
+ const SliceParts& key,
+ const bool assume_tracked) {
+ const bool do_validate = !assume_tracked;
+ Status s = TryLock(column_family, key, false /* read_only */,
+ true /* exclusive */, do_validate, assume_tracked);
+
+ if (s.ok()) {
+ s = GetBatchForWrite()->SingleDelete(column_family, key);
+ if (s.ok()) {
+ num_deletes_++;
+ }
+ }
+
+ return s;
+}
+
+Status TransactionBaseImpl::PutUntracked(ColumnFamilyHandle* column_family,
+ const Slice& key, const Slice& value) {
+ Status s = TryLock(column_family, key, false /* read_only */,
+ true /* exclusive */, false /* do_validate */);
+
+ if (s.ok()) {
+ s = GetBatchForWrite()->Put(column_family, key, value);
+ if (s.ok()) {
+ num_puts_++;
+ }
+ }
+
+ return s;
+}
+
+Status TransactionBaseImpl::PutUntracked(ColumnFamilyHandle* column_family,
+ const SliceParts& key,
+ const SliceParts& value) {
+ Status s = TryLock(column_family, key, false /* read_only */,
+ true /* exclusive */, false /* do_validate */);
+
+ if (s.ok()) {
+ s = GetBatchForWrite()->Put(column_family, key, value);
+ if (s.ok()) {
+ num_puts_++;
+ }
+ }
+
+ return s;
+}
+
+Status TransactionBaseImpl::MergeUntracked(ColumnFamilyHandle* column_family,
+ const Slice& key,
+ const Slice& value) {
+ Status s = TryLock(column_family, key, false /* read_only */,
+ true /* exclusive */, false /* do_validate */);
+
+ if (s.ok()) {
+ s = GetBatchForWrite()->Merge(column_family, key, value);
+ if (s.ok()) {
+ num_merges_++;
+ }
+ }
+
+ return s;
+}
+
+Status TransactionBaseImpl::DeleteUntracked(ColumnFamilyHandle* column_family,
+ const Slice& key) {
+ Status s = TryLock(column_family, key, false /* read_only */,
+ true /* exclusive */, false /* do_validate */);
+
+ if (s.ok()) {
+ s = GetBatchForWrite()->Delete(column_family, key);
+ if (s.ok()) {
+ num_deletes_++;
+ }
+ }
+
+ return s;
+}
+
+Status TransactionBaseImpl::DeleteUntracked(ColumnFamilyHandle* column_family,
+ const SliceParts& key) {
+ Status s = TryLock(column_family, key, false /* read_only */,
+ true /* exclusive */, false /* do_validate */);
+
+ if (s.ok()) {
+ s = GetBatchForWrite()->Delete(column_family, key);
+ if (s.ok()) {
+ num_deletes_++;
+ }
+ }
+
+ return s;
+}
+
+Status TransactionBaseImpl::SingleDeleteUntracked(
+ ColumnFamilyHandle* column_family, const Slice& key) {
+ Status s = TryLock(column_family, key, false /* read_only */,
+ true /* exclusive */, false /* do_validate */);
+
+ if (s.ok()) {
+ s = GetBatchForWrite()->SingleDelete(column_family, key);
+ if (s.ok()) {
+ num_deletes_++;
+ }
+ }
+
+ return s;
+}
+
+void TransactionBaseImpl::PutLogData(const Slice& blob) {
+ auto s = write_batch_.PutLogData(blob);
+ (void)s;
+ assert(s.ok());
+}
+
+WriteBatchWithIndex* TransactionBaseImpl::GetWriteBatch() {
+ return &write_batch_;
+}
+
+uint64_t TransactionBaseImpl::GetElapsedTime() const {
+ return (dbimpl_->GetSystemClock()->NowMicros() - start_time_) / 1000;
+}
+
+uint64_t TransactionBaseImpl::GetNumPuts() const { return num_puts_; }
+
+uint64_t TransactionBaseImpl::GetNumDeletes() const { return num_deletes_; }
+
+uint64_t TransactionBaseImpl::GetNumMerges() const { return num_merges_; }
+
+uint64_t TransactionBaseImpl::GetNumKeys() const {
+ return tracked_locks_->GetNumPointLocks();
+}
+
+void TransactionBaseImpl::TrackKey(uint32_t cfh_id, const std::string& key,
+ SequenceNumber seq, bool read_only,
+ bool exclusive) {
+ PointLockRequest r;
+ r.column_family_id = cfh_id;
+ r.key = key;
+ r.seq = seq;
+ r.read_only = read_only;
+ r.exclusive = exclusive;
+
+ // Update map of all tracked keys for this transaction
+ tracked_locks_->Track(r);
+
+ if (save_points_ != nullptr && !save_points_->empty()) {
+ // Update map of tracked keys in this SavePoint
+ save_points_->top().new_locks_->Track(r);
+ }
+}
+
+// Gets the write batch that should be used for Put/Merge/Deletes.
+//
+// Returns either a WriteBatch or WriteBatchWithIndex depending on whether
+// DisableIndexing() has been called.
+WriteBatchBase* TransactionBaseImpl::GetBatchForWrite() {
+ if (indexing_enabled_) {
+ // Use WriteBatchWithIndex
+ return &write_batch_;
+ } else {
+ // Don't use WriteBatchWithIndex. Return base WriteBatch.
+ return write_batch_.GetWriteBatch();
+ }
+}
+
+void TransactionBaseImpl::ReleaseSnapshot(const Snapshot* snapshot, DB* db) {
+ if (snapshot != nullptr) {
+ ROCKS_LOG_DETAILS(dbimpl_->immutable_db_options().info_log,
+ "ReleaseSnapshot %" PRIu64 " Set",
+ snapshot->GetSequenceNumber());
+ db->ReleaseSnapshot(snapshot);
+ }
+}
+
+void TransactionBaseImpl::UndoGetForUpdate(ColumnFamilyHandle* column_family,
+ const Slice& key) {
+ PointLockRequest r;
+ r.column_family_id = GetColumnFamilyID(column_family);
+ r.key = key.ToString();
+ r.read_only = true;
+
+ bool can_untrack = false;
+ if (save_points_ != nullptr && !save_points_->empty()) {
+ // If there is no GetForUpdate of the key in this save point,
+ // then cannot untrack from the global lock tracker.
+ UntrackStatus s = save_points_->top().new_locks_->Untrack(r);
+ can_untrack = (s != UntrackStatus::NOT_TRACKED);
+ } else {
+ // No save point, so can untrack from the global lock tracker.
+ can_untrack = true;
+ }
+
+ if (can_untrack) {
+ // If erased from the global tracker, then can unlock the key.
+ UntrackStatus s = tracked_locks_->Untrack(r);
+ bool can_unlock = (s == UntrackStatus::REMOVED);
+ if (can_unlock) {
+ UnlockGetForUpdate(column_family, key);
+ }
+ }
+}
+
+Status TransactionBaseImpl::RebuildFromWriteBatch(WriteBatch* src_batch) {
+ struct IndexedWriteBatchBuilder : public WriteBatch::Handler {
+ Transaction* txn_;
+ DBImpl* db_;
+ IndexedWriteBatchBuilder(Transaction* txn, DBImpl* db)
+ : txn_(txn), db_(db) {
+ assert(dynamic_cast<TransactionBaseImpl*>(txn_) != nullptr);
+ }
+
+ Status PutCF(uint32_t cf, const Slice& key, const Slice& val) override {
+ return txn_->Put(db_->GetColumnFamilyHandle(cf), key, val);
+ }
+
+ Status DeleteCF(uint32_t cf, const Slice& key) override {
+ return txn_->Delete(db_->GetColumnFamilyHandle(cf), key);
+ }
+
+ Status SingleDeleteCF(uint32_t cf, const Slice& key) override {
+ return txn_->SingleDelete(db_->GetColumnFamilyHandle(cf), key);
+ }
+
+ Status MergeCF(uint32_t cf, const Slice& key, const Slice& val) override {
+ return txn_->Merge(db_->GetColumnFamilyHandle(cf), key, val);
+ }
+
+ // this is used for reconstructing prepared transactions upon
+ // recovery. there should not be any meta markers in the batches
+ // we are processing.
+ Status MarkBeginPrepare(bool) override { return Status::InvalidArgument(); }
+
+ Status MarkEndPrepare(const Slice&) override {
+ return Status::InvalidArgument();
+ }
+
+ Status MarkCommit(const Slice&) override {
+ return Status::InvalidArgument();
+ }
+
+ Status MarkCommitWithTimestamp(const Slice&, const Slice&) override {
+ return Status::InvalidArgument();
+ }
+
+ Status MarkRollback(const Slice&) override {
+ return Status::InvalidArgument();
+ }
+ };
+
+ IndexedWriteBatchBuilder copycat(this, dbimpl_);
+ return src_batch->Iterate(&copycat);
+}
+
+WriteBatch* TransactionBaseImpl::GetCommitTimeWriteBatch() {
+ return &commit_time_batch_;
+}
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/transaction_base.h b/src/rocksdb/utilities/transactions/transaction_base.h
new file mode 100644
index 000000000..1bcb20ca9
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/transaction_base.h
@@ -0,0 +1,384 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#ifndef ROCKSDB_LITE
+
+#include <stack>
+#include <string>
+#include <vector>
+
+#include "db/write_batch_internal.h"
+#include "rocksdb/db.h"
+#include "rocksdb/slice.h"
+#include "rocksdb/snapshot.h"
+#include "rocksdb/status.h"
+#include "rocksdb/types.h"
+#include "rocksdb/utilities/transaction.h"
+#include "rocksdb/utilities/transaction_db.h"
+#include "rocksdb/utilities/write_batch_with_index.h"
+#include "util/autovector.h"
+#include "utilities/transactions/lock/lock_tracker.h"
+#include "utilities/transactions/transaction_util.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class TransactionBaseImpl : public Transaction {
+ public:
+ TransactionBaseImpl(DB* db, const WriteOptions& write_options,
+ const LockTrackerFactory& lock_tracker_factory);
+
+ ~TransactionBaseImpl() override;
+
+ // Remove pending operations queued in this transaction.
+ virtual void Clear();
+
+ void Reinitialize(DB* db, const WriteOptions& write_options);
+
+ // Called before executing Put, Merge, Delete, and GetForUpdate. If TryLock
+ // returns non-OK, the Put/Merge/Delete/GetForUpdate will be failed.
+ // do_validate will be false if called from PutUntracked, DeleteUntracked,
+ // MergeUntracked, or GetForUpdate(do_validate=false)
+ virtual Status TryLock(ColumnFamilyHandle* column_family, const Slice& key,
+ bool read_only, bool exclusive,
+ const bool do_validate = true,
+ const bool assume_tracked = false) = 0;
+
+ void SetSavePoint() override;
+
+ Status RollbackToSavePoint() override;
+
+ Status PopSavePoint() override;
+
+ using Transaction::Get;
+ Status Get(const ReadOptions& options, ColumnFamilyHandle* column_family,
+ const Slice& key, std::string* value) override;
+
+ Status Get(const ReadOptions& options, ColumnFamilyHandle* column_family,
+ const Slice& key, PinnableSlice* value) override;
+
+ Status Get(const ReadOptions& options, const Slice& key,
+ std::string* value) override {
+ return Get(options, db_->DefaultColumnFamily(), key, value);
+ }
+
+ using Transaction::GetForUpdate;
+ Status GetForUpdate(const ReadOptions& options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ std::string* value, bool exclusive,
+ const bool do_validate) override;
+
+ Status GetForUpdate(const ReadOptions& options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ PinnableSlice* pinnable_val, bool exclusive,
+ const bool do_validate) override;
+
+ Status GetForUpdate(const ReadOptions& options, const Slice& key,
+ std::string* value, bool exclusive,
+ const bool do_validate) override {
+ return GetForUpdate(options, db_->DefaultColumnFamily(), key, value,
+ exclusive, do_validate);
+ }
+
+ using Transaction::MultiGet;
+ std::vector<Status> MultiGet(
+ const ReadOptions& options,
+ const std::vector<ColumnFamilyHandle*>& column_family,
+ const std::vector<Slice>& keys,
+ std::vector<std::string>* values) override;
+
+ std::vector<Status> MultiGet(const ReadOptions& options,
+ const std::vector<Slice>& keys,
+ std::vector<std::string>* values) override {
+ return MultiGet(options,
+ std::vector<ColumnFamilyHandle*>(
+ keys.size(), db_->DefaultColumnFamily()),
+ keys, values);
+ }
+
+ void MultiGet(const ReadOptions& options, ColumnFamilyHandle* column_family,
+ const size_t num_keys, const Slice* keys, PinnableSlice* values,
+ Status* statuses, const bool sorted_input = false) override;
+
+ using Transaction::MultiGetForUpdate;
+ std::vector<Status> MultiGetForUpdate(
+ const ReadOptions& options,
+ const std::vector<ColumnFamilyHandle*>& column_family,
+ const std::vector<Slice>& keys,
+ std::vector<std::string>* values) override;
+
+ std::vector<Status> MultiGetForUpdate(
+ const ReadOptions& options, const std::vector<Slice>& keys,
+ std::vector<std::string>* values) override {
+ return MultiGetForUpdate(options,
+ std::vector<ColumnFamilyHandle*>(
+ keys.size(), db_->DefaultColumnFamily()),
+ keys, values);
+ }
+
+ Iterator* GetIterator(const ReadOptions& read_options) override;
+ Iterator* GetIterator(const ReadOptions& read_options,
+ ColumnFamilyHandle* column_family) override;
+
+ Status Put(ColumnFamilyHandle* column_family, const Slice& key,
+ const Slice& value, const bool assume_tracked = false) override;
+ Status Put(const Slice& key, const Slice& value) override {
+ return Put(nullptr, key, value);
+ }
+
+ Status Put(ColumnFamilyHandle* column_family, const SliceParts& key,
+ const SliceParts& value,
+ const bool assume_tracked = false) override;
+ Status Put(const SliceParts& key, const SliceParts& value) override {
+ return Put(nullptr, key, value);
+ }
+
+ Status Merge(ColumnFamilyHandle* column_family, const Slice& key,
+ const Slice& value, const bool assume_tracked = false) override;
+ Status Merge(const Slice& key, const Slice& value) override {
+ return Merge(nullptr, key, value);
+ }
+
+ Status Delete(ColumnFamilyHandle* column_family, const Slice& key,
+ const bool assume_tracked = false) override;
+ Status Delete(const Slice& key) override { return Delete(nullptr, key); }
+ Status Delete(ColumnFamilyHandle* column_family, const SliceParts& key,
+ const bool assume_tracked = false) override;
+ Status Delete(const SliceParts& key) override { return Delete(nullptr, key); }
+
+ Status SingleDelete(ColumnFamilyHandle* column_family, const Slice& key,
+ const bool assume_tracked = false) override;
+ Status SingleDelete(const Slice& key) override {
+ return SingleDelete(nullptr, key);
+ }
+ Status SingleDelete(ColumnFamilyHandle* column_family, const SliceParts& key,
+ const bool assume_tracked = false) override;
+ Status SingleDelete(const SliceParts& key) override {
+ return SingleDelete(nullptr, key);
+ }
+
+ Status PutUntracked(ColumnFamilyHandle* column_family, const Slice& key,
+ const Slice& value) override;
+ Status PutUntracked(const Slice& key, const Slice& value) override {
+ return PutUntracked(nullptr, key, value);
+ }
+
+ Status PutUntracked(ColumnFamilyHandle* column_family, const SliceParts& key,
+ const SliceParts& value) override;
+ Status PutUntracked(const SliceParts& key, const SliceParts& value) override {
+ return PutUntracked(nullptr, key, value);
+ }
+
+ Status MergeUntracked(ColumnFamilyHandle* column_family, const Slice& key,
+ const Slice& value) override;
+ Status MergeUntracked(const Slice& key, const Slice& value) override {
+ return MergeUntracked(nullptr, key, value);
+ }
+
+ Status DeleteUntracked(ColumnFamilyHandle* column_family,
+ const Slice& key) override;
+ Status DeleteUntracked(const Slice& key) override {
+ return DeleteUntracked(nullptr, key);
+ }
+ Status DeleteUntracked(ColumnFamilyHandle* column_family,
+ const SliceParts& key) override;
+ Status DeleteUntracked(const SliceParts& key) override {
+ return DeleteUntracked(nullptr, key);
+ }
+
+ Status SingleDeleteUntracked(ColumnFamilyHandle* column_family,
+ const Slice& key) override;
+ Status SingleDeleteUntracked(const Slice& key) override {
+ return SingleDeleteUntracked(nullptr, key);
+ }
+
+ void PutLogData(const Slice& blob) override;
+
+ WriteBatchWithIndex* GetWriteBatch() override;
+
+ virtual void SetLockTimeout(int64_t /*timeout*/) override { /* Do nothing */
+ }
+
+ const Snapshot* GetSnapshot() const override {
+ // will return nullptr when there is no snapshot
+ return snapshot_.get();
+ }
+
+ std::shared_ptr<const Snapshot> GetTimestampedSnapshot() const override {
+ return snapshot_;
+ }
+
+ virtual void SetSnapshot() override;
+ void SetSnapshotOnNextOperation(
+ std::shared_ptr<TransactionNotifier> notifier = nullptr) override;
+
+ void ClearSnapshot() override {
+ snapshot_.reset();
+ snapshot_needed_ = false;
+ snapshot_notifier_ = nullptr;
+ }
+
+ void DisableIndexing() override { indexing_enabled_ = false; }
+
+ void EnableIndexing() override { indexing_enabled_ = true; }
+
+ bool IndexingEnabled() const { return indexing_enabled_; }
+
+ uint64_t GetElapsedTime() const override;
+
+ uint64_t GetNumPuts() const override;
+
+ uint64_t GetNumDeletes() const override;
+
+ uint64_t GetNumMerges() const override;
+
+ uint64_t GetNumKeys() const override;
+
+ void UndoGetForUpdate(ColumnFamilyHandle* column_family,
+ const Slice& key) override;
+ void UndoGetForUpdate(const Slice& key) override {
+ return UndoGetForUpdate(nullptr, key);
+ };
+
+ WriteOptions* GetWriteOptions() override { return &write_options_; }
+
+ void SetWriteOptions(const WriteOptions& write_options) override {
+ write_options_ = write_options;
+ }
+
+ // Used for memory management for snapshot_
+ void ReleaseSnapshot(const Snapshot* snapshot, DB* db);
+
+ // iterates over the given batch and makes the appropriate inserts.
+ // used for rebuilding prepared transactions after recovery.
+ virtual Status RebuildFromWriteBatch(WriteBatch* src_batch) override;
+
+ WriteBatch* GetCommitTimeWriteBatch() override;
+
+ LockTracker& GetTrackedLocks() { return *tracked_locks_; }
+
+ protected:
+ // Add a key to the list of tracked keys.
+ //
+ // seqno is the earliest seqno this key was involved with this transaction.
+ // readonly should be set to true if no data was written for this key
+ void TrackKey(uint32_t cfh_id, const std::string& key, SequenceNumber seqno,
+ bool readonly, bool exclusive);
+
+ // Called when UndoGetForUpdate determines that this key can be unlocked.
+ virtual void UnlockGetForUpdate(ColumnFamilyHandle* column_family,
+ const Slice& key) = 0;
+
+ // Sets a snapshot if SetSnapshotOnNextOperation() has been called.
+ void SetSnapshotIfNeeded();
+
+ // Initialize write_batch_ for 2PC by inserting Noop.
+ inline void InitWriteBatch(bool clear = false) {
+ if (clear) {
+ write_batch_.Clear();
+ }
+ assert(write_batch_.GetDataSize() == WriteBatchInternal::kHeader);
+ auto s = WriteBatchInternal::InsertNoop(write_batch_.GetWriteBatch());
+ assert(s.ok());
+ }
+
+ WriteBatchBase* GetBatchForWrite();
+
+ DB* db_;
+ DBImpl* dbimpl_;
+
+ WriteOptions write_options_;
+
+ const Comparator* cmp_;
+
+ const LockTrackerFactory& lock_tracker_factory_;
+
+ // Stores that time the txn was constructed, in microseconds.
+ uint64_t start_time_;
+
+ // Stores the current snapshot that was set by SetSnapshot or null if
+ // no snapshot is currently set.
+ std::shared_ptr<const Snapshot> snapshot_;
+
+ // Count of various operations pending in this transaction
+ uint64_t num_puts_ = 0;
+ uint64_t num_deletes_ = 0;
+ uint64_t num_merges_ = 0;
+
+ struct SavePoint {
+ std::shared_ptr<const Snapshot> snapshot_;
+ bool snapshot_needed_ = false;
+ std::shared_ptr<TransactionNotifier> snapshot_notifier_;
+ uint64_t num_puts_ = 0;
+ uint64_t num_deletes_ = 0;
+ uint64_t num_merges_ = 0;
+
+ // Record all locks tracked since the last savepoint
+ std::shared_ptr<LockTracker> new_locks_;
+
+ SavePoint(std::shared_ptr<const Snapshot> snapshot, bool snapshot_needed,
+ std::shared_ptr<TransactionNotifier> snapshot_notifier,
+ uint64_t num_puts, uint64_t num_deletes, uint64_t num_merges,
+ const LockTrackerFactory& lock_tracker_factory)
+ : snapshot_(snapshot),
+ snapshot_needed_(snapshot_needed),
+ snapshot_notifier_(snapshot_notifier),
+ num_puts_(num_puts),
+ num_deletes_(num_deletes),
+ num_merges_(num_merges),
+ new_locks_(lock_tracker_factory.Create()) {}
+
+ explicit SavePoint(const LockTrackerFactory& lock_tracker_factory)
+ : new_locks_(lock_tracker_factory.Create()) {}
+ };
+
+ // Records writes pending in this transaction
+ WriteBatchWithIndex write_batch_;
+
+ // For Pessimistic Transactions this is the set of acquired locks.
+ // Optimistic Transactions will keep note the requested locks (not actually
+ // locked), and do conflict checking until commit time based on the tracked
+ // lock requests.
+ std::unique_ptr<LockTracker> tracked_locks_;
+
+ // Stack of the Snapshot saved at each save point. Saved snapshots may be
+ // nullptr if there was no snapshot at the time SetSavePoint() was called.
+ std::unique_ptr<std::stack<TransactionBaseImpl::SavePoint,
+ autovector<TransactionBaseImpl::SavePoint>>>
+ save_points_;
+
+ private:
+ friend class WriteCommittedTxn;
+ friend class WritePreparedTxn;
+
+ // Extra data to be persisted with the commit. Note this is only used when
+ // prepare phase is not skipped.
+ WriteBatch commit_time_batch_;
+
+ // If true, future Put/Merge/Deletes will be indexed in the
+ // WriteBatchWithIndex.
+ // If false, future Put/Merge/Deletes will be inserted directly into the
+ // underlying WriteBatch and not indexed in the WriteBatchWithIndex.
+ bool indexing_enabled_;
+
+ // SetSnapshotOnNextOperation() has been called and the snapshot has not yet
+ // been reset.
+ bool snapshot_needed_ = false;
+
+ // SetSnapshotOnNextOperation() has been called and the caller would like
+ // a notification through the TransactionNotifier interface
+ std::shared_ptr<TransactionNotifier> snapshot_notifier_ = nullptr;
+
+ Status TryLock(ColumnFamilyHandle* column_family, const SliceParts& key,
+ bool read_only, bool exclusive, const bool do_validate = true,
+ const bool assume_tracked = false);
+
+ void SetSnapshotInternal(const Snapshot* snapshot);
+};
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/transaction_db_mutex_impl.cc b/src/rocksdb/utilities/transactions/transaction_db_mutex_impl.cc
new file mode 100644
index 000000000..345c4be90
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/transaction_db_mutex_impl.cc
@@ -0,0 +1,135 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "utilities/transactions/transaction_db_mutex_impl.h"
+
+#include <chrono>
+#include <condition_variable>
+#include <functional>
+#include <mutex>
+
+#include "rocksdb/utilities/transaction_db_mutex.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class TransactionDBMutexImpl : public TransactionDBMutex {
+ public:
+ TransactionDBMutexImpl() {}
+ ~TransactionDBMutexImpl() override {}
+
+ Status Lock() override;
+
+ Status TryLockFor(int64_t timeout_time) override;
+
+ void UnLock() override { mutex_.unlock(); }
+
+ friend class TransactionDBCondVarImpl;
+
+ private:
+ std::mutex mutex_;
+};
+
+class TransactionDBCondVarImpl : public TransactionDBCondVar {
+ public:
+ TransactionDBCondVarImpl() {}
+ ~TransactionDBCondVarImpl() override {}
+
+ Status Wait(std::shared_ptr<TransactionDBMutex> mutex) override;
+
+ Status WaitFor(std::shared_ptr<TransactionDBMutex> mutex,
+ int64_t timeout_time) override;
+
+ void Notify() override { cv_.notify_one(); }
+
+ void NotifyAll() override { cv_.notify_all(); }
+
+ private:
+ std::condition_variable cv_;
+};
+
+std::shared_ptr<TransactionDBMutex>
+TransactionDBMutexFactoryImpl::AllocateMutex() {
+ return std::shared_ptr<TransactionDBMutex>(new TransactionDBMutexImpl());
+}
+
+std::shared_ptr<TransactionDBCondVar>
+TransactionDBMutexFactoryImpl::AllocateCondVar() {
+ return std::shared_ptr<TransactionDBCondVar>(new TransactionDBCondVarImpl());
+}
+
+Status TransactionDBMutexImpl::Lock() {
+ mutex_.lock();
+ return Status::OK();
+}
+
+Status TransactionDBMutexImpl::TryLockFor(int64_t timeout_time) {
+ bool locked = true;
+
+ if (timeout_time == 0) {
+ locked = mutex_.try_lock();
+ } else {
+ // Previously, this code used a std::timed_mutex. However, this was changed
+ // due to known bugs in gcc versions < 4.9.
+ // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=54562
+ //
+ // Since this mutex isn't held for long and only a single mutex is ever
+ // held at a time, it is reasonable to ignore the lock timeout_time here
+ // and only check it when waiting on the condition_variable.
+ mutex_.lock();
+ }
+
+ if (!locked) {
+ // timeout acquiring mutex
+ return Status::TimedOut(Status::SubCode::kMutexTimeout);
+ }
+
+ return Status::OK();
+}
+
+Status TransactionDBCondVarImpl::Wait(
+ std::shared_ptr<TransactionDBMutex> mutex) {
+ auto mutex_impl = reinterpret_cast<TransactionDBMutexImpl*>(mutex.get());
+
+ std::unique_lock<std::mutex> lock(mutex_impl->mutex_, std::adopt_lock);
+ cv_.wait(lock);
+
+ // Make sure unique_lock doesn't unlock mutex when it destructs
+ lock.release();
+
+ return Status::OK();
+}
+
+Status TransactionDBCondVarImpl::WaitFor(
+ std::shared_ptr<TransactionDBMutex> mutex, int64_t timeout_time) {
+ Status s;
+
+ auto mutex_impl = reinterpret_cast<TransactionDBMutexImpl*>(mutex.get());
+ std::unique_lock<std::mutex> lock(mutex_impl->mutex_, std::adopt_lock);
+
+ if (timeout_time < 0) {
+ // If timeout is negative, do not use a timeout
+ cv_.wait(lock);
+ } else {
+ auto duration = std::chrono::microseconds(timeout_time);
+ auto cv_status = cv_.wait_for(lock, duration);
+
+ // Check if the wait stopped due to timing out.
+ if (cv_status == std::cv_status::timeout) {
+ s = Status::TimedOut(Status::SubCode::kMutexTimeout);
+ }
+ }
+
+ // Make sure unique_lock doesn't unlock mutex when it destructs
+ lock.release();
+
+ // CV was signaled, or we spuriously woke up (but didn't time out)
+ return s;
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/transaction_db_mutex_impl.h b/src/rocksdb/utilities/transactions/transaction_db_mutex_impl.h
new file mode 100644
index 000000000..fbee92832
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/transaction_db_mutex_impl.h
@@ -0,0 +1,26 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+#ifndef ROCKSDB_LITE
+
+#include "rocksdb/utilities/transaction_db_mutex.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class TransactionDBMutex;
+class TransactionDBCondVar;
+
+// Default implementation of TransactionDBMutexFactory. May be overridden
+// by TransactionDBOptions.custom_mutex_factory.
+class TransactionDBMutexFactoryImpl : public TransactionDBMutexFactory {
+ public:
+ std::shared_ptr<TransactionDBMutex> AllocateMutex() override;
+ std::shared_ptr<TransactionDBCondVar> AllocateCondVar() override;
+};
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/transaction_test.cc b/src/rocksdb/utilities/transactions/transaction_test.cc
new file mode 100644
index 000000000..caf1566b9
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/transaction_test.cc
@@ -0,0 +1,6550 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "utilities/transactions/transaction_test.h"
+
+#include <algorithm>
+#include <functional>
+#include <string>
+#include <thread>
+
+#include "db/db_impl/db_impl.h"
+#include "port/port.h"
+#include "rocksdb/db.h"
+#include "rocksdb/options.h"
+#include "rocksdb/perf_context.h"
+#include "rocksdb/utilities/transaction.h"
+#include "rocksdb/utilities/transaction_db.h"
+#include "table/mock_table.h"
+#include "test_util/sync_point.h"
+#include "test_util/testharness.h"
+#include "test_util/testutil.h"
+#include "test_util/transaction_test_util.h"
+#include "util/random.h"
+#include "util/string_util.h"
+#include "utilities/fault_injection_env.h"
+#include "utilities/merge_operators.h"
+#include "utilities/merge_operators/string_append/stringappend.h"
+#include "utilities/transactions/pessimistic_transaction_db.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+INSTANTIATE_TEST_CASE_P(
+ DBAsBaseDB, TransactionTest,
+ ::testing::Values(
+ std::make_tuple(false, false, WRITE_COMMITTED, kOrderedWrite),
+ std::make_tuple(false, true, WRITE_COMMITTED, kOrderedWrite),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite),
+ std::make_tuple(false, false, WRITE_UNPREPARED, kOrderedWrite),
+ std::make_tuple(false, true, WRITE_UNPREPARED, kOrderedWrite)));
+INSTANTIATE_TEST_CASE_P(
+ DBAsBaseDB, TransactionStressTest,
+ ::testing::Values(
+ std::make_tuple(false, false, WRITE_COMMITTED, kOrderedWrite),
+ std::make_tuple(false, true, WRITE_COMMITTED, kOrderedWrite),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite),
+ std::make_tuple(false, false, WRITE_UNPREPARED, kOrderedWrite),
+ std::make_tuple(false, true, WRITE_UNPREPARED, kOrderedWrite)));
+INSTANTIATE_TEST_CASE_P(
+ StackableDBAsBaseDB, TransactionTest,
+ ::testing::Values(
+ std::make_tuple(true, true, WRITE_COMMITTED, kOrderedWrite),
+ std::make_tuple(true, true, WRITE_PREPARED, kOrderedWrite),
+ std::make_tuple(true, true, WRITE_UNPREPARED, kOrderedWrite)));
+
+// MySQLStyleTransactionTest takes far too long for valgrind to run. Only do it
+// in full mode (`ROCKSDB_FULL_VALGRIND_RUN` compiler flag is set).
+#if !defined(ROCKSDB_VALGRIND_RUN) || defined(ROCKSDB_FULL_VALGRIND_RUN)
+INSTANTIATE_TEST_CASE_P(
+ MySQLStyleTransactionTest, MySQLStyleTransactionTest,
+ ::testing::Values(
+ std::make_tuple(false, false, WRITE_COMMITTED, kOrderedWrite, false),
+ std::make_tuple(false, true, WRITE_COMMITTED, kOrderedWrite, false),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, false),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, true),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, false),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, true),
+ std::make_tuple(false, false, WRITE_UNPREPARED, kOrderedWrite, false),
+ std::make_tuple(false, false, WRITE_UNPREPARED, kOrderedWrite, true),
+ std::make_tuple(false, true, WRITE_UNPREPARED, kOrderedWrite, false),
+ std::make_tuple(false, true, WRITE_UNPREPARED, kOrderedWrite, true),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, false),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, true)));
+#endif // !defined(ROCKSDB_VALGRIND_RUN) || defined(ROCKSDB_FULL_VALGRIND_RUN)
+
+TEST_P(TransactionTest, DoubleEmptyWrite) {
+ WriteOptions write_options;
+ write_options.sync = true;
+ write_options.disableWAL = false;
+
+ WriteBatch batch;
+
+ ASSERT_OK(db->Write(write_options, &batch));
+ ASSERT_OK(db->Write(write_options, &batch));
+
+ // Also test committing empty transactions in 2PC
+ TransactionOptions txn_options;
+ Transaction* txn0 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(txn0->SetName("xid"));
+ ASSERT_OK(txn0->Prepare());
+ ASSERT_OK(txn0->Commit());
+ delete txn0;
+
+ // Also test that it works during recovery
+ txn0 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(txn0->SetName("xid2"));
+ ASSERT_OK(txn0->Put(Slice("foo0"), Slice("bar0a")));
+ ASSERT_OK(txn0->Prepare());
+ delete txn0;
+ reinterpret_cast<PessimisticTransactionDB*>(db)->TEST_Crash();
+ ASSERT_OK(ReOpenNoDelete());
+ assert(db != nullptr);
+ txn0 = db->GetTransactionByName("xid2");
+ ASSERT_OK(txn0->Commit());
+ delete txn0;
+}
+
+TEST_P(TransactionTest, SuccessTest) {
+ ASSERT_OK(db->ResetStats());
+
+ WriteOptions write_options;
+ ReadOptions read_options;
+ std::string value;
+
+ ASSERT_OK(db->Put(write_options, Slice("foo"), Slice("bar")));
+ ASSERT_OK(db->Put(write_options, Slice("foo2"), Slice("bar")));
+
+ Transaction* txn = db->BeginTransaction(write_options, TransactionOptions());
+ ASSERT_TRUE(txn);
+
+ ASSERT_EQ(0, txn->GetNumPuts());
+ ASSERT_LE(0, txn->GetID());
+
+ ASSERT_OK(txn->GetForUpdate(read_options, "foo", &value));
+ ASSERT_EQ(value, "bar");
+
+ ASSERT_OK(txn->Put(Slice("foo"), Slice("bar2")));
+
+ ASSERT_EQ(1, txn->GetNumPuts());
+
+ ASSERT_OK(txn->GetForUpdate(read_options, "foo", &value));
+ ASSERT_EQ(value, "bar2");
+
+ ASSERT_OK(txn->Commit());
+
+ ASSERT_OK(db->Get(read_options, "foo", &value));
+ ASSERT_EQ(value, "bar2");
+
+ delete txn;
+}
+
+TEST_P(TransactionTest, SwitchMemtableDuringPrepareAndCommit_WC) {
+ const TxnDBWritePolicy write_policy = std::get<2>(GetParam());
+
+ if (write_policy != TxnDBWritePolicy::WRITE_COMMITTED) {
+ ROCKSDB_GTEST_BYPASS("Test applies to write-committed only");
+ return;
+ }
+
+ ASSERT_OK(db->Put(WriteOptions(), "key0", "value"));
+
+ TransactionOptions txn_opts;
+ txn_opts.use_only_the_last_commit_time_batch_for_recovery = true;
+ Transaction* txn = db->BeginTransaction(WriteOptions(), txn_opts);
+ assert(txn);
+
+ SyncPoint::GetInstance()->DisableProcessing();
+ SyncPoint::GetInstance()->ClearAllCallBacks();
+ SyncPoint::GetInstance()->SetCallBack(
+ "FlushJob::WriteLevel0Table", [&](void* arg) {
+ // db mutex not held.
+ auto* mems = reinterpret_cast<autovector<MemTable*>*>(arg);
+ assert(mems);
+ ASSERT_EQ(1, mems->size());
+ auto* ctwb = txn->GetCommitTimeWriteBatch();
+ ASSERT_OK(ctwb->Put("gtid", "123"));
+ ASSERT_OK(txn->Commit());
+ delete txn;
+ });
+ SyncPoint::GetInstance()->EnableProcessing();
+
+ ASSERT_OK(txn->Put("key1", "value"));
+ ASSERT_OK(txn->SetName("txn1"));
+
+ ASSERT_OK(txn->Prepare());
+
+ auto dbimpl = static_cast_with_check<DBImpl>(db->GetRootDB());
+ ASSERT_OK(dbimpl->TEST_SwitchMemtable(nullptr));
+ ASSERT_OK(dbimpl->TEST_FlushMemTable(
+ /*wait=*/false, /*allow_write_stall=*/true, /*cfh=*/nullptr));
+
+ ASSERT_OK(dbimpl->TEST_WaitForFlushMemTable());
+
+ {
+ std::string value;
+ ASSERT_OK(db->Get(ReadOptions(), "key1", &value));
+ ASSERT_EQ("value", value);
+ }
+
+ delete db;
+ db = nullptr;
+ Status s;
+ if (use_stackable_db_ == false) {
+ s = TransactionDB::Open(options, txn_db_options, dbname, &db);
+ } else {
+ s = OpenWithStackableDB();
+ }
+ ASSERT_OK(s);
+ assert(db);
+
+ {
+ std::string value;
+ ASSERT_OK(db->Get(ReadOptions(), "gtid", &value));
+ ASSERT_EQ("123", value);
+
+ ASSERT_OK(db->Get(ReadOptions(), "key1", &value));
+ ASSERT_EQ("value", value);
+ }
+
+ SyncPoint::GetInstance()->DisableProcessing();
+ SyncPoint::GetInstance()->ClearAllCallBacks();
+}
+
+// The test clarifies the contract of do_validate and assume_tracked
+// in GetForUpdate and Put/Merge/Delete
+TEST_P(TransactionTest, AssumeExclusiveTracked) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ std::string value;
+ Status s;
+ TransactionOptions txn_options;
+ txn_options.lock_timeout = 1;
+ const bool EXCLUSIVE = true;
+ const bool DO_VALIDATE = true;
+ const bool ASSUME_LOCKED = true;
+
+ Transaction* txn = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn);
+ txn->SetSnapshot();
+
+ // commit a value after the snapshot is taken
+ ASSERT_OK(db->Put(write_options, Slice("foo"), Slice("bar")));
+
+ // By default write should fail to the commit after our snapshot
+ s = txn->GetForUpdate(read_options, "foo", &value, EXCLUSIVE);
+ ASSERT_TRUE(s.IsBusy());
+ // But the user could direct the db to skip validating the snapshot. The read
+ // value then should be the most recently committed
+ ASSERT_OK(
+ txn->GetForUpdate(read_options, "foo", &value, EXCLUSIVE, !DO_VALIDATE));
+ ASSERT_EQ(value, "bar");
+
+ // Although ValidateSnapshot is skipped the key must have still got locked
+ s = db->Put(write_options, Slice("foo"), Slice("bar"));
+ ASSERT_TRUE(s.IsTimedOut());
+
+ // By default the write operations should fail due to the commit after the
+ // snapshot
+ s = txn->Put(Slice("foo"), Slice("bar1"));
+ ASSERT_TRUE(s.IsBusy());
+ s = txn->Put(db->DefaultColumnFamily(), Slice("foo"), Slice("bar1"),
+ !ASSUME_LOCKED);
+ ASSERT_TRUE(s.IsBusy());
+ // But the user could direct the db that it already assumes exclusive lock on
+ // the key due to the previous GetForUpdate call.
+ ASSERT_OK(txn->Put(db->DefaultColumnFamily(), Slice("foo"), Slice("bar1"),
+ ASSUME_LOCKED));
+ ASSERT_OK(txn->Merge(db->DefaultColumnFamily(), Slice("foo"), Slice("bar2"),
+ ASSUME_LOCKED));
+ ASSERT_OK(
+ txn->Delete(db->DefaultColumnFamily(), Slice("foo"), ASSUME_LOCKED));
+ ASSERT_OK(txn->SingleDelete(db->DefaultColumnFamily(), Slice("foo"),
+ ASSUME_LOCKED));
+
+ ASSERT_OK(txn->Rollback());
+ delete txn;
+}
+
+// This test clarifies the contract of ValidateSnapshot
+TEST_P(TransactionTest, ValidateSnapshotTest) {
+ for (bool with_flush : {true}) {
+ for (bool with_2pc : {true}) {
+ ASSERT_OK(ReOpen());
+ WriteOptions write_options;
+ ReadOptions read_options;
+ std::string value;
+
+ assert(db != nullptr);
+ Transaction* txn1 =
+ db->BeginTransaction(write_options, TransactionOptions());
+ ASSERT_TRUE(txn1);
+ ASSERT_OK(txn1->Put(Slice("foo"), Slice("bar1")));
+ if (with_2pc) {
+ ASSERT_OK(txn1->SetName("xid1"));
+ ASSERT_OK(txn1->Prepare());
+ }
+
+ if (with_flush) {
+ auto db_impl = static_cast_with_check<DBImpl>(db->GetRootDB());
+ ASSERT_OK(db_impl->TEST_FlushMemTable(true));
+ // Make sure the flushed memtable is not kept in memory
+ int max_memtable_in_history =
+ std::max(
+ options.max_write_buffer_number,
+ static_cast<int>(options.max_write_buffer_size_to_maintain) /
+ static_cast<int>(options.write_buffer_size)) +
+ 1;
+ for (int i = 0; i < max_memtable_in_history; i++) {
+ ASSERT_OK(db->Put(write_options, Slice("key"), Slice("value")));
+ ASSERT_OK(db_impl->TEST_FlushMemTable(true));
+ }
+ }
+
+ Transaction* txn2 =
+ db->BeginTransaction(write_options, TransactionOptions());
+ ASSERT_TRUE(txn2);
+ txn2->SetSnapshot();
+
+ ASSERT_OK(txn1->Commit());
+ delete txn1;
+
+ auto pes_txn2 = dynamic_cast<PessimisticTransaction*>(txn2);
+ // Test the simple case where the key is not tracked yet
+ auto trakced_seq = kMaxSequenceNumber;
+ auto s = pes_txn2->ValidateSnapshot(db->DefaultColumnFamily(), "foo",
+ &trakced_seq);
+ ASSERT_TRUE(s.IsBusy());
+ delete txn2;
+ }
+ }
+}
+
+TEST_P(TransactionTest, WaitingTxn) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ TransactionOptions txn_options;
+ std::string value;
+ Status s;
+
+ txn_options.lock_timeout = 1;
+ s = db->Put(write_options, Slice("foo"), Slice("bar"));
+ ASSERT_OK(s);
+
+ /* create second cf */
+ ColumnFamilyHandle* cfa;
+ ColumnFamilyOptions cf_options;
+ s = db->CreateColumnFamily(cf_options, "CFA", &cfa);
+ ASSERT_OK(s);
+ s = db->Put(write_options, cfa, Slice("foo"), Slice("bar"));
+ ASSERT_OK(s);
+
+ Transaction* txn1 = db->BeginTransaction(write_options, txn_options);
+ Transaction* txn2 = db->BeginTransaction(write_options, txn_options);
+ TransactionID id1 = txn1->GetID();
+ ASSERT_TRUE(txn1);
+ ASSERT_TRUE(txn2);
+
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "PointLockManager::AcquireWithTimeout:WaitingTxn", [&](void* /*arg*/) {
+ std::string key;
+ uint32_t cf_id;
+ std::vector<TransactionID> wait = txn2->GetWaitingTxns(&cf_id, &key);
+ ASSERT_EQ(key, "foo");
+ ASSERT_EQ(wait.size(), 1);
+ ASSERT_EQ(wait[0], id1);
+ ASSERT_EQ(cf_id, 0U);
+ });
+
+ get_perf_context()->Reset();
+ // lock key in default cf
+ s = txn1->GetForUpdate(read_options, "foo", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "bar");
+ ASSERT_EQ(get_perf_context()->key_lock_wait_count, 0);
+
+ // lock key in cfa
+ s = txn1->GetForUpdate(read_options, cfa, "foo", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "bar");
+ ASSERT_EQ(get_perf_context()->key_lock_wait_count, 0);
+
+ auto lock_data = db->GetLockStatusData();
+ // Locked keys exist in both column family.
+ ASSERT_EQ(lock_data.size(), 2);
+
+ auto cf_iterator = lock_data.begin();
+
+ // The iterator points to an unordered_multimap
+ // thus the test can not assume any particular order.
+
+ // Column family is 1 or 0 (cfa).
+ if (cf_iterator->first != 1 && cf_iterator->first != 0) {
+ FAIL();
+ }
+ // The locked key is "foo" and is locked by txn1
+ ASSERT_EQ(cf_iterator->second.key, "foo");
+ ASSERT_EQ(cf_iterator->second.ids.size(), 1);
+ ASSERT_EQ(cf_iterator->second.ids[0], txn1->GetID());
+
+ cf_iterator++;
+
+ // Column family is 0 (default) or 1.
+ if (cf_iterator->first != 1 && cf_iterator->first != 0) {
+ FAIL();
+ }
+ // The locked key is "foo" and is locked by txn1
+ ASSERT_EQ(cf_iterator->second.key, "foo");
+ ASSERT_EQ(cf_iterator->second.ids.size(), 1);
+ ASSERT_EQ(cf_iterator->second.ids[0], txn1->GetID());
+
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+
+ s = txn2->GetForUpdate(read_options, "foo", &value);
+ ASSERT_TRUE(s.IsTimedOut());
+ ASSERT_EQ(s.ToString(), "Operation timed out: Timeout waiting to lock key");
+ ASSERT_EQ(get_perf_context()->key_lock_wait_count, 1);
+ ASSERT_GE(get_perf_context()->key_lock_wait_time, 0);
+
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->ClearAllCallBacks();
+
+ delete cfa;
+ delete txn1;
+ delete txn2;
+}
+
+TEST_P(TransactionTest, SharedLocks) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ TransactionOptions txn_options;
+ Status s;
+
+ txn_options.lock_timeout = 1;
+ s = db->Put(write_options, Slice("foo"), Slice("bar"));
+ ASSERT_OK(s);
+
+ Transaction* txn1 = db->BeginTransaction(write_options, txn_options);
+ Transaction* txn2 = db->BeginTransaction(write_options, txn_options);
+ Transaction* txn3 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn1);
+ ASSERT_TRUE(txn2);
+ ASSERT_TRUE(txn3);
+
+ // Test shared access between txns
+ s = txn1->GetForUpdate(read_options, "foo", nullptr, false /* exclusive */);
+ ASSERT_OK(s);
+
+ s = txn2->GetForUpdate(read_options, "foo", nullptr, false /* exclusive */);
+ ASSERT_OK(s);
+
+ s = txn3->GetForUpdate(read_options, "foo", nullptr, false /* exclusive */);
+ ASSERT_OK(s);
+
+ auto lock_data = db->GetLockStatusData();
+ ASSERT_EQ(lock_data.size(), 1);
+
+ auto cf_iterator = lock_data.begin();
+ ASSERT_EQ(cf_iterator->second.key, "foo");
+
+ // We compare whether the set of txns locking this key is the same. To do
+ // this, we need to sort both vectors so that the comparison is done
+ // correctly.
+ std::vector<TransactionID> expected_txns = {txn1->GetID(), txn2->GetID(),
+ txn3->GetID()};
+ std::vector<TransactionID> lock_txns = cf_iterator->second.ids;
+ ASSERT_EQ(expected_txns, lock_txns);
+ ASSERT_FALSE(cf_iterator->second.exclusive);
+
+ ASSERT_OK(txn1->Rollback());
+ ASSERT_OK(txn2->Rollback());
+ ASSERT_OK(txn3->Rollback());
+
+ // Test txn1 and txn2 sharing a lock and txn3 trying to obtain it.
+ s = txn1->GetForUpdate(read_options, "foo", nullptr, false /* exclusive */);
+ ASSERT_OK(s);
+
+ s = txn2->GetForUpdate(read_options, "foo", nullptr, false /* exclusive */);
+ ASSERT_OK(s);
+
+ s = txn3->GetForUpdate(read_options, "foo", nullptr);
+ ASSERT_TRUE(s.IsTimedOut());
+ ASSERT_EQ(s.ToString(), "Operation timed out: Timeout waiting to lock key");
+
+ txn1->UndoGetForUpdate("foo");
+ s = txn3->GetForUpdate(read_options, "foo", nullptr);
+ ASSERT_TRUE(s.IsTimedOut());
+ ASSERT_EQ(s.ToString(), "Operation timed out: Timeout waiting to lock key");
+
+ txn2->UndoGetForUpdate("foo");
+ s = txn3->GetForUpdate(read_options, "foo", nullptr);
+ ASSERT_OK(s);
+
+ ASSERT_OK(txn1->Rollback());
+ ASSERT_OK(txn2->Rollback());
+ ASSERT_OK(txn3->Rollback());
+
+ // Test txn1 and txn2 sharing a lock and txn2 trying to upgrade lock.
+ s = txn1->GetForUpdate(read_options, "foo", nullptr, false /* exclusive */);
+ ASSERT_OK(s);
+
+ s = txn2->GetForUpdate(read_options, "foo", nullptr, false /* exclusive */);
+ ASSERT_OK(s);
+
+ s = txn2->GetForUpdate(read_options, "foo", nullptr);
+ ASSERT_TRUE(s.IsTimedOut());
+ ASSERT_EQ(s.ToString(), "Operation timed out: Timeout waiting to lock key");
+
+ txn1->UndoGetForUpdate("foo");
+ s = txn2->GetForUpdate(read_options, "foo", nullptr);
+ ASSERT_OK(s);
+
+ ASSERT_OK(txn1->Rollback());
+ ASSERT_OK(txn2->Rollback());
+
+ // Test txn1 trying to downgrade its lock.
+ s = txn1->GetForUpdate(read_options, "foo", nullptr, true /* exclusive */);
+ ASSERT_OK(s);
+
+ s = txn2->GetForUpdate(read_options, "foo", nullptr, false /* exclusive */);
+ ASSERT_TRUE(s.IsTimedOut());
+ ASSERT_EQ(s.ToString(), "Operation timed out: Timeout waiting to lock key");
+
+ // Should still fail after "downgrading".
+ s = txn1->GetForUpdate(read_options, "foo", nullptr, false /* exclusive */);
+ ASSERT_OK(s);
+
+ s = txn2->GetForUpdate(read_options, "foo", nullptr, false /* exclusive */);
+ ASSERT_TRUE(s.IsTimedOut());
+ ASSERT_EQ(s.ToString(), "Operation timed out: Timeout waiting to lock key");
+
+ ASSERT_OK(txn1->Rollback());
+ ASSERT_OK(txn2->Rollback());
+
+ // Test txn1 holding an exclusive lock and txn2 trying to obtain shared
+ // access.
+ s = txn1->GetForUpdate(read_options, "foo", nullptr);
+ ASSERT_OK(s);
+
+ s = txn2->GetForUpdate(read_options, "foo", nullptr, false /* exclusive */);
+ ASSERT_TRUE(s.IsTimedOut());
+ ASSERT_EQ(s.ToString(), "Operation timed out: Timeout waiting to lock key");
+
+ txn1->UndoGetForUpdate("foo");
+ s = txn2->GetForUpdate(read_options, "foo", nullptr, false /* exclusive */);
+ ASSERT_OK(s);
+
+ delete txn1;
+ delete txn2;
+ delete txn3;
+}
+
+TEST_P(TransactionTest, DeadlockCycleShared) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ TransactionOptions txn_options;
+
+ txn_options.lock_timeout = 1000000;
+ txn_options.deadlock_detect = true;
+
+ // Set up a wait for chain like this:
+ //
+ // Tn -> T(n*2)
+ // Tn -> T(n*2 + 1)
+ //
+ // So we have:
+ // T1 -> T2 -> T4 ...
+ // | |> T5 ...
+ // |> T3 -> T6 ...
+ // |> T7 ...
+ // up to T31, then T[16 - 31] -> T1.
+ // Note that Tn holds lock on floor(n / 2).
+
+ std::vector<Transaction*> txns(31);
+
+ for (uint32_t i = 0; i < 31; i++) {
+ txns[i] = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txns[i]);
+ auto s = txns[i]->GetForUpdate(read_options, std::to_string((i + 1) / 2),
+ nullptr, false /* exclusive */);
+ ASSERT_OK(s);
+ }
+
+ std::atomic<uint32_t> checkpoints(0);
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "PointLockManager::AcquireWithTimeout:WaitingTxn",
+ [&](void* /*arg*/) { checkpoints.fetch_add(1); });
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+
+ // We want the leaf transactions to block and hold everyone back.
+ std::vector<port::Thread> threads;
+ for (uint32_t i = 0; i < 15; i++) {
+ std::function<void()> blocking_thread = [&, i] {
+ auto s = txns[i]->GetForUpdate(read_options, std::to_string(i + 1),
+ nullptr, true /* exclusive */);
+ ASSERT_OK(s);
+ ASSERT_OK(txns[i]->Rollback());
+ delete txns[i];
+ };
+ threads.emplace_back(blocking_thread);
+ }
+
+ // Wait until all threads are waiting on each other.
+ while (checkpoints.load() != 15) {
+ /* sleep override */
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
+ }
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->ClearAllCallBacks();
+
+ // Complete the cycle T[16 - 31] -> T1
+ for (uint32_t i = 15; i < 31; i++) {
+ auto s =
+ txns[i]->GetForUpdate(read_options, "0", nullptr, true /* exclusive */);
+ ASSERT_TRUE(s.IsDeadlock());
+
+ // Calculate next buffer len, plateau at 5 when 5 records are inserted.
+ const uint32_t curr_dlock_buffer_len_ =
+ (i - 14 > kInitialMaxDeadlocks) ? kInitialMaxDeadlocks : (i - 14);
+
+ auto dlock_buffer = db->GetDeadlockInfoBuffer();
+ ASSERT_EQ(dlock_buffer.size(), curr_dlock_buffer_len_);
+ auto dlock_entry = dlock_buffer[0].path;
+ ASSERT_EQ(dlock_entry.size(), kInitialMaxDeadlocks);
+ int64_t pre_deadlock_time = dlock_buffer[0].deadlock_time;
+ int64_t cur_deadlock_time = 0;
+ for (auto const& dl_path_rec : dlock_buffer) {
+ cur_deadlock_time = dl_path_rec.deadlock_time;
+ ASSERT_NE(cur_deadlock_time, 0);
+ ASSERT_TRUE(cur_deadlock_time <= pre_deadlock_time);
+ pre_deadlock_time = cur_deadlock_time;
+ }
+
+ int64_t curr_waiting_key = 0;
+
+ // Offset of each txn id from the root of the shared dlock tree's txn id.
+ int64_t offset_root = dlock_entry[0].m_txn_id - 1;
+ // Offset of the final entry in the dlock path from the root's txn id.
+ TransactionID leaf_id =
+ dlock_entry[dlock_entry.size() - 1].m_txn_id - offset_root;
+
+ for (auto it = dlock_entry.rbegin(); it != dlock_entry.rend(); ++it) {
+ auto dl_node = *it;
+ ASSERT_EQ(dl_node.m_txn_id, offset_root + leaf_id);
+ ASSERT_EQ(dl_node.m_cf_id, 0U);
+ ASSERT_EQ(dl_node.m_waiting_key, std::to_string(curr_waiting_key));
+ ASSERT_EQ(dl_node.m_exclusive, true);
+
+ if (curr_waiting_key == 0) {
+ curr_waiting_key = leaf_id;
+ }
+ curr_waiting_key /= 2;
+ leaf_id /= 2;
+ }
+ }
+
+ // Rollback the leaf transaction.
+ for (uint32_t i = 15; i < 31; i++) {
+ ASSERT_OK(txns[i]->Rollback());
+ delete txns[i];
+ }
+
+ for (auto& t : threads) {
+ t.join();
+ }
+
+ // Downsize the buffer and verify the 3 latest deadlocks are preserved.
+ auto dlock_buffer_before_resize = db->GetDeadlockInfoBuffer();
+ db->SetDeadlockInfoBufferSize(3);
+ auto dlock_buffer_after_resize = db->GetDeadlockInfoBuffer();
+ ASSERT_EQ(dlock_buffer_after_resize.size(), 3);
+
+ for (uint32_t i = 0; i < dlock_buffer_after_resize.size(); i++) {
+ for (uint32_t j = 0; j < dlock_buffer_after_resize[i].path.size(); j++) {
+ ASSERT_EQ(dlock_buffer_after_resize[i].path[j].m_txn_id,
+ dlock_buffer_before_resize[i].path[j].m_txn_id);
+ }
+ }
+
+ // Upsize the buffer and verify the 3 latest dealocks are preserved.
+ dlock_buffer_before_resize = db->GetDeadlockInfoBuffer();
+ db->SetDeadlockInfoBufferSize(5);
+ dlock_buffer_after_resize = db->GetDeadlockInfoBuffer();
+ ASSERT_EQ(dlock_buffer_after_resize.size(), 3);
+
+ for (uint32_t i = 0; i < dlock_buffer_before_resize.size(); i++) {
+ for (uint32_t j = 0; j < dlock_buffer_before_resize[i].path.size(); j++) {
+ ASSERT_EQ(dlock_buffer_after_resize[i].path[j].m_txn_id,
+ dlock_buffer_before_resize[i].path[j].m_txn_id);
+ }
+ }
+
+ // Downsize to 0 and verify the size is consistent.
+ dlock_buffer_before_resize = db->GetDeadlockInfoBuffer();
+ db->SetDeadlockInfoBufferSize(0);
+ dlock_buffer_after_resize = db->GetDeadlockInfoBuffer();
+ ASSERT_EQ(dlock_buffer_after_resize.size(), 0);
+
+ // Upsize from 0 to verify the size is persistent.
+ dlock_buffer_before_resize = db->GetDeadlockInfoBuffer();
+ db->SetDeadlockInfoBufferSize(3);
+ dlock_buffer_after_resize = db->GetDeadlockInfoBuffer();
+ ASSERT_EQ(dlock_buffer_after_resize.size(), 0);
+
+ // Contrived case of shared lock of cycle size 2 to verify that a shared
+ // lock causing a deadlock is correctly reported as "shared" in the buffer.
+ std::vector<Transaction*> txns_shared(2);
+
+ // Create a cycle of size 2.
+ for (uint32_t i = 0; i < 2; i++) {
+ txns_shared[i] = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txns_shared[i]);
+ auto s =
+ txns_shared[i]->GetForUpdate(read_options, std::to_string(i), nullptr);
+ ASSERT_OK(s);
+ }
+
+ std::atomic<uint32_t> checkpoints_shared(0);
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "PointLockManager::AcquireWithTimeout:WaitingTxn",
+ [&](void* /*arg*/) { checkpoints_shared.fetch_add(1); });
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+
+ std::vector<port::Thread> threads_shared;
+ for (uint32_t i = 0; i < 1; i++) {
+ std::function<void()> blocking_thread = [&, i] {
+ auto s = txns_shared[i]->GetForUpdate(read_options, std::to_string(i + 1),
+ nullptr);
+ ASSERT_OK(s);
+ ASSERT_OK(txns_shared[i]->Rollback());
+ delete txns_shared[i];
+ };
+ threads_shared.emplace_back(blocking_thread);
+ }
+
+ // Wait until all threads are waiting on each other.
+ while (checkpoints_shared.load() != 1) {
+ /* sleep override */
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
+ }
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->ClearAllCallBacks();
+
+ // Complete the cycle T2 -> T1 with a shared lock.
+ auto s = txns_shared[1]->GetForUpdate(read_options, "0", nullptr, false);
+ ASSERT_TRUE(s.IsDeadlock());
+
+ auto dlock_buffer = db->GetDeadlockInfoBuffer();
+
+ // Verify the size of the buffer and the single path.
+ ASSERT_EQ(dlock_buffer.size(), 1);
+ ASSERT_EQ(dlock_buffer[0].path.size(), 2);
+
+ // Verify the exclusivity field of the transactions in the deadlock path.
+ ASSERT_TRUE(dlock_buffer[0].path[0].m_exclusive);
+ ASSERT_FALSE(dlock_buffer[0].path[1].m_exclusive);
+ ASSERT_OK(txns_shared[1]->Rollback());
+ delete txns_shared[1];
+
+ for (auto& t : threads_shared) {
+ t.join();
+ }
+}
+
+#if !defined(ROCKSDB_VALGRIND_RUN) || defined(ROCKSDB_FULL_VALGRIND_RUN)
+TEST_P(TransactionStressTest, DeadlockCycle) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ TransactionOptions txn_options;
+
+ // offset by 2 from the max depth to test edge case
+ const uint32_t kMaxCycleLength = 52;
+
+ txn_options.lock_timeout = 1000000;
+ txn_options.deadlock_detect = true;
+
+ for (uint32_t len = 2; len < kMaxCycleLength; len++) {
+ // Set up a long wait for chain like this:
+ //
+ // T1 -> T2 -> T3 -> ... -> Tlen
+
+ std::vector<Transaction*> txns(len);
+
+ for (uint32_t i = 0; i < len; i++) {
+ txns[i] = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txns[i]);
+ auto s = txns[i]->GetForUpdate(read_options, std::to_string(i), nullptr);
+ ASSERT_OK(s);
+ }
+
+ std::atomic<uint32_t> checkpoints(0);
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "PointLockManager::AcquireWithTimeout:WaitingTxn",
+ [&](void* /*arg*/) { checkpoints.fetch_add(1); });
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+
+ // We want the last transaction in the chain to block and hold everyone
+ // back.
+ std::vector<port::Thread> threads;
+ for (uint32_t i = 0; i + 1 < len; i++) {
+ std::function<void()> blocking_thread = [&, i] {
+ auto s =
+ txns[i]->GetForUpdate(read_options, std::to_string(i + 1), nullptr);
+ ASSERT_OK(s);
+ ASSERT_OK(txns[i]->Rollback());
+ delete txns[i];
+ };
+ threads.emplace_back(blocking_thread);
+ }
+
+ // Wait until all threads are waiting on each other.
+ while (checkpoints.load() != len - 1) {
+ /* sleep override */
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
+ }
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->ClearAllCallBacks();
+
+ // Complete the cycle Tlen -> T1
+ auto s = txns[len - 1]->GetForUpdate(read_options, "0", nullptr);
+ ASSERT_TRUE(s.IsDeadlock());
+
+ const uint32_t dlock_buffer_size_ = (len - 1 > 5) ? 5 : (len - 1);
+ uint32_t curr_waiting_key = 0;
+ TransactionID curr_txn_id = txns[0]->GetID();
+
+ auto dlock_buffer = db->GetDeadlockInfoBuffer();
+ ASSERT_EQ(dlock_buffer.size(), dlock_buffer_size_);
+ uint32_t check_len = len;
+ bool check_limit_flag = false;
+
+ // Special case for a deadlock path that exceeds the maximum depth.
+ if (len > 50) {
+ check_len = 0;
+ check_limit_flag = true;
+ }
+ auto dlock_entry = dlock_buffer[0].path;
+ ASSERT_EQ(dlock_entry.size(), check_len);
+ ASSERT_EQ(dlock_buffer[0].limit_exceeded, check_limit_flag);
+
+ int64_t pre_deadlock_time = dlock_buffer[0].deadlock_time;
+ int64_t cur_deadlock_time = 0;
+ for (auto const& dl_path_rec : dlock_buffer) {
+ cur_deadlock_time = dl_path_rec.deadlock_time;
+ ASSERT_NE(cur_deadlock_time, 0);
+ ASSERT_TRUE(cur_deadlock_time <= pre_deadlock_time);
+ pre_deadlock_time = cur_deadlock_time;
+ }
+
+ // Iterates backwards over path verifying decreasing txn_ids.
+ for (auto it = dlock_entry.rbegin(); it != dlock_entry.rend(); ++it) {
+ auto dl_node = *it;
+ ASSERT_EQ(dl_node.m_txn_id, len + curr_txn_id - 1);
+ ASSERT_EQ(dl_node.m_cf_id, 0u);
+ ASSERT_EQ(dl_node.m_waiting_key, std::to_string(curr_waiting_key));
+ ASSERT_EQ(dl_node.m_exclusive, true);
+
+ curr_txn_id--;
+ if (curr_waiting_key == 0) {
+ curr_waiting_key = len;
+ }
+ curr_waiting_key--;
+ }
+
+ // Rollback the last transaction.
+ ASSERT_OK(txns[len - 1]->Rollback());
+ delete txns[len - 1];
+
+ for (auto& t : threads) {
+ t.join();
+ }
+ }
+}
+
+TEST_P(TransactionStressTest, DeadlockStress) {
+ const uint32_t NUM_TXN_THREADS = 10;
+ const uint32_t NUM_KEYS = 100;
+ const uint32_t NUM_ITERS = 1000;
+
+ WriteOptions write_options;
+ ReadOptions read_options;
+ TransactionOptions txn_options;
+
+ txn_options.lock_timeout = 1000000;
+ txn_options.deadlock_detect = true;
+ std::vector<std::string> keys;
+
+ for (uint32_t i = 0; i < NUM_KEYS; i++) {
+ ASSERT_OK(db->Put(write_options, Slice(std::to_string(i)), Slice("")));
+ keys.push_back(std::to_string(i));
+ }
+
+ size_t tid = std::hash<std::thread::id>()(std::this_thread::get_id());
+ Random rnd(static_cast<uint32_t>(tid));
+ std::function<void(uint32_t)> stress_thread = [&](uint32_t seed) {
+ std::default_random_engine g(seed);
+
+ Transaction* txn;
+ for (uint32_t i = 0; i < NUM_ITERS; i++) {
+ txn = db->BeginTransaction(write_options, txn_options);
+ auto random_keys = keys;
+ std::shuffle(random_keys.begin(), random_keys.end(), g);
+
+ // Lock keys in random order.
+ for (const auto& k : random_keys) {
+ // Lock mostly for shared access, but exclusive 1/4 of the time.
+ auto s =
+ txn->GetForUpdate(read_options, k, nullptr, txn->GetID() % 4 == 0);
+ if (!s.ok()) {
+ ASSERT_TRUE(s.IsDeadlock());
+ ASSERT_OK(txn->Rollback());
+ break;
+ }
+ }
+
+ delete txn;
+ }
+ };
+
+ std::vector<port::Thread> threads;
+ for (uint32_t i = 0; i < NUM_TXN_THREADS; i++) {
+ threads.emplace_back(stress_thread, rnd.Next());
+ }
+
+ for (auto& t : threads) {
+ t.join();
+ }
+}
+#endif // !defined(ROCKSDB_VALGRIND_RUN) || defined(ROCKSDB_FULL_VALGRIND_RUN)
+
+TEST_P(TransactionTest, CommitTimeBatchFailTest) {
+ WriteOptions write_options;
+ TransactionOptions txn_options;
+
+ std::string value;
+ Status s;
+
+ Transaction* txn1 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn1);
+
+ ASSERT_OK(txn1->GetCommitTimeWriteBatch()->Put("cat", "dog"));
+
+ s = txn1->Put("foo", "bar");
+ ASSERT_OK(s);
+
+ // fails due to non-empty commit-time batch
+ s = txn1->Commit();
+ ASSERT_EQ(s, Status::InvalidArgument());
+
+ delete txn1;
+}
+
+TEST_P(TransactionTest, LogMarkLeakTest) {
+ TransactionOptions txn_options;
+ WriteOptions write_options;
+ options.write_buffer_size = 1024;
+ ASSERT_OK(ReOpenNoDelete());
+ assert(db != nullptr);
+ Random rnd(47);
+ std::vector<Transaction*> txns;
+ DBImpl* db_impl = static_cast_with_check<DBImpl>(db->GetRootDB());
+ // At the beginning there should be no log containing prepare data
+ ASSERT_EQ(db_impl->TEST_FindMinLogContainingOutstandingPrep(), 0);
+ for (size_t i = 0; i < 100; i++) {
+ Transaction* txn = db->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(txn->SetName("xid" + std::to_string(i)));
+ ASSERT_OK(txn->Put(Slice("foo" + std::to_string(i)), Slice("bar")));
+ ASSERT_OK(txn->Prepare());
+ ASSERT_GT(db_impl->TEST_FindMinLogContainingOutstandingPrep(), 0);
+ if (rnd.OneIn(5)) {
+ txns.push_back(txn);
+ } else {
+ ASSERT_OK(txn->Commit());
+ delete txn;
+ }
+ ASSERT_OK(db_impl->TEST_FlushMemTable(true));
+ }
+ for (auto txn : txns) {
+ ASSERT_OK(txn->Commit());
+ delete txn;
+ }
+ // At the end there should be no log left containing prepare data
+ ASSERT_EQ(db_impl->TEST_FindMinLogContainingOutstandingPrep(), 0);
+ // Make sure that the underlying data structures are properly truncated and
+ // cause not leak
+ ASSERT_EQ(db_impl->TEST_PreparedSectionCompletedSize(), 0);
+ ASSERT_EQ(db_impl->TEST_LogsWithPrepSize(), 0);
+}
+
+TEST_P(TransactionTest, SimpleTwoPhaseTransactionTest) {
+ for (bool cwb4recovery : {true, false}) {
+ ASSERT_OK(ReOpen());
+ WriteOptions write_options;
+ ReadOptions read_options;
+
+ TransactionOptions txn_options;
+ txn_options.use_only_the_last_commit_time_batch_for_recovery = cwb4recovery;
+
+ std::string value;
+ Status s;
+
+ DBImpl* db_impl = static_cast_with_check<DBImpl>(db->GetRootDB());
+
+ Transaction* txn = db->BeginTransaction(write_options, txn_options);
+ s = txn->SetName("xid");
+ ASSERT_OK(s);
+
+ ASSERT_EQ(db->GetTransactionByName("xid"), txn);
+
+ // transaction put
+ s = txn->Put(Slice("foo"), Slice("bar"));
+ ASSERT_OK(s);
+ ASSERT_EQ(1, txn->GetNumPuts());
+
+ // regular db put
+ s = db->Put(write_options, Slice("foo2"), Slice("bar2"));
+ ASSERT_OK(s);
+ ASSERT_EQ(1, txn->GetNumPuts());
+
+ // regular db read
+ ASSERT_OK(db->Get(read_options, "foo2", &value));
+ ASSERT_EQ(value, "bar2");
+
+ // commit time put
+ if (cwb4recovery) {
+ ASSERT_OK(
+ txn->GetCommitTimeWriteBatch()->Put(Slice("gtid"), Slice("dogs")));
+ ASSERT_OK(
+ txn->GetCommitTimeWriteBatch()->Put(Slice("gtid2"), Slice("cats")));
+ }
+
+ // nothing has been prepped yet
+ ASSERT_EQ(db_impl->TEST_FindMinLogContainingOutstandingPrep(), 0);
+
+ s = txn->Prepare();
+ ASSERT_OK(s);
+
+ // data not im mem yet
+ s = db->Get(read_options, Slice("foo"), &value);
+ ASSERT_TRUE(s.IsNotFound());
+ s = db->Get(read_options, Slice("gtid"), &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ // find trans in list of prepared transactions
+ std::vector<Transaction*> prepared_trans;
+ db->GetAllPreparedTransactions(&prepared_trans);
+ ASSERT_EQ(prepared_trans.size(), 1);
+ ASSERT_EQ(prepared_trans.front()->GetName(), "xid");
+
+ auto log_containing_prep =
+ db_impl->TEST_FindMinLogContainingOutstandingPrep();
+ ASSERT_GT(log_containing_prep, 0);
+
+ // make commit
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ // value is now available
+ s = db->Get(read_options, "foo", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "bar");
+
+ // we already committed
+ s = txn->Commit();
+ ASSERT_EQ(s, Status::InvalidArgument());
+
+ // no longer is prepared results
+ db->GetAllPreparedTransactions(&prepared_trans);
+ ASSERT_EQ(prepared_trans.size(), 0);
+ ASSERT_EQ(db->GetTransactionByName("xid"), nullptr);
+
+ // heap should not care about prepared section anymore
+ ASSERT_EQ(db_impl->TEST_FindMinLogContainingOutstandingPrep(), 0);
+
+ switch (txn_db_options.write_policy) {
+ case WRITE_COMMITTED:
+ // but now our memtable should be referencing the prep section
+ ASSERT_GE(log_containing_prep, db_impl->MinLogNumberToKeep());
+ ASSERT_EQ(log_containing_prep,
+ db_impl->TEST_FindMinPrepLogReferencedByMemTable());
+ break;
+ case WRITE_PREPARED:
+ case WRITE_UNPREPARED:
+ // In these modes memtable do not ref the prep sections
+ ASSERT_EQ(0, db_impl->TEST_FindMinPrepLogReferencedByMemTable());
+ break;
+ default:
+ assert(false);
+ }
+
+ ASSERT_OK(db_impl->TEST_FlushMemTable(true));
+ // After flush the recoverable state must be visible
+ if (cwb4recovery) {
+ s = db->Get(read_options, "gtid", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "dogs");
+
+ s = db->Get(read_options, "gtid2", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "cats");
+ }
+
+ // after memtable flush we can now relese the log
+ ASSERT_GT(db_impl->MinLogNumberToKeep(), log_containing_prep);
+ ASSERT_EQ(0, db_impl->TEST_FindMinPrepLogReferencedByMemTable());
+
+ delete txn;
+
+ if (cwb4recovery) {
+ // kill and reopen to trigger recovery
+ s = ReOpenNoDelete();
+ ASSERT_OK(s);
+ assert(db != nullptr);
+ s = db->Get(read_options, "gtid", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "dogs");
+
+ s = db->Get(read_options, "gtid2", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "cats");
+ }
+ }
+}
+
+TEST_P(TransactionTest, TwoPhaseNameTest) {
+ Status s;
+
+ WriteOptions write_options;
+ TransactionOptions txn_options;
+ Transaction* txn1 = db->BeginTransaction(write_options, txn_options);
+ Transaction* txn2 = db->BeginTransaction(write_options, txn_options);
+ Transaction* txn3 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn3);
+ delete txn3;
+
+ // cant prepare txn without name
+ s = txn1->Prepare();
+ ASSERT_EQ(s, Status::InvalidArgument());
+
+ // name too short
+ s = txn1->SetName("");
+ ASSERT_EQ(s, Status::InvalidArgument());
+
+ // name too long
+ s = txn1->SetName(std::string(513, 'x'));
+ ASSERT_EQ(s, Status::InvalidArgument());
+
+ // valid set name
+ s = txn1->SetName("name1");
+ ASSERT_OK(s);
+
+ // cant have duplicate name
+ s = txn2->SetName("name1");
+ ASSERT_EQ(s, Status::InvalidArgument());
+
+ // shouldn't be able to prepare
+ s = txn2->Prepare();
+ ASSERT_EQ(s, Status::InvalidArgument());
+
+ // valid name set
+ s = txn2->SetName("name2");
+ ASSERT_OK(s);
+
+ // cant reset name
+ s = txn2->SetName("name3");
+ ASSERT_EQ(s, Status::InvalidArgument());
+
+ ASSERT_EQ(txn1->GetName(), "name1");
+ ASSERT_EQ(txn2->GetName(), "name2");
+
+ s = txn1->Prepare();
+ ASSERT_OK(s);
+
+ // can't rename after prepare
+ s = txn1->SetName("name4");
+ ASSERT_EQ(s, Status::InvalidArgument());
+
+ ASSERT_OK(txn1->Rollback());
+ ASSERT_OK(txn2->Rollback());
+ delete txn1;
+ delete txn2;
+}
+
+TEST_P(TransactionTest, TwoPhaseEmptyWriteTest) {
+ for (bool cwb4recovery : {true, false}) {
+ for (bool test_with_empty_wal : {true, false}) {
+ if (!cwb4recovery && test_with_empty_wal) {
+ continue;
+ }
+ ASSERT_OK(ReOpen());
+ Status s;
+ std::string value;
+
+ WriteOptions write_options;
+ ReadOptions read_options;
+ TransactionOptions txn_options;
+ txn_options.use_only_the_last_commit_time_batch_for_recovery =
+ cwb4recovery;
+ Transaction* txn1 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn1);
+ Transaction* txn2 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn2);
+
+ s = txn1->SetName("joe");
+ ASSERT_OK(s);
+
+ s = txn2->SetName("bob");
+ ASSERT_OK(s);
+
+ s = txn1->Prepare();
+ ASSERT_OK(s);
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ delete txn1;
+
+ if (cwb4recovery) {
+ ASSERT_OK(
+ txn2->GetCommitTimeWriteBatch()->Put(Slice("foo"), Slice("bar")));
+ }
+
+ s = txn2->Prepare();
+ ASSERT_OK(s);
+
+ s = txn2->Commit();
+ ASSERT_OK(s);
+
+ delete txn2;
+ if (cwb4recovery) {
+ if (test_with_empty_wal) {
+ DBImpl* db_impl = static_cast_with_check<DBImpl>(db->GetRootDB());
+ ASSERT_OK(db_impl->TEST_FlushMemTable(true));
+ // After flush the state must be visible
+ s = db->Get(read_options, "foo", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "bar");
+ }
+ ASSERT_OK(db->FlushWAL(true));
+ // kill and reopen to trigger recovery
+ s = ReOpenNoDelete();
+ ASSERT_OK(s);
+ assert(db != nullptr);
+ s = db->Get(read_options, "foo", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "bar");
+ }
+ }
+ }
+}
+
+#if !defined(ROCKSDB_VALGRIND_RUN) || defined(ROCKSDB_FULL_VALGRIND_RUN)
+TEST_P(TransactionStressTest, TwoPhaseExpirationTest) {
+ Status s;
+
+ WriteOptions write_options;
+ TransactionOptions txn_options;
+ txn_options.expiration = 500; // 500ms
+ Transaction* txn1 = db->BeginTransaction(write_options, txn_options);
+ Transaction* txn2 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn1);
+ ASSERT_TRUE(txn1);
+
+ s = txn1->SetName("joe");
+ ASSERT_OK(s);
+ s = txn2->SetName("bob");
+ ASSERT_OK(s);
+
+ s = txn1->Prepare();
+ ASSERT_OK(s);
+
+ /* sleep override */
+ std::this_thread::sleep_for(std::chrono::milliseconds(1000));
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ s = txn2->Prepare();
+ ASSERT_EQ(s, Status::Expired());
+
+ delete txn1;
+ delete txn2;
+}
+
+TEST_P(TransactionTest, TwoPhaseRollbackTest) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+
+ TransactionOptions txn_options;
+
+ std::string value;
+ Status s;
+
+ DBImpl* db_impl = static_cast_with_check<DBImpl>(db->GetRootDB());
+ Transaction* txn = db->BeginTransaction(write_options, txn_options);
+ s = txn->SetName("xid");
+ ASSERT_OK(s);
+
+ // transaction put
+ s = txn->Put(Slice("tfoo"), Slice("tbar"));
+ ASSERT_OK(s);
+
+ // value is readable form txn
+ s = txn->Get(read_options, Slice("tfoo"), &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "tbar");
+
+ // issue rollback
+ s = txn->Rollback();
+ ASSERT_OK(s);
+
+ // value is nolonger readable
+ s = txn->Get(read_options, Slice("tfoo"), &value);
+ ASSERT_TRUE(s.IsNotFound());
+ ASSERT_EQ(txn->GetNumPuts(), 0);
+
+ // put new txn values
+ s = txn->Put(Slice("tfoo2"), Slice("tbar2"));
+ ASSERT_OK(s);
+
+ // new value is readable from txn
+ s = txn->Get(read_options, Slice("tfoo2"), &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "tbar2");
+
+ s = txn->Prepare();
+ ASSERT_OK(s);
+
+ // flush to next wal
+ s = db->Put(write_options, Slice("foo"), Slice("bar"));
+ ASSERT_OK(s);
+ ASSERT_OK(db_impl->TEST_FlushMemTable(true));
+
+ // issue rollback (marker written to WAL)
+ s = txn->Rollback();
+ ASSERT_OK(s);
+
+ // value is nolonger readable
+ s = txn->Get(read_options, Slice("tfoo2"), &value);
+ ASSERT_TRUE(s.IsNotFound());
+ ASSERT_EQ(txn->GetNumPuts(), 0);
+
+ // make commit
+ s = txn->Commit();
+ ASSERT_EQ(s, Status::InvalidArgument());
+
+ // try rollback again
+ s = txn->Rollback();
+ ASSERT_EQ(s, Status::InvalidArgument());
+
+ delete txn;
+}
+
+TEST_P(TransactionTest, PersistentTwoPhaseTransactionTest) {
+ WriteOptions write_options;
+ write_options.sync = true;
+ write_options.disableWAL = false;
+ ReadOptions read_options;
+
+ TransactionOptions txn_options;
+
+ std::string value;
+ Status s;
+
+ DBImpl* db_impl = static_cast_with_check<DBImpl>(db->GetRootDB());
+
+ Transaction* txn = db->BeginTransaction(write_options, txn_options);
+ s = txn->SetName("xid");
+ ASSERT_OK(s);
+
+ ASSERT_EQ(db->GetTransactionByName("xid"), txn);
+
+ // transaction put
+ s = txn->Put(Slice("foo"), Slice("bar"));
+ ASSERT_OK(s);
+ ASSERT_EQ(1, txn->GetNumPuts());
+
+ // txn read
+ s = txn->Get(read_options, "foo", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "bar");
+
+ // regular db put
+ s = db->Put(write_options, Slice("foo2"), Slice("bar2"));
+ ASSERT_OK(s);
+ ASSERT_EQ(1, txn->GetNumPuts());
+
+ ASSERT_OK(db_impl->TEST_FlushMemTable(true));
+
+ // regular db read
+ db->Get(read_options, "foo2", &value);
+ ASSERT_EQ(value, "bar2");
+
+ // nothing has been prepped yet
+ ASSERT_EQ(db_impl->TEST_FindMinLogContainingOutstandingPrep(), 0);
+
+ // prepare
+ s = txn->Prepare();
+ ASSERT_OK(s);
+
+ // still not available to db
+ s = db->Get(read_options, Slice("foo"), &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ ASSERT_OK(db->FlushWAL(false));
+ delete txn;
+ // kill and reopen
+ reinterpret_cast<PessimisticTransactionDB*>(db)->TEST_Crash();
+ s = ReOpenNoDelete();
+ ASSERT_OK(s);
+ assert(db != nullptr);
+ db_impl = static_cast_with_check<DBImpl>(db->GetRootDB());
+
+ // find trans in list of prepared transactions
+ std::vector<Transaction*> prepared_trans;
+ db->GetAllPreparedTransactions(&prepared_trans);
+ ASSERT_EQ(prepared_trans.size(), 1);
+
+ txn = prepared_trans.front();
+ ASSERT_TRUE(txn);
+ ASSERT_EQ(txn->GetName(), "xid");
+ ASSERT_EQ(db->GetTransactionByName("xid"), txn);
+
+ // log has been marked
+ auto log_containing_prep =
+ db_impl->TEST_FindMinLogContainingOutstandingPrep();
+ ASSERT_GT(log_containing_prep, 0);
+
+ // value is readable from txn
+ s = txn->Get(read_options, "foo", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "bar");
+
+ // make commit
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ // value is now available
+ db->Get(read_options, "foo", &value);
+ ASSERT_EQ(value, "bar");
+
+ // we already committed
+ s = txn->Commit();
+ ASSERT_EQ(s, Status::InvalidArgument());
+
+ // no longer is prepared results
+ prepared_trans.clear();
+ db->GetAllPreparedTransactions(&prepared_trans);
+ ASSERT_EQ(prepared_trans.size(), 0);
+
+ // transaction should no longer be visible
+ ASSERT_EQ(db->GetTransactionByName("xid"), nullptr);
+
+ // heap should not care about prepared section anymore
+ ASSERT_EQ(db_impl->TEST_FindMinLogContainingOutstandingPrep(), 0);
+
+ switch (txn_db_options.write_policy) {
+ case WRITE_COMMITTED:
+ // but now our memtable should be referencing the prep section
+ ASSERT_EQ(log_containing_prep,
+ db_impl->TEST_FindMinPrepLogReferencedByMemTable());
+ ASSERT_GE(log_containing_prep, db_impl->MinLogNumberToKeep());
+
+ break;
+ case WRITE_PREPARED:
+ case WRITE_UNPREPARED:
+ // In these modes memtable do not ref the prep sections
+ ASSERT_EQ(0, db_impl->TEST_FindMinPrepLogReferencedByMemTable());
+ break;
+ default:
+ assert(false);
+ }
+
+ // Add a dummy record to memtable before a flush. Otherwise, the
+ // memtable will be empty and flush will be skipped.
+ s = db->Put(write_options, Slice("foo3"), Slice("bar3"));
+ ASSERT_OK(s);
+
+ ASSERT_OK(db_impl->TEST_FlushMemTable(true));
+
+ // after memtable flush we can now release the log
+ ASSERT_GT(db_impl->MinLogNumberToKeep(), log_containing_prep);
+ ASSERT_EQ(0, db_impl->TEST_FindMinPrepLogReferencedByMemTable());
+
+ delete txn;
+
+ // deleting transaction should unregister transaction
+ ASSERT_EQ(db->GetTransactionByName("xid"), nullptr);
+}
+#endif // !defined(ROCKSDB_VALGRIND_RUN) || defined(ROCKSDB_FULL_VALGRIND_RUN)
+
+// TODO this test needs to be updated with serial commits
+TEST_P(TransactionTest, DISABLED_TwoPhaseMultiThreadTest) {
+ // mix transaction writes and regular writes
+ const uint32_t NUM_TXN_THREADS = 50;
+ std::atomic<uint32_t> txn_thread_num(0);
+
+ std::function<void()> txn_write_thread = [&]() {
+ uint32_t id = txn_thread_num.fetch_add(1);
+
+ WriteOptions write_options;
+ write_options.sync = true;
+ write_options.disableWAL = false;
+ TransactionOptions txn_options;
+ txn_options.lock_timeout = 1000000;
+ if (id % 2 == 0) {
+ txn_options.expiration = 1000000;
+ }
+ TransactionName name("xid_" + std::string(1, 'A' + static_cast<char>(id)));
+ Transaction* txn = db->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(txn->SetName(name));
+ for (int i = 0; i < 10; i++) {
+ std::string key(name + "_" + std::string(1, static_cast<char>('A' + i)));
+ ASSERT_OK(txn->Put(key, "val"));
+ }
+ ASSERT_OK(txn->Prepare());
+ ASSERT_OK(txn->Commit());
+ delete txn;
+ };
+
+ // assure that all thread are in the same write group
+ std::atomic<uint32_t> t_wait_on_prepare(0);
+ std::atomic<uint32_t> t_wait_on_commit(0);
+
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "WriteThread::JoinBatchGroup:Wait", [&](void* arg) {
+ auto* writer = reinterpret_cast<WriteThread::Writer*>(arg);
+
+ if (writer->ShouldWriteToWAL()) {
+ t_wait_on_prepare.fetch_add(1);
+ // wait for friends
+ while (t_wait_on_prepare.load() < NUM_TXN_THREADS) {
+ env->SleepForMicroseconds(10);
+ }
+ } else if (writer->ShouldWriteToMemtable()) {
+ t_wait_on_commit.fetch_add(1);
+ // wait for friends
+ while (t_wait_on_commit.load() < NUM_TXN_THREADS) {
+ env->SleepForMicroseconds(10);
+ }
+ } else {
+ FAIL();
+ }
+ });
+
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+
+ // do all the writes
+ std::vector<port::Thread> threads;
+ for (uint32_t i = 0; i < NUM_TXN_THREADS; i++) {
+ threads.emplace_back(txn_write_thread);
+ }
+ for (auto& t : threads) {
+ t.join();
+ }
+
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->ClearAllCallBacks();
+
+ ReadOptions read_options;
+ std::string value;
+ Status s;
+ for (uint32_t t = 0; t < NUM_TXN_THREADS; t++) {
+ TransactionName name("xid_" + std::string(1, 'A' + static_cast<char>(t)));
+ for (int i = 0; i < 10; i++) {
+ std::string key(name + "_" + std::string(1, static_cast<char>('A' + i)));
+ s = db->Get(read_options, key, &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "val");
+ }
+ }
+}
+
+#if !defined(ROCKSDB_VALGRIND_RUN) || defined(ROCKSDB_FULL_VALGRIND_RUN)
+TEST_P(TransactionStressTest, TwoPhaseLongPrepareTest) {
+ WriteOptions write_options;
+ write_options.sync = true;
+ write_options.disableWAL = false;
+ ReadOptions read_options;
+ TransactionOptions txn_options;
+
+ std::string value;
+ Status s;
+
+ Transaction* txn = db->BeginTransaction(write_options, txn_options);
+ s = txn->SetName("bob");
+ ASSERT_OK(s);
+
+ // transaction put
+ s = txn->Put(Slice("foo"), Slice("bar"));
+ ASSERT_OK(s);
+
+ // prepare
+ s = txn->Prepare();
+ ASSERT_OK(s);
+
+ delete txn;
+
+ for (int i = 0; i < 1000; i++) {
+ std::string key(i, 'k');
+ std::string val(1000, 'v');
+ assert(db != nullptr);
+ s = db->Put(write_options, key, val);
+ ASSERT_OK(s);
+
+ if (i % 29 == 0) {
+ // crash
+ env->SetFilesystemActive(false);
+ reinterpret_cast<PessimisticTransactionDB*>(db)->TEST_Crash();
+ ReOpenNoDelete();
+ } else if (i % 37 == 0) {
+ // close
+ ReOpenNoDelete();
+ }
+ }
+
+ // commit old txn
+ txn = db->GetTransactionByName("bob");
+ ASSERT_TRUE(txn);
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ // verify data txn data
+ s = db->Get(read_options, "foo", &value);
+ ASSERT_EQ(s, Status::OK());
+ ASSERT_EQ(value, "bar");
+
+ // verify non txn data
+ for (int i = 0; i < 1000; i++) {
+ std::string key(i, 'k');
+ std::string val(1000, 'v');
+ s = db->Get(read_options, key, &value);
+ ASSERT_EQ(s, Status::OK());
+ ASSERT_EQ(value, val);
+ }
+
+ delete txn;
+}
+
+TEST_P(TransactionTest, TwoPhaseSequenceTest) {
+ WriteOptions write_options;
+ write_options.sync = true;
+ write_options.disableWAL = false;
+ ReadOptions read_options;
+
+ TransactionOptions txn_options;
+
+ std::string value;
+ Status s;
+
+ Transaction* txn = db->BeginTransaction(write_options, txn_options);
+ s = txn->SetName("xid");
+ ASSERT_OK(s);
+
+ // transaction put
+ s = txn->Put(Slice("foo"), Slice("bar"));
+ ASSERT_OK(s);
+ s = txn->Put(Slice("foo2"), Slice("bar2"));
+ ASSERT_OK(s);
+ s = txn->Put(Slice("foo3"), Slice("bar3"));
+ ASSERT_OK(s);
+ s = txn->Put(Slice("foo4"), Slice("bar4"));
+ ASSERT_OK(s);
+
+ // prepare
+ s = txn->Prepare();
+ ASSERT_OK(s);
+
+ // make commit
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ delete txn;
+
+ // kill and reopen
+ env->SetFilesystemActive(false);
+ ReOpenNoDelete();
+ assert(db != nullptr);
+
+ // value is now available
+ s = db->Get(read_options, "foo4", &value);
+ ASSERT_EQ(s, Status::OK());
+ ASSERT_EQ(value, "bar4");
+}
+#endif // !defined(ROCKSDB_VALGRIND_RUN) || defined(ROCKSDB_FULL_VALGRIND_RUN)
+
+TEST_P(TransactionTest, TwoPhaseDoubleRecoveryTest) {
+ WriteOptions write_options;
+ write_options.sync = true;
+ write_options.disableWAL = false;
+ ReadOptions read_options;
+
+ TransactionOptions txn_options;
+
+ std::string value;
+ Status s;
+
+ Transaction* txn = db->BeginTransaction(write_options, txn_options);
+ s = txn->SetName("a");
+ ASSERT_OK(s);
+
+ // transaction put
+ s = txn->Put(Slice("foo"), Slice("bar"));
+ ASSERT_OK(s);
+
+ // prepare
+ s = txn->Prepare();
+ ASSERT_OK(s);
+
+ delete txn;
+
+ // kill and reopen
+ env->SetFilesystemActive(false);
+ reinterpret_cast<PessimisticTransactionDB*>(db)->TEST_Crash();
+ ReOpenNoDelete();
+
+ // commit old txn
+ assert(db != nullptr); // Make clang analyze happy.
+ txn = db->GetTransactionByName("a");
+ assert(txn != nullptr);
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ s = db->Get(read_options, "foo", &value);
+ ASSERT_EQ(s, Status::OK());
+ ASSERT_EQ(value, "bar");
+
+ delete txn;
+
+ txn = db->BeginTransaction(write_options, txn_options);
+ s = txn->SetName("b");
+ ASSERT_OK(s);
+
+ s = txn->Put(Slice("foo2"), Slice("bar2"));
+ ASSERT_OK(s);
+
+ s = txn->Prepare();
+ ASSERT_OK(s);
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ delete txn;
+
+ // kill and reopen
+ env->SetFilesystemActive(false);
+ ASSERT_OK(ReOpenNoDelete());
+ assert(db != nullptr);
+
+ // value is now available
+ s = db->Get(read_options, "foo", &value);
+ ASSERT_EQ(s, Status::OK());
+ ASSERT_EQ(value, "bar");
+
+ s = db->Get(read_options, "foo2", &value);
+ ASSERT_EQ(s, Status::OK());
+ ASSERT_EQ(value, "bar2");
+}
+
+TEST_P(TransactionTest, TwoPhaseLogRollingTest) {
+ DBImpl* db_impl = static_cast_with_check<DBImpl>(db->GetRootDB());
+
+ Status s;
+ std::string v;
+ ColumnFamilyHandle *cfa, *cfb;
+
+ // Create 2 new column families
+ ColumnFamilyOptions cf_options;
+ s = db->CreateColumnFamily(cf_options, "CFA", &cfa);
+ ASSERT_OK(s);
+ s = db->CreateColumnFamily(cf_options, "CFB", &cfb);
+ ASSERT_OK(s);
+
+ WriteOptions wopts;
+ wopts.disableWAL = false;
+ wopts.sync = true;
+
+ TransactionOptions topts1;
+ Transaction* txn1 = db->BeginTransaction(wopts, topts1);
+ s = txn1->SetName("xid1");
+ ASSERT_OK(s);
+
+ TransactionOptions topts2;
+ Transaction* txn2 = db->BeginTransaction(wopts, topts2);
+ s = txn2->SetName("xid2");
+ ASSERT_OK(s);
+
+ // transaction put in two column families
+ s = txn1->Put(cfa, "ka1", "va1");
+ ASSERT_OK(s);
+
+ // transaction put in two column families
+ s = txn2->Put(cfa, "ka2", "va2");
+ ASSERT_OK(s);
+ s = txn2->Put(cfb, "kb2", "vb2");
+ ASSERT_OK(s);
+
+ // write prep section to wal
+ s = txn1->Prepare();
+ ASSERT_OK(s);
+
+ // our log should be in the heap
+ ASSERT_EQ(db_impl->TEST_FindMinLogContainingOutstandingPrep(),
+ txn1->GetLogNumber());
+ ASSERT_EQ(db_impl->TEST_LogfileNumber(), txn1->GetLastLogNumber());
+
+ // flush default cf to crate new log
+ s = db->Put(wopts, "foo", "bar");
+ ASSERT_OK(s);
+ s = db_impl->TEST_FlushMemTable(true);
+ ASSERT_OK(s);
+
+ // make sure we are on a new log
+ ASSERT_GT(db_impl->TEST_LogfileNumber(), txn1->GetLastLogNumber());
+
+ // put txn2 prep section in this log
+ s = txn2->Prepare();
+ ASSERT_OK(s);
+ ASSERT_EQ(db_impl->TEST_LogfileNumber(), txn2->GetLastLogNumber());
+
+ // heap should still see first log
+ ASSERT_EQ(db_impl->TEST_FindMinLogContainingOutstandingPrep(),
+ txn1->GetLogNumber());
+
+ // commit txn1
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ // heap should now show txn2s log
+ ASSERT_EQ(db_impl->TEST_FindMinLogContainingOutstandingPrep(),
+ txn2->GetLogNumber());
+
+ switch (txn_db_options.write_policy) {
+ case WRITE_COMMITTED:
+ // we should see txn1s log refernced by the memtables
+ ASSERT_EQ(txn1->GetLogNumber(),
+ db_impl->TEST_FindMinPrepLogReferencedByMemTable());
+ break;
+ case WRITE_PREPARED:
+ case WRITE_UNPREPARED:
+ // In these modes memtable do not ref the prep sections
+ ASSERT_EQ(0, db_impl->TEST_FindMinPrepLogReferencedByMemTable());
+ break;
+ default:
+ assert(false);
+ }
+
+ // flush default cf to crate new log
+ s = db->Put(wopts, "foo", "bar2");
+ ASSERT_OK(s);
+ s = db_impl->TEST_FlushMemTable(true);
+ ASSERT_OK(s);
+
+ // make sure we are on a new log
+ ASSERT_GT(db_impl->TEST_LogfileNumber(), txn2->GetLastLogNumber());
+
+ // commit txn2
+ s = txn2->Commit();
+ ASSERT_OK(s);
+
+ // heap should not show any logs
+ ASSERT_EQ(db_impl->TEST_FindMinLogContainingOutstandingPrep(), 0);
+
+ switch (txn_db_options.write_policy) {
+ case WRITE_COMMITTED:
+ // should show the first txn log
+ ASSERT_EQ(txn1->GetLogNumber(),
+ db_impl->TEST_FindMinPrepLogReferencedByMemTable());
+ break;
+ case WRITE_PREPARED:
+ case WRITE_UNPREPARED:
+ // In these modes memtable do not ref the prep sections
+ ASSERT_EQ(0, db_impl->TEST_FindMinPrepLogReferencedByMemTable());
+ break;
+ default:
+ assert(false);
+ }
+
+ // flush only cfa memtable
+ s = db_impl->TEST_FlushMemTable(true, false, cfa);
+ ASSERT_OK(s);
+
+ switch (txn_db_options.write_policy) {
+ case WRITE_COMMITTED:
+ // should show the first txn log
+ ASSERT_EQ(txn2->GetLogNumber(),
+ db_impl->TEST_FindMinPrepLogReferencedByMemTable());
+ break;
+ case WRITE_PREPARED:
+ case WRITE_UNPREPARED:
+ // In these modes memtable do not ref the prep sections
+ ASSERT_EQ(0, db_impl->TEST_FindMinPrepLogReferencedByMemTable());
+ break;
+ default:
+ assert(false);
+ }
+
+ // flush only cfb memtable
+ s = db_impl->TEST_FlushMemTable(true, false, cfb);
+ ASSERT_OK(s);
+
+ // should show not dependency on logs
+ ASSERT_EQ(db_impl->TEST_FindMinPrepLogReferencedByMemTable(), 0);
+ ASSERT_EQ(db_impl->TEST_FindMinLogContainingOutstandingPrep(), 0);
+
+ delete txn1;
+ delete txn2;
+ delete cfa;
+ delete cfb;
+}
+
+TEST_P(TransactionTest, TwoPhaseLogRollingTest2) {
+ DBImpl* db_impl = static_cast_with_check<DBImpl>(db->GetRootDB());
+
+ Status s;
+ ColumnFamilyHandle *cfa, *cfb;
+
+ ColumnFamilyOptions cf_options;
+ s = db->CreateColumnFamily(cf_options, "CFA", &cfa);
+ ASSERT_OK(s);
+ s = db->CreateColumnFamily(cf_options, "CFB", &cfb);
+ ASSERT_OK(s);
+
+ WriteOptions wopts;
+ wopts.disableWAL = false;
+ wopts.sync = true;
+
+ auto cfh_a = static_cast_with_check<ColumnFamilyHandleImpl>(cfa);
+ auto cfh_b = static_cast_with_check<ColumnFamilyHandleImpl>(cfb);
+
+ TransactionOptions topts1;
+ Transaction* txn1 = db->BeginTransaction(wopts, topts1);
+ s = txn1->SetName("xid1");
+ ASSERT_OK(s);
+ s = txn1->Put(cfa, "boys", "girls1");
+ ASSERT_OK(s);
+
+ Transaction* txn2 = db->BeginTransaction(wopts, topts1);
+ s = txn2->SetName("xid2");
+ ASSERT_OK(s);
+ s = txn2->Put(cfb, "up", "down1");
+ ASSERT_OK(s);
+
+ // prepre transaction in LOG A
+ s = txn1->Prepare();
+ ASSERT_OK(s);
+
+ // prepre transaction in LOG A
+ s = txn2->Prepare();
+ ASSERT_OK(s);
+
+ // regular put so that mem table can actually be flushed for log rolling
+ s = db->Put(wopts, "cats", "dogs1");
+ ASSERT_OK(s);
+
+ auto prepare_log_no = txn1->GetLastLogNumber();
+
+ // roll to LOG B
+ s = db_impl->TEST_FlushMemTable(true);
+ ASSERT_OK(s);
+
+ // now we pause background work so that
+ // imm()s are not flushed before we can check their status
+ s = db_impl->PauseBackgroundWork();
+ ASSERT_OK(s);
+
+ ASSERT_GT(db_impl->TEST_LogfileNumber(), prepare_log_no);
+ switch (txn_db_options.write_policy) {
+ case WRITE_COMMITTED:
+ // This cf is empty and should ref the latest log
+ ASSERT_GT(cfh_a->cfd()->GetLogNumber(), prepare_log_no);
+ ASSERT_EQ(cfh_a->cfd()->GetLogNumber(), db_impl->TEST_LogfileNumber());
+ break;
+ case WRITE_PREPARED:
+ case WRITE_UNPREPARED:
+ // This cf is not flushed yet and should ref the log that has its data
+ ASSERT_EQ(cfh_a->cfd()->GetLogNumber(), prepare_log_no);
+ break;
+ default:
+ assert(false);
+ }
+ ASSERT_EQ(db_impl->TEST_FindMinLogContainingOutstandingPrep(),
+ txn1->GetLogNumber());
+ ASSERT_EQ(db_impl->TEST_FindMinPrepLogReferencedByMemTable(), 0);
+
+ // commit in LOG B
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ switch (txn_db_options.write_policy) {
+ case WRITE_COMMITTED:
+ ASSERT_EQ(db_impl->TEST_FindMinPrepLogReferencedByMemTable(),
+ prepare_log_no);
+ break;
+ case WRITE_PREPARED:
+ case WRITE_UNPREPARED:
+ // In these modes memtable do not ref the prep sections
+ ASSERT_EQ(db_impl->TEST_FindMinPrepLogReferencedByMemTable(), 0);
+ break;
+ default:
+ assert(false);
+ }
+
+ ASSERT_TRUE(!db_impl->TEST_UnableToReleaseOldestLog());
+
+ // request a flush for all column families such that the earliest
+ // alive log file can be killed
+ ASSERT_OK(db_impl->TEST_SwitchWAL());
+ // log cannot be flushed because txn2 has not been commited
+ ASSERT_TRUE(!db_impl->TEST_IsLogGettingFlushed());
+ ASSERT_TRUE(db_impl->TEST_UnableToReleaseOldestLog());
+
+ // assert that cfa has a flush requested
+ ASSERT_TRUE(cfh_a->cfd()->imm()->HasFlushRequested());
+
+ switch (txn_db_options.write_policy) {
+ case WRITE_COMMITTED:
+ // cfb should not be flushed becuse it has no data from LOG A
+ ASSERT_TRUE(!cfh_b->cfd()->imm()->HasFlushRequested());
+ break;
+ case WRITE_PREPARED:
+ case WRITE_UNPREPARED:
+ // cfb should be flushed becuse it has prepared data from LOG A
+ ASSERT_TRUE(cfh_b->cfd()->imm()->HasFlushRequested());
+ break;
+ default:
+ assert(false);
+ }
+
+ // cfb now has data from LOG A
+ s = txn2->Commit();
+ ASSERT_OK(s);
+
+ ASSERT_OK(db_impl->TEST_SwitchWAL());
+ ASSERT_TRUE(!db_impl->TEST_UnableToReleaseOldestLog());
+
+ // we should see that cfb now has a flush requested
+ ASSERT_TRUE(cfh_b->cfd()->imm()->HasFlushRequested());
+
+ // all data in LOG A resides in a memtable that has been
+ // requested for a flush
+ ASSERT_TRUE(db_impl->TEST_IsLogGettingFlushed());
+
+ delete txn1;
+ delete txn2;
+ delete cfa;
+ delete cfb;
+}
+/*
+ * 1) use prepare to keep first log around to determine starting sequence
+ * during recovery.
+ * 2) insert many values, skipping wal, to increase seqid.
+ * 3) insert final value into wal
+ * 4) recover and see that final value was properly recovered - not
+ * hidden behind improperly summed sequence ids
+ */
+TEST_P(TransactionTest, TwoPhaseOutOfOrderDelete) {
+ DBImpl* db_impl = static_cast_with_check<DBImpl>(db->GetRootDB());
+ WriteOptions wal_on, wal_off;
+ wal_on.sync = true;
+ wal_on.disableWAL = false;
+ wal_off.disableWAL = true;
+ ReadOptions read_options;
+ TransactionOptions txn_options;
+
+ std::string value;
+ Status s;
+
+ Transaction* txn1 = db->BeginTransaction(wal_on, txn_options);
+
+ s = txn1->SetName("1");
+ ASSERT_OK(s);
+
+ s = db->Put(wal_on, "first", "first");
+ ASSERT_OK(s);
+
+ s = txn1->Put(Slice("dummy"), Slice("dummy"));
+ ASSERT_OK(s);
+ s = txn1->Prepare();
+ ASSERT_OK(s);
+
+ s = db->Put(wal_off, "cats", "dogs1");
+ ASSERT_OK(s);
+ s = db->Put(wal_off, "cats", "dogs2");
+ ASSERT_OK(s);
+ s = db->Put(wal_off, "cats", "dogs3");
+ ASSERT_OK(s);
+
+ s = db_impl->TEST_FlushMemTable(true);
+ ASSERT_OK(s);
+
+ s = db->Put(wal_on, "cats", "dogs4");
+ ASSERT_OK(s);
+
+ ASSERT_OK(db->FlushWAL(false));
+
+ // kill and reopen
+ env->SetFilesystemActive(false);
+ reinterpret_cast<PessimisticTransactionDB*>(db)->TEST_Crash();
+ ASSERT_OK(ReOpenNoDelete());
+ assert(db != nullptr);
+
+ s = db->Get(read_options, "first", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "first");
+
+ s = db->Get(read_options, "cats", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "dogs4");
+}
+
+TEST_P(TransactionTest, FirstWriteTest) {
+ WriteOptions write_options;
+
+ // Test conflict checking against the very first write to a db.
+ // The transaction's snapshot will have seq 1 and the following write
+ // will have sequence 1.
+ Status s = db->Put(write_options, "A", "a");
+
+ Transaction* txn = db->BeginTransaction(write_options);
+ txn->SetSnapshot();
+
+ ASSERT_OK(s);
+
+ s = txn->Put("A", "b");
+ ASSERT_OK(s);
+
+ delete txn;
+}
+
+TEST_P(TransactionTest, FirstWriteTest2) {
+ WriteOptions write_options;
+
+ Transaction* txn = db->BeginTransaction(write_options);
+ txn->SetSnapshot();
+
+ // Test conflict checking against the very first write to a db.
+ // The transaction's snapshot is a seq 0 while the following write
+ // will have sequence 1.
+ Status s = db->Put(write_options, "A", "a");
+ ASSERT_OK(s);
+
+ s = txn->Put("A", "b");
+ ASSERT_TRUE(s.IsBusy());
+
+ delete txn;
+}
+
+TEST_P(TransactionTest, WriteOptionsTest) {
+ WriteOptions write_options;
+ write_options.sync = true;
+ write_options.disableWAL = true;
+
+ Transaction* txn = db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn);
+
+ ASSERT_TRUE(txn->GetWriteOptions()->sync);
+
+ write_options.sync = false;
+ txn->SetWriteOptions(write_options);
+ ASSERT_FALSE(txn->GetWriteOptions()->sync);
+ ASSERT_TRUE(txn->GetWriteOptions()->disableWAL);
+
+ delete txn;
+}
+
+TEST_P(TransactionTest, WriteConflictTest) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ std::string value;
+ Status s;
+
+ ASSERT_OK(db->Put(write_options, "foo", "A"));
+ ASSERT_OK(db->Put(write_options, "foo2", "B"));
+
+ Transaction* txn = db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn);
+
+ s = txn->Put("foo", "A2");
+ ASSERT_OK(s);
+
+ s = txn->Put("foo2", "B2");
+ ASSERT_OK(s);
+
+ // This Put outside of a transaction will conflict with the previous write
+ s = db->Put(write_options, "foo", "xxx");
+ ASSERT_TRUE(s.IsTimedOut());
+
+ s = db->Get(read_options, "foo", &value);
+ ASSERT_EQ(value, "A");
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ db->Get(read_options, "foo", &value);
+ ASSERT_EQ(value, "A2");
+ db->Get(read_options, "foo2", &value);
+ ASSERT_EQ(value, "B2");
+
+ delete txn;
+}
+
+TEST_P(TransactionTest, WriteConflictTest2) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ TransactionOptions txn_options;
+ std::string value;
+ Status s;
+
+ ASSERT_OK(db->Put(write_options, "foo", "bar"));
+
+ txn_options.set_snapshot = true;
+ Transaction* txn = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn);
+
+ // This Put outside of a transaction will conflict with a later write
+ s = db->Put(write_options, "foo", "barz");
+ ASSERT_OK(s);
+
+ s = txn->Put("foo2", "X");
+ ASSERT_OK(s);
+
+ s = txn->Put("foo",
+ "bar2"); // Conflicts with write done after snapshot taken
+ ASSERT_TRUE(s.IsBusy());
+
+ s = txn->Put("foo3", "Y");
+ ASSERT_OK(s);
+
+ s = db->Get(read_options, "foo", &value);
+ ASSERT_EQ(value, "barz");
+
+ ASSERT_EQ(2, txn->GetNumKeys());
+
+ s = txn->Commit();
+ ASSERT_OK(s); // Txn should commit, but only write foo2 and foo3
+
+ // Verify that transaction wrote foo2 and foo3 but not foo
+ db->Get(read_options, "foo", &value);
+ ASSERT_EQ(value, "barz");
+
+ db->Get(read_options, "foo2", &value);
+ ASSERT_EQ(value, "X");
+
+ db->Get(read_options, "foo3", &value);
+ ASSERT_EQ(value, "Y");
+
+ delete txn;
+}
+
+TEST_P(TransactionTest, ReadConflictTest) {
+ WriteOptions write_options;
+ ReadOptions read_options, snapshot_read_options;
+ TransactionOptions txn_options;
+ std::string value;
+ Status s;
+
+ ASSERT_OK(db->Put(write_options, "foo", "bar"));
+ ASSERT_OK(db->Put(write_options, "foo2", "bar"));
+
+ txn_options.set_snapshot = true;
+ Transaction* txn = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn);
+
+ txn->SetSnapshot();
+ snapshot_read_options.snapshot = txn->GetSnapshot();
+
+ ASSERT_OK(txn->GetForUpdate(snapshot_read_options, "foo", &value));
+ ASSERT_EQ(value, "bar");
+
+ // This Put outside of a transaction will conflict with the previous read
+ s = db->Put(write_options, "foo", "barz");
+ ASSERT_TRUE(s.IsTimedOut());
+
+ s = db->Get(read_options, "foo", &value);
+ ASSERT_EQ(value, "bar");
+
+ s = txn->Get(read_options, "foo", &value);
+ ASSERT_EQ(value, "bar");
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ delete txn;
+}
+
+TEST_P(TransactionTest, TxnOnlyTest) {
+ // Test to make sure transactions work when there are no other writes in an
+ // empty db.
+
+ WriteOptions write_options;
+ ReadOptions read_options;
+ std::string value;
+ Status s;
+
+ Transaction* txn = db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn);
+
+ s = txn->Put("x", "y");
+ ASSERT_OK(s);
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ delete txn;
+}
+
+TEST_P(TransactionTest, FlushTest) {
+ WriteOptions write_options;
+ ReadOptions read_options, snapshot_read_options;
+ std::string value;
+ Status s;
+
+ ASSERT_OK(db->Put(write_options, Slice("foo"), Slice("bar")));
+ ASSERT_OK(db->Put(write_options, Slice("foo2"), Slice("bar")));
+
+ Transaction* txn = db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn);
+
+ snapshot_read_options.snapshot = txn->GetSnapshot();
+
+ ASSERT_OK(txn->GetForUpdate(snapshot_read_options, "foo", &value));
+ ASSERT_EQ(value, "bar");
+
+ s = txn->Put(Slice("foo"), Slice("bar2"));
+ ASSERT_OK(s);
+
+ ASSERT_OK(txn->GetForUpdate(snapshot_read_options, "foo", &value));
+ ASSERT_EQ(value, "bar2");
+
+ // Put a random key so we have a memtable to flush
+ s = db->Put(write_options, "dummy", "dummy");
+ ASSERT_OK(s);
+
+ // force a memtable flush
+ FlushOptions flush_ops;
+ db->Flush(flush_ops);
+
+ s = txn->Commit();
+ // txn should commit since the flushed table is still in MemtableList History
+ ASSERT_OK(s);
+
+ db->Get(read_options, "foo", &value);
+ ASSERT_EQ(value, "bar2");
+
+ delete txn;
+}
+
+TEST_P(TransactionTest, FlushTest2) {
+ const size_t num_tests = 3;
+
+ for (size_t n = 0; n < num_tests; n++) {
+ // Test different table factories
+ switch (n) {
+ case 0:
+ break;
+ case 1:
+ options.table_factory.reset(new mock::MockTableFactory());
+ break;
+ case 2: {
+ PlainTableOptions pt_opts;
+ pt_opts.hash_table_ratio = 0;
+ options.table_factory.reset(NewPlainTableFactory(pt_opts));
+ break;
+ }
+ }
+
+ Status s = ReOpen();
+ ASSERT_OK(s);
+ assert(db != nullptr);
+
+ WriteOptions write_options;
+ ReadOptions read_options, snapshot_read_options;
+ TransactionOptions txn_options;
+ std::string value;
+
+ DBImpl* db_impl = static_cast_with_check<DBImpl>(db->GetRootDB());
+
+ ASSERT_OK(db->Put(write_options, Slice("foo"), Slice("bar")));
+ ASSERT_OK(db->Put(write_options, Slice("foo2"), Slice("bar2")));
+ ASSERT_OK(db->Put(write_options, Slice("foo3"), Slice("bar3")));
+
+ txn_options.set_snapshot = true;
+ Transaction* txn = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn);
+
+ snapshot_read_options.snapshot = txn->GetSnapshot();
+
+ ASSERT_OK(txn->GetForUpdate(snapshot_read_options, "foo", &value));
+ ASSERT_EQ(value, "bar");
+
+ s = txn->Put(Slice("foo"), Slice("bar2"));
+ ASSERT_OK(s);
+
+ ASSERT_OK(txn->GetForUpdate(snapshot_read_options, "foo", &value));
+ ASSERT_EQ(value, "bar2");
+ // verify foo is locked by txn
+ s = db->Delete(write_options, "foo");
+ ASSERT_TRUE(s.IsTimedOut());
+
+ s = db->Put(write_options, "Z", "z");
+ ASSERT_OK(s);
+ s = db->Put(write_options, "dummy", "dummy");
+ ASSERT_OK(s);
+
+ s = db->Put(write_options, "S", "s");
+ ASSERT_OK(s);
+ s = db->SingleDelete(write_options, "S");
+ ASSERT_OK(s);
+
+ s = txn->Delete("S");
+ // Should fail after encountering a write to S in memtable
+ ASSERT_TRUE(s.IsBusy());
+
+ // force a memtable flush
+ s = db_impl->TEST_FlushMemTable(true);
+ ASSERT_OK(s);
+
+ // Put a random key so we have a MemTable to flush
+ s = db->Put(write_options, "dummy", "dummy2");
+ ASSERT_OK(s);
+
+ // force a memtable flush
+ ASSERT_OK(db_impl->TEST_FlushMemTable(true));
+
+ s = db->Put(write_options, "dummy", "dummy3");
+ ASSERT_OK(s);
+
+ // force a memtable flush
+ // Since our test db has max_write_buffer_number=2, this flush will cause
+ // the first memtable to get purged from the MemtableList history.
+ ASSERT_OK(db_impl->TEST_FlushMemTable(true));
+
+ s = txn->Put("X", "Y");
+ // Should succeed after verifying there is no write to X in SST file
+ ASSERT_OK(s);
+
+ s = txn->Put("Z", "zz");
+ // Should fail after encountering a write to Z in SST file
+ ASSERT_TRUE(s.IsBusy());
+
+ s = txn->GetForUpdate(read_options, "foo2", &value);
+ // should succeed since key was written before txn started
+ ASSERT_OK(s);
+ // verify foo2 is locked by txn
+ s = db->Delete(write_options, "foo2");
+ ASSERT_TRUE(s.IsTimedOut());
+
+ s = txn->Delete("S");
+ // Should fail after encountering a write to S in SST file
+ ASSERT_TRUE(s.IsBusy());
+
+ // Write a bunch of keys to db to force a compaction
+ Random rnd(47);
+ for (int i = 0; i < 1000; i++) {
+ s = db->Put(write_options, std::to_string(i),
+ test::CompressibleString(&rnd, 0.8, 100, &value));
+ ASSERT_OK(s);
+ }
+
+ s = txn->Put("X", "yy");
+ // Should succeed after verifying there is no write to X in SST file
+ ASSERT_OK(s);
+
+ s = txn->Put("Z", "zzz");
+ // Should fail after encountering a write to Z in SST file
+ ASSERT_TRUE(s.IsBusy());
+
+ s = txn->Delete("S");
+ // Should fail after encountering a write to S in SST file
+ ASSERT_TRUE(s.IsBusy());
+
+ s = txn->GetForUpdate(read_options, "foo3", &value);
+ // should succeed since key was written before txn started
+ ASSERT_OK(s);
+ // verify foo3 is locked by txn
+ s = db->Delete(write_options, "foo3");
+ ASSERT_TRUE(s.IsTimedOut());
+
+ ASSERT_OK(db_impl->TEST_WaitForCompact());
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ // Transaction should only write the keys that succeeded.
+ s = db->Get(read_options, "foo", &value);
+ ASSERT_EQ(value, "bar2");
+
+ s = db->Get(read_options, "X", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("yy", value);
+
+ s = db->Get(read_options, "Z", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("z", value);
+
+ delete txn;
+ }
+}
+
+TEST_P(TransactionTest, NoSnapshotTest) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ std::string value;
+ Status s;
+
+ ASSERT_OK(db->Put(write_options, "AAA", "bar"));
+
+ Transaction* txn = db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn);
+
+ // Modify key after transaction start
+ ASSERT_OK(db->Put(write_options, "AAA", "bar1"));
+
+ // Read and write without a snap
+ ASSERT_OK(txn->GetForUpdate(read_options, "AAA", &value));
+ ASSERT_EQ(value, "bar1");
+ s = txn->Put("AAA", "bar2");
+ ASSERT_OK(s);
+
+ // Should commit since read/write was done after data changed
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ ASSERT_OK(txn->GetForUpdate(read_options, "AAA", &value));
+ ASSERT_EQ(value, "bar2");
+
+ delete txn;
+}
+
+TEST_P(TransactionTest, MultipleSnapshotTest) {
+ WriteOptions write_options;
+ ReadOptions read_options, snapshot_read_options;
+ std::string value;
+ Status s;
+
+ ASSERT_OK(db->Put(write_options, "AAA", "bar"));
+ ASSERT_OK(db->Put(write_options, "BBB", "bar"));
+ ASSERT_OK(db->Put(write_options, "CCC", "bar"));
+
+ Transaction* txn = db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn);
+
+ ASSERT_OK(db->Put(write_options, "AAA", "bar1"));
+
+ // Read and write without a snapshot
+ ASSERT_OK(txn->GetForUpdate(read_options, "AAA", &value));
+ ASSERT_EQ(value, "bar1");
+ s = txn->Put("AAA", "bar2");
+ ASSERT_OK(s);
+
+ // Modify BBB before snapshot is taken
+ ASSERT_OK(db->Put(write_options, "BBB", "bar1"));
+
+ txn->SetSnapshot();
+ snapshot_read_options.snapshot = txn->GetSnapshot();
+
+ // Read and write with snapshot
+ ASSERT_OK(txn->GetForUpdate(snapshot_read_options, "BBB", &value));
+ ASSERT_EQ(value, "bar1");
+ s = txn->Put("BBB", "bar2");
+ ASSERT_OK(s);
+
+ ASSERT_OK(db->Put(write_options, "CCC", "bar1"));
+
+ // Set a new snapshot
+ txn->SetSnapshot();
+ snapshot_read_options.snapshot = txn->GetSnapshot();
+
+ // Read and write with snapshot
+ ASSERT_OK(txn->GetForUpdate(snapshot_read_options, "CCC", &value));
+ ASSERT_EQ(value, "bar1");
+ s = txn->Put("CCC", "bar2");
+ ASSERT_OK(s);
+
+ s = txn->GetForUpdate(read_options, "AAA", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "bar2");
+ s = txn->GetForUpdate(read_options, "BBB", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "bar2");
+ s = txn->GetForUpdate(read_options, "CCC", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "bar2");
+
+ s = db->Get(read_options, "AAA", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "bar1");
+ s = db->Get(read_options, "BBB", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "bar1");
+ s = db->Get(read_options, "CCC", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "bar1");
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ s = db->Get(read_options, "AAA", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "bar2");
+ s = db->Get(read_options, "BBB", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "bar2");
+ s = db->Get(read_options, "CCC", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "bar2");
+
+ // verify that we track multiple writes to the same key at different snapshots
+ delete txn;
+ txn = db->BeginTransaction(write_options);
+
+ // Potentially conflicting writes
+ ASSERT_OK(db->Put(write_options, "ZZZ", "zzz"));
+ ASSERT_OK(db->Put(write_options, "XXX", "xxx"));
+
+ txn->SetSnapshot();
+
+ TransactionOptions txn_options;
+ txn_options.set_snapshot = true;
+ Transaction* txn2 = db->BeginTransaction(write_options, txn_options);
+ txn2->SetSnapshot();
+
+ // This should not conflict in txn since the snapshot is later than the
+ // previous write (spoiler alert: it will later conflict with txn2).
+ s = txn->Put("ZZZ", "zzzz");
+ ASSERT_OK(s);
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ delete txn;
+
+ // This will conflict since the snapshot is earlier than another write to ZZZ
+ s = txn2->Put("ZZZ", "xxxxx");
+ ASSERT_TRUE(s.IsBusy());
+
+ s = txn2->Commit();
+ ASSERT_OK(s);
+
+ s = db->Get(read_options, "ZZZ", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "zzzz");
+
+ delete txn2;
+}
+
+TEST_P(TransactionTest, ColumnFamiliesTest) {
+ WriteOptions write_options;
+ ReadOptions read_options, snapshot_read_options;
+ TransactionOptions txn_options;
+ std::string value;
+ Status s;
+
+ ColumnFamilyHandle *cfa, *cfb;
+ ColumnFamilyOptions cf_options;
+
+ // Create 2 new column families
+ s = db->CreateColumnFamily(cf_options, "CFA", &cfa);
+ ASSERT_OK(s);
+ s = db->CreateColumnFamily(cf_options, "CFB", &cfb);
+ ASSERT_OK(s);
+
+ delete cfa;
+ delete cfb;
+ delete db;
+ db = nullptr;
+
+ // open DB with three column families
+ std::vector<ColumnFamilyDescriptor> column_families;
+ // have to open default column family
+ column_families.push_back(
+ ColumnFamilyDescriptor(kDefaultColumnFamilyName, ColumnFamilyOptions()));
+ // open the new column families
+ column_families.push_back(
+ ColumnFamilyDescriptor("CFA", ColumnFamilyOptions()));
+ column_families.push_back(
+ ColumnFamilyDescriptor("CFB", ColumnFamilyOptions()));
+
+ std::vector<ColumnFamilyHandle*> handles;
+
+ ASSERT_OK(ReOpenNoDelete(column_families, &handles));
+ assert(db != nullptr);
+
+ Transaction* txn = db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn);
+
+ txn->SetSnapshot();
+ snapshot_read_options.snapshot = txn->GetSnapshot();
+
+ txn_options.set_snapshot = true;
+ Transaction* txn2 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn2);
+
+ // Write some data to the db
+ WriteBatch batch;
+ ASSERT_OK(batch.Put("foo", "foo"));
+ ASSERT_OK(batch.Put(handles[1], "AAA", "bar"));
+ ASSERT_OK(batch.Put(handles[1], "AAAZZZ", "bar"));
+ s = db->Write(write_options, &batch);
+ ASSERT_OK(s);
+ ASSERT_OK(db->Delete(write_options, handles[1], "AAAZZZ"));
+
+ // These keys do not conflict with existing writes since they're in
+ // different column families
+ s = txn->Delete("AAA");
+ ASSERT_OK(s);
+ s = txn->GetForUpdate(snapshot_read_options, handles[1], "foo", &value);
+ ASSERT_TRUE(s.IsNotFound());
+ Slice key_slice("AAAZZZ");
+ Slice value_slices[2] = {Slice("bar"), Slice("bar")};
+ s = txn->Put(handles[2], SliceParts(&key_slice, 1),
+ SliceParts(value_slices, 2));
+ ASSERT_OK(s);
+ ASSERT_EQ(3, txn->GetNumKeys());
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+ s = db->Get(read_options, "AAA", &value);
+ ASSERT_TRUE(s.IsNotFound());
+ s = db->Get(read_options, handles[2], "AAAZZZ", &value);
+ ASSERT_EQ(value, "barbar");
+
+ Slice key_slices[3] = {Slice("AAA"), Slice("ZZ"), Slice("Z")};
+ Slice value_slice("barbarbar");
+
+ s = txn2->Delete(handles[2], "XXX");
+ ASSERT_OK(s);
+ s = txn2->Delete(handles[1], "XXX");
+ ASSERT_OK(s);
+
+ // This write will cause a conflict with the earlier batch write
+ s = txn2->Put(handles[1], SliceParts(key_slices, 3),
+ SliceParts(&value_slice, 1));
+ ASSERT_TRUE(s.IsBusy());
+
+ s = txn2->Commit();
+ ASSERT_OK(s);
+ // In the above the latest change to AAAZZZ in handles[1] is delete.
+ s = db->Get(read_options, handles[1], "AAAZZZ", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ delete txn;
+ delete txn2;
+
+ txn = db->BeginTransaction(write_options, txn_options);
+ snapshot_read_options.snapshot = txn->GetSnapshot();
+
+ txn2 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn);
+
+ std::vector<ColumnFamilyHandle*> multiget_cfh = {handles[1], handles[2],
+ handles[0], handles[2]};
+ std::vector<Slice> multiget_keys = {"AAA", "AAAZZZ", "foo", "foo"};
+ std::vector<std::string> values(4);
+ std::vector<Status> results = txn->MultiGetForUpdate(
+ snapshot_read_options, multiget_cfh, multiget_keys, &values);
+ ASSERT_OK(results[0]);
+ ASSERT_OK(results[1]);
+ ASSERT_OK(results[2]);
+ ASSERT_TRUE(results[3].IsNotFound());
+ ASSERT_EQ(values[0], "bar");
+ ASSERT_EQ(values[1], "barbar");
+ ASSERT_EQ(values[2], "foo");
+
+ s = txn->SingleDelete(handles[2], "ZZZ");
+ ASSERT_OK(s);
+ s = txn->Put(handles[2], "ZZZ", "YYY");
+ ASSERT_OK(s);
+ s = txn->Put(handles[2], "ZZZ", "YYYY");
+ ASSERT_OK(s);
+ s = txn->Delete(handles[2], "ZZZ");
+ ASSERT_OK(s);
+ s = txn->Put(handles[2], "AAAZZZ", "barbarbar");
+ ASSERT_OK(s);
+
+ ASSERT_EQ(5, txn->GetNumKeys());
+
+ // Txn should commit
+ s = txn->Commit();
+ ASSERT_OK(s);
+ s = db->Get(read_options, handles[2], "ZZZ", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ // Put a key which will conflict with the next txn using the previous snapshot
+ ASSERT_OK(db->Put(write_options, handles[2], "foo", "000"));
+
+ results = txn2->MultiGetForUpdate(snapshot_read_options, multiget_cfh,
+ multiget_keys, &values);
+ // All results should fail since there was a conflict
+ ASSERT_TRUE(results[0].IsBusy());
+ ASSERT_TRUE(results[1].IsBusy());
+ ASSERT_TRUE(results[2].IsBusy());
+ ASSERT_TRUE(results[3].IsBusy());
+
+ s = db->Get(read_options, handles[2], "foo", &value);
+ ASSERT_EQ(value, "000");
+
+ s = txn2->Commit();
+ ASSERT_OK(s);
+
+ s = db->DropColumnFamily(handles[1]);
+ ASSERT_OK(s);
+ s = db->DropColumnFamily(handles[2]);
+ ASSERT_OK(s);
+
+ delete txn;
+ delete txn2;
+
+ for (auto handle : handles) {
+ delete handle;
+ }
+}
+
+TEST_P(TransactionTest, MultiGetBatchedTest) {
+ WriteOptions write_options;
+ ReadOptions read_options, snapshot_read_options;
+ TransactionOptions txn_options;
+ std::string value;
+ Status s;
+
+ ColumnFamilyHandle* cf;
+ ColumnFamilyOptions cf_options;
+
+ // Create a new column families
+ s = db->CreateColumnFamily(cf_options, "CF", &cf);
+ ASSERT_OK(s);
+
+ delete cf;
+ delete db;
+ db = nullptr;
+
+ // open DB with three column families
+ std::vector<ColumnFamilyDescriptor> column_families;
+ // have to open default column family
+ column_families.push_back(
+ ColumnFamilyDescriptor(kDefaultColumnFamilyName, ColumnFamilyOptions()));
+ // open the new column families
+ cf_options.merge_operator = MergeOperators::CreateStringAppendOperator();
+ column_families.push_back(ColumnFamilyDescriptor("CF", cf_options));
+
+ std::vector<ColumnFamilyHandle*> handles;
+
+ options.merge_operator = MergeOperators::CreateStringAppendOperator();
+ ASSERT_OK(ReOpenNoDelete(column_families, &handles));
+ assert(db != nullptr);
+
+ // Write some data to the db
+ WriteBatch batch;
+ ASSERT_OK(batch.Put(handles[1], "aaa", "val1"));
+ ASSERT_OK(batch.Put(handles[1], "bbb", "val2"));
+ ASSERT_OK(batch.Put(handles[1], "ccc", "val3"));
+ ASSERT_OK(batch.Put(handles[1], "ddd", "foo"));
+ ASSERT_OK(batch.Put(handles[1], "eee", "val5"));
+ ASSERT_OK(batch.Put(handles[1], "fff", "val6"));
+ ASSERT_OK(batch.Merge(handles[1], "ggg", "foo"));
+ s = db->Write(write_options, &batch);
+ ASSERT_OK(s);
+
+ Transaction* txn = db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn);
+
+ txn->SetSnapshot();
+ snapshot_read_options.snapshot = txn->GetSnapshot();
+
+ txn_options.set_snapshot = true;
+ // Write some data to the db
+ s = txn->Delete(handles[1], "bbb");
+ ASSERT_OK(s);
+ s = txn->Put(handles[1], "ccc", "val3_new");
+ ASSERT_OK(s);
+ s = txn->Merge(handles[1], "ddd", "bar");
+ ASSERT_OK(s);
+
+ std::vector<Slice> keys = {"aaa", "bbb", "ccc", "ddd", "eee", "fff", "ggg"};
+ std::vector<PinnableSlice> values(keys.size());
+ std::vector<Status> statuses(keys.size());
+
+ txn->MultiGet(snapshot_read_options, handles[1], keys.size(), keys.data(),
+ values.data(), statuses.data());
+ ASSERT_TRUE(statuses[0].ok());
+ ASSERT_EQ(values[0], "val1");
+ ASSERT_TRUE(statuses[1].IsNotFound());
+ ASSERT_TRUE(statuses[2].ok());
+ ASSERT_EQ(values[2], "val3_new");
+ ASSERT_TRUE(statuses[3].ok());
+ ASSERT_EQ(values[3], "foo,bar");
+ ASSERT_TRUE(statuses[4].ok());
+ ASSERT_EQ(values[4], "val5");
+ ASSERT_TRUE(statuses[5].ok());
+ ASSERT_EQ(values[5], "val6");
+ ASSERT_TRUE(statuses[6].ok());
+ ASSERT_EQ(values[6], "foo");
+ delete txn;
+ for (auto handle : handles) {
+ delete handle;
+ }
+}
+
+// This test calls WriteBatchWithIndex::MultiGetFromBatchAndDB with a large
+// number of keys, i.e greater than MultiGetContext::MAX_BATCH_SIZE, which is
+// is 32. This forces autovector allocations in the MultiGet code paths
+// to use std::vector in addition to stack allocations. The MultiGet keys
+// includes Merges, which are handled specially in MultiGetFromBatchAndDB by
+// allocating an autovector of MergeContexts
+TEST_P(TransactionTest, MultiGetLargeBatchedTest) {
+ WriteOptions write_options;
+ ReadOptions read_options, snapshot_read_options;
+ std::string value;
+ Status s;
+
+ ColumnFamilyHandle* cf;
+ ColumnFamilyOptions cf_options;
+
+ std::vector<std::string> key_str;
+ for (int i = 0; i < 100; ++i) {
+ key_str.emplace_back(std::to_string(i));
+ }
+ // Create a new column families
+ s = db->CreateColumnFamily(cf_options, "CF", &cf);
+ ASSERT_OK(s);
+
+ delete cf;
+ delete db;
+ db = nullptr;
+
+ // open DB with three column families
+ std::vector<ColumnFamilyDescriptor> column_families;
+ // have to open default column family
+ column_families.push_back(
+ ColumnFamilyDescriptor(kDefaultColumnFamilyName, ColumnFamilyOptions()));
+ // open the new column families
+ cf_options.merge_operator = MergeOperators::CreateStringAppendOperator();
+ column_families.push_back(ColumnFamilyDescriptor("CF", cf_options));
+
+ std::vector<ColumnFamilyHandle*> handles;
+
+ options.merge_operator = MergeOperators::CreateStringAppendOperator();
+ ASSERT_OK(ReOpenNoDelete(column_families, &handles));
+ assert(db != nullptr);
+
+ // Write some data to the db
+ WriteBatch batch;
+ for (int i = 0; i < 3 * MultiGetContext::MAX_BATCH_SIZE; ++i) {
+ std::string val = "val" + std::to_string(i);
+ ASSERT_OK(batch.Put(handles[1], key_str[i], val));
+ }
+ s = db->Write(write_options, &batch);
+ ASSERT_OK(s);
+
+ WriteBatchWithIndex wb;
+ // Write some data to the db
+ s = wb.Delete(handles[1], std::to_string(1));
+ ASSERT_OK(s);
+ s = wb.Put(handles[1], std::to_string(2), "new_val" + std::to_string(2));
+ ASSERT_OK(s);
+ // Write a lot of merges so when we call MultiGetFromBatchAndDB later on,
+ // it is forced to use std::vector in ROCKSDB_NAMESPACE::autovector to
+ // allocate MergeContexts. The number of merges needs to be >
+ // MultiGetContext::MAX_BATCH_SIZE
+ for (int i = 8; i < MultiGetContext::MAX_BATCH_SIZE + 24; ++i) {
+ s = wb.Merge(handles[1], std::to_string(i), "merge");
+ ASSERT_OK(s);
+ }
+
+ // MultiGet a lot of keys in order to force std::vector reallocations
+ std::vector<Slice> keys;
+ for (int i = 0; i < MultiGetContext::MAX_BATCH_SIZE + 32; ++i) {
+ keys.emplace_back(key_str[i]);
+ }
+ std::vector<PinnableSlice> values(keys.size());
+ std::vector<Status> statuses(keys.size());
+
+ wb.MultiGetFromBatchAndDB(db, snapshot_read_options, handles[1], keys.size(),
+ keys.data(), values.data(), statuses.data(), false);
+ for (size_t i = 0; i < keys.size(); ++i) {
+ if (i == 1) {
+ ASSERT_TRUE(statuses[1].IsNotFound());
+ } else if (i == 2) {
+ ASSERT_TRUE(statuses[2].ok());
+ ASSERT_EQ(values[2], "new_val" + std::to_string(2));
+ } else if (i >= 8 && i < 56) {
+ ASSERT_TRUE(statuses[i].ok());
+ ASSERT_EQ(values[i], "val" + std::to_string(i) + ",merge");
+ } else {
+ ASSERT_TRUE(statuses[i].ok());
+ if (values[i] != "val" + std::to_string(i)) {
+ ASSERT_EQ(values[i], "val" + std::to_string(i));
+ }
+ }
+ }
+
+ for (auto handle : handles) {
+ delete handle;
+ }
+}
+
+TEST_P(TransactionTest, MultiGetSnapshot) {
+ WriteOptions write_options;
+ TransactionOptions transaction_options;
+ Transaction* txn1 = db->BeginTransaction(write_options, transaction_options);
+
+ Slice key = "foo";
+
+ Status s = txn1->Put(key, "bar");
+ ASSERT_OK(s);
+
+ s = txn1->SetName("test");
+ ASSERT_OK(s);
+
+ s = txn1->Prepare();
+ ASSERT_OK(s);
+
+ // Get snapshot between prepare and commit
+ // Un-committed data should be invisible to other transactions
+ const Snapshot* s1 = db->GetSnapshot();
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+ delete txn1;
+
+ Transaction* txn2 = db->BeginTransaction(write_options, transaction_options);
+ ReadOptions read_options;
+ read_options.snapshot = s1;
+
+ std::vector<Slice> keys;
+ std::vector<PinnableSlice> values(1);
+ std::vector<Status> statuses(1);
+ keys.push_back(key);
+ auto cfd = db->DefaultColumnFamily();
+ txn2->MultiGet(read_options, cfd, 1, keys.data(), values.data(),
+ statuses.data());
+ ASSERT_TRUE(statuses[0].IsNotFound());
+ delete txn2;
+
+ db->ReleaseSnapshot(s1);
+}
+
+TEST_P(TransactionTest, ColumnFamiliesTest2) {
+ WriteOptions write_options;
+ ReadOptions read_options, snapshot_read_options;
+ std::string value;
+ Status s;
+
+ ColumnFamilyHandle *one, *two;
+ ColumnFamilyOptions cf_options;
+
+ // Create 2 new column families
+ s = db->CreateColumnFamily(cf_options, "ONE", &one);
+ ASSERT_OK(s);
+ s = db->CreateColumnFamily(cf_options, "TWO", &two);
+ ASSERT_OK(s);
+
+ Transaction* txn1 = db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn1);
+ Transaction* txn2 = db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn2);
+
+ s = txn1->Put(one, "X", "1");
+ ASSERT_OK(s);
+ s = txn1->Put(two, "X", "2");
+ ASSERT_OK(s);
+ s = txn1->Put("X", "0");
+ ASSERT_OK(s);
+
+ s = txn2->Put(one, "X", "11");
+ ASSERT_TRUE(s.IsTimedOut());
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ // Drop first column family
+ s = db->DropColumnFamily(one);
+ ASSERT_OK(s);
+
+ // Should fail since column family was dropped.
+ s = txn2->Commit();
+ ASSERT_OK(s);
+
+ delete txn1;
+ txn1 = db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn1);
+
+ // Should fail since column family was dropped
+ s = txn1->Put(one, "X", "111");
+ ASSERT_TRUE(s.IsInvalidArgument());
+
+ s = txn1->Put(two, "X", "222");
+ ASSERT_OK(s);
+
+ s = txn1->Put("X", "000");
+ ASSERT_OK(s);
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ s = db->Get(read_options, two, "X", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("222", value);
+
+ s = db->Get(read_options, "X", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("000", value);
+
+ s = db->DropColumnFamily(two);
+ ASSERT_OK(s);
+
+ delete txn1;
+ delete txn2;
+
+ delete one;
+ delete two;
+}
+
+TEST_P(TransactionTest, EmptyTest) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ std::string value;
+ Status s;
+
+ s = db->Put(write_options, "aaa", "aaa");
+ ASSERT_OK(s);
+
+ Transaction* txn = db->BeginTransaction(write_options);
+ s = txn->Commit();
+ ASSERT_OK(s);
+ delete txn;
+
+ txn = db->BeginTransaction(write_options);
+ ASSERT_OK(txn->Rollback());
+ delete txn;
+
+ txn = db->BeginTransaction(write_options);
+ s = txn->GetForUpdate(read_options, "aaa", &value);
+ ASSERT_EQ(value, "aaa");
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+ delete txn;
+
+ txn = db->BeginTransaction(write_options);
+ txn->SetSnapshot();
+
+ s = txn->GetForUpdate(read_options, "aaa", &value);
+ ASSERT_EQ(value, "aaa");
+
+ // Conflicts with previous GetForUpdate
+ s = db->Put(write_options, "aaa", "xxx");
+ ASSERT_TRUE(s.IsTimedOut());
+
+ // transaction expired!
+ s = txn->Commit();
+ ASSERT_OK(s);
+ delete txn;
+}
+
+TEST_P(TransactionTest, PredicateManyPreceders) {
+ WriteOptions write_options;
+ ReadOptions read_options1, read_options2;
+ TransactionOptions txn_options;
+ std::string value;
+ Status s;
+
+ txn_options.set_snapshot = true;
+ Transaction* txn1 = db->BeginTransaction(write_options, txn_options);
+ read_options1.snapshot = txn1->GetSnapshot();
+
+ Transaction* txn2 = db->BeginTransaction(write_options);
+ txn2->SetSnapshot();
+ read_options2.snapshot = txn2->GetSnapshot();
+
+ std::vector<Slice> multiget_keys = {"1", "2", "3"};
+ std::vector<std::string> multiget_values;
+
+ std::vector<Status> results =
+ txn1->MultiGetForUpdate(read_options1, multiget_keys, &multiget_values);
+ ASSERT_EQ(results.size(), 3);
+ ASSERT_TRUE(results[0].IsNotFound());
+ ASSERT_TRUE(results[1].IsNotFound());
+ ASSERT_TRUE(results[2].IsNotFound());
+
+ s = txn2->Put("2", "x"); // Conflict's with txn1's MultiGetForUpdate
+ ASSERT_TRUE(s.IsTimedOut());
+
+ ASSERT_OK(txn2->Rollback());
+
+ multiget_values.clear();
+ results =
+ txn1->MultiGetForUpdate(read_options1, multiget_keys, &multiget_values);
+ ASSERT_EQ(results.size(), 3);
+ ASSERT_TRUE(results[0].IsNotFound());
+ ASSERT_TRUE(results[1].IsNotFound());
+ ASSERT_TRUE(results[2].IsNotFound());
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ delete txn1;
+ delete txn2;
+
+ txn1 = db->BeginTransaction(write_options, txn_options);
+ read_options1.snapshot = txn1->GetSnapshot();
+
+ txn2 = db->BeginTransaction(write_options, txn_options);
+ read_options2.snapshot = txn2->GetSnapshot();
+
+ s = txn1->Put("4", "x");
+ ASSERT_OK(s);
+
+ s = txn2->Delete("4"); // conflict
+ ASSERT_TRUE(s.IsTimedOut());
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ s = txn2->GetForUpdate(read_options2, "4", &value);
+ ASSERT_TRUE(s.IsBusy());
+
+ ASSERT_OK(txn2->Rollback());
+
+ delete txn1;
+ delete txn2;
+}
+
+TEST_P(TransactionTest, LostUpdate) {
+ WriteOptions write_options;
+ ReadOptions read_options, read_options1, read_options2;
+ TransactionOptions txn_options;
+ std::string value;
+ Status s;
+
+ // Test 2 transactions writing to the same key in multiple orders and
+ // with/without snapshots
+
+ Transaction* txn1 = db->BeginTransaction(write_options);
+ Transaction* txn2 = db->BeginTransaction(write_options);
+
+ s = txn1->Put("1", "1");
+ ASSERT_OK(s);
+
+ s = txn2->Put("1", "2"); // conflict
+ ASSERT_TRUE(s.IsTimedOut());
+
+ s = txn2->Commit();
+ ASSERT_OK(s);
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ s = db->Get(read_options, "1", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("1", value);
+
+ delete txn1;
+ delete txn2;
+
+ txn_options.set_snapshot = true;
+ txn1 = db->BeginTransaction(write_options, txn_options);
+ read_options1.snapshot = txn1->GetSnapshot();
+
+ txn2 = db->BeginTransaction(write_options, txn_options);
+ read_options2.snapshot = txn2->GetSnapshot();
+
+ s = txn1->Put("1", "3");
+ ASSERT_OK(s);
+ s = txn2->Put("1", "4"); // conflict
+ ASSERT_TRUE(s.IsTimedOut());
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ s = txn2->Commit();
+ ASSERT_OK(s);
+
+ s = db->Get(read_options, "1", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("3", value);
+
+ delete txn1;
+ delete txn2;
+
+ txn1 = db->BeginTransaction(write_options, txn_options);
+ read_options1.snapshot = txn1->GetSnapshot();
+
+ txn2 = db->BeginTransaction(write_options, txn_options);
+ read_options2.snapshot = txn2->GetSnapshot();
+
+ s = txn1->Put("1", "5");
+ ASSERT_OK(s);
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ s = txn2->Put("1", "6");
+ ASSERT_TRUE(s.IsBusy());
+ s = txn2->Commit();
+ ASSERT_OK(s);
+
+ s = db->Get(read_options, "1", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("5", value);
+
+ delete txn1;
+ delete txn2;
+
+ txn1 = db->BeginTransaction(write_options, txn_options);
+ read_options1.snapshot = txn1->GetSnapshot();
+
+ txn2 = db->BeginTransaction(write_options, txn_options);
+ read_options2.snapshot = txn2->GetSnapshot();
+
+ s = txn1->Put("1", "7");
+ ASSERT_OK(s);
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ txn2->SetSnapshot();
+ s = txn2->Put("1", "8");
+ ASSERT_OK(s);
+ s = txn2->Commit();
+ ASSERT_OK(s);
+
+ s = db->Get(read_options, "1", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("8", value);
+
+ delete txn1;
+ delete txn2;
+
+ txn1 = db->BeginTransaction(write_options);
+ txn2 = db->BeginTransaction(write_options);
+
+ s = txn1->Put("1", "9");
+ ASSERT_OK(s);
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ s = txn2->Put("1", "10");
+ ASSERT_OK(s);
+ s = txn2->Commit();
+ ASSERT_OK(s);
+
+ delete txn1;
+ delete txn2;
+
+ s = db->Get(read_options, "1", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "10");
+}
+
+TEST_P(TransactionTest, UntrackedWrites) {
+ if (txn_db_options.write_policy == WRITE_UNPREPARED) {
+ // TODO(lth): For WriteUnprepared, validate that untracked writes are
+ // not supported.
+ return;
+ }
+
+ WriteOptions write_options;
+ ReadOptions read_options;
+ std::string value;
+ Status s;
+
+ // Verify transaction rollback works for untracked keys.
+ Transaction* txn = db->BeginTransaction(write_options);
+ txn->SetSnapshot();
+
+ s = txn->PutUntracked("untracked", "0");
+ ASSERT_OK(s);
+ ASSERT_OK(txn->Rollback());
+ s = db->Get(read_options, "untracked", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ delete txn;
+ txn = db->BeginTransaction(write_options);
+ txn->SetSnapshot();
+
+ s = db->Put(write_options, "untracked", "x");
+ ASSERT_OK(s);
+
+ // Untracked writes should succeed even though key was written after snapshot
+ s = txn->PutUntracked("untracked", "1");
+ ASSERT_OK(s);
+ s = txn->MergeUntracked("untracked", "2");
+ ASSERT_OK(s);
+ s = txn->DeleteUntracked("untracked");
+ ASSERT_OK(s);
+
+ // Conflict
+ s = txn->Put("untracked", "3");
+ ASSERT_TRUE(s.IsBusy());
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ s = db->Get(read_options, "untracked", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ delete txn;
+}
+
+TEST_P(TransactionTest, ExpiredTransaction) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ TransactionOptions txn_options;
+ std::string value;
+ Status s;
+
+ // Set txn expiration timeout to 0 microseconds (expires instantly)
+ txn_options.expiration = 0;
+ Transaction* txn1 = db->BeginTransaction(write_options, txn_options);
+
+ s = txn1->Put("X", "1");
+ ASSERT_OK(s);
+
+ s = txn1->Put("Y", "1");
+ ASSERT_OK(s);
+
+ Transaction* txn2 = db->BeginTransaction(write_options);
+
+ // txn2 should be able to write to X since txn1 has expired
+ s = txn2->Put("X", "2");
+ ASSERT_OK(s);
+
+ s = txn2->Commit();
+ ASSERT_OK(s);
+ s = db->Get(read_options, "X", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("2", value);
+
+ s = txn1->Put("Z", "1");
+ ASSERT_OK(s);
+
+ // txn1 should fail to commit since it is expired
+ s = txn1->Commit();
+ ASSERT_TRUE(s.IsExpired());
+
+ s = db->Get(read_options, "Y", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = db->Get(read_options, "Z", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ delete txn1;
+ delete txn2;
+}
+
+TEST_P(TransactionTest, ReinitializeTest) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ TransactionOptions txn_options;
+ std::string value;
+ Status s;
+
+ // Set txn expiration timeout to 0 microseconds (expires instantly)
+ txn_options.expiration = 0;
+ Transaction* txn1 = db->BeginTransaction(write_options, txn_options);
+
+ // Reinitialize transaction to no long expire
+ txn_options.expiration = -1;
+ txn1 = db->BeginTransaction(write_options, txn_options, txn1);
+
+ s = txn1->Put("Z", "z");
+ ASSERT_OK(s);
+
+ // Should commit since not expired
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ txn1 = db->BeginTransaction(write_options, txn_options, txn1);
+
+ s = txn1->Put("Z", "zz");
+ ASSERT_OK(s);
+
+ // Reinitilize txn1 and verify that Z gets unlocked
+ txn1 = db->BeginTransaction(write_options, txn_options, txn1);
+
+ Transaction* txn2 = db->BeginTransaction(write_options, txn_options, nullptr);
+ s = txn2->Put("Z", "zzz");
+ ASSERT_OK(s);
+ s = txn2->Commit();
+ ASSERT_OK(s);
+ delete txn2;
+
+ s = db->Get(read_options, "Z", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "zzz");
+
+ // Verify snapshots get reinitialized correctly
+ txn1->SetSnapshot();
+ s = txn1->Put("Z", "zzzz");
+ ASSERT_OK(s);
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ s = db->Get(read_options, "Z", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "zzzz");
+
+ txn1 = db->BeginTransaction(write_options, txn_options, txn1);
+ const Snapshot* snapshot = txn1->GetSnapshot();
+ ASSERT_FALSE(snapshot);
+
+ txn_options.set_snapshot = true;
+ txn1 = db->BeginTransaction(write_options, txn_options, txn1);
+ snapshot = txn1->GetSnapshot();
+ ASSERT_TRUE(snapshot);
+
+ s = txn1->Put("Z", "a");
+ ASSERT_OK(s);
+
+ ASSERT_OK(txn1->Rollback());
+
+ s = txn1->Put("Y", "y");
+ ASSERT_OK(s);
+
+ txn_options.set_snapshot = false;
+ txn1 = db->BeginTransaction(write_options, txn_options, txn1);
+ snapshot = txn1->GetSnapshot();
+ ASSERT_FALSE(snapshot);
+
+ s = txn1->Put("X", "x");
+ ASSERT_OK(s);
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ s = db->Get(read_options, "Z", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "zzzz");
+
+ s = db->Get(read_options, "Y", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ txn1 = db->BeginTransaction(write_options, txn_options, txn1);
+
+ s = txn1->SetName("name");
+ ASSERT_OK(s);
+
+ s = txn1->Prepare();
+ ASSERT_OK(s);
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ txn1 = db->BeginTransaction(write_options, txn_options, txn1);
+
+ s = txn1->SetName("name");
+ ASSERT_OK(s);
+
+ delete txn1;
+}
+
+TEST_P(TransactionTest, Rollback) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ TransactionOptions txn_options;
+ std::string value;
+ Status s;
+
+ Transaction* txn1 = db->BeginTransaction(write_options, txn_options);
+
+ ASSERT_OK(s);
+
+ s = txn1->Put("X", "1");
+ ASSERT_OK(s);
+
+ Transaction* txn2 = db->BeginTransaction(write_options);
+
+ // txn2 should not be able to write to X since txn1 has it locked
+ s = txn2->Put("X", "2");
+ ASSERT_TRUE(s.IsTimedOut());
+
+ ASSERT_OK(txn1->Rollback());
+ delete txn1;
+
+ // txn2 should now be able to write to X
+ s = txn2->Put("X", "3");
+ ASSERT_OK(s);
+
+ s = txn2->Commit();
+ ASSERT_OK(s);
+
+ s = db->Get(read_options, "X", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("3", value);
+
+ delete txn2;
+}
+
+TEST_P(TransactionTest, LockLimitTest) {
+ WriteOptions write_options;
+ ReadOptions read_options, snapshot_read_options;
+ TransactionOptions txn_options;
+ std::string value;
+ Status s;
+
+ delete db;
+ db = nullptr;
+
+ // Open DB with a lock limit of 3
+ txn_db_options.max_num_locks = 3;
+ ASSERT_OK(ReOpen());
+ assert(db != nullptr);
+ ASSERT_OK(s);
+
+ // Create a txn and verify we can only lock up to 3 keys
+ Transaction* txn = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn);
+
+ s = txn->Put("X", "x");
+ ASSERT_OK(s);
+
+ s = txn->Put("Y", "y");
+ ASSERT_OK(s);
+
+ s = txn->Put("Z", "z");
+ ASSERT_OK(s);
+
+ // lock limit reached
+ s = txn->Put("W", "w");
+ ASSERT_TRUE(s.IsBusy());
+
+ // re-locking same key shouldn't put us over the limit
+ s = txn->Put("X", "xx");
+ ASSERT_OK(s);
+
+ s = txn->GetForUpdate(read_options, "W", &value);
+ ASSERT_TRUE(s.IsBusy());
+ s = txn->GetForUpdate(read_options, "V", &value);
+ ASSERT_TRUE(s.IsBusy());
+
+ // re-locking same key shouldn't put us over the limit
+ s = txn->GetForUpdate(read_options, "Y", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("y", value);
+
+ s = txn->Get(read_options, "W", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ Transaction* txn2 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn2);
+
+ // "X" currently locked
+ s = txn2->Put("X", "x");
+ ASSERT_TRUE(s.IsTimedOut());
+
+ // lock limit reached
+ s = txn2->Put("M", "m");
+ ASSERT_TRUE(s.IsBusy());
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ s = db->Get(read_options, "X", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("xx", value);
+
+ s = db->Get(read_options, "W", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ // Committing txn should release its locks and allow txn2 to proceed
+ s = txn2->Put("X", "x2");
+ ASSERT_OK(s);
+
+ s = txn2->Delete("X");
+ ASSERT_OK(s);
+
+ s = txn2->Put("M", "m");
+ ASSERT_OK(s);
+
+ s = txn2->Put("Z", "z2");
+ ASSERT_OK(s);
+
+ // lock limit reached
+ s = txn2->Delete("Y");
+ ASSERT_TRUE(s.IsBusy());
+
+ s = txn2->Commit();
+ ASSERT_OK(s);
+
+ s = db->Get(read_options, "Z", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("z2", value);
+
+ s = db->Get(read_options, "Y", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("y", value);
+
+ s = db->Get(read_options, "X", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ delete txn;
+ delete txn2;
+}
+
+TEST_P(TransactionTest, IteratorTest) {
+ // This test does writes without snapshot validation, and then tries to create
+ // iterator later, which is unsupported in write unprepared.
+ if (txn_db_options.write_policy == WRITE_UNPREPARED) {
+ return;
+ }
+
+ WriteOptions write_options;
+ ReadOptions read_options, snapshot_read_options;
+ std::string value;
+ Status s;
+
+ // Write some keys to the db
+ s = db->Put(write_options, "A", "a");
+ ASSERT_OK(s);
+
+ s = db->Put(write_options, "G", "g");
+ ASSERT_OK(s);
+
+ s = db->Put(write_options, "F", "f");
+ ASSERT_OK(s);
+
+ s = db->Put(write_options, "C", "c");
+ ASSERT_OK(s);
+
+ s = db->Put(write_options, "D", "d");
+ ASSERT_OK(s);
+
+ Transaction* txn = db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn);
+
+ // Write some keys in a txn
+ s = txn->Put("B", "b");
+ ASSERT_OK(s);
+
+ s = txn->Put("H", "h");
+ ASSERT_OK(s);
+
+ s = txn->Delete("D");
+ ASSERT_OK(s);
+
+ s = txn->Put("E", "e");
+ ASSERT_OK(s);
+
+ txn->SetSnapshot();
+ const Snapshot* snapshot = txn->GetSnapshot();
+
+ // Write some keys to the db after the snapshot
+ s = db->Put(write_options, "BB", "xx");
+ ASSERT_OK(s);
+
+ s = db->Put(write_options, "C", "xx");
+ ASSERT_OK(s);
+
+ read_options.snapshot = snapshot;
+ Iterator* iter = txn->GetIterator(read_options);
+ ASSERT_OK(iter->status());
+ iter->SeekToFirst();
+
+ // Read all keys via iter and lock them all
+ std::string results[] = {"a", "b", "c", "e", "f", "g", "h"};
+ for (int i = 0; i < 7; i++) {
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ(results[i], iter->value().ToString());
+
+ s = txn->GetForUpdate(read_options, iter->key(), nullptr);
+ if (i == 2) {
+ // "C" was modified after txn's snapshot
+ ASSERT_TRUE(s.IsBusy());
+ } else {
+ ASSERT_OK(s);
+ }
+
+ iter->Next();
+ }
+ ASSERT_FALSE(iter->Valid());
+
+ iter->Seek("G");
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("g", iter->value().ToString());
+
+ iter->Prev();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("f", iter->value().ToString());
+
+ iter->Seek("D");
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("e", iter->value().ToString());
+
+ iter->Seek("C");
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("c", iter->value().ToString());
+
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("e", iter->value().ToString());
+
+ iter->Seek("");
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("a", iter->value().ToString());
+
+ iter->Seek("X");
+ ASSERT_OK(iter->status());
+ ASSERT_FALSE(iter->Valid());
+
+ iter->SeekToLast();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("h", iter->value().ToString());
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ delete iter;
+ delete txn;
+}
+
+TEST_P(TransactionTest, DisableIndexingTest) {
+ // Skip this test for write unprepared. It does not solely rely on WBWI for
+ // read your own writes, so depending on whether batches are flushed or not,
+ // only some writes will be visible.
+ //
+ // Also, write unprepared does not support creating iterators if there has
+ // been txn->Put() without snapshot validation.
+ if (txn_db_options.write_policy == WRITE_UNPREPARED) {
+ return;
+ }
+
+ WriteOptions write_options;
+ ReadOptions read_options;
+ std::string value;
+ Status s;
+
+ Transaction* txn = db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn);
+
+ s = txn->Put("A", "a");
+ ASSERT_OK(s);
+
+ s = txn->Get(read_options, "A", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("a", value);
+
+ txn->DisableIndexing();
+
+ s = txn->Put("B", "b");
+ ASSERT_OK(s);
+
+ s = txn->Get(read_options, "B", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ Iterator* iter = txn->GetIterator(read_options);
+ ASSERT_OK(iter->status());
+
+ iter->Seek("B");
+ ASSERT_OK(iter->status());
+ ASSERT_FALSE(iter->Valid());
+
+ s = txn->Delete("A");
+
+ s = txn->Get(read_options, "A", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("a", value);
+
+ txn->EnableIndexing();
+
+ s = txn->Put("B", "bb");
+ ASSERT_OK(s);
+
+ iter->Seek("B");
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("bb", iter->value().ToString());
+
+ s = txn->Get(read_options, "B", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("bb", value);
+
+ s = txn->Put("A", "aa");
+ ASSERT_OK(s);
+
+ s = txn->Get(read_options, "A", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("aa", value);
+
+ delete iter;
+ delete txn;
+}
+
+TEST_P(TransactionTest, SavepointTest) {
+ WriteOptions write_options;
+ ReadOptions read_options, snapshot_read_options;
+ std::string value;
+ Status s;
+
+ Transaction* txn = db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn);
+
+ ASSERT_EQ(0, txn->GetNumPuts());
+
+ s = txn->RollbackToSavePoint();
+ ASSERT_TRUE(s.IsNotFound());
+
+ txn->SetSavePoint(); // 1
+
+ ASSERT_OK(txn->RollbackToSavePoint()); // Rollback to beginning of txn
+ s = txn->RollbackToSavePoint();
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn->Put("B", "b");
+ ASSERT_OK(s);
+
+ ASSERT_EQ(1, txn->GetNumPuts());
+ ASSERT_EQ(0, txn->GetNumDeletes());
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ s = db->Get(read_options, "B", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("b", value);
+
+ delete txn;
+ txn = db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn);
+
+ s = txn->Put("A", "a");
+ ASSERT_OK(s);
+
+ s = txn->Put("B", "bb");
+ ASSERT_OK(s);
+
+ s = txn->Put("C", "c");
+ ASSERT_OK(s);
+
+ txn->SetSavePoint(); // 2
+
+ s = txn->Delete("B");
+ ASSERT_OK(s);
+
+ s = txn->Put("C", "cc");
+ ASSERT_OK(s);
+
+ s = txn->Put("D", "d");
+ ASSERT_OK(s);
+
+ ASSERT_EQ(5, txn->GetNumPuts());
+ ASSERT_EQ(1, txn->GetNumDeletes());
+
+ ASSERT_OK(txn->RollbackToSavePoint()); // Rollback to 2
+
+ ASSERT_EQ(3, txn->GetNumPuts());
+ ASSERT_EQ(0, txn->GetNumDeletes());
+
+ s = txn->Get(read_options, "A", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("a", value);
+
+ s = txn->Get(read_options, "B", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("bb", value);
+
+ s = txn->Get(read_options, "C", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("c", value);
+
+ s = txn->Get(read_options, "D", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn->Put("A", "a");
+ ASSERT_OK(s);
+
+ s = txn->Put("E", "e");
+ ASSERT_OK(s);
+
+ ASSERT_EQ(5, txn->GetNumPuts());
+ ASSERT_EQ(0, txn->GetNumDeletes());
+
+ // Rollback to beginning of txn
+ s = txn->RollbackToSavePoint();
+ ASSERT_TRUE(s.IsNotFound());
+ ASSERT_OK(txn->Rollback());
+
+ ASSERT_EQ(0, txn->GetNumPuts());
+ ASSERT_EQ(0, txn->GetNumDeletes());
+
+ s = txn->Get(read_options, "A", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn->Get(read_options, "B", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("b", value);
+
+ s = txn->Get(read_options, "D", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn->Get(read_options, "D", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn->Get(read_options, "E", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn->Put("A", "aa");
+ ASSERT_OK(s);
+
+ s = txn->Put("F", "f");
+ ASSERT_OK(s);
+
+ ASSERT_EQ(2, txn->GetNumPuts());
+ ASSERT_EQ(0, txn->GetNumDeletes());
+
+ txn->SetSavePoint(); // 3
+ txn->SetSavePoint(); // 4
+
+ s = txn->Put("G", "g");
+ ASSERT_OK(s);
+
+ s = txn->SingleDelete("F");
+ ASSERT_OK(s);
+
+ s = txn->Delete("B");
+ ASSERT_OK(s);
+
+ s = txn->Get(read_options, "A", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("aa", value);
+
+ s = txn->Get(read_options, "F", &value);
+ // According to db.h, doing a SingleDelete on a key that has been
+ // overwritten will have undefinied behavior. So it is unclear what the
+ // result of fetching "F" should be. The current implementation will
+ // return NotFound in this case.
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn->Get(read_options, "B", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ ASSERT_EQ(3, txn->GetNumPuts());
+ ASSERT_EQ(2, txn->GetNumDeletes());
+
+ ASSERT_OK(txn->RollbackToSavePoint()); // Rollback to 3
+
+ ASSERT_EQ(2, txn->GetNumPuts());
+ ASSERT_EQ(0, txn->GetNumDeletes());
+
+ s = txn->Get(read_options, "F", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("f", value);
+
+ s = txn->Get(read_options, "G", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ s = db->Get(read_options, "F", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("f", value);
+
+ s = db->Get(read_options, "G", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = db->Get(read_options, "A", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("aa", value);
+
+ s = db->Get(read_options, "B", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("b", value);
+
+ s = db->Get(read_options, "C", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = db->Get(read_options, "D", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = db->Get(read_options, "E", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ delete txn;
+}
+
+TEST_P(TransactionTest, SavepointTest2) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ TransactionOptions txn_options;
+ Status s;
+
+ txn_options.lock_timeout = 1; // 1 ms
+ Transaction* txn1 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn1);
+
+ s = txn1->Put("A", "");
+ ASSERT_OK(s);
+
+ txn1->SetSavePoint(); // 1
+
+ s = txn1->Put("A", "a");
+ ASSERT_OK(s);
+
+ s = txn1->Put("C", "c");
+ ASSERT_OK(s);
+
+ txn1->SetSavePoint(); // 2
+
+ s = txn1->Put("A", "a");
+ ASSERT_OK(s);
+ s = txn1->Put("B", "b");
+ ASSERT_OK(s);
+
+ ASSERT_OK(txn1->RollbackToSavePoint()); // Rollback to 2
+
+ // Verify that "A" and "C" is still locked while "B" is not
+ Transaction* txn2 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn2);
+
+ s = txn2->Put("A", "a2");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("C", "c2");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("B", "b2");
+ ASSERT_OK(s);
+
+ s = txn1->Put("A", "aa");
+ ASSERT_OK(s);
+ s = txn1->Put("B", "bb");
+ ASSERT_TRUE(s.IsTimedOut());
+
+ s = txn2->Commit();
+ ASSERT_OK(s);
+ delete txn2;
+
+ s = txn1->Put("A", "aaa");
+ ASSERT_OK(s);
+ s = txn1->Put("B", "bbb");
+ ASSERT_OK(s);
+ s = txn1->Put("C", "ccc");
+ ASSERT_OK(s);
+
+ txn1->SetSavePoint(); // 3
+ ASSERT_OK(txn1->RollbackToSavePoint()); // Rollback to 3
+
+ // Verify that "A", "B", "C" are still locked
+ txn2 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn2);
+
+ s = txn2->Put("A", "a2");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("B", "b2");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("C", "c2");
+ ASSERT_TRUE(s.IsTimedOut());
+
+ ASSERT_OK(txn1->RollbackToSavePoint()); // Rollback to 1
+
+ // Verify that only "A" is locked
+ s = txn2->Put("A", "a3");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("B", "b3");
+ ASSERT_OK(s);
+ s = txn2->Put("C", "c3po");
+ ASSERT_OK(s);
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+ delete txn1;
+
+ // Verify "A" "C" "B" are no longer locked
+ s = txn2->Put("A", "a4");
+ ASSERT_OK(s);
+ s = txn2->Put("B", "b4");
+ ASSERT_OK(s);
+ s = txn2->Put("C", "c4");
+ ASSERT_OK(s);
+
+ s = txn2->Commit();
+ ASSERT_OK(s);
+ delete txn2;
+}
+
+TEST_P(TransactionTest, SavepointTest3) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ TransactionOptions txn_options;
+ Status s;
+
+ txn_options.lock_timeout = 1; // 1 ms
+ Transaction* txn1 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn1);
+
+ s = txn1->PopSavePoint(); // No SavePoint present
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn1->Put("A", "");
+ ASSERT_OK(s);
+
+ s = txn1->PopSavePoint(); // Still no SavePoint present
+ ASSERT_TRUE(s.IsNotFound());
+
+ txn1->SetSavePoint(); // 1
+
+ s = txn1->Put("A", "a");
+ ASSERT_OK(s);
+
+ s = txn1->PopSavePoint(); // Remove 1
+ ASSERT_TRUE(txn1->RollbackToSavePoint().IsNotFound());
+
+ // Verify that "A" is still locked
+ Transaction* txn2 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn2);
+
+ s = txn2->Put("A", "a2");
+ ASSERT_TRUE(s.IsTimedOut());
+ delete txn2;
+
+ txn1->SetSavePoint(); // 2
+
+ s = txn1->Put("B", "b");
+ ASSERT_OK(s);
+
+ txn1->SetSavePoint(); // 3
+
+ s = txn1->Put("B", "b2");
+ ASSERT_OK(s);
+
+ ASSERT_OK(txn1->RollbackToSavePoint()); // Roll back to 2
+
+ s = txn1->PopSavePoint();
+ ASSERT_OK(s);
+
+ s = txn1->PopSavePoint();
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+ delete txn1;
+
+ std::string value;
+
+ // tnx1 should have modified "A" to "a"
+ s = db->Get(read_options, "A", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("a", value);
+
+ // tnx1 should have set "B" to just "b"
+ s = db->Get(read_options, "B", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("b", value);
+
+ s = db->Get(read_options, "C", &value);
+ ASSERT_TRUE(s.IsNotFound());
+}
+
+TEST_P(TransactionTest, SavepointTest4) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ TransactionOptions txn_options;
+ Status s;
+
+ txn_options.lock_timeout = 1; // 1 ms
+ Transaction* txn1 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn1);
+
+ txn1->SetSavePoint(); // 1
+ s = txn1->Put("A", "a");
+ ASSERT_OK(s);
+
+ txn1->SetSavePoint(); // 2
+ s = txn1->Put("B", "b");
+ ASSERT_OK(s);
+
+ s = txn1->PopSavePoint(); // Remove 2
+ ASSERT_OK(s);
+
+ // Verify that A/B still exists.
+ std::string value;
+ ASSERT_OK(txn1->Get(read_options, "A", &value));
+ ASSERT_EQ("a", value);
+
+ ASSERT_OK(txn1->Get(read_options, "B", &value));
+ ASSERT_EQ("b", value);
+
+ ASSERT_OK(txn1->RollbackToSavePoint()); // Rollback to 1
+
+ // Verify that everything was rolled back.
+ s = txn1->Get(read_options, "A", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn1->Get(read_options, "B", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ // Nothing should be locked
+ Transaction* txn2 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn2);
+
+ s = txn2->Put("A", "");
+ ASSERT_OK(s);
+
+ s = txn2->Put("B", "");
+ ASSERT_OK(s);
+
+ delete txn2;
+ delete txn1;
+}
+
+TEST_P(TransactionTest, UndoGetForUpdateTest) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ TransactionOptions txn_options;
+ std::string value;
+ Status s;
+
+ txn_options.lock_timeout = 1; // 1 ms
+ Transaction* txn1 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn1);
+
+ txn1->UndoGetForUpdate("A");
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+ delete txn1;
+
+ txn1 = db->BeginTransaction(write_options, txn_options);
+
+ txn1->UndoGetForUpdate("A");
+ s = txn1->GetForUpdate(read_options, "A", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ // Verify that A is locked
+ Transaction* txn2 = db->BeginTransaction(write_options, txn_options);
+ s = txn2->Put("A", "a");
+ ASSERT_TRUE(s.IsTimedOut());
+
+ txn1->UndoGetForUpdate("A");
+
+ // Verify that A is now unlocked
+ s = txn2->Put("A", "a2");
+ ASSERT_OK(s);
+ ASSERT_OK(txn2->Commit());
+ delete txn2;
+ s = db->Get(read_options, "A", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("a2", value);
+
+ s = txn1->Delete("A");
+ ASSERT_OK(s);
+ s = txn1->GetForUpdate(read_options, "A", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn1->Put("B", "b3");
+ ASSERT_OK(s);
+ s = txn1->GetForUpdate(read_options, "B", &value);
+ ASSERT_OK(s);
+
+ txn1->UndoGetForUpdate("A");
+ txn1->UndoGetForUpdate("B");
+
+ // Verify that A and B are still locked
+ txn2 = db->BeginTransaction(write_options, txn_options);
+ s = txn2->Put("A", "a4");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("B", "b4");
+ ASSERT_TRUE(s.IsTimedOut());
+
+ ASSERT_OK(txn1->Rollback());
+ delete txn1;
+
+ // Verify that A and B are no longer locked
+ s = txn2->Put("A", "a5");
+ ASSERT_OK(s);
+ s = txn2->Put("B", "b5");
+ ASSERT_OK(s);
+ s = txn2->Commit();
+ delete txn2;
+ ASSERT_OK(s);
+
+ txn1 = db->BeginTransaction(write_options, txn_options);
+
+ s = txn1->GetForUpdate(read_options, "A", &value);
+ ASSERT_OK(s);
+ s = txn1->GetForUpdate(read_options, "A", &value);
+ ASSERT_OK(s);
+ s = txn1->GetForUpdate(read_options, "C", &value);
+ ASSERT_TRUE(s.IsNotFound());
+ s = txn1->GetForUpdate(read_options, "A", &value);
+ ASSERT_OK(s);
+ s = txn1->GetForUpdate(read_options, "C", &value);
+ ASSERT_TRUE(s.IsNotFound());
+ s = txn1->GetForUpdate(read_options, "B", &value);
+ ASSERT_OK(s);
+ s = txn1->Put("B", "b5");
+ s = txn1->GetForUpdate(read_options, "B", &value);
+ ASSERT_OK(s);
+
+ txn1->UndoGetForUpdate("A");
+ txn1->UndoGetForUpdate("B");
+ txn1->UndoGetForUpdate("C");
+ txn1->UndoGetForUpdate("X");
+
+ // Verify A,B,C are locked
+ txn2 = db->BeginTransaction(write_options, txn_options);
+ s = txn2->Put("A", "a6");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Delete("B");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("C", "c6");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("X", "x6");
+ ASSERT_OK(s);
+
+ txn1->UndoGetForUpdate("A");
+ txn1->UndoGetForUpdate("B");
+ txn1->UndoGetForUpdate("C");
+ txn1->UndoGetForUpdate("X");
+
+ // Verify A,B are locked and C is not
+ s = txn2->Put("A", "a6");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Delete("B");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("C", "c6");
+ ASSERT_OK(s);
+ s = txn2->Put("X", "x6");
+ ASSERT_OK(s);
+
+ txn1->UndoGetForUpdate("A");
+ txn1->UndoGetForUpdate("B");
+ txn1->UndoGetForUpdate("C");
+ txn1->UndoGetForUpdate("X");
+
+ // Verify B is locked and A and C are not
+ s = txn2->Put("A", "a7");
+ ASSERT_OK(s);
+ s = txn2->Delete("B");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("C", "c7");
+ ASSERT_OK(s);
+ s = txn2->Put("X", "x7");
+ ASSERT_OK(s);
+
+ s = txn2->Commit();
+ ASSERT_OK(s);
+ delete txn2;
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+ delete txn1;
+}
+
+TEST_P(TransactionTest, UndoGetForUpdateTest2) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ TransactionOptions txn_options;
+ std::string value;
+ Status s;
+
+ s = db->Put(write_options, "A", "");
+ ASSERT_OK(s);
+
+ txn_options.lock_timeout = 1; // 1 ms
+ Transaction* txn1 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn1);
+
+ s = txn1->GetForUpdate(read_options, "A", &value);
+ ASSERT_OK(s);
+ s = txn1->GetForUpdate(read_options, "B", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn1->Put("F", "f");
+ ASSERT_OK(s);
+
+ txn1->SetSavePoint(); // 1
+
+ txn1->UndoGetForUpdate("A");
+
+ s = txn1->GetForUpdate(read_options, "C", &value);
+ ASSERT_TRUE(s.IsNotFound());
+ s = txn1->GetForUpdate(read_options, "D", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn1->Put("E", "e");
+ ASSERT_OK(s);
+ s = txn1->GetForUpdate(read_options, "E", &value);
+ ASSERT_OK(s);
+
+ s = txn1->GetForUpdate(read_options, "F", &value);
+ ASSERT_OK(s);
+
+ // Verify A,B,C,D,E,F are still locked
+ Transaction* txn2 = db->BeginTransaction(write_options, txn_options);
+ s = txn2->Put("A", "a1");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("B", "b1");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("C", "c1");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("D", "d1");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("E", "e1");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("F", "f1");
+ ASSERT_TRUE(s.IsTimedOut());
+
+ txn1->UndoGetForUpdate("C");
+ txn1->UndoGetForUpdate("E");
+
+ // Verify A,B,D,E,F are still locked and C is not.
+ s = txn2->Put("A", "a2");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("B", "b2");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("D", "d2");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("E", "e2");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("F", "f2");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("C", "c2");
+ ASSERT_OK(s);
+
+ txn1->SetSavePoint(); // 2
+
+ s = txn1->Put("H", "h");
+ ASSERT_OK(s);
+
+ txn1->UndoGetForUpdate("A");
+ txn1->UndoGetForUpdate("B");
+ txn1->UndoGetForUpdate("C");
+ txn1->UndoGetForUpdate("D");
+ txn1->UndoGetForUpdate("E");
+ txn1->UndoGetForUpdate("F");
+ txn1->UndoGetForUpdate("G");
+ txn1->UndoGetForUpdate("H");
+
+ // Verify A,B,D,E,F,H are still locked and C,G are not.
+ s = txn2->Put("A", "a3");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("B", "b3");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("D", "d3");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("E", "e3");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("F", "f3");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("H", "h3");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("C", "c3");
+ ASSERT_OK(s);
+ s = txn2->Put("G", "g3");
+ ASSERT_OK(s);
+
+ ASSERT_OK(txn1->RollbackToSavePoint()); // rollback to 2
+
+ // Verify A,B,D,E,F are still locked and C,G,H are not.
+ s = txn2->Put("A", "a3");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("B", "b3");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("D", "d3");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("E", "e3");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("F", "f3");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("C", "c3");
+ ASSERT_OK(s);
+ s = txn2->Put("G", "g3");
+ ASSERT_OK(s);
+ s = txn2->Put("H", "h3");
+ ASSERT_OK(s);
+
+ txn1->UndoGetForUpdate("A");
+ txn1->UndoGetForUpdate("B");
+ txn1->UndoGetForUpdate("C");
+ txn1->UndoGetForUpdate("D");
+ txn1->UndoGetForUpdate("E");
+ txn1->UndoGetForUpdate("F");
+ txn1->UndoGetForUpdate("G");
+ txn1->UndoGetForUpdate("H");
+
+ // Verify A,B,E,F are still locked and C,D,G,H are not.
+ s = txn2->Put("A", "a3");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("B", "b3");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("E", "e3");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("F", "f3");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("C", "c3");
+ ASSERT_OK(s);
+ s = txn2->Put("D", "d3");
+ ASSERT_OK(s);
+ s = txn2->Put("G", "g3");
+ ASSERT_OK(s);
+ s = txn2->Put("H", "h3");
+ ASSERT_OK(s);
+
+ ASSERT_OK(txn1->RollbackToSavePoint()); // rollback to 1
+
+ // Verify A,B,F are still locked and C,D,E,G,H are not.
+ s = txn2->Put("A", "a3");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("B", "b3");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("F", "f3");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("C", "c3");
+ ASSERT_OK(s);
+ s = txn2->Put("D", "d3");
+ ASSERT_OK(s);
+ s = txn2->Put("E", "e3");
+ ASSERT_OK(s);
+ s = txn2->Put("G", "g3");
+ ASSERT_OK(s);
+ s = txn2->Put("H", "h3");
+ ASSERT_OK(s);
+
+ txn1->UndoGetForUpdate("A");
+ txn1->UndoGetForUpdate("B");
+ txn1->UndoGetForUpdate("C");
+ txn1->UndoGetForUpdate("D");
+ txn1->UndoGetForUpdate("E");
+ txn1->UndoGetForUpdate("F");
+ txn1->UndoGetForUpdate("G");
+ txn1->UndoGetForUpdate("H");
+
+ // Verify F is still locked and A,B,C,D,E,G,H are not.
+ s = txn2->Put("F", "f3");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("A", "a3");
+ ASSERT_OK(s);
+ s = txn2->Put("B", "b3");
+ ASSERT_OK(s);
+ s = txn2->Put("C", "c3");
+ ASSERT_OK(s);
+ s = txn2->Put("D", "d3");
+ ASSERT_OK(s);
+ s = txn2->Put("E", "e3");
+ ASSERT_OK(s);
+ s = txn2->Put("G", "g3");
+ ASSERT_OK(s);
+ s = txn2->Put("H", "h3");
+ ASSERT_OK(s);
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+ s = txn2->Commit();
+ ASSERT_OK(s);
+
+ delete txn1;
+ delete txn2;
+}
+
+TEST_P(TransactionTest, TimeoutTest) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ std::string value;
+ Status s;
+
+ delete db;
+ db = nullptr;
+
+ // transaction writes have an infinite timeout,
+ // but we will override this when we start a txn
+ // db writes have infinite timeout
+ txn_db_options.transaction_lock_timeout = -1;
+ txn_db_options.default_lock_timeout = -1;
+
+ s = TransactionDB::Open(options, txn_db_options, dbname, &db);
+ assert(db != nullptr);
+ ASSERT_OK(s);
+
+ s = db->Put(write_options, "aaa", "aaa");
+ ASSERT_OK(s);
+
+ TransactionOptions txn_options0;
+ txn_options0.expiration = 100; // 100ms
+ txn_options0.lock_timeout = 50; // txn timeout no longer infinite
+ Transaction* txn1 = db->BeginTransaction(write_options, txn_options0);
+
+ s = txn1->GetForUpdate(read_options, "aaa", nullptr);
+ ASSERT_OK(s);
+
+ // Conflicts with previous GetForUpdate.
+ // Since db writes do not have a timeout, this should eventually succeed when
+ // the transaction expires.
+ s = db->Put(write_options, "aaa", "xxx");
+ ASSERT_OK(s);
+
+ ASSERT_GE(txn1->GetElapsedTime(),
+ static_cast<uint64_t>(txn_options0.expiration));
+
+ s = txn1->Commit();
+ ASSERT_TRUE(s.IsExpired()); // expired!
+
+ s = db->Get(read_options, "aaa", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("xxx", value);
+
+ delete txn1;
+ delete db;
+
+ // transaction writes have 10ms timeout,
+ // db writes have infinite timeout
+ txn_db_options.transaction_lock_timeout = 50;
+ txn_db_options.default_lock_timeout = -1;
+
+ s = TransactionDB::Open(options, txn_db_options, dbname, &db);
+ ASSERT_OK(s);
+
+ s = db->Put(write_options, "aaa", "aaa");
+ ASSERT_OK(s);
+
+ TransactionOptions txn_options;
+ txn_options.expiration = 100; // 100ms
+ txn1 = db->BeginTransaction(write_options, txn_options);
+
+ s = txn1->GetForUpdate(read_options, "aaa", nullptr);
+ ASSERT_OK(s);
+
+ // Conflicts with previous GetForUpdate.
+ // Since db writes do not have a timeout, this should eventually succeed when
+ // the transaction expires.
+ s = db->Put(write_options, "aaa", "xxx");
+ ASSERT_OK(s);
+
+ s = txn1->Commit();
+ ASSERT_NOK(s); // expired!
+
+ s = db->Get(read_options, "aaa", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("xxx", value);
+
+ delete txn1;
+ txn_options.expiration = 6000000; // 100 minutes
+ txn_options.lock_timeout = 1; // 1ms
+ txn1 = db->BeginTransaction(write_options, txn_options);
+ txn1->SetLockTimeout(100);
+
+ TransactionOptions txn_options2;
+ txn_options2.expiration = 10; // 10ms
+ Transaction* txn2 = db->BeginTransaction(write_options, txn_options2);
+ ASSERT_OK(s);
+
+ s = txn2->Put("a", "2");
+ ASSERT_OK(s);
+
+ // txn1 has a lock timeout longer than txn2's expiration, so it will win
+ s = txn1->Delete("a");
+ ASSERT_OK(s);
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ // txn2 should be expired out since txn1 waiting until its timeout expired.
+ s = txn2->Commit();
+ ASSERT_TRUE(s.IsExpired());
+
+ delete txn1;
+ delete txn2;
+ txn_options.expiration = 6000000; // 100 minutes
+ txn1 = db->BeginTransaction(write_options, txn_options);
+ txn_options2.expiration = 100000000;
+ txn2 = db->BeginTransaction(write_options, txn_options2);
+
+ s = txn1->Delete("asdf");
+ ASSERT_OK(s);
+
+ // txn2 has a smaller lock timeout than txn1's expiration, so it will time out
+ s = txn2->Delete("asdf");
+ ASSERT_TRUE(s.IsTimedOut());
+ ASSERT_EQ(s.ToString(), "Operation timed out: Timeout waiting to lock key");
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ s = txn2->Put("asdf", "asdf");
+ ASSERT_OK(s);
+
+ s = txn2->Commit();
+ ASSERT_OK(s);
+
+ s = db->Get(read_options, "asdf", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("asdf", value);
+
+ delete txn1;
+ delete txn2;
+}
+
+TEST_P(TransactionTest, SingleDeleteTest) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ std::string value;
+ Status s;
+
+ Transaction* txn = db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn);
+
+ s = txn->SingleDelete("A");
+ ASSERT_OK(s);
+
+ s = txn->Get(read_options, "A", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+ delete txn;
+
+ txn = db->BeginTransaction(write_options);
+
+ s = txn->SingleDelete("A");
+ ASSERT_OK(s);
+
+ s = txn->Put("A", "a");
+ ASSERT_OK(s);
+
+ s = txn->Get(read_options, "A", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("a", value);
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+ delete txn;
+
+ s = db->Get(read_options, "A", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("a", value);
+
+ txn = db->BeginTransaction(write_options);
+
+ s = txn->SingleDelete("A");
+ ASSERT_OK(s);
+
+ s = txn->Get(read_options, "A", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+ delete txn;
+
+ s = db->Get(read_options, "A", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ txn = db->BeginTransaction(write_options);
+ Transaction* txn2 = db->BeginTransaction(write_options);
+ txn2->SetSnapshot();
+
+ s = txn->Put("A", "a");
+ ASSERT_OK(s);
+
+ s = txn->Put("A", "a2");
+ ASSERT_OK(s);
+
+ s = txn->SingleDelete("A");
+ ASSERT_OK(s);
+
+ s = txn->SingleDelete("B");
+ ASSERT_OK(s);
+
+ // According to db.h, doing a SingleDelete on a key that has been
+ // overwritten will have undefinied behavior. So it is unclear what the
+ // result of fetching "A" should be. The current implementation will
+ // return NotFound in this case.
+ s = txn->Get(read_options, "A", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn2->Put("B", "b");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Commit();
+ ASSERT_OK(s);
+ delete txn2;
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+ delete txn;
+
+ // According to db.h, doing a SingleDelete on a key that has been
+ // overwritten will have undefinied behavior. So it is unclear what the
+ // result of fetching "A" should be. The current implementation will
+ // return NotFound in this case.
+ s = db->Get(read_options, "A", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = db->Get(read_options, "B", &value);
+ ASSERT_TRUE(s.IsNotFound());
+}
+
+TEST_P(TransactionTest, MergeTest) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ std::string value;
+ Status s;
+
+ Transaction* txn = db->BeginTransaction(write_options, TransactionOptions());
+ ASSERT_TRUE(txn);
+
+ s = db->Put(write_options, "A", "a0");
+ ASSERT_OK(s);
+
+ s = txn->Merge("A", "1");
+ ASSERT_OK(s);
+
+ s = txn->Merge("A", "2");
+ ASSERT_OK(s);
+
+ s = txn->Get(read_options, "A", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("a0,1,2", value);
+
+ s = txn->Put("A", "a");
+ ASSERT_OK(s);
+
+ s = txn->Get(read_options, "A", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("a", value);
+
+ s = txn->Merge("A", "3");
+ ASSERT_OK(s);
+
+ s = txn->Get(read_options, "A", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("a,3", value);
+
+ TransactionOptions txn_options;
+ txn_options.lock_timeout = 1; // 1 ms
+ Transaction* txn2 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn2);
+
+ // verify that txn has "A" locked
+ s = txn2->Merge("A", "4");
+ ASSERT_TRUE(s.IsTimedOut());
+
+ s = txn2->Commit();
+ ASSERT_OK(s);
+ delete txn2;
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+ delete txn;
+
+ s = db->Get(read_options, "A", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("a,3", value);
+}
+
+TEST_P(TransactionTest, DeleteRangeSupportTest) {
+ // The `DeleteRange()` API is banned everywhere.
+ ASSERT_TRUE(
+ db->DeleteRange(WriteOptions(), db->DefaultColumnFamily(), "a", "b")
+ .IsNotSupported());
+
+ // But range deletions can be added via the `Write()` API by specifying the
+ // proper flags to promise there are no conflicts according to the DB type
+ // (see `TransactionDB::DeleteRange()` API doc for details).
+ for (bool skip_concurrency_control : {false, true}) {
+ for (bool skip_duplicate_key_check : {false, true}) {
+ ASSERT_OK(db->Put(WriteOptions(), "a", "val"));
+ WriteBatch wb;
+ ASSERT_OK(wb.DeleteRange("a", "b"));
+ TransactionDBWriteOptimizations flags;
+ flags.skip_concurrency_control = skip_concurrency_control;
+ flags.skip_duplicate_key_check = skip_duplicate_key_check;
+ Status s = db->Write(WriteOptions(), flags, &wb);
+ std::string value;
+ switch (txn_db_options.write_policy) {
+ case WRITE_COMMITTED:
+ if (skip_concurrency_control) {
+ ASSERT_OK(s);
+ ASSERT_TRUE(db->Get(ReadOptions(), "a", &value).IsNotFound());
+ } else {
+ ASSERT_NOK(s);
+ ASSERT_OK(db->Get(ReadOptions(), "a", &value));
+ }
+ break;
+ case WRITE_PREPARED:
+ // Intentional fall-through
+ case WRITE_UNPREPARED:
+ if (skip_concurrency_control && skip_duplicate_key_check) {
+ ASSERT_OK(s);
+ ASSERT_TRUE(db->Get(ReadOptions(), "a", &value).IsNotFound());
+ } else {
+ ASSERT_NOK(s);
+ ASSERT_OK(db->Get(ReadOptions(), "a", &value));
+ }
+ break;
+ }
+ // Without any promises from the user, range deletion via other `Write()`
+ // APIs are still banned.
+ ASSERT_OK(db->Put(WriteOptions(), "a", "val"));
+ ASSERT_NOK(db->Write(WriteOptions(), &wb));
+ ASSERT_OK(db->Get(ReadOptions(), "a", &value));
+ }
+ }
+}
+
+TEST_P(TransactionTest, DeferSnapshotTest) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ std::string value;
+ Status s;
+
+ s = db->Put(write_options, "A", "a0");
+ ASSERT_OK(s);
+
+ Transaction* txn1 = db->BeginTransaction(write_options);
+ Transaction* txn2 = db->BeginTransaction(write_options);
+
+ txn1->SetSnapshotOnNextOperation();
+ auto snapshot = txn1->GetSnapshot();
+ ASSERT_FALSE(snapshot);
+
+ s = txn2->Put("A", "a2");
+ ASSERT_OK(s);
+ s = txn2->Commit();
+ ASSERT_OK(s);
+ delete txn2;
+
+ s = txn1->GetForUpdate(read_options, "A", &value);
+ // Should not conflict with txn2 since snapshot wasn't set until
+ // GetForUpdate was called.
+ ASSERT_OK(s);
+ ASSERT_EQ("a2", value);
+
+ s = txn1->Put("A", "a1");
+ ASSERT_OK(s);
+
+ s = db->Put(write_options, "B", "b0");
+ ASSERT_OK(s);
+
+ // Cannot lock B since it was written after the snapshot was set
+ s = txn1->Put("B", "b1");
+ ASSERT_TRUE(s.IsBusy());
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+ delete txn1;
+
+ s = db->Get(read_options, "A", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("a1", value);
+
+ s = db->Get(read_options, "B", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("b0", value);
+}
+
+TEST_P(TransactionTest, DeferSnapshotTest2) {
+ WriteOptions write_options;
+ ReadOptions read_options, snapshot_read_options;
+ std::string value;
+ Status s;
+
+ Transaction* txn1 = db->BeginTransaction(write_options);
+
+ txn1->SetSnapshot();
+
+ s = txn1->Put("A", "a1");
+ ASSERT_OK(s);
+
+ s = db->Put(write_options, "C", "c0");
+ ASSERT_OK(s);
+ s = db->Put(write_options, "D", "d0");
+ ASSERT_OK(s);
+
+ snapshot_read_options.snapshot = txn1->GetSnapshot();
+
+ txn1->SetSnapshotOnNextOperation();
+
+ s = txn1->Get(snapshot_read_options, "C", &value);
+ // Snapshot was set before C was written
+ ASSERT_TRUE(s.IsNotFound());
+ s = txn1->Get(snapshot_read_options, "D", &value);
+ // Snapshot was set before D was written
+ ASSERT_TRUE(s.IsNotFound());
+
+ // Snapshot should not have changed yet.
+ snapshot_read_options.snapshot = txn1->GetSnapshot();
+
+ s = txn1->Get(snapshot_read_options, "C", &value);
+ // Snapshot was set before C was written
+ ASSERT_TRUE(s.IsNotFound());
+ s = txn1->Get(snapshot_read_options, "D", &value);
+ // Snapshot was set before D was written
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn1->GetForUpdate(read_options, "C", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("c0", value);
+
+ s = db->Put(write_options, "D", "d00");
+ ASSERT_OK(s);
+
+ // Snapshot is now set
+ snapshot_read_options.snapshot = txn1->GetSnapshot();
+ s = txn1->Get(snapshot_read_options, "D", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("d0", value);
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+ delete txn1;
+}
+
+TEST_P(TransactionTest, DeferSnapshotSavePointTest) {
+ WriteOptions write_options;
+ ReadOptions read_options, snapshot_read_options;
+ std::string value;
+ Status s;
+
+ Transaction* txn1 = db->BeginTransaction(write_options);
+
+ txn1->SetSavePoint(); // 1
+
+ s = db->Put(write_options, "T", "1");
+ ASSERT_OK(s);
+
+ txn1->SetSnapshotOnNextOperation();
+
+ s = db->Put(write_options, "T", "2");
+ ASSERT_OK(s);
+
+ txn1->SetSavePoint(); // 2
+
+ s = db->Put(write_options, "T", "3");
+ ASSERT_OK(s);
+
+ s = txn1->Put("A", "a");
+ ASSERT_OK(s);
+
+ txn1->SetSavePoint(); // 3
+
+ s = db->Put(write_options, "T", "4");
+ ASSERT_OK(s);
+
+ txn1->SetSnapshot();
+ txn1->SetSnapshotOnNextOperation();
+
+ txn1->SetSavePoint(); // 4
+
+ s = db->Put(write_options, "T", "5");
+ ASSERT_OK(s);
+
+ snapshot_read_options.snapshot = txn1->GetSnapshot();
+ s = txn1->Get(snapshot_read_options, "T", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("4", value);
+
+ s = txn1->Put("A", "a1");
+ ASSERT_OK(s);
+
+ snapshot_read_options.snapshot = txn1->GetSnapshot();
+ s = txn1->Get(snapshot_read_options, "T", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("5", value);
+
+ s = txn1->RollbackToSavePoint(); // Rollback to 4
+ ASSERT_OK(s);
+
+ snapshot_read_options.snapshot = txn1->GetSnapshot();
+ s = txn1->Get(snapshot_read_options, "T", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("4", value);
+
+ s = txn1->RollbackToSavePoint(); // Rollback to 3
+ ASSERT_OK(s);
+
+ snapshot_read_options.snapshot = txn1->GetSnapshot();
+ s = txn1->Get(snapshot_read_options, "T", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("3", value);
+
+ s = txn1->Get(read_options, "T", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("5", value);
+
+ s = txn1->RollbackToSavePoint(); // Rollback to 2
+ ASSERT_OK(s);
+
+ snapshot_read_options.snapshot = txn1->GetSnapshot();
+ ASSERT_FALSE(snapshot_read_options.snapshot);
+ s = txn1->Get(snapshot_read_options, "T", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("5", value);
+
+ s = txn1->Delete("A");
+ ASSERT_OK(s);
+
+ snapshot_read_options.snapshot = txn1->GetSnapshot();
+ ASSERT_TRUE(snapshot_read_options.snapshot);
+ s = txn1->Get(snapshot_read_options, "T", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("5", value);
+
+ s = txn1->RollbackToSavePoint(); // Rollback to 1
+ ASSERT_OK(s);
+
+ s = txn1->Delete("A");
+ ASSERT_OK(s);
+
+ snapshot_read_options.snapshot = txn1->GetSnapshot();
+ ASSERT_FALSE(snapshot_read_options.snapshot);
+ s = txn1->Get(snapshot_read_options, "T", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("5", value);
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ delete txn1;
+}
+
+TEST_P(TransactionTest, SetSnapshotOnNextOperationWithNotification) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ std::string value;
+
+ class Notifier : public TransactionNotifier {
+ private:
+ const Snapshot** snapshot_ptr_;
+
+ public:
+ explicit Notifier(const Snapshot** snapshot_ptr)
+ : snapshot_ptr_(snapshot_ptr) {}
+
+ void SnapshotCreated(const Snapshot* newSnapshot) override {
+ *snapshot_ptr_ = newSnapshot;
+ }
+ };
+
+ std::shared_ptr<Notifier> notifier =
+ std::make_shared<Notifier>(&read_options.snapshot);
+ Status s;
+
+ s = db->Put(write_options, "B", "0");
+ ASSERT_OK(s);
+
+ Transaction* txn1 = db->BeginTransaction(write_options);
+
+ txn1->SetSnapshotOnNextOperation(notifier);
+ ASSERT_FALSE(read_options.snapshot);
+
+ s = db->Put(write_options, "B", "1");
+ ASSERT_OK(s);
+
+ // A Get does not generate the snapshot
+ s = txn1->Get(read_options, "B", &value);
+ ASSERT_OK(s);
+ ASSERT_FALSE(read_options.snapshot);
+ ASSERT_EQ(value, "1");
+
+ // Any other operation does
+ s = txn1->Put("A", "0");
+ ASSERT_OK(s);
+
+ // Now change "B".
+ s = db->Put(write_options, "B", "2");
+ ASSERT_OK(s);
+
+ // The original value should still be read
+ s = txn1->Get(read_options, "B", &value);
+ ASSERT_OK(s);
+ ASSERT_TRUE(read_options.snapshot);
+ ASSERT_EQ(value, "1");
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ delete txn1;
+}
+
+TEST_P(TransactionTest, ClearSnapshotTest) {
+ WriteOptions write_options;
+ ReadOptions read_options, snapshot_read_options;
+ std::string value;
+ Status s;
+
+ s = db->Put(write_options, "foo", "0");
+ ASSERT_OK(s);
+
+ Transaction* txn = db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn);
+
+ s = db->Put(write_options, "foo", "1");
+ ASSERT_OK(s);
+
+ snapshot_read_options.snapshot = txn->GetSnapshot();
+ ASSERT_FALSE(snapshot_read_options.snapshot);
+
+ // No snapshot created yet
+ s = txn->Get(snapshot_read_options, "foo", &value);
+ ASSERT_EQ(value, "1");
+
+ txn->SetSnapshot();
+ snapshot_read_options.snapshot = txn->GetSnapshot();
+ ASSERT_TRUE(snapshot_read_options.snapshot);
+
+ s = db->Put(write_options, "foo", "2");
+ ASSERT_OK(s);
+
+ // Snapshot was created before change to '2'
+ s = txn->Get(snapshot_read_options, "foo", &value);
+ ASSERT_EQ(value, "1");
+
+ txn->ClearSnapshot();
+ snapshot_read_options.snapshot = txn->GetSnapshot();
+ ASSERT_FALSE(snapshot_read_options.snapshot);
+
+ // Snapshot has now been cleared
+ s = txn->Get(snapshot_read_options, "foo", &value);
+ ASSERT_EQ(value, "2");
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ delete txn;
+}
+
+TEST_P(TransactionTest, ToggleAutoCompactionTest) {
+ Status s;
+
+ ColumnFamilyHandle *cfa, *cfb;
+ ColumnFamilyOptions cf_options;
+
+ // Create 2 new column families
+ s = db->CreateColumnFamily(cf_options, "CFA", &cfa);
+ ASSERT_OK(s);
+ s = db->CreateColumnFamily(cf_options, "CFB", &cfb);
+ ASSERT_OK(s);
+
+ delete cfa;
+ delete cfb;
+ delete db;
+
+ // open DB with three column families
+ std::vector<ColumnFamilyDescriptor> column_families;
+ // have to open default column family
+ column_families.push_back(
+ ColumnFamilyDescriptor(kDefaultColumnFamilyName, ColumnFamilyOptions()));
+ // open the new column families
+ column_families.push_back(
+ ColumnFamilyDescriptor("CFA", ColumnFamilyOptions()));
+ column_families.push_back(
+ ColumnFamilyDescriptor("CFB", ColumnFamilyOptions()));
+
+ ColumnFamilyOptions* cf_opt_default = &column_families[0].options;
+ ColumnFamilyOptions* cf_opt_cfa = &column_families[1].options;
+ ColumnFamilyOptions* cf_opt_cfb = &column_families[2].options;
+ cf_opt_default->disable_auto_compactions = false;
+ cf_opt_cfa->disable_auto_compactions = true;
+ cf_opt_cfb->disable_auto_compactions = false;
+
+ std::vector<ColumnFamilyHandle*> handles;
+
+ s = TransactionDB::Open(options, txn_db_options, dbname, column_families,
+ &handles, &db);
+ ASSERT_OK(s);
+
+ auto cfh_default = static_cast_with_check<ColumnFamilyHandleImpl>(handles[0]);
+ auto opt_default = *cfh_default->cfd()->GetLatestMutableCFOptions();
+
+ auto cfh_a = static_cast_with_check<ColumnFamilyHandleImpl>(handles[1]);
+ auto opt_a = *cfh_a->cfd()->GetLatestMutableCFOptions();
+
+ auto cfh_b = static_cast_with_check<ColumnFamilyHandleImpl>(handles[2]);
+ auto opt_b = *cfh_b->cfd()->GetLatestMutableCFOptions();
+
+ ASSERT_EQ(opt_default.disable_auto_compactions, false);
+ ASSERT_EQ(opt_a.disable_auto_compactions, true);
+ ASSERT_EQ(opt_b.disable_auto_compactions, false);
+
+ for (auto handle : handles) {
+ delete handle;
+ }
+}
+
+TEST_P(TransactionStressTest, ExpiredTransactionDataRace1) {
+ // In this test, txn1 should succeed committing,
+ // as the callback is called after txn1 starts committing.
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency(
+ {{"TransactionTest::ExpirableTransactionDataRace:1"}});
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "TransactionTest::ExpirableTransactionDataRace:1", [&](void* /*arg*/) {
+ WriteOptions write_options;
+ TransactionOptions txn_options;
+
+ // Force txn1 to expire
+ /* sleep override */
+ std::this_thread::sleep_for(std::chrono::milliseconds(1500));
+
+ Transaction* txn2 = db->BeginTransaction(write_options, txn_options);
+ Status s;
+ s = txn2->Put("X", "2");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Commit();
+ ASSERT_OK(s);
+ delete txn2;
+ });
+
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+
+ WriteOptions write_options;
+ TransactionOptions txn_options;
+
+ txn_options.expiration = 1000; // 1 second
+ Transaction* txn1 = db->BeginTransaction(write_options, txn_options);
+
+ Status s;
+ s = txn1->Put("X", "1");
+ ASSERT_OK(s);
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ ReadOptions read_options;
+ std::string value;
+ s = db->Get(read_options, "X", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("1", value);
+
+ delete txn1;
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+}
+
+#if !defined(ROCKSDB_VALGRIND_RUN) || defined(ROCKSDB_FULL_VALGRIND_RUN)
+namespace {
+// cmt_delay_ms is the delay between prepare and commit
+// first_id is the id of the first transaction
+Status TransactionStressTestInserter(
+ TransactionDB* db, const size_t num_transactions, const size_t num_sets,
+ const size_t num_keys_per_set, Random64* rand,
+ const uint64_t cmt_delay_ms = 0, const uint64_t first_id = 0) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ TransactionOptions txn_options;
+ txn_options.use_only_the_last_commit_time_batch_for_recovery = true;
+
+ // Inside the inserter we might also retake the snapshot. We do both since two
+ // separte functions are engaged for each.
+ txn_options.set_snapshot = rand->OneIn(2);
+
+ RandomTransactionInserter inserter(
+ rand, write_options, read_options, num_keys_per_set,
+ static_cast<uint16_t>(num_sets), cmt_delay_ms, first_id);
+
+ for (size_t t = 0; t < num_transactions; t++) {
+ bool success = inserter.TransactionDBInsert(db, txn_options);
+ if (!success) {
+ // unexpected failure
+ return inserter.GetLastStatus();
+ }
+ }
+ inserter.GetLastStatus().PermitUncheckedError();
+
+ // Make sure at least some of the transactions succeeded. It's ok if
+ // some failed due to write-conflicts.
+ if (num_transactions != 1 &&
+ inserter.GetFailureCount() > num_transactions / 2) {
+ return Status::TryAgain("Too many transactions failed! " +
+ std::to_string(inserter.GetFailureCount()) + " / " +
+ std::to_string(num_transactions));
+ }
+
+ return Status::OK();
+}
+} // namespace
+
+// Worker threads add a number to a key from each set of keys. The checker
+// threads verify that the sum of all keys in each set are equal.
+TEST_P(MySQLStyleTransactionTest, TransactionStressTest) {
+ // Small write buffer to trigger more compactions
+ options.write_buffer_size = 1024;
+ txn_db_options.rollback_deletion_type_callback =
+ [](TransactionDB*, ColumnFamilyHandle*, const Slice& key) {
+ return RandomTransactionInserter::RollbackDeletionTypeCallback(key);
+ };
+ ASSERT_OK(ReOpenNoDelete());
+ constexpr size_t num_workers = 4; // worker threads count
+ constexpr size_t num_checkers = 2; // checker threads count
+ constexpr size_t num_slow_checkers = 2; // checker threads emulating backups
+ constexpr size_t num_slow_workers = 1; // slow worker threads count
+ constexpr size_t num_transactions_per_thread = 1000;
+ constexpr uint16_t num_sets = 3;
+ constexpr size_t num_keys_per_set = 100;
+ // Setting the key-space to be 100 keys should cause enough write-conflicts
+ // to make this test interesting.
+
+ std::vector<port::Thread> threads;
+ std::atomic<uint32_t> finished = {0};
+ constexpr bool TAKE_SNAPSHOT = true;
+ uint64_t time_seed = env->NowMicros();
+ printf("time_seed is %" PRIu64 "\n", time_seed); // would help to reproduce
+
+ std::function<void()> call_inserter = [&] {
+ size_t thd_seed = std::hash<std::thread::id>()(std::this_thread::get_id());
+ Random64 rand(time_seed * thd_seed);
+ ASSERT_OK(TransactionStressTestInserter(db, num_transactions_per_thread,
+ num_sets, num_keys_per_set, &rand));
+ finished++;
+ };
+ std::function<void()> call_checker = [&] {
+ size_t thd_seed = std::hash<std::thread::id>()(std::this_thread::get_id());
+ Random64 rand(time_seed * thd_seed);
+ // Verify that data is consistent
+ while (finished < num_workers) {
+ ASSERT_OK(RandomTransactionInserter::Verify(
+ db, num_sets, num_keys_per_set, TAKE_SNAPSHOT, &rand));
+ }
+ };
+ std::function<void()> call_slow_checker = [&] {
+ size_t thd_seed = std::hash<std::thread::id>()(std::this_thread::get_id());
+ Random64 rand(time_seed * thd_seed);
+ // Verify that data is consistent
+ while (finished < num_workers) {
+ uint64_t delay_ms = rand.Uniform(100) + 1;
+ Status s = RandomTransactionInserter::Verify(
+ db, num_sets, num_keys_per_set, TAKE_SNAPSHOT, &rand, delay_ms);
+ ASSERT_OK(s);
+ }
+ };
+ std::function<void()> call_slow_inserter = [&] {
+ size_t thd_seed = std::hash<std::thread::id>()(std::this_thread::get_id());
+ Random64 rand(time_seed * thd_seed);
+ uint64_t id = 0;
+ // Verify that data is consistent
+ while (finished < num_workers) {
+ uint64_t delay_ms = rand.Uniform(500) + 1;
+ ASSERT_OK(TransactionStressTestInserter(db, 1, num_sets, num_keys_per_set,
+ &rand, delay_ms, id++));
+ }
+ };
+
+ for (uint32_t i = 0; i < num_workers; i++) {
+ threads.emplace_back(call_inserter);
+ }
+ for (uint32_t i = 0; i < num_checkers; i++) {
+ threads.emplace_back(call_checker);
+ }
+ if (with_slow_threads_) {
+ for (uint32_t i = 0; i < num_slow_checkers; i++) {
+ threads.emplace_back(call_slow_checker);
+ }
+ for (uint32_t i = 0; i < num_slow_workers; i++) {
+ threads.emplace_back(call_slow_inserter);
+ }
+ }
+
+ // Wait for all threads to finish
+ for (auto& t : threads) {
+ t.join();
+ }
+
+ // Verify that data is consistent
+ Status s = RandomTransactionInserter::Verify(db, num_sets, num_keys_per_set,
+ !TAKE_SNAPSHOT);
+ ASSERT_OK(s);
+}
+#endif // !defined(ROCKSDB_VALGRIND_RUN) || defined(ROCKSDB_FULL_VALGRIND_RUN)
+
+TEST_P(TransactionTest, MemoryLimitTest) {
+ TransactionOptions txn_options;
+ // Header (12 bytes) + NOOP (1 byte) + 2 * 8 bytes for data.
+ txn_options.max_write_batch_size = 29;
+ // Set threshold to unlimited so that the write batch does not get flushed,
+ // and can hit the memory limit.
+ txn_options.write_batch_flush_threshold = 0;
+ std::string value;
+ Status s;
+
+ Transaction* txn = db->BeginTransaction(WriteOptions(), txn_options);
+ ASSERT_TRUE(txn);
+
+ ASSERT_EQ(0, txn->GetNumPuts());
+ ASSERT_LE(0, txn->GetID());
+
+ s = txn->Put(Slice("a"), Slice("...."));
+ ASSERT_OK(s);
+ ASSERT_EQ(1, txn->GetNumPuts());
+
+ s = txn->Put(Slice("b"), Slice("...."));
+ ASSERT_OK(s);
+ ASSERT_EQ(2, txn->GetNumPuts());
+
+ s = txn->Put(Slice("b"), Slice("...."));
+ ASSERT_TRUE(s.IsMemoryLimit());
+ ASSERT_EQ(2, txn->GetNumPuts());
+
+ ASSERT_OK(txn->Rollback());
+ delete txn;
+}
+
+#if !defined(ROCKSDB_VALGRIND_RUN) || defined(ROCKSDB_FULL_VALGRIND_RUN)
+// This test clarifies the existing expectation from the sequence number
+// algorithm. It could detect mistakes in updating the code but it is not
+// necessarily the one acceptable way. If the algorithm is legitimately changed,
+// this unit test should be updated as well.
+TEST_P(TransactionStressTest, SeqAdvanceTest) {
+ // TODO(myabandeh): must be test with false before new releases
+ const bool short_test = true;
+ WriteOptions wopts;
+ FlushOptions fopt;
+
+ options.disable_auto_compactions = true;
+ ASSERT_OK(ReOpen());
+
+ // Do the test with NUM_BRANCHES branches in it. Each run of a test takes some
+ // of the branches. This is the same as counting a binary number where i-th
+ // bit represents whether we take branch i in the represented by the number.
+ const size_t NUM_BRANCHES = short_test ? 6 : 10;
+ // Helper function that shows if the branch is to be taken in the run
+ // represented by the number n.
+ auto branch_do = [&](size_t n, size_t* branch) {
+ assert(*branch < NUM_BRANCHES);
+ const size_t filter = static_cast<size_t>(1) << *branch;
+ return n & filter;
+ };
+ const size_t max_n = static_cast<size_t>(1) << NUM_BRANCHES;
+ for (size_t n = 0; n < max_n; n++) {
+ DBImpl* db_impl = static_cast_with_check<DBImpl>(db->GetRootDB());
+ size_t branch = 0;
+ auto seq = db_impl->GetLatestSequenceNumber();
+ exp_seq = seq;
+ TestTxn0(0);
+ seq = db_impl->TEST_GetLastVisibleSequence();
+ ASSERT_EQ(exp_seq, seq);
+
+ if (branch_do(n, &branch)) {
+ ASSERT_OK(db_impl->Flush(fopt));
+ seq = db_impl->TEST_GetLastVisibleSequence();
+ ASSERT_EQ(exp_seq, seq);
+ }
+ if (!short_test && branch_do(n, &branch)) {
+ ASSERT_OK(db_impl->FlushWAL(true));
+ ASSERT_OK(ReOpenNoDelete());
+ db_impl = static_cast_with_check<DBImpl>(db->GetRootDB());
+ seq = db_impl->GetLatestSequenceNumber();
+ ASSERT_EQ(exp_seq, seq);
+ }
+
+ // Doing it twice might detect some bugs
+ TestTxn0(1);
+ seq = db_impl->TEST_GetLastVisibleSequence();
+ ASSERT_EQ(exp_seq, seq);
+
+ TestTxn1(0);
+ seq = db_impl->TEST_GetLastVisibleSequence();
+ ASSERT_EQ(exp_seq, seq);
+
+ if (branch_do(n, &branch)) {
+ ASSERT_OK(db_impl->Flush(fopt));
+ seq = db_impl->TEST_GetLastVisibleSequence();
+ ASSERT_EQ(exp_seq, seq);
+ }
+ if (!short_test && branch_do(n, &branch)) {
+ ASSERT_OK(db_impl->FlushWAL(true));
+ ASSERT_OK(ReOpenNoDelete());
+ db_impl = static_cast_with_check<DBImpl>(db->GetRootDB());
+ seq = db_impl->GetLatestSequenceNumber();
+ ASSERT_EQ(exp_seq, seq);
+ }
+
+ TestTxn3(0);
+ seq = db_impl->TEST_GetLastVisibleSequence();
+ ASSERT_EQ(exp_seq, seq);
+
+ if (branch_do(n, &branch)) {
+ ASSERT_OK(db_impl->Flush(fopt));
+ seq = db_impl->TEST_GetLastVisibleSequence();
+ ASSERT_EQ(exp_seq, seq);
+ }
+ if (!short_test && branch_do(n, &branch)) {
+ ASSERT_OK(db_impl->FlushWAL(true));
+ ASSERT_OK(ReOpenNoDelete());
+ db_impl = static_cast_with_check<DBImpl>(db->GetRootDB());
+ seq = db_impl->GetLatestSequenceNumber();
+ ASSERT_EQ(exp_seq, seq);
+ }
+
+ TestTxn4(0);
+ seq = db_impl->TEST_GetLastVisibleSequence();
+
+ ASSERT_EQ(exp_seq, seq);
+
+ if (branch_do(n, &branch)) {
+ ASSERT_OK(db_impl->Flush(fopt));
+ seq = db_impl->TEST_GetLastVisibleSequence();
+ ASSERT_EQ(exp_seq, seq);
+ }
+ if (!short_test && branch_do(n, &branch)) {
+ ASSERT_OK(db_impl->FlushWAL(true));
+ ASSERT_OK(ReOpenNoDelete());
+ db_impl = static_cast_with_check<DBImpl>(db->GetRootDB());
+ seq = db_impl->GetLatestSequenceNumber();
+ ASSERT_EQ(exp_seq, seq);
+ }
+
+ TestTxn2(0);
+ seq = db_impl->TEST_GetLastVisibleSequence();
+ ASSERT_EQ(exp_seq, seq);
+
+ if (branch_do(n, &branch)) {
+ ASSERT_OK(db_impl->Flush(fopt));
+ seq = db_impl->TEST_GetLastVisibleSequence();
+ ASSERT_EQ(exp_seq, seq);
+ }
+ if (!short_test && branch_do(n, &branch)) {
+ ASSERT_OK(db_impl->FlushWAL(true));
+ ASSERT_OK(ReOpenNoDelete());
+ db_impl = static_cast_with_check<DBImpl>(db->GetRootDB());
+ seq = db_impl->GetLatestSequenceNumber();
+ ASSERT_EQ(exp_seq, seq);
+ }
+ ASSERT_OK(ReOpen());
+ }
+}
+#endif // !defined(ROCKSDB_VALGRIND_RUN) || defined(ROCKSDB_FULL_VALGRIND_RUN)
+
+// Verify that the optimization would not compromize the correctness
+TEST_P(TransactionTest, Optimizations) {
+ size_t comb_cnt = size_t(1) << 2; // 2 is number of optimization vars
+ for (size_t new_comb = 0; new_comb < comb_cnt; new_comb++) {
+ TransactionDBWriteOptimizations optimizations;
+ optimizations.skip_concurrency_control = IsInCombination(0, new_comb);
+ optimizations.skip_duplicate_key_check = IsInCombination(1, new_comb);
+
+ ASSERT_OK(ReOpen());
+ WriteOptions write_options;
+ WriteBatch batch;
+ ASSERT_OK(batch.Put(Slice("k"), Slice("v1")));
+ ASSERT_OK(db->Write(write_options, &batch));
+
+ ReadOptions ropt;
+ PinnableSlice pinnable_val;
+ ASSERT_OK(db->Get(ropt, db->DefaultColumnFamily(), "k", &pinnable_val));
+ ASSERT_TRUE(pinnable_val == ("v1"));
+ }
+}
+
+// A comparator that uses only the first three bytes
+class ThreeBytewiseComparator : public Comparator {
+ public:
+ ThreeBytewiseComparator() {}
+ const char* Name() const override { return "test.ThreeBytewiseComparator"; }
+ int Compare(const Slice& a, const Slice& b) const override {
+ Slice na = Slice(a.data(), a.size() < 3 ? a.size() : 3);
+ Slice nb = Slice(b.data(), b.size() < 3 ? b.size() : 3);
+ return na.compare(nb);
+ }
+ bool Equal(const Slice& a, const Slice& b) const override {
+ Slice na = Slice(a.data(), a.size() < 3 ? a.size() : 3);
+ Slice nb = Slice(b.data(), b.size() < 3 ? b.size() : 3);
+ return na == nb;
+ }
+ // These methods below don't seem relevant to this test. Implement them if
+ // proven othersize.
+ void FindShortestSeparator(std::string* start,
+ const Slice& limit) const override {
+ const Comparator* bytewise_comp = BytewiseComparator();
+ bytewise_comp->FindShortestSeparator(start, limit);
+ }
+ void FindShortSuccessor(std::string* key) const override {
+ const Comparator* bytewise_comp = BytewiseComparator();
+ bytewise_comp->FindShortSuccessor(key);
+ }
+};
+
+#if !defined(ROCKSDB_VALGRIND_RUN) || defined(ROCKSDB_FULL_VALGRIND_RUN)
+TEST_P(TransactionTest, GetWithoutSnapshot) {
+ WriteOptions write_options;
+ std::atomic<bool> finish = {false};
+ ASSERT_OK(db->Put(write_options, "key", "value"));
+ ROCKSDB_NAMESPACE::port::Thread commit_thread([&]() {
+ for (int i = 0; i < 100; i++) {
+ TransactionOptions txn_options;
+ Transaction* txn = db->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(txn->SetName("xid"));
+ ASSERT_OK(txn->Put("key", "overridedvalue"));
+ ASSERT_OK(txn->Put("key", "value"));
+ ASSERT_OK(txn->Prepare());
+ ASSERT_OK(txn->Commit());
+ delete txn;
+ }
+ finish = true;
+ });
+ ROCKSDB_NAMESPACE::port::Thread read_thread([&]() {
+ while (!finish) {
+ ReadOptions ropt;
+ PinnableSlice pinnable_val;
+ ASSERT_OK(db->Get(ropt, db->DefaultColumnFamily(), "key", &pinnable_val));
+ ASSERT_TRUE(pinnable_val == ("value"));
+ }
+ });
+ commit_thread.join();
+ read_thread.join();
+}
+#endif // !defined(ROCKSDB_VALGRIND_RUN) || defined(ROCKSDB_FULL_VALGRIND_RUN)
+
+// Test that the transactional db can handle duplicate keys in the write batch
+TEST_P(TransactionTest, DuplicateKeys) {
+ ColumnFamilyOptions cf_options;
+ std::string cf_name = "two";
+ ColumnFamilyHandle* cf_handle = nullptr;
+ {
+ ASSERT_OK(db->CreateColumnFamily(cf_options, cf_name, &cf_handle));
+ WriteOptions write_options;
+ WriteBatch batch;
+ ASSERT_OK(batch.Put(Slice("key"), Slice("value")));
+ ASSERT_OK(batch.Put(Slice("key2"), Slice("value2")));
+ // duplicate the keys
+ ASSERT_OK(batch.Put(Slice("key"), Slice("value3")));
+ // duplicate the 2nd key. It should not be counted duplicate since a
+ // sub-patch is cut after the last duplicate.
+ ASSERT_OK(batch.Put(Slice("key2"), Slice("value4")));
+ // duplicate the keys but in a different cf. It should not be counted as
+ // duplicate keys
+ ASSERT_OK(batch.Put(cf_handle, Slice("key"), Slice("value5")));
+
+ ASSERT_OK(db->Write(write_options, &batch));
+
+ ReadOptions ropt;
+ PinnableSlice pinnable_val;
+ auto s = db->Get(ropt, db->DefaultColumnFamily(), "key", &pinnable_val);
+ ASSERT_OK(s);
+ ASSERT_TRUE(pinnable_val == ("value3"));
+ s = db->Get(ropt, db->DefaultColumnFamily(), "key2", &pinnable_val);
+ ASSERT_OK(s);
+ ASSERT_TRUE(pinnable_val == ("value4"));
+ s = db->Get(ropt, cf_handle, "key", &pinnable_val);
+ ASSERT_OK(s);
+ ASSERT_TRUE(pinnable_val == ("value5"));
+
+ delete cf_handle;
+ }
+
+ // Test with non-bytewise comparator
+ {
+ ASSERT_OK(ReOpen());
+ std::unique_ptr<const Comparator> comp_gc(new ThreeBytewiseComparator());
+ cf_options.comparator = comp_gc.get();
+ ASSERT_OK(db->CreateColumnFamily(cf_options, cf_name, &cf_handle));
+ WriteOptions write_options;
+ WriteBatch batch;
+ ASSERT_OK(batch.Put(cf_handle, Slice("key"), Slice("value")));
+ // The first three bytes are the same, do it must be counted as duplicate
+ ASSERT_OK(batch.Put(cf_handle, Slice("key2"), Slice("value2")));
+ // check for 2nd duplicate key in cf with non-default comparator
+ ASSERT_OK(batch.Put(cf_handle, Slice("key2b"), Slice("value2b")));
+ ASSERT_OK(db->Write(write_options, &batch));
+
+ // The value must be the most recent value for all the keys equal to "key",
+ // including "key2"
+ ReadOptions ropt;
+ PinnableSlice pinnable_val;
+ ASSERT_OK(db->Get(ropt, cf_handle, "key", &pinnable_val));
+ ASSERT_TRUE(pinnable_val == ("value2b"));
+
+ // Test duplicate keys with rollback
+ TransactionOptions txn_options;
+ Transaction* txn0 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(txn0->SetName("xid"));
+ ASSERT_OK(txn0->Put(cf_handle, Slice("key3"), Slice("value3")));
+ ASSERT_OK(txn0->Merge(cf_handle, Slice("key4"), Slice("value4")));
+ ASSERT_OK(txn0->Rollback());
+ ASSERT_OK(db->Get(ropt, cf_handle, "key5", &pinnable_val));
+ ASSERT_TRUE(pinnable_val == ("value2b"));
+ delete txn0;
+
+ delete cf_handle;
+ cf_options.comparator = BytewiseComparator();
+ }
+
+ for (bool do_prepare : {true, false}) {
+ for (bool do_rollback : {true, false}) {
+ for (bool with_commit_batch : {true, false}) {
+ if (with_commit_batch && !do_prepare) {
+ continue;
+ }
+ if (with_commit_batch && do_rollback) {
+ continue;
+ }
+ ASSERT_OK(ReOpen());
+ ASSERT_OK(db->CreateColumnFamily(cf_options, cf_name, &cf_handle));
+ TransactionOptions txn_options;
+ txn_options.use_only_the_last_commit_time_batch_for_recovery = true;
+ WriteOptions write_options;
+ Transaction* txn0 = db->BeginTransaction(write_options, txn_options);
+ auto s = txn0->SetName("xid");
+ ASSERT_OK(s);
+ s = txn0->Put(Slice("foo0"), Slice("bar0a"));
+ ASSERT_OK(s);
+ s = txn0->Put(Slice("foo0"), Slice("bar0b"));
+ ASSERT_OK(s);
+ s = txn0->Put(Slice("foo1"), Slice("bar1"));
+ ASSERT_OK(s);
+ s = txn0->Merge(Slice("foo2"), Slice("bar2a"));
+ ASSERT_OK(s);
+ // Repeat a key after the start of a sub-patch. This should not cause a
+ // duplicate in the most recent sub-patch and hence not creating a new
+ // sub-patch.
+ s = txn0->Put(Slice("foo0"), Slice("bar0c"));
+ ASSERT_OK(s);
+ s = txn0->Merge(Slice("foo2"), Slice("bar2b"));
+ ASSERT_OK(s);
+ // duplicate the keys but in a different cf. It should not be counted as
+ // duplicate.
+ s = txn0->Put(cf_handle, Slice("foo0"), Slice("bar0-cf1"));
+ ASSERT_OK(s);
+ s = txn0->Put(Slice("foo3"), Slice("bar3"));
+ ASSERT_OK(s);
+ s = txn0->Merge(Slice("foo3"), Slice("bar3"));
+ ASSERT_OK(s);
+ s = txn0->Put(Slice("foo4"), Slice("bar4"));
+ ASSERT_OK(s);
+ s = txn0->Delete(Slice("foo4"));
+ ASSERT_OK(s);
+ s = txn0->SingleDelete(Slice("foo4"));
+ ASSERT_OK(s);
+ if (do_prepare) {
+ s = txn0->Prepare();
+ ASSERT_OK(s);
+ }
+ if (do_rollback) {
+ // Test rolling back the batch with duplicates
+ s = txn0->Rollback();
+ ASSERT_OK(s);
+ } else {
+ if (with_commit_batch) {
+ assert(do_prepare);
+ auto cb = txn0->GetCommitTimeWriteBatch();
+ // duplicate a key in the original batch
+ // TODO(myabandeh): the behavior of GetCommitTimeWriteBatch
+ // conflicting with the prepared batch is currently undefined and
+ // gives different results in different implementations.
+
+ // s = cb->Put(Slice("foo0"), Slice("bar0d"));
+ // ASSERT_OK(s);
+ // add a new duplicate key
+ s = cb->Put(Slice("foo6"), Slice("bar6a"));
+ ASSERT_OK(s);
+ s = cb->Put(Slice("foo6"), Slice("bar6b"));
+ ASSERT_OK(s);
+ // add a duplicate key that is removed in the same batch
+ s = cb->Put(Slice("foo7"), Slice("bar7a"));
+ ASSERT_OK(s);
+ s = cb->Delete(Slice("foo7"));
+ ASSERT_OK(s);
+ }
+ s = txn0->Commit();
+ ASSERT_OK(s);
+ }
+ delete txn0;
+ ReadOptions ropt;
+ PinnableSlice pinnable_val;
+
+ if (do_rollback) {
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo0", &pinnable_val);
+ ASSERT_TRUE(s.IsNotFound());
+ s = db->Get(ropt, cf_handle, "foo0", &pinnable_val);
+ ASSERT_TRUE(s.IsNotFound());
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo1", &pinnable_val);
+ ASSERT_TRUE(s.IsNotFound());
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo2", &pinnable_val);
+ ASSERT_TRUE(s.IsNotFound());
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo3", &pinnable_val);
+ ASSERT_TRUE(s.IsNotFound());
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo4", &pinnable_val);
+ ASSERT_TRUE(s.IsNotFound());
+ } else {
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo0", &pinnable_val);
+ ASSERT_OK(s);
+ ASSERT_TRUE(pinnable_val == ("bar0c"));
+ s = db->Get(ropt, cf_handle, "foo0", &pinnable_val);
+ ASSERT_OK(s);
+ ASSERT_TRUE(pinnable_val == ("bar0-cf1"));
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo1", &pinnable_val);
+ ASSERT_OK(s);
+ ASSERT_TRUE(pinnable_val == ("bar1"));
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo2", &pinnable_val);
+ ASSERT_OK(s);
+ ASSERT_TRUE(pinnable_val == ("bar2a,bar2b"));
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo3", &pinnable_val);
+ ASSERT_OK(s);
+ ASSERT_TRUE(pinnable_val == ("bar3,bar3"));
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo4", &pinnable_val);
+ ASSERT_TRUE(s.IsNotFound());
+ if (with_commit_batch) {
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo6", &pinnable_val);
+ if (txn_db_options.write_policy ==
+ TxnDBWritePolicy::WRITE_COMMITTED) {
+ ASSERT_OK(s);
+ ASSERT_TRUE(pinnable_val == ("bar6b"));
+ } else {
+ ASSERT_TRUE(s.IsNotFound());
+ }
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo7", &pinnable_val);
+ ASSERT_TRUE(s.IsNotFound());
+ }
+ }
+ delete cf_handle;
+ } // with_commit_batch
+ } // do_rollback
+ } // do_prepare
+
+ if (!options.unordered_write) {
+ // Also test with max_successive_merges > 0. max_successive_merges will not
+ // affect our algorithm for duplicate key insertion but we add the test to
+ // verify that.
+ cf_options.max_successive_merges = 2;
+ cf_options.merge_operator = MergeOperators::CreateStringAppendOperator();
+ ASSERT_OK(ReOpen());
+ db->CreateColumnFamily(cf_options, cf_name, &cf_handle);
+ WriteOptions write_options;
+ // Ensure one value for the key
+ ASSERT_OK(db->Put(write_options, cf_handle, Slice("key"), Slice("value")));
+ WriteBatch batch;
+ // Merge more than max_successive_merges times
+ ASSERT_OK(batch.Merge(cf_handle, Slice("key"), Slice("1")));
+ ASSERT_OK(batch.Merge(cf_handle, Slice("key"), Slice("2")));
+ ASSERT_OK(batch.Merge(cf_handle, Slice("key"), Slice("3")));
+ ASSERT_OK(batch.Merge(cf_handle, Slice("key"), Slice("4")));
+ ASSERT_OK(db->Write(write_options, &batch));
+ ReadOptions read_options;
+ std::string value;
+ ASSERT_OK(db->Get(read_options, cf_handle, "key", &value));
+ ASSERT_EQ(value, "value,1,2,3,4");
+ delete cf_handle;
+ }
+
+ {
+ // Test that the duplicate detection is not compromised after rolling back
+ // to a save point
+ TransactionOptions txn_options;
+ WriteOptions write_options;
+ Transaction* txn0 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(txn0->Put(Slice("foo0"), Slice("bar0a")));
+ ASSERT_OK(txn0->Put(Slice("foo0"), Slice("bar0b")));
+ txn0->SetSavePoint();
+ ASSERT_OK(txn0->RollbackToSavePoint());
+ ASSERT_OK(txn0->Commit());
+ delete txn0;
+ }
+
+ // Test sucessfull recovery after a crash
+ {
+ ASSERT_OK(ReOpen());
+ TransactionOptions txn_options;
+ WriteOptions write_options;
+ ReadOptions ropt;
+ Transaction* txn0;
+ PinnableSlice pinnable_val;
+ Status s;
+
+ std::unique_ptr<const Comparator> comp_gc(new ThreeBytewiseComparator());
+ cf_options.comparator = comp_gc.get();
+ cf_options.merge_operator = MergeOperators::CreateStringAppendOperator();
+ ASSERT_OK(db->CreateColumnFamily(cf_options, cf_name, &cf_handle));
+ delete cf_handle;
+ std::vector<ColumnFamilyDescriptor> cfds{
+ ColumnFamilyDescriptor(kDefaultColumnFamilyName,
+ ColumnFamilyOptions(options)),
+ ColumnFamilyDescriptor(cf_name, cf_options),
+ };
+ std::vector<ColumnFamilyHandle*> handles;
+ ASSERT_OK(ReOpenNoDelete(cfds, &handles));
+
+ assert(db != nullptr);
+ ASSERT_OK(db->Put(write_options, "foo0", "init"));
+ ASSERT_OK(db->Put(write_options, "foo1", "init"));
+ ASSERT_OK(db->Put(write_options, handles[1], "foo0", "init"));
+ ASSERT_OK(db->Put(write_options, handles[1], "foo1", "init"));
+
+ // one entry
+ txn0 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(txn0->SetName("xid"));
+ ASSERT_OK(txn0->Put(Slice("foo0"), Slice("bar0a")));
+ ASSERT_OK(txn0->Prepare());
+ delete txn0;
+ // This will check the asserts inside recovery code
+ ASSERT_OK(db->FlushWAL(true));
+ reinterpret_cast<PessimisticTransactionDB*>(db)->TEST_Crash();
+ ASSERT_OK(ReOpenNoDelete(cfds, &handles));
+ txn0 = db->GetTransactionByName("xid");
+ ASSERT_TRUE(txn0 != nullptr);
+ ASSERT_OK(txn0->Commit());
+ delete txn0;
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo0", &pinnable_val);
+ ASSERT_OK(s);
+ ASSERT_TRUE(pinnable_val == ("bar0a"));
+
+ // two entries, no duplicate
+ txn0 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(txn0->SetName("xid"));
+ ASSERT_OK(txn0->Put(handles[1], Slice("foo0"), Slice("bar0b")));
+ ASSERT_OK(txn0->Put(handles[1], Slice("fol1"), Slice("bar1b")));
+ ASSERT_OK(txn0->Put(Slice("foo0"), Slice("bar0b")));
+ ASSERT_OK(txn0->Put(Slice("foo1"), Slice("bar1b")));
+ ASSERT_OK(txn0->Prepare());
+ delete txn0;
+ // This will check the asserts inside recovery code
+ ASSERT_OK(db->FlushWAL(true));
+ // Flush only cf 1
+ ASSERT_OK(static_cast_with_check<DBImpl>(db->GetRootDB())
+ ->TEST_FlushMemTable(true, false, handles[1]));
+ reinterpret_cast<PessimisticTransactionDB*>(db)->TEST_Crash();
+ ASSERT_OK(ReOpenNoDelete(cfds, &handles));
+ txn0 = db->GetTransactionByName("xid");
+ ASSERT_TRUE(txn0 != nullptr);
+ ASSERT_OK(txn0->Commit());
+ delete txn0;
+ pinnable_val.Reset();
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo0", &pinnable_val);
+ ASSERT_OK(s);
+ ASSERT_TRUE(pinnable_val == ("bar0b"));
+ pinnable_val.Reset();
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo1", &pinnable_val);
+ ASSERT_OK(s);
+ ASSERT_TRUE(pinnable_val == ("bar1b"));
+ pinnable_val.Reset();
+ s = db->Get(ropt, handles[1], "foo0", &pinnable_val);
+ ASSERT_OK(s);
+ ASSERT_TRUE(pinnable_val == ("bar0b"));
+ pinnable_val.Reset();
+ s = db->Get(ropt, handles[1], "fol1", &pinnable_val);
+ ASSERT_OK(s);
+ ASSERT_TRUE(pinnable_val == ("bar1b"));
+
+ // one duplicate with ::Put
+ txn0 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(txn0->SetName("xid"));
+ ASSERT_OK(txn0->Put(handles[1], Slice("key-nonkey0"), Slice("bar0c")));
+ ASSERT_OK(txn0->Put(handles[1], Slice("key-nonkey1"), Slice("bar1d")));
+ ASSERT_OK(txn0->Put(Slice("foo0"), Slice("bar0c")));
+ ASSERT_OK(txn0->Put(Slice("foo1"), Slice("bar1c")));
+ ASSERT_OK(txn0->Put(Slice("foo0"), Slice("bar0d")));
+ ASSERT_OK(txn0->Prepare());
+ delete txn0;
+ // This will check the asserts inside recovery code
+ ASSERT_OK(db->FlushWAL(true));
+ // Flush only cf 1
+ ASSERT_OK(static_cast_with_check<DBImpl>(db->GetRootDB())
+ ->TEST_FlushMemTable(true, false, handles[1]));
+ reinterpret_cast<PessimisticTransactionDB*>(db)->TEST_Crash();
+ ASSERT_OK(ReOpenNoDelete(cfds, &handles));
+ txn0 = db->GetTransactionByName("xid");
+ ASSERT_TRUE(txn0 != nullptr);
+ ASSERT_OK(txn0->Commit());
+ delete txn0;
+ pinnable_val.Reset();
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo0", &pinnable_val);
+ ASSERT_OK(s);
+ ASSERT_TRUE(pinnable_val == ("bar0d"));
+ pinnable_val.Reset();
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo1", &pinnable_val);
+ ASSERT_OK(s);
+ ASSERT_TRUE(pinnable_val == ("bar1c"));
+ pinnable_val.Reset();
+ s = db->Get(ropt, handles[1], "key-nonkey2", &pinnable_val);
+ ASSERT_OK(s);
+ ASSERT_TRUE(pinnable_val == ("bar1d"));
+
+ // Duplicate with ::Put, ::Delete
+ txn0 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(txn0->SetName("xid"));
+ ASSERT_OK(txn0->Put(handles[1], Slice("key-nonkey0"), Slice("bar0e")));
+ ASSERT_OK(txn0->Delete(handles[1], Slice("key-nonkey1")));
+ ASSERT_OK(txn0->Put(Slice("foo0"), Slice("bar0e")));
+ ASSERT_OK(txn0->Delete(Slice("foo0")));
+ ASSERT_OK(txn0->Prepare());
+ delete txn0;
+ // This will check the asserts inside recovery code
+ ASSERT_OK(db->FlushWAL(true));
+ // Flush only cf 1
+ ASSERT_OK(static_cast_with_check<DBImpl>(db->GetRootDB())
+ ->TEST_FlushMemTable(true, false, handles[1]));
+ reinterpret_cast<PessimisticTransactionDB*>(db)->TEST_Crash();
+ ASSERT_OK(ReOpenNoDelete(cfds, &handles));
+ txn0 = db->GetTransactionByName("xid");
+ ASSERT_TRUE(txn0 != nullptr);
+ ASSERT_OK(txn0->Commit());
+ delete txn0;
+ pinnable_val.Reset();
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo0", &pinnable_val);
+ ASSERT_TRUE(s.IsNotFound());
+ pinnable_val.Reset();
+ s = db->Get(ropt, handles[1], "key-nonkey2", &pinnable_val);
+ ASSERT_TRUE(s.IsNotFound());
+
+ // Duplicate with ::Put, ::SingleDelete
+ txn0 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(txn0->SetName("xid"));
+ ASSERT_OK(txn0->Put(handles[1], Slice("key-nonkey0"), Slice("bar0g")));
+ ASSERT_OK(txn0->SingleDelete(handles[1], Slice("key-nonkey1")));
+ ASSERT_OK(txn0->Put(Slice("foo0"), Slice("bar0e")));
+ ASSERT_OK(txn0->SingleDelete(Slice("foo0")));
+ ASSERT_OK(txn0->Prepare());
+ delete txn0;
+ // This will check the asserts inside recovery code
+ ASSERT_OK(db->FlushWAL(true));
+ // Flush only cf 1
+ ASSERT_OK(static_cast_with_check<DBImpl>(db->GetRootDB())
+ ->TEST_FlushMemTable(true, false, handles[1]));
+ reinterpret_cast<PessimisticTransactionDB*>(db)->TEST_Crash();
+ ASSERT_OK(ReOpenNoDelete(cfds, &handles));
+ txn0 = db->GetTransactionByName("xid");
+ ASSERT_TRUE(txn0 != nullptr);
+ ASSERT_OK(txn0->Commit());
+ delete txn0;
+ pinnable_val.Reset();
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo0", &pinnable_val);
+ ASSERT_TRUE(s.IsNotFound());
+ pinnable_val.Reset();
+ s = db->Get(ropt, handles[1], "key-nonkey2", &pinnable_val);
+ ASSERT_TRUE(s.IsNotFound());
+
+ // Duplicate with ::Put, ::Merge
+ txn0 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(txn0->SetName("xid"));
+ ASSERT_OK(txn0->Put(handles[1], Slice("key-nonkey0"), Slice("bar1i")));
+ ASSERT_OK(txn0->Merge(handles[1], Slice("key-nonkey1"), Slice("bar1j")));
+ ASSERT_OK(txn0->Put(Slice("foo0"), Slice("bar0f")));
+ ASSERT_OK(txn0->Merge(Slice("foo0"), Slice("bar0g")));
+ ASSERT_OK(txn0->Prepare());
+ delete txn0;
+ // This will check the asserts inside recovery code
+ ASSERT_OK(db->FlushWAL(true));
+ // Flush only cf 1
+ ASSERT_OK(static_cast_with_check<DBImpl>(db->GetRootDB())
+ ->TEST_FlushMemTable(true, false, handles[1]));
+ reinterpret_cast<PessimisticTransactionDB*>(db)->TEST_Crash();
+ ASSERT_OK(ReOpenNoDelete(cfds, &handles));
+ txn0 = db->GetTransactionByName("xid");
+ ASSERT_TRUE(txn0 != nullptr);
+ ASSERT_OK(txn0->Commit());
+ delete txn0;
+ pinnable_val.Reset();
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo0", &pinnable_val);
+ ASSERT_OK(s);
+ ASSERT_TRUE(pinnable_val == ("bar0f,bar0g"));
+ pinnable_val.Reset();
+ s = db->Get(ropt, handles[1], "key-nonkey2", &pinnable_val);
+ ASSERT_OK(s);
+ ASSERT_TRUE(pinnable_val == ("bar1i,bar1j"));
+
+ for (auto h : handles) {
+ delete h;
+ }
+ delete db;
+ db = nullptr;
+ }
+}
+
+// Test that the reseek optimization in iterators will not result in an infinite
+// loop if there are too many uncommitted entries before the snapshot.
+TEST_P(TransactionTest, ReseekOptimization) {
+ WriteOptions write_options;
+ write_options.sync = true;
+ write_options.disableWAL = false;
+ ColumnFamilyDescriptor cfd;
+ ASSERT_OK(db->DefaultColumnFamily()->GetDescriptor(&cfd));
+ auto max_skip = cfd.options.max_sequential_skip_in_iterations;
+
+ ASSERT_OK(db->Put(write_options, Slice("foo0"), Slice("initv")));
+
+ TransactionOptions txn_options;
+ Transaction* txn0 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(txn0->SetName("xid"));
+ // Duplicate keys will result into separate sequence numbers in WritePrepared
+ // and WriteUnPrepared
+ for (size_t i = 0; i < 2 * max_skip; i++) {
+ ASSERT_OK(txn0->Put(Slice("foo1"), Slice("bar")));
+ }
+ ASSERT_OK(txn0->Prepare());
+ ASSERT_OK(db->Put(write_options, Slice("foo2"), Slice("initv")));
+
+ ReadOptions read_options;
+ // To avoid loops
+ read_options.max_skippable_internal_keys = 10 * max_skip;
+ Iterator* iter = db->NewIterator(read_options);
+ ASSERT_OK(iter->status());
+ size_t cnt = 0;
+ iter->SeekToFirst();
+ while (iter->Valid()) {
+ iter->Next();
+ ASSERT_OK(iter->status());
+ cnt++;
+ }
+ ASSERT_EQ(cnt, 2);
+ cnt = 0;
+ iter->SeekToLast();
+ while (iter->Valid()) {
+ iter->Prev();
+ ASSERT_OK(iter->status());
+ cnt++;
+ }
+ ASSERT_EQ(cnt, 2);
+ delete iter;
+ ASSERT_OK(txn0->Rollback());
+ delete txn0;
+}
+
+// After recovery in kPointInTimeRecovery mode, the corrupted log file remains
+// there. The new log files should be still read succesfully during recovery of
+// the 2nd crash.
+TEST_P(TransactionTest, DoubleCrashInRecovery) {
+ for (const bool manual_wal_flush : {false, true}) {
+ for (const bool write_after_recovery : {false, true}) {
+ options.wal_recovery_mode = WALRecoveryMode::kPointInTimeRecovery;
+ options.manual_wal_flush = manual_wal_flush;
+ ASSERT_OK(ReOpen());
+ std::string cf_name = "two";
+ ColumnFamilyOptions cf_options;
+ ColumnFamilyHandle* cf_handle = nullptr;
+ ASSERT_OK(db->CreateColumnFamily(cf_options, cf_name, &cf_handle));
+
+ // Add a prepare entry to prevent the older logs from being deleted.
+ WriteOptions write_options;
+ TransactionOptions txn_options;
+ Transaction* txn = db->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(txn->SetName("xid"));
+ ASSERT_OK(txn->Put(Slice("foo-prepare"), Slice("bar-prepare")));
+ ASSERT_OK(txn->Prepare());
+
+ FlushOptions flush_ops;
+ ASSERT_OK(db->Flush(flush_ops));
+ // Now we have a log that cannot be deleted
+
+ ASSERT_OK(db->Put(write_options, cf_handle, "foo1", "bar1"));
+ // Flush only the 2nd cf
+ ASSERT_OK(db->Flush(flush_ops, cf_handle));
+
+ // The value is large enough to be touched by the corruption we ingest
+ // below.
+ std::string large_value(400, ' ');
+ // key/value not touched by corruption
+ ASSERT_OK(db->Put(write_options, "foo2", "bar2"));
+ // key/value touched by corruption
+ ASSERT_OK(db->Put(write_options, "foo3", large_value));
+ // key/value not touched by corruption
+ ASSERT_OK(db->Put(write_options, "foo4", "bar4"));
+
+ ASSERT_OK(db->FlushWAL(true));
+ DBImpl* db_impl = static_cast_with_check<DBImpl>(db->GetRootDB());
+ uint64_t wal_file_id = db_impl->TEST_LogfileNumber();
+ std::string fname = LogFileName(dbname, wal_file_id);
+ reinterpret_cast<PessimisticTransactionDB*>(db)->TEST_Crash();
+ delete txn;
+ delete cf_handle;
+ delete db;
+ db = nullptr;
+
+ // Corrupt the last log file in the middle, so that it is not corrupted
+ // in the tail.
+ std::string file_content;
+ ASSERT_OK(ReadFileToString(env, fname, &file_content));
+ file_content[400] = 'h';
+ file_content[401] = 'a';
+ ASSERT_OK(env->DeleteFile(fname));
+ ASSERT_OK(WriteStringToFile(env, file_content, fname, true));
+
+ // Recover from corruption
+ std::vector<ColumnFamilyHandle*> handles;
+ std::vector<ColumnFamilyDescriptor> column_families;
+ column_families.push_back(ColumnFamilyDescriptor(kDefaultColumnFamilyName,
+ ColumnFamilyOptions()));
+ column_families.push_back(
+ ColumnFamilyDescriptor("two", ColumnFamilyOptions()));
+ ASSERT_OK(ReOpenNoDelete(column_families, &handles));
+ assert(db != nullptr);
+
+ if (write_after_recovery) {
+ // Write data to the log right after the corrupted log
+ ASSERT_OK(db->Put(write_options, "foo5", large_value));
+ }
+
+ // Persist data written to WAL during recovery or by the last Put
+ ASSERT_OK(db->FlushWAL(true));
+ // 2nd crash to recover while having a valid log after the corrupted one.
+ ASSERT_OK(ReOpenNoDelete(column_families, &handles));
+ assert(db != nullptr);
+ txn = db->GetTransactionByName("xid");
+ ASSERT_TRUE(txn != nullptr);
+ ASSERT_OK(txn->Commit());
+ delete txn;
+ for (auto handle : handles) {
+ delete handle;
+ }
+ }
+ }
+}
+
+TEST_P(TransactionTest, CommitWithoutPrepare) {
+ {
+ // skip_prepare = false.
+ WriteOptions write_options;
+ TransactionOptions txn_options;
+ txn_options.skip_prepare = false;
+ Transaction* txn = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn->Commit().IsTxnNotPrepared());
+ delete txn;
+ }
+
+ {
+ // skip_prepare = true.
+ WriteOptions write_options;
+ TransactionOptions txn_options;
+ txn_options.skip_prepare = true;
+ Transaction* txn = db->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(txn->Commit());
+ delete txn;
+ }
+}
+
+TEST_P(TransactionTest, OpenAndEnableU64Timestamp) {
+ ASSERT_OK(ReOpenNoDelete());
+
+ assert(db);
+
+ const std::string test_cf_name = "test_cf";
+ ColumnFamilyOptions cf_opts;
+ cf_opts.comparator = test::BytewiseComparatorWithU64TsWrapper();
+ {
+ ColumnFamilyHandle* cfh = nullptr;
+ const Status s = db->CreateColumnFamily(cf_opts, test_cf_name, &cfh);
+ if (txn_db_options.write_policy == WRITE_COMMITTED) {
+ ASSERT_OK(s);
+ delete cfh;
+ } else {
+ ASSERT_TRUE(s.IsNotSupported());
+ assert(!cfh);
+ }
+ }
+
+ // Bypass transaction db layer.
+ if (txn_db_options.write_policy != WRITE_COMMITTED) {
+ DBImpl* db_impl = static_cast_with_check<DBImpl>(db->GetRootDB());
+ assert(db_impl);
+ ColumnFamilyHandle* cfh = nullptr;
+ ASSERT_OK(db_impl->CreateColumnFamily(cf_opts, test_cf_name, &cfh));
+ delete cfh;
+ }
+
+ {
+ std::vector<ColumnFamilyDescriptor> cf_descs;
+ cf_descs.emplace_back(kDefaultColumnFamilyName, options);
+ cf_descs.emplace_back(test_cf_name, cf_opts);
+ std::vector<ColumnFamilyHandle*> handles;
+ const Status s = ReOpenNoDelete(cf_descs, &handles);
+ if (txn_db_options.write_policy == WRITE_COMMITTED) {
+ ASSERT_OK(s);
+ for (auto* h : handles) {
+ delete h;
+ }
+ } else {
+ ASSERT_TRUE(s.IsNotSupported());
+ }
+ }
+}
+
+TEST_P(TransactionTest, OpenAndEnableU32Timestamp) {
+ class DummyComparatorWithU32Ts : public Comparator {
+ public:
+ DummyComparatorWithU32Ts() : Comparator(sizeof(uint32_t)) {}
+ const char* Name() const override { return "DummyComparatorWithU32Ts"; }
+ void FindShortSuccessor(std::string*) const override {}
+ void FindShortestSeparator(std::string*, const Slice&) const override {}
+ int Compare(const Slice&, const Slice&) const override { return 0; }
+ };
+
+ std::unique_ptr<Comparator> dummy_ucmp(new DummyComparatorWithU32Ts());
+
+ ASSERT_OK(ReOpenNoDelete());
+
+ assert(db);
+
+ const std::string test_cf_name = "test_cf";
+
+ ColumnFamilyOptions cf_opts;
+ cf_opts.comparator = dummy_ucmp.get();
+ {
+ ColumnFamilyHandle* cfh = nullptr;
+ ASSERT_TRUE(db->CreateColumnFamily(cf_opts, test_cf_name, &cfh)
+ .IsInvalidArgument());
+ }
+
+ // Bypass transaction db layer.
+ {
+ ColumnFamilyHandle* cfh = nullptr;
+ DBImpl* db_impl = static_cast_with_check<DBImpl>(db->GetRootDB());
+ assert(db_impl);
+ ASSERT_OK(db_impl->CreateColumnFamily(cf_opts, test_cf_name, &cfh));
+ delete cfh;
+ }
+
+ {
+ std::vector<ColumnFamilyDescriptor> cf_descs;
+ cf_descs.emplace_back(kDefaultColumnFamilyName, options);
+ cf_descs.emplace_back(test_cf_name, cf_opts);
+ std::vector<ColumnFamilyHandle*> handles;
+ ASSERT_TRUE(ReOpenNoDelete(cf_descs, &handles).IsInvalidArgument());
+ }
+}
+
+TEST_P(TransactionTest, WriteWithBulkCreatedColumnFamilies) {
+ ColumnFamilyOptions cf_options;
+ WriteOptions write_options;
+
+ std::vector<std::string> cf_names;
+ std::vector<ColumnFamilyHandle*> cf_handles;
+
+ cf_names.push_back("test_cf");
+
+ ASSERT_OK(db->CreateColumnFamilies(cf_options, cf_names, &cf_handles));
+ ASSERT_OK(db->Put(write_options, cf_handles[0], "foo", "bar"));
+ ASSERT_OK(db->DropColumnFamilies(cf_handles));
+
+ for (auto* h : cf_handles) {
+ delete h;
+ }
+ cf_handles.clear();
+
+ std::vector<ColumnFamilyDescriptor> cf_descriptors;
+
+ cf_descriptors.emplace_back("test_cf", ColumnFamilyOptions());
+
+ ASSERT_OK(db->CreateColumnFamilies(cf_options, cf_names, &cf_handles));
+ ASSERT_OK(db->Put(write_options, cf_handles[0], "foo", "bar"));
+ ASSERT_OK(db->DropColumnFamilies(cf_handles));
+ for (auto* h : cf_handles) {
+ delete h;
+ }
+ cf_handles.clear();
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
+
+#else
+#include <stdio.h>
+
+int main(int /*argc*/, char** /*argv*/) {
+ fprintf(stderr,
+ "SKIPPED as Transactions are not supported in ROCKSDB_LITE\n");
+ return 0;
+}
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/transaction_test.h b/src/rocksdb/utilities/transactions/transaction_test.h
new file mode 100644
index 000000000..0b86453a4
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/transaction_test.h
@@ -0,0 +1,578 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#include <algorithm>
+#include <cinttypes>
+#include <functional>
+#include <string>
+#include <thread>
+
+#include "db/db_impl/db_impl.h"
+#include "db/db_test_util.h"
+#include "port/port.h"
+#include "rocksdb/db.h"
+#include "rocksdb/options.h"
+#include "rocksdb/utilities/transaction.h"
+#include "rocksdb/utilities/transaction_db.h"
+#include "table/mock_table.h"
+#include "test_util/sync_point.h"
+#include "test_util/testharness.h"
+#include "test_util/testutil.h"
+#include "test_util/transaction_test_util.h"
+#include "util/random.h"
+#include "util/string_util.h"
+#include "utilities/fault_injection_env.h"
+#include "utilities/merge_operators.h"
+#include "utilities/merge_operators/string_append/stringappend.h"
+#include "utilities/transactions/pessimistic_transaction_db.h"
+#include "utilities/transactions/write_unprepared_txn_db.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// Return true if the ith bit is set in combination represented by comb
+bool IsInCombination(size_t i, size_t comb) { return comb & (size_t(1) << i); }
+
+enum WriteOrdering : bool { kOrderedWrite, kUnorderedWrite };
+
+class TransactionTestBase : public ::testing::Test {
+ public:
+ TransactionDB* db;
+ SpecialEnv special_env;
+ FaultInjectionTestEnv* env;
+ std::string dbname;
+ Options options;
+
+ TransactionDBOptions txn_db_options;
+ bool use_stackable_db_;
+
+ TransactionTestBase(bool use_stackable_db, bool two_write_queue,
+ TxnDBWritePolicy write_policy,
+ WriteOrdering write_ordering)
+ : db(nullptr),
+ special_env(Env::Default()),
+ env(nullptr),
+ use_stackable_db_(use_stackable_db) {
+ options.create_if_missing = true;
+ options.max_write_buffer_number = 2;
+ options.write_buffer_size = 4 * 1024;
+ options.unordered_write = write_ordering == kUnorderedWrite;
+ options.level0_file_num_compaction_trigger = 2;
+ options.merge_operator = MergeOperators::CreateFromStringId("stringappend");
+ special_env.skip_fsync_ = true;
+ env = new FaultInjectionTestEnv(&special_env);
+ options.env = env;
+ options.two_write_queues = two_write_queue;
+ dbname = test::PerThreadDBPath("transaction_testdb");
+
+ EXPECT_OK(DestroyDB(dbname, options));
+ txn_db_options.transaction_lock_timeout = 0;
+ txn_db_options.default_lock_timeout = 0;
+ txn_db_options.write_policy = write_policy;
+ txn_db_options.rollback_merge_operands = true;
+ // This will stress write unprepared, by forcing write batch flush on every
+ // write.
+ txn_db_options.default_write_batch_flush_threshold = 1;
+ // Write unprepared requires all transactions to be named. This setting
+ // autogenerates the name so that existing tests can pass.
+ txn_db_options.autogenerate_name = true;
+ Status s;
+ if (use_stackable_db == false) {
+ s = TransactionDB::Open(options, txn_db_options, dbname, &db);
+ } else {
+ s = OpenWithStackableDB();
+ }
+ EXPECT_OK(s);
+ }
+
+ ~TransactionTestBase() {
+ delete db;
+ db = nullptr;
+ // This is to skip the assert statement in FaultInjectionTestEnv. There
+ // seems to be a bug in btrfs that the makes readdir return recently
+ // unlink-ed files. By using the default fs we simply ignore errors resulted
+ // from attempting to delete such files in DestroyDB.
+ if (getenv("KEEP_DB") == nullptr) {
+ options.env = Env::Default();
+ EXPECT_OK(DestroyDB(dbname, options));
+ } else {
+ fprintf(stdout, "db is still in %s\n", dbname.c_str());
+ }
+ delete env;
+ }
+
+ Status ReOpenNoDelete() {
+ delete db;
+ db = nullptr;
+ env->AssertNoOpenFile();
+ env->DropUnsyncedFileData();
+ env->ResetState();
+ Status s;
+ if (use_stackable_db_ == false) {
+ s = TransactionDB::Open(options, txn_db_options, dbname, &db);
+ } else {
+ s = OpenWithStackableDB();
+ }
+ assert(!s.ok() || db != nullptr);
+ return s;
+ }
+
+ Status ReOpenNoDelete(std::vector<ColumnFamilyDescriptor>& cfs,
+ std::vector<ColumnFamilyHandle*>* handles) {
+ for (auto h : *handles) {
+ delete h;
+ }
+ handles->clear();
+ delete db;
+ db = nullptr;
+ env->AssertNoOpenFile();
+ env->DropUnsyncedFileData();
+ env->ResetState();
+ Status s;
+ if (use_stackable_db_ == false) {
+ s = TransactionDB::Open(options, txn_db_options, dbname, cfs, handles,
+ &db);
+ } else {
+ s = OpenWithStackableDB(cfs, handles);
+ }
+ assert(!s.ok() || db != nullptr);
+ return s;
+ }
+
+ Status ReOpen() {
+ delete db;
+ db = nullptr;
+ DestroyDB(dbname, options);
+ Status s;
+ if (use_stackable_db_ == false) {
+ s = TransactionDB::Open(options, txn_db_options, dbname, &db);
+ } else {
+ s = OpenWithStackableDB();
+ }
+ assert(db != nullptr);
+ return s;
+ }
+
+ Status OpenWithStackableDB(std::vector<ColumnFamilyDescriptor>& cfs,
+ std::vector<ColumnFamilyHandle*>* handles) {
+ std::vector<size_t> compaction_enabled_cf_indices;
+ TransactionDB::PrepareWrap(&options, &cfs, &compaction_enabled_cf_indices);
+ DB* root_db = nullptr;
+ Options options_copy(options);
+ const bool use_seq_per_batch =
+ txn_db_options.write_policy == WRITE_PREPARED ||
+ txn_db_options.write_policy == WRITE_UNPREPARED;
+ const bool use_batch_per_txn =
+ txn_db_options.write_policy == WRITE_COMMITTED ||
+ txn_db_options.write_policy == WRITE_PREPARED;
+ Status s = DBImpl::Open(options_copy, dbname, cfs, handles, &root_db,
+ use_seq_per_batch, use_batch_per_txn);
+ auto stackable_db = std::make_unique<StackableDB>(root_db);
+ if (s.ok()) {
+ assert(root_db != nullptr);
+ // If WrapStackableDB() returns non-ok, then stackable_db is already
+ // deleted within WrapStackableDB().
+ s = TransactionDB::WrapStackableDB(stackable_db.release(), txn_db_options,
+ compaction_enabled_cf_indices,
+ *handles, &db);
+ }
+ return s;
+ }
+
+ Status OpenWithStackableDB() {
+ std::vector<size_t> compaction_enabled_cf_indices;
+ std::vector<ColumnFamilyDescriptor> column_families{ColumnFamilyDescriptor(
+ kDefaultColumnFamilyName, ColumnFamilyOptions(options))};
+
+ TransactionDB::PrepareWrap(&options, &column_families,
+ &compaction_enabled_cf_indices);
+ std::vector<ColumnFamilyHandle*> handles;
+ DB* root_db = nullptr;
+ Options options_copy(options);
+ const bool use_seq_per_batch =
+ txn_db_options.write_policy == WRITE_PREPARED ||
+ txn_db_options.write_policy == WRITE_UNPREPARED;
+ const bool use_batch_per_txn =
+ txn_db_options.write_policy == WRITE_COMMITTED ||
+ txn_db_options.write_policy == WRITE_PREPARED;
+ Status s = DBImpl::Open(options_copy, dbname, column_families, &handles,
+ &root_db, use_seq_per_batch, use_batch_per_txn);
+ if (!s.ok()) {
+ delete root_db;
+ return s;
+ }
+ StackableDB* stackable_db = new StackableDB(root_db);
+ assert(root_db != nullptr);
+ assert(handles.size() == 1);
+ s = TransactionDB::WrapStackableDB(stackable_db, txn_db_options,
+ compaction_enabled_cf_indices, handles,
+ &db);
+ delete handles[0];
+ if (!s.ok()) {
+ delete stackable_db;
+ }
+ return s;
+ }
+
+ std::atomic<size_t> linked = {0};
+ std::atomic<size_t> exp_seq = {0};
+ std::atomic<size_t> commit_writes = {0};
+ std::atomic<size_t> expected_commits = {0};
+ // Without Prepare, the commit does not write to WAL
+ std::atomic<size_t> with_empty_commits = {0};
+ void TestTxn0(size_t index) {
+ // Test DB's internal txn. It involves no prepare phase nor a commit marker.
+ auto s = db->Put(WriteOptions(), "key" + std::to_string(index), "value");
+ ASSERT_OK(s);
+ if (txn_db_options.write_policy == TxnDBWritePolicy::WRITE_COMMITTED) {
+ // Consume one seq per key
+ exp_seq++;
+ } else {
+ // Consume one seq per batch
+ exp_seq++;
+ if (options.two_write_queues) {
+ // Consume one seq for commit
+ exp_seq++;
+ }
+ }
+ with_empty_commits++;
+ }
+
+ void TestTxn1(size_t index) {
+ // Testing directly writing a write batch. Functionality-wise it is
+ // equivalent to commit without prepare.
+ WriteBatch wb;
+ auto istr = std::to_string(index);
+ ASSERT_OK(wb.Put("k1" + istr, "v1"));
+ ASSERT_OK(wb.Put("k2" + istr, "v2"));
+ ASSERT_OK(wb.Put("k3" + istr, "v3"));
+ auto s = db->Write(WriteOptions(), &wb);
+ if (txn_db_options.write_policy == TxnDBWritePolicy::WRITE_COMMITTED) {
+ // Consume one seq per key
+ exp_seq += 3;
+ } else {
+ // Consume one seq per batch
+ exp_seq++;
+ if (options.two_write_queues) {
+ // Consume one seq for commit
+ exp_seq++;
+ }
+ }
+ ASSERT_OK(s);
+ with_empty_commits++;
+ }
+
+ void TestTxn2(size_t index) {
+ // Commit without prepare. It should write to DB without a commit marker.
+ Transaction* txn =
+ db->BeginTransaction(WriteOptions(), TransactionOptions());
+ auto istr = std::to_string(index);
+ ASSERT_OK(txn->SetName("xid" + istr));
+ ASSERT_OK(txn->Put(Slice("foo" + istr), Slice("bar")));
+ ASSERT_OK(txn->Put(Slice("foo2" + istr), Slice("bar2")));
+ ASSERT_OK(txn->Put(Slice("foo3" + istr), Slice("bar3")));
+ ASSERT_OK(txn->Put(Slice("foo4" + istr), Slice("bar4")));
+ ASSERT_OK(txn->Commit());
+ if (txn_db_options.write_policy == TxnDBWritePolicy::WRITE_COMMITTED) {
+ // Consume one seq per key
+ exp_seq += 4;
+ } else if (txn_db_options.write_policy ==
+ TxnDBWritePolicy::WRITE_PREPARED) {
+ // Consume one seq per batch
+ exp_seq++;
+ if (options.two_write_queues) {
+ // Consume one seq for commit
+ exp_seq++;
+ }
+ } else {
+ // Flushed after each key, consume one seq per flushed batch
+ exp_seq += 4;
+ // WriteUnprepared implements CommitWithoutPrepareInternal by simply
+ // calling Prepare then Commit. Consume one seq for the prepare.
+ exp_seq++;
+ }
+ delete txn;
+ with_empty_commits++;
+ }
+
+ void TestTxn3(size_t index) {
+ // A full 2pc txn that also involves a commit marker.
+ Transaction* txn =
+ db->BeginTransaction(WriteOptions(), TransactionOptions());
+ auto istr = std::to_string(index);
+ ASSERT_OK(txn->SetName("xid" + istr));
+ ASSERT_OK(txn->Put(Slice("foo" + istr), Slice("bar")));
+ ASSERT_OK(txn->Put(Slice("foo2" + istr), Slice("bar2")));
+ ASSERT_OK(txn->Put(Slice("foo3" + istr), Slice("bar3")));
+ ASSERT_OK(txn->Put(Slice("foo4" + istr), Slice("bar4")));
+ ASSERT_OK(txn->Put(Slice("foo5" + istr), Slice("bar5")));
+ expected_commits++;
+ ASSERT_OK(txn->Prepare());
+ commit_writes++;
+ ASSERT_OK(txn->Commit());
+ if (txn_db_options.write_policy == TxnDBWritePolicy::WRITE_COMMITTED) {
+ // Consume one seq per key
+ exp_seq += 5;
+ } else if (txn_db_options.write_policy ==
+ TxnDBWritePolicy::WRITE_PREPARED) {
+ // Consume one seq per batch
+ exp_seq++;
+ // Consume one seq per commit marker
+ exp_seq++;
+ } else {
+ // Flushed after each key, consume one seq per flushed batch
+ exp_seq += 5;
+ // Consume one seq per commit marker
+ exp_seq++;
+ }
+ delete txn;
+ }
+
+ void TestTxn4(size_t index) {
+ // A full 2pc txn that also involves a commit marker.
+ Transaction* txn =
+ db->BeginTransaction(WriteOptions(), TransactionOptions());
+ auto istr = std::to_string(index);
+ ASSERT_OK(txn->SetName("xid" + istr));
+ ASSERT_OK(txn->Put(Slice("foo" + istr), Slice("bar")));
+ ASSERT_OK(txn->Put(Slice("foo2" + istr), Slice("bar2")));
+ ASSERT_OK(txn->Put(Slice("foo3" + istr), Slice("bar3")));
+ ASSERT_OK(txn->Put(Slice("foo4" + istr), Slice("bar4")));
+ ASSERT_OK(txn->Put(Slice("foo5" + istr), Slice("bar5")));
+ expected_commits++;
+ ASSERT_OK(txn->Prepare());
+ commit_writes++;
+ ASSERT_OK(txn->Rollback());
+ if (txn_db_options.write_policy == TxnDBWritePolicy::WRITE_COMMITTED) {
+ // No seq is consumed for deleting the txn buffer
+ exp_seq += 0;
+ } else if (txn_db_options.write_policy ==
+ TxnDBWritePolicy::WRITE_PREPARED) {
+ // Consume one seq per batch
+ exp_seq++;
+ // Consume one seq per rollback batch
+ exp_seq++;
+ if (options.two_write_queues) {
+ // Consume one seq for rollback commit
+ exp_seq++;
+ }
+ } else {
+ // Flushed after each key, consume one seq per flushed batch
+ exp_seq += 5;
+ // Consume one seq per rollback batch
+ exp_seq++;
+ if (options.two_write_queues) {
+ // Consume one seq for rollback commit
+ exp_seq++;
+ }
+ }
+ delete txn;
+ }
+
+ // Test that we can change write policy after a clean shutdown (which would
+ // empty the WAL)
+ void CrossCompatibilityTest(TxnDBWritePolicy from_policy,
+ TxnDBWritePolicy to_policy, bool empty_wal) {
+ TransactionOptions txn_options;
+ ReadOptions read_options;
+ WriteOptions write_options;
+ uint32_t index = 0;
+ Random rnd(1103);
+ options.write_buffer_size = 1024; // To create more sst files
+ std::unordered_map<std::string, std::string> committed_kvs;
+ Transaction* txn;
+
+ txn_db_options.write_policy = from_policy;
+ if (txn_db_options.write_policy == WRITE_COMMITTED) {
+ options.unordered_write = false;
+ }
+ ASSERT_OK(ReOpen());
+
+ for (int i = 0; i < 1024; i++) {
+ auto istr = std::to_string(index);
+ auto k = Slice("foo-" + istr).ToString();
+ auto v = Slice("bar-" + istr).ToString();
+ // For test the duplicate keys
+ auto v2 = Slice("bar2-" + istr).ToString();
+ auto type = rnd.Uniform(4);
+ switch (type) {
+ case 0:
+ committed_kvs[k] = v;
+ ASSERT_OK(db->Put(write_options, k, v));
+ committed_kvs[k] = v2;
+ ASSERT_OK(db->Put(write_options, k, v2));
+ break;
+ case 1: {
+ WriteBatch wb;
+ committed_kvs[k] = v;
+ ASSERT_OK(wb.Put(k, v));
+ committed_kvs[k] = v2;
+ ASSERT_OK(wb.Put(k, v2));
+ ASSERT_OK(db->Write(write_options, &wb));
+
+ } break;
+ case 2:
+ case 3:
+ txn = db->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(txn->SetName("xid" + istr));
+ committed_kvs[k] = v;
+ ASSERT_OK(txn->Put(k, v));
+ committed_kvs[k] = v2;
+ ASSERT_OK(txn->Put(k, v2));
+
+ if (type == 3) {
+ ASSERT_OK(txn->Prepare());
+ }
+ ASSERT_OK(txn->Commit());
+ delete txn;
+ break;
+ default:
+ FAIL();
+ }
+
+ index++;
+ } // for i
+
+ txn_db_options.write_policy = to_policy;
+ if (txn_db_options.write_policy == WRITE_COMMITTED) {
+ options.unordered_write = false;
+ }
+ auto db_impl = static_cast_with_check<DBImpl>(db->GetRootDB());
+ // Before upgrade/downgrade the WAL must be emptied
+ if (empty_wal) {
+ ASSERT_OK(db_impl->TEST_FlushMemTable());
+ } else {
+ ASSERT_OK(db_impl->FlushWAL(true));
+ }
+ auto s = ReOpenNoDelete();
+ if (empty_wal) {
+ ASSERT_OK(s);
+ } else {
+ // Test that we can detect the WAL that is produced by an incompatible
+ // WritePolicy and fail fast before mis-interpreting the WAL.
+ ASSERT_TRUE(s.IsNotSupported());
+ return;
+ }
+ db_impl = static_cast_with_check<DBImpl>(db->GetRootDB());
+ // Check that WAL is empty
+ VectorLogPtr log_files;
+ ASSERT_OK(db_impl->GetSortedWalFiles(log_files));
+ ASSERT_EQ(0, log_files.size());
+
+ for (auto& kv : committed_kvs) {
+ std::string value;
+ s = db->Get(read_options, kv.first, &value);
+ if (s.IsNotFound()) {
+ printf("key = %s\n", kv.first.c_str());
+ }
+ ASSERT_OK(s);
+ if (kv.second != value) {
+ printf("key = %s\n", kv.first.c_str());
+ }
+ ASSERT_EQ(kv.second, value);
+ }
+ }
+};
+
+class TransactionTest
+ : public TransactionTestBase,
+ virtual public ::testing::WithParamInterface<
+ std::tuple<bool, bool, TxnDBWritePolicy, WriteOrdering>> {
+ public:
+ TransactionTest()
+ : TransactionTestBase(std::get<0>(GetParam()), std::get<1>(GetParam()),
+ std::get<2>(GetParam()), std::get<3>(GetParam())){};
+};
+
+class TransactionStressTest : public TransactionTest {};
+
+class MySQLStyleTransactionTest
+ : public TransactionTestBase,
+ virtual public ::testing::WithParamInterface<
+ std::tuple<bool, bool, TxnDBWritePolicy, WriteOrdering, bool>> {
+ public:
+ MySQLStyleTransactionTest()
+ : TransactionTestBase(std::get<0>(GetParam()), std::get<1>(GetParam()),
+ std::get<2>(GetParam()), std::get<3>(GetParam())),
+ with_slow_threads_(std::get<4>(GetParam())) {
+ if (with_slow_threads_ &&
+ (txn_db_options.write_policy == WRITE_PREPARED ||
+ txn_db_options.write_policy == WRITE_UNPREPARED)) {
+ // The corner case with slow threads involves the caches filling
+ // over which would not happen even with artifial delays. To help
+ // such cases to show up we lower the size of the cache-related data
+ // structures.
+ txn_db_options.wp_snapshot_cache_bits = 1;
+ txn_db_options.wp_commit_cache_bits = 10;
+ options.write_buffer_size = 1024;
+ EXPECT_OK(ReOpen());
+ }
+ };
+
+ protected:
+ // Also emulate slow threads by addin artiftial delays
+ const bool with_slow_threads_;
+};
+
+class WriteCommittedTxnWithTsTest
+ : public TransactionTestBase,
+ public ::testing::WithParamInterface<std::tuple<bool, bool, bool>> {
+ public:
+ WriteCommittedTxnWithTsTest()
+ : TransactionTestBase(std::get<0>(GetParam()), std::get<1>(GetParam()),
+ WRITE_COMMITTED, kOrderedWrite) {}
+ ~WriteCommittedTxnWithTsTest() override {
+ for (auto* h : handles_) {
+ delete h;
+ }
+ }
+
+ Status GetFromDb(ReadOptions read_opts, ColumnFamilyHandle* column_family,
+ const Slice& key, TxnTimestamp ts, std::string* value) {
+ std::string ts_buf;
+ PutFixed64(&ts_buf, ts);
+ Slice ts_slc = ts_buf;
+ read_opts.timestamp = &ts_slc;
+ assert(db);
+ return db->Get(read_opts, column_family, key, value);
+ }
+
+ Transaction* NewTxn(WriteOptions write_opts, TransactionOptions txn_opts) {
+ assert(db);
+ auto* txn = db->BeginTransaction(write_opts, txn_opts);
+ assert(txn);
+ const bool enable_indexing = std::get<2>(GetParam());
+ if (enable_indexing) {
+ txn->EnableIndexing();
+ } else {
+ txn->DisableIndexing();
+ }
+ return txn;
+ }
+
+ protected:
+ std::vector<ColumnFamilyHandle*> handles_{};
+};
+
+class TimestampedSnapshotWithTsSanityCheck
+ : public TransactionTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<bool, bool, TxnDBWritePolicy, WriteOrdering>> {
+ public:
+ explicit TimestampedSnapshotWithTsSanityCheck()
+ : TransactionTestBase(std::get<0>(GetParam()), std::get<1>(GetParam()),
+ std::get<2>(GetParam()), std::get<3>(GetParam())) {}
+ ~TimestampedSnapshotWithTsSanityCheck() override {
+ for (auto* h : handles_) {
+ delete h;
+ }
+ }
+
+ protected:
+ std::vector<ColumnFamilyHandle*> handles_{};
+};
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/transactions/transaction_util.cc b/src/rocksdb/utilities/transactions/transaction_util.cc
new file mode 100644
index 000000000..360edc8ec
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/transaction_util.cc
@@ -0,0 +1,206 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "utilities/transactions/transaction_util.h"
+
+#include <cinttypes>
+#include <string>
+#include <vector>
+
+#include "db/db_impl/db_impl.h"
+#include "rocksdb/status.h"
+#include "rocksdb/utilities/write_batch_with_index.h"
+#include "util/cast_util.h"
+#include "util/string_util.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+Status TransactionUtil::CheckKeyForConflicts(
+ DBImpl* db_impl, ColumnFamilyHandle* column_family, const std::string& key,
+ SequenceNumber snap_seq, const std::string* const read_ts, bool cache_only,
+ ReadCallback* snap_checker, SequenceNumber min_uncommitted) {
+ Status result;
+
+ auto cfh = static_cast_with_check<ColumnFamilyHandleImpl>(column_family);
+ auto cfd = cfh->cfd();
+ SuperVersion* sv = db_impl->GetAndRefSuperVersion(cfd);
+
+ if (sv == nullptr) {
+ result = Status::InvalidArgument("Could not access column family " +
+ cfh->GetName());
+ }
+
+ if (result.ok()) {
+ SequenceNumber earliest_seq =
+ db_impl->GetEarliestMemTableSequenceNumber(sv, true);
+
+ result = CheckKey(db_impl, sv, earliest_seq, snap_seq, key, read_ts,
+ cache_only, snap_checker, min_uncommitted);
+
+ db_impl->ReturnAndCleanupSuperVersion(cfd, sv);
+ }
+
+ return result;
+}
+
+Status TransactionUtil::CheckKey(DBImpl* db_impl, SuperVersion* sv,
+ SequenceNumber earliest_seq,
+ SequenceNumber snap_seq,
+ const std::string& key,
+ const std::string* const read_ts,
+ bool cache_only, ReadCallback* snap_checker,
+ SequenceNumber min_uncommitted) {
+ // When `min_uncommitted` is provided, keys are not always committed
+ // in sequence number order, and `snap_checker` is used to check whether
+ // specific sequence number is in the database is visible to the transaction.
+ // So `snap_checker` must be provided.
+ assert(min_uncommitted == kMaxSequenceNumber || snap_checker != nullptr);
+
+ Status result;
+ bool need_to_read_sst = false;
+
+ // Since it would be too slow to check the SST files, we will only use
+ // the memtables to check whether there have been any recent writes
+ // to this key after it was accessed in this transaction. But if the
+ // Memtables do not contain a long enough history, we must fail the
+ // transaction.
+ if (earliest_seq == kMaxSequenceNumber) {
+ // The age of this memtable is unknown. Cannot rely on it to check
+ // for recent writes. This error shouldn't happen often in practice as
+ // the Memtable should have a valid earliest sequence number except in some
+ // corner cases (such as error cases during recovery).
+ need_to_read_sst = true;
+
+ if (cache_only) {
+ result = Status::TryAgain(
+ "Transaction could not check for conflicts as the MemTable does not "
+ "contain a long enough history to check write at SequenceNumber: ",
+ std::to_string(snap_seq));
+ }
+ } else if (snap_seq < earliest_seq || min_uncommitted <= earliest_seq) {
+ // Use <= for min_uncommitted since earliest_seq is actually the largest sec
+ // before this memtable was created
+ need_to_read_sst = true;
+
+ if (cache_only) {
+ // The age of this memtable is too new to use to check for recent
+ // writes.
+ char msg[300];
+ snprintf(msg, sizeof(msg),
+ "Transaction could not check for conflicts for operation at "
+ "SequenceNumber %" PRIu64
+ " as the MemTable only contains changes newer than "
+ "SequenceNumber %" PRIu64
+ ". Increasing the value of the "
+ "max_write_buffer_size_to_maintain option could reduce the "
+ "frequency "
+ "of this error.",
+ snap_seq, earliest_seq);
+ result = Status::TryAgain(msg);
+ }
+ }
+
+ if (result.ok()) {
+ SequenceNumber seq = kMaxSequenceNumber;
+ std::string timestamp;
+ bool found_record_for_key = false;
+
+ // When min_uncommitted == kMaxSequenceNumber, writes are committed in
+ // sequence number order, so only keys larger than `snap_seq` can cause
+ // conflict.
+ // When min_uncommitted != kMaxSequenceNumber, keys lower than
+ // min_uncommitted will not triggered conflicts, while keys larger than
+ // min_uncommitted might create conflicts, so we need to read them out
+ // from the DB, and call callback to snap_checker to determine. So only
+ // keys lower than min_uncommitted can be skipped.
+ SequenceNumber lower_bound_seq =
+ (min_uncommitted == kMaxSequenceNumber) ? snap_seq : min_uncommitted;
+ Status s = db_impl->GetLatestSequenceForKey(
+ sv, key, !need_to_read_sst, lower_bound_seq, &seq,
+ !read_ts ? nullptr : &timestamp, &found_record_for_key,
+ /*is_blob_index=*/nullptr);
+
+ if (!(s.ok() || s.IsNotFound() || s.IsMergeInProgress())) {
+ result = s;
+ } else if (found_record_for_key) {
+ bool write_conflict = snap_checker == nullptr
+ ? snap_seq < seq
+ : !snap_checker->IsVisible(seq);
+ // Perform conflict checking based on timestamp if applicable.
+ if (!write_conflict && read_ts != nullptr) {
+ ColumnFamilyData* cfd = sv->cfd;
+ assert(cfd);
+ const Comparator* const ucmp = cfd->user_comparator();
+ assert(ucmp);
+ assert(read_ts->size() == ucmp->timestamp_size());
+ assert(read_ts->size() == timestamp.size());
+ // Write conflict if *ts < timestamp.
+ write_conflict = ucmp->CompareTimestamp(*read_ts, timestamp) < 0;
+ }
+ if (write_conflict) {
+ result = Status::Busy();
+ }
+ }
+ }
+
+ return result;
+}
+
+Status TransactionUtil::CheckKeysForConflicts(DBImpl* db_impl,
+ const LockTracker& tracker,
+ bool cache_only) {
+ Status result;
+
+ std::unique_ptr<LockTracker::ColumnFamilyIterator> cf_it(
+ tracker.GetColumnFamilyIterator());
+ assert(cf_it != nullptr);
+ while (cf_it->HasNext()) {
+ ColumnFamilyId cf = cf_it->Next();
+
+ SuperVersion* sv = db_impl->GetAndRefSuperVersion(cf);
+ if (sv == nullptr) {
+ result = Status::InvalidArgument("Could not access column family " +
+ std::to_string(cf));
+ break;
+ }
+
+ SequenceNumber earliest_seq =
+ db_impl->GetEarliestMemTableSequenceNumber(sv, true);
+
+ // For each of the keys in this transaction, check to see if someone has
+ // written to this key since the start of the transaction.
+ std::unique_ptr<LockTracker::KeyIterator> key_it(
+ tracker.GetKeyIterator(cf));
+ assert(key_it != nullptr);
+ while (key_it->HasNext()) {
+ const std::string& key = key_it->Next();
+ PointLockStatus status = tracker.GetPointLockStatus(cf, key);
+ const SequenceNumber key_seq = status.seq;
+
+ // TODO: support timestamp-based conflict checking.
+ // CheckKeysForConflicts() is currently used only by optimistic
+ // transactions.
+ result = CheckKey(db_impl, sv, earliest_seq, key_seq, key,
+ /*read_ts=*/nullptr, cache_only);
+ if (!result.ok()) {
+ break;
+ }
+ }
+
+ db_impl->ReturnAndCleanupSuperVersion(cf, sv);
+
+ if (!result.ok()) {
+ break;
+ }
+ }
+
+ return result;
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/transaction_util.h b/src/rocksdb/utilities/transactions/transaction_util.h
new file mode 100644
index 000000000..a349ba87a
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/transaction_util.h
@@ -0,0 +1,85 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#ifndef ROCKSDB_LITE
+
+#include <string>
+#include <unordered_map>
+
+#include "db/dbformat.h"
+#include "db/read_callback.h"
+#include "rocksdb/db.h"
+#include "rocksdb/slice.h"
+#include "rocksdb/status.h"
+#include "rocksdb/types.h"
+#include "utilities/transactions/lock/lock_tracker.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class DBImpl;
+struct SuperVersion;
+class WriteBatchWithIndex;
+
+class TransactionUtil {
+ public:
+ // Verifies there have been no commits to this key in the db since this
+ // sequence number. If user-defined timestamp is enabled, then also check
+ // no commits to this key in the db since the given ts.
+ //
+ // If cache_only is true, then this function will not attempt to read any
+ // SST files. This will make it more likely this function will
+ // return an error if it is unable to determine if there are any conflicts.
+ //
+ // See comment of CheckKey() for explanation of `snap_seq`, `ts`,
+ // `snap_checker` and `min_uncommitted`.
+ //
+ // Returns OK on success, BUSY if there is a conflicting write, or other error
+ // status for any unexpected errors.
+ static Status CheckKeyForConflicts(
+ DBImpl* db_impl, ColumnFamilyHandle* column_family,
+ const std::string& key, SequenceNumber snap_seq,
+ const std::string* const ts, bool cache_only,
+ ReadCallback* snap_checker = nullptr,
+ SequenceNumber min_uncommitted = kMaxSequenceNumber);
+
+ // For each key,SequenceNumber pair tracked by the LockTracker, this function
+ // will verify there have been no writes to the key in the db since that
+ // sequence number.
+ //
+ // Returns OK on success, BUSY if there is a conflicting write, or other error
+ // status for any unexpected errors.
+ //
+ // REQUIRED:
+ // This function should only be called on the write thread or if the
+ // mutex is held.
+ // tracker must support point lock.
+ static Status CheckKeysForConflicts(DBImpl* db_impl,
+ const LockTracker& tracker,
+ bool cache_only);
+
+ private:
+ // If `snap_checker` == nullptr, writes are always commited in sequence number
+ // order. All sequence number <= `snap_seq` will not conflict with any
+ // write, and all keys > `snap_seq` of `key` will trigger conflict.
+ // If `snap_checker` != nullptr, writes may not commit in sequence number
+ // order. In this case `min_uncommitted` is a lower bound.
+ // seq < `min_uncommitted`: no conflict
+ // seq > `snap_seq`: applicable to conflict
+ // `min_uncommitted` <= seq <= `snap_seq`: call `snap_checker` to determine.
+ //
+ // If user-defined timestamp is enabled, a write conflict is detected if an
+ // operation for `key` with timestamp greater than `ts` exists.
+ static Status CheckKey(DBImpl* db_impl, SuperVersion* sv,
+ SequenceNumber earliest_seq, SequenceNumber snap_seq,
+ const std::string& key, const std::string* const ts,
+ bool cache_only, ReadCallback* snap_checker = nullptr,
+ SequenceNumber min_uncommitted = kMaxSequenceNumber);
+};
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/write_committed_transaction_ts_test.cc b/src/rocksdb/utilities/transactions/write_committed_transaction_ts_test.cc
new file mode 100644
index 000000000..94b8201f7
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/write_committed_transaction_ts_test.cc
@@ -0,0 +1,588 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "rocksdb/db.h"
+#include "rocksdb/options.h"
+#include "rocksdb/utilities/transaction_db.h"
+#include "utilities/merge_operators.h"
+#ifndef ROCKSDB_LITE
+
+#include "test_util/testutil.h"
+#include "utilities/transactions/transaction_test.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+INSTANTIATE_TEST_CASE_P(
+ DBAsBaseDB, WriteCommittedTxnWithTsTest,
+ ::testing::Values(std::make_tuple(false, /*two_write_queue=*/false,
+ /*enable_indexing=*/false),
+ std::make_tuple(false, /*two_write_queue=*/true,
+ /*enable_indexing=*/false),
+ std::make_tuple(false, /*two_write_queue=*/false,
+ /*enable_indexing=*/true),
+ std::make_tuple(false, /*two_write_queue=*/true,
+ /*enable_indexing=*/true)));
+
+INSTANTIATE_TEST_CASE_P(
+ DBAsStackableDB, WriteCommittedTxnWithTsTest,
+ ::testing::Values(std::make_tuple(true, /*two_write_queue=*/false,
+ /*enable_indexing=*/false),
+ std::make_tuple(true, /*two_write_queue=*/true,
+ /*enable_indexing=*/false),
+ std::make_tuple(true, /*two_write_queue=*/false,
+ /*enable_indexing=*/true),
+ std::make_tuple(true, /*two_write_queue=*/true,
+ /*enable_indexing=*/true)));
+
+TEST_P(WriteCommittedTxnWithTsTest, SanityChecks) {
+ ASSERT_OK(ReOpenNoDelete());
+
+ ColumnFamilyOptions cf_opts;
+ cf_opts.comparator = test::BytewiseComparatorWithU64TsWrapper();
+ const std::string test_cf_name = "test_cf";
+ ColumnFamilyHandle* cfh = nullptr;
+ assert(db);
+ ASSERT_OK(db->CreateColumnFamily(cf_opts, test_cf_name, &cfh));
+ delete cfh;
+ cfh = nullptr;
+
+ std::vector<ColumnFamilyDescriptor> cf_descs;
+ cf_descs.emplace_back(kDefaultColumnFamilyName, options);
+ cf_descs.emplace_back(test_cf_name, cf_opts);
+ ASSERT_OK(ReOpenNoDelete(cf_descs, &handles_));
+
+ std::unique_ptr<Transaction> txn(
+ NewTxn(WriteOptions(), TransactionOptions()));
+ assert(txn);
+ ASSERT_OK(txn->Put(handles_[1], "foo", "value"));
+ ASSERT_TRUE(txn->Commit().IsInvalidArgument());
+
+ auto* pessimistic_txn =
+ static_cast_with_check<PessimisticTransaction>(txn.get());
+ ASSERT_TRUE(
+ pessimistic_txn->CommitBatch(/*batch=*/nullptr).IsInvalidArgument());
+
+ {
+ WriteBatchWithIndex* wbwi = txn->GetWriteBatch();
+ assert(wbwi);
+ WriteBatch* wb = wbwi->GetWriteBatch();
+ assert(wb);
+ // Write a key to the batch for nonexisting cf.
+ ASSERT_OK(WriteBatchInternal::Put(wb, /*column_family_id=*/10, /*key=*/"",
+ /*value=*/""));
+ }
+
+ ASSERT_OK(txn->SetCommitTimestamp(20));
+
+ ASSERT_TRUE(txn->Commit().IsInvalidArgument());
+ txn.reset();
+
+ std::unique_ptr<Transaction> txn1(
+ NewTxn(WriteOptions(), TransactionOptions()));
+ assert(txn1);
+ ASSERT_OK(txn1->SetName("txn1"));
+ ASSERT_OK(txn1->Put(handles_[1], "foo", "value"));
+ {
+ WriteBatchWithIndex* wbwi = txn1->GetWriteBatch();
+ assert(wbwi);
+ WriteBatch* wb = wbwi->GetWriteBatch();
+ assert(wb);
+ // Write a key to the batch for non-existing cf.
+ ASSERT_OK(WriteBatchInternal::Put(wb, /*column_family_id=*/10, /*key=*/"",
+ /*value=*/""));
+ }
+ ASSERT_OK(txn1->Prepare());
+ ASSERT_OK(txn1->SetCommitTimestamp(21));
+ ASSERT_TRUE(txn1->Commit().IsInvalidArgument());
+ txn1.reset();
+}
+
+TEST_P(WriteCommittedTxnWithTsTest, ReOpenWithTimestamp) {
+ options.merge_operator = MergeOperators::CreateUInt64AddOperator();
+ ASSERT_OK(ReOpenNoDelete());
+
+ ColumnFamilyOptions cf_opts;
+ cf_opts.comparator = test::BytewiseComparatorWithU64TsWrapper();
+ const std::string test_cf_name = "test_cf";
+ ColumnFamilyHandle* cfh = nullptr;
+ assert(db);
+ ASSERT_OK(db->CreateColumnFamily(cf_opts, test_cf_name, &cfh));
+ delete cfh;
+ cfh = nullptr;
+
+ std::vector<ColumnFamilyDescriptor> cf_descs;
+ cf_descs.emplace_back(kDefaultColumnFamilyName, options);
+ cf_descs.emplace_back(test_cf_name, cf_opts);
+ ASSERT_OK(ReOpenNoDelete(cf_descs, &handles_));
+
+ std::unique_ptr<Transaction> txn0(
+ NewTxn(WriteOptions(), TransactionOptions()));
+ assert(txn0);
+ ASSERT_OK(txn0->Put(handles_[1], "foo", "value"));
+ ASSERT_OK(txn0->SetName("txn0"));
+ ASSERT_OK(txn0->Prepare());
+ ASSERT_TRUE(txn0->Commit().IsInvalidArgument());
+ txn0.reset();
+
+ std::unique_ptr<Transaction> txn1(
+ NewTxn(WriteOptions(), TransactionOptions()));
+ assert(txn1);
+ ASSERT_OK(txn1->Put(handles_[1], "foo", "value1"));
+ {
+ std::string buf;
+ PutFixed64(&buf, 23);
+ ASSERT_OK(txn1->Put("id", buf));
+ ASSERT_OK(txn1->Merge("id", buf));
+ }
+ ASSERT_OK(txn1->SetName("txn1"));
+ ASSERT_OK(txn1->Prepare());
+ ASSERT_OK(txn1->SetCommitTimestamp(/*ts=*/23));
+ ASSERT_OK(txn1->Commit());
+ txn1.reset();
+
+ {
+ std::string value;
+ const Status s =
+ GetFromDb(ReadOptions(), handles_[1], "foo", /*ts=*/23, &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("value1", value);
+ }
+
+ {
+ std::string value;
+ const Status s = db->Get(ReadOptions(), handles_[0], "id", &value);
+ ASSERT_OK(s);
+ uint64_t ival = 0;
+ Slice value_slc = value;
+ bool result = GetFixed64(&value_slc, &ival);
+ assert(result);
+ ASSERT_EQ(46, ival);
+ }
+}
+
+TEST_P(WriteCommittedTxnWithTsTest, RecoverFromWal) {
+ ASSERT_OK(ReOpenNoDelete());
+
+ ColumnFamilyOptions cf_opts;
+ cf_opts.comparator = test::BytewiseComparatorWithU64TsWrapper();
+ const std::string test_cf_name = "test_cf";
+ ColumnFamilyHandle* cfh = nullptr;
+ assert(db);
+ ASSERT_OK(db->CreateColumnFamily(cf_opts, test_cf_name, &cfh));
+ delete cfh;
+ cfh = nullptr;
+
+ std::vector<ColumnFamilyDescriptor> cf_descs;
+ cf_descs.emplace_back(kDefaultColumnFamilyName, options);
+ cf_descs.emplace_back(test_cf_name, cf_opts);
+ options.avoid_flush_during_shutdown = true;
+ ASSERT_OK(ReOpenNoDelete(cf_descs, &handles_));
+
+ std::unique_ptr<Transaction> txn0(
+ NewTxn(WriteOptions(), TransactionOptions()));
+ assert(txn0);
+ ASSERT_OK(txn0->Put(handles_[1], "foo", "foo_value"));
+ ASSERT_OK(txn0->SetName("txn0"));
+ ASSERT_OK(txn0->Prepare());
+
+ WriteOptions write_opts;
+ write_opts.sync = true;
+ std::unique_ptr<Transaction> txn1(NewTxn(write_opts, TransactionOptions()));
+ assert(txn1);
+ ASSERT_OK(txn1->Put("bar", "bar_value_1"));
+ ASSERT_OK(txn1->Put(handles_[1], "bar", "bar_value_1"));
+ ASSERT_OK(txn1->SetName("txn1"));
+ ASSERT_OK(txn1->Prepare());
+ ASSERT_OK(txn1->SetCommitTimestamp(/*ts=*/23));
+ ASSERT_OK(txn1->Commit());
+ txn1.reset();
+
+ std::unique_ptr<Transaction> txn2(NewTxn(write_opts, TransactionOptions()));
+ assert(txn2);
+ ASSERT_OK(txn2->Put("key1", "value_3"));
+ ASSERT_OK(txn2->Put(handles_[1], "key1", "value_3"));
+ ASSERT_OK(txn2->SetCommitTimestamp(/*ts=*/24));
+ ASSERT_OK(txn2->Commit());
+ txn2.reset();
+
+ txn0.reset();
+
+ ASSERT_OK(ReOpenNoDelete(cf_descs, &handles_));
+
+ {
+ std::string value;
+ Status s = GetFromDb(ReadOptions(), handles_[1], "foo", /*ts=*/23, &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = db->Get(ReadOptions(), handles_[0], "bar", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("bar_value_1", value);
+
+ value.clear();
+ s = GetFromDb(ReadOptions(), handles_[1], "bar", /*ts=*/23, &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("bar_value_1", value);
+
+ s = GetFromDb(ReadOptions(), handles_[1], "key1", /*ts=*/23, &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = db->Get(ReadOptions(), handles_[0], "key1", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("value_3", value);
+
+ s = GetFromDb(ReadOptions(), handles_[1], "key1", /*ts=*/24, &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("value_3", value);
+ }
+}
+
+TEST_P(WriteCommittedTxnWithTsTest, TransactionDbLevelApi) {
+ ASSERT_OK(ReOpenNoDelete());
+
+ ColumnFamilyOptions cf_options;
+ cf_options.merge_operator = MergeOperators::CreateStringAppendOperator();
+ cf_options.comparator = test::BytewiseComparatorWithU64TsWrapper();
+ const std::string test_cf_name = "test_cf";
+ ColumnFamilyHandle* cfh = nullptr;
+ assert(db);
+ ASSERT_OK(db->CreateColumnFamily(cf_options, test_cf_name, &cfh));
+ delete cfh;
+ cfh = nullptr;
+
+ std::vector<ColumnFamilyDescriptor> cf_descs;
+ cf_descs.emplace_back(kDefaultColumnFamilyName, options);
+ cf_descs.emplace_back(test_cf_name, cf_options);
+
+ ASSERT_OK(ReOpenNoDelete(cf_descs, &handles_));
+
+ std::string key_str = "tes_key";
+ std::string ts_str;
+ std::string value_str = "test_value";
+ PutFixed64(&ts_str, 100);
+ Slice value = value_str;
+
+ assert(db);
+ ASSERT_TRUE(
+ db->Put(WriteOptions(), handles_[1], "foo", "bar").IsNotSupported());
+ ASSERT_TRUE(db->Delete(WriteOptions(), handles_[1], "foo").IsNotSupported());
+ ASSERT_TRUE(
+ db->SingleDelete(WriteOptions(), handles_[1], "foo").IsNotSupported());
+ ASSERT_TRUE(
+ db->Merge(WriteOptions(), handles_[1], "foo", "+1").IsNotSupported());
+ WriteBatch wb1(/*reserved_bytes=*/0, /*max_bytes=*/0,
+ /*protection_bytes_per_key=*/0, /*default_cf_ts_sz=*/0);
+ ASSERT_OK(wb1.Put(handles_[1], key_str, ts_str, value));
+ ASSERT_TRUE(db->Write(WriteOptions(), &wb1).IsNotSupported());
+ ASSERT_TRUE(db->Write(WriteOptions(), TransactionDBWriteOptimizations(), &wb1)
+ .IsNotSupported());
+ auto* pessimistic_txn_db =
+ static_cast_with_check<PessimisticTransactionDB>(db);
+ assert(pessimistic_txn_db);
+ ASSERT_TRUE(
+ pessimistic_txn_db->WriteWithConcurrencyControl(WriteOptions(), &wb1)
+ .IsNotSupported());
+
+ ASSERT_OK(db->Put(WriteOptions(), "foo", "value"));
+ ASSERT_OK(db->Put(WriteOptions(), "bar", "value"));
+ ASSERT_OK(db->Delete(WriteOptions(), "bar"));
+ ASSERT_OK(db->SingleDelete(WriteOptions(), "foo"));
+ ASSERT_OK(db->Put(WriteOptions(), "key", "value"));
+ ASSERT_OK(db->Merge(WriteOptions(), "key", "_more"));
+ WriteBatch wb2(/*reserved_bytes=*/0, /*max_bytes=*/0,
+ /*protection_bytes_per_key=*/0, /*default_cf_ts_sz=*/0);
+ ASSERT_OK(wb2.Put(key_str, value));
+ ASSERT_OK(db->Write(WriteOptions(), &wb2));
+ ASSERT_OK(db->Write(WriteOptions(), TransactionDBWriteOptimizations(), &wb2));
+ ASSERT_OK(
+ pessimistic_txn_db->WriteWithConcurrencyControl(WriteOptions(), &wb2));
+
+ std::unique_ptr<Transaction> txn(
+ NewTxn(WriteOptions(), TransactionOptions()));
+ assert(txn);
+
+ WriteBatch wb3(/*reserved_bytes=*/0, /*max_bytes=*/0,
+ /*protection_bytes_per_key=*/0, /*default_cf_ts_sz=*/0);
+
+ ASSERT_OK(wb3.Put(handles_[1], "key", "value"));
+ auto* pessimistic_txn =
+ static_cast_with_check<PessimisticTransaction>(txn.get());
+ assert(pessimistic_txn);
+ ASSERT_TRUE(pessimistic_txn->CommitBatch(&wb3).IsNotSupported());
+
+ txn.reset();
+}
+
+TEST_P(WriteCommittedTxnWithTsTest, Merge) {
+ options.merge_operator = MergeOperators::CreateStringAppendOperator();
+ ASSERT_OK(ReOpenNoDelete());
+
+ ColumnFamilyOptions cf_options;
+ cf_options.comparator = test::BytewiseComparatorWithU64TsWrapper();
+ cf_options.merge_operator = MergeOperators::CreateStringAppendOperator();
+ const std::string test_cf_name = "test_cf";
+ ColumnFamilyHandle* cfh = nullptr;
+ assert(db);
+ ASSERT_OK(db->CreateColumnFamily(cf_options, test_cf_name, &cfh));
+ delete cfh;
+ cfh = nullptr;
+
+ std::vector<ColumnFamilyDescriptor> cf_descs;
+ cf_descs.emplace_back(kDefaultColumnFamilyName, options);
+ cf_descs.emplace_back(test_cf_name, Options(DBOptions(), cf_options));
+ options.avoid_flush_during_shutdown = true;
+
+ ASSERT_OK(ReOpenNoDelete(cf_descs, &handles_));
+
+ std::unique_ptr<Transaction> txn(
+ NewTxn(WriteOptions(), TransactionOptions()));
+ assert(txn);
+ ASSERT_OK(txn->Put(handles_[1], "foo", "bar"));
+ ASSERT_OK(txn->Merge(handles_[1], "foo", "1"));
+ ASSERT_OK(txn->SetCommitTimestamp(24));
+ ASSERT_OK(txn->Commit());
+ txn.reset();
+ {
+ std::string value;
+ const Status s =
+ GetFromDb(ReadOptions(), handles_[1], "foo", /*ts=*/24, &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("bar,1", value);
+ }
+}
+
+TEST_P(WriteCommittedTxnWithTsTest, GetForUpdate) {
+ ASSERT_OK(ReOpenNoDelete());
+
+ ColumnFamilyOptions cf_options;
+ cf_options.comparator = test::BytewiseComparatorWithU64TsWrapper();
+ const std::string test_cf_name = "test_cf";
+ ColumnFamilyHandle* cfh = nullptr;
+ assert(db);
+ ASSERT_OK(db->CreateColumnFamily(cf_options, test_cf_name, &cfh));
+ delete cfh;
+ cfh = nullptr;
+
+ std::vector<ColumnFamilyDescriptor> cf_descs;
+ cf_descs.emplace_back(kDefaultColumnFamilyName, options);
+ cf_descs.emplace_back(test_cf_name, Options(DBOptions(), cf_options));
+ options.avoid_flush_during_shutdown = true;
+
+ ASSERT_OK(ReOpenNoDelete(cf_descs, &handles_));
+
+ std::unique_ptr<Transaction> txn0(
+ NewTxn(WriteOptions(), TransactionOptions()));
+
+ std::unique_ptr<Transaction> txn1(
+ NewTxn(WriteOptions(), TransactionOptions()));
+ ASSERT_OK(txn1->Put(handles_[1], "key", "value1"));
+ ASSERT_OK(txn1->SetCommitTimestamp(24));
+ ASSERT_OK(txn1->Commit());
+ txn1.reset();
+
+ std::string value;
+ ASSERT_OK(txn0->SetReadTimestampForValidation(23));
+ ASSERT_TRUE(
+ txn0->GetForUpdate(ReadOptions(), handles_[1], "key", &value).IsBusy());
+ ASSERT_OK(txn0->Rollback());
+ txn0.reset();
+
+ std::unique_ptr<Transaction> txn2(
+ NewTxn(WriteOptions(), TransactionOptions()));
+ ASSERT_OK(txn2->SetReadTimestampForValidation(25));
+ ASSERT_OK(txn2->GetForUpdate(ReadOptions(), handles_[1], "key", &value));
+ ASSERT_OK(txn2->SetCommitTimestamp(26));
+ ASSERT_OK(txn2->Commit());
+ txn2.reset();
+}
+
+TEST_P(WriteCommittedTxnWithTsTest, BlindWrite) {
+ ASSERT_OK(ReOpenNoDelete());
+
+ ColumnFamilyOptions cf_options;
+ cf_options.comparator = test::BytewiseComparatorWithU64TsWrapper();
+ const std::string test_cf_name = "test_cf";
+ ColumnFamilyHandle* cfh = nullptr;
+ assert(db);
+ ASSERT_OK(db->CreateColumnFamily(cf_options, test_cf_name, &cfh));
+ delete cfh;
+ cfh = nullptr;
+
+ std::vector<ColumnFamilyDescriptor> cf_descs;
+ cf_descs.emplace_back(kDefaultColumnFamilyName, options);
+ cf_descs.emplace_back(test_cf_name, Options(DBOptions(), cf_options));
+ options.avoid_flush_during_shutdown = true;
+ ASSERT_OK(ReOpenNoDelete(cf_descs, &handles_));
+
+ std::unique_ptr<Transaction> txn0(
+ NewTxn(WriteOptions(), TransactionOptions()));
+ assert(txn0);
+ std::unique_ptr<Transaction> txn1(
+ NewTxn(WriteOptions(), TransactionOptions()));
+ assert(txn1);
+
+ {
+ std::string value;
+ ASSERT_OK(txn0->SetReadTimestampForValidation(100));
+ // Lock "key".
+ ASSERT_TRUE(txn0->GetForUpdate(ReadOptions(), handles_[1], "key", &value)
+ .IsNotFound());
+ }
+
+ ASSERT_OK(txn0->Put(handles_[1], "key", "value0"));
+ ASSERT_OK(txn0->SetCommitTimestamp(101));
+ ASSERT_OK(txn0->Commit());
+
+ ASSERT_OK(txn1->Put(handles_[1], "key", "value1"));
+ // In reality, caller needs to ensure commit_ts of txn1 is greater than the
+ // commit_ts of txn0, which is true for lock-based concurrency control.
+ ASSERT_OK(txn1->SetCommitTimestamp(102));
+ ASSERT_OK(txn1->Commit());
+
+ txn0.reset();
+ txn1.reset();
+}
+
+TEST_P(WriteCommittedTxnWithTsTest, RefineReadTimestamp) {
+ ASSERT_OK(ReOpenNoDelete());
+
+ ColumnFamilyOptions cf_options;
+ cf_options.comparator = test::BytewiseComparatorWithU64TsWrapper();
+ const std::string test_cf_name = "test_cf";
+ ColumnFamilyHandle* cfh = nullptr;
+ assert(db);
+ ASSERT_OK(db->CreateColumnFamily(cf_options, test_cf_name, &cfh));
+ delete cfh;
+ cfh = nullptr;
+
+ std::vector<ColumnFamilyDescriptor> cf_descs;
+ cf_descs.emplace_back(kDefaultColumnFamilyName, options);
+ cf_descs.emplace_back(test_cf_name, Options(DBOptions(), cf_options));
+ options.avoid_flush_during_shutdown = true;
+
+ ASSERT_OK(ReOpenNoDelete(cf_descs, &handles_));
+
+ std::unique_ptr<Transaction> txn0(
+ NewTxn(WriteOptions(), TransactionOptions()));
+ assert(txn0);
+
+ std::unique_ptr<Transaction> txn1(
+ NewTxn(WriteOptions(), TransactionOptions()));
+ assert(txn1);
+
+ {
+ ASSERT_OK(txn0->SetReadTimestampForValidation(100));
+ // Lock "key0", "key1", ..., "key4".
+ for (int i = 0; i < 5; ++i) {
+ std::string value;
+ ASSERT_TRUE(txn0->GetForUpdate(ReadOptions(), handles_[1],
+ "key" + std::to_string(i), &value)
+ .IsNotFound());
+ }
+ }
+ ASSERT_OK(txn1->Put(handles_[1], "key5", "value5_0"));
+ ASSERT_OK(txn1->SetName("txn1"));
+ ASSERT_OK(txn1->Prepare());
+ ASSERT_OK(txn1->SetCommitTimestamp(101));
+ ASSERT_OK(txn1->Commit());
+ txn1.reset();
+
+ {
+ std::string value;
+ ASSERT_TRUE(txn0->GetForUpdate(ReadOptions(), handles_[1], "key5", &value)
+ .IsBusy());
+ ASSERT_OK(txn0->SetReadTimestampForValidation(102));
+ ASSERT_OK(txn0->GetForUpdate(ReadOptions(), handles_[1], "key5", &value));
+ ASSERT_EQ("value5_0", value);
+ }
+
+ for (int i = 0; i < 6; ++i) {
+ ASSERT_OK(txn0->Put(handles_[1], "key" + std::to_string(i),
+ "value" + std::to_string(i)));
+ }
+ ASSERT_OK(txn0->SetName("txn0"));
+ ASSERT_OK(txn0->Prepare());
+ ASSERT_OK(txn0->SetCommitTimestamp(103));
+ ASSERT_OK(txn0->Commit());
+ txn0.reset();
+}
+
+TEST_P(WriteCommittedTxnWithTsTest, CheckKeysForConflicts) {
+ options.comparator = test::BytewiseComparatorWithU64TsWrapper();
+ ASSERT_OK(ReOpen());
+
+ std::unique_ptr<Transaction> txn1(
+ db->BeginTransaction(WriteOptions(), TransactionOptions()));
+ assert(txn1);
+
+ std::unique_ptr<Transaction> txn2(
+ db->BeginTransaction(WriteOptions(), TransactionOptions()));
+ assert(txn2);
+ ASSERT_OK(txn2->Put("foo", "v0"));
+ ASSERT_OK(txn2->SetCommitTimestamp(10));
+ ASSERT_OK(txn2->Commit());
+ txn2.reset();
+
+ // txn1 takes a snapshot after txn2 commits. The writes of txn2 have
+ // a smaller seqno than txn1's snapshot, thus should not affect conflict
+ // checking.
+ txn1->SetSnapshot();
+
+ std::unique_ptr<Transaction> txn3(
+ db->BeginTransaction(WriteOptions(), TransactionOptions()));
+ assert(txn3);
+ ASSERT_OK(txn3->SetReadTimestampForValidation(20));
+ std::string dontcare;
+ ASSERT_OK(txn3->GetForUpdate(ReadOptions(), "foo", &dontcare));
+ ASSERT_OK(txn3->SingleDelete("foo"));
+ ASSERT_OK(txn3->SetName("txn3"));
+ ASSERT_OK(txn3->Prepare());
+ ASSERT_OK(txn3->SetCommitTimestamp(30));
+ // txn3 reads at ts=20 > txn2's commit timestamp, and commits at ts=30.
+ // txn3 can commit successfully, leaving a tombstone with ts=30.
+ ASSERT_OK(txn3->Commit());
+ txn3.reset();
+
+ bool called = false;
+ SyncPoint::GetInstance()->DisableProcessing();
+ SyncPoint::GetInstance()->ClearAllCallBacks();
+ SyncPoint::GetInstance()->SetCallBack(
+ "DBImpl::GetLatestSequenceForKey:mem", [&](void* arg) {
+ auto* const ts_ptr = reinterpret_cast<std::string*>(arg);
+ assert(ts_ptr);
+ Slice ts_slc = *ts_ptr;
+ uint64_t last_ts = 0;
+ ASSERT_TRUE(GetFixed64(&ts_slc, &last_ts));
+ ASSERT_EQ(30, last_ts);
+ called = true;
+ });
+ SyncPoint::GetInstance()->EnableProcessing();
+
+ // txn1's read timestamp is 25 < 30 (commit timestamp of txn3). Therefore,
+ // the tombstone written by txn3 causes the conflict checking to fail.
+ ASSERT_OK(txn1->SetReadTimestampForValidation(25));
+ ASSERT_TRUE(txn1->GetForUpdate(ReadOptions(), "foo", &dontcare).IsBusy());
+ ASSERT_TRUE(called);
+
+ SyncPoint::GetInstance()->DisableProcessing();
+ SyncPoint::GetInstance()->ClearAllCallBacks();
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
+
+#else
+#include <cstdio>
+
+int main(int /*argc*/, char** /*argv*/) {
+ fprintf(stderr, "SKIPPED as Transactions not supported in ROCKSDB_LITE\n");
+ return 0;
+}
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/write_prepared_transaction_test.cc b/src/rocksdb/utilities/transactions/write_prepared_transaction_test.cc
new file mode 100644
index 000000000..86a9511a4
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/write_prepared_transaction_test.cc
@@ -0,0 +1,4078 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include <algorithm>
+#include <atomic>
+#include <cinttypes>
+#include <functional>
+#include <string>
+#include <thread>
+
+#include "db/db_impl/db_impl.h"
+#include "db/dbformat.h"
+#include "port/port.h"
+#include "port/stack_trace.h"
+#include "rocksdb/db.h"
+#include "rocksdb/options.h"
+#include "rocksdb/types.h"
+#include "rocksdb/utilities/debug.h"
+#include "rocksdb/utilities/transaction.h"
+#include "rocksdb/utilities/transaction_db.h"
+#include "table/mock_table.h"
+#include "test_util/sync_point.h"
+#include "test_util/testharness.h"
+#include "test_util/testutil.h"
+#include "test_util/transaction_test_util.h"
+#include "util/mutexlock.h"
+#include "util/random.h"
+#include "util/string_util.h"
+#include "utilities/fault_injection_env.h"
+#include "utilities/merge_operators.h"
+#include "utilities/merge_operators/string_append/stringappend.h"
+#include "utilities/transactions/pessimistic_transaction_db.h"
+#include "utilities/transactions/transaction_test.h"
+#include "utilities/transactions/write_prepared_txn_db.h"
+
+using std::string;
+
+namespace ROCKSDB_NAMESPACE {
+
+using CommitEntry = WritePreparedTxnDB::CommitEntry;
+using CommitEntry64b = WritePreparedTxnDB::CommitEntry64b;
+using CommitEntry64bFormat = WritePreparedTxnDB::CommitEntry64bFormat;
+
+TEST(PreparedHeap, BasicsTest) {
+ WritePreparedTxnDB::PreparedHeap heap;
+ {
+ MutexLock ml(heap.push_pop_mutex());
+ heap.push(14l);
+ // Test with one element
+ ASSERT_EQ(14l, heap.top());
+ heap.push(24l);
+ heap.push(34l);
+ // Test that old min is still on top
+ ASSERT_EQ(14l, heap.top());
+ heap.push(44l);
+ heap.push(54l);
+ heap.push(64l);
+ heap.push(74l);
+ heap.push(84l);
+ }
+ // Test that old min is still on top
+ ASSERT_EQ(14l, heap.top());
+ heap.erase(24l);
+ // Test that old min is still on top
+ ASSERT_EQ(14l, heap.top());
+ heap.erase(14l);
+ // Test that the new comes to the top after multiple erase
+ ASSERT_EQ(34l, heap.top());
+ heap.erase(34l);
+ // Test that the new comes to the top after single erase
+ ASSERT_EQ(44l, heap.top());
+ heap.erase(54l);
+ ASSERT_EQ(44l, heap.top());
+ heap.pop(); // pop 44l
+ // Test that the erased items are ignored after pop
+ ASSERT_EQ(64l, heap.top());
+ heap.erase(44l);
+ // Test that erasing an already popped item would work
+ ASSERT_EQ(64l, heap.top());
+ heap.erase(84l);
+ ASSERT_EQ(64l, heap.top());
+ {
+ MutexLock ml(heap.push_pop_mutex());
+ heap.push(85l);
+ heap.push(86l);
+ heap.push(87l);
+ heap.push(88l);
+ heap.push(89l);
+ }
+ heap.erase(87l);
+ heap.erase(85l);
+ heap.erase(89l);
+ heap.erase(86l);
+ heap.erase(88l);
+ // Test top remains the same after a random order of many erases
+ ASSERT_EQ(64l, heap.top());
+ heap.pop();
+ // Test that pop works with a series of random pending erases
+ ASSERT_EQ(74l, heap.top());
+ ASSERT_FALSE(heap.empty());
+ heap.pop();
+ // Test that empty works
+ ASSERT_TRUE(heap.empty());
+}
+
+// This is a scenario reconstructed from a buggy trace. Test that the bug does
+// not resurface again.
+TEST(PreparedHeap, EmptyAtTheEnd) {
+ WritePreparedTxnDB::PreparedHeap heap;
+ {
+ MutexLock ml(heap.push_pop_mutex());
+ heap.push(40l);
+ }
+ ASSERT_EQ(40l, heap.top());
+ // Although not a recommended scenario, we must be resilient against erase
+ // without a prior push.
+ heap.erase(50l);
+ ASSERT_EQ(40l, heap.top());
+ {
+ MutexLock ml(heap.push_pop_mutex());
+ heap.push(60l);
+ }
+ ASSERT_EQ(40l, heap.top());
+
+ heap.erase(60l);
+ ASSERT_EQ(40l, heap.top());
+ heap.erase(40l);
+ ASSERT_TRUE(heap.empty());
+
+ {
+ MutexLock ml(heap.push_pop_mutex());
+ heap.push(40l);
+ }
+ ASSERT_EQ(40l, heap.top());
+ heap.erase(50l);
+ ASSERT_EQ(40l, heap.top());
+ {
+ MutexLock ml(heap.push_pop_mutex());
+ heap.push(60l);
+ }
+ ASSERT_EQ(40l, heap.top());
+
+ heap.erase(40l);
+ // Test that the erase has not emptied the heap (we had a bug doing that)
+ ASSERT_FALSE(heap.empty());
+ ASSERT_EQ(60l, heap.top());
+ heap.erase(60l);
+ ASSERT_TRUE(heap.empty());
+}
+
+// Generate random order of PreparedHeap access and test that the heap will be
+// successfully emptied at the end.
+TEST(PreparedHeap, Concurrent) {
+ const size_t t_cnt = 10;
+ ROCKSDB_NAMESPACE::port::Thread t[t_cnt + 1];
+ WritePreparedTxnDB::PreparedHeap heap;
+ port::RWMutex prepared_mutex;
+ std::atomic<size_t> last;
+
+ for (size_t n = 0; n < 100; n++) {
+ last = 0;
+ t[0] = ROCKSDB_NAMESPACE::port::Thread([&]() {
+ Random rnd(1103);
+ for (size_t seq = 1; seq <= t_cnt; seq++) {
+ // This is not recommended usage but we should be resilient against it.
+ bool skip_push = rnd.OneIn(5);
+ if (!skip_push) {
+ MutexLock ml(heap.push_pop_mutex());
+ std::this_thread::yield();
+ heap.push(seq);
+ last.store(seq);
+ }
+ }
+ });
+ for (size_t i = 1; i <= t_cnt; i++) {
+ t[i] =
+ ROCKSDB_NAMESPACE::port::Thread([&heap, &prepared_mutex, &last, i]() {
+ auto seq = i;
+ do {
+ std::this_thread::yield();
+ } while (last.load() < seq);
+ WriteLock wl(&prepared_mutex);
+ heap.erase(seq);
+ });
+ }
+ for (size_t i = 0; i <= t_cnt; i++) {
+ t[i].join();
+ }
+ ASSERT_TRUE(heap.empty());
+ }
+}
+
+// Test that WriteBatchWithIndex correctly counts the number of sub-batches
+TEST(WriteBatchWithIndex, SubBatchCnt) {
+ ColumnFamilyOptions cf_options;
+ std::string cf_name = "two";
+ DB* db;
+ Options options;
+ options.create_if_missing = true;
+ const std::string dbname = test::PerThreadDBPath("transaction_testdb");
+ EXPECT_OK(DestroyDB(dbname, options));
+ ASSERT_OK(DB::Open(options, dbname, &db));
+ ColumnFamilyHandle* cf_handle = nullptr;
+ ASSERT_OK(db->CreateColumnFamily(cf_options, cf_name, &cf_handle));
+ WriteOptions write_options;
+ size_t batch_cnt = 1;
+ size_t save_points = 0;
+ std::vector<size_t> batch_cnt_at;
+ WriteBatchWithIndex batch(db->DefaultColumnFamily()->GetComparator(), 0, true,
+ 0);
+ ASSERT_EQ(batch_cnt, batch.SubBatchCnt());
+ batch_cnt_at.push_back(batch_cnt);
+ batch.SetSavePoint();
+ save_points++;
+ ASSERT_OK(batch.Put(Slice("key"), Slice("value")));
+ ASSERT_EQ(batch_cnt, batch.SubBatchCnt());
+ batch_cnt_at.push_back(batch_cnt);
+ batch.SetSavePoint();
+ save_points++;
+ ASSERT_OK(batch.Put(Slice("key2"), Slice("value2")));
+ ASSERT_EQ(batch_cnt, batch.SubBatchCnt());
+ // duplicate the keys
+ batch_cnt_at.push_back(batch_cnt);
+ batch.SetSavePoint();
+ save_points++;
+ ASSERT_OK(batch.Put(Slice("key"), Slice("value3")));
+ batch_cnt++;
+ ASSERT_EQ(batch_cnt, batch.SubBatchCnt());
+ // duplicate the 2nd key. It should not be counted duplicate since a
+ // sub-patch is cut after the last duplicate.
+ batch_cnt_at.push_back(batch_cnt);
+ batch.SetSavePoint();
+ save_points++;
+ ASSERT_OK(batch.Put(Slice("key2"), Slice("value4")));
+ ASSERT_EQ(batch_cnt, batch.SubBatchCnt());
+ // duplicate the keys but in a different cf. It should not be counted as
+ // duplicate keys
+ batch_cnt_at.push_back(batch_cnt);
+ batch.SetSavePoint();
+ save_points++;
+ ASSERT_OK(batch.Put(cf_handle, Slice("key"), Slice("value5")));
+ ASSERT_EQ(batch_cnt, batch.SubBatchCnt());
+
+ // Test that the number of sub-batches matches what we count with
+ // SubBatchCounter
+ std::map<uint32_t, const Comparator*> comparators;
+ comparators[0] = db->DefaultColumnFamily()->GetComparator();
+ comparators[cf_handle->GetID()] = cf_handle->GetComparator();
+ SubBatchCounter counter(comparators);
+ ASSERT_OK(batch.GetWriteBatch()->Iterate(&counter));
+ ASSERT_EQ(batch_cnt, counter.BatchCount());
+
+ // Test that RollbackToSavePoint will properly resets the number of
+ // sub-batches
+ for (size_t i = save_points; i > 0; i--) {
+ ASSERT_OK(batch.RollbackToSavePoint());
+ ASSERT_EQ(batch_cnt_at[i - 1], batch.SubBatchCnt());
+ }
+
+ // Test the count is right with random batches
+ {
+ const size_t TOTAL_KEYS = 20; // 20 ~= 10 to cause a few randoms
+ Random rnd(1131);
+ std::string keys[TOTAL_KEYS];
+ for (size_t k = 0; k < TOTAL_KEYS; k++) {
+ int len = static_cast<int>(rnd.Uniform(50));
+ keys[k] = test::RandomKey(&rnd, len);
+ }
+ for (size_t i = 0; i < 1000; i++) { // 1000 random batches
+ WriteBatchWithIndex rndbatch(db->DefaultColumnFamily()->GetComparator(),
+ 0, true, 0);
+ for (size_t k = 0; k < 10; k++) { // 10 key per batch
+ size_t ki = static_cast<size_t>(rnd.Uniform(TOTAL_KEYS));
+ Slice key = Slice(keys[ki]);
+ std::string tmp = rnd.RandomString(16);
+ Slice value = Slice(tmp);
+ ASSERT_OK(rndbatch.Put(key, value));
+ }
+ SubBatchCounter batch_counter(comparators);
+ ASSERT_OK(rndbatch.GetWriteBatch()->Iterate(&batch_counter));
+ ASSERT_EQ(rndbatch.SubBatchCnt(), batch_counter.BatchCount());
+ }
+ }
+
+ delete cf_handle;
+ delete db;
+}
+
+TEST(CommitEntry64b, BasicTest) {
+ const size_t INDEX_BITS = static_cast<size_t>(21);
+ const size_t INDEX_SIZE = static_cast<size_t>(1ull << INDEX_BITS);
+ const CommitEntry64bFormat FORMAT(static_cast<size_t>(INDEX_BITS));
+
+ // zero-initialized CommitEntry64b should indicate an empty entry
+ CommitEntry64b empty_entry64b;
+ uint64_t empty_index = 11ul;
+ CommitEntry empty_entry;
+ bool ok = empty_entry64b.Parse(empty_index, &empty_entry, FORMAT);
+ ASSERT_FALSE(ok);
+
+ // the zero entry is reserved for un-initialized entries
+ const size_t MAX_COMMIT = (1 << FORMAT.COMMIT_BITS) - 1 - 1;
+ // Samples over the numbers that are covered by that many index bits
+ std::array<uint64_t, 4> is = {{0, 1, INDEX_SIZE / 2 + 1, INDEX_SIZE - 1}};
+ // Samples over the numbers that are covered by that many commit bits
+ std::array<uint64_t, 4> ds = {{0, 1, MAX_COMMIT / 2 + 1, MAX_COMMIT}};
+ // Iterate over prepare numbers that have i) cover all bits of a sequence
+ // number, and ii) include some bits that fall into the range of index or
+ // commit bits
+ for (uint64_t base = 1; base < kMaxSequenceNumber; base *= 2) {
+ for (uint64_t i : is) {
+ for (uint64_t d : ds) {
+ uint64_t p = base + i + d;
+ for (uint64_t c : {p, p + d / 2, p + d}) {
+ uint64_t index = p % INDEX_SIZE;
+ CommitEntry before(p, c), after;
+ CommitEntry64b entry64b(before, FORMAT);
+ ok = entry64b.Parse(index, &after, FORMAT);
+ ASSERT_TRUE(ok);
+ if (!(before == after)) {
+ printf("base %" PRIu64 " i %" PRIu64 " d %" PRIu64 " p %" PRIu64
+ " c %" PRIu64 " index %" PRIu64 "\n",
+ base, i, d, p, c, index);
+ }
+ ASSERT_EQ(before, after);
+ }
+ }
+ }
+ }
+}
+
+class WritePreparedTxnDBMock : public WritePreparedTxnDB {
+ public:
+ WritePreparedTxnDBMock(DBImpl* db_impl, TransactionDBOptions& opt)
+ : WritePreparedTxnDB(db_impl, opt) {}
+ void SetDBSnapshots(const std::vector<SequenceNumber>& snapshots) {
+ snapshots_ = snapshots;
+ }
+ void TakeSnapshot(SequenceNumber seq) { snapshots_.push_back(seq); }
+
+ protected:
+ const std::vector<SequenceNumber> GetSnapshotListFromDB(
+ SequenceNumber /* unused */) override {
+ return snapshots_;
+ }
+
+ private:
+ std::vector<SequenceNumber> snapshots_;
+};
+
+class WritePreparedTransactionTestBase : public TransactionTestBase {
+ public:
+ WritePreparedTransactionTestBase(bool use_stackable_db, bool two_write_queue,
+ TxnDBWritePolicy write_policy,
+ WriteOrdering write_ordering)
+ : TransactionTestBase(use_stackable_db, two_write_queue, write_policy,
+ write_ordering){};
+
+ protected:
+ void UpdateTransactionDBOptions(size_t snapshot_cache_bits,
+ size_t commit_cache_bits) {
+ txn_db_options.wp_snapshot_cache_bits = snapshot_cache_bits;
+ txn_db_options.wp_commit_cache_bits = commit_cache_bits;
+ }
+ void UpdateTransactionDBOptions(size_t snapshot_cache_bits) {
+ txn_db_options.wp_snapshot_cache_bits = snapshot_cache_bits;
+ }
+ // If expect_update is set, check if it actually updated old_commit_map_. If
+ // it did not and yet suggested not to check the next snapshot, do the
+ // opposite to check if it was not a bad suggestion.
+ void MaybeUpdateOldCommitMapTestWithNext(uint64_t prepare, uint64_t commit,
+ uint64_t snapshot,
+ uint64_t next_snapshot,
+ bool expect_update) {
+ WritePreparedTxnDB* wp_db = dynamic_cast<WritePreparedTxnDB*>(db);
+ // reset old_commit_map_empty_ so that its value indicate whether
+ // old_commit_map_ was updated
+ wp_db->old_commit_map_empty_ = true;
+ bool check_next = wp_db->MaybeUpdateOldCommitMap(prepare, commit, snapshot,
+ snapshot < next_snapshot);
+ if (expect_update == wp_db->old_commit_map_empty_) {
+ printf("prepare: %" PRIu64 " commit: %" PRIu64 " snapshot: %" PRIu64
+ " next: %" PRIu64 "\n",
+ prepare, commit, snapshot, next_snapshot);
+ }
+ EXPECT_EQ(!expect_update, wp_db->old_commit_map_empty_);
+ if (!check_next && wp_db->old_commit_map_empty_) {
+ // do the opposite to make sure it was not a bad suggestion
+ const bool dont_care_bool = true;
+ wp_db->MaybeUpdateOldCommitMap(prepare, commit, next_snapshot,
+ dont_care_bool);
+ if (!wp_db->old_commit_map_empty_) {
+ printf("prepare: %" PRIu64 " commit: %" PRIu64 " snapshot: %" PRIu64
+ " next: %" PRIu64 "\n",
+ prepare, commit, snapshot, next_snapshot);
+ }
+ EXPECT_TRUE(wp_db->old_commit_map_empty_);
+ }
+ }
+
+ // Test that a CheckAgainstSnapshots thread reading old_snapshots will not
+ // miss a snapshot because of a concurrent update by UpdateSnapshots that is
+ // writing new_snapshots. Both threads are broken at two points. The sync
+ // points to enforce them are specified by a1, a2, b1, and b2. CommitEntry
+ // entry is expected to be vital for one of the snapshots that is common
+ // between the old and new list of snapshots.
+ void SnapshotConcurrentAccessTestInternal(
+ WritePreparedTxnDB* wp_db,
+ const std::vector<SequenceNumber>& old_snapshots,
+ const std::vector<SequenceNumber>& new_snapshots, CommitEntry& entry,
+ SequenceNumber& version, size_t a1, size_t a2, size_t b1, size_t b2) {
+ // First reset the snapshot list
+ const std::vector<SequenceNumber> empty_snapshots;
+ wp_db->old_commit_map_empty_ = true;
+ wp_db->UpdateSnapshots(empty_snapshots, ++version);
+ // Then initialize it with the old_snapshots
+ wp_db->UpdateSnapshots(old_snapshots, ++version);
+
+ // Starting from the first thread, cut each thread at two points
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency({
+ {"WritePreparedTxnDB::CheckAgainstSnapshots:p:" + std::to_string(a1),
+ "WritePreparedTxnDB::UpdateSnapshots:s:start"},
+ {"WritePreparedTxnDB::UpdateSnapshots:p:" + std::to_string(b1),
+ "WritePreparedTxnDB::CheckAgainstSnapshots:s:" + std::to_string(a1)},
+ {"WritePreparedTxnDB::CheckAgainstSnapshots:p:" + std::to_string(a2),
+ "WritePreparedTxnDB::UpdateSnapshots:s:" + std::to_string(b1)},
+ {"WritePreparedTxnDB::UpdateSnapshots:p:" + std::to_string(b2),
+ "WritePreparedTxnDB::CheckAgainstSnapshots:s:" + std::to_string(a2)},
+ {"WritePreparedTxnDB::CheckAgainstSnapshots:p:end",
+ "WritePreparedTxnDB::UpdateSnapshots:s:" + std::to_string(b2)},
+ });
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+ {
+ ASSERT_TRUE(wp_db->old_commit_map_empty_);
+ ROCKSDB_NAMESPACE::port::Thread t1(
+ [&]() { wp_db->UpdateSnapshots(new_snapshots, version); });
+ wp_db->CheckAgainstSnapshots(entry);
+ t1.join();
+ ASSERT_FALSE(wp_db->old_commit_map_empty_);
+ }
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+
+ wp_db->old_commit_map_empty_ = true;
+ wp_db->UpdateSnapshots(empty_snapshots, ++version);
+ wp_db->UpdateSnapshots(old_snapshots, ++version);
+ // Starting from the second thread, cut each thread at two points
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency({
+ {"WritePreparedTxnDB::UpdateSnapshots:p:" + std::to_string(a1),
+ "WritePreparedTxnDB::CheckAgainstSnapshots:s:start"},
+ {"WritePreparedTxnDB::CheckAgainstSnapshots:p:" + std::to_string(b1),
+ "WritePreparedTxnDB::UpdateSnapshots:s:" + std::to_string(a1)},
+ {"WritePreparedTxnDB::UpdateSnapshots:p:" + std::to_string(a2),
+ "WritePreparedTxnDB::CheckAgainstSnapshots:s:" + std::to_string(b1)},
+ {"WritePreparedTxnDB::CheckAgainstSnapshots:p:" + std::to_string(b2),
+ "WritePreparedTxnDB::UpdateSnapshots:s:" + std::to_string(a2)},
+ {"WritePreparedTxnDB::UpdateSnapshots:p:end",
+ "WritePreparedTxnDB::CheckAgainstSnapshots:s:" + std::to_string(b2)},
+ });
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+ {
+ ASSERT_TRUE(wp_db->old_commit_map_empty_);
+ ROCKSDB_NAMESPACE::port::Thread t1(
+ [&]() { wp_db->UpdateSnapshots(new_snapshots, version); });
+ wp_db->CheckAgainstSnapshots(entry);
+ t1.join();
+ ASSERT_FALSE(wp_db->old_commit_map_empty_);
+ }
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+ }
+
+ // Verify value of keys.
+ void VerifyKeys(const std::unordered_map<std::string, std::string>& data,
+ const Snapshot* snapshot = nullptr) {
+ std::string value;
+ ReadOptions read_options;
+ read_options.snapshot = snapshot;
+ for (auto& kv : data) {
+ auto s = db->Get(read_options, kv.first, &value);
+ ASSERT_TRUE(s.ok() || s.IsNotFound());
+ if (s.ok()) {
+ if (kv.second != value) {
+ printf("key = %s\n", kv.first.c_str());
+ }
+ ASSERT_EQ(kv.second, value);
+ } else {
+ ASSERT_EQ(kv.second, "NOT_FOUND");
+ }
+
+ // Try with MultiGet API too
+ std::vector<std::string> values;
+ auto s_vec = db->MultiGet(read_options, {db->DefaultColumnFamily()},
+ {kv.first}, &values);
+ ASSERT_EQ(1, values.size());
+ ASSERT_EQ(1, s_vec.size());
+ s = s_vec[0];
+ ASSERT_TRUE(s.ok() || s.IsNotFound());
+ if (s.ok()) {
+ ASSERT_TRUE(kv.second == values[0]);
+ } else {
+ ASSERT_EQ(kv.second, "NOT_FOUND");
+ }
+ }
+ }
+
+ // Verify all versions of keys.
+ void VerifyInternalKeys(const std::vector<KeyVersion>& expected_versions) {
+ std::vector<KeyVersion> versions;
+ const size_t kMaxKeys = 100000;
+ ASSERT_OK(GetAllKeyVersions(db, expected_versions.front().user_key,
+ expected_versions.back().user_key, kMaxKeys,
+ &versions));
+ ASSERT_EQ(expected_versions.size(), versions.size());
+ for (size_t i = 0; i < versions.size(); i++) {
+ ASSERT_EQ(expected_versions[i].user_key, versions[i].user_key);
+ ASSERT_EQ(expected_versions[i].sequence, versions[i].sequence);
+ ASSERT_EQ(expected_versions[i].type, versions[i].type);
+ if (versions[i].type != kTypeDeletion &&
+ versions[i].type != kTypeSingleDeletion) {
+ ASSERT_EQ(expected_versions[i].value, versions[i].value);
+ }
+ // Range delete not supported.
+ ASSERT_NE(expected_versions[i].type, kTypeRangeDeletion);
+ }
+ }
+};
+
+class WritePreparedTransactionTest
+ : public WritePreparedTransactionTestBase,
+ virtual public ::testing::WithParamInterface<
+ std::tuple<bool, bool, TxnDBWritePolicy, WriteOrdering>> {
+ public:
+ WritePreparedTransactionTest()
+ : WritePreparedTransactionTestBase(
+ std::get<0>(GetParam()), std::get<1>(GetParam()),
+ std::get<2>(GetParam()), std::get<3>(GetParam())){};
+};
+
+#if !defined(ROCKSDB_VALGRIND_RUN) || defined(ROCKSDB_FULL_VALGRIND_RUN)
+class SnapshotConcurrentAccessTest
+ : public WritePreparedTransactionTestBase,
+ virtual public ::testing::WithParamInterface<std::tuple<
+ bool, bool, TxnDBWritePolicy, WriteOrdering, size_t, size_t>> {
+ public:
+ SnapshotConcurrentAccessTest()
+ : WritePreparedTransactionTestBase(
+ std::get<0>(GetParam()), std::get<1>(GetParam()),
+ std::get<2>(GetParam()), std::get<3>(GetParam())),
+ split_id_(std::get<4>(GetParam())),
+ split_cnt_(std::get<5>(GetParam())){};
+
+ protected:
+ // A test is split into split_cnt_ tests, each identified with split_id_ where
+ // 0 <= split_id_ < split_cnt_
+ size_t split_id_;
+ size_t split_cnt_;
+};
+#endif // !defined(ROCKSDB_VALGRIND_RUN) || defined(ROCKSDB_FULL_VALGRIND_RUN)
+
+class SeqAdvanceConcurrentTest
+ : public WritePreparedTransactionTestBase,
+ virtual public ::testing::WithParamInterface<std::tuple<
+ bool, bool, TxnDBWritePolicy, WriteOrdering, size_t, size_t>> {
+ public:
+ SeqAdvanceConcurrentTest()
+ : WritePreparedTransactionTestBase(
+ std::get<0>(GetParam()), std::get<1>(GetParam()),
+ std::get<2>(GetParam()), std::get<3>(GetParam())),
+ split_id_(std::get<4>(GetParam())),
+ split_cnt_(std::get<5>(GetParam())) {
+ special_env.skip_fsync_ = true;
+ };
+
+ protected:
+ // A test is split into split_cnt_ tests, each identified with split_id_ where
+ // 0 <= split_id_ < split_cnt_
+ size_t split_id_;
+ size_t split_cnt_;
+};
+
+INSTANTIATE_TEST_CASE_P(
+ WritePreparedTransaction, WritePreparedTransactionTest,
+ ::testing::Values(
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite)));
+
+#if !defined(ROCKSDB_VALGRIND_RUN) || defined(ROCKSDB_FULL_VALGRIND_RUN)
+INSTANTIATE_TEST_CASE_P(
+ TwoWriteQueues, SnapshotConcurrentAccessTest,
+ ::testing::Values(
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 0, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 1, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 2, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 3, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 4, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 5, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 6, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 7, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 8, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 9, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 10, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 11, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 12, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 13, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 14, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 15, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 16, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 17, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 18, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 19, 20),
+
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 0, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 1, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 2, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 3, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 4, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 5, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 6, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 7, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 8, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 9, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 10, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 11, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 12, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 13, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 14, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 15, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 16, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 17, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 18, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 19, 20)));
+
+INSTANTIATE_TEST_CASE_P(
+ OneWriteQueue, SnapshotConcurrentAccessTest,
+ ::testing::Values(
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 0, 20),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 1, 20),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 2, 20),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 3, 20),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 4, 20),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 5, 20),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 6, 20),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 7, 20),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 8, 20),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 9, 20),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 10, 20),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 11, 20),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 12, 20),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 13, 20),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 14, 20),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 15, 20),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 16, 20),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 17, 20),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 18, 20),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 19, 20)));
+
+INSTANTIATE_TEST_CASE_P(
+ TwoWriteQueues, SeqAdvanceConcurrentTest,
+ ::testing::Values(
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 0, 10),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 1, 10),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 2, 10),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 3, 10),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 4, 10),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 5, 10),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 6, 10),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 7, 10),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 8, 10),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 9, 10),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 0, 10),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 1, 10),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 2, 10),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 3, 10),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 4, 10),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 5, 10),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 6, 10),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 7, 10),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 8, 10),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 9, 10)));
+
+INSTANTIATE_TEST_CASE_P(
+ OneWriteQueue, SeqAdvanceConcurrentTest,
+ ::testing::Values(
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 0, 10),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 1, 10),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 2, 10),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 3, 10),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 4, 10),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 5, 10),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 6, 10),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 7, 10),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 8, 10),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 9, 10)));
+#endif // !defined(ROCKSDB_VALGRIND_RUN) || defined(ROCKSDB_FULL_VALGRIND_RUN)
+
+TEST_P(WritePreparedTransactionTest, CommitMap) {
+ WritePreparedTxnDB* wp_db = dynamic_cast<WritePreparedTxnDB*>(db);
+ ASSERT_NE(wp_db, nullptr);
+ ASSERT_NE(wp_db->db_impl_, nullptr);
+ size_t size = wp_db->COMMIT_CACHE_SIZE;
+ CommitEntry c = {5, 12}, e;
+ bool evicted = wp_db->AddCommitEntry(c.prep_seq % size, c, &e);
+ ASSERT_FALSE(evicted);
+
+ // Should be able to read the same value
+ CommitEntry64b dont_care;
+ bool found = wp_db->GetCommitEntry(c.prep_seq % size, &dont_care, &e);
+ ASSERT_TRUE(found);
+ ASSERT_EQ(c, e);
+ // Should be able to distinguish between overlapping entries
+ found = wp_db->GetCommitEntry((c.prep_seq + size) % size, &dont_care, &e);
+ ASSERT_TRUE(found);
+ ASSERT_NE(c.prep_seq + size, e.prep_seq);
+ // Should be able to detect non-existent entry
+ found = wp_db->GetCommitEntry((c.prep_seq + 1) % size, &dont_care, &e);
+ ASSERT_FALSE(found);
+
+ // Reject an invalid exchange
+ CommitEntry e2 = {c.prep_seq + size, c.commit_seq + size};
+ CommitEntry64b e2_64b(e2, wp_db->FORMAT);
+ bool exchanged = wp_db->ExchangeCommitEntry(e2.prep_seq % size, e2_64b, e);
+ ASSERT_FALSE(exchanged);
+ // check whether it did actually reject that
+ found = wp_db->GetCommitEntry(e2.prep_seq % size, &dont_care, &e);
+ ASSERT_TRUE(found);
+ ASSERT_EQ(c, e);
+
+ // Accept a valid exchange
+ CommitEntry64b c_64b(c, wp_db->FORMAT);
+ CommitEntry e3 = {c.prep_seq + size, c.commit_seq + size + 1};
+ exchanged = wp_db->ExchangeCommitEntry(c.prep_seq % size, c_64b, e3);
+ ASSERT_TRUE(exchanged);
+ // check whether it did actually accepted that
+ found = wp_db->GetCommitEntry(c.prep_seq % size, &dont_care, &e);
+ ASSERT_TRUE(found);
+ ASSERT_EQ(e3, e);
+
+ // Rewrite an entry
+ CommitEntry e4 = {e3.prep_seq + size, e3.commit_seq + size + 1};
+ evicted = wp_db->AddCommitEntry(e4.prep_seq % size, e4, &e);
+ ASSERT_TRUE(evicted);
+ ASSERT_EQ(e3, e);
+ found = wp_db->GetCommitEntry(e4.prep_seq % size, &dont_care, &e);
+ ASSERT_TRUE(found);
+ ASSERT_EQ(e4, e);
+}
+
+TEST_P(WritePreparedTransactionTest, MaybeUpdateOldCommitMap) {
+ // If prepare <= snapshot < commit we should keep the entry around since its
+ // nonexistence could be interpreted as committed in the snapshot while it is
+ // not true. We keep such entries around by adding them to the
+ // old_commit_map_.
+ uint64_t p /*prepare*/, c /*commit*/, s /*snapshot*/, ns /*next_snapshot*/;
+ p = 10l, c = 15l, s = 20l, ns = 21l;
+ MaybeUpdateOldCommitMapTestWithNext(p, c, s, ns, false);
+ // If we do not expect the old commit map to be updated, try also with a next
+ // snapshot that is expected to update the old commit map. This would test
+ // that MaybeUpdateOldCommitMap would not prevent us from checking the next
+ // snapshot that must be checked.
+ p = 10l, c = 15l, s = 20l, ns = 11l;
+ MaybeUpdateOldCommitMapTestWithNext(p, c, s, ns, false);
+
+ p = 10l, c = 20l, s = 20l, ns = 19l;
+ MaybeUpdateOldCommitMapTestWithNext(p, c, s, ns, false);
+ p = 10l, c = 20l, s = 20l, ns = 21l;
+ MaybeUpdateOldCommitMapTestWithNext(p, c, s, ns, false);
+
+ p = 20l, c = 20l, s = 20l, ns = 21l;
+ MaybeUpdateOldCommitMapTestWithNext(p, c, s, ns, false);
+ p = 20l, c = 20l, s = 20l, ns = 19l;
+ MaybeUpdateOldCommitMapTestWithNext(p, c, s, ns, false);
+
+ p = 10l, c = 25l, s = 20l, ns = 21l;
+ MaybeUpdateOldCommitMapTestWithNext(p, c, s, ns, true);
+
+ p = 20l, c = 25l, s = 20l, ns = 21l;
+ MaybeUpdateOldCommitMapTestWithNext(p, c, s, ns, true);
+
+ p = 21l, c = 25l, s = 20l, ns = 22l;
+ MaybeUpdateOldCommitMapTestWithNext(p, c, s, ns, false);
+ p = 21l, c = 25l, s = 20l, ns = 19l;
+ MaybeUpdateOldCommitMapTestWithNext(p, c, s, ns, false);
+}
+
+// Trigger the condition where some old memtables are skipped when doing
+// TransactionUtil::CheckKey(), and make sure the result is still correct.
+TEST_P(WritePreparedTransactionTest, CheckKeySkipOldMemtable) {
+ const int kAttemptHistoryMemtable = 0;
+ const int kAttemptImmMemTable = 1;
+ for (int attempt = kAttemptHistoryMemtable; attempt <= kAttemptImmMemTable;
+ attempt++) {
+ options.max_write_buffer_number_to_maintain = 3;
+ ASSERT_OK(ReOpen());
+
+ WriteOptions write_options;
+ ReadOptions read_options;
+ TransactionOptions txn_options;
+ txn_options.set_snapshot = true;
+ string value;
+
+ ASSERT_OK(db->Put(write_options, Slice("foo"), Slice("bar")));
+ ASSERT_OK(db->Put(write_options, Slice("foo2"), Slice("bar")));
+
+ Transaction* txn = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn != nullptr);
+ ASSERT_OK(txn->SetName("txn"));
+
+ Transaction* txn2 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn2 != nullptr);
+ ASSERT_OK(txn2->SetName("txn2"));
+
+ // This transaction is created to cause potential conflict.
+ Transaction* txn_x = db->BeginTransaction(write_options);
+ ASSERT_OK(txn_x->SetName("txn_x"));
+ ASSERT_OK(txn_x->Put(Slice("foo"), Slice("bar3")));
+ ASSERT_OK(txn_x->Prepare());
+
+ // Create snapshots after the prepare, but there should still
+ // be a conflict when trying to read "foo".
+
+ if (attempt == kAttemptImmMemTable) {
+ // For the second attempt, hold flush from beginning. The memtable
+ // will be switched to immutable after calling TEST_SwitchMemtable()
+ // while CheckKey() is called.
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency(
+ {{"WritePreparedTransactionTest.CheckKeySkipOldMemtable",
+ "FlushJob::Start"}});
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+ }
+
+ // force a memtable flush. The memtable should still be kept
+ FlushOptions flush_ops;
+ if (attempt == kAttemptHistoryMemtable) {
+ ASSERT_OK(db->Flush(flush_ops));
+ } else {
+ ASSERT_EQ(attempt, kAttemptImmMemTable);
+ DBImpl* db_impl = static_cast<DBImpl*>(db->GetRootDB());
+ ASSERT_OK(db_impl->TEST_SwitchMemtable());
+ }
+ uint64_t num_imm_mems;
+ ASSERT_TRUE(db->GetIntProperty(DB::Properties::kNumImmutableMemTable,
+ &num_imm_mems));
+ if (attempt == kAttemptHistoryMemtable) {
+ ASSERT_EQ(0, num_imm_mems);
+ } else {
+ ASSERT_EQ(attempt, kAttemptImmMemTable);
+ ASSERT_EQ(1, num_imm_mems);
+ }
+
+ // Put something in active memtable
+ ASSERT_OK(db->Put(write_options, Slice("foo3"), Slice("bar")));
+
+ // Create txn3 after flushing, but this transaction also needs to
+ // check all memtables because of they contains uncommitted data.
+ Transaction* txn3 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn3 != nullptr);
+ ASSERT_OK(txn3->SetName("txn3"));
+
+ // Commit the pending write
+ ASSERT_OK(txn_x->Commit());
+
+ // Commit txn, txn2 and tx3. txn and tx3 will conflict but txn2 will
+ // pass. In all cases, both memtables are queried.
+ SetPerfLevel(PerfLevel::kEnableCount);
+ get_perf_context()->Reset();
+ ASSERT_TRUE(txn3->GetForUpdate(read_options, "foo", &value).IsBusy());
+ // We should have checked two memtables, active and either immutable
+ // or history memtable, depending on the test case.
+ ASSERT_EQ(2, get_perf_context()->get_from_memtable_count);
+
+ get_perf_context()->Reset();
+ ASSERT_TRUE(txn->GetForUpdate(read_options, "foo", &value).IsBusy());
+ // We should have checked two memtables, active and either immutable
+ // or history memtable, depending on the test case.
+ ASSERT_EQ(2, get_perf_context()->get_from_memtable_count);
+
+ get_perf_context()->Reset();
+ ASSERT_OK(txn2->GetForUpdate(read_options, "foo2", &value));
+ ASSERT_EQ(value, "bar");
+ // We should have checked two memtables, and since there is no
+ // conflict, another Get() will be made and fetch the data from
+ // DB. If it is in immutable memtable, two extra memtable reads
+ // will be issued. If it is not (in history), only one will
+ // be made, which is to the active memtable.
+ if (attempt == kAttemptHistoryMemtable) {
+ ASSERT_EQ(3, get_perf_context()->get_from_memtable_count);
+ } else {
+ ASSERT_EQ(attempt, kAttemptImmMemTable);
+ ASSERT_EQ(4, get_perf_context()->get_from_memtable_count);
+ }
+
+ Transaction* txn4 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn4 != nullptr);
+ ASSERT_OK(txn4->SetName("txn4"));
+ get_perf_context()->Reset();
+ ASSERT_OK(txn4->GetForUpdate(read_options, "foo", &value));
+ if (attempt == kAttemptHistoryMemtable) {
+ // Active memtable will be checked in snapshot validation and when
+ // getting the value.
+ ASSERT_EQ(2, get_perf_context()->get_from_memtable_count);
+ } else {
+ // Only active memtable will be checked in snapshot validation but
+ // both of active and immutable snapshot will be queried when
+ // getting the value.
+ ASSERT_EQ(attempt, kAttemptImmMemTable);
+ ASSERT_EQ(3, get_perf_context()->get_from_memtable_count);
+ }
+
+ ASSERT_OK(txn2->Commit());
+ ASSERT_OK(txn4->Commit());
+
+ TEST_SYNC_POINT("WritePreparedTransactionTest.CheckKeySkipOldMemtable");
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+
+ SetPerfLevel(PerfLevel::kDisable);
+
+ delete txn;
+ delete txn2;
+ delete txn3;
+ delete txn4;
+ delete txn_x;
+ }
+}
+
+// Reproduce the bug with two snapshots with the same seuqence number and test
+// that the release of the first snapshot will not affect the reads by the other
+// snapshot
+TEST_P(WritePreparedTransactionTest, DoubleSnapshot) {
+ TransactionOptions txn_options;
+ Status s;
+
+ // Insert initial value
+ ASSERT_OK(db->Put(WriteOptions(), "key", "value1"));
+
+ WritePreparedTxnDB* wp_db = dynamic_cast<WritePreparedTxnDB*>(db);
+ Transaction* txn =
+ wp_db->BeginTransaction(WriteOptions(), txn_options, nullptr);
+ ASSERT_OK(txn->SetName("txn"));
+ ASSERT_OK(txn->Put("key", "value2"));
+ ASSERT_OK(txn->Prepare());
+ // Three snapshots with the same seq number
+ const Snapshot* snapshot0 = wp_db->GetSnapshot();
+ const Snapshot* snapshot1 = wp_db->GetSnapshot();
+ const Snapshot* snapshot2 = wp_db->GetSnapshot();
+ ASSERT_OK(txn->Commit());
+ SequenceNumber cache_size = wp_db->COMMIT_CACHE_SIZE;
+ SequenceNumber overlap_seq = txn->GetId() + cache_size;
+ delete txn;
+
+ // 4th snapshot with a larger seq
+ const Snapshot* snapshot3 = wp_db->GetSnapshot();
+ // Cause an eviction to advance max evicted seq number
+ // This also fetches the 4 snapshots from db since their seq is lower than the
+ // new max
+ wp_db->AddCommitted(overlap_seq, overlap_seq);
+
+ ReadOptions ropt;
+ // It should see the value before commit
+ ropt.snapshot = snapshot2;
+ PinnableSlice pinnable_val;
+ s = wp_db->Get(ropt, wp_db->DefaultColumnFamily(), "key", &pinnable_val);
+ ASSERT_OK(s);
+ ASSERT_TRUE(pinnable_val == "value1");
+ pinnable_val.Reset();
+
+ wp_db->ReleaseSnapshot(snapshot1);
+
+ // It should still see the value before commit
+ s = wp_db->Get(ropt, wp_db->DefaultColumnFamily(), "key", &pinnable_val);
+ ASSERT_OK(s);
+ ASSERT_TRUE(pinnable_val == "value1");
+ pinnable_val.Reset();
+
+ // Cause an eviction to advance max evicted seq number and trigger updating
+ // the snapshot list
+ overlap_seq += cache_size;
+ wp_db->AddCommitted(overlap_seq, overlap_seq);
+
+ // It should still see the value before commit
+ s = wp_db->Get(ropt, wp_db->DefaultColumnFamily(), "key", &pinnable_val);
+ ASSERT_OK(s);
+ ASSERT_TRUE(pinnable_val == "value1");
+ pinnable_val.Reset();
+
+ wp_db->ReleaseSnapshot(snapshot0);
+ wp_db->ReleaseSnapshot(snapshot2);
+ wp_db->ReleaseSnapshot(snapshot3);
+}
+
+size_t UniqueCnt(std::vector<SequenceNumber> vec) {
+ std::set<SequenceNumber> aset;
+ for (auto i : vec) {
+ aset.insert(i);
+ }
+ return aset.size();
+}
+// Test that the entries in old_commit_map_ get garbage collected properly
+TEST_P(WritePreparedTransactionTest, OldCommitMapGC) {
+ const size_t snapshot_cache_bits = 0;
+ const size_t commit_cache_bits = 0;
+ DBImpl* mock_db = new DBImpl(options, dbname);
+ UpdateTransactionDBOptions(snapshot_cache_bits, commit_cache_bits);
+ std::unique_ptr<WritePreparedTxnDBMock> wp_db(
+ new WritePreparedTxnDBMock(mock_db, txn_db_options));
+
+ SequenceNumber seq = 0;
+ // Take the first snapshot that overlaps with two txn
+ auto prep_seq = ++seq;
+ wp_db->AddPrepared(prep_seq);
+ auto prep_seq2 = ++seq;
+ wp_db->AddPrepared(prep_seq2);
+ auto snap_seq1 = seq;
+ wp_db->TakeSnapshot(snap_seq1);
+ auto commit_seq = ++seq;
+ wp_db->AddCommitted(prep_seq, commit_seq);
+ wp_db->RemovePrepared(prep_seq);
+ auto commit_seq2 = ++seq;
+ wp_db->AddCommitted(prep_seq2, commit_seq2);
+ wp_db->RemovePrepared(prep_seq2);
+ // Take the 2nd and 3rd snapshot that overlap with the same txn
+ prep_seq = ++seq;
+ wp_db->AddPrepared(prep_seq);
+ auto snap_seq2 = seq;
+ wp_db->TakeSnapshot(snap_seq2);
+ seq++;
+ auto snap_seq3 = seq;
+ wp_db->TakeSnapshot(snap_seq3);
+ seq++;
+ commit_seq = ++seq;
+ wp_db->AddCommitted(prep_seq, commit_seq);
+ wp_db->RemovePrepared(prep_seq);
+ // Make sure max_evicted_seq_ will be larger than 2nd snapshot by evicting the
+ // only item in the commit_cache_ via another commit.
+ prep_seq = ++seq;
+ wp_db->AddPrepared(prep_seq);
+ commit_seq = ++seq;
+ wp_db->AddCommitted(prep_seq, commit_seq);
+ wp_db->RemovePrepared(prep_seq);
+
+ // Verify that the evicted commit entries for all snapshots are in the
+ // old_commit_map_
+ {
+ ASSERT_FALSE(wp_db->old_commit_map_empty_.load());
+ ReadLock rl(&wp_db->old_commit_map_mutex_);
+ ASSERT_EQ(3, wp_db->old_commit_map_.size());
+ ASSERT_EQ(2, UniqueCnt(wp_db->old_commit_map_[snap_seq1]));
+ ASSERT_EQ(1, UniqueCnt(wp_db->old_commit_map_[snap_seq2]));
+ ASSERT_EQ(1, UniqueCnt(wp_db->old_commit_map_[snap_seq3]));
+ }
+
+ // Verify that the 2nd snapshot is cleaned up after the release
+ wp_db->ReleaseSnapshotInternal(snap_seq2);
+ {
+ ASSERT_FALSE(wp_db->old_commit_map_empty_.load());
+ ReadLock rl(&wp_db->old_commit_map_mutex_);
+ ASSERT_EQ(2, wp_db->old_commit_map_.size());
+ ASSERT_EQ(2, UniqueCnt(wp_db->old_commit_map_[snap_seq1]));
+ ASSERT_EQ(1, UniqueCnt(wp_db->old_commit_map_[snap_seq3]));
+ }
+
+ // Verify that the 1st snapshot is cleaned up after the release
+ wp_db->ReleaseSnapshotInternal(snap_seq1);
+ {
+ ASSERT_FALSE(wp_db->old_commit_map_empty_.load());
+ ReadLock rl(&wp_db->old_commit_map_mutex_);
+ ASSERT_EQ(1, wp_db->old_commit_map_.size());
+ ASSERT_EQ(1, UniqueCnt(wp_db->old_commit_map_[snap_seq3]));
+ }
+
+ // Verify that the 3rd snapshot is cleaned up after the release
+ wp_db->ReleaseSnapshotInternal(snap_seq3);
+ {
+ ASSERT_TRUE(wp_db->old_commit_map_empty_.load());
+ ReadLock rl(&wp_db->old_commit_map_mutex_);
+ ASSERT_EQ(0, wp_db->old_commit_map_.size());
+ }
+}
+
+TEST_P(WritePreparedTransactionTest, CheckAgainstSnapshots) {
+ std::vector<SequenceNumber> snapshots = {100l, 200l, 300l, 400l, 500l,
+ 600l, 700l, 800l, 900l};
+ const size_t snapshot_cache_bits = 2;
+ const uint64_t cache_size = 1ul << snapshot_cache_bits;
+ // Safety check to express the intended size in the test. Can be adjusted if
+ // the snapshots lists changed.
+ ASSERT_EQ((1ul << snapshot_cache_bits) * 2 + 1, snapshots.size());
+ DBImpl* mock_db = new DBImpl(options, dbname);
+ UpdateTransactionDBOptions(snapshot_cache_bits);
+ std::unique_ptr<WritePreparedTxnDBMock> wp_db(
+ new WritePreparedTxnDBMock(mock_db, txn_db_options));
+ SequenceNumber version = 1000l;
+ ASSERT_EQ(0, wp_db->snapshots_total_);
+ wp_db->UpdateSnapshots(snapshots, version);
+ ASSERT_EQ(snapshots.size(), wp_db->snapshots_total_);
+ // seq numbers are chosen so that we have two of them between each two
+ // snapshots. If the diff of two consecutive seq is more than 5, there is a
+ // snapshot between them.
+ std::vector<SequenceNumber> seqs = {50l, 55l, 150l, 155l, 250l, 255l, 350l,
+ 355l, 450l, 455l, 550l, 555l, 650l, 655l,
+ 750l, 755l, 850l, 855l, 950l, 955l};
+ ASSERT_GT(seqs.size(), 1);
+ for (size_t i = 0; i + 1 < seqs.size(); i++) {
+ wp_db->old_commit_map_empty_ = true; // reset
+ CommitEntry commit_entry = {seqs[i], seqs[i + 1]};
+ wp_db->CheckAgainstSnapshots(commit_entry);
+ // Expect update if there is snapshot in between the prepare and commit
+ bool expect_update = commit_entry.commit_seq - commit_entry.prep_seq > 5 &&
+ commit_entry.commit_seq >= snapshots.front() &&
+ commit_entry.prep_seq <= snapshots.back();
+ ASSERT_EQ(expect_update, !wp_db->old_commit_map_empty_);
+ }
+
+ // Test that search will include multiple snapshot from snapshot cache
+ {
+ // exclude first and last item in the cache
+ CommitEntry commit_entry = {snapshots.front() + 1,
+ snapshots[cache_size - 1] - 1};
+ wp_db->old_commit_map_empty_ = true; // reset
+ wp_db->old_commit_map_.clear();
+ wp_db->CheckAgainstSnapshots(commit_entry);
+ ASSERT_EQ(wp_db->old_commit_map_.size(), cache_size - 2);
+ }
+
+ // Test that search will include multiple snapshot from old snapshots
+ {
+ // include two in the middle
+ CommitEntry commit_entry = {snapshots[cache_size] + 1,
+ snapshots[cache_size + 2] + 1};
+ wp_db->old_commit_map_empty_ = true; // reset
+ wp_db->old_commit_map_.clear();
+ wp_db->CheckAgainstSnapshots(commit_entry);
+ ASSERT_EQ(wp_db->old_commit_map_.size(), 2);
+ }
+
+ // Test that search will include both snapshot cache and old snapshots
+ // Case 1: includes all in snapshot cache
+ {
+ CommitEntry commit_entry = {snapshots.front() - 1, snapshots.back() + 1};
+ wp_db->old_commit_map_empty_ = true; // reset
+ wp_db->old_commit_map_.clear();
+ wp_db->CheckAgainstSnapshots(commit_entry);
+ ASSERT_EQ(wp_db->old_commit_map_.size(), snapshots.size());
+ }
+
+ // Case 2: includes all snapshot caches except the smallest
+ {
+ CommitEntry commit_entry = {snapshots.front() + 1, snapshots.back() + 1};
+ wp_db->old_commit_map_empty_ = true; // reset
+ wp_db->old_commit_map_.clear();
+ wp_db->CheckAgainstSnapshots(commit_entry);
+ ASSERT_EQ(wp_db->old_commit_map_.size(), snapshots.size() - 1);
+ }
+
+ // Case 3: includes only the largest of snapshot cache
+ {
+ CommitEntry commit_entry = {snapshots[cache_size - 1] - 1,
+ snapshots.back() + 1};
+ wp_db->old_commit_map_empty_ = true; // reset
+ wp_db->old_commit_map_.clear();
+ wp_db->CheckAgainstSnapshots(commit_entry);
+ ASSERT_EQ(wp_db->old_commit_map_.size(), snapshots.size() - cache_size + 1);
+ }
+}
+
+#if !defined(ROCKSDB_VALGRIND_RUN) || defined(ROCKSDB_FULL_VALGRIND_RUN)
+// Test that CheckAgainstSnapshots will not miss a live snapshot if it is run in
+// parallel with UpdateSnapshots.
+TEST_P(SnapshotConcurrentAccessTest, SnapshotConcurrentAccess) {
+ // We have a sync point in the method under test after checking each snapshot.
+ // If you increase the max number of snapshots in this test, more sync points
+ // in the methods must also be added.
+ const std::vector<SequenceNumber> snapshots = {10l, 20l, 30l, 40l, 50l,
+ 60l, 70l, 80l, 90l, 100l};
+ const size_t snapshot_cache_bits = 2;
+ // Safety check to express the intended size in the test. Can be adjusted if
+ // the snapshots lists changed.
+ ASSERT_EQ((1ul << snapshot_cache_bits) * 2 + 2, snapshots.size());
+ SequenceNumber version = 1000l;
+ // Choose the cache size so that the new snapshot list could replace all the
+ // existing items in the cache and also have some overflow.
+ DBImpl* mock_db = new DBImpl(options, dbname);
+ UpdateTransactionDBOptions(snapshot_cache_bits);
+ std::unique_ptr<WritePreparedTxnDBMock> wp_db(
+ new WritePreparedTxnDBMock(mock_db, txn_db_options));
+ const size_t extra = 2;
+ size_t loop_id = 0;
+ // Add up to extra items that do not fit into the cache
+ for (size_t old_size = 1; old_size <= wp_db->SNAPSHOT_CACHE_SIZE + extra;
+ old_size++) {
+ const std::vector<SequenceNumber> old_snapshots(
+ snapshots.begin(), snapshots.begin() + old_size);
+
+ // Each member of old snapshot might or might not appear in the new list. We
+ // create a common_snapshots for each combination.
+ size_t new_comb_cnt = size_t(1) << old_size;
+ for (size_t new_comb = 0; new_comb < new_comb_cnt; new_comb++, loop_id++) {
+ if (loop_id % split_cnt_ != split_id_) continue;
+ printf("."); // To signal progress
+ fflush(stdout);
+ std::vector<SequenceNumber> common_snapshots;
+ for (size_t i = 0; i < old_snapshots.size(); i++) {
+ if (IsInCombination(i, new_comb)) {
+ common_snapshots.push_back(old_snapshots[i]);
+ }
+ }
+ // And add some new snapshots to the common list
+ for (size_t added_snapshots = 0;
+ added_snapshots <= snapshots.size() - old_snapshots.size();
+ added_snapshots++) {
+ std::vector<SequenceNumber> new_snapshots = common_snapshots;
+ for (size_t i = 0; i < added_snapshots; i++) {
+ new_snapshots.push_back(snapshots[old_snapshots.size() + i]);
+ }
+ for (auto it = common_snapshots.begin(); it != common_snapshots.end();
+ ++it) {
+ auto snapshot = *it;
+ // Create a commit entry that is around the snapshot and thus should
+ // be not be discarded
+ CommitEntry entry = {static_cast<uint64_t>(snapshot - 1),
+ snapshot + 1};
+ // The critical part is when iterating the snapshot cache. Afterwards,
+ // we are operating under the lock
+ size_t a_range =
+ std::min(old_snapshots.size(), wp_db->SNAPSHOT_CACHE_SIZE) + 1;
+ size_t b_range =
+ std::min(new_snapshots.size(), wp_db->SNAPSHOT_CACHE_SIZE) + 1;
+ // Break each thread at two points
+ for (size_t a1 = 1; a1 <= a_range; a1++) {
+ for (size_t a2 = a1 + 1; a2 <= a_range; a2++) {
+ for (size_t b1 = 1; b1 <= b_range; b1++) {
+ for (size_t b2 = b1 + 1; b2 <= b_range; b2++) {
+ SnapshotConcurrentAccessTestInternal(
+ wp_db.get(), old_snapshots, new_snapshots, entry, version,
+ a1, a2, b1, b2);
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ printf("\n");
+}
+#endif // !defined(ROCKSDB_VALGRIND_RUN) || defined(ROCKSDB_FULL_VALGRIND_RUN)
+
+// This test clarifies the contract of AdvanceMaxEvictedSeq method
+TEST_P(WritePreparedTransactionTest, AdvanceMaxEvictedSeqBasic) {
+ DBImpl* mock_db = new DBImpl(options, dbname);
+ std::unique_ptr<WritePreparedTxnDBMock> wp_db(
+ new WritePreparedTxnDBMock(mock_db, txn_db_options));
+
+ // 1. Set the initial values for max, prepared, and snapshots
+ SequenceNumber zero_max = 0l;
+ // Set the initial list of prepared txns
+ const std::vector<SequenceNumber> initial_prepared = {10, 30, 50, 100,
+ 150, 200, 250};
+ for (auto p : initial_prepared) {
+ wp_db->AddPrepared(p);
+ }
+ // This updates the max value and also set old prepared
+ SequenceNumber init_max = 100;
+ wp_db->AdvanceMaxEvictedSeq(zero_max, init_max);
+ const std::vector<SequenceNumber> initial_snapshots = {20, 40};
+ wp_db->SetDBSnapshots(initial_snapshots);
+ // This will update the internal cache of snapshots from the DB
+ wp_db->UpdateSnapshots(initial_snapshots, init_max);
+
+ // 2. Invoke AdvanceMaxEvictedSeq
+ const std::vector<SequenceNumber> latest_snapshots = {20, 110, 220, 300};
+ wp_db->SetDBSnapshots(latest_snapshots);
+ SequenceNumber new_max = 200;
+ wp_db->AdvanceMaxEvictedSeq(init_max, new_max);
+
+ // 3. Verify that the state matches with AdvanceMaxEvictedSeq contract
+ // a. max should be updated to new_max
+ ASSERT_EQ(wp_db->max_evicted_seq_, new_max);
+ // b. delayed prepared should contain every txn <= max and prepared should
+ // only contain txns > max
+ auto it = initial_prepared.begin();
+ for (; it != initial_prepared.end() && *it <= new_max; ++it) {
+ ASSERT_EQ(1, wp_db->delayed_prepared_.erase(*it));
+ }
+ ASSERT_TRUE(wp_db->delayed_prepared_.empty());
+ for (; it != initial_prepared.end() && !wp_db->prepared_txns_.empty();
+ ++it, wp_db->prepared_txns_.pop()) {
+ ASSERT_EQ(*it, wp_db->prepared_txns_.top());
+ }
+ ASSERT_TRUE(it == initial_prepared.end());
+ ASSERT_TRUE(wp_db->prepared_txns_.empty());
+ // c. snapshots should contain everything below new_max
+ auto sit = latest_snapshots.begin();
+ for (size_t i = 0; sit != latest_snapshots.end() && *sit <= new_max &&
+ i < wp_db->snapshots_total_;
+ sit++, i++) {
+ ASSERT_TRUE(i < wp_db->snapshots_total_);
+ // This test is in small scale and the list of snapshots are assumed to be
+ // within the cache size limit. This is just a safety check to double check
+ // that assumption.
+ ASSERT_TRUE(i < wp_db->SNAPSHOT_CACHE_SIZE);
+ ASSERT_EQ(*sit, wp_db->snapshot_cache_[i]);
+ }
+}
+
+// A new snapshot should always be always larger than max_evicted_seq_
+// Otherwise the snapshot does not go through AdvanceMaxEvictedSeq
+TEST_P(WritePreparedTransactionTest, NewSnapshotLargerThanMax) {
+ WriteOptions woptions;
+ TransactionOptions txn_options;
+ WritePreparedTxnDB* wp_db = dynamic_cast<WritePreparedTxnDB*>(db);
+ Transaction* txn0 = db->BeginTransaction(woptions, txn_options);
+ ASSERT_OK(txn0->Put(Slice("key"), Slice("value")));
+ ASSERT_OK(txn0->Commit());
+ const SequenceNumber seq = txn0->GetId(); // is also prepare seq
+ delete txn0;
+ std::vector<Transaction*> txns;
+ // Inc seq without committing anything
+ for (int i = 0; i < 10; i++) {
+ Transaction* txn = db->BeginTransaction(woptions, txn_options);
+ ASSERT_OK(txn->SetName("xid" + std::to_string(i)));
+ ASSERT_OK(txn->Put(Slice("key" + std::to_string(i)), Slice("value")));
+ ASSERT_OK(txn->Prepare());
+ txns.push_back(txn);
+ }
+
+ // The new commit is seq + 10
+ ASSERT_OK(db->Put(woptions, "key", "value"));
+ auto snap = wp_db->GetSnapshot();
+ const SequenceNumber last_seq = snap->GetSequenceNumber();
+ wp_db->ReleaseSnapshot(snap);
+ ASSERT_LT(seq, last_seq);
+ // Otherwise our test is not effective
+ ASSERT_LT(last_seq - seq, wp_db->INC_STEP_FOR_MAX_EVICTED);
+
+ // Evict seq out of commit cache
+ const SequenceNumber overwrite_seq = seq + wp_db->COMMIT_CACHE_SIZE;
+ // Check that the next write could make max go beyond last
+ auto last_max = wp_db->max_evicted_seq_.load();
+ wp_db->AddCommitted(overwrite_seq, overwrite_seq);
+ // Check that eviction has advanced the max
+ ASSERT_LT(last_max, wp_db->max_evicted_seq_.load());
+ // Check that the new max has not advanced the last seq
+ ASSERT_LT(wp_db->max_evicted_seq_.load(), last_seq);
+ for (auto txn : txns) {
+ txn->Rollback();
+ delete txn;
+ }
+}
+
+// A new snapshot should always be always larger than max_evicted_seq_
+// In very rare cases max could be below last published seq. Test that
+// taking snapshot will wait for max to catch up.
+TEST_P(WritePreparedTransactionTest, MaxCatchupWithNewSnapshot) {
+ const size_t snapshot_cache_bits = 7; // same as default
+ const size_t commit_cache_bits = 0; // only 1 entry => frequent eviction
+ UpdateTransactionDBOptions(snapshot_cache_bits, commit_cache_bits);
+ ASSERT_OK(ReOpen());
+ WriteOptions woptions;
+ WritePreparedTxnDB* wp_db = dynamic_cast<WritePreparedTxnDB*>(db);
+
+ const int writes = 50;
+ const int batch_cnt = 4;
+ ROCKSDB_NAMESPACE::port::Thread t1([&]() {
+ for (int i = 0; i < writes; i++) {
+ WriteBatch batch;
+ // For duplicate keys cause 4 commit entries, each evicting an entry that
+ // is not published yet, thus causing max evicted seq go higher than last
+ // published.
+ for (int b = 0; b < batch_cnt; b++) {
+ ASSERT_OK(batch.Put("foo", "foo"));
+ }
+ ASSERT_OK(db->Write(woptions, &batch));
+ }
+ });
+
+ ROCKSDB_NAMESPACE::port::Thread t2([&]() {
+ while (wp_db->max_evicted_seq_ == 0) { // wait for insert thread
+ std::this_thread::yield();
+ }
+ for (int i = 0; i < 10; i++) {
+ SequenceNumber max_lower_bound = wp_db->max_evicted_seq_;
+ auto snap = db->GetSnapshot();
+ if (snap->GetSequenceNumber() != 0) {
+ // Value of max_evicted_seq_ when snapshot was taken in unknown. We thus
+ // compare with the lower bound instead as an approximation.
+ ASSERT_LT(max_lower_bound, snap->GetSequenceNumber());
+ } // seq 0 is ok to be less than max since nothing is visible to it
+ db->ReleaseSnapshot(snap);
+ }
+ });
+
+ t1.join();
+ t2.join();
+
+ // Make sure that the test has worked and seq number has advanced as we
+ // thought
+ auto snap = db->GetSnapshot();
+ ASSERT_GT(snap->GetSequenceNumber(), batch_cnt * writes - 1);
+ db->ReleaseSnapshot(snap);
+}
+
+// Test that reads without snapshots would not hit an undefined state
+TEST_P(WritePreparedTransactionTest, MaxCatchupWithUnbackedSnapshot) {
+ const size_t snapshot_cache_bits = 7; // same as default
+ const size_t commit_cache_bits = 0; // only 1 entry => frequent eviction
+ UpdateTransactionDBOptions(snapshot_cache_bits, commit_cache_bits);
+ ASSERT_OK(ReOpen());
+ WriteOptions woptions;
+ WritePreparedTxnDB* wp_db = dynamic_cast<WritePreparedTxnDB*>(db);
+
+ const int writes = 50;
+ ROCKSDB_NAMESPACE::port::Thread t1([&]() {
+ for (int i = 0; i < writes; i++) {
+ WriteBatch batch;
+ ASSERT_OK(batch.Put("key", "foo"));
+ ASSERT_OK(db->Write(woptions, &batch));
+ }
+ });
+
+ ROCKSDB_NAMESPACE::port::Thread t2([&]() {
+ while (wp_db->max_evicted_seq_ == 0) { // wait for insert thread
+ std::this_thread::yield();
+ }
+ ReadOptions ropt;
+ PinnableSlice pinnable_val;
+ TransactionOptions txn_options;
+ for (int i = 0; i < 10; i++) {
+ auto s = db->Get(ropt, db->DefaultColumnFamily(), "key", &pinnable_val);
+ ASSERT_TRUE(s.ok() || s.IsTryAgain());
+ pinnable_val.Reset();
+ Transaction* txn = db->BeginTransaction(woptions, txn_options);
+ s = txn->Get(ropt, db->DefaultColumnFamily(), "key", &pinnable_val);
+ ASSERT_TRUE(s.ok() || s.IsTryAgain());
+ pinnable_val.Reset();
+ std::vector<std::string> values;
+ auto s_vec =
+ txn->MultiGet(ropt, {db->DefaultColumnFamily()}, {"key"}, &values);
+ ASSERT_EQ(1, values.size());
+ ASSERT_EQ(1, s_vec.size());
+ s = s_vec[0];
+ ASSERT_TRUE(s.ok() || s.IsTryAgain());
+ Slice key("key");
+ txn->MultiGet(ropt, db->DefaultColumnFamily(), 1, &key, &pinnable_val, &s,
+ true);
+ ASSERT_TRUE(s.ok() || s.IsTryAgain());
+ delete txn;
+ }
+ });
+
+ t1.join();
+ t2.join();
+
+ // Make sure that the test has worked and seq number has advanced as we
+ // thought
+ auto snap = db->GetSnapshot();
+ ASSERT_GT(snap->GetSequenceNumber(), writes - 1);
+ db->ReleaseSnapshot(snap);
+}
+
+// Check that old_commit_map_ cleanup works correctly if the snapshot equals
+// max_evicted_seq_.
+TEST_P(WritePreparedTransactionTest, CleanupSnapshotEqualToMax) {
+ const size_t snapshot_cache_bits = 7; // same as default
+ const size_t commit_cache_bits = 0; // only 1 entry => frequent eviction
+ UpdateTransactionDBOptions(snapshot_cache_bits, commit_cache_bits);
+ ASSERT_OK(ReOpen());
+ WriteOptions woptions;
+ WritePreparedTxnDB* wp_db = dynamic_cast<WritePreparedTxnDB*>(db);
+ // Insert something to increase seq
+ ASSERT_OK(db->Put(woptions, "key", "value"));
+ auto snap = db->GetSnapshot();
+ auto snap_seq = snap->GetSequenceNumber();
+ // Another insert should trigger eviction + load snapshot from db
+ ASSERT_OK(db->Put(woptions, "key", "value"));
+ // This is the scenario that we check agaisnt
+ ASSERT_EQ(snap_seq, wp_db->max_evicted_seq_);
+ // old_commit_map_ now has some data that needs gc
+ ASSERT_EQ(1, wp_db->snapshots_total_);
+ ASSERT_EQ(1, wp_db->old_commit_map_.size());
+
+ db->ReleaseSnapshot(snap);
+
+ // Another insert should trigger eviction + load snapshot from db
+ ASSERT_OK(db->Put(woptions, "key", "value"));
+
+ // the snapshot and related metadata must be properly garbage collected
+ ASSERT_EQ(0, wp_db->snapshots_total_);
+ ASSERT_TRUE(wp_db->snapshots_all_.empty());
+ ASSERT_EQ(0, wp_db->old_commit_map_.size());
+}
+
+TEST_P(WritePreparedTransactionTest, AdvanceSeqByOne) {
+ auto snap = db->GetSnapshot();
+ auto seq1 = snap->GetSequenceNumber();
+ db->ReleaseSnapshot(snap);
+
+ WritePreparedTxnDB* wp_db = dynamic_cast<WritePreparedTxnDB*>(db);
+ wp_db->AdvanceSeqByOne();
+
+ snap = db->GetSnapshot();
+ auto seq2 = snap->GetSequenceNumber();
+ db->ReleaseSnapshot(snap);
+
+ ASSERT_LT(seq1, seq2);
+}
+
+// Test that the txn Initilize calls the overridden functions
+TEST_P(WritePreparedTransactionTest, TxnInitialize) {
+ TransactionOptions txn_options;
+ WriteOptions write_options;
+ ASSERT_OK(db->Put(write_options, "key", "value"));
+ Transaction* txn0 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(txn0->SetName("xid"));
+ ASSERT_OK(txn0->Put(Slice("key"), Slice("value1")));
+ ASSERT_OK(txn0->Prepare());
+
+ // SetSnapshot is overridden to update min_uncommitted_
+ txn_options.set_snapshot = true;
+ Transaction* txn1 = db->BeginTransaction(write_options, txn_options);
+ auto snap = txn1->GetSnapshot();
+ auto snap_impl = reinterpret_cast<const SnapshotImpl*>(snap);
+ // If ::Initialize calls the overriden SetSnapshot, min_uncommitted_ must be
+ // udpated
+ ASSERT_GT(snap_impl->min_uncommitted_, kMinUnCommittedSeq);
+
+ ASSERT_OK(txn0->Rollback());
+ ASSERT_OK(txn1->Rollback());
+ delete txn0;
+ delete txn1;
+}
+
+// This tests that transactions with duplicate keys perform correctly after max
+// is advancing their prepared sequence numbers. This will not be the case if
+// for example the txn does not add the prepared seq for the second sub-batch to
+// the PreparedHeap structure.
+TEST_P(WritePreparedTransactionTest, AdvanceMaxEvictedSeqWithDuplicates) {
+ const size_t snapshot_cache_bits = 7; // same as default
+ const size_t commit_cache_bits = 1; // disable commit cache
+ UpdateTransactionDBOptions(snapshot_cache_bits, commit_cache_bits);
+ ASSERT_OK(ReOpen());
+
+ ReadOptions ropt;
+ PinnableSlice pinnable_val;
+ WriteOptions write_options;
+ TransactionOptions txn_options;
+ Transaction* txn0 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(txn0->SetName("xid"));
+ ASSERT_OK(txn0->Put(Slice("key"), Slice("value1")));
+ ASSERT_OK(txn0->Put(Slice("key"), Slice("value2")));
+ ASSERT_OK(txn0->Prepare());
+
+ ASSERT_OK(db->Put(write_options, "key2", "value"));
+ // Will cause max advance due to disabled commit cache
+ ASSERT_OK(db->Put(write_options, "key3", "value"));
+
+ auto s = db->Get(ropt, db->DefaultColumnFamily(), "key", &pinnable_val);
+ ASSERT_TRUE(s.IsNotFound());
+ delete txn0;
+
+ WritePreparedTxnDB* wp_db = dynamic_cast<WritePreparedTxnDB*>(db);
+ ASSERT_OK(wp_db->db_impl_->FlushWAL(true));
+ wp_db->TEST_Crash();
+ ASSERT_OK(ReOpenNoDelete());
+ ASSERT_NE(db, nullptr);
+ s = db->Get(ropt, db->DefaultColumnFamily(), "key", &pinnable_val);
+ ASSERT_TRUE(s.IsNotFound());
+
+ txn0 = db->GetTransactionByName("xid");
+ ASSERT_OK(txn0->Rollback());
+ delete txn0;
+}
+
+#if !defined(ROCKSDB_VALGRIND_RUN) || defined(ROCKSDB_FULL_VALGRIND_RUN)
+// Stress SmallestUnCommittedSeq, which reads from both prepared_txns_ and
+// delayed_prepared_, when is run concurrently with advancing max_evicted_seq,
+// which moves prepared txns from prepared_txns_ to delayed_prepared_.
+TEST_P(WritePreparedTransactionTest, SmallestUnCommittedSeq) {
+ const size_t snapshot_cache_bits = 7; // same as default
+ const size_t commit_cache_bits = 1; // disable commit cache
+ UpdateTransactionDBOptions(snapshot_cache_bits, commit_cache_bits);
+ ASSERT_OK(ReOpen());
+ WritePreparedTxnDB* wp_db = dynamic_cast<WritePreparedTxnDB*>(db);
+ ReadOptions ropt;
+ PinnableSlice pinnable_val;
+ WriteOptions write_options;
+ TransactionOptions txn_options;
+ std::vector<Transaction*> txns, committed_txns;
+
+ const int cnt = 100;
+ for (int i = 0; i < cnt; i++) {
+ Transaction* txn = db->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(txn->SetName("xid" + std::to_string(i)));
+ auto key = "key1" + std::to_string(i);
+ auto value = "value1" + std::to_string(i);
+ ASSERT_OK(txn->Put(Slice(key), Slice(value)));
+ ASSERT_OK(txn->Prepare());
+ txns.push_back(txn);
+ }
+
+ port::Mutex mutex;
+ Random rnd(1103);
+ ROCKSDB_NAMESPACE::port::Thread commit_thread([&]() {
+ for (int i = 0; i < cnt; i++) {
+ uint32_t index = rnd.Uniform(cnt - i);
+ Transaction* txn;
+ {
+ MutexLock l(&mutex);
+ txn = txns[index];
+ txns.erase(txns.begin() + index);
+ }
+ // Since commit cache is practically disabled, commit results in immediate
+ // advance in max_evicted_seq_ and subsequently moving some prepared txns
+ // to delayed_prepared_.
+ ASSERT_OK(txn->Commit());
+ committed_txns.push_back(txn);
+ }
+ });
+ ROCKSDB_NAMESPACE::port::Thread read_thread([&]() {
+ while (1) {
+ MutexLock l(&mutex);
+ if (txns.empty()) {
+ break;
+ }
+ auto min_uncommitted = wp_db->SmallestUnCommittedSeq();
+ ASSERT_LE(min_uncommitted, (*txns.begin())->GetId());
+ }
+ });
+
+ commit_thread.join();
+ read_thread.join();
+ for (auto txn : committed_txns) {
+ delete txn;
+ }
+}
+#endif // !defined(ROCKSDB_VALGRIND_RUN) || defined(ROCKSDB_FULL_VALGRIND_RUN)
+
+TEST_P(SeqAdvanceConcurrentTest, SeqAdvanceConcurrent) {
+ // Given the sequential run of txns, with this timeout we should never see a
+ // deadlock nor a timeout unless we have a key conflict, which should be
+ // almost infeasible.
+ txn_db_options.transaction_lock_timeout = 1000;
+ txn_db_options.default_lock_timeout = 1000;
+ ASSERT_OK(ReOpen());
+ FlushOptions fopt;
+
+ // Number of different txn types we use in this test
+ const size_t type_cnt = 5;
+ // The size of the first write group
+ // TODO(myabandeh): This should be increase for pre-release tests
+ const size_t first_group_size = 2;
+ // Total number of txns we run in each test
+ // TODO(myabandeh): This should be increase for pre-release tests
+ const size_t txn_cnt = first_group_size + 1;
+
+ size_t base[txn_cnt + 1] = {
+ 1,
+ };
+ for (size_t bi = 1; bi <= txn_cnt; bi++) {
+ base[bi] = base[bi - 1] * type_cnt;
+ }
+ const size_t max_n = static_cast<size_t>(std::pow(type_cnt, txn_cnt));
+ printf("Number of cases being tested is %" ROCKSDB_PRIszt "\n", max_n);
+ for (size_t n = 0; n < max_n; n++) {
+ if (n > 0) {
+ ASSERT_OK(ReOpen());
+ }
+
+ if (n % split_cnt_ != split_id_) continue;
+ if (n % 1000 == 0) {
+ printf("Tested %" ROCKSDB_PRIszt " cases so far\n", n);
+ }
+ DBImpl* db_impl = static_cast_with_check<DBImpl>(db->GetRootDB());
+ auto seq = db_impl->TEST_GetLastVisibleSequence();
+ with_empty_commits = 0;
+ exp_seq = seq;
+ // This is increased before writing the batch for commit
+ commit_writes = 0;
+ // This is increased before txn starts linking if it expects to do a commit
+ // eventually
+ expected_commits = 0;
+ std::vector<port::Thread> threads;
+
+ linked.store(0, std::memory_order_release);
+ std::atomic<bool> batch_formed(false);
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "WriteThread::EnterAsBatchGroupLeader:End",
+ [&](void* /*arg*/) { batch_formed = true; });
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "WriteThread::JoinBatchGroup:Wait", [&](void* /*arg*/) {
+ size_t orig_linked = linked.fetch_add(1, std::memory_order_acq_rel);
+ if (orig_linked == 0) {
+ // Wait until the others are linked too.
+ while (linked.load(std::memory_order_acquire) < first_group_size) {
+ }
+ } else if (orig_linked == first_group_size) {
+ // Make the 2nd batch of the rest of writes plus any followup
+ // commits from the first batch
+ while (linked.load(std::memory_order_acquire) <
+ txn_cnt + commit_writes) {
+ }
+ }
+ // Then we will have one or more batches consisting of follow-up
+ // commits from the 2nd batch. There is a bit of non-determinism here
+ // but it should be tolerable.
+ });
+
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+ for (size_t bi = 0; bi < txn_cnt; bi++) {
+ // get the bi-th digit in number system based on type_cnt
+ size_t d = (n % base[bi + 1]) / base[bi];
+ switch (d) {
+ case 0:
+ threads.emplace_back(&TransactionTestBase::TestTxn0, this, bi);
+ break;
+ case 1:
+ threads.emplace_back(&TransactionTestBase::TestTxn1, this, bi);
+ break;
+ case 2:
+ threads.emplace_back(&TransactionTestBase::TestTxn2, this, bi);
+ break;
+ case 3:
+ threads.emplace_back(&TransactionTestBase::TestTxn3, this, bi);
+ break;
+ case 4:
+ threads.emplace_back(&TransactionTestBase::TestTxn3, this, bi);
+ break;
+ default:
+ FAIL();
+ }
+ // wait to be linked
+ while (linked.load(std::memory_order_acquire) <= bi) {
+ }
+ // after a queue of size first_group_size
+ if (bi + 1 == first_group_size) {
+ while (!batch_formed) {
+ }
+ // to make it more deterministic, wait until the commits are linked
+ while (linked.load(std::memory_order_acquire) <=
+ bi + expected_commits) {
+ }
+ }
+ }
+ for (auto& t : threads) {
+ t.join();
+ }
+ if (options.two_write_queues) {
+ // In this case none of the above scheduling tricks to deterministically
+ // form merged batches works because the writes go to separate queues.
+ // This would result in different write groups in each run of the test. We
+ // still keep the test since although non-deterministic and hard to debug,
+ // it is still useful to have.
+ // TODO(myabandeh): Add a deterministic unit test for two_write_queues
+ }
+
+ // Check if memtable inserts advanced seq number as expected
+ seq = db_impl->TEST_GetLastVisibleSequence();
+ ASSERT_EQ(exp_seq, seq);
+
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->ClearAllCallBacks();
+
+ // Check if recovery preserves the last sequence number
+ ASSERT_OK(db_impl->FlushWAL(true));
+ ASSERT_OK(ReOpenNoDelete());
+ ASSERT_NE(db, nullptr);
+ db_impl = static_cast_with_check<DBImpl>(db->GetRootDB());
+ seq = db_impl->TEST_GetLastVisibleSequence();
+ ASSERT_LE(exp_seq, seq + with_empty_commits);
+
+ // Check if flush preserves the last sequence number
+ ASSERT_OK(db_impl->Flush(fopt));
+ seq = db_impl->GetLatestSequenceNumber();
+ ASSERT_LE(exp_seq, seq + with_empty_commits);
+
+ // Check if recovery after flush preserves the last sequence number
+ ASSERT_OK(db_impl->FlushWAL(true));
+ ASSERT_OK(ReOpenNoDelete());
+ ASSERT_NE(db, nullptr);
+ db_impl = static_cast_with_check<DBImpl>(db->GetRootDB());
+ seq = db_impl->GetLatestSequenceNumber();
+ ASSERT_LE(exp_seq, seq + with_empty_commits);
+ }
+}
+
+// Run a couple of different txns among them some uncommitted. Restart the db at
+// a couple points to check whether the list of uncommitted txns are recovered
+// properly.
+TEST_P(WritePreparedTransactionTest, BasicRecovery) {
+ options.disable_auto_compactions = true;
+ ASSERT_OK(ReOpen());
+ WritePreparedTxnDB* wp_db = dynamic_cast<WritePreparedTxnDB*>(db);
+
+ TestTxn0(0);
+
+ TransactionOptions txn_options;
+ WriteOptions write_options;
+ size_t index = 1000;
+ Transaction* txn0 = db->BeginTransaction(write_options, txn_options);
+ auto istr0 = std::to_string(index);
+ auto s = txn0->SetName("xid" + istr0);
+ ASSERT_OK(s);
+ s = txn0->Put(Slice("foo0" + istr0), Slice("bar0" + istr0));
+ ASSERT_OK(s);
+ s = txn0->Prepare();
+ ASSERT_OK(s);
+ auto prep_seq_0 = txn0->GetId();
+
+ TestTxn1(0);
+
+ index++;
+ Transaction* txn1 = db->BeginTransaction(write_options, txn_options);
+ auto istr1 = std::to_string(index);
+ s = txn1->SetName("xid" + istr1);
+ ASSERT_OK(s);
+ s = txn1->Put(Slice("foo1" + istr1), Slice("bar"));
+ ASSERT_OK(s);
+ s = txn1->Prepare();
+ ASSERT_OK(s);
+ auto prep_seq_1 = txn1->GetId();
+
+ TestTxn2(0);
+
+ ReadOptions ropt;
+ PinnableSlice pinnable_val;
+ // Check the value is not committed before restart
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo0" + istr0, &pinnable_val);
+ ASSERT_TRUE(s.IsNotFound());
+ pinnable_val.Reset();
+
+ delete txn0;
+ delete txn1;
+ ASSERT_OK(wp_db->db_impl_->FlushWAL(true));
+ wp_db->TEST_Crash();
+ ASSERT_OK(ReOpenNoDelete());
+ ASSERT_NE(db, nullptr);
+ wp_db = dynamic_cast<WritePreparedTxnDB*>(db);
+ // After recovery, all the uncommitted txns (0 and 1) should be inserted into
+ // delayed_prepared_
+ ASSERT_TRUE(wp_db->prepared_txns_.empty());
+ ASSERT_FALSE(wp_db->delayed_prepared_empty_);
+ ASSERT_LE(prep_seq_0, wp_db->max_evicted_seq_);
+ ASSERT_LE(prep_seq_1, wp_db->max_evicted_seq_);
+ {
+ ReadLock rl(&wp_db->prepared_mutex_);
+ ASSERT_EQ(2, wp_db->delayed_prepared_.size());
+ ASSERT_TRUE(wp_db->delayed_prepared_.find(prep_seq_0) !=
+ wp_db->delayed_prepared_.end());
+ ASSERT_TRUE(wp_db->delayed_prepared_.find(prep_seq_1) !=
+ wp_db->delayed_prepared_.end());
+ }
+
+ // Check the value is still not committed after restart
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo0" + istr0, &pinnable_val);
+ ASSERT_TRUE(s.IsNotFound());
+ pinnable_val.Reset();
+
+ TestTxn3(0);
+
+ // Test that a recovered txns will be properly marked committed for the next
+ // recovery
+ txn1 = db->GetTransactionByName("xid" + istr1);
+ ASSERT_NE(txn1, nullptr);
+ ASSERT_OK(txn1->Commit());
+ delete txn1;
+
+ index++;
+ Transaction* txn2 = db->BeginTransaction(write_options, txn_options);
+ auto istr2 = std::to_string(index);
+ s = txn2->SetName("xid" + istr2);
+ ASSERT_OK(s);
+ s = txn2->Put(Slice("foo2" + istr2), Slice("bar"));
+ ASSERT_OK(s);
+ s = txn2->Prepare();
+ ASSERT_OK(s);
+ auto prep_seq_2 = txn2->GetId();
+
+ delete txn2;
+ ASSERT_OK(wp_db->db_impl_->FlushWAL(true));
+ wp_db->TEST_Crash();
+ ASSERT_OK(ReOpenNoDelete());
+ ASSERT_NE(db, nullptr);
+ wp_db = dynamic_cast<WritePreparedTxnDB*>(db);
+ ASSERT_TRUE(wp_db->prepared_txns_.empty());
+ ASSERT_FALSE(wp_db->delayed_prepared_empty_);
+
+ // 0 and 2 are prepared and 1 is committed
+ {
+ ReadLock rl(&wp_db->prepared_mutex_);
+ ASSERT_EQ(2, wp_db->delayed_prepared_.size());
+ const auto& end = wp_db->delayed_prepared_.end();
+ ASSERT_NE(wp_db->delayed_prepared_.find(prep_seq_0), end);
+ ASSERT_EQ(wp_db->delayed_prepared_.find(prep_seq_1), end);
+ ASSERT_NE(wp_db->delayed_prepared_.find(prep_seq_2), end);
+ }
+ ASSERT_LE(prep_seq_0, wp_db->max_evicted_seq_);
+ ASSERT_LE(prep_seq_2, wp_db->max_evicted_seq_);
+
+ // Commit all the remaining txns
+ txn0 = db->GetTransactionByName("xid" + istr0);
+ ASSERT_NE(txn0, nullptr);
+ ASSERT_OK(txn0->Commit());
+ txn2 = db->GetTransactionByName("xid" + istr2);
+ ASSERT_NE(txn2, nullptr);
+ ASSERT_OK(txn2->Commit());
+
+ // Check the value is committed after commit
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo0" + istr0, &pinnable_val);
+ ASSERT_TRUE(s.ok());
+ ASSERT_TRUE(pinnable_val == ("bar0" + istr0));
+ pinnable_val.Reset();
+
+ delete txn0;
+ delete txn2;
+ ASSERT_OK(wp_db->db_impl_->FlushWAL(true));
+ ASSERT_OK(ReOpenNoDelete());
+ ASSERT_NE(db, nullptr);
+ wp_db = dynamic_cast<WritePreparedTxnDB*>(db);
+ ASSERT_TRUE(wp_db->prepared_txns_.empty());
+ ASSERT_TRUE(wp_db->delayed_prepared_empty_);
+
+ // Check the value is still committed after recovery
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo0" + istr0, &pinnable_val);
+ ASSERT_TRUE(s.ok());
+ ASSERT_TRUE(pinnable_val == ("bar0" + istr0));
+ pinnable_val.Reset();
+}
+
+// After recovery the commit map is empty while the max is set. The code would
+// go through a different path which requires a separate test. Test that the
+// committed data before the restart is visible to all snapshots.
+TEST_P(WritePreparedTransactionTest, IsInSnapshotEmptyMap) {
+ for (bool end_with_prepare : {false, true}) {
+ ASSERT_OK(ReOpen());
+ WriteOptions woptions;
+ ASSERT_OK(db->Put(woptions, "key", "value"));
+ ASSERT_OK(db->Put(woptions, "key", "value"));
+ ASSERT_OK(db->Put(woptions, "key", "value"));
+ SequenceNumber prepare_seq = kMaxSequenceNumber;
+ if (end_with_prepare) {
+ TransactionOptions txn_options;
+ Transaction* txn = db->BeginTransaction(woptions, txn_options);
+ ASSERT_OK(txn->SetName("xid0"));
+ ASSERT_OK(txn->Prepare());
+ prepare_seq = txn->GetId();
+ delete txn;
+ }
+ dynamic_cast<WritePreparedTxnDB*>(db)->TEST_Crash();
+ auto db_impl = static_cast_with_check<DBImpl>(db->GetRootDB());
+ ASSERT_OK(db_impl->FlushWAL(true));
+ ASSERT_OK(ReOpenNoDelete());
+ WritePreparedTxnDB* wp_db = dynamic_cast<WritePreparedTxnDB*>(db);
+ ASSERT_NE(wp_db, nullptr);
+ ASSERT_GT(wp_db->max_evicted_seq_, 0); // max after recovery
+ // Take a snapshot right after recovery
+ const Snapshot* snap = db->GetSnapshot();
+ auto snap_seq = snap->GetSequenceNumber();
+ ASSERT_GT(snap_seq, 0);
+
+ for (SequenceNumber seq = 0;
+ seq <= wp_db->max_evicted_seq_ && seq != prepare_seq; seq++) {
+ ASSERT_TRUE(wp_db->IsInSnapshot(seq, snap_seq));
+ }
+ if (end_with_prepare) {
+ ASSERT_FALSE(wp_db->IsInSnapshot(prepare_seq, snap_seq));
+ }
+ // trivial check
+ ASSERT_FALSE(wp_db->IsInSnapshot(snap_seq + 1, snap_seq));
+
+ db->ReleaseSnapshot(snap);
+
+ ASSERT_OK(db->Put(woptions, "key", "value"));
+ // Take a snapshot after some writes
+ snap = db->GetSnapshot();
+ snap_seq = snap->GetSequenceNumber();
+ for (SequenceNumber seq = 0;
+ seq <= wp_db->max_evicted_seq_ && seq != prepare_seq; seq++) {
+ ASSERT_TRUE(wp_db->IsInSnapshot(seq, snap_seq));
+ }
+ if (end_with_prepare) {
+ ASSERT_FALSE(wp_db->IsInSnapshot(prepare_seq, snap_seq));
+ }
+ // trivial check
+ ASSERT_FALSE(wp_db->IsInSnapshot(snap_seq + 1, snap_seq));
+
+ db->ReleaseSnapshot(snap);
+ }
+}
+
+// Shows the contract of IsInSnapshot when called on invalid/released snapshots
+TEST_P(WritePreparedTransactionTest, IsInSnapshotReleased) {
+ WritePreparedTxnDB* wp_db = dynamic_cast<WritePreparedTxnDB*>(db);
+ WriteOptions woptions;
+ ASSERT_OK(db->Put(woptions, "key", "value"));
+ // snap seq = 1
+ const Snapshot* snap1 = db->GetSnapshot();
+ ASSERT_OK(db->Put(woptions, "key", "value"));
+ ASSERT_OK(db->Put(woptions, "key", "value"));
+ // snap seq = 3
+ const Snapshot* snap2 = db->GetSnapshot();
+ const SequenceNumber seq = 1;
+ // Evict seq out of commit cache
+ size_t overwrite_seq = wp_db->COMMIT_CACHE_SIZE + seq;
+ wp_db->AddCommitted(overwrite_seq, overwrite_seq);
+ SequenceNumber snap_seq;
+ uint64_t min_uncommitted = kMinUnCommittedSeq;
+ bool released;
+
+ released = false;
+ snap_seq = snap1->GetSequenceNumber();
+ ASSERT_LE(seq, snap_seq);
+ // Valid snapshot lower than max
+ ASSERT_LE(snap_seq, wp_db->max_evicted_seq_);
+ ASSERT_TRUE(wp_db->IsInSnapshot(seq, snap_seq, min_uncommitted, &released));
+ ASSERT_FALSE(released);
+
+ released = false;
+ snap_seq = snap1->GetSequenceNumber();
+ // Invaid snapshot lower than max
+ ASSERT_LE(snap_seq + 1, wp_db->max_evicted_seq_);
+ ASSERT_TRUE(
+ wp_db->IsInSnapshot(seq, snap_seq + 1, min_uncommitted, &released));
+ ASSERT_TRUE(released);
+
+ db->ReleaseSnapshot(snap1);
+
+ released = false;
+ // Released snapshot lower than max
+ ASSERT_TRUE(wp_db->IsInSnapshot(seq, snap_seq, min_uncommitted, &released));
+ // The release does not take affect until the next max advance
+ ASSERT_FALSE(released);
+
+ released = false;
+ // Invaid snapshot lower than max
+ ASSERT_TRUE(
+ wp_db->IsInSnapshot(seq, snap_seq + 1, min_uncommitted, &released));
+ ASSERT_TRUE(released);
+
+ // This make the snapshot release to reflect in txn db structures
+ wp_db->AdvanceMaxEvictedSeq(wp_db->max_evicted_seq_,
+ wp_db->max_evicted_seq_ + 1);
+
+ released = false;
+ // Released snapshot lower than max
+ ASSERT_TRUE(wp_db->IsInSnapshot(seq, snap_seq, min_uncommitted, &released));
+ ASSERT_TRUE(released);
+
+ released = false;
+ // Invaid snapshot lower than max
+ ASSERT_TRUE(
+ wp_db->IsInSnapshot(seq, snap_seq + 1, min_uncommitted, &released));
+ ASSERT_TRUE(released);
+
+ snap_seq = snap2->GetSequenceNumber();
+
+ released = false;
+ // Unreleased snapshot lower than max
+ ASSERT_TRUE(wp_db->IsInSnapshot(seq, snap_seq, min_uncommitted, &released));
+ ASSERT_FALSE(released);
+
+ db->ReleaseSnapshot(snap2);
+}
+
+// Test WritePreparedTxnDB's IsInSnapshot against different ordering of
+// snapshot, max_committed_seq_, prepared, and commit entries.
+TEST_P(WritePreparedTransactionTest, IsInSnapshot) {
+ WriteOptions wo;
+ // Use small commit cache to trigger lots of eviction and fast advance of
+ // max_evicted_seq_
+ const size_t commit_cache_bits = 3;
+ // Same for snapshot cache size
+ const size_t snapshot_cache_bits = 2;
+
+ // Take some preliminary snapshots first. This is to stress the data structure
+ // that holds the old snapshots as it will be designed to be efficient when
+ // only a few snapshots are below the max_evicted_seq_.
+ for (int max_snapshots = 1; max_snapshots < 20; max_snapshots++) {
+ // Leave some gap between the preliminary snapshots and the final snapshot
+ // that we check. This should test for also different overlapping scenarios
+ // between the last snapshot and the commits.
+ for (int max_gap = 1; max_gap < 10; max_gap++) {
+ // Since we do not actually write to db, we mock the seq as it would be
+ // increased by the db. The only exception is that we need db seq to
+ // advance for our snapshots. for which we apply a dummy put each time we
+ // increase our mock of seq.
+ uint64_t seq = 0;
+ // At each step we prepare a txn and then we commit it in the next txn.
+ // This emulates the consecutive transactions that write to the same key
+ uint64_t cur_txn = 0;
+ // Number of snapshots taken so far
+ int num_snapshots = 0;
+ // Number of gaps applied so far
+ int gap_cnt = 0;
+ // The final snapshot that we will inspect
+ uint64_t snapshot = 0;
+ bool found_committed = false;
+ // To stress the data structure that maintain prepared txns, at each cycle
+ // we add a new prepare txn. These do not mean to be committed for
+ // snapshot inspection.
+ std::set<uint64_t> prepared;
+ // We keep the list of txns committed before we take the last snapshot.
+ // These should be the only seq numbers that will be found in the snapshot
+ std::set<uint64_t> committed_before;
+ // The set of commit seq numbers to be excluded from IsInSnapshot queries
+ std::set<uint64_t> commit_seqs;
+ DBImpl* mock_db = new DBImpl(options, dbname);
+ UpdateTransactionDBOptions(snapshot_cache_bits, commit_cache_bits);
+ std::unique_ptr<WritePreparedTxnDBMock> wp_db(
+ new WritePreparedTxnDBMock(mock_db, txn_db_options));
+ // We continue until max advances a bit beyond the snapshot.
+ while (!snapshot || wp_db->max_evicted_seq_ < snapshot + 100) {
+ // do prepare for a transaction
+ seq++;
+ wp_db->AddPrepared(seq);
+ prepared.insert(seq);
+
+ // If cur_txn is not started, do prepare for it.
+ if (!cur_txn) {
+ seq++;
+ cur_txn = seq;
+ wp_db->AddPrepared(cur_txn);
+ } else { // else commit it
+ seq++;
+ wp_db->AddCommitted(cur_txn, seq);
+ wp_db->RemovePrepared(cur_txn);
+ commit_seqs.insert(seq);
+ if (!snapshot) {
+ committed_before.insert(cur_txn);
+ }
+ cur_txn = 0;
+ }
+
+ if (num_snapshots < max_snapshots - 1) {
+ // Take preliminary snapshots
+ wp_db->TakeSnapshot(seq);
+ num_snapshots++;
+ } else if (gap_cnt < max_gap) {
+ // Wait for some gap before taking the final snapshot
+ gap_cnt++;
+ } else if (!snapshot) {
+ // Take the final snapshot if it is not already taken
+ snapshot = seq;
+ wp_db->TakeSnapshot(snapshot);
+ num_snapshots++;
+ }
+
+ // If the snapshot is taken, verify seq numbers visible to it. We redo
+ // it at each cycle to test that the system is still sound when
+ // max_evicted_seq_ advances.
+ if (snapshot) {
+ for (uint64_t s = 1;
+ s <= seq && commit_seqs.find(s) == commit_seqs.end(); s++) {
+ bool was_committed =
+ (committed_before.find(s) != committed_before.end());
+ bool is_in_snapshot = wp_db->IsInSnapshot(s, snapshot);
+ if (was_committed != is_in_snapshot) {
+ printf("max_snapshots %d max_gap %d seq %" PRIu64 " max %" PRIu64
+ " snapshot %" PRIu64
+ " gap_cnt %d num_snapshots %d s %" PRIu64 "\n",
+ max_snapshots, max_gap, seq,
+ wp_db->max_evicted_seq_.load(), snapshot, gap_cnt,
+ num_snapshots, s);
+ }
+ ASSERT_EQ(was_committed, is_in_snapshot);
+ found_committed = found_committed || is_in_snapshot;
+ }
+ }
+ }
+ // Safety check to make sure the test actually ran
+ ASSERT_TRUE(found_committed);
+ // As an extra check, check if prepared set will be properly empty after
+ // they are committed.
+ if (cur_txn) {
+ wp_db->AddCommitted(cur_txn, seq);
+ wp_db->RemovePrepared(cur_txn);
+ }
+ for (auto p : prepared) {
+ wp_db->AddCommitted(p, seq);
+ wp_db->RemovePrepared(p);
+ }
+ ASSERT_TRUE(wp_db->delayed_prepared_.empty());
+ ASSERT_TRUE(wp_db->prepared_txns_.empty());
+ }
+ }
+}
+
+void ASSERT_SAME(ReadOptions roptions, TransactionDB* db, Status exp_s,
+ PinnableSlice& exp_v, Slice key) {
+ Status s;
+ PinnableSlice v;
+ s = db->Get(roptions, db->DefaultColumnFamily(), key, &v);
+ ASSERT_EQ(exp_s, s);
+ ASSERT_TRUE(s.ok() || s.IsNotFound());
+ if (s.ok()) {
+ ASSERT_TRUE(exp_v == v);
+ }
+
+ // Try with MultiGet API too
+ std::vector<std::string> values;
+ auto s_vec =
+ db->MultiGet(roptions, {db->DefaultColumnFamily()}, {key}, &values);
+ ASSERT_EQ(1, values.size());
+ ASSERT_EQ(1, s_vec.size());
+ s = s_vec[0];
+ ASSERT_EQ(exp_s, s);
+ ASSERT_TRUE(s.ok() || s.IsNotFound());
+ if (s.ok()) {
+ ASSERT_TRUE(exp_v == values[0]);
+ }
+}
+
+void ASSERT_SAME(TransactionDB* db, Status exp_s, PinnableSlice& exp_v,
+ Slice key) {
+ ASSERT_SAME(ReadOptions(), db, exp_s, exp_v, key);
+}
+
+TEST_P(WritePreparedTransactionTest, Rollback) {
+ ReadOptions roptions;
+ WriteOptions woptions;
+ TransactionOptions txn_options;
+ const size_t num_keys = 4;
+ const size_t num_values = 5;
+ for (size_t ikey = 1; ikey <= num_keys; ikey++) {
+ for (size_t ivalue = 0; ivalue < num_values; ivalue++) {
+ for (bool crash : {false, true}) {
+ ASSERT_OK(ReOpen());
+ WritePreparedTxnDB* wp_db = dynamic_cast<WritePreparedTxnDB*>(db);
+ std::string key_str = "key" + std::to_string(ikey);
+ switch (ivalue) {
+ case 0:
+ break;
+ case 1:
+ ASSERT_OK(db->Put(woptions, key_str, "initvalue1"));
+ break;
+ case 2:
+ ASSERT_OK(db->Merge(woptions, key_str, "initvalue2"));
+ break;
+ case 3:
+ ASSERT_OK(db->Delete(woptions, key_str));
+ break;
+ case 4:
+ ASSERT_OK(db->SingleDelete(woptions, key_str));
+ break;
+ default:
+ FAIL();
+ }
+
+ PinnableSlice v1;
+ auto s1 =
+ db->Get(roptions, db->DefaultColumnFamily(), Slice("key1"), &v1);
+ PinnableSlice v2;
+ auto s2 =
+ db->Get(roptions, db->DefaultColumnFamily(), Slice("key2"), &v2);
+ PinnableSlice v3;
+ auto s3 =
+ db->Get(roptions, db->DefaultColumnFamily(), Slice("key3"), &v3);
+ PinnableSlice v4;
+ auto s4 =
+ db->Get(roptions, db->DefaultColumnFamily(), Slice("key4"), &v4);
+ Transaction* txn = db->BeginTransaction(woptions, txn_options);
+ auto s = txn->SetName("xid0");
+ ASSERT_OK(s);
+ s = txn->Put(Slice("key1"), Slice("value1"));
+ ASSERT_OK(s);
+ s = txn->Merge(Slice("key2"), Slice("value2"));
+ ASSERT_OK(s);
+ s = txn->Delete(Slice("key3"));
+ ASSERT_OK(s);
+ s = txn->SingleDelete(Slice("key4"));
+ ASSERT_OK(s);
+ s = txn->Prepare();
+ ASSERT_OK(s);
+
+ {
+ ReadLock rl(&wp_db->prepared_mutex_);
+ ASSERT_FALSE(wp_db->prepared_txns_.empty());
+ ASSERT_EQ(txn->GetId(), wp_db->prepared_txns_.top());
+ }
+
+ ASSERT_SAME(db, s1, v1, "key1");
+ ASSERT_SAME(db, s2, v2, "key2");
+ ASSERT_SAME(db, s3, v3, "key3");
+ ASSERT_SAME(db, s4, v4, "key4");
+
+ if (crash) {
+ delete txn;
+ auto db_impl = static_cast_with_check<DBImpl>(db->GetRootDB());
+ ASSERT_OK(db_impl->FlushWAL(true));
+ dynamic_cast<WritePreparedTxnDB*>(db)->TEST_Crash();
+ ASSERT_OK(ReOpenNoDelete());
+ ASSERT_NE(db, nullptr);
+ wp_db = dynamic_cast<WritePreparedTxnDB*>(db);
+ txn = db->GetTransactionByName("xid0");
+ ASSERT_FALSE(wp_db->delayed_prepared_empty_);
+ ReadLock rl(&wp_db->prepared_mutex_);
+ ASSERT_TRUE(wp_db->prepared_txns_.empty());
+ ASSERT_FALSE(wp_db->delayed_prepared_.empty());
+ ASSERT_TRUE(wp_db->delayed_prepared_.find(txn->GetId()) !=
+ wp_db->delayed_prepared_.end());
+ }
+
+ ASSERT_SAME(db, s1, v1, "key1");
+ ASSERT_SAME(db, s2, v2, "key2");
+ ASSERT_SAME(db, s3, v3, "key3");
+ ASSERT_SAME(db, s4, v4, "key4");
+
+ s = txn->Rollback();
+ ASSERT_OK(s);
+
+ {
+ ASSERT_TRUE(wp_db->delayed_prepared_empty_);
+ ReadLock rl(&wp_db->prepared_mutex_);
+ ASSERT_TRUE(wp_db->prepared_txns_.empty());
+ ASSERT_TRUE(wp_db->delayed_prepared_.empty());
+ }
+
+ ASSERT_SAME(db, s1, v1, "key1");
+ ASSERT_SAME(db, s2, v2, "key2");
+ ASSERT_SAME(db, s3, v3, "key3");
+ ASSERT_SAME(db, s4, v4, "key4");
+ delete txn;
+ }
+ }
+ }
+}
+
+TEST_P(WritePreparedTransactionTest, DisableGCDuringRecovery) {
+ // Use large buffer to avoid memtable flush after 1024 insertions
+ options.write_buffer_size = 1024 * 1024;
+ ASSERT_OK(ReOpen());
+ std::vector<KeyVersion> versions;
+ uint64_t seq = 0;
+ for (uint64_t i = 1; i <= 1024; i++) {
+ std::string v = "bar" + std::to_string(i);
+ ASSERT_OK(db->Put(WriteOptions(), "foo", v));
+ VerifyKeys({{"foo", v}});
+ seq++; // one for the key/value
+ KeyVersion kv = {"foo", v, seq, kTypeValue};
+ if (options.two_write_queues) {
+ seq++; // one for the commit
+ }
+ versions.emplace_back(kv);
+ }
+ std::reverse(std::begin(versions), std::end(versions));
+ VerifyInternalKeys(versions);
+ DBImpl* db_impl = static_cast_with_check<DBImpl>(db->GetRootDB());
+ ASSERT_OK(db_impl->FlushWAL(true));
+ // Use small buffer to ensure memtable flush during recovery
+ options.write_buffer_size = 1024;
+ ASSERT_OK(ReOpenNoDelete());
+ VerifyInternalKeys(versions);
+}
+
+TEST_P(WritePreparedTransactionTest, SequenceNumberZero) {
+ ASSERT_OK(db->Put(WriteOptions(), "foo", "bar"));
+ VerifyKeys({{"foo", "bar"}});
+ const Snapshot* snapshot = db->GetSnapshot();
+ ASSERT_OK(db->Flush(FlushOptions()));
+ // Dummy keys to avoid compaction trivially move files and get around actual
+ // compaction logic.
+ ASSERT_OK(db->Put(WriteOptions(), "a", "dummy"));
+ ASSERT_OK(db->Put(WriteOptions(), "z", "dummy"));
+ ASSERT_OK(db->CompactRange(CompactRangeOptions(), nullptr, nullptr));
+ // Compaction will output keys with sequence number 0, if it is visible to
+ // earliest snapshot. Make sure IsInSnapshot() report sequence number 0 is
+ // visible to any snapshot.
+ VerifyKeys({{"foo", "bar"}});
+ VerifyKeys({{"foo", "bar"}}, snapshot);
+ VerifyInternalKeys({{"foo", "bar", 0, kTypeValue}});
+ db->ReleaseSnapshot(snapshot);
+}
+
+// Compaction should not remove a key if it is not committed, and should
+// proceed with older versions of the key as-if the new version doesn't exist.
+TEST_P(WritePreparedTransactionTest, CompactionShouldKeepUncommittedKeys) {
+ options.disable_auto_compactions = true;
+ ASSERT_OK(ReOpen());
+ DBImpl* db_impl = static_cast_with_check<DBImpl>(db->GetRootDB());
+ // Snapshots to avoid keys get evicted.
+ std::vector<const Snapshot*> snapshots;
+ // Keep track of expected sequence number.
+ SequenceNumber expected_seq = 0;
+
+ auto add_key = [&](std::function<Status()> func) {
+ ASSERT_OK(func());
+ expected_seq++;
+ if (options.two_write_queues) {
+ expected_seq++; // 1 for commit
+ }
+ ASSERT_EQ(expected_seq, db_impl->TEST_GetLastVisibleSequence());
+ snapshots.push_back(db->GetSnapshot());
+ };
+
+ // Each key here represent a standalone test case.
+ add_key([&]() { return db->Put(WriteOptions(), "key1", "value1_1"); });
+ add_key([&]() { return db->Put(WriteOptions(), "key2", "value2_1"); });
+ add_key([&]() { return db->Put(WriteOptions(), "key3", "value3_1"); });
+ add_key([&]() { return db->Put(WriteOptions(), "key4", "value4_1"); });
+ add_key([&]() { return db->Merge(WriteOptions(), "key5", "value5_1"); });
+ add_key([&]() { return db->Merge(WriteOptions(), "key5", "value5_2"); });
+ add_key([&]() { return db->Put(WriteOptions(), "key6", "value6_1"); });
+ add_key([&]() { return db->Put(WriteOptions(), "key7", "value7_1"); });
+ ASSERT_OK(db->Flush(FlushOptions()));
+ add_key([&]() { return db->Delete(WriteOptions(), "key6"); });
+ add_key([&]() { return db->SingleDelete(WriteOptions(), "key7"); });
+
+ auto* transaction = db->BeginTransaction(WriteOptions());
+ ASSERT_OK(transaction->SetName("txn"));
+ ASSERT_OK(transaction->Put("key1", "value1_2"));
+ ASSERT_OK(transaction->Delete("key2"));
+ ASSERT_OK(transaction->SingleDelete("key3"));
+ ASSERT_OK(transaction->Merge("key4", "value4_2"));
+ ASSERT_OK(transaction->Merge("key5", "value5_3"));
+ ASSERT_OK(transaction->Put("key6", "value6_2"));
+ ASSERT_OK(transaction->Put("key7", "value7_2"));
+ // Prepare but not commit.
+ ASSERT_OK(transaction->Prepare());
+ ASSERT_EQ(++expected_seq, db->GetLatestSequenceNumber());
+ ASSERT_OK(db->Flush(FlushOptions()));
+ for (auto* s : snapshots) {
+ db->ReleaseSnapshot(s);
+ }
+ // Dummy keys to avoid compaction trivially move files and get around actual
+ // compaction logic.
+ ASSERT_OK(db->Put(WriteOptions(), "a", "dummy"));
+ ASSERT_OK(db->Put(WriteOptions(), "z", "dummy"));
+ ASSERT_OK(db->CompactRange(CompactRangeOptions(), nullptr, nullptr));
+ VerifyKeys({
+ {"key1", "value1_1"},
+ {"key2", "value2_1"},
+ {"key3", "value3_1"},
+ {"key4", "value4_1"},
+ {"key5", "value5_1,value5_2"},
+ {"key6", "NOT_FOUND"},
+ {"key7", "NOT_FOUND"},
+ });
+ VerifyInternalKeys({
+ {"key1", "value1_2", expected_seq, kTypeValue},
+ {"key1", "value1_1", 0, kTypeValue},
+ {"key2", "", expected_seq, kTypeDeletion},
+ {"key2", "value2_1", 0, kTypeValue},
+ {"key3", "", expected_seq, kTypeSingleDeletion},
+ {"key3", "value3_1", 0, kTypeValue},
+ {"key4", "value4_2", expected_seq, kTypeMerge},
+ {"key4", "value4_1", 0, kTypeValue},
+ {"key5", "value5_3", expected_seq, kTypeMerge},
+ {"key5", "value5_1,value5_2", 0, kTypeValue},
+ {"key6", "value6_2", expected_seq, kTypeValue},
+ {"key7", "value7_2", expected_seq, kTypeValue},
+ });
+ ASSERT_OK(transaction->Commit());
+ VerifyKeys({
+ {"key1", "value1_2"},
+ {"key2", "NOT_FOUND"},
+ {"key3", "NOT_FOUND"},
+ {"key4", "value4_1,value4_2"},
+ {"key5", "value5_1,value5_2,value5_3"},
+ {"key6", "value6_2"},
+ {"key7", "value7_2"},
+ });
+ delete transaction;
+}
+
+// Compaction should keep keys visible to a snapshot based on commit sequence,
+// not just prepare sequence.
+TEST_P(WritePreparedTransactionTest, CompactionShouldKeepSnapshotVisibleKeys) {
+ options.disable_auto_compactions = true;
+ ASSERT_OK(ReOpen());
+ // Keep track of expected sequence number.
+ SequenceNumber expected_seq = 0;
+ auto* txn1 = db->BeginTransaction(WriteOptions());
+ ASSERT_OK(txn1->SetName("txn1"));
+ ASSERT_OK(txn1->Put("key1", "value1_1"));
+ ASSERT_OK(txn1->Prepare());
+ ASSERT_EQ(++expected_seq, db->GetLatestSequenceNumber());
+ ASSERT_OK(txn1->Commit());
+ DBImpl* db_impl = static_cast_with_check<DBImpl>(db->GetRootDB());
+ ASSERT_EQ(++expected_seq, db_impl->TEST_GetLastVisibleSequence());
+ delete txn1;
+ // Take a snapshots to avoid keys get evicted before compaction.
+ const Snapshot* snapshot1 = db->GetSnapshot();
+ auto* txn2 = db->BeginTransaction(WriteOptions());
+ ASSERT_OK(txn2->SetName("txn2"));
+ ASSERT_OK(txn2->Put("key2", "value2_1"));
+ ASSERT_OK(txn2->Prepare());
+ ASSERT_EQ(++expected_seq, db->GetLatestSequenceNumber());
+ // txn1 commit before snapshot2 and it is visible to snapshot2.
+ // txn2 commit after snapshot2 and it is not visible.
+ const Snapshot* snapshot2 = db->GetSnapshot();
+ ASSERT_OK(txn2->Commit());
+ ASSERT_EQ(++expected_seq, db_impl->TEST_GetLastVisibleSequence());
+ delete txn2;
+ // Take a snapshots to avoid keys get evicted before compaction.
+ const Snapshot* snapshot3 = db->GetSnapshot();
+ ASSERT_OK(db->Put(WriteOptions(), "key1", "value1_2"));
+ expected_seq++; // 1 for write
+ SequenceNumber seq1 = expected_seq;
+ if (options.two_write_queues) {
+ expected_seq++; // 1 for commit
+ }
+ ASSERT_EQ(expected_seq, db_impl->TEST_GetLastVisibleSequence());
+ ASSERT_OK(db->Put(WriteOptions(), "key2", "value2_2"));
+ expected_seq++; // 1 for write
+ SequenceNumber seq2 = expected_seq;
+ if (options.two_write_queues) {
+ expected_seq++; // 1 for commit
+ }
+ ASSERT_EQ(expected_seq, db_impl->TEST_GetLastVisibleSequence());
+ ASSERT_OK(db->Flush(FlushOptions()));
+ db->ReleaseSnapshot(snapshot1);
+ db->ReleaseSnapshot(snapshot3);
+ // Dummy keys to avoid compaction trivially move files and get around actual
+ // compaction logic.
+ ASSERT_OK(db->Put(WriteOptions(), "a", "dummy"));
+ ASSERT_OK(db->Put(WriteOptions(), "z", "dummy"));
+ ASSERT_OK(db->CompactRange(CompactRangeOptions(), nullptr, nullptr));
+ VerifyKeys({{"key1", "value1_2"}, {"key2", "value2_2"}});
+ VerifyKeys({{"key1", "value1_1"}, {"key2", "NOT_FOUND"}}, snapshot2);
+ VerifyInternalKeys({
+ {"key1", "value1_2", seq1, kTypeValue},
+ // "value1_1" is visible to snapshot2. Also keys at bottom level visible
+ // to earliest snapshot will output with seq = 0.
+ {"key1", "value1_1", 0, kTypeValue},
+ {"key2", "value2_2", seq2, kTypeValue},
+ });
+ db->ReleaseSnapshot(snapshot2);
+}
+
+TEST_P(WritePreparedTransactionTest, SmallestUncommittedOptimization) {
+ const size_t snapshot_cache_bits = 7; // same as default
+ const size_t commit_cache_bits = 0; // disable commit cache
+ for (bool has_recent_prepare : {true, false}) {
+ UpdateTransactionDBOptions(snapshot_cache_bits, commit_cache_bits);
+ ASSERT_OK(ReOpen());
+
+ ASSERT_OK(db->Put(WriteOptions(), "key1", "value1"));
+ auto* transaction =
+ db->BeginTransaction(WriteOptions(), TransactionOptions(), nullptr);
+ ASSERT_OK(transaction->SetName("txn"));
+ ASSERT_OK(transaction->Delete("key1"));
+ ASSERT_OK(transaction->Prepare());
+ // snapshot1 should get min_uncommitted from prepared_txns_ heap.
+ auto snapshot1 = db->GetSnapshot();
+ ASSERT_EQ(transaction->GetId(),
+ ((SnapshotImpl*)snapshot1)->min_uncommitted_);
+ // Add a commit to advance max_evicted_seq and move the prepared transaction
+ // into delayed_prepared_ set.
+ ASSERT_OK(db->Put(WriteOptions(), "key2", "value2"));
+ Transaction* txn2 = nullptr;
+ if (has_recent_prepare) {
+ txn2 =
+ db->BeginTransaction(WriteOptions(), TransactionOptions(), nullptr);
+ ASSERT_OK(txn2->SetName("txn2"));
+ ASSERT_OK(txn2->Put("key3", "value3"));
+ ASSERT_OK(txn2->Prepare());
+ }
+ // snapshot2 should get min_uncommitted from delayed_prepared_ set.
+ auto snapshot2 = db->GetSnapshot();
+ ASSERT_EQ(transaction->GetId(),
+ ((SnapshotImpl*)snapshot1)->min_uncommitted_);
+ ASSERT_OK(transaction->Commit());
+ delete transaction;
+ if (has_recent_prepare) {
+ ASSERT_OK(txn2->Commit());
+ delete txn2;
+ }
+ VerifyKeys({{"key1", "NOT_FOUND"}});
+ VerifyKeys({{"key1", "value1"}}, snapshot1);
+ VerifyKeys({{"key1", "value1"}}, snapshot2);
+ db->ReleaseSnapshot(snapshot1);
+ db->ReleaseSnapshot(snapshot2);
+ }
+}
+
+// Insert two values, v1 and v2, for a key. Between prepare and commit of v2
+// take two snapshots, s1 and s2. Release s1 during compaction.
+// Test to make sure compaction doesn't get confused and think s1 can see both
+// values, and thus compact out the older value by mistake.
+TEST_P(WritePreparedTransactionTest, ReleaseSnapshotDuringCompaction) {
+ const size_t snapshot_cache_bits = 7; // same as default
+ const size_t commit_cache_bits = 0; // minimum commit cache
+ UpdateTransactionDBOptions(snapshot_cache_bits, commit_cache_bits);
+ options.disable_auto_compactions = true;
+ ASSERT_OK(ReOpen());
+
+ ASSERT_OK(db->Put(WriteOptions(), "key1", "value1_1"));
+ auto* transaction =
+ db->BeginTransaction(WriteOptions(), TransactionOptions(), nullptr);
+ ASSERT_OK(transaction->SetName("txn"));
+ ASSERT_OK(transaction->Put("key1", "value1_2"));
+ ASSERT_OK(transaction->Prepare());
+ auto snapshot1 = db->GetSnapshot();
+ // Increment sequence number.
+ ASSERT_OK(db->Put(WriteOptions(), "key2", "value2"));
+ auto snapshot2 = db->GetSnapshot();
+ ASSERT_OK(transaction->Commit());
+ delete transaction;
+ VerifyKeys({{"key1", "value1_2"}});
+ VerifyKeys({{"key1", "value1_1"}}, snapshot1);
+ VerifyKeys({{"key1", "value1_1"}}, snapshot2);
+ // Add a flush to avoid compaction to fallback to trivial move.
+
+ // The callback might be called twice, record the calling state to
+ // prevent double calling.
+ bool callback_finished = false;
+ auto callback = [&](void*) {
+ if (callback_finished) {
+ return;
+ }
+ // Release snapshot1 after CompactionIterator init.
+ // CompactionIterator need to figure out the earliest snapshot
+ // that can see key1:value1_2 is kMaxSequenceNumber, not
+ // snapshot1 or snapshot2.
+ db->ReleaseSnapshot(snapshot1);
+ // Add some keys to advance max_evicted_seq.
+ ASSERT_OK(db->Put(WriteOptions(), "key3", "value3"));
+ ASSERT_OK(db->Put(WriteOptions(), "key4", "value4"));
+ callback_finished = true;
+ };
+ SyncPoint::GetInstance()->SetCallBack("CompactionIterator:AfterInit",
+ callback);
+ SyncPoint::GetInstance()->EnableProcessing();
+
+ ASSERT_OK(db->Flush(FlushOptions()));
+ VerifyKeys({{"key1", "value1_2"}});
+ VerifyKeys({{"key1", "value1_1"}}, snapshot2);
+ db->ReleaseSnapshot(snapshot2);
+ SyncPoint::GetInstance()->ClearAllCallBacks();
+}
+
+// Insert two values, v1 and v2, for a key. Take two snapshots, s1 and s2,
+// after committing v2. Release s1 during compaction, right after compaction
+// processes v2 and before processes v1. Test to make sure compaction doesn't
+// get confused and believe v1 and v2 are visible to different snapshot
+// (v1 by s2, v2 by s1) and refuse to compact out v1.
+TEST_P(WritePreparedTransactionTest, ReleaseSnapshotDuringCompaction2) {
+ const size_t snapshot_cache_bits = 7; // same as default
+ const size_t commit_cache_bits = 0; // minimum commit cache
+ UpdateTransactionDBOptions(snapshot_cache_bits, commit_cache_bits);
+ options.disable_auto_compactions = true;
+ ASSERT_OK(ReOpen());
+
+ ASSERT_OK(db->Put(WriteOptions(), "key1", "value1"));
+ ASSERT_OK(db->Put(WriteOptions(), "key1", "value2"));
+ SequenceNumber v2_seq = db->GetLatestSequenceNumber();
+ auto* s1 = db->GetSnapshot();
+ // Advance sequence number.
+ ASSERT_OK(db->Put(WriteOptions(), "key2", "dummy"));
+ auto* s2 = db->GetSnapshot();
+
+ int count_value = 0;
+ auto callback = [&](void* arg) {
+ auto* ikey = reinterpret_cast<ParsedInternalKey*>(arg);
+ if (ikey->user_key == "key1") {
+ count_value++;
+ if (count_value == 2) {
+ // Processing v1.
+ db->ReleaseSnapshot(s1);
+ // Add some keys to advance max_evicted_seq and update
+ // old_commit_map.
+ ASSERT_OK(db->Put(WriteOptions(), "key3", "dummy"));
+ ASSERT_OK(db->Put(WriteOptions(), "key4", "dummy"));
+ }
+ }
+ };
+ SyncPoint::GetInstance()->SetCallBack("CompactionIterator:ProcessKV",
+ callback);
+ SyncPoint::GetInstance()->EnableProcessing();
+
+ ASSERT_OK(db->Flush(FlushOptions()));
+ // value1 should be compact out.
+ VerifyInternalKeys({{"key1", "value2", v2_seq, kTypeValue}});
+
+ // cleanup
+ db->ReleaseSnapshot(s2);
+ SyncPoint::GetInstance()->ClearAllCallBacks();
+}
+
+// Insert two values, v1 and v2, for a key. Insert another dummy key
+// so to evict the commit cache for v2, while v1 is still in commit cache.
+// Take two snapshots, s1 and s2. Release s1 during compaction.
+// Since commit cache for v2 is evicted, and old_commit_map don't have
+// s1 (it is released),
+// TODO(myabandeh): how can we be sure that the v2's commit info is evicted
+// (and not v1's)? Instead of putting a dummy, we can directly call
+// AddCommitted(v2_seq + cache_size, ...) to evict v2's entry from commit cache.
+TEST_P(WritePreparedTransactionTest, ReleaseSnapshotDuringCompaction3) {
+ const size_t snapshot_cache_bits = 7; // same as default
+ const size_t commit_cache_bits = 1; // commit cache size = 2
+ UpdateTransactionDBOptions(snapshot_cache_bits, commit_cache_bits);
+ options.disable_auto_compactions = true;
+ ASSERT_OK(ReOpen());
+
+ // Add a dummy key to evict v2 commit cache, but keep v1 commit cache.
+ // It also advance max_evicted_seq and can trigger old_commit_map cleanup.
+ auto add_dummy = [&]() {
+ auto* txn_dummy =
+ db->BeginTransaction(WriteOptions(), TransactionOptions(), nullptr);
+ ASSERT_OK(txn_dummy->SetName("txn_dummy"));
+ ASSERT_OK(txn_dummy->Put("dummy", "dummy"));
+ ASSERT_OK(txn_dummy->Prepare());
+ ASSERT_OK(txn_dummy->Commit());
+ delete txn_dummy;
+ };
+
+ ASSERT_OK(db->Put(WriteOptions(), "key1", "value1"));
+ auto* txn =
+ db->BeginTransaction(WriteOptions(), TransactionOptions(), nullptr);
+ ASSERT_OK(txn->SetName("txn"));
+ ASSERT_OK(txn->Put("key1", "value2"));
+ ASSERT_OK(txn->Prepare());
+ // TODO(myabandeh): replace it with GetId()?
+ auto v2_seq = db->GetLatestSequenceNumber();
+ ASSERT_OK(txn->Commit());
+ delete txn;
+ auto* s1 = db->GetSnapshot();
+ // Dummy key to advance sequence number.
+ add_dummy();
+ auto* s2 = db->GetSnapshot();
+
+ // The callback might be called twice, record the calling state to
+ // prevent double calling.
+ bool callback_finished = false;
+ auto callback = [&](void*) {
+ if (callback_finished) {
+ return;
+ }
+ db->ReleaseSnapshot(s1);
+ // Add some dummy entries to trigger s1 being cleanup from old_commit_map.
+ add_dummy();
+ add_dummy();
+ callback_finished = true;
+ };
+ SyncPoint::GetInstance()->SetCallBack("CompactionIterator:AfterInit",
+ callback);
+ SyncPoint::GetInstance()->EnableProcessing();
+
+ ASSERT_OK(db->Flush(FlushOptions()));
+ // value1 should be compact out.
+ VerifyInternalKeys({{"key1", "value2", v2_seq, kTypeValue}});
+
+ db->ReleaseSnapshot(s2);
+ SyncPoint::GetInstance()->ClearAllCallBacks();
+}
+
+TEST_P(WritePreparedTransactionTest, ReleaseEarliestSnapshotDuringCompaction) {
+ const size_t snapshot_cache_bits = 7; // same as default
+ const size_t commit_cache_bits = 0; // minimum commit cache
+ UpdateTransactionDBOptions(snapshot_cache_bits, commit_cache_bits);
+ options.disable_auto_compactions = true;
+ ASSERT_OK(ReOpen());
+
+ ASSERT_OK(db->Put(WriteOptions(), "key1", "value1"));
+ SequenceNumber put_seq = db->GetLatestSequenceNumber();
+ auto* transaction =
+ db->BeginTransaction(WriteOptions(), TransactionOptions(), nullptr);
+ ASSERT_OK(transaction->SetName("txn"));
+ ASSERT_OK(transaction->Delete("key1"));
+ ASSERT_OK(transaction->Prepare());
+ SequenceNumber del_seq = db->GetLatestSequenceNumber();
+ auto snapshot1 = db->GetSnapshot();
+ // Increment sequence number.
+ ASSERT_OK(db->Put(WriteOptions(), "key2", "value2"));
+ auto snapshot2 = db->GetSnapshot();
+ ASSERT_OK(transaction->Commit());
+ delete transaction;
+ VerifyKeys({{"key1", "NOT_FOUND"}});
+ VerifyKeys({{"key1", "value1"}}, snapshot1);
+ VerifyKeys({{"key1", "value1"}}, snapshot2);
+ ASSERT_OK(db->Flush(FlushOptions()));
+
+ auto callback = [&](void* compaction) {
+ // Release snapshot1 after CompactionIterator init.
+ // CompactionIterator need to double check and find out snapshot2 is now
+ // the earliest existing snapshot.
+ if (compaction != nullptr) {
+ db->ReleaseSnapshot(snapshot1);
+ // Add some keys to advance max_evicted_seq.
+ ASSERT_OK(db->Put(WriteOptions(), "key3", "value3"));
+ ASSERT_OK(db->Put(WriteOptions(), "key4", "value4"));
+ }
+ };
+ SyncPoint::GetInstance()->SetCallBack("CompactionIterator:AfterInit",
+ callback);
+ SyncPoint::GetInstance()->EnableProcessing();
+
+ // Dummy keys to avoid compaction trivially move files and get around actual
+ // compaction logic.
+ ASSERT_OK(db->Put(WriteOptions(), "a", "dummy"));
+ ASSERT_OK(db->Put(WriteOptions(), "z", "dummy"));
+ ASSERT_OK(db->CompactRange(CompactRangeOptions(), nullptr, nullptr));
+ // Only verify for key1. Both the put and delete for the key should be kept.
+ // Since the delete tombstone is not visible to snapshot2, we need to keep
+ // at least one version of the key, for write-conflict check.
+ VerifyInternalKeys({{"key1", "", del_seq, kTypeDeletion},
+ {"key1", "value1", put_seq, kTypeValue}});
+ db->ReleaseSnapshot(snapshot2);
+ SyncPoint::GetInstance()->ClearAllCallBacks();
+}
+
+TEST_P(WritePreparedTransactionTest,
+ ReleaseEarliestSnapshotDuringCompaction_WithSD) {
+ constexpr size_t kSnapshotCacheBits = 7; // same as default
+ constexpr size_t kCommitCacheBits = 0; // minimum commit cache
+ UpdateTransactionDBOptions(kSnapshotCacheBits, kCommitCacheBits);
+ options.disable_auto_compactions = true;
+ ASSERT_OK(ReOpen());
+
+ ASSERT_OK(db->Put(WriteOptions(), "key", "value"));
+ ASSERT_OK(db->Put(WriteOptions(), "foo", "value"));
+ ASSERT_OK(db->Flush(FlushOptions()));
+
+ auto* txn = db->BeginTransaction(WriteOptions(), TransactionOptions(),
+ /*old_txn=*/nullptr);
+ ASSERT_OK(txn->SingleDelete("key"));
+ ASSERT_OK(txn->Put("wow", "value"));
+ ASSERT_OK(txn->SetName("txn"));
+ ASSERT_OK(txn->Prepare());
+ ASSERT_OK(db->Flush(FlushOptions()));
+
+ const bool two_write_queues = std::get<1>(GetParam());
+ if (two_write_queues) {
+ // In the case of two queues, commit another txn just to bump
+ // last_published_seq so that a subsequent GetSnapshot() call can return
+ // a snapshot with higher sequence.
+ auto* dummy_txn = db->BeginTransaction(WriteOptions(), TransactionOptions(),
+ /*old_txn=*/nullptr);
+ ASSERT_OK(dummy_txn->Put("haha", "value"));
+ ASSERT_OK(dummy_txn->Commit());
+ delete dummy_txn;
+ }
+ auto* snapshot = db->GetSnapshot();
+
+ ASSERT_OK(txn->Commit());
+ delete txn;
+
+ SyncPoint::GetInstance()->SetCallBack(
+ "CompactionIterator::NextFromInput:SingleDelete:1", [&](void* arg) {
+ if (!arg) {
+ return;
+ }
+ db->ReleaseSnapshot(snapshot);
+
+ // Advance max_evicted_seq
+ ASSERT_OK(db->Put(WriteOptions(), "bar", "value"));
+ });
+ SyncPoint::GetInstance()->EnableProcessing();
+ ASSERT_OK(db->CompactRange(CompactRangeOptions(), /*begin=*/nullptr,
+ /*end=*/nullptr));
+ SyncPoint::GetInstance()->ClearAllCallBacks();
+}
+
+TEST_P(WritePreparedTransactionTest,
+ ReleaseEarliestSnapshotDuringCompaction_WithSD2) {
+ constexpr size_t kSnapshotCacheBits = 7; // same as default
+ constexpr size_t kCommitCacheBits = 0; // minimum commit cache
+ UpdateTransactionDBOptions(kSnapshotCacheBits, kCommitCacheBits);
+ options.disable_auto_compactions = true;
+ ASSERT_OK(ReOpen());
+
+ ASSERT_OK(db->Put(WriteOptions(), "foo", "value"));
+ ASSERT_OK(db->Put(WriteOptions(), "key", "value"));
+ ASSERT_OK(db->Flush(FlushOptions()));
+
+ auto* txn = db->BeginTransaction(WriteOptions(), TransactionOptions(),
+ /*old_txn=*/nullptr);
+ ASSERT_OK(txn->Put("bar", "value"));
+ ASSERT_OK(txn->SingleDelete("key"));
+ ASSERT_OK(txn->SetName("txn"));
+ ASSERT_OK(txn->Prepare());
+ ASSERT_OK(db->Flush(FlushOptions()));
+
+ ASSERT_OK(txn->Commit());
+ delete txn;
+
+ ASSERT_OK(db->Put(WriteOptions(), "haha", "value"));
+
+ // Create a dummy transaction to take a snapshot for ww-conflict detection.
+ TransactionOptions txn_opts;
+ txn_opts.set_snapshot = true;
+ auto* dummy_txn =
+ db->BeginTransaction(WriteOptions(), txn_opts, /*old_txn=*/nullptr);
+
+ SyncPoint::GetInstance()->SetCallBack(
+ "CompactionIterator::NextFromInput:SingleDelete:2", [&](void* /*arg*/) {
+ ASSERT_OK(dummy_txn->Rollback());
+ delete dummy_txn;
+
+ ASSERT_OK(db->Put(WriteOptions(), "dontcare", "value"));
+ });
+ SyncPoint::GetInstance()->EnableProcessing();
+
+ ASSERT_OK(db->Put(WriteOptions(), "haha2", "value"));
+ auto* snapshot = db->GetSnapshot();
+
+ ASSERT_OK(db->CompactRange(CompactRangeOptions(), nullptr, nullptr));
+ db->ReleaseSnapshot(snapshot);
+ SyncPoint::GetInstance()->ClearAllCallBacks();
+}
+
+TEST_P(WritePreparedTransactionTest,
+ ReleaseEarliestSnapshotDuringCompaction_WithDelete) {
+ constexpr size_t kSnapshotCacheBits = 7; // same as default
+ constexpr size_t kCommitCacheBits = 0; // minimum commit cache
+ UpdateTransactionDBOptions(kSnapshotCacheBits, kCommitCacheBits);
+ options.disable_auto_compactions = true;
+ ASSERT_OK(ReOpen());
+
+ ASSERT_OK(db->Put(WriteOptions(), "a", "value"));
+ ASSERT_OK(db->Put(WriteOptions(), "b", "value"));
+ ASSERT_OK(db->Put(WriteOptions(), "c", "value"));
+ ASSERT_OK(db->Flush(FlushOptions()));
+
+ auto* txn = db->BeginTransaction(WriteOptions(), TransactionOptions(),
+ /*old_txn=*/nullptr);
+ ASSERT_OK(txn->Delete("b"));
+ ASSERT_OK(txn->SetName("txn"));
+ ASSERT_OK(txn->Prepare());
+
+ const bool two_write_queues = std::get<1>(GetParam());
+ if (two_write_queues) {
+ // In the case of two queues, commit another txn just to bump
+ // last_published_seq so that a subsequent GetSnapshot() call can return
+ // a snapshot with higher sequence.
+ auto* dummy_txn = db->BeginTransaction(WriteOptions(), TransactionOptions(),
+ /*old_txn=*/nullptr);
+ ASSERT_OK(dummy_txn->Put("haha", "value"));
+ ASSERT_OK(dummy_txn->Commit());
+ delete dummy_txn;
+ }
+ auto* snapshot1 = db->GetSnapshot();
+ ASSERT_OK(txn->Commit());
+ delete txn;
+ auto* snapshot2 = db->GetSnapshot();
+
+ SyncPoint::GetInstance()->SetCallBack(
+ "CompactionIterator::NextFromInput:BottommostDelete:1", [&](void* arg) {
+ if (!arg) {
+ return;
+ }
+ db->ReleaseSnapshot(snapshot1);
+
+ // Advance max_evicted_seq
+ ASSERT_OK(db->Put(WriteOptions(), "dummy1", "value"));
+ });
+ SyncPoint::GetInstance()->EnableProcessing();
+
+ ASSERT_OK(db->CompactRange(CompactRangeOptions(), /*begin=*/nullptr,
+ /*end=*/nullptr));
+ db->ReleaseSnapshot(snapshot2);
+ SyncPoint::GetInstance()->ClearAllCallBacks();
+}
+
+TEST_P(WritePreparedTransactionTest,
+ ReleaseSnapshotBetweenSDAndPutDuringCompaction) {
+ constexpr size_t kSnapshotCacheBits = 7; // same as default
+ constexpr size_t kCommitCacheBits = 0; // minimum commit cache
+ UpdateTransactionDBOptions(kSnapshotCacheBits, kCommitCacheBits);
+ options.disable_auto_compactions = true;
+ ASSERT_OK(ReOpen());
+
+ // Create a dummy transaction to take a snapshot for ww-conflict detection.
+ TransactionOptions txn_opts;
+ txn_opts.set_snapshot = true;
+ auto* dummy_txn =
+ db->BeginTransaction(WriteOptions(), txn_opts, /*old_txn=*/nullptr);
+ // Increment seq
+ ASSERT_OK(db->Put(WriteOptions(), "bar", "value"));
+
+ ASSERT_OK(db->Put(WriteOptions(), "foo", "value"));
+ ASSERT_OK(db->SingleDelete(WriteOptions(), "foo"));
+ auto* snapshot1 = db->GetSnapshot();
+ // Increment seq
+ ASSERT_OK(db->Put(WriteOptions(), "dontcare", "value"));
+ auto* snapshot2 = db->GetSnapshot();
+
+ SyncPoint::GetInstance()->SetCallBack(
+ "CompactionIterator::NextFromInput:KeepSDForWW", [&](void* /*arg*/) {
+ db->ReleaseSnapshot(snapshot1);
+
+ ASSERT_OK(db->Put(WriteOptions(), "dontcare2", "value2"));
+ });
+ SyncPoint::GetInstance()->EnableProcessing();
+
+ ASSERT_OK(db->Flush(FlushOptions()));
+ db->ReleaseSnapshot(snapshot2);
+ ASSERT_OK(dummy_txn->Commit());
+ delete dummy_txn;
+ SyncPoint::GetInstance()->ClearAllCallBacks();
+}
+
+TEST_P(WritePreparedTransactionTest,
+ ReleaseEarliestWriteConflictSnapshot_SingleDelete) {
+ constexpr size_t kSnapshotCacheBits = 7; // same as default
+ constexpr size_t kCommitCacheBits = 0; // minimum commit cache
+ UpdateTransactionDBOptions(kSnapshotCacheBits, kCommitCacheBits);
+ options.disable_auto_compactions = true;
+ ASSERT_OK(ReOpen());
+
+ ASSERT_OK(db->Put(WriteOptions(), "a", "value"));
+ ASSERT_OK(db->Put(WriteOptions(), "b", "value"));
+ ASSERT_OK(db->Put(WriteOptions(), "c", "value"));
+ ASSERT_OK(db->Flush(FlushOptions()));
+
+ {
+ CompactRangeOptions cro;
+ cro.change_level = true;
+ cro.target_level = 2;
+ ASSERT_OK(db->CompactRange(cro, /*begin=*/nullptr, /*end=*/nullptr));
+ }
+
+ std::unique_ptr<Transaction> txn;
+ txn.reset(db->BeginTransaction(WriteOptions(), TransactionOptions(),
+ /*old_txn=*/nullptr));
+ ASSERT_OK(txn->SetName("txn1"));
+ ASSERT_OK(txn->SingleDelete("b"));
+ ASSERT_OK(txn->Prepare());
+ ASSERT_OK(txn->Commit());
+
+ auto* snapshot1 = db->GetSnapshot();
+
+ // Bump seq of the db by performing writes so that
+ // earliest_snapshot_ < earliest_write_conflict_snapshot_ in
+ // CompactionIterator.
+ ASSERT_OK(db->Put(WriteOptions(), "z", "dontcare"));
+
+ // Create another snapshot for write conflict checking
+ std::unique_ptr<Transaction> txn2;
+ {
+ TransactionOptions txn_opts;
+ txn_opts.set_snapshot = true;
+ txn2.reset(
+ db->BeginTransaction(WriteOptions(), txn_opts, /*old_txn=*/nullptr));
+ }
+
+ // Bump seq so that the subsequent bg flush won't create a snapshot with the
+ // same seq as the previous snapshot for conflict checking.
+ ASSERT_OK(db->Put(WriteOptions(), "y", "dont"));
+
+ ASSERT_OK(db->Flush(FlushOptions()));
+
+ SyncPoint::GetInstance()->DisableProcessing();
+ SyncPoint::GetInstance()->ClearAllCallBacks();
+ SyncPoint::GetInstance()->SetCallBack(
+ "CompactionIterator::NextFromInput:SingleDelete:1", [&](void* /*arg*/) {
+ // Rolling back txn2 should release its snapshot(for ww checking).
+ ASSERT_OK(txn2->Rollback());
+ txn2.reset();
+ // Advance max_evicted_seq
+ ASSERT_OK(db->Put(WriteOptions(), "x", "value"));
+ });
+ SyncPoint::GetInstance()->EnableProcessing();
+
+ ASSERT_OK(db->CompactRange(CompactRangeOptions(), /*begin=*/nullptr,
+ /*end=*/nullptr));
+
+ SyncPoint::GetInstance()->DisableProcessing();
+ SyncPoint::GetInstance()->ClearAllCallBacks();
+
+ db->ReleaseSnapshot(snapshot1);
+}
+
+TEST_P(WritePreparedTransactionTest, ReleaseEarliestSnapshotAfterSeqZeroing) {
+ constexpr size_t kSnapshotCacheBits = 7; // same as default
+ constexpr size_t kCommitCacheBits = 0; // minimum commit cache
+ UpdateTransactionDBOptions(kSnapshotCacheBits, kCommitCacheBits);
+ options.disable_auto_compactions = true;
+ ASSERT_OK(ReOpen());
+
+ ASSERT_OK(db->Put(WriteOptions(), "a", "value"));
+ ASSERT_OK(db->Put(WriteOptions(), "b", "value"));
+ ASSERT_OK(db->Put(WriteOptions(), "c", "value"));
+ ASSERT_OK(db->Flush(FlushOptions()));
+
+ {
+ CompactRangeOptions cro;
+ cro.change_level = true;
+ cro.target_level = 2;
+ ASSERT_OK(db->CompactRange(cro, /*begin=*/nullptr, /*end=*/nullptr));
+ }
+
+ ASSERT_OK(db->SingleDelete(WriteOptions(), "b"));
+
+ // Take a snapshot so that the SD won't be dropped during flush.
+ auto* tmp_snapshot = db->GetSnapshot();
+
+ ASSERT_OK(db->Put(WriteOptions(), "b", "value2"));
+ auto* snapshot = db->GetSnapshot();
+ ASSERT_OK(db->Flush(FlushOptions()));
+
+ db->ReleaseSnapshot(tmp_snapshot);
+
+ // Bump the sequence so that the below bg compaction job's snapshot will be
+ // different from snapshot's sequence.
+ ASSERT_OK(db->Put(WriteOptions(), "z", "foo"));
+
+ SyncPoint::GetInstance()->DisableProcessing();
+ SyncPoint::GetInstance()->ClearAllCallBacks();
+ SyncPoint::GetInstance()->SetCallBack(
+ "CompactionIterator::PrepareOutput:ZeroingSeq", [&](void* arg) {
+ const auto* const ikey =
+ reinterpret_cast<const ParsedInternalKey*>(arg);
+ assert(ikey);
+ if (ikey->user_key == "b") {
+ assert(ikey->type == kTypeValue);
+ db->ReleaseSnapshot(snapshot);
+
+ // Bump max_evicted_seq.
+ ASSERT_OK(db->Put(WriteOptions(), "z", "dontcare"));
+ }
+ });
+ SyncPoint::GetInstance()->EnableProcessing();
+
+ ASSERT_OK(db->CompactRange(CompactRangeOptions(), /*begin=*/nullptr,
+ /*end=*/nullptr));
+
+ SyncPoint::GetInstance()->DisableProcessing();
+ SyncPoint::GetInstance()->ClearAllCallBacks();
+}
+
+TEST_P(WritePreparedTransactionTest, ReleaseEarliestSnapshotAfterSeqZeroing2) {
+ constexpr size_t kSnapshotCacheBits = 7; // same as default
+ constexpr size_t kCommitCacheBits = 0; // minimum commit cache
+ UpdateTransactionDBOptions(kSnapshotCacheBits, kCommitCacheBits);
+ options.disable_auto_compactions = true;
+ ASSERT_OK(ReOpen());
+
+ // Generate an L0 with only SD for one key "b".
+ ASSERT_OK(db->Put(WriteOptions(), "a", "value"));
+ ASSERT_OK(db->Put(WriteOptions(), "b", "value"));
+ // Take a snapshot so that subsequent flush outputs the SD for "b".
+ auto* tmp_snapshot = db->GetSnapshot();
+ ASSERT_OK(db->SingleDelete(WriteOptions(), "b"));
+ ASSERT_OK(db->Put(WriteOptions(), "c", "value"));
+
+ SyncPoint::GetInstance()->DisableProcessing();
+ SyncPoint::GetInstance()->ClearAllCallBacks();
+ SyncPoint::GetInstance()->SetCallBack(
+ "CompactionIterator::NextFromInput:SingleDelete:3", [&](void* arg) {
+ if (!arg) {
+ db->ReleaseSnapshot(tmp_snapshot);
+ // Bump max_evicted_seq
+ ASSERT_OK(db->Put(WriteOptions(), "x", "dontcare"));
+ }
+ });
+ SyncPoint::GetInstance()->EnableProcessing();
+
+ ASSERT_OK(db->Flush(FlushOptions()));
+ // Finish generating L0 with only SD for "b".
+
+ SyncPoint::GetInstance()->DisableProcessing();
+ SyncPoint::GetInstance()->ClearAllCallBacks();
+
+ // Move the L0 to L2.
+ {
+ CompactRangeOptions cro;
+ cro.change_level = true;
+ cro.target_level = 2;
+ ASSERT_OK(db->CompactRange(cro, /*begin=*/nullptr, /*end=*/nullptr));
+ }
+
+ ASSERT_OK(db->Put(WriteOptions(), "b", "value1"));
+
+ auto* snapshot = db->GetSnapshot();
+
+ // Bump seq so that a subsequent flush/compaction job's snapshot is larger
+ // than the above snapshot's seq.
+ ASSERT_OK(db->Put(WriteOptions(), "x", "dontcare"));
+
+ // Generate a second L0.
+ ASSERT_OK(db->Flush(FlushOptions()));
+
+ SyncPoint::GetInstance()->SetCallBack(
+ "CompactionIterator::PrepareOutput:ZeroingSeq", [&](void* arg) {
+ const auto* const ikey =
+ reinterpret_cast<const ParsedInternalKey*>(arg);
+ assert(ikey);
+ if (ikey->user_key == "b") {
+ assert(ikey->type == kTypeValue);
+ db->ReleaseSnapshot(snapshot);
+
+ // Bump max_evicted_seq.
+ ASSERT_OK(db->Put(WriteOptions(), "z", "dontcare"));
+ }
+ });
+ SyncPoint::GetInstance()->EnableProcessing();
+
+ ASSERT_OK(db->CompactRange(CompactRangeOptions(), /*begin=*/nullptr,
+ /*end=*/nullptr));
+
+ SyncPoint::GetInstance()->DisableProcessing();
+ SyncPoint::GetInstance()->ClearAllCallBacks();
+}
+
+// Although the user-contract indicates that a SD can only be issued for a key
+// that exists and has not been overwritten, it is still possible for a Delete
+// to be present when write-prepared transaction is rolled back.
+TEST_P(WritePreparedTransactionTest, SingleDeleteAfterRollback) {
+ constexpr size_t kSnapshotCacheBits = 7; // same as default
+ constexpr size_t kCommitCacheBits = 0; // minimum commit cache
+ txn_db_options.rollback_deletion_type_callback =
+ [](TransactionDB*, ColumnFamilyHandle*, const Slice&) { return true; };
+ UpdateTransactionDBOptions(kSnapshotCacheBits, kCommitCacheBits);
+ options.disable_auto_compactions = true;
+ ASSERT_OK(ReOpen());
+
+ // Get a write conflict snapshot by creating a transaction with
+ // set_snapshot=true.
+ TransactionOptions txn_opts;
+ txn_opts.set_snapshot = true;
+ std::unique_ptr<Transaction> dummy_txn(
+ db->BeginTransaction(WriteOptions(), txn_opts));
+
+ std::unique_ptr<Transaction> txn0(
+ db->BeginTransaction(WriteOptions(), TransactionOptions()));
+ ASSERT_OK(txn0->Put("foo", "value"));
+ ASSERT_OK(txn0->SetName("xid0"));
+ ASSERT_OK(txn0->Prepare());
+
+ // Create an SST with only {"foo": "value"}.
+ ASSERT_OK(db->Flush(FlushOptions()));
+
+ // Insert a Delete to cancel out the prior Put by txn0.
+ ASSERT_OK(txn0->Rollback());
+ txn0.reset();
+
+ // Create a second SST.
+ ASSERT_OK(db->Flush(FlushOptions()));
+
+ ASSERT_OK(db->Put(WriteOptions(), "foo", "value1"));
+
+ auto* snapshot = db->GetSnapshot();
+
+ ASSERT_OK(db->SingleDelete(WriteOptions(), "foo"));
+
+ int count = 0;
+ SyncPoint::GetInstance()->DisableProcessing();
+ SyncPoint::GetInstance()->ClearAllCallBacks();
+ SyncPoint::GetInstance()->SetCallBack(
+ "CompactionIterator::NextFromInput:SingleDelete:1", [&](void* arg) {
+ const auto* const c = reinterpret_cast<const Compaction*>(arg);
+ assert(!c);
+ // Trigger once only for SingleDelete during flush.
+ if (0 == count) {
+ ++count;
+ db->ReleaseSnapshot(snapshot);
+ // Bump max_evicted_seq
+ ASSERT_OK(db->Put(WriteOptions(), "x", "dontcare"));
+ }
+ });
+ SyncPoint::GetInstance()->EnableProcessing();
+
+ // Create a third SST containing a SD without its matching PUT.
+ ASSERT_OK(db->Flush(FlushOptions()));
+
+ SyncPoint::GetInstance()->DisableProcessing();
+ SyncPoint::GetInstance()->ClearAllCallBacks();
+ SyncPoint::GetInstance()->EnableProcessing();
+
+ DBImpl* dbimpl = static_cast_with_check<DBImpl>(db->GetRootDB());
+ assert(dbimpl);
+ ASSERT_OK(dbimpl->TEST_CompactRange(
+ /*level=*/0, /*begin=*/nullptr, /*end=*/nullptr,
+ /*column_family=*/nullptr, /*disallow_trivial_mode=*/true));
+
+ SyncPoint::GetInstance()->DisableProcessing();
+ SyncPoint::GetInstance()->ClearAllCallBacks();
+
+ // Release the conflict-checking snapshot.
+ ASSERT_OK(dummy_txn->Rollback());
+}
+
+// A more complex test to verify compaction/flush should keep keys visible
+// to snapshots.
+TEST_P(WritePreparedTransactionTest,
+ CompactionKeepSnapshotVisibleKeysRandomized) {
+ constexpr size_t kNumTransactions = 10;
+ constexpr size_t kNumIterations = 1000;
+
+ std::vector<Transaction*> transactions(kNumTransactions, nullptr);
+ std::vector<size_t> versions(kNumTransactions, 0);
+ std::unordered_map<std::string, std::string> current_data;
+ std::vector<const Snapshot*> snapshots;
+ std::vector<std::unordered_map<std::string, std::string>> snapshot_data;
+
+ Random rnd(1103);
+ options.disable_auto_compactions = true;
+ ASSERT_OK(ReOpen());
+
+ for (size_t i = 0; i < kNumTransactions; i++) {
+ std::string key = "key" + std::to_string(i);
+ std::string value = "value0";
+ ASSERT_OK(db->Put(WriteOptions(), key, value));
+ current_data[key] = value;
+ }
+ VerifyKeys(current_data);
+
+ for (size_t iter = 0; iter < kNumIterations; iter++) {
+ auto r = rnd.Next() % (kNumTransactions + 1);
+ if (r < kNumTransactions) {
+ std::string key = "key" + std::to_string(r);
+ if (transactions[r] == nullptr) {
+ std::string value = "value" + std::to_string(versions[r] + 1);
+ auto* txn = db->BeginTransaction(WriteOptions());
+ ASSERT_OK(txn->SetName("txn" + std::to_string(r)));
+ ASSERT_OK(txn->Put(key, value));
+ ASSERT_OK(txn->Prepare());
+ transactions[r] = txn;
+ } else {
+ std::string value = "value" + std::to_string(++versions[r]);
+ ASSERT_OK(transactions[r]->Commit());
+ delete transactions[r];
+ transactions[r] = nullptr;
+ current_data[key] = value;
+ }
+ } else {
+ auto* snapshot = db->GetSnapshot();
+ VerifyKeys(current_data, snapshot);
+ snapshots.push_back(snapshot);
+ snapshot_data.push_back(current_data);
+ }
+ VerifyKeys(current_data);
+ }
+ // Take a last snapshot to test compaction with uncommitted prepared
+ // transaction.
+ snapshots.push_back(db->GetSnapshot());
+ snapshot_data.push_back(current_data);
+
+ ASSERT_EQ(snapshots.size(), snapshot_data.size());
+ for (size_t i = 0; i < snapshots.size(); i++) {
+ VerifyKeys(snapshot_data[i], snapshots[i]);
+ }
+ ASSERT_OK(db->Flush(FlushOptions()));
+ for (size_t i = 0; i < snapshots.size(); i++) {
+ VerifyKeys(snapshot_data[i], snapshots[i]);
+ }
+ // Dummy keys to avoid compaction trivially move files and get around actual
+ // compaction logic.
+ ASSERT_OK(db->Put(WriteOptions(), "a", "dummy"));
+ ASSERT_OK(db->Put(WriteOptions(), "z", "dummy"));
+ ASSERT_OK(db->CompactRange(CompactRangeOptions(), nullptr, nullptr));
+ for (size_t i = 0; i < snapshots.size(); i++) {
+ VerifyKeys(snapshot_data[i], snapshots[i]);
+ }
+ // cleanup
+ for (size_t i = 0; i < kNumTransactions; i++) {
+ if (transactions[i] == nullptr) {
+ continue;
+ }
+ ASSERT_OK(transactions[i]->Commit());
+ delete transactions[i];
+ }
+ for (size_t i = 0; i < snapshots.size(); i++) {
+ db->ReleaseSnapshot(snapshots[i]);
+ }
+}
+
+// Compaction should not apply the optimization to output key with sequence
+// number equal to 0 if the key is not visible to earliest snapshot, based on
+// commit sequence number.
+TEST_P(WritePreparedTransactionTest,
+ CompactionShouldKeepSequenceForUncommittedKeys) {
+ options.disable_auto_compactions = true;
+ ASSERT_OK(ReOpen());
+ // Keep track of expected sequence number.
+ SequenceNumber expected_seq = 0;
+ auto* transaction = db->BeginTransaction(WriteOptions());
+ ASSERT_OK(transaction->SetName("txn"));
+ ASSERT_OK(transaction->Put("key1", "value1"));
+ ASSERT_OK(transaction->Prepare());
+ ASSERT_EQ(++expected_seq, db->GetLatestSequenceNumber());
+ SequenceNumber seq1 = expected_seq;
+ ASSERT_OK(db->Put(WriteOptions(), "key2", "value2"));
+ DBImpl* db_impl = static_cast_with_check<DBImpl>(db->GetRootDB());
+ expected_seq++; // one for data
+ if (options.two_write_queues) {
+ expected_seq++; // one for commit
+ }
+ ASSERT_EQ(expected_seq, db_impl->TEST_GetLastVisibleSequence());
+ ASSERT_OK(db->Flush(FlushOptions()));
+ // Dummy keys to avoid compaction trivially move files and get around actual
+ // compaction logic.
+ ASSERT_OK(db->Put(WriteOptions(), "a", "dummy"));
+ ASSERT_OK(db->Put(WriteOptions(), "z", "dummy"));
+ ASSERT_OK(db->CompactRange(CompactRangeOptions(), nullptr, nullptr));
+ VerifyKeys({
+ {"key1", "NOT_FOUND"},
+ {"key2", "value2"},
+ });
+ VerifyInternalKeys({
+ // "key1" has not been committed. It keeps its sequence number.
+ {"key1", "value1", seq1, kTypeValue},
+ // "key2" is committed and output with seq = 0.
+ {"key2", "value2", 0, kTypeValue},
+ });
+ ASSERT_OK(transaction->Commit());
+ VerifyKeys({
+ {"key1", "value1"},
+ {"key2", "value2"},
+ });
+ delete transaction;
+}
+
+TEST_P(WritePreparedTransactionTest, CommitAndSnapshotDuringCompaction) {
+ options.disable_auto_compactions = true;
+ ASSERT_OK(ReOpen());
+
+ const Snapshot* snapshot = nullptr;
+ ASSERT_OK(db->Put(WriteOptions(), "key1", "value1"));
+ auto* txn = db->BeginTransaction(WriteOptions());
+ ASSERT_OK(txn->SetName("txn"));
+ ASSERT_OK(txn->Put("key1", "value2"));
+ ASSERT_OK(txn->Prepare());
+
+ auto callback = [&](void*) {
+ // Snapshot is taken after compaction start. It should be taken into
+ // consideration for whether to compact out value1.
+ snapshot = db->GetSnapshot();
+ ASSERT_OK(txn->Commit());
+ delete txn;
+ };
+ SyncPoint::GetInstance()->SetCallBack("CompactionIterator:AfterInit",
+ callback);
+ SyncPoint::GetInstance()->EnableProcessing();
+ ASSERT_OK(db->Flush(FlushOptions()));
+ ASSERT_NE(nullptr, snapshot);
+ VerifyKeys({{"key1", "value2"}});
+ VerifyKeys({{"key1", "value1"}}, snapshot);
+ db->ReleaseSnapshot(snapshot);
+}
+
+TEST_P(WritePreparedTransactionTest, Iterate) {
+ auto verify_state = [](Iterator* iter, const std::string& key,
+ const std::string& value) {
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_OK(iter->status());
+ ASSERT_EQ(key, iter->key().ToString());
+ ASSERT_EQ(value, iter->value().ToString());
+ };
+
+ auto verify_iter = [&](const std::string& expected_val) {
+ // Get iterator from a concurrent transaction and make sure it has the
+ // same view as an iterator from the DB.
+ auto* txn = db->BeginTransaction(WriteOptions());
+
+ for (int i = 0; i < 2; i++) {
+ Iterator* iter = (i == 0) ? db->NewIterator(ReadOptions())
+ : txn->GetIterator(ReadOptions());
+ // Seek
+ iter->Seek("foo");
+ verify_state(iter, "foo", expected_val);
+ // Next
+ iter->Seek("a");
+ verify_state(iter, "a", "va");
+ iter->Next();
+ verify_state(iter, "foo", expected_val);
+ // SeekForPrev
+ iter->SeekForPrev("y");
+ verify_state(iter, "foo", expected_val);
+ // Prev
+ iter->SeekForPrev("z");
+ verify_state(iter, "z", "vz");
+ iter->Prev();
+ verify_state(iter, "foo", expected_val);
+ delete iter;
+ }
+ delete txn;
+ };
+
+ ASSERT_OK(db->Put(WriteOptions(), "foo", "v1"));
+ auto* transaction = db->BeginTransaction(WriteOptions());
+ ASSERT_OK(transaction->SetName("txn"));
+ ASSERT_OK(transaction->Put("foo", "v2"));
+ ASSERT_OK(transaction->Prepare());
+ VerifyKeys({{"foo", "v1"}});
+ // dummy keys
+ ASSERT_OK(db->Put(WriteOptions(), "a", "va"));
+ ASSERT_OK(db->Put(WriteOptions(), "z", "vz"));
+ verify_iter("v1");
+ ASSERT_OK(transaction->Commit());
+ VerifyKeys({{"foo", "v2"}});
+ verify_iter("v2");
+ delete transaction;
+}
+
+TEST_P(WritePreparedTransactionTest, IteratorRefreshNotSupported) {
+ Iterator* iter = db->NewIterator(ReadOptions());
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Refresh().IsNotSupported());
+ delete iter;
+}
+
+// Committing an delayed prepared has two non-atomic steps: update commit cache,
+// remove seq from delayed_prepared_. The read in IsInSnapshot also involves two
+// non-atomic steps of checking these two data structures. This test breaks each
+// in the middle to ensure correctness in spite of non-atomic execution.
+// Note: This test is limitted to the case where snapshot is larger than the
+// max_evicted_seq_.
+TEST_P(WritePreparedTransactionTest, NonAtomicCommitOfDelayedPrepared) {
+ const size_t snapshot_cache_bits = 7; // same as default
+ const size_t commit_cache_bits = 3; // 8 entries
+ for (auto split_read : {true, false}) {
+ std::vector<bool> split_options = {false};
+ if (split_read) {
+ // Also test for break before mutex
+ split_options.push_back(true);
+ }
+ for (auto split_before_mutex : split_options) {
+ UpdateTransactionDBOptions(snapshot_cache_bits, commit_cache_bits);
+ ASSERT_OK(ReOpen());
+ WritePreparedTxnDB* wp_db = dynamic_cast<WritePreparedTxnDB*>(db);
+ DBImpl* db_impl = static_cast_with_check<DBImpl>(db->GetRootDB());
+ // Fill up the commit cache
+ std::string init_value("value1");
+ for (int i = 0; i < 10; i++) {
+ ASSERT_OK(db->Put(WriteOptions(), Slice("key1"), Slice(init_value)));
+ }
+ // Prepare a transaction but do not commit it
+ Transaction* txn =
+ db->BeginTransaction(WriteOptions(), TransactionOptions());
+ ASSERT_OK(txn->SetName("xid"));
+ ASSERT_OK(txn->Put(Slice("key1"), Slice("value2")));
+ ASSERT_OK(txn->Prepare());
+ // Commit a bunch of entries to advance max evicted seq and make the
+ // prepared a delayed prepared
+ for (int i = 0; i < 10; i++) {
+ ASSERT_OK(db->Put(WriteOptions(), Slice("key3"), Slice("value3")));
+ }
+ // The snapshot should not see the delayed prepared entry
+ auto snap = db->GetSnapshot();
+
+ if (split_read) {
+ if (split_before_mutex) {
+ // split before acquiring prepare_mutex_
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency(
+ {{"WritePreparedTxnDB::IsInSnapshot:prepared_mutex_:pause",
+ "AtomicCommitOfDelayedPrepared:Commit:before"},
+ {"AtomicCommitOfDelayedPrepared:Commit:after",
+ "WritePreparedTxnDB::IsInSnapshot:prepared_mutex_:resume"}});
+ } else {
+ // split right after reading from the commit cache
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency(
+ {{"WritePreparedTxnDB::IsInSnapshot:GetCommitEntry:pause",
+ "AtomicCommitOfDelayedPrepared:Commit:before"},
+ {"AtomicCommitOfDelayedPrepared:Commit:after",
+ "WritePreparedTxnDB::IsInSnapshot:GetCommitEntry:resume"}});
+ }
+ } else { // split commit
+ // split right before removing from delayed_prepared_
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency(
+ {{"WritePreparedTxnDB::RemovePrepared:pause",
+ "AtomicCommitOfDelayedPrepared:Read:before"},
+ {"AtomicCommitOfDelayedPrepared:Read:after",
+ "WritePreparedTxnDB::RemovePrepared:resume"}});
+ }
+ SyncPoint::GetInstance()->EnableProcessing();
+
+ ROCKSDB_NAMESPACE::port::Thread commit_thread([&]() {
+ TEST_SYNC_POINT("AtomicCommitOfDelayedPrepared:Commit:before");
+ ASSERT_OK(txn->Commit());
+ if (split_before_mutex) {
+ // Do bunch of inserts to evict the commit entry from the cache. This
+ // would prevent the 2nd look into commit cache under prepare_mutex_
+ // to see the commit entry.
+ auto seq = db_impl->TEST_GetLastVisibleSequence();
+ size_t tries = 0;
+ while (wp_db->max_evicted_seq_ < seq && tries < 50) {
+ ASSERT_OK(db->Put(WriteOptions(), Slice("key3"), Slice("value3")));
+ tries++;
+ };
+ ASSERT_LT(tries, 50);
+ }
+ TEST_SYNC_POINT("AtomicCommitOfDelayedPrepared:Commit:after");
+ delete txn;
+ });
+
+ ROCKSDB_NAMESPACE::port::Thread read_thread([&]() {
+ TEST_SYNC_POINT("AtomicCommitOfDelayedPrepared:Read:before");
+ ReadOptions roptions;
+ roptions.snapshot = snap;
+ PinnableSlice value;
+ auto s = db->Get(roptions, db->DefaultColumnFamily(), "key1", &value);
+ ASSERT_OK(s);
+ // It should not see the commit of delayed prepared
+ ASSERT_TRUE(value == init_value);
+ TEST_SYNC_POINT("AtomicCommitOfDelayedPrepared:Read:after");
+ db->ReleaseSnapshot(snap);
+ });
+
+ read_thread.join();
+ commit_thread.join();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->ClearAllCallBacks();
+ } // for split_before_mutex
+ } // for split_read
+}
+
+// When max evicted seq advances a prepared seq, it involves two updates: i)
+// adding prepared seq to delayed_prepared_, ii) updating max_evicted_seq_.
+// ::IsInSnapshot also reads these two values in a non-atomic way. This test
+// ensures correctness if the update occurs after ::IsInSnapshot reads
+// delayed_prepared_empty_ and before it reads max_evicted_seq_.
+// Note: this test focuses on read snapshot larger than max_evicted_seq_.
+TEST_P(WritePreparedTransactionTest, NonAtomicUpdateOfDelayedPrepared) {
+ const size_t snapshot_cache_bits = 7; // same as default
+ const size_t commit_cache_bits = 3; // 8 entries
+ UpdateTransactionDBOptions(snapshot_cache_bits, commit_cache_bits);
+ ASSERT_OK(ReOpen());
+ WritePreparedTxnDB* wp_db = dynamic_cast<WritePreparedTxnDB*>(db);
+ // Fill up the commit cache
+ std::string init_value("value1");
+ for (int i = 0; i < 10; i++) {
+ ASSERT_OK(db->Put(WriteOptions(), Slice("key1"), Slice(init_value)));
+ }
+ // Prepare a transaction but do not commit it
+ Transaction* txn = db->BeginTransaction(WriteOptions(), TransactionOptions());
+ ASSERT_OK(txn->SetName("xid"));
+ ASSERT_OK(txn->Put(Slice("key1"), Slice("value2")));
+ ASSERT_OK(txn->Prepare());
+ // Create a gap between prepare seq and snapshot seq
+ ASSERT_OK(db->Put(WriteOptions(), Slice("key3"), Slice("value3")));
+ ASSERT_OK(db->Put(WriteOptions(), Slice("key3"), Slice("value3")));
+ // The snapshot should not see the delayed prepared entry
+ auto snap = db->GetSnapshot();
+ ASSERT_LT(txn->GetId(), snap->GetSequenceNumber());
+
+ // split right after reading delayed_prepared_empty_
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency(
+ {{"WritePreparedTxnDB::IsInSnapshot:delayed_prepared_empty_:pause",
+ "AtomicUpdateOfDelayedPrepared:before"},
+ {"AtomicUpdateOfDelayedPrepared:after",
+ "WritePreparedTxnDB::IsInSnapshot:delayed_prepared_empty_:resume"}});
+ SyncPoint::GetInstance()->EnableProcessing();
+
+ ROCKSDB_NAMESPACE::port::Thread commit_thread([&]() {
+ TEST_SYNC_POINT("AtomicUpdateOfDelayedPrepared:before");
+ // Commit a bunch of entries to advance max evicted seq and make the
+ // prepared a delayed prepared
+ size_t tries = 0;
+ while (wp_db->max_evicted_seq_ < txn->GetId() && tries < 50) {
+ ASSERT_OK(db->Put(WriteOptions(), Slice("key3"), Slice("value3")));
+ tries++;
+ };
+ ASSERT_LT(tries, 50);
+ // This is the case on which the test focuses
+ ASSERT_LT(wp_db->max_evicted_seq_, snap->GetSequenceNumber());
+ TEST_SYNC_POINT("AtomicUpdateOfDelayedPrepared:after");
+ });
+
+ ROCKSDB_NAMESPACE::port::Thread read_thread([&]() {
+ ReadOptions roptions;
+ roptions.snapshot = snap;
+ PinnableSlice value;
+ auto s = db->Get(roptions, db->DefaultColumnFamily(), "key1", &value);
+ ASSERT_OK(s);
+ // It should not see the uncommitted value of delayed prepared
+ ASSERT_TRUE(value == init_value);
+ db->ReleaseSnapshot(snap);
+ });
+
+ read_thread.join();
+ commit_thread.join();
+ ASSERT_OK(txn->Commit());
+ delete txn;
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->ClearAllCallBacks();
+}
+
+// Eviction from commit cache and update of max evicted seq are two non-atomic
+// steps. Similarly the read of max_evicted_seq_ in ::IsInSnapshot and reading
+// from commit cache are two non-atomic steps. This tests if the update occurs
+// after reading max_evicted_seq_ and before reading the commit cache.
+// Note: the test focuses on snapshot larger than max_evicted_seq_
+TEST_P(WritePreparedTransactionTest, NonAtomicUpdateOfMaxEvictedSeq) {
+ const size_t snapshot_cache_bits = 7; // same as default
+ const size_t commit_cache_bits = 3; // 8 entries
+ UpdateTransactionDBOptions(snapshot_cache_bits, commit_cache_bits);
+ ASSERT_OK(ReOpen());
+ WritePreparedTxnDB* wp_db = dynamic_cast<WritePreparedTxnDB*>(db);
+ // Fill up the commit cache
+ std::string init_value("value1");
+ std::string last_value("value_final");
+ for (int i = 0; i < 10; i++) {
+ ASSERT_OK(db->Put(WriteOptions(), Slice("key1"), Slice(init_value)));
+ }
+ // Do an uncommitted write to prevent min_uncommitted optimization
+ Transaction* txn1 =
+ db->BeginTransaction(WriteOptions(), TransactionOptions());
+ ASSERT_OK(txn1->SetName("xid1"));
+ ASSERT_OK(txn1->Put(Slice("key0"), last_value));
+ ASSERT_OK(txn1->Prepare());
+ // Do a write with prepare to get the prepare seq
+ Transaction* txn = db->BeginTransaction(WriteOptions(), TransactionOptions());
+ ASSERT_OK(txn->SetName("xid"));
+ ASSERT_OK(txn->Put(Slice("key1"), last_value));
+ ASSERT_OK(txn->Prepare());
+ ASSERT_OK(txn->Commit());
+ // Create a gap between commit entry and snapshot seq
+ ASSERT_OK(db->Put(WriteOptions(), Slice("key3"), Slice("value3")));
+ ASSERT_OK(db->Put(WriteOptions(), Slice("key3"), Slice("value3")));
+ // The snapshot should see the last commit
+ auto snap = db->GetSnapshot();
+ ASSERT_LE(txn->GetId(), snap->GetSequenceNumber());
+
+ // split right after reading max_evicted_seq_
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency(
+ {{"WritePreparedTxnDB::IsInSnapshot:max_evicted_seq_:pause",
+ "NonAtomicUpdateOfMaxEvictedSeq:before"},
+ {"NonAtomicUpdateOfMaxEvictedSeq:after",
+ "WritePreparedTxnDB::IsInSnapshot:max_evicted_seq_:resume"}});
+ SyncPoint::GetInstance()->EnableProcessing();
+
+ ROCKSDB_NAMESPACE::port::Thread commit_thread([&]() {
+ TEST_SYNC_POINT("NonAtomicUpdateOfMaxEvictedSeq:before");
+ // Commit a bunch of entries to advance max evicted seq beyond txn->GetId()
+ size_t tries = 0;
+ while (wp_db->max_evicted_seq_ < txn->GetId() && tries < 50) {
+ ASSERT_OK(db->Put(WriteOptions(), Slice("key3"), Slice("value3")));
+ tries++;
+ };
+ ASSERT_LT(tries, 50);
+ // This is the case on which the test focuses
+ ASSERT_LT(wp_db->max_evicted_seq_, snap->GetSequenceNumber());
+ TEST_SYNC_POINT("NonAtomicUpdateOfMaxEvictedSeq:after");
+ });
+
+ ROCKSDB_NAMESPACE::port::Thread read_thread([&]() {
+ ReadOptions roptions;
+ roptions.snapshot = snap;
+ PinnableSlice value;
+ auto s = db->Get(roptions, db->DefaultColumnFamily(), "key1", &value);
+ ASSERT_OK(s);
+ // It should see the committed value of the evicted entry
+ ASSERT_TRUE(value == last_value);
+ db->ReleaseSnapshot(snap);
+ });
+
+ read_thread.join();
+ commit_thread.join();
+ delete txn;
+ ASSERT_OK(txn1->Commit());
+ delete txn1;
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->ClearAllCallBacks();
+}
+
+// Test when we add a prepared seq when the max_evicted_seq_ already goes beyond
+// that. The test focuses on a race condition between AddPrepared and
+// AdvanceMaxEvictedSeq functions.
+TEST_P(WritePreparedTransactionTest, AddPreparedBeforeMax) {
+ if (!options.two_write_queues) {
+ // This test is only for two write queues
+ return;
+ }
+ const size_t snapshot_cache_bits = 7; // same as default
+ // 1 entry to advance max after the 2nd commit
+ const size_t commit_cache_bits = 0;
+ UpdateTransactionDBOptions(snapshot_cache_bits, commit_cache_bits);
+ ASSERT_OK(ReOpen());
+ WritePreparedTxnDB* wp_db = dynamic_cast<WritePreparedTxnDB*>(db);
+ std::string some_value("value_some");
+ std::string uncommitted_value("value_uncommitted");
+ // Prepare two uncommitted transactions
+ Transaction* txn1 =
+ db->BeginTransaction(WriteOptions(), TransactionOptions());
+ ASSERT_OK(txn1->SetName("xid1"));
+ ASSERT_OK(txn1->Put(Slice("key1"), some_value));
+ ASSERT_OK(txn1->Prepare());
+ Transaction* txn2 =
+ db->BeginTransaction(WriteOptions(), TransactionOptions());
+ ASSERT_OK(txn2->SetName("xid2"));
+ ASSERT_OK(txn2->Put(Slice("key2"), some_value));
+ ASSERT_OK(txn2->Prepare());
+ // Start the txn here so the other thread could get its id
+ Transaction* txn = db->BeginTransaction(WriteOptions(), TransactionOptions());
+ ASSERT_OK(txn->SetName("xid"));
+ ASSERT_OK(txn->Put(Slice("key0"), uncommitted_value));
+ port::Mutex txn_mutex_;
+
+ // t1) Insert prepared entry, t2) commit other entries to advance max
+ // evicted sec and finish checking the existing prepared entries, t1)
+ // AddPrepared, t2) update max_evicted_seq_
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency({
+ {"AddPreparedCallback::AddPrepared::begin:pause",
+ "AddPreparedBeforeMax::read_thread:start"},
+ {"AdvanceMaxEvictedSeq::update_max:pause",
+ "AddPreparedCallback::AddPrepared::begin:resume"},
+ {"AddPreparedCallback::AddPrepared::end",
+ "AdvanceMaxEvictedSeq::update_max:resume"},
+ });
+ SyncPoint::GetInstance()->EnableProcessing();
+
+ ROCKSDB_NAMESPACE::port::Thread write_thread([&]() {
+ txn_mutex_.Lock();
+ ASSERT_OK(txn->Prepare());
+ txn_mutex_.Unlock();
+ });
+
+ ROCKSDB_NAMESPACE::port::Thread read_thread([&]() {
+ TEST_SYNC_POINT("AddPreparedBeforeMax::read_thread:start");
+ // Publish seq number with a commit
+ ASSERT_OK(txn1->Commit());
+ // Since the commit cache size is one the 2nd commit evict the 1st one and
+ // invokes AdcanceMaxEvictedSeq
+ ASSERT_OK(txn2->Commit());
+
+ ReadOptions roptions;
+ PinnableSlice value;
+ // The snapshot should not see the uncommitted value from write_thread
+ auto snap = db->GetSnapshot();
+ ASSERT_LT(wp_db->max_evicted_seq_, snap->GetSequenceNumber());
+ // This is the scenario that we test for
+ txn_mutex_.Lock();
+ ASSERT_GT(wp_db->max_evicted_seq_, txn->GetId());
+ txn_mutex_.Unlock();
+ roptions.snapshot = snap;
+ auto s = db->Get(roptions, db->DefaultColumnFamily(), "key0", &value);
+ ASSERT_TRUE(s.IsNotFound());
+ db->ReleaseSnapshot(snap);
+ });
+
+ read_thread.join();
+ write_thread.join();
+ delete txn1;
+ delete txn2;
+ ASSERT_OK(txn->Commit());
+ delete txn;
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->ClearAllCallBacks();
+}
+
+// When an old prepared entry gets committed, there is a gap between the time
+// that it is published and when it is cleaned up from old_prepared_. This test
+// stresses such cases.
+TEST_P(WritePreparedTransactionTest, CommitOfDelayedPrepared) {
+ const size_t snapshot_cache_bits = 7; // same as default
+ for (const size_t commit_cache_bits : {0, 2, 3}) {
+ for (const size_t sub_batch_cnt : {1, 2, 3}) {
+ UpdateTransactionDBOptions(snapshot_cache_bits, commit_cache_bits);
+ ASSERT_OK(ReOpen());
+ std::atomic<const Snapshot*> snap = {nullptr};
+ std::atomic<SequenceNumber> exp_prepare = {0};
+ ROCKSDB_NAMESPACE::port::Thread callback_thread;
+ // Value is synchronized via snap
+ PinnableSlice value;
+ // Take a snapshot after publish and before RemovePrepared:Start
+ auto snap_callback = [&]() {
+ ASSERT_EQ(nullptr, snap.load());
+ snap.store(db->GetSnapshot());
+ ReadOptions roptions;
+ roptions.snapshot = snap.load();
+ auto s = db->Get(roptions, db->DefaultColumnFamily(), "key2", &value);
+ ASSERT_OK(s);
+ };
+ auto callback = [&](void* param) {
+ SequenceNumber prep_seq = *((SequenceNumber*)param);
+ if (prep_seq == exp_prepare.load()) { // only for write_thread
+ // We need to spawn a thread to avoid deadlock since getting a
+ // snpashot might end up calling AdvanceSeqByOne which needs joining
+ // the write queue.
+ callback_thread = ROCKSDB_NAMESPACE::port::Thread(snap_callback);
+ TEST_SYNC_POINT("callback:end");
+ }
+ };
+ // Wait for the first snapshot be taken in GetSnapshotInternal. Although
+ // it might be updated before GetSnapshotInternal finishes but this should
+ // cover most of the cases.
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency({
+ {"WritePreparedTxnDB::GetSnapshotInternal:first", "callback:end"},
+ });
+ SyncPoint::GetInstance()->SetCallBack("RemovePrepared:Start", callback);
+ SyncPoint::GetInstance()->EnableProcessing();
+ // Thread to cause frequent evictions
+ ROCKSDB_NAMESPACE::port::Thread eviction_thread([&]() {
+ // Too many txns might cause commit_seq - prepare_seq in another thread
+ // to go beyond DELTA_UPPERBOUND
+ for (int i = 0; i < 25 * (1 << commit_cache_bits); i++) {
+ ASSERT_OK(db->Put(WriteOptions(), Slice("key1"), Slice("value1")));
+ }
+ });
+ ROCKSDB_NAMESPACE::port::Thread write_thread([&]() {
+ for (int i = 0; i < 25 * (1 << commit_cache_bits); i++) {
+ Transaction* txn =
+ db->BeginTransaction(WriteOptions(), TransactionOptions());
+ ASSERT_OK(txn->SetName("xid"));
+ std::string val_str = "value" + std::to_string(i);
+ for (size_t b = 0; b < sub_batch_cnt; b++) {
+ ASSERT_OK(txn->Put(Slice("key2"), val_str));
+ }
+ ASSERT_OK(txn->Prepare());
+ // Let an eviction to kick in
+ std::this_thread::yield();
+
+ exp_prepare.store(txn->GetId());
+ ASSERT_OK(txn->Commit());
+ delete txn;
+ // Wait for the snapshot taking that is triggered by
+ // RemovePrepared:Start callback
+ callback_thread.join();
+
+ // Read with the snapshot taken before delayed_prepared_ cleanup
+ ReadOptions roptions;
+ roptions.snapshot = snap.load();
+ ASSERT_NE(nullptr, roptions.snapshot);
+ PinnableSlice value2;
+ auto s =
+ db->Get(roptions, db->DefaultColumnFamily(), "key2", &value2);
+ ASSERT_OK(s);
+ // It should see its own write
+ ASSERT_TRUE(val_str == value2);
+ // The value read by snapshot should not change
+ ASSERT_STREQ(value2.ToString().c_str(), value.ToString().c_str());
+
+ db->ReleaseSnapshot(roptions.snapshot);
+ snap.store(nullptr);
+ }
+ });
+ write_thread.join();
+ eviction_thread.join();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->ClearAllCallBacks();
+ }
+ }
+}
+
+// Test that updating the commit map will not affect the existing snapshots
+TEST_P(WritePreparedTransactionTest, AtomicCommit) {
+ for (bool skip_prepare : {true, false}) {
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency({
+ {"WritePreparedTxnDB::AddCommitted:start",
+ "AtomicCommit::GetSnapshot:start"},
+ {"AtomicCommit::Get:end",
+ "WritePreparedTxnDB::AddCommitted:start:pause"},
+ {"WritePreparedTxnDB::AddCommitted:end", "AtomicCommit::Get2:start"},
+ {"AtomicCommit::Get2:end",
+ "WritePreparedTxnDB::AddCommitted:end:pause:"},
+ });
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+ ROCKSDB_NAMESPACE::port::Thread write_thread([&]() {
+ if (skip_prepare) {
+ ASSERT_OK(db->Put(WriteOptions(), Slice("key"), Slice("value")));
+ } else {
+ Transaction* txn =
+ db->BeginTransaction(WriteOptions(), TransactionOptions());
+ ASSERT_OK(txn->SetName("xid"));
+ ASSERT_OK(txn->Put(Slice("key"), Slice("value")));
+ ASSERT_OK(txn->Prepare());
+ ASSERT_OK(txn->Commit());
+ delete txn;
+ }
+ });
+ ROCKSDB_NAMESPACE::port::Thread read_thread([&]() {
+ ReadOptions roptions;
+ TEST_SYNC_POINT("AtomicCommit::GetSnapshot:start");
+ roptions.snapshot = db->GetSnapshot();
+ PinnableSlice val;
+ auto s = db->Get(roptions, db->DefaultColumnFamily(), "key", &val);
+ TEST_SYNC_POINT("AtomicCommit::Get:end");
+ TEST_SYNC_POINT("AtomicCommit::Get2:start");
+ ASSERT_SAME(roptions, db, s, val, "key");
+ TEST_SYNC_POINT("AtomicCommit::Get2:end");
+ db->ReleaseSnapshot(roptions.snapshot);
+ });
+ read_thread.join();
+ write_thread.join();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+ }
+}
+
+TEST_P(WritePreparedTransactionTest, BasicRollbackDeletionTypeCb) {
+ options.level0_file_num_compaction_trigger = 2;
+ // Always use SingleDelete to rollback Put.
+ txn_db_options.rollback_deletion_type_callback =
+ [](TransactionDB*, ColumnFamilyHandle*, const Slice&) { return true; };
+
+ const auto write_to_db = [&]() {
+ assert(db);
+ std::unique_ptr<Transaction> txn0(
+ db->BeginTransaction(WriteOptions(), TransactionOptions()));
+ ASSERT_OK(txn0->SetName("txn0"));
+ ASSERT_OK(txn0->Put("a", "v0"));
+ ASSERT_OK(txn0->Prepare());
+
+ // Generate sst1: [PUT('a')]
+ ASSERT_OK(db->Flush(FlushOptions()));
+
+ {
+ CompactRangeOptions cro;
+ cro.change_level = true;
+ cro.target_level = options.num_levels - 1;
+ cro.bottommost_level_compaction = BottommostLevelCompaction::kForce;
+ ASSERT_OK(db->CompactRange(cro, /*begin=*/nullptr, /*end=*/nullptr));
+ }
+
+ ASSERT_OK(txn0->Rollback());
+ txn0.reset();
+
+ ASSERT_OK(db->Put(WriteOptions(), "a", "v1"));
+
+ ASSERT_OK(db->SingleDelete(WriteOptions(), "a"));
+ // Generate another SST with a SD to cover the oldest PUT('a')
+ ASSERT_OK(db->Flush(FlushOptions()));
+
+ auto* dbimpl = static_cast_with_check<DBImpl>(db->GetRootDB());
+ assert(dbimpl);
+ ASSERT_OK(dbimpl->TEST_WaitForCompact());
+
+ {
+ CompactRangeOptions cro;
+ cro.bottommost_level_compaction = BottommostLevelCompaction::kForce;
+ ASSERT_OK(db->CompactRange(cro, /*begin=*/nullptr, /*end=*/nullptr));
+ }
+
+ {
+ std::string value;
+ const Status s = db->Get(ReadOptions(), "a", &value);
+ ASSERT_TRUE(s.IsNotFound());
+ }
+ };
+
+ // Destroy and reopen
+ ASSERT_OK(ReOpen());
+ write_to_db();
+}
+
+// Test that we can change write policy from WriteCommitted to WritePrepared
+// after a clean shutdown (which would empty the WAL)
+TEST_P(WritePreparedTransactionTest, WP_WC_DBBackwardCompatibility) {
+ bool empty_wal = true;
+ CrossCompatibilityTest(WRITE_COMMITTED, WRITE_PREPARED, empty_wal);
+}
+
+// Test that we fail fast if WAL is not emptied between changing the write
+// policy from WriteCommitted to WritePrepared
+TEST_P(WritePreparedTransactionTest, WP_WC_WALBackwardIncompatibility) {
+ bool empty_wal = true;
+ CrossCompatibilityTest(WRITE_COMMITTED, WRITE_PREPARED, !empty_wal);
+}
+
+// Test that we can change write policy from WritePrepare back to WriteCommitted
+// after a clean shutdown (which would empty the WAL)
+TEST_P(WritePreparedTransactionTest, WC_WP_ForwardCompatibility) {
+ bool empty_wal = true;
+ CrossCompatibilityTest(WRITE_PREPARED, WRITE_COMMITTED, empty_wal);
+}
+
+// Test that we fail fast if WAL is not emptied between changing the write
+// policy from WriteCommitted to WritePrepared
+TEST_P(WritePreparedTransactionTest, WC_WP_WALForwardIncompatibility) {
+ bool empty_wal = true;
+ CrossCompatibilityTest(WRITE_PREPARED, WRITE_COMMITTED, !empty_wal);
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ if (getenv("CIRCLECI")) {
+ // Looking for backtrace on "Resource temporarily unavailable" exceptions
+ ::testing::FLAGS_gtest_catch_exceptions = false;
+ }
+ return RUN_ALL_TESTS();
+}
+
+#else
+#include <stdio.h>
+
+int main(int /*argc*/, char** /*argv*/) {
+ fprintf(stderr,
+ "SKIPPED as Transactions are not supported in ROCKSDB_LITE\n");
+ return 0;
+}
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/write_prepared_txn.cc b/src/rocksdb/utilities/transactions/write_prepared_txn.cc
new file mode 100644
index 000000000..16b5cc1cb
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/write_prepared_txn.cc
@@ -0,0 +1,512 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "utilities/transactions/write_prepared_txn.h"
+
+#include <cinttypes>
+#include <map>
+#include <set>
+
+#include "db/column_family.h"
+#include "db/db_impl/db_impl.h"
+#include "rocksdb/db.h"
+#include "rocksdb/status.h"
+#include "rocksdb/utilities/transaction_db.h"
+#include "util/cast_util.h"
+#include "utilities/transactions/pessimistic_transaction.h"
+#include "utilities/transactions/write_prepared_txn_db.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+struct WriteOptions;
+
+WritePreparedTxn::WritePreparedTxn(WritePreparedTxnDB* txn_db,
+ const WriteOptions& write_options,
+ const TransactionOptions& txn_options)
+ : PessimisticTransaction(txn_db, write_options, txn_options, false),
+ wpt_db_(txn_db) {
+ // Call Initialize outside PessimisticTransaction constructor otherwise it
+ // would skip overridden functions in WritePreparedTxn since they are not
+ // defined yet in the constructor of PessimisticTransaction
+ Initialize(txn_options);
+}
+
+void WritePreparedTxn::Initialize(const TransactionOptions& txn_options) {
+ PessimisticTransaction::Initialize(txn_options);
+ prepare_batch_cnt_ = 0;
+}
+
+void WritePreparedTxn::MultiGet(const ReadOptions& options,
+ ColumnFamilyHandle* column_family,
+ const size_t num_keys, const Slice* keys,
+ PinnableSlice* values, Status* statuses,
+ const bool sorted_input) {
+ SequenceNumber min_uncommitted, snap_seq;
+ const SnapshotBackup backed_by_snapshot =
+ wpt_db_->AssignMinMaxSeqs(options.snapshot, &min_uncommitted, &snap_seq);
+ WritePreparedTxnReadCallback callback(wpt_db_, snap_seq, min_uncommitted,
+ backed_by_snapshot);
+ write_batch_.MultiGetFromBatchAndDB(db_, options, column_family, num_keys,
+ keys, values, statuses, sorted_input,
+ &callback);
+ if (UNLIKELY(!callback.valid() ||
+ !wpt_db_->ValidateSnapshot(snap_seq, backed_by_snapshot))) {
+ wpt_db_->WPRecordTick(TXN_GET_TRY_AGAIN);
+ for (size_t i = 0; i < num_keys; i++) {
+ statuses[i] = Status::TryAgain();
+ }
+ }
+}
+
+Status WritePreparedTxn::Get(const ReadOptions& options,
+ ColumnFamilyHandle* column_family,
+ const Slice& key, PinnableSlice* pinnable_val) {
+ SequenceNumber min_uncommitted, snap_seq;
+ const SnapshotBackup backed_by_snapshot =
+ wpt_db_->AssignMinMaxSeqs(options.snapshot, &min_uncommitted, &snap_seq);
+ WritePreparedTxnReadCallback callback(wpt_db_, snap_seq, min_uncommitted,
+ backed_by_snapshot);
+ Status res = write_batch_.GetFromBatchAndDB(db_, options, column_family, key,
+ pinnable_val, &callback);
+ const bool callback_valid =
+ callback.valid(); // NOTE: validity of callback must always be checked
+ // before it is destructed
+ if (res.ok()) {
+ if (!LIKELY(callback_valid &&
+ wpt_db_->ValidateSnapshot(callback.max_visible_seq(),
+ backed_by_snapshot))) {
+ wpt_db_->WPRecordTick(TXN_GET_TRY_AGAIN);
+ res = Status::TryAgain();
+ }
+ }
+
+ return res;
+}
+
+Iterator* WritePreparedTxn::GetIterator(const ReadOptions& options) {
+ // Make sure to get iterator from WritePrepareTxnDB, not the root db.
+ Iterator* db_iter = wpt_db_->NewIterator(options);
+ assert(db_iter);
+
+ return write_batch_.NewIteratorWithBase(db_iter);
+}
+
+Iterator* WritePreparedTxn::GetIterator(const ReadOptions& options,
+ ColumnFamilyHandle* column_family) {
+ // Make sure to get iterator from WritePrepareTxnDB, not the root db.
+ Iterator* db_iter = wpt_db_->NewIterator(options, column_family);
+ assert(db_iter);
+
+ return write_batch_.NewIteratorWithBase(column_family, db_iter);
+}
+
+Status WritePreparedTxn::PrepareInternal() {
+ WriteOptions write_options = write_options_;
+ write_options.disableWAL = false;
+ const bool WRITE_AFTER_COMMIT = true;
+ const bool kFirstPrepareBatch = true;
+ auto s = WriteBatchInternal::MarkEndPrepare(GetWriteBatch()->GetWriteBatch(),
+ name_, !WRITE_AFTER_COMMIT);
+ assert(s.ok());
+ // For each duplicate key we account for a new sub-batch
+ prepare_batch_cnt_ = GetWriteBatch()->SubBatchCnt();
+ // Having AddPrepared in the PreReleaseCallback allows in-order addition of
+ // prepared entries to PreparedHeap and hence enables an optimization. Refer
+ // to SmallestUnCommittedSeq for more details.
+ AddPreparedCallback add_prepared_callback(
+ wpt_db_, db_impl_, prepare_batch_cnt_,
+ db_impl_->immutable_db_options().two_write_queues, kFirstPrepareBatch);
+ const bool DISABLE_MEMTABLE = true;
+ uint64_t seq_used = kMaxSequenceNumber;
+ s = db_impl_->WriteImpl(write_options, GetWriteBatch()->GetWriteBatch(),
+ /*callback*/ nullptr, &log_number_, /*log ref*/ 0,
+ !DISABLE_MEMTABLE, &seq_used, prepare_batch_cnt_,
+ &add_prepared_callback);
+ assert(!s.ok() || seq_used != kMaxSequenceNumber);
+ auto prepare_seq = seq_used;
+ SetId(prepare_seq);
+ return s;
+}
+
+Status WritePreparedTxn::CommitWithoutPrepareInternal() {
+ // For each duplicate key we account for a new sub-batch
+ const size_t batch_cnt = GetWriteBatch()->SubBatchCnt();
+ return CommitBatchInternal(GetWriteBatch()->GetWriteBatch(), batch_cnt);
+}
+
+Status WritePreparedTxn::CommitBatchInternal(WriteBatch* batch,
+ size_t batch_cnt) {
+ return wpt_db_->WriteInternal(write_options_, batch, batch_cnt, this);
+}
+
+Status WritePreparedTxn::CommitInternal() {
+ ROCKS_LOG_DETAILS(db_impl_->immutable_db_options().info_log,
+ "CommitInternal prepare_seq: %" PRIu64, GetID());
+ // We take the commit-time batch and append the Commit marker.
+ // The Memtable will ignore the Commit marker in non-recovery mode
+ WriteBatch* working_batch = GetCommitTimeWriteBatch();
+ const bool empty = working_batch->Count() == 0;
+ auto s = WriteBatchInternal::MarkCommit(working_batch, name_);
+ assert(s.ok());
+
+ const bool for_recovery = use_only_the_last_commit_time_batch_for_recovery_;
+ if (!empty) {
+ // When not writing to memtable, we can still cache the latest write batch.
+ // The cached batch will be written to memtable in WriteRecoverableState
+ // during FlushMemTable
+ if (for_recovery) {
+ WriteBatchInternal::SetAsLatestPersistentState(working_batch);
+ } else {
+ return Status::InvalidArgument(
+ "Commit-time-batch can only be used if "
+ "use_only_the_last_commit_time_batch_for_recovery is true");
+ }
+ }
+
+ auto prepare_seq = GetId();
+ const bool includes_data = !empty && !for_recovery;
+ assert(prepare_batch_cnt_);
+ size_t commit_batch_cnt = 0;
+ if (UNLIKELY(includes_data)) {
+ ROCKS_LOG_WARN(db_impl_->immutable_db_options().info_log,
+ "Duplicate key overhead");
+ SubBatchCounter counter(*wpt_db_->GetCFComparatorMap());
+ s = working_batch->Iterate(&counter);
+ assert(s.ok());
+ commit_batch_cnt = counter.BatchCount();
+ }
+ const bool disable_memtable = !includes_data;
+ const bool do_one_write =
+ !db_impl_->immutable_db_options().two_write_queues || disable_memtable;
+ WritePreparedCommitEntryPreReleaseCallback update_commit_map(
+ wpt_db_, db_impl_, prepare_seq, prepare_batch_cnt_, commit_batch_cnt);
+ // This is to call AddPrepared on CommitTimeWriteBatch
+ const bool kFirstPrepareBatch = true;
+ AddPreparedCallback add_prepared_callback(
+ wpt_db_, db_impl_, commit_batch_cnt,
+ db_impl_->immutable_db_options().two_write_queues, !kFirstPrepareBatch);
+ PreReleaseCallback* pre_release_callback;
+ if (do_one_write) {
+ pre_release_callback = &update_commit_map;
+ } else {
+ pre_release_callback = &add_prepared_callback;
+ }
+ uint64_t seq_used = kMaxSequenceNumber;
+ // Since the prepared batch is directly written to memtable, there is already
+ // a connection between the memtable and its WAL, so there is no need to
+ // redundantly reference the log that contains the prepared data.
+ const uint64_t zero_log_number = 0ull;
+ size_t batch_cnt = UNLIKELY(commit_batch_cnt) ? commit_batch_cnt : 1;
+ // If `two_write_queues && includes_data`, then `do_one_write` is false. The
+ // following `WriteImpl` will insert the data of the commit-time-batch into
+ // the database before updating the commit cache. Therefore, the data of the
+ // commmit-time-batch is considered uncommitted. Furthermore, since data of
+ // the commit-time-batch are not locked, it is possible for two uncommitted
+ // versions of the same key to co-exist for a (short) period of time until
+ // the commit cache is updated by the second write. If the two uncommitted
+ // keys are compacted to the bottommost level in the meantime, it is possible
+ // that compaction iterator will zero out the sequence numbers of both, thus
+ // violating the invariant that an SST does not have two identical internal
+ // keys. To prevent this situation, we should allow the usage of
+ // commit-time-batch only if the user sets
+ // TransactionOptions::use_only_the_last_commit_time_batch_for_recovery to
+ // true. See the comments about GetCommitTimeWriteBatch() in
+ // include/rocksdb/utilities/transaction.h.
+ s = db_impl_->WriteImpl(write_options_, working_batch, nullptr, nullptr,
+ zero_log_number, disable_memtable, &seq_used,
+ batch_cnt, pre_release_callback);
+ assert(!s.ok() || seq_used != kMaxSequenceNumber);
+ const SequenceNumber commit_batch_seq = seq_used;
+ if (LIKELY(do_one_write || !s.ok())) {
+ if (UNLIKELY(!db_impl_->immutable_db_options().two_write_queues &&
+ s.ok())) {
+ // Note: RemovePrepared should be called after WriteImpl that publishsed
+ // the seq. Otherwise SmallestUnCommittedSeq optimization breaks.
+ wpt_db_->RemovePrepared(prepare_seq, prepare_batch_cnt_);
+ } // else RemovePrepared is called from within PreReleaseCallback
+ if (UNLIKELY(!do_one_write)) {
+ assert(!s.ok());
+ // Cleanup the prepared entry we added with add_prepared_callback
+ wpt_db_->RemovePrepared(commit_batch_seq, commit_batch_cnt);
+ }
+ return s;
+ } // else do the 2nd write to publish seq
+ // Note: the 2nd write comes with a performance penality. So if we have too
+ // many of commits accompanied with ComitTimeWriteBatch and yet we cannot
+ // enable use_only_the_last_commit_time_batch_for_recovery_ optimization,
+ // two_write_queues should be disabled to avoid many additional writes here.
+ const size_t kZeroData = 0;
+ // Update commit map only from the 2nd queue
+ WritePreparedCommitEntryPreReleaseCallback update_commit_map_with_aux_batch(
+ wpt_db_, db_impl_, prepare_seq, prepare_batch_cnt_, kZeroData,
+ commit_batch_seq, commit_batch_cnt);
+ WriteBatch empty_batch;
+ s = empty_batch.PutLogData(Slice());
+ assert(s.ok());
+ // In the absence of Prepare markers, use Noop as a batch separator
+ s = WriteBatchInternal::InsertNoop(&empty_batch);
+ assert(s.ok());
+ const bool DISABLE_MEMTABLE = true;
+ const size_t ONE_BATCH = 1;
+ const uint64_t NO_REF_LOG = 0;
+ s = db_impl_->WriteImpl(write_options_, &empty_batch, nullptr, nullptr,
+ NO_REF_LOG, DISABLE_MEMTABLE, &seq_used, ONE_BATCH,
+ &update_commit_map_with_aux_batch);
+ assert(!s.ok() || seq_used != kMaxSequenceNumber);
+ return s;
+}
+
+Status WritePreparedTxn::RollbackInternal() {
+ ROCKS_LOG_WARN(db_impl_->immutable_db_options().info_log,
+ "RollbackInternal prepare_seq: %" PRIu64, GetId());
+
+ assert(db_impl_);
+ assert(wpt_db_);
+
+ WriteBatch rollback_batch(0 /* reserved_bytes */, 0 /* max_bytes */,
+ write_options_.protection_bytes_per_key,
+ 0 /* default_cf_ts_sz */);
+ assert(GetId() != kMaxSequenceNumber);
+ assert(GetId() > 0);
+ auto cf_map_shared_ptr = wpt_db_->GetCFHandleMap();
+ auto cf_comp_map_shared_ptr = wpt_db_->GetCFComparatorMap();
+ auto read_at_seq = kMaxSequenceNumber;
+ ReadOptions roptions;
+ // to prevent callback's seq to be overrriden inside DBImpk::Get
+ roptions.snapshot = wpt_db_->GetMaxSnapshot();
+ struct RollbackWriteBatchBuilder : public WriteBatch::Handler {
+ DBImpl* const db_;
+ WritePreparedTxnDB* const wpt_db_;
+ WritePreparedTxnReadCallback callback_;
+ WriteBatch* rollback_batch_;
+ std::map<uint32_t, const Comparator*>& comparators_;
+ std::map<uint32_t, ColumnFamilyHandle*>& handles_;
+ using CFKeys = std::set<Slice, SetComparator>;
+ std::map<uint32_t, CFKeys> keys_;
+ bool rollback_merge_operands_;
+ ReadOptions roptions_;
+
+ RollbackWriteBatchBuilder(
+ DBImpl* db, WritePreparedTxnDB* wpt_db, SequenceNumber snap_seq,
+ WriteBatch* dst_batch,
+ std::map<uint32_t, const Comparator*>& comparators,
+ std::map<uint32_t, ColumnFamilyHandle*>& handles,
+ bool rollback_merge_operands, const ReadOptions& _roptions)
+ : db_(db),
+ wpt_db_(wpt_db),
+ callback_(wpt_db, snap_seq), // disable min_uncommitted optimization
+ rollback_batch_(dst_batch),
+ comparators_(comparators),
+ handles_(handles),
+ rollback_merge_operands_(rollback_merge_operands),
+ roptions_(_roptions) {}
+
+ Status Rollback(uint32_t cf, const Slice& key) {
+ Status s;
+ CFKeys& cf_keys = keys_[cf];
+ if (cf_keys.size() == 0) { // just inserted
+ auto cmp = comparators_[cf];
+ keys_[cf] = CFKeys(SetComparator(cmp));
+ }
+ auto it = cf_keys.insert(key);
+ // second is false if a element already existed.
+ if (it.second == false) {
+ return s;
+ }
+
+ PinnableSlice pinnable_val;
+ bool not_used;
+ auto cf_handle = handles_[cf];
+ DBImpl::GetImplOptions get_impl_options;
+ get_impl_options.column_family = cf_handle;
+ get_impl_options.value = &pinnable_val;
+ get_impl_options.value_found = &not_used;
+ get_impl_options.callback = &callback_;
+ s = db_->GetImpl(roptions_, key, get_impl_options);
+ assert(s.ok() || s.IsNotFound());
+ if (s.ok()) {
+ s = rollback_batch_->Put(cf_handle, key, pinnable_val);
+ assert(s.ok());
+ } else if (s.IsNotFound()) {
+ // There has been no readable value before txn. By adding a delete we
+ // make sure that there will be none afterwards either.
+ if (wpt_db_->ShouldRollbackWithSingleDelete(cf_handle, key)) {
+ s = rollback_batch_->SingleDelete(cf_handle, key);
+ } else {
+ s = rollback_batch_->Delete(cf_handle, key);
+ }
+ assert(s.ok());
+ } else {
+ // Unexpected status. Return it to the user.
+ }
+ return s;
+ }
+
+ Status PutCF(uint32_t cf, const Slice& key, const Slice& /*val*/) override {
+ return Rollback(cf, key);
+ }
+
+ Status DeleteCF(uint32_t cf, const Slice& key) override {
+ return Rollback(cf, key);
+ }
+
+ Status SingleDeleteCF(uint32_t cf, const Slice& key) override {
+ return Rollback(cf, key);
+ }
+
+ Status MergeCF(uint32_t cf, const Slice& key,
+ const Slice& /*val*/) override {
+ if (rollback_merge_operands_) {
+ return Rollback(cf, key);
+ } else {
+ return Status::OK();
+ }
+ }
+
+ Status MarkNoop(bool) override { return Status::OK(); }
+ Status MarkBeginPrepare(bool) override { return Status::OK(); }
+ Status MarkEndPrepare(const Slice&) override { return Status::OK(); }
+ Status MarkCommit(const Slice&) override { return Status::OK(); }
+ Status MarkRollback(const Slice&) override {
+ return Status::InvalidArgument();
+ }
+
+ protected:
+ Handler::OptionState WriteAfterCommit() const override {
+ return Handler::OptionState::kDisabled;
+ }
+ } rollback_handler(db_impl_, wpt_db_, read_at_seq, &rollback_batch,
+ *cf_comp_map_shared_ptr.get(), *cf_map_shared_ptr.get(),
+ wpt_db_->txn_db_options_.rollback_merge_operands,
+ roptions);
+ auto s = GetWriteBatch()->GetWriteBatch()->Iterate(&rollback_handler);
+ if (!s.ok()) {
+ return s;
+ }
+ // The Rollback marker will be used as a batch separator
+ s = WriteBatchInternal::MarkRollback(&rollback_batch, name_);
+ assert(s.ok());
+ bool do_one_write = !db_impl_->immutable_db_options().two_write_queues;
+ const bool DISABLE_MEMTABLE = true;
+ const uint64_t NO_REF_LOG = 0;
+ uint64_t seq_used = kMaxSequenceNumber;
+ const size_t ONE_BATCH = 1;
+ const bool kFirstPrepareBatch = true;
+ // We commit the rolled back prepared batches. Although this is
+ // counter-intuitive, i) it is safe to do so, since the prepared batches are
+ // already canceled out by the rollback batch, ii) adding the commit entry to
+ // CommitCache will allow us to benefit from the existing mechanism in
+ // CommitCache that keeps an entry evicted due to max advance and yet overlaps
+ // with a live snapshot around so that the live snapshot properly skips the
+ // entry even if its prepare seq is lower than max_evicted_seq_.
+ AddPreparedCallback add_prepared_callback(
+ wpt_db_, db_impl_, ONE_BATCH,
+ db_impl_->immutable_db_options().two_write_queues, !kFirstPrepareBatch);
+ WritePreparedCommitEntryPreReleaseCallback update_commit_map(
+ wpt_db_, db_impl_, GetId(), prepare_batch_cnt_, ONE_BATCH);
+ PreReleaseCallback* pre_release_callback;
+ if (do_one_write) {
+ pre_release_callback = &update_commit_map;
+ } else {
+ pre_release_callback = &add_prepared_callback;
+ }
+ // Note: the rollback batch does not need AddPrepared since it is written to
+ // DB in one shot. min_uncommitted still works since it requires capturing
+ // data that is written to DB but not yet committed, while
+ // the rollback batch commits with PreReleaseCallback.
+ s = db_impl_->WriteImpl(write_options_, &rollback_batch, nullptr, nullptr,
+ NO_REF_LOG, !DISABLE_MEMTABLE, &seq_used, ONE_BATCH,
+ pre_release_callback);
+ assert(!s.ok() || seq_used != kMaxSequenceNumber);
+ if (!s.ok()) {
+ return s;
+ }
+ if (do_one_write) {
+ assert(!db_impl_->immutable_db_options().two_write_queues);
+ wpt_db_->RemovePrepared(GetId(), prepare_batch_cnt_);
+ return s;
+ } // else do the 2nd write for commit
+ uint64_t rollback_seq = seq_used;
+ ROCKS_LOG_DETAILS(db_impl_->immutable_db_options().info_log,
+ "RollbackInternal 2nd write rollback_seq: %" PRIu64,
+ rollback_seq);
+ // Commit the batch by writing an empty batch to the queue that will release
+ // the commit sequence number to readers.
+ WritePreparedRollbackPreReleaseCallback update_commit_map_with_prepare(
+ wpt_db_, db_impl_, GetId(), rollback_seq, prepare_batch_cnt_);
+ WriteBatch empty_batch;
+ s = empty_batch.PutLogData(Slice());
+ assert(s.ok());
+ // In the absence of Prepare markers, use Noop as a batch separator
+ s = WriteBatchInternal::InsertNoop(&empty_batch);
+ assert(s.ok());
+ s = db_impl_->WriteImpl(write_options_, &empty_batch, nullptr, nullptr,
+ NO_REF_LOG, DISABLE_MEMTABLE, &seq_used, ONE_BATCH,
+ &update_commit_map_with_prepare);
+ assert(!s.ok() || seq_used != kMaxSequenceNumber);
+ ROCKS_LOG_DETAILS(db_impl_->immutable_db_options().info_log,
+ "RollbackInternal (status=%s) commit: %" PRIu64,
+ s.ToString().c_str(), GetId());
+ // TODO(lth): For WriteUnPrepared that rollback is called frequently,
+ // RemovePrepared could be moved to the callback to reduce lock contention.
+ if (s.ok()) {
+ wpt_db_->RemovePrepared(GetId(), prepare_batch_cnt_);
+ }
+ // Note: RemovePrepared for prepared batch is called from within
+ // PreReleaseCallback
+ wpt_db_->RemovePrepared(rollback_seq, ONE_BATCH);
+
+ return s;
+}
+
+Status WritePreparedTxn::ValidateSnapshot(ColumnFamilyHandle* column_family,
+ const Slice& key,
+ SequenceNumber* tracked_at_seq) {
+ assert(snapshot_);
+
+ SequenceNumber min_uncommitted =
+ static_cast_with_check<const SnapshotImpl>(snapshot_.get())
+ ->min_uncommitted_;
+ SequenceNumber snap_seq = snapshot_->GetSequenceNumber();
+ // tracked_at_seq is either max or the last snapshot with which this key was
+ // trackeed so there is no need to apply the IsInSnapshot to this comparison
+ // here as tracked_at_seq is not a prepare seq.
+ if (*tracked_at_seq <= snap_seq) {
+ // If the key has been previous validated at a sequence number earlier
+ // than the curent snapshot's sequence number, we already know it has not
+ // been modified.
+ return Status::OK();
+ }
+
+ *tracked_at_seq = snap_seq;
+
+ ColumnFamilyHandle* cfh =
+ column_family ? column_family : db_impl_->DefaultColumnFamily();
+
+ WritePreparedTxnReadCallback snap_checker(wpt_db_, snap_seq, min_uncommitted,
+ kBackedByDBSnapshot);
+ // TODO(yanqin): support user-defined timestamp
+ return TransactionUtil::CheckKeyForConflicts(
+ db_impl_, cfh, key.ToString(), snap_seq, /*ts=*/nullptr,
+ false /* cache_only */, &snap_checker, min_uncommitted);
+}
+
+void WritePreparedTxn::SetSnapshot() {
+ const bool kForWWConflictCheck = true;
+ SnapshotImpl* snapshot = wpt_db_->GetSnapshotInternal(kForWWConflictCheck);
+ SetSnapshotInternal(snapshot);
+}
+
+Status WritePreparedTxn::RebuildFromWriteBatch(WriteBatch* src_batch) {
+ auto ret = PessimisticTransaction::RebuildFromWriteBatch(src_batch);
+ prepare_batch_cnt_ = GetWriteBatch()->SubBatchCnt();
+ return ret;
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/write_prepared_txn.h b/src/rocksdb/utilities/transactions/write_prepared_txn.h
new file mode 100644
index 000000000..30d9bdb99
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/write_prepared_txn.h
@@ -0,0 +1,119 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#ifndef ROCKSDB_LITE
+
+#include <algorithm>
+#include <atomic>
+#include <mutex>
+#include <stack>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "db/write_callback.h"
+#include "rocksdb/db.h"
+#include "rocksdb/slice.h"
+#include "rocksdb/snapshot.h"
+#include "rocksdb/status.h"
+#include "rocksdb/types.h"
+#include "rocksdb/utilities/transaction.h"
+#include "rocksdb/utilities/transaction_db.h"
+#include "rocksdb/utilities/write_batch_with_index.h"
+#include "util/autovector.h"
+#include "utilities/transactions/pessimistic_transaction.h"
+#include "utilities/transactions/pessimistic_transaction_db.h"
+#include "utilities/transactions/transaction_base.h"
+#include "utilities/transactions/transaction_util.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class WritePreparedTxnDB;
+
+// This impl could write to DB also uncommitted data and then later tell apart
+// committed data from uncommitted data. Uncommitted data could be after the
+// Prepare phase in 2PC (WritePreparedTxn) or before that
+// (WriteUnpreparedTxnImpl).
+class WritePreparedTxn : public PessimisticTransaction {
+ public:
+ WritePreparedTxn(WritePreparedTxnDB* db, const WriteOptions& write_options,
+ const TransactionOptions& txn_options);
+ // No copying allowed
+ WritePreparedTxn(const WritePreparedTxn&) = delete;
+ void operator=(const WritePreparedTxn&) = delete;
+
+ virtual ~WritePreparedTxn() {}
+
+ // To make WAL commit markers visible, the snapshot will be based on the last
+ // seq in the WAL that is also published, LastPublishedSequence, as opposed to
+ // the last seq in the memtable.
+ using Transaction::Get;
+ virtual Status Get(const ReadOptions& options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ PinnableSlice* value) override;
+
+ using Transaction::MultiGet;
+ virtual void MultiGet(const ReadOptions& options,
+ ColumnFamilyHandle* column_family,
+ const size_t num_keys, const Slice* keys,
+ PinnableSlice* values, Status* statuses,
+ const bool sorted_input = false) override;
+
+ // Note: The behavior is undefined in presence of interleaved writes to the
+ // same transaction.
+ // To make WAL commit markers visible, the snapshot will be
+ // based on the last seq in the WAL that is also published,
+ // LastPublishedSequence, as opposed to the last seq in the memtable.
+ using Transaction::GetIterator;
+ virtual Iterator* GetIterator(const ReadOptions& options) override;
+ virtual Iterator* GetIterator(const ReadOptions& options,
+ ColumnFamilyHandle* column_family) override;
+
+ virtual void SetSnapshot() override;
+
+ protected:
+ void Initialize(const TransactionOptions& txn_options) override;
+ // Override the protected SetId to make it visible to the friend class
+ // WritePreparedTxnDB
+ inline void SetId(uint64_t id) override { Transaction::SetId(id); }
+
+ private:
+ friend class WritePreparedTransactionTest_BasicRecoveryTest_Test;
+ friend class WritePreparedTxnDB;
+ friend class WriteUnpreparedTxnDB;
+ friend class WriteUnpreparedTxn;
+
+ Status PrepareInternal() override;
+
+ Status CommitWithoutPrepareInternal() override;
+
+ Status CommitBatchInternal(WriteBatch* batch, size_t batch_cnt) override;
+
+ // Since the data is already written to memtables at the Prepare phase, the
+ // commit entails writing only a commit marker in the WAL. The sequence number
+ // of the commit marker is then the commit timestamp of the transaction. To
+ // make WAL commit markers visible, the snapshot will be based on the last seq
+ // in the WAL that is also published, LastPublishedSequence, as opposed to the
+ // last seq in the memtable.
+ Status CommitInternal() override;
+
+ Status RollbackInternal() override;
+
+ virtual Status ValidateSnapshot(ColumnFamilyHandle* column_family,
+ const Slice& key,
+ SequenceNumber* tracked_at_seq) override;
+
+ virtual Status RebuildFromWriteBatch(WriteBatch* src_batch) override;
+
+ WritePreparedTxnDB* wpt_db_;
+ // Number of sub-batches in prepare
+ size_t prepare_batch_cnt_ = 0;
+};
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/write_prepared_txn_db.cc b/src/rocksdb/utilities/transactions/write_prepared_txn_db.cc
new file mode 100644
index 000000000..595c3df8f
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/write_prepared_txn_db.cc
@@ -0,0 +1,1030 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "utilities/transactions/write_prepared_txn_db.h"
+
+#include <algorithm>
+#include <cinttypes>
+#include <string>
+#include <unordered_set>
+#include <vector>
+
+#include "db/arena_wrapped_db_iter.h"
+#include "db/db_impl/db_impl.h"
+#include "logging/logging.h"
+#include "rocksdb/db.h"
+#include "rocksdb/options.h"
+#include "rocksdb/utilities/transaction_db.h"
+#include "test_util/sync_point.h"
+#include "util/cast_util.h"
+#include "util/mutexlock.h"
+#include "util/string_util.h"
+#include "utilities/transactions/pessimistic_transaction.h"
+#include "utilities/transactions/transaction_db_mutex_impl.h"
+
+// This function is for testing only. If it returns true, then all entries in
+// the commit cache will be evicted. Unit and/or stress tests (db_stress)
+// can implement this function and customize how frequently commit cache
+// eviction occurs.
+// TODO: remove this function once we can configure commit cache to be very
+// small so that eviction occurs very frequently. This requires the commit
+// cache entry to be able to encode prepare and commit sequence numbers so that
+// the commit sequence number does not have to be within a certain range of
+// prepare sequence number.
+extern "C" bool rocksdb_write_prepared_TEST_ShouldClearCommitCache(void)
+ __attribute__((__weak__));
+
+namespace ROCKSDB_NAMESPACE {
+
+Status WritePreparedTxnDB::Initialize(
+ const std::vector<size_t>& compaction_enabled_cf_indices,
+ const std::vector<ColumnFamilyHandle*>& handles) {
+ auto dbimpl = static_cast_with_check<DBImpl>(GetRootDB());
+ assert(dbimpl != nullptr);
+ auto rtxns = dbimpl->recovered_transactions();
+ std::map<SequenceNumber, SequenceNumber> ordered_seq_cnt;
+ for (auto rtxn : rtxns) {
+ // There should only one batch for WritePrepared policy.
+ assert(rtxn.second->batches_.size() == 1);
+ const auto& seq = rtxn.second->batches_.begin()->first;
+ const auto& batch_info = rtxn.second->batches_.begin()->second;
+ auto cnt = batch_info.batch_cnt_ ? batch_info.batch_cnt_ : 1;
+ ordered_seq_cnt[seq] = cnt;
+ }
+ // AddPrepared must be called in order
+ for (auto seq_cnt : ordered_seq_cnt) {
+ auto seq = seq_cnt.first;
+ auto cnt = seq_cnt.second;
+ for (size_t i = 0; i < cnt; i++) {
+ AddPrepared(seq + i);
+ }
+ }
+ SequenceNumber prev_max = max_evicted_seq_;
+ SequenceNumber last_seq = db_impl_->GetLatestSequenceNumber();
+ AdvanceMaxEvictedSeq(prev_max, last_seq);
+ // Create a gap between max and the next snapshot. This simplifies the logic
+ // in IsInSnapshot by not having to consider the special case of max ==
+ // snapshot after recovery. This is tested in IsInSnapshotEmptyMapTest.
+ if (last_seq) {
+ db_impl_->versions_->SetLastAllocatedSequence(last_seq + 1);
+ db_impl_->versions_->SetLastSequence(last_seq + 1);
+ db_impl_->versions_->SetLastPublishedSequence(last_seq + 1);
+ }
+
+ db_impl_->SetSnapshotChecker(new WritePreparedSnapshotChecker(this));
+ // A callback to commit a single sub-batch
+ class CommitSubBatchPreReleaseCallback : public PreReleaseCallback {
+ public:
+ explicit CommitSubBatchPreReleaseCallback(WritePreparedTxnDB* db)
+ : db_(db) {}
+ Status Callback(SequenceNumber commit_seq,
+ bool is_mem_disabled __attribute__((__unused__)), uint64_t,
+ size_t /*index*/, size_t /*total*/) override {
+ assert(!is_mem_disabled);
+ db_->AddCommitted(commit_seq, commit_seq);
+ return Status::OK();
+ }
+
+ private:
+ WritePreparedTxnDB* db_;
+ };
+ db_impl_->SetRecoverableStatePreReleaseCallback(
+ new CommitSubBatchPreReleaseCallback(this));
+
+ auto s = PessimisticTransactionDB::Initialize(compaction_enabled_cf_indices,
+ handles);
+ return s;
+}
+
+Status WritePreparedTxnDB::VerifyCFOptions(
+ const ColumnFamilyOptions& cf_options) {
+ Status s = PessimisticTransactionDB::VerifyCFOptions(cf_options);
+ if (!s.ok()) {
+ return s;
+ }
+ if (!cf_options.memtable_factory->CanHandleDuplicatedKey()) {
+ return Status::InvalidArgument(
+ "memtable_factory->CanHandleDuplicatedKey() cannot be false with "
+ "WritePrpeared transactions");
+ }
+ return Status::OK();
+}
+
+Transaction* WritePreparedTxnDB::BeginTransaction(
+ const WriteOptions& write_options, const TransactionOptions& txn_options,
+ Transaction* old_txn) {
+ if (old_txn != nullptr) {
+ ReinitializeTransaction(old_txn, write_options, txn_options);
+ return old_txn;
+ } else {
+ return new WritePreparedTxn(this, write_options, txn_options);
+ }
+}
+
+Status WritePreparedTxnDB::Write(const WriteOptions& opts,
+ WriteBatch* updates) {
+ if (txn_db_options_.skip_concurrency_control) {
+ // Skip locking the rows
+ const size_t UNKNOWN_BATCH_CNT = 0;
+ WritePreparedTxn* NO_TXN = nullptr;
+ return WriteInternal(opts, updates, UNKNOWN_BATCH_CNT, NO_TXN);
+ } else {
+ return PessimisticTransactionDB::WriteWithConcurrencyControl(opts, updates);
+ }
+}
+
+Status WritePreparedTxnDB::Write(
+ const WriteOptions& opts,
+ const TransactionDBWriteOptimizations& optimizations, WriteBatch* updates) {
+ if (optimizations.skip_concurrency_control) {
+ // Skip locking the rows
+ const size_t UNKNOWN_BATCH_CNT = 0;
+ const size_t ONE_BATCH_CNT = 1;
+ const size_t batch_cnt = optimizations.skip_duplicate_key_check
+ ? ONE_BATCH_CNT
+ : UNKNOWN_BATCH_CNT;
+ WritePreparedTxn* NO_TXN = nullptr;
+ return WriteInternal(opts, updates, batch_cnt, NO_TXN);
+ } else {
+ // TODO(myabandeh): Make use of skip_duplicate_key_check hint
+ // Fall back to unoptimized version
+ return PessimisticTransactionDB::WriteWithConcurrencyControl(opts, updates);
+ }
+}
+
+Status WritePreparedTxnDB::WriteInternal(const WriteOptions& write_options_orig,
+ WriteBatch* batch, size_t batch_cnt,
+ WritePreparedTxn* txn) {
+ ROCKS_LOG_DETAILS(db_impl_->immutable_db_options().info_log,
+ "CommitBatchInternal");
+ if (batch->Count() == 0) {
+ // Otherwise our 1 seq per batch logic will break since there is no seq
+ // increased for this batch.
+ return Status::OK();
+ }
+
+ if (write_options_orig.protection_bytes_per_key > 0) {
+ auto s = WriteBatchInternal::UpdateProtectionInfo(
+ batch, write_options_orig.protection_bytes_per_key);
+ if (!s.ok()) {
+ return s;
+ }
+ }
+
+ if (batch_cnt == 0) { // not provided, then compute it
+ // TODO(myabandeh): add an option to allow user skipping this cost
+ SubBatchCounter counter(*GetCFComparatorMap());
+ auto s = batch->Iterate(&counter);
+ if (!s.ok()) {
+ return s;
+ }
+ batch_cnt = counter.BatchCount();
+ WPRecordTick(TXN_DUPLICATE_KEY_OVERHEAD);
+ ROCKS_LOG_DETAILS(info_log_, "Duplicate key overhead: %" PRIu64 " batches",
+ static_cast<uint64_t>(batch_cnt));
+ }
+ assert(batch_cnt);
+
+ bool do_one_write = !db_impl_->immutable_db_options().two_write_queues;
+ WriteOptions write_options(write_options_orig);
+ // In the absence of Prepare markers, use Noop as a batch separator
+ auto s = WriteBatchInternal::InsertNoop(batch);
+ assert(s.ok());
+ const bool DISABLE_MEMTABLE = true;
+ const uint64_t no_log_ref = 0;
+ uint64_t seq_used = kMaxSequenceNumber;
+ const size_t ZERO_PREPARES = 0;
+ const bool kSeperatePrepareCommitBatches = true;
+ // Since this is not 2pc, there is no need for AddPrepared but having it in
+ // the PreReleaseCallback enables an optimization. Refer to
+ // SmallestUnCommittedSeq for more details.
+ AddPreparedCallback add_prepared_callback(
+ this, db_impl_, batch_cnt,
+ db_impl_->immutable_db_options().two_write_queues,
+ !kSeperatePrepareCommitBatches);
+ WritePreparedCommitEntryPreReleaseCallback update_commit_map(
+ this, db_impl_, kMaxSequenceNumber, ZERO_PREPARES, batch_cnt);
+ PreReleaseCallback* pre_release_callback;
+ if (do_one_write) {
+ pre_release_callback = &update_commit_map;
+ } else {
+ pre_release_callback = &add_prepared_callback;
+ }
+ s = db_impl_->WriteImpl(write_options, batch, nullptr, nullptr, no_log_ref,
+ !DISABLE_MEMTABLE, &seq_used, batch_cnt,
+ pre_release_callback);
+ assert(!s.ok() || seq_used != kMaxSequenceNumber);
+ uint64_t prepare_seq = seq_used;
+ if (txn != nullptr) {
+ txn->SetId(prepare_seq);
+ }
+ if (!s.ok()) {
+ return s;
+ }
+ if (do_one_write) {
+ return s;
+ } // else do the 2nd write for commit
+ ROCKS_LOG_DETAILS(db_impl_->immutable_db_options().info_log,
+ "CommitBatchInternal 2nd write prepare_seq: %" PRIu64,
+ prepare_seq);
+ // Commit the batch by writing an empty batch to the 2nd queue that will
+ // release the commit sequence number to readers.
+ const size_t ZERO_COMMITS = 0;
+ WritePreparedCommitEntryPreReleaseCallback update_commit_map_with_prepare(
+ this, db_impl_, prepare_seq, batch_cnt, ZERO_COMMITS);
+ WriteBatch empty_batch;
+ write_options.disableWAL = true;
+ write_options.sync = false;
+ const size_t ONE_BATCH = 1; // Just to inc the seq
+ s = db_impl_->WriteImpl(write_options, &empty_batch, nullptr, nullptr,
+ no_log_ref, DISABLE_MEMTABLE, &seq_used, ONE_BATCH,
+ &update_commit_map_with_prepare);
+ assert(!s.ok() || seq_used != kMaxSequenceNumber);
+ // Note: RemovePrepared is called from within PreReleaseCallback
+ return s;
+}
+
+Status WritePreparedTxnDB::Get(const ReadOptions& options,
+ ColumnFamilyHandle* column_family,
+ const Slice& key, PinnableSlice* value) {
+ SequenceNumber min_uncommitted, snap_seq;
+ const SnapshotBackup backed_by_snapshot =
+ AssignMinMaxSeqs(options.snapshot, &min_uncommitted, &snap_seq);
+ WritePreparedTxnReadCallback callback(this, snap_seq, min_uncommitted,
+ backed_by_snapshot);
+ bool* dont_care = nullptr;
+ DBImpl::GetImplOptions get_impl_options;
+ get_impl_options.column_family = column_family;
+ get_impl_options.value = value;
+ get_impl_options.value_found = dont_care;
+ get_impl_options.callback = &callback;
+ auto res = db_impl_->GetImpl(options, key, get_impl_options);
+ if (LIKELY(callback.valid() && ValidateSnapshot(callback.max_visible_seq(),
+ backed_by_snapshot))) {
+ return res;
+ } else {
+ res.PermitUncheckedError();
+ WPRecordTick(TXN_GET_TRY_AGAIN);
+ return Status::TryAgain();
+ }
+}
+
+void WritePreparedTxnDB::UpdateCFComparatorMap(
+ const std::vector<ColumnFamilyHandle*>& handles) {
+ auto cf_map = new std::map<uint32_t, const Comparator*>();
+ auto handle_map = new std::map<uint32_t, ColumnFamilyHandle*>();
+ for (auto h : handles) {
+ auto id = h->GetID();
+ const Comparator* comparator = h->GetComparator();
+ (*cf_map)[id] = comparator;
+ if (id != 0) {
+ (*handle_map)[id] = h;
+ } else {
+ // The pointer to the default cf handle in the handles will be deleted.
+ // Use the pointer maintained by the db instead.
+ (*handle_map)[id] = DefaultColumnFamily();
+ }
+ }
+ cf_map_.reset(cf_map);
+ handle_map_.reset(handle_map);
+}
+
+void WritePreparedTxnDB::UpdateCFComparatorMap(ColumnFamilyHandle* h) {
+ auto old_cf_map_ptr = cf_map_.get();
+ assert(old_cf_map_ptr);
+ auto cf_map = new std::map<uint32_t, const Comparator*>(*old_cf_map_ptr);
+ auto old_handle_map_ptr = handle_map_.get();
+ assert(old_handle_map_ptr);
+ auto handle_map =
+ new std::map<uint32_t, ColumnFamilyHandle*>(*old_handle_map_ptr);
+ auto id = h->GetID();
+ const Comparator* comparator = h->GetComparator();
+ (*cf_map)[id] = comparator;
+ (*handle_map)[id] = h;
+ cf_map_.reset(cf_map);
+ handle_map_.reset(handle_map);
+}
+
+std::vector<Status> WritePreparedTxnDB::MultiGet(
+ const ReadOptions& options,
+ const std::vector<ColumnFamilyHandle*>& column_family,
+ const std::vector<Slice>& keys, std::vector<std::string>* values) {
+ assert(values);
+ size_t num_keys = keys.size();
+ values->resize(num_keys);
+
+ std::vector<Status> stat_list(num_keys);
+ for (size_t i = 0; i < num_keys; ++i) {
+ stat_list[i] = this->Get(options, column_family[i], keys[i], &(*values)[i]);
+ }
+ return stat_list;
+}
+
+// Struct to hold ownership of snapshot and read callback for iterator cleanup.
+struct WritePreparedTxnDB::IteratorState {
+ IteratorState(WritePreparedTxnDB* txn_db, SequenceNumber sequence,
+ std::shared_ptr<ManagedSnapshot> s,
+ SequenceNumber min_uncommitted)
+ : callback(txn_db, sequence, min_uncommitted, kBackedByDBSnapshot),
+ snapshot(s) {}
+
+ WritePreparedTxnReadCallback callback;
+ std::shared_ptr<ManagedSnapshot> snapshot;
+};
+
+namespace {
+static void CleanupWritePreparedTxnDBIterator(void* arg1, void* /*arg2*/) {
+ delete reinterpret_cast<WritePreparedTxnDB::IteratorState*>(arg1);
+}
+} // anonymous namespace
+
+Iterator* WritePreparedTxnDB::NewIterator(const ReadOptions& options,
+ ColumnFamilyHandle* column_family) {
+ constexpr bool expose_blob_index = false;
+ constexpr bool allow_refresh = false;
+ std::shared_ptr<ManagedSnapshot> own_snapshot = nullptr;
+ SequenceNumber snapshot_seq = kMaxSequenceNumber;
+ SequenceNumber min_uncommitted = 0;
+ if (options.snapshot != nullptr) {
+ snapshot_seq = options.snapshot->GetSequenceNumber();
+ min_uncommitted =
+ static_cast_with_check<const SnapshotImpl>(options.snapshot)
+ ->min_uncommitted_;
+ } else {
+ auto* snapshot = GetSnapshot();
+ // We take a snapshot to make sure that the related data in the commit map
+ // are not deleted.
+ snapshot_seq = snapshot->GetSequenceNumber();
+ min_uncommitted =
+ static_cast_with_check<const SnapshotImpl>(snapshot)->min_uncommitted_;
+ own_snapshot = std::make_shared<ManagedSnapshot>(db_impl_, snapshot);
+ }
+ assert(snapshot_seq != kMaxSequenceNumber);
+ auto* cfd =
+ static_cast_with_check<ColumnFamilyHandleImpl>(column_family)->cfd();
+ auto* state =
+ new IteratorState(this, snapshot_seq, own_snapshot, min_uncommitted);
+ auto* db_iter =
+ db_impl_->NewIteratorImpl(options, cfd, snapshot_seq, &state->callback,
+ expose_blob_index, allow_refresh);
+ db_iter->RegisterCleanup(CleanupWritePreparedTxnDBIterator, state, nullptr);
+ return db_iter;
+}
+
+Status WritePreparedTxnDB::NewIterators(
+ const ReadOptions& options,
+ const std::vector<ColumnFamilyHandle*>& column_families,
+ std::vector<Iterator*>* iterators) {
+ constexpr bool expose_blob_index = false;
+ constexpr bool allow_refresh = false;
+ std::shared_ptr<ManagedSnapshot> own_snapshot = nullptr;
+ SequenceNumber snapshot_seq = kMaxSequenceNumber;
+ SequenceNumber min_uncommitted = 0;
+ if (options.snapshot != nullptr) {
+ snapshot_seq = options.snapshot->GetSequenceNumber();
+ min_uncommitted =
+ static_cast_with_check<const SnapshotImpl>(options.snapshot)
+ ->min_uncommitted_;
+ } else {
+ auto* snapshot = GetSnapshot();
+ // We take a snapshot to make sure that the related data in the commit map
+ // are not deleted.
+ snapshot_seq = snapshot->GetSequenceNumber();
+ own_snapshot = std::make_shared<ManagedSnapshot>(db_impl_, snapshot);
+ min_uncommitted =
+ static_cast_with_check<const SnapshotImpl>(snapshot)->min_uncommitted_;
+ }
+ iterators->clear();
+ iterators->reserve(column_families.size());
+ for (auto* column_family : column_families) {
+ auto* cfd =
+ static_cast_with_check<ColumnFamilyHandleImpl>(column_family)->cfd();
+ auto* state =
+ new IteratorState(this, snapshot_seq, own_snapshot, min_uncommitted);
+ auto* db_iter =
+ db_impl_->NewIteratorImpl(options, cfd, snapshot_seq, &state->callback,
+ expose_blob_index, allow_refresh);
+ db_iter->RegisterCleanup(CleanupWritePreparedTxnDBIterator, state, nullptr);
+ iterators->push_back(db_iter);
+ }
+ return Status::OK();
+}
+
+void WritePreparedTxnDB::Init(const TransactionDBOptions& txn_db_opts) {
+ // Adcance max_evicted_seq_ no more than 100 times before the cache wraps
+ // around.
+ INC_STEP_FOR_MAX_EVICTED =
+ std::max(COMMIT_CACHE_SIZE / 100, static_cast<size_t>(1));
+ snapshot_cache_ = std::unique_ptr<std::atomic<SequenceNumber>[]>(
+ new std::atomic<SequenceNumber>[SNAPSHOT_CACHE_SIZE] {});
+ commit_cache_ = std::unique_ptr<std::atomic<CommitEntry64b>[]>(
+ new std::atomic<CommitEntry64b>[COMMIT_CACHE_SIZE] {});
+ dummy_max_snapshot_.number_ = kMaxSequenceNumber;
+ rollback_deletion_type_callback_ =
+ txn_db_opts.rollback_deletion_type_callback;
+}
+
+void WritePreparedTxnDB::CheckPreparedAgainstMax(SequenceNumber new_max,
+ bool locked) {
+ // When max_evicted_seq_ advances, move older entries from prepared_txns_
+ // to delayed_prepared_. This guarantees that if a seq is lower than max,
+ // then it is not in prepared_txns_ and save an expensive, synchronized
+ // lookup from a shared set. delayed_prepared_ is expected to be empty in
+ // normal cases.
+ ROCKS_LOG_DETAILS(
+ info_log_,
+ "CheckPreparedAgainstMax prepared_txns_.empty() %d top: %" PRIu64,
+ prepared_txns_.empty(),
+ prepared_txns_.empty() ? 0 : prepared_txns_.top());
+ const SequenceNumber prepared_top = prepared_txns_.top();
+ const bool empty = prepared_top == kMaxSequenceNumber;
+ // Preliminary check to avoid the synchronization cost
+ if (!empty && prepared_top <= new_max) {
+ if (locked) {
+ // Needed to avoid double locking in pop().
+ prepared_txns_.push_pop_mutex()->Unlock();
+ }
+ WriteLock wl(&prepared_mutex_);
+ // Need to fetch fresh values of ::top after mutex is acquired
+ while (!prepared_txns_.empty() && prepared_txns_.top() <= new_max) {
+ auto to_be_popped = prepared_txns_.top();
+ delayed_prepared_.insert(to_be_popped);
+ ROCKS_LOG_WARN(info_log_,
+ "prepared_mutex_ overhead %" PRIu64 " (prep=%" PRIu64
+ " new_max=%" PRIu64 ")",
+ static_cast<uint64_t>(delayed_prepared_.size()),
+ to_be_popped, new_max);
+ delayed_prepared_empty_.store(false, std::memory_order_release);
+ // Update prepared_txns_ after updating delayed_prepared_empty_ otherwise
+ // there will be a point in time that the entry is neither in
+ // prepared_txns_ nor in delayed_prepared_, which will not be checked if
+ // delayed_prepared_empty_ is false.
+ prepared_txns_.pop();
+ }
+ if (locked) {
+ prepared_txns_.push_pop_mutex()->Lock();
+ }
+ }
+}
+
+void WritePreparedTxnDB::AddPrepared(uint64_t seq, bool locked) {
+ ROCKS_LOG_DETAILS(info_log_, "Txn %" PRIu64 " Preparing with max %" PRIu64,
+ seq, max_evicted_seq_.load());
+ TEST_SYNC_POINT("AddPrepared::begin:pause");
+ TEST_SYNC_POINT("AddPrepared::begin:resume");
+ if (!locked) {
+ prepared_txns_.push_pop_mutex()->Lock();
+ }
+ prepared_txns_.push_pop_mutex()->AssertHeld();
+ prepared_txns_.push(seq);
+ auto new_max = future_max_evicted_seq_.load();
+ if (UNLIKELY(seq <= new_max)) {
+ // This should not happen in normal case
+ ROCKS_LOG_ERROR(
+ info_log_,
+ "Added prepare_seq is not larger than max_evicted_seq_: %" PRIu64
+ " <= %" PRIu64,
+ seq, new_max);
+ CheckPreparedAgainstMax(new_max, true /*locked*/);
+ }
+ if (!locked) {
+ prepared_txns_.push_pop_mutex()->Unlock();
+ }
+ TEST_SYNC_POINT("AddPrepared::end");
+}
+
+void WritePreparedTxnDB::AddCommitted(uint64_t prepare_seq, uint64_t commit_seq,
+ uint8_t loop_cnt) {
+ ROCKS_LOG_DETAILS(info_log_, "Txn %" PRIu64 " Committing with %" PRIu64,
+ prepare_seq, commit_seq);
+ TEST_SYNC_POINT("WritePreparedTxnDB::AddCommitted:start");
+ TEST_SYNC_POINT("WritePreparedTxnDB::AddCommitted:start:pause");
+ auto indexed_seq = prepare_seq % COMMIT_CACHE_SIZE;
+ CommitEntry64b evicted_64b;
+ CommitEntry evicted;
+ bool to_be_evicted = GetCommitEntry(indexed_seq, &evicted_64b, &evicted);
+ if (LIKELY(to_be_evicted)) {
+ assert(evicted.prep_seq != prepare_seq);
+ auto prev_max = max_evicted_seq_.load(std::memory_order_acquire);
+ ROCKS_LOG_DETAILS(info_log_,
+ "Evicting %" PRIu64 ",%" PRIu64 " with max %" PRIu64,
+ evicted.prep_seq, evicted.commit_seq, prev_max);
+ if (prev_max < evicted.commit_seq) {
+ auto last = db_impl_->GetLastPublishedSequence(); // could be 0
+ SequenceNumber max_evicted_seq;
+ if (LIKELY(evicted.commit_seq < last)) {
+ assert(last > 0);
+ // Inc max in larger steps to avoid frequent updates
+ max_evicted_seq =
+ std::min(evicted.commit_seq + INC_STEP_FOR_MAX_EVICTED, last - 1);
+ } else {
+ // legit when a commit entry in a write batch overwrite the previous one
+ max_evicted_seq = evicted.commit_seq;
+ }
+#ifdef OS_LINUX
+ if (rocksdb_write_prepared_TEST_ShouldClearCommitCache &&
+ rocksdb_write_prepared_TEST_ShouldClearCommitCache()) {
+ max_evicted_seq = last;
+ }
+#endif // OS_LINUX
+ ROCKS_LOG_DETAILS(info_log_,
+ "%lu Evicting %" PRIu64 ",%" PRIu64 " with max %" PRIu64
+ " => %lu",
+ prepare_seq, evicted.prep_seq, evicted.commit_seq,
+ prev_max, max_evicted_seq);
+ AdvanceMaxEvictedSeq(prev_max, max_evicted_seq);
+ }
+ if (UNLIKELY(!delayed_prepared_empty_.load(std::memory_order_acquire))) {
+ WriteLock wl(&prepared_mutex_);
+ auto dp_iter = delayed_prepared_.find(evicted.prep_seq);
+ if (dp_iter != delayed_prepared_.end()) {
+ // This is a rare case that txn is committed but prepared_txns_ is not
+ // cleaned up yet. Refer to delayed_prepared_commits_ definition for
+ // why it should be kept updated.
+ delayed_prepared_commits_[evicted.prep_seq] = evicted.commit_seq;
+ ROCKS_LOG_DEBUG(info_log_,
+ "delayed_prepared_commits_[%" PRIu64 "]=%" PRIu64,
+ evicted.prep_seq, evicted.commit_seq);
+ }
+ }
+ // After each eviction from commit cache, check if the commit entry should
+ // be kept around because it overlaps with a live snapshot.
+ CheckAgainstSnapshots(evicted);
+ }
+ bool succ =
+ ExchangeCommitEntry(indexed_seq, evicted_64b, {prepare_seq, commit_seq});
+ if (UNLIKELY(!succ)) {
+ ROCKS_LOG_ERROR(info_log_,
+ "ExchangeCommitEntry failed on [%" PRIu64 "] %" PRIu64
+ ",%" PRIu64 " retrying...",
+ indexed_seq, prepare_seq, commit_seq);
+ // A very rare event, in which the commit entry is updated before we do.
+ // Here we apply a very simple solution of retrying.
+ if (loop_cnt > 100) {
+ throw std::runtime_error("Infinite loop in AddCommitted!");
+ }
+ AddCommitted(prepare_seq, commit_seq, ++loop_cnt);
+ return;
+ }
+ TEST_SYNC_POINT("WritePreparedTxnDB::AddCommitted:end");
+ TEST_SYNC_POINT("WritePreparedTxnDB::AddCommitted:end:pause");
+}
+
+void WritePreparedTxnDB::RemovePrepared(const uint64_t prepare_seq,
+ const size_t batch_cnt) {
+ TEST_SYNC_POINT_CALLBACK(
+ "RemovePrepared:Start",
+ const_cast<void*>(reinterpret_cast<const void*>(&prepare_seq)));
+ TEST_SYNC_POINT("WritePreparedTxnDB::RemovePrepared:pause");
+ TEST_SYNC_POINT("WritePreparedTxnDB::RemovePrepared:resume");
+ ROCKS_LOG_DETAILS(info_log_,
+ "RemovePrepared %" PRIu64 " cnt: %" ROCKSDB_PRIszt,
+ prepare_seq, batch_cnt);
+ WriteLock wl(&prepared_mutex_);
+ for (size_t i = 0; i < batch_cnt; i++) {
+ prepared_txns_.erase(prepare_seq + i);
+ bool was_empty = delayed_prepared_.empty();
+ if (!was_empty) {
+ delayed_prepared_.erase(prepare_seq + i);
+ auto it = delayed_prepared_commits_.find(prepare_seq + i);
+ if (it != delayed_prepared_commits_.end()) {
+ ROCKS_LOG_DETAILS(info_log_, "delayed_prepared_commits_.erase %" PRIu64,
+ prepare_seq + i);
+ delayed_prepared_commits_.erase(it);
+ }
+ bool is_empty = delayed_prepared_.empty();
+ if (was_empty != is_empty) {
+ delayed_prepared_empty_.store(is_empty, std::memory_order_release);
+ }
+ }
+ }
+}
+
+bool WritePreparedTxnDB::GetCommitEntry(const uint64_t indexed_seq,
+ CommitEntry64b* entry_64b,
+ CommitEntry* entry) const {
+ *entry_64b = commit_cache_[static_cast<size_t>(indexed_seq)].load(
+ std::memory_order_acquire);
+ bool valid = entry_64b->Parse(indexed_seq, entry, FORMAT);
+ return valid;
+}
+
+bool WritePreparedTxnDB::AddCommitEntry(const uint64_t indexed_seq,
+ const CommitEntry& new_entry,
+ CommitEntry* evicted_entry) {
+ CommitEntry64b new_entry_64b(new_entry, FORMAT);
+ CommitEntry64b evicted_entry_64b =
+ commit_cache_[static_cast<size_t>(indexed_seq)].exchange(
+ new_entry_64b, std::memory_order_acq_rel);
+ bool valid = evicted_entry_64b.Parse(indexed_seq, evicted_entry, FORMAT);
+ return valid;
+}
+
+bool WritePreparedTxnDB::ExchangeCommitEntry(const uint64_t indexed_seq,
+ CommitEntry64b& expected_entry_64b,
+ const CommitEntry& new_entry) {
+ auto& atomic_entry = commit_cache_[static_cast<size_t>(indexed_seq)];
+ CommitEntry64b new_entry_64b(new_entry, FORMAT);
+ bool succ = atomic_entry.compare_exchange_strong(
+ expected_entry_64b, new_entry_64b, std::memory_order_acq_rel,
+ std::memory_order_acquire);
+ return succ;
+}
+
+void WritePreparedTxnDB::AdvanceMaxEvictedSeq(const SequenceNumber& prev_max,
+ const SequenceNumber& new_max) {
+ ROCKS_LOG_DETAILS(info_log_,
+ "AdvanceMaxEvictedSeq overhead %" PRIu64 " => %" PRIu64,
+ prev_max, new_max);
+ // Declare the intention before getting snapshot from the DB. This helps a
+ // concurrent GetSnapshot to wait to catch up with future_max_evicted_seq_ if
+ // it has not already. Otherwise the new snapshot is when we ask DB for
+ // snapshots smaller than future max.
+ auto updated_future_max = prev_max;
+ while (updated_future_max < new_max &&
+ !future_max_evicted_seq_.compare_exchange_weak(
+ updated_future_max, new_max, std::memory_order_acq_rel,
+ std::memory_order_relaxed)) {
+ };
+
+ CheckPreparedAgainstMax(new_max, false /*locked*/);
+
+ // With each change to max_evicted_seq_ fetch the live snapshots behind it.
+ // We use max as the version of snapshots to identify how fresh are the
+ // snapshot list. This works because the snapshots are between 0 and
+ // max, so the larger the max, the more complete they are.
+ SequenceNumber new_snapshots_version = new_max;
+ std::vector<SequenceNumber> snapshots;
+ bool update_snapshots = false;
+ if (new_snapshots_version > snapshots_version_) {
+ // This is to avoid updating the snapshots_ if it already updated
+ // with a more recent vesion by a concrrent thread
+ update_snapshots = true;
+ // We only care about snapshots lower then max
+ snapshots = GetSnapshotListFromDB(new_max);
+ }
+ if (update_snapshots) {
+ UpdateSnapshots(snapshots, new_snapshots_version);
+ if (!snapshots.empty()) {
+ WriteLock wl(&old_commit_map_mutex_);
+ for (auto snap : snapshots) {
+ // This allows IsInSnapshot to tell apart the reads from in valid
+ // snapshots from the reads from committed values in valid snapshots.
+ old_commit_map_[snap];
+ }
+ old_commit_map_empty_.store(false, std::memory_order_release);
+ }
+ }
+ auto updated_prev_max = prev_max;
+ TEST_SYNC_POINT("AdvanceMaxEvictedSeq::update_max:pause");
+ TEST_SYNC_POINT("AdvanceMaxEvictedSeq::update_max:resume");
+ while (updated_prev_max < new_max &&
+ !max_evicted_seq_.compare_exchange_weak(updated_prev_max, new_max,
+ std::memory_order_acq_rel,
+ std::memory_order_relaxed)) {
+ };
+}
+
+const Snapshot* WritePreparedTxnDB::GetSnapshot() {
+ const bool kForWWConflictCheck = true;
+ return GetSnapshotInternal(!kForWWConflictCheck);
+}
+
+SnapshotImpl* WritePreparedTxnDB::GetSnapshotInternal(
+ bool for_ww_conflict_check) {
+ // Note: for this optimization setting the last sequence number and obtaining
+ // the smallest uncommitted seq should be done atomically. However to avoid
+ // the mutex overhead, we call SmallestUnCommittedSeq BEFORE taking the
+ // snapshot. Since we always updated the list of unprepared seq (via
+ // AddPrepared) AFTER the last sequence is updated, this guarantees that the
+ // smallest uncommitted seq that we pair with the snapshot is smaller or equal
+ // the value that would be obtained otherwise atomically. That is ok since
+ // this optimization works as long as min_uncommitted is less than or equal
+ // than the smallest uncommitted seq when the snapshot was taken.
+ auto min_uncommitted = WritePreparedTxnDB::SmallestUnCommittedSeq();
+ SnapshotImpl* snap_impl = db_impl_->GetSnapshotImpl(for_ww_conflict_check);
+ TEST_SYNC_POINT("WritePreparedTxnDB::GetSnapshotInternal:first");
+ assert(snap_impl);
+ SequenceNumber snap_seq = snap_impl->GetSequenceNumber();
+ // Note: Check against future_max_evicted_seq_ (in contrast with
+ // max_evicted_seq_) in case there is a concurrent AdvanceMaxEvictedSeq.
+ if (UNLIKELY(snap_seq != 0 && snap_seq <= future_max_evicted_seq_)) {
+ // There is a very rare case in which the commit entry evicts another commit
+ // entry that is not published yet thus advancing max evicted seq beyond the
+ // last published seq. This case is not likely in real-world setup so we
+ // handle it with a few retries.
+ size_t retry = 0;
+ SequenceNumber max;
+ while ((max = future_max_evicted_seq_.load()) != 0 &&
+ snap_impl->GetSequenceNumber() <= max && retry < 100) {
+ ROCKS_LOG_WARN(info_log_,
+ "GetSnapshot snap: %" PRIu64 " max: %" PRIu64
+ " retry %" ROCKSDB_PRIszt,
+ snap_impl->GetSequenceNumber(), max, retry);
+ ReleaseSnapshot(snap_impl);
+ // Wait for last visible seq to catch up with max, and also go beyond it
+ // by one.
+ AdvanceSeqByOne();
+ snap_impl = db_impl_->GetSnapshotImpl(for_ww_conflict_check);
+ assert(snap_impl);
+ retry++;
+ }
+ assert(snap_impl->GetSequenceNumber() > max);
+ if (snap_impl->GetSequenceNumber() <= max) {
+ throw std::runtime_error(
+ "Snapshot seq " + std::to_string(snap_impl->GetSequenceNumber()) +
+ " after " + std::to_string(retry) +
+ " retries is still less than futre_max_evicted_seq_" +
+ std::to_string(max));
+ }
+ }
+ EnhanceSnapshot(snap_impl, min_uncommitted);
+ ROCKS_LOG_DETAILS(
+ db_impl_->immutable_db_options().info_log,
+ "GetSnapshot %" PRIu64 " ww:%" PRIi32 " min_uncommitted: %" PRIu64,
+ snap_impl->GetSequenceNumber(), for_ww_conflict_check, min_uncommitted);
+ TEST_SYNC_POINT("WritePreparedTxnDB::GetSnapshotInternal:end");
+ return snap_impl;
+}
+
+void WritePreparedTxnDB::AdvanceSeqByOne() {
+ // Inserting an empty value will i) let the max evicted entry to be
+ // published, i.e., max == last_published, increase the last published to
+ // be one beyond max, i.e., max < last_published.
+ WriteOptions woptions;
+ TransactionOptions txn_options;
+ Transaction* txn0 = BeginTransaction(woptions, txn_options, nullptr);
+ std::hash<std::thread::id> hasher;
+ char name[64];
+ snprintf(name, 64, "txn%" ROCKSDB_PRIszt, hasher(std::this_thread::get_id()));
+ assert(strlen(name) < 64 - 1);
+ Status s = txn0->SetName(name);
+ assert(s.ok());
+ if (s.ok()) {
+ // Without prepare it would simply skip the commit
+ s = txn0->Prepare();
+ }
+ assert(s.ok());
+ if (s.ok()) {
+ s = txn0->Commit();
+ }
+ assert(s.ok());
+ delete txn0;
+}
+
+const std::vector<SequenceNumber> WritePreparedTxnDB::GetSnapshotListFromDB(
+ SequenceNumber max) {
+ ROCKS_LOG_DETAILS(info_log_, "GetSnapshotListFromDB with max %" PRIu64, max);
+ InstrumentedMutexLock dblock(db_impl_->mutex());
+ db_impl_->mutex()->AssertHeld();
+ return db_impl_->snapshots().GetAll(nullptr, max);
+}
+
+void WritePreparedTxnDB::ReleaseSnapshotInternal(
+ const SequenceNumber snap_seq) {
+ // TODO(myabandeh): relax should enough since the synchronizatin is already
+ // done by snapshots_mutex_ under which this function is called.
+ if (snap_seq <= max_evicted_seq_.load(std::memory_order_acquire)) {
+ // Then this is a rare case that transaction did not finish before max
+ // advances. It is expected for a few read-only backup snapshots. For such
+ // snapshots we might have kept around a couple of entries in the
+ // old_commit_map_. Check and do garbage collection if that is the case.
+ bool need_gc = false;
+ {
+ WPRecordTick(TXN_OLD_COMMIT_MAP_MUTEX_OVERHEAD);
+ ROCKS_LOG_WARN(info_log_, "old_commit_map_mutex_ overhead for %" PRIu64,
+ snap_seq);
+ ReadLock rl(&old_commit_map_mutex_);
+ auto prep_set_entry = old_commit_map_.find(snap_seq);
+ need_gc = prep_set_entry != old_commit_map_.end();
+ }
+ if (need_gc) {
+ WPRecordTick(TXN_OLD_COMMIT_MAP_MUTEX_OVERHEAD);
+ ROCKS_LOG_WARN(info_log_, "old_commit_map_mutex_ overhead for %" PRIu64,
+ snap_seq);
+ WriteLock wl(&old_commit_map_mutex_);
+ old_commit_map_.erase(snap_seq);
+ old_commit_map_empty_.store(old_commit_map_.empty(),
+ std::memory_order_release);
+ }
+ }
+}
+
+void WritePreparedTxnDB::CleanupReleasedSnapshots(
+ const std::vector<SequenceNumber>& new_snapshots,
+ const std::vector<SequenceNumber>& old_snapshots) {
+ auto newi = new_snapshots.begin();
+ auto oldi = old_snapshots.begin();
+ for (; newi != new_snapshots.end() && oldi != old_snapshots.end();) {
+ assert(*newi >= *oldi); // cannot have new snapshots with lower seq
+ if (*newi == *oldi) { // still not released
+ auto value = *newi;
+ while (newi != new_snapshots.end() && *newi == value) {
+ newi++;
+ }
+ while (oldi != old_snapshots.end() && *oldi == value) {
+ oldi++;
+ }
+ } else {
+ assert(*newi > *oldi); // *oldi is released
+ ReleaseSnapshotInternal(*oldi);
+ oldi++;
+ }
+ }
+ // Everything remained in old_snapshots is released and must be cleaned up
+ for (; oldi != old_snapshots.end(); oldi++) {
+ ReleaseSnapshotInternal(*oldi);
+ }
+}
+
+void WritePreparedTxnDB::UpdateSnapshots(
+ const std::vector<SequenceNumber>& snapshots,
+ const SequenceNumber& version) {
+ ROCKS_LOG_DETAILS(info_log_, "UpdateSnapshots with version %" PRIu64,
+ version);
+ TEST_SYNC_POINT("WritePreparedTxnDB::UpdateSnapshots:p:start");
+ TEST_SYNC_POINT("WritePreparedTxnDB::UpdateSnapshots:s:start");
+#ifndef NDEBUG
+ size_t sync_i = 0;
+#endif
+ ROCKS_LOG_DETAILS(info_log_, "snapshots_mutex_ overhead");
+ WriteLock wl(&snapshots_mutex_);
+ snapshots_version_ = version;
+ // We update the list concurrently with the readers.
+ // Both new and old lists are sorted and the new list is subset of the
+ // previous list plus some new items. Thus if a snapshot repeats in
+ // both new and old lists, it will appear upper in the new list. So if
+ // we simply insert the new snapshots in order, if an overwritten item
+ // is still valid in the new list is either written to the same place in
+ // the array or it is written in a higher palce before it gets
+ // overwritten by another item. This guarantess a reader that reads the
+ // list bottom-up will eventaully see a snapshot that repeats in the
+ // update, either before it gets overwritten by the writer or
+ // afterwards.
+ size_t i = 0;
+ auto it = snapshots.begin();
+ for (; it != snapshots.end() && i < SNAPSHOT_CACHE_SIZE; ++it, ++i) {
+ snapshot_cache_[i].store(*it, std::memory_order_release);
+ TEST_IDX_SYNC_POINT("WritePreparedTxnDB::UpdateSnapshots:p:", ++sync_i);
+ TEST_IDX_SYNC_POINT("WritePreparedTxnDB::UpdateSnapshots:s:", sync_i);
+ }
+#ifndef NDEBUG
+ // Release the remaining sync points since they are useless given that the
+ // reader would also use lock to access snapshots
+ for (++sync_i; sync_i <= 10; ++sync_i) {
+ TEST_IDX_SYNC_POINT("WritePreparedTxnDB::UpdateSnapshots:p:", sync_i);
+ TEST_IDX_SYNC_POINT("WritePreparedTxnDB::UpdateSnapshots:s:", sync_i);
+ }
+#endif
+ snapshots_.clear();
+ for (; it != snapshots.end(); ++it) {
+ // Insert them to a vector that is less efficient to access
+ // concurrently
+ snapshots_.push_back(*it);
+ }
+ // Update the size at the end. Otherwise a parallel reader might read
+ // items that are not set yet.
+ snapshots_total_.store(snapshots.size(), std::memory_order_release);
+
+ // Note: this must be done after the snapshots data structures are updated
+ // with the new list of snapshots.
+ CleanupReleasedSnapshots(snapshots, snapshots_all_);
+ snapshots_all_ = snapshots;
+
+ TEST_SYNC_POINT("WritePreparedTxnDB::UpdateSnapshots:p:end");
+ TEST_SYNC_POINT("WritePreparedTxnDB::UpdateSnapshots:s:end");
+}
+
+void WritePreparedTxnDB::CheckAgainstSnapshots(const CommitEntry& evicted) {
+ TEST_SYNC_POINT("WritePreparedTxnDB::CheckAgainstSnapshots:p:start");
+ TEST_SYNC_POINT("WritePreparedTxnDB::CheckAgainstSnapshots:s:start");
+#ifndef NDEBUG
+ size_t sync_i = 0;
+#endif
+ // First check the snapshot cache that is efficient for concurrent access
+ auto cnt = snapshots_total_.load(std::memory_order_acquire);
+ // The list might get updated concurrently as we are reading from it. The
+ // reader should be able to read all the snapshots that are still valid
+ // after the update. Since the survived snapshots are written in a higher
+ // place before gets overwritten the reader that reads bottom-up will
+ // eventully see it.
+ const bool next_is_larger = true;
+ // We will set to true if the border line snapshot suggests that.
+ bool search_larger_list = false;
+ size_t ip1 = std::min(cnt, SNAPSHOT_CACHE_SIZE);
+ for (; 0 < ip1; ip1--) {
+ SequenceNumber snapshot_seq =
+ snapshot_cache_[ip1 - 1].load(std::memory_order_acquire);
+ TEST_IDX_SYNC_POINT("WritePreparedTxnDB::CheckAgainstSnapshots:p:",
+ ++sync_i);
+ TEST_IDX_SYNC_POINT("WritePreparedTxnDB::CheckAgainstSnapshots:s:", sync_i);
+ if (ip1 == SNAPSHOT_CACHE_SIZE) { // border line snapshot
+ // snapshot_seq < commit_seq => larger_snapshot_seq <= commit_seq
+ // then later also continue the search to larger snapshots
+ search_larger_list = snapshot_seq < evicted.commit_seq;
+ }
+ if (!MaybeUpdateOldCommitMap(evicted.prep_seq, evicted.commit_seq,
+ snapshot_seq, !next_is_larger)) {
+ break;
+ }
+ }
+#ifndef NDEBUG
+ // Release the remaining sync points before accquiring the lock
+ for (++sync_i; sync_i <= 10; ++sync_i) {
+ TEST_IDX_SYNC_POINT("WritePreparedTxnDB::CheckAgainstSnapshots:p:", sync_i);
+ TEST_IDX_SYNC_POINT("WritePreparedTxnDB::CheckAgainstSnapshots:s:", sync_i);
+ }
+#endif
+ TEST_SYNC_POINT("WritePreparedTxnDB::CheckAgainstSnapshots:p:end");
+ TEST_SYNC_POINT("WritePreparedTxnDB::CheckAgainstSnapshots:s:end");
+ if (UNLIKELY(SNAPSHOT_CACHE_SIZE < cnt && search_larger_list)) {
+ // Then access the less efficient list of snapshots_
+ WPRecordTick(TXN_SNAPSHOT_MUTEX_OVERHEAD);
+ ROCKS_LOG_WARN(info_log_,
+ "snapshots_mutex_ overhead for <%" PRIu64 ",%" PRIu64
+ "> with %" ROCKSDB_PRIszt " snapshots",
+ evicted.prep_seq, evicted.commit_seq, cnt);
+ ReadLock rl(&snapshots_mutex_);
+ // Items could have moved from the snapshots_ to snapshot_cache_ before
+ // accquiring the lock. To make sure that we do not miss a valid snapshot,
+ // read snapshot_cache_ again while holding the lock.
+ for (size_t i = 0; i < SNAPSHOT_CACHE_SIZE; i++) {
+ SequenceNumber snapshot_seq =
+ snapshot_cache_[i].load(std::memory_order_acquire);
+ if (!MaybeUpdateOldCommitMap(evicted.prep_seq, evicted.commit_seq,
+ snapshot_seq, next_is_larger)) {
+ break;
+ }
+ }
+ for (auto snapshot_seq_2 : snapshots_) {
+ if (!MaybeUpdateOldCommitMap(evicted.prep_seq, evicted.commit_seq,
+ snapshot_seq_2, next_is_larger)) {
+ break;
+ }
+ }
+ }
+}
+
+bool WritePreparedTxnDB::MaybeUpdateOldCommitMap(
+ const uint64_t& prep_seq, const uint64_t& commit_seq,
+ const uint64_t& snapshot_seq, const bool next_is_larger = true) {
+ // If we do not store an entry in old_commit_map_ we assume it is committed in
+ // all snapshots. If commit_seq <= snapshot_seq, it is considered already in
+ // the snapshot so we need not to keep the entry around for this snapshot.
+ if (commit_seq <= snapshot_seq) {
+ // continue the search if the next snapshot could be smaller than commit_seq
+ return !next_is_larger;
+ }
+ // then snapshot_seq < commit_seq
+ if (prep_seq <= snapshot_seq) { // overlapping range
+ WPRecordTick(TXN_OLD_COMMIT_MAP_MUTEX_OVERHEAD);
+ ROCKS_LOG_WARN(info_log_,
+ "old_commit_map_mutex_ overhead for %" PRIu64
+ " commit entry: <%" PRIu64 ",%" PRIu64 ">",
+ snapshot_seq, prep_seq, commit_seq);
+ WriteLock wl(&old_commit_map_mutex_);
+ old_commit_map_empty_.store(false, std::memory_order_release);
+ auto& vec = old_commit_map_[snapshot_seq];
+ vec.insert(std::upper_bound(vec.begin(), vec.end(), prep_seq), prep_seq);
+ // We need to store it once for each overlapping snapshot. Returning true to
+ // continue the search if there is more overlapping snapshot.
+ return true;
+ }
+ // continue the search if the next snapshot could be larger than prep_seq
+ return next_is_larger;
+}
+
+WritePreparedTxnDB::~WritePreparedTxnDB() {
+ // At this point there could be running compaction/flush holding a
+ // SnapshotChecker, which holds a pointer back to WritePreparedTxnDB.
+ // Make sure those jobs finished before destructing WritePreparedTxnDB.
+ if (!db_impl_->shutting_down_) {
+ db_impl_->CancelAllBackgroundWork(true /*wait*/);
+ }
+}
+
+void SubBatchCounter::InitWithComp(const uint32_t cf) {
+ auto cmp = comparators_[cf];
+ keys_[cf] = CFKeys(SetComparator(cmp));
+}
+
+void SubBatchCounter::AddKey(const uint32_t cf, const Slice& key) {
+ CFKeys& cf_keys = keys_[cf];
+ if (cf_keys.size() == 0) { // just inserted
+ InitWithComp(cf);
+ }
+ auto it = cf_keys.insert(key);
+ if (it.second == false) { // second is false if a element already existed.
+ batches_++;
+ keys_.clear();
+ InitWithComp(cf);
+ keys_[cf].insert(key);
+ }
+}
+
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/write_prepared_txn_db.h b/src/rocksdb/utilities/transactions/write_prepared_txn_db.h
new file mode 100644
index 000000000..25a382473
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/write_prepared_txn_db.h
@@ -0,0 +1,1125 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+#ifndef ROCKSDB_LITE
+
+#include <cinttypes>
+#include <mutex>
+#include <queue>
+#include <set>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "db/db_iter.h"
+#include "db/pre_release_callback.h"
+#include "db/read_callback.h"
+#include "db/snapshot_checker.h"
+#include "logging/logging.h"
+#include "rocksdb/db.h"
+#include "rocksdb/options.h"
+#include "rocksdb/utilities/transaction_db.h"
+#include "util/cast_util.h"
+#include "util/set_comparator.h"
+#include "util/string_util.h"
+#include "utilities/transactions/pessimistic_transaction.h"
+#include "utilities/transactions/pessimistic_transaction_db.h"
+#include "utilities/transactions/write_prepared_txn.h"
+
+namespace ROCKSDB_NAMESPACE {
+enum SnapshotBackup : bool { kUnbackedByDBSnapshot, kBackedByDBSnapshot };
+
+// A PessimisticTransactionDB that writes data to DB after prepare phase of 2PC.
+// In this way some data in the DB might not be committed. The DB provides
+// mechanisms to tell such data apart from committed data.
+class WritePreparedTxnDB : public PessimisticTransactionDB {
+ public:
+ explicit WritePreparedTxnDB(DB* db,
+ const TransactionDBOptions& txn_db_options)
+ : PessimisticTransactionDB(db, txn_db_options),
+ SNAPSHOT_CACHE_BITS(txn_db_options.wp_snapshot_cache_bits),
+ SNAPSHOT_CACHE_SIZE(static_cast<size_t>(1ull << SNAPSHOT_CACHE_BITS)),
+ COMMIT_CACHE_BITS(txn_db_options.wp_commit_cache_bits),
+ COMMIT_CACHE_SIZE(static_cast<size_t>(1ull << COMMIT_CACHE_BITS)),
+ FORMAT(COMMIT_CACHE_BITS) {
+ Init(txn_db_options);
+ }
+
+ explicit WritePreparedTxnDB(StackableDB* db,
+ const TransactionDBOptions& txn_db_options)
+ : PessimisticTransactionDB(db, txn_db_options),
+ SNAPSHOT_CACHE_BITS(txn_db_options.wp_snapshot_cache_bits),
+ SNAPSHOT_CACHE_SIZE(static_cast<size_t>(1ull << SNAPSHOT_CACHE_BITS)),
+ COMMIT_CACHE_BITS(txn_db_options.wp_commit_cache_bits),
+ COMMIT_CACHE_SIZE(static_cast<size_t>(1ull << COMMIT_CACHE_BITS)),
+ FORMAT(COMMIT_CACHE_BITS) {
+ Init(txn_db_options);
+ }
+
+ virtual ~WritePreparedTxnDB();
+
+ virtual Status Initialize(
+ const std::vector<size_t>& compaction_enabled_cf_indices,
+ const std::vector<ColumnFamilyHandle*>& handles) override;
+
+ Transaction* BeginTransaction(const WriteOptions& write_options,
+ const TransactionOptions& txn_options,
+ Transaction* old_txn) override;
+
+ using TransactionDB::Write;
+ Status Write(const WriteOptions& opts, WriteBatch* updates) override;
+
+ // Optimized version of ::Write that receives more optimization request such
+ // as skip_concurrency_control.
+ using PessimisticTransactionDB::Write;
+ Status Write(const WriteOptions& opts, const TransactionDBWriteOptimizations&,
+ WriteBatch* updates) override;
+
+ // Write the batch to the underlying DB and mark it as committed. Could be
+ // used by both directly from TxnDB or through a transaction.
+ Status WriteInternal(const WriteOptions& write_options, WriteBatch* batch,
+ size_t batch_cnt, WritePreparedTxn* txn);
+
+ using DB::Get;
+ virtual Status Get(const ReadOptions& options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ PinnableSlice* value) override;
+
+ using DB::MultiGet;
+ virtual std::vector<Status> MultiGet(
+ const ReadOptions& options,
+ const std::vector<ColumnFamilyHandle*>& column_family,
+ const std::vector<Slice>& keys,
+ std::vector<std::string>* values) override;
+
+ using DB::NewIterator;
+ virtual Iterator* NewIterator(const ReadOptions& options,
+ ColumnFamilyHandle* column_family) override;
+
+ using DB::NewIterators;
+ virtual Status NewIterators(
+ const ReadOptions& options,
+ const std::vector<ColumnFamilyHandle*>& column_families,
+ std::vector<Iterator*>* iterators) override;
+
+ // Check whether the transaction that wrote the value with sequence number seq
+ // is visible to the snapshot with sequence number snapshot_seq.
+ // Returns true if commit_seq <= snapshot_seq
+ // If the snapshot_seq is already released and snapshot_seq <= max, sets
+ // *snap_released to true and returns true as well.
+ inline bool IsInSnapshot(uint64_t prep_seq, uint64_t snapshot_seq,
+ uint64_t min_uncommitted = kMinUnCommittedSeq,
+ bool* snap_released = nullptr) const {
+ ROCKS_LOG_DETAILS(info_log_,
+ "IsInSnapshot %" PRIu64 " in %" PRIu64
+ " min_uncommitted %" PRIu64,
+ prep_seq, snapshot_seq, min_uncommitted);
+ assert(min_uncommitted >= kMinUnCommittedSeq);
+ // Caller is responsible to initialize snap_released.
+ assert(snap_released == nullptr || *snap_released == false);
+ // Here we try to infer the return value without looking into prepare list.
+ // This would help avoiding synchronization over a shared map.
+ // TODO(myabandeh): optimize this. This sequence of checks must be correct
+ // but not necessary efficient
+ if (prep_seq == 0) {
+ // Compaction will output keys to bottom-level with sequence number 0 if
+ // it is visible to the earliest snapshot.
+ ROCKS_LOG_DETAILS(
+ info_log_, "IsInSnapshot %" PRIu64 " in %" PRIu64 " returns %" PRId32,
+ prep_seq, snapshot_seq, 1);
+ return true;
+ }
+ if (snapshot_seq < prep_seq) {
+ // snapshot_seq < prep_seq <= commit_seq => snapshot_seq < commit_seq
+ ROCKS_LOG_DETAILS(
+ info_log_, "IsInSnapshot %" PRIu64 " in %" PRIu64 " returns %" PRId32,
+ prep_seq, snapshot_seq, 0);
+ return false;
+ }
+ if (prep_seq < min_uncommitted) {
+ ROCKS_LOG_DETAILS(info_log_,
+ "IsInSnapshot %" PRIu64 " in %" PRIu64
+ " returns %" PRId32
+ " because of min_uncommitted %" PRIu64,
+ prep_seq, snapshot_seq, 1, min_uncommitted);
+ return true;
+ }
+ // Commit of delayed prepared has two non-atomic steps: add to commit cache,
+ // remove from delayed prepared. Our reads from these two is also
+ // non-atomic. By looking into commit cache first thus we might not find the
+ // prep_seq neither in commit cache not in delayed_prepared_. To fix that i)
+ // we check if there was any delayed prepared BEFORE looking into commit
+ // cache, ii) if there was, we complete the search steps to be these: i)
+ // commit cache, ii) delayed prepared, commit cache again. In this way if
+ // the first query to commit cache missed the commit, the 2nd will catch it.
+ bool was_empty;
+ SequenceNumber max_evicted_seq_lb, max_evicted_seq_ub;
+ CommitEntry64b dont_care;
+ auto indexed_seq = prep_seq % COMMIT_CACHE_SIZE;
+ size_t repeats = 0;
+ do {
+ repeats++;
+ assert(repeats < 100);
+ if (UNLIKELY(repeats >= 100)) {
+ throw std::runtime_error(
+ "The read was intrupted 100 times by update to max_evicted_seq_. "
+ "This is unexpected in all setups");
+ }
+ max_evicted_seq_lb = max_evicted_seq_.load(std::memory_order_acquire);
+ TEST_SYNC_POINT(
+ "WritePreparedTxnDB::IsInSnapshot:max_evicted_seq_:pause");
+ TEST_SYNC_POINT(
+ "WritePreparedTxnDB::IsInSnapshot:max_evicted_seq_:resume");
+ was_empty = delayed_prepared_empty_.load(std::memory_order_acquire);
+ TEST_SYNC_POINT(
+ "WritePreparedTxnDB::IsInSnapshot:delayed_prepared_empty_:pause");
+ TEST_SYNC_POINT(
+ "WritePreparedTxnDB::IsInSnapshot:delayed_prepared_empty_:resume");
+ CommitEntry cached;
+ bool exist = GetCommitEntry(indexed_seq, &dont_care, &cached);
+ TEST_SYNC_POINT("WritePreparedTxnDB::IsInSnapshot:GetCommitEntry:pause");
+ TEST_SYNC_POINT("WritePreparedTxnDB::IsInSnapshot:GetCommitEntry:resume");
+ if (exist && prep_seq == cached.prep_seq) {
+ // It is committed and also not evicted from commit cache
+ ROCKS_LOG_DETAILS(
+ info_log_,
+ "IsInSnapshot %" PRIu64 " in %" PRIu64 " returns %" PRId32,
+ prep_seq, snapshot_seq, cached.commit_seq <= snapshot_seq);
+ return cached.commit_seq <= snapshot_seq;
+ }
+ // else it could be committed but not inserted in the map which could
+ // happen after recovery, or it could be committed and evicted by another
+ // commit, or never committed.
+
+ // At this point we don't know if it was committed or it is still prepared
+ max_evicted_seq_ub = max_evicted_seq_.load(std::memory_order_acquire);
+ if (UNLIKELY(max_evicted_seq_lb != max_evicted_seq_ub)) {
+ continue;
+ }
+ // Note: max_evicted_seq_ when we did GetCommitEntry <= max_evicted_seq_ub
+ if (max_evicted_seq_ub < prep_seq) {
+ // Not evicted from cache and also not present, so must be still
+ // prepared
+ ROCKS_LOG_DETAILS(info_log_,
+ "IsInSnapshot %" PRIu64 " in %" PRIu64
+ " returns %" PRId32,
+ prep_seq, snapshot_seq, 0);
+ return false;
+ }
+ TEST_SYNC_POINT("WritePreparedTxnDB::IsInSnapshot:prepared_mutex_:pause");
+ TEST_SYNC_POINT(
+ "WritePreparedTxnDB::IsInSnapshot:prepared_mutex_:resume");
+ if (!was_empty) {
+ // We should not normally reach here
+ WPRecordTick(TXN_PREPARE_MUTEX_OVERHEAD);
+ ReadLock rl(&prepared_mutex_);
+ ROCKS_LOG_WARN(
+ info_log_, "prepared_mutex_ overhead %" PRIu64 " for %" PRIu64,
+ static_cast<uint64_t>(delayed_prepared_.size()), prep_seq);
+ if (delayed_prepared_.find(prep_seq) != delayed_prepared_.end()) {
+ // This is the order: 1) delayed_prepared_commits_ update, 2) publish
+ // 3) delayed_prepared_ clean up. So check if it is the case of a late
+ // clenaup.
+ auto it = delayed_prepared_commits_.find(prep_seq);
+ if (it == delayed_prepared_commits_.end()) {
+ // Then it is not committed yet
+ ROCKS_LOG_DETAILS(info_log_,
+ "IsInSnapshot %" PRIu64 " in %" PRIu64
+ " returns %" PRId32,
+ prep_seq, snapshot_seq, 0);
+ return false;
+ } else {
+ ROCKS_LOG_DETAILS(info_log_,
+ "IsInSnapshot %" PRIu64 " in %" PRIu64
+ " commit: %" PRIu64 " returns %" PRId32,
+ prep_seq, snapshot_seq, it->second,
+ snapshot_seq <= it->second);
+ return it->second <= snapshot_seq;
+ }
+ } else {
+ // 2nd query to commit cache. Refer to was_empty comment above.
+ exist = GetCommitEntry(indexed_seq, &dont_care, &cached);
+ if (exist && prep_seq == cached.prep_seq) {
+ ROCKS_LOG_DETAILS(
+ info_log_,
+ "IsInSnapshot %" PRIu64 " in %" PRIu64 " returns %" PRId32,
+ prep_seq, snapshot_seq, cached.commit_seq <= snapshot_seq);
+ return cached.commit_seq <= snapshot_seq;
+ }
+ max_evicted_seq_ub = max_evicted_seq_.load(std::memory_order_acquire);
+ }
+ }
+ } while (UNLIKELY(max_evicted_seq_lb != max_evicted_seq_ub));
+ // When advancing max_evicted_seq_, we move older entires from prepared to
+ // delayed_prepared_. Also we move evicted entries from commit cache to
+ // old_commit_map_ if it overlaps with any snapshot. Since prep_seq <=
+ // max_evicted_seq_, we have three cases: i) in delayed_prepared_, ii) in
+ // old_commit_map_, iii) committed with no conflict with any snapshot. Case
+ // (i) delayed_prepared_ is checked above
+ if (max_evicted_seq_ub < snapshot_seq) { // then (ii) cannot be the case
+ // only (iii) is the case: committed
+ // commit_seq <= max_evicted_seq_ < snapshot_seq => commit_seq <
+ // snapshot_seq
+ ROCKS_LOG_DETAILS(
+ info_log_, "IsInSnapshot %" PRIu64 " in %" PRIu64 " returns %" PRId32,
+ prep_seq, snapshot_seq, 1);
+ return true;
+ }
+ // else (ii) might be the case: check the commit data saved for this
+ // snapshot. If there was no overlapping commit entry, then it is committed
+ // with a commit_seq lower than any live snapshot, including snapshot_seq.
+ if (old_commit_map_empty_.load(std::memory_order_acquire)) {
+ ROCKS_LOG_DETAILS(info_log_,
+ "IsInSnapshot %" PRIu64 " in %" PRIu64
+ " returns %" PRId32 " released=1",
+ prep_seq, snapshot_seq, 0);
+ assert(snap_released);
+ // This snapshot is not valid anymore. We cannot tell if prep_seq is
+ // committed before or after the snapshot. Return true but also set
+ // snap_released to true.
+ *snap_released = true;
+ return true;
+ }
+ {
+ // We should not normally reach here unless sapshot_seq is old. This is a
+ // rare case and it is ok to pay the cost of mutex ReadLock for such old,
+ // reading transactions.
+ WPRecordTick(TXN_OLD_COMMIT_MAP_MUTEX_OVERHEAD);
+ ReadLock rl(&old_commit_map_mutex_);
+ auto prep_set_entry = old_commit_map_.find(snapshot_seq);
+ bool found = prep_set_entry != old_commit_map_.end();
+ if (found) {
+ auto& vec = prep_set_entry->second;
+ found = std::binary_search(vec.begin(), vec.end(), prep_seq);
+ } else {
+ // coming from compaction
+ ROCKS_LOG_DETAILS(info_log_,
+ "IsInSnapshot %" PRIu64 " in %" PRIu64
+ " returns %" PRId32 " released=1",
+ prep_seq, snapshot_seq, 0);
+ // This snapshot is not valid anymore. We cannot tell if prep_seq is
+ // committed before or after the snapshot. Return true but also set
+ // snap_released to true.
+ assert(snap_released);
+ *snap_released = true;
+ return true;
+ }
+
+ if (!found) {
+ ROCKS_LOG_DETAILS(info_log_,
+ "IsInSnapshot %" PRIu64 " in %" PRIu64
+ " returns %" PRId32,
+ prep_seq, snapshot_seq, 1);
+ return true;
+ }
+ }
+ // (ii) it the case: it is committed but after the snapshot_seq
+ ROCKS_LOG_DETAILS(
+ info_log_, "IsInSnapshot %" PRIu64 " in %" PRIu64 " returns %" PRId32,
+ prep_seq, snapshot_seq, 0);
+ return false;
+ }
+
+ // Add the transaction with prepare sequence seq to the prepared list.
+ // Note: must be called serially with increasing seq on each call.
+ // locked is true if prepared_mutex_ is already locked.
+ void AddPrepared(uint64_t seq, bool locked = false);
+ // Check if any of the prepared txns are less than new max_evicted_seq_. Must
+ // be called with prepared_mutex_ write locked.
+ void CheckPreparedAgainstMax(SequenceNumber new_max, bool locked);
+ // Remove the transaction with prepare sequence seq from the prepared list
+ void RemovePrepared(const uint64_t seq, const size_t batch_cnt = 1);
+ // Add the transaction with prepare sequence prepare_seq and commit sequence
+ // commit_seq to the commit map. loop_cnt is to detect infinite loops.
+ // Note: must be called serially.
+ void AddCommitted(uint64_t prepare_seq, uint64_t commit_seq,
+ uint8_t loop_cnt = 0);
+
+ struct CommitEntry {
+ uint64_t prep_seq;
+ uint64_t commit_seq;
+ CommitEntry() : prep_seq(0), commit_seq(0) {}
+ CommitEntry(uint64_t ps, uint64_t cs) : prep_seq(ps), commit_seq(cs) {}
+ bool operator==(const CommitEntry& rhs) const {
+ return prep_seq == rhs.prep_seq && commit_seq == rhs.commit_seq;
+ }
+ };
+
+ struct CommitEntry64bFormat {
+ explicit CommitEntry64bFormat(size_t index_bits)
+ : INDEX_BITS(index_bits),
+ PREP_BITS(static_cast<size_t>(64 - PAD_BITS - INDEX_BITS)),
+ COMMIT_BITS(static_cast<size_t>(64 - PREP_BITS)),
+ COMMIT_FILTER(static_cast<uint64_t>((1ull << COMMIT_BITS) - 1)),
+ DELTA_UPPERBOUND(static_cast<uint64_t>((1ull << COMMIT_BITS))) {}
+ // Number of higher bits of a sequence number that is not used. They are
+ // used to encode the value type, ...
+ const size_t PAD_BITS = static_cast<size_t>(8);
+ // Number of lower bits from prepare seq that can be skipped as they are
+ // implied by the index of the entry in the array
+ const size_t INDEX_BITS;
+ // Number of bits we use to encode the prepare seq
+ const size_t PREP_BITS;
+ // Number of bits we use to encode the commit seq.
+ const size_t COMMIT_BITS;
+ // Filter to encode/decode commit seq
+ const uint64_t COMMIT_FILTER;
+ // The value of commit_seq - prepare_seq + 1 must be less than this bound
+ const uint64_t DELTA_UPPERBOUND;
+ };
+
+ // Prepare Seq (64 bits) = PAD ... PAD PREP PREP ... PREP INDEX INDEX ...
+ // INDEX Delta Seq (64 bits) = 0 0 0 0 0 0 0 0 0 0 0 0 DELTA DELTA ...
+ // DELTA DELTA Encoded Value = PREP PREP .... PREP PREP DELTA DELTA
+ // ... DELTA DELTA PAD: first bits of a seq that is reserved for tagging and
+ // hence ignored PREP/INDEX: the used bits in a prepare seq number INDEX: the
+ // bits that do not have to be encoded (will be provided externally) DELTA:
+ // prep seq - commit seq + 1 Number of DELTA bits should be equal to number of
+ // index bits + PADs
+ struct CommitEntry64b {
+ constexpr CommitEntry64b() noexcept : rep_(0) {}
+
+ CommitEntry64b(const CommitEntry& entry, const CommitEntry64bFormat& format)
+ : CommitEntry64b(entry.prep_seq, entry.commit_seq, format) {}
+
+ CommitEntry64b(const uint64_t ps, const uint64_t cs,
+ const CommitEntry64bFormat& format) {
+ assert(ps < static_cast<uint64_t>(
+ (1ull << (format.PREP_BITS + format.INDEX_BITS))));
+ assert(ps <= cs);
+ uint64_t delta = cs - ps + 1; // make initialized delta always >= 1
+ // zero is reserved for uninitialized entries
+ assert(0 < delta);
+ assert(delta < format.DELTA_UPPERBOUND);
+ if (delta >= format.DELTA_UPPERBOUND) {
+ throw std::runtime_error(
+ "commit_seq >> prepare_seq. The allowed distance is " +
+ std::to_string(format.DELTA_UPPERBOUND) + " commit_seq is " +
+ std::to_string(cs) + " prepare_seq is " + std::to_string(ps));
+ }
+ rep_ = (ps << format.PAD_BITS) & ~format.COMMIT_FILTER;
+ rep_ = rep_ | delta;
+ }
+
+ // Return false if the entry is empty
+ bool Parse(const uint64_t indexed_seq, CommitEntry* entry,
+ const CommitEntry64bFormat& format) {
+ uint64_t delta = rep_ & format.COMMIT_FILTER;
+ // zero is reserved for uninitialized entries
+ assert(delta < static_cast<uint64_t>((1ull << format.COMMIT_BITS)));
+ if (delta == 0) {
+ return false; // initialized entry would have non-zero delta
+ }
+
+ assert(indexed_seq < static_cast<uint64_t>((1ull << format.INDEX_BITS)));
+ uint64_t prep_up = rep_ & ~format.COMMIT_FILTER;
+ prep_up >>= format.PAD_BITS;
+ const uint64_t& prep_low = indexed_seq;
+ entry->prep_seq = prep_up | prep_low;
+
+ entry->commit_seq = entry->prep_seq + delta - 1;
+ return true;
+ }
+
+ private:
+ uint64_t rep_;
+ };
+
+ // Struct to hold ownership of snapshot and read callback for cleanup.
+ struct IteratorState;
+
+ std::shared_ptr<std::map<uint32_t, const Comparator*>> GetCFComparatorMap() {
+ return cf_map_;
+ }
+ std::shared_ptr<std::map<uint32_t, ColumnFamilyHandle*>> GetCFHandleMap() {
+ return handle_map_;
+ }
+ void UpdateCFComparatorMap(
+ const std::vector<ColumnFamilyHandle*>& handles) override;
+ void UpdateCFComparatorMap(ColumnFamilyHandle* handle) override;
+
+ virtual const Snapshot* GetSnapshot() override;
+ SnapshotImpl* GetSnapshotInternal(bool for_ww_conflict_check);
+
+ protected:
+ virtual Status VerifyCFOptions(
+ const ColumnFamilyOptions& cf_options) override;
+ // Assign the min and max sequence numbers for reading from the db. A seq >
+ // max is not valid, and a seq < min is valid, and a min <= seq < max requires
+ // further checking. Normally max is defined by the snapshot and min is by
+ // minimum uncommitted seq.
+ inline SnapshotBackup AssignMinMaxSeqs(const Snapshot* snapshot,
+ SequenceNumber* min,
+ SequenceNumber* max);
+ // Validate is a snapshot sequence number is still valid based on the latest
+ // db status. backed_by_snapshot specifies if the number is baked by an actual
+ // snapshot object. order specified the memory order with which we load the
+ // atomic variables: relax is enough for the default since we care about last
+ // value seen by same thread.
+ inline bool ValidateSnapshot(
+ const SequenceNumber snap_seq, const SnapshotBackup backed_by_snapshot,
+ std::memory_order order = std::memory_order_relaxed);
+ // Get a dummy snapshot that refers to kMaxSequenceNumber
+ Snapshot* GetMaxSnapshot() { return &dummy_max_snapshot_; }
+
+ bool ShouldRollbackWithSingleDelete(ColumnFamilyHandle* column_family,
+ const Slice& key) {
+ return rollback_deletion_type_callback_
+ ? rollback_deletion_type_callback_(this, column_family, key)
+ : false;
+ }
+
+ std::function<bool(TransactionDB*, ColumnFamilyHandle*, const Slice&)>
+ rollback_deletion_type_callback_;
+
+ private:
+ friend class AddPreparedCallback;
+ friend class PreparedHeap_BasicsTest_Test;
+ friend class PreparedHeap_Concurrent_Test;
+ friend class PreparedHeap_EmptyAtTheEnd_Test;
+ friend class SnapshotConcurrentAccessTest_SnapshotConcurrentAccess_Test;
+ friend class WritePreparedCommitEntryPreReleaseCallback;
+ friend class WritePreparedTransactionTestBase;
+ friend class WritePreparedTxn;
+ friend class WritePreparedTxnDBMock;
+ friend class WritePreparedTransactionTest_AddPreparedBeforeMax_Test;
+ friend class WritePreparedTransactionTest_AdvanceMaxEvictedSeqBasic_Test;
+ friend class
+ WritePreparedTransactionTest_AdvanceMaxEvictedSeqWithDuplicates_Test;
+ friend class WritePreparedTransactionTest_AdvanceSeqByOne_Test;
+ friend class WritePreparedTransactionTest_BasicRecovery_Test;
+ friend class WritePreparedTransactionTest_CheckAgainstSnapshots_Test;
+ friend class WritePreparedTransactionTest_CleanupSnapshotEqualToMax_Test;
+ friend class WritePreparedTransactionTest_ConflictDetectionAfterRecovery_Test;
+ friend class WritePreparedTransactionTest_CommitMap_Test;
+ friend class WritePreparedTransactionTest_DoubleSnapshot_Test;
+ friend class WritePreparedTransactionTest_IsInSnapshotEmptyMap_Test;
+ friend class WritePreparedTransactionTest_IsInSnapshotReleased_Test;
+ friend class WritePreparedTransactionTest_IsInSnapshot_Test;
+ friend class WritePreparedTransactionTest_NewSnapshotLargerThanMax_Test;
+ friend class WritePreparedTransactionTest_MaxCatchupWithNewSnapshot_Test;
+ friend class WritePreparedTransactionTest_MaxCatchupWithUnbackedSnapshot_Test;
+ friend class
+ WritePreparedTransactionTest_NonAtomicCommitOfDelayedPrepared_Test;
+ friend class
+ WritePreparedTransactionTest_NonAtomicUpdateOfDelayedPrepared_Test;
+ friend class WritePreparedTransactionTest_NonAtomicUpdateOfMaxEvictedSeq_Test;
+ friend class WritePreparedTransactionTest_OldCommitMapGC_Test;
+ friend class WritePreparedTransactionTest_Rollback_Test;
+ friend class WritePreparedTransactionTest_SmallestUnCommittedSeq_Test;
+ friend class WriteUnpreparedTxn;
+ friend class WriteUnpreparedTxnDB;
+ friend class WriteUnpreparedTransactionTest_RecoveryTest_Test;
+ friend class MultiOpsTxnsStressTest;
+
+ void Init(const TransactionDBOptions& txn_db_opts);
+
+ void WPRecordTick(uint32_t ticker_type) const {
+ RecordTick(db_impl_->immutable_db_options_.statistics.get(), ticker_type);
+ }
+
+ // A heap with the amortized O(1) complexity for erase. It uses one extra heap
+ // to keep track of erased entries that are not yet on top of the main heap.
+ class PreparedHeap {
+ // The mutex is required for push and pop from PreparedHeap. ::erase will
+ // use external synchronization via prepared_mutex_.
+ port::Mutex push_pop_mutex_;
+ std::deque<uint64_t> heap_;
+ std::priority_queue<uint64_t, std::vector<uint64_t>, std::greater<uint64_t>>
+ erased_heap_;
+ std::atomic<uint64_t> heap_top_ = {kMaxSequenceNumber};
+ // True when testing crash recovery
+ bool TEST_CRASH_ = false;
+ friend class WritePreparedTxnDB;
+
+ public:
+ ~PreparedHeap() {
+ if (!TEST_CRASH_) {
+ assert(heap_.empty());
+ assert(erased_heap_.empty());
+ }
+ }
+ port::Mutex* push_pop_mutex() { return &push_pop_mutex_; }
+
+ inline bool empty() { return top() == kMaxSequenceNumber; }
+ // Returns kMaxSequenceNumber if empty() and the smallest otherwise.
+ inline uint64_t top() { return heap_top_.load(std::memory_order_acquire); }
+ inline void push(uint64_t v) {
+ push_pop_mutex_.AssertHeld();
+ if (heap_.empty()) {
+ heap_top_.store(v, std::memory_order_release);
+ } else {
+ assert(heap_top_.load() < v);
+ }
+ heap_.push_back(v);
+ }
+ void pop(bool locked = false) {
+ if (!locked) {
+ push_pop_mutex()->Lock();
+ }
+ push_pop_mutex_.AssertHeld();
+ heap_.pop_front();
+ while (!heap_.empty() && !erased_heap_.empty() &&
+ // heap_.top() > erased_heap_.top() could happen if we have erased
+ // a non-existent entry. Ideally the user should not do that but we
+ // should be resilient against it.
+ heap_.front() >= erased_heap_.top()) {
+ if (heap_.front() == erased_heap_.top()) {
+ heap_.pop_front();
+ }
+ uint64_t erased __attribute__((__unused__));
+ erased = erased_heap_.top();
+ erased_heap_.pop();
+ // No duplicate prepare sequence numbers
+ assert(erased_heap_.empty() || erased_heap_.top() != erased);
+ }
+ while (heap_.empty() && !erased_heap_.empty()) {
+ erased_heap_.pop();
+ }
+ heap_top_.store(!heap_.empty() ? heap_.front() : kMaxSequenceNumber,
+ std::memory_order_release);
+ if (!locked) {
+ push_pop_mutex()->Unlock();
+ }
+ }
+ // Concurrrent calls needs external synchronization. It is safe to be called
+ // concurrent to push and pop though.
+ void erase(uint64_t seq) {
+ if (!empty()) {
+ auto top_seq = top();
+ if (seq < top_seq) {
+ // Already popped, ignore it.
+ } else if (top_seq == seq) {
+ pop();
+#ifndef NDEBUG
+ MutexLock ml(push_pop_mutex());
+ assert(heap_.empty() || heap_.front() != seq);
+#endif
+ } else { // top() > seq
+ // Down the heap, remember to pop it later
+ erased_heap_.push(seq);
+ }
+ }
+ }
+ };
+
+ void TEST_Crash() override { prepared_txns_.TEST_CRASH_ = true; }
+
+ // Get the commit entry with index indexed_seq from the commit table. It
+ // returns true if such entry exists.
+ bool GetCommitEntry(const uint64_t indexed_seq, CommitEntry64b* entry_64b,
+ CommitEntry* entry) const;
+
+ // Rewrite the entry with the index indexed_seq in the commit table with the
+ // commit entry <prep_seq, commit_seq>. If the rewrite results into eviction,
+ // sets the evicted_entry and returns true.
+ bool AddCommitEntry(const uint64_t indexed_seq, const CommitEntry& new_entry,
+ CommitEntry* evicted_entry);
+
+ // Rewrite the entry with the index indexed_seq in the commit table with the
+ // commit entry new_entry only if the existing entry matches the
+ // expected_entry. Returns false otherwise.
+ bool ExchangeCommitEntry(const uint64_t indexed_seq,
+ CommitEntry64b& expected_entry,
+ const CommitEntry& new_entry);
+
+ // Increase max_evicted_seq_ from the previous value prev_max to the new
+ // value. This also involves taking care of prepared txns that are not
+ // committed before new_max, as well as updating the list of live snapshots at
+ // the time of updating the max. Thread-safety: this function can be called
+ // concurrently. The concurrent invocations of this function is equivalent to
+ // a serial invocation in which the last invocation is the one with the
+ // largest new_max value.
+ void AdvanceMaxEvictedSeq(const SequenceNumber& prev_max,
+ const SequenceNumber& new_max);
+
+ inline SequenceNumber SmallestUnCommittedSeq() {
+ // Note: We have two lists to look into, but for performance reasons they
+ // are not read atomically. Since CheckPreparedAgainstMax copies the entry
+ // to delayed_prepared_ before removing it from prepared_txns_, to ensure
+ // that a prepared entry will not go unmissed, we look into them in opposite
+ // order: first read prepared_txns_ and then delayed_prepared_.
+
+ // This must be called before calling ::top. This is because the concurrent
+ // thread would call ::RemovePrepared before updating
+ // GetLatestSequenceNumber(). Reading then in opposite order here guarantees
+ // that the ::top that we read would be lower the ::top if we had otherwise
+ // update/read them atomically.
+ auto next_prepare = db_impl_->GetLatestSequenceNumber() + 1;
+ auto min_prepare = prepared_txns_.top();
+ // Since we update the prepare_heap always from the main write queue via
+ // PreReleaseCallback, the prepared_txns_.top() indicates the smallest
+ // prepared data in 2pc transactions. For non-2pc transactions that are
+ // written in two steps, we also update prepared_txns_ at the first step
+ // (via the same mechanism) so that their uncommitted data is reflected in
+ // SmallestUnCommittedSeq.
+ if (!delayed_prepared_empty_.load()) {
+ ReadLock rl(&prepared_mutex_);
+ if (!delayed_prepared_.empty()) {
+ return *delayed_prepared_.begin();
+ }
+ }
+ bool empty = min_prepare == kMaxSequenceNumber;
+ if (empty) {
+ // Since GetLatestSequenceNumber is updated
+ // after prepared_txns_ are, the value of GetLatestSequenceNumber would
+ // reflect any uncommitted data that is not added to prepared_txns_ yet.
+ // Otherwise, if there is no concurrent txn, this value simply reflects
+ // that latest value in the memtable.
+ return next_prepare;
+ } else {
+ return std::min(min_prepare, next_prepare);
+ }
+ }
+
+ // Enhance the snapshot object by recording in it the smallest uncommitted seq
+ inline void EnhanceSnapshot(SnapshotImpl* snapshot,
+ SequenceNumber min_uncommitted) {
+ assert(snapshot);
+ assert(min_uncommitted <= snapshot->number_ + 1);
+ snapshot->min_uncommitted_ = min_uncommitted;
+ }
+
+ virtual const std::vector<SequenceNumber> GetSnapshotListFromDB(
+ SequenceNumber max);
+
+ // Will be called by the public ReleaseSnapshot method. Does the maintenance
+ // internal to WritePreparedTxnDB
+ void ReleaseSnapshotInternal(const SequenceNumber snap_seq);
+
+ // Update the list of snapshots corresponding to the soon-to-be-updated
+ // max_evicted_seq_. Thread-safety: this function can be called concurrently.
+ // The concurrent invocations of this function is equivalent to a serial
+ // invocation in which the last invocation is the one with the largest
+ // version value.
+ void UpdateSnapshots(const std::vector<SequenceNumber>& snapshots,
+ const SequenceNumber& version);
+ // Check the new list of new snapshots against the old one to see if any of
+ // the snapshots are released and to do the cleanup for the released snapshot.
+ void CleanupReleasedSnapshots(
+ const std::vector<SequenceNumber>& new_snapshots,
+ const std::vector<SequenceNumber>& old_snapshots);
+
+ // Check an evicted entry against live snapshots to see if it should be kept
+ // around or it can be safely discarded (and hence assume committed for all
+ // snapshots). Thread-safety: this function can be called concurrently. If it
+ // is called concurrently with multiple UpdateSnapshots, the result is the
+ // same as checking the intersection of the snapshot list before updates with
+ // the snapshot list of all the concurrent updates.
+ void CheckAgainstSnapshots(const CommitEntry& evicted);
+
+ // Add a new entry to old_commit_map_ if prep_seq <= snapshot_seq <
+ // commit_seq. Return false if checking the next snapshot(s) is not needed.
+ // This is the case if none of the next snapshots could satisfy the condition.
+ // next_is_larger: the next snapshot will be a larger value
+ bool MaybeUpdateOldCommitMap(const uint64_t& prep_seq,
+ const uint64_t& commit_seq,
+ const uint64_t& snapshot_seq,
+ const bool next_is_larger);
+
+ // A trick to increase the last visible sequence number by one and also wait
+ // for the in-flight commits to be visible.
+ void AdvanceSeqByOne();
+
+ // The list of live snapshots at the last time that max_evicted_seq_ advanced.
+ // The list stored into two data structures: in snapshot_cache_ that is
+ // efficient for concurrent reads, and in snapshots_ if the data does not fit
+ // into snapshot_cache_. The total number of snapshots in the two lists
+ std::atomic<size_t> snapshots_total_ = {};
+ // The list sorted in ascending order. Thread-safety for writes is provided
+ // with snapshots_mutex_ and concurrent reads are safe due to std::atomic for
+ // each entry. In x86_64 architecture such reads are compiled to simple read
+ // instructions.
+ const size_t SNAPSHOT_CACHE_BITS;
+ const size_t SNAPSHOT_CACHE_SIZE;
+ std::unique_ptr<std::atomic<SequenceNumber>[]> snapshot_cache_;
+ // 2nd list for storing snapshots. The list sorted in ascending order.
+ // Thread-safety is provided with snapshots_mutex_.
+ std::vector<SequenceNumber> snapshots_;
+ // The list of all snapshots: snapshots_ + snapshot_cache_. This list although
+ // redundant but simplifies CleanupOldSnapshots implementation.
+ // Thread-safety is provided with snapshots_mutex_.
+ std::vector<SequenceNumber> snapshots_all_;
+ // The version of the latest list of snapshots. This can be used to avoid
+ // rewriting a list that is concurrently updated with a more recent version.
+ SequenceNumber snapshots_version_ = 0;
+
+ // A heap of prepared transactions. Thread-safety is provided with
+ // prepared_mutex_.
+ PreparedHeap prepared_txns_;
+ const size_t COMMIT_CACHE_BITS;
+ const size_t COMMIT_CACHE_SIZE;
+ const CommitEntry64bFormat FORMAT;
+ // commit_cache_ must be initialized to zero to tell apart an empty index from
+ // a filled one. Thread-safety is provided with commit_cache_mutex_.
+ std::unique_ptr<std::atomic<CommitEntry64b>[]> commit_cache_;
+ // The largest evicted *commit* sequence number from the commit_cache_. If a
+ // seq is smaller than max_evicted_seq_ is might or might not be present in
+ // commit_cache_. So commit_cache_ must first be checked before consulting
+ // with max_evicted_seq_.
+ std::atomic<uint64_t> max_evicted_seq_ = {};
+ // Order: 1) update future_max_evicted_seq_ = new_max, 2)
+ // GetSnapshotListFromDB(new_max), max_evicted_seq_ = new_max. Since
+ // GetSnapshotInternal guarantess that the snapshot seq is larger than
+ // future_max_evicted_seq_, this guarantes that if a snapshot is not larger
+ // than max has already being looked at via a GetSnapshotListFromDB(new_max).
+ std::atomic<uint64_t> future_max_evicted_seq_ = {};
+ // Advance max_evicted_seq_ by this value each time it needs an update. The
+ // larger the value, the less frequent advances we would have. We do not want
+ // it to be too large either as it would cause stalls by doing too much
+ // maintenance work under the lock.
+ size_t INC_STEP_FOR_MAX_EVICTED = 1;
+ // A map from old snapshots (expected to be used by a few read-only txns) to
+ // prepared sequence number of the evicted entries from commit_cache_ that
+ // overlaps with such snapshot. These are the prepared sequence numbers that
+ // the snapshot, to which they are mapped, cannot assume to be committed just
+ // because it is no longer in the commit_cache_. The vector must be sorted
+ // after each update.
+ // Thread-safety is provided with old_commit_map_mutex_.
+ std::map<SequenceNumber, std::vector<SequenceNumber>> old_commit_map_;
+ // A set of long-running prepared transactions that are not finished by the
+ // time max_evicted_seq_ advances their sequence number. This is expected to
+ // be empty normally. Thread-safety is provided with prepared_mutex_.
+ std::set<uint64_t> delayed_prepared_;
+ // Commit of a delayed prepared: 1) update commit cache, 2) update
+ // delayed_prepared_commits_, 3) publish seq, 3) clean up delayed_prepared_.
+ // delayed_prepared_commits_ will help us tell apart the unprepared txns from
+ // the ones that are committed but not cleaned up yet.
+ std::unordered_map<SequenceNumber, SequenceNumber> delayed_prepared_commits_;
+ // Update when delayed_prepared_.empty() changes. Expected to be true
+ // normally.
+ std::atomic<bool> delayed_prepared_empty_ = {true};
+ // Update when old_commit_map_.empty() changes. Expected to be true normally.
+ std::atomic<bool> old_commit_map_empty_ = {true};
+ mutable port::RWMutex prepared_mutex_;
+ mutable port::RWMutex old_commit_map_mutex_;
+ mutable port::RWMutex commit_cache_mutex_;
+ mutable port::RWMutex snapshots_mutex_;
+ // A cache of the cf comparators
+ // Thread safety: since it is a const it is safe to read it concurrently
+ std::shared_ptr<std::map<uint32_t, const Comparator*>> cf_map_;
+ // A cache of the cf handles
+ // Thread safety: since the handle is read-only object it is a const it is
+ // safe to read it concurrently
+ std::shared_ptr<std::map<uint32_t, ColumnFamilyHandle*>> handle_map_;
+ // A dummy snapshot object that refers to kMaxSequenceNumber
+ SnapshotImpl dummy_max_snapshot_;
+};
+
+class WritePreparedTxnReadCallback : public ReadCallback {
+ public:
+ WritePreparedTxnReadCallback(WritePreparedTxnDB* db, SequenceNumber snapshot)
+ : ReadCallback(snapshot),
+ db_(db),
+ backed_by_snapshot_(kBackedByDBSnapshot) {}
+ WritePreparedTxnReadCallback(WritePreparedTxnDB* db, SequenceNumber snapshot,
+ SequenceNumber min_uncommitted,
+ SnapshotBackup backed_by_snapshot)
+ : ReadCallback(snapshot, min_uncommitted),
+ db_(db),
+ backed_by_snapshot_(backed_by_snapshot) {
+ (void)backed_by_snapshot_; // to silence unused private field warning
+ }
+
+ virtual ~WritePreparedTxnReadCallback() {
+ // If it is not backed by snapshot, the caller must check validity
+ assert(valid_checked_ || backed_by_snapshot_ == kBackedByDBSnapshot);
+ }
+
+ // Will be called to see if the seq number visible; if not it moves on to
+ // the next seq number.
+ inline virtual bool IsVisibleFullCheck(SequenceNumber seq) override {
+ auto snapshot = max_visible_seq_;
+ bool snap_released = false;
+ auto ret =
+ db_->IsInSnapshot(seq, snapshot, min_uncommitted_, &snap_released);
+ assert(!snap_released || backed_by_snapshot_ == kUnbackedByDBSnapshot);
+ snap_released_ |= snap_released;
+ return ret;
+ }
+
+ inline bool valid() {
+ valid_checked_ = true;
+ return snap_released_ == false;
+ }
+
+ // TODO(myabandeh): override Refresh when Iterator::Refresh is supported
+ private:
+ WritePreparedTxnDB* db_;
+ // Whether max_visible_seq_ is backed by a snapshot
+ const SnapshotBackup backed_by_snapshot_;
+ bool snap_released_ = false;
+ // Safety check to ensure that the caller has checked invalid statuses
+ bool valid_checked_ = false;
+};
+
+class AddPreparedCallback : public PreReleaseCallback {
+ public:
+ AddPreparedCallback(WritePreparedTxnDB* db, DBImpl* db_impl,
+ size_t sub_batch_cnt, bool two_write_queues,
+ bool first_prepare_batch)
+ : db_(db),
+ db_impl_(db_impl),
+ sub_batch_cnt_(sub_batch_cnt),
+ two_write_queues_(two_write_queues),
+ first_prepare_batch_(first_prepare_batch) {
+ (void)two_write_queues_; // to silence unused private field warning
+ }
+ virtual Status Callback(SequenceNumber prepare_seq,
+ bool is_mem_disabled __attribute__((__unused__)),
+ uint64_t log_number, size_t index,
+ size_t total) override {
+ assert(index < total);
+ // To reduce the cost of lock acquisition competing with the concurrent
+ // prepare requests, lock on the first callback and unlock on the last.
+ const bool do_lock = !two_write_queues_ || index == 0;
+ const bool do_unlock = !two_write_queues_ || index + 1 == total;
+ // Always Prepare from the main queue
+ assert(!two_write_queues_ || !is_mem_disabled); // implies the 1st queue
+ TEST_SYNC_POINT("AddPreparedCallback::AddPrepared::begin:pause");
+ TEST_SYNC_POINT("AddPreparedCallback::AddPrepared::begin:resume");
+ if (do_lock) {
+ db_->prepared_txns_.push_pop_mutex()->Lock();
+ }
+ const bool kLocked = true;
+ for (size_t i = 0; i < sub_batch_cnt_; i++) {
+ db_->AddPrepared(prepare_seq + i, kLocked);
+ }
+ if (do_unlock) {
+ db_->prepared_txns_.push_pop_mutex()->Unlock();
+ }
+ TEST_SYNC_POINT("AddPreparedCallback::AddPrepared::end");
+ if (first_prepare_batch_) {
+ assert(log_number != 0);
+ db_impl_->logs_with_prep_tracker()->MarkLogAsContainingPrepSection(
+ log_number);
+ }
+ return Status::OK();
+ }
+
+ private:
+ WritePreparedTxnDB* db_;
+ DBImpl* db_impl_;
+ size_t sub_batch_cnt_;
+ bool two_write_queues_;
+ // It is 2PC and this is the first prepare batch. Always the case in 2PC
+ // unless it is WriteUnPrepared.
+ bool first_prepare_batch_;
+};
+
+class WritePreparedCommitEntryPreReleaseCallback : public PreReleaseCallback {
+ public:
+ // includes_data indicates that the commit also writes non-empty
+ // CommitTimeWriteBatch to memtable, which needs to be committed separately.
+ WritePreparedCommitEntryPreReleaseCallback(
+ WritePreparedTxnDB* db, DBImpl* db_impl, SequenceNumber prep_seq,
+ size_t prep_batch_cnt, size_t data_batch_cnt = 0,
+ SequenceNumber aux_seq = kMaxSequenceNumber, size_t aux_batch_cnt = 0)
+ : db_(db),
+ db_impl_(db_impl),
+ prep_seq_(prep_seq),
+ prep_batch_cnt_(prep_batch_cnt),
+ data_batch_cnt_(data_batch_cnt),
+ includes_data_(data_batch_cnt_ > 0),
+ aux_seq_(aux_seq),
+ aux_batch_cnt_(aux_batch_cnt),
+ includes_aux_batch_(aux_batch_cnt > 0) {
+ assert((prep_batch_cnt_ > 0) != (prep_seq == kMaxSequenceNumber)); // xor
+ assert(prep_batch_cnt_ > 0 || data_batch_cnt_ > 0);
+ assert((aux_batch_cnt_ > 0) != (aux_seq == kMaxSequenceNumber)); // xor
+ }
+
+ virtual Status Callback(SequenceNumber commit_seq,
+ bool is_mem_disabled __attribute__((__unused__)),
+ uint64_t, size_t /*index*/,
+ size_t /*total*/) override {
+ // Always commit from the 2nd queue
+ assert(!db_impl_->immutable_db_options().two_write_queues ||
+ is_mem_disabled);
+ assert(includes_data_ || prep_seq_ != kMaxSequenceNumber);
+ // Data batch is what accompanied with the commit marker and affects the
+ // last seq in the commit batch.
+ const uint64_t last_commit_seq = LIKELY(data_batch_cnt_ <= 1)
+ ? commit_seq
+ : commit_seq + data_batch_cnt_ - 1;
+ if (prep_seq_ != kMaxSequenceNumber) {
+ for (size_t i = 0; i < prep_batch_cnt_; i++) {
+ db_->AddCommitted(prep_seq_ + i, last_commit_seq);
+ }
+ } // else there was no prepare phase
+ if (includes_aux_batch_) {
+ for (size_t i = 0; i < aux_batch_cnt_; i++) {
+ db_->AddCommitted(aux_seq_ + i, last_commit_seq);
+ }
+ }
+ if (includes_data_) {
+ assert(data_batch_cnt_);
+ // Commit the data that is accompanied with the commit request
+ for (size_t i = 0; i < data_batch_cnt_; i++) {
+ // For commit seq of each batch use the commit seq of the last batch.
+ // This would make debugging easier by having all the batches having
+ // the same sequence number.
+ db_->AddCommitted(commit_seq + i, last_commit_seq);
+ }
+ }
+ if (db_impl_->immutable_db_options().two_write_queues) {
+ assert(is_mem_disabled); // implies the 2nd queue
+ // Publish the sequence number. We can do that here assuming the callback
+ // is invoked only from one write queue, which would guarantee that the
+ // publish sequence numbers will be in order, i.e., once a seq is
+ // published all the seq prior to that are also publishable.
+ db_impl_->SetLastPublishedSequence(last_commit_seq);
+ // Note RemovePrepared should be called after publishing the seq.
+ // Otherwise SmallestUnCommittedSeq optimization breaks.
+ if (prep_seq_ != kMaxSequenceNumber) {
+ db_->RemovePrepared(prep_seq_, prep_batch_cnt_);
+ } // else there was no prepare phase
+ if (includes_aux_batch_) {
+ db_->RemovePrepared(aux_seq_, aux_batch_cnt_);
+ }
+ }
+ // else SequenceNumber that is updated as part of the write already does the
+ // publishing
+ return Status::OK();
+ }
+
+ private:
+ WritePreparedTxnDB* db_;
+ DBImpl* db_impl_;
+ // kMaxSequenceNumber if there was no prepare phase
+ SequenceNumber prep_seq_;
+ size_t prep_batch_cnt_;
+ size_t data_batch_cnt_;
+ // Data here is the batch that is written with the commit marker, either
+ // because it is commit without prepare or commit has a CommitTimeWriteBatch.
+ bool includes_data_;
+ // Auxiliary batch (if there is any) is a batch that is written before, but
+ // gets the same commit seq as prepare batch or data batch. This is used in
+ // two write queues where the CommitTimeWriteBatch becomes the aux batch and
+ // we do a separate write to actually commit everything.
+ SequenceNumber aux_seq_;
+ size_t aux_batch_cnt_;
+ bool includes_aux_batch_;
+};
+
+// For two_write_queues commit both the aborted batch and the cleanup batch and
+// then published the seq
+class WritePreparedRollbackPreReleaseCallback : public PreReleaseCallback {
+ public:
+ WritePreparedRollbackPreReleaseCallback(WritePreparedTxnDB* db,
+ DBImpl* db_impl,
+ SequenceNumber prep_seq,
+ SequenceNumber rollback_seq,
+ size_t prep_batch_cnt)
+ : db_(db),
+ db_impl_(db_impl),
+ prep_seq_(prep_seq),
+ rollback_seq_(rollback_seq),
+ prep_batch_cnt_(prep_batch_cnt) {
+ assert(prep_seq != kMaxSequenceNumber);
+ assert(rollback_seq != kMaxSequenceNumber);
+ assert(prep_batch_cnt_ > 0);
+ }
+
+ Status Callback(SequenceNumber commit_seq, bool is_mem_disabled, uint64_t,
+ size_t /*index*/, size_t /*total*/) override {
+ // Always commit from the 2nd queue
+ assert(is_mem_disabled); // implies the 2nd queue
+ assert(db_impl_->immutable_db_options().two_write_queues);
+#ifdef NDEBUG
+ (void)is_mem_disabled;
+#endif
+ const uint64_t last_commit_seq = commit_seq;
+ db_->AddCommitted(rollback_seq_, last_commit_seq);
+ for (size_t i = 0; i < prep_batch_cnt_; i++) {
+ db_->AddCommitted(prep_seq_ + i, last_commit_seq);
+ }
+ db_impl_->SetLastPublishedSequence(last_commit_seq);
+ return Status::OK();
+ }
+
+ private:
+ WritePreparedTxnDB* db_;
+ DBImpl* db_impl_;
+ SequenceNumber prep_seq_;
+ SequenceNumber rollback_seq_;
+ size_t prep_batch_cnt_;
+};
+
+// Count the number of sub-batches inside a batch. A sub-batch does not have
+// duplicate keys.
+struct SubBatchCounter : public WriteBatch::Handler {
+ explicit SubBatchCounter(std::map<uint32_t, const Comparator*>& comparators)
+ : comparators_(comparators), batches_(1) {}
+ std::map<uint32_t, const Comparator*>& comparators_;
+ using CFKeys = std::set<Slice, SetComparator>;
+ std::map<uint32_t, CFKeys> keys_;
+ size_t batches_;
+ size_t BatchCount() { return batches_; }
+ void AddKey(const uint32_t cf, const Slice& key);
+ void InitWithComp(const uint32_t cf);
+ Status MarkNoop(bool) override { return Status::OK(); }
+ Status MarkEndPrepare(const Slice&) override { return Status::OK(); }
+ Status MarkCommit(const Slice&) override { return Status::OK(); }
+ Status PutCF(uint32_t cf, const Slice& key, const Slice&) override {
+ AddKey(cf, key);
+ return Status::OK();
+ }
+ Status DeleteCF(uint32_t cf, const Slice& key) override {
+ AddKey(cf, key);
+ return Status::OK();
+ }
+ Status SingleDeleteCF(uint32_t cf, const Slice& key) override {
+ AddKey(cf, key);
+ return Status::OK();
+ }
+ Status MergeCF(uint32_t cf, const Slice& key, const Slice&) override {
+ AddKey(cf, key);
+ return Status::OK();
+ }
+ Status MarkBeginPrepare(bool) override { return Status::OK(); }
+ Status MarkRollback(const Slice&) override { return Status::OK(); }
+ Handler::OptionState WriteAfterCommit() const override {
+ return Handler::OptionState::kDisabled;
+ }
+};
+
+SnapshotBackup WritePreparedTxnDB::AssignMinMaxSeqs(const Snapshot* snapshot,
+ SequenceNumber* min,
+ SequenceNumber* max) {
+ if (snapshot != nullptr) {
+ *min =
+ static_cast_with_check<const SnapshotImpl>(snapshot)->min_uncommitted_;
+ *max = static_cast_with_check<const SnapshotImpl>(snapshot)->number_;
+ // A duplicate of the check in EnhanceSnapshot().
+ assert(*min <= *max + 1);
+ return kBackedByDBSnapshot;
+ } else {
+ *min = SmallestUnCommittedSeq();
+ *max = 0; // to be assigned later after sv is referenced.
+ return kUnbackedByDBSnapshot;
+ }
+}
+
+bool WritePreparedTxnDB::ValidateSnapshot(
+ const SequenceNumber snap_seq, const SnapshotBackup backed_by_snapshot,
+ std::memory_order order) {
+ if (backed_by_snapshot == kBackedByDBSnapshot) {
+ return true;
+ } else {
+ SequenceNumber max = max_evicted_seq_.load(order);
+ // Validate that max has not advanced the snapshot seq that is not backed
+ // by a real snapshot. This is a very rare case that should not happen in
+ // real workloads.
+ if (UNLIKELY(snap_seq <= max && snap_seq != 0)) {
+ return false;
+ }
+ }
+ return true;
+}
+
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/write_unprepared_transaction_test.cc b/src/rocksdb/utilities/transactions/write_unprepared_transaction_test.cc
new file mode 100644
index 000000000..6c8c62e0e
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/write_unprepared_transaction_test.cc
@@ -0,0 +1,790 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "utilities/transactions/transaction_test.h"
+#include "utilities/transactions/write_unprepared_txn.h"
+#include "utilities/transactions/write_unprepared_txn_db.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class WriteUnpreparedTransactionTestBase : public TransactionTestBase {
+ public:
+ WriteUnpreparedTransactionTestBase(bool use_stackable_db,
+ bool two_write_queue,
+ TxnDBWritePolicy write_policy)
+ : TransactionTestBase(use_stackable_db, two_write_queue, write_policy,
+ kOrderedWrite) {}
+};
+
+class WriteUnpreparedTransactionTest
+ : public WriteUnpreparedTransactionTestBase,
+ virtual public ::testing::WithParamInterface<
+ std::tuple<bool, bool, TxnDBWritePolicy>> {
+ public:
+ WriteUnpreparedTransactionTest()
+ : WriteUnpreparedTransactionTestBase(std::get<0>(GetParam()),
+ std::get<1>(GetParam()),
+ std::get<2>(GetParam())) {}
+};
+
+INSTANTIATE_TEST_CASE_P(
+ WriteUnpreparedTransactionTest, WriteUnpreparedTransactionTest,
+ ::testing::Values(std::make_tuple(false, false, WRITE_UNPREPARED),
+ std::make_tuple(false, true, WRITE_UNPREPARED)));
+
+enum StressAction { NO_SNAPSHOT, RO_SNAPSHOT, REFRESH_SNAPSHOT };
+class WriteUnpreparedStressTest : public WriteUnpreparedTransactionTestBase,
+ virtual public ::testing::WithParamInterface<
+ std::tuple<bool, StressAction>> {
+ public:
+ WriteUnpreparedStressTest()
+ : WriteUnpreparedTransactionTestBase(false, std::get<0>(GetParam()),
+ WRITE_UNPREPARED),
+ action_(std::get<1>(GetParam())) {}
+ StressAction action_;
+};
+
+INSTANTIATE_TEST_CASE_P(
+ WriteUnpreparedStressTest, WriteUnpreparedStressTest,
+ ::testing::Values(std::make_tuple(false, NO_SNAPSHOT),
+ std::make_tuple(false, RO_SNAPSHOT),
+ std::make_tuple(false, REFRESH_SNAPSHOT),
+ std::make_tuple(true, NO_SNAPSHOT),
+ std::make_tuple(true, RO_SNAPSHOT),
+ std::make_tuple(true, REFRESH_SNAPSHOT)));
+
+TEST_P(WriteUnpreparedTransactionTest, ReadYourOwnWrite) {
+ // The following tests checks whether reading your own write for
+ // a transaction works for write unprepared, when there are uncommitted
+ // values written into DB.
+ auto verify_state = [](Iterator* iter, const std::string& key,
+ const std::string& value) {
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_OK(iter->status());
+ ASSERT_EQ(key, iter->key().ToString());
+ ASSERT_EQ(value, iter->value().ToString());
+ };
+
+ // Test always reseeking vs never reseeking.
+ for (uint64_t max_skip : {0, std::numeric_limits<int>::max()}) {
+ options.max_sequential_skip_in_iterations = max_skip;
+ options.disable_auto_compactions = true;
+ ASSERT_OK(ReOpen());
+
+ TransactionOptions txn_options;
+ WriteOptions woptions;
+ ReadOptions roptions;
+
+ ASSERT_OK(db->Put(woptions, "a", ""));
+ ASSERT_OK(db->Put(woptions, "b", ""));
+
+ Transaction* txn = db->BeginTransaction(woptions, txn_options);
+ WriteUnpreparedTxn* wup_txn = dynamic_cast<WriteUnpreparedTxn*>(txn);
+ txn->SetSnapshot();
+
+ for (int i = 0; i < 5; i++) {
+ std::string stored_value = "v" + std::to_string(i);
+ ASSERT_OK(txn->Put("a", stored_value));
+ ASSERT_OK(txn->Put("b", stored_value));
+ ASSERT_OK(wup_txn->FlushWriteBatchToDB(false));
+
+ // Test Get()
+ std::string value;
+ ASSERT_OK(txn->Get(roptions, "a", &value));
+ ASSERT_EQ(value, stored_value);
+ ASSERT_OK(txn->Get(roptions, "b", &value));
+ ASSERT_EQ(value, stored_value);
+
+ // Test Next()
+ auto iter = txn->GetIterator(roptions);
+ iter->Seek("a");
+ verify_state(iter, "a", stored_value);
+
+ iter->Next();
+ verify_state(iter, "b", stored_value);
+
+ iter->SeekToFirst();
+ verify_state(iter, "a", stored_value);
+
+ iter->Next();
+ verify_state(iter, "b", stored_value);
+
+ delete iter;
+
+ // Test Prev()
+ iter = txn->GetIterator(roptions);
+ iter->SeekForPrev("b");
+ verify_state(iter, "b", stored_value);
+
+ iter->Prev();
+ verify_state(iter, "a", stored_value);
+
+ iter->SeekToLast();
+ verify_state(iter, "b", stored_value);
+
+ iter->Prev();
+ verify_state(iter, "a", stored_value);
+
+ delete iter;
+ }
+
+ delete txn;
+ }
+}
+
+#if !defined(ROCKSDB_VALGRIND_RUN) || defined(ROCKSDB_FULL_VALGRIND_RUN)
+TEST_P(WriteUnpreparedStressTest, ReadYourOwnWriteStress) {
+ // This is a stress test where different threads are writing random keys, and
+ // then before committing or aborting the transaction, it validates to see
+ // that it can read the keys it wrote, and the keys it did not write respect
+ // the snapshot. To avoid row lock contention (and simply stressing the
+ // locking system), each thread is mostly only writing to its own set of keys.
+ const uint32_t kNumIter = 1000;
+ const uint32_t kNumThreads = 10;
+ const uint32_t kNumKeys = 5;
+
+ // Test with
+ // 1. no snapshots set
+ // 2. snapshot set on ReadOptions
+ // 3. snapshot set, and refreshing after every write.
+ StressAction a = action_;
+ WriteOptions write_options;
+ txn_db_options.transaction_lock_timeout = -1;
+ options.disable_auto_compactions = true;
+ ASSERT_OK(ReOpen());
+
+ std::vector<std::string> keys;
+ for (uint32_t k = 0; k < kNumKeys * kNumThreads; k++) {
+ keys.push_back("k" + std::to_string(k));
+ }
+ RandomShuffle(keys.begin(), keys.end());
+
+ // This counter will act as a "sequence number" to help us validate
+ // visibility logic with snapshots. If we had direct access to the seqno of
+ // snapshots and key/values, then we should directly compare those instead.
+ std::atomic<int64_t> counter(0);
+
+ std::function<void(uint32_t)> stress_thread = [&](int id) {
+ size_t tid = std::hash<std::thread::id>()(std::this_thread::get_id());
+ Random64 rnd(static_cast<uint32_t>(tid));
+
+ Transaction* txn;
+ TransactionOptions txn_options;
+ // batch_size of 1 causes writes to DB for every marker.
+ txn_options.write_batch_flush_threshold = 1;
+ ReadOptions read_options;
+
+ for (uint32_t i = 0; i < kNumIter; i++) {
+ std::set<std::string> owned_keys(keys.begin() + id * kNumKeys,
+ keys.begin() + (id + 1) * kNumKeys);
+ // Add unowned keys to make the workload more interesting, but this
+ // increases row lock contention, so just do it sometimes.
+ if (rnd.OneIn(2)) {
+ owned_keys.insert(keys[rnd.Uniform(kNumKeys * kNumThreads)]);
+ }
+
+ txn = db->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(txn->SetName(std::to_string(id)));
+ txn->SetSnapshot();
+ if (a >= RO_SNAPSHOT) {
+ read_options.snapshot = txn->GetSnapshot();
+ ASSERT_TRUE(read_options.snapshot != nullptr);
+ }
+
+ uint64_t buf[2];
+ buf[0] = id;
+
+ // When scanning through the database, make sure that all unprepared
+ // keys have value >= snapshot and all other keys have value < snapshot.
+ int64_t snapshot_num = counter.fetch_add(1);
+
+ Status s;
+ for (const auto& key : owned_keys) {
+ buf[1] = counter.fetch_add(1);
+ s = txn->Put(key, Slice((const char*)buf, sizeof(buf)));
+ if (!s.ok()) {
+ break;
+ }
+ if (a == REFRESH_SNAPSHOT) {
+ txn->SetSnapshot();
+ read_options.snapshot = txn->GetSnapshot();
+ snapshot_num = counter.fetch_add(1);
+ }
+ }
+
+ // Failure is possible due to snapshot validation. In this case,
+ // rollback and move onto next iteration.
+ if (!s.ok()) {
+ ASSERT_TRUE(s.IsBusy());
+ ASSERT_OK(txn->Rollback());
+ delete txn;
+ continue;
+ }
+
+ auto verify_key = [&owned_keys, &a, &id, &snapshot_num](
+ const std::string& key, const std::string& value) {
+ if (owned_keys.count(key) > 0) {
+ ASSERT_EQ(value.size(), 16);
+
+ // Since this key is part of owned_keys, then this key must be
+ // unprepared by this transaction identified by 'id'
+ ASSERT_EQ(((int64_t*)value.c_str())[0], id);
+ if (a == REFRESH_SNAPSHOT) {
+ // If refresh snapshot is true, then the snapshot is refreshed
+ // after every Put(), meaning that the current snapshot in
+ // snapshot_num must be greater than the "seqno" of any keys
+ // written by the current transaction.
+ ASSERT_LT(((int64_t*)value.c_str())[1], snapshot_num);
+ } else {
+ // If refresh snapshot is not on, then the snapshot was taken at
+ // the beginning of the transaction, meaning all writes must come
+ // after snapshot_num
+ ASSERT_GT(((int64_t*)value.c_str())[1], snapshot_num);
+ }
+ } else if (a >= RO_SNAPSHOT) {
+ // If this is not an unprepared key, just assert that the key
+ // "seqno" is smaller than the snapshot seqno.
+ ASSERT_EQ(value.size(), 16);
+ ASSERT_LT(((int64_t*)value.c_str())[1], snapshot_num);
+ }
+ };
+
+ // Validate Get()/Next()/Prev(). Do only one of them to save time, and
+ // reduce lock contention.
+ switch (rnd.Uniform(3)) {
+ case 0: // Validate Get()
+ {
+ for (const auto& key : keys) {
+ std::string value;
+ s = txn->Get(read_options, Slice(key), &value);
+ if (!s.ok()) {
+ ASSERT_TRUE(s.IsNotFound());
+ ASSERT_EQ(owned_keys.count(key), 0);
+ } else {
+ verify_key(key, value);
+ }
+ }
+ break;
+ }
+ case 1: // Validate Next()
+ {
+ Iterator* iter = txn->GetIterator(read_options);
+ ASSERT_OK(iter->status());
+ for (iter->SeekToFirst(); iter->Valid(); iter->Next()) {
+ verify_key(iter->key().ToString(), iter->value().ToString());
+ }
+ ASSERT_OK(iter->status());
+ delete iter;
+ break;
+ }
+ case 2: // Validate Prev()
+ {
+ Iterator* iter = txn->GetIterator(read_options);
+ ASSERT_OK(iter->status());
+ for (iter->SeekToLast(); iter->Valid(); iter->Prev()) {
+ verify_key(iter->key().ToString(), iter->value().ToString());
+ }
+ ASSERT_OK(iter->status());
+ delete iter;
+ break;
+ }
+ default:
+ FAIL();
+ }
+
+ if (rnd.OneIn(2)) {
+ ASSERT_OK(txn->Commit());
+ } else {
+ ASSERT_OK(txn->Rollback());
+ }
+ delete txn;
+ }
+ };
+
+ std::vector<port::Thread> threads;
+ for (uint32_t i = 0; i < kNumThreads; i++) {
+ threads.emplace_back(stress_thread, i);
+ }
+
+ for (auto& t : threads) {
+ t.join();
+ }
+}
+#endif // !defined(ROCKSDB_VALGRIND_RUN) || defined(ROCKSDB_FULL_VALGRIND_RUN)
+
+// This tests how write unprepared behaves during recovery when the DB crashes
+// after a transaction has either been unprepared or prepared, and tests if
+// the changes are correctly applied for prepared transactions if we decide to
+// rollback/commit.
+TEST_P(WriteUnpreparedTransactionTest, RecoveryTest) {
+ WriteOptions write_options;
+ write_options.disableWAL = false;
+ TransactionOptions txn_options;
+ std::vector<Transaction*> prepared_trans;
+ WriteUnpreparedTxnDB* wup_db;
+ options.disable_auto_compactions = true;
+
+ enum Action { UNPREPARED, ROLLBACK, COMMIT };
+
+ // batch_size of 1 causes writes to DB for every marker.
+ for (size_t batch_size : {1, 1000000}) {
+ txn_options.write_batch_flush_threshold = batch_size;
+ for (bool empty : {true, false}) {
+ for (Action a : {UNPREPARED, ROLLBACK, COMMIT}) {
+ for (int num_batches = 1; num_batches < 10; num_batches++) {
+ // Reset database.
+ prepared_trans.clear();
+ ASSERT_OK(ReOpen());
+ wup_db = dynamic_cast<WriteUnpreparedTxnDB*>(db);
+ if (!empty) {
+ for (int i = 0; i < num_batches; i++) {
+ ASSERT_OK(db->Put(WriteOptions(), "k" + std::to_string(i),
+ "before value" + std::to_string(i)));
+ }
+ }
+
+ // Write num_batches unprepared batches.
+ Transaction* txn = db->BeginTransaction(write_options, txn_options);
+ WriteUnpreparedTxn* wup_txn = dynamic_cast<WriteUnpreparedTxn*>(txn);
+ ASSERT_OK(txn->SetName("xid"));
+ for (int i = 0; i < num_batches; i++) {
+ ASSERT_OK(
+ txn->Put("k" + std::to_string(i), "value" + std::to_string(i)));
+ if (txn_options.write_batch_flush_threshold == 1) {
+ // WriteUnprepared will check write_batch_flush_threshold and
+ // possibly flush before appending to the write batch. No flush
+ // will happen at the first write because the batch is still
+ // empty, so after k puts, there should be k-1 flushed batches.
+ ASSERT_EQ(wup_txn->GetUnpreparedSequenceNumbers().size(), i);
+ } else {
+ ASSERT_EQ(wup_txn->GetUnpreparedSequenceNumbers().size(), 0);
+ }
+ }
+ if (a == UNPREPARED) {
+ // This is done to prevent the destructor from rolling back the
+ // transaction for us, since we want to pretend we crashed and
+ // test that recovery does the rollback.
+ wup_txn->unprep_seqs_.clear();
+ } else {
+ ASSERT_OK(txn->Prepare());
+ }
+ delete txn;
+
+ // Crash and run recovery code paths.
+ ASSERT_OK(wup_db->db_impl_->FlushWAL(true));
+ wup_db->TEST_Crash();
+ ASSERT_OK(ReOpenNoDelete());
+ assert(db != nullptr);
+
+ db->GetAllPreparedTransactions(&prepared_trans);
+ ASSERT_EQ(prepared_trans.size(), a == UNPREPARED ? 0 : 1);
+ if (a == ROLLBACK) {
+ ASSERT_OK(prepared_trans[0]->Rollback());
+ delete prepared_trans[0];
+ } else if (a == COMMIT) {
+ ASSERT_OK(prepared_trans[0]->Commit());
+ delete prepared_trans[0];
+ }
+
+ Iterator* iter = db->NewIterator(ReadOptions());
+ ASSERT_OK(iter->status());
+ iter->SeekToFirst();
+ // Check that DB has before values.
+ if (!empty || a == COMMIT) {
+ for (int i = 0; i < num_batches; i++) {
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ(iter->key().ToString(), "k" + std::to_string(i));
+ if (a == COMMIT) {
+ ASSERT_EQ(iter->value().ToString(),
+ "value" + std::to_string(i));
+ } else {
+ ASSERT_EQ(iter->value().ToString(),
+ "before value" + std::to_string(i));
+ }
+ iter->Next();
+ }
+ }
+ ASSERT_FALSE(iter->Valid());
+ ASSERT_OK(iter->status());
+ delete iter;
+ }
+ }
+ }
+ }
+}
+
+// Basic test to see that unprepared batch gets written to DB when batch size
+// is exceeded. It also does some basic checks to see if commit/rollback works
+// as expected for write unprepared.
+TEST_P(WriteUnpreparedTransactionTest, UnpreparedBatch) {
+ WriteOptions write_options;
+ TransactionOptions txn_options;
+ const int kNumKeys = 10;
+
+ // batch_size of 1 causes writes to DB for every marker.
+ for (size_t batch_size : {1, 1000000}) {
+ txn_options.write_batch_flush_threshold = batch_size;
+ for (bool prepare : {false, true}) {
+ for (bool commit : {false, true}) {
+ ASSERT_OK(ReOpen());
+ Transaction* txn = db->BeginTransaction(write_options, txn_options);
+ WriteUnpreparedTxn* wup_txn = dynamic_cast<WriteUnpreparedTxn*>(txn);
+ ASSERT_OK(txn->SetName("xid"));
+
+ for (int i = 0; i < kNumKeys; i++) {
+ ASSERT_OK(txn->Put("k" + std::to_string(i), "v" + std::to_string(i)));
+ if (txn_options.write_batch_flush_threshold == 1) {
+ // WriteUnprepared will check write_batch_flush_threshold and
+ // possibly flush before appending to the write batch. No flush will
+ // happen at the first write because the batch is still empty, so
+ // after k puts, there should be k-1 flushed batches.
+ ASSERT_EQ(wup_txn->GetUnpreparedSequenceNumbers().size(), i);
+ } else {
+ ASSERT_EQ(wup_txn->GetUnpreparedSequenceNumbers().size(), 0);
+ }
+ }
+
+ if (prepare) {
+ ASSERT_OK(txn->Prepare());
+ }
+
+ Iterator* iter = db->NewIterator(ReadOptions());
+ ASSERT_OK(iter->status());
+ iter->SeekToFirst();
+ assert(!iter->Valid());
+ ASSERT_FALSE(iter->Valid());
+ ASSERT_OK(iter->status());
+ delete iter;
+
+ if (commit) {
+ ASSERT_OK(txn->Commit());
+ } else {
+ ASSERT_OK(txn->Rollback());
+ }
+ delete txn;
+
+ iter = db->NewIterator(ReadOptions());
+ ASSERT_OK(iter->status());
+ iter->SeekToFirst();
+
+ for (int i = 0; i < (commit ? kNumKeys : 0); i++) {
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ(iter->key().ToString(), "k" + std::to_string(i));
+ ASSERT_EQ(iter->value().ToString(), "v" + std::to_string(i));
+ iter->Next();
+ }
+ ASSERT_FALSE(iter->Valid());
+ ASSERT_OK(iter->status());
+ delete iter;
+ }
+ }
+ }
+}
+
+// Test whether logs containing unprepared/prepared batches are kept even
+// after memtable finishes flushing, and whether they are removed when
+// transaction commits/aborts.
+//
+// TODO(lth): Merge with TransactionTest/TwoPhaseLogRollingTest tests.
+TEST_P(WriteUnpreparedTransactionTest, MarkLogWithPrepSection) {
+ WriteOptions write_options;
+ TransactionOptions txn_options;
+ // batch_size of 1 causes writes to DB for every marker.
+ txn_options.write_batch_flush_threshold = 1;
+ const int kNumKeys = 10;
+
+ WriteOptions wopts;
+ wopts.sync = true;
+
+ for (bool prepare : {false, true}) {
+ for (bool commit : {false, true}) {
+ ASSERT_OK(ReOpen());
+ auto wup_db = dynamic_cast<WriteUnpreparedTxnDB*>(db);
+ auto db_impl = wup_db->db_impl_;
+
+ Transaction* txn1 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(txn1->SetName("xid1"));
+
+ Transaction* txn2 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(txn2->SetName("xid2"));
+
+ // Spread this transaction across multiple log files.
+ for (int i = 0; i < kNumKeys; i++) {
+ ASSERT_OK(txn1->Put("k1" + std::to_string(i), "v" + std::to_string(i)));
+ if (i >= kNumKeys / 2) {
+ ASSERT_OK(
+ txn2->Put("k2" + std::to_string(i), "v" + std::to_string(i)));
+ }
+
+ if (i > 0) {
+ ASSERT_OK(db_impl->TEST_SwitchWAL());
+ }
+ }
+
+ ASSERT_GT(txn1->GetLogNumber(), 0);
+ ASSERT_GT(txn2->GetLogNumber(), 0);
+
+ ASSERT_EQ(db_impl->TEST_FindMinLogContainingOutstandingPrep(),
+ txn1->GetLogNumber());
+ ASSERT_GT(db_impl->TEST_LogfileNumber(), txn1->GetLogNumber());
+
+ if (prepare) {
+ ASSERT_OK(txn1->Prepare());
+ ASSERT_OK(txn2->Prepare());
+ }
+
+ ASSERT_GE(db_impl->TEST_LogfileNumber(), txn1->GetLogNumber());
+ ASSERT_GE(db_impl->TEST_LogfileNumber(), txn2->GetLogNumber());
+
+ ASSERT_EQ(db_impl->TEST_FindMinLogContainingOutstandingPrep(),
+ txn1->GetLogNumber());
+ if (commit) {
+ ASSERT_OK(txn1->Commit());
+ } else {
+ ASSERT_OK(txn1->Rollback());
+ }
+
+ ASSERT_EQ(db_impl->TEST_FindMinLogContainingOutstandingPrep(),
+ txn2->GetLogNumber());
+
+ if (commit) {
+ ASSERT_OK(txn2->Commit());
+ } else {
+ ASSERT_OK(txn2->Rollback());
+ }
+
+ ASSERT_EQ(db_impl->TEST_FindMinLogContainingOutstandingPrep(), 0);
+
+ delete txn1;
+ delete txn2;
+ }
+ }
+}
+
+TEST_P(WriteUnpreparedTransactionTest, NoSnapshotWrite) {
+ WriteOptions woptions;
+ TransactionOptions txn_options;
+ txn_options.write_batch_flush_threshold = 1;
+
+ Transaction* txn = db->BeginTransaction(woptions, txn_options);
+
+ // Do some writes with no snapshot
+ ASSERT_OK(txn->Put("a", "a"));
+ ASSERT_OK(txn->Put("b", "b"));
+ ASSERT_OK(txn->Put("c", "c"));
+
+ // Test that it is still possible to create iterators after writes with no
+ // snapshot, if iterator snapshot is fresh enough.
+ ReadOptions roptions;
+ auto iter = txn->GetIterator(roptions);
+ ASSERT_OK(iter->status());
+ int keys = 0;
+ for (iter->SeekToLast(); iter->Valid(); iter->Prev(), keys++) {
+ ASSERT_OK(iter->status());
+ ASSERT_EQ(iter->key().ToString(), iter->value().ToString());
+ }
+ ASSERT_EQ(keys, 3);
+ ASSERT_OK(iter->status());
+
+ delete iter;
+ delete txn;
+}
+
+// Test whether write to a transaction while iterating is supported.
+TEST_P(WriteUnpreparedTransactionTest, IterateAndWrite) {
+ WriteOptions woptions;
+ TransactionOptions txn_options;
+ txn_options.write_batch_flush_threshold = 1;
+
+ enum Action { DO_DELETE, DO_UPDATE };
+
+ for (Action a : {DO_DELETE, DO_UPDATE}) {
+ for (int i = 0; i < 100; i++) {
+ ASSERT_OK(db->Put(woptions, std::to_string(i), std::to_string(i)));
+ }
+
+ Transaction* txn = db->BeginTransaction(woptions, txn_options);
+ // write_batch_ now contains 1 key.
+ ASSERT_OK(txn->Put("9", "a"));
+
+ ReadOptions roptions;
+ auto iter = txn->GetIterator(roptions);
+ ASSERT_OK(iter->status());
+ for (iter->SeekToFirst(); iter->Valid(); iter->Next()) {
+ ASSERT_OK(iter->status());
+ if (iter->key() == "9") {
+ ASSERT_EQ(iter->value().ToString(), "a");
+ } else {
+ ASSERT_EQ(iter->key().ToString(), iter->value().ToString());
+ }
+
+ if (a == DO_DELETE) {
+ ASSERT_OK(txn->Delete(iter->key()));
+ } else {
+ ASSERT_OK(txn->Put(iter->key(), "b"));
+ }
+ }
+ ASSERT_OK(iter->status());
+
+ delete iter;
+ ASSERT_OK(txn->Commit());
+
+ iter = db->NewIterator(roptions);
+ ASSERT_OK(iter->status());
+ if (a == DO_DELETE) {
+ // Check that db is empty.
+ iter->SeekToFirst();
+ ASSERT_FALSE(iter->Valid());
+ } else {
+ int keys = 0;
+ // Check that all values are updated to b.
+ for (iter->SeekToFirst(); iter->Valid(); iter->Next(), keys++) {
+ ASSERT_OK(iter->status());
+ ASSERT_EQ(iter->value().ToString(), "b");
+ }
+ ASSERT_EQ(keys, 100);
+ }
+ ASSERT_OK(iter->status());
+
+ delete iter;
+ delete txn;
+ }
+}
+
+// Test that using an iterator after transaction clear is not supported
+TEST_P(WriteUnpreparedTransactionTest, IterateAfterClear) {
+ WriteOptions woptions;
+ TransactionOptions txn_options;
+ txn_options.write_batch_flush_threshold = 1;
+
+ enum Action { kCommit, kRollback };
+
+ for (Action a : {kCommit, kRollback}) {
+ for (int i = 0; i < 100; i++) {
+ ASSERT_OK(db->Put(woptions, std::to_string(i), std::to_string(i)));
+ }
+
+ Transaction* txn = db->BeginTransaction(woptions, txn_options);
+ ASSERT_OK(txn->Put("9", "a"));
+
+ ReadOptions roptions;
+ auto iter1 = txn->GetIterator(roptions);
+ auto iter2 = txn->GetIterator(roptions);
+ iter1->SeekToFirst();
+ iter2->Seek("9");
+
+ // Check that iterators are valid before transaction finishes.
+ ASSERT_TRUE(iter1->Valid());
+ ASSERT_TRUE(iter2->Valid());
+ ASSERT_OK(iter1->status());
+ ASSERT_OK(iter2->status());
+
+ if (a == kCommit) {
+ ASSERT_OK(txn->Commit());
+ } else {
+ ASSERT_OK(txn->Rollback());
+ }
+
+ // Check that iterators are invalidated after transaction finishes.
+ ASSERT_FALSE(iter1->Valid());
+ ASSERT_FALSE(iter2->Valid());
+ ASSERT_TRUE(iter1->status().IsInvalidArgument());
+ ASSERT_TRUE(iter2->status().IsInvalidArgument());
+
+ delete iter1;
+ delete iter2;
+ delete txn;
+ }
+}
+
+TEST_P(WriteUnpreparedTransactionTest, SavePoint) {
+ WriteOptions woptions;
+ TransactionOptions txn_options;
+ txn_options.write_batch_flush_threshold = 1;
+
+ Transaction* txn = db->BeginTransaction(woptions, txn_options);
+ txn->SetSavePoint();
+ ASSERT_OK(txn->Put("a", "a"));
+ ASSERT_OK(txn->Put("b", "b"));
+ ASSERT_OK(txn->Commit());
+
+ ReadOptions roptions;
+ std::string value;
+ ASSERT_OK(txn->Get(roptions, "a", &value));
+ ASSERT_EQ(value, "a");
+ ASSERT_OK(txn->Get(roptions, "b", &value));
+ ASSERT_EQ(value, "b");
+ delete txn;
+}
+
+TEST_P(WriteUnpreparedTransactionTest, UntrackedKeys) {
+ WriteOptions woptions;
+ TransactionOptions txn_options;
+ txn_options.write_batch_flush_threshold = 1;
+
+ Transaction* txn = db->BeginTransaction(woptions, txn_options);
+ auto wb = txn->GetWriteBatch()->GetWriteBatch();
+ ASSERT_OK(txn->Put("a", "a"));
+ ASSERT_OK(wb->Put("a_untrack", "a_untrack"));
+ txn->SetSavePoint();
+ ASSERT_OK(txn->Put("b", "b"));
+ ASSERT_OK(txn->Put("b_untrack", "b_untrack"));
+
+ ReadOptions roptions;
+ std::string value;
+ ASSERT_OK(txn->Get(roptions, "a", &value));
+ ASSERT_EQ(value, "a");
+ ASSERT_OK(txn->Get(roptions, "a_untrack", &value));
+ ASSERT_EQ(value, "a_untrack");
+ ASSERT_OK(txn->Get(roptions, "b", &value));
+ ASSERT_EQ(value, "b");
+ ASSERT_OK(txn->Get(roptions, "b_untrack", &value));
+ ASSERT_EQ(value, "b_untrack");
+
+ // b and b_untrack should be rolled back.
+ ASSERT_OK(txn->RollbackToSavePoint());
+ ASSERT_OK(txn->Get(roptions, "a", &value));
+ ASSERT_EQ(value, "a");
+ ASSERT_OK(txn->Get(roptions, "a_untrack", &value));
+ ASSERT_EQ(value, "a_untrack");
+ auto s = txn->Get(roptions, "b", &value);
+ ASSERT_TRUE(s.IsNotFound());
+ s = txn->Get(roptions, "b_untrack", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ // Everything should be rolled back.
+ ASSERT_OK(txn->Rollback());
+ s = txn->Get(roptions, "a", &value);
+ ASSERT_TRUE(s.IsNotFound());
+ s = txn->Get(roptions, "a_untrack", &value);
+ ASSERT_TRUE(s.IsNotFound());
+ s = txn->Get(roptions, "b", &value);
+ ASSERT_TRUE(s.IsNotFound());
+ s = txn->Get(roptions, "b_untrack", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ delete txn;
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
+
+#else
+#include <stdio.h>
+
+int main(int /*argc*/, char** /*argv*/) {
+ fprintf(stderr,
+ "SKIPPED as Transactions are not supported in ROCKSDB_LITE\n");
+ return 0;
+}
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/write_unprepared_txn.cc b/src/rocksdb/utilities/transactions/write_unprepared_txn.cc
new file mode 100644
index 000000000..6e04d3344
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/write_unprepared_txn.cc
@@ -0,0 +1,1053 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "utilities/transactions/write_unprepared_txn.h"
+
+#include "db/db_impl/db_impl.h"
+#include "util/cast_util.h"
+#include "utilities/transactions/write_unprepared_txn_db.h"
+#include "utilities/write_batch_with_index/write_batch_with_index_internal.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+bool WriteUnpreparedTxnReadCallback::IsVisibleFullCheck(SequenceNumber seq) {
+ // Since unprep_seqs maps prep_seq => prepare_batch_cnt, to check if seq is
+ // in unprep_seqs, we have to check if seq is equal to prep_seq or any of
+ // the prepare_batch_cnt seq nums after it.
+ //
+ // TODO(lth): Can be optimized with std::lower_bound if unprep_seqs is
+ // large.
+ for (const auto& it : unprep_seqs_) {
+ if (it.first <= seq && seq < it.first + it.second) {
+ return true;
+ }
+ }
+
+ bool snap_released = false;
+ auto ret =
+ db_->IsInSnapshot(seq, wup_snapshot_, min_uncommitted_, &snap_released);
+ assert(!snap_released || backed_by_snapshot_ == kUnbackedByDBSnapshot);
+ snap_released_ |= snap_released;
+ return ret;
+}
+
+WriteUnpreparedTxn::WriteUnpreparedTxn(WriteUnpreparedTxnDB* txn_db,
+ const WriteOptions& write_options,
+ const TransactionOptions& txn_options)
+ : WritePreparedTxn(txn_db, write_options, txn_options),
+ wupt_db_(txn_db),
+ last_log_number_(0),
+ recovered_txn_(false),
+ largest_validated_seq_(0) {
+ if (txn_options.write_batch_flush_threshold < 0) {
+ write_batch_flush_threshold_ =
+ txn_db_impl_->GetTxnDBOptions().default_write_batch_flush_threshold;
+ } else {
+ write_batch_flush_threshold_ = txn_options.write_batch_flush_threshold;
+ }
+}
+
+WriteUnpreparedTxn::~WriteUnpreparedTxn() {
+ if (!unprep_seqs_.empty()) {
+ assert(log_number_ > 0);
+ assert(GetId() > 0);
+ assert(!name_.empty());
+
+ // We should rollback regardless of GetState, but some unit tests that
+ // test crash recovery run the destructor assuming that rollback does not
+ // happen, so that rollback during recovery can be exercised.
+ if (GetState() == STARTED || GetState() == LOCKS_STOLEN) {
+ auto s = RollbackInternal();
+ assert(s.ok());
+ if (!s.ok()) {
+ ROCKS_LOG_FATAL(
+ wupt_db_->info_log_,
+ "Rollback of WriteUnprepared transaction failed in destructor: %s",
+ s.ToString().c_str());
+ }
+ dbimpl_->logs_with_prep_tracker()->MarkLogAsHavingPrepSectionFlushed(
+ log_number_);
+ }
+ }
+
+ // Clear the tracked locks so that ~PessimisticTransaction does not
+ // try to unlock keys for recovered transactions.
+ if (recovered_txn_) {
+ tracked_locks_->Clear();
+ }
+}
+
+void WriteUnpreparedTxn::Initialize(const TransactionOptions& txn_options) {
+ PessimisticTransaction::Initialize(txn_options);
+ if (txn_options.write_batch_flush_threshold < 0) {
+ write_batch_flush_threshold_ =
+ txn_db_impl_->GetTxnDBOptions().default_write_batch_flush_threshold;
+ } else {
+ write_batch_flush_threshold_ = txn_options.write_batch_flush_threshold;
+ }
+
+ unprep_seqs_.clear();
+ flushed_save_points_.reset(nullptr);
+ unflushed_save_points_.reset(nullptr);
+ recovered_txn_ = false;
+ largest_validated_seq_ = 0;
+ assert(active_iterators_.empty());
+ active_iterators_.clear();
+ untracked_keys_.clear();
+}
+
+Status WriteUnpreparedTxn::HandleWrite(std::function<Status()> do_write) {
+ Status s;
+ if (active_iterators_.empty()) {
+ s = MaybeFlushWriteBatchToDB();
+ if (!s.ok()) {
+ return s;
+ }
+ }
+ s = do_write();
+ if (s.ok()) {
+ if (snapshot_) {
+ largest_validated_seq_ =
+ std::max(largest_validated_seq_, snapshot_->GetSequenceNumber());
+ } else {
+ // TODO(lth): We should use the same number as tracked_at_seq in TryLock,
+ // because what is actually being tracked is the sequence number at which
+ // this key was locked at.
+ largest_validated_seq_ = db_impl_->GetLastPublishedSequence();
+ }
+ }
+ return s;
+}
+
+Status WriteUnpreparedTxn::Put(ColumnFamilyHandle* column_family,
+ const Slice& key, const Slice& value,
+ const bool assume_tracked) {
+ return HandleWrite([&]() {
+ return TransactionBaseImpl::Put(column_family, key, value, assume_tracked);
+ });
+}
+
+Status WriteUnpreparedTxn::Put(ColumnFamilyHandle* column_family,
+ const SliceParts& key, const SliceParts& value,
+ const bool assume_tracked) {
+ return HandleWrite([&]() {
+ return TransactionBaseImpl::Put(column_family, key, value, assume_tracked);
+ });
+}
+
+Status WriteUnpreparedTxn::Merge(ColumnFamilyHandle* column_family,
+ const Slice& key, const Slice& value,
+ const bool assume_tracked) {
+ return HandleWrite([&]() {
+ return TransactionBaseImpl::Merge(column_family, key, value,
+ assume_tracked);
+ });
+}
+
+Status WriteUnpreparedTxn::Delete(ColumnFamilyHandle* column_family,
+ const Slice& key, const bool assume_tracked) {
+ return HandleWrite([&]() {
+ return TransactionBaseImpl::Delete(column_family, key, assume_tracked);
+ });
+}
+
+Status WriteUnpreparedTxn::Delete(ColumnFamilyHandle* column_family,
+ const SliceParts& key,
+ const bool assume_tracked) {
+ return HandleWrite([&]() {
+ return TransactionBaseImpl::Delete(column_family, key, assume_tracked);
+ });
+}
+
+Status WriteUnpreparedTxn::SingleDelete(ColumnFamilyHandle* column_family,
+ const Slice& key,
+ const bool assume_tracked) {
+ return HandleWrite([&]() {
+ return TransactionBaseImpl::SingleDelete(column_family, key,
+ assume_tracked);
+ });
+}
+
+Status WriteUnpreparedTxn::SingleDelete(ColumnFamilyHandle* column_family,
+ const SliceParts& key,
+ const bool assume_tracked) {
+ return HandleWrite([&]() {
+ return TransactionBaseImpl::SingleDelete(column_family, key,
+ assume_tracked);
+ });
+}
+
+// WriteUnpreparedTxn::RebuildFromWriteBatch is only called on recovery. For
+// WriteUnprepared, the write batches have already been written into the
+// database during WAL replay, so all we have to do is just to "retrack" the key
+// so that rollbacks are possible.
+//
+// Calling TryLock instead of TrackKey is also possible, but as an optimization,
+// recovered transactions do not hold locks on their keys. This follows the
+// implementation in PessimisticTransactionDB::Initialize where we set
+// skip_concurrency_control to true.
+Status WriteUnpreparedTxn::RebuildFromWriteBatch(WriteBatch* wb) {
+ struct TrackKeyHandler : public WriteBatch::Handler {
+ WriteUnpreparedTxn* txn_;
+ bool rollback_merge_operands_;
+
+ TrackKeyHandler(WriteUnpreparedTxn* txn, bool rollback_merge_operands)
+ : txn_(txn), rollback_merge_operands_(rollback_merge_operands) {}
+
+ Status PutCF(uint32_t cf, const Slice& key, const Slice&) override {
+ txn_->TrackKey(cf, key.ToString(), kMaxSequenceNumber,
+ false /* read_only */, true /* exclusive */);
+ return Status::OK();
+ }
+
+ Status DeleteCF(uint32_t cf, const Slice& key) override {
+ txn_->TrackKey(cf, key.ToString(), kMaxSequenceNumber,
+ false /* read_only */, true /* exclusive */);
+ return Status::OK();
+ }
+
+ Status SingleDeleteCF(uint32_t cf, const Slice& key) override {
+ txn_->TrackKey(cf, key.ToString(), kMaxSequenceNumber,
+ false /* read_only */, true /* exclusive */);
+ return Status::OK();
+ }
+
+ Status MergeCF(uint32_t cf, const Slice& key, const Slice&) override {
+ if (rollback_merge_operands_) {
+ txn_->TrackKey(cf, key.ToString(), kMaxSequenceNumber,
+ false /* read_only */, true /* exclusive */);
+ }
+ return Status::OK();
+ }
+
+ // Recovered batches do not contain 2PC markers.
+ Status MarkBeginPrepare(bool) override { return Status::InvalidArgument(); }
+
+ Status MarkEndPrepare(const Slice&) override {
+ return Status::InvalidArgument();
+ }
+
+ Status MarkNoop(bool) override { return Status::InvalidArgument(); }
+
+ Status MarkCommit(const Slice&) override {
+ return Status::InvalidArgument();
+ }
+
+ Status MarkRollback(const Slice&) override {
+ return Status::InvalidArgument();
+ }
+ };
+
+ TrackKeyHandler handler(this,
+ wupt_db_->txn_db_options_.rollback_merge_operands);
+ return wb->Iterate(&handler);
+}
+
+Status WriteUnpreparedTxn::MaybeFlushWriteBatchToDB() {
+ const bool kPrepared = true;
+ Status s;
+ if (write_batch_flush_threshold_ > 0 &&
+ write_batch_.GetWriteBatch()->Count() > 0 &&
+ write_batch_.GetDataSize() >
+ static_cast<size_t>(write_batch_flush_threshold_)) {
+ assert(GetState() != PREPARED);
+ s = FlushWriteBatchToDB(!kPrepared);
+ }
+ return s;
+}
+
+Status WriteUnpreparedTxn::FlushWriteBatchToDB(bool prepared) {
+ // If the current write batch contains savepoints, then some special handling
+ // is required so that RollbackToSavepoint can work.
+ //
+ // RollbackToSavepoint is not supported after Prepare() is called, so only do
+ // this for unprepared batches.
+ if (!prepared && unflushed_save_points_ != nullptr &&
+ !unflushed_save_points_->empty()) {
+ return FlushWriteBatchWithSavePointToDB();
+ }
+
+ return FlushWriteBatchToDBInternal(prepared);
+}
+
+Status WriteUnpreparedTxn::FlushWriteBatchToDBInternal(bool prepared) {
+ if (name_.empty()) {
+ assert(!prepared);
+#ifndef NDEBUG
+ static std::atomic_ullong autogen_id{0};
+ // To avoid changing all tests to call SetName, just autogenerate one.
+ if (wupt_db_->txn_db_options_.autogenerate_name) {
+ auto s = SetName(std::string("autoxid") +
+ std::to_string(autogen_id.fetch_add(1)));
+ assert(s.ok());
+ } else
+#endif
+ {
+ return Status::InvalidArgument("Cannot write to DB without SetName.");
+ }
+ }
+
+ struct UntrackedKeyHandler : public WriteBatch::Handler {
+ WriteUnpreparedTxn* txn_;
+ bool rollback_merge_operands_;
+
+ UntrackedKeyHandler(WriteUnpreparedTxn* txn, bool rollback_merge_operands)
+ : txn_(txn), rollback_merge_operands_(rollback_merge_operands) {}
+
+ Status AddUntrackedKey(uint32_t cf, const Slice& key) {
+ auto str = key.ToString();
+ PointLockStatus lock_status =
+ txn_->tracked_locks_->GetPointLockStatus(cf, str);
+ if (!lock_status.locked) {
+ txn_->untracked_keys_[cf].push_back(str);
+ }
+ return Status::OK();
+ }
+
+ Status PutCF(uint32_t cf, const Slice& key, const Slice&) override {
+ return AddUntrackedKey(cf, key);
+ }
+
+ Status DeleteCF(uint32_t cf, const Slice& key) override {
+ return AddUntrackedKey(cf, key);
+ }
+
+ Status SingleDeleteCF(uint32_t cf, const Slice& key) override {
+ return AddUntrackedKey(cf, key);
+ }
+
+ Status MergeCF(uint32_t cf, const Slice& key, const Slice&) override {
+ if (rollback_merge_operands_) {
+ return AddUntrackedKey(cf, key);
+ }
+ return Status::OK();
+ }
+
+ // The only expected 2PC marker is the initial Noop marker.
+ Status MarkNoop(bool empty_batch) override {
+ return empty_batch ? Status::OK() : Status::InvalidArgument();
+ }
+
+ Status MarkBeginPrepare(bool) override { return Status::InvalidArgument(); }
+
+ Status MarkEndPrepare(const Slice&) override {
+ return Status::InvalidArgument();
+ }
+
+ Status MarkCommit(const Slice&) override {
+ return Status::InvalidArgument();
+ }
+
+ Status MarkRollback(const Slice&) override {
+ return Status::InvalidArgument();
+ }
+ };
+
+ UntrackedKeyHandler handler(
+ this, wupt_db_->txn_db_options_.rollback_merge_operands);
+ auto s = GetWriteBatch()->GetWriteBatch()->Iterate(&handler);
+ assert(s.ok());
+
+ // TODO(lth): Reduce duplicate code with WritePrepared prepare logic.
+ WriteOptions write_options = write_options_;
+ write_options.disableWAL = false;
+ const bool WRITE_AFTER_COMMIT = true;
+ const bool first_prepare_batch = log_number_ == 0;
+ // MarkEndPrepare will change Noop marker to the appropriate marker.
+ s = WriteBatchInternal::MarkEndPrepare(GetWriteBatch()->GetWriteBatch(),
+ name_, !WRITE_AFTER_COMMIT, !prepared);
+ assert(s.ok());
+ // For each duplicate key we account for a new sub-batch
+ prepare_batch_cnt_ = GetWriteBatch()->SubBatchCnt();
+ // AddPrepared better to be called in the pre-release callback otherwise there
+ // is a non-zero chance of max advancing prepare_seq and readers assume the
+ // data as committed.
+ // Also having it in the PreReleaseCallback allows in-order addition of
+ // prepared entries to PreparedHeap and hence enables an optimization. Refer
+ // to SmallestUnCommittedSeq for more details.
+ AddPreparedCallback add_prepared_callback(
+ wpt_db_, db_impl_, prepare_batch_cnt_,
+ db_impl_->immutable_db_options().two_write_queues, first_prepare_batch);
+ const bool DISABLE_MEMTABLE = true;
+ uint64_t seq_used = kMaxSequenceNumber;
+ // log_number_ should refer to the oldest log containing uncommitted data
+ // from the current transaction. This means that if log_number_ is set,
+ // WriteImpl should not overwrite that value, so set log_used to nullptr if
+ // log_number_ is already set.
+ s = db_impl_->WriteImpl(write_options, GetWriteBatch()->GetWriteBatch(),
+ /*callback*/ nullptr, &last_log_number_,
+ /*log ref*/ 0, !DISABLE_MEMTABLE, &seq_used,
+ prepare_batch_cnt_, &add_prepared_callback);
+ if (log_number_ == 0) {
+ log_number_ = last_log_number_;
+ }
+ assert(!s.ok() || seq_used != kMaxSequenceNumber);
+ auto prepare_seq = seq_used;
+
+ // Only call SetId if it hasn't been set yet.
+ if (GetId() == 0) {
+ SetId(prepare_seq);
+ }
+ // unprep_seqs_ will also contain prepared seqnos since they are treated in
+ // the same way in the prepare/commit callbacks. See the comment on the
+ // definition of unprep_seqs_.
+ unprep_seqs_[prepare_seq] = prepare_batch_cnt_;
+
+ // Reset transaction state.
+ if (!prepared) {
+ prepare_batch_cnt_ = 0;
+ const bool kClear = true;
+ TransactionBaseImpl::InitWriteBatch(kClear);
+ }
+
+ return s;
+}
+
+Status WriteUnpreparedTxn::FlushWriteBatchWithSavePointToDB() {
+ assert(unflushed_save_points_ != nullptr &&
+ unflushed_save_points_->size() > 0);
+ assert(save_points_ != nullptr && save_points_->size() > 0);
+ assert(save_points_->size() >= unflushed_save_points_->size());
+
+ // Handler class for creating an unprepared batch from a savepoint.
+ struct SavePointBatchHandler : public WriteBatch::Handler {
+ WriteBatchWithIndex* wb_;
+ const std::map<uint32_t, ColumnFamilyHandle*>& handles_;
+
+ SavePointBatchHandler(
+ WriteBatchWithIndex* wb,
+ const std::map<uint32_t, ColumnFamilyHandle*>& handles)
+ : wb_(wb), handles_(handles) {}
+
+ Status PutCF(uint32_t cf, const Slice& key, const Slice& value) override {
+ return wb_->Put(handles_.at(cf), key, value);
+ }
+
+ Status DeleteCF(uint32_t cf, const Slice& key) override {
+ return wb_->Delete(handles_.at(cf), key);
+ }
+
+ Status SingleDeleteCF(uint32_t cf, const Slice& key) override {
+ return wb_->SingleDelete(handles_.at(cf), key);
+ }
+
+ Status MergeCF(uint32_t cf, const Slice& key, const Slice& value) override {
+ return wb_->Merge(handles_.at(cf), key, value);
+ }
+
+ // The only expected 2PC marker is the initial Noop marker.
+ Status MarkNoop(bool empty_batch) override {
+ return empty_batch ? Status::OK() : Status::InvalidArgument();
+ }
+
+ Status MarkBeginPrepare(bool) override { return Status::InvalidArgument(); }
+
+ Status MarkEndPrepare(const Slice&) override {
+ return Status::InvalidArgument();
+ }
+
+ Status MarkCommit(const Slice&) override {
+ return Status::InvalidArgument();
+ }
+
+ Status MarkRollback(const Slice&) override {
+ return Status::InvalidArgument();
+ }
+ };
+
+ // The comparator of the default cf is passed in, similar to the
+ // initialization of TransactionBaseImpl::write_batch_. This comparator is
+ // only used if the write batch encounters an invalid cf id, and falls back to
+ // this comparator.
+ WriteBatchWithIndex wb(wpt_db_->DefaultColumnFamily()->GetComparator(), 0,
+ true, 0, write_options_.protection_bytes_per_key);
+ // Swap with write_batch_ so that wb contains the complete write batch. The
+ // actual write batch that will be flushed to DB will be built in
+ // write_batch_, and will be read by FlushWriteBatchToDBInternal.
+ std::swap(wb, write_batch_);
+ TransactionBaseImpl::InitWriteBatch();
+
+ size_t prev_boundary = WriteBatchInternal::kHeader;
+ const bool kPrepared = true;
+ for (size_t i = 0; i < unflushed_save_points_->size() + 1; i++) {
+ bool trailing_batch = i == unflushed_save_points_->size();
+ SavePointBatchHandler sp_handler(&write_batch_,
+ *wupt_db_->GetCFHandleMap().get());
+ size_t curr_boundary = trailing_batch ? wb.GetWriteBatch()->GetDataSize()
+ : (*unflushed_save_points_)[i];
+
+ // Construct the partial write batch up to the savepoint.
+ //
+ // Theoretically, a memcpy between the write batches should be sufficient
+ // since the rewriting into the batch should produce the exact same byte
+ // representation. Rebuilding the WriteBatchWithIndex index is still
+ // necessary though, and would imply doing two passes over the batch though.
+ Status s = WriteBatchInternal::Iterate(wb.GetWriteBatch(), &sp_handler,
+ prev_boundary, curr_boundary);
+ if (!s.ok()) {
+ return s;
+ }
+
+ if (write_batch_.GetWriteBatch()->Count() > 0) {
+ // Flush the write batch.
+ s = FlushWriteBatchToDBInternal(!kPrepared);
+ if (!s.ok()) {
+ return s;
+ }
+ }
+
+ if (!trailing_batch) {
+ if (flushed_save_points_ == nullptr) {
+ flushed_save_points_.reset(
+ new autovector<WriteUnpreparedTxn::SavePoint>());
+ }
+ flushed_save_points_->emplace_back(
+ unprep_seqs_, new ManagedSnapshot(db_impl_, wupt_db_->GetSnapshot()));
+ }
+
+ prev_boundary = curr_boundary;
+ const bool kClear = true;
+ TransactionBaseImpl::InitWriteBatch(kClear);
+ }
+
+ unflushed_save_points_->clear();
+ return Status::OK();
+}
+
+Status WriteUnpreparedTxn::PrepareInternal() {
+ const bool kPrepared = true;
+ return FlushWriteBatchToDB(kPrepared);
+}
+
+Status WriteUnpreparedTxn::CommitWithoutPrepareInternal() {
+ if (unprep_seqs_.empty()) {
+ assert(log_number_ == 0);
+ assert(GetId() == 0);
+ return WritePreparedTxn::CommitWithoutPrepareInternal();
+ }
+
+ // TODO(lth): We should optimize commit without prepare to not perform
+ // a prepare under the hood.
+ auto s = PrepareInternal();
+ if (!s.ok()) {
+ return s;
+ }
+ return CommitInternal();
+}
+
+Status WriteUnpreparedTxn::CommitInternal() {
+ // TODO(lth): Reduce duplicate code with WritePrepared commit logic.
+
+ // We take the commit-time batch and append the Commit marker. The Memtable
+ // will ignore the Commit marker in non-recovery mode
+ WriteBatch* working_batch = GetCommitTimeWriteBatch();
+ const bool empty = working_batch->Count() == 0;
+ auto s = WriteBatchInternal::MarkCommit(working_batch, name_);
+ assert(s.ok());
+
+ const bool for_recovery = use_only_the_last_commit_time_batch_for_recovery_;
+ if (!empty) {
+ // When not writing to memtable, we can still cache the latest write batch.
+ // The cached batch will be written to memtable in WriteRecoverableState
+ // during FlushMemTable
+ if (for_recovery) {
+ WriteBatchInternal::SetAsLatestPersistentState(working_batch);
+ } else {
+ return Status::InvalidArgument(
+ "Commit-time-batch can only be used if "
+ "use_only_the_last_commit_time_batch_for_recovery is true");
+ }
+ }
+
+ const bool includes_data = !empty && !for_recovery;
+ size_t commit_batch_cnt = 0;
+ if (UNLIKELY(includes_data)) {
+ ROCKS_LOG_WARN(db_impl_->immutable_db_options().info_log,
+ "Duplicate key overhead");
+ SubBatchCounter counter(*wpt_db_->GetCFComparatorMap());
+ s = working_batch->Iterate(&counter);
+ assert(s.ok());
+ commit_batch_cnt = counter.BatchCount();
+ }
+ const bool disable_memtable = !includes_data;
+ const bool do_one_write =
+ !db_impl_->immutable_db_options().two_write_queues || disable_memtable;
+
+ WriteUnpreparedCommitEntryPreReleaseCallback update_commit_map(
+ wpt_db_, db_impl_, unprep_seqs_, commit_batch_cnt);
+ const bool kFirstPrepareBatch = true;
+ AddPreparedCallback add_prepared_callback(
+ wpt_db_, db_impl_, commit_batch_cnt,
+ db_impl_->immutable_db_options().two_write_queues, !kFirstPrepareBatch);
+ PreReleaseCallback* pre_release_callback;
+ if (do_one_write) {
+ pre_release_callback = &update_commit_map;
+ } else {
+ pre_release_callback = &add_prepared_callback;
+ }
+ uint64_t seq_used = kMaxSequenceNumber;
+ // Since the prepared batch is directly written to memtable, there is
+ // already a connection between the memtable and its WAL, so there is no
+ // need to redundantly reference the log that contains the prepared data.
+ const uint64_t zero_log_number = 0ull;
+ size_t batch_cnt = UNLIKELY(commit_batch_cnt) ? commit_batch_cnt : 1;
+ s = db_impl_->WriteImpl(write_options_, working_batch, nullptr, nullptr,
+ zero_log_number, disable_memtable, &seq_used,
+ batch_cnt, pre_release_callback);
+ assert(!s.ok() || seq_used != kMaxSequenceNumber);
+ const SequenceNumber commit_batch_seq = seq_used;
+ if (LIKELY(do_one_write || !s.ok())) {
+ if (LIKELY(s.ok())) {
+ // Note RemovePrepared should be called after WriteImpl that publishsed
+ // the seq. Otherwise SmallestUnCommittedSeq optimization breaks.
+ for (const auto& seq : unprep_seqs_) {
+ wpt_db_->RemovePrepared(seq.first, seq.second);
+ }
+ }
+ if (UNLIKELY(!do_one_write)) {
+ wpt_db_->RemovePrepared(commit_batch_seq, commit_batch_cnt);
+ }
+ unprep_seqs_.clear();
+ flushed_save_points_.reset(nullptr);
+ unflushed_save_points_.reset(nullptr);
+ return s;
+ } // else do the 2nd write to publish seq
+
+ // Populate unprep_seqs_ with commit_batch_seq, since we treat data in the
+ // commit write batch as just another "unprepared" batch. This will also
+ // update the unprep_seqs_ in the update_commit_map callback.
+ unprep_seqs_[commit_batch_seq] = commit_batch_cnt;
+ WriteUnpreparedCommitEntryPreReleaseCallback
+ update_commit_map_with_commit_batch(wpt_db_, db_impl_, unprep_seqs_, 0);
+
+ // Note: the 2nd write comes with a performance penality. So if we have too
+ // many of commits accompanied with ComitTimeWriteBatch and yet we cannot
+ // enable use_only_the_last_commit_time_batch_for_recovery_ optimization,
+ // two_write_queues should be disabled to avoid many additional writes here.
+
+ // Update commit map only from the 2nd queue
+ WriteBatch empty_batch;
+ s = empty_batch.PutLogData(Slice());
+ assert(s.ok());
+ // In the absence of Prepare markers, use Noop as a batch separator
+ s = WriteBatchInternal::InsertNoop(&empty_batch);
+ assert(s.ok());
+ const bool DISABLE_MEMTABLE = true;
+ const size_t ONE_BATCH = 1;
+ const uint64_t NO_REF_LOG = 0;
+ s = db_impl_->WriteImpl(write_options_, &empty_batch, nullptr, nullptr,
+ NO_REF_LOG, DISABLE_MEMTABLE, &seq_used, ONE_BATCH,
+ &update_commit_map_with_commit_batch);
+ assert(!s.ok() || seq_used != kMaxSequenceNumber);
+ // Note RemovePrepared should be called after WriteImpl that publishsed the
+ // seq. Otherwise SmallestUnCommittedSeq optimization breaks.
+ for (const auto& seq : unprep_seqs_) {
+ wpt_db_->RemovePrepared(seq.first, seq.second);
+ }
+ unprep_seqs_.clear();
+ flushed_save_points_.reset(nullptr);
+ unflushed_save_points_.reset(nullptr);
+ return s;
+}
+
+Status WriteUnpreparedTxn::WriteRollbackKeys(
+ const LockTracker& lock_tracker, WriteBatchWithIndex* rollback_batch,
+ ReadCallback* callback, const ReadOptions& roptions) {
+ // This assertion can be removed when range lock is supported.
+ assert(lock_tracker.IsPointLockSupported());
+ const auto& cf_map = *wupt_db_->GetCFHandleMap();
+ auto WriteRollbackKey = [&](const std::string& key, uint32_t cfid) {
+ const auto& cf_handle = cf_map.at(cfid);
+ PinnableSlice pinnable_val;
+ bool not_used;
+ DBImpl::GetImplOptions get_impl_options;
+ get_impl_options.column_family = cf_handle;
+ get_impl_options.value = &pinnable_val;
+ get_impl_options.value_found = &not_used;
+ get_impl_options.callback = callback;
+ auto s = db_impl_->GetImpl(roptions, key, get_impl_options);
+
+ if (s.ok()) {
+ s = rollback_batch->Put(cf_handle, key, pinnable_val);
+ assert(s.ok());
+ } else if (s.IsNotFound()) {
+ if (wupt_db_->ShouldRollbackWithSingleDelete(cf_handle, key)) {
+ s = rollback_batch->SingleDelete(cf_handle, key);
+ } else {
+ s = rollback_batch->Delete(cf_handle, key);
+ }
+ assert(s.ok());
+ } else {
+ return s;
+ }
+
+ return Status::OK();
+ };
+
+ std::unique_ptr<LockTracker::ColumnFamilyIterator> cf_it(
+ lock_tracker.GetColumnFamilyIterator());
+ assert(cf_it != nullptr);
+ while (cf_it->HasNext()) {
+ ColumnFamilyId cf = cf_it->Next();
+ std::unique_ptr<LockTracker::KeyIterator> key_it(
+ lock_tracker.GetKeyIterator(cf));
+ assert(key_it != nullptr);
+ while (key_it->HasNext()) {
+ const std::string& key = key_it->Next();
+ auto s = WriteRollbackKey(key, cf);
+ if (!s.ok()) {
+ return s;
+ }
+ }
+ }
+
+ for (const auto& cfkey : untracked_keys_) {
+ const auto cfid = cfkey.first;
+ const auto& keys = cfkey.second;
+ for (const auto& key : keys) {
+ auto s = WriteRollbackKey(key, cfid);
+ if (!s.ok()) {
+ return s;
+ }
+ }
+ }
+
+ return Status::OK();
+}
+
+Status WriteUnpreparedTxn::RollbackInternal() {
+ // TODO(lth): Reduce duplicate code with WritePrepared rollback logic.
+ WriteBatchWithIndex rollback_batch(
+ wpt_db_->DefaultColumnFamily()->GetComparator(), 0, true, 0,
+ write_options_.protection_bytes_per_key);
+ assert(GetId() != kMaxSequenceNumber);
+ assert(GetId() > 0);
+ Status s;
+ auto read_at_seq = kMaxSequenceNumber;
+ ReadOptions roptions;
+ // to prevent callback's seq to be overrriden inside DBImpk::Get
+ roptions.snapshot = wpt_db_->GetMaxSnapshot();
+ // Note that we do not use WriteUnpreparedTxnReadCallback because we do not
+ // need to read our own writes when reading prior versions of the key for
+ // rollback.
+ WritePreparedTxnReadCallback callback(wpt_db_, read_at_seq);
+ // TODO(lth): We write rollback batch all in a single batch here, but this
+ // should be subdivded into multiple batches as well. In phase 2, when key
+ // sets are read from WAL, this will happen naturally.
+ s = WriteRollbackKeys(*tracked_locks_, &rollback_batch, &callback, roptions);
+ if (!s.ok()) {
+ return s;
+ }
+
+ // The Rollback marker will be used as a batch separator
+ s = WriteBatchInternal::MarkRollback(rollback_batch.GetWriteBatch(), name_);
+ assert(s.ok());
+ bool do_one_write = !db_impl_->immutable_db_options().two_write_queues;
+ const bool DISABLE_MEMTABLE = true;
+ const uint64_t NO_REF_LOG = 0;
+ uint64_t seq_used = kMaxSequenceNumber;
+ // Rollback batch may contain duplicate keys, because tracked_keys_ is not
+ // comparator aware.
+ auto rollback_batch_cnt = rollback_batch.SubBatchCnt();
+ // We commit the rolled back prepared batches. Although this is
+ // counter-intuitive, i) it is safe to do so, since the prepared batches are
+ // already canceled out by the rollback batch, ii) adding the commit entry to
+ // CommitCache will allow us to benefit from the existing mechanism in
+ // CommitCache that keeps an entry evicted due to max advance and yet overlaps
+ // with a live snapshot around so that the live snapshot properly skips the
+ // entry even if its prepare seq is lower than max_evicted_seq_.
+ //
+ // TODO(lth): RollbackInternal is conceptually very similar to
+ // CommitInternal, with the rollback batch simply taking on the role of
+ // CommitTimeWriteBatch. We should be able to merge the two code paths.
+ WriteUnpreparedCommitEntryPreReleaseCallback update_commit_map(
+ wpt_db_, db_impl_, unprep_seqs_, rollback_batch_cnt);
+ // Note: the rollback batch does not need AddPrepared since it is written to
+ // DB in one shot. min_uncommitted still works since it requires capturing
+ // data that is written to DB but not yet committed, while the rollback
+ // batch commits with PreReleaseCallback.
+ s = db_impl_->WriteImpl(write_options_, rollback_batch.GetWriteBatch(),
+ nullptr, nullptr, NO_REF_LOG, !DISABLE_MEMTABLE,
+ &seq_used, rollback_batch_cnt,
+ do_one_write ? &update_commit_map : nullptr);
+ assert(!s.ok() || seq_used != kMaxSequenceNumber);
+ if (!s.ok()) {
+ return s;
+ }
+ if (do_one_write) {
+ for (const auto& seq : unprep_seqs_) {
+ wpt_db_->RemovePrepared(seq.first, seq.second);
+ }
+ unprep_seqs_.clear();
+ flushed_save_points_.reset(nullptr);
+ unflushed_save_points_.reset(nullptr);
+ return s;
+ } // else do the 2nd write for commit
+
+ uint64_t& prepare_seq = seq_used;
+ // Populate unprep_seqs_ with rollback_batch_cnt, since we treat data in the
+ // rollback write batch as just another "unprepared" batch. This will also
+ // update the unprep_seqs_ in the update_commit_map callback.
+ unprep_seqs_[prepare_seq] = rollback_batch_cnt;
+ WriteUnpreparedCommitEntryPreReleaseCallback
+ update_commit_map_with_rollback_batch(wpt_db_, db_impl_, unprep_seqs_, 0);
+
+ ROCKS_LOG_DETAILS(db_impl_->immutable_db_options().info_log,
+ "RollbackInternal 2nd write prepare_seq: %" PRIu64,
+ prepare_seq);
+ WriteBatch empty_batch;
+ const size_t ONE_BATCH = 1;
+ s = empty_batch.PutLogData(Slice());
+ assert(s.ok());
+ // In the absence of Prepare markers, use Noop as a batch separator
+ s = WriteBatchInternal::InsertNoop(&empty_batch);
+ assert(s.ok());
+ s = db_impl_->WriteImpl(write_options_, &empty_batch, nullptr, nullptr,
+ NO_REF_LOG, DISABLE_MEMTABLE, &seq_used, ONE_BATCH,
+ &update_commit_map_with_rollback_batch);
+ assert(!s.ok() || seq_used != kMaxSequenceNumber);
+ // Mark the txn as rolled back
+ if (s.ok()) {
+ for (const auto& seq : unprep_seqs_) {
+ wpt_db_->RemovePrepared(seq.first, seq.second);
+ }
+ }
+
+ unprep_seqs_.clear();
+ flushed_save_points_.reset(nullptr);
+ unflushed_save_points_.reset(nullptr);
+ return s;
+}
+
+void WriteUnpreparedTxn::Clear() {
+ if (!recovered_txn_) {
+ txn_db_impl_->UnLock(this, *tracked_locks_);
+ }
+ unprep_seqs_.clear();
+ flushed_save_points_.reset(nullptr);
+ unflushed_save_points_.reset(nullptr);
+ recovered_txn_ = false;
+ largest_validated_seq_ = 0;
+ for (auto& it : active_iterators_) {
+ auto bdit = static_cast<BaseDeltaIterator*>(it);
+ bdit->Invalidate(Status::InvalidArgument(
+ "Cannot use iterator after transaction has finished"));
+ }
+ active_iterators_.clear();
+ untracked_keys_.clear();
+ TransactionBaseImpl::Clear();
+}
+
+void WriteUnpreparedTxn::SetSavePoint() {
+ assert((unflushed_save_points_ ? unflushed_save_points_->size() : 0) +
+ (flushed_save_points_ ? flushed_save_points_->size() : 0) ==
+ (save_points_ ? save_points_->size() : 0));
+ PessimisticTransaction::SetSavePoint();
+ if (unflushed_save_points_ == nullptr) {
+ unflushed_save_points_.reset(new autovector<size_t>());
+ }
+ unflushed_save_points_->push_back(write_batch_.GetDataSize());
+}
+
+Status WriteUnpreparedTxn::RollbackToSavePoint() {
+ assert((unflushed_save_points_ ? unflushed_save_points_->size() : 0) +
+ (flushed_save_points_ ? flushed_save_points_->size() : 0) ==
+ (save_points_ ? save_points_->size() : 0));
+ if (unflushed_save_points_ != nullptr && unflushed_save_points_->size() > 0) {
+ Status s = PessimisticTransaction::RollbackToSavePoint();
+ assert(!s.IsNotFound());
+ unflushed_save_points_->pop_back();
+ return s;
+ }
+
+ if (flushed_save_points_ != nullptr && !flushed_save_points_->empty()) {
+ return RollbackToSavePointInternal();
+ }
+
+ return Status::NotFound();
+}
+
+Status WriteUnpreparedTxn::RollbackToSavePointInternal() {
+ Status s;
+
+ const bool kClear = true;
+ TransactionBaseImpl::InitWriteBatch(kClear);
+
+ assert(flushed_save_points_->size() > 0);
+ WriteUnpreparedTxn::SavePoint& top = flushed_save_points_->back();
+
+ assert(save_points_ != nullptr && save_points_->size() > 0);
+ const LockTracker& tracked_keys = *save_points_->top().new_locks_;
+
+ ReadOptions roptions;
+ roptions.snapshot = top.snapshot_->snapshot();
+ SequenceNumber min_uncommitted =
+ static_cast_with_check<const SnapshotImpl>(roptions.snapshot)
+ ->min_uncommitted_;
+ SequenceNumber snap_seq = roptions.snapshot->GetSequenceNumber();
+ WriteUnpreparedTxnReadCallback callback(wupt_db_, snap_seq, min_uncommitted,
+ top.unprep_seqs_,
+ kBackedByDBSnapshot);
+ s = WriteRollbackKeys(tracked_keys, &write_batch_, &callback, roptions);
+ if (!s.ok()) {
+ return s;
+ }
+
+ const bool kPrepared = true;
+ s = FlushWriteBatchToDBInternal(!kPrepared);
+ if (!s.ok()) {
+ return s;
+ }
+
+ // PessimisticTransaction::RollbackToSavePoint will call also call
+ // RollbackToSavepoint on write_batch_. However, write_batch_ is empty and has
+ // no savepoints because this savepoint has already been flushed. Work around
+ // this by setting a fake savepoint.
+ write_batch_.SetSavePoint();
+ s = PessimisticTransaction::RollbackToSavePoint();
+ assert(s.ok());
+ if (!s.ok()) {
+ return s;
+ }
+
+ flushed_save_points_->pop_back();
+ return s;
+}
+
+Status WriteUnpreparedTxn::PopSavePoint() {
+ assert((unflushed_save_points_ ? unflushed_save_points_->size() : 0) +
+ (flushed_save_points_ ? flushed_save_points_->size() : 0) ==
+ (save_points_ ? save_points_->size() : 0));
+ if (unflushed_save_points_ != nullptr && unflushed_save_points_->size() > 0) {
+ Status s = PessimisticTransaction::PopSavePoint();
+ assert(!s.IsNotFound());
+ unflushed_save_points_->pop_back();
+ return s;
+ }
+
+ if (flushed_save_points_ != nullptr && !flushed_save_points_->empty()) {
+ // PessimisticTransaction::PopSavePoint will call also call PopSavePoint on
+ // write_batch_. However, write_batch_ is empty and has no savepoints
+ // because this savepoint has already been flushed. Work around this by
+ // setting a fake savepoint.
+ write_batch_.SetSavePoint();
+ Status s = PessimisticTransaction::PopSavePoint();
+ assert(!s.IsNotFound());
+ flushed_save_points_->pop_back();
+ return s;
+ }
+
+ return Status::NotFound();
+}
+
+void WriteUnpreparedTxn::MultiGet(const ReadOptions& options,
+ ColumnFamilyHandle* column_family,
+ const size_t num_keys, const Slice* keys,
+ PinnableSlice* values, Status* statuses,
+ const bool sorted_input) {
+ SequenceNumber min_uncommitted, snap_seq;
+ const SnapshotBackup backed_by_snapshot =
+ wupt_db_->AssignMinMaxSeqs(options.snapshot, &min_uncommitted, &snap_seq);
+ WriteUnpreparedTxnReadCallback callback(wupt_db_, snap_seq, min_uncommitted,
+ unprep_seqs_, backed_by_snapshot);
+ write_batch_.MultiGetFromBatchAndDB(db_, options, column_family, num_keys,
+ keys, values, statuses, sorted_input,
+ &callback);
+ if (UNLIKELY(!callback.valid() ||
+ !wupt_db_->ValidateSnapshot(snap_seq, backed_by_snapshot))) {
+ wupt_db_->WPRecordTick(TXN_GET_TRY_AGAIN);
+ for (size_t i = 0; i < num_keys; i++) {
+ statuses[i] = Status::TryAgain();
+ }
+ }
+}
+
+Status WriteUnpreparedTxn::Get(const ReadOptions& options,
+ ColumnFamilyHandle* column_family,
+ const Slice& key, PinnableSlice* value) {
+ SequenceNumber min_uncommitted, snap_seq;
+ const SnapshotBackup backed_by_snapshot =
+ wupt_db_->AssignMinMaxSeqs(options.snapshot, &min_uncommitted, &snap_seq);
+ WriteUnpreparedTxnReadCallback callback(wupt_db_, snap_seq, min_uncommitted,
+ unprep_seqs_, backed_by_snapshot);
+ auto res = write_batch_.GetFromBatchAndDB(db_, options, column_family, key,
+ value, &callback);
+ if (LIKELY(callback.valid() &&
+ wupt_db_->ValidateSnapshot(snap_seq, backed_by_snapshot))) {
+ return res;
+ } else {
+ res.PermitUncheckedError();
+ wupt_db_->WPRecordTick(TXN_GET_TRY_AGAIN);
+ return Status::TryAgain();
+ }
+}
+
+namespace {
+static void CleanupWriteUnpreparedWBWIIterator(void* arg1, void* arg2) {
+ auto txn = reinterpret_cast<WriteUnpreparedTxn*>(arg1);
+ auto iter = reinterpret_cast<Iterator*>(arg2);
+ txn->RemoveActiveIterator(iter);
+}
+} // anonymous namespace
+
+Iterator* WriteUnpreparedTxn::GetIterator(const ReadOptions& options) {
+ return GetIterator(options, wupt_db_->DefaultColumnFamily());
+}
+
+Iterator* WriteUnpreparedTxn::GetIterator(const ReadOptions& options,
+ ColumnFamilyHandle* column_family) {
+ // Make sure to get iterator from WriteUnprepareTxnDB, not the root db.
+ Iterator* db_iter = wupt_db_->NewIterator(options, column_family, this);
+ assert(db_iter);
+
+ auto iter = write_batch_.NewIteratorWithBase(column_family, db_iter);
+ active_iterators_.push_back(iter);
+ iter->RegisterCleanup(CleanupWriteUnpreparedWBWIIterator, this, iter);
+ return iter;
+}
+
+Status WriteUnpreparedTxn::ValidateSnapshot(ColumnFamilyHandle* column_family,
+ const Slice& key,
+ SequenceNumber* tracked_at_seq) {
+ // TODO(lth): Reduce duplicate code with WritePrepared ValidateSnapshot logic.
+ assert(snapshot_);
+
+ SequenceNumber min_uncommitted =
+ static_cast_with_check<const SnapshotImpl>(snapshot_.get())
+ ->min_uncommitted_;
+ SequenceNumber snap_seq = snapshot_->GetSequenceNumber();
+ // tracked_at_seq is either max or the last snapshot with which this key was
+ // trackeed so there is no need to apply the IsInSnapshot to this comparison
+ // here as tracked_at_seq is not a prepare seq.
+ if (*tracked_at_seq <= snap_seq) {
+ // If the key has been previous validated at a sequence number earlier
+ // than the curent snapshot's sequence number, we already know it has not
+ // been modified.
+ return Status::OK();
+ }
+
+ *tracked_at_seq = snap_seq;
+
+ ColumnFamilyHandle* cfh =
+ column_family ? column_family : db_impl_->DefaultColumnFamily();
+
+ WriteUnpreparedTxnReadCallback snap_checker(
+ wupt_db_, snap_seq, min_uncommitted, unprep_seqs_, kBackedByDBSnapshot);
+ // TODO(yanqin): Support user-defined timestamp.
+ return TransactionUtil::CheckKeyForConflicts(
+ db_impl_, cfh, key.ToString(), snap_seq, /*ts=*/nullptr,
+ false /* cache_only */, &snap_checker, min_uncommitted);
+}
+
+const std::map<SequenceNumber, size_t>&
+WriteUnpreparedTxn::GetUnpreparedSequenceNumbers() {
+ return unprep_seqs_;
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/write_unprepared_txn.h b/src/rocksdb/utilities/transactions/write_unprepared_txn.h
new file mode 100644
index 000000000..5a3227f4e
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/write_unprepared_txn.h
@@ -0,0 +1,341 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#ifndef ROCKSDB_LITE
+
+#include <set>
+
+#include "utilities/transactions/write_prepared_txn.h"
+#include "utilities/transactions/write_unprepared_txn_db.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class WriteUnpreparedTxnDB;
+class WriteUnpreparedTxn;
+
+// WriteUnprepared transactions needs to be able to read their own uncommitted
+// writes, and supporting this requires some careful consideration. Because
+// writes in the current transaction may be flushed to DB already, we cannot
+// rely on the contents of WriteBatchWithIndex to determine whether a key should
+// be visible or not, so we have to remember to check the DB for any uncommitted
+// keys that should be visible to us. First, we will need to change the seek to
+// snapshot logic, to seek to max_visible_seq = max(snap_seq, max_unprep_seq).
+// Any key greater than max_visible_seq should not be visible because they
+// cannot be unprepared by the current transaction and they are not in its
+// snapshot.
+//
+// When we seek to max_visible_seq, one of these cases will happen:
+// 1. We hit a unprepared key from the current transaction.
+// 2. We hit a unprepared key from the another transaction.
+// 3. We hit a committed key with snap_seq < seq < max_unprep_seq.
+// 4. We hit a committed key with seq <= snap_seq.
+//
+// IsVisibleFullCheck handles all cases correctly.
+//
+// Other notes:
+// Note that max_visible_seq is only calculated once at iterator construction
+// time, meaning if the same transaction is adding more unprep seqs through
+// writes during iteration, these newer writes may not be visible. This is not a
+// problem for MySQL though because it avoids modifying the index as it is
+// scanning through it to avoid the Halloween Problem. Instead, it scans the
+// index once up front, and modifies based on a temporary copy.
+//
+// In DBIter, there is a "reseek" optimization if the iterator skips over too
+// many keys. However, this assumes that the reseek seeks exactly to the
+// required key. In write unprepared, even after seeking directly to
+// max_visible_seq, some iteration may be required before hitting a visible key,
+// and special precautions must be taken to avoid performing another reseek,
+// leading to an infinite loop.
+//
+class WriteUnpreparedTxnReadCallback : public ReadCallback {
+ public:
+ WriteUnpreparedTxnReadCallback(
+ WritePreparedTxnDB* db, SequenceNumber snapshot,
+ SequenceNumber min_uncommitted,
+ const std::map<SequenceNumber, size_t>& unprep_seqs,
+ SnapshotBackup backed_by_snapshot)
+ // Pass our last uncommitted seq as the snapshot to the parent class to
+ // ensure that the parent will not prematurely filter out own writes. We
+ // will do the exact comparison against snapshots in IsVisibleFullCheck
+ // override.
+ : ReadCallback(CalcMaxVisibleSeq(unprep_seqs, snapshot), min_uncommitted),
+ db_(db),
+ unprep_seqs_(unprep_seqs),
+ wup_snapshot_(snapshot),
+ backed_by_snapshot_(backed_by_snapshot) {
+ (void)backed_by_snapshot_; // to silence unused private field warning
+ }
+
+ virtual ~WriteUnpreparedTxnReadCallback() {
+ // If it is not backed by snapshot, the caller must check validity
+ assert(valid_checked_ || backed_by_snapshot_ == kBackedByDBSnapshot);
+ }
+
+ virtual bool IsVisibleFullCheck(SequenceNumber seq) override;
+
+ inline bool valid() {
+ valid_checked_ = true;
+ return snap_released_ == false;
+ }
+
+ void Refresh(SequenceNumber seq) override {
+ max_visible_seq_ = std::max(max_visible_seq_, seq);
+ wup_snapshot_ = seq;
+ }
+
+ static SequenceNumber CalcMaxVisibleSeq(
+ const std::map<SequenceNumber, size_t>& unprep_seqs,
+ SequenceNumber snapshot_seq) {
+ SequenceNumber max_unprepared = 0;
+ if (unprep_seqs.size()) {
+ max_unprepared =
+ unprep_seqs.rbegin()->first + unprep_seqs.rbegin()->second - 1;
+ }
+ return std::max(max_unprepared, snapshot_seq);
+ }
+
+ private:
+ WritePreparedTxnDB* db_;
+ const std::map<SequenceNumber, size_t>& unprep_seqs_;
+ SequenceNumber wup_snapshot_;
+ // Whether max_visible_seq_ is backed by a snapshot
+ const SnapshotBackup backed_by_snapshot_;
+ bool snap_released_ = false;
+ // Safety check to ensure that the caller has checked invalid statuses
+ bool valid_checked_ = false;
+};
+
+class WriteUnpreparedTxn : public WritePreparedTxn {
+ public:
+ WriteUnpreparedTxn(WriteUnpreparedTxnDB* db,
+ const WriteOptions& write_options,
+ const TransactionOptions& txn_options);
+
+ virtual ~WriteUnpreparedTxn();
+
+ using TransactionBaseImpl::Put;
+ virtual Status Put(ColumnFamilyHandle* column_family, const Slice& key,
+ const Slice& value,
+ const bool assume_tracked = false) override;
+ virtual Status Put(ColumnFamilyHandle* column_family, const SliceParts& key,
+ const SliceParts& value,
+ const bool assume_tracked = false) override;
+
+ using TransactionBaseImpl::Merge;
+ virtual Status Merge(ColumnFamilyHandle* column_family, const Slice& key,
+ const Slice& value,
+ const bool assume_tracked = false) override;
+
+ using TransactionBaseImpl::Delete;
+ virtual Status Delete(ColumnFamilyHandle* column_family, const Slice& key,
+ const bool assume_tracked = false) override;
+ virtual Status Delete(ColumnFamilyHandle* column_family,
+ const SliceParts& key,
+ const bool assume_tracked = false) override;
+
+ using TransactionBaseImpl::SingleDelete;
+ virtual Status SingleDelete(ColumnFamilyHandle* column_family,
+ const Slice& key,
+ const bool assume_tracked = false) override;
+ virtual Status SingleDelete(ColumnFamilyHandle* column_family,
+ const SliceParts& key,
+ const bool assume_tracked = false) override;
+
+ // In WriteUnprepared, untracked writes will break snapshot validation logic.
+ // Snapshot validation will only check the largest sequence number of a key to
+ // see if it was committed or not. However, an untracked unprepared write will
+ // hide smaller committed sequence numbers.
+ //
+ // TODO(lth): Investigate whether it is worth having snapshot validation
+ // validate all values larger than snap_seq. Otherwise, we should return
+ // Status::NotSupported for untracked writes.
+
+ virtual Status RebuildFromWriteBatch(WriteBatch*) override;
+
+ virtual uint64_t GetLastLogNumber() const override {
+ return last_log_number_;
+ }
+
+ void RemoveActiveIterator(Iterator* iter) {
+ active_iterators_.erase(
+ std::remove(active_iterators_.begin(), active_iterators_.end(), iter),
+ active_iterators_.end());
+ }
+
+ protected:
+ void Initialize(const TransactionOptions& txn_options) override;
+
+ Status PrepareInternal() override;
+
+ Status CommitWithoutPrepareInternal() override;
+ Status CommitInternal() override;
+
+ Status RollbackInternal() override;
+
+ void Clear() override;
+
+ void SetSavePoint() override;
+ Status RollbackToSavePoint() override;
+ Status PopSavePoint() override;
+
+ // Get and GetIterator needs to be overridden so that a ReadCallback to
+ // handle read-your-own-write is used.
+ using Transaction::Get;
+ virtual Status Get(const ReadOptions& options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ PinnableSlice* value) override;
+
+ using Transaction::MultiGet;
+ virtual void MultiGet(const ReadOptions& options,
+ ColumnFamilyHandle* column_family,
+ const size_t num_keys, const Slice* keys,
+ PinnableSlice* values, Status* statuses,
+ const bool sorted_input = false) override;
+
+ using Transaction::GetIterator;
+ virtual Iterator* GetIterator(const ReadOptions& options) override;
+ virtual Iterator* GetIterator(const ReadOptions& options,
+ ColumnFamilyHandle* column_family) override;
+
+ virtual Status ValidateSnapshot(ColumnFamilyHandle* column_family,
+ const Slice& key,
+ SequenceNumber* tracked_at_seq) override;
+
+ private:
+ friend class WriteUnpreparedTransactionTest_ReadYourOwnWrite_Test;
+ friend class WriteUnpreparedTransactionTest_RecoveryTest_Test;
+ friend class WriteUnpreparedTransactionTest_UnpreparedBatch_Test;
+ friend class WriteUnpreparedTxnDB;
+
+ const std::map<SequenceNumber, size_t>& GetUnpreparedSequenceNumbers();
+ Status WriteRollbackKeys(const LockTracker& tracked_keys,
+ WriteBatchWithIndex* rollback_batch,
+ ReadCallback* callback, const ReadOptions& roptions);
+
+ Status MaybeFlushWriteBatchToDB();
+ Status FlushWriteBatchToDB(bool prepared);
+ Status FlushWriteBatchToDBInternal(bool prepared);
+ Status FlushWriteBatchWithSavePointToDB();
+ Status RollbackToSavePointInternal();
+ Status HandleWrite(std::function<Status()> do_write);
+
+ // For write unprepared, we check on every writebatch append to see if
+ // write_batch_flush_threshold_ has been exceeded, and then call
+ // FlushWriteBatchToDB if so. This logic is encapsulated in
+ // MaybeFlushWriteBatchToDB.
+ int64_t write_batch_flush_threshold_;
+ WriteUnpreparedTxnDB* wupt_db_;
+
+ // Ordered list of unprep_seq sequence numbers that we have already written
+ // to DB.
+ //
+ // This maps unprep_seq => prepare_batch_cnt for each unprepared batch
+ // written by this transaction.
+ //
+ // Note that this contains both prepared and unprepared batches, since they
+ // are treated similarily in prepare heap/commit map, so it simplifies the
+ // commit callbacks.
+ std::map<SequenceNumber, size_t> unprep_seqs_;
+
+ uint64_t last_log_number_;
+
+ // Recovered transactions have tracked_keys_ populated, but are not actually
+ // locked for efficiency reasons. For recovered transactions, skip unlocking
+ // keys when transaction ends.
+ bool recovered_txn_;
+
+ // Track the largest sequence number at which we performed snapshot
+ // validation. If snapshot validation was skipped because no snapshot was set,
+ // then this is set to GetLastPublishedSequence. This value is useful because
+ // it means that for keys that have unprepared seqnos, we can guarantee that
+ // no committed keys by other transactions can exist between
+ // largest_validated_seq_ and max_unprep_seq. See
+ // WriteUnpreparedTxnDB::NewIterator for an explanation for why this is
+ // necessary for iterator Prev().
+ //
+ // Currently this value only increases during the lifetime of a transaction,
+ // but in some cases, we should be able to restore the previously largest
+ // value when calling RollbackToSavepoint.
+ SequenceNumber largest_validated_seq_;
+
+ struct SavePoint {
+ // Record of unprep_seqs_ at this savepoint. The set of unprep_seq is
+ // used during RollbackToSavepoint to determine visibility when restoring
+ // old values.
+ //
+ // TODO(lth): Since all unprep_seqs_ sets further down the stack must be
+ // subsets, this can potentially be deduplicated by just storing set
+ // difference. Investigate if this is worth it.
+ std::map<SequenceNumber, size_t> unprep_seqs_;
+
+ // This snapshot will be used to read keys at this savepoint if we call
+ // RollbackToSavePoint.
+ std::unique_ptr<ManagedSnapshot> snapshot_;
+
+ SavePoint(const std::map<SequenceNumber, size_t>& seqs,
+ ManagedSnapshot* snapshot)
+ : unprep_seqs_(seqs), snapshot_(snapshot){};
+ };
+
+ // We have 3 data structures holding savepoint information:
+ // 1. TransactionBaseImpl::save_points_
+ // 2. WriteUnpreparedTxn::flushed_save_points_
+ // 3. WriteUnpreparecTxn::unflushed_save_points_
+ //
+ // TransactionBaseImpl::save_points_ holds information about all write
+ // batches, including the current in-memory write_batch_, or unprepared
+ // batches that have been written out. Its responsibility is just to track
+ // which keys have been modified in every savepoint.
+ //
+ // WriteUnpreparedTxn::flushed_save_points_ holds information about savepoints
+ // set on unprepared batches that have already flushed. It holds the snapshot
+ // and unprep_seqs at that savepoint, so that the rollback process can
+ // determine which keys were visible at that point in time.
+ //
+ // WriteUnpreparecTxn::unflushed_save_points_ holds information about
+ // savepoints on the current in-memory write_batch_. It simply records the
+ // size of the write batch at every savepoint.
+ //
+ // TODO(lth): Remove the redundancy between save_point_boundaries_ and
+ // write_batch_.save_points_.
+ //
+ // Based on this information, here are some invariants:
+ // size(unflushed_save_points_) = size(write_batch_.save_points_)
+ // size(flushed_save_points_) + size(unflushed_save_points_)
+ // = size(save_points_)
+ //
+ std::unique_ptr<autovector<WriteUnpreparedTxn::SavePoint>>
+ flushed_save_points_;
+ std::unique_ptr<autovector<size_t>> unflushed_save_points_;
+
+ // It is currently unsafe to flush a write batch if there are active iterators
+ // created from this transaction. This is because we use WriteBatchWithIndex
+ // to do merging reads from the DB and the write batch. If we flush the write
+ // batch, it is possible that the delta iterator on the iterator will point to
+ // invalid memory.
+ std::vector<Iterator*> active_iterators_;
+
+ // Untracked keys that we have to rollback.
+ //
+ // TODO(lth): Currently we we do not record untracked keys per-savepoint.
+ // This means that when rolling back to savepoints, we have to check all
+ // keys in the current transaction for rollback. Note that this is only
+ // inefficient, but still correct because we take a snapshot at every
+ // savepoint, and we will use that snapshot to construct the rollback batch.
+ // The rollback batch will then contain a reissue of the same marker.
+ //
+ // A more optimal solution would be to only check keys changed since the
+ // last savepoint. Also, it may make sense to merge this into tracked_keys_
+ // and differentiate between tracked but not locked keys to avoid having two
+ // very similar data structures.
+ using KeySet = std::unordered_map<uint32_t, std::vector<std::string>>;
+ KeySet untracked_keys_;
+};
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/write_unprepared_txn_db.cc b/src/rocksdb/utilities/transactions/write_unprepared_txn_db.cc
new file mode 100644
index 000000000..2ed2d5c59
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/write_unprepared_txn_db.cc
@@ -0,0 +1,473 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "utilities/transactions/write_unprepared_txn_db.h"
+
+#include "db/arena_wrapped_db_iter.h"
+#include "rocksdb/utilities/transaction_db.h"
+#include "util/cast_util.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// Instead of reconstructing a Transaction object, and calling rollback on it,
+// we can be more efficient with RollbackRecoveredTransaction by skipping
+// unnecessary steps (eg. updating CommitMap, reconstructing keyset)
+Status WriteUnpreparedTxnDB::RollbackRecoveredTransaction(
+ const DBImpl::RecoveredTransaction* rtxn) {
+ // TODO(lth): Reduce duplicate code with WritePrepared rollback logic.
+ assert(rtxn->unprepared_);
+ auto cf_map_shared_ptr = WritePreparedTxnDB::GetCFHandleMap();
+ auto cf_comp_map_shared_ptr = WritePreparedTxnDB::GetCFComparatorMap();
+ // In theory we could write with disableWAL = true during recovery, and
+ // assume that if we crash again during recovery, we can just replay from
+ // the very beginning. Unfortunately, the XIDs from the application may not
+ // necessarily be unique across restarts, potentially leading to situations
+ // like this:
+ //
+ // BEGIN_PREPARE(unprepared) Put(a) END_PREPARE(xid = 1)
+ // -- crash and recover with Put(a) rolled back as it was not prepared
+ // BEGIN_PREPARE(prepared) Put(b) END_PREPARE(xid = 1)
+ // COMMIT(xid = 1)
+ // -- crash and recover with both a, b
+ //
+ // We could just write the rollback marker, but then we would have to extend
+ // MemTableInserter during recovery to actually do writes into the DB
+ // instead of just dropping the in-memory write batch.
+ //
+ WriteOptions w_options;
+
+ class InvalidSnapshotReadCallback : public ReadCallback {
+ public:
+ InvalidSnapshotReadCallback(SequenceNumber snapshot)
+ : ReadCallback(snapshot) {}
+
+ inline bool IsVisibleFullCheck(SequenceNumber) override {
+ // The seq provided as snapshot is the seq right before we have locked and
+ // wrote to it, so whatever is there, it is committed.
+ return true;
+ }
+
+ // Ignore the refresh request since we are confident that our snapshot seq
+ // is not going to be affected by concurrent compactions (not enabled yet.)
+ void Refresh(SequenceNumber) override {}
+ };
+
+ // Iterate starting with largest sequence number.
+ for (auto it = rtxn->batches_.rbegin(); it != rtxn->batches_.rend(); ++it) {
+ auto last_visible_txn = it->first - 1;
+ const auto& batch = it->second.batch_;
+ WriteBatch rollback_batch(0 /* reserved_bytes */, 0 /* max_bytes */,
+ w_options.protection_bytes_per_key,
+ 0 /* default_cf_ts_sz */);
+
+ struct RollbackWriteBatchBuilder : public WriteBatch::Handler {
+ DBImpl* db_;
+ ReadOptions roptions;
+ InvalidSnapshotReadCallback callback;
+ WriteBatch* rollback_batch_;
+ std::map<uint32_t, const Comparator*>& comparators_;
+ std::map<uint32_t, ColumnFamilyHandle*>& handles_;
+ using CFKeys = std::set<Slice, SetComparator>;
+ std::map<uint32_t, CFKeys> keys_;
+ bool rollback_merge_operands_;
+ RollbackWriteBatchBuilder(
+ DBImpl* db, SequenceNumber snap_seq, WriteBatch* dst_batch,
+ std::map<uint32_t, const Comparator*>& comparators,
+ std::map<uint32_t, ColumnFamilyHandle*>& handles,
+ bool rollback_merge_operands)
+ : db_(db),
+ callback(snap_seq),
+ // disable min_uncommitted optimization
+ rollback_batch_(dst_batch),
+ comparators_(comparators),
+ handles_(handles),
+ rollback_merge_operands_(rollback_merge_operands) {}
+
+ Status Rollback(uint32_t cf, const Slice& key) {
+ Status s;
+ CFKeys& cf_keys = keys_[cf];
+ if (cf_keys.size() == 0) { // just inserted
+ auto cmp = comparators_[cf];
+ keys_[cf] = CFKeys(SetComparator(cmp));
+ }
+ auto res = cf_keys.insert(key);
+ if (res.second ==
+ false) { // second is false if a element already existed.
+ return s;
+ }
+
+ PinnableSlice pinnable_val;
+ bool not_used;
+ auto cf_handle = handles_[cf];
+ DBImpl::GetImplOptions get_impl_options;
+ get_impl_options.column_family = cf_handle;
+ get_impl_options.value = &pinnable_val;
+ get_impl_options.value_found = &not_used;
+ get_impl_options.callback = &callback;
+ s = db_->GetImpl(roptions, key, get_impl_options);
+ assert(s.ok() || s.IsNotFound());
+ if (s.ok()) {
+ s = rollback_batch_->Put(cf_handle, key, pinnable_val);
+ assert(s.ok());
+ } else if (s.IsNotFound()) {
+ // There has been no readable value before txn. By adding a delete we
+ // make sure that there will be none afterwards either.
+ s = rollback_batch_->Delete(cf_handle, key);
+ assert(s.ok());
+ } else {
+ // Unexpected status. Return it to the user.
+ }
+ return s;
+ }
+
+ Status PutCF(uint32_t cf, const Slice& key,
+ const Slice& /*val*/) override {
+ return Rollback(cf, key);
+ }
+
+ Status DeleteCF(uint32_t cf, const Slice& key) override {
+ return Rollback(cf, key);
+ }
+
+ Status SingleDeleteCF(uint32_t cf, const Slice& key) override {
+ return Rollback(cf, key);
+ }
+
+ Status MergeCF(uint32_t cf, const Slice& key,
+ const Slice& /*val*/) override {
+ if (rollback_merge_operands_) {
+ return Rollback(cf, key);
+ } else {
+ return Status::OK();
+ }
+ }
+
+ // Recovered batches do not contain 2PC markers.
+ Status MarkNoop(bool) override { return Status::InvalidArgument(); }
+ Status MarkBeginPrepare(bool) override {
+ return Status::InvalidArgument();
+ }
+ Status MarkEndPrepare(const Slice&) override {
+ return Status::InvalidArgument();
+ }
+ Status MarkCommit(const Slice&) override {
+ return Status::InvalidArgument();
+ }
+ Status MarkRollback(const Slice&) override {
+ return Status::InvalidArgument();
+ }
+ } rollback_handler(db_impl_, last_visible_txn, &rollback_batch,
+ *cf_comp_map_shared_ptr.get(), *cf_map_shared_ptr.get(),
+ txn_db_options_.rollback_merge_operands);
+
+ auto s = batch->Iterate(&rollback_handler);
+ if (!s.ok()) {
+ return s;
+ }
+
+ // The Rollback marker will be used as a batch separator
+ s = WriteBatchInternal::MarkRollback(&rollback_batch, rtxn->name_);
+ if (!s.ok()) {
+ return s;
+ }
+
+ const uint64_t kNoLogRef = 0;
+ const bool kDisableMemtable = true;
+ const size_t kOneBatch = 1;
+ uint64_t seq_used = kMaxSequenceNumber;
+ s = db_impl_->WriteImpl(w_options, &rollback_batch, nullptr, nullptr,
+ kNoLogRef, !kDisableMemtable, &seq_used, kOneBatch);
+ if (!s.ok()) {
+ return s;
+ }
+
+ // If two_write_queues, we must manually release the sequence number to
+ // readers.
+ if (db_impl_->immutable_db_options().two_write_queues) {
+ db_impl_->SetLastPublishedSequence(seq_used);
+ }
+ }
+
+ return Status::OK();
+}
+
+Status WriteUnpreparedTxnDB::Initialize(
+ const std::vector<size_t>& compaction_enabled_cf_indices,
+ const std::vector<ColumnFamilyHandle*>& handles) {
+ // TODO(lth): Reduce code duplication in this function.
+ auto dbimpl = static_cast_with_check<DBImpl>(GetRootDB());
+ assert(dbimpl != nullptr);
+
+ db_impl_->SetSnapshotChecker(new WritePreparedSnapshotChecker(this));
+ // A callback to commit a single sub-batch
+ class CommitSubBatchPreReleaseCallback : public PreReleaseCallback {
+ public:
+ explicit CommitSubBatchPreReleaseCallback(WritePreparedTxnDB* db)
+ : db_(db) {}
+ Status Callback(SequenceNumber commit_seq,
+ bool is_mem_disabled __attribute__((__unused__)), uint64_t,
+ size_t /*index*/, size_t /*total*/) override {
+ assert(!is_mem_disabled);
+ db_->AddCommitted(commit_seq, commit_seq);
+ return Status::OK();
+ }
+
+ private:
+ WritePreparedTxnDB* db_;
+ };
+ db_impl_->SetRecoverableStatePreReleaseCallback(
+ new CommitSubBatchPreReleaseCallback(this));
+
+ // PessimisticTransactionDB::Initialize
+ for (auto cf_ptr : handles) {
+ AddColumnFamily(cf_ptr);
+ }
+ // Verify cf options
+ for (auto handle : handles) {
+ ColumnFamilyDescriptor cfd;
+ Status s = handle->GetDescriptor(&cfd);
+ if (!s.ok()) {
+ return s;
+ }
+ s = VerifyCFOptions(cfd.options);
+ if (!s.ok()) {
+ return s;
+ }
+ }
+
+ // Re-enable compaction for the column families that initially had
+ // compaction enabled.
+ std::vector<ColumnFamilyHandle*> compaction_enabled_cf_handles;
+ compaction_enabled_cf_handles.reserve(compaction_enabled_cf_indices.size());
+ for (auto index : compaction_enabled_cf_indices) {
+ compaction_enabled_cf_handles.push_back(handles[index]);
+ }
+
+ // create 'real' transactions from recovered shell transactions
+ auto rtxns = dbimpl->recovered_transactions();
+ std::map<SequenceNumber, SequenceNumber> ordered_seq_cnt;
+ for (auto rtxn : rtxns) {
+ auto recovered_trx = rtxn.second;
+ assert(recovered_trx);
+ assert(recovered_trx->batches_.size() >= 1);
+ assert(recovered_trx->name_.length());
+
+ // We can only rollback transactions after AdvanceMaxEvictedSeq is called,
+ // but AddPrepared must occur before AdvanceMaxEvictedSeq, which is why
+ // two iterations is required.
+ if (recovered_trx->unprepared_) {
+ continue;
+ }
+
+ WriteOptions w_options;
+ w_options.sync = true;
+ TransactionOptions t_options;
+
+ auto first_log_number = recovered_trx->batches_.begin()->second.log_number_;
+ auto first_seq = recovered_trx->batches_.begin()->first;
+ auto last_prepare_batch_cnt =
+ recovered_trx->batches_.begin()->second.batch_cnt_;
+
+ Transaction* real_trx = BeginTransaction(w_options, t_options, nullptr);
+ assert(real_trx);
+ auto wupt = static_cast_with_check<WriteUnpreparedTxn>(real_trx);
+ wupt->recovered_txn_ = true;
+
+ real_trx->SetLogNumber(first_log_number);
+ real_trx->SetId(first_seq);
+ Status s = real_trx->SetName(recovered_trx->name_);
+ if (!s.ok()) {
+ return s;
+ }
+ wupt->prepare_batch_cnt_ = last_prepare_batch_cnt;
+
+ for (auto batch : recovered_trx->batches_) {
+ const auto& seq = batch.first;
+ const auto& batch_info = batch.second;
+ auto cnt = batch_info.batch_cnt_ ? batch_info.batch_cnt_ : 1;
+ assert(batch_info.log_number_);
+
+ ordered_seq_cnt[seq] = cnt;
+ assert(wupt->unprep_seqs_.count(seq) == 0);
+ wupt->unprep_seqs_[seq] = cnt;
+
+ s = wupt->RebuildFromWriteBatch(batch_info.batch_);
+ assert(s.ok());
+ if (!s.ok()) {
+ return s;
+ }
+ }
+
+ const bool kClear = true;
+ wupt->InitWriteBatch(kClear);
+
+ real_trx->SetState(Transaction::PREPARED);
+ if (!s.ok()) {
+ return s;
+ }
+ }
+ // AddPrepared must be called in order
+ for (auto seq_cnt : ordered_seq_cnt) {
+ auto seq = seq_cnt.first;
+ auto cnt = seq_cnt.second;
+ for (size_t i = 0; i < cnt; i++) {
+ AddPrepared(seq + i);
+ }
+ }
+
+ SequenceNumber prev_max = max_evicted_seq_;
+ SequenceNumber last_seq = db_impl_->GetLatestSequenceNumber();
+ AdvanceMaxEvictedSeq(prev_max, last_seq);
+ // Create a gap between max and the next snapshot. This simplifies the logic
+ // in IsInSnapshot by not having to consider the special case of max ==
+ // snapshot after recovery. This is tested in IsInSnapshotEmptyMapTest.
+ if (last_seq) {
+ db_impl_->versions_->SetLastAllocatedSequence(last_seq + 1);
+ db_impl_->versions_->SetLastSequence(last_seq + 1);
+ db_impl_->versions_->SetLastPublishedSequence(last_seq + 1);
+ }
+
+ Status s;
+ // Rollback unprepared transactions.
+ for (auto rtxn : rtxns) {
+ auto recovered_trx = rtxn.second;
+ if (recovered_trx->unprepared_) {
+ s = RollbackRecoveredTransaction(recovered_trx);
+ if (!s.ok()) {
+ return s;
+ }
+ continue;
+ }
+ }
+
+ if (s.ok()) {
+ dbimpl->DeleteAllRecoveredTransactions();
+
+ // Compaction should start only after max_evicted_seq_ is set AND recovered
+ // transactions are either added to PrepareHeap or rolled back.
+ s = EnableAutoCompaction(compaction_enabled_cf_handles);
+ }
+
+ return s;
+}
+
+Transaction* WriteUnpreparedTxnDB::BeginTransaction(
+ const WriteOptions& write_options, const TransactionOptions& txn_options,
+ Transaction* old_txn) {
+ if (old_txn != nullptr) {
+ ReinitializeTransaction(old_txn, write_options, txn_options);
+ return old_txn;
+ } else {
+ return new WriteUnpreparedTxn(this, write_options, txn_options);
+ }
+}
+
+// Struct to hold ownership of snapshot and read callback for iterator cleanup.
+struct WriteUnpreparedTxnDB::IteratorState {
+ IteratorState(WritePreparedTxnDB* txn_db, SequenceNumber sequence,
+ std::shared_ptr<ManagedSnapshot> s,
+ SequenceNumber min_uncommitted, WriteUnpreparedTxn* txn)
+ : callback(txn_db, sequence, min_uncommitted, txn->unprep_seqs_,
+ kBackedByDBSnapshot),
+ snapshot(s) {}
+ SequenceNumber MaxVisibleSeq() { return callback.max_visible_seq(); }
+
+ WriteUnpreparedTxnReadCallback callback;
+ std::shared_ptr<ManagedSnapshot> snapshot;
+};
+
+namespace {
+static void CleanupWriteUnpreparedTxnDBIterator(void* arg1, void* /*arg2*/) {
+ delete reinterpret_cast<WriteUnpreparedTxnDB::IteratorState*>(arg1);
+}
+} // anonymous namespace
+
+Iterator* WriteUnpreparedTxnDB::NewIterator(const ReadOptions& options,
+ ColumnFamilyHandle* column_family,
+ WriteUnpreparedTxn* txn) {
+ // TODO(lth): Refactor so that this logic is shared with WritePrepared.
+ constexpr bool expose_blob_index = false;
+ constexpr bool allow_refresh = false;
+ std::shared_ptr<ManagedSnapshot> own_snapshot = nullptr;
+ SequenceNumber snapshot_seq = kMaxSequenceNumber;
+ SequenceNumber min_uncommitted = 0;
+
+ // Currently, the Prev() iterator logic does not work well without snapshot
+ // validation. The logic simply iterates through values of a key in
+ // ascending seqno order, stopping at the first non-visible value and
+ // returning the last visible value.
+ //
+ // For example, if snapshot sequence is 3, and we have the following keys:
+ // foo: v1 1
+ // foo: v2 2
+ // foo: v3 3
+ // foo: v4 4
+ // foo: v5 5
+ //
+ // Then 1, 2, 3 will be visible, but 4 will be non-visible, so we return v3,
+ // which is the last visible value.
+ //
+ // For unprepared transactions, if we have snap_seq = 3, but the current
+ // transaction has unprep_seq 5, then returning the first non-visible value
+ // would be incorrect, as we should return v5, and not v3. The problem is that
+ // there are committed values at snapshot_seq < commit_seq < unprep_seq.
+ //
+ // Snapshot validation can prevent this problem by ensuring that no committed
+ // values exist at snapshot_seq < commit_seq, and thus any value with a
+ // sequence number greater than snapshot_seq must be unprepared values. For
+ // example, if the transaction had a snapshot at 3, then snapshot validation
+ // would be performed during the Put(v5) call. It would find v4, and the Put
+ // would fail with snapshot validation failure.
+ //
+ // TODO(lth): Improve Prev() logic to continue iterating until
+ // max_visible_seq, and then return the last visible value, so that this
+ // restriction can be lifted.
+ const Snapshot* snapshot = nullptr;
+ if (options.snapshot == nullptr) {
+ snapshot = GetSnapshot();
+ own_snapshot = std::make_shared<ManagedSnapshot>(db_impl_, snapshot);
+ } else {
+ snapshot = options.snapshot;
+ }
+
+ snapshot_seq = snapshot->GetSequenceNumber();
+ assert(snapshot_seq != kMaxSequenceNumber);
+ // Iteration is safe as long as largest_validated_seq <= snapshot_seq. We are
+ // guaranteed that for keys that were modified by this transaction (and thus
+ // might have unprepared values), no committed values exist at
+ // largest_validated_seq < commit_seq (or the contrapositive: any committed
+ // value must exist at commit_seq <= largest_validated_seq). This implies
+ // that commit_seq <= largest_validated_seq <= snapshot_seq or commit_seq <=
+ // snapshot_seq. As explained above, the problem with Prev() only happens when
+ // snapshot_seq < commit_seq.
+ //
+ // For keys that were not modified by this transaction, largest_validated_seq_
+ // is meaningless, and Prev() should just work with the existing visibility
+ // logic.
+ if (txn->largest_validated_seq_ > snapshot->GetSequenceNumber() &&
+ !txn->unprep_seqs_.empty()) {
+ ROCKS_LOG_ERROR(info_log_,
+ "WriteUnprepared iterator creation failed since the "
+ "transaction has performed unvalidated writes");
+ return nullptr;
+ }
+ min_uncommitted =
+ static_cast_with_check<const SnapshotImpl>(snapshot)->min_uncommitted_;
+
+ auto* cfd =
+ static_cast_with_check<ColumnFamilyHandleImpl>(column_family)->cfd();
+ auto* state =
+ new IteratorState(this, snapshot_seq, own_snapshot, min_uncommitted, txn);
+ auto* db_iter = db_impl_->NewIteratorImpl(
+ options, cfd, state->MaxVisibleSeq(), &state->callback, expose_blob_index,
+ allow_refresh);
+ db_iter->RegisterCleanup(CleanupWriteUnpreparedTxnDBIterator, state, nullptr);
+ return db_iter;
+}
+
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/write_unprepared_txn_db.h b/src/rocksdb/utilities/transactions/write_unprepared_txn_db.h
new file mode 100644
index 000000000..c40e96d49
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/write_unprepared_txn_db.h
@@ -0,0 +1,108 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+#ifndef ROCKSDB_LITE
+
+#include "utilities/transactions/write_prepared_txn_db.h"
+#include "utilities/transactions/write_unprepared_txn.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class WriteUnpreparedTxn;
+
+class WriteUnpreparedTxnDB : public WritePreparedTxnDB {
+ public:
+ using WritePreparedTxnDB::WritePreparedTxnDB;
+
+ Status Initialize(const std::vector<size_t>& compaction_enabled_cf_indices,
+ const std::vector<ColumnFamilyHandle*>& handles) override;
+
+ Transaction* BeginTransaction(const WriteOptions& write_options,
+ const TransactionOptions& txn_options,
+ Transaction* old_txn) override;
+
+ // Struct to hold ownership of snapshot and read callback for cleanup.
+ struct IteratorState;
+
+ using WritePreparedTxnDB::NewIterator;
+ Iterator* NewIterator(const ReadOptions& options,
+ ColumnFamilyHandle* column_family,
+ WriteUnpreparedTxn* txn);
+
+ private:
+ Status RollbackRecoveredTransaction(const DBImpl::RecoveredTransaction* rtxn);
+};
+
+class WriteUnpreparedCommitEntryPreReleaseCallback : public PreReleaseCallback {
+ // TODO(lth): Reduce code duplication with
+ // WritePreparedCommitEntryPreReleaseCallback
+ public:
+ // includes_data indicates that the commit also writes non-empty
+ // CommitTimeWriteBatch to memtable, which needs to be committed separately.
+ WriteUnpreparedCommitEntryPreReleaseCallback(
+ WritePreparedTxnDB* db, DBImpl* db_impl,
+ const std::map<SequenceNumber, size_t>& unprep_seqs,
+ size_t data_batch_cnt = 0, bool publish_seq = true)
+ : db_(db),
+ db_impl_(db_impl),
+ unprep_seqs_(unprep_seqs),
+ data_batch_cnt_(data_batch_cnt),
+ includes_data_(data_batch_cnt_ > 0),
+ publish_seq_(publish_seq) {
+ assert(unprep_seqs.size() > 0);
+ }
+
+ virtual Status Callback(SequenceNumber commit_seq,
+ bool is_mem_disabled __attribute__((__unused__)),
+ uint64_t, size_t /*index*/,
+ size_t /*total*/) override {
+ const uint64_t last_commit_seq = LIKELY(data_batch_cnt_ <= 1)
+ ? commit_seq
+ : commit_seq + data_batch_cnt_ - 1;
+ // Recall that unprep_seqs maps (un)prepared_seq => prepare_batch_cnt.
+ for (const auto& s : unprep_seqs_) {
+ for (size_t i = 0; i < s.second; i++) {
+ db_->AddCommitted(s.first + i, last_commit_seq);
+ }
+ }
+
+ if (includes_data_) {
+ assert(data_batch_cnt_);
+ // Commit the data that is accompanied with the commit request
+ for (size_t i = 0; i < data_batch_cnt_; i++) {
+ // For commit seq of each batch use the commit seq of the last batch.
+ // This would make debugging easier by having all the batches having
+ // the same sequence number.
+ db_->AddCommitted(commit_seq + i, last_commit_seq);
+ }
+ }
+ if (db_impl_->immutable_db_options().two_write_queues && publish_seq_) {
+ assert(is_mem_disabled); // implies the 2nd queue
+ // Publish the sequence number. We can do that here assuming the callback
+ // is invoked only from one write queue, which would guarantee that the
+ // publish sequence numbers will be in order, i.e., once a seq is
+ // published all the seq prior to that are also publishable.
+ db_impl_->SetLastPublishedSequence(last_commit_seq);
+ }
+ // else SequenceNumber that is updated as part of the write already does the
+ // publishing
+ return Status::OK();
+ }
+
+ private:
+ WritePreparedTxnDB* db_;
+ DBImpl* db_impl_;
+ const std::map<SequenceNumber, size_t>& unprep_seqs_;
+ size_t data_batch_cnt_;
+ // Either because it is commit without prepare or it has a
+ // CommitTimeWriteBatch
+ bool includes_data_;
+ // Should the callback also publishes the commit seq number
+ bool publish_seq_;
+};
+
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/ttl/db_ttl_impl.cc b/src/rocksdb/utilities/ttl/db_ttl_impl.cc
new file mode 100644
index 000000000..6ec9d87b0
--- /dev/null
+++ b/src/rocksdb/utilities/ttl/db_ttl_impl.cc
@@ -0,0 +1,609 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+#ifndef ROCKSDB_LITE
+
+#include "utilities/ttl/db_ttl_impl.h"
+
+#include "db/write_batch_internal.h"
+#include "file/filename.h"
+#include "logging/logging.h"
+#include "rocksdb/convenience.h"
+#include "rocksdb/env.h"
+#include "rocksdb/iterator.h"
+#include "rocksdb/system_clock.h"
+#include "rocksdb/utilities/db_ttl.h"
+#include "rocksdb/utilities/object_registry.h"
+#include "rocksdb/utilities/options_type.h"
+#include "util/coding.h"
+
+namespace ROCKSDB_NAMESPACE {
+static std::unordered_map<std::string, OptionTypeInfo> ttl_merge_op_type_info =
+ {{"user_operator",
+ OptionTypeInfo::AsCustomSharedPtr<MergeOperator>(
+ 0, OptionVerificationType::kByName, OptionTypeFlags::kNone)}};
+
+TtlMergeOperator::TtlMergeOperator(
+ const std::shared_ptr<MergeOperator>& merge_op, SystemClock* clock)
+ : user_merge_op_(merge_op), clock_(clock) {
+ RegisterOptions("TtlMergeOptions", &user_merge_op_, &ttl_merge_op_type_info);
+}
+
+bool TtlMergeOperator::FullMergeV2(const MergeOperationInput& merge_in,
+ MergeOperationOutput* merge_out) const {
+ const uint32_t ts_len = DBWithTTLImpl::kTSLength;
+ if (merge_in.existing_value && merge_in.existing_value->size() < ts_len) {
+ ROCKS_LOG_ERROR(merge_in.logger,
+ "Error: Could not remove timestamp from existing value.");
+ return false;
+ }
+
+ // Extract time-stamp from each operand to be passed to user_merge_op_
+ std::vector<Slice> operands_without_ts;
+ for (const auto& operand : merge_in.operand_list) {
+ if (operand.size() < ts_len) {
+ ROCKS_LOG_ERROR(merge_in.logger,
+ "Error: Could not remove timestamp from operand value.");
+ return false;
+ }
+ operands_without_ts.push_back(operand);
+ operands_without_ts.back().remove_suffix(ts_len);
+ }
+
+ // Apply the user merge operator (store result in *new_value)
+ bool good = true;
+ MergeOperationOutput user_merge_out(merge_out->new_value,
+ merge_out->existing_operand);
+ if (merge_in.existing_value) {
+ Slice existing_value_without_ts(merge_in.existing_value->data(),
+ merge_in.existing_value->size() - ts_len);
+ good = user_merge_op_->FullMergeV2(
+ MergeOperationInput(merge_in.key, &existing_value_without_ts,
+ operands_without_ts, merge_in.logger),
+ &user_merge_out);
+ } else {
+ good = user_merge_op_->FullMergeV2(
+ MergeOperationInput(merge_in.key, nullptr, operands_without_ts,
+ merge_in.logger),
+ &user_merge_out);
+ }
+
+ // Return false if the user merge operator returned false
+ if (!good) {
+ return false;
+ }
+
+ if (merge_out->existing_operand.data()) {
+ merge_out->new_value.assign(merge_out->existing_operand.data(),
+ merge_out->existing_operand.size());
+ merge_out->existing_operand = Slice(nullptr, 0);
+ }
+
+ // Augment the *new_value with the ttl time-stamp
+ int64_t curtime;
+ if (!clock_->GetCurrentTime(&curtime).ok()) {
+ ROCKS_LOG_ERROR(
+ merge_in.logger,
+ "Error: Could not get current time to be attached internally "
+ "to the new value.");
+ return false;
+ } else {
+ char ts_string[ts_len];
+ EncodeFixed32(ts_string, (int32_t)curtime);
+ merge_out->new_value.append(ts_string, ts_len);
+ return true;
+ }
+}
+
+bool TtlMergeOperator::PartialMergeMulti(const Slice& key,
+ const std::deque<Slice>& operand_list,
+ std::string* new_value,
+ Logger* logger) const {
+ const uint32_t ts_len = DBWithTTLImpl::kTSLength;
+ std::deque<Slice> operands_without_ts;
+
+ for (const auto& operand : operand_list) {
+ if (operand.size() < ts_len) {
+ ROCKS_LOG_ERROR(logger, "Error: Could not remove timestamp from value.");
+ return false;
+ }
+
+ operands_without_ts.push_back(
+ Slice(operand.data(), operand.size() - ts_len));
+ }
+
+ // Apply the user partial-merge operator (store result in *new_value)
+ assert(new_value);
+ if (!user_merge_op_->PartialMergeMulti(key, operands_without_ts, new_value,
+ logger)) {
+ return false;
+ }
+
+ // Augment the *new_value with the ttl time-stamp
+ int64_t curtime;
+ if (!clock_->GetCurrentTime(&curtime).ok()) {
+ ROCKS_LOG_ERROR(
+ logger,
+ "Error: Could not get current time to be attached internally "
+ "to the new value.");
+ return false;
+ } else {
+ char ts_string[ts_len];
+ EncodeFixed32(ts_string, (int32_t)curtime);
+ new_value->append(ts_string, ts_len);
+ return true;
+ }
+}
+
+Status TtlMergeOperator::PrepareOptions(const ConfigOptions& config_options) {
+ if (clock_ == nullptr) {
+ clock_ = config_options.env->GetSystemClock().get();
+ }
+ return MergeOperator::PrepareOptions(config_options);
+}
+
+Status TtlMergeOperator::ValidateOptions(
+ const DBOptions& db_opts, const ColumnFamilyOptions& cf_opts) const {
+ if (user_merge_op_ == nullptr) {
+ return Status::InvalidArgument(
+ "UserMergeOperator required by TtlMergeOperator");
+ } else if (clock_ == nullptr) {
+ return Status::InvalidArgument("SystemClock required by TtlMergeOperator");
+ } else {
+ return MergeOperator::ValidateOptions(db_opts, cf_opts);
+ }
+}
+
+void DBWithTTLImpl::SanitizeOptions(int32_t ttl, ColumnFamilyOptions* options,
+ SystemClock* clock) {
+ if (options->compaction_filter) {
+ options->compaction_filter =
+ new TtlCompactionFilter(ttl, clock, options->compaction_filter);
+ } else {
+ options->compaction_filter_factory =
+ std::shared_ptr<CompactionFilterFactory>(new TtlCompactionFilterFactory(
+ ttl, clock, options->compaction_filter_factory));
+ }
+
+ if (options->merge_operator) {
+ options->merge_operator.reset(
+ new TtlMergeOperator(options->merge_operator, clock));
+ }
+}
+
+static std::unordered_map<std::string, OptionTypeInfo> ttl_type_info = {
+ {"ttl", {0, OptionType::kInt32T}},
+};
+
+static std::unordered_map<std::string, OptionTypeInfo> ttl_cff_type_info = {
+ {"user_filter_factory",
+ OptionTypeInfo::AsCustomSharedPtr<CompactionFilterFactory>(
+ 0, OptionVerificationType::kByNameAllowFromNull,
+ OptionTypeFlags::kNone)}};
+static std::unordered_map<std::string, OptionTypeInfo> user_cf_type_info = {
+ {"user_filter",
+ OptionTypeInfo::AsCustomRawPtr<const CompactionFilter>(
+ 0, OptionVerificationType::kByName, OptionTypeFlags::kAllowNull)}};
+
+TtlCompactionFilter::TtlCompactionFilter(
+ int32_t ttl, SystemClock* clock, const CompactionFilter* _user_comp_filter,
+ std::unique_ptr<const CompactionFilter> _user_comp_filter_from_factory)
+ : LayeredCompactionFilterBase(_user_comp_filter,
+ std::move(_user_comp_filter_from_factory)),
+ ttl_(ttl),
+ clock_(clock) {
+ RegisterOptions("TTL", &ttl_, &ttl_type_info);
+ RegisterOptions("UserFilter", &user_comp_filter_, &user_cf_type_info);
+}
+
+bool TtlCompactionFilter::Filter(int level, const Slice& key,
+ const Slice& old_val, std::string* new_val,
+ bool* value_changed) const {
+ if (DBWithTTLImpl::IsStale(old_val, ttl_, clock_)) {
+ return true;
+ }
+ if (user_comp_filter() == nullptr) {
+ return false;
+ }
+ assert(old_val.size() >= DBWithTTLImpl::kTSLength);
+ Slice old_val_without_ts(old_val.data(),
+ old_val.size() - DBWithTTLImpl::kTSLength);
+ if (user_comp_filter()->Filter(level, key, old_val_without_ts, new_val,
+ value_changed)) {
+ return true;
+ }
+ if (*value_changed) {
+ new_val->append(old_val.data() + old_val.size() - DBWithTTLImpl::kTSLength,
+ DBWithTTLImpl::kTSLength);
+ }
+ return false;
+}
+
+Status TtlCompactionFilter::PrepareOptions(
+ const ConfigOptions& config_options) {
+ if (clock_ == nullptr) {
+ clock_ = config_options.env->GetSystemClock().get();
+ }
+ return LayeredCompactionFilterBase::PrepareOptions(config_options);
+}
+
+Status TtlCompactionFilter::ValidateOptions(
+ const DBOptions& db_opts, const ColumnFamilyOptions& cf_opts) const {
+ if (clock_ == nullptr) {
+ return Status::InvalidArgument(
+ "SystemClock required by TtlCompactionFilter");
+ } else {
+ return LayeredCompactionFilterBase::ValidateOptions(db_opts, cf_opts);
+ }
+}
+
+TtlCompactionFilterFactory::TtlCompactionFilterFactory(
+ int32_t ttl, SystemClock* clock,
+ std::shared_ptr<CompactionFilterFactory> comp_filter_factory)
+ : ttl_(ttl), clock_(clock), user_comp_filter_factory_(comp_filter_factory) {
+ RegisterOptions("UserOptions", &user_comp_filter_factory_,
+ &ttl_cff_type_info);
+ RegisterOptions("TTL", &ttl_, &ttl_type_info);
+}
+
+std::unique_ptr<CompactionFilter>
+TtlCompactionFilterFactory::CreateCompactionFilter(
+ const CompactionFilter::Context& context) {
+ std::unique_ptr<const CompactionFilter> user_comp_filter_from_factory =
+ nullptr;
+ if (user_comp_filter_factory_) {
+ user_comp_filter_from_factory =
+ user_comp_filter_factory_->CreateCompactionFilter(context);
+ }
+
+ return std::unique_ptr<TtlCompactionFilter>(new TtlCompactionFilter(
+ ttl_, clock_, nullptr, std::move(user_comp_filter_from_factory)));
+}
+
+Status TtlCompactionFilterFactory::PrepareOptions(
+ const ConfigOptions& config_options) {
+ if (clock_ == nullptr) {
+ clock_ = config_options.env->GetSystemClock().get();
+ }
+ return CompactionFilterFactory::PrepareOptions(config_options);
+}
+
+Status TtlCompactionFilterFactory::ValidateOptions(
+ const DBOptions& db_opts, const ColumnFamilyOptions& cf_opts) const {
+ if (clock_ == nullptr) {
+ return Status::InvalidArgument(
+ "SystemClock required by TtlCompactionFilterFactory");
+ } else {
+ return CompactionFilterFactory::ValidateOptions(db_opts, cf_opts);
+ }
+}
+
+int RegisterTtlObjects(ObjectLibrary& library, const std::string& /*arg*/) {
+ library.AddFactory<MergeOperator>(
+ TtlMergeOperator::kClassName(),
+ [](const std::string& /*uri*/, std::unique_ptr<MergeOperator>* guard,
+ std::string* /* errmsg */) {
+ guard->reset(new TtlMergeOperator(nullptr, nullptr));
+ return guard->get();
+ });
+ library.AddFactory<CompactionFilterFactory>(
+ TtlCompactionFilterFactory::kClassName(),
+ [](const std::string& /*uri*/,
+ std::unique_ptr<CompactionFilterFactory>* guard,
+ std::string* /* errmsg */) {
+ guard->reset(new TtlCompactionFilterFactory(0, nullptr, nullptr));
+ return guard->get();
+ });
+ library.AddFactory<CompactionFilter>(
+ TtlCompactionFilter::kClassName(),
+ [](const std::string& /*uri*/,
+ std::unique_ptr<CompactionFilter>* /*guard*/,
+ std::string* /* errmsg */) {
+ return new TtlCompactionFilter(0, nullptr, nullptr);
+ });
+ size_t num_types;
+ return static_cast<int>(library.GetFactoryCount(&num_types));
+}
+// Open the db inside DBWithTTLImpl because options needs pointer to its ttl
+DBWithTTLImpl::DBWithTTLImpl(DB* db) : DBWithTTL(db), closed_(false) {}
+
+DBWithTTLImpl::~DBWithTTLImpl() {
+ if (!closed_) {
+ Close().PermitUncheckedError();
+ }
+}
+
+Status DBWithTTLImpl::Close() {
+ Status ret = Status::OK();
+ if (!closed_) {
+ Options default_options = GetOptions();
+ // Need to stop background compaction before getting rid of the filter
+ CancelAllBackgroundWork(db_, /* wait = */ true);
+ ret = db_->Close();
+ delete default_options.compaction_filter;
+ closed_ = true;
+ }
+ return ret;
+}
+
+void DBWithTTLImpl::RegisterTtlClasses() {
+ static std::once_flag once;
+ std::call_once(once, [&]() {
+ ObjectRegistry::Default()->AddLibrary("TTL", RegisterTtlObjects, "");
+ });
+}
+
+Status DBWithTTL::Open(const Options& options, const std::string& dbname,
+ DBWithTTL** dbptr, int32_t ttl, bool read_only) {
+ DBOptions db_options(options);
+ ColumnFamilyOptions cf_options(options);
+ std::vector<ColumnFamilyDescriptor> column_families;
+ column_families.push_back(
+ ColumnFamilyDescriptor(kDefaultColumnFamilyName, cf_options));
+ std::vector<ColumnFamilyHandle*> handles;
+ Status s = DBWithTTL::Open(db_options, dbname, column_families, &handles,
+ dbptr, {ttl}, read_only);
+ if (s.ok()) {
+ assert(handles.size() == 1);
+ // i can delete the handle since DBImpl is always holding a reference to
+ // default column family
+ delete handles[0];
+ }
+ return s;
+}
+
+Status DBWithTTL::Open(
+ const DBOptions& db_options, const std::string& dbname,
+ const std::vector<ColumnFamilyDescriptor>& column_families,
+ std::vector<ColumnFamilyHandle*>* handles, DBWithTTL** dbptr,
+ const std::vector<int32_t>& ttls, bool read_only) {
+ DBWithTTLImpl::RegisterTtlClasses();
+ if (ttls.size() != column_families.size()) {
+ return Status::InvalidArgument(
+ "ttls size has to be the same as number of column families");
+ }
+
+ SystemClock* clock = (db_options.env == nullptr)
+ ? SystemClock::Default().get()
+ : db_options.env->GetSystemClock().get();
+
+ std::vector<ColumnFamilyDescriptor> column_families_sanitized =
+ column_families;
+ for (size_t i = 0; i < column_families_sanitized.size(); ++i) {
+ DBWithTTLImpl::SanitizeOptions(
+ ttls[i], &column_families_sanitized[i].options, clock);
+ }
+ DB* db;
+
+ Status st;
+ if (read_only) {
+ st = DB::OpenForReadOnly(db_options, dbname, column_families_sanitized,
+ handles, &db);
+ } else {
+ st = DB::Open(db_options, dbname, column_families_sanitized, handles, &db);
+ }
+ if (st.ok()) {
+ *dbptr = new DBWithTTLImpl(db);
+ } else {
+ *dbptr = nullptr;
+ }
+ return st;
+}
+
+Status DBWithTTLImpl::CreateColumnFamilyWithTtl(
+ const ColumnFamilyOptions& options, const std::string& column_family_name,
+ ColumnFamilyHandle** handle, int ttl) {
+ RegisterTtlClasses();
+ ColumnFamilyOptions sanitized_options = options;
+ DBWithTTLImpl::SanitizeOptions(ttl, &sanitized_options,
+ GetEnv()->GetSystemClock().get());
+
+ return DBWithTTL::CreateColumnFamily(sanitized_options, column_family_name,
+ handle);
+}
+
+Status DBWithTTLImpl::CreateColumnFamily(const ColumnFamilyOptions& options,
+ const std::string& column_family_name,
+ ColumnFamilyHandle** handle) {
+ return CreateColumnFamilyWithTtl(options, column_family_name, handle, 0);
+}
+
+// Appends the current timestamp to the string.
+// Returns false if could not get the current_time, true if append succeeds
+Status DBWithTTLImpl::AppendTS(const Slice& val, std::string* val_with_ts,
+ SystemClock* clock) {
+ val_with_ts->reserve(kTSLength + val.size());
+ char ts_string[kTSLength];
+ int64_t curtime;
+ Status st = clock->GetCurrentTime(&curtime);
+ if (!st.ok()) {
+ return st;
+ }
+ EncodeFixed32(ts_string, (int32_t)curtime);
+ val_with_ts->append(val.data(), val.size());
+ val_with_ts->append(ts_string, kTSLength);
+ return st;
+}
+
+// Returns corruption if the length of the string is lesser than timestamp, or
+// timestamp refers to a time lesser than ttl-feature release time
+Status DBWithTTLImpl::SanityCheckTimestamp(const Slice& str) {
+ if (str.size() < kTSLength) {
+ return Status::Corruption("Error: value's length less than timestamp's\n");
+ }
+ // Checks that TS is not lesser than kMinTimestamp
+ // Gaurds against corruption & normal database opened incorrectly in ttl mode
+ int32_t timestamp_value = DecodeFixed32(str.data() + str.size() - kTSLength);
+ if (timestamp_value < kMinTimestamp) {
+ return Status::Corruption("Error: Timestamp < ttl feature release time!\n");
+ }
+ return Status::OK();
+}
+
+// Checks if the string is stale or not according to TTl provided
+bool DBWithTTLImpl::IsStale(const Slice& value, int32_t ttl,
+ SystemClock* clock) {
+ if (ttl <= 0) { // Data is fresh if TTL is non-positive
+ return false;
+ }
+ int64_t curtime;
+ if (!clock->GetCurrentTime(&curtime).ok()) {
+ return false; // Treat the data as fresh if could not get current time
+ }
+ int32_t timestamp_value =
+ DecodeFixed32(value.data() + value.size() - kTSLength);
+ return (timestamp_value + ttl) < curtime;
+}
+
+// Strips the TS from the end of the slice
+Status DBWithTTLImpl::StripTS(PinnableSlice* pinnable_val) {
+ if (pinnable_val->size() < kTSLength) {
+ return Status::Corruption("Bad timestamp in key-value");
+ }
+ // Erasing characters which hold the TS
+ pinnable_val->remove_suffix(kTSLength);
+ return Status::OK();
+}
+
+// Strips the TS from the end of the string
+Status DBWithTTLImpl::StripTS(std::string* str) {
+ if (str->length() < kTSLength) {
+ return Status::Corruption("Bad timestamp in key-value");
+ }
+ // Erasing characters which hold the TS
+ str->erase(str->length() - kTSLength, kTSLength);
+ return Status::OK();
+}
+
+Status DBWithTTLImpl::Put(const WriteOptions& options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ const Slice& val) {
+ WriteBatch batch;
+ Status st = batch.Put(column_family, key, val);
+ if (st.ok()) {
+ st = Write(options, &batch);
+ }
+ return st;
+}
+
+Status DBWithTTLImpl::Get(const ReadOptions& options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ PinnableSlice* value) {
+ Status st = db_->Get(options, column_family, key, value);
+ if (!st.ok()) {
+ return st;
+ }
+ st = SanityCheckTimestamp(*value);
+ if (!st.ok()) {
+ return st;
+ }
+ return StripTS(value);
+}
+
+std::vector<Status> DBWithTTLImpl::MultiGet(
+ const ReadOptions& options,
+ const std::vector<ColumnFamilyHandle*>& column_family,
+ const std::vector<Slice>& keys, std::vector<std::string>* values) {
+ auto statuses = db_->MultiGet(options, column_family, keys, values);
+ for (size_t i = 0; i < keys.size(); ++i) {
+ if (!statuses[i].ok()) {
+ continue;
+ }
+ statuses[i] = SanityCheckTimestamp((*values)[i]);
+ if (!statuses[i].ok()) {
+ continue;
+ }
+ statuses[i] = StripTS(&(*values)[i]);
+ }
+ return statuses;
+}
+
+bool DBWithTTLImpl::KeyMayExist(const ReadOptions& options,
+ ColumnFamilyHandle* column_family,
+ const Slice& key, std::string* value,
+ bool* value_found) {
+ bool ret = db_->KeyMayExist(options, column_family, key, value, value_found);
+ if (ret && value != nullptr && value_found != nullptr && *value_found) {
+ if (!SanityCheckTimestamp(*value).ok() || !StripTS(value).ok()) {
+ return false;
+ }
+ }
+ return ret;
+}
+
+Status DBWithTTLImpl::Merge(const WriteOptions& options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ const Slice& value) {
+ WriteBatch batch;
+ Status st = batch.Merge(column_family, key, value);
+ if (st.ok()) {
+ st = Write(options, &batch);
+ }
+ return st;
+}
+
+Status DBWithTTLImpl::Write(const WriteOptions& opts, WriteBatch* updates) {
+ class Handler : public WriteBatch::Handler {
+ public:
+ explicit Handler(SystemClock* clock) : clock_(clock) {}
+ WriteBatch updates_ttl;
+ Status PutCF(uint32_t column_family_id, const Slice& key,
+ const Slice& value) override {
+ std::string value_with_ts;
+ Status st = AppendTS(value, &value_with_ts, clock_);
+ if (!st.ok()) {
+ return st;
+ }
+ return WriteBatchInternal::Put(&updates_ttl, column_family_id, key,
+ value_with_ts);
+ }
+ Status MergeCF(uint32_t column_family_id, const Slice& key,
+ const Slice& value) override {
+ std::string value_with_ts;
+ Status st = AppendTS(value, &value_with_ts, clock_);
+ if (!st.ok()) {
+ return st;
+ }
+ return WriteBatchInternal::Merge(&updates_ttl, column_family_id, key,
+ value_with_ts);
+ }
+ Status DeleteCF(uint32_t column_family_id, const Slice& key) override {
+ return WriteBatchInternal::Delete(&updates_ttl, column_family_id, key);
+ }
+ Status DeleteRangeCF(uint32_t column_family_id, const Slice& begin_key,
+ const Slice& end_key) override {
+ return WriteBatchInternal::DeleteRange(&updates_ttl, column_family_id,
+ begin_key, end_key);
+ }
+ void LogData(const Slice& blob) override { updates_ttl.PutLogData(blob); }
+
+ private:
+ SystemClock* clock_;
+ };
+ Handler handler(GetEnv()->GetSystemClock().get());
+ Status st = updates->Iterate(&handler);
+ if (!st.ok()) {
+ return st;
+ } else {
+ return db_->Write(opts, &(handler.updates_ttl));
+ }
+}
+
+Iterator* DBWithTTLImpl::NewIterator(const ReadOptions& opts,
+ ColumnFamilyHandle* column_family) {
+ return new TtlIterator(db_->NewIterator(opts, column_family));
+}
+
+void DBWithTTLImpl::SetTtl(ColumnFamilyHandle* h, int32_t ttl) {
+ std::shared_ptr<TtlCompactionFilterFactory> filter;
+ Options opts;
+ opts = GetOptions(h);
+ filter = std::static_pointer_cast<TtlCompactionFilterFactory>(
+ opts.compaction_filter_factory);
+ if (!filter) return;
+ filter->SetTtl(ttl);
+}
+
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/ttl/db_ttl_impl.h b/src/rocksdb/utilities/ttl/db_ttl_impl.h
new file mode 100644
index 000000000..dd67a6ddc
--- /dev/null
+++ b/src/rocksdb/utilities/ttl/db_ttl_impl.h
@@ -0,0 +1,245 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#pragma once
+
+#ifndef ROCKSDB_LITE
+#include <deque>
+#include <string>
+#include <vector>
+
+#include "db/db_impl/db_impl.h"
+#include "rocksdb/compaction_filter.h"
+#include "rocksdb/db.h"
+#include "rocksdb/merge_operator.h"
+#include "rocksdb/system_clock.h"
+#include "rocksdb/utilities/db_ttl.h"
+#include "utilities/compaction_filters/layered_compaction_filter_base.h"
+
+#ifdef _WIN32
+// Windows API macro interference
+#undef GetCurrentTime
+#endif
+
+namespace ROCKSDB_NAMESPACE {
+struct ConfigOptions;
+class ObjectLibrary;
+class ObjectRegistry;
+class DBWithTTLImpl : public DBWithTTL {
+ public:
+ static void SanitizeOptions(int32_t ttl, ColumnFamilyOptions* options,
+ SystemClock* clock);
+
+ static void RegisterTtlClasses();
+ explicit DBWithTTLImpl(DB* db);
+
+ virtual ~DBWithTTLImpl();
+
+ virtual Status Close() override;
+
+ Status CreateColumnFamilyWithTtl(const ColumnFamilyOptions& options,
+ const std::string& column_family_name,
+ ColumnFamilyHandle** handle,
+ int ttl) override;
+
+ Status CreateColumnFamily(const ColumnFamilyOptions& options,
+ const std::string& column_family_name,
+ ColumnFamilyHandle** handle) override;
+
+ using StackableDB::Put;
+ virtual Status Put(const WriteOptions& options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ const Slice& val) override;
+
+ using StackableDB::Get;
+ virtual Status Get(const ReadOptions& options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ PinnableSlice* value) override;
+
+ using StackableDB::MultiGet;
+ virtual std::vector<Status> MultiGet(
+ const ReadOptions& options,
+ const std::vector<ColumnFamilyHandle*>& column_family,
+ const std::vector<Slice>& keys,
+ std::vector<std::string>* values) override;
+
+ using StackableDB::KeyMayExist;
+ virtual bool KeyMayExist(const ReadOptions& options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ std::string* value,
+ bool* value_found = nullptr) override;
+
+ using StackableDB::Merge;
+ virtual Status Merge(const WriteOptions& options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ const Slice& value) override;
+
+ virtual Status Write(const WriteOptions& opts, WriteBatch* updates) override;
+
+ using StackableDB::NewIterator;
+ virtual Iterator* NewIterator(const ReadOptions& opts,
+ ColumnFamilyHandle* column_family) override;
+
+ virtual DB* GetBaseDB() override { return db_; }
+
+ static bool IsStale(const Slice& value, int32_t ttl, SystemClock* clock);
+
+ static Status AppendTS(const Slice& val, std::string* val_with_ts,
+ SystemClock* clock);
+
+ static Status SanityCheckTimestamp(const Slice& str);
+
+ static Status StripTS(std::string* str);
+
+ static Status StripTS(PinnableSlice* str);
+
+ static const uint32_t kTSLength = sizeof(int32_t); // size of timestamp
+
+ static const int32_t kMinTimestamp = 1368146402; // 05/09/2013:5:40PM GMT-8
+
+ static const int32_t kMaxTimestamp = 2147483647; // 01/18/2038:7:14PM GMT-8
+
+ void SetTtl(int32_t ttl) override { SetTtl(DefaultColumnFamily(), ttl); }
+
+ void SetTtl(ColumnFamilyHandle* h, int32_t ttl) override;
+
+ private:
+ // remember whether the Close completes or not
+ bool closed_;
+};
+
+class TtlIterator : public Iterator {
+ public:
+ explicit TtlIterator(Iterator* iter) : iter_(iter) { assert(iter_); }
+
+ ~TtlIterator() { delete iter_; }
+
+ bool Valid() const override { return iter_->Valid(); }
+
+ void SeekToFirst() override { iter_->SeekToFirst(); }
+
+ void SeekToLast() override { iter_->SeekToLast(); }
+
+ void Seek(const Slice& target) override { iter_->Seek(target); }
+
+ void SeekForPrev(const Slice& target) override { iter_->SeekForPrev(target); }
+
+ void Next() override { iter_->Next(); }
+
+ void Prev() override { iter_->Prev(); }
+
+ Slice key() const override { return iter_->key(); }
+
+ int32_t ttl_timestamp() const {
+ return DecodeFixed32(iter_->value().data() + iter_->value().size() -
+ DBWithTTLImpl::kTSLength);
+ }
+
+ Slice value() const override {
+ // TODO: handle timestamp corruption like in general iterator semantics
+ assert(DBWithTTLImpl::SanityCheckTimestamp(iter_->value()).ok());
+ Slice trimmed_value = iter_->value();
+ trimmed_value.size_ -= DBWithTTLImpl::kTSLength;
+ return trimmed_value;
+ }
+
+ Status status() const override { return iter_->status(); }
+
+ private:
+ Iterator* iter_;
+};
+
+class TtlCompactionFilter : public LayeredCompactionFilterBase {
+ public:
+ TtlCompactionFilter(int32_t ttl, SystemClock* clock,
+ const CompactionFilter* _user_comp_filter,
+ std::unique_ptr<const CompactionFilter>
+ _user_comp_filter_from_factory = nullptr);
+
+ virtual bool Filter(int level, const Slice& key, const Slice& old_val,
+ std::string* new_val, bool* value_changed) const override;
+
+ const char* Name() const override { return kClassName(); }
+ static const char* kClassName() { return "TtlCompactionFilter"; }
+ bool IsInstanceOf(const std::string& name) const override {
+ if (name == "Delete By TTL") {
+ return true;
+ } else {
+ return LayeredCompactionFilterBase::IsInstanceOf(name);
+ }
+ }
+
+ Status PrepareOptions(const ConfigOptions& config_options) override;
+ Status ValidateOptions(const DBOptions& db_opts,
+ const ColumnFamilyOptions& cf_opts) const override;
+
+ private:
+ int32_t ttl_;
+ SystemClock* clock_;
+};
+
+class TtlCompactionFilterFactory : public CompactionFilterFactory {
+ public:
+ TtlCompactionFilterFactory(
+ int32_t ttl, SystemClock* clock,
+ std::shared_ptr<CompactionFilterFactory> comp_filter_factory);
+
+ std::unique_ptr<CompactionFilter> CreateCompactionFilter(
+ const CompactionFilter::Context& context) override;
+ void SetTtl(int32_t ttl) { ttl_ = ttl; }
+
+ const char* Name() const override { return kClassName(); }
+ static const char* kClassName() { return "TtlCompactionFilterFactory"; }
+ Status PrepareOptions(const ConfigOptions& config_options) override;
+ Status ValidateOptions(const DBOptions& db_opts,
+ const ColumnFamilyOptions& cf_opts) const override;
+ const Customizable* Inner() const override {
+ return user_comp_filter_factory_.get();
+ }
+
+ private:
+ int32_t ttl_;
+ SystemClock* clock_;
+ std::shared_ptr<CompactionFilterFactory> user_comp_filter_factory_;
+};
+
+class TtlMergeOperator : public MergeOperator {
+ public:
+ explicit TtlMergeOperator(const std::shared_ptr<MergeOperator>& merge_op,
+ SystemClock* clock);
+
+ bool FullMergeV2(const MergeOperationInput& merge_in,
+ MergeOperationOutput* merge_out) const override;
+
+ bool PartialMergeMulti(const Slice& key,
+ const std::deque<Slice>& operand_list,
+ std::string* new_value, Logger* logger) const override;
+
+ static const char* kClassName() { return "TtlMergeOperator"; }
+
+ const char* Name() const override { return kClassName(); }
+ bool IsInstanceOf(const std::string& name) const override {
+ if (name == "Merge By TTL") {
+ return true;
+ } else {
+ return MergeOperator::IsInstanceOf(name);
+ }
+ }
+
+ Status PrepareOptions(const ConfigOptions& config_options) override;
+ Status ValidateOptions(const DBOptions& db_opts,
+ const ColumnFamilyOptions& cf_opts) const override;
+ const Customizable* Inner() const override { return user_merge_op_.get(); }
+
+ private:
+ std::shared_ptr<MergeOperator> user_merge_op_;
+ SystemClock* clock_;
+};
+extern "C" {
+int RegisterTtlObjects(ObjectLibrary& library, const std::string& /*arg*/);
+} // extern "C"
+
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/ttl/ttl_test.cc b/src/rocksdb/utilities/ttl/ttl_test.cc
new file mode 100644
index 000000000..a42e0acb4
--- /dev/null
+++ b/src/rocksdb/utilities/ttl/ttl_test.cc
@@ -0,0 +1,912 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#ifndef ROCKSDB_LITE
+
+#include <map>
+#include <memory>
+
+#include "rocksdb/compaction_filter.h"
+#include "rocksdb/convenience.h"
+#include "rocksdb/merge_operator.h"
+#include "rocksdb/utilities/db_ttl.h"
+#include "rocksdb/utilities/object_registry.h"
+#include "test_util/testharness.h"
+#include "util/string_util.h"
+#include "utilities/merge_operators/bytesxor.h"
+#include "utilities/ttl/db_ttl_impl.h"
+#ifndef OS_WIN
+#include <unistd.h>
+#endif
+
+namespace ROCKSDB_NAMESPACE {
+
+namespace {
+
+using KVMap = std::map<std::string, std::string>;
+
+enum BatchOperation { OP_PUT = 0, OP_DELETE = 1 };
+} // namespace
+
+class SpecialTimeEnv : public EnvWrapper {
+ public:
+ explicit SpecialTimeEnv(Env* base) : EnvWrapper(base) {
+ EXPECT_OK(base->GetCurrentTime(&current_time_));
+ }
+ const char* Name() const override { return "SpecialTimeEnv"; }
+ void Sleep(int64_t sleep_time) { current_time_ += sleep_time; }
+ Status GetCurrentTime(int64_t* current_time) override {
+ *current_time = current_time_;
+ return Status::OK();
+ }
+
+ private:
+ int64_t current_time_ = 0;
+};
+
+class TtlTest : public testing::Test {
+ public:
+ TtlTest() {
+ env_.reset(new SpecialTimeEnv(Env::Default()));
+ dbname_ = test::PerThreadDBPath("db_ttl");
+ options_.create_if_missing = true;
+ options_.env = env_.get();
+ // ensure that compaction is kicked in to always strip timestamp from kvs
+ options_.max_compaction_bytes = 1;
+ // compaction should take place always from level0 for determinism
+ db_ttl_ = nullptr;
+ EXPECT_OK(DestroyDB(dbname_, Options()));
+ }
+
+ ~TtlTest() override {
+ CloseTtl();
+ EXPECT_OK(DestroyDB(dbname_, Options()));
+ }
+
+ // Open database with TTL support when TTL not provided with db_ttl_ pointer
+ void OpenTtl() {
+ ASSERT_TRUE(db_ttl_ ==
+ nullptr); // db should be closed before opening again
+ ASSERT_OK(DBWithTTL::Open(options_, dbname_, &db_ttl_));
+ }
+
+ // Open database with TTL support when TTL provided with db_ttl_ pointer
+ void OpenTtl(int32_t ttl) {
+ ASSERT_TRUE(db_ttl_ == nullptr);
+ ASSERT_OK(DBWithTTL::Open(options_, dbname_, &db_ttl_, ttl));
+ }
+
+ // Open with TestFilter compaction filter
+ void OpenTtlWithTestCompaction(int32_t ttl) {
+ options_.compaction_filter_factory =
+ std::shared_ptr<CompactionFilterFactory>(
+ new TestFilterFactory(kSampleSize_, kNewValue_));
+ OpenTtl(ttl);
+ }
+
+ // Open database with TTL support in read_only mode
+ void OpenReadOnlyTtl(int32_t ttl) {
+ ASSERT_TRUE(db_ttl_ == nullptr);
+ ASSERT_OK(DBWithTTL::Open(options_, dbname_, &db_ttl_, ttl, true));
+ }
+
+ // Call db_ttl_->Close() before delete db_ttl_
+ void CloseTtl() { CloseTtlHelper(true); }
+
+ // No db_ttl_->Close() before delete db_ttl_
+ void CloseTtlNoDBClose() { CloseTtlHelper(false); }
+
+ void CloseTtlHelper(bool close_db) {
+ if (db_ttl_ != nullptr) {
+ if (close_db) {
+ EXPECT_OK(db_ttl_->Close());
+ }
+ delete db_ttl_;
+ db_ttl_ = nullptr;
+ }
+ }
+
+ // Populates and returns a kv-map
+ void MakeKVMap(int64_t num_entries) {
+ kvmap_.clear();
+ int digits = 1;
+ for (int64_t dummy = num_entries; dummy /= 10; ++digits) {
+ }
+ int digits_in_i = 1;
+ for (int64_t i = 0; i < num_entries; i++) {
+ std::string key = "key";
+ std::string value = "value";
+ if (i % 10 == 0) {
+ digits_in_i++;
+ }
+ for (int j = digits_in_i; j < digits; j++) {
+ key.append("0");
+ value.append("0");
+ }
+ AppendNumberTo(&key, i);
+ AppendNumberTo(&value, i);
+ kvmap_[key] = value;
+ }
+ ASSERT_EQ(static_cast<int64_t>(kvmap_.size()),
+ num_entries); // check all insertions done
+ }
+
+ // Makes a write-batch with key-vals from kvmap_ and 'Write''s it
+ void MakePutWriteBatch(const BatchOperation* batch_ops, int64_t num_ops) {
+ ASSERT_LE(num_ops, static_cast<int64_t>(kvmap_.size()));
+ static WriteOptions wopts;
+ static FlushOptions flush_opts;
+ WriteBatch batch;
+ kv_it_ = kvmap_.begin();
+ for (int64_t i = 0; i < num_ops && kv_it_ != kvmap_.end(); i++, ++kv_it_) {
+ switch (batch_ops[i]) {
+ case OP_PUT:
+ ASSERT_OK(batch.Put(kv_it_->first, kv_it_->second));
+ break;
+ case OP_DELETE:
+ ASSERT_OK(batch.Delete(kv_it_->first));
+ break;
+ default:
+ FAIL();
+ }
+ }
+ ASSERT_OK(db_ttl_->Write(wopts, &batch));
+ ASSERT_OK(db_ttl_->Flush(flush_opts));
+ }
+
+ // Puts num_entries starting from start_pos_map from kvmap_ into the database
+ void PutValues(int64_t start_pos_map, int64_t num_entries, bool flush = true,
+ ColumnFamilyHandle* cf = nullptr) {
+ ASSERT_TRUE(db_ttl_);
+ ASSERT_LE(start_pos_map + num_entries, static_cast<int64_t>(kvmap_.size()));
+ static WriteOptions wopts;
+ static FlushOptions flush_opts;
+ kv_it_ = kvmap_.begin();
+ advance(kv_it_, start_pos_map);
+ for (int64_t i = 0; kv_it_ != kvmap_.end() && i < num_entries;
+ i++, ++kv_it_) {
+ ASSERT_OK(cf == nullptr
+ ? db_ttl_->Put(wopts, kv_it_->first, kv_it_->second)
+ : db_ttl_->Put(wopts, cf, kv_it_->first, kv_it_->second));
+ }
+ // Put a mock kv at the end because CompactionFilter doesn't delete last key
+ ASSERT_OK(cf == nullptr ? db_ttl_->Put(wopts, "keymock", "valuemock")
+ : db_ttl_->Put(wopts, cf, "keymock", "valuemock"));
+ if (flush) {
+ if (cf == nullptr) {
+ ASSERT_OK(db_ttl_->Flush(flush_opts));
+ } else {
+ ASSERT_OK(db_ttl_->Flush(flush_opts, cf));
+ }
+ }
+ }
+
+ // Runs a manual compaction
+ Status ManualCompact(ColumnFamilyHandle* cf = nullptr) {
+ assert(db_ttl_);
+ if (cf == nullptr) {
+ return db_ttl_->CompactRange(CompactRangeOptions(), nullptr, nullptr);
+ } else {
+ return db_ttl_->CompactRange(CompactRangeOptions(), cf, nullptr, nullptr);
+ }
+ }
+
+ // Runs a DeleteRange
+ void MakeDeleteRange(std::string start, std::string end,
+ ColumnFamilyHandle* cf = nullptr) {
+ ASSERT_TRUE(db_ttl_);
+ static WriteOptions wops;
+ WriteBatch wb;
+ ASSERT_OK(cf == nullptr
+ ? wb.DeleteRange(db_ttl_->DefaultColumnFamily(), start, end)
+ : wb.DeleteRange(cf, start, end));
+ ASSERT_OK(db_ttl_->Write(wops, &wb));
+ }
+
+ // checks the whole kvmap_ to return correct values using KeyMayExist
+ void SimpleKeyMayExistCheck() {
+ static ReadOptions ropts;
+ bool value_found;
+ std::string val;
+ for (auto& kv : kvmap_) {
+ bool ret = db_ttl_->KeyMayExist(ropts, kv.first, &val, &value_found);
+ if (ret == false || value_found == false) {
+ fprintf(stderr,
+ "KeyMayExist could not find key=%s in the database but"
+ " should have\n",
+ kv.first.c_str());
+ FAIL();
+ } else if (val.compare(kv.second) != 0) {
+ fprintf(stderr,
+ " value for key=%s present in database is %s but"
+ " should be %s\n",
+ kv.first.c_str(), val.c_str(), kv.second.c_str());
+ FAIL();
+ }
+ }
+ }
+
+ // checks the whole kvmap_ to return correct values using MultiGet
+ void SimpleMultiGetTest() {
+ static ReadOptions ropts;
+ std::vector<Slice> keys;
+ std::vector<std::string> values;
+
+ for (auto& kv : kvmap_) {
+ keys.emplace_back(kv.first);
+ }
+
+ auto statuses = db_ttl_->MultiGet(ropts, keys, &values);
+ size_t i = 0;
+ for (auto& kv : kvmap_) {
+ ASSERT_OK(statuses[i]);
+ ASSERT_EQ(values[i], kv.second);
+ ++i;
+ }
+ }
+
+ void CompactCheck(int64_t st_pos, int64_t span, bool check = true,
+ bool test_compaction_change = false,
+ ColumnFamilyHandle* cf = nullptr) {
+ static ReadOptions ropts;
+ kv_it_ = kvmap_.begin();
+ advance(kv_it_, st_pos);
+ std::string v;
+ for (int64_t i = 0; kv_it_ != kvmap_.end() && i < span; i++, ++kv_it_) {
+ Status s = (cf == nullptr) ? db_ttl_->Get(ropts, kv_it_->first, &v)
+ : db_ttl_->Get(ropts, cf, kv_it_->first, &v);
+ if (s.ok() != check) {
+ fprintf(stderr, "key=%s ", kv_it_->first.c_str());
+ if (!s.ok()) {
+ fprintf(stderr, "is absent from db but was expected to be present\n");
+ } else {
+ fprintf(stderr, "is present in db but was expected to be absent\n");
+ }
+ FAIL();
+ } else if (s.ok()) {
+ if (test_compaction_change && v.compare(kNewValue_) != 0) {
+ fprintf(stderr,
+ " value for key=%s present in database is %s but "
+ " should be %s\n",
+ kv_it_->first.c_str(), v.c_str(), kNewValue_.c_str());
+ FAIL();
+ } else if (!test_compaction_change && v.compare(kv_it_->second) != 0) {
+ fprintf(stderr,
+ " value for key=%s present in database is %s but "
+ " should be %s\n",
+ kv_it_->first.c_str(), v.c_str(), kv_it_->second.c_str());
+ FAIL();
+ }
+ }
+ }
+ }
+ // Sleeps for slp_tim then runs a manual compaction
+ // Checks span starting from st_pos from kvmap_ in the db and
+ // Gets should return true if check is true and false otherwise
+ // Also checks that value that we got is the same as inserted; and =kNewValue
+ // if test_compaction_change is true
+ void SleepCompactCheck(int slp_tim, int64_t st_pos, int64_t span,
+ bool check = true, bool test_compaction_change = false,
+ ColumnFamilyHandle* cf = nullptr) {
+ ASSERT_TRUE(db_ttl_);
+
+ env_->Sleep(slp_tim);
+ ASSERT_OK(ManualCompact(cf));
+ CompactCheck(st_pos, span, check, test_compaction_change, cf);
+ }
+
+ // Similar as SleepCompactCheck but uses TtlIterator to read from db
+ void SleepCompactCheckIter(int slp, int st_pos, int64_t span,
+ bool check = true) {
+ ASSERT_TRUE(db_ttl_);
+ env_->Sleep(slp);
+ ASSERT_OK(ManualCompact());
+ static ReadOptions ropts;
+ Iterator* dbiter = db_ttl_->NewIterator(ropts);
+ kv_it_ = kvmap_.begin();
+ advance(kv_it_, st_pos);
+
+ dbiter->Seek(kv_it_->first);
+ if (!check) {
+ if (dbiter->Valid()) {
+ ASSERT_NE(dbiter->value().compare(kv_it_->second), 0);
+ }
+ } else { // dbiter should have found out kvmap_[st_pos]
+ for (int64_t i = st_pos; kv_it_ != kvmap_.end() && i < st_pos + span;
+ i++, ++kv_it_) {
+ ASSERT_TRUE(dbiter->Valid());
+ ASSERT_EQ(dbiter->value().compare(kv_it_->second), 0);
+ dbiter->Next();
+ }
+ }
+ ASSERT_OK(dbiter->status());
+ delete dbiter;
+ }
+
+ // Set ttl on open db
+ void SetTtl(int32_t ttl, ColumnFamilyHandle* cf = nullptr) {
+ ASSERT_TRUE(db_ttl_);
+ cf == nullptr ? db_ttl_->SetTtl(ttl) : db_ttl_->SetTtl(cf, ttl);
+ }
+
+ class TestFilter : public CompactionFilter {
+ public:
+ TestFilter(const int64_t kSampleSize, const std::string& kNewValue)
+ : kSampleSize_(kSampleSize), kNewValue_(kNewValue) {}
+
+ // Works on keys of the form "key<number>"
+ // Drops key if number at the end of key is in [0, kSampleSize_/3),
+ // Keeps key if it is in [kSampleSize_/3, 2*kSampleSize_/3),
+ // Change value if it is in [2*kSampleSize_/3, kSampleSize_)
+ // Eg. kSampleSize_=6. Drop:key0-1...Keep:key2-3...Change:key4-5...
+ bool Filter(int /*level*/, const Slice& key, const Slice& /*value*/,
+ std::string* new_value, bool* value_changed) const override {
+ assert(new_value != nullptr);
+
+ std::string search_str = "0123456789";
+ std::string key_string = key.ToString();
+ size_t pos = key_string.find_first_of(search_str);
+ int num_key_end;
+ if (pos != std::string::npos) {
+ auto key_substr = key_string.substr(pos, key.size() - pos);
+#ifndef CYGWIN
+ num_key_end = std::stoi(key_substr);
+#else
+ num_key_end = std::strtol(key_substr.c_str(), 0, 10);
+#endif
+
+ } else {
+ return false; // Keep keys not matching the format "key<NUMBER>"
+ }
+
+ int64_t partition = kSampleSize_ / 3;
+ if (num_key_end < partition) {
+ return true;
+ } else if (num_key_end < partition * 2) {
+ return false;
+ } else {
+ *new_value = kNewValue_;
+ *value_changed = true;
+ return false;
+ }
+ }
+
+ const char* Name() const override { return "TestFilter"; }
+
+ private:
+ const int64_t kSampleSize_;
+ const std::string kNewValue_;
+ };
+
+ class TestFilterFactory : public CompactionFilterFactory {
+ public:
+ TestFilterFactory(const int64_t kSampleSize, const std::string& kNewValue)
+ : kSampleSize_(kSampleSize), kNewValue_(kNewValue) {}
+
+ std::unique_ptr<CompactionFilter> CreateCompactionFilter(
+ const CompactionFilter::Context& /*context*/) override {
+ return std::unique_ptr<CompactionFilter>(
+ new TestFilter(kSampleSize_, kNewValue_));
+ }
+
+ const char* Name() const override { return "TestFilterFactory"; }
+
+ private:
+ const int64_t kSampleSize_;
+ const std::string kNewValue_;
+ };
+
+ // Choose carefully so that Put, Gets & Compaction complete in 1 second buffer
+ static const int64_t kSampleSize_ = 100;
+ std::string dbname_;
+ DBWithTTL* db_ttl_;
+ std::unique_ptr<SpecialTimeEnv> env_;
+
+ private:
+ Options options_;
+ KVMap kvmap_;
+ KVMap::iterator kv_it_;
+ const std::string kNewValue_ = "new_value";
+ std::unique_ptr<CompactionFilter> test_comp_filter_;
+}; // class TtlTest
+
+// If TTL is non positive or not provided, the behaviour is TTL = infinity
+// This test opens the db 3 times with such default behavior and inserts a
+// bunch of kvs each time. All kvs should accumulate in the db till the end
+// Partitions the sample-size provided into 3 sets over boundary1 and boundary2
+TEST_F(TtlTest, NoEffect) {
+ MakeKVMap(kSampleSize_);
+ int64_t boundary1 = kSampleSize_ / 3;
+ int64_t boundary2 = 2 * boundary1;
+
+ OpenTtl();
+ PutValues(0, boundary1); // T=0: Set1 never deleted
+ SleepCompactCheck(1, 0, boundary1); // T=1: Set1 still there
+ CloseTtl();
+
+ OpenTtl(0);
+ PutValues(boundary1, boundary2 - boundary1); // T=1: Set2 never deleted
+ SleepCompactCheck(1, 0, boundary2); // T=2: Sets1 & 2 still there
+ CloseTtl();
+
+ OpenTtl(-1);
+ PutValues(boundary2, kSampleSize_ - boundary2); // T=3: Set3 never deleted
+ SleepCompactCheck(1, 0, kSampleSize_, true); // T=4: Sets 1,2,3 still there
+ CloseTtl();
+}
+
+// Rerun the NoEffect test with a different version of CloseTtl
+// function, where db is directly deleted without close.
+TEST_F(TtlTest, DestructWithoutClose) {
+ MakeKVMap(kSampleSize_);
+ int64_t boundary1 = kSampleSize_ / 3;
+ int64_t boundary2 = 2 * boundary1;
+
+ OpenTtl();
+ PutValues(0, boundary1); // T=0: Set1 never deleted
+ SleepCompactCheck(1, 0, boundary1); // T=1: Set1 still there
+ CloseTtlNoDBClose();
+
+ OpenTtl(0);
+ PutValues(boundary1, boundary2 - boundary1); // T=1: Set2 never deleted
+ SleepCompactCheck(1, 0, boundary2); // T=2: Sets1 & 2 still there
+ CloseTtlNoDBClose();
+
+ OpenTtl(-1);
+ PutValues(boundary2, kSampleSize_ - boundary2); // T=3: Set3 never deleted
+ SleepCompactCheck(1, 0, kSampleSize_, true); // T=4: Sets 1,2,3 still there
+ CloseTtlNoDBClose();
+}
+
+// Puts a set of values and checks its presence using Get during ttl
+TEST_F(TtlTest, PresentDuringTTL) {
+ MakeKVMap(kSampleSize_);
+
+ OpenTtl(2); // T=0:Open the db with ttl = 2
+ PutValues(0, kSampleSize_); // T=0:Insert Set1. Delete at t=2
+ SleepCompactCheck(1, 0, kSampleSize_,
+ true); // T=1:Set1 should still be there
+ CloseTtl();
+}
+
+// Puts a set of values and checks its absence using Get after ttl
+TEST_F(TtlTest, AbsentAfterTTL) {
+ MakeKVMap(kSampleSize_);
+
+ OpenTtl(1); // T=0:Open the db with ttl = 2
+ PutValues(0, kSampleSize_); // T=0:Insert Set1. Delete at t=2
+ SleepCompactCheck(2, 0, kSampleSize_, false); // T=2:Set1 should not be there
+ CloseTtl();
+}
+
+// Resets the timestamp of a set of kvs by updating them and checks that they
+// are not deleted according to the old timestamp
+TEST_F(TtlTest, ResetTimestamp) {
+ MakeKVMap(kSampleSize_);
+
+ OpenTtl(3);
+ PutValues(0, kSampleSize_); // T=0: Insert Set1. Delete at t=3
+ env_->Sleep(2); // T=2
+ PutValues(0, kSampleSize_); // T=2: Insert Set1. Delete at t=5
+ SleepCompactCheck(2, 0, kSampleSize_); // T=4: Set1 should still be there
+ CloseTtl();
+}
+
+// Similar to PresentDuringTTL but uses Iterator
+TEST_F(TtlTest, IterPresentDuringTTL) {
+ MakeKVMap(kSampleSize_);
+
+ OpenTtl(2);
+ PutValues(0, kSampleSize_); // T=0: Insert. Delete at t=2
+ SleepCompactCheckIter(1, 0, kSampleSize_); // T=1: Set should be there
+ CloseTtl();
+}
+
+// Similar to AbsentAfterTTL but uses Iterator
+TEST_F(TtlTest, IterAbsentAfterTTL) {
+ MakeKVMap(kSampleSize_);
+
+ OpenTtl(1);
+ PutValues(0, kSampleSize_); // T=0: Insert. Delete at t=1
+ SleepCompactCheckIter(2, 0, kSampleSize_, false); // T=2: Should not be there
+ CloseTtl();
+}
+
+// Checks presence while opening the same db more than once with the same ttl
+// Note: The second open will open the same db
+TEST_F(TtlTest, MultiOpenSamePresent) {
+ MakeKVMap(kSampleSize_);
+
+ OpenTtl(2);
+ PutValues(0, kSampleSize_); // T=0: Insert. Delete at t=2
+ CloseTtl();
+
+ OpenTtl(2); // T=0. Delete at t=2
+ SleepCompactCheck(1, 0, kSampleSize_); // T=1: Set should be there
+ CloseTtl();
+}
+
+// Checks absence while opening the same db more than once with the same ttl
+// Note: The second open will open the same db
+TEST_F(TtlTest, MultiOpenSameAbsent) {
+ MakeKVMap(kSampleSize_);
+
+ OpenTtl(1);
+ PutValues(0, kSampleSize_); // T=0: Insert. Delete at t=1
+ CloseTtl();
+
+ OpenTtl(1); // T=0.Delete at t=1
+ SleepCompactCheck(2, 0, kSampleSize_, false); // T=2: Set should not be there
+ CloseTtl();
+}
+
+// Checks presence while opening the same db more than once with bigger ttl
+TEST_F(TtlTest, MultiOpenDifferent) {
+ MakeKVMap(kSampleSize_);
+
+ OpenTtl(1);
+ PutValues(0, kSampleSize_); // T=0: Insert. Delete at t=1
+ CloseTtl();
+
+ OpenTtl(3); // T=0: Set deleted at t=3
+ SleepCompactCheck(2, 0, kSampleSize_); // T=2: Set should be there
+ CloseTtl();
+}
+
+// Checks presence during ttl in read_only mode
+TEST_F(TtlTest, ReadOnlyPresentForever) {
+ MakeKVMap(kSampleSize_);
+
+ OpenTtl(1); // T=0:Open the db normally
+ PutValues(0, kSampleSize_); // T=0:Insert Set1. Delete at t=1
+ CloseTtl();
+
+ OpenReadOnlyTtl(1);
+ ASSERT_TRUE(db_ttl_);
+
+ env_->Sleep(2);
+ Status s = ManualCompact(); // T=2:Set1 should still be there
+ ASSERT_TRUE(s.IsNotSupported());
+ CompactCheck(0, kSampleSize_);
+ CloseTtl();
+}
+
+// Checks whether WriteBatch works well with TTL
+// Puts all kvs in kvmap_ in a batch and writes first, then deletes first half
+TEST_F(TtlTest, WriteBatchTest) {
+ MakeKVMap(kSampleSize_);
+ BatchOperation batch_ops[kSampleSize_];
+ for (int i = 0; i < kSampleSize_; i++) {
+ batch_ops[i] = OP_PUT;
+ }
+
+ OpenTtl(2);
+ MakePutWriteBatch(batch_ops, kSampleSize_);
+ for (int i = 0; i < kSampleSize_ / 2; i++) {
+ batch_ops[i] = OP_DELETE;
+ }
+ MakePutWriteBatch(batch_ops, kSampleSize_ / 2);
+ SleepCompactCheck(0, 0, kSampleSize_ / 2, false);
+ SleepCompactCheck(0, kSampleSize_ / 2, kSampleSize_ - kSampleSize_ / 2);
+ CloseTtl();
+}
+
+// Checks user's compaction filter for correctness with TTL logic
+TEST_F(TtlTest, CompactionFilter) {
+ MakeKVMap(kSampleSize_);
+
+ OpenTtlWithTestCompaction(1);
+ PutValues(0, kSampleSize_); // T=0:Insert Set1. Delete at t=1
+ // T=2: TTL logic takes precedence over TestFilter:-Set1 should not be there
+ SleepCompactCheck(2, 0, kSampleSize_, false);
+ CloseTtl();
+
+ OpenTtlWithTestCompaction(3);
+ PutValues(0, kSampleSize_); // T=0:Insert Set1.
+ int64_t partition = kSampleSize_ / 3;
+ SleepCompactCheck(1, 0, partition, false); // Part dropped
+ SleepCompactCheck(0, partition, partition); // Part kept
+ SleepCompactCheck(0, 2 * partition, partition, true, true); // Part changed
+ CloseTtl();
+}
+
+// Insert some key-values which KeyMayExist should be able to get and check that
+// values returned are fine
+TEST_F(TtlTest, KeyMayExist) {
+ MakeKVMap(kSampleSize_);
+
+ OpenTtl();
+ PutValues(0, kSampleSize_, false);
+
+ SimpleKeyMayExistCheck();
+
+ CloseTtl();
+}
+
+TEST_F(TtlTest, MultiGetTest) {
+ MakeKVMap(kSampleSize_);
+
+ OpenTtl();
+ PutValues(0, kSampleSize_, false);
+
+ SimpleMultiGetTest();
+
+ CloseTtl();
+}
+
+TEST_F(TtlTest, ColumnFamiliesTest) {
+ DB* db;
+ Options options;
+ options.create_if_missing = true;
+ options.env = env_.get();
+
+ DB::Open(options, dbname_, &db);
+ ColumnFamilyHandle* handle;
+ ASSERT_OK(db->CreateColumnFamily(ColumnFamilyOptions(options),
+ "ttl_column_family", &handle));
+
+ delete handle;
+ delete db;
+
+ std::vector<ColumnFamilyDescriptor> column_families;
+ column_families.push_back(ColumnFamilyDescriptor(
+ kDefaultColumnFamilyName, ColumnFamilyOptions(options)));
+ column_families.push_back(ColumnFamilyDescriptor(
+ "ttl_column_family", ColumnFamilyOptions(options)));
+
+ std::vector<ColumnFamilyHandle*> handles;
+
+ ASSERT_OK(DBWithTTL::Open(DBOptions(options), dbname_, column_families,
+ &handles, &db_ttl_, {3, 5}, false));
+ ASSERT_EQ(handles.size(), 2U);
+ ColumnFamilyHandle* new_handle;
+ ASSERT_OK(db_ttl_->CreateColumnFamilyWithTtl(options, "ttl_column_family_2",
+ &new_handle, 2));
+ handles.push_back(new_handle);
+
+ MakeKVMap(kSampleSize_);
+ PutValues(0, kSampleSize_, false, handles[0]);
+ PutValues(0, kSampleSize_, false, handles[1]);
+ PutValues(0, kSampleSize_, false, handles[2]);
+
+ // everything should be there after 1 second
+ SleepCompactCheck(1, 0, kSampleSize_, true, false, handles[0]);
+ SleepCompactCheck(0, 0, kSampleSize_, true, false, handles[1]);
+ SleepCompactCheck(0, 0, kSampleSize_, true, false, handles[2]);
+
+ // only column family 1 should be alive after 4 seconds
+ SleepCompactCheck(3, 0, kSampleSize_, false, false, handles[0]);
+ SleepCompactCheck(0, 0, kSampleSize_, true, false, handles[1]);
+ SleepCompactCheck(0, 0, kSampleSize_, false, false, handles[2]);
+
+ // nothing should be there after 6 seconds
+ SleepCompactCheck(2, 0, kSampleSize_, false, false, handles[0]);
+ SleepCompactCheck(0, 0, kSampleSize_, false, false, handles[1]);
+ SleepCompactCheck(0, 0, kSampleSize_, false, false, handles[2]);
+
+ for (auto h : handles) {
+ delete h;
+ }
+ delete db_ttl_;
+ db_ttl_ = nullptr;
+}
+
+// Puts a set of values and checks its absence using Get after ttl
+TEST_F(TtlTest, ChangeTtlOnOpenDb) {
+ MakeKVMap(kSampleSize_);
+
+ OpenTtl(1); // T=0:Open the db with ttl = 2
+ SetTtl(3);
+ PutValues(0, kSampleSize_); // T=0:Insert Set1. Delete at t=2
+ SleepCompactCheck(2, 0, kSampleSize_, true); // T=2:Set1 should be there
+ CloseTtl();
+}
+
+// Test DeleteRange for DBWithTtl
+TEST_F(TtlTest, DeleteRangeTest) {
+ OpenTtl();
+ ASSERT_OK(db_ttl_->Put(WriteOptions(), "a", "val"));
+ MakeDeleteRange("a", "b");
+ ASSERT_OK(db_ttl_->Put(WriteOptions(), "c", "val"));
+ MakeDeleteRange("b", "d");
+ ASSERT_OK(db_ttl_->Put(WriteOptions(), "e", "val"));
+ MakeDeleteRange("d", "e");
+ // first iteration verifies query correctness in memtable, second verifies
+ // query correctness for a single SST file
+ for (int i = 0; i < 2; i++) {
+ if (i > 0) {
+ ASSERT_OK(db_ttl_->Flush(FlushOptions()));
+ }
+ std::string value;
+ ASSERT_TRUE(db_ttl_->Get(ReadOptions(), "a", &value).IsNotFound());
+ ASSERT_TRUE(db_ttl_->Get(ReadOptions(), "c", &value).IsNotFound());
+ ASSERT_OK(db_ttl_->Get(ReadOptions(), "e", &value));
+ }
+ CloseTtl();
+}
+
+class DummyFilter : public CompactionFilter {
+ public:
+ bool Filter(int /*level*/, const Slice& /*key*/, const Slice& /*value*/,
+ std::string* /*new_value*/,
+ bool* /*value_changed*/) const override {
+ return false;
+ }
+
+ const char* Name() const override { return kClassName(); }
+ static const char* kClassName() { return "DummyFilter"; }
+};
+
+class DummyFilterFactory : public CompactionFilterFactory {
+ public:
+ const char* Name() const override { return kClassName(); }
+ static const char* kClassName() { return "DummyFilterFactory"; }
+
+ std::unique_ptr<CompactionFilter> CreateCompactionFilter(
+ const CompactionFilter::Context&) override {
+ std::unique_ptr<CompactionFilter> f(new DummyFilter());
+ return f;
+ }
+};
+
+static int RegisterTestObjects(ObjectLibrary& library,
+ const std::string& /*arg*/) {
+ library.AddFactory<CompactionFilter>(
+ "DummyFilter", [](const std::string& /*uri*/,
+ std::unique_ptr<CompactionFilter>* /*guard*/,
+ std::string* /* errmsg */) {
+ static DummyFilter dummy;
+ return &dummy;
+ });
+ library.AddFactory<CompactionFilterFactory>(
+ "DummyFilterFactory", [](const std::string& /*uri*/,
+ std::unique_ptr<CompactionFilterFactory>* guard,
+ std::string* /* errmsg */) {
+ guard->reset(new DummyFilterFactory());
+ return guard->get();
+ });
+ return 2;
+}
+
+class TtlOptionsTest : public testing::Test {
+ public:
+ TtlOptionsTest() {
+ config_options_.registry->AddLibrary("RegisterTtlObjects",
+ RegisterTtlObjects, "");
+ config_options_.registry->AddLibrary("RegisterTtlTestObjects",
+ RegisterTestObjects, "");
+ }
+ ConfigOptions config_options_;
+};
+
+TEST_F(TtlOptionsTest, LoadTtlCompactionFilter) {
+ const CompactionFilter* filter = nullptr;
+
+ ASSERT_OK(CompactionFilter::CreateFromString(
+ config_options_, TtlCompactionFilter::kClassName(), &filter));
+ ASSERT_NE(filter, nullptr);
+ ASSERT_STREQ(filter->Name(), TtlCompactionFilter::kClassName());
+ auto ttl = filter->GetOptions<int32_t>("TTL");
+ ASSERT_NE(ttl, nullptr);
+ ASSERT_EQ(*ttl, 0);
+ ASSERT_OK(filter->ValidateOptions(DBOptions(), ColumnFamilyOptions()));
+ delete filter;
+ filter = nullptr;
+
+ ASSERT_OK(CompactionFilter::CreateFromString(
+ config_options_, "id=TtlCompactionFilter; ttl=123", &filter));
+ ASSERT_NE(filter, nullptr);
+ ttl = filter->GetOptions<int32_t>("TTL");
+ ASSERT_NE(ttl, nullptr);
+ ASSERT_EQ(*ttl, 123);
+ ASSERT_OK(filter->ValidateOptions(DBOptions(), ColumnFamilyOptions()));
+ delete filter;
+ filter = nullptr;
+
+ ASSERT_OK(CompactionFilter::CreateFromString(
+ config_options_,
+ "id=TtlCompactionFilter; ttl=456; user_filter=DummyFilter;", &filter));
+ ASSERT_NE(filter, nullptr);
+ auto inner = filter->CheckedCast<DummyFilter>();
+ ASSERT_NE(inner, nullptr);
+ ASSERT_OK(filter->ValidateOptions(DBOptions(), ColumnFamilyOptions()));
+ std::string mismatch;
+ std::string opts_str = filter->ToString(config_options_);
+ const CompactionFilter* copy = nullptr;
+ ASSERT_OK(
+ CompactionFilter::CreateFromString(config_options_, opts_str, &copy));
+ ASSERT_TRUE(filter->AreEquivalent(config_options_, copy, &mismatch));
+ delete filter;
+ delete copy;
+}
+
+TEST_F(TtlOptionsTest, LoadTtlCompactionFilterFactory) {
+ std::shared_ptr<CompactionFilterFactory> cff;
+
+ ASSERT_OK(CompactionFilterFactory::CreateFromString(
+ config_options_, TtlCompactionFilterFactory::kClassName(), &cff));
+ ASSERT_NE(cff.get(), nullptr);
+ ASSERT_STREQ(cff->Name(), TtlCompactionFilterFactory::kClassName());
+ auto ttl = cff->GetOptions<int32_t>("TTL");
+ ASSERT_NE(ttl, nullptr);
+ ASSERT_EQ(*ttl, 0);
+ ASSERT_OK(cff->ValidateOptions(DBOptions(), ColumnFamilyOptions()));
+
+ ASSERT_OK(CompactionFilterFactory::CreateFromString(
+ config_options_, "id=TtlCompactionFilterFactory; ttl=123", &cff));
+ ASSERT_NE(cff.get(), nullptr);
+ ASSERT_STREQ(cff->Name(), TtlCompactionFilterFactory::kClassName());
+ ttl = cff->GetOptions<int32_t>("TTL");
+ ASSERT_NE(ttl, nullptr);
+ ASSERT_EQ(*ttl, 123);
+ ASSERT_OK(cff->ValidateOptions(DBOptions(), ColumnFamilyOptions()));
+
+ ASSERT_OK(CompactionFilterFactory::CreateFromString(
+ config_options_,
+ "id=TtlCompactionFilterFactory; ttl=456; "
+ "user_filter_factory=DummyFilterFactory;",
+ &cff));
+ ASSERT_NE(cff.get(), nullptr);
+ auto filter = cff->CreateCompactionFilter(CompactionFilter::Context());
+ ASSERT_NE(filter.get(), nullptr);
+ auto ttlf = filter->CheckedCast<TtlCompactionFilter>();
+ ASSERT_EQ(filter.get(), ttlf);
+ auto user = filter->CheckedCast<DummyFilter>();
+ ASSERT_NE(user, nullptr);
+ ASSERT_OK(cff->ValidateOptions(DBOptions(), ColumnFamilyOptions()));
+
+ std::string opts_str = cff->ToString(config_options_);
+ std::string mismatch;
+ std::shared_ptr<CompactionFilterFactory> copy;
+ ASSERT_OK(CompactionFilterFactory::CreateFromString(config_options_, opts_str,
+ &copy));
+ ASSERT_TRUE(cff->AreEquivalent(config_options_, copy.get(), &mismatch));
+}
+
+TEST_F(TtlOptionsTest, LoadTtlMergeOperator) {
+ std::shared_ptr<MergeOperator> mo;
+
+ config_options_.invoke_prepare_options = false;
+ ASSERT_OK(MergeOperator::CreateFromString(
+ config_options_, TtlMergeOperator::kClassName(), &mo));
+ ASSERT_NE(mo.get(), nullptr);
+ ASSERT_STREQ(mo->Name(), TtlMergeOperator::kClassName());
+ ASSERT_NOK(mo->ValidateOptions(DBOptions(), ColumnFamilyOptions()));
+
+ config_options_.invoke_prepare_options = true;
+ ASSERT_OK(MergeOperator::CreateFromString(
+ config_options_, "id=TtlMergeOperator; user_operator=bytesxor", &mo));
+ ASSERT_NE(mo.get(), nullptr);
+ ASSERT_STREQ(mo->Name(), TtlMergeOperator::kClassName());
+ ASSERT_OK(mo->ValidateOptions(DBOptions(), ColumnFamilyOptions()));
+ auto ttl_mo = mo->CheckedCast<TtlMergeOperator>();
+ ASSERT_EQ(mo.get(), ttl_mo);
+ auto user = ttl_mo->CheckedCast<BytesXOROperator>();
+ ASSERT_NE(user, nullptr);
+
+ std::string mismatch;
+ std::string opts_str = mo->ToString(config_options_);
+ std::shared_ptr<MergeOperator> copy;
+ ASSERT_OK(MergeOperator::CreateFromString(config_options_, opts_str, &copy));
+ ASSERT_TRUE(mo->AreEquivalent(config_options_, copy.get(), &mismatch));
+}
+} // namespace ROCKSDB_NAMESPACE
+
+// A black-box test for the ttl wrapper around rocksdb
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
+
+#else
+#include <stdio.h>
+
+int main(int /*argc*/, char** /*argv*/) {
+ fprintf(stderr, "SKIPPED as DBWithTTL is not supported in ROCKSDB_LITE\n");
+ return 0;
+}
+
+#endif // !ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/util_merge_operators_test.cc b/src/rocksdb/utilities/util_merge_operators_test.cc
new file mode 100644
index 000000000..fed6f1a75
--- /dev/null
+++ b/src/rocksdb/utilities/util_merge_operators_test.cc
@@ -0,0 +1,100 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "test_util/testharness.h"
+#include "test_util/testutil.h"
+#include "utilities/merge_operators.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class UtilMergeOperatorTest : public testing::Test {
+ public:
+ UtilMergeOperatorTest() {}
+
+ std::string FullMergeV2(std::string existing_value,
+ std::vector<std::string> operands,
+ std::string key = "") {
+ std::string result;
+ Slice result_operand(nullptr, 0);
+
+ Slice existing_value_slice(existing_value);
+ std::vector<Slice> operands_slice(operands.begin(), operands.end());
+
+ const MergeOperator::MergeOperationInput merge_in(
+ key, &existing_value_slice, operands_slice, nullptr);
+ MergeOperator::MergeOperationOutput merge_out(result, result_operand);
+ merge_operator_->FullMergeV2(merge_in, &merge_out);
+
+ if (result_operand.data()) {
+ result.assign(result_operand.data(), result_operand.size());
+ }
+ return result;
+ }
+
+ std::string FullMergeV2(std::vector<std::string> operands,
+ std::string key = "") {
+ std::string result;
+ Slice result_operand(nullptr, 0);
+
+ std::vector<Slice> operands_slice(operands.begin(), operands.end());
+
+ const MergeOperator::MergeOperationInput merge_in(key, nullptr,
+ operands_slice, nullptr);
+ MergeOperator::MergeOperationOutput merge_out(result, result_operand);
+ merge_operator_->FullMergeV2(merge_in, &merge_out);
+
+ if (result_operand.data()) {
+ result.assign(result_operand.data(), result_operand.size());
+ }
+ return result;
+ }
+
+ std::string PartialMerge(std::string left, std::string right,
+ std::string key = "") {
+ std::string result;
+
+ merge_operator_->PartialMerge(key, left, right, &result, nullptr);
+ return result;
+ }
+
+ std::string PartialMergeMulti(std::deque<std::string> operands,
+ std::string key = "") {
+ std::string result;
+ std::deque<Slice> operands_slice(operands.begin(), operands.end());
+
+ merge_operator_->PartialMergeMulti(key, operands_slice, &result, nullptr);
+ return result;
+ }
+
+ protected:
+ std::shared_ptr<MergeOperator> merge_operator_;
+};
+
+TEST_F(UtilMergeOperatorTest, MaxMergeOperator) {
+ merge_operator_ = MergeOperators::CreateMaxOperator();
+
+ EXPECT_EQ("B", FullMergeV2("B", {"A"}));
+ EXPECT_EQ("B", FullMergeV2("A", {"B"}));
+ EXPECT_EQ("", FullMergeV2({"", "", ""}));
+ EXPECT_EQ("A", FullMergeV2({"A"}));
+ EXPECT_EQ("ABC", FullMergeV2({"ABC"}));
+ EXPECT_EQ("Z", FullMergeV2({"ABC", "Z", "C", "AXX"}));
+ EXPECT_EQ("ZZZ", FullMergeV2({"ABC", "CC", "Z", "ZZZ"}));
+ EXPECT_EQ("a", FullMergeV2("a", {"ABC", "CC", "Z", "ZZZ"}));
+
+ EXPECT_EQ("z", PartialMergeMulti({"a", "z", "efqfqwgwew", "aaz", "hhhhh"}));
+
+ EXPECT_EQ("b", PartialMerge("a", "b"));
+ EXPECT_EQ("z", PartialMerge("z", "azzz"));
+ EXPECT_EQ("a", PartialMerge("a", ""));
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/rocksdb/utilities/wal_filter.cc b/src/rocksdb/utilities/wal_filter.cc
new file mode 100644
index 000000000..98bba3610
--- /dev/null
+++ b/src/rocksdb/utilities/wal_filter.cc
@@ -0,0 +1,23 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "rocksdb/wal_filter.h"
+
+#include <memory>
+
+#include "rocksdb/convenience.h"
+#include "rocksdb/options.h"
+#include "rocksdb/utilities/customizable_util.h"
+
+namespace ROCKSDB_NAMESPACE {
+Status WalFilter::CreateFromString(const ConfigOptions& config_options,
+ const std::string& value,
+ WalFilter** filter) {
+ Status s =
+ LoadStaticObject<WalFilter>(config_options, value, nullptr, filter);
+ return s;
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/write_batch_with_index/write_batch_with_index.cc b/src/rocksdb/utilities/write_batch_with_index/write_batch_with_index.cc
new file mode 100644
index 000000000..408243b3f
--- /dev/null
+++ b/src/rocksdb/utilities/write_batch_with_index/write_batch_with_index.cc
@@ -0,0 +1,695 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "rocksdb/utilities/write_batch_with_index.h"
+
+#include <memory>
+
+#include "db/column_family.h"
+#include "db/db_impl/db_impl.h"
+#include "db/merge_context.h"
+#include "db/merge_helper.h"
+#include "memory/arena.h"
+#include "memtable/skiplist.h"
+#include "options/db_options.h"
+#include "rocksdb/comparator.h"
+#include "rocksdb/iterator.h"
+#include "util/cast_util.h"
+#include "util/string_util.h"
+#include "utilities/write_batch_with_index/write_batch_with_index_internal.h"
+
+namespace ROCKSDB_NAMESPACE {
+struct WriteBatchWithIndex::Rep {
+ explicit Rep(const Comparator* index_comparator, size_t reserved_bytes = 0,
+ size_t max_bytes = 0, bool _overwrite_key = false,
+ size_t protection_bytes_per_key = 0)
+ : write_batch(reserved_bytes, max_bytes, protection_bytes_per_key,
+ index_comparator ? index_comparator->timestamp_size() : 0),
+ comparator(index_comparator, &write_batch),
+ skip_list(comparator, &arena),
+ overwrite_key(_overwrite_key),
+ last_entry_offset(0),
+ last_sub_batch_offset(0),
+ sub_batch_cnt(1) {}
+ ReadableWriteBatch write_batch;
+ WriteBatchEntryComparator comparator;
+ Arena arena;
+ WriteBatchEntrySkipList skip_list;
+ bool overwrite_key;
+ size_t last_entry_offset;
+ // The starting offset of the last sub-batch. A sub-batch starts right before
+ // inserting a key that is a duplicate of a key in the last sub-batch. Zero,
+ // the default, means that no duplicate key is detected so far.
+ size_t last_sub_batch_offset;
+ // Total number of sub-batches in the write batch. Default is 1.
+ size_t sub_batch_cnt;
+
+ // Remember current offset of internal write batch, which is used as
+ // the starting offset of the next record.
+ void SetLastEntryOffset() { last_entry_offset = write_batch.GetDataSize(); }
+
+ // In overwrite mode, find the existing entry for the same key and update it
+ // to point to the current entry.
+ // Return true if the key is found and updated.
+ bool UpdateExistingEntry(ColumnFamilyHandle* column_family, const Slice& key,
+ WriteType type);
+ bool UpdateExistingEntryWithCfId(uint32_t column_family_id, const Slice& key,
+ WriteType type);
+
+ // Add the recent entry to the update.
+ // In overwrite mode, if key already exists in the index, update it.
+ void AddOrUpdateIndex(ColumnFamilyHandle* column_family, const Slice& key,
+ WriteType type);
+ void AddOrUpdateIndex(const Slice& key, WriteType type);
+
+ // Allocate an index entry pointing to the last entry in the write batch and
+ // put it to skip list.
+ void AddNewEntry(uint32_t column_family_id);
+
+ // Clear all updates buffered in this batch.
+ void Clear();
+ void ClearIndex();
+
+ // Rebuild index by reading all records from the batch.
+ // Returns non-ok status on corruption.
+ Status ReBuildIndex();
+};
+
+bool WriteBatchWithIndex::Rep::UpdateExistingEntry(
+ ColumnFamilyHandle* column_family, const Slice& key, WriteType type) {
+ uint32_t cf_id = GetColumnFamilyID(column_family);
+ return UpdateExistingEntryWithCfId(cf_id, key, type);
+}
+
+bool WriteBatchWithIndex::Rep::UpdateExistingEntryWithCfId(
+ uint32_t column_family_id, const Slice& key, WriteType type) {
+ if (!overwrite_key) {
+ return false;
+ }
+
+ WBWIIteratorImpl iter(column_family_id, &skip_list, &write_batch,
+ &comparator);
+ iter.Seek(key);
+ if (!iter.Valid()) {
+ return false;
+ } else if (!iter.MatchesKey(column_family_id, key)) {
+ return false;
+ } else {
+ // Move to the end of this key (NextKey-Prev)
+ iter.NextKey(); // Move to the next key
+ if (iter.Valid()) {
+ iter.Prev(); // Move back one entry
+ } else {
+ iter.SeekToLast();
+ }
+ }
+ WriteBatchIndexEntry* non_const_entry =
+ const_cast<WriteBatchIndexEntry*>(iter.GetRawEntry());
+ if (LIKELY(last_sub_batch_offset <= non_const_entry->offset)) {
+ last_sub_batch_offset = last_entry_offset;
+ sub_batch_cnt++;
+ }
+ if (type == kMergeRecord) {
+ return false;
+ } else {
+ non_const_entry->offset = last_entry_offset;
+ return true;
+ }
+}
+
+void WriteBatchWithIndex::Rep::AddOrUpdateIndex(
+ ColumnFamilyHandle* column_family, const Slice& key, WriteType type) {
+ if (!UpdateExistingEntry(column_family, key, type)) {
+ uint32_t cf_id = GetColumnFamilyID(column_family);
+ const auto* cf_cmp = GetColumnFamilyUserComparator(column_family);
+ if (cf_cmp != nullptr) {
+ comparator.SetComparatorForCF(cf_id, cf_cmp);
+ }
+ AddNewEntry(cf_id);
+ }
+}
+
+void WriteBatchWithIndex::Rep::AddOrUpdateIndex(const Slice& key,
+ WriteType type) {
+ if (!UpdateExistingEntryWithCfId(0, key, type)) {
+ AddNewEntry(0);
+ }
+}
+
+void WriteBatchWithIndex::Rep::AddNewEntry(uint32_t column_family_id) {
+ const std::string& wb_data = write_batch.Data();
+ Slice entry_ptr = Slice(wb_data.data() + last_entry_offset,
+ wb_data.size() - last_entry_offset);
+ // Extract key
+ Slice key;
+ bool success =
+ ReadKeyFromWriteBatchEntry(&entry_ptr, &key, column_family_id != 0);
+#ifdef NDEBUG
+ (void)success;
+#endif
+ assert(success);
+
+ const Comparator* const ucmp = comparator.GetComparator(column_family_id);
+ size_t ts_sz = ucmp ? ucmp->timestamp_size() : 0;
+
+ if (ts_sz > 0) {
+ key.remove_suffix(ts_sz);
+ }
+
+ auto* mem = arena.Allocate(sizeof(WriteBatchIndexEntry));
+ auto* index_entry =
+ new (mem) WriteBatchIndexEntry(last_entry_offset, column_family_id,
+ key.data() - wb_data.data(), key.size());
+ skip_list.Insert(index_entry);
+}
+
+void WriteBatchWithIndex::Rep::Clear() {
+ write_batch.Clear();
+ ClearIndex();
+}
+
+void WriteBatchWithIndex::Rep::ClearIndex() {
+ skip_list.~WriteBatchEntrySkipList();
+ arena.~Arena();
+ new (&arena) Arena();
+ new (&skip_list) WriteBatchEntrySkipList(comparator, &arena);
+ last_entry_offset = 0;
+ last_sub_batch_offset = 0;
+ sub_batch_cnt = 1;
+}
+
+Status WriteBatchWithIndex::Rep::ReBuildIndex() {
+ Status s;
+
+ ClearIndex();
+
+ if (write_batch.Count() == 0) {
+ // Nothing to re-index
+ return s;
+ }
+
+ size_t offset = WriteBatchInternal::GetFirstOffset(&write_batch);
+
+ Slice input(write_batch.Data());
+ input.remove_prefix(offset);
+
+ // Loop through all entries in Rep and add each one to the index
+ uint32_t found = 0;
+ while (s.ok() && !input.empty()) {
+ Slice key, value, blob, xid;
+ uint32_t column_family_id = 0; // default
+ char tag = 0;
+
+ // set offset of current entry for call to AddNewEntry()
+ last_entry_offset = input.data() - write_batch.Data().data();
+
+ s = ReadRecordFromWriteBatch(&input, &tag, &column_family_id, &key, &value,
+ &blob, &xid);
+ if (!s.ok()) {
+ break;
+ }
+
+ switch (tag) {
+ case kTypeColumnFamilyValue:
+ case kTypeValue:
+ found++;
+ if (!UpdateExistingEntryWithCfId(column_family_id, key, kPutRecord)) {
+ AddNewEntry(column_family_id);
+ }
+ break;
+ case kTypeColumnFamilyDeletion:
+ case kTypeDeletion:
+ found++;
+ if (!UpdateExistingEntryWithCfId(column_family_id, key,
+ kDeleteRecord)) {
+ AddNewEntry(column_family_id);
+ }
+ break;
+ case kTypeColumnFamilySingleDeletion:
+ case kTypeSingleDeletion:
+ found++;
+ if (!UpdateExistingEntryWithCfId(column_family_id, key,
+ kSingleDeleteRecord)) {
+ AddNewEntry(column_family_id);
+ }
+ break;
+ case kTypeColumnFamilyMerge:
+ case kTypeMerge:
+ found++;
+ if (!UpdateExistingEntryWithCfId(column_family_id, key, kMergeRecord)) {
+ AddNewEntry(column_family_id);
+ }
+ break;
+ case kTypeLogData:
+ case kTypeBeginPrepareXID:
+ case kTypeBeginPersistedPrepareXID:
+ case kTypeBeginUnprepareXID:
+ case kTypeEndPrepareXID:
+ case kTypeCommitXID:
+ case kTypeCommitXIDAndTimestamp:
+ case kTypeRollbackXID:
+ case kTypeNoop:
+ break;
+ default:
+ return Status::Corruption(
+ "unknown WriteBatch tag in ReBuildIndex",
+ std::to_string(static_cast<unsigned int>(tag)));
+ }
+ }
+
+ if (s.ok() && found != write_batch.Count()) {
+ s = Status::Corruption("WriteBatch has wrong count");
+ }
+
+ return s;
+}
+
+WriteBatchWithIndex::WriteBatchWithIndex(
+ const Comparator* default_index_comparator, size_t reserved_bytes,
+ bool overwrite_key, size_t max_bytes, size_t protection_bytes_per_key)
+ : rep(new Rep(default_index_comparator, reserved_bytes, max_bytes,
+ overwrite_key, protection_bytes_per_key)) {}
+
+WriteBatchWithIndex::~WriteBatchWithIndex() {}
+
+WriteBatchWithIndex::WriteBatchWithIndex(WriteBatchWithIndex&&) = default;
+
+WriteBatchWithIndex& WriteBatchWithIndex::operator=(WriteBatchWithIndex&&) =
+ default;
+
+WriteBatch* WriteBatchWithIndex::GetWriteBatch() { return &rep->write_batch; }
+
+size_t WriteBatchWithIndex::SubBatchCnt() { return rep->sub_batch_cnt; }
+
+WBWIIterator* WriteBatchWithIndex::NewIterator() {
+ return new WBWIIteratorImpl(0, &(rep->skip_list), &rep->write_batch,
+ &(rep->comparator));
+}
+
+WBWIIterator* WriteBatchWithIndex::NewIterator(
+ ColumnFamilyHandle* column_family) {
+ return new WBWIIteratorImpl(GetColumnFamilyID(column_family),
+ &(rep->skip_list), &rep->write_batch,
+ &(rep->comparator));
+}
+
+Iterator* WriteBatchWithIndex::NewIteratorWithBase(
+ ColumnFamilyHandle* column_family, Iterator* base_iterator,
+ const ReadOptions* read_options) {
+ auto wbwiii =
+ new WBWIIteratorImpl(GetColumnFamilyID(column_family), &(rep->skip_list),
+ &rep->write_batch, &rep->comparator);
+ return new BaseDeltaIterator(column_family, base_iterator, wbwiii,
+ GetColumnFamilyUserComparator(column_family),
+ read_options);
+}
+
+Iterator* WriteBatchWithIndex::NewIteratorWithBase(Iterator* base_iterator) {
+ // default column family's comparator
+ auto wbwiii = new WBWIIteratorImpl(0, &(rep->skip_list), &rep->write_batch,
+ &rep->comparator);
+ return new BaseDeltaIterator(nullptr, base_iterator, wbwiii,
+ rep->comparator.default_comparator());
+}
+
+Status WriteBatchWithIndex::Put(ColumnFamilyHandle* column_family,
+ const Slice& key, const Slice& value) {
+ rep->SetLastEntryOffset();
+ auto s = rep->write_batch.Put(column_family, key, value);
+ if (s.ok()) {
+ rep->AddOrUpdateIndex(column_family, key, kPutRecord);
+ }
+ return s;
+}
+
+Status WriteBatchWithIndex::Put(const Slice& key, const Slice& value) {
+ rep->SetLastEntryOffset();
+ auto s = rep->write_batch.Put(key, value);
+ if (s.ok()) {
+ rep->AddOrUpdateIndex(key, kPutRecord);
+ }
+ return s;
+}
+
+Status WriteBatchWithIndex::Put(ColumnFamilyHandle* column_family,
+ const Slice& /*key*/, const Slice& /*ts*/,
+ const Slice& /*value*/) {
+ if (!column_family) {
+ return Status::InvalidArgument("column family handle cannot be nullptr");
+ }
+ // TODO: support WBWI::Put() with timestamp.
+ return Status::NotSupported();
+}
+
+Status WriteBatchWithIndex::Delete(ColumnFamilyHandle* column_family,
+ const Slice& key) {
+ rep->SetLastEntryOffset();
+ auto s = rep->write_batch.Delete(column_family, key);
+ if (s.ok()) {
+ rep->AddOrUpdateIndex(column_family, key, kDeleteRecord);
+ }
+ return s;
+}
+
+Status WriteBatchWithIndex::Delete(const Slice& key) {
+ rep->SetLastEntryOffset();
+ auto s = rep->write_batch.Delete(key);
+ if (s.ok()) {
+ rep->AddOrUpdateIndex(key, kDeleteRecord);
+ }
+ return s;
+}
+
+Status WriteBatchWithIndex::Delete(ColumnFamilyHandle* column_family,
+ const Slice& /*key*/, const Slice& /*ts*/) {
+ if (!column_family) {
+ return Status::InvalidArgument("column family handle cannot be nullptr");
+ }
+ // TODO: support WBWI::Delete() with timestamp.
+ return Status::NotSupported();
+}
+
+Status WriteBatchWithIndex::SingleDelete(ColumnFamilyHandle* column_family,
+ const Slice& key) {
+ rep->SetLastEntryOffset();
+ auto s = rep->write_batch.SingleDelete(column_family, key);
+ if (s.ok()) {
+ rep->AddOrUpdateIndex(column_family, key, kSingleDeleteRecord);
+ }
+ return s;
+}
+
+Status WriteBatchWithIndex::SingleDelete(const Slice& key) {
+ rep->SetLastEntryOffset();
+ auto s = rep->write_batch.SingleDelete(key);
+ if (s.ok()) {
+ rep->AddOrUpdateIndex(key, kSingleDeleteRecord);
+ }
+ return s;
+}
+
+Status WriteBatchWithIndex::SingleDelete(ColumnFamilyHandle* column_family,
+ const Slice& /*key*/,
+ const Slice& /*ts*/) {
+ if (!column_family) {
+ return Status::InvalidArgument("column family handle cannot be nullptr");
+ }
+ // TODO: support WBWI::SingleDelete() with timestamp.
+ return Status::NotSupported();
+}
+
+Status WriteBatchWithIndex::Merge(ColumnFamilyHandle* column_family,
+ const Slice& key, const Slice& value) {
+ rep->SetLastEntryOffset();
+ auto s = rep->write_batch.Merge(column_family, key, value);
+ if (s.ok()) {
+ rep->AddOrUpdateIndex(column_family, key, kMergeRecord);
+ }
+ return s;
+}
+
+Status WriteBatchWithIndex::Merge(const Slice& key, const Slice& value) {
+ rep->SetLastEntryOffset();
+ auto s = rep->write_batch.Merge(key, value);
+ if (s.ok()) {
+ rep->AddOrUpdateIndex(key, kMergeRecord);
+ }
+ return s;
+}
+
+Status WriteBatchWithIndex::PutLogData(const Slice& blob) {
+ return rep->write_batch.PutLogData(blob);
+}
+
+void WriteBatchWithIndex::Clear() { rep->Clear(); }
+
+Status WriteBatchWithIndex::GetFromBatch(ColumnFamilyHandle* column_family,
+ const DBOptions& options,
+ const Slice& key, std::string* value) {
+ Status s;
+ WriteBatchWithIndexInternal wbwii(&options, column_family);
+ auto result = wbwii.GetFromBatch(this, key, value, &s);
+
+ switch (result) {
+ case WBWIIteratorImpl::kFound:
+ case WBWIIteratorImpl::kError:
+ // use returned status
+ break;
+ case WBWIIteratorImpl::kDeleted:
+ case WBWIIteratorImpl::kNotFound:
+ s = Status::NotFound();
+ break;
+ case WBWIIteratorImpl::kMergeInProgress:
+ s = Status::MergeInProgress();
+ break;
+ default:
+ assert(false);
+ }
+
+ return s;
+}
+
+Status WriteBatchWithIndex::GetFromBatchAndDB(DB* db,
+ const ReadOptions& read_options,
+ const Slice& key,
+ std::string* value) {
+ assert(value != nullptr);
+ PinnableSlice pinnable_val(value);
+ assert(!pinnable_val.IsPinned());
+ auto s = GetFromBatchAndDB(db, read_options, db->DefaultColumnFamily(), key,
+ &pinnable_val);
+ if (s.ok() && pinnable_val.IsPinned()) {
+ value->assign(pinnable_val.data(), pinnable_val.size());
+ } // else value is already assigned
+ return s;
+}
+
+Status WriteBatchWithIndex::GetFromBatchAndDB(DB* db,
+ const ReadOptions& read_options,
+ const Slice& key,
+ PinnableSlice* pinnable_val) {
+ return GetFromBatchAndDB(db, read_options, db->DefaultColumnFamily(), key,
+ pinnable_val);
+}
+
+Status WriteBatchWithIndex::GetFromBatchAndDB(DB* db,
+ const ReadOptions& read_options,
+ ColumnFamilyHandle* column_family,
+ const Slice& key,
+ std::string* value) {
+ assert(value != nullptr);
+ PinnableSlice pinnable_val(value);
+ assert(!pinnable_val.IsPinned());
+ auto s =
+ GetFromBatchAndDB(db, read_options, column_family, key, &pinnable_val);
+ if (s.ok() && pinnable_val.IsPinned()) {
+ value->assign(pinnable_val.data(), pinnable_val.size());
+ } // else value is already assigned
+ return s;
+}
+
+Status WriteBatchWithIndex::GetFromBatchAndDB(DB* db,
+ const ReadOptions& read_options,
+ ColumnFamilyHandle* column_family,
+ const Slice& key,
+ PinnableSlice* pinnable_val) {
+ return GetFromBatchAndDB(db, read_options, column_family, key, pinnable_val,
+ nullptr);
+}
+
+Status WriteBatchWithIndex::GetFromBatchAndDB(
+ DB* db, const ReadOptions& read_options, ColumnFamilyHandle* column_family,
+ const Slice& key, PinnableSlice* pinnable_val, ReadCallback* callback) {
+ const Comparator* const ucmp = rep->comparator.GetComparator(column_family);
+ size_t ts_sz = ucmp ? ucmp->timestamp_size() : 0;
+ if (ts_sz > 0 && !read_options.timestamp) {
+ return Status::InvalidArgument("Must specify timestamp");
+ }
+
+ Status s;
+ WriteBatchWithIndexInternal wbwii(db, column_family);
+
+ // Since the lifetime of the WriteBatch is the same as that of the transaction
+ // we cannot pin it as otherwise the returned value will not be available
+ // after the transaction finishes.
+ std::string& batch_value = *pinnable_val->GetSelf();
+ auto result = wbwii.GetFromBatch(this, key, &batch_value, &s);
+
+ if (result == WBWIIteratorImpl::kFound) {
+ pinnable_val->PinSelf();
+ return s;
+ } else if (!s.ok() || result == WBWIIteratorImpl::kError) {
+ return s;
+ } else if (result == WBWIIteratorImpl::kDeleted) {
+ return Status::NotFound();
+ }
+ assert(result == WBWIIteratorImpl::kMergeInProgress ||
+ result == WBWIIteratorImpl::kNotFound);
+
+ // Did not find key in batch OR could not resolve Merges. Try DB.
+ if (!callback) {
+ s = db->Get(read_options, column_family, key, pinnable_val);
+ } else {
+ DBImpl::GetImplOptions get_impl_options;
+ get_impl_options.column_family = column_family;
+ get_impl_options.value = pinnable_val;
+ get_impl_options.callback = callback;
+ s = static_cast_with_check<DBImpl>(db->GetRootDB())
+ ->GetImpl(read_options, key, get_impl_options);
+ }
+
+ if (s.ok() || s.IsNotFound()) { // DB Get Succeeded
+ if (result == WBWIIteratorImpl::kMergeInProgress) {
+ // Merge result from DB with merges in Batch
+ std::string merge_result;
+ if (s.ok()) {
+ s = wbwii.MergeKey(key, pinnable_val, &merge_result);
+ } else { // Key not present in db (s.IsNotFound())
+ s = wbwii.MergeKey(key, nullptr, &merge_result);
+ }
+ if (s.ok()) {
+ pinnable_val->Reset();
+ *pinnable_val->GetSelf() = std::move(merge_result);
+ pinnable_val->PinSelf();
+ }
+ }
+ }
+
+ return s;
+}
+
+void WriteBatchWithIndex::MultiGetFromBatchAndDB(
+ DB* db, const ReadOptions& read_options, ColumnFamilyHandle* column_family,
+ const size_t num_keys, const Slice* keys, PinnableSlice* values,
+ Status* statuses, bool sorted_input) {
+ MultiGetFromBatchAndDB(db, read_options, column_family, num_keys, keys,
+ values, statuses, sorted_input, nullptr);
+}
+
+void WriteBatchWithIndex::MultiGetFromBatchAndDB(
+ DB* db, const ReadOptions& read_options, ColumnFamilyHandle* column_family,
+ const size_t num_keys, const Slice* keys, PinnableSlice* values,
+ Status* statuses, bool sorted_input, ReadCallback* callback) {
+ const Comparator* const ucmp = rep->comparator.GetComparator(column_family);
+ size_t ts_sz = ucmp ? ucmp->timestamp_size() : 0;
+ if (ts_sz > 0 && !read_options.timestamp) {
+ for (size_t i = 0; i < num_keys; ++i) {
+ statuses[i] = Status::InvalidArgument("Must specify timestamp");
+ }
+ return;
+ }
+
+ WriteBatchWithIndexInternal wbwii(db, column_family);
+
+ autovector<KeyContext, MultiGetContext::MAX_BATCH_SIZE> key_context;
+ autovector<KeyContext*, MultiGetContext::MAX_BATCH_SIZE> sorted_keys;
+ // To hold merges from the write batch
+ autovector<std::pair<WBWIIteratorImpl::Result, MergeContext>,
+ MultiGetContext::MAX_BATCH_SIZE>
+ merges;
+ // Since the lifetime of the WriteBatch is the same as that of the transaction
+ // we cannot pin it as otherwise the returned value will not be available
+ // after the transaction finishes.
+ for (size_t i = 0; i < num_keys; ++i) {
+ MergeContext merge_context;
+ std::string batch_value;
+ Status* s = &statuses[i];
+ PinnableSlice* pinnable_val = &values[i];
+ pinnable_val->Reset();
+ auto result =
+ wbwii.GetFromBatch(this, keys[i], &merge_context, &batch_value, s);
+
+ if (result == WBWIIteratorImpl::kFound) {
+ *pinnable_val->GetSelf() = std::move(batch_value);
+ pinnable_val->PinSelf();
+ continue;
+ }
+ if (result == WBWIIteratorImpl::kDeleted) {
+ *s = Status::NotFound();
+ continue;
+ }
+ if (result == WBWIIteratorImpl::kError) {
+ continue;
+ }
+ assert(result == WBWIIteratorImpl::kMergeInProgress ||
+ result == WBWIIteratorImpl::kNotFound);
+ key_context.emplace_back(column_family, keys[i], &values[i],
+ /*timestamp*/ nullptr, &statuses[i]);
+ merges.emplace_back(result, std::move(merge_context));
+ }
+
+ for (KeyContext& key : key_context) {
+ sorted_keys.emplace_back(&key);
+ }
+
+ // Did not find key in batch OR could not resolve Merges. Try DB.
+ static_cast_with_check<DBImpl>(db->GetRootDB())
+ ->PrepareMultiGetKeys(key_context.size(), sorted_input, &sorted_keys);
+ static_cast_with_check<DBImpl>(db->GetRootDB())
+ ->MultiGetWithCallback(read_options, column_family, callback,
+ &sorted_keys);
+
+ for (auto iter = key_context.begin(); iter != key_context.end(); ++iter) {
+ KeyContext& key = *iter;
+ if (key.s->ok() || key.s->IsNotFound()) { // DB Get Succeeded
+ size_t index = iter - key_context.begin();
+ std::pair<WBWIIteratorImpl::Result, MergeContext>& merge_result =
+ merges[index];
+ if (merge_result.first == WBWIIteratorImpl::kMergeInProgress) {
+ std::string merged_value;
+ // Merge result from DB with merges in Batch
+ if (key.s->ok()) {
+ *key.s = wbwii.MergeKey(*key.key, iter->value, merge_result.second,
+ &merged_value);
+ } else { // Key not present in db (s.IsNotFound())
+ *key.s = wbwii.MergeKey(*key.key, nullptr, merge_result.second,
+ &merged_value);
+ }
+ if (key.s->ok()) {
+ key.value->Reset();
+ *key.value->GetSelf() = std::move(merged_value);
+ key.value->PinSelf();
+ }
+ }
+ }
+ }
+}
+
+void WriteBatchWithIndex::SetSavePoint() { rep->write_batch.SetSavePoint(); }
+
+Status WriteBatchWithIndex::RollbackToSavePoint() {
+ Status s = rep->write_batch.RollbackToSavePoint();
+
+ if (s.ok()) {
+ rep->sub_batch_cnt = 1;
+ rep->last_sub_batch_offset = 0;
+ s = rep->ReBuildIndex();
+ }
+
+ return s;
+}
+
+Status WriteBatchWithIndex::PopSavePoint() {
+ return rep->write_batch.PopSavePoint();
+}
+
+void WriteBatchWithIndex::SetMaxBytes(size_t max_bytes) {
+ rep->write_batch.SetMaxBytes(max_bytes);
+}
+
+size_t WriteBatchWithIndex::GetDataSize() const {
+ return rep->write_batch.GetDataSize();
+}
+
+const Comparator* WriteBatchWithIndexInternal::GetUserComparator(
+ const WriteBatchWithIndex& wbwi, uint32_t cf_id) {
+ const WriteBatchEntryComparator& ucmps = wbwi.rep->comparator;
+ return ucmps.GetComparator(cf_id);
+}
+
+} // namespace ROCKSDB_NAMESPACE
+#endif // !ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/write_batch_with_index/write_batch_with_index_internal.cc b/src/rocksdb/utilities/write_batch_with_index/write_batch_with_index_internal.cc
new file mode 100644
index 000000000..3c9205bf7
--- /dev/null
+++ b/src/rocksdb/utilities/write_batch_with_index/write_batch_with_index_internal.cc
@@ -0,0 +1,735 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "utilities/write_batch_with_index/write_batch_with_index_internal.h"
+
+#include "db/column_family.h"
+#include "db/db_impl/db_impl.h"
+#include "db/merge_context.h"
+#include "db/merge_helper.h"
+#include "rocksdb/comparator.h"
+#include "rocksdb/db.h"
+#include "rocksdb/utilities/write_batch_with_index.h"
+#include "util/cast_util.h"
+#include "util/coding.h"
+#include "util/string_util.h"
+
+namespace ROCKSDB_NAMESPACE {
+BaseDeltaIterator::BaseDeltaIterator(ColumnFamilyHandle* column_family,
+ Iterator* base_iterator,
+ WBWIIteratorImpl* delta_iterator,
+ const Comparator* comparator,
+ const ReadOptions* read_options)
+ : forward_(true),
+ current_at_base_(true),
+ equal_keys_(false),
+ status_(Status::OK()),
+ base_iterator_(base_iterator),
+ delta_iterator_(delta_iterator),
+ comparator_(comparator),
+ iterate_upper_bound_(read_options ? read_options->iterate_upper_bound
+ : nullptr) {
+ assert(comparator_);
+ wbwii_.reset(new WriteBatchWithIndexInternal(column_family));
+}
+
+bool BaseDeltaIterator::Valid() const {
+ return status_.ok() ? (current_at_base_ ? BaseValid() : DeltaValid()) : false;
+}
+
+void BaseDeltaIterator::SeekToFirst() {
+ forward_ = true;
+ base_iterator_->SeekToFirst();
+ delta_iterator_->SeekToFirst();
+ UpdateCurrent();
+}
+
+void BaseDeltaIterator::SeekToLast() {
+ forward_ = false;
+ base_iterator_->SeekToLast();
+ delta_iterator_->SeekToLast();
+ UpdateCurrent();
+}
+
+void BaseDeltaIterator::Seek(const Slice& k) {
+ forward_ = true;
+ base_iterator_->Seek(k);
+ delta_iterator_->Seek(k);
+ UpdateCurrent();
+}
+
+void BaseDeltaIterator::SeekForPrev(const Slice& k) {
+ forward_ = false;
+ base_iterator_->SeekForPrev(k);
+ delta_iterator_->SeekForPrev(k);
+ UpdateCurrent();
+}
+
+void BaseDeltaIterator::Next() {
+ if (!Valid()) {
+ status_ = Status::NotSupported("Next() on invalid iterator");
+ return;
+ }
+
+ if (!forward_) {
+ // Need to change direction
+ // if our direction was backward and we're not equal, we have two states:
+ // * both iterators are valid: we're already in a good state (current
+ // shows to smaller)
+ // * only one iterator is valid: we need to advance that iterator
+ forward_ = true;
+ equal_keys_ = false;
+ if (!BaseValid()) {
+ assert(DeltaValid());
+ base_iterator_->SeekToFirst();
+ } else if (!DeltaValid()) {
+ delta_iterator_->SeekToFirst();
+ } else if (current_at_base_) {
+ // Change delta from larger than base to smaller
+ AdvanceDelta();
+ } else {
+ // Change base from larger than delta to smaller
+ AdvanceBase();
+ }
+ if (DeltaValid() && BaseValid()) {
+ if (0 == comparator_->CompareWithoutTimestamp(
+ delta_iterator_->Entry().key, /*a_has_ts=*/false,
+ base_iterator_->key(), /*b_has_ts=*/false)) {
+ equal_keys_ = true;
+ }
+ }
+ }
+ Advance();
+}
+
+void BaseDeltaIterator::Prev() {
+ if (!Valid()) {
+ status_ = Status::NotSupported("Prev() on invalid iterator");
+ return;
+ }
+
+ if (forward_) {
+ // Need to change direction
+ // if our direction was backward and we're not equal, we have two states:
+ // * both iterators are valid: we're already in a good state (current
+ // shows to smaller)
+ // * only one iterator is valid: we need to advance that iterator
+ forward_ = false;
+ equal_keys_ = false;
+ if (!BaseValid()) {
+ assert(DeltaValid());
+ base_iterator_->SeekToLast();
+ } else if (!DeltaValid()) {
+ delta_iterator_->SeekToLast();
+ } else if (current_at_base_) {
+ // Change delta from less advanced than base to more advanced
+ AdvanceDelta();
+ } else {
+ // Change base from less advanced than delta to more advanced
+ AdvanceBase();
+ }
+ if (DeltaValid() && BaseValid()) {
+ if (0 == comparator_->CompareWithoutTimestamp(
+ delta_iterator_->Entry().key, /*a_has_ts=*/false,
+ base_iterator_->key(), /*b_has_ts=*/false)) {
+ equal_keys_ = true;
+ }
+ }
+ }
+
+ Advance();
+}
+
+Slice BaseDeltaIterator::key() const {
+ return current_at_base_ ? base_iterator_->key()
+ : delta_iterator_->Entry().key;
+}
+
+Slice BaseDeltaIterator::value() const {
+ if (current_at_base_) {
+ return base_iterator_->value();
+ } else {
+ WriteEntry delta_entry = delta_iterator_->Entry();
+ if (wbwii_->GetNumOperands() == 0) {
+ return delta_entry.value;
+ } else if (delta_entry.type == kDeleteRecord ||
+ delta_entry.type == kSingleDeleteRecord) {
+ status_ =
+ wbwii_->MergeKey(delta_entry.key, nullptr, merge_result_.GetSelf());
+ } else if (delta_entry.type == kPutRecord) {
+ status_ = wbwii_->MergeKey(delta_entry.key, &delta_entry.value,
+ merge_result_.GetSelf());
+ } else if (delta_entry.type == kMergeRecord) {
+ if (equal_keys_) {
+ Slice base_value = base_iterator_->value();
+ status_ = wbwii_->MergeKey(delta_entry.key, &base_value,
+ merge_result_.GetSelf());
+ } else {
+ status_ =
+ wbwii_->MergeKey(delta_entry.key, nullptr, merge_result_.GetSelf());
+ }
+ }
+ merge_result_.PinSelf();
+ return merge_result_;
+ }
+}
+
+Status BaseDeltaIterator::status() const {
+ if (!status_.ok()) {
+ return status_;
+ }
+ if (!base_iterator_->status().ok()) {
+ return base_iterator_->status();
+ }
+ return delta_iterator_->status();
+}
+
+void BaseDeltaIterator::Invalidate(Status s) { status_ = s; }
+
+void BaseDeltaIterator::AssertInvariants() {
+#ifndef NDEBUG
+ bool not_ok = false;
+ if (!base_iterator_->status().ok()) {
+ assert(!base_iterator_->Valid());
+ not_ok = true;
+ }
+ if (!delta_iterator_->status().ok()) {
+ assert(!delta_iterator_->Valid());
+ not_ok = true;
+ }
+ if (not_ok) {
+ assert(!Valid());
+ assert(!status().ok());
+ return;
+ }
+
+ if (!Valid()) {
+ return;
+ }
+ if (!BaseValid()) {
+ assert(!current_at_base_ && delta_iterator_->Valid());
+ return;
+ }
+ if (!DeltaValid()) {
+ assert(current_at_base_ && base_iterator_->Valid());
+ return;
+ }
+ // we don't support those yet
+ assert(delta_iterator_->Entry().type != kMergeRecord &&
+ delta_iterator_->Entry().type != kLogDataRecord);
+ int compare = comparator_->CompareWithoutTimestamp(
+ delta_iterator_->Entry().key, /*a_has_ts=*/false, base_iterator_->key(),
+ /*b_has_ts=*/false);
+ if (forward_) {
+ // current_at_base -> compare < 0
+ assert(!current_at_base_ || compare < 0);
+ // !current_at_base -> compare <= 0
+ assert(current_at_base_ && compare >= 0);
+ } else {
+ // current_at_base -> compare > 0
+ assert(!current_at_base_ || compare > 0);
+ // !current_at_base -> compare <= 0
+ assert(current_at_base_ && compare <= 0);
+ }
+ // equal_keys_ <=> compare == 0
+ assert((equal_keys_ || compare != 0) && (!equal_keys_ || compare == 0));
+#endif
+}
+
+void BaseDeltaIterator::Advance() {
+ if (equal_keys_) {
+ assert(BaseValid() && DeltaValid());
+ AdvanceBase();
+ AdvanceDelta();
+ } else {
+ if (current_at_base_) {
+ assert(BaseValid());
+ AdvanceBase();
+ } else {
+ assert(DeltaValid());
+ AdvanceDelta();
+ }
+ }
+ UpdateCurrent();
+}
+
+void BaseDeltaIterator::AdvanceDelta() {
+ if (forward_) {
+ delta_iterator_->NextKey();
+ } else {
+ delta_iterator_->PrevKey();
+ }
+}
+void BaseDeltaIterator::AdvanceBase() {
+ if (forward_) {
+ base_iterator_->Next();
+ } else {
+ base_iterator_->Prev();
+ }
+}
+
+bool BaseDeltaIterator::BaseValid() const { return base_iterator_->Valid(); }
+bool BaseDeltaIterator::DeltaValid() const { return delta_iterator_->Valid(); }
+void BaseDeltaIterator::UpdateCurrent() {
+// Suppress false positive clang analyzer warnings.
+#ifndef __clang_analyzer__
+ status_ = Status::OK();
+ while (true) {
+ auto delta_result = WBWIIteratorImpl::kNotFound;
+ WriteEntry delta_entry;
+ if (DeltaValid()) {
+ assert(delta_iterator_->status().ok());
+ delta_result =
+ delta_iterator_->FindLatestUpdate(wbwii_->GetMergeContext());
+ delta_entry = delta_iterator_->Entry();
+ } else if (!delta_iterator_->status().ok()) {
+ // Expose the error status and stop.
+ current_at_base_ = false;
+ return;
+ }
+ equal_keys_ = false;
+ if (!BaseValid()) {
+ if (!base_iterator_->status().ok()) {
+ // Expose the error status and stop.
+ current_at_base_ = true;
+ return;
+ }
+
+ // Base has finished.
+ if (!DeltaValid()) {
+ // Finished
+ return;
+ }
+ if (iterate_upper_bound_) {
+ if (comparator_->CompareWithoutTimestamp(
+ delta_entry.key, /*a_has_ts=*/false, *iterate_upper_bound_,
+ /*b_has_ts=*/false) >= 0) {
+ // out of upper bound -> finished.
+ return;
+ }
+ }
+ if (delta_result == WBWIIteratorImpl::kDeleted &&
+ wbwii_->GetNumOperands() == 0) {
+ AdvanceDelta();
+ } else {
+ current_at_base_ = false;
+ return;
+ }
+ } else if (!DeltaValid()) {
+ // Delta has finished.
+ current_at_base_ = true;
+ return;
+ } else {
+ int compare =
+ (forward_ ? 1 : -1) * comparator_->CompareWithoutTimestamp(
+ delta_entry.key, /*a_has_ts=*/false,
+ base_iterator_->key(), /*b_has_ts=*/false);
+ if (compare <= 0) { // delta bigger or equal
+ if (compare == 0) {
+ equal_keys_ = true;
+ }
+ if (delta_result != WBWIIteratorImpl::kDeleted ||
+ wbwii_->GetNumOperands() > 0) {
+ current_at_base_ = false;
+ return;
+ }
+ // Delta is less advanced and is delete.
+ AdvanceDelta();
+ if (equal_keys_) {
+ AdvanceBase();
+ }
+ } else {
+ current_at_base_ = true;
+ return;
+ }
+ }
+ }
+
+ AssertInvariants();
+#endif // __clang_analyzer__
+}
+
+void WBWIIteratorImpl::AdvanceKey(bool forward) {
+ if (Valid()) {
+ Slice key = Entry().key;
+ do {
+ if (forward) {
+ Next();
+ } else {
+ Prev();
+ }
+ } while (MatchesKey(column_family_id_, key));
+ }
+}
+
+void WBWIIteratorImpl::NextKey() { AdvanceKey(true); }
+
+void WBWIIteratorImpl::PrevKey() {
+ AdvanceKey(false); // Move to the tail of the previous key
+ if (Valid()) {
+ AdvanceKey(false); // Move back another key. Now we are at the start of
+ // the previous key
+ if (Valid()) { // Still a valid
+ Next(); // Move forward one onto this key
+ } else {
+ SeekToFirst(); // Not valid, move to the start
+ }
+ }
+}
+
+WBWIIteratorImpl::Result WBWIIteratorImpl::FindLatestUpdate(
+ MergeContext* merge_context) {
+ if (Valid()) {
+ Slice key = Entry().key;
+ return FindLatestUpdate(key, merge_context);
+ } else {
+ merge_context->Clear(); // Clear any entries in the MergeContext
+ return WBWIIteratorImpl::kNotFound;
+ }
+}
+
+WBWIIteratorImpl::Result WBWIIteratorImpl::FindLatestUpdate(
+ const Slice& key, MergeContext* merge_context) {
+ Result result = WBWIIteratorImpl::kNotFound;
+ merge_context->Clear(); // Clear any entries in the MergeContext
+ // TODO(agiardullo): consider adding support for reverse iteration
+ if (!Valid()) {
+ return result;
+ } else if (comparator_->CompareKey(column_family_id_, Entry().key, key) !=
+ 0) {
+ return result;
+ } else {
+ // We want to iterate in the reverse order that the writes were added to the
+ // batch. Since we don't have a reverse iterator, we must seek past the
+ // end. We do this by seeking to the next key, and then back one step
+ NextKey();
+ if (Valid()) {
+ Prev();
+ } else {
+ SeekToLast();
+ }
+
+ // We are at the end of the iterator for this key. Search backwards for the
+ // last Put or Delete, accumulating merges along the way.
+ while (Valid()) {
+ const WriteEntry entry = Entry();
+ if (comparator_->CompareKey(column_family_id_, entry.key, key) != 0) {
+ break; // Unexpected error or we've reached a different next key
+ }
+
+ switch (entry.type) {
+ case kPutRecord:
+ return WBWIIteratorImpl::kFound;
+ case kDeleteRecord:
+ return WBWIIteratorImpl::kDeleted;
+ case kSingleDeleteRecord:
+ return WBWIIteratorImpl::kDeleted;
+ case kMergeRecord:
+ result = WBWIIteratorImpl::kMergeInProgress;
+ merge_context->PushOperand(entry.value);
+ break;
+ case kLogDataRecord:
+ break; // ignore
+ case kXIDRecord:
+ break; // ignore
+ default:
+ return WBWIIteratorImpl::kError;
+ } // end switch statement
+ Prev();
+ } // End while Valid()
+ // At this point, we have been through the whole list and found no Puts or
+ // Deletes. The iterator points to the previous key. Move the iterator back
+ // onto this one.
+ if (Valid()) {
+ Next();
+ } else {
+ SeekToFirst();
+ }
+ }
+ return result;
+}
+
+Status ReadableWriteBatch::GetEntryFromDataOffset(size_t data_offset,
+ WriteType* type, Slice* Key,
+ Slice* value, Slice* blob,
+ Slice* xid) const {
+ if (type == nullptr || Key == nullptr || value == nullptr ||
+ blob == nullptr || xid == nullptr) {
+ return Status::InvalidArgument("Output parameters cannot be null");
+ }
+
+ if (data_offset == GetDataSize()) {
+ // reached end of batch.
+ return Status::NotFound();
+ }
+
+ if (data_offset > GetDataSize()) {
+ return Status::InvalidArgument("data offset exceed write batch size");
+ }
+ Slice input = Slice(rep_.data() + data_offset, rep_.size() - data_offset);
+ char tag;
+ uint32_t column_family;
+ Status s = ReadRecordFromWriteBatch(&input, &tag, &column_family, Key, value,
+ blob, xid);
+ if (!s.ok()) {
+ return s;
+ }
+
+ switch (tag) {
+ case kTypeColumnFamilyValue:
+ case kTypeValue:
+ *type = kPutRecord;
+ break;
+ case kTypeColumnFamilyDeletion:
+ case kTypeDeletion:
+ *type = kDeleteRecord;
+ break;
+ case kTypeColumnFamilySingleDeletion:
+ case kTypeSingleDeletion:
+ *type = kSingleDeleteRecord;
+ break;
+ case kTypeColumnFamilyRangeDeletion:
+ case kTypeRangeDeletion:
+ *type = kDeleteRangeRecord;
+ break;
+ case kTypeColumnFamilyMerge:
+ case kTypeMerge:
+ *type = kMergeRecord;
+ break;
+ case kTypeLogData:
+ *type = kLogDataRecord;
+ break;
+ case kTypeNoop:
+ case kTypeBeginPrepareXID:
+ case kTypeBeginPersistedPrepareXID:
+ case kTypeBeginUnprepareXID:
+ case kTypeEndPrepareXID:
+ case kTypeCommitXID:
+ case kTypeRollbackXID:
+ *type = kXIDRecord;
+ break;
+ default:
+ return Status::Corruption("unknown WriteBatch tag ",
+ std::to_string(static_cast<unsigned int>(tag)));
+ }
+ return Status::OK();
+}
+
+// If both of `entry1` and `entry2` point to real entry in write batch, we
+// compare the entries as following:
+// 1. first compare the column family, the one with larger CF will be larger;
+// 2. Inside the same CF, we first decode the entry to find the key of the entry
+// and the entry with larger key will be larger;
+// 3. If two entries are of the same CF and key, the one with larger offset
+// will be larger.
+// Some times either `entry1` or `entry2` is dummy entry, which is actually
+// a search key. In this case, in step 2, we don't go ahead and decode the
+// entry but use the value in WriteBatchIndexEntry::search_key.
+// One special case is WriteBatchIndexEntry::key_size is kFlagMinInCf.
+// This indicate that we are going to seek to the first of the column family.
+// Once we see this, this entry will be smaller than all the real entries of
+// the column family.
+int WriteBatchEntryComparator::operator()(
+ const WriteBatchIndexEntry* entry1,
+ const WriteBatchIndexEntry* entry2) const {
+ if (entry1->column_family > entry2->column_family) {
+ return 1;
+ } else if (entry1->column_family < entry2->column_family) {
+ return -1;
+ }
+
+ // Deal with special case of seeking to the beginning of a column family
+ if (entry1->is_min_in_cf()) {
+ return -1;
+ } else if (entry2->is_min_in_cf()) {
+ return 1;
+ }
+
+ Slice key1, key2;
+ if (entry1->search_key == nullptr) {
+ key1 = Slice(write_batch_->Data().data() + entry1->key_offset,
+ entry1->key_size);
+ } else {
+ key1 = *(entry1->search_key);
+ }
+ if (entry2->search_key == nullptr) {
+ key2 = Slice(write_batch_->Data().data() + entry2->key_offset,
+ entry2->key_size);
+ } else {
+ key2 = *(entry2->search_key);
+ }
+
+ int cmp = CompareKey(entry1->column_family, key1, key2);
+ if (cmp != 0) {
+ return cmp;
+ } else if (entry1->offset > entry2->offset) {
+ return 1;
+ } else if (entry1->offset < entry2->offset) {
+ return -1;
+ }
+ return 0;
+}
+
+int WriteBatchEntryComparator::CompareKey(uint32_t column_family,
+ const Slice& key1,
+ const Slice& key2) const {
+ if (column_family < cf_comparators_.size() &&
+ cf_comparators_[column_family] != nullptr) {
+ return cf_comparators_[column_family]->CompareWithoutTimestamp(
+ key1, /*a_has_ts=*/false, key2, /*b_has_ts=*/false);
+ } else {
+ return default_comparator_->CompareWithoutTimestamp(
+ key1, /*a_has_ts=*/false, key2, /*b_has_ts=*/false);
+ }
+}
+
+const Comparator* WriteBatchEntryComparator::GetComparator(
+ const ColumnFamilyHandle* column_family) const {
+ return column_family ? column_family->GetComparator() : default_comparator_;
+}
+
+const Comparator* WriteBatchEntryComparator::GetComparator(
+ uint32_t column_family) const {
+ if (column_family < cf_comparators_.size() &&
+ cf_comparators_[column_family]) {
+ return cf_comparators_[column_family];
+ }
+ return default_comparator_;
+}
+
+WriteEntry WBWIIteratorImpl::Entry() const {
+ WriteEntry ret;
+ Slice blob, xid;
+ const WriteBatchIndexEntry* iter_entry = skip_list_iter_.key();
+ // this is guaranteed with Valid()
+ assert(iter_entry != nullptr &&
+ iter_entry->column_family == column_family_id_);
+ auto s = write_batch_->GetEntryFromDataOffset(
+ iter_entry->offset, &ret.type, &ret.key, &ret.value, &blob, &xid);
+ assert(s.ok());
+ assert(ret.type == kPutRecord || ret.type == kDeleteRecord ||
+ ret.type == kSingleDeleteRecord || ret.type == kDeleteRangeRecord ||
+ ret.type == kMergeRecord);
+ // Make sure entry.key does not include user-defined timestamp.
+ const Comparator* const ucmp = comparator_->GetComparator(column_family_id_);
+ size_t ts_sz = ucmp->timestamp_size();
+ if (ts_sz > 0) {
+ ret.key = StripTimestampFromUserKey(ret.key, ts_sz);
+ }
+ return ret;
+}
+
+bool WBWIIteratorImpl::MatchesKey(uint32_t cf_id, const Slice& key) {
+ if (Valid()) {
+ return comparator_->CompareKey(cf_id, key, Entry().key) == 0;
+ } else {
+ return false;
+ }
+}
+
+WriteBatchWithIndexInternal::WriteBatchWithIndexInternal(
+ ColumnFamilyHandle* column_family)
+ : db_(nullptr), db_options_(nullptr), column_family_(column_family) {}
+
+WriteBatchWithIndexInternal::WriteBatchWithIndexInternal(
+ DB* db, ColumnFamilyHandle* column_family)
+ : db_(db), db_options_(nullptr), column_family_(column_family) {
+ if (db_ != nullptr && column_family_ == nullptr) {
+ column_family_ = db_->DefaultColumnFamily();
+ }
+}
+
+WriteBatchWithIndexInternal::WriteBatchWithIndexInternal(
+ const DBOptions* db_options, ColumnFamilyHandle* column_family)
+ : db_(nullptr), db_options_(db_options), column_family_(column_family) {}
+
+Status WriteBatchWithIndexInternal::MergeKey(const Slice& key,
+ const Slice* value,
+ const MergeContext& context,
+ std::string* result) const {
+ if (column_family_ != nullptr) {
+ auto cfh = static_cast_with_check<ColumnFamilyHandleImpl>(column_family_);
+ const auto merge_operator = cfh->cfd()->ioptions()->merge_operator.get();
+ if (merge_operator == nullptr) {
+ return Status::InvalidArgument(
+ "Merge_operator must be set for column_family");
+ } else if (db_ != nullptr) {
+ const ImmutableDBOptions& immutable_db_options =
+ static_cast_with_check<DBImpl>(db_->GetRootDB())
+ ->immutable_db_options();
+ Statistics* statistics = immutable_db_options.statistics.get();
+ Logger* logger = immutable_db_options.info_log.get();
+ SystemClock* clock = immutable_db_options.clock;
+ return MergeHelper::TimedFullMerge(
+ merge_operator, key, value, context.GetOperands(), result, logger,
+ statistics, clock, /* result_operand */ nullptr,
+ /* update_num_ops_stats */ false);
+ } else if (db_options_ != nullptr) {
+ Statistics* statistics = db_options_->statistics.get();
+ Env* env = db_options_->env;
+ Logger* logger = db_options_->info_log.get();
+ SystemClock* clock = env->GetSystemClock().get();
+ return MergeHelper::TimedFullMerge(
+ merge_operator, key, value, context.GetOperands(), result, logger,
+ statistics, clock, /* result_operand */ nullptr,
+ /* update_num_ops_stats */ false);
+ } else {
+ const auto cf_opts = cfh->cfd()->ioptions();
+ return MergeHelper::TimedFullMerge(
+ merge_operator, key, value, context.GetOperands(), result,
+ cf_opts->logger, cf_opts->stats, cf_opts->clock,
+ /* result_operand */ nullptr, /* update_num_ops_stats */ false);
+ }
+ } else {
+ return Status::InvalidArgument("Must provide a column_family");
+ }
+}
+
+WBWIIteratorImpl::Result WriteBatchWithIndexInternal::GetFromBatch(
+ WriteBatchWithIndex* batch, const Slice& key, MergeContext* context,
+ std::string* value, Status* s) {
+ *s = Status::OK();
+
+ std::unique_ptr<WBWIIteratorImpl> iter(
+ static_cast_with_check<WBWIIteratorImpl>(
+ batch->NewIterator(column_family_)));
+
+ // Search the iterator for this key, and updates/merges to it.
+ iter->Seek(key);
+ auto result = iter->FindLatestUpdate(key, context);
+ if (result == WBWIIteratorImpl::kError) {
+ (*s) = Status::Corruption("Unexpected entry in WriteBatchWithIndex:",
+ std::to_string(iter->Entry().type));
+ return result;
+ } else if (result == WBWIIteratorImpl::kNotFound) {
+ return result;
+ } else if (result == WBWIIteratorImpl::Result::kFound) { // PUT
+ Slice entry_value = iter->Entry().value;
+ if (context->GetNumOperands() > 0) {
+ *s = MergeKey(key, &entry_value, *context, value);
+ if (!s->ok()) {
+ result = WBWIIteratorImpl::Result::kError;
+ }
+ } else {
+ value->assign(entry_value.data(), entry_value.size());
+ }
+ } else if (result == WBWIIteratorImpl::kDeleted) {
+ if (context->GetNumOperands() > 0) {
+ *s = MergeKey(key, nullptr, *context, value);
+ if (s->ok()) {
+ result = WBWIIteratorImpl::Result::kFound;
+ } else {
+ result = WBWIIteratorImpl::Result::kError;
+ }
+ }
+ }
+ return result;
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // !ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/write_batch_with_index/write_batch_with_index_internal.h b/src/rocksdb/utilities/write_batch_with_index/write_batch_with_index_internal.h
new file mode 100644
index 000000000..edabc95bc
--- /dev/null
+++ b/src/rocksdb/utilities/write_batch_with_index/write_batch_with_index_internal.h
@@ -0,0 +1,344 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+#pragma once
+
+#ifndef ROCKSDB_LITE
+
+#include <limits>
+#include <string>
+#include <vector>
+
+#include "db/merge_context.h"
+#include "memtable/skiplist.h"
+#include "options/db_options.h"
+#include "port/port.h"
+#include "rocksdb/comparator.h"
+#include "rocksdb/iterator.h"
+#include "rocksdb/slice.h"
+#include "rocksdb/status.h"
+#include "rocksdb/utilities/write_batch_with_index.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class MergeContext;
+class WBWIIteratorImpl;
+class WriteBatchWithIndexInternal;
+struct Options;
+
+// when direction == forward
+// * current_at_base_ <=> base_iterator > delta_iterator
+// when direction == backwards
+// * current_at_base_ <=> base_iterator < delta_iterator
+// always:
+// * equal_keys_ <=> base_iterator == delta_iterator
+class BaseDeltaIterator : public Iterator {
+ public:
+ BaseDeltaIterator(ColumnFamilyHandle* column_family, Iterator* base_iterator,
+ WBWIIteratorImpl* delta_iterator,
+ const Comparator* comparator,
+ const ReadOptions* read_options = nullptr);
+
+ ~BaseDeltaIterator() override {}
+
+ bool Valid() const override;
+ void SeekToFirst() override;
+ void SeekToLast() override;
+ void Seek(const Slice& k) override;
+ void SeekForPrev(const Slice& k) override;
+ void Next() override;
+ void Prev() override;
+ Slice key() const override;
+ Slice value() const override;
+ Status status() const override;
+ void Invalidate(Status s);
+
+ private:
+ void AssertInvariants();
+ void Advance();
+ void AdvanceDelta();
+ void AdvanceBase();
+ bool BaseValid() const;
+ bool DeltaValid() const;
+ void UpdateCurrent();
+
+ std::unique_ptr<WriteBatchWithIndexInternal> wbwii_;
+ bool forward_;
+ bool current_at_base_;
+ bool equal_keys_;
+ mutable Status status_;
+ std::unique_ptr<Iterator> base_iterator_;
+ std::unique_ptr<WBWIIteratorImpl> delta_iterator_;
+ const Comparator* comparator_; // not owned
+ const Slice* iterate_upper_bound_;
+ mutable PinnableSlice merge_result_;
+};
+
+// Key used by skip list, as the binary searchable index of WriteBatchWithIndex.
+struct WriteBatchIndexEntry {
+ WriteBatchIndexEntry(size_t o, uint32_t c, size_t ko, size_t ksz)
+ : offset(o),
+ column_family(c),
+ key_offset(ko),
+ key_size(ksz),
+ search_key(nullptr) {}
+ // Create a dummy entry as the search key. This index entry won't be backed
+ // by an entry from the write batch, but a pointer to the search key. Or a
+ // special flag of offset can indicate we are seek to first.
+ // @_search_key: the search key
+ // @_column_family: column family
+ // @is_forward_direction: true for Seek(). False for SeekForPrev()
+ // @is_seek_to_first: true if we seek to the beginning of the column family
+ // _search_key should be null in this case.
+ WriteBatchIndexEntry(const Slice* _search_key, uint32_t _column_family,
+ bool is_forward_direction, bool is_seek_to_first)
+ // For SeekForPrev(), we need to make the dummy entry larger than any
+ // entry who has the same search key. Otherwise, we'll miss those entries.
+ : offset(is_forward_direction ? 0 : std::numeric_limits<size_t>::max()),
+ column_family(_column_family),
+ key_offset(0),
+ key_size(is_seek_to_first ? kFlagMinInCf : 0),
+ search_key(_search_key) {
+ assert(_search_key != nullptr || is_seek_to_first);
+ }
+
+ // If this flag appears in the key_size, it indicates a
+ // key that is smaller than any other entry for the same column family.
+ static const size_t kFlagMinInCf = std::numeric_limits<size_t>::max();
+
+ bool is_min_in_cf() const {
+ assert(key_size != kFlagMinInCf ||
+ (key_offset == 0 && search_key == nullptr));
+ return key_size == kFlagMinInCf;
+ }
+
+ // offset of an entry in write batch's string buffer. If this is a dummy
+ // lookup key, in which case search_key != nullptr, offset is set to either
+ // 0 or max, only for comparison purpose. Because when entries have the same
+ // key, the entry with larger offset is larger, offset = 0 will make a seek
+ // key small or equal than all the entries with the seek key, so that Seek()
+ // will find all the entries of the same key. Similarly, offset = MAX will
+ // make the entry just larger than all entries with the search key so
+ // SeekForPrev() will see all the keys with the same key.
+ size_t offset;
+ uint32_t column_family; // column family of the entry.
+ size_t key_offset; // offset of the key in write batch's string buffer.
+ size_t key_size; // size of the key. kFlagMinInCf indicates
+ // that this is a dummy look up entry for
+ // SeekToFirst() to the beginning of the column
+ // family. We use the flag here to save a boolean
+ // in the struct.
+
+ const Slice* search_key; // if not null, instead of reading keys from
+ // write batch, use it to compare. This is used
+ // for lookup key.
+};
+
+class ReadableWriteBatch : public WriteBatch {
+ public:
+ explicit ReadableWriteBatch(size_t reserved_bytes = 0, size_t max_bytes = 0,
+ size_t protection_bytes_per_key = 0,
+ size_t default_cf_ts_sz = 0)
+ : WriteBatch(reserved_bytes, max_bytes, protection_bytes_per_key,
+ default_cf_ts_sz) {}
+ // Retrieve some information from a write entry in the write batch, given
+ // the start offset of the write entry.
+ Status GetEntryFromDataOffset(size_t data_offset, WriteType* type, Slice* Key,
+ Slice* value, Slice* blob, Slice* xid) const;
+};
+
+class WriteBatchEntryComparator {
+ public:
+ WriteBatchEntryComparator(const Comparator* _default_comparator,
+ const ReadableWriteBatch* write_batch)
+ : default_comparator_(_default_comparator), write_batch_(write_batch) {}
+ // Compare a and b. Return a negative value if a is less than b, 0 if they
+ // are equal, and a positive value if a is greater than b
+ int operator()(const WriteBatchIndexEntry* entry1,
+ const WriteBatchIndexEntry* entry2) const;
+
+ int CompareKey(uint32_t column_family, const Slice& key1,
+ const Slice& key2) const;
+
+ void SetComparatorForCF(uint32_t column_family_id,
+ const Comparator* comparator) {
+ if (column_family_id >= cf_comparators_.size()) {
+ cf_comparators_.resize(column_family_id + 1, nullptr);
+ }
+ cf_comparators_[column_family_id] = comparator;
+ }
+
+ const Comparator* default_comparator() { return default_comparator_; }
+
+ const Comparator* GetComparator(
+ const ColumnFamilyHandle* column_family) const;
+
+ const Comparator* GetComparator(uint32_t column_family) const;
+
+ private:
+ const Comparator* const default_comparator_;
+ std::vector<const Comparator*> cf_comparators_;
+ const ReadableWriteBatch* const write_batch_;
+};
+
+using WriteBatchEntrySkipList =
+ SkipList<WriteBatchIndexEntry*, const WriteBatchEntryComparator&>;
+
+class WBWIIteratorImpl : public WBWIIterator {
+ public:
+ enum Result : uint8_t {
+ kFound,
+ kDeleted,
+ kNotFound,
+ kMergeInProgress,
+ kError
+ };
+ WBWIIteratorImpl(uint32_t column_family_id,
+ WriteBatchEntrySkipList* skip_list,
+ const ReadableWriteBatch* write_batch,
+ WriteBatchEntryComparator* comparator)
+ : column_family_id_(column_family_id),
+ skip_list_iter_(skip_list),
+ write_batch_(write_batch),
+ comparator_(comparator) {}
+
+ ~WBWIIteratorImpl() override {}
+
+ bool Valid() const override {
+ if (!skip_list_iter_.Valid()) {
+ return false;
+ }
+ const WriteBatchIndexEntry* iter_entry = skip_list_iter_.key();
+ return (iter_entry != nullptr &&
+ iter_entry->column_family == column_family_id_);
+ }
+
+ void SeekToFirst() override {
+ WriteBatchIndexEntry search_entry(
+ nullptr /* search_key */, column_family_id_,
+ true /* is_forward_direction */, true /* is_seek_to_first */);
+ skip_list_iter_.Seek(&search_entry);
+ }
+
+ void SeekToLast() override {
+ WriteBatchIndexEntry search_entry(
+ nullptr /* search_key */, column_family_id_ + 1,
+ true /* is_forward_direction */, true /* is_seek_to_first */);
+ skip_list_iter_.Seek(&search_entry);
+ if (!skip_list_iter_.Valid()) {
+ skip_list_iter_.SeekToLast();
+ } else {
+ skip_list_iter_.Prev();
+ }
+ }
+
+ void Seek(const Slice& key) override {
+ WriteBatchIndexEntry search_entry(&key, column_family_id_,
+ true /* is_forward_direction */,
+ false /* is_seek_to_first */);
+ skip_list_iter_.Seek(&search_entry);
+ }
+
+ void SeekForPrev(const Slice& key) override {
+ WriteBatchIndexEntry search_entry(&key, column_family_id_,
+ false /* is_forward_direction */,
+ false /* is_seek_to_first */);
+ skip_list_iter_.SeekForPrev(&search_entry);
+ }
+
+ void Next() override { skip_list_iter_.Next(); }
+
+ void Prev() override { skip_list_iter_.Prev(); }
+
+ WriteEntry Entry() const override;
+
+ Status status() const override {
+ // this is in-memory data structure, so the only way status can be non-ok is
+ // through memory corruption
+ return Status::OK();
+ }
+
+ const WriteBatchIndexEntry* GetRawEntry() const {
+ return skip_list_iter_.key();
+ }
+
+ bool MatchesKey(uint32_t cf_id, const Slice& key);
+
+ // Moves the iterator to first entry of the previous key.
+ void PrevKey();
+ // Moves the iterator to first entry of the next key.
+ void NextKey();
+
+ // Moves the iterator to the Update (Put or Delete) for the current key
+ // If there are no Put/Delete, the Iterator will point to the first entry for
+ // this key
+ // @return kFound if a Put was found for the key
+ // @return kDeleted if a delete was found for the key
+ // @return kMergeInProgress if only merges were fouund for the key
+ // @return kError if an unsupported operation was found for the key
+ // @return kNotFound if no operations were found for this key
+ //
+ Result FindLatestUpdate(const Slice& key, MergeContext* merge_context);
+ Result FindLatestUpdate(MergeContext* merge_context);
+
+ protected:
+ void AdvanceKey(bool forward);
+
+ private:
+ uint32_t column_family_id_;
+ WriteBatchEntrySkipList::Iterator skip_list_iter_;
+ const ReadableWriteBatch* write_batch_;
+ WriteBatchEntryComparator* comparator_;
+};
+
+class WriteBatchWithIndexInternal {
+ public:
+ static const Comparator* GetUserComparator(const WriteBatchWithIndex& wbwi,
+ uint32_t cf_id);
+
+ // For GetFromBatchAndDB or similar
+ explicit WriteBatchWithIndexInternal(DB* db,
+ ColumnFamilyHandle* column_family);
+ // For GetFromBatchAndDB or similar
+ explicit WriteBatchWithIndexInternal(ColumnFamilyHandle* column_family);
+ // For GetFromBatch or similar
+ explicit WriteBatchWithIndexInternal(const DBOptions* db_options,
+ ColumnFamilyHandle* column_family);
+
+ // If batch contains a value for key, store it in *value and return kFound.
+ // If batch contains a deletion for key, return Deleted.
+ // If batch contains Merge operations as the most recent entry for a key,
+ // and the merge process does not stop (not reaching a value or delete),
+ // prepend the current merge operands to *operands,
+ // and return kMergeInProgress
+ // If batch does not contain this key, return kNotFound
+ // Else, return kError on error with error Status stored in *s.
+ WBWIIteratorImpl::Result GetFromBatch(WriteBatchWithIndex* batch,
+ const Slice& key, std::string* value,
+ Status* s) {
+ return GetFromBatch(batch, key, &merge_context_, value, s);
+ }
+ WBWIIteratorImpl::Result GetFromBatch(WriteBatchWithIndex* batch,
+ const Slice& key,
+ MergeContext* merge_context,
+ std::string* value, Status* s);
+ Status MergeKey(const Slice& key, const Slice* value,
+ std::string* result) const {
+ return MergeKey(key, value, merge_context_, result);
+ }
+ Status MergeKey(const Slice& key, const Slice* value,
+ const MergeContext& context, std::string* result) const;
+ size_t GetNumOperands() const { return merge_context_.GetNumOperands(); }
+ MergeContext* GetMergeContext() { return &merge_context_; }
+ Slice GetOperand(int index) const { return merge_context_.GetOperand(index); }
+
+ private:
+ DB* db_;
+ const DBOptions* db_options_;
+ ColumnFamilyHandle* column_family_;
+ MergeContext merge_context_;
+};
+
+} // namespace ROCKSDB_NAMESPACE
+#endif // !ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/write_batch_with_index/write_batch_with_index_test.cc b/src/rocksdb/utilities/write_batch_with_index/write_batch_with_index_test.cc
new file mode 100644
index 000000000..350dcc881
--- /dev/null
+++ b/src/rocksdb/utilities/write_batch_with_index/write_batch_with_index_test.cc
@@ -0,0 +1,2419 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#ifndef ROCKSDB_LITE
+
+#include "rocksdb/utilities/write_batch_with_index.h"
+
+#include <map>
+#include <memory>
+
+#include "db/column_family.h"
+#include "port/stack_trace.h"
+#include "test_util/testharness.h"
+#include "test_util/testutil.h"
+#include "util/random.h"
+#include "util/string_util.h"
+#include "utilities/merge_operators.h"
+#include "utilities/merge_operators/string_append/stringappend.h"
+#include "utilities/write_batch_with_index/write_batch_with_index_internal.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+namespace {
+class ColumnFamilyHandleImplDummy : public ColumnFamilyHandleImpl {
+ public:
+ explicit ColumnFamilyHandleImplDummy(int id, const Comparator* comparator)
+ : ColumnFamilyHandleImpl(nullptr, nullptr, nullptr),
+ id_(id),
+ comparator_(comparator) {}
+ uint32_t GetID() const override { return id_; }
+ const Comparator* GetComparator() const override { return comparator_; }
+
+ private:
+ uint32_t id_;
+ const Comparator* comparator_;
+};
+
+struct Entry {
+ std::string key;
+ std::string value;
+ WriteType type;
+};
+
+struct TestHandler : public WriteBatch::Handler {
+ std::map<uint32_t, std::vector<Entry>> seen;
+ Status PutCF(uint32_t column_family_id, const Slice& key,
+ const Slice& value) override {
+ Entry e;
+ e.key = key.ToString();
+ e.value = value.ToString();
+ e.type = kPutRecord;
+ seen[column_family_id].push_back(e);
+ return Status::OK();
+ }
+ Status MergeCF(uint32_t column_family_id, const Slice& key,
+ const Slice& value) override {
+ Entry e;
+ e.key = key.ToString();
+ e.value = value.ToString();
+ e.type = kMergeRecord;
+ seen[column_family_id].push_back(e);
+ return Status::OK();
+ }
+ void LogData(const Slice& /*blob*/) override {}
+ Status DeleteCF(uint32_t column_family_id, const Slice& key) override {
+ Entry e;
+ e.key = key.ToString();
+ e.value = "";
+ e.type = kDeleteRecord;
+ seen[column_family_id].push_back(e);
+ return Status::OK();
+ }
+};
+
+using KVMap = std::map<std::string, std::string>;
+
+class KVIter : public Iterator {
+ public:
+ explicit KVIter(const KVMap* map) : map_(map), iter_(map_->end()) {}
+ bool Valid() const override { return iter_ != map_->end(); }
+ void SeekToFirst() override { iter_ = map_->begin(); }
+ void SeekToLast() override {
+ if (map_->empty()) {
+ iter_ = map_->end();
+ } else {
+ iter_ = map_->find(map_->rbegin()->first);
+ }
+ }
+ void Seek(const Slice& k) override {
+ iter_ = map_->lower_bound(k.ToString());
+ }
+ void SeekForPrev(const Slice& k) override {
+ iter_ = map_->upper_bound(k.ToString());
+ Prev();
+ }
+ void Next() override { ++iter_; }
+ void Prev() override {
+ if (iter_ == map_->begin()) {
+ iter_ = map_->end();
+ return;
+ }
+ --iter_;
+ }
+ Slice key() const override { return iter_->first; }
+ Slice value() const override { return iter_->second; }
+ Status status() const override { return Status::OK(); }
+
+ private:
+ const KVMap* const map_;
+ KVMap::const_iterator iter_;
+};
+
+static std::string PrintContents(WriteBatchWithIndex* batch,
+ ColumnFamilyHandle* column_family,
+ bool hex = false) {
+ std::string result;
+
+ WBWIIterator* iter;
+ if (column_family == nullptr) {
+ iter = batch->NewIterator();
+ } else {
+ iter = batch->NewIterator(column_family);
+ }
+
+ iter->SeekToFirst();
+ while (iter->Valid()) {
+ WriteEntry e = iter->Entry();
+
+ if (e.type == kPutRecord) {
+ result.append("PUT(");
+ result.append(e.key.ToString(hex));
+ result.append("):");
+ result.append(e.value.ToString(hex));
+ } else if (e.type == kMergeRecord) {
+ result.append("MERGE(");
+ result.append(e.key.ToString(hex));
+ result.append("):");
+ result.append(e.value.ToString(hex));
+ } else if (e.type == kSingleDeleteRecord) {
+ result.append("SINGLE-DEL(");
+ result.append(e.key.ToString(hex));
+ result.append(")");
+ } else {
+ assert(e.type == kDeleteRecord);
+ result.append("DEL(");
+ result.append(e.key.ToString(hex));
+ result.append(")");
+ }
+
+ result.append(",");
+ iter->Next();
+ }
+
+ delete iter;
+ return result;
+}
+
+static std::string PrintContents(WriteBatchWithIndex* batch, KVMap* base_map,
+ ColumnFamilyHandle* column_family) {
+ std::string result;
+
+ Iterator* iter;
+ if (column_family == nullptr) {
+ iter = batch->NewIteratorWithBase(new KVIter(base_map));
+ } else {
+ iter = batch->NewIteratorWithBase(column_family, new KVIter(base_map));
+ }
+
+ iter->SeekToFirst();
+ while (iter->Valid()) {
+ assert(iter->status().ok());
+
+ Slice key = iter->key();
+ Slice value = iter->value();
+
+ result.append(key.ToString());
+ result.append(":");
+ result.append(value.ToString());
+ result.append(",");
+
+ iter->Next();
+ }
+
+ delete iter;
+ return result;
+}
+
+void AssertIter(Iterator* iter, const std::string& key,
+ const std::string& value) {
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ(key, iter->key().ToString());
+ ASSERT_EQ(value, iter->value().ToString());
+}
+
+void AssertItersMatch(Iterator* iter1, Iterator* iter2) {
+ ASSERT_EQ(iter1->Valid(), iter2->Valid());
+ if (iter1->Valid()) {
+ ASSERT_EQ(iter1->key().ToString(), iter2->key().ToString());
+ ASSERT_EQ(iter1->value().ToString(), iter2->value().ToString());
+ }
+}
+
+void AssertItersEqual(Iterator* iter1, Iterator* iter2) {
+ iter1->SeekToFirst();
+ iter2->SeekToFirst();
+ while (iter1->Valid()) {
+ ASSERT_EQ(iter1->Valid(), iter2->Valid());
+ ASSERT_EQ(iter1->key().ToString(), iter2->key().ToString());
+ ASSERT_EQ(iter1->value().ToString(), iter2->value().ToString());
+ iter1->Next();
+ iter2->Next();
+ }
+ ASSERT_EQ(iter1->Valid(), iter2->Valid());
+}
+
+void AssertIterEqual(WBWIIteratorImpl* wbwii,
+ const std::vector<std::string>& keys) {
+ wbwii->SeekToFirst();
+ for (auto k : keys) {
+ ASSERT_TRUE(wbwii->Valid());
+ ASSERT_EQ(wbwii->Entry().key, k);
+ wbwii->NextKey();
+ }
+ ASSERT_FALSE(wbwii->Valid());
+ wbwii->SeekToLast();
+ for (auto kit = keys.rbegin(); kit != keys.rend(); ++kit) {
+ ASSERT_TRUE(wbwii->Valid());
+ ASSERT_EQ(wbwii->Entry().key, *kit);
+ wbwii->PrevKey();
+ }
+ ASSERT_FALSE(wbwii->Valid());
+}
+} // namespace
+
+class WBWIBaseTest : public testing::Test {
+ public:
+ explicit WBWIBaseTest(bool overwrite) : db_(nullptr) {
+ options_.merge_operator =
+ MergeOperators::CreateFromStringId("stringappend");
+ options_.create_if_missing = true;
+ dbname_ = test::PerThreadDBPath("write_batch_with_index_test");
+ EXPECT_OK(DestroyDB(dbname_, options_));
+ batch_.reset(new WriteBatchWithIndex(BytewiseComparator(), 20, overwrite));
+ }
+
+ virtual ~WBWIBaseTest() {
+ if (db_ != nullptr) {
+ ReleaseSnapshot();
+ delete db_;
+ EXPECT_OK(DestroyDB(dbname_, options_));
+ }
+ }
+
+ std::string AddToBatch(ColumnFamilyHandle* cf, const std::string& key) {
+ std::string result;
+ for (size_t i = 0; i < key.size(); i++) {
+ if (key[i] == 'd') {
+ batch_->Delete(cf, key);
+ result = "";
+ } else if (key[i] == 'p') {
+ result = key + std::to_string(i);
+ batch_->Put(cf, key, result);
+ } else if (key[i] == 'm') {
+ std::string value = key + std::to_string(i);
+ batch_->Merge(cf, key, value);
+ if (result.empty()) {
+ result = value;
+ } else {
+ result = result + "," + value;
+ }
+ }
+ }
+ return result;
+ }
+
+ virtual Status OpenDB() { return DB::Open(options_, dbname_, &db_); }
+
+ void ReleaseSnapshot() {
+ if (read_opts_.snapshot != nullptr) {
+ EXPECT_NE(db_, nullptr);
+ db_->ReleaseSnapshot(read_opts_.snapshot);
+ read_opts_.snapshot = nullptr;
+ }
+ }
+
+ public:
+ DB* db_;
+ std::string dbname_;
+ Options options_;
+ WriteOptions write_opts_;
+ ReadOptions read_opts_;
+ std::unique_ptr<WriteBatchWithIndex> batch_;
+};
+
+class WBWIKeepTest : public WBWIBaseTest {
+ public:
+ WBWIKeepTest() : WBWIBaseTest(false) {}
+};
+
+class WBWIOverwriteTest : public WBWIBaseTest {
+ public:
+ WBWIOverwriteTest() : WBWIBaseTest(true) {}
+};
+class WriteBatchWithIndexTest : public WBWIBaseTest,
+ public testing::WithParamInterface<bool> {
+ public:
+ WriteBatchWithIndexTest() : WBWIBaseTest(GetParam()) {}
+};
+
+void TestValueAsSecondaryIndexHelper(std::vector<Entry> entries,
+ WriteBatchWithIndex* batch) {
+ // In this test, we insert <key, value> to column family `data`, and
+ // <value, key> to column family `index`. Then iterator them in order
+ // and seek them by key.
+
+ // Sort entries by key
+ std::map<std::string, std::vector<Entry*>> data_map;
+ // Sort entries by value
+ std::map<std::string, std::vector<Entry*>> index_map;
+ for (auto& e : entries) {
+ data_map[e.key].push_back(&e);
+ index_map[e.value].push_back(&e);
+ }
+
+ ColumnFamilyHandleImplDummy data(6, BytewiseComparator());
+ ColumnFamilyHandleImplDummy index(8, BytewiseComparator());
+ for (auto& e : entries) {
+ if (e.type == kPutRecord) {
+ ASSERT_OK(batch->Put(&data, e.key, e.value));
+ ASSERT_OK(batch->Put(&index, e.value, e.key));
+ } else if (e.type == kMergeRecord) {
+ ASSERT_OK(batch->Merge(&data, e.key, e.value));
+ ASSERT_OK(batch->Put(&index, e.value, e.key));
+ } else {
+ assert(e.type == kDeleteRecord);
+ std::unique_ptr<WBWIIterator> iter(batch->NewIterator(&data));
+ iter->Seek(e.key);
+ ASSERT_OK(iter->status());
+ auto write_entry = iter->Entry();
+ ASSERT_EQ(e.key, write_entry.key.ToString());
+ ASSERT_EQ(e.value, write_entry.value.ToString());
+ ASSERT_OK(batch->Delete(&data, e.key));
+ ASSERT_OK(batch->Put(&index, e.value, ""));
+ }
+ }
+
+ // Iterator all keys
+ {
+ std::unique_ptr<WBWIIterator> iter(batch->NewIterator(&data));
+ for (int seek_to_first : {0, 1}) {
+ if (seek_to_first) {
+ iter->SeekToFirst();
+ } else {
+ iter->Seek("");
+ }
+ for (auto pair : data_map) {
+ for (auto v : pair.second) {
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ auto write_entry = iter->Entry();
+ ASSERT_EQ(pair.first, write_entry.key.ToString());
+ ASSERT_EQ(v->type, write_entry.type);
+ if (write_entry.type != kDeleteRecord) {
+ ASSERT_EQ(v->value, write_entry.value.ToString());
+ }
+ iter->Next();
+ }
+ }
+ ASSERT_TRUE(!iter->Valid());
+ }
+ iter->SeekToLast();
+ for (auto pair = data_map.rbegin(); pair != data_map.rend(); ++pair) {
+ for (auto v = pair->second.rbegin(); v != pair->second.rend(); v++) {
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ auto write_entry = iter->Entry();
+ ASSERT_EQ(pair->first, write_entry.key.ToString());
+ ASSERT_EQ((*v)->type, write_entry.type);
+ if (write_entry.type != kDeleteRecord) {
+ ASSERT_EQ((*v)->value, write_entry.value.ToString());
+ }
+ iter->Prev();
+ }
+ }
+ ASSERT_TRUE(!iter->Valid());
+ }
+
+ // Iterator all indexes
+ {
+ std::unique_ptr<WBWIIterator> iter(batch->NewIterator(&index));
+ for (int seek_to_first : {0, 1}) {
+ if (seek_to_first) {
+ iter->SeekToFirst();
+ } else {
+ iter->Seek("");
+ }
+ for (auto pair : index_map) {
+ for (auto v : pair.second) {
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ auto write_entry = iter->Entry();
+ ASSERT_EQ(pair.first, write_entry.key.ToString());
+ if (v->type != kDeleteRecord) {
+ ASSERT_EQ(v->key, write_entry.value.ToString());
+ ASSERT_EQ(v->value, write_entry.key.ToString());
+ }
+ iter->Next();
+ }
+ }
+ ASSERT_TRUE(!iter->Valid());
+ }
+
+ iter->SeekToLast();
+ for (auto pair = index_map.rbegin(); pair != index_map.rend(); ++pair) {
+ for (auto v = pair->second.rbegin(); v != pair->second.rend(); v++) {
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ auto write_entry = iter->Entry();
+ ASSERT_EQ(pair->first, write_entry.key.ToString());
+ if ((*v)->type != kDeleteRecord) {
+ ASSERT_EQ((*v)->key, write_entry.value.ToString());
+ ASSERT_EQ((*v)->value, write_entry.key.ToString());
+ }
+ iter->Prev();
+ }
+ }
+ ASSERT_TRUE(!iter->Valid());
+ }
+
+ // Seek to every key
+ {
+ std::unique_ptr<WBWIIterator> iter(batch->NewIterator(&data));
+
+ // Seek the keys one by one in reverse order
+ for (auto pair = data_map.rbegin(); pair != data_map.rend(); ++pair) {
+ iter->Seek(pair->first);
+ ASSERT_OK(iter->status());
+ for (auto v : pair->second) {
+ ASSERT_TRUE(iter->Valid());
+ auto write_entry = iter->Entry();
+ ASSERT_EQ(pair->first, write_entry.key.ToString());
+ ASSERT_EQ(v->type, write_entry.type);
+ if (write_entry.type != kDeleteRecord) {
+ ASSERT_EQ(v->value, write_entry.value.ToString());
+ }
+ iter->Next();
+ ASSERT_OK(iter->status());
+ }
+ }
+ }
+
+ // Seek to every index
+ {
+ std::unique_ptr<WBWIIterator> iter(batch->NewIterator(&index));
+
+ // Seek the keys one by one in reverse order
+ for (auto pair = index_map.rbegin(); pair != index_map.rend(); ++pair) {
+ iter->Seek(pair->first);
+ ASSERT_OK(iter->status());
+ for (auto v : pair->second) {
+ ASSERT_TRUE(iter->Valid());
+ auto write_entry = iter->Entry();
+ ASSERT_EQ(pair->first, write_entry.key.ToString());
+ ASSERT_EQ(v->value, write_entry.key.ToString());
+ if (v->type != kDeleteRecord) {
+ ASSERT_EQ(v->key, write_entry.value.ToString());
+ }
+ iter->Next();
+ ASSERT_OK(iter->status());
+ }
+ }
+ }
+
+ // Verify WriteBatch can be iterated
+ TestHandler handler;
+ ASSERT_OK(batch->GetWriteBatch()->Iterate(&handler));
+
+ // Verify data column family
+ {
+ ASSERT_EQ(entries.size(), handler.seen[data.GetID()].size());
+ size_t i = 0;
+ for (auto e : handler.seen[data.GetID()]) {
+ auto write_entry = entries[i++];
+ ASSERT_EQ(e.type, write_entry.type);
+ ASSERT_EQ(e.key, write_entry.key);
+ if (e.type != kDeleteRecord) {
+ ASSERT_EQ(e.value, write_entry.value);
+ }
+ }
+ }
+
+ // Verify index column family
+ {
+ ASSERT_EQ(entries.size(), handler.seen[index.GetID()].size());
+ size_t i = 0;
+ for (auto e : handler.seen[index.GetID()]) {
+ auto write_entry = entries[i++];
+ ASSERT_EQ(e.key, write_entry.value);
+ if (write_entry.type != kDeleteRecord) {
+ ASSERT_EQ(e.value, write_entry.key);
+ }
+ }
+ }
+}
+
+TEST_F(WBWIKeepTest, TestValueAsSecondaryIndex) {
+ Entry entries[] = {
+ {"aaa", "0005", kPutRecord}, {"b", "0002", kPutRecord},
+ {"cdd", "0002", kMergeRecord}, {"aab", "00001", kPutRecord},
+ {"cc", "00005", kPutRecord}, {"cdd", "0002", kPutRecord},
+ {"aab", "0003", kPutRecord}, {"cc", "00005", kDeleteRecord},
+ };
+ std::vector<Entry> entries_list(entries, entries + 8);
+
+ batch_.reset(new WriteBatchWithIndex(nullptr, 20, false));
+
+ TestValueAsSecondaryIndexHelper(entries_list, batch_.get());
+
+ // Clear batch and re-run test with new values
+ batch_->Clear();
+
+ Entry new_entries[] = {
+ {"aaa", "0005", kPutRecord}, {"e", "0002", kPutRecord},
+ {"add", "0002", kMergeRecord}, {"aab", "00001", kPutRecord},
+ {"zz", "00005", kPutRecord}, {"add", "0002", kPutRecord},
+ {"aab", "0003", kPutRecord}, {"zz", "00005", kDeleteRecord},
+ };
+
+ entries_list = std::vector<Entry>(new_entries, new_entries + 8);
+
+ TestValueAsSecondaryIndexHelper(entries_list, batch_.get());
+}
+
+TEST_P(WriteBatchWithIndexTest, TestComparatorForCF) {
+ ColumnFamilyHandleImplDummy cf1(6, nullptr);
+ ColumnFamilyHandleImplDummy reverse_cf(66, ReverseBytewiseComparator());
+ ColumnFamilyHandleImplDummy cf2(88, BytewiseComparator());
+
+ ASSERT_OK(batch_->Put(&cf1, "ddd", ""));
+ ASSERT_OK(batch_->Put(&cf2, "aaa", ""));
+ ASSERT_OK(batch_->Put(&cf2, "eee", ""));
+ ASSERT_OK(batch_->Put(&cf1, "ccc", ""));
+ ASSERT_OK(batch_->Put(&reverse_cf, "a11", ""));
+ ASSERT_OK(batch_->Put(&cf1, "bbb", ""));
+
+ Slice key_slices[] = {"a", "3", "3"};
+ Slice value_slice = "";
+ ASSERT_OK(batch_->Put(&reverse_cf, SliceParts(key_slices, 3),
+ SliceParts(&value_slice, 1)));
+ ASSERT_OK(batch_->Put(&reverse_cf, "a22", ""));
+
+ {
+ std::unique_ptr<WBWIIterator> iter(batch_->NewIterator(&cf1));
+ iter->Seek("");
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("bbb", iter->Entry().key.ToString());
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("ccc", iter->Entry().key.ToString());
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("ddd", iter->Entry().key.ToString());
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(!iter->Valid());
+ }
+
+ {
+ std::unique_ptr<WBWIIterator> iter(batch_->NewIterator(&cf2));
+ iter->Seek("");
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("aaa", iter->Entry().key.ToString());
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("eee", iter->Entry().key.ToString());
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(!iter->Valid());
+ }
+
+ {
+ std::unique_ptr<WBWIIterator> iter(batch_->NewIterator(&reverse_cf));
+ iter->Seek("");
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(!iter->Valid());
+
+ iter->Seek("z");
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("a33", iter->Entry().key.ToString());
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("a22", iter->Entry().key.ToString());
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("a11", iter->Entry().key.ToString());
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(!iter->Valid());
+
+ iter->Seek("a22");
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("a22", iter->Entry().key.ToString());
+
+ iter->Seek("a13");
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("a11", iter->Entry().key.ToString());
+ }
+}
+
+TEST_F(WBWIOverwriteTest, TestOverwriteKey) {
+ ColumnFamilyHandleImplDummy cf1(6, nullptr);
+ ColumnFamilyHandleImplDummy reverse_cf(66, ReverseBytewiseComparator());
+ ColumnFamilyHandleImplDummy cf2(88, BytewiseComparator());
+
+ ASSERT_OK(batch_->Merge(&cf1, "ddd", ""));
+ ASSERT_OK(batch_->Put(&cf1, "ddd", ""));
+ ASSERT_OK(batch_->Delete(&cf1, "ddd"));
+ ASSERT_OK(batch_->Put(&cf2, "aaa", ""));
+ ASSERT_OK(batch_->Delete(&cf2, "aaa"));
+ ASSERT_OK(batch_->Put(&cf2, "aaa", "aaa"));
+ ASSERT_OK(batch_->Put(&cf2, "eee", "eee"));
+ ASSERT_OK(batch_->Put(&cf1, "ccc", ""));
+ ASSERT_OK(batch_->Put(&reverse_cf, "a11", ""));
+ ASSERT_OK(batch_->Delete(&cf1, "ccc"));
+ ASSERT_OK(batch_->Put(&reverse_cf, "a33", "a33"));
+ ASSERT_OK(batch_->Put(&reverse_cf, "a11", "a11"));
+ Slice slices[] = {"a", "3", "3"};
+ ASSERT_OK(batch_->Delete(&reverse_cf, SliceParts(slices, 3)));
+
+ {
+ std::unique_ptr<WBWIIterator> iter(batch_->NewIterator(&cf1));
+ iter->Seek("");
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("ccc", iter->Entry().key.ToString());
+ ASSERT_TRUE(iter->Entry().type == WriteType::kDeleteRecord);
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("ddd", iter->Entry().key.ToString());
+ ASSERT_TRUE(iter->Entry().type == WriteType::kDeleteRecord);
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(!iter->Valid());
+ }
+
+ {
+ std::unique_ptr<WBWIIterator> iter(batch_->NewIterator(&cf2));
+ iter->SeekToLast();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("eee", iter->Entry().key.ToString());
+ ASSERT_EQ("eee", iter->Entry().value.ToString());
+ iter->Prev();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("aaa", iter->Entry().key.ToString());
+ ASSERT_EQ("aaa", iter->Entry().value.ToString());
+ iter->Prev();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(!iter->Valid());
+
+ iter->SeekToFirst();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("aaa", iter->Entry().key.ToString());
+ ASSERT_EQ("aaa", iter->Entry().value.ToString());
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("eee", iter->Entry().key.ToString());
+ ASSERT_EQ("eee", iter->Entry().value.ToString());
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(!iter->Valid());
+ }
+
+ {
+ std::unique_ptr<WBWIIterator> iter(batch_->NewIterator(&reverse_cf));
+ iter->Seek("");
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(!iter->Valid());
+
+ iter->Seek("z");
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("a33", iter->Entry().key.ToString());
+ ASSERT_TRUE(iter->Entry().type == WriteType::kDeleteRecord);
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("a11", iter->Entry().key.ToString());
+ ASSERT_EQ("a11", iter->Entry().value.ToString());
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(!iter->Valid());
+
+ iter->SeekToLast();
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("a11", iter->Entry().key.ToString());
+ ASSERT_EQ("a11", iter->Entry().value.ToString());
+ iter->Prev();
+
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("a33", iter->Entry().key.ToString());
+ ASSERT_TRUE(iter->Entry().type == WriteType::kDeleteRecord);
+ iter->Prev();
+ ASSERT_TRUE(!iter->Valid());
+ }
+}
+
+TEST_P(WriteBatchWithIndexTest, TestWBWIIterator) {
+ ColumnFamilyHandleImplDummy cf1(1, BytewiseComparator());
+ ColumnFamilyHandleImplDummy cf2(2, BytewiseComparator());
+ ASSERT_OK(batch_->Put(&cf1, "a", "a1"));
+ ASSERT_OK(batch_->Put(&cf1, "c", "c1"));
+ ASSERT_OK(batch_->Put(&cf1, "c", "c2"));
+ ASSERT_OK(batch_->Put(&cf1, "e", "e1"));
+ ASSERT_OK(batch_->Put(&cf1, "e", "e2"));
+ ASSERT_OK(batch_->Put(&cf1, "e", "e3"));
+ std::unique_ptr<WBWIIteratorImpl> iter1(
+ static_cast<WBWIIteratorImpl*>(batch_->NewIterator(&cf1)));
+ std::unique_ptr<WBWIIteratorImpl> iter2(
+ static_cast<WBWIIteratorImpl*>(batch_->NewIterator(&cf2)));
+ AssertIterEqual(iter1.get(), {"a", "c", "e"});
+ AssertIterEqual(iter2.get(), {});
+ ASSERT_OK(batch_->Put(&cf2, "a", "a2"));
+ ASSERT_OK(batch_->Merge(&cf2, "b", "b1"));
+ ASSERT_OK(batch_->Merge(&cf2, "b", "b2"));
+ ASSERT_OK(batch_->Delete(&cf2, "d"));
+ ASSERT_OK(batch_->Merge(&cf2, "d", "d2"));
+ ASSERT_OK(batch_->Merge(&cf2, "d", "d3"));
+ ASSERT_OK(batch_->Delete(&cf2, "f"));
+ AssertIterEqual(iter1.get(), {"a", "c", "e"});
+ AssertIterEqual(iter2.get(), {"a", "b", "d", "f"});
+}
+
+TEST_P(WriteBatchWithIndexTest, TestRandomIteraratorWithBase) {
+ std::vector<std::string> source_strings = {"a", "b", "c", "d", "e",
+ "f", "g", "h", "i", "j"};
+ for (int rand_seed = 301; rand_seed < 366; rand_seed++) {
+ Random rnd(rand_seed);
+
+ ColumnFamilyHandleImplDummy cf1(6, BytewiseComparator());
+ ColumnFamilyHandleImplDummy cf2(2, BytewiseComparator());
+ ColumnFamilyHandleImplDummy cf3(8, BytewiseComparator());
+ batch_->Clear();
+
+ if (rand_seed % 2 == 0) {
+ ASSERT_OK(batch_->Put(&cf2, "zoo", "bar"));
+ }
+ if (rand_seed % 4 == 1) {
+ ASSERT_OK(batch_->Put(&cf3, "zoo", "bar"));
+ }
+
+ KVMap map;
+ KVMap merged_map;
+ for (auto key : source_strings) {
+ std::string value = key + key;
+ int type = rnd.Uniform(6);
+ switch (type) {
+ case 0:
+ // only base has it
+ map[key] = value;
+ merged_map[key] = value;
+ break;
+ case 1:
+ // only delta has it
+ ASSERT_OK(batch_->Put(&cf1, key, value));
+ map[key] = value;
+ merged_map[key] = value;
+ break;
+ case 2:
+ // both has it. Delta should win
+ ASSERT_OK(batch_->Put(&cf1, key, value));
+ map[key] = "wrong_value";
+ merged_map[key] = value;
+ break;
+ case 3:
+ // both has it. Delta is delete
+ ASSERT_OK(batch_->Delete(&cf1, key));
+ map[key] = "wrong_value";
+ break;
+ case 4:
+ // only delta has it. Delta is delete
+ ASSERT_OK(batch_->Delete(&cf1, key));
+ map[key] = "wrong_value";
+ break;
+ default:
+ // Neither iterator has it.
+ break;
+ }
+ }
+
+ std::unique_ptr<Iterator> iter(
+ batch_->NewIteratorWithBase(&cf1, new KVIter(&map)));
+ std::unique_ptr<Iterator> result_iter(new KVIter(&merged_map));
+
+ bool is_valid = false;
+ for (int i = 0; i < 128; i++) {
+ // Random walk and make sure iter and result_iter returns the
+ // same key and value
+ int type = rnd.Uniform(6);
+ ASSERT_OK(iter->status());
+ switch (type) {
+ case 0:
+ // Seek to First
+ iter->SeekToFirst();
+ result_iter->SeekToFirst();
+ break;
+ case 1:
+ // Seek to last
+ iter->SeekToLast();
+ result_iter->SeekToLast();
+ break;
+ case 2: {
+ // Seek to random key
+ auto key_idx = rnd.Uniform(static_cast<int>(source_strings.size()));
+ auto key = source_strings[key_idx];
+ iter->Seek(key);
+ result_iter->Seek(key);
+ break;
+ }
+ case 3: {
+ // SeekForPrev to random key
+ auto key_idx = rnd.Uniform(static_cast<int>(source_strings.size()));
+ auto key = source_strings[key_idx];
+ iter->SeekForPrev(key);
+ result_iter->SeekForPrev(key);
+ break;
+ }
+ case 4:
+ // Next
+ if (is_valid) {
+ iter->Next();
+ result_iter->Next();
+ } else {
+ continue;
+ }
+ break;
+ default:
+ assert(type == 5);
+ // Prev
+ if (is_valid) {
+ iter->Prev();
+ result_iter->Prev();
+ } else {
+ continue;
+ }
+ break;
+ }
+ AssertItersMatch(iter.get(), result_iter.get());
+ is_valid = iter->Valid();
+ }
+
+ ASSERT_OK(iter->status());
+ }
+}
+
+TEST_P(WriteBatchWithIndexTest, TestIteraratorWithBase) {
+ ColumnFamilyHandleImplDummy cf1(6, BytewiseComparator());
+ ColumnFamilyHandleImplDummy cf2(2, BytewiseComparator());
+ {
+ KVMap map;
+ map["a"] = "aa";
+ map["c"] = "cc";
+ map["e"] = "ee";
+ std::unique_ptr<Iterator> iter(
+ batch_->NewIteratorWithBase(&cf1, new KVIter(&map)));
+
+ iter->SeekToFirst();
+ AssertIter(iter.get(), "a", "aa");
+ iter->Next();
+ AssertIter(iter.get(), "c", "cc");
+ iter->Next();
+ AssertIter(iter.get(), "e", "ee");
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(!iter->Valid());
+
+ iter->SeekToLast();
+ AssertIter(iter.get(), "e", "ee");
+ iter->Prev();
+ AssertIter(iter.get(), "c", "cc");
+ iter->Prev();
+ AssertIter(iter.get(), "a", "aa");
+ iter->Prev();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(!iter->Valid());
+
+ iter->Seek("b");
+ AssertIter(iter.get(), "c", "cc");
+
+ iter->Prev();
+ AssertIter(iter.get(), "a", "aa");
+
+ iter->Seek("a");
+ AssertIter(iter.get(), "a", "aa");
+ }
+
+ // Test the case that there is one element in the write batch
+ ASSERT_OK(batch_->Put(&cf2, "zoo", "bar"));
+ ASSERT_OK(batch_->Put(&cf1, "a", "aa"));
+ {
+ KVMap empty_map;
+ std::unique_ptr<Iterator> iter(
+ batch_->NewIteratorWithBase(&cf1, new KVIter(&empty_map)));
+
+ iter->SeekToFirst();
+ AssertIter(iter.get(), "a", "aa");
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(!iter->Valid());
+ }
+
+ ASSERT_OK(batch_->Delete(&cf1, "b"));
+ ASSERT_OK(batch_->Put(&cf1, "c", "cc"));
+ ASSERT_OK(batch_->Put(&cf1, "d", "dd"));
+ ASSERT_OK(batch_->Delete(&cf1, "e"));
+
+ {
+ KVMap map;
+ map["b"] = "";
+ map["cc"] = "cccc";
+ map["f"] = "ff";
+ std::unique_ptr<Iterator> iter(
+ batch_->NewIteratorWithBase(&cf1, new KVIter(&map)));
+
+ iter->SeekToFirst();
+ AssertIter(iter.get(), "a", "aa");
+ iter->Next();
+ AssertIter(iter.get(), "c", "cc");
+ iter->Next();
+ AssertIter(iter.get(), "cc", "cccc");
+ iter->Next();
+ AssertIter(iter.get(), "d", "dd");
+ iter->Next();
+ AssertIter(iter.get(), "f", "ff");
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(!iter->Valid());
+
+ iter->SeekToLast();
+ AssertIter(iter.get(), "f", "ff");
+ iter->Prev();
+ AssertIter(iter.get(), "d", "dd");
+ iter->Prev();
+ AssertIter(iter.get(), "cc", "cccc");
+ iter->Prev();
+ AssertIter(iter.get(), "c", "cc");
+ iter->Next();
+ AssertIter(iter.get(), "cc", "cccc");
+ iter->Prev();
+ AssertIter(iter.get(), "c", "cc");
+ iter->Prev();
+ AssertIter(iter.get(), "a", "aa");
+ iter->Prev();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(!iter->Valid());
+
+ iter->Seek("c");
+ AssertIter(iter.get(), "c", "cc");
+
+ iter->Seek("cb");
+ AssertIter(iter.get(), "cc", "cccc");
+
+ iter->Seek("cc");
+ AssertIter(iter.get(), "cc", "cccc");
+ iter->Next();
+ AssertIter(iter.get(), "d", "dd");
+
+ iter->Seek("e");
+ AssertIter(iter.get(), "f", "ff");
+
+ iter->Prev();
+ AssertIter(iter.get(), "d", "dd");
+
+ iter->Next();
+ AssertIter(iter.get(), "f", "ff");
+ }
+
+ {
+ KVMap empty_map;
+ std::unique_ptr<Iterator> iter(
+ batch_->NewIteratorWithBase(&cf1, new KVIter(&empty_map)));
+
+ iter->SeekToFirst();
+ AssertIter(iter.get(), "a", "aa");
+ iter->Next();
+ AssertIter(iter.get(), "c", "cc");
+ iter->Next();
+ AssertIter(iter.get(), "d", "dd");
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(!iter->Valid());
+
+ iter->SeekToLast();
+ AssertIter(iter.get(), "d", "dd");
+ iter->Prev();
+ AssertIter(iter.get(), "c", "cc");
+ iter->Prev();
+ AssertIter(iter.get(), "a", "aa");
+
+ iter->Prev();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(!iter->Valid());
+
+ iter->Seek("aa");
+ AssertIter(iter.get(), "c", "cc");
+ iter->Next();
+ AssertIter(iter.get(), "d", "dd");
+
+ iter->Seek("ca");
+ AssertIter(iter.get(), "d", "dd");
+
+ iter->Prev();
+ AssertIter(iter.get(), "c", "cc");
+ }
+}
+
+TEST_P(WriteBatchWithIndexTest, TestIteraratorWithBaseReverseCmp) {
+ ColumnFamilyHandleImplDummy cf1(6, ReverseBytewiseComparator());
+ ColumnFamilyHandleImplDummy cf2(2, ReverseBytewiseComparator());
+
+ // Test the case that there is one element in the write batch
+ ASSERT_OK(batch_->Put(&cf2, "zoo", "bar"));
+ ASSERT_OK(batch_->Put(&cf1, "a", "aa"));
+ {
+ KVMap empty_map;
+ std::unique_ptr<Iterator> iter(
+ batch_->NewIteratorWithBase(&cf1, new KVIter(&empty_map)));
+
+ iter->SeekToFirst();
+ AssertIter(iter.get(), "a", "aa");
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(!iter->Valid());
+ }
+
+ ASSERT_OK(batch_->Put(&cf1, "c", "cc"));
+ {
+ KVMap map;
+ std::unique_ptr<Iterator> iter(
+ batch_->NewIteratorWithBase(&cf1, new KVIter(&map)));
+
+ iter->SeekToFirst();
+ AssertIter(iter.get(), "c", "cc");
+ iter->Next();
+ AssertIter(iter.get(), "a", "aa");
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(!iter->Valid());
+
+ iter->SeekToLast();
+ AssertIter(iter.get(), "a", "aa");
+ iter->Prev();
+ AssertIter(iter.get(), "c", "cc");
+ iter->Prev();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(!iter->Valid());
+
+ iter->Seek("b");
+ AssertIter(iter.get(), "a", "aa");
+
+ iter->Prev();
+ AssertIter(iter.get(), "c", "cc");
+
+ iter->Seek("a");
+ AssertIter(iter.get(), "a", "aa");
+ }
+
+ // default column family
+ ASSERT_OK(batch_->Put("a", "b"));
+ {
+ KVMap map;
+ map["b"] = "";
+ std::unique_ptr<Iterator> iter(
+ batch_->NewIteratorWithBase(new KVIter(&map)));
+
+ iter->SeekToFirst();
+ AssertIter(iter.get(), "a", "b");
+ iter->Next();
+ AssertIter(iter.get(), "b", "");
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(!iter->Valid());
+
+ iter->SeekToLast();
+ AssertIter(iter.get(), "b", "");
+ iter->Prev();
+ AssertIter(iter.get(), "a", "b");
+ iter->Prev();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(!iter->Valid());
+
+ iter->Seek("b");
+ AssertIter(iter.get(), "b", "");
+
+ iter->Prev();
+ AssertIter(iter.get(), "a", "b");
+
+ iter->Seek("0");
+ AssertIter(iter.get(), "a", "b");
+ }
+}
+
+TEST_P(WriteBatchWithIndexTest, TestGetFromBatch) {
+ Options options;
+ Status s;
+ std::string value;
+
+ s = batch_->GetFromBatch(options_, "b", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ ASSERT_OK(batch_->Put("a", "a"));
+ ASSERT_OK(batch_->Put("b", "b"));
+ ASSERT_OK(batch_->Put("c", "c"));
+ ASSERT_OK(batch_->Put("a", "z"));
+ ASSERT_OK(batch_->Delete("c"));
+ ASSERT_OK(batch_->Delete("d"));
+ ASSERT_OK(batch_->Delete("e"));
+ ASSERT_OK(batch_->Put("e", "e"));
+
+ s = batch_->GetFromBatch(options_, "b", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("b", value);
+
+ s = batch_->GetFromBatch(options_, "a", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("z", value);
+
+ s = batch_->GetFromBatch(options_, "c", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = batch_->GetFromBatch(options_, "d", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = batch_->GetFromBatch(options_, "x", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = batch_->GetFromBatch(options_, "e", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("e", value);
+
+ ASSERT_OK(batch_->Merge("z", "z"));
+
+ s = batch_->GetFromBatch(options_, "z", &value);
+ ASSERT_NOK(s); // No merge operator specified.
+
+ s = batch_->GetFromBatch(options_, "b", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("b", value);
+}
+
+TEST_P(WriteBatchWithIndexTest, TestGetFromBatchMerge) {
+ Status s = OpenDB();
+ ASSERT_OK(s);
+
+ ColumnFamilyHandle* column_family = db_->DefaultColumnFamily();
+ std::string value;
+
+ s = batch_->GetFromBatch(options_, "x", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ ASSERT_OK(batch_->Put("x", "X"));
+ std::string expected = "X";
+
+ for (int i = 0; i < 5; i++) {
+ ASSERT_OK(batch_->Merge("x", std::to_string(i)));
+ expected = expected + "," + std::to_string(i);
+
+ if (i % 2 == 0) {
+ ASSERT_OK(batch_->Put("y", std::to_string(i / 2)));
+ }
+
+ ASSERT_OK(batch_->Merge("z", "z"));
+
+ s = batch_->GetFromBatch(column_family, options_, "x", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(expected, value);
+
+ s = batch_->GetFromBatch(column_family, options_, "y", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(std::to_string(i / 2), value);
+
+ s = batch_->GetFromBatch(column_family, options_, "z", &value);
+ ASSERT_TRUE(s.IsMergeInProgress());
+ }
+}
+
+TEST_F(WBWIOverwriteTest, TestGetFromBatchMerge2) {
+ Status s = OpenDB();
+ ASSERT_OK(s);
+
+ ColumnFamilyHandle* column_family = db_->DefaultColumnFamily();
+ std::string value;
+
+ s = batch_->GetFromBatch(column_family, options_, "X", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ ASSERT_OK(batch_->Put(column_family, "X", "x"));
+ ASSERT_OK(batch_->GetFromBatch(column_family, options_, "X", &value));
+ ASSERT_EQ("x", value);
+
+ ASSERT_OK(batch_->Put(column_family, "X", "x2"));
+ ASSERT_OK(batch_->GetFromBatch(column_family, options_, "X", &value));
+ ASSERT_EQ("x2", value);
+
+ ASSERT_OK(batch_->Merge(column_family, "X", "aaa"));
+ ASSERT_OK(batch_->GetFromBatch(column_family, options_, "X", &value));
+ ASSERT_EQ("x2,aaa", value);
+
+ ASSERT_OK(batch_->Merge(column_family, "X", "bbb"));
+ ASSERT_OK(batch_->GetFromBatch(column_family, options_, "X", &value));
+ ASSERT_EQ("x2,aaa,bbb", value);
+
+ ASSERT_OK(batch_->Put(column_family, "X", "x3"));
+ ASSERT_OK(batch_->GetFromBatch(column_family, options_, "X", &value));
+ ASSERT_EQ("x3", value);
+
+ ASSERT_OK(batch_->Merge(column_family, "X", "ccc"));
+ ASSERT_OK(batch_->GetFromBatch(column_family, options_, "X", &value));
+ ASSERT_EQ("x3,ccc", value);
+
+ ASSERT_OK(batch_->Delete(column_family, "X"));
+ s = batch_->GetFromBatch(column_family, options_, "X", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ batch_->Merge(column_family, "X", "ddd");
+ ASSERT_OK(batch_->GetFromBatch(column_family, options_, "X", &value));
+ ASSERT_EQ("ddd", value);
+}
+
+TEST_P(WriteBatchWithIndexTest, TestGetFromBatchAndDB) {
+ ASSERT_OK(OpenDB());
+
+ std::string value;
+
+ ASSERT_OK(db_->Put(write_opts_, "a", "a"));
+ ASSERT_OK(db_->Put(write_opts_, "b", "b"));
+ ASSERT_OK(db_->Put(write_opts_, "c", "c"));
+
+ ASSERT_OK(batch_->Put("a", "batch_->a"));
+ ASSERT_OK(batch_->Delete("b"));
+
+ ASSERT_OK(batch_->GetFromBatchAndDB(db_, read_opts_, "a", &value));
+ ASSERT_EQ("batch_->a", value);
+
+ Status s = batch_->GetFromBatchAndDB(db_, read_opts_, "b", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ ASSERT_OK(batch_->GetFromBatchAndDB(db_, read_opts_, "c", &value));
+ ASSERT_EQ("c", value);
+
+ s = batch_->GetFromBatchAndDB(db_, read_opts_, "x", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ ASSERT_OK(db_->Delete(write_opts_, "x"));
+
+ s = batch_->GetFromBatchAndDB(db_, read_opts_, "x", &value);
+ ASSERT_TRUE(s.IsNotFound());
+}
+
+TEST_P(WriteBatchWithIndexTest, TestGetFromBatchAndDBMerge) {
+ Status s = OpenDB();
+ ASSERT_OK(s);
+
+ std::string value;
+
+ ASSERT_OK(db_->Put(write_opts_, "a", "a0"));
+ ASSERT_OK(db_->Put(write_opts_, "b", "b0"));
+ ASSERT_OK(db_->Merge(write_opts_, "b", "b1"));
+ ASSERT_OK(db_->Merge(write_opts_, "c", "c0"));
+ ASSERT_OK(db_->Merge(write_opts_, "d", "d0"));
+
+ ASSERT_OK(batch_->Merge("a", "a1"));
+ ASSERT_OK(batch_->Merge("a", "a2"));
+ ASSERT_OK(batch_->Merge("b", "b2"));
+ ASSERT_OK(batch_->Merge("d", "d1"));
+ ASSERT_OK(batch_->Merge("e", "e0"));
+
+ ASSERT_OK(batch_->GetFromBatchAndDB(db_, read_opts_, "a", &value));
+ ASSERT_EQ("a0,a1,a2", value);
+
+ ASSERT_OK(batch_->GetFromBatchAndDB(db_, read_opts_, "b", &value));
+ ASSERT_EQ("b0,b1,b2", value);
+
+ ASSERT_OK(batch_->GetFromBatchAndDB(db_, read_opts_, "c", &value));
+ ASSERT_EQ("c0", value);
+
+ ASSERT_OK(batch_->GetFromBatchAndDB(db_, read_opts_, "d", &value));
+ ASSERT_EQ("d0,d1", value);
+
+ ASSERT_OK(batch_->GetFromBatchAndDB(db_, read_opts_, "e", &value));
+ ASSERT_EQ("e0", value);
+
+ ASSERT_OK(db_->Delete(write_opts_, "x"));
+
+ s = batch_->GetFromBatchAndDB(db_, read_opts_, "x", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ const Snapshot* snapshot = db_->GetSnapshot();
+ ReadOptions snapshot_read_options;
+ snapshot_read_options.snapshot = snapshot;
+
+ ASSERT_OK(db_->Delete(write_opts_, "a"));
+
+ ASSERT_OK(batch_->GetFromBatchAndDB(db_, read_opts_, "a", &value));
+ ASSERT_EQ("a1,a2", value);
+
+ ASSERT_OK(
+ s = batch_->GetFromBatchAndDB(db_, snapshot_read_options, "a", &value));
+ ASSERT_EQ("a0,a1,a2", value);
+
+ ASSERT_OK(batch_->Delete("a"));
+
+ s = batch_->GetFromBatchAndDB(db_, read_opts_, "a", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = batch_->GetFromBatchAndDB(db_, snapshot_read_options, "a", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ ASSERT_OK(s = db_->Merge(write_opts_, "c", "c1"));
+
+ ASSERT_OK(s = batch_->GetFromBatchAndDB(db_, read_opts_, "c", &value));
+ ASSERT_EQ("c0,c1", value);
+
+ ASSERT_OK(
+ s = batch_->GetFromBatchAndDB(db_, snapshot_read_options, "c", &value));
+ ASSERT_EQ("c0", value);
+
+ ASSERT_OK(db_->Put(write_opts_, "e", "e1"));
+ ASSERT_OK(batch_->GetFromBatchAndDB(db_, read_opts_, "e", &value));
+ ASSERT_EQ("e1,e0", value);
+
+ ASSERT_OK(batch_->GetFromBatchAndDB(db_, snapshot_read_options, "e", &value));
+ ASSERT_EQ("e0", value);
+
+ ASSERT_OK(s = db_->Delete(write_opts_, "e"));
+ ASSERT_OK(batch_->GetFromBatchAndDB(db_, read_opts_, "e", &value));
+ ASSERT_EQ("e0", value);
+
+ ASSERT_OK(batch_->GetFromBatchAndDB(db_, snapshot_read_options, "e", &value));
+ ASSERT_EQ("e0", value);
+
+ db_->ReleaseSnapshot(snapshot);
+}
+
+TEST_F(WBWIOverwriteTest, TestGetFromBatchAndDBMerge2) {
+ Status s = OpenDB();
+ ASSERT_OK(s);
+
+ std::string value;
+
+ s = batch_->GetFromBatchAndDB(db_, read_opts_, "A", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ ASSERT_OK(batch_->Merge("A", "xxx"));
+ ASSERT_OK(batch_->GetFromBatchAndDB(db_, read_opts_, "A", &value));
+ ASSERT_EQ(value, "xxx");
+
+ ASSERT_OK(batch_->Merge("A", "yyy"));
+ ASSERT_OK(batch_->GetFromBatchAndDB(db_, read_opts_, "A", &value));
+ ASSERT_EQ(value, "xxx,yyy");
+
+ ASSERT_OK(db_->Put(write_opts_, "A", "a0"));
+ ASSERT_OK(batch_->GetFromBatchAndDB(db_, read_opts_, "A", &value));
+ ASSERT_EQ(value, "a0,xxx,yyy");
+
+ ASSERT_OK(batch_->Delete("A"));
+
+ s = batch_->GetFromBatchAndDB(db_, read_opts_, "A", &value);
+ ASSERT_TRUE(s.IsNotFound());
+}
+
+TEST_P(WriteBatchWithIndexTest, TestGetFromBatchAndDBMerge3) {
+ Status s = OpenDB();
+ ASSERT_OK(s);
+
+ FlushOptions flush_options;
+ std::string value;
+
+ ASSERT_OK(db_->Put(write_opts_, "A", "1"));
+ ASSERT_OK(db_->Flush(flush_options, db_->DefaultColumnFamily()));
+ ASSERT_OK(batch_->Merge("A", "2"));
+
+ ASSERT_OK(batch_->GetFromBatchAndDB(db_, read_opts_, "A", &value));
+ ASSERT_EQ(value, "1,2");
+}
+
+TEST_P(WriteBatchWithIndexTest, TestPinnedGetFromBatchAndDB) {
+ Status s = OpenDB();
+ ASSERT_OK(s);
+
+ PinnableSlice value;
+
+ ASSERT_OK(db_->Put(write_opts_, "a", "a0"));
+ ASSERT_OK(db_->Put(write_opts_, "b", "b0"));
+ ASSERT_OK(db_->Merge(write_opts_, "b", "b1"));
+ ASSERT_OK(db_->Merge(write_opts_, "c", "c0"));
+ ASSERT_OK(db_->Merge(write_opts_, "d", "d0"));
+ ASSERT_OK(batch_->Merge("a", "a1"));
+ ASSERT_OK(batch_->Merge("a", "a2"));
+ ASSERT_OK(batch_->Merge("b", "b2"));
+ ASSERT_OK(batch_->Merge("d", "d1"));
+ ASSERT_OK(batch_->Merge("e", "e0"));
+
+ for (int i = 0; i < 2; i++) {
+ if (i == 1) {
+ // Do it again with a flushed DB...
+ ASSERT_OK(db_->Flush(FlushOptions(), db_->DefaultColumnFamily()));
+ }
+ ASSERT_OK(batch_->GetFromBatchAndDB(db_, read_opts_, "a", &value));
+ ASSERT_EQ("a0,a1,a2", value.ToString());
+
+ ASSERT_OK(batch_->GetFromBatchAndDB(db_, read_opts_, "b", &value));
+ ASSERT_EQ("b0,b1,b2", value.ToString());
+
+ ASSERT_OK(batch_->GetFromBatchAndDB(db_, read_opts_, "c", &value));
+ ASSERT_EQ("c0", value.ToString());
+
+ ASSERT_OK(batch_->GetFromBatchAndDB(db_, read_opts_, "d", &value));
+ ASSERT_EQ("d0,d1", value.ToString());
+
+ ASSERT_OK(batch_->GetFromBatchAndDB(db_, read_opts_, "e", &value));
+ ASSERT_EQ("e0", value.ToString());
+ ASSERT_OK(db_->Delete(write_opts_, "x"));
+
+ s = batch_->GetFromBatchAndDB(db_, read_opts_, "x", &value);
+ ASSERT_TRUE(s.IsNotFound());
+ }
+}
+
+void AssertKey(std::string key, WBWIIterator* iter) {
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ(key, iter->Entry().key.ToString());
+}
+
+void AssertValue(std::string value, WBWIIterator* iter) {
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ(value, iter->Entry().value.ToString());
+}
+
+// Tests that we can write to the WBWI while we iterate (from a single thread).
+// iteration should see the newest writes
+TEST_F(WBWIOverwriteTest, MutateWhileIteratingCorrectnessTest) {
+ for (char c = 'a'; c <= 'z'; ++c) {
+ ASSERT_OK(batch_->Put(std::string(1, c), std::string(1, c)));
+ }
+
+ std::unique_ptr<WBWIIterator> iter(batch_->NewIterator());
+ iter->Seek("k");
+ AssertKey("k", iter.get());
+ iter->Next();
+ AssertKey("l", iter.get());
+ ASSERT_OK(batch_->Put("ab", "cc"));
+ iter->Next();
+ AssertKey("m", iter.get());
+ ASSERT_OK(batch_->Put("mm", "kk"));
+ iter->Next();
+ AssertKey("mm", iter.get());
+ AssertValue("kk", iter.get());
+ ASSERT_OK(batch_->Delete("mm"));
+
+ iter->Next();
+ AssertKey("n", iter.get());
+ iter->Prev();
+ AssertKey("mm", iter.get());
+ ASSERT_EQ(kDeleteRecord, iter->Entry().type);
+
+ iter->Seek("ab");
+ AssertKey("ab", iter.get());
+ ASSERT_OK(batch_->Delete("x"));
+ iter->Seek("x");
+ AssertKey("x", iter.get());
+ ASSERT_EQ(kDeleteRecord, iter->Entry().type);
+ iter->Prev();
+ AssertKey("w", iter.get());
+}
+
+void AssertIterKey(std::string key, Iterator* iter) {
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ(key, iter->key().ToString());
+}
+
+void AssertIterValue(std::string value, Iterator* iter) {
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ(value, iter->value().ToString());
+}
+
+// same thing as above, but testing IteratorWithBase
+TEST_F(WBWIOverwriteTest, MutateWhileIteratingBaseCorrectnessTest) {
+ WriteBatchWithIndex batch(BytewiseComparator(), 0, true);
+ for (char c = 'a'; c <= 'z'; ++c) {
+ ASSERT_OK(batch_->Put(std::string(1, c), std::string(1, c)));
+ }
+
+ KVMap map;
+ map["aa"] = "aa";
+ map["cc"] = "cc";
+ map["ee"] = "ee";
+ map["em"] = "me";
+
+ std::unique_ptr<Iterator> iter(batch_->NewIteratorWithBase(new KVIter(&map)));
+ iter->Seek("k");
+ AssertIterKey("k", iter.get());
+ iter->Next();
+ AssertIterKey("l", iter.get());
+ ASSERT_OK(batch_->Put("ab", "cc"));
+ iter->Next();
+ AssertIterKey("m", iter.get());
+ ASSERT_OK(batch_->Put("mm", "kk"));
+ iter->Next();
+ AssertIterKey("mm", iter.get());
+ AssertIterValue("kk", iter.get());
+ ASSERT_OK(batch_->Delete("mm"));
+ iter->Next();
+ AssertIterKey("n", iter.get());
+ iter->Prev();
+ // "mm" is deleted, so we're back at "m"
+ AssertIterKey("m", iter.get());
+
+ iter->Seek("ab");
+ AssertIterKey("ab", iter.get());
+ iter->Prev();
+ AssertIterKey("aa", iter.get());
+ iter->Prev();
+ AssertIterKey("a", iter.get());
+ ASSERT_OK(batch_->Delete("aa"));
+ iter->Next();
+ AssertIterKey("ab", iter.get());
+ iter->Prev();
+ AssertIterKey("a", iter.get());
+
+ ASSERT_OK(batch_->Delete("x"));
+ iter->Seek("x");
+ AssertIterKey("y", iter.get());
+ iter->Next();
+ AssertIterKey("z", iter.get());
+ iter->Prev();
+ iter->Prev();
+ AssertIterKey("w", iter.get());
+
+ ASSERT_OK(batch_->Delete("e"));
+ iter->Seek("e");
+ AssertIterKey("ee", iter.get());
+ AssertIterValue("ee", iter.get());
+ ASSERT_OK(batch_->Put("ee", "xx"));
+ // still the same value
+ AssertIterValue("ee", iter.get());
+ iter->Next();
+ AssertIterKey("em", iter.get());
+ iter->Prev();
+ // new value
+ AssertIterValue("xx", iter.get());
+
+ ASSERT_OK(iter->status());
+}
+
+// stress testing mutations with IteratorWithBase
+TEST_F(WBWIOverwriteTest, MutateWhileIteratingBaseStressTest) {
+ for (char c = 'a'; c <= 'z'; ++c) {
+ ASSERT_OK(batch_->Put(std::string(1, c), std::string(1, c)));
+ }
+
+ KVMap map;
+ for (char c = 'a'; c <= 'z'; ++c) {
+ map[std::string(2, c)] = std::string(2, c);
+ }
+
+ std::unique_ptr<Iterator> iter(batch_->NewIteratorWithBase(new KVIter(&map)));
+
+ Random rnd(301);
+ for (int i = 0; i < 1000000; ++i) {
+ int random = rnd.Uniform(8);
+ char c = static_cast<char>(rnd.Uniform(26) + 'a');
+ switch (random) {
+ case 0:
+ ASSERT_OK(batch_->Put(std::string(1, c), "xxx"));
+ break;
+ case 1:
+ ASSERT_OK(batch_->Put(std::string(2, c), "xxx"));
+ break;
+ case 2:
+ ASSERT_OK(batch_->Delete(std::string(1, c)));
+ break;
+ case 3:
+ ASSERT_OK(batch_->Delete(std::string(2, c)));
+ break;
+ case 4:
+ iter->Seek(std::string(1, c));
+ break;
+ case 5:
+ iter->Seek(std::string(2, c));
+ break;
+ case 6:
+ if (iter->Valid()) {
+ iter->Next();
+ }
+ break;
+ case 7:
+ if (iter->Valid()) {
+ iter->Prev();
+ }
+ break;
+ default:
+ assert(false);
+ }
+ }
+ ASSERT_OK(iter->status());
+}
+
+TEST_P(WriteBatchWithIndexTest, TestNewIteratorWithBaseFromWbwi) {
+ ColumnFamilyHandleImplDummy cf1(6, BytewiseComparator());
+ KVMap map;
+ map["a"] = "aa";
+ map["c"] = "cc";
+ map["e"] = "ee";
+ std::unique_ptr<Iterator> iter(
+ batch_->NewIteratorWithBase(&cf1, new KVIter(&map)));
+ ASSERT_NE(nullptr, iter);
+ iter->SeekToFirst();
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_OK(iter->status());
+}
+
+TEST_P(WriteBatchWithIndexTest, SavePointTest) {
+ ColumnFamilyHandleImplDummy cf1(1, BytewiseComparator());
+ KVMap empty_map;
+ std::unique_ptr<Iterator> cf0_iter(
+ batch_->NewIteratorWithBase(new KVIter(&empty_map)));
+ std::unique_ptr<Iterator> cf1_iter(
+ batch_->NewIteratorWithBase(&cf1, new KVIter(&empty_map)));
+ Status s;
+ KVMap kvm_cf0_0 = {{"A", "aa"}, {"B", "b"}};
+ KVMap kvm_cf1_0 = {{"A", "a1"}, {"C", "c1"}, {"E", "e1"}};
+ KVIter kvi_cf0_0(&kvm_cf0_0);
+ KVIter kvi_cf1_0(&kvm_cf1_0);
+
+ ASSERT_OK(batch_->Put("A", "a"));
+ ASSERT_OK(batch_->Put("B", "b"));
+ ASSERT_OK(batch_->Put("A", "aa"));
+ ASSERT_OK(batch_->Put(&cf1, "A", "a1"));
+ ASSERT_OK(batch_->Delete(&cf1, "B"));
+ ASSERT_OK(batch_->Put(&cf1, "C", "c1"));
+ ASSERT_OK(batch_->Put(&cf1, "E", "e1"));
+
+ AssertItersEqual(cf0_iter.get(), &kvi_cf0_0);
+ AssertItersEqual(cf1_iter.get(), &kvi_cf1_0);
+ batch_->SetSavePoint(); // 1
+
+ KVMap kvm_cf0_1 = {{"B", "bb"}, {"C", "cc"}};
+ KVMap kvm_cf1_1 = {{"B", "b1"}, {"C", "c1"}};
+ KVIter kvi_cf0_1(&kvm_cf0_1);
+ KVIter kvi_cf1_1(&kvm_cf1_1);
+
+ ASSERT_OK(batch_->Put("C", "cc"));
+ ASSERT_OK(batch_->Put("B", "bb"));
+ ASSERT_OK(batch_->Delete("A"));
+ ASSERT_OK(batch_->Put(&cf1, "B", "b1"));
+ ASSERT_OK(batch_->Delete(&cf1, "A"));
+ ASSERT_OK(batch_->SingleDelete(&cf1, "E"));
+ batch_->SetSavePoint(); // 2
+ AssertItersEqual(cf0_iter.get(), &kvi_cf0_1);
+ AssertItersEqual(cf1_iter.get(), &kvi_cf1_1);
+
+ KVMap kvm_cf0_2 = {{"A", "xxx"}, {"C", "cc"}};
+ KVMap kvm_cf1_2 = {{"B", "b2"}};
+ KVIter kvi_cf0_2(&kvm_cf0_2);
+ KVIter kvi_cf1_2(&kvm_cf1_2);
+
+ ASSERT_OK(batch_->Put("A", "aaa"));
+ ASSERT_OK(batch_->Put("A", "xxx"));
+ ASSERT_OK(batch_->Delete("B"));
+ ASSERT_OK(batch_->Put(&cf1, "B", "b2"));
+ ASSERT_OK(batch_->Delete(&cf1, "C"));
+ batch_->SetSavePoint(); // 3
+ batch_->SetSavePoint(); // 4
+ AssertItersEqual(cf0_iter.get(), &kvi_cf0_2);
+ AssertItersEqual(cf1_iter.get(), &kvi_cf1_2);
+
+ KVMap kvm_cf0_4 = {{"A", "xxx"}, {"C", "cc"}};
+ KVMap kvm_cf1_4 = {{"B", "b2"}};
+ KVIter kvi_cf0_4(&kvm_cf0_4);
+ KVIter kvi_cf1_4(&kvm_cf1_4);
+ ASSERT_OK(batch_->SingleDelete("D"));
+ ASSERT_OK(batch_->Delete(&cf1, "D"));
+ ASSERT_OK(batch_->Delete(&cf1, "E"));
+ AssertItersEqual(cf0_iter.get(), &kvi_cf0_4);
+ AssertItersEqual(cf1_iter.get(), &kvi_cf1_4);
+
+ ASSERT_OK(batch_->RollbackToSavePoint()); // rollback to 4
+ AssertItersEqual(cf0_iter.get(), &kvi_cf0_2);
+ AssertItersEqual(cf1_iter.get(), &kvi_cf1_2);
+
+ ASSERT_OK(batch_->RollbackToSavePoint()); // rollback to 3
+ AssertItersEqual(cf0_iter.get(), &kvi_cf0_2);
+ AssertItersEqual(cf1_iter.get(), &kvi_cf1_2);
+
+ ASSERT_OK(batch_->RollbackToSavePoint()); // rollback to 2
+ AssertItersEqual(cf0_iter.get(), &kvi_cf0_1);
+ AssertItersEqual(cf1_iter.get(), &kvi_cf1_1);
+
+ batch_->SetSavePoint(); // 5
+ ASSERT_OK(batch_->Put("X", "x"));
+
+ KVMap kvm_cf0_5 = {{"B", "bb"}, {"C", "cc"}, {"X", "x"}};
+ KVIter kvi_cf0_5(&kvm_cf0_5);
+ KVIter kvi_cf1_5(&kvm_cf1_1);
+ AssertItersEqual(cf0_iter.get(), &kvi_cf0_5);
+ AssertItersEqual(cf1_iter.get(), &kvi_cf1_5);
+
+ ASSERT_OK(batch_->RollbackToSavePoint()); // rollback to 5
+ AssertItersEqual(cf0_iter.get(), &kvi_cf0_1);
+ AssertItersEqual(cf1_iter.get(), &kvi_cf1_1);
+
+ ASSERT_OK(batch_->RollbackToSavePoint()); // rollback to 1
+ AssertItersEqual(cf0_iter.get(), &kvi_cf0_0);
+ AssertItersEqual(cf1_iter.get(), &kvi_cf1_0);
+
+ s = batch_->RollbackToSavePoint(); // no savepoint found
+ ASSERT_TRUE(s.IsNotFound());
+ AssertItersEqual(cf0_iter.get(), &kvi_cf0_0);
+ AssertItersEqual(cf1_iter.get(), &kvi_cf1_0);
+
+ batch_->SetSavePoint(); // 6
+
+ batch_->Clear();
+ ASSERT_EQ("", PrintContents(batch_.get(), nullptr));
+ ASSERT_EQ("", PrintContents(batch_.get(), &cf1));
+
+ s = batch_->RollbackToSavePoint(); // rollback to 6
+ ASSERT_TRUE(s.IsNotFound());
+}
+
+TEST_P(WriteBatchWithIndexTest, SingleDeleteTest) {
+ Status s;
+ std::string value;
+
+ ASSERT_OK(batch_->SingleDelete("A"));
+
+ s = batch_->GetFromBatch(options_, "A", &value);
+ ASSERT_TRUE(s.IsNotFound());
+ s = batch_->GetFromBatch(options_, "B", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ batch_->Clear();
+ ASSERT_OK(batch_->Put("A", "a"));
+ ASSERT_OK(batch_->Put("A", "a2"));
+ ASSERT_OK(batch_->Put("B", "b"));
+ ASSERT_OK(batch_->SingleDelete("A"));
+
+ s = batch_->GetFromBatch(options_, "A", &value);
+ ASSERT_TRUE(s.IsNotFound());
+ s = batch_->GetFromBatch(options_, "B", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("b", value);
+
+ ASSERT_OK(batch_->Put("C", "c"));
+ ASSERT_OK(batch_->Put("A", "a3"));
+ ASSERT_OK(batch_->Delete("B"));
+ ASSERT_OK(batch_->SingleDelete("B"));
+ ASSERT_OK(batch_->SingleDelete("C"));
+
+ s = batch_->GetFromBatch(options_, "A", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("a3", value);
+ s = batch_->GetFromBatch(options_, "B", &value);
+ ASSERT_TRUE(s.IsNotFound());
+ s = batch_->GetFromBatch(options_, "C", &value);
+ ASSERT_TRUE(s.IsNotFound());
+ s = batch_->GetFromBatch(options_, "D", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ ASSERT_OK(batch_->Put("B", "b4"));
+ ASSERT_OK(batch_->Put("C", "c4"));
+ ASSERT_OK(batch_->Put("D", "d4"));
+ ASSERT_OK(batch_->SingleDelete("D"));
+ ASSERT_OK(batch_->SingleDelete("D"));
+ ASSERT_OK(batch_->Delete("A"));
+
+ s = batch_->GetFromBatch(options_, "A", &value);
+ ASSERT_TRUE(s.IsNotFound());
+ s = batch_->GetFromBatch(options_, "B", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("b4", value);
+ s = batch_->GetFromBatch(options_, "C", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("c4", value);
+ s = batch_->GetFromBatch(options_, "D", &value);
+ ASSERT_TRUE(s.IsNotFound());
+}
+
+TEST_P(WriteBatchWithIndexTest, SingleDeleteDeltaIterTest) {
+ std::string value;
+ ASSERT_OK(batch_->Put("A", "a"));
+ ASSERT_OK(batch_->Put("A", "a2"));
+ ASSERT_OK(batch_->Put("B", "b"));
+ ASSERT_OK(batch_->SingleDelete("A"));
+ ASSERT_OK(batch_->Delete("B"));
+
+ KVMap map;
+ value = PrintContents(batch_.get(), &map, nullptr);
+ ASSERT_EQ("", value);
+
+ map["A"] = "aa";
+ map["C"] = "cc";
+ map["D"] = "dd";
+
+ ASSERT_OK(batch_->SingleDelete("B"));
+ ASSERT_OK(batch_->SingleDelete("C"));
+ ASSERT_OK(batch_->SingleDelete("Z"));
+
+ value = PrintContents(batch_.get(), &map, nullptr);
+ ASSERT_EQ("D:dd,", value);
+
+ ASSERT_OK(batch_->Put("A", "a3"));
+ ASSERT_OK(batch_->Put("B", "b3"));
+ ASSERT_OK(batch_->SingleDelete("A"));
+ ASSERT_OK(batch_->SingleDelete("A"));
+ ASSERT_OK(batch_->SingleDelete("D"));
+ ASSERT_OK(batch_->SingleDelete("D"));
+ ASSERT_OK(batch_->Delete("D"));
+
+ map["E"] = "ee";
+
+ value = PrintContents(batch_.get(), &map, nullptr);
+ ASSERT_EQ("B:b3,E:ee,", value);
+}
+
+TEST_P(WriteBatchWithIndexTest, MultiGetTest) {
+ // MultiGet a lot of keys in order to force std::vector reallocations
+ std::vector<std::string> keys;
+ for (int i = 0; i < 100; ++i) {
+ keys.emplace_back(std::to_string(i));
+ }
+
+ ASSERT_OK(OpenDB());
+ ColumnFamilyHandle* cf0 = db_->DefaultColumnFamily();
+
+ // Write some data to the db for the even numbered keys
+ {
+ WriteBatch wb;
+ for (size_t i = 0; i < keys.size(); i += 2) {
+ std::string val = "val" + std::to_string(i);
+ ASSERT_OK(wb.Put(cf0, keys[i], val));
+ }
+ ASSERT_OK(db_->Write(write_opts_, &wb));
+ for (size_t i = 0; i < keys.size(); i += 2) {
+ std::string value;
+ ASSERT_OK(db_->Get(read_opts_, cf0, keys[i], &value));
+ }
+ }
+
+ // Write some data to the batch
+ for (size_t i = 0; i < keys.size(); ++i) {
+ if ((i % 5) == 0) {
+ ASSERT_OK(batch_->Delete(cf0, keys[i]));
+ } else if ((i % 7) == 0) {
+ std::string val = "new" + std::to_string(i);
+ ASSERT_OK(batch_->Put(cf0, keys[i], val));
+ }
+ if (i > 0 && (i % 3) == 0) {
+ ASSERT_OK(batch_->Merge(cf0, keys[i], "merge"));
+ }
+ }
+
+ std::vector<Slice> key_slices;
+ for (size_t i = 0; i < keys.size(); ++i) {
+ key_slices.emplace_back(keys[i]);
+ }
+ std::vector<PinnableSlice> values(keys.size());
+ std::vector<Status> statuses(keys.size());
+
+ batch_->MultiGetFromBatchAndDB(db_, read_opts_, cf0, key_slices.size(),
+ key_slices.data(), values.data(),
+ statuses.data(), false);
+ for (size_t i = 0; i < keys.size(); ++i) {
+ if (i == 0) {
+ ASSERT_TRUE(statuses[i].IsNotFound());
+ } else if ((i % 3) == 0) {
+ ASSERT_OK(statuses[i]);
+ if ((i % 5) == 0) { // Merge after Delete
+ ASSERT_EQ(values[i], "merge");
+ } else if ((i % 7) == 0) { // Merge after Put
+ std::string val = "new" + std::to_string(i);
+ ASSERT_EQ(values[i], val + ",merge");
+ } else if ((i % 2) == 0) {
+ std::string val = "val" + std::to_string(i);
+ ASSERT_EQ(values[i], val + ",merge");
+ } else {
+ ASSERT_EQ(values[i], "merge");
+ }
+ } else if ((i % 5) == 0) {
+ ASSERT_TRUE(statuses[i].IsNotFound());
+ } else if ((i % 7) == 0) {
+ ASSERT_OK(statuses[i]);
+ ASSERT_EQ(values[i], "new" + std::to_string(i));
+ } else if ((i % 2) == 0) {
+ ASSERT_OK(statuses[i]);
+ ASSERT_EQ(values[i], "val" + std::to_string(i));
+ } else {
+ ASSERT_TRUE(statuses[i].IsNotFound());
+ }
+ }
+}
+TEST_P(WriteBatchWithIndexTest, MultiGetTest2) {
+ // MultiGet a lot of keys in order to force std::vector reallocations
+ const int num_keys = 700;
+ const int keys_per_pass = 100;
+ std::vector<std::string> keys;
+ for (size_t i = 0; i < num_keys; ++i) {
+ keys.emplace_back(std::to_string(i));
+ }
+ ASSERT_OK(OpenDB());
+ ColumnFamilyHandle* cf0 = db_->DefaultColumnFamily();
+
+ // Keys 0- 99 have a PUT in the batch but not DB
+ // Keys 100-199 have a PUT in the DB
+ // Keys 200-299 Have a PUT/DELETE
+ // Keys 300-399 Have a PUT/DELETE/MERGE
+ // Keys 400-499 have a PUT/MERGE
+ // Keys 500-599 have a MERGE only
+ // Keys 600-699 were never written
+ {
+ WriteBatch wb;
+ for (size_t i = 100; i < 500; i++) {
+ std::string val = std::to_string(i);
+ ASSERT_OK(wb.Put(cf0, keys[i], val));
+ }
+ ASSERT_OK(db_->Write(write_opts_, &wb));
+ }
+ ASSERT_OK(db_->Flush(FlushOptions(), cf0));
+ for (size_t i = 0; i < 100; i++) {
+ ASSERT_OK(batch_->Put(cf0, keys[i], keys[i]));
+ }
+ for (size_t i = 200; i < 400; i++) {
+ ASSERT_OK(batch_->Delete(cf0, keys[i]));
+ }
+ for (size_t i = 300; i < 600; i++) {
+ std::string val = std::to_string(i) + "m";
+ ASSERT_OK(batch_->Merge(cf0, keys[i], val));
+ }
+
+ Random rnd(301);
+ std::vector<PinnableSlice> values(keys_per_pass);
+ std::vector<Status> statuses(keys_per_pass);
+ for (int pass = 0; pass < 40; pass++) {
+ std::vector<Slice> key_slices;
+ for (size_t i = 0; i < keys_per_pass; i++) {
+ int random = rnd.Uniform(num_keys);
+ key_slices.emplace_back(keys[random]);
+ }
+ batch_->MultiGetFromBatchAndDB(db_, read_opts_, cf0, keys_per_pass,
+ key_slices.data(), values.data(),
+ statuses.data(), false);
+ for (size_t i = 0; i < keys_per_pass; i++) {
+ int key = ParseInt(key_slices[i].ToString());
+ switch (key / 100) {
+ case 0: // 0-99 PUT only
+ ASSERT_OK(statuses[i]);
+ ASSERT_EQ(values[i], key_slices[i].ToString());
+ break;
+ case 1: // 100-199 PUT only
+ ASSERT_OK(statuses[i]);
+ ASSERT_EQ(values[i], key_slices[i].ToString());
+ break;
+ case 2: // 200-299 Deleted
+ ASSERT_TRUE(statuses[i].IsNotFound());
+ break;
+ case 3: // 300-399 Delete+Merge
+ ASSERT_OK(statuses[i]);
+ ASSERT_EQ(values[i], key_slices[i].ToString() + "m");
+ break;
+ case 4: // 400-400 Put+ Merge
+ ASSERT_OK(statuses[i]);
+ ASSERT_EQ(values[i], key_slices[i].ToString() + "," +
+ key_slices[i].ToString() + "m");
+ break;
+ case 5: // Merge only
+ ASSERT_OK(statuses[i]);
+ ASSERT_EQ(values[i], key_slices[i].ToString() + "m");
+ break;
+ case 6: // Never written
+ ASSERT_TRUE(statuses[i].IsNotFound());
+ break;
+ default:
+ assert(false);
+ } // end switch
+ } // End for each key
+ } // end for passes
+}
+
+// This test has merges, but the merge does not play into the final result
+TEST_P(WriteBatchWithIndexTest, FakeMergeWithIteratorTest) {
+ ASSERT_OK(OpenDB());
+ ColumnFamilyHandle* cf0 = db_->DefaultColumnFamily();
+
+ // The map we are starting with
+ KVMap input = {
+ {"odm", "odm0"},
+ {"omd", "omd0"},
+ {"omp", "omp0"},
+ };
+ KVMap result = {
+ {"odm", "odm2"}, // Orig, Delete, Merge
+ {"mp", "mp1"}, // Merge, Put
+ {"omp", "omp2"}, // Origi, Merge, Put
+ {"mmp", "mmp2"} // Merge, Merge, Put
+ };
+
+ for (auto& iter : result) {
+ EXPECT_EQ(AddToBatch(cf0, iter.first), iter.second);
+ }
+ AddToBatch(cf0, "md"); // Merge, Delete
+ AddToBatch(cf0, "mmd"); // Merge, Merge, Delete
+ AddToBatch(cf0, "omd"); // Orig, Merge, Delete
+
+ KVIter kvi(&result);
+ // First try just the batch
+ std::unique_ptr<Iterator> iter(
+ batch_->NewIteratorWithBase(cf0, new KVIter(&input)));
+ AssertItersEqual(iter.get(), &kvi);
+}
+
+TEST_P(WriteBatchWithIndexTest, IteratorMergeTest) {
+ ASSERT_OK(OpenDB());
+ ColumnFamilyHandle* cf0 = db_->DefaultColumnFamily();
+
+ KVMap result = {
+ {"m", "m0"}, // Merge
+ {"mm", "mm0,mm1"}, // Merge, Merge
+ {"dm", "dm1"}, // Delete, Merge
+ {"dmm", "dmm1,dmm2"}, // Delete, Merge, Merge
+ {"mdm", "mdm2"}, // Merge, Delete, Merge
+ {"mpm", "mpm1,mpm2"}, // Merge, Put, Merge
+ {"pm", "pm0,pm1"}, // Put, Merge
+ {"pmm", "pmm0,pmm1,pmm2"}, // Put, Merge, Merge
+ };
+
+ for (auto& iter : result) {
+ EXPECT_EQ(AddToBatch(cf0, iter.first), iter.second);
+ }
+
+ KVIter kvi(&result);
+ // First try just the batch
+ KVMap empty_map;
+ std::unique_ptr<Iterator> iter(
+ batch_->NewIteratorWithBase(cf0, new KVIter(&empty_map)));
+ AssertItersEqual(iter.get(), &kvi);
+}
+
+TEST_P(WriteBatchWithIndexTest, IteratorMergeTestWithOrig) {
+ ASSERT_OK(OpenDB());
+ ColumnFamilyHandle* cf0 = db_->DefaultColumnFamily();
+ KVMap original;
+ KVMap results = {
+ {"m", "om,m0"}, // Merge
+ {"mm", "omm,mm0,mm1"}, // Merge, Merge
+ {"dm", "dm1"}, // Delete, Merge
+ {"dmm", "dmm1,dmm2"}, // Delete, Merge, Merge
+ {"mdm", "mdm2"}, // Merge, Delete, Merge
+ {"mpm", "mpm1,mpm2"}, // Merge, Put, Merge
+ {"pm", "pm0,pm1"}, // Put, Merge
+ {"pmm", "pmm0,pmm1,pmm2"}, // Put, Merge, Merge
+ };
+
+ for (auto& iter : results) {
+ AddToBatch(cf0, iter.first);
+ original[iter.first] = "o" + iter.first;
+ }
+
+ KVIter kvi(&results);
+ // First try just the batch
+ std::unique_ptr<Iterator> iter(
+ batch_->NewIteratorWithBase(cf0, new KVIter(&original)));
+ AssertItersEqual(iter.get(), &kvi);
+}
+
+TEST_P(WriteBatchWithIndexTest, GetFromBatchAfterMerge) {
+ std::string value;
+ Status s;
+
+ ASSERT_OK(OpenDB());
+ ASSERT_OK(db_->Put(write_opts_, "o", "aa"));
+ batch_->Merge("o", "bb"); // Merging bb under key "o"
+ batch_->Merge("m", "cc"); // Merging bc under key "m"
+ s = batch_->GetFromBatch(options_, "m", &value);
+ ASSERT_EQ(s.code(), Status::Code::kMergeInProgress);
+ s = batch_->GetFromBatch(options_, "o", &value);
+ ASSERT_EQ(s.code(), Status::Code::kMergeInProgress);
+
+ ASSERT_OK(db_->Write(write_opts_, batch_->GetWriteBatch()));
+ ASSERT_OK(db_->Get(read_opts_, "o", &value));
+ ASSERT_EQ(value, "aa,bb");
+ ASSERT_OK(db_->Get(read_opts_, "m", &value));
+ ASSERT_EQ(value, "cc");
+}
+
+TEST_P(WriteBatchWithIndexTest, GetFromBatchAndDBAfterMerge) {
+ std::string value;
+
+ ASSERT_OK(OpenDB());
+ ASSERT_OK(db_->Put(write_opts_, "o", "aa"));
+ ASSERT_OK(batch_->Merge("o", "bb")); // Merging bb under key "o"
+ ASSERT_OK(batch_->Merge("m", "cc")); // Merging bc under key "m"
+ ASSERT_OK(batch_->GetFromBatchAndDB(db_, read_opts_, "o", &value));
+ ASSERT_EQ(value, "aa,bb");
+ ASSERT_OK(batch_->GetFromBatchAndDB(db_, read_opts_, "m", &value));
+ ASSERT_EQ(value, "cc");
+}
+
+TEST_F(WBWIKeepTest, GetAfterPut) {
+ std::string value;
+ ASSERT_OK(OpenDB());
+ ColumnFamilyHandle* cf0 = db_->DefaultColumnFamily();
+
+ ASSERT_OK(db_->Put(write_opts_, "key", "orig"));
+
+ ASSERT_OK(batch_->Put("key", "aa")); // Writing aa under key
+ ASSERT_OK(batch_->GetFromBatch(cf0, options_, "key", &value));
+ ASSERT_EQ(value, "aa");
+ ASSERT_OK(batch_->GetFromBatchAndDB(db_, read_opts_, "key", &value));
+ ASSERT_EQ(value, "aa");
+
+ ASSERT_OK(batch_->Merge("key", "bb")); // Merging bb under key
+ ASSERT_OK(batch_->GetFromBatch(cf0, options_, "key", &value));
+ ASSERT_EQ(value, "aa,bb");
+ ASSERT_OK(batch_->GetFromBatchAndDB(db_, read_opts_, "key", &value));
+ ASSERT_EQ(value, "aa,bb");
+
+ ASSERT_OK(batch_->Merge("key", "cc")); // Merging cc under key
+ ASSERT_OK(batch_->GetFromBatch(cf0, options_, "key", &value));
+ ASSERT_EQ(value, "aa,bb,cc");
+ ASSERT_OK(batch_->GetFromBatchAndDB(db_, read_opts_, "key", &value));
+ ASSERT_EQ(value, "aa,bb,cc");
+}
+
+TEST_P(WriteBatchWithIndexTest, GetAfterMergePut) {
+ std::string value;
+ ASSERT_OK(OpenDB());
+ ColumnFamilyHandle* cf0 = db_->DefaultColumnFamily();
+ ASSERT_OK(db_->Put(write_opts_, "key", "orig"));
+
+ ASSERT_OK(batch_->Merge("key", "aa")); // Merging aa under key
+ Status s = batch_->GetFromBatch(cf0, options_, "key", &value);
+ ASSERT_EQ(s.code(), Status::Code::kMergeInProgress);
+ ASSERT_OK(batch_->GetFromBatchAndDB(db_, read_opts_, "key", &value));
+ ASSERT_EQ(value, "orig,aa");
+
+ ASSERT_OK(batch_->Merge("key", "bb")); // Merging bb under key
+ s = batch_->GetFromBatch(cf0, options_, "key", &value);
+ ASSERT_EQ(s.code(), Status::Code::kMergeInProgress);
+ ASSERT_OK(batch_->GetFromBatchAndDB(db_, read_opts_, "key", &value));
+ ASSERT_EQ(value, "orig,aa,bb");
+
+ ASSERT_OK(batch_->Put("key", "cc")); // Writing cc under key
+ ASSERT_OK(batch_->GetFromBatch(cf0, options_, "key", &value));
+ ASSERT_EQ(value, "cc");
+ ASSERT_OK(batch_->GetFromBatchAndDB(db_, read_opts_, "key", &value));
+ ASSERT_EQ(value, "cc");
+
+ ASSERT_OK(batch_->Merge("key", "dd")); // Merging dd under key
+ ASSERT_OK(batch_->GetFromBatch(cf0, options_, "key", &value));
+ ASSERT_EQ(value, "cc,dd");
+ ASSERT_OK(batch_->GetFromBatchAndDB(db_, read_opts_, "key", &value));
+ ASSERT_EQ(value, "cc,dd");
+}
+
+TEST_P(WriteBatchWithIndexTest, GetAfterMergeDelete) {
+ std::string value;
+ ASSERT_OK(OpenDB());
+ ColumnFamilyHandle* cf0 = db_->DefaultColumnFamily();
+
+ ASSERT_OK(batch_->Merge("key", "aa")); // Merging aa under key
+ Status s = batch_->GetFromBatch(cf0, options_, "key", &value);
+ ASSERT_EQ(s.code(), Status::Code::kMergeInProgress);
+ ASSERT_OK(batch_->GetFromBatchAndDB(db_, read_opts_, "key", &value));
+ ASSERT_EQ(value, "aa");
+
+ ASSERT_OK(batch_->Merge("key", "bb")); // Merging bb under key
+ s = batch_->GetFromBatch(cf0, options_, "key", &value);
+ ASSERT_EQ(s.code(), Status::Code::kMergeInProgress);
+ ASSERT_OK(batch_->GetFromBatchAndDB(db_, read_opts_, "key", &value));
+ ASSERT_EQ(value, "aa,bb");
+
+ ASSERT_OK(batch_->Delete("key")); // Delete key from batch
+ s = batch_->GetFromBatch(cf0, options_, "key", &value);
+ ASSERT_TRUE(s.IsNotFound());
+ s = batch_->GetFromBatchAndDB(db_, read_opts_, "key", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ ASSERT_OK(batch_->Merge("key", "cc")); // Merging cc under key
+ ASSERT_OK(batch_->GetFromBatch(cf0, options_, "key", &value));
+ ASSERT_EQ(value, "cc");
+ ASSERT_OK(batch_->GetFromBatchAndDB(db_, read_opts_, "key", &value));
+ ASSERT_EQ(value, "cc");
+ ASSERT_OK(batch_->Merge("key", "dd")); // Merging dd under key
+ ASSERT_OK(batch_->GetFromBatch(cf0, options_, "key", &value));
+ ASSERT_EQ(value, "cc,dd");
+ ASSERT_OK(batch_->GetFromBatchAndDB(db_, read_opts_, "key", &value));
+ ASSERT_EQ(value, "cc,dd");
+}
+
+TEST_F(WBWIOverwriteTest, TestBadMergeOperator) {
+ class FailingMergeOperator : public MergeOperator {
+ public:
+ FailingMergeOperator() {}
+
+ bool FullMergeV2(const MergeOperationInput& /*merge_in*/,
+ MergeOperationOutput* /*merge_out*/) const override {
+ return false;
+ }
+
+ const char* Name() const override { return "Failing"; }
+ };
+ options_.merge_operator.reset(new FailingMergeOperator());
+ ASSERT_OK(OpenDB());
+
+ ColumnFamilyHandle* column_family = db_->DefaultColumnFamily();
+ std::string value;
+
+ ASSERT_OK(db_->Put(write_opts_, "a", "a0"));
+ ASSERT_OK(batch_->Put("b", "b0"));
+
+ ASSERT_OK(batch_->Merge("a", "a1"));
+ ASSERT_NOK(batch_->GetFromBatchAndDB(db_, read_opts_, "a", &value));
+ ASSERT_NOK(batch_->GetFromBatch(column_family, options_, "a", &value));
+ ASSERT_OK(batch_->GetFromBatchAndDB(db_, read_opts_, "b", &value));
+ ASSERT_OK(batch_->GetFromBatch(column_family, options_, "b", &value));
+}
+
+TEST_P(WriteBatchWithIndexTest, ColumnFamilyWithTimestamp) {
+ ColumnFamilyHandleImplDummy cf2(2,
+ test::BytewiseComparatorWithU64TsWrapper());
+
+ // Sanity checks
+ ASSERT_TRUE(batch_->Put(&cf2, "key", "ts", "value").IsNotSupported());
+ ASSERT_TRUE(batch_->Put(/*column_family=*/nullptr, "key", "ts", "value")
+ .IsInvalidArgument());
+ ASSERT_TRUE(batch_->Delete(&cf2, "key", "ts").IsNotSupported());
+ ASSERT_TRUE(batch_->Delete(/*column_family=*/nullptr, "key", "ts")
+ .IsInvalidArgument());
+ ASSERT_TRUE(batch_->SingleDelete(&cf2, "key", "ts").IsNotSupported());
+ ASSERT_TRUE(batch_->SingleDelete(/*column_family=*/nullptr, "key", "ts")
+ .IsInvalidArgument());
+ {
+ std::string value;
+ ASSERT_TRUE(batch_
+ ->GetFromBatchAndDB(
+ /*db=*/nullptr, ReadOptions(), &cf2, "key", &value)
+ .IsInvalidArgument());
+ }
+ {
+ constexpr size_t num_keys = 2;
+ std::array<Slice, num_keys> keys{{Slice(), Slice()}};
+ std::array<PinnableSlice, num_keys> pinnable_vals{
+ {PinnableSlice(), PinnableSlice()}};
+ std::array<Status, num_keys> statuses{{Status(), Status()}};
+ constexpr bool sorted_input = false;
+ batch_->MultiGetFromBatchAndDB(/*db=*/nullptr, ReadOptions(), &cf2,
+ num_keys, keys.data(), pinnable_vals.data(),
+ statuses.data(), sorted_input);
+ for (const auto& s : statuses) {
+ ASSERT_TRUE(s.IsInvalidArgument());
+ }
+ }
+
+ constexpr uint32_t kMaxKey = 10;
+
+ const auto ts_sz_lookup = [&cf2](uint32_t id) {
+ if (cf2.GetID() == id) {
+ return sizeof(uint64_t);
+ } else {
+ return std::numeric_limits<size_t>::max();
+ }
+ };
+
+ // Put keys
+ for (uint32_t i = 0; i < kMaxKey; ++i) {
+ std::string key;
+ PutFixed32(&key, i);
+ Status s = batch_->Put(&cf2, key, "value" + std::to_string(i));
+ ASSERT_OK(s);
+ }
+
+ WriteBatch* wb = batch_->GetWriteBatch();
+ assert(wb);
+ ASSERT_OK(
+ wb->UpdateTimestamps(std::string(sizeof(uint64_t), '\0'), ts_sz_lookup));
+
+ // Point lookup
+ for (uint32_t i = 0; i < kMaxKey; ++i) {
+ std::string value;
+ std::string key;
+ PutFixed32(&key, i);
+ Status s = batch_->GetFromBatch(&cf2, Options(), key, &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("value" + std::to_string(i), value);
+ }
+
+ // Iterator
+ {
+ std::unique_ptr<WBWIIterator> it(batch_->NewIterator(&cf2));
+ uint32_t start = 0;
+ for (it->SeekToFirst(); it->Valid(); it->Next(), ++start) {
+ std::string key;
+ PutFixed32(&key, start);
+ ASSERT_OK(it->status());
+ ASSERT_EQ(key, it->Entry().key);
+ ASSERT_EQ("value" + std::to_string(start), it->Entry().value);
+ ASSERT_EQ(WriteType::kPutRecord, it->Entry().type);
+ }
+ ASSERT_EQ(kMaxKey, start);
+ }
+
+ // Delete the keys with Delete() or SingleDelete()
+ for (uint32_t i = 0; i < kMaxKey; ++i) {
+ std::string key;
+ PutFixed32(&key, i);
+ Status s;
+ if (0 == (i % 2)) {
+ s = batch_->Delete(&cf2, key);
+ } else {
+ s = batch_->SingleDelete(&cf2, key);
+ }
+ ASSERT_OK(s);
+ }
+
+ ASSERT_OK(wb->UpdateTimestamps(std::string(sizeof(uint64_t), '\xfe'),
+ ts_sz_lookup));
+
+ for (uint32_t i = 0; i < kMaxKey; ++i) {
+ std::string value;
+ std::string key;
+ PutFixed32(&key, i);
+ Status s = batch_->GetFromBatch(&cf2, Options(), key, &value);
+ ASSERT_TRUE(s.IsNotFound());
+ }
+
+ // Iterator
+ {
+ const bool overwrite = GetParam();
+ std::unique_ptr<WBWIIterator> it(batch_->NewIterator(&cf2));
+ uint32_t start = 0;
+ for (it->SeekToFirst(); it->Valid(); it->Next(), ++start) {
+ std::string key;
+ PutFixed32(&key, start);
+ ASSERT_EQ(key, it->Entry().key);
+ if (!overwrite) {
+ ASSERT_EQ(WriteType::kPutRecord, it->Entry().type);
+ it->Next();
+ ASSERT_TRUE(it->Valid());
+ }
+ if (0 == (start % 2)) {
+ ASSERT_EQ(WriteType::kDeleteRecord, it->Entry().type);
+ } else {
+ ASSERT_EQ(WriteType::kSingleDeleteRecord, it->Entry().type);
+ }
+ }
+ }
+}
+
+TEST_P(WriteBatchWithIndexTest, IndexNoTs) {
+ const Comparator* const ucmp = test::BytewiseComparatorWithU64TsWrapper();
+ ColumnFamilyHandleImplDummy cf(1, ucmp);
+ WriteBatchWithIndex wbwi;
+ ASSERT_OK(wbwi.Put(&cf, "a", "a0"));
+ ASSERT_OK(wbwi.Put(&cf, "a", "a1"));
+ {
+ std::string ts;
+ PutFixed64(&ts, 10000);
+ ASSERT_OK(wbwi.GetWriteBatch()->UpdateTimestamps(
+ ts, [](uint32_t cf_id) { return cf_id == 1 ? 8 : 0; }));
+ }
+ {
+ std::string value;
+ Status s = wbwi.GetFromBatch(&cf, options_, "a", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("a1", value);
+ }
+}
+
+INSTANTIATE_TEST_CASE_P(WBWI, WriteBatchWithIndexTest, testing::Bool());
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
+
+#else
+#include <stdio.h>
+
+int main() {
+ fprintf(stderr, "SKIPPED\n");
+ return 0;
+}
+
+#endif // !ROCKSDB_LITE