summaryrefslogtreecommitdiffstats
path: root/src/rocksdb/utilities
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--src/rocksdb/utilities/backupable/backupable_db.cc1989
-rw-r--r--src/rocksdb/utilities/backupable/backupable_db_test.cc1863
-rw-r--r--src/rocksdb/utilities/blob_db/blob_compaction_filter.cc329
-rw-r--r--src/rocksdb/utilities/blob_db/blob_compaction_filter.h168
-rw-r--r--src/rocksdb/utilities/blob_db/blob_db.cc102
-rw-r--r--src/rocksdb/utilities/blob_db/blob_db.h261
-rw-r--r--src/rocksdb/utilities/blob_db/blob_db_gc_stats.h52
-rw-r--r--src/rocksdb/utilities/blob_db/blob_db_impl.cc2116
-rw-r--r--src/rocksdb/utilities/blob_db/blob_db_impl.h495
-rw-r--r--src/rocksdb/utilities/blob_db/blob_db_impl_filesnapshot.cc109
-rw-r--r--src/rocksdb/utilities/blob_db/blob_db_iterator.h147
-rw-r--r--src/rocksdb/utilities/blob_db/blob_db_listener.h66
-rw-r--r--src/rocksdb/utilities/blob_db/blob_db_test.cc1992
-rw-r--r--src/rocksdb/utilities/blob_db/blob_dump_tool.cc278
-rw-r--r--src/rocksdb/utilities/blob_db/blob_dump_tool.h57
-rw-r--r--src/rocksdb/utilities/blob_db/blob_file.cc320
-rw-r--r--src/rocksdb/utilities/blob_db/blob_file.h252
-rw-r--r--src/rocksdb/utilities/blob_db/blob_log_format.cc149
-rw-r--r--src/rocksdb/utilities/blob_db/blob_log_format.h133
-rw-r--r--src/rocksdb/utilities/blob_db/blob_log_reader.cc105
-rw-r--r--src/rocksdb/utilities/blob_db/blob_log_reader.h82
-rw-r--r--src/rocksdb/utilities/blob_db/blob_log_writer.cc139
-rw-r--r--src/rocksdb/utilities/blob_db/blob_log_writer.h94
-rw-r--r--src/rocksdb/utilities/cassandra/cassandra_compaction_filter.cc47
-rw-r--r--src/rocksdb/utilities/cassandra/cassandra_compaction_filter.h42
-rw-r--r--src/rocksdb/utilities/cassandra/cassandra_format_test.cc367
-rw-r--r--src/rocksdb/utilities/cassandra/cassandra_functional_test.cc311
-rw-r--r--src/rocksdb/utilities/cassandra/cassandra_row_merge_test.cc112
-rw-r--r--src/rocksdb/utilities/cassandra/cassandra_serialize_test.cc188
-rw-r--r--src/rocksdb/utilities/cassandra/format.cc390
-rw-r--r--src/rocksdb/utilities/cassandra/format.h197
-rw-r--r--src/rocksdb/utilities/cassandra/merge_operator.cc67
-rw-r--r--src/rocksdb/utilities/cassandra/merge_operator.h44
-rw-r--r--src/rocksdb/utilities/cassandra/serialize.h75
-rw-r--r--src/rocksdb/utilities/cassandra/test_utils.cc75
-rw-r--r--src/rocksdb/utilities/cassandra/test_utils.h46
-rw-r--r--src/rocksdb/utilities/checkpoint/checkpoint_impl.cc516
-rw-r--r--src/rocksdb/utilities/checkpoint/checkpoint_impl.h79
-rw-r--r--src/rocksdb/utilities/checkpoint/checkpoint_test.cc829
-rw-r--r--src/rocksdb/utilities/compaction_filters/remove_emptyvalue_compactionfilter.cc29
-rw-r--r--src/rocksdb/utilities/compaction_filters/remove_emptyvalue_compactionfilter.h27
-rw-r--r--src/rocksdb/utilities/convenience/info_log_finder.cc25
-rw-r--r--src/rocksdb/utilities/debug.cc80
-rw-r--r--src/rocksdb/utilities/env_librados.cc1497
-rw-r--r--src/rocksdb/utilities/env_librados.md122
-rw-r--r--src/rocksdb/utilities/env_librados_test.cc1146
-rw-r--r--src/rocksdb/utilities/env_mirror.cc262
-rw-r--r--src/rocksdb/utilities/env_mirror_test.cc223
-rw-r--r--src/rocksdb/utilities/env_timed.cc145
-rw-r--r--src/rocksdb/utilities/env_timed_test.cc44
-rw-r--r--src/rocksdb/utilities/leveldb_options/leveldb_options.cc56
-rw-r--r--src/rocksdb/utilities/memory/memory_test.cc278
-rw-r--r--src/rocksdb/utilities/memory/memory_util.cc52
-rw-r--r--src/rocksdb/utilities/merge_operators.h55
-rw-r--r--src/rocksdb/utilities/merge_operators/bytesxor.cc59
-rw-r--r--src/rocksdb/utilities/merge_operators/bytesxor.h39
-rw-r--r--src/rocksdb/utilities/merge_operators/max.cc77
-rw-r--r--src/rocksdb/utilities/merge_operators/put.cc83
-rw-r--r--src/rocksdb/utilities/merge_operators/sortlist.cc100
-rw-r--r--src/rocksdb/utilities/merge_operators/sortlist.h38
-rw-r--r--src/rocksdb/utilities/merge_operators/string_append/stringappend.cc59
-rw-r--r--src/rocksdb/utilities/merge_operators/string_append/stringappend.h31
-rw-r--r--src/rocksdb/utilities/merge_operators/string_append/stringappend2.cc117
-rw-r--r--src/rocksdb/utilities/merge_operators/string_append/stringappend2.h49
-rw-r--r--src/rocksdb/utilities/merge_operators/string_append/stringappend_test.cc601
-rw-r--r--src/rocksdb/utilities/merge_operators/uint64add.cc69
-rw-r--r--src/rocksdb/utilities/object_registry.cc87
-rw-r--r--src/rocksdb/utilities/object_registry_test.cc174
-rw-r--r--src/rocksdb/utilities/option_change_migration/option_change_migration.cc168
-rw-r--r--src/rocksdb/utilities/option_change_migration/option_change_migration_test.cc425
-rw-r--r--src/rocksdb/utilities/options/options_util.cc114
-rw-r--r--src/rocksdb/utilities/options/options_util_test.cc363
-rw-r--r--src/rocksdb/utilities/persistent_cache/block_cache_tier.cc425
-rw-r--r--src/rocksdb/utilities/persistent_cache/block_cache_tier.h156
-rw-r--r--src/rocksdb/utilities/persistent_cache/block_cache_tier_file.cc608
-rw-r--r--src/rocksdb/utilities/persistent_cache/block_cache_tier_file.h296
-rw-r--r--src/rocksdb/utilities/persistent_cache/block_cache_tier_file_buffer.h127
-rw-r--r--src/rocksdb/utilities/persistent_cache/block_cache_tier_metadata.cc86
-rw-r--r--src/rocksdb/utilities/persistent_cache/block_cache_tier_metadata.h125
-rw-r--r--src/rocksdb/utilities/persistent_cache/hash_table.h238
-rw-r--r--src/rocksdb/utilities/persistent_cache/hash_table_bench.cc308
-rw-r--r--src/rocksdb/utilities/persistent_cache/hash_table_evictable.h168
-rw-r--r--src/rocksdb/utilities/persistent_cache/hash_table_test.cc160
-rw-r--r--src/rocksdb/utilities/persistent_cache/lrulist.h174
-rw-r--r--src/rocksdb/utilities/persistent_cache/persistent_cache_bench.cc360
-rw-r--r--src/rocksdb/utilities/persistent_cache/persistent_cache_test.cc474
-rw-r--r--src/rocksdb/utilities/persistent_cache/persistent_cache_test.h285
-rw-r--r--src/rocksdb/utilities/persistent_cache/persistent_cache_tier.cc163
-rw-r--r--src/rocksdb/utilities/persistent_cache/persistent_cache_tier.h336
-rw-r--r--src/rocksdb/utilities/persistent_cache/persistent_cache_util.h67
-rw-r--r--src/rocksdb/utilities/persistent_cache/volatile_tier_impl.cc138
-rw-r--r--src/rocksdb/utilities/persistent_cache/volatile_tier_impl.h142
-rw-r--r--src/rocksdb/utilities/simulator_cache/cache_simulator.cc274
-rw-r--r--src/rocksdb/utilities/simulator_cache/cache_simulator.h231
-rw-r--r--src/rocksdb/utilities/simulator_cache/cache_simulator_test.cc494
-rw-r--r--src/rocksdb/utilities/simulator_cache/sim_cache.cc354
-rw-r--r--src/rocksdb/utilities/simulator_cache/sim_cache_test.cc225
-rw-r--r--src/rocksdb/utilities/table_properties_collectors/compact_on_deletion_collector.cc90
-rw-r--r--src/rocksdb/utilities/table_properties_collectors/compact_on_deletion_collector.h72
-rw-r--r--src/rocksdb/utilities/table_properties_collectors/compact_on_deletion_collector_test.cc178
-rw-r--r--src/rocksdb/utilities/trace/file_trace_reader_writer.cc123
-rw-r--r--src/rocksdb/utilities/trace/file_trace_reader_writer.h48
-rw-r--r--src/rocksdb/utilities/transactions/optimistic_transaction.cc187
-rw-r--r--src/rocksdb/utilities/transactions/optimistic_transaction.h101
-rw-r--r--src/rocksdb/utilities/transactions/optimistic_transaction_db_impl.cc111
-rw-r--r--src/rocksdb/utilities/transactions/optimistic_transaction_db_impl.h71
-rw-r--r--src/rocksdb/utilities/transactions/optimistic_transaction_test.cc1535
-rw-r--r--src/rocksdb/utilities/transactions/pessimistic_transaction.cc723
-rw-r--r--src/rocksdb/utilities/transactions/pessimistic_transaction.h225
-rw-r--r--src/rocksdb/utilities/transactions/pessimistic_transaction_db.cc632
-rw-r--r--src/rocksdb/utilities/transactions/pessimistic_transaction_db.h220
-rw-r--r--src/rocksdb/utilities/transactions/snapshot_checker.cc49
-rw-r--r--src/rocksdb/utilities/transactions/transaction_base.cc837
-rw-r--r--src/rocksdb/utilities/transactions/transaction_base.h374
-rw-r--r--src/rocksdb/utilities/transactions/transaction_db_mutex_impl.cc135
-rw-r--r--src/rocksdb/utilities/transactions/transaction_db_mutex_impl.h26
-rw-r--r--src/rocksdb/utilities/transactions/transaction_lock_mgr.cc745
-rw-r--r--src/rocksdb/utilities/transactions/transaction_lock_mgr.h158
-rw-r--r--src/rocksdb/utilities/transactions/transaction_test.cc6224
-rw-r--r--src/rocksdb/utilities/transactions/transaction_test.h517
-rw-r--r--src/rocksdb/utilities/transactions/transaction_util.cc182
-rw-r--r--src/rocksdb/utilities/transactions/transaction_util.h103
-rw-r--r--src/rocksdb/utilities/transactions/write_prepared_transaction_test.cc3524
-rw-r--r--src/rocksdb/utilities/transactions/write_prepared_txn.cc473
-rw-r--r--src/rocksdb/utilities/transactions/write_prepared_txn.h119
-rw-r--r--src/rocksdb/utilities/transactions/write_prepared_txn_db.cc998
-rw-r--r--src/rocksdb/utilities/transactions/write_prepared_txn_db.h1111
-rw-r--r--src/rocksdb/utilities/transactions/write_unprepared_transaction_test.cc727
-rw-r--r--src/rocksdb/utilities/transactions/write_unprepared_txn.cc999
-rw-r--r--src/rocksdb/utilities/transactions/write_unprepared_txn.h341
-rw-r--r--src/rocksdb/utilities/transactions/write_unprepared_txn_db.cc468
-rw-r--r--src/rocksdb/utilities/transactions/write_unprepared_txn_db.h148
-rw-r--r--src/rocksdb/utilities/ttl/db_ttl_impl.cc335
-rw-r--r--src/rocksdb/utilities/ttl/db_ttl_impl.h361
-rw-r--r--src/rocksdb/utilities/ttl/ttl_test.cc693
-rw-r--r--src/rocksdb/utilities/util_merge_operators_test.cc99
-rw-r--r--src/rocksdb/utilities/write_batch_with_index/write_batch_with_index.cc1065
-rw-r--r--src/rocksdb/utilities/write_batch_with_index/write_batch_with_index_internal.cc288
-rw-r--r--src/rocksdb/utilities/write_batch_with_index/write_batch_with_index_internal.h145
-rw-r--r--src/rocksdb/utilities/write_batch_with_index/write_batch_with_index_test.cc1846
140 files changed, 55197 insertions, 0 deletions
diff --git a/src/rocksdb/utilities/backupable/backupable_db.cc b/src/rocksdb/utilities/backupable/backupable_db.cc
new file mode 100644
index 000000000..0ca67670b
--- /dev/null
+++ b/src/rocksdb/utilities/backupable/backupable_db.cc
@@ -0,0 +1,1989 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#ifndef ROCKSDB_LITE
+
+#include <stdlib.h>
+#include <algorithm>
+#include <atomic>
+#include <cinttypes>
+#include <functional>
+#include <future>
+#include <limits>
+#include <map>
+#include <mutex>
+#include <sstream>
+#include <string>
+#include <thread>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "env/composite_env_wrapper.h"
+#include "file/filename.h"
+#include "file/sequence_file_reader.h"
+#include "file/writable_file_writer.h"
+#include "logging/logging.h"
+#include "port/port.h"
+#include "rocksdb/rate_limiter.h"
+#include "rocksdb/transaction_log.h"
+#include "rocksdb/utilities/backupable_db.h"
+#include "test_util/sync_point.h"
+#include "util/channel.h"
+#include "util/coding.h"
+#include "util/crc32c.h"
+#include "util/string_util.h"
+#include "utilities/checkpoint/checkpoint_impl.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+void BackupStatistics::IncrementNumberSuccessBackup() {
+ number_success_backup++;
+}
+void BackupStatistics::IncrementNumberFailBackup() {
+ number_fail_backup++;
+}
+
+uint32_t BackupStatistics::GetNumberSuccessBackup() const {
+ return number_success_backup;
+}
+uint32_t BackupStatistics::GetNumberFailBackup() const {
+ return number_fail_backup;
+}
+
+std::string BackupStatistics::ToString() const {
+ char result[50];
+ snprintf(result, sizeof(result), "# success backup: %u, # fail backup: %u",
+ GetNumberSuccessBackup(), GetNumberFailBackup());
+ return result;
+}
+
+void BackupableDBOptions::Dump(Logger* logger) const {
+ ROCKS_LOG_INFO(logger, " Options.backup_dir: %s",
+ backup_dir.c_str());
+ ROCKS_LOG_INFO(logger, " Options.backup_env: %p", backup_env);
+ ROCKS_LOG_INFO(logger, " Options.share_table_files: %d",
+ static_cast<int>(share_table_files));
+ ROCKS_LOG_INFO(logger, " Options.info_log: %p", info_log);
+ ROCKS_LOG_INFO(logger, " Options.sync: %d",
+ static_cast<int>(sync));
+ ROCKS_LOG_INFO(logger, " Options.destroy_old_data: %d",
+ static_cast<int>(destroy_old_data));
+ ROCKS_LOG_INFO(logger, " Options.backup_log_files: %d",
+ static_cast<int>(backup_log_files));
+ ROCKS_LOG_INFO(logger, " Options.backup_rate_limit: %" PRIu64,
+ backup_rate_limit);
+ ROCKS_LOG_INFO(logger, " Options.restore_rate_limit: %" PRIu64,
+ restore_rate_limit);
+ ROCKS_LOG_INFO(logger, "Options.max_background_operations: %d",
+ max_background_operations);
+}
+
+// -------- BackupEngineImpl class ---------
+class BackupEngineImpl : public BackupEngine {
+ public:
+ BackupEngineImpl(Env* db_env, const BackupableDBOptions& options,
+ bool read_only = false);
+ ~BackupEngineImpl() override;
+ Status CreateNewBackupWithMetadata(DB* db, const std::string& app_metadata,
+ bool flush_before_backup = false,
+ std::function<void()> progress_callback =
+ []() {}) override;
+ Status PurgeOldBackups(uint32_t num_backups_to_keep) override;
+ Status DeleteBackup(BackupID backup_id) override;
+ void StopBackup() override {
+ stop_backup_.store(true, std::memory_order_release);
+ }
+ Status GarbageCollect() override;
+
+ // The returned BackupInfos are in chronological order, which means the
+ // latest backup comes last.
+ void GetBackupInfo(std::vector<BackupInfo>* backup_info) override;
+ void GetCorruptedBackups(std::vector<BackupID>* corrupt_backup_ids) override;
+ Status RestoreDBFromBackup(
+ BackupID backup_id, const std::string& db_dir, const std::string& wal_dir,
+ const RestoreOptions& restore_options = RestoreOptions()) override;
+ Status RestoreDBFromLatestBackup(
+ const std::string& db_dir, const std::string& wal_dir,
+ const RestoreOptions& restore_options = RestoreOptions()) override {
+ return RestoreDBFromBackup(latest_valid_backup_id_, db_dir, wal_dir,
+ restore_options);
+ }
+
+ Status VerifyBackup(BackupID backup_id) override;
+
+ Status Initialize();
+
+ private:
+ void DeleteChildren(const std::string& dir, uint32_t file_type_filter = 0);
+ Status DeleteBackupInternal(BackupID backup_id);
+
+ // Extends the "result" map with pathname->size mappings for the contents of
+ // "dir" in "env". Pathnames are prefixed with "dir".
+ Status InsertPathnameToSizeBytes(
+ const std::string& dir, Env* env,
+ std::unordered_map<std::string, uint64_t>* result);
+
+ struct FileInfo {
+ FileInfo(const std::string& fname, uint64_t sz, uint32_t checksum)
+ : refs(0), filename(fname), size(sz), checksum_value(checksum) {}
+
+ FileInfo(const FileInfo&) = delete;
+ FileInfo& operator=(const FileInfo&) = delete;
+
+ int refs;
+ const std::string filename;
+ const uint64_t size;
+ const uint32_t checksum_value;
+ };
+
+ class BackupMeta {
+ public:
+ BackupMeta(
+ const std::string& meta_filename, const std::string& meta_tmp_filename,
+ std::unordered_map<std::string, std::shared_ptr<FileInfo>>* file_infos,
+ Env* env)
+ : timestamp_(0),
+ sequence_number_(0),
+ size_(0),
+ meta_filename_(meta_filename),
+ meta_tmp_filename_(meta_tmp_filename),
+ file_infos_(file_infos),
+ env_(env) {}
+
+ BackupMeta(const BackupMeta&) = delete;
+ BackupMeta& operator=(const BackupMeta&) = delete;
+
+ ~BackupMeta() {}
+
+ void RecordTimestamp() {
+ env_->GetCurrentTime(&timestamp_);
+ }
+ int64_t GetTimestamp() const {
+ return timestamp_;
+ }
+ uint64_t GetSize() const {
+ return size_;
+ }
+ uint32_t GetNumberFiles() { return static_cast<uint32_t>(files_.size()); }
+ void SetSequenceNumber(uint64_t sequence_number) {
+ sequence_number_ = sequence_number;
+ }
+ uint64_t GetSequenceNumber() {
+ return sequence_number_;
+ }
+
+ const std::string& GetAppMetadata() const { return app_metadata_; }
+
+ void SetAppMetadata(const std::string& app_metadata) {
+ app_metadata_ = app_metadata;
+ }
+
+ Status AddFile(std::shared_ptr<FileInfo> file_info);
+
+ Status Delete(bool delete_meta = true);
+
+ bool Empty() {
+ return files_.empty();
+ }
+
+ std::shared_ptr<FileInfo> GetFile(const std::string& filename) const {
+ auto it = file_infos_->find(filename);
+ if (it == file_infos_->end())
+ return nullptr;
+ return it->second;
+ }
+
+ const std::vector<std::shared_ptr<FileInfo>>& GetFiles() {
+ return files_;
+ }
+
+ // @param abs_path_to_size Pre-fetched file sizes (bytes).
+ Status LoadFromFile(
+ const std::string& backup_dir,
+ const std::unordered_map<std::string, uint64_t>& abs_path_to_size);
+ Status StoreToFile(bool sync);
+
+ std::string GetInfoString() {
+ std::ostringstream ss;
+ ss << "Timestamp: " << timestamp_ << std::endl;
+ char human_size[16];
+ AppendHumanBytes(size_, human_size, sizeof(human_size));
+ ss << "Size: " << human_size << std::endl;
+ ss << "Files:" << std::endl;
+ for (const auto& file : files_) {
+ AppendHumanBytes(file->size, human_size, sizeof(human_size));
+ ss << file->filename << ", size " << human_size << ", refs "
+ << file->refs << std::endl;
+ }
+ return ss.str();
+ }
+
+ private:
+ int64_t timestamp_;
+ // sequence number is only approximate, should not be used
+ // by clients
+ uint64_t sequence_number_;
+ uint64_t size_;
+ std::string app_metadata_;
+ std::string const meta_filename_;
+ std::string const meta_tmp_filename_;
+ // files with relative paths (without "/" prefix!!)
+ std::vector<std::shared_ptr<FileInfo>> files_;
+ std::unordered_map<std::string, std::shared_ptr<FileInfo>>* file_infos_;
+ Env* env_;
+
+ static const size_t max_backup_meta_file_size_ = 10 * 1024 * 1024; // 10MB
+ }; // BackupMeta
+
+ inline std::string GetAbsolutePath(
+ const std::string &relative_path = "") const {
+ assert(relative_path.size() == 0 || relative_path[0] != '/');
+ return options_.backup_dir + "/" + relative_path;
+ }
+ inline std::string GetPrivateDirRel() const {
+ return "private";
+ }
+ inline std::string GetSharedChecksumDirRel() const {
+ return "shared_checksum";
+ }
+ inline std::string GetPrivateFileRel(BackupID backup_id,
+ bool tmp = false,
+ const std::string& file = "") const {
+ assert(file.size() == 0 || file[0] != '/');
+ return GetPrivateDirRel() + "/" + ROCKSDB_NAMESPACE::ToString(backup_id) +
+ (tmp ? ".tmp" : "") + "/" + file;
+ }
+ inline std::string GetSharedFileRel(const std::string& file = "",
+ bool tmp = false) const {
+ assert(file.size() == 0 || file[0] != '/');
+ return std::string("shared/") + (tmp ? "." : "") + file +
+ (tmp ? ".tmp" : "");
+ }
+ inline std::string GetSharedFileWithChecksumRel(const std::string& file = "",
+ bool tmp = false) const {
+ assert(file.size() == 0 || file[0] != '/');
+ return GetSharedChecksumDirRel() + "/" + (tmp ? "." : "") + file +
+ (tmp ? ".tmp" : "");
+ }
+ inline std::string GetSharedFileWithChecksum(const std::string& file,
+ const uint32_t checksum_value,
+ const uint64_t file_size) const {
+ assert(file.size() == 0 || file[0] != '/');
+ std::string file_copy = file;
+ return file_copy.insert(file_copy.find_last_of('.'),
+ "_" + ROCKSDB_NAMESPACE::ToString(checksum_value) +
+ "_" + ROCKSDB_NAMESPACE::ToString(file_size));
+ }
+ inline std::string GetFileFromChecksumFile(const std::string& file) const {
+ assert(file.size() == 0 || file[0] != '/');
+ std::string file_copy = file;
+ size_t first_underscore = file_copy.find_first_of('_');
+ return file_copy.erase(first_underscore,
+ file_copy.find_last_of('.') - first_underscore);
+ }
+ inline std::string GetBackupMetaDir() const {
+ return GetAbsolutePath("meta");
+ }
+ inline std::string GetBackupMetaFile(BackupID backup_id, bool tmp) const {
+ return GetBackupMetaDir() + "/" + (tmp ? "." : "") +
+ ROCKSDB_NAMESPACE::ToString(backup_id) + (tmp ? ".tmp" : "");
+ }
+
+ // If size_limit == 0, there is no size limit, copy everything.
+ //
+ // Exactly one of src and contents must be non-empty.
+ //
+ // @param src If non-empty, the file is copied from this pathname.
+ // @param contents If non-empty, the file will be created with these contents.
+ Status CopyOrCreateFile(const std::string& src, const std::string& dst,
+ const std::string& contents, Env* src_env,
+ Env* dst_env, const EnvOptions& src_env_options,
+ bool sync, RateLimiter* rate_limiter,
+ uint64_t* size = nullptr,
+ uint32_t* checksum_value = nullptr,
+ uint64_t size_limit = 0,
+ std::function<void()> progress_callback = []() {});
+
+ Status CalculateChecksum(const std::string& src, Env* src_env,
+ const EnvOptions& src_env_options,
+ uint64_t size_limit, uint32_t* checksum_value);
+
+ struct CopyOrCreateResult {
+ uint64_t size;
+ uint32_t checksum_value;
+ Status status;
+ };
+
+ // Exactly one of src_path and contents must be non-empty. If src_path is
+ // non-empty, the file is copied from this pathname. Otherwise, if contents is
+ // non-empty, the file will be created at dst_path with these contents.
+ struct CopyOrCreateWorkItem {
+ std::string src_path;
+ std::string dst_path;
+ std::string contents;
+ Env* src_env;
+ Env* dst_env;
+ EnvOptions src_env_options;
+ bool sync;
+ RateLimiter* rate_limiter;
+ uint64_t size_limit;
+ std::promise<CopyOrCreateResult> result;
+ std::function<void()> progress_callback;
+
+ CopyOrCreateWorkItem()
+ : src_path(""),
+ dst_path(""),
+ contents(""),
+ src_env(nullptr),
+ dst_env(nullptr),
+ src_env_options(),
+ sync(false),
+ rate_limiter(nullptr),
+ size_limit(0) {}
+
+ CopyOrCreateWorkItem(const CopyOrCreateWorkItem&) = delete;
+ CopyOrCreateWorkItem& operator=(const CopyOrCreateWorkItem&) = delete;
+
+ CopyOrCreateWorkItem(CopyOrCreateWorkItem&& o) ROCKSDB_NOEXCEPT {
+ *this = std::move(o);
+ }
+
+ CopyOrCreateWorkItem& operator=(CopyOrCreateWorkItem&& o) ROCKSDB_NOEXCEPT {
+ src_path = std::move(o.src_path);
+ dst_path = std::move(o.dst_path);
+ contents = std::move(o.contents);
+ src_env = o.src_env;
+ dst_env = o.dst_env;
+ src_env_options = std::move(o.src_env_options);
+ sync = o.sync;
+ rate_limiter = o.rate_limiter;
+ size_limit = o.size_limit;
+ result = std::move(o.result);
+ progress_callback = std::move(o.progress_callback);
+ return *this;
+ }
+
+ CopyOrCreateWorkItem(std::string _src_path, std::string _dst_path,
+ std::string _contents, Env* _src_env, Env* _dst_env,
+ EnvOptions _src_env_options, bool _sync,
+ RateLimiter* _rate_limiter, uint64_t _size_limit,
+ std::function<void()> _progress_callback = []() {})
+ : src_path(std::move(_src_path)),
+ dst_path(std::move(_dst_path)),
+ contents(std::move(_contents)),
+ src_env(_src_env),
+ dst_env(_dst_env),
+ src_env_options(std::move(_src_env_options)),
+ sync(_sync),
+ rate_limiter(_rate_limiter),
+ size_limit(_size_limit),
+ progress_callback(_progress_callback) {}
+ };
+
+ struct BackupAfterCopyOrCreateWorkItem {
+ std::future<CopyOrCreateResult> result;
+ bool shared;
+ bool needed_to_copy;
+ Env* backup_env;
+ std::string dst_path_tmp;
+ std::string dst_path;
+ std::string dst_relative;
+ BackupAfterCopyOrCreateWorkItem()
+ : shared(false),
+ needed_to_copy(false),
+ backup_env(nullptr),
+ dst_path_tmp(""),
+ dst_path(""),
+ dst_relative("") {}
+
+ BackupAfterCopyOrCreateWorkItem(BackupAfterCopyOrCreateWorkItem&& o)
+ ROCKSDB_NOEXCEPT {
+ *this = std::move(o);
+ }
+
+ BackupAfterCopyOrCreateWorkItem& operator=(
+ BackupAfterCopyOrCreateWorkItem&& o) ROCKSDB_NOEXCEPT {
+ result = std::move(o.result);
+ shared = o.shared;
+ needed_to_copy = o.needed_to_copy;
+ backup_env = o.backup_env;
+ dst_path_tmp = std::move(o.dst_path_tmp);
+ dst_path = std::move(o.dst_path);
+ dst_relative = std::move(o.dst_relative);
+ return *this;
+ }
+
+ BackupAfterCopyOrCreateWorkItem(std::future<CopyOrCreateResult>&& _result,
+ bool _shared, bool _needed_to_copy,
+ Env* _backup_env, std::string _dst_path_tmp,
+ std::string _dst_path,
+ std::string _dst_relative)
+ : result(std::move(_result)),
+ shared(_shared),
+ needed_to_copy(_needed_to_copy),
+ backup_env(_backup_env),
+ dst_path_tmp(std::move(_dst_path_tmp)),
+ dst_path(std::move(_dst_path)),
+ dst_relative(std::move(_dst_relative)) {}
+ };
+
+ struct RestoreAfterCopyOrCreateWorkItem {
+ std::future<CopyOrCreateResult> result;
+ uint32_t checksum_value;
+ RestoreAfterCopyOrCreateWorkItem()
+ : checksum_value(0) {}
+ RestoreAfterCopyOrCreateWorkItem(std::future<CopyOrCreateResult>&& _result,
+ uint32_t _checksum_value)
+ : result(std::move(_result)), checksum_value(_checksum_value) {}
+ RestoreAfterCopyOrCreateWorkItem(RestoreAfterCopyOrCreateWorkItem&& o)
+ ROCKSDB_NOEXCEPT {
+ *this = std::move(o);
+ }
+
+ RestoreAfterCopyOrCreateWorkItem& operator=(
+ RestoreAfterCopyOrCreateWorkItem&& o) ROCKSDB_NOEXCEPT {
+ result = std::move(o.result);
+ checksum_value = o.checksum_value;
+ return *this;
+ }
+ };
+
+ bool initialized_;
+ std::mutex byte_report_mutex_;
+ channel<CopyOrCreateWorkItem> files_to_copy_or_create_;
+ std::vector<port::Thread> threads_;
+ // Certain operations like PurgeOldBackups and DeleteBackup will trigger
+ // automatic GarbageCollect (true) unless we've already done one in this
+ // session and have not failed to delete backup files since then (false).
+ bool might_need_garbage_collect_ = true;
+
+ // Adds a file to the backup work queue to be copied or created if it doesn't
+ // already exist.
+ //
+ // Exactly one of src_dir and contents must be non-empty.
+ //
+ // @param src_dir If non-empty, the file in this directory named fname will be
+ // copied.
+ // @param fname Name of destination file and, in case of copy, source file.
+ // @param contents If non-empty, the file will be created with these contents.
+ Status AddBackupFileWorkItem(
+ std::unordered_set<std::string>& live_dst_paths,
+ std::vector<BackupAfterCopyOrCreateWorkItem>& backup_items_to_finish,
+ BackupID backup_id, bool shared, const std::string& src_dir,
+ const std::string& fname, // starts with "/"
+ const EnvOptions& src_env_options, RateLimiter* rate_limiter,
+ uint64_t size_bytes, uint64_t size_limit = 0,
+ bool shared_checksum = false,
+ std::function<void()> progress_callback = []() {},
+ const std::string& contents = std::string());
+
+ // backup state data
+ BackupID latest_backup_id_;
+ BackupID latest_valid_backup_id_;
+ std::map<BackupID, std::unique_ptr<BackupMeta>> backups_;
+ std::map<BackupID, std::pair<Status, std::unique_ptr<BackupMeta>>>
+ corrupt_backups_;
+ std::unordered_map<std::string,
+ std::shared_ptr<FileInfo>> backuped_file_infos_;
+ std::atomic<bool> stop_backup_;
+
+ // options data
+ BackupableDBOptions options_;
+ Env* db_env_;
+ Env* backup_env_;
+
+ // directories
+ std::unique_ptr<Directory> backup_directory_;
+ std::unique_ptr<Directory> shared_directory_;
+ std::unique_ptr<Directory> meta_directory_;
+ std::unique_ptr<Directory> private_directory_;
+
+ static const size_t kDefaultCopyFileBufferSize = 5 * 1024 * 1024LL; // 5MB
+ size_t copy_file_buffer_size_;
+ bool read_only_;
+ BackupStatistics backup_statistics_;
+ static const size_t kMaxAppMetaSize = 1024 * 1024; // 1MB
+};
+
+Status BackupEngine::Open(Env* env, const BackupableDBOptions& options,
+ BackupEngine** backup_engine_ptr) {
+ std::unique_ptr<BackupEngineImpl> backup_engine(
+ new BackupEngineImpl(env, options));
+ auto s = backup_engine->Initialize();
+ if (!s.ok()) {
+ *backup_engine_ptr = nullptr;
+ return s;
+ }
+ *backup_engine_ptr = backup_engine.release();
+ return Status::OK();
+}
+
+BackupEngineImpl::BackupEngineImpl(Env* db_env,
+ const BackupableDBOptions& options,
+ bool read_only)
+ : initialized_(false),
+ latest_backup_id_(0),
+ latest_valid_backup_id_(0),
+ stop_backup_(false),
+ options_(options),
+ db_env_(db_env),
+ backup_env_(options.backup_env != nullptr ? options.backup_env : db_env_),
+ copy_file_buffer_size_(kDefaultCopyFileBufferSize),
+ read_only_(read_only) {
+ if (options_.backup_rate_limiter == nullptr &&
+ options_.backup_rate_limit > 0) {
+ options_.backup_rate_limiter.reset(
+ NewGenericRateLimiter(options_.backup_rate_limit));
+ }
+ if (options_.restore_rate_limiter == nullptr &&
+ options_.restore_rate_limit > 0) {
+ options_.restore_rate_limiter.reset(
+ NewGenericRateLimiter(options_.restore_rate_limit));
+ }
+}
+
+BackupEngineImpl::~BackupEngineImpl() {
+ files_to_copy_or_create_.sendEof();
+ for (auto& t : threads_) {
+ t.join();
+ }
+ LogFlush(options_.info_log);
+}
+
+Status BackupEngineImpl::Initialize() {
+ assert(!initialized_);
+ initialized_ = true;
+ if (read_only_) {
+ ROCKS_LOG_INFO(options_.info_log, "Starting read_only backup engine");
+ }
+ options_.Dump(options_.info_log);
+
+ if (!read_only_) {
+ // we might need to clean up from previous crash or I/O errors
+ might_need_garbage_collect_ = true;
+
+ if (options_.max_valid_backups_to_open != port::kMaxInt32) {
+ options_.max_valid_backups_to_open = port::kMaxInt32;
+ ROCKS_LOG_WARN(
+ options_.info_log,
+ "`max_valid_backups_to_open` is not set to the default value. Ignoring "
+ "its value since BackupEngine is not read-only.");
+ }
+
+ // gather the list of directories that we need to create
+ std::vector<std::pair<std::string, std::unique_ptr<Directory>*>>
+ directories;
+ directories.emplace_back(GetAbsolutePath(), &backup_directory_);
+ if (options_.share_table_files) {
+ if (options_.share_files_with_checksum) {
+ directories.emplace_back(
+ GetAbsolutePath(GetSharedFileWithChecksumRel()),
+ &shared_directory_);
+ } else {
+ directories.emplace_back(GetAbsolutePath(GetSharedFileRel()),
+ &shared_directory_);
+ }
+ }
+ directories.emplace_back(GetAbsolutePath(GetPrivateDirRel()),
+ &private_directory_);
+ directories.emplace_back(GetBackupMetaDir(), &meta_directory_);
+ // create all the dirs we need
+ for (const auto& d : directories) {
+ auto s = backup_env_->CreateDirIfMissing(d.first);
+ if (s.ok()) {
+ s = backup_env_->NewDirectory(d.first, d.second);
+ }
+ if (!s.ok()) {
+ return s;
+ }
+ }
+ }
+
+ std::vector<std::string> backup_meta_files;
+ {
+ auto s = backup_env_->GetChildren(GetBackupMetaDir(), &backup_meta_files);
+ if (s.IsNotFound()) {
+ return Status::NotFound(GetBackupMetaDir() + " is missing");
+ } else if (!s.ok()) {
+ return s;
+ }
+ }
+ // create backups_ structure
+ for (auto& file : backup_meta_files) {
+ if (file == "." || file == "..") {
+ continue;
+ }
+ ROCKS_LOG_INFO(options_.info_log, "Detected backup %s", file.c_str());
+ BackupID backup_id = 0;
+ sscanf(file.c_str(), "%u", &backup_id);
+ if (backup_id == 0 || file != ROCKSDB_NAMESPACE::ToString(backup_id)) {
+ if (!read_only_) {
+ // invalid file name, delete that
+ auto s = backup_env_->DeleteFile(GetBackupMetaDir() + "/" + file);
+ ROCKS_LOG_INFO(options_.info_log,
+ "Unrecognized meta file %s, deleting -- %s",
+ file.c_str(), s.ToString().c_str());
+ }
+ continue;
+ }
+ assert(backups_.find(backup_id) == backups_.end());
+ backups_.insert(std::make_pair(
+ backup_id, std::unique_ptr<BackupMeta>(new BackupMeta(
+ GetBackupMetaFile(backup_id, false /* tmp */),
+ GetBackupMetaFile(backup_id, true /* tmp */),
+ &backuped_file_infos_, backup_env_))));
+ }
+
+ latest_backup_id_ = 0;
+ latest_valid_backup_id_ = 0;
+ if (options_.destroy_old_data) { // Destroy old data
+ assert(!read_only_);
+ ROCKS_LOG_INFO(
+ options_.info_log,
+ "Backup Engine started with destroy_old_data == true, deleting all "
+ "backups");
+ auto s = PurgeOldBackups(0);
+ if (s.ok()) {
+ s = GarbageCollect();
+ }
+ if (!s.ok()) {
+ return s;
+ }
+ } else { // Load data from storage
+ std::unordered_map<std::string, uint64_t> abs_path_to_size;
+ for (const auto& rel_dir :
+ {GetSharedFileRel(), GetSharedFileWithChecksumRel()}) {
+ const auto abs_dir = GetAbsolutePath(rel_dir);
+ InsertPathnameToSizeBytes(abs_dir, backup_env_, &abs_path_to_size);
+ }
+ // load the backups if any, until valid_backups_to_open of the latest
+ // non-corrupted backups have been successfully opened.
+ int valid_backups_to_open = options_.max_valid_backups_to_open;
+ for (auto backup_iter = backups_.rbegin();
+ backup_iter != backups_.rend();
+ ++backup_iter) {
+ assert(latest_backup_id_ == 0 || latest_backup_id_ > backup_iter->first);
+ if (latest_backup_id_ == 0) {
+ latest_backup_id_ = backup_iter->first;
+ }
+ if (valid_backups_to_open == 0) {
+ break;
+ }
+
+ InsertPathnameToSizeBytes(
+ GetAbsolutePath(GetPrivateFileRel(backup_iter->first)), backup_env_,
+ &abs_path_to_size);
+ Status s = backup_iter->second->LoadFromFile(options_.backup_dir,
+ abs_path_to_size);
+ if (s.IsCorruption()) {
+ ROCKS_LOG_INFO(options_.info_log, "Backup %u corrupted -- %s",
+ backup_iter->first, s.ToString().c_str());
+ corrupt_backups_.insert(
+ std::make_pair(backup_iter->first,
+ std::make_pair(s, std::move(backup_iter->second))));
+ } else if (!s.ok()) {
+ // Distinguish corruption errors from errors in the backup Env.
+ // Errors in the backup Env (i.e., this code path) will cause Open() to
+ // fail, whereas corruption errors would not cause Open() failures.
+ return s;
+ } else {
+ ROCKS_LOG_INFO(options_.info_log, "Loading backup %" PRIu32 " OK:\n%s",
+ backup_iter->first,
+ backup_iter->second->GetInfoString().c_str());
+ assert(latest_valid_backup_id_ == 0 ||
+ latest_valid_backup_id_ > backup_iter->first);
+ if (latest_valid_backup_id_ == 0) {
+ latest_valid_backup_id_ = backup_iter->first;
+ }
+ --valid_backups_to_open;
+ }
+ }
+
+ for (const auto& corrupt : corrupt_backups_) {
+ backups_.erase(backups_.find(corrupt.first));
+ }
+ // erase the backups before max_valid_backups_to_open
+ int num_unopened_backups;
+ if (options_.max_valid_backups_to_open == 0) {
+ num_unopened_backups = 0;
+ } else {
+ num_unopened_backups =
+ std::max(0, static_cast<int>(backups_.size()) -
+ options_.max_valid_backups_to_open);
+ }
+ for (int i = 0; i < num_unopened_backups; ++i) {
+ assert(backups_.begin()->second->Empty());
+ backups_.erase(backups_.begin());
+ }
+ }
+
+ ROCKS_LOG_INFO(options_.info_log, "Latest backup is %u", latest_backup_id_);
+ ROCKS_LOG_INFO(options_.info_log, "Latest valid backup is %u",
+ latest_valid_backup_id_);
+
+ // set up threads perform copies from files_to_copy_or_create_ in the
+ // background
+ for (int t = 0; t < options_.max_background_operations; t++) {
+ threads_.emplace_back([this]() {
+#if defined(_GNU_SOURCE) && defined(__GLIBC_PREREQ)
+#if __GLIBC_PREREQ(2, 12)
+ pthread_setname_np(pthread_self(), "backup_engine");
+#endif
+#endif
+ CopyOrCreateWorkItem work_item;
+ while (files_to_copy_or_create_.read(work_item)) {
+ CopyOrCreateResult result;
+ result.status = CopyOrCreateFile(
+ work_item.src_path, work_item.dst_path, work_item.contents,
+ work_item.src_env, work_item.dst_env, work_item.src_env_options,
+ work_item.sync, work_item.rate_limiter, &result.size,
+ &result.checksum_value, work_item.size_limit,
+ work_item.progress_callback);
+ work_item.result.set_value(std::move(result));
+ }
+ });
+ }
+ ROCKS_LOG_INFO(options_.info_log, "Initialized BackupEngine");
+
+ return Status::OK();
+}
+
+Status BackupEngineImpl::CreateNewBackupWithMetadata(
+ DB* db, const std::string& app_metadata, bool flush_before_backup,
+ std::function<void()> progress_callback) {
+ assert(initialized_);
+ assert(!read_only_);
+ if (app_metadata.size() > kMaxAppMetaSize) {
+ return Status::InvalidArgument("App metadata too large");
+ }
+
+ BackupID new_backup_id = latest_backup_id_ + 1;
+
+ assert(backups_.find(new_backup_id) == backups_.end());
+
+ auto private_dir = GetAbsolutePath(GetPrivateFileRel(new_backup_id));
+ Status s = backup_env_->FileExists(private_dir);
+ if (s.ok()) {
+ // maybe last backup failed and left partial state behind, clean it up.
+ // need to do this before updating backups_ such that a private dir
+ // named after new_backup_id will be cleaned up.
+ // (If an incomplete new backup is followed by an incomplete delete
+ // of the latest full backup, then there could be more than one next
+ // id with a private dir, the last thing to be deleted in delete
+ // backup, but all will be cleaned up with a GarbageCollect.)
+ s = GarbageCollect();
+ } else if (s.IsNotFound()) {
+ // normal case, the new backup's private dir doesn't exist yet
+ s = Status::OK();
+ }
+
+ auto ret = backups_.insert(std::make_pair(
+ new_backup_id, std::unique_ptr<BackupMeta>(new BackupMeta(
+ GetBackupMetaFile(new_backup_id, false /* tmp */),
+ GetBackupMetaFile(new_backup_id, true /* tmp */),
+ &backuped_file_infos_, backup_env_))));
+ assert(ret.second == true);
+ auto& new_backup = ret.first->second;
+ new_backup->RecordTimestamp();
+ new_backup->SetAppMetadata(app_metadata);
+
+ auto start_backup = backup_env_->NowMicros();
+
+ ROCKS_LOG_INFO(options_.info_log,
+ "Started the backup process -- creating backup %u",
+ new_backup_id);
+ if (s.ok()) {
+ s = backup_env_->CreateDir(private_dir);
+ }
+
+ RateLimiter* rate_limiter = options_.backup_rate_limiter.get();
+ if (rate_limiter) {
+ copy_file_buffer_size_ = static_cast<size_t>(rate_limiter->GetSingleBurstBytes());
+ }
+
+ // A set into which we will insert the dst_paths that are calculated for live
+ // files and live WAL files.
+ // This is used to check whether a live files shares a dst_path with another
+ // live file.
+ std::unordered_set<std::string> live_dst_paths;
+
+ std::vector<BackupAfterCopyOrCreateWorkItem> backup_items_to_finish;
+ // Add a CopyOrCreateWorkItem to the channel for each live file
+ db->DisableFileDeletions();
+ if (s.ok()) {
+ CheckpointImpl checkpoint(db);
+ uint64_t sequence_number = 0;
+ DBOptions db_options = db->GetDBOptions();
+ EnvOptions src_raw_env_options(db_options);
+ s = checkpoint.CreateCustomCheckpoint(
+ db_options,
+ [&](const std::string& /*src_dirname*/, const std::string& /*fname*/,
+ FileType) {
+ // custom checkpoint will switch to calling copy_file_cb after it sees
+ // NotSupported returned from link_file_cb.
+ return Status::NotSupported();
+ } /* link_file_cb */,
+ [&](const std::string& src_dirname, const std::string& fname,
+ uint64_t size_limit_bytes, FileType type) {
+ if (type == kLogFile && !options_.backup_log_files) {
+ return Status::OK();
+ }
+ Log(options_.info_log, "add file for backup %s", fname.c_str());
+ uint64_t size_bytes = 0;
+ Status st;
+ if (type == kTableFile) {
+ st = db_env_->GetFileSize(src_dirname + fname, &size_bytes);
+ }
+ EnvOptions src_env_options;
+ switch (type) {
+ case kLogFile:
+ src_env_options =
+ db_env_->OptimizeForLogRead(src_raw_env_options);
+ break;
+ case kTableFile:
+ src_env_options = db_env_->OptimizeForCompactionTableRead(
+ src_raw_env_options, ImmutableDBOptions(db_options));
+ break;
+ case kDescriptorFile:
+ src_env_options =
+ db_env_->OptimizeForManifestRead(src_raw_env_options);
+ break;
+ default:
+ // Other backed up files (like options file) are not read by live
+ // DB, so don't need to worry about avoiding mixing buffered and
+ // direct I/O. Just use plain defaults.
+ src_env_options = src_raw_env_options;
+ break;
+ }
+ if (st.ok()) {
+ st = AddBackupFileWorkItem(
+ live_dst_paths, backup_items_to_finish, new_backup_id,
+ options_.share_table_files && type == kTableFile, src_dirname,
+ fname, src_env_options, rate_limiter, size_bytes,
+ size_limit_bytes,
+ options_.share_files_with_checksum && type == kTableFile,
+ progress_callback);
+ }
+ return st;
+ } /* copy_file_cb */,
+ [&](const std::string& fname, const std::string& contents, FileType) {
+ Log(options_.info_log, "add file for backup %s", fname.c_str());
+ return AddBackupFileWorkItem(
+ live_dst_paths, backup_items_to_finish, new_backup_id,
+ false /* shared */, "" /* src_dir */, fname,
+ EnvOptions() /* src_env_options */, rate_limiter, contents.size(),
+ 0 /* size_limit */, false /* shared_checksum */,
+ progress_callback, contents);
+ } /* create_file_cb */,
+ &sequence_number, flush_before_backup ? 0 : port::kMaxUint64);
+ if (s.ok()) {
+ new_backup->SetSequenceNumber(sequence_number);
+ }
+ }
+ ROCKS_LOG_INFO(options_.info_log, "add files for backup done, wait finish.");
+ Status item_status;
+ for (auto& item : backup_items_to_finish) {
+ item.result.wait();
+ auto result = item.result.get();
+ item_status = result.status;
+ if (item_status.ok() && item.shared && item.needed_to_copy) {
+ item_status = item.backup_env->RenameFile(item.dst_path_tmp,
+ item.dst_path);
+ }
+ if (item_status.ok()) {
+ item_status = new_backup.get()->AddFile(
+ std::make_shared<FileInfo>(item.dst_relative,
+ result.size,
+ result.checksum_value));
+ }
+ if (!item_status.ok()) {
+ s = item_status;
+ }
+ }
+
+ // we copied all the files, enable file deletions
+ db->EnableFileDeletions(false);
+
+ auto backup_time = backup_env_->NowMicros() - start_backup;
+
+ if (s.ok()) {
+ // persist the backup metadata on the disk
+ s = new_backup->StoreToFile(options_.sync);
+ }
+ if (s.ok() && options_.sync) {
+ std::unique_ptr<Directory> backup_private_directory;
+ backup_env_->NewDirectory(
+ GetAbsolutePath(GetPrivateFileRel(new_backup_id, false)),
+ &backup_private_directory);
+ if (backup_private_directory != nullptr) {
+ s = backup_private_directory->Fsync();
+ }
+ if (s.ok() && private_directory_ != nullptr) {
+ s = private_directory_->Fsync();
+ }
+ if (s.ok() && meta_directory_ != nullptr) {
+ s = meta_directory_->Fsync();
+ }
+ if (s.ok() && shared_directory_ != nullptr) {
+ s = shared_directory_->Fsync();
+ }
+ if (s.ok() && backup_directory_ != nullptr) {
+ s = backup_directory_->Fsync();
+ }
+ }
+
+ if (s.ok()) {
+ backup_statistics_.IncrementNumberSuccessBackup();
+ }
+ if (!s.ok()) {
+ backup_statistics_.IncrementNumberFailBackup();
+ // clean all the files we might have created
+ ROCKS_LOG_INFO(options_.info_log, "Backup failed -- %s",
+ s.ToString().c_str());
+ ROCKS_LOG_INFO(options_.info_log, "Backup Statistics %s\n",
+ backup_statistics_.ToString().c_str());
+ // delete files that we might have already written
+ might_need_garbage_collect_ = true;
+ DeleteBackup(new_backup_id);
+ return s;
+ }
+
+ // here we know that we succeeded and installed the new backup
+ // in the LATEST_BACKUP file
+ latest_backup_id_ = new_backup_id;
+ latest_valid_backup_id_ = new_backup_id;
+ ROCKS_LOG_INFO(options_.info_log, "Backup DONE. All is good");
+
+ // backup_speed is in byte/second
+ double backup_speed = new_backup->GetSize() / (1.048576 * backup_time);
+ ROCKS_LOG_INFO(options_.info_log, "Backup number of files: %u",
+ new_backup->GetNumberFiles());
+ char human_size[16];
+ AppendHumanBytes(new_backup->GetSize(), human_size, sizeof(human_size));
+ ROCKS_LOG_INFO(options_.info_log, "Backup size: %s", human_size);
+ ROCKS_LOG_INFO(options_.info_log, "Backup time: %" PRIu64 " microseconds",
+ backup_time);
+ ROCKS_LOG_INFO(options_.info_log, "Backup speed: %.3f MB/s", backup_speed);
+ ROCKS_LOG_INFO(options_.info_log, "Backup Statistics %s",
+ backup_statistics_.ToString().c_str());
+ return s;
+}
+
+Status BackupEngineImpl::PurgeOldBackups(uint32_t num_backups_to_keep) {
+ assert(initialized_);
+ assert(!read_only_);
+
+ // Best effort deletion even with errors
+ Status overall_status = Status::OK();
+
+ ROCKS_LOG_INFO(options_.info_log, "Purging old backups, keeping %u",
+ num_backups_to_keep);
+ std::vector<BackupID> to_delete;
+ auto itr = backups_.begin();
+ while ((backups_.size() - to_delete.size()) > num_backups_to_keep) {
+ to_delete.push_back(itr->first);
+ itr++;
+ }
+ for (auto backup_id : to_delete) {
+ auto s = DeleteBackupInternal(backup_id);
+ if (!s.ok()) {
+ overall_status = s;
+ }
+ }
+ // Clean up after any incomplete backup deletion, potentially from
+ // earlier session.
+ if (might_need_garbage_collect_) {
+ auto s = GarbageCollect();
+ if (!s.ok() && overall_status.ok()) {
+ overall_status = s;
+ }
+ }
+ return overall_status;
+}
+
+Status BackupEngineImpl::DeleteBackup(BackupID backup_id) {
+ auto s1 = DeleteBackupInternal(backup_id);
+ auto s2 = Status::OK();
+
+ // Clean up after any incomplete backup deletion, potentially from
+ // earlier session.
+ if (might_need_garbage_collect_) {
+ s2 = GarbageCollect();
+ }
+
+ if (!s1.ok()) {
+ return s1;
+ } else {
+ return s2;
+ }
+}
+
+// Does not auto-GarbageCollect
+Status BackupEngineImpl::DeleteBackupInternal(BackupID backup_id) {
+ assert(initialized_);
+ assert(!read_only_);
+
+ ROCKS_LOG_INFO(options_.info_log, "Deleting backup %u", backup_id);
+ auto backup = backups_.find(backup_id);
+ if (backup != backups_.end()) {
+ auto s = backup->second->Delete();
+ if (!s.ok()) {
+ return s;
+ }
+ backups_.erase(backup);
+ } else {
+ auto corrupt = corrupt_backups_.find(backup_id);
+ if (corrupt == corrupt_backups_.end()) {
+ return Status::NotFound("Backup not found");
+ }
+ auto s = corrupt->second.second->Delete();
+ if (!s.ok()) {
+ return s;
+ }
+ corrupt_backups_.erase(corrupt);
+ }
+
+ // After removing meta file, best effort deletion even with errors.
+ // (Don't delete other files if we can't delete the meta file right
+ // now.)
+ std::vector<std::string> to_delete;
+ for (auto& itr : backuped_file_infos_) {
+ if (itr.second->refs == 0) {
+ Status s = backup_env_->DeleteFile(GetAbsolutePath(itr.first));
+ ROCKS_LOG_INFO(options_.info_log, "Deleting %s -- %s", itr.first.c_str(),
+ s.ToString().c_str());
+ to_delete.push_back(itr.first);
+ if (!s.ok()) {
+ // Trying again later might work
+ might_need_garbage_collect_ = true;
+ }
+ }
+ }
+ for (auto& td : to_delete) {
+ backuped_file_infos_.erase(td);
+ }
+
+ // take care of private dirs -- GarbageCollect() will take care of them
+ // if they are not empty
+ std::string private_dir = GetPrivateFileRel(backup_id);
+ Status s = backup_env_->DeleteDir(GetAbsolutePath(private_dir));
+ ROCKS_LOG_INFO(options_.info_log, "Deleting private dir %s -- %s",
+ private_dir.c_str(), s.ToString().c_str());
+ if (!s.ok()) {
+ // Full gc or trying again later might work
+ might_need_garbage_collect_ = true;
+ }
+ return Status::OK();
+}
+
+void BackupEngineImpl::GetBackupInfo(std::vector<BackupInfo>* backup_info) {
+ assert(initialized_);
+ backup_info->reserve(backups_.size());
+ for (auto& backup : backups_) {
+ if (!backup.second->Empty()) {
+ backup_info->push_back(BackupInfo(
+ backup.first, backup.second->GetTimestamp(), backup.second->GetSize(),
+ backup.second->GetNumberFiles(), backup.second->GetAppMetadata()));
+ }
+ }
+}
+
+void
+BackupEngineImpl::GetCorruptedBackups(
+ std::vector<BackupID>* corrupt_backup_ids) {
+ assert(initialized_);
+ corrupt_backup_ids->reserve(corrupt_backups_.size());
+ for (auto& backup : corrupt_backups_) {
+ corrupt_backup_ids->push_back(backup.first);
+ }
+}
+
+Status BackupEngineImpl::RestoreDBFromBackup(
+ BackupID backup_id, const std::string& db_dir, const std::string& wal_dir,
+ const RestoreOptions& restore_options) {
+ assert(initialized_);
+ auto corrupt_itr = corrupt_backups_.find(backup_id);
+ if (corrupt_itr != corrupt_backups_.end()) {
+ return corrupt_itr->second.first;
+ }
+ auto backup_itr = backups_.find(backup_id);
+ if (backup_itr == backups_.end()) {
+ return Status::NotFound("Backup not found");
+ }
+ auto& backup = backup_itr->second;
+ if (backup->Empty()) {
+ return Status::NotFound("Backup not found");
+ }
+
+ ROCKS_LOG_INFO(options_.info_log, "Restoring backup id %u\n", backup_id);
+ ROCKS_LOG_INFO(options_.info_log, "keep_log_files: %d\n",
+ static_cast<int>(restore_options.keep_log_files));
+
+ // just in case. Ignore errors
+ db_env_->CreateDirIfMissing(db_dir);
+ db_env_->CreateDirIfMissing(wal_dir);
+
+ if (restore_options.keep_log_files) {
+ // delete files in db_dir, but keep all the log files
+ DeleteChildren(db_dir, 1 << kLogFile);
+ // move all the files from archive dir to wal_dir
+ std::string archive_dir = ArchivalDirectory(wal_dir);
+ std::vector<std::string> archive_files;
+ db_env_->GetChildren(archive_dir, &archive_files); // ignore errors
+ for (const auto& f : archive_files) {
+ uint64_t number;
+ FileType type;
+ bool ok = ParseFileName(f, &number, &type);
+ if (ok && type == kLogFile) {
+ ROCKS_LOG_INFO(options_.info_log,
+ "Moving log file from archive/ to wal_dir: %s",
+ f.c_str());
+ Status s =
+ db_env_->RenameFile(archive_dir + "/" + f, wal_dir + "/" + f);
+ if (!s.ok()) {
+ // if we can't move log file from archive_dir to wal_dir,
+ // we should fail, since it might mean data loss
+ return s;
+ }
+ }
+ }
+ } else {
+ DeleteChildren(wal_dir);
+ DeleteChildren(ArchivalDirectory(wal_dir));
+ DeleteChildren(db_dir);
+ }
+
+ RateLimiter* rate_limiter = options_.restore_rate_limiter.get();
+ if (rate_limiter) {
+ copy_file_buffer_size_ = static_cast<size_t>(rate_limiter->GetSingleBurstBytes());
+ }
+ Status s;
+ std::vector<RestoreAfterCopyOrCreateWorkItem> restore_items_to_finish;
+ for (const auto& file_info : backup->GetFiles()) {
+ const std::string &file = file_info->filename;
+ std::string dst;
+ // 1. extract the filename
+ size_t slash = file.find_last_of('/');
+ // file will either be shared/<file>, shared_checksum/<file_crc32_size>
+ // or private/<number>/<file>
+ assert(slash != std::string::npos);
+ dst = file.substr(slash + 1);
+
+ // if the file was in shared_checksum, extract the real file name
+ // in this case the file is <number>_<checksum>_<size>.<type>
+ if (file.substr(0, slash) == GetSharedChecksumDirRel()) {
+ dst = GetFileFromChecksumFile(dst);
+ }
+
+ // 2. find the filetype
+ uint64_t number;
+ FileType type;
+ bool ok = ParseFileName(dst, &number, &type);
+ if (!ok) {
+ return Status::Corruption("Backup corrupted");
+ }
+ // 3. Construct the final path
+ // kLogFile lives in wal_dir and all the rest live in db_dir
+ dst = ((type == kLogFile) ? wal_dir : db_dir) +
+ "/" + dst;
+
+ ROCKS_LOG_INFO(options_.info_log, "Restoring %s to %s\n", file.c_str(),
+ dst.c_str());
+ CopyOrCreateWorkItem copy_or_create_work_item(
+ GetAbsolutePath(file), dst, "" /* contents */, backup_env_, db_env_,
+ EnvOptions() /* src_env_options */, false, rate_limiter,
+ 0 /* size_limit */);
+ RestoreAfterCopyOrCreateWorkItem after_copy_or_create_work_item(
+ copy_or_create_work_item.result.get_future(),
+ file_info->checksum_value);
+ files_to_copy_or_create_.write(std::move(copy_or_create_work_item));
+ restore_items_to_finish.push_back(
+ std::move(after_copy_or_create_work_item));
+ }
+ Status item_status;
+ for (auto& item : restore_items_to_finish) {
+ item.result.wait();
+ auto result = item.result.get();
+ item_status = result.status;
+ // Note: It is possible that both of the following bad-status cases occur
+ // during copying. But, we only return one status.
+ if (!item_status.ok()) {
+ s = item_status;
+ break;
+ } else if (item.checksum_value != result.checksum_value) {
+ s = Status::Corruption("Checksum check failed");
+ break;
+ }
+ }
+
+ ROCKS_LOG_INFO(options_.info_log, "Restoring done -- %s\n",
+ s.ToString().c_str());
+ return s;
+}
+
+Status BackupEngineImpl::VerifyBackup(BackupID backup_id) {
+ assert(initialized_);
+ auto corrupt_itr = corrupt_backups_.find(backup_id);
+ if (corrupt_itr != corrupt_backups_.end()) {
+ return corrupt_itr->second.first;
+ }
+
+ auto backup_itr = backups_.find(backup_id);
+ if (backup_itr == backups_.end()) {
+ return Status::NotFound();
+ }
+
+ auto& backup = backup_itr->second;
+ if (backup->Empty()) {
+ return Status::NotFound();
+ }
+
+ ROCKS_LOG_INFO(options_.info_log, "Verifying backup id %u\n", backup_id);
+
+ std::unordered_map<std::string, uint64_t> curr_abs_path_to_size;
+ for (const auto& rel_dir : {GetPrivateFileRel(backup_id), GetSharedFileRel(),
+ GetSharedFileWithChecksumRel()}) {
+ const auto abs_dir = GetAbsolutePath(rel_dir);
+ InsertPathnameToSizeBytes(abs_dir, backup_env_, &curr_abs_path_to_size);
+ }
+
+ for (const auto& file_info : backup->GetFiles()) {
+ const auto abs_path = GetAbsolutePath(file_info->filename);
+ if (curr_abs_path_to_size.find(abs_path) == curr_abs_path_to_size.end()) {
+ return Status::NotFound("File missing: " + abs_path);
+ }
+ if (file_info->size != curr_abs_path_to_size[abs_path]) {
+ return Status::Corruption("File corrupted: " + abs_path);
+ }
+ }
+ return Status::OK();
+}
+
+Status BackupEngineImpl::CopyOrCreateFile(
+ const std::string& src, const std::string& dst, const std::string& contents,
+ Env* src_env, Env* dst_env, const EnvOptions& src_env_options, bool sync,
+ RateLimiter* rate_limiter, uint64_t* size, uint32_t* checksum_value,
+ uint64_t size_limit, std::function<void()> progress_callback) {
+ assert(src.empty() != contents.empty());
+ Status s;
+ std::unique_ptr<WritableFile> dst_file;
+ std::unique_ptr<SequentialFile> src_file;
+ EnvOptions dst_env_options;
+ dst_env_options.use_mmap_writes = false;
+ // TODO:(gzh) maybe use direct reads/writes here if possible
+ if (size != nullptr) {
+ *size = 0;
+ }
+ if (checksum_value != nullptr) {
+ *checksum_value = 0;
+ }
+
+ // Check if size limit is set. if not, set it to very big number
+ if (size_limit == 0) {
+ size_limit = std::numeric_limits<uint64_t>::max();
+ }
+
+ s = dst_env->NewWritableFile(dst, &dst_file, dst_env_options);
+ if (s.ok() && !src.empty()) {
+ s = src_env->NewSequentialFile(src, &src_file, src_env_options);
+ }
+ if (!s.ok()) {
+ return s;
+ }
+
+ std::unique_ptr<WritableFileWriter> dest_writer(new WritableFileWriter(
+ NewLegacyWritableFileWrapper(std::move(dst_file)), dst, dst_env_options));
+ std::unique_ptr<SequentialFileReader> src_reader;
+ std::unique_ptr<char[]> buf;
+ if (!src.empty()) {
+ src_reader.reset(new SequentialFileReader(
+ NewLegacySequentialFileWrapper(src_file), src));
+ buf.reset(new char[copy_file_buffer_size_]);
+ }
+
+ Slice data;
+ uint64_t processed_buffer_size = 0;
+ do {
+ if (stop_backup_.load(std::memory_order_acquire)) {
+ return Status::Incomplete("Backup stopped");
+ }
+ if (!src.empty()) {
+ size_t buffer_to_read = (copy_file_buffer_size_ < size_limit)
+ ? copy_file_buffer_size_
+ : static_cast<size_t>(size_limit);
+ s = src_reader->Read(buffer_to_read, &data, buf.get());
+ processed_buffer_size += buffer_to_read;
+ } else {
+ data = contents;
+ }
+ size_limit -= data.size();
+
+ if (!s.ok()) {
+ return s;
+ }
+
+ if (size != nullptr) {
+ *size += data.size();
+ }
+ if (checksum_value != nullptr) {
+ *checksum_value =
+ crc32c::Extend(*checksum_value, data.data(), data.size());
+ }
+ s = dest_writer->Append(data);
+ if (rate_limiter != nullptr) {
+ rate_limiter->Request(data.size(), Env::IO_LOW, nullptr /* stats */,
+ RateLimiter::OpType::kWrite);
+ }
+ if (processed_buffer_size > options_.callback_trigger_interval_size) {
+ processed_buffer_size -= options_.callback_trigger_interval_size;
+ std::lock_guard<std::mutex> lock(byte_report_mutex_);
+ progress_callback();
+ }
+ } while (s.ok() && contents.empty() && data.size() > 0 && size_limit > 0);
+
+ if (s.ok() && sync) {
+ s = dest_writer->Sync(false);
+ }
+ if (s.ok()) {
+ s = dest_writer->Close();
+ }
+ return s;
+}
+
+// fname will always start with "/"
+Status BackupEngineImpl::AddBackupFileWorkItem(
+ std::unordered_set<std::string>& live_dst_paths,
+ std::vector<BackupAfterCopyOrCreateWorkItem>& backup_items_to_finish,
+ BackupID backup_id, bool shared, const std::string& src_dir,
+ const std::string& fname, const EnvOptions& src_env_options,
+ RateLimiter* rate_limiter, uint64_t size_bytes, uint64_t size_limit,
+ bool shared_checksum, std::function<void()> progress_callback,
+ const std::string& contents) {
+ assert(!fname.empty() && fname[0] == '/');
+ assert(contents.empty() != src_dir.empty());
+
+ std::string dst_relative = fname.substr(1);
+ std::string dst_relative_tmp;
+ Status s;
+ uint32_t checksum_value = 0;
+
+ if (shared && shared_checksum) {
+ // add checksum and file length to the file name
+ s = CalculateChecksum(src_dir + fname, db_env_, src_env_options, size_limit,
+ &checksum_value);
+ if (!s.ok()) {
+ return s;
+ }
+ if (size_bytes == port::kMaxUint64) {
+ return Status::NotFound("File missing: " + src_dir + fname);
+ }
+ dst_relative =
+ GetSharedFileWithChecksum(dst_relative, checksum_value, size_bytes);
+ dst_relative_tmp = GetSharedFileWithChecksumRel(dst_relative, true);
+ dst_relative = GetSharedFileWithChecksumRel(dst_relative, false);
+ } else if (shared) {
+ dst_relative_tmp = GetSharedFileRel(dst_relative, true);
+ dst_relative = GetSharedFileRel(dst_relative, false);
+ } else {
+ dst_relative = GetPrivateFileRel(backup_id, false, dst_relative);
+ }
+
+ // We copy into `temp_dest_path` and, once finished, rename it to
+ // `final_dest_path`. This allows files to atomically appear at
+ // `final_dest_path`. We can copy directly to the final path when atomicity
+ // is unnecessary, like for files in private backup directories.
+ const std::string* copy_dest_path;
+ std::string temp_dest_path;
+ std::string final_dest_path = GetAbsolutePath(dst_relative);
+ if (!dst_relative_tmp.empty()) {
+ temp_dest_path = GetAbsolutePath(dst_relative_tmp);
+ copy_dest_path = &temp_dest_path;
+ } else {
+ copy_dest_path = &final_dest_path;
+ }
+
+ // if it's shared, we also need to check if it exists -- if it does, no need
+ // to copy it again.
+ bool need_to_copy = true;
+ // true if final_dest_path is the same path as another live file
+ const bool same_path =
+ live_dst_paths.find(final_dest_path) != live_dst_paths.end();
+
+ bool file_exists = false;
+ if (shared && !same_path) {
+ Status exist = backup_env_->FileExists(final_dest_path);
+ if (exist.ok()) {
+ file_exists = true;
+ } else if (exist.IsNotFound()) {
+ file_exists = false;
+ } else {
+ assert(s.IsIOError());
+ return exist;
+ }
+ }
+
+ if (!contents.empty()) {
+ need_to_copy = false;
+ } else if (shared && (same_path || file_exists)) {
+ need_to_copy = false;
+ if (shared_checksum) {
+ ROCKS_LOG_INFO(options_.info_log,
+ "%s already present, with checksum %u and size %" PRIu64,
+ fname.c_str(), checksum_value, size_bytes);
+ } else if (backuped_file_infos_.find(dst_relative) ==
+ backuped_file_infos_.end() && !same_path) {
+ // file already exists, but it's not referenced by any backup. overwrite
+ // the file
+ ROCKS_LOG_INFO(
+ options_.info_log,
+ "%s already present, but not referenced by any backup. We will "
+ "overwrite the file.",
+ fname.c_str());
+ need_to_copy = true;
+ backup_env_->DeleteFile(final_dest_path);
+ } else {
+ // the file is present and referenced by a backup
+ ROCKS_LOG_INFO(options_.info_log,
+ "%s already present, calculate checksum", fname.c_str());
+ s = CalculateChecksum(src_dir + fname, db_env_, src_env_options,
+ size_limit, &checksum_value);
+ }
+ }
+ live_dst_paths.insert(final_dest_path);
+
+ if (!contents.empty() || need_to_copy) {
+ ROCKS_LOG_INFO(options_.info_log, "Copying %s to %s", fname.c_str(),
+ copy_dest_path->c_str());
+ CopyOrCreateWorkItem copy_or_create_work_item(
+ src_dir.empty() ? "" : src_dir + fname, *copy_dest_path, contents,
+ db_env_, backup_env_, src_env_options, options_.sync, rate_limiter,
+ size_limit, progress_callback);
+ BackupAfterCopyOrCreateWorkItem after_copy_or_create_work_item(
+ copy_or_create_work_item.result.get_future(), shared, need_to_copy,
+ backup_env_, temp_dest_path, final_dest_path, dst_relative);
+ files_to_copy_or_create_.write(std::move(copy_or_create_work_item));
+ backup_items_to_finish.push_back(std::move(after_copy_or_create_work_item));
+ } else {
+ std::promise<CopyOrCreateResult> promise_result;
+ BackupAfterCopyOrCreateWorkItem after_copy_or_create_work_item(
+ promise_result.get_future(), shared, need_to_copy, backup_env_,
+ temp_dest_path, final_dest_path, dst_relative);
+ backup_items_to_finish.push_back(std::move(after_copy_or_create_work_item));
+ CopyOrCreateResult result;
+ result.status = s;
+ result.size = size_bytes;
+ result.checksum_value = checksum_value;
+ promise_result.set_value(std::move(result));
+ }
+ return s;
+}
+
+Status BackupEngineImpl::CalculateChecksum(const std::string& src, Env* src_env,
+ const EnvOptions& src_env_options,
+ uint64_t size_limit,
+ uint32_t* checksum_value) {
+ *checksum_value = 0;
+ if (size_limit == 0) {
+ size_limit = std::numeric_limits<uint64_t>::max();
+ }
+
+ std::unique_ptr<SequentialFile> src_file;
+ Status s = src_env->NewSequentialFile(src, &src_file, src_env_options);
+ if (!s.ok()) {
+ return s;
+ }
+
+ std::unique_ptr<SequentialFileReader> src_reader(
+ new SequentialFileReader(NewLegacySequentialFileWrapper(src_file), src));
+ std::unique_ptr<char[]> buf(new char[copy_file_buffer_size_]);
+ Slice data;
+
+ do {
+ if (stop_backup_.load(std::memory_order_acquire)) {
+ return Status::Incomplete("Backup stopped");
+ }
+ size_t buffer_to_read = (copy_file_buffer_size_ < size_limit) ?
+ copy_file_buffer_size_ : static_cast<size_t>(size_limit);
+ s = src_reader->Read(buffer_to_read, &data, buf.get());
+
+ if (!s.ok()) {
+ return s;
+ }
+
+ size_limit -= data.size();
+ *checksum_value = crc32c::Extend(*checksum_value, data.data(), data.size());
+ } while (data.size() > 0 && size_limit > 0);
+
+ return s;
+}
+
+void BackupEngineImpl::DeleteChildren(const std::string& dir,
+ uint32_t file_type_filter) {
+ std::vector<std::string> children;
+ db_env_->GetChildren(dir, &children); // ignore errors
+
+ for (const auto& f : children) {
+ uint64_t number;
+ FileType type;
+ bool ok = ParseFileName(f, &number, &type);
+ if (ok && (file_type_filter & (1 << type))) {
+ // don't delete this file
+ continue;
+ }
+ db_env_->DeleteFile(dir + "/" + f); // ignore errors
+ }
+}
+
+Status BackupEngineImpl::InsertPathnameToSizeBytes(
+ const std::string& dir, Env* env,
+ std::unordered_map<std::string, uint64_t>* result) {
+ assert(result != nullptr);
+ std::vector<Env::FileAttributes> files_attrs;
+ Status status = env->FileExists(dir);
+ if (status.ok()) {
+ status = env->GetChildrenFileAttributes(dir, &files_attrs);
+ } else if (status.IsNotFound()) {
+ // Insert no entries can be considered success
+ status = Status::OK();
+ }
+ const bool slash_needed = dir.empty() || dir.back() != '/';
+ for (const auto& file_attrs : files_attrs) {
+ result->emplace(dir + (slash_needed ? "/" : "") + file_attrs.name,
+ file_attrs.size_bytes);
+ }
+ return status;
+}
+
+Status BackupEngineImpl::GarbageCollect() {
+ assert(!read_only_);
+
+ // We will make a best effort to remove all garbage even in the presence
+ // of inconsistencies or I/O failures that inhibit finding garbage.
+ Status overall_status = Status::OK();
+ // If all goes well, we don't need another auto-GC this session
+ might_need_garbage_collect_ = false;
+
+ ROCKS_LOG_INFO(options_.info_log, "Starting garbage collection");
+
+ // delete obsolete shared files
+ for (bool with_checksum : {false, true}) {
+ std::vector<std::string> shared_children;
+ {
+ std::string shared_path;
+ if (with_checksum) {
+ shared_path = GetAbsolutePath(GetSharedFileWithChecksumRel());
+ } else {
+ shared_path = GetAbsolutePath(GetSharedFileRel());
+ }
+ auto s = backup_env_->FileExists(shared_path);
+ if (s.ok()) {
+ s = backup_env_->GetChildren(shared_path, &shared_children);
+ } else if (s.IsNotFound()) {
+ s = Status::OK();
+ }
+ if (!s.ok()) {
+ overall_status = s;
+ // Trying again later might work
+ might_need_garbage_collect_ = true;
+ }
+ }
+ for (auto& child : shared_children) {
+ if (child == "." || child == "..") {
+ continue;
+ }
+ std::string rel_fname;
+ if (with_checksum) {
+ rel_fname = GetSharedFileWithChecksumRel(child);
+ } else {
+ rel_fname = GetSharedFileRel(child);
+ }
+ auto child_itr = backuped_file_infos_.find(rel_fname);
+ // if it's not refcounted, delete it
+ if (child_itr == backuped_file_infos_.end() ||
+ child_itr->second->refs == 0) {
+ // this might be a directory, but DeleteFile will just fail in that
+ // case, so we're good
+ Status s = backup_env_->DeleteFile(GetAbsolutePath(rel_fname));
+ ROCKS_LOG_INFO(options_.info_log, "Deleting %s -- %s",
+ rel_fname.c_str(), s.ToString().c_str());
+ backuped_file_infos_.erase(rel_fname);
+ if (!s.ok()) {
+ // Trying again later might work
+ might_need_garbage_collect_ = true;
+ }
+ }
+ }
+ }
+
+ // delete obsolete private files
+ std::vector<std::string> private_children;
+ {
+ auto s = backup_env_->GetChildren(GetAbsolutePath(GetPrivateDirRel()),
+ &private_children);
+ if (!s.ok()) {
+ overall_status = s;
+ // Trying again later might work
+ might_need_garbage_collect_ = true;
+ }
+ }
+ for (auto& child : private_children) {
+ if (child == "." || child == "..") {
+ continue;
+ }
+
+ BackupID backup_id = 0;
+ bool tmp_dir = child.find(".tmp") != std::string::npos;
+ sscanf(child.c_str(), "%u", &backup_id);
+ if (!tmp_dir && // if it's tmp_dir, delete it
+ (backup_id == 0 || backups_.find(backup_id) != backups_.end())) {
+ // it's either not a number or it's still alive. continue
+ continue;
+ }
+ // here we have to delete the dir and all its children
+ std::string full_private_path =
+ GetAbsolutePath(GetPrivateFileRel(backup_id));
+ std::vector<std::string> subchildren;
+ backup_env_->GetChildren(full_private_path, &subchildren);
+ for (auto& subchild : subchildren) {
+ if (subchild == "." || subchild == "..") {
+ continue;
+ }
+ Status s = backup_env_->DeleteFile(full_private_path + subchild);
+ ROCKS_LOG_INFO(options_.info_log, "Deleting %s -- %s",
+ (full_private_path + subchild).c_str(),
+ s.ToString().c_str());
+ if (!s.ok()) {
+ // Trying again later might work
+ might_need_garbage_collect_ = true;
+ }
+ }
+ // finally delete the private dir
+ Status s = backup_env_->DeleteDir(full_private_path);
+ ROCKS_LOG_INFO(options_.info_log, "Deleting dir %s -- %s",
+ full_private_path.c_str(), s.ToString().c_str());
+ if (!s.ok()) {
+ // Trying again later might work
+ might_need_garbage_collect_ = true;
+ }
+ }
+
+ assert(overall_status.ok() || might_need_garbage_collect_);
+ return overall_status;
+}
+
+// ------- BackupMeta class --------
+
+Status BackupEngineImpl::BackupMeta::AddFile(
+ std::shared_ptr<FileInfo> file_info) {
+ auto itr = file_infos_->find(file_info->filename);
+ if (itr == file_infos_->end()) {
+ auto ret = file_infos_->insert({file_info->filename, file_info});
+ if (ret.second) {
+ itr = ret.first;
+ itr->second->refs = 1;
+ } else {
+ // if this happens, something is seriously wrong
+ return Status::Corruption("In memory metadata insertion error");
+ }
+ } else {
+ if (itr->second->checksum_value != file_info->checksum_value) {
+ return Status::Corruption(
+ "Checksum mismatch for existing backup file. Delete old backups and "
+ "try again.");
+ }
+ ++itr->second->refs; // increase refcount if already present
+ }
+
+ size_ += file_info->size;
+ files_.push_back(itr->second);
+
+ return Status::OK();
+}
+
+Status BackupEngineImpl::BackupMeta::Delete(bool delete_meta) {
+ Status s;
+ for (const auto& file : files_) {
+ --file->refs; // decrease refcount
+ }
+ files_.clear();
+ // delete meta file
+ if (delete_meta) {
+ s = env_->FileExists(meta_filename_);
+ if (s.ok()) {
+ s = env_->DeleteFile(meta_filename_);
+ } else if (s.IsNotFound()) {
+ s = Status::OK(); // nothing to delete
+ }
+ }
+ timestamp_ = 0;
+ return s;
+}
+
+Slice kMetaDataPrefix("metadata ");
+
+// each backup meta file is of the format:
+// <timestamp>
+// <seq number>
+// <metadata(literal string)> <metadata> (optional)
+// <number of files>
+// <file1> <crc32(literal string)> <crc32_value>
+// <file2> <crc32(literal string)> <crc32_value>
+// ...
+Status BackupEngineImpl::BackupMeta::LoadFromFile(
+ const std::string& backup_dir,
+ const std::unordered_map<std::string, uint64_t>& abs_path_to_size) {
+ assert(Empty());
+ Status s;
+ std::unique_ptr<SequentialFile> backup_meta_file;
+ s = env_->NewSequentialFile(meta_filename_, &backup_meta_file, EnvOptions());
+ if (!s.ok()) {
+ return s;
+ }
+
+ std::unique_ptr<SequentialFileReader> backup_meta_reader(
+ new SequentialFileReader(NewLegacySequentialFileWrapper(backup_meta_file),
+ meta_filename_));
+ std::unique_ptr<char[]> buf(new char[max_backup_meta_file_size_ + 1]);
+ Slice data;
+ s = backup_meta_reader->Read(max_backup_meta_file_size_, &data, buf.get());
+
+ if (!s.ok() || data.size() == max_backup_meta_file_size_) {
+ return s.ok() ? Status::Corruption("File size too big") : s;
+ }
+ buf[data.size()] = 0;
+
+ uint32_t num_files = 0;
+ char *next;
+ timestamp_ = strtoull(data.data(), &next, 10);
+ data.remove_prefix(next - data.data() + 1); // +1 for '\n'
+ sequence_number_ = strtoull(data.data(), &next, 10);
+ data.remove_prefix(next - data.data() + 1); // +1 for '\n'
+
+ if (data.starts_with(kMetaDataPrefix)) {
+ // app metadata present
+ data.remove_prefix(kMetaDataPrefix.size());
+ Slice hex_encoded_metadata = GetSliceUntil(&data, '\n');
+ bool decode_success = hex_encoded_metadata.DecodeHex(&app_metadata_);
+ if (!decode_success) {
+ return Status::Corruption(
+ "Failed to decode stored hex encoded app metadata");
+ }
+ }
+
+ num_files = static_cast<uint32_t>(strtoul(data.data(), &next, 10));
+ data.remove_prefix(next - data.data() + 1); // +1 for '\n'
+
+ std::vector<std::shared_ptr<FileInfo>> files;
+
+ Slice checksum_prefix("crc32 ");
+
+ for (uint32_t i = 0; s.ok() && i < num_files; ++i) {
+ auto line = GetSliceUntil(&data, '\n');
+ std::string filename = GetSliceUntil(&line, ' ').ToString();
+
+ uint64_t size;
+ const std::shared_ptr<FileInfo> file_info = GetFile(filename);
+ if (file_info) {
+ size = file_info->size;
+ } else {
+ std::string abs_path = backup_dir + "/" + filename;
+ try {
+ size = abs_path_to_size.at(abs_path);
+ } catch (std::out_of_range&) {
+ return Status::Corruption("Size missing for pathname: " + abs_path);
+ }
+ }
+
+ if (line.empty()) {
+ return Status::Corruption("File checksum is missing for " + filename +
+ " in " + meta_filename_);
+ }
+
+ uint32_t checksum_value = 0;
+ if (line.starts_with(checksum_prefix)) {
+ line.remove_prefix(checksum_prefix.size());
+ checksum_value = static_cast<uint32_t>(
+ strtoul(line.data(), nullptr, 10));
+ if (line != ROCKSDB_NAMESPACE::ToString(checksum_value)) {
+ return Status::Corruption("Invalid checksum value for " + filename +
+ " in " + meta_filename_);
+ }
+ } else {
+ return Status::Corruption("Unknown checksum type for " + filename +
+ " in " + meta_filename_);
+ }
+
+ files.emplace_back(new FileInfo(filename, size, checksum_value));
+ }
+
+ if (s.ok() && data.size() > 0) {
+ // file has to be read completely. if not, we count it as corruption
+ s = Status::Corruption("Tailing data in backup meta file in " +
+ meta_filename_);
+ }
+
+ if (s.ok()) {
+ files_.reserve(files.size());
+ for (const auto& file_info : files) {
+ s = AddFile(file_info);
+ if (!s.ok()) {
+ break;
+ }
+ }
+ }
+
+ return s;
+}
+
+Status BackupEngineImpl::BackupMeta::StoreToFile(bool sync) {
+ Status s;
+ std::unique_ptr<WritableFile> backup_meta_file;
+ EnvOptions env_options;
+ env_options.use_mmap_writes = false;
+ env_options.use_direct_writes = false;
+ s = env_->NewWritableFile(meta_tmp_filename_, &backup_meta_file, env_options);
+ if (!s.ok()) {
+ return s;
+ }
+
+ std::unique_ptr<char[]> buf(new char[max_backup_meta_file_size_]);
+ size_t len = 0, buf_size = max_backup_meta_file_size_;
+ len += snprintf(buf.get(), buf_size, "%" PRId64 "\n", timestamp_);
+ len += snprintf(buf.get() + len, buf_size - len, "%" PRIu64 "\n",
+ sequence_number_);
+ if (!app_metadata_.empty()) {
+ std::string hex_encoded_metadata =
+ Slice(app_metadata_).ToString(/* hex */ true);
+
+ // +1 to accommodate newline character
+ size_t hex_meta_strlen = kMetaDataPrefix.ToString().length() + hex_encoded_metadata.length() + 1;
+ if (hex_meta_strlen >= buf_size) {
+ return Status::Corruption("Buffer too small to fit backup metadata");
+ }
+ else if (len + hex_meta_strlen >= buf_size) {
+ backup_meta_file->Append(Slice(buf.get(), len));
+ buf.reset();
+ std::unique_ptr<char[]> new_reset_buf(
+ new char[max_backup_meta_file_size_]);
+ buf.swap(new_reset_buf);
+ len = 0;
+ }
+ len += snprintf(buf.get() + len, buf_size - len, "%s%s\n",
+ kMetaDataPrefix.ToString().c_str(),
+ hex_encoded_metadata.c_str());
+ }
+
+ char writelen_temp[19];
+ if (len + snprintf(writelen_temp, sizeof(writelen_temp),
+ "%" ROCKSDB_PRIszt "\n", files_.size()) >= buf_size) {
+ backup_meta_file->Append(Slice(buf.get(), len));
+ buf.reset();
+ std::unique_ptr<char[]> new_reset_buf(new char[max_backup_meta_file_size_]);
+ buf.swap(new_reset_buf);
+ len = 0;
+ }
+ {
+ const char *const_write = writelen_temp;
+ len += snprintf(buf.get() + len, buf_size - len, "%s", const_write);
+ }
+
+ for (const auto& file : files_) {
+ // use crc32 for now, switch to something else if needed
+
+ size_t newlen = len + file->filename.length() + snprintf(writelen_temp,
+ sizeof(writelen_temp), " crc32 %u\n", file->checksum_value);
+ const char *const_write = writelen_temp;
+ if (newlen >= buf_size) {
+ backup_meta_file->Append(Slice(buf.get(), len));
+ buf.reset();
+ std::unique_ptr<char[]> new_reset_buf(
+ new char[max_backup_meta_file_size_]);
+ buf.swap(new_reset_buf);
+ len = 0;
+ }
+ len += snprintf(buf.get() + len, buf_size - len, "%s%s",
+ file->filename.c_str(), const_write);
+ }
+
+ s = backup_meta_file->Append(Slice(buf.get(), len));
+ if (s.ok() && sync) {
+ s = backup_meta_file->Sync();
+ }
+ if (s.ok()) {
+ s = backup_meta_file->Close();
+ }
+ if (s.ok()) {
+ s = env_->RenameFile(meta_tmp_filename_, meta_filename_);
+ }
+ return s;
+}
+
+// -------- BackupEngineReadOnlyImpl ---------
+class BackupEngineReadOnlyImpl : public BackupEngineReadOnly {
+ public:
+ BackupEngineReadOnlyImpl(Env* db_env, const BackupableDBOptions& options)
+ : backup_engine_(new BackupEngineImpl(db_env, options, true)) {}
+
+ ~BackupEngineReadOnlyImpl() override {}
+
+ // The returned BackupInfos are in chronological order, which means the
+ // latest backup comes last.
+ void GetBackupInfo(std::vector<BackupInfo>* backup_info) override {
+ backup_engine_->GetBackupInfo(backup_info);
+ }
+
+ void GetCorruptedBackups(std::vector<BackupID>* corrupt_backup_ids) override {
+ backup_engine_->GetCorruptedBackups(corrupt_backup_ids);
+ }
+
+ Status RestoreDBFromBackup(
+ BackupID backup_id, const std::string& db_dir, const std::string& wal_dir,
+ const RestoreOptions& restore_options = RestoreOptions()) override {
+ return backup_engine_->RestoreDBFromBackup(backup_id, db_dir, wal_dir,
+ restore_options);
+ }
+
+ Status RestoreDBFromLatestBackup(
+ const std::string& db_dir, const std::string& wal_dir,
+ const RestoreOptions& restore_options = RestoreOptions()) override {
+ return backup_engine_->RestoreDBFromLatestBackup(db_dir, wal_dir,
+ restore_options);
+ }
+
+ Status VerifyBackup(BackupID backup_id) override {
+ return backup_engine_->VerifyBackup(backup_id);
+ }
+
+ Status Initialize() { return backup_engine_->Initialize(); }
+
+ private:
+ std::unique_ptr<BackupEngineImpl> backup_engine_;
+};
+
+Status BackupEngineReadOnly::Open(Env* env, const BackupableDBOptions& options,
+ BackupEngineReadOnly** backup_engine_ptr) {
+ if (options.destroy_old_data) {
+ return Status::InvalidArgument(
+ "Can't destroy old data with ReadOnly BackupEngine");
+ }
+ std::unique_ptr<BackupEngineReadOnlyImpl> backup_engine(
+ new BackupEngineReadOnlyImpl(env, options));
+ auto s = backup_engine->Initialize();
+ if (!s.ok()) {
+ *backup_engine_ptr = nullptr;
+ return s;
+ }
+ *backup_engine_ptr = backup_engine.release();
+ return Status::OK();
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/backupable/backupable_db_test.cc b/src/rocksdb/utilities/backupable/backupable_db_test.cc
new file mode 100644
index 000000000..efdb34b30
--- /dev/null
+++ b/src/rocksdb/utilities/backupable/backupable_db_test.cc
@@ -0,0 +1,1863 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#if !defined(ROCKSDB_LITE) && !defined(OS_WIN)
+
+#include <algorithm>
+#include <limits>
+#include <string>
+#include <utility>
+
+#include "db/db_impl/db_impl.h"
+#include "env/env_chroot.h"
+#include "file/filename.h"
+#include "port/port.h"
+#include "port/stack_trace.h"
+#include "rocksdb/rate_limiter.h"
+#include "rocksdb/transaction_log.h"
+#include "rocksdb/types.h"
+#include "rocksdb/utilities/backupable_db.h"
+#include "rocksdb/utilities/options_util.h"
+#include "test_util/sync_point.h"
+#include "test_util/testharness.h"
+#include "test_util/testutil.h"
+#include "util/mutexlock.h"
+#include "util/random.h"
+#include "util/stderr_logger.h"
+#include "util/string_util.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+namespace {
+
+class DummyDB : public StackableDB {
+ public:
+ /* implicit */
+ DummyDB(const Options& options, const std::string& dbname)
+ : StackableDB(nullptr), options_(options), dbname_(dbname),
+ deletions_enabled_(true), sequence_number_(0) {}
+
+ SequenceNumber GetLatestSequenceNumber() const override {
+ return ++sequence_number_;
+ }
+
+ const std::string& GetName() const override { return dbname_; }
+
+ Env* GetEnv() const override { return options_.env; }
+
+ using DB::GetOptions;
+ Options GetOptions(ColumnFamilyHandle* /*column_family*/) const override {
+ return options_;
+ }
+
+ DBOptions GetDBOptions() const override { return DBOptions(options_); }
+
+ Status EnableFileDeletions(bool /*force*/) override {
+ EXPECT_TRUE(!deletions_enabled_);
+ deletions_enabled_ = true;
+ return Status::OK();
+ }
+
+ Status DisableFileDeletions() override {
+ EXPECT_TRUE(deletions_enabled_);
+ deletions_enabled_ = false;
+ return Status::OK();
+ }
+
+ Status GetLiveFiles(std::vector<std::string>& vec, uint64_t* mfs,
+ bool /*flush_memtable*/ = true) override {
+ EXPECT_TRUE(!deletions_enabled_);
+ vec = live_files_;
+ *mfs = 100;
+ return Status::OK();
+ }
+
+ ColumnFamilyHandle* DefaultColumnFamily() const override { return nullptr; }
+
+ class DummyLogFile : public LogFile {
+ public:
+ /* implicit */
+ DummyLogFile(const std::string& path, bool alive = true)
+ : path_(path), alive_(alive) {}
+
+ std::string PathName() const override { return path_; }
+
+ uint64_t LogNumber() const override {
+ // what business do you have calling this method?
+ ADD_FAILURE();
+ return 0;
+ }
+
+ WalFileType Type() const override {
+ return alive_ ? kAliveLogFile : kArchivedLogFile;
+ }
+
+ SequenceNumber StartSequence() const override {
+ // this seqnum guarantees the dummy file will be included in the backup
+ // as long as it is alive.
+ return kMaxSequenceNumber;
+ }
+
+ uint64_t SizeFileBytes() const override { return 0; }
+
+ private:
+ std::string path_;
+ bool alive_;
+ }; // DummyLogFile
+
+ Status GetSortedWalFiles(VectorLogPtr& files) override {
+ EXPECT_TRUE(!deletions_enabled_);
+ files.resize(wal_files_.size());
+ for (size_t i = 0; i < files.size(); ++i) {
+ files[i].reset(
+ new DummyLogFile(wal_files_[i].first, wal_files_[i].second));
+ }
+ return Status::OK();
+ }
+
+ // To avoid FlushWAL called on stacked db which is nullptr
+ Status FlushWAL(bool /*sync*/) override { return Status::OK(); }
+
+ std::vector<std::string> live_files_;
+ // pair<filename, alive?>
+ std::vector<std::pair<std::string, bool>> wal_files_;
+ private:
+ Options options_;
+ std::string dbname_;
+ bool deletions_enabled_;
+ mutable SequenceNumber sequence_number_;
+}; // DummyDB
+
+class TestEnv : public EnvWrapper {
+ public:
+ explicit TestEnv(Env* t) : EnvWrapper(t) {}
+
+ class DummySequentialFile : public SequentialFile {
+ public:
+ explicit DummySequentialFile(bool fail_reads)
+ : SequentialFile(), rnd_(5), fail_reads_(fail_reads) {}
+ Status Read(size_t n, Slice* result, char* scratch) override {
+ if (fail_reads_) {
+ return Status::IOError();
+ }
+ size_t read_size = (n > size_left) ? size_left : n;
+ for (size_t i = 0; i < read_size; ++i) {
+ scratch[i] = rnd_.Next() & 255;
+ }
+ *result = Slice(scratch, read_size);
+ size_left -= read_size;
+ return Status::OK();
+ }
+
+ Status Skip(uint64_t n) override {
+ size_left = (n > size_left) ? size_left - n : 0;
+ return Status::OK();
+ }
+
+ private:
+ size_t size_left = 200;
+ Random rnd_;
+ bool fail_reads_;
+ };
+
+ Status NewSequentialFile(const std::string& f,
+ std::unique_ptr<SequentialFile>* r,
+ const EnvOptions& options) override {
+ MutexLock l(&mutex_);
+ if (dummy_sequential_file_) {
+ r->reset(
+ new TestEnv::DummySequentialFile(dummy_sequential_file_fail_reads_));
+ return Status::OK();
+ } else {
+ Status s = EnvWrapper::NewSequentialFile(f, r, options);
+ if (s.ok()) {
+ if ((*r)->use_direct_io()) {
+ ++num_direct_seq_readers_;
+ }
+ ++num_seq_readers_;
+ }
+ return s;
+ }
+ }
+
+ Status NewWritableFile(const std::string& f, std::unique_ptr<WritableFile>* r,
+ const EnvOptions& options) override {
+ MutexLock l(&mutex_);
+ written_files_.push_back(f);
+ if (limit_written_files_ <= 0) {
+ return Status::NotSupported("Sorry, can't do this");
+ }
+ limit_written_files_--;
+ Status s = EnvWrapper::NewWritableFile(f, r, options);
+ if (s.ok()) {
+ if ((*r)->use_direct_io()) {
+ ++num_direct_writers_;
+ }
+ ++num_writers_;
+ }
+ return s;
+ }
+
+ Status NewRandomAccessFile(const std::string& fname,
+ std::unique_ptr<RandomAccessFile>* result,
+ const EnvOptions& options) override {
+ MutexLock l(&mutex_);
+ Status s = EnvWrapper::NewRandomAccessFile(fname, result, options);
+ if (s.ok()) {
+ if ((*result)->use_direct_io()) {
+ ++num_direct_rand_readers_;
+ }
+ ++num_rand_readers_;
+ }
+ return s;
+ }
+
+ Status DeleteFile(const std::string& fname) override {
+ MutexLock l(&mutex_);
+ if (fail_delete_files_) {
+ return Status::IOError();
+ }
+ EXPECT_GT(limit_delete_files_, 0U);
+ limit_delete_files_--;
+ return EnvWrapper::DeleteFile(fname);
+ }
+
+ Status DeleteDir(const std::string& dirname) override {
+ MutexLock l(&mutex_);
+ if (fail_delete_files_) {
+ return Status::IOError();
+ }
+ return EnvWrapper::DeleteDir(dirname);
+ }
+
+ void AssertWrittenFiles(std::vector<std::string>& should_have_written) {
+ MutexLock l(&mutex_);
+ std::sort(should_have_written.begin(), should_have_written.end());
+ std::sort(written_files_.begin(), written_files_.end());
+
+ ASSERT_EQ(should_have_written, written_files_);
+ }
+
+ void ClearWrittenFiles() {
+ MutexLock l(&mutex_);
+ written_files_.clear();
+ }
+
+ void SetLimitWrittenFiles(uint64_t limit) {
+ MutexLock l(&mutex_);
+ limit_written_files_ = limit;
+ }
+
+ void SetLimitDeleteFiles(uint64_t limit) {
+ MutexLock l(&mutex_);
+ limit_delete_files_ = limit;
+ }
+
+ void SetDeleteFileFailure(bool fail) {
+ MutexLock l(&mutex_);
+ fail_delete_files_ = fail;
+ }
+
+ void SetDummySequentialFile(bool dummy_sequential_file) {
+ MutexLock l(&mutex_);
+ dummy_sequential_file_ = dummy_sequential_file;
+ }
+ void SetDummySequentialFileFailReads(bool dummy_sequential_file_fail_reads) {
+ MutexLock l(&mutex_);
+ dummy_sequential_file_fail_reads_ = dummy_sequential_file_fail_reads;
+ }
+
+ void SetGetChildrenFailure(bool fail) { get_children_failure_ = fail; }
+ Status GetChildren(const std::string& dir,
+ std::vector<std::string>* r) override {
+ if (get_children_failure_) {
+ return Status::IOError("SimulatedFailure");
+ }
+ return EnvWrapper::GetChildren(dir, r);
+ }
+
+ // Some test cases do not actually create the test files (e.g., see
+ // DummyDB::live_files_) - for those cases, we mock those files' attributes
+ // so CreateNewBackup() can get their attributes.
+ void SetFilenamesForMockedAttrs(const std::vector<std::string>& filenames) {
+ filenames_for_mocked_attrs_ = filenames;
+ }
+ Status GetChildrenFileAttributes(
+ const std::string& dir, std::vector<Env::FileAttributes>* r) override {
+ if (filenames_for_mocked_attrs_.size() > 0) {
+ for (const auto& filename : filenames_for_mocked_attrs_) {
+ r->push_back({dir + filename, 10 /* size_bytes */});
+ }
+ return Status::OK();
+ }
+ return EnvWrapper::GetChildrenFileAttributes(dir, r);
+ }
+ Status GetFileSize(const std::string& path, uint64_t* size_bytes) override {
+ if (filenames_for_mocked_attrs_.size() > 0) {
+ auto fname = path.substr(path.find_last_of('/'));
+ auto filename_iter = std::find(filenames_for_mocked_attrs_.begin(),
+ filenames_for_mocked_attrs_.end(), fname);
+ if (filename_iter != filenames_for_mocked_attrs_.end()) {
+ *size_bytes = 10;
+ return Status::OK();
+ }
+ return Status::NotFound(fname);
+ }
+ return EnvWrapper::GetFileSize(path, size_bytes);
+ }
+
+ void SetCreateDirIfMissingFailure(bool fail) {
+ create_dir_if_missing_failure_ = fail;
+ }
+ Status CreateDirIfMissing(const std::string& d) override {
+ if (create_dir_if_missing_failure_) {
+ return Status::IOError("SimulatedFailure");
+ }
+ return EnvWrapper::CreateDirIfMissing(d);
+ }
+
+ void SetNewDirectoryFailure(bool fail) { new_directory_failure_ = fail; }
+ Status NewDirectory(const std::string& name,
+ std::unique_ptr<Directory>* result) override {
+ if (new_directory_failure_) {
+ return Status::IOError("SimulatedFailure");
+ }
+ return EnvWrapper::NewDirectory(name, result);
+ }
+
+ void ClearFileOpenCounters() {
+ MutexLock l(&mutex_);
+ num_rand_readers_ = 0;
+ num_direct_rand_readers_ = 0;
+ num_seq_readers_ = 0;
+ num_direct_seq_readers_ = 0;
+ num_writers_ = 0;
+ num_direct_writers_ = 0;
+ }
+
+ int num_rand_readers() { return num_rand_readers_; }
+ int num_direct_rand_readers() { return num_direct_rand_readers_; }
+ int num_seq_readers() { return num_seq_readers_; }
+ int num_direct_seq_readers() { return num_direct_seq_readers_; }
+ int num_writers() { return num_writers_; }
+ int num_direct_writers() { return num_direct_writers_; }
+
+ private:
+ port::Mutex mutex_;
+ bool dummy_sequential_file_ = false;
+ bool dummy_sequential_file_fail_reads_ = false;
+ std::vector<std::string> written_files_;
+ std::vector<std::string> filenames_for_mocked_attrs_;
+ uint64_t limit_written_files_ = 1000000;
+ uint64_t limit_delete_files_ = 1000000;
+ bool fail_delete_files_ = false;
+
+ bool get_children_failure_ = false;
+ bool create_dir_if_missing_failure_ = false;
+ bool new_directory_failure_ = false;
+
+ // Keeps track of how many files of each type were successfully opened, and
+ // out of those, how many were opened with direct I/O.
+ std::atomic<int> num_rand_readers_;
+ std::atomic<int> num_direct_rand_readers_;
+ std::atomic<int> num_seq_readers_;
+ std::atomic<int> num_direct_seq_readers_;
+ std::atomic<int> num_writers_;
+ std::atomic<int> num_direct_writers_;
+}; // TestEnv
+
+class FileManager : public EnvWrapper {
+ public:
+ explicit FileManager(Env* t) : EnvWrapper(t), rnd_(5) {}
+
+ Status DeleteRandomFileInDir(const std::string& dir) {
+ std::vector<std::string> children;
+ GetChildren(dir, &children);
+ if (children.size() <= 2) { // . and ..
+ return Status::NotFound("");
+ }
+ while (true) {
+ int i = rnd_.Next() % children.size();
+ if (children[i] != "." && children[i] != "..") {
+ return DeleteFile(dir + "/" + children[i]);
+ }
+ }
+ // should never get here
+ assert(false);
+ return Status::NotFound("");
+ }
+
+ Status AppendToRandomFileInDir(const std::string& dir,
+ const std::string& data) {
+ std::vector<std::string> children;
+ GetChildren(dir, &children);
+ if (children.size() <= 2) {
+ return Status::NotFound("");
+ }
+ while (true) {
+ int i = rnd_.Next() % children.size();
+ if (children[i] != "." && children[i] != "..") {
+ return WriteToFile(dir + "/" + children[i], data);
+ }
+ }
+ // should never get here
+ assert(false);
+ return Status::NotFound("");
+ }
+
+ Status CorruptFile(const std::string& fname, uint64_t bytes_to_corrupt) {
+ std::string file_contents;
+ Status s = ReadFileToString(this, fname, &file_contents);
+ if (!s.ok()) {
+ return s;
+ }
+ s = DeleteFile(fname);
+ if (!s.ok()) {
+ return s;
+ }
+
+ for (uint64_t i = 0; i < bytes_to_corrupt; ++i) {
+ std::string tmp;
+ test::RandomString(&rnd_, 1, &tmp);
+ file_contents[rnd_.Next() % file_contents.size()] = tmp[0];
+ }
+ return WriteToFile(fname, file_contents);
+ }
+
+ Status CorruptChecksum(const std::string& fname, bool appear_valid) {
+ std::string metadata;
+ Status s = ReadFileToString(this, fname, &metadata);
+ if (!s.ok()) {
+ return s;
+ }
+ s = DeleteFile(fname);
+ if (!s.ok()) {
+ return s;
+ }
+
+ auto pos = metadata.find("private");
+ if (pos == std::string::npos) {
+ return Status::Corruption("private file is expected");
+ }
+ pos = metadata.find(" crc32 ", pos + 6);
+ if (pos == std::string::npos) {
+ return Status::Corruption("checksum not found");
+ }
+
+ if (metadata.size() < pos + 7) {
+ return Status::Corruption("bad CRC32 checksum value");
+ }
+
+ if (appear_valid) {
+ if (metadata[pos + 8] == '\n') {
+ // single digit value, safe to insert one more digit
+ metadata.insert(pos + 8, 1, '0');
+ } else {
+ metadata.erase(pos + 8, 1);
+ }
+ } else {
+ metadata[pos + 7] = 'a';
+ }
+
+ return WriteToFile(fname, metadata);
+ }
+
+ Status WriteToFile(const std::string& fname, const std::string& data) {
+ std::unique_ptr<WritableFile> file;
+ EnvOptions env_options;
+ env_options.use_mmap_writes = false;
+ Status s = EnvWrapper::NewWritableFile(fname, &file, env_options);
+ if (!s.ok()) {
+ return s;
+ }
+ return file->Append(Slice(data));
+ }
+
+ private:
+ Random rnd_;
+}; // FileManager
+
+// utility functions
+static size_t FillDB(DB* db, int from, int to) {
+ size_t bytes_written = 0;
+ for (int i = from; i < to; ++i) {
+ std::string key = "testkey" + ToString(i);
+ std::string value = "testvalue" + ToString(i);
+ bytes_written += key.size() + value.size();
+
+ EXPECT_OK(db->Put(WriteOptions(), Slice(key), Slice(value)));
+ }
+ return bytes_written;
+}
+
+static void AssertExists(DB* db, int from, int to) {
+ for (int i = from; i < to; ++i) {
+ std::string key = "testkey" + ToString(i);
+ std::string value;
+ Status s = db->Get(ReadOptions(), Slice(key), &value);
+ ASSERT_EQ(value, "testvalue" + ToString(i));
+ }
+}
+
+static void AssertEmpty(DB* db, int from, int to) {
+ for (int i = from; i < to; ++i) {
+ std::string key = "testkey" + ToString(i);
+ std::string value = "testvalue" + ToString(i);
+
+ Status s = db->Get(ReadOptions(), Slice(key), &value);
+ ASSERT_TRUE(s.IsNotFound());
+ }
+}
+
+class BackupableDBTest : public testing::Test {
+ public:
+ enum ShareOption {
+ kNoShare,
+ kShareNoChecksum,
+ kShareWithChecksum,
+ };
+
+ const std::vector<ShareOption> kAllShareOptions = {
+ kNoShare, kShareNoChecksum, kShareWithChecksum};
+
+ BackupableDBTest() {
+ // set up files
+ std::string db_chroot = test::PerThreadDBPath("backupable_db");
+ std::string backup_chroot = test::PerThreadDBPath("backupable_db_backup");
+ Env::Default()->CreateDir(db_chroot);
+ Env::Default()->CreateDir(backup_chroot);
+ dbname_ = "/tempdb";
+ backupdir_ = "/tempbk";
+
+ // set up envs
+ db_chroot_env_.reset(NewChrootEnv(Env::Default(), db_chroot));
+ backup_chroot_env_.reset(NewChrootEnv(Env::Default(), backup_chroot));
+ test_db_env_.reset(new TestEnv(db_chroot_env_.get()));
+ test_backup_env_.reset(new TestEnv(backup_chroot_env_.get()));
+ file_manager_.reset(new FileManager(backup_chroot_env_.get()));
+
+ // set up db options
+ options_.create_if_missing = true;
+ options_.paranoid_checks = true;
+ options_.write_buffer_size = 1 << 17; // 128KB
+ options_.env = test_db_env_.get();
+ options_.wal_dir = dbname_;
+
+ // Create logger
+ DBOptions logger_options;
+ logger_options.env = db_chroot_env_.get();
+ CreateLoggerFromOptions(dbname_, logger_options, &logger_);
+
+ // set up backup db options
+ backupable_options_.reset(new BackupableDBOptions(
+ backupdir_, test_backup_env_.get(), true, logger_.get(), true));
+
+ // most tests will use multi-threaded backups
+ backupable_options_->max_background_operations = 7;
+
+ // delete old files in db
+ DestroyDB(dbname_, options_);
+ }
+
+ DB* OpenDB() {
+ DB* db;
+ EXPECT_OK(DB::Open(options_, dbname_, &db));
+ return db;
+ }
+
+ void OpenDBAndBackupEngine(bool destroy_old_data = false, bool dummy = false,
+ ShareOption shared_option = kShareNoChecksum) {
+ // reset all the defaults
+ test_backup_env_->SetLimitWrittenFiles(1000000);
+ test_db_env_->SetLimitWrittenFiles(1000000);
+ test_db_env_->SetDummySequentialFile(dummy);
+
+ DB* db;
+ if (dummy) {
+ dummy_db_ = new DummyDB(options_, dbname_);
+ db = dummy_db_;
+ } else {
+ ASSERT_OK(DB::Open(options_, dbname_, &db));
+ }
+ db_.reset(db);
+ backupable_options_->destroy_old_data = destroy_old_data;
+ backupable_options_->share_table_files = shared_option != kNoShare;
+ backupable_options_->share_files_with_checksum =
+ shared_option == kShareWithChecksum;
+ BackupEngine* backup_engine;
+ ASSERT_OK(BackupEngine::Open(test_db_env_.get(), *backupable_options_,
+ &backup_engine));
+ backup_engine_.reset(backup_engine);
+ }
+
+ void CloseDBAndBackupEngine() {
+ db_.reset();
+ backup_engine_.reset();
+ }
+
+ void OpenBackupEngine() {
+ backupable_options_->destroy_old_data = false;
+ BackupEngine* backup_engine;
+ ASSERT_OK(BackupEngine::Open(test_db_env_.get(), *backupable_options_,
+ &backup_engine));
+ backup_engine_.reset(backup_engine);
+ }
+
+ void CloseBackupEngine() { backup_engine_.reset(nullptr); }
+
+ // restores backup backup_id and asserts the existence of
+ // [start_exist, end_exist> and not-existence of
+ // [end_exist, end>
+ //
+ // if backup_id == 0, it means restore from latest
+ // if end == 0, don't check AssertEmpty
+ void AssertBackupConsistency(BackupID backup_id, uint32_t start_exist,
+ uint32_t end_exist, uint32_t end = 0,
+ bool keep_log_files = false) {
+ RestoreOptions restore_options(keep_log_files);
+ bool opened_backup_engine = false;
+ if (backup_engine_.get() == nullptr) {
+ opened_backup_engine = true;
+ OpenBackupEngine();
+ }
+ if (backup_id > 0) {
+ ASSERT_OK(backup_engine_->RestoreDBFromBackup(backup_id, dbname_, dbname_,
+ restore_options));
+ } else {
+ ASSERT_OK(backup_engine_->RestoreDBFromLatestBackup(dbname_, dbname_,
+ restore_options));
+ }
+ DB* db = OpenDB();
+ AssertExists(db, start_exist, end_exist);
+ if (end != 0) {
+ AssertEmpty(db, end_exist, end);
+ }
+ delete db;
+ if (opened_backup_engine) {
+ CloseBackupEngine();
+ }
+ }
+
+ void DeleteLogFiles() {
+ std::vector<std::string> delete_logs;
+ db_chroot_env_->GetChildren(dbname_, &delete_logs);
+ for (auto f : delete_logs) {
+ uint64_t number;
+ FileType type;
+ bool ok = ParseFileName(f, &number, &type);
+ if (ok && type == kLogFile) {
+ db_chroot_env_->DeleteFile(dbname_ + "/" + f);
+ }
+ }
+ }
+
+ // files
+ std::string dbname_;
+ std::string backupdir_;
+
+ // logger_ must be above backup_engine_ such that the engine's destructor,
+ // which uses a raw pointer to the logger, executes first.
+ std::shared_ptr<Logger> logger_;
+
+ // envs
+ std::unique_ptr<Env> db_chroot_env_;
+ std::unique_ptr<Env> backup_chroot_env_;
+ std::unique_ptr<TestEnv> test_db_env_;
+ std::unique_ptr<TestEnv> test_backup_env_;
+ std::unique_ptr<FileManager> file_manager_;
+
+ // all the dbs!
+ DummyDB* dummy_db_; // BackupableDB owns dummy_db_
+ std::unique_ptr<DB> db_;
+ std::unique_ptr<BackupEngine> backup_engine_;
+
+ // options
+ Options options_;
+
+ protected:
+ std::unique_ptr<BackupableDBOptions> backupable_options_;
+}; // BackupableDBTest
+
+void AppendPath(const std::string& path, std::vector<std::string>& v) {
+ for (auto& f : v) {
+ f = path + f;
+ }
+}
+
+class BackupableDBTestWithParam : public BackupableDBTest,
+ public testing::WithParamInterface<bool> {
+ public:
+ BackupableDBTestWithParam() {
+ backupable_options_->share_files_with_checksum = GetParam();
+ }
+};
+
+// This test verifies that the verifyBackup method correctly identifies
+// invalid backups
+TEST_P(BackupableDBTestWithParam, VerifyBackup) {
+ const int keys_iteration = 5000;
+ Random rnd(6);
+ Status s;
+ OpenDBAndBackupEngine(true);
+ // create five backups
+ for (int i = 0; i < 5; ++i) {
+ FillDB(db_.get(), keys_iteration * i, keys_iteration * (i + 1));
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), true));
+ }
+ CloseDBAndBackupEngine();
+
+ OpenDBAndBackupEngine();
+ // ---------- case 1. - valid backup -----------
+ ASSERT_TRUE(backup_engine_->VerifyBackup(1).ok());
+
+ // ---------- case 2. - delete a file -----------i
+ file_manager_->DeleteRandomFileInDir(backupdir_ + "/private/1");
+ ASSERT_TRUE(backup_engine_->VerifyBackup(1).IsNotFound());
+
+ // ---------- case 3. - corrupt a file -----------
+ std::string append_data = "Corrupting a random file";
+ file_manager_->AppendToRandomFileInDir(backupdir_ + "/private/2",
+ append_data);
+ ASSERT_TRUE(backup_engine_->VerifyBackup(2).IsCorruption());
+
+ // ---------- case 4. - invalid backup -----------
+ ASSERT_TRUE(backup_engine_->VerifyBackup(6).IsNotFound());
+ CloseDBAndBackupEngine();
+}
+
+// open DB, write, close DB, backup, restore, repeat
+TEST_P(BackupableDBTestWithParam, OfflineIntegrationTest) {
+ // has to be a big number, so that it triggers the memtable flush
+ const int keys_iteration = 5000;
+ const int max_key = keys_iteration * 4 + 10;
+ // first iter -- flush before backup
+ // second iter -- don't flush before backup
+ for (int iter = 0; iter < 2; ++iter) {
+ // delete old data
+ DestroyDB(dbname_, options_);
+ bool destroy_data = true;
+
+ // every iteration --
+ // 1. insert new data in the DB
+ // 2. backup the DB
+ // 3. destroy the db
+ // 4. restore the db, check everything is still there
+ for (int i = 0; i < 5; ++i) {
+ // in last iteration, put smaller amount of data,
+ int fill_up_to = std::min(keys_iteration * (i + 1), max_key);
+ // ---- insert new data and back up ----
+ OpenDBAndBackupEngine(destroy_data);
+ destroy_data = false;
+ FillDB(db_.get(), keys_iteration * i, fill_up_to);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), iter == 0));
+ CloseDBAndBackupEngine();
+ DestroyDB(dbname_, options_);
+
+ // ---- make sure it's empty ----
+ DB* db = OpenDB();
+ AssertEmpty(db, 0, fill_up_to);
+ delete db;
+
+ // ---- restore the DB ----
+ OpenBackupEngine();
+ if (i >= 3) { // test purge old backups
+ // when i == 4, purge to only 1 backup
+ // when i == 3, purge to 2 backups
+ ASSERT_OK(backup_engine_->PurgeOldBackups(5 - i));
+ }
+ // ---- make sure the data is there ---
+ AssertBackupConsistency(0, 0, fill_up_to, max_key);
+ CloseBackupEngine();
+ }
+ }
+}
+
+// open DB, write, backup, write, backup, close, restore
+TEST_P(BackupableDBTestWithParam, OnlineIntegrationTest) {
+ // has to be a big number, so that it triggers the memtable flush
+ const int keys_iteration = 5000;
+ const int max_key = keys_iteration * 4 + 10;
+ Random rnd(7);
+ // delete old data
+ DestroyDB(dbname_, options_);
+
+ OpenDBAndBackupEngine(true);
+ // write some data, backup, repeat
+ for (int i = 0; i < 5; ++i) {
+ if (i == 4) {
+ // delete backup number 2, online delete!
+ ASSERT_OK(backup_engine_->DeleteBackup(2));
+ }
+ // in last iteration, put smaller amount of data,
+ // so that backups can share sst files
+ int fill_up_to = std::min(keys_iteration * (i + 1), max_key);
+ FillDB(db_.get(), keys_iteration * i, fill_up_to);
+ // we should get consistent results with flush_before_backup
+ // set to both true and false
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), !!(rnd.Next() % 2)));
+ }
+ // close and destroy
+ CloseDBAndBackupEngine();
+ DestroyDB(dbname_, options_);
+
+ // ---- make sure it's empty ----
+ DB* db = OpenDB();
+ AssertEmpty(db, 0, max_key);
+ delete db;
+
+ // ---- restore every backup and verify all the data is there ----
+ OpenBackupEngine();
+ for (int i = 1; i <= 5; ++i) {
+ if (i == 2) {
+ // we deleted backup 2
+ Status s = backup_engine_->RestoreDBFromBackup(2, dbname_, dbname_);
+ ASSERT_TRUE(!s.ok());
+ } else {
+ int fill_up_to = std::min(keys_iteration * i, max_key);
+ AssertBackupConsistency(i, 0, fill_up_to, max_key);
+ }
+ }
+
+ // delete some backups -- this should leave only backups 3 and 5 alive
+ ASSERT_OK(backup_engine_->DeleteBackup(4));
+ ASSERT_OK(backup_engine_->PurgeOldBackups(2));
+
+ std::vector<BackupInfo> backup_info;
+ backup_engine_->GetBackupInfo(&backup_info);
+ ASSERT_EQ(2UL, backup_info.size());
+
+ // check backup 3
+ AssertBackupConsistency(3, 0, 3 * keys_iteration, max_key);
+ // check backup 5
+ AssertBackupConsistency(5, 0, max_key);
+
+ CloseBackupEngine();
+}
+
+INSTANTIATE_TEST_CASE_P(BackupableDBTestWithParam, BackupableDBTestWithParam,
+ ::testing::Bool());
+
+// this will make sure that backup does not copy the same file twice
+TEST_F(BackupableDBTest, NoDoubleCopy_And_AutoGC) {
+ OpenDBAndBackupEngine(true, true);
+
+ // should write 5 DB files + one meta file
+ test_backup_env_->SetLimitWrittenFiles(7);
+ test_backup_env_->ClearWrittenFiles();
+ test_db_env_->SetLimitWrittenFiles(0);
+ dummy_db_->live_files_ = {"/00010.sst", "/00011.sst", "/CURRENT",
+ "/MANIFEST-01"};
+ dummy_db_->wal_files_ = {{"/00011.log", true}, {"/00012.log", false}};
+ test_db_env_->SetFilenamesForMockedAttrs(dummy_db_->live_files_);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), false));
+ std::vector<std::string> should_have_written = {
+ "/shared/.00010.sst.tmp", "/shared/.00011.sst.tmp", "/private/1/CURRENT",
+ "/private/1/MANIFEST-01", "/private/1/00011.log", "/meta/.1.tmp"};
+ AppendPath(backupdir_, should_have_written);
+ test_backup_env_->AssertWrittenFiles(should_have_written);
+
+ char db_number = '1';
+
+ for (std::string other_sst : {"00015.sst", "00017.sst", "00019.sst"}) {
+ // should write 4 new DB files + one meta file
+ // should not write/copy 00010.sst, since it's already there!
+ test_backup_env_->SetLimitWrittenFiles(6);
+ test_backup_env_->ClearWrittenFiles();
+
+ dummy_db_->live_files_ = {"/00010.sst", "/" + other_sst, "/CURRENT",
+ "/MANIFEST-01"};
+ dummy_db_->wal_files_ = {{"/00011.log", true}, {"/00012.log", false}};
+ test_db_env_->SetFilenamesForMockedAttrs(dummy_db_->live_files_);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), false));
+ // should not open 00010.sst - it's already there
+
+ ++db_number;
+ std::string private_dir = std::string("/private/") + db_number;
+ should_have_written = {
+ "/shared/." + other_sst + ".tmp", private_dir + "/CURRENT",
+ private_dir + "/MANIFEST-01", private_dir + "/00011.log",
+ std::string("/meta/.") + db_number + ".tmp"};
+ AppendPath(backupdir_, should_have_written);
+ test_backup_env_->AssertWrittenFiles(should_have_written);
+ }
+
+ ASSERT_OK(backup_engine_->DeleteBackup(1));
+ ASSERT_OK(test_backup_env_->FileExists(backupdir_ + "/shared/00010.sst"));
+
+ // 00011.sst was only in backup 1, should be deleted
+ ASSERT_EQ(Status::NotFound(),
+ test_backup_env_->FileExists(backupdir_ + "/shared/00011.sst"));
+ ASSERT_OK(test_backup_env_->FileExists(backupdir_ + "/shared/00015.sst"));
+
+ // MANIFEST file size should be only 100
+ uint64_t size = 0;
+ test_backup_env_->GetFileSize(backupdir_ + "/private/2/MANIFEST-01", &size);
+ ASSERT_EQ(100UL, size);
+ test_backup_env_->GetFileSize(backupdir_ + "/shared/00015.sst", &size);
+ ASSERT_EQ(200UL, size);
+
+ CloseBackupEngine();
+
+ //
+ // Now simulate incomplete delete by removing just meta
+ //
+ ASSERT_OK(test_backup_env_->DeleteFile(backupdir_ + "/meta/2"));
+
+ OpenBackupEngine();
+
+ // 1 appears to be removed, so
+ // 2 non-corrupt and 0 corrupt seen
+ std::vector<BackupInfo> backup_info;
+ std::vector<BackupID> corrupt_backup_ids;
+ backup_engine_->GetBackupInfo(&backup_info);
+ backup_engine_->GetCorruptedBackups(&corrupt_backup_ids);
+ ASSERT_EQ(2UL, backup_info.size());
+ ASSERT_EQ(0UL, corrupt_backup_ids.size());
+
+ // Keep the two we see, but this should suffice to purge unreferenced
+ // shared files from incomplete delete.
+ ASSERT_OK(backup_engine_->PurgeOldBackups(2));
+
+ // Make sure dangling sst file has been removed (somewhere along this
+ // process). GarbageCollect should not be needed.
+ ASSERT_EQ(Status::NotFound(),
+ test_backup_env_->FileExists(backupdir_ + "/shared/00015.sst"));
+ ASSERT_OK(test_backup_env_->FileExists(backupdir_ + "/shared/00017.sst"));
+ ASSERT_OK(test_backup_env_->FileExists(backupdir_ + "/shared/00019.sst"));
+
+ // Now actually purge a good one
+ ASSERT_OK(backup_engine_->PurgeOldBackups(1));
+
+ ASSERT_EQ(Status::NotFound(),
+ test_backup_env_->FileExists(backupdir_ + "/shared/00017.sst"));
+ ASSERT_OK(test_backup_env_->FileExists(backupdir_ + "/shared/00019.sst"));
+
+ CloseDBAndBackupEngine();
+}
+
+// test various kind of corruptions that may happen:
+// 1. Not able to write a file for backup - that backup should fail,
+// everything else should work
+// 2. Corrupted backup meta file or missing backuped file - we should
+// not be able to open that backup, but all other backups should be
+// fine
+// 3. Corrupted checksum value - if the checksum is not a valid uint32_t,
+// db open should fail, otherwise, it aborts during the restore process.
+TEST_F(BackupableDBTest, CorruptionsTest) {
+ const int keys_iteration = 5000;
+ Random rnd(6);
+ Status s;
+
+ OpenDBAndBackupEngine(true);
+ // create five backups
+ for (int i = 0; i < 5; ++i) {
+ FillDB(db_.get(), keys_iteration * i, keys_iteration * (i + 1));
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), !!(rnd.Next() % 2)));
+ }
+
+ // ---------- case 1. - fail a write -----------
+ // try creating backup 6, but fail a write
+ FillDB(db_.get(), keys_iteration * 5, keys_iteration * 6);
+ test_backup_env_->SetLimitWrittenFiles(2);
+ // should fail
+ s = backup_engine_->CreateNewBackup(db_.get(), !!(rnd.Next() % 2));
+ ASSERT_TRUE(!s.ok());
+ test_backup_env_->SetLimitWrittenFiles(1000000);
+ // latest backup should have all the keys
+ CloseDBAndBackupEngine();
+ AssertBackupConsistency(0, 0, keys_iteration * 5, keys_iteration * 6);
+
+ // --------- case 2. corrupted backup meta or missing backuped file ----
+ ASSERT_OK(file_manager_->CorruptFile(backupdir_ + "/meta/5", 3));
+ // since 5 meta is now corrupted, latest backup should be 4
+ AssertBackupConsistency(0, 0, keys_iteration * 4, keys_iteration * 5);
+ OpenBackupEngine();
+ s = backup_engine_->RestoreDBFromBackup(5, dbname_, dbname_);
+ ASSERT_TRUE(!s.ok());
+ CloseBackupEngine();
+ ASSERT_OK(file_manager_->DeleteRandomFileInDir(backupdir_ + "/private/4"));
+ // 4 is corrupted, 3 is the latest backup now
+ AssertBackupConsistency(0, 0, keys_iteration * 3, keys_iteration * 5);
+ OpenBackupEngine();
+ s = backup_engine_->RestoreDBFromBackup(4, dbname_, dbname_);
+ CloseBackupEngine();
+ ASSERT_TRUE(!s.ok());
+
+ // --------- case 3. corrupted checksum value ----
+ ASSERT_OK(file_manager_->CorruptChecksum(backupdir_ + "/meta/3", false));
+ // checksum of backup 3 is an invalid value, this can be detected at
+ // db open time, and it reverts to the previous backup automatically
+ AssertBackupConsistency(0, 0, keys_iteration * 2, keys_iteration * 5);
+ // checksum of the backup 2 appears to be valid, this can cause checksum
+ // mismatch and abort restore process
+ ASSERT_OK(file_manager_->CorruptChecksum(backupdir_ + "/meta/2", true));
+ ASSERT_OK(file_manager_->FileExists(backupdir_ + "/meta/2"));
+ OpenBackupEngine();
+ ASSERT_OK(file_manager_->FileExists(backupdir_ + "/meta/2"));
+ s = backup_engine_->RestoreDBFromBackup(2, dbname_, dbname_);
+ ASSERT_TRUE(!s.ok());
+
+ // make sure that no corrupt backups have actually been deleted!
+ ASSERT_OK(file_manager_->FileExists(backupdir_ + "/meta/1"));
+ ASSERT_OK(file_manager_->FileExists(backupdir_ + "/meta/2"));
+ ASSERT_OK(file_manager_->FileExists(backupdir_ + "/meta/3"));
+ ASSERT_OK(file_manager_->FileExists(backupdir_ + "/meta/4"));
+ ASSERT_OK(file_manager_->FileExists(backupdir_ + "/meta/5"));
+ ASSERT_OK(file_manager_->FileExists(backupdir_ + "/private/1"));
+ ASSERT_OK(file_manager_->FileExists(backupdir_ + "/private/2"));
+ ASSERT_OK(file_manager_->FileExists(backupdir_ + "/private/3"));
+ ASSERT_OK(file_manager_->FileExists(backupdir_ + "/private/4"));
+ ASSERT_OK(file_manager_->FileExists(backupdir_ + "/private/5"));
+
+ // delete the corrupt backups and then make sure they're actually deleted
+ ASSERT_OK(backup_engine_->DeleteBackup(5));
+ ASSERT_OK(backup_engine_->DeleteBackup(4));
+ ASSERT_OK(backup_engine_->DeleteBackup(3));
+ ASSERT_OK(backup_engine_->DeleteBackup(2));
+ // Should not be needed anymore with auto-GC on DeleteBackup
+ //(void)backup_engine_->GarbageCollect();
+ ASSERT_EQ(Status::NotFound(),
+ file_manager_->FileExists(backupdir_ + "/meta/5"));
+ ASSERT_EQ(Status::NotFound(),
+ file_manager_->FileExists(backupdir_ + "/private/5"));
+ ASSERT_EQ(Status::NotFound(),
+ file_manager_->FileExists(backupdir_ + "/meta/4"));
+ ASSERT_EQ(Status::NotFound(),
+ file_manager_->FileExists(backupdir_ + "/private/4"));
+ ASSERT_EQ(Status::NotFound(),
+ file_manager_->FileExists(backupdir_ + "/meta/3"));
+ ASSERT_EQ(Status::NotFound(),
+ file_manager_->FileExists(backupdir_ + "/private/3"));
+ ASSERT_EQ(Status::NotFound(),
+ file_manager_->FileExists(backupdir_ + "/meta/2"));
+ ASSERT_EQ(Status::NotFound(),
+ file_manager_->FileExists(backupdir_ + "/private/2"));
+
+ CloseBackupEngine();
+ AssertBackupConsistency(0, 0, keys_iteration * 1, keys_iteration * 5);
+
+ // new backup should be 2!
+ OpenDBAndBackupEngine();
+ FillDB(db_.get(), keys_iteration * 1, keys_iteration * 2);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), !!(rnd.Next() % 2)));
+ CloseDBAndBackupEngine();
+ AssertBackupConsistency(2, 0, keys_iteration * 2, keys_iteration * 5);
+}
+
+TEST_F(BackupableDBTest, InterruptCreationTest) {
+ // Interrupt backup creation by failing new writes and failing cleanup of the
+ // partial state. Then verify a subsequent backup can still succeed.
+ const int keys_iteration = 5000;
+ Random rnd(6);
+
+ OpenDBAndBackupEngine(true /* destroy_old_data */);
+ FillDB(db_.get(), 0, keys_iteration);
+ test_backup_env_->SetLimitWrittenFiles(2);
+ test_backup_env_->SetDeleteFileFailure(true);
+ // should fail creation
+ ASSERT_FALSE(
+ backup_engine_->CreateNewBackup(db_.get(), !!(rnd.Next() % 2)).ok());
+ CloseDBAndBackupEngine();
+ // should also fail cleanup so the tmp directory stays behind
+ ASSERT_OK(backup_chroot_env_->FileExists(backupdir_ + "/private/1/"));
+
+ OpenDBAndBackupEngine(false /* destroy_old_data */);
+ test_backup_env_->SetLimitWrittenFiles(1000000);
+ test_backup_env_->SetDeleteFileFailure(false);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), !!(rnd.Next() % 2)));
+ // latest backup should have all the keys
+ CloseDBAndBackupEngine();
+ AssertBackupConsistency(0, 0, keys_iteration);
+}
+
+inline std::string OptionsPath(std::string ret, int backupID) {
+ ret += "/private/";
+ ret += std::to_string(backupID);
+ ret += "/";
+ return ret;
+}
+
+// Backup the LATEST options file to
+// "<backup_dir>/private/<backup_id>/OPTIONS<number>"
+
+TEST_F(BackupableDBTest, BackupOptions) {
+ OpenDBAndBackupEngine(true);
+ for (int i = 1; i < 5; i++) {
+ std::string name;
+ std::vector<std::string> filenames;
+ // Must reset() before reset(OpenDB()) again.
+ // Calling OpenDB() while *db_ is existing will cause LOCK issue
+ db_.reset();
+ db_.reset(OpenDB());
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), true));
+ ROCKSDB_NAMESPACE::GetLatestOptionsFileName(db_->GetName(), options_.env,
+ &name);
+ ASSERT_OK(file_manager_->FileExists(OptionsPath(backupdir_, i) + name));
+ backup_chroot_env_->GetChildren(OptionsPath(backupdir_, i), &filenames);
+ for (auto fn : filenames) {
+ if (fn.compare(0, 7, "OPTIONS") == 0) {
+ ASSERT_EQ(name, fn);
+ }
+ }
+ }
+
+ CloseDBAndBackupEngine();
+}
+
+TEST_F(BackupableDBTest, SetOptionsBackupRaceCondition) {
+ OpenDBAndBackupEngine(true);
+ SyncPoint::GetInstance()->LoadDependency(
+ {{"CheckpointImpl::CreateCheckpoint:SavedLiveFiles1",
+ "BackupableDBTest::SetOptionsBackupRaceCondition:BeforeSetOptions"},
+ {"BackupableDBTest::SetOptionsBackupRaceCondition:AfterSetOptions",
+ "CheckpointImpl::CreateCheckpoint:SavedLiveFiles2"}});
+ SyncPoint::GetInstance()->EnableProcessing();
+ ROCKSDB_NAMESPACE::port::Thread setoptions_thread{[this]() {
+ TEST_SYNC_POINT(
+ "BackupableDBTest::SetOptionsBackupRaceCondition:BeforeSetOptions");
+ DBImpl* dbi = static_cast<DBImpl*>(db_.get());
+ // Change arbitrary option to trigger OPTIONS file deletion
+ ASSERT_OK(dbi->SetOptions(dbi->DefaultColumnFamily(),
+ {{"paranoid_file_checks", "false"}}));
+ ASSERT_OK(dbi->SetOptions(dbi->DefaultColumnFamily(),
+ {{"paranoid_file_checks", "true"}}));
+ ASSERT_OK(dbi->SetOptions(dbi->DefaultColumnFamily(),
+ {{"paranoid_file_checks", "false"}}));
+ TEST_SYNC_POINT(
+ "BackupableDBTest::SetOptionsBackupRaceCondition:AfterSetOptions");
+ }};
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get()));
+ setoptions_thread.join();
+ CloseDBAndBackupEngine();
+}
+
+// This test verifies we don't delete the latest backup when read-only option is
+// set
+TEST_F(BackupableDBTest, NoDeleteWithReadOnly) {
+ const int keys_iteration = 5000;
+ Random rnd(6);
+ Status s;
+
+ OpenDBAndBackupEngine(true);
+ // create five backups
+ for (int i = 0; i < 5; ++i) {
+ FillDB(db_.get(), keys_iteration * i, keys_iteration * (i + 1));
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), !!(rnd.Next() % 2)));
+ }
+ CloseDBAndBackupEngine();
+ ASSERT_OK(file_manager_->WriteToFile(backupdir_ + "/LATEST_BACKUP", "4"));
+
+ backupable_options_->destroy_old_data = false;
+ BackupEngineReadOnly* read_only_backup_engine;
+ ASSERT_OK(BackupEngineReadOnly::Open(backup_chroot_env_.get(),
+ *backupable_options_,
+ &read_only_backup_engine));
+
+ // assert that data from backup 5 is still here (even though LATEST_BACKUP
+ // says 4 is latest)
+ ASSERT_OK(file_manager_->FileExists(backupdir_ + "/meta/5"));
+ ASSERT_OK(file_manager_->FileExists(backupdir_ + "/private/5"));
+
+ // Behavior change: We now ignore LATEST_BACKUP contents. This means that
+ // we should have 5 backups, even if LATEST_BACKUP says 4.
+ std::vector<BackupInfo> backup_info;
+ read_only_backup_engine->GetBackupInfo(&backup_info);
+ ASSERT_EQ(5UL, backup_info.size());
+ delete read_only_backup_engine;
+}
+
+TEST_F(BackupableDBTest, FailOverwritingBackups) {
+ options_.write_buffer_size = 1024 * 1024 * 1024; // 1GB
+ options_.disable_auto_compactions = true;
+
+ // create backups 1, 2, 3, 4, 5
+ OpenDBAndBackupEngine(true);
+ for (int i = 0; i < 5; ++i) {
+ CloseDBAndBackupEngine();
+ DeleteLogFiles();
+ OpenDBAndBackupEngine(false);
+ FillDB(db_.get(), 100 * i, 100 * (i + 1));
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), true));
+ }
+ CloseDBAndBackupEngine();
+
+ // restore 3
+ OpenBackupEngine();
+ ASSERT_OK(backup_engine_->RestoreDBFromBackup(3, dbname_, dbname_));
+ CloseBackupEngine();
+
+ OpenDBAndBackupEngine(false);
+ FillDB(db_.get(), 0, 300);
+ Status s = backup_engine_->CreateNewBackup(db_.get(), true);
+ // the new backup fails because new table files
+ // clash with old table files from backups 4 and 5
+ // (since write_buffer_size is huge, we can be sure that
+ // each backup will generate only one sst file and that
+ // a file generated by a new backup is the same as
+ // sst file generated by backup 4)
+ ASSERT_TRUE(s.IsCorruption());
+ ASSERT_OK(backup_engine_->DeleteBackup(4));
+ ASSERT_OK(backup_engine_->DeleteBackup(5));
+ // now, the backup can succeed
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), true));
+ CloseDBAndBackupEngine();
+}
+
+TEST_F(BackupableDBTest, NoShareTableFiles) {
+ const int keys_iteration = 5000;
+ OpenDBAndBackupEngine(true, false, kNoShare);
+ for (int i = 0; i < 5; ++i) {
+ FillDB(db_.get(), keys_iteration * i, keys_iteration * (i + 1));
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), !!(i % 2)));
+ }
+ CloseDBAndBackupEngine();
+
+ for (int i = 0; i < 5; ++i) {
+ AssertBackupConsistency(i + 1, 0, keys_iteration * (i + 1),
+ keys_iteration * 6);
+ }
+}
+
+// Verify that you can backup and restore with share_files_with_checksum on
+TEST_F(BackupableDBTest, ShareTableFilesWithChecksums) {
+ const int keys_iteration = 5000;
+ OpenDBAndBackupEngine(true, false, kShareWithChecksum);
+ for (int i = 0; i < 5; ++i) {
+ FillDB(db_.get(), keys_iteration * i, keys_iteration * (i + 1));
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), !!(i % 2)));
+ }
+ CloseDBAndBackupEngine();
+
+ for (int i = 0; i < 5; ++i) {
+ AssertBackupConsistency(i + 1, 0, keys_iteration * (i + 1),
+ keys_iteration * 6);
+ }
+}
+
+// Verify that you can backup and restore using share_files_with_checksum set to
+// false and then transition this option to true
+TEST_F(BackupableDBTest, ShareTableFilesWithChecksumsTransition) {
+ const int keys_iteration = 5000;
+ // set share_files_with_checksum to false
+ OpenDBAndBackupEngine(true, false, kShareNoChecksum);
+ for (int i = 0; i < 5; ++i) {
+ FillDB(db_.get(), keys_iteration * i, keys_iteration * (i + 1));
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), true));
+ }
+ CloseDBAndBackupEngine();
+
+ for (int i = 0; i < 5; ++i) {
+ AssertBackupConsistency(i + 1, 0, keys_iteration * (i + 1),
+ keys_iteration * 6);
+ }
+
+ // set share_files_with_checksum to true and do some more backups
+ OpenDBAndBackupEngine(false /* destroy_old_data */, false,
+ kShareWithChecksum);
+ for (int i = 5; i < 10; ++i) {
+ FillDB(db_.get(), keys_iteration * i, keys_iteration * (i + 1));
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), true));
+ }
+ CloseDBAndBackupEngine();
+
+ // Verify first (about to delete)
+ AssertBackupConsistency(1, 0, keys_iteration, keys_iteration * 11);
+
+ // For an extra challenge, make sure that GarbageCollect / DeleteBackup
+ // is OK even if we open without share_table_files
+ OpenDBAndBackupEngine(false /* destroy_old_data */, false, kNoShare);
+ backup_engine_->DeleteBackup(1);
+ backup_engine_->GarbageCollect();
+ CloseDBAndBackupEngine();
+
+ // Verify rest (not deleted)
+ for (int i = 1; i < 10; ++i) {
+ AssertBackupConsistency(i + 1, 0, keys_iteration * (i + 1),
+ keys_iteration * 11);
+ }
+}
+
+// This test simulates cleaning up after aborted or incomplete creation
+// of a new backup.
+TEST_F(BackupableDBTest, DeleteTmpFiles) {
+ for (int cleanup_fn : {1, 2, 3, 4}) {
+ for (ShareOption shared_option : kAllShareOptions) {
+ OpenDBAndBackupEngine(false /* destroy_old_data */, false /* dummy */,
+ shared_option);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get()));
+ BackupID next_id = 1;
+ BackupID oldest_id = std::numeric_limits<BackupID>::max();
+ {
+ std::vector<BackupInfo> backup_info;
+ backup_engine_->GetBackupInfo(&backup_info);
+ for (const auto& bi : backup_info) {
+ next_id = std::max(next_id, bi.backup_id + 1);
+ oldest_id = std::min(oldest_id, bi.backup_id);
+ }
+ }
+ CloseDBAndBackupEngine();
+
+ // An aborted or incomplete new backup will always be in the next
+ // id (maybe more)
+ std::string next_private = "private/" + std::to_string(next_id);
+
+ // NOTE: both shared and shared_checksum should be cleaned up
+ // regardless of how the backup engine is opened.
+ std::vector<std::string> tmp_files_and_dirs;
+ for (const auto& dir_and_file : {
+ std::make_pair(std::string("shared"),
+ std::string(".00006.sst.tmp")),
+ std::make_pair(std::string("shared_checksum"),
+ std::string(".00007.sst.tmp")),
+ std::make_pair(next_private, std::string("00003.sst")),
+ }) {
+ std::string dir = backupdir_ + "/" + dir_and_file.first;
+ file_manager_->CreateDir(dir);
+ ASSERT_OK(file_manager_->FileExists(dir));
+
+ std::string file = dir + "/" + dir_and_file.second;
+ file_manager_->WriteToFile(file, "tmp");
+ ASSERT_OK(file_manager_->FileExists(file));
+
+ tmp_files_and_dirs.push_back(file);
+ }
+ if (cleanup_fn != /*CreateNewBackup*/ 4) {
+ // This exists after CreateNewBackup because it's deleted then
+ // re-created.
+ tmp_files_and_dirs.push_back(backupdir_ + "/" + next_private);
+ }
+
+ OpenDBAndBackupEngine(false /* destroy_old_data */, false /* dummy */,
+ shared_option);
+ // Need to call one of these explicitly to delete tmp files
+ switch (cleanup_fn) {
+ case 1:
+ ASSERT_OK(backup_engine_->GarbageCollect());
+ break;
+ case 2:
+ ASSERT_OK(backup_engine_->DeleteBackup(oldest_id));
+ break;
+ case 3:
+ ASSERT_OK(backup_engine_->PurgeOldBackups(1));
+ break;
+ case 4:
+ // Does a garbage collect if it sees that next private dir exists
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get()));
+ break;
+ default:
+ assert(false);
+ }
+ CloseDBAndBackupEngine();
+ for (std::string file_or_dir : tmp_files_and_dirs) {
+ if (file_manager_->FileExists(file_or_dir) != Status::NotFound()) {
+ FAIL() << file_or_dir << " was expected to be deleted." << cleanup_fn;
+ }
+ }
+ }
+ }
+}
+
+TEST_F(BackupableDBTest, KeepLogFiles) {
+ backupable_options_->backup_log_files = false;
+ // basically infinite
+ options_.WAL_ttl_seconds = 24 * 60 * 60;
+ OpenDBAndBackupEngine(true);
+ FillDB(db_.get(), 0, 100);
+ ASSERT_OK(db_->Flush(FlushOptions()));
+ FillDB(db_.get(), 100, 200);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), false));
+ FillDB(db_.get(), 200, 300);
+ ASSERT_OK(db_->Flush(FlushOptions()));
+ FillDB(db_.get(), 300, 400);
+ ASSERT_OK(db_->Flush(FlushOptions()));
+ FillDB(db_.get(), 400, 500);
+ ASSERT_OK(db_->Flush(FlushOptions()));
+ CloseDBAndBackupEngine();
+
+ // all data should be there if we call with keep_log_files = true
+ AssertBackupConsistency(0, 0, 500, 600, true);
+}
+
+TEST_F(BackupableDBTest, RateLimiting) {
+ size_t const kMicrosPerSec = 1000 * 1000LL;
+ uint64_t const MB = 1024 * 1024;
+
+ const std::vector<std::pair<uint64_t, uint64_t>> limits(
+ {{1 * MB, 5 * MB}, {2 * MB, 3 * MB}});
+
+ std::shared_ptr<RateLimiter> backupThrottler(NewGenericRateLimiter(1));
+ std::shared_ptr<RateLimiter> restoreThrottler(NewGenericRateLimiter(1));
+
+ for (bool makeThrottler : {false, true}) {
+ if (makeThrottler) {
+ backupable_options_->backup_rate_limiter = backupThrottler;
+ backupable_options_->restore_rate_limiter = restoreThrottler;
+ }
+ // iter 0 -- single threaded
+ // iter 1 -- multi threaded
+ for (int iter = 0; iter < 2; ++iter) {
+ for (const auto& limit : limits) {
+ // destroy old data
+ DestroyDB(dbname_, Options());
+ if (makeThrottler) {
+ backupThrottler->SetBytesPerSecond(limit.first);
+ restoreThrottler->SetBytesPerSecond(limit.second);
+ } else {
+ backupable_options_->backup_rate_limit = limit.first;
+ backupable_options_->restore_rate_limit = limit.second;
+ }
+ backupable_options_->max_background_operations = (iter == 0) ? 1 : 10;
+ options_.compression = kNoCompression;
+ OpenDBAndBackupEngine(true);
+ size_t bytes_written = FillDB(db_.get(), 0, 100000);
+
+ auto start_backup = db_chroot_env_->NowMicros();
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), false));
+ auto backup_time = db_chroot_env_->NowMicros() - start_backup;
+ auto rate_limited_backup_time =
+ (bytes_written * kMicrosPerSec) / limit.first;
+ ASSERT_GT(backup_time, 0.8 * rate_limited_backup_time);
+
+ CloseDBAndBackupEngine();
+
+ OpenBackupEngine();
+ auto start_restore = db_chroot_env_->NowMicros();
+ ASSERT_OK(backup_engine_->RestoreDBFromLatestBackup(dbname_, dbname_));
+ auto restore_time = db_chroot_env_->NowMicros() - start_restore;
+ CloseBackupEngine();
+ auto rate_limited_restore_time =
+ (bytes_written * kMicrosPerSec) / limit.second;
+ ASSERT_GT(restore_time, 0.8 * rate_limited_restore_time);
+
+ AssertBackupConsistency(0, 0, 100000, 100010);
+ }
+ }
+ }
+}
+
+TEST_F(BackupableDBTest, ReadOnlyBackupEngine) {
+ DestroyDB(dbname_, options_);
+ OpenDBAndBackupEngine(true);
+ FillDB(db_.get(), 0, 100);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), true));
+ FillDB(db_.get(), 100, 200);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), true));
+ CloseDBAndBackupEngine();
+ DestroyDB(dbname_, options_);
+
+ backupable_options_->destroy_old_data = false;
+ test_backup_env_->ClearWrittenFiles();
+ test_backup_env_->SetLimitDeleteFiles(0);
+ BackupEngineReadOnly* read_only_backup_engine;
+ ASSERT_OK(BackupEngineReadOnly::Open(
+ db_chroot_env_.get(), *backupable_options_, &read_only_backup_engine));
+ std::vector<BackupInfo> backup_info;
+ read_only_backup_engine->GetBackupInfo(&backup_info);
+ ASSERT_EQ(backup_info.size(), 2U);
+
+ RestoreOptions restore_options(false);
+ ASSERT_OK(read_only_backup_engine->RestoreDBFromLatestBackup(
+ dbname_, dbname_, restore_options));
+ delete read_only_backup_engine;
+ std::vector<std::string> should_have_written;
+ test_backup_env_->AssertWrittenFiles(should_have_written);
+
+ DB* db = OpenDB();
+ AssertExists(db, 0, 200);
+ delete db;
+}
+
+TEST_F(BackupableDBTest, ProgressCallbackDuringBackup) {
+ DestroyDB(dbname_, options_);
+ OpenDBAndBackupEngine(true);
+ FillDB(db_.get(), 0, 100);
+ bool is_callback_invoked = false;
+ ASSERT_OK(backup_engine_->CreateNewBackup(
+ db_.get(), true,
+ [&is_callback_invoked]() { is_callback_invoked = true; }));
+
+ ASSERT_TRUE(is_callback_invoked);
+ CloseDBAndBackupEngine();
+ DestroyDB(dbname_, options_);
+}
+
+TEST_F(BackupableDBTest, GarbageCollectionBeforeBackup) {
+ DestroyDB(dbname_, options_);
+ OpenDBAndBackupEngine(true);
+
+ backup_chroot_env_->CreateDirIfMissing(backupdir_ + "/shared");
+ std::string file_five = backupdir_ + "/shared/000007.sst";
+ std::string file_five_contents = "I'm not really a sst file";
+ // this depends on the fact that 00007.sst is the first file created by the DB
+ ASSERT_OK(file_manager_->WriteToFile(file_five, file_five_contents));
+
+ FillDB(db_.get(), 0, 100);
+ // backup overwrites file 000007.sst
+ ASSERT_TRUE(backup_engine_->CreateNewBackup(db_.get(), true).ok());
+
+ std::string new_file_five_contents;
+ ASSERT_OK(ReadFileToString(backup_chroot_env_.get(), file_five,
+ &new_file_five_contents));
+ // file 000007.sst was overwritten
+ ASSERT_TRUE(new_file_five_contents != file_five_contents);
+
+ CloseDBAndBackupEngine();
+
+ AssertBackupConsistency(0, 0, 100);
+}
+
+// Test that we properly propagate Env failures
+TEST_F(BackupableDBTest, EnvFailures) {
+ BackupEngine* backup_engine;
+
+ // get children failure
+ {
+ test_backup_env_->SetGetChildrenFailure(true);
+ ASSERT_NOK(BackupEngine::Open(test_db_env_.get(), *backupable_options_,
+ &backup_engine));
+ test_backup_env_->SetGetChildrenFailure(false);
+ }
+
+ // created dir failure
+ {
+ test_backup_env_->SetCreateDirIfMissingFailure(true);
+ ASSERT_NOK(BackupEngine::Open(test_db_env_.get(), *backupable_options_,
+ &backup_engine));
+ test_backup_env_->SetCreateDirIfMissingFailure(false);
+ }
+
+ // new directory failure
+ {
+ test_backup_env_->SetNewDirectoryFailure(true);
+ ASSERT_NOK(BackupEngine::Open(test_db_env_.get(), *backupable_options_,
+ &backup_engine));
+ test_backup_env_->SetNewDirectoryFailure(false);
+ }
+
+ // Read from meta-file failure
+ {
+ DestroyDB(dbname_, options_);
+ OpenDBAndBackupEngine(true);
+ FillDB(db_.get(), 0, 100);
+ ASSERT_TRUE(backup_engine_->CreateNewBackup(db_.get(), true).ok());
+ CloseDBAndBackupEngine();
+ test_backup_env_->SetDummySequentialFile(true);
+ test_backup_env_->SetDummySequentialFileFailReads(true);
+ backupable_options_->destroy_old_data = false;
+ ASSERT_NOK(BackupEngine::Open(test_db_env_.get(), *backupable_options_,
+ &backup_engine));
+ test_backup_env_->SetDummySequentialFile(false);
+ test_backup_env_->SetDummySequentialFileFailReads(false);
+ }
+
+ // no failure
+ {
+ ASSERT_OK(BackupEngine::Open(test_db_env_.get(), *backupable_options_,
+ &backup_engine));
+ delete backup_engine;
+ }
+}
+
+// Verify manifest can roll while a backup is being created with the old
+// manifest.
+TEST_F(BackupableDBTest, ChangeManifestDuringBackupCreation) {
+ DestroyDB(dbname_, options_);
+ options_.max_manifest_file_size = 0; // always rollover manifest for file add
+ OpenDBAndBackupEngine(true);
+ FillDB(db_.get(), 0, 100);
+
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency({
+ {"CheckpointImpl::CreateCheckpoint:SavedLiveFiles1",
+ "VersionSet::LogAndApply:WriteManifest"},
+ {"VersionSet::LogAndApply:WriteManifestDone",
+ "CheckpointImpl::CreateCheckpoint:SavedLiveFiles2"},
+ });
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+
+ ROCKSDB_NAMESPACE::port::Thread flush_thread{
+ [this]() { ASSERT_OK(db_->Flush(FlushOptions())); }};
+
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), false));
+
+ flush_thread.join();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+
+ // The last manifest roll would've already been cleaned up by the full scan
+ // that happens when CreateNewBackup invokes EnableFileDeletions. We need to
+ // trigger another roll to verify non-full scan purges stale manifests.
+ DBImpl* db_impl = reinterpret_cast<DBImpl*>(db_.get());
+ std::string prev_manifest_path =
+ DescriptorFileName(dbname_, db_impl->TEST_Current_Manifest_FileNo());
+ FillDB(db_.get(), 0, 100);
+ ASSERT_OK(db_chroot_env_->FileExists(prev_manifest_path));
+ ASSERT_OK(db_->Flush(FlushOptions()));
+ ASSERT_TRUE(db_chroot_env_->FileExists(prev_manifest_path).IsNotFound());
+
+ CloseDBAndBackupEngine();
+ DestroyDB(dbname_, options_);
+ AssertBackupConsistency(0, 0, 100);
+}
+
+// see https://github.com/facebook/rocksdb/issues/921
+TEST_F(BackupableDBTest, Issue921Test) {
+ BackupEngine* backup_engine;
+ backupable_options_->share_table_files = false;
+ backup_chroot_env_->CreateDirIfMissing(backupable_options_->backup_dir);
+ backupable_options_->backup_dir += "/new_dir";
+ ASSERT_OK(BackupEngine::Open(backup_chroot_env_.get(), *backupable_options_,
+ &backup_engine));
+
+ delete backup_engine;
+}
+
+TEST_F(BackupableDBTest, BackupWithMetadata) {
+ const int keys_iteration = 5000;
+ OpenDBAndBackupEngine(true);
+ // create five backups
+ for (int i = 0; i < 5; ++i) {
+ const std::string metadata = std::to_string(i);
+ FillDB(db_.get(), keys_iteration * i, keys_iteration * (i + 1));
+ ASSERT_OK(
+ backup_engine_->CreateNewBackupWithMetadata(db_.get(), metadata, true));
+ }
+ CloseDBAndBackupEngine();
+
+ OpenDBAndBackupEngine();
+ std::vector<BackupInfo> backup_infos;
+ backup_engine_->GetBackupInfo(&backup_infos);
+ ASSERT_EQ(5, backup_infos.size());
+ for (int i = 0; i < 5; i++) {
+ ASSERT_EQ(std::to_string(i), backup_infos[i].app_metadata);
+ }
+ CloseDBAndBackupEngine();
+ DestroyDB(dbname_, options_);
+}
+
+TEST_F(BackupableDBTest, BinaryMetadata) {
+ OpenDBAndBackupEngine(true);
+ std::string binaryMetadata = "abc\ndef";
+ binaryMetadata.push_back('\0');
+ binaryMetadata.append("ghi");
+ ASSERT_OK(
+ backup_engine_->CreateNewBackupWithMetadata(db_.get(), binaryMetadata));
+ CloseDBAndBackupEngine();
+
+ OpenDBAndBackupEngine();
+ std::vector<BackupInfo> backup_infos;
+ backup_engine_->GetBackupInfo(&backup_infos);
+ ASSERT_EQ(1, backup_infos.size());
+ ASSERT_EQ(binaryMetadata, backup_infos[0].app_metadata);
+ CloseDBAndBackupEngine();
+ DestroyDB(dbname_, options_);
+}
+
+TEST_F(BackupableDBTest, MetadataTooLarge) {
+ OpenDBAndBackupEngine(true);
+ std::string largeMetadata(1024 * 1024 + 1, 0);
+ ASSERT_NOK(
+ backup_engine_->CreateNewBackupWithMetadata(db_.get(), largeMetadata));
+ CloseDBAndBackupEngine();
+ DestroyDB(dbname_, options_);
+}
+
+TEST_F(BackupableDBTest, LimitBackupsOpened) {
+ // Verify the specified max backups are opened, including skipping over
+ // corrupted backups.
+ //
+ // Setup:
+ // - backups 1, 2, and 4 are valid
+ // - backup 3 is corrupt
+ // - max_valid_backups_to_open == 2
+ //
+ // Expectation: the engine opens backups 4 and 2 since those are latest two
+ // non-corrupt backups.
+ const int kNumKeys = 5000;
+ OpenDBAndBackupEngine(true);
+ for (int i = 1; i <= 4; ++i) {
+ FillDB(db_.get(), kNumKeys * i, kNumKeys * (i + 1));
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), true));
+ if (i == 3) {
+ ASSERT_OK(file_manager_->CorruptFile(backupdir_ + "/meta/3", 3));
+ }
+ }
+ CloseDBAndBackupEngine();
+
+ backupable_options_->max_valid_backups_to_open = 2;
+ backupable_options_->destroy_old_data = false;
+ BackupEngineReadOnly* read_only_backup_engine;
+ ASSERT_OK(BackupEngineReadOnly::Open(backup_chroot_env_.get(),
+ *backupable_options_,
+ &read_only_backup_engine));
+
+ std::vector<BackupInfo> backup_infos;
+ read_only_backup_engine->GetBackupInfo(&backup_infos);
+ ASSERT_EQ(2, backup_infos.size());
+ ASSERT_EQ(2, backup_infos[0].backup_id);
+ ASSERT_EQ(4, backup_infos[1].backup_id);
+ delete read_only_backup_engine;
+}
+
+TEST_F(BackupableDBTest, IgnoreLimitBackupsOpenedWhenNotReadOnly) {
+ // Verify the specified max_valid_backups_to_open is ignored if the engine
+ // is not read-only.
+ //
+ // Setup:
+ // - backups 1, 2, and 4 are valid
+ // - backup 3 is corrupt
+ // - max_valid_backups_to_open == 2
+ //
+ // Expectation: the engine opens backups 4, 2, and 1 since those are latest
+ // non-corrupt backups, by ignoring max_valid_backups_to_open == 2.
+ const int kNumKeys = 5000;
+ OpenDBAndBackupEngine(true);
+ for (int i = 1; i <= 4; ++i) {
+ FillDB(db_.get(), kNumKeys * i, kNumKeys * (i + 1));
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), true));
+ if (i == 3) {
+ ASSERT_OK(file_manager_->CorruptFile(backupdir_ + "/meta/3", 3));
+ }
+ }
+ CloseDBAndBackupEngine();
+
+ backupable_options_->max_valid_backups_to_open = 2;
+ OpenDBAndBackupEngine();
+ std::vector<BackupInfo> backup_infos;
+ backup_engine_->GetBackupInfo(&backup_infos);
+ ASSERT_EQ(3, backup_infos.size());
+ ASSERT_EQ(1, backup_infos[0].backup_id);
+ ASSERT_EQ(2, backup_infos[1].backup_id);
+ ASSERT_EQ(4, backup_infos[2].backup_id);
+ CloseDBAndBackupEngine();
+ DestroyDB(dbname_, options_);
+}
+
+TEST_F(BackupableDBTest, CreateWhenLatestBackupCorrupted) {
+ // we should pick an ID greater than corrupted backups' IDs so creation can
+ // succeed even when latest backup is corrupted.
+ const int kNumKeys = 5000;
+ OpenDBAndBackupEngine(true /* destroy_old_data */);
+ FillDB(db_.get(), 0 /* from */, kNumKeys);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(),
+ true /* flush_before_backup */));
+ ASSERT_OK(file_manager_->CorruptFile(backupdir_ + "/meta/1",
+ 3 /* bytes_to_corrupt */));
+ CloseDBAndBackupEngine();
+
+ OpenDBAndBackupEngine();
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(),
+ true /* flush_before_backup */));
+ std::vector<BackupInfo> backup_infos;
+ backup_engine_->GetBackupInfo(&backup_infos);
+ ASSERT_EQ(1, backup_infos.size());
+ ASSERT_EQ(2, backup_infos[0].backup_id);
+}
+
+TEST_F(BackupableDBTest, WriteOnlyEngineNoSharedFileDeletion) {
+ // Verifies a write-only BackupEngine does not delete files belonging to valid
+ // backups when GarbageCollect, PurgeOldBackups, or DeleteBackup are called.
+ const int kNumKeys = 5000;
+ for (int i = 0; i < 3; ++i) {
+ OpenDBAndBackupEngine(i == 0 /* destroy_old_data */);
+ FillDB(db_.get(), i * kNumKeys, (i + 1) * kNumKeys);
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(), true));
+ CloseDBAndBackupEngine();
+
+ backupable_options_->max_valid_backups_to_open = 0;
+ OpenDBAndBackupEngine();
+ switch (i) {
+ case 0:
+ ASSERT_OK(backup_engine_->GarbageCollect());
+ break;
+ case 1:
+ ASSERT_OK(backup_engine_->PurgeOldBackups(1 /* num_backups_to_keep */));
+ break;
+ case 2:
+ ASSERT_OK(backup_engine_->DeleteBackup(2 /* backup_id */));
+ break;
+ default:
+ assert(false);
+ }
+ CloseDBAndBackupEngine();
+
+ backupable_options_->max_valid_backups_to_open = port::kMaxInt32;
+ AssertBackupConsistency(i + 1, 0, (i + 1) * kNumKeys);
+ }
+}
+
+TEST_P(BackupableDBTestWithParam, BackupUsingDirectIO) {
+ // Tests direct I/O on the backup engine's reads and writes on the DB env and
+ // backup env
+ // We use ChrootEnv underneath so the below line checks for direct I/O support
+ // in the chroot directory, not the true filesystem root.
+ if (!test::IsDirectIOSupported(test_db_env_.get(), "/")) {
+ return;
+ }
+ const int kNumKeysPerBackup = 100;
+ const int kNumBackups = 3;
+ options_.use_direct_reads = true;
+ OpenDBAndBackupEngine(true /* destroy_old_data */);
+ for (int i = 0; i < kNumBackups; ++i) {
+ FillDB(db_.get(), i * kNumKeysPerBackup /* from */,
+ (i + 1) * kNumKeysPerBackup /* to */);
+ ASSERT_OK(db_->Flush(FlushOptions()));
+
+ // Clear the file open counters and then do a bunch of backup engine ops.
+ // For all ops, files should be opened in direct mode.
+ test_backup_env_->ClearFileOpenCounters();
+ test_db_env_->ClearFileOpenCounters();
+ CloseBackupEngine();
+ OpenBackupEngine();
+ ASSERT_OK(backup_engine_->CreateNewBackup(db_.get(),
+ false /* flush_before_backup */));
+ ASSERT_OK(backup_engine_->VerifyBackup(i + 1));
+ CloseBackupEngine();
+ OpenBackupEngine();
+ std::vector<BackupInfo> backup_infos;
+ backup_engine_->GetBackupInfo(&backup_infos);
+ ASSERT_EQ(static_cast<size_t>(i + 1), backup_infos.size());
+
+ // Verify backup engine always opened files with direct I/O
+ ASSERT_EQ(0, test_db_env_->num_writers());
+ ASSERT_EQ(0, test_db_env_->num_rand_readers());
+ ASSERT_GT(test_db_env_->num_direct_seq_readers(), 0);
+ // Currently the DB doesn't support reading WALs or manifest with direct
+ // I/O, so subtract two.
+ ASSERT_EQ(test_db_env_->num_seq_readers() - 2,
+ test_db_env_->num_direct_seq_readers());
+ ASSERT_EQ(0, test_db_env_->num_rand_readers());
+ }
+ CloseDBAndBackupEngine();
+
+ for (int i = 0; i < kNumBackups; ++i) {
+ AssertBackupConsistency(i + 1 /* backup_id */,
+ i * kNumKeysPerBackup /* start_exist */,
+ (i + 1) * kNumKeysPerBackup /* end_exist */,
+ (i + 2) * kNumKeysPerBackup /* end */);
+ }
+}
+
+} // anon namespace
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
+
+#else
+#include <stdio.h>
+
+int main(int /*argc*/, char** /*argv*/) {
+ fprintf(stderr, "SKIPPED as BackupableDB is not supported in ROCKSDB_LITE\n");
+ return 0;
+}
+
+#endif // !defined(ROCKSDB_LITE) && !defined(OS_WIN)
diff --git a/src/rocksdb/utilities/blob_db/blob_compaction_filter.cc b/src/rocksdb/utilities/blob_db/blob_compaction_filter.cc
new file mode 100644
index 000000000..5900f0926
--- /dev/null
+++ b/src/rocksdb/utilities/blob_db/blob_compaction_filter.cc
@@ -0,0 +1,329 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "utilities/blob_db/blob_compaction_filter.h"
+#include "db/dbformat.h"
+
+#include <cinttypes>
+
+namespace ROCKSDB_NAMESPACE {
+namespace blob_db {
+
+CompactionFilter::Decision BlobIndexCompactionFilterBase::FilterV2(
+ int /*level*/, const Slice& key, ValueType value_type, const Slice& value,
+ std::string* /*new_value*/, std::string* /*skip_until*/) const {
+ if (value_type != kBlobIndex) {
+ return Decision::kKeep;
+ }
+ BlobIndex blob_index;
+ Status s = blob_index.DecodeFrom(value);
+ if (!s.ok()) {
+ // Unable to decode blob index. Keeping the value.
+ return Decision::kKeep;
+ }
+ if (blob_index.HasTTL() && blob_index.expiration() <= current_time_) {
+ // Expired
+ expired_count_++;
+ expired_size_ += key.size() + value.size();
+ return Decision::kRemove;
+ }
+ if (!blob_index.IsInlined() &&
+ blob_index.file_number() < context_.next_file_number &&
+ context_.current_blob_files.count(blob_index.file_number()) == 0) {
+ // Corresponding blob file gone (most likely, evicted by FIFO eviction).
+ evicted_count_++;
+ evicted_size_ += key.size() + value.size();
+ return Decision::kRemove;
+ }
+ if (context_.fifo_eviction_seq > 0 && blob_index.HasTTL() &&
+ blob_index.expiration() < context_.evict_expiration_up_to) {
+ // Hack: Internal key is passed to BlobIndexCompactionFilter for it to
+ // get sequence number.
+ ParsedInternalKey ikey;
+ bool ok = ParseInternalKey(key, &ikey);
+ // Remove keys that could have been remove by last FIFO eviction.
+ // If get error while parsing key, ignore and continue.
+ if (ok && ikey.sequence < context_.fifo_eviction_seq) {
+ evicted_count_++;
+ evicted_size_ += key.size() + value.size();
+ return Decision::kRemove;
+ }
+ }
+ return Decision::kKeep;
+}
+
+BlobIndexCompactionFilterGC::~BlobIndexCompactionFilterGC() {
+ if (blob_file_) {
+ CloseAndRegisterNewBlobFile();
+ }
+
+ assert(context_gc_.blob_db_impl);
+
+ ROCKS_LOG_INFO(context_gc_.blob_db_impl->db_options_.info_log,
+ "GC pass finished %s: encountered %" PRIu64 " blobs (%" PRIu64
+ " bytes), relocated %" PRIu64 " blobs (%" PRIu64
+ " bytes), created %" PRIu64 " new blob file(s)",
+ !gc_stats_.HasError() ? "successfully" : "with failure",
+ gc_stats_.AllBlobs(), gc_stats_.AllBytes(),
+ gc_stats_.RelocatedBlobs(), gc_stats_.RelocatedBytes(),
+ gc_stats_.NewFiles());
+
+ RecordTick(statistics(), BLOB_DB_GC_NUM_KEYS_RELOCATED,
+ gc_stats_.RelocatedBlobs());
+ RecordTick(statistics(), BLOB_DB_GC_BYTES_RELOCATED,
+ gc_stats_.RelocatedBytes());
+ RecordTick(statistics(), BLOB_DB_GC_NUM_NEW_FILES, gc_stats_.NewFiles());
+ RecordTick(statistics(), BLOB_DB_GC_FAILURES, gc_stats_.HasError());
+}
+
+CompactionFilter::BlobDecision BlobIndexCompactionFilterGC::PrepareBlobOutput(
+ const Slice& key, const Slice& existing_value,
+ std::string* new_value) const {
+ assert(new_value);
+
+ const BlobDBImpl* const blob_db_impl = context_gc_.blob_db_impl;
+ (void)blob_db_impl;
+
+ assert(blob_db_impl);
+ assert(blob_db_impl->bdb_options_.enable_garbage_collection);
+
+ BlobIndex blob_index;
+ const Status s = blob_index.DecodeFrom(existing_value);
+ if (!s.ok()) {
+ gc_stats_.SetError();
+ return BlobDecision::kCorruption;
+ }
+
+ if (blob_index.IsInlined()) {
+ gc_stats_.AddBlob(blob_index.value().size());
+
+ return BlobDecision::kKeep;
+ }
+
+ gc_stats_.AddBlob(blob_index.size());
+
+ if (blob_index.HasTTL()) {
+ return BlobDecision::kKeep;
+ }
+
+ if (blob_index.file_number() >= context_gc_.cutoff_file_number) {
+ return BlobDecision::kKeep;
+ }
+
+ // Note: each compaction generates its own blob files, which, depending on the
+ // workload, might result in many small blob files. The total number of files
+ // is bounded though (determined by the number of compactions and the blob
+ // file size option).
+ if (!OpenNewBlobFileIfNeeded()) {
+ gc_stats_.SetError();
+ return BlobDecision::kIOError;
+ }
+
+ PinnableSlice blob;
+ CompressionType compression_type = kNoCompression;
+ if (!ReadBlobFromOldFile(key, blob_index, &blob, &compression_type)) {
+ gc_stats_.SetError();
+ return BlobDecision::kIOError;
+ }
+
+ uint64_t new_blob_file_number = 0;
+ uint64_t new_blob_offset = 0;
+ if (!WriteBlobToNewFile(key, blob, &new_blob_file_number, &new_blob_offset)) {
+ gc_stats_.SetError();
+ return BlobDecision::kIOError;
+ }
+
+ if (!CloseAndRegisterNewBlobFileIfNeeded()) {
+ gc_stats_.SetError();
+ return BlobDecision::kIOError;
+ }
+
+ BlobIndex::EncodeBlob(new_value, new_blob_file_number, new_blob_offset,
+ blob.size(), compression_type);
+
+ gc_stats_.AddRelocatedBlob(blob_index.size());
+
+ return BlobDecision::kChangeValue;
+}
+
+bool BlobIndexCompactionFilterGC::OpenNewBlobFileIfNeeded() const {
+ if (blob_file_) {
+ assert(writer_);
+ return true;
+ }
+
+ BlobDBImpl* const blob_db_impl = context_gc_.blob_db_impl;
+ assert(blob_db_impl);
+
+ const Status s = blob_db_impl->CreateBlobFileAndWriter(
+ /* has_ttl */ false, ExpirationRange(), "GC", &blob_file_, &writer_);
+ if (!s.ok()) {
+ ROCKS_LOG_ERROR(blob_db_impl->db_options_.info_log,
+ "Error opening new blob file during GC, status: %s",
+ s.ToString().c_str());
+
+ return false;
+ }
+
+ assert(blob_file_);
+ assert(writer_);
+
+ gc_stats_.AddNewFile();
+
+ return true;
+}
+
+bool BlobIndexCompactionFilterGC::ReadBlobFromOldFile(
+ const Slice& key, const BlobIndex& blob_index, PinnableSlice* blob,
+ CompressionType* compression_type) const {
+ BlobDBImpl* const blob_db_impl = context_gc_.blob_db_impl;
+ assert(blob_db_impl);
+
+ const Status s = blob_db_impl->GetRawBlobFromFile(
+ key, blob_index.file_number(), blob_index.offset(), blob_index.size(),
+ blob, compression_type);
+
+ if (!s.ok()) {
+ ROCKS_LOG_ERROR(blob_db_impl->db_options_.info_log,
+ "Error reading blob during GC, key: %s (%s), status: %s",
+ key.ToString(/* output_hex */ true).c_str(),
+ blob_index.DebugString(/* output_hex */ true).c_str(),
+ s.ToString().c_str());
+
+ return false;
+ }
+
+ return true;
+}
+
+bool BlobIndexCompactionFilterGC::WriteBlobToNewFile(
+ const Slice& key, const Slice& blob, uint64_t* new_blob_file_number,
+ uint64_t* new_blob_offset) const {
+ assert(new_blob_file_number);
+ assert(new_blob_offset);
+
+ assert(blob_file_);
+ *new_blob_file_number = blob_file_->BlobFileNumber();
+
+ assert(writer_);
+ uint64_t new_key_offset = 0;
+ const Status s = writer_->AddRecord(key, blob, kNoExpiration, &new_key_offset,
+ new_blob_offset);
+
+ if (!s.ok()) {
+ const BlobDBImpl* const blob_db_impl = context_gc_.blob_db_impl;
+ assert(blob_db_impl);
+
+ ROCKS_LOG_ERROR(
+ blob_db_impl->db_options_.info_log,
+ "Error writing blob to new file %s during GC, key: %s, status: %s",
+ blob_file_->PathName().c_str(),
+ key.ToString(/* output_hex */ true).c_str(), s.ToString().c_str());
+ return false;
+ }
+
+ const uint64_t new_size =
+ BlobLogRecord::kHeaderSize + key.size() + blob.size();
+ blob_file_->BlobRecordAdded(new_size);
+
+ BlobDBImpl* const blob_db_impl = context_gc_.blob_db_impl;
+ assert(blob_db_impl);
+
+ blob_db_impl->total_blob_size_ += new_size;
+
+ return true;
+}
+
+bool BlobIndexCompactionFilterGC::CloseAndRegisterNewBlobFileIfNeeded() const {
+ const BlobDBImpl* const blob_db_impl = context_gc_.blob_db_impl;
+ assert(blob_db_impl);
+
+ assert(blob_file_);
+ if (blob_file_->GetFileSize() < blob_db_impl->bdb_options_.blob_file_size) {
+ return true;
+ }
+
+ return CloseAndRegisterNewBlobFile();
+}
+
+bool BlobIndexCompactionFilterGC::CloseAndRegisterNewBlobFile() const {
+ BlobDBImpl* const blob_db_impl = context_gc_.blob_db_impl;
+ assert(blob_db_impl);
+ assert(blob_file_);
+
+ Status s;
+
+ {
+ WriteLock wl(&blob_db_impl->mutex_);
+
+ s = blob_db_impl->CloseBlobFile(blob_file_);
+
+ // Note: we delay registering the new blob file until it's closed to
+ // prevent FIFO eviction from processing it during the GC run.
+ blob_db_impl->RegisterBlobFile(blob_file_);
+ }
+
+ assert(blob_file_->Immutable());
+ blob_file_.reset();
+
+ if (!s.ok()) {
+ ROCKS_LOG_ERROR(blob_db_impl->db_options_.info_log,
+ "Error closing new blob file %s during GC, status: %s",
+ blob_file_->PathName().c_str(), s.ToString().c_str());
+
+ return false;
+ }
+
+ return true;
+}
+
+std::unique_ptr<CompactionFilter>
+BlobIndexCompactionFilterFactory::CreateCompactionFilter(
+ const CompactionFilter::Context& /*context*/) {
+ assert(env());
+
+ int64_t current_time = 0;
+ Status s = env()->GetCurrentTime(&current_time);
+ if (!s.ok()) {
+ return nullptr;
+ }
+ assert(current_time >= 0);
+
+ assert(blob_db_impl());
+
+ BlobCompactionContext context;
+ blob_db_impl()->GetCompactionContext(&context);
+
+ return std::unique_ptr<CompactionFilter>(new BlobIndexCompactionFilter(
+ std::move(context), current_time, statistics()));
+}
+
+std::unique_ptr<CompactionFilter>
+BlobIndexCompactionFilterFactoryGC::CreateCompactionFilter(
+ const CompactionFilter::Context& /*context*/) {
+ assert(env());
+
+ int64_t current_time = 0;
+ Status s = env()->GetCurrentTime(&current_time);
+ if (!s.ok()) {
+ return nullptr;
+ }
+ assert(current_time >= 0);
+
+ assert(blob_db_impl());
+
+ BlobCompactionContext context;
+ BlobCompactionContextGC context_gc;
+ blob_db_impl()->GetCompactionContext(&context, &context_gc);
+
+ return std::unique_ptr<CompactionFilter>(new BlobIndexCompactionFilterGC(
+ std::move(context), std::move(context_gc), current_time, statistics()));
+}
+
+} // namespace blob_db
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/blob_db/blob_compaction_filter.h b/src/rocksdb/utilities/blob_db/blob_compaction_filter.h
new file mode 100644
index 000000000..409df26ac
--- /dev/null
+++ b/src/rocksdb/utilities/blob_db/blob_compaction_filter.h
@@ -0,0 +1,168 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+#pragma once
+#ifndef ROCKSDB_LITE
+
+#include <unordered_set>
+
+#include "db/blob_index.h"
+#include "monitoring/statistics.h"
+#include "rocksdb/compaction_filter.h"
+#include "rocksdb/env.h"
+#include "utilities/blob_db/blob_db_gc_stats.h"
+#include "utilities/blob_db/blob_db_impl.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace blob_db {
+
+struct BlobCompactionContext {
+ uint64_t next_file_number = 0;
+ std::unordered_set<uint64_t> current_blob_files;
+ SequenceNumber fifo_eviction_seq = 0;
+ uint64_t evict_expiration_up_to = 0;
+};
+
+struct BlobCompactionContextGC {
+ BlobDBImpl* blob_db_impl = nullptr;
+ uint64_t cutoff_file_number = 0;
+};
+
+// Compaction filter that deletes expired blob indexes from the base DB.
+// Comes into two varieties, one for the non-GC case and one for the GC case.
+class BlobIndexCompactionFilterBase : public CompactionFilter {
+ public:
+ BlobIndexCompactionFilterBase(BlobCompactionContext&& context,
+ uint64_t current_time, Statistics* stats)
+ : context_(std::move(context)),
+ current_time_(current_time),
+ statistics_(stats) {}
+
+ ~BlobIndexCompactionFilterBase() override {
+ RecordTick(statistics_, BLOB_DB_BLOB_INDEX_EXPIRED_COUNT, expired_count_);
+ RecordTick(statistics_, BLOB_DB_BLOB_INDEX_EXPIRED_SIZE, expired_size_);
+ RecordTick(statistics_, BLOB_DB_BLOB_INDEX_EVICTED_COUNT, evicted_count_);
+ RecordTick(statistics_, BLOB_DB_BLOB_INDEX_EVICTED_SIZE, evicted_size_);
+ }
+
+ // Filter expired blob indexes regardless of snapshots.
+ bool IgnoreSnapshots() const override { return true; }
+
+ Decision FilterV2(int /*level*/, const Slice& key, ValueType value_type,
+ const Slice& value, std::string* /*new_value*/,
+ std::string* /*skip_until*/) const override;
+
+ protected:
+ Statistics* statistics() const { return statistics_; }
+
+ private:
+ BlobCompactionContext context_;
+ const uint64_t current_time_;
+ Statistics* statistics_;
+ // It is safe to not using std::atomic since the compaction filter, created
+ // from a compaction filter factroy, will not be called from multiple threads.
+ mutable uint64_t expired_count_ = 0;
+ mutable uint64_t expired_size_ = 0;
+ mutable uint64_t evicted_count_ = 0;
+ mutable uint64_t evicted_size_ = 0;
+};
+
+class BlobIndexCompactionFilter : public BlobIndexCompactionFilterBase {
+ public:
+ BlobIndexCompactionFilter(BlobCompactionContext&& context,
+ uint64_t current_time, Statistics* stats)
+ : BlobIndexCompactionFilterBase(std::move(context), current_time, stats) {
+ }
+
+ const char* Name() const override { return "BlobIndexCompactionFilter"; }
+};
+
+class BlobIndexCompactionFilterGC : public BlobIndexCompactionFilterBase {
+ public:
+ BlobIndexCompactionFilterGC(BlobCompactionContext&& context,
+ BlobCompactionContextGC&& context_gc,
+ uint64_t current_time, Statistics* stats)
+ : BlobIndexCompactionFilterBase(std::move(context), current_time, stats),
+ context_gc_(std::move(context_gc)) {}
+
+ ~BlobIndexCompactionFilterGC() override;
+
+ const char* Name() const override { return "BlobIndexCompactionFilterGC"; }
+
+ BlobDecision PrepareBlobOutput(const Slice& key, const Slice& existing_value,
+ std::string* new_value) const override;
+
+ private:
+ bool OpenNewBlobFileIfNeeded() const;
+ bool ReadBlobFromOldFile(const Slice& key, const BlobIndex& blob_index,
+ PinnableSlice* blob,
+ CompressionType* compression_type) const;
+ bool WriteBlobToNewFile(const Slice& key, const Slice& blob,
+ uint64_t* new_blob_file_number,
+ uint64_t* new_blob_offset) const;
+ bool CloseAndRegisterNewBlobFileIfNeeded() const;
+ bool CloseAndRegisterNewBlobFile() const;
+
+ private:
+ BlobCompactionContextGC context_gc_;
+ mutable std::shared_ptr<BlobFile> blob_file_;
+ mutable std::shared_ptr<Writer> writer_;
+ mutable BlobDBGarbageCollectionStats gc_stats_;
+};
+
+// Compaction filter factory; similarly to the filters above, it comes
+// in two flavors, one that creates filters that support GC, and one
+// that creates non-GC filters.
+class BlobIndexCompactionFilterFactoryBase : public CompactionFilterFactory {
+ public:
+ BlobIndexCompactionFilterFactoryBase(BlobDBImpl* _blob_db_impl, Env* _env,
+ Statistics* _statistics)
+ : blob_db_impl_(_blob_db_impl), env_(_env), statistics_(_statistics) {}
+
+ protected:
+ BlobDBImpl* blob_db_impl() const { return blob_db_impl_; }
+ Env* env() const { return env_; }
+ Statistics* statistics() const { return statistics_; }
+
+ private:
+ BlobDBImpl* blob_db_impl_;
+ Env* env_;
+ Statistics* statistics_;
+};
+
+class BlobIndexCompactionFilterFactory
+ : public BlobIndexCompactionFilterFactoryBase {
+ public:
+ BlobIndexCompactionFilterFactory(BlobDBImpl* _blob_db_impl, Env* _env,
+ Statistics* _statistics)
+ : BlobIndexCompactionFilterFactoryBase(_blob_db_impl, _env, _statistics) {
+ }
+
+ const char* Name() const override {
+ return "BlobIndexCompactionFilterFactory";
+ }
+
+ std::unique_ptr<CompactionFilter> CreateCompactionFilter(
+ const CompactionFilter::Context& /*context*/) override;
+};
+
+class BlobIndexCompactionFilterFactoryGC
+ : public BlobIndexCompactionFilterFactoryBase {
+ public:
+ BlobIndexCompactionFilterFactoryGC(BlobDBImpl* _blob_db_impl, Env* _env,
+ Statistics* _statistics)
+ : BlobIndexCompactionFilterFactoryBase(_blob_db_impl, _env, _statistics) {
+ }
+
+ const char* Name() const override {
+ return "BlobIndexCompactionFilterFactoryGC";
+ }
+
+ std::unique_ptr<CompactionFilter> CreateCompactionFilter(
+ const CompactionFilter::Context& /*context*/) override;
+};
+
+} // namespace blob_db
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/blob_db/blob_db.cc b/src/rocksdb/utilities/blob_db/blob_db.cc
new file mode 100644
index 000000000..f568ecd1a
--- /dev/null
+++ b/src/rocksdb/utilities/blob_db/blob_db.cc
@@ -0,0 +1,102 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+#ifndef ROCKSDB_LITE
+
+#include "utilities/blob_db/blob_db.h"
+
+#include <cinttypes>
+#include "utilities/blob_db/blob_db_impl.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace blob_db {
+
+Status BlobDB::Open(const Options& options, const BlobDBOptions& bdb_options,
+ const std::string& dbname, BlobDB** blob_db) {
+ *blob_db = nullptr;
+ DBOptions db_options(options);
+ ColumnFamilyOptions cf_options(options);
+ std::vector<ColumnFamilyDescriptor> column_families;
+ column_families.push_back(
+ ColumnFamilyDescriptor(kDefaultColumnFamilyName, cf_options));
+ std::vector<ColumnFamilyHandle*> handles;
+ Status s = BlobDB::Open(db_options, bdb_options, dbname, column_families,
+ &handles, blob_db);
+ if (s.ok()) {
+ assert(handles.size() == 1);
+ // i can delete the handle since DBImpl is always holding a reference to
+ // default column family
+ delete handles[0];
+ }
+ return s;
+}
+
+Status BlobDB::Open(const DBOptions& db_options,
+ const BlobDBOptions& bdb_options, const std::string& dbname,
+ const std::vector<ColumnFamilyDescriptor>& column_families,
+ std::vector<ColumnFamilyHandle*>* handles,
+ BlobDB** blob_db) {
+ if (column_families.size() != 1 ||
+ column_families[0].name != kDefaultColumnFamilyName) {
+ return Status::NotSupported(
+ "Blob DB doesn't support non-default column family.");
+ }
+
+ BlobDBImpl* blob_db_impl = new BlobDBImpl(dbname, bdb_options, db_options,
+ column_families[0].options);
+ Status s = blob_db_impl->Open(handles);
+ if (s.ok()) {
+ *blob_db = static_cast<BlobDB*>(blob_db_impl);
+ } else {
+ delete blob_db_impl;
+ *blob_db = nullptr;
+ }
+ return s;
+}
+
+BlobDB::BlobDB() : StackableDB(nullptr) {}
+
+void BlobDBOptions::Dump(Logger* log) const {
+ ROCKS_LOG_HEADER(
+ log, " BlobDBOptions.blob_dir: %s",
+ blob_dir.c_str());
+ ROCKS_LOG_HEADER(
+ log, " BlobDBOptions.path_relative: %d",
+ path_relative);
+ ROCKS_LOG_HEADER(
+ log, " BlobDBOptions.is_fifo: %d",
+ is_fifo);
+ ROCKS_LOG_HEADER(
+ log, " BlobDBOptions.max_db_size: %" PRIu64,
+ max_db_size);
+ ROCKS_LOG_HEADER(
+ log, " BlobDBOptions.ttl_range_secs: %" PRIu64,
+ ttl_range_secs);
+ ROCKS_LOG_HEADER(
+ log, " BlobDBOptions.min_blob_size: %" PRIu64,
+ min_blob_size);
+ ROCKS_LOG_HEADER(
+ log, " BlobDBOptions.bytes_per_sync: %" PRIu64,
+ bytes_per_sync);
+ ROCKS_LOG_HEADER(
+ log, " BlobDBOptions.blob_file_size: %" PRIu64,
+ blob_file_size);
+ ROCKS_LOG_HEADER(
+ log, " BlobDBOptions.compression: %d",
+ static_cast<int>(compression));
+ ROCKS_LOG_HEADER(
+ log, " BlobDBOptions.enable_garbage_collection: %d",
+ enable_garbage_collection);
+ ROCKS_LOG_HEADER(
+ log, " BlobDBOptions.garbage_collection_cutoff: %f",
+ garbage_collection_cutoff);
+ ROCKS_LOG_HEADER(
+ log, " BlobDBOptions.disable_background_tasks: %d",
+ disable_background_tasks);
+}
+
+} // namespace blob_db
+} // namespace ROCKSDB_NAMESPACE
+#endif
diff --git a/src/rocksdb/utilities/blob_db/blob_db.h b/src/rocksdb/utilities/blob_db/blob_db.h
new file mode 100644
index 000000000..72a580433
--- /dev/null
+++ b/src/rocksdb/utilities/blob_db/blob_db.h
@@ -0,0 +1,261 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#ifndef ROCKSDB_LITE
+
+#include <functional>
+#include <string>
+#include <vector>
+#include "rocksdb/db.h"
+#include "rocksdb/status.h"
+#include "rocksdb/utilities/stackable_db.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+namespace blob_db {
+
+// A wrapped database which puts values of KV pairs in a separate log
+// and store location to the log in the underlying DB.
+//
+// The factory needs to be moved to include/rocksdb/utilities to allow
+// users to use blob DB.
+
+struct BlobDBOptions {
+ // name of the directory under main db, where blobs will be stored.
+ // default is "blob_dir"
+ std::string blob_dir = "blob_dir";
+
+ // whether the blob_dir path is relative or absolute.
+ bool path_relative = true;
+
+ // When max_db_size is reached, evict blob files to free up space
+ // instead of returnning NoSpace error on write. Blob files will be
+ // evicted from oldest to newest, based on file creation time.
+ bool is_fifo = false;
+
+ // Maximum size of the database (including SST files and blob files).
+ //
+ // Default: 0 (no limits)
+ uint64_t max_db_size = 0;
+
+ // a new bucket is opened, for ttl_range. So if ttl_range is 600seconds
+ // (10 minutes), and the first bucket starts at 1471542000
+ // then the blob buckets will be
+ // first bucket is 1471542000 - 1471542600
+ // second bucket is 1471542600 - 1471543200
+ // and so on
+ uint64_t ttl_range_secs = 3600;
+
+ // The smallest value to store in blob log. Values smaller than this threshold
+ // will be inlined in base DB together with the key.
+ uint64_t min_blob_size = 0;
+
+ // Allows OS to incrementally sync blob files to disk for every
+ // bytes_per_sync bytes written. Users shouldn't rely on it for
+ // persistency guarantee.
+ uint64_t bytes_per_sync = 512 * 1024;
+
+ // the target size of each blob file. File will become immutable
+ // after it exceeds that size
+ uint64_t blob_file_size = 256 * 1024 * 1024;
+
+ // what compression to use for Blob's
+ CompressionType compression = kNoCompression;
+
+ // If enabled, BlobDB cleans up stale blobs in non-TTL files during compaction
+ // by rewriting the remaining live blobs to new files.
+ bool enable_garbage_collection = false;
+
+ // The cutoff in terms of blob file age for garbage collection. Blobs in
+ // the oldest N non-TTL blob files will be rewritten when encountered during
+ // compaction, where N = garbage_collection_cutoff * number_of_non_TTL_files.
+ double garbage_collection_cutoff = 0.25;
+
+ // Disable all background job. Used for test only.
+ bool disable_background_tasks = false;
+
+ void Dump(Logger* log) const;
+};
+
+class BlobDB : public StackableDB {
+ public:
+ using ROCKSDB_NAMESPACE::StackableDB::Put;
+ virtual Status Put(const WriteOptions& options, const Slice& key,
+ const Slice& value) override = 0;
+ virtual Status Put(const WriteOptions& options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ const Slice& value) override {
+ if (column_family->GetID() != DefaultColumnFamily()->GetID()) {
+ return Status::NotSupported(
+ "Blob DB doesn't support non-default column family.");
+ }
+ return Put(options, key, value);
+ }
+
+ using ROCKSDB_NAMESPACE::StackableDB::Delete;
+ virtual Status Delete(const WriteOptions& options,
+ ColumnFamilyHandle* column_family,
+ const Slice& key) override {
+ if (column_family->GetID() != DefaultColumnFamily()->GetID()) {
+ return Status::NotSupported(
+ "Blob DB doesn't support non-default column family.");
+ }
+ assert(db_ != nullptr);
+ return db_->Delete(options, column_family, key);
+ }
+
+ virtual Status PutWithTTL(const WriteOptions& options, const Slice& key,
+ const Slice& value, uint64_t ttl) = 0;
+ virtual Status PutWithTTL(const WriteOptions& options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ const Slice& value, uint64_t ttl) {
+ if (column_family->GetID() != DefaultColumnFamily()->GetID()) {
+ return Status::NotSupported(
+ "Blob DB doesn't support non-default column family.");
+ }
+ return PutWithTTL(options, key, value, ttl);
+ }
+
+ // Put with expiration. Key with expiration time equal to
+ // std::numeric_limits<uint64_t>::max() means the key don't expire.
+ virtual Status PutUntil(const WriteOptions& options, const Slice& key,
+ const Slice& value, uint64_t expiration) = 0;
+ virtual Status PutUntil(const WriteOptions& options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ const Slice& value, uint64_t expiration) {
+ if (column_family->GetID() != DefaultColumnFamily()->GetID()) {
+ return Status::NotSupported(
+ "Blob DB doesn't support non-default column family.");
+ }
+ return PutUntil(options, key, value, expiration);
+ }
+
+ using ROCKSDB_NAMESPACE::StackableDB::Get;
+ virtual Status Get(const ReadOptions& options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ PinnableSlice* value) override = 0;
+
+ // Get value and expiration.
+ virtual Status Get(const ReadOptions& options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ PinnableSlice* value, uint64_t* expiration) = 0;
+ virtual Status Get(const ReadOptions& options, const Slice& key,
+ PinnableSlice* value, uint64_t* expiration) {
+ return Get(options, DefaultColumnFamily(), key, value, expiration);
+ }
+
+ using ROCKSDB_NAMESPACE::StackableDB::MultiGet;
+ virtual std::vector<Status> MultiGet(
+ const ReadOptions& options,
+ const std::vector<Slice>& keys,
+ std::vector<std::string>* values) override = 0;
+ virtual std::vector<Status> MultiGet(
+ const ReadOptions& options,
+ const std::vector<ColumnFamilyHandle*>& column_families,
+ const std::vector<Slice>& keys,
+ std::vector<std::string>* values) override {
+ for (auto column_family : column_families) {
+ if (column_family->GetID() != DefaultColumnFamily()->GetID()) {
+ return std::vector<Status>(
+ column_families.size(),
+ Status::NotSupported(
+ "Blob DB doesn't support non-default column family."));
+ }
+ }
+ return MultiGet(options, keys, values);
+ }
+ virtual void MultiGet(const ReadOptions& /*options*/,
+ ColumnFamilyHandle* /*column_family*/,
+ const size_t num_keys, const Slice* /*keys*/,
+ PinnableSlice* /*values*/, Status* statuses,
+ const bool /*sorted_input*/ = false) override {
+ for (size_t i = 0; i < num_keys; ++i) {
+ statuses[i] = Status::NotSupported(
+ "Blob DB doesn't support batched MultiGet");
+ }
+ }
+
+ using ROCKSDB_NAMESPACE::StackableDB::SingleDelete;
+ virtual Status SingleDelete(const WriteOptions& /*wopts*/,
+ ColumnFamilyHandle* /*column_family*/,
+ const Slice& /*key*/) override {
+ return Status::NotSupported("Not supported operation in blob db.");
+ }
+
+ using ROCKSDB_NAMESPACE::StackableDB::Merge;
+ virtual Status Merge(const WriteOptions& /*options*/,
+ ColumnFamilyHandle* /*column_family*/,
+ const Slice& /*key*/, const Slice& /*value*/) override {
+ return Status::NotSupported("Not supported operation in blob db.");
+ }
+
+ virtual Status Write(const WriteOptions& opts,
+ WriteBatch* updates) override = 0;
+ using ROCKSDB_NAMESPACE::StackableDB::NewIterator;
+ virtual Iterator* NewIterator(const ReadOptions& options) override = 0;
+ virtual Iterator* NewIterator(const ReadOptions& options,
+ ColumnFamilyHandle* column_family) override {
+ if (column_family->GetID() != DefaultColumnFamily()->GetID()) {
+ // Blob DB doesn't support non-default column family.
+ return nullptr;
+ }
+ return NewIterator(options);
+ }
+
+ Status CompactFiles(
+ const CompactionOptions& compact_options,
+ const std::vector<std::string>& input_file_names, const int output_level,
+ const int output_path_id = -1,
+ std::vector<std::string>* const output_file_names = nullptr,
+ CompactionJobInfo* compaction_job_info = nullptr) override = 0;
+ Status CompactFiles(
+ const CompactionOptions& compact_options,
+ ColumnFamilyHandle* column_family,
+ const std::vector<std::string>& input_file_names, const int output_level,
+ const int output_path_id = -1,
+ std::vector<std::string>* const output_file_names = nullptr,
+ CompactionJobInfo* compaction_job_info = nullptr) override {
+ if (column_family->GetID() != DefaultColumnFamily()->GetID()) {
+ return Status::NotSupported(
+ "Blob DB doesn't support non-default column family.");
+ }
+
+ return CompactFiles(compact_options, input_file_names, output_level,
+ output_path_id, output_file_names, compaction_job_info);
+ }
+
+ using ROCKSDB_NAMESPACE::StackableDB::Close;
+ virtual Status Close() override = 0;
+
+ // Opening blob db.
+ static Status Open(const Options& options, const BlobDBOptions& bdb_options,
+ const std::string& dbname, BlobDB** blob_db);
+
+ static Status Open(const DBOptions& db_options,
+ const BlobDBOptions& bdb_options,
+ const std::string& dbname,
+ const std::vector<ColumnFamilyDescriptor>& column_families,
+ std::vector<ColumnFamilyHandle*>* handles,
+ BlobDB** blob_db);
+
+ virtual BlobDBOptions GetBlobDBOptions() const = 0;
+
+ virtual Status SyncBlobFiles() = 0;
+
+ virtual ~BlobDB() {}
+
+ protected:
+ explicit BlobDB();
+};
+
+// Destroy the content of the database.
+Status DestroyBlobDB(const std::string& dbname, const Options& options,
+ const BlobDBOptions& bdb_options);
+
+} // namespace blob_db
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/blob_db/blob_db_gc_stats.h b/src/rocksdb/utilities/blob_db/blob_db_gc_stats.h
new file mode 100644
index 000000000..1e6e4a25d
--- /dev/null
+++ b/src/rocksdb/utilities/blob_db/blob_db_gc_stats.h
@@ -0,0 +1,52 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+#pragma once
+
+#ifndef ROCKSDB_LITE
+
+namespace ROCKSDB_NAMESPACE {
+
+namespace blob_db {
+
+/**
+ * Statistics related to a single garbage collection pass (i.e. a single
+ * (sub)compaction).
+ */
+class BlobDBGarbageCollectionStats {
+ public:
+ uint64_t AllBlobs() const { return all_blobs_; }
+ uint64_t AllBytes() const { return all_bytes_; }
+ uint64_t RelocatedBlobs() const { return relocated_blobs_; }
+ uint64_t RelocatedBytes() const { return relocated_bytes_; }
+ uint64_t NewFiles() const { return new_files_; }
+ bool HasError() const { return error_; }
+
+ void AddBlob(uint64_t size) {
+ ++all_blobs_;
+ all_bytes_ += size;
+ }
+
+ void AddRelocatedBlob(uint64_t size) {
+ ++relocated_blobs_;
+ relocated_bytes_ += size;
+ }
+
+ void AddNewFile() { ++new_files_; }
+
+ void SetError() { error_ = true; }
+
+ private:
+ uint64_t all_blobs_ = 0;
+ uint64_t all_bytes_ = 0;
+ uint64_t relocated_blobs_ = 0;
+ uint64_t relocated_bytes_ = 0;
+ uint64_t new_files_ = 0;
+ bool error_ = false;
+};
+
+} // namespace blob_db
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/blob_db/blob_db_impl.cc b/src/rocksdb/utilities/blob_db/blob_db_impl.cc
new file mode 100644
index 000000000..5f2ca2498
--- /dev/null
+++ b/src/rocksdb/utilities/blob_db/blob_db_impl.cc
@@ -0,0 +1,2116 @@
+
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+#ifndef ROCKSDB_LITE
+
+#include "utilities/blob_db/blob_db_impl.h"
+#include <algorithm>
+#include <cinttypes>
+#include <iomanip>
+#include <memory>
+#include <sstream>
+
+#include "db/blob_index.h"
+#include "db/db_impl/db_impl.h"
+#include "db/write_batch_internal.h"
+#include "env/composite_env_wrapper.h"
+#include "file/file_util.h"
+#include "file/filename.h"
+#include "file/random_access_file_reader.h"
+#include "file/sst_file_manager_impl.h"
+#include "file/writable_file_writer.h"
+#include "logging/logging.h"
+#include "monitoring/instrumented_mutex.h"
+#include "monitoring/statistics.h"
+#include "rocksdb/convenience.h"
+#include "rocksdb/env.h"
+#include "rocksdb/iterator.h"
+#include "rocksdb/utilities/stackable_db.h"
+#include "rocksdb/utilities/transaction.h"
+#include "table/block_based/block.h"
+#include "table/block_based/block_based_table_builder.h"
+#include "table/block_based/block_builder.h"
+#include "table/meta_blocks.h"
+#include "test_util/sync_point.h"
+#include "util/cast_util.h"
+#include "util/crc32c.h"
+#include "util/mutexlock.h"
+#include "util/random.h"
+#include "util/stop_watch.h"
+#include "util/timer_queue.h"
+#include "utilities/blob_db/blob_compaction_filter.h"
+#include "utilities/blob_db/blob_db_iterator.h"
+#include "utilities/blob_db/blob_db_listener.h"
+
+namespace {
+int kBlockBasedTableVersionFormat = 2;
+} // end namespace
+
+namespace ROCKSDB_NAMESPACE {
+namespace blob_db {
+
+bool BlobFileComparator::operator()(
+ const std::shared_ptr<BlobFile>& lhs,
+ const std::shared_ptr<BlobFile>& rhs) const {
+ return lhs->BlobFileNumber() > rhs->BlobFileNumber();
+}
+
+bool BlobFileComparatorTTL::operator()(
+ const std::shared_ptr<BlobFile>& lhs,
+ const std::shared_ptr<BlobFile>& rhs) const {
+ assert(lhs->HasTTL() && rhs->HasTTL());
+ if (lhs->expiration_range_.first < rhs->expiration_range_.first) {
+ return true;
+ }
+ if (lhs->expiration_range_.first > rhs->expiration_range_.first) {
+ return false;
+ }
+ return lhs->BlobFileNumber() < rhs->BlobFileNumber();
+}
+
+BlobDBImpl::BlobDBImpl(const std::string& dbname,
+ const BlobDBOptions& blob_db_options,
+ const DBOptions& db_options,
+ const ColumnFamilyOptions& cf_options)
+ : BlobDB(),
+ dbname_(dbname),
+ db_impl_(nullptr),
+ env_(db_options.env),
+ bdb_options_(blob_db_options),
+ db_options_(db_options),
+ cf_options_(cf_options),
+ env_options_(db_options),
+ statistics_(db_options_.statistics.get()),
+ next_file_number_(1),
+ flush_sequence_(0),
+ closed_(true),
+ open_file_count_(0),
+ total_blob_size_(0),
+ live_sst_size_(0),
+ fifo_eviction_seq_(0),
+ evict_expiration_up_to_(0),
+ debug_level_(0) {
+ blob_dir_ = (bdb_options_.path_relative)
+ ? dbname + "/" + bdb_options_.blob_dir
+ : bdb_options_.blob_dir;
+ env_options_.bytes_per_sync = blob_db_options.bytes_per_sync;
+}
+
+BlobDBImpl::~BlobDBImpl() {
+ tqueue_.shutdown();
+ // CancelAllBackgroundWork(db_, true);
+ Status s __attribute__((__unused__)) = Close();
+ assert(s.ok());
+}
+
+Status BlobDBImpl::Close() {
+ if (closed_) {
+ return Status::OK();
+ }
+ closed_ = true;
+
+ // Close base DB before BlobDBImpl destructs to stop event listener and
+ // compaction filter call.
+ Status s = db_->Close();
+ // delete db_ anyway even if close failed.
+ delete db_;
+ // Reset pointers to avoid StackableDB delete the pointer again.
+ db_ = nullptr;
+ db_impl_ = nullptr;
+ if (!s.ok()) {
+ return s;
+ }
+
+ s = SyncBlobFiles();
+ return s;
+}
+
+BlobDBOptions BlobDBImpl::GetBlobDBOptions() const { return bdb_options_; }
+
+Status BlobDBImpl::Open(std::vector<ColumnFamilyHandle*>* handles) {
+ assert(handles != nullptr);
+ assert(db_ == nullptr);
+
+ if (blob_dir_.empty()) {
+ return Status::NotSupported("No blob directory in options");
+ }
+
+ if (cf_options_.compaction_filter != nullptr ||
+ cf_options_.compaction_filter_factory != nullptr) {
+ return Status::NotSupported("Blob DB doesn't support compaction filter.");
+ }
+
+ if (bdb_options_.garbage_collection_cutoff < 0.0 ||
+ bdb_options_.garbage_collection_cutoff > 1.0) {
+ return Status::InvalidArgument(
+ "Garbage collection cutoff must be in the interval [0.0, 1.0]");
+ }
+
+ // Temporarily disable compactions in the base DB during open; save the user
+ // defined value beforehand so we can restore it once BlobDB is initialized.
+ // Note: this is only needed if garbage collection is enabled.
+ const bool disable_auto_compactions = cf_options_.disable_auto_compactions;
+
+ if (bdb_options_.enable_garbage_collection) {
+ cf_options_.disable_auto_compactions = true;
+ }
+
+ Status s;
+
+ // Create info log.
+ if (db_options_.info_log == nullptr) {
+ s = CreateLoggerFromOptions(dbname_, db_options_, &db_options_.info_log);
+ if (!s.ok()) {
+ return s;
+ }
+ }
+
+ ROCKS_LOG_INFO(db_options_.info_log, "Opening BlobDB...");
+
+ // Open blob directory.
+ s = env_->CreateDirIfMissing(blob_dir_);
+ if (!s.ok()) {
+ ROCKS_LOG_ERROR(db_options_.info_log,
+ "Failed to create blob_dir %s, status: %s",
+ blob_dir_.c_str(), s.ToString().c_str());
+ }
+ s = env_->NewDirectory(blob_dir_, &dir_ent_);
+ if (!s.ok()) {
+ ROCKS_LOG_ERROR(db_options_.info_log,
+ "Failed to open blob_dir %s, status: %s", blob_dir_.c_str(),
+ s.ToString().c_str());
+ return s;
+ }
+
+ // Open blob files.
+ s = OpenAllBlobFiles();
+ if (!s.ok()) {
+ return s;
+ }
+
+ // Update options
+ if (bdb_options_.enable_garbage_collection) {
+ db_options_.listeners.push_back(std::make_shared<BlobDBListenerGC>(this));
+ cf_options_.compaction_filter_factory =
+ std::make_shared<BlobIndexCompactionFilterFactoryGC>(this, env_,
+ statistics_);
+ } else {
+ db_options_.listeners.push_back(std::make_shared<BlobDBListener>(this));
+ cf_options_.compaction_filter_factory =
+ std::make_shared<BlobIndexCompactionFilterFactory>(this, env_,
+ statistics_);
+ }
+
+ // Open base db.
+ ColumnFamilyDescriptor cf_descriptor(kDefaultColumnFamilyName, cf_options_);
+ s = DB::Open(db_options_, dbname_, {cf_descriptor}, handles, &db_);
+ if (!s.ok()) {
+ return s;
+ }
+ db_impl_ = static_cast_with_check<DBImpl, DB>(db_->GetRootDB());
+
+ // Initialize SST file <-> oldest blob file mapping if garbage collection
+ // is enabled.
+ if (bdb_options_.enable_garbage_collection) {
+ std::vector<LiveFileMetaData> live_files;
+ db_->GetLiveFilesMetaData(&live_files);
+
+ InitializeBlobFileToSstMapping(live_files);
+
+ MarkUnreferencedBlobFilesObsoleteDuringOpen();
+
+ if (!disable_auto_compactions) {
+ s = db_->EnableAutoCompaction(*handles);
+ if (!s.ok()) {
+ ROCKS_LOG_ERROR(
+ db_options_.info_log,
+ "Failed to enable automatic compactions during open, status: %s",
+ s.ToString().c_str());
+ return s;
+ }
+ }
+ }
+
+ // Add trash files in blob dir to file delete scheduler.
+ SstFileManagerImpl* sfm = static_cast<SstFileManagerImpl*>(
+ db_impl_->immutable_db_options().sst_file_manager.get());
+ DeleteScheduler::CleanupDirectory(env_, sfm, blob_dir_);
+
+ UpdateLiveSSTSize();
+
+ // Start background jobs.
+ if (!bdb_options_.disable_background_tasks) {
+ StartBackgroundTasks();
+ }
+
+ ROCKS_LOG_INFO(db_options_.info_log, "BlobDB pointer %p", this);
+ bdb_options_.Dump(db_options_.info_log.get());
+ closed_ = false;
+ return s;
+}
+
+void BlobDBImpl::StartBackgroundTasks() {
+ // store a call to a member function and object
+ tqueue_.add(
+ kReclaimOpenFilesPeriodMillisecs,
+ std::bind(&BlobDBImpl::ReclaimOpenFiles, this, std::placeholders::_1));
+ tqueue_.add(
+ kDeleteObsoleteFilesPeriodMillisecs,
+ std::bind(&BlobDBImpl::DeleteObsoleteFiles, this, std::placeholders::_1));
+ tqueue_.add(kSanityCheckPeriodMillisecs,
+ std::bind(&BlobDBImpl::SanityCheck, this, std::placeholders::_1));
+ tqueue_.add(
+ kEvictExpiredFilesPeriodMillisecs,
+ std::bind(&BlobDBImpl::EvictExpiredFiles, this, std::placeholders::_1));
+}
+
+Status BlobDBImpl::GetAllBlobFiles(std::set<uint64_t>* file_numbers) {
+ assert(file_numbers != nullptr);
+ std::vector<std::string> all_files;
+ Status s = env_->GetChildren(blob_dir_, &all_files);
+ if (!s.ok()) {
+ ROCKS_LOG_ERROR(db_options_.info_log,
+ "Failed to get list of blob files, status: %s",
+ s.ToString().c_str());
+ return s;
+ }
+
+ for (const auto& file_name : all_files) {
+ uint64_t file_number;
+ FileType type;
+ bool success = ParseFileName(file_name, &file_number, &type);
+ if (success && type == kBlobFile) {
+ file_numbers->insert(file_number);
+ } else {
+ ROCKS_LOG_WARN(db_options_.info_log,
+ "Skipping file in blob directory: %s", file_name.c_str());
+ }
+ }
+
+ return s;
+}
+
+Status BlobDBImpl::OpenAllBlobFiles() {
+ std::set<uint64_t> file_numbers;
+ Status s = GetAllBlobFiles(&file_numbers);
+ if (!s.ok()) {
+ return s;
+ }
+
+ if (!file_numbers.empty()) {
+ next_file_number_.store(*file_numbers.rbegin() + 1);
+ }
+
+ std::ostringstream blob_file_oss;
+ std::ostringstream live_imm_oss;
+ std::ostringstream obsolete_file_oss;
+
+ for (auto& file_number : file_numbers) {
+ std::shared_ptr<BlobFile> blob_file = std::make_shared<BlobFile>(
+ this, blob_dir_, file_number, db_options_.info_log.get());
+ blob_file->MarkImmutable(/* sequence */ 0);
+
+ // Read file header and footer
+ Status read_metadata_status = blob_file->ReadMetadata(env_, env_options_);
+ if (read_metadata_status.IsCorruption()) {
+ // Remove incomplete file.
+ if (!obsolete_files_.empty()) {
+ obsolete_file_oss << ", ";
+ }
+ obsolete_file_oss << file_number;
+
+ ObsoleteBlobFile(blob_file, 0 /*obsolete_seq*/, false /*update_size*/);
+ continue;
+ } else if (!read_metadata_status.ok()) {
+ ROCKS_LOG_ERROR(db_options_.info_log,
+ "Unable to read metadata of blob file %" PRIu64
+ ", status: '%s'",
+ file_number, read_metadata_status.ToString().c_str());
+ return read_metadata_status;
+ }
+
+ total_blob_size_ += blob_file->GetFileSize();
+
+ if (!blob_files_.empty()) {
+ blob_file_oss << ", ";
+ }
+ blob_file_oss << file_number;
+
+ blob_files_[file_number] = blob_file;
+
+ if (!blob_file->HasTTL()) {
+ if (!live_imm_non_ttl_blob_files_.empty()) {
+ live_imm_oss << ", ";
+ }
+ live_imm_oss << file_number;
+
+ live_imm_non_ttl_blob_files_[file_number] = blob_file;
+ }
+ }
+
+ ROCKS_LOG_INFO(db_options_.info_log,
+ "Found %" ROCKSDB_PRIszt " blob files: %s", blob_files_.size(),
+ blob_file_oss.str().c_str());
+ ROCKS_LOG_INFO(
+ db_options_.info_log, "Found %" ROCKSDB_PRIszt " non-TTL blob files: %s",
+ live_imm_non_ttl_blob_files_.size(), live_imm_oss.str().c_str());
+ ROCKS_LOG_INFO(db_options_.info_log,
+ "Found %" ROCKSDB_PRIszt
+ " incomplete or corrupted blob files: %s",
+ obsolete_files_.size(), obsolete_file_oss.str().c_str());
+ return s;
+}
+
+template <typename Linker>
+void BlobDBImpl::LinkSstToBlobFileImpl(uint64_t sst_file_number,
+ uint64_t blob_file_number,
+ Linker linker) {
+ assert(bdb_options_.enable_garbage_collection);
+ assert(blob_file_number != kInvalidBlobFileNumber);
+
+ auto it = blob_files_.find(blob_file_number);
+ if (it == blob_files_.end()) {
+ ROCKS_LOG_WARN(db_options_.info_log,
+ "Blob file %" PRIu64
+ " not found while trying to link "
+ "SST file %" PRIu64,
+ blob_file_number, sst_file_number);
+ return;
+ }
+
+ BlobFile* const blob_file = it->second.get();
+ assert(blob_file);
+
+ linker(blob_file, sst_file_number);
+
+ ROCKS_LOG_INFO(db_options_.info_log,
+ "Blob file %" PRIu64 " linked to SST file %" PRIu64,
+ blob_file_number, sst_file_number);
+}
+
+void BlobDBImpl::LinkSstToBlobFile(uint64_t sst_file_number,
+ uint64_t blob_file_number) {
+ auto linker = [](BlobFile* blob_file, uint64_t sst_file) {
+ WriteLock file_lock(&blob_file->mutex_);
+ blob_file->LinkSstFile(sst_file);
+ };
+
+ LinkSstToBlobFileImpl(sst_file_number, blob_file_number, linker);
+}
+
+void BlobDBImpl::LinkSstToBlobFileNoLock(uint64_t sst_file_number,
+ uint64_t blob_file_number) {
+ auto linker = [](BlobFile* blob_file, uint64_t sst_file) {
+ blob_file->LinkSstFile(sst_file);
+ };
+
+ LinkSstToBlobFileImpl(sst_file_number, blob_file_number, linker);
+}
+
+void BlobDBImpl::UnlinkSstFromBlobFile(uint64_t sst_file_number,
+ uint64_t blob_file_number) {
+ assert(bdb_options_.enable_garbage_collection);
+ assert(blob_file_number != kInvalidBlobFileNumber);
+
+ auto it = blob_files_.find(blob_file_number);
+ if (it == blob_files_.end()) {
+ ROCKS_LOG_WARN(db_options_.info_log,
+ "Blob file %" PRIu64
+ " not found while trying to unlink "
+ "SST file %" PRIu64,
+ blob_file_number, sst_file_number);
+ return;
+ }
+
+ BlobFile* const blob_file = it->second.get();
+ assert(blob_file);
+
+ {
+ WriteLock file_lock(&blob_file->mutex_);
+ blob_file->UnlinkSstFile(sst_file_number);
+ }
+
+ ROCKS_LOG_INFO(db_options_.info_log,
+ "Blob file %" PRIu64 " unlinked from SST file %" PRIu64,
+ blob_file_number, sst_file_number);
+}
+
+void BlobDBImpl::InitializeBlobFileToSstMapping(
+ const std::vector<LiveFileMetaData>& live_files) {
+ assert(bdb_options_.enable_garbage_collection);
+
+ for (const auto& live_file : live_files) {
+ const uint64_t sst_file_number = live_file.file_number;
+ const uint64_t blob_file_number = live_file.oldest_blob_file_number;
+
+ if (blob_file_number == kInvalidBlobFileNumber) {
+ continue;
+ }
+
+ LinkSstToBlobFileNoLock(sst_file_number, blob_file_number);
+ }
+}
+
+void BlobDBImpl::ProcessFlushJobInfo(const FlushJobInfo& info) {
+ assert(bdb_options_.enable_garbage_collection);
+
+ WriteLock lock(&mutex_);
+
+ if (info.oldest_blob_file_number != kInvalidBlobFileNumber) {
+ LinkSstToBlobFile(info.file_number, info.oldest_blob_file_number);
+ }
+
+ assert(flush_sequence_ < info.largest_seqno);
+ flush_sequence_ = info.largest_seqno;
+
+ MarkUnreferencedBlobFilesObsolete();
+}
+
+void BlobDBImpl::ProcessCompactionJobInfo(const CompactionJobInfo& info) {
+ assert(bdb_options_.enable_garbage_collection);
+
+ if (!info.status.ok()) {
+ return;
+ }
+
+ // Note: the same SST file may appear in both the input and the output
+ // file list in case of a trivial move. We walk through the two lists
+ // below in a fashion that's similar to merge sort to detect this.
+
+ auto cmp = [](const CompactionFileInfo& lhs, const CompactionFileInfo& rhs) {
+ return lhs.file_number < rhs.file_number;
+ };
+
+ auto inputs = info.input_file_infos;
+ auto iit = inputs.begin();
+ const auto iit_end = inputs.end();
+
+ std::sort(iit, iit_end, cmp);
+
+ auto outputs = info.output_file_infos;
+ auto oit = outputs.begin();
+ const auto oit_end = outputs.end();
+
+ std::sort(oit, oit_end, cmp);
+
+ WriteLock lock(&mutex_);
+
+ while (iit != iit_end && oit != oit_end) {
+ const auto& input = *iit;
+ const auto& output = *oit;
+
+ if (input.file_number == output.file_number) {
+ ++iit;
+ ++oit;
+ } else if (input.file_number < output.file_number) {
+ if (input.oldest_blob_file_number != kInvalidBlobFileNumber) {
+ UnlinkSstFromBlobFile(input.file_number, input.oldest_blob_file_number);
+ }
+
+ ++iit;
+ } else {
+ assert(output.file_number < input.file_number);
+
+ if (output.oldest_blob_file_number != kInvalidBlobFileNumber) {
+ LinkSstToBlobFile(output.file_number, output.oldest_blob_file_number);
+ }
+
+ ++oit;
+ }
+ }
+
+ while (iit != iit_end) {
+ const auto& input = *iit;
+
+ if (input.oldest_blob_file_number != kInvalidBlobFileNumber) {
+ UnlinkSstFromBlobFile(input.file_number, input.oldest_blob_file_number);
+ }
+
+ ++iit;
+ }
+
+ while (oit != oit_end) {
+ const auto& output = *oit;
+
+ if (output.oldest_blob_file_number != kInvalidBlobFileNumber) {
+ LinkSstToBlobFile(output.file_number, output.oldest_blob_file_number);
+ }
+
+ ++oit;
+ }
+
+ MarkUnreferencedBlobFilesObsolete();
+}
+
+bool BlobDBImpl::MarkBlobFileObsoleteIfNeeded(
+ const std::shared_ptr<BlobFile>& blob_file, SequenceNumber obsolete_seq) {
+ assert(blob_file);
+ assert(!blob_file->HasTTL());
+ assert(blob_file->Immutable());
+ assert(bdb_options_.enable_garbage_collection);
+
+ // Note: FIFO eviction could have marked this file obsolete already.
+ if (blob_file->Obsolete()) {
+ return true;
+ }
+
+ // We cannot mark this file (or any higher-numbered files for that matter)
+ // obsolete if it is referenced by any memtables or SSTs. We keep track of
+ // the SSTs explicitly. To account for memtables, we keep track of the highest
+ // sequence number received in flush notifications, and we do not mark the
+ // blob file obsolete if there are still unflushed memtables from before
+ // the time the blob file was closed.
+ if (blob_file->GetImmutableSequence() > flush_sequence_ ||
+ !blob_file->GetLinkedSstFiles().empty()) {
+ return false;
+ }
+
+ ROCKS_LOG_INFO(db_options_.info_log,
+ "Blob file %" PRIu64 " is no longer needed, marking obsolete",
+ blob_file->BlobFileNumber());
+
+ ObsoleteBlobFile(blob_file, obsolete_seq, /* update_size */ true);
+ return true;
+}
+
+template <class Functor>
+void BlobDBImpl::MarkUnreferencedBlobFilesObsoleteImpl(Functor mark_if_needed) {
+ assert(bdb_options_.enable_garbage_collection);
+
+ // Iterate through all live immutable non-TTL blob files, and mark them
+ // obsolete assuming no SST files or memtables rely on the blobs in them.
+ // Note: we need to stop as soon as we find a blob file that has any
+ // linked SSTs (or one potentially referenced by memtables).
+
+ uint64_t obsoleted_files = 0;
+
+ auto it = live_imm_non_ttl_blob_files_.begin();
+ while (it != live_imm_non_ttl_blob_files_.end()) {
+ const auto& blob_file = it->second;
+ assert(blob_file);
+ assert(blob_file->BlobFileNumber() == it->first);
+ assert(!blob_file->HasTTL());
+ assert(blob_file->Immutable());
+
+ // Small optimization: Obsolete() does an atomic read, so we can do
+ // this check without taking a lock on the blob file's mutex.
+ if (blob_file->Obsolete()) {
+ it = live_imm_non_ttl_blob_files_.erase(it);
+ continue;
+ }
+
+ if (!mark_if_needed(blob_file)) {
+ break;
+ }
+
+ it = live_imm_non_ttl_blob_files_.erase(it);
+
+ ++obsoleted_files;
+ }
+
+ if (obsoleted_files > 0) {
+ ROCKS_LOG_INFO(db_options_.info_log,
+ "%" PRIu64 " blob file(s) marked obsolete by GC",
+ obsoleted_files);
+ RecordTick(statistics_, BLOB_DB_GC_NUM_FILES, obsoleted_files);
+ }
+}
+
+void BlobDBImpl::MarkUnreferencedBlobFilesObsolete() {
+ const SequenceNumber obsolete_seq = GetLatestSequenceNumber();
+
+ MarkUnreferencedBlobFilesObsoleteImpl(
+ [=](const std::shared_ptr<BlobFile>& blob_file) {
+ WriteLock file_lock(&blob_file->mutex_);
+ return MarkBlobFileObsoleteIfNeeded(blob_file, obsolete_seq);
+ });
+}
+
+void BlobDBImpl::MarkUnreferencedBlobFilesObsoleteDuringOpen() {
+ MarkUnreferencedBlobFilesObsoleteImpl(
+ [=](const std::shared_ptr<BlobFile>& blob_file) {
+ return MarkBlobFileObsoleteIfNeeded(blob_file, /* obsolete_seq */ 0);
+ });
+}
+
+void BlobDBImpl::CloseRandomAccessLocked(
+ const std::shared_ptr<BlobFile>& bfile) {
+ bfile->CloseRandomAccessLocked();
+ open_file_count_--;
+}
+
+Status BlobDBImpl::GetBlobFileReader(
+ const std::shared_ptr<BlobFile>& blob_file,
+ std::shared_ptr<RandomAccessFileReader>* reader) {
+ assert(reader != nullptr);
+ bool fresh_open = false;
+ Status s = blob_file->GetReader(env_, env_options_, reader, &fresh_open);
+ if (s.ok() && fresh_open) {
+ assert(*reader != nullptr);
+ open_file_count_++;
+ }
+ return s;
+}
+
+std::shared_ptr<BlobFile> BlobDBImpl::NewBlobFile(
+ bool has_ttl, const ExpirationRange& expiration_range,
+ const std::string& reason) {
+ assert(has_ttl == (expiration_range.first || expiration_range.second));
+
+ uint64_t file_num = next_file_number_++;
+
+ const uint32_t column_family_id =
+ static_cast<ColumnFamilyHandleImpl*>(DefaultColumnFamily())->GetID();
+ auto blob_file = std::make_shared<BlobFile>(
+ this, blob_dir_, file_num, db_options_.info_log.get(), column_family_id,
+ bdb_options_.compression, has_ttl, expiration_range);
+
+ ROCKS_LOG_DEBUG(db_options_.info_log, "New blob file created: %s reason='%s'",
+ blob_file->PathName().c_str(), reason.c_str());
+ LogFlush(db_options_.info_log);
+
+ return blob_file;
+}
+
+void BlobDBImpl::RegisterBlobFile(std::shared_ptr<BlobFile> blob_file) {
+ const uint64_t blob_file_number = blob_file->BlobFileNumber();
+
+ auto it = blob_files_.lower_bound(blob_file_number);
+ assert(it == blob_files_.end() || it->first != blob_file_number);
+
+ blob_files_.insert(it,
+ std::map<uint64_t, std::shared_ptr<BlobFile>>::value_type(
+ blob_file_number, std::move(blob_file)));
+}
+
+Status BlobDBImpl::CreateWriterLocked(const std::shared_ptr<BlobFile>& bfile) {
+ std::string fpath(bfile->PathName());
+ std::unique_ptr<WritableFile> wfile;
+
+ Status s = env_->ReopenWritableFile(fpath, &wfile, env_options_);
+ if (!s.ok()) {
+ ROCKS_LOG_ERROR(db_options_.info_log,
+ "Failed to open blob file for write: %s status: '%s'"
+ " exists: '%s'",
+ fpath.c_str(), s.ToString().c_str(),
+ env_->FileExists(fpath).ToString().c_str());
+ return s;
+ }
+
+ std::unique_ptr<WritableFileWriter> fwriter;
+ fwriter.reset(new WritableFileWriter(
+ NewLegacyWritableFileWrapper(std::move(wfile)), fpath, env_options_));
+
+ uint64_t boffset = bfile->GetFileSize();
+ if (debug_level_ >= 2 && boffset) {
+ ROCKS_LOG_DEBUG(db_options_.info_log,
+ "Open blob file: %s with offset: %" PRIu64, fpath.c_str(),
+ boffset);
+ }
+
+ Writer::ElemType et = Writer::kEtNone;
+ if (bfile->file_size_ == BlobLogHeader::kSize) {
+ et = Writer::kEtFileHdr;
+ } else if (bfile->file_size_ > BlobLogHeader::kSize) {
+ et = Writer::kEtRecord;
+ } else if (bfile->file_size_) {
+ ROCKS_LOG_WARN(db_options_.info_log,
+ "Open blob file: %s with wrong size: %" PRIu64,
+ fpath.c_str(), boffset);
+ return Status::Corruption("Invalid blob file size");
+ }
+
+ bfile->log_writer_ = std::make_shared<Writer>(
+ std::move(fwriter), env_, statistics_, bfile->file_number_,
+ bdb_options_.bytes_per_sync, db_options_.use_fsync, boffset);
+ bfile->log_writer_->last_elem_type_ = et;
+
+ return s;
+}
+
+std::shared_ptr<BlobFile> BlobDBImpl::FindBlobFileLocked(
+ uint64_t expiration) const {
+ if (open_ttl_files_.empty()) {
+ return nullptr;
+ }
+
+ std::shared_ptr<BlobFile> tmp = std::make_shared<BlobFile>();
+ tmp->SetHasTTL(true);
+ tmp->expiration_range_ = std::make_pair(expiration, 0);
+ tmp->file_number_ = std::numeric_limits<uint64_t>::max();
+
+ auto citr = open_ttl_files_.equal_range(tmp);
+ if (citr.first == open_ttl_files_.end()) {
+ assert(citr.second == open_ttl_files_.end());
+
+ std::shared_ptr<BlobFile> check = *(open_ttl_files_.rbegin());
+ return (check->expiration_range_.second <= expiration) ? nullptr : check;
+ }
+
+ if (citr.first != citr.second) {
+ return *(citr.first);
+ }
+
+ auto finditr = citr.second;
+ if (finditr != open_ttl_files_.begin()) {
+ --finditr;
+ }
+
+ bool b2 = (*finditr)->expiration_range_.second <= expiration;
+ bool b1 = (*finditr)->expiration_range_.first > expiration;
+
+ return (b1 || b2) ? nullptr : (*finditr);
+}
+
+Status BlobDBImpl::CheckOrCreateWriterLocked(
+ const std::shared_ptr<BlobFile>& blob_file,
+ std::shared_ptr<Writer>* writer) {
+ assert(writer != nullptr);
+ *writer = blob_file->GetWriter();
+ if (*writer != nullptr) {
+ return Status::OK();
+ }
+ Status s = CreateWriterLocked(blob_file);
+ if (s.ok()) {
+ *writer = blob_file->GetWriter();
+ }
+ return s;
+}
+
+Status BlobDBImpl::CreateBlobFileAndWriter(
+ bool has_ttl, const ExpirationRange& expiration_range,
+ const std::string& reason, std::shared_ptr<BlobFile>* blob_file,
+ std::shared_ptr<Writer>* writer) {
+ assert(has_ttl == (expiration_range.first || expiration_range.second));
+ assert(blob_file);
+ assert(writer);
+
+ *blob_file = NewBlobFile(has_ttl, expiration_range, reason);
+ assert(*blob_file);
+
+ // file not visible, hence no lock
+ Status s = CheckOrCreateWriterLocked(*blob_file, writer);
+ if (!s.ok()) {
+ ROCKS_LOG_ERROR(db_options_.info_log,
+ "Failed to get writer for blob file: %s, error: %s",
+ (*blob_file)->PathName().c_str(), s.ToString().c_str());
+ return s;
+ }
+
+ assert(*writer);
+
+ s = (*writer)->WriteHeader((*blob_file)->header_);
+ if (!s.ok()) {
+ ROCKS_LOG_ERROR(db_options_.info_log,
+ "Failed to write header to new blob file: %s"
+ " status: '%s'",
+ (*blob_file)->PathName().c_str(), s.ToString().c_str());
+ return s;
+ }
+
+ (*blob_file)->SetFileSize(BlobLogHeader::kSize);
+ total_blob_size_ += BlobLogHeader::kSize;
+
+ return s;
+}
+
+Status BlobDBImpl::SelectBlobFile(std::shared_ptr<BlobFile>* blob_file) {
+ assert(blob_file);
+
+ {
+ ReadLock rl(&mutex_);
+
+ if (open_non_ttl_file_) {
+ assert(!open_non_ttl_file_->Immutable());
+ *blob_file = open_non_ttl_file_;
+ return Status::OK();
+ }
+ }
+
+ // Check again
+ WriteLock wl(&mutex_);
+
+ if (open_non_ttl_file_) {
+ assert(!open_non_ttl_file_->Immutable());
+ *blob_file = open_non_ttl_file_;
+ return Status::OK();
+ }
+
+ std::shared_ptr<Writer> writer;
+ const Status s = CreateBlobFileAndWriter(
+ /* has_ttl */ false, ExpirationRange(),
+ /* reason */ "SelectBlobFile", blob_file, &writer);
+ if (!s.ok()) {
+ return s;
+ }
+
+ RegisterBlobFile(*blob_file);
+ open_non_ttl_file_ = *blob_file;
+
+ return s;
+}
+
+Status BlobDBImpl::SelectBlobFileTTL(uint64_t expiration,
+ std::shared_ptr<BlobFile>* blob_file) {
+ assert(blob_file);
+ assert(expiration != kNoExpiration);
+
+ {
+ ReadLock rl(&mutex_);
+
+ *blob_file = FindBlobFileLocked(expiration);
+ if (*blob_file != nullptr) {
+ assert(!(*blob_file)->Immutable());
+ return Status::OK();
+ }
+ }
+
+ // Check again
+ WriteLock wl(&mutex_);
+
+ *blob_file = FindBlobFileLocked(expiration);
+ if (*blob_file != nullptr) {
+ assert(!(*blob_file)->Immutable());
+ return Status::OK();
+ }
+
+ const uint64_t exp_low =
+ (expiration / bdb_options_.ttl_range_secs) * bdb_options_.ttl_range_secs;
+ const uint64_t exp_high = exp_low + bdb_options_.ttl_range_secs;
+ const ExpirationRange expiration_range(exp_low, exp_high);
+
+ std::ostringstream oss;
+ oss << "SelectBlobFileTTL range: [" << exp_low << ',' << exp_high << ')';
+
+ std::shared_ptr<Writer> writer;
+ const Status s =
+ CreateBlobFileAndWriter(/* has_ttl */ true, expiration_range,
+ /* reason */ oss.str(), blob_file, &writer);
+ if (!s.ok()) {
+ return s;
+ }
+
+ RegisterBlobFile(*blob_file);
+ open_ttl_files_.insert(*blob_file);
+
+ return s;
+}
+
+class BlobDBImpl::BlobInserter : public WriteBatch::Handler {
+ private:
+ const WriteOptions& options_;
+ BlobDBImpl* blob_db_impl_;
+ uint32_t default_cf_id_;
+ WriteBatch batch_;
+
+ public:
+ BlobInserter(const WriteOptions& options, BlobDBImpl* blob_db_impl,
+ uint32_t default_cf_id)
+ : options_(options),
+ blob_db_impl_(blob_db_impl),
+ default_cf_id_(default_cf_id) {}
+
+ WriteBatch* batch() { return &batch_; }
+
+ Status PutCF(uint32_t column_family_id, const Slice& key,
+ const Slice& value) override {
+ if (column_family_id != default_cf_id_) {
+ return Status::NotSupported(
+ "Blob DB doesn't support non-default column family.");
+ }
+ Status s = blob_db_impl_->PutBlobValue(options_, key, value, kNoExpiration,
+ &batch_);
+ return s;
+ }
+
+ Status DeleteCF(uint32_t column_family_id, const Slice& key) override {
+ if (column_family_id != default_cf_id_) {
+ return Status::NotSupported(
+ "Blob DB doesn't support non-default column family.");
+ }
+ Status s = WriteBatchInternal::Delete(&batch_, column_family_id, key);
+ return s;
+ }
+
+ virtual Status DeleteRange(uint32_t column_family_id, const Slice& begin_key,
+ const Slice& end_key) {
+ if (column_family_id != default_cf_id_) {
+ return Status::NotSupported(
+ "Blob DB doesn't support non-default column family.");
+ }
+ Status s = WriteBatchInternal::DeleteRange(&batch_, column_family_id,
+ begin_key, end_key);
+ return s;
+ }
+
+ Status SingleDeleteCF(uint32_t /*column_family_id*/,
+ const Slice& /*key*/) override {
+ return Status::NotSupported("Not supported operation in blob db.");
+ }
+
+ Status MergeCF(uint32_t /*column_family_id*/, const Slice& /*key*/,
+ const Slice& /*value*/) override {
+ return Status::NotSupported("Not supported operation in blob db.");
+ }
+
+ void LogData(const Slice& blob) override { batch_.PutLogData(blob); }
+};
+
+Status BlobDBImpl::Write(const WriteOptions& options, WriteBatch* updates) {
+ StopWatch write_sw(env_, statistics_, BLOB_DB_WRITE_MICROS);
+ RecordTick(statistics_, BLOB_DB_NUM_WRITE);
+ uint32_t default_cf_id =
+ reinterpret_cast<ColumnFamilyHandleImpl*>(DefaultColumnFamily())->GetID();
+ Status s;
+ BlobInserter blob_inserter(options, this, default_cf_id);
+ {
+ // Release write_mutex_ before DB write to avoid race condition with
+ // flush begin listener, which also require write_mutex_ to sync
+ // blob files.
+ MutexLock l(&write_mutex_);
+ s = updates->Iterate(&blob_inserter);
+ }
+ if (!s.ok()) {
+ return s;
+ }
+ return db_->Write(options, blob_inserter.batch());
+}
+
+Status BlobDBImpl::Put(const WriteOptions& options, const Slice& key,
+ const Slice& value) {
+ return PutUntil(options, key, value, kNoExpiration);
+}
+
+Status BlobDBImpl::PutWithTTL(const WriteOptions& options,
+ const Slice& key, const Slice& value,
+ uint64_t ttl) {
+ uint64_t now = EpochNow();
+ uint64_t expiration = kNoExpiration - now > ttl ? now + ttl : kNoExpiration;
+ return PutUntil(options, key, value, expiration);
+}
+
+Status BlobDBImpl::PutUntil(const WriteOptions& options, const Slice& key,
+ const Slice& value, uint64_t expiration) {
+ StopWatch write_sw(env_, statistics_, BLOB_DB_WRITE_MICROS);
+ RecordTick(statistics_, BLOB_DB_NUM_PUT);
+ Status s;
+ WriteBatch batch;
+ {
+ // Release write_mutex_ before DB write to avoid race condition with
+ // flush begin listener, which also require write_mutex_ to sync
+ // blob files.
+ MutexLock l(&write_mutex_);
+ s = PutBlobValue(options, key, value, expiration, &batch);
+ }
+ if (s.ok()) {
+ s = db_->Write(options, &batch);
+ }
+ return s;
+}
+
+Status BlobDBImpl::PutBlobValue(const WriteOptions& /*options*/,
+ const Slice& key, const Slice& value,
+ uint64_t expiration, WriteBatch* batch) {
+ write_mutex_.AssertHeld();
+ Status s;
+ std::string index_entry;
+ uint32_t column_family_id =
+ reinterpret_cast<ColumnFamilyHandleImpl*>(DefaultColumnFamily())->GetID();
+ if (value.size() < bdb_options_.min_blob_size) {
+ if (expiration == kNoExpiration) {
+ // Put as normal value
+ s = batch->Put(key, value);
+ RecordTick(statistics_, BLOB_DB_WRITE_INLINED);
+ } else {
+ // Inlined with TTL
+ BlobIndex::EncodeInlinedTTL(&index_entry, expiration, value);
+ s = WriteBatchInternal::PutBlobIndex(batch, column_family_id, key,
+ index_entry);
+ RecordTick(statistics_, BLOB_DB_WRITE_INLINED_TTL);
+ }
+ } else {
+ std::string compression_output;
+ Slice value_compressed = GetCompressedSlice(value, &compression_output);
+
+ std::string headerbuf;
+ Writer::ConstructBlobHeader(&headerbuf, key, value_compressed, expiration);
+
+ // Check DB size limit before selecting blob file to
+ // Since CheckSizeAndEvictBlobFiles() can close blob files, it needs to be
+ // done before calling SelectBlobFile().
+ s = CheckSizeAndEvictBlobFiles(headerbuf.size() + key.size() +
+ value_compressed.size());
+ if (!s.ok()) {
+ return s;
+ }
+
+ std::shared_ptr<BlobFile> blob_file;
+ if (expiration != kNoExpiration) {
+ s = SelectBlobFileTTL(expiration, &blob_file);
+ } else {
+ s = SelectBlobFile(&blob_file);
+ }
+ if (s.ok()) {
+ assert(blob_file != nullptr);
+ assert(blob_file->GetCompressionType() == bdb_options_.compression);
+ s = AppendBlob(blob_file, headerbuf, key, value_compressed, expiration,
+ &index_entry);
+ }
+ if (s.ok()) {
+ if (expiration != kNoExpiration) {
+ WriteLock file_lock(&blob_file->mutex_);
+ blob_file->ExtendExpirationRange(expiration);
+ }
+ s = CloseBlobFileIfNeeded(blob_file);
+ }
+ if (s.ok()) {
+ s = WriteBatchInternal::PutBlobIndex(batch, column_family_id, key,
+ index_entry);
+ }
+ if (s.ok()) {
+ if (expiration == kNoExpiration) {
+ RecordTick(statistics_, BLOB_DB_WRITE_BLOB);
+ } else {
+ RecordTick(statistics_, BLOB_DB_WRITE_BLOB_TTL);
+ }
+ } else {
+ ROCKS_LOG_ERROR(
+ db_options_.info_log,
+ "Failed to append blob to FILE: %s: KEY: %s VALSZ: %" ROCKSDB_PRIszt
+ " status: '%s' blob_file: '%s'",
+ blob_file->PathName().c_str(), key.ToString().c_str(), value.size(),
+ s.ToString().c_str(), blob_file->DumpState().c_str());
+ }
+ }
+
+ RecordTick(statistics_, BLOB_DB_NUM_KEYS_WRITTEN);
+ RecordTick(statistics_, BLOB_DB_BYTES_WRITTEN, key.size() + value.size());
+ RecordInHistogram(statistics_, BLOB_DB_KEY_SIZE, key.size());
+ RecordInHistogram(statistics_, BLOB_DB_VALUE_SIZE, value.size());
+
+ return s;
+}
+
+Slice BlobDBImpl::GetCompressedSlice(const Slice& raw,
+ std::string* compression_output) const {
+ if (bdb_options_.compression == kNoCompression) {
+ return raw;
+ }
+ StopWatch compression_sw(env_, statistics_, BLOB_DB_COMPRESSION_MICROS);
+ CompressionType type = bdb_options_.compression;
+ CompressionOptions opts;
+ CompressionContext context(type);
+ CompressionInfo info(opts, context, CompressionDict::GetEmptyDict(), type,
+ 0 /* sample_for_compression */);
+ CompressBlock(raw, info, &type, kBlockBasedTableVersionFormat, false,
+ compression_output, nullptr, nullptr);
+ return *compression_output;
+}
+
+Status BlobDBImpl::CompactFiles(
+ const CompactionOptions& compact_options,
+ const std::vector<std::string>& input_file_names, const int output_level,
+ const int output_path_id, std::vector<std::string>* const output_file_names,
+ CompactionJobInfo* compaction_job_info) {
+ // Note: we need CompactionJobInfo to be able to track updates to the
+ // blob file <-> SST mappings, so we provide one if the user hasn't,
+ // assuming that GC is enabled.
+ CompactionJobInfo info{};
+ if (bdb_options_.enable_garbage_collection && !compaction_job_info) {
+ compaction_job_info = &info;
+ }
+
+ const Status s =
+ db_->CompactFiles(compact_options, input_file_names, output_level,
+ output_path_id, output_file_names, compaction_job_info);
+ if (!s.ok()) {
+ return s;
+ }
+
+ if (bdb_options_.enable_garbage_collection) {
+ assert(compaction_job_info);
+ ProcessCompactionJobInfo(*compaction_job_info);
+ }
+
+ return s;
+}
+
+void BlobDBImpl::GetCompactionContextCommon(
+ BlobCompactionContext* context) const {
+ assert(context);
+
+ context->next_file_number = next_file_number_.load();
+ context->current_blob_files.clear();
+ for (auto& p : blob_files_) {
+ context->current_blob_files.insert(p.first);
+ }
+ context->fifo_eviction_seq = fifo_eviction_seq_;
+ context->evict_expiration_up_to = evict_expiration_up_to_;
+}
+
+void BlobDBImpl::GetCompactionContext(BlobCompactionContext* context) {
+ assert(context);
+
+ ReadLock l(&mutex_);
+ GetCompactionContextCommon(context);
+}
+
+void BlobDBImpl::GetCompactionContext(BlobCompactionContext* context,
+ BlobCompactionContextGC* context_gc) {
+ assert(context);
+ assert(context_gc);
+
+ ReadLock l(&mutex_);
+ GetCompactionContextCommon(context);
+
+ context_gc->blob_db_impl = this;
+
+ if (!live_imm_non_ttl_blob_files_.empty()) {
+ auto it = live_imm_non_ttl_blob_files_.begin();
+ std::advance(it, bdb_options_.garbage_collection_cutoff *
+ live_imm_non_ttl_blob_files_.size());
+ context_gc->cutoff_file_number = it != live_imm_non_ttl_blob_files_.end()
+ ? it->first
+ : std::numeric_limits<uint64_t>::max();
+ }
+}
+
+void BlobDBImpl::UpdateLiveSSTSize() {
+ uint64_t live_sst_size = 0;
+ bool ok = GetIntProperty(DB::Properties::kLiveSstFilesSize, &live_sst_size);
+ if (ok) {
+ live_sst_size_.store(live_sst_size);
+ ROCKS_LOG_INFO(db_options_.info_log,
+ "Updated total SST file size: %" PRIu64 " bytes.",
+ live_sst_size);
+ } else {
+ ROCKS_LOG_ERROR(
+ db_options_.info_log,
+ "Failed to update total SST file size after flush or compaction.");
+ }
+ {
+ // Trigger FIFO eviction if needed.
+ MutexLock l(&write_mutex_);
+ Status s = CheckSizeAndEvictBlobFiles(0, true /*force*/);
+ if (s.IsNoSpace()) {
+ ROCKS_LOG_WARN(db_options_.info_log,
+ "DB grow out-of-space after SST size updated. Current live"
+ " SST size: %" PRIu64
+ " , current blob files size: %" PRIu64 ".",
+ live_sst_size_.load(), total_blob_size_.load());
+ }
+ }
+}
+
+Status BlobDBImpl::CheckSizeAndEvictBlobFiles(uint64_t blob_size,
+ bool force_evict) {
+ write_mutex_.AssertHeld();
+
+ uint64_t live_sst_size = live_sst_size_.load();
+ if (bdb_options_.max_db_size == 0 ||
+ live_sst_size + total_blob_size_.load() + blob_size <=
+ bdb_options_.max_db_size) {
+ return Status::OK();
+ }
+
+ if (bdb_options_.is_fifo == false ||
+ (!force_evict && live_sst_size + blob_size > bdb_options_.max_db_size)) {
+ // FIFO eviction is disabled, or no space to insert new blob even we evict
+ // all blob files.
+ return Status::NoSpace(
+ "Write failed, as writing it would exceed max_db_size limit.");
+ }
+
+ std::vector<std::shared_ptr<BlobFile>> candidate_files;
+ CopyBlobFiles(&candidate_files);
+ std::sort(candidate_files.begin(), candidate_files.end(),
+ BlobFileComparator());
+ fifo_eviction_seq_ = GetLatestSequenceNumber();
+
+ WriteLock l(&mutex_);
+
+ while (!candidate_files.empty() &&
+ live_sst_size + total_blob_size_.load() + blob_size >
+ bdb_options_.max_db_size) {
+ std::shared_ptr<BlobFile> blob_file = candidate_files.back();
+ candidate_files.pop_back();
+ WriteLock file_lock(&blob_file->mutex_);
+ if (blob_file->Obsolete()) {
+ // File already obsoleted by someone else.
+ assert(blob_file->Immutable());
+ continue;
+ }
+ // FIFO eviction can evict open blob files.
+ if (!blob_file->Immutable()) {
+ Status s = CloseBlobFile(blob_file);
+ if (!s.ok()) {
+ return s;
+ }
+ }
+ assert(blob_file->Immutable());
+ auto expiration_range = blob_file->GetExpirationRange();
+ ROCKS_LOG_INFO(db_options_.info_log,
+ "Evict oldest blob file since DB out of space. Current "
+ "live SST file size: %" PRIu64 ", total blob size: %" PRIu64
+ ", max db size: %" PRIu64 ", evicted blob file #%" PRIu64
+ ".",
+ live_sst_size, total_blob_size_.load(),
+ bdb_options_.max_db_size, blob_file->BlobFileNumber());
+ ObsoleteBlobFile(blob_file, fifo_eviction_seq_, true /*update_size*/);
+ evict_expiration_up_to_ = expiration_range.first;
+ RecordTick(statistics_, BLOB_DB_FIFO_NUM_FILES_EVICTED);
+ RecordTick(statistics_, BLOB_DB_FIFO_NUM_KEYS_EVICTED,
+ blob_file->BlobCount());
+ RecordTick(statistics_, BLOB_DB_FIFO_BYTES_EVICTED,
+ blob_file->GetFileSize());
+ TEST_SYNC_POINT("BlobDBImpl::EvictOldestBlobFile:Evicted");
+ }
+ if (live_sst_size + total_blob_size_.load() + blob_size >
+ bdb_options_.max_db_size) {
+ return Status::NoSpace(
+ "Write failed, as writing it would exceed max_db_size limit.");
+ }
+ return Status::OK();
+}
+
+Status BlobDBImpl::AppendBlob(const std::shared_ptr<BlobFile>& bfile,
+ const std::string& headerbuf, const Slice& key,
+ const Slice& value, uint64_t expiration,
+ std::string* index_entry) {
+ Status s;
+ uint64_t blob_offset = 0;
+ uint64_t key_offset = 0;
+ {
+ WriteLock lockbfile_w(&bfile->mutex_);
+ std::shared_ptr<Writer> writer;
+ s = CheckOrCreateWriterLocked(bfile, &writer);
+ if (!s.ok()) {
+ return s;
+ }
+
+ // write the blob to the blob log.
+ s = writer->EmitPhysicalRecord(headerbuf, key, value, &key_offset,
+ &blob_offset);
+ }
+
+ if (!s.ok()) {
+ ROCKS_LOG_ERROR(db_options_.info_log,
+ "Invalid status in AppendBlob: %s status: '%s'",
+ bfile->PathName().c_str(), s.ToString().c_str());
+ return s;
+ }
+
+ uint64_t size_put = headerbuf.size() + key.size() + value.size();
+ bfile->BlobRecordAdded(size_put);
+ total_blob_size_ += size_put;
+
+ if (expiration == kNoExpiration) {
+ BlobIndex::EncodeBlob(index_entry, bfile->BlobFileNumber(), blob_offset,
+ value.size(), bdb_options_.compression);
+ } else {
+ BlobIndex::EncodeBlobTTL(index_entry, expiration, bfile->BlobFileNumber(),
+ blob_offset, value.size(),
+ bdb_options_.compression);
+ }
+
+ return s;
+}
+
+std::vector<Status> BlobDBImpl::MultiGet(
+ const ReadOptions& read_options,
+ const std::vector<Slice>& keys, std::vector<std::string>* values) {
+ StopWatch multiget_sw(env_, statistics_, BLOB_DB_MULTIGET_MICROS);
+ RecordTick(statistics_, BLOB_DB_NUM_MULTIGET);
+ // Get a snapshot to avoid blob file get deleted between we
+ // fetch and index entry and reading from the file.
+ ReadOptions ro(read_options);
+ bool snapshot_created = SetSnapshotIfNeeded(&ro);
+
+ std::vector<Status> statuses;
+ statuses.reserve(keys.size());
+ values->clear();
+ values->reserve(keys.size());
+ PinnableSlice value;
+ for (size_t i = 0; i < keys.size(); i++) {
+ statuses.push_back(Get(ro, DefaultColumnFamily(), keys[i], &value));
+ values->push_back(value.ToString());
+ value.Reset();
+ }
+ if (snapshot_created) {
+ db_->ReleaseSnapshot(ro.snapshot);
+ }
+ return statuses;
+}
+
+bool BlobDBImpl::SetSnapshotIfNeeded(ReadOptions* read_options) {
+ assert(read_options != nullptr);
+ if (read_options->snapshot != nullptr) {
+ return false;
+ }
+ read_options->snapshot = db_->GetSnapshot();
+ return true;
+}
+
+Status BlobDBImpl::GetBlobValue(const Slice& key, const Slice& index_entry,
+ PinnableSlice* value, uint64_t* expiration) {
+ assert(value);
+
+ BlobIndex blob_index;
+ Status s = blob_index.DecodeFrom(index_entry);
+ if (!s.ok()) {
+ return s;
+ }
+
+ if (blob_index.HasTTL() && blob_index.expiration() <= EpochNow()) {
+ return Status::NotFound("Key expired");
+ }
+
+ if (expiration != nullptr) {
+ if (blob_index.HasTTL()) {
+ *expiration = blob_index.expiration();
+ } else {
+ *expiration = kNoExpiration;
+ }
+ }
+
+ if (blob_index.IsInlined()) {
+ // TODO(yiwu): If index_entry is a PinnableSlice, we can also pin the same
+ // memory buffer to avoid extra copy.
+ value->PinSelf(blob_index.value());
+ return Status::OK();
+ }
+
+ CompressionType compression_type = kNoCompression;
+ s = GetRawBlobFromFile(key, blob_index.file_number(), blob_index.offset(),
+ blob_index.size(), value, &compression_type);
+ if (!s.ok()) {
+ return s;
+ }
+
+ if (compression_type != kNoCompression) {
+ BlockContents contents;
+ auto cfh = static_cast<ColumnFamilyHandleImpl*>(DefaultColumnFamily());
+
+ {
+ StopWatch decompression_sw(env_, statistics_,
+ BLOB_DB_DECOMPRESSION_MICROS);
+ UncompressionContext context(compression_type);
+ UncompressionInfo info(context, UncompressionDict::GetEmptyDict(),
+ compression_type);
+ s = UncompressBlockContentsForCompressionType(
+ info, value->data(), value->size(), &contents,
+ kBlockBasedTableVersionFormat, *(cfh->cfd()->ioptions()));
+ }
+
+ if (!s.ok()) {
+ if (debug_level_ >= 2) {
+ ROCKS_LOG_ERROR(
+ db_options_.info_log,
+ "Uncompression error during blob read from file: %" PRIu64
+ " blob_offset: %" PRIu64 " blob_size: %" PRIu64
+ " key: %s status: '%s'",
+ blob_index.file_number(), blob_index.offset(), blob_index.size(),
+ key.ToString(/* output_hex */ true).c_str(), s.ToString().c_str());
+ }
+
+ return Status::Corruption("Unable to uncompress blob.");
+ }
+
+ value->PinSelf(contents.data);
+ }
+
+ return Status::OK();
+}
+
+Status BlobDBImpl::GetRawBlobFromFile(const Slice& key, uint64_t file_number,
+ uint64_t offset, uint64_t size,
+ PinnableSlice* value,
+ CompressionType* compression_type) {
+ assert(value);
+ assert(compression_type);
+ assert(*compression_type == kNoCompression);
+
+ if (!size) {
+ value->PinSelf("");
+ return Status::OK();
+ }
+
+ // offset has to have certain min, as we will read CRC
+ // later from the Blob Header, which needs to be also a
+ // valid offset.
+ if (offset <
+ (BlobLogHeader::kSize + BlobLogRecord::kHeaderSize + key.size())) {
+ if (debug_level_ >= 2) {
+ ROCKS_LOG_ERROR(db_options_.info_log,
+ "Invalid blob index file_number: %" PRIu64
+ " blob_offset: %" PRIu64 " blob_size: %" PRIu64
+ " key: %s",
+ file_number, offset, size,
+ key.ToString(/* output_hex */ true).c_str());
+ }
+
+ return Status::NotFound("Invalid blob offset");
+ }
+
+ std::shared_ptr<BlobFile> blob_file;
+
+ {
+ ReadLock rl(&mutex_);
+ auto it = blob_files_.find(file_number);
+
+ // file was deleted
+ if (it == blob_files_.end()) {
+ return Status::NotFound("Blob Not Found as blob file missing");
+ }
+
+ blob_file = it->second;
+ }
+
+ *compression_type = blob_file->GetCompressionType();
+
+ // takes locks when called
+ std::shared_ptr<RandomAccessFileReader> reader;
+ Status s = GetBlobFileReader(blob_file, &reader);
+ if (!s.ok()) {
+ return s;
+ }
+
+ assert(offset >= key.size() + sizeof(uint32_t));
+ const uint64_t record_offset = offset - key.size() - sizeof(uint32_t);
+ const uint64_t record_size = sizeof(uint32_t) + key.size() + size;
+
+ // Allocate the buffer. This is safe in C++11
+ std::string buffer_str(static_cast<size_t>(record_size), static_cast<char>(0));
+ char* buffer = &buffer_str[0];
+
+ // A partial blob record contain checksum, key and value.
+ Slice blob_record;
+
+ {
+ StopWatch read_sw(env_, statistics_, BLOB_DB_BLOB_FILE_READ_MICROS);
+ s = reader->Read(record_offset, static_cast<size_t>(record_size), &blob_record, buffer);
+ RecordTick(statistics_, BLOB_DB_BLOB_FILE_BYTES_READ, blob_record.size());
+ }
+
+ if (!s.ok()) {
+ ROCKS_LOG_DEBUG(
+ db_options_.info_log,
+ "Failed to read blob from blob file %" PRIu64 ", blob_offset: %" PRIu64
+ ", blob_size: %" PRIu64 ", key_size: %" ROCKSDB_PRIszt ", status: '%s'",
+ file_number, offset, size, key.size(), s.ToString().c_str());
+ return s;
+ }
+
+ if (blob_record.size() != record_size) {
+ ROCKS_LOG_DEBUG(
+ db_options_.info_log,
+ "Failed to read blob from blob file %" PRIu64 ", blob_offset: %" PRIu64
+ ", blob_size: %" PRIu64 ", key_size: %" ROCKSDB_PRIszt
+ ", read %" ROCKSDB_PRIszt " bytes, expected %" PRIu64 " bytes",
+ file_number, offset, size, key.size(), blob_record.size(), record_size);
+
+ return Status::Corruption("Failed to retrieve blob from blob index.");
+ }
+
+ Slice crc_slice(blob_record.data(), sizeof(uint32_t));
+ Slice blob_value(blob_record.data() + sizeof(uint32_t) + key.size(),
+ static_cast<size_t>(size));
+
+ uint32_t crc_exp = 0;
+ if (!GetFixed32(&crc_slice, &crc_exp)) {
+ ROCKS_LOG_DEBUG(
+ db_options_.info_log,
+ "Unable to decode CRC from blob file %" PRIu64 ", blob_offset: %" PRIu64
+ ", blob_size: %" PRIu64 ", key size: %" ROCKSDB_PRIszt ", status: '%s'",
+ file_number, offset, size, key.size(), s.ToString().c_str());
+ return Status::Corruption("Unable to decode checksum.");
+ }
+
+ uint32_t crc = crc32c::Value(blob_record.data() + sizeof(uint32_t),
+ blob_record.size() - sizeof(uint32_t));
+ crc = crc32c::Mask(crc); // Adjust for storage
+ if (crc != crc_exp) {
+ if (debug_level_ >= 2) {
+ ROCKS_LOG_ERROR(
+ db_options_.info_log,
+ "Blob crc mismatch file: %" PRIu64 " blob_offset: %" PRIu64
+ " blob_size: %" PRIu64 " key: %s status: '%s'",
+ file_number, offset, size,
+ key.ToString(/* output_hex */ true).c_str(), s.ToString().c_str());
+ }
+
+ return Status::Corruption("Corruption. Blob CRC mismatch");
+ }
+
+ value->PinSelf(blob_value);
+
+ return Status::OK();
+}
+
+Status BlobDBImpl::Get(const ReadOptions& read_options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ PinnableSlice* value) {
+ return Get(read_options, column_family, key, value, nullptr /*expiration*/);
+}
+
+Status BlobDBImpl::Get(const ReadOptions& read_options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ PinnableSlice* value, uint64_t* expiration) {
+ StopWatch get_sw(env_, statistics_, BLOB_DB_GET_MICROS);
+ RecordTick(statistics_, BLOB_DB_NUM_GET);
+ return GetImpl(read_options, column_family, key, value, expiration);
+}
+
+Status BlobDBImpl::GetImpl(const ReadOptions& read_options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ PinnableSlice* value, uint64_t* expiration) {
+ if (column_family->GetID() != DefaultColumnFamily()->GetID()) {
+ return Status::NotSupported(
+ "Blob DB doesn't support non-default column family.");
+ }
+ // Get a snapshot to avoid blob file get deleted between we
+ // fetch and index entry and reading from the file.
+ // TODO(yiwu): For Get() retry if file not found would be a simpler strategy.
+ ReadOptions ro(read_options);
+ bool snapshot_created = SetSnapshotIfNeeded(&ro);
+
+ PinnableSlice index_entry;
+ Status s;
+ bool is_blob_index = false;
+ DBImpl::GetImplOptions get_impl_options;
+ get_impl_options.column_family = column_family;
+ get_impl_options.value = &index_entry;
+ get_impl_options.is_blob_index = &is_blob_index;
+ s = db_impl_->GetImpl(ro, key, get_impl_options);
+ if (expiration != nullptr) {
+ *expiration = kNoExpiration;
+ }
+ RecordTick(statistics_, BLOB_DB_NUM_KEYS_READ);
+ if (s.ok()) {
+ if (is_blob_index) {
+ s = GetBlobValue(key, index_entry, value, expiration);
+ } else {
+ // The index entry is the value itself in this case.
+ value->PinSelf(index_entry);
+ }
+ RecordTick(statistics_, BLOB_DB_BYTES_READ, value->size());
+ }
+ if (snapshot_created) {
+ db_->ReleaseSnapshot(ro.snapshot);
+ }
+ return s;
+}
+
+std::pair<bool, int64_t> BlobDBImpl::SanityCheck(bool aborted) {
+ if (aborted) {
+ return std::make_pair(false, -1);
+ }
+
+ ReadLock rl(&mutex_);
+
+ ROCKS_LOG_INFO(db_options_.info_log, "Starting Sanity Check");
+ ROCKS_LOG_INFO(db_options_.info_log, "Number of files %" ROCKSDB_PRIszt,
+ blob_files_.size());
+ ROCKS_LOG_INFO(db_options_.info_log, "Number of open files %" ROCKSDB_PRIszt,
+ open_ttl_files_.size());
+
+ for (const auto& blob_file : open_ttl_files_) {
+ (void)blob_file;
+ assert(!blob_file->Immutable());
+ }
+
+ for (const auto& pair : live_imm_non_ttl_blob_files_) {
+ const auto& blob_file = pair.second;
+ (void)blob_file;
+ assert(!blob_file->HasTTL());
+ assert(blob_file->Immutable());
+ }
+
+ uint64_t now = EpochNow();
+
+ for (auto blob_file_pair : blob_files_) {
+ auto blob_file = blob_file_pair.second;
+ char buf[1000];
+ int pos = snprintf(buf, sizeof(buf),
+ "Blob file %" PRIu64 ", size %" PRIu64
+ ", blob count %" PRIu64 ", immutable %d",
+ blob_file->BlobFileNumber(), blob_file->GetFileSize(),
+ blob_file->BlobCount(), blob_file->Immutable());
+ if (blob_file->HasTTL()) {
+ ExpirationRange expiration_range;
+
+ {
+ ReadLock file_lock(&blob_file->mutex_);
+ expiration_range = blob_file->GetExpirationRange();
+ }
+
+ pos += snprintf(buf + pos, sizeof(buf) - pos,
+ ", expiration range (%" PRIu64 ", %" PRIu64 ")",
+ expiration_range.first, expiration_range.second);
+ if (!blob_file->Obsolete()) {
+ pos += snprintf(buf + pos, sizeof(buf) - pos,
+ ", expire in %" PRIu64 " seconds",
+ expiration_range.second - now);
+ }
+ }
+ if (blob_file->Obsolete()) {
+ pos += snprintf(buf + pos, sizeof(buf) - pos, ", obsolete at %" PRIu64,
+ blob_file->GetObsoleteSequence());
+ }
+ snprintf(buf + pos, sizeof(buf) - pos, ".");
+ ROCKS_LOG_INFO(db_options_.info_log, "%s", buf);
+ }
+
+ // reschedule
+ return std::make_pair(true, -1);
+}
+
+Status BlobDBImpl::CloseBlobFile(std::shared_ptr<BlobFile> bfile) {
+ assert(bfile);
+ assert(!bfile->Immutable());
+ assert(!bfile->Obsolete());
+
+ if (bfile->HasTTL() || bfile == open_non_ttl_file_) {
+ write_mutex_.AssertHeld();
+ }
+
+ ROCKS_LOG_INFO(db_options_.info_log,
+ "Closing blob file %" PRIu64 ". Path: %s",
+ bfile->BlobFileNumber(), bfile->PathName().c_str());
+
+ const SequenceNumber sequence = GetLatestSequenceNumber();
+
+ const Status s = bfile->WriteFooterAndCloseLocked(sequence);
+
+ if (s.ok()) {
+ total_blob_size_ += BlobLogFooter::kSize;
+ } else {
+ bfile->MarkImmutable(sequence);
+
+ ROCKS_LOG_ERROR(db_options_.info_log,
+ "Failed to close blob file %" PRIu64 "with error: %s",
+ bfile->BlobFileNumber(), s.ToString().c_str());
+ }
+
+ if (bfile->HasTTL()) {
+ size_t erased __attribute__((__unused__));
+ erased = open_ttl_files_.erase(bfile);
+ } else {
+ if (bfile == open_non_ttl_file_) {
+ open_non_ttl_file_ = nullptr;
+ }
+
+ const uint64_t blob_file_number = bfile->BlobFileNumber();
+ auto it = live_imm_non_ttl_blob_files_.lower_bound(blob_file_number);
+ assert(it == live_imm_non_ttl_blob_files_.end() ||
+ it->first != blob_file_number);
+ live_imm_non_ttl_blob_files_.insert(
+ it, std::map<uint64_t, std::shared_ptr<BlobFile>>::value_type(
+ blob_file_number, bfile));
+ }
+
+ return s;
+}
+
+Status BlobDBImpl::CloseBlobFileIfNeeded(std::shared_ptr<BlobFile>& bfile) {
+ write_mutex_.AssertHeld();
+
+ // atomic read
+ if (bfile->GetFileSize() < bdb_options_.blob_file_size) {
+ return Status::OK();
+ }
+
+ WriteLock lock(&mutex_);
+ WriteLock file_lock(&bfile->mutex_);
+
+ assert(!bfile->Obsolete() || bfile->Immutable());
+ if (bfile->Immutable()) {
+ return Status::OK();
+ }
+
+ return CloseBlobFile(bfile);
+}
+
+void BlobDBImpl::ObsoleteBlobFile(std::shared_ptr<BlobFile> blob_file,
+ SequenceNumber obsolete_seq,
+ bool update_size) {
+ assert(blob_file->Immutable());
+ assert(!blob_file->Obsolete());
+
+ // Should hold write lock of mutex_ or during DB open.
+ blob_file->MarkObsolete(obsolete_seq);
+ obsolete_files_.push_back(blob_file);
+ assert(total_blob_size_.load() >= blob_file->GetFileSize());
+ if (update_size) {
+ total_blob_size_ -= blob_file->GetFileSize();
+ }
+}
+
+bool BlobDBImpl::VisibleToActiveSnapshot(
+ const std::shared_ptr<BlobFile>& bfile) {
+ assert(bfile->Obsolete());
+
+ // We check whether the oldest snapshot is no less than the last sequence
+ // by the time the blob file become obsolete. If so, the blob file is not
+ // visible to all existing snapshots.
+ //
+ // If we keep track of the earliest sequence of the keys in the blob file,
+ // we could instead check if there's a snapshot falls in range
+ // [earliest_sequence, obsolete_sequence). But doing so will make the
+ // implementation more complicated.
+ SequenceNumber obsolete_sequence = bfile->GetObsoleteSequence();
+ SequenceNumber oldest_snapshot = kMaxSequenceNumber;
+ {
+ // Need to lock DBImpl mutex before access snapshot list.
+ InstrumentedMutexLock l(db_impl_->mutex());
+ auto& snapshots = db_impl_->snapshots();
+ if (!snapshots.empty()) {
+ oldest_snapshot = snapshots.oldest()->GetSequenceNumber();
+ }
+ }
+ bool visible = oldest_snapshot < obsolete_sequence;
+ if (visible) {
+ ROCKS_LOG_INFO(db_options_.info_log,
+ "Obsolete blob file %" PRIu64 " (obsolete at %" PRIu64
+ ") visible to oldest snapshot %" PRIu64 ".",
+ bfile->BlobFileNumber(), obsolete_sequence, oldest_snapshot);
+ }
+ return visible;
+}
+
+std::pair<bool, int64_t> BlobDBImpl::EvictExpiredFiles(bool aborted) {
+ if (aborted) {
+ return std::make_pair(false, -1);
+ }
+
+ TEST_SYNC_POINT("BlobDBImpl::EvictExpiredFiles:0");
+ TEST_SYNC_POINT("BlobDBImpl::EvictExpiredFiles:1");
+
+ std::vector<std::shared_ptr<BlobFile>> process_files;
+ uint64_t now = EpochNow();
+ {
+ ReadLock rl(&mutex_);
+ for (auto p : blob_files_) {
+ auto& blob_file = p.second;
+ ReadLock file_lock(&blob_file->mutex_);
+ if (blob_file->HasTTL() && !blob_file->Obsolete() &&
+ blob_file->GetExpirationRange().second <= now) {
+ process_files.push_back(blob_file);
+ }
+ }
+ }
+
+ TEST_SYNC_POINT("BlobDBImpl::EvictExpiredFiles:2");
+ TEST_SYNC_POINT("BlobDBImpl::EvictExpiredFiles:3");
+ TEST_SYNC_POINT_CALLBACK("BlobDBImpl::EvictExpiredFiles:cb", nullptr);
+
+ SequenceNumber seq = GetLatestSequenceNumber();
+ {
+ MutexLock l(&write_mutex_);
+ WriteLock lock(&mutex_);
+ for (auto& blob_file : process_files) {
+ WriteLock file_lock(&blob_file->mutex_);
+
+ // Need to double check if the file is obsolete.
+ if (blob_file->Obsolete()) {
+ assert(blob_file->Immutable());
+ continue;
+ }
+
+ if (!blob_file->Immutable()) {
+ CloseBlobFile(blob_file);
+ }
+
+ assert(blob_file->Immutable());
+
+ ObsoleteBlobFile(blob_file, seq, true /*update_size*/);
+ }
+ }
+
+ return std::make_pair(true, -1);
+}
+
+Status BlobDBImpl::SyncBlobFiles() {
+ MutexLock l(&write_mutex_);
+
+ std::vector<std::shared_ptr<BlobFile>> process_files;
+ {
+ ReadLock rl(&mutex_);
+ for (auto fitr : open_ttl_files_) {
+ process_files.push_back(fitr);
+ }
+ if (open_non_ttl_file_ != nullptr) {
+ process_files.push_back(open_non_ttl_file_);
+ }
+ }
+
+ Status s;
+ for (auto& blob_file : process_files) {
+ s = blob_file->Fsync();
+ if (!s.ok()) {
+ ROCKS_LOG_ERROR(db_options_.info_log,
+ "Failed to sync blob file %" PRIu64 ", status: %s",
+ blob_file->BlobFileNumber(), s.ToString().c_str());
+ return s;
+ }
+ }
+
+ s = dir_ent_->Fsync();
+ if (!s.ok()) {
+ ROCKS_LOG_ERROR(db_options_.info_log,
+ "Failed to sync blob directory, status: %s",
+ s.ToString().c_str());
+ }
+ return s;
+}
+
+std::pair<bool, int64_t> BlobDBImpl::ReclaimOpenFiles(bool aborted) {
+ if (aborted) return std::make_pair(false, -1);
+
+ if (open_file_count_.load() < kOpenFilesTrigger) {
+ return std::make_pair(true, -1);
+ }
+
+ // in the future, we should sort by last_access_
+ // instead of closing every file
+ ReadLock rl(&mutex_);
+ for (auto const& ent : blob_files_) {
+ auto bfile = ent.second;
+ if (bfile->last_access_.load() == -1) continue;
+
+ WriteLock lockbfile_w(&bfile->mutex_);
+ CloseRandomAccessLocked(bfile);
+ }
+
+ return std::make_pair(true, -1);
+}
+
+std::pair<bool, int64_t> BlobDBImpl::DeleteObsoleteFiles(bool aborted) {
+ if (aborted) {
+ return std::make_pair(false, -1);
+ }
+
+ MutexLock delete_file_lock(&delete_file_mutex_);
+ if (disable_file_deletions_ > 0) {
+ return std::make_pair(true, -1);
+ }
+
+ std::list<std::shared_ptr<BlobFile>> tobsolete;
+ {
+ WriteLock wl(&mutex_);
+ if (obsolete_files_.empty()) {
+ return std::make_pair(true, -1);
+ }
+ tobsolete.swap(obsolete_files_);
+ }
+
+ bool file_deleted = false;
+ for (auto iter = tobsolete.begin(); iter != tobsolete.end();) {
+ auto bfile = *iter;
+ {
+ ReadLock lockbfile_r(&bfile->mutex_);
+ if (VisibleToActiveSnapshot(bfile)) {
+ ROCKS_LOG_INFO(db_options_.info_log,
+ "Could not delete file due to snapshot failure %s",
+ bfile->PathName().c_str());
+ ++iter;
+ continue;
+ }
+ }
+ ROCKS_LOG_INFO(db_options_.info_log,
+ "Will delete file due to snapshot success %s",
+ bfile->PathName().c_str());
+
+ {
+ WriteLock wl(&mutex_);
+ blob_files_.erase(bfile->BlobFileNumber());
+ }
+
+ Status s = DeleteDBFile(&(db_impl_->immutable_db_options()),
+ bfile->PathName(), blob_dir_, true,
+ /*force_fg=*/false);
+ if (!s.ok()) {
+ ROCKS_LOG_ERROR(db_options_.info_log,
+ "File failed to be deleted as obsolete %s",
+ bfile->PathName().c_str());
+ ++iter;
+ continue;
+ }
+
+ file_deleted = true;
+ ROCKS_LOG_INFO(db_options_.info_log,
+ "File deleted as obsolete from blob dir %s",
+ bfile->PathName().c_str());
+
+ iter = tobsolete.erase(iter);
+ }
+
+ // directory change. Fsync
+ if (file_deleted) {
+ Status s = dir_ent_->Fsync();
+ if (!s.ok()) {
+ ROCKS_LOG_ERROR(db_options_.info_log, "Failed to sync dir %s: %s",
+ blob_dir_.c_str(), s.ToString().c_str());
+ }
+ }
+
+ // put files back into obsolete if for some reason, delete failed
+ if (!tobsolete.empty()) {
+ WriteLock wl(&mutex_);
+ for (auto bfile : tobsolete) {
+ blob_files_.insert(std::make_pair(bfile->BlobFileNumber(), bfile));
+ obsolete_files_.push_front(bfile);
+ }
+ }
+
+ return std::make_pair(!aborted, -1);
+}
+
+void BlobDBImpl::CopyBlobFiles(
+ std::vector<std::shared_ptr<BlobFile>>* bfiles_copy) {
+ ReadLock rl(&mutex_);
+ for (auto const& p : blob_files_) {
+ bfiles_copy->push_back(p.second);
+ }
+}
+
+Iterator* BlobDBImpl::NewIterator(const ReadOptions& read_options) {
+ auto* cfd =
+ reinterpret_cast<ColumnFamilyHandleImpl*>(DefaultColumnFamily())->cfd();
+ // Get a snapshot to avoid blob file get deleted between we
+ // fetch and index entry and reading from the file.
+ ManagedSnapshot* own_snapshot = nullptr;
+ const Snapshot* snapshot = read_options.snapshot;
+ if (snapshot == nullptr) {
+ own_snapshot = new ManagedSnapshot(db_);
+ snapshot = own_snapshot->snapshot();
+ }
+ auto* iter = db_impl_->NewIteratorImpl(
+ read_options, cfd, snapshot->GetSequenceNumber(),
+ nullptr /*read_callback*/, true /*allow_blob*/);
+ return new BlobDBIterator(own_snapshot, iter, this, env_, statistics_);
+}
+
+Status DestroyBlobDB(const std::string& dbname, const Options& options,
+ const BlobDBOptions& bdb_options) {
+ const ImmutableDBOptions soptions(SanitizeOptions(dbname, options));
+ Env* env = soptions.env;
+
+ Status status;
+ std::string blobdir;
+ blobdir = (bdb_options.path_relative) ? dbname + "/" + bdb_options.blob_dir
+ : bdb_options.blob_dir;
+
+ std::vector<std::string> filenames;
+ env->GetChildren(blobdir, &filenames);
+
+ for (const auto& f : filenames) {
+ uint64_t number;
+ FileType type;
+ if (ParseFileName(f, &number, &type) && type == kBlobFile) {
+ Status del = DeleteDBFile(&soptions, blobdir + "/" + f, blobdir, true,
+ /*force_fg=*/false);
+ if (status.ok() && !del.ok()) {
+ status = del;
+ }
+ }
+ }
+ env->DeleteDir(blobdir);
+
+ Status destroy = DestroyDB(dbname, options);
+ if (status.ok() && !destroy.ok()) {
+ status = destroy;
+ }
+
+ return status;
+}
+
+#ifndef NDEBUG
+Status BlobDBImpl::TEST_GetBlobValue(const Slice& key, const Slice& index_entry,
+ PinnableSlice* value) {
+ return GetBlobValue(key, index_entry, value);
+}
+
+void BlobDBImpl::TEST_AddDummyBlobFile(uint64_t blob_file_number,
+ SequenceNumber immutable_sequence) {
+ auto blob_file = std::make_shared<BlobFile>(this, blob_dir_, blob_file_number,
+ db_options_.info_log.get());
+ blob_file->MarkImmutable(immutable_sequence);
+
+ blob_files_[blob_file_number] = blob_file;
+ live_imm_non_ttl_blob_files_[blob_file_number] = blob_file;
+}
+
+std::vector<std::shared_ptr<BlobFile>> BlobDBImpl::TEST_GetBlobFiles() const {
+ ReadLock l(&mutex_);
+ std::vector<std::shared_ptr<BlobFile>> blob_files;
+ for (auto& p : blob_files_) {
+ blob_files.emplace_back(p.second);
+ }
+ return blob_files;
+}
+
+std::vector<std::shared_ptr<BlobFile>> BlobDBImpl::TEST_GetLiveImmNonTTLFiles()
+ const {
+ ReadLock l(&mutex_);
+ std::vector<std::shared_ptr<BlobFile>> live_imm_non_ttl_files;
+ for (const auto& pair : live_imm_non_ttl_blob_files_) {
+ live_imm_non_ttl_files.emplace_back(pair.second);
+ }
+ return live_imm_non_ttl_files;
+}
+
+std::vector<std::shared_ptr<BlobFile>> BlobDBImpl::TEST_GetObsoleteFiles()
+ const {
+ ReadLock l(&mutex_);
+ std::vector<std::shared_ptr<BlobFile>> obsolete_files;
+ for (auto& bfile : obsolete_files_) {
+ obsolete_files.emplace_back(bfile);
+ }
+ return obsolete_files;
+}
+
+void BlobDBImpl::TEST_DeleteObsoleteFiles() {
+ DeleteObsoleteFiles(false /*abort*/);
+}
+
+Status BlobDBImpl::TEST_CloseBlobFile(std::shared_ptr<BlobFile>& bfile) {
+ MutexLock l(&write_mutex_);
+ WriteLock lock(&mutex_);
+ WriteLock file_lock(&bfile->mutex_);
+
+ return CloseBlobFile(bfile);
+}
+
+void BlobDBImpl::TEST_ObsoleteBlobFile(std::shared_ptr<BlobFile>& blob_file,
+ SequenceNumber obsolete_seq,
+ bool update_size) {
+ return ObsoleteBlobFile(blob_file, obsolete_seq, update_size);
+}
+
+void BlobDBImpl::TEST_EvictExpiredFiles() {
+ EvictExpiredFiles(false /*abort*/);
+}
+
+uint64_t BlobDBImpl::TEST_live_sst_size() { return live_sst_size_.load(); }
+
+void BlobDBImpl::TEST_InitializeBlobFileToSstMapping(
+ const std::vector<LiveFileMetaData>& live_files) {
+ InitializeBlobFileToSstMapping(live_files);
+}
+
+void BlobDBImpl::TEST_ProcessFlushJobInfo(const FlushJobInfo& info) {
+ ProcessFlushJobInfo(info);
+}
+
+void BlobDBImpl::TEST_ProcessCompactionJobInfo(const CompactionJobInfo& info) {
+ ProcessCompactionJobInfo(info);
+}
+
+#endif // !NDEBUG
+
+} // namespace blob_db
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/blob_db/blob_db_impl.h b/src/rocksdb/utilities/blob_db/blob_db_impl.h
new file mode 100644
index 000000000..c1e649cc5
--- /dev/null
+++ b/src/rocksdb/utilities/blob_db/blob_db_impl.h
@@ -0,0 +1,495 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#ifndef ROCKSDB_LITE
+
+#include <atomic>
+#include <condition_variable>
+#include <limits>
+#include <list>
+#include <memory>
+#include <set>
+#include <string>
+#include <thread>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "db/db_iter.h"
+#include "rocksdb/compaction_filter.h"
+#include "rocksdb/db.h"
+#include "rocksdb/listener.h"
+#include "rocksdb/options.h"
+#include "rocksdb/statistics.h"
+#include "rocksdb/wal_filter.h"
+#include "util/mutexlock.h"
+#include "util/timer_queue.h"
+#include "utilities/blob_db/blob_db.h"
+#include "utilities/blob_db/blob_file.h"
+#include "utilities/blob_db/blob_log_format.h"
+#include "utilities/blob_db/blob_log_reader.h"
+#include "utilities/blob_db/blob_log_writer.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class DBImpl;
+class ColumnFamilyHandle;
+class ColumnFamilyData;
+struct FlushJobInfo;
+
+namespace blob_db {
+
+struct BlobCompactionContext;
+struct BlobCompactionContextGC;
+class BlobDBImpl;
+class BlobFile;
+
+// Comparator to sort "TTL" aware Blob files based on the lower value of
+// TTL range.
+struct BlobFileComparatorTTL {
+ bool operator()(const std::shared_ptr<BlobFile>& lhs,
+ const std::shared_ptr<BlobFile>& rhs) const;
+};
+
+struct BlobFileComparator {
+ bool operator()(const std::shared_ptr<BlobFile>& lhs,
+ const std::shared_ptr<BlobFile>& rhs) const;
+};
+
+/**
+ * The implementation class for BlobDB. It manages the blob logs, which
+ * are sequentially written files. Blob logs can be of the TTL or non-TTL
+ * varieties; the former are cleaned up when they expire, while the latter
+ * are (optionally) garbage collected.
+ */
+class BlobDBImpl : public BlobDB {
+ friend class BlobFile;
+ friend class BlobDBIterator;
+ friend class BlobDBListener;
+ friend class BlobDBListenerGC;
+ friend class BlobIndexCompactionFilterGC;
+
+ public:
+ // deletions check period
+ static constexpr uint32_t kDeleteCheckPeriodMillisecs = 2 * 1000;
+
+ // sanity check task
+ static constexpr uint32_t kSanityCheckPeriodMillisecs = 20 * 60 * 1000;
+
+ // how many random access open files can we tolerate
+ static constexpr uint32_t kOpenFilesTrigger = 100;
+
+ // how often to schedule reclaim open files.
+ static constexpr uint32_t kReclaimOpenFilesPeriodMillisecs = 1 * 1000;
+
+ // how often to schedule delete obs files periods
+ static constexpr uint32_t kDeleteObsoleteFilesPeriodMillisecs = 10 * 1000;
+
+ // how often to schedule expired files eviction.
+ static constexpr uint32_t kEvictExpiredFilesPeriodMillisecs = 10 * 1000;
+
+ // when should oldest file be evicted:
+ // on reaching 90% of blob_dir_size
+ static constexpr double kEvictOldestFileAtSize = 0.9;
+
+ using BlobDB::Put;
+ Status Put(const WriteOptions& options, const Slice& key,
+ const Slice& value) override;
+
+ using BlobDB::Get;
+ Status Get(const ReadOptions& read_options, ColumnFamilyHandle* column_family,
+ const Slice& key, PinnableSlice* value) override;
+
+ Status Get(const ReadOptions& read_options, ColumnFamilyHandle* column_family,
+ const Slice& key, PinnableSlice* value,
+ uint64_t* expiration) override;
+
+ using BlobDB::NewIterator;
+ virtual Iterator* NewIterator(const ReadOptions& read_options) override;
+
+ using BlobDB::NewIterators;
+ virtual Status NewIterators(
+ const ReadOptions& /*read_options*/,
+ const std::vector<ColumnFamilyHandle*>& /*column_families*/,
+ std::vector<Iterator*>* /*iterators*/) override {
+ return Status::NotSupported("Not implemented");
+ }
+
+ using BlobDB::MultiGet;
+ virtual std::vector<Status> MultiGet(
+ const ReadOptions& read_options,
+ const std::vector<Slice>& keys,
+ std::vector<std::string>* values) override;
+
+ virtual Status Write(const WriteOptions& opts, WriteBatch* updates) override;
+
+ virtual Status Close() override;
+
+ using BlobDB::PutWithTTL;
+ Status PutWithTTL(const WriteOptions& options, const Slice& key,
+ const Slice& value, uint64_t ttl) override;
+
+ using BlobDB::PutUntil;
+ Status PutUntil(const WriteOptions& options, const Slice& key,
+ const Slice& value, uint64_t expiration) override;
+
+ using BlobDB::CompactFiles;
+ Status CompactFiles(
+ const CompactionOptions& compact_options,
+ const std::vector<std::string>& input_file_names, const int output_level,
+ const int output_path_id = -1,
+ std::vector<std::string>* const output_file_names = nullptr,
+ CompactionJobInfo* compaction_job_info = nullptr) override;
+
+ BlobDBOptions GetBlobDBOptions() const override;
+
+ BlobDBImpl(const std::string& dbname, const BlobDBOptions& bdb_options,
+ const DBOptions& db_options,
+ const ColumnFamilyOptions& cf_options);
+
+ virtual Status DisableFileDeletions() override;
+
+ virtual Status EnableFileDeletions(bool force) override;
+
+ virtual Status GetLiveFiles(std::vector<std::string>&,
+ uint64_t* manifest_file_size,
+ bool flush_memtable = true) override;
+ virtual void GetLiveFilesMetaData(std::vector<LiveFileMetaData>*) override;
+
+ ~BlobDBImpl();
+
+ Status Open(std::vector<ColumnFamilyHandle*>* handles);
+
+ Status SyncBlobFiles() override;
+
+ // Common part of the two GetCompactionContext methods below.
+ // REQUIRES: read lock on mutex_
+ void GetCompactionContextCommon(BlobCompactionContext* context) const;
+
+ void GetCompactionContext(BlobCompactionContext* context);
+ void GetCompactionContext(BlobCompactionContext* context,
+ BlobCompactionContextGC* context_gc);
+
+#ifndef NDEBUG
+ Status TEST_GetBlobValue(const Slice& key, const Slice& index_entry,
+ PinnableSlice* value);
+
+ void TEST_AddDummyBlobFile(uint64_t blob_file_number,
+ SequenceNumber immutable_sequence);
+
+ std::vector<std::shared_ptr<BlobFile>> TEST_GetBlobFiles() const;
+
+ std::vector<std::shared_ptr<BlobFile>> TEST_GetLiveImmNonTTLFiles() const;
+
+ std::vector<std::shared_ptr<BlobFile>> TEST_GetObsoleteFiles() const;
+
+ Status TEST_CloseBlobFile(std::shared_ptr<BlobFile>& bfile);
+
+ void TEST_ObsoleteBlobFile(std::shared_ptr<BlobFile>& blob_file,
+ SequenceNumber obsolete_seq = 0,
+ bool update_size = true);
+
+ void TEST_EvictExpiredFiles();
+
+ void TEST_DeleteObsoleteFiles();
+
+ uint64_t TEST_live_sst_size();
+
+ const std::string& TEST_blob_dir() const { return blob_dir_; }
+
+ void TEST_InitializeBlobFileToSstMapping(
+ const std::vector<LiveFileMetaData>& live_files);
+
+ void TEST_ProcessFlushJobInfo(const FlushJobInfo& info);
+
+ void TEST_ProcessCompactionJobInfo(const CompactionJobInfo& info);
+
+#endif // !NDEBUG
+
+ private:
+ class BlobInserter;
+
+ // Create a snapshot if there isn't one in read options.
+ // Return true if a snapshot is created.
+ bool SetSnapshotIfNeeded(ReadOptions* read_options);
+
+ Status GetImpl(const ReadOptions& read_options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ PinnableSlice* value, uint64_t* expiration = nullptr);
+
+ Status GetBlobValue(const Slice& key, const Slice& index_entry,
+ PinnableSlice* value, uint64_t* expiration = nullptr);
+
+ Status GetRawBlobFromFile(const Slice& key, uint64_t file_number,
+ uint64_t offset, uint64_t size,
+ PinnableSlice* value,
+ CompressionType* compression_type);
+
+ Slice GetCompressedSlice(const Slice& raw,
+ std::string* compression_output) const;
+
+ // Close a file by appending a footer, and removes file from open files list.
+ // REQUIRES: lock held on write_mutex_, write lock held on both the db mutex_
+ // and the blob file's mutex_. If called on a blob file which is visible only
+ // to a single thread (like in the case of new files written during GC), the
+ // locks on write_mutex_ and the blob file's mutex_ can be avoided.
+ Status CloseBlobFile(std::shared_ptr<BlobFile> bfile);
+
+ // Close a file if its size exceeds blob_file_size
+ // REQUIRES: lock held on write_mutex_.
+ Status CloseBlobFileIfNeeded(std::shared_ptr<BlobFile>& bfile);
+
+ // Mark file as obsolete and move the file to obsolete file list.
+ //
+ // REQUIRED: hold write lock of mutex_ or during DB open.
+ void ObsoleteBlobFile(std::shared_ptr<BlobFile> blob_file,
+ SequenceNumber obsolete_seq, bool update_size);
+
+ Status PutBlobValue(const WriteOptions& options, const Slice& key,
+ const Slice& value, uint64_t expiration,
+ WriteBatch* batch);
+
+ Status AppendBlob(const std::shared_ptr<BlobFile>& bfile,
+ const std::string& headerbuf, const Slice& key,
+ const Slice& value, uint64_t expiration,
+ std::string* index_entry);
+
+ // Create a new blob file and associated writer.
+ Status CreateBlobFileAndWriter(bool has_ttl,
+ const ExpirationRange& expiration_range,
+ const std::string& reason,
+ std::shared_ptr<BlobFile>* blob_file,
+ std::shared_ptr<Writer>* writer);
+
+ // Get the open non-TTL blob log file, or create a new one if no such file
+ // exists.
+ Status SelectBlobFile(std::shared_ptr<BlobFile>* blob_file);
+
+ // Get the open TTL blob log file for a certain expiration, or create a new
+ // one if no such file exists.
+ Status SelectBlobFileTTL(uint64_t expiration,
+ std::shared_ptr<BlobFile>* blob_file);
+
+ std::shared_ptr<BlobFile> FindBlobFileLocked(uint64_t expiration) const;
+
+ // periodic sanity check. Bunch of checks
+ std::pair<bool, int64_t> SanityCheck(bool aborted);
+
+ // Delete files that have been marked obsolete (either because of TTL
+ // or GC). Check whether any snapshots exist which refer to the same.
+ std::pair<bool, int64_t> DeleteObsoleteFiles(bool aborted);
+
+ // periodically check if open blob files and their TTL's has expired
+ // if expired, close the sequential writer and make the file immutable
+ std::pair<bool, int64_t> EvictExpiredFiles(bool aborted);
+
+ // if the number of open files, approaches ULIMIT's this
+ // task will close random readers, which are kept around for
+ // efficiency
+ std::pair<bool, int64_t> ReclaimOpenFiles(bool aborted);
+
+ std::pair<bool, int64_t> RemoveTimerQ(TimerQueue* tq, bool aborted);
+
+ // Adds the background tasks to the timer queue
+ void StartBackgroundTasks();
+
+ // add a new Blob File
+ std::shared_ptr<BlobFile> NewBlobFile(bool has_ttl,
+ const ExpirationRange& expiration_range,
+ const std::string& reason);
+
+ // Register a new blob file.
+ // REQUIRES: write lock on mutex_.
+ void RegisterBlobFile(std::shared_ptr<BlobFile> blob_file);
+
+ // collect all the blob log files from the blob directory
+ Status GetAllBlobFiles(std::set<uint64_t>* file_numbers);
+
+ // Open all blob files found in blob_dir.
+ Status OpenAllBlobFiles();
+
+ // Link an SST to a blob file. Comes in locking and non-locking varieties
+ // (the latter is used during Open).
+ template <typename Linker>
+ void LinkSstToBlobFileImpl(uint64_t sst_file_number,
+ uint64_t blob_file_number, Linker linker);
+
+ void LinkSstToBlobFile(uint64_t sst_file_number, uint64_t blob_file_number);
+
+ void LinkSstToBlobFileNoLock(uint64_t sst_file_number,
+ uint64_t blob_file_number);
+
+ // Unlink an SST from a blob file.
+ void UnlinkSstFromBlobFile(uint64_t sst_file_number,
+ uint64_t blob_file_number);
+
+ // Initialize the mapping between blob files and SSTs during Open.
+ void InitializeBlobFileToSstMapping(
+ const std::vector<LiveFileMetaData>& live_files);
+
+ // Update the mapping between blob files and SSTs after a flush and mark
+ // any unneeded blob files obsolete.
+ void ProcessFlushJobInfo(const FlushJobInfo& info);
+
+ // Update the mapping between blob files and SSTs after a compaction and
+ // mark any unneeded blob files obsolete.
+ void ProcessCompactionJobInfo(const CompactionJobInfo& info);
+
+ // Mark an immutable non-TTL blob file obsolete assuming it has no more SSTs
+ // linked to it, and all memtables from before the blob file became immutable
+ // have been flushed. Note: should only be called if the condition holds for
+ // all lower-numbered non-TTL blob files as well.
+ bool MarkBlobFileObsoleteIfNeeded(const std::shared_ptr<BlobFile>& blob_file,
+ SequenceNumber obsolete_seq);
+
+ // Mark all immutable non-TTL blob files that aren't needed by any SSTs as
+ // obsolete. Comes in two varieties; the version used during Open need not
+ // worry about locking or snapshots.
+ template <class Functor>
+ void MarkUnreferencedBlobFilesObsoleteImpl(Functor mark_if_needed);
+
+ void MarkUnreferencedBlobFilesObsolete();
+ void MarkUnreferencedBlobFilesObsoleteDuringOpen();
+
+ void UpdateLiveSSTSize();
+
+ Status GetBlobFileReader(const std::shared_ptr<BlobFile>& blob_file,
+ std::shared_ptr<RandomAccessFileReader>* reader);
+
+ // hold write mutex on file and call.
+ // Close the above Random Access reader
+ void CloseRandomAccessLocked(const std::shared_ptr<BlobFile>& bfile);
+
+ // hold write mutex on file and call
+ // creates a sequential (append) writer for this blobfile
+ Status CreateWriterLocked(const std::shared_ptr<BlobFile>& bfile);
+
+ // returns a Writer object for the file. If writer is not
+ // already present, creates one. Needs Write Mutex to be held
+ Status CheckOrCreateWriterLocked(const std::shared_ptr<BlobFile>& blob_file,
+ std::shared_ptr<Writer>* writer);
+
+ // checks if there is no snapshot which is referencing the
+ // blobs
+ bool VisibleToActiveSnapshot(const std::shared_ptr<BlobFile>& file);
+ bool FileDeleteOk_SnapshotCheckLocked(const std::shared_ptr<BlobFile>& bfile);
+
+ void CopyBlobFiles(std::vector<std::shared_ptr<BlobFile>>* bfiles_copy);
+
+ uint64_t EpochNow() { return env_->NowMicros() / 1000000; }
+
+ // Check if inserting a new blob will make DB grow out of space.
+ // If is_fifo = true, FIFO eviction will be triggered to make room for the
+ // new blob. If force_evict = true, FIFO eviction will evict blob files
+ // even eviction will not make enough room for the new blob.
+ Status CheckSizeAndEvictBlobFiles(uint64_t blob_size,
+ bool force_evict = false);
+
+ // name of the database directory
+ std::string dbname_;
+
+ // the base DB
+ DBImpl* db_impl_;
+ Env* env_;
+
+ // the options that govern the behavior of Blob Storage
+ BlobDBOptions bdb_options_;
+ DBOptions db_options_;
+ ColumnFamilyOptions cf_options_;
+ EnvOptions env_options_;
+
+ // Raw pointer of statistic. db_options_ has a std::shared_ptr to hold
+ // ownership.
+ Statistics* statistics_;
+
+ // by default this is "blob_dir" under dbname_
+ // but can be configured
+ std::string blob_dir_;
+
+ // pointer to directory
+ std::unique_ptr<Directory> dir_ent_;
+
+ // Read Write Mutex, which protects all the data structures
+ // HEAVILY TRAFFICKED
+ mutable port::RWMutex mutex_;
+
+ // Writers has to hold write_mutex_ before writing.
+ mutable port::Mutex write_mutex_;
+
+ // counter for blob file number
+ std::atomic<uint64_t> next_file_number_;
+
+ // entire metadata of all the BLOB files memory
+ std::map<uint64_t, std::shared_ptr<BlobFile>> blob_files_;
+
+ // All live immutable non-TTL blob files.
+ std::map<uint64_t, std::shared_ptr<BlobFile>> live_imm_non_ttl_blob_files_;
+
+ // The largest sequence number that has been flushed.
+ SequenceNumber flush_sequence_;
+
+ // opened non-TTL blob file.
+ std::shared_ptr<BlobFile> open_non_ttl_file_;
+
+ // all the blob files which are currently being appended to based
+ // on variety of incoming TTL's
+ std::set<std::shared_ptr<BlobFile>, BlobFileComparatorTTL> open_ttl_files_;
+
+ // Flag to check whether Close() has been called on this DB
+ bool closed_;
+
+ // timer based queue to execute tasks
+ TimerQueue tqueue_;
+
+ // number of files opened for random access/GET
+ // counter is used to monitor and close excess RA files.
+ std::atomic<uint32_t> open_file_count_;
+
+ // Total size of all live blob files (i.e. exclude obsolete files).
+ std::atomic<uint64_t> total_blob_size_;
+
+ // total size of SST files.
+ std::atomic<uint64_t> live_sst_size_;
+
+ // Latest FIFO eviction timestamp
+ //
+ // REQUIRES: access with metex_ lock held.
+ uint64_t fifo_eviction_seq_;
+
+ // The expiration up to which latest FIFO eviction evicts.
+ //
+ // REQUIRES: access with metex_ lock held.
+ uint64_t evict_expiration_up_to_;
+
+ std::list<std::shared_ptr<BlobFile>> obsolete_files_;
+
+ // DeleteObsoleteFiles, DiableFileDeletions and EnableFileDeletions block
+ // on the mutex to avoid contention.
+ //
+ // While DeleteObsoleteFiles hold both mutex_ and delete_file_mutex_, note
+ // the difference. mutex_ only needs to be held when access the
+ // data-structure, and delete_file_mutex_ needs to be held the whole time
+ // during DeleteObsoleteFiles to avoid being run simultaneously with
+ // DisableFileDeletions.
+ //
+ // If both of mutex_ and delete_file_mutex_ needs to be held, it is adviced
+ // to hold delete_file_mutex_ first to avoid deadlock.
+ mutable port::Mutex delete_file_mutex_;
+
+ // Each call of DisableFileDeletions will increase disable_file_deletion_
+ // by 1. EnableFileDeletions will either decrease the count by 1 or reset
+ // it to zeor, depending on the force flag.
+ //
+ // REQUIRES: access with delete_file_mutex_ held.
+ int disable_file_deletions_ = 0;
+
+ uint32_t debug_level_;
+};
+
+} // namespace blob_db
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/blob_db/blob_db_impl_filesnapshot.cc b/src/rocksdb/utilities/blob_db/blob_db_impl_filesnapshot.cc
new file mode 100644
index 000000000..168c7ce9d
--- /dev/null
+++ b/src/rocksdb/utilities/blob_db/blob_db_impl_filesnapshot.cc
@@ -0,0 +1,109 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "utilities/blob_db/blob_db_impl.h"
+
+#include "file/filename.h"
+#include "logging/logging.h"
+#include "util/mutexlock.h"
+
+// BlobDBImpl methods to get snapshot of files, e.g. for replication.
+
+namespace ROCKSDB_NAMESPACE {
+namespace blob_db {
+
+Status BlobDBImpl::DisableFileDeletions() {
+ // Disable base DB file deletions.
+ Status s = db_impl_->DisableFileDeletions();
+ if (!s.ok()) {
+ return s;
+ }
+
+ int count = 0;
+ {
+ // Hold delete_file_mutex_ to make sure no DeleteObsoleteFiles job
+ // is running.
+ MutexLock l(&delete_file_mutex_);
+ count = ++disable_file_deletions_;
+ }
+
+ ROCKS_LOG_INFO(db_options_.info_log,
+ "Disalbed blob file deletions. count: %d", count);
+ return Status::OK();
+}
+
+Status BlobDBImpl::EnableFileDeletions(bool force) {
+ // Enable base DB file deletions.
+ Status s = db_impl_->EnableFileDeletions(force);
+ if (!s.ok()) {
+ return s;
+ }
+
+ int count = 0;
+ {
+ MutexLock l(&delete_file_mutex_);
+ if (force) {
+ disable_file_deletions_ = 0;
+ } else if (disable_file_deletions_ > 0) {
+ count = --disable_file_deletions_;
+ }
+ assert(count >= 0);
+ }
+
+ ROCKS_LOG_INFO(db_options_.info_log, "Enabled blob file deletions. count: %d",
+ count);
+ // Consider trigger DeleteobsoleteFiles once after re-enabled, if we are to
+ // make DeleteobsoleteFiles re-run interval configuration.
+ return Status::OK();
+}
+
+Status BlobDBImpl::GetLiveFiles(std::vector<std::string>& ret,
+ uint64_t* manifest_file_size,
+ bool flush_memtable) {
+ if (!bdb_options_.path_relative) {
+ return Status::NotSupported(
+ "Not able to get relative blob file path from absolute blob_dir.");
+ }
+ // Hold a lock in the beginning to avoid updates to base DB during the call
+ ReadLock rl(&mutex_);
+ Status s = db_->GetLiveFiles(ret, manifest_file_size, flush_memtable);
+ if (!s.ok()) {
+ return s;
+ }
+ ret.reserve(ret.size() + blob_files_.size());
+ for (auto bfile_pair : blob_files_) {
+ auto blob_file = bfile_pair.second;
+ // Path should be relative to db_name, but begin with slash.
+ ret.emplace_back(
+ BlobFileName("", bdb_options_.blob_dir, blob_file->BlobFileNumber()));
+ }
+ return Status::OK();
+}
+
+void BlobDBImpl::GetLiveFilesMetaData(std::vector<LiveFileMetaData>* metadata) {
+ // Path should be relative to db_name.
+ assert(bdb_options_.path_relative);
+ // Hold a lock in the beginning to avoid updates to base DB during the call
+ ReadLock rl(&mutex_);
+ db_->GetLiveFilesMetaData(metadata);
+ for (auto bfile_pair : blob_files_) {
+ auto blob_file = bfile_pair.second;
+ LiveFileMetaData filemetadata;
+ filemetadata.size = static_cast<size_t>(blob_file->GetFileSize());
+ const uint64_t file_number = blob_file->BlobFileNumber();
+ // Path should be relative to db_name, but begin with slash.
+ filemetadata.name = BlobFileName("", bdb_options_.blob_dir, file_number);
+ filemetadata.file_number = file_number;
+ auto cfh = reinterpret_cast<ColumnFamilyHandleImpl*>(DefaultColumnFamily());
+ filemetadata.column_family_name = cfh->GetName();
+ metadata->emplace_back(filemetadata);
+ }
+}
+
+} // namespace blob_db
+} // namespace ROCKSDB_NAMESPACE
+#endif // !ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/blob_db/blob_db_iterator.h b/src/rocksdb/utilities/blob_db/blob_db_iterator.h
new file mode 100644
index 000000000..af07117eb
--- /dev/null
+++ b/src/rocksdb/utilities/blob_db/blob_db_iterator.h
@@ -0,0 +1,147 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+#ifndef ROCKSDB_LITE
+
+#include "db/arena_wrapped_db_iter.h"
+#include "monitoring/statistics.h"
+#include "rocksdb/iterator.h"
+#include "util/stop_watch.h"
+#include "utilities/blob_db/blob_db_impl.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace blob_db {
+
+using ROCKSDB_NAMESPACE::ManagedSnapshot;
+
+class BlobDBIterator : public Iterator {
+ public:
+ BlobDBIterator(ManagedSnapshot* snapshot, ArenaWrappedDBIter* iter,
+ BlobDBImpl* blob_db, Env* env, Statistics* statistics)
+ : snapshot_(snapshot),
+ iter_(iter),
+ blob_db_(blob_db),
+ env_(env),
+ statistics_(statistics) {}
+
+ virtual ~BlobDBIterator() = default;
+
+ bool Valid() const override {
+ if (!iter_->Valid()) {
+ return false;
+ }
+ return status_.ok();
+ }
+
+ Status status() const override {
+ if (!iter_->status().ok()) {
+ return iter_->status();
+ }
+ return status_;
+ }
+
+ void SeekToFirst() override {
+ StopWatch seek_sw(env_, statistics_, BLOB_DB_SEEK_MICROS);
+ RecordTick(statistics_, BLOB_DB_NUM_SEEK);
+ iter_->SeekToFirst();
+ while (UpdateBlobValue()) {
+ iter_->Next();
+ }
+ }
+
+ void SeekToLast() override {
+ StopWatch seek_sw(env_, statistics_, BLOB_DB_SEEK_MICROS);
+ RecordTick(statistics_, BLOB_DB_NUM_SEEK);
+ iter_->SeekToLast();
+ while (UpdateBlobValue()) {
+ iter_->Prev();
+ }
+ }
+
+ void Seek(const Slice& target) override {
+ StopWatch seek_sw(env_, statistics_, BLOB_DB_SEEK_MICROS);
+ RecordTick(statistics_, BLOB_DB_NUM_SEEK);
+ iter_->Seek(target);
+ while (UpdateBlobValue()) {
+ iter_->Next();
+ }
+ }
+
+ void SeekForPrev(const Slice& target) override {
+ StopWatch seek_sw(env_, statistics_, BLOB_DB_SEEK_MICROS);
+ RecordTick(statistics_, BLOB_DB_NUM_SEEK);
+ iter_->SeekForPrev(target);
+ while (UpdateBlobValue()) {
+ iter_->Prev();
+ }
+ }
+
+ void Next() override {
+ assert(Valid());
+ StopWatch next_sw(env_, statistics_, BLOB_DB_NEXT_MICROS);
+ RecordTick(statistics_, BLOB_DB_NUM_NEXT);
+ iter_->Next();
+ while (UpdateBlobValue()) {
+ iter_->Next();
+ }
+ }
+
+ void Prev() override {
+ assert(Valid());
+ StopWatch prev_sw(env_, statistics_, BLOB_DB_PREV_MICROS);
+ RecordTick(statistics_, BLOB_DB_NUM_PREV);
+ iter_->Prev();
+ while (UpdateBlobValue()) {
+ iter_->Prev();
+ }
+ }
+
+ Slice key() const override {
+ assert(Valid());
+ return iter_->key();
+ }
+
+ Slice value() const override {
+ assert(Valid());
+ if (!iter_->IsBlob()) {
+ return iter_->value();
+ }
+ return value_;
+ }
+
+ // Iterator::Refresh() not supported.
+
+ private:
+ // Return true if caller should continue to next value.
+ bool UpdateBlobValue() {
+ value_.Reset();
+ status_ = Status::OK();
+ if (iter_->Valid() && iter_->status().ok() && iter_->IsBlob()) {
+ Status s = blob_db_->GetBlobValue(iter_->key(), iter_->value(), &value_);
+ if (s.IsNotFound()) {
+ return true;
+ } else {
+ if (!s.ok()) {
+ status_ = s;
+ }
+ return false;
+ }
+ } else {
+ return false;
+ }
+ }
+
+ std::unique_ptr<ManagedSnapshot> snapshot_;
+ std::unique_ptr<ArenaWrappedDBIter> iter_;
+ BlobDBImpl* blob_db_;
+ Env* env_;
+ Statistics* statistics_;
+ Status status_;
+ PinnableSlice value_;
+};
+} // namespace blob_db
+} // namespace ROCKSDB_NAMESPACE
+#endif // !ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/blob_db/blob_db_listener.h b/src/rocksdb/utilities/blob_db/blob_db_listener.h
new file mode 100644
index 000000000..c26d7bd27
--- /dev/null
+++ b/src/rocksdb/utilities/blob_db/blob_db_listener.h
@@ -0,0 +1,66 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#ifndef ROCKSDB_LITE
+
+#include <atomic>
+
+#include "rocksdb/listener.h"
+#include "util/mutexlock.h"
+#include "utilities/blob_db/blob_db_impl.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace blob_db {
+
+class BlobDBListener : public EventListener {
+ public:
+ explicit BlobDBListener(BlobDBImpl* blob_db_impl)
+ : blob_db_impl_(blob_db_impl) {}
+
+ void OnFlushBegin(DB* /*db*/, const FlushJobInfo& /*info*/) override {
+ assert(blob_db_impl_ != nullptr);
+ blob_db_impl_->SyncBlobFiles();
+ }
+
+ void OnFlushCompleted(DB* /*db*/, const FlushJobInfo& /*info*/) override {
+ assert(blob_db_impl_ != nullptr);
+ blob_db_impl_->UpdateLiveSSTSize();
+ }
+
+ void OnCompactionCompleted(DB* /*db*/,
+ const CompactionJobInfo& /*info*/) override {
+ assert(blob_db_impl_ != nullptr);
+ blob_db_impl_->UpdateLiveSSTSize();
+ }
+
+ protected:
+ BlobDBImpl* blob_db_impl_;
+};
+
+class BlobDBListenerGC : public BlobDBListener {
+ public:
+ explicit BlobDBListenerGC(BlobDBImpl* blob_db_impl)
+ : BlobDBListener(blob_db_impl) {}
+
+ void OnFlushCompleted(DB* db, const FlushJobInfo& info) override {
+ BlobDBListener::OnFlushCompleted(db, info);
+
+ assert(blob_db_impl_);
+ blob_db_impl_->ProcessFlushJobInfo(info);
+ }
+
+ void OnCompactionCompleted(DB* db, const CompactionJobInfo& info) override {
+ BlobDBListener::OnCompactionCompleted(db, info);
+
+ assert(blob_db_impl_);
+ blob_db_impl_->ProcessCompactionJobInfo(info);
+ }
+};
+
+} // namespace blob_db
+} // namespace ROCKSDB_NAMESPACE
+#endif // !ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/blob_db/blob_db_test.cc b/src/rocksdb/utilities/blob_db/blob_db_test.cc
new file mode 100644
index 000000000..9fee52e8c
--- /dev/null
+++ b/src/rocksdb/utilities/blob_db/blob_db_test.cc
@@ -0,0 +1,1992 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include <algorithm>
+#include <chrono>
+#include <cstdlib>
+#include <iomanip>
+#include <map>
+#include <memory>
+#include <sstream>
+#include <string>
+#include <vector>
+
+#include "db/blob_index.h"
+#include "db/db_test_util.h"
+#include "env/composite_env_wrapper.h"
+#include "file/file_util.h"
+#include "file/sst_file_manager_impl.h"
+#include "port/port.h"
+#include "rocksdb/utilities/debug.h"
+#include "test_util/fault_injection_test_env.h"
+#include "test_util/sync_point.h"
+#include "test_util/testharness.h"
+#include "util/cast_util.h"
+#include "util/random.h"
+#include "util/string_util.h"
+#include "utilities/blob_db/blob_db.h"
+#include "utilities/blob_db/blob_db_impl.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace blob_db {
+
+class BlobDBTest : public testing::Test {
+ public:
+ const int kMaxBlobSize = 1 << 14;
+
+ struct BlobIndexVersion {
+ BlobIndexVersion() = default;
+ BlobIndexVersion(std::string _user_key, uint64_t _file_number,
+ uint64_t _expiration, SequenceNumber _sequence,
+ ValueType _type)
+ : user_key(std::move(_user_key)),
+ file_number(_file_number),
+ expiration(_expiration),
+ sequence(_sequence),
+ type(_type) {}
+
+ std::string user_key;
+ uint64_t file_number = kInvalidBlobFileNumber;
+ uint64_t expiration = kNoExpiration;
+ SequenceNumber sequence = 0;
+ ValueType type = kTypeValue;
+ };
+
+ BlobDBTest()
+ : dbname_(test::PerThreadDBPath("blob_db_test")),
+ mock_env_(new MockTimeEnv(Env::Default())),
+ fault_injection_env_(new FaultInjectionTestEnv(Env::Default())),
+ blob_db_(nullptr) {
+ Status s = DestroyBlobDB(dbname_, Options(), BlobDBOptions());
+ assert(s.ok());
+ }
+
+ ~BlobDBTest() override {
+ SyncPoint::GetInstance()->ClearAllCallBacks();
+ Destroy();
+ }
+
+ Status TryOpen(BlobDBOptions bdb_options = BlobDBOptions(),
+ Options options = Options()) {
+ options.create_if_missing = true;
+ return BlobDB::Open(options, bdb_options, dbname_, &blob_db_);
+ }
+
+ void Open(BlobDBOptions bdb_options = BlobDBOptions(),
+ Options options = Options()) {
+ ASSERT_OK(TryOpen(bdb_options, options));
+ }
+
+ void Reopen(BlobDBOptions bdb_options = BlobDBOptions(),
+ Options options = Options()) {
+ assert(blob_db_ != nullptr);
+ delete blob_db_;
+ blob_db_ = nullptr;
+ Open(bdb_options, options);
+ }
+
+ void Close() {
+ assert(blob_db_ != nullptr);
+ delete blob_db_;
+ blob_db_ = nullptr;
+ }
+
+ void Destroy() {
+ if (blob_db_) {
+ Options options = blob_db_->GetOptions();
+ BlobDBOptions bdb_options = blob_db_->GetBlobDBOptions();
+ delete blob_db_;
+ blob_db_ = nullptr;
+ ASSERT_OK(DestroyBlobDB(dbname_, options, bdb_options));
+ }
+ }
+
+ BlobDBImpl *blob_db_impl() {
+ return reinterpret_cast<BlobDBImpl *>(blob_db_);
+ }
+
+ Status Put(const Slice &key, const Slice &value,
+ std::map<std::string, std::string> *data = nullptr) {
+ Status s = blob_db_->Put(WriteOptions(), key, value);
+ if (data != nullptr) {
+ (*data)[key.ToString()] = value.ToString();
+ }
+ return s;
+ }
+
+ void Delete(const std::string &key,
+ std::map<std::string, std::string> *data = nullptr) {
+ ASSERT_OK(blob_db_->Delete(WriteOptions(), key));
+ if (data != nullptr) {
+ data->erase(key);
+ }
+ }
+
+ Status PutWithTTL(const Slice &key, const Slice &value, uint64_t ttl,
+ std::map<std::string, std::string> *data = nullptr) {
+ Status s = blob_db_->PutWithTTL(WriteOptions(), key, value, ttl);
+ if (data != nullptr) {
+ (*data)[key.ToString()] = value.ToString();
+ }
+ return s;
+ }
+
+ Status PutUntil(const Slice &key, const Slice &value, uint64_t expiration) {
+ return blob_db_->PutUntil(WriteOptions(), key, value, expiration);
+ }
+
+ void PutRandomWithTTL(const std::string &key, uint64_t ttl, Random *rnd,
+ std::map<std::string, std::string> *data = nullptr) {
+ int len = rnd->Next() % kMaxBlobSize + 1;
+ std::string value = test::RandomHumanReadableString(rnd, len);
+ ASSERT_OK(
+ blob_db_->PutWithTTL(WriteOptions(), Slice(key), Slice(value), ttl));
+ if (data != nullptr) {
+ (*data)[key] = value;
+ }
+ }
+
+ void PutRandomUntil(const std::string &key, uint64_t expiration, Random *rnd,
+ std::map<std::string, std::string> *data = nullptr) {
+ int len = rnd->Next() % kMaxBlobSize + 1;
+ std::string value = test::RandomHumanReadableString(rnd, len);
+ ASSERT_OK(blob_db_->PutUntil(WriteOptions(), Slice(key), Slice(value),
+ expiration));
+ if (data != nullptr) {
+ (*data)[key] = value;
+ }
+ }
+
+ void PutRandom(const std::string &key, Random *rnd,
+ std::map<std::string, std::string> *data = nullptr) {
+ PutRandom(blob_db_, key, rnd, data);
+ }
+
+ void PutRandom(DB *db, const std::string &key, Random *rnd,
+ std::map<std::string, std::string> *data = nullptr) {
+ int len = rnd->Next() % kMaxBlobSize + 1;
+ std::string value = test::RandomHumanReadableString(rnd, len);
+ ASSERT_OK(db->Put(WriteOptions(), Slice(key), Slice(value)));
+ if (data != nullptr) {
+ (*data)[key] = value;
+ }
+ }
+
+ void PutRandomToWriteBatch(
+ const std::string &key, Random *rnd, WriteBatch *batch,
+ std::map<std::string, std::string> *data = nullptr) {
+ int len = rnd->Next() % kMaxBlobSize + 1;
+ std::string value = test::RandomHumanReadableString(rnd, len);
+ ASSERT_OK(batch->Put(key, value));
+ if (data != nullptr) {
+ (*data)[key] = value;
+ }
+ }
+
+ // Verify blob db contain expected data and nothing more.
+ void VerifyDB(const std::map<std::string, std::string> &data) {
+ VerifyDB(blob_db_, data);
+ }
+
+ void VerifyDB(DB *db, const std::map<std::string, std::string> &data) {
+ // Verify normal Get
+ auto* cfh = db->DefaultColumnFamily();
+ for (auto &p : data) {
+ PinnableSlice value_slice;
+ ASSERT_OK(db->Get(ReadOptions(), cfh, p.first, &value_slice));
+ ASSERT_EQ(p.second, value_slice.ToString());
+ std::string value;
+ ASSERT_OK(db->Get(ReadOptions(), cfh, p.first, &value));
+ ASSERT_EQ(p.second, value);
+ }
+
+ // Verify iterators
+ Iterator *iter = db->NewIterator(ReadOptions());
+ iter->SeekToFirst();
+ for (auto &p : data) {
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ(p.first, iter->key().ToString());
+ ASSERT_EQ(p.second, iter->value().ToString());
+ iter->Next();
+ }
+ ASSERT_FALSE(iter->Valid());
+ ASSERT_OK(iter->status());
+ delete iter;
+ }
+
+ void VerifyBaseDB(
+ const std::map<std::string, KeyVersion> &expected_versions) {
+ auto *bdb_impl = static_cast<BlobDBImpl *>(blob_db_);
+ DB *db = blob_db_->GetRootDB();
+ const size_t kMaxKeys = 10000;
+ std::vector<KeyVersion> versions;
+ GetAllKeyVersions(db, "", "", kMaxKeys, &versions);
+ ASSERT_EQ(expected_versions.size(), versions.size());
+ size_t i = 0;
+ for (auto &key_version : expected_versions) {
+ const KeyVersion &expected_version = key_version.second;
+ ASSERT_EQ(expected_version.user_key, versions[i].user_key);
+ ASSERT_EQ(expected_version.sequence, versions[i].sequence);
+ ASSERT_EQ(expected_version.type, versions[i].type);
+ if (versions[i].type == kTypeValue) {
+ ASSERT_EQ(expected_version.value, versions[i].value);
+ } else {
+ ASSERT_EQ(kTypeBlobIndex, versions[i].type);
+ PinnableSlice value;
+ ASSERT_OK(bdb_impl->TEST_GetBlobValue(versions[i].user_key,
+ versions[i].value, &value));
+ ASSERT_EQ(expected_version.value, value.ToString());
+ }
+ i++;
+ }
+ }
+
+ void VerifyBaseDBBlobIndex(
+ const std::map<std::string, BlobIndexVersion> &expected_versions) {
+ const size_t kMaxKeys = 10000;
+ std::vector<KeyVersion> versions;
+ ASSERT_OK(
+ GetAllKeyVersions(blob_db_->GetRootDB(), "", "", kMaxKeys, &versions));
+ ASSERT_EQ(versions.size(), expected_versions.size());
+
+ size_t i = 0;
+ for (const auto &expected_pair : expected_versions) {
+ const BlobIndexVersion &expected_version = expected_pair.second;
+
+ ASSERT_EQ(versions[i].user_key, expected_version.user_key);
+ ASSERT_EQ(versions[i].sequence, expected_version.sequence);
+ ASSERT_EQ(versions[i].type, expected_version.type);
+ if (versions[i].type != kTypeBlobIndex) {
+ ASSERT_EQ(kInvalidBlobFileNumber, expected_version.file_number);
+ ASSERT_EQ(kNoExpiration, expected_version.expiration);
+
+ ++i;
+ continue;
+ }
+
+ BlobIndex blob_index;
+ ASSERT_OK(blob_index.DecodeFrom(versions[i].value));
+
+ const uint64_t file_number = !blob_index.IsInlined()
+ ? blob_index.file_number()
+ : kInvalidBlobFileNumber;
+ ASSERT_EQ(file_number, expected_version.file_number);
+
+ const uint64_t expiration =
+ blob_index.HasTTL() ? blob_index.expiration() : kNoExpiration;
+ ASSERT_EQ(expiration, expected_version.expiration);
+
+ ++i;
+ }
+ }
+
+ void InsertBlobs() {
+ WriteOptions wo;
+ std::string value;
+
+ Random rnd(301);
+ for (size_t i = 0; i < 100000; i++) {
+ uint64_t ttl = rnd.Next() % 86400;
+ PutRandomWithTTL("key" + ToString(i % 500), ttl, &rnd, nullptr);
+ }
+
+ for (size_t i = 0; i < 10; i++) {
+ Delete("key" + ToString(i % 500));
+ }
+ }
+
+ const std::string dbname_;
+ std::unique_ptr<MockTimeEnv> mock_env_;
+ std::unique_ptr<FaultInjectionTestEnv> fault_injection_env_;
+ BlobDB *blob_db_;
+}; // class BlobDBTest
+
+TEST_F(BlobDBTest, Put) {
+ Random rnd(301);
+ BlobDBOptions bdb_options;
+ bdb_options.min_blob_size = 0;
+ bdb_options.disable_background_tasks = true;
+ Open(bdb_options);
+ std::map<std::string, std::string> data;
+ for (size_t i = 0; i < 100; i++) {
+ PutRandom("key" + ToString(i), &rnd, &data);
+ }
+ VerifyDB(data);
+}
+
+TEST_F(BlobDBTest, PutWithTTL) {
+ Random rnd(301);
+ Options options;
+ options.env = mock_env_.get();
+ BlobDBOptions bdb_options;
+ bdb_options.ttl_range_secs = 1000;
+ bdb_options.min_blob_size = 0;
+ bdb_options.blob_file_size = 256 * 1000 * 1000;
+ bdb_options.disable_background_tasks = true;
+ Open(bdb_options, options);
+ std::map<std::string, std::string> data;
+ mock_env_->set_current_time(50);
+ for (size_t i = 0; i < 100; i++) {
+ uint64_t ttl = rnd.Next() % 100;
+ PutRandomWithTTL("key" + ToString(i), ttl, &rnd,
+ (ttl <= 50 ? nullptr : &data));
+ }
+ mock_env_->set_current_time(100);
+ auto *bdb_impl = static_cast<BlobDBImpl *>(blob_db_);
+ auto blob_files = bdb_impl->TEST_GetBlobFiles();
+ ASSERT_EQ(1, blob_files.size());
+ ASSERT_TRUE(blob_files[0]->HasTTL());
+ ASSERT_OK(bdb_impl->TEST_CloseBlobFile(blob_files[0]));
+ VerifyDB(data);
+}
+
+TEST_F(BlobDBTest, PutUntil) {
+ Random rnd(301);
+ Options options;
+ options.env = mock_env_.get();
+ BlobDBOptions bdb_options;
+ bdb_options.ttl_range_secs = 1000;
+ bdb_options.min_blob_size = 0;
+ bdb_options.blob_file_size = 256 * 1000 * 1000;
+ bdb_options.disable_background_tasks = true;
+ Open(bdb_options, options);
+ std::map<std::string, std::string> data;
+ mock_env_->set_current_time(50);
+ for (size_t i = 0; i < 100; i++) {
+ uint64_t expiration = rnd.Next() % 100 + 50;
+ PutRandomUntil("key" + ToString(i), expiration, &rnd,
+ (expiration <= 100 ? nullptr : &data));
+ }
+ mock_env_->set_current_time(100);
+ auto *bdb_impl = static_cast<BlobDBImpl *>(blob_db_);
+ auto blob_files = bdb_impl->TEST_GetBlobFiles();
+ ASSERT_EQ(1, blob_files.size());
+ ASSERT_TRUE(blob_files[0]->HasTTL());
+ ASSERT_OK(bdb_impl->TEST_CloseBlobFile(blob_files[0]));
+ VerifyDB(data);
+}
+
+TEST_F(BlobDBTest, StackableDBGet) {
+ Random rnd(301);
+ BlobDBOptions bdb_options;
+ bdb_options.min_blob_size = 0;
+ bdb_options.disable_background_tasks = true;
+ Open(bdb_options);
+ std::map<std::string, std::string> data;
+ for (size_t i = 0; i < 100; i++) {
+ PutRandom("key" + ToString(i), &rnd, &data);
+ }
+ for (size_t i = 0; i < 100; i++) {
+ StackableDB *db = blob_db_;
+ ColumnFamilyHandle *column_family = db->DefaultColumnFamily();
+ std::string key = "key" + ToString(i);
+ PinnableSlice pinnable_value;
+ ASSERT_OK(db->Get(ReadOptions(), column_family, key, &pinnable_value));
+ std::string string_value;
+ ASSERT_OK(db->Get(ReadOptions(), column_family, key, &string_value));
+ ASSERT_EQ(string_value, pinnable_value.ToString());
+ ASSERT_EQ(string_value, data[key]);
+ }
+}
+
+TEST_F(BlobDBTest, GetExpiration) {
+ Options options;
+ options.env = mock_env_.get();
+ BlobDBOptions bdb_options;
+ bdb_options.disable_background_tasks = true;
+ mock_env_->set_current_time(100);
+ Open(bdb_options, options);
+ Put("key1", "value1");
+ PutWithTTL("key2", "value2", 200);
+ PinnableSlice value;
+ uint64_t expiration;
+ ASSERT_OK(blob_db_->Get(ReadOptions(), "key1", &value, &expiration));
+ ASSERT_EQ("value1", value.ToString());
+ ASSERT_EQ(kNoExpiration, expiration);
+ ASSERT_OK(blob_db_->Get(ReadOptions(), "key2", &value, &expiration));
+ ASSERT_EQ("value2", value.ToString());
+ ASSERT_EQ(300 /* = 100 + 200 */, expiration);
+}
+
+TEST_F(BlobDBTest, GetIOError) {
+ Options options;
+ options.env = fault_injection_env_.get();
+ BlobDBOptions bdb_options;
+ bdb_options.min_blob_size = 0; // Make sure value write to blob file
+ bdb_options.disable_background_tasks = true;
+ Open(bdb_options, options);
+ ColumnFamilyHandle *column_family = blob_db_->DefaultColumnFamily();
+ PinnableSlice value;
+ ASSERT_OK(Put("foo", "bar"));
+ fault_injection_env_->SetFilesystemActive(false, Status::IOError());
+ Status s = blob_db_->Get(ReadOptions(), column_family, "foo", &value);
+ ASSERT_TRUE(s.IsIOError());
+ // Reactivate file system to allow test to close DB.
+ fault_injection_env_->SetFilesystemActive(true);
+}
+
+TEST_F(BlobDBTest, PutIOError) {
+ Options options;
+ options.env = fault_injection_env_.get();
+ BlobDBOptions bdb_options;
+ bdb_options.min_blob_size = 0; // Make sure value write to blob file
+ bdb_options.disable_background_tasks = true;
+ Open(bdb_options, options);
+ fault_injection_env_->SetFilesystemActive(false, Status::IOError());
+ ASSERT_TRUE(Put("foo", "v1").IsIOError());
+ fault_injection_env_->SetFilesystemActive(true, Status::IOError());
+ ASSERT_OK(Put("bar", "v1"));
+}
+
+TEST_F(BlobDBTest, WriteBatch) {
+ Random rnd(301);
+ BlobDBOptions bdb_options;
+ bdb_options.min_blob_size = 0;
+ bdb_options.disable_background_tasks = true;
+ Open(bdb_options);
+ std::map<std::string, std::string> data;
+ for (size_t i = 0; i < 100; i++) {
+ WriteBatch batch;
+ for (size_t j = 0; j < 10; j++) {
+ PutRandomToWriteBatch("key" + ToString(j * 100 + i), &rnd, &batch, &data);
+ }
+ blob_db_->Write(WriteOptions(), &batch);
+ }
+ VerifyDB(data);
+}
+
+TEST_F(BlobDBTest, Delete) {
+ Random rnd(301);
+ BlobDBOptions bdb_options;
+ bdb_options.min_blob_size = 0;
+ bdb_options.disable_background_tasks = true;
+ Open(bdb_options);
+ std::map<std::string, std::string> data;
+ for (size_t i = 0; i < 100; i++) {
+ PutRandom("key" + ToString(i), &rnd, &data);
+ }
+ for (size_t i = 0; i < 100; i += 5) {
+ Delete("key" + ToString(i), &data);
+ }
+ VerifyDB(data);
+}
+
+TEST_F(BlobDBTest, DeleteBatch) {
+ Random rnd(301);
+ BlobDBOptions bdb_options;
+ bdb_options.min_blob_size = 0;
+ bdb_options.disable_background_tasks = true;
+ Open(bdb_options);
+ for (size_t i = 0; i < 100; i++) {
+ PutRandom("key" + ToString(i), &rnd);
+ }
+ WriteBatch batch;
+ for (size_t i = 0; i < 100; i++) {
+ batch.Delete("key" + ToString(i));
+ }
+ ASSERT_OK(blob_db_->Write(WriteOptions(), &batch));
+ // DB should be empty.
+ VerifyDB({});
+}
+
+TEST_F(BlobDBTest, Override) {
+ Random rnd(301);
+ BlobDBOptions bdb_options;
+ bdb_options.min_blob_size = 0;
+ bdb_options.disable_background_tasks = true;
+ Open(bdb_options);
+ std::map<std::string, std::string> data;
+ for (int i = 0; i < 10000; i++) {
+ PutRandom("key" + ToString(i), &rnd, nullptr);
+ }
+ // override all the keys
+ for (int i = 0; i < 10000; i++) {
+ PutRandom("key" + ToString(i), &rnd, &data);
+ }
+ VerifyDB(data);
+}
+
+#ifdef SNAPPY
+TEST_F(BlobDBTest, Compression) {
+ Random rnd(301);
+ BlobDBOptions bdb_options;
+ bdb_options.min_blob_size = 0;
+ bdb_options.disable_background_tasks = true;
+ bdb_options.compression = CompressionType::kSnappyCompression;
+ Open(bdb_options);
+ std::map<std::string, std::string> data;
+ for (size_t i = 0; i < 100; i++) {
+ PutRandom("put-key" + ToString(i), &rnd, &data);
+ }
+ for (int i = 0; i < 100; i++) {
+ WriteBatch batch;
+ for (size_t j = 0; j < 10; j++) {
+ PutRandomToWriteBatch("write-batch-key" + ToString(j * 100 + i), &rnd,
+ &batch, &data);
+ }
+ blob_db_->Write(WriteOptions(), &batch);
+ }
+ VerifyDB(data);
+}
+
+TEST_F(BlobDBTest, DecompressAfterReopen) {
+ Random rnd(301);
+ BlobDBOptions bdb_options;
+ bdb_options.min_blob_size = 0;
+ bdb_options.disable_background_tasks = true;
+ bdb_options.compression = CompressionType::kSnappyCompression;
+ Open(bdb_options);
+ std::map<std::string, std::string> data;
+ for (size_t i = 0; i < 100; i++) {
+ PutRandom("put-key" + ToString(i), &rnd, &data);
+ }
+ VerifyDB(data);
+ bdb_options.compression = CompressionType::kNoCompression;
+ Reopen(bdb_options);
+ VerifyDB(data);
+}
+#endif
+
+TEST_F(BlobDBTest, MultipleWriters) {
+ Open(BlobDBOptions());
+
+ std::vector<port::Thread> workers;
+ std::vector<std::map<std::string, std::string>> data_set(10);
+ for (uint32_t i = 0; i < 10; i++)
+ workers.push_back(port::Thread(
+ [&](uint32_t id) {
+ Random rnd(301 + id);
+ for (int j = 0; j < 100; j++) {
+ std::string key = "key" + ToString(id) + "_" + ToString(j);
+ if (id < 5) {
+ PutRandom(key, &rnd, &data_set[id]);
+ } else {
+ WriteBatch batch;
+ PutRandomToWriteBatch(key, &rnd, &batch, &data_set[id]);
+ blob_db_->Write(WriteOptions(), &batch);
+ }
+ }
+ },
+ i));
+ std::map<std::string, std::string> data;
+ for (size_t i = 0; i < 10; i++) {
+ workers[i].join();
+ data.insert(data_set[i].begin(), data_set[i].end());
+ }
+ VerifyDB(data);
+}
+
+TEST_F(BlobDBTest, SstFileManager) {
+ // run the same test for Get(), MultiGet() and Iterator each.
+ std::shared_ptr<SstFileManager> sst_file_manager(
+ NewSstFileManager(mock_env_.get()));
+ sst_file_manager->SetDeleteRateBytesPerSecond(1);
+ SstFileManagerImpl *sfm =
+ static_cast<SstFileManagerImpl *>(sst_file_manager.get());
+
+ BlobDBOptions bdb_options;
+ bdb_options.min_blob_size = 0;
+ bdb_options.enable_garbage_collection = true;
+ bdb_options.garbage_collection_cutoff = 1.0;
+ Options db_options;
+
+ int files_scheduled_to_delete = 0;
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "SstFileManagerImpl::ScheduleFileDeletion", [&](void *arg) {
+ assert(arg);
+ const std::string *const file_path =
+ static_cast<const std::string *>(arg);
+ if (file_path->find(".blob") != std::string::npos) {
+ ++files_scheduled_to_delete;
+ }
+ });
+ SyncPoint::GetInstance()->EnableProcessing();
+ db_options.sst_file_manager = sst_file_manager;
+
+ Open(bdb_options, db_options);
+
+ // Create one obselete file and clean it.
+ blob_db_->Put(WriteOptions(), "foo", "bar");
+ auto blob_files = blob_db_impl()->TEST_GetBlobFiles();
+ ASSERT_EQ(1, blob_files.size());
+ std::shared_ptr<BlobFile> bfile = blob_files[0];
+ ASSERT_OK(blob_db_impl()->TEST_CloseBlobFile(bfile));
+ ASSERT_OK(blob_db_->CompactRange(CompactRangeOptions(), nullptr, nullptr));
+ blob_db_impl()->TEST_DeleteObsoleteFiles();
+
+ // Even if SSTFileManager is not set, DB is creating a dummy one.
+ ASSERT_EQ(1, files_scheduled_to_delete);
+ Destroy();
+ // Make sure that DestroyBlobDB() also goes through delete scheduler.
+ ASSERT_EQ(2, files_scheduled_to_delete);
+ SyncPoint::GetInstance()->DisableProcessing();
+ sfm->WaitForEmptyTrash();
+}
+
+TEST_F(BlobDBTest, SstFileManagerRestart) {
+ int files_scheduled_to_delete = 0;
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "SstFileManagerImpl::ScheduleFileDeletion", [&](void *arg) {
+ assert(arg);
+ const std::string *const file_path =
+ static_cast<const std::string *>(arg);
+ if (file_path->find(".blob") != std::string::npos) {
+ ++files_scheduled_to_delete;
+ }
+ });
+
+ // run the same test for Get(), MultiGet() and Iterator each.
+ std::shared_ptr<SstFileManager> sst_file_manager(
+ NewSstFileManager(mock_env_.get()));
+ sst_file_manager->SetDeleteRateBytesPerSecond(1);
+ SstFileManagerImpl *sfm =
+ static_cast<SstFileManagerImpl *>(sst_file_manager.get());
+
+ BlobDBOptions bdb_options;
+ bdb_options.min_blob_size = 0;
+ Options db_options;
+
+ SyncPoint::GetInstance()->EnableProcessing();
+ db_options.sst_file_manager = sst_file_manager;
+
+ Open(bdb_options, db_options);
+ std::string blob_dir = blob_db_impl()->TEST_blob_dir();
+ blob_db_->Put(WriteOptions(), "foo", "bar");
+ Close();
+
+ // Create 3 dummy trash files under the blob_dir
+ LegacyFileSystemWrapper fs(db_options.env);
+ CreateFile(&fs, blob_dir + "/000666.blob.trash", "", false);
+ CreateFile(&fs, blob_dir + "/000888.blob.trash", "", true);
+ CreateFile(&fs, blob_dir + "/something_not_match.trash", "", false);
+
+ // Make sure that reopening the DB rescan the existing trash files
+ Open(bdb_options, db_options);
+ ASSERT_EQ(files_scheduled_to_delete, 2);
+
+ sfm->WaitForEmptyTrash();
+
+ // There should be exact one file under the blob dir now.
+ std::vector<std::string> all_files;
+ ASSERT_OK(db_options.env->GetChildren(blob_dir, &all_files));
+ int nfiles = 0;
+ for (const auto &f : all_files) {
+ assert(!f.empty());
+ if (f[0] == '.') {
+ continue;
+ }
+ nfiles++;
+ }
+ ASSERT_EQ(nfiles, 1);
+
+ SyncPoint::GetInstance()->DisableProcessing();
+}
+
+TEST_F(BlobDBTest, SnapshotAndGarbageCollection) {
+ BlobDBOptions bdb_options;
+ bdb_options.min_blob_size = 0;
+ bdb_options.enable_garbage_collection = true;
+ bdb_options.garbage_collection_cutoff = 1.0;
+ bdb_options.disable_background_tasks = true;
+
+ // i = when to take snapshot
+ for (int i = 0; i < 4; i++) {
+ Destroy();
+ Open(bdb_options);
+
+ const Snapshot *snapshot = nullptr;
+
+ // First file
+ ASSERT_OK(Put("key1", "value"));
+ if (i == 0) {
+ snapshot = blob_db_->GetSnapshot();
+ }
+
+ auto blob_files = blob_db_impl()->TEST_GetBlobFiles();
+ ASSERT_EQ(1, blob_files.size());
+ ASSERT_OK(blob_db_impl()->TEST_CloseBlobFile(blob_files[0]));
+
+ // Second file
+ ASSERT_OK(Put("key2", "value"));
+ if (i == 1) {
+ snapshot = blob_db_->GetSnapshot();
+ }
+
+ blob_files = blob_db_impl()->TEST_GetBlobFiles();
+ ASSERT_EQ(2, blob_files.size());
+ auto bfile = blob_files[1];
+ ASSERT_FALSE(bfile->Immutable());
+ ASSERT_OK(blob_db_impl()->TEST_CloseBlobFile(bfile));
+
+ // Third file
+ ASSERT_OK(Put("key3", "value"));
+ if (i == 2) {
+ snapshot = blob_db_->GetSnapshot();
+ }
+
+ ASSERT_OK(blob_db_->CompactRange(CompactRangeOptions(), nullptr, nullptr));
+ ASSERT_TRUE(bfile->Obsolete());
+ ASSERT_EQ(blob_db_->GetLatestSequenceNumber(),
+ bfile->GetObsoleteSequence());
+
+ Delete("key2");
+ if (i == 3) {
+ snapshot = blob_db_->GetSnapshot();
+ }
+
+ ASSERT_EQ(4, blob_db_impl()->TEST_GetBlobFiles().size());
+ blob_db_impl()->TEST_DeleteObsoleteFiles();
+
+ if (i >= 2) {
+ // The snapshot shouldn't see data in bfile
+ ASSERT_EQ(2, blob_db_impl()->TEST_GetBlobFiles().size());
+ blob_db_->ReleaseSnapshot(snapshot);
+ } else {
+ // The snapshot will see data in bfile, so the file shouldn't be deleted
+ ASSERT_EQ(4, blob_db_impl()->TEST_GetBlobFiles().size());
+ blob_db_->ReleaseSnapshot(snapshot);
+ blob_db_impl()->TEST_DeleteObsoleteFiles();
+ ASSERT_EQ(2, blob_db_impl()->TEST_GetBlobFiles().size());
+ }
+ }
+}
+
+TEST_F(BlobDBTest, ColumnFamilyNotSupported) {
+ Options options;
+ options.env = mock_env_.get();
+ mock_env_->set_current_time(0);
+ Open(BlobDBOptions(), options);
+ ColumnFamilyHandle *default_handle = blob_db_->DefaultColumnFamily();
+ ColumnFamilyHandle *handle = nullptr;
+ std::string value;
+ std::vector<std::string> values;
+ // The call simply pass through to base db. It should succeed.
+ ASSERT_OK(
+ blob_db_->CreateColumnFamily(ColumnFamilyOptions(), "foo", &handle));
+ ASSERT_TRUE(blob_db_->Put(WriteOptions(), handle, "k", "v").IsNotSupported());
+ ASSERT_TRUE(blob_db_->PutWithTTL(WriteOptions(), handle, "k", "v", 60)
+ .IsNotSupported());
+ ASSERT_TRUE(blob_db_->PutUntil(WriteOptions(), handle, "k", "v", 100)
+ .IsNotSupported());
+ WriteBatch batch;
+ batch.Put("k1", "v1");
+ batch.Put(handle, "k2", "v2");
+ ASSERT_TRUE(blob_db_->Write(WriteOptions(), &batch).IsNotSupported());
+ ASSERT_TRUE(blob_db_->Get(ReadOptions(), "k1", &value).IsNotFound());
+ ASSERT_TRUE(
+ blob_db_->Get(ReadOptions(), handle, "k", &value).IsNotSupported());
+ auto statuses = blob_db_->MultiGet(ReadOptions(), {default_handle, handle},
+ {"k1", "k2"}, &values);
+ ASSERT_EQ(2, statuses.size());
+ ASSERT_TRUE(statuses[0].IsNotSupported());
+ ASSERT_TRUE(statuses[1].IsNotSupported());
+ ASSERT_EQ(nullptr, blob_db_->NewIterator(ReadOptions(), handle));
+ delete handle;
+}
+
+TEST_F(BlobDBTest, GetLiveFilesMetaData) {
+ Random rnd(301);
+ BlobDBOptions bdb_options;
+ bdb_options.blob_dir = "blob_dir";
+ bdb_options.path_relative = true;
+ bdb_options.min_blob_size = 0;
+ bdb_options.disable_background_tasks = true;
+ Open(bdb_options);
+ std::map<std::string, std::string> data;
+ for (size_t i = 0; i < 100; i++) {
+ PutRandom("key" + ToString(i), &rnd, &data);
+ }
+ std::vector<LiveFileMetaData> metadata;
+ blob_db_->GetLiveFilesMetaData(&metadata);
+ ASSERT_EQ(1U, metadata.size());
+ // Path should be relative to db_name, but begin with slash.
+ std::string filename = "/blob_dir/000001.blob";
+ ASSERT_EQ(filename, metadata[0].name);
+ ASSERT_EQ(1, metadata[0].file_number);
+ ASSERT_EQ("default", metadata[0].column_family_name);
+ std::vector<std::string> livefile;
+ uint64_t mfs;
+ ASSERT_OK(blob_db_->GetLiveFiles(livefile, &mfs, false));
+ ASSERT_EQ(4U, livefile.size());
+ ASSERT_EQ(filename, livefile[3]);
+ VerifyDB(data);
+}
+
+TEST_F(BlobDBTest, MigrateFromPlainRocksDB) {
+ constexpr size_t kNumKey = 20;
+ constexpr size_t kNumIteration = 10;
+ Random rnd(301);
+ std::map<std::string, std::string> data;
+ std::vector<bool> is_blob(kNumKey, false);
+
+ // Write to plain rocksdb.
+ Options options;
+ options.create_if_missing = true;
+ DB *db = nullptr;
+ ASSERT_OK(DB::Open(options, dbname_, &db));
+ for (size_t i = 0; i < kNumIteration; i++) {
+ auto key_index = rnd.Next() % kNumKey;
+ std::string key = "key" + ToString(key_index);
+ PutRandom(db, key, &rnd, &data);
+ }
+ VerifyDB(db, data);
+ delete db;
+ db = nullptr;
+
+ // Open as blob db. Verify it can read existing data.
+ Open();
+ VerifyDB(blob_db_, data);
+ for (size_t i = 0; i < kNumIteration; i++) {
+ auto key_index = rnd.Next() % kNumKey;
+ std::string key = "key" + ToString(key_index);
+ is_blob[key_index] = true;
+ PutRandom(blob_db_, key, &rnd, &data);
+ }
+ VerifyDB(blob_db_, data);
+ delete blob_db_;
+ blob_db_ = nullptr;
+
+ // Verify plain db return error for keys written by blob db.
+ ASSERT_OK(DB::Open(options, dbname_, &db));
+ std::string value;
+ for (size_t i = 0; i < kNumKey; i++) {
+ std::string key = "key" + ToString(i);
+ Status s = db->Get(ReadOptions(), key, &value);
+ if (data.count(key) == 0) {
+ ASSERT_TRUE(s.IsNotFound());
+ } else if (is_blob[i]) {
+ ASSERT_TRUE(s.IsNotSupported());
+ } else {
+ ASSERT_OK(s);
+ ASSERT_EQ(data[key], value);
+ }
+ }
+ delete db;
+}
+
+// Test to verify that a NoSpace IOError Status is returned on reaching
+// max_db_size limit.
+TEST_F(BlobDBTest, OutOfSpace) {
+ // Use mock env to stop wall clock.
+ Options options;
+ options.env = mock_env_.get();
+ BlobDBOptions bdb_options;
+ bdb_options.max_db_size = 200;
+ bdb_options.is_fifo = false;
+ bdb_options.disable_background_tasks = true;
+ Open(bdb_options);
+
+ // Each stored blob has an overhead of about 42 bytes currently.
+ // So a small key + a 100 byte blob should take up ~150 bytes in the db.
+ std::string value(100, 'v');
+ ASSERT_OK(blob_db_->PutWithTTL(WriteOptions(), "key1", value, 60));
+
+ // Putting another blob should fail as ading it would exceed the max_db_size
+ // limit.
+ Status s = blob_db_->PutWithTTL(WriteOptions(), "key2", value, 60);
+ ASSERT_TRUE(s.IsIOError());
+ ASSERT_TRUE(s.IsNoSpace());
+}
+
+TEST_F(BlobDBTest, FIFOEviction) {
+ BlobDBOptions bdb_options;
+ bdb_options.max_db_size = 200;
+ bdb_options.blob_file_size = 100;
+ bdb_options.is_fifo = true;
+ bdb_options.disable_background_tasks = true;
+ Open(bdb_options);
+
+ std::atomic<int> evict_count{0};
+ SyncPoint::GetInstance()->SetCallBack(
+ "BlobDBImpl::EvictOldestBlobFile:Evicted",
+ [&](void *) { evict_count++; });
+ SyncPoint::GetInstance()->EnableProcessing();
+
+ // Each stored blob has an overhead of 32 bytes currently.
+ // So a 100 byte blob should take up 132 bytes.
+ std::string value(100, 'v');
+ ASSERT_OK(blob_db_->PutWithTTL(WriteOptions(), "key1", value, 10));
+ VerifyDB({{"key1", value}});
+
+ ASSERT_EQ(1, blob_db_impl()->TEST_GetBlobFiles().size());
+
+ // Adding another 100 bytes blob would take the total size to 264 bytes
+ // (2*132). max_db_size will be exceeded
+ // than max_db_size and trigger FIFO eviction.
+ ASSERT_OK(blob_db_->PutWithTTL(WriteOptions(), "key2", value, 60));
+ ASSERT_EQ(1, evict_count);
+ // key1 will exist until corresponding file be deleted.
+ VerifyDB({{"key1", value}, {"key2", value}});
+
+ // Adding another 100 bytes blob without TTL.
+ ASSERT_OK(blob_db_->Put(WriteOptions(), "key3", value));
+ ASSERT_EQ(2, evict_count);
+ // key1 and key2 will exist until corresponding file be deleted.
+ VerifyDB({{"key1", value}, {"key2", value}, {"key3", value}});
+
+ // The fourth blob file, without TTL.
+ ASSERT_OK(blob_db_->Put(WriteOptions(), "key4", value));
+ ASSERT_EQ(3, evict_count);
+ VerifyDB(
+ {{"key1", value}, {"key2", value}, {"key3", value}, {"key4", value}});
+
+ auto blob_files = blob_db_impl()->TEST_GetBlobFiles();
+ ASSERT_EQ(4, blob_files.size());
+ ASSERT_TRUE(blob_files[0]->Obsolete());
+ ASSERT_TRUE(blob_files[1]->Obsolete());
+ ASSERT_TRUE(blob_files[2]->Obsolete());
+ ASSERT_FALSE(blob_files[3]->Obsolete());
+ auto obsolete_files = blob_db_impl()->TEST_GetObsoleteFiles();
+ ASSERT_EQ(3, obsolete_files.size());
+ ASSERT_EQ(blob_files[0], obsolete_files[0]);
+ ASSERT_EQ(blob_files[1], obsolete_files[1]);
+ ASSERT_EQ(blob_files[2], obsolete_files[2]);
+
+ blob_db_impl()->TEST_DeleteObsoleteFiles();
+ obsolete_files = blob_db_impl()->TEST_GetObsoleteFiles();
+ ASSERT_TRUE(obsolete_files.empty());
+ VerifyDB({{"key4", value}});
+}
+
+TEST_F(BlobDBTest, FIFOEviction_NoOldestFileToEvict) {
+ Options options;
+ BlobDBOptions bdb_options;
+ bdb_options.max_db_size = 1000;
+ bdb_options.blob_file_size = 5000;
+ bdb_options.is_fifo = true;
+ bdb_options.disable_background_tasks = true;
+ Open(bdb_options);
+
+ std::atomic<int> evict_count{0};
+ SyncPoint::GetInstance()->SetCallBack(
+ "BlobDBImpl::EvictOldestBlobFile:Evicted",
+ [&](void *) { evict_count++; });
+ SyncPoint::GetInstance()->EnableProcessing();
+
+ std::string value(2000, 'v');
+ ASSERT_TRUE(Put("foo", std::string(2000, 'v')).IsNoSpace());
+ ASSERT_EQ(0, evict_count);
+}
+
+TEST_F(BlobDBTest, FIFOEviction_NoEnoughBlobFilesToEvict) {
+ BlobDBOptions bdb_options;
+ bdb_options.is_fifo = true;
+ bdb_options.min_blob_size = 100;
+ bdb_options.disable_background_tasks = true;
+ Options options;
+ // Use mock env to stop wall clock.
+ options.env = mock_env_.get();
+ options.disable_auto_compactions = true;
+ auto statistics = CreateDBStatistics();
+ options.statistics = statistics;
+ Open(bdb_options, options);
+
+ ASSERT_EQ(0, blob_db_impl()->TEST_live_sst_size());
+ std::string small_value(50, 'v');
+ std::map<std::string, std::string> data;
+ // Insert some data into LSM tree to make sure FIFO eviction take SST
+ // file size into account.
+ for (int i = 0; i < 1000; i++) {
+ ASSERT_OK(Put("key" + ToString(i), small_value, &data));
+ }
+ ASSERT_OK(blob_db_->Flush(FlushOptions()));
+ uint64_t live_sst_size = 0;
+ ASSERT_TRUE(blob_db_->GetIntProperty(DB::Properties::kTotalSstFilesSize,
+ &live_sst_size));
+ ASSERT_TRUE(live_sst_size > 0);
+ ASSERT_EQ(live_sst_size, blob_db_impl()->TEST_live_sst_size());
+
+ bdb_options.max_db_size = live_sst_size + 2000;
+ Reopen(bdb_options, options);
+ ASSERT_EQ(live_sst_size, blob_db_impl()->TEST_live_sst_size());
+
+ std::string value_1k(1000, 'v');
+ ASSERT_OK(PutWithTTL("large_key1", value_1k, 60, &data));
+ ASSERT_EQ(0, statistics->getTickerCount(BLOB_DB_FIFO_NUM_FILES_EVICTED));
+ VerifyDB(data);
+ // large_key2 evicts large_key1
+ ASSERT_OK(PutWithTTL("large_key2", value_1k, 60, &data));
+ ASSERT_EQ(1, statistics->getTickerCount(BLOB_DB_FIFO_NUM_FILES_EVICTED));
+ blob_db_impl()->TEST_DeleteObsoleteFiles();
+ data.erase("large_key1");
+ VerifyDB(data);
+ // large_key3 get no enough space even after evicting large_key2, so it
+ // instead return no space error.
+ std::string value_2k(2000, 'v');
+ ASSERT_TRUE(PutWithTTL("large_key3", value_2k, 60).IsNoSpace());
+ ASSERT_EQ(1, statistics->getTickerCount(BLOB_DB_FIFO_NUM_FILES_EVICTED));
+ // Verify large_key2 still exists.
+ VerifyDB(data);
+}
+
+// Test flush or compaction will trigger FIFO eviction since they update
+// total SST file size.
+TEST_F(BlobDBTest, FIFOEviction_TriggerOnSSTSizeChange) {
+ BlobDBOptions bdb_options;
+ bdb_options.max_db_size = 1000;
+ bdb_options.is_fifo = true;
+ bdb_options.min_blob_size = 100;
+ bdb_options.disable_background_tasks = true;
+ Options options;
+ // Use mock env to stop wall clock.
+ options.env = mock_env_.get();
+ auto statistics = CreateDBStatistics();
+ options.statistics = statistics;
+ options.compression = kNoCompression;
+ Open(bdb_options, options);
+
+ std::string value(800, 'v');
+ ASSERT_OK(PutWithTTL("large_key", value, 60));
+ ASSERT_EQ(1, blob_db_impl()->TEST_GetBlobFiles().size());
+ ASSERT_EQ(0, statistics->getTickerCount(BLOB_DB_FIFO_NUM_FILES_EVICTED));
+ VerifyDB({{"large_key", value}});
+
+ // Insert some small keys and flush to bring DB out of space.
+ std::map<std::string, std::string> data;
+ for (int i = 0; i < 10; i++) {
+ ASSERT_OK(Put("key" + ToString(i), "v", &data));
+ }
+ ASSERT_OK(blob_db_->Flush(FlushOptions()));
+
+ // Verify large_key is deleted by FIFO eviction.
+ blob_db_impl()->TEST_DeleteObsoleteFiles();
+ ASSERT_EQ(0, blob_db_impl()->TEST_GetBlobFiles().size());
+ ASSERT_EQ(1, statistics->getTickerCount(BLOB_DB_FIFO_NUM_FILES_EVICTED));
+ VerifyDB(data);
+}
+
+TEST_F(BlobDBTest, InlineSmallValues) {
+ constexpr uint64_t kMaxExpiration = 1000;
+ Random rnd(301);
+ BlobDBOptions bdb_options;
+ bdb_options.ttl_range_secs = kMaxExpiration;
+ bdb_options.min_blob_size = 100;
+ bdb_options.blob_file_size = 256 * 1000 * 1000;
+ bdb_options.disable_background_tasks = true;
+ Options options;
+ options.env = mock_env_.get();
+ mock_env_->set_current_time(0);
+ Open(bdb_options, options);
+ std::map<std::string, std::string> data;
+ std::map<std::string, KeyVersion> versions;
+ for (size_t i = 0; i < 1000; i++) {
+ bool is_small_value = rnd.Next() % 2;
+ bool has_ttl = rnd.Next() % 2;
+ uint64_t expiration = rnd.Next() % kMaxExpiration;
+ int len = is_small_value ? 50 : 200;
+ std::string key = "key" + ToString(i);
+ std::string value = test::RandomHumanReadableString(&rnd, len);
+ std::string blob_index;
+ data[key] = value;
+ SequenceNumber sequence = blob_db_->GetLatestSequenceNumber() + 1;
+ if (!has_ttl) {
+ ASSERT_OK(blob_db_->Put(WriteOptions(), key, value));
+ } else {
+ ASSERT_OK(blob_db_->PutUntil(WriteOptions(), key, value, expiration));
+ }
+ ASSERT_EQ(blob_db_->GetLatestSequenceNumber(), sequence);
+ versions[key] =
+ KeyVersion(key, value, sequence,
+ (is_small_value && !has_ttl) ? kTypeValue : kTypeBlobIndex);
+ }
+ VerifyDB(data);
+ VerifyBaseDB(versions);
+ auto *bdb_impl = static_cast<BlobDBImpl *>(blob_db_);
+ auto blob_files = bdb_impl->TEST_GetBlobFiles();
+ ASSERT_EQ(2, blob_files.size());
+ std::shared_ptr<BlobFile> non_ttl_file;
+ std::shared_ptr<BlobFile> ttl_file;
+ if (blob_files[0]->HasTTL()) {
+ ttl_file = blob_files[0];
+ non_ttl_file = blob_files[1];
+ } else {
+ non_ttl_file = blob_files[0];
+ ttl_file = blob_files[1];
+ }
+ ASSERT_FALSE(non_ttl_file->HasTTL());
+ ASSERT_TRUE(ttl_file->HasTTL());
+}
+
+TEST_F(BlobDBTest, CompactionFilterNotSupported) {
+ class TestCompactionFilter : public CompactionFilter {
+ const char *Name() const override { return "TestCompactionFilter"; }
+ };
+ class TestCompactionFilterFactory : public CompactionFilterFactory {
+ const char *Name() const override { return "TestCompactionFilterFactory"; }
+ std::unique_ptr<CompactionFilter> CreateCompactionFilter(
+ const CompactionFilter::Context & /*context*/) override {
+ return std::unique_ptr<CompactionFilter>(new TestCompactionFilter());
+ }
+ };
+ for (int i = 0; i < 2; i++) {
+ Options options;
+ if (i == 0) {
+ options.compaction_filter = new TestCompactionFilter();
+ } else {
+ options.compaction_filter_factory.reset(
+ new TestCompactionFilterFactory());
+ }
+ ASSERT_TRUE(TryOpen(BlobDBOptions(), options).IsNotSupported());
+ delete options.compaction_filter;
+ }
+}
+
+// Test comapction filter should remove any expired blob index.
+TEST_F(BlobDBTest, FilterExpiredBlobIndex) {
+ constexpr size_t kNumKeys = 100;
+ constexpr size_t kNumPuts = 1000;
+ constexpr uint64_t kMaxExpiration = 1000;
+ constexpr uint64_t kCompactTime = 500;
+ constexpr uint64_t kMinBlobSize = 100;
+ Random rnd(301);
+ mock_env_->set_current_time(0);
+ BlobDBOptions bdb_options;
+ bdb_options.min_blob_size = kMinBlobSize;
+ bdb_options.disable_background_tasks = true;
+ Options options;
+ options.env = mock_env_.get();
+ Open(bdb_options, options);
+
+ std::map<std::string, std::string> data;
+ std::map<std::string, std::string> data_after_compact;
+ for (size_t i = 0; i < kNumPuts; i++) {
+ bool is_small_value = rnd.Next() % 2;
+ bool has_ttl = rnd.Next() % 2;
+ uint64_t expiration = rnd.Next() % kMaxExpiration;
+ int len = is_small_value ? 10 : 200;
+ std::string key = "key" + ToString(rnd.Next() % kNumKeys);
+ std::string value = test::RandomHumanReadableString(&rnd, len);
+ if (!has_ttl) {
+ if (is_small_value) {
+ std::string blob_entry;
+ BlobIndex::EncodeInlinedTTL(&blob_entry, expiration, value);
+ // Fake blob index with TTL. See what it will do.
+ ASSERT_GT(kMinBlobSize, blob_entry.size());
+ value = blob_entry;
+ }
+ ASSERT_OK(Put(key, value));
+ data_after_compact[key] = value;
+ } else {
+ ASSERT_OK(PutUntil(key, value, expiration));
+ if (expiration <= kCompactTime) {
+ data_after_compact.erase(key);
+ } else {
+ data_after_compact[key] = value;
+ }
+ }
+ data[key] = value;
+ }
+ VerifyDB(data);
+
+ mock_env_->set_current_time(kCompactTime);
+ // Take a snapshot before compaction. Make sure expired blob indexes is
+ // filtered regardless of snapshot.
+ const Snapshot *snapshot = blob_db_->GetSnapshot();
+ // Issue manual compaction to trigger compaction filter.
+ ASSERT_OK(blob_db_->CompactRange(CompactRangeOptions(), nullptr, nullptr));
+ blob_db_->ReleaseSnapshot(snapshot);
+ // Verify expired blob index are filtered.
+ std::vector<KeyVersion> versions;
+ const size_t kMaxKeys = 10000;
+ GetAllKeyVersions(blob_db_, "", "", kMaxKeys, &versions);
+ ASSERT_EQ(data_after_compact.size(), versions.size());
+ for (auto &version : versions) {
+ ASSERT_TRUE(data_after_compact.count(version.user_key) > 0);
+ }
+ VerifyDB(data_after_compact);
+}
+
+// Test compaction filter should remove any blob index where corresponding
+// blob file has been removed.
+TEST_F(BlobDBTest, FilterFileNotAvailable) {
+ BlobDBOptions bdb_options;
+ bdb_options.min_blob_size = 0;
+ bdb_options.disable_background_tasks = true;
+ Options options;
+ options.disable_auto_compactions = true;
+ Open(bdb_options, options);
+
+ ASSERT_OK(Put("foo", "v1"));
+ auto blob_files = blob_db_impl()->TEST_GetBlobFiles();
+ ASSERT_EQ(1, blob_files.size());
+ ASSERT_EQ(1, blob_files[0]->BlobFileNumber());
+ ASSERT_OK(blob_db_impl()->TEST_CloseBlobFile(blob_files[0]));
+
+ ASSERT_OK(Put("bar", "v2"));
+ blob_files = blob_db_impl()->TEST_GetBlobFiles();
+ ASSERT_EQ(2, blob_files.size());
+ ASSERT_EQ(2, blob_files[1]->BlobFileNumber());
+ ASSERT_OK(blob_db_impl()->TEST_CloseBlobFile(blob_files[1]));
+
+ const size_t kMaxKeys = 10000;
+
+ DB *base_db = blob_db_->GetRootDB();
+ std::vector<KeyVersion> versions;
+ ASSERT_OK(GetAllKeyVersions(base_db, "", "", kMaxKeys, &versions));
+ ASSERT_EQ(2, versions.size());
+ ASSERT_EQ("bar", versions[0].user_key);
+ ASSERT_EQ("foo", versions[1].user_key);
+ VerifyDB({{"bar", "v2"}, {"foo", "v1"}});
+
+ ASSERT_OK(blob_db_->CompactRange(CompactRangeOptions(), nullptr, nullptr));
+ ASSERT_OK(GetAllKeyVersions(base_db, "", "", kMaxKeys, &versions));
+ ASSERT_EQ(2, versions.size());
+ ASSERT_EQ("bar", versions[0].user_key);
+ ASSERT_EQ("foo", versions[1].user_key);
+ VerifyDB({{"bar", "v2"}, {"foo", "v1"}});
+
+ // Remove the first blob file and compact. foo should be remove from base db.
+ blob_db_impl()->TEST_ObsoleteBlobFile(blob_files[0]);
+ blob_db_impl()->TEST_DeleteObsoleteFiles();
+ ASSERT_OK(blob_db_->CompactRange(CompactRangeOptions(), nullptr, nullptr));
+ ASSERT_OK(GetAllKeyVersions(base_db, "", "", kMaxKeys, &versions));
+ ASSERT_EQ(1, versions.size());
+ ASSERT_EQ("bar", versions[0].user_key);
+ VerifyDB({{"bar", "v2"}});
+
+ // Remove the second blob file and compact. bar should be remove from base db.
+ blob_db_impl()->TEST_ObsoleteBlobFile(blob_files[1]);
+ blob_db_impl()->TEST_DeleteObsoleteFiles();
+ ASSERT_OK(blob_db_->CompactRange(CompactRangeOptions(), nullptr, nullptr));
+ ASSERT_OK(GetAllKeyVersions(base_db, "", "", kMaxKeys, &versions));
+ ASSERT_EQ(0, versions.size());
+ VerifyDB({});
+}
+
+// Test compaction filter should filter any inlined TTL keys that would have
+// been dropped by last FIFO eviction if they are store out-of-line.
+TEST_F(BlobDBTest, FilterForFIFOEviction) {
+ Random rnd(215);
+ BlobDBOptions bdb_options;
+ bdb_options.min_blob_size = 100;
+ bdb_options.ttl_range_secs = 60;
+ bdb_options.max_db_size = 0;
+ bdb_options.disable_background_tasks = true;
+ Options options;
+ // Use mock env to stop wall clock.
+ mock_env_->set_current_time(0);
+ options.env = mock_env_.get();
+ auto statistics = CreateDBStatistics();
+ options.statistics = statistics;
+ options.disable_auto_compactions = true;
+ Open(bdb_options, options);
+
+ std::map<std::string, std::string> data;
+ std::map<std::string, std::string> data_after_compact;
+ // Insert some small values that will be inlined.
+ for (int i = 0; i < 1000; i++) {
+ std::string key = "key" + ToString(i);
+ std::string value = test::RandomHumanReadableString(&rnd, 50);
+ uint64_t ttl = rnd.Next() % 120 + 1;
+ ASSERT_OK(PutWithTTL(key, value, ttl, &data));
+ if (ttl >= 60) {
+ data_after_compact[key] = value;
+ }
+ }
+ uint64_t num_keys_to_evict = data.size() - data_after_compact.size();
+ ASSERT_OK(blob_db_->Flush(FlushOptions()));
+ uint64_t live_sst_size = blob_db_impl()->TEST_live_sst_size();
+ ASSERT_GT(live_sst_size, 0);
+ VerifyDB(data);
+
+ bdb_options.max_db_size = live_sst_size + 30000;
+ bdb_options.is_fifo = true;
+ Reopen(bdb_options, options);
+ VerifyDB(data);
+
+ // Put two large values, each on a different blob file.
+ std::string large_value(10000, 'v');
+ ASSERT_OK(PutWithTTL("large_key1", large_value, 90));
+ ASSERT_OK(PutWithTTL("large_key2", large_value, 150));
+ ASSERT_EQ(2, blob_db_impl()->TEST_GetBlobFiles().size());
+ ASSERT_EQ(0, statistics->getTickerCount(BLOB_DB_FIFO_NUM_FILES_EVICTED));
+ data["large_key1"] = large_value;
+ data["large_key2"] = large_value;
+ VerifyDB(data);
+
+ // Put a third large value which will bring the DB out of space.
+ // FIFO eviction will evict the file of large_key1.
+ ASSERT_OK(PutWithTTL("large_key3", large_value, 150));
+ ASSERT_EQ(1, statistics->getTickerCount(BLOB_DB_FIFO_NUM_FILES_EVICTED));
+ ASSERT_EQ(2, blob_db_impl()->TEST_GetBlobFiles().size());
+ blob_db_impl()->TEST_DeleteObsoleteFiles();
+ ASSERT_EQ(1, blob_db_impl()->TEST_GetBlobFiles().size());
+ data.erase("large_key1");
+ data["large_key3"] = large_value;
+ VerifyDB(data);
+
+ // Putting some more small values. These values shouldn't be evicted by
+ // compaction filter since they are inserted after FIFO eviction.
+ ASSERT_OK(PutWithTTL("foo", "v", 30, &data_after_compact));
+ ASSERT_OK(PutWithTTL("bar", "v", 30, &data_after_compact));
+
+ // FIFO eviction doesn't trigger again since there enough room for the flush.
+ ASSERT_OK(blob_db_->Flush(FlushOptions()));
+ ASSERT_EQ(1, statistics->getTickerCount(BLOB_DB_FIFO_NUM_FILES_EVICTED));
+
+ // Manual compact and check if compaction filter evict those keys with
+ // expiration < 60.
+ ASSERT_OK(blob_db_->CompactRange(CompactRangeOptions(), nullptr, nullptr));
+ // All keys with expiration < 60, plus large_key1 is filtered by
+ // compaction filter.
+ ASSERT_EQ(num_keys_to_evict + 1,
+ statistics->getTickerCount(BLOB_DB_BLOB_INDEX_EVICTED_COUNT));
+ ASSERT_EQ(1, statistics->getTickerCount(BLOB_DB_FIFO_NUM_FILES_EVICTED));
+ ASSERT_EQ(1, blob_db_impl()->TEST_GetBlobFiles().size());
+ data_after_compact["large_key2"] = large_value;
+ data_after_compact["large_key3"] = large_value;
+ VerifyDB(data_after_compact);
+}
+
+TEST_F(BlobDBTest, GarbageCollection) {
+ constexpr size_t kNumPuts = 1 << 10;
+
+ constexpr uint64_t kExpiration = 1000;
+ constexpr uint64_t kCompactTime = 500;
+
+ constexpr uint64_t kKeySize = 7; // "key" + 4 digits
+
+ constexpr uint64_t kSmallValueSize = 1 << 6;
+ constexpr uint64_t kLargeValueSize = 1 << 8;
+ constexpr uint64_t kMinBlobSize = 1 << 7;
+ static_assert(kSmallValueSize < kMinBlobSize, "");
+ static_assert(kLargeValueSize > kMinBlobSize, "");
+
+ constexpr size_t kBlobsPerFile = 8;
+ constexpr size_t kNumBlobFiles = kNumPuts / kBlobsPerFile;
+ constexpr uint64_t kBlobFileSize =
+ BlobLogHeader::kSize +
+ (BlobLogRecord::kHeaderSize + kKeySize + kLargeValueSize) * kBlobsPerFile;
+
+ BlobDBOptions bdb_options;
+ bdb_options.min_blob_size = kMinBlobSize;
+ bdb_options.blob_file_size = kBlobFileSize;
+ bdb_options.enable_garbage_collection = true;
+ bdb_options.garbage_collection_cutoff = 0.25;
+ bdb_options.disable_background_tasks = true;
+
+ Options options;
+ options.env = mock_env_.get();
+ options.statistics = CreateDBStatistics();
+
+ Open(bdb_options, options);
+
+ std::map<std::string, std::string> data;
+ std::map<std::string, KeyVersion> blob_value_versions;
+ std::map<std::string, BlobIndexVersion> blob_index_versions;
+
+ Random rnd(301);
+
+ // Add a bunch of large non-TTL values. These will be written to non-TTL
+ // blob files and will be subject to GC.
+ for (size_t i = 0; i < kNumPuts; ++i) {
+ std::ostringstream oss;
+ oss << "key" << std::setw(4) << std::setfill('0') << i;
+
+ const std::string key(oss.str());
+ const std::string value(
+ test::RandomHumanReadableString(&rnd, kLargeValueSize));
+ const SequenceNumber sequence = blob_db_->GetLatestSequenceNumber() + 1;
+
+ ASSERT_OK(Put(key, value));
+ ASSERT_EQ(blob_db_->GetLatestSequenceNumber(), sequence);
+
+ data[key] = value;
+ blob_value_versions[key] = KeyVersion(key, value, sequence, kTypeBlobIndex);
+ blob_index_versions[key] =
+ BlobIndexVersion(key, /* file_number */ (i >> 3) + 1, kNoExpiration,
+ sequence, kTypeBlobIndex);
+ }
+
+ // Add some small and/or TTL values that will be ignored during GC.
+ // First, add a large TTL value will be written to its own TTL blob file.
+ {
+ const std::string key("key2000");
+ const std::string value(
+ test::RandomHumanReadableString(&rnd, kLargeValueSize));
+ const SequenceNumber sequence = blob_db_->GetLatestSequenceNumber() + 1;
+
+ ASSERT_OK(PutUntil(key, value, kExpiration));
+ ASSERT_EQ(blob_db_->GetLatestSequenceNumber(), sequence);
+
+ data[key] = value;
+ blob_value_versions[key] = KeyVersion(key, value, sequence, kTypeBlobIndex);
+ blob_index_versions[key] =
+ BlobIndexVersion(key, /* file_number */ kNumBlobFiles + 1, kExpiration,
+ sequence, kTypeBlobIndex);
+ }
+
+ // Now add a small TTL value (which will be inlined).
+ {
+ const std::string key("key3000");
+ const std::string value(
+ test::RandomHumanReadableString(&rnd, kSmallValueSize));
+ const SequenceNumber sequence = blob_db_->GetLatestSequenceNumber() + 1;
+
+ ASSERT_OK(PutUntil(key, value, kExpiration));
+ ASSERT_EQ(blob_db_->GetLatestSequenceNumber(), sequence);
+
+ data[key] = value;
+ blob_value_versions[key] = KeyVersion(key, value, sequence, kTypeBlobIndex);
+ blob_index_versions[key] = BlobIndexVersion(
+ key, kInvalidBlobFileNumber, kExpiration, sequence, kTypeBlobIndex);
+ }
+
+ // Finally, add a small non-TTL value (which will be stored as a regular
+ // value).
+ {
+ const std::string key("key4000");
+ const std::string value(
+ test::RandomHumanReadableString(&rnd, kSmallValueSize));
+ const SequenceNumber sequence = blob_db_->GetLatestSequenceNumber() + 1;
+
+ ASSERT_OK(Put(key, value));
+ ASSERT_EQ(blob_db_->GetLatestSequenceNumber(), sequence);
+
+ data[key] = value;
+ blob_value_versions[key] = KeyVersion(key, value, sequence, kTypeValue);
+ blob_index_versions[key] = BlobIndexVersion(
+ key, kInvalidBlobFileNumber, kNoExpiration, sequence, kTypeValue);
+ }
+
+ VerifyDB(data);
+ VerifyBaseDB(blob_value_versions);
+ VerifyBaseDBBlobIndex(blob_index_versions);
+
+ // At this point, we should have 128 immutable non-TTL files with file numbers
+ // 1..128.
+ {
+ auto live_imm_files = blob_db_impl()->TEST_GetLiveImmNonTTLFiles();
+ ASSERT_EQ(live_imm_files.size(), kNumBlobFiles);
+ for (size_t i = 0; i < kNumBlobFiles; ++i) {
+ ASSERT_EQ(live_imm_files[i]->BlobFileNumber(), i + 1);
+ ASSERT_EQ(live_imm_files[i]->GetFileSize(),
+ kBlobFileSize + BlobLogFooter::kSize);
+ }
+ }
+
+ mock_env_->set_current_time(kCompactTime);
+
+ ASSERT_OK(blob_db_->CompactRange(CompactRangeOptions(), nullptr, nullptr));
+
+ // We expect the data to remain the same and the blobs from the oldest N files
+ // to be moved to new files. Sequence numbers get zeroed out during the
+ // compaction.
+ VerifyDB(data);
+
+ for (auto &pair : blob_value_versions) {
+ KeyVersion &version = pair.second;
+ version.sequence = 0;
+ }
+
+ VerifyBaseDB(blob_value_versions);
+
+ const uint64_t cutoff = static_cast<uint64_t>(
+ bdb_options.garbage_collection_cutoff * kNumBlobFiles);
+ for (auto &pair : blob_index_versions) {
+ BlobIndexVersion &version = pair.second;
+
+ version.sequence = 0;
+
+ if (version.file_number == kInvalidBlobFileNumber) {
+ continue;
+ }
+
+ if (version.file_number > cutoff) {
+ continue;
+ }
+
+ version.file_number += kNumBlobFiles + 1;
+ }
+
+ VerifyBaseDBBlobIndex(blob_index_versions);
+
+ const Statistics *const statistics = options.statistics.get();
+ assert(statistics);
+
+ ASSERT_EQ(statistics->getTickerCount(BLOB_DB_GC_NUM_FILES), cutoff);
+ ASSERT_EQ(statistics->getTickerCount(BLOB_DB_GC_NUM_NEW_FILES), cutoff);
+ ASSERT_EQ(statistics->getTickerCount(BLOB_DB_GC_FAILURES), 0);
+ ASSERT_EQ(statistics->getTickerCount(BLOB_DB_GC_NUM_KEYS_RELOCATED),
+ cutoff * kBlobsPerFile);
+ ASSERT_EQ(statistics->getTickerCount(BLOB_DB_GC_BYTES_RELOCATED),
+ cutoff * kBlobsPerFile * kLargeValueSize);
+
+ // At this point, we should have 128 immutable non-TTL files with file numbers
+ // 33..128 and 130..161. (129 was taken by the TTL blob file.)
+ {
+ auto live_imm_files = blob_db_impl()->TEST_GetLiveImmNonTTLFiles();
+ ASSERT_EQ(live_imm_files.size(), kNumBlobFiles);
+ for (size_t i = 0; i < kNumBlobFiles; ++i) {
+ uint64_t expected_file_number = i + cutoff + 1;
+ if (expected_file_number > kNumBlobFiles) {
+ ++expected_file_number;
+ }
+
+ ASSERT_EQ(live_imm_files[i]->BlobFileNumber(), expected_file_number);
+ ASSERT_EQ(live_imm_files[i]->GetFileSize(),
+ kBlobFileSize + BlobLogFooter::kSize);
+ }
+ }
+}
+
+TEST_F(BlobDBTest, GarbageCollectionFailure) {
+ BlobDBOptions bdb_options;
+ bdb_options.min_blob_size = 0;
+ bdb_options.enable_garbage_collection = true;
+ bdb_options.garbage_collection_cutoff = 1.0;
+ bdb_options.disable_background_tasks = true;
+
+ Options db_options;
+ db_options.statistics = CreateDBStatistics();
+
+ Open(bdb_options, db_options);
+
+ // Write a couple of valid blobs.
+ Put("foo", "bar");
+ Put("dead", "beef");
+
+ // Write a fake blob reference into the base DB that cannot be parsed.
+ WriteBatch batch;
+ ASSERT_OK(WriteBatchInternal::PutBlobIndex(
+ &batch, blob_db_->DefaultColumnFamily()->GetID(), "key",
+ "not a valid blob index"));
+ ASSERT_OK(blob_db_->GetRootDB()->Write(WriteOptions(), &batch));
+
+ auto blob_files = blob_db_impl()->TEST_GetBlobFiles();
+ ASSERT_EQ(blob_files.size(), 1);
+ auto blob_file = blob_files[0];
+ ASSERT_OK(blob_db_impl()->TEST_CloseBlobFile(blob_file));
+
+ ASSERT_TRUE(blob_db_->CompactRange(CompactRangeOptions(), nullptr, nullptr)
+ .IsCorruption());
+
+ const Statistics *const statistics = db_options.statistics.get();
+ assert(statistics);
+
+ ASSERT_EQ(statistics->getTickerCount(BLOB_DB_GC_NUM_FILES), 0);
+ ASSERT_EQ(statistics->getTickerCount(BLOB_DB_GC_NUM_NEW_FILES), 1);
+ ASSERT_EQ(statistics->getTickerCount(BLOB_DB_GC_FAILURES), 1);
+ ASSERT_EQ(statistics->getTickerCount(BLOB_DB_GC_NUM_KEYS_RELOCATED), 2);
+ ASSERT_EQ(statistics->getTickerCount(BLOB_DB_GC_BYTES_RELOCATED), 7);
+}
+
+// File should be evicted after expiration.
+TEST_F(BlobDBTest, EvictExpiredFile) {
+ BlobDBOptions bdb_options;
+ bdb_options.ttl_range_secs = 100;
+ bdb_options.min_blob_size = 0;
+ bdb_options.disable_background_tasks = true;
+ Options options;
+ options.env = mock_env_.get();
+ Open(bdb_options, options);
+ mock_env_->set_current_time(50);
+ std::map<std::string, std::string> data;
+ ASSERT_OK(PutWithTTL("foo", "bar", 100, &data));
+ auto blob_files = blob_db_impl()->TEST_GetBlobFiles();
+ ASSERT_EQ(1, blob_files.size());
+ auto blob_file = blob_files[0];
+ ASSERT_FALSE(blob_file->Immutable());
+ ASSERT_FALSE(blob_file->Obsolete());
+ VerifyDB(data);
+ mock_env_->set_current_time(250);
+ // The key should expired now.
+ blob_db_impl()->TEST_EvictExpiredFiles();
+ ASSERT_EQ(1, blob_db_impl()->TEST_GetBlobFiles().size());
+ ASSERT_EQ(1, blob_db_impl()->TEST_GetObsoleteFiles().size());
+ ASSERT_TRUE(blob_file->Immutable());
+ ASSERT_TRUE(blob_file->Obsolete());
+ blob_db_impl()->TEST_DeleteObsoleteFiles();
+ ASSERT_EQ(0, blob_db_impl()->TEST_GetBlobFiles().size());
+ ASSERT_EQ(0, blob_db_impl()->TEST_GetObsoleteFiles().size());
+ // Make sure we don't return garbage value after blob file being evicted,
+ // but the blob index still exists in the LSM tree.
+ std::string val = "";
+ ASSERT_TRUE(blob_db_->Get(ReadOptions(), "foo", &val).IsNotFound());
+ ASSERT_EQ("", val);
+}
+
+TEST_F(BlobDBTest, DisableFileDeletions) {
+ BlobDBOptions bdb_options;
+ bdb_options.disable_background_tasks = true;
+ Open(bdb_options);
+ std::map<std::string, std::string> data;
+ for (bool force : {true, false}) {
+ ASSERT_OK(Put("foo", "v", &data));
+ auto blob_files = blob_db_impl()->TEST_GetBlobFiles();
+ ASSERT_EQ(1, blob_files.size());
+ auto blob_file = blob_files[0];
+ ASSERT_OK(blob_db_impl()->TEST_CloseBlobFile(blob_file));
+ blob_db_impl()->TEST_ObsoleteBlobFile(blob_file);
+ ASSERT_EQ(1, blob_db_impl()->TEST_GetBlobFiles().size());
+ ASSERT_EQ(1, blob_db_impl()->TEST_GetObsoleteFiles().size());
+ // Call DisableFileDeletions twice.
+ ASSERT_OK(blob_db_->DisableFileDeletions());
+ ASSERT_OK(blob_db_->DisableFileDeletions());
+ // File deletions should be disabled.
+ blob_db_impl()->TEST_DeleteObsoleteFiles();
+ ASSERT_EQ(1, blob_db_impl()->TEST_GetBlobFiles().size());
+ ASSERT_EQ(1, blob_db_impl()->TEST_GetObsoleteFiles().size());
+ VerifyDB(data);
+ // Enable file deletions once. If force=true, file deletion is enabled.
+ // Otherwise it needs to enable it for a second time.
+ ASSERT_OK(blob_db_->EnableFileDeletions(force));
+ blob_db_impl()->TEST_DeleteObsoleteFiles();
+ if (!force) {
+ ASSERT_EQ(1, blob_db_impl()->TEST_GetBlobFiles().size());
+ ASSERT_EQ(1, blob_db_impl()->TEST_GetObsoleteFiles().size());
+ VerifyDB(data);
+ // Call EnableFileDeletions a second time.
+ ASSERT_OK(blob_db_->EnableFileDeletions(false));
+ blob_db_impl()->TEST_DeleteObsoleteFiles();
+ }
+ // Regardless of value of `force`, file should be deleted by now.
+ ASSERT_EQ(0, blob_db_impl()->TEST_GetBlobFiles().size());
+ ASSERT_EQ(0, blob_db_impl()->TEST_GetObsoleteFiles().size());
+ VerifyDB({});
+ }
+}
+
+TEST_F(BlobDBTest, MaintainBlobFileToSstMapping) {
+ BlobDBOptions bdb_options;
+ bdb_options.enable_garbage_collection = true;
+ bdb_options.disable_background_tasks = true;
+ Open(bdb_options);
+
+ // Register some dummy blob files.
+ blob_db_impl()->TEST_AddDummyBlobFile(1, /* immutable_sequence */ 200);
+ blob_db_impl()->TEST_AddDummyBlobFile(2, /* immutable_sequence */ 300);
+ blob_db_impl()->TEST_AddDummyBlobFile(3, /* immutable_sequence */ 400);
+ blob_db_impl()->TEST_AddDummyBlobFile(4, /* immutable_sequence */ 500);
+ blob_db_impl()->TEST_AddDummyBlobFile(5, /* immutable_sequence */ 600);
+
+ // Initialize the blob <-> SST file mapping. First, add some SST files with
+ // blob file references, then some without.
+ std::vector<LiveFileMetaData> live_files;
+
+ for (uint64_t i = 1; i <= 10; ++i) {
+ LiveFileMetaData live_file;
+ live_file.file_number = i;
+ live_file.oldest_blob_file_number = ((i - 1) % 5) + 1;
+
+ live_files.emplace_back(live_file);
+ }
+
+ for (uint64_t i = 11; i <= 20; ++i) {
+ LiveFileMetaData live_file;
+ live_file.file_number = i;
+
+ live_files.emplace_back(live_file);
+ }
+
+ blob_db_impl()->TEST_InitializeBlobFileToSstMapping(live_files);
+
+ // Check that the blob <-> SST mappings have been correctly initialized.
+ auto blob_files = blob_db_impl()->TEST_GetBlobFiles();
+
+ ASSERT_EQ(blob_files.size(), 5);
+
+ {
+ auto live_imm_files = blob_db_impl()->TEST_GetLiveImmNonTTLFiles();
+ ASSERT_EQ(live_imm_files.size(), 5);
+ for (size_t i = 0; i < 5; ++i) {
+ ASSERT_EQ(live_imm_files[i]->BlobFileNumber(), i + 1);
+ }
+
+ ASSERT_TRUE(blob_db_impl()->TEST_GetObsoleteFiles().empty());
+ }
+
+ {
+ const std::vector<std::unordered_set<uint64_t>> expected_sst_files{
+ {1, 6}, {2, 7}, {3, 8}, {4, 9}, {5, 10}};
+ const std::vector<bool> expected_obsolete{false, false, false, false,
+ false};
+ for (size_t i = 0; i < 5; ++i) {
+ const auto &blob_file = blob_files[i];
+ ASSERT_EQ(blob_file->GetLinkedSstFiles(), expected_sst_files[i]);
+ ASSERT_EQ(blob_file->Obsolete(), expected_obsolete[i]);
+ }
+
+ auto live_imm_files = blob_db_impl()->TEST_GetLiveImmNonTTLFiles();
+ ASSERT_EQ(live_imm_files.size(), 5);
+ for (size_t i = 0; i < 5; ++i) {
+ ASSERT_EQ(live_imm_files[i]->BlobFileNumber(), i + 1);
+ }
+
+ ASSERT_TRUE(blob_db_impl()->TEST_GetObsoleteFiles().empty());
+ }
+
+ // Simulate a flush where the SST does not reference any blob files.
+ {
+ FlushJobInfo info{};
+ info.file_number = 21;
+ info.smallest_seqno = 1;
+ info.largest_seqno = 100;
+
+ blob_db_impl()->TEST_ProcessFlushJobInfo(info);
+
+ const std::vector<std::unordered_set<uint64_t>> expected_sst_files{
+ {1, 6}, {2, 7}, {3, 8}, {4, 9}, {5, 10}};
+ const std::vector<bool> expected_obsolete{false, false, false, false,
+ false};
+ for (size_t i = 0; i < 5; ++i) {
+ const auto &blob_file = blob_files[i];
+ ASSERT_EQ(blob_file->GetLinkedSstFiles(), expected_sst_files[i]);
+ ASSERT_EQ(blob_file->Obsolete(), expected_obsolete[i]);
+ }
+
+ auto live_imm_files = blob_db_impl()->TEST_GetLiveImmNonTTLFiles();
+ ASSERT_EQ(live_imm_files.size(), 5);
+ for (size_t i = 0; i < 5; ++i) {
+ ASSERT_EQ(live_imm_files[i]->BlobFileNumber(), i + 1);
+ }
+
+ ASSERT_TRUE(blob_db_impl()->TEST_GetObsoleteFiles().empty());
+ }
+
+ // Simulate a flush where the SST references a blob file.
+ {
+ FlushJobInfo info{};
+ info.file_number = 22;
+ info.oldest_blob_file_number = 5;
+ info.smallest_seqno = 101;
+ info.largest_seqno = 200;
+
+ blob_db_impl()->TEST_ProcessFlushJobInfo(info);
+
+ const std::vector<std::unordered_set<uint64_t>> expected_sst_files{
+ {1, 6}, {2, 7}, {3, 8}, {4, 9}, {5, 10, 22}};
+ const std::vector<bool> expected_obsolete{false, false, false, false,
+ false};
+ for (size_t i = 0; i < 5; ++i) {
+ const auto &blob_file = blob_files[i];
+ ASSERT_EQ(blob_file->GetLinkedSstFiles(), expected_sst_files[i]);
+ ASSERT_EQ(blob_file->Obsolete(), expected_obsolete[i]);
+ }
+
+ auto live_imm_files = blob_db_impl()->TEST_GetLiveImmNonTTLFiles();
+ ASSERT_EQ(live_imm_files.size(), 5);
+ for (size_t i = 0; i < 5; ++i) {
+ ASSERT_EQ(live_imm_files[i]->BlobFileNumber(), i + 1);
+ }
+
+ ASSERT_TRUE(blob_db_impl()->TEST_GetObsoleteFiles().empty());
+ }
+
+ // Simulate a compaction. Some inputs and outputs have blob file references,
+ // some don't. There is also a trivial move (which means the SST appears on
+ // both the input and the output list). Blob file 1 loses all its linked SSTs,
+ // and since it got marked immutable at sequence number 200 which has already
+ // been flushed, it can be marked obsolete.
+ {
+ CompactionJobInfo info{};
+ info.input_file_infos.emplace_back(CompactionFileInfo{1, 1, 1});
+ info.input_file_infos.emplace_back(CompactionFileInfo{1, 2, 2});
+ info.input_file_infos.emplace_back(CompactionFileInfo{1, 6, 1});
+ info.input_file_infos.emplace_back(
+ CompactionFileInfo{1, 11, kInvalidBlobFileNumber});
+ info.input_file_infos.emplace_back(CompactionFileInfo{1, 22, 5});
+ info.output_file_infos.emplace_back(CompactionFileInfo{2, 22, 5});
+ info.output_file_infos.emplace_back(CompactionFileInfo{2, 23, 3});
+ info.output_file_infos.emplace_back(
+ CompactionFileInfo{2, 24, kInvalidBlobFileNumber});
+
+ blob_db_impl()->TEST_ProcessCompactionJobInfo(info);
+
+ const std::vector<std::unordered_set<uint64_t>> expected_sst_files{
+ {}, {7}, {3, 8, 23}, {4, 9}, {5, 10, 22}};
+ const std::vector<bool> expected_obsolete{true, false, false, false, false};
+ for (size_t i = 0; i < 5; ++i) {
+ const auto &blob_file = blob_files[i];
+ ASSERT_EQ(blob_file->GetLinkedSstFiles(), expected_sst_files[i]);
+ ASSERT_EQ(blob_file->Obsolete(), expected_obsolete[i]);
+ }
+
+ auto live_imm_files = blob_db_impl()->TEST_GetLiveImmNonTTLFiles();
+ ASSERT_EQ(live_imm_files.size(), 4);
+ for (size_t i = 0; i < 4; ++i) {
+ ASSERT_EQ(live_imm_files[i]->BlobFileNumber(), i + 2);
+ }
+
+ auto obsolete_files = blob_db_impl()->TEST_GetObsoleteFiles();
+ ASSERT_EQ(obsolete_files.size(), 1);
+ ASSERT_EQ(obsolete_files[0]->BlobFileNumber(), 1);
+ }
+
+ // Simulate a failed compaction. No mappings should be updated.
+ {
+ CompactionJobInfo info{};
+ info.input_file_infos.emplace_back(CompactionFileInfo{1, 7, 2});
+ info.input_file_infos.emplace_back(CompactionFileInfo{2, 22, 5});
+ info.output_file_infos.emplace_back(CompactionFileInfo{2, 25, 3});
+ info.status = Status::Corruption();
+
+ blob_db_impl()->TEST_ProcessCompactionJobInfo(info);
+
+ const std::vector<std::unordered_set<uint64_t>> expected_sst_files{
+ {}, {7}, {3, 8, 23}, {4, 9}, {5, 10, 22}};
+ const std::vector<bool> expected_obsolete{true, false, false, false, false};
+ for (size_t i = 0; i < 5; ++i) {
+ const auto &blob_file = blob_files[i];
+ ASSERT_EQ(blob_file->GetLinkedSstFiles(), expected_sst_files[i]);
+ ASSERT_EQ(blob_file->Obsolete(), expected_obsolete[i]);
+ }
+
+ auto live_imm_files = blob_db_impl()->TEST_GetLiveImmNonTTLFiles();
+ ASSERT_EQ(live_imm_files.size(), 4);
+ for (size_t i = 0; i < 4; ++i) {
+ ASSERT_EQ(live_imm_files[i]->BlobFileNumber(), i + 2);
+ }
+
+ auto obsolete_files = blob_db_impl()->TEST_GetObsoleteFiles();
+ ASSERT_EQ(obsolete_files.size(), 1);
+ ASSERT_EQ(obsolete_files[0]->BlobFileNumber(), 1);
+ }
+
+ // Simulate another compaction. Blob file 2 loses all its linked SSTs
+ // but since it got marked immutable at sequence number 300 which hasn't
+ // been flushed yet, it cannot be marked obsolete at this point.
+ {
+ CompactionJobInfo info{};
+ info.input_file_infos.emplace_back(CompactionFileInfo{1, 7, 2});
+ info.input_file_infos.emplace_back(CompactionFileInfo{2, 22, 5});
+ info.output_file_infos.emplace_back(CompactionFileInfo{2, 25, 3});
+
+ blob_db_impl()->TEST_ProcessCompactionJobInfo(info);
+
+ const std::vector<std::unordered_set<uint64_t>> expected_sst_files{
+ {}, {}, {3, 8, 23, 25}, {4, 9}, {5, 10}};
+ const std::vector<bool> expected_obsolete{true, false, false, false, false};
+ for (size_t i = 0; i < 5; ++i) {
+ const auto &blob_file = blob_files[i];
+ ASSERT_EQ(blob_file->GetLinkedSstFiles(), expected_sst_files[i]);
+ ASSERT_EQ(blob_file->Obsolete(), expected_obsolete[i]);
+ }
+
+ auto live_imm_files = blob_db_impl()->TEST_GetLiveImmNonTTLFiles();
+ ASSERT_EQ(live_imm_files.size(), 4);
+ for (size_t i = 0; i < 4; ++i) {
+ ASSERT_EQ(live_imm_files[i]->BlobFileNumber(), i + 2);
+ }
+
+ auto obsolete_files = blob_db_impl()->TEST_GetObsoleteFiles();
+ ASSERT_EQ(obsolete_files.size(), 1);
+ ASSERT_EQ(obsolete_files[0]->BlobFileNumber(), 1);
+ }
+
+ // Simulate a flush with largest sequence number 300. This will make it
+ // possible to mark blob file 2 obsolete.
+ {
+ FlushJobInfo info{};
+ info.file_number = 26;
+ info.smallest_seqno = 201;
+ info.largest_seqno = 300;
+
+ blob_db_impl()->TEST_ProcessFlushJobInfo(info);
+
+ const std::vector<std::unordered_set<uint64_t>> expected_sst_files{
+ {}, {}, {3, 8, 23, 25}, {4, 9}, {5, 10}};
+ const std::vector<bool> expected_obsolete{true, true, false, false, false};
+ for (size_t i = 0; i < 5; ++i) {
+ const auto &blob_file = blob_files[i];
+ ASSERT_EQ(blob_file->GetLinkedSstFiles(), expected_sst_files[i]);
+ ASSERT_EQ(blob_file->Obsolete(), expected_obsolete[i]);
+ }
+
+ auto live_imm_files = blob_db_impl()->TEST_GetLiveImmNonTTLFiles();
+ ASSERT_EQ(live_imm_files.size(), 3);
+ for (size_t i = 0; i < 3; ++i) {
+ ASSERT_EQ(live_imm_files[i]->BlobFileNumber(), i + 3);
+ }
+
+ auto obsolete_files = blob_db_impl()->TEST_GetObsoleteFiles();
+ ASSERT_EQ(obsolete_files.size(), 2);
+ ASSERT_EQ(obsolete_files[0]->BlobFileNumber(), 1);
+ ASSERT_EQ(obsolete_files[1]->BlobFileNumber(), 2);
+ }
+}
+
+TEST_F(BlobDBTest, ShutdownWait) {
+ BlobDBOptions bdb_options;
+ bdb_options.ttl_range_secs = 100;
+ bdb_options.min_blob_size = 0;
+ bdb_options.disable_background_tasks = false;
+ Options options;
+ options.env = mock_env_.get();
+
+ SyncPoint::GetInstance()->LoadDependency({
+ {"BlobDBImpl::EvictExpiredFiles:0", "BlobDBTest.ShutdownWait:0"},
+ {"BlobDBTest.ShutdownWait:1", "BlobDBImpl::EvictExpiredFiles:1"},
+ {"BlobDBImpl::EvictExpiredFiles:2", "BlobDBTest.ShutdownWait:2"},
+ {"BlobDBTest.ShutdownWait:3", "BlobDBImpl::EvictExpiredFiles:3"},
+ });
+ // Force all tasks to be scheduled immediately.
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "TimeQueue::Add:item.end", [&](void *arg) {
+ std::chrono::steady_clock::time_point *tp =
+ static_cast<std::chrono::steady_clock::time_point *>(arg);
+ *tp =
+ std::chrono::steady_clock::now() - std::chrono::milliseconds(10000);
+ });
+
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "BlobDBImpl::EvictExpiredFiles:cb", [&](void * /*arg*/) {
+ // Sleep 3 ms to increase the chance of data race.
+ // We've synced up the code so that EvictExpiredFiles()
+ // is called concurrently with ~BlobDBImpl().
+ // ~BlobDBImpl() is supposed to wait for all background
+ // task to shutdown before doing anything else. In order
+ // to use the same test to reproduce a bug of the waiting
+ // logic, we wait a little bit here, so that TSAN can
+ // catch the data race.
+ // We should improve the test if we find a better way.
+ Env::Default()->SleepForMicroseconds(3000);
+ });
+
+ SyncPoint::GetInstance()->EnableProcessing();
+
+ Open(bdb_options, options);
+ mock_env_->set_current_time(50);
+ std::map<std::string, std::string> data;
+ ASSERT_OK(PutWithTTL("foo", "bar", 100, &data));
+ auto blob_files = blob_db_impl()->TEST_GetBlobFiles();
+ ASSERT_EQ(1, blob_files.size());
+ auto blob_file = blob_files[0];
+ ASSERT_FALSE(blob_file->Immutable());
+ ASSERT_FALSE(blob_file->Obsolete());
+ VerifyDB(data);
+
+ TEST_SYNC_POINT("BlobDBTest.ShutdownWait:0");
+ mock_env_->set_current_time(250);
+ // The key should expired now.
+ TEST_SYNC_POINT("BlobDBTest.ShutdownWait:1");
+
+ TEST_SYNC_POINT("BlobDBTest.ShutdownWait:2");
+ TEST_SYNC_POINT("BlobDBTest.ShutdownWait:3");
+ Close();
+
+ SyncPoint::GetInstance()->DisableProcessing();
+}
+
+} // namespace blob_db
+} // namespace ROCKSDB_NAMESPACE
+
+// A black-box test for the ttl wrapper around rocksdb
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
+
+#else
+#include <stdio.h>
+
+int main(int /*argc*/, char** /*argv*/) {
+ fprintf(stderr, "SKIPPED as BlobDB is not supported in ROCKSDB_LITE\n");
+ return 0;
+}
+
+#endif // !ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/blob_db/blob_dump_tool.cc b/src/rocksdb/utilities/blob_db/blob_dump_tool.cc
new file mode 100644
index 000000000..58f26128f
--- /dev/null
+++ b/src/rocksdb/utilities/blob_db/blob_dump_tool.cc
@@ -0,0 +1,278 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+#ifndef ROCKSDB_LITE
+
+#include "utilities/blob_db/blob_dump_tool.h"
+#include <stdio.h>
+#include <cinttypes>
+#include <iostream>
+#include <memory>
+#include <string>
+#include "env/composite_env_wrapper.h"
+#include "file/random_access_file_reader.h"
+#include "file/readahead_raf.h"
+#include "port/port.h"
+#include "rocksdb/convenience.h"
+#include "rocksdb/env.h"
+#include "table/format.h"
+#include "util/coding.h"
+#include "util/string_util.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace blob_db {
+
+BlobDumpTool::BlobDumpTool()
+ : reader_(nullptr), buffer_(nullptr), buffer_size_(0) {}
+
+Status BlobDumpTool::Run(const std::string& filename, DisplayType show_key,
+ DisplayType show_blob,
+ DisplayType show_uncompressed_blob,
+ bool show_summary) {
+ constexpr size_t kReadaheadSize = 2 * 1024 * 1024;
+ Status s;
+ Env* env = Env::Default();
+ s = env->FileExists(filename);
+ if (!s.ok()) {
+ return s;
+ }
+ uint64_t file_size = 0;
+ s = env->GetFileSize(filename, &file_size);
+ if (!s.ok()) {
+ return s;
+ }
+ std::unique_ptr<RandomAccessFile> file;
+ s = env->NewRandomAccessFile(filename, &file, EnvOptions());
+ if (!s.ok()) {
+ return s;
+ }
+ file = NewReadaheadRandomAccessFile(std::move(file), kReadaheadSize);
+ if (file_size == 0) {
+ return Status::Corruption("File is empty.");
+ }
+ reader_.reset(new RandomAccessFileReader(
+ NewLegacyRandomAccessFileWrapper(file), filename));
+ uint64_t offset = 0;
+ uint64_t footer_offset = 0;
+ CompressionType compression = kNoCompression;
+ s = DumpBlobLogHeader(&offset, &compression);
+ if (!s.ok()) {
+ return s;
+ }
+ s = DumpBlobLogFooter(file_size, &footer_offset);
+ if (!s.ok()) {
+ return s;
+ }
+ uint64_t total_records = 0;
+ uint64_t total_key_size = 0;
+ uint64_t total_blob_size = 0;
+ uint64_t total_uncompressed_blob_size = 0;
+ if (show_key != DisplayType::kNone || show_summary) {
+ while (offset < footer_offset) {
+ s = DumpRecord(show_key, show_blob, show_uncompressed_blob, show_summary,
+ compression, &offset, &total_records, &total_key_size,
+ &total_blob_size, &total_uncompressed_blob_size);
+ if (!s.ok()) {
+ break;
+ }
+ }
+ }
+ if (show_summary) {
+ fprintf(stdout, "Summary:\n");
+ fprintf(stdout, " total records: %" PRIu64 "\n", total_records);
+ fprintf(stdout, " total key size: %" PRIu64 "\n", total_key_size);
+ fprintf(stdout, " total blob size: %" PRIu64 "\n", total_blob_size);
+ if (compression != kNoCompression) {
+ fprintf(stdout, " total raw blob size: %" PRIu64 "\n",
+ total_uncompressed_blob_size);
+ }
+ }
+ return s;
+}
+
+Status BlobDumpTool::Read(uint64_t offset, size_t size, Slice* result) {
+ if (buffer_size_ < size) {
+ if (buffer_size_ == 0) {
+ buffer_size_ = 4096;
+ }
+ while (buffer_size_ < size) {
+ buffer_size_ *= 2;
+ }
+ buffer_.reset(new char[buffer_size_]);
+ }
+ Status s = reader_->Read(offset, size, result, buffer_.get());
+ if (!s.ok()) {
+ return s;
+ }
+ if (result->size() != size) {
+ return Status::Corruption("Reach the end of the file unexpectedly.");
+ }
+ return s;
+}
+
+Status BlobDumpTool::DumpBlobLogHeader(uint64_t* offset,
+ CompressionType* compression) {
+ Slice slice;
+ Status s = Read(0, BlobLogHeader::kSize, &slice);
+ if (!s.ok()) {
+ return s;
+ }
+ BlobLogHeader header;
+ s = header.DecodeFrom(slice);
+ if (!s.ok()) {
+ return s;
+ }
+ fprintf(stdout, "Blob log header:\n");
+ fprintf(stdout, " Version : %" PRIu32 "\n", header.version);
+ fprintf(stdout, " Column Family ID : %" PRIu32 "\n",
+ header.column_family_id);
+ std::string compression_str;
+ if (!GetStringFromCompressionType(&compression_str, header.compression)
+ .ok()) {
+ compression_str = "Unrecongnized compression type (" +
+ ToString((int)header.compression) + ")";
+ }
+ fprintf(stdout, " Compression : %s\n", compression_str.c_str());
+ fprintf(stdout, " Expiration range : %s\n",
+ GetString(header.expiration_range).c_str());
+ *offset = BlobLogHeader::kSize;
+ *compression = header.compression;
+ return s;
+}
+
+Status BlobDumpTool::DumpBlobLogFooter(uint64_t file_size,
+ uint64_t* footer_offset) {
+ auto no_footer = [&]() {
+ *footer_offset = file_size;
+ fprintf(stdout, "No blob log footer.\n");
+ return Status::OK();
+ };
+ if (file_size < BlobLogHeader::kSize + BlobLogFooter::kSize) {
+ return no_footer();
+ }
+ Slice slice;
+ *footer_offset = file_size - BlobLogFooter::kSize;
+ Status s = Read(*footer_offset, BlobLogFooter::kSize, &slice);
+ if (!s.ok()) {
+ return s;
+ }
+ BlobLogFooter footer;
+ s = footer.DecodeFrom(slice);
+ if (!s.ok()) {
+ return no_footer();
+ }
+ fprintf(stdout, "Blob log footer:\n");
+ fprintf(stdout, " Blob count : %" PRIu64 "\n", footer.blob_count);
+ fprintf(stdout, " Expiration Range : %s\n",
+ GetString(footer.expiration_range).c_str());
+ return s;
+}
+
+Status BlobDumpTool::DumpRecord(DisplayType show_key, DisplayType show_blob,
+ DisplayType show_uncompressed_blob,
+ bool show_summary, CompressionType compression,
+ uint64_t* offset, uint64_t* total_records,
+ uint64_t* total_key_size,
+ uint64_t* total_blob_size,
+ uint64_t* total_uncompressed_blob_size) {
+ if (show_key != DisplayType::kNone) {
+ fprintf(stdout, "Read record with offset 0x%" PRIx64 " (%" PRIu64 "):\n",
+ *offset, *offset);
+ }
+ Slice slice;
+ Status s = Read(*offset, BlobLogRecord::kHeaderSize, &slice);
+ if (!s.ok()) {
+ return s;
+ }
+ BlobLogRecord record;
+ s = record.DecodeHeaderFrom(slice);
+ if (!s.ok()) {
+ return s;
+ }
+ uint64_t key_size = record.key_size;
+ uint64_t value_size = record.value_size;
+ if (show_key != DisplayType::kNone) {
+ fprintf(stdout, " key size : %" PRIu64 "\n", key_size);
+ fprintf(stdout, " value size : %" PRIu64 "\n", value_size);
+ fprintf(stdout, " expiration : %" PRIu64 "\n", record.expiration);
+ }
+ *offset += BlobLogRecord::kHeaderSize;
+ s = Read(*offset, static_cast<size_t>(key_size + value_size), &slice);
+ if (!s.ok()) {
+ return s;
+ }
+ // Decompress value
+ std::string uncompressed_value;
+ if (compression != kNoCompression &&
+ (show_uncompressed_blob != DisplayType::kNone || show_summary)) {
+ BlockContents contents;
+ UncompressionContext context(compression);
+ UncompressionInfo info(context, UncompressionDict::GetEmptyDict(),
+ compression);
+ s = UncompressBlockContentsForCompressionType(
+ info, slice.data() + key_size, static_cast<size_t>(value_size),
+ &contents, 2 /*compress_format_version*/,
+ ImmutableCFOptions(Options()));
+ if (!s.ok()) {
+ return s;
+ }
+ uncompressed_value = contents.data.ToString();
+ }
+ if (show_key != DisplayType::kNone) {
+ fprintf(stdout, " key : ");
+ DumpSlice(Slice(slice.data(), static_cast<size_t>(key_size)), show_key);
+ if (show_blob != DisplayType::kNone) {
+ fprintf(stdout, " blob : ");
+ DumpSlice(Slice(slice.data() + static_cast<size_t>(key_size), static_cast<size_t>(value_size)), show_blob);
+ }
+ if (show_uncompressed_blob != DisplayType::kNone) {
+ fprintf(stdout, " raw blob : ");
+ DumpSlice(Slice(uncompressed_value), show_uncompressed_blob);
+ }
+ }
+ *offset += key_size + value_size;
+ *total_records += 1;
+ *total_key_size += key_size;
+ *total_blob_size += value_size;
+ *total_uncompressed_blob_size += uncompressed_value.size();
+ return s;
+}
+
+void BlobDumpTool::DumpSlice(const Slice s, DisplayType type) {
+ if (type == DisplayType::kRaw) {
+ fprintf(stdout, "%s\n", s.ToString().c_str());
+ } else if (type == DisplayType::kHex) {
+ fprintf(stdout, "%s\n", s.ToString(true /*hex*/).c_str());
+ } else if (type == DisplayType::kDetail) {
+ char buf[100];
+ for (size_t i = 0; i < s.size(); i += 16) {
+ memset(buf, 0, sizeof(buf));
+ for (size_t j = 0; j < 16 && i + j < s.size(); j++) {
+ unsigned char c = s[i + j];
+ snprintf(buf + j * 3 + 15, 2, "%x", c >> 4);
+ snprintf(buf + j * 3 + 16, 2, "%x", c & 0xf);
+ snprintf(buf + j + 65, 2, "%c", (0x20 <= c && c <= 0x7e) ? c : '.');
+ }
+ for (size_t p = 0; p < sizeof(buf) - 1; p++) {
+ if (buf[p] == 0) {
+ buf[p] = ' ';
+ }
+ }
+ fprintf(stdout, "%s\n", i == 0 ? buf + 15 : buf);
+ }
+ }
+}
+
+template <class T>
+std::string BlobDumpTool::GetString(std::pair<T, T> p) {
+ if (p.first == 0 && p.second == 0) {
+ return "nil";
+ }
+ return "(" + ToString(p.first) + ", " + ToString(p.second) + ")";
+}
+
+} // namespace blob_db
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/blob_db/blob_dump_tool.h b/src/rocksdb/utilities/blob_db/blob_dump_tool.h
new file mode 100644
index 000000000..498543ada
--- /dev/null
+++ b/src/rocksdb/utilities/blob_db/blob_dump_tool.h
@@ -0,0 +1,57 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+#pragma once
+#ifndef ROCKSDB_LITE
+
+#include <memory>
+#include <string>
+#include <utility>
+#include "file/random_access_file_reader.h"
+#include "rocksdb/slice.h"
+#include "rocksdb/status.h"
+#include "utilities/blob_db/blob_log_format.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace blob_db {
+
+class BlobDumpTool {
+ public:
+ enum class DisplayType {
+ kNone,
+ kRaw,
+ kHex,
+ kDetail,
+ };
+
+ BlobDumpTool();
+
+ Status Run(const std::string& filename, DisplayType show_key,
+ DisplayType show_blob, DisplayType show_uncompressed_blob,
+ bool show_summary);
+
+ private:
+ std::unique_ptr<RandomAccessFileReader> reader_;
+ std::unique_ptr<char[]> buffer_;
+ size_t buffer_size_;
+
+ Status Read(uint64_t offset, size_t size, Slice* result);
+ Status DumpBlobLogHeader(uint64_t* offset, CompressionType* compression);
+ Status DumpBlobLogFooter(uint64_t file_size, uint64_t* footer_offset);
+ Status DumpRecord(DisplayType show_key, DisplayType show_blob,
+ DisplayType show_uncompressed_blob, bool show_summary,
+ CompressionType compression, uint64_t* offset,
+ uint64_t* total_records, uint64_t* total_key_size,
+ uint64_t* total_blob_size,
+ uint64_t* total_uncompressed_blob_size);
+ void DumpSlice(const Slice s, DisplayType type);
+
+ template <class T>
+ std::string GetString(std::pair<T, T> p);
+};
+
+} // namespace blob_db
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/blob_db/blob_file.cc b/src/rocksdb/utilities/blob_db/blob_file.cc
new file mode 100644
index 000000000..f32e29529
--- /dev/null
+++ b/src/rocksdb/utilities/blob_db/blob_file.cc
@@ -0,0 +1,320 @@
+
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+#ifndef ROCKSDB_LITE
+#include "utilities/blob_db/blob_file.h"
+
+#include <stdio.h>
+#include <cinttypes>
+
+#include <algorithm>
+#include <memory>
+
+#include "db/column_family.h"
+#include "db/db_impl/db_impl.h"
+#include "db/dbformat.h"
+#include "env/composite_env_wrapper.h"
+#include "file/filename.h"
+#include "file/readahead_raf.h"
+#include "logging/logging.h"
+#include "utilities/blob_db/blob_db_impl.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+namespace blob_db {
+
+BlobFile::BlobFile(const BlobDBImpl* p, const std::string& bdir, uint64_t fn,
+ Logger* info_log)
+ : parent_(p), path_to_dir_(bdir), file_number_(fn), info_log_(info_log) {}
+
+BlobFile::BlobFile(const BlobDBImpl* p, const std::string& bdir, uint64_t fn,
+ Logger* info_log, uint32_t column_family_id,
+ CompressionType compression, bool has_ttl,
+ const ExpirationRange& expiration_range)
+ : parent_(p),
+ path_to_dir_(bdir),
+ file_number_(fn),
+ info_log_(info_log),
+ column_family_id_(column_family_id),
+ compression_(compression),
+ has_ttl_(has_ttl),
+ expiration_range_(expiration_range),
+ header_(column_family_id, compression, has_ttl, expiration_range),
+ header_valid_(true) {}
+
+BlobFile::~BlobFile() {
+ if (obsolete_) {
+ std::string pn(PathName());
+ Status s = Env::Default()->DeleteFile(PathName());
+ if (!s.ok()) {
+ // ROCKS_LOG_INFO(db_options_.info_log,
+ // "File could not be deleted %s", pn.c_str());
+ }
+ }
+}
+
+uint32_t BlobFile::GetColumnFamilyId() const { return column_family_id_; }
+
+std::string BlobFile::PathName() const {
+ return BlobFileName(path_to_dir_, file_number_);
+}
+
+std::shared_ptr<Reader> BlobFile::OpenRandomAccessReader(
+ Env* env, const DBOptions& db_options,
+ const EnvOptions& env_options) const {
+ constexpr size_t kReadaheadSize = 2 * 1024 * 1024;
+ std::unique_ptr<RandomAccessFile> sfile;
+ std::string path_name(PathName());
+ Status s = env->NewRandomAccessFile(path_name, &sfile, env_options);
+ if (!s.ok()) {
+ // report something here.
+ return nullptr;
+ }
+ sfile = NewReadaheadRandomAccessFile(std::move(sfile), kReadaheadSize);
+
+ std::unique_ptr<RandomAccessFileReader> sfile_reader;
+ sfile_reader.reset(new RandomAccessFileReader(
+ NewLegacyRandomAccessFileWrapper(sfile), path_name));
+
+ std::shared_ptr<Reader> log_reader = std::make_shared<Reader>(
+ std::move(sfile_reader), db_options.env, db_options.statistics.get());
+
+ return log_reader;
+}
+
+std::string BlobFile::DumpState() const {
+ char str[1000];
+ snprintf(
+ str, sizeof(str),
+ "path: %s fn: %" PRIu64 " blob_count: %" PRIu64 " file_size: %" PRIu64
+ " closed: %d obsolete: %d expiration_range: (%" PRIu64 ", %" PRIu64
+ "), writer: %d reader: %d",
+ path_to_dir_.c_str(), file_number_, blob_count_.load(), file_size_.load(),
+ closed_.load(), obsolete_.load(), expiration_range_.first,
+ expiration_range_.second, (!!log_writer_), (!!ra_file_reader_));
+ return str;
+}
+
+void BlobFile::MarkObsolete(SequenceNumber sequence) {
+ assert(Immutable());
+ obsolete_sequence_ = sequence;
+ obsolete_.store(true);
+}
+
+bool BlobFile::NeedsFsync(bool hard, uint64_t bytes_per_sync) const {
+ assert(last_fsync_ <= file_size_);
+ return (hard) ? file_size_ > last_fsync_
+ : (file_size_ - last_fsync_) >= bytes_per_sync;
+}
+
+Status BlobFile::WriteFooterAndCloseLocked(SequenceNumber sequence) {
+ BlobLogFooter footer;
+ footer.blob_count = blob_count_;
+ if (HasTTL()) {
+ footer.expiration_range = expiration_range_;
+ }
+
+ // this will close the file and reset the Writable File Pointer.
+ Status s = log_writer_->AppendFooter(footer);
+ if (s.ok()) {
+ closed_ = true;
+ immutable_sequence_ = sequence;
+ file_size_ += BlobLogFooter::kSize;
+ }
+ // delete the sequential writer
+ log_writer_.reset();
+ return s;
+}
+
+Status BlobFile::ReadFooter(BlobLogFooter* bf) {
+ if (file_size_ < (BlobLogHeader::kSize + BlobLogFooter::kSize)) {
+ return Status::IOError("File does not have footer", PathName());
+ }
+
+ uint64_t footer_offset = file_size_ - BlobLogFooter::kSize;
+ // assume that ra_file_reader_ is valid before we enter this
+ assert(ra_file_reader_);
+
+ Slice result;
+ char scratch[BlobLogFooter::kSize + 10];
+ Status s = ra_file_reader_->Read(footer_offset, BlobLogFooter::kSize, &result,
+ scratch);
+ if (!s.ok()) return s;
+ if (result.size() != BlobLogFooter::kSize) {
+ // should not happen
+ return Status::IOError("EOF reached before footer");
+ }
+
+ s = bf->DecodeFrom(result);
+ return s;
+}
+
+Status BlobFile::SetFromFooterLocked(const BlobLogFooter& footer) {
+ // assume that file has been fully fsync'd
+ last_fsync_.store(file_size_);
+ blob_count_ = footer.blob_count;
+ expiration_range_ = footer.expiration_range;
+ closed_ = true;
+ return Status::OK();
+}
+
+Status BlobFile::Fsync() {
+ Status s;
+ if (log_writer_.get()) {
+ s = log_writer_->Sync();
+ last_fsync_.store(file_size_.load());
+ }
+ return s;
+}
+
+void BlobFile::CloseRandomAccessLocked() {
+ ra_file_reader_.reset();
+ last_access_ = -1;
+}
+
+Status BlobFile::GetReader(Env* env, const EnvOptions& env_options,
+ std::shared_ptr<RandomAccessFileReader>* reader,
+ bool* fresh_open) {
+ assert(reader != nullptr);
+ assert(fresh_open != nullptr);
+ *fresh_open = false;
+ int64_t current_time = 0;
+ env->GetCurrentTime(&current_time);
+ last_access_.store(current_time);
+ Status s;
+
+ {
+ ReadLock lockbfile_r(&mutex_);
+ if (ra_file_reader_) {
+ *reader = ra_file_reader_;
+ return s;
+ }
+ }
+
+ WriteLock lockbfile_w(&mutex_);
+ // Double check.
+ if (ra_file_reader_) {
+ *reader = ra_file_reader_;
+ return s;
+ }
+
+ std::unique_ptr<RandomAccessFile> rfile;
+ s = env->NewRandomAccessFile(PathName(), &rfile, env_options);
+ if (!s.ok()) {
+ ROCKS_LOG_ERROR(info_log_,
+ "Failed to open blob file for random-read: %s status: '%s'"
+ " exists: '%s'",
+ PathName().c_str(), s.ToString().c_str(),
+ env->FileExists(PathName()).ToString().c_str());
+ return s;
+ }
+
+ ra_file_reader_ = std::make_shared<RandomAccessFileReader>(
+ NewLegacyRandomAccessFileWrapper(rfile), PathName());
+ *reader = ra_file_reader_;
+ *fresh_open = true;
+ return s;
+}
+
+Status BlobFile::ReadMetadata(Env* env, const EnvOptions& env_options) {
+ assert(Immutable());
+ // Get file size.
+ uint64_t file_size = 0;
+ Status s = env->GetFileSize(PathName(), &file_size);
+ if (s.ok()) {
+ file_size_ = file_size;
+ } else {
+ ROCKS_LOG_ERROR(info_log_,
+ "Failed to get size of blob file %" PRIu64
+ ", status: %s",
+ file_number_, s.ToString().c_str());
+ return s;
+ }
+ if (file_size < BlobLogHeader::kSize) {
+ ROCKS_LOG_ERROR(info_log_,
+ "Incomplete blob file blob file %" PRIu64
+ ", size: %" PRIu64,
+ file_number_, file_size);
+ return Status::Corruption("Incomplete blob file header.");
+ }
+
+ // Create file reader.
+ std::unique_ptr<RandomAccessFile> file;
+ s = env->NewRandomAccessFile(PathName(), &file, env_options);
+ if (!s.ok()) {
+ ROCKS_LOG_ERROR(info_log_,
+ "Failed to open blob file %" PRIu64 ", status: %s",
+ file_number_, s.ToString().c_str());
+ return s;
+ }
+ std::unique_ptr<RandomAccessFileReader> file_reader(
+ new RandomAccessFileReader(NewLegacyRandomAccessFileWrapper(file),
+ PathName()));
+
+ // Read file header.
+ char header_buf[BlobLogHeader::kSize];
+ Slice header_slice;
+ s = file_reader->Read(0, BlobLogHeader::kSize, &header_slice, header_buf);
+ if (!s.ok()) {
+ ROCKS_LOG_ERROR(info_log_,
+ "Failed to read header of blob file %" PRIu64
+ ", status: %s",
+ file_number_, s.ToString().c_str());
+ return s;
+ }
+ BlobLogHeader header;
+ s = header.DecodeFrom(header_slice);
+ if (!s.ok()) {
+ ROCKS_LOG_ERROR(info_log_,
+ "Failed to decode header of blob file %" PRIu64
+ ", status: %s",
+ file_number_, s.ToString().c_str());
+ return s;
+ }
+ column_family_id_ = header.column_family_id;
+ compression_ = header.compression;
+ has_ttl_ = header.has_ttl;
+ if (has_ttl_) {
+ expiration_range_ = header.expiration_range;
+ }
+ header_valid_ = true;
+
+ // Read file footer.
+ if (file_size_ < BlobLogHeader::kSize + BlobLogFooter::kSize) {
+ // OK not to have footer.
+ assert(!footer_valid_);
+ return Status::OK();
+ }
+ char footer_buf[BlobLogFooter::kSize];
+ Slice footer_slice;
+ s = file_reader->Read(file_size - BlobLogFooter::kSize, BlobLogFooter::kSize,
+ &footer_slice, footer_buf);
+ if (!s.ok()) {
+ ROCKS_LOG_ERROR(info_log_,
+ "Failed to read footer of blob file %" PRIu64
+ ", status: %s",
+ file_number_, s.ToString().c_str());
+ return s;
+ }
+ BlobLogFooter footer;
+ s = footer.DecodeFrom(footer_slice);
+ if (!s.ok()) {
+ // OK not to have footer.
+ assert(!footer_valid_);
+ return Status::OK();
+ }
+ blob_count_ = footer.blob_count;
+ if (has_ttl_) {
+ assert(header.expiration_range.first <= footer.expiration_range.first);
+ assert(header.expiration_range.second >= footer.expiration_range.second);
+ expiration_range_ = footer.expiration_range;
+ }
+ footer_valid_ = true;
+ return Status::OK();
+}
+
+} // namespace blob_db
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/blob_db/blob_file.h b/src/rocksdb/utilities/blob_db/blob_file.h
new file mode 100644
index 000000000..17d39b542
--- /dev/null
+++ b/src/rocksdb/utilities/blob_db/blob_file.h
@@ -0,0 +1,252 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+#pragma once
+#ifndef ROCKSDB_LITE
+
+#include <atomic>
+#include <limits>
+#include <memory>
+#include <unordered_set>
+
+#include "file/random_access_file_reader.h"
+#include "port/port.h"
+#include "rocksdb/env.h"
+#include "rocksdb/options.h"
+#include "utilities/blob_db/blob_log_format.h"
+#include "utilities/blob_db/blob_log_reader.h"
+#include "utilities/blob_db/blob_log_writer.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace blob_db {
+
+class BlobDBImpl;
+
+class BlobFile {
+ friend class BlobDBImpl;
+ friend struct BlobFileComparator;
+ friend struct BlobFileComparatorTTL;
+ friend class BlobIndexCompactionFilterGC;
+
+ private:
+ // access to parent
+ const BlobDBImpl* parent_{nullptr};
+
+ // path to blob directory
+ std::string path_to_dir_;
+
+ // the id of the file.
+ // the above 2 are created during file creation and never changed
+ // after that
+ uint64_t file_number_{0};
+
+ // The file numbers of the SST files whose oldest blob file reference
+ // points to this blob file.
+ std::unordered_set<uint64_t> linked_sst_files_;
+
+ // Info log.
+ Logger* info_log_{nullptr};
+
+ // Column family id.
+ uint32_t column_family_id_{std::numeric_limits<uint32_t>::max()};
+
+ // Compression type of blobs in the file
+ CompressionType compression_{kNoCompression};
+
+ // If true, the keys in this file all has TTL. Otherwise all keys don't
+ // have TTL.
+ bool has_ttl_{false};
+
+ // TTL range of blobs in the file.
+ ExpirationRange expiration_range_;
+
+ // number of blobs in the file
+ std::atomic<uint64_t> blob_count_{0};
+
+ // size of the file
+ std::atomic<uint64_t> file_size_{0};
+
+ BlobLogHeader header_;
+
+ // closed_ = true implies the file is no more mutable
+ // no more blobs will be appended and the footer has been written out
+ std::atomic<bool> closed_{false};
+
+ // The latest sequence number when the file was closed/made immutable.
+ SequenceNumber immutable_sequence_{0};
+
+ // Whether the file was marked obsolete (due to either TTL or GC).
+ // obsolete_ still needs to do iterator/snapshot checks
+ std::atomic<bool> obsolete_{false};
+
+ // The last sequence number by the time the file marked as obsolete.
+ // Data in this file is visible to a snapshot taken before the sequence.
+ SequenceNumber obsolete_sequence_{0};
+
+ // Sequential/Append writer for blobs
+ std::shared_ptr<Writer> log_writer_;
+
+ // random access file reader for GET calls
+ std::shared_ptr<RandomAccessFileReader> ra_file_reader_;
+
+ // This Read-Write mutex is per file specific and protects
+ // all the datastructures
+ mutable port::RWMutex mutex_;
+
+ // time when the random access reader was last created.
+ std::atomic<std::int64_t> last_access_{-1};
+
+ // last time file was fsync'd/fdatasyncd
+ std::atomic<uint64_t> last_fsync_{0};
+
+ bool header_valid_{false};
+
+ bool footer_valid_{false};
+
+ public:
+ BlobFile() = default;
+
+ BlobFile(const BlobDBImpl* parent, const std::string& bdir, uint64_t fnum,
+ Logger* info_log);
+
+ BlobFile(const BlobDBImpl* parent, const std::string& bdir, uint64_t fnum,
+ Logger* info_log, uint32_t column_family_id,
+ CompressionType compression, bool has_ttl,
+ const ExpirationRange& expiration_range);
+
+ ~BlobFile();
+
+ uint32_t GetColumnFamilyId() const;
+
+ // Returns log file's absolute pathname.
+ std::string PathName() const;
+
+ // Primary identifier for blob file.
+ // once the file is created, this never changes
+ uint64_t BlobFileNumber() const { return file_number_; }
+
+ // Get the set of SST files whose oldest blob file reference points to
+ // this file.
+ const std::unordered_set<uint64_t>& GetLinkedSstFiles() const {
+ return linked_sst_files_;
+ }
+
+ // Link an SST file whose oldest blob file reference points to this file.
+ void LinkSstFile(uint64_t sst_file_number) {
+ assert(linked_sst_files_.find(sst_file_number) == linked_sst_files_.end());
+ linked_sst_files_.insert(sst_file_number);
+ }
+
+ // Unlink an SST file whose oldest blob file reference points to this file.
+ void UnlinkSstFile(uint64_t sst_file_number) {
+ auto it = linked_sst_files_.find(sst_file_number);
+ assert(it != linked_sst_files_.end());
+ linked_sst_files_.erase(it);
+ }
+
+ // the following functions are atomic, and don't need
+ // read lock
+ uint64_t BlobCount() const {
+ return blob_count_.load(std::memory_order_acquire);
+ }
+
+ std::string DumpState() const;
+
+ // if the file is not taking any more appends.
+ bool Immutable() const { return closed_.load(); }
+
+ // Mark the file as immutable.
+ // REQUIRES: write lock held, or access from single thread (on DB open).
+ void MarkImmutable(SequenceNumber sequence) {
+ closed_ = true;
+ immutable_sequence_ = sequence;
+ }
+
+ SequenceNumber GetImmutableSequence() const {
+ assert(Immutable());
+ return immutable_sequence_;
+ }
+
+ // Whether the file was marked obsolete (due to either TTL or GC).
+ bool Obsolete() const {
+ assert(Immutable() || !obsolete_.load());
+ return obsolete_.load();
+ }
+
+ // Mark file as obsolete (due to either TTL or GC). The file is not visible to
+ // snapshots with sequence greater or equal to the given sequence.
+ void MarkObsolete(SequenceNumber sequence);
+
+ SequenceNumber GetObsoleteSequence() const {
+ assert(Obsolete());
+ return obsolete_sequence_;
+ }
+
+ // we will assume this is atomic
+ bool NeedsFsync(bool hard, uint64_t bytes_per_sync) const;
+
+ Status Fsync();
+
+ uint64_t GetFileSize() const {
+ return file_size_.load(std::memory_order_acquire);
+ }
+
+ // All Get functions which are not atomic, will need ReadLock on the mutex
+
+ ExpirationRange GetExpirationRange() const { return expiration_range_; }
+
+ void ExtendExpirationRange(uint64_t expiration) {
+ expiration_range_.first = std::min(expiration_range_.first, expiration);
+ expiration_range_.second = std::max(expiration_range_.second, expiration);
+ }
+
+ bool HasTTL() const { return has_ttl_; }
+
+ void SetHasTTL(bool has_ttl) { has_ttl_ = has_ttl; }
+
+ CompressionType GetCompressionType() const { return compression_; }
+
+ std::shared_ptr<Writer> GetWriter() const { return log_writer_; }
+
+ // Read blob file header and footer. Return corruption if file header is
+ // malform or incomplete. If footer is malform or incomplete, set
+ // footer_valid_ to false and return Status::OK.
+ Status ReadMetadata(Env* env, const EnvOptions& env_options);
+
+ Status GetReader(Env* env, const EnvOptions& env_options,
+ std::shared_ptr<RandomAccessFileReader>* reader,
+ bool* fresh_open);
+
+ private:
+ std::shared_ptr<Reader> OpenRandomAccessReader(
+ Env* env, const DBOptions& db_options,
+ const EnvOptions& env_options) const;
+
+ Status ReadFooter(BlobLogFooter* footer);
+
+ Status WriteFooterAndCloseLocked(SequenceNumber sequence);
+
+ void CloseRandomAccessLocked();
+
+ // this is used, when you are reading only the footer of a
+ // previously closed file
+ Status SetFromFooterLocked(const BlobLogFooter& footer);
+
+ void set_expiration_range(const ExpirationRange& expiration_range) {
+ expiration_range_ = expiration_range;
+ }
+
+ // The following functions are atomic, and don't need locks
+ void SetFileSize(uint64_t fs) { file_size_ = fs; }
+
+ void SetBlobCount(uint64_t bc) { blob_count_ = bc; }
+
+ void BlobRecordAdded(uint64_t record_size) {
+ ++blob_count_;
+ file_size_ += record_size;
+ }
+};
+} // namespace blob_db
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/blob_db/blob_log_format.cc b/src/rocksdb/utilities/blob_db/blob_log_format.cc
new file mode 100644
index 000000000..64894ca7b
--- /dev/null
+++ b/src/rocksdb/utilities/blob_db/blob_log_format.cc
@@ -0,0 +1,149 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+#ifndef ROCKSDB_LITE
+
+#include "utilities/blob_db/blob_log_format.h"
+
+#include "util/coding.h"
+#include "util/crc32c.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace blob_db {
+
+void BlobLogHeader::EncodeTo(std::string* dst) {
+ assert(dst != nullptr);
+ dst->clear();
+ dst->reserve(BlobLogHeader::kSize);
+ PutFixed32(dst, kMagicNumber);
+ PutFixed32(dst, version);
+ PutFixed32(dst, column_family_id);
+ unsigned char flags = (has_ttl ? 1 : 0);
+ dst->push_back(flags);
+ dst->push_back(compression);
+ PutFixed64(dst, expiration_range.first);
+ PutFixed64(dst, expiration_range.second);
+}
+
+Status BlobLogHeader::DecodeFrom(Slice src) {
+ static const std::string kErrorMessage =
+ "Error while decoding blob log header";
+ if (src.size() != BlobLogHeader::kSize) {
+ return Status::Corruption(kErrorMessage,
+ "Unexpected blob file header size");
+ }
+ uint32_t magic_number;
+ unsigned char flags;
+ if (!GetFixed32(&src, &magic_number) || !GetFixed32(&src, &version) ||
+ !GetFixed32(&src, &column_family_id)) {
+ return Status::Corruption(
+ kErrorMessage,
+ "Error decoding magic number, version and column family id");
+ }
+ if (magic_number != kMagicNumber) {
+ return Status::Corruption(kErrorMessage, "Magic number mismatch");
+ }
+ if (version != kVersion1) {
+ return Status::Corruption(kErrorMessage, "Unknown header version");
+ }
+ flags = src.data()[0];
+ compression = static_cast<CompressionType>(src.data()[1]);
+ has_ttl = (flags & 1) == 1;
+ src.remove_prefix(2);
+ if (!GetFixed64(&src, &expiration_range.first) ||
+ !GetFixed64(&src, &expiration_range.second)) {
+ return Status::Corruption(kErrorMessage, "Error decoding expiration range");
+ }
+ return Status::OK();
+}
+
+void BlobLogFooter::EncodeTo(std::string* dst) {
+ assert(dst != nullptr);
+ dst->clear();
+ dst->reserve(BlobLogFooter::kSize);
+ PutFixed32(dst, kMagicNumber);
+ PutFixed64(dst, blob_count);
+ PutFixed64(dst, expiration_range.first);
+ PutFixed64(dst, expiration_range.second);
+ crc = crc32c::Value(dst->c_str(), dst->size());
+ crc = crc32c::Mask(crc);
+ PutFixed32(dst, crc);
+}
+
+Status BlobLogFooter::DecodeFrom(Slice src) {
+ static const std::string kErrorMessage =
+ "Error while decoding blob log footer";
+ if (src.size() != BlobLogFooter::kSize) {
+ return Status::Corruption(kErrorMessage,
+ "Unexpected blob file footer size");
+ }
+ uint32_t src_crc = 0;
+ src_crc = crc32c::Value(src.data(), BlobLogFooter::kSize - sizeof(uint32_t));
+ src_crc = crc32c::Mask(src_crc);
+ uint32_t magic_number = 0;
+ if (!GetFixed32(&src, &magic_number) || !GetFixed64(&src, &blob_count) ||
+ !GetFixed64(&src, &expiration_range.first) ||
+ !GetFixed64(&src, &expiration_range.second) || !GetFixed32(&src, &crc)) {
+ return Status::Corruption(kErrorMessage, "Error decoding content");
+ }
+ if (magic_number != kMagicNumber) {
+ return Status::Corruption(kErrorMessage, "Magic number mismatch");
+ }
+ if (src_crc != crc) {
+ return Status::Corruption(kErrorMessage, "CRC mismatch");
+ }
+ return Status::OK();
+}
+
+void BlobLogRecord::EncodeHeaderTo(std::string* dst) {
+ assert(dst != nullptr);
+ dst->clear();
+ dst->reserve(BlobLogRecord::kHeaderSize + key.size() + value.size());
+ PutFixed64(dst, key.size());
+ PutFixed64(dst, value.size());
+ PutFixed64(dst, expiration);
+ header_crc = crc32c::Value(dst->c_str(), dst->size());
+ header_crc = crc32c::Mask(header_crc);
+ PutFixed32(dst, header_crc);
+ blob_crc = crc32c::Value(key.data(), key.size());
+ blob_crc = crc32c::Extend(blob_crc, value.data(), value.size());
+ blob_crc = crc32c::Mask(blob_crc);
+ PutFixed32(dst, blob_crc);
+}
+
+Status BlobLogRecord::DecodeHeaderFrom(Slice src) {
+ static const std::string kErrorMessage = "Error while decoding blob record";
+ if (src.size() != BlobLogRecord::kHeaderSize) {
+ return Status::Corruption(kErrorMessage,
+ "Unexpected blob record header size");
+ }
+ uint32_t src_crc = 0;
+ src_crc = crc32c::Value(src.data(), BlobLogRecord::kHeaderSize - 8);
+ src_crc = crc32c::Mask(src_crc);
+ if (!GetFixed64(&src, &key_size) || !GetFixed64(&src, &value_size) ||
+ !GetFixed64(&src, &expiration) || !GetFixed32(&src, &header_crc) ||
+ !GetFixed32(&src, &blob_crc)) {
+ return Status::Corruption(kErrorMessage, "Error decoding content");
+ }
+ if (src_crc != header_crc) {
+ return Status::Corruption(kErrorMessage, "Header CRC mismatch");
+ }
+ return Status::OK();
+}
+
+Status BlobLogRecord::CheckBlobCRC() const {
+ uint32_t expected_crc = 0;
+ expected_crc = crc32c::Value(key.data(), key.size());
+ expected_crc = crc32c::Extend(expected_crc, value.data(), value.size());
+ expected_crc = crc32c::Mask(expected_crc);
+ if (expected_crc != blob_crc) {
+ return Status::Corruption("Blob CRC mismatch");
+ }
+ return Status::OK();
+}
+
+} // namespace blob_db
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/blob_db/blob_log_format.h b/src/rocksdb/utilities/blob_db/blob_log_format.h
new file mode 100644
index 000000000..26cdf6e71
--- /dev/null
+++ b/src/rocksdb/utilities/blob_db/blob_log_format.h
@@ -0,0 +1,133 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Log format information shared by reader and writer.
+
+#pragma once
+
+#ifndef ROCKSDB_LITE
+
+#include <limits>
+#include <memory>
+#include <utility>
+
+#include "rocksdb/options.h"
+#include "rocksdb/slice.h"
+#include "rocksdb/status.h"
+#include "rocksdb/types.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace blob_db {
+
+constexpr uint32_t kMagicNumber = 2395959; // 0x00248f37
+constexpr uint32_t kVersion1 = 1;
+constexpr uint64_t kNoExpiration = std::numeric_limits<uint64_t>::max();
+
+using ExpirationRange = std::pair<uint64_t, uint64_t>;
+
+// Format of blob log file header (30 bytes):
+//
+// +--------------+---------+---------+-------+-------------+-------------------+
+// | magic number | version | cf id | flags | compression | expiration range |
+// +--------------+---------+---------+-------+-------------+-------------------+
+// | Fixed32 | Fixed32 | Fixed32 | char | char | Fixed64 Fixed64 |
+// +--------------+---------+---------+-------+-------------+-------------------+
+//
+// List of flags:
+// has_ttl: Whether the file contain TTL data.
+//
+// Expiration range in the header is a rough range based on
+// blob_db_options.ttl_range_secs.
+struct BlobLogHeader {
+ static constexpr size_t kSize = 30;
+
+ BlobLogHeader() = default;
+ BlobLogHeader(uint32_t _column_family_id, CompressionType _compression,
+ bool _has_ttl, const ExpirationRange& _expiration_range)
+ : column_family_id(_column_family_id),
+ compression(_compression),
+ has_ttl(_has_ttl),
+ expiration_range(_expiration_range) {}
+
+ uint32_t version = kVersion1;
+ uint32_t column_family_id = 0;
+ CompressionType compression = kNoCompression;
+ bool has_ttl = false;
+ ExpirationRange expiration_range;
+
+ void EncodeTo(std::string* dst);
+
+ Status DecodeFrom(Slice slice);
+};
+
+// Format of blob log file footer (32 bytes):
+//
+// +--------------+------------+-------------------+------------+
+// | magic number | blob count | expiration range | footer CRC |
+// +--------------+------------+-------------------+------------+
+// | Fixed32 | Fixed64 | Fixed64 + Fixed64 | Fixed32 |
+// +--------------+------------+-------------------+------------+
+//
+// The footer will be presented only when the blob file is properly closed.
+//
+// Unlike the same field in file header, expiration range in the footer is the
+// range of smallest and largest expiration of the data in this file.
+struct BlobLogFooter {
+ static constexpr size_t kSize = 32;
+
+ uint64_t blob_count = 0;
+ ExpirationRange expiration_range = std::make_pair(0, 0);
+ uint32_t crc = 0;
+
+ void EncodeTo(std::string* dst);
+
+ Status DecodeFrom(Slice slice);
+};
+
+// Blob record format (32 bytes header + key + value):
+//
+// +------------+--------------+------------+------------+----------+---------+-----------+
+// | key length | value length | expiration | header CRC | blob CRC | key | value |
+// +------------+--------------+------------+------------+----------+---------+-----------+
+// | Fixed64 | Fixed64 | Fixed64 | Fixed32 | Fixed32 | key len | value len |
+// +------------+--------------+------------+------------+----------+---------+-----------+
+//
+// If file has has_ttl = false, expiration field is always 0, and the blob
+// doesn't has expiration.
+//
+// Also note that if compression is used, value is compressed value and value
+// length is compressed value length.
+//
+// Header CRC is the checksum of (key_len + val_len + expiration), while
+// blob CRC is the checksum of (key + value).
+//
+// We could use variable length encoding (Varint64) to save more space, but it
+// make reader more complicated.
+struct BlobLogRecord {
+ // header include fields up to blob CRC
+ static constexpr size_t kHeaderSize = 32;
+
+ uint64_t key_size = 0;
+ uint64_t value_size = 0;
+ uint64_t expiration = 0;
+ uint32_t header_crc = 0;
+ uint32_t blob_crc = 0;
+ Slice key;
+ Slice value;
+ std::unique_ptr<char[]> key_buf;
+ std::unique_ptr<char[]> value_buf;
+
+ uint64_t record_size() const { return kHeaderSize + key_size + value_size; }
+
+ void EncodeHeaderTo(std::string* dst);
+
+ Status DecodeHeaderFrom(Slice src);
+
+ Status CheckBlobCRC() const;
+};
+
+} // namespace blob_db
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/blob_db/blob_log_reader.cc b/src/rocksdb/utilities/blob_db/blob_log_reader.cc
new file mode 100644
index 000000000..1a4b5ac81
--- /dev/null
+++ b/src/rocksdb/utilities/blob_db/blob_log_reader.cc
@@ -0,0 +1,105 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+#ifndef ROCKSDB_LITE
+
+#include "utilities/blob_db/blob_log_reader.h"
+
+#include <algorithm>
+
+#include "file/random_access_file_reader.h"
+#include "monitoring/statistics.h"
+#include "util/stop_watch.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace blob_db {
+
+Reader::Reader(std::unique_ptr<RandomAccessFileReader>&& file_reader, Env* env,
+ Statistics* statistics)
+ : file_(std::move(file_reader)),
+ env_(env),
+ statistics_(statistics),
+ buffer_(),
+ next_byte_(0) {}
+
+Status Reader::ReadSlice(uint64_t size, Slice* slice, char* buf) {
+ StopWatch read_sw(env_, statistics_, BLOB_DB_BLOB_FILE_READ_MICROS);
+ Status s = file_->Read(next_byte_, static_cast<size_t>(size), slice, buf);
+ next_byte_ += size;
+ if (!s.ok()) {
+ return s;
+ }
+ RecordTick(statistics_, BLOB_DB_BLOB_FILE_BYTES_READ, slice->size());
+ if (slice->size() != size) {
+ return Status::Corruption("EOF reached while reading record");
+ }
+ return s;
+}
+
+Status Reader::ReadHeader(BlobLogHeader* header) {
+ assert(file_.get() != nullptr);
+ assert(next_byte_ == 0);
+ Status s = ReadSlice(BlobLogHeader::kSize, &buffer_, header_buf_);
+ if (!s.ok()) {
+ return s;
+ }
+
+ if (buffer_.size() != BlobLogHeader::kSize) {
+ return Status::Corruption("EOF reached before file header");
+ }
+
+ return header->DecodeFrom(buffer_);
+}
+
+Status Reader::ReadRecord(BlobLogRecord* record, ReadLevel level,
+ uint64_t* blob_offset) {
+ Status s = ReadSlice(BlobLogRecord::kHeaderSize, &buffer_, header_buf_);
+ if (!s.ok()) {
+ return s;
+ }
+ if (buffer_.size() != BlobLogRecord::kHeaderSize) {
+ return Status::Corruption("EOF reached before record header");
+ }
+
+ s = record->DecodeHeaderFrom(buffer_);
+ if (!s.ok()) {
+ return s;
+ }
+
+ uint64_t kb_size = record->key_size + record->value_size;
+ if (blob_offset != nullptr) {
+ *blob_offset = next_byte_ + record->key_size;
+ }
+
+ switch (level) {
+ case kReadHeader:
+ next_byte_ += kb_size;
+ break;
+
+ case kReadHeaderKey:
+ record->key_buf.reset(new char[record->key_size]);
+ s = ReadSlice(record->key_size, &record->key, record->key_buf.get());
+ next_byte_ += record->value_size;
+ break;
+
+ case kReadHeaderKeyBlob:
+ record->key_buf.reset(new char[record->key_size]);
+ s = ReadSlice(record->key_size, &record->key, record->key_buf.get());
+ if (s.ok()) {
+ record->value_buf.reset(new char[record->value_size]);
+ s = ReadSlice(record->value_size, &record->value,
+ record->value_buf.get());
+ }
+ if (s.ok()) {
+ s = record->CheckBlobCRC();
+ }
+ break;
+ }
+ return s;
+}
+
+} // namespace blob_db
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/blob_db/blob_log_reader.h b/src/rocksdb/utilities/blob_db/blob_log_reader.h
new file mode 100644
index 000000000..45fda284a
--- /dev/null
+++ b/src/rocksdb/utilities/blob_db/blob_log_reader.h
@@ -0,0 +1,82 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+#pragma once
+
+#ifndef ROCKSDB_LITE
+
+#include <memory>
+#include <string>
+
+#include "file/random_access_file_reader.h"
+#include "rocksdb/env.h"
+#include "rocksdb/slice.h"
+#include "rocksdb/statistics.h"
+#include "rocksdb/status.h"
+#include "utilities/blob_db/blob_log_format.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class SequentialFileReader;
+class Logger;
+
+namespace blob_db {
+
+/**
+ * Reader is a general purpose log stream reader implementation. The actual job
+ * of reading from the device is implemented by the SequentialFile interface.
+ *
+ * Please see Writer for details on the file and record layout.
+ */
+class Reader {
+ public:
+ enum ReadLevel {
+ kReadHeader,
+ kReadHeaderKey,
+ kReadHeaderKeyBlob,
+ };
+
+ // Create a reader that will return log records from "*file".
+ // "*file" must remain live while this Reader is in use.
+ Reader(std::unique_ptr<RandomAccessFileReader>&& file_reader, Env* env,
+ Statistics* statistics);
+ // No copying allowed
+ Reader(const Reader&) = delete;
+ Reader& operator=(const Reader&) = delete;
+
+ ~Reader() = default;
+
+ Status ReadHeader(BlobLogHeader* header);
+
+ // Read the next record into *record. Returns true if read
+ // successfully, false if we hit end of the input. May use
+ // "*scratch" as temporary storage. The contents filled in *record
+ // will only be valid until the next mutating operation on this
+ // reader or the next mutation to *scratch.
+ // If blob_offset is non-null, return offset of the blob through it.
+ Status ReadRecord(BlobLogRecord* record, ReadLevel level = kReadHeader,
+ uint64_t* blob_offset = nullptr);
+
+ void ResetNextByte() { next_byte_ = 0; }
+
+ uint64_t GetNextByte() const { return next_byte_; }
+
+ private:
+ Status ReadSlice(uint64_t size, Slice* slice, char* buf);
+
+ const std::unique_ptr<RandomAccessFileReader> file_;
+ Env* env_;
+ Statistics* statistics_;
+
+ Slice buffer_;
+ char header_buf_[BlobLogRecord::kHeaderSize];
+
+ // which byte to read next. For asserting proper usage
+ uint64_t next_byte_;
+};
+
+} // namespace blob_db
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/blob_db/blob_log_writer.cc b/src/rocksdb/utilities/blob_db/blob_log_writer.cc
new file mode 100644
index 000000000..2fe92263b
--- /dev/null
+++ b/src/rocksdb/utilities/blob_db/blob_log_writer.cc
@@ -0,0 +1,139 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+#ifndef ROCKSDB_LITE
+
+#include "utilities/blob_db/blob_log_writer.h"
+
+#include <cstdint>
+#include <string>
+
+#include "file/writable_file_writer.h"
+#include "monitoring/statistics.h"
+#include "rocksdb/env.h"
+#include "util/coding.h"
+#include "util/stop_watch.h"
+#include "utilities/blob_db/blob_log_format.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace blob_db {
+
+Writer::Writer(std::unique_ptr<WritableFileWriter>&& dest, Env* env,
+ Statistics* statistics, uint64_t log_number, uint64_t bpsync,
+ bool use_fs, uint64_t boffset)
+ : dest_(std::move(dest)),
+ env_(env),
+ statistics_(statistics),
+ log_number_(log_number),
+ block_offset_(boffset),
+ bytes_per_sync_(bpsync),
+ next_sync_offset_(0),
+ use_fsync_(use_fs),
+ last_elem_type_(kEtNone) {}
+
+Status Writer::Sync() {
+ StopWatch sync_sw(env_, statistics_, BLOB_DB_BLOB_FILE_SYNC_MICROS);
+ Status s = dest_->Sync(use_fsync_);
+ RecordTick(statistics_, BLOB_DB_BLOB_FILE_SYNCED);
+ return s;
+}
+
+Status Writer::WriteHeader(BlobLogHeader& header) {
+ assert(block_offset_ == 0);
+ assert(last_elem_type_ == kEtNone);
+ std::string str;
+ header.EncodeTo(&str);
+
+ Status s = dest_->Append(Slice(str));
+ if (s.ok()) {
+ block_offset_ += str.size();
+ s = dest_->Flush();
+ }
+ last_elem_type_ = kEtFileHdr;
+ RecordTick(statistics_, BLOB_DB_BLOB_FILE_BYTES_WRITTEN,
+ BlobLogHeader::kSize);
+ return s;
+}
+
+Status Writer::AppendFooter(BlobLogFooter& footer) {
+ assert(block_offset_ != 0);
+ assert(last_elem_type_ == kEtFileHdr || last_elem_type_ == kEtRecord);
+
+ std::string str;
+ footer.EncodeTo(&str);
+
+ Status s = dest_->Append(Slice(str));
+ if (s.ok()) {
+ block_offset_ += str.size();
+ s = dest_->Close();
+ dest_.reset();
+ }
+
+ last_elem_type_ = kEtFileFooter;
+ RecordTick(statistics_, BLOB_DB_BLOB_FILE_BYTES_WRITTEN,
+ BlobLogFooter::kSize);
+ return s;
+}
+
+Status Writer::AddRecord(const Slice& key, const Slice& val,
+ uint64_t expiration, uint64_t* key_offset,
+ uint64_t* blob_offset) {
+ assert(block_offset_ != 0);
+ assert(last_elem_type_ == kEtFileHdr || last_elem_type_ == kEtRecord);
+
+ std::string buf;
+ ConstructBlobHeader(&buf, key, val, expiration);
+
+ Status s = EmitPhysicalRecord(buf, key, val, key_offset, blob_offset);
+ return s;
+}
+
+Status Writer::AddRecord(const Slice& key, const Slice& val,
+ uint64_t* key_offset, uint64_t* blob_offset) {
+ assert(block_offset_ != 0);
+ assert(last_elem_type_ == kEtFileHdr || last_elem_type_ == kEtRecord);
+
+ std::string buf;
+ ConstructBlobHeader(&buf, key, val, 0);
+
+ Status s = EmitPhysicalRecord(buf, key, val, key_offset, blob_offset);
+ return s;
+}
+
+void Writer::ConstructBlobHeader(std::string* buf, const Slice& key,
+ const Slice& val, uint64_t expiration) {
+ BlobLogRecord record;
+ record.key = key;
+ record.value = val;
+ record.expiration = expiration;
+ record.EncodeHeaderTo(buf);
+}
+
+Status Writer::EmitPhysicalRecord(const std::string& headerbuf,
+ const Slice& key, const Slice& val,
+ uint64_t* key_offset, uint64_t* blob_offset) {
+ StopWatch write_sw(env_, statistics_, BLOB_DB_BLOB_FILE_WRITE_MICROS);
+ Status s = dest_->Append(Slice(headerbuf));
+ if (s.ok()) {
+ s = dest_->Append(key);
+ }
+ if (s.ok()) {
+ s = dest_->Append(val);
+ }
+ if (s.ok()) {
+ s = dest_->Flush();
+ }
+
+ *key_offset = block_offset_ + BlobLogRecord::kHeaderSize;
+ *blob_offset = *key_offset + key.size();
+ block_offset_ = *blob_offset + val.size();
+ last_elem_type_ = kEtRecord;
+ RecordTick(statistics_, BLOB_DB_BLOB_FILE_BYTES_WRITTEN,
+ BlobLogRecord::kHeaderSize + key.size() + val.size());
+ return s;
+}
+
+} // namespace blob_db
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/blob_db/blob_log_writer.h b/src/rocksdb/utilities/blob_db/blob_log_writer.h
new file mode 100644
index 000000000..29dbd00f1
--- /dev/null
+++ b/src/rocksdb/utilities/blob_db/blob_log_writer.h
@@ -0,0 +1,94 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+#pragma once
+
+#ifndef ROCKSDB_LITE
+
+#include <cstdint>
+#include <memory>
+#include <string>
+
+#include "rocksdb/env.h"
+#include "rocksdb/slice.h"
+#include "rocksdb/statistics.h"
+#include "rocksdb/status.h"
+#include "rocksdb/types.h"
+#include "utilities/blob_db/blob_log_format.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class WritableFileWriter;
+
+namespace blob_db {
+
+/**
+ * Writer is the blob log stream writer. It provides an append-only
+ * abstraction for writing blob data.
+ *
+ *
+ * Look at blob_db_format.h to see the details of the record formats.
+ */
+
+class Writer {
+ public:
+ // Create a writer that will append data to "*dest".
+ // "*dest" must be initially empty.
+ // "*dest" must remain live while this Writer is in use.
+ Writer(std::unique_ptr<WritableFileWriter>&& dest, Env* env,
+ Statistics* statistics, uint64_t log_number, uint64_t bpsync,
+ bool use_fsync, uint64_t boffset = 0);
+ // No copying allowed
+ Writer(const Writer&) = delete;
+ Writer& operator=(const Writer&) = delete;
+
+ ~Writer() = default;
+
+ static void ConstructBlobHeader(std::string* buf, const Slice& key,
+ const Slice& val, uint64_t expiration);
+
+ Status AddRecord(const Slice& key, const Slice& val, uint64_t* key_offset,
+ uint64_t* blob_offset);
+
+ Status AddRecord(const Slice& key, const Slice& val, uint64_t expiration,
+ uint64_t* key_offset, uint64_t* blob_offset);
+
+ Status EmitPhysicalRecord(const std::string& headerbuf, const Slice& key,
+ const Slice& val, uint64_t* key_offset,
+ uint64_t* blob_offset);
+
+ Status AppendFooter(BlobLogFooter& footer);
+
+ Status WriteHeader(BlobLogHeader& header);
+
+ WritableFileWriter* file() { return dest_.get(); }
+
+ const WritableFileWriter* file() const { return dest_.get(); }
+
+ uint64_t get_log_number() const { return log_number_; }
+
+ bool ShouldSync() const { return block_offset_ > next_sync_offset_; }
+
+ Status Sync();
+
+ void ResetSyncPointer() { next_sync_offset_ += bytes_per_sync_; }
+
+ private:
+ std::unique_ptr<WritableFileWriter> dest_;
+ Env* env_;
+ Statistics* statistics_;
+ uint64_t log_number_;
+ uint64_t block_offset_; // Current offset in block
+ uint64_t bytes_per_sync_;
+ uint64_t next_sync_offset_;
+ bool use_fsync_;
+
+ public:
+ enum ElemType { kEtNone, kEtFileHdr, kEtRecord, kEtFileFooter };
+ ElemType last_elem_type_;
+};
+
+} // namespace blob_db
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/cassandra/cassandra_compaction_filter.cc b/src/rocksdb/utilities/cassandra/cassandra_compaction_filter.cc
new file mode 100644
index 000000000..f0a00e4d1
--- /dev/null
+++ b/src/rocksdb/utilities/cassandra/cassandra_compaction_filter.cc
@@ -0,0 +1,47 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "utilities/cassandra/cassandra_compaction_filter.h"
+#include <string>
+#include "rocksdb/slice.h"
+#include "utilities/cassandra/format.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace cassandra {
+
+const char* CassandraCompactionFilter::Name() const {
+ return "CassandraCompactionFilter";
+}
+
+CompactionFilter::Decision CassandraCompactionFilter::FilterV2(
+ int /*level*/, const Slice& /*key*/, ValueType value_type,
+ const Slice& existing_value, std::string* new_value,
+ std::string* /*skip_until*/) const {
+ bool value_changed = false;
+ RowValue row_value = RowValue::Deserialize(
+ existing_value.data(), existing_value.size());
+ RowValue compacted =
+ purge_ttl_on_expiration_
+ ? row_value.RemoveExpiredColumns(&value_changed)
+ : row_value.ConvertExpiredColumnsToTombstones(&value_changed);
+
+ if (value_type == ValueType::kValue) {
+ compacted = compacted.RemoveTombstones(gc_grace_period_in_seconds_);
+ }
+
+ if(compacted.Empty()) {
+ return Decision::kRemove;
+ }
+
+ if (value_changed) {
+ compacted.Serialize(new_value);
+ return Decision::kChangeValue;
+ }
+
+ return Decision::kKeep;
+}
+
+} // namespace cassandra
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/cassandra/cassandra_compaction_filter.h b/src/rocksdb/utilities/cassandra/cassandra_compaction_filter.h
new file mode 100644
index 000000000..ac2588106
--- /dev/null
+++ b/src/rocksdb/utilities/cassandra/cassandra_compaction_filter.h
@@ -0,0 +1,42 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+#include <string>
+#include "rocksdb/compaction_filter.h"
+#include "rocksdb/slice.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace cassandra {
+
+/**
+ * Compaction filter for removing expired Cassandra data with ttl.
+ * If option `purge_ttl_on_expiration` is set to true, expired data
+ * will be directly purged. Otherwise expired data will be converted
+ * tombstones first, then be eventally removed after gc grace period.
+ * `purge_ttl_on_expiration` should only be on in the case all the
+ * writes have same ttl setting, otherwise it could bring old data back.
+ *
+ * Compaction filter is also in charge of removing tombstone that has been
+ * promoted to kValue type after serials of merging in compaction.
+ */
+class CassandraCompactionFilter : public CompactionFilter {
+public:
+ explicit CassandraCompactionFilter(bool purge_ttl_on_expiration,
+ int32_t gc_grace_period_in_seconds)
+ : purge_ttl_on_expiration_(purge_ttl_on_expiration),
+ gc_grace_period_in_seconds_(gc_grace_period_in_seconds) {}
+
+ const char* Name() const override;
+ virtual Decision FilterV2(int level, const Slice& key, ValueType value_type,
+ const Slice& existing_value, std::string* new_value,
+ std::string* skip_until) const override;
+
+private:
+ bool purge_ttl_on_expiration_;
+ int32_t gc_grace_period_in_seconds_;
+};
+} // namespace cassandra
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/cassandra/cassandra_format_test.cc b/src/rocksdb/utilities/cassandra/cassandra_format_test.cc
new file mode 100644
index 000000000..a8e6ad3f1
--- /dev/null
+++ b/src/rocksdb/utilities/cassandra/cassandra_format_test.cc
@@ -0,0 +1,367 @@
+// Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include <cstring>
+#include <memory>
+#include "test_util/testharness.h"
+#include "utilities/cassandra/format.h"
+#include "utilities/cassandra/serialize.h"
+#include "utilities/cassandra/test_utils.h"
+
+using namespace ROCKSDB_NAMESPACE::cassandra;
+
+namespace ROCKSDB_NAMESPACE {
+namespace cassandra {
+
+TEST(ColumnTest, Column) {
+ char data[4] = {'d', 'a', 't', 'a'};
+ int8_t mask = 0;
+ int8_t index = 1;
+ int64_t timestamp = 1494022807044;
+ Column c = Column(mask, index, timestamp, sizeof(data), data);
+
+ EXPECT_EQ(c.Index(), index);
+ EXPECT_EQ(c.Timestamp(), timestamp);
+ EXPECT_EQ(c.Size(), 14 + sizeof(data));
+
+ // Verify the serialization.
+ std::string dest;
+ dest.reserve(c.Size() * 2);
+ c.Serialize(&dest);
+
+ EXPECT_EQ(dest.size(), c.Size());
+ std::size_t offset = 0;
+ EXPECT_EQ(Deserialize<int8_t>(dest.c_str(), offset), mask);
+ offset += sizeof(int8_t);
+ EXPECT_EQ(Deserialize<int8_t>(dest.c_str(), offset), index);
+ offset += sizeof(int8_t);
+ EXPECT_EQ(Deserialize<int64_t>(dest.c_str(), offset), timestamp);
+ offset += sizeof(int64_t);
+ EXPECT_EQ(Deserialize<int32_t>(dest.c_str(), offset), sizeof(data));
+ offset += sizeof(int32_t);
+ EXPECT_TRUE(std::memcmp(data, dest.c_str() + offset, sizeof(data)) == 0);
+
+ // Verify the deserialization.
+ std::string saved_dest = dest;
+ std::shared_ptr<Column> c1 = Column::Deserialize(saved_dest.c_str(), 0);
+ EXPECT_EQ(c1->Index(), index);
+ EXPECT_EQ(c1->Timestamp(), timestamp);
+ EXPECT_EQ(c1->Size(), 14 + sizeof(data));
+
+ c1->Serialize(&dest);
+ EXPECT_EQ(dest.size(), 2 * c.Size());
+ EXPECT_TRUE(
+ std::memcmp(dest.c_str(), dest.c_str() + c.Size(), c.Size()) == 0);
+
+ // Verify the ColumnBase::Deserialization.
+ saved_dest = dest;
+ std::shared_ptr<ColumnBase> c2 =
+ ColumnBase::Deserialize(saved_dest.c_str(), c.Size());
+ c2->Serialize(&dest);
+ EXPECT_EQ(dest.size(), 3 * c.Size());
+ EXPECT_TRUE(
+ std::memcmp(dest.c_str() + c.Size(), dest.c_str() + c.Size() * 2, c.Size())
+ == 0);
+}
+
+TEST(ExpiringColumnTest, ExpiringColumn) {
+ char data[4] = {'d', 'a', 't', 'a'};
+ int8_t mask = ColumnTypeMask::EXPIRATION_MASK;
+ int8_t index = 3;
+ int64_t timestamp = 1494022807044;
+ int32_t ttl = 3600;
+ ExpiringColumn c = ExpiringColumn(mask, index, timestamp,
+ sizeof(data), data, ttl);
+
+ EXPECT_EQ(c.Index(), index);
+ EXPECT_EQ(c.Timestamp(), timestamp);
+ EXPECT_EQ(c.Size(), 18 + sizeof(data));
+
+ // Verify the serialization.
+ std::string dest;
+ dest.reserve(c.Size() * 2);
+ c.Serialize(&dest);
+
+ EXPECT_EQ(dest.size(), c.Size());
+ std::size_t offset = 0;
+ EXPECT_EQ(Deserialize<int8_t>(dest.c_str(), offset), mask);
+ offset += sizeof(int8_t);
+ EXPECT_EQ(Deserialize<int8_t>(dest.c_str(), offset), index);
+ offset += sizeof(int8_t);
+ EXPECT_EQ(Deserialize<int64_t>(dest.c_str(), offset), timestamp);
+ offset += sizeof(int64_t);
+ EXPECT_EQ(Deserialize<int32_t>(dest.c_str(), offset), sizeof(data));
+ offset += sizeof(int32_t);
+ EXPECT_TRUE(std::memcmp(data, dest.c_str() + offset, sizeof(data)) == 0);
+ offset += sizeof(data);
+ EXPECT_EQ(Deserialize<int32_t>(dest.c_str(), offset), ttl);
+
+ // Verify the deserialization.
+ std::string saved_dest = dest;
+ std::shared_ptr<ExpiringColumn> c1 =
+ ExpiringColumn::Deserialize(saved_dest.c_str(), 0);
+ EXPECT_EQ(c1->Index(), index);
+ EXPECT_EQ(c1->Timestamp(), timestamp);
+ EXPECT_EQ(c1->Size(), 18 + sizeof(data));
+
+ c1->Serialize(&dest);
+ EXPECT_EQ(dest.size(), 2 * c.Size());
+ EXPECT_TRUE(
+ std::memcmp(dest.c_str(), dest.c_str() + c.Size(), c.Size()) == 0);
+
+ // Verify the ColumnBase::Deserialization.
+ saved_dest = dest;
+ std::shared_ptr<ColumnBase> c2 =
+ ColumnBase::Deserialize(saved_dest.c_str(), c.Size());
+ c2->Serialize(&dest);
+ EXPECT_EQ(dest.size(), 3 * c.Size());
+ EXPECT_TRUE(
+ std::memcmp(dest.c_str() + c.Size(), dest.c_str() + c.Size() * 2, c.Size())
+ == 0);
+}
+
+TEST(TombstoneTest, TombstoneCollectable) {
+ int32_t now = (int32_t)time(nullptr);
+ int32_t gc_grace_seconds = 16440;
+ int32_t time_delta_seconds = 10;
+ EXPECT_TRUE(Tombstone(ColumnTypeMask::DELETION_MASK, 0,
+ now - gc_grace_seconds - time_delta_seconds,
+ ToMicroSeconds(now - gc_grace_seconds - time_delta_seconds))
+ .Collectable(gc_grace_seconds));
+ EXPECT_FALSE(Tombstone(ColumnTypeMask::DELETION_MASK, 0,
+ now - gc_grace_seconds + time_delta_seconds,
+ ToMicroSeconds(now - gc_grace_seconds + time_delta_seconds))
+ .Collectable(gc_grace_seconds));
+}
+
+TEST(TombstoneTest, Tombstone) {
+ int8_t mask = ColumnTypeMask::DELETION_MASK;
+ int8_t index = 2;
+ int32_t local_deletion_time = 1494022807;
+ int64_t marked_for_delete_at = 1494022807044;
+ Tombstone c = Tombstone(mask, index, local_deletion_time,
+ marked_for_delete_at);
+
+ EXPECT_EQ(c.Index(), index);
+ EXPECT_EQ(c.Timestamp(), marked_for_delete_at);
+ EXPECT_EQ(c.Size(), 14);
+
+ // Verify the serialization.
+ std::string dest;
+ dest.reserve(c.Size() * 2);
+ c.Serialize(&dest);
+
+ EXPECT_EQ(dest.size(), c.Size());
+ std::size_t offset = 0;
+ EXPECT_EQ(Deserialize<int8_t>(dest.c_str(), offset), mask);
+ offset += sizeof(int8_t);
+ EXPECT_EQ(Deserialize<int8_t>(dest.c_str(), offset), index);
+ offset += sizeof(int8_t);
+ EXPECT_EQ(Deserialize<int32_t>(dest.c_str(), offset), local_deletion_time);
+ offset += sizeof(int32_t);
+ EXPECT_EQ(Deserialize<int64_t>(dest.c_str(), offset), marked_for_delete_at);
+
+ // Verify the deserialization.
+ std::shared_ptr<Tombstone> c1 = Tombstone::Deserialize(dest.c_str(), 0);
+ EXPECT_EQ(c1->Index(), index);
+ EXPECT_EQ(c1->Timestamp(), marked_for_delete_at);
+ EXPECT_EQ(c1->Size(), 14);
+
+ c1->Serialize(&dest);
+ EXPECT_EQ(dest.size(), 2 * c.Size());
+ EXPECT_TRUE(
+ std::memcmp(dest.c_str(), dest.c_str() + c.Size(), c.Size()) == 0);
+
+ // Verify the ColumnBase::Deserialization.
+ std::shared_ptr<ColumnBase> c2 =
+ ColumnBase::Deserialize(dest.c_str(), c.Size());
+ c2->Serialize(&dest);
+ EXPECT_EQ(dest.size(), 3 * c.Size());
+ EXPECT_TRUE(
+ std::memcmp(dest.c_str() + c.Size(), dest.c_str() + c.Size() * 2, c.Size())
+ == 0);
+}
+
+TEST(RowValueTest, RowTombstone) {
+ int32_t local_deletion_time = 1494022807;
+ int64_t marked_for_delete_at = 1494022807044;
+ RowValue r = RowValue(local_deletion_time, marked_for_delete_at);
+
+ EXPECT_EQ(r.Size(), 12);
+ EXPECT_EQ(r.IsTombstone(), true);
+ EXPECT_EQ(r.LastModifiedTime(), marked_for_delete_at);
+
+ // Verify the serialization.
+ std::string dest;
+ dest.reserve(r.Size() * 2);
+ r.Serialize(&dest);
+
+ EXPECT_EQ(dest.size(), r.Size());
+ std::size_t offset = 0;
+ EXPECT_EQ(Deserialize<int32_t>(dest.c_str(), offset), local_deletion_time);
+ offset += sizeof(int32_t);
+ EXPECT_EQ(Deserialize<int64_t>(dest.c_str(), offset), marked_for_delete_at);
+
+ // Verify the deserialization.
+ RowValue r1 = RowValue::Deserialize(dest.c_str(), r.Size());
+ EXPECT_EQ(r1.Size(), 12);
+ EXPECT_EQ(r1.IsTombstone(), true);
+ EXPECT_EQ(r1.LastModifiedTime(), marked_for_delete_at);
+
+ r1.Serialize(&dest);
+ EXPECT_EQ(dest.size(), 2 * r.Size());
+ EXPECT_TRUE(
+ std::memcmp(dest.c_str(), dest.c_str() + r.Size(), r.Size()) == 0);
+}
+
+TEST(RowValueTest, RowWithColumns) {
+ std::vector<std::shared_ptr<ColumnBase>> columns;
+ int64_t last_modified_time = 1494022807048;
+ std::size_t columns_data_size = 0;
+
+ char e_data[5] = {'e', 'd', 'a', 't', 'a'};
+ int8_t e_index = 0;
+ int64_t e_timestamp = 1494022807044;
+ int32_t e_ttl = 3600;
+ columns.push_back(std::shared_ptr<ExpiringColumn>(
+ new ExpiringColumn(ColumnTypeMask::EXPIRATION_MASK, e_index,
+ e_timestamp, sizeof(e_data), e_data, e_ttl)));
+ columns_data_size += columns[0]->Size();
+
+ char c_data[4] = {'d', 'a', 't', 'a'};
+ int8_t c_index = 1;
+ int64_t c_timestamp = 1494022807048;
+ columns.push_back(std::shared_ptr<Column>(
+ new Column(0, c_index, c_timestamp, sizeof(c_data), c_data)));
+ columns_data_size += columns[1]->Size();
+
+ int8_t t_index = 2;
+ int32_t t_local_deletion_time = 1494022801;
+ int64_t t_marked_for_delete_at = 1494022807043;
+ columns.push_back(std::shared_ptr<Tombstone>(
+ new Tombstone(ColumnTypeMask::DELETION_MASK,
+ t_index, t_local_deletion_time, t_marked_for_delete_at)));
+ columns_data_size += columns[2]->Size();
+
+ RowValue r = RowValue(std::move(columns), last_modified_time);
+
+ EXPECT_EQ(r.Size(), columns_data_size + 12);
+ EXPECT_EQ(r.IsTombstone(), false);
+ EXPECT_EQ(r.LastModifiedTime(), last_modified_time);
+
+ // Verify the serialization.
+ std::string dest;
+ dest.reserve(r.Size() * 2);
+ r.Serialize(&dest);
+
+ EXPECT_EQ(dest.size(), r.Size());
+ std::size_t offset = 0;
+ EXPECT_EQ(Deserialize<int32_t>(dest.c_str(), offset),
+ std::numeric_limits<int32_t>::max());
+ offset += sizeof(int32_t);
+ EXPECT_EQ(Deserialize<int64_t>(dest.c_str(), offset),
+ std::numeric_limits<int64_t>::min());
+ offset += sizeof(int64_t);
+
+ // Column0: ExpiringColumn
+ EXPECT_EQ(Deserialize<int8_t>(dest.c_str(), offset),
+ ColumnTypeMask::EXPIRATION_MASK);
+ offset += sizeof(int8_t);
+ EXPECT_EQ(Deserialize<int8_t>(dest.c_str(), offset), e_index);
+ offset += sizeof(int8_t);
+ EXPECT_EQ(Deserialize<int64_t>(dest.c_str(), offset), e_timestamp);
+ offset += sizeof(int64_t);
+ EXPECT_EQ(Deserialize<int32_t>(dest.c_str(), offset), sizeof(e_data));
+ offset += sizeof(int32_t);
+ EXPECT_TRUE(std::memcmp(e_data, dest.c_str() + offset, sizeof(e_data)) == 0);
+ offset += sizeof(e_data);
+ EXPECT_EQ(Deserialize<int32_t>(dest.c_str(), offset), e_ttl);
+ offset += sizeof(int32_t);
+
+ // Column1: Column
+ EXPECT_EQ(Deserialize<int8_t>(dest.c_str(), offset), 0);
+ offset += sizeof(int8_t);
+ EXPECT_EQ(Deserialize<int8_t>(dest.c_str(), offset), c_index);
+ offset += sizeof(int8_t);
+ EXPECT_EQ(Deserialize<int64_t>(dest.c_str(), offset), c_timestamp);
+ offset += sizeof(int64_t);
+ EXPECT_EQ(Deserialize<int32_t>(dest.c_str(), offset), sizeof(c_data));
+ offset += sizeof(int32_t);
+ EXPECT_TRUE(std::memcmp(c_data, dest.c_str() + offset, sizeof(c_data)) == 0);
+ offset += sizeof(c_data);
+
+ // Column2: Tombstone
+ EXPECT_EQ(Deserialize<int8_t>(dest.c_str(), offset),
+ ColumnTypeMask::DELETION_MASK);
+ offset += sizeof(int8_t);
+ EXPECT_EQ(Deserialize<int8_t>(dest.c_str(), offset), t_index);
+ offset += sizeof(int8_t);
+ EXPECT_EQ(Deserialize<int32_t>(dest.c_str(), offset), t_local_deletion_time);
+ offset += sizeof(int32_t);
+ EXPECT_EQ(Deserialize<int64_t>(dest.c_str(), offset), t_marked_for_delete_at);
+
+ // Verify the deserialization.
+ RowValue r1 = RowValue::Deserialize(dest.c_str(), r.Size());
+ EXPECT_EQ(r1.Size(), columns_data_size + 12);
+ EXPECT_EQ(r1.IsTombstone(), false);
+ EXPECT_EQ(r1.LastModifiedTime(), last_modified_time);
+
+ r1.Serialize(&dest);
+ EXPECT_EQ(dest.size(), 2 * r.Size());
+ EXPECT_TRUE(
+ std::memcmp(dest.c_str(), dest.c_str() + r.Size(), r.Size()) == 0);
+}
+
+TEST(RowValueTest, PurgeTtlShouldRemvoeAllColumnsExpired) {
+ int64_t now = time(nullptr);
+
+ auto row_value = CreateTestRowValue({
+ CreateTestColumnSpec(kColumn, 0, ToMicroSeconds(now)),
+ CreateTestColumnSpec(kExpiringColumn, 1, ToMicroSeconds(now - kTtl - 10)), //expired
+ CreateTestColumnSpec(kExpiringColumn, 2, ToMicroSeconds(now)), // not expired
+ CreateTestColumnSpec(kTombstone, 3, ToMicroSeconds(now))
+ });
+
+ bool changed = false;
+ auto purged = row_value.RemoveExpiredColumns(&changed);
+ EXPECT_TRUE(changed);
+ EXPECT_EQ(purged.columns_.size(), 3);
+ VerifyRowValueColumns(purged.columns_, 0, kColumn, 0, ToMicroSeconds(now));
+ VerifyRowValueColumns(purged.columns_, 1, kExpiringColumn, 2, ToMicroSeconds(now));
+ VerifyRowValueColumns(purged.columns_, 2, kTombstone, 3, ToMicroSeconds(now));
+
+ purged.RemoveExpiredColumns(&changed);
+ EXPECT_FALSE(changed);
+}
+
+TEST(RowValueTest, ExpireTtlShouldConvertExpiredColumnsToTombstones) {
+ int64_t now = time(nullptr);
+
+ auto row_value = CreateTestRowValue({
+ CreateTestColumnSpec(kColumn, 0, ToMicroSeconds(now)),
+ CreateTestColumnSpec(kExpiringColumn, 1, ToMicroSeconds(now - kTtl - 10)), //expired
+ CreateTestColumnSpec(kExpiringColumn, 2, ToMicroSeconds(now)), // not expired
+ CreateTestColumnSpec(kTombstone, 3, ToMicroSeconds(now))
+ });
+
+ bool changed = false;
+ auto compacted = row_value.ConvertExpiredColumnsToTombstones(&changed);
+ EXPECT_TRUE(changed);
+ EXPECT_EQ(compacted.columns_.size(), 4);
+ VerifyRowValueColumns(compacted.columns_, 0, kColumn, 0, ToMicroSeconds(now));
+ VerifyRowValueColumns(compacted.columns_, 1, kTombstone, 1, ToMicroSeconds(now - 10));
+ VerifyRowValueColumns(compacted.columns_, 2, kExpiringColumn, 2, ToMicroSeconds(now));
+ VerifyRowValueColumns(compacted.columns_, 3, kTombstone, 3, ToMicroSeconds(now));
+
+ compacted.ConvertExpiredColumnsToTombstones(&changed);
+ EXPECT_FALSE(changed);
+}
+} // namespace cassandra
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/rocksdb/utilities/cassandra/cassandra_functional_test.cc b/src/rocksdb/utilities/cassandra/cassandra_functional_test.cc
new file mode 100644
index 000000000..501988423
--- /dev/null
+++ b/src/rocksdb/utilities/cassandra/cassandra_functional_test.cc
@@ -0,0 +1,311 @@
+// Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include <iostream>
+#include "db/db_impl/db_impl.h"
+#include "rocksdb/db.h"
+#include "rocksdb/merge_operator.h"
+#include "rocksdb/utilities/db_ttl.h"
+#include "test_util/testharness.h"
+#include "util/random.h"
+#include "utilities/cassandra/cassandra_compaction_filter.h"
+#include "utilities/cassandra/merge_operator.h"
+#include "utilities/cassandra/test_utils.h"
+#include "utilities/merge_operators.h"
+
+using namespace ROCKSDB_NAMESPACE;
+
+namespace ROCKSDB_NAMESPACE {
+namespace cassandra {
+
+// Path to the database on file system
+const std::string kDbName = test::PerThreadDBPath("cassandra_functional_test");
+
+class CassandraStore {
+ public:
+ explicit CassandraStore(std::shared_ptr<DB> db)
+ : db_(db), write_option_(), get_option_() {
+ assert(db);
+ }
+
+ bool Append(const std::string& key, const RowValue& val){
+ std::string result;
+ val.Serialize(&result);
+ Slice valSlice(result.data(), result.size());
+ auto s = db_->Merge(write_option_, key, valSlice);
+
+ if (s.ok()) {
+ return true;
+ } else {
+ std::cerr << "ERROR " << s.ToString() << std::endl;
+ return false;
+ }
+ }
+
+ bool Put(const std::string& key, const RowValue& val) {
+ std::string result;
+ val.Serialize(&result);
+ Slice valSlice(result.data(), result.size());
+ auto s = db_->Put(write_option_, key, valSlice);
+ if (s.ok()) {
+ return true;
+ } else {
+ std::cerr << "ERROR " << s.ToString() << std::endl;
+ return false;
+ }
+ }
+
+ void Flush() {
+ dbfull()->TEST_FlushMemTable();
+ dbfull()->TEST_WaitForCompact();
+ }
+
+ void Compact() {
+ dbfull()->TEST_CompactRange(
+ 0, nullptr, nullptr, db_->DefaultColumnFamily());
+ }
+
+ std::tuple<bool, RowValue> Get(const std::string& key){
+ std::string result;
+ auto s = db_->Get(get_option_, key, &result);
+
+ if (s.ok()) {
+ return std::make_tuple(true,
+ RowValue::Deserialize(result.data(),
+ result.size()));
+ }
+
+ if (!s.IsNotFound()) {
+ std::cerr << "ERROR " << s.ToString() << std::endl;
+ }
+
+ return std::make_tuple(false, RowValue(0, 0));
+ }
+
+ private:
+ std::shared_ptr<DB> db_;
+ WriteOptions write_option_;
+ ReadOptions get_option_;
+
+ DBImpl* dbfull() { return reinterpret_cast<DBImpl*>(db_.get()); }
+};
+
+class TestCompactionFilterFactory : public CompactionFilterFactory {
+public:
+ explicit TestCompactionFilterFactory(bool purge_ttl_on_expiration,
+ int32_t gc_grace_period_in_seconds)
+ : purge_ttl_on_expiration_(purge_ttl_on_expiration),
+ gc_grace_period_in_seconds_(gc_grace_period_in_seconds) {}
+
+ std::unique_ptr<CompactionFilter> CreateCompactionFilter(
+ const CompactionFilter::Context& /*context*/) override {
+ return std::unique_ptr<CompactionFilter>(new CassandraCompactionFilter(
+ purge_ttl_on_expiration_, gc_grace_period_in_seconds_));
+ }
+
+ const char* Name() const override { return "TestCompactionFilterFactory"; }
+
+private:
+ bool purge_ttl_on_expiration_;
+ int32_t gc_grace_period_in_seconds_;
+};
+
+
+// The class for unit-testing
+class CassandraFunctionalTest : public testing::Test {
+public:
+ CassandraFunctionalTest() {
+ DestroyDB(kDbName, Options()); // Start each test with a fresh DB
+ }
+
+ std::shared_ptr<DB> OpenDb() {
+ DB* db;
+ Options options;
+ options.create_if_missing = true;
+ options.merge_operator.reset(new CassandraValueMergeOperator(gc_grace_period_in_seconds_));
+ auto* cf_factory = new TestCompactionFilterFactory(
+ purge_ttl_on_expiration_, gc_grace_period_in_seconds_);
+ options.compaction_filter_factory.reset(cf_factory);
+ EXPECT_OK(DB::Open(options, kDbName, &db));
+ return std::shared_ptr<DB>(db);
+ }
+
+ bool purge_ttl_on_expiration_ = false;
+ int32_t gc_grace_period_in_seconds_ = 100;
+};
+
+// THE TEST CASES BEGIN HERE
+
+TEST_F(CassandraFunctionalTest, SimpleMergeTest) {
+ CassandraStore store(OpenDb());
+ int64_t now = time(nullptr);
+
+ store.Append("k1", CreateTestRowValue({
+ CreateTestColumnSpec(kTombstone, 0, ToMicroSeconds(now + 5)),
+ CreateTestColumnSpec(kColumn, 1, ToMicroSeconds(now + 8)),
+ CreateTestColumnSpec(kExpiringColumn, 2, ToMicroSeconds(now + 5)),
+ }));
+ store.Append("k1",CreateTestRowValue({
+ CreateTestColumnSpec(kColumn, 0, ToMicroSeconds(now + 2)),
+ CreateTestColumnSpec(kExpiringColumn, 1, ToMicroSeconds(now + 5)),
+ CreateTestColumnSpec(kTombstone, 2, ToMicroSeconds(now + 7)),
+ CreateTestColumnSpec(kExpiringColumn, 7, ToMicroSeconds(now + 17)),
+ }));
+ store.Append("k1", CreateTestRowValue({
+ CreateTestColumnSpec(kExpiringColumn, 0, ToMicroSeconds(now + 6)),
+ CreateTestColumnSpec(kTombstone, 1, ToMicroSeconds(now + 5)),
+ CreateTestColumnSpec(kColumn, 2, ToMicroSeconds(now + 4)),
+ CreateTestColumnSpec(kTombstone, 11, ToMicroSeconds(now + 11)),
+ }));
+
+ auto ret = store.Get("k1");
+
+ ASSERT_TRUE(std::get<0>(ret));
+ RowValue& merged = std::get<1>(ret);
+ EXPECT_EQ(merged.columns_.size(), 5);
+ VerifyRowValueColumns(merged.columns_, 0, kExpiringColumn, 0, ToMicroSeconds(now + 6));
+ VerifyRowValueColumns(merged.columns_, 1, kColumn, 1, ToMicroSeconds(now + 8));
+ VerifyRowValueColumns(merged.columns_, 2, kTombstone, 2, ToMicroSeconds(now + 7));
+ VerifyRowValueColumns(merged.columns_, 3, kExpiringColumn, 7, ToMicroSeconds(now + 17));
+ VerifyRowValueColumns(merged.columns_, 4, kTombstone, 11, ToMicroSeconds(now + 11));
+}
+
+TEST_F(CassandraFunctionalTest,
+ CompactionShouldConvertExpiredColumnsToTombstone) {
+ CassandraStore store(OpenDb());
+ int64_t now= time(nullptr);
+
+ store.Append("k1", CreateTestRowValue({
+ CreateTestColumnSpec(kExpiringColumn, 0, ToMicroSeconds(now - kTtl - 20)), //expired
+ CreateTestColumnSpec(kExpiringColumn, 1, ToMicroSeconds(now - kTtl + 10)), // not expired
+ CreateTestColumnSpec(kTombstone, 3, ToMicroSeconds(now))
+ }));
+
+ store.Flush();
+
+ store.Append("k1",CreateTestRowValue({
+ CreateTestColumnSpec(kExpiringColumn, 0, ToMicroSeconds(now - kTtl - 10)), //expired
+ CreateTestColumnSpec(kColumn, 2, ToMicroSeconds(now))
+ }));
+
+ store.Flush();
+ store.Compact();
+
+ auto ret = store.Get("k1");
+ ASSERT_TRUE(std::get<0>(ret));
+ RowValue& merged = std::get<1>(ret);
+ EXPECT_EQ(merged.columns_.size(), 4);
+ VerifyRowValueColumns(merged.columns_, 0, kTombstone, 0, ToMicroSeconds(now - 10));
+ VerifyRowValueColumns(merged.columns_, 1, kExpiringColumn, 1, ToMicroSeconds(now - kTtl + 10));
+ VerifyRowValueColumns(merged.columns_, 2, kColumn, 2, ToMicroSeconds(now));
+ VerifyRowValueColumns(merged.columns_, 3, kTombstone, 3, ToMicroSeconds(now));
+}
+
+
+TEST_F(CassandraFunctionalTest,
+ CompactionShouldPurgeExpiredColumnsIfPurgeTtlIsOn) {
+ purge_ttl_on_expiration_ = true;
+ CassandraStore store(OpenDb());
+ int64_t now = time(nullptr);
+
+ store.Append("k1", CreateTestRowValue({
+ CreateTestColumnSpec(kExpiringColumn, 0, ToMicroSeconds(now - kTtl - 20)), //expired
+ CreateTestColumnSpec(kExpiringColumn, 1, ToMicroSeconds(now)), // not expired
+ CreateTestColumnSpec(kTombstone, 3, ToMicroSeconds(now))
+ }));
+
+ store.Flush();
+
+ store.Append("k1",CreateTestRowValue({
+ CreateTestColumnSpec(kExpiringColumn, 0, ToMicroSeconds(now - kTtl - 10)), //expired
+ CreateTestColumnSpec(kColumn, 2, ToMicroSeconds(now))
+ }));
+
+ store.Flush();
+ store.Compact();
+
+ auto ret = store.Get("k1");
+ ASSERT_TRUE(std::get<0>(ret));
+ RowValue& merged = std::get<1>(ret);
+ EXPECT_EQ(merged.columns_.size(), 3);
+ VerifyRowValueColumns(merged.columns_, 0, kExpiringColumn, 1, ToMicroSeconds(now));
+ VerifyRowValueColumns(merged.columns_, 1, kColumn, 2, ToMicroSeconds(now));
+ VerifyRowValueColumns(merged.columns_, 2, kTombstone, 3, ToMicroSeconds(now));
+}
+
+TEST_F(CassandraFunctionalTest,
+ CompactionShouldRemoveRowWhenAllColumnsExpiredIfPurgeTtlIsOn) {
+ purge_ttl_on_expiration_ = true;
+ CassandraStore store(OpenDb());
+ int64_t now = time(nullptr);
+
+ store.Append("k1", CreateTestRowValue({
+ CreateTestColumnSpec(kExpiringColumn, 0, ToMicroSeconds(now - kTtl - 20)),
+ CreateTestColumnSpec(kExpiringColumn, 1, ToMicroSeconds(now - kTtl - 20)),
+ }));
+
+ store.Flush();
+
+ store.Append("k1",CreateTestRowValue({
+ CreateTestColumnSpec(kExpiringColumn, 0, ToMicroSeconds(now - kTtl - 10)),
+ }));
+
+ store.Flush();
+ store.Compact();
+ ASSERT_FALSE(std::get<0>(store.Get("k1")));
+}
+
+TEST_F(CassandraFunctionalTest,
+ CompactionShouldRemoveTombstoneExceedingGCGracePeriod) {
+ purge_ttl_on_expiration_ = true;
+ CassandraStore store(OpenDb());
+ int64_t now = time(nullptr);
+
+ store.Append("k1", CreateTestRowValue({
+ CreateTestColumnSpec(kTombstone, 0, ToMicroSeconds(now - gc_grace_period_in_seconds_ - 1)),
+ CreateTestColumnSpec(kColumn, 1, ToMicroSeconds(now))
+ }));
+
+ store.Append("k2", CreateTestRowValue({
+ CreateTestColumnSpec(kColumn, 0, ToMicroSeconds(now))
+ }));
+
+ store.Flush();
+
+ store.Append("k1",CreateTestRowValue({
+ CreateTestColumnSpec(kColumn, 1, ToMicroSeconds(now)),
+ }));
+
+ store.Flush();
+ store.Compact();
+
+ auto ret = store.Get("k1");
+ ASSERT_TRUE(std::get<0>(ret));
+ RowValue& gced = std::get<1>(ret);
+ EXPECT_EQ(gced.columns_.size(), 1);
+ VerifyRowValueColumns(gced.columns_, 0, kColumn, 1, ToMicroSeconds(now));
+}
+
+TEST_F(CassandraFunctionalTest, CompactionShouldRemoveTombstoneFromPut) {
+ purge_ttl_on_expiration_ = true;
+ CassandraStore store(OpenDb());
+ int64_t now = time(nullptr);
+
+ store.Put("k1", CreateTestRowValue({
+ CreateTestColumnSpec(kTombstone, 0, ToMicroSeconds(now - gc_grace_period_in_seconds_ - 1)),
+ }));
+
+ store.Flush();
+ store.Compact();
+ ASSERT_FALSE(std::get<0>(store.Get("k1")));
+}
+
+} // namespace cassandra
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/rocksdb/utilities/cassandra/cassandra_row_merge_test.cc b/src/rocksdb/utilities/cassandra/cassandra_row_merge_test.cc
new file mode 100644
index 000000000..9e9ff1494
--- /dev/null
+++ b/src/rocksdb/utilities/cassandra/cassandra_row_merge_test.cc
@@ -0,0 +1,112 @@
+// Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include <memory>
+#include "test_util/testharness.h"
+#include "utilities/cassandra/format.h"
+#include "utilities/cassandra/test_utils.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace cassandra {
+
+TEST(RowValueMergeTest, Merge) {
+ std::vector<RowValue> row_values;
+ row_values.push_back(
+ CreateTestRowValue({
+ CreateTestColumnSpec(kTombstone, 0, 5),
+ CreateTestColumnSpec(kColumn, 1, 8),
+ CreateTestColumnSpec(kExpiringColumn, 2, 5),
+ })
+ );
+
+ row_values.push_back(
+ CreateTestRowValue({
+ CreateTestColumnSpec(kColumn, 0, 2),
+ CreateTestColumnSpec(kExpiringColumn, 1, 5),
+ CreateTestColumnSpec(kTombstone, 2, 7),
+ CreateTestColumnSpec(kExpiringColumn, 7, 17),
+ })
+ );
+
+ row_values.push_back(
+ CreateTestRowValue({
+ CreateTestColumnSpec(kExpiringColumn, 0, 6),
+ CreateTestColumnSpec(kTombstone, 1, 5),
+ CreateTestColumnSpec(kColumn, 2, 4),
+ CreateTestColumnSpec(kTombstone, 11, 11),
+ })
+ );
+
+ RowValue merged = RowValue::Merge(std::move(row_values));
+ EXPECT_FALSE(merged.IsTombstone());
+ EXPECT_EQ(merged.columns_.size(), 5);
+ VerifyRowValueColumns(merged.columns_, 0, kExpiringColumn, 0, 6);
+ VerifyRowValueColumns(merged.columns_, 1, kColumn, 1, 8);
+ VerifyRowValueColumns(merged.columns_, 2, kTombstone, 2, 7);
+ VerifyRowValueColumns(merged.columns_, 3, kExpiringColumn, 7, 17);
+ VerifyRowValueColumns(merged.columns_, 4, kTombstone, 11, 11);
+}
+
+TEST(RowValueMergeTest, MergeWithRowTombstone) {
+ std::vector<RowValue> row_values;
+
+ // A row tombstone.
+ row_values.push_back(
+ CreateRowTombstone(11)
+ );
+
+ // This row's timestamp is smaller than tombstone.
+ row_values.push_back(
+ CreateTestRowValue({
+ CreateTestColumnSpec(kColumn, 0, 5),
+ CreateTestColumnSpec(kColumn, 1, 6),
+ })
+ );
+
+ // Some of the column's row is smaller, some is larger.
+ row_values.push_back(
+ CreateTestRowValue({
+ CreateTestColumnSpec(kColumn, 2, 10),
+ CreateTestColumnSpec(kColumn, 3, 12),
+ })
+ );
+
+ // All of the column's rows are larger than tombstone.
+ row_values.push_back(
+ CreateTestRowValue({
+ CreateTestColumnSpec(kColumn, 4, 13),
+ CreateTestColumnSpec(kColumn, 5, 14),
+ })
+ );
+
+ RowValue merged = RowValue::Merge(std::move(row_values));
+ EXPECT_FALSE(merged.IsTombstone());
+ EXPECT_EQ(merged.columns_.size(), 3);
+ VerifyRowValueColumns(merged.columns_, 0, kColumn, 3, 12);
+ VerifyRowValueColumns(merged.columns_, 1, kColumn, 4, 13);
+ VerifyRowValueColumns(merged.columns_, 2, kColumn, 5, 14);
+
+ // If the tombstone's timestamp is the latest, then it returns a
+ // row tombstone.
+ row_values.push_back(
+ CreateRowTombstone(15)
+ );
+
+ row_values.push_back(
+ CreateRowTombstone(17)
+ );
+
+ merged = RowValue::Merge(std::move(row_values));
+ EXPECT_TRUE(merged.IsTombstone());
+ EXPECT_EQ(merged.LastModifiedTime(), 17);
+}
+
+} // namespace cassandra
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/rocksdb/utilities/cassandra/cassandra_serialize_test.cc b/src/rocksdb/utilities/cassandra/cassandra_serialize_test.cc
new file mode 100644
index 000000000..491540bfe
--- /dev/null
+++ b/src/rocksdb/utilities/cassandra/cassandra_serialize_test.cc
@@ -0,0 +1,188 @@
+// Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "test_util/testharness.h"
+#include "utilities/cassandra/serialize.h"
+
+using namespace ROCKSDB_NAMESPACE::cassandra;
+
+namespace ROCKSDB_NAMESPACE {
+namespace cassandra {
+
+TEST(SerializeTest, SerializeI64) {
+ std::string dest;
+ Serialize<int64_t>(0, &dest);
+ EXPECT_EQ(
+ std::string(
+ {'\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00'}),
+ dest);
+
+ dest.clear();
+ Serialize<int64_t>(1, &dest);
+ EXPECT_EQ(
+ std::string(
+ {'\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x01'}),
+ dest);
+
+
+ dest.clear();
+ Serialize<int64_t>(-1, &dest);
+ EXPECT_EQ(
+ std::string(
+ {'\xff', '\xff', '\xff', '\xff', '\xff', '\xff', '\xff', '\xff'}),
+ dest);
+
+ dest.clear();
+ Serialize<int64_t>(9223372036854775807, &dest);
+ EXPECT_EQ(
+ std::string(
+ {'\x7f', '\xff', '\xff', '\xff', '\xff', '\xff', '\xff', '\xff'}),
+ dest);
+
+ dest.clear();
+ Serialize<int64_t>(-9223372036854775807, &dest);
+ EXPECT_EQ(
+ std::string(
+ {'\x80', '\x00', '\x00', '\x00', '\x00', '\x00', '\x00', '\x01'}),
+ dest);
+}
+
+TEST(SerializeTest, DeserializeI64) {
+ std::string dest;
+ std::size_t offset = dest.size();
+ Serialize<int64_t>(0, &dest);
+ EXPECT_EQ(0, Deserialize<int64_t>(dest.c_str(), offset));
+
+ offset = dest.size();
+ Serialize<int64_t>(1, &dest);
+ EXPECT_EQ(1, Deserialize<int64_t>(dest.c_str(), offset));
+
+ offset = dest.size();
+ Serialize<int64_t>(-1, &dest);
+ EXPECT_EQ(-1, Deserialize<int64_t>(dest.c_str(), offset));
+
+ offset = dest.size();
+ Serialize<int64_t>(-9223372036854775807, &dest);
+ EXPECT_EQ(-9223372036854775807, Deserialize<int64_t>(dest.c_str(), offset));
+
+ offset = dest.size();
+ Serialize<int64_t>(9223372036854775807, &dest);
+ EXPECT_EQ(9223372036854775807, Deserialize<int64_t>(dest.c_str(), offset));
+}
+
+TEST(SerializeTest, SerializeI32) {
+ std::string dest;
+ Serialize<int32_t>(0, &dest);
+ EXPECT_EQ(
+ std::string(
+ {'\x00', '\x00', '\x00', '\x00'}),
+ dest);
+
+ dest.clear();
+ Serialize<int32_t>(1, &dest);
+ EXPECT_EQ(
+ std::string(
+ {'\x00', '\x00', '\x00', '\x01'}),
+ dest);
+
+
+ dest.clear();
+ Serialize<int32_t>(-1, &dest);
+ EXPECT_EQ(
+ std::string(
+ {'\xff', '\xff', '\xff', '\xff'}),
+ dest);
+
+ dest.clear();
+ Serialize<int32_t>(2147483647, &dest);
+ EXPECT_EQ(
+ std::string(
+ {'\x7f', '\xff', '\xff', '\xff'}),
+ dest);
+
+ dest.clear();
+ Serialize<int32_t>(-2147483648LL, &dest);
+ EXPECT_EQ(
+ std::string(
+ {'\x80', '\x00', '\x00', '\x00'}),
+ dest);
+}
+
+TEST(SerializeTest, DeserializeI32) {
+ std::string dest;
+ std::size_t offset = dest.size();
+ Serialize<int32_t>(0, &dest);
+ EXPECT_EQ(0, Deserialize<int32_t>(dest.c_str(), offset));
+
+ offset = dest.size();
+ Serialize<int32_t>(1, &dest);
+ EXPECT_EQ(1, Deserialize<int32_t>(dest.c_str(), offset));
+
+ offset = dest.size();
+ Serialize<int32_t>(-1, &dest);
+ EXPECT_EQ(-1, Deserialize<int32_t>(dest.c_str(), offset));
+
+ offset = dest.size();
+ Serialize<int32_t>(2147483647, &dest);
+ EXPECT_EQ(2147483647, Deserialize<int32_t>(dest.c_str(), offset));
+
+ offset = dest.size();
+ Serialize<int32_t>(-2147483648LL, &dest);
+ EXPECT_EQ(-2147483648LL, Deserialize<int32_t>(dest.c_str(), offset));
+}
+
+TEST(SerializeTest, SerializeI8) {
+ std::string dest;
+ Serialize<int8_t>(0, &dest);
+ EXPECT_EQ(std::string({'\x00'}), dest);
+
+ dest.clear();
+ Serialize<int8_t>(1, &dest);
+ EXPECT_EQ(std::string({'\x01'}), dest);
+
+
+ dest.clear();
+ Serialize<int8_t>(-1, &dest);
+ EXPECT_EQ(std::string({'\xff'}), dest);
+
+ dest.clear();
+ Serialize<int8_t>(127, &dest);
+ EXPECT_EQ(std::string({'\x7f'}), dest);
+
+ dest.clear();
+ Serialize<int8_t>(-128, &dest);
+ EXPECT_EQ(std::string({'\x80'}), dest);
+}
+
+TEST(SerializeTest, DeserializeI8) {
+ std::string dest;
+ std::size_t offset = dest.size();
+ Serialize<int8_t>(0, &dest);
+ EXPECT_EQ(0, Deserialize<int8_t>(dest.c_str(), offset));
+
+ offset = dest.size();
+ Serialize<int8_t>(1, &dest);
+ EXPECT_EQ(1, Deserialize<int8_t>(dest.c_str(), offset));
+
+ offset = dest.size();
+ Serialize<int8_t>(-1, &dest);
+ EXPECT_EQ(-1, Deserialize<int8_t>(dest.c_str(), offset));
+
+ offset = dest.size();
+ Serialize<int8_t>(127, &dest);
+ EXPECT_EQ(127, Deserialize<int8_t>(dest.c_str(), offset));
+
+ offset = dest.size();
+ Serialize<int8_t>(-128, &dest);
+ EXPECT_EQ(-128, Deserialize<int8_t>(dest.c_str(), offset));
+}
+
+} // namespace cassandra
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/rocksdb/utilities/cassandra/format.cc b/src/rocksdb/utilities/cassandra/format.cc
new file mode 100644
index 000000000..a767f41e7
--- /dev/null
+++ b/src/rocksdb/utilities/cassandra/format.cc
@@ -0,0 +1,390 @@
+// Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "format.h"
+
+#include <algorithm>
+#include <map>
+#include <memory>
+
+#include "utilities/cassandra/serialize.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace cassandra {
+namespace {
+const int32_t kDefaultLocalDeletionTime =
+ std::numeric_limits<int32_t>::max();
+const int64_t kDefaultMarkedForDeleteAt =
+ std::numeric_limits<int64_t>::min();
+}
+
+ColumnBase::ColumnBase(int8_t mask, int8_t index)
+ : mask_(mask), index_(index) {}
+
+std::size_t ColumnBase::Size() const {
+ return sizeof(mask_) + sizeof(index_);
+}
+
+int8_t ColumnBase::Mask() const {
+ return mask_;
+}
+
+int8_t ColumnBase::Index() const {
+ return index_;
+}
+
+void ColumnBase::Serialize(std::string* dest) const {
+ ROCKSDB_NAMESPACE::cassandra::Serialize<int8_t>(mask_, dest);
+ ROCKSDB_NAMESPACE::cassandra::Serialize<int8_t>(index_, dest);
+}
+
+std::shared_ptr<ColumnBase> ColumnBase::Deserialize(const char* src,
+ std::size_t offset) {
+ int8_t mask = ROCKSDB_NAMESPACE::cassandra::Deserialize<int8_t>(src, offset);
+ if ((mask & ColumnTypeMask::DELETION_MASK) != 0) {
+ return Tombstone::Deserialize(src, offset);
+ } else if ((mask & ColumnTypeMask::EXPIRATION_MASK) != 0) {
+ return ExpiringColumn::Deserialize(src, offset);
+ } else {
+ return Column::Deserialize(src, offset);
+ }
+}
+
+Column::Column(
+ int8_t mask,
+ int8_t index,
+ int64_t timestamp,
+ int32_t value_size,
+ const char* value
+) : ColumnBase(mask, index), timestamp_(timestamp),
+ value_size_(value_size), value_(value) {}
+
+int64_t Column::Timestamp() const {
+ return timestamp_;
+}
+
+std::size_t Column::Size() const {
+ return ColumnBase::Size() + sizeof(timestamp_) + sizeof(value_size_)
+ + value_size_;
+}
+
+void Column::Serialize(std::string* dest) const {
+ ColumnBase::Serialize(dest);
+ ROCKSDB_NAMESPACE::cassandra::Serialize<int64_t>(timestamp_, dest);
+ ROCKSDB_NAMESPACE::cassandra::Serialize<int32_t>(value_size_, dest);
+ dest->append(value_, value_size_);
+}
+
+std::shared_ptr<Column> Column::Deserialize(const char *src,
+ std::size_t offset) {
+ int8_t mask = ROCKSDB_NAMESPACE::cassandra::Deserialize<int8_t>(src, offset);
+ offset += sizeof(mask);
+ int8_t index = ROCKSDB_NAMESPACE::cassandra::Deserialize<int8_t>(src, offset);
+ offset += sizeof(index);
+ int64_t timestamp =
+ ROCKSDB_NAMESPACE::cassandra::Deserialize<int64_t>(src, offset);
+ offset += sizeof(timestamp);
+ int32_t value_size =
+ ROCKSDB_NAMESPACE::cassandra::Deserialize<int32_t>(src, offset);
+ offset += sizeof(value_size);
+ return std::make_shared<Column>(
+ mask, index, timestamp, value_size, src + offset);
+}
+
+ExpiringColumn::ExpiringColumn(
+ int8_t mask,
+ int8_t index,
+ int64_t timestamp,
+ int32_t value_size,
+ const char* value,
+ int32_t ttl
+) : Column(mask, index, timestamp, value_size, value),
+ ttl_(ttl) {}
+
+std::size_t ExpiringColumn::Size() const {
+ return Column::Size() + sizeof(ttl_);
+}
+
+void ExpiringColumn::Serialize(std::string* dest) const {
+ Column::Serialize(dest);
+ ROCKSDB_NAMESPACE::cassandra::Serialize<int32_t>(ttl_, dest);
+}
+
+std::chrono::time_point<std::chrono::system_clock> ExpiringColumn::TimePoint() const {
+ return std::chrono::time_point<std::chrono::system_clock>(std::chrono::microseconds(Timestamp()));
+}
+
+std::chrono::seconds ExpiringColumn::Ttl() const {
+ return std::chrono::seconds(ttl_);
+}
+
+bool ExpiringColumn::Expired() const {
+ return TimePoint() + Ttl() < std::chrono::system_clock::now();
+}
+
+std::shared_ptr<Tombstone> ExpiringColumn::ToTombstone() const {
+ auto expired_at = (TimePoint() + Ttl()).time_since_epoch();
+ int32_t local_deletion_time = static_cast<int32_t>(
+ std::chrono::duration_cast<std::chrono::seconds>(expired_at).count());
+ int64_t marked_for_delete_at =
+ std::chrono::duration_cast<std::chrono::microseconds>(expired_at).count();
+ return std::make_shared<Tombstone>(
+ static_cast<int8_t>(ColumnTypeMask::DELETION_MASK),
+ Index(),
+ local_deletion_time,
+ marked_for_delete_at);
+}
+
+std::shared_ptr<ExpiringColumn> ExpiringColumn::Deserialize(
+ const char *src,
+ std::size_t offset) {
+ int8_t mask = ROCKSDB_NAMESPACE::cassandra::Deserialize<int8_t>(src, offset);
+ offset += sizeof(mask);
+ int8_t index = ROCKSDB_NAMESPACE::cassandra::Deserialize<int8_t>(src, offset);
+ offset += sizeof(index);
+ int64_t timestamp =
+ ROCKSDB_NAMESPACE::cassandra::Deserialize<int64_t>(src, offset);
+ offset += sizeof(timestamp);
+ int32_t value_size =
+ ROCKSDB_NAMESPACE::cassandra::Deserialize<int32_t>(src, offset);
+ offset += sizeof(value_size);
+ const char* value = src + offset;
+ offset += value_size;
+ int32_t ttl = ROCKSDB_NAMESPACE::cassandra::Deserialize<int32_t>(src, offset);
+ return std::make_shared<ExpiringColumn>(
+ mask, index, timestamp, value_size, value, ttl);
+}
+
+Tombstone::Tombstone(
+ int8_t mask,
+ int8_t index,
+ int32_t local_deletion_time,
+ int64_t marked_for_delete_at
+) : ColumnBase(mask, index), local_deletion_time_(local_deletion_time),
+ marked_for_delete_at_(marked_for_delete_at) {}
+
+int64_t Tombstone::Timestamp() const {
+ return marked_for_delete_at_;
+}
+
+std::size_t Tombstone::Size() const {
+ return ColumnBase::Size() + sizeof(local_deletion_time_)
+ + sizeof(marked_for_delete_at_);
+}
+
+void Tombstone::Serialize(std::string* dest) const {
+ ColumnBase::Serialize(dest);
+ ROCKSDB_NAMESPACE::cassandra::Serialize<int32_t>(local_deletion_time_, dest);
+ ROCKSDB_NAMESPACE::cassandra::Serialize<int64_t>(marked_for_delete_at_, dest);
+}
+
+bool Tombstone::Collectable(int32_t gc_grace_period_in_seconds) const {
+ auto local_deleted_at = std::chrono::time_point<std::chrono::system_clock>(
+ std::chrono::seconds(local_deletion_time_));
+ auto gc_grace_period = std::chrono::seconds(gc_grace_period_in_seconds);
+ return local_deleted_at + gc_grace_period < std::chrono::system_clock::now();
+}
+
+std::shared_ptr<Tombstone> Tombstone::Deserialize(const char *src,
+ std::size_t offset) {
+ int8_t mask = ROCKSDB_NAMESPACE::cassandra::Deserialize<int8_t>(src, offset);
+ offset += sizeof(mask);
+ int8_t index = ROCKSDB_NAMESPACE::cassandra::Deserialize<int8_t>(src, offset);
+ offset += sizeof(index);
+ int32_t local_deletion_time =
+ ROCKSDB_NAMESPACE::cassandra::Deserialize<int32_t>(src, offset);
+ offset += sizeof(int32_t);
+ int64_t marked_for_delete_at =
+ ROCKSDB_NAMESPACE::cassandra::Deserialize<int64_t>(src, offset);
+ return std::make_shared<Tombstone>(
+ mask, index, local_deletion_time, marked_for_delete_at);
+}
+
+RowValue::RowValue(int32_t local_deletion_time, int64_t marked_for_delete_at)
+ : local_deletion_time_(local_deletion_time),
+ marked_for_delete_at_(marked_for_delete_at), columns_(),
+ last_modified_time_(0) {}
+
+RowValue::RowValue(Columns columns,
+ int64_t last_modified_time)
+ : local_deletion_time_(kDefaultLocalDeletionTime),
+ marked_for_delete_at_(kDefaultMarkedForDeleteAt),
+ columns_(std::move(columns)), last_modified_time_(last_modified_time) {}
+
+std::size_t RowValue::Size() const {
+ std::size_t size = sizeof(local_deletion_time_)
+ + sizeof(marked_for_delete_at_);
+ for (const auto& column : columns_) {
+ size += column -> Size();
+ }
+ return size;
+}
+
+int64_t RowValue::LastModifiedTime() const {
+ if (IsTombstone()) {
+ return marked_for_delete_at_;
+ } else {
+ return last_modified_time_;
+ }
+}
+
+bool RowValue::IsTombstone() const {
+ return marked_for_delete_at_ > kDefaultMarkedForDeleteAt;
+}
+
+void RowValue::Serialize(std::string* dest) const {
+ ROCKSDB_NAMESPACE::cassandra::Serialize<int32_t>(local_deletion_time_, dest);
+ ROCKSDB_NAMESPACE::cassandra::Serialize<int64_t>(marked_for_delete_at_, dest);
+ for (const auto& column : columns_) {
+ column -> Serialize(dest);
+ }
+}
+
+RowValue RowValue::RemoveExpiredColumns(bool* changed) const {
+ *changed = false;
+ Columns new_columns;
+ for (auto& column : columns_) {
+ if(column->Mask() == ColumnTypeMask::EXPIRATION_MASK) {
+ std::shared_ptr<ExpiringColumn> expiring_column =
+ std::static_pointer_cast<ExpiringColumn>(column);
+
+ if(expiring_column->Expired()){
+ *changed = true;
+ continue;
+ }
+ }
+
+ new_columns.push_back(column);
+ }
+ return RowValue(std::move(new_columns), last_modified_time_);
+}
+
+RowValue RowValue::ConvertExpiredColumnsToTombstones(bool* changed) const {
+ *changed = false;
+ Columns new_columns;
+ for (auto& column : columns_) {
+ if(column->Mask() == ColumnTypeMask::EXPIRATION_MASK) {
+ std::shared_ptr<ExpiringColumn> expiring_column =
+ std::static_pointer_cast<ExpiringColumn>(column);
+
+ if(expiring_column->Expired()) {
+ std::shared_ptr<Tombstone> tombstone = expiring_column->ToTombstone();
+ new_columns.push_back(tombstone);
+ *changed = true;
+ continue;
+ }
+ }
+ new_columns.push_back(column);
+ }
+ return RowValue(std::move(new_columns), last_modified_time_);
+}
+
+RowValue RowValue::RemoveTombstones(int32_t gc_grace_period) const {
+ Columns new_columns;
+ for (auto& column : columns_) {
+ if (column->Mask() == ColumnTypeMask::DELETION_MASK) {
+ std::shared_ptr<Tombstone> tombstone =
+ std::static_pointer_cast<Tombstone>(column);
+
+ if (tombstone->Collectable(gc_grace_period)) {
+ continue;
+ }
+ }
+
+ new_columns.push_back(column);
+ }
+ return RowValue(std::move(new_columns), last_modified_time_);
+}
+
+bool RowValue::Empty() const {
+ return columns_.empty();
+}
+
+RowValue RowValue::Deserialize(const char *src, std::size_t size) {
+ std::size_t offset = 0;
+ assert(size >= sizeof(local_deletion_time_) + sizeof(marked_for_delete_at_));
+ int32_t local_deletion_time =
+ ROCKSDB_NAMESPACE::cassandra::Deserialize<int32_t>(src, offset);
+ offset += sizeof(int32_t);
+ int64_t marked_for_delete_at =
+ ROCKSDB_NAMESPACE::cassandra::Deserialize<int64_t>(src, offset);
+ offset += sizeof(int64_t);
+ if (offset == size) {
+ return RowValue(local_deletion_time, marked_for_delete_at);
+ }
+
+ assert(local_deletion_time == kDefaultLocalDeletionTime);
+ assert(marked_for_delete_at == kDefaultMarkedForDeleteAt);
+ Columns columns;
+ int64_t last_modified_time = 0;
+ while (offset < size) {
+ auto c = ColumnBase::Deserialize(src, offset);
+ offset += c -> Size();
+ assert(offset <= size);
+ last_modified_time = std::max(last_modified_time, c -> Timestamp());
+ columns.push_back(std::move(c));
+ }
+
+ return RowValue(std::move(columns), last_modified_time);
+}
+
+// Merge multiple row values into one.
+// For each column in rows with same index, we pick the one with latest
+// timestamp. And we also take row tombstone into consideration, by iterating
+// each row from reverse timestamp order, and stop once we hit the first
+// row tombstone.
+RowValue RowValue::Merge(std::vector<RowValue>&& values) {
+ assert(values.size() > 0);
+ if (values.size() == 1) {
+ return std::move(values[0]);
+ }
+
+ // Merge columns by their last modified time, and skip once we hit
+ // a row tombstone.
+ std::sort(values.begin(), values.end(),
+ [](const RowValue& r1, const RowValue& r2) {
+ return r1.LastModifiedTime() > r2.LastModifiedTime();
+ });
+
+ std::map<int8_t, std::shared_ptr<ColumnBase>> merged_columns;
+ int64_t tombstone_timestamp = 0;
+
+ for (auto& value : values) {
+ if (value.IsTombstone()) {
+ if (merged_columns.size() == 0) {
+ return std::move(value);
+ }
+ tombstone_timestamp = value.LastModifiedTime();
+ break;
+ }
+ for (auto& column : value.columns_) {
+ int8_t index = column->Index();
+ if (merged_columns.find(index) == merged_columns.end()) {
+ merged_columns[index] = column;
+ } else {
+ if (column->Timestamp() > merged_columns[index]->Timestamp()) {
+ merged_columns[index] = column;
+ }
+ }
+ }
+ }
+
+ int64_t last_modified_time = 0;
+ Columns columns;
+ for (auto& pair: merged_columns) {
+ // For some row, its last_modified_time > row tombstone_timestamp, but
+ // it might have rows whose timestamp is ealier than tombstone, so we
+ // ned to filter these rows.
+ if (pair.second->Timestamp() <= tombstone_timestamp) {
+ continue;
+ }
+ last_modified_time = std::max(last_modified_time, pair.second->Timestamp());
+ columns.push_back(std::move(pair.second));
+ }
+ return RowValue(std::move(columns), last_modified_time);
+}
+
+} // namepsace cassandrda
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/cassandra/format.h b/src/rocksdb/utilities/cassandra/format.h
new file mode 100644
index 000000000..3f9b433c7
--- /dev/null
+++ b/src/rocksdb/utilities/cassandra/format.h
@@ -0,0 +1,197 @@
+// Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+/**
+ * The encoding of Cassandra Row Value.
+ *
+ * A Cassandra Row Value could either be a row tombstone,
+ * or contains multiple columns, it has following fields:
+ *
+ * struct row_value {
+ * int32_t local_deletion_time; // Time in second when the row is deleted,
+ * // only used for Cassandra tombstone gc.
+ * int64_t marked_for_delete_at; // Ms that marked this row is deleted.
+ * struct column_base columns[]; // For non tombstone row, all columns
+ * // are stored here.
+ * }
+ *
+ * If the local_deletion_time and marked_for_delete_at is set, then this is
+ * a tombstone, otherwise it contains multiple columns.
+ *
+ * There are three type of Columns: Normal Column, Expiring Column and Column
+ * Tombstone, which have following fields:
+ *
+ * // Identify the type of the column.
+ * enum mask {
+ * DELETION_MASK = 0x01,
+ * EXPIRATION_MASK = 0x02,
+ * };
+ *
+ * struct column {
+ * int8_t mask = 0;
+ * int8_t index;
+ * int64_t timestamp;
+ * int32_t value_length;
+ * char value[value_length];
+ * }
+ *
+ * struct expiring_column {
+ * int8_t mask = mask.EXPIRATION_MASK;
+ * int8_t index;
+ * int64_t timestamp;
+ * int32_t value_length;
+ * char value[value_length];
+ * int32_t ttl;
+ * }
+ *
+ * struct tombstone_column {
+ * int8_t mask = mask.DELETION_MASK;
+ * int8_t index;
+ * int32_t local_deletion_time; // Similar to row_value's field.
+ * int64_t marked_for_delete_at;
+ * }
+ */
+
+#pragma once
+#include <chrono>
+#include <memory>
+#include <vector>
+#include "rocksdb/merge_operator.h"
+#include "rocksdb/slice.h"
+#include "test_util/testharness.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace cassandra {
+
+// Identify the type of the column.
+enum ColumnTypeMask {
+ DELETION_MASK = 0x01,
+ EXPIRATION_MASK = 0x02,
+};
+
+
+class ColumnBase {
+public:
+ ColumnBase(int8_t mask, int8_t index);
+ virtual ~ColumnBase() = default;
+
+ virtual int64_t Timestamp() const = 0;
+ virtual int8_t Mask() const;
+ virtual int8_t Index() const;
+ virtual std::size_t Size() const;
+ virtual void Serialize(std::string* dest) const;
+ static std::shared_ptr<ColumnBase> Deserialize(const char* src,
+ std::size_t offset);
+
+private:
+ int8_t mask_;
+ int8_t index_;
+};
+
+class Column : public ColumnBase {
+public:
+ Column(int8_t mask, int8_t index, int64_t timestamp,
+ int32_t value_size, const char* value);
+
+ virtual int64_t Timestamp() const override;
+ virtual std::size_t Size() const override;
+ virtual void Serialize(std::string* dest) const override;
+ static std::shared_ptr<Column> Deserialize(const char* src,
+ std::size_t offset);
+
+private:
+ int64_t timestamp_;
+ int32_t value_size_;
+ const char* value_;
+};
+
+class Tombstone : public ColumnBase {
+public:
+ Tombstone(int8_t mask, int8_t index,
+ int32_t local_deletion_time, int64_t marked_for_delete_at);
+
+ virtual int64_t Timestamp() const override;
+ virtual std::size_t Size() const override;
+ virtual void Serialize(std::string* dest) const override;
+ bool Collectable(int32_t gc_grace_period) const;
+ static std::shared_ptr<Tombstone> Deserialize(const char* src,
+ std::size_t offset);
+
+private:
+ int32_t local_deletion_time_;
+ int64_t marked_for_delete_at_;
+};
+
+class ExpiringColumn : public Column {
+public:
+ ExpiringColumn(int8_t mask, int8_t index, int64_t timestamp,
+ int32_t value_size, const char* value, int32_t ttl);
+
+ virtual std::size_t Size() const override;
+ virtual void Serialize(std::string* dest) const override;
+ bool Expired() const;
+ std::shared_ptr<Tombstone> ToTombstone() const;
+
+ static std::shared_ptr<ExpiringColumn> Deserialize(const char* src,
+ std::size_t offset);
+
+private:
+ int32_t ttl_;
+ std::chrono::time_point<std::chrono::system_clock> TimePoint() const;
+ std::chrono::seconds Ttl() const;
+};
+
+typedef std::vector<std::shared_ptr<ColumnBase>> Columns;
+
+class RowValue {
+public:
+ // Create a Row Tombstone.
+ RowValue(int32_t local_deletion_time, int64_t marked_for_delete_at);
+ // Create a Row containing columns.
+ RowValue(Columns columns,
+ int64_t last_modified_time);
+ RowValue(const RowValue& /*that*/) = delete;
+ RowValue(RowValue&& /*that*/) noexcept = default;
+ RowValue& operator=(const RowValue& /*that*/) = delete;
+ RowValue& operator=(RowValue&& /*that*/) = default;
+
+ std::size_t Size() const;;
+ bool IsTombstone() const;
+ // For Tombstone this returns the marked_for_delete_at_,
+ // otherwise it returns the max timestamp of containing columns.
+ int64_t LastModifiedTime() const;
+ void Serialize(std::string* dest) const;
+ RowValue RemoveExpiredColumns(bool* changed) const;
+ RowValue ConvertExpiredColumnsToTombstones(bool* changed) const;
+ RowValue RemoveTombstones(int32_t gc_grace_period) const;
+ bool Empty() const;
+
+ static RowValue Deserialize(const char* src, std::size_t size);
+ // Merge multiple rows according to their timestamp.
+ static RowValue Merge(std::vector<RowValue>&& values);
+
+private:
+ int32_t local_deletion_time_;
+ int64_t marked_for_delete_at_;
+ Columns columns_;
+ int64_t last_modified_time_;
+
+ FRIEND_TEST(RowValueTest, PurgeTtlShouldRemvoeAllColumnsExpired);
+ FRIEND_TEST(RowValueTest, ExpireTtlShouldConvertExpiredColumnsToTombstones);
+ FRIEND_TEST(RowValueMergeTest, Merge);
+ FRIEND_TEST(RowValueMergeTest, MergeWithRowTombstone);
+ FRIEND_TEST(CassandraFunctionalTest, SimpleMergeTest);
+ FRIEND_TEST(
+ CassandraFunctionalTest, CompactionShouldConvertExpiredColumnsToTombstone);
+ FRIEND_TEST(
+ CassandraFunctionalTest, CompactionShouldPurgeExpiredColumnsIfPurgeTtlIsOn);
+ FRIEND_TEST(
+ CassandraFunctionalTest, CompactionShouldRemoveRowWhenAllColumnExpiredIfPurgeTtlIsOn);
+ FRIEND_TEST(CassandraFunctionalTest,
+ CompactionShouldRemoveTombstoneExceedingGCGracePeriod);
+};
+
+} // namepsace cassandrda
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/cassandra/merge_operator.cc b/src/rocksdb/utilities/cassandra/merge_operator.cc
new file mode 100644
index 000000000..82fe5d661
--- /dev/null
+++ b/src/rocksdb/utilities/cassandra/merge_operator.cc
@@ -0,0 +1,67 @@
+// Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "merge_operator.h"
+
+#include <memory>
+#include <assert.h>
+
+#include "rocksdb/slice.h"
+#include "rocksdb/merge_operator.h"
+#include "utilities/merge_operators.h"
+#include "utilities/cassandra/format.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace cassandra {
+
+// Implementation for the merge operation (merges two Cassandra values)
+bool CassandraValueMergeOperator::FullMergeV2(
+ const MergeOperationInput& merge_in,
+ MergeOperationOutput* merge_out) const {
+ // Clear the *new_value for writing.
+ merge_out->new_value.clear();
+ std::vector<RowValue> row_values;
+ if (merge_in.existing_value) {
+ row_values.push_back(
+ RowValue::Deserialize(merge_in.existing_value->data(),
+ merge_in.existing_value->size()));
+ }
+
+ for (auto& operand : merge_in.operand_list) {
+ row_values.push_back(RowValue::Deserialize(operand.data(), operand.size()));
+ }
+
+ RowValue merged = RowValue::Merge(std::move(row_values));
+ merged = merged.RemoveTombstones(gc_grace_period_in_seconds_);
+ merge_out->new_value.reserve(merged.Size());
+ merged.Serialize(&(merge_out->new_value));
+
+ return true;
+}
+
+bool CassandraValueMergeOperator::PartialMergeMulti(
+ const Slice& /*key*/, const std::deque<Slice>& operand_list,
+ std::string* new_value, Logger* /*logger*/) const {
+ // Clear the *new_value for writing.
+ assert(new_value);
+ new_value->clear();
+
+ std::vector<RowValue> row_values;
+ for (auto& operand : operand_list) {
+ row_values.push_back(RowValue::Deserialize(operand.data(), operand.size()));
+ }
+ RowValue merged = RowValue::Merge(std::move(row_values));
+ new_value->reserve(merged.Size());
+ merged.Serialize(new_value);
+ return true;
+}
+
+const char* CassandraValueMergeOperator::Name() const {
+ return "CassandraValueMergeOperator";
+}
+
+} // namespace cassandra
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/cassandra/merge_operator.h b/src/rocksdb/utilities/cassandra/merge_operator.h
new file mode 100644
index 000000000..b5bf7c520
--- /dev/null
+++ b/src/rocksdb/utilities/cassandra/merge_operator.h
@@ -0,0 +1,44 @@
+// Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+#include "rocksdb/merge_operator.h"
+#include "rocksdb/slice.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace cassandra {
+
+/**
+ * A MergeOperator for rocksdb that implements Cassandra row value merge.
+ */
+class CassandraValueMergeOperator : public MergeOperator {
+public:
+ explicit CassandraValueMergeOperator(int32_t gc_grace_period_in_seconds,
+ size_t operands_limit = 0)
+ : gc_grace_period_in_seconds_(gc_grace_period_in_seconds),
+ operands_limit_(operands_limit) {}
+
+ virtual bool FullMergeV2(const MergeOperationInput& merge_in,
+ MergeOperationOutput* merge_out) const override;
+
+ virtual bool PartialMergeMulti(const Slice& key,
+ const std::deque<Slice>& operand_list,
+ std::string* new_value,
+ Logger* logger) const override;
+
+ virtual const char* Name() const override;
+
+ virtual bool AllowSingleOperand() const override { return true; }
+
+ virtual bool ShouldMerge(const std::vector<Slice>& operands) const override {
+ return operands_limit_ > 0 && operands.size() >= operands_limit_;
+ }
+
+private:
+ int32_t gc_grace_period_in_seconds_;
+ size_t operands_limit_;
+};
+} // namespace cassandra
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/cassandra/serialize.h b/src/rocksdb/utilities/cassandra/serialize.h
new file mode 100644
index 000000000..cd980ade0
--- /dev/null
+++ b/src/rocksdb/utilities/cassandra/serialize.h
@@ -0,0 +1,75 @@
+// Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+/**
+ * Helper functions which serialize and deserialize integers
+ * into bytes in big endian.
+ */
+
+#pragma once
+
+namespace ROCKSDB_NAMESPACE {
+namespace cassandra {
+namespace {
+const int64_t kCharMask = 0xFFLL;
+const int32_t kBitsPerByte = 8;
+}
+
+template<typename T>
+void Serialize(T val, std::string* dest);
+
+template<typename T>
+T Deserialize(const char* src, std::size_t offset=0);
+
+// Specializations
+template<>
+inline void Serialize<int8_t>(int8_t t, std::string* dest) {
+ dest->append(1, static_cast<char>(t & kCharMask));
+}
+
+template<>
+inline void Serialize<int32_t>(int32_t t, std::string* dest) {
+ for (unsigned long i = 0; i < sizeof(int32_t); i++) {
+ dest->append(1, static_cast<char>(
+ (t >> (sizeof(int32_t) - 1 - i) * kBitsPerByte) & kCharMask));
+ }
+}
+
+template<>
+inline void Serialize<int64_t>(int64_t t, std::string* dest) {
+ for (unsigned long i = 0; i < sizeof(int64_t); i++) {
+ dest->append(
+ 1, static_cast<char>(
+ (t >> (sizeof(int64_t) - 1 - i) * kBitsPerByte) & kCharMask));
+ }
+}
+
+template<>
+inline int8_t Deserialize<int8_t>(const char* src, std::size_t offset) {
+ return static_cast<int8_t>(src[offset]);
+}
+
+template<>
+inline int32_t Deserialize<int32_t>(const char* src, std::size_t offset) {
+ int32_t result = 0;
+ for (unsigned long i = 0; i < sizeof(int32_t); i++) {
+ result |= static_cast<int32_t>(static_cast<unsigned char>(src[offset + i]))
+ << ((sizeof(int32_t) - 1 - i) * kBitsPerByte);
+ }
+ return result;
+}
+
+template<>
+inline int64_t Deserialize<int64_t>(const char* src, std::size_t offset) {
+ int64_t result = 0;
+ for (unsigned long i = 0; i < sizeof(int64_t); i++) {
+ result |= static_cast<int64_t>(static_cast<unsigned char>(src[offset + i]))
+ << ((sizeof(int64_t) - 1 - i) * kBitsPerByte);
+ }
+ return result;
+}
+
+} // namepsace cassandrda
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/cassandra/test_utils.cc b/src/rocksdb/utilities/cassandra/test_utils.cc
new file mode 100644
index 000000000..47919bf62
--- /dev/null
+++ b/src/rocksdb/utilities/cassandra/test_utils.cc
@@ -0,0 +1,75 @@
+// Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "test_utils.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace cassandra {
+const char kData[] = {'d', 'a', 't', 'a'};
+const char kExpiringData[] = {'e', 'd', 'a', 't', 'a'};
+const int32_t kTtl = 86400;
+const int8_t kColumn = 0;
+const int8_t kTombstone = 1;
+const int8_t kExpiringColumn = 2;
+
+std::shared_ptr<ColumnBase> CreateTestColumn(int8_t mask,
+ int8_t index,
+ int64_t timestamp) {
+ if ((mask & ColumnTypeMask::DELETION_MASK) != 0) {
+ return std::shared_ptr<Tombstone>(
+ new Tombstone(mask, index, ToSeconds(timestamp), timestamp));
+ } else if ((mask & ColumnTypeMask::EXPIRATION_MASK) != 0) {
+ return std::shared_ptr<ExpiringColumn>(new ExpiringColumn(
+ mask, index, timestamp, sizeof(kExpiringData), kExpiringData, kTtl));
+ } else {
+ return std::shared_ptr<Column>(
+ new Column(mask, index, timestamp, sizeof(kData), kData));
+ }
+}
+
+std::tuple<int8_t, int8_t, int64_t> CreateTestColumnSpec(int8_t mask,
+ int8_t index,
+ int64_t timestamp) {
+ return std::make_tuple(mask, index, timestamp);
+}
+
+RowValue CreateTestRowValue(
+ std::vector<std::tuple<int8_t, int8_t, int64_t>> column_specs) {
+ std::vector<std::shared_ptr<ColumnBase>> columns;
+ int64_t last_modified_time = 0;
+ for (auto spec: column_specs) {
+ auto c = CreateTestColumn(std::get<0>(spec), std::get<1>(spec),
+ std::get<2>(spec));
+ last_modified_time = std::max(last_modified_time, c -> Timestamp());
+ columns.push_back(std::move(c));
+ }
+ return RowValue(std::move(columns), last_modified_time);
+}
+
+RowValue CreateRowTombstone(int64_t timestamp) {
+ return RowValue(ToSeconds(timestamp), timestamp);
+}
+
+void VerifyRowValueColumns(
+ std::vector<std::shared_ptr<ColumnBase>> &columns,
+ std::size_t index_of_vector,
+ int8_t expected_mask,
+ int8_t expected_index,
+ int64_t expected_timestamp
+) {
+ EXPECT_EQ(expected_timestamp, columns[index_of_vector]->Timestamp());
+ EXPECT_EQ(expected_mask, columns[index_of_vector]->Mask());
+ EXPECT_EQ(expected_index, columns[index_of_vector]->Index());
+}
+
+int64_t ToMicroSeconds(int64_t seconds) {
+ return seconds * (int64_t)1000000;
+}
+
+int32_t ToSeconds(int64_t microseconds) {
+ return (int32_t)(microseconds / (int64_t)1000000);
+}
+}
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/cassandra/test_utils.h b/src/rocksdb/utilities/cassandra/test_utils.h
new file mode 100644
index 000000000..235b35a02
--- /dev/null
+++ b/src/rocksdb/utilities/cassandra/test_utils.h
@@ -0,0 +1,46 @@
+// Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+#include <memory>
+#include "test_util/testharness.h"
+#include "utilities/cassandra/format.h"
+#include "utilities/cassandra/serialize.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace cassandra {
+extern const char kData[];
+extern const char kExpiringData[];
+extern const int32_t kTtl;
+extern const int8_t kColumn;
+extern const int8_t kTombstone;
+extern const int8_t kExpiringColumn;
+
+
+std::shared_ptr<ColumnBase> CreateTestColumn(int8_t mask,
+ int8_t index,
+ int64_t timestamp);
+
+std::tuple<int8_t, int8_t, int64_t> CreateTestColumnSpec(int8_t mask,
+ int8_t index,
+ int64_t timestamp);
+
+RowValue CreateTestRowValue(
+ std::vector<std::tuple<int8_t, int8_t, int64_t>> column_specs);
+
+RowValue CreateRowTombstone(int64_t timestamp);
+
+void VerifyRowValueColumns(
+ std::vector<std::shared_ptr<ColumnBase>> &columns,
+ std::size_t index_of_vector,
+ int8_t expected_mask,
+ int8_t expected_index,
+ int64_t expected_timestamp
+);
+
+int64_t ToMicroSeconds(int64_t seconds);
+int32_t ToSeconds(int64_t microseconds);
+}
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/checkpoint/checkpoint_impl.cc b/src/rocksdb/utilities/checkpoint/checkpoint_impl.cc
new file mode 100644
index 000000000..98e609609
--- /dev/null
+++ b/src/rocksdb/utilities/checkpoint/checkpoint_impl.cc
@@ -0,0 +1,516 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2012 Facebook.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef ROCKSDB_LITE
+
+#include "utilities/checkpoint/checkpoint_impl.h"
+
+#include <algorithm>
+#include <cinttypes>
+#include <string>
+#include <vector>
+
+#include "db/wal_manager.h"
+#include "file/file_util.h"
+#include "file/filename.h"
+#include "port/port.h"
+#include "rocksdb/db.h"
+#include "rocksdb/env.h"
+#include "rocksdb/metadata.h"
+#include "rocksdb/transaction_log.h"
+#include "rocksdb/utilities/checkpoint.h"
+#include "test_util/sync_point.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+Status Checkpoint::Create(DB* db, Checkpoint** checkpoint_ptr) {
+ *checkpoint_ptr = new CheckpointImpl(db);
+ return Status::OK();
+}
+
+Status Checkpoint::CreateCheckpoint(const std::string& /*checkpoint_dir*/,
+ uint64_t /*log_size_for_flush*/) {
+ return Status::NotSupported("");
+}
+
+void CheckpointImpl::CleanStagingDirectory(
+ const std::string& full_private_path, Logger* info_log) {
+ std::vector<std::string> subchildren;
+ Status s = db_->GetEnv()->FileExists(full_private_path);
+ if (s.IsNotFound()) {
+ return;
+ }
+ ROCKS_LOG_INFO(info_log, "File exists %s -- %s",
+ full_private_path.c_str(), s.ToString().c_str());
+ db_->GetEnv()->GetChildren(full_private_path, &subchildren);
+ for (auto& subchild : subchildren) {
+ std::string subchild_path = full_private_path + "/" + subchild;
+ s = db_->GetEnv()->DeleteFile(subchild_path);
+ ROCKS_LOG_INFO(info_log, "Delete file %s -- %s",
+ subchild_path.c_str(), s.ToString().c_str());
+ }
+ // finally delete the private dir
+ s = db_->GetEnv()->DeleteDir(full_private_path);
+ ROCKS_LOG_INFO(info_log, "Delete dir %s -- %s",
+ full_private_path.c_str(), s.ToString().c_str());
+}
+
+Status Checkpoint::ExportColumnFamily(
+ ColumnFamilyHandle* /*handle*/, const std::string& /*export_dir*/,
+ ExportImportFilesMetaData** /*metadata*/) {
+ return Status::NotSupported("");
+}
+
+// Builds an openable snapshot of RocksDB
+Status CheckpointImpl::CreateCheckpoint(const std::string& checkpoint_dir,
+ uint64_t log_size_for_flush) {
+ DBOptions db_options = db_->GetDBOptions();
+
+ Status s = db_->GetEnv()->FileExists(checkpoint_dir);
+ if (s.ok()) {
+ return Status::InvalidArgument("Directory exists");
+ } else if (!s.IsNotFound()) {
+ assert(s.IsIOError());
+ return s;
+ }
+
+ ROCKS_LOG_INFO(
+ db_options.info_log,
+ "Started the snapshot process -- creating snapshot in directory %s",
+ checkpoint_dir.c_str());
+
+ size_t final_nonslash_idx = checkpoint_dir.find_last_not_of('/');
+ if (final_nonslash_idx == std::string::npos) {
+ // npos means it's only slashes or empty. Non-empty means it's the root
+ // directory, but it shouldn't be because we verified above the directory
+ // doesn't exist.
+ assert(checkpoint_dir.empty());
+ return Status::InvalidArgument("invalid checkpoint directory name");
+ }
+
+ std::string full_private_path =
+ checkpoint_dir.substr(0, final_nonslash_idx + 1) + ".tmp";
+ ROCKS_LOG_INFO(
+ db_options.info_log,
+ "Snapshot process -- using temporary directory %s",
+ full_private_path.c_str());
+ CleanStagingDirectory(full_private_path, db_options.info_log.get());
+ // create snapshot directory
+ s = db_->GetEnv()->CreateDir(full_private_path);
+ uint64_t sequence_number = 0;
+ if (s.ok()) {
+ db_->DisableFileDeletions();
+ s = CreateCustomCheckpoint(
+ db_options,
+ [&](const std::string& src_dirname, const std::string& fname,
+ FileType) {
+ ROCKS_LOG_INFO(db_options.info_log, "Hard Linking %s", fname.c_str());
+ return db_->GetFileSystem()->LinkFile(src_dirname + fname,
+ full_private_path + fname,
+ IOOptions(), nullptr);
+ } /* link_file_cb */,
+ [&](const std::string& src_dirname, const std::string& fname,
+ uint64_t size_limit_bytes, FileType) {
+ ROCKS_LOG_INFO(db_options.info_log, "Copying %s", fname.c_str());
+ return CopyFile(db_->GetFileSystem(), src_dirname + fname,
+ full_private_path + fname, size_limit_bytes,
+ db_options.use_fsync);
+ } /* copy_file_cb */,
+ [&](const std::string& fname, const std::string& contents, FileType) {
+ ROCKS_LOG_INFO(db_options.info_log, "Creating %s", fname.c_str());
+ return CreateFile(db_->GetFileSystem(), full_private_path + fname,
+ contents, db_options.use_fsync);
+ } /* create_file_cb */,
+ &sequence_number, log_size_for_flush);
+ // we copied all the files, enable file deletions
+ db_->EnableFileDeletions(false);
+ }
+
+ if (s.ok()) {
+ // move tmp private backup to real snapshot directory
+ s = db_->GetEnv()->RenameFile(full_private_path, checkpoint_dir);
+ }
+ if (s.ok()) {
+ std::unique_ptr<Directory> checkpoint_directory;
+ db_->GetEnv()->NewDirectory(checkpoint_dir, &checkpoint_directory);
+ if (checkpoint_directory != nullptr) {
+ s = checkpoint_directory->Fsync();
+ }
+ }
+
+ if (s.ok()) {
+ // here we know that we succeeded and installed the new snapshot
+ ROCKS_LOG_INFO(db_options.info_log, "Snapshot DONE. All is good");
+ ROCKS_LOG_INFO(db_options.info_log, "Snapshot sequence number: %" PRIu64,
+ sequence_number);
+ } else {
+ // clean all the files we might have created
+ ROCKS_LOG_INFO(db_options.info_log, "Snapshot failed -- %s",
+ s.ToString().c_str());
+ CleanStagingDirectory(full_private_path, db_options.info_log.get());
+ }
+ return s;
+}
+
+Status CheckpointImpl::CreateCustomCheckpoint(
+ const DBOptions& db_options,
+ std::function<Status(const std::string& src_dirname,
+ const std::string& src_fname, FileType type)>
+ link_file_cb,
+ std::function<Status(const std::string& src_dirname,
+ const std::string& src_fname,
+ uint64_t size_limit_bytes, FileType type)>
+ copy_file_cb,
+ std::function<Status(const std::string& fname, const std::string& contents,
+ FileType type)>
+ create_file_cb,
+ uint64_t* sequence_number, uint64_t log_size_for_flush) {
+ Status s;
+ std::vector<std::string> live_files;
+ uint64_t manifest_file_size = 0;
+ uint64_t min_log_num = port::kMaxUint64;
+ *sequence_number = db_->GetLatestSequenceNumber();
+ bool same_fs = true;
+ VectorLogPtr live_wal_files;
+
+ bool flush_memtable = true;
+ if (s.ok()) {
+ if (!db_options.allow_2pc) {
+ if (log_size_for_flush == port::kMaxUint64) {
+ flush_memtable = false;
+ } else if (log_size_for_flush > 0) {
+ // If out standing log files are small, we skip the flush.
+ s = db_->GetSortedWalFiles(live_wal_files);
+
+ if (!s.ok()) {
+ return s;
+ }
+
+ // Don't flush column families if total log size is smaller than
+ // log_size_for_flush. We copy the log files instead.
+ // We may be able to cover 2PC case too.
+ uint64_t total_wal_size = 0;
+ for (auto& wal : live_wal_files) {
+ total_wal_size += wal->SizeFileBytes();
+ }
+ if (total_wal_size < log_size_for_flush) {
+ flush_memtable = false;
+ }
+ live_wal_files.clear();
+ }
+ }
+
+ // this will return live_files prefixed with "/"
+ s = db_->GetLiveFiles(live_files, &manifest_file_size, flush_memtable);
+
+ if (s.ok() && db_options.allow_2pc) {
+ // If 2PC is enabled, we need to get minimum log number after the flush.
+ // Need to refetch the live files to recapture the snapshot.
+ if (!db_->GetIntProperty(DB::Properties::kMinLogNumberToKeep,
+ &min_log_num)) {
+ return Status::InvalidArgument(
+ "2PC enabled but cannot fine the min log number to keep.");
+ }
+ // We need to refetch live files with flush to handle this case:
+ // A previous 000001.log contains the prepare record of transaction tnx1.
+ // The current log file is 000002.log, and sequence_number points to this
+ // file.
+ // After calling GetLiveFiles(), 000003.log is created.
+ // Then tnx1 is committed. The commit record is written to 000003.log.
+ // Now we fetch min_log_num, which will be 3.
+ // Then only 000002.log and 000003.log will be copied, and 000001.log will
+ // be skipped. 000003.log contains commit message of tnx1, but we don't
+ // have respective prepare record for it.
+ // In order to avoid this situation, we need to force flush to make sure
+ // all transactions committed before getting min_log_num will be flushed
+ // to SST files.
+ // We cannot get min_log_num before calling the GetLiveFiles() for the
+ // first time, because if we do that, all the logs files will be included,
+ // far more than needed.
+ s = db_->GetLiveFiles(live_files, &manifest_file_size, flush_memtable);
+ }
+
+ TEST_SYNC_POINT("CheckpointImpl::CreateCheckpoint:SavedLiveFiles1");
+ TEST_SYNC_POINT("CheckpointImpl::CreateCheckpoint:SavedLiveFiles2");
+ db_->FlushWAL(false /* sync */);
+ }
+ // if we have more than one column family, we need to also get WAL files
+ if (s.ok()) {
+ s = db_->GetSortedWalFiles(live_wal_files);
+ }
+ if (!s.ok()) {
+ return s;
+ }
+
+ size_t wal_size = live_wal_files.size();
+
+ // copy/hard link live_files
+ std::string manifest_fname, current_fname;
+ for (size_t i = 0; s.ok() && i < live_files.size(); ++i) {
+ uint64_t number;
+ FileType type;
+ bool ok = ParseFileName(live_files[i], &number, &type);
+ if (!ok) {
+ s = Status::Corruption("Can't parse file name. This is very bad");
+ break;
+ }
+ // we should only get sst, options, manifest and current files here
+ assert(type == kTableFile || type == kDescriptorFile ||
+ type == kCurrentFile || type == kOptionsFile);
+ assert(live_files[i].size() > 0 && live_files[i][0] == '/');
+ if (type == kCurrentFile) {
+ // We will craft the current file manually to ensure it's consistent with
+ // the manifest number. This is necessary because current's file contents
+ // can change during checkpoint creation.
+ current_fname = live_files[i];
+ continue;
+ } else if (type == kDescriptorFile) {
+ manifest_fname = live_files[i];
+ }
+ std::string src_fname = live_files[i];
+
+ // rules:
+ // * if it's kTableFile, then it's shared
+ // * if it's kDescriptorFile, limit the size to manifest_file_size
+ // * always copy if cross-device link
+ if ((type == kTableFile) && same_fs) {
+ s = link_file_cb(db_->GetName(), src_fname, type);
+ if (s.IsNotSupported()) {
+ same_fs = false;
+ s = Status::OK();
+ }
+ }
+ if ((type != kTableFile) || (!same_fs)) {
+ s = copy_file_cb(db_->GetName(), src_fname,
+ (type == kDescriptorFile) ? manifest_file_size : 0,
+ type);
+ }
+ }
+ if (s.ok() && !current_fname.empty() && !manifest_fname.empty()) {
+ create_file_cb(current_fname, manifest_fname.substr(1) + "\n",
+ kCurrentFile);
+ }
+ ROCKS_LOG_INFO(db_options.info_log, "Number of log files %" ROCKSDB_PRIszt,
+ live_wal_files.size());
+
+ // Link WAL files. Copy exact size of last one because it is the only one
+ // that has changes after the last flush.
+ for (size_t i = 0; s.ok() && i < wal_size; ++i) {
+ if ((live_wal_files[i]->Type() == kAliveLogFile) &&
+ (!flush_memtable ||
+ live_wal_files[i]->StartSequence() >= *sequence_number ||
+ live_wal_files[i]->LogNumber() >= min_log_num)) {
+ if (i + 1 == wal_size) {
+ s = copy_file_cb(db_options.wal_dir, live_wal_files[i]->PathName(),
+ live_wal_files[i]->SizeFileBytes(), kLogFile);
+ break;
+ }
+ if (same_fs) {
+ // we only care about live log files
+ s = link_file_cb(db_options.wal_dir, live_wal_files[i]->PathName(),
+ kLogFile);
+ if (s.IsNotSupported()) {
+ same_fs = false;
+ s = Status::OK();
+ }
+ }
+ if (!same_fs) {
+ s = copy_file_cb(db_options.wal_dir, live_wal_files[i]->PathName(), 0,
+ kLogFile);
+ }
+ }
+ }
+
+ return s;
+}
+
+// Exports all live SST files of a specified Column Family onto export_dir,
+// returning SST files information in metadata.
+Status CheckpointImpl::ExportColumnFamily(
+ ColumnFamilyHandle* handle, const std::string& export_dir,
+ ExportImportFilesMetaData** metadata) {
+ auto cfh = reinterpret_cast<ColumnFamilyHandleImpl*>(handle);
+ const auto cf_name = cfh->GetName();
+ const auto db_options = db_->GetDBOptions();
+
+ assert(metadata != nullptr);
+ assert(*metadata == nullptr);
+ auto s = db_->GetEnv()->FileExists(export_dir);
+ if (s.ok()) {
+ return Status::InvalidArgument("Specified export_dir exists");
+ } else if (!s.IsNotFound()) {
+ assert(s.IsIOError());
+ return s;
+ }
+
+ const auto final_nonslash_idx = export_dir.find_last_not_of('/');
+ if (final_nonslash_idx == std::string::npos) {
+ return Status::InvalidArgument("Specified export_dir invalid");
+ }
+ ROCKS_LOG_INFO(db_options.info_log,
+ "[%s] export column family onto export directory %s",
+ cf_name.c_str(), export_dir.c_str());
+
+ // Create a temporary export directory.
+ const auto tmp_export_dir =
+ export_dir.substr(0, final_nonslash_idx + 1) + ".tmp";
+ s = db_->GetEnv()->CreateDir(tmp_export_dir);
+
+ if (s.ok()) {
+ s = db_->Flush(ROCKSDB_NAMESPACE::FlushOptions(), handle);
+ }
+
+ ColumnFamilyMetaData db_metadata;
+ if (s.ok()) {
+ // Export live sst files with file deletions disabled.
+ s = db_->DisableFileDeletions();
+ if (s.ok()) {
+ db_->GetColumnFamilyMetaData(handle, &db_metadata);
+
+ s = ExportFilesInMetaData(
+ db_options, db_metadata,
+ [&](const std::string& src_dirname, const std::string& fname) {
+ ROCKS_LOG_INFO(db_options.info_log, "[%s] HardLinking %s",
+ cf_name.c_str(), fname.c_str());
+ return db_->GetEnv()->LinkFile(src_dirname + fname,
+ tmp_export_dir + fname);
+ } /*link_file_cb*/,
+ [&](const std::string& src_dirname, const std::string& fname) {
+ ROCKS_LOG_INFO(db_options.info_log, "[%s] Copying %s",
+ cf_name.c_str(), fname.c_str());
+ return CopyFile(db_->GetFileSystem(), src_dirname + fname,
+ tmp_export_dir + fname, 0, db_options.use_fsync);
+ } /*copy_file_cb*/);
+
+ const auto enable_status = db_->EnableFileDeletions(false /*force*/);
+ if (s.ok()) {
+ s = enable_status;
+ }
+ }
+ }
+
+ auto moved_to_user_specified_dir = false;
+ if (s.ok()) {
+ // Move temporary export directory to the actual export directory.
+ s = db_->GetEnv()->RenameFile(tmp_export_dir, export_dir);
+ }
+
+ if (s.ok()) {
+ // Fsync export directory.
+ moved_to_user_specified_dir = true;
+ std::unique_ptr<Directory> dir_ptr;
+ s = db_->GetEnv()->NewDirectory(export_dir, &dir_ptr);
+ if (s.ok()) {
+ assert(dir_ptr != nullptr);
+ s = dir_ptr->Fsync();
+ }
+ }
+
+ if (s.ok()) {
+ // Export of files succeeded. Fill in the metadata information.
+ auto result_metadata = new ExportImportFilesMetaData();
+ result_metadata->db_comparator_name = handle->GetComparator()->Name();
+ for (const auto& level_metadata : db_metadata.levels) {
+ for (const auto& file_metadata : level_metadata.files) {
+ LiveFileMetaData live_file_metadata;
+ live_file_metadata.size = file_metadata.size;
+ live_file_metadata.name = std::move(file_metadata.name);
+ live_file_metadata.file_number = file_metadata.file_number;
+ live_file_metadata.db_path = export_dir;
+ live_file_metadata.smallest_seqno = file_metadata.smallest_seqno;
+ live_file_metadata.largest_seqno = file_metadata.largest_seqno;
+ live_file_metadata.smallestkey = std::move(file_metadata.smallestkey);
+ live_file_metadata.largestkey = std::move(file_metadata.largestkey);
+ live_file_metadata.oldest_blob_file_number =
+ file_metadata.oldest_blob_file_number;
+ live_file_metadata.level = level_metadata.level;
+ result_metadata->files.push_back(live_file_metadata);
+ }
+ *metadata = result_metadata;
+ }
+ ROCKS_LOG_INFO(db_options.info_log, "[%s] Export succeeded.",
+ cf_name.c_str());
+ } else {
+ // Failure: Clean up all the files/directories created.
+ ROCKS_LOG_INFO(db_options.info_log, "[%s] Export failed. %s",
+ cf_name.c_str(), s.ToString().c_str());
+ std::vector<std::string> subchildren;
+ const auto cleanup_dir =
+ moved_to_user_specified_dir ? export_dir : tmp_export_dir;
+ db_->GetEnv()->GetChildren(cleanup_dir, &subchildren);
+ for (const auto& subchild : subchildren) {
+ const auto subchild_path = cleanup_dir + "/" + subchild;
+ const auto status = db_->GetEnv()->DeleteFile(subchild_path);
+ if (!status.ok()) {
+ ROCKS_LOG_WARN(db_options.info_log, "Failed to cleanup file %s: %s",
+ subchild_path.c_str(), status.ToString().c_str());
+ }
+ }
+ const auto status = db_->GetEnv()->DeleteDir(cleanup_dir);
+ if (!status.ok()) {
+ ROCKS_LOG_WARN(db_options.info_log, "Failed to cleanup dir %s: %s",
+ cleanup_dir.c_str(), status.ToString().c_str());
+ }
+ }
+ return s;
+}
+
+Status CheckpointImpl::ExportFilesInMetaData(
+ const DBOptions& db_options, const ColumnFamilyMetaData& metadata,
+ std::function<Status(const std::string& src_dirname,
+ const std::string& src_fname)>
+ link_file_cb,
+ std::function<Status(const std::string& src_dirname,
+ const std::string& src_fname)>
+ copy_file_cb) {
+ Status s;
+ auto hardlink_file = true;
+
+ // Copy/hard link files in metadata.
+ size_t num_files = 0;
+ for (const auto& level_metadata : metadata.levels) {
+ for (const auto& file_metadata : level_metadata.files) {
+ uint64_t number;
+ FileType type;
+ const auto ok = ParseFileName(file_metadata.name, &number, &type);
+ if (!ok) {
+ s = Status::Corruption("Could not parse file name");
+ break;
+ }
+
+ // We should only get sst files here.
+ assert(type == kTableFile);
+ assert(file_metadata.size > 0 && file_metadata.name[0] == '/');
+ const auto src_fname = file_metadata.name;
+ ++num_files;
+
+ if (hardlink_file) {
+ s = link_file_cb(db_->GetName(), src_fname);
+ if (num_files == 1 && s.IsNotSupported()) {
+ // Fallback to copy if link failed due to cross-device directories.
+ hardlink_file = false;
+ s = Status::OK();
+ }
+ }
+ if (!hardlink_file) {
+ s = copy_file_cb(db_->GetName(), src_fname);
+ }
+ if (!s.ok()) {
+ break;
+ }
+ }
+ }
+ ROCKS_LOG_INFO(db_options.info_log, "Number of table files %" ROCKSDB_PRIszt,
+ num_files);
+
+ return s;
+}
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/checkpoint/checkpoint_impl.h b/src/rocksdb/utilities/checkpoint/checkpoint_impl.h
new file mode 100644
index 000000000..81ee8320b
--- /dev/null
+++ b/src/rocksdb/utilities/checkpoint/checkpoint_impl.h
@@ -0,0 +1,79 @@
+// Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+#ifndef ROCKSDB_LITE
+
+#include "rocksdb/utilities/checkpoint.h"
+
+#include <string>
+#include "file/filename.h"
+#include "rocksdb/db.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class CheckpointImpl : public Checkpoint {
+ public:
+ // Creates a Checkpoint object to be used for creating openable snapshots
+ explicit CheckpointImpl(DB* db) : db_(db) {}
+
+ // Builds an openable snapshot of RocksDB on the same disk, which
+ // accepts an output directory on the same disk, and under the directory
+ // (1) hard-linked SST files pointing to existing live SST files
+ // SST files will be copied if output directory is on a different filesystem
+ // (2) a copied manifest files and other files
+ // The directory should not already exist and will be created by this API.
+ // The directory will be an absolute path
+ using Checkpoint::CreateCheckpoint;
+ virtual Status CreateCheckpoint(const std::string& checkpoint_dir,
+ uint64_t log_size_for_flush) override;
+
+ // Exports all live SST files of a specified Column Family onto export_dir
+ // and returning SST files information in metadata.
+ // - SST files will be created as hard links when the directory specified
+ // is in the same partition as the db directory, copied otherwise.
+ // - export_dir should not already exist and will be created by this API.
+ // - Always triggers a flush.
+ using Checkpoint::ExportColumnFamily;
+ virtual Status ExportColumnFamily(
+ ColumnFamilyHandle* handle, const std::string& export_dir,
+ ExportImportFilesMetaData** metadata) override;
+
+ // Checkpoint logic can be customized by providing callbacks for link, copy,
+ // or create.
+ Status CreateCustomCheckpoint(
+ const DBOptions& db_options,
+ std::function<Status(const std::string& src_dirname,
+ const std::string& fname, FileType type)>
+ link_file_cb,
+ std::function<Status(const std::string& src_dirname,
+ const std::string& fname, uint64_t size_limit_bytes,
+ FileType type)>
+ copy_file_cb,
+ std::function<Status(const std::string& fname,
+ const std::string& contents, FileType type)>
+ create_file_cb,
+ uint64_t* sequence_number, uint64_t log_size_for_flush);
+
+ private:
+ void CleanStagingDirectory(const std::string& path, Logger* info_log);
+
+ // Export logic customization by providing callbacks for link or copy.
+ Status ExportFilesInMetaData(
+ const DBOptions& db_options, const ColumnFamilyMetaData& metadata,
+ std::function<Status(const std::string& src_dirname,
+ const std::string& fname)>
+ link_file_cb,
+ std::function<Status(const std::string& src_dirname,
+ const std::string& fname)>
+ copy_file_cb);
+
+ private:
+ DB* db_;
+};
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/checkpoint/checkpoint_test.cc b/src/rocksdb/utilities/checkpoint/checkpoint_test.cc
new file mode 100644
index 000000000..1a31c40ff
--- /dev/null
+++ b/src/rocksdb/utilities/checkpoint/checkpoint_test.cc
@@ -0,0 +1,829 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+// Syncpoint prevents us building and running tests in release
+#ifndef ROCKSDB_LITE
+
+#ifndef OS_WIN
+#include <unistd.h>
+#endif
+#include <iostream>
+#include <thread>
+#include <utility>
+#include "db/db_impl/db_impl.h"
+#include "port/port.h"
+#include "port/stack_trace.h"
+#include "rocksdb/db.h"
+#include "rocksdb/env.h"
+#include "rocksdb/utilities/checkpoint.h"
+#include "rocksdb/utilities/transaction_db.h"
+#include "test_util/fault_injection_test_env.h"
+#include "test_util/sync_point.h"
+#include "test_util/testharness.h"
+#include "test_util/testutil.h"
+
+namespace ROCKSDB_NAMESPACE {
+class CheckpointTest : public testing::Test {
+ protected:
+ // Sequence of option configurations to try
+ enum OptionConfig {
+ kDefault = 0,
+ };
+ int option_config_;
+
+ public:
+ std::string dbname_;
+ std::string alternative_wal_dir_;
+ Env* env_;
+ DB* db_;
+ Options last_options_;
+ std::vector<ColumnFamilyHandle*> handles_;
+ std::string snapshot_name_;
+ std::string export_path_;
+ ColumnFamilyHandle* cfh_reverse_comp_;
+ ExportImportFilesMetaData* metadata_;
+
+ CheckpointTest() : env_(Env::Default()) {
+ env_->SetBackgroundThreads(1, Env::LOW);
+ env_->SetBackgroundThreads(1, Env::HIGH);
+ dbname_ = test::PerThreadDBPath(env_, "checkpoint_test");
+ alternative_wal_dir_ = dbname_ + "/wal";
+ auto options = CurrentOptions();
+ auto delete_options = options;
+ delete_options.wal_dir = alternative_wal_dir_;
+ EXPECT_OK(DestroyDB(dbname_, delete_options));
+ // Destroy it for not alternative WAL dir is used.
+ EXPECT_OK(DestroyDB(dbname_, options));
+ db_ = nullptr;
+ snapshot_name_ = test::PerThreadDBPath(env_, "snapshot");
+ std::string snapshot_tmp_name = snapshot_name_ + ".tmp";
+ EXPECT_OK(DestroyDB(snapshot_name_, options));
+ env_->DeleteDir(snapshot_name_);
+ EXPECT_OK(DestroyDB(snapshot_tmp_name, options));
+ env_->DeleteDir(snapshot_tmp_name);
+ Reopen(options);
+ export_path_ = test::TmpDir(env_) + "/export";
+ test::DestroyDir(env_, export_path_);
+ cfh_reverse_comp_ = nullptr;
+ metadata_ = nullptr;
+ }
+
+ ~CheckpointTest() override {
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency({});
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->ClearAllCallBacks();
+ if (cfh_reverse_comp_) {
+ EXPECT_OK(db_->DestroyColumnFamilyHandle(cfh_reverse_comp_));
+ cfh_reverse_comp_ = nullptr;
+ }
+ if (metadata_) {
+ delete metadata_;
+ metadata_ = nullptr;
+ }
+ Close();
+ Options options;
+ options.db_paths.emplace_back(dbname_, 0);
+ options.db_paths.emplace_back(dbname_ + "_2", 0);
+ options.db_paths.emplace_back(dbname_ + "_3", 0);
+ options.db_paths.emplace_back(dbname_ + "_4", 0);
+ EXPECT_OK(DestroyDB(dbname_, options));
+ EXPECT_OK(DestroyDB(snapshot_name_, options));
+ test::DestroyDir(env_, export_path_);
+ }
+
+ // Return the current option configuration.
+ Options CurrentOptions() {
+ Options options;
+ options.env = env_;
+ options.create_if_missing = true;
+ return options;
+ }
+
+ void CreateColumnFamilies(const std::vector<std::string>& cfs,
+ const Options& options) {
+ ColumnFamilyOptions cf_opts(options);
+ size_t cfi = handles_.size();
+ handles_.resize(cfi + cfs.size());
+ for (auto cf : cfs) {
+ ASSERT_OK(db_->CreateColumnFamily(cf_opts, cf, &handles_[cfi++]));
+ }
+ }
+
+ void CreateAndReopenWithCF(const std::vector<std::string>& cfs,
+ const Options& options) {
+ CreateColumnFamilies(cfs, options);
+ std::vector<std::string> cfs_plus_default = cfs;
+ cfs_plus_default.insert(cfs_plus_default.begin(), kDefaultColumnFamilyName);
+ ReopenWithColumnFamilies(cfs_plus_default, options);
+ }
+
+ void ReopenWithColumnFamilies(const std::vector<std::string>& cfs,
+ const std::vector<Options>& options) {
+ ASSERT_OK(TryReopenWithColumnFamilies(cfs, options));
+ }
+
+ void ReopenWithColumnFamilies(const std::vector<std::string>& cfs,
+ const Options& options) {
+ ASSERT_OK(TryReopenWithColumnFamilies(cfs, options));
+ }
+
+ Status TryReopenWithColumnFamilies(
+ const std::vector<std::string>& cfs,
+ const std::vector<Options>& options) {
+ Close();
+ EXPECT_EQ(cfs.size(), options.size());
+ std::vector<ColumnFamilyDescriptor> column_families;
+ for (size_t i = 0; i < cfs.size(); ++i) {
+ column_families.push_back(ColumnFamilyDescriptor(cfs[i], options[i]));
+ }
+ DBOptions db_opts = DBOptions(options[0]);
+ return DB::Open(db_opts, dbname_, column_families, &handles_, &db_);
+ }
+
+ Status TryReopenWithColumnFamilies(const std::vector<std::string>& cfs,
+ const Options& options) {
+ Close();
+ std::vector<Options> v_opts(cfs.size(), options);
+ return TryReopenWithColumnFamilies(cfs, v_opts);
+ }
+
+ void Reopen(const Options& options) {
+ ASSERT_OK(TryReopen(options));
+ }
+
+ void CompactAll() {
+ for (auto h : handles_) {
+ ASSERT_OK(db_->CompactRange(CompactRangeOptions(), h, nullptr, nullptr));
+ }
+ }
+
+ void Close() {
+ for (auto h : handles_) {
+ delete h;
+ }
+ handles_.clear();
+ delete db_;
+ db_ = nullptr;
+ }
+
+ void DestroyAndReopen(const Options& options) {
+ // Destroy using last options
+ Destroy(last_options_);
+ ASSERT_OK(TryReopen(options));
+ }
+
+ void Destroy(const Options& options) {
+ Close();
+ ASSERT_OK(DestroyDB(dbname_, options));
+ }
+
+ Status ReadOnlyReopen(const Options& options) {
+ return DB::OpenForReadOnly(options, dbname_, &db_);
+ }
+
+ Status ReadOnlyReopenWithColumnFamilies(const std::vector<std::string>& cfs,
+ const Options& options) {
+ std::vector<ColumnFamilyDescriptor> column_families;
+ for (const auto& cf : cfs) {
+ column_families.emplace_back(cf, options);
+ }
+ return DB::OpenForReadOnly(options, dbname_, column_families, &handles_,
+ &db_);
+ }
+
+ Status TryReopen(const Options& options) {
+ Close();
+ last_options_ = options;
+ return DB::Open(options, dbname_, &db_);
+ }
+
+ Status Flush(int cf = 0) {
+ if (cf == 0) {
+ return db_->Flush(FlushOptions());
+ } else {
+ return db_->Flush(FlushOptions(), handles_[cf]);
+ }
+ }
+
+ Status Put(const Slice& k, const Slice& v, WriteOptions wo = WriteOptions()) {
+ return db_->Put(wo, k, v);
+ }
+
+ Status Put(int cf, const Slice& k, const Slice& v,
+ WriteOptions wo = WriteOptions()) {
+ return db_->Put(wo, handles_[cf], k, v);
+ }
+
+ Status Delete(const std::string& k) {
+ return db_->Delete(WriteOptions(), k);
+ }
+
+ Status Delete(int cf, const std::string& k) {
+ return db_->Delete(WriteOptions(), handles_[cf], k);
+ }
+
+ std::string Get(const std::string& k, const Snapshot* snapshot = nullptr) {
+ ReadOptions options;
+ options.verify_checksums = true;
+ options.snapshot = snapshot;
+ std::string result;
+ Status s = db_->Get(options, k, &result);
+ if (s.IsNotFound()) {
+ result = "NOT_FOUND";
+ } else if (!s.ok()) {
+ result = s.ToString();
+ }
+ return result;
+ }
+
+ std::string Get(int cf, const std::string& k,
+ const Snapshot* snapshot = nullptr) {
+ ReadOptions options;
+ options.verify_checksums = true;
+ options.snapshot = snapshot;
+ std::string result;
+ Status s = db_->Get(options, handles_[cf], k, &result);
+ if (s.IsNotFound()) {
+ result = "NOT_FOUND";
+ } else if (!s.ok()) {
+ result = s.ToString();
+ }
+ return result;
+ }
+};
+
+TEST_F(CheckpointTest, GetSnapshotLink) {
+ for (uint64_t log_size_for_flush : {0, 1000000}) {
+ Options options;
+ DB* snapshotDB;
+ ReadOptions roptions;
+ std::string result;
+ Checkpoint* checkpoint;
+
+ options = CurrentOptions();
+ delete db_;
+ db_ = nullptr;
+ ASSERT_OK(DestroyDB(dbname_, options));
+
+ // Create a database
+ Status s;
+ options.create_if_missing = true;
+ ASSERT_OK(DB::Open(options, dbname_, &db_));
+ std::string key = std::string("foo");
+ ASSERT_OK(Put(key, "v1"));
+ // Take a snapshot
+ ASSERT_OK(Checkpoint::Create(db_, &checkpoint));
+ ASSERT_OK(checkpoint->CreateCheckpoint(snapshot_name_, log_size_for_flush));
+ ASSERT_OK(Put(key, "v2"));
+ ASSERT_EQ("v2", Get(key));
+ ASSERT_OK(Flush());
+ ASSERT_EQ("v2", Get(key));
+ // Open snapshot and verify contents while DB is running
+ options.create_if_missing = false;
+ ASSERT_OK(DB::Open(options, snapshot_name_, &snapshotDB));
+ ASSERT_OK(snapshotDB->Get(roptions, key, &result));
+ ASSERT_EQ("v1", result);
+ delete snapshotDB;
+ snapshotDB = nullptr;
+ delete db_;
+ db_ = nullptr;
+
+ // Destroy original DB
+ ASSERT_OK(DestroyDB(dbname_, options));
+
+ // Open snapshot and verify contents
+ options.create_if_missing = false;
+ dbname_ = snapshot_name_;
+ ASSERT_OK(DB::Open(options, dbname_, &db_));
+ ASSERT_EQ("v1", Get(key));
+ delete db_;
+ db_ = nullptr;
+ ASSERT_OK(DestroyDB(dbname_, options));
+ delete checkpoint;
+
+ // Restore DB name
+ dbname_ = test::PerThreadDBPath(env_, "db_test");
+ }
+}
+
+TEST_F(CheckpointTest, ExportColumnFamilyWithLinks) {
+ // Create a database
+ Status s;
+ auto options = CurrentOptions();
+ options.create_if_missing = true;
+ CreateAndReopenWithCF({}, options);
+
+ // Helper to verify the number of files in metadata and export dir
+ auto verify_files_exported = [&](const ExportImportFilesMetaData& metadata,
+ int num_files_expected) {
+ ASSERT_EQ(metadata.files.size(), num_files_expected);
+ std::vector<std::string> subchildren;
+ env_->GetChildren(export_path_, &subchildren);
+ int num_children = 0;
+ for (const auto& child : subchildren) {
+ if (child != "." && child != "..") {
+ ++num_children;
+ }
+ }
+ ASSERT_EQ(num_children, num_files_expected);
+ };
+
+ // Test DefaultColumnFamily
+ {
+ const auto key = std::string("foo");
+ ASSERT_OK(Put(key, "v1"));
+
+ Checkpoint* checkpoint;
+ ASSERT_OK(Checkpoint::Create(db_, &checkpoint));
+
+ // Export the Tables and verify
+ ASSERT_OK(checkpoint->ExportColumnFamily(db_->DefaultColumnFamily(),
+ export_path_, &metadata_));
+ verify_files_exported(*metadata_, 1);
+ ASSERT_EQ(metadata_->db_comparator_name, options.comparator->Name());
+ test::DestroyDir(env_, export_path_);
+ delete metadata_;
+ metadata_ = nullptr;
+
+ // Check again after compaction
+ CompactAll();
+ ASSERT_OK(Put(key, "v2"));
+ ASSERT_OK(checkpoint->ExportColumnFamily(db_->DefaultColumnFamily(),
+ export_path_, &metadata_));
+ verify_files_exported(*metadata_, 2);
+ ASSERT_EQ(metadata_->db_comparator_name, options.comparator->Name());
+ test::DestroyDir(env_, export_path_);
+ delete metadata_;
+ metadata_ = nullptr;
+ delete checkpoint;
+ }
+
+ // Test non default column family with non default comparator
+ {
+ auto cf_options = CurrentOptions();
+ cf_options.comparator = ReverseBytewiseComparator();
+ ASSERT_OK(db_->CreateColumnFamily(cf_options, "yoyo", &cfh_reverse_comp_));
+
+ const auto key = std::string("foo");
+ ASSERT_OK(db_->Put(WriteOptions(), cfh_reverse_comp_, key, "v1"));
+
+ Checkpoint* checkpoint;
+ ASSERT_OK(Checkpoint::Create(db_, &checkpoint));
+
+ // Export the Tables and verify
+ ASSERT_OK(checkpoint->ExportColumnFamily(cfh_reverse_comp_, export_path_,
+ &metadata_));
+ verify_files_exported(*metadata_, 1);
+ ASSERT_EQ(metadata_->db_comparator_name,
+ ReverseBytewiseComparator()->Name());
+ delete checkpoint;
+ }
+}
+
+TEST_F(CheckpointTest, ExportColumnFamilyNegativeTest) {
+ // Create a database
+ Status s;
+ auto options = CurrentOptions();
+ options.create_if_missing = true;
+ CreateAndReopenWithCF({}, options);
+
+ const auto key = std::string("foo");
+ ASSERT_OK(Put(key, "v1"));
+
+ Checkpoint* checkpoint;
+ ASSERT_OK(Checkpoint::Create(db_, &checkpoint));
+
+ // Export onto existing directory
+ env_->CreateDirIfMissing(export_path_);
+ ASSERT_EQ(checkpoint->ExportColumnFamily(db_->DefaultColumnFamily(),
+ export_path_, &metadata_),
+ Status::InvalidArgument("Specified export_dir exists"));
+ test::DestroyDir(env_, export_path_);
+
+ // Export with invalid directory specification
+ export_path_ = "";
+ ASSERT_EQ(checkpoint->ExportColumnFamily(db_->DefaultColumnFamily(),
+ export_path_, &metadata_),
+ Status::InvalidArgument("Specified export_dir invalid"));
+ delete checkpoint;
+}
+
+TEST_F(CheckpointTest, CheckpointCF) {
+ Options options = CurrentOptions();
+ CreateAndReopenWithCF({"one", "two", "three", "four", "five"}, options);
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency(
+ {{"CheckpointTest::CheckpointCF:2", "DBImpl::GetLiveFiles:2"},
+ {"DBImpl::GetLiveFiles:1", "CheckpointTest::CheckpointCF:1"}});
+
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+
+ ASSERT_OK(Put(0, "Default", "Default"));
+ ASSERT_OK(Put(1, "one", "one"));
+ ASSERT_OK(Put(2, "two", "two"));
+ ASSERT_OK(Put(3, "three", "three"));
+ ASSERT_OK(Put(4, "four", "four"));
+ ASSERT_OK(Put(5, "five", "five"));
+
+ DB* snapshotDB;
+ ReadOptions roptions;
+ std::string result;
+ std::vector<ColumnFamilyHandle*> cphandles;
+
+ Status s;
+ // Take a snapshot
+ ROCKSDB_NAMESPACE::port::Thread t([&]() {
+ Checkpoint* checkpoint;
+ ASSERT_OK(Checkpoint::Create(db_, &checkpoint));
+ ASSERT_OK(checkpoint->CreateCheckpoint(snapshot_name_));
+ delete checkpoint;
+ });
+ TEST_SYNC_POINT("CheckpointTest::CheckpointCF:1");
+ ASSERT_OK(Put(0, "Default", "Default1"));
+ ASSERT_OK(Put(1, "one", "eleven"));
+ ASSERT_OK(Put(2, "two", "twelve"));
+ ASSERT_OK(Put(3, "three", "thirteen"));
+ ASSERT_OK(Put(4, "four", "fourteen"));
+ ASSERT_OK(Put(5, "five", "fifteen"));
+ TEST_SYNC_POINT("CheckpointTest::CheckpointCF:2");
+ t.join();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+ ASSERT_OK(Put(1, "one", "twentyone"));
+ ASSERT_OK(Put(2, "two", "twentytwo"));
+ ASSERT_OK(Put(3, "three", "twentythree"));
+ ASSERT_OK(Put(4, "four", "twentyfour"));
+ ASSERT_OK(Put(5, "five", "twentyfive"));
+ ASSERT_OK(Flush());
+
+ // Open snapshot and verify contents while DB is running
+ options.create_if_missing = false;
+ std::vector<std::string> cfs;
+ cfs= {kDefaultColumnFamilyName, "one", "two", "three", "four", "five"};
+ std::vector<ColumnFamilyDescriptor> column_families;
+ for (size_t i = 0; i < cfs.size(); ++i) {
+ column_families.push_back(ColumnFamilyDescriptor(cfs[i], options));
+ }
+ ASSERT_OK(DB::Open(options, snapshot_name_,
+ column_families, &cphandles, &snapshotDB));
+ ASSERT_OK(snapshotDB->Get(roptions, cphandles[0], "Default", &result));
+ ASSERT_EQ("Default1", result);
+ ASSERT_OK(snapshotDB->Get(roptions, cphandles[1], "one", &result));
+ ASSERT_EQ("eleven", result);
+ ASSERT_OK(snapshotDB->Get(roptions, cphandles[2], "two", &result));
+ for (auto h : cphandles) {
+ delete h;
+ }
+ cphandles.clear();
+ delete snapshotDB;
+ snapshotDB = nullptr;
+}
+
+TEST_F(CheckpointTest, CheckpointCFNoFlush) {
+ Options options = CurrentOptions();
+ CreateAndReopenWithCF({"one", "two", "three", "four", "five"}, options);
+
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+
+ ASSERT_OK(Put(0, "Default", "Default"));
+ ASSERT_OK(Put(1, "one", "one"));
+ Flush();
+ ASSERT_OK(Put(2, "two", "two"));
+
+ DB* snapshotDB;
+ ReadOptions roptions;
+ std::string result;
+ std::vector<ColumnFamilyHandle*> cphandles;
+
+ Status s;
+ // Take a snapshot
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "DBImpl::BackgroundCallFlush:start", [&](void* /*arg*/) {
+ // Flush should never trigger.
+ FAIL();
+ });
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+ Checkpoint* checkpoint;
+ ASSERT_OK(Checkpoint::Create(db_, &checkpoint));
+ ASSERT_OK(checkpoint->CreateCheckpoint(snapshot_name_, 1000000));
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+
+ delete checkpoint;
+ ASSERT_OK(Put(1, "one", "two"));
+ ASSERT_OK(Flush(1));
+ ASSERT_OK(Put(2, "two", "twentytwo"));
+ Close();
+ EXPECT_OK(DestroyDB(dbname_, options));
+
+ // Open snapshot and verify contents while DB is running
+ options.create_if_missing = false;
+ std::vector<std::string> cfs;
+ cfs = {kDefaultColumnFamilyName, "one", "two", "three", "four", "five"};
+ std::vector<ColumnFamilyDescriptor> column_families;
+ for (size_t i = 0; i < cfs.size(); ++i) {
+ column_families.push_back(ColumnFamilyDescriptor(cfs[i], options));
+ }
+ ASSERT_OK(DB::Open(options, snapshot_name_, column_families, &cphandles,
+ &snapshotDB));
+ ASSERT_OK(snapshotDB->Get(roptions, cphandles[0], "Default", &result));
+ ASSERT_EQ("Default", result);
+ ASSERT_OK(snapshotDB->Get(roptions, cphandles[1], "one", &result));
+ ASSERT_EQ("one", result);
+ ASSERT_OK(snapshotDB->Get(roptions, cphandles[2], "two", &result));
+ ASSERT_EQ("two", result);
+ for (auto h : cphandles) {
+ delete h;
+ }
+ cphandles.clear();
+ delete snapshotDB;
+ snapshotDB = nullptr;
+}
+
+TEST_F(CheckpointTest, CurrentFileModifiedWhileCheckpointing) {
+ Options options = CurrentOptions();
+ options.max_manifest_file_size = 0; // always rollover manifest for file add
+ Reopen(options);
+
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency(
+ {// Get past the flush in the checkpoint thread before adding any keys to
+ // the db so the checkpoint thread won't hit the WriteManifest
+ // syncpoints.
+ {"DBImpl::GetLiveFiles:1",
+ "CheckpointTest::CurrentFileModifiedWhileCheckpointing:PrePut"},
+ // Roll the manifest during checkpointing right after live files are
+ // snapshotted.
+ {"CheckpointImpl::CreateCheckpoint:SavedLiveFiles1",
+ "VersionSet::LogAndApply:WriteManifest"},
+ {"VersionSet::LogAndApply:WriteManifestDone",
+ "CheckpointImpl::CreateCheckpoint:SavedLiveFiles2"}});
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+
+ ROCKSDB_NAMESPACE::port::Thread t([&]() {
+ Checkpoint* checkpoint;
+ ASSERT_OK(Checkpoint::Create(db_, &checkpoint));
+ ASSERT_OK(checkpoint->CreateCheckpoint(snapshot_name_));
+ delete checkpoint;
+ });
+ TEST_SYNC_POINT(
+ "CheckpointTest::CurrentFileModifiedWhileCheckpointing:PrePut");
+ ASSERT_OK(Put("Default", "Default1"));
+ ASSERT_OK(Flush());
+ t.join();
+
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+
+ DB* snapshotDB;
+ // Successful Open() implies that CURRENT pointed to the manifest in the
+ // checkpoint.
+ ASSERT_OK(DB::Open(options, snapshot_name_, &snapshotDB));
+ delete snapshotDB;
+ snapshotDB = nullptr;
+}
+
+TEST_F(CheckpointTest, CurrentFileModifiedWhileCheckpointing2PC) {
+ Close();
+ const std::string dbname = test::PerThreadDBPath("transaction_testdb");
+ ASSERT_OK(DestroyDB(dbname, CurrentOptions()));
+ env_->DeleteDir(dbname);
+
+ Options options = CurrentOptions();
+ options.allow_2pc = true;
+ // allow_2pc is implicitly set with tx prepare
+ // options.allow_2pc = true;
+ TransactionDBOptions txn_db_options;
+ TransactionDB* txdb;
+ Status s = TransactionDB::Open(options, txn_db_options, dbname, &txdb);
+ assert(s.ok());
+ ColumnFamilyHandle* cfa;
+ ColumnFamilyHandle* cfb;
+ ColumnFamilyOptions cf_options;
+ ASSERT_OK(txdb->CreateColumnFamily(cf_options, "CFA", &cfa));
+
+ WriteOptions write_options;
+ // Insert something into CFB so lots of log files will be kept
+ // before creating the checkpoint.
+ ASSERT_OK(txdb->CreateColumnFamily(cf_options, "CFB", &cfb));
+ ASSERT_OK(txdb->Put(write_options, cfb, "", ""));
+
+ ReadOptions read_options;
+ std::string value;
+ TransactionOptions txn_options;
+ Transaction* txn = txdb->BeginTransaction(write_options, txn_options);
+ s = txn->SetName("xid");
+ ASSERT_OK(s);
+ ASSERT_EQ(txdb->GetTransactionByName("xid"), txn);
+
+ s = txn->Put(Slice("foo"), Slice("bar"));
+ s = txn->Put(cfa, Slice("foocfa"), Slice("barcfa"));
+ ASSERT_OK(s);
+ // Writing prepare into middle of first WAL, then flush WALs many times
+ for (int i = 1; i <= 100000; i++) {
+ Transaction* tx = txdb->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(tx->SetName("x"));
+ ASSERT_OK(tx->Put(Slice(std::to_string(i)), Slice("val")));
+ ASSERT_OK(tx->Put(cfa, Slice("aaa"), Slice("111")));
+ ASSERT_OK(tx->Prepare());
+ ASSERT_OK(tx->Commit());
+ if (i % 10000 == 0) {
+ txdb->Flush(FlushOptions());
+ }
+ if (i == 88888) {
+ ASSERT_OK(txn->Prepare());
+ }
+ delete tx;
+ }
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency(
+ {{"CheckpointImpl::CreateCheckpoint:SavedLiveFiles1",
+ "CheckpointTest::CurrentFileModifiedWhileCheckpointing2PC:PreCommit"},
+ {"CheckpointTest::CurrentFileModifiedWhileCheckpointing2PC:PostCommit",
+ "CheckpointImpl::CreateCheckpoint:SavedLiveFiles2"}});
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+ ROCKSDB_NAMESPACE::port::Thread t([&]() {
+ Checkpoint* checkpoint;
+ ASSERT_OK(Checkpoint::Create(txdb, &checkpoint));
+ ASSERT_OK(checkpoint->CreateCheckpoint(snapshot_name_));
+ delete checkpoint;
+ });
+ TEST_SYNC_POINT(
+ "CheckpointTest::CurrentFileModifiedWhileCheckpointing2PC:PreCommit");
+ ASSERT_OK(txn->Commit());
+ delete txn;
+ TEST_SYNC_POINT(
+ "CheckpointTest::CurrentFileModifiedWhileCheckpointing2PC:PostCommit");
+ t.join();
+
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+
+ // No more than two logs files should exist.
+ std::vector<std::string> files;
+ env_->GetChildren(snapshot_name_, &files);
+ int num_log_files = 0;
+ for (auto& file : files) {
+ uint64_t num;
+ FileType type;
+ WalFileType log_type;
+ if (ParseFileName(file, &num, &type, &log_type) && type == kLogFile) {
+ num_log_files++;
+ }
+ }
+ // One flush after preapare + one outstanding file before checkpoint + one log
+ // file generated after checkpoint.
+ ASSERT_LE(num_log_files, 3);
+
+ TransactionDB* snapshotDB;
+ std::vector<ColumnFamilyDescriptor> column_families;
+ column_families.push_back(
+ ColumnFamilyDescriptor(kDefaultColumnFamilyName, ColumnFamilyOptions()));
+ column_families.push_back(
+ ColumnFamilyDescriptor("CFA", ColumnFamilyOptions()));
+ column_families.push_back(
+ ColumnFamilyDescriptor("CFB", ColumnFamilyOptions()));
+ std::vector<ROCKSDB_NAMESPACE::ColumnFamilyHandle*> cf_handles;
+ ASSERT_OK(TransactionDB::Open(options, txn_db_options, snapshot_name_,
+ column_families, &cf_handles, &snapshotDB));
+ ASSERT_OK(snapshotDB->Get(read_options, "foo", &value));
+ ASSERT_EQ(value, "bar");
+ ASSERT_OK(snapshotDB->Get(read_options, cf_handles[1], "foocfa", &value));
+ ASSERT_EQ(value, "barcfa");
+
+ delete cfa;
+ delete cfb;
+ delete cf_handles[0];
+ delete cf_handles[1];
+ delete cf_handles[2];
+ delete snapshotDB;
+ snapshotDB = nullptr;
+ delete txdb;
+}
+
+TEST_F(CheckpointTest, CheckpointInvalidDirectoryName) {
+ for (std::string checkpoint_dir : {"", "/", "////"}) {
+ Checkpoint* checkpoint;
+ ASSERT_OK(Checkpoint::Create(db_, &checkpoint));
+ ASSERT_TRUE(checkpoint->CreateCheckpoint("").IsInvalidArgument());
+ delete checkpoint;
+ }
+}
+
+TEST_F(CheckpointTest, CheckpointWithParallelWrites) {
+ // When run with TSAN, this exposes the data race fixed in
+ // https://github.com/facebook/rocksdb/pull/3603
+ ASSERT_OK(Put("key1", "val1"));
+ port::Thread thread([this]() { ASSERT_OK(Put("key2", "val2")); });
+ Checkpoint* checkpoint;
+ ASSERT_OK(Checkpoint::Create(db_, &checkpoint));
+ ASSERT_OK(checkpoint->CreateCheckpoint(snapshot_name_));
+ delete checkpoint;
+ thread.join();
+}
+
+TEST_F(CheckpointTest, CheckpointWithUnsyncedDataDropped) {
+ Options options = CurrentOptions();
+ std::unique_ptr<FaultInjectionTestEnv> env(new FaultInjectionTestEnv(env_));
+ options.env = env.get();
+ Reopen(options);
+ ASSERT_OK(Put("key1", "val1"));
+ Checkpoint* checkpoint;
+ ASSERT_OK(Checkpoint::Create(db_, &checkpoint));
+ ASSERT_OK(checkpoint->CreateCheckpoint(snapshot_name_));
+ delete checkpoint;
+ env->DropUnsyncedFileData();
+
+ // make sure it's openable even though whatever data that wasn't synced got
+ // dropped.
+ options.env = env_;
+ DB* snapshot_db;
+ ASSERT_OK(DB::Open(options, snapshot_name_, &snapshot_db));
+ ReadOptions read_opts;
+ std::string get_result;
+ ASSERT_OK(snapshot_db->Get(read_opts, "key1", &get_result));
+ ASSERT_EQ("val1", get_result);
+ delete snapshot_db;
+ delete db_;
+ db_ = nullptr;
+}
+
+TEST_F(CheckpointTest, CheckpointReadOnlyDB) {
+ ASSERT_OK(Put("foo", "foo_value"));
+ ASSERT_OK(Flush());
+ Close();
+ Options options = CurrentOptions();
+ ASSERT_OK(ReadOnlyReopen(options));
+ Checkpoint* checkpoint = nullptr;
+ ASSERT_OK(Checkpoint::Create(db_, &checkpoint));
+ ASSERT_OK(checkpoint->CreateCheckpoint(snapshot_name_));
+ delete checkpoint;
+ checkpoint = nullptr;
+ Close();
+ DB* snapshot_db = nullptr;
+ ASSERT_OK(DB::Open(options, snapshot_name_, &snapshot_db));
+ ReadOptions read_opts;
+ std::string get_result;
+ ASSERT_OK(snapshot_db->Get(read_opts, "foo", &get_result));
+ ASSERT_EQ("foo_value", get_result);
+ delete snapshot_db;
+}
+
+TEST_F(CheckpointTest, CheckpointReadOnlyDBWithMultipleColumnFamilies) {
+ Options options = CurrentOptions();
+ CreateAndReopenWithCF({"pikachu", "eevee"}, options);
+ for (int i = 0; i != 3; ++i) {
+ ASSERT_OK(Put(i, "foo", "foo_value"));
+ ASSERT_OK(Flush(i));
+ }
+ Close();
+ Status s = ReadOnlyReopenWithColumnFamilies(
+ {kDefaultColumnFamilyName, "pikachu", "eevee"}, options);
+ ASSERT_OK(s);
+ Checkpoint* checkpoint = nullptr;
+ ASSERT_OK(Checkpoint::Create(db_, &checkpoint));
+ ASSERT_OK(checkpoint->CreateCheckpoint(snapshot_name_));
+ delete checkpoint;
+ checkpoint = nullptr;
+ Close();
+
+ std::vector<ColumnFamilyDescriptor> column_families{
+ {kDefaultColumnFamilyName, options},
+ {"pikachu", options},
+ {"eevee", options}};
+ DB* snapshot_db = nullptr;
+ std::vector<ColumnFamilyHandle*> snapshot_handles;
+ s = DB::Open(options, snapshot_name_, column_families, &snapshot_handles,
+ &snapshot_db);
+ ASSERT_OK(s);
+ ReadOptions read_opts;
+ for (int i = 0; i != 3; ++i) {
+ std::string get_result;
+ s = snapshot_db->Get(read_opts, snapshot_handles[i], "foo", &get_result);
+ ASSERT_OK(s);
+ ASSERT_EQ("foo_value", get_result);
+ }
+
+ for (auto snapshot_h : snapshot_handles) {
+ delete snapshot_h;
+ }
+ snapshot_handles.clear();
+ delete snapshot_db;
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
+
+#else
+#include <stdio.h>
+
+int main(int /*argc*/, char** /*argv*/) {
+ fprintf(stderr, "SKIPPED as Checkpoint is not supported in ROCKSDB_LITE\n");
+ return 0;
+}
+
+#endif // !ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/compaction_filters/remove_emptyvalue_compactionfilter.cc b/src/rocksdb/utilities/compaction_filters/remove_emptyvalue_compactionfilter.cc
new file mode 100644
index 000000000..c97eef41d
--- /dev/null
+++ b/src/rocksdb/utilities/compaction_filters/remove_emptyvalue_compactionfilter.cc
@@ -0,0 +1,29 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include <string>
+
+#include "rocksdb/slice.h"
+#include "utilities/compaction_filters/remove_emptyvalue_compactionfilter.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+const char* RemoveEmptyValueCompactionFilter::Name() const {
+ return "RemoveEmptyValueCompactionFilter";
+}
+
+bool RemoveEmptyValueCompactionFilter::Filter(int /*level*/,
+ const Slice& /*key*/,
+ const Slice& existing_value,
+ std::string* /*new_value*/,
+ bool* /*value_changed*/) const {
+ // remove kv pairs that have empty values
+ return existing_value.empty();
+}
+
+} // namespace ROCKSDB_NAMESPACE
+#endif // !ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/compaction_filters/remove_emptyvalue_compactionfilter.h b/src/rocksdb/utilities/compaction_filters/remove_emptyvalue_compactionfilter.h
new file mode 100644
index 000000000..f5dbec900
--- /dev/null
+++ b/src/rocksdb/utilities/compaction_filters/remove_emptyvalue_compactionfilter.h
@@ -0,0 +1,27 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#pragma once
+
+#include <string>
+
+#include "rocksdb/compaction_filter.h"
+#include "rocksdb/slice.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class RemoveEmptyValueCompactionFilter : public CompactionFilter {
+ public:
+ const char* Name() const override;
+ bool Filter(int level,
+ const Slice& key,
+ const Slice& existing_value,
+ std::string* new_value,
+ bool* value_changed) const override;
+};
+} // namespace ROCKSDB_NAMESPACE
+#endif // !ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/convenience/info_log_finder.cc b/src/rocksdb/utilities/convenience/info_log_finder.cc
new file mode 100644
index 000000000..980262f22
--- /dev/null
+++ b/src/rocksdb/utilities/convenience/info_log_finder.cc
@@ -0,0 +1,25 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2012 Facebook.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "rocksdb/utilities/info_log_finder.h"
+#include "file/filename.h"
+#include "rocksdb/env.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+Status GetInfoLogList(DB* db, std::vector<std::string>* info_log_list) {
+ if (!db) {
+ return Status::InvalidArgument("DB pointer is not valid");
+ }
+ std::string parent_path;
+ const Options& options = db->GetOptions();
+ return GetInfoLogFiles(options.env, options.db_log_dir, db->GetName(),
+ &parent_path, info_log_list);
+}
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/debug.cc b/src/rocksdb/utilities/debug.cc
new file mode 100644
index 000000000..b51f9da0d
--- /dev/null
+++ b/src/rocksdb/utilities/debug.cc
@@ -0,0 +1,80 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "rocksdb/utilities/debug.h"
+
+#include "db/db_impl/db_impl.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+Status GetAllKeyVersions(DB* db, Slice begin_key, Slice end_key,
+ size_t max_num_ikeys,
+ std::vector<KeyVersion>* key_versions) {
+ if (nullptr == db) {
+ return Status::InvalidArgument("db cannot be null.");
+ }
+ return GetAllKeyVersions(db, db->DefaultColumnFamily(), begin_key, end_key,
+ max_num_ikeys, key_versions);
+}
+
+Status GetAllKeyVersions(DB* db, ColumnFamilyHandle* cfh, Slice begin_key,
+ Slice end_key, size_t max_num_ikeys,
+ std::vector<KeyVersion>* key_versions) {
+ if (nullptr == db) {
+ return Status::InvalidArgument("db cannot be null.");
+ }
+ if (nullptr == cfh) {
+ return Status::InvalidArgument("Column family handle cannot be null.");
+ }
+ if (nullptr == key_versions) {
+ return Status::InvalidArgument("key_versions cannot be null.");
+ }
+ key_versions->clear();
+
+ DBImpl* idb = static_cast<DBImpl*>(db->GetRootDB());
+ auto icmp = InternalKeyComparator(idb->GetOptions(cfh).comparator);
+ ReadRangeDelAggregator range_del_agg(&icmp,
+ kMaxSequenceNumber /* upper_bound */);
+ Arena arena;
+ ScopedArenaIterator iter(idb->NewInternalIterator(&arena, &range_del_agg,
+ kMaxSequenceNumber, cfh));
+
+ if (!begin_key.empty()) {
+ InternalKey ikey;
+ ikey.SetMinPossibleForUserKey(begin_key);
+ iter->Seek(ikey.Encode());
+ } else {
+ iter->SeekToFirst();
+ }
+
+ size_t num_keys = 0;
+ for (; iter->Valid(); iter->Next()) {
+ ParsedInternalKey ikey;
+ if (!ParseInternalKey(iter->key(), &ikey)) {
+ return Status::Corruption("Internal Key [" + iter->key().ToString() +
+ "] parse error!");
+ }
+
+ if (!end_key.empty() &&
+ icmp.user_comparator()->Compare(ikey.user_key, end_key) > 0) {
+ break;
+ }
+
+ key_versions->emplace_back(ikey.user_key.ToString() /* _user_key */,
+ iter->value().ToString() /* _value */,
+ ikey.sequence /* _sequence */,
+ static_cast<int>(ikey.type) /* _type */);
+ if (++num_keys >= max_num_ikeys) {
+ break;
+ }
+ }
+ return Status::OK();
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/env_librados.cc b/src/rocksdb/utilities/env_librados.cc
new file mode 100644
index 000000000..5842edbc7
--- /dev/null
+++ b/src/rocksdb/utilities/env_librados.cc
@@ -0,0 +1,1497 @@
+// -*- mode:C++; tab-width:8; c-basic-offset:2; indent-tabs-mode:t -*-
+// vim: ts=8 sw=2 smarttab
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+#include "rocksdb/utilities/env_librados.h"
+#include "util/random.h"
+#include <mutex>
+#include <cstdlib>
+
+namespace ROCKSDB_NAMESPACE {
+/* GLOBAL DIFINE */
+// #define DEBUG
+#ifdef DEBUG
+#include <cstdio>
+#include <sys/syscall.h>
+#include <unistd.h>
+#define LOG_DEBUG(...) do{\
+ printf("[%ld:%s:%i:%s]", syscall(SYS_gettid), __FILE__, __LINE__, __FUNCTION__);\
+ printf(__VA_ARGS__);\
+ }while(0)
+#else
+#define LOG_DEBUG(...)
+#endif
+
+/* GLOBAL CONSTANT */
+const char *default_db_name = "default_envlibrados_db";
+const char *default_pool_name = "default_envlibrados_pool";
+const char *default_config_path = "CEPH_CONFIG_PATH"; // the env variable name of ceph configure file
+// maximum dir/file that can store in the fs
+const int MAX_ITEMS_IN_FS = 1 << 30;
+// root dir tag
+const std::string ROOT_DIR_KEY = "/";
+const std::string DIR_ID_VALUE = "<DIR>";
+
+/**
+ * @brief convert error code to status
+ * @details Convert internal linux error code to Status
+ *
+ * @param r [description]
+ * @return [description]
+ */
+Status err_to_status(int r)
+{
+ switch (r) {
+ case 0:
+ return Status::OK();
+ case -ENOENT:
+ return Status::IOError();
+ case -ENODATA:
+ case -ENOTDIR:
+ return Status::NotFound(Status::kNone);
+ case -EINVAL:
+ return Status::InvalidArgument(Status::kNone);
+ case -EIO:
+ return Status::IOError(Status::kNone);
+ default:
+ // FIXME :(
+ assert(0 == "unrecognized error code");
+ return Status::NotSupported(Status::kNone);
+ }
+}
+
+/**
+ * @brief split file path into dir path and file name
+ * @details
+ * Because rocksdb only need a 2-level structure (dir/file), all input path will be shortened to dir/file format
+ * For example:
+ * b/c => dir '/b', file 'c'
+ * /a/b/c => dir '/b', file 'c'
+ *
+ * @param fn [description]
+ * @param dir [description]
+ * @param file [description]
+ */
+void split(const std::string &fn, std::string *dir, std::string *file) {
+ LOG_DEBUG("[IN]%s\n", fn.c_str());
+ int pos = fn.size() - 1;
+ while ('/' == fn[pos]) --pos;
+ size_t fstart = fn.rfind('/', pos);
+ *file = fn.substr(fstart + 1, pos - fstart);
+
+ pos = fstart;
+ while (pos >= 0 && '/' == fn[pos]) --pos;
+
+ if (pos < 0) {
+ *dir = "/";
+ } else {
+ size_t dstart = fn.rfind('/', pos);
+ *dir = fn.substr(dstart + 1, pos - dstart);
+ *dir = std::string("/") + *dir;
+ }
+
+ LOG_DEBUG("[OUT]%s | %s\n", dir->c_str(), file->c_str());
+}
+
+// A file abstraction for reading sequentially through a file
+class LibradosSequentialFile : public SequentialFile {
+ librados::IoCtx * _io_ctx;
+ std::string _fid;
+ std::string _hint;
+ int _offset;
+public:
+ LibradosSequentialFile(librados::IoCtx * io_ctx, std::string fid, std::string hint):
+ _io_ctx(io_ctx), _fid(fid), _hint(hint), _offset(0) {}
+
+ ~LibradosSequentialFile() {}
+
+ /**
+ * @brief read file
+ * @details
+ * Read up to "n" bytes from the file. "scratch[0..n-1]" may be
+ * written by this routine. Sets "*result" to the data that was
+ * read (including if fewer than "n" bytes were successfully read).
+ * May set "*result" to point at data in "scratch[0..n-1]", so
+ * "scratch[0..n-1]" must be live when "*result" is used.
+ * If an error was encountered, returns a non-OK status.
+ *
+ * REQUIRES: External synchronization
+ *
+ * @param n [description]
+ * @param result [description]
+ * @param scratch [description]
+ * @return [description]
+ */
+ Status Read(size_t n, Slice* result, char* scratch) {
+ LOG_DEBUG("[IN]%i\n", (int)n);
+ librados::bufferlist buffer;
+ Status s;
+ int r = _io_ctx->read(_fid, buffer, n, _offset);
+ if (r >= 0) {
+ buffer.begin().copy(r, scratch);
+ *result = Slice(scratch, r);
+ _offset += r;
+ s = Status::OK();
+ } else {
+ s = err_to_status(r);
+ if (s == Status::IOError()) {
+ *result = Slice();
+ s = Status::OK();
+ }
+ }
+ LOG_DEBUG("[OUT]%s, %i, %s\n", s.ToString().c_str(), (int)r, buffer.c_str());
+ return s;
+ }
+
+ /**
+ * @brief skip "n" bytes from the file
+ * @details
+ * Skip "n" bytes from the file. This is guaranteed to be no
+ * slower that reading the same data, but may be faster.
+ *
+ * If end of file is reached, skipping will stop at the end of the
+ * file, and Skip will return OK.
+ *
+ * REQUIRES: External synchronization
+ *
+ * @param n [description]
+ * @return [description]
+ */
+ Status Skip(uint64_t n) {
+ _offset += n;
+ return Status::OK();
+ }
+
+ /**
+ * @brief noop
+ * @details
+ * rocksdb has it's own caching capabilities that we should be able to use,
+ * without relying on a cache here. This can safely be a no-op.
+ *
+ * @param offset [description]
+ * @param length [description]
+ *
+ * @return [description]
+ */
+ Status InvalidateCache(size_t offset, size_t length) {
+ return Status::OK();
+ }
+};
+
+// A file abstraction for randomly reading the contents of a file.
+class LibradosRandomAccessFile : public RandomAccessFile {
+ librados::IoCtx * _io_ctx;
+ std::string _fid;
+ std::string _hint;
+public:
+ LibradosRandomAccessFile(librados::IoCtx * io_ctx, std::string fid, std::string hint):
+ _io_ctx(io_ctx), _fid(fid), _hint(hint) {}
+
+ ~LibradosRandomAccessFile() {}
+
+ /**
+ * @brief read file
+ * @details similar to LibradosSequentialFile::Read
+ *
+ * @param offset [description]
+ * @param n [description]
+ * @param result [description]
+ * @param scratch [description]
+ * @return [description]
+ */
+ Status Read(uint64_t offset, size_t n, Slice* result,
+ char* scratch) const {
+ LOG_DEBUG("[IN]%i\n", (int)n);
+ librados::bufferlist buffer;
+ Status s;
+ int r = _io_ctx->read(_fid, buffer, n, offset);
+ if (r >= 0) {
+ buffer.begin().copy(r, scratch);
+ *result = Slice(scratch, r);
+ s = Status::OK();
+ } else {
+ s = err_to_status(r);
+ if (s == Status::IOError()) {
+ *result = Slice();
+ s = Status::OK();
+ }
+ }
+ LOG_DEBUG("[OUT]%s, %i, %s\n", s.ToString().c_str(), (int)r, buffer.c_str());
+ return s;
+ }
+
+ /**
+ * @brief [brief description]
+ * @details Get unique id for each file and guarantee this id is different for each file
+ *
+ * @param id [description]
+ * @param max_size max size of id, it shoud be larger than 16
+ *
+ * @return [description]
+ */
+ size_t GetUniqueId(char* id, size_t max_size) const {
+ // All fid has the same db_id prefix, so we need to ignore db_id prefix
+ size_t s = std::min(max_size, _fid.size());
+ strncpy(id, _fid.c_str() + (_fid.size() - s), s);
+ id[s - 1] = '\0';
+ return s;
+ };
+
+ //enum AccessPattern { NORMAL, RANDOM, SEQUENTIAL, WILLNEED, DONTNEED };
+ void Hint(AccessPattern pattern) {
+ /* Do nothing */
+ }
+
+ /**
+ * @brief noop
+ * @details [long description]
+ *
+ * @param offset [description]
+ * @param length [description]
+ *
+ * @return [description]
+ */
+ Status InvalidateCache(size_t offset, size_t length) {
+ return Status::OK();
+ }
+};
+
+
+// A file abstraction for sequential writing. The implementation
+// must provide buffering since callers may append small fragments
+// at a time to the file.
+class LibradosWritableFile : public WritableFile {
+ librados::IoCtx * _io_ctx;
+ std::string _fid;
+ std::string _hint;
+ const EnvLibrados * const _env;
+
+ std::mutex _mutex; // used to protect modification of all following variables
+ librados::bufferlist _buffer; // write buffer
+ uint64_t _buffer_size; // write buffer size
+ uint64_t _file_size; // this file size doesn't include buffer size
+
+ /**
+ * @brief assuming caller holds lock
+ * @details [long description]
+ * @return [description]
+ */
+ int _SyncLocked() {
+ // 1. sync append data to RADOS
+ int r = _io_ctx->append(_fid, _buffer, _buffer_size);
+ assert(r >= 0);
+
+ // 2. update local variables
+ if (0 == r) {
+ _buffer.clear();
+ _file_size += _buffer_size;
+ _buffer_size = 0;
+ }
+
+ return r;
+ }
+
+ public:
+ LibradosWritableFile(librados::IoCtx* io_ctx, std::string fid,
+ std::string hint, const EnvLibrados* const env,
+ const EnvOptions& options)
+ : WritableFile(options),
+ _io_ctx(io_ctx),
+ _fid(fid),
+ _hint(hint),
+ _env(env),
+ _buffer(),
+ _buffer_size(0),
+ _file_size(0) {
+ int ret = _io_ctx->stat(_fid, &_file_size, nullptr);
+
+ // if file not exist
+ if (ret < 0) {
+ _file_size = 0;
+ }
+ }
+
+ ~LibradosWritableFile() {
+ // sync before closeing writable file
+ Sync();
+ }
+
+ /**
+ * @brief append data to file
+ * @details
+ * Append will save all written data in buffer util buffer size
+ * reaches buffer max size. Then, it will write buffer into rados
+ *
+ * @param data [description]
+ * @return [description]
+ */
+ Status Append(const Slice& data) {
+ // append buffer
+ LOG_DEBUG("[IN] %i | %s\n", (int)data.size(), data.data());
+ int r = 0;
+
+ std::lock_guard<std::mutex> lock(_mutex);
+ _buffer.append(data.data(), data.size());
+ _buffer_size += data.size();
+
+ if (_buffer_size > _env->_write_buffer_size) {
+ r = _SyncLocked();
+ }
+
+ LOG_DEBUG("[OUT] %i\n", r);
+ return err_to_status(r);
+ }
+
+ /**
+ * @brief not supported
+ * @details [long description]
+ * @return [description]
+ */
+ Status PositionedAppend(
+ const Slice& /* data */,
+ uint64_t /* offset */) {
+ return Status::NotSupported();
+ }
+
+ /**
+ * @brief truncate file to assigned size
+ * @details [long description]
+ *
+ * @param size [description]
+ * @return [description]
+ */
+ Status Truncate(uint64_t size) {
+ LOG_DEBUG("[IN]%lld|%lld|%lld\n", (long long)size, (long long)_file_size, (long long)_buffer_size);
+ int r = 0;
+
+ std::lock_guard<std::mutex> lock(_mutex);
+ if (_file_size > size) {
+ r = _io_ctx->trunc(_fid, size);
+
+ if (r == 0) {
+ _buffer.clear();
+ _buffer_size = 0;
+ _file_size = size;
+ }
+ } else if (_file_size == size) {
+ _buffer.clear();
+ _buffer_size = 0;
+ } else {
+ librados::bufferlist tmp;
+ tmp.claim(_buffer);
+ _buffer.substr_of(tmp, 0, size - _file_size);
+ _buffer_size = size - _file_size;
+ }
+
+ LOG_DEBUG("[OUT] %i\n", r);
+ return err_to_status(r);
+ }
+
+ /**
+ * @brief close file
+ * @details [long description]
+ * @return [description]
+ */
+ Status Close() {
+ LOG_DEBUG("%s | %lld | %lld\n", _hint.c_str(), (long long)_buffer_size, (long long)_file_size);
+ return Sync();
+ }
+
+ /**
+ * @brief flush file,
+ * @details initiate an aio write and not wait
+ *
+ * @return [description]
+ */
+ Status Flush() {
+ librados::AioCompletion *write_completion = librados::Rados::aio_create_completion();
+ int r = 0;
+
+ std::lock_guard<std::mutex> lock(_mutex);
+ r = _io_ctx->aio_append(_fid, write_completion, _buffer, _buffer_size);
+
+ if (0 == r) {
+ _file_size += _buffer_size;
+ _buffer.clear();
+ _buffer_size = 0;
+ }
+
+ write_completion->release();
+
+ return err_to_status(r);
+ }
+
+ /**
+ * @brief write buffer data to rados
+ * @details initiate an aio write and wait for result
+ * @return [description]
+ */
+ Status Sync() { // sync data
+ int r = 0;
+
+ std::lock_guard<std::mutex> lock(_mutex);
+ if (_buffer_size > 0) {
+ r = _SyncLocked();
+ }
+
+ return err_to_status(r);
+ }
+
+ /**
+ * @brief [brief description]
+ * @details [long description]
+ * @return true if Sync() and Fsync() are safe to call concurrently with Append()and Flush().
+ */
+ bool IsSyncThreadSafe() const {
+ return true;
+ }
+
+ /**
+ * @brief Indicates the upper layers if the current WritableFile implementation uses direct IO.
+ * @details [long description]
+ * @return [description]
+ */
+ bool use_direct_io() const {
+ return false;
+ }
+
+ /**
+ * @brief Get file size
+ * @details
+ * This API will use cached file_size.
+ * @return [description]
+ */
+ uint64_t GetFileSize() {
+ LOG_DEBUG("%lld|%lld\n", (long long)_buffer_size, (long long)_file_size);
+
+ std::lock_guard<std::mutex> lock(_mutex);
+ int file_size = _file_size + _buffer_size;
+
+ return file_size;
+ }
+
+ /**
+ * @brief For documentation, refer to RandomAccessFile::GetUniqueId()
+ * @details [long description]
+ *
+ * @param id [description]
+ * @param max_size [description]
+ *
+ * @return [description]
+ */
+ size_t GetUniqueId(char* id, size_t max_size) const {
+ // All fid has the same db_id prefix, so we need to ignore db_id prefix
+ size_t s = std::min(max_size, _fid.size());
+ strncpy(id, _fid.c_str() + (_fid.size() - s), s);
+ id[s - 1] = '\0';
+ return s;
+ }
+
+ /**
+ * @brief noop
+ * @details [long description]
+ *
+ * @param offset [description]
+ * @param length [description]
+ *
+ * @return [description]
+ */
+ Status InvalidateCache(size_t offset, size_t length) {
+ return Status::OK();
+ }
+
+ using WritableFile::RangeSync;
+ /**
+ * @brief No RangeSync support, just call Sync()
+ * @details [long description]
+ *
+ * @param offset [description]
+ * @param nbytes [description]
+ *
+ * @return [description]
+ */
+ Status RangeSync(off_t offset, off_t nbytes) {
+ return Sync();
+ }
+
+protected:
+ using WritableFile::Allocate;
+ /**
+ * @brief noop
+ * @details [long description]
+ *
+ * @param offset [description]
+ * @param len [description]
+ *
+ * @return [description]
+ */
+ Status Allocate(off_t offset, off_t len) {
+ return Status::OK();
+ }
+};
+
+
+// Directory object represents collection of files and implements
+// filesystem operations that can be executed on directories.
+class LibradosDirectory : public Directory {
+ librados::IoCtx * _io_ctx;
+ std::string _fid;
+public:
+ explicit LibradosDirectory(librados::IoCtx * io_ctx, std::string fid):
+ _io_ctx(io_ctx), _fid(fid) {}
+
+ // Fsync directory. Can be called concurrently from multiple threads.
+ Status Fsync() {
+ return Status::OK();
+ }
+};
+
+// Identifies a locked file.
+// This is exclusive lock and can't nested lock by same thread
+class LibradosFileLock : public FileLock {
+ librados::IoCtx * _io_ctx;
+ const std::string _obj_name;
+ const std::string _lock_name;
+ const std::string _cookie;
+ int lock_state;
+public:
+ LibradosFileLock(
+ librados::IoCtx * io_ctx,
+ const std::string obj_name):
+ _io_ctx(io_ctx),
+ _obj_name(obj_name),
+ _lock_name("lock_name"),
+ _cookie("cookie") {
+
+ // TODO: the lock will never expire. It may cause problem if the process crash or abnormally exit.
+ while (!_io_ctx->lock_exclusive(
+ _obj_name,
+ _lock_name,
+ _cookie,
+ "description", nullptr, 0));
+ }
+
+ ~LibradosFileLock() {
+ _io_ctx->unlock(_obj_name, _lock_name, _cookie);
+ }
+};
+
+
+// --------------------
+// --- EnvLibrados ----
+// --------------------
+/**
+ * @brief EnvLibrados ctor
+ * @details [long description]
+ *
+ * @param db_name unique database name
+ * @param config_path the configure file path for rados
+ */
+EnvLibrados::EnvLibrados(const std::string& db_name,
+ const std::string& config_path,
+ const std::string& db_pool)
+ : EnvLibrados("client.admin",
+ "ceph",
+ 0,
+ db_name,
+ config_path,
+ db_pool,
+ "/wal",
+ db_pool,
+ 1 << 20) {}
+
+/**
+ * @brief EnvLibrados ctor
+ * @details [long description]
+ *
+ * @param client_name first 3 parameters is for RADOS client init
+ * @param cluster_name
+ * @param flags
+ * @param db_name unique database name, used as db_id key
+ * @param config_path the configure file path for rados
+ * @param db_pool the pool for db data
+ * @param wal_pool the pool for WAL data
+ * @param write_buffer_size WritableFile buffer max size
+ */
+EnvLibrados::EnvLibrados(const std::string& client_name,
+ const std::string& cluster_name,
+ const uint64_t flags,
+ const std::string& db_name,
+ const std::string& config_path,
+ const std::string& db_pool,
+ const std::string& wal_dir,
+ const std::string& wal_pool,
+ const uint64_t write_buffer_size)
+ : EnvWrapper(Env::Default()),
+ _client_name(client_name),
+ _cluster_name(cluster_name),
+ _flags(flags),
+ _db_name(db_name),
+ _config_path(config_path),
+ _db_pool_name(db_pool),
+ _wal_dir(wal_dir),
+ _wal_pool_name(wal_pool),
+ _write_buffer_size(write_buffer_size) {
+ int ret = 0;
+
+ // 1. create a Rados object and initialize it
+ ret = _rados.init2(_client_name.c_str(), _cluster_name.c_str(), _flags); // just use the client.admin keyring
+ if (ret < 0) { // let's handle any error that might have come back
+ std::cerr << "couldn't initialize rados! error " << ret << std::endl;
+ ret = EXIT_FAILURE;
+ goto out;
+ }
+
+ // 2. read configure file
+ ret = _rados.conf_read_file(_config_path.c_str());
+ if (ret < 0) {
+ // This could fail if the config file is malformed, but it'd be hard.
+ std::cerr << "failed to parse config file " << _config_path
+ << "! error" << ret << std::endl;
+ ret = EXIT_FAILURE;
+ goto out;
+ }
+
+ // 3. we actually connect to the cluster
+ ret = _rados.connect();
+ if (ret < 0) {
+ std::cerr << "couldn't connect to cluster! error " << ret << std::endl;
+ ret = EXIT_FAILURE;
+ goto out;
+ }
+
+ // 4. create db_pool if not exist
+ ret = _rados.pool_create(_db_pool_name.c_str());
+ if (ret < 0 && ret != -EEXIST && ret != -EPERM) {
+ std::cerr << "couldn't create pool! error " << ret << std::endl;
+ goto out;
+ }
+
+ // 5. create db_pool_ioctx
+ ret = _rados.ioctx_create(_db_pool_name.c_str(), _db_pool_ioctx);
+ if (ret < 0) {
+ std::cerr << "couldn't set up ioctx! error " << ret << std::endl;
+ ret = EXIT_FAILURE;
+ goto out;
+ }
+
+ // 6. create wal_pool if not exist
+ ret = _rados.pool_create(_wal_pool_name.c_str());
+ if (ret < 0 && ret != -EEXIST && ret != -EPERM) {
+ std::cerr << "couldn't create pool! error " << ret << std::endl;
+ goto out;
+ }
+
+ // 7. create wal_pool_ioctx
+ ret = _rados.ioctx_create(_wal_pool_name.c_str(), _wal_pool_ioctx);
+ if (ret < 0) {
+ std::cerr << "couldn't set up ioctx! error " << ret << std::endl;
+ ret = EXIT_FAILURE;
+ goto out;
+ }
+
+ // 8. add root dir
+ _AddFid(ROOT_DIR_KEY, DIR_ID_VALUE);
+
+out:
+ LOG_DEBUG("rados connect result code : %i\n", ret);
+}
+
+/****************************************************
+ private functions to handle fid operation.
+ Dir also have fid, but the value is DIR_ID_VALUE
+****************************************************/
+
+/**
+ * @brief generate a new fid
+ * @details [long description]
+ * @return [description]
+ */
+std::string EnvLibrados::_CreateFid() {
+ return _db_name + "." + GenerateUniqueId();
+}
+
+/**
+ * @brief get fid
+ * @details [long description]
+ *
+ * @param fname [description]
+ * @param fid [description]
+ *
+ * @return
+ * Status::OK()
+ * Status::NotFound()
+ */
+Status EnvLibrados::_GetFid(
+ const std::string &fname,
+ std::string& fid) {
+ std::set<std::string> keys;
+ std::map<std::string, librados::bufferlist> kvs;
+ keys.insert(fname);
+ int r = _db_pool_ioctx.omap_get_vals_by_keys(_db_name, keys, &kvs);
+
+ if (0 == r && 0 == kvs.size()) {
+ return Status::NotFound();
+ } else if (0 == r && 0 != kvs.size()) {
+ fid.assign(kvs[fname].c_str(), kvs[fname].length());
+ return Status::OK();
+ } else {
+ return err_to_status(r);
+ }
+}
+
+/**
+ * @brief rename fid
+ * @details Only modify object in rados once,
+ * so this rename operation is atomic in term of rados
+ *
+ * @param old_fname [description]
+ * @param new_fname [description]
+ *
+ * @return [description]
+ */
+Status EnvLibrados::_RenameFid(const std::string& old_fname,
+ const std::string& new_fname) {
+ std::string fid;
+ Status s = _GetFid(old_fname, fid);
+
+ if (Status::OK() != s) {
+ return s;
+ }
+
+ librados::bufferlist bl;
+ std::set<std::string> keys;
+ std::map<std::string, librados::bufferlist> kvs;
+ librados::ObjectWriteOperation o;
+ bl.append(fid);
+ keys.insert(old_fname);
+ kvs[new_fname] = bl;
+ o.omap_rm_keys(keys);
+ o.omap_set(kvs);
+ int r = _db_pool_ioctx.operate(_db_name, &o);
+ return err_to_status(r);
+}
+
+/**
+ * @brief add <file path, fid> to metadata object. It may overwrite exist key.
+ * @details [long description]
+ *
+ * @param fname [description]
+ * @param fid [description]
+ *
+ * @return [description]
+ */
+Status EnvLibrados::_AddFid(
+ const std::string& fname,
+ const std::string& fid) {
+ std::map<std::string, librados::bufferlist> kvs;
+ librados::bufferlist value;
+ value.append(fid);
+ kvs[fname] = value;
+ int r = _db_pool_ioctx.omap_set(_db_name, kvs);
+ return err_to_status(r);
+}
+
+/**
+ * @brief return subfile names of dir.
+ * @details
+ * RocksDB has a 2-level structure, so all keys
+ * that have dir as prefix are subfiles of dir.
+ * So we can just return these files' name.
+ *
+ * @param dir [description]
+ * @param result [description]
+ *
+ * @return [description]
+ */
+Status EnvLibrados::_GetSubFnames(
+ const std::string& dir,
+ std::vector<std::string> * result
+) {
+ std::string start_after(dir);
+ std::string filter_prefix(dir);
+ std::map<std::string, librados::bufferlist> kvs;
+ _db_pool_ioctx.omap_get_vals(_db_name,
+ start_after, filter_prefix,
+ MAX_ITEMS_IN_FS, &kvs);
+
+ result->clear();
+ for (auto i = kvs.begin(); i != kvs.end(); i++) {
+ result->push_back(i->first.substr(dir.size() + 1));
+ }
+ return Status::OK();
+}
+
+/**
+ * @brief delete key fname from metadata object
+ * @details [long description]
+ *
+ * @param fname [description]
+ * @return [description]
+ */
+Status EnvLibrados::_DelFid(
+ const std::string& fname) {
+ std::set<std::string> keys;
+ keys.insert(fname);
+ int r = _db_pool_ioctx.omap_rm_keys(_db_name, keys);
+ return err_to_status(r);
+}
+
+/**
+ * @brief get match IoCtx from _prefix_pool_map
+ * @details [long description]
+ *
+ * @param prefix [description]
+ * @return [description]
+ *
+ */
+librados::IoCtx* EnvLibrados::_GetIoctx(const std::string& fpath) {
+ auto is_prefix = [](const std::string & s1, const std::string & s2) {
+ auto it1 = s1.begin(), it2 = s2.begin();
+ while (it1 != s1.end() && it2 != s2.end() && *it1 == *it2) ++it1, ++it2;
+ return it1 == s1.end();
+ };
+
+ if (is_prefix(_wal_dir, fpath)) {
+ return &_wal_pool_ioctx;
+ } else {
+ return &_db_pool_ioctx;
+ }
+}
+
+/************************************************************
+ public functions
+************************************************************/
+/**
+ * @brief generate unique id
+ * @details Combine system time and random number.
+ * @return [description]
+ */
+std::string EnvLibrados::GenerateUniqueId() {
+ Random64 r(time(nullptr));
+ uint64_t random_uuid_portion =
+ r.Uniform(std::numeric_limits<uint64_t>::max());
+ uint64_t nanos_uuid_portion = NowNanos();
+ char uuid2[200];
+ snprintf(uuid2,
+ 200,
+ "%16lx-%16lx",
+ (unsigned long)nanos_uuid_portion,
+ (unsigned long)random_uuid_portion);
+ return uuid2;
+}
+
+/**
+ * @brief create a new sequential read file handler
+ * @details it will check the existence of fname
+ *
+ * @param fname [description]
+ * @param result [description]
+ * @param options [description]
+ * @return [description]
+ */
+Status EnvLibrados::NewSequentialFile(
+ const std::string& fname,
+ std::unique_ptr<SequentialFile>* result,
+ const EnvOptions& options)
+{
+ LOG_DEBUG("[IN]%s\n", fname.c_str());
+ std::string dir, file, fid;
+ split(fname, &dir, &file);
+ Status s;
+ std::string fpath = dir + "/" + file;
+ do {
+ s = _GetFid(dir, fid);
+
+ if (!s.ok() || fid != DIR_ID_VALUE) {
+ if (fid != DIR_ID_VALUE) s = Status::IOError();
+ break;
+ }
+
+ s = _GetFid(fpath, fid);
+
+ if (Status::NotFound() == s) {
+ s = Status::IOError();
+ errno = ENOENT;
+ break;
+ }
+
+ result->reset(new LibradosSequentialFile(_GetIoctx(fpath), fid, fpath));
+ s = Status::OK();
+ } while (0);
+
+ LOG_DEBUG("[OUT]%s\n", s.ToString().c_str());
+ return s;
+}
+
+/**
+ * @brief create a new random access file handler
+ * @details it will check the existence of fname
+ *
+ * @param fname [description]
+ * @param result [description]
+ * @param options [description]
+ * @return [description]
+ */
+Status EnvLibrados::NewRandomAccessFile(
+ const std::string& fname,
+ std::unique_ptr<RandomAccessFile>* result,
+ const EnvOptions& options)
+{
+ LOG_DEBUG("[IN]%s\n", fname.c_str());
+ std::string dir, file, fid;
+ split(fname, &dir, &file);
+ Status s;
+ std::string fpath = dir + "/" + file;
+ do {
+ s = _GetFid(dir, fid);
+
+ if (!s.ok() || fid != DIR_ID_VALUE) {
+ s = Status::IOError();
+ break;
+ }
+
+ s = _GetFid(fpath, fid);
+
+ if (Status::NotFound() == s) {
+ s = Status::IOError();
+ errno = ENOENT;
+ break;
+ }
+
+ result->reset(new LibradosRandomAccessFile(_GetIoctx(fpath), fid, fpath));
+ s = Status::OK();
+ } while (0);
+
+ LOG_DEBUG("[OUT]%s\n", s.ToString().c_str());
+ return s;
+}
+
+/**
+ * @brief create a new write file handler
+ * @details it will check the existence of fname
+ *
+ * @param fname [description]
+ * @param result [description]
+ * @param options [description]
+ * @return [description]
+ */
+Status EnvLibrados::NewWritableFile(
+ const std::string& fname,
+ std::unique_ptr<WritableFile>* result,
+ const EnvOptions& options)
+{
+ LOG_DEBUG("[IN]%s\n", fname.c_str());
+ std::string dir, file, fid;
+ split(fname, &dir, &file);
+ Status s;
+ std::string fpath = dir + "/" + file;
+
+ do {
+ // 1. check if dir exist
+ s = _GetFid(dir, fid);
+ if (!s.ok()) {
+ break;
+ }
+
+ if (fid != DIR_ID_VALUE) {
+ s = Status::IOError();
+ break;
+ }
+
+ // 2. check if file exist.
+ // 2.1 exist, use it
+ // 2.2 not exist, create it
+ s = _GetFid(fpath, fid);
+ if (Status::NotFound() == s) {
+ fid = _CreateFid();
+ _AddFid(fpath, fid);
+ }
+
+ result->reset(
+ new LibradosWritableFile(_GetIoctx(fpath), fid, fpath, this, options));
+ s = Status::OK();
+ } while (0);
+
+ LOG_DEBUG("[OUT]%s\n", s.ToString().c_str());
+ return s;
+}
+
+/**
+ * @brief reuse write file handler
+ * @details
+ * This function will rename old_fname to new_fname,
+ * then return the handler of new_fname
+ *
+ * @param new_fname [description]
+ * @param old_fname [description]
+ * @param result [description]
+ * @param options [description]
+ * @return [description]
+ */
+Status EnvLibrados::ReuseWritableFile(
+ const std::string& new_fname,
+ const std::string& old_fname,
+ std::unique_ptr<WritableFile>* result,
+ const EnvOptions& options)
+{
+ LOG_DEBUG("[IN]%s => %s\n", old_fname.c_str(), new_fname.c_str());
+ std::string src_fid, tmp_fid, src_dir, src_file, dst_dir, dst_file;
+ split(old_fname, &src_dir, &src_file);
+ split(new_fname, &dst_dir, &dst_file);
+
+ std::string src_fpath = src_dir + "/" + src_file;
+ std::string dst_fpath = dst_dir + "/" + dst_file;
+ Status r = Status::OK();
+ do {
+ r = _RenameFid(src_fpath,
+ dst_fpath);
+ if (!r.ok()) {
+ break;
+ }
+
+ result->reset(new LibradosWritableFile(_GetIoctx(dst_fpath), src_fid,
+ dst_fpath, this, options));
+ } while (0);
+
+ LOG_DEBUG("[OUT]%s\n", r.ToString().c_str());
+ return r;
+}
+
+/**
+ * @brief create a new directory handler
+ * @details [long description]
+ *
+ * @param name [description]
+ * @param result [description]
+ *
+ * @return [description]
+ */
+Status EnvLibrados::NewDirectory(
+ const std::string& name,
+ std::unique_ptr<Directory>* result)
+{
+ LOG_DEBUG("[IN]%s\n", name.c_str());
+ std::string fid, dir, file;
+ /* just want to get dir name */
+ split(name + "/tmp", &dir, &file);
+ Status s;
+
+ do {
+ s = _GetFid(dir, fid);
+
+ if (!s.ok() || DIR_ID_VALUE != fid) {
+ s = Status::IOError(name, strerror(-ENOENT));
+ break;
+ }
+
+ if (Status::NotFound() == s) {
+ s = _AddFid(dir, DIR_ID_VALUE);
+ if (!s.ok()) break;
+ } else if (!s.ok()) {
+ break;
+ }
+
+ result->reset(new LibradosDirectory(_GetIoctx(dir), dir));
+ s = Status::OK();
+ } while (0);
+
+ LOG_DEBUG("[OUT]%s\n", s.ToString().c_str());
+ return s;
+}
+
+/**
+ * @brief check if fname is exist
+ * @details [long description]
+ *
+ * @param fname [description]
+ * @return [description]
+ */
+Status EnvLibrados::FileExists(const std::string& fname)
+{
+ LOG_DEBUG("[IN]%s\n", fname.c_str());
+ std::string fid, dir, file;
+ split(fname, &dir, &file);
+ Status s = _GetFid(dir + "/" + file, fid);
+
+ if (s.ok() && fid != DIR_ID_VALUE) {
+ s = Status::OK();
+ }
+
+ LOG_DEBUG("[OUT]%s\n", s.ToString().c_str());
+ return s;
+}
+
+/**
+ * @brief get subfile name of dir_in
+ * @details [long description]
+ *
+ * @param dir_in [description]
+ * @param result [description]
+ *
+ * @return [description]
+ */
+Status EnvLibrados::GetChildren(
+ const std::string& dir_in,
+ std::vector<std::string>* result)
+{
+ LOG_DEBUG("[IN]%s\n", dir_in.c_str());
+ std::string fid, dir, file;
+ split(dir_in + "/temp", &dir, &file);
+ Status s;
+
+ do {
+ s = _GetFid(dir, fid);
+ if (!s.ok()) {
+ break;
+ }
+
+ if (fid != DIR_ID_VALUE) {
+ s = Status::IOError();
+ break;
+ }
+
+ s = _GetSubFnames(dir, result);
+ } while (0);
+
+ LOG_DEBUG("[OUT]%s\n", s.ToString().c_str());
+ return s;
+}
+
+/**
+ * @brief delete fname
+ * @details [long description]
+ *
+ * @param fname [description]
+ * @return [description]
+ */
+Status EnvLibrados::DeleteFile(const std::string& fname)
+{
+ LOG_DEBUG("[IN]%s\n", fname.c_str());
+ std::string fid, dir, file;
+ split(fname, &dir, &file);
+ Status s = _GetFid(dir + "/" + file, fid);
+
+ if (s.ok() && DIR_ID_VALUE != fid) {
+ s = _DelFid(dir + "/" + file);
+ } else {
+ s = Status::NotFound();
+ }
+ LOG_DEBUG("[OUT]%s\n", s.ToString().c_str());
+ return s;
+}
+
+/**
+ * @brief create new dir
+ * @details [long description]
+ *
+ * @param dirname [description]
+ * @return [description]
+ */
+Status EnvLibrados::CreateDir(const std::string& dirname)
+{
+ LOG_DEBUG("[IN]%s\n", dirname.c_str());
+ std::string fid, dir, file;
+ split(dirname + "/temp", &dir, &file);
+ Status s = _GetFid(dir + "/" + file, fid);
+
+ do {
+ if (Status::NotFound() != s && fid != DIR_ID_VALUE) {
+ break;
+ } else if (Status::OK() == s && fid == DIR_ID_VALUE) {
+ break;
+ }
+
+ s = _AddFid(dir, DIR_ID_VALUE);
+ } while (0);
+
+ LOG_DEBUG("[OUT]%s\n", s.ToString().c_str());
+ return s;
+}
+
+/**
+ * @brief create dir if missing
+ * @details [long description]
+ *
+ * @param dirname [description]
+ * @return [description]
+ */
+Status EnvLibrados::CreateDirIfMissing(const std::string& dirname)
+{
+ LOG_DEBUG("[IN]%s\n", dirname.c_str());
+ std::string fid, dir, file;
+ split(dirname + "/temp", &dir, &file);
+ Status s = Status::OK();
+
+ do {
+ s = _GetFid(dir, fid);
+ if (Status::NotFound() != s) {
+ break;
+ }
+
+ s = _AddFid(dir, DIR_ID_VALUE);
+ } while (0);
+
+ LOG_DEBUG("[OUT]%s\n", s.ToString().c_str());
+ return s;
+}
+
+/**
+ * @brief delete dir
+ * @details
+ *
+ * @param dirname [description]
+ * @return [description]
+ */
+Status EnvLibrados::DeleteDir(const std::string& dirname)
+{
+ LOG_DEBUG("[IN]%s\n", dirname.c_str());
+ std::string fid, dir, file;
+ split(dirname + "/temp", &dir, &file);
+ Status s = Status::OK();
+
+ s = _GetFid(dir, fid);
+
+ if (s.ok() && DIR_ID_VALUE == fid) {
+ std::vector<std::string> subs;
+ s = _GetSubFnames(dir, &subs);
+ // if subfiles exist, can't delete dir
+ if (subs.size() > 0) {
+ s = Status::IOError();
+ } else {
+ s = _DelFid(dir);
+ }
+ } else {
+ s = Status::NotFound();
+ }
+
+ LOG_DEBUG("[OUT]%s\n", s.ToString().c_str());
+ return s;
+}
+
+/**
+ * @brief return file size
+ * @details [long description]
+ *
+ * @param fname [description]
+ * @param file_size [description]
+ *
+ * @return [description]
+ */
+Status EnvLibrados::GetFileSize(
+ const std::string& fname,
+ uint64_t* file_size)
+{
+ LOG_DEBUG("[IN]%s\n", fname.c_str());
+ std::string fid, dir, file;
+ split(fname, &dir, &file);
+ time_t mtime;
+ Status s;
+
+ do {
+ std::string fpath = dir + "/" + file;
+ s = _GetFid(fpath, fid);
+
+ if (!s.ok()) {
+ break;
+ }
+
+ int ret = _GetIoctx(fpath)->stat(fid, file_size, &mtime);
+ if (ret < 0) {
+ LOG_DEBUG("%i\n", ret);
+ if (-ENOENT == ret) {
+ *file_size = 0;
+ s = Status::OK();
+ } else {
+ s = err_to_status(ret);
+ }
+ } else {
+ s = Status::OK();
+ }
+ } while (0);
+
+ LOG_DEBUG("[OUT]%s|%lld\n", s.ToString().c_str(), (long long)*file_size);
+ return s;
+}
+
+/**
+ * @brief get file modification time
+ * @details [long description]
+ *
+ * @param fname [description]
+ * @param file_mtime [description]
+ *
+ * @return [description]
+ */
+Status EnvLibrados::GetFileModificationTime(const std::string& fname,
+ uint64_t* file_mtime)
+{
+ LOG_DEBUG("[IN]%s\n", fname.c_str());
+ std::string fid, dir, file;
+ split(fname, &dir, &file);
+ time_t mtime;
+ uint64_t file_size;
+ Status s = Status::OK();
+ do {
+ std::string fpath = dir + "/" + file;
+ s = _GetFid(dir + "/" + file, fid);
+
+ if (!s.ok()) {
+ break;
+ }
+
+ int ret = _GetIoctx(fpath)->stat(fid, &file_size, &mtime);
+ if (ret < 0) {
+ if (Status::NotFound() == err_to_status(ret)) {
+ *file_mtime = static_cast<uint64_t>(mtime);
+ s = Status::OK();
+ } else {
+ s = err_to_status(ret);
+ }
+ } else {
+ s = Status::OK();
+ }
+ } while (0);
+
+ LOG_DEBUG("[OUT]%s\n", s.ToString().c_str());
+ return s;
+}
+
+/**
+ * @brief rename file
+ * @details
+ *
+ * @param src [description]
+ * @param target_in [description]
+ *
+ * @return [description]
+ */
+Status EnvLibrados::RenameFile(
+ const std::string& src,
+ const std::string& target_in)
+{
+ LOG_DEBUG("[IN]%s => %s\n", src.c_str(), target_in.c_str());
+ std::string src_fid, tmp_fid, src_dir, src_file, dst_dir, dst_file;
+ split(src, &src_dir, &src_file);
+ split(target_in, &dst_dir, &dst_file);
+
+ auto s = _RenameFid(src_dir + "/" + src_file,
+ dst_dir + "/" + dst_file);
+ LOG_DEBUG("[OUT]%s\n", s.ToString().c_str());
+ return s;
+}
+
+/**
+ * @brief not support
+ * @details [long description]
+ *
+ * @param src [description]
+ * @param target_in [description]
+ *
+ * @return [description]
+ */
+Status EnvLibrados::LinkFile(
+ const std::string& src,
+ const std::string& target_in)
+{
+ LOG_DEBUG("[IO]%s => %s\n", src.c_str(), target_in.c_str());
+ return Status::NotSupported();
+}
+
+/**
+ * @brief lock file. create if missing.
+ * @details [long description]
+ *
+ * It seems that LockFile is used for preventing other instance of RocksDB
+ * from opening up the database at the same time. From RocksDB source code,
+ * the invokes of LockFile are at following locations:
+ *
+ * ./db/db_impl.cc:1159: s = env_->LockFile(LockFileName(dbname_), &db_lock_); // DBImpl::Recover
+ * ./db/db_impl.cc:5839: Status result = env->LockFile(lockname, &lock); // Status DestroyDB
+ *
+ * When db recovery and db destroy, RocksDB will call LockFile
+ *
+ * @param fname [description]
+ * @param lock [description]
+ *
+ * @return [description]
+ */
+Status EnvLibrados::LockFile(
+ const std::string& fname,
+ FileLock** lock)
+{
+ LOG_DEBUG("[IN]%s\n", fname.c_str());
+ std::string fid, dir, file;
+ split(fname, &dir, &file);
+ Status s = Status::OK();
+
+ do {
+ std::string fpath = dir + "/" + file;
+ s = _GetFid(fpath, fid);
+
+ if (Status::OK() != s &&
+ Status::NotFound() != s) {
+ break;
+ } else if (Status::NotFound() == s) {
+ s = _AddFid(fpath, _CreateFid());
+ if (!s.ok()) {
+ break;
+ }
+ } else if (Status::OK() == s && DIR_ID_VALUE == fid) {
+ s = Status::IOError();
+ break;
+ }
+
+ *lock = new LibradosFileLock(_GetIoctx(fpath), fpath);
+ } while (0);
+
+ LOG_DEBUG("[OUT]%s\n", s.ToString().c_str());
+ return s;
+}
+
+/**
+ * @brief unlock file
+ * @details [long description]
+ *
+ * @param lock [description]
+ * @return [description]
+ */
+Status EnvLibrados::UnlockFile(FileLock* lock)
+{
+ LOG_DEBUG("[IO]%p\n", lock);
+ if (nullptr != lock) {
+ delete lock;
+ }
+ return Status::OK();
+}
+
+
+/**
+ * @brief not support
+ * @details [long description]
+ *
+ * @param db_path [description]
+ * @param output_path [description]
+ *
+ * @return [description]
+ */
+Status EnvLibrados::GetAbsolutePath(
+ const std::string& db_path,
+ std::string* output_path)
+{
+ LOG_DEBUG("[IO]%s\n", db_path.c_str());
+ return Status::NotSupported();
+}
+
+/**
+ * @brief Get default EnvLibrados
+ * @details [long description]
+ * @return [description]
+ */
+EnvLibrados* EnvLibrados::Default() {
+ static EnvLibrados default_env(default_db_name,
+ std::getenv(default_config_path),
+ default_pool_name);
+ return &default_env;
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/env_librados.md b/src/rocksdb/utilities/env_librados.md
new file mode 100644
index 000000000..45a2a7bad
--- /dev/null
+++ b/src/rocksdb/utilities/env_librados.md
@@ -0,0 +1,122 @@
+# Introduce to EnvLibrados
+EnvLibrados is a customized RocksDB Env to use RADOS as the backend file system of RocksDB. It overrides all file system related API of default Env. The easiest way to use it is just like following:
+```c++
+std::string db_name = "test_db";
+std::string config_path = "path/to/ceph/config";
+DB* db;
+Options options;
+options.env = EnvLibrados(db_name, config_path);
+Status s = DB::Open(options, kDBPath, &db);
+...
+```
+Then EnvLibrados will forward all file read/write operation to the RADOS cluster assigned by config_path. Default pool is db_name+"_pool".
+
+# Options for EnvLibrados
+There are some options that users could set for EnvLibrados.
+- write_buffer_size. This variable is the max buffer size for WritableFile. After reaching the buffer_max_size, EnvLibrados will sync buffer content to RADOS, then clear buffer.
+- db_pool. Rather than using default pool, users could set their own db pool name
+- wal_dir. The dir for WAL files. Because RocksDB only has 2-level structure (dir_name/file_name), the format of wal_dir is "/dir_name"(CAN'T be "/dir1/dir2"). Default wal_dir is "/wal".
+- wal_pool. Corresponding pool name for WAL files. Default value is db_name+"_wal_pool"
+
+The example of setting options looks like following:
+```c++
+db_name = "test_db";
+db_pool = db_name+"_pool";
+wal_dir = "/wal";
+wal_pool = db_name+"_wal_pool";
+write_buffer_size = 1 << 20;
+env_ = new EnvLibrados(db_name, config, db_pool, wal_dir, wal_pool, write_buffer_size);
+
+DB* db;
+Options options;
+options.env = env_;
+// The last level dir name should match the dir name in prefix_pool_map
+options.wal_dir = "/tmp/wal";
+
+// open DB
+Status s = DB::Open(options, kDBPath, &db);
+...
+```
+
+# Performance Test
+## Compile
+Check this [link](https://github.com/facebook/rocksdb/blob/master/INSTALL.md) to install the dependencies of RocksDB. Then you can compile it by running `$ make env_librados_test ROCKSDB_USE_LIBRADOS=1` under `rocksdb\`. The configure file used by env_librados_test is `../ceph/src/ceph.conf`. For Ubuntu 14.04, just run following commands:
+```bash
+$ sudo apt-get install libgflags-dev
+$ sudo apt-get install libsnappy-dev
+$ sudo apt-get install zlib1g-dev
+$ sudo apt-get install libbz2-dev
+$ make env_librados_test ROCKSDB_USE_LIBRADOS=1
+```
+
+## Test Result
+My test environment is Ubuntu 14.04 in VirtualBox with 8 cores and 8G RAM. Following is the test result.
+
+1. Write (1<<20) keys in random order. The time of writing under default env is around 10s while the time of writing under EnvLibrados is varying from 10s to 30s.
+
+2. Write (1<<20) keys in sequential order. The time of writing under default env drops to arround 1s. But the time of writing under EnvLibrados is not changed.
+
+3. Read (1<<16) keys from (1<<20) keys in random order. The time of reading under both Envs are roughly the same, around 1.8s.
+
+# MyRocks Test
+## Compile Ceph
+See [link](http://docs.ceph.com/docs/master/install/build-ceph/)
+
+## Start RADOS
+
+```bash
+cd ceph-path/src
+( ( ./stop.sh; rm -rf dev/*; CEPH_NUM_OSD=3 ./vstart.sh --short --localhost -n
+-x -d ; ) ) 2>&1
+```
+
+## Compile MySQL
+
+```bash
+sudo apt-get update
+sudo apt-get install g++ cmake libbz2-dev libaio-dev bison \
+zlib1g-dev libsnappy-dev
+sudo apt-get install libgflags-dev libreadline6-dev libncurses5-dev \
+libssl-dev liblz4-dev gdb git
+
+git clone https://github.com/facebook/mysql-5.6.git
+cd mysql-5.6
+git submodule init
+git submodule update
+cmake . -DCMAKE_BUILD_TYPE=RelWithDebInfo -DWITH_SSL=system \
+-DWITH_ZLIB=bundled -DMYSQL_MAINTAINER_MODE=0 -DENABLED_LOCAL_INFILE=1 -DROCKSDB_USE_LIBRADOS=1
+make install -j8
+```
+
+Check this [link](https://github.com/facebook/mysql-5.6/wiki/Build-Steps) for latest compile steps.
+
+## Configure MySQL
+Following is the steps of configuration of MySQL.
+
+```bash
+mkdir -p /etc/mysql
+mkdir -p /var/lib/mysql
+mkdir -p /etc/mysql/conf.d
+echo -e '[mysqld_safe]\nsyslog' > /etc/mysql/conf.d/mysqld_safe_syslog.cnf
+cp /usr/share/mysql/my-medium.cnf /etc/mysql/my.cnf
+sed -i 's#.*datadir.*#datadir = /var/lib/mysql#g' /etc/mysql/my.cnf
+chown mysql:mysql -R /var/lib/mysql
+
+mysql_install_db --user=mysql --ldata=/var/lib/mysql/
+export CEPH_CONFIG_PATH="path/of/ceph/config/file"
+mysqld_safe -user=mysql --skip-innodb --rocksdb --default-storage-engine=rocksdb --default-tmp-storage-engine=MyISAM &
+mysqladmin -u root password
+mysql -u root -p
+```
+
+Check this [link](https://gist.github.com/shichao-an/f5639ecd551496ac2d70) for detail information.
+
+```sql
+show databases;
+create database testdb;
+use testdb;
+show tables;
+CREATE TABLE tbl (id INT AUTO_INCREMENT primary key, str VARCHAR(32));
+insert into tbl values (1, "val2");
+select * from tbl;
+```
diff --git a/src/rocksdb/utilities/env_librados_test.cc b/src/rocksdb/utilities/env_librados_test.cc
new file mode 100644
index 000000000..d5167acc0
--- /dev/null
+++ b/src/rocksdb/utilities/env_librados_test.cc
@@ -0,0 +1,1146 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+// Copyright (c) 2016, Red Hat, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "rocksdb/utilities/env_librados.h"
+#include <rados/librados.hpp>
+#include "env/mock_env.h"
+#include "test_util/testharness.h"
+
+#include "rocksdb/db.h"
+#include "rocksdb/slice.h"
+#include "rocksdb/options.h"
+#include "util/random.h"
+#include <chrono>
+#include <ostream>
+#include "rocksdb/utilities/transaction_db.h"
+
+class Timer {
+ typedef std::chrono::high_resolution_clock high_resolution_clock;
+ typedef std::chrono::milliseconds milliseconds;
+public:
+ explicit Timer(bool run = false)
+ {
+ if (run)
+ Reset();
+ }
+ void Reset()
+ {
+ _start = high_resolution_clock::now();
+ }
+ milliseconds Elapsed() const
+ {
+ return std::chrono::duration_cast<milliseconds>(high_resolution_clock::now() - _start);
+ }
+ template <typename T, typename Traits>
+ friend std::basic_ostream<T, Traits>& operator<<(std::basic_ostream<T, Traits>& out, const Timer& timer)
+ {
+ return out << timer.Elapsed().count();
+ }
+private:
+ high_resolution_clock::time_point _start;
+};
+
+namespace ROCKSDB_NAMESPACE {
+
+class EnvLibradosTest : public testing::Test {
+public:
+ // we will use all of these below
+ const std::string db_name = "env_librados_test_db";
+ const std::string db_pool = db_name + "_pool";
+ const char *keyring = "admin";
+ const char *config = "../ceph/src/ceph.conf";
+
+ EnvLibrados* env_;
+ const EnvOptions soptions_;
+
+ EnvLibradosTest()
+ : env_(new EnvLibrados(db_name, config, db_pool)) {
+ }
+ ~EnvLibradosTest() {
+ delete env_;
+ librados::Rados rados;
+ int ret = 0;
+ do {
+ ret = rados.init("admin"); // just use the client.admin keyring
+ if (ret < 0) { // let's handle any error that might have come back
+ std::cerr << "couldn't initialize rados! error " << ret << std::endl;
+ ret = EXIT_FAILURE;
+ break;
+ }
+
+ ret = rados.conf_read_file(config);
+ if (ret < 0) {
+ // This could fail if the config file is malformed, but it'd be hard.
+ std::cerr << "failed to parse config file " << config
+ << "! error" << ret << std::endl;
+ ret = EXIT_FAILURE;
+ break;
+ }
+
+ /*
+ * next, we actually connect to the cluster
+ */
+
+ ret = rados.connect();
+ if (ret < 0) {
+ std::cerr << "couldn't connect to cluster! error " << ret << std::endl;
+ ret = EXIT_FAILURE;
+ break;
+ }
+
+ /*
+ * And now we're done, so let's remove our pool and then
+ * shut down the connection gracefully.
+ */
+ int delete_ret = rados.pool_delete(db_pool.c_str());
+ if (delete_ret < 0) {
+ // be careful not to
+ std::cerr << "We failed to delete our test pool!" << db_pool << delete_ret << std::endl;
+ ret = EXIT_FAILURE;
+ }
+ } while (0);
+ }
+};
+
+TEST_F(EnvLibradosTest, Basics) {
+ uint64_t file_size;
+ std::unique_ptr<WritableFile> writable_file;
+ std::vector<std::string> children;
+
+ ASSERT_OK(env_->CreateDir("/dir"));
+ // Check that the directory is empty.
+ ASSERT_EQ(Status::NotFound(), env_->FileExists("/dir/non_existent"));
+ ASSERT_TRUE(!env_->GetFileSize("/dir/non_existent", &file_size).ok());
+ ASSERT_OK(env_->GetChildren("/dir", &children));
+ ASSERT_EQ(0U, children.size());
+
+ // Create a file.
+ ASSERT_OK(env_->NewWritableFile("/dir/f", &writable_file, soptions_));
+ writable_file.reset();
+
+ // Check that the file exists.
+ ASSERT_OK(env_->FileExists("/dir/f"));
+ ASSERT_OK(env_->GetFileSize("/dir/f", &file_size));
+ ASSERT_EQ(0U, file_size);
+ ASSERT_OK(env_->GetChildren("/dir", &children));
+ ASSERT_EQ(1U, children.size());
+ ASSERT_EQ("f", children[0]);
+
+ // Write to the file.
+ ASSERT_OK(env_->NewWritableFile("/dir/f", &writable_file, soptions_));
+ ASSERT_OK(writable_file->Append("abc"));
+ writable_file.reset();
+
+
+ // Check for expected size.
+ ASSERT_OK(env_->GetFileSize("/dir/f", &file_size));
+ ASSERT_EQ(3U, file_size);
+
+
+ // Check that renaming works.
+ ASSERT_TRUE(!env_->RenameFile("/dir/non_existent", "/dir/g").ok());
+ ASSERT_OK(env_->RenameFile("/dir/f", "/dir/g"));
+ ASSERT_EQ(Status::NotFound(), env_->FileExists("/dir/f"));
+ ASSERT_OK(env_->FileExists("/dir/g"));
+ ASSERT_OK(env_->GetFileSize("/dir/g", &file_size));
+ ASSERT_EQ(3U, file_size);
+
+ // Check that opening non-existent file fails.
+ std::unique_ptr<SequentialFile> seq_file;
+ std::unique_ptr<RandomAccessFile> rand_file;
+ ASSERT_TRUE(
+ !env_->NewSequentialFile("/dir/non_existent", &seq_file, soptions_).ok());
+ ASSERT_TRUE(!seq_file);
+ ASSERT_TRUE(!env_->NewRandomAccessFile("/dir/non_existent", &rand_file,
+ soptions_).ok());
+ ASSERT_TRUE(!rand_file);
+
+ // Check that deleting works.
+ ASSERT_TRUE(!env_->DeleteFile("/dir/non_existent").ok());
+ ASSERT_OK(env_->DeleteFile("/dir/g"));
+ ASSERT_EQ(Status::NotFound(), env_->FileExists("/dir/g"));
+ ASSERT_OK(env_->GetChildren("/dir", &children));
+ ASSERT_EQ(0U, children.size());
+ ASSERT_OK(env_->DeleteDir("/dir"));
+}
+
+TEST_F(EnvLibradosTest, ReadWrite) {
+ std::unique_ptr<WritableFile> writable_file;
+ std::unique_ptr<SequentialFile> seq_file;
+ std::unique_ptr<RandomAccessFile> rand_file;
+ Slice result;
+ char scratch[100];
+
+ ASSERT_OK(env_->CreateDir("/dir"));
+
+ ASSERT_OK(env_->NewWritableFile("/dir/f", &writable_file, soptions_));
+ ASSERT_OK(writable_file->Append("hello "));
+ ASSERT_OK(writable_file->Append("world"));
+ writable_file.reset();
+
+ // Read sequentially.
+ ASSERT_OK(env_->NewSequentialFile("/dir/f", &seq_file, soptions_));
+ ASSERT_OK(seq_file->Read(5, &result, scratch)); // Read "hello".
+ ASSERT_EQ(0, result.compare("hello"));
+ ASSERT_OK(seq_file->Skip(1));
+ ASSERT_OK(seq_file->Read(1000, &result, scratch)); // Read "world".
+ ASSERT_EQ(0, result.compare("world"));
+ ASSERT_OK(seq_file->Read(1000, &result, scratch)); // Try reading past EOF.
+ ASSERT_EQ(0U, result.size());
+ ASSERT_OK(seq_file->Skip(100)); // Try to skip past end of file.
+ ASSERT_OK(seq_file->Read(1000, &result, scratch));
+ ASSERT_EQ(0U, result.size());
+
+ // Random reads.
+ ASSERT_OK(env_->NewRandomAccessFile("/dir/f", &rand_file, soptions_));
+ ASSERT_OK(rand_file->Read(6, 5, &result, scratch)); // Read "world".
+ ASSERT_EQ(0, result.compare("world"));
+ ASSERT_OK(rand_file->Read(0, 5, &result, scratch)); // Read "hello".
+ ASSERT_EQ(0, result.compare("hello"));
+ ASSERT_OK(rand_file->Read(10, 100, &result, scratch)); // Read "d".
+ ASSERT_EQ(0, result.compare("d"));
+
+ // Too high offset.
+ ASSERT_OK(rand_file->Read(1000, 5, &result, scratch));
+}
+
+TEST_F(EnvLibradosTest, Locks) {
+ FileLock* lock = nullptr;
+ std::unique_ptr<WritableFile> writable_file;
+
+ ASSERT_OK(env_->CreateDir("/dir"));
+
+ ASSERT_OK(env_->NewWritableFile("/dir/f", &writable_file, soptions_));
+
+ // These are no-ops, but we test they return success.
+ ASSERT_OK(env_->LockFile("some file", &lock));
+ ASSERT_OK(env_->UnlockFile(lock));
+
+ ASSERT_OK(env_->LockFile("/dir/f", &lock));
+ ASSERT_OK(env_->UnlockFile(lock));
+}
+
+TEST_F(EnvLibradosTest, Misc) {
+ std::string test_dir;
+ ASSERT_OK(env_->GetTestDirectory(&test_dir));
+ ASSERT_TRUE(!test_dir.empty());
+
+ std::unique_ptr<WritableFile> writable_file;
+ ASSERT_TRUE(!env_->NewWritableFile("/a/b", &writable_file, soptions_).ok());
+
+ ASSERT_OK(env_->NewWritableFile("/a", &writable_file, soptions_));
+ // These are no-ops, but we test they return success.
+ ASSERT_OK(writable_file->Sync());
+ ASSERT_OK(writable_file->Flush());
+ ASSERT_OK(writable_file->Close());
+ writable_file.reset();
+}
+
+TEST_F(EnvLibradosTest, LargeWrite) {
+ const size_t kWriteSize = 300 * 1024;
+ char* scratch = new char[kWriteSize * 2];
+
+ std::string write_data;
+ for (size_t i = 0; i < kWriteSize; ++i) {
+ write_data.append(1, 'h');
+ }
+
+ std::unique_ptr<WritableFile> writable_file;
+ ASSERT_OK(env_->CreateDir("/dir"));
+ ASSERT_OK(env_->NewWritableFile("/dir/g", &writable_file, soptions_));
+ ASSERT_OK(writable_file->Append("foo"));
+ ASSERT_OK(writable_file->Append(write_data));
+ writable_file.reset();
+
+ std::unique_ptr<SequentialFile> seq_file;
+ Slice result;
+ ASSERT_OK(env_->NewSequentialFile("/dir/g", &seq_file, soptions_));
+ ASSERT_OK(seq_file->Read(3, &result, scratch)); // Read "foo".
+ ASSERT_EQ(0, result.compare("foo"));
+
+ size_t read = 0;
+ std::string read_data;
+ while (read < kWriteSize) {
+ ASSERT_OK(seq_file->Read(kWriteSize - read, &result, scratch));
+ read_data.append(result.data(), result.size());
+ read += result.size();
+ }
+ ASSERT_TRUE(write_data == read_data);
+ delete[] scratch;
+}
+
+TEST_F(EnvLibradosTest, FrequentlySmallWrite) {
+ const size_t kWriteSize = 1 << 10;
+ char* scratch = new char[kWriteSize * 2];
+
+ std::string write_data;
+ for (size_t i = 0; i < kWriteSize; ++i) {
+ write_data.append(1, 'h');
+ }
+
+ std::unique_ptr<WritableFile> writable_file;
+ ASSERT_OK(env_->CreateDir("/dir"));
+ ASSERT_OK(env_->NewWritableFile("/dir/g", &writable_file, soptions_));
+ ASSERT_OK(writable_file->Append("foo"));
+
+ for (size_t i = 0; i < kWriteSize; ++i) {
+ ASSERT_OK(writable_file->Append("h"));
+ }
+ writable_file.reset();
+
+ std::unique_ptr<SequentialFile> seq_file;
+ Slice result;
+ ASSERT_OK(env_->NewSequentialFile("/dir/g", &seq_file, soptions_));
+ ASSERT_OK(seq_file->Read(3, &result, scratch)); // Read "foo".
+ ASSERT_EQ(0, result.compare("foo"));
+
+ size_t read = 0;
+ std::string read_data;
+ while (read < kWriteSize) {
+ ASSERT_OK(seq_file->Read(kWriteSize - read, &result, scratch));
+ read_data.append(result.data(), result.size());
+ read += result.size();
+ }
+ ASSERT_TRUE(write_data == read_data);
+ delete[] scratch;
+}
+
+TEST_F(EnvLibradosTest, Truncate) {
+ const size_t kWriteSize = 300 * 1024;
+ const size_t truncSize = 1024;
+ std::string write_data;
+ for (size_t i = 0; i < kWriteSize; ++i) {
+ write_data.append(1, 'h');
+ }
+
+ std::unique_ptr<WritableFile> writable_file;
+ ASSERT_OK(env_->CreateDir("/dir"));
+ ASSERT_OK(env_->NewWritableFile("/dir/g", &writable_file, soptions_));
+ ASSERT_OK(writable_file->Append(write_data));
+ ASSERT_EQ(writable_file->GetFileSize(), kWriteSize);
+ ASSERT_OK(writable_file->Truncate(truncSize));
+ ASSERT_EQ(writable_file->GetFileSize(), truncSize);
+ writable_file.reset();
+}
+
+TEST_F(EnvLibradosTest, DBBasics) {
+ std::string kDBPath = "/tmp/DBBasics";
+ DB* db;
+ Options options;
+ // Optimize RocksDB. This is the easiest way to get RocksDB to perform well
+ options.IncreaseParallelism();
+ options.OptimizeLevelStyleCompaction();
+ // create the DB if it's not already present
+ options.create_if_missing = true;
+ options.env = env_;
+
+ // open DB
+ Status s = DB::Open(options, kDBPath, &db);
+ assert(s.ok());
+
+ // Put key-value
+ s = db->Put(WriteOptions(), "key1", "value");
+ assert(s.ok());
+ std::string value;
+ // get value
+ s = db->Get(ReadOptions(), "key1", &value);
+ assert(s.ok());
+ assert(value == "value");
+
+ // atomically apply a set of updates
+ {
+ WriteBatch batch;
+ batch.Delete("key1");
+ batch.Put("key2", value);
+ s = db->Write(WriteOptions(), &batch);
+ }
+
+ s = db->Get(ReadOptions(), "key1", &value);
+ assert(s.IsNotFound());
+
+ db->Get(ReadOptions(), "key2", &value);
+ assert(value == "value");
+
+ delete db;
+}
+
+TEST_F(EnvLibradosTest, DBLoadKeysInRandomOrder) {
+ char key[20] = {0}, value[20] = {0};
+ int max_loop = 1 << 10;
+ Timer timer(false);
+ std::cout << "Test size : loop(" << max_loop << ")" << std::endl;
+ /**********************************
+ use default env
+ ***********************************/
+ std::string kDBPath1 = "/tmp/DBLoadKeysInRandomOrder1";
+ DB* db1;
+ Options options1;
+ // Optimize Rocksdb. This is the easiest way to get RocksDB to perform well
+ options1.IncreaseParallelism();
+ options1.OptimizeLevelStyleCompaction();
+ // create the DB if it's not already present
+ options1.create_if_missing = true;
+
+ // open DB
+ Status s1 = DB::Open(options1, kDBPath1, &db1);
+ assert(s1.ok());
+
+ ROCKSDB_NAMESPACE::Random64 r1(time(nullptr));
+
+ timer.Reset();
+ for (int i = 0; i < max_loop; ++i) {
+ snprintf(key,
+ 20,
+ "%16lx",
+ (unsigned long)r1.Uniform(std::numeric_limits<uint64_t>::max()));
+ snprintf(value,
+ 20,
+ "%16lx",
+ (unsigned long)r1.Uniform(std::numeric_limits<uint64_t>::max()));
+ // Put key-value
+ s1 = db1->Put(WriteOptions(), key, value);
+ assert(s1.ok());
+ }
+ std::cout << "Time by default : " << timer << "ms" << std::endl;
+ delete db1;
+
+ /**********************************
+ use librados env
+ ***********************************/
+ std::string kDBPath2 = "/tmp/DBLoadKeysInRandomOrder2";
+ DB* db2;
+ Options options2;
+ // Optimize RocksDB. This is the easiest way to get RocksDB to perform well
+ options2.IncreaseParallelism();
+ options2.OptimizeLevelStyleCompaction();
+ // create the DB if it's not already present
+ options2.create_if_missing = true;
+ options2.env = env_;
+
+ // open DB
+ Status s2 = DB::Open(options2, kDBPath2, &db2);
+ assert(s2.ok());
+
+ ROCKSDB_NAMESPACE::Random64 r2(time(nullptr));
+
+ timer.Reset();
+ for (int i = 0; i < max_loop; ++i) {
+ snprintf(key,
+ 20,
+ "%16lx",
+ (unsigned long)r2.Uniform(std::numeric_limits<uint64_t>::max()));
+ snprintf(value,
+ 20,
+ "%16lx",
+ (unsigned long)r2.Uniform(std::numeric_limits<uint64_t>::max()));
+ // Put key-value
+ s2 = db2->Put(WriteOptions(), key, value);
+ assert(s2.ok());
+ }
+ std::cout << "Time by librados : " << timer << "ms" << std::endl;
+ delete db2;
+}
+
+TEST_F(EnvLibradosTest, DBBulkLoadKeysInRandomOrder) {
+ char key[20] = {0}, value[20] = {0};
+ int max_loop = 1 << 6;
+ int bulk_size = 1 << 15;
+ Timer timer(false);
+ std::cout << "Test size : loop(" << max_loop << "); bulk_size(" << bulk_size << ")" << std::endl;
+ /**********************************
+ use default env
+ ***********************************/
+ std::string kDBPath1 = "/tmp/DBBulkLoadKeysInRandomOrder1";
+ DB* db1;
+ Options options1;
+ // Optimize Rocksdb. This is the easiest way to get RocksDB to perform well
+ options1.IncreaseParallelism();
+ options1.OptimizeLevelStyleCompaction();
+ // create the DB if it's not already present
+ options1.create_if_missing = true;
+
+ // open DB
+ Status s1 = DB::Open(options1, kDBPath1, &db1);
+ assert(s1.ok());
+
+ ROCKSDB_NAMESPACE::Random64 r1(time(nullptr));
+
+ timer.Reset();
+ for (int i = 0; i < max_loop; ++i) {
+ WriteBatch batch;
+ for (int j = 0; j < bulk_size; ++j) {
+ snprintf(key,
+ 20,
+ "%16lx",
+ (unsigned long)r1.Uniform(std::numeric_limits<uint64_t>::max()));
+ snprintf(value,
+ 20,
+ "%16lx",
+ (unsigned long)r1.Uniform(std::numeric_limits<uint64_t>::max()));
+ batch.Put(key, value);
+ }
+ s1 = db1->Write(WriteOptions(), &batch);
+ assert(s1.ok());
+ }
+ std::cout << "Time by default : " << timer << "ms" << std::endl;
+ delete db1;
+
+ /**********************************
+ use librados env
+ ***********************************/
+ std::string kDBPath2 = "/tmp/DBBulkLoadKeysInRandomOrder2";
+ DB* db2;
+ Options options2;
+ // Optimize RocksDB. This is the easiest way to get RocksDB to perform well
+ options2.IncreaseParallelism();
+ options2.OptimizeLevelStyleCompaction();
+ // create the DB if it's not already present
+ options2.create_if_missing = true;
+ options2.env = env_;
+
+ // open DB
+ Status s2 = DB::Open(options2, kDBPath2, &db2);
+ assert(s2.ok());
+
+ ROCKSDB_NAMESPACE::Random64 r2(time(nullptr));
+
+ timer.Reset();
+ for (int i = 0; i < max_loop; ++i) {
+ WriteBatch batch;
+ for (int j = 0; j < bulk_size; ++j) {
+ snprintf(key,
+ 20,
+ "%16lx",
+ (unsigned long)r2.Uniform(std::numeric_limits<uint64_t>::max()));
+ snprintf(value,
+ 20,
+ "%16lx",
+ (unsigned long)r2.Uniform(std::numeric_limits<uint64_t>::max()));
+ batch.Put(key, value);
+ }
+ s2 = db2->Write(WriteOptions(), &batch);
+ assert(s2.ok());
+ }
+ std::cout << "Time by librados : " << timer << "ms" << std::endl;
+ delete db2;
+}
+
+TEST_F(EnvLibradosTest, DBBulkLoadKeysInSequentialOrder) {
+ char key[20] = {0}, value[20] = {0};
+ int max_loop = 1 << 6;
+ int bulk_size = 1 << 15;
+ Timer timer(false);
+ std::cout << "Test size : loop(" << max_loop << "); bulk_size(" << bulk_size << ")" << std::endl;
+ /**********************************
+ use default env
+ ***********************************/
+ std::string kDBPath1 = "/tmp/DBBulkLoadKeysInSequentialOrder1";
+ DB* db1;
+ Options options1;
+ // Optimize Rocksdb. This is the easiest way to get RocksDB to perform well
+ options1.IncreaseParallelism();
+ options1.OptimizeLevelStyleCompaction();
+ // create the DB if it's not already present
+ options1.create_if_missing = true;
+
+ // open DB
+ Status s1 = DB::Open(options1, kDBPath1, &db1);
+ assert(s1.ok());
+
+ ROCKSDB_NAMESPACE::Random64 r1(time(nullptr));
+
+ timer.Reset();
+ for (int i = 0; i < max_loop; ++i) {
+ WriteBatch batch;
+ for (int j = 0; j < bulk_size; ++j) {
+ snprintf(key,
+ 20,
+ "%019lld",
+ (long long)(i * bulk_size + j));
+ snprintf(value,
+ 20,
+ "%16lx",
+ (unsigned long)r1.Uniform(std::numeric_limits<uint64_t>::max()));
+ batch.Put(key, value);
+ }
+ s1 = db1->Write(WriteOptions(), &batch);
+ assert(s1.ok());
+ }
+ std::cout << "Time by default : " << timer << "ms" << std::endl;
+ delete db1;
+
+ /**********************************
+ use librados env
+ ***********************************/
+ std::string kDBPath2 = "/tmp/DBBulkLoadKeysInSequentialOrder2";
+ DB* db2;
+ Options options2;
+ // Optimize RocksDB. This is the easiest way to get RocksDB to perform well
+ options2.IncreaseParallelism();
+ options2.OptimizeLevelStyleCompaction();
+ // create the DB if it's not already present
+ options2.create_if_missing = true;
+ options2.env = env_;
+
+ // open DB
+ Status s2 = DB::Open(options2, kDBPath2, &db2);
+ assert(s2.ok());
+
+ ROCKSDB_NAMESPACE::Random64 r2(time(nullptr));
+
+ timer.Reset();
+ for (int i = 0; i < max_loop; ++i) {
+ WriteBatch batch;
+ for (int j = 0; j < bulk_size; ++j) {
+ snprintf(key,
+ 20,
+ "%16lx",
+ (unsigned long)r2.Uniform(std::numeric_limits<uint64_t>::max()));
+ snprintf(value,
+ 20,
+ "%16lx",
+ (unsigned long)r2.Uniform(std::numeric_limits<uint64_t>::max()));
+ batch.Put(key, value);
+ }
+ s2 = db2->Write(WriteOptions(), &batch);
+ assert(s2.ok());
+ }
+ std::cout << "Time by librados : " << timer << "ms" << std::endl;
+ delete db2;
+}
+
+TEST_F(EnvLibradosTest, DBRandomRead) {
+ char key[20] = {0}, value[20] = {0};
+ int max_loop = 1 << 6;
+ int bulk_size = 1 << 10;
+ int read_loop = 1 << 20;
+ Timer timer(false);
+ std::cout << "Test size : keys_num(" << max_loop << ", " << bulk_size << "); read_loop(" << read_loop << ")" << std::endl;
+ /**********************************
+ use default env
+ ***********************************/
+ std::string kDBPath1 = "/tmp/DBRandomRead1";
+ DB* db1;
+ Options options1;
+ // Optimize Rocksdb. This is the easiest way to get RocksDB to perform well
+ options1.IncreaseParallelism();
+ options1.OptimizeLevelStyleCompaction();
+ // create the DB if it's not already present
+ options1.create_if_missing = true;
+
+ // open DB
+ Status s1 = DB::Open(options1, kDBPath1, &db1);
+ assert(s1.ok());
+
+ ROCKSDB_NAMESPACE::Random64 r1(time(nullptr));
+
+ for (int i = 0; i < max_loop; ++i) {
+ WriteBatch batch;
+ for (int j = 0; j < bulk_size; ++j) {
+ snprintf(key,
+ 20,
+ "%019lld",
+ (long long)(i * bulk_size + j));
+ snprintf(value,
+ 20,
+ "%16lx",
+ (unsigned long)r1.Uniform(std::numeric_limits<uint64_t>::max()));
+ batch.Put(key, value);
+ }
+ s1 = db1->Write(WriteOptions(), &batch);
+ assert(s1.ok());
+ }
+ timer.Reset();
+ int base1 = 0, offset1 = 0;
+ for (int i = 0; i < read_loop; ++i) {
+ base1 = r1.Uniform(max_loop);
+ offset1 = r1.Uniform(bulk_size);
+ std::string value1;
+ snprintf(key,
+ 20,
+ "%019lld",
+ (long long)(base1 * bulk_size + offset1));
+ s1 = db1->Get(ReadOptions(), key, &value1);
+ assert(s1.ok());
+ }
+ std::cout << "Time by default : " << timer << "ms" << std::endl;
+ delete db1;
+
+ /**********************************
+ use librados env
+ ***********************************/
+ std::string kDBPath2 = "/tmp/DBRandomRead2";
+ DB* db2;
+ Options options2;
+ // Optimize RocksDB. This is the easiest way to get RocksDB to perform well
+ options2.IncreaseParallelism();
+ options2.OptimizeLevelStyleCompaction();
+ // create the DB if it's not already present
+ options2.create_if_missing = true;
+ options2.env = env_;
+
+ // open DB
+ Status s2 = DB::Open(options2, kDBPath2, &db2);
+ assert(s2.ok());
+
+ ROCKSDB_NAMESPACE::Random64 r2(time(nullptr));
+
+ for (int i = 0; i < max_loop; ++i) {
+ WriteBatch batch;
+ for (int j = 0; j < bulk_size; ++j) {
+ snprintf(key,
+ 20,
+ "%019lld",
+ (long long)(i * bulk_size + j));
+ snprintf(value,
+ 20,
+ "%16lx",
+ (unsigned long)r2.Uniform(std::numeric_limits<uint64_t>::max()));
+ batch.Put(key, value);
+ }
+ s2 = db2->Write(WriteOptions(), &batch);
+ assert(s2.ok());
+ }
+
+ timer.Reset();
+ int base2 = 0, offset2 = 0;
+ for (int i = 0; i < read_loop; ++i) {
+ base2 = r2.Uniform(max_loop);
+ offset2 = r2.Uniform(bulk_size);
+ std::string value2;
+ snprintf(key,
+ 20,
+ "%019lld",
+ (long long)(base2 * bulk_size + offset2));
+ s2 = db2->Get(ReadOptions(), key, &value2);
+ if (!s2.ok()) {
+ std::cout << s2.ToString() << std::endl;
+ }
+ assert(s2.ok());
+ }
+ std::cout << "Time by librados : " << timer << "ms" << std::endl;
+ delete db2;
+}
+
+class EnvLibradosMutipoolTest : public testing::Test {
+public:
+ // we will use all of these below
+ const std::string client_name = "client.admin";
+ const std::string cluster_name = "ceph";
+ const uint64_t flags = 0;
+ const std::string db_name = "env_librados_test_db";
+ const std::string db_pool = db_name + "_pool";
+ const std::string wal_dir = "/wal";
+ const std::string wal_pool = db_name + "_wal_pool";
+ const size_t write_buffer_size = 1 << 20;
+ const char *keyring = "admin";
+ const char *config = "../ceph/src/ceph.conf";
+
+ EnvLibrados* env_;
+ const EnvOptions soptions_;
+
+ EnvLibradosMutipoolTest() {
+ env_ = new EnvLibrados(client_name, cluster_name, flags, db_name, config, db_pool, wal_dir, wal_pool, write_buffer_size);
+ }
+ ~EnvLibradosMutipoolTest() {
+ delete env_;
+ librados::Rados rados;
+ int ret = 0;
+ do {
+ ret = rados.init("admin"); // just use the client.admin keyring
+ if (ret < 0) { // let's handle any error that might have come back
+ std::cerr << "couldn't initialize rados! error " << ret << std::endl;
+ ret = EXIT_FAILURE;
+ break;
+ }
+
+ ret = rados.conf_read_file(config);
+ if (ret < 0) {
+ // This could fail if the config file is malformed, but it'd be hard.
+ std::cerr << "failed to parse config file " << config
+ << "! error" << ret << std::endl;
+ ret = EXIT_FAILURE;
+ break;
+ }
+
+ /*
+ * next, we actually connect to the cluster
+ */
+
+ ret = rados.connect();
+ if (ret < 0) {
+ std::cerr << "couldn't connect to cluster! error " << ret << std::endl;
+ ret = EXIT_FAILURE;
+ break;
+ }
+
+ /*
+ * And now we're done, so let's remove our pool and then
+ * shut down the connection gracefully.
+ */
+ int delete_ret = rados.pool_delete(db_pool.c_str());
+ if (delete_ret < 0) {
+ // be careful not to
+ std::cerr << "We failed to delete our test pool!" << db_pool << delete_ret << std::endl;
+ ret = EXIT_FAILURE;
+ }
+ delete_ret = rados.pool_delete(wal_pool.c_str());
+ if (delete_ret < 0) {
+ // be careful not to
+ std::cerr << "We failed to delete our test pool!" << wal_pool << delete_ret << std::endl;
+ ret = EXIT_FAILURE;
+ }
+ } while (0);
+ }
+};
+
+TEST_F(EnvLibradosMutipoolTest, Basics) {
+ uint64_t file_size;
+ std::unique_ptr<WritableFile> writable_file;
+ std::vector<std::string> children;
+ std::vector<std::string> v = {"/tmp/dir1", "/tmp/dir2", "/tmp/dir3", "/tmp/dir4", "dir"};
+
+ for (size_t i = 0; i < v.size(); ++i) {
+ std::string dir = v[i];
+ std::string dir_non_existent = dir + "/non_existent";
+ std::string dir_f = dir + "/f";
+ std::string dir_g = dir + "/g";
+
+ ASSERT_OK(env_->CreateDir(dir.c_str()));
+ // Check that the directory is empty.
+ ASSERT_EQ(Status::NotFound(), env_->FileExists(dir_non_existent.c_str()));
+ ASSERT_TRUE(!env_->GetFileSize(dir_non_existent.c_str(), &file_size).ok());
+ ASSERT_OK(env_->GetChildren(dir.c_str(), &children));
+ ASSERT_EQ(0U, children.size());
+
+ // Create a file.
+ ASSERT_OK(env_->NewWritableFile(dir_f.c_str(), &writable_file, soptions_));
+ writable_file.reset();
+
+ // Check that the file exists.
+ ASSERT_OK(env_->FileExists(dir_f.c_str()));
+ ASSERT_OK(env_->GetFileSize(dir_f.c_str(), &file_size));
+ ASSERT_EQ(0U, file_size);
+ ASSERT_OK(env_->GetChildren(dir.c_str(), &children));
+ ASSERT_EQ(1U, children.size());
+ ASSERT_EQ("f", children[0]);
+
+ // Write to the file.
+ ASSERT_OK(env_->NewWritableFile(dir_f.c_str(), &writable_file, soptions_));
+ ASSERT_OK(writable_file->Append("abc"));
+ writable_file.reset();
+
+
+ // Check for expected size.
+ ASSERT_OK(env_->GetFileSize(dir_f.c_str(), &file_size));
+ ASSERT_EQ(3U, file_size);
+
+
+ // Check that renaming works.
+ ASSERT_TRUE(!env_->RenameFile(dir_non_existent.c_str(), dir_g.c_str()).ok());
+ ASSERT_OK(env_->RenameFile(dir_f.c_str(), dir_g.c_str()));
+ ASSERT_EQ(Status::NotFound(), env_->FileExists(dir_f.c_str()));
+ ASSERT_OK(env_->FileExists(dir_g.c_str()));
+ ASSERT_OK(env_->GetFileSize(dir_g.c_str(), &file_size));
+ ASSERT_EQ(3U, file_size);
+
+ // Check that opening non-existent file fails.
+ std::unique_ptr<SequentialFile> seq_file;
+ std::unique_ptr<RandomAccessFile> rand_file;
+ ASSERT_TRUE(
+ !env_->NewSequentialFile(dir_non_existent.c_str(), &seq_file, soptions_).ok());
+ ASSERT_TRUE(!seq_file);
+ ASSERT_TRUE(!env_->NewRandomAccessFile(dir_non_existent.c_str(), &rand_file,
+ soptions_).ok());
+ ASSERT_TRUE(!rand_file);
+
+ // Check that deleting works.
+ ASSERT_TRUE(!env_->DeleteFile(dir_non_existent.c_str()).ok());
+ ASSERT_OK(env_->DeleteFile(dir_g.c_str()));
+ ASSERT_EQ(Status::NotFound(), env_->FileExists(dir_g.c_str()));
+ ASSERT_OK(env_->GetChildren(dir.c_str(), &children));
+ ASSERT_EQ(0U, children.size());
+ ASSERT_OK(env_->DeleteDir(dir.c_str()));
+ }
+}
+
+TEST_F(EnvLibradosMutipoolTest, DBBasics) {
+ std::string kDBPath = "/tmp/DBBasics";
+ std::string walPath = "/tmp/wal";
+ DB* db;
+ Options options;
+ // Optimize RocksDB. This is the easiest way to get RocksDB to perform well
+ options.IncreaseParallelism();
+ options.OptimizeLevelStyleCompaction();
+ // create the DB if it's not already present
+ options.create_if_missing = true;
+ options.env = env_;
+ options.wal_dir = walPath;
+
+ // open DB
+ Status s = DB::Open(options, kDBPath, &db);
+ assert(s.ok());
+
+ // Put key-value
+ s = db->Put(WriteOptions(), "key1", "value");
+ assert(s.ok());
+ std::string value;
+ // get value
+ s = db->Get(ReadOptions(), "key1", &value);
+ assert(s.ok());
+ assert(value == "value");
+
+ // atomically apply a set of updates
+ {
+ WriteBatch batch;
+ batch.Delete("key1");
+ batch.Put("key2", value);
+ s = db->Write(WriteOptions(), &batch);
+ }
+
+ s = db->Get(ReadOptions(), "key1", &value);
+ assert(s.IsNotFound());
+
+ db->Get(ReadOptions(), "key2", &value);
+ assert(value == "value");
+
+ delete db;
+}
+
+TEST_F(EnvLibradosMutipoolTest, DBBulkLoadKeysInRandomOrder) {
+ char key[20] = {0}, value[20] = {0};
+ int max_loop = 1 << 6;
+ int bulk_size = 1 << 15;
+ Timer timer(false);
+ std::cout << "Test size : loop(" << max_loop << "); bulk_size(" << bulk_size << ")" << std::endl;
+ /**********************************
+ use default env
+ ***********************************/
+ std::string kDBPath1 = "/tmp/DBBulkLoadKeysInRandomOrder1";
+ std::string walPath = "/tmp/wal";
+ DB* db1;
+ Options options1;
+ // Optimize Rocksdb. This is the easiest way to get RocksDB to perform well
+ options1.IncreaseParallelism();
+ options1.OptimizeLevelStyleCompaction();
+ // create the DB if it's not already present
+ options1.create_if_missing = true;
+
+ // open DB
+ Status s1 = DB::Open(options1, kDBPath1, &db1);
+ assert(s1.ok());
+
+ ROCKSDB_NAMESPACE::Random64 r1(time(nullptr));
+
+ timer.Reset();
+ for (int i = 0; i < max_loop; ++i) {
+ WriteBatch batch;
+ for (int j = 0; j < bulk_size; ++j) {
+ snprintf(key,
+ 20,
+ "%16lx",
+ (unsigned long)r1.Uniform(std::numeric_limits<uint64_t>::max()));
+ snprintf(value,
+ 20,
+ "%16lx",
+ (unsigned long)r1.Uniform(std::numeric_limits<uint64_t>::max()));
+ batch.Put(key, value);
+ }
+ s1 = db1->Write(WriteOptions(), &batch);
+ assert(s1.ok());
+ }
+ std::cout << "Time by default : " << timer << "ms" << std::endl;
+ delete db1;
+
+ /**********************************
+ use librados env
+ ***********************************/
+ std::string kDBPath2 = "/tmp/DBBulkLoadKeysInRandomOrder2";
+ DB* db2;
+ Options options2;
+ // Optimize RocksDB. This is the easiest way to get RocksDB to perform well
+ options2.IncreaseParallelism();
+ options2.OptimizeLevelStyleCompaction();
+ // create the DB if it's not already present
+ options2.create_if_missing = true;
+ options2.env = env_;
+ options2.wal_dir = walPath;
+
+ // open DB
+ Status s2 = DB::Open(options2, kDBPath2, &db2);
+ if (!s2.ok()) {
+ std::cerr << s2.ToString() << std::endl;
+ }
+ assert(s2.ok());
+
+ ROCKSDB_NAMESPACE::Random64 r2(time(nullptr));
+
+ timer.Reset();
+ for (int i = 0; i < max_loop; ++i) {
+ WriteBatch batch;
+ for (int j = 0; j < bulk_size; ++j) {
+ snprintf(key,
+ 20,
+ "%16lx",
+ (unsigned long)r2.Uniform(std::numeric_limits<uint64_t>::max()));
+ snprintf(value,
+ 20,
+ "%16lx",
+ (unsigned long)r2.Uniform(std::numeric_limits<uint64_t>::max()));
+ batch.Put(key, value);
+ }
+ s2 = db2->Write(WriteOptions(), &batch);
+ assert(s2.ok());
+ }
+ std::cout << "Time by librados : " << timer << "ms" << std::endl;
+ delete db2;
+}
+
+TEST_F(EnvLibradosMutipoolTest, DBTransactionDB) {
+ std::string kDBPath = "/tmp/DBTransactionDB";
+ // open DB
+ Options options;
+ TransactionDBOptions txn_db_options;
+ options.create_if_missing = true;
+ options.env = env_;
+ TransactionDB* txn_db;
+
+ Status s = TransactionDB::Open(options, txn_db_options, kDBPath, &txn_db);
+ assert(s.ok());
+
+ WriteOptions write_options;
+ ReadOptions read_options;
+ TransactionOptions txn_options;
+ std::string value;
+
+ ////////////////////////////////////////////////////////
+ //
+ // Simple OptimisticTransaction Example ("Read Committed")
+ //
+ ////////////////////////////////////////////////////////
+
+ // Start a transaction
+ Transaction* txn = txn_db->BeginTransaction(write_options);
+ assert(txn);
+
+ // Read a key in this transaction
+ s = txn->Get(read_options, "abc", &value);
+ assert(s.IsNotFound());
+
+ // Write a key in this transaction
+ s = txn->Put("abc", "def");
+ assert(s.ok());
+
+ // Read a key OUTSIDE this transaction. Does not affect txn.
+ s = txn_db->Get(read_options, "abc", &value);
+
+ // Write a key OUTSIDE of this transaction.
+ // Does not affect txn since this is an unrelated key. If we wrote key 'abc'
+ // here, the transaction would fail to commit.
+ s = txn_db->Put(write_options, "xyz", "zzz");
+
+ // Commit transaction
+ s = txn->Commit();
+ assert(s.ok());
+ delete txn;
+
+ ////////////////////////////////////////////////////////
+ //
+ // "Repeatable Read" (Snapshot Isolation) Example
+ // -- Using a single Snapshot
+ //
+ ////////////////////////////////////////////////////////
+
+ // Set a snapshot at start of transaction by setting set_snapshot=true
+ txn_options.set_snapshot = true;
+ txn = txn_db->BeginTransaction(write_options, txn_options);
+
+ const Snapshot* snapshot = txn->GetSnapshot();
+
+ // Write a key OUTSIDE of transaction
+ s = txn_db->Put(write_options, "abc", "xyz");
+ assert(s.ok());
+
+ // Attempt to read a key using the snapshot. This will fail since
+ // the previous write outside this txn conflicts with this read.
+ read_options.snapshot = snapshot;
+ s = txn->GetForUpdate(read_options, "abc", &value);
+ assert(s.IsBusy());
+
+ txn->Rollback();
+
+ delete txn;
+ // Clear snapshot from read options since it is no longer valid
+ read_options.snapshot = nullptr;
+ snapshot = nullptr;
+
+ ////////////////////////////////////////////////////////
+ //
+ // "Read Committed" (Monotonic Atomic Views) Example
+ // --Using multiple Snapshots
+ //
+ ////////////////////////////////////////////////////////
+
+ // In this example, we set the snapshot multiple times. This is probably
+ // only necessary if you have very strict isolation requirements to
+ // implement.
+
+ // Set a snapshot at start of transaction
+ txn_options.set_snapshot = true;
+ txn = txn_db->BeginTransaction(write_options, txn_options);
+
+ // Do some reads and writes to key "x"
+ read_options.snapshot = txn_db->GetSnapshot();
+ s = txn->Get(read_options, "x", &value);
+ txn->Put("x", "x");
+
+ // Do a write outside of the transaction to key "y"
+ s = txn_db->Put(write_options, "y", "y");
+
+ // Set a new snapshot in the transaction
+ txn->SetSnapshot();
+ txn->SetSavePoint();
+ read_options.snapshot = txn_db->GetSnapshot();
+
+ // Do some reads and writes to key "y"
+ // Since the snapshot was advanced, the write done outside of the
+ // transaction does not conflict.
+ s = txn->GetForUpdate(read_options, "y", &value);
+ txn->Put("y", "y");
+
+ // Decide we want to revert the last write from this transaction.
+ txn->RollbackToSavePoint();
+
+ // Commit.
+ s = txn->Commit();
+ assert(s.ok());
+ delete txn;
+ // Clear snapshot from read options since it is no longer valid
+ read_options.snapshot = nullptr;
+
+ // Cleanup
+ delete txn_db;
+ DestroyDB(kDBPath, options);
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
+
+#else
+#include <stdio.h>
+
+int main(int argc, char** argv) {
+ fprintf(stderr, "SKIPPED as EnvMirror is not supported in ROCKSDB_LITE\n");
+ return 0;
+}
+
+#endif // !ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/env_mirror.cc b/src/rocksdb/utilities/env_mirror.cc
new file mode 100644
index 000000000..dbb5e8021
--- /dev/null
+++ b/src/rocksdb/utilities/env_mirror.cc
@@ -0,0 +1,262 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+// Copyright (c) 2015, Red Hat, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#ifndef ROCKSDB_LITE
+
+#include "rocksdb/utilities/env_mirror.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// An implementation of Env that mirrors all work over two backend
+// Env's. This is useful for debugging purposes.
+class SequentialFileMirror : public SequentialFile {
+ public:
+ std::unique_ptr<SequentialFile> a_, b_;
+ std::string fname;
+ explicit SequentialFileMirror(std::string f) : fname(f) {}
+
+ Status Read(size_t n, Slice* result, char* scratch) override {
+ Slice aslice;
+ Status as = a_->Read(n, &aslice, scratch);
+ if (as == Status::OK()) {
+ char* bscratch = new char[n];
+ Slice bslice;
+ size_t off = 0;
+ size_t left = aslice.size();
+ while (left) {
+ Status bs = b_->Read(left, &bslice, bscratch);
+ assert(as == bs);
+ assert(memcmp(bscratch, scratch + off, bslice.size()) == 0);
+ off += bslice.size();
+ left -= bslice.size();
+ }
+ delete[] bscratch;
+ *result = aslice;
+ } else {
+ Status bs = b_->Read(n, result, scratch);
+ assert(as == bs);
+ }
+ return as;
+ }
+
+ Status Skip(uint64_t n) override {
+ Status as = a_->Skip(n);
+ Status bs = b_->Skip(n);
+ assert(as == bs);
+ return as;
+ }
+ Status InvalidateCache(size_t offset, size_t length) override {
+ Status as = a_->InvalidateCache(offset, length);
+ Status bs = b_->InvalidateCache(offset, length);
+ assert(as == bs);
+ return as;
+ };
+};
+
+class RandomAccessFileMirror : public RandomAccessFile {
+ public:
+ std::unique_ptr<RandomAccessFile> a_, b_;
+ std::string fname;
+ explicit RandomAccessFileMirror(std::string f) : fname(f) {}
+
+ Status Read(uint64_t offset, size_t n, Slice* result,
+ char* scratch) const override {
+ Status as = a_->Read(offset, n, result, scratch);
+ if (as == Status::OK()) {
+ char* bscratch = new char[n];
+ Slice bslice;
+ size_t off = 0;
+ size_t left = result->size();
+ while (left) {
+ Status bs = b_->Read(offset + off, left, &bslice, bscratch);
+ assert(as == bs);
+ assert(memcmp(bscratch, scratch + off, bslice.size()) == 0);
+ off += bslice.size();
+ left -= bslice.size();
+ }
+ delete[] bscratch;
+ } else {
+ Status bs = b_->Read(offset, n, result, scratch);
+ assert(as == bs);
+ }
+ return as;
+ }
+
+ size_t GetUniqueId(char* id, size_t max_size) const override {
+ // NOTE: not verified
+ return a_->GetUniqueId(id, max_size);
+ }
+};
+
+class WritableFileMirror : public WritableFile {
+ public:
+ std::unique_ptr<WritableFile> a_, b_;
+ std::string fname;
+ explicit WritableFileMirror(std::string f, const EnvOptions& options)
+ : WritableFile(options), fname(f) {}
+
+ Status Append(const Slice& data) override {
+ Status as = a_->Append(data);
+ Status bs = b_->Append(data);
+ assert(as == bs);
+ return as;
+ }
+ Status PositionedAppend(const Slice& data, uint64_t offset) override {
+ Status as = a_->PositionedAppend(data, offset);
+ Status bs = b_->PositionedAppend(data, offset);
+ assert(as == bs);
+ return as;
+ }
+ Status Truncate(uint64_t size) override {
+ Status as = a_->Truncate(size);
+ Status bs = b_->Truncate(size);
+ assert(as == bs);
+ return as;
+ }
+ Status Close() override {
+ Status as = a_->Close();
+ Status bs = b_->Close();
+ assert(as == bs);
+ return as;
+ }
+ Status Flush() override {
+ Status as = a_->Flush();
+ Status bs = b_->Flush();
+ assert(as == bs);
+ return as;
+ }
+ Status Sync() override {
+ Status as = a_->Sync();
+ Status bs = b_->Sync();
+ assert(as == bs);
+ return as;
+ }
+ Status Fsync() override {
+ Status as = a_->Fsync();
+ Status bs = b_->Fsync();
+ assert(as == bs);
+ return as;
+ }
+ bool IsSyncThreadSafe() const override {
+ bool as = a_->IsSyncThreadSafe();
+ assert(as == b_->IsSyncThreadSafe());
+ return as;
+ }
+ void SetIOPriority(Env::IOPriority pri) override {
+ a_->SetIOPriority(pri);
+ b_->SetIOPriority(pri);
+ }
+ Env::IOPriority GetIOPriority() override {
+ // NOTE: we don't verify this one
+ return a_->GetIOPriority();
+ }
+ uint64_t GetFileSize() override {
+ uint64_t as = a_->GetFileSize();
+ assert(as == b_->GetFileSize());
+ return as;
+ }
+ void GetPreallocationStatus(size_t* block_size,
+ size_t* last_allocated_block) override {
+ // NOTE: we don't verify this one
+ return a_->GetPreallocationStatus(block_size, last_allocated_block);
+ }
+ size_t GetUniqueId(char* id, size_t max_size) const override {
+ // NOTE: we don't verify this one
+ return a_->GetUniqueId(id, max_size);
+ }
+ Status InvalidateCache(size_t offset, size_t length) override {
+ Status as = a_->InvalidateCache(offset, length);
+ Status bs = b_->InvalidateCache(offset, length);
+ assert(as == bs);
+ return as;
+ }
+
+ protected:
+ Status Allocate(uint64_t offset, uint64_t length) override {
+ Status as = a_->Allocate(offset, length);
+ Status bs = b_->Allocate(offset, length);
+ assert(as == bs);
+ return as;
+ }
+ Status RangeSync(uint64_t offset, uint64_t nbytes) override {
+ Status as = a_->RangeSync(offset, nbytes);
+ Status bs = b_->RangeSync(offset, nbytes);
+ assert(as == bs);
+ return as;
+ }
+};
+
+Status EnvMirror::NewSequentialFile(const std::string& f,
+ std::unique_ptr<SequentialFile>* r,
+ const EnvOptions& options) {
+ if (f.find("/proc/") == 0) {
+ return a_->NewSequentialFile(f, r, options);
+ }
+ SequentialFileMirror* mf = new SequentialFileMirror(f);
+ Status as = a_->NewSequentialFile(f, &mf->a_, options);
+ Status bs = b_->NewSequentialFile(f, &mf->b_, options);
+ assert(as == bs);
+ if (as.ok())
+ r->reset(mf);
+ else
+ delete mf;
+ return as;
+}
+
+Status EnvMirror::NewRandomAccessFile(const std::string& f,
+ std::unique_ptr<RandomAccessFile>* r,
+ const EnvOptions& options) {
+ if (f.find("/proc/") == 0) {
+ return a_->NewRandomAccessFile(f, r, options);
+ }
+ RandomAccessFileMirror* mf = new RandomAccessFileMirror(f);
+ Status as = a_->NewRandomAccessFile(f, &mf->a_, options);
+ Status bs = b_->NewRandomAccessFile(f, &mf->b_, options);
+ assert(as == bs);
+ if (as.ok())
+ r->reset(mf);
+ else
+ delete mf;
+ return as;
+}
+
+Status EnvMirror::NewWritableFile(const std::string& f,
+ std::unique_ptr<WritableFile>* r,
+ const EnvOptions& options) {
+ if (f.find("/proc/") == 0) return a_->NewWritableFile(f, r, options);
+ WritableFileMirror* mf = new WritableFileMirror(f, options);
+ Status as = a_->NewWritableFile(f, &mf->a_, options);
+ Status bs = b_->NewWritableFile(f, &mf->b_, options);
+ assert(as == bs);
+ if (as.ok())
+ r->reset(mf);
+ else
+ delete mf;
+ return as;
+}
+
+Status EnvMirror::ReuseWritableFile(const std::string& fname,
+ const std::string& old_fname,
+ std::unique_ptr<WritableFile>* r,
+ const EnvOptions& options) {
+ if (fname.find("/proc/") == 0)
+ return a_->ReuseWritableFile(fname, old_fname, r, options);
+ WritableFileMirror* mf = new WritableFileMirror(fname, options);
+ Status as = a_->ReuseWritableFile(fname, old_fname, &mf->a_, options);
+ Status bs = b_->ReuseWritableFile(fname, old_fname, &mf->b_, options);
+ assert(as == bs);
+ if (as.ok())
+ r->reset(mf);
+ else
+ delete mf;
+ return as;
+}
+
+} // namespace ROCKSDB_NAMESPACE
+#endif
diff --git a/src/rocksdb/utilities/env_mirror_test.cc b/src/rocksdb/utilities/env_mirror_test.cc
new file mode 100644
index 000000000..0be9d7db2
--- /dev/null
+++ b/src/rocksdb/utilities/env_mirror_test.cc
@@ -0,0 +1,223 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+// Copyright (c) 2015, Red Hat, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "rocksdb/utilities/env_mirror.h"
+#include "env/mock_env.h"
+#include "test_util/testharness.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class EnvMirrorTest : public testing::Test {
+ public:
+ Env* default_;
+ MockEnv* a_, *b_;
+ EnvMirror* env_;
+ const EnvOptions soptions_;
+
+ EnvMirrorTest()
+ : default_(Env::Default()),
+ a_(new MockEnv(default_)),
+ b_(new MockEnv(default_)),
+ env_(new EnvMirror(a_, b_)) {}
+ ~EnvMirrorTest() {
+ delete env_;
+ delete a_;
+ delete b_;
+ }
+};
+
+TEST_F(EnvMirrorTest, Basics) {
+ uint64_t file_size;
+ std::unique_ptr<WritableFile> writable_file;
+ std::vector<std::string> children;
+
+ ASSERT_OK(env_->CreateDir("/dir"));
+
+ // Check that the directory is empty.
+ ASSERT_EQ(Status::NotFound(), env_->FileExists("/dir/non_existent"));
+ ASSERT_TRUE(!env_->GetFileSize("/dir/non_existent", &file_size).ok());
+ ASSERT_OK(env_->GetChildren("/dir", &children));
+ ASSERT_EQ(0U, children.size());
+
+ // Create a file.
+ ASSERT_OK(env_->NewWritableFile("/dir/f", &writable_file, soptions_));
+ writable_file.reset();
+
+ // Check that the file exists.
+ ASSERT_OK(env_->FileExists("/dir/f"));
+ ASSERT_OK(a_->FileExists("/dir/f"));
+ ASSERT_OK(b_->FileExists("/dir/f"));
+ ASSERT_OK(env_->GetFileSize("/dir/f", &file_size));
+ ASSERT_EQ(0U, file_size);
+ ASSERT_OK(env_->GetChildren("/dir", &children));
+ ASSERT_EQ(1U, children.size());
+ ASSERT_EQ("f", children[0]);
+ ASSERT_OK(a_->GetChildren("/dir", &children));
+ ASSERT_EQ(1U, children.size());
+ ASSERT_EQ("f", children[0]);
+ ASSERT_OK(b_->GetChildren("/dir", &children));
+ ASSERT_EQ(1U, children.size());
+ ASSERT_EQ("f", children[0]);
+
+ // Write to the file.
+ ASSERT_OK(env_->NewWritableFile("/dir/f", &writable_file, soptions_));
+ ASSERT_OK(writable_file->Append("abc"));
+ writable_file.reset();
+
+ // Check for expected size.
+ ASSERT_OK(env_->GetFileSize("/dir/f", &file_size));
+ ASSERT_EQ(3U, file_size);
+ ASSERT_OK(a_->GetFileSize("/dir/f", &file_size));
+ ASSERT_EQ(3U, file_size);
+ ASSERT_OK(b_->GetFileSize("/dir/f", &file_size));
+ ASSERT_EQ(3U, file_size);
+
+ // Check that renaming works.
+ ASSERT_TRUE(!env_->RenameFile("/dir/non_existent", "/dir/g").ok());
+ ASSERT_OK(env_->RenameFile("/dir/f", "/dir/g"));
+ ASSERT_EQ(Status::NotFound(), env_->FileExists("/dir/f"));
+ ASSERT_OK(env_->FileExists("/dir/g"));
+ ASSERT_OK(env_->GetFileSize("/dir/g", &file_size));
+ ASSERT_EQ(3U, file_size);
+ ASSERT_OK(a_->FileExists("/dir/g"));
+ ASSERT_OK(a_->GetFileSize("/dir/g", &file_size));
+ ASSERT_EQ(3U, file_size);
+ ASSERT_OK(b_->FileExists("/dir/g"));
+ ASSERT_OK(b_->GetFileSize("/dir/g", &file_size));
+ ASSERT_EQ(3U, file_size);
+
+ // Check that opening non-existent file fails.
+ std::unique_ptr<SequentialFile> seq_file;
+ std::unique_ptr<RandomAccessFile> rand_file;
+ ASSERT_TRUE(
+ !env_->NewSequentialFile("/dir/non_existent", &seq_file, soptions_).ok());
+ ASSERT_TRUE(!seq_file);
+ ASSERT_TRUE(!env_->NewRandomAccessFile("/dir/non_existent", &rand_file,
+ soptions_).ok());
+ ASSERT_TRUE(!rand_file);
+
+ // Check that deleting works.
+ ASSERT_TRUE(!env_->DeleteFile("/dir/non_existent").ok());
+ ASSERT_OK(env_->DeleteFile("/dir/g"));
+ ASSERT_EQ(Status::NotFound(), env_->FileExists("/dir/g"));
+ ASSERT_OK(env_->GetChildren("/dir", &children));
+ ASSERT_EQ(0U, children.size());
+ ASSERT_OK(env_->DeleteDir("/dir"));
+}
+
+TEST_F(EnvMirrorTest, ReadWrite) {
+ std::unique_ptr<WritableFile> writable_file;
+ std::unique_ptr<SequentialFile> seq_file;
+ std::unique_ptr<RandomAccessFile> rand_file;
+ Slice result;
+ char scratch[100];
+
+ ASSERT_OK(env_->CreateDir("/dir"));
+
+ ASSERT_OK(env_->NewWritableFile("/dir/f", &writable_file, soptions_));
+ ASSERT_OK(writable_file->Append("hello "));
+ ASSERT_OK(writable_file->Append("world"));
+ writable_file.reset();
+
+ // Read sequentially.
+ ASSERT_OK(env_->NewSequentialFile("/dir/f", &seq_file, soptions_));
+ ASSERT_OK(seq_file->Read(5, &result, scratch)); // Read "hello".
+ ASSERT_EQ(0, result.compare("hello"));
+ ASSERT_OK(seq_file->Skip(1));
+ ASSERT_OK(seq_file->Read(1000, &result, scratch)); // Read "world".
+ ASSERT_EQ(0, result.compare("world"));
+ ASSERT_OK(seq_file->Read(1000, &result, scratch)); // Try reading past EOF.
+ ASSERT_EQ(0U, result.size());
+ ASSERT_OK(seq_file->Skip(100)); // Try to skip past end of file.
+ ASSERT_OK(seq_file->Read(1000, &result, scratch));
+ ASSERT_EQ(0U, result.size());
+
+ // Random reads.
+ ASSERT_OK(env_->NewRandomAccessFile("/dir/f", &rand_file, soptions_));
+ ASSERT_OK(rand_file->Read(6, 5, &result, scratch)); // Read "world".
+ ASSERT_EQ(0, result.compare("world"));
+ ASSERT_OK(rand_file->Read(0, 5, &result, scratch)); // Read "hello".
+ ASSERT_EQ(0, result.compare("hello"));
+ ASSERT_OK(rand_file->Read(10, 100, &result, scratch)); // Read "d".
+ ASSERT_EQ(0, result.compare("d"));
+
+ // Too high offset.
+ ASSERT_TRUE(!rand_file->Read(1000, 5, &result, scratch).ok());
+}
+
+TEST_F(EnvMirrorTest, Locks) {
+ FileLock* lock;
+
+ // These are no-ops, but we test they return success.
+ ASSERT_OK(env_->LockFile("some file", &lock));
+ ASSERT_OK(env_->UnlockFile(lock));
+}
+
+TEST_F(EnvMirrorTest, Misc) {
+ std::string test_dir;
+ ASSERT_OK(env_->GetTestDirectory(&test_dir));
+ ASSERT_TRUE(!test_dir.empty());
+
+ std::unique_ptr<WritableFile> writable_file;
+ ASSERT_OK(env_->NewWritableFile("/a/b", &writable_file, soptions_));
+
+ // These are no-ops, but we test they return success.
+ ASSERT_OK(writable_file->Sync());
+ ASSERT_OK(writable_file->Flush());
+ ASSERT_OK(writable_file->Close());
+ writable_file.reset();
+}
+
+TEST_F(EnvMirrorTest, LargeWrite) {
+ const size_t kWriteSize = 300 * 1024;
+ char* scratch = new char[kWriteSize * 2];
+
+ std::string write_data;
+ for (size_t i = 0; i < kWriteSize; ++i) {
+ write_data.append(1, static_cast<char>(i));
+ }
+
+ std::unique_ptr<WritableFile> writable_file;
+ ASSERT_OK(env_->NewWritableFile("/dir/f", &writable_file, soptions_));
+ ASSERT_OK(writable_file->Append("foo"));
+ ASSERT_OK(writable_file->Append(write_data));
+ writable_file.reset();
+
+ std::unique_ptr<SequentialFile> seq_file;
+ Slice result;
+ ASSERT_OK(env_->NewSequentialFile("/dir/f", &seq_file, soptions_));
+ ASSERT_OK(seq_file->Read(3, &result, scratch)); // Read "foo".
+ ASSERT_EQ(0, result.compare("foo"));
+
+ size_t read = 0;
+ std::string read_data;
+ while (read < kWriteSize) {
+ ASSERT_OK(seq_file->Read(kWriteSize - read, &result, scratch));
+ read_data.append(result.data(), result.size());
+ read += result.size();
+ }
+ ASSERT_TRUE(write_data == read_data);
+ delete[] scratch;
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
+
+#else
+#include <stdio.h>
+
+int main(int argc, char** argv) {
+ fprintf(stderr, "SKIPPED as EnvMirror is not supported in ROCKSDB_LITE\n");
+ return 0;
+}
+
+#endif // !ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/env_timed.cc b/src/rocksdb/utilities/env_timed.cc
new file mode 100644
index 000000000..fc6627da2
--- /dev/null
+++ b/src/rocksdb/utilities/env_timed.cc
@@ -0,0 +1,145 @@
+// Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "monitoring/perf_context_imp.h"
+#include "rocksdb/env.h"
+#include "rocksdb/status.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+#ifndef ROCKSDB_LITE
+
+// An environment that measures function call times for filesystem
+// operations, reporting results to variables in PerfContext.
+class TimedEnv : public EnvWrapper {
+ public:
+ explicit TimedEnv(Env* base_env) : EnvWrapper(base_env) {}
+
+ Status NewSequentialFile(const std::string& fname,
+ std::unique_ptr<SequentialFile>* result,
+ const EnvOptions& options) override {
+ PERF_TIMER_GUARD(env_new_sequential_file_nanos);
+ return EnvWrapper::NewSequentialFile(fname, result, options);
+ }
+
+ Status NewRandomAccessFile(const std::string& fname,
+ std::unique_ptr<RandomAccessFile>* result,
+ const EnvOptions& options) override {
+ PERF_TIMER_GUARD(env_new_random_access_file_nanos);
+ return EnvWrapper::NewRandomAccessFile(fname, result, options);
+ }
+
+ Status NewWritableFile(const std::string& fname,
+ std::unique_ptr<WritableFile>* result,
+ const EnvOptions& options) override {
+ PERF_TIMER_GUARD(env_new_writable_file_nanos);
+ return EnvWrapper::NewWritableFile(fname, result, options);
+ }
+
+ Status ReuseWritableFile(const std::string& fname,
+ const std::string& old_fname,
+ std::unique_ptr<WritableFile>* result,
+ const EnvOptions& options) override {
+ PERF_TIMER_GUARD(env_reuse_writable_file_nanos);
+ return EnvWrapper::ReuseWritableFile(fname, old_fname, result, options);
+ }
+
+ Status NewRandomRWFile(const std::string& fname,
+ std::unique_ptr<RandomRWFile>* result,
+ const EnvOptions& options) override {
+ PERF_TIMER_GUARD(env_new_random_rw_file_nanos);
+ return EnvWrapper::NewRandomRWFile(fname, result, options);
+ }
+
+ Status NewDirectory(const std::string& name,
+ std::unique_ptr<Directory>* result) override {
+ PERF_TIMER_GUARD(env_new_directory_nanos);
+ return EnvWrapper::NewDirectory(name, result);
+ }
+
+ Status FileExists(const std::string& fname) override {
+ PERF_TIMER_GUARD(env_file_exists_nanos);
+ return EnvWrapper::FileExists(fname);
+ }
+
+ Status GetChildren(const std::string& dir,
+ std::vector<std::string>* result) override {
+ PERF_TIMER_GUARD(env_get_children_nanos);
+ return EnvWrapper::GetChildren(dir, result);
+ }
+
+ Status GetChildrenFileAttributes(
+ const std::string& dir, std::vector<FileAttributes>* result) override {
+ PERF_TIMER_GUARD(env_get_children_file_attributes_nanos);
+ return EnvWrapper::GetChildrenFileAttributes(dir, result);
+ }
+
+ Status DeleteFile(const std::string& fname) override {
+ PERF_TIMER_GUARD(env_delete_file_nanos);
+ return EnvWrapper::DeleteFile(fname);
+ }
+
+ Status CreateDir(const std::string& dirname) override {
+ PERF_TIMER_GUARD(env_create_dir_nanos);
+ return EnvWrapper::CreateDir(dirname);
+ }
+
+ Status CreateDirIfMissing(const std::string& dirname) override {
+ PERF_TIMER_GUARD(env_create_dir_if_missing_nanos);
+ return EnvWrapper::CreateDirIfMissing(dirname);
+ }
+
+ Status DeleteDir(const std::string& dirname) override {
+ PERF_TIMER_GUARD(env_delete_dir_nanos);
+ return EnvWrapper::DeleteDir(dirname);
+ }
+
+ Status GetFileSize(const std::string& fname, uint64_t* file_size) override {
+ PERF_TIMER_GUARD(env_get_file_size_nanos);
+ return EnvWrapper::GetFileSize(fname, file_size);
+ }
+
+ Status GetFileModificationTime(const std::string& fname,
+ uint64_t* file_mtime) override {
+ PERF_TIMER_GUARD(env_get_file_modification_time_nanos);
+ return EnvWrapper::GetFileModificationTime(fname, file_mtime);
+ }
+
+ Status RenameFile(const std::string& src, const std::string& dst) override {
+ PERF_TIMER_GUARD(env_rename_file_nanos);
+ return EnvWrapper::RenameFile(src, dst);
+ }
+
+ Status LinkFile(const std::string& src, const std::string& dst) override {
+ PERF_TIMER_GUARD(env_link_file_nanos);
+ return EnvWrapper::LinkFile(src, dst);
+ }
+
+ Status LockFile(const std::string& fname, FileLock** lock) override {
+ PERF_TIMER_GUARD(env_lock_file_nanos);
+ return EnvWrapper::LockFile(fname, lock);
+ }
+
+ Status UnlockFile(FileLock* lock) override {
+ PERF_TIMER_GUARD(env_unlock_file_nanos);
+ return EnvWrapper::UnlockFile(lock);
+ }
+
+ Status NewLogger(const std::string& fname,
+ std::shared_ptr<Logger>* result) override {
+ PERF_TIMER_GUARD(env_new_logger_nanos);
+ return EnvWrapper::NewLogger(fname, result);
+ }
+};
+
+Env* NewTimedEnv(Env* base_env) { return new TimedEnv(base_env); }
+
+#else // ROCKSDB_LITE
+
+Env* NewTimedEnv(Env* /*base_env*/) { return nullptr; }
+
+#endif // !ROCKSDB_LITE
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/env_timed_test.cc b/src/rocksdb/utilities/env_timed_test.cc
new file mode 100644
index 000000000..f1695185e
--- /dev/null
+++ b/src/rocksdb/utilities/env_timed_test.cc
@@ -0,0 +1,44 @@
+// Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "rocksdb/env.h"
+#include "rocksdb/perf_context.h"
+#include "test_util/testharness.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class TimedEnvTest : public testing::Test {
+};
+
+TEST_F(TimedEnvTest, BasicTest) {
+ SetPerfLevel(PerfLevel::kEnableTime);
+ ASSERT_EQ(0, get_perf_context()->env_new_writable_file_nanos);
+
+ std::unique_ptr<Env> mem_env(NewMemEnv(Env::Default()));
+ std::unique_ptr<Env> timed_env(NewTimedEnv(mem_env.get()));
+ std::unique_ptr<WritableFile> writable_file;
+ timed_env->NewWritableFile("f", &writable_file, EnvOptions());
+
+ ASSERT_GT(get_perf_context()->env_new_writable_file_nanos, 0);
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
+
+#else // ROCKSDB_LITE
+#include <stdio.h>
+
+int main(int /*argc*/, char** /*argv*/) {
+ fprintf(stderr, "SKIPPED as TimedEnv is not supported in ROCKSDB_LITE\n");
+ return 0;
+}
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/leveldb_options/leveldb_options.cc b/src/rocksdb/utilities/leveldb_options/leveldb_options.cc
new file mode 100644
index 000000000..5698b21ce
--- /dev/null
+++ b/src/rocksdb/utilities/leveldb_options/leveldb_options.cc
@@ -0,0 +1,56 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "rocksdb/utilities/leveldb_options.h"
+#include "rocksdb/cache.h"
+#include "rocksdb/comparator.h"
+#include "rocksdb/env.h"
+#include "rocksdb/filter_policy.h"
+#include "rocksdb/options.h"
+#include "rocksdb/table.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+LevelDBOptions::LevelDBOptions()
+ : comparator(BytewiseComparator()),
+ create_if_missing(false),
+ error_if_exists(false),
+ paranoid_checks(false),
+ env(Env::Default()),
+ info_log(nullptr),
+ write_buffer_size(4 << 20),
+ max_open_files(1000),
+ block_cache(nullptr),
+ block_size(4096),
+ block_restart_interval(16),
+ compression(kSnappyCompression),
+ filter_policy(nullptr) {}
+
+Options ConvertOptions(const LevelDBOptions& leveldb_options) {
+ Options options = Options();
+ options.create_if_missing = leveldb_options.create_if_missing;
+ options.error_if_exists = leveldb_options.error_if_exists;
+ options.paranoid_checks = leveldb_options.paranoid_checks;
+ options.env = leveldb_options.env;
+ options.info_log.reset(leveldb_options.info_log);
+ options.write_buffer_size = leveldb_options.write_buffer_size;
+ options.max_open_files = leveldb_options.max_open_files;
+ options.compression = leveldb_options.compression;
+
+ BlockBasedTableOptions table_options;
+ table_options.block_cache.reset(leveldb_options.block_cache);
+ table_options.block_size = leveldb_options.block_size;
+ table_options.block_restart_interval = leveldb_options.block_restart_interval;
+ table_options.filter_policy.reset(leveldb_options.filter_policy);
+ options.table_factory.reset(NewBlockBasedTableFactory(table_options));
+
+ return options;
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/memory/memory_test.cc b/src/rocksdb/utilities/memory/memory_test.cc
new file mode 100644
index 000000000..9e253df44
--- /dev/null
+++ b/src/rocksdb/utilities/memory/memory_test.cc
@@ -0,0 +1,278 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "db/db_impl/db_impl.h"
+#include "rocksdb/cache.h"
+#include "rocksdb/table.h"
+#include "rocksdb/utilities/memory_util.h"
+#include "rocksdb/utilities/stackable_db.h"
+#include "table/block_based/block_based_table_factory.h"
+#include "test_util/testharness.h"
+#include "test_util/testutil.h"
+#include "util/string_util.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class MemoryTest : public testing::Test {
+ public:
+ MemoryTest() : kDbDir(test::PerThreadDBPath("memory_test")), rnd_(301) {
+ assert(Env::Default()->CreateDirIfMissing(kDbDir).ok());
+ }
+
+ std::string GetDBName(int id) { return kDbDir + "db_" + ToString(id); }
+
+ std::string RandomString(int len) {
+ std::string r;
+ test::RandomString(&rnd_, len, &r);
+ return r;
+ }
+
+ void UpdateUsagesHistory(const std::vector<DB*>& dbs) {
+ std::map<MemoryUtil::UsageType, uint64_t> usage_by_type;
+ ASSERT_OK(GetApproximateMemoryUsageByType(dbs, &usage_by_type));
+ for (int i = 0; i < MemoryUtil::kNumUsageTypes; ++i) {
+ usage_history_[i].push_back(
+ usage_by_type[static_cast<MemoryUtil::UsageType>(i)]);
+ }
+ }
+
+ void GetCachePointersFromTableFactory(
+ const TableFactory* factory,
+ std::unordered_set<const Cache*>* cache_set) {
+ const BlockBasedTableFactory* bbtf =
+ dynamic_cast<const BlockBasedTableFactory*>(factory);
+ if (bbtf != nullptr) {
+ const auto bbt_opts = bbtf->table_options();
+ cache_set->insert(bbt_opts.block_cache.get());
+ cache_set->insert(bbt_opts.block_cache_compressed.get());
+ }
+ }
+
+ void GetCachePointers(const std::vector<DB*>& dbs,
+ std::unordered_set<const Cache*>* cache_set) {
+ cache_set->clear();
+
+ for (auto* db : dbs) {
+ assert(db);
+
+ // Cache from DBImpl
+ StackableDB* sdb = dynamic_cast<StackableDB*>(db);
+ DBImpl* db_impl = dynamic_cast<DBImpl*>(sdb ? sdb->GetBaseDB() : db);
+ if (db_impl != nullptr) {
+ cache_set->insert(db_impl->TEST_table_cache());
+ }
+
+ // Cache from DBOptions
+ cache_set->insert(db->GetDBOptions().row_cache.get());
+
+ // Cache from table factories
+ std::unordered_map<std::string, const ImmutableCFOptions*> iopts_map;
+ if (db_impl != nullptr) {
+ ASSERT_OK(db_impl->TEST_GetAllImmutableCFOptions(&iopts_map));
+ }
+ for (auto pair : iopts_map) {
+ GetCachePointersFromTableFactory(pair.second->table_factory, cache_set);
+ }
+ }
+ }
+
+ Status GetApproximateMemoryUsageByType(
+ const std::vector<DB*>& dbs,
+ std::map<MemoryUtil::UsageType, uint64_t>* usage_by_type) {
+ std::unordered_set<const Cache*> cache_set;
+ GetCachePointers(dbs, &cache_set);
+
+ return MemoryUtil::GetApproximateMemoryUsageByType(dbs, cache_set,
+ usage_by_type);
+ }
+
+ const std::string kDbDir;
+ Random rnd_;
+ std::vector<uint64_t> usage_history_[MemoryUtil::kNumUsageTypes];
+};
+
+TEST_F(MemoryTest, SharedBlockCacheTotal) {
+ std::vector<DB*> dbs;
+ std::vector<uint64_t> usage_by_type;
+ const int kNumDBs = 10;
+ const int kKeySize = 100;
+ const int kValueSize = 500;
+ Options opt;
+ opt.create_if_missing = true;
+ opt.write_buffer_size = kKeySize + kValueSize;
+ opt.max_write_buffer_number = 10;
+ opt.min_write_buffer_number_to_merge = 10;
+ opt.disable_auto_compactions = true;
+ BlockBasedTableOptions bbt_opts;
+ bbt_opts.block_cache = NewLRUCache(4096 * 1000 * 10);
+ for (int i = 0; i < kNumDBs; ++i) {
+ DestroyDB(GetDBName(i), opt);
+ DB* db = nullptr;
+ ASSERT_OK(DB::Open(opt, GetDBName(i), &db));
+ dbs.push_back(db);
+ }
+
+ std::vector<std::string> keys_by_db[kNumDBs];
+
+ // Fill one memtable per Put to make memtable use more memory.
+ for (int p = 0; p < opt.min_write_buffer_number_to_merge / 2; ++p) {
+ for (int i = 0; i < kNumDBs; ++i) {
+ for (int j = 0; j < 100; ++j) {
+ keys_by_db[i].emplace_back(RandomString(kKeySize));
+ dbs[i]->Put(WriteOptions(), keys_by_db[i].back(),
+ RandomString(kValueSize));
+ }
+ dbs[i]->Flush(FlushOptions());
+ }
+ }
+ for (int i = 0; i < kNumDBs; ++i) {
+ for (auto& key : keys_by_db[i]) {
+ std::string value;
+ dbs[i]->Get(ReadOptions(), key, &value);
+ }
+ UpdateUsagesHistory(dbs);
+ }
+ for (size_t i = 1; i < usage_history_[MemoryUtil::kMemTableTotal].size();
+ ++i) {
+ // Expect EQ as we didn't flush more memtables.
+ ASSERT_EQ(usage_history_[MemoryUtil::kTableReadersTotal][i],
+ usage_history_[MemoryUtil::kTableReadersTotal][i - 1]);
+ }
+ for (int i = 0; i < kNumDBs; ++i) {
+ delete dbs[i];
+ }
+}
+
+TEST_F(MemoryTest, MemTableAndTableReadersTotal) {
+ std::vector<DB*> dbs;
+ std::vector<uint64_t> usage_by_type;
+ std::vector<std::vector<ColumnFamilyHandle*>> vec_handles;
+ const int kNumDBs = 10;
+ const int kKeySize = 100;
+ const int kValueSize = 500;
+ Options opt;
+ opt.create_if_missing = true;
+ opt.create_missing_column_families = true;
+ opt.write_buffer_size = kKeySize + kValueSize;
+ opt.max_write_buffer_number = 10;
+ opt.min_write_buffer_number_to_merge = 10;
+ opt.disable_auto_compactions = true;
+
+ std::vector<ColumnFamilyDescriptor> cf_descs = {
+ {kDefaultColumnFamilyName, ColumnFamilyOptions(opt)},
+ {"one", ColumnFamilyOptions(opt)},
+ {"two", ColumnFamilyOptions(opt)},
+ };
+
+ for (int i = 0; i < kNumDBs; ++i) {
+ DestroyDB(GetDBName(i), opt);
+ std::vector<ColumnFamilyHandle*> handles;
+ dbs.emplace_back();
+ vec_handles.emplace_back();
+ ASSERT_OK(DB::Open(DBOptions(opt), GetDBName(i), cf_descs,
+ &vec_handles.back(), &dbs.back()));
+ }
+
+ // Fill one memtable per Put to make memtable use more memory.
+ for (int p = 0; p < opt.min_write_buffer_number_to_merge / 2; ++p) {
+ for (int i = 0; i < kNumDBs; ++i) {
+ for (auto* handle : vec_handles[i]) {
+ dbs[i]->Put(WriteOptions(), handle, RandomString(kKeySize),
+ RandomString(kValueSize));
+ UpdateUsagesHistory(dbs);
+ }
+ }
+ }
+ // Expect the usage history is monotonically increasing
+ for (size_t i = 1; i < usage_history_[MemoryUtil::kMemTableTotal].size();
+ ++i) {
+ ASSERT_GT(usage_history_[MemoryUtil::kMemTableTotal][i],
+ usage_history_[MemoryUtil::kMemTableTotal][i - 1]);
+ ASSERT_GT(usage_history_[MemoryUtil::kMemTableUnFlushed][i],
+ usage_history_[MemoryUtil::kMemTableUnFlushed][i - 1]);
+ ASSERT_EQ(usage_history_[MemoryUtil::kTableReadersTotal][i],
+ usage_history_[MemoryUtil::kTableReadersTotal][i - 1]);
+ }
+
+ size_t usage_check_point = usage_history_[MemoryUtil::kMemTableTotal].size();
+ std::vector<Iterator*> iters;
+
+ // Create an iterator and flush all memtables for each db
+ for (int i = 0; i < kNumDBs; ++i) {
+ iters.push_back(dbs[i]->NewIterator(ReadOptions()));
+ dbs[i]->Flush(FlushOptions());
+
+ for (int j = 0; j < 100; ++j) {
+ std::string value;
+ dbs[i]->Get(ReadOptions(), RandomString(kKeySize), &value);
+ }
+
+ UpdateUsagesHistory(dbs);
+ }
+ for (size_t i = usage_check_point;
+ i < usage_history_[MemoryUtil::kMemTableTotal].size(); ++i) {
+ // Since memtables are pinned by iterators, we don't expect the
+ // memory usage of all the memtables decreases as they are pinned
+ // by iterators.
+ ASSERT_GE(usage_history_[MemoryUtil::kMemTableTotal][i],
+ usage_history_[MemoryUtil::kMemTableTotal][i - 1]);
+ // Expect the usage history from the "usage_decay_point" is
+ // monotonically decreasing.
+ ASSERT_LT(usage_history_[MemoryUtil::kMemTableUnFlushed][i],
+ usage_history_[MemoryUtil::kMemTableUnFlushed][i - 1]);
+ // Expect the usage history of the table readers increases
+ // as we flush tables.
+ ASSERT_GT(usage_history_[MemoryUtil::kTableReadersTotal][i],
+ usage_history_[MemoryUtil::kTableReadersTotal][i - 1]);
+ ASSERT_GT(usage_history_[MemoryUtil::kCacheTotal][i],
+ usage_history_[MemoryUtil::kCacheTotal][i - 1]);
+ }
+ usage_check_point = usage_history_[MemoryUtil::kMemTableTotal].size();
+ for (int i = 0; i < kNumDBs; ++i) {
+ delete iters[i];
+ UpdateUsagesHistory(dbs);
+ }
+ for (size_t i = usage_check_point;
+ i < usage_history_[MemoryUtil::kMemTableTotal].size(); ++i) {
+ // Expect the usage of all memtables decreasing as we delete iterators.
+ ASSERT_LT(usage_history_[MemoryUtil::kMemTableTotal][i],
+ usage_history_[MemoryUtil::kMemTableTotal][i - 1]);
+ // Since the memory usage of un-flushed memtables is only affected
+ // by Put and flush, we expect EQ here as we only delete iterators.
+ ASSERT_EQ(usage_history_[MemoryUtil::kMemTableUnFlushed][i],
+ usage_history_[MemoryUtil::kMemTableUnFlushed][i - 1]);
+ // Expect EQ as we didn't flush more memtables.
+ ASSERT_EQ(usage_history_[MemoryUtil::kTableReadersTotal][i],
+ usage_history_[MemoryUtil::kTableReadersTotal][i - 1]);
+ }
+
+ for (int i = 0; i < kNumDBs; ++i) {
+ for (auto* handle : vec_handles[i]) {
+ delete handle;
+ }
+ delete dbs[i];
+ }
+}
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+#if !(defined NDEBUG) || !defined(OS_WIN)
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+#else
+ return 0;
+#endif
+}
+
+#else
+#include <cstdio>
+
+int main(int /*argc*/, char** /*argv*/) {
+ printf("Skipped in RocksDBLite as utilities are not supported.\n");
+ return 0;
+}
+#endif // !ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/memory/memory_util.cc b/src/rocksdb/utilities/memory/memory_util.cc
new file mode 100644
index 000000000..13c81aec4
--- /dev/null
+++ b/src/rocksdb/utilities/memory/memory_util.cc
@@ -0,0 +1,52 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "rocksdb/utilities/memory_util.h"
+
+#include "db/db_impl/db_impl.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+Status MemoryUtil::GetApproximateMemoryUsageByType(
+ const std::vector<DB*>& dbs,
+ const std::unordered_set<const Cache*> cache_set,
+ std::map<MemoryUtil::UsageType, uint64_t>* usage_by_type) {
+ usage_by_type->clear();
+
+ // MemTable
+ for (auto* db : dbs) {
+ uint64_t usage = 0;
+ if (db->GetAggregatedIntProperty(DB::Properties::kSizeAllMemTables,
+ &usage)) {
+ (*usage_by_type)[MemoryUtil::kMemTableTotal] += usage;
+ }
+ if (db->GetAggregatedIntProperty(DB::Properties::kCurSizeAllMemTables,
+ &usage)) {
+ (*usage_by_type)[MemoryUtil::kMemTableUnFlushed] += usage;
+ }
+ }
+
+ // Table Readers
+ for (auto* db : dbs) {
+ uint64_t usage = 0;
+ if (db->GetAggregatedIntProperty(DB::Properties::kEstimateTableReadersMem,
+ &usage)) {
+ (*usage_by_type)[MemoryUtil::kTableReadersTotal] += usage;
+ }
+ }
+
+ // Cache
+ for (const auto* cache : cache_set) {
+ if (cache != nullptr) {
+ (*usage_by_type)[MemoryUtil::kCacheTotal] += cache->GetUsage();
+ }
+ }
+
+ return Status::OK();
+}
+} // namespace ROCKSDB_NAMESPACE
+#endif // !ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/merge_operators.h b/src/rocksdb/utilities/merge_operators.h
new file mode 100644
index 000000000..018d097b1
--- /dev/null
+++ b/src/rocksdb/utilities/merge_operators.h
@@ -0,0 +1,55 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+#pragma once
+#include "rocksdb/merge_operator.h"
+
+#include <stdio.h>
+
+#include <memory>
+#include <string>
+
+namespace ROCKSDB_NAMESPACE {
+
+class MergeOperators {
+ public:
+ static std::shared_ptr<MergeOperator> CreatePutOperator();
+ static std::shared_ptr<MergeOperator> CreateDeprecatedPutOperator();
+ static std::shared_ptr<MergeOperator> CreateUInt64AddOperator();
+ static std::shared_ptr<MergeOperator> CreateStringAppendOperator();
+ static std::shared_ptr<MergeOperator> CreateStringAppendOperator(char delim_char);
+ static std::shared_ptr<MergeOperator> CreateStringAppendTESTOperator();
+ static std::shared_ptr<MergeOperator> CreateMaxOperator();
+ static std::shared_ptr<MergeOperator> CreateBytesXOROperator();
+ static std::shared_ptr<MergeOperator> CreateSortOperator();
+
+ // Will return a different merge operator depending on the string.
+ // TODO: Hook the "name" up to the actual Name() of the MergeOperators?
+ static std::shared_ptr<MergeOperator> CreateFromStringId(
+ const std::string& name) {
+ if (name == "put") {
+ return CreatePutOperator();
+ } else if (name == "put_v1") {
+ return CreateDeprecatedPutOperator();
+ } else if ( name == "uint64add") {
+ return CreateUInt64AddOperator();
+ } else if (name == "stringappend") {
+ return CreateStringAppendOperator();
+ } else if (name == "stringappendtest") {
+ return CreateStringAppendTESTOperator();
+ } else if (name == "max") {
+ return CreateMaxOperator();
+ } else if (name == "bytesxor") {
+ return CreateBytesXOROperator();
+ } else if (name == "sortlist") {
+ return CreateSortOperator();
+ } else {
+ // Empty or unknown, just return nullptr
+ return nullptr;
+ }
+ }
+};
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/merge_operators/bytesxor.cc b/src/rocksdb/utilities/merge_operators/bytesxor.cc
new file mode 100644
index 000000000..859affb5e
--- /dev/null
+++ b/src/rocksdb/utilities/merge_operators/bytesxor.cc
@@ -0,0 +1,59 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include <algorithm>
+#include <string>
+
+#include "utilities/merge_operators/bytesxor.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+std::shared_ptr<MergeOperator> MergeOperators::CreateBytesXOROperator() {
+ return std::make_shared<BytesXOROperator>();
+}
+
+bool BytesXOROperator::Merge(const Slice& /*key*/,
+ const Slice* existing_value,
+ const Slice& value,
+ std::string* new_value,
+ Logger* /*logger*/) const {
+ XOR(existing_value, value, new_value);
+ return true;
+}
+
+void BytesXOROperator::XOR(const Slice* existing_value,
+ const Slice& value, std::string* new_value) const {
+ if (!existing_value) {
+ new_value->clear();
+ new_value->assign(value.data(), value.size());
+ return;
+ }
+
+ size_t min_size = std::min(existing_value->size(), value.size());
+ size_t max_size = std::max(existing_value->size(), value.size());
+
+ new_value->clear();
+ new_value->reserve(max_size);
+
+ const char* existing_value_data = existing_value->data();
+ const char* value_data = value.data();
+
+ for (size_t i = 0; i < min_size; i++) {
+ new_value->push_back(existing_value_data[i] ^ value_data[i]);
+ }
+
+ if (existing_value->size() == max_size) {
+ for (size_t i = min_size; i < max_size; i++) {
+ new_value->push_back(existing_value_data[i]);
+ }
+ } else {
+ assert(value.size() == max_size);
+ for (size_t i = min_size; i < max_size; i++) {
+ new_value->push_back(value_data[i]);
+ }
+ }
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/merge_operators/bytesxor.h b/src/rocksdb/utilities/merge_operators/bytesxor.h
new file mode 100644
index 000000000..ab0c5aecc
--- /dev/null
+++ b/src/rocksdb/utilities/merge_operators/bytesxor.h
@@ -0,0 +1,39 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#include <algorithm>
+#include <memory>
+#include <string>
+#include "rocksdb/env.h"
+#include "rocksdb/merge_operator.h"
+#include "rocksdb/slice.h"
+#include "util/coding.h"
+#include "utilities/merge_operators.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// A 'model' merge operator that XORs two (same sized) array of bytes.
+// Implemented as an AssociativeMergeOperator for simplicity and example.
+class BytesXOROperator : public AssociativeMergeOperator {
+ public:
+ // XORs the two array of bytes one byte at a time and stores the result
+ // in new_value. len is the number of xored bytes, and the length of new_value
+ virtual bool Merge(const Slice& key,
+ const Slice* existing_value,
+ const Slice& value,
+ std::string* new_value,
+ Logger* logger) const override;
+
+ virtual const char* Name() const override {
+ return "BytesXOR";
+ }
+
+ void XOR(const Slice* existing_value, const Slice& value,
+ std::string* new_value) const;
+};
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/merge_operators/max.cc b/src/rocksdb/utilities/merge_operators/max.cc
new file mode 100644
index 000000000..2270c1f03
--- /dev/null
+++ b/src/rocksdb/utilities/merge_operators/max.cc
@@ -0,0 +1,77 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include <memory>
+
+#include "rocksdb/merge_operator.h"
+#include "rocksdb/slice.h"
+#include "utilities/merge_operators.h"
+
+using ROCKSDB_NAMESPACE::Logger;
+using ROCKSDB_NAMESPACE::MergeOperator;
+using ROCKSDB_NAMESPACE::Slice;
+
+namespace { // anonymous namespace
+
+// Merge operator that picks the maximum operand, Comparison is based on
+// Slice::compare
+class MaxOperator : public MergeOperator {
+ public:
+ bool FullMergeV2(const MergeOperationInput& merge_in,
+ MergeOperationOutput* merge_out) const override {
+ Slice& max = merge_out->existing_operand;
+ if (merge_in.existing_value) {
+ max = Slice(merge_in.existing_value->data(),
+ merge_in.existing_value->size());
+ } else if (max.data() == nullptr) {
+ max = Slice();
+ }
+
+ for (const auto& op : merge_in.operand_list) {
+ if (max.compare(op) < 0) {
+ max = op;
+ }
+ }
+
+ return true;
+ }
+
+ bool PartialMerge(const Slice& /*key*/, const Slice& left_operand,
+ const Slice& right_operand, std::string* new_value,
+ Logger* /*logger*/) const override {
+ if (left_operand.compare(right_operand) >= 0) {
+ new_value->assign(left_operand.data(), left_operand.size());
+ } else {
+ new_value->assign(right_operand.data(), right_operand.size());
+ }
+ return true;
+ }
+
+ bool PartialMergeMulti(const Slice& /*key*/,
+ const std::deque<Slice>& operand_list,
+ std::string* new_value,
+ Logger* /*logger*/) const override {
+ Slice max;
+ for (const auto& operand : operand_list) {
+ if (max.compare(operand) < 0) {
+ max = operand;
+ }
+ }
+
+ new_value->assign(max.data(), max.size());
+ return true;
+ }
+
+ const char* Name() const override { return "MaxOperator"; }
+};
+
+} // end of anonymous namespace
+
+namespace ROCKSDB_NAMESPACE {
+
+std::shared_ptr<MergeOperator> MergeOperators::CreateMaxOperator() {
+ return std::make_shared<MaxOperator>();
+}
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/merge_operators/put.cc b/src/rocksdb/utilities/merge_operators/put.cc
new file mode 100644
index 000000000..901d69e94
--- /dev/null
+++ b/src/rocksdb/utilities/merge_operators/put.cc
@@ -0,0 +1,83 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include <memory>
+#include "rocksdb/slice.h"
+#include "rocksdb/merge_operator.h"
+#include "utilities/merge_operators.h"
+
+using namespace ROCKSDB_NAMESPACE;
+
+namespace { // anonymous namespace
+
+// A merge operator that mimics Put semantics
+// Since this merge-operator will not be used in production,
+// it is implemented as a non-associative merge operator to illustrate the
+// new interface and for testing purposes. (That is, we inherit from
+// the MergeOperator class rather than the AssociativeMergeOperator
+// which would be simpler in this case).
+//
+// From the client-perspective, semantics are the same.
+class PutOperator : public MergeOperator {
+ public:
+ bool FullMerge(const Slice& /*key*/, const Slice* /*existing_value*/,
+ const std::deque<std::string>& operand_sequence,
+ std::string* new_value, Logger* /*logger*/) const override {
+ // Put basically only looks at the current/latest value
+ assert(!operand_sequence.empty());
+ assert(new_value != nullptr);
+ new_value->assign(operand_sequence.back());
+ return true;
+ }
+
+ bool PartialMerge(const Slice& /*key*/, const Slice& /*left_operand*/,
+ const Slice& right_operand, std::string* new_value,
+ Logger* /*logger*/) const override {
+ new_value->assign(right_operand.data(), right_operand.size());
+ return true;
+ }
+
+ using MergeOperator::PartialMergeMulti;
+ bool PartialMergeMulti(const Slice& /*key*/,
+ const std::deque<Slice>& operand_list,
+ std::string* new_value,
+ Logger* /*logger*/) const override {
+ new_value->assign(operand_list.back().data(), operand_list.back().size());
+ return true;
+ }
+
+ const char* Name() const override { return "PutOperator"; }
+};
+
+class PutOperatorV2 : public PutOperator {
+ bool FullMerge(const Slice& /*key*/, const Slice* /*existing_value*/,
+ const std::deque<std::string>& /*operand_sequence*/,
+ std::string* /*new_value*/,
+ Logger* /*logger*/) const override {
+ assert(false);
+ return false;
+ }
+
+ bool FullMergeV2(const MergeOperationInput& merge_in,
+ MergeOperationOutput* merge_out) const override {
+ // Put basically only looks at the current/latest value
+ assert(!merge_in.operand_list.empty());
+ merge_out->existing_operand = merge_in.operand_list.back();
+ return true;
+ }
+};
+
+} // end of anonymous namespace
+
+namespace ROCKSDB_NAMESPACE {
+
+std::shared_ptr<MergeOperator> MergeOperators::CreateDeprecatedPutOperator() {
+ return std::make_shared<PutOperator>();
+}
+
+std::shared_ptr<MergeOperator> MergeOperators::CreatePutOperator() {
+ return std::make_shared<PutOperatorV2>();
+}
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/merge_operators/sortlist.cc b/src/rocksdb/utilities/merge_operators/sortlist.cc
new file mode 100644
index 000000000..b6bd65b36
--- /dev/null
+++ b/src/rocksdb/utilities/merge_operators/sortlist.cc
@@ -0,0 +1,100 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+#include "utilities/merge_operators/sortlist.h"
+#include "rocksdb/merge_operator.h"
+#include "rocksdb/slice.h"
+#include "utilities/merge_operators.h"
+
+using ROCKSDB_NAMESPACE::Logger;
+using ROCKSDB_NAMESPACE::MergeOperator;
+using ROCKSDB_NAMESPACE::Slice;
+
+namespace ROCKSDB_NAMESPACE {
+
+bool SortList::FullMergeV2(const MergeOperationInput& merge_in,
+ MergeOperationOutput* merge_out) const {
+ std::vector<int> left;
+ for (Slice slice : merge_in.operand_list) {
+ std::vector<int> right;
+ MakeVector(right, slice);
+ left = Merge(left, right);
+ }
+ for (int i = 0; i < static_cast<int>(left.size()) - 1; i++) {
+ merge_out->new_value.append(std::to_string(left[i])).append(",");
+ }
+ merge_out->new_value.append(std::to_string(left.back()));
+ return true;
+}
+
+bool SortList::PartialMerge(const Slice& /*key*/, const Slice& left_operand,
+ const Slice& right_operand, std::string* new_value,
+ Logger* /*logger*/) const {
+ std::vector<int> left;
+ std::vector<int> right;
+ MakeVector(left, left_operand);
+ MakeVector(right, right_operand);
+ left = Merge(left, right);
+ for (int i = 0; i < static_cast<int>(left.size()) - 1; i++) {
+ new_value->append(std::to_string(left[i])).append(",");
+ }
+ new_value->append(std::to_string(left.back()));
+ return true;
+}
+
+bool SortList::PartialMergeMulti(const Slice& /*key*/,
+ const std::deque<Slice>& operand_list,
+ std::string* new_value,
+ Logger* /*logger*/) const {
+ (void)operand_list;
+ (void)new_value;
+ return true;
+}
+
+const char* SortList::Name() const { return "MergeSortOperator"; }
+
+void SortList::MakeVector(std::vector<int>& operand, Slice slice) const {
+ do {
+ const char* begin = slice.data_;
+ while (*slice.data_ != ',' && *slice.data_) slice.data_++;
+ operand.push_back(std::stoi(std::string(begin, slice.data_)));
+ } while (0 != *slice.data_++);
+}
+
+std::vector<int> SortList::Merge(std::vector<int>& left,
+ std::vector<int>& right) const {
+ // Fill the resultant vector with sorted results from both vectors
+ std::vector<int> result;
+ unsigned left_it = 0, right_it = 0;
+
+ while (left_it < left.size() && right_it < right.size()) {
+ // If the left value is smaller than the right it goes next
+ // into the resultant vector
+ if (left[left_it] < right[right_it]) {
+ result.push_back(left[left_it]);
+ left_it++;
+ } else {
+ result.push_back(right[right_it]);
+ right_it++;
+ }
+ }
+
+ // Push the remaining data from both vectors onto the resultant
+ while (left_it < left.size()) {
+ result.push_back(left[left_it]);
+ left_it++;
+ }
+
+ while (right_it < right.size()) {
+ result.push_back(right[right_it]);
+ right_it++;
+ }
+
+ return result;
+}
+
+std::shared_ptr<MergeOperator> MergeOperators::CreateSortOperator() {
+ return std::make_shared<SortList>();
+}
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/merge_operators/sortlist.h b/src/rocksdb/utilities/merge_operators/sortlist.h
new file mode 100644
index 000000000..5e08bd583
--- /dev/null
+++ b/src/rocksdb/utilities/merge_operators/sortlist.h
@@ -0,0 +1,38 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+// A MergeOperator for RocksDB that implements Merge Sort.
+// It is built using the MergeOperator interface. The operator works by taking
+// an input which contains one or more merge operands where each operand is a
+// list of sorted ints and merges them to form a large sorted list.
+#pragma once
+
+#include "rocksdb/merge_operator.h"
+#include "rocksdb/slice.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class SortList : public MergeOperator {
+ public:
+ bool FullMergeV2(const MergeOperationInput& merge_in,
+ MergeOperationOutput* merge_out) const override;
+
+ bool PartialMerge(const Slice& /*key*/, const Slice& left_operand,
+ const Slice& right_operand, std::string* new_value,
+ Logger* /*logger*/) const override;
+
+ bool PartialMergeMulti(const Slice& key,
+ const std::deque<Slice>& operand_list,
+ std::string* new_value, Logger* logger) const override;
+
+ const char* Name() const override;
+
+ void MakeVector(std::vector<int>& operand, Slice slice) const;
+
+ private:
+ std::vector<int> Merge(std::vector<int>& left, std::vector<int>& right) const;
+};
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/merge_operators/string_append/stringappend.cc b/src/rocksdb/utilities/merge_operators/string_append/stringappend.cc
new file mode 100644
index 000000000..534f7a566
--- /dev/null
+++ b/src/rocksdb/utilities/merge_operators/string_append/stringappend.cc
@@ -0,0 +1,59 @@
+/**
+ * A MergeOperator for rocksdb that implements string append.
+ * @author Deon Nicholas (dnicholas@fb.com)
+ * Copyright 2013 Facebook
+ */
+
+#include "stringappend.h"
+
+#include <memory>
+#include <assert.h>
+
+#include "rocksdb/slice.h"
+#include "rocksdb/merge_operator.h"
+#include "utilities/merge_operators.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// Constructor: also specify the delimiter character.
+StringAppendOperator::StringAppendOperator(char delim_char)
+ : delim_(delim_char) {
+}
+
+// Implementation for the merge operation (concatenates two strings)
+bool StringAppendOperator::Merge(const Slice& /*key*/,
+ const Slice* existing_value,
+ const Slice& value, std::string* new_value,
+ Logger* /*logger*/) const {
+ // Clear the *new_value for writing.
+ assert(new_value);
+ new_value->clear();
+
+ if (!existing_value) {
+ // No existing_value. Set *new_value = value
+ new_value->assign(value.data(),value.size());
+ } else {
+ // Generic append (existing_value != null).
+ // Reserve *new_value to correct size, and apply concatenation.
+ new_value->reserve(existing_value->size() + 1 + value.size());
+ new_value->assign(existing_value->data(),existing_value->size());
+ new_value->append(1,delim_);
+ new_value->append(value.data(), value.size());
+ }
+
+ return true;
+}
+
+const char* StringAppendOperator::Name() const {
+ return "StringAppendOperator";
+}
+
+std::shared_ptr<MergeOperator> MergeOperators::CreateStringAppendOperator() {
+ return std::make_shared<StringAppendOperator>(',');
+}
+
+std::shared_ptr<MergeOperator> MergeOperators::CreateStringAppendOperator(char delim_char) {
+ return std::make_shared<StringAppendOperator>(delim_char);
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/merge_operators/string_append/stringappend.h b/src/rocksdb/utilities/merge_operators/string_append/stringappend.h
new file mode 100644
index 000000000..388612f1e
--- /dev/null
+++ b/src/rocksdb/utilities/merge_operators/string_append/stringappend.h
@@ -0,0 +1,31 @@
+/**
+ * A MergeOperator for rocksdb that implements string append.
+ * @author Deon Nicholas (dnicholas@fb.com)
+ * Copyright 2013 Facebook
+ */
+
+#pragma once
+#include "rocksdb/merge_operator.h"
+#include "rocksdb/slice.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class StringAppendOperator : public AssociativeMergeOperator {
+ public:
+ // Constructor: specify delimiter
+ explicit StringAppendOperator(char delim_char);
+
+ virtual bool Merge(const Slice& key,
+ const Slice* existing_value,
+ const Slice& value,
+ std::string* new_value,
+ Logger* logger) const override;
+
+ virtual const char* Name() const override;
+
+ private:
+ char delim_; // The delimiter is inserted between elements
+
+};
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/merge_operators/string_append/stringappend2.cc b/src/rocksdb/utilities/merge_operators/string_append/stringappend2.cc
new file mode 100644
index 000000000..b8c676ee5
--- /dev/null
+++ b/src/rocksdb/utilities/merge_operators/string_append/stringappend2.cc
@@ -0,0 +1,117 @@
+/**
+ * @author Deon Nicholas (dnicholas@fb.com)
+ * Copyright 2013 Facebook
+ */
+
+#include "stringappend2.h"
+
+#include <memory>
+#include <string>
+#include <assert.h>
+
+#include "rocksdb/slice.h"
+#include "rocksdb/merge_operator.h"
+#include "utilities/merge_operators.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// Constructor: also specify the delimiter character.
+StringAppendTESTOperator::StringAppendTESTOperator(char delim_char)
+ : delim_(delim_char) {
+}
+
+// Implementation for the merge operation (concatenates two strings)
+bool StringAppendTESTOperator::FullMergeV2(
+ const MergeOperationInput& merge_in,
+ MergeOperationOutput* merge_out) const {
+ // Clear the *new_value for writing.
+ merge_out->new_value.clear();
+
+ if (merge_in.existing_value == nullptr && merge_in.operand_list.size() == 1) {
+ // Only one operand
+ merge_out->existing_operand = merge_in.operand_list.back();
+ return true;
+ }
+
+ // Compute the space needed for the final result.
+ size_t numBytes = 0;
+ for (auto it = merge_in.operand_list.begin();
+ it != merge_in.operand_list.end(); ++it) {
+ numBytes += it->size() + 1; // Plus 1 for the delimiter
+ }
+
+ // Only print the delimiter after the first entry has been printed
+ bool printDelim = false;
+
+ // Prepend the *existing_value if one exists.
+ if (merge_in.existing_value) {
+ merge_out->new_value.reserve(numBytes + merge_in.existing_value->size());
+ merge_out->new_value.append(merge_in.existing_value->data(),
+ merge_in.existing_value->size());
+ printDelim = true;
+ } else if (numBytes) {
+ merge_out->new_value.reserve(
+ numBytes - 1); // Minus 1 since we have one less delimiter
+ }
+
+ // Concatenate the sequence of strings (and add a delimiter between each)
+ for (auto it = merge_in.operand_list.begin();
+ it != merge_in.operand_list.end(); ++it) {
+ if (printDelim) {
+ merge_out->new_value.append(1, delim_);
+ }
+ merge_out->new_value.append(it->data(), it->size());
+ printDelim = true;
+ }
+
+ return true;
+}
+
+bool StringAppendTESTOperator::PartialMergeMulti(
+ const Slice& /*key*/, const std::deque<Slice>& /*operand_list*/,
+ std::string* /*new_value*/, Logger* /*logger*/) const {
+ return false;
+}
+
+// A version of PartialMerge that actually performs "partial merging".
+// Use this to simulate the exact behaviour of the StringAppendOperator.
+bool StringAppendTESTOperator::_AssocPartialMergeMulti(
+ const Slice& /*key*/, const std::deque<Slice>& operand_list,
+ std::string* new_value, Logger* /*logger*/) const {
+ // Clear the *new_value for writing
+ assert(new_value);
+ new_value->clear();
+ assert(operand_list.size() >= 2);
+
+ // Generic append
+ // Determine and reserve correct size for *new_value.
+ size_t size = 0;
+ for (const auto& operand : operand_list) {
+ size += operand.size();
+ }
+ size += operand_list.size() - 1; // Delimiters
+ new_value->reserve(size);
+
+ // Apply concatenation
+ new_value->assign(operand_list.front().data(), operand_list.front().size());
+
+ for (std::deque<Slice>::const_iterator it = operand_list.begin() + 1;
+ it != operand_list.end(); ++it) {
+ new_value->append(1, delim_);
+ new_value->append(it->data(), it->size());
+ }
+
+ return true;
+}
+
+const char* StringAppendTESTOperator::Name() const {
+ return "StringAppendTESTOperator";
+}
+
+
+std::shared_ptr<MergeOperator>
+MergeOperators::CreateStringAppendTESTOperator() {
+ return std::make_shared<StringAppendTESTOperator>(',');
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/merge_operators/string_append/stringappend2.h b/src/rocksdb/utilities/merge_operators/string_append/stringappend2.h
new file mode 100644
index 000000000..452164d8e
--- /dev/null
+++ b/src/rocksdb/utilities/merge_operators/string_append/stringappend2.h
@@ -0,0 +1,49 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+/**
+ * A TEST MergeOperator for rocksdb that implements string append.
+ * It is built using the MergeOperator interface rather than the simpler
+ * AssociativeMergeOperator interface. This is useful for testing/benchmarking.
+ * While the two operators are semantically the same, all production code
+ * should use the StringAppendOperator defined in stringappend.{h,cc}. The
+ * operator defined in the present file is primarily for testing.
+ *
+ * @author Deon Nicholas (dnicholas@fb.com)
+ * Copyright 2013 Facebook
+ */
+
+#pragma once
+#include <deque>
+#include <string>
+
+#include "rocksdb/merge_operator.h"
+#include "rocksdb/slice.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class StringAppendTESTOperator : public MergeOperator {
+ public:
+ // Constructor with delimiter
+ explicit StringAppendTESTOperator(char delim_char);
+
+ virtual bool FullMergeV2(const MergeOperationInput& merge_in,
+ MergeOperationOutput* merge_out) const override;
+
+ virtual bool PartialMergeMulti(const Slice& key,
+ const std::deque<Slice>& operand_list,
+ std::string* new_value, Logger* logger) const
+ override;
+
+ virtual const char* Name() const override;
+
+ private:
+ // A version of PartialMerge that actually performs "partial merging".
+ // Use this to simulate the exact behaviour of the StringAppendOperator.
+ bool _AssocPartialMergeMulti(const Slice& key,
+ const std::deque<Slice>& operand_list,
+ std::string* new_value, Logger* logger) const;
+
+ char delim_; // The delimiter is inserted between elements
+
+};
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/merge_operators/string_append/stringappend_test.cc b/src/rocksdb/utilities/merge_operators/string_append/stringappend_test.cc
new file mode 100644
index 000000000..c5f5e3e7c
--- /dev/null
+++ b/src/rocksdb/utilities/merge_operators/string_append/stringappend_test.cc
@@ -0,0 +1,601 @@
+/**
+ * An persistent map : key -> (list of strings), using rocksdb merge.
+ * This file is a test-harness / use-case for the StringAppendOperator.
+ *
+ * @author Deon Nicholas (dnicholas@fb.com)
+ * Copyright 2013 Facebook, Inc.
+*/
+
+#include <iostream>
+#include <map>
+
+#include "rocksdb/db.h"
+#include "rocksdb/merge_operator.h"
+#include "rocksdb/utilities/db_ttl.h"
+#include "test_util/testharness.h"
+#include "util/random.h"
+#include "utilities/merge_operators.h"
+#include "utilities/merge_operators/string_append/stringappend.h"
+#include "utilities/merge_operators/string_append/stringappend2.h"
+
+using namespace ROCKSDB_NAMESPACE;
+
+namespace ROCKSDB_NAMESPACE {
+
+// Path to the database on file system
+const std::string kDbName = test::PerThreadDBPath("stringappend_test");
+
+namespace {
+// OpenDb opens a (possibly new) rocksdb database with a StringAppendOperator
+std::shared_ptr<DB> OpenNormalDb(char delim_char) {
+ DB* db;
+ Options options;
+ options.create_if_missing = true;
+ options.merge_operator.reset(new StringAppendOperator(delim_char));
+ EXPECT_OK(DB::Open(options, kDbName, &db));
+ return std::shared_ptr<DB>(db);
+}
+
+#ifndef ROCKSDB_LITE // TtlDb is not supported in Lite
+// Open a TtlDB with a non-associative StringAppendTESTOperator
+std::shared_ptr<DB> OpenTtlDb(char delim_char) {
+ DBWithTTL* db;
+ Options options;
+ options.create_if_missing = true;
+ options.merge_operator.reset(new StringAppendTESTOperator(delim_char));
+ EXPECT_OK(DBWithTTL::Open(options, kDbName, &db, 123456));
+ return std::shared_ptr<DB>(db);
+}
+#endif // !ROCKSDB_LITE
+} // namespace
+
+/// StringLists represents a set of string-lists, each with a key-index.
+/// Supports Append(list, string) and Get(list)
+class StringLists {
+ public:
+
+ //Constructor: specifies the rocksdb db
+ /* implicit */
+ StringLists(std::shared_ptr<DB> db)
+ : db_(db),
+ merge_option_(),
+ get_option_() {
+ assert(db);
+ }
+
+ // Append string val onto the list defined by key; return true on success
+ bool Append(const std::string& key, const std::string& val){
+ Slice valSlice(val.data(), val.size());
+ auto s = db_->Merge(merge_option_, key, valSlice);
+
+ if (s.ok()) {
+ return true;
+ } else {
+ std::cerr << "ERROR " << s.ToString() << std::endl;
+ return false;
+ }
+ }
+
+ // Returns the list of strings associated with key (or "" if does not exist)
+ bool Get(const std::string& key, std::string* const result){
+ assert(result != nullptr); // we should have a place to store the result
+ auto s = db_->Get(get_option_, key, result);
+
+ if (s.ok()) {
+ return true;
+ }
+
+ // Either key does not exist, or there is some error.
+ *result = ""; // Always return empty string (just for convention)
+
+ //NotFound is okay; just return empty (similar to std::map)
+ //But network or db errors, etc, should fail the test (or at least yell)
+ if (!s.IsNotFound()) {
+ std::cerr << "ERROR " << s.ToString() << std::endl;
+ }
+
+ // Always return false if s.ok() was not true
+ return false;
+ }
+
+
+ private:
+ std::shared_ptr<DB> db_;
+ WriteOptions merge_option_;
+ ReadOptions get_option_;
+
+};
+
+
+// The class for unit-testing
+class StringAppendOperatorTest : public testing::Test {
+ public:
+ StringAppendOperatorTest() {
+ DestroyDB(kDbName, Options()); // Start each test with a fresh DB
+ }
+
+ typedef std::shared_ptr<DB> (* OpenFuncPtr)(char);
+
+ // Allows user to open databases with different configurations.
+ // e.g.: Can open a DB or a TtlDB, etc.
+ static void SetOpenDbFunction(OpenFuncPtr func) {
+ OpenDb = func;
+ }
+
+ protected:
+ static OpenFuncPtr OpenDb;
+};
+StringAppendOperatorTest::OpenFuncPtr StringAppendOperatorTest::OpenDb = nullptr;
+
+// THE TEST CASES BEGIN HERE
+
+TEST_F(StringAppendOperatorTest, IteratorTest) {
+ auto db_ = OpenDb(',');
+ StringLists slists(db_);
+
+ slists.Append("k1", "v1");
+ slists.Append("k1", "v2");
+ slists.Append("k1", "v3");
+
+ slists.Append("k2", "a1");
+ slists.Append("k2", "a2");
+ slists.Append("k2", "a3");
+
+ std::string res;
+ std::unique_ptr<ROCKSDB_NAMESPACE::Iterator> it(
+ db_->NewIterator(ReadOptions()));
+ std::string k1("k1");
+ std::string k2("k2");
+ bool first = true;
+ for (it->Seek(k1); it->Valid(); it->Next()) {
+ res = it->value().ToString();
+ if (first) {
+ ASSERT_EQ(res, "v1,v2,v3");
+ first = false;
+ } else {
+ ASSERT_EQ(res, "a1,a2,a3");
+ }
+ }
+ slists.Append("k2", "a4");
+ slists.Append("k1", "v4");
+
+ // Snapshot should still be the same. Should ignore a4 and v4.
+ first = true;
+ for (it->Seek(k1); it->Valid(); it->Next()) {
+ res = it->value().ToString();
+ if (first) {
+ ASSERT_EQ(res, "v1,v2,v3");
+ first = false;
+ } else {
+ ASSERT_EQ(res, "a1,a2,a3");
+ }
+ }
+
+
+ // Should release the snapshot and be aware of the new stuff now
+ it.reset(db_->NewIterator(ReadOptions()));
+ first = true;
+ for (it->Seek(k1); it->Valid(); it->Next()) {
+ res = it->value().ToString();
+ if (first) {
+ ASSERT_EQ(res, "v1,v2,v3,v4");
+ first = false;
+ } else {
+ ASSERT_EQ(res, "a1,a2,a3,a4");
+ }
+ }
+
+ // start from k2 this time.
+ for (it->Seek(k2); it->Valid(); it->Next()) {
+ res = it->value().ToString();
+ if (first) {
+ ASSERT_EQ(res, "v1,v2,v3,v4");
+ first = false;
+ } else {
+ ASSERT_EQ(res, "a1,a2,a3,a4");
+ }
+ }
+
+ slists.Append("k3", "g1");
+
+ it.reset(db_->NewIterator(ReadOptions()));
+ first = true;
+ std::string k3("k3");
+ for(it->Seek(k2); it->Valid(); it->Next()) {
+ res = it->value().ToString();
+ if (first) {
+ ASSERT_EQ(res, "a1,a2,a3,a4");
+ first = false;
+ } else {
+ ASSERT_EQ(res, "g1");
+ }
+ }
+ for(it->Seek(k3); it->Valid(); it->Next()) {
+ res = it->value().ToString();
+ if (first) {
+ // should not be hit
+ ASSERT_EQ(res, "a1,a2,a3,a4");
+ first = false;
+ } else {
+ ASSERT_EQ(res, "g1");
+ }
+ }
+
+}
+
+TEST_F(StringAppendOperatorTest, SimpleTest) {
+ auto db = OpenDb(',');
+ StringLists slists(db);
+
+ slists.Append("k1", "v1");
+ slists.Append("k1", "v2");
+ slists.Append("k1", "v3");
+
+ std::string res;
+ bool status = slists.Get("k1", &res);
+
+ ASSERT_TRUE(status);
+ ASSERT_EQ(res, "v1,v2,v3");
+}
+
+TEST_F(StringAppendOperatorTest, SimpleDelimiterTest) {
+ auto db = OpenDb('|');
+ StringLists slists(db);
+
+ slists.Append("k1", "v1");
+ slists.Append("k1", "v2");
+ slists.Append("k1", "v3");
+
+ std::string res;
+ slists.Get("k1", &res);
+ ASSERT_EQ(res, "v1|v2|v3");
+}
+
+TEST_F(StringAppendOperatorTest, OneValueNoDelimiterTest) {
+ auto db = OpenDb('!');
+ StringLists slists(db);
+
+ slists.Append("random_key", "single_val");
+
+ std::string res;
+ slists.Get("random_key", &res);
+ ASSERT_EQ(res, "single_val");
+}
+
+TEST_F(StringAppendOperatorTest, VariousKeys) {
+ auto db = OpenDb('\n');
+ StringLists slists(db);
+
+ slists.Append("c", "asdasd");
+ slists.Append("a", "x");
+ slists.Append("b", "y");
+ slists.Append("a", "t");
+ slists.Append("a", "r");
+ slists.Append("b", "2");
+ slists.Append("c", "asdasd");
+
+ std::string a, b, c;
+ bool sa, sb, sc;
+ sa = slists.Get("a", &a);
+ sb = slists.Get("b", &b);
+ sc = slists.Get("c", &c);
+
+ ASSERT_TRUE(sa && sb && sc); // All three keys should have been found
+
+ ASSERT_EQ(a, "x\nt\nr");
+ ASSERT_EQ(b, "y\n2");
+ ASSERT_EQ(c, "asdasd\nasdasd");
+}
+
+// Generate semi random keys/words from a small distribution.
+TEST_F(StringAppendOperatorTest, RandomMixGetAppend) {
+ auto db = OpenDb(' ');
+ StringLists slists(db);
+
+ // Generate a list of random keys and values
+ const int kWordCount = 15;
+ std::string words[] = {"sdasd", "triejf", "fnjsdfn", "dfjisdfsf", "342839",
+ "dsuha", "mabuais", "sadajsid", "jf9834hf", "2d9j89",
+ "dj9823jd", "a", "dk02ed2dh", "$(jd4h984$(*", "mabz"};
+ const int kKeyCount = 6;
+ std::string keys[] = {"dhaiusdhu", "denidw", "daisda", "keykey", "muki",
+ "shzassdianmd"};
+
+ // Will store a local copy of all data in order to verify correctness
+ std::map<std::string, std::string> parallel_copy;
+
+ // Generate a bunch of random queries (Append and Get)!
+ enum query_t { APPEND_OP, GET_OP, NUM_OPS };
+ Random randomGen(1337); //deterministic seed; always get same results!
+
+ const int kNumQueries = 30;
+ for (int q=0; q<kNumQueries; ++q) {
+ // Generate a random query (Append or Get) and random parameters
+ query_t query = (query_t)randomGen.Uniform((int)NUM_OPS);
+ std::string key = keys[randomGen.Uniform((int)kKeyCount)];
+ std::string word = words[randomGen.Uniform((int)kWordCount)];
+
+ // Apply the query and any checks.
+ if (query == APPEND_OP) {
+
+ // Apply the rocksdb test-harness Append defined above
+ slists.Append(key, word); //apply the rocksdb append
+
+ // Apply the similar "Append" to the parallel copy
+ if (parallel_copy[key].size() > 0) {
+ parallel_copy[key] += " " + word;
+ } else {
+ parallel_copy[key] = word;
+ }
+
+ } else if (query == GET_OP) {
+ // Assumes that a non-existent key just returns <empty>
+ std::string res;
+ slists.Get(key, &res);
+ ASSERT_EQ(res, parallel_copy[key]);
+ }
+
+ }
+
+}
+
+TEST_F(StringAppendOperatorTest, BIGRandomMixGetAppend) {
+ auto db = OpenDb(' ');
+ StringLists slists(db);
+
+ // Generate a list of random keys and values
+ const int kWordCount = 15;
+ std::string words[] = {"sdasd", "triejf", "fnjsdfn", "dfjisdfsf", "342839",
+ "dsuha", "mabuais", "sadajsid", "jf9834hf", "2d9j89",
+ "dj9823jd", "a", "dk02ed2dh", "$(jd4h984$(*", "mabz"};
+ const int kKeyCount = 6;
+ std::string keys[] = {"dhaiusdhu", "denidw", "daisda", "keykey", "muki",
+ "shzassdianmd"};
+
+ // Will store a local copy of all data in order to verify correctness
+ std::map<std::string, std::string> parallel_copy;
+
+ // Generate a bunch of random queries (Append and Get)!
+ enum query_t { APPEND_OP, GET_OP, NUM_OPS };
+ Random randomGen(9138204); // deterministic seed
+
+ const int kNumQueries = 1000;
+ for (int q=0; q<kNumQueries; ++q) {
+ // Generate a random query (Append or Get) and random parameters
+ query_t query = (query_t)randomGen.Uniform((int)NUM_OPS);
+ std::string key = keys[randomGen.Uniform((int)kKeyCount)];
+ std::string word = words[randomGen.Uniform((int)kWordCount)];
+
+ //Apply the query and any checks.
+ if (query == APPEND_OP) {
+
+ // Apply the rocksdb test-harness Append defined above
+ slists.Append(key, word); //apply the rocksdb append
+
+ // Apply the similar "Append" to the parallel copy
+ if (parallel_copy[key].size() > 0) {
+ parallel_copy[key] += " " + word;
+ } else {
+ parallel_copy[key] = word;
+ }
+
+ } else if (query == GET_OP) {
+ // Assumes that a non-existent key just returns <empty>
+ std::string res;
+ slists.Get(key, &res);
+ ASSERT_EQ(res, parallel_copy[key]);
+ }
+
+ }
+
+}
+
+TEST_F(StringAppendOperatorTest, PersistentVariousKeys) {
+ // Perform the following operations in limited scope
+ {
+ auto db = OpenDb('\n');
+ StringLists slists(db);
+
+ slists.Append("c", "asdasd");
+ slists.Append("a", "x");
+ slists.Append("b", "y");
+ slists.Append("a", "t");
+ slists.Append("a", "r");
+ slists.Append("b", "2");
+ slists.Append("c", "asdasd");
+
+ std::string a, b, c;
+ slists.Get("a", &a);
+ slists.Get("b", &b);
+ slists.Get("c", &c);
+
+ ASSERT_EQ(a, "x\nt\nr");
+ ASSERT_EQ(b, "y\n2");
+ ASSERT_EQ(c, "asdasd\nasdasd");
+ }
+
+ // Reopen the database (the previous changes should persist / be remembered)
+ {
+ auto db = OpenDb('\n');
+ StringLists slists(db);
+
+ slists.Append("c", "bbnagnagsx");
+ slists.Append("a", "sa");
+ slists.Append("b", "df");
+ slists.Append("a", "gh");
+ slists.Append("a", "jk");
+ slists.Append("b", "l;");
+ slists.Append("c", "rogosh");
+
+ // The previous changes should be on disk (L0)
+ // The most recent changes should be in memory (MemTable)
+ // Hence, this will test both Get() paths.
+ std::string a, b, c;
+ slists.Get("a", &a);
+ slists.Get("b", &b);
+ slists.Get("c", &c);
+
+ ASSERT_EQ(a, "x\nt\nr\nsa\ngh\njk");
+ ASSERT_EQ(b, "y\n2\ndf\nl;");
+ ASSERT_EQ(c, "asdasd\nasdasd\nbbnagnagsx\nrogosh");
+ }
+
+ // Reopen the database (the previous changes should persist / be remembered)
+ {
+ auto db = OpenDb('\n');
+ StringLists slists(db);
+
+ // All changes should be on disk. This will test VersionSet Get()
+ std::string a, b, c;
+ slists.Get("a", &a);
+ slists.Get("b", &b);
+ slists.Get("c", &c);
+
+ ASSERT_EQ(a, "x\nt\nr\nsa\ngh\njk");
+ ASSERT_EQ(b, "y\n2\ndf\nl;");
+ ASSERT_EQ(c, "asdasd\nasdasd\nbbnagnagsx\nrogosh");
+ }
+}
+
+TEST_F(StringAppendOperatorTest, PersistentFlushAndCompaction) {
+ // Perform the following operations in limited scope
+ {
+ auto db = OpenDb('\n');
+ StringLists slists(db);
+ std::string a, b, c;
+ bool success;
+
+ // Append, Flush, Get
+ slists.Append("c", "asdasd");
+ db->Flush(ROCKSDB_NAMESPACE::FlushOptions());
+ success = slists.Get("c", &c);
+ ASSERT_TRUE(success);
+ ASSERT_EQ(c, "asdasd");
+
+ // Append, Flush, Append, Get
+ slists.Append("a", "x");
+ slists.Append("b", "y");
+ db->Flush(ROCKSDB_NAMESPACE::FlushOptions());
+ slists.Append("a", "t");
+ slists.Append("a", "r");
+ slists.Append("b", "2");
+
+ success = slists.Get("a", &a);
+ assert(success == true);
+ ASSERT_EQ(a, "x\nt\nr");
+
+ success = slists.Get("b", &b);
+ assert(success == true);
+ ASSERT_EQ(b, "y\n2");
+
+ // Append, Get
+ success = slists.Append("c", "asdasd");
+ assert(success);
+ success = slists.Append("b", "monkey");
+ assert(success);
+
+ // I omit the "assert(success)" checks here.
+ slists.Get("a", &a);
+ slists.Get("b", &b);
+ slists.Get("c", &c);
+
+ ASSERT_EQ(a, "x\nt\nr");
+ ASSERT_EQ(b, "y\n2\nmonkey");
+ ASSERT_EQ(c, "asdasd\nasdasd");
+ }
+
+ // Reopen the database (the previous changes should persist / be remembered)
+ {
+ auto db = OpenDb('\n');
+ StringLists slists(db);
+ std::string a, b, c;
+
+ // Get (Quick check for persistence of previous database)
+ slists.Get("a", &a);
+ ASSERT_EQ(a, "x\nt\nr");
+
+ //Append, Compact, Get
+ slists.Append("c", "bbnagnagsx");
+ slists.Append("a", "sa");
+ slists.Append("b", "df");
+ db->CompactRange(CompactRangeOptions(), nullptr, nullptr);
+ slists.Get("a", &a);
+ slists.Get("b", &b);
+ slists.Get("c", &c);
+ ASSERT_EQ(a, "x\nt\nr\nsa");
+ ASSERT_EQ(b, "y\n2\nmonkey\ndf");
+ ASSERT_EQ(c, "asdasd\nasdasd\nbbnagnagsx");
+
+ // Append, Get
+ slists.Append("a", "gh");
+ slists.Append("a", "jk");
+ slists.Append("b", "l;");
+ slists.Append("c", "rogosh");
+ slists.Get("a", &a);
+ slists.Get("b", &b);
+ slists.Get("c", &c);
+ ASSERT_EQ(a, "x\nt\nr\nsa\ngh\njk");
+ ASSERT_EQ(b, "y\n2\nmonkey\ndf\nl;");
+ ASSERT_EQ(c, "asdasd\nasdasd\nbbnagnagsx\nrogosh");
+
+ // Compact, Get
+ db->CompactRange(CompactRangeOptions(), nullptr, nullptr);
+ ASSERT_EQ(a, "x\nt\nr\nsa\ngh\njk");
+ ASSERT_EQ(b, "y\n2\nmonkey\ndf\nl;");
+ ASSERT_EQ(c, "asdasd\nasdasd\nbbnagnagsx\nrogosh");
+
+ // Append, Flush, Compact, Get
+ slists.Append("b", "afcg");
+ db->Flush(ROCKSDB_NAMESPACE::FlushOptions());
+ db->CompactRange(CompactRangeOptions(), nullptr, nullptr);
+ slists.Get("b", &b);
+ ASSERT_EQ(b, "y\n2\nmonkey\ndf\nl;\nafcg");
+ }
+}
+
+TEST_F(StringAppendOperatorTest, SimpleTestNullDelimiter) {
+ auto db = OpenDb('\0');
+ StringLists slists(db);
+
+ slists.Append("k1", "v1");
+ slists.Append("k1", "v2");
+ slists.Append("k1", "v3");
+
+ std::string res;
+ bool status = slists.Get("k1", &res);
+ ASSERT_TRUE(status);
+
+ // Construct the desired string. Default constructor doesn't like '\0' chars.
+ std::string checker("v1,v2,v3"); // Verify that the string is right size.
+ checker[2] = '\0'; // Use null delimiter instead of comma.
+ checker[5] = '\0';
+ assert(checker.size() == 8); // Verify it is still the correct size
+
+ // Check that the rocksdb result string matches the desired string
+ assert(res.size() == checker.size());
+ ASSERT_EQ(res, checker);
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ // Run with regular database
+ int result;
+ {
+ fprintf(stderr, "Running tests with regular db and operator.\n");
+ StringAppendOperatorTest::SetOpenDbFunction(&OpenNormalDb);
+ result = RUN_ALL_TESTS();
+ }
+
+#ifndef ROCKSDB_LITE // TtlDb is not supported in Lite
+ // Run with TTL
+ {
+ fprintf(stderr, "Running tests with ttl db and generic operator.\n");
+ StringAppendOperatorTest::SetOpenDbFunction(&OpenTtlDb);
+ result |= RUN_ALL_TESTS();
+ }
+#endif // !ROCKSDB_LITE
+
+ return result;
+}
diff --git a/src/rocksdb/utilities/merge_operators/uint64add.cc b/src/rocksdb/utilities/merge_operators/uint64add.cc
new file mode 100644
index 000000000..3ef240928
--- /dev/null
+++ b/src/rocksdb/utilities/merge_operators/uint64add.cc
@@ -0,0 +1,69 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include <memory>
+
+#include "logging/logging.h"
+#include "rocksdb/env.h"
+#include "rocksdb/merge_operator.h"
+#include "rocksdb/slice.h"
+#include "util/coding.h"
+#include "utilities/merge_operators.h"
+
+using namespace ROCKSDB_NAMESPACE;
+
+namespace { // anonymous namespace
+
+// A 'model' merge operator with uint64 addition semantics
+// Implemented as an AssociativeMergeOperator for simplicity and example.
+class UInt64AddOperator : public AssociativeMergeOperator {
+ public:
+ bool Merge(const Slice& /*key*/, const Slice* existing_value,
+ const Slice& value, std::string* new_value,
+ Logger* logger) const override {
+ uint64_t orig_value = 0;
+ if (existing_value){
+ orig_value = DecodeInteger(*existing_value, logger);
+ }
+ uint64_t operand = DecodeInteger(value, logger);
+
+ assert(new_value);
+ new_value->clear();
+ PutFixed64(new_value, orig_value + operand);
+
+ return true; // Return true always since corruption will be treated as 0
+ }
+
+ const char* Name() const override { return "UInt64AddOperator"; }
+
+ private:
+ // Takes the string and decodes it into a uint64_t
+ // On error, prints a message and returns 0
+ uint64_t DecodeInteger(const Slice& value, Logger* logger) const {
+ uint64_t result = 0;
+
+ if (value.size() == sizeof(uint64_t)) {
+ result = DecodeFixed64(value.data());
+ } else if (logger != nullptr) {
+ // If value is corrupted, treat it as 0
+ ROCKS_LOG_ERROR(logger, "uint64 value corruption, size: %" ROCKSDB_PRIszt
+ " > %" ROCKSDB_PRIszt,
+ value.size(), sizeof(uint64_t));
+ }
+
+ return result;
+ }
+
+};
+
+}
+
+namespace ROCKSDB_NAMESPACE {
+
+std::shared_ptr<MergeOperator> MergeOperators::CreateUInt64AddOperator() {
+ return std::make_shared<UInt64AddOperator>();
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/object_registry.cc b/src/rocksdb/utilities/object_registry.cc
new file mode 100644
index 000000000..38e55020e
--- /dev/null
+++ b/src/rocksdb/utilities/object_registry.cc
@@ -0,0 +1,87 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "rocksdb/utilities/object_registry.h"
+
+#include "logging/logging.h"
+#include "rocksdb/env.h"
+
+namespace ROCKSDB_NAMESPACE {
+#ifndef ROCKSDB_LITE
+// Looks through the "type" factories for one that matches "name".
+// If found, returns the pointer to the Entry matching this name.
+// Otherwise, nullptr is returned
+const ObjectLibrary::Entry *ObjectLibrary::FindEntry(
+ const std::string &type, const std::string &name) const {
+ auto entries = entries_.find(type);
+ if (entries != entries_.end()) {
+ for (const auto &entry : entries->second) {
+ if (entry->matches(name)) {
+ return entry.get();
+ }
+ }
+ }
+ return nullptr;
+}
+
+void ObjectLibrary::AddEntry(const std::string &type,
+ std::unique_ptr<Entry> &entry) {
+ auto &entries = entries_[type];
+ entries.emplace_back(std::move(entry));
+}
+
+void ObjectLibrary::Dump(Logger *logger) const {
+ for (const auto &iter : entries_) {
+ ROCKS_LOG_HEADER(logger, " Registered factories for type[%s] ",
+ iter.first.c_str());
+ bool printed_one = false;
+ for (const auto &e : iter.second) {
+ ROCKS_LOG_HEADER(logger, "%c %s", (printed_one) ? ',' : ':',
+ e->Name().c_str());
+ printed_one = true;
+ }
+ }
+ ROCKS_LOG_HEADER(logger, "\n");
+}
+
+// Returns the Default singleton instance of the ObjectLibrary
+// This instance will contain most of the "standard" registered objects
+std::shared_ptr<ObjectLibrary> &ObjectLibrary::Default() {
+ static std::shared_ptr<ObjectLibrary> instance =
+ std::make_shared<ObjectLibrary>();
+ return instance;
+}
+
+std::shared_ptr<ObjectRegistry> ObjectRegistry::NewInstance() {
+ std::shared_ptr<ObjectRegistry> instance = std::make_shared<ObjectRegistry>();
+ return instance;
+}
+
+ObjectRegistry::ObjectRegistry() {
+ libraries_.push_back(ObjectLibrary::Default());
+}
+
+// Searches (from back to front) the libraries looking for the
+// an entry that matches this pattern.
+// Returns the entry if it is found, and nullptr otherwise
+const ObjectLibrary::Entry *ObjectRegistry::FindEntry(
+ const std::string &type, const std::string &name) const {
+ for (auto iter = libraries_.crbegin(); iter != libraries_.crend(); ++iter) {
+ const auto *entry = iter->get()->FindEntry(type, name);
+ if (entry != nullptr) {
+ return entry;
+ }
+ }
+ return nullptr;
+}
+
+void ObjectRegistry::Dump(Logger *logger) const {
+ for (auto iter = libraries_.crbegin(); iter != libraries_.crend(); ++iter) {
+ iter->get()->Dump(logger);
+ }
+}
+
+#endif // ROCKSDB_LITE
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/object_registry_test.cc b/src/rocksdb/utilities/object_registry_test.cc
new file mode 100644
index 000000000..bbb76b142
--- /dev/null
+++ b/src/rocksdb/utilities/object_registry_test.cc
@@ -0,0 +1,174 @@
+// Copyright (c) 2016-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "rocksdb/utilities/object_registry.h"
+#include "test_util/testharness.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class EnvRegistryTest : public testing::Test {
+ public:
+ static int num_a, num_b;
+};
+
+int EnvRegistryTest::num_a = 0;
+int EnvRegistryTest::num_b = 0;
+static FactoryFunc<Env> test_reg_a = ObjectLibrary::Default()->Register<Env>(
+ "a://.*",
+ [](const std::string& /*uri*/, std::unique_ptr<Env>* /*env_guard*/,
+ std::string* /* errmsg */) {
+ ++EnvRegistryTest::num_a;
+ return Env::Default();
+ });
+
+static FactoryFunc<Env> test_reg_b = ObjectLibrary::Default()->Register<Env>(
+ "b://.*", [](const std::string& /*uri*/, std::unique_ptr<Env>* env_guard,
+ std::string* /* errmsg */) {
+ ++EnvRegistryTest::num_b;
+ // Env::Default() is a singleton so we can't grant ownership directly to
+ // the caller - we must wrap it first.
+ env_guard->reset(new EnvWrapper(Env::Default()));
+ return env_guard->get();
+ });
+
+TEST_F(EnvRegistryTest, Basics) {
+ std::string msg;
+ std::unique_ptr<Env> env_guard;
+ auto registry = ObjectRegistry::NewInstance();
+ auto res = registry->NewObject<Env>("a://test", &env_guard, &msg);
+ ASSERT_NE(res, nullptr);
+ ASSERT_EQ(env_guard, nullptr);
+ ASSERT_EQ(1, num_a);
+ ASSERT_EQ(0, num_b);
+
+ res = registry->NewObject<Env>("b://test", &env_guard, &msg);
+ ASSERT_NE(res, nullptr);
+ ASSERT_NE(env_guard, nullptr);
+ ASSERT_EQ(1, num_a);
+ ASSERT_EQ(1, num_b);
+
+ res = registry->NewObject<Env>("c://test", &env_guard, &msg);
+ ASSERT_EQ(res, nullptr);
+ ASSERT_EQ(env_guard, nullptr);
+ ASSERT_EQ(1, num_a);
+ ASSERT_EQ(1, num_b);
+}
+
+TEST_F(EnvRegistryTest, LocalRegistry) {
+ std::string msg;
+ std::unique_ptr<Env> guard;
+ auto registry = ObjectRegistry::NewInstance();
+ std::shared_ptr<ObjectLibrary> library = std::make_shared<ObjectLibrary>();
+ registry->AddLibrary(library);
+ library->Register<Env>(
+ "test-local",
+ [](const std::string& /*uri*/, std::unique_ptr<Env>* /*guard */,
+ std::string* /* errmsg */) { return Env::Default(); });
+
+ ObjectLibrary::Default()->Register<Env>(
+ "test-global",
+ [](const std::string& /*uri*/, std::unique_ptr<Env>* /*guard */,
+ std::string* /* errmsg */) { return Env::Default(); });
+
+ ASSERT_EQ(
+ ObjectRegistry::NewInstance()->NewObject<Env>("test-local", &guard, &msg),
+ nullptr);
+ ASSERT_NE(
+ ObjectRegistry::NewInstance()->NewObject("test-global", &guard, &msg),
+ nullptr);
+ ASSERT_NE(registry->NewObject<Env>("test-local", &guard, &msg), nullptr);
+ ASSERT_NE(registry->NewObject<Env>("test-global", &guard, &msg), nullptr);
+}
+
+TEST_F(EnvRegistryTest, CheckShared) {
+ std::shared_ptr<Env> shared;
+ std::shared_ptr<ObjectRegistry> registry = ObjectRegistry::NewInstance();
+ std::shared_ptr<ObjectLibrary> library = std::make_shared<ObjectLibrary>();
+ registry->AddLibrary(library);
+ library->Register<Env>(
+ "unguarded",
+ [](const std::string& /*uri*/, std::unique_ptr<Env>* /*guard */,
+ std::string* /* errmsg */) { return Env::Default(); });
+
+ library->Register<Env>(
+ "guarded", [](const std::string& /*uri*/, std::unique_ptr<Env>* guard,
+ std::string* /* errmsg */) {
+ guard->reset(new EnvWrapper(Env::Default()));
+ return guard->get();
+ });
+
+ ASSERT_OK(registry->NewSharedObject<Env>("guarded", &shared));
+ ASSERT_NE(shared, nullptr);
+ shared.reset();
+ ASSERT_NOK(registry->NewSharedObject<Env>("unguarded", &shared));
+ ASSERT_EQ(shared, nullptr);
+}
+
+TEST_F(EnvRegistryTest, CheckStatic) {
+ Env* env = nullptr;
+ std::shared_ptr<ObjectRegistry> registry = ObjectRegistry::NewInstance();
+ std::shared_ptr<ObjectLibrary> library = std::make_shared<ObjectLibrary>();
+ registry->AddLibrary(library);
+ library->Register<Env>(
+ "unguarded",
+ [](const std::string& /*uri*/, std::unique_ptr<Env>* /*guard */,
+ std::string* /* errmsg */) { return Env::Default(); });
+
+ library->Register<Env>(
+ "guarded", [](const std::string& /*uri*/, std::unique_ptr<Env>* guard,
+ std::string* /* errmsg */) {
+ guard->reset(new EnvWrapper(Env::Default()));
+ return guard->get();
+ });
+
+ ASSERT_NOK(registry->NewStaticObject<Env>("guarded", &env));
+ ASSERT_EQ(env, nullptr);
+ env = nullptr;
+ ASSERT_OK(registry->NewStaticObject<Env>("unguarded", &env));
+ ASSERT_NE(env, nullptr);
+}
+
+TEST_F(EnvRegistryTest, CheckUnique) {
+ std::unique_ptr<Env> unique;
+ std::shared_ptr<ObjectRegistry> registry = ObjectRegistry::NewInstance();
+ std::shared_ptr<ObjectLibrary> library = std::make_shared<ObjectLibrary>();
+ registry->AddLibrary(library);
+ library->Register<Env>(
+ "unguarded",
+ [](const std::string& /*uri*/, std::unique_ptr<Env>* /*guard */,
+ std::string* /* errmsg */) { return Env::Default(); });
+
+ library->Register<Env>(
+ "guarded", [](const std::string& /*uri*/, std::unique_ptr<Env>* guard,
+ std::string* /* errmsg */) {
+ guard->reset(new EnvWrapper(Env::Default()));
+ return guard->get();
+ });
+
+ ASSERT_OK(registry->NewUniqueObject<Env>("guarded", &unique));
+ ASSERT_NE(unique, nullptr);
+ unique.reset();
+ ASSERT_NOK(registry->NewUniqueObject<Env>("unguarded", &unique));
+ ASSERT_EQ(unique, nullptr);
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
+
+#else // ROCKSDB_LITE
+#include <stdio.h>
+
+int main(int /*argc*/, char** /*argv*/) {
+ fprintf(stderr, "SKIPPED as EnvRegistry is not supported in ROCKSDB_LITE\n");
+ return 0;
+}
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/option_change_migration/option_change_migration.cc b/src/rocksdb/utilities/option_change_migration/option_change_migration.cc
new file mode 100644
index 000000000..f2382297b
--- /dev/null
+++ b/src/rocksdb/utilities/option_change_migration/option_change_migration.cc
@@ -0,0 +1,168 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "rocksdb/utilities/option_change_migration.h"
+
+#ifndef ROCKSDB_LITE
+#include "rocksdb/db.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace {
+// Return a version of Options `opts` that allow us to open/write into a DB
+// without triggering an automatic compaction or stalling. This is guaranteed
+// by disabling automatic compactions and using huge values for stalling
+// triggers.
+Options GetNoCompactionOptions(const Options& opts) {
+ Options ret_opts = opts;
+ ret_opts.disable_auto_compactions = true;
+ ret_opts.level0_slowdown_writes_trigger = 999999;
+ ret_opts.level0_stop_writes_trigger = 999999;
+ ret_opts.soft_pending_compaction_bytes_limit = 0;
+ ret_opts.hard_pending_compaction_bytes_limit = 0;
+ return ret_opts;
+}
+
+Status OpenDb(const Options& options, const std::string& dbname,
+ std::unique_ptr<DB>* db) {
+ db->reset();
+ DB* tmpdb;
+ Status s = DB::Open(options, dbname, &tmpdb);
+ if (s.ok()) {
+ db->reset(tmpdb);
+ }
+ return s;
+}
+
+Status CompactToLevel(const Options& options, const std::string& dbname,
+ int dest_level, bool need_reopen) {
+ std::unique_ptr<DB> db;
+ Options no_compact_opts = GetNoCompactionOptions(options);
+ if (dest_level == 0) {
+ // L0 has strict sequenceID requirements to files to it. It's safer
+ // to only put one compacted file to there.
+ // This is only used for converting to universal compaction with
+ // only one level. In this case, compacting to one file is also
+ // optimal.
+ no_compact_opts.target_file_size_base = 999999999999999;
+ no_compact_opts.max_compaction_bytes = 999999999999999;
+ }
+ Status s = OpenDb(no_compact_opts, dbname, &db);
+ if (!s.ok()) {
+ return s;
+ }
+ CompactRangeOptions cro;
+ cro.change_level = true;
+ cro.target_level = dest_level;
+ if (dest_level == 0) {
+ // cannot use kForceOptimized because the compaction is expected to
+ // generate one output file
+ cro.bottommost_level_compaction = BottommostLevelCompaction::kForce;
+ }
+ db->CompactRange(cro, nullptr, nullptr);
+
+ if (need_reopen) {
+ // Need to restart DB to rewrite the manifest file.
+ // In order to open a DB with specific num_levels, the manifest file should
+ // contain no record that mentiones any level beyond num_levels. Issuing a
+ // full compaction will move all the data to a level not exceeding
+ // num_levels, but the manifest may still contain previous record mentioning
+ // a higher level. Reopening the DB will force the manifest to be rewritten
+ // so that those records will be cleared.
+ db.reset();
+ s = OpenDb(no_compact_opts, dbname, &db);
+ }
+ return s;
+}
+
+Status MigrateToUniversal(std::string dbname, const Options& old_opts,
+ const Options& new_opts) {
+ if (old_opts.num_levels <= new_opts.num_levels ||
+ old_opts.compaction_style == CompactionStyle::kCompactionStyleFIFO) {
+ return Status::OK();
+ } else {
+ bool need_compact = false;
+ {
+ std::unique_ptr<DB> db;
+ Options opts = GetNoCompactionOptions(old_opts);
+ Status s = OpenDb(opts, dbname, &db);
+ if (!s.ok()) {
+ return s;
+ }
+ ColumnFamilyMetaData metadata;
+ db->GetColumnFamilyMetaData(&metadata);
+ if (!metadata.levels.empty() &&
+ metadata.levels.back().level >= new_opts.num_levels) {
+ need_compact = true;
+ }
+ }
+ if (need_compact) {
+ return CompactToLevel(old_opts, dbname, new_opts.num_levels - 1, true);
+ }
+ return Status::OK();
+ }
+}
+
+Status MigrateToLevelBase(std::string dbname, const Options& old_opts,
+ const Options& new_opts) {
+ if (!new_opts.level_compaction_dynamic_level_bytes) {
+ if (old_opts.num_levels == 1) {
+ return Status::OK();
+ }
+ // Compact everything to level 1 to guarantee it can be safely opened.
+ Options opts = old_opts;
+ opts.target_file_size_base = new_opts.target_file_size_base;
+ // Although sometimes we can open the DB with the new option without error,
+ // We still want to compact the files to avoid the LSM tree to stuck
+ // in bad shape. For example, if the user changed the level size
+ // multiplier from 4 to 8, with the same data, we will have fewer
+ // levels. Unless we issue a full comaction, the LSM tree may stuck
+ // with more levels than needed and it won't recover automatically.
+ return CompactToLevel(opts, dbname, 1, true);
+ } else {
+ // Compact everything to the last level to guarantee it can be safely
+ // opened.
+ if (old_opts.num_levels == 1) {
+ return Status::OK();
+ } else if (new_opts.num_levels > old_opts.num_levels) {
+ // Dynamic level mode requires data to be put in the last level first.
+ return CompactToLevel(new_opts, dbname, new_opts.num_levels - 1, false);
+ } else {
+ Options opts = old_opts;
+ opts.target_file_size_base = new_opts.target_file_size_base;
+ return CompactToLevel(opts, dbname, new_opts.num_levels - 1, true);
+ }
+ }
+}
+} // namespace
+
+Status OptionChangeMigration(std::string dbname, const Options& old_opts,
+ const Options& new_opts) {
+ if (old_opts.compaction_style == CompactionStyle::kCompactionStyleFIFO) {
+ // LSM generated by FIFO compation can be opened by any compaction.
+ return Status::OK();
+ } else if (new_opts.compaction_style ==
+ CompactionStyle::kCompactionStyleUniversal) {
+ return MigrateToUniversal(dbname, old_opts, new_opts);
+ } else if (new_opts.compaction_style ==
+ CompactionStyle::kCompactionStyleLevel) {
+ return MigrateToLevelBase(dbname, old_opts, new_opts);
+ } else if (new_opts.compaction_style ==
+ CompactionStyle::kCompactionStyleFIFO) {
+ return CompactToLevel(old_opts, dbname, 0, true);
+ } else {
+ return Status::NotSupported(
+ "Do not how to migrate to this compaction style");
+ }
+}
+} // namespace ROCKSDB_NAMESPACE
+#else
+namespace ROCKSDB_NAMESPACE {
+Status OptionChangeMigration(std::string /*dbname*/,
+ const Options& /*old_opts*/,
+ const Options& /*new_opts*/) {
+ return Status::NotSupported();
+}
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/option_change_migration/option_change_migration_test.cc b/src/rocksdb/utilities/option_change_migration/option_change_migration_test.cc
new file mode 100644
index 000000000..5bc883ff7
--- /dev/null
+++ b/src/rocksdb/utilities/option_change_migration/option_change_migration_test.cc
@@ -0,0 +1,425 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "rocksdb/utilities/option_change_migration.h"
+#include <set>
+#include "db/db_test_util.h"
+#include "port/stack_trace.h"
+namespace ROCKSDB_NAMESPACE {
+
+class DBOptionChangeMigrationTests
+ : public DBTestBase,
+ public testing::WithParamInterface<
+ std::tuple<int, int, bool, int, int, bool>> {
+ public:
+ DBOptionChangeMigrationTests()
+ : DBTestBase("/db_option_change_migration_test") {
+ level1_ = std::get<0>(GetParam());
+ compaction_style1_ = std::get<1>(GetParam());
+ is_dynamic1_ = std::get<2>(GetParam());
+
+ level2_ = std::get<3>(GetParam());
+ compaction_style2_ = std::get<4>(GetParam());
+ is_dynamic2_ = std::get<5>(GetParam());
+ }
+
+ // Required if inheriting from testing::WithParamInterface<>
+ static void SetUpTestCase() {}
+ static void TearDownTestCase() {}
+
+ int level1_;
+ int compaction_style1_;
+ bool is_dynamic1_;
+
+ int level2_;
+ int compaction_style2_;
+ bool is_dynamic2_;
+};
+
+#ifndef ROCKSDB_LITE
+TEST_P(DBOptionChangeMigrationTests, Migrate1) {
+ Options old_options = CurrentOptions();
+ old_options.compaction_style =
+ static_cast<CompactionStyle>(compaction_style1_);
+ if (old_options.compaction_style == CompactionStyle::kCompactionStyleLevel) {
+ old_options.level_compaction_dynamic_level_bytes = is_dynamic1_;
+ }
+
+ old_options.level0_file_num_compaction_trigger = 3;
+ old_options.write_buffer_size = 64 * 1024;
+ old_options.target_file_size_base = 128 * 1024;
+ // Make level target of L1, L2 to be 200KB and 600KB
+ old_options.num_levels = level1_;
+ old_options.max_bytes_for_level_multiplier = 3;
+ old_options.max_bytes_for_level_base = 200 * 1024;
+
+ Reopen(old_options);
+
+ Random rnd(301);
+ int key_idx = 0;
+
+ // Generate at least 2MB of data
+ for (int num = 0; num < 20; num++) {
+ GenerateNewFile(&rnd, &key_idx);
+ }
+ dbfull()->TEST_WaitForFlushMemTable();
+ dbfull()->TEST_WaitForCompact();
+
+ // Will make sure exactly those keys are in the DB after migration.
+ std::set<std::string> keys;
+ {
+ std::unique_ptr<Iterator> it(db_->NewIterator(ReadOptions()));
+ it->SeekToFirst();
+ for (; it->Valid(); it->Next()) {
+ keys.insert(it->key().ToString());
+ }
+ }
+ Close();
+
+ Options new_options = old_options;
+ new_options.compaction_style =
+ static_cast<CompactionStyle>(compaction_style2_);
+ if (new_options.compaction_style == CompactionStyle::kCompactionStyleLevel) {
+ new_options.level_compaction_dynamic_level_bytes = is_dynamic2_;
+ }
+ new_options.target_file_size_base = 256 * 1024;
+ new_options.num_levels = level2_;
+ new_options.max_bytes_for_level_base = 150 * 1024;
+ new_options.max_bytes_for_level_multiplier = 4;
+ ASSERT_OK(OptionChangeMigration(dbname_, old_options, new_options));
+ Reopen(new_options);
+
+ // Wait for compaction to finish and make sure it can reopen
+ dbfull()->TEST_WaitForFlushMemTable();
+ dbfull()->TEST_WaitForCompact();
+ Reopen(new_options);
+
+ {
+ std::unique_ptr<Iterator> it(db_->NewIterator(ReadOptions()));
+ it->SeekToFirst();
+ for (std::string key : keys) {
+ ASSERT_TRUE(it->Valid());
+ ASSERT_EQ(key, it->key().ToString());
+ it->Next();
+ }
+ ASSERT_TRUE(!it->Valid());
+ }
+}
+
+TEST_P(DBOptionChangeMigrationTests, Migrate2) {
+ Options old_options = CurrentOptions();
+ old_options.compaction_style =
+ static_cast<CompactionStyle>(compaction_style2_);
+ if (old_options.compaction_style == CompactionStyle::kCompactionStyleLevel) {
+ old_options.level_compaction_dynamic_level_bytes = is_dynamic2_;
+ }
+ old_options.level0_file_num_compaction_trigger = 3;
+ old_options.write_buffer_size = 64 * 1024;
+ old_options.target_file_size_base = 128 * 1024;
+ // Make level target of L1, L2 to be 200KB and 600KB
+ old_options.num_levels = level2_;
+ old_options.max_bytes_for_level_multiplier = 3;
+ old_options.max_bytes_for_level_base = 200 * 1024;
+
+ Reopen(old_options);
+
+ Random rnd(301);
+ int key_idx = 0;
+
+ // Generate at least 2MB of data
+ for (int num = 0; num < 20; num++) {
+ GenerateNewFile(&rnd, &key_idx);
+ }
+ dbfull()->TEST_WaitForFlushMemTable();
+ dbfull()->TEST_WaitForCompact();
+
+ // Will make sure exactly those keys are in the DB after migration.
+ std::set<std::string> keys;
+ {
+ std::unique_ptr<Iterator> it(db_->NewIterator(ReadOptions()));
+ it->SeekToFirst();
+ for (; it->Valid(); it->Next()) {
+ keys.insert(it->key().ToString());
+ }
+ }
+
+ Close();
+
+ Options new_options = old_options;
+ new_options.compaction_style =
+ static_cast<CompactionStyle>(compaction_style1_);
+ if (new_options.compaction_style == CompactionStyle::kCompactionStyleLevel) {
+ new_options.level_compaction_dynamic_level_bytes = is_dynamic1_;
+ }
+ new_options.target_file_size_base = 256 * 1024;
+ new_options.num_levels = level1_;
+ new_options.max_bytes_for_level_base = 150 * 1024;
+ new_options.max_bytes_for_level_multiplier = 4;
+ ASSERT_OK(OptionChangeMigration(dbname_, old_options, new_options));
+ Reopen(new_options);
+ // Wait for compaction to finish and make sure it can reopen
+ dbfull()->TEST_WaitForFlushMemTable();
+ dbfull()->TEST_WaitForCompact();
+ Reopen(new_options);
+
+ {
+ std::unique_ptr<Iterator> it(db_->NewIterator(ReadOptions()));
+ it->SeekToFirst();
+ for (std::string key : keys) {
+ ASSERT_TRUE(it->Valid());
+ ASSERT_EQ(key, it->key().ToString());
+ it->Next();
+ }
+ ASSERT_TRUE(!it->Valid());
+ }
+}
+
+TEST_P(DBOptionChangeMigrationTests, Migrate3) {
+ Options old_options = CurrentOptions();
+ old_options.compaction_style =
+ static_cast<CompactionStyle>(compaction_style1_);
+ if (old_options.compaction_style == CompactionStyle::kCompactionStyleLevel) {
+ old_options.level_compaction_dynamic_level_bytes = is_dynamic1_;
+ }
+
+ old_options.level0_file_num_compaction_trigger = 3;
+ old_options.write_buffer_size = 64 * 1024;
+ old_options.target_file_size_base = 128 * 1024;
+ // Make level target of L1, L2 to be 200KB and 600KB
+ old_options.num_levels = level1_;
+ old_options.max_bytes_for_level_multiplier = 3;
+ old_options.max_bytes_for_level_base = 200 * 1024;
+
+ Reopen(old_options);
+ Random rnd(301);
+ for (int num = 0; num < 20; num++) {
+ for (int i = 0; i < 50; i++) {
+ ASSERT_OK(Put(Key(num * 100 + i), RandomString(&rnd, 900)));
+ }
+ Flush();
+ dbfull()->TEST_WaitForCompact();
+ if (num == 9) {
+ // Issue a full compaction to generate some zero-out files
+ CompactRangeOptions cro;
+ cro.bottommost_level_compaction = BottommostLevelCompaction::kForce;
+ dbfull()->CompactRange(cro, nullptr, nullptr);
+ }
+ }
+ dbfull()->TEST_WaitForFlushMemTable();
+ dbfull()->TEST_WaitForCompact();
+
+ // Will make sure exactly those keys are in the DB after migration.
+ std::set<std::string> keys;
+ {
+ std::unique_ptr<Iterator> it(db_->NewIterator(ReadOptions()));
+ it->SeekToFirst();
+ for (; it->Valid(); it->Next()) {
+ keys.insert(it->key().ToString());
+ }
+ }
+ Close();
+
+ Options new_options = old_options;
+ new_options.compaction_style =
+ static_cast<CompactionStyle>(compaction_style2_);
+ if (new_options.compaction_style == CompactionStyle::kCompactionStyleLevel) {
+ new_options.level_compaction_dynamic_level_bytes = is_dynamic2_;
+ }
+ new_options.target_file_size_base = 256 * 1024;
+ new_options.num_levels = level2_;
+ new_options.max_bytes_for_level_base = 150 * 1024;
+ new_options.max_bytes_for_level_multiplier = 4;
+ ASSERT_OK(OptionChangeMigration(dbname_, old_options, new_options));
+ Reopen(new_options);
+
+ // Wait for compaction to finish and make sure it can reopen
+ dbfull()->TEST_WaitForFlushMemTable();
+ dbfull()->TEST_WaitForCompact();
+ Reopen(new_options);
+
+ {
+ std::unique_ptr<Iterator> it(db_->NewIterator(ReadOptions()));
+ it->SeekToFirst();
+ for (std::string key : keys) {
+ ASSERT_TRUE(it->Valid());
+ ASSERT_EQ(key, it->key().ToString());
+ it->Next();
+ }
+ ASSERT_TRUE(!it->Valid());
+ }
+}
+
+TEST_P(DBOptionChangeMigrationTests, Migrate4) {
+ Options old_options = CurrentOptions();
+ old_options.compaction_style =
+ static_cast<CompactionStyle>(compaction_style2_);
+ if (old_options.compaction_style == CompactionStyle::kCompactionStyleLevel) {
+ old_options.level_compaction_dynamic_level_bytes = is_dynamic2_;
+ }
+ old_options.level0_file_num_compaction_trigger = 3;
+ old_options.write_buffer_size = 64 * 1024;
+ old_options.target_file_size_base = 128 * 1024;
+ // Make level target of L1, L2 to be 200KB and 600KB
+ old_options.num_levels = level2_;
+ old_options.max_bytes_for_level_multiplier = 3;
+ old_options.max_bytes_for_level_base = 200 * 1024;
+
+ Reopen(old_options);
+ Random rnd(301);
+ for (int num = 0; num < 20; num++) {
+ for (int i = 0; i < 50; i++) {
+ ASSERT_OK(Put(Key(num * 100 + i), RandomString(&rnd, 900)));
+ }
+ Flush();
+ dbfull()->TEST_WaitForCompact();
+ if (num == 9) {
+ // Issue a full compaction to generate some zero-out files
+ CompactRangeOptions cro;
+ cro.bottommost_level_compaction = BottommostLevelCompaction::kForce;
+ dbfull()->CompactRange(cro, nullptr, nullptr);
+ }
+ }
+ dbfull()->TEST_WaitForFlushMemTable();
+ dbfull()->TEST_WaitForCompact();
+
+ // Will make sure exactly those keys are in the DB after migration.
+ std::set<std::string> keys;
+ {
+ std::unique_ptr<Iterator> it(db_->NewIterator(ReadOptions()));
+ it->SeekToFirst();
+ for (; it->Valid(); it->Next()) {
+ keys.insert(it->key().ToString());
+ }
+ }
+
+ Close();
+
+ Options new_options = old_options;
+ new_options.compaction_style =
+ static_cast<CompactionStyle>(compaction_style1_);
+ if (new_options.compaction_style == CompactionStyle::kCompactionStyleLevel) {
+ new_options.level_compaction_dynamic_level_bytes = is_dynamic1_;
+ }
+ new_options.target_file_size_base = 256 * 1024;
+ new_options.num_levels = level1_;
+ new_options.max_bytes_for_level_base = 150 * 1024;
+ new_options.max_bytes_for_level_multiplier = 4;
+ ASSERT_OK(OptionChangeMigration(dbname_, old_options, new_options));
+ Reopen(new_options);
+ // Wait for compaction to finish and make sure it can reopen
+ dbfull()->TEST_WaitForFlushMemTable();
+ dbfull()->TEST_WaitForCompact();
+ Reopen(new_options);
+
+ {
+ std::unique_ptr<Iterator> it(db_->NewIterator(ReadOptions()));
+ it->SeekToFirst();
+ for (std::string key : keys) {
+ ASSERT_TRUE(it->Valid());
+ ASSERT_EQ(key, it->key().ToString());
+ it->Next();
+ }
+ ASSERT_TRUE(!it->Valid());
+ }
+}
+
+INSTANTIATE_TEST_CASE_P(
+ DBOptionChangeMigrationTests, DBOptionChangeMigrationTests,
+ ::testing::Values(std::make_tuple(3, 0, false, 4, 0, false),
+ std::make_tuple(3, 0, true, 4, 0, true),
+ std::make_tuple(3, 0, true, 4, 0, false),
+ std::make_tuple(3, 0, false, 4, 0, true),
+ std::make_tuple(3, 1, false, 4, 1, false),
+ std::make_tuple(1, 1, false, 4, 1, false),
+ std::make_tuple(3, 0, false, 4, 1, false),
+ std::make_tuple(3, 0, false, 1, 1, false),
+ std::make_tuple(3, 0, true, 4, 1, false),
+ std::make_tuple(3, 0, true, 1, 1, false),
+ std::make_tuple(1, 1, false, 4, 0, false),
+ std::make_tuple(4, 0, false, 1, 2, false),
+ std::make_tuple(3, 0, true, 2, 2, false),
+ std::make_tuple(3, 1, false, 3, 2, false),
+ std::make_tuple(1, 1, false, 4, 2, false)));
+
+class DBOptionChangeMigrationTest : public DBTestBase {
+ public:
+ DBOptionChangeMigrationTest()
+ : DBTestBase("/db_option_change_migration_test2") {}
+};
+
+TEST_F(DBOptionChangeMigrationTest, CompactedSrcToUniversal) {
+ Options old_options = CurrentOptions();
+ old_options.compaction_style = CompactionStyle::kCompactionStyleLevel;
+ old_options.max_compaction_bytes = 200 * 1024;
+ old_options.level_compaction_dynamic_level_bytes = false;
+ old_options.level0_file_num_compaction_trigger = 3;
+ old_options.write_buffer_size = 64 * 1024;
+ old_options.target_file_size_base = 128 * 1024;
+ // Make level target of L1, L2 to be 200KB and 600KB
+ old_options.num_levels = 4;
+ old_options.max_bytes_for_level_multiplier = 3;
+ old_options.max_bytes_for_level_base = 200 * 1024;
+
+ Reopen(old_options);
+ Random rnd(301);
+ for (int num = 0; num < 20; num++) {
+ for (int i = 0; i < 50; i++) {
+ ASSERT_OK(Put(Key(num * 100 + i), RandomString(&rnd, 900)));
+ }
+ }
+ Flush();
+ CompactRangeOptions cro;
+ cro.bottommost_level_compaction = BottommostLevelCompaction::kForce;
+ dbfull()->CompactRange(cro, nullptr, nullptr);
+
+ // Will make sure exactly those keys are in the DB after migration.
+ std::set<std::string> keys;
+ {
+ std::unique_ptr<Iterator> it(db_->NewIterator(ReadOptions()));
+ it->SeekToFirst();
+ for (; it->Valid(); it->Next()) {
+ keys.insert(it->key().ToString());
+ }
+ }
+
+ Close();
+
+ Options new_options = old_options;
+ new_options.compaction_style = CompactionStyle::kCompactionStyleUniversal;
+ new_options.target_file_size_base = 256 * 1024;
+ new_options.num_levels = 1;
+ new_options.max_bytes_for_level_base = 150 * 1024;
+ new_options.max_bytes_for_level_multiplier = 4;
+ ASSERT_OK(OptionChangeMigration(dbname_, old_options, new_options));
+ Reopen(new_options);
+ // Wait for compaction to finish and make sure it can reopen
+ dbfull()->TEST_WaitForFlushMemTable();
+ dbfull()->TEST_WaitForCompact();
+ Reopen(new_options);
+
+ {
+ std::unique_ptr<Iterator> it(db_->NewIterator(ReadOptions()));
+ it->SeekToFirst();
+ for (std::string key : keys) {
+ ASSERT_TRUE(it->Valid());
+ ASSERT_EQ(key, it->key().ToString());
+ it->Next();
+ }
+ ASSERT_TRUE(!it->Valid());
+ }
+}
+
+#endif // ROCKSDB_LITE
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/rocksdb/utilities/options/options_util.cc b/src/rocksdb/utilities/options/options_util.cc
new file mode 100644
index 000000000..0719798e3
--- /dev/null
+++ b/src/rocksdb/utilities/options/options_util.cc
@@ -0,0 +1,114 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "rocksdb/utilities/options_util.h"
+
+#include "env/composite_env_wrapper.h"
+#include "file/filename.h"
+#include "options/options_parser.h"
+#include "rocksdb/options.h"
+
+namespace ROCKSDB_NAMESPACE {
+Status LoadOptionsFromFile(const std::string& file_name, Env* env,
+ DBOptions* db_options,
+ std::vector<ColumnFamilyDescriptor>* cf_descs,
+ bool ignore_unknown_options,
+ std::shared_ptr<Cache>* cache) {
+ RocksDBOptionsParser parser;
+ LegacyFileSystemWrapper fs(env);
+ Status s = parser.Parse(file_name, &fs, ignore_unknown_options,
+ 0 /* file_readahead_size */);
+ if (!s.ok()) {
+ return s;
+ }
+ *db_options = *parser.db_opt();
+ const std::vector<std::string>& cf_names = *parser.cf_names();
+ const std::vector<ColumnFamilyOptions>& cf_opts = *parser.cf_opts();
+ cf_descs->clear();
+ for (size_t i = 0; i < cf_opts.size(); ++i) {
+ cf_descs->push_back({cf_names[i], cf_opts[i]});
+ if (cache != nullptr) {
+ TableFactory* tf = cf_opts[i].table_factory.get();
+ if (tf != nullptr && tf->GetOptions() != nullptr &&
+ tf->Name() == BlockBasedTableFactory().Name()) {
+ auto* loaded_bbt_opt =
+ reinterpret_cast<BlockBasedTableOptions*>(tf->GetOptions());
+ loaded_bbt_opt->block_cache = *cache;
+ }
+ }
+ }
+ return Status::OK();
+}
+
+Status GetLatestOptionsFileName(const std::string& dbpath,
+ Env* env, std::string* options_file_name) {
+ Status s;
+ std::string latest_file_name;
+ uint64_t latest_time_stamp = 0;
+ std::vector<std::string> file_names;
+ s = env->GetChildren(dbpath, &file_names);
+ if (!s.ok()) {
+ return s;
+ }
+ for (auto& file_name : file_names) {
+ uint64_t time_stamp;
+ FileType type;
+ if (ParseFileName(file_name, &time_stamp, &type) && type == kOptionsFile) {
+ if (time_stamp > latest_time_stamp) {
+ latest_time_stamp = time_stamp;
+ latest_file_name = file_name;
+ }
+ }
+ }
+ if (latest_file_name.size() == 0) {
+ return Status::NotFound("No options files found in the DB directory.");
+ }
+ *options_file_name = latest_file_name;
+ return Status::OK();
+}
+
+Status LoadLatestOptions(const std::string& dbpath, Env* env,
+ DBOptions* db_options,
+ std::vector<ColumnFamilyDescriptor>* cf_descs,
+ bool ignore_unknown_options,
+ std::shared_ptr<Cache>* cache) {
+ std::string options_file_name;
+ Status s = GetLatestOptionsFileName(dbpath, env, &options_file_name);
+ if (!s.ok()) {
+ return s;
+ }
+ return LoadOptionsFromFile(dbpath + "/" + options_file_name, env, db_options,
+ cf_descs, ignore_unknown_options, cache);
+}
+
+Status CheckOptionsCompatibility(
+ const std::string& dbpath, Env* env, const DBOptions& db_options,
+ const std::vector<ColumnFamilyDescriptor>& cf_descs,
+ bool ignore_unknown_options) {
+ std::string options_file_name;
+ Status s = GetLatestOptionsFileName(dbpath, env, &options_file_name);
+ if (!s.ok()) {
+ return s;
+ }
+
+ std::vector<std::string> cf_names;
+ std::vector<ColumnFamilyOptions> cf_opts;
+ for (const auto& cf_desc : cf_descs) {
+ cf_names.push_back(cf_desc.name);
+ cf_opts.push_back(cf_desc.options);
+ }
+
+ const OptionsSanityCheckLevel kDefaultLevel = kSanityLevelLooselyCompatible;
+ LegacyFileSystemWrapper fs(env);
+
+ return RocksDBOptionsParser::VerifyRocksDBOptionsFromFile(
+ db_options, cf_names, cf_opts, dbpath + "/" + options_file_name, &fs,
+ kDefaultLevel, ignore_unknown_options);
+}
+
+} // namespace ROCKSDB_NAMESPACE
+#endif // !ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/options/options_util_test.cc b/src/rocksdb/utilities/options/options_util_test.cc
new file mode 100644
index 000000000..30ad76a99
--- /dev/null
+++ b/src/rocksdb/utilities/options/options_util_test.cc
@@ -0,0 +1,363 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include <cinttypes>
+
+#include <cctype>
+#include <unordered_map>
+
+#include "options/options_parser.h"
+#include "rocksdb/db.h"
+#include "rocksdb/table.h"
+#include "rocksdb/utilities/options_util.h"
+#include "test_util/testharness.h"
+#include "test_util/testutil.h"
+#include "util/random.h"
+
+#ifndef GFLAGS
+bool FLAGS_enable_print = false;
+#else
+#include "util/gflags_compat.h"
+using GFLAGS_NAMESPACE::ParseCommandLineFlags;
+DEFINE_bool(enable_print, false, "Print options generated to console.");
+#endif // GFLAGS
+
+namespace ROCKSDB_NAMESPACE {
+class OptionsUtilTest : public testing::Test {
+ public:
+ OptionsUtilTest() : rnd_(0xFB) {
+ env_.reset(new test::StringEnv(Env::Default()));
+ fs_.reset(new LegacyFileSystemWrapper(env_.get()));
+ dbname_ = test::PerThreadDBPath("options_util_test");
+ }
+
+ protected:
+ std::unique_ptr<test::StringEnv> env_;
+ std::unique_ptr<LegacyFileSystemWrapper> fs_;
+ std::string dbname_;
+ Random rnd_;
+};
+
+bool IsBlockBasedTableFactory(TableFactory* tf) {
+ return tf->Name() == BlockBasedTableFactory().Name();
+}
+
+TEST_F(OptionsUtilTest, SaveAndLoad) {
+ const size_t kCFCount = 5;
+
+ DBOptions db_opt;
+ std::vector<std::string> cf_names;
+ std::vector<ColumnFamilyOptions> cf_opts;
+ test::RandomInitDBOptions(&db_opt, &rnd_);
+ for (size_t i = 0; i < kCFCount; ++i) {
+ cf_names.push_back(i == 0 ? kDefaultColumnFamilyName
+ : test::RandomName(&rnd_, 10));
+ cf_opts.emplace_back();
+ test::RandomInitCFOptions(&cf_opts.back(), db_opt, &rnd_);
+ }
+
+ const std::string kFileName = "OPTIONS-123456";
+ PersistRocksDBOptions(db_opt, cf_names, cf_opts, kFileName, fs_.get());
+
+ DBOptions loaded_db_opt;
+ std::vector<ColumnFamilyDescriptor> loaded_cf_descs;
+ ASSERT_OK(LoadOptionsFromFile(kFileName, env_.get(), &loaded_db_opt,
+ &loaded_cf_descs));
+
+ ASSERT_OK(RocksDBOptionsParser::VerifyDBOptions(db_opt, loaded_db_opt));
+ test::RandomInitDBOptions(&db_opt, &rnd_);
+ ASSERT_NOK(RocksDBOptionsParser::VerifyDBOptions(db_opt, loaded_db_opt));
+
+ for (size_t i = 0; i < kCFCount; ++i) {
+ ASSERT_EQ(cf_names[i], loaded_cf_descs[i].name);
+ ASSERT_OK(RocksDBOptionsParser::VerifyCFOptions(
+ cf_opts[i], loaded_cf_descs[i].options));
+ if (IsBlockBasedTableFactory(cf_opts[i].table_factory.get())) {
+ ASSERT_OK(RocksDBOptionsParser::VerifyTableFactory(
+ cf_opts[i].table_factory.get(),
+ loaded_cf_descs[i].options.table_factory.get()));
+ }
+ test::RandomInitCFOptions(&cf_opts[i], db_opt, &rnd_);
+ ASSERT_NOK(RocksDBOptionsParser::VerifyCFOptions(
+ cf_opts[i], loaded_cf_descs[i].options));
+ }
+
+ for (size_t i = 0; i < kCFCount; ++i) {
+ if (cf_opts[i].compaction_filter) {
+ delete cf_opts[i].compaction_filter;
+ }
+ }
+}
+
+TEST_F(OptionsUtilTest, SaveAndLoadWithCacheCheck) {
+ // creating db
+ DBOptions db_opt;
+ db_opt.create_if_missing = true;
+ // initialize BlockBasedTableOptions
+ std::shared_ptr<Cache> cache = NewLRUCache(1 * 1024);
+ BlockBasedTableOptions bbt_opts;
+ bbt_opts.block_size = 32 * 1024;
+ // saving cf options
+ std::vector<ColumnFamilyOptions> cf_opts;
+ ColumnFamilyOptions default_column_family_opt = ColumnFamilyOptions();
+ default_column_family_opt.table_factory.reset(
+ NewBlockBasedTableFactory(bbt_opts));
+ cf_opts.push_back(default_column_family_opt);
+
+ ColumnFamilyOptions cf_opt_sample = ColumnFamilyOptions();
+ cf_opt_sample.table_factory.reset(NewBlockBasedTableFactory(bbt_opts));
+ cf_opts.push_back(cf_opt_sample);
+
+ ColumnFamilyOptions cf_opt_plain_table_opt = ColumnFamilyOptions();
+ cf_opt_plain_table_opt.table_factory.reset(NewPlainTableFactory());
+ cf_opts.push_back(cf_opt_plain_table_opt);
+
+ std::vector<std::string> cf_names;
+ cf_names.push_back(kDefaultColumnFamilyName);
+ cf_names.push_back("cf_sample");
+ cf_names.push_back("cf_plain_table_sample");
+ // Saving DB in file
+ const std::string kFileName = "OPTIONS-LOAD_CACHE_123456";
+ PersistRocksDBOptions(db_opt, cf_names, cf_opts, kFileName, fs_.get());
+ DBOptions loaded_db_opt;
+ std::vector<ColumnFamilyDescriptor> loaded_cf_descs;
+ ASSERT_OK(LoadOptionsFromFile(kFileName, env_.get(), &loaded_db_opt,
+ &loaded_cf_descs, false, &cache));
+ for (size_t i = 0; i < loaded_cf_descs.size(); i++) {
+ if (IsBlockBasedTableFactory(cf_opts[i].table_factory.get())) {
+ auto* loaded_bbt_opt = reinterpret_cast<BlockBasedTableOptions*>(
+ loaded_cf_descs[i].options.table_factory->GetOptions());
+ // Expect the same cache will be loaded
+ if (loaded_bbt_opt != nullptr) {
+ ASSERT_EQ(loaded_bbt_opt->block_cache.get(), cache.get());
+ }
+ }
+ }
+}
+
+namespace {
+class DummyTableFactory : public TableFactory {
+ public:
+ DummyTableFactory() {}
+ ~DummyTableFactory() override {}
+
+ const char* Name() const override { return "DummyTableFactory"; }
+
+ Status NewTableReader(
+ const TableReaderOptions& /*table_reader_options*/,
+ std::unique_ptr<RandomAccessFileReader>&& /*file*/,
+ uint64_t /*file_size*/, std::unique_ptr<TableReader>* /*table_reader*/,
+ bool /*prefetch_index_and_filter_in_cache*/) const override {
+ return Status::NotSupported();
+ }
+
+ TableBuilder* NewTableBuilder(
+ const TableBuilderOptions& /*table_builder_options*/,
+ uint32_t /*column_family_id*/,
+ WritableFileWriter* /*file*/) const override {
+ return nullptr;
+ }
+
+ Status SanitizeOptions(
+ const DBOptions& /*db_opts*/,
+ const ColumnFamilyOptions& /*cf_opts*/) const override {
+ return Status::NotSupported();
+ }
+
+ std::string GetPrintableTableOptions() const override { return ""; }
+
+ Status GetOptionString(std::string* /*opt_string*/,
+ const std::string& /*delimiter*/) const override {
+ return Status::OK();
+ }
+};
+
+class DummyMergeOperator : public MergeOperator {
+ public:
+ DummyMergeOperator() {}
+ ~DummyMergeOperator() override {}
+
+ bool FullMergeV2(const MergeOperationInput& /*merge_in*/,
+ MergeOperationOutput* /*merge_out*/) const override {
+ return false;
+ }
+
+ bool PartialMergeMulti(const Slice& /*key*/,
+ const std::deque<Slice>& /*operand_list*/,
+ std::string* /*new_value*/,
+ Logger* /*logger*/) const override {
+ return false;
+ }
+
+ const char* Name() const override { return "DummyMergeOperator"; }
+};
+
+class DummySliceTransform : public SliceTransform {
+ public:
+ DummySliceTransform() {}
+ ~DummySliceTransform() override {}
+
+ // Return the name of this transformation.
+ const char* Name() const override { return "DummySliceTransform"; }
+
+ // transform a src in domain to a dst in the range
+ Slice Transform(const Slice& src) const override { return src; }
+
+ // determine whether this is a valid src upon the function applies
+ bool InDomain(const Slice& /*src*/) const override { return false; }
+
+ // determine whether dst=Transform(src) for some src
+ bool InRange(const Slice& /*dst*/) const override { return false; }
+};
+
+} // namespace
+
+TEST_F(OptionsUtilTest, SanityCheck) {
+ DBOptions db_opt;
+ std::vector<ColumnFamilyDescriptor> cf_descs;
+ const size_t kCFCount = 5;
+ for (size_t i = 0; i < kCFCount; ++i) {
+ cf_descs.emplace_back();
+ cf_descs.back().name =
+ (i == 0) ? kDefaultColumnFamilyName : test::RandomName(&rnd_, 10);
+
+ cf_descs.back().options.table_factory.reset(NewBlockBasedTableFactory());
+ // Assign non-null values to prefix_extractors except the first cf.
+ cf_descs.back().options.prefix_extractor.reset(
+ i != 0 ? test::RandomSliceTransform(&rnd_) : nullptr);
+ cf_descs.back().options.merge_operator.reset(
+ test::RandomMergeOperator(&rnd_));
+ }
+
+ db_opt.create_missing_column_families = true;
+ db_opt.create_if_missing = true;
+
+ DestroyDB(dbname_, Options(db_opt, cf_descs[0].options));
+ DB* db;
+ std::vector<ColumnFamilyHandle*> handles;
+ // open and persist the options
+ ASSERT_OK(DB::Open(db_opt, dbname_, cf_descs, &handles, &db));
+
+ // close the db
+ for (auto* handle : handles) {
+ delete handle;
+ }
+ delete db;
+
+ // perform sanity check
+ ASSERT_OK(
+ CheckOptionsCompatibility(dbname_, Env::Default(), db_opt, cf_descs));
+
+ ASSERT_GE(kCFCount, 5);
+ // merge operator
+ {
+ std::shared_ptr<MergeOperator> merge_op =
+ cf_descs[0].options.merge_operator;
+
+ ASSERT_NE(merge_op.get(), nullptr);
+ cf_descs[0].options.merge_operator.reset();
+ ASSERT_NOK(
+ CheckOptionsCompatibility(dbname_, Env::Default(), db_opt, cf_descs));
+
+ cf_descs[0].options.merge_operator.reset(new DummyMergeOperator());
+ ASSERT_NOK(
+ CheckOptionsCompatibility(dbname_, Env::Default(), db_opt, cf_descs));
+
+ cf_descs[0].options.merge_operator = merge_op;
+ ASSERT_OK(
+ CheckOptionsCompatibility(dbname_, Env::Default(), db_opt, cf_descs));
+ }
+
+ // prefix extractor
+ {
+ std::shared_ptr<const SliceTransform> prefix_extractor =
+ cf_descs[1].options.prefix_extractor;
+
+ // It's okay to set prefix_extractor to nullptr.
+ ASSERT_NE(prefix_extractor, nullptr);
+ cf_descs[1].options.prefix_extractor.reset();
+ ASSERT_OK(
+ CheckOptionsCompatibility(dbname_, Env::Default(), db_opt, cf_descs));
+
+ cf_descs[1].options.prefix_extractor.reset(new DummySliceTransform());
+ ASSERT_OK(
+ CheckOptionsCompatibility(dbname_, Env::Default(), db_opt, cf_descs));
+
+ cf_descs[1].options.prefix_extractor = prefix_extractor;
+ ASSERT_OK(
+ CheckOptionsCompatibility(dbname_, Env::Default(), db_opt, cf_descs));
+ }
+
+ // prefix extractor nullptr case
+ {
+ std::shared_ptr<const SliceTransform> prefix_extractor =
+ cf_descs[0].options.prefix_extractor;
+
+ // It's okay to set prefix_extractor to nullptr.
+ ASSERT_EQ(prefix_extractor, nullptr);
+ cf_descs[0].options.prefix_extractor.reset();
+ ASSERT_OK(
+ CheckOptionsCompatibility(dbname_, Env::Default(), db_opt, cf_descs));
+
+ // It's okay to change prefix_extractor from nullptr to non-nullptr
+ cf_descs[0].options.prefix_extractor.reset(new DummySliceTransform());
+ ASSERT_OK(
+ CheckOptionsCompatibility(dbname_, Env::Default(), db_opt, cf_descs));
+
+ cf_descs[0].options.prefix_extractor = prefix_extractor;
+ ASSERT_OK(
+ CheckOptionsCompatibility(dbname_, Env::Default(), db_opt, cf_descs));
+ }
+
+ // comparator
+ {
+ test::SimpleSuffixReverseComparator comparator;
+
+ auto* prev_comparator = cf_descs[2].options.comparator;
+ cf_descs[2].options.comparator = &comparator;
+ ASSERT_NOK(
+ CheckOptionsCompatibility(dbname_, Env::Default(), db_opt, cf_descs));
+
+ cf_descs[2].options.comparator = prev_comparator;
+ ASSERT_OK(
+ CheckOptionsCompatibility(dbname_, Env::Default(), db_opt, cf_descs));
+ }
+
+ // table factory
+ {
+ std::shared_ptr<TableFactory> table_factory =
+ cf_descs[3].options.table_factory;
+
+ ASSERT_NE(table_factory, nullptr);
+ cf_descs[3].options.table_factory.reset(new DummyTableFactory());
+ ASSERT_NOK(
+ CheckOptionsCompatibility(dbname_, Env::Default(), db_opt, cf_descs));
+
+ cf_descs[3].options.table_factory = table_factory;
+ ASSERT_OK(
+ CheckOptionsCompatibility(dbname_, Env::Default(), db_opt, cf_descs));
+ }
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+#ifdef GFLAGS
+ ParseCommandLineFlags(&argc, &argv, true);
+#endif // GFLAGS
+ return RUN_ALL_TESTS();
+}
+
+#else
+#include <cstdio>
+
+int main(int /*argc*/, char** /*argv*/) {
+ printf("Skipped in RocksDBLite as utilities are not supported.\n");
+ return 0;
+}
+#endif // !ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/persistent_cache/block_cache_tier.cc b/src/rocksdb/utilities/persistent_cache/block_cache_tier.cc
new file mode 100644
index 000000000..658737571
--- /dev/null
+++ b/src/rocksdb/utilities/persistent_cache/block_cache_tier.cc
@@ -0,0 +1,425 @@
+// Copyright (c) 2013, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+#ifndef ROCKSDB_LITE
+
+#include "utilities/persistent_cache/block_cache_tier.h"
+
+#include <regex>
+#include <utility>
+#include <vector>
+
+#include "logging/logging.h"
+#include "port/port.h"
+#include "test_util/sync_point.h"
+#include "util/stop_watch.h"
+#include "utilities/persistent_cache/block_cache_tier_file.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+//
+// BlockCacheImpl
+//
+Status BlockCacheTier::Open() {
+ Status status;
+
+ WriteLock _(&lock_);
+
+ assert(!size_);
+
+ // Check the validity of the options
+ status = opt_.ValidateSettings();
+ assert(status.ok());
+ if (!status.ok()) {
+ Error(opt_.log, "Invalid block cache options");
+ return status;
+ }
+
+ // Create base directory or cleanup existing directory
+ status = opt_.env->CreateDirIfMissing(opt_.path);
+ if (!status.ok()) {
+ Error(opt_.log, "Error creating directory %s. %s", opt_.path.c_str(),
+ status.ToString().c_str());
+ return status;
+ }
+
+ // Create base/<cache dir> directory
+ status = opt_.env->CreateDir(GetCachePath());
+ if (!status.ok()) {
+ // directory already exists, clean it up
+ status = CleanupCacheFolder(GetCachePath());
+ assert(status.ok());
+ if (!status.ok()) {
+ Error(opt_.log, "Error creating directory %s. %s", opt_.path.c_str(),
+ status.ToString().c_str());
+ return status;
+ }
+ }
+
+ // create a new file
+ assert(!cache_file_);
+ status = NewCacheFile();
+ if (!status.ok()) {
+ Error(opt_.log, "Error creating new file %s. %s", opt_.path.c_str(),
+ status.ToString().c_str());
+ return status;
+ }
+
+ assert(cache_file_);
+
+ if (opt_.pipeline_writes) {
+ assert(!insert_th_.joinable());
+ insert_th_ = port::Thread(&BlockCacheTier::InsertMain, this);
+ }
+
+ return Status::OK();
+}
+
+bool IsCacheFile(const std::string& file) {
+ // check if the file has .rc suffix
+ // Unfortunately regex support across compilers is not even, so we use simple
+ // string parsing
+ size_t pos = file.find(".");
+ if (pos == std::string::npos) {
+ return false;
+ }
+
+ std::string suffix = file.substr(pos);
+ return suffix == ".rc";
+}
+
+Status BlockCacheTier::CleanupCacheFolder(const std::string& folder) {
+ std::vector<std::string> files;
+ Status status = opt_.env->GetChildren(folder, &files);
+ if (!status.ok()) {
+ Error(opt_.log, "Error getting files for %s. %s", folder.c_str(),
+ status.ToString().c_str());
+ return status;
+ }
+
+ // cleanup files with the patter :digi:.rc
+ for (auto file : files) {
+ if (IsCacheFile(file)) {
+ // cache file
+ Info(opt_.log, "Removing file %s.", file.c_str());
+ status = opt_.env->DeleteFile(folder + "/" + file);
+ if (!status.ok()) {
+ Error(opt_.log, "Error deleting file %s. %s", file.c_str(),
+ status.ToString().c_str());
+ return status;
+ }
+ } else {
+ ROCKS_LOG_DEBUG(opt_.log, "Skipping file %s", file.c_str());
+ }
+ }
+ return Status::OK();
+}
+
+Status BlockCacheTier::Close() {
+ // stop the insert thread
+ if (opt_.pipeline_writes && insert_th_.joinable()) {
+ InsertOp op(/*quit=*/true);
+ insert_ops_.Push(std::move(op));
+ insert_th_.join();
+ }
+
+ // stop the writer before
+ writer_.Stop();
+
+ // clear all metadata
+ WriteLock _(&lock_);
+ metadata_.Clear();
+ return Status::OK();
+}
+
+template<class T>
+void Add(std::map<std::string, double>* stats, const std::string& key,
+ const T& t) {
+ stats->insert({key, static_cast<double>(t)});
+}
+
+PersistentCache::StatsType BlockCacheTier::Stats() {
+ std::map<std::string, double> stats;
+ Add(&stats, "persistentcache.blockcachetier.bytes_piplined",
+ stats_.bytes_pipelined_.Average());
+ Add(&stats, "persistentcache.blockcachetier.bytes_written",
+ stats_.bytes_written_.Average());
+ Add(&stats, "persistentcache.blockcachetier.bytes_read",
+ stats_.bytes_read_.Average());
+ Add(&stats, "persistentcache.blockcachetier.insert_dropped",
+ stats_.insert_dropped_);
+ Add(&stats, "persistentcache.blockcachetier.cache_hits",
+ stats_.cache_hits_);
+ Add(&stats, "persistentcache.blockcachetier.cache_misses",
+ stats_.cache_misses_);
+ Add(&stats, "persistentcache.blockcachetier.cache_errors",
+ stats_.cache_errors_);
+ Add(&stats, "persistentcache.blockcachetier.cache_hits_pct",
+ stats_.CacheHitPct());
+ Add(&stats, "persistentcache.blockcachetier.cache_misses_pct",
+ stats_.CacheMissPct());
+ Add(&stats, "persistentcache.blockcachetier.read_hit_latency",
+ stats_.read_hit_latency_.Average());
+ Add(&stats, "persistentcache.blockcachetier.read_miss_latency",
+ stats_.read_miss_latency_.Average());
+ Add(&stats, "persistentcache.blockcachetier.write_latency",
+ stats_.write_latency_.Average());
+
+ auto out = PersistentCacheTier::Stats();
+ out.push_back(stats);
+ return out;
+}
+
+Status BlockCacheTier::Insert(const Slice& key, const char* data,
+ const size_t size) {
+ // update stats
+ stats_.bytes_pipelined_.Add(size);
+
+ if (opt_.pipeline_writes) {
+ // off load the write to the write thread
+ insert_ops_.Push(
+ InsertOp(key.ToString(), std::move(std::string(data, size))));
+ return Status::OK();
+ }
+
+ assert(!opt_.pipeline_writes);
+ return InsertImpl(key, Slice(data, size));
+}
+
+void BlockCacheTier::InsertMain() {
+ while (true) {
+ InsertOp op(insert_ops_.Pop());
+
+ if (op.signal_) {
+ // that is a secret signal to exit
+ break;
+ }
+
+ size_t retry = 0;
+ Status s;
+ while ((s = InsertImpl(Slice(op.key_), Slice(op.data_))).IsTryAgain()) {
+ if (retry > kMaxRetry) {
+ break;
+ }
+
+ // this can happen when the buffers are full, we wait till some buffers
+ // are free. Why don't we wait inside the code. This is because we want
+ // to support both pipelined and non-pipelined mode
+ buffer_allocator_.WaitUntilUsable();
+ retry++;
+ }
+
+ if (!s.ok()) {
+ stats_.insert_dropped_++;
+ }
+ }
+}
+
+Status BlockCacheTier::InsertImpl(const Slice& key, const Slice& data) {
+ // pre-condition
+ assert(key.size());
+ assert(data.size());
+ assert(cache_file_);
+
+ StopWatchNano timer(opt_.env, /*auto_start=*/ true);
+
+ WriteLock _(&lock_);
+
+ LBA lba;
+ if (metadata_.Lookup(key, &lba)) {
+ // the key already exists, this is duplicate insert
+ return Status::OK();
+ }
+
+ while (!cache_file_->Append(key, data, &lba)) {
+ if (!cache_file_->Eof()) {
+ ROCKS_LOG_DEBUG(opt_.log, "Error inserting to cache file %d",
+ cache_file_->cacheid());
+ stats_.write_latency_.Add(timer.ElapsedNanos() / 1000);
+ return Status::TryAgain();
+ }
+
+ assert(cache_file_->Eof());
+ Status status = NewCacheFile();
+ if (!status.ok()) {
+ return status;
+ }
+ }
+
+ // Insert into lookup index
+ BlockInfo* info = metadata_.Insert(key, lba);
+ assert(info);
+ if (!info) {
+ return Status::IOError("Unexpected error inserting to index");
+ }
+
+ // insert to cache file reverse mapping
+ cache_file_->Add(info);
+
+ // update stats
+ stats_.bytes_written_.Add(data.size());
+ stats_.write_latency_.Add(timer.ElapsedNanos() / 1000);
+ return Status::OK();
+}
+
+Status BlockCacheTier::Lookup(const Slice& key, std::unique_ptr<char[]>* val,
+ size_t* size) {
+ StopWatchNano timer(opt_.env, /*auto_start=*/ true);
+
+ LBA lba;
+ bool status;
+ status = metadata_.Lookup(key, &lba);
+ if (!status) {
+ stats_.cache_misses_++;
+ stats_.read_miss_latency_.Add(timer.ElapsedNanos() / 1000);
+ return Status::NotFound("blockcache: key not found");
+ }
+
+ BlockCacheFile* const file = metadata_.Lookup(lba.cache_id_);
+ if (!file) {
+ // this can happen because the block index and cache file index are
+ // different, and the cache file might be removed between the two lookups
+ stats_.cache_misses_++;
+ stats_.read_miss_latency_.Add(timer.ElapsedNanos() / 1000);
+ return Status::NotFound("blockcache: cache file not found");
+ }
+
+ assert(file->refs_);
+
+ std::unique_ptr<char[]> scratch(new char[lba.size_]);
+ Slice blk_key;
+ Slice blk_val;
+
+ status = file->Read(lba, &blk_key, &blk_val, scratch.get());
+ --file->refs_;
+ if (!status) {
+ stats_.cache_misses_++;
+ stats_.cache_errors_++;
+ stats_.read_miss_latency_.Add(timer.ElapsedNanos() / 1000);
+ return Status::NotFound("blockcache: error reading data");
+ }
+
+ assert(blk_key == key);
+
+ val->reset(new char[blk_val.size()]);
+ memcpy(val->get(), blk_val.data(), blk_val.size());
+ *size = blk_val.size();
+
+ stats_.bytes_read_.Add(*size);
+ stats_.cache_hits_++;
+ stats_.read_hit_latency_.Add(timer.ElapsedNanos() / 1000);
+
+ return Status::OK();
+}
+
+bool BlockCacheTier::Erase(const Slice& key) {
+ WriteLock _(&lock_);
+ BlockInfo* info = metadata_.Remove(key);
+ assert(info);
+ delete info;
+ return true;
+}
+
+Status BlockCacheTier::NewCacheFile() {
+ lock_.AssertHeld();
+
+ TEST_SYNC_POINT_CALLBACK("BlockCacheTier::NewCacheFile:DeleteDir",
+ (void*)(GetCachePath().c_str()));
+
+ std::unique_ptr<WriteableCacheFile> f(
+ new WriteableCacheFile(opt_.env, &buffer_allocator_, &writer_,
+ GetCachePath(), writer_cache_id_,
+ opt_.cache_file_size, opt_.log));
+
+ bool status = f->Create(opt_.enable_direct_writes, opt_.enable_direct_reads);
+ if (!status) {
+ return Status::IOError("Error creating file");
+ }
+
+ Info(opt_.log, "Created cache file %d", writer_cache_id_);
+
+ writer_cache_id_++;
+ cache_file_ = f.release();
+
+ // insert to cache files tree
+ status = metadata_.Insert(cache_file_);
+ assert(status);
+ if (!status) {
+ Error(opt_.log, "Error inserting to metadata");
+ return Status::IOError("Error inserting to metadata");
+ }
+
+ return Status::OK();
+}
+
+bool BlockCacheTier::Reserve(const size_t size) {
+ WriteLock _(&lock_);
+ assert(size_ <= opt_.cache_size);
+
+ if (size + size_ <= opt_.cache_size) {
+ // there is enough space to write
+ size_ += size;
+ return true;
+ }
+
+ assert(size + size_ >= opt_.cache_size);
+ // there is not enough space to fit the requested data
+ // we can clear some space by evicting cold data
+
+ const double retain_fac = (100 - kEvictPct) / static_cast<double>(100);
+ while (size + size_ > opt_.cache_size * retain_fac) {
+ std::unique_ptr<BlockCacheFile> f(metadata_.Evict());
+ if (!f) {
+ // nothing is evictable
+ return false;
+ }
+ assert(!f->refs_);
+ uint64_t file_size;
+ if (!f->Delete(&file_size).ok()) {
+ // unable to delete file
+ return false;
+ }
+
+ assert(file_size <= size_);
+ size_ -= file_size;
+ }
+
+ size_ += size;
+ assert(size_ <= opt_.cache_size * 0.9);
+ return true;
+}
+
+Status NewPersistentCache(Env* const env, const std::string& path,
+ const uint64_t size,
+ const std::shared_ptr<Logger>& log,
+ const bool optimized_for_nvm,
+ std::shared_ptr<PersistentCache>* cache) {
+ if (!cache) {
+ return Status::IOError("invalid argument cache");
+ }
+
+ auto opt = PersistentCacheConfig(env, path, size, log);
+ if (optimized_for_nvm) {
+ // the default settings are optimized for SSD
+ // NVM devices are better accessed with 4K direct IO and written with
+ // parallelism
+ opt.enable_direct_writes = true;
+ opt.writer_qdepth = 4;
+ opt.writer_dispatch_size = 4 * 1024;
+ }
+
+ auto pcache = std::make_shared<BlockCacheTier>(opt);
+ Status s = pcache->Open();
+
+ if (!s.ok()) {
+ return s;
+ }
+
+ *cache = pcache;
+ return s;
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // ifndef ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/persistent_cache/block_cache_tier.h b/src/rocksdb/utilities/persistent_cache/block_cache_tier.h
new file mode 100644
index 000000000..ae0c13fdb
--- /dev/null
+++ b/src/rocksdb/utilities/persistent_cache/block_cache_tier.h
@@ -0,0 +1,156 @@
+// Copyright (c) 2013, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+#pragma once
+
+#ifndef ROCKSDB_LITE
+
+#ifndef OS_WIN
+#include <unistd.h>
+#endif // ! OS_WIN
+
+#include <atomic>
+#include <list>
+#include <memory>
+#include <set>
+#include <sstream>
+#include <stdexcept>
+#include <string>
+#include <thread>
+
+#include "rocksdb/cache.h"
+#include "rocksdb/comparator.h"
+#include "rocksdb/persistent_cache.h"
+
+#include "utilities/persistent_cache/block_cache_tier_file.h"
+#include "utilities/persistent_cache/block_cache_tier_metadata.h"
+#include "utilities/persistent_cache/persistent_cache_util.h"
+
+#include "memory/arena.h"
+#include "memtable/skiplist.h"
+#include "monitoring/histogram.h"
+#include "port/port.h"
+#include "util/coding.h"
+#include "util/crc32c.h"
+#include "util/mutexlock.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+//
+// Block cache tier implementation
+//
+class BlockCacheTier : public PersistentCacheTier {
+ public:
+ explicit BlockCacheTier(const PersistentCacheConfig& opt)
+ : opt_(opt),
+ insert_ops_(static_cast<size_t>(opt_.max_write_pipeline_backlog_size)),
+ buffer_allocator_(opt.write_buffer_size, opt.write_buffer_count()),
+ writer_(this, opt_.writer_qdepth, static_cast<size_t>(opt_.writer_dispatch_size)) {
+ Info(opt_.log, "Initializing allocator. size=%d B count=%" ROCKSDB_PRIszt,
+ opt_.write_buffer_size, opt_.write_buffer_count());
+ }
+
+ virtual ~BlockCacheTier() {
+ // Close is re-entrant so we can call close even if it is already closed
+ Close();
+ assert(!insert_th_.joinable());
+ }
+
+ Status Insert(const Slice& key, const char* data, const size_t size) override;
+ Status Lookup(const Slice& key, std::unique_ptr<char[]>* data,
+ size_t* size) override;
+ Status Open() override;
+ Status Close() override;
+ bool Erase(const Slice& key) override;
+ bool Reserve(const size_t size) override;
+
+ bool IsCompressed() override { return opt_.is_compressed; }
+
+ std::string GetPrintableOptions() const override { return opt_.ToString(); }
+
+ PersistentCache::StatsType Stats() override;
+
+ void TEST_Flush() override {
+ while (insert_ops_.Size()) {
+ /* sleep override */
+ Env::Default()->SleepForMicroseconds(1000000);
+ }
+ }
+
+ private:
+ // Percentage of cache to be evicted when the cache is full
+ static const size_t kEvictPct = 10;
+ // Max attempts to insert key, value to cache in pipelined mode
+ static const size_t kMaxRetry = 3;
+
+ // Pipelined operation
+ struct InsertOp {
+ explicit InsertOp(const bool signal) : signal_(signal) {}
+ explicit InsertOp(std::string&& key, const std::string& data)
+ : key_(std::move(key)), data_(data) {}
+ ~InsertOp() {}
+
+ InsertOp() = delete;
+ InsertOp(InsertOp&& /*rhs*/) = default;
+ InsertOp& operator=(InsertOp&& rhs) = default;
+
+ // used for estimating size by bounded queue
+ size_t Size() { return data_.size() + key_.size(); }
+
+ std::string key_;
+ std::string data_;
+ bool signal_ = false; // signal to request processing thread to exit
+ };
+
+ // entry point for insert thread
+ void InsertMain();
+ // insert implementation
+ Status InsertImpl(const Slice& key, const Slice& data);
+ // Create a new cache file
+ Status NewCacheFile();
+ // Get cache directory path
+ std::string GetCachePath() const { return opt_.path + "/cache"; }
+ // Cleanup folder
+ Status CleanupCacheFolder(const std::string& folder);
+
+ // Statistics
+ struct Statistics {
+ HistogramImpl bytes_pipelined_;
+ HistogramImpl bytes_written_;
+ HistogramImpl bytes_read_;
+ HistogramImpl read_hit_latency_;
+ HistogramImpl read_miss_latency_;
+ HistogramImpl write_latency_;
+ std::atomic<uint64_t> cache_hits_{0};
+ std::atomic<uint64_t> cache_misses_{0};
+ std::atomic<uint64_t> cache_errors_{0};
+ std::atomic<uint64_t> insert_dropped_{0};
+
+ double CacheHitPct() const {
+ const auto lookups = cache_hits_ + cache_misses_;
+ return lookups ? 100 * cache_hits_ / static_cast<double>(lookups) : 0.0;
+ }
+
+ double CacheMissPct() const {
+ const auto lookups = cache_hits_ + cache_misses_;
+ return lookups ? 100 * cache_misses_ / static_cast<double>(lookups) : 0.0;
+ }
+ };
+
+ port::RWMutex lock_; // Synchronization
+ const PersistentCacheConfig opt_; // BlockCache options
+ BoundedQueue<InsertOp> insert_ops_; // Ops waiting for insert
+ ROCKSDB_NAMESPACE::port::Thread insert_th_; // Insert thread
+ uint32_t writer_cache_id_ = 0; // Current cache file identifier
+ WriteableCacheFile* cache_file_ = nullptr; // Current cache file reference
+ CacheWriteBufferAllocator buffer_allocator_; // Buffer provider
+ ThreadedWriter writer_; // Writer threads
+ BlockCacheTierMetadata metadata_; // Cache meta data manager
+ std::atomic<uint64_t> size_{0}; // Size of the cache
+ Statistics stats_; // Statistics
+};
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif
diff --git a/src/rocksdb/utilities/persistent_cache/block_cache_tier_file.cc b/src/rocksdb/utilities/persistent_cache/block_cache_tier_file.cc
new file mode 100644
index 000000000..87ae603c5
--- /dev/null
+++ b/src/rocksdb/utilities/persistent_cache/block_cache_tier_file.cc
@@ -0,0 +1,608 @@
+// Copyright (c) 2013, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+#ifndef ROCKSDB_LITE
+
+#include "utilities/persistent_cache/block_cache_tier_file.h"
+
+#ifndef OS_WIN
+#include <unistd.h>
+#endif
+#include <functional>
+#include <memory>
+#include <vector>
+
+#include "env/composite_env_wrapper.h"
+#include "logging/logging.h"
+#include "port/port.h"
+#include "util/crc32c.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+//
+// File creation factories
+//
+Status NewWritableCacheFile(Env* const env, const std::string& filepath,
+ std::unique_ptr<WritableFile>* file,
+ const bool use_direct_writes = false) {
+ EnvOptions opt;
+ opt.use_direct_writes = use_direct_writes;
+ Status s = env->NewWritableFile(filepath, file, opt);
+ return s;
+}
+
+Status NewRandomAccessCacheFile(Env* const env, const std::string& filepath,
+ std::unique_ptr<RandomAccessFile>* file,
+ const bool use_direct_reads = true) {
+ assert(env);
+
+ EnvOptions opt;
+ opt.use_direct_reads = use_direct_reads;
+ Status s = env->NewRandomAccessFile(filepath, file, opt);
+ return s;
+}
+
+//
+// BlockCacheFile
+//
+Status BlockCacheFile::Delete(uint64_t* size) {
+ assert(env_);
+
+ Status status = env_->GetFileSize(Path(), size);
+ if (!status.ok()) {
+ return status;
+ }
+ return env_->DeleteFile(Path());
+}
+
+//
+// CacheRecord
+//
+// Cache record represents the record on disk
+//
+// +--------+---------+----------+------------+---------------+-------------+
+// | magic | crc | key size | value size | key data | value data |
+// +--------+---------+----------+------------+---------------+-------------+
+// <-- 4 --><-- 4 --><-- 4 --><-- 4 --><-- key size --><-- v-size -->
+//
+struct CacheRecordHeader {
+ CacheRecordHeader()
+ : magic_(0), crc_(0), key_size_(0), val_size_(0) {}
+ CacheRecordHeader(const uint32_t magic, const uint32_t key_size,
+ const uint32_t val_size)
+ : magic_(magic), crc_(0), key_size_(key_size), val_size_(val_size) {}
+
+ uint32_t magic_;
+ uint32_t crc_;
+ uint32_t key_size_;
+ uint32_t val_size_;
+};
+
+struct CacheRecord {
+ CacheRecord() {}
+ CacheRecord(const Slice& key, const Slice& val)
+ : hdr_(MAGIC, static_cast<uint32_t>(key.size()),
+ static_cast<uint32_t>(val.size())),
+ key_(key),
+ val_(val) {
+ hdr_.crc_ = ComputeCRC();
+ }
+
+ uint32_t ComputeCRC() const;
+ bool Serialize(std::vector<CacheWriteBuffer*>* bufs, size_t* woff);
+ bool Deserialize(const Slice& buf);
+
+ static uint32_t CalcSize(const Slice& key, const Slice& val) {
+ return static_cast<uint32_t>(sizeof(CacheRecordHeader) + key.size() +
+ val.size());
+ }
+
+ static const uint32_t MAGIC = 0xfefa;
+
+ bool Append(std::vector<CacheWriteBuffer*>* bufs, size_t* woff,
+ const char* data, const size_t size);
+
+ CacheRecordHeader hdr_;
+ Slice key_;
+ Slice val_;
+};
+
+static_assert(sizeof(CacheRecordHeader) == 16, "DataHeader is not aligned");
+
+uint32_t CacheRecord::ComputeCRC() const {
+ uint32_t crc = 0;
+ CacheRecordHeader tmp = hdr_;
+ tmp.crc_ = 0;
+ crc = crc32c::Extend(crc, reinterpret_cast<const char*>(&tmp), sizeof(tmp));
+ crc = crc32c::Extend(crc, reinterpret_cast<const char*>(key_.data()),
+ key_.size());
+ crc = crc32c::Extend(crc, reinterpret_cast<const char*>(val_.data()),
+ val_.size());
+ return crc;
+}
+
+bool CacheRecord::Serialize(std::vector<CacheWriteBuffer*>* bufs,
+ size_t* woff) {
+ assert(bufs->size());
+ return Append(bufs, woff, reinterpret_cast<const char*>(&hdr_),
+ sizeof(hdr_)) &&
+ Append(bufs, woff, reinterpret_cast<const char*>(key_.data()),
+ key_.size()) &&
+ Append(bufs, woff, reinterpret_cast<const char*>(val_.data()),
+ val_.size());
+}
+
+bool CacheRecord::Append(std::vector<CacheWriteBuffer*>* bufs, size_t* woff,
+ const char* data, const size_t data_size) {
+ assert(*woff < bufs->size());
+
+ const char* p = data;
+ size_t size = data_size;
+
+ while (size && *woff < bufs->size()) {
+ CacheWriteBuffer* buf = (*bufs)[*woff];
+ const size_t free = buf->Free();
+ if (size <= free) {
+ buf->Append(p, size);
+ size = 0;
+ } else {
+ buf->Append(p, free);
+ p += free;
+ size -= free;
+ assert(!buf->Free());
+ assert(buf->Used() == buf->Capacity());
+ }
+
+ if (!buf->Free()) {
+ *woff += 1;
+ }
+ }
+
+ assert(!size);
+
+ return !size;
+}
+
+bool CacheRecord::Deserialize(const Slice& data) {
+ assert(data.size() >= sizeof(CacheRecordHeader));
+ if (data.size() < sizeof(CacheRecordHeader)) {
+ return false;
+ }
+
+ memcpy(&hdr_, data.data(), sizeof(hdr_));
+
+ assert(hdr_.key_size_ + hdr_.val_size_ + sizeof(hdr_) == data.size());
+ if (hdr_.key_size_ + hdr_.val_size_ + sizeof(hdr_) != data.size()) {
+ return false;
+ }
+
+ key_ = Slice(data.data_ + sizeof(hdr_), hdr_.key_size_);
+ val_ = Slice(key_.data_ + hdr_.key_size_, hdr_.val_size_);
+
+ if (!(hdr_.magic_ == MAGIC && ComputeCRC() == hdr_.crc_)) {
+ fprintf(stderr, "** magic %d ** \n", hdr_.magic_);
+ fprintf(stderr, "** key_size %d ** \n", hdr_.key_size_);
+ fprintf(stderr, "** val_size %d ** \n", hdr_.val_size_);
+ fprintf(stderr, "** key %s ** \n", key_.ToString().c_str());
+ fprintf(stderr, "** val %s ** \n", val_.ToString().c_str());
+ for (size_t i = 0; i < hdr_.val_size_; ++i) {
+ fprintf(stderr, "%d.", (uint8_t)val_.data()[i]);
+ }
+ fprintf(stderr, "\n** cksum %d != %d **", hdr_.crc_, ComputeCRC());
+ }
+
+ assert(hdr_.magic_ == MAGIC && ComputeCRC() == hdr_.crc_);
+ return hdr_.magic_ == MAGIC && ComputeCRC() == hdr_.crc_;
+}
+
+//
+// RandomAccessFile
+//
+
+bool RandomAccessCacheFile::Open(const bool enable_direct_reads) {
+ WriteLock _(&rwlock_);
+ return OpenImpl(enable_direct_reads);
+}
+
+bool RandomAccessCacheFile::OpenImpl(const bool enable_direct_reads) {
+ rwlock_.AssertHeld();
+
+ ROCKS_LOG_DEBUG(log_, "Opening cache file %s", Path().c_str());
+
+ std::unique_ptr<RandomAccessFile> file;
+ Status status =
+ NewRandomAccessCacheFile(env_, Path(), &file, enable_direct_reads);
+ if (!status.ok()) {
+ Error(log_, "Error opening random access file %s. %s", Path().c_str(),
+ status.ToString().c_str());
+ return false;
+ }
+ freader_.reset(new RandomAccessFileReader(
+ NewLegacyRandomAccessFileWrapper(file), Path(), env_));
+
+ return true;
+}
+
+bool RandomAccessCacheFile::Read(const LBA& lba, Slice* key, Slice* val,
+ char* scratch) {
+ ReadLock _(&rwlock_);
+
+ assert(lba.cache_id_ == cache_id_);
+
+ if (!freader_) {
+ return false;
+ }
+
+ Slice result;
+ Status s = freader_->Read(lba.off_, lba.size_, &result, scratch);
+ if (!s.ok()) {
+ Error(log_, "Error reading from file %s. %s", Path().c_str(),
+ s.ToString().c_str());
+ return false;
+ }
+
+ assert(result.data() == scratch);
+
+ return ParseRec(lba, key, val, scratch);
+}
+
+bool RandomAccessCacheFile::ParseRec(const LBA& lba, Slice* key, Slice* val,
+ char* scratch) {
+ Slice data(scratch, lba.size_);
+
+ CacheRecord rec;
+ if (!rec.Deserialize(data)) {
+ assert(!"Error deserializing data");
+ Error(log_, "Error de-serializing record from file %s off %d",
+ Path().c_str(), lba.off_);
+ return false;
+ }
+
+ *key = Slice(rec.key_);
+ *val = Slice(rec.val_);
+
+ return true;
+}
+
+//
+// WriteableCacheFile
+//
+
+WriteableCacheFile::~WriteableCacheFile() {
+ WriteLock _(&rwlock_);
+ if (!eof_) {
+ // This file never flushed. We give priority to shutdown since this is a
+ // cache
+ // TODO(krad): Figure a way to flush the pending data
+ if (file_) {
+ assert(refs_ == 1);
+ --refs_;
+ }
+ }
+ assert(!refs_);
+ ClearBuffers();
+}
+
+bool WriteableCacheFile::Create(const bool /*enable_direct_writes*/,
+ const bool enable_direct_reads) {
+ WriteLock _(&rwlock_);
+
+ enable_direct_reads_ = enable_direct_reads;
+
+ ROCKS_LOG_DEBUG(log_, "Creating new cache %s (max size is %d B)",
+ Path().c_str(), max_size_);
+
+ assert(env_);
+
+ Status s = env_->FileExists(Path());
+ if (s.ok()) {
+ ROCKS_LOG_WARN(log_, "File %s already exists. %s", Path().c_str(),
+ s.ToString().c_str());
+ }
+
+ s = NewWritableCacheFile(env_, Path(), &file_);
+ if (!s.ok()) {
+ ROCKS_LOG_WARN(log_, "Unable to create file %s. %s", Path().c_str(),
+ s.ToString().c_str());
+ return false;
+ }
+
+ assert(!refs_);
+ ++refs_;
+
+ return true;
+}
+
+bool WriteableCacheFile::Append(const Slice& key, const Slice& val, LBA* lba) {
+ WriteLock _(&rwlock_);
+
+ if (eof_) {
+ // We can't append since the file is full
+ return false;
+ }
+
+ // estimate the space required to store the (key, val)
+ uint32_t rec_size = CacheRecord::CalcSize(key, val);
+
+ if (!ExpandBuffer(rec_size)) {
+ // unable to expand the buffer
+ ROCKS_LOG_DEBUG(log_, "Error expanding buffers. size=%d", rec_size);
+ return false;
+ }
+
+ lba->cache_id_ = cache_id_;
+ lba->off_ = disk_woff_;
+ lba->size_ = rec_size;
+
+ CacheRecord rec(key, val);
+ if (!rec.Serialize(&bufs_, &buf_woff_)) {
+ // unexpected error: unable to serialize the data
+ assert(!"Error serializing record");
+ return false;
+ }
+
+ disk_woff_ += rec_size;
+ eof_ = disk_woff_ >= max_size_;
+
+ // dispatch buffer for flush
+ DispatchBuffer();
+
+ return true;
+}
+
+bool WriteableCacheFile::ExpandBuffer(const size_t size) {
+ rwlock_.AssertHeld();
+ assert(!eof_);
+
+ // determine if there is enough space
+ size_t free = 0; // compute the free space left in buffer
+ for (size_t i = buf_woff_; i < bufs_.size(); ++i) {
+ free += bufs_[i]->Free();
+ if (size <= free) {
+ // we have enough space in the buffer
+ return true;
+ }
+ }
+
+ // expand the buffer until there is enough space to write `size` bytes
+ assert(free < size);
+ assert(alloc_);
+
+ while (free < size) {
+ CacheWriteBuffer* const buf = alloc_->Allocate();
+ if (!buf) {
+ ROCKS_LOG_DEBUG(log_, "Unable to allocate buffers");
+ return false;
+ }
+
+ size_ += static_cast<uint32_t>(buf->Free());
+ free += buf->Free();
+ bufs_.push_back(buf);
+ }
+
+ assert(free >= size);
+ return true;
+}
+
+void WriteableCacheFile::DispatchBuffer() {
+ rwlock_.AssertHeld();
+
+ assert(bufs_.size());
+ assert(buf_doff_ <= buf_woff_);
+ assert(buf_woff_ <= bufs_.size());
+
+ if (pending_ios_) {
+ return;
+ }
+
+ if (!eof_ && buf_doff_ == buf_woff_) {
+ // dispatch buffer is pointing to write buffer and we haven't hit eof
+ return;
+ }
+
+ assert(eof_ || buf_doff_ < buf_woff_);
+ assert(buf_doff_ < bufs_.size());
+ assert(file_);
+ assert(alloc_);
+
+ auto* buf = bufs_[buf_doff_];
+ const uint64_t file_off = buf_doff_ * alloc_->BufferSize();
+
+ assert(!buf->Free() ||
+ (eof_ && buf_doff_ == buf_woff_ && buf_woff_ < bufs_.size()));
+ // we have reached end of file, and there is space in the last buffer
+ // pad it with zero for direct IO
+ buf->FillTrailingZeros();
+
+ assert(buf->Used() % kFileAlignmentSize == 0);
+
+ writer_->Write(file_.get(), buf, file_off,
+ std::bind(&WriteableCacheFile::BufferWriteDone, this));
+ pending_ios_++;
+ buf_doff_++;
+}
+
+void WriteableCacheFile::BufferWriteDone() {
+ WriteLock _(&rwlock_);
+
+ assert(bufs_.size());
+
+ pending_ios_--;
+
+ if (buf_doff_ < bufs_.size()) {
+ DispatchBuffer();
+ }
+
+ if (eof_ && buf_doff_ >= bufs_.size() && !pending_ios_) {
+ // end-of-file reached, move to read mode
+ CloseAndOpenForReading();
+ }
+}
+
+void WriteableCacheFile::CloseAndOpenForReading() {
+ // Our env abstraction do not allow reading from a file opened for appending
+ // We need close the file and re-open it for reading
+ Close();
+ RandomAccessCacheFile::OpenImpl(enable_direct_reads_);
+}
+
+bool WriteableCacheFile::ReadBuffer(const LBA& lba, Slice* key, Slice* block,
+ char* scratch) {
+ rwlock_.AssertHeld();
+
+ if (!ReadBuffer(lba, scratch)) {
+ Error(log_, "Error reading from buffer. cache=%d off=%d", cache_id_,
+ lba.off_);
+ return false;
+ }
+
+ return ParseRec(lba, key, block, scratch);
+}
+
+bool WriteableCacheFile::ReadBuffer(const LBA& lba, char* data) {
+ rwlock_.AssertHeld();
+
+ assert(lba.off_ < disk_woff_);
+ assert(alloc_);
+
+ // we read from the buffers like reading from a flat file. The list of buffers
+ // are treated as contiguous stream of data
+
+ char* tmp = data;
+ size_t pending_nbytes = lba.size_;
+ // start buffer
+ size_t start_idx = lba.off_ / alloc_->BufferSize();
+ // offset into the start buffer
+ size_t start_off = lba.off_ % alloc_->BufferSize();
+
+ assert(start_idx <= buf_woff_);
+
+ for (size_t i = start_idx; pending_nbytes && i < bufs_.size(); ++i) {
+ assert(i <= buf_woff_);
+ auto* buf = bufs_[i];
+ assert(i == buf_woff_ || !buf->Free());
+ // bytes to write to the buffer
+ size_t nbytes = pending_nbytes > (buf->Used() - start_off)
+ ? (buf->Used() - start_off)
+ : pending_nbytes;
+ memcpy(tmp, buf->Data() + start_off, nbytes);
+
+ // left over to be written
+ pending_nbytes -= nbytes;
+ start_off = 0;
+ tmp += nbytes;
+ }
+
+ assert(!pending_nbytes);
+ if (pending_nbytes) {
+ return false;
+ }
+
+ assert(tmp == data + lba.size_);
+ return true;
+}
+
+void WriteableCacheFile::Close() {
+ rwlock_.AssertHeld();
+
+ assert(size_ >= max_size_);
+ assert(disk_woff_ >= max_size_);
+ assert(buf_doff_ == bufs_.size());
+ assert(bufs_.size() - buf_woff_ <= 1);
+ assert(!pending_ios_);
+
+ Info(log_, "Closing file %s. size=%d written=%d", Path().c_str(), size_,
+ disk_woff_);
+
+ ClearBuffers();
+ file_.reset();
+
+ assert(refs_);
+ --refs_;
+}
+
+void WriteableCacheFile::ClearBuffers() {
+ assert(alloc_);
+
+ for (size_t i = 0; i < bufs_.size(); ++i) {
+ alloc_->Deallocate(bufs_[i]);
+ }
+
+ bufs_.clear();
+}
+
+//
+// ThreadedFileWriter implementation
+//
+ThreadedWriter::ThreadedWriter(PersistentCacheTier* const cache,
+ const size_t qdepth, const size_t io_size)
+ : Writer(cache), io_size_(io_size) {
+ for (size_t i = 0; i < qdepth; ++i) {
+ port::Thread th(&ThreadedWriter::ThreadMain, this);
+ threads_.push_back(std::move(th));
+ }
+}
+
+void ThreadedWriter::Stop() {
+ // notify all threads to exit
+ for (size_t i = 0; i < threads_.size(); ++i) {
+ q_.Push(IO(/*signal=*/true));
+ }
+
+ // wait for all threads to exit
+ for (auto& th : threads_) {
+ th.join();
+ assert(!th.joinable());
+ }
+ threads_.clear();
+}
+
+void ThreadedWriter::Write(WritableFile* const file, CacheWriteBuffer* buf,
+ const uint64_t file_off,
+ const std::function<void()> callback) {
+ q_.Push(IO(file, buf, file_off, callback));
+}
+
+void ThreadedWriter::ThreadMain() {
+ while (true) {
+ // Fetch the IO to process
+ IO io(q_.Pop());
+ if (io.signal_) {
+ // that's secret signal to exit
+ break;
+ }
+
+ // Reserve space for writing the buffer
+ while (!cache_->Reserve(io.buf_->Used())) {
+ // We can fail to reserve space if every file in the system
+ // is being currently accessed
+ /* sleep override */
+ Env::Default()->SleepForMicroseconds(1000000);
+ }
+
+ DispatchIO(io);
+
+ io.callback_();
+ }
+}
+
+void ThreadedWriter::DispatchIO(const IO& io) {
+ size_t written = 0;
+ while (written < io.buf_->Used()) {
+ Slice data(io.buf_->Data() + written, io_size_);
+ Status s = io.file_->Append(data);
+ assert(s.ok());
+ if (!s.ok()) {
+ // That is definite IO error to device. There is not much we can
+ // do but ignore the failure. This can lead to corruption of data on
+ // disk, but the cache will skip while reading
+ fprintf(stderr, "Error writing data to file. %s\n", s.ToString().c_str());
+ }
+ written += io_size_;
+ }
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif
diff --git a/src/rocksdb/utilities/persistent_cache/block_cache_tier_file.h b/src/rocksdb/utilities/persistent_cache/block_cache_tier_file.h
new file mode 100644
index 000000000..95be4ec3c
--- /dev/null
+++ b/src/rocksdb/utilities/persistent_cache/block_cache_tier_file.h
@@ -0,0 +1,296 @@
+// Copyright (c) 2013, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+#pragma once
+
+#ifndef ROCKSDB_LITE
+
+#include <list>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "file/random_access_file_reader.h"
+
+#include "rocksdb/comparator.h"
+#include "rocksdb/env.h"
+
+#include "utilities/persistent_cache/block_cache_tier_file_buffer.h"
+#include "utilities/persistent_cache/lrulist.h"
+#include "utilities/persistent_cache/persistent_cache_tier.h"
+#include "utilities/persistent_cache/persistent_cache_util.h"
+
+#include "port/port.h"
+#include "util/crc32c.h"
+#include "util/mutexlock.h"
+
+// The io code path of persistent cache uses pipelined architecture
+//
+// client -> In Queue <-- BlockCacheTier --> Out Queue <-- Writer <--> Kernel
+//
+// This would enable the system to scale for GB/s of throughput which is
+// expected with modern devies like NVM.
+//
+// The file level operations are encapsulated in the following abstractions
+//
+// BlockCacheFile
+// ^
+// |
+// |
+// RandomAccessCacheFile (For reading)
+// ^
+// |
+// |
+// WriteableCacheFile (For writing)
+//
+// Write IO code path :
+//
+namespace ROCKSDB_NAMESPACE {
+
+class WriteableCacheFile;
+struct BlockInfo;
+
+// Represents a logical record on device
+//
+// (L)ogical (B)lock (Address = { cache-file-id, offset, size }
+struct LogicalBlockAddress {
+ LogicalBlockAddress() {}
+ explicit LogicalBlockAddress(const uint32_t cache_id, const uint32_t off,
+ const uint16_t size)
+ : cache_id_(cache_id), off_(off), size_(size) {}
+
+ uint32_t cache_id_ = 0;
+ uint32_t off_ = 0;
+ uint32_t size_ = 0;
+};
+
+typedef LogicalBlockAddress LBA;
+
+// class Writer
+//
+// Writer is the abstraction used for writing data to file. The component can be
+// multithreaded. It is the last step of write pipeline
+class Writer {
+ public:
+ explicit Writer(PersistentCacheTier* const cache) : cache_(cache) {}
+ virtual ~Writer() {}
+
+ // write buffer to file at the given offset
+ virtual void Write(WritableFile* const file, CacheWriteBuffer* buf,
+ const uint64_t file_off,
+ const std::function<void()> callback) = 0;
+ // stop the writer
+ virtual void Stop() = 0;
+
+ PersistentCacheTier* const cache_;
+};
+
+// class BlockCacheFile
+//
+// Generic interface to support building file specialized for read/writing
+class BlockCacheFile : public LRUElement<BlockCacheFile> {
+ public:
+ explicit BlockCacheFile(const uint32_t cache_id)
+ : LRUElement<BlockCacheFile>(), cache_id_(cache_id) {}
+
+ explicit BlockCacheFile(Env* const env, const std::string& dir,
+ const uint32_t cache_id)
+ : LRUElement<BlockCacheFile>(),
+ env_(env),
+ dir_(dir),
+ cache_id_(cache_id) {}
+
+ virtual ~BlockCacheFile() {}
+
+ // append key/value to file and return LBA locator to user
+ virtual bool Append(const Slice& /*key*/, const Slice& /*val*/,
+ LBA* const /*lba*/) {
+ assert(!"not implemented");
+ return false;
+ }
+
+ // read from the record locator (LBA) and return key, value and status
+ virtual bool Read(const LBA& /*lba*/, Slice* /*key*/, Slice* /*block*/,
+ char* /*scratch*/) {
+ assert(!"not implemented");
+ return false;
+ }
+
+ // get file path
+ std::string Path() const {
+ return dir_ + "/" + std::to_string(cache_id_) + ".rc";
+ }
+ // get cache ID
+ uint32_t cacheid() const { return cache_id_; }
+ // Add block information to file data
+ // Block information is the list of index reference for this file
+ virtual void Add(BlockInfo* binfo) {
+ WriteLock _(&rwlock_);
+ block_infos_.push_back(binfo);
+ }
+ // get block information
+ std::list<BlockInfo*>& block_infos() { return block_infos_; }
+ // delete file and return the size of the file
+ virtual Status Delete(uint64_t* size);
+
+ protected:
+ port::RWMutex rwlock_; // synchronization mutex
+ Env* const env_ = nullptr; // Env for OS
+ const std::string dir_; // Directory name
+ const uint32_t cache_id_; // Cache id for the file
+ std::list<BlockInfo*> block_infos_; // List of index entries mapping to the
+ // file content
+};
+
+// class RandomAccessFile
+//
+// Thread safe implementation for reading random data from file
+class RandomAccessCacheFile : public BlockCacheFile {
+ public:
+ explicit RandomAccessCacheFile(Env* const env, const std::string& dir,
+ const uint32_t cache_id,
+ const std::shared_ptr<Logger>& log)
+ : BlockCacheFile(env, dir, cache_id), log_(log) {}
+
+ virtual ~RandomAccessCacheFile() {}
+
+ // open file for reading
+ bool Open(const bool enable_direct_reads);
+ // read data from the disk
+ bool Read(const LBA& lba, Slice* key, Slice* block, char* scratch) override;
+
+ private:
+ std::unique_ptr<RandomAccessFileReader> freader_;
+
+ protected:
+ bool OpenImpl(const bool enable_direct_reads);
+ bool ParseRec(const LBA& lba, Slice* key, Slice* val, char* scratch);
+
+ std::shared_ptr<Logger> log_; // log file
+};
+
+// class WriteableCacheFile
+//
+// All writes to the files are cached in buffers. The buffers are flushed to
+// disk as they get filled up. When file size reaches a certain size, a new file
+// will be created provided there is free space
+class WriteableCacheFile : public RandomAccessCacheFile {
+ public:
+ explicit WriteableCacheFile(Env* const env, CacheWriteBufferAllocator* alloc,
+ Writer* writer, const std::string& dir,
+ const uint32_t cache_id, const uint32_t max_size,
+ const std::shared_ptr<Logger>& log)
+ : RandomAccessCacheFile(env, dir, cache_id, log),
+ alloc_(alloc),
+ writer_(writer),
+ max_size_(max_size) {}
+
+ virtual ~WriteableCacheFile();
+
+ // create file on disk
+ bool Create(const bool enable_direct_writes, const bool enable_direct_reads);
+
+ // read data from logical file
+ bool Read(const LBA& lba, Slice* key, Slice* block, char* scratch) override {
+ ReadLock _(&rwlock_);
+ const bool closed = eof_ && bufs_.empty();
+ if (closed) {
+ // the file is closed, read from disk
+ return RandomAccessCacheFile::Read(lba, key, block, scratch);
+ }
+ // file is still being written, read from buffers
+ return ReadBuffer(lba, key, block, scratch);
+ }
+
+ // append data to end of file
+ bool Append(const Slice&, const Slice&, LBA* const) override;
+ // End-of-file
+ bool Eof() const { return eof_; }
+
+ private:
+ friend class ThreadedWriter;
+
+ static const size_t kFileAlignmentSize = 4 * 1024; // align file size
+
+ bool ReadBuffer(const LBA& lba, Slice* key, Slice* block, char* scratch);
+ bool ReadBuffer(const LBA& lba, char* data);
+ bool ExpandBuffer(const size_t size);
+ void DispatchBuffer();
+ void BufferWriteDone();
+ void CloseAndOpenForReading();
+ void ClearBuffers();
+ void Close();
+
+ // File layout in memory
+ //
+ // +------+------+------+------+------+------+
+ // | b0 | b1 | b2 | b3 | b4 | b5 |
+ // +------+------+------+------+------+------+
+ // ^ ^
+ // | |
+ // buf_doff_ buf_woff_
+ // (next buffer to (next buffer to fill)
+ // flush to disk)
+ //
+ // The buffers are flushed to disk serially for a given file
+
+ CacheWriteBufferAllocator* const alloc_ = nullptr; // Buffer provider
+ Writer* const writer_ = nullptr; // File writer thread
+ std::unique_ptr<WritableFile> file_; // RocksDB Env file abstraction
+ std::vector<CacheWriteBuffer*> bufs_; // Written buffers
+ uint32_t size_ = 0; // Size of the file
+ const uint32_t max_size_; // Max size of the file
+ bool eof_ = false; // End of file
+ uint32_t disk_woff_ = 0; // Offset to write on disk
+ size_t buf_woff_ = 0; // off into bufs_ to write
+ size_t buf_doff_ = 0; // off into bufs_ to dispatch
+ size_t pending_ios_ = 0; // Number of ios to disk in-progress
+ bool enable_direct_reads_ = false; // Should we enable direct reads
+ // when reading from disk
+};
+
+//
+// Abstraction to do writing to device. It is part of pipelined architecture.
+//
+class ThreadedWriter : public Writer {
+ public:
+ // Representation of IO to device
+ struct IO {
+ explicit IO(const bool signal) : signal_(signal) {}
+ explicit IO(WritableFile* const file, CacheWriteBuffer* const buf,
+ const uint64_t file_off, const std::function<void()> callback)
+ : file_(file), buf_(buf), file_off_(file_off), callback_(callback) {}
+
+ IO(const IO&) = default;
+ IO& operator=(const IO&) = default;
+ size_t Size() const { return sizeof(IO); }
+
+ WritableFile* file_ = nullptr; // File to write to
+ CacheWriteBuffer* buf_ = nullptr; // buffer to write
+ uint64_t file_off_ = 0; // file offset
+ bool signal_ = false; // signal to exit thread loop
+ std::function<void()> callback_; // Callback on completion
+ };
+
+ explicit ThreadedWriter(PersistentCacheTier* const cache, const size_t qdepth,
+ const size_t io_size);
+ virtual ~ThreadedWriter() { assert(threads_.empty()); }
+
+ void Stop() override;
+ void Write(WritableFile* const file, CacheWriteBuffer* buf,
+ const uint64_t file_off,
+ const std::function<void()> callback) override;
+
+ private:
+ void ThreadMain();
+ void DispatchIO(const IO& io);
+
+ const size_t io_size_ = 0;
+ BoundedQueue<IO> q_;
+ std::vector<port::Thread> threads_;
+};
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif
diff --git a/src/rocksdb/utilities/persistent_cache/block_cache_tier_file_buffer.h b/src/rocksdb/utilities/persistent_cache/block_cache_tier_file_buffer.h
new file mode 100644
index 000000000..23013d720
--- /dev/null
+++ b/src/rocksdb/utilities/persistent_cache/block_cache_tier_file_buffer.h
@@ -0,0 +1,127 @@
+// Copyright (c) 2013, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+#pragma once
+
+#include <list>
+#include <memory>
+#include <string>
+
+#include "rocksdb/comparator.h"
+#include "memory/arena.h"
+#include "util/mutexlock.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+//
+// CacheWriteBuffer
+//
+// Buffer abstraction that can be manipulated via append
+// (not thread safe)
+class CacheWriteBuffer {
+ public:
+ explicit CacheWriteBuffer(const size_t size) : size_(size), pos_(0) {
+ buf_.reset(new char[size_]);
+ assert(!pos_);
+ assert(size_);
+ }
+
+ virtual ~CacheWriteBuffer() {}
+
+ void Append(const char* buf, const size_t size) {
+ assert(pos_ + size <= size_);
+ memcpy(buf_.get() + pos_, buf, size);
+ pos_ += size;
+ assert(pos_ <= size_);
+ }
+
+ void FillTrailingZeros() {
+ assert(pos_ <= size_);
+ memset(buf_.get() + pos_, '0', size_ - pos_);
+ pos_ = size_;
+ }
+
+ void Reset() { pos_ = 0; }
+ size_t Free() const { return size_ - pos_; }
+ size_t Capacity() const { return size_; }
+ size_t Used() const { return pos_; }
+ char* Data() const { return buf_.get(); }
+
+ private:
+ std::unique_ptr<char[]> buf_;
+ const size_t size_;
+ size_t pos_;
+};
+
+//
+// CacheWriteBufferAllocator
+//
+// Buffer pool abstraction(not thread safe)
+//
+class CacheWriteBufferAllocator {
+ public:
+ explicit CacheWriteBufferAllocator(const size_t buffer_size,
+ const size_t buffer_count)
+ : cond_empty_(&lock_), buffer_size_(buffer_size) {
+ MutexLock _(&lock_);
+ buffer_size_ = buffer_size;
+ for (uint32_t i = 0; i < buffer_count; i++) {
+ auto* buf = new CacheWriteBuffer(buffer_size_);
+ assert(buf);
+ if (buf) {
+ bufs_.push_back(buf);
+ cond_empty_.Signal();
+ }
+ }
+ }
+
+ virtual ~CacheWriteBufferAllocator() {
+ MutexLock _(&lock_);
+ assert(bufs_.size() * buffer_size_ == Capacity());
+ for (auto* buf : bufs_) {
+ delete buf;
+ }
+ bufs_.clear();
+ }
+
+ CacheWriteBuffer* Allocate() {
+ MutexLock _(&lock_);
+ if (bufs_.empty()) {
+ return nullptr;
+ }
+
+ assert(!bufs_.empty());
+ CacheWriteBuffer* const buf = bufs_.front();
+ bufs_.pop_front();
+ return buf;
+ }
+
+ void Deallocate(CacheWriteBuffer* const buf) {
+ assert(buf);
+ MutexLock _(&lock_);
+ buf->Reset();
+ bufs_.push_back(buf);
+ cond_empty_.Signal();
+ }
+
+ void WaitUntilUsable() {
+ // We are asked to wait till we have buffers available
+ MutexLock _(&lock_);
+ while (bufs_.empty()) {
+ cond_empty_.Wait();
+ }
+ }
+
+ size_t Capacity() const { return bufs_.size() * buffer_size_; }
+ size_t Free() const { return bufs_.size() * buffer_size_; }
+ size_t BufferSize() const { return buffer_size_; }
+
+ private:
+ port::Mutex lock_; // Sync lock
+ port::CondVar cond_empty_; // Condition var for empty buffers
+ size_t buffer_size_; // Size of each buffer
+ std::list<CacheWriteBuffer*> bufs_; // Buffer stash
+};
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/persistent_cache/block_cache_tier_metadata.cc b/src/rocksdb/utilities/persistent_cache/block_cache_tier_metadata.cc
new file mode 100644
index 000000000..c99322e6b
--- /dev/null
+++ b/src/rocksdb/utilities/persistent_cache/block_cache_tier_metadata.cc
@@ -0,0 +1,86 @@
+// Copyright (c) 2013, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+#ifndef ROCKSDB_LITE
+
+#include "utilities/persistent_cache/block_cache_tier_metadata.h"
+
+#include <functional>
+
+namespace ROCKSDB_NAMESPACE {
+
+bool BlockCacheTierMetadata::Insert(BlockCacheFile* file) {
+ return cache_file_index_.Insert(file);
+}
+
+BlockCacheFile* BlockCacheTierMetadata::Lookup(const uint32_t cache_id) {
+ BlockCacheFile* ret = nullptr;
+ BlockCacheFile lookup_key(cache_id);
+ bool ok = cache_file_index_.Find(&lookup_key, &ret);
+ if (ok) {
+ assert(ret->refs_);
+ return ret;
+ }
+ return nullptr;
+}
+
+BlockCacheFile* BlockCacheTierMetadata::Evict() {
+ using std::placeholders::_1;
+ auto fn = std::bind(&BlockCacheTierMetadata::RemoveAllKeys, this, _1);
+ return cache_file_index_.Evict(fn);
+}
+
+void BlockCacheTierMetadata::Clear() {
+ cache_file_index_.Clear([](BlockCacheFile* arg){ delete arg; });
+ block_index_.Clear([](BlockInfo* arg){ delete arg; });
+}
+
+BlockInfo* BlockCacheTierMetadata::Insert(const Slice& key, const LBA& lba) {
+ std::unique_ptr<BlockInfo> binfo(new BlockInfo(key, lba));
+ if (!block_index_.Insert(binfo.get())) {
+ return nullptr;
+ }
+ return binfo.release();
+}
+
+bool BlockCacheTierMetadata::Lookup(const Slice& key, LBA* lba) {
+ BlockInfo lookup_key(key);
+ BlockInfo* block;
+ port::RWMutex* rlock = nullptr;
+ if (!block_index_.Find(&lookup_key, &block, &rlock)) {
+ return false;
+ }
+
+ ReadUnlock _(rlock);
+ assert(block->key_ == key.ToString());
+ if (lba) {
+ *lba = block->lba_;
+ }
+ return true;
+}
+
+BlockInfo* BlockCacheTierMetadata::Remove(const Slice& key) {
+ BlockInfo lookup_key(key);
+ BlockInfo* binfo = nullptr;
+ bool ok __attribute__((__unused__));
+ ok = block_index_.Erase(&lookup_key, &binfo);
+ assert(ok);
+ return binfo;
+}
+
+void BlockCacheTierMetadata::RemoveAllKeys(BlockCacheFile* f) {
+ for (BlockInfo* binfo : f->block_infos()) {
+ BlockInfo* tmp = nullptr;
+ bool status = block_index_.Erase(binfo, &tmp);
+ (void)status;
+ assert(status);
+ assert(tmp == binfo);
+ delete binfo;
+ }
+ f->block_infos().clear();
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif
diff --git a/src/rocksdb/utilities/persistent_cache/block_cache_tier_metadata.h b/src/rocksdb/utilities/persistent_cache/block_cache_tier_metadata.h
new file mode 100644
index 000000000..92adae2bf
--- /dev/null
+++ b/src/rocksdb/utilities/persistent_cache/block_cache_tier_metadata.h
@@ -0,0 +1,125 @@
+// Copyright (c) 2013, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+#pragma once
+
+#ifndef ROCKSDB_LITE
+
+#include <functional>
+#include <string>
+#include <unordered_map>
+
+#include "rocksdb/slice.h"
+
+#include "utilities/persistent_cache/block_cache_tier_file.h"
+#include "utilities/persistent_cache/hash_table.h"
+#include "utilities/persistent_cache/hash_table_evictable.h"
+#include "utilities/persistent_cache/lrulist.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+//
+// Block Cache Tier Metadata
+//
+// The BlockCacheTierMetadata holds all the metadata associated with block
+// cache. It
+// fundamentally contains 2 indexes and an LRU.
+//
+// Block Cache Index
+//
+// This is a forward index that maps a given key to a LBA (Logical Block
+// Address). LBA is a disk pointer that points to a record on the cache.
+//
+// LBA = { cache-id, offset, size }
+//
+// Cache File Index
+//
+// This is a forward index that maps a given cache-id to a cache file object.
+// Typically you would lookup using LBA and use the object to read or write
+struct BlockInfo {
+ explicit BlockInfo(const Slice& key, const LBA& lba = LBA())
+ : key_(key.ToString()), lba_(lba) {}
+
+ std::string key_;
+ LBA lba_;
+};
+
+class BlockCacheTierMetadata {
+ public:
+ explicit BlockCacheTierMetadata(const uint32_t blocks_capacity = 1024 * 1024,
+ const uint32_t cachefile_capacity = 10 * 1024)
+ : cache_file_index_(cachefile_capacity), block_index_(blocks_capacity) {}
+
+ virtual ~BlockCacheTierMetadata() {}
+
+ // Insert a given cache file
+ bool Insert(BlockCacheFile* file);
+
+ // Lookup cache file based on cache_id
+ BlockCacheFile* Lookup(const uint32_t cache_id);
+
+ // Insert block information to block index
+ BlockInfo* Insert(const Slice& key, const LBA& lba);
+ // bool Insert(BlockInfo* binfo);
+
+ // Lookup block information from block index
+ bool Lookup(const Slice& key, LBA* lba);
+
+ // Remove a given from the block index
+ BlockInfo* Remove(const Slice& key);
+
+ // Find and evict a cache file using LRU policy
+ BlockCacheFile* Evict();
+
+ // Clear the metadata contents
+ virtual void Clear();
+
+ protected:
+ // Remove all block information from a given file
+ virtual void RemoveAllKeys(BlockCacheFile* file);
+
+ private:
+ // Cache file index definition
+ //
+ // cache-id => BlockCacheFile
+ struct BlockCacheFileHash {
+ uint64_t operator()(const BlockCacheFile* rec) {
+ return std::hash<uint32_t>()(rec->cacheid());
+ }
+ };
+
+ struct BlockCacheFileEqual {
+ uint64_t operator()(const BlockCacheFile* lhs, const BlockCacheFile* rhs) {
+ return lhs->cacheid() == rhs->cacheid();
+ }
+ };
+
+ typedef EvictableHashTable<BlockCacheFile, BlockCacheFileHash,
+ BlockCacheFileEqual>
+ CacheFileIndexType;
+
+ // Block Lookup Index
+ //
+ // key => LBA
+ struct Hash {
+ size_t operator()(BlockInfo* node) const {
+ return std::hash<std::string>()(node->key_);
+ }
+ };
+
+ struct Equal {
+ size_t operator()(BlockInfo* lhs, BlockInfo* rhs) const {
+ return lhs->key_ == rhs->key_;
+ }
+ };
+
+ typedef HashTable<BlockInfo*, Hash, Equal> BlockIndexType;
+
+ CacheFileIndexType cache_file_index_;
+ BlockIndexType block_index_;
+};
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif
diff --git a/src/rocksdb/utilities/persistent_cache/hash_table.h b/src/rocksdb/utilities/persistent_cache/hash_table.h
new file mode 100644
index 000000000..3d0a1f993
--- /dev/null
+++ b/src/rocksdb/utilities/persistent_cache/hash_table.h
@@ -0,0 +1,238 @@
+// Copyright (c) 2013, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+#pragma once
+
+#ifndef ROCKSDB_LITE
+
+#include <assert.h>
+#include <list>
+#include <vector>
+
+#ifdef OS_LINUX
+#include <sys/mman.h>
+#endif
+
+#include "rocksdb/env.h"
+#include "util/mutexlock.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// HashTable<T, Hash, Equal>
+//
+// Traditional implementation of hash table with synchronization built on top
+// don't perform very well in multi-core scenarios. This is an implementation
+// designed for multi-core scenarios with high lock contention.
+//
+// |<-------- alpha ------------->|
+// Buckets Collision list
+// ---- +----+ +---+---+--- ...... ---+---+---+
+// / | |--->| | | | | |
+// / +----+ +---+---+--- ...... ---+---+---+
+// / | |
+// Locks/ +----+
+// +--+/ . .
+// | | . .
+// +--+ . .
+// | | . .
+// +--+ . .
+// | | . .
+// +--+ . .
+// \ +----+
+// \ | |
+// \ +----+
+// \ | |
+// \---- +----+
+//
+// The lock contention is spread over an array of locks. This helps improve
+// concurrent access. The spine is designed for a certain capacity and load
+// factor. When the capacity planning is done correctly we can expect
+// O(load_factor = 1) insert, access and remove time.
+//
+// Micro benchmark on debug build gives about .5 Million/sec rate of insert,
+// erase and lookup in parallel (total of about 1.5 Million ops/sec). If the
+// blocks were of 4K, the hash table can support a virtual throughput of
+// 6 GB/s.
+//
+// T Object type (contains both key and value)
+// Hash Function that returns an hash from type T
+// Equal Returns if two objects are equal
+// (We need explicit equal for pointer type)
+//
+template <class T, class Hash, class Equal>
+class HashTable {
+ public:
+ explicit HashTable(const size_t capacity = 1024 * 1024,
+ const float load_factor = 2.0, const uint32_t nlocks = 256)
+ : nbuckets_(
+ static_cast<uint32_t>(load_factor ? capacity / load_factor : 0)),
+ nlocks_(nlocks) {
+ // pre-conditions
+ assert(capacity);
+ assert(load_factor);
+ assert(nbuckets_);
+ assert(nlocks_);
+
+ buckets_.reset(new Bucket[nbuckets_]);
+#ifdef OS_LINUX
+ mlock(buckets_.get(), nbuckets_ * sizeof(Bucket));
+#endif
+
+ // initialize locks
+ locks_.reset(new port::RWMutex[nlocks_]);
+#ifdef OS_LINUX
+ mlock(locks_.get(), nlocks_ * sizeof(port::RWMutex));
+#endif
+
+ // post-conditions
+ assert(buckets_);
+ assert(locks_);
+ }
+
+ virtual ~HashTable() { AssertEmptyBuckets(); }
+
+ //
+ // Insert given record to hash table
+ //
+ bool Insert(const T& t) {
+ const uint64_t h = Hash()(t);
+ const uint32_t bucket_idx = h % nbuckets_;
+ const uint32_t lock_idx = bucket_idx % nlocks_;
+
+ WriteLock _(&locks_[lock_idx]);
+ auto& bucket = buckets_[bucket_idx];
+ return Insert(&bucket, t);
+ }
+
+ // Lookup hash table
+ //
+ // Please note that read lock should be held by the caller. This is because
+ // the caller owns the data, and should hold the read lock as long as he
+ // operates on the data.
+ bool Find(const T& t, T* ret, port::RWMutex** ret_lock) {
+ const uint64_t h = Hash()(t);
+ const uint32_t bucket_idx = h % nbuckets_;
+ const uint32_t lock_idx = bucket_idx % nlocks_;
+
+ port::RWMutex& lock = locks_[lock_idx];
+ lock.ReadLock();
+
+ auto& bucket = buckets_[bucket_idx];
+ if (Find(&bucket, t, ret)) {
+ *ret_lock = &lock;
+ return true;
+ }
+
+ lock.ReadUnlock();
+ return false;
+ }
+
+ //
+ // Erase a given key from the hash table
+ //
+ bool Erase(const T& t, T* ret) {
+ const uint64_t h = Hash()(t);
+ const uint32_t bucket_idx = h % nbuckets_;
+ const uint32_t lock_idx = bucket_idx % nlocks_;
+
+ WriteLock _(&locks_[lock_idx]);
+
+ auto& bucket = buckets_[bucket_idx];
+ return Erase(&bucket, t, ret);
+ }
+
+ // Fetch the mutex associated with a key
+ // This call is used to hold the lock for a given data for extended period of
+ // time.
+ port::RWMutex* GetMutex(const T& t) {
+ const uint64_t h = Hash()(t);
+ const uint32_t bucket_idx = h % nbuckets_;
+ const uint32_t lock_idx = bucket_idx % nlocks_;
+
+ return &locks_[lock_idx];
+ }
+
+ void Clear(void (*fn)(T)) {
+ for (uint32_t i = 0; i < nbuckets_; ++i) {
+ const uint32_t lock_idx = i % nlocks_;
+ WriteLock _(&locks_[lock_idx]);
+ for (auto& t : buckets_[i].list_) {
+ (*fn)(t);
+ }
+ buckets_[i].list_.clear();
+ }
+ }
+
+ protected:
+ // Models bucket of keys that hash to the same bucket number
+ struct Bucket {
+ std::list<T> list_;
+ };
+
+ // Substitute for std::find with custom comparator operator
+ typename std::list<T>::iterator Find(std::list<T>* list, const T& t) {
+ for (auto it = list->begin(); it != list->end(); ++it) {
+ if (Equal()(*it, t)) {
+ return it;
+ }
+ }
+ return list->end();
+ }
+
+ bool Insert(Bucket* bucket, const T& t) {
+ // Check if the key already exists
+ auto it = Find(&bucket->list_, t);
+ if (it != bucket->list_.end()) {
+ return false;
+ }
+
+ // insert to bucket
+ bucket->list_.push_back(t);
+ return true;
+ }
+
+ bool Find(Bucket* bucket, const T& t, T* ret) {
+ auto it = Find(&bucket->list_, t);
+ if (it != bucket->list_.end()) {
+ if (ret) {
+ *ret = *it;
+ }
+ return true;
+ }
+ return false;
+ }
+
+ bool Erase(Bucket* bucket, const T& t, T* ret) {
+ auto it = Find(&bucket->list_, t);
+ if (it != bucket->list_.end()) {
+ if (ret) {
+ *ret = *it;
+ }
+
+ bucket->list_.erase(it);
+ return true;
+ }
+ return false;
+ }
+
+ // assert that all buckets are empty
+ void AssertEmptyBuckets() {
+#ifndef NDEBUG
+ for (size_t i = 0; i < nbuckets_; ++i) {
+ WriteLock _(&locks_[i % nlocks_]);
+ assert(buckets_[i].list_.empty());
+ }
+#endif
+ }
+
+ const uint32_t nbuckets_; // No. of buckets in the spine
+ std::unique_ptr<Bucket[]> buckets_; // Spine of the hash buckets
+ const uint32_t nlocks_; // No. of locks
+ std::unique_ptr<port::RWMutex[]> locks_; // Granular locks
+};
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif
diff --git a/src/rocksdb/utilities/persistent_cache/hash_table_bench.cc b/src/rocksdb/utilities/persistent_cache/hash_table_bench.cc
new file mode 100644
index 000000000..a1f05007e
--- /dev/null
+++ b/src/rocksdb/utilities/persistent_cache/hash_table_bench.cc
@@ -0,0 +1,308 @@
+// Copyright (c) 2013, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+
+#if !defined(OS_WIN) && !defined(ROCKSDB_LITE)
+
+#ifndef GFLAGS
+#include <cstdio>
+int main() { fprintf(stderr, "Please install gflags to run tools\n"); }
+#else
+
+#include <atomic>
+#include <functional>
+#include <string>
+#include <unordered_map>
+#include <unistd.h>
+#include <sys/time.h>
+
+#include "port/port_posix.h"
+#include "rocksdb/env.h"
+#include "util/gflags_compat.h"
+#include "util/mutexlock.h"
+#include "util/random.h"
+#include "utilities/persistent_cache/hash_table.h"
+
+using std::string;
+
+DEFINE_int32(nsec, 10, "nsec");
+DEFINE_int32(nthread_write, 1, "insert %");
+DEFINE_int32(nthread_read, 1, "lookup %");
+DEFINE_int32(nthread_erase, 1, "erase %");
+
+namespace ROCKSDB_NAMESPACE {
+
+//
+// HashTableImpl interface
+//
+// Abstraction of a hash table implementation
+template <class Key, class Value>
+class HashTableImpl {
+ public:
+ virtual ~HashTableImpl() {}
+
+ virtual bool Insert(const Key& key, const Value& val) = 0;
+ virtual bool Erase(const Key& key) = 0;
+ virtual bool Lookup(const Key& key, Value* val) = 0;
+};
+
+// HashTableBenchmark
+//
+// Abstraction to test a given hash table implementation. The test mostly
+// focus on insert, lookup and erase. The test can operate in test mode and
+// benchmark mode.
+class HashTableBenchmark {
+ public:
+ explicit HashTableBenchmark(HashTableImpl<size_t, std::string>* impl,
+ const size_t sec = 10,
+ const size_t nthread_write = 1,
+ const size_t nthread_read = 1,
+ const size_t nthread_erase = 1)
+ : impl_(impl),
+ sec_(sec),
+ ninserts_(0),
+ nreads_(0),
+ nerases_(0),
+ nerases_failed_(0),
+ quit_(false) {
+ Prepop();
+
+ StartThreads(nthread_write, WriteMain);
+ StartThreads(nthread_read, ReadMain);
+ StartThreads(nthread_erase, EraseMain);
+
+ uint64_t start = NowInMillSec();
+ while (!quit_) {
+ quit_ = NowInMillSec() - start > sec_ * 1000;
+ /* sleep override */ sleep(1);
+ }
+
+ Env* env = Env::Default();
+ env->WaitForJoin();
+
+ if (sec_) {
+ printf("Result \n");
+ printf("====== \n");
+ printf("insert/sec = %f \n", ninserts_ / static_cast<double>(sec_));
+ printf("read/sec = %f \n", nreads_ / static_cast<double>(sec_));
+ printf("erases/sec = %f \n", nerases_ / static_cast<double>(sec_));
+ const uint64_t ops = ninserts_ + nreads_ + nerases_;
+ printf("ops/sec = %f \n", ops / static_cast<double>(sec_));
+ printf("erase fail = %d (%f%%)\n", static_cast<int>(nerases_failed_),
+ static_cast<float>(nerases_failed_ / nerases_ * 100));
+ printf("====== \n");
+ }
+ }
+
+ void RunWrite() {
+ while (!quit_) {
+ size_t k = insert_key_++;
+ std::string tmp(1000, k % 255);
+ bool status = impl_->Insert(k, tmp);
+ assert(status);
+ ninserts_++;
+ }
+ }
+
+ void RunRead() {
+ Random64 rgen(time(nullptr));
+ while (!quit_) {
+ std::string s;
+ size_t k = rgen.Next() % max_prepop_key;
+ bool status = impl_->Lookup(k, &s);
+ assert(status);
+ assert(s == std::string(1000, k % 255));
+ nreads_++;
+ }
+ }
+
+ void RunErase() {
+ while (!quit_) {
+ size_t k = erase_key_++;
+ bool status = impl_->Erase(k);
+ nerases_failed_ += !status;
+ nerases_++;
+ }
+ }
+
+ private:
+ // Start threads for a given function
+ void StartThreads(const size_t n, void (*fn)(void*)) {
+ Env* env = Env::Default();
+ for (size_t i = 0; i < n; ++i) {
+ env->StartThread(fn, this);
+ }
+ }
+
+ // Prepop the hash table with 1M keys
+ void Prepop() {
+ for (size_t i = 0; i < max_prepop_key; ++i) {
+ bool status = impl_->Insert(i, std::string(1000, i % 255));
+ assert(status);
+ }
+
+ erase_key_ = insert_key_ = max_prepop_key;
+
+ for (size_t i = 0; i < 10 * max_prepop_key; ++i) {
+ bool status = impl_->Insert(insert_key_++, std::string(1000, 'x'));
+ assert(status);
+ }
+ }
+
+ static uint64_t NowInMillSec() {
+ timeval tv;
+ gettimeofday(&tv, /*tz=*/nullptr);
+ return tv.tv_sec * 1000 + tv.tv_usec / 1000;
+ }
+
+ //
+ // Wrapper functions for thread entry
+ //
+ static void WriteMain(void* args) {
+ reinterpret_cast<HashTableBenchmark*>(args)->RunWrite();
+ }
+
+ static void ReadMain(void* args) {
+ reinterpret_cast<HashTableBenchmark*>(args)->RunRead();
+ }
+
+ static void EraseMain(void* args) {
+ reinterpret_cast<HashTableBenchmark*>(args)->RunErase();
+ }
+
+ HashTableImpl<size_t, std::string>* impl_; // Implementation to test
+ const size_t sec_; // Test time
+ const size_t max_prepop_key = 1ULL * 1024 * 1024; // Max prepop key
+ std::atomic<size_t> insert_key_; // Last inserted key
+ std::atomic<size_t> erase_key_; // Erase key
+ std::atomic<size_t> ninserts_; // Number of inserts
+ std::atomic<size_t> nreads_; // Number of reads
+ std::atomic<size_t> nerases_; // Number of erases
+ std::atomic<size_t> nerases_failed_; // Number of erases failed
+ bool quit_; // Should the threads quit ?
+};
+
+//
+// SimpleImpl
+// Lock safe unordered_map implementation
+class SimpleImpl : public HashTableImpl<size_t, string> {
+ public:
+ bool Insert(const size_t& key, const string& val) override {
+ WriteLock _(&rwlock_);
+ map_.insert(make_pair(key, val));
+ return true;
+ }
+
+ bool Erase(const size_t& key) override {
+ WriteLock _(&rwlock_);
+ auto it = map_.find(key);
+ if (it == map_.end()) {
+ return false;
+ }
+ map_.erase(it);
+ return true;
+ }
+
+ bool Lookup(const size_t& key, string* val) override {
+ ReadLock _(&rwlock_);
+ auto it = map_.find(key);
+ if (it != map_.end()) {
+ *val = it->second;
+ }
+ return it != map_.end();
+ }
+
+ private:
+ port::RWMutex rwlock_;
+ std::unordered_map<size_t, string> map_;
+};
+
+//
+// GranularLockImpl
+// Thread safe custom RocksDB implementation of hash table with granular
+// locking
+class GranularLockImpl : public HashTableImpl<size_t, string> {
+ public:
+ bool Insert(const size_t& key, const string& val) override {
+ Node n(key, val);
+ return impl_.Insert(n);
+ }
+
+ bool Erase(const size_t& key) override {
+ Node n(key, string());
+ return impl_.Erase(n, nullptr);
+ }
+
+ bool Lookup(const size_t& key, string* val) override {
+ Node n(key, string());
+ port::RWMutex* rlock;
+ bool status = impl_.Find(n, &n, &rlock);
+ if (status) {
+ ReadUnlock _(rlock);
+ *val = n.val_;
+ }
+ return status;
+ }
+
+ private:
+ struct Node {
+ explicit Node(const size_t key, const string& val) : key_(key), val_(val) {}
+
+ size_t key_ = 0;
+ string val_;
+ };
+
+ struct Hash {
+ uint64_t operator()(const Node& node) {
+ return std::hash<uint64_t>()(node.key_);
+ }
+ };
+
+ struct Equal {
+ bool operator()(const Node& lhs, const Node& rhs) {
+ return lhs.key_ == rhs.key_;
+ }
+ };
+
+ HashTable<Node, Hash, Equal> impl_;
+};
+
+} // namespace ROCKSDB_NAMESPACE
+
+//
+// main
+//
+int main(int argc, char** argv) {
+ GFLAGS_NAMESPACE::SetUsageMessage(std::string("\nUSAGE:\n") +
+ std::string(argv[0]) + " [OPTIONS]...");
+ GFLAGS_NAMESPACE::ParseCommandLineFlags(&argc, &argv, false);
+
+ //
+ // Micro benchmark unordered_map
+ //
+ printf("Micro benchmarking std::unordered_map \n");
+ {
+ ROCKSDB_NAMESPACE::SimpleImpl impl;
+ ROCKSDB_NAMESPACE::HashTableBenchmark _(
+ &impl, FLAGS_nsec, FLAGS_nthread_write, FLAGS_nthread_read,
+ FLAGS_nthread_erase);
+ }
+ //
+ // Micro benchmark scalable hash table
+ //
+ printf("Micro benchmarking scalable hash map \n");
+ {
+ ROCKSDB_NAMESPACE::GranularLockImpl impl;
+ ROCKSDB_NAMESPACE::HashTableBenchmark _(
+ &impl, FLAGS_nsec, FLAGS_nthread_write, FLAGS_nthread_read,
+ FLAGS_nthread_erase);
+ }
+
+ return 0;
+}
+#endif // #ifndef GFLAGS
+#else
+int main(int /*argc*/, char** /*argv*/) { return 0; }
+#endif
diff --git a/src/rocksdb/utilities/persistent_cache/hash_table_evictable.h b/src/rocksdb/utilities/persistent_cache/hash_table_evictable.h
new file mode 100644
index 000000000..d27205d08
--- /dev/null
+++ b/src/rocksdb/utilities/persistent_cache/hash_table_evictable.h
@@ -0,0 +1,168 @@
+// Copyright (c) 2013, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+#pragma once
+
+#ifndef ROCKSDB_LITE
+
+#include <functional>
+
+#include "util/random.h"
+#include "utilities/persistent_cache/hash_table.h"
+#include "utilities/persistent_cache/lrulist.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// Evictable Hash Table
+//
+// Hash table index where least accessed (or one of the least accessed) elements
+// can be evicted.
+//
+// Please note EvictableHashTable can only be created for pointer type objects
+template <class T, class Hash, class Equal>
+class EvictableHashTable : private HashTable<T*, Hash, Equal> {
+ public:
+ typedef HashTable<T*, Hash, Equal> hash_table;
+
+ explicit EvictableHashTable(const size_t capacity = 1024 * 1024,
+ const float load_factor = 2.0,
+ const uint32_t nlocks = 256)
+ : HashTable<T*, Hash, Equal>(capacity, load_factor, nlocks),
+ lru_lists_(new LRUList<T>[hash_table::nlocks_]) {
+ assert(lru_lists_);
+ }
+
+ virtual ~EvictableHashTable() { AssertEmptyLRU(); }
+
+ //
+ // Insert given record to hash table (and LRU list)
+ //
+ bool Insert(T* t) {
+ const uint64_t h = Hash()(t);
+ typename hash_table::Bucket& bucket = GetBucket(h);
+ LRUListType& lru = GetLRUList(h);
+ port::RWMutex& lock = GetMutex(h);
+
+ WriteLock _(&lock);
+ if (hash_table::Insert(&bucket, t)) {
+ lru.Push(t);
+ return true;
+ }
+ return false;
+ }
+
+ //
+ // Lookup hash table
+ //
+ // Please note that read lock should be held by the caller. This is because
+ // the caller owns the data, and should hold the read lock as long as he
+ // operates on the data.
+ bool Find(T* t, T** ret) {
+ const uint64_t h = Hash()(t);
+ typename hash_table::Bucket& bucket = GetBucket(h);
+ LRUListType& lru = GetLRUList(h);
+ port::RWMutex& lock = GetMutex(h);
+
+ ReadLock _(&lock);
+ if (hash_table::Find(&bucket, t, ret)) {
+ ++(*ret)->refs_;
+ lru.Touch(*ret);
+ return true;
+ }
+ return false;
+ }
+
+ //
+ // Evict one of the least recently used object
+ //
+ T* Evict(const std::function<void(T*)>& fn = nullptr) {
+ uint32_t random = Random::GetTLSInstance()->Next();
+ const size_t start_idx = random % hash_table::nlocks_;
+ T* t = nullptr;
+
+ // iterate from start_idx .. 0 .. start_idx
+ for (size_t i = 0; !t && i < hash_table::nlocks_; ++i) {
+ const size_t idx = (start_idx + i) % hash_table::nlocks_;
+
+ WriteLock _(&hash_table::locks_[idx]);
+ LRUListType& lru = lru_lists_[idx];
+ if (!lru.IsEmpty() && (t = lru.Pop()) != nullptr) {
+ assert(!t->refs_);
+ // We got an item to evict, erase from the bucket
+ const uint64_t h = Hash()(t);
+ typename hash_table::Bucket& bucket = GetBucket(h);
+ T* tmp = nullptr;
+ bool status = hash_table::Erase(&bucket, t, &tmp);
+ assert(t == tmp);
+ (void)status;
+ assert(status);
+ if (fn) {
+ fn(t);
+ }
+ break;
+ }
+ assert(!t);
+ }
+ return t;
+ }
+
+ void Clear(void (*fn)(T*)) {
+ for (uint32_t i = 0; i < hash_table::nbuckets_; ++i) {
+ const uint32_t lock_idx = i % hash_table::nlocks_;
+ WriteLock _(&hash_table::locks_[lock_idx]);
+ auto& lru_list = lru_lists_[lock_idx];
+ auto& bucket = hash_table::buckets_[i];
+ for (auto* t : bucket.list_) {
+ lru_list.Unlink(t);
+ (*fn)(t);
+ }
+ bucket.list_.clear();
+ }
+ // make sure that all LRU lists are emptied
+ AssertEmptyLRU();
+ }
+
+ void AssertEmptyLRU() {
+#ifndef NDEBUG
+ for (uint32_t i = 0; i < hash_table::nlocks_; ++i) {
+ WriteLock _(&hash_table::locks_[i]);
+ auto& lru_list = lru_lists_[i];
+ assert(lru_list.IsEmpty());
+ }
+#endif
+ }
+
+ //
+ // Fetch the mutex associated with a key
+ // This call is used to hold the lock for a given data for extended period of
+ // time.
+ port::RWMutex* GetMutex(T* t) { return hash_table::GetMutex(t); }
+
+ private:
+ typedef LRUList<T> LRUListType;
+
+ typename hash_table::Bucket& GetBucket(const uint64_t h) {
+ const uint32_t bucket_idx = h % hash_table::nbuckets_;
+ return hash_table::buckets_[bucket_idx];
+ }
+
+ LRUListType& GetLRUList(const uint64_t h) {
+ const uint32_t bucket_idx = h % hash_table::nbuckets_;
+ const uint32_t lock_idx = bucket_idx % hash_table::nlocks_;
+ return lru_lists_[lock_idx];
+ }
+
+ port::RWMutex& GetMutex(const uint64_t h) {
+ const uint32_t bucket_idx = h % hash_table::nbuckets_;
+ const uint32_t lock_idx = bucket_idx % hash_table::nlocks_;
+ return hash_table::locks_[lock_idx];
+ }
+
+ std::unique_ptr<LRUListType[]> lru_lists_;
+};
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif
diff --git a/src/rocksdb/utilities/persistent_cache/hash_table_test.cc b/src/rocksdb/utilities/persistent_cache/hash_table_test.cc
new file mode 100644
index 000000000..62a5b1d40
--- /dev/null
+++ b/src/rocksdb/utilities/persistent_cache/hash_table_test.cc
@@ -0,0 +1,160 @@
+// Copyright (c) 2013, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+#include <stdlib.h>
+#include <iostream>
+#include <set>
+#include <string>
+
+#include "db/db_test_util.h"
+#include "memory/arena.h"
+#include "test_util/testharness.h"
+#include "util/random.h"
+#include "utilities/persistent_cache/hash_table.h"
+#include "utilities/persistent_cache/hash_table_evictable.h"
+
+#ifndef ROCKSDB_LITE
+
+namespace ROCKSDB_NAMESPACE {
+
+struct HashTableTest : public testing::Test {
+ ~HashTableTest() override { map_.Clear(&HashTableTest::ClearNode); }
+
+ struct Node {
+ Node() {}
+ explicit Node(const uint64_t key, const std::string& val = std::string())
+ : key_(key), val_(val) {}
+
+ uint64_t key_ = 0;
+ std::string val_;
+ };
+
+ struct Equal {
+ bool operator()(const Node& lhs, const Node& rhs) {
+ return lhs.key_ == rhs.key_;
+ }
+ };
+
+ struct Hash {
+ uint64_t operator()(const Node& node) {
+ return std::hash<uint64_t>()(node.key_);
+ }
+ };
+
+ static void ClearNode(Node /*node*/) {}
+
+ HashTable<Node, Hash, Equal> map_;
+};
+
+struct EvictableHashTableTest : public testing::Test {
+ ~EvictableHashTableTest() override {
+ map_.Clear(&EvictableHashTableTest::ClearNode);
+ }
+
+ struct Node : LRUElement<Node> {
+ Node() {}
+ explicit Node(const uint64_t key, const std::string& val = std::string())
+ : key_(key), val_(val) {}
+
+ uint64_t key_ = 0;
+ std::string val_;
+ std::atomic<uint32_t> refs_{0};
+ };
+
+ struct Equal {
+ bool operator()(const Node* lhs, const Node* rhs) {
+ return lhs->key_ == rhs->key_;
+ }
+ };
+
+ struct Hash {
+ uint64_t operator()(const Node* node) {
+ return std::hash<uint64_t>()(node->key_);
+ }
+ };
+
+ static void ClearNode(Node* /*node*/) {}
+
+ EvictableHashTable<Node, Hash, Equal> map_;
+};
+
+TEST_F(HashTableTest, TestInsert) {
+ const uint64_t max_keys = 1024 * 1024;
+
+ // insert
+ for (uint64_t k = 0; k < max_keys; ++k) {
+ map_.Insert(Node(k, std::string(1000, k % 255)));
+ }
+
+ // verify
+ for (uint64_t k = 0; k < max_keys; ++k) {
+ Node val;
+ port::RWMutex* rlock = nullptr;
+ assert(map_.Find(Node(k), &val, &rlock));
+ rlock->ReadUnlock();
+ assert(val.val_ == std::string(1000, k % 255));
+ }
+}
+
+TEST_F(HashTableTest, TestErase) {
+ const uint64_t max_keys = 1024 * 1024;
+ // insert
+ for (uint64_t k = 0; k < max_keys; ++k) {
+ map_.Insert(Node(k, std::string(1000, k % 255)));
+ }
+
+ auto rand = Random64(time(nullptr));
+ // erase a few keys randomly
+ std::set<uint64_t> erased;
+ for (int i = 0; i < 1024; ++i) {
+ uint64_t k = rand.Next() % max_keys;
+ if (erased.find(k) != erased.end()) {
+ continue;
+ }
+ assert(map_.Erase(Node(k), /*ret=*/nullptr));
+ erased.insert(k);
+ }
+
+ // verify
+ for (uint64_t k = 0; k < max_keys; ++k) {
+ Node val;
+ port::RWMutex* rlock = nullptr;
+ bool status = map_.Find(Node(k), &val, &rlock);
+ if (erased.find(k) == erased.end()) {
+ assert(status);
+ rlock->ReadUnlock();
+ assert(val.val_ == std::string(1000, k % 255));
+ } else {
+ assert(!status);
+ }
+ }
+}
+
+TEST_F(EvictableHashTableTest, TestEvict) {
+ const uint64_t max_keys = 1024 * 1024;
+
+ // insert
+ for (uint64_t k = 0; k < max_keys; ++k) {
+ map_.Insert(new Node(k, std::string(1000, k % 255)));
+ }
+
+ // verify
+ for (uint64_t k = 0; k < max_keys; ++k) {
+ Node* val = map_.Evict();
+ // unfortunately we can't predict eviction value since it is from any one of
+ // the lock stripe
+ assert(val);
+ assert(val->val_ == std::string(1000, val->key_ % 255));
+ delete val;
+ }
+}
+
+} // namespace ROCKSDB_NAMESPACE
+#endif
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/rocksdb/utilities/persistent_cache/lrulist.h b/src/rocksdb/utilities/persistent_cache/lrulist.h
new file mode 100644
index 000000000..a608890fc
--- /dev/null
+++ b/src/rocksdb/utilities/persistent_cache/lrulist.h
@@ -0,0 +1,174 @@
+// Copyright (c) 2013, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+#pragma once
+
+#ifndef ROCKSDB_LITE
+
+#include <atomic>
+
+#include "util/mutexlock.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// LRU element definition
+//
+// Any object that needs to be part of the LRU algorithm should extend this
+// class
+template <class T>
+struct LRUElement {
+ explicit LRUElement() : next_(nullptr), prev_(nullptr), refs_(0) {}
+
+ virtual ~LRUElement() { assert(!refs_); }
+
+ T* next_;
+ T* prev_;
+ std::atomic<size_t> refs_;
+};
+
+// LRU implementation
+//
+// In place LRU implementation. There is no copy or allocation involved when
+// inserting or removing an element. This makes the data structure slim
+template <class T>
+class LRUList {
+ public:
+ virtual ~LRUList() {
+ MutexLock _(&lock_);
+ assert(!head_);
+ assert(!tail_);
+ }
+
+ // Push element into the LRU at the cold end
+ inline void Push(T* const t) {
+ assert(t);
+ assert(!t->next_);
+ assert(!t->prev_);
+
+ MutexLock _(&lock_);
+
+ assert((!head_ && !tail_) || (head_ && tail_));
+ assert(!head_ || !head_->prev_);
+ assert(!tail_ || !tail_->next_);
+
+ t->next_ = head_;
+ if (head_) {
+ head_->prev_ = t;
+ }
+
+ head_ = t;
+ if (!tail_) {
+ tail_ = t;
+ }
+ }
+
+ // Unlink the element from the LRU
+ inline void Unlink(T* const t) {
+ MutexLock _(&lock_);
+ UnlinkImpl(t);
+ }
+
+ // Evict an element from the LRU
+ inline T* Pop() {
+ MutexLock _(&lock_);
+
+ assert(tail_ && head_);
+ assert(!tail_->next_);
+ assert(!head_->prev_);
+
+ T* t = head_;
+ while (t && t->refs_) {
+ t = t->next_;
+ }
+
+ if (!t) {
+ // nothing can be evicted
+ return nullptr;
+ }
+
+ assert(!t->refs_);
+
+ // unlike the element
+ UnlinkImpl(t);
+ return t;
+ }
+
+ // Move the element from the front of the list to the back of the list
+ inline void Touch(T* const t) {
+ MutexLock _(&lock_);
+ UnlinkImpl(t);
+ PushBackImpl(t);
+ }
+
+ // Check if the LRU is empty
+ inline bool IsEmpty() const {
+ MutexLock _(&lock_);
+ return !head_ && !tail_;
+ }
+
+ private:
+ // Unlink an element from the LRU
+ void UnlinkImpl(T* const t) {
+ assert(t);
+
+ lock_.AssertHeld();
+
+ assert(head_ && tail_);
+ assert(t->prev_ || head_ == t);
+ assert(t->next_ || tail_ == t);
+
+ if (t->prev_) {
+ t->prev_->next_ = t->next_;
+ }
+ if (t->next_) {
+ t->next_->prev_ = t->prev_;
+ }
+
+ if (tail_ == t) {
+ tail_ = tail_->prev_;
+ }
+ if (head_ == t) {
+ head_ = head_->next_;
+ }
+
+ t->next_ = t->prev_ = nullptr;
+ }
+
+ // Insert an element at the hot end
+ inline void PushBack(T* const t) {
+ MutexLock _(&lock_);
+ PushBackImpl(t);
+ }
+
+ inline void PushBackImpl(T* const t) {
+ assert(t);
+ assert(!t->next_);
+ assert(!t->prev_);
+
+ lock_.AssertHeld();
+
+ assert((!head_ && !tail_) || (head_ && tail_));
+ assert(!head_ || !head_->prev_);
+ assert(!tail_ || !tail_->next_);
+
+ t->prev_ = tail_;
+ if (tail_) {
+ tail_->next_ = t;
+ }
+
+ tail_ = t;
+ if (!head_) {
+ head_ = tail_;
+ }
+ }
+
+ mutable port::Mutex lock_; // synchronization primitive
+ T* head_ = nullptr; // front (cold)
+ T* tail_ = nullptr; // back (hot)
+};
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif
diff --git a/src/rocksdb/utilities/persistent_cache/persistent_cache_bench.cc b/src/rocksdb/utilities/persistent_cache/persistent_cache_bench.cc
new file mode 100644
index 000000000..359fcdd1d
--- /dev/null
+++ b/src/rocksdb/utilities/persistent_cache/persistent_cache_bench.cc
@@ -0,0 +1,360 @@
+// Copyright (c) 2013, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+#ifndef ROCKSDB_LITE
+
+#ifndef GFLAGS
+#include <cstdio>
+int main() { fprintf(stderr, "Please install gflags to run tools\n"); }
+#else
+#include <atomic>
+#include <functional>
+#include <memory>
+#include <sstream>
+#include <unordered_map>
+
+#include "rocksdb/env.h"
+
+#include "utilities/persistent_cache/block_cache_tier.h"
+#include "utilities/persistent_cache/persistent_cache_tier.h"
+#include "utilities/persistent_cache/volatile_tier_impl.h"
+
+#include "monitoring/histogram.h"
+#include "port/port.h"
+#include "table/block_based/block_builder.h"
+#include "util/gflags_compat.h"
+#include "util/mutexlock.h"
+#include "util/stop_watch.h"
+
+DEFINE_int32(nsec, 10, "nsec");
+DEFINE_int32(nthread_write, 1, "Insert threads");
+DEFINE_int32(nthread_read, 1, "Lookup threads");
+DEFINE_string(path, "/tmp/microbench/blkcache", "Path for cachefile");
+DEFINE_string(log_path, "/tmp/log", "Path for the log file");
+DEFINE_uint64(cache_size, std::numeric_limits<uint64_t>::max(), "Cache size");
+DEFINE_int32(iosize, 4 * 1024, "Read IO size");
+DEFINE_int32(writer_iosize, 4 * 1024, "File writer IO size");
+DEFINE_int32(writer_qdepth, 1, "File writer qdepth");
+DEFINE_bool(enable_pipelined_writes, false, "Enable async writes");
+DEFINE_string(cache_type, "block_cache",
+ "Cache type. (block_cache, volatile, tiered)");
+DEFINE_bool(benchmark, false, "Benchmark mode");
+DEFINE_int32(volatile_cache_pct, 10, "Percentage of cache in memory tier.");
+
+namespace ROCKSDB_NAMESPACE {
+
+std::unique_ptr<PersistentCacheTier> NewVolatileCache() {
+ assert(FLAGS_cache_size != std::numeric_limits<uint64_t>::max());
+ std::unique_ptr<PersistentCacheTier> pcache(
+ new VolatileCacheTier(FLAGS_cache_size));
+ return pcache;
+}
+
+std::unique_ptr<PersistentCacheTier> NewBlockCache() {
+ std::shared_ptr<Logger> log;
+ if (!Env::Default()->NewLogger(FLAGS_log_path, &log).ok()) {
+ fprintf(stderr, "Error creating log %s \n", FLAGS_log_path.c_str());
+ return nullptr;
+ }
+
+ PersistentCacheConfig opt(Env::Default(), FLAGS_path, FLAGS_cache_size, log);
+ opt.writer_dispatch_size = FLAGS_writer_iosize;
+ opt.writer_qdepth = FLAGS_writer_qdepth;
+ opt.pipeline_writes = FLAGS_enable_pipelined_writes;
+ opt.max_write_pipeline_backlog_size = std::numeric_limits<uint64_t>::max();
+ std::unique_ptr<PersistentCacheTier> cache(new BlockCacheTier(opt));
+ Status status = cache->Open();
+ return cache;
+}
+
+// create a new cache tier
+// construct a tiered RAM+Block cache
+std::unique_ptr<PersistentTieredCache> NewTieredCache(
+ const size_t mem_size, const PersistentCacheConfig& opt) {
+ std::unique_ptr<PersistentTieredCache> tcache(new PersistentTieredCache());
+ // create primary tier
+ assert(mem_size);
+ auto pcache =
+ std::shared_ptr<PersistentCacheTier>(new VolatileCacheTier(mem_size));
+ tcache->AddTier(pcache);
+ // create secondary tier
+ auto scache = std::shared_ptr<PersistentCacheTier>(new BlockCacheTier(opt));
+ tcache->AddTier(scache);
+
+ Status s = tcache->Open();
+ assert(s.ok());
+ return tcache;
+}
+
+std::unique_ptr<PersistentTieredCache> NewTieredCache() {
+ std::shared_ptr<Logger> log;
+ if (!Env::Default()->NewLogger(FLAGS_log_path, &log).ok()) {
+ fprintf(stderr, "Error creating log %s \n", FLAGS_log_path.c_str());
+ abort();
+ }
+
+ auto pct = FLAGS_volatile_cache_pct / static_cast<double>(100);
+ PersistentCacheConfig opt(Env::Default(), FLAGS_path,
+ (1 - pct) * FLAGS_cache_size, log);
+ opt.writer_dispatch_size = FLAGS_writer_iosize;
+ opt.writer_qdepth = FLAGS_writer_qdepth;
+ opt.pipeline_writes = FLAGS_enable_pipelined_writes;
+ opt.max_write_pipeline_backlog_size = std::numeric_limits<uint64_t>::max();
+ return NewTieredCache(FLAGS_cache_size * pct, opt);
+}
+
+//
+// Benchmark driver
+//
+class CacheTierBenchmark {
+ public:
+ explicit CacheTierBenchmark(std::shared_ptr<PersistentCacheTier>&& cache)
+ : cache_(cache) {
+ if (FLAGS_nthread_read) {
+ fprintf(stdout, "Pre-populating\n");
+ Prepop();
+ fprintf(stdout, "Pre-population completed\n");
+ }
+
+ stats_.Clear();
+
+ // Start IO threads
+ std::list<port::Thread> threads;
+ Spawn(FLAGS_nthread_write, &threads,
+ std::bind(&CacheTierBenchmark::Write, this));
+ Spawn(FLAGS_nthread_read, &threads,
+ std::bind(&CacheTierBenchmark::Read, this));
+
+ // Wait till FLAGS_nsec and then signal to quit
+ StopWatchNano t(Env::Default(), /*auto_start=*/true);
+ size_t sec = t.ElapsedNanos() / 1000000000ULL;
+ while (!quit_) {
+ sec = t.ElapsedNanos() / 1000000000ULL;
+ quit_ = sec > size_t(FLAGS_nsec);
+ /* sleep override */ sleep(1);
+ }
+
+ // Wait for threads to exit
+ Join(&threads);
+ // Print stats
+ PrintStats(sec);
+ // Close the cache
+ cache_->TEST_Flush();
+ cache_->Close();
+ }
+
+ private:
+ void PrintStats(const size_t sec) {
+ std::ostringstream msg;
+ msg << "Test stats" << std::endl
+ << "* Elapsed: " << sec << " s" << std::endl
+ << "* Write Latency:" << std::endl
+ << stats_.write_latency_.ToString() << std::endl
+ << "* Read Latency:" << std::endl
+ << stats_.read_latency_.ToString() << std::endl
+ << "* Bytes written:" << std::endl
+ << stats_.bytes_written_.ToString() << std::endl
+ << "* Bytes read:" << std::endl
+ << stats_.bytes_read_.ToString() << std::endl
+ << "Cache stats:" << std::endl
+ << cache_->PrintStats() << std::endl;
+ fprintf(stderr, "%s\n", msg.str().c_str());
+ }
+
+ //
+ // Insert implementation and corresponding helper functions
+ //
+ void Prepop() {
+ for (uint64_t i = 0; i < 1024 * 1024; ++i) {
+ InsertKey(i);
+ insert_key_limit_++;
+ read_key_limit_++;
+ }
+
+ // Wait until data is flushed
+ cache_->TEST_Flush();
+ // warmup the cache
+ for (uint64_t i = 0; i < 1024 * 1024; ReadKey(i++)) {
+ }
+ }
+
+ void Write() {
+ while (!quit_) {
+ InsertKey(insert_key_limit_++);
+ }
+ }
+
+ void InsertKey(const uint64_t key) {
+ // construct key
+ uint64_t k[3];
+ Slice block_key = FillKey(k, key);
+
+ // construct value
+ auto block = NewBlock(key);
+
+ // insert
+ StopWatchNano timer(Env::Default(), /*auto_start=*/true);
+ while (true) {
+ Status status = cache_->Insert(block_key, block.get(), FLAGS_iosize);
+ if (status.ok()) {
+ break;
+ }
+
+ // transient error is possible if we run without pipelining
+ assert(!FLAGS_enable_pipelined_writes);
+ }
+
+ // adjust stats
+ const size_t elapsed_micro = timer.ElapsedNanos() / 1000;
+ stats_.write_latency_.Add(elapsed_micro);
+ stats_.bytes_written_.Add(FLAGS_iosize);
+ }
+
+ //
+ // Read implementation
+ //
+ void Read() {
+ while (!quit_) {
+ ReadKey(random() % read_key_limit_);
+ }
+ }
+
+ void ReadKey(const uint64_t val) {
+ // construct key
+ uint64_t k[3];
+ Slice key = FillKey(k, val);
+
+ // Lookup in cache
+ StopWatchNano timer(Env::Default(), /*auto_start=*/true);
+ std::unique_ptr<char[]> block;
+ size_t size;
+ Status status = cache_->Lookup(key, &block, &size);
+ if (!status.ok()) {
+ fprintf(stderr, "%s\n", status.ToString().c_str());
+ }
+ assert(status.ok());
+ assert(size == (size_t) FLAGS_iosize);
+
+ // adjust stats
+ const size_t elapsed_micro = timer.ElapsedNanos() / 1000;
+ stats_.read_latency_.Add(elapsed_micro);
+ stats_.bytes_read_.Add(FLAGS_iosize);
+
+ // verify content
+ if (!FLAGS_benchmark) {
+ auto expected_block = NewBlock(val);
+ assert(memcmp(block.get(), expected_block.get(), FLAGS_iosize) == 0);
+ }
+ }
+
+ // create data for a key by filling with a certain pattern
+ std::unique_ptr<char[]> NewBlock(const uint64_t val) {
+ std::unique_ptr<char[]> data(new char[FLAGS_iosize]);
+ memset(data.get(), val % 255, FLAGS_iosize);
+ return data;
+ }
+
+ // spawn threads
+ void Spawn(const size_t n, std::list<port::Thread>* threads,
+ const std::function<void()>& fn) {
+ for (size_t i = 0; i < n; ++i) {
+ threads->emplace_back(fn);
+ }
+ }
+
+ // join threads
+ void Join(std::list<port::Thread>* threads) {
+ for (auto& th : *threads) {
+ th.join();
+ }
+ }
+
+ // construct key
+ Slice FillKey(uint64_t (&k)[3], const uint64_t val) {
+ k[0] = k[1] = 0;
+ k[2] = val;
+ void* p = static_cast<void*>(&k);
+ return Slice(static_cast<char*>(p), sizeof(k));
+ }
+
+ // benchmark stats
+ struct Stats {
+ void Clear() {
+ bytes_written_.Clear();
+ bytes_read_.Clear();
+ read_latency_.Clear();
+ write_latency_.Clear();
+ }
+
+ HistogramImpl bytes_written_;
+ HistogramImpl bytes_read_;
+ HistogramImpl read_latency_;
+ HistogramImpl write_latency_;
+ };
+
+ std::shared_ptr<PersistentCacheTier> cache_; // cache implementation
+ std::atomic<uint64_t> insert_key_limit_{0}; // data inserted upto
+ std::atomic<uint64_t> read_key_limit_{0}; // data can be read safely upto
+ bool quit_ = false; // Quit thread ?
+ mutable Stats stats_; // Stats
+};
+
+} // namespace ROCKSDB_NAMESPACE
+
+//
+// main
+//
+int main(int argc, char** argv) {
+ GFLAGS_NAMESPACE::SetUsageMessage(std::string("\nUSAGE:\n") +
+ std::string(argv[0]) + " [OPTIONS]...");
+ GFLAGS_NAMESPACE::ParseCommandLineFlags(&argc, &argv, false);
+
+ std::ostringstream msg;
+ msg << "Config" << std::endl
+ << "======" << std::endl
+ << "* nsec=" << FLAGS_nsec << std::endl
+ << "* nthread_write=" << FLAGS_nthread_write << std::endl
+ << "* path=" << FLAGS_path << std::endl
+ << "* cache_size=" << FLAGS_cache_size << std::endl
+ << "* iosize=" << FLAGS_iosize << std::endl
+ << "* writer_iosize=" << FLAGS_writer_iosize << std::endl
+ << "* writer_qdepth=" << FLAGS_writer_qdepth << std::endl
+ << "* enable_pipelined_writes=" << FLAGS_enable_pipelined_writes
+ << std::endl
+ << "* cache_type=" << FLAGS_cache_type << std::endl
+ << "* benchmark=" << FLAGS_benchmark << std::endl
+ << "* volatile_cache_pct=" << FLAGS_volatile_cache_pct << std::endl;
+
+ fprintf(stderr, "%s\n", msg.str().c_str());
+
+ std::shared_ptr<ROCKSDB_NAMESPACE::PersistentCacheTier> cache;
+ if (FLAGS_cache_type == "block_cache") {
+ fprintf(stderr, "Using block cache implementation\n");
+ cache = ROCKSDB_NAMESPACE::NewBlockCache();
+ } else if (FLAGS_cache_type == "volatile") {
+ fprintf(stderr, "Using volatile cache implementation\n");
+ cache = ROCKSDB_NAMESPACE::NewVolatileCache();
+ } else if (FLAGS_cache_type == "tiered") {
+ fprintf(stderr, "Using tiered cache implementation\n");
+ cache = ROCKSDB_NAMESPACE::NewTieredCache();
+ } else {
+ fprintf(stderr, "Unknown option for cache\n");
+ }
+
+ assert(cache);
+ if (!cache) {
+ fprintf(stderr, "Error creating cache\n");
+ abort();
+ }
+
+ std::unique_ptr<ROCKSDB_NAMESPACE::CacheTierBenchmark> benchmark(
+ new ROCKSDB_NAMESPACE::CacheTierBenchmark(std::move(cache)));
+
+ return 0;
+}
+#endif // #ifndef GFLAGS
+#else
+int main(int, char**) { return 0; }
+#endif
diff --git a/src/rocksdb/utilities/persistent_cache/persistent_cache_test.cc b/src/rocksdb/utilities/persistent_cache/persistent_cache_test.cc
new file mode 100644
index 000000000..dce6e08e0
--- /dev/null
+++ b/src/rocksdb/utilities/persistent_cache/persistent_cache_test.cc
@@ -0,0 +1,474 @@
+// Copyright (c) 2013, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+// GetUniqueIdFromFile is not implemented on Windows. Persistent cache
+// breaks when that function is not implemented
+#if !defined(ROCKSDB_LITE) && !defined(OS_WIN)
+
+#include "utilities/persistent_cache/persistent_cache_test.h"
+
+#include <functional>
+#include <memory>
+#include <thread>
+
+#include "utilities/persistent_cache/block_cache_tier.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+static const double kStressFactor = .125;
+
+#ifdef OS_LINUX
+static void OnOpenForRead(void* arg) {
+ int* val = static_cast<int*>(arg);
+ *val &= ~O_DIRECT;
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "NewRandomAccessFile:O_DIRECT",
+ std::bind(OnOpenForRead, std::placeholders::_1));
+}
+
+static void OnOpenForWrite(void* arg) {
+ int* val = static_cast<int*>(arg);
+ *val &= ~O_DIRECT;
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "NewWritableFile:O_DIRECT",
+ std::bind(OnOpenForWrite, std::placeholders::_1));
+}
+#endif
+
+static void RemoveDirectory(const std::string& folder) {
+ std::vector<std::string> files;
+ Status status = Env::Default()->GetChildren(folder, &files);
+ if (!status.ok()) {
+ // we assume the directory does not exist
+ return;
+ }
+
+ // cleanup files with the patter :digi:.rc
+ for (auto file : files) {
+ if (file == "." || file == "..") {
+ continue;
+ }
+ status = Env::Default()->DeleteFile(folder + "/" + file);
+ assert(status.ok());
+ }
+
+ status = Env::Default()->DeleteDir(folder);
+ assert(status.ok());
+}
+
+static void OnDeleteDir(void* arg) {
+ char* dir = static_cast<char*>(arg);
+ RemoveDirectory(std::string(dir));
+}
+
+//
+// Simple logger that prints message on stdout
+//
+class ConsoleLogger : public Logger {
+ public:
+ using Logger::Logv;
+ ConsoleLogger() : Logger(InfoLogLevel::ERROR_LEVEL) {}
+
+ void Logv(const char* format, va_list ap) override {
+ MutexLock _(&lock_);
+ vprintf(format, ap);
+ printf("\n");
+ }
+
+ port::Mutex lock_;
+};
+
+// construct a tiered RAM+Block cache
+std::unique_ptr<PersistentTieredCache> NewTieredCache(
+ const size_t mem_size, const PersistentCacheConfig& opt) {
+ std::unique_ptr<PersistentTieredCache> tcache(new PersistentTieredCache());
+ // create primary tier
+ assert(mem_size);
+ auto pcache = std::shared_ptr<PersistentCacheTier>(new VolatileCacheTier(
+ /*is_compressed*/ true, mem_size));
+ tcache->AddTier(pcache);
+ // create secondary tier
+ auto scache = std::shared_ptr<PersistentCacheTier>(new BlockCacheTier(opt));
+ tcache->AddTier(scache);
+
+ Status s = tcache->Open();
+ assert(s.ok());
+ return tcache;
+}
+
+// create block cache
+std::unique_ptr<PersistentCacheTier> NewBlockCache(
+ Env* env, const std::string& path,
+ const uint64_t max_size = std::numeric_limits<uint64_t>::max(),
+ const bool enable_direct_writes = false) {
+ const uint32_t max_file_size = static_cast<uint32_t>(12 * 1024 * 1024 * kStressFactor);
+ auto log = std::make_shared<ConsoleLogger>();
+ PersistentCacheConfig opt(env, path, max_size, log);
+ opt.cache_file_size = max_file_size;
+ opt.max_write_pipeline_backlog_size = std::numeric_limits<uint64_t>::max();
+ opt.enable_direct_writes = enable_direct_writes;
+ std::unique_ptr<PersistentCacheTier> scache(new BlockCacheTier(opt));
+ Status s = scache->Open();
+ assert(s.ok());
+ return scache;
+}
+
+// create a new cache tier
+std::unique_ptr<PersistentTieredCache> NewTieredCache(
+ Env* env, const std::string& path, const uint64_t max_volatile_cache_size,
+ const uint64_t max_block_cache_size =
+ std::numeric_limits<uint64_t>::max()) {
+ const uint32_t max_file_size = static_cast<uint32_t>(12 * 1024 * 1024 * kStressFactor);
+ auto log = std::make_shared<ConsoleLogger>();
+ auto opt = PersistentCacheConfig(env, path, max_block_cache_size, log);
+ opt.cache_file_size = max_file_size;
+ opt.max_write_pipeline_backlog_size = std::numeric_limits<uint64_t>::max();
+ // create tier out of the two caches
+ auto cache = NewTieredCache(max_volatile_cache_size, opt);
+ return cache;
+}
+
+PersistentCacheTierTest::PersistentCacheTierTest()
+ : path_(test::PerThreadDBPath("cache_test")) {
+#ifdef OS_LINUX
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "NewRandomAccessFile:O_DIRECT", OnOpenForRead);
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "NewWritableFile:O_DIRECT", OnOpenForWrite);
+#endif
+}
+
+// Block cache tests
+TEST_F(PersistentCacheTierTest, DISABLED_BlockCacheInsertWithFileCreateError) {
+ cache_ = NewBlockCache(Env::Default(), path_,
+ /*size=*/std::numeric_limits<uint64_t>::max(),
+ /*direct_writes=*/ false);
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "BlockCacheTier::NewCacheFile:DeleteDir", OnDeleteDir);
+
+ RunNegativeInsertTest(/*nthreads=*/ 1,
+ /*max_keys*/
+ static_cast<size_t>(10 * 1024 * kStressFactor));
+
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->ClearAllCallBacks();
+}
+
+#if defined(TRAVIS) || defined(ROCKSDB_VALGRIND_RUN)
+// Travis is unable to handle the normal version of the tests running out of
+// fds, out of space and timeouts. This is an easier version of the test
+// specifically written for Travis
+TEST_F(PersistentCacheTierTest, BasicTest) {
+ cache_ = std::make_shared<VolatileCacheTier>();
+ RunInsertTest(/*nthreads=*/1, /*max_keys=*/1024);
+
+ cache_ = NewBlockCache(Env::Default(), path_,
+ /*size=*/std::numeric_limits<uint64_t>::max(),
+ /*direct_writes=*/true);
+ RunInsertTest(/*nthreads=*/1, /*max_keys=*/1024);
+
+ cache_ = NewTieredCache(Env::Default(), path_,
+ /*memory_size=*/static_cast<size_t>(1 * 1024 * 1024));
+ RunInsertTest(/*nthreads=*/1, /*max_keys=*/1024);
+}
+#else
+// Volatile cache tests
+TEST_F(PersistentCacheTierTest, VolatileCacheInsert) {
+ for (auto nthreads : {1, 5}) {
+ for (auto max_keys :
+ {10 * 1024 * kStressFactor, 1 * 1024 * 1024 * kStressFactor}) {
+ cache_ = std::make_shared<VolatileCacheTier>();
+ RunInsertTest(nthreads, static_cast<size_t>(max_keys));
+ }
+ }
+}
+
+TEST_F(PersistentCacheTierTest, VolatileCacheInsertWithEviction) {
+ for (auto nthreads : {1, 5}) {
+ for (auto max_keys : {1 * 1024 * 1024 * kStressFactor}) {
+ cache_ = std::make_shared<VolatileCacheTier>(
+ /*compressed=*/true, /*size=*/static_cast<size_t>(1 * 1024 * 1024 * kStressFactor));
+ RunInsertTestWithEviction(nthreads, static_cast<size_t>(max_keys));
+ }
+ }
+}
+
+// Block cache tests
+TEST_F(PersistentCacheTierTest, BlockCacheInsert) {
+ for (auto direct_writes : {true, false}) {
+ for (auto nthreads : {1, 5}) {
+ for (auto max_keys :
+ {10 * 1024 * kStressFactor, 1 * 1024 * 1024 * kStressFactor}) {
+ cache_ = NewBlockCache(Env::Default(), path_,
+ /*size=*/std::numeric_limits<uint64_t>::max(),
+ direct_writes);
+ RunInsertTest(nthreads, static_cast<size_t>(max_keys));
+ }
+ }
+ }
+}
+
+TEST_F(PersistentCacheTierTest, BlockCacheInsertWithEviction) {
+ for (auto nthreads : {1, 5}) {
+ for (auto max_keys : {1 * 1024 * 1024 * kStressFactor}) {
+ cache_ = NewBlockCache(Env::Default(), path_,
+ /*max_size=*/static_cast<size_t>(200 * 1024 * 1024 * kStressFactor));
+ RunInsertTestWithEviction(nthreads, static_cast<size_t>(max_keys));
+ }
+ }
+}
+
+// Tiered cache tests
+TEST_F(PersistentCacheTierTest, TieredCacheInsert) {
+ for (auto nthreads : {1, 5}) {
+ for (auto max_keys :
+ {10 * 1024 * kStressFactor, 1 * 1024 * 1024 * kStressFactor}) {
+ cache_ = NewTieredCache(Env::Default(), path_,
+ /*memory_size=*/static_cast<size_t>(1 * 1024 * 1024 * kStressFactor));
+ RunInsertTest(nthreads, static_cast<size_t>(max_keys));
+ }
+ }
+}
+
+// the tests causes a lot of file deletions which Travis limited testing
+// environment cannot handle
+TEST_F(PersistentCacheTierTest, TieredCacheInsertWithEviction) {
+ for (auto nthreads : {1, 5}) {
+ for (auto max_keys : {1 * 1024 * 1024 * kStressFactor}) {
+ cache_ = NewTieredCache(
+ Env::Default(), path_,
+ /*memory_size=*/static_cast<size_t>(1 * 1024 * 1024 * kStressFactor),
+ /*block_cache_size*/ static_cast<size_t>(200 * 1024 * 1024 * kStressFactor));
+ RunInsertTestWithEviction(nthreads, static_cast<size_t>(max_keys));
+ }
+ }
+}
+#endif
+
+std::shared_ptr<PersistentCacheTier> MakeVolatileCache(
+ const std::string& /*dbname*/) {
+ return std::make_shared<VolatileCacheTier>();
+}
+
+std::shared_ptr<PersistentCacheTier> MakeBlockCache(const std::string& dbname) {
+ return NewBlockCache(Env::Default(), dbname);
+}
+
+std::shared_ptr<PersistentCacheTier> MakeTieredCache(
+ const std::string& dbname) {
+ const auto memory_size = 1 * 1024 * 1024 * kStressFactor;
+ return NewTieredCache(Env::Default(), dbname, static_cast<size_t>(memory_size));
+}
+
+#ifdef OS_LINUX
+static void UniqueIdCallback(void* arg) {
+ int* result = reinterpret_cast<int*>(arg);
+ if (*result == -1) {
+ *result = 0;
+ }
+
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->ClearTrace();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "GetUniqueIdFromFile:FS_IOC_GETVERSION", UniqueIdCallback);
+}
+#endif
+
+TEST_F(PersistentCacheTierTest, FactoryTest) {
+ for (auto nvm_opt : {true, false}) {
+ ASSERT_FALSE(cache_);
+ auto log = std::make_shared<ConsoleLogger>();
+ std::shared_ptr<PersistentCache> cache;
+ ASSERT_OK(NewPersistentCache(Env::Default(), path_,
+ /*size=*/1 * 1024 * 1024 * 1024, log, nvm_opt,
+ &cache));
+ ASSERT_TRUE(cache);
+ ASSERT_EQ(cache->Stats().size(), 1);
+ ASSERT_TRUE(cache->Stats()[0].size());
+ cache.reset();
+ }
+}
+
+PersistentCacheDBTest::PersistentCacheDBTest() : DBTestBase("/cache_test") {
+#ifdef OS_LINUX
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "GetUniqueIdFromFile:FS_IOC_GETVERSION", UniqueIdCallback);
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "NewRandomAccessFile:O_DIRECT", OnOpenForRead);
+#endif
+}
+
+// test template
+void PersistentCacheDBTest::RunTest(
+ const std::function<std::shared_ptr<PersistentCacheTier>(bool)>& new_pcache,
+ const size_t max_keys = 100 * 1024, const size_t max_usecase = 5) {
+ if (!Snappy_Supported()) {
+ return;
+ }
+
+ // number of insertion interations
+ int num_iter = static_cast<int>(max_keys * kStressFactor);
+
+ for (size_t iter = 0; iter < max_usecase; iter++) {
+ Options options;
+ options.write_buffer_size =
+ static_cast<size_t>(64 * 1024 * kStressFactor); // small write buffer
+ options.statistics = ROCKSDB_NAMESPACE::CreateDBStatistics();
+ options = CurrentOptions(options);
+
+ // setup page cache
+ std::shared_ptr<PersistentCacheTier> pcache;
+ BlockBasedTableOptions table_options;
+ table_options.cache_index_and_filter_blocks = true;
+
+ const size_t size_max = std::numeric_limits<size_t>::max();
+
+ switch (iter) {
+ case 0:
+ // page cache, block cache, no-compressed cache
+ pcache = new_pcache(/*is_compressed=*/true);
+ table_options.persistent_cache = pcache;
+ table_options.block_cache = NewLRUCache(size_max);
+ table_options.block_cache_compressed = nullptr;
+ options.table_factory.reset(NewBlockBasedTableFactory(table_options));
+ break;
+ case 1:
+ // page cache, block cache, compressed cache
+ pcache = new_pcache(/*is_compressed=*/true);
+ table_options.persistent_cache = pcache;
+ table_options.block_cache = NewLRUCache(size_max);
+ table_options.block_cache_compressed = NewLRUCache(size_max);
+ options.table_factory.reset(NewBlockBasedTableFactory(table_options));
+ break;
+ case 2:
+ // page cache, block cache, compressed cache + KNoCompression
+ // both block cache and compressed cache, but DB is not compressed
+ // also, make block cache sizes bigger, to trigger block cache hits
+ pcache = new_pcache(/*is_compressed=*/true);
+ table_options.persistent_cache = pcache;
+ table_options.block_cache = NewLRUCache(size_max);
+ table_options.block_cache_compressed = NewLRUCache(size_max);
+ options.table_factory.reset(NewBlockBasedTableFactory(table_options));
+ options.compression = kNoCompression;
+ break;
+ case 3:
+ // page cache, no block cache, no compressed cache
+ pcache = new_pcache(/*is_compressed=*/false);
+ table_options.persistent_cache = pcache;
+ table_options.block_cache = nullptr;
+ table_options.block_cache_compressed = nullptr;
+ options.table_factory.reset(NewBlockBasedTableFactory(table_options));
+ break;
+ case 4:
+ // page cache, no block cache, no compressed cache
+ // Page cache caches compressed blocks
+ pcache = new_pcache(/*is_compressed=*/true);
+ table_options.persistent_cache = pcache;
+ table_options.block_cache = nullptr;
+ table_options.block_cache_compressed = nullptr;
+ options.table_factory.reset(NewBlockBasedTableFactory(table_options));
+ break;
+ default:
+ FAIL();
+ }
+
+ std::vector<std::string> values;
+ // insert data
+ Insert(options, table_options, num_iter, &values);
+ // flush all data in cache to device
+ pcache->TEST_Flush();
+ // verify data
+ Verify(num_iter, values);
+
+ auto block_miss = TestGetTickerCount(options, BLOCK_CACHE_MISS);
+ auto compressed_block_hit =
+ TestGetTickerCount(options, BLOCK_CACHE_COMPRESSED_HIT);
+ auto compressed_block_miss =
+ TestGetTickerCount(options, BLOCK_CACHE_COMPRESSED_MISS);
+ auto page_hit = TestGetTickerCount(options, PERSISTENT_CACHE_HIT);
+ auto page_miss = TestGetTickerCount(options, PERSISTENT_CACHE_MISS);
+
+ // check that we triggered the appropriate code paths in the cache
+ switch (iter) {
+ case 0:
+ // page cache, block cache, no-compressed cache
+ ASSERT_GT(page_miss, 0);
+ ASSERT_GT(page_hit, 0);
+ ASSERT_GT(block_miss, 0);
+ ASSERT_EQ(compressed_block_miss, 0);
+ ASSERT_EQ(compressed_block_hit, 0);
+ break;
+ case 1:
+ // page cache, block cache, compressed cache
+ ASSERT_GT(page_miss, 0);
+ ASSERT_GT(block_miss, 0);
+ ASSERT_GT(compressed_block_miss, 0);
+ break;
+ case 2:
+ // page cache, block cache, compressed cache + KNoCompression
+ ASSERT_GT(page_miss, 0);
+ ASSERT_GT(page_hit, 0);
+ ASSERT_GT(block_miss, 0);
+ ASSERT_GT(compressed_block_miss, 0);
+ // remember kNoCompression
+ ASSERT_EQ(compressed_block_hit, 0);
+ break;
+ case 3:
+ case 4:
+ // page cache, no block cache, no compressed cache
+ ASSERT_GT(page_miss, 0);
+ ASSERT_GT(page_hit, 0);
+ ASSERT_EQ(compressed_block_hit, 0);
+ ASSERT_EQ(compressed_block_miss, 0);
+ break;
+ default:
+ FAIL();
+ }
+
+ options.create_if_missing = true;
+ DestroyAndReopen(options);
+
+ pcache->Close();
+ }
+}
+
+#if defined(TRAVIS) || defined(ROCKSDB_VALGRIND_RUN)
+// Travis is unable to handle the normal version of the tests running out of
+// fds, out of space and timeouts. This is an easier version of the test
+// specifically written for Travis
+TEST_F(PersistentCacheDBTest, BasicTest) {
+ RunTest(std::bind(&MakeBlockCache, dbname_), /*max_keys=*/1024,
+ /*max_usecase=*/1);
+}
+#else
+// test table with block page cache
+TEST_F(PersistentCacheDBTest, BlockCacheTest) {
+ RunTest(std::bind(&MakeBlockCache, dbname_));
+}
+
+// test table with volatile page cache
+TEST_F(PersistentCacheDBTest, VolatileCacheTest) {
+ RunTest(std::bind(&MakeVolatileCache, dbname_));
+}
+
+// test table with tiered page cache
+TEST_F(PersistentCacheDBTest, TieredCacheTest) {
+ RunTest(std::bind(&MakeTieredCache, dbname_));
+}
+#endif
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
+#else // !defined(ROCKSDB_LITE) && !defined(OS_WIN)
+int main() { return 0; }
+#endif // !defined(ROCKSDB_LITE) && !defined(OS_WIN)
diff --git a/src/rocksdb/utilities/persistent_cache/persistent_cache_test.h b/src/rocksdb/utilities/persistent_cache/persistent_cache_test.h
new file mode 100644
index 000000000..47611ecd3
--- /dev/null
+++ b/src/rocksdb/utilities/persistent_cache/persistent_cache_test.h
@@ -0,0 +1,285 @@
+// Copyright (c) 2013, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+#pragma once
+
+#ifndef ROCKSDB_LITE
+
+#include <functional>
+#include <limits>
+#include <list>
+#include <memory>
+#include <string>
+#include <thread>
+#include <vector>
+
+#include "db/db_test_util.h"
+#include "memory/arena.h"
+#include "port/port.h"
+#include "rocksdb/cache.h"
+#include "table/block_based/block_builder.h"
+#include "test_util/testharness.h"
+#include "utilities/persistent_cache/volatile_tier_impl.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+//
+// Unit tests for testing PersistentCacheTier
+//
+class PersistentCacheTierTest : public testing::Test {
+ public:
+ PersistentCacheTierTest();
+ virtual ~PersistentCacheTierTest() {
+ if (cache_) {
+ Status s = cache_->Close();
+ assert(s.ok());
+ }
+ }
+
+ protected:
+ // Flush cache
+ void Flush() {
+ if (cache_) {
+ cache_->TEST_Flush();
+ }
+ }
+
+ // create threaded workload
+ template <class T>
+ std::list<port::Thread> SpawnThreads(const size_t n, const T& fn) {
+ std::list<port::Thread> threads;
+ for (size_t i = 0; i < n; i++) {
+ port::Thread th(fn);
+ threads.push_back(std::move(th));
+ }
+ return threads;
+ }
+
+ // Wait for threads to join
+ void Join(std::list<port::Thread>&& threads) {
+ for (auto& th : threads) {
+ th.join();
+ }
+ threads.clear();
+ }
+
+ // Run insert workload in threads
+ void Insert(const size_t nthreads, const size_t max_keys) {
+ key_ = 0;
+ max_keys_ = max_keys;
+ // spawn threads
+ auto fn = std::bind(&PersistentCacheTierTest::InsertImpl, this);
+ auto threads = SpawnThreads(nthreads, fn);
+ // join with threads
+ Join(std::move(threads));
+ // Flush cache
+ Flush();
+ }
+
+ // Run verification on the cache
+ void Verify(const size_t nthreads = 1, const bool eviction_enabled = false) {
+ stats_verify_hits_ = 0;
+ stats_verify_missed_ = 0;
+ key_ = 0;
+ // spawn threads
+ auto fn =
+ std::bind(&PersistentCacheTierTest::VerifyImpl, this, eviction_enabled);
+ auto threads = SpawnThreads(nthreads, fn);
+ // join with threads
+ Join(std::move(threads));
+ }
+
+ // pad 0 to numbers
+ std::string PaddedNumber(const size_t data, const size_t pad_size) {
+ assert(pad_size);
+ char* ret = new char[pad_size];
+ int pos = static_cast<int>(pad_size) - 1;
+ size_t count = 0;
+ size_t t = data;
+ // copy numbers
+ while (t) {
+ count++;
+ ret[pos--] = '0' + t % 10;
+ t = t / 10;
+ }
+ // copy 0s
+ while (pos >= 0) {
+ ret[pos--] = '0';
+ }
+ // post condition
+ assert(count <= pad_size);
+ assert(pos == -1);
+ std::string result(ret, pad_size);
+ delete[] ret;
+ return result;
+ }
+
+ // Insert workload implementation
+ void InsertImpl() {
+ const std::string prefix = "key_prefix_";
+
+ while (true) {
+ size_t i = key_++;
+ if (i >= max_keys_) {
+ break;
+ }
+
+ char data[4 * 1024];
+ memset(data, '0' + (i % 10), sizeof(data));
+ auto k = prefix + PaddedNumber(i, /*count=*/8);
+ Slice key(k);
+ while (true) {
+ Status status = cache_->Insert(key, data, sizeof(data));
+ if (status.ok()) {
+ break;
+ }
+ ASSERT_TRUE(status.IsTryAgain());
+ Env::Default()->SleepForMicroseconds(1 * 1000 * 1000);
+ }
+ }
+ }
+
+ // Verification implementation
+ void VerifyImpl(const bool eviction_enabled = false) {
+ const std::string prefix = "key_prefix_";
+ while (true) {
+ size_t i = key_++;
+ if (i >= max_keys_) {
+ break;
+ }
+
+ char edata[4 * 1024];
+ memset(edata, '0' + (i % 10), sizeof(edata));
+ auto k = prefix + PaddedNumber(i, /*count=*/8);
+ Slice key(k);
+ std::unique_ptr<char[]> block;
+ size_t block_size;
+
+ if (eviction_enabled) {
+ if (!cache_->Lookup(key, &block, &block_size).ok()) {
+ // assume that the key is evicted
+ stats_verify_missed_++;
+ continue;
+ }
+ }
+
+ ASSERT_OK(cache_->Lookup(key, &block, &block_size));
+ ASSERT_EQ(block_size, sizeof(edata));
+ ASSERT_EQ(memcmp(edata, block.get(), sizeof(edata)), 0);
+ stats_verify_hits_++;
+ }
+ }
+
+ // template for insert test
+ void RunInsertTest(const size_t nthreads, const size_t max_keys) {
+ Insert(nthreads, max_keys);
+ Verify(nthreads);
+ ASSERT_EQ(stats_verify_hits_, max_keys);
+ ASSERT_EQ(stats_verify_missed_, 0);
+
+ cache_->Close();
+ cache_.reset();
+ }
+
+ // template for negative insert test
+ void RunNegativeInsertTest(const size_t nthreads, const size_t max_keys) {
+ Insert(nthreads, max_keys);
+ Verify(nthreads, /*eviction_enabled=*/true);
+ ASSERT_LT(stats_verify_hits_, max_keys);
+ ASSERT_GT(stats_verify_missed_, 0);
+
+ cache_->Close();
+ cache_.reset();
+ }
+
+ // template for insert with eviction test
+ void RunInsertTestWithEviction(const size_t nthreads, const size_t max_keys) {
+ Insert(nthreads, max_keys);
+ Verify(nthreads, /*eviction_enabled=*/true);
+ ASSERT_EQ(stats_verify_hits_ + stats_verify_missed_, max_keys);
+ ASSERT_GT(stats_verify_hits_, 0);
+ ASSERT_GT(stats_verify_missed_, 0);
+
+ cache_->Close();
+ cache_.reset();
+ }
+
+ const std::string path_;
+ std::shared_ptr<Logger> log_;
+ std::shared_ptr<PersistentCacheTier> cache_;
+ std::atomic<size_t> key_{0};
+ size_t max_keys_ = 0;
+ std::atomic<size_t> stats_verify_hits_{0};
+ std::atomic<size_t> stats_verify_missed_{0};
+};
+
+//
+// RocksDB tests
+//
+class PersistentCacheDBTest : public DBTestBase {
+ public:
+ PersistentCacheDBTest();
+
+ static uint64_t TestGetTickerCount(const Options& options,
+ Tickers ticker_type) {
+ return static_cast<uint32_t>(
+ options.statistics->getTickerCount(ticker_type));
+ }
+
+ // insert data to table
+ void Insert(const Options& options,
+ const BlockBasedTableOptions& /*table_options*/,
+ const int num_iter, std::vector<std::string>* values) {
+ CreateAndReopenWithCF({"pikachu"}, options);
+ // default column family doesn't have block cache
+ Options no_block_cache_opts;
+ no_block_cache_opts.statistics = options.statistics;
+ no_block_cache_opts = CurrentOptions(no_block_cache_opts);
+ BlockBasedTableOptions table_options_no_bc;
+ table_options_no_bc.no_block_cache = true;
+ no_block_cache_opts.table_factory.reset(
+ NewBlockBasedTableFactory(table_options_no_bc));
+ ReopenWithColumnFamilies(
+ {"default", "pikachu"},
+ std::vector<Options>({no_block_cache_opts, options}));
+
+ Random rnd(301);
+
+ // Write 8MB (80 values, each 100K)
+ ASSERT_EQ(NumTableFilesAtLevel(0, 1), 0);
+ std::string str;
+ for (int i = 0; i < num_iter; i++) {
+ if (i % 4 == 0) { // high compression ratio
+ str = RandomString(&rnd, 1000);
+ }
+ values->push_back(str);
+ ASSERT_OK(Put(1, Key(i), (*values)[i]));
+ }
+
+ // flush all data from memtable so that reads are from block cache
+ ASSERT_OK(Flush(1));
+ }
+
+ // verify data
+ void Verify(const int num_iter, const std::vector<std::string>& values) {
+ for (int j = 0; j < 2; ++j) {
+ for (int i = 0; i < num_iter; i++) {
+ ASSERT_EQ(Get(1, Key(i)), values[i]);
+ }
+ }
+ }
+
+ // test template
+ void RunTest(const std::function<std::shared_ptr<PersistentCacheTier>(bool)>&
+ new_pcache,
+ const size_t max_keys, const size_t max_usecase);
+};
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif
diff --git a/src/rocksdb/utilities/persistent_cache/persistent_cache_tier.cc b/src/rocksdb/utilities/persistent_cache/persistent_cache_tier.cc
new file mode 100644
index 000000000..3847a4ee9
--- /dev/null
+++ b/src/rocksdb/utilities/persistent_cache/persistent_cache_tier.cc
@@ -0,0 +1,163 @@
+// Copyright (c) 2013, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+#ifndef ROCKSDB_LITE
+
+#include "utilities/persistent_cache/persistent_cache_tier.h"
+
+#include <cinttypes>
+#include <sstream>
+#include <string>
+
+namespace ROCKSDB_NAMESPACE {
+
+std::string PersistentCacheConfig::ToString() const {
+ std::string ret;
+ ret.reserve(20000);
+ const int kBufferSize = 200;
+ char buffer[kBufferSize];
+
+ snprintf(buffer, kBufferSize, " path: %s\n", path.c_str());
+ ret.append(buffer);
+ snprintf(buffer, kBufferSize, " enable_direct_reads: %d\n",
+ enable_direct_reads);
+ ret.append(buffer);
+ snprintf(buffer, kBufferSize, " enable_direct_writes: %d\n",
+ enable_direct_writes);
+ ret.append(buffer);
+ snprintf(buffer, kBufferSize, " cache_size: %" PRIu64 "\n", cache_size);
+ ret.append(buffer);
+ snprintf(buffer, kBufferSize, " cache_file_size: %" PRIu32 "\n",
+ cache_file_size);
+ ret.append(buffer);
+ snprintf(buffer, kBufferSize, " writer_qdepth: %" PRIu32 "\n",
+ writer_qdepth);
+ ret.append(buffer);
+ snprintf(buffer, kBufferSize, " pipeline_writes: %d\n", pipeline_writes);
+ ret.append(buffer);
+ snprintf(buffer, kBufferSize,
+ " max_write_pipeline_backlog_size: %" PRIu64 "\n",
+ max_write_pipeline_backlog_size);
+ ret.append(buffer);
+ snprintf(buffer, kBufferSize, " write_buffer_size: %" PRIu32 "\n",
+ write_buffer_size);
+ ret.append(buffer);
+ snprintf(buffer, kBufferSize, " writer_dispatch_size: %" PRIu64 "\n",
+ writer_dispatch_size);
+ ret.append(buffer);
+ snprintf(buffer, kBufferSize, " is_compressed: %d\n", is_compressed);
+ ret.append(buffer);
+
+ return ret;
+}
+
+//
+// PersistentCacheTier implementation
+//
+Status PersistentCacheTier::Open() {
+ if (next_tier_) {
+ return next_tier_->Open();
+ }
+ return Status::OK();
+}
+
+Status PersistentCacheTier::Close() {
+ if (next_tier_) {
+ return next_tier_->Close();
+ }
+ return Status::OK();
+}
+
+bool PersistentCacheTier::Reserve(const size_t /*size*/) {
+ // default implementation is a pass through
+ return true;
+}
+
+bool PersistentCacheTier::Erase(const Slice& /*key*/) {
+ // default implementation is a pass through since not all cache tiers might
+ // support erase
+ return true;
+}
+
+std::string PersistentCacheTier::PrintStats() {
+ std::ostringstream os;
+ for (auto tier_stats : Stats()) {
+ os << "---- next tier -----" << std::endl;
+ for (auto stat : tier_stats) {
+ os << stat.first << ": " << stat.second << std::endl;
+ }
+ }
+ return os.str();
+}
+
+PersistentCache::StatsType PersistentCacheTier::Stats() {
+ if (next_tier_) {
+ return next_tier_->Stats();
+ }
+ return PersistentCache::StatsType{};
+}
+
+//
+// PersistentTieredCache implementation
+//
+PersistentTieredCache::~PersistentTieredCache() { assert(tiers_.empty()); }
+
+Status PersistentTieredCache::Open() {
+ assert(!tiers_.empty());
+ return tiers_.front()->Open();
+}
+
+Status PersistentTieredCache::Close() {
+ assert(!tiers_.empty());
+ Status status = tiers_.front()->Close();
+ if (status.ok()) {
+ tiers_.clear();
+ }
+ return status;
+}
+
+bool PersistentTieredCache::Erase(const Slice& key) {
+ assert(!tiers_.empty());
+ return tiers_.front()->Erase(key);
+}
+
+PersistentCache::StatsType PersistentTieredCache::Stats() {
+ assert(!tiers_.empty());
+ return tiers_.front()->Stats();
+}
+
+std::string PersistentTieredCache::PrintStats() {
+ assert(!tiers_.empty());
+ return tiers_.front()->PrintStats();
+}
+
+Status PersistentTieredCache::Insert(const Slice& page_key, const char* data,
+ const size_t size) {
+ assert(!tiers_.empty());
+ return tiers_.front()->Insert(page_key, data, size);
+}
+
+Status PersistentTieredCache::Lookup(const Slice& page_key,
+ std::unique_ptr<char[]>* data,
+ size_t* size) {
+ assert(!tiers_.empty());
+ return tiers_.front()->Lookup(page_key, data, size);
+}
+
+void PersistentTieredCache::AddTier(const Tier& tier) {
+ if (!tiers_.empty()) {
+ tiers_.back()->set_next_tier(tier);
+ }
+ tiers_.push_back(tier);
+}
+
+bool PersistentTieredCache::IsCompressed() {
+ assert(tiers_.size());
+ return tiers_.front()->IsCompressed();
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif
diff --git a/src/rocksdb/utilities/persistent_cache/persistent_cache_tier.h b/src/rocksdb/utilities/persistent_cache/persistent_cache_tier.h
new file mode 100644
index 000000000..3905957de
--- /dev/null
+++ b/src/rocksdb/utilities/persistent_cache/persistent_cache_tier.h
@@ -0,0 +1,336 @@
+// Copyright (c) 2013, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+#pragma once
+
+#ifndef ROCKSDB_LITE
+
+#include <limits>
+#include <list>
+#include <map>
+#include <string>
+#include <vector>
+
+#include "monitoring/histogram.h"
+#include "rocksdb/env.h"
+#include "rocksdb/persistent_cache.h"
+#include "rocksdb/status.h"
+
+// Persistent Cache
+//
+// Persistent cache is tiered key-value cache that can use persistent medium. It
+// is a generic design and can leverage any storage medium -- disk/SSD/NVM/RAM.
+// The code has been kept generic but significant benchmark/design/development
+// time has been spent to make sure the cache performs appropriately for
+// respective storage medium.
+// The file defines
+// PersistentCacheTier : Implementation that handles individual cache tier
+// PersistentTieresCache : Implementation that handles all tiers as a logical
+// unit
+//
+// PersistentTieredCache architecture:
+// +--------------------------+ PersistentCacheTier that handles multiple tiers
+// | +----------------+ |
+// | | RAM | PersistentCacheTier that handles RAM (VolatileCacheImpl)
+// | +----------------+ |
+// | | next |
+// | v |
+// | +----------------+ |
+// | | NVM | PersistentCacheTier implementation that handles NVM
+// | +----------------+ (BlockCacheImpl)
+// | | next |
+// | V |
+// | +----------------+ |
+// | | LE-SSD | PersistentCacheTier implementation that handles LE-SSD
+// | +----------------+ (BlockCacheImpl)
+// | | |
+// | V |
+// | null |
+// +--------------------------+
+// |
+// V
+// null
+namespace ROCKSDB_NAMESPACE {
+
+// Persistent Cache Config
+//
+// This struct captures all the options that are used to configure persistent
+// cache. Some of the terminologies used in naming the options are
+//
+// dispatch size :
+// This is the size in which IO is dispatched to the device
+//
+// write buffer size :
+// This is the size of an individual write buffer size. Write buffers are
+// grouped to form buffered file.
+//
+// cache size :
+// This is the logical maximum for the cache size
+//
+// qdepth :
+// This is the max number of IOs that can issues to the device in parallel
+//
+// pepeling :
+// The writer code path follows pipelined architecture, which means the
+// operations are handed off from one stage to another
+//
+// pipelining backlog size :
+// With the pipelined architecture, there can always be backlogging of ops in
+// pipeline queues. This is the maximum backlog size after which ops are dropped
+// from queue
+struct PersistentCacheConfig {
+ explicit PersistentCacheConfig(
+ Env* const _env, const std::string& _path, const uint64_t _cache_size,
+ const std::shared_ptr<Logger>& _log,
+ const uint32_t _write_buffer_size = 1 * 1024 * 1024 /*1MB*/) {
+ env = _env;
+ path = _path;
+ log = _log;
+ cache_size = _cache_size;
+ writer_dispatch_size = write_buffer_size = _write_buffer_size;
+ }
+
+ //
+ // Validate the settings. Our intentions are to catch erroneous settings ahead
+ // of time instead going violating invariants or causing dead locks.
+ //
+ Status ValidateSettings() const {
+ // (1) check pre-conditions for variables
+ if (!env || path.empty()) {
+ return Status::InvalidArgument("empty or null args");
+ }
+
+ // (2) assert size related invariants
+ // - cache size cannot be less than cache file size
+ // - individual write buffer size cannot be greater than cache file size
+ // - total write buffer size cannot be less than 2X cache file size
+ if (cache_size < cache_file_size || write_buffer_size >= cache_file_size ||
+ write_buffer_size * write_buffer_count() < 2 * cache_file_size) {
+ return Status::InvalidArgument("invalid cache size");
+ }
+
+ // (2) check writer settings
+ // - Queue depth cannot be 0
+ // - writer_dispatch_size cannot be greater than writer_buffer_size
+ // - dispatch size and buffer size need to be aligned
+ if (!writer_qdepth || writer_dispatch_size > write_buffer_size ||
+ write_buffer_size % writer_dispatch_size) {
+ return Status::InvalidArgument("invalid writer settings");
+ }
+
+ return Status::OK();
+ }
+
+ //
+ // Env abstraction to use for systmer level operations
+ //
+ Env* env;
+
+ //
+ // Path for the block cache where blocks are persisted
+ //
+ std::string path;
+
+ //
+ // Log handle for logging messages
+ //
+ std::shared_ptr<Logger> log;
+
+ //
+ // Enable direct IO for reading
+ //
+ bool enable_direct_reads = true;
+
+ //
+ // Enable direct IO for writing
+ //
+ bool enable_direct_writes = false;
+
+ //
+ // Logical cache size
+ //
+ uint64_t cache_size = std::numeric_limits<uint64_t>::max();
+
+ // cache-file-size
+ //
+ // Cache consists of multiples of small files. This parameter defines the
+ // size of an individual cache file
+ //
+ // default: 1M
+ uint32_t cache_file_size = 100ULL * 1024 * 1024;
+
+ // writer-qdepth
+ //
+ // The writers can issues IO to the devices in parallel. This parameter
+ // controls the max number if IOs that can issues in parallel to the block
+ // device
+ //
+ // default :1
+ uint32_t writer_qdepth = 1;
+
+ // pipeline-writes
+ //
+ // The write optionally follow pipelined architecture. This helps
+ // avoid regression in the eviction code path of the primary tier. This
+ // parameter defines if pipelining is enabled or disabled
+ //
+ // default: true
+ bool pipeline_writes = true;
+
+ // max-write-pipeline-backlog-size
+ //
+ // Max pipeline buffer size. This is the maximum backlog we can accumulate
+ // while waiting for writes. After the limit, new ops will be dropped.
+ //
+ // Default: 1GiB
+ uint64_t max_write_pipeline_backlog_size = 1ULL * 1024 * 1024 * 1024;
+
+ // write-buffer-size
+ //
+ // This is the size in which buffer slabs are allocated.
+ //
+ // Default: 1M
+ uint32_t write_buffer_size = 1ULL * 1024 * 1024;
+
+ // write-buffer-count
+ //
+ // This is the total number of buffer slabs. This is calculated as a factor of
+ // file size in order to avoid dead lock.
+ size_t write_buffer_count() const {
+ assert(write_buffer_size);
+ return static_cast<size_t>((writer_qdepth + 1.2) * cache_file_size /
+ write_buffer_size);
+ }
+
+ // writer-dispatch-size
+ //
+ // The writer thread will dispatch the IO at the specified IO size
+ //
+ // default: 1M
+ uint64_t writer_dispatch_size = 1ULL * 1024 * 1024;
+
+ // is_compressed
+ //
+ // This option determines if the cache will run in compressed mode or
+ // uncompressed mode
+ bool is_compressed = true;
+
+ PersistentCacheConfig MakePersistentCacheConfig(
+ const std::string& path, const uint64_t size,
+ const std::shared_ptr<Logger>& log);
+
+ std::string ToString() const;
+};
+
+// Persistent Cache Tier
+//
+// This a logical abstraction that defines a tier of the persistent cache. Tiers
+// can be stacked over one another. PersistentCahe provides the basic definition
+// for accessing/storing in the cache. PersistentCacheTier extends the interface
+// to enable management and stacking of tiers.
+class PersistentCacheTier : public PersistentCache {
+ public:
+ typedef std::shared_ptr<PersistentCacheTier> Tier;
+
+ virtual ~PersistentCacheTier() {}
+
+ // Open the persistent cache tier
+ virtual Status Open();
+
+ // Close the persistent cache tier
+ virtual Status Close();
+
+ // Reserve space up to 'size' bytes
+ virtual bool Reserve(const size_t size);
+
+ // Erase a key from the cache
+ virtual bool Erase(const Slice& key);
+
+ // Print stats to string recursively
+ virtual std::string PrintStats();
+
+ virtual PersistentCache::StatsType Stats() override;
+
+ // Insert to page cache
+ virtual Status Insert(const Slice& page_key, const char* data,
+ const size_t size) override = 0;
+
+ // Lookup page cache by page identifier
+ virtual Status Lookup(const Slice& page_key, std::unique_ptr<char[]>* data,
+ size_t* size) override = 0;
+
+ // Does it store compressed data ?
+ virtual bool IsCompressed() override = 0;
+
+ virtual std::string GetPrintableOptions() const override = 0;
+
+ // Return a reference to next tier
+ virtual Tier& next_tier() { return next_tier_; }
+
+ // Set the value for next tier
+ virtual void set_next_tier(const Tier& tier) {
+ assert(!next_tier_);
+ next_tier_ = tier;
+ }
+
+ virtual void TEST_Flush() {
+ if (next_tier_) {
+ next_tier_->TEST_Flush();
+ }
+ }
+
+ private:
+ Tier next_tier_; // next tier
+};
+
+// PersistentTieredCache
+//
+// Abstraction that helps you construct a tiers of persistent caches as a
+// unified cache. The tier(s) of cache will act a single tier for management
+// ease and support PersistentCache methods for accessing data.
+class PersistentTieredCache : public PersistentCacheTier {
+ public:
+ virtual ~PersistentTieredCache();
+
+ Status Open() override;
+ Status Close() override;
+ bool Erase(const Slice& key) override;
+ std::string PrintStats() override;
+ PersistentCache::StatsType Stats() override;
+ Status Insert(const Slice& page_key, const char* data,
+ const size_t size) override;
+ Status Lookup(const Slice& page_key, std::unique_ptr<char[]>* data,
+ size_t* size) override;
+ bool IsCompressed() override;
+
+ std::string GetPrintableOptions() const override {
+ return "PersistentTieredCache";
+ }
+
+ void AddTier(const Tier& tier);
+
+ Tier& next_tier() override {
+ auto it = tiers_.end();
+ return (*it)->next_tier();
+ }
+
+ void set_next_tier(const Tier& tier) override {
+ auto it = tiers_.end();
+ (*it)->set_next_tier(tier);
+ }
+
+ void TEST_Flush() override {
+ assert(!tiers_.empty());
+ tiers_.front()->TEST_Flush();
+ PersistentCacheTier::TEST_Flush();
+ }
+
+ protected:
+ std::list<Tier> tiers_; // list of tiers top-down
+};
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif
diff --git a/src/rocksdb/utilities/persistent_cache/persistent_cache_util.h b/src/rocksdb/utilities/persistent_cache/persistent_cache_util.h
new file mode 100644
index 000000000..2a769652d
--- /dev/null
+++ b/src/rocksdb/utilities/persistent_cache/persistent_cache_util.h
@@ -0,0 +1,67 @@
+// Copyright (c) 2013, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+#pragma once
+
+#include <limits>
+#include <list>
+
+#include "util/mutexlock.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+//
+// Simple synchronized queue implementation with the option of
+// bounding the queue
+//
+// On overflow, the elements will be discarded
+//
+template <class T>
+class BoundedQueue {
+ public:
+ explicit BoundedQueue(
+ const size_t max_size = std::numeric_limits<size_t>::max())
+ : cond_empty_(&lock_), max_size_(max_size) {}
+
+ virtual ~BoundedQueue() {}
+
+ void Push(T&& t) {
+ MutexLock _(&lock_);
+ if (max_size_ != std::numeric_limits<size_t>::max() &&
+ size_ + t.Size() >= max_size_) {
+ // overflow
+ return;
+ }
+
+ size_ += t.Size();
+ q_.push_back(std::move(t));
+ cond_empty_.SignalAll();
+ }
+
+ T Pop() {
+ MutexLock _(&lock_);
+ while (q_.empty()) {
+ cond_empty_.Wait();
+ }
+
+ T t = std::move(q_.front());
+ size_ -= t.Size();
+ q_.pop_front();
+ return t;
+ }
+
+ size_t Size() const {
+ MutexLock _(&lock_);
+ return size_;
+ }
+
+ private:
+ mutable port::Mutex lock_;
+ port::CondVar cond_empty_;
+ std::list<T> q_;
+ size_t size_ = 0;
+ const size_t max_size_;
+};
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/persistent_cache/volatile_tier_impl.cc b/src/rocksdb/utilities/persistent_cache/volatile_tier_impl.cc
new file mode 100644
index 000000000..ee63f828c
--- /dev/null
+++ b/src/rocksdb/utilities/persistent_cache/volatile_tier_impl.cc
@@ -0,0 +1,138 @@
+// Copyright (c) 2013, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+#ifndef ROCKSDB_LITE
+
+#include "utilities/persistent_cache/volatile_tier_impl.h"
+
+#include <string>
+
+namespace ROCKSDB_NAMESPACE {
+
+void VolatileCacheTier::DeleteCacheData(VolatileCacheTier::CacheData* data) {
+ assert(data);
+ delete data;
+}
+
+VolatileCacheTier::~VolatileCacheTier() { index_.Clear(&DeleteCacheData); }
+
+PersistentCache::StatsType VolatileCacheTier::Stats() {
+ std::map<std::string, double> stat;
+ stat.insert({"persistent_cache.volatile_cache.hits",
+ static_cast<double>(stats_.cache_hits_)});
+ stat.insert({"persistent_cache.volatile_cache.misses",
+ static_cast<double>(stats_.cache_misses_)});
+ stat.insert({"persistent_cache.volatile_cache.inserts",
+ static_cast<double>(stats_.cache_inserts_)});
+ stat.insert({"persistent_cache.volatile_cache.evicts",
+ static_cast<double>(stats_.cache_evicts_)});
+ stat.insert({"persistent_cache.volatile_cache.hit_pct",
+ static_cast<double>(stats_.CacheHitPct())});
+ stat.insert({"persistent_cache.volatile_cache.miss_pct",
+ static_cast<double>(stats_.CacheMissPct())});
+
+ auto out = PersistentCacheTier::Stats();
+ out.push_back(stat);
+ return out;
+}
+
+Status VolatileCacheTier::Insert(const Slice& page_key, const char* data,
+ const size_t size) {
+ // precondition
+ assert(data);
+ assert(size);
+
+ // increment the size
+ size_ += size;
+
+ // check if we have overshot the limit, if so evict some space
+ while (size_ > max_size_) {
+ if (!Evict()) {
+ // unable to evict data, we give up so we don't spike read
+ // latency
+ assert(size_ >= size);
+ size_ -= size;
+ return Status::TryAgain("Unable to evict any data");
+ }
+ }
+
+ assert(size_ >= size);
+
+ // insert order: LRU, followed by index
+ std::string key(page_key.data(), page_key.size());
+ std::string value(data, size);
+ std::unique_ptr<CacheData> cache_data(
+ new CacheData(std::move(key), std::move(value)));
+ bool ok = index_.Insert(cache_data.get());
+ if (!ok) {
+ // decrement the size that we incremented ahead of time
+ assert(size_ >= size);
+ size_ -= size;
+ // failed to insert to cache, block already in cache
+ return Status::TryAgain("key already exists in volatile cache");
+ }
+
+ cache_data.release();
+ stats_.cache_inserts_++;
+ return Status::OK();
+}
+
+Status VolatileCacheTier::Lookup(const Slice& page_key,
+ std::unique_ptr<char[]>* result,
+ size_t* size) {
+ CacheData key(std::move(page_key.ToString()));
+ CacheData* kv;
+ bool ok = index_.Find(&key, &kv);
+ if (ok) {
+ // set return data
+ result->reset(new char[kv->value.size()]);
+ memcpy(result->get(), kv->value.c_str(), kv->value.size());
+ *size = kv->value.size();
+ // drop the reference on cache data
+ kv->refs_--;
+ // update stats
+ stats_.cache_hits_++;
+ return Status::OK();
+ }
+
+ stats_.cache_misses_++;
+
+ if (next_tier()) {
+ return next_tier()->Lookup(page_key, result, size);
+ }
+
+ return Status::NotFound("key not found in volatile cache");
+}
+
+bool VolatileCacheTier::Erase(const Slice& /*key*/) {
+ assert(!"not supported");
+ return true;
+}
+
+bool VolatileCacheTier::Evict() {
+ CacheData* edata = index_.Evict();
+ if (!edata) {
+ // not able to evict any object
+ return false;
+ }
+
+ stats_.cache_evicts_++;
+
+ // push the evicted object to the next level
+ if (next_tier()) {
+ next_tier()->Insert(Slice(edata->key), edata->value.c_str(),
+ edata->value.size());
+ }
+
+ // adjust size and destroy data
+ size_ -= edata->value.size();
+ delete edata;
+
+ return true;
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif
diff --git a/src/rocksdb/utilities/persistent_cache/volatile_tier_impl.h b/src/rocksdb/utilities/persistent_cache/volatile_tier_impl.h
new file mode 100644
index 000000000..6116e894b
--- /dev/null
+++ b/src/rocksdb/utilities/persistent_cache/volatile_tier_impl.h
@@ -0,0 +1,142 @@
+// Copyright (c) 2013, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+#pragma once
+
+#ifndef ROCKSDB_LITE
+
+#include <atomic>
+#include <limits>
+#include <sstream>
+#include <string>
+#include <vector>
+
+#include "rocksdb/cache.h"
+#include "utilities/persistent_cache/hash_table.h"
+#include "utilities/persistent_cache/hash_table_evictable.h"
+#include "utilities/persistent_cache/persistent_cache_tier.h"
+
+// VolatileCacheTier
+//
+// This file provides persistent cache tier implementation for caching
+// key/values in RAM.
+//
+// key/values
+// |
+// V
+// +-------------------+
+// | VolatileCacheTier | Store in an evictable hash table
+// +-------------------+
+// |
+// V
+// on eviction
+// pushed to next tier
+//
+// The implementation is designed to be concurrent. The evictable hash table
+// implementation is not concurrent at this point though.
+//
+// The eviction algorithm is LRU
+namespace ROCKSDB_NAMESPACE {
+
+class VolatileCacheTier : public PersistentCacheTier {
+ public:
+ explicit VolatileCacheTier(
+ const bool is_compressed = true,
+ const size_t max_size = std::numeric_limits<size_t>::max())
+ : is_compressed_(is_compressed), max_size_(max_size) {}
+
+ virtual ~VolatileCacheTier();
+
+ // insert to cache
+ Status Insert(const Slice& page_key, const char* data,
+ const size_t size) override;
+ // lookup key in cache
+ Status Lookup(const Slice& page_key, std::unique_ptr<char[]>* data,
+ size_t* size) override;
+
+ // is compressed cache ?
+ bool IsCompressed() override { return is_compressed_; }
+
+ // erase key from cache
+ bool Erase(const Slice& key) override;
+
+ std::string GetPrintableOptions() const override {
+ return "VolatileCacheTier";
+ }
+
+ // Expose stats as map
+ PersistentCache::StatsType Stats() override;
+
+ private:
+ //
+ // Cache data abstraction
+ //
+ struct CacheData : LRUElement<CacheData> {
+ explicit CacheData(CacheData&& rhs) ROCKSDB_NOEXCEPT
+ : key(std::move(rhs.key)),
+ value(std::move(rhs.value)) {}
+
+ explicit CacheData(const std::string& _key, const std::string& _value = "")
+ : key(_key), value(_value) {}
+
+ virtual ~CacheData() {}
+
+ const std::string key;
+ const std::string value;
+ };
+
+ static void DeleteCacheData(CacheData* data);
+
+ //
+ // Index and LRU definition
+ //
+ struct CacheDataHash {
+ uint64_t operator()(const CacheData* obj) const {
+ assert(obj);
+ return std::hash<std::string>()(obj->key);
+ }
+ };
+
+ struct CacheDataEqual {
+ bool operator()(const CacheData* lhs, const CacheData* rhs) const {
+ assert(lhs);
+ assert(rhs);
+ return lhs->key == rhs->key;
+ }
+ };
+
+ struct Statistics {
+ std::atomic<uint64_t> cache_misses_{0};
+ std::atomic<uint64_t> cache_hits_{0};
+ std::atomic<uint64_t> cache_inserts_{0};
+ std::atomic<uint64_t> cache_evicts_{0};
+
+ double CacheHitPct() const {
+ auto lookups = cache_hits_ + cache_misses_;
+ return lookups ? 100 * cache_hits_ / static_cast<double>(lookups) : 0.0;
+ }
+
+ double CacheMissPct() const {
+ auto lookups = cache_hits_ + cache_misses_;
+ return lookups ? 100 * cache_misses_ / static_cast<double>(lookups) : 0.0;
+ }
+ };
+
+ typedef EvictableHashTable<CacheData, CacheDataHash, CacheDataEqual>
+ IndexType;
+
+ // Evict LRU tail
+ bool Evict();
+
+ const bool is_compressed_ = true; // does it store compressed data
+ IndexType index_; // in-memory cache
+ std::atomic<uint64_t> max_size_{0}; // Maximum size of the cache
+ std::atomic<uint64_t> size_{0}; // Size of the cache
+ Statistics stats_;
+};
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif
diff --git a/src/rocksdb/utilities/simulator_cache/cache_simulator.cc b/src/rocksdb/utilities/simulator_cache/cache_simulator.cc
new file mode 100644
index 000000000..16a78ea71
--- /dev/null
+++ b/src/rocksdb/utilities/simulator_cache/cache_simulator.cc
@@ -0,0 +1,274 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "utilities/simulator_cache/cache_simulator.h"
+#include <algorithm>
+#include "db/dbformat.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+namespace {
+const std::string kGhostCachePrefix = "ghost_";
+} // namespace
+
+GhostCache::GhostCache(std::shared_ptr<Cache> sim_cache)
+ : sim_cache_(sim_cache) {}
+
+bool GhostCache::Admit(const Slice& lookup_key) {
+ auto handle = sim_cache_->Lookup(lookup_key);
+ if (handle != nullptr) {
+ sim_cache_->Release(handle);
+ return true;
+ }
+ sim_cache_->Insert(lookup_key, /*value=*/nullptr, lookup_key.size(),
+ /*deleter=*/nullptr);
+ return false;
+}
+
+CacheSimulator::CacheSimulator(std::unique_ptr<GhostCache>&& ghost_cache,
+ std::shared_ptr<Cache> sim_cache)
+ : ghost_cache_(std::move(ghost_cache)), sim_cache_(sim_cache) {}
+
+void CacheSimulator::Access(const BlockCacheTraceRecord& access) {
+ bool admit = true;
+ const bool is_user_access =
+ BlockCacheTraceHelper::IsUserAccess(access.caller);
+ bool is_cache_miss = true;
+ if (ghost_cache_ && access.no_insert == Boolean::kFalse) {
+ admit = ghost_cache_->Admit(access.block_key);
+ }
+ auto handle = sim_cache_->Lookup(access.block_key);
+ if (handle != nullptr) {
+ sim_cache_->Release(handle);
+ is_cache_miss = false;
+ } else {
+ if (access.no_insert == Boolean::kFalse && admit && access.block_size > 0) {
+ sim_cache_->Insert(access.block_key, /*value=*/nullptr, access.block_size,
+ /*deleter=*/nullptr);
+ }
+ }
+ miss_ratio_stats_.UpdateMetrics(access.access_timestamp, is_user_access,
+ is_cache_miss);
+}
+
+void MissRatioStats::UpdateMetrics(uint64_t timestamp_in_ms,
+ bool is_user_access, bool is_cache_miss) {
+ uint64_t timestamp_in_seconds = timestamp_in_ms / kMicrosInSecond;
+ num_accesses_timeline_[timestamp_in_seconds] += 1;
+ num_accesses_ += 1;
+ if (num_misses_timeline_.find(timestamp_in_seconds) ==
+ num_misses_timeline_.end()) {
+ num_misses_timeline_[timestamp_in_seconds] = 0;
+ }
+ if (is_cache_miss) {
+ num_misses_ += 1;
+ num_misses_timeline_[timestamp_in_seconds] += 1;
+ }
+ if (is_user_access) {
+ user_accesses_ += 1;
+ if (is_cache_miss) {
+ user_misses_ += 1;
+ }
+ }
+}
+
+Cache::Priority PrioritizedCacheSimulator::ComputeBlockPriority(
+ const BlockCacheTraceRecord& access) const {
+ if (access.block_type == TraceType::kBlockTraceFilterBlock ||
+ access.block_type == TraceType::kBlockTraceIndexBlock ||
+ access.block_type == TraceType::kBlockTraceUncompressionDictBlock) {
+ return Cache::Priority::HIGH;
+ }
+ return Cache::Priority::LOW;
+}
+
+void PrioritizedCacheSimulator::AccessKVPair(
+ const Slice& key, uint64_t value_size, Cache::Priority priority,
+ const BlockCacheTraceRecord& access, bool no_insert, bool is_user_access,
+ bool* is_cache_miss, bool* admitted, bool update_metrics) {
+ assert(is_cache_miss);
+ assert(admitted);
+ *is_cache_miss = true;
+ *admitted = true;
+ if (ghost_cache_ && !no_insert) {
+ *admitted = ghost_cache_->Admit(key);
+ }
+ auto handle = sim_cache_->Lookup(key);
+ if (handle != nullptr) {
+ sim_cache_->Release(handle);
+ *is_cache_miss = false;
+ } else if (!no_insert && *admitted && value_size > 0) {
+ sim_cache_->Insert(key, /*value=*/nullptr, value_size, /*deleter=*/nullptr,
+ /*handle=*/nullptr, priority);
+ }
+ if (update_metrics) {
+ miss_ratio_stats_.UpdateMetrics(access.access_timestamp, is_user_access,
+ *is_cache_miss);
+ }
+}
+
+void PrioritizedCacheSimulator::Access(const BlockCacheTraceRecord& access) {
+ bool is_cache_miss = true;
+ bool admitted = true;
+ AccessKVPair(access.block_key, access.block_size,
+ ComputeBlockPriority(access), access, access.no_insert,
+ BlockCacheTraceHelper::IsUserAccess(access.caller),
+ &is_cache_miss, &admitted, /*update_metrics=*/true);
+}
+
+void HybridRowBlockCacheSimulator::Access(const BlockCacheTraceRecord& access) {
+ // TODO (haoyu): We only support Get for now. We need to extend the tracing
+ // for MultiGet, i.e., non-data block accesses must log all keys in a
+ // MultiGet.
+ bool is_cache_miss = true;
+ bool admitted = false;
+ if (access.caller == TableReaderCaller::kUserGet &&
+ access.get_id != BlockCacheTraceHelper::kReservedGetId) {
+ // This is a Get request.
+ const std::string& row_key = BlockCacheTraceHelper::ComputeRowKey(access);
+ GetRequestStatus& status = getid_status_map_[access.get_id];
+ if (status.is_complete) {
+ // This Get request completes.
+ // Skip future accesses to its index/filter/data
+ // blocks. These block lookups are unnecessary if we observe a hit for the
+ // referenced key-value pair already. Thus, we treat these lookups as
+ // hits. This is also to ensure the total number of accesses are the same
+ // when comparing to other policies.
+ miss_ratio_stats_.UpdateMetrics(access.access_timestamp,
+ /*is_user_access=*/true,
+ /*is_cache_miss=*/false);
+ return;
+ }
+ if (status.row_key_status.find(row_key) == status.row_key_status.end()) {
+ // This is the first time that this key is accessed. Look up the key-value
+ // pair first. Do not update the miss/accesses metrics here since it will
+ // be updated later.
+ AccessKVPair(row_key, access.referenced_data_size, Cache::Priority::HIGH,
+ access,
+ /*no_insert=*/false,
+ /*is_user_access=*/true, &is_cache_miss, &admitted,
+ /*update_metrics=*/false);
+ InsertResult result = InsertResult::NO_INSERT;
+ if (admitted && access.referenced_data_size > 0) {
+ result = InsertResult::INSERTED;
+ } else if (admitted) {
+ result = InsertResult::ADMITTED;
+ }
+ status.row_key_status[row_key] = result;
+ }
+ if (!is_cache_miss) {
+ // A cache hit.
+ status.is_complete = true;
+ miss_ratio_stats_.UpdateMetrics(access.access_timestamp,
+ /*is_user_access=*/true,
+ /*is_cache_miss=*/false);
+ return;
+ }
+ // The row key-value pair observes a cache miss. We need to access its
+ // index/filter/data blocks.
+ InsertResult inserted = status.row_key_status[row_key];
+ AccessKVPair(
+ access.block_key, access.block_size, ComputeBlockPriority(access),
+ access,
+ /*no_insert=*/!insert_blocks_upon_row_kvpair_miss_ || access.no_insert,
+ /*is_user_access=*/true, &is_cache_miss, &admitted,
+ /*update_metrics=*/true);
+ if (access.referenced_data_size > 0 && inserted == InsertResult::ADMITTED) {
+ sim_cache_->Insert(row_key, /*value=*/nullptr,
+ access.referenced_data_size, /*deleter=*/nullptr,
+ /*handle=*/nullptr, Cache::Priority::HIGH);
+ status.row_key_status[row_key] = InsertResult::INSERTED;
+ }
+ return;
+ }
+ AccessKVPair(access.block_key, access.block_size,
+ ComputeBlockPriority(access), access, access.no_insert,
+ BlockCacheTraceHelper::IsUserAccess(access.caller),
+ &is_cache_miss, &admitted, /*update_metrics=*/true);
+}
+
+BlockCacheTraceSimulator::BlockCacheTraceSimulator(
+ uint64_t warmup_seconds, uint32_t downsample_ratio,
+ const std::vector<CacheConfiguration>& cache_configurations)
+ : warmup_seconds_(warmup_seconds),
+ downsample_ratio_(downsample_ratio),
+ cache_configurations_(cache_configurations) {}
+
+Status BlockCacheTraceSimulator::InitializeCaches() {
+ for (auto const& config : cache_configurations_) {
+ for (auto cache_capacity : config.cache_capacities) {
+ // Scale down the cache capacity since the trace contains accesses on
+ // 1/'downsample_ratio' blocks.
+ uint64_t simulate_cache_capacity = cache_capacity / downsample_ratio_;
+ std::shared_ptr<CacheSimulator> sim_cache;
+ std::unique_ptr<GhostCache> ghost_cache;
+ std::string cache_name = config.cache_name;
+ if (cache_name.find(kGhostCachePrefix) != std::string::npos) {
+ ghost_cache.reset(new GhostCache(
+ NewLRUCache(config.ghost_cache_capacity, /*num_shard_bits=*/1,
+ /*strict_capacity_limit=*/false,
+ /*high_pri_pool_ratio=*/0)));
+ cache_name = cache_name.substr(kGhostCachePrefix.size());
+ }
+ if (cache_name == "lru") {
+ sim_cache = std::make_shared<CacheSimulator>(
+ std::move(ghost_cache),
+ NewLRUCache(simulate_cache_capacity, config.num_shard_bits,
+ /*strict_capacity_limit=*/false,
+ /*high_pri_pool_ratio=*/0));
+ } else if (cache_name == "lru_priority") {
+ sim_cache = std::make_shared<PrioritizedCacheSimulator>(
+ std::move(ghost_cache),
+ NewLRUCache(simulate_cache_capacity, config.num_shard_bits,
+ /*strict_capacity_limit=*/false,
+ /*high_pri_pool_ratio=*/0.5));
+ } else if (cache_name == "lru_hybrid") {
+ sim_cache = std::make_shared<HybridRowBlockCacheSimulator>(
+ std::move(ghost_cache),
+ NewLRUCache(simulate_cache_capacity, config.num_shard_bits,
+ /*strict_capacity_limit=*/false,
+ /*high_pri_pool_ratio=*/0.5),
+ /*insert_blocks_upon_row_kvpair_miss=*/true);
+ } else if (cache_name == "lru_hybrid_no_insert_on_row_miss") {
+ sim_cache = std::make_shared<HybridRowBlockCacheSimulator>(
+ std::move(ghost_cache),
+ NewLRUCache(simulate_cache_capacity, config.num_shard_bits,
+ /*strict_capacity_limit=*/false,
+ /*high_pri_pool_ratio=*/0.5),
+ /*insert_blocks_upon_row_kvpair_miss=*/false);
+ } else {
+ // Not supported.
+ return Status::InvalidArgument("Unknown cache name " +
+ config.cache_name);
+ }
+ sim_caches_[config].push_back(sim_cache);
+ }
+ }
+ return Status::OK();
+}
+
+void BlockCacheTraceSimulator::Access(const BlockCacheTraceRecord& access) {
+ if (trace_start_time_ == 0) {
+ trace_start_time_ = access.access_timestamp;
+ }
+ // access.access_timestamp is in microseconds.
+ if (!warmup_complete_ &&
+ trace_start_time_ + warmup_seconds_ * kMicrosInSecond <=
+ access.access_timestamp) {
+ for (auto& config_caches : sim_caches_) {
+ for (auto& sim_cache : config_caches.second) {
+ sim_cache->reset_counter();
+ }
+ }
+ warmup_complete_ = true;
+ }
+ for (auto& config_caches : sim_caches_) {
+ for (auto& sim_cache : config_caches.second) {
+ sim_cache->Access(access);
+ }
+ }
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/simulator_cache/cache_simulator.h b/src/rocksdb/utilities/simulator_cache/cache_simulator.h
new file mode 100644
index 000000000..6d4979013
--- /dev/null
+++ b/src/rocksdb/utilities/simulator_cache/cache_simulator.h
@@ -0,0 +1,231 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#include <unordered_map>
+
+#include "cache/lru_cache.h"
+#include "trace_replay/block_cache_tracer.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// A cache configuration provided by user.
+struct CacheConfiguration {
+ std::string cache_name; // LRU.
+ uint32_t num_shard_bits;
+ uint64_t ghost_cache_capacity; // ghost cache capacity in bytes.
+ std::vector<uint64_t>
+ cache_capacities; // simulate cache capacities in bytes.
+
+ bool operator==(const CacheConfiguration& o) const {
+ return cache_name == o.cache_name && num_shard_bits == o.num_shard_bits &&
+ ghost_cache_capacity == o.ghost_cache_capacity;
+ }
+ bool operator<(const CacheConfiguration& o) const {
+ return cache_name < o.cache_name ||
+ (cache_name == o.cache_name && num_shard_bits < o.num_shard_bits) ||
+ (cache_name == o.cache_name && num_shard_bits == o.num_shard_bits &&
+ ghost_cache_capacity < o.ghost_cache_capacity);
+ }
+};
+
+class MissRatioStats {
+ public:
+ void reset_counter() {
+ num_misses_ = 0;
+ num_accesses_ = 0;
+ user_accesses_ = 0;
+ user_misses_ = 0;
+ }
+ double miss_ratio() const {
+ if (num_accesses_ == 0) {
+ return -1;
+ }
+ return static_cast<double>(num_misses_ * 100.0 / num_accesses_);
+ }
+ uint64_t total_accesses() const { return num_accesses_; }
+ uint64_t total_misses() const { return num_misses_; }
+
+ const std::map<uint64_t, uint64_t>& num_accesses_timeline() const {
+ return num_accesses_timeline_;
+ }
+
+ const std::map<uint64_t, uint64_t>& num_misses_timeline() const {
+ return num_misses_timeline_;
+ }
+
+ double user_miss_ratio() const {
+ if (user_accesses_ == 0) {
+ return -1;
+ }
+ return static_cast<double>(user_misses_ * 100.0 / user_accesses_);
+ }
+ uint64_t user_accesses() const { return user_accesses_; }
+ uint64_t user_misses() const { return user_misses_; }
+
+ void UpdateMetrics(uint64_t timestamp_in_ms, bool is_user_access,
+ bool is_cache_miss);
+
+ private:
+ uint64_t num_accesses_ = 0;
+ uint64_t num_misses_ = 0;
+ uint64_t user_accesses_ = 0;
+ uint64_t user_misses_ = 0;
+
+ std::map<uint64_t, uint64_t> num_accesses_timeline_;
+ std::map<uint64_t, uint64_t> num_misses_timeline_;
+};
+
+// A ghost cache admits an entry on its second access.
+class GhostCache {
+ public:
+ explicit GhostCache(std::shared_ptr<Cache> sim_cache);
+ ~GhostCache() = default;
+ // No copy and move.
+ GhostCache(const GhostCache&) = delete;
+ GhostCache& operator=(const GhostCache&) = delete;
+ GhostCache(GhostCache&&) = delete;
+ GhostCache& operator=(GhostCache&&) = delete;
+
+ // Returns true if the lookup_key is in the ghost cache.
+ // Returns false otherwise.
+ bool Admit(const Slice& lookup_key);
+
+ private:
+ std::shared_ptr<Cache> sim_cache_;
+};
+
+// A cache simulator that runs against a block cache trace.
+class CacheSimulator {
+ public:
+ CacheSimulator(std::unique_ptr<GhostCache>&& ghost_cache,
+ std::shared_ptr<Cache> sim_cache);
+ virtual ~CacheSimulator() = default;
+ // No copy and move.
+ CacheSimulator(const CacheSimulator&) = delete;
+ CacheSimulator& operator=(const CacheSimulator&) = delete;
+ CacheSimulator(CacheSimulator&&) = delete;
+ CacheSimulator& operator=(CacheSimulator&&) = delete;
+
+ virtual void Access(const BlockCacheTraceRecord& access);
+
+ void reset_counter() { miss_ratio_stats_.reset_counter(); }
+
+ const MissRatioStats& miss_ratio_stats() const { return miss_ratio_stats_; }
+
+ protected:
+ MissRatioStats miss_ratio_stats_;
+ std::unique_ptr<GhostCache> ghost_cache_;
+ std::shared_ptr<Cache> sim_cache_;
+};
+
+// A prioritized cache simulator that runs against a block cache trace.
+// It inserts missing index/filter/uncompression-dictionary blocks with high
+// priority in the cache.
+class PrioritizedCacheSimulator : public CacheSimulator {
+ public:
+ PrioritizedCacheSimulator(std::unique_ptr<GhostCache>&& ghost_cache,
+ std::shared_ptr<Cache> sim_cache)
+ : CacheSimulator(std::move(ghost_cache), sim_cache) {}
+ void Access(const BlockCacheTraceRecord& access) override;
+
+ protected:
+ // Access the key-value pair and returns true upon a cache miss.
+ void AccessKVPair(const Slice& key, uint64_t value_size,
+ Cache::Priority priority,
+ const BlockCacheTraceRecord& access, bool no_insert,
+ bool is_user_access, bool* is_cache_miss, bool* admitted,
+ bool update_metrics);
+
+ Cache::Priority ComputeBlockPriority(
+ const BlockCacheTraceRecord& access) const;
+};
+
+// A hybrid row and block cache simulator. It looks up/inserts key-value pairs
+// referenced by Get/MultiGet requests, and not their accessed index/filter/data
+// blocks.
+//
+// Upon a Get/MultiGet request, it looks up the referenced key first.
+// If it observes a cache hit, future block accesses on this key-value pair is
+// skipped since the request is served already. Otherwise, it continues to look
+// up/insert its index/filter/data blocks. It also inserts the referenced
+// key-value pair in the cache for future lookups.
+class HybridRowBlockCacheSimulator : public PrioritizedCacheSimulator {
+ public:
+ HybridRowBlockCacheSimulator(std::unique_ptr<GhostCache>&& ghost_cache,
+ std::shared_ptr<Cache> sim_cache,
+ bool insert_blocks_upon_row_kvpair_miss)
+ : PrioritizedCacheSimulator(std::move(ghost_cache), sim_cache),
+ insert_blocks_upon_row_kvpair_miss_(
+ insert_blocks_upon_row_kvpair_miss) {}
+ void Access(const BlockCacheTraceRecord& access) override;
+
+ private:
+ enum InsertResult : char {
+ INSERTED,
+ ADMITTED,
+ NO_INSERT,
+ };
+
+ // We set is_complete to true when the referenced row-key of a get request
+ // hits the cache. If is_complete is true, we treat future accesses of this
+ // get request as hits.
+ //
+ // For each row key, it stores an enum. It is INSERTED when the
+ // kv-pair has been inserted into the cache, ADMITTED if it should be inserted
+ // but haven't been, NO_INSERT if it should not be inserted.
+ //
+ // A kv-pair is in ADMITTED state when we encounter this kv-pair but do not
+ // know its size. This may happen if the first access on the referenced key is
+ // an index/filter block.
+ struct GetRequestStatus {
+ bool is_complete = false;
+ std::map<std::string, InsertResult> row_key_status;
+ };
+
+ // A map stores get_id to a map of row keys.
+ std::map<uint64_t, GetRequestStatus> getid_status_map_;
+ bool insert_blocks_upon_row_kvpair_miss_;
+};
+
+// A block cache simulator that reports miss ratio curves given a set of cache
+// configurations.
+class BlockCacheTraceSimulator {
+ public:
+ // warmup_seconds: The number of seconds to warmup simulated caches. The
+ // hit/miss counters are reset after the warmup completes.
+ BlockCacheTraceSimulator(
+ uint64_t warmup_seconds, uint32_t downsample_ratio,
+ const std::vector<CacheConfiguration>& cache_configurations);
+ ~BlockCacheTraceSimulator() = default;
+ // No copy and move.
+ BlockCacheTraceSimulator(const BlockCacheTraceSimulator&) = delete;
+ BlockCacheTraceSimulator& operator=(const BlockCacheTraceSimulator&) = delete;
+ BlockCacheTraceSimulator(BlockCacheTraceSimulator&&) = delete;
+ BlockCacheTraceSimulator& operator=(BlockCacheTraceSimulator&&) = delete;
+
+ Status InitializeCaches();
+
+ void Access(const BlockCacheTraceRecord& access);
+
+ const std::map<CacheConfiguration,
+ std::vector<std::shared_ptr<CacheSimulator>>>&
+ sim_caches() const {
+ return sim_caches_;
+ }
+
+ private:
+ const uint64_t warmup_seconds_;
+ const uint32_t downsample_ratio_;
+ const std::vector<CacheConfiguration> cache_configurations_;
+
+ bool warmup_complete_ = false;
+ std::map<CacheConfiguration, std::vector<std::shared_ptr<CacheSimulator>>>
+ sim_caches_;
+ uint64_t trace_start_time_ = 0;
+};
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/simulator_cache/cache_simulator_test.cc b/src/rocksdb/utilities/simulator_cache/cache_simulator_test.cc
new file mode 100644
index 000000000..a205315cc
--- /dev/null
+++ b/src/rocksdb/utilities/simulator_cache/cache_simulator_test.cc
@@ -0,0 +1,494 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "utilities/simulator_cache/cache_simulator.h"
+
+#include <cstdlib>
+#include "rocksdb/env.h"
+#include "test_util/testharness.h"
+#include "test_util/testutil.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace {
+const std::string kBlockKeyPrefix = "test-block-";
+const std::string kRefKeyPrefix = "test-get-";
+const std::string kRefKeySequenceNumber = std::string(8, 'c');
+const uint64_t kGetId = 1;
+const uint64_t kGetBlockId = 100;
+const uint64_t kCompactionBlockId = 1000;
+const uint64_t kCacheSize = 1024 * 1024 * 1024;
+const uint64_t kGhostCacheSize = 1024 * 1024;
+} // namespace
+
+class CacheSimulatorTest : public testing::Test {
+ public:
+ const size_t kNumBlocks = 5;
+ const size_t kValueSize = 1000;
+
+ CacheSimulatorTest() { env_ = ROCKSDB_NAMESPACE::Env::Default(); }
+
+ BlockCacheTraceRecord GenerateGetRecord(uint64_t getid) {
+ BlockCacheTraceRecord record;
+ record.block_type = TraceType::kBlockTraceDataBlock;
+ record.block_size = 4096;
+ record.block_key = kBlockKeyPrefix + std::to_string(kGetBlockId);
+ record.access_timestamp = env_->NowMicros();
+ record.cf_id = 0;
+ record.cf_name = "test";
+ record.caller = TableReaderCaller::kUserGet;
+ record.level = 6;
+ record.sst_fd_number = 0;
+ record.get_id = getid;
+ record.is_cache_hit = Boolean::kFalse;
+ record.no_insert = Boolean::kFalse;
+ record.referenced_key =
+ kRefKeyPrefix + std::to_string(kGetId) + kRefKeySequenceNumber;
+ record.referenced_key_exist_in_block = Boolean::kTrue;
+ record.referenced_data_size = 100;
+ record.num_keys_in_block = 300;
+ return record;
+ }
+
+ BlockCacheTraceRecord GenerateCompactionRecord() {
+ BlockCacheTraceRecord record;
+ record.block_type = TraceType::kBlockTraceDataBlock;
+ record.block_size = 4096;
+ record.block_key = kBlockKeyPrefix + std::to_string(kCompactionBlockId);
+ record.access_timestamp = env_->NowMicros();
+ record.cf_id = 0;
+ record.cf_name = "test";
+ record.caller = TableReaderCaller::kCompaction;
+ record.level = 6;
+ record.sst_fd_number = kCompactionBlockId;
+ record.is_cache_hit = Boolean::kFalse;
+ record.no_insert = Boolean::kTrue;
+ return record;
+ }
+
+ void AssertCache(std::shared_ptr<Cache> sim_cache,
+ const MissRatioStats& miss_ratio_stats,
+ uint64_t expected_usage, uint64_t expected_num_accesses,
+ uint64_t expected_num_misses,
+ std::vector<std::string> blocks,
+ std::vector<std::string> keys) {
+ EXPECT_EQ(expected_usage, sim_cache->GetUsage());
+ EXPECT_EQ(expected_num_accesses, miss_ratio_stats.total_accesses());
+ EXPECT_EQ(expected_num_misses, miss_ratio_stats.total_misses());
+ for (auto const& block : blocks) {
+ auto handle = sim_cache->Lookup(block);
+ EXPECT_NE(nullptr, handle);
+ sim_cache->Release(handle);
+ }
+ for (auto const& key : keys) {
+ std::string row_key = kRefKeyPrefix + key + kRefKeySequenceNumber;
+ auto handle =
+ sim_cache->Lookup("0_" + ExtractUserKey(row_key).ToString());
+ EXPECT_NE(nullptr, handle);
+ sim_cache->Release(handle);
+ }
+ }
+
+ Env* env_;
+};
+
+TEST_F(CacheSimulatorTest, GhostCache) {
+ const std::string key1 = "test1";
+ const std::string key2 = "test2";
+ std::unique_ptr<GhostCache> ghost_cache(new GhostCache(
+ NewLRUCache(/*capacity=*/kGhostCacheSize, /*num_shard_bits=*/1,
+ /*strict_capacity_limit=*/false,
+ /*high_pri_pool_ratio=*/0)));
+ EXPECT_FALSE(ghost_cache->Admit(key1));
+ EXPECT_TRUE(ghost_cache->Admit(key1));
+ EXPECT_TRUE(ghost_cache->Admit(key1));
+ EXPECT_FALSE(ghost_cache->Admit(key2));
+ EXPECT_TRUE(ghost_cache->Admit(key2));
+}
+
+TEST_F(CacheSimulatorTest, CacheSimulator) {
+ const BlockCacheTraceRecord& access = GenerateGetRecord(kGetId);
+ const BlockCacheTraceRecord& compaction_access = GenerateCompactionRecord();
+ std::shared_ptr<Cache> sim_cache =
+ NewLRUCache(/*capacity=*/kCacheSize, /*num_shard_bits=*/1,
+ /*strict_capacity_limit=*/false,
+ /*high_pri_pool_ratio=*/0);
+ std::unique_ptr<CacheSimulator> cache_simulator(
+ new CacheSimulator(nullptr, sim_cache));
+ cache_simulator->Access(access);
+ cache_simulator->Access(access);
+ ASSERT_EQ(2, cache_simulator->miss_ratio_stats().total_accesses());
+ ASSERT_EQ(50, cache_simulator->miss_ratio_stats().miss_ratio());
+ ASSERT_EQ(2, cache_simulator->miss_ratio_stats().user_accesses());
+ ASSERT_EQ(50, cache_simulator->miss_ratio_stats().user_miss_ratio());
+
+ cache_simulator->Access(compaction_access);
+ cache_simulator->Access(compaction_access);
+ ASSERT_EQ(4, cache_simulator->miss_ratio_stats().total_accesses());
+ ASSERT_EQ(75, cache_simulator->miss_ratio_stats().miss_ratio());
+ ASSERT_EQ(2, cache_simulator->miss_ratio_stats().user_accesses());
+ ASSERT_EQ(50, cache_simulator->miss_ratio_stats().user_miss_ratio());
+
+ cache_simulator->reset_counter();
+ ASSERT_EQ(0, cache_simulator->miss_ratio_stats().total_accesses());
+ ASSERT_EQ(-1, cache_simulator->miss_ratio_stats().miss_ratio());
+ auto handle = sim_cache->Lookup(access.block_key);
+ ASSERT_NE(nullptr, handle);
+ sim_cache->Release(handle);
+ handle = sim_cache->Lookup(compaction_access.block_key);
+ ASSERT_EQ(nullptr, handle);
+}
+
+TEST_F(CacheSimulatorTest, GhostCacheSimulator) {
+ const BlockCacheTraceRecord& access = GenerateGetRecord(kGetId);
+ std::unique_ptr<GhostCache> ghost_cache(new GhostCache(
+ NewLRUCache(/*capacity=*/kGhostCacheSize, /*num_shard_bits=*/1,
+ /*strict_capacity_limit=*/false,
+ /*high_pri_pool_ratio=*/0)));
+ std::unique_ptr<CacheSimulator> cache_simulator(new CacheSimulator(
+ std::move(ghost_cache),
+ NewLRUCache(/*capacity=*/kCacheSize, /*num_shard_bits=*/1,
+ /*strict_capacity_limit=*/false,
+ /*high_pri_pool_ratio=*/0)));
+ cache_simulator->Access(access);
+ cache_simulator->Access(access);
+ ASSERT_EQ(2, cache_simulator->miss_ratio_stats().total_accesses());
+ // Both of them will be miss since we have a ghost cache.
+ ASSERT_EQ(100, cache_simulator->miss_ratio_stats().miss_ratio());
+}
+
+TEST_F(CacheSimulatorTest, PrioritizedCacheSimulator) {
+ const BlockCacheTraceRecord& access = GenerateGetRecord(kGetId);
+ std::shared_ptr<Cache> sim_cache =
+ NewLRUCache(/*capacity=*/kCacheSize, /*num_shard_bits=*/1,
+ /*strict_capacity_limit=*/false,
+ /*high_pri_pool_ratio=*/0);
+ std::unique_ptr<PrioritizedCacheSimulator> cache_simulator(
+ new PrioritizedCacheSimulator(nullptr, sim_cache));
+ cache_simulator->Access(access);
+ cache_simulator->Access(access);
+ ASSERT_EQ(2, cache_simulator->miss_ratio_stats().total_accesses());
+ ASSERT_EQ(50, cache_simulator->miss_ratio_stats().miss_ratio());
+
+ auto handle = sim_cache->Lookup(access.block_key);
+ ASSERT_NE(nullptr, handle);
+ sim_cache->Release(handle);
+}
+
+TEST_F(CacheSimulatorTest, GhostPrioritizedCacheSimulator) {
+ const BlockCacheTraceRecord& access = GenerateGetRecord(kGetId);
+ std::unique_ptr<GhostCache> ghost_cache(new GhostCache(
+ NewLRUCache(/*capacity=*/kGhostCacheSize, /*num_shard_bits=*/1,
+ /*strict_capacity_limit=*/false,
+ /*high_pri_pool_ratio=*/0)));
+ std::unique_ptr<PrioritizedCacheSimulator> cache_simulator(
+ new PrioritizedCacheSimulator(
+ std::move(ghost_cache),
+ NewLRUCache(/*capacity=*/kCacheSize, /*num_shard_bits=*/1,
+ /*strict_capacity_limit=*/false,
+ /*high_pri_pool_ratio=*/0)));
+ cache_simulator->Access(access);
+ cache_simulator->Access(access);
+ ASSERT_EQ(2, cache_simulator->miss_ratio_stats().total_accesses());
+ // Both of them will be miss since we have a ghost cache.
+ ASSERT_EQ(100, cache_simulator->miss_ratio_stats().miss_ratio());
+}
+
+TEST_F(CacheSimulatorTest, HybridRowBlockCacheSimulator) {
+ uint64_t block_id = 100;
+ BlockCacheTraceRecord first_get = GenerateGetRecord(kGetId);
+ first_get.get_from_user_specified_snapshot = Boolean::kTrue;
+ BlockCacheTraceRecord second_get = GenerateGetRecord(kGetId + 1);
+ second_get.referenced_data_size = 0;
+ second_get.referenced_key_exist_in_block = Boolean::kFalse;
+ second_get.get_from_user_specified_snapshot = Boolean::kTrue;
+ BlockCacheTraceRecord third_get = GenerateGetRecord(kGetId + 2);
+ third_get.referenced_data_size = 0;
+ third_get.referenced_key_exist_in_block = Boolean::kFalse;
+ third_get.referenced_key = kRefKeyPrefix + "third_get";
+ // We didn't find the referenced key in the third get.
+ third_get.referenced_key_exist_in_block = Boolean::kFalse;
+ third_get.referenced_data_size = 0;
+ std::shared_ptr<Cache> sim_cache =
+ NewLRUCache(/*capacity=*/kCacheSize, /*num_shard_bits=*/1,
+ /*strict_capacity_limit=*/false,
+ /*high_pri_pool_ratio=*/0);
+ std::unique_ptr<HybridRowBlockCacheSimulator> cache_simulator(
+ new HybridRowBlockCacheSimulator(
+ nullptr, sim_cache, /*insert_blocks_row_kvpair_misses=*/true));
+ // The first get request accesses 10 blocks. We should only report 10 accesses
+ // and 100% miss.
+ for (uint32_t i = 0; i < 10; i++) {
+ first_get.block_key = kBlockKeyPrefix + std::to_string(block_id);
+ cache_simulator->Access(first_get);
+ block_id++;
+ }
+
+ ASSERT_EQ(10, cache_simulator->miss_ratio_stats().total_accesses());
+ ASSERT_EQ(100, cache_simulator->miss_ratio_stats().miss_ratio());
+ ASSERT_EQ(10, cache_simulator->miss_ratio_stats().user_accesses());
+ ASSERT_EQ(100, cache_simulator->miss_ratio_stats().user_miss_ratio());
+ auto handle =
+ sim_cache->Lookup(std::to_string(first_get.sst_fd_number) + "_" +
+ ExtractUserKey(first_get.referenced_key).ToString());
+ ASSERT_NE(nullptr, handle);
+ sim_cache->Release(handle);
+ for (uint32_t i = 100; i < block_id; i++) {
+ handle = sim_cache->Lookup(kBlockKeyPrefix + std::to_string(i));
+ ASSERT_NE(nullptr, handle);
+ sim_cache->Release(handle);
+ }
+
+ // The second get request accesses the same key. We should report 15
+ // access and 66% miss, 10 misses with 15 accesses.
+ // We do not consider these 5 block lookups as misses since the row hits the
+ // cache.
+ for (uint32_t i = 0; i < 5; i++) {
+ second_get.block_key = kBlockKeyPrefix + std::to_string(block_id);
+ cache_simulator->Access(second_get);
+ block_id++;
+ }
+ ASSERT_EQ(15, cache_simulator->miss_ratio_stats().total_accesses());
+ ASSERT_EQ(66, static_cast<uint64_t>(
+ cache_simulator->miss_ratio_stats().miss_ratio()));
+ ASSERT_EQ(15, cache_simulator->miss_ratio_stats().user_accesses());
+ ASSERT_EQ(66, static_cast<uint64_t>(
+ cache_simulator->miss_ratio_stats().user_miss_ratio()));
+ handle =
+ sim_cache->Lookup(std::to_string(second_get.sst_fd_number) + "_" +
+ ExtractUserKey(second_get.referenced_key).ToString());
+ ASSERT_NE(nullptr, handle);
+ sim_cache->Release(handle);
+ for (uint32_t i = 100; i < block_id; i++) {
+ handle = sim_cache->Lookup(kBlockKeyPrefix + std::to_string(i));
+ if (i < 110) {
+ ASSERT_NE(nullptr, handle) << i;
+ sim_cache->Release(handle);
+ } else {
+ ASSERT_EQ(nullptr, handle) << i;
+ }
+ }
+
+ // The third get on a different key and does not have a size.
+ // This key should not be inserted into the cache.
+ for (uint32_t i = 0; i < 5; i++) {
+ third_get.block_key = kBlockKeyPrefix + std::to_string(block_id);
+ cache_simulator->Access(third_get);
+ block_id++;
+ }
+ ASSERT_EQ(20, cache_simulator->miss_ratio_stats().total_accesses());
+ ASSERT_EQ(75, static_cast<uint64_t>(
+ cache_simulator->miss_ratio_stats().miss_ratio()));
+ ASSERT_EQ(20, cache_simulator->miss_ratio_stats().user_accesses());
+ ASSERT_EQ(75, static_cast<uint64_t>(
+ cache_simulator->miss_ratio_stats().user_miss_ratio()));
+ // Assert that the third key is not inserted into the cache.
+ handle = sim_cache->Lookup(std::to_string(third_get.sst_fd_number) + "_" +
+ third_get.referenced_key);
+ ASSERT_EQ(nullptr, handle);
+ for (uint32_t i = 100; i < block_id; i++) {
+ if (i < 110 || i >= 115) {
+ handle = sim_cache->Lookup(kBlockKeyPrefix + std::to_string(i));
+ ASSERT_NE(nullptr, handle) << i;
+ sim_cache->Release(handle);
+ } else {
+ handle = sim_cache->Lookup(kBlockKeyPrefix + std::to_string(i));
+ ASSERT_EQ(nullptr, handle) << i;
+ }
+ }
+}
+
+TEST_F(CacheSimulatorTest, HybridRowBlockCacheSimulatorGetTest) {
+ BlockCacheTraceRecord get = GenerateGetRecord(kGetId);
+ get.block_size = 1;
+ get.referenced_data_size = 0;
+ get.access_timestamp = 0;
+ get.block_key = "1";
+ get.get_id = 1;
+ get.get_from_user_specified_snapshot = Boolean::kFalse;
+ get.referenced_key =
+ kRefKeyPrefix + std::to_string(1) + kRefKeySequenceNumber;
+ get.no_insert = Boolean::kFalse;
+ get.sst_fd_number = 0;
+ get.get_from_user_specified_snapshot = Boolean::kFalse;
+
+ LRUCacheOptions co;
+ co.capacity = 16;
+ co.num_shard_bits = 1;
+ co.strict_capacity_limit = false;
+ co.high_pri_pool_ratio = 0;
+ co.metadata_charge_policy = kDontChargeCacheMetadata;
+ std::shared_ptr<Cache> sim_cache = NewLRUCache(co);
+ std::unique_ptr<HybridRowBlockCacheSimulator> cache_simulator(
+ new HybridRowBlockCacheSimulator(
+ nullptr, sim_cache, /*insert_blocks_row_kvpair_misses=*/true));
+ // Expect a miss and does not insert the row key-value pair since it does not
+ // have size.
+ cache_simulator->Access(get);
+ AssertCache(sim_cache, cache_simulator->miss_ratio_stats(), 1, 1, 1, {"1"},
+ {});
+ get.access_timestamp += 1;
+ get.referenced_data_size = 1;
+ get.block_key = "2";
+ cache_simulator->Access(get);
+ AssertCache(sim_cache, cache_simulator->miss_ratio_stats(), 3, 2, 2,
+ {"1", "2"}, {"1"});
+ get.access_timestamp += 1;
+ get.block_key = "3";
+ // K1 should not inserted again.
+ cache_simulator->Access(get);
+ AssertCache(sim_cache, cache_simulator->miss_ratio_stats(), 4, 3, 3,
+ {"1", "2", "3"}, {"1"});
+
+ // A second get request referencing the same key.
+ get.access_timestamp += 1;
+ get.get_id = 2;
+ get.block_key = "4";
+ get.referenced_data_size = 0;
+ cache_simulator->Access(get);
+ AssertCache(sim_cache, cache_simulator->miss_ratio_stats(), 4, 4, 3,
+ {"1", "2", "3"}, {"1"});
+
+ // A third get request searches three files, three different keys.
+ // And the second key observes a hit.
+ get.access_timestamp += 1;
+ get.referenced_data_size = 1;
+ get.get_id = 3;
+ get.block_key = "3";
+ get.referenced_key = kRefKeyPrefix + "2" + kRefKeySequenceNumber;
+ // K2 should observe a miss. Block 3 observes a hit.
+ cache_simulator->Access(get);
+ AssertCache(sim_cache, cache_simulator->miss_ratio_stats(), 5, 5, 3,
+ {"1", "2", "3"}, {"1", "2"});
+
+ get.access_timestamp += 1;
+ get.referenced_data_size = 1;
+ get.get_id = 3;
+ get.block_key = "4";
+ get.referenced_data_size = 1;
+ get.referenced_key = kRefKeyPrefix + "1" + kRefKeySequenceNumber;
+ // K1 should observe a hit.
+ cache_simulator->Access(get);
+ AssertCache(sim_cache, cache_simulator->miss_ratio_stats(), 5, 6, 3,
+ {"1", "2", "3"}, {"1", "2"});
+
+ get.access_timestamp += 1;
+ get.referenced_data_size = 1;
+ get.get_id = 3;
+ get.block_key = "4";
+ get.referenced_data_size = 1;
+ get.referenced_key = kRefKeyPrefix + "3" + kRefKeySequenceNumber;
+ // K3 should observe a miss.
+ // However, as the get already complete, we should not access k3 any more.
+ cache_simulator->Access(get);
+ AssertCache(sim_cache, cache_simulator->miss_ratio_stats(), 5, 7, 3,
+ {"1", "2", "3"}, {"1", "2"});
+
+ // A fourth get request searches one file and two blocks. One row key.
+ get.access_timestamp += 1;
+ get.get_id = 4;
+ get.block_key = "5";
+ get.referenced_key = kRefKeyPrefix + "4" + kRefKeySequenceNumber;
+ get.referenced_data_size = 1;
+ cache_simulator->Access(get);
+ AssertCache(sim_cache, cache_simulator->miss_ratio_stats(), 7, 8, 4,
+ {"1", "2", "3", "5"}, {"1", "2", "4"});
+ for (auto const& key : {"1", "2", "4"}) {
+ auto handle = sim_cache->Lookup("0_" + kRefKeyPrefix + key);
+ ASSERT_NE(nullptr, handle);
+ sim_cache->Release(handle);
+ }
+
+ // A bunch of insertions which evict cached row keys.
+ for (uint32_t i = 6; i < 100; i++) {
+ get.access_timestamp += 1;
+ get.get_id = 0;
+ get.block_key = std::to_string(i);
+ cache_simulator->Access(get);
+ }
+
+ get.get_id = 4;
+ // A different block.
+ get.block_key = "100";
+ // Same row key and should not be inserted again.
+ get.referenced_key = kRefKeyPrefix + "4" + kRefKeySequenceNumber;
+ get.referenced_data_size = 1;
+ cache_simulator->Access(get);
+ AssertCache(sim_cache, cache_simulator->miss_ratio_stats(), 16, 103, 99, {},
+ {});
+ for (auto const& key : {"1", "2", "4"}) {
+ auto handle = sim_cache->Lookup("0_" + kRefKeyPrefix + key);
+ ASSERT_EQ(nullptr, handle);
+ }
+}
+
+TEST_F(CacheSimulatorTest, HybridRowBlockNoInsertCacheSimulator) {
+ uint64_t block_id = 100;
+ BlockCacheTraceRecord first_get = GenerateGetRecord(kGetId);
+ std::shared_ptr<Cache> sim_cache =
+ NewLRUCache(/*capacity=*/kCacheSize, /*num_shard_bits=*/1,
+ /*strict_capacity_limit=*/false,
+ /*high_pri_pool_ratio=*/0);
+ std::unique_ptr<HybridRowBlockCacheSimulator> cache_simulator(
+ new HybridRowBlockCacheSimulator(
+ nullptr, sim_cache, /*insert_blocks_row_kvpair_misses=*/false));
+ for (uint32_t i = 0; i < 9; i++) {
+ first_get.block_key = kBlockKeyPrefix + std::to_string(block_id);
+ cache_simulator->Access(first_get);
+ block_id++;
+ }
+ auto handle =
+ sim_cache->Lookup(std::to_string(first_get.sst_fd_number) + "_" +
+ ExtractUserKey(first_get.referenced_key).ToString());
+ ASSERT_NE(nullptr, handle);
+ sim_cache->Release(handle);
+ // All blocks are missing from the cache since insert_blocks_row_kvpair_misses
+ // is set to false.
+ for (uint32_t i = 100; i < block_id; i++) {
+ handle = sim_cache->Lookup(kBlockKeyPrefix + std::to_string(i));
+ ASSERT_EQ(nullptr, handle);
+ }
+}
+
+TEST_F(CacheSimulatorTest, GhostHybridRowBlockCacheSimulator) {
+ std::unique_ptr<GhostCache> ghost_cache(new GhostCache(
+ NewLRUCache(/*capacity=*/kGhostCacheSize, /*num_shard_bits=*/1,
+ /*strict_capacity_limit=*/false,
+ /*high_pri_pool_ratio=*/0)));
+ const BlockCacheTraceRecord& first_get = GenerateGetRecord(kGetId);
+ const BlockCacheTraceRecord& second_get = GenerateGetRecord(kGetId + 1);
+ const BlockCacheTraceRecord& third_get = GenerateGetRecord(kGetId + 2);
+ std::unique_ptr<HybridRowBlockCacheSimulator> cache_simulator(
+ new HybridRowBlockCacheSimulator(
+ std::move(ghost_cache),
+ NewLRUCache(/*capacity=*/kCacheSize, /*num_shard_bits=*/1,
+ /*strict_capacity_limit=*/false,
+ /*high_pri_pool_ratio=*/0),
+ /*insert_blocks_row_kvpair_misses=*/false));
+ // Two get requests access the same key.
+ cache_simulator->Access(first_get);
+ cache_simulator->Access(second_get);
+ ASSERT_EQ(2, cache_simulator->miss_ratio_stats().total_accesses());
+ ASSERT_EQ(100, cache_simulator->miss_ratio_stats().miss_ratio());
+ ASSERT_EQ(2, cache_simulator->miss_ratio_stats().user_accesses());
+ ASSERT_EQ(100, cache_simulator->miss_ratio_stats().user_miss_ratio());
+ // We insert the key-value pair upon the second get request. A third get
+ // request should observe a hit.
+ for (uint32_t i = 0; i < 10; i++) {
+ cache_simulator->Access(third_get);
+ }
+ ASSERT_EQ(12, cache_simulator->miss_ratio_stats().total_accesses());
+ ASSERT_EQ(16, static_cast<uint64_t>(
+ cache_simulator->miss_ratio_stats().miss_ratio()));
+ ASSERT_EQ(12, cache_simulator->miss_ratio_stats().user_accesses());
+ ASSERT_EQ(16, static_cast<uint64_t>(
+ cache_simulator->miss_ratio_stats().user_miss_ratio()));
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/rocksdb/utilities/simulator_cache/sim_cache.cc b/src/rocksdb/utilities/simulator_cache/sim_cache.cc
new file mode 100644
index 000000000..ec411cf9a
--- /dev/null
+++ b/src/rocksdb/utilities/simulator_cache/sim_cache.cc
@@ -0,0 +1,354 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "rocksdb/utilities/sim_cache.h"
+#include <atomic>
+#include "env/composite_env_wrapper.h"
+#include "file/writable_file_writer.h"
+#include "monitoring/statistics.h"
+#include "port/port.h"
+#include "rocksdb/env.h"
+#include "util/mutexlock.h"
+#include "util/string_util.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+namespace {
+
+class CacheActivityLogger {
+ public:
+ CacheActivityLogger()
+ : activity_logging_enabled_(false), max_logging_size_(0) {}
+
+ ~CacheActivityLogger() {
+ MutexLock l(&mutex_);
+
+ StopLoggingInternal();
+ }
+
+ Status StartLogging(const std::string& activity_log_file, Env* env,
+ uint64_t max_logging_size = 0) {
+ assert(activity_log_file != "");
+ assert(env != nullptr);
+
+ Status status;
+ EnvOptions env_opts;
+ std::unique_ptr<WritableFile> log_file;
+
+ MutexLock l(&mutex_);
+
+ // Stop existing logging if any
+ StopLoggingInternal();
+
+ // Open log file
+ status = env->NewWritableFile(activity_log_file, &log_file, env_opts);
+ if (!status.ok()) {
+ return status;
+ }
+ file_writer_.reset(new WritableFileWriter(
+ NewLegacyWritableFileWrapper(std::move(log_file)), activity_log_file,
+ env_opts));
+
+ max_logging_size_ = max_logging_size;
+ activity_logging_enabled_.store(true);
+
+ return status;
+ }
+
+ void StopLogging() {
+ MutexLock l(&mutex_);
+
+ StopLoggingInternal();
+ }
+
+ void ReportLookup(const Slice& key) {
+ if (activity_logging_enabled_.load() == false) {
+ return;
+ }
+
+ std::string log_line = "LOOKUP - " + key.ToString(true) + "\n";
+
+ // line format: "LOOKUP - <KEY>"
+ MutexLock l(&mutex_);
+ Status s = file_writer_->Append(log_line);
+ if (!s.ok() && bg_status_.ok()) {
+ bg_status_ = s;
+ }
+ if (MaxLoggingSizeReached() || !bg_status_.ok()) {
+ // Stop logging if we have reached the max file size or
+ // encountered an error
+ StopLoggingInternal();
+ }
+ }
+
+ void ReportAdd(const Slice& key, size_t size) {
+ if (activity_logging_enabled_.load() == false) {
+ return;
+ }
+
+ std::string log_line = "ADD - ";
+ log_line += key.ToString(true);
+ log_line += " - ";
+ AppendNumberTo(&log_line, size);
+ // @lint-ignore TXT2 T25377293 Grandfathered in
+ log_line += "\n";
+
+ // line format: "ADD - <KEY> - <KEY-SIZE>"
+ MutexLock l(&mutex_);
+ Status s = file_writer_->Append(log_line);
+ if (!s.ok() && bg_status_.ok()) {
+ bg_status_ = s;
+ }
+
+ if (MaxLoggingSizeReached() || !bg_status_.ok()) {
+ // Stop logging if we have reached the max file size or
+ // encountered an error
+ StopLoggingInternal();
+ }
+ }
+
+ Status& bg_status() {
+ MutexLock l(&mutex_);
+ return bg_status_;
+ }
+
+ private:
+ bool MaxLoggingSizeReached() {
+ mutex_.AssertHeld();
+
+ return (max_logging_size_ > 0 &&
+ file_writer_->GetFileSize() >= max_logging_size_);
+ }
+
+ void StopLoggingInternal() {
+ mutex_.AssertHeld();
+
+ if (!activity_logging_enabled_) {
+ return;
+ }
+
+ activity_logging_enabled_.store(false);
+ Status s = file_writer_->Close();
+ if (!s.ok() && bg_status_.ok()) {
+ bg_status_ = s;
+ }
+ }
+
+ // Mutex to sync writes to file_writer, and all following
+ // class data members
+ port::Mutex mutex_;
+ // Indicates if logging is currently enabled
+ // atomic to allow reads without mutex
+ std::atomic<bool> activity_logging_enabled_;
+ // When reached, we will stop logging and close the file
+ // Value of 0 means unlimited
+ uint64_t max_logging_size_;
+ std::unique_ptr<WritableFileWriter> file_writer_;
+ Status bg_status_;
+};
+
+// SimCacheImpl definition
+class SimCacheImpl : public SimCache {
+ public:
+ // capacity for real cache (ShardedLRUCache)
+ // test_capacity for key only cache
+ SimCacheImpl(std::shared_ptr<Cache> sim_cache, std::shared_ptr<Cache> cache)
+ : cache_(cache),
+ key_only_cache_(sim_cache),
+ miss_times_(0),
+ hit_times_(0),
+ stats_(nullptr) {}
+
+ ~SimCacheImpl() override {}
+ void SetCapacity(size_t capacity) override { cache_->SetCapacity(capacity); }
+
+ void SetStrictCapacityLimit(bool strict_capacity_limit) override {
+ cache_->SetStrictCapacityLimit(strict_capacity_limit);
+ }
+
+ Status Insert(const Slice& key, void* value, size_t charge,
+ void (*deleter)(const Slice& key, void* value), Handle** handle,
+ Priority priority) override {
+ // The handle and value passed in are for real cache, so we pass nullptr
+ // to key_only_cache_ for both instead. Also, the deleter function pointer
+ // will be called by user to perform some external operation which should
+ // be applied only once. Thus key_only_cache accepts an empty function.
+ // *Lambda function without capture can be assgined to a function pointer
+ Handle* h = key_only_cache_->Lookup(key);
+ if (h == nullptr) {
+ key_only_cache_->Insert(key, nullptr, charge,
+ [](const Slice& /*k*/, void* /*v*/) {}, nullptr,
+ priority);
+ } else {
+ key_only_cache_->Release(h);
+ }
+
+ cache_activity_logger_.ReportAdd(key, charge);
+ if (!cache_) {
+ return Status::OK();
+ }
+ return cache_->Insert(key, value, charge, deleter, handle, priority);
+ }
+
+ Handle* Lookup(const Slice& key, Statistics* stats) override {
+ Handle* h = key_only_cache_->Lookup(key);
+ if (h != nullptr) {
+ key_only_cache_->Release(h);
+ inc_hit_counter();
+ RecordTick(stats, SIM_BLOCK_CACHE_HIT);
+ } else {
+ inc_miss_counter();
+ RecordTick(stats, SIM_BLOCK_CACHE_MISS);
+ }
+
+ cache_activity_logger_.ReportLookup(key);
+ if (!cache_) {
+ return nullptr;
+ }
+ return cache_->Lookup(key, stats);
+ }
+
+ bool Ref(Handle* handle) override { return cache_->Ref(handle); }
+
+ bool Release(Handle* handle, bool force_erase = false) override {
+ return cache_->Release(handle, force_erase);
+ }
+
+ void Erase(const Slice& key) override {
+ cache_->Erase(key);
+ key_only_cache_->Erase(key);
+ }
+
+ void* Value(Handle* handle) override { return cache_->Value(handle); }
+
+ uint64_t NewId() override { return cache_->NewId(); }
+
+ size_t GetCapacity() const override { return cache_->GetCapacity(); }
+
+ bool HasStrictCapacityLimit() const override {
+ return cache_->HasStrictCapacityLimit();
+ }
+
+ size_t GetUsage() const override { return cache_->GetUsage(); }
+
+ size_t GetUsage(Handle* handle) const override {
+ return cache_->GetUsage(handle);
+ }
+
+ size_t GetCharge(Handle* handle) const override {
+ return cache_->GetCharge(handle);
+ }
+
+ size_t GetPinnedUsage() const override { return cache_->GetPinnedUsage(); }
+
+ void DisownData() override {
+ cache_->DisownData();
+ key_only_cache_->DisownData();
+ }
+
+ void ApplyToAllCacheEntries(void (*callback)(void*, size_t),
+ bool thread_safe) override {
+ // only apply to _cache since key_only_cache doesn't hold value
+ cache_->ApplyToAllCacheEntries(callback, thread_safe);
+ }
+
+ void EraseUnRefEntries() override {
+ cache_->EraseUnRefEntries();
+ key_only_cache_->EraseUnRefEntries();
+ }
+
+ size_t GetSimCapacity() const override {
+ return key_only_cache_->GetCapacity();
+ }
+ size_t GetSimUsage() const override { return key_only_cache_->GetUsage(); }
+ void SetSimCapacity(size_t capacity) override {
+ key_only_cache_->SetCapacity(capacity);
+ }
+
+ uint64_t get_miss_counter() const override {
+ return miss_times_.load(std::memory_order_relaxed);
+ }
+
+ uint64_t get_hit_counter() const override {
+ return hit_times_.load(std::memory_order_relaxed);
+ }
+
+ void reset_counter() override {
+ miss_times_.store(0, std::memory_order_relaxed);
+ hit_times_.store(0, std::memory_order_relaxed);
+ SetTickerCount(stats_, SIM_BLOCK_CACHE_HIT, 0);
+ SetTickerCount(stats_, SIM_BLOCK_CACHE_MISS, 0);
+ }
+
+ std::string ToString() const override {
+ std::string res;
+ res.append("SimCache MISSes: " + std::to_string(get_miss_counter()) + "\n");
+ res.append("SimCache HITs: " + std::to_string(get_hit_counter()) + "\n");
+ char buff[350];
+ auto lookups = get_miss_counter() + get_hit_counter();
+ snprintf(buff, sizeof(buff), "SimCache HITRATE: %.2f%%\n",
+ (lookups == 0 ? 0 : get_hit_counter() * 100.0f / lookups));
+ res.append(buff);
+ return res;
+ }
+
+ std::string GetPrintableOptions() const override {
+ std::string ret;
+ ret.reserve(20000);
+ ret.append(" cache_options:\n");
+ ret.append(cache_->GetPrintableOptions());
+ ret.append(" sim_cache_options:\n");
+ ret.append(key_only_cache_->GetPrintableOptions());
+ return ret;
+ }
+
+ Status StartActivityLogging(const std::string& activity_log_file, Env* env,
+ uint64_t max_logging_size = 0) override {
+ return cache_activity_logger_.StartLogging(activity_log_file, env,
+ max_logging_size);
+ }
+
+ void StopActivityLogging() override { cache_activity_logger_.StopLogging(); }
+
+ Status GetActivityLoggingStatus() override {
+ return cache_activity_logger_.bg_status();
+ }
+
+ private:
+ std::shared_ptr<Cache> cache_;
+ std::shared_ptr<Cache> key_only_cache_;
+ std::atomic<uint64_t> miss_times_;
+ std::atomic<uint64_t> hit_times_;
+ Statistics* stats_;
+ CacheActivityLogger cache_activity_logger_;
+
+ void inc_miss_counter() {
+ miss_times_.fetch_add(1, std::memory_order_relaxed);
+ }
+ void inc_hit_counter() { hit_times_.fetch_add(1, std::memory_order_relaxed); }
+};
+
+} // end anonymous namespace
+
+// For instrumentation purpose, use NewSimCache instead
+std::shared_ptr<SimCache> NewSimCache(std::shared_ptr<Cache> cache,
+ size_t sim_capacity, int num_shard_bits) {
+ LRUCacheOptions co;
+ co.capacity = sim_capacity;
+ co.num_shard_bits = num_shard_bits;
+ co.metadata_charge_policy = kDontChargeCacheMetadata;
+ return NewSimCache(NewLRUCache(co), cache, num_shard_bits);
+}
+
+std::shared_ptr<SimCache> NewSimCache(std::shared_ptr<Cache> sim_cache,
+ std::shared_ptr<Cache> cache,
+ int num_shard_bits) {
+ if (num_shard_bits >= 20) {
+ return nullptr; // the cache cannot be sharded into too many fine pieces
+ }
+ return std::make_shared<SimCacheImpl>(sim_cache, cache);
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/simulator_cache/sim_cache_test.cc b/src/rocksdb/utilities/simulator_cache/sim_cache_test.cc
new file mode 100644
index 000000000..6cb495813
--- /dev/null
+++ b/src/rocksdb/utilities/simulator_cache/sim_cache_test.cc
@@ -0,0 +1,225 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "rocksdb/utilities/sim_cache.h"
+#include <cstdlib>
+#include "db/db_test_util.h"
+#include "port/stack_trace.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class SimCacheTest : public DBTestBase {
+ private:
+ size_t miss_count_ = 0;
+ size_t hit_count_ = 0;
+ size_t insert_count_ = 0;
+ size_t failure_count_ = 0;
+
+ public:
+ const size_t kNumBlocks = 5;
+ const size_t kValueSize = 1000;
+
+ SimCacheTest() : DBTestBase("/sim_cache_test") {}
+
+ BlockBasedTableOptions GetTableOptions() {
+ BlockBasedTableOptions table_options;
+ // Set a small enough block size so that each key-value get its own block.
+ table_options.block_size = 1;
+ return table_options;
+ }
+
+ Options GetOptions(const BlockBasedTableOptions& table_options) {
+ Options options = CurrentOptions();
+ options.create_if_missing = true;
+ // options.compression = kNoCompression;
+ options.statistics = ROCKSDB_NAMESPACE::CreateDBStatistics();
+ options.table_factory.reset(new BlockBasedTableFactory(table_options));
+ return options;
+ }
+
+ void InitTable(const Options& /*options*/) {
+ std::string value(kValueSize, 'a');
+ for (size_t i = 0; i < kNumBlocks * 2; i++) {
+ ASSERT_OK(Put(ToString(i), value.c_str()));
+ }
+ }
+
+ void RecordCacheCounters(const Options& options) {
+ miss_count_ = TestGetTickerCount(options, BLOCK_CACHE_MISS);
+ hit_count_ = TestGetTickerCount(options, BLOCK_CACHE_HIT);
+ insert_count_ = TestGetTickerCount(options, BLOCK_CACHE_ADD);
+ failure_count_ = TestGetTickerCount(options, BLOCK_CACHE_ADD_FAILURES);
+ }
+
+ void CheckCacheCounters(const Options& options, size_t expected_misses,
+ size_t expected_hits, size_t expected_inserts,
+ size_t expected_failures) {
+ size_t new_miss_count = TestGetTickerCount(options, BLOCK_CACHE_MISS);
+ size_t new_hit_count = TestGetTickerCount(options, BLOCK_CACHE_HIT);
+ size_t new_insert_count = TestGetTickerCount(options, BLOCK_CACHE_ADD);
+ size_t new_failure_count =
+ TestGetTickerCount(options, BLOCK_CACHE_ADD_FAILURES);
+ ASSERT_EQ(miss_count_ + expected_misses, new_miss_count);
+ ASSERT_EQ(hit_count_ + expected_hits, new_hit_count);
+ ASSERT_EQ(insert_count_ + expected_inserts, new_insert_count);
+ ASSERT_EQ(failure_count_ + expected_failures, new_failure_count);
+ miss_count_ = new_miss_count;
+ hit_count_ = new_hit_count;
+ insert_count_ = new_insert_count;
+ failure_count_ = new_failure_count;
+ }
+};
+
+TEST_F(SimCacheTest, SimCache) {
+ ReadOptions read_options;
+ auto table_options = GetTableOptions();
+ auto options = GetOptions(table_options);
+ InitTable(options);
+ LRUCacheOptions co;
+ co.capacity = 0;
+ co.num_shard_bits = 0;
+ co.strict_capacity_limit = false;
+ co.metadata_charge_policy = kDontChargeCacheMetadata;
+ std::shared_ptr<SimCache> simCache = NewSimCache(NewLRUCache(co), 20000, 0);
+ table_options.block_cache = simCache;
+ options.table_factory.reset(new BlockBasedTableFactory(table_options));
+ Reopen(options);
+ RecordCacheCounters(options);
+
+ std::vector<std::unique_ptr<Iterator>> iterators(kNumBlocks);
+ Iterator* iter = nullptr;
+
+ // Load blocks into cache.
+ for (size_t i = 0; i < kNumBlocks; i++) {
+ iter = db_->NewIterator(read_options);
+ iter->Seek(ToString(i));
+ ASSERT_OK(iter->status());
+ CheckCacheCounters(options, 1, 0, 1, 0);
+ iterators[i].reset(iter);
+ }
+ ASSERT_EQ(kNumBlocks,
+ simCache->get_hit_counter() + simCache->get_miss_counter());
+ ASSERT_EQ(0, simCache->get_hit_counter());
+ size_t usage = simCache->GetUsage();
+ ASSERT_LT(0, usage);
+ ASSERT_EQ(usage, simCache->GetSimUsage());
+ simCache->SetCapacity(usage);
+ ASSERT_EQ(usage, simCache->GetPinnedUsage());
+
+ // Test with strict capacity limit.
+ simCache->SetStrictCapacityLimit(true);
+ iter = db_->NewIterator(read_options);
+ iter->Seek(ToString(kNumBlocks * 2 - 1));
+ ASSERT_TRUE(iter->status().IsIncomplete());
+ CheckCacheCounters(options, 1, 0, 0, 1);
+ delete iter;
+ iter = nullptr;
+
+ // Release iterators and access cache again.
+ for (size_t i = 0; i < kNumBlocks; i++) {
+ iterators[i].reset();
+ CheckCacheCounters(options, 0, 0, 0, 0);
+ }
+ // Add kNumBlocks again
+ for (size_t i = 0; i < kNumBlocks; i++) {
+ std::unique_ptr<Iterator> it(db_->NewIterator(read_options));
+ it->Seek(ToString(i));
+ ASSERT_OK(it->status());
+ CheckCacheCounters(options, 0, 1, 0, 0);
+ }
+ ASSERT_EQ(5, simCache->get_hit_counter());
+ for (size_t i = kNumBlocks; i < kNumBlocks * 2; i++) {
+ std::unique_ptr<Iterator> it(db_->NewIterator(read_options));
+ it->Seek(ToString(i));
+ ASSERT_OK(it->status());
+ CheckCacheCounters(options, 1, 0, 1, 0);
+ }
+ ASSERT_EQ(0, simCache->GetPinnedUsage());
+ ASSERT_EQ(3 * kNumBlocks + 1,
+ simCache->get_hit_counter() + simCache->get_miss_counter());
+ ASSERT_EQ(6, simCache->get_hit_counter());
+}
+
+TEST_F(SimCacheTest, SimCacheLogging) {
+ auto table_options = GetTableOptions();
+ auto options = GetOptions(table_options);
+ options.disable_auto_compactions = true;
+ LRUCacheOptions co;
+ co.capacity = 1024 * 1024;
+ co.metadata_charge_policy = kDontChargeCacheMetadata;
+ std::shared_ptr<SimCache> sim_cache = NewSimCache(NewLRUCache(co), 20000, 0);
+ table_options.block_cache = sim_cache;
+ options.table_factory.reset(new BlockBasedTableFactory(table_options));
+ Reopen(options);
+
+ int num_block_entries = 20;
+ for (int i = 0; i < num_block_entries; i++) {
+ Put(Key(i), "val");
+ Flush();
+ }
+
+ std::string log_file = test::PerThreadDBPath(env_, "cache_log.txt");
+ ASSERT_OK(sim_cache->StartActivityLogging(log_file, env_));
+ for (int i = 0; i < num_block_entries; i++) {
+ ASSERT_EQ(Get(Key(i)), "val");
+ }
+ for (int i = 0; i < num_block_entries; i++) {
+ ASSERT_EQ(Get(Key(i)), "val");
+ }
+ sim_cache->StopActivityLogging();
+ ASSERT_OK(sim_cache->GetActivityLoggingStatus());
+
+ std::string file_contents = "";
+ ReadFileToString(env_, log_file, &file_contents);
+
+ int lookup_num = 0;
+ int add_num = 0;
+ std::string::size_type pos;
+
+ // count number of lookups
+ pos = 0;
+ while ((pos = file_contents.find("LOOKUP -", pos)) != std::string::npos) {
+ ++lookup_num;
+ pos += 1;
+ }
+
+ // count number of additions
+ pos = 0;
+ while ((pos = file_contents.find("ADD -", pos)) != std::string::npos) {
+ ++add_num;
+ pos += 1;
+ }
+
+ // We asked for every block twice
+ ASSERT_EQ(lookup_num, num_block_entries * 2);
+
+ // We added every block only once, since the cache can hold all blocks
+ ASSERT_EQ(add_num, num_block_entries);
+
+ // Log things again but stop logging automatically after reaching 512 bytes
+ // @lint-ignore TXT2 T25377293 Grandfathered in
+ int max_size = 512;
+ ASSERT_OK(sim_cache->StartActivityLogging(log_file, env_, max_size));
+ for (int it = 0; it < 10; it++) {
+ for (int i = 0; i < num_block_entries; i++) {
+ ASSERT_EQ(Get(Key(i)), "val");
+ }
+ }
+ ASSERT_OK(sim_cache->GetActivityLoggingStatus());
+
+ uint64_t fsize = 0;
+ ASSERT_OK(env_->GetFileSize(log_file, &fsize));
+ // error margin of 100 bytes
+ ASSERT_LT(fsize, max_size + 100);
+ ASSERT_GT(fsize, max_size - 100);
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/rocksdb/utilities/table_properties_collectors/compact_on_deletion_collector.cc b/src/rocksdb/utilities/table_properties_collectors/compact_on_deletion_collector.cc
new file mode 100644
index 000000000..89d666d4d
--- /dev/null
+++ b/src/rocksdb/utilities/table_properties_collectors/compact_on_deletion_collector.cc
@@ -0,0 +1,90 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+#include "utilities/table_properties_collectors/compact_on_deletion_collector.h"
+
+#include <memory>
+#include "rocksdb/utilities/table_properties_collectors.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+CompactOnDeletionCollector::CompactOnDeletionCollector(
+ size_t sliding_window_size, size_t deletion_trigger)
+ : bucket_size_((sliding_window_size + kNumBuckets - 1) / kNumBuckets),
+ current_bucket_(0),
+ num_keys_in_current_bucket_(0),
+ num_deletions_in_observation_window_(0),
+ deletion_trigger_(deletion_trigger),
+ need_compaction_(false),
+ finished_(false) {
+ memset(num_deletions_in_buckets_, 0, sizeof(size_t) * kNumBuckets);
+}
+
+// AddUserKey() will be called when a new key/value pair is inserted into the
+// table.
+// @params key the user key that is inserted into the table.
+// @params value the value that is inserted into the table.
+// @params file_size file size up to now
+Status CompactOnDeletionCollector::AddUserKey(const Slice& /*key*/,
+ const Slice& /*value*/,
+ EntryType type,
+ SequenceNumber /*seq*/,
+ uint64_t /*file_size*/) {
+ assert(!finished_);
+ if (bucket_size_ == 0) {
+ // This collector is effectively disabled
+ return Status::OK();
+ }
+
+ if (need_compaction_) {
+ // If the output file already needs to be compacted, skip the check.
+ return Status::OK();
+ }
+
+ if (num_keys_in_current_bucket_ == bucket_size_) {
+ // When the current bucket is full, advance the cursor of the
+ // ring buffer to the next bucket.
+ current_bucket_ = (current_bucket_ + 1) % kNumBuckets;
+
+ // Update the current count of observed deletion keys by excluding
+ // the number of deletion keys in the oldest bucket in the
+ // observation window.
+ assert(num_deletions_in_observation_window_ >=
+ num_deletions_in_buckets_[current_bucket_]);
+ num_deletions_in_observation_window_ -=
+ num_deletions_in_buckets_[current_bucket_];
+ num_deletions_in_buckets_[current_bucket_] = 0;
+ num_keys_in_current_bucket_ = 0;
+ }
+
+ num_keys_in_current_bucket_++;
+ if (type == kEntryDelete) {
+ num_deletions_in_observation_window_++;
+ num_deletions_in_buckets_[current_bucket_]++;
+ if (num_deletions_in_observation_window_ >= deletion_trigger_) {
+ need_compaction_ = true;
+ }
+ }
+ return Status::OK();
+}
+
+TablePropertiesCollector*
+CompactOnDeletionCollectorFactory::CreateTablePropertiesCollector(
+ TablePropertiesCollectorFactory::Context /*context*/) {
+ return new CompactOnDeletionCollector(
+ sliding_window_size_.load(), deletion_trigger_.load());
+}
+
+std::shared_ptr<CompactOnDeletionCollectorFactory>
+ NewCompactOnDeletionCollectorFactory(
+ size_t sliding_window_size,
+ size_t deletion_trigger) {
+ return std::shared_ptr<CompactOnDeletionCollectorFactory>(
+ new CompactOnDeletionCollectorFactory(
+ sliding_window_size, deletion_trigger));
+}
+} // namespace ROCKSDB_NAMESPACE
+#endif // !ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/table_properties_collectors/compact_on_deletion_collector.h b/src/rocksdb/utilities/table_properties_collectors/compact_on_deletion_collector.h
new file mode 100644
index 000000000..cc559ab2b
--- /dev/null
+++ b/src/rocksdb/utilities/table_properties_collectors/compact_on_deletion_collector.h
@@ -0,0 +1,72 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#ifndef ROCKSDB_LITE
+#include "rocksdb/utilities/table_properties_collectors.h"
+namespace ROCKSDB_NAMESPACE {
+
+class CompactOnDeletionCollector : public TablePropertiesCollector {
+ public:
+ CompactOnDeletionCollector(
+ size_t sliding_window_size,
+ size_t deletion_trigger);
+
+ // AddUserKey() will be called when a new key/value pair is inserted into the
+ // table.
+ // @params key the user key that is inserted into the table.
+ // @params value the value that is inserted into the table.
+ // @params file_size file size up to now
+ virtual Status AddUserKey(const Slice& key, const Slice& value,
+ EntryType type, SequenceNumber seq,
+ uint64_t file_size) override;
+
+ // Finish() will be called when a table has already been built and is ready
+ // for writing the properties block.
+ // @params properties User will add their collected statistics to
+ // `properties`.
+ virtual Status Finish(UserCollectedProperties* /*properties*/) override {
+ finished_ = true;
+ return Status::OK();
+ }
+
+ // Return the human-readable properties, where the key is property name and
+ // the value is the human-readable form of value.
+ virtual UserCollectedProperties GetReadableProperties() const override {
+ return UserCollectedProperties();
+ }
+
+ // The name of the properties collector can be used for debugging purpose.
+ virtual const char* Name() const override {
+ return "CompactOnDeletionCollector";
+ }
+
+ // EXPERIMENTAL Return whether the output file should be further compacted
+ virtual bool NeedCompact() const override {
+ return need_compaction_;
+ }
+
+ static const int kNumBuckets = 128;
+
+ private:
+ void Reset();
+
+ // A ring buffer that used to count the number of deletion entries for every
+ // "bucket_size_" keys.
+ size_t num_deletions_in_buckets_[kNumBuckets];
+ // the number of keys in a bucket
+ size_t bucket_size_;
+
+ size_t current_bucket_;
+ size_t num_keys_in_current_bucket_;
+ size_t num_deletions_in_observation_window_;
+ size_t deletion_trigger_;
+ // true if the current SST file needs to be compacted.
+ bool need_compaction_;
+ bool finished_;
+};
+} // namespace ROCKSDB_NAMESPACE
+#endif // !ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/table_properties_collectors/compact_on_deletion_collector_test.cc b/src/rocksdb/utilities/table_properties_collectors/compact_on_deletion_collector_test.cc
new file mode 100644
index 000000000..9b94cc272
--- /dev/null
+++ b/src/rocksdb/utilities/table_properties_collectors/compact_on_deletion_collector_test.cc
@@ -0,0 +1,178 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include <stdio.h>
+
+#ifndef ROCKSDB_LITE
+#include <algorithm>
+#include <cmath>
+#include <vector>
+
+#include "rocksdb/table.h"
+#include "rocksdb/table_properties.h"
+#include "rocksdb/utilities/table_properties_collectors.h"
+#include "util/random.h"
+#include "utilities/table_properties_collectors/compact_on_deletion_collector.h"
+
+int main(int /*argc*/, char** /*argv*/) {
+ const int kWindowSizes[] =
+ {1000, 10000, 10000, 127, 128, 129, 255, 256, 257, 2, 10000};
+ const int kDeletionTriggers[] =
+ {500, 9500, 4323, 47, 61, 128, 250, 250, 250, 2, 2};
+ ROCKSDB_NAMESPACE::TablePropertiesCollectorFactory::Context context;
+ context.column_family_id = ROCKSDB_NAMESPACE::
+ TablePropertiesCollectorFactory::Context::kUnknownColumnFamily;
+
+ std::vector<int> window_sizes;
+ std::vector<int> deletion_triggers;
+ // deterministic tests
+ for (int test = 0; test < 9; ++test) {
+ window_sizes.emplace_back(kWindowSizes[test]);
+ deletion_triggers.emplace_back(kDeletionTriggers[test]);
+ }
+
+ // randomize tests
+ ROCKSDB_NAMESPACE::Random rnd(301);
+ const int kMaxTestSize = 100000l;
+ for (int random_test = 0; random_test < 10; random_test++) {
+ int window_size = rnd.Uniform(kMaxTestSize) + 1;
+ int deletion_trigger = rnd.Uniform(window_size);
+ window_sizes.emplace_back(window_size);
+ deletion_triggers.emplace_back(deletion_trigger);
+ }
+
+ assert(window_sizes.size() == deletion_triggers.size());
+
+ for (size_t test = 0; test < window_sizes.size(); ++test) {
+ const int kBucketSize = 128;
+ const int kWindowSize = window_sizes[test];
+ const int kPaddedWindowSize =
+ kBucketSize * ((window_sizes[test] + kBucketSize - 1) / kBucketSize);
+ const int kNumDeletionTrigger = deletion_triggers[test];
+ const int kBias = (kNumDeletionTrigger + kBucketSize - 1) / kBucketSize;
+ // Simple test
+ {
+ auto factory = ROCKSDB_NAMESPACE::NewCompactOnDeletionCollectorFactory(
+ kWindowSize, kNumDeletionTrigger);
+ const int kSample = 10;
+ for (int delete_rate = 0; delete_rate <= kSample; ++delete_rate) {
+ std::unique_ptr<ROCKSDB_NAMESPACE::TablePropertiesCollector> collector(
+ factory->CreateTablePropertiesCollector(context));
+ int deletions = 0;
+ for (int i = 0; i < kPaddedWindowSize; ++i) {
+ if (i % kSample < delete_rate) {
+ collector->AddUserKey("hello", "rocksdb",
+ ROCKSDB_NAMESPACE::kEntryDelete, 0, 0);
+ deletions++;
+ } else {
+ collector->AddUserKey("hello", "rocksdb",
+ ROCKSDB_NAMESPACE::kEntryPut, 0, 0);
+ }
+ }
+ if (collector->NeedCompact() !=
+ (deletions >= kNumDeletionTrigger) &&
+ std::abs(deletions - kNumDeletionTrigger) > kBias) {
+ fprintf(stderr, "[Error] collector->NeedCompact() != (%d >= %d)"
+ " with kWindowSize = %d and kNumDeletionTrigger = %d\n",
+ deletions, kNumDeletionTrigger,
+ kWindowSize, kNumDeletionTrigger);
+ assert(false);
+ }
+ collector->Finish(nullptr);
+ }
+ }
+
+ // Only one section of a file satisfies the compaction trigger
+ {
+ auto factory = ROCKSDB_NAMESPACE::NewCompactOnDeletionCollectorFactory(
+ kWindowSize, kNumDeletionTrigger);
+ const int kSample = 10;
+ for (int delete_rate = 0; delete_rate <= kSample; ++delete_rate) {
+ std::unique_ptr<ROCKSDB_NAMESPACE::TablePropertiesCollector> collector(
+ factory->CreateTablePropertiesCollector(context));
+ int deletions = 0;
+ for (int section = 0; section < 5; ++section) {
+ int initial_entries = rnd.Uniform(kWindowSize) + kWindowSize;
+ for (int i = 0; i < initial_entries; ++i) {
+ collector->AddUserKey("hello", "rocksdb",
+ ROCKSDB_NAMESPACE::kEntryPut, 0, 0);
+ }
+ }
+ for (int i = 0; i < kPaddedWindowSize; ++i) {
+ if (i % kSample < delete_rate) {
+ collector->AddUserKey("hello", "rocksdb",
+ ROCKSDB_NAMESPACE::kEntryDelete, 0, 0);
+ deletions++;
+ } else {
+ collector->AddUserKey("hello", "rocksdb",
+ ROCKSDB_NAMESPACE::kEntryPut, 0, 0);
+ }
+ }
+ for (int section = 0; section < 5; ++section) {
+ int ending_entries = rnd.Uniform(kWindowSize) + kWindowSize;
+ for (int i = 0; i < ending_entries; ++i) {
+ collector->AddUserKey("hello", "rocksdb",
+ ROCKSDB_NAMESPACE::kEntryPut, 0, 0);
+ }
+ }
+ if (collector->NeedCompact() != (deletions >= kNumDeletionTrigger) &&
+ std::abs(deletions - kNumDeletionTrigger) > kBias) {
+ fprintf(stderr, "[Error] collector->NeedCompact() %d != (%d >= %d)"
+ " with kWindowSize = %d, kNumDeletionTrigger = %d\n",
+ collector->NeedCompact(),
+ deletions, kNumDeletionTrigger, kWindowSize,
+ kNumDeletionTrigger);
+ assert(false);
+ }
+ collector->Finish(nullptr);
+ }
+ }
+
+ // TEST 3: Issues a lots of deletes, but their density is not
+ // high enough to trigger compaction.
+ {
+ std::unique_ptr<ROCKSDB_NAMESPACE::TablePropertiesCollector> collector;
+ auto factory = ROCKSDB_NAMESPACE::NewCompactOnDeletionCollectorFactory(
+ kWindowSize, kNumDeletionTrigger);
+ collector.reset(factory->CreateTablePropertiesCollector(context));
+ assert(collector->NeedCompact() == false);
+ // Insert "kNumDeletionTrigger * 0.95" deletions for every
+ // "kWindowSize" and verify compaction is not needed.
+ const int kDeletionsPerSection = kNumDeletionTrigger * 95 / 100;
+ if (kDeletionsPerSection >= 0) {
+ for (int section = 0; section < 200; ++section) {
+ for (int i = 0; i < kPaddedWindowSize; ++i) {
+ if (i < kDeletionsPerSection) {
+ collector->AddUserKey("hello", "rocksdb",
+ ROCKSDB_NAMESPACE::kEntryDelete, 0, 0);
+ } else {
+ collector->AddUserKey("hello", "rocksdb",
+ ROCKSDB_NAMESPACE::kEntryPut, 0, 0);
+ }
+ }
+ }
+ if (collector->NeedCompact() &&
+ std::abs(kDeletionsPerSection - kNumDeletionTrigger) > kBias) {
+ fprintf(stderr, "[Error] collector->NeedCompact() != false"
+ " with kWindowSize = %d and kNumDeletionTrigger = %d\n",
+ kWindowSize, kNumDeletionTrigger);
+ assert(false);
+ }
+ collector->Finish(nullptr);
+ }
+ }
+ }
+ fprintf(stderr, "PASSED\n");
+}
+#else
+int main(int /*argc*/, char** /*argv*/) {
+ fprintf(stderr, "SKIPPED as RocksDBLite does not include utilities.\n");
+ return 0;
+}
+#endif // !ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/trace/file_trace_reader_writer.cc b/src/rocksdb/utilities/trace/file_trace_reader_writer.cc
new file mode 100644
index 000000000..7160f7a4c
--- /dev/null
+++ b/src/rocksdb/utilities/trace/file_trace_reader_writer.cc
@@ -0,0 +1,123 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "utilities/trace/file_trace_reader_writer.h"
+
+#include "env/composite_env_wrapper.h"
+#include "file/random_access_file_reader.h"
+#include "file/writable_file_writer.h"
+#include "trace_replay/trace_replay.h"
+#include "util/coding.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+const unsigned int FileTraceReader::kBufferSize = 1024; // 1KB
+
+FileTraceReader::FileTraceReader(
+ std::unique_ptr<RandomAccessFileReader>&& reader)
+ : file_reader_(std::move(reader)),
+ offset_(0),
+ buffer_(new char[kBufferSize]) {}
+
+FileTraceReader::~FileTraceReader() {
+ Close();
+ delete[] buffer_;
+}
+
+Status FileTraceReader::Close() {
+ file_reader_.reset();
+ return Status::OK();
+}
+
+Status FileTraceReader::Read(std::string* data) {
+ assert(file_reader_ != nullptr);
+ Status s = file_reader_->Read(offset_, kTraceMetadataSize, &result_, buffer_);
+ if (!s.ok()) {
+ return s;
+ }
+ if (result_.size() == 0) {
+ // No more data to read
+ // Todo: Come up with a better way to indicate end of data. May be this
+ // could be avoided once footer is introduced.
+ return Status::Incomplete();
+ }
+ if (result_.size() < kTraceMetadataSize) {
+ return Status::Corruption("Corrupted trace file.");
+ }
+ *data = result_.ToString();
+ offset_ += kTraceMetadataSize;
+
+ uint32_t payload_len =
+ DecodeFixed32(&buffer_[kTraceTimestampSize + kTraceTypeSize]);
+
+ // Read Payload
+ unsigned int bytes_to_read = payload_len;
+ unsigned int to_read =
+ bytes_to_read > kBufferSize ? kBufferSize : bytes_to_read;
+ while (to_read > 0) {
+ s = file_reader_->Read(offset_, to_read, &result_, buffer_);
+ if (!s.ok()) {
+ return s;
+ }
+ if (result_.size() < to_read) {
+ return Status::Corruption("Corrupted trace file.");
+ }
+ data->append(result_.data(), result_.size());
+
+ offset_ += to_read;
+ bytes_to_read -= to_read;
+ to_read = bytes_to_read > kBufferSize ? kBufferSize : bytes_to_read;
+ }
+
+ return s;
+}
+
+FileTraceWriter::~FileTraceWriter() { Close(); }
+
+Status FileTraceWriter::Close() {
+ file_writer_.reset();
+ return Status::OK();
+}
+
+Status FileTraceWriter::Write(const Slice& data) {
+ return file_writer_->Append(data);
+}
+
+uint64_t FileTraceWriter::GetFileSize() { return file_writer_->GetFileSize(); }
+
+Status NewFileTraceReader(Env* env, const EnvOptions& env_options,
+ const std::string& trace_filename,
+ std::unique_ptr<TraceReader>* trace_reader) {
+ std::unique_ptr<RandomAccessFile> trace_file;
+ Status s = env->NewRandomAccessFile(trace_filename, &trace_file, env_options);
+ if (!s.ok()) {
+ return s;
+ }
+
+ std::unique_ptr<RandomAccessFileReader> file_reader;
+ file_reader.reset(new RandomAccessFileReader(
+ NewLegacyRandomAccessFileWrapper(trace_file), trace_filename));
+ trace_reader->reset(new FileTraceReader(std::move(file_reader)));
+ return s;
+}
+
+Status NewFileTraceWriter(Env* env, const EnvOptions& env_options,
+ const std::string& trace_filename,
+ std::unique_ptr<TraceWriter>* trace_writer) {
+ std::unique_ptr<WritableFile> trace_file;
+ Status s = env->NewWritableFile(trace_filename, &trace_file, env_options);
+ if (!s.ok()) {
+ return s;
+ }
+
+ std::unique_ptr<WritableFileWriter> file_writer;
+ file_writer.reset(new WritableFileWriter(
+ NewLegacyWritableFileWrapper(std::move(trace_file)), trace_filename,
+ env_options));
+ trace_writer->reset(new FileTraceWriter(std::move(file_writer)));
+ return s;
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/trace/file_trace_reader_writer.h b/src/rocksdb/utilities/trace/file_trace_reader_writer.h
new file mode 100644
index 000000000..a9eafa5af
--- /dev/null
+++ b/src/rocksdb/utilities/trace/file_trace_reader_writer.h
@@ -0,0 +1,48 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#include "rocksdb/trace_reader_writer.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class RandomAccessFileReader;
+class WritableFileWriter;
+
+// FileTraceReader allows reading RocksDB traces from a file.
+class FileTraceReader : public TraceReader {
+ public:
+ explicit FileTraceReader(std::unique_ptr<RandomAccessFileReader>&& reader);
+ ~FileTraceReader();
+
+ virtual Status Read(std::string* data) override;
+ virtual Status Close() override;
+
+ private:
+ std::unique_ptr<RandomAccessFileReader> file_reader_;
+ Slice result_;
+ size_t offset_;
+ char* const buffer_;
+
+ static const unsigned int kBufferSize;
+};
+
+// FileTraceWriter allows writing RocksDB traces to a file.
+class FileTraceWriter : public TraceWriter {
+ public:
+ explicit FileTraceWriter(std::unique_ptr<WritableFileWriter>&& file_writer)
+ : file_writer_(std::move(file_writer)) {}
+ ~FileTraceWriter();
+
+ virtual Status Write(const Slice& data) override;
+ virtual Status Close() override;
+ virtual uint64_t GetFileSize() override;
+
+ private:
+ std::unique_ptr<WritableFileWriter> file_writer_;
+};
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/transactions/optimistic_transaction.cc b/src/rocksdb/utilities/transactions/optimistic_transaction.cc
new file mode 100644
index 000000000..b01102bb2
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/optimistic_transaction.cc
@@ -0,0 +1,187 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "utilities/transactions/optimistic_transaction.h"
+
+#include <string>
+
+#include "db/column_family.h"
+#include "db/db_impl/db_impl.h"
+#include "rocksdb/comparator.h"
+#include "rocksdb/db.h"
+#include "rocksdb/status.h"
+#include "rocksdb/utilities/optimistic_transaction_db.h"
+#include "util/cast_util.h"
+#include "util/string_util.h"
+#include "utilities/transactions/transaction_util.h"
+#include "utilities/transactions/optimistic_transaction.h"
+#include "utilities/transactions/optimistic_transaction_db_impl.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+struct WriteOptions;
+
+OptimisticTransaction::OptimisticTransaction(
+ OptimisticTransactionDB* txn_db, const WriteOptions& write_options,
+ const OptimisticTransactionOptions& txn_options)
+ : TransactionBaseImpl(txn_db->GetBaseDB(), write_options), txn_db_(txn_db) {
+ Initialize(txn_options);
+}
+
+void OptimisticTransaction::Initialize(
+ const OptimisticTransactionOptions& txn_options) {
+ if (txn_options.set_snapshot) {
+ SetSnapshot();
+ }
+}
+
+void OptimisticTransaction::Reinitialize(
+ OptimisticTransactionDB* txn_db, const WriteOptions& write_options,
+ const OptimisticTransactionOptions& txn_options) {
+ TransactionBaseImpl::Reinitialize(txn_db->GetBaseDB(), write_options);
+ Initialize(txn_options);
+}
+
+OptimisticTransaction::~OptimisticTransaction() {}
+
+void OptimisticTransaction::Clear() { TransactionBaseImpl::Clear(); }
+
+Status OptimisticTransaction::Prepare() {
+ return Status::InvalidArgument(
+ "Two phase commit not supported for optimistic transactions.");
+}
+
+Status OptimisticTransaction::Commit() {
+ auto txn_db_impl = static_cast_with_check<OptimisticTransactionDBImpl,
+ OptimisticTransactionDB>(txn_db_);
+ assert(txn_db_impl);
+ switch (txn_db_impl->GetValidatePolicy()) {
+ case OccValidationPolicy::kValidateParallel:
+ return CommitWithParallelValidate();
+ case OccValidationPolicy::kValidateSerial:
+ return CommitWithSerialValidate();
+ default:
+ assert(0);
+ }
+ // unreachable, just void compiler complain
+ return Status::OK();
+}
+
+Status OptimisticTransaction::CommitWithSerialValidate() {
+ // Set up callback which will call CheckTransactionForConflicts() to
+ // check whether this transaction is safe to be committed.
+ OptimisticTransactionCallback callback(this);
+
+ DBImpl* db_impl = static_cast_with_check<DBImpl, DB>(db_->GetRootDB());
+
+ Status s = db_impl->WriteWithCallback(
+ write_options_, GetWriteBatch()->GetWriteBatch(), &callback);
+
+ if (s.ok()) {
+ Clear();
+ }
+
+ return s;
+}
+
+Status OptimisticTransaction::CommitWithParallelValidate() {
+ auto txn_db_impl = static_cast_with_check<OptimisticTransactionDBImpl,
+ OptimisticTransactionDB>(txn_db_);
+ assert(txn_db_impl);
+ DBImpl* db_impl = static_cast_with_check<DBImpl, DB>(db_->GetRootDB());
+ assert(db_impl);
+ const size_t space = txn_db_impl->GetLockBucketsSize();
+ std::set<size_t> lk_idxes;
+ std::vector<std::unique_lock<std::mutex>> lks;
+ for (auto& cfit : GetTrackedKeys()) {
+ for (auto& keyit : cfit.second) {
+ lk_idxes.insert(fastrange64(GetSliceNPHash64(keyit.first), space));
+ }
+ }
+ // NOTE: in a single txn, all bucket-locks are taken in ascending order.
+ // In this way, txns from different threads all obey this rule so that
+ // deadlock can be avoided.
+ for (auto v : lk_idxes) {
+ lks.emplace_back(txn_db_impl->LockBucket(v));
+ }
+
+ Status s = TransactionUtil::CheckKeysForConflicts(db_impl, GetTrackedKeys(),
+ true /* cache_only */);
+ if (!s.ok()) {
+ return s;
+ }
+
+ s = db_impl->Write(write_options_, GetWriteBatch()->GetWriteBatch());
+ if (s.ok()) {
+ Clear();
+ }
+
+ return s;
+}
+
+Status OptimisticTransaction::Rollback() {
+ Clear();
+ return Status::OK();
+}
+
+// Record this key so that we can check it for conflicts at commit time.
+//
+// 'exclusive' is unused for OptimisticTransaction.
+Status OptimisticTransaction::TryLock(ColumnFamilyHandle* column_family,
+ const Slice& key, bool read_only,
+ bool exclusive, const bool do_validate,
+ const bool assume_tracked) {
+ assert(!assume_tracked); // not supported
+ (void)assume_tracked;
+ if (!do_validate) {
+ return Status::OK();
+ }
+ uint32_t cfh_id = GetColumnFamilyID(column_family);
+
+ SetSnapshotIfNeeded();
+
+ SequenceNumber seq;
+ if (snapshot_) {
+ seq = snapshot_->GetSequenceNumber();
+ } else {
+ seq = db_->GetLatestSequenceNumber();
+ }
+
+ std::string key_str = key.ToString();
+
+ TrackKey(cfh_id, key_str, seq, read_only, exclusive);
+
+ // Always return OK. Confilct checking will happen at commit time.
+ return Status::OK();
+}
+
+// Returns OK if it is safe to commit this transaction. Returns Status::Busy
+// if there are read or write conflicts that would prevent us from committing OR
+// if we can not determine whether there would be any such conflicts.
+//
+// Should only be called on writer thread in order to avoid any race conditions
+// in detecting write conflicts.
+Status OptimisticTransaction::CheckTransactionForConflicts(DB* db) {
+ Status result;
+
+ auto db_impl = static_cast_with_check<DBImpl, DB>(db);
+
+ // Since we are on the write thread and do not want to block other writers,
+ // we will do a cache-only conflict check. This can result in TryAgain
+ // getting returned if there is not sufficient memtable history to check
+ // for conflicts.
+ return TransactionUtil::CheckKeysForConflicts(db_impl, GetTrackedKeys(),
+ true /* cache_only */);
+}
+
+Status OptimisticTransaction::SetName(const TransactionName& /* unused */) {
+ return Status::InvalidArgument("Optimistic transactions cannot be named.");
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/optimistic_transaction.h b/src/rocksdb/utilities/transactions/optimistic_transaction.h
new file mode 100644
index 000000000..c337de2af
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/optimistic_transaction.h
@@ -0,0 +1,101 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#ifndef ROCKSDB_LITE
+
+#include <stack>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "db/write_callback.h"
+#include "rocksdb/db.h"
+#include "rocksdb/slice.h"
+#include "rocksdb/snapshot.h"
+#include "rocksdb/status.h"
+#include "rocksdb/types.h"
+#include "rocksdb/utilities/transaction.h"
+#include "rocksdb/utilities/optimistic_transaction_db.h"
+#include "rocksdb/utilities/write_batch_with_index.h"
+#include "utilities/transactions/transaction_base.h"
+#include "utilities/transactions/transaction_util.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class OptimisticTransaction : public TransactionBaseImpl {
+ public:
+ OptimisticTransaction(OptimisticTransactionDB* db,
+ const WriteOptions& write_options,
+ const OptimisticTransactionOptions& txn_options);
+ // No copying allowed
+ OptimisticTransaction(const OptimisticTransaction&) = delete;
+ void operator=(const OptimisticTransaction&) = delete;
+
+ virtual ~OptimisticTransaction();
+
+ void Reinitialize(OptimisticTransactionDB* txn_db,
+ const WriteOptions& write_options,
+ const OptimisticTransactionOptions& txn_options);
+
+ Status Prepare() override;
+
+ Status Commit() override;
+
+ Status Rollback() override;
+
+ Status SetName(const TransactionName& name) override;
+
+ protected:
+ Status TryLock(ColumnFamilyHandle* column_family, const Slice& key,
+ bool read_only, bool exclusive, const bool do_validate = true,
+ const bool assume_tracked = false) override;
+
+ private:
+ ROCKSDB_FIELD_UNUSED OptimisticTransactionDB* const txn_db_;
+
+ friend class OptimisticTransactionCallback;
+
+ void Initialize(const OptimisticTransactionOptions& txn_options);
+
+ // Returns OK if it is safe to commit this transaction. Returns Status::Busy
+ // if there are read or write conflicts that would prevent us from committing
+ // OR if we can not determine whether there would be any such conflicts.
+ //
+ // Should only be called on writer thread.
+ Status CheckTransactionForConflicts(DB* db);
+
+ void Clear() override;
+
+ void UnlockGetForUpdate(ColumnFamilyHandle* /* unused */,
+ const Slice& /* unused */) override {
+ // Nothing to unlock.
+ }
+
+ Status CommitWithSerialValidate();
+
+ Status CommitWithParallelValidate();
+};
+
+// Used at commit time to trigger transaction validation
+class OptimisticTransactionCallback : public WriteCallback {
+ public:
+ explicit OptimisticTransactionCallback(OptimisticTransaction* txn)
+ : txn_(txn) {}
+
+ Status Callback(DB* db) override {
+ return txn_->CheckTransactionForConflicts(db);
+ }
+
+ bool AllowWriteBatching() override { return false; }
+
+ private:
+ OptimisticTransaction* txn_;
+};
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/optimistic_transaction_db_impl.cc b/src/rocksdb/utilities/transactions/optimistic_transaction_db_impl.cc
new file mode 100644
index 000000000..bffb3d5ed
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/optimistic_transaction_db_impl.cc
@@ -0,0 +1,111 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "utilities/transactions/optimistic_transaction_db_impl.h"
+
+#include <string>
+#include <vector>
+
+#include "db/db_impl/db_impl.h"
+#include "rocksdb/db.h"
+#include "rocksdb/options.h"
+#include "rocksdb/utilities/optimistic_transaction_db.h"
+#include "utilities/transactions/optimistic_transaction.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+Transaction* OptimisticTransactionDBImpl::BeginTransaction(
+ const WriteOptions& write_options,
+ const OptimisticTransactionOptions& txn_options, Transaction* old_txn) {
+ if (old_txn != nullptr) {
+ ReinitializeTransaction(old_txn, write_options, txn_options);
+ return old_txn;
+ } else {
+ return new OptimisticTransaction(this, write_options, txn_options);
+ }
+}
+
+std::unique_lock<std::mutex> OptimisticTransactionDBImpl::LockBucket(
+ size_t idx) {
+ assert(idx < bucketed_locks_.size());
+ return std::unique_lock<std::mutex>(*bucketed_locks_[idx]);
+}
+
+Status OptimisticTransactionDB::Open(const Options& options,
+ const std::string& dbname,
+ OptimisticTransactionDB** dbptr) {
+ DBOptions db_options(options);
+ ColumnFamilyOptions cf_options(options);
+ std::vector<ColumnFamilyDescriptor> column_families;
+ column_families.push_back(
+ ColumnFamilyDescriptor(kDefaultColumnFamilyName, cf_options));
+ std::vector<ColumnFamilyHandle*> handles;
+ Status s = Open(db_options, dbname, column_families, &handles, dbptr);
+ if (s.ok()) {
+ assert(handles.size() == 1);
+ // i can delete the handle since DBImpl is always holding a reference to
+ // default column family
+ delete handles[0];
+ }
+
+ return s;
+}
+
+Status OptimisticTransactionDB::Open(
+ const DBOptions& db_options, const std::string& dbname,
+ const std::vector<ColumnFamilyDescriptor>& column_families,
+ std::vector<ColumnFamilyHandle*>* handles,
+ OptimisticTransactionDB** dbptr) {
+ return OptimisticTransactionDB::Open(db_options,
+ OptimisticTransactionDBOptions(), dbname,
+ column_families, handles, dbptr);
+}
+
+Status OptimisticTransactionDB::Open(
+ const DBOptions& db_options,
+ const OptimisticTransactionDBOptions& occ_options,
+ const std::string& dbname,
+ const std::vector<ColumnFamilyDescriptor>& column_families,
+ std::vector<ColumnFamilyHandle*>* handles,
+ OptimisticTransactionDB** dbptr) {
+ Status s;
+ DB* db;
+
+ std::vector<ColumnFamilyDescriptor> column_families_copy = column_families;
+
+ // Enable MemTable History if not already enabled
+ for (auto& column_family : column_families_copy) {
+ ColumnFamilyOptions* options = &column_family.options;
+
+ if (options->max_write_buffer_size_to_maintain == 0 &&
+ options->max_write_buffer_number_to_maintain == 0) {
+ // Setting to -1 will set the History size to
+ // max_write_buffer_number * write_buffer_size.
+ options->max_write_buffer_size_to_maintain = -1;
+ }
+ }
+
+ s = DB::Open(db_options, dbname, column_families_copy, handles, &db);
+
+ if (s.ok()) {
+ *dbptr = new OptimisticTransactionDBImpl(db, occ_options);
+ }
+
+ return s;
+}
+
+void OptimisticTransactionDBImpl::ReinitializeTransaction(
+ Transaction* txn, const WriteOptions& write_options,
+ const OptimisticTransactionOptions& txn_options) {
+ assert(dynamic_cast<OptimisticTransaction*>(txn) != nullptr);
+ auto txn_impl = reinterpret_cast<OptimisticTransaction*>(txn);
+
+ txn_impl->Reinitialize(this, write_options, txn_options);
+}
+
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/optimistic_transaction_db_impl.h b/src/rocksdb/utilities/transactions/optimistic_transaction_db_impl.h
new file mode 100644
index 000000000..d895d49b8
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/optimistic_transaction_db_impl.h
@@ -0,0 +1,71 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+#ifndef ROCKSDB_LITE
+
+#include <mutex>
+#include <vector>
+#include <algorithm>
+
+#include "rocksdb/db.h"
+#include "rocksdb/options.h"
+#include "rocksdb/utilities/optimistic_transaction_db.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class OptimisticTransactionDBImpl : public OptimisticTransactionDB {
+ public:
+ explicit OptimisticTransactionDBImpl(
+ DB* db, const OptimisticTransactionDBOptions& occ_options,
+ bool take_ownership = true)
+ : OptimisticTransactionDB(db),
+ db_owner_(take_ownership),
+ validate_policy_(occ_options.validate_policy) {
+ if (validate_policy_ == OccValidationPolicy::kValidateParallel) {
+ uint32_t bucket_size = std::max(16u, occ_options.occ_lock_buckets);
+ bucketed_locks_.reserve(bucket_size);
+ for (size_t i = 0; i < bucket_size; ++i) {
+ bucketed_locks_.emplace_back(
+ std::unique_ptr<std::mutex>(new std::mutex));
+ }
+ }
+ }
+
+ ~OptimisticTransactionDBImpl() {
+ // Prevent this stackable from destroying
+ // base db
+ if (!db_owner_) {
+ db_ = nullptr;
+ }
+ }
+
+ Transaction* BeginTransaction(const WriteOptions& write_options,
+ const OptimisticTransactionOptions& txn_options,
+ Transaction* old_txn) override;
+
+ size_t GetLockBucketsSize() const { return bucketed_locks_.size(); }
+
+ OccValidationPolicy GetValidatePolicy() const { return validate_policy_; }
+
+ std::unique_lock<std::mutex> LockBucket(size_t idx);
+
+ private:
+ // NOTE: used in validation phase. Each key is hashed into some
+ // bucket. We then take the lock in the hash value order to avoid deadlock.
+ std::vector<std::unique_ptr<std::mutex>> bucketed_locks_;
+
+ bool db_owner_;
+
+ const OccValidationPolicy validate_policy_;
+
+ void ReinitializeTransaction(Transaction* txn,
+ const WriteOptions& write_options,
+ const OptimisticTransactionOptions& txn_options =
+ OptimisticTransactionOptions());
+};
+
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/optimistic_transaction_test.cc b/src/rocksdb/utilities/transactions/optimistic_transaction_test.cc
new file mode 100644
index 000000000..63c1a255c
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/optimistic_transaction_test.cc
@@ -0,0 +1,1535 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include <functional>
+#include <string>
+#include <thread>
+
+#include "db/db_impl/db_impl.h"
+#include "logging/logging.h"
+#include "port/port.h"
+#include "rocksdb/db.h"
+#include "rocksdb/perf_context.h"
+#include "rocksdb/utilities/optimistic_transaction_db.h"
+#include "rocksdb/utilities/transaction.h"
+#include "test_util/sync_point.h"
+#include "test_util/testharness.h"
+#include "test_util/transaction_test_util.h"
+#include "util/crc32c.h"
+#include "util/random.h"
+
+using std::string;
+
+namespace ROCKSDB_NAMESPACE {
+
+class OptimisticTransactionTest
+ : public testing::Test,
+ public testing::WithParamInterface<OccValidationPolicy> {
+ public:
+ OptimisticTransactionDB* txn_db;
+ string dbname;
+ Options options;
+
+ OptimisticTransactionTest() {
+ options.create_if_missing = true;
+ options.max_write_buffer_number = 2;
+ options.max_write_buffer_size_to_maintain = 1600;
+ dbname = test::PerThreadDBPath("optimistic_transaction_testdb");
+
+ DestroyDB(dbname, options);
+ Open();
+ }
+ ~OptimisticTransactionTest() override {
+ delete txn_db;
+ DestroyDB(dbname, options);
+ }
+
+ void Reopen() {
+ delete txn_db;
+ txn_db = nullptr;
+ Open();
+ }
+
+private:
+ void Open() {
+ ColumnFamilyOptions cf_options(options);
+ OptimisticTransactionDBOptions occ_opts;
+ occ_opts.validate_policy = GetParam();
+ std::vector<ColumnFamilyDescriptor> column_families;
+ std::vector<ColumnFamilyHandle*> handles;
+ column_families.push_back(
+ ColumnFamilyDescriptor(kDefaultColumnFamilyName, cf_options));
+ Status s =
+ OptimisticTransactionDB::Open(DBOptions(options), occ_opts, dbname,
+ column_families, &handles, &txn_db);
+
+ assert(s.ok());
+ assert(txn_db != nullptr);
+ assert(handles.size() == 1);
+ delete handles[0];
+ }
+};
+
+TEST_P(OptimisticTransactionTest, SuccessTest) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ string value;
+ Status s;
+
+ txn_db->Put(write_options, Slice("foo"), Slice("bar"));
+ txn_db->Put(write_options, Slice("foo2"), Slice("bar"));
+
+ Transaction* txn = txn_db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn);
+
+ txn->GetForUpdate(read_options, "foo", &value);
+ ASSERT_EQ(value, "bar");
+
+ txn->Put(Slice("foo"), Slice("bar2"));
+
+ txn->GetForUpdate(read_options, "foo", &value);
+ ASSERT_EQ(value, "bar2");
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ txn_db->Get(read_options, "foo", &value);
+ ASSERT_EQ(value, "bar2");
+
+ delete txn;
+}
+
+TEST_P(OptimisticTransactionTest, WriteConflictTest) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ string value;
+ Status s;
+
+ txn_db->Put(write_options, "foo", "bar");
+ txn_db->Put(write_options, "foo2", "bar");
+
+ Transaction* txn = txn_db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn);
+
+ txn->Put("foo", "bar2");
+
+ // This Put outside of a transaction will conflict with the previous write
+ s = txn_db->Put(write_options, "foo", "barz");
+ ASSERT_OK(s);
+
+ s = txn_db->Get(read_options, "foo", &value);
+ ASSERT_EQ(value, "barz");
+ ASSERT_EQ(1, txn->GetNumKeys());
+
+ s = txn->Commit();
+ ASSERT_TRUE(s.IsBusy()); // Txn should not commit
+
+ // Verify that transaction did not write anything
+ txn_db->Get(read_options, "foo", &value);
+ ASSERT_EQ(value, "barz");
+ txn_db->Get(read_options, "foo2", &value);
+ ASSERT_EQ(value, "bar");
+
+ delete txn;
+}
+
+TEST_P(OptimisticTransactionTest, WriteConflictTest2) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ OptimisticTransactionOptions txn_options;
+ string value;
+ Status s;
+
+ txn_db->Put(write_options, "foo", "bar");
+ txn_db->Put(write_options, "foo2", "bar");
+
+ txn_options.set_snapshot = true;
+ Transaction* txn = txn_db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn);
+
+ // This Put outside of a transaction will conflict with a later write
+ s = txn_db->Put(write_options, "foo", "barz");
+ ASSERT_OK(s);
+
+ txn->Put("foo", "bar2"); // Conflicts with write done after snapshot taken
+
+ s = txn_db->Get(read_options, "foo", &value);
+ ASSERT_EQ(value, "barz");
+
+ s = txn->Commit();
+ ASSERT_TRUE(s.IsBusy()); // Txn should not commit
+
+ // Verify that transaction did not write anything
+ txn_db->Get(read_options, "foo", &value);
+ ASSERT_EQ(value, "barz");
+ txn_db->Get(read_options, "foo2", &value);
+ ASSERT_EQ(value, "bar");
+
+ delete txn;
+}
+
+TEST_P(OptimisticTransactionTest, ReadConflictTest) {
+ WriteOptions write_options;
+ ReadOptions read_options, snapshot_read_options;
+ OptimisticTransactionOptions txn_options;
+ string value;
+ Status s;
+
+ txn_db->Put(write_options, "foo", "bar");
+ txn_db->Put(write_options, "foo2", "bar");
+
+ txn_options.set_snapshot = true;
+ Transaction* txn = txn_db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn);
+
+ txn->SetSnapshot();
+ snapshot_read_options.snapshot = txn->GetSnapshot();
+
+ txn->GetForUpdate(snapshot_read_options, "foo", &value);
+ ASSERT_EQ(value, "bar");
+
+ // This Put outside of a transaction will conflict with the previous read
+ s = txn_db->Put(write_options, "foo", "barz");
+ ASSERT_OK(s);
+
+ s = txn_db->Get(read_options, "foo", &value);
+ ASSERT_EQ(value, "barz");
+
+ s = txn->Commit();
+ ASSERT_TRUE(s.IsBusy()); // Txn should not commit
+
+ // Verify that transaction did not write anything
+ txn->GetForUpdate(read_options, "foo", &value);
+ ASSERT_EQ(value, "barz");
+ txn->GetForUpdate(read_options, "foo2", &value);
+ ASSERT_EQ(value, "bar");
+
+ delete txn;
+}
+
+TEST_P(OptimisticTransactionTest, TxnOnlyTest) {
+ // Test to make sure transactions work when there are no other writes in an
+ // empty db.
+
+ WriteOptions write_options;
+ ReadOptions read_options;
+ string value;
+ Status s;
+
+ Transaction* txn = txn_db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn);
+
+ txn->Put("x", "y");
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ delete txn;
+}
+
+TEST_P(OptimisticTransactionTest, FlushTest) {
+ WriteOptions write_options;
+ ReadOptions read_options, snapshot_read_options;
+ string value;
+ Status s;
+
+ txn_db->Put(write_options, Slice("foo"), Slice("bar"));
+ txn_db->Put(write_options, Slice("foo2"), Slice("bar"));
+
+ Transaction* txn = txn_db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn);
+
+ snapshot_read_options.snapshot = txn->GetSnapshot();
+
+ txn->GetForUpdate(snapshot_read_options, "foo", &value);
+ ASSERT_EQ(value, "bar");
+
+ txn->Put(Slice("foo"), Slice("bar2"));
+
+ txn->GetForUpdate(snapshot_read_options, "foo", &value);
+ ASSERT_EQ(value, "bar2");
+
+ // Put a random key so we have a memtable to flush
+ s = txn_db->Put(write_options, "dummy", "dummy");
+ ASSERT_OK(s);
+
+ // force a memtable flush
+ FlushOptions flush_ops;
+ txn_db->Flush(flush_ops);
+
+ s = txn->Commit();
+ // txn should commit since the flushed table is still in MemtableList History
+ ASSERT_OK(s);
+
+ txn_db->Get(read_options, "foo", &value);
+ ASSERT_EQ(value, "bar2");
+
+ delete txn;
+}
+
+TEST_P(OptimisticTransactionTest, FlushTest2) {
+ WriteOptions write_options;
+ ReadOptions read_options, snapshot_read_options;
+ string value;
+ Status s;
+
+ txn_db->Put(write_options, Slice("foo"), Slice("bar"));
+ txn_db->Put(write_options, Slice("foo2"), Slice("bar"));
+
+ Transaction* txn = txn_db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn);
+
+ snapshot_read_options.snapshot = txn->GetSnapshot();
+
+ txn->GetForUpdate(snapshot_read_options, "foo", &value);
+ ASSERT_EQ(value, "bar");
+
+ txn->Put(Slice("foo"), Slice("bar2"));
+
+ txn->GetForUpdate(snapshot_read_options, "foo", &value);
+ ASSERT_EQ(value, "bar2");
+
+ // Put a random key so we have a MemTable to flush
+ s = txn_db->Put(write_options, "dummy", "dummy");
+ ASSERT_OK(s);
+
+ // force a memtable flush
+ FlushOptions flush_ops;
+ txn_db->Flush(flush_ops);
+
+ // Put a random key so we have a MemTable to flush
+ s = txn_db->Put(write_options, "dummy", "dummy2");
+ ASSERT_OK(s);
+
+ // force a memtable flush
+ txn_db->Flush(flush_ops);
+
+ s = txn_db->Put(write_options, "dummy", "dummy3");
+ ASSERT_OK(s);
+
+ // force a memtable flush
+ // Since our test db has max_write_buffer_number=2, this flush will cause
+ // the first memtable to get purged from the MemtableList history.
+ txn_db->Flush(flush_ops);
+
+ s = txn->Commit();
+ // txn should not commit since MemTableList History is not large enough
+ ASSERT_TRUE(s.IsTryAgain());
+
+ txn_db->Get(read_options, "foo", &value);
+ ASSERT_EQ(value, "bar");
+
+ delete txn;
+}
+
+// Trigger the condition where some old memtables are skipped when doing
+// TransactionUtil::CheckKey(), and make sure the result is still correct.
+TEST_P(OptimisticTransactionTest, CheckKeySkipOldMemtable) {
+ const int kAttemptHistoryMemtable = 0;
+ const int kAttemptImmMemTable = 1;
+ for (int attempt = kAttemptHistoryMemtable; attempt <= kAttemptImmMemTable;
+ attempt++) {
+ options.max_write_buffer_number_to_maintain = 3;
+ Reopen();
+
+ WriteOptions write_options;
+ ReadOptions read_options;
+ ReadOptions snapshot_read_options;
+ ReadOptions snapshot_read_options2;
+ string value;
+ Status s;
+
+ ASSERT_OK(txn_db->Put(write_options, Slice("foo"), Slice("bar")));
+ ASSERT_OK(txn_db->Put(write_options, Slice("foo2"), Slice("bar")));
+
+ Transaction* txn = txn_db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn != nullptr);
+
+ Transaction* txn2 = txn_db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn2 != nullptr);
+
+ snapshot_read_options.snapshot = txn->GetSnapshot();
+ ASSERT_OK(txn->GetForUpdate(snapshot_read_options, "foo", &value));
+ ASSERT_EQ(value, "bar");
+ ASSERT_OK(txn->Put(Slice("foo"), Slice("bar2")));
+
+ snapshot_read_options2.snapshot = txn2->GetSnapshot();
+ ASSERT_OK(txn2->GetForUpdate(snapshot_read_options2, "foo2", &value));
+ ASSERT_EQ(value, "bar");
+ ASSERT_OK(txn2->Put(Slice("foo2"), Slice("bar2")));
+
+ // txn updates "foo" and txn2 updates "foo2", and now a write is
+ // issued for "foo", which conflicts with txn but not txn2
+ ASSERT_OK(txn_db->Put(write_options, "foo", "bar"));
+
+ if (attempt == kAttemptImmMemTable) {
+ // For the second attempt, hold flush from beginning. The memtable
+ // will be switched to immutable after calling TEST_SwitchMemtable()
+ // while CheckKey() is called.
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency(
+ {{"OptimisticTransactionTest.CheckKeySkipOldMemtable",
+ "FlushJob::Start"}});
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+ }
+
+ // force a memtable flush. The memtable should still be kept
+ FlushOptions flush_ops;
+ if (attempt == kAttemptHistoryMemtable) {
+ ASSERT_OK(txn_db->Flush(flush_ops));
+ } else {
+ assert(attempt == kAttemptImmMemTable);
+ DBImpl* db_impl = static_cast<DBImpl*>(txn_db->GetRootDB());
+ db_impl->TEST_SwitchMemtable();
+ }
+ uint64_t num_imm_mems;
+ ASSERT_TRUE(txn_db->GetIntProperty(DB::Properties::kNumImmutableMemTable,
+ &num_imm_mems));
+ if (attempt == kAttemptHistoryMemtable) {
+ ASSERT_EQ(0, num_imm_mems);
+ } else {
+ assert(attempt == kAttemptImmMemTable);
+ ASSERT_EQ(1, num_imm_mems);
+ }
+
+ // Put something in active memtable
+ ASSERT_OK(txn_db->Put(write_options, Slice("foo3"), Slice("bar")));
+
+ // Create txn3 after flushing, when this transaction is commited,
+ // only need to check the active memtable
+ Transaction* txn3 = txn_db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn3 != nullptr);
+
+ // Commit both of txn and txn2. txn will conflict but txn2 will
+ // pass. In both ways, both memtables are queried.
+ SetPerfLevel(PerfLevel::kEnableCount);
+
+ get_perf_context()->Reset();
+ s = txn->Commit();
+ // We should have checked two memtables
+ ASSERT_EQ(2, get_perf_context()->get_from_memtable_count);
+ // txn should fail because of conflict, even if the memtable
+ // has flushed, because it is still preserved in history.
+ ASSERT_TRUE(s.IsBusy());
+
+ get_perf_context()->Reset();
+ s = txn2->Commit();
+ // We should have checked two memtables
+ ASSERT_EQ(2, get_perf_context()->get_from_memtable_count);
+ ASSERT_TRUE(s.ok());
+
+ txn3->Put(Slice("foo2"), Slice("bar2"));
+ get_perf_context()->Reset();
+ s = txn3->Commit();
+ // txn3 is created after the active memtable is created, so that is the only
+ // memtable to check.
+ ASSERT_EQ(1, get_perf_context()->get_from_memtable_count);
+ ASSERT_TRUE(s.ok());
+
+ TEST_SYNC_POINT("OptimisticTransactionTest.CheckKeySkipOldMemtable");
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+
+ SetPerfLevel(PerfLevel::kDisable);
+
+ delete txn;
+ delete txn2;
+ delete txn3;
+ }
+}
+
+TEST_P(OptimisticTransactionTest, NoSnapshotTest) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ string value;
+ Status s;
+
+ txn_db->Put(write_options, "AAA", "bar");
+
+ Transaction* txn = txn_db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn);
+
+ // Modify key after transaction start
+ txn_db->Put(write_options, "AAA", "bar1");
+
+ // Read and write without a snapshot
+ txn->GetForUpdate(read_options, "AAA", &value);
+ ASSERT_EQ(value, "bar1");
+ txn->Put("AAA", "bar2");
+
+ // Should commit since read/write was done after data changed
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ txn->GetForUpdate(read_options, "AAA", &value);
+ ASSERT_EQ(value, "bar2");
+
+ delete txn;
+}
+
+TEST_P(OptimisticTransactionTest, MultipleSnapshotTest) {
+ WriteOptions write_options;
+ ReadOptions read_options, snapshot_read_options;
+ string value;
+ Status s;
+
+ txn_db->Put(write_options, "AAA", "bar");
+ txn_db->Put(write_options, "BBB", "bar");
+ txn_db->Put(write_options, "CCC", "bar");
+
+ Transaction* txn = txn_db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn);
+
+ txn_db->Put(write_options, "AAA", "bar1");
+
+ // Read and write without a snapshot
+ txn->GetForUpdate(read_options, "AAA", &value);
+ ASSERT_EQ(value, "bar1");
+ txn->Put("AAA", "bar2");
+
+ // Modify BBB before snapshot is taken
+ txn_db->Put(write_options, "BBB", "bar1");
+
+ txn->SetSnapshot();
+ snapshot_read_options.snapshot = txn->GetSnapshot();
+
+ // Read and write with snapshot
+ txn->GetForUpdate(snapshot_read_options, "BBB", &value);
+ ASSERT_EQ(value, "bar1");
+ txn->Put("BBB", "bar2");
+
+ txn_db->Put(write_options, "CCC", "bar1");
+
+ // Set a new snapshot
+ txn->SetSnapshot();
+ snapshot_read_options.snapshot = txn->GetSnapshot();
+
+ // Read and write with snapshot
+ txn->GetForUpdate(snapshot_read_options, "CCC", &value);
+ ASSERT_EQ(value, "bar1");
+ txn->Put("CCC", "bar2");
+
+ s = txn->GetForUpdate(read_options, "AAA", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "bar2");
+ s = txn->GetForUpdate(read_options, "BBB", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "bar2");
+ s = txn->GetForUpdate(read_options, "CCC", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "bar2");
+
+ s = txn_db->Get(read_options, "AAA", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "bar1");
+ s = txn_db->Get(read_options, "BBB", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "bar1");
+ s = txn_db->Get(read_options, "CCC", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "bar1");
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ s = txn_db->Get(read_options, "AAA", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "bar2");
+ s = txn_db->Get(read_options, "BBB", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "bar2");
+ s = txn_db->Get(read_options, "CCC", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "bar2");
+
+ // verify that we track multiple writes to the same key at different snapshots
+ delete txn;
+ txn = txn_db->BeginTransaction(write_options);
+
+ // Potentially conflicting writes
+ txn_db->Put(write_options, "ZZZ", "zzz");
+ txn_db->Put(write_options, "XXX", "xxx");
+
+ txn->SetSnapshot();
+
+ OptimisticTransactionOptions txn_options;
+ txn_options.set_snapshot = true;
+ Transaction* txn2 = txn_db->BeginTransaction(write_options, txn_options);
+ txn2->SetSnapshot();
+
+ // This should not conflict in txn since the snapshot is later than the
+ // previous write (spoiler alert: it will later conflict with txn2).
+ txn->Put("ZZZ", "zzzz");
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ delete txn;
+
+ // This will conflict since the snapshot is earlier than another write to ZZZ
+ txn2->Put("ZZZ", "xxxxx");
+
+ s = txn2->Commit();
+ ASSERT_TRUE(s.IsBusy());
+
+ delete txn2;
+}
+
+TEST_P(OptimisticTransactionTest, ColumnFamiliesTest) {
+ WriteOptions write_options;
+ ReadOptions read_options, snapshot_read_options;
+ OptimisticTransactionOptions txn_options;
+ string value;
+ Status s;
+
+ ColumnFamilyHandle *cfa, *cfb;
+ ColumnFamilyOptions cf_options;
+
+ // Create 2 new column families
+ s = txn_db->CreateColumnFamily(cf_options, "CFA", &cfa);
+ ASSERT_OK(s);
+ s = txn_db->CreateColumnFamily(cf_options, "CFB", &cfb);
+ ASSERT_OK(s);
+
+ delete cfa;
+ delete cfb;
+ delete txn_db;
+ txn_db = nullptr;
+
+ // open DB with three column families
+ std::vector<ColumnFamilyDescriptor> column_families;
+ // have to open default column family
+ column_families.push_back(
+ ColumnFamilyDescriptor(kDefaultColumnFamilyName, ColumnFamilyOptions()));
+ // open the new column families
+ column_families.push_back(
+ ColumnFamilyDescriptor("CFA", ColumnFamilyOptions()));
+ column_families.push_back(
+ ColumnFamilyDescriptor("CFB", ColumnFamilyOptions()));
+ std::vector<ColumnFamilyHandle*> handles;
+ s = OptimisticTransactionDB::Open(options, dbname, column_families, &handles,
+ &txn_db);
+ ASSERT_OK(s);
+ assert(txn_db != nullptr);
+
+ Transaction* txn = txn_db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn);
+
+ txn->SetSnapshot();
+ snapshot_read_options.snapshot = txn->GetSnapshot();
+
+ txn_options.set_snapshot = true;
+ Transaction* txn2 = txn_db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn2);
+
+ // Write some data to the db
+ WriteBatch batch;
+ batch.Put("foo", "foo");
+ batch.Put(handles[1], "AAA", "bar");
+ batch.Put(handles[1], "AAAZZZ", "bar");
+ s = txn_db->Write(write_options, &batch);
+ ASSERT_OK(s);
+ txn_db->Delete(write_options, handles[1], "AAAZZZ");
+
+ // These keys do no conflict with existing writes since they're in
+ // different column families
+ txn->Delete("AAA");
+ txn->GetForUpdate(snapshot_read_options, handles[1], "foo", &value);
+ Slice key_slice("AAAZZZ");
+ Slice value_slices[2] = {Slice("bar"), Slice("bar")};
+ txn->Put(handles[2], SliceParts(&key_slice, 1), SliceParts(value_slices, 2));
+
+ ASSERT_EQ(3, txn->GetNumKeys());
+
+ // Txn should commit
+ s = txn->Commit();
+ ASSERT_OK(s);
+ s = txn_db->Get(read_options, "AAA", &value);
+ ASSERT_TRUE(s.IsNotFound());
+ s = txn_db->Get(read_options, handles[2], "AAAZZZ", &value);
+ ASSERT_EQ(value, "barbar");
+
+ Slice key_slices[3] = {Slice("AAA"), Slice("ZZ"), Slice("Z")};
+ Slice value_slice("barbarbar");
+ // This write will cause a conflict with the earlier batch write
+ txn2->Put(handles[1], SliceParts(key_slices, 3), SliceParts(&value_slice, 1));
+
+ txn2->Delete(handles[2], "XXX");
+ txn2->Delete(handles[1], "XXX");
+ s = txn2->GetForUpdate(snapshot_read_options, handles[1], "AAA", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ // Verify txn did not commit
+ s = txn2->Commit();
+ ASSERT_TRUE(s.IsBusy());
+ s = txn_db->Get(read_options, handles[1], "AAAZZZ", &value);
+ ASSERT_EQ(value, "barbar");
+
+ delete txn;
+ delete txn2;
+
+ txn = txn_db->BeginTransaction(write_options, txn_options);
+ snapshot_read_options.snapshot = txn->GetSnapshot();
+
+ txn2 = txn_db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn);
+
+ std::vector<ColumnFamilyHandle*> multiget_cfh = {handles[1], handles[2],
+ handles[0], handles[2]};
+ std::vector<Slice> multiget_keys = {"AAA", "AAAZZZ", "foo", "foo"};
+ std::vector<std::string> values(4);
+
+ std::vector<Status> results = txn->MultiGetForUpdate(
+ snapshot_read_options, multiget_cfh, multiget_keys, &values);
+ ASSERT_OK(results[0]);
+ ASSERT_OK(results[1]);
+ ASSERT_OK(results[2]);
+ ASSERT_TRUE(results[3].IsNotFound());
+ ASSERT_EQ(values[0], "bar");
+ ASSERT_EQ(values[1], "barbar");
+ ASSERT_EQ(values[2], "foo");
+
+ txn->Delete(handles[2], "ZZZ");
+ txn->Put(handles[2], "ZZZ", "YYY");
+ txn->Put(handles[2], "ZZZ", "YYYY");
+ txn->Delete(handles[2], "ZZZ");
+ txn->Put(handles[2], "AAAZZZ", "barbarbar");
+
+ ASSERT_EQ(5, txn->GetNumKeys());
+
+ // Txn should commit
+ s = txn->Commit();
+ ASSERT_OK(s);
+ s = txn_db->Get(read_options, handles[2], "ZZZ", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ // Put a key which will conflict with the next txn using the previous snapshot
+ txn_db->Put(write_options, handles[2], "foo", "000");
+
+ results = txn2->MultiGetForUpdate(snapshot_read_options, multiget_cfh,
+ multiget_keys, &values);
+ ASSERT_OK(results[0]);
+ ASSERT_OK(results[1]);
+ ASSERT_OK(results[2]);
+ ASSERT_TRUE(results[3].IsNotFound());
+ ASSERT_EQ(values[0], "bar");
+ ASSERT_EQ(values[1], "barbar");
+ ASSERT_EQ(values[2], "foo");
+
+ // Verify Txn Did not Commit
+ s = txn2->Commit();
+ ASSERT_TRUE(s.IsBusy());
+
+ s = txn_db->DropColumnFamily(handles[1]);
+ ASSERT_OK(s);
+ s = txn_db->DropColumnFamily(handles[2]);
+ ASSERT_OK(s);
+
+ delete txn;
+ delete txn2;
+
+ for (auto handle : handles) {
+ delete handle;
+ }
+}
+
+TEST_P(OptimisticTransactionTest, EmptyTest) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ string value;
+ Status s;
+
+ s = txn_db->Put(write_options, "aaa", "aaa");
+ ASSERT_OK(s);
+
+ Transaction* txn = txn_db->BeginTransaction(write_options);
+ s = txn->Commit();
+ ASSERT_OK(s);
+ delete txn;
+
+ txn = txn_db->BeginTransaction(write_options);
+ txn->Rollback();
+ delete txn;
+
+ txn = txn_db->BeginTransaction(write_options);
+ s = txn->GetForUpdate(read_options, "aaa", &value);
+ ASSERT_EQ(value, "aaa");
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+ delete txn;
+
+ txn = txn_db->BeginTransaction(write_options);
+ txn->SetSnapshot();
+ s = txn->GetForUpdate(read_options, "aaa", &value);
+ ASSERT_EQ(value, "aaa");
+
+ s = txn_db->Put(write_options, "aaa", "xxx");
+ s = txn->Commit();
+ ASSERT_TRUE(s.IsBusy());
+ delete txn;
+}
+
+TEST_P(OptimisticTransactionTest, PredicateManyPreceders) {
+ WriteOptions write_options;
+ ReadOptions read_options1, read_options2;
+ OptimisticTransactionOptions txn_options;
+ string value;
+ Status s;
+
+ txn_options.set_snapshot = true;
+ Transaction* txn1 = txn_db->BeginTransaction(write_options, txn_options);
+ read_options1.snapshot = txn1->GetSnapshot();
+
+ Transaction* txn2 = txn_db->BeginTransaction(write_options);
+ txn2->SetSnapshot();
+ read_options2.snapshot = txn2->GetSnapshot();
+
+ std::vector<Slice> multiget_keys = {"1", "2", "3"};
+ std::vector<std::string> multiget_values;
+
+ std::vector<Status> results =
+ txn1->MultiGetForUpdate(read_options1, multiget_keys, &multiget_values);
+ ASSERT_TRUE(results[1].IsNotFound());
+
+ txn2->Put("2", "x");
+
+ s = txn2->Commit();
+ ASSERT_OK(s);
+
+ multiget_values.clear();
+ results =
+ txn1->MultiGetForUpdate(read_options1, multiget_keys, &multiget_values);
+ ASSERT_TRUE(results[1].IsNotFound());
+
+ // should not commit since txn2 wrote a key txn has read
+ s = txn1->Commit();
+ ASSERT_TRUE(s.IsBusy());
+
+ delete txn1;
+ delete txn2;
+
+ txn1 = txn_db->BeginTransaction(write_options, txn_options);
+ read_options1.snapshot = txn1->GetSnapshot();
+
+ txn2 = txn_db->BeginTransaction(write_options, txn_options);
+ read_options2.snapshot = txn2->GetSnapshot();
+
+ txn1->Put("4", "x");
+
+ txn2->Delete("4");
+
+ // txn1 can commit since txn2's delete hasn't happened yet (it's just batched)
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ s = txn2->GetForUpdate(read_options2, "4", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ // txn2 cannot commit since txn1 changed "4"
+ s = txn2->Commit();
+ ASSERT_TRUE(s.IsBusy());
+
+ delete txn1;
+ delete txn2;
+}
+
+TEST_P(OptimisticTransactionTest, LostUpdate) {
+ WriteOptions write_options;
+ ReadOptions read_options, read_options1, read_options2;
+ OptimisticTransactionOptions txn_options;
+ string value;
+ Status s;
+
+ // Test 2 transactions writing to the same key in multiple orders and
+ // with/without snapshots
+
+ Transaction* txn1 = txn_db->BeginTransaction(write_options);
+ Transaction* txn2 = txn_db->BeginTransaction(write_options);
+
+ txn1->Put("1", "1");
+ txn2->Put("1", "2");
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ s = txn2->Commit();
+ ASSERT_TRUE(s.IsBusy());
+
+ delete txn1;
+ delete txn2;
+
+ txn_options.set_snapshot = true;
+ txn1 = txn_db->BeginTransaction(write_options, txn_options);
+ read_options1.snapshot = txn1->GetSnapshot();
+
+ txn2 = txn_db->BeginTransaction(write_options, txn_options);
+ read_options2.snapshot = txn2->GetSnapshot();
+
+ txn1->Put("1", "3");
+ txn2->Put("1", "4");
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ s = txn2->Commit();
+ ASSERT_TRUE(s.IsBusy());
+
+ delete txn1;
+ delete txn2;
+
+ txn1 = txn_db->BeginTransaction(write_options, txn_options);
+ read_options1.snapshot = txn1->GetSnapshot();
+
+ txn2 = txn_db->BeginTransaction(write_options, txn_options);
+ read_options2.snapshot = txn2->GetSnapshot();
+
+ txn1->Put("1", "5");
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ txn2->Put("1", "6");
+ s = txn2->Commit();
+ ASSERT_TRUE(s.IsBusy());
+
+ delete txn1;
+ delete txn2;
+
+ txn1 = txn_db->BeginTransaction(write_options, txn_options);
+ read_options1.snapshot = txn1->GetSnapshot();
+
+ txn2 = txn_db->BeginTransaction(write_options, txn_options);
+ read_options2.snapshot = txn2->GetSnapshot();
+
+ txn1->Put("1", "5");
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ txn2->SetSnapshot();
+ txn2->Put("1", "6");
+ s = txn2->Commit();
+ ASSERT_OK(s);
+
+ delete txn1;
+ delete txn2;
+
+ txn1 = txn_db->BeginTransaction(write_options);
+ txn2 = txn_db->BeginTransaction(write_options);
+
+ txn1->Put("1", "7");
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ txn2->Put("1", "8");
+ s = txn2->Commit();
+ ASSERT_OK(s);
+
+ delete txn1;
+ delete txn2;
+
+ s = txn_db->Get(read_options, "1", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "8");
+}
+
+TEST_P(OptimisticTransactionTest, UntrackedWrites) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ string value;
+ Status s;
+
+ // Verify transaction rollback works for untracked keys.
+ Transaction* txn = txn_db->BeginTransaction(write_options);
+ txn->PutUntracked("untracked", "0");
+ txn->Rollback();
+ s = txn_db->Get(read_options, "untracked", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ delete txn;
+ txn = txn_db->BeginTransaction(write_options);
+
+ txn->Put("tracked", "1");
+ txn->PutUntracked("untracked", "1");
+ txn->MergeUntracked("untracked", "2");
+ txn->DeleteUntracked("untracked");
+
+ // Write to the untracked key outside of the transaction and verify
+ // it doesn't prevent the transaction from committing.
+ s = txn_db->Put(write_options, "untracked", "x");
+ ASSERT_OK(s);
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ s = txn_db->Get(read_options, "untracked", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ delete txn;
+ txn = txn_db->BeginTransaction(write_options);
+
+ txn->Put("tracked", "10");
+ txn->PutUntracked("untracked", "A");
+
+ // Write to tracked key outside of the transaction and verify that the
+ // untracked keys are not written when the commit fails.
+ s = txn_db->Delete(write_options, "tracked");
+
+ s = txn->Commit();
+ ASSERT_TRUE(s.IsBusy());
+
+ s = txn_db->Get(read_options, "untracked", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ delete txn;
+}
+
+TEST_P(OptimisticTransactionTest, IteratorTest) {
+ WriteOptions write_options;
+ ReadOptions read_options, snapshot_read_options;
+ OptimisticTransactionOptions txn_options;
+ string value;
+ Status s;
+
+ // Write some keys to the db
+ s = txn_db->Put(write_options, "A", "a");
+ ASSERT_OK(s);
+
+ s = txn_db->Put(write_options, "G", "g");
+ ASSERT_OK(s);
+
+ s = txn_db->Put(write_options, "F", "f");
+ ASSERT_OK(s);
+
+ s = txn_db->Put(write_options, "C", "c");
+ ASSERT_OK(s);
+
+ s = txn_db->Put(write_options, "D", "d");
+ ASSERT_OK(s);
+
+ Transaction* txn = txn_db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn);
+
+ // Write some keys in a txn
+ s = txn->Put("B", "b");
+ ASSERT_OK(s);
+
+ s = txn->Put("H", "h");
+ ASSERT_OK(s);
+
+ s = txn->Delete("D");
+ ASSERT_OK(s);
+
+ s = txn->Put("E", "e");
+ ASSERT_OK(s);
+
+ txn->SetSnapshot();
+ const Snapshot* snapshot = txn->GetSnapshot();
+
+ // Write some keys to the db after the snapshot
+ s = txn_db->Put(write_options, "BB", "xx");
+ ASSERT_OK(s);
+
+ s = txn_db->Put(write_options, "C", "xx");
+ ASSERT_OK(s);
+
+ read_options.snapshot = snapshot;
+ Iterator* iter = txn->GetIterator(read_options);
+ ASSERT_OK(iter->status());
+ iter->SeekToFirst();
+
+ // Read all keys via iter and lock them all
+ std::string results[] = {"a", "b", "c", "e", "f", "g", "h"};
+ for (int i = 0; i < 7; i++) {
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ(results[i], iter->value().ToString());
+
+ s = txn->GetForUpdate(read_options, iter->key(), nullptr);
+ ASSERT_OK(s);
+
+ iter->Next();
+ }
+ ASSERT_FALSE(iter->Valid());
+
+ iter->Seek("G");
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("g", iter->value().ToString());
+
+ iter->Prev();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("f", iter->value().ToString());
+
+ iter->Seek("D");
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("e", iter->value().ToString());
+
+ iter->Seek("C");
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("c", iter->value().ToString());
+
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("e", iter->value().ToString());
+
+ iter->Seek("");
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("a", iter->value().ToString());
+
+ iter->Seek("X");
+ ASSERT_OK(iter->status());
+ ASSERT_FALSE(iter->Valid());
+
+ iter->SeekToLast();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("h", iter->value().ToString());
+
+ // key "C" was modified in the db after txn's snapshot. txn will not commit.
+ s = txn->Commit();
+ ASSERT_TRUE(s.IsBusy());
+
+ delete iter;
+ delete txn;
+}
+
+TEST_P(OptimisticTransactionTest, SavepointTest) {
+ WriteOptions write_options;
+ ReadOptions read_options, snapshot_read_options;
+ OptimisticTransactionOptions txn_options;
+ string value;
+ Status s;
+
+ Transaction* txn = txn_db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn);
+
+ s = txn->RollbackToSavePoint();
+ ASSERT_TRUE(s.IsNotFound());
+
+ txn->SetSavePoint(); // 1
+
+ ASSERT_OK(txn->RollbackToSavePoint()); // Rollback to beginning of txn
+ s = txn->RollbackToSavePoint();
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn->Put("B", "b");
+ ASSERT_OK(s);
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ s = txn_db->Get(read_options, "B", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("b", value);
+
+ delete txn;
+ txn = txn_db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn);
+
+ s = txn->Put("A", "a");
+ ASSERT_OK(s);
+
+ s = txn->Put("B", "bb");
+ ASSERT_OK(s);
+
+ s = txn->Put("C", "c");
+ ASSERT_OK(s);
+
+ txn->SetSavePoint(); // 2
+
+ s = txn->Delete("B");
+ ASSERT_OK(s);
+
+ s = txn->Put("C", "cc");
+ ASSERT_OK(s);
+
+ s = txn->Put("D", "d");
+ ASSERT_OK(s);
+
+ ASSERT_OK(txn->RollbackToSavePoint()); // Rollback to 2
+
+ s = txn->Get(read_options, "A", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("a", value);
+
+ s = txn->Get(read_options, "B", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("bb", value);
+
+ s = txn->Get(read_options, "C", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("c", value);
+
+ s = txn->Get(read_options, "D", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn->Put("A", "a");
+ ASSERT_OK(s);
+
+ s = txn->Put("E", "e");
+ ASSERT_OK(s);
+
+ // Rollback to beginning of txn
+ s = txn->RollbackToSavePoint();
+ ASSERT_TRUE(s.IsNotFound());
+ txn->Rollback();
+
+ s = txn->Get(read_options, "A", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn->Get(read_options, "B", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("b", value);
+
+ s = txn->Get(read_options, "D", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn->Get(read_options, "D", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn->Get(read_options, "E", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn->Put("A", "aa");
+ ASSERT_OK(s);
+
+ s = txn->Put("F", "f");
+ ASSERT_OK(s);
+
+ txn->SetSavePoint(); // 3
+ txn->SetSavePoint(); // 4
+
+ s = txn->Put("G", "g");
+ ASSERT_OK(s);
+
+ s = txn->Delete("F");
+ ASSERT_OK(s);
+
+ s = txn->Delete("B");
+ ASSERT_OK(s);
+
+ s = txn->Get(read_options, "A", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("aa", value);
+
+ s = txn->Get(read_options, "F", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn->Get(read_options, "B", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ ASSERT_OK(txn->RollbackToSavePoint()); // Rollback to 3
+
+ s = txn->Get(read_options, "F", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("f", value);
+
+ s = txn->Get(read_options, "G", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ s = txn_db->Get(read_options, "F", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("f", value);
+
+ s = txn_db->Get(read_options, "G", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn_db->Get(read_options, "A", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("aa", value);
+
+ s = txn_db->Get(read_options, "B", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("b", value);
+
+ s = txn_db->Get(read_options, "C", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn_db->Get(read_options, "D", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn_db->Get(read_options, "E", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ delete txn;
+}
+
+TEST_P(OptimisticTransactionTest, UndoGetForUpdateTest) {
+ WriteOptions write_options;
+ ReadOptions read_options, snapshot_read_options;
+ OptimisticTransactionOptions txn_options;
+ string value;
+ Status s;
+
+ txn_db->Put(write_options, "A", "");
+
+ Transaction* txn1 = txn_db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn1);
+
+ s = txn1->GetForUpdate(read_options, "A", &value);
+ ASSERT_OK(s);
+
+ txn1->UndoGetForUpdate("A");
+
+ Transaction* txn2 = txn_db->BeginTransaction(write_options);
+ txn2->Put("A", "x");
+ s = txn2->Commit();
+ ASSERT_OK(s);
+ delete txn2;
+
+ // Verify that txn1 can commit since A isn't conflict checked
+ s = txn1->Commit();
+ ASSERT_OK(s);
+ delete txn1;
+
+ txn1 = txn_db->BeginTransaction(write_options);
+ txn1->Put("A", "a");
+
+ s = txn1->GetForUpdate(read_options, "A", &value);
+ ASSERT_OK(s);
+
+ txn1->UndoGetForUpdate("A");
+
+ txn2 = txn_db->BeginTransaction(write_options);
+ txn2->Put("A", "x");
+ s = txn2->Commit();
+ ASSERT_OK(s);
+ delete txn2;
+
+ // Verify that txn1 cannot commit since A will still be conflict checked
+ s = txn1->Commit();
+ ASSERT_TRUE(s.IsBusy());
+ delete txn1;
+
+ txn1 = txn_db->BeginTransaction(write_options);
+
+ s = txn1->GetForUpdate(read_options, "A", &value);
+ ASSERT_OK(s);
+ s = txn1->GetForUpdate(read_options, "A", &value);
+ ASSERT_OK(s);
+
+ txn1->UndoGetForUpdate("A");
+
+ txn2 = txn_db->BeginTransaction(write_options);
+ txn2->Put("A", "x");
+ s = txn2->Commit();
+ ASSERT_OK(s);
+ delete txn2;
+
+ // Verify that txn1 cannot commit since A will still be conflict checked
+ s = txn1->Commit();
+ ASSERT_TRUE(s.IsBusy());
+ delete txn1;
+
+ txn1 = txn_db->BeginTransaction(write_options);
+
+ s = txn1->GetForUpdate(read_options, "A", &value);
+ ASSERT_OK(s);
+ s = txn1->GetForUpdate(read_options, "A", &value);
+ ASSERT_OK(s);
+
+ txn1->UndoGetForUpdate("A");
+ txn1->UndoGetForUpdate("A");
+
+ txn2 = txn_db->BeginTransaction(write_options);
+ txn2->Put("A", "x");
+ s = txn2->Commit();
+ ASSERT_OK(s);
+ delete txn2;
+
+ // Verify that txn1 can commit since A isn't conflict checked
+ s = txn1->Commit();
+ ASSERT_OK(s);
+ delete txn1;
+
+ txn1 = txn_db->BeginTransaction(write_options);
+
+ s = txn1->GetForUpdate(read_options, "A", &value);
+ ASSERT_OK(s);
+
+ txn1->SetSavePoint();
+ txn1->UndoGetForUpdate("A");
+
+ txn2 = txn_db->BeginTransaction(write_options);
+ txn2->Put("A", "x");
+ s = txn2->Commit();
+ ASSERT_OK(s);
+ delete txn2;
+
+ // Verify that txn1 cannot commit since A will still be conflict checked
+ s = txn1->Commit();
+ ASSERT_TRUE(s.IsBusy());
+ delete txn1;
+
+ txn1 = txn_db->BeginTransaction(write_options);
+
+ s = txn1->GetForUpdate(read_options, "A", &value);
+ ASSERT_OK(s);
+
+ txn1->SetSavePoint();
+ s = txn1->GetForUpdate(read_options, "A", &value);
+ ASSERT_OK(s);
+ txn1->UndoGetForUpdate("A");
+
+ txn2 = txn_db->BeginTransaction(write_options);
+ txn2->Put("A", "x");
+ s = txn2->Commit();
+ ASSERT_OK(s);
+ delete txn2;
+
+ // Verify that txn1 cannot commit since A will still be conflict checked
+ s = txn1->Commit();
+ ASSERT_TRUE(s.IsBusy());
+ delete txn1;
+
+ txn1 = txn_db->BeginTransaction(write_options);
+
+ s = txn1->GetForUpdate(read_options, "A", &value);
+ ASSERT_OK(s);
+
+ txn1->SetSavePoint();
+ s = txn1->GetForUpdate(read_options, "A", &value);
+ ASSERT_OK(s);
+ txn1->UndoGetForUpdate("A");
+
+ txn1->RollbackToSavePoint();
+ txn1->UndoGetForUpdate("A");
+
+ txn2 = txn_db->BeginTransaction(write_options);
+ txn2->Put("A", "x");
+ s = txn2->Commit();
+ ASSERT_OK(s);
+ delete txn2;
+
+ // Verify that txn1 can commit since A isn't conflict checked
+ s = txn1->Commit();
+ ASSERT_OK(s);
+ delete txn1;
+}
+
+namespace {
+Status OptimisticTransactionStressTestInserter(OptimisticTransactionDB* db,
+ const size_t num_transactions,
+ const size_t num_sets,
+ const size_t num_keys_per_set) {
+ size_t seed = std::hash<std::thread::id>()(std::this_thread::get_id());
+ Random64 _rand(seed);
+ WriteOptions write_options;
+ ReadOptions read_options;
+ OptimisticTransactionOptions txn_options;
+ txn_options.set_snapshot = true;
+
+ RandomTransactionInserter inserter(&_rand, write_options, read_options,
+ num_keys_per_set,
+ static_cast<uint16_t>(num_sets));
+
+ for (size_t t = 0; t < num_transactions; t++) {
+ bool success = inserter.OptimisticTransactionDBInsert(db, txn_options);
+ if (!success) {
+ // unexpected failure
+ return inserter.GetLastStatus();
+ }
+ }
+
+ // Make sure at least some of the transactions succeeded. It's ok if
+ // some failed due to write-conflicts.
+ if (inserter.GetFailureCount() > num_transactions / 2) {
+ return Status::TryAgain("Too many transactions failed! " +
+ std::to_string(inserter.GetFailureCount()) + " / " +
+ std::to_string(num_transactions));
+ }
+
+ return Status::OK();
+}
+} // namespace
+
+TEST_P(OptimisticTransactionTest, OptimisticTransactionStressTest) {
+ const size_t num_threads = 4;
+ const size_t num_transactions_per_thread = 10000;
+ const size_t num_sets = 3;
+ const size_t num_keys_per_set = 100;
+ // Setting the key-space to be 100 keys should cause enough write-conflicts
+ // to make this test interesting.
+
+ std::vector<port::Thread> threads;
+
+ std::function<void()> call_inserter = [&] {
+ ASSERT_OK(OptimisticTransactionStressTestInserter(
+ txn_db, num_transactions_per_thread, num_sets, num_keys_per_set));
+ };
+
+ // Create N threads that use RandomTransactionInserter to write
+ // many transactions.
+ for (uint32_t i = 0; i < num_threads; i++) {
+ threads.emplace_back(call_inserter);
+ }
+
+ // Wait for all threads to run
+ for (auto& t : threads) {
+ t.join();
+ }
+
+ // Verify that data is consistent
+ Status s = RandomTransactionInserter::Verify(txn_db, num_sets);
+ ASSERT_OK(s);
+}
+
+TEST_P(OptimisticTransactionTest, SequenceNumberAfterRecoverTest) {
+ WriteOptions write_options;
+ OptimisticTransactionOptions transaction_options;
+
+ Transaction* transaction(txn_db->BeginTransaction(write_options, transaction_options));
+ Status s = transaction->Put("foo", "val");
+ ASSERT_OK(s);
+ s = transaction->Put("foo2", "val");
+ ASSERT_OK(s);
+ s = transaction->Put("foo3", "val");
+ ASSERT_OK(s);
+ s = transaction->Commit();
+ ASSERT_OK(s);
+ delete transaction;
+
+ Reopen();
+ transaction = txn_db->BeginTransaction(write_options, transaction_options);
+ s = transaction->Put("bar", "val");
+ ASSERT_OK(s);
+ s = transaction->Put("bar2", "val");
+ ASSERT_OK(s);
+ s = transaction->Commit();
+ ASSERT_OK(s);
+
+ delete transaction;
+}
+
+INSTANTIATE_TEST_CASE_P(
+ InstanceOccGroup, OptimisticTransactionTest,
+ testing::Values(OccValidationPolicy::kValidateSerial,
+ OccValidationPolicy::kValidateParallel));
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
+
+#else
+#include <stdio.h>
+
+int main(int /*argc*/, char** /*argv*/) {
+ fprintf(
+ stderr,
+ "SKIPPED as optimistic_transaction is not supported in ROCKSDB_LITE\n");
+ return 0;
+}
+
+#endif // !ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/pessimistic_transaction.cc b/src/rocksdb/utilities/transactions/pessimistic_transaction.cc
new file mode 100644
index 000000000..5ae5fed08
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/pessimistic_transaction.cc
@@ -0,0 +1,723 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "utilities/transactions/pessimistic_transaction.h"
+
+#include <map>
+#include <set>
+#include <string>
+#include <vector>
+
+#include "db/column_family.h"
+#include "db/db_impl/db_impl.h"
+#include "rocksdb/comparator.h"
+#include "rocksdb/db.h"
+#include "rocksdb/snapshot.h"
+#include "rocksdb/status.h"
+#include "rocksdb/utilities/transaction_db.h"
+#include "test_util/sync_point.h"
+#include "util/cast_util.h"
+#include "util/string_util.h"
+#include "utilities/transactions/pessimistic_transaction_db.h"
+#include "utilities/transactions/transaction_util.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+struct WriteOptions;
+
+std::atomic<TransactionID> PessimisticTransaction::txn_id_counter_(1);
+
+TransactionID PessimisticTransaction::GenTxnID() {
+ return txn_id_counter_.fetch_add(1);
+}
+
+PessimisticTransaction::PessimisticTransaction(
+ TransactionDB* txn_db, const WriteOptions& write_options,
+ const TransactionOptions& txn_options, const bool init)
+ : TransactionBaseImpl(txn_db->GetRootDB(), write_options),
+ txn_db_impl_(nullptr),
+ expiration_time_(0),
+ txn_id_(0),
+ waiting_cf_id_(0),
+ waiting_key_(nullptr),
+ lock_timeout_(0),
+ deadlock_detect_(false),
+ deadlock_detect_depth_(0),
+ skip_concurrency_control_(false) {
+ txn_db_impl_ =
+ static_cast_with_check<PessimisticTransactionDB, TransactionDB>(txn_db);
+ db_impl_ = static_cast_with_check<DBImpl, DB>(db_);
+ if (init) {
+ Initialize(txn_options);
+ }
+}
+
+void PessimisticTransaction::Initialize(const TransactionOptions& txn_options) {
+ txn_id_ = GenTxnID();
+
+ txn_state_ = STARTED;
+
+ deadlock_detect_ = txn_options.deadlock_detect;
+ deadlock_detect_depth_ = txn_options.deadlock_detect_depth;
+ write_batch_.SetMaxBytes(txn_options.max_write_batch_size);
+ skip_concurrency_control_ = txn_options.skip_concurrency_control;
+
+ lock_timeout_ = txn_options.lock_timeout * 1000;
+ if (lock_timeout_ < 0) {
+ // Lock timeout not set, use default
+ lock_timeout_ =
+ txn_db_impl_->GetTxnDBOptions().transaction_lock_timeout * 1000;
+ }
+
+ if (txn_options.expiration >= 0) {
+ expiration_time_ = start_time_ + txn_options.expiration * 1000;
+ } else {
+ expiration_time_ = 0;
+ }
+
+ if (txn_options.set_snapshot) {
+ SetSnapshot();
+ }
+
+ if (expiration_time_ > 0) {
+ txn_db_impl_->InsertExpirableTransaction(txn_id_, this);
+ }
+ use_only_the_last_commit_time_batch_for_recovery_ =
+ txn_options.use_only_the_last_commit_time_batch_for_recovery;
+}
+
+PessimisticTransaction::~PessimisticTransaction() {
+ txn_db_impl_->UnLock(this, &GetTrackedKeys());
+ if (expiration_time_ > 0) {
+ txn_db_impl_->RemoveExpirableTransaction(txn_id_);
+ }
+ if (!name_.empty() && txn_state_ != COMMITED) {
+ txn_db_impl_->UnregisterTransaction(this);
+ }
+}
+
+void PessimisticTransaction::Clear() {
+ txn_db_impl_->UnLock(this, &GetTrackedKeys());
+ TransactionBaseImpl::Clear();
+}
+
+void PessimisticTransaction::Reinitialize(
+ TransactionDB* txn_db, const WriteOptions& write_options,
+ const TransactionOptions& txn_options) {
+ if (!name_.empty() && txn_state_ != COMMITED) {
+ txn_db_impl_->UnregisterTransaction(this);
+ }
+ TransactionBaseImpl::Reinitialize(txn_db->GetRootDB(), write_options);
+ Initialize(txn_options);
+}
+
+bool PessimisticTransaction::IsExpired() const {
+ if (expiration_time_ > 0) {
+ if (db_->GetEnv()->NowMicros() >= expiration_time_) {
+ // Transaction is expired.
+ return true;
+ }
+ }
+
+ return false;
+}
+
+WriteCommittedTxn::WriteCommittedTxn(TransactionDB* txn_db,
+ const WriteOptions& write_options,
+ const TransactionOptions& txn_options)
+ : PessimisticTransaction(txn_db, write_options, txn_options){};
+
+Status PessimisticTransaction::CommitBatch(WriteBatch* batch) {
+ TransactionKeyMap keys_to_unlock;
+ Status s = LockBatch(batch, &keys_to_unlock);
+
+ if (!s.ok()) {
+ return s;
+ }
+
+ bool can_commit = false;
+
+ if (IsExpired()) {
+ s = Status::Expired();
+ } else if (expiration_time_ > 0) {
+ TransactionState expected = STARTED;
+ can_commit = std::atomic_compare_exchange_strong(&txn_state_, &expected,
+ AWAITING_COMMIT);
+ } else if (txn_state_ == STARTED) {
+ // lock stealing is not a concern
+ can_commit = true;
+ }
+
+ if (can_commit) {
+ txn_state_.store(AWAITING_COMMIT);
+ s = CommitBatchInternal(batch);
+ if (s.ok()) {
+ txn_state_.store(COMMITED);
+ }
+ } else if (txn_state_ == LOCKS_STOLEN) {
+ s = Status::Expired();
+ } else {
+ s = Status::InvalidArgument("Transaction is not in state for commit.");
+ }
+
+ txn_db_impl_->UnLock(this, &keys_to_unlock);
+
+ return s;
+}
+
+Status PessimisticTransaction::Prepare() {
+ Status s;
+
+ if (name_.empty()) {
+ return Status::InvalidArgument(
+ "Cannot prepare a transaction that has not been named.");
+ }
+
+ if (IsExpired()) {
+ return Status::Expired();
+ }
+
+ bool can_prepare = false;
+
+ if (expiration_time_ > 0) {
+ // must concern ourselves with expiraton and/or lock stealing
+ // need to compare/exchange bc locks could be stolen under us here
+ TransactionState expected = STARTED;
+ can_prepare = std::atomic_compare_exchange_strong(&txn_state_, &expected,
+ AWAITING_PREPARE);
+ } else if (txn_state_ == STARTED) {
+ // expiration and lock stealing is not possible
+ can_prepare = true;
+ }
+
+ if (can_prepare) {
+ txn_state_.store(AWAITING_PREPARE);
+ // transaction can't expire after preparation
+ expiration_time_ = 0;
+ assert(log_number_ == 0 ||
+ txn_db_impl_->GetTxnDBOptions().write_policy == WRITE_UNPREPARED);
+
+ s = PrepareInternal();
+ if (s.ok()) {
+ txn_state_.store(PREPARED);
+ }
+ } else if (txn_state_ == LOCKS_STOLEN) {
+ s = Status::Expired();
+ } else if (txn_state_ == PREPARED) {
+ s = Status::InvalidArgument("Transaction has already been prepared.");
+ } else if (txn_state_ == COMMITED) {
+ s = Status::InvalidArgument("Transaction has already been committed.");
+ } else if (txn_state_ == ROLLEDBACK) {
+ s = Status::InvalidArgument("Transaction has already been rolledback.");
+ } else {
+ s = Status::InvalidArgument("Transaction is not in state for commit.");
+ }
+
+ return s;
+}
+
+Status WriteCommittedTxn::PrepareInternal() {
+ WriteOptions write_options = write_options_;
+ write_options.disableWAL = false;
+ WriteBatchInternal::MarkEndPrepare(GetWriteBatch()->GetWriteBatch(), name_);
+ class MarkLogCallback : public PreReleaseCallback {
+ public:
+ MarkLogCallback(DBImpl* db, bool two_write_queues)
+ : db_(db), two_write_queues_(two_write_queues) {
+ (void)two_write_queues_; // to silence unused private field warning
+ }
+ virtual Status Callback(SequenceNumber, bool is_mem_disabled,
+ uint64_t log_number, size_t /*index*/,
+ size_t /*total*/) override {
+#ifdef NDEBUG
+ (void)is_mem_disabled;
+#endif
+ assert(log_number != 0);
+ assert(!two_write_queues_ || is_mem_disabled); // implies the 2nd queue
+ db_->logs_with_prep_tracker()->MarkLogAsContainingPrepSection(log_number);
+ return Status::OK();
+ }
+
+ private:
+ DBImpl* db_;
+ bool two_write_queues_;
+ } mark_log_callback(db_impl_,
+ db_impl_->immutable_db_options().two_write_queues);
+
+ WriteCallback* const kNoWriteCallback = nullptr;
+ const uint64_t kRefNoLog = 0;
+ const bool kDisableMemtable = true;
+ SequenceNumber* const KIgnoreSeqUsed = nullptr;
+ const size_t kNoBatchCount = 0;
+ Status s = db_impl_->WriteImpl(
+ write_options, GetWriteBatch()->GetWriteBatch(), kNoWriteCallback,
+ &log_number_, kRefNoLog, kDisableMemtable, KIgnoreSeqUsed, kNoBatchCount,
+ &mark_log_callback);
+ return s;
+}
+
+Status PessimisticTransaction::Commit() {
+ Status s;
+ bool commit_without_prepare = false;
+ bool commit_prepared = false;
+
+ if (IsExpired()) {
+ return Status::Expired();
+ }
+
+ if (expiration_time_ > 0) {
+ // we must atomicaly compare and exchange the state here because at
+ // this state in the transaction it is possible for another thread
+ // to change our state out from under us in the even that we expire and have
+ // our locks stolen. In this case the only valid state is STARTED because
+ // a state of PREPARED would have a cleared expiration_time_.
+ TransactionState expected = STARTED;
+ commit_without_prepare = std::atomic_compare_exchange_strong(
+ &txn_state_, &expected, AWAITING_COMMIT);
+ TEST_SYNC_POINT("TransactionTest::ExpirableTransactionDataRace:1");
+ } else if (txn_state_ == PREPARED) {
+ // expiration and lock stealing is not a concern
+ commit_prepared = true;
+ } else if (txn_state_ == STARTED) {
+ // expiration and lock stealing is not a concern
+ commit_without_prepare = true;
+ // TODO(myabandeh): what if the user mistakenly forgets prepare? We should
+ // add an option so that the user explictly express the intention of
+ // skipping the prepare phase.
+ }
+
+ if (commit_without_prepare) {
+ assert(!commit_prepared);
+ if (WriteBatchInternal::Count(GetCommitTimeWriteBatch()) > 0) {
+ s = Status::InvalidArgument(
+ "Commit-time batch contains values that will not be committed.");
+ } else {
+ txn_state_.store(AWAITING_COMMIT);
+ if (log_number_ > 0) {
+ dbimpl_->logs_with_prep_tracker()->MarkLogAsHavingPrepSectionFlushed(
+ log_number_);
+ }
+ s = CommitWithoutPrepareInternal();
+ if (!name_.empty()) {
+ txn_db_impl_->UnregisterTransaction(this);
+ }
+ Clear();
+ if (s.ok()) {
+ txn_state_.store(COMMITED);
+ }
+ }
+ } else if (commit_prepared) {
+ txn_state_.store(AWAITING_COMMIT);
+
+ s = CommitInternal();
+
+ if (!s.ok()) {
+ ROCKS_LOG_WARN(db_impl_->immutable_db_options().info_log,
+ "Commit write failed");
+ return s;
+ }
+
+ // FindObsoleteFiles must now look to the memtables
+ // to determine what prep logs must be kept around,
+ // not the prep section heap.
+ assert(log_number_ > 0);
+ dbimpl_->logs_with_prep_tracker()->MarkLogAsHavingPrepSectionFlushed(
+ log_number_);
+ txn_db_impl_->UnregisterTransaction(this);
+
+ Clear();
+ txn_state_.store(COMMITED);
+ } else if (txn_state_ == LOCKS_STOLEN) {
+ s = Status::Expired();
+ } else if (txn_state_ == COMMITED) {
+ s = Status::InvalidArgument("Transaction has already been committed.");
+ } else if (txn_state_ == ROLLEDBACK) {
+ s = Status::InvalidArgument("Transaction has already been rolledback.");
+ } else {
+ s = Status::InvalidArgument("Transaction is not in state for commit.");
+ }
+
+ return s;
+}
+
+Status WriteCommittedTxn::CommitWithoutPrepareInternal() {
+ uint64_t seq_used = kMaxSequenceNumber;
+ auto s =
+ db_impl_->WriteImpl(write_options_, GetWriteBatch()->GetWriteBatch(),
+ /*callback*/ nullptr, /*log_used*/ nullptr,
+ /*log_ref*/ 0, /*disable_memtable*/ false, &seq_used);
+ assert(!s.ok() || seq_used != kMaxSequenceNumber);
+ if (s.ok()) {
+ SetId(seq_used);
+ }
+ return s;
+}
+
+Status WriteCommittedTxn::CommitBatchInternal(WriteBatch* batch, size_t) {
+ uint64_t seq_used = kMaxSequenceNumber;
+ auto s = db_impl_->WriteImpl(write_options_, batch, /*callback*/ nullptr,
+ /*log_used*/ nullptr, /*log_ref*/ 0,
+ /*disable_memtable*/ false, &seq_used);
+ assert(!s.ok() || seq_used != kMaxSequenceNumber);
+ if (s.ok()) {
+ SetId(seq_used);
+ }
+ return s;
+}
+
+Status WriteCommittedTxn::CommitInternal() {
+ // We take the commit-time batch and append the Commit marker.
+ // The Memtable will ignore the Commit marker in non-recovery mode
+ WriteBatch* working_batch = GetCommitTimeWriteBatch();
+ WriteBatchInternal::MarkCommit(working_batch, name_);
+
+ // any operations appended to this working_batch will be ignored from WAL
+ working_batch->MarkWalTerminationPoint();
+
+ // insert prepared batch into Memtable only skipping WAL.
+ // Memtable will ignore BeginPrepare/EndPrepare markers
+ // in non recovery mode and simply insert the values
+ WriteBatchInternal::Append(working_batch, GetWriteBatch()->GetWriteBatch());
+
+ uint64_t seq_used = kMaxSequenceNumber;
+ auto s =
+ db_impl_->WriteImpl(write_options_, working_batch, /*callback*/ nullptr,
+ /*log_used*/ nullptr, /*log_ref*/ log_number_,
+ /*disable_memtable*/ false, &seq_used);
+ assert(!s.ok() || seq_used != kMaxSequenceNumber);
+ if (s.ok()) {
+ SetId(seq_used);
+ }
+ return s;
+}
+
+Status PessimisticTransaction::Rollback() {
+ Status s;
+ if (txn_state_ == PREPARED) {
+ txn_state_.store(AWAITING_ROLLBACK);
+
+ s = RollbackInternal();
+
+ if (s.ok()) {
+ // we do not need to keep our prepared section around
+ assert(log_number_ > 0);
+ dbimpl_->logs_with_prep_tracker()->MarkLogAsHavingPrepSectionFlushed(
+ log_number_);
+ Clear();
+ txn_state_.store(ROLLEDBACK);
+ }
+ } else if (txn_state_ == STARTED) {
+ if (log_number_ > 0) {
+ assert(txn_db_impl_->GetTxnDBOptions().write_policy == WRITE_UNPREPARED);
+ assert(GetId() > 0);
+ s = RollbackInternal();
+
+ if (s.ok()) {
+ dbimpl_->logs_with_prep_tracker()->MarkLogAsHavingPrepSectionFlushed(
+ log_number_);
+ }
+ }
+ // prepare couldn't have taken place
+ Clear();
+ } else if (txn_state_ == COMMITED) {
+ s = Status::InvalidArgument("This transaction has already been committed.");
+ } else {
+ s = Status::InvalidArgument(
+ "Two phase transaction is not in state for rollback.");
+ }
+
+ return s;
+}
+
+Status WriteCommittedTxn::RollbackInternal() {
+ WriteBatch rollback_marker;
+ WriteBatchInternal::MarkRollback(&rollback_marker, name_);
+ auto s = db_impl_->WriteImpl(write_options_, &rollback_marker);
+ return s;
+}
+
+Status PessimisticTransaction::RollbackToSavePoint() {
+ if (txn_state_ != STARTED) {
+ return Status::InvalidArgument("Transaction is beyond state for rollback.");
+ }
+
+ // Unlock any keys locked since last transaction
+ const std::unique_ptr<TransactionKeyMap>& keys =
+ GetTrackedKeysSinceSavePoint();
+
+ if (keys) {
+ txn_db_impl_->UnLock(this, keys.get());
+ }
+
+ return TransactionBaseImpl::RollbackToSavePoint();
+}
+
+// Lock all keys in this batch.
+// On success, caller should unlock keys_to_unlock
+Status PessimisticTransaction::LockBatch(WriteBatch* batch,
+ TransactionKeyMap* keys_to_unlock) {
+ class Handler : public WriteBatch::Handler {
+ public:
+ // Sorted map of column_family_id to sorted set of keys.
+ // Since LockBatch() always locks keys in sorted order, it cannot deadlock
+ // with itself. We're not using a comparator here since it doesn't matter
+ // what the sorting is as long as it's consistent.
+ std::map<uint32_t, std::set<std::string>> keys_;
+
+ Handler() {}
+
+ void RecordKey(uint32_t column_family_id, const Slice& key) {
+ std::string key_str = key.ToString();
+
+ auto& cfh_keys = keys_[column_family_id];
+ auto iter = cfh_keys.find(key_str);
+ if (iter == cfh_keys.end()) {
+ // key not yet seen, store it.
+ cfh_keys.insert({std::move(key_str)});
+ }
+ }
+
+ Status PutCF(uint32_t column_family_id, const Slice& key,
+ const Slice& /* unused */) override {
+ RecordKey(column_family_id, key);
+ return Status::OK();
+ }
+ Status MergeCF(uint32_t column_family_id, const Slice& key,
+ const Slice& /* unused */) override {
+ RecordKey(column_family_id, key);
+ return Status::OK();
+ }
+ Status DeleteCF(uint32_t column_family_id, const Slice& key) override {
+ RecordKey(column_family_id, key);
+ return Status::OK();
+ }
+ };
+
+ // Iterating on this handler will add all keys in this batch into keys
+ Handler handler;
+ batch->Iterate(&handler);
+
+ Status s;
+
+ // Attempt to lock all keys
+ for (const auto& cf_iter : handler.keys_) {
+ uint32_t cfh_id = cf_iter.first;
+ auto& cfh_keys = cf_iter.second;
+
+ for (const auto& key_iter : cfh_keys) {
+ const std::string& key = key_iter;
+
+ s = txn_db_impl_->TryLock(this, cfh_id, key, true /* exclusive */);
+ if (!s.ok()) {
+ break;
+ }
+ TrackKey(keys_to_unlock, cfh_id, std::move(key), kMaxSequenceNumber,
+ false, true /* exclusive */);
+ }
+
+ if (!s.ok()) {
+ break;
+ }
+ }
+
+ if (!s.ok()) {
+ txn_db_impl_->UnLock(this, keys_to_unlock);
+ }
+
+ return s;
+}
+
+// Attempt to lock this key.
+// Returns OK if the key has been successfully locked. Non-ok, otherwise.
+// If check_shapshot is true and this transaction has a snapshot set,
+// this key will only be locked if there have been no writes to this key since
+// the snapshot time.
+Status PessimisticTransaction::TryLock(ColumnFamilyHandle* column_family,
+ const Slice& key, bool read_only,
+ bool exclusive, const bool do_validate,
+ const bool assume_tracked) {
+ assert(!assume_tracked || !do_validate);
+ Status s;
+ if (UNLIKELY(skip_concurrency_control_)) {
+ return s;
+ }
+ uint32_t cfh_id = GetColumnFamilyID(column_family);
+ std::string key_str = key.ToString();
+ bool previously_locked;
+ bool lock_upgrade = false;
+
+ // lock this key if this transactions hasn't already locked it
+ SequenceNumber tracked_at_seq = kMaxSequenceNumber;
+
+ const auto& tracked_keys = GetTrackedKeys();
+ const auto tracked_keys_cf = tracked_keys.find(cfh_id);
+ if (tracked_keys_cf == tracked_keys.end()) {
+ previously_locked = false;
+ } else {
+ auto iter = tracked_keys_cf->second.find(key_str);
+ if (iter == tracked_keys_cf->second.end()) {
+ previously_locked = false;
+ } else {
+ if (!iter->second.exclusive && exclusive) {
+ lock_upgrade = true;
+ }
+ previously_locked = true;
+ tracked_at_seq = iter->second.seq;
+ }
+ }
+
+ // Lock this key if this transactions hasn't already locked it or we require
+ // an upgrade.
+ if (!previously_locked || lock_upgrade) {
+ s = txn_db_impl_->TryLock(this, cfh_id, key_str, exclusive);
+ }
+
+ SetSnapshotIfNeeded();
+
+ // Even though we do not care about doing conflict checking for this write,
+ // we still need to take a lock to make sure we do not cause a conflict with
+ // some other write. However, we do not need to check if there have been
+ // any writes since this transaction's snapshot.
+ // TODO(agiardullo): could optimize by supporting shared txn locks in the
+ // future
+ if (!do_validate || snapshot_ == nullptr) {
+ if (assume_tracked && !previously_locked) {
+ s = Status::InvalidArgument(
+ "assume_tracked is set but it is not tracked yet");
+ }
+ // Need to remember the earliest sequence number that we know that this
+ // key has not been modified after. This is useful if this same
+ // transaction
+ // later tries to lock this key again.
+ if (tracked_at_seq == kMaxSequenceNumber) {
+ // Since we haven't checked a snapshot, we only know this key has not
+ // been modified since after we locked it.
+ // Note: when last_seq_same_as_publish_seq_==false this is less than the
+ // latest allocated seq but it is ok since i) this is just a heuristic
+ // used only as a hint to avoid actual check for conflicts, ii) this would
+ // cause a false positive only if the snapthot is taken right after the
+ // lock, which would be an unusual sequence.
+ tracked_at_seq = db_->GetLatestSequenceNumber();
+ }
+ } else {
+ // If a snapshot is set, we need to make sure the key hasn't been modified
+ // since the snapshot. This must be done after we locked the key.
+ // If we already have validated an earilier snapshot it must has been
+ // reflected in tracked_at_seq and ValidateSnapshot will return OK.
+ if (s.ok()) {
+ s = ValidateSnapshot(column_family, key, &tracked_at_seq);
+
+ if (!s.ok()) {
+ // Failed to validate key
+ if (!previously_locked) {
+ // Unlock key we just locked
+ if (lock_upgrade) {
+ s = txn_db_impl_->TryLock(this, cfh_id, key_str,
+ false /* exclusive */);
+ assert(s.ok());
+ } else {
+ txn_db_impl_->UnLock(this, cfh_id, key.ToString());
+ }
+ }
+ }
+ }
+ }
+
+ if (s.ok()) {
+ // We must track all the locked keys so that we can unlock them later. If
+ // the key is already locked, this func will update some stats on the
+ // tracked key. It could also update the tracked_at_seq if it is lower
+ // than the existing tracked key seq. These stats are necessary for
+ // RollbackToSavePoint to determine whether a key can be safely removed
+ // from tracked_keys_. Removal can only be done if a key was only locked
+ // during the current savepoint.
+ //
+ // Recall that if assume_tracked is true, we assume that TrackKey has been
+ // called previously since the last savepoint, with the same exclusive
+ // setting, and at a lower sequence number, so skipping here should be
+ // safe.
+ if (!assume_tracked) {
+ TrackKey(cfh_id, key_str, tracked_at_seq, read_only, exclusive);
+ } else {
+#ifndef NDEBUG
+ assert(tracked_keys_cf->second.count(key_str) > 0);
+ const auto& info = tracked_keys_cf->second.find(key_str)->second;
+ assert(info.seq <= tracked_at_seq);
+ assert(info.exclusive == exclusive);
+#endif
+ }
+ }
+
+ return s;
+}
+
+// Return OK() if this key has not been modified more recently than the
+// transaction snapshot_.
+// tracked_at_seq is the global seq at which we either locked the key or already
+// have done ValidateSnapshot.
+Status PessimisticTransaction::ValidateSnapshot(
+ ColumnFamilyHandle* column_family, const Slice& key,
+ SequenceNumber* tracked_at_seq) {
+ assert(snapshot_);
+
+ SequenceNumber snap_seq = snapshot_->GetSequenceNumber();
+ if (*tracked_at_seq <= snap_seq) {
+ // If the key has been previous validated (or locked) at a sequence number
+ // earlier than the current snapshot's sequence number, we already know it
+ // has not been modified aftter snap_seq either.
+ return Status::OK();
+ }
+ // Otherwise we have either
+ // 1: tracked_at_seq == kMaxSequenceNumber, i.e., first time tracking the key
+ // 2: snap_seq < tracked_at_seq: last time we lock the key was via
+ // do_validate=false which means we had skipped ValidateSnapshot. In both
+ // cases we should do ValidateSnapshot now.
+
+ *tracked_at_seq = snap_seq;
+
+ ColumnFamilyHandle* cfh =
+ column_family ? column_family : db_impl_->DefaultColumnFamily();
+
+ return TransactionUtil::CheckKeyForConflicts(
+ db_impl_, cfh, key.ToString(), snap_seq, false /* cache_only */);
+}
+
+bool PessimisticTransaction::TryStealingLocks() {
+ assert(IsExpired());
+ TransactionState expected = STARTED;
+ return std::atomic_compare_exchange_strong(&txn_state_, &expected,
+ LOCKS_STOLEN);
+}
+
+void PessimisticTransaction::UnlockGetForUpdate(
+ ColumnFamilyHandle* column_family, const Slice& key) {
+ txn_db_impl_->UnLock(this, GetColumnFamilyID(column_family), key.ToString());
+}
+
+Status PessimisticTransaction::SetName(const TransactionName& name) {
+ Status s;
+ if (txn_state_ == STARTED) {
+ if (name_.length()) {
+ s = Status::InvalidArgument("Transaction has already been named.");
+ } else if (txn_db_impl_->GetTransactionByName(name) != nullptr) {
+ s = Status::InvalidArgument("Transaction name must be unique.");
+ } else if (name.length() < 1 || name.length() > 512) {
+ s = Status::InvalidArgument(
+ "Transaction name length must be between 1 and 512 chars.");
+ } else {
+ name_ = name;
+ txn_db_impl_->RegisterTransaction(this);
+ }
+ } else {
+ s = Status::InvalidArgument("Transaction is beyond state for naming.");
+ }
+ return s;
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/pessimistic_transaction.h b/src/rocksdb/utilities/transactions/pessimistic_transaction.h
new file mode 100644
index 000000000..8f2c84405
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/pessimistic_transaction.h
@@ -0,0 +1,225 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#ifndef ROCKSDB_LITE
+
+#include <algorithm>
+#include <atomic>
+#include <mutex>
+#include <stack>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "db/write_callback.h"
+#include "rocksdb/db.h"
+#include "rocksdb/slice.h"
+#include "rocksdb/snapshot.h"
+#include "rocksdb/status.h"
+#include "rocksdb/types.h"
+#include "rocksdb/utilities/transaction.h"
+#include "rocksdb/utilities/transaction_db.h"
+#include "rocksdb/utilities/write_batch_with_index.h"
+#include "util/autovector.h"
+#include "utilities/transactions/transaction_base.h"
+#include "utilities/transactions/transaction_util.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class PessimisticTransactionDB;
+
+// A transaction under pessimistic concurrency control. This class implements
+// the locking API and interfaces with the lock manager as well as the
+// pessimistic transactional db.
+class PessimisticTransaction : public TransactionBaseImpl {
+ public:
+ PessimisticTransaction(TransactionDB* db, const WriteOptions& write_options,
+ const TransactionOptions& txn_options,
+ const bool init = true);
+ // No copying allowed
+ PessimisticTransaction(const PessimisticTransaction&) = delete;
+ void operator=(const PessimisticTransaction&) = delete;
+
+ virtual ~PessimisticTransaction();
+
+ void Reinitialize(TransactionDB* txn_db, const WriteOptions& write_options,
+ const TransactionOptions& txn_options);
+
+ Status Prepare() override;
+
+ Status Commit() override;
+
+ // It is basically Commit without going through Prepare phase. The write batch
+ // is also directly provided instead of expecting txn to gradually batch the
+ // transactions writes to an internal write batch.
+ Status CommitBatch(WriteBatch* batch);
+
+ Status Rollback() override;
+
+ Status RollbackToSavePoint() override;
+
+ Status SetName(const TransactionName& name) override;
+
+ // Generate a new unique transaction identifier
+ static TransactionID GenTxnID();
+
+ TransactionID GetID() const override { return txn_id_; }
+
+ std::vector<TransactionID> GetWaitingTxns(uint32_t* column_family_id,
+ std::string* key) const override {
+ std::lock_guard<std::mutex> lock(wait_mutex_);
+ std::vector<TransactionID> ids(waiting_txn_ids_.size());
+ if (key) *key = waiting_key_ ? *waiting_key_ : "";
+ if (column_family_id) *column_family_id = waiting_cf_id_;
+ std::copy(waiting_txn_ids_.begin(), waiting_txn_ids_.end(), ids.begin());
+ return ids;
+ }
+
+ void SetWaitingTxn(autovector<TransactionID> ids, uint32_t column_family_id,
+ const std::string* key) {
+ std::lock_guard<std::mutex> lock(wait_mutex_);
+ waiting_txn_ids_ = ids;
+ waiting_cf_id_ = column_family_id;
+ waiting_key_ = key;
+ }
+
+ void ClearWaitingTxn() {
+ std::lock_guard<std::mutex> lock(wait_mutex_);
+ waiting_txn_ids_.clear();
+ waiting_cf_id_ = 0;
+ waiting_key_ = nullptr;
+ }
+
+ // Returns the time (in microseconds according to Env->GetMicros())
+ // that this transaction will be expired. Returns 0 if this transaction does
+ // not expire.
+ uint64_t GetExpirationTime() const { return expiration_time_; }
+
+ // returns true if this transaction has an expiration_time and has expired.
+ bool IsExpired() const;
+
+ // Returns the number of microseconds a transaction can wait on acquiring a
+ // lock or -1 if there is no timeout.
+ int64_t GetLockTimeout() const { return lock_timeout_; }
+ void SetLockTimeout(int64_t timeout) override {
+ lock_timeout_ = timeout * 1000;
+ }
+
+ // Returns true if locks were stolen successfully, false otherwise.
+ bool TryStealingLocks();
+
+ bool IsDeadlockDetect() const override { return deadlock_detect_; }
+
+ int64_t GetDeadlockDetectDepth() const { return deadlock_detect_depth_; }
+
+ protected:
+ // Refer to
+ // TransactionOptions::use_only_the_last_commit_time_batch_for_recovery
+ bool use_only_the_last_commit_time_batch_for_recovery_ = false;
+
+ virtual Status PrepareInternal() = 0;
+
+ virtual Status CommitWithoutPrepareInternal() = 0;
+
+ // batch_cnt if non-zero is the number of sub-batches. A sub-batch is a batch
+ // with no duplicate keys. If zero, then the number of sub-batches is unknown.
+ virtual Status CommitBatchInternal(WriteBatch* batch,
+ size_t batch_cnt = 0) = 0;
+
+ virtual Status CommitInternal() = 0;
+
+ virtual Status RollbackInternal() = 0;
+
+ virtual void Initialize(const TransactionOptions& txn_options);
+
+ Status LockBatch(WriteBatch* batch, TransactionKeyMap* keys_to_unlock);
+
+ Status TryLock(ColumnFamilyHandle* column_family, const Slice& key,
+ bool read_only, bool exclusive, const bool do_validate = true,
+ const bool assume_tracked = false) override;
+
+ void Clear() override;
+
+ PessimisticTransactionDB* txn_db_impl_;
+ DBImpl* db_impl_;
+
+ // If non-zero, this transaction should not be committed after this time (in
+ // microseconds according to Env->NowMicros())
+ uint64_t expiration_time_;
+
+ private:
+ friend class TransactionTest_ValidateSnapshotTest_Test;
+ // Used to create unique ids for transactions.
+ static std::atomic<TransactionID> txn_id_counter_;
+
+ // Unique ID for this transaction
+ TransactionID txn_id_;
+
+ // IDs for the transactions that are blocking the current transaction.
+ //
+ // empty if current transaction is not waiting.
+ autovector<TransactionID> waiting_txn_ids_;
+
+ // The following two represents the (cf, key) that a transaction is waiting
+ // on.
+ //
+ // If waiting_key_ is not null, then the pointer should always point to
+ // a valid string object. The reason is that it is only non-null when the
+ // transaction is blocked in the TransactionLockMgr::AcquireWithTimeout
+ // function. At that point, the key string object is one of the function
+ // parameters.
+ uint32_t waiting_cf_id_;
+ const std::string* waiting_key_;
+
+ // Mutex protecting waiting_txn_ids_, waiting_cf_id_ and waiting_key_.
+ mutable std::mutex wait_mutex_;
+
+ // Timeout in microseconds when locking a key or -1 if there is no timeout.
+ int64_t lock_timeout_;
+
+ // Whether to perform deadlock detection or not.
+ bool deadlock_detect_;
+
+ // Whether to perform deadlock detection or not.
+ int64_t deadlock_detect_depth_;
+
+ // Refer to TransactionOptions::skip_concurrency_control
+ bool skip_concurrency_control_;
+
+ virtual Status ValidateSnapshot(ColumnFamilyHandle* column_family,
+ const Slice& key,
+ SequenceNumber* tracked_at_seq);
+
+ void UnlockGetForUpdate(ColumnFamilyHandle* column_family,
+ const Slice& key) override;
+};
+
+class WriteCommittedTxn : public PessimisticTransaction {
+ public:
+ WriteCommittedTxn(TransactionDB* db, const WriteOptions& write_options,
+ const TransactionOptions& txn_options);
+ // No copying allowed
+ WriteCommittedTxn(const WriteCommittedTxn&) = delete;
+ void operator=(const WriteCommittedTxn&) = delete;
+
+ virtual ~WriteCommittedTxn() {}
+
+ private:
+ Status PrepareInternal() override;
+
+ Status CommitWithoutPrepareInternal() override;
+
+ Status CommitBatchInternal(WriteBatch* batch, size_t batch_cnt) override;
+
+ Status CommitInternal() override;
+
+ Status RollbackInternal() override;
+};
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/pessimistic_transaction_db.cc b/src/rocksdb/utilities/transactions/pessimistic_transaction_db.cc
new file mode 100644
index 000000000..30d5b79f6
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/pessimistic_transaction_db.cc
@@ -0,0 +1,632 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "utilities/transactions/pessimistic_transaction_db.h"
+
+#include <cinttypes>
+#include <string>
+#include <unordered_set>
+#include <vector>
+
+#include "db/db_impl/db_impl.h"
+#include "rocksdb/db.h"
+#include "rocksdb/options.h"
+#include "rocksdb/utilities/transaction_db.h"
+#include "test_util/sync_point.h"
+#include "util/cast_util.h"
+#include "util/mutexlock.h"
+#include "utilities/transactions/pessimistic_transaction.h"
+#include "utilities/transactions/transaction_db_mutex_impl.h"
+#include "utilities/transactions/write_prepared_txn_db.h"
+#include "utilities/transactions/write_unprepared_txn_db.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+PessimisticTransactionDB::PessimisticTransactionDB(
+ DB* db, const TransactionDBOptions& txn_db_options)
+ : TransactionDB(db),
+ db_impl_(static_cast_with_check<DBImpl, DB>(db)),
+ txn_db_options_(txn_db_options),
+ lock_mgr_(this, txn_db_options_.num_stripes, txn_db_options.max_num_locks,
+ txn_db_options_.max_num_deadlocks,
+ txn_db_options_.custom_mutex_factory
+ ? txn_db_options_.custom_mutex_factory
+ : std::shared_ptr<TransactionDBMutexFactory>(
+ new TransactionDBMutexFactoryImpl())) {
+ assert(db_impl_ != nullptr);
+ info_log_ = db_impl_->GetDBOptions().info_log;
+}
+
+// Support initiliazing PessimisticTransactionDB from a stackable db
+//
+// PessimisticTransactionDB
+// ^ ^
+// | |
+// | +
+// | StackableDB
+// | ^
+// | |
+// + +
+// DBImpl
+// ^
+// |(inherit)
+// +
+// DB
+//
+PessimisticTransactionDB::PessimisticTransactionDB(
+ StackableDB* db, const TransactionDBOptions& txn_db_options)
+ : TransactionDB(db),
+ db_impl_(static_cast_with_check<DBImpl, DB>(db->GetRootDB())),
+ txn_db_options_(txn_db_options),
+ lock_mgr_(this, txn_db_options_.num_stripes, txn_db_options.max_num_locks,
+ txn_db_options_.max_num_deadlocks,
+ txn_db_options_.custom_mutex_factory
+ ? txn_db_options_.custom_mutex_factory
+ : std::shared_ptr<TransactionDBMutexFactory>(
+ new TransactionDBMutexFactoryImpl())) {
+ assert(db_impl_ != nullptr);
+}
+
+PessimisticTransactionDB::~PessimisticTransactionDB() {
+ while (!transactions_.empty()) {
+ delete transactions_.begin()->second;
+ // TODO(myabandeh): this seems to be an unsafe approach as it is not quite
+ // clear whether delete would also remove the entry from transactions_.
+ }
+}
+
+Status PessimisticTransactionDB::VerifyCFOptions(const ColumnFamilyOptions&) {
+ return Status::OK();
+}
+
+Status PessimisticTransactionDB::Initialize(
+ const std::vector<size_t>& compaction_enabled_cf_indices,
+ const std::vector<ColumnFamilyHandle*>& handles) {
+ for (auto cf_ptr : handles) {
+ AddColumnFamily(cf_ptr);
+ }
+ // Verify cf options
+ for (auto handle : handles) {
+ ColumnFamilyDescriptor cfd;
+ Status s = handle->GetDescriptor(&cfd);
+ if (!s.ok()) {
+ return s;
+ }
+ s = VerifyCFOptions(cfd.options);
+ if (!s.ok()) {
+ return s;
+ }
+ }
+
+ // Re-enable compaction for the column families that initially had
+ // compaction enabled.
+ std::vector<ColumnFamilyHandle*> compaction_enabled_cf_handles;
+ compaction_enabled_cf_handles.reserve(compaction_enabled_cf_indices.size());
+ for (auto index : compaction_enabled_cf_indices) {
+ compaction_enabled_cf_handles.push_back(handles[index]);
+ }
+
+ Status s = EnableAutoCompaction(compaction_enabled_cf_handles);
+
+ // create 'real' transactions from recovered shell transactions
+ auto dbimpl = static_cast_with_check<DBImpl, DB>(GetRootDB());
+ assert(dbimpl != nullptr);
+ auto rtrxs = dbimpl->recovered_transactions();
+
+ for (auto it = rtrxs.begin(); it != rtrxs.end(); ++it) {
+ auto recovered_trx = it->second;
+ assert(recovered_trx);
+ assert(recovered_trx->batches_.size() == 1);
+ const auto& seq = recovered_trx->batches_.begin()->first;
+ const auto& batch_info = recovered_trx->batches_.begin()->second;
+ assert(batch_info.log_number_);
+ assert(recovered_trx->name_.length());
+
+ WriteOptions w_options;
+ w_options.sync = true;
+ TransactionOptions t_options;
+ // This would help avoiding deadlock for keys that although exist in the WAL
+ // did not go through concurrency control. This includes the merge that
+ // MyRocks uses for auto-inc columns. It is safe to do so, since (i) if
+ // there is a conflict between the keys of two transactions that must be
+ // avoided, it is already avoided by the application, MyRocks, before the
+ // restart (ii) application, MyRocks, guarntees to rollback/commit the
+ // recovered transactions before new transactions start.
+ t_options.skip_concurrency_control = true;
+
+ Transaction* real_trx = BeginTransaction(w_options, t_options, nullptr);
+ assert(real_trx);
+ real_trx->SetLogNumber(batch_info.log_number_);
+ assert(seq != kMaxSequenceNumber);
+ if (GetTxnDBOptions().write_policy != WRITE_COMMITTED) {
+ real_trx->SetId(seq);
+ }
+
+ s = real_trx->SetName(recovered_trx->name_);
+ if (!s.ok()) {
+ break;
+ }
+
+ s = real_trx->RebuildFromWriteBatch(batch_info.batch_);
+ // WriteCommitted set this to to disable this check that is specific to
+ // WritePrepared txns
+ assert(batch_info.batch_cnt_ == 0 ||
+ real_trx->GetWriteBatch()->SubBatchCnt() == batch_info.batch_cnt_);
+ real_trx->SetState(Transaction::PREPARED);
+ if (!s.ok()) {
+ break;
+ }
+ }
+ if (s.ok()) {
+ dbimpl->DeleteAllRecoveredTransactions();
+ }
+ return s;
+}
+
+Transaction* WriteCommittedTxnDB::BeginTransaction(
+ const WriteOptions& write_options, const TransactionOptions& txn_options,
+ Transaction* old_txn) {
+ if (old_txn != nullptr) {
+ ReinitializeTransaction(old_txn, write_options, txn_options);
+ return old_txn;
+ } else {
+ return new WriteCommittedTxn(this, write_options, txn_options);
+ }
+}
+
+TransactionDBOptions PessimisticTransactionDB::ValidateTxnDBOptions(
+ const TransactionDBOptions& txn_db_options) {
+ TransactionDBOptions validated = txn_db_options;
+
+ if (txn_db_options.num_stripes == 0) {
+ validated.num_stripes = 1;
+ }
+
+ return validated;
+}
+
+Status TransactionDB::Open(const Options& options,
+ const TransactionDBOptions& txn_db_options,
+ const std::string& dbname, TransactionDB** dbptr) {
+ DBOptions db_options(options);
+ ColumnFamilyOptions cf_options(options);
+ std::vector<ColumnFamilyDescriptor> column_families;
+ column_families.push_back(
+ ColumnFamilyDescriptor(kDefaultColumnFamilyName, cf_options));
+ std::vector<ColumnFamilyHandle*> handles;
+ Status s = TransactionDB::Open(db_options, txn_db_options, dbname,
+ column_families, &handles, dbptr);
+ if (s.ok()) {
+ assert(handles.size() == 1);
+ // i can delete the handle since DBImpl is always holding a reference to
+ // default column family
+ delete handles[0];
+ }
+
+ return s;
+}
+
+Status TransactionDB::Open(
+ const DBOptions& db_options, const TransactionDBOptions& txn_db_options,
+ const std::string& dbname,
+ const std::vector<ColumnFamilyDescriptor>& column_families,
+ std::vector<ColumnFamilyHandle*>* handles, TransactionDB** dbptr) {
+ Status s;
+ DB* db = nullptr;
+ if (txn_db_options.write_policy == WRITE_COMMITTED &&
+ db_options.unordered_write) {
+ return Status::NotSupported(
+ "WRITE_COMMITTED is incompatible with unordered_writes");
+ }
+ if (txn_db_options.write_policy == WRITE_UNPREPARED &&
+ db_options.unordered_write) {
+ // TODO(lth): support it
+ return Status::NotSupported(
+ "WRITE_UNPREPARED is currently incompatible with unordered_writes");
+ }
+ if (txn_db_options.write_policy == WRITE_PREPARED &&
+ db_options.unordered_write && !db_options.two_write_queues) {
+ return Status::NotSupported(
+ "WRITE_PREPARED is incompatible with unordered_writes if "
+ "two_write_queues is not enabled.");
+ }
+
+ std::vector<ColumnFamilyDescriptor> column_families_copy = column_families;
+ std::vector<size_t> compaction_enabled_cf_indices;
+ DBOptions db_options_2pc = db_options;
+ PrepareWrap(&db_options_2pc, &column_families_copy,
+ &compaction_enabled_cf_indices);
+ const bool use_seq_per_batch =
+ txn_db_options.write_policy == WRITE_PREPARED ||
+ txn_db_options.write_policy == WRITE_UNPREPARED;
+ const bool use_batch_per_txn =
+ txn_db_options.write_policy == WRITE_COMMITTED ||
+ txn_db_options.write_policy == WRITE_PREPARED;
+ s = DBImpl::Open(db_options_2pc, dbname, column_families_copy, handles, &db,
+ use_seq_per_batch, use_batch_per_txn);
+ if (s.ok()) {
+ ROCKS_LOG_WARN(db->GetDBOptions().info_log,
+ "Transaction write_policy is %" PRId32,
+ static_cast<int>(txn_db_options.write_policy));
+ s = WrapDB(db, txn_db_options, compaction_enabled_cf_indices, *handles,
+ dbptr);
+ }
+ if (!s.ok()) {
+ // just in case it was not deleted (and not set to nullptr).
+ delete db;
+ }
+ return s;
+}
+
+void TransactionDB::PrepareWrap(
+ DBOptions* db_options, std::vector<ColumnFamilyDescriptor>* column_families,
+ std::vector<size_t>* compaction_enabled_cf_indices) {
+ compaction_enabled_cf_indices->clear();
+
+ // Enable MemTable History if not already enabled
+ for (size_t i = 0; i < column_families->size(); i++) {
+ ColumnFamilyOptions* cf_options = &(*column_families)[i].options;
+
+ if (cf_options->max_write_buffer_size_to_maintain == 0 &&
+ cf_options->max_write_buffer_number_to_maintain == 0) {
+ // Setting to -1 will set the History size to
+ // max_write_buffer_number * write_buffer_size.
+ cf_options->max_write_buffer_size_to_maintain = -1;
+ }
+ if (!cf_options->disable_auto_compactions) {
+ // Disable compactions momentarily to prevent race with DB::Open
+ cf_options->disable_auto_compactions = true;
+ compaction_enabled_cf_indices->push_back(i);
+ }
+ }
+ db_options->allow_2pc = true;
+}
+
+Status TransactionDB::WrapDB(
+ // make sure this db is already opened with memtable history enabled,
+ // auto compaction distabled and 2 phase commit enabled
+ DB* db, const TransactionDBOptions& txn_db_options,
+ const std::vector<size_t>& compaction_enabled_cf_indices,
+ const std::vector<ColumnFamilyHandle*>& handles, TransactionDB** dbptr) {
+ assert(db != nullptr);
+ assert(dbptr != nullptr);
+ *dbptr = nullptr;
+ std::unique_ptr<PessimisticTransactionDB> txn_db;
+ switch (txn_db_options.write_policy) {
+ case WRITE_UNPREPARED:
+ txn_db.reset(new WriteUnpreparedTxnDB(
+ db, PessimisticTransactionDB::ValidateTxnDBOptions(txn_db_options)));
+ break;
+ case WRITE_PREPARED:
+ txn_db.reset(new WritePreparedTxnDB(
+ db, PessimisticTransactionDB::ValidateTxnDBOptions(txn_db_options)));
+ break;
+ case WRITE_COMMITTED:
+ default:
+ txn_db.reset(new WriteCommittedTxnDB(
+ db, PessimisticTransactionDB::ValidateTxnDBOptions(txn_db_options)));
+ }
+ txn_db->UpdateCFComparatorMap(handles);
+ Status s = txn_db->Initialize(compaction_enabled_cf_indices, handles);
+ // In case of a failure at this point, db is deleted via the txn_db destructor
+ // and set to nullptr.
+ if (s.ok()) {
+ *dbptr = txn_db.release();
+ }
+ return s;
+}
+
+Status TransactionDB::WrapStackableDB(
+ // make sure this stackable_db is already opened with memtable history
+ // enabled, auto compaction distabled and 2 phase commit enabled
+ StackableDB* db, const TransactionDBOptions& txn_db_options,
+ const std::vector<size_t>& compaction_enabled_cf_indices,
+ const std::vector<ColumnFamilyHandle*>& handles, TransactionDB** dbptr) {
+ assert(db != nullptr);
+ assert(dbptr != nullptr);
+ *dbptr = nullptr;
+ std::unique_ptr<PessimisticTransactionDB> txn_db;
+
+ switch (txn_db_options.write_policy) {
+ case WRITE_UNPREPARED:
+ txn_db.reset(new WriteUnpreparedTxnDB(
+ db, PessimisticTransactionDB::ValidateTxnDBOptions(txn_db_options)));
+ break;
+ case WRITE_PREPARED:
+ txn_db.reset(new WritePreparedTxnDB(
+ db, PessimisticTransactionDB::ValidateTxnDBOptions(txn_db_options)));
+ break;
+ case WRITE_COMMITTED:
+ default:
+ txn_db.reset(new WriteCommittedTxnDB(
+ db, PessimisticTransactionDB::ValidateTxnDBOptions(txn_db_options)));
+ }
+ txn_db->UpdateCFComparatorMap(handles);
+ Status s = txn_db->Initialize(compaction_enabled_cf_indices, handles);
+ // In case of a failure at this point, db is deleted via the txn_db destructor
+ // and set to nullptr.
+ if (s.ok()) {
+ *dbptr = txn_db.release();
+ }
+ return s;
+}
+
+// Let TransactionLockMgr know that this column family exists so it can
+// allocate a LockMap for it.
+void PessimisticTransactionDB::AddColumnFamily(
+ const ColumnFamilyHandle* handle) {
+ lock_mgr_.AddColumnFamily(handle->GetID());
+}
+
+Status PessimisticTransactionDB::CreateColumnFamily(
+ const ColumnFamilyOptions& options, const std::string& column_family_name,
+ ColumnFamilyHandle** handle) {
+ InstrumentedMutexLock l(&column_family_mutex_);
+ Status s = VerifyCFOptions(options);
+ if (!s.ok()) {
+ return s;
+ }
+
+ s = db_->CreateColumnFamily(options, column_family_name, handle);
+ if (s.ok()) {
+ lock_mgr_.AddColumnFamily((*handle)->GetID());
+ UpdateCFComparatorMap(*handle);
+ }
+
+ return s;
+}
+
+// Let TransactionLockMgr know that it can deallocate the LockMap for this
+// column family.
+Status PessimisticTransactionDB::DropColumnFamily(
+ ColumnFamilyHandle* column_family) {
+ InstrumentedMutexLock l(&column_family_mutex_);
+
+ Status s = db_->DropColumnFamily(column_family);
+ if (s.ok()) {
+ lock_mgr_.RemoveColumnFamily(column_family->GetID());
+ }
+
+ return s;
+}
+
+Status PessimisticTransactionDB::TryLock(PessimisticTransaction* txn,
+ uint32_t cfh_id,
+ const std::string& key,
+ bool exclusive) {
+ return lock_mgr_.TryLock(txn, cfh_id, key, GetEnv(), exclusive);
+}
+
+void PessimisticTransactionDB::UnLock(PessimisticTransaction* txn,
+ const TransactionKeyMap* keys) {
+ lock_mgr_.UnLock(txn, keys, GetEnv());
+}
+
+void PessimisticTransactionDB::UnLock(PessimisticTransaction* txn,
+ uint32_t cfh_id, const std::string& key) {
+ lock_mgr_.UnLock(txn, cfh_id, key, GetEnv());
+}
+
+// Used when wrapping DB write operations in a transaction
+Transaction* PessimisticTransactionDB::BeginInternalTransaction(
+ const WriteOptions& options) {
+ TransactionOptions txn_options;
+ Transaction* txn = BeginTransaction(options, txn_options, nullptr);
+
+ // Use default timeout for non-transactional writes
+ txn->SetLockTimeout(txn_db_options_.default_lock_timeout);
+ return txn;
+}
+
+// All user Put, Merge, Delete, and Write requests must be intercepted to make
+// sure that they lock all keys that they are writing to avoid causing conflicts
+// with any concurrent transactions. The easiest way to do this is to wrap all
+// write operations in a transaction.
+//
+// Put(), Merge(), and Delete() only lock a single key per call. Write() will
+// sort its keys before locking them. This guarantees that TransactionDB write
+// methods cannot deadlock with each other (but still could deadlock with a
+// Transaction).
+Status PessimisticTransactionDB::Put(const WriteOptions& options,
+ ColumnFamilyHandle* column_family,
+ const Slice& key, const Slice& val) {
+ Status s;
+
+ Transaction* txn = BeginInternalTransaction(options);
+ txn->DisableIndexing();
+
+ // Since the client didn't create a transaction, they don't care about
+ // conflict checking for this write. So we just need to do PutUntracked().
+ s = txn->PutUntracked(column_family, key, val);
+
+ if (s.ok()) {
+ s = txn->Commit();
+ }
+
+ delete txn;
+
+ return s;
+}
+
+Status PessimisticTransactionDB::Delete(const WriteOptions& wopts,
+ ColumnFamilyHandle* column_family,
+ const Slice& key) {
+ Status s;
+
+ Transaction* txn = BeginInternalTransaction(wopts);
+ txn->DisableIndexing();
+
+ // Since the client didn't create a transaction, they don't care about
+ // conflict checking for this write. So we just need to do
+ // DeleteUntracked().
+ s = txn->DeleteUntracked(column_family, key);
+
+ if (s.ok()) {
+ s = txn->Commit();
+ }
+
+ delete txn;
+
+ return s;
+}
+
+Status PessimisticTransactionDB::SingleDelete(const WriteOptions& wopts,
+ ColumnFamilyHandle* column_family,
+ const Slice& key) {
+ Status s;
+
+ Transaction* txn = BeginInternalTransaction(wopts);
+ txn->DisableIndexing();
+
+ // Since the client didn't create a transaction, they don't care about
+ // conflict checking for this write. So we just need to do
+ // SingleDeleteUntracked().
+ s = txn->SingleDeleteUntracked(column_family, key);
+
+ if (s.ok()) {
+ s = txn->Commit();
+ }
+
+ delete txn;
+
+ return s;
+}
+
+Status PessimisticTransactionDB::Merge(const WriteOptions& options,
+ ColumnFamilyHandle* column_family,
+ const Slice& key, const Slice& value) {
+ Status s;
+
+ Transaction* txn = BeginInternalTransaction(options);
+ txn->DisableIndexing();
+
+ // Since the client didn't create a transaction, they don't care about
+ // conflict checking for this write. So we just need to do
+ // MergeUntracked().
+ s = txn->MergeUntracked(column_family, key, value);
+
+ if (s.ok()) {
+ s = txn->Commit();
+ }
+
+ delete txn;
+
+ return s;
+}
+
+Status PessimisticTransactionDB::Write(const WriteOptions& opts,
+ WriteBatch* updates) {
+ return WriteWithConcurrencyControl(opts, updates);
+}
+
+Status WriteCommittedTxnDB::Write(const WriteOptions& opts,
+ WriteBatch* updates) {
+ if (txn_db_options_.skip_concurrency_control) {
+ return db_impl_->Write(opts, updates);
+ } else {
+ return WriteWithConcurrencyControl(opts, updates);
+ }
+}
+
+Status WriteCommittedTxnDB::Write(
+ const WriteOptions& opts,
+ const TransactionDBWriteOptimizations& optimizations, WriteBatch* updates) {
+ if (optimizations.skip_concurrency_control) {
+ return db_impl_->Write(opts, updates);
+ } else {
+ return WriteWithConcurrencyControl(opts, updates);
+ }
+}
+
+void PessimisticTransactionDB::InsertExpirableTransaction(
+ TransactionID tx_id, PessimisticTransaction* tx) {
+ assert(tx->GetExpirationTime() > 0);
+ std::lock_guard<std::mutex> lock(map_mutex_);
+ expirable_transactions_map_.insert({tx_id, tx});
+}
+
+void PessimisticTransactionDB::RemoveExpirableTransaction(TransactionID tx_id) {
+ std::lock_guard<std::mutex> lock(map_mutex_);
+ expirable_transactions_map_.erase(tx_id);
+}
+
+bool PessimisticTransactionDB::TryStealingExpiredTransactionLocks(
+ TransactionID tx_id) {
+ std::lock_guard<std::mutex> lock(map_mutex_);
+
+ auto tx_it = expirable_transactions_map_.find(tx_id);
+ if (tx_it == expirable_transactions_map_.end()) {
+ return true;
+ }
+ PessimisticTransaction& tx = *(tx_it->second);
+ return tx.TryStealingLocks();
+}
+
+void PessimisticTransactionDB::ReinitializeTransaction(
+ Transaction* txn, const WriteOptions& write_options,
+ const TransactionOptions& txn_options) {
+ auto txn_impl =
+ static_cast_with_check<PessimisticTransaction, Transaction>(txn);
+
+ txn_impl->Reinitialize(this, write_options, txn_options);
+}
+
+Transaction* PessimisticTransactionDB::GetTransactionByName(
+ const TransactionName& name) {
+ std::lock_guard<std::mutex> lock(name_map_mutex_);
+ auto it = transactions_.find(name);
+ if (it == transactions_.end()) {
+ return nullptr;
+ } else {
+ return it->second;
+ }
+}
+
+void PessimisticTransactionDB::GetAllPreparedTransactions(
+ std::vector<Transaction*>* transv) {
+ assert(transv);
+ transv->clear();
+ std::lock_guard<std::mutex> lock(name_map_mutex_);
+ for (auto it = transactions_.begin(); it != transactions_.end(); ++it) {
+ if (it->second->GetState() == Transaction::PREPARED) {
+ transv->push_back(it->second);
+ }
+ }
+}
+
+TransactionLockMgr::LockStatusData
+PessimisticTransactionDB::GetLockStatusData() {
+ return lock_mgr_.GetLockStatusData();
+}
+
+std::vector<DeadlockPath> PessimisticTransactionDB::GetDeadlockInfoBuffer() {
+ return lock_mgr_.GetDeadlockInfoBuffer();
+}
+
+void PessimisticTransactionDB::SetDeadlockInfoBufferSize(uint32_t target_size) {
+ lock_mgr_.Resize(target_size);
+}
+
+void PessimisticTransactionDB::RegisterTransaction(Transaction* txn) {
+ assert(txn);
+ assert(txn->GetName().length() > 0);
+ assert(GetTransactionByName(txn->GetName()) == nullptr);
+ assert(txn->GetState() == Transaction::STARTED);
+ std::lock_guard<std::mutex> lock(name_map_mutex_);
+ transactions_[txn->GetName()] = txn;
+}
+
+void PessimisticTransactionDB::UnregisterTransaction(Transaction* txn) {
+ assert(txn);
+ std::lock_guard<std::mutex> lock(name_map_mutex_);
+ auto it = transactions_.find(txn->GetName());
+ assert(it != transactions_.end());
+ transactions_.erase(it);
+}
+
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/pessimistic_transaction_db.h b/src/rocksdb/utilities/transactions/pessimistic_transaction_db.h
new file mode 100644
index 000000000..39346dddd
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/pessimistic_transaction_db.h
@@ -0,0 +1,220 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+#ifndef ROCKSDB_LITE
+
+#include <mutex>
+#include <queue>
+#include <set>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "db/db_iter.h"
+#include "db/read_callback.h"
+#include "db/snapshot_checker.h"
+#include "rocksdb/db.h"
+#include "rocksdb/options.h"
+#include "rocksdb/utilities/transaction_db.h"
+#include "util/cast_util.h"
+#include "utilities/transactions/pessimistic_transaction.h"
+#include "utilities/transactions/transaction_lock_mgr.h"
+#include "utilities/transactions/write_prepared_txn.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class PessimisticTransactionDB : public TransactionDB {
+ public:
+ explicit PessimisticTransactionDB(DB* db,
+ const TransactionDBOptions& txn_db_options);
+
+ explicit PessimisticTransactionDB(StackableDB* db,
+ const TransactionDBOptions& txn_db_options);
+
+ virtual ~PessimisticTransactionDB();
+
+ virtual const Snapshot* GetSnapshot() override { return db_->GetSnapshot(); }
+
+ virtual Status Initialize(
+ const std::vector<size_t>& compaction_enabled_cf_indices,
+ const std::vector<ColumnFamilyHandle*>& handles);
+
+ Transaction* BeginTransaction(const WriteOptions& write_options,
+ const TransactionOptions& txn_options,
+ Transaction* old_txn) override = 0;
+
+ using StackableDB::Put;
+ virtual Status Put(const WriteOptions& options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ const Slice& val) override;
+
+ using StackableDB::Delete;
+ virtual Status Delete(const WriteOptions& wopts,
+ ColumnFamilyHandle* column_family,
+ const Slice& key) override;
+
+ using StackableDB::SingleDelete;
+ virtual Status SingleDelete(const WriteOptions& wopts,
+ ColumnFamilyHandle* column_family,
+ const Slice& key) override;
+
+ using StackableDB::Merge;
+ virtual Status Merge(const WriteOptions& options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ const Slice& value) override;
+
+ using TransactionDB::Write;
+ virtual Status Write(const WriteOptions& opts, WriteBatch* updates) override;
+ inline Status WriteWithConcurrencyControl(const WriteOptions& opts,
+ WriteBatch* updates) {
+ // Need to lock all keys in this batch to prevent write conflicts with
+ // concurrent transactions.
+ Transaction* txn = BeginInternalTransaction(opts);
+ txn->DisableIndexing();
+
+ auto txn_impl =
+ static_cast_with_check<PessimisticTransaction, Transaction>(txn);
+
+ // Since commitBatch sorts the keys before locking, concurrent Write()
+ // operations will not cause a deadlock.
+ // In order to avoid a deadlock with a concurrent Transaction, Transactions
+ // should use a lock timeout.
+ Status s = txn_impl->CommitBatch(updates);
+
+ delete txn;
+
+ return s;
+ }
+
+ using StackableDB::CreateColumnFamily;
+ virtual Status CreateColumnFamily(const ColumnFamilyOptions& options,
+ const std::string& column_family_name,
+ ColumnFamilyHandle** handle) override;
+
+ using StackableDB::DropColumnFamily;
+ virtual Status DropColumnFamily(ColumnFamilyHandle* column_family) override;
+
+ Status TryLock(PessimisticTransaction* txn, uint32_t cfh_id,
+ const std::string& key, bool exclusive);
+
+ void UnLock(PessimisticTransaction* txn, const TransactionKeyMap* keys);
+ void UnLock(PessimisticTransaction* txn, uint32_t cfh_id,
+ const std::string& key);
+
+ void AddColumnFamily(const ColumnFamilyHandle* handle);
+
+ static TransactionDBOptions ValidateTxnDBOptions(
+ const TransactionDBOptions& txn_db_options);
+
+ const TransactionDBOptions& GetTxnDBOptions() const {
+ return txn_db_options_;
+ }
+
+ void InsertExpirableTransaction(TransactionID tx_id,
+ PessimisticTransaction* tx);
+ void RemoveExpirableTransaction(TransactionID tx_id);
+
+ // If transaction is no longer available, locks can be stolen
+ // If transaction is available, try stealing locks directly from transaction
+ // It is the caller's responsibility to ensure that the referred transaction
+ // is expirable (GetExpirationTime() > 0) and that it is expired.
+ bool TryStealingExpiredTransactionLocks(TransactionID tx_id);
+
+ Transaction* GetTransactionByName(const TransactionName& name) override;
+
+ void RegisterTransaction(Transaction* txn);
+ void UnregisterTransaction(Transaction* txn);
+
+ // not thread safe. current use case is during recovery (single thread)
+ void GetAllPreparedTransactions(std::vector<Transaction*>* trans) override;
+
+ TransactionLockMgr::LockStatusData GetLockStatusData() override;
+
+ std::vector<DeadlockPath> GetDeadlockInfoBuffer() override;
+ void SetDeadlockInfoBufferSize(uint32_t target_size) override;
+
+ // The default implementation does nothing. The actual implementation is moved
+ // to the child classes that actually need this information. This was due to
+ // an odd performance drop we observed when the added std::atomic member to
+ // the base class even when the subclass do not read it in the fast path.
+ virtual void UpdateCFComparatorMap(const std::vector<ColumnFamilyHandle*>&) {}
+ virtual void UpdateCFComparatorMap(ColumnFamilyHandle*) {}
+
+ protected:
+ DBImpl* db_impl_;
+ std::shared_ptr<Logger> info_log_;
+ const TransactionDBOptions txn_db_options_;
+
+ void ReinitializeTransaction(
+ Transaction* txn, const WriteOptions& write_options,
+ const TransactionOptions& txn_options = TransactionOptions());
+
+ virtual Status VerifyCFOptions(const ColumnFamilyOptions& cf_options);
+
+ private:
+ friend class WritePreparedTxnDB;
+ friend class WritePreparedTxnDBMock;
+ friend class WriteUnpreparedTxn;
+ friend class TransactionTest_DoubleCrashInRecovery_Test;
+ friend class TransactionTest_DoubleEmptyWrite_Test;
+ friend class TransactionTest_DuplicateKeys_Test;
+ friend class TransactionTest_PersistentTwoPhaseTransactionTest_Test;
+ friend class TransactionTest_TwoPhaseDoubleRecoveryTest_Test;
+ friend class TransactionTest_TwoPhaseOutOfOrderDelete_Test;
+ friend class TransactionStressTest_TwoPhaseLongPrepareTest_Test;
+ friend class WriteUnpreparedTransactionTest_RecoveryTest_Test;
+ friend class WriteUnpreparedTransactionTest_MarkLogWithPrepSection_Test;
+ TransactionLockMgr lock_mgr_;
+
+ // Must be held when adding/dropping column families.
+ InstrumentedMutex column_family_mutex_;
+ Transaction* BeginInternalTransaction(const WriteOptions& options);
+
+ // Used to ensure that no locks are stolen from an expirable transaction
+ // that has started a commit. Only transactions with an expiration time
+ // should be in this map.
+ std::mutex map_mutex_;
+ std::unordered_map<TransactionID, PessimisticTransaction*>
+ expirable_transactions_map_;
+
+ // map from name to two phase transaction instance
+ std::mutex name_map_mutex_;
+ std::unordered_map<TransactionName, Transaction*> transactions_;
+
+ // Signal that we are testing a crash scenario. Some asserts could be relaxed
+ // in such cases.
+ virtual void TEST_Crash() {}
+};
+
+// A PessimisticTransactionDB that writes the data to the DB after the commit.
+// In this way the DB only contains the committed data.
+class WriteCommittedTxnDB : public PessimisticTransactionDB {
+ public:
+ explicit WriteCommittedTxnDB(DB* db,
+ const TransactionDBOptions& txn_db_options)
+ : PessimisticTransactionDB(db, txn_db_options) {}
+
+ explicit WriteCommittedTxnDB(StackableDB* db,
+ const TransactionDBOptions& txn_db_options)
+ : PessimisticTransactionDB(db, txn_db_options) {}
+
+ virtual ~WriteCommittedTxnDB() {}
+
+ Transaction* BeginTransaction(const WriteOptions& write_options,
+ const TransactionOptions& txn_options,
+ Transaction* old_txn) override;
+
+ // Optimized version of ::Write that makes use of skip_concurrency_control
+ // hint
+ using TransactionDB::Write;
+ virtual Status Write(const WriteOptions& opts,
+ const TransactionDBWriteOptimizations& optimizations,
+ WriteBatch* updates) override;
+ virtual Status Write(const WriteOptions& opts, WriteBatch* updates) override;
+};
+
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/snapshot_checker.cc b/src/rocksdb/utilities/transactions/snapshot_checker.cc
new file mode 100644
index 000000000..9c43bef43
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/snapshot_checker.cc
@@ -0,0 +1,49 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "db/snapshot_checker.h"
+
+#ifdef ROCKSDB_LITE
+#include <assert.h>
+#endif // ROCKSDB_LITE
+
+#include "utilities/transactions/write_prepared_txn_db.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+#ifdef ROCKSDB_LITE
+WritePreparedSnapshotChecker::WritePreparedSnapshotChecker(
+ WritePreparedTxnDB* /*txn_db*/) {}
+
+SnapshotCheckerResult WritePreparedSnapshotChecker::CheckInSnapshot(
+ SequenceNumber /*sequence*/, SequenceNumber /*snapshot_sequence*/) const {
+ // Should never be called in LITE mode.
+ assert(false);
+ return SnapshotCheckerResult::kInSnapshot;
+}
+
+#else
+
+WritePreparedSnapshotChecker::WritePreparedSnapshotChecker(
+ WritePreparedTxnDB* txn_db)
+ : txn_db_(txn_db){};
+
+SnapshotCheckerResult WritePreparedSnapshotChecker::CheckInSnapshot(
+ SequenceNumber sequence, SequenceNumber snapshot_sequence) const {
+ bool snapshot_released = false;
+ // TODO(myabandeh): set min_uncommitted
+ bool in_snapshot = txn_db_->IsInSnapshot(
+ sequence, snapshot_sequence, kMinUnCommittedSeq, &snapshot_released);
+ if (snapshot_released) {
+ return SnapshotCheckerResult::kSnapshotReleased;
+ }
+ return in_snapshot ? SnapshotCheckerResult::kInSnapshot
+ : SnapshotCheckerResult::kNotInSnapshot;
+}
+
+#endif // ROCKSDB_LITE
+DisableGCSnapshotChecker DisableGCSnapshotChecker::instance_;
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/transactions/transaction_base.cc b/src/rocksdb/utilities/transactions/transaction_base.cc
new file mode 100644
index 000000000..805d4ab36
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/transaction_base.cc
@@ -0,0 +1,837 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "utilities/transactions/transaction_base.h"
+
+#include <cinttypes>
+
+#include "db/column_family.h"
+#include "db/db_impl/db_impl.h"
+#include "rocksdb/comparator.h"
+#include "rocksdb/db.h"
+#include "rocksdb/status.h"
+#include "util/cast_util.h"
+#include "util/string_util.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+TransactionBaseImpl::TransactionBaseImpl(DB* db,
+ const WriteOptions& write_options)
+ : db_(db),
+ dbimpl_(static_cast_with_check<DBImpl, DB>(db)),
+ write_options_(write_options),
+ cmp_(GetColumnFamilyUserComparator(db->DefaultColumnFamily())),
+ start_time_(db_->GetEnv()->NowMicros()),
+ write_batch_(cmp_, 0, true, 0),
+ indexing_enabled_(true) {
+ assert(dynamic_cast<DBImpl*>(db_) != nullptr);
+ log_number_ = 0;
+ if (dbimpl_->allow_2pc()) {
+ InitWriteBatch();
+ }
+}
+
+TransactionBaseImpl::~TransactionBaseImpl() {
+ // Release snapshot if snapshot is set
+ SetSnapshotInternal(nullptr);
+}
+
+void TransactionBaseImpl::Clear() {
+ save_points_.reset(nullptr);
+ write_batch_.Clear();
+ commit_time_batch_.Clear();
+ tracked_keys_.clear();
+ num_puts_ = 0;
+ num_deletes_ = 0;
+ num_merges_ = 0;
+
+ if (dbimpl_->allow_2pc()) {
+ InitWriteBatch();
+ }
+}
+
+void TransactionBaseImpl::Reinitialize(DB* db,
+ const WriteOptions& write_options) {
+ Clear();
+ ClearSnapshot();
+ id_ = 0;
+ db_ = db;
+ name_.clear();
+ log_number_ = 0;
+ write_options_ = write_options;
+ start_time_ = db_->GetEnv()->NowMicros();
+ indexing_enabled_ = true;
+ cmp_ = GetColumnFamilyUserComparator(db_->DefaultColumnFamily());
+}
+
+void TransactionBaseImpl::SetSnapshot() {
+ const Snapshot* snapshot = dbimpl_->GetSnapshotForWriteConflictBoundary();
+ SetSnapshotInternal(snapshot);
+}
+
+void TransactionBaseImpl::SetSnapshotInternal(const Snapshot* snapshot) {
+ // Set a custom deleter for the snapshot_ SharedPtr as the snapshot needs to
+ // be released, not deleted when it is no longer referenced.
+ snapshot_.reset(snapshot, std::bind(&TransactionBaseImpl::ReleaseSnapshot,
+ this, std::placeholders::_1, db_));
+ snapshot_needed_ = false;
+ snapshot_notifier_ = nullptr;
+}
+
+void TransactionBaseImpl::SetSnapshotOnNextOperation(
+ std::shared_ptr<TransactionNotifier> notifier) {
+ snapshot_needed_ = true;
+ snapshot_notifier_ = notifier;
+}
+
+void TransactionBaseImpl::SetSnapshotIfNeeded() {
+ if (snapshot_needed_) {
+ std::shared_ptr<TransactionNotifier> notifier = snapshot_notifier_;
+ SetSnapshot();
+ if (notifier != nullptr) {
+ notifier->SnapshotCreated(GetSnapshot());
+ }
+ }
+}
+
+Status TransactionBaseImpl::TryLock(ColumnFamilyHandle* column_family,
+ const SliceParts& key, bool read_only,
+ bool exclusive, const bool do_validate,
+ const bool assume_tracked) {
+ size_t key_size = 0;
+ for (int i = 0; i < key.num_parts; ++i) {
+ key_size += key.parts[i].size();
+ }
+
+ std::string str;
+ str.reserve(key_size);
+
+ for (int i = 0; i < key.num_parts; ++i) {
+ str.append(key.parts[i].data(), key.parts[i].size());
+ }
+
+ return TryLock(column_family, str, read_only, exclusive, do_validate,
+ assume_tracked);
+}
+
+void TransactionBaseImpl::SetSavePoint() {
+ if (save_points_ == nullptr) {
+ save_points_.reset(new std::stack<TransactionBaseImpl::SavePoint, autovector<TransactionBaseImpl::SavePoint>>());
+ }
+ save_points_->emplace(snapshot_, snapshot_needed_, snapshot_notifier_,
+ num_puts_, num_deletes_, num_merges_);
+ write_batch_.SetSavePoint();
+}
+
+Status TransactionBaseImpl::RollbackToSavePoint() {
+ if (save_points_ != nullptr && save_points_->size() > 0) {
+ // Restore saved SavePoint
+ TransactionBaseImpl::SavePoint& save_point = save_points_->top();
+ snapshot_ = save_point.snapshot_;
+ snapshot_needed_ = save_point.snapshot_needed_;
+ snapshot_notifier_ = save_point.snapshot_notifier_;
+ num_puts_ = save_point.num_puts_;
+ num_deletes_ = save_point.num_deletes_;
+ num_merges_ = save_point.num_merges_;
+
+ // Rollback batch
+ Status s = write_batch_.RollbackToSavePoint();
+ assert(s.ok());
+
+ // Rollback any keys that were tracked since the last savepoint
+ const TransactionKeyMap& key_map = save_point.new_keys_;
+ for (const auto& key_map_iter : key_map) {
+ uint32_t column_family_id = key_map_iter.first;
+ auto& keys = key_map_iter.second;
+
+ auto& cf_tracked_keys = tracked_keys_[column_family_id];
+
+ for (const auto& key_iter : keys) {
+ const std::string& key = key_iter.first;
+ uint32_t num_reads = key_iter.second.num_reads;
+ uint32_t num_writes = key_iter.second.num_writes;
+
+ auto tracked_keys_iter = cf_tracked_keys.find(key);
+ assert(tracked_keys_iter != cf_tracked_keys.end());
+
+ // Decrement the total reads/writes of this key by the number of
+ // reads/writes done since the last SavePoint.
+ if (num_reads > 0) {
+ assert(tracked_keys_iter->second.num_reads >= num_reads);
+ tracked_keys_iter->second.num_reads -= num_reads;
+ }
+ if (num_writes > 0) {
+ assert(tracked_keys_iter->second.num_writes >= num_writes);
+ tracked_keys_iter->second.num_writes -= num_writes;
+ }
+ if (tracked_keys_iter->second.num_reads == 0 &&
+ tracked_keys_iter->second.num_writes == 0) {
+ cf_tracked_keys.erase(tracked_keys_iter);
+ }
+ }
+ }
+
+ save_points_->pop();
+
+ return s;
+ } else {
+ assert(write_batch_.RollbackToSavePoint().IsNotFound());
+ return Status::NotFound();
+ }
+}
+
+Status TransactionBaseImpl::PopSavePoint() {
+ if (save_points_ == nullptr ||
+ save_points_->empty()) {
+ // No SavePoint yet.
+ assert(write_batch_.PopSavePoint().IsNotFound());
+ return Status::NotFound();
+ }
+
+ assert(!save_points_->empty());
+ // If there is another savepoint A below the current savepoint B, then A needs
+ // to inherit tracked_keys in B so that if we rollback to savepoint A, we
+ // remember to unlock keys in B. If there is no other savepoint below, then we
+ // can safely discard savepoint info.
+ if (save_points_->size() == 1) {
+ save_points_->pop();
+ } else {
+ TransactionBaseImpl::SavePoint top;
+ std::swap(top, save_points_->top());
+ save_points_->pop();
+
+ const TransactionKeyMap& curr_cf_key_map = top.new_keys_;
+ TransactionKeyMap& prev_cf_key_map = save_points_->top().new_keys_;
+
+ for (const auto& curr_cf_key_iter : curr_cf_key_map) {
+ uint32_t column_family_id = curr_cf_key_iter.first;
+ const std::unordered_map<std::string, TransactionKeyMapInfo>& curr_keys =
+ curr_cf_key_iter.second;
+
+ // If cfid was not previously tracked, just copy everything over.
+ auto prev_keys_iter = prev_cf_key_map.find(column_family_id);
+ if (prev_keys_iter == prev_cf_key_map.end()) {
+ prev_cf_key_map.emplace(curr_cf_key_iter);
+ } else {
+ std::unordered_map<std::string, TransactionKeyMapInfo>& prev_keys =
+ prev_keys_iter->second;
+ for (const auto& key_iter : curr_keys) {
+ const std::string& key = key_iter.first;
+ const TransactionKeyMapInfo& info = key_iter.second;
+ // If key was not previously tracked, just copy the whole struct over.
+ // Otherwise, some merging needs to occur.
+ auto prev_info = prev_keys.find(key);
+ if (prev_info == prev_keys.end()) {
+ prev_keys.emplace(key_iter);
+ } else {
+ prev_info->second.Merge(info);
+ }
+ }
+ }
+ }
+ }
+
+ return write_batch_.PopSavePoint();
+}
+
+Status TransactionBaseImpl::Get(const ReadOptions& read_options,
+ ColumnFamilyHandle* column_family,
+ const Slice& key, std::string* value) {
+ assert(value != nullptr);
+ PinnableSlice pinnable_val(value);
+ assert(!pinnable_val.IsPinned());
+ auto s = Get(read_options, column_family, key, &pinnable_val);
+ if (s.ok() && pinnable_val.IsPinned()) {
+ value->assign(pinnable_val.data(), pinnable_val.size());
+ } // else value is already assigned
+ return s;
+}
+
+Status TransactionBaseImpl::Get(const ReadOptions& read_options,
+ ColumnFamilyHandle* column_family,
+ const Slice& key, PinnableSlice* pinnable_val) {
+ return write_batch_.GetFromBatchAndDB(db_, read_options, column_family, key,
+ pinnable_val);
+}
+
+Status TransactionBaseImpl::GetForUpdate(const ReadOptions& read_options,
+ ColumnFamilyHandle* column_family,
+ const Slice& key, std::string* value,
+ bool exclusive,
+ const bool do_validate) {
+ if (!do_validate && read_options.snapshot != nullptr) {
+ return Status::InvalidArgument(
+ "If do_validate is false then GetForUpdate with snapshot is not "
+ "defined.");
+ }
+ Status s =
+ TryLock(column_family, key, true /* read_only */, exclusive, do_validate);
+
+ if (s.ok() && value != nullptr) {
+ assert(value != nullptr);
+ PinnableSlice pinnable_val(value);
+ assert(!pinnable_val.IsPinned());
+ s = Get(read_options, column_family, key, &pinnable_val);
+ if (s.ok() && pinnable_val.IsPinned()) {
+ value->assign(pinnable_val.data(), pinnable_val.size());
+ } // else value is already assigned
+ }
+ return s;
+}
+
+Status TransactionBaseImpl::GetForUpdate(const ReadOptions& read_options,
+ ColumnFamilyHandle* column_family,
+ const Slice& key,
+ PinnableSlice* pinnable_val,
+ bool exclusive,
+ const bool do_validate) {
+ if (!do_validate && read_options.snapshot != nullptr) {
+ return Status::InvalidArgument(
+ "If do_validate is false then GetForUpdate with snapshot is not "
+ "defined.");
+ }
+ Status s =
+ TryLock(column_family, key, true /* read_only */, exclusive, do_validate);
+
+ if (s.ok() && pinnable_val != nullptr) {
+ s = Get(read_options, column_family, key, pinnable_val);
+ }
+ return s;
+}
+
+std::vector<Status> TransactionBaseImpl::MultiGet(
+ const ReadOptions& read_options,
+ const std::vector<ColumnFamilyHandle*>& column_family,
+ const std::vector<Slice>& keys, std::vector<std::string>* values) {
+ size_t num_keys = keys.size();
+ values->resize(num_keys);
+
+ std::vector<Status> stat_list(num_keys);
+ for (size_t i = 0; i < num_keys; ++i) {
+ std::string* value = values ? &(*values)[i] : nullptr;
+ stat_list[i] = Get(read_options, column_family[i], keys[i], value);
+ }
+
+ return stat_list;
+}
+
+void TransactionBaseImpl::MultiGet(const ReadOptions& read_options,
+ ColumnFamilyHandle* column_family,
+ const size_t num_keys, const Slice* keys,
+ PinnableSlice* values, Status* statuses,
+ const bool sorted_input) {
+ write_batch_.MultiGetFromBatchAndDB(db_, read_options, column_family,
+ num_keys, keys, values, statuses,
+ sorted_input);
+}
+
+std::vector<Status> TransactionBaseImpl::MultiGetForUpdate(
+ const ReadOptions& read_options,
+ const std::vector<ColumnFamilyHandle*>& column_family,
+ const std::vector<Slice>& keys, std::vector<std::string>* values) {
+ // Regardless of whether the MultiGet succeeded, track these keys.
+ size_t num_keys = keys.size();
+ values->resize(num_keys);
+
+ // Lock all keys
+ for (size_t i = 0; i < num_keys; ++i) {
+ Status s = TryLock(column_family[i], keys[i], true /* read_only */,
+ true /* exclusive */);
+ if (!s.ok()) {
+ // Fail entire multiget if we cannot lock all keys
+ return std::vector<Status>(num_keys, s);
+ }
+ }
+
+ // TODO(agiardullo): optimize multiget?
+ std::vector<Status> stat_list(num_keys);
+ for (size_t i = 0; i < num_keys; ++i) {
+ std::string* value = values ? &(*values)[i] : nullptr;
+ stat_list[i] = Get(read_options, column_family[i], keys[i], value);
+ }
+
+ return stat_list;
+}
+
+Iterator* TransactionBaseImpl::GetIterator(const ReadOptions& read_options) {
+ Iterator* db_iter = db_->NewIterator(read_options);
+ assert(db_iter);
+
+ return write_batch_.NewIteratorWithBase(db_iter);
+}
+
+Iterator* TransactionBaseImpl::GetIterator(const ReadOptions& read_options,
+ ColumnFamilyHandle* column_family) {
+ Iterator* db_iter = db_->NewIterator(read_options, column_family);
+ assert(db_iter);
+
+ return write_batch_.NewIteratorWithBase(column_family, db_iter,
+ &read_options);
+}
+
+Status TransactionBaseImpl::Put(ColumnFamilyHandle* column_family,
+ const Slice& key, const Slice& value,
+ const bool assume_tracked) {
+ const bool do_validate = !assume_tracked;
+ Status s = TryLock(column_family, key, false /* read_only */,
+ true /* exclusive */, do_validate, assume_tracked);
+
+ if (s.ok()) {
+ s = GetBatchForWrite()->Put(column_family, key, value);
+ if (s.ok()) {
+ num_puts_++;
+ }
+ }
+
+ return s;
+}
+
+Status TransactionBaseImpl::Put(ColumnFamilyHandle* column_family,
+ const SliceParts& key, const SliceParts& value,
+ const bool assume_tracked) {
+ const bool do_validate = !assume_tracked;
+ Status s = TryLock(column_family, key, false /* read_only */,
+ true /* exclusive */, do_validate, assume_tracked);
+
+ if (s.ok()) {
+ s = GetBatchForWrite()->Put(column_family, key, value);
+ if (s.ok()) {
+ num_puts_++;
+ }
+ }
+
+ return s;
+}
+
+Status TransactionBaseImpl::Merge(ColumnFamilyHandle* column_family,
+ const Slice& key, const Slice& value,
+ const bool assume_tracked) {
+ const bool do_validate = !assume_tracked;
+ Status s = TryLock(column_family, key, false /* read_only */,
+ true /* exclusive */, do_validate, assume_tracked);
+
+ if (s.ok()) {
+ s = GetBatchForWrite()->Merge(column_family, key, value);
+ if (s.ok()) {
+ num_merges_++;
+ }
+ }
+
+ return s;
+}
+
+Status TransactionBaseImpl::Delete(ColumnFamilyHandle* column_family,
+ const Slice& key,
+ const bool assume_tracked) {
+ const bool do_validate = !assume_tracked;
+ Status s = TryLock(column_family, key, false /* read_only */,
+ true /* exclusive */, do_validate, assume_tracked);
+
+ if (s.ok()) {
+ s = GetBatchForWrite()->Delete(column_family, key);
+ if (s.ok()) {
+ num_deletes_++;
+ }
+ }
+
+ return s;
+}
+
+Status TransactionBaseImpl::Delete(ColumnFamilyHandle* column_family,
+ const SliceParts& key,
+ const bool assume_tracked) {
+ const bool do_validate = !assume_tracked;
+ Status s = TryLock(column_family, key, false /* read_only */,
+ true /* exclusive */, do_validate, assume_tracked);
+
+ if (s.ok()) {
+ s = GetBatchForWrite()->Delete(column_family, key);
+ if (s.ok()) {
+ num_deletes_++;
+ }
+ }
+
+ return s;
+}
+
+Status TransactionBaseImpl::SingleDelete(ColumnFamilyHandle* column_family,
+ const Slice& key,
+ const bool assume_tracked) {
+ const bool do_validate = !assume_tracked;
+ Status s = TryLock(column_family, key, false /* read_only */,
+ true /* exclusive */, do_validate, assume_tracked);
+
+ if (s.ok()) {
+ s = GetBatchForWrite()->SingleDelete(column_family, key);
+ if (s.ok()) {
+ num_deletes_++;
+ }
+ }
+
+ return s;
+}
+
+Status TransactionBaseImpl::SingleDelete(ColumnFamilyHandle* column_family,
+ const SliceParts& key,
+ const bool assume_tracked) {
+ const bool do_validate = !assume_tracked;
+ Status s = TryLock(column_family, key, false /* read_only */,
+ true /* exclusive */, do_validate, assume_tracked);
+
+ if (s.ok()) {
+ s = GetBatchForWrite()->SingleDelete(column_family, key);
+ if (s.ok()) {
+ num_deletes_++;
+ }
+ }
+
+ return s;
+}
+
+Status TransactionBaseImpl::PutUntracked(ColumnFamilyHandle* column_family,
+ const Slice& key, const Slice& value) {
+ Status s = TryLock(column_family, key, false /* read_only */,
+ true /* exclusive */, false /* do_validate */);
+
+ if (s.ok()) {
+ s = GetBatchForWrite()->Put(column_family, key, value);
+ if (s.ok()) {
+ num_puts_++;
+ }
+ }
+
+ return s;
+}
+
+Status TransactionBaseImpl::PutUntracked(ColumnFamilyHandle* column_family,
+ const SliceParts& key,
+ const SliceParts& value) {
+ Status s = TryLock(column_family, key, false /* read_only */,
+ true /* exclusive */, false /* do_validate */);
+
+ if (s.ok()) {
+ s = GetBatchForWrite()->Put(column_family, key, value);
+ if (s.ok()) {
+ num_puts_++;
+ }
+ }
+
+ return s;
+}
+
+Status TransactionBaseImpl::MergeUntracked(ColumnFamilyHandle* column_family,
+ const Slice& key,
+ const Slice& value) {
+ Status s = TryLock(column_family, key, false /* read_only */,
+ true /* exclusive */, false /* do_validate */);
+
+ if (s.ok()) {
+ s = GetBatchForWrite()->Merge(column_family, key, value);
+ if (s.ok()) {
+ num_merges_++;
+ }
+ }
+
+ return s;
+}
+
+Status TransactionBaseImpl::DeleteUntracked(ColumnFamilyHandle* column_family,
+ const Slice& key) {
+ Status s = TryLock(column_family, key, false /* read_only */,
+ true /* exclusive */, false /* do_validate */);
+
+ if (s.ok()) {
+ s = GetBatchForWrite()->Delete(column_family, key);
+ if (s.ok()) {
+ num_deletes_++;
+ }
+ }
+
+ return s;
+}
+
+Status TransactionBaseImpl::DeleteUntracked(ColumnFamilyHandle* column_family,
+ const SliceParts& key) {
+ Status s = TryLock(column_family, key, false /* read_only */,
+ true /* exclusive */, false /* do_validate */);
+
+ if (s.ok()) {
+ s = GetBatchForWrite()->Delete(column_family, key);
+ if (s.ok()) {
+ num_deletes_++;
+ }
+ }
+
+ return s;
+}
+
+Status TransactionBaseImpl::SingleDeleteUntracked(
+ ColumnFamilyHandle* column_family, const Slice& key) {
+ Status s = TryLock(column_family, key, false /* read_only */,
+ true /* exclusive */, false /* do_validate */);
+
+ if (s.ok()) {
+ s = GetBatchForWrite()->SingleDelete(column_family, key);
+ if (s.ok()) {
+ num_deletes_++;
+ }
+ }
+
+ return s;
+}
+
+void TransactionBaseImpl::PutLogData(const Slice& blob) {
+ write_batch_.PutLogData(blob);
+}
+
+WriteBatchWithIndex* TransactionBaseImpl::GetWriteBatch() {
+ return &write_batch_;
+}
+
+uint64_t TransactionBaseImpl::GetElapsedTime() const {
+ return (db_->GetEnv()->NowMicros() - start_time_) / 1000;
+}
+
+uint64_t TransactionBaseImpl::GetNumPuts() const { return num_puts_; }
+
+uint64_t TransactionBaseImpl::GetNumDeletes() const { return num_deletes_; }
+
+uint64_t TransactionBaseImpl::GetNumMerges() const { return num_merges_; }
+
+uint64_t TransactionBaseImpl::GetNumKeys() const {
+ uint64_t count = 0;
+
+ // sum up locked keys in all column families
+ for (const auto& key_map_iter : tracked_keys_) {
+ const auto& keys = key_map_iter.second;
+ count += keys.size();
+ }
+
+ return count;
+}
+
+void TransactionBaseImpl::TrackKey(uint32_t cfh_id, const std::string& key,
+ SequenceNumber seq, bool read_only,
+ bool exclusive) {
+ // Update map of all tracked keys for this transaction
+ TrackKey(&tracked_keys_, cfh_id, key, seq, read_only, exclusive);
+
+ if (save_points_ != nullptr && !save_points_->empty()) {
+ // Update map of tracked keys in this SavePoint
+ TrackKey(&save_points_->top().new_keys_, cfh_id, key, seq, read_only,
+ exclusive);
+ }
+}
+
+// Add a key to the given TransactionKeyMap
+// seq for pessimistic transactions is the sequence number from which we know
+// there has not been a concurrent update to the key.
+void TransactionBaseImpl::TrackKey(TransactionKeyMap* key_map, uint32_t cfh_id,
+ const std::string& key, SequenceNumber seq,
+ bool read_only, bool exclusive) {
+ auto& cf_key_map = (*key_map)[cfh_id];
+#ifdef __cpp_lib_unordered_map_try_emplace
+ // use c++17's try_emplace if available, to avoid rehashing the key
+ // in case it is not already in the map
+ auto result = cf_key_map.try_emplace(key, seq);
+ auto iter = result.first;
+ if (!result.second && seq < iter->second.seq) {
+ // Now tracking this key with an earlier sequence number
+ iter->second.seq = seq;
+ }
+#else
+ auto iter = cf_key_map.find(key);
+ if (iter == cf_key_map.end()) {
+ auto result = cf_key_map.emplace(key, TransactionKeyMapInfo(seq));
+ iter = result.first;
+ } else if (seq < iter->second.seq) {
+ // Now tracking this key with an earlier sequence number
+ iter->second.seq = seq;
+ }
+#endif
+ // else we do not update the seq. The smaller the tracked seq, the stronger it
+ // the guarantee since it implies from the seq onward there has not been a
+ // concurrent update to the key. So we update the seq if it implies stronger
+ // guarantees, i.e., if it is smaller than the existing tracked seq.
+
+ if (read_only) {
+ iter->second.num_reads++;
+ } else {
+ iter->second.num_writes++;
+ }
+ iter->second.exclusive |= exclusive;
+}
+
+std::unique_ptr<TransactionKeyMap>
+TransactionBaseImpl::GetTrackedKeysSinceSavePoint() {
+ if (save_points_ != nullptr && !save_points_->empty()) {
+ // Examine the number of reads/writes performed on all keys written
+ // since the last SavePoint and compare to the total number of reads/writes
+ // for each key.
+ TransactionKeyMap* result = new TransactionKeyMap();
+ for (const auto& key_map_iter : save_points_->top().new_keys_) {
+ uint32_t column_family_id = key_map_iter.first;
+ auto& keys = key_map_iter.second;
+
+ auto& cf_tracked_keys = tracked_keys_[column_family_id];
+
+ for (const auto& key_iter : keys) {
+ const std::string& key = key_iter.first;
+ uint32_t num_reads = key_iter.second.num_reads;
+ uint32_t num_writes = key_iter.second.num_writes;
+
+ auto total_key_info = cf_tracked_keys.find(key);
+ assert(total_key_info != cf_tracked_keys.end());
+ assert(total_key_info->second.num_reads >= num_reads);
+ assert(total_key_info->second.num_writes >= num_writes);
+
+ if (total_key_info->second.num_reads == num_reads &&
+ total_key_info->second.num_writes == num_writes) {
+ // All the reads/writes to this key were done in the last savepoint.
+ bool read_only = (num_writes == 0);
+ TrackKey(result, column_family_id, key, key_iter.second.seq,
+ read_only, key_iter.second.exclusive);
+ }
+ }
+ }
+ return std::unique_ptr<TransactionKeyMap>(result);
+ }
+
+ // No SavePoint
+ return nullptr;
+}
+
+// Gets the write batch that should be used for Put/Merge/Deletes.
+//
+// Returns either a WriteBatch or WriteBatchWithIndex depending on whether
+// DisableIndexing() has been called.
+WriteBatchBase* TransactionBaseImpl::GetBatchForWrite() {
+ if (indexing_enabled_) {
+ // Use WriteBatchWithIndex
+ return &write_batch_;
+ } else {
+ // Don't use WriteBatchWithIndex. Return base WriteBatch.
+ return write_batch_.GetWriteBatch();
+ }
+}
+
+void TransactionBaseImpl::ReleaseSnapshot(const Snapshot* snapshot, DB* db) {
+ if (snapshot != nullptr) {
+ ROCKS_LOG_DETAILS(dbimpl_->immutable_db_options().info_log,
+ "ReleaseSnapshot %" PRIu64 " Set",
+ snapshot->GetSequenceNumber());
+ db->ReleaseSnapshot(snapshot);
+ }
+}
+
+void TransactionBaseImpl::UndoGetForUpdate(ColumnFamilyHandle* column_family,
+ const Slice& key) {
+ uint32_t column_family_id = GetColumnFamilyID(column_family);
+ auto& cf_tracked_keys = tracked_keys_[column_family_id];
+ std::string key_str = key.ToString();
+ bool can_decrement = false;
+ bool can_unlock __attribute__((__unused__)) = false;
+
+ if (save_points_ != nullptr && !save_points_->empty()) {
+ // Check if this key was fetched ForUpdate in this SavePoint
+ auto& cf_savepoint_keys = save_points_->top().new_keys_[column_family_id];
+
+ auto savepoint_iter = cf_savepoint_keys.find(key_str);
+ if (savepoint_iter != cf_savepoint_keys.end()) {
+ if (savepoint_iter->second.num_reads > 0) {
+ savepoint_iter->second.num_reads--;
+ can_decrement = true;
+
+ if (savepoint_iter->second.num_reads == 0 &&
+ savepoint_iter->second.num_writes == 0) {
+ // No other GetForUpdates or write on this key in this SavePoint
+ cf_savepoint_keys.erase(savepoint_iter);
+ can_unlock = true;
+ }
+ }
+ }
+ } else {
+ // No SavePoint set
+ can_decrement = true;
+ can_unlock = true;
+ }
+
+ // We can only decrement the read count for this key if we were able to
+ // decrement the read count in the current SavePoint, OR if there is no
+ // SavePoint set.
+ if (can_decrement) {
+ auto key_iter = cf_tracked_keys.find(key_str);
+
+ if (key_iter != cf_tracked_keys.end()) {
+ if (key_iter->second.num_reads > 0) {
+ key_iter->second.num_reads--;
+
+ if (key_iter->second.num_reads == 0 &&
+ key_iter->second.num_writes == 0) {
+ // No other GetForUpdates or writes on this key
+ assert(can_unlock);
+ cf_tracked_keys.erase(key_iter);
+ UnlockGetForUpdate(column_family, key);
+ }
+ }
+ }
+ }
+}
+
+Status TransactionBaseImpl::RebuildFromWriteBatch(WriteBatch* src_batch) {
+ struct IndexedWriteBatchBuilder : public WriteBatch::Handler {
+ Transaction* txn_;
+ DBImpl* db_;
+ IndexedWriteBatchBuilder(Transaction* txn, DBImpl* db)
+ : txn_(txn), db_(db) {
+ assert(dynamic_cast<TransactionBaseImpl*>(txn_) != nullptr);
+ }
+
+ Status PutCF(uint32_t cf, const Slice& key, const Slice& val) override {
+ return txn_->Put(db_->GetColumnFamilyHandle(cf), key, val);
+ }
+
+ Status DeleteCF(uint32_t cf, const Slice& key) override {
+ return txn_->Delete(db_->GetColumnFamilyHandle(cf), key);
+ }
+
+ Status SingleDeleteCF(uint32_t cf, const Slice& key) override {
+ return txn_->SingleDelete(db_->GetColumnFamilyHandle(cf), key);
+ }
+
+ Status MergeCF(uint32_t cf, const Slice& key, const Slice& val) override {
+ return txn_->Merge(db_->GetColumnFamilyHandle(cf), key, val);
+ }
+
+ // this is used for reconstructing prepared transactions upon
+ // recovery. there should not be any meta markers in the batches
+ // we are processing.
+ Status MarkBeginPrepare(bool) override { return Status::InvalidArgument(); }
+
+ Status MarkEndPrepare(const Slice&) override {
+ return Status::InvalidArgument();
+ }
+
+ Status MarkCommit(const Slice&) override {
+ return Status::InvalidArgument();
+ }
+
+ Status MarkRollback(const Slice&) override {
+ return Status::InvalidArgument();
+ }
+ };
+
+ IndexedWriteBatchBuilder copycat(this, dbimpl_);
+ return src_batch->Iterate(&copycat);
+}
+
+WriteBatch* TransactionBaseImpl::GetCommitTimeWriteBatch() {
+ return &commit_time_batch_;
+}
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/transaction_base.h b/src/rocksdb/utilities/transactions/transaction_base.h
new file mode 100644
index 000000000..f279676c6
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/transaction_base.h
@@ -0,0 +1,374 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#ifndef ROCKSDB_LITE
+
+#include <stack>
+#include <string>
+#include <vector>
+
+#include "db/write_batch_internal.h"
+#include "rocksdb/db.h"
+#include "rocksdb/slice.h"
+#include "rocksdb/snapshot.h"
+#include "rocksdb/status.h"
+#include "rocksdb/types.h"
+#include "rocksdb/utilities/transaction.h"
+#include "rocksdb/utilities/transaction_db.h"
+#include "rocksdb/utilities/write_batch_with_index.h"
+#include "util/autovector.h"
+#include "utilities/transactions/transaction_util.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class TransactionBaseImpl : public Transaction {
+ public:
+ TransactionBaseImpl(DB* db, const WriteOptions& write_options);
+
+ virtual ~TransactionBaseImpl();
+
+ // Remove pending operations queued in this transaction.
+ virtual void Clear();
+
+ void Reinitialize(DB* db, const WriteOptions& write_options);
+
+ // Called before executing Put, Merge, Delete, and GetForUpdate. If TryLock
+ // returns non-OK, the Put/Merge/Delete/GetForUpdate will be failed.
+ // do_validate will be false if called from PutUntracked, DeleteUntracked,
+ // MergeUntracked, or GetForUpdate(do_validate=false)
+ virtual Status TryLock(ColumnFamilyHandle* column_family, const Slice& key,
+ bool read_only, bool exclusive,
+ const bool do_validate = true,
+ const bool assume_tracked = false) = 0;
+
+ void SetSavePoint() override;
+
+ Status RollbackToSavePoint() override;
+
+ Status PopSavePoint() override;
+
+ using Transaction::Get;
+ Status Get(const ReadOptions& options, ColumnFamilyHandle* column_family,
+ const Slice& key, std::string* value) override;
+
+ Status Get(const ReadOptions& options, ColumnFamilyHandle* column_family,
+ const Slice& key, PinnableSlice* value) override;
+
+ Status Get(const ReadOptions& options, const Slice& key,
+ std::string* value) override {
+ return Get(options, db_->DefaultColumnFamily(), key, value);
+ }
+
+ using Transaction::GetForUpdate;
+ Status GetForUpdate(const ReadOptions& options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ std::string* value, bool exclusive,
+ const bool do_validate) override;
+
+ Status GetForUpdate(const ReadOptions& options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ PinnableSlice* pinnable_val, bool exclusive,
+ const bool do_validate) override;
+
+ Status GetForUpdate(const ReadOptions& options, const Slice& key,
+ std::string* value, bool exclusive,
+ const bool do_validate) override {
+ return GetForUpdate(options, db_->DefaultColumnFamily(), key, value,
+ exclusive, do_validate);
+ }
+
+ using Transaction::MultiGet;
+ std::vector<Status> MultiGet(
+ const ReadOptions& options,
+ const std::vector<ColumnFamilyHandle*>& column_family,
+ const std::vector<Slice>& keys,
+ std::vector<std::string>* values) override;
+
+ std::vector<Status> MultiGet(const ReadOptions& options,
+ const std::vector<Slice>& keys,
+ std::vector<std::string>* values) override {
+ return MultiGet(options, std::vector<ColumnFamilyHandle*>(
+ keys.size(), db_->DefaultColumnFamily()),
+ keys, values);
+ }
+
+ void MultiGet(const ReadOptions& options, ColumnFamilyHandle* column_family,
+ const size_t num_keys, const Slice* keys, PinnableSlice* values,
+ Status* statuses, const bool sorted_input = false) override;
+
+ using Transaction::MultiGetForUpdate;
+ std::vector<Status> MultiGetForUpdate(
+ const ReadOptions& options,
+ const std::vector<ColumnFamilyHandle*>& column_family,
+ const std::vector<Slice>& keys,
+ std::vector<std::string>* values) override;
+
+ std::vector<Status> MultiGetForUpdate(
+ const ReadOptions& options, const std::vector<Slice>& keys,
+ std::vector<std::string>* values) override {
+ return MultiGetForUpdate(options,
+ std::vector<ColumnFamilyHandle*>(
+ keys.size(), db_->DefaultColumnFamily()),
+ keys, values);
+ }
+
+ Iterator* GetIterator(const ReadOptions& read_options) override;
+ Iterator* GetIterator(const ReadOptions& read_options,
+ ColumnFamilyHandle* column_family) override;
+
+ Status Put(ColumnFamilyHandle* column_family, const Slice& key,
+ const Slice& value, const bool assume_tracked = false) override;
+ Status Put(const Slice& key, const Slice& value) override {
+ return Put(nullptr, key, value);
+ }
+
+ Status Put(ColumnFamilyHandle* column_family, const SliceParts& key,
+ const SliceParts& value,
+ const bool assume_tracked = false) override;
+ Status Put(const SliceParts& key, const SliceParts& value) override {
+ return Put(nullptr, key, value);
+ }
+
+ Status Merge(ColumnFamilyHandle* column_family, const Slice& key,
+ const Slice& value, const bool assume_tracked = false) override;
+ Status Merge(const Slice& key, const Slice& value) override {
+ return Merge(nullptr, key, value);
+ }
+
+ Status Delete(ColumnFamilyHandle* column_family, const Slice& key,
+ const bool assume_tracked = false) override;
+ Status Delete(const Slice& key) override { return Delete(nullptr, key); }
+ Status Delete(ColumnFamilyHandle* column_family, const SliceParts& key,
+ const bool assume_tracked = false) override;
+ Status Delete(const SliceParts& key) override { return Delete(nullptr, key); }
+
+ Status SingleDelete(ColumnFamilyHandle* column_family, const Slice& key,
+ const bool assume_tracked = false) override;
+ Status SingleDelete(const Slice& key) override {
+ return SingleDelete(nullptr, key);
+ }
+ Status SingleDelete(ColumnFamilyHandle* column_family, const SliceParts& key,
+ const bool assume_tracked = false) override;
+ Status SingleDelete(const SliceParts& key) override {
+ return SingleDelete(nullptr, key);
+ }
+
+ Status PutUntracked(ColumnFamilyHandle* column_family, const Slice& key,
+ const Slice& value) override;
+ Status PutUntracked(const Slice& key, const Slice& value) override {
+ return PutUntracked(nullptr, key, value);
+ }
+
+ Status PutUntracked(ColumnFamilyHandle* column_family, const SliceParts& key,
+ const SliceParts& value) override;
+ Status PutUntracked(const SliceParts& key, const SliceParts& value) override {
+ return PutUntracked(nullptr, key, value);
+ }
+
+ Status MergeUntracked(ColumnFamilyHandle* column_family, const Slice& key,
+ const Slice& value) override;
+ Status MergeUntracked(const Slice& key, const Slice& value) override {
+ return MergeUntracked(nullptr, key, value);
+ }
+
+ Status DeleteUntracked(ColumnFamilyHandle* column_family,
+ const Slice& key) override;
+ Status DeleteUntracked(const Slice& key) override {
+ return DeleteUntracked(nullptr, key);
+ }
+ Status DeleteUntracked(ColumnFamilyHandle* column_family,
+ const SliceParts& key) override;
+ Status DeleteUntracked(const SliceParts& key) override {
+ return DeleteUntracked(nullptr, key);
+ }
+
+ Status SingleDeleteUntracked(ColumnFamilyHandle* column_family,
+ const Slice& key) override;
+ Status SingleDeleteUntracked(const Slice& key) override {
+ return SingleDeleteUntracked(nullptr, key);
+ }
+
+ void PutLogData(const Slice& blob) override;
+
+ WriteBatchWithIndex* GetWriteBatch() override;
+
+ virtual void SetLockTimeout(int64_t /*timeout*/) override { /* Do nothing */
+ }
+
+ const Snapshot* GetSnapshot() const override {
+ return snapshot_ ? snapshot_.get() : nullptr;
+ }
+
+ virtual void SetSnapshot() override;
+ void SetSnapshotOnNextOperation(
+ std::shared_ptr<TransactionNotifier> notifier = nullptr) override;
+
+ void ClearSnapshot() override {
+ snapshot_.reset();
+ snapshot_needed_ = false;
+ snapshot_notifier_ = nullptr;
+ }
+
+ void DisableIndexing() override { indexing_enabled_ = false; }
+
+ void EnableIndexing() override { indexing_enabled_ = true; }
+
+ uint64_t GetElapsedTime() const override;
+
+ uint64_t GetNumPuts() const override;
+
+ uint64_t GetNumDeletes() const override;
+
+ uint64_t GetNumMerges() const override;
+
+ uint64_t GetNumKeys() const override;
+
+ void UndoGetForUpdate(ColumnFamilyHandle* column_family,
+ const Slice& key) override;
+ void UndoGetForUpdate(const Slice& key) override {
+ return UndoGetForUpdate(nullptr, key);
+ };
+
+ // Get list of keys in this transaction that must not have any conflicts
+ // with writes in other transactions.
+ const TransactionKeyMap& GetTrackedKeys() const { return tracked_keys_; }
+
+ WriteOptions* GetWriteOptions() override { return &write_options_; }
+
+ void SetWriteOptions(const WriteOptions& write_options) override {
+ write_options_ = write_options;
+ }
+
+ // Used for memory management for snapshot_
+ void ReleaseSnapshot(const Snapshot* snapshot, DB* db);
+
+ // iterates over the given batch and makes the appropriate inserts.
+ // used for rebuilding prepared transactions after recovery.
+ virtual Status RebuildFromWriteBatch(WriteBatch* src_batch) override;
+
+ WriteBatch* GetCommitTimeWriteBatch() override;
+
+ protected:
+ // Add a key to the list of tracked keys.
+ //
+ // seqno is the earliest seqno this key was involved with this transaction.
+ // readonly should be set to true if no data was written for this key
+ void TrackKey(uint32_t cfh_id, const std::string& key, SequenceNumber seqno,
+ bool readonly, bool exclusive);
+
+ // Helper function to add a key to the given TransactionKeyMap
+ static void TrackKey(TransactionKeyMap* key_map, uint32_t cfh_id,
+ const std::string& key, SequenceNumber seqno,
+ bool readonly, bool exclusive);
+
+ // Called when UndoGetForUpdate determines that this key can be unlocked.
+ virtual void UnlockGetForUpdate(ColumnFamilyHandle* column_family,
+ const Slice& key) = 0;
+
+ std::unique_ptr<TransactionKeyMap> GetTrackedKeysSinceSavePoint();
+
+ // Sets a snapshot if SetSnapshotOnNextOperation() has been called.
+ void SetSnapshotIfNeeded();
+
+ // Initialize write_batch_ for 2PC by inserting Noop.
+ inline void InitWriteBatch(bool clear = false) {
+ if (clear) {
+ write_batch_.Clear();
+ }
+ assert(write_batch_.GetDataSize() == WriteBatchInternal::kHeader);
+ WriteBatchInternal::InsertNoop(write_batch_.GetWriteBatch());
+ }
+
+ DB* db_;
+ DBImpl* dbimpl_;
+
+ WriteOptions write_options_;
+
+ const Comparator* cmp_;
+
+ // Stores that time the txn was constructed, in microseconds.
+ uint64_t start_time_;
+
+ // Stores the current snapshot that was set by SetSnapshot or null if
+ // no snapshot is currently set.
+ std::shared_ptr<const Snapshot> snapshot_;
+
+ // Count of various operations pending in this transaction
+ uint64_t num_puts_ = 0;
+ uint64_t num_deletes_ = 0;
+ uint64_t num_merges_ = 0;
+
+ struct SavePoint {
+ std::shared_ptr<const Snapshot> snapshot_;
+ bool snapshot_needed_ = false;
+ std::shared_ptr<TransactionNotifier> snapshot_notifier_;
+ uint64_t num_puts_ = 0;
+ uint64_t num_deletes_ = 0;
+ uint64_t num_merges_ = 0;
+
+ // Record all keys tracked since the last savepoint
+ TransactionKeyMap new_keys_;
+
+ SavePoint(std::shared_ptr<const Snapshot> snapshot, bool snapshot_needed,
+ std::shared_ptr<TransactionNotifier> snapshot_notifier,
+ uint64_t num_puts, uint64_t num_deletes, uint64_t num_merges)
+ : snapshot_(snapshot),
+ snapshot_needed_(snapshot_needed),
+ snapshot_notifier_(snapshot_notifier),
+ num_puts_(num_puts),
+ num_deletes_(num_deletes),
+ num_merges_(num_merges) {}
+
+ SavePoint() = default;
+ };
+
+ // Records writes pending in this transaction
+ WriteBatchWithIndex write_batch_;
+
+ // Map from column_family_id to map of keys that are involved in this
+ // transaction.
+ // For Pessimistic Transactions this is the list of locked keys.
+ // Optimistic Transactions will wait till commit time to do conflict checking.
+ TransactionKeyMap tracked_keys_;
+
+ // Stack of the Snapshot saved at each save point. Saved snapshots may be
+ // nullptr if there was no snapshot at the time SetSavePoint() was called.
+ std::unique_ptr<std::stack<TransactionBaseImpl::SavePoint,
+ autovector<TransactionBaseImpl::SavePoint>>>
+ save_points_;
+
+ private:
+ friend class WritePreparedTxn;
+ // Extra data to be persisted with the commit. Note this is only used when
+ // prepare phase is not skipped.
+ WriteBatch commit_time_batch_;
+
+ // If true, future Put/Merge/Deletes will be indexed in the
+ // WriteBatchWithIndex.
+ // If false, future Put/Merge/Deletes will be inserted directly into the
+ // underlying WriteBatch and not indexed in the WriteBatchWithIndex.
+ bool indexing_enabled_;
+
+ // SetSnapshotOnNextOperation() has been called and the snapshot has not yet
+ // been reset.
+ bool snapshot_needed_ = false;
+
+ // SetSnapshotOnNextOperation() has been called and the caller would like
+ // a notification through the TransactionNotifier interface
+ std::shared_ptr<TransactionNotifier> snapshot_notifier_ = nullptr;
+
+ Status TryLock(ColumnFamilyHandle* column_family, const SliceParts& key,
+ bool read_only, bool exclusive, const bool do_validate = true,
+ const bool assume_tracked = false);
+
+ WriteBatchBase* GetBatchForWrite();
+ void SetSnapshotInternal(const Snapshot* snapshot);
+};
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/transaction_db_mutex_impl.cc b/src/rocksdb/utilities/transactions/transaction_db_mutex_impl.cc
new file mode 100644
index 000000000..345c4be90
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/transaction_db_mutex_impl.cc
@@ -0,0 +1,135 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "utilities/transactions/transaction_db_mutex_impl.h"
+
+#include <chrono>
+#include <condition_variable>
+#include <functional>
+#include <mutex>
+
+#include "rocksdb/utilities/transaction_db_mutex.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class TransactionDBMutexImpl : public TransactionDBMutex {
+ public:
+ TransactionDBMutexImpl() {}
+ ~TransactionDBMutexImpl() override {}
+
+ Status Lock() override;
+
+ Status TryLockFor(int64_t timeout_time) override;
+
+ void UnLock() override { mutex_.unlock(); }
+
+ friend class TransactionDBCondVarImpl;
+
+ private:
+ std::mutex mutex_;
+};
+
+class TransactionDBCondVarImpl : public TransactionDBCondVar {
+ public:
+ TransactionDBCondVarImpl() {}
+ ~TransactionDBCondVarImpl() override {}
+
+ Status Wait(std::shared_ptr<TransactionDBMutex> mutex) override;
+
+ Status WaitFor(std::shared_ptr<TransactionDBMutex> mutex,
+ int64_t timeout_time) override;
+
+ void Notify() override { cv_.notify_one(); }
+
+ void NotifyAll() override { cv_.notify_all(); }
+
+ private:
+ std::condition_variable cv_;
+};
+
+std::shared_ptr<TransactionDBMutex>
+TransactionDBMutexFactoryImpl::AllocateMutex() {
+ return std::shared_ptr<TransactionDBMutex>(new TransactionDBMutexImpl());
+}
+
+std::shared_ptr<TransactionDBCondVar>
+TransactionDBMutexFactoryImpl::AllocateCondVar() {
+ return std::shared_ptr<TransactionDBCondVar>(new TransactionDBCondVarImpl());
+}
+
+Status TransactionDBMutexImpl::Lock() {
+ mutex_.lock();
+ return Status::OK();
+}
+
+Status TransactionDBMutexImpl::TryLockFor(int64_t timeout_time) {
+ bool locked = true;
+
+ if (timeout_time == 0) {
+ locked = mutex_.try_lock();
+ } else {
+ // Previously, this code used a std::timed_mutex. However, this was changed
+ // due to known bugs in gcc versions < 4.9.
+ // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=54562
+ //
+ // Since this mutex isn't held for long and only a single mutex is ever
+ // held at a time, it is reasonable to ignore the lock timeout_time here
+ // and only check it when waiting on the condition_variable.
+ mutex_.lock();
+ }
+
+ if (!locked) {
+ // timeout acquiring mutex
+ return Status::TimedOut(Status::SubCode::kMutexTimeout);
+ }
+
+ return Status::OK();
+}
+
+Status TransactionDBCondVarImpl::Wait(
+ std::shared_ptr<TransactionDBMutex> mutex) {
+ auto mutex_impl = reinterpret_cast<TransactionDBMutexImpl*>(mutex.get());
+
+ std::unique_lock<std::mutex> lock(mutex_impl->mutex_, std::adopt_lock);
+ cv_.wait(lock);
+
+ // Make sure unique_lock doesn't unlock mutex when it destructs
+ lock.release();
+
+ return Status::OK();
+}
+
+Status TransactionDBCondVarImpl::WaitFor(
+ std::shared_ptr<TransactionDBMutex> mutex, int64_t timeout_time) {
+ Status s;
+
+ auto mutex_impl = reinterpret_cast<TransactionDBMutexImpl*>(mutex.get());
+ std::unique_lock<std::mutex> lock(mutex_impl->mutex_, std::adopt_lock);
+
+ if (timeout_time < 0) {
+ // If timeout is negative, do not use a timeout
+ cv_.wait(lock);
+ } else {
+ auto duration = std::chrono::microseconds(timeout_time);
+ auto cv_status = cv_.wait_for(lock, duration);
+
+ // Check if the wait stopped due to timing out.
+ if (cv_status == std::cv_status::timeout) {
+ s = Status::TimedOut(Status::SubCode::kMutexTimeout);
+ }
+ }
+
+ // Make sure unique_lock doesn't unlock mutex when it destructs
+ lock.release();
+
+ // CV was signaled, or we spuriously woke up (but didn't time out)
+ return s;
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/transaction_db_mutex_impl.h b/src/rocksdb/utilities/transactions/transaction_db_mutex_impl.h
new file mode 100644
index 000000000..fbee92832
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/transaction_db_mutex_impl.h
@@ -0,0 +1,26 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+#ifndef ROCKSDB_LITE
+
+#include "rocksdb/utilities/transaction_db_mutex.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class TransactionDBMutex;
+class TransactionDBCondVar;
+
+// Default implementation of TransactionDBMutexFactory. May be overridden
+// by TransactionDBOptions.custom_mutex_factory.
+class TransactionDBMutexFactoryImpl : public TransactionDBMutexFactory {
+ public:
+ std::shared_ptr<TransactionDBMutex> AllocateMutex() override;
+ std::shared_ptr<TransactionDBCondVar> AllocateCondVar() override;
+};
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/transaction_lock_mgr.cc b/src/rocksdb/utilities/transactions/transaction_lock_mgr.cc
new file mode 100644
index 000000000..82b614033
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/transaction_lock_mgr.cc
@@ -0,0 +1,745 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "utilities/transactions/transaction_lock_mgr.h"
+
+#include <cinttypes>
+
+#include <algorithm>
+#include <condition_variable>
+#include <functional>
+#include <mutex>
+#include <string>
+#include <vector>
+
+#include "monitoring/perf_context_imp.h"
+#include "rocksdb/slice.h"
+#include "rocksdb/utilities/transaction_db_mutex.h"
+#include "test_util/sync_point.h"
+#include "util/cast_util.h"
+#include "util/hash.h"
+#include "util/thread_local.h"
+#include "utilities/transactions/pessimistic_transaction_db.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+struct LockInfo {
+ bool exclusive;
+ autovector<TransactionID> txn_ids;
+
+ // Transaction locks are not valid after this time in us
+ uint64_t expiration_time;
+
+ LockInfo(TransactionID id, uint64_t time, bool ex)
+ : exclusive(ex), expiration_time(time) {
+ txn_ids.push_back(id);
+ }
+ LockInfo(const LockInfo& lock_info)
+ : exclusive(lock_info.exclusive),
+ txn_ids(lock_info.txn_ids),
+ expiration_time(lock_info.expiration_time) {}
+};
+
+struct LockMapStripe {
+ explicit LockMapStripe(std::shared_ptr<TransactionDBMutexFactory> factory) {
+ stripe_mutex = factory->AllocateMutex();
+ stripe_cv = factory->AllocateCondVar();
+ assert(stripe_mutex);
+ assert(stripe_cv);
+ }
+
+ // Mutex must be held before modifying keys map
+ std::shared_ptr<TransactionDBMutex> stripe_mutex;
+
+ // Condition Variable per stripe for waiting on a lock
+ std::shared_ptr<TransactionDBCondVar> stripe_cv;
+
+ // Locked keys mapped to the info about the transactions that locked them.
+ // TODO(agiardullo): Explore performance of other data structures.
+ std::unordered_map<std::string, LockInfo> keys;
+};
+
+// Map of #num_stripes LockMapStripes
+struct LockMap {
+ explicit LockMap(size_t num_stripes,
+ std::shared_ptr<TransactionDBMutexFactory> factory)
+ : num_stripes_(num_stripes) {
+ lock_map_stripes_.reserve(num_stripes);
+ for (size_t i = 0; i < num_stripes; i++) {
+ LockMapStripe* stripe = new LockMapStripe(factory);
+ lock_map_stripes_.push_back(stripe);
+ }
+ }
+
+ ~LockMap() {
+ for (auto stripe : lock_map_stripes_) {
+ delete stripe;
+ }
+ }
+
+ // Number of sepearate LockMapStripes to create, each with their own Mutex
+ const size_t num_stripes_;
+
+ // Count of keys that are currently locked in this column family.
+ // (Only maintained if TransactionLockMgr::max_num_locks_ is positive.)
+ std::atomic<int64_t> lock_cnt{0};
+
+ std::vector<LockMapStripe*> lock_map_stripes_;
+
+ size_t GetStripe(const std::string& key) const;
+};
+
+void DeadlockInfoBuffer::AddNewPath(DeadlockPath path) {
+ std::lock_guard<std::mutex> lock(paths_buffer_mutex_);
+
+ if (paths_buffer_.empty()) {
+ return;
+ }
+
+ paths_buffer_[buffer_idx_] = std::move(path);
+ buffer_idx_ = (buffer_idx_ + 1) % paths_buffer_.size();
+}
+
+void DeadlockInfoBuffer::Resize(uint32_t target_size) {
+ std::lock_guard<std::mutex> lock(paths_buffer_mutex_);
+
+ paths_buffer_ = Normalize();
+
+ // Drop the deadlocks that will no longer be needed ater the normalize
+ if (target_size < paths_buffer_.size()) {
+ paths_buffer_.erase(
+ paths_buffer_.begin(),
+ paths_buffer_.begin() + (paths_buffer_.size() - target_size));
+ buffer_idx_ = 0;
+ }
+ // Resize the buffer to the target size and restore the buffer's idx
+ else {
+ auto prev_size = paths_buffer_.size();
+ paths_buffer_.resize(target_size);
+ buffer_idx_ = (uint32_t)prev_size;
+ }
+}
+
+std::vector<DeadlockPath> DeadlockInfoBuffer::Normalize() {
+ auto working = paths_buffer_;
+
+ if (working.empty()) {
+ return working;
+ }
+
+ // Next write occurs at a nonexistent path's slot
+ if (paths_buffer_[buffer_idx_].empty()) {
+ working.resize(buffer_idx_);
+ } else {
+ std::rotate(working.begin(), working.begin() + buffer_idx_, working.end());
+ }
+
+ return working;
+}
+
+std::vector<DeadlockPath> DeadlockInfoBuffer::PrepareBuffer() {
+ std::lock_guard<std::mutex> lock(paths_buffer_mutex_);
+
+ // Reversing the normalized vector returns the latest deadlocks first
+ auto working = Normalize();
+ std::reverse(working.begin(), working.end());
+
+ return working;
+}
+
+namespace {
+void UnrefLockMapsCache(void* ptr) {
+ // Called when a thread exits or a ThreadLocalPtr gets destroyed.
+ auto lock_maps_cache =
+ static_cast<std::unordered_map<uint32_t, std::shared_ptr<LockMap>>*>(ptr);
+ delete lock_maps_cache;
+}
+} // anonymous namespace
+
+TransactionLockMgr::TransactionLockMgr(
+ TransactionDB* txn_db, size_t default_num_stripes, int64_t max_num_locks,
+ uint32_t max_num_deadlocks,
+ std::shared_ptr<TransactionDBMutexFactory> mutex_factory)
+ : txn_db_impl_(nullptr),
+ default_num_stripes_(default_num_stripes),
+ max_num_locks_(max_num_locks),
+ lock_maps_cache_(new ThreadLocalPtr(&UnrefLockMapsCache)),
+ dlock_buffer_(max_num_deadlocks),
+ mutex_factory_(mutex_factory) {
+ assert(txn_db);
+ txn_db_impl_ =
+ static_cast_with_check<PessimisticTransactionDB, TransactionDB>(txn_db);
+}
+
+TransactionLockMgr::~TransactionLockMgr() {}
+
+size_t LockMap::GetStripe(const std::string& key) const {
+ assert(num_stripes_ > 0);
+ return fastrange64(GetSliceNPHash64(key), num_stripes_);
+}
+
+void TransactionLockMgr::AddColumnFamily(uint32_t column_family_id) {
+ InstrumentedMutexLock l(&lock_map_mutex_);
+
+ if (lock_maps_.find(column_family_id) == lock_maps_.end()) {
+ lock_maps_.emplace(column_family_id,
+ std::make_shared<LockMap>(default_num_stripes_, mutex_factory_));
+ } else {
+ // column_family already exists in lock map
+ assert(false);
+ }
+}
+
+void TransactionLockMgr::RemoveColumnFamily(uint32_t column_family_id) {
+ // Remove lock_map for this column family. Since the lock map is stored
+ // as a shared ptr, concurrent transactions can still keep using it
+ // until they release their references to it.
+ {
+ InstrumentedMutexLock l(&lock_map_mutex_);
+
+ auto lock_maps_iter = lock_maps_.find(column_family_id);
+ assert(lock_maps_iter != lock_maps_.end());
+
+ lock_maps_.erase(lock_maps_iter);
+ } // lock_map_mutex_
+
+ // Clear all thread-local caches
+ autovector<void*> local_caches;
+ lock_maps_cache_->Scrape(&local_caches, nullptr);
+ for (auto cache : local_caches) {
+ delete static_cast<LockMaps*>(cache);
+ }
+}
+
+// Look up the LockMap std::shared_ptr for a given column_family_id.
+// Note: The LockMap is only valid as long as the caller is still holding on
+// to the returned std::shared_ptr.
+std::shared_ptr<LockMap> TransactionLockMgr::GetLockMap(
+ uint32_t column_family_id) {
+ // First check thread-local cache
+ if (lock_maps_cache_->Get() == nullptr) {
+ lock_maps_cache_->Reset(new LockMaps());
+ }
+
+ auto lock_maps_cache = static_cast<LockMaps*>(lock_maps_cache_->Get());
+
+ auto lock_map_iter = lock_maps_cache->find(column_family_id);
+ if (lock_map_iter != lock_maps_cache->end()) {
+ // Found lock map for this column family.
+ return lock_map_iter->second;
+ }
+
+ // Not found in local cache, grab mutex and check shared LockMaps
+ InstrumentedMutexLock l(&lock_map_mutex_);
+
+ lock_map_iter = lock_maps_.find(column_family_id);
+ if (lock_map_iter == lock_maps_.end()) {
+ return std::shared_ptr<LockMap>(nullptr);
+ } else {
+ // Found lock map. Store in thread-local cache and return.
+ std::shared_ptr<LockMap>& lock_map = lock_map_iter->second;
+ lock_maps_cache->insert({column_family_id, lock_map});
+
+ return lock_map;
+ }
+}
+
+// Returns true if this lock has expired and can be acquired by another
+// transaction.
+// If false, sets *expire_time to the expiration time of the lock according
+// to Env->GetMicros() or 0 if no expiration.
+bool TransactionLockMgr::IsLockExpired(TransactionID txn_id,
+ const LockInfo& lock_info, Env* env,
+ uint64_t* expire_time) {
+ auto now = env->NowMicros();
+
+ bool expired =
+ (lock_info.expiration_time > 0 && lock_info.expiration_time <= now);
+
+ if (!expired && lock_info.expiration_time > 0) {
+ // return how many microseconds until lock will be expired
+ *expire_time = lock_info.expiration_time;
+ } else {
+ for (auto id : lock_info.txn_ids) {
+ if (txn_id == id) {
+ continue;
+ }
+
+ bool success = txn_db_impl_->TryStealingExpiredTransactionLocks(id);
+ if (!success) {
+ expired = false;
+ break;
+ }
+ *expire_time = 0;
+ }
+ }
+
+ return expired;
+}
+
+Status TransactionLockMgr::TryLock(PessimisticTransaction* txn,
+ uint32_t column_family_id,
+ const std::string& key, Env* env,
+ bool exclusive) {
+ // Lookup lock map for this column family id
+ std::shared_ptr<LockMap> lock_map_ptr = GetLockMap(column_family_id);
+ LockMap* lock_map = lock_map_ptr.get();
+ if (lock_map == nullptr) {
+ char msg[255];
+ snprintf(msg, sizeof(msg), "Column family id not found: %" PRIu32,
+ column_family_id);
+
+ return Status::InvalidArgument(msg);
+ }
+
+ // Need to lock the mutex for the stripe that this key hashes to
+ size_t stripe_num = lock_map->GetStripe(key);
+ assert(lock_map->lock_map_stripes_.size() > stripe_num);
+ LockMapStripe* stripe = lock_map->lock_map_stripes_.at(stripe_num);
+
+ LockInfo lock_info(txn->GetID(), txn->GetExpirationTime(), exclusive);
+ int64_t timeout = txn->GetLockTimeout();
+
+ return AcquireWithTimeout(txn, lock_map, stripe, column_family_id, key, env,
+ timeout, std::move(lock_info));
+}
+
+// Helper function for TryLock().
+Status TransactionLockMgr::AcquireWithTimeout(
+ PessimisticTransaction* txn, LockMap* lock_map, LockMapStripe* stripe,
+ uint32_t column_family_id, const std::string& key, Env* env,
+ int64_t timeout, LockInfo&& lock_info) {
+ Status result;
+ uint64_t end_time = 0;
+
+ if (timeout > 0) {
+ uint64_t start_time = env->NowMicros();
+ end_time = start_time + timeout;
+ }
+
+ if (timeout < 0) {
+ // If timeout is negative, we wait indefinitely to acquire the lock
+ result = stripe->stripe_mutex->Lock();
+ } else {
+ result = stripe->stripe_mutex->TryLockFor(timeout);
+ }
+
+ if (!result.ok()) {
+ // failed to acquire mutex
+ return result;
+ }
+
+ // Acquire lock if we are able to
+ uint64_t expire_time_hint = 0;
+ autovector<TransactionID> wait_ids;
+ result = AcquireLocked(lock_map, stripe, key, env, std::move(lock_info),
+ &expire_time_hint, &wait_ids);
+
+ if (!result.ok() && timeout != 0) {
+ PERF_TIMER_GUARD(key_lock_wait_time);
+ PERF_COUNTER_ADD(key_lock_wait_count, 1);
+ // If we weren't able to acquire the lock, we will keep retrying as long
+ // as the timeout allows.
+ bool timed_out = false;
+ do {
+ // Decide how long to wait
+ int64_t cv_end_time = -1;
+
+ // Check if held lock's expiration time is sooner than our timeout
+ if (expire_time_hint > 0 &&
+ (timeout < 0 || (timeout > 0 && expire_time_hint < end_time))) {
+ // expiration time is sooner than our timeout
+ cv_end_time = expire_time_hint;
+ } else if (timeout >= 0) {
+ cv_end_time = end_time;
+ }
+
+ assert(result.IsBusy() || wait_ids.size() != 0);
+
+ // We are dependent on a transaction to finish, so perform deadlock
+ // detection.
+ if (wait_ids.size() != 0) {
+ if (txn->IsDeadlockDetect()) {
+ if (IncrementWaiters(txn, wait_ids, key, column_family_id,
+ lock_info.exclusive, env)) {
+ result = Status::Busy(Status::SubCode::kDeadlock);
+ stripe->stripe_mutex->UnLock();
+ return result;
+ }
+ }
+ txn->SetWaitingTxn(wait_ids, column_family_id, &key);
+ }
+
+ TEST_SYNC_POINT("TransactionLockMgr::AcquireWithTimeout:WaitingTxn");
+ if (cv_end_time < 0) {
+ // Wait indefinitely
+ result = stripe->stripe_cv->Wait(stripe->stripe_mutex);
+ } else {
+ uint64_t now = env->NowMicros();
+ if (static_cast<uint64_t>(cv_end_time) > now) {
+ result = stripe->stripe_cv->WaitFor(stripe->stripe_mutex,
+ cv_end_time - now);
+ }
+ }
+
+ if (wait_ids.size() != 0) {
+ txn->ClearWaitingTxn();
+ if (txn->IsDeadlockDetect()) {
+ DecrementWaiters(txn, wait_ids);
+ }
+ }
+
+ if (result.IsTimedOut()) {
+ timed_out = true;
+ // Even though we timed out, we will still make one more attempt to
+ // acquire lock below (it is possible the lock expired and we
+ // were never signaled).
+ }
+
+ if (result.ok() || result.IsTimedOut()) {
+ result = AcquireLocked(lock_map, stripe, key, env, std::move(lock_info),
+ &expire_time_hint, &wait_ids);
+ }
+ } while (!result.ok() && !timed_out);
+ }
+
+ stripe->stripe_mutex->UnLock();
+
+ return result;
+}
+
+void TransactionLockMgr::DecrementWaiters(
+ const PessimisticTransaction* txn,
+ const autovector<TransactionID>& wait_ids) {
+ std::lock_guard<std::mutex> lock(wait_txn_map_mutex_);
+ DecrementWaitersImpl(txn, wait_ids);
+}
+
+void TransactionLockMgr::DecrementWaitersImpl(
+ const PessimisticTransaction* txn,
+ const autovector<TransactionID>& wait_ids) {
+ auto id = txn->GetID();
+ assert(wait_txn_map_.Contains(id));
+ wait_txn_map_.Delete(id);
+
+ for (auto wait_id : wait_ids) {
+ rev_wait_txn_map_.Get(wait_id)--;
+ if (rev_wait_txn_map_.Get(wait_id) == 0) {
+ rev_wait_txn_map_.Delete(wait_id);
+ }
+ }
+}
+
+bool TransactionLockMgr::IncrementWaiters(
+ const PessimisticTransaction* txn,
+ const autovector<TransactionID>& wait_ids, const std::string& key,
+ const uint32_t& cf_id, const bool& exclusive, Env* const env) {
+ auto id = txn->GetID();
+ std::vector<int> queue_parents(static_cast<size_t>(txn->GetDeadlockDetectDepth()));
+ std::vector<TransactionID> queue_values(static_cast<size_t>(txn->GetDeadlockDetectDepth()));
+ std::lock_guard<std::mutex> lock(wait_txn_map_mutex_);
+ assert(!wait_txn_map_.Contains(id));
+
+ wait_txn_map_.Insert(id, {wait_ids, cf_id, exclusive, key});
+
+ for (auto wait_id : wait_ids) {
+ if (rev_wait_txn_map_.Contains(wait_id)) {
+ rev_wait_txn_map_.Get(wait_id)++;
+ } else {
+ rev_wait_txn_map_.Insert(wait_id, 1);
+ }
+ }
+
+ // No deadlock if nobody is waiting on self.
+ if (!rev_wait_txn_map_.Contains(id)) {
+ return false;
+ }
+
+ const auto* next_ids = &wait_ids;
+ int parent = -1;
+ int64_t deadlock_time = 0;
+ for (int tail = 0, head = 0; head < txn->GetDeadlockDetectDepth(); head++) {
+ int i = 0;
+ if (next_ids) {
+ for (; i < static_cast<int>(next_ids->size()) &&
+ tail + i < txn->GetDeadlockDetectDepth();
+ i++) {
+ queue_values[tail + i] = (*next_ids)[i];
+ queue_parents[tail + i] = parent;
+ }
+ tail += i;
+ }
+
+ // No more items in the list, meaning no deadlock.
+ if (tail == head) {
+ return false;
+ }
+
+ auto next = queue_values[head];
+ if (next == id) {
+ std::vector<DeadlockInfo> path;
+ while (head != -1) {
+ assert(wait_txn_map_.Contains(queue_values[head]));
+
+ auto extracted_info = wait_txn_map_.Get(queue_values[head]);
+ path.push_back({queue_values[head], extracted_info.m_cf_id,
+ extracted_info.m_exclusive,
+ extracted_info.m_waiting_key});
+ head = queue_parents[head];
+ }
+ env->GetCurrentTime(&deadlock_time);
+ std::reverse(path.begin(), path.end());
+ dlock_buffer_.AddNewPath(DeadlockPath(path, deadlock_time));
+ deadlock_time = 0;
+ DecrementWaitersImpl(txn, wait_ids);
+ return true;
+ } else if (!wait_txn_map_.Contains(next)) {
+ next_ids = nullptr;
+ continue;
+ } else {
+ parent = head;
+ next_ids = &(wait_txn_map_.Get(next).m_neighbors);
+ }
+ }
+
+ // Wait cycle too big, just assume deadlock.
+ env->GetCurrentTime(&deadlock_time);
+ dlock_buffer_.AddNewPath(DeadlockPath(deadlock_time, true));
+ DecrementWaitersImpl(txn, wait_ids);
+ return true;
+}
+
+// Try to lock this key after we have acquired the mutex.
+// Sets *expire_time to the expiration time in microseconds
+// or 0 if no expiration.
+// REQUIRED: Stripe mutex must be held.
+Status TransactionLockMgr::AcquireLocked(LockMap* lock_map,
+ LockMapStripe* stripe,
+ const std::string& key, Env* env,
+ LockInfo&& txn_lock_info,
+ uint64_t* expire_time,
+ autovector<TransactionID>* txn_ids) {
+ assert(txn_lock_info.txn_ids.size() == 1);
+
+ Status result;
+ // Check if this key is already locked
+ auto stripe_iter = stripe->keys.find(key);
+ if (stripe_iter != stripe->keys.end()) {
+ // Lock already held
+ LockInfo& lock_info = stripe_iter->second;
+ assert(lock_info.txn_ids.size() == 1 || !lock_info.exclusive);
+
+ if (lock_info.exclusive || txn_lock_info.exclusive) {
+ if (lock_info.txn_ids.size() == 1 &&
+ lock_info.txn_ids[0] == txn_lock_info.txn_ids[0]) {
+ // The list contains one txn and we're it, so just take it.
+ lock_info.exclusive = txn_lock_info.exclusive;
+ lock_info.expiration_time = txn_lock_info.expiration_time;
+ } else {
+ // Check if it's expired. Skips over txn_lock_info.txn_ids[0] in case
+ // it's there for a shared lock with multiple holders which was not
+ // caught in the first case.
+ if (IsLockExpired(txn_lock_info.txn_ids[0], lock_info, env,
+ expire_time)) {
+ // lock is expired, can steal it
+ lock_info.txn_ids = txn_lock_info.txn_ids;
+ lock_info.exclusive = txn_lock_info.exclusive;
+ lock_info.expiration_time = txn_lock_info.expiration_time;
+ // lock_cnt does not change
+ } else {
+ result = Status::TimedOut(Status::SubCode::kLockTimeout);
+ *txn_ids = lock_info.txn_ids;
+ }
+ }
+ } else {
+ // We are requesting shared access to a shared lock, so just grant it.
+ lock_info.txn_ids.push_back(txn_lock_info.txn_ids[0]);
+ // Using std::max means that expiration time never goes down even when
+ // a transaction is removed from the list. The correct solution would be
+ // to track expiry for every transaction, but this would also work for
+ // now.
+ lock_info.expiration_time =
+ std::max(lock_info.expiration_time, txn_lock_info.expiration_time);
+ }
+ } else { // Lock not held.
+ // Check lock limit
+ if (max_num_locks_ > 0 &&
+ lock_map->lock_cnt.load(std::memory_order_acquire) >= max_num_locks_) {
+ result = Status::Busy(Status::SubCode::kLockLimit);
+ } else {
+ // acquire lock
+ stripe->keys.emplace(key, std::move(txn_lock_info));
+
+ // Maintain lock count if there is a limit on the number of locks
+ if (max_num_locks_) {
+ lock_map->lock_cnt++;
+ }
+ }
+ }
+
+ return result;
+}
+
+void TransactionLockMgr::UnLockKey(const PessimisticTransaction* txn,
+ const std::string& key,
+ LockMapStripe* stripe, LockMap* lock_map,
+ Env* env) {
+#ifdef NDEBUG
+ (void)env;
+#endif
+ TransactionID txn_id = txn->GetID();
+
+ auto stripe_iter = stripe->keys.find(key);
+ if (stripe_iter != stripe->keys.end()) {
+ auto& txns = stripe_iter->second.txn_ids;
+ auto txn_it = std::find(txns.begin(), txns.end(), txn_id);
+ // Found the key we locked. unlock it.
+ if (txn_it != txns.end()) {
+ if (txns.size() == 1) {
+ stripe->keys.erase(stripe_iter);
+ } else {
+ auto last_it = txns.end() - 1;
+ if (txn_it != last_it) {
+ *txn_it = *last_it;
+ }
+ txns.pop_back();
+ }
+
+ if (max_num_locks_ > 0) {
+ // Maintain lock count if there is a limit on the number of locks.
+ assert(lock_map->lock_cnt.load(std::memory_order_relaxed) > 0);
+ lock_map->lock_cnt--;
+ }
+ }
+ } else {
+ // This key is either not locked or locked by someone else. This should
+ // only happen if the unlocking transaction has expired.
+ assert(txn->GetExpirationTime() > 0 &&
+ txn->GetExpirationTime() < env->NowMicros());
+ }
+}
+
+void TransactionLockMgr::UnLock(PessimisticTransaction* txn,
+ uint32_t column_family_id,
+ const std::string& key, Env* env) {
+ std::shared_ptr<LockMap> lock_map_ptr = GetLockMap(column_family_id);
+ LockMap* lock_map = lock_map_ptr.get();
+ if (lock_map == nullptr) {
+ // Column Family must have been dropped.
+ return;
+ }
+
+ // Lock the mutex for the stripe that this key hashes to
+ size_t stripe_num = lock_map->GetStripe(key);
+ assert(lock_map->lock_map_stripes_.size() > stripe_num);
+ LockMapStripe* stripe = lock_map->lock_map_stripes_.at(stripe_num);
+
+ stripe->stripe_mutex->Lock();
+ UnLockKey(txn, key, stripe, lock_map, env);
+ stripe->stripe_mutex->UnLock();
+
+ // Signal waiting threads to retry locking
+ stripe->stripe_cv->NotifyAll();
+}
+
+void TransactionLockMgr::UnLock(const PessimisticTransaction* txn,
+ const TransactionKeyMap* key_map, Env* env) {
+ for (auto& key_map_iter : *key_map) {
+ uint32_t column_family_id = key_map_iter.first;
+ auto& keys = key_map_iter.second;
+
+ std::shared_ptr<LockMap> lock_map_ptr = GetLockMap(column_family_id);
+ LockMap* lock_map = lock_map_ptr.get();
+
+ if (lock_map == nullptr) {
+ // Column Family must have been dropped.
+ return;
+ }
+
+ // Bucket keys by lock_map_ stripe
+ std::unordered_map<size_t, std::vector<const std::string*>> keys_by_stripe(
+ std::max(keys.size(), lock_map->num_stripes_));
+
+ for (auto& key_iter : keys) {
+ const std::string& key = key_iter.first;
+
+ size_t stripe_num = lock_map->GetStripe(key);
+ keys_by_stripe[stripe_num].push_back(&key);
+ }
+
+ // For each stripe, grab the stripe mutex and unlock all keys in this stripe
+ for (auto& stripe_iter : keys_by_stripe) {
+ size_t stripe_num = stripe_iter.first;
+ auto& stripe_keys = stripe_iter.second;
+
+ assert(lock_map->lock_map_stripes_.size() > stripe_num);
+ LockMapStripe* stripe = lock_map->lock_map_stripes_.at(stripe_num);
+
+ stripe->stripe_mutex->Lock();
+
+ for (const std::string* key : stripe_keys) {
+ UnLockKey(txn, *key, stripe, lock_map, env);
+ }
+
+ stripe->stripe_mutex->UnLock();
+
+ // Signal waiting threads to retry locking
+ stripe->stripe_cv->NotifyAll();
+ }
+ }
+}
+
+TransactionLockMgr::LockStatusData TransactionLockMgr::GetLockStatusData() {
+ LockStatusData data;
+ // Lock order here is important. The correct order is lock_map_mutex_, then
+ // for every column family ID in ascending order lock every stripe in
+ // ascending order.
+ InstrumentedMutexLock l(&lock_map_mutex_);
+
+ std::vector<uint32_t> cf_ids;
+ for (const auto& map : lock_maps_) {
+ cf_ids.push_back(map.first);
+ }
+ std::sort(cf_ids.begin(), cf_ids.end());
+
+ for (auto i : cf_ids) {
+ const auto& stripes = lock_maps_[i]->lock_map_stripes_;
+ // Iterate and lock all stripes in ascending order.
+ for (const auto& j : stripes) {
+ j->stripe_mutex->Lock();
+ for (const auto& it : j->keys) {
+ struct KeyLockInfo info;
+ info.exclusive = it.second.exclusive;
+ info.key = it.first;
+ for (const auto& id : it.second.txn_ids) {
+ info.ids.push_back(id);
+ }
+ data.insert({i, info});
+ }
+ }
+ }
+
+ // Unlock everything. Unlocking order is not important.
+ for (auto i : cf_ids) {
+ const auto& stripes = lock_maps_[i]->lock_map_stripes_;
+ for (const auto& j : stripes) {
+ j->stripe_mutex->UnLock();
+ }
+ }
+
+ return data;
+}
+std::vector<DeadlockPath> TransactionLockMgr::GetDeadlockInfoBuffer() {
+ return dlock_buffer_.PrepareBuffer();
+}
+
+void TransactionLockMgr::Resize(uint32_t target_size) {
+ dlock_buffer_.Resize(target_size);
+}
+
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/transaction_lock_mgr.h b/src/rocksdb/utilities/transactions/transaction_lock_mgr.h
new file mode 100644
index 000000000..b4fd85929
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/transaction_lock_mgr.h
@@ -0,0 +1,158 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+#ifndef ROCKSDB_LITE
+
+#include <chrono>
+#include <string>
+#include <unordered_map>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "monitoring/instrumented_mutex.h"
+#include "rocksdb/utilities/transaction.h"
+#include "util/autovector.h"
+#include "util/hash_map.h"
+#include "util/thread_local.h"
+#include "utilities/transactions/pessimistic_transaction.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class ColumnFamilyHandle;
+struct LockInfo;
+struct LockMap;
+struct LockMapStripe;
+
+struct DeadlockInfoBuffer {
+ private:
+ std::vector<DeadlockPath> paths_buffer_;
+ uint32_t buffer_idx_;
+ std::mutex paths_buffer_mutex_;
+ std::vector<DeadlockPath> Normalize();
+
+ public:
+ explicit DeadlockInfoBuffer(uint32_t n_latest_dlocks)
+ : paths_buffer_(n_latest_dlocks), buffer_idx_(0) {}
+ void AddNewPath(DeadlockPath path);
+ void Resize(uint32_t target_size);
+ std::vector<DeadlockPath> PrepareBuffer();
+};
+
+struct TrackedTrxInfo {
+ autovector<TransactionID> m_neighbors;
+ uint32_t m_cf_id;
+ bool m_exclusive;
+ std::string m_waiting_key;
+};
+
+class Slice;
+class PessimisticTransactionDB;
+
+class TransactionLockMgr {
+ public:
+ TransactionLockMgr(TransactionDB* txn_db, size_t default_num_stripes,
+ int64_t max_num_locks, uint32_t max_num_deadlocks,
+ std::shared_ptr<TransactionDBMutexFactory> factory);
+ // No copying allowed
+ TransactionLockMgr(const TransactionLockMgr&) = delete;
+ void operator=(const TransactionLockMgr&) = delete;
+
+ ~TransactionLockMgr();
+
+ // Creates a new LockMap for this column family. Caller should guarantee
+ // that this column family does not already exist.
+ void AddColumnFamily(uint32_t column_family_id);
+
+ // Deletes the LockMap for this column family. Caller should guarantee that
+ // this column family is no longer in use.
+ void RemoveColumnFamily(uint32_t column_family_id);
+
+ // Attempt to lock key. If OK status is returned, the caller is responsible
+ // for calling UnLock() on this key.
+ Status TryLock(PessimisticTransaction* txn, uint32_t column_family_id,
+ const std::string& key, Env* env, bool exclusive);
+
+ // Unlock a key locked by TryLock(). txn must be the same Transaction that
+ // locked this key.
+ void UnLock(const PessimisticTransaction* txn, const TransactionKeyMap* keys,
+ Env* env);
+ void UnLock(PessimisticTransaction* txn, uint32_t column_family_id,
+ const std::string& key, Env* env);
+
+ using LockStatusData = std::unordered_multimap<uint32_t, KeyLockInfo>;
+ LockStatusData GetLockStatusData();
+ std::vector<DeadlockPath> GetDeadlockInfoBuffer();
+ void Resize(uint32_t);
+
+ private:
+ PessimisticTransactionDB* txn_db_impl_;
+
+ // Default number of lock map stripes per column family
+ const size_t default_num_stripes_;
+
+ // Limit on number of keys locked per column family
+ const int64_t max_num_locks_;
+
+ // The following lock order must be satisfied in order to avoid deadlocking
+ // ourselves.
+ // - lock_map_mutex_
+ // - stripe mutexes in ascending cf id, ascending stripe order
+ // - wait_txn_map_mutex_
+ //
+ // Must be held when accessing/modifying lock_maps_.
+ InstrumentedMutex lock_map_mutex_;
+
+ // Map of ColumnFamilyId to locked key info
+ using LockMaps = std::unordered_map<uint32_t, std::shared_ptr<LockMap>>;
+ LockMaps lock_maps_;
+
+ // Thread-local cache of entries in lock_maps_. This is an optimization
+ // to avoid acquiring a mutex in order to look up a LockMap
+ std::unique_ptr<ThreadLocalPtr> lock_maps_cache_;
+
+ // Must be held when modifying wait_txn_map_ and rev_wait_txn_map_.
+ std::mutex wait_txn_map_mutex_;
+
+ // Maps from waitee -> number of waiters.
+ HashMap<TransactionID, int> rev_wait_txn_map_;
+ // Maps from waiter -> waitee.
+ HashMap<TransactionID, TrackedTrxInfo> wait_txn_map_;
+ DeadlockInfoBuffer dlock_buffer_;
+
+ // Used to allocate mutexes/condvars to use when locking keys
+ std::shared_ptr<TransactionDBMutexFactory> mutex_factory_;
+
+ bool IsLockExpired(TransactionID txn_id, const LockInfo& lock_info, Env* env,
+ uint64_t* wait_time);
+
+ std::shared_ptr<LockMap> GetLockMap(uint32_t column_family_id);
+
+ Status AcquireWithTimeout(PessimisticTransaction* txn, LockMap* lock_map,
+ LockMapStripe* stripe, uint32_t column_family_id,
+ const std::string& key, Env* env, int64_t timeout,
+ LockInfo&& lock_info);
+
+ Status AcquireLocked(LockMap* lock_map, LockMapStripe* stripe,
+ const std::string& key, Env* env,
+ LockInfo&& lock_info, uint64_t* wait_time,
+ autovector<TransactionID>* txn_ids);
+
+ void UnLockKey(const PessimisticTransaction* txn, const std::string& key,
+ LockMapStripe* stripe, LockMap* lock_map, Env* env);
+
+ bool IncrementWaiters(const PessimisticTransaction* txn,
+ const autovector<TransactionID>& wait_ids,
+ const std::string& key, const uint32_t& cf_id,
+ const bool& exclusive, Env* const env);
+ void DecrementWaiters(const PessimisticTransaction* txn,
+ const autovector<TransactionID>& wait_ids);
+ void DecrementWaitersImpl(const PessimisticTransaction* txn,
+ const autovector<TransactionID>& wait_ids);
+};
+
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/transaction_test.cc b/src/rocksdb/utilities/transactions/transaction_test.cc
new file mode 100644
index 000000000..bdc2609f5
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/transaction_test.cc
@@ -0,0 +1,6224 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "utilities/transactions/transaction_test.h"
+
+#include <algorithm>
+#include <functional>
+#include <string>
+#include <thread>
+
+#include "db/db_impl/db_impl.h"
+#include "rocksdb/db.h"
+#include "rocksdb/options.h"
+#include "rocksdb/perf_context.h"
+#include "rocksdb/utilities/transaction.h"
+#include "rocksdb/utilities/transaction_db.h"
+#include "table/mock_table.h"
+#include "test_util/fault_injection_test_env.h"
+#include "test_util/sync_point.h"
+#include "test_util/testharness.h"
+#include "test_util/testutil.h"
+#include "test_util/transaction_test_util.h"
+#include "util/random.h"
+#include "util/string_util.h"
+#include "utilities/merge_operators.h"
+#include "utilities/merge_operators/string_append/stringappend.h"
+#include "utilities/transactions/pessimistic_transaction_db.h"
+
+#include "port/port.h"
+
+using std::string;
+
+namespace ROCKSDB_NAMESPACE {
+
+INSTANTIATE_TEST_CASE_P(
+ DBAsBaseDB, TransactionTest,
+ ::testing::Values(
+ std::make_tuple(false, false, WRITE_COMMITTED, kOrderedWrite),
+ std::make_tuple(false, true, WRITE_COMMITTED, kOrderedWrite),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite),
+ std::make_tuple(false, false, WRITE_UNPREPARED, kOrderedWrite),
+ std::make_tuple(false, true, WRITE_UNPREPARED, kOrderedWrite)));
+INSTANTIATE_TEST_CASE_P(
+ DBAsBaseDB, TransactionStressTest,
+ ::testing::Values(
+ std::make_tuple(false, false, WRITE_COMMITTED, kOrderedWrite),
+ std::make_tuple(false, true, WRITE_COMMITTED, kOrderedWrite),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite),
+ std::make_tuple(false, false, WRITE_UNPREPARED, kOrderedWrite),
+ std::make_tuple(false, true, WRITE_UNPREPARED, kOrderedWrite)));
+INSTANTIATE_TEST_CASE_P(
+ StackableDBAsBaseDB, TransactionTest,
+ ::testing::Values(
+ std::make_tuple(true, true, WRITE_COMMITTED, kOrderedWrite),
+ std::make_tuple(true, true, WRITE_PREPARED, kOrderedWrite),
+ std::make_tuple(true, true, WRITE_UNPREPARED, kOrderedWrite)));
+
+// MySQLStyleTransactionTest takes far too long for valgrind to run.
+#ifndef ROCKSDB_VALGRIND_RUN
+INSTANTIATE_TEST_CASE_P(
+ MySQLStyleTransactionTest, MySQLStyleTransactionTest,
+ ::testing::Values(
+ std::make_tuple(false, false, WRITE_COMMITTED, kOrderedWrite, false),
+ std::make_tuple(false, true, WRITE_COMMITTED, kOrderedWrite, false),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, false),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, true),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, false),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, true),
+ std::make_tuple(false, false, WRITE_UNPREPARED, kOrderedWrite, false),
+ std::make_tuple(false, false, WRITE_UNPREPARED, kOrderedWrite, true),
+ std::make_tuple(false, true, WRITE_UNPREPARED, kOrderedWrite, false),
+ std::make_tuple(false, true, WRITE_UNPREPARED, kOrderedWrite, true),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, false),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, true)));
+#endif // ROCKSDB_VALGRIND_RUN
+
+TEST_P(TransactionTest, DoubleEmptyWrite) {
+ WriteOptions write_options;
+ write_options.sync = true;
+ write_options.disableWAL = false;
+
+ WriteBatch batch;
+
+ ASSERT_OK(db->Write(write_options, &batch));
+ ASSERT_OK(db->Write(write_options, &batch));
+
+ // Also test committing empty transactions in 2PC
+ TransactionOptions txn_options;
+ Transaction* txn0 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(txn0->SetName("xid"));
+ ASSERT_OK(txn0->Prepare());
+ ASSERT_OK(txn0->Commit());
+ delete txn0;
+
+ // Also test that it works during recovery
+ txn0 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(txn0->SetName("xid2"));
+ txn0->Put(Slice("foo0"), Slice("bar0a"));
+ ASSERT_OK(txn0->Prepare());
+ delete txn0;
+ reinterpret_cast<PessimisticTransactionDB*>(db)->TEST_Crash();
+ ASSERT_OK(ReOpenNoDelete());
+ assert(db != nullptr);
+ txn0 = db->GetTransactionByName("xid2");
+ ASSERT_OK(txn0->Commit());
+ delete txn0;
+}
+
+TEST_P(TransactionTest, SuccessTest) {
+ ASSERT_OK(db->ResetStats());
+
+ WriteOptions write_options;
+ ReadOptions read_options;
+ std::string value;
+
+ ASSERT_OK(db->Put(write_options, Slice("foo"), Slice("bar")));
+ ASSERT_OK(db->Put(write_options, Slice("foo2"), Slice("bar")));
+
+ Transaction* txn = db->BeginTransaction(write_options, TransactionOptions());
+ ASSERT_TRUE(txn);
+
+ ASSERT_EQ(0, txn->GetNumPuts());
+ ASSERT_LE(0, txn->GetID());
+
+ ASSERT_OK(txn->GetForUpdate(read_options, "foo", &value));
+ ASSERT_EQ(value, "bar");
+
+ ASSERT_OK(txn->Put(Slice("foo"), Slice("bar2")));
+
+ ASSERT_EQ(1, txn->GetNumPuts());
+
+ ASSERT_OK(txn->GetForUpdate(read_options, "foo", &value));
+ ASSERT_EQ(value, "bar2");
+
+ ASSERT_OK(txn->Commit());
+
+ ASSERT_OK(db->Get(read_options, "foo", &value));
+ ASSERT_EQ(value, "bar2");
+
+ delete txn;
+}
+
+// The test clarifies the contract of do_validate and assume_tracked
+// in GetForUpdate and Put/Merge/Delete
+TEST_P(TransactionTest, AssumeExclusiveTracked) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ std::string value;
+ Status s;
+ TransactionOptions txn_options;
+ txn_options.lock_timeout = 1;
+ const bool EXCLUSIVE = true;
+ const bool DO_VALIDATE = true;
+ const bool ASSUME_LOCKED = true;
+
+ Transaction* txn = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn);
+ txn->SetSnapshot();
+
+ // commit a value after the snapshot is taken
+ ASSERT_OK(db->Put(write_options, Slice("foo"), Slice("bar")));
+
+ // By default write should fail to the commit after our snapshot
+ s = txn->GetForUpdate(read_options, "foo", &value, EXCLUSIVE);
+ ASSERT_TRUE(s.IsBusy());
+ // But the user could direct the db to skip validating the snapshot. The read
+ // value then should be the most recently committed
+ ASSERT_OK(
+ txn->GetForUpdate(read_options, "foo", &value, EXCLUSIVE, !DO_VALIDATE));
+ ASSERT_EQ(value, "bar");
+
+ // Although ValidateSnapshot is skipped the key must have still got locked
+ s = db->Put(write_options, Slice("foo"), Slice("bar"));
+ ASSERT_TRUE(s.IsTimedOut());
+
+ // By default the write operations should fail due to the commit after the
+ // snapshot
+ s = txn->Put(Slice("foo"), Slice("bar1"));
+ ASSERT_TRUE(s.IsBusy());
+ s = txn->Put(db->DefaultColumnFamily(), Slice("foo"), Slice("bar1"),
+ !ASSUME_LOCKED);
+ ASSERT_TRUE(s.IsBusy());
+ // But the user could direct the db that it already assumes exclusive lock on
+ // the key due to the previous GetForUpdate call.
+ ASSERT_OK(txn->Put(db->DefaultColumnFamily(), Slice("foo"), Slice("bar1"),
+ ASSUME_LOCKED));
+ ASSERT_OK(txn->Merge(db->DefaultColumnFamily(), Slice("foo"), Slice("bar2"),
+ ASSUME_LOCKED));
+ ASSERT_OK(
+ txn->Delete(db->DefaultColumnFamily(), Slice("foo"), ASSUME_LOCKED));
+ ASSERT_OK(txn->SingleDelete(db->DefaultColumnFamily(), Slice("foo"),
+ ASSUME_LOCKED));
+
+ txn->Rollback();
+ delete txn;
+}
+
+// This test clarifies the contract of ValidateSnapshot
+TEST_P(TransactionTest, ValidateSnapshotTest) {
+ for (bool with_flush : {true}) {
+ for (bool with_2pc : {true}) {
+ ASSERT_OK(ReOpen());
+ WriteOptions write_options;
+ ReadOptions read_options;
+ std::string value;
+
+ assert(db != nullptr);
+ Transaction* txn1 =
+ db->BeginTransaction(write_options, TransactionOptions());
+ ASSERT_TRUE(txn1);
+ ASSERT_OK(txn1->Put(Slice("foo"), Slice("bar1")));
+ if (with_2pc) {
+ ASSERT_OK(txn1->SetName("xid1"));
+ ASSERT_OK(txn1->Prepare());
+ }
+
+ if (with_flush) {
+ auto db_impl = reinterpret_cast<DBImpl*>(db->GetRootDB());
+ db_impl->TEST_FlushMemTable(true);
+ // Make sure the flushed memtable is not kept in memory
+ int max_memtable_in_history =
+ std::max(
+ options.max_write_buffer_number,
+ static_cast<int>(options.max_write_buffer_size_to_maintain) /
+ static_cast<int>(options.write_buffer_size)) +
+ 1;
+ for (int i = 0; i < max_memtable_in_history; i++) {
+ db->Put(write_options, Slice("key"), Slice("value"));
+ db_impl->TEST_FlushMemTable(true);
+ }
+ }
+
+ Transaction* txn2 =
+ db->BeginTransaction(write_options, TransactionOptions());
+ ASSERT_TRUE(txn2);
+ txn2->SetSnapshot();
+
+ ASSERT_OK(txn1->Commit());
+ delete txn1;
+
+ auto pes_txn2 = dynamic_cast<PessimisticTransaction*>(txn2);
+ // Test the simple case where the key is not tracked yet
+ auto trakced_seq = kMaxSequenceNumber;
+ auto s = pes_txn2->ValidateSnapshot(db->DefaultColumnFamily(), "foo",
+ &trakced_seq);
+ ASSERT_TRUE(s.IsBusy());
+ delete txn2;
+ }
+ }
+}
+
+TEST_P(TransactionTest, WaitingTxn) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ TransactionOptions txn_options;
+ string value;
+ Status s;
+
+ txn_options.lock_timeout = 1;
+ s = db->Put(write_options, Slice("foo"), Slice("bar"));
+ ASSERT_OK(s);
+
+ /* create second cf */
+ ColumnFamilyHandle* cfa;
+ ColumnFamilyOptions cf_options;
+ s = db->CreateColumnFamily(cf_options, "CFA", &cfa);
+ ASSERT_OK(s);
+ s = db->Put(write_options, cfa, Slice("foo"), Slice("bar"));
+ ASSERT_OK(s);
+
+ Transaction* txn1 = db->BeginTransaction(write_options, txn_options);
+ Transaction* txn2 = db->BeginTransaction(write_options, txn_options);
+ TransactionID id1 = txn1->GetID();
+ ASSERT_TRUE(txn1);
+ ASSERT_TRUE(txn2);
+
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "TransactionLockMgr::AcquireWithTimeout:WaitingTxn", [&](void* /*arg*/) {
+ std::string key;
+ uint32_t cf_id;
+ std::vector<TransactionID> wait = txn2->GetWaitingTxns(&cf_id, &key);
+ ASSERT_EQ(key, "foo");
+ ASSERT_EQ(wait.size(), 1);
+ ASSERT_EQ(wait[0], id1);
+ ASSERT_EQ(cf_id, 0U);
+ });
+
+ get_perf_context()->Reset();
+ // lock key in default cf
+ s = txn1->GetForUpdate(read_options, "foo", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "bar");
+ ASSERT_EQ(get_perf_context()->key_lock_wait_count, 0);
+
+ // lock key in cfa
+ s = txn1->GetForUpdate(read_options, cfa, "foo", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "bar");
+ ASSERT_EQ(get_perf_context()->key_lock_wait_count, 0);
+
+ auto lock_data = db->GetLockStatusData();
+ // Locked keys exist in both column family.
+ ASSERT_EQ(lock_data.size(), 2);
+
+ auto cf_iterator = lock_data.begin();
+
+ // The iterator points to an unordered_multimap
+ // thus the test can not assume any particular order.
+
+ // Column family is 1 or 0 (cfa).
+ if (cf_iterator->first != 1 && cf_iterator->first != 0) {
+ FAIL();
+ }
+ // The locked key is "foo" and is locked by txn1
+ ASSERT_EQ(cf_iterator->second.key, "foo");
+ ASSERT_EQ(cf_iterator->second.ids.size(), 1);
+ ASSERT_EQ(cf_iterator->second.ids[0], txn1->GetID());
+
+ cf_iterator++;
+
+ // Column family is 0 (default) or 1.
+ if (cf_iterator->first != 1 && cf_iterator->first != 0) {
+ FAIL();
+ }
+ // The locked key is "foo" and is locked by txn1
+ ASSERT_EQ(cf_iterator->second.key, "foo");
+ ASSERT_EQ(cf_iterator->second.ids.size(), 1);
+ ASSERT_EQ(cf_iterator->second.ids[0], txn1->GetID());
+
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+
+ s = txn2->GetForUpdate(read_options, "foo", &value);
+ ASSERT_TRUE(s.IsTimedOut());
+ ASSERT_EQ(s.ToString(), "Operation timed out: Timeout waiting to lock key");
+ ASSERT_EQ(get_perf_context()->key_lock_wait_count, 1);
+ ASSERT_GE(get_perf_context()->key_lock_wait_time, 0);
+
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->ClearAllCallBacks();
+
+ delete cfa;
+ delete txn1;
+ delete txn2;
+}
+
+TEST_P(TransactionTest, SharedLocks) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ TransactionOptions txn_options;
+ Status s;
+
+ txn_options.lock_timeout = 1;
+ s = db->Put(write_options, Slice("foo"), Slice("bar"));
+ ASSERT_OK(s);
+
+ Transaction* txn1 = db->BeginTransaction(write_options, txn_options);
+ Transaction* txn2 = db->BeginTransaction(write_options, txn_options);
+ Transaction* txn3 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn1);
+ ASSERT_TRUE(txn2);
+ ASSERT_TRUE(txn3);
+
+ // Test shared access between txns
+ s = txn1->GetForUpdate(read_options, "foo", nullptr, false /* exclusive */);
+ ASSERT_OK(s);
+
+ s = txn2->GetForUpdate(read_options, "foo", nullptr, false /* exclusive */);
+ ASSERT_OK(s);
+
+ s = txn3->GetForUpdate(read_options, "foo", nullptr, false /* exclusive */);
+ ASSERT_OK(s);
+
+ auto lock_data = db->GetLockStatusData();
+ ASSERT_EQ(lock_data.size(), 1);
+
+ auto cf_iterator = lock_data.begin();
+ ASSERT_EQ(cf_iterator->second.key, "foo");
+
+ // We compare whether the set of txns locking this key is the same. To do
+ // this, we need to sort both vectors so that the comparison is done
+ // correctly.
+ std::vector<TransactionID> expected_txns = {txn1->GetID(), txn2->GetID(),
+ txn3->GetID()};
+ std::vector<TransactionID> lock_txns = cf_iterator->second.ids;
+ ASSERT_EQ(expected_txns, lock_txns);
+ ASSERT_FALSE(cf_iterator->second.exclusive);
+
+ txn1->Rollback();
+ txn2->Rollback();
+ txn3->Rollback();
+
+ // Test txn1 and txn2 sharing a lock and txn3 trying to obtain it.
+ s = txn1->GetForUpdate(read_options, "foo", nullptr, false /* exclusive */);
+ ASSERT_OK(s);
+
+ s = txn2->GetForUpdate(read_options, "foo", nullptr, false /* exclusive */);
+ ASSERT_OK(s);
+
+ s = txn3->GetForUpdate(read_options, "foo", nullptr);
+ ASSERT_TRUE(s.IsTimedOut());
+ ASSERT_EQ(s.ToString(), "Operation timed out: Timeout waiting to lock key");
+
+ txn1->UndoGetForUpdate("foo");
+ s = txn3->GetForUpdate(read_options, "foo", nullptr);
+ ASSERT_TRUE(s.IsTimedOut());
+ ASSERT_EQ(s.ToString(), "Operation timed out: Timeout waiting to lock key");
+
+ txn2->UndoGetForUpdate("foo");
+ s = txn3->GetForUpdate(read_options, "foo", nullptr);
+ ASSERT_OK(s);
+
+ txn1->Rollback();
+ txn2->Rollback();
+ txn3->Rollback();
+
+ // Test txn1 and txn2 sharing a lock and txn2 trying to upgrade lock.
+ s = txn1->GetForUpdate(read_options, "foo", nullptr, false /* exclusive */);
+ ASSERT_OK(s);
+
+ s = txn2->GetForUpdate(read_options, "foo", nullptr, false /* exclusive */);
+ ASSERT_OK(s);
+
+ s = txn2->GetForUpdate(read_options, "foo", nullptr);
+ ASSERT_TRUE(s.IsTimedOut());
+ ASSERT_EQ(s.ToString(), "Operation timed out: Timeout waiting to lock key");
+
+ txn1->UndoGetForUpdate("foo");
+ s = txn2->GetForUpdate(read_options, "foo", nullptr);
+ ASSERT_OK(s);
+
+ ASSERT_OK(txn1->Rollback());
+ ASSERT_OK(txn2->Rollback());
+
+ // Test txn1 trying to downgrade its lock.
+ s = txn1->GetForUpdate(read_options, "foo", nullptr, true /* exclusive */);
+ ASSERT_OK(s);
+
+ s = txn2->GetForUpdate(read_options, "foo", nullptr, false /* exclusive */);
+ ASSERT_TRUE(s.IsTimedOut());
+ ASSERT_EQ(s.ToString(), "Operation timed out: Timeout waiting to lock key");
+
+ // Should still fail after "downgrading".
+ s = txn1->GetForUpdate(read_options, "foo", nullptr, false /* exclusive */);
+ ASSERT_OK(s);
+
+ s = txn2->GetForUpdate(read_options, "foo", nullptr, false /* exclusive */);
+ ASSERT_TRUE(s.IsTimedOut());
+ ASSERT_EQ(s.ToString(), "Operation timed out: Timeout waiting to lock key");
+
+ txn1->Rollback();
+ txn2->Rollback();
+
+ // Test txn1 holding an exclusive lock and txn2 trying to obtain shared
+ // access.
+ s = txn1->GetForUpdate(read_options, "foo", nullptr);
+ ASSERT_OK(s);
+
+ s = txn2->GetForUpdate(read_options, "foo", nullptr, false /* exclusive */);
+ ASSERT_TRUE(s.IsTimedOut());
+ ASSERT_EQ(s.ToString(), "Operation timed out: Timeout waiting to lock key");
+
+ txn1->UndoGetForUpdate("foo");
+ s = txn2->GetForUpdate(read_options, "foo", nullptr, false /* exclusive */);
+ ASSERT_OK(s);
+
+ delete txn1;
+ delete txn2;
+ delete txn3;
+}
+
+TEST_P(TransactionTest, DeadlockCycleShared) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ TransactionOptions txn_options;
+
+ txn_options.lock_timeout = 1000000;
+ txn_options.deadlock_detect = true;
+
+ // Set up a wait for chain like this:
+ //
+ // Tn -> T(n*2)
+ // Tn -> T(n*2 + 1)
+ //
+ // So we have:
+ // T1 -> T2 -> T4 ...
+ // | |> T5 ...
+ // |> T3 -> T6 ...
+ // |> T7 ...
+ // up to T31, then T[16 - 31] -> T1.
+ // Note that Tn holds lock on floor(n / 2).
+
+ std::vector<Transaction*> txns(31);
+
+ for (uint32_t i = 0; i < 31; i++) {
+ txns[i] = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txns[i]);
+ auto s = txns[i]->GetForUpdate(read_options, ToString((i + 1) / 2), nullptr,
+ false /* exclusive */);
+ ASSERT_OK(s);
+ }
+
+ std::atomic<uint32_t> checkpoints(0);
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "TransactionLockMgr::AcquireWithTimeout:WaitingTxn",
+ [&](void* /*arg*/) { checkpoints.fetch_add(1); });
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+
+ // We want the leaf transactions to block and hold everyone back.
+ std::vector<port::Thread> threads;
+ for (uint32_t i = 0; i < 15; i++) {
+ std::function<void()> blocking_thread = [&, i] {
+ auto s = txns[i]->GetForUpdate(read_options, ToString(i + 1), nullptr,
+ true /* exclusive */);
+ ASSERT_OK(s);
+ txns[i]->Rollback();
+ delete txns[i];
+ };
+ threads.emplace_back(blocking_thread);
+ }
+
+ // Wait until all threads are waiting on each other.
+ while (checkpoints.load() != 15) {
+ /* sleep override */
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
+ }
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->ClearAllCallBacks();
+
+ // Complete the cycle T[16 - 31] -> T1
+ for (uint32_t i = 15; i < 31; i++) {
+ auto s =
+ txns[i]->GetForUpdate(read_options, "0", nullptr, true /* exclusive */);
+ ASSERT_TRUE(s.IsDeadlock());
+
+ // Calculate next buffer len, plateau at 5 when 5 records are inserted.
+ const uint32_t curr_dlock_buffer_len_ =
+ (i - 14 > kInitialMaxDeadlocks) ? kInitialMaxDeadlocks : (i - 14);
+
+ auto dlock_buffer = db->GetDeadlockInfoBuffer();
+ ASSERT_EQ(dlock_buffer.size(), curr_dlock_buffer_len_);
+ auto dlock_entry = dlock_buffer[0].path;
+ ASSERT_EQ(dlock_entry.size(), kInitialMaxDeadlocks);
+ int64_t pre_deadlock_time = dlock_buffer[0].deadlock_time;
+ int64_t cur_deadlock_time = 0;
+ for (auto const& dl_path_rec : dlock_buffer) {
+ cur_deadlock_time = dl_path_rec.deadlock_time;
+ ASSERT_NE(cur_deadlock_time, 0);
+ ASSERT_TRUE(cur_deadlock_time <= pre_deadlock_time);
+ pre_deadlock_time = cur_deadlock_time;
+ }
+
+ int64_t curr_waiting_key = 0;
+
+ // Offset of each txn id from the root of the shared dlock tree's txn id.
+ int64_t offset_root = dlock_entry[0].m_txn_id - 1;
+ // Offset of the final entry in the dlock path from the root's txn id.
+ TransactionID leaf_id =
+ dlock_entry[dlock_entry.size() - 1].m_txn_id - offset_root;
+
+ for (auto it = dlock_entry.rbegin(); it != dlock_entry.rend(); ++it) {
+ auto dl_node = *it;
+ ASSERT_EQ(dl_node.m_txn_id, offset_root + leaf_id);
+ ASSERT_EQ(dl_node.m_cf_id, 0U);
+ ASSERT_EQ(dl_node.m_waiting_key, ToString(curr_waiting_key));
+ ASSERT_EQ(dl_node.m_exclusive, true);
+
+ if (curr_waiting_key == 0) {
+ curr_waiting_key = leaf_id;
+ }
+ curr_waiting_key /= 2;
+ leaf_id /= 2;
+ }
+ }
+
+ // Rollback the leaf transaction.
+ for (uint32_t i = 15; i < 31; i++) {
+ txns[i]->Rollback();
+ delete txns[i];
+ }
+
+ for (auto& t : threads) {
+ t.join();
+ }
+
+ // Downsize the buffer and verify the 3 latest deadlocks are preserved.
+ auto dlock_buffer_before_resize = db->GetDeadlockInfoBuffer();
+ db->SetDeadlockInfoBufferSize(3);
+ auto dlock_buffer_after_resize = db->GetDeadlockInfoBuffer();
+ ASSERT_EQ(dlock_buffer_after_resize.size(), 3);
+
+ for (uint32_t i = 0; i < dlock_buffer_after_resize.size(); i++) {
+ for (uint32_t j = 0; j < dlock_buffer_after_resize[i].path.size(); j++) {
+ ASSERT_EQ(dlock_buffer_after_resize[i].path[j].m_txn_id,
+ dlock_buffer_before_resize[i].path[j].m_txn_id);
+ }
+ }
+
+ // Upsize the buffer and verify the 3 latest dealocks are preserved.
+ dlock_buffer_before_resize = db->GetDeadlockInfoBuffer();
+ db->SetDeadlockInfoBufferSize(5);
+ dlock_buffer_after_resize = db->GetDeadlockInfoBuffer();
+ ASSERT_EQ(dlock_buffer_after_resize.size(), 3);
+
+ for (uint32_t i = 0; i < dlock_buffer_before_resize.size(); i++) {
+ for (uint32_t j = 0; j < dlock_buffer_before_resize[i].path.size(); j++) {
+ ASSERT_EQ(dlock_buffer_after_resize[i].path[j].m_txn_id,
+ dlock_buffer_before_resize[i].path[j].m_txn_id);
+ }
+ }
+
+ // Downsize to 0 and verify the size is consistent.
+ dlock_buffer_before_resize = db->GetDeadlockInfoBuffer();
+ db->SetDeadlockInfoBufferSize(0);
+ dlock_buffer_after_resize = db->GetDeadlockInfoBuffer();
+ ASSERT_EQ(dlock_buffer_after_resize.size(), 0);
+
+ // Upsize from 0 to verify the size is persistent.
+ dlock_buffer_before_resize = db->GetDeadlockInfoBuffer();
+ db->SetDeadlockInfoBufferSize(3);
+ dlock_buffer_after_resize = db->GetDeadlockInfoBuffer();
+ ASSERT_EQ(dlock_buffer_after_resize.size(), 0);
+
+ // Contrived case of shared lock of cycle size 2 to verify that a shared
+ // lock causing a deadlock is correctly reported as "shared" in the buffer.
+ std::vector<Transaction*> txns_shared(2);
+
+ // Create a cycle of size 2.
+ for (uint32_t i = 0; i < 2; i++) {
+ txns_shared[i] = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txns_shared[i]);
+ auto s = txns_shared[i]->GetForUpdate(read_options, ToString(i), nullptr);
+ ASSERT_OK(s);
+ }
+
+ std::atomic<uint32_t> checkpoints_shared(0);
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "TransactionLockMgr::AcquireWithTimeout:WaitingTxn",
+ [&](void* /*arg*/) { checkpoints_shared.fetch_add(1); });
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+
+ std::vector<port::Thread> threads_shared;
+ for (uint32_t i = 0; i < 1; i++) {
+ std::function<void()> blocking_thread = [&, i] {
+ auto s =
+ txns_shared[i]->GetForUpdate(read_options, ToString(i + 1), nullptr);
+ ASSERT_OK(s);
+ txns_shared[i]->Rollback();
+ delete txns_shared[i];
+ };
+ threads_shared.emplace_back(blocking_thread);
+ }
+
+ // Wait until all threads are waiting on each other.
+ while (checkpoints_shared.load() != 1) {
+ /* sleep override */
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
+ }
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->ClearAllCallBacks();
+
+ // Complete the cycle T2 -> T1 with a shared lock.
+ auto s = txns_shared[1]->GetForUpdate(read_options, "0", nullptr, false);
+ ASSERT_TRUE(s.IsDeadlock());
+
+ auto dlock_buffer = db->GetDeadlockInfoBuffer();
+
+ // Verify the size of the buffer and the single path.
+ ASSERT_EQ(dlock_buffer.size(), 1);
+ ASSERT_EQ(dlock_buffer[0].path.size(), 2);
+
+ // Verify the exclusivity field of the transactions in the deadlock path.
+ ASSERT_TRUE(dlock_buffer[0].path[0].m_exclusive);
+ ASSERT_FALSE(dlock_buffer[0].path[1].m_exclusive);
+ txns_shared[1]->Rollback();
+ delete txns_shared[1];
+
+ for (auto& t : threads_shared) {
+ t.join();
+ }
+}
+
+#ifndef ROCKSDB_VALGRIND_RUN
+TEST_P(TransactionStressTest, DeadlockCycle) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ TransactionOptions txn_options;
+
+ // offset by 2 from the max depth to test edge case
+ const uint32_t kMaxCycleLength = 52;
+
+ txn_options.lock_timeout = 1000000;
+ txn_options.deadlock_detect = true;
+
+ for (uint32_t len = 2; len < kMaxCycleLength; len++) {
+ // Set up a long wait for chain like this:
+ //
+ // T1 -> T2 -> T3 -> ... -> Tlen
+
+ std::vector<Transaction*> txns(len);
+
+ for (uint32_t i = 0; i < len; i++) {
+ txns[i] = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txns[i]);
+ auto s = txns[i]->GetForUpdate(read_options, ToString(i), nullptr);
+ ASSERT_OK(s);
+ }
+
+ std::atomic<uint32_t> checkpoints(0);
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "TransactionLockMgr::AcquireWithTimeout:WaitingTxn",
+ [&](void* /*arg*/) { checkpoints.fetch_add(1); });
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+
+ // We want the last transaction in the chain to block and hold everyone
+ // back.
+ std::vector<port::Thread> threads;
+ for (uint32_t i = 0; i < len - 1; i++) {
+ std::function<void()> blocking_thread = [&, i] {
+ auto s = txns[i]->GetForUpdate(read_options, ToString(i + 1), nullptr);
+ ASSERT_OK(s);
+ txns[i]->Rollback();
+ delete txns[i];
+ };
+ threads.emplace_back(blocking_thread);
+ }
+
+ // Wait until all threads are waiting on each other.
+ while (checkpoints.load() != len - 1) {
+ /* sleep override */
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
+ }
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->ClearAllCallBacks();
+
+ // Complete the cycle Tlen -> T1
+ auto s = txns[len - 1]->GetForUpdate(read_options, "0", nullptr);
+ ASSERT_TRUE(s.IsDeadlock());
+
+ const uint32_t dlock_buffer_size_ = (len - 1 > 5) ? 5 : (len - 1);
+ uint32_t curr_waiting_key = 0;
+ TransactionID curr_txn_id = txns[0]->GetID();
+
+ auto dlock_buffer = db->GetDeadlockInfoBuffer();
+ ASSERT_EQ(dlock_buffer.size(), dlock_buffer_size_);
+ uint32_t check_len = len;
+ bool check_limit_flag = false;
+
+ // Special case for a deadlock path that exceeds the maximum depth.
+ if (len > 50) {
+ check_len = 0;
+ check_limit_flag = true;
+ }
+ auto dlock_entry = dlock_buffer[0].path;
+ ASSERT_EQ(dlock_entry.size(), check_len);
+ ASSERT_EQ(dlock_buffer[0].limit_exceeded, check_limit_flag);
+
+ int64_t pre_deadlock_time = dlock_buffer[0].deadlock_time;
+ int64_t cur_deadlock_time = 0;
+ for (auto const& dl_path_rec : dlock_buffer) {
+ cur_deadlock_time = dl_path_rec.deadlock_time;
+ ASSERT_NE(cur_deadlock_time, 0);
+ ASSERT_TRUE(cur_deadlock_time <= pre_deadlock_time);
+ pre_deadlock_time = cur_deadlock_time;
+ }
+
+ // Iterates backwards over path verifying decreasing txn_ids.
+ for (auto it = dlock_entry.rbegin(); it != dlock_entry.rend(); ++it) {
+ auto dl_node = *it;
+ ASSERT_EQ(dl_node.m_txn_id, len + curr_txn_id - 1);
+ ASSERT_EQ(dl_node.m_cf_id, 0u);
+ ASSERT_EQ(dl_node.m_waiting_key, ToString(curr_waiting_key));
+ ASSERT_EQ(dl_node.m_exclusive, true);
+
+ curr_txn_id--;
+ if (curr_waiting_key == 0) {
+ curr_waiting_key = len;
+ }
+ curr_waiting_key--;
+ }
+
+ // Rollback the last transaction.
+ txns[len - 1]->Rollback();
+ delete txns[len - 1];
+
+ for (auto& t : threads) {
+ t.join();
+ }
+ }
+}
+
+TEST_P(TransactionStressTest, DeadlockStress) {
+ const uint32_t NUM_TXN_THREADS = 10;
+ const uint32_t NUM_KEYS = 100;
+ const uint32_t NUM_ITERS = 10000;
+
+ WriteOptions write_options;
+ ReadOptions read_options;
+ TransactionOptions txn_options;
+
+ txn_options.lock_timeout = 1000000;
+ txn_options.deadlock_detect = true;
+ std::vector<std::string> keys;
+
+ for (uint32_t i = 0; i < NUM_KEYS; i++) {
+ db->Put(write_options, Slice(ToString(i)), Slice(""));
+ keys.push_back(ToString(i));
+ }
+
+ size_t tid = std::hash<std::thread::id>()(std::this_thread::get_id());
+ Random rnd(static_cast<uint32_t>(tid));
+ std::function<void(uint32_t)> stress_thread = [&](uint32_t seed) {
+ std::default_random_engine g(seed);
+
+ Transaction* txn;
+ for (uint32_t i = 0; i < NUM_ITERS; i++) {
+ txn = db->BeginTransaction(write_options, txn_options);
+ auto random_keys = keys;
+ std::shuffle(random_keys.begin(), random_keys.end(), g);
+
+ // Lock keys in random order.
+ for (const auto& k : random_keys) {
+ // Lock mostly for shared access, but exclusive 1/4 of the time.
+ auto s =
+ txn->GetForUpdate(read_options, k, nullptr, txn->GetID() % 4 == 0);
+ if (!s.ok()) {
+ ASSERT_TRUE(s.IsDeadlock());
+ txn->Rollback();
+ break;
+ }
+ }
+
+ delete txn;
+ }
+ };
+
+ std::vector<port::Thread> threads;
+ for (uint32_t i = 0; i < NUM_TXN_THREADS; i++) {
+ threads.emplace_back(stress_thread, rnd.Next());
+ }
+
+ for (auto& t : threads) {
+ t.join();
+ }
+}
+#endif // ROCKSDB_VALGRIND_RUN
+
+TEST_P(TransactionTest, CommitTimeBatchFailTest) {
+ WriteOptions write_options;
+ TransactionOptions txn_options;
+
+ std::string value;
+ Status s;
+
+ Transaction* txn1 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn1);
+
+ ASSERT_OK(txn1->GetCommitTimeWriteBatch()->Put("cat", "dog"));
+
+ s = txn1->Put("foo", "bar");
+ ASSERT_OK(s);
+
+ // fails due to non-empty commit-time batch
+ s = txn1->Commit();
+ ASSERT_EQ(s, Status::InvalidArgument());
+
+ delete txn1;
+}
+
+TEST_P(TransactionTest, LogMarkLeakTest) {
+ TransactionOptions txn_options;
+ WriteOptions write_options;
+ options.write_buffer_size = 1024;
+ ASSERT_OK(ReOpenNoDelete());
+ assert(db != nullptr);
+ Random rnd(47);
+ std::vector<Transaction*> txns;
+ DBImpl* db_impl = reinterpret_cast<DBImpl*>(db->GetRootDB());
+ // At the beginning there should be no log containing prepare data
+ ASSERT_EQ(db_impl->TEST_FindMinLogContainingOutstandingPrep(), 0);
+ for (size_t i = 0; i < 100; i++) {
+ Transaction* txn = db->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(txn->SetName("xid" + ToString(i)));
+ ASSERT_OK(txn->Put(Slice("foo" + ToString(i)), Slice("bar")));
+ ASSERT_OK(txn->Prepare());
+ ASSERT_GT(db_impl->TEST_FindMinLogContainingOutstandingPrep(), 0);
+ if (rnd.OneIn(5)) {
+ txns.push_back(txn);
+ } else {
+ ASSERT_OK(txn->Commit());
+ delete txn;
+ }
+ db_impl->TEST_FlushMemTable(true);
+ }
+ for (auto txn : txns) {
+ ASSERT_OK(txn->Commit());
+ delete txn;
+ }
+ // At the end there should be no log left containing prepare data
+ ASSERT_EQ(db_impl->TEST_FindMinLogContainingOutstandingPrep(), 0);
+ // Make sure that the underlying data structures are properly truncated and
+ // cause not leak
+ ASSERT_EQ(db_impl->TEST_PreparedSectionCompletedSize(), 0);
+ ASSERT_EQ(db_impl->TEST_LogsWithPrepSize(), 0);
+}
+
+TEST_P(TransactionTest, SimpleTwoPhaseTransactionTest) {
+ for (bool cwb4recovery : {true, false}) {
+ ASSERT_OK(ReOpen());
+ WriteOptions write_options;
+ ReadOptions read_options;
+
+ TransactionOptions txn_options;
+ txn_options.use_only_the_last_commit_time_batch_for_recovery = cwb4recovery;
+
+ string value;
+ Status s;
+
+ DBImpl* db_impl = reinterpret_cast<DBImpl*>(db->GetRootDB());
+
+ Transaction* txn = db->BeginTransaction(write_options, txn_options);
+ s = txn->SetName("xid");
+ ASSERT_OK(s);
+
+ ASSERT_EQ(db->GetTransactionByName("xid"), txn);
+
+ // transaction put
+ s = txn->Put(Slice("foo"), Slice("bar"));
+ ASSERT_OK(s);
+ ASSERT_EQ(1, txn->GetNumPuts());
+
+ // regular db put
+ s = db->Put(write_options, Slice("foo2"), Slice("bar2"));
+ ASSERT_OK(s);
+ ASSERT_EQ(1, txn->GetNumPuts());
+
+ // regular db read
+ db->Get(read_options, "foo2", &value);
+ ASSERT_EQ(value, "bar2");
+
+ // commit time put
+ txn->GetCommitTimeWriteBatch()->Put(Slice("gtid"), Slice("dogs"));
+ txn->GetCommitTimeWriteBatch()->Put(Slice("gtid2"), Slice("cats"));
+
+ // nothing has been prepped yet
+ ASSERT_EQ(db_impl->TEST_FindMinLogContainingOutstandingPrep(), 0);
+
+ s = txn->Prepare();
+ ASSERT_OK(s);
+
+ // data not im mem yet
+ s = db->Get(read_options, Slice("foo"), &value);
+ ASSERT_TRUE(s.IsNotFound());
+ s = db->Get(read_options, Slice("gtid"), &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ // find trans in list of prepared transactions
+ std::vector<Transaction*> prepared_trans;
+ db->GetAllPreparedTransactions(&prepared_trans);
+ ASSERT_EQ(prepared_trans.size(), 1);
+ ASSERT_EQ(prepared_trans.front()->GetName(), "xid");
+
+ auto log_containing_prep =
+ db_impl->TEST_FindMinLogContainingOutstandingPrep();
+ ASSERT_GT(log_containing_prep, 0);
+
+ // make commit
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ // value is now available
+ s = db->Get(read_options, "foo", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "bar");
+
+ if (!cwb4recovery) {
+ s = db->Get(read_options, "gtid", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "dogs");
+
+ s = db->Get(read_options, "gtid2", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "cats");
+ }
+
+ // we already committed
+ s = txn->Commit();
+ ASSERT_EQ(s, Status::InvalidArgument());
+
+ // no longer is prepared results
+ db->GetAllPreparedTransactions(&prepared_trans);
+ ASSERT_EQ(prepared_trans.size(), 0);
+ ASSERT_EQ(db->GetTransactionByName("xid"), nullptr);
+
+ // heap should not care about prepared section anymore
+ ASSERT_EQ(db_impl->TEST_FindMinLogContainingOutstandingPrep(), 0);
+
+ switch (txn_db_options.write_policy) {
+ case WRITE_COMMITTED:
+ // but now our memtable should be referencing the prep section
+ ASSERT_GE(log_containing_prep, db_impl->MinLogNumberToKeep());
+ ASSERT_EQ(log_containing_prep,
+ db_impl->TEST_FindMinPrepLogReferencedByMemTable());
+ break;
+ case WRITE_PREPARED:
+ case WRITE_UNPREPARED:
+ // In these modes memtable do not ref the prep sections
+ ASSERT_EQ(0, db_impl->TEST_FindMinPrepLogReferencedByMemTable());
+ break;
+ default:
+ assert(false);
+ }
+
+ db_impl->TEST_FlushMemTable(true);
+ // After flush the recoverable state must be visible
+ if (cwb4recovery) {
+ s = db->Get(read_options, "gtid", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "dogs");
+
+ s = db->Get(read_options, "gtid2", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "cats");
+ }
+
+ // after memtable flush we can now relese the log
+ ASSERT_GT(db_impl->MinLogNumberToKeep(), log_containing_prep);
+ ASSERT_EQ(0, db_impl->TEST_FindMinPrepLogReferencedByMemTable());
+
+ delete txn;
+
+ if (cwb4recovery) {
+ // kill and reopen to trigger recovery
+ s = ReOpenNoDelete();
+ ASSERT_OK(s);
+ assert(db != nullptr);
+ s = db->Get(read_options, "gtid", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "dogs");
+
+ s = db->Get(read_options, "gtid2", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "cats");
+ }
+ }
+}
+
+TEST_P(TransactionTest, TwoPhaseNameTest) {
+ Status s;
+
+ WriteOptions write_options;
+ TransactionOptions txn_options;
+ Transaction* txn1 = db->BeginTransaction(write_options, txn_options);
+ Transaction* txn2 = db->BeginTransaction(write_options, txn_options);
+ Transaction* txn3 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn3);
+ delete txn3;
+
+ // cant prepare txn without name
+ s = txn1->Prepare();
+ ASSERT_EQ(s, Status::InvalidArgument());
+
+ // name too short
+ s = txn1->SetName("");
+ ASSERT_EQ(s, Status::InvalidArgument());
+
+ // name too long
+ s = txn1->SetName(std::string(513, 'x'));
+ ASSERT_EQ(s, Status::InvalidArgument());
+
+ // valid set name
+ s = txn1->SetName("name1");
+ ASSERT_OK(s);
+
+ // cant have duplicate name
+ s = txn2->SetName("name1");
+ ASSERT_EQ(s, Status::InvalidArgument());
+
+ // shouldn't be able to prepare
+ s = txn2->Prepare();
+ ASSERT_EQ(s, Status::InvalidArgument());
+
+ // valid name set
+ s = txn2->SetName("name2");
+ ASSERT_OK(s);
+
+ // cant reset name
+ s = txn2->SetName("name3");
+ ASSERT_EQ(s, Status::InvalidArgument());
+
+ ASSERT_EQ(txn1->GetName(), "name1");
+ ASSERT_EQ(txn2->GetName(), "name2");
+
+ s = txn1->Prepare();
+ ASSERT_OK(s);
+
+ // can't rename after prepare
+ s = txn1->SetName("name4");
+ ASSERT_EQ(s, Status::InvalidArgument());
+
+ txn1->Rollback();
+ txn2->Rollback();
+ delete txn1;
+ delete txn2;
+}
+
+TEST_P(TransactionTest, TwoPhaseEmptyWriteTest) {
+ for (bool cwb4recovery : {true, false}) {
+ for (bool test_with_empty_wal : {true, false}) {
+ if (!cwb4recovery && test_with_empty_wal) {
+ continue;
+ }
+ ASSERT_OK(ReOpen());
+ Status s;
+ std::string value;
+
+ WriteOptions write_options;
+ ReadOptions read_options;
+ TransactionOptions txn_options;
+ txn_options.use_only_the_last_commit_time_batch_for_recovery =
+ cwb4recovery;
+ Transaction* txn1 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn1);
+ Transaction* txn2 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn2);
+
+ s = txn1->SetName("joe");
+ ASSERT_OK(s);
+
+ s = txn2->SetName("bob");
+ ASSERT_OK(s);
+
+ s = txn1->Prepare();
+ ASSERT_OK(s);
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ delete txn1;
+
+ txn2->GetCommitTimeWriteBatch()->Put(Slice("foo"), Slice("bar"));
+
+ s = txn2->Prepare();
+ ASSERT_OK(s);
+
+ s = txn2->Commit();
+ ASSERT_OK(s);
+
+ delete txn2;
+ if (!cwb4recovery) {
+ s = db->Get(read_options, "foo", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "bar");
+ } else {
+ if (test_with_empty_wal) {
+ DBImpl* db_impl = reinterpret_cast<DBImpl*>(db->GetRootDB());
+ db_impl->TEST_FlushMemTable(true);
+ // After flush the state must be visible
+ s = db->Get(read_options, "foo", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "bar");
+ }
+ db->FlushWAL(true);
+ // kill and reopen to trigger recovery
+ s = ReOpenNoDelete();
+ ASSERT_OK(s);
+ assert(db != nullptr);
+ s = db->Get(read_options, "foo", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "bar");
+ }
+ }
+ }
+}
+
+#ifndef ROCKSDB_VALGRIND_RUN
+TEST_P(TransactionStressTest, TwoPhaseExpirationTest) {
+ Status s;
+
+ WriteOptions write_options;
+ TransactionOptions txn_options;
+ txn_options.expiration = 500; // 500ms
+ Transaction* txn1 = db->BeginTransaction(write_options, txn_options);
+ Transaction* txn2 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn1);
+ ASSERT_TRUE(txn1);
+
+ s = txn1->SetName("joe");
+ ASSERT_OK(s);
+ s = txn2->SetName("bob");
+ ASSERT_OK(s);
+
+ s = txn1->Prepare();
+ ASSERT_OK(s);
+
+ /* sleep override */
+ std::this_thread::sleep_for(std::chrono::milliseconds(1000));
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ s = txn2->Prepare();
+ ASSERT_EQ(s, Status::Expired());
+
+ delete txn1;
+ delete txn2;
+}
+
+TEST_P(TransactionTest, TwoPhaseRollbackTest) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+
+ TransactionOptions txn_options;
+
+ std::string value;
+ Status s;
+
+ DBImpl* db_impl = reinterpret_cast<DBImpl*>(db->GetRootDB());
+ Transaction* txn = db->BeginTransaction(write_options, txn_options);
+ s = txn->SetName("xid");
+ ASSERT_OK(s);
+
+ // transaction put
+ s = txn->Put(Slice("tfoo"), Slice("tbar"));
+ ASSERT_OK(s);
+
+ // value is readable form txn
+ s = txn->Get(read_options, Slice("tfoo"), &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "tbar");
+
+ // issue rollback
+ s = txn->Rollback();
+ ASSERT_OK(s);
+
+ // value is nolonger readable
+ s = txn->Get(read_options, Slice("tfoo"), &value);
+ ASSERT_TRUE(s.IsNotFound());
+ ASSERT_EQ(txn->GetNumPuts(), 0);
+
+ // put new txn values
+ s = txn->Put(Slice("tfoo2"), Slice("tbar2"));
+ ASSERT_OK(s);
+
+ // new value is readable from txn
+ s = txn->Get(read_options, Slice("tfoo2"), &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "tbar2");
+
+ s = txn->Prepare();
+ ASSERT_OK(s);
+
+ // flush to next wal
+ s = db->Put(write_options, Slice("foo"), Slice("bar"));
+ ASSERT_OK(s);
+ db_impl->TEST_FlushMemTable(true);
+
+ // issue rollback (marker written to WAL)
+ s = txn->Rollback();
+ ASSERT_OK(s);
+
+ // value is nolonger readable
+ s = txn->Get(read_options, Slice("tfoo2"), &value);
+ ASSERT_TRUE(s.IsNotFound());
+ ASSERT_EQ(txn->GetNumPuts(), 0);
+
+ // make commit
+ s = txn->Commit();
+ ASSERT_EQ(s, Status::InvalidArgument());
+
+ // try rollback again
+ s = txn->Rollback();
+ ASSERT_EQ(s, Status::InvalidArgument());
+
+ delete txn;
+}
+
+TEST_P(TransactionTest, PersistentTwoPhaseTransactionTest) {
+ WriteOptions write_options;
+ write_options.sync = true;
+ write_options.disableWAL = false;
+ ReadOptions read_options;
+
+ TransactionOptions txn_options;
+
+ std::string value;
+ Status s;
+
+ DBImpl* db_impl = reinterpret_cast<DBImpl*>(db->GetRootDB());
+
+ Transaction* txn = db->BeginTransaction(write_options, txn_options);
+ s = txn->SetName("xid");
+ ASSERT_OK(s);
+
+ ASSERT_EQ(db->GetTransactionByName("xid"), txn);
+
+ // transaction put
+ s = txn->Put(Slice("foo"), Slice("bar"));
+ ASSERT_OK(s);
+ ASSERT_EQ(1, txn->GetNumPuts());
+
+ // txn read
+ s = txn->Get(read_options, "foo", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "bar");
+
+ // regular db put
+ s = db->Put(write_options, Slice("foo2"), Slice("bar2"));
+ ASSERT_OK(s);
+ ASSERT_EQ(1, txn->GetNumPuts());
+
+ db_impl->TEST_FlushMemTable(true);
+
+ // regular db read
+ db->Get(read_options, "foo2", &value);
+ ASSERT_EQ(value, "bar2");
+
+ // nothing has been prepped yet
+ ASSERT_EQ(db_impl->TEST_FindMinLogContainingOutstandingPrep(), 0);
+
+ // prepare
+ s = txn->Prepare();
+ ASSERT_OK(s);
+
+ // still not available to db
+ s = db->Get(read_options, Slice("foo"), &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ db->FlushWAL(false);
+ delete txn;
+ // kill and reopen
+ reinterpret_cast<PessimisticTransactionDB*>(db)->TEST_Crash();
+ s = ReOpenNoDelete();
+ ASSERT_OK(s);
+ assert(db != nullptr);
+ db_impl = reinterpret_cast<DBImpl*>(db->GetRootDB());
+
+ // find trans in list of prepared transactions
+ std::vector<Transaction*> prepared_trans;
+ db->GetAllPreparedTransactions(&prepared_trans);
+ ASSERT_EQ(prepared_trans.size(), 1);
+
+ txn = prepared_trans.front();
+ ASSERT_TRUE(txn);
+ ASSERT_EQ(txn->GetName(), "xid");
+ ASSERT_EQ(db->GetTransactionByName("xid"), txn);
+
+ // log has been marked
+ auto log_containing_prep =
+ db_impl->TEST_FindMinLogContainingOutstandingPrep();
+ ASSERT_GT(log_containing_prep, 0);
+
+ // value is readable from txn
+ s = txn->Get(read_options, "foo", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "bar");
+
+ // make commit
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ // value is now available
+ db->Get(read_options, "foo", &value);
+ ASSERT_EQ(value, "bar");
+
+ // we already committed
+ s = txn->Commit();
+ ASSERT_EQ(s, Status::InvalidArgument());
+
+ // no longer is prepared results
+ prepared_trans.clear();
+ db->GetAllPreparedTransactions(&prepared_trans);
+ ASSERT_EQ(prepared_trans.size(), 0);
+
+ // transaction should no longer be visible
+ ASSERT_EQ(db->GetTransactionByName("xid"), nullptr);
+
+ // heap should not care about prepared section anymore
+ ASSERT_EQ(db_impl->TEST_FindMinLogContainingOutstandingPrep(), 0);
+
+ switch (txn_db_options.write_policy) {
+ case WRITE_COMMITTED:
+ // but now our memtable should be referencing the prep section
+ ASSERT_EQ(log_containing_prep,
+ db_impl->TEST_FindMinPrepLogReferencedByMemTable());
+ ASSERT_GE(log_containing_prep, db_impl->MinLogNumberToKeep());
+
+ break;
+ case WRITE_PREPARED:
+ case WRITE_UNPREPARED:
+ // In these modes memtable do not ref the prep sections
+ ASSERT_EQ(0, db_impl->TEST_FindMinPrepLogReferencedByMemTable());
+ break;
+ default:
+ assert(false);
+ }
+
+ // Add a dummy record to memtable before a flush. Otherwise, the
+ // memtable will be empty and flush will be skipped.
+ s = db->Put(write_options, Slice("foo3"), Slice("bar3"));
+ ASSERT_OK(s);
+
+ db_impl->TEST_FlushMemTable(true);
+
+ // after memtable flush we can now release the log
+ ASSERT_GT(db_impl->MinLogNumberToKeep(), log_containing_prep);
+ ASSERT_EQ(0, db_impl->TEST_FindMinPrepLogReferencedByMemTable());
+
+ delete txn;
+
+ // deleting transaction should unregister transaction
+ ASSERT_EQ(db->GetTransactionByName("xid"), nullptr);
+}
+#endif // ROCKSDB_VALGRIND_RUN
+
+// TODO this test needs to be updated with serial commits
+TEST_P(TransactionTest, DISABLED_TwoPhaseMultiThreadTest) {
+ // mix transaction writes and regular writes
+ const uint32_t NUM_TXN_THREADS = 50;
+ std::atomic<uint32_t> txn_thread_num(0);
+
+ std::function<void()> txn_write_thread = [&]() {
+ uint32_t id = txn_thread_num.fetch_add(1);
+
+ WriteOptions write_options;
+ write_options.sync = true;
+ write_options.disableWAL = false;
+ TransactionOptions txn_options;
+ txn_options.lock_timeout = 1000000;
+ if (id % 2 == 0) {
+ txn_options.expiration = 1000000;
+ }
+ TransactionName name("xid_" + std::string(1, 'A' + static_cast<char>(id)));
+ Transaction* txn = db->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(txn->SetName(name));
+ for (int i = 0; i < 10; i++) {
+ std::string key(name + "_" + std::string(1, static_cast<char>('A' + i)));
+ ASSERT_OK(txn->Put(key, "val"));
+ }
+ ASSERT_OK(txn->Prepare());
+ ASSERT_OK(txn->Commit());
+ delete txn;
+ };
+
+ // assure that all thread are in the same write group
+ std::atomic<uint32_t> t_wait_on_prepare(0);
+ std::atomic<uint32_t> t_wait_on_commit(0);
+
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "WriteThread::JoinBatchGroup:Wait", [&](void* arg) {
+ auto* writer = reinterpret_cast<WriteThread::Writer*>(arg);
+
+ if (writer->ShouldWriteToWAL()) {
+ t_wait_on_prepare.fetch_add(1);
+ // wait for friends
+ while (t_wait_on_prepare.load() < NUM_TXN_THREADS) {
+ env->SleepForMicroseconds(10);
+ }
+ } else if (writer->ShouldWriteToMemtable()) {
+ t_wait_on_commit.fetch_add(1);
+ // wait for friends
+ while (t_wait_on_commit.load() < NUM_TXN_THREADS) {
+ env->SleepForMicroseconds(10);
+ }
+ } else {
+ FAIL();
+ }
+ });
+
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+
+ // do all the writes
+ std::vector<port::Thread> threads;
+ for (uint32_t i = 0; i < NUM_TXN_THREADS; i++) {
+ threads.emplace_back(txn_write_thread);
+ }
+ for (auto& t : threads) {
+ t.join();
+ }
+
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->ClearAllCallBacks();
+
+ ReadOptions read_options;
+ std::string value;
+ Status s;
+ for (uint32_t t = 0; t < NUM_TXN_THREADS; t++) {
+ TransactionName name("xid_" + std::string(1, 'A' + static_cast<char>(t)));
+ for (int i = 0; i < 10; i++) {
+ std::string key(name + "_" + std::string(1, static_cast<char>('A' + i)));
+ s = db->Get(read_options, key, &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "val");
+ }
+ }
+}
+
+TEST_P(TransactionStressTest, TwoPhaseLongPrepareTest) {
+ WriteOptions write_options;
+ write_options.sync = true;
+ write_options.disableWAL = false;
+ ReadOptions read_options;
+ TransactionOptions txn_options;
+
+ std::string value;
+ Status s;
+
+ Transaction* txn = db->BeginTransaction(write_options, txn_options);
+ s = txn->SetName("bob");
+ ASSERT_OK(s);
+
+ // transaction put
+ s = txn->Put(Slice("foo"), Slice("bar"));
+ ASSERT_OK(s);
+
+ // prepare
+ s = txn->Prepare();
+ ASSERT_OK(s);
+
+ delete txn;
+
+ for (int i = 0; i < 1000; i++) {
+ std::string key(i, 'k');
+ std::string val(1000, 'v');
+ assert(db != nullptr);
+ s = db->Put(write_options, key, val);
+ ASSERT_OK(s);
+
+ if (i % 29 == 0) {
+ // crash
+ env->SetFilesystemActive(false);
+ reinterpret_cast<PessimisticTransactionDB*>(db)->TEST_Crash();
+ ReOpenNoDelete();
+ } else if (i % 37 == 0) {
+ // close
+ ReOpenNoDelete();
+ }
+ }
+
+ // commit old txn
+ txn = db->GetTransactionByName("bob");
+ ASSERT_TRUE(txn);
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ // verify data txn data
+ s = db->Get(read_options, "foo", &value);
+ ASSERT_EQ(s, Status::OK());
+ ASSERT_EQ(value, "bar");
+
+ // verify non txn data
+ for (int i = 0; i < 1000; i++) {
+ std::string key(i, 'k');
+ std::string val(1000, 'v');
+ s = db->Get(read_options, key, &value);
+ ASSERT_EQ(s, Status::OK());
+ ASSERT_EQ(value, val);
+ }
+
+ delete txn;
+}
+
+TEST_P(TransactionTest, TwoPhaseSequenceTest) {
+ WriteOptions write_options;
+ write_options.sync = true;
+ write_options.disableWAL = false;
+ ReadOptions read_options;
+
+ TransactionOptions txn_options;
+
+ std::string value;
+ Status s;
+
+ Transaction* txn = db->BeginTransaction(write_options, txn_options);
+ s = txn->SetName("xid");
+ ASSERT_OK(s);
+
+ // transaction put
+ s = txn->Put(Slice("foo"), Slice("bar"));
+ ASSERT_OK(s);
+ s = txn->Put(Slice("foo2"), Slice("bar2"));
+ ASSERT_OK(s);
+ s = txn->Put(Slice("foo3"), Slice("bar3"));
+ ASSERT_OK(s);
+ s = txn->Put(Slice("foo4"), Slice("bar4"));
+ ASSERT_OK(s);
+
+ // prepare
+ s = txn->Prepare();
+ ASSERT_OK(s);
+
+ // make commit
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ delete txn;
+
+ // kill and reopen
+ env->SetFilesystemActive(false);
+ ReOpenNoDelete();
+ assert(db != nullptr);
+
+ // value is now available
+ s = db->Get(read_options, "foo4", &value);
+ ASSERT_EQ(s, Status::OK());
+ ASSERT_EQ(value, "bar4");
+}
+
+TEST_P(TransactionTest, TwoPhaseDoubleRecoveryTest) {
+ WriteOptions write_options;
+ write_options.sync = true;
+ write_options.disableWAL = false;
+ ReadOptions read_options;
+
+ TransactionOptions txn_options;
+
+ std::string value;
+ Status s;
+
+ Transaction* txn = db->BeginTransaction(write_options, txn_options);
+ s = txn->SetName("a");
+ ASSERT_OK(s);
+
+ // transaction put
+ s = txn->Put(Slice("foo"), Slice("bar"));
+ ASSERT_OK(s);
+
+ // prepare
+ s = txn->Prepare();
+ ASSERT_OK(s);
+
+ delete txn;
+
+ // kill and reopen
+ env->SetFilesystemActive(false);
+ reinterpret_cast<PessimisticTransactionDB*>(db)->TEST_Crash();
+ ReOpenNoDelete();
+
+ // commit old txn
+ txn = db->GetTransactionByName("a");
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ s = db->Get(read_options, "foo", &value);
+ ASSERT_EQ(s, Status::OK());
+ ASSERT_EQ(value, "bar");
+
+ delete txn;
+
+ txn = db->BeginTransaction(write_options, txn_options);
+ s = txn->SetName("b");
+ ASSERT_OK(s);
+
+ s = txn->Put(Slice("foo2"), Slice("bar2"));
+ ASSERT_OK(s);
+
+ s = txn->Prepare();
+ ASSERT_OK(s);
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ delete txn;
+
+ // kill and reopen
+ env->SetFilesystemActive(false);
+ ReOpenNoDelete();
+ assert(db != nullptr);
+
+ // value is now available
+ s = db->Get(read_options, "foo", &value);
+ ASSERT_EQ(s, Status::OK());
+ ASSERT_EQ(value, "bar");
+
+ s = db->Get(read_options, "foo2", &value);
+ ASSERT_EQ(s, Status::OK());
+ ASSERT_EQ(value, "bar2");
+}
+
+TEST_P(TransactionTest, TwoPhaseLogRollingTest) {
+ DBImpl* db_impl = reinterpret_cast<DBImpl*>(db->GetRootDB());
+
+ Status s;
+ std::string v;
+ ColumnFamilyHandle *cfa, *cfb;
+
+ // Create 2 new column families
+ ColumnFamilyOptions cf_options;
+ s = db->CreateColumnFamily(cf_options, "CFA", &cfa);
+ ASSERT_OK(s);
+ s = db->CreateColumnFamily(cf_options, "CFB", &cfb);
+ ASSERT_OK(s);
+
+ WriteOptions wopts;
+ wopts.disableWAL = false;
+ wopts.sync = true;
+
+ TransactionOptions topts1;
+ Transaction* txn1 = db->BeginTransaction(wopts, topts1);
+ s = txn1->SetName("xid1");
+ ASSERT_OK(s);
+
+ TransactionOptions topts2;
+ Transaction* txn2 = db->BeginTransaction(wopts, topts2);
+ s = txn2->SetName("xid2");
+ ASSERT_OK(s);
+
+ // transaction put in two column families
+ s = txn1->Put(cfa, "ka1", "va1");
+ ASSERT_OK(s);
+
+ // transaction put in two column families
+ s = txn2->Put(cfa, "ka2", "va2");
+ ASSERT_OK(s);
+ s = txn2->Put(cfb, "kb2", "vb2");
+ ASSERT_OK(s);
+
+ // write prep section to wal
+ s = txn1->Prepare();
+ ASSERT_OK(s);
+
+ // our log should be in the heap
+ ASSERT_EQ(db_impl->TEST_FindMinLogContainingOutstandingPrep(),
+ txn1->GetLogNumber());
+ ASSERT_EQ(db_impl->TEST_LogfileNumber(), txn1->GetLastLogNumber());
+
+ // flush default cf to crate new log
+ s = db->Put(wopts, "foo", "bar");
+ ASSERT_OK(s);
+ s = db_impl->TEST_FlushMemTable(true);
+ ASSERT_OK(s);
+
+ // make sure we are on a new log
+ ASSERT_GT(db_impl->TEST_LogfileNumber(), txn1->GetLastLogNumber());
+
+ // put txn2 prep section in this log
+ s = txn2->Prepare();
+ ASSERT_OK(s);
+ ASSERT_EQ(db_impl->TEST_LogfileNumber(), txn2->GetLastLogNumber());
+
+ // heap should still see first log
+ ASSERT_EQ(db_impl->TEST_FindMinLogContainingOutstandingPrep(),
+ txn1->GetLogNumber());
+
+ // commit txn1
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ // heap should now show txn2s log
+ ASSERT_EQ(db_impl->TEST_FindMinLogContainingOutstandingPrep(),
+ txn2->GetLogNumber());
+
+ switch (txn_db_options.write_policy) {
+ case WRITE_COMMITTED:
+ // we should see txn1s log refernced by the memtables
+ ASSERT_EQ(txn1->GetLogNumber(),
+ db_impl->TEST_FindMinPrepLogReferencedByMemTable());
+ break;
+ case WRITE_PREPARED:
+ case WRITE_UNPREPARED:
+ // In these modes memtable do not ref the prep sections
+ ASSERT_EQ(0, db_impl->TEST_FindMinPrepLogReferencedByMemTable());
+ break;
+ default:
+ assert(false);
+ }
+
+ // flush default cf to crate new log
+ s = db->Put(wopts, "foo", "bar2");
+ ASSERT_OK(s);
+ s = db_impl->TEST_FlushMemTable(true);
+ ASSERT_OK(s);
+
+ // make sure we are on a new log
+ ASSERT_GT(db_impl->TEST_LogfileNumber(), txn2->GetLastLogNumber());
+
+ // commit txn2
+ s = txn2->Commit();
+ ASSERT_OK(s);
+
+ // heap should not show any logs
+ ASSERT_EQ(db_impl->TEST_FindMinLogContainingOutstandingPrep(), 0);
+
+ switch (txn_db_options.write_policy) {
+ case WRITE_COMMITTED:
+ // should show the first txn log
+ ASSERT_EQ(txn1->GetLogNumber(),
+ db_impl->TEST_FindMinPrepLogReferencedByMemTable());
+ break;
+ case WRITE_PREPARED:
+ case WRITE_UNPREPARED:
+ // In these modes memtable do not ref the prep sections
+ ASSERT_EQ(0, db_impl->TEST_FindMinPrepLogReferencedByMemTable());
+ break;
+ default:
+ assert(false);
+ }
+
+ // flush only cfa memtable
+ s = db_impl->TEST_FlushMemTable(true, false, cfa);
+ ASSERT_OK(s);
+
+ switch (txn_db_options.write_policy) {
+ case WRITE_COMMITTED:
+ // should show the first txn log
+ ASSERT_EQ(txn2->GetLogNumber(),
+ db_impl->TEST_FindMinPrepLogReferencedByMemTable());
+ break;
+ case WRITE_PREPARED:
+ case WRITE_UNPREPARED:
+ // In these modes memtable do not ref the prep sections
+ ASSERT_EQ(0, db_impl->TEST_FindMinPrepLogReferencedByMemTable());
+ break;
+ default:
+ assert(false);
+ }
+
+ // flush only cfb memtable
+ s = db_impl->TEST_FlushMemTable(true, false, cfb);
+ ASSERT_OK(s);
+
+ // should show not dependency on logs
+ ASSERT_EQ(db_impl->TEST_FindMinPrepLogReferencedByMemTable(), 0);
+ ASSERT_EQ(db_impl->TEST_FindMinLogContainingOutstandingPrep(), 0);
+
+ delete txn1;
+ delete txn2;
+ delete cfa;
+ delete cfb;
+}
+
+TEST_P(TransactionTest, TwoPhaseLogRollingTest2) {
+ DBImpl* db_impl = reinterpret_cast<DBImpl*>(db->GetRootDB());
+
+ Status s;
+ ColumnFamilyHandle *cfa, *cfb;
+
+ ColumnFamilyOptions cf_options;
+ s = db->CreateColumnFamily(cf_options, "CFA", &cfa);
+ ASSERT_OK(s);
+ s = db->CreateColumnFamily(cf_options, "CFB", &cfb);
+ ASSERT_OK(s);
+
+ WriteOptions wopts;
+ wopts.disableWAL = false;
+ wopts.sync = true;
+
+ auto cfh_a = reinterpret_cast<ColumnFamilyHandleImpl*>(cfa);
+ auto cfh_b = reinterpret_cast<ColumnFamilyHandleImpl*>(cfb);
+
+ TransactionOptions topts1;
+ Transaction* txn1 = db->BeginTransaction(wopts, topts1);
+ s = txn1->SetName("xid1");
+ ASSERT_OK(s);
+ s = txn1->Put(cfa, "boys", "girls1");
+ ASSERT_OK(s);
+
+ Transaction* txn2 = db->BeginTransaction(wopts, topts1);
+ s = txn2->SetName("xid2");
+ ASSERT_OK(s);
+ s = txn2->Put(cfb, "up", "down1");
+ ASSERT_OK(s);
+
+ // prepre transaction in LOG A
+ s = txn1->Prepare();
+ ASSERT_OK(s);
+
+ // prepre transaction in LOG A
+ s = txn2->Prepare();
+ ASSERT_OK(s);
+
+ // regular put so that mem table can actually be flushed for log rolling
+ s = db->Put(wopts, "cats", "dogs1");
+ ASSERT_OK(s);
+
+ auto prepare_log_no = txn1->GetLastLogNumber();
+
+ // roll to LOG B
+ s = db_impl->TEST_FlushMemTable(true);
+ ASSERT_OK(s);
+
+ // now we pause background work so that
+ // imm()s are not flushed before we can check their status
+ s = db_impl->PauseBackgroundWork();
+ ASSERT_OK(s);
+
+ ASSERT_GT(db_impl->TEST_LogfileNumber(), prepare_log_no);
+ switch (txn_db_options.write_policy) {
+ case WRITE_COMMITTED:
+ // This cf is empty and should ref the latest log
+ ASSERT_GT(cfh_a->cfd()->GetLogNumber(), prepare_log_no);
+ ASSERT_EQ(cfh_a->cfd()->GetLogNumber(), db_impl->TEST_LogfileNumber());
+ break;
+ case WRITE_PREPARED:
+ case WRITE_UNPREPARED:
+ // This cf is not flushed yet and should ref the log that has its data
+ ASSERT_EQ(cfh_a->cfd()->GetLogNumber(), prepare_log_no);
+ break;
+ default:
+ assert(false);
+ }
+ ASSERT_EQ(db_impl->TEST_FindMinLogContainingOutstandingPrep(),
+ txn1->GetLogNumber());
+ ASSERT_EQ(db_impl->TEST_FindMinPrepLogReferencedByMemTable(), 0);
+
+ // commit in LOG B
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ switch (txn_db_options.write_policy) {
+ case WRITE_COMMITTED:
+ ASSERT_EQ(db_impl->TEST_FindMinPrepLogReferencedByMemTable(),
+ prepare_log_no);
+ break;
+ case WRITE_PREPARED:
+ case WRITE_UNPREPARED:
+ // In these modes memtable do not ref the prep sections
+ ASSERT_EQ(db_impl->TEST_FindMinPrepLogReferencedByMemTable(), 0);
+ break;
+ default:
+ assert(false);
+ }
+
+ ASSERT_TRUE(!db_impl->TEST_UnableToReleaseOldestLog());
+
+ // request a flush for all column families such that the earliest
+ // alive log file can be killed
+ db_impl->TEST_SwitchWAL();
+ // log cannot be flushed because txn2 has not been commited
+ ASSERT_TRUE(!db_impl->TEST_IsLogGettingFlushed());
+ ASSERT_TRUE(db_impl->TEST_UnableToReleaseOldestLog());
+
+ // assert that cfa has a flush requested
+ ASSERT_TRUE(cfh_a->cfd()->imm()->HasFlushRequested());
+
+ switch (txn_db_options.write_policy) {
+ case WRITE_COMMITTED:
+ // cfb should not be flushed becuse it has no data from LOG A
+ ASSERT_TRUE(!cfh_b->cfd()->imm()->HasFlushRequested());
+ break;
+ case WRITE_PREPARED:
+ case WRITE_UNPREPARED:
+ // cfb should be flushed becuse it has prepared data from LOG A
+ ASSERT_TRUE(cfh_b->cfd()->imm()->HasFlushRequested());
+ break;
+ default:
+ assert(false);
+ }
+
+ // cfb now has data from LOG A
+ s = txn2->Commit();
+ ASSERT_OK(s);
+
+ db_impl->TEST_SwitchWAL();
+ ASSERT_TRUE(!db_impl->TEST_UnableToReleaseOldestLog());
+
+ // we should see that cfb now has a flush requested
+ ASSERT_TRUE(cfh_b->cfd()->imm()->HasFlushRequested());
+
+ // all data in LOG A resides in a memtable that has been
+ // requested for a flush
+ ASSERT_TRUE(db_impl->TEST_IsLogGettingFlushed());
+
+ delete txn1;
+ delete txn2;
+ delete cfa;
+ delete cfb;
+}
+/*
+ * 1) use prepare to keep first log around to determine starting sequence
+ * during recovery.
+ * 2) insert many values, skipping wal, to increase seqid.
+ * 3) insert final value into wal
+ * 4) recover and see that final value was properly recovered - not
+ * hidden behind improperly summed sequence ids
+ */
+TEST_P(TransactionTest, TwoPhaseOutOfOrderDelete) {
+ DBImpl* db_impl = reinterpret_cast<DBImpl*>(db->GetRootDB());
+ WriteOptions wal_on, wal_off;
+ wal_on.sync = true;
+ wal_on.disableWAL = false;
+ wal_off.disableWAL = true;
+ ReadOptions read_options;
+ TransactionOptions txn_options;
+
+ std::string value;
+ Status s;
+
+ Transaction* txn1 = db->BeginTransaction(wal_on, txn_options);
+
+ s = txn1->SetName("1");
+ ASSERT_OK(s);
+
+ s = db->Put(wal_on, "first", "first");
+ ASSERT_OK(s);
+
+ s = txn1->Put(Slice("dummy"), Slice("dummy"));
+ ASSERT_OK(s);
+ s = txn1->Prepare();
+ ASSERT_OK(s);
+
+ s = db->Put(wal_off, "cats", "dogs1");
+ ASSERT_OK(s);
+ s = db->Put(wal_off, "cats", "dogs2");
+ ASSERT_OK(s);
+ s = db->Put(wal_off, "cats", "dogs3");
+ ASSERT_OK(s);
+
+ s = db_impl->TEST_FlushMemTable(true);
+ ASSERT_OK(s);
+
+ s = db->Put(wal_on, "cats", "dogs4");
+ ASSERT_OK(s);
+
+ db->FlushWAL(false);
+
+ // kill and reopen
+ env->SetFilesystemActive(false);
+ reinterpret_cast<PessimisticTransactionDB*>(db)->TEST_Crash();
+ ReOpenNoDelete();
+ assert(db != nullptr);
+
+ s = db->Get(read_options, "first", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "first");
+
+ s = db->Get(read_options, "cats", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "dogs4");
+}
+
+TEST_P(TransactionTest, FirstWriteTest) {
+ WriteOptions write_options;
+
+ // Test conflict checking against the very first write to a db.
+ // The transaction's snapshot will have seq 1 and the following write
+ // will have sequence 1.
+ Status s = db->Put(write_options, "A", "a");
+
+ Transaction* txn = db->BeginTransaction(write_options);
+ txn->SetSnapshot();
+
+ ASSERT_OK(s);
+
+ s = txn->Put("A", "b");
+ ASSERT_OK(s);
+
+ delete txn;
+}
+
+TEST_P(TransactionTest, FirstWriteTest2) {
+ WriteOptions write_options;
+
+ Transaction* txn = db->BeginTransaction(write_options);
+ txn->SetSnapshot();
+
+ // Test conflict checking against the very first write to a db.
+ // The transaction's snapshot is a seq 0 while the following write
+ // will have sequence 1.
+ Status s = db->Put(write_options, "A", "a");
+ ASSERT_OK(s);
+
+ s = txn->Put("A", "b");
+ ASSERT_TRUE(s.IsBusy());
+
+ delete txn;
+}
+
+TEST_P(TransactionTest, WriteOptionsTest) {
+ WriteOptions write_options;
+ write_options.sync = true;
+ write_options.disableWAL = true;
+
+ Transaction* txn = db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn);
+
+ ASSERT_TRUE(txn->GetWriteOptions()->sync);
+
+ write_options.sync = false;
+ txn->SetWriteOptions(write_options);
+ ASSERT_FALSE(txn->GetWriteOptions()->sync);
+ ASSERT_TRUE(txn->GetWriteOptions()->disableWAL);
+
+ delete txn;
+}
+
+TEST_P(TransactionTest, WriteConflictTest) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ string value;
+ Status s;
+
+ db->Put(write_options, "foo", "A");
+ db->Put(write_options, "foo2", "B");
+
+ Transaction* txn = db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn);
+
+ s = txn->Put("foo", "A2");
+ ASSERT_OK(s);
+
+ s = txn->Put("foo2", "B2");
+ ASSERT_OK(s);
+
+ // This Put outside of a transaction will conflict with the previous write
+ s = db->Put(write_options, "foo", "xxx");
+ ASSERT_TRUE(s.IsTimedOut());
+
+ s = db->Get(read_options, "foo", &value);
+ ASSERT_EQ(value, "A");
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ db->Get(read_options, "foo", &value);
+ ASSERT_EQ(value, "A2");
+ db->Get(read_options, "foo2", &value);
+ ASSERT_EQ(value, "B2");
+
+ delete txn;
+}
+
+TEST_P(TransactionTest, WriteConflictTest2) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ TransactionOptions txn_options;
+ std::string value;
+ Status s;
+
+ db->Put(write_options, "foo", "bar");
+
+ txn_options.set_snapshot = true;
+ Transaction* txn = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn);
+
+ // This Put outside of a transaction will conflict with a later write
+ s = db->Put(write_options, "foo", "barz");
+ ASSERT_OK(s);
+
+ s = txn->Put("foo2", "X");
+ ASSERT_OK(s);
+
+ s = txn->Put("foo",
+ "bar2"); // Conflicts with write done after snapshot taken
+ ASSERT_TRUE(s.IsBusy());
+
+ s = txn->Put("foo3", "Y");
+ ASSERT_OK(s);
+
+ s = db->Get(read_options, "foo", &value);
+ ASSERT_EQ(value, "barz");
+
+ ASSERT_EQ(2, txn->GetNumKeys());
+
+ s = txn->Commit();
+ ASSERT_OK(s); // Txn should commit, but only write foo2 and foo3
+
+ // Verify that transaction wrote foo2 and foo3 but not foo
+ db->Get(read_options, "foo", &value);
+ ASSERT_EQ(value, "barz");
+
+ db->Get(read_options, "foo2", &value);
+ ASSERT_EQ(value, "X");
+
+ db->Get(read_options, "foo3", &value);
+ ASSERT_EQ(value, "Y");
+
+ delete txn;
+}
+
+TEST_P(TransactionTest, ReadConflictTest) {
+ WriteOptions write_options;
+ ReadOptions read_options, snapshot_read_options;
+ TransactionOptions txn_options;
+ std::string value;
+ Status s;
+
+ db->Put(write_options, "foo", "bar");
+ db->Put(write_options, "foo2", "bar");
+
+ txn_options.set_snapshot = true;
+ Transaction* txn = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn);
+
+ txn->SetSnapshot();
+ snapshot_read_options.snapshot = txn->GetSnapshot();
+
+ txn->GetForUpdate(snapshot_read_options, "foo", &value);
+ ASSERT_EQ(value, "bar");
+
+ // This Put outside of a transaction will conflict with the previous read
+ s = db->Put(write_options, "foo", "barz");
+ ASSERT_TRUE(s.IsTimedOut());
+
+ s = db->Get(read_options, "foo", &value);
+ ASSERT_EQ(value, "bar");
+
+ s = txn->Get(read_options, "foo", &value);
+ ASSERT_EQ(value, "bar");
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ delete txn;
+}
+
+TEST_P(TransactionTest, TxnOnlyTest) {
+ // Test to make sure transactions work when there are no other writes in an
+ // empty db.
+
+ WriteOptions write_options;
+ ReadOptions read_options;
+ std::string value;
+ Status s;
+
+ Transaction* txn = db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn);
+
+ s = txn->Put("x", "y");
+ ASSERT_OK(s);
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ delete txn;
+}
+
+TEST_P(TransactionTest, FlushTest) {
+ WriteOptions write_options;
+ ReadOptions read_options, snapshot_read_options;
+ std::string value;
+ Status s;
+
+ db->Put(write_options, Slice("foo"), Slice("bar"));
+ db->Put(write_options, Slice("foo2"), Slice("bar"));
+
+ Transaction* txn = db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn);
+
+ snapshot_read_options.snapshot = txn->GetSnapshot();
+
+ txn->GetForUpdate(snapshot_read_options, "foo", &value);
+ ASSERT_EQ(value, "bar");
+
+ s = txn->Put(Slice("foo"), Slice("bar2"));
+ ASSERT_OK(s);
+
+ txn->GetForUpdate(snapshot_read_options, "foo", &value);
+ ASSERT_EQ(value, "bar2");
+
+ // Put a random key so we have a memtable to flush
+ s = db->Put(write_options, "dummy", "dummy");
+ ASSERT_OK(s);
+
+ // force a memtable flush
+ FlushOptions flush_ops;
+ db->Flush(flush_ops);
+
+ s = txn->Commit();
+ // txn should commit since the flushed table is still in MemtableList History
+ ASSERT_OK(s);
+
+ db->Get(read_options, "foo", &value);
+ ASSERT_EQ(value, "bar2");
+
+ delete txn;
+}
+
+TEST_P(TransactionTest, FlushTest2) {
+ const size_t num_tests = 3;
+
+ for (size_t n = 0; n < num_tests; n++) {
+ // Test different table factories
+ switch (n) {
+ case 0:
+ break;
+ case 1:
+ options.table_factory.reset(new mock::MockTableFactory());
+ break;
+ case 2: {
+ PlainTableOptions pt_opts;
+ pt_opts.hash_table_ratio = 0;
+ options.table_factory.reset(NewPlainTableFactory(pt_opts));
+ break;
+ }
+ }
+
+ Status s = ReOpen();
+ ASSERT_OK(s);
+ assert(db != nullptr);
+
+ WriteOptions write_options;
+ ReadOptions read_options, snapshot_read_options;
+ TransactionOptions txn_options;
+ string value;
+
+ DBImpl* db_impl = reinterpret_cast<DBImpl*>(db->GetRootDB());
+
+ db->Put(write_options, Slice("foo"), Slice("bar"));
+ db->Put(write_options, Slice("foo2"), Slice("bar2"));
+ db->Put(write_options, Slice("foo3"), Slice("bar3"));
+
+ txn_options.set_snapshot = true;
+ Transaction* txn = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn);
+
+ snapshot_read_options.snapshot = txn->GetSnapshot();
+
+ txn->GetForUpdate(snapshot_read_options, "foo", &value);
+ ASSERT_EQ(value, "bar");
+
+ s = txn->Put(Slice("foo"), Slice("bar2"));
+ ASSERT_OK(s);
+
+ txn->GetForUpdate(snapshot_read_options, "foo", &value);
+ ASSERT_EQ(value, "bar2");
+ // verify foo is locked by txn
+ s = db->Delete(write_options, "foo");
+ ASSERT_TRUE(s.IsTimedOut());
+
+ s = db->Put(write_options, "Z", "z");
+ ASSERT_OK(s);
+ s = db->Put(write_options, "dummy", "dummy");
+ ASSERT_OK(s);
+
+ s = db->Put(write_options, "S", "s");
+ ASSERT_OK(s);
+ s = db->SingleDelete(write_options, "S");
+ ASSERT_OK(s);
+
+ s = txn->Delete("S");
+ // Should fail after encountering a write to S in memtable
+ ASSERT_TRUE(s.IsBusy());
+
+ // force a memtable flush
+ s = db_impl->TEST_FlushMemTable(true);
+ ASSERT_OK(s);
+
+ // Put a random key so we have a MemTable to flush
+ s = db->Put(write_options, "dummy", "dummy2");
+ ASSERT_OK(s);
+
+ // force a memtable flush
+ ASSERT_OK(db_impl->TEST_FlushMemTable(true));
+
+ s = db->Put(write_options, "dummy", "dummy3");
+ ASSERT_OK(s);
+
+ // force a memtable flush
+ // Since our test db has max_write_buffer_number=2, this flush will cause
+ // the first memtable to get purged from the MemtableList history.
+ ASSERT_OK(db_impl->TEST_FlushMemTable(true));
+
+ s = txn->Put("X", "Y");
+ // Should succeed after verifying there is no write to X in SST file
+ ASSERT_OK(s);
+
+ s = txn->Put("Z", "zz");
+ // Should fail after encountering a write to Z in SST file
+ ASSERT_TRUE(s.IsBusy());
+
+ s = txn->GetForUpdate(read_options, "foo2", &value);
+ // should succeed since key was written before txn started
+ ASSERT_OK(s);
+ // verify foo2 is locked by txn
+ s = db->Delete(write_options, "foo2");
+ ASSERT_TRUE(s.IsTimedOut());
+
+ s = txn->Delete("S");
+ // Should fail after encountering a write to S in SST file
+ ASSERT_TRUE(s.IsBusy());
+
+ // Write a bunch of keys to db to force a compaction
+ Random rnd(47);
+ for (int i = 0; i < 1000; i++) {
+ s = db->Put(write_options, std::to_string(i),
+ test::CompressibleString(&rnd, 0.8, 100, &value));
+ ASSERT_OK(s);
+ }
+
+ s = txn->Put("X", "yy");
+ // Should succeed after verifying there is no write to X in SST file
+ ASSERT_OK(s);
+
+ s = txn->Put("Z", "zzz");
+ // Should fail after encountering a write to Z in SST file
+ ASSERT_TRUE(s.IsBusy());
+
+ s = txn->Delete("S");
+ // Should fail after encountering a write to S in SST file
+ ASSERT_TRUE(s.IsBusy());
+
+ s = txn->GetForUpdate(read_options, "foo3", &value);
+ // should succeed since key was written before txn started
+ ASSERT_OK(s);
+ // verify foo3 is locked by txn
+ s = db->Delete(write_options, "foo3");
+ ASSERT_TRUE(s.IsTimedOut());
+
+ db_impl->TEST_WaitForCompact();
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ // Transaction should only write the keys that succeeded.
+ s = db->Get(read_options, "foo", &value);
+ ASSERT_EQ(value, "bar2");
+
+ s = db->Get(read_options, "X", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("yy", value);
+
+ s = db->Get(read_options, "Z", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("z", value);
+
+ delete txn;
+ }
+}
+
+TEST_P(TransactionTest, NoSnapshotTest) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ std::string value;
+ Status s;
+
+ db->Put(write_options, "AAA", "bar");
+
+ Transaction* txn = db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn);
+
+ // Modify key after transaction start
+ db->Put(write_options, "AAA", "bar1");
+
+ // Read and write without a snap
+ txn->GetForUpdate(read_options, "AAA", &value);
+ ASSERT_EQ(value, "bar1");
+ s = txn->Put("AAA", "bar2");
+ ASSERT_OK(s);
+
+ // Should commit since read/write was done after data changed
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ txn->GetForUpdate(read_options, "AAA", &value);
+ ASSERT_EQ(value, "bar2");
+
+ delete txn;
+}
+
+TEST_P(TransactionTest, MultipleSnapshotTest) {
+ WriteOptions write_options;
+ ReadOptions read_options, snapshot_read_options;
+ std::string value;
+ Status s;
+
+ ASSERT_OK(db->Put(write_options, "AAA", "bar"));
+ ASSERT_OK(db->Put(write_options, "BBB", "bar"));
+ ASSERT_OK(db->Put(write_options, "CCC", "bar"));
+
+ Transaction* txn = db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn);
+
+ db->Put(write_options, "AAA", "bar1");
+
+ // Read and write without a snapshot
+ ASSERT_OK(txn->GetForUpdate(read_options, "AAA", &value));
+ ASSERT_EQ(value, "bar1");
+ s = txn->Put("AAA", "bar2");
+ ASSERT_OK(s);
+
+ // Modify BBB before snapshot is taken
+ ASSERT_OK(db->Put(write_options, "BBB", "bar1"));
+
+ txn->SetSnapshot();
+ snapshot_read_options.snapshot = txn->GetSnapshot();
+
+ // Read and write with snapshot
+ ASSERT_OK(txn->GetForUpdate(snapshot_read_options, "BBB", &value));
+ ASSERT_EQ(value, "bar1");
+ s = txn->Put("BBB", "bar2");
+ ASSERT_OK(s);
+
+ ASSERT_OK(db->Put(write_options, "CCC", "bar1"));
+
+ // Set a new snapshot
+ txn->SetSnapshot();
+ snapshot_read_options.snapshot = txn->GetSnapshot();
+
+ // Read and write with snapshot
+ txn->GetForUpdate(snapshot_read_options, "CCC", &value);
+ ASSERT_EQ(value, "bar1");
+ s = txn->Put("CCC", "bar2");
+ ASSERT_OK(s);
+
+ s = txn->GetForUpdate(read_options, "AAA", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "bar2");
+ s = txn->GetForUpdate(read_options, "BBB", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "bar2");
+ s = txn->GetForUpdate(read_options, "CCC", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "bar2");
+
+ s = db->Get(read_options, "AAA", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "bar1");
+ s = db->Get(read_options, "BBB", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "bar1");
+ s = db->Get(read_options, "CCC", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "bar1");
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ s = db->Get(read_options, "AAA", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "bar2");
+ s = db->Get(read_options, "BBB", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "bar2");
+ s = db->Get(read_options, "CCC", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "bar2");
+
+ // verify that we track multiple writes to the same key at different snapshots
+ delete txn;
+ txn = db->BeginTransaction(write_options);
+
+ // Potentially conflicting writes
+ db->Put(write_options, "ZZZ", "zzz");
+ db->Put(write_options, "XXX", "xxx");
+
+ txn->SetSnapshot();
+
+ TransactionOptions txn_options;
+ txn_options.set_snapshot = true;
+ Transaction* txn2 = db->BeginTransaction(write_options, txn_options);
+ txn2->SetSnapshot();
+
+ // This should not conflict in txn since the snapshot is later than the
+ // previous write (spoiler alert: it will later conflict with txn2).
+ s = txn->Put("ZZZ", "zzzz");
+ ASSERT_OK(s);
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ delete txn;
+
+ // This will conflict since the snapshot is earlier than another write to ZZZ
+ s = txn2->Put("ZZZ", "xxxxx");
+ ASSERT_TRUE(s.IsBusy());
+
+ s = txn2->Commit();
+ ASSERT_OK(s);
+
+ s = db->Get(read_options, "ZZZ", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "zzzz");
+
+ delete txn2;
+}
+
+TEST_P(TransactionTest, ColumnFamiliesTest) {
+ WriteOptions write_options;
+ ReadOptions read_options, snapshot_read_options;
+ TransactionOptions txn_options;
+ string value;
+ Status s;
+
+ ColumnFamilyHandle *cfa, *cfb;
+ ColumnFamilyOptions cf_options;
+
+ // Create 2 new column families
+ s = db->CreateColumnFamily(cf_options, "CFA", &cfa);
+ ASSERT_OK(s);
+ s = db->CreateColumnFamily(cf_options, "CFB", &cfb);
+ ASSERT_OK(s);
+
+ delete cfa;
+ delete cfb;
+ delete db;
+ db = nullptr;
+
+ // open DB with three column families
+ std::vector<ColumnFamilyDescriptor> column_families;
+ // have to open default column family
+ column_families.push_back(
+ ColumnFamilyDescriptor(kDefaultColumnFamilyName, ColumnFamilyOptions()));
+ // open the new column families
+ column_families.push_back(
+ ColumnFamilyDescriptor("CFA", ColumnFamilyOptions()));
+ column_families.push_back(
+ ColumnFamilyDescriptor("CFB", ColumnFamilyOptions()));
+
+ std::vector<ColumnFamilyHandle*> handles;
+
+ ASSERT_OK(ReOpenNoDelete(column_families, &handles));
+ assert(db != nullptr);
+
+ Transaction* txn = db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn);
+
+ txn->SetSnapshot();
+ snapshot_read_options.snapshot = txn->GetSnapshot();
+
+ txn_options.set_snapshot = true;
+ Transaction* txn2 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn2);
+
+ // Write some data to the db
+ WriteBatch batch;
+ batch.Put("foo", "foo");
+ batch.Put(handles[1], "AAA", "bar");
+ batch.Put(handles[1], "AAAZZZ", "bar");
+ s = db->Write(write_options, &batch);
+ ASSERT_OK(s);
+ db->Delete(write_options, handles[1], "AAAZZZ");
+
+ // These keys do not conflict with existing writes since they're in
+ // different column families
+ s = txn->Delete("AAA");
+ ASSERT_OK(s);
+ s = txn->GetForUpdate(snapshot_read_options, handles[1], "foo", &value);
+ ASSERT_TRUE(s.IsNotFound());
+ Slice key_slice("AAAZZZ");
+ Slice value_slices[2] = {Slice("bar"), Slice("bar")};
+ s = txn->Put(handles[2], SliceParts(&key_slice, 1),
+ SliceParts(value_slices, 2));
+ ASSERT_OK(s);
+ ASSERT_EQ(3, txn->GetNumKeys());
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+ s = db->Get(read_options, "AAA", &value);
+ ASSERT_TRUE(s.IsNotFound());
+ s = db->Get(read_options, handles[2], "AAAZZZ", &value);
+ ASSERT_EQ(value, "barbar");
+
+ Slice key_slices[3] = {Slice("AAA"), Slice("ZZ"), Slice("Z")};
+ Slice value_slice("barbarbar");
+
+ s = txn2->Delete(handles[2], "XXX");
+ ASSERT_OK(s);
+ s = txn2->Delete(handles[1], "XXX");
+ ASSERT_OK(s);
+
+ // This write will cause a conflict with the earlier batch write
+ s = txn2->Put(handles[1], SliceParts(key_slices, 3),
+ SliceParts(&value_slice, 1));
+ ASSERT_TRUE(s.IsBusy());
+
+ s = txn2->Commit();
+ ASSERT_OK(s);
+ // In the above the latest change to AAAZZZ in handles[1] is delete.
+ s = db->Get(read_options, handles[1], "AAAZZZ", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ delete txn;
+ delete txn2;
+
+ txn = db->BeginTransaction(write_options, txn_options);
+ snapshot_read_options.snapshot = txn->GetSnapshot();
+
+ txn2 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn);
+
+ std::vector<ColumnFamilyHandle*> multiget_cfh = {handles[1], handles[2],
+ handles[0], handles[2]};
+ std::vector<Slice> multiget_keys = {"AAA", "AAAZZZ", "foo", "foo"};
+ std::vector<std::string> values(4);
+ std::vector<Status> results = txn->MultiGetForUpdate(
+ snapshot_read_options, multiget_cfh, multiget_keys, &values);
+ ASSERT_OK(results[0]);
+ ASSERT_OK(results[1]);
+ ASSERT_OK(results[2]);
+ ASSERT_TRUE(results[3].IsNotFound());
+ ASSERT_EQ(values[0], "bar");
+ ASSERT_EQ(values[1], "barbar");
+ ASSERT_EQ(values[2], "foo");
+
+ s = txn->SingleDelete(handles[2], "ZZZ");
+ ASSERT_OK(s);
+ s = txn->Put(handles[2], "ZZZ", "YYY");
+ ASSERT_OK(s);
+ s = txn->Put(handles[2], "ZZZ", "YYYY");
+ ASSERT_OK(s);
+ s = txn->Delete(handles[2], "ZZZ");
+ ASSERT_OK(s);
+ s = txn->Put(handles[2], "AAAZZZ", "barbarbar");
+ ASSERT_OK(s);
+
+ ASSERT_EQ(5, txn->GetNumKeys());
+
+ // Txn should commit
+ s = txn->Commit();
+ ASSERT_OK(s);
+ s = db->Get(read_options, handles[2], "ZZZ", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ // Put a key which will conflict with the next txn using the previous snapshot
+ db->Put(write_options, handles[2], "foo", "000");
+
+ results = txn2->MultiGetForUpdate(snapshot_read_options, multiget_cfh,
+ multiget_keys, &values);
+ // All results should fail since there was a conflict
+ ASSERT_TRUE(results[0].IsBusy());
+ ASSERT_TRUE(results[1].IsBusy());
+ ASSERT_TRUE(results[2].IsBusy());
+ ASSERT_TRUE(results[3].IsBusy());
+
+ s = db->Get(read_options, handles[2], "foo", &value);
+ ASSERT_EQ(value, "000");
+
+ s = txn2->Commit();
+ ASSERT_OK(s);
+
+ s = db->DropColumnFamily(handles[1]);
+ ASSERT_OK(s);
+ s = db->DropColumnFamily(handles[2]);
+ ASSERT_OK(s);
+
+ delete txn;
+ delete txn2;
+
+ for (auto handle : handles) {
+ delete handle;
+ }
+}
+
+TEST_P(TransactionTest, MultiGetBatchedTest) {
+ WriteOptions write_options;
+ ReadOptions read_options, snapshot_read_options;
+ TransactionOptions txn_options;
+ string value;
+ Status s;
+
+ ColumnFamilyHandle* cf;
+ ColumnFamilyOptions cf_options;
+
+ // Create a new column families
+ s = db->CreateColumnFamily(cf_options, "CF", &cf);
+ ASSERT_OK(s);
+
+ delete cf;
+ delete db;
+ db = nullptr;
+
+ // open DB with three column families
+ std::vector<ColumnFamilyDescriptor> column_families;
+ // have to open default column family
+ column_families.push_back(
+ ColumnFamilyDescriptor(kDefaultColumnFamilyName, ColumnFamilyOptions()));
+ // open the new column families
+ cf_options.merge_operator = MergeOperators::CreateStringAppendOperator();
+ column_families.push_back(ColumnFamilyDescriptor("CF", cf_options));
+
+ std::vector<ColumnFamilyHandle*> handles;
+
+ options.merge_operator = MergeOperators::CreateStringAppendOperator();
+ ASSERT_OK(ReOpenNoDelete(column_families, &handles));
+ assert(db != nullptr);
+
+ // Write some data to the db
+ WriteBatch batch;
+ batch.Put(handles[1], "aaa", "val1");
+ batch.Put(handles[1], "bbb", "val2");
+ batch.Put(handles[1], "ccc", "val3");
+ batch.Put(handles[1], "ddd", "foo");
+ batch.Put(handles[1], "eee", "val5");
+ batch.Put(handles[1], "fff", "val6");
+ batch.Merge(handles[1], "ggg", "foo");
+ s = db->Write(write_options, &batch);
+ ASSERT_OK(s);
+
+ Transaction* txn = db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn);
+
+ txn->SetSnapshot();
+ snapshot_read_options.snapshot = txn->GetSnapshot();
+
+ txn_options.set_snapshot = true;
+ // Write some data to the db
+ s = txn->Delete(handles[1], "bbb");
+ ASSERT_OK(s);
+ s = txn->Put(handles[1], "ccc", "val3_new");
+ ASSERT_OK(s);
+ s = txn->Merge(handles[1], "ddd", "bar");
+ ASSERT_OK(s);
+
+ std::vector<Slice> keys = {"aaa", "bbb", "ccc", "ddd", "eee", "fff", "ggg"};
+ std::vector<PinnableSlice> values(keys.size());
+ std::vector<Status> statuses(keys.size());
+
+ txn->MultiGet(snapshot_read_options, handles[1], keys.size(), keys.data(),
+ values.data(), statuses.data());
+ ASSERT_TRUE(statuses[0].ok());
+ ASSERT_EQ(values[0], "val1");
+ ASSERT_TRUE(statuses[1].IsNotFound());
+ ASSERT_TRUE(statuses[2].ok());
+ ASSERT_EQ(values[2], "val3_new");
+ ASSERT_TRUE(statuses[3].IsMergeInProgress());
+ ASSERT_TRUE(statuses[4].ok());
+ ASSERT_EQ(values[4], "val5");
+ ASSERT_TRUE(statuses[5].ok());
+ ASSERT_EQ(values[5], "val6");
+ ASSERT_TRUE(statuses[6].ok());
+ ASSERT_EQ(values[6], "foo");
+ delete txn;
+ for (auto handle : handles) {
+ delete handle;
+ }
+}
+
+// This test calls WriteBatchWithIndex::MultiGetFromBatchAndDB with a large
+// number of keys, i.e greater than MultiGetContext::MAX_BATCH_SIZE, which is
+// is 32. This forces autovector allocations in the MultiGet code paths
+// to use std::vector in addition to stack allocations. The MultiGet keys
+// includes Merges, which are handled specially in MultiGetFromBatchAndDB by
+// allocating an autovector of MergeContexts
+TEST_P(TransactionTest, MultiGetLargeBatchedTest) {
+ WriteOptions write_options;
+ ReadOptions read_options, snapshot_read_options;
+ string value;
+ Status s;
+
+ ColumnFamilyHandle* cf;
+ ColumnFamilyOptions cf_options;
+
+ std::vector<std::string> key_str;
+ for (int i = 0; i < 100; ++i) {
+ key_str.emplace_back(std::to_string(i));
+ }
+ // Create a new column families
+ s = db->CreateColumnFamily(cf_options, "CF", &cf);
+ ASSERT_OK(s);
+
+ delete cf;
+ delete db;
+ db = nullptr;
+
+ // open DB with three column families
+ std::vector<ColumnFamilyDescriptor> column_families;
+ // have to open default column family
+ column_families.push_back(
+ ColumnFamilyDescriptor(kDefaultColumnFamilyName, ColumnFamilyOptions()));
+ // open the new column families
+ cf_options.merge_operator = MergeOperators::CreateStringAppendOperator();
+ column_families.push_back(ColumnFamilyDescriptor("CF", cf_options));
+
+ std::vector<ColumnFamilyHandle*> handles;
+
+ options.merge_operator = MergeOperators::CreateStringAppendOperator();
+ ASSERT_OK(ReOpenNoDelete(column_families, &handles));
+ assert(db != nullptr);
+
+ // Write some data to the db
+ WriteBatch batch;
+ for (int i = 0; i < 3 * MultiGetContext::MAX_BATCH_SIZE; ++i) {
+ std::string val = "val" + std::to_string(i);
+ batch.Put(handles[1], key_str[i], val);
+ }
+ s = db->Write(write_options, &batch);
+ ASSERT_OK(s);
+
+ WriteBatchWithIndex wb;
+ // Write some data to the db
+ s = wb.Delete(handles[1], std::to_string(1));
+ ASSERT_OK(s);
+ s = wb.Put(handles[1], std::to_string(2), "new_val" + std::to_string(2));
+ ASSERT_OK(s);
+ // Write a lot of merges so when we call MultiGetFromBatchAndDB later on,
+ // it is forced to use std::vector in ROCKSDB_NAMESPACE::autovector to
+ // allocate MergeContexts. The number of merges needs to be >
+ // MultiGetContext::MAX_BATCH_SIZE
+ for (int i = 8; i < MultiGetContext::MAX_BATCH_SIZE + 24; ++i) {
+ s = wb.Merge(handles[1], std::to_string(i), "merge");
+ ASSERT_OK(s);
+ }
+
+ // MultiGet a lot of keys in order to force std::vector reallocations
+ std::vector<Slice> keys;
+ for (int i = 0; i < MultiGetContext::MAX_BATCH_SIZE + 32; ++i) {
+ keys.emplace_back(key_str[i]);
+ }
+ std::vector<PinnableSlice> values(keys.size());
+ std::vector<Status> statuses(keys.size());
+
+ wb.MultiGetFromBatchAndDB(db, snapshot_read_options, handles[1], keys.size(), keys.data(),
+ values.data(), statuses.data(), false);
+ for (size_t i =0; i < keys.size(); ++i) {
+ if (i == 1) {
+ ASSERT_TRUE(statuses[1].IsNotFound());
+ } else if (i == 2) {
+ ASSERT_TRUE(statuses[2].ok());
+ ASSERT_EQ(values[2], "new_val" + std::to_string(2));
+ } else if (i >= 8 && i < 56) {
+ ASSERT_TRUE(statuses[i].ok());
+ ASSERT_EQ(values[i], "val" + std::to_string(i) + ",merge");
+ } else {
+ ASSERT_TRUE(statuses[i].ok());
+ if (values[i] != "val" + std::to_string(i)) {
+ ASSERT_EQ(values[i], "val" + std::to_string(i));
+ }
+ }
+ }
+
+ for (auto handle : handles) {
+ delete handle;
+ }
+}
+
+TEST_P(TransactionTest, ColumnFamiliesTest2) {
+ WriteOptions write_options;
+ ReadOptions read_options, snapshot_read_options;
+ string value;
+ Status s;
+
+ ColumnFamilyHandle *one, *two;
+ ColumnFamilyOptions cf_options;
+
+ // Create 2 new column families
+ s = db->CreateColumnFamily(cf_options, "ONE", &one);
+ ASSERT_OK(s);
+ s = db->CreateColumnFamily(cf_options, "TWO", &two);
+ ASSERT_OK(s);
+
+ Transaction* txn1 = db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn1);
+ Transaction* txn2 = db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn2);
+
+ s = txn1->Put(one, "X", "1");
+ ASSERT_OK(s);
+ s = txn1->Put(two, "X", "2");
+ ASSERT_OK(s);
+ s = txn1->Put("X", "0");
+ ASSERT_OK(s);
+
+ s = txn2->Put(one, "X", "11");
+ ASSERT_TRUE(s.IsTimedOut());
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ // Drop first column family
+ s = db->DropColumnFamily(one);
+ ASSERT_OK(s);
+
+ // Should fail since column family was dropped.
+ s = txn2->Commit();
+ ASSERT_OK(s);
+
+ delete txn1;
+ txn1 = db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn1);
+
+ // Should fail since column family was dropped
+ s = txn1->Put(one, "X", "111");
+ ASSERT_TRUE(s.IsInvalidArgument());
+
+ s = txn1->Put(two, "X", "222");
+ ASSERT_OK(s);
+
+ s = txn1->Put("X", "000");
+ ASSERT_OK(s);
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ s = db->Get(read_options, two, "X", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("222", value);
+
+ s = db->Get(read_options, "X", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("000", value);
+
+ s = db->DropColumnFamily(two);
+ ASSERT_OK(s);
+
+ delete txn1;
+ delete txn2;
+
+ delete one;
+ delete two;
+}
+
+TEST_P(TransactionTest, EmptyTest) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ string value;
+ Status s;
+
+ s = db->Put(write_options, "aaa", "aaa");
+ ASSERT_OK(s);
+
+ Transaction* txn = db->BeginTransaction(write_options);
+ s = txn->Commit();
+ ASSERT_OK(s);
+ delete txn;
+
+ txn = db->BeginTransaction(write_options);
+ txn->Rollback();
+ delete txn;
+
+ txn = db->BeginTransaction(write_options);
+ s = txn->GetForUpdate(read_options, "aaa", &value);
+ ASSERT_EQ(value, "aaa");
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+ delete txn;
+
+ txn = db->BeginTransaction(write_options);
+ txn->SetSnapshot();
+
+ s = txn->GetForUpdate(read_options, "aaa", &value);
+ ASSERT_EQ(value, "aaa");
+
+ // Conflicts with previous GetForUpdate
+ s = db->Put(write_options, "aaa", "xxx");
+ ASSERT_TRUE(s.IsTimedOut());
+
+ // transaction expired!
+ s = txn->Commit();
+ ASSERT_OK(s);
+ delete txn;
+}
+
+TEST_P(TransactionTest, PredicateManyPreceders) {
+ WriteOptions write_options;
+ ReadOptions read_options1, read_options2;
+ TransactionOptions txn_options;
+ string value;
+ Status s;
+
+ txn_options.set_snapshot = true;
+ Transaction* txn1 = db->BeginTransaction(write_options, txn_options);
+ read_options1.snapshot = txn1->GetSnapshot();
+
+ Transaction* txn2 = db->BeginTransaction(write_options);
+ txn2->SetSnapshot();
+ read_options2.snapshot = txn2->GetSnapshot();
+
+ std::vector<Slice> multiget_keys = {"1", "2", "3"};
+ std::vector<std::string> multiget_values;
+
+ std::vector<Status> results =
+ txn1->MultiGetForUpdate(read_options1, multiget_keys, &multiget_values);
+ ASSERT_TRUE(results[1].IsNotFound());
+
+ s = txn2->Put("2", "x"); // Conflict's with txn1's MultiGetForUpdate
+ ASSERT_TRUE(s.IsTimedOut());
+
+ txn2->Rollback();
+
+ multiget_values.clear();
+ results =
+ txn1->MultiGetForUpdate(read_options1, multiget_keys, &multiget_values);
+ ASSERT_TRUE(results[1].IsNotFound());
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ delete txn1;
+ delete txn2;
+
+ txn1 = db->BeginTransaction(write_options, txn_options);
+ read_options1.snapshot = txn1->GetSnapshot();
+
+ txn2 = db->BeginTransaction(write_options, txn_options);
+ read_options2.snapshot = txn2->GetSnapshot();
+
+ s = txn1->Put("4", "x");
+ ASSERT_OK(s);
+
+ s = txn2->Delete("4"); // conflict
+ ASSERT_TRUE(s.IsTimedOut());
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ s = txn2->GetForUpdate(read_options2, "4", &value);
+ ASSERT_TRUE(s.IsBusy());
+
+ txn2->Rollback();
+
+ delete txn1;
+ delete txn2;
+}
+
+TEST_P(TransactionTest, LostUpdate) {
+ WriteOptions write_options;
+ ReadOptions read_options, read_options1, read_options2;
+ TransactionOptions txn_options;
+ std::string value;
+ Status s;
+
+ // Test 2 transactions writing to the same key in multiple orders and
+ // with/without snapshots
+
+ Transaction* txn1 = db->BeginTransaction(write_options);
+ Transaction* txn2 = db->BeginTransaction(write_options);
+
+ s = txn1->Put("1", "1");
+ ASSERT_OK(s);
+
+ s = txn2->Put("1", "2"); // conflict
+ ASSERT_TRUE(s.IsTimedOut());
+
+ s = txn2->Commit();
+ ASSERT_OK(s);
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ s = db->Get(read_options, "1", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("1", value);
+
+ delete txn1;
+ delete txn2;
+
+ txn_options.set_snapshot = true;
+ txn1 = db->BeginTransaction(write_options, txn_options);
+ read_options1.snapshot = txn1->GetSnapshot();
+
+ txn2 = db->BeginTransaction(write_options, txn_options);
+ read_options2.snapshot = txn2->GetSnapshot();
+
+ s = txn1->Put("1", "3");
+ ASSERT_OK(s);
+ s = txn2->Put("1", "4"); // conflict
+ ASSERT_TRUE(s.IsTimedOut());
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ s = txn2->Commit();
+ ASSERT_OK(s);
+
+ s = db->Get(read_options, "1", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("3", value);
+
+ delete txn1;
+ delete txn2;
+
+ txn1 = db->BeginTransaction(write_options, txn_options);
+ read_options1.snapshot = txn1->GetSnapshot();
+
+ txn2 = db->BeginTransaction(write_options, txn_options);
+ read_options2.snapshot = txn2->GetSnapshot();
+
+ s = txn1->Put("1", "5");
+ ASSERT_OK(s);
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ s = txn2->Put("1", "6");
+ ASSERT_TRUE(s.IsBusy());
+ s = txn2->Commit();
+ ASSERT_OK(s);
+
+ s = db->Get(read_options, "1", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("5", value);
+
+ delete txn1;
+ delete txn2;
+
+ txn1 = db->BeginTransaction(write_options, txn_options);
+ read_options1.snapshot = txn1->GetSnapshot();
+
+ txn2 = db->BeginTransaction(write_options, txn_options);
+ read_options2.snapshot = txn2->GetSnapshot();
+
+ s = txn1->Put("1", "7");
+ ASSERT_OK(s);
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ txn2->SetSnapshot();
+ s = txn2->Put("1", "8");
+ ASSERT_OK(s);
+ s = txn2->Commit();
+ ASSERT_OK(s);
+
+ s = db->Get(read_options, "1", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("8", value);
+
+ delete txn1;
+ delete txn2;
+
+ txn1 = db->BeginTransaction(write_options);
+ txn2 = db->BeginTransaction(write_options);
+
+ s = txn1->Put("1", "9");
+ ASSERT_OK(s);
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ s = txn2->Put("1", "10");
+ ASSERT_OK(s);
+ s = txn2->Commit();
+ ASSERT_OK(s);
+
+ delete txn1;
+ delete txn2;
+
+ s = db->Get(read_options, "1", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "10");
+}
+
+TEST_P(TransactionTest, UntrackedWrites) {
+ if (txn_db_options.write_policy == WRITE_UNPREPARED) {
+ // TODO(lth): For WriteUnprepared, validate that untracked writes are
+ // not supported.
+ return;
+ }
+
+ WriteOptions write_options;
+ ReadOptions read_options;
+ std::string value;
+ Status s;
+
+ // Verify transaction rollback works for untracked keys.
+ Transaction* txn = db->BeginTransaction(write_options);
+ txn->SetSnapshot();
+
+ s = txn->PutUntracked("untracked", "0");
+ ASSERT_OK(s);
+ txn->Rollback();
+ s = db->Get(read_options, "untracked", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ delete txn;
+ txn = db->BeginTransaction(write_options);
+ txn->SetSnapshot();
+
+ s = db->Put(write_options, "untracked", "x");
+ ASSERT_OK(s);
+
+ // Untracked writes should succeed even though key was written after snapshot
+ s = txn->PutUntracked("untracked", "1");
+ ASSERT_OK(s);
+ s = txn->MergeUntracked("untracked", "2");
+ ASSERT_OK(s);
+ s = txn->DeleteUntracked("untracked");
+ ASSERT_OK(s);
+
+ // Conflict
+ s = txn->Put("untracked", "3");
+ ASSERT_TRUE(s.IsBusy());
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ s = db->Get(read_options, "untracked", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ delete txn;
+}
+
+TEST_P(TransactionTest, ExpiredTransaction) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ TransactionOptions txn_options;
+ string value;
+ Status s;
+
+ // Set txn expiration timeout to 0 microseconds (expires instantly)
+ txn_options.expiration = 0;
+ Transaction* txn1 = db->BeginTransaction(write_options, txn_options);
+
+ s = txn1->Put("X", "1");
+ ASSERT_OK(s);
+
+ s = txn1->Put("Y", "1");
+ ASSERT_OK(s);
+
+ Transaction* txn2 = db->BeginTransaction(write_options);
+
+ // txn2 should be able to write to X since txn1 has expired
+ s = txn2->Put("X", "2");
+ ASSERT_OK(s);
+
+ s = txn2->Commit();
+ ASSERT_OK(s);
+ s = db->Get(read_options, "X", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("2", value);
+
+ s = txn1->Put("Z", "1");
+ ASSERT_OK(s);
+
+ // txn1 should fail to commit since it is expired
+ s = txn1->Commit();
+ ASSERT_TRUE(s.IsExpired());
+
+ s = db->Get(read_options, "Y", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = db->Get(read_options, "Z", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ delete txn1;
+ delete txn2;
+}
+
+TEST_P(TransactionTest, ReinitializeTest) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ TransactionOptions txn_options;
+ std::string value;
+ Status s;
+
+ // Set txn expiration timeout to 0 microseconds (expires instantly)
+ txn_options.expiration = 0;
+ Transaction* txn1 = db->BeginTransaction(write_options, txn_options);
+
+ // Reinitialize transaction to no long expire
+ txn_options.expiration = -1;
+ txn1 = db->BeginTransaction(write_options, txn_options, txn1);
+
+ s = txn1->Put("Z", "z");
+ ASSERT_OK(s);
+
+ // Should commit since not expired
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ txn1 = db->BeginTransaction(write_options, txn_options, txn1);
+
+ s = txn1->Put("Z", "zz");
+ ASSERT_OK(s);
+
+ // Reinitilize txn1 and verify that Z gets unlocked
+ txn1 = db->BeginTransaction(write_options, txn_options, txn1);
+
+ Transaction* txn2 = db->BeginTransaction(write_options, txn_options, nullptr);
+ s = txn2->Put("Z", "zzz");
+ ASSERT_OK(s);
+ s = txn2->Commit();
+ ASSERT_OK(s);
+ delete txn2;
+
+ s = db->Get(read_options, "Z", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "zzz");
+
+ // Verify snapshots get reinitialized correctly
+ txn1->SetSnapshot();
+ s = txn1->Put("Z", "zzzz");
+ ASSERT_OK(s);
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ s = db->Get(read_options, "Z", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "zzzz");
+
+ txn1 = db->BeginTransaction(write_options, txn_options, txn1);
+ const Snapshot* snapshot = txn1->GetSnapshot();
+ ASSERT_FALSE(snapshot);
+
+ txn_options.set_snapshot = true;
+ txn1 = db->BeginTransaction(write_options, txn_options, txn1);
+ snapshot = txn1->GetSnapshot();
+ ASSERT_TRUE(snapshot);
+
+ s = txn1->Put("Z", "a");
+ ASSERT_OK(s);
+
+ txn1->Rollback();
+
+ s = txn1->Put("Y", "y");
+ ASSERT_OK(s);
+
+ txn_options.set_snapshot = false;
+ txn1 = db->BeginTransaction(write_options, txn_options, txn1);
+ snapshot = txn1->GetSnapshot();
+ ASSERT_FALSE(snapshot);
+
+ s = txn1->Put("X", "x");
+ ASSERT_OK(s);
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ s = db->Get(read_options, "Z", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(value, "zzzz");
+
+ s = db->Get(read_options, "Y", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ txn1 = db->BeginTransaction(write_options, txn_options, txn1);
+
+ s = txn1->SetName("name");
+ ASSERT_OK(s);
+
+ s = txn1->Prepare();
+ ASSERT_OK(s);
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ txn1 = db->BeginTransaction(write_options, txn_options, txn1);
+
+ s = txn1->SetName("name");
+ ASSERT_OK(s);
+
+ delete txn1;
+}
+
+TEST_P(TransactionTest, Rollback) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ TransactionOptions txn_options;
+ std::string value;
+ Status s;
+
+ Transaction* txn1 = db->BeginTransaction(write_options, txn_options);
+
+ ASSERT_OK(s);
+
+ s = txn1->Put("X", "1");
+ ASSERT_OK(s);
+
+ Transaction* txn2 = db->BeginTransaction(write_options);
+
+ // txn2 should not be able to write to X since txn1 has it locked
+ s = txn2->Put("X", "2");
+ ASSERT_TRUE(s.IsTimedOut());
+
+ txn1->Rollback();
+ delete txn1;
+
+ // txn2 should now be able to write to X
+ s = txn2->Put("X", "3");
+ ASSERT_OK(s);
+
+ s = txn2->Commit();
+ ASSERT_OK(s);
+
+ s = db->Get(read_options, "X", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("3", value);
+
+ delete txn2;
+}
+
+TEST_P(TransactionTest, LockLimitTest) {
+ WriteOptions write_options;
+ ReadOptions read_options, snapshot_read_options;
+ TransactionOptions txn_options;
+ string value;
+ Status s;
+
+ delete db;
+ db = nullptr;
+
+ // Open DB with a lock limit of 3
+ txn_db_options.max_num_locks = 3;
+ ASSERT_OK(ReOpen());
+ assert(db != nullptr);
+ ASSERT_OK(s);
+
+ // Create a txn and verify we can only lock up to 3 keys
+ Transaction* txn = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn);
+
+ s = txn->Put("X", "x");
+ ASSERT_OK(s);
+
+ s = txn->Put("Y", "y");
+ ASSERT_OK(s);
+
+ s = txn->Put("Z", "z");
+ ASSERT_OK(s);
+
+ // lock limit reached
+ s = txn->Put("W", "w");
+ ASSERT_TRUE(s.IsBusy());
+
+ // re-locking same key shouldn't put us over the limit
+ s = txn->Put("X", "xx");
+ ASSERT_OK(s);
+
+ s = txn->GetForUpdate(read_options, "W", &value);
+ ASSERT_TRUE(s.IsBusy());
+ s = txn->GetForUpdate(read_options, "V", &value);
+ ASSERT_TRUE(s.IsBusy());
+
+ // re-locking same key shouldn't put us over the limit
+ s = txn->GetForUpdate(read_options, "Y", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("y", value);
+
+ s = txn->Get(read_options, "W", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ Transaction* txn2 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn2);
+
+ // "X" currently locked
+ s = txn2->Put("X", "x");
+ ASSERT_TRUE(s.IsTimedOut());
+
+ // lock limit reached
+ s = txn2->Put("M", "m");
+ ASSERT_TRUE(s.IsBusy());
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ s = db->Get(read_options, "X", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("xx", value);
+
+ s = db->Get(read_options, "W", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ // Committing txn should release its locks and allow txn2 to proceed
+ s = txn2->Put("X", "x2");
+ ASSERT_OK(s);
+
+ s = txn2->Delete("X");
+ ASSERT_OK(s);
+
+ s = txn2->Put("M", "m");
+ ASSERT_OK(s);
+
+ s = txn2->Put("Z", "z2");
+ ASSERT_OK(s);
+
+ // lock limit reached
+ s = txn2->Delete("Y");
+ ASSERT_TRUE(s.IsBusy());
+
+ s = txn2->Commit();
+ ASSERT_OK(s);
+
+ s = db->Get(read_options, "Z", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("z2", value);
+
+ s = db->Get(read_options, "Y", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("y", value);
+
+ s = db->Get(read_options, "X", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ delete txn;
+ delete txn2;
+}
+
+TEST_P(TransactionTest, IteratorTest) {
+ // This test does writes without snapshot validation, and then tries to create
+ // iterator later, which is unsupported in write unprepared.
+ if (txn_db_options.write_policy == WRITE_UNPREPARED) {
+ return;
+ }
+
+ WriteOptions write_options;
+ ReadOptions read_options, snapshot_read_options;
+ std::string value;
+ Status s;
+
+ // Write some keys to the db
+ s = db->Put(write_options, "A", "a");
+ ASSERT_OK(s);
+
+ s = db->Put(write_options, "G", "g");
+ ASSERT_OK(s);
+
+ s = db->Put(write_options, "F", "f");
+ ASSERT_OK(s);
+
+ s = db->Put(write_options, "C", "c");
+ ASSERT_OK(s);
+
+ s = db->Put(write_options, "D", "d");
+ ASSERT_OK(s);
+
+ Transaction* txn = db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn);
+
+ // Write some keys in a txn
+ s = txn->Put("B", "b");
+ ASSERT_OK(s);
+
+ s = txn->Put("H", "h");
+ ASSERT_OK(s);
+
+ s = txn->Delete("D");
+ ASSERT_OK(s);
+
+ s = txn->Put("E", "e");
+ ASSERT_OK(s);
+
+ txn->SetSnapshot();
+ const Snapshot* snapshot = txn->GetSnapshot();
+
+ // Write some keys to the db after the snapshot
+ s = db->Put(write_options, "BB", "xx");
+ ASSERT_OK(s);
+
+ s = db->Put(write_options, "C", "xx");
+ ASSERT_OK(s);
+
+ read_options.snapshot = snapshot;
+ Iterator* iter = txn->GetIterator(read_options);
+ ASSERT_OK(iter->status());
+ iter->SeekToFirst();
+
+ // Read all keys via iter and lock them all
+ std::string results[] = {"a", "b", "c", "e", "f", "g", "h"};
+ for (int i = 0; i < 7; i++) {
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ(results[i], iter->value().ToString());
+
+ s = txn->GetForUpdate(read_options, iter->key(), nullptr);
+ if (i == 2) {
+ // "C" was modified after txn's snapshot
+ ASSERT_TRUE(s.IsBusy());
+ } else {
+ ASSERT_OK(s);
+ }
+
+ iter->Next();
+ }
+ ASSERT_FALSE(iter->Valid());
+
+ iter->Seek("G");
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("g", iter->value().ToString());
+
+ iter->Prev();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("f", iter->value().ToString());
+
+ iter->Seek("D");
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("e", iter->value().ToString());
+
+ iter->Seek("C");
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("c", iter->value().ToString());
+
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("e", iter->value().ToString());
+
+ iter->Seek("");
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("a", iter->value().ToString());
+
+ iter->Seek("X");
+ ASSERT_OK(iter->status());
+ ASSERT_FALSE(iter->Valid());
+
+ iter->SeekToLast();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("h", iter->value().ToString());
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ delete iter;
+ delete txn;
+}
+
+TEST_P(TransactionTest, DisableIndexingTest) {
+ // Skip this test for write unprepared. It does not solely rely on WBWI for
+ // read your own writes, so depending on whether batches are flushed or not,
+ // only some writes will be visible.
+ //
+ // Also, write unprepared does not support creating iterators if there has
+ // been txn->Put() without snapshot validation.
+ if (txn_db_options.write_policy == WRITE_UNPREPARED) {
+ return;
+ }
+
+ WriteOptions write_options;
+ ReadOptions read_options;
+ std::string value;
+ Status s;
+
+ Transaction* txn = db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn);
+
+ s = txn->Put("A", "a");
+ ASSERT_OK(s);
+
+ s = txn->Get(read_options, "A", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("a", value);
+
+ txn->DisableIndexing();
+
+ s = txn->Put("B", "b");
+ ASSERT_OK(s);
+
+ s = txn->Get(read_options, "B", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ Iterator* iter = txn->GetIterator(read_options);
+ ASSERT_OK(iter->status());
+
+ iter->Seek("B");
+ ASSERT_OK(iter->status());
+ ASSERT_FALSE(iter->Valid());
+
+ s = txn->Delete("A");
+
+ s = txn->Get(read_options, "A", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("a", value);
+
+ txn->EnableIndexing();
+
+ s = txn->Put("B", "bb");
+ ASSERT_OK(s);
+
+ iter->Seek("B");
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("bb", iter->value().ToString());
+
+ s = txn->Get(read_options, "B", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("bb", value);
+
+ s = txn->Put("A", "aa");
+ ASSERT_OK(s);
+
+ s = txn->Get(read_options, "A", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("aa", value);
+
+ delete iter;
+ delete txn;
+}
+
+TEST_P(TransactionTest, SavepointTest) {
+ WriteOptions write_options;
+ ReadOptions read_options, snapshot_read_options;
+ std::string value;
+ Status s;
+
+ Transaction* txn = db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn);
+
+ ASSERT_EQ(0, txn->GetNumPuts());
+
+ s = txn->RollbackToSavePoint();
+ ASSERT_TRUE(s.IsNotFound());
+
+ txn->SetSavePoint(); // 1
+
+ ASSERT_OK(txn->RollbackToSavePoint()); // Rollback to beginning of txn
+ s = txn->RollbackToSavePoint();
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn->Put("B", "b");
+ ASSERT_OK(s);
+
+ ASSERT_EQ(1, txn->GetNumPuts());
+ ASSERT_EQ(0, txn->GetNumDeletes());
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ s = db->Get(read_options, "B", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("b", value);
+
+ delete txn;
+ txn = db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn);
+
+ s = txn->Put("A", "a");
+ ASSERT_OK(s);
+
+ s = txn->Put("B", "bb");
+ ASSERT_OK(s);
+
+ s = txn->Put("C", "c");
+ ASSERT_OK(s);
+
+ txn->SetSavePoint(); // 2
+
+ s = txn->Delete("B");
+ ASSERT_OK(s);
+
+ s = txn->Put("C", "cc");
+ ASSERT_OK(s);
+
+ s = txn->Put("D", "d");
+ ASSERT_OK(s);
+
+ ASSERT_EQ(5, txn->GetNumPuts());
+ ASSERT_EQ(1, txn->GetNumDeletes());
+
+ ASSERT_OK(txn->RollbackToSavePoint()); // Rollback to 2
+
+ ASSERT_EQ(3, txn->GetNumPuts());
+ ASSERT_EQ(0, txn->GetNumDeletes());
+
+ s = txn->Get(read_options, "A", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("a", value);
+
+ s = txn->Get(read_options, "B", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("bb", value);
+
+ s = txn->Get(read_options, "C", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("c", value);
+
+ s = txn->Get(read_options, "D", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn->Put("A", "a");
+ ASSERT_OK(s);
+
+ s = txn->Put("E", "e");
+ ASSERT_OK(s);
+
+ ASSERT_EQ(5, txn->GetNumPuts());
+ ASSERT_EQ(0, txn->GetNumDeletes());
+
+ // Rollback to beginning of txn
+ s = txn->RollbackToSavePoint();
+ ASSERT_TRUE(s.IsNotFound());
+ txn->Rollback();
+
+ ASSERT_EQ(0, txn->GetNumPuts());
+ ASSERT_EQ(0, txn->GetNumDeletes());
+
+ s = txn->Get(read_options, "A", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn->Get(read_options, "B", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("b", value);
+
+ s = txn->Get(read_options, "D", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn->Get(read_options, "D", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn->Get(read_options, "E", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn->Put("A", "aa");
+ ASSERT_OK(s);
+
+ s = txn->Put("F", "f");
+ ASSERT_OK(s);
+
+ ASSERT_EQ(2, txn->GetNumPuts());
+ ASSERT_EQ(0, txn->GetNumDeletes());
+
+ txn->SetSavePoint(); // 3
+ txn->SetSavePoint(); // 4
+
+ s = txn->Put("G", "g");
+ ASSERT_OK(s);
+
+ s = txn->SingleDelete("F");
+ ASSERT_OK(s);
+
+ s = txn->Delete("B");
+ ASSERT_OK(s);
+
+ s = txn->Get(read_options, "A", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("aa", value);
+
+ s = txn->Get(read_options, "F", &value);
+ // According to db.h, doing a SingleDelete on a key that has been
+ // overwritten will have undefinied behavior. So it is unclear what the
+ // result of fetching "F" should be. The current implementation will
+ // return NotFound in this case.
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn->Get(read_options, "B", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ ASSERT_EQ(3, txn->GetNumPuts());
+ ASSERT_EQ(2, txn->GetNumDeletes());
+
+ ASSERT_OK(txn->RollbackToSavePoint()); // Rollback to 3
+
+ ASSERT_EQ(2, txn->GetNumPuts());
+ ASSERT_EQ(0, txn->GetNumDeletes());
+
+ s = txn->Get(read_options, "F", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("f", value);
+
+ s = txn->Get(read_options, "G", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ s = db->Get(read_options, "F", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("f", value);
+
+ s = db->Get(read_options, "G", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = db->Get(read_options, "A", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("aa", value);
+
+ s = db->Get(read_options, "B", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("b", value);
+
+ s = db->Get(read_options, "C", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = db->Get(read_options, "D", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = db->Get(read_options, "E", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ delete txn;
+}
+
+TEST_P(TransactionTest, SavepointTest2) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ TransactionOptions txn_options;
+ Status s;
+
+ txn_options.lock_timeout = 1; // 1 ms
+ Transaction* txn1 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn1);
+
+ s = txn1->Put("A", "");
+ ASSERT_OK(s);
+
+ txn1->SetSavePoint(); // 1
+
+ s = txn1->Put("A", "a");
+ ASSERT_OK(s);
+
+ s = txn1->Put("C", "c");
+ ASSERT_OK(s);
+
+ txn1->SetSavePoint(); // 2
+
+ s = txn1->Put("A", "a");
+ ASSERT_OK(s);
+ s = txn1->Put("B", "b");
+ ASSERT_OK(s);
+
+ ASSERT_OK(txn1->RollbackToSavePoint()); // Rollback to 2
+
+ // Verify that "A" and "C" is still locked while "B" is not
+ Transaction* txn2 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn2);
+
+ s = txn2->Put("A", "a2");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("C", "c2");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("B", "b2");
+ ASSERT_OK(s);
+
+ s = txn1->Put("A", "aa");
+ ASSERT_OK(s);
+ s = txn1->Put("B", "bb");
+ ASSERT_TRUE(s.IsTimedOut());
+
+ s = txn2->Commit();
+ ASSERT_OK(s);
+ delete txn2;
+
+ s = txn1->Put("A", "aaa");
+ ASSERT_OK(s);
+ s = txn1->Put("B", "bbb");
+ ASSERT_OK(s);
+ s = txn1->Put("C", "ccc");
+ ASSERT_OK(s);
+
+ txn1->SetSavePoint(); // 3
+ ASSERT_OK(txn1->RollbackToSavePoint()); // Rollback to 3
+
+ // Verify that "A", "B", "C" are still locked
+ txn2 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn2);
+
+ s = txn2->Put("A", "a2");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("B", "b2");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("C", "c2");
+ ASSERT_TRUE(s.IsTimedOut());
+
+ ASSERT_OK(txn1->RollbackToSavePoint()); // Rollback to 1
+
+ // Verify that only "A" is locked
+ s = txn2->Put("A", "a3");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("B", "b3");
+ ASSERT_OK(s);
+ s = txn2->Put("C", "c3po");
+ ASSERT_OK(s);
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+ delete txn1;
+
+ // Verify "A" "C" "B" are no longer locked
+ s = txn2->Put("A", "a4");
+ ASSERT_OK(s);
+ s = txn2->Put("B", "b4");
+ ASSERT_OK(s);
+ s = txn2->Put("C", "c4");
+ ASSERT_OK(s);
+
+ s = txn2->Commit();
+ ASSERT_OK(s);
+ delete txn2;
+}
+
+TEST_P(TransactionTest, SavepointTest3) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ TransactionOptions txn_options;
+ Status s;
+
+ txn_options.lock_timeout = 1; // 1 ms
+ Transaction* txn1 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn1);
+
+ s = txn1->PopSavePoint(); // No SavePoint present
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn1->Put("A", "");
+ ASSERT_OK(s);
+
+ s = txn1->PopSavePoint(); // Still no SavePoint present
+ ASSERT_TRUE(s.IsNotFound());
+
+ txn1->SetSavePoint(); // 1
+
+ s = txn1->Put("A", "a");
+ ASSERT_OK(s);
+
+ s = txn1->PopSavePoint(); // Remove 1
+ ASSERT_TRUE(txn1->RollbackToSavePoint().IsNotFound());
+
+ // Verify that "A" is still locked
+ Transaction* txn2 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn2);
+
+ s = txn2->Put("A", "a2");
+ ASSERT_TRUE(s.IsTimedOut());
+ delete txn2;
+
+ txn1->SetSavePoint(); // 2
+
+ s = txn1->Put("B", "b");
+ ASSERT_OK(s);
+
+ txn1->SetSavePoint(); // 3
+
+ s = txn1->Put("B", "b2");
+ ASSERT_OK(s);
+
+ ASSERT_OK(txn1->RollbackToSavePoint()); // Roll back to 2
+
+ s = txn1->PopSavePoint();
+ ASSERT_OK(s);
+
+ s = txn1->PopSavePoint();
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+ delete txn1;
+
+ std::string value;
+
+ // tnx1 should have modified "A" to "a"
+ s = db->Get(read_options, "A", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("a", value);
+
+ // tnx1 should have set "B" to just "b"
+ s = db->Get(read_options, "B", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("b", value);
+
+ s = db->Get(read_options, "C", &value);
+ ASSERT_TRUE(s.IsNotFound());
+}
+
+TEST_P(TransactionTest, SavepointTest4) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ TransactionOptions txn_options;
+ Status s;
+
+ txn_options.lock_timeout = 1; // 1 ms
+ Transaction* txn1 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn1);
+
+ txn1->SetSavePoint(); // 1
+ s = txn1->Put("A", "a");
+ ASSERT_OK(s);
+
+ txn1->SetSavePoint(); // 2
+ s = txn1->Put("B", "b");
+ ASSERT_OK(s);
+
+ s = txn1->PopSavePoint(); // Remove 2
+ ASSERT_OK(s);
+
+ // Verify that A/B still exists.
+ std::string value;
+ ASSERT_OK(txn1->Get(read_options, "A", &value));
+ ASSERT_EQ("a", value);
+
+ ASSERT_OK(txn1->Get(read_options, "B", &value));
+ ASSERT_EQ("b", value);
+
+ ASSERT_OK(txn1->RollbackToSavePoint()); // Rollback to 1
+
+ // Verify that everything was rolled back.
+ s = txn1->Get(read_options, "A", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn1->Get(read_options, "B", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ // Nothing should be locked
+ Transaction* txn2 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn2);
+
+ s = txn2->Put("A", "");
+ ASSERT_OK(s);
+
+ s = txn2->Put("B", "");
+ ASSERT_OK(s);
+
+ delete txn2;
+ delete txn1;
+}
+
+TEST_P(TransactionTest, UndoGetForUpdateTest) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ TransactionOptions txn_options;
+ std::string value;
+ Status s;
+
+ txn_options.lock_timeout = 1; // 1 ms
+ Transaction* txn1 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn1);
+
+ txn1->UndoGetForUpdate("A");
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+ delete txn1;
+
+ txn1 = db->BeginTransaction(write_options, txn_options);
+
+ txn1->UndoGetForUpdate("A");
+ s = txn1->GetForUpdate(read_options, "A", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ // Verify that A is locked
+ Transaction* txn2 = db->BeginTransaction(write_options, txn_options);
+ s = txn2->Put("A", "a");
+ ASSERT_TRUE(s.IsTimedOut());
+
+ txn1->UndoGetForUpdate("A");
+
+ // Verify that A is now unlocked
+ s = txn2->Put("A", "a2");
+ ASSERT_OK(s);
+ txn2->Commit();
+ delete txn2;
+ s = db->Get(read_options, "A", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("a2", value);
+
+ s = txn1->Delete("A");
+ ASSERT_OK(s);
+ s = txn1->GetForUpdate(read_options, "A", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn1->Put("B", "b3");
+ ASSERT_OK(s);
+ s = txn1->GetForUpdate(read_options, "B", &value);
+ ASSERT_OK(s);
+
+ txn1->UndoGetForUpdate("A");
+ txn1->UndoGetForUpdate("B");
+
+ // Verify that A and B are still locked
+ txn2 = db->BeginTransaction(write_options, txn_options);
+ s = txn2->Put("A", "a4");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("B", "b4");
+ ASSERT_TRUE(s.IsTimedOut());
+
+ txn1->Rollback();
+ delete txn1;
+
+ // Verify that A and B are no longer locked
+ s = txn2->Put("A", "a5");
+ ASSERT_OK(s);
+ s = txn2->Put("B", "b5");
+ ASSERT_OK(s);
+ s = txn2->Commit();
+ delete txn2;
+ ASSERT_OK(s);
+
+ txn1 = db->BeginTransaction(write_options, txn_options);
+
+ s = txn1->GetForUpdate(read_options, "A", &value);
+ ASSERT_OK(s);
+ s = txn1->GetForUpdate(read_options, "A", &value);
+ ASSERT_OK(s);
+ s = txn1->GetForUpdate(read_options, "C", &value);
+ ASSERT_TRUE(s.IsNotFound());
+ s = txn1->GetForUpdate(read_options, "A", &value);
+ ASSERT_OK(s);
+ s = txn1->GetForUpdate(read_options, "C", &value);
+ ASSERT_TRUE(s.IsNotFound());
+ s = txn1->GetForUpdate(read_options, "B", &value);
+ ASSERT_OK(s);
+ s = txn1->Put("B", "b5");
+ s = txn1->GetForUpdate(read_options, "B", &value);
+ ASSERT_OK(s);
+
+ txn1->UndoGetForUpdate("A");
+ txn1->UndoGetForUpdate("B");
+ txn1->UndoGetForUpdate("C");
+ txn1->UndoGetForUpdate("X");
+
+ // Verify A,B,C are locked
+ txn2 = db->BeginTransaction(write_options, txn_options);
+ s = txn2->Put("A", "a6");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Delete("B");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("C", "c6");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("X", "x6");
+ ASSERT_OK(s);
+
+ txn1->UndoGetForUpdate("A");
+ txn1->UndoGetForUpdate("B");
+ txn1->UndoGetForUpdate("C");
+ txn1->UndoGetForUpdate("X");
+
+ // Verify A,B are locked and C is not
+ s = txn2->Put("A", "a6");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Delete("B");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("C", "c6");
+ ASSERT_OK(s);
+ s = txn2->Put("X", "x6");
+ ASSERT_OK(s);
+
+ txn1->UndoGetForUpdate("A");
+ txn1->UndoGetForUpdate("B");
+ txn1->UndoGetForUpdate("C");
+ txn1->UndoGetForUpdate("X");
+
+ // Verify B is locked and A and C are not
+ s = txn2->Put("A", "a7");
+ ASSERT_OK(s);
+ s = txn2->Delete("B");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("C", "c7");
+ ASSERT_OK(s);
+ s = txn2->Put("X", "x7");
+ ASSERT_OK(s);
+
+ s = txn2->Commit();
+ ASSERT_OK(s);
+ delete txn2;
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+ delete txn1;
+}
+
+TEST_P(TransactionTest, UndoGetForUpdateTest2) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ TransactionOptions txn_options;
+ std::string value;
+ Status s;
+
+ s = db->Put(write_options, "A", "");
+ ASSERT_OK(s);
+
+ txn_options.lock_timeout = 1; // 1 ms
+ Transaction* txn1 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn1);
+
+ s = txn1->GetForUpdate(read_options, "A", &value);
+ ASSERT_OK(s);
+ s = txn1->GetForUpdate(read_options, "B", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn1->Put("F", "f");
+ ASSERT_OK(s);
+
+ txn1->SetSavePoint(); // 1
+
+ txn1->UndoGetForUpdate("A");
+
+ s = txn1->GetForUpdate(read_options, "C", &value);
+ ASSERT_TRUE(s.IsNotFound());
+ s = txn1->GetForUpdate(read_options, "D", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn1->Put("E", "e");
+ ASSERT_OK(s);
+ s = txn1->GetForUpdate(read_options, "E", &value);
+ ASSERT_OK(s);
+
+ s = txn1->GetForUpdate(read_options, "F", &value);
+ ASSERT_OK(s);
+
+ // Verify A,B,C,D,E,F are still locked
+ Transaction* txn2 = db->BeginTransaction(write_options, txn_options);
+ s = txn2->Put("A", "a1");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("B", "b1");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("C", "c1");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("D", "d1");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("E", "e1");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("F", "f1");
+ ASSERT_TRUE(s.IsTimedOut());
+
+ txn1->UndoGetForUpdate("C");
+ txn1->UndoGetForUpdate("E");
+
+ // Verify A,B,D,E,F are still locked and C is not.
+ s = txn2->Put("A", "a2");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("B", "b2");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("D", "d2");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("E", "e2");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("F", "f2");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("C", "c2");
+ ASSERT_OK(s);
+
+ txn1->SetSavePoint(); // 2
+
+ s = txn1->Put("H", "h");
+ ASSERT_OK(s);
+
+ txn1->UndoGetForUpdate("A");
+ txn1->UndoGetForUpdate("B");
+ txn1->UndoGetForUpdate("C");
+ txn1->UndoGetForUpdate("D");
+ txn1->UndoGetForUpdate("E");
+ txn1->UndoGetForUpdate("F");
+ txn1->UndoGetForUpdate("G");
+ txn1->UndoGetForUpdate("H");
+
+ // Verify A,B,D,E,F,H are still locked and C,G are not.
+ s = txn2->Put("A", "a3");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("B", "b3");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("D", "d3");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("E", "e3");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("F", "f3");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("H", "h3");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("C", "c3");
+ ASSERT_OK(s);
+ s = txn2->Put("G", "g3");
+ ASSERT_OK(s);
+
+ txn1->RollbackToSavePoint(); // rollback to 2
+
+ // Verify A,B,D,E,F are still locked and C,G,H are not.
+ s = txn2->Put("A", "a3");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("B", "b3");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("D", "d3");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("E", "e3");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("F", "f3");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("C", "c3");
+ ASSERT_OK(s);
+ s = txn2->Put("G", "g3");
+ ASSERT_OK(s);
+ s = txn2->Put("H", "h3");
+ ASSERT_OK(s);
+
+ txn1->UndoGetForUpdate("A");
+ txn1->UndoGetForUpdate("B");
+ txn1->UndoGetForUpdate("C");
+ txn1->UndoGetForUpdate("D");
+ txn1->UndoGetForUpdate("E");
+ txn1->UndoGetForUpdate("F");
+ txn1->UndoGetForUpdate("G");
+ txn1->UndoGetForUpdate("H");
+
+ // Verify A,B,E,F are still locked and C,D,G,H are not.
+ s = txn2->Put("A", "a3");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("B", "b3");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("E", "e3");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("F", "f3");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("C", "c3");
+ ASSERT_OK(s);
+ s = txn2->Put("D", "d3");
+ ASSERT_OK(s);
+ s = txn2->Put("G", "g3");
+ ASSERT_OK(s);
+ s = txn2->Put("H", "h3");
+ ASSERT_OK(s);
+
+ txn1->RollbackToSavePoint(); // rollback to 1
+
+ // Verify A,B,F are still locked and C,D,E,G,H are not.
+ s = txn2->Put("A", "a3");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("B", "b3");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("F", "f3");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("C", "c3");
+ ASSERT_OK(s);
+ s = txn2->Put("D", "d3");
+ ASSERT_OK(s);
+ s = txn2->Put("E", "e3");
+ ASSERT_OK(s);
+ s = txn2->Put("G", "g3");
+ ASSERT_OK(s);
+ s = txn2->Put("H", "h3");
+ ASSERT_OK(s);
+
+ txn1->UndoGetForUpdate("A");
+ txn1->UndoGetForUpdate("B");
+ txn1->UndoGetForUpdate("C");
+ txn1->UndoGetForUpdate("D");
+ txn1->UndoGetForUpdate("E");
+ txn1->UndoGetForUpdate("F");
+ txn1->UndoGetForUpdate("G");
+ txn1->UndoGetForUpdate("H");
+
+ // Verify F is still locked and A,B,C,D,E,G,H are not.
+ s = txn2->Put("F", "f3");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Put("A", "a3");
+ ASSERT_OK(s);
+ s = txn2->Put("B", "b3");
+ ASSERT_OK(s);
+ s = txn2->Put("C", "c3");
+ ASSERT_OK(s);
+ s = txn2->Put("D", "d3");
+ ASSERT_OK(s);
+ s = txn2->Put("E", "e3");
+ ASSERT_OK(s);
+ s = txn2->Put("G", "g3");
+ ASSERT_OK(s);
+ s = txn2->Put("H", "h3");
+ ASSERT_OK(s);
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+ s = txn2->Commit();
+ ASSERT_OK(s);
+
+ delete txn1;
+ delete txn2;
+}
+
+TEST_P(TransactionTest, TimeoutTest) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ std::string value;
+ Status s;
+
+ delete db;
+ db = nullptr;
+
+ // transaction writes have an infinite timeout,
+ // but we will override this when we start a txn
+ // db writes have infinite timeout
+ txn_db_options.transaction_lock_timeout = -1;
+ txn_db_options.default_lock_timeout = -1;
+
+ s = TransactionDB::Open(options, txn_db_options, dbname, &db);
+ assert(db != nullptr);
+ ASSERT_OK(s);
+
+ s = db->Put(write_options, "aaa", "aaa");
+ ASSERT_OK(s);
+
+ TransactionOptions txn_options0;
+ txn_options0.expiration = 100; // 100ms
+ txn_options0.lock_timeout = 50; // txn timeout no longer infinite
+ Transaction* txn1 = db->BeginTransaction(write_options, txn_options0);
+
+ s = txn1->GetForUpdate(read_options, "aaa", nullptr);
+ ASSERT_OK(s);
+
+ // Conflicts with previous GetForUpdate.
+ // Since db writes do not have a timeout, this should eventually succeed when
+ // the transaction expires.
+ s = db->Put(write_options, "aaa", "xxx");
+ ASSERT_OK(s);
+
+ ASSERT_GE(txn1->GetElapsedTime(),
+ static_cast<uint64_t>(txn_options0.expiration));
+
+ s = txn1->Commit();
+ ASSERT_TRUE(s.IsExpired()); // expired!
+
+ s = db->Get(read_options, "aaa", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("xxx", value);
+
+ delete txn1;
+ delete db;
+
+ // transaction writes have 10ms timeout,
+ // db writes have infinite timeout
+ txn_db_options.transaction_lock_timeout = 50;
+ txn_db_options.default_lock_timeout = -1;
+
+ s = TransactionDB::Open(options, txn_db_options, dbname, &db);
+ ASSERT_OK(s);
+
+ s = db->Put(write_options, "aaa", "aaa");
+ ASSERT_OK(s);
+
+ TransactionOptions txn_options;
+ txn_options.expiration = 100; // 100ms
+ txn1 = db->BeginTransaction(write_options, txn_options);
+
+ s = txn1->GetForUpdate(read_options, "aaa", nullptr);
+ ASSERT_OK(s);
+
+ // Conflicts with previous GetForUpdate.
+ // Since db writes do not have a timeout, this should eventually succeed when
+ // the transaction expires.
+ s = db->Put(write_options, "aaa", "xxx");
+ ASSERT_OK(s);
+
+ s = txn1->Commit();
+ ASSERT_NOK(s); // expired!
+
+ s = db->Get(read_options, "aaa", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("xxx", value);
+
+ delete txn1;
+ txn_options.expiration = 6000000; // 100 minutes
+ txn_options.lock_timeout = 1; // 1ms
+ txn1 = db->BeginTransaction(write_options, txn_options);
+ txn1->SetLockTimeout(100);
+
+ TransactionOptions txn_options2;
+ txn_options2.expiration = 10; // 10ms
+ Transaction* txn2 = db->BeginTransaction(write_options, txn_options2);
+ ASSERT_OK(s);
+
+ s = txn2->Put("a", "2");
+ ASSERT_OK(s);
+
+ // txn1 has a lock timeout longer than txn2's expiration, so it will win
+ s = txn1->Delete("a");
+ ASSERT_OK(s);
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ // txn2 should be expired out since txn1 waiting until its timeout expired.
+ s = txn2->Commit();
+ ASSERT_TRUE(s.IsExpired());
+
+ delete txn1;
+ delete txn2;
+ txn_options.expiration = 6000000; // 100 minutes
+ txn1 = db->BeginTransaction(write_options, txn_options);
+ txn_options2.expiration = 100000000;
+ txn2 = db->BeginTransaction(write_options, txn_options2);
+
+ s = txn1->Delete("asdf");
+ ASSERT_OK(s);
+
+ // txn2 has a smaller lock timeout than txn1's expiration, so it will time out
+ s = txn2->Delete("asdf");
+ ASSERT_TRUE(s.IsTimedOut());
+ ASSERT_EQ(s.ToString(), "Operation timed out: Timeout waiting to lock key");
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ s = txn2->Put("asdf", "asdf");
+ ASSERT_OK(s);
+
+ s = txn2->Commit();
+ ASSERT_OK(s);
+
+ s = db->Get(read_options, "asdf", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("asdf", value);
+
+ delete txn1;
+ delete txn2;
+}
+
+TEST_P(TransactionTest, SingleDeleteTest) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ std::string value;
+ Status s;
+
+ Transaction* txn = db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn);
+
+ s = txn->SingleDelete("A");
+ ASSERT_OK(s);
+
+ s = txn->Get(read_options, "A", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+ delete txn;
+
+ txn = db->BeginTransaction(write_options);
+
+ s = txn->SingleDelete("A");
+ ASSERT_OK(s);
+
+ s = txn->Put("A", "a");
+ ASSERT_OK(s);
+
+ s = txn->Get(read_options, "A", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("a", value);
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+ delete txn;
+
+ s = db->Get(read_options, "A", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("a", value);
+
+ txn = db->BeginTransaction(write_options);
+
+ s = txn->SingleDelete("A");
+ ASSERT_OK(s);
+
+ s = txn->Get(read_options, "A", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+ delete txn;
+
+ s = db->Get(read_options, "A", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ txn = db->BeginTransaction(write_options);
+ Transaction* txn2 = db->BeginTransaction(write_options);
+ txn2->SetSnapshot();
+
+ s = txn->Put("A", "a");
+ ASSERT_OK(s);
+
+ s = txn->Put("A", "a2");
+ ASSERT_OK(s);
+
+ s = txn->SingleDelete("A");
+ ASSERT_OK(s);
+
+ s = txn->SingleDelete("B");
+ ASSERT_OK(s);
+
+ // According to db.h, doing a SingleDelete on a key that has been
+ // overwritten will have undefinied behavior. So it is unclear what the
+ // result of fetching "A" should be. The current implementation will
+ // return NotFound in this case.
+ s = txn->Get(read_options, "A", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn2->Put("B", "b");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Commit();
+ ASSERT_OK(s);
+ delete txn2;
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+ delete txn;
+
+ // According to db.h, doing a SingleDelete on a key that has been
+ // overwritten will have undefinied behavior. So it is unclear what the
+ // result of fetching "A" should be. The current implementation will
+ // return NotFound in this case.
+ s = db->Get(read_options, "A", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = db->Get(read_options, "B", &value);
+ ASSERT_TRUE(s.IsNotFound());
+}
+
+TEST_P(TransactionTest, MergeTest) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ std::string value;
+ Status s;
+
+ Transaction* txn = db->BeginTransaction(write_options, TransactionOptions());
+ ASSERT_TRUE(txn);
+
+ s = db->Put(write_options, "A", "a0");
+ ASSERT_OK(s);
+
+ s = txn->Merge("A", "1");
+ ASSERT_OK(s);
+
+ s = txn->Merge("A", "2");
+ ASSERT_OK(s);
+
+ s = txn->Get(read_options, "A", &value);
+ ASSERT_TRUE(s.IsMergeInProgress());
+
+ s = txn->Put("A", "a");
+ ASSERT_OK(s);
+
+ s = txn->Get(read_options, "A", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("a", value);
+
+ s = txn->Merge("A", "3");
+ ASSERT_OK(s);
+
+ s = txn->Get(read_options, "A", &value);
+ ASSERT_TRUE(s.IsMergeInProgress());
+
+ TransactionOptions txn_options;
+ txn_options.lock_timeout = 1; // 1 ms
+ Transaction* txn2 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn2);
+
+ // verify that txn has "A" locked
+ s = txn2->Merge("A", "4");
+ ASSERT_TRUE(s.IsTimedOut());
+
+ s = txn2->Commit();
+ ASSERT_OK(s);
+ delete txn2;
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+ delete txn;
+
+ s = db->Get(read_options, "A", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("a,3", value);
+}
+
+TEST_P(TransactionTest, DeferSnapshotTest) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ std::string value;
+ Status s;
+
+ s = db->Put(write_options, "A", "a0");
+ ASSERT_OK(s);
+
+ Transaction* txn1 = db->BeginTransaction(write_options);
+ Transaction* txn2 = db->BeginTransaction(write_options);
+
+ txn1->SetSnapshotOnNextOperation();
+ auto snapshot = txn1->GetSnapshot();
+ ASSERT_FALSE(snapshot);
+
+ s = txn2->Put("A", "a2");
+ ASSERT_OK(s);
+ s = txn2->Commit();
+ ASSERT_OK(s);
+ delete txn2;
+
+ s = txn1->GetForUpdate(read_options, "A", &value);
+ // Should not conflict with txn2 since snapshot wasn't set until
+ // GetForUpdate was called.
+ ASSERT_OK(s);
+ ASSERT_EQ("a2", value);
+
+ s = txn1->Put("A", "a1");
+ ASSERT_OK(s);
+
+ s = db->Put(write_options, "B", "b0");
+ ASSERT_OK(s);
+
+ // Cannot lock B since it was written after the snapshot was set
+ s = txn1->Put("B", "b1");
+ ASSERT_TRUE(s.IsBusy());
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+ delete txn1;
+
+ s = db->Get(read_options, "A", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("a1", value);
+
+ s = db->Get(read_options, "B", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("b0", value);
+}
+
+TEST_P(TransactionTest, DeferSnapshotTest2) {
+ WriteOptions write_options;
+ ReadOptions read_options, snapshot_read_options;
+ std::string value;
+ Status s;
+
+ Transaction* txn1 = db->BeginTransaction(write_options);
+
+ txn1->SetSnapshot();
+
+ s = txn1->Put("A", "a1");
+ ASSERT_OK(s);
+
+ s = db->Put(write_options, "C", "c0");
+ ASSERT_OK(s);
+ s = db->Put(write_options, "D", "d0");
+ ASSERT_OK(s);
+
+ snapshot_read_options.snapshot = txn1->GetSnapshot();
+
+ txn1->SetSnapshotOnNextOperation();
+
+ s = txn1->Get(snapshot_read_options, "C", &value);
+ // Snapshot was set before C was written
+ ASSERT_TRUE(s.IsNotFound());
+ s = txn1->Get(snapshot_read_options, "D", &value);
+ // Snapshot was set before D was written
+ ASSERT_TRUE(s.IsNotFound());
+
+ // Snapshot should not have changed yet.
+ snapshot_read_options.snapshot = txn1->GetSnapshot();
+
+ s = txn1->Get(snapshot_read_options, "C", &value);
+ // Snapshot was set before C was written
+ ASSERT_TRUE(s.IsNotFound());
+ s = txn1->Get(snapshot_read_options, "D", &value);
+ // Snapshot was set before D was written
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = txn1->GetForUpdate(read_options, "C", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("c0", value);
+
+ s = db->Put(write_options, "D", "d00");
+ ASSERT_OK(s);
+
+ // Snapshot is now set
+ snapshot_read_options.snapshot = txn1->GetSnapshot();
+ s = txn1->Get(snapshot_read_options, "D", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("d0", value);
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+ delete txn1;
+}
+
+TEST_P(TransactionTest, DeferSnapshotSavePointTest) {
+ WriteOptions write_options;
+ ReadOptions read_options, snapshot_read_options;
+ std::string value;
+ Status s;
+
+ Transaction* txn1 = db->BeginTransaction(write_options);
+
+ txn1->SetSavePoint(); // 1
+
+ s = db->Put(write_options, "T", "1");
+ ASSERT_OK(s);
+
+ txn1->SetSnapshotOnNextOperation();
+
+ s = db->Put(write_options, "T", "2");
+ ASSERT_OK(s);
+
+ txn1->SetSavePoint(); // 2
+
+ s = db->Put(write_options, "T", "3");
+ ASSERT_OK(s);
+
+ s = txn1->Put("A", "a");
+ ASSERT_OK(s);
+
+ txn1->SetSavePoint(); // 3
+
+ s = db->Put(write_options, "T", "4");
+ ASSERT_OK(s);
+
+ txn1->SetSnapshot();
+ txn1->SetSnapshotOnNextOperation();
+
+ txn1->SetSavePoint(); // 4
+
+ s = db->Put(write_options, "T", "5");
+ ASSERT_OK(s);
+
+ snapshot_read_options.snapshot = txn1->GetSnapshot();
+ s = txn1->Get(snapshot_read_options, "T", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("4", value);
+
+ s = txn1->Put("A", "a1");
+ ASSERT_OK(s);
+
+ snapshot_read_options.snapshot = txn1->GetSnapshot();
+ s = txn1->Get(snapshot_read_options, "T", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("5", value);
+
+ s = txn1->RollbackToSavePoint(); // Rollback to 4
+ ASSERT_OK(s);
+
+ snapshot_read_options.snapshot = txn1->GetSnapshot();
+ s = txn1->Get(snapshot_read_options, "T", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("4", value);
+
+ s = txn1->RollbackToSavePoint(); // Rollback to 3
+ ASSERT_OK(s);
+
+ snapshot_read_options.snapshot = txn1->GetSnapshot();
+ s = txn1->Get(snapshot_read_options, "T", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("3", value);
+
+ s = txn1->Get(read_options, "T", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("5", value);
+
+ s = txn1->RollbackToSavePoint(); // Rollback to 2
+ ASSERT_OK(s);
+
+ snapshot_read_options.snapshot = txn1->GetSnapshot();
+ ASSERT_FALSE(snapshot_read_options.snapshot);
+ s = txn1->Get(snapshot_read_options, "T", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("5", value);
+
+ s = txn1->Delete("A");
+ ASSERT_OK(s);
+
+ snapshot_read_options.snapshot = txn1->GetSnapshot();
+ ASSERT_TRUE(snapshot_read_options.snapshot);
+ s = txn1->Get(snapshot_read_options, "T", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("5", value);
+
+ s = txn1->RollbackToSavePoint(); // Rollback to 1
+ ASSERT_OK(s);
+
+ s = txn1->Delete("A");
+ ASSERT_OK(s);
+
+ snapshot_read_options.snapshot = txn1->GetSnapshot();
+ ASSERT_FALSE(snapshot_read_options.snapshot);
+ s = txn1->Get(snapshot_read_options, "T", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("5", value);
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ delete txn1;
+}
+
+TEST_P(TransactionTest, SetSnapshotOnNextOperationWithNotification) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ std::string value;
+
+ class Notifier : public TransactionNotifier {
+ private:
+ const Snapshot** snapshot_ptr_;
+
+ public:
+ explicit Notifier(const Snapshot** snapshot_ptr)
+ : snapshot_ptr_(snapshot_ptr) {}
+
+ void SnapshotCreated(const Snapshot* newSnapshot) override {
+ *snapshot_ptr_ = newSnapshot;
+ }
+ };
+
+ std::shared_ptr<Notifier> notifier =
+ std::make_shared<Notifier>(&read_options.snapshot);
+ Status s;
+
+ s = db->Put(write_options, "B", "0");
+ ASSERT_OK(s);
+
+ Transaction* txn1 = db->BeginTransaction(write_options);
+
+ txn1->SetSnapshotOnNextOperation(notifier);
+ ASSERT_FALSE(read_options.snapshot);
+
+ s = db->Put(write_options, "B", "1");
+ ASSERT_OK(s);
+
+ // A Get does not generate the snapshot
+ s = txn1->Get(read_options, "B", &value);
+ ASSERT_OK(s);
+ ASSERT_FALSE(read_options.snapshot);
+ ASSERT_EQ(value, "1");
+
+ // Any other operation does
+ s = txn1->Put("A", "0");
+ ASSERT_OK(s);
+
+ // Now change "B".
+ s = db->Put(write_options, "B", "2");
+ ASSERT_OK(s);
+
+ // The original value should still be read
+ s = txn1->Get(read_options, "B", &value);
+ ASSERT_OK(s);
+ ASSERT_TRUE(read_options.snapshot);
+ ASSERT_EQ(value, "1");
+
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ delete txn1;
+}
+
+TEST_P(TransactionTest, ClearSnapshotTest) {
+ WriteOptions write_options;
+ ReadOptions read_options, snapshot_read_options;
+ std::string value;
+ Status s;
+
+ s = db->Put(write_options, "foo", "0");
+ ASSERT_OK(s);
+
+ Transaction* txn = db->BeginTransaction(write_options);
+ ASSERT_TRUE(txn);
+
+ s = db->Put(write_options, "foo", "1");
+ ASSERT_OK(s);
+
+ snapshot_read_options.snapshot = txn->GetSnapshot();
+ ASSERT_FALSE(snapshot_read_options.snapshot);
+
+ // No snapshot created yet
+ s = txn->Get(snapshot_read_options, "foo", &value);
+ ASSERT_EQ(value, "1");
+
+ txn->SetSnapshot();
+ snapshot_read_options.snapshot = txn->GetSnapshot();
+ ASSERT_TRUE(snapshot_read_options.snapshot);
+
+ s = db->Put(write_options, "foo", "2");
+ ASSERT_OK(s);
+
+ // Snapshot was created before change to '2'
+ s = txn->Get(snapshot_read_options, "foo", &value);
+ ASSERT_EQ(value, "1");
+
+ txn->ClearSnapshot();
+ snapshot_read_options.snapshot = txn->GetSnapshot();
+ ASSERT_FALSE(snapshot_read_options.snapshot);
+
+ // Snapshot has now been cleared
+ s = txn->Get(snapshot_read_options, "foo", &value);
+ ASSERT_EQ(value, "2");
+
+ s = txn->Commit();
+ ASSERT_OK(s);
+
+ delete txn;
+}
+
+TEST_P(TransactionTest, ToggleAutoCompactionTest) {
+ Status s;
+
+ ColumnFamilyHandle *cfa, *cfb;
+ ColumnFamilyOptions cf_options;
+
+ // Create 2 new column families
+ s = db->CreateColumnFamily(cf_options, "CFA", &cfa);
+ ASSERT_OK(s);
+ s = db->CreateColumnFamily(cf_options, "CFB", &cfb);
+ ASSERT_OK(s);
+
+ delete cfa;
+ delete cfb;
+ delete db;
+
+ // open DB with three column families
+ std::vector<ColumnFamilyDescriptor> column_families;
+ // have to open default column family
+ column_families.push_back(
+ ColumnFamilyDescriptor(kDefaultColumnFamilyName, ColumnFamilyOptions()));
+ // open the new column families
+ column_families.push_back(
+ ColumnFamilyDescriptor("CFA", ColumnFamilyOptions()));
+ column_families.push_back(
+ ColumnFamilyDescriptor("CFB", ColumnFamilyOptions()));
+
+ ColumnFamilyOptions* cf_opt_default = &column_families[0].options;
+ ColumnFamilyOptions* cf_opt_cfa = &column_families[1].options;
+ ColumnFamilyOptions* cf_opt_cfb = &column_families[2].options;
+ cf_opt_default->disable_auto_compactions = false;
+ cf_opt_cfa->disable_auto_compactions = true;
+ cf_opt_cfb->disable_auto_compactions = false;
+
+ std::vector<ColumnFamilyHandle*> handles;
+
+ s = TransactionDB::Open(options, txn_db_options, dbname, column_families,
+ &handles, &db);
+ ASSERT_OK(s);
+
+ auto cfh_default = reinterpret_cast<ColumnFamilyHandleImpl*>(handles[0]);
+ auto opt_default = *cfh_default->cfd()->GetLatestMutableCFOptions();
+
+ auto cfh_a = reinterpret_cast<ColumnFamilyHandleImpl*>(handles[1]);
+ auto opt_a = *cfh_a->cfd()->GetLatestMutableCFOptions();
+
+ auto cfh_b = reinterpret_cast<ColumnFamilyHandleImpl*>(handles[2]);
+ auto opt_b = *cfh_b->cfd()->GetLatestMutableCFOptions();
+
+ ASSERT_EQ(opt_default.disable_auto_compactions, false);
+ ASSERT_EQ(opt_a.disable_auto_compactions, true);
+ ASSERT_EQ(opt_b.disable_auto_compactions, false);
+
+ for (auto handle : handles) {
+ delete handle;
+ }
+}
+
+TEST_P(TransactionStressTest, ExpiredTransactionDataRace1) {
+ // In this test, txn1 should succeed committing,
+ // as the callback is called after txn1 starts committing.
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency(
+ {{"TransactionTest::ExpirableTransactionDataRace:1"}});
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "TransactionTest::ExpirableTransactionDataRace:1", [&](void* /*arg*/) {
+ WriteOptions write_options;
+ TransactionOptions txn_options;
+
+ // Force txn1 to expire
+ /* sleep override */
+ std::this_thread::sleep_for(std::chrono::milliseconds(150));
+
+ Transaction* txn2 = db->BeginTransaction(write_options, txn_options);
+ Status s;
+ s = txn2->Put("X", "2");
+ ASSERT_TRUE(s.IsTimedOut());
+ s = txn2->Commit();
+ ASSERT_OK(s);
+ delete txn2;
+ });
+
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+
+ WriteOptions write_options;
+ TransactionOptions txn_options;
+
+ txn_options.expiration = 100;
+ Transaction* txn1 = db->BeginTransaction(write_options, txn_options);
+
+ Status s;
+ s = txn1->Put("X", "1");
+ ASSERT_OK(s);
+ s = txn1->Commit();
+ ASSERT_OK(s);
+
+ ReadOptions read_options;
+ string value;
+ s = db->Get(read_options, "X", &value);
+ ASSERT_EQ("1", value);
+
+ delete txn1;
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+}
+
+#ifndef ROCKSDB_VALGRIND_RUN
+namespace {
+// cmt_delay_ms is the delay between prepare and commit
+// first_id is the id of the first transaction
+Status TransactionStressTestInserter(
+ TransactionDB* db, const size_t num_transactions, const size_t num_sets,
+ const size_t num_keys_per_set, Random64* rand,
+ const uint64_t cmt_delay_ms = 0, const uint64_t first_id = 0) {
+ WriteOptions write_options;
+ ReadOptions read_options;
+ TransactionOptions txn_options;
+ if (rand->OneIn(2)) {
+ txn_options.use_only_the_last_commit_time_batch_for_recovery = true;
+ }
+ // Inside the inserter we might also retake the snapshot. We do both since two
+ // separte functions are engaged for each.
+ txn_options.set_snapshot = rand->OneIn(2);
+
+ RandomTransactionInserter inserter(
+ rand, write_options, read_options, num_keys_per_set,
+ static_cast<uint16_t>(num_sets), cmt_delay_ms, first_id);
+
+ for (size_t t = 0; t < num_transactions; t++) {
+ bool success = inserter.TransactionDBInsert(db, txn_options);
+ if (!success) {
+ // unexpected failure
+ return inserter.GetLastStatus();
+ }
+ }
+
+ // Make sure at least some of the transactions succeeded. It's ok if
+ // some failed due to write-conflicts.
+ if (num_transactions != 1 &&
+ inserter.GetFailureCount() > num_transactions / 2) {
+ return Status::TryAgain("Too many transactions failed! " +
+ std::to_string(inserter.GetFailureCount()) + " / " +
+ std::to_string(num_transactions));
+ }
+
+ return Status::OK();
+}
+} // namespace
+
+// Worker threads add a number to a key from each set of keys. The checker
+// threads verify that the sum of all keys in each set are equal.
+TEST_P(MySQLStyleTransactionTest, TransactionStressTest) {
+ // Small write buffer to trigger more compactions
+ options.write_buffer_size = 1024;
+ ReOpenNoDelete();
+ const size_t num_workers = 4; // worker threads count
+ const size_t num_checkers = 2; // checker threads count
+ const size_t num_slow_checkers = 2; // checker threads emulating backups
+ const size_t num_slow_workers = 1; // slow worker threads count
+ const size_t num_transactions_per_thread = 10000;
+ const uint16_t num_sets = 3;
+ const size_t num_keys_per_set = 100;
+ // Setting the key-space to be 100 keys should cause enough write-conflicts
+ // to make this test interesting.
+
+ std::vector<port::Thread> threads;
+ std::atomic<uint32_t> finished = {0};
+ bool TAKE_SNAPSHOT = true;
+ uint64_t time_seed = env->NowMicros();
+ printf("time_seed is %" PRIu64 "\n", time_seed); // would help to reproduce
+
+ std::function<void()> call_inserter = [&] {
+ size_t thd_seed = std::hash<std::thread::id>()(std::this_thread::get_id());
+ Random64 rand(time_seed * thd_seed);
+ ASSERT_OK(TransactionStressTestInserter(db, num_transactions_per_thread,
+ num_sets, num_keys_per_set, &rand));
+ finished++;
+ };
+ std::function<void()> call_checker = [&] {
+ size_t thd_seed = std::hash<std::thread::id>()(std::this_thread::get_id());
+ Random64 rand(time_seed * thd_seed);
+ // Verify that data is consistent
+ while (finished < num_workers) {
+ Status s = RandomTransactionInserter::Verify(
+ db, num_sets, num_keys_per_set, TAKE_SNAPSHOT, &rand);
+ ASSERT_OK(s);
+ }
+ };
+ std::function<void()> call_slow_checker = [&] {
+ size_t thd_seed = std::hash<std::thread::id>()(std::this_thread::get_id());
+ Random64 rand(time_seed * thd_seed);
+ // Verify that data is consistent
+ while (finished < num_workers) {
+ uint64_t delay_ms = rand.Uniform(100) + 1;
+ Status s = RandomTransactionInserter::Verify(
+ db, num_sets, num_keys_per_set, TAKE_SNAPSHOT, &rand, delay_ms);
+ ASSERT_OK(s);
+ }
+ };
+ std::function<void()> call_slow_inserter = [&] {
+ size_t thd_seed = std::hash<std::thread::id>()(std::this_thread::get_id());
+ Random64 rand(time_seed * thd_seed);
+ uint64_t id = 0;
+ // Verify that data is consistent
+ while (finished < num_workers) {
+ uint64_t delay_ms = rand.Uniform(500) + 1;
+ ASSERT_OK(TransactionStressTestInserter(db, 1, num_sets, num_keys_per_set,
+ &rand, delay_ms, id++));
+ }
+ };
+
+ for (uint32_t i = 0; i < num_workers; i++) {
+ threads.emplace_back(call_inserter);
+ }
+ for (uint32_t i = 0; i < num_checkers; i++) {
+ threads.emplace_back(call_checker);
+ }
+ if (with_slow_threads_) {
+ for (uint32_t i = 0; i < num_slow_checkers; i++) {
+ threads.emplace_back(call_slow_checker);
+ }
+ for (uint32_t i = 0; i < num_slow_workers; i++) {
+ threads.emplace_back(call_slow_inserter);
+ }
+ }
+
+ // Wait for all threads to finish
+ for (auto& t : threads) {
+ t.join();
+ }
+
+ // Verify that data is consistent
+ Status s = RandomTransactionInserter::Verify(db, num_sets, num_keys_per_set,
+ !TAKE_SNAPSHOT);
+ ASSERT_OK(s);
+}
+#endif // ROCKSDB_VALGRIND_RUN
+
+TEST_P(TransactionTest, MemoryLimitTest) {
+ TransactionOptions txn_options;
+ // Header (12 bytes) + NOOP (1 byte) + 2 * 8 bytes for data.
+ txn_options.max_write_batch_size = 29;
+ // Set threshold to unlimited so that the write batch does not get flushed,
+ // and can hit the memory limit.
+ txn_options.write_batch_flush_threshold = 0;
+ std::string value;
+ Status s;
+
+ Transaction* txn = db->BeginTransaction(WriteOptions(), txn_options);
+ ASSERT_TRUE(txn);
+
+ ASSERT_EQ(0, txn->GetNumPuts());
+ ASSERT_LE(0, txn->GetID());
+
+ s = txn->Put(Slice("a"), Slice("...."));
+ ASSERT_OK(s);
+ ASSERT_EQ(1, txn->GetNumPuts());
+
+ s = txn->Put(Slice("b"), Slice("...."));
+ ASSERT_OK(s);
+ ASSERT_EQ(2, txn->GetNumPuts());
+
+ s = txn->Put(Slice("b"), Slice("...."));
+ ASSERT_TRUE(s.IsMemoryLimit());
+ ASSERT_EQ(2, txn->GetNumPuts());
+
+ txn->Rollback();
+ delete txn;
+}
+
+// This test clarifies the existing expectation from the sequence number
+// algorithm. It could detect mistakes in updating the code but it is not
+// necessarily the one acceptable way. If the algorithm is legitimately changed,
+// this unit test should be updated as well.
+TEST_P(TransactionStressTest, SeqAdvanceTest) {
+ // TODO(myabandeh): must be test with false before new releases
+ const bool short_test = true;
+ WriteOptions wopts;
+ FlushOptions fopt;
+
+ options.disable_auto_compactions = true;
+ ASSERT_OK(ReOpen());
+
+ // Do the test with NUM_BRANCHES branches in it. Each run of a test takes some
+ // of the branches. This is the same as counting a binary number where i-th
+ // bit represents whether we take branch i in the represented by the number.
+ const size_t NUM_BRANCHES = short_test ? 6 : 10;
+ // Helper function that shows if the branch is to be taken in the run
+ // represented by the number n.
+ auto branch_do = [&](size_t n, size_t* branch) {
+ assert(*branch < NUM_BRANCHES);
+ const size_t filter = static_cast<size_t>(1) << *branch;
+ return n & filter;
+ };
+ const size_t max_n = static_cast<size_t>(1) << NUM_BRANCHES;
+ for (size_t n = 0; n < max_n; n++) {
+ DBImpl* db_impl = reinterpret_cast<DBImpl*>(db->GetRootDB());
+ size_t branch = 0;
+ auto seq = db_impl->GetLatestSequenceNumber();
+ exp_seq = seq;
+ txn_t0(0);
+ seq = db_impl->TEST_GetLastVisibleSequence();
+ ASSERT_EQ(exp_seq, seq);
+
+ if (branch_do(n, &branch)) {
+ ASSERT_OK(db_impl->Flush(fopt));
+ seq = db_impl->TEST_GetLastVisibleSequence();
+ ASSERT_EQ(exp_seq, seq);
+ }
+ if (!short_test && branch_do(n, &branch)) {
+ ASSERT_OK(db_impl->FlushWAL(true));
+ ASSERT_OK(ReOpenNoDelete());
+ db_impl = reinterpret_cast<DBImpl*>(db->GetRootDB());
+ seq = db_impl->GetLatestSequenceNumber();
+ ASSERT_EQ(exp_seq, seq);
+ }
+
+ // Doing it twice might detect some bugs
+ txn_t0(1);
+ seq = db_impl->TEST_GetLastVisibleSequence();
+ ASSERT_EQ(exp_seq, seq);
+
+ txn_t1(0);
+ seq = db_impl->TEST_GetLastVisibleSequence();
+ ASSERT_EQ(exp_seq, seq);
+
+ if (branch_do(n, &branch)) {
+ ASSERT_OK(db_impl->Flush(fopt));
+ seq = db_impl->TEST_GetLastVisibleSequence();
+ ASSERT_EQ(exp_seq, seq);
+ }
+ if (!short_test && branch_do(n, &branch)) {
+ ASSERT_OK(db_impl->FlushWAL(true));
+ ASSERT_OK(ReOpenNoDelete());
+ db_impl = reinterpret_cast<DBImpl*>(db->GetRootDB());
+ seq = db_impl->GetLatestSequenceNumber();
+ ASSERT_EQ(exp_seq, seq);
+ }
+
+ txn_t3(0);
+ seq = db_impl->TEST_GetLastVisibleSequence();
+ ASSERT_EQ(exp_seq, seq);
+
+ if (branch_do(n, &branch)) {
+ ASSERT_OK(db_impl->Flush(fopt));
+ seq = db_impl->TEST_GetLastVisibleSequence();
+ ASSERT_EQ(exp_seq, seq);
+ }
+ if (!short_test && branch_do(n, &branch)) {
+ ASSERT_OK(db_impl->FlushWAL(true));
+ ASSERT_OK(ReOpenNoDelete());
+ db_impl = reinterpret_cast<DBImpl*>(db->GetRootDB());
+ seq = db_impl->GetLatestSequenceNumber();
+ ASSERT_EQ(exp_seq, seq);
+ }
+
+ txn_t4(0);
+ seq = db_impl->TEST_GetLastVisibleSequence();
+
+ ASSERT_EQ(exp_seq, seq);
+
+ if (branch_do(n, &branch)) {
+ ASSERT_OK(db_impl->Flush(fopt));
+ seq = db_impl->TEST_GetLastVisibleSequence();
+ ASSERT_EQ(exp_seq, seq);
+ }
+ if (!short_test && branch_do(n, &branch)) {
+ ASSERT_OK(db_impl->FlushWAL(true));
+ ASSERT_OK(ReOpenNoDelete());
+ db_impl = reinterpret_cast<DBImpl*>(db->GetRootDB());
+ seq = db_impl->GetLatestSequenceNumber();
+ ASSERT_EQ(exp_seq, seq);
+ }
+
+ txn_t2(0);
+ seq = db_impl->TEST_GetLastVisibleSequence();
+ ASSERT_EQ(exp_seq, seq);
+
+ if (branch_do(n, &branch)) {
+ ASSERT_OK(db_impl->Flush(fopt));
+ seq = db_impl->TEST_GetLastVisibleSequence();
+ ASSERT_EQ(exp_seq, seq);
+ }
+ if (!short_test && branch_do(n, &branch)) {
+ ASSERT_OK(db_impl->FlushWAL(true));
+ ASSERT_OK(ReOpenNoDelete());
+ db_impl = reinterpret_cast<DBImpl*>(db->GetRootDB());
+ seq = db_impl->GetLatestSequenceNumber();
+ ASSERT_EQ(exp_seq, seq);
+ }
+ ASSERT_OK(ReOpen());
+ }
+}
+
+// Verify that the optimization would not compromize the correctness
+TEST_P(TransactionTest, Optimizations) {
+ size_t comb_cnt = size_t(1) << 2; // 2 is number of optimization vars
+ for (size_t new_comb = 0; new_comb < comb_cnt; new_comb++) {
+ TransactionDBWriteOptimizations optimizations;
+ optimizations.skip_concurrency_control = IsInCombination(0, new_comb);
+ optimizations.skip_duplicate_key_check = IsInCombination(1, new_comb);
+
+ ASSERT_OK(ReOpen());
+ WriteOptions write_options;
+ WriteBatch batch;
+ batch.Put(Slice("k"), Slice("v1"));
+ ASSERT_OK(db->Write(write_options, &batch));
+
+ ReadOptions ropt;
+ PinnableSlice pinnable_val;
+ ASSERT_OK(db->Get(ropt, db->DefaultColumnFamily(), "k", &pinnable_val));
+ ASSERT_TRUE(pinnable_val == ("v1"));
+ }
+}
+
+// A comparator that uses only the first three bytes
+class ThreeBytewiseComparator : public Comparator {
+ public:
+ ThreeBytewiseComparator() {}
+ const char* Name() const override { return "test.ThreeBytewiseComparator"; }
+ int Compare(const Slice& a, const Slice& b) const override {
+ Slice na = Slice(a.data(), a.size() < 3 ? a.size() : 3);
+ Slice nb = Slice(b.data(), b.size() < 3 ? b.size() : 3);
+ return na.compare(nb);
+ }
+ bool Equal(const Slice& a, const Slice& b) const override {
+ Slice na = Slice(a.data(), a.size() < 3 ? a.size() : 3);
+ Slice nb = Slice(b.data(), b.size() < 3 ? b.size() : 3);
+ return na == nb;
+ }
+ // This methods below dont seem relevant to this test. Implement them if
+ // proven othersize.
+ void FindShortestSeparator(std::string* start,
+ const Slice& limit) const override {
+ const Comparator* bytewise_comp = BytewiseComparator();
+ bytewise_comp->FindShortestSeparator(start, limit);
+ }
+ void FindShortSuccessor(std::string* key) const override {
+ const Comparator* bytewise_comp = BytewiseComparator();
+ bytewise_comp->FindShortSuccessor(key);
+ }
+};
+
+#ifndef ROCKSDB_VALGRIND_RUN
+TEST_P(TransactionTest, GetWithoutSnapshot) {
+ WriteOptions write_options;
+ std::atomic<bool> finish = {false};
+ db->Put(write_options, "key", "value");
+ ROCKSDB_NAMESPACE::port::Thread commit_thread([&]() {
+ for (int i = 0; i < 100; i++) {
+ TransactionOptions txn_options;
+ Transaction* txn = db->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(txn->SetName("xid"));
+ ASSERT_OK(txn->Put("key", "overridedvalue"));
+ ASSERT_OK(txn->Put("key", "value"));
+ ASSERT_OK(txn->Prepare());
+ ASSERT_OK(txn->Commit());
+ delete txn;
+ }
+ finish = true;
+ });
+ ROCKSDB_NAMESPACE::port::Thread read_thread([&]() {
+ while (!finish) {
+ ReadOptions ropt;
+ PinnableSlice pinnable_val;
+ ASSERT_OK(db->Get(ropt, db->DefaultColumnFamily(), "key", &pinnable_val));
+ ASSERT_TRUE(pinnable_val == ("value"));
+ }
+ });
+ commit_thread.join();
+ read_thread.join();
+}
+#endif // ROCKSDB_VALGRIND_RUN
+
+// Test that the transactional db can handle duplicate keys in the write batch
+TEST_P(TransactionTest, DuplicateKeys) {
+ ColumnFamilyOptions cf_options;
+ std::string cf_name = "two";
+ ColumnFamilyHandle* cf_handle = nullptr;
+ {
+ ASSERT_OK(db->CreateColumnFamily(cf_options, cf_name, &cf_handle));
+ WriteOptions write_options;
+ WriteBatch batch;
+ batch.Put(Slice("key"), Slice("value"));
+ batch.Put(Slice("key2"), Slice("value2"));
+ // duplicate the keys
+ batch.Put(Slice("key"), Slice("value3"));
+ // duplicate the 2nd key. It should not be counted duplicate since a
+ // sub-patch is cut after the last duplicate.
+ batch.Put(Slice("key2"), Slice("value4"));
+ // duplicate the keys but in a different cf. It should not be counted as
+ // duplicate keys
+ batch.Put(cf_handle, Slice("key"), Slice("value5"));
+
+ ASSERT_OK(db->Write(write_options, &batch));
+
+ ReadOptions ropt;
+ PinnableSlice pinnable_val;
+ auto s = db->Get(ropt, db->DefaultColumnFamily(), "key", &pinnable_val);
+ ASSERT_OK(s);
+ ASSERT_TRUE(pinnable_val == ("value3"));
+ s = db->Get(ropt, db->DefaultColumnFamily(), "key2", &pinnable_val);
+ ASSERT_OK(s);
+ ASSERT_TRUE(pinnable_val == ("value4"));
+ s = db->Get(ropt, cf_handle, "key", &pinnable_val);
+ ASSERT_OK(s);
+ ASSERT_TRUE(pinnable_val == ("value5"));
+
+ delete cf_handle;
+ }
+
+ // Test with non-bytewise comparator
+ {
+ ASSERT_OK(ReOpen());
+ std::unique_ptr<const Comparator> comp_gc(new ThreeBytewiseComparator());
+ cf_options.comparator = comp_gc.get();
+ ASSERT_OK(db->CreateColumnFamily(cf_options, cf_name, &cf_handle));
+ WriteOptions write_options;
+ WriteBatch batch;
+ batch.Put(cf_handle, Slice("key"), Slice("value"));
+ // The first three bytes are the same, do it must be counted as duplicate
+ batch.Put(cf_handle, Slice("key2"), Slice("value2"));
+ // check for 2nd duplicate key in cf with non-default comparator
+ batch.Put(cf_handle, Slice("key2b"), Slice("value2b"));
+ ASSERT_OK(db->Write(write_options, &batch));
+
+ // The value must be the most recent value for all the keys equal to "key",
+ // including "key2"
+ ReadOptions ropt;
+ PinnableSlice pinnable_val;
+ ASSERT_OK(db->Get(ropt, cf_handle, "key", &pinnable_val));
+ ASSERT_TRUE(pinnable_val == ("value2b"));
+
+ // Test duplicate keys with rollback
+ TransactionOptions txn_options;
+ Transaction* txn0 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(txn0->SetName("xid"));
+ ASSERT_OK(txn0->Put(cf_handle, Slice("key3"), Slice("value3")));
+ ASSERT_OK(txn0->Merge(cf_handle, Slice("key4"), Slice("value4")));
+ ASSERT_OK(txn0->Rollback());
+ ASSERT_OK(db->Get(ropt, cf_handle, "key5", &pinnable_val));
+ ASSERT_TRUE(pinnable_val == ("value2b"));
+ delete txn0;
+
+ delete cf_handle;
+ cf_options.comparator = BytewiseComparator();
+ }
+
+ for (bool do_prepare : {true, false}) {
+ for (bool do_rollback : {true, false}) {
+ for (bool with_commit_batch : {true, false}) {
+ if (with_commit_batch && !do_prepare) {
+ continue;
+ }
+ if (with_commit_batch && do_rollback) {
+ continue;
+ }
+ ASSERT_OK(ReOpen());
+ ASSERT_OK(db->CreateColumnFamily(cf_options, cf_name, &cf_handle));
+ TransactionOptions txn_options;
+ txn_options.use_only_the_last_commit_time_batch_for_recovery = false;
+ WriteOptions write_options;
+ Transaction* txn0 = db->BeginTransaction(write_options, txn_options);
+ auto s = txn0->SetName("xid");
+ ASSERT_OK(s);
+ s = txn0->Put(Slice("foo0"), Slice("bar0a"));
+ ASSERT_OK(s);
+ s = txn0->Put(Slice("foo0"), Slice("bar0b"));
+ ASSERT_OK(s);
+ s = txn0->Put(Slice("foo1"), Slice("bar1"));
+ ASSERT_OK(s);
+ s = txn0->Merge(Slice("foo2"), Slice("bar2a"));
+ ASSERT_OK(s);
+ // Repeat a key after the start of a sub-patch. This should not cause a
+ // duplicate in the most recent sub-patch and hence not creating a new
+ // sub-patch.
+ s = txn0->Put(Slice("foo0"), Slice("bar0c"));
+ ASSERT_OK(s);
+ s = txn0->Merge(Slice("foo2"), Slice("bar2b"));
+ ASSERT_OK(s);
+ // duplicate the keys but in a different cf. It should not be counted as
+ // duplicate.
+ s = txn0->Put(cf_handle, Slice("foo0"), Slice("bar0-cf1"));
+ ASSERT_OK(s);
+ s = txn0->Put(Slice("foo3"), Slice("bar3"));
+ ASSERT_OK(s);
+ s = txn0->Merge(Slice("foo3"), Slice("bar3"));
+ ASSERT_OK(s);
+ s = txn0->Put(Slice("foo4"), Slice("bar4"));
+ ASSERT_OK(s);
+ s = txn0->Delete(Slice("foo4"));
+ ASSERT_OK(s);
+ s = txn0->SingleDelete(Slice("foo4"));
+ ASSERT_OK(s);
+ if (do_prepare) {
+ s = txn0->Prepare();
+ ASSERT_OK(s);
+ }
+ if (do_rollback) {
+ // Test rolling back the batch with duplicates
+ s = txn0->Rollback();
+ ASSERT_OK(s);
+ } else {
+ if (with_commit_batch) {
+ assert(do_prepare);
+ auto cb = txn0->GetCommitTimeWriteBatch();
+ // duplicate a key in the original batch
+ // TODO(myabandeh): the behavior of GetCommitTimeWriteBatch
+ // conflicting with the prepared batch is currently undefined and
+ // gives different results in different implementations.
+
+ // s = cb->Put(Slice("foo0"), Slice("bar0d"));
+ // ASSERT_OK(s);
+ // add a new duplicate key
+ s = cb->Put(Slice("foo6"), Slice("bar6a"));
+ ASSERT_OK(s);
+ s = cb->Put(Slice("foo6"), Slice("bar6b"));
+ ASSERT_OK(s);
+ // add a duplicate key that is removed in the same batch
+ s = cb->Put(Slice("foo7"), Slice("bar7a"));
+ ASSERT_OK(s);
+ s = cb->Delete(Slice("foo7"));
+ ASSERT_OK(s);
+ }
+ s = txn0->Commit();
+ ASSERT_OK(s);
+ }
+ delete txn0;
+ ReadOptions ropt;
+ PinnableSlice pinnable_val;
+
+ if (do_rollback) {
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo0", &pinnable_val);
+ ASSERT_TRUE(s.IsNotFound());
+ s = db->Get(ropt, cf_handle, "foo0", &pinnable_val);
+ ASSERT_TRUE(s.IsNotFound());
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo1", &pinnable_val);
+ ASSERT_TRUE(s.IsNotFound());
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo2", &pinnable_val);
+ ASSERT_TRUE(s.IsNotFound());
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo3", &pinnable_val);
+ ASSERT_TRUE(s.IsNotFound());
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo4", &pinnable_val);
+ ASSERT_TRUE(s.IsNotFound());
+ } else {
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo0", &pinnable_val);
+ ASSERT_OK(s);
+ ASSERT_TRUE(pinnable_val == ("bar0c"));
+ s = db->Get(ropt, cf_handle, "foo0", &pinnable_val);
+ ASSERT_OK(s);
+ ASSERT_TRUE(pinnable_val == ("bar0-cf1"));
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo1", &pinnable_val);
+ ASSERT_OK(s);
+ ASSERT_TRUE(pinnable_val == ("bar1"));
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo2", &pinnable_val);
+ ASSERT_OK(s);
+ ASSERT_TRUE(pinnable_val == ("bar2a,bar2b"));
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo3", &pinnable_val);
+ ASSERT_OK(s);
+ ASSERT_TRUE(pinnable_val == ("bar3,bar3"));
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo4", &pinnable_val);
+ ASSERT_TRUE(s.IsNotFound());
+ if (with_commit_batch) {
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo6", &pinnable_val);
+ ASSERT_OK(s);
+ ASSERT_TRUE(pinnable_val == ("bar6b"));
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo7", &pinnable_val);
+ ASSERT_TRUE(s.IsNotFound());
+ }
+ }
+ delete cf_handle;
+ } // with_commit_batch
+ } // do_rollback
+ } // do_prepare
+
+ if (!options.unordered_write) {
+ // Also test with max_successive_merges > 0. max_successive_merges will not
+ // affect our algorithm for duplicate key insertion but we add the test to
+ // verify that.
+ cf_options.max_successive_merges = 2;
+ cf_options.merge_operator = MergeOperators::CreateStringAppendOperator();
+ ASSERT_OK(ReOpen());
+ db->CreateColumnFamily(cf_options, cf_name, &cf_handle);
+ WriteOptions write_options;
+ // Ensure one value for the key
+ ASSERT_OK(db->Put(write_options, cf_handle, Slice("key"), Slice("value")));
+ WriteBatch batch;
+ // Merge more than max_successive_merges times
+ batch.Merge(cf_handle, Slice("key"), Slice("1"));
+ batch.Merge(cf_handle, Slice("key"), Slice("2"));
+ batch.Merge(cf_handle, Slice("key"), Slice("3"));
+ batch.Merge(cf_handle, Slice("key"), Slice("4"));
+ ASSERT_OK(db->Write(write_options, &batch));
+ ReadOptions read_options;
+ string value;
+ ASSERT_OK(db->Get(read_options, cf_handle, "key", &value));
+ ASSERT_EQ(value, "value,1,2,3,4");
+ delete cf_handle;
+ }
+
+ {
+ // Test that the duplicate detection is not compromised after rolling back
+ // to a save point
+ TransactionOptions txn_options;
+ WriteOptions write_options;
+ Transaction* txn0 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(txn0->Put(Slice("foo0"), Slice("bar0a")));
+ ASSERT_OK(txn0->Put(Slice("foo0"), Slice("bar0b")));
+ txn0->SetSavePoint();
+ ASSERT_OK(txn0->RollbackToSavePoint());
+ ASSERT_OK(txn0->Commit());
+ delete txn0;
+ }
+
+ // Test sucessfull recovery after a crash
+ {
+ ASSERT_OK(ReOpen());
+ TransactionOptions txn_options;
+ WriteOptions write_options;
+ ReadOptions ropt;
+ Transaction* txn0;
+ PinnableSlice pinnable_val;
+ Status s;
+
+ std::unique_ptr<const Comparator> comp_gc(new ThreeBytewiseComparator());
+ cf_options.comparator = comp_gc.get();
+ cf_options.merge_operator = MergeOperators::CreateStringAppendOperator();
+ ASSERT_OK(db->CreateColumnFamily(cf_options, cf_name, &cf_handle));
+ delete cf_handle;
+ std::vector<ColumnFamilyDescriptor> cfds{
+ ColumnFamilyDescriptor(kDefaultColumnFamilyName,
+ ColumnFamilyOptions(options)),
+ ColumnFamilyDescriptor(cf_name, cf_options),
+ };
+ std::vector<ColumnFamilyHandle*> handles;
+ ASSERT_OK(ReOpenNoDelete(cfds, &handles));
+
+ ASSERT_OK(db->Put(write_options, "foo0", "init"));
+ ASSERT_OK(db->Put(write_options, "foo1", "init"));
+ ASSERT_OK(db->Put(write_options, handles[1], "foo0", "init"));
+ ASSERT_OK(db->Put(write_options, handles[1], "foo1", "init"));
+
+ // one entry
+ txn0 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(txn0->SetName("xid"));
+ ASSERT_OK(txn0->Put(Slice("foo0"), Slice("bar0a")));
+ ASSERT_OK(txn0->Prepare());
+ delete txn0;
+ // This will check the asserts inside recovery code
+ ASSERT_OK(db->FlushWAL(true));
+ reinterpret_cast<PessimisticTransactionDB*>(db)->TEST_Crash();
+ ASSERT_OK(ReOpenNoDelete(cfds, &handles));
+ txn0 = db->GetTransactionByName("xid");
+ ASSERT_TRUE(txn0 != nullptr);
+ ASSERT_OK(txn0->Commit());
+ delete txn0;
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo0", &pinnable_val);
+ ASSERT_OK(s);
+ ASSERT_TRUE(pinnable_val == ("bar0a"));
+
+ // two entries, no duplicate
+ txn0 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(txn0->SetName("xid"));
+ ASSERT_OK(txn0->Put(handles[1], Slice("foo0"), Slice("bar0b")));
+ ASSERT_OK(txn0->Put(handles[1], Slice("fol1"), Slice("bar1b")));
+ ASSERT_OK(txn0->Put(Slice("foo0"), Slice("bar0b")));
+ ASSERT_OK(txn0->Put(Slice("foo1"), Slice("bar1b")));
+ ASSERT_OK(txn0->Prepare());
+ delete txn0;
+ // This will check the asserts inside recovery code
+ db->FlushWAL(true);
+ // Flush only cf 1
+ reinterpret_cast<DBImpl*>(db->GetRootDB())
+ ->TEST_FlushMemTable(true, false, handles[1]);
+ reinterpret_cast<PessimisticTransactionDB*>(db)->TEST_Crash();
+ ASSERT_OK(ReOpenNoDelete(cfds, &handles));
+ txn0 = db->GetTransactionByName("xid");
+ ASSERT_TRUE(txn0 != nullptr);
+ ASSERT_OK(txn0->Commit());
+ delete txn0;
+ pinnable_val.Reset();
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo0", &pinnable_val);
+ ASSERT_OK(s);
+ ASSERT_TRUE(pinnable_val == ("bar0b"));
+ pinnable_val.Reset();
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo1", &pinnable_val);
+ ASSERT_OK(s);
+ ASSERT_TRUE(pinnable_val == ("bar1b"));
+ pinnable_val.Reset();
+ s = db->Get(ropt, handles[1], "foo0", &pinnable_val);
+ ASSERT_OK(s);
+ ASSERT_TRUE(pinnable_val == ("bar0b"));
+ pinnable_val.Reset();
+ s = db->Get(ropt, handles[1], "fol1", &pinnable_val);
+ ASSERT_OK(s);
+ ASSERT_TRUE(pinnable_val == ("bar1b"));
+
+ // one duplicate with ::Put
+ txn0 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(txn0->SetName("xid"));
+ ASSERT_OK(txn0->Put(handles[1], Slice("key-nonkey0"), Slice("bar0c")));
+ ASSERT_OK(txn0->Put(handles[1], Slice("key-nonkey1"), Slice("bar1d")));
+ ASSERT_OK(txn0->Put(Slice("foo0"), Slice("bar0c")));
+ ASSERT_OK(txn0->Put(Slice("foo1"), Slice("bar1c")));
+ ASSERT_OK(txn0->Put(Slice("foo0"), Slice("bar0d")));
+ ASSERT_OK(txn0->Prepare());
+ delete txn0;
+ // This will check the asserts inside recovery code
+ ASSERT_OK(db->FlushWAL(true));
+ // Flush only cf 1
+ reinterpret_cast<DBImpl*>(db->GetRootDB())
+ ->TEST_FlushMemTable(true, false, handles[1]);
+ reinterpret_cast<PessimisticTransactionDB*>(db)->TEST_Crash();
+ ASSERT_OK(ReOpenNoDelete(cfds, &handles));
+ txn0 = db->GetTransactionByName("xid");
+ ASSERT_TRUE(txn0 != nullptr);
+ ASSERT_OK(txn0->Commit());
+ delete txn0;
+ pinnable_val.Reset();
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo0", &pinnable_val);
+ ASSERT_OK(s);
+ ASSERT_TRUE(pinnable_val == ("bar0d"));
+ pinnable_val.Reset();
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo1", &pinnable_val);
+ ASSERT_OK(s);
+ ASSERT_TRUE(pinnable_val == ("bar1c"));
+ pinnable_val.Reset();
+ s = db->Get(ropt, handles[1], "key-nonkey2", &pinnable_val);
+ ASSERT_OK(s);
+ ASSERT_TRUE(pinnable_val == ("bar1d"));
+
+ // Duplicate with ::Put, ::Delete
+ txn0 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(txn0->SetName("xid"));
+ ASSERT_OK(txn0->Put(handles[1], Slice("key-nonkey0"), Slice("bar0e")));
+ ASSERT_OK(txn0->Delete(handles[1], Slice("key-nonkey1")));
+ ASSERT_OK(txn0->Put(Slice("foo0"), Slice("bar0e")));
+ ASSERT_OK(txn0->Delete(Slice("foo0")));
+ ASSERT_OK(txn0->Prepare());
+ delete txn0;
+ // This will check the asserts inside recovery code
+ ASSERT_OK(db->FlushWAL(true));
+ // Flush only cf 1
+ reinterpret_cast<DBImpl*>(db->GetRootDB())
+ ->TEST_FlushMemTable(true, false, handles[1]);
+ reinterpret_cast<PessimisticTransactionDB*>(db)->TEST_Crash();
+ ASSERT_OK(ReOpenNoDelete(cfds, &handles));
+ txn0 = db->GetTransactionByName("xid");
+ ASSERT_TRUE(txn0 != nullptr);
+ ASSERT_OK(txn0->Commit());
+ delete txn0;
+ pinnable_val.Reset();
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo0", &pinnable_val);
+ ASSERT_TRUE(s.IsNotFound());
+ pinnable_val.Reset();
+ s = db->Get(ropt, handles[1], "key-nonkey2", &pinnable_val);
+ ASSERT_TRUE(s.IsNotFound());
+
+ // Duplicate with ::Put, ::SingleDelete
+ txn0 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(txn0->SetName("xid"));
+ ASSERT_OK(txn0->Put(handles[1], Slice("key-nonkey0"), Slice("bar0g")));
+ ASSERT_OK(txn0->SingleDelete(handles[1], Slice("key-nonkey1")));
+ ASSERT_OK(txn0->Put(Slice("foo0"), Slice("bar0e")));
+ ASSERT_OK(txn0->SingleDelete(Slice("foo0")));
+ ASSERT_OK(txn0->Prepare());
+ delete txn0;
+ // This will check the asserts inside recovery code
+ ASSERT_OK(db->FlushWAL(true));
+ // Flush only cf 1
+ reinterpret_cast<DBImpl*>(db->GetRootDB())
+ ->TEST_FlushMemTable(true, false, handles[1]);
+ reinterpret_cast<PessimisticTransactionDB*>(db)->TEST_Crash();
+ ASSERT_OK(ReOpenNoDelete(cfds, &handles));
+ txn0 = db->GetTransactionByName("xid");
+ ASSERT_TRUE(txn0 != nullptr);
+ ASSERT_OK(txn0->Commit());
+ delete txn0;
+ pinnable_val.Reset();
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo0", &pinnable_val);
+ ASSERT_TRUE(s.IsNotFound());
+ pinnable_val.Reset();
+ s = db->Get(ropt, handles[1], "key-nonkey2", &pinnable_val);
+ ASSERT_TRUE(s.IsNotFound());
+
+ // Duplicate with ::Put, ::Merge
+ txn0 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(txn0->SetName("xid"));
+ ASSERT_OK(txn0->Put(handles[1], Slice("key-nonkey0"), Slice("bar1i")));
+ ASSERT_OK(txn0->Merge(handles[1], Slice("key-nonkey1"), Slice("bar1j")));
+ ASSERT_OK(txn0->Put(Slice("foo0"), Slice("bar0f")));
+ ASSERT_OK(txn0->Merge(Slice("foo0"), Slice("bar0g")));
+ ASSERT_OK(txn0->Prepare());
+ delete txn0;
+ // This will check the asserts inside recovery code
+ ASSERT_OK(db->FlushWAL(true));
+ // Flush only cf 1
+ reinterpret_cast<DBImpl*>(db->GetRootDB())
+ ->TEST_FlushMemTable(true, false, handles[1]);
+ reinterpret_cast<PessimisticTransactionDB*>(db)->TEST_Crash();
+ ASSERT_OK(ReOpenNoDelete(cfds, &handles));
+ txn0 = db->GetTransactionByName("xid");
+ ASSERT_TRUE(txn0 != nullptr);
+ ASSERT_OK(txn0->Commit());
+ delete txn0;
+ pinnable_val.Reset();
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo0", &pinnable_val);
+ ASSERT_OK(s);
+ ASSERT_TRUE(pinnable_val == ("bar0f,bar0g"));
+ pinnable_val.Reset();
+ s = db->Get(ropt, handles[1], "key-nonkey2", &pinnable_val);
+ ASSERT_OK(s);
+ ASSERT_TRUE(pinnable_val == ("bar1i,bar1j"));
+
+ for (auto h : handles) {
+ delete h;
+ }
+ delete db;
+ db = nullptr;
+ }
+}
+
+// Test that the reseek optimization in iterators will not result in an infinite
+// loop if there are too many uncommitted entries before the snapshot.
+TEST_P(TransactionTest, ReseekOptimization) {
+ WriteOptions write_options;
+ write_options.sync = true;
+ write_options.disableWAL = false;
+ ColumnFamilyDescriptor cfd;
+ db->DefaultColumnFamily()->GetDescriptor(&cfd);
+ auto max_skip = cfd.options.max_sequential_skip_in_iterations;
+
+ ASSERT_OK(db->Put(write_options, Slice("foo0"), Slice("initv")));
+
+ TransactionOptions txn_options;
+ Transaction* txn0 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(txn0->SetName("xid"));
+ // Duplicate keys will result into separate sequence numbers in WritePrepared
+ // and WriteUnPrepared
+ for (size_t i = 0; i < 2 * max_skip; i++) {
+ ASSERT_OK(txn0->Put(Slice("foo1"), Slice("bar")));
+ }
+ ASSERT_OK(txn0->Prepare());
+ ASSERT_OK(db->Put(write_options, Slice("foo2"), Slice("initv")));
+
+ ReadOptions read_options;
+ // To avoid loops
+ read_options.max_skippable_internal_keys = 10 * max_skip;
+ Iterator* iter = db->NewIterator(read_options);
+ ASSERT_OK(iter->status());
+ size_t cnt = 0;
+ iter->SeekToFirst();
+ while (iter->Valid()) {
+ iter->Next();
+ ASSERT_OK(iter->status());
+ cnt++;
+ }
+ ASSERT_EQ(cnt, 2);
+ cnt = 0;
+ iter->SeekToLast();
+ while (iter->Valid()) {
+ iter->Prev();
+ ASSERT_OK(iter->status());
+ cnt++;
+ }
+ ASSERT_EQ(cnt, 2);
+ delete iter;
+ txn0->Rollback();
+ delete txn0;
+}
+
+// After recovery in kPointInTimeRecovery mode, the corrupted log file remains
+// there. The new log files should be still read succesfully during recovery of
+// the 2nd crash.
+TEST_P(TransactionTest, DoubleCrashInRecovery) {
+ for (const bool manual_wal_flush : {false, true}) {
+ for (const bool write_after_recovery : {false, true}) {
+ options.wal_recovery_mode = WALRecoveryMode::kPointInTimeRecovery;
+ options.manual_wal_flush = manual_wal_flush;
+ ReOpen();
+ std::string cf_name = "two";
+ ColumnFamilyOptions cf_options;
+ ColumnFamilyHandle* cf_handle = nullptr;
+ ASSERT_OK(db->CreateColumnFamily(cf_options, cf_name, &cf_handle));
+
+ // Add a prepare entry to prevent the older logs from being deleted.
+ WriteOptions write_options;
+ TransactionOptions txn_options;
+ Transaction* txn = db->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(txn->SetName("xid"));
+ ASSERT_OK(txn->Put(Slice("foo-prepare"), Slice("bar-prepare")));
+ ASSERT_OK(txn->Prepare());
+
+ FlushOptions flush_ops;
+ db->Flush(flush_ops);
+ // Now we have a log that cannot be deleted
+
+ ASSERT_OK(db->Put(write_options, cf_handle, "foo1", "bar1"));
+ // Flush only the 2nd cf
+ db->Flush(flush_ops, cf_handle);
+
+ // The value is large enough to be touched by the corruption we ingest
+ // below.
+ std::string large_value(400, ' ');
+ // key/value not touched by corruption
+ ASSERT_OK(db->Put(write_options, "foo2", "bar2"));
+ // key/value touched by corruption
+ ASSERT_OK(db->Put(write_options, "foo3", large_value));
+ // key/value not touched by corruption
+ ASSERT_OK(db->Put(write_options, "foo4", "bar4"));
+
+ db->FlushWAL(true);
+ DBImpl* db_impl = reinterpret_cast<DBImpl*>(db->GetRootDB());
+ uint64_t wal_file_id = db_impl->TEST_LogfileNumber();
+ std::string fname = LogFileName(dbname, wal_file_id);
+ reinterpret_cast<PessimisticTransactionDB*>(db)->TEST_Crash();
+ delete txn;
+ delete cf_handle;
+ delete db;
+ db = nullptr;
+
+ // Corrupt the last log file in the middle, so that it is not corrupted
+ // in the tail.
+ std::string file_content;
+ ASSERT_OK(ReadFileToString(env, fname, &file_content));
+ file_content[400] = 'h';
+ file_content[401] = 'a';
+ ASSERT_OK(env->DeleteFile(fname));
+ ASSERT_OK(WriteStringToFile(env, file_content, fname, true));
+
+ // Recover from corruption
+ std::vector<ColumnFamilyHandle*> handles;
+ std::vector<ColumnFamilyDescriptor> column_families;
+ column_families.push_back(ColumnFamilyDescriptor(kDefaultColumnFamilyName,
+ ColumnFamilyOptions()));
+ column_families.push_back(
+ ColumnFamilyDescriptor("two", ColumnFamilyOptions()));
+ ASSERT_OK(ReOpenNoDelete(column_families, &handles));
+
+ if (write_after_recovery) {
+ // Write data to the log right after the corrupted log
+ ASSERT_OK(db->Put(write_options, "foo5", large_value));
+ }
+
+ // Persist data written to WAL during recovery or by the last Put
+ db->FlushWAL(true);
+ // 2nd crash to recover while having a valid log after the corrupted one.
+ ASSERT_OK(ReOpenNoDelete(column_families, &handles));
+ assert(db != nullptr);
+ txn = db->GetTransactionByName("xid");
+ ASSERT_TRUE(txn != nullptr);
+ ASSERT_OK(txn->Commit());
+ delete txn;
+ for (auto handle : handles) {
+ delete handle;
+ }
+ }
+ }
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
+
+#else
+#include <stdio.h>
+
+int main(int /*argc*/, char** /*argv*/) {
+ fprintf(stderr,
+ "SKIPPED as Transactions are not supported in ROCKSDB_LITE\n");
+ return 0;
+}
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/transaction_test.h b/src/rocksdb/utilities/transactions/transaction_test.h
new file mode 100644
index 000000000..2e533d379
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/transaction_test.h
@@ -0,0 +1,517 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#include <algorithm>
+#include <cinttypes>
+#include <functional>
+#include <string>
+#include <thread>
+
+#include "db/db_impl/db_impl.h"
+#include "rocksdb/db.h"
+#include "rocksdb/options.h"
+#include "rocksdb/utilities/transaction.h"
+#include "rocksdb/utilities/transaction_db.h"
+#include "table/mock_table.h"
+#include "test_util/fault_injection_test_env.h"
+#include "test_util/sync_point.h"
+#include "test_util/testharness.h"
+#include "test_util/testutil.h"
+#include "test_util/transaction_test_util.h"
+#include "util/random.h"
+#include "util/string_util.h"
+#include "utilities/merge_operators.h"
+#include "utilities/merge_operators/string_append/stringappend.h"
+#include "utilities/transactions/pessimistic_transaction_db.h"
+#include "utilities/transactions/write_unprepared_txn_db.h"
+
+#include "port/port.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// Return true if the ith bit is set in combination represented by comb
+bool IsInCombination(size_t i, size_t comb) { return comb & (size_t(1) << i); }
+
+enum WriteOrdering : bool { kOrderedWrite, kUnorderedWrite };
+
+class TransactionTestBase : public ::testing::Test {
+ public:
+ TransactionDB* db;
+ FaultInjectionTestEnv* env;
+ std::string dbname;
+ Options options;
+
+ TransactionDBOptions txn_db_options;
+ bool use_stackable_db_;
+
+ TransactionTestBase(bool use_stackable_db, bool two_write_queue,
+ TxnDBWritePolicy write_policy,
+ WriteOrdering write_ordering)
+ : db(nullptr), env(nullptr), use_stackable_db_(use_stackable_db) {
+ options.create_if_missing = true;
+ options.max_write_buffer_number = 2;
+ options.write_buffer_size = 4 * 1024;
+ options.unordered_write = write_ordering == kUnorderedWrite;
+ options.level0_file_num_compaction_trigger = 2;
+ options.merge_operator = MergeOperators::CreateFromStringId("stringappend");
+ env = new FaultInjectionTestEnv(Env::Default());
+ options.env = env;
+ options.two_write_queues = two_write_queue;
+ dbname = test::PerThreadDBPath("transaction_testdb");
+
+ DestroyDB(dbname, options);
+ txn_db_options.transaction_lock_timeout = 0;
+ txn_db_options.default_lock_timeout = 0;
+ txn_db_options.write_policy = write_policy;
+ txn_db_options.rollback_merge_operands = true;
+ // This will stress write unprepared, by forcing write batch flush on every
+ // write.
+ txn_db_options.default_write_batch_flush_threshold = 1;
+ // Write unprepared requires all transactions to be named. This setting
+ // autogenerates the name so that existing tests can pass.
+ txn_db_options.autogenerate_name = true;
+ Status s;
+ if (use_stackable_db == false) {
+ s = TransactionDB::Open(options, txn_db_options, dbname, &db);
+ } else {
+ s = OpenWithStackableDB();
+ }
+ assert(s.ok());
+ }
+
+ ~TransactionTestBase() {
+ delete db;
+ db = nullptr;
+ // This is to skip the assert statement in FaultInjectionTestEnv. There
+ // seems to be a bug in btrfs that the makes readdir return recently
+ // unlink-ed files. By using the default fs we simply ignore errors resulted
+ // from attempting to delete such files in DestroyDB.
+ options.env = Env::Default();
+ DestroyDB(dbname, options);
+ delete env;
+ }
+
+ Status ReOpenNoDelete() {
+ delete db;
+ db = nullptr;
+ env->AssertNoOpenFile();
+ env->DropUnsyncedFileData();
+ env->ResetState();
+ Status s;
+ if (use_stackable_db_ == false) {
+ s = TransactionDB::Open(options, txn_db_options, dbname, &db);
+ } else {
+ s = OpenWithStackableDB();
+ }
+ assert(!s.ok() || db != nullptr);
+ return s;
+ }
+
+ Status ReOpenNoDelete(std::vector<ColumnFamilyDescriptor>& cfs,
+ std::vector<ColumnFamilyHandle*>* handles) {
+ for (auto h : *handles) {
+ delete h;
+ }
+ handles->clear();
+ delete db;
+ db = nullptr;
+ env->AssertNoOpenFile();
+ env->DropUnsyncedFileData();
+ env->ResetState();
+ Status s;
+ if (use_stackable_db_ == false) {
+ s = TransactionDB::Open(options, txn_db_options, dbname, cfs, handles,
+ &db);
+ } else {
+ s = OpenWithStackableDB(cfs, handles);
+ }
+ assert(!s.ok() || db != nullptr);
+ return s;
+ }
+
+ Status ReOpen() {
+ delete db;
+ db = nullptr;
+ DestroyDB(dbname, options);
+ Status s;
+ if (use_stackable_db_ == false) {
+ s = TransactionDB::Open(options, txn_db_options, dbname, &db);
+ } else {
+ s = OpenWithStackableDB();
+ }
+ assert(db != nullptr);
+ return s;
+ }
+
+ Status OpenWithStackableDB(std::vector<ColumnFamilyDescriptor>& cfs,
+ std::vector<ColumnFamilyHandle*>* handles) {
+ std::vector<size_t> compaction_enabled_cf_indices;
+ TransactionDB::PrepareWrap(&options, &cfs, &compaction_enabled_cf_indices);
+ DB* root_db = nullptr;
+ Options options_copy(options);
+ const bool use_seq_per_batch =
+ txn_db_options.write_policy == WRITE_PREPARED ||
+ txn_db_options.write_policy == WRITE_UNPREPARED;
+ const bool use_batch_per_txn =
+ txn_db_options.write_policy == WRITE_COMMITTED ||
+ txn_db_options.write_policy == WRITE_PREPARED;
+ Status s = DBImpl::Open(options_copy, dbname, cfs, handles, &root_db,
+ use_seq_per_batch, use_batch_per_txn);
+ StackableDB* stackable_db = new StackableDB(root_db);
+ if (s.ok()) {
+ assert(root_db != nullptr);
+ s = TransactionDB::WrapStackableDB(stackable_db, txn_db_options,
+ compaction_enabled_cf_indices,
+ *handles, &db);
+ }
+ if (!s.ok()) {
+ delete stackable_db;
+ }
+ return s;
+ }
+
+ Status OpenWithStackableDB() {
+ std::vector<size_t> compaction_enabled_cf_indices;
+ std::vector<ColumnFamilyDescriptor> column_families{ColumnFamilyDescriptor(
+ kDefaultColumnFamilyName, ColumnFamilyOptions(options))};
+
+ TransactionDB::PrepareWrap(&options, &column_families,
+ &compaction_enabled_cf_indices);
+ std::vector<ColumnFamilyHandle*> handles;
+ DB* root_db = nullptr;
+ Options options_copy(options);
+ const bool use_seq_per_batch =
+ txn_db_options.write_policy == WRITE_PREPARED ||
+ txn_db_options.write_policy == WRITE_UNPREPARED;
+ const bool use_batch_per_txn =
+ txn_db_options.write_policy == WRITE_COMMITTED ||
+ txn_db_options.write_policy == WRITE_PREPARED;
+ Status s = DBImpl::Open(options_copy, dbname, column_families, &handles,
+ &root_db, use_seq_per_batch, use_batch_per_txn);
+ if (!s.ok()) {
+ delete root_db;
+ return s;
+ }
+ StackableDB* stackable_db = new StackableDB(root_db);
+ assert(root_db != nullptr);
+ assert(handles.size() == 1);
+ s = TransactionDB::WrapStackableDB(stackable_db, txn_db_options,
+ compaction_enabled_cf_indices, handles,
+ &db);
+ delete handles[0];
+ if (!s.ok()) {
+ delete stackable_db;
+ }
+ return s;
+ }
+
+ std::atomic<size_t> linked = {0};
+ std::atomic<size_t> exp_seq = {0};
+ std::atomic<size_t> commit_writes = {0};
+ std::atomic<size_t> expected_commits = {0};
+ // Without Prepare, the commit does not write to WAL
+ std::atomic<size_t> with_empty_commits = {0};
+ std::function<void(size_t, Status)> txn_t0_with_status = [&](size_t index,
+ Status exp_s) {
+ // Test DB's internal txn. It involves no prepare phase nor a commit marker.
+ WriteOptions wopts;
+ auto s = db->Put(wopts, "key" + std::to_string(index), "value");
+ ASSERT_EQ(exp_s, s);
+ if (txn_db_options.write_policy == TxnDBWritePolicy::WRITE_COMMITTED) {
+ // Consume one seq per key
+ exp_seq++;
+ } else {
+ // Consume one seq per batch
+ exp_seq++;
+ if (options.two_write_queues) {
+ // Consume one seq for commit
+ exp_seq++;
+ }
+ }
+ with_empty_commits++;
+ };
+ std::function<void(size_t)> txn_t0 = [&](size_t index) {
+ return txn_t0_with_status(index, Status::OK());
+ };
+ std::function<void(size_t)> txn_t1 = [&](size_t index) {
+ // Testing directly writing a write batch. Functionality-wise it is
+ // equivalent to commit without prepare.
+ WriteBatch wb;
+ auto istr = std::to_string(index);
+ ASSERT_OK(wb.Put("k1" + istr, "v1"));
+ ASSERT_OK(wb.Put("k2" + istr, "v2"));
+ ASSERT_OK(wb.Put("k3" + istr, "v3"));
+ WriteOptions wopts;
+ auto s = db->Write(wopts, &wb);
+ if (txn_db_options.write_policy == TxnDBWritePolicy::WRITE_COMMITTED) {
+ // Consume one seq per key
+ exp_seq += 3;
+ } else {
+ // Consume one seq per batch
+ exp_seq++;
+ if (options.two_write_queues) {
+ // Consume one seq for commit
+ exp_seq++;
+ }
+ }
+ ASSERT_OK(s);
+ with_empty_commits++;
+ };
+ std::function<void(size_t)> txn_t2 = [&](size_t index) {
+ // Commit without prepare. It should write to DB without a commit marker.
+ TransactionOptions txn_options;
+ WriteOptions write_options;
+ Transaction* txn = db->BeginTransaction(write_options, txn_options);
+ auto istr = std::to_string(index);
+ ASSERT_OK(txn->SetName("xid" + istr));
+ ASSERT_OK(txn->Put(Slice("foo" + istr), Slice("bar")));
+ ASSERT_OK(txn->Put(Slice("foo2" + istr), Slice("bar2")));
+ ASSERT_OK(txn->Put(Slice("foo3" + istr), Slice("bar3")));
+ ASSERT_OK(txn->Put(Slice("foo4" + istr), Slice("bar4")));
+ ASSERT_OK(txn->Commit());
+ if (txn_db_options.write_policy == TxnDBWritePolicy::WRITE_COMMITTED) {
+ // Consume one seq per key
+ exp_seq += 4;
+ } else if (txn_db_options.write_policy ==
+ TxnDBWritePolicy::WRITE_PREPARED) {
+ // Consume one seq per batch
+ exp_seq++;
+ if (options.two_write_queues) {
+ // Consume one seq for commit
+ exp_seq++;
+ }
+ } else {
+ // Flushed after each key, consume one seq per flushed batch
+ exp_seq += 4;
+ // WriteUnprepared implements CommitWithoutPrepareInternal by simply
+ // calling Prepare then Commit. Consume one seq for the prepare.
+ exp_seq++;
+ }
+ delete txn;
+ with_empty_commits++;
+ };
+ std::function<void(size_t)> txn_t3 = [&](size_t index) {
+ // A full 2pc txn that also involves a commit marker.
+ TransactionOptions txn_options;
+ WriteOptions write_options;
+ Transaction* txn = db->BeginTransaction(write_options, txn_options);
+ auto istr = std::to_string(index);
+ ASSERT_OK(txn->SetName("xid" + istr));
+ ASSERT_OK(txn->Put(Slice("foo" + istr), Slice("bar")));
+ ASSERT_OK(txn->Put(Slice("foo2" + istr), Slice("bar2")));
+ ASSERT_OK(txn->Put(Slice("foo3" + istr), Slice("bar3")));
+ ASSERT_OK(txn->Put(Slice("foo4" + istr), Slice("bar4")));
+ ASSERT_OK(txn->Put(Slice("foo5" + istr), Slice("bar5")));
+ expected_commits++;
+ ASSERT_OK(txn->Prepare());
+ commit_writes++;
+ ASSERT_OK(txn->Commit());
+ if (txn_db_options.write_policy == TxnDBWritePolicy::WRITE_COMMITTED) {
+ // Consume one seq per key
+ exp_seq += 5;
+ } else if (txn_db_options.write_policy ==
+ TxnDBWritePolicy::WRITE_PREPARED) {
+ // Consume one seq per batch
+ exp_seq++;
+ // Consume one seq per commit marker
+ exp_seq++;
+ } else {
+ // Flushed after each key, consume one seq per flushed batch
+ exp_seq += 5;
+ // Consume one seq per commit marker
+ exp_seq++;
+ }
+ delete txn;
+ };
+ std::function<void(size_t)> txn_t4 = [&](size_t index) {
+ // A full 2pc txn that also involves a commit marker.
+ TransactionOptions txn_options;
+ WriteOptions write_options;
+ Transaction* txn = db->BeginTransaction(write_options, txn_options);
+ auto istr = std::to_string(index);
+ ASSERT_OK(txn->SetName("xid" + istr));
+ ASSERT_OK(txn->Put(Slice("foo" + istr), Slice("bar")));
+ ASSERT_OK(txn->Put(Slice("foo2" + istr), Slice("bar2")));
+ ASSERT_OK(txn->Put(Slice("foo3" + istr), Slice("bar3")));
+ ASSERT_OK(txn->Put(Slice("foo4" + istr), Slice("bar4")));
+ ASSERT_OK(txn->Put(Slice("foo5" + istr), Slice("bar5")));
+ expected_commits++;
+ ASSERT_OK(txn->Prepare());
+ commit_writes++;
+ ASSERT_OK(txn->Rollback());
+ if (txn_db_options.write_policy == TxnDBWritePolicy::WRITE_COMMITTED) {
+ // No seq is consumed for deleting the txn buffer
+ exp_seq += 0;
+ } else if (txn_db_options.write_policy ==
+ TxnDBWritePolicy::WRITE_PREPARED) {
+ // Consume one seq per batch
+ exp_seq++;
+ // Consume one seq per rollback batch
+ exp_seq++;
+ if (options.two_write_queues) {
+ // Consume one seq for rollback commit
+ exp_seq++;
+ }
+ } else {
+ // Flushed after each key, consume one seq per flushed batch
+ exp_seq += 5;
+ // Consume one seq per rollback batch
+ exp_seq++;
+ if (options.two_write_queues) {
+ // Consume one seq for rollback commit
+ exp_seq++;
+ }
+ }
+ delete txn;
+ };
+
+ // Test that we can change write policy after a clean shutdown (which would
+ // empty the WAL)
+ void CrossCompatibilityTest(TxnDBWritePolicy from_policy,
+ TxnDBWritePolicy to_policy, bool empty_wal) {
+ TransactionOptions txn_options;
+ ReadOptions read_options;
+ WriteOptions write_options;
+ uint32_t index = 0;
+ Random rnd(1103);
+ options.write_buffer_size = 1024; // To create more sst files
+ std::unordered_map<std::string, std::string> committed_kvs;
+ Transaction* txn;
+
+ txn_db_options.write_policy = from_policy;
+ if (txn_db_options.write_policy == WRITE_COMMITTED) {
+ options.unordered_write = false;
+ }
+ ReOpen();
+
+ for (int i = 0; i < 1024; i++) {
+ auto istr = std::to_string(index);
+ auto k = Slice("foo-" + istr).ToString();
+ auto v = Slice("bar-" + istr).ToString();
+ // For test the duplicate keys
+ auto v2 = Slice("bar2-" + istr).ToString();
+ auto type = rnd.Uniform(4);
+ switch (type) {
+ case 0:
+ committed_kvs[k] = v;
+ ASSERT_OK(db->Put(write_options, k, v));
+ committed_kvs[k] = v2;
+ ASSERT_OK(db->Put(write_options, k, v2));
+ break;
+ case 1: {
+ WriteBatch wb;
+ committed_kvs[k] = v;
+ wb.Put(k, v);
+ committed_kvs[k] = v2;
+ wb.Put(k, v2);
+ ASSERT_OK(db->Write(write_options, &wb));
+
+ } break;
+ case 2:
+ case 3:
+ txn = db->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(txn->SetName("xid" + istr));
+ committed_kvs[k] = v;
+ ASSERT_OK(txn->Put(k, v));
+ committed_kvs[k] = v2;
+ ASSERT_OK(txn->Put(k, v2));
+
+ if (type == 3) {
+ ASSERT_OK(txn->Prepare());
+ }
+ ASSERT_OK(txn->Commit());
+ delete txn;
+ break;
+ default:
+ assert(0);
+ }
+
+ index++;
+ } // for i
+
+ txn_db_options.write_policy = to_policy;
+ if (txn_db_options.write_policy == WRITE_COMMITTED) {
+ options.unordered_write = false;
+ }
+ auto db_impl = reinterpret_cast<DBImpl*>(db->GetRootDB());
+ // Before upgrade/downgrade the WAL must be emptied
+ if (empty_wal) {
+ db_impl->TEST_FlushMemTable();
+ } else {
+ db_impl->FlushWAL(true);
+ }
+ auto s = ReOpenNoDelete();
+ if (empty_wal) {
+ ASSERT_OK(s);
+ } else {
+ // Test that we can detect the WAL that is produced by an incompatible
+ // WritePolicy and fail fast before mis-interpreting the WAL.
+ ASSERT_TRUE(s.IsNotSupported());
+ return;
+ }
+ db_impl = reinterpret_cast<DBImpl*>(db->GetRootDB());
+ // Check that WAL is empty
+ VectorLogPtr log_files;
+ db_impl->GetSortedWalFiles(log_files);
+ ASSERT_EQ(0, log_files.size());
+
+ for (auto& kv : committed_kvs) {
+ std::string value;
+ s = db->Get(read_options, kv.first, &value);
+ if (s.IsNotFound()) {
+ printf("key = %s\n", kv.first.c_str());
+ }
+ ASSERT_OK(s);
+ if (kv.second != value) {
+ printf("key = %s\n", kv.first.c_str());
+ }
+ ASSERT_EQ(kv.second, value);
+ }
+ }
+};
+
+class TransactionTest
+ : public TransactionTestBase,
+ virtual public ::testing::WithParamInterface<
+ std::tuple<bool, bool, TxnDBWritePolicy, WriteOrdering>> {
+ public:
+ TransactionTest()
+ : TransactionTestBase(std::get<0>(GetParam()), std::get<1>(GetParam()),
+ std::get<2>(GetParam()), std::get<3>(GetParam())){};
+};
+
+class TransactionStressTest : public TransactionTest {};
+
+class MySQLStyleTransactionTest
+ : public TransactionTestBase,
+ virtual public ::testing::WithParamInterface<
+ std::tuple<bool, bool, TxnDBWritePolicy, WriteOrdering, bool>> {
+ public:
+ MySQLStyleTransactionTest()
+ : TransactionTestBase(std::get<0>(GetParam()), std::get<1>(GetParam()),
+ std::get<2>(GetParam()), std::get<3>(GetParam())),
+ with_slow_threads_(std::get<4>(GetParam())) {
+ if (with_slow_threads_ &&
+ (txn_db_options.write_policy == WRITE_PREPARED ||
+ txn_db_options.write_policy == WRITE_UNPREPARED)) {
+ // The corner case with slow threads involves the caches filling
+ // over which would not happen even with artifial delays. To help
+ // such cases to show up we lower the size of the cache-related data
+ // structures.
+ txn_db_options.wp_snapshot_cache_bits = 1;
+ txn_db_options.wp_commit_cache_bits = 10;
+ options.write_buffer_size = 1024;
+ EXPECT_OK(ReOpen());
+ }
+ };
+
+ protected:
+ // Also emulate slow threads by addin artiftial delays
+ const bool with_slow_threads_;
+};
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/transactions/transaction_util.cc b/src/rocksdb/utilities/transactions/transaction_util.cc
new file mode 100644
index 000000000..23532ae42
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/transaction_util.cc
@@ -0,0 +1,182 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "utilities/transactions/transaction_util.h"
+
+#include <cinttypes>
+#include <string>
+#include <vector>
+
+#include "db/db_impl/db_impl.h"
+#include "rocksdb/status.h"
+#include "rocksdb/utilities/write_batch_with_index.h"
+#include "util/string_util.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+Status TransactionUtil::CheckKeyForConflicts(
+ DBImpl* db_impl, ColumnFamilyHandle* column_family, const std::string& key,
+ SequenceNumber snap_seq, bool cache_only, ReadCallback* snap_checker,
+ SequenceNumber min_uncommitted) {
+ Status result;
+
+ auto cfh = reinterpret_cast<ColumnFamilyHandleImpl*>(column_family);
+ auto cfd = cfh->cfd();
+ SuperVersion* sv = db_impl->GetAndRefSuperVersion(cfd);
+
+ if (sv == nullptr) {
+ result = Status::InvalidArgument("Could not access column family " +
+ cfh->GetName());
+ }
+
+ if (result.ok()) {
+ SequenceNumber earliest_seq =
+ db_impl->GetEarliestMemTableSequenceNumber(sv, true);
+
+ result = CheckKey(db_impl, sv, earliest_seq, snap_seq, key, cache_only,
+ snap_checker, min_uncommitted);
+
+ db_impl->ReturnAndCleanupSuperVersion(cfd, sv);
+ }
+
+ return result;
+}
+
+Status TransactionUtil::CheckKey(DBImpl* db_impl, SuperVersion* sv,
+ SequenceNumber earliest_seq,
+ SequenceNumber snap_seq,
+ const std::string& key, bool cache_only,
+ ReadCallback* snap_checker,
+ SequenceNumber min_uncommitted) {
+ // When `min_uncommitted` is provided, keys are not always committed
+ // in sequence number order, and `snap_checker` is used to check whether
+ // specific sequence number is in the database is visible to the transaction.
+ // So `snap_checker` must be provided.
+ assert(min_uncommitted == kMaxSequenceNumber || snap_checker != nullptr);
+
+ Status result;
+ bool need_to_read_sst = false;
+
+ // Since it would be too slow to check the SST files, we will only use
+ // the memtables to check whether there have been any recent writes
+ // to this key after it was accessed in this transaction. But if the
+ // Memtables do not contain a long enough history, we must fail the
+ // transaction.
+ if (earliest_seq == kMaxSequenceNumber) {
+ // The age of this memtable is unknown. Cannot rely on it to check
+ // for recent writes. This error shouldn't happen often in practice as
+ // the Memtable should have a valid earliest sequence number except in some
+ // corner cases (such as error cases during recovery).
+ need_to_read_sst = true;
+
+ if (cache_only) {
+ result = Status::TryAgain(
+ "Transaction could not check for conflicts as the MemTable does not "
+ "contain a long enough history to check write at SequenceNumber: ",
+ ToString(snap_seq));
+ }
+ } else if (snap_seq < earliest_seq || min_uncommitted <= earliest_seq) {
+ // Use <= for min_uncommitted since earliest_seq is actually the largest sec
+ // before this memtable was created
+ need_to_read_sst = true;
+
+ if (cache_only) {
+ // The age of this memtable is too new to use to check for recent
+ // writes.
+ char msg[300];
+ snprintf(msg, sizeof(msg),
+ "Transaction could not check for conflicts for operation at "
+ "SequenceNumber %" PRIu64
+ " as the MemTable only contains changes newer than "
+ "SequenceNumber %" PRIu64
+ ". Increasing the value of the "
+ "max_write_buffer_size_to_maintain option could reduce the "
+ "frequency "
+ "of this error.",
+ snap_seq, earliest_seq);
+ result = Status::TryAgain(msg);
+ }
+ }
+
+ if (result.ok()) {
+ SequenceNumber seq = kMaxSequenceNumber;
+ bool found_record_for_key = false;
+
+ // When min_uncommitted == kMaxSequenceNumber, writes are committed in
+ // sequence number order, so only keys larger than `snap_seq` can cause
+ // conflict.
+ // When min_uncommitted != kMaxSequenceNumber, keys lower than
+ // min_uncommitted will not triggered conflicts, while keys larger than
+ // min_uncommitted might create conflicts, so we need to read them out
+ // from the DB, and call callback to snap_checker to determine. So only
+ // keys lower than min_uncommitted can be skipped.
+ SequenceNumber lower_bound_seq =
+ (min_uncommitted == kMaxSequenceNumber) ? snap_seq : min_uncommitted;
+ Status s = db_impl->GetLatestSequenceForKey(sv, key, !need_to_read_sst,
+ lower_bound_seq, &seq,
+ &found_record_for_key);
+
+ if (!(s.ok() || s.IsNotFound() || s.IsMergeInProgress())) {
+ result = s;
+ } else if (found_record_for_key) {
+ bool write_conflict = snap_checker == nullptr
+ ? snap_seq < seq
+ : !snap_checker->IsVisible(seq);
+ if (write_conflict) {
+ result = Status::Busy();
+ }
+ }
+ }
+
+ return result;
+}
+
+Status TransactionUtil::CheckKeysForConflicts(DBImpl* db_impl,
+ const TransactionKeyMap& key_map,
+ bool cache_only) {
+ Status result;
+
+ for (auto& key_map_iter : key_map) {
+ uint32_t cf_id = key_map_iter.first;
+ const auto& keys = key_map_iter.second;
+
+ SuperVersion* sv = db_impl->GetAndRefSuperVersion(cf_id);
+ if (sv == nullptr) {
+ result = Status::InvalidArgument("Could not access column family " +
+ ToString(cf_id));
+ break;
+ }
+
+ SequenceNumber earliest_seq =
+ db_impl->GetEarliestMemTableSequenceNumber(sv, true);
+
+ // For each of the keys in this transaction, check to see if someone has
+ // written to this key since the start of the transaction.
+ for (const auto& key_iter : keys) {
+ const auto& key = key_iter.first;
+ const SequenceNumber key_seq = key_iter.second.seq;
+
+ result = CheckKey(db_impl, sv, earliest_seq, key_seq, key, cache_only);
+
+ if (!result.ok()) {
+ break;
+ }
+ }
+
+ db_impl->ReturnAndCleanupSuperVersion(cf_id, sv);
+
+ if (!result.ok()) {
+ break;
+ }
+ }
+
+ return result;
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/transaction_util.h b/src/rocksdb/utilities/transactions/transaction_util.h
new file mode 100644
index 000000000..2e48f84a4
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/transaction_util.h
@@ -0,0 +1,103 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#ifndef ROCKSDB_LITE
+
+#include <string>
+#include <unordered_map>
+
+#include "db/dbformat.h"
+#include "db/read_callback.h"
+
+#include "rocksdb/db.h"
+#include "rocksdb/slice.h"
+#include "rocksdb/status.h"
+#include "rocksdb/types.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+struct TransactionKeyMapInfo {
+ // Earliest sequence number that is relevant to this transaction for this key
+ SequenceNumber seq;
+
+ uint32_t num_writes;
+ uint32_t num_reads;
+
+ bool exclusive;
+
+ explicit TransactionKeyMapInfo(SequenceNumber seq_no)
+ : seq(seq_no), num_writes(0), num_reads(0), exclusive(false) {}
+
+ // Used in PopSavePoint to collapse two savepoints together.
+ void Merge(const TransactionKeyMapInfo& info) {
+ assert(seq <= info.seq);
+ num_reads += info.num_reads;
+ num_writes += info.num_writes;
+ exclusive |= info.exclusive;
+ }
+};
+
+using TransactionKeyMap =
+ std::unordered_map<uint32_t,
+ std::unordered_map<std::string, TransactionKeyMapInfo>>;
+
+class DBImpl;
+struct SuperVersion;
+class WriteBatchWithIndex;
+
+class TransactionUtil {
+ public:
+ // Verifies there have been no commits to this key in the db since this
+ // sequence number.
+ //
+ // If cache_only is true, then this function will not attempt to read any
+ // SST files. This will make it more likely this function will
+ // return an error if it is unable to determine if there are any conflicts.
+ //
+ // See comment of CheckKey() for explanation of `snap_seq`, `snap_checker`
+ // and `min_uncommitted`.
+ //
+ // Returns OK on success, BUSY if there is a conflicting write, or other error
+ // status for any unexpected errors.
+ static Status CheckKeyForConflicts(
+ DBImpl* db_impl, ColumnFamilyHandle* column_family,
+ const std::string& key, SequenceNumber snap_seq, bool cache_only,
+ ReadCallback* snap_checker = nullptr,
+ SequenceNumber min_uncommitted = kMaxSequenceNumber);
+
+ // For each key,SequenceNumber pair in the TransactionKeyMap, this function
+ // will verify there have been no writes to the key in the db since that
+ // sequence number.
+ //
+ // Returns OK on success, BUSY if there is a conflicting write, or other error
+ // status for any unexpected errors.
+ //
+ // REQUIRED: this function should only be called on the write thread or if the
+ // mutex is held.
+ static Status CheckKeysForConflicts(DBImpl* db_impl,
+ const TransactionKeyMap& keys,
+ bool cache_only);
+
+ private:
+ // If `snap_checker` == nullptr, writes are always commited in sequence number
+ // order. All sequence number <= `snap_seq` will not conflict with any
+ // write, and all keys > `snap_seq` of `key` will trigger conflict.
+ // If `snap_checker` != nullptr, writes may not commit in sequence number
+ // order. In this case `min_uncommitted` is a lower bound.
+ // seq < `min_uncommitted`: no conflict
+ // seq > `snap_seq`: applicable to conflict
+ // `min_uncommitted` <= seq <= `snap_seq`: call `snap_checker` to determine.
+ static Status CheckKey(DBImpl* db_impl, SuperVersion* sv,
+ SequenceNumber earliest_seq, SequenceNumber snap_seq,
+ const std::string& key, bool cache_only,
+ ReadCallback* snap_checker = nullptr,
+ SequenceNumber min_uncommitted = kMaxSequenceNumber);
+};
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/write_prepared_transaction_test.cc b/src/rocksdb/utilities/transactions/write_prepared_transaction_test.cc
new file mode 100644
index 000000000..0171b9716
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/write_prepared_transaction_test.cc
@@ -0,0 +1,3524 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "utilities/transactions/transaction_test.h"
+
+#include <algorithm>
+#include <atomic>
+#include <cinttypes>
+#include <functional>
+#include <string>
+#include <thread>
+
+#include "db/db_impl/db_impl.h"
+#include "db/dbformat.h"
+#include "rocksdb/db.h"
+#include "rocksdb/options.h"
+#include "rocksdb/types.h"
+#include "rocksdb/utilities/debug.h"
+#include "rocksdb/utilities/transaction.h"
+#include "rocksdb/utilities/transaction_db.h"
+#include "table/mock_table.h"
+#include "test_util/fault_injection_test_env.h"
+#include "test_util/sync_point.h"
+#include "test_util/testharness.h"
+#include "test_util/testutil.h"
+#include "test_util/transaction_test_util.h"
+#include "util/mutexlock.h"
+#include "util/random.h"
+#include "util/string_util.h"
+#include "utilities/merge_operators.h"
+#include "utilities/merge_operators/string_append/stringappend.h"
+#include "utilities/transactions/pessimistic_transaction_db.h"
+#include "utilities/transactions/write_prepared_txn_db.h"
+
+#include "port/port.h"
+
+using std::string;
+
+namespace ROCKSDB_NAMESPACE {
+
+using CommitEntry = WritePreparedTxnDB::CommitEntry;
+using CommitEntry64b = WritePreparedTxnDB::CommitEntry64b;
+using CommitEntry64bFormat = WritePreparedTxnDB::CommitEntry64bFormat;
+
+TEST(PreparedHeap, BasicsTest) {
+ WritePreparedTxnDB::PreparedHeap heap;
+ {
+ MutexLock ml(heap.push_pop_mutex());
+ heap.push(14l);
+ // Test with one element
+ ASSERT_EQ(14l, heap.top());
+ heap.push(24l);
+ heap.push(34l);
+ // Test that old min is still on top
+ ASSERT_EQ(14l, heap.top());
+ heap.push(44l);
+ heap.push(54l);
+ heap.push(64l);
+ heap.push(74l);
+ heap.push(84l);
+ }
+ // Test that old min is still on top
+ ASSERT_EQ(14l, heap.top());
+ heap.erase(24l);
+ // Test that old min is still on top
+ ASSERT_EQ(14l, heap.top());
+ heap.erase(14l);
+ // Test that the new comes to the top after multiple erase
+ ASSERT_EQ(34l, heap.top());
+ heap.erase(34l);
+ // Test that the new comes to the top after single erase
+ ASSERT_EQ(44l, heap.top());
+ heap.erase(54l);
+ ASSERT_EQ(44l, heap.top());
+ heap.pop(); // pop 44l
+ // Test that the erased items are ignored after pop
+ ASSERT_EQ(64l, heap.top());
+ heap.erase(44l);
+ // Test that erasing an already popped item would work
+ ASSERT_EQ(64l, heap.top());
+ heap.erase(84l);
+ ASSERT_EQ(64l, heap.top());
+ {
+ MutexLock ml(heap.push_pop_mutex());
+ heap.push(85l);
+ heap.push(86l);
+ heap.push(87l);
+ heap.push(88l);
+ heap.push(89l);
+ }
+ heap.erase(87l);
+ heap.erase(85l);
+ heap.erase(89l);
+ heap.erase(86l);
+ heap.erase(88l);
+ // Test top remains the same after a random order of many erases
+ ASSERT_EQ(64l, heap.top());
+ heap.pop();
+ // Test that pop works with a series of random pending erases
+ ASSERT_EQ(74l, heap.top());
+ ASSERT_FALSE(heap.empty());
+ heap.pop();
+ // Test that empty works
+ ASSERT_TRUE(heap.empty());
+}
+
+// This is a scenario reconstructed from a buggy trace. Test that the bug does
+// not resurface again.
+TEST(PreparedHeap, EmptyAtTheEnd) {
+ WritePreparedTxnDB::PreparedHeap heap;
+ {
+ MutexLock ml(heap.push_pop_mutex());
+ heap.push(40l);
+ }
+ ASSERT_EQ(40l, heap.top());
+ // Although not a recommended scenario, we must be resilient against erase
+ // without a prior push.
+ heap.erase(50l);
+ ASSERT_EQ(40l, heap.top());
+ {
+ MutexLock ml(heap.push_pop_mutex());
+ heap.push(60l);
+ }
+ ASSERT_EQ(40l, heap.top());
+
+ heap.erase(60l);
+ ASSERT_EQ(40l, heap.top());
+ heap.erase(40l);
+ ASSERT_TRUE(heap.empty());
+
+ {
+ MutexLock ml(heap.push_pop_mutex());
+ heap.push(40l);
+ }
+ ASSERT_EQ(40l, heap.top());
+ heap.erase(50l);
+ ASSERT_EQ(40l, heap.top());
+ {
+ MutexLock ml(heap.push_pop_mutex());
+ heap.push(60l);
+ }
+ ASSERT_EQ(40l, heap.top());
+
+ heap.erase(40l);
+ // Test that the erase has not emptied the heap (we had a bug doing that)
+ ASSERT_FALSE(heap.empty());
+ ASSERT_EQ(60l, heap.top());
+ heap.erase(60l);
+ ASSERT_TRUE(heap.empty());
+}
+
+// Generate random order of PreparedHeap access and test that the heap will be
+// successfully emptied at the end.
+TEST(PreparedHeap, Concurrent) {
+ const size_t t_cnt = 10;
+ ROCKSDB_NAMESPACE::port::Thread t[t_cnt + 1];
+ WritePreparedTxnDB::PreparedHeap heap;
+ port::RWMutex prepared_mutex;
+ std::atomic<size_t> last;
+
+ for (size_t n = 0; n < 100; n++) {
+ last = 0;
+ t[0] = ROCKSDB_NAMESPACE::port::Thread([&]() {
+ Random rnd(1103);
+ for (size_t seq = 1; seq <= t_cnt; seq++) {
+ // This is not recommended usage but we should be resilient against it.
+ bool skip_push = rnd.OneIn(5);
+ if (!skip_push) {
+ MutexLock ml(heap.push_pop_mutex());
+ std::this_thread::yield();
+ heap.push(seq);
+ last.store(seq);
+ }
+ }
+ });
+ for (size_t i = 1; i <= t_cnt; i++) {
+ t[i] =
+ ROCKSDB_NAMESPACE::port::Thread([&heap, &prepared_mutex, &last, i]() {
+ auto seq = i;
+ do {
+ std::this_thread::yield();
+ } while (last.load() < seq);
+ WriteLock wl(&prepared_mutex);
+ heap.erase(seq);
+ });
+ }
+ for (size_t i = 0; i <= t_cnt; i++) {
+ t[i].join();
+ }
+ ASSERT_TRUE(heap.empty());
+ }
+}
+
+// Test that WriteBatchWithIndex correctly counts the number of sub-batches
+TEST(WriteBatchWithIndex, SubBatchCnt) {
+ ColumnFamilyOptions cf_options;
+ std::string cf_name = "two";
+ DB* db;
+ Options options;
+ options.create_if_missing = true;
+ const std::string dbname = test::PerThreadDBPath("transaction_testdb");
+ DestroyDB(dbname, options);
+ ASSERT_OK(DB::Open(options, dbname, &db));
+ ColumnFamilyHandle* cf_handle = nullptr;
+ ASSERT_OK(db->CreateColumnFamily(cf_options, cf_name, &cf_handle));
+ WriteOptions write_options;
+ size_t batch_cnt = 1;
+ size_t save_points = 0;
+ std::vector<size_t> batch_cnt_at;
+ WriteBatchWithIndex batch(db->DefaultColumnFamily()->GetComparator(), 0, true,
+ 0);
+ ASSERT_EQ(batch_cnt, batch.SubBatchCnt());
+ batch_cnt_at.push_back(batch_cnt);
+ batch.SetSavePoint();
+ save_points++;
+ batch.Put(Slice("key"), Slice("value"));
+ ASSERT_EQ(batch_cnt, batch.SubBatchCnt());
+ batch_cnt_at.push_back(batch_cnt);
+ batch.SetSavePoint();
+ save_points++;
+ batch.Put(Slice("key2"), Slice("value2"));
+ ASSERT_EQ(batch_cnt, batch.SubBatchCnt());
+ // duplicate the keys
+ batch_cnt_at.push_back(batch_cnt);
+ batch.SetSavePoint();
+ save_points++;
+ batch.Put(Slice("key"), Slice("value3"));
+ batch_cnt++;
+ ASSERT_EQ(batch_cnt, batch.SubBatchCnt());
+ // duplicate the 2nd key. It should not be counted duplicate since a
+ // sub-patch is cut after the last duplicate.
+ batch_cnt_at.push_back(batch_cnt);
+ batch.SetSavePoint();
+ save_points++;
+ batch.Put(Slice("key2"), Slice("value4"));
+ ASSERT_EQ(batch_cnt, batch.SubBatchCnt());
+ // duplicate the keys but in a different cf. It should not be counted as
+ // duplicate keys
+ batch_cnt_at.push_back(batch_cnt);
+ batch.SetSavePoint();
+ save_points++;
+ batch.Put(cf_handle, Slice("key"), Slice("value5"));
+ ASSERT_EQ(batch_cnt, batch.SubBatchCnt());
+
+ // Test that the number of sub-batches matches what we count with
+ // SubBatchCounter
+ std::map<uint32_t, const Comparator*> comparators;
+ comparators[0] = db->DefaultColumnFamily()->GetComparator();
+ comparators[cf_handle->GetID()] = cf_handle->GetComparator();
+ SubBatchCounter counter(comparators);
+ ASSERT_OK(batch.GetWriteBatch()->Iterate(&counter));
+ ASSERT_EQ(batch_cnt, counter.BatchCount());
+
+ // Test that RollbackToSavePoint will properly resets the number of
+ // sub-batches
+ for (size_t i = save_points; i > 0; i--) {
+ batch.RollbackToSavePoint();
+ ASSERT_EQ(batch_cnt_at[i - 1], batch.SubBatchCnt());
+ }
+
+ // Test the count is right with random batches
+ {
+ const size_t TOTAL_KEYS = 20; // 20 ~= 10 to cause a few randoms
+ Random rnd(1131);
+ std::string keys[TOTAL_KEYS];
+ for (size_t k = 0; k < TOTAL_KEYS; k++) {
+ int len = static_cast<int>(rnd.Uniform(50));
+ keys[k] = test::RandomKey(&rnd, len);
+ }
+ for (size_t i = 0; i < 1000; i++) { // 1000 random batches
+ WriteBatchWithIndex rndbatch(db->DefaultColumnFamily()->GetComparator(),
+ 0, true, 0);
+ for (size_t k = 0; k < 10; k++) { // 10 key per batch
+ size_t ki = static_cast<size_t>(rnd.Uniform(TOTAL_KEYS));
+ Slice key = Slice(keys[ki]);
+ std::string buffer;
+ Slice value = Slice(test::RandomString(&rnd, 16, &buffer));
+ rndbatch.Put(key, value);
+ }
+ SubBatchCounter batch_counter(comparators);
+ ASSERT_OK(rndbatch.GetWriteBatch()->Iterate(&batch_counter));
+ ASSERT_EQ(rndbatch.SubBatchCnt(), batch_counter.BatchCount());
+ }
+ }
+
+ delete cf_handle;
+ delete db;
+}
+
+TEST(CommitEntry64b, BasicTest) {
+ const size_t INDEX_BITS = static_cast<size_t>(21);
+ const size_t INDEX_SIZE = static_cast<size_t>(1ull << INDEX_BITS);
+ const CommitEntry64bFormat FORMAT(static_cast<size_t>(INDEX_BITS));
+
+ // zero-initialized CommitEntry64b should indicate an empty entry
+ CommitEntry64b empty_entry64b;
+ uint64_t empty_index = 11ul;
+ CommitEntry empty_entry;
+ bool ok = empty_entry64b.Parse(empty_index, &empty_entry, FORMAT);
+ ASSERT_FALSE(ok);
+
+ // the zero entry is reserved for un-initialized entries
+ const size_t MAX_COMMIT = (1 << FORMAT.COMMIT_BITS) - 1 - 1;
+ // Samples over the numbers that are covered by that many index bits
+ std::array<uint64_t, 4> is = {{0, 1, INDEX_SIZE / 2 + 1, INDEX_SIZE - 1}};
+ // Samples over the numbers that are covered by that many commit bits
+ std::array<uint64_t, 4> ds = {{0, 1, MAX_COMMIT / 2 + 1, MAX_COMMIT}};
+ // Iterate over prepare numbers that have i) cover all bits of a sequence
+ // number, and ii) include some bits that fall into the range of index or
+ // commit bits
+ for (uint64_t base = 1; base < kMaxSequenceNumber; base *= 2) {
+ for (uint64_t i : is) {
+ for (uint64_t d : ds) {
+ uint64_t p = base + i + d;
+ for (uint64_t c : {p, p + d / 2, p + d}) {
+ uint64_t index = p % INDEX_SIZE;
+ CommitEntry before(p, c), after;
+ CommitEntry64b entry64b(before, FORMAT);
+ ok = entry64b.Parse(index, &after, FORMAT);
+ ASSERT_TRUE(ok);
+ if (!(before == after)) {
+ printf("base %" PRIu64 " i %" PRIu64 " d %" PRIu64 " p %" PRIu64
+ " c %" PRIu64 " index %" PRIu64 "\n",
+ base, i, d, p, c, index);
+ }
+ ASSERT_EQ(before, after);
+ }
+ }
+ }
+ }
+}
+
+class WritePreparedTxnDBMock : public WritePreparedTxnDB {
+ public:
+ WritePreparedTxnDBMock(DBImpl* db_impl, TransactionDBOptions& opt)
+ : WritePreparedTxnDB(db_impl, opt) {}
+ void SetDBSnapshots(const std::vector<SequenceNumber>& snapshots) {
+ snapshots_ = snapshots;
+ }
+ void TakeSnapshot(SequenceNumber seq) { snapshots_.push_back(seq); }
+
+ protected:
+ const std::vector<SequenceNumber> GetSnapshotListFromDB(
+ SequenceNumber /* unused */) override {
+ return snapshots_;
+ }
+
+ private:
+ std::vector<SequenceNumber> snapshots_;
+};
+
+class WritePreparedTransactionTestBase : public TransactionTestBase {
+ public:
+ WritePreparedTransactionTestBase(bool use_stackable_db, bool two_write_queue,
+ TxnDBWritePolicy write_policy,
+ WriteOrdering write_ordering)
+ : TransactionTestBase(use_stackable_db, two_write_queue, write_policy,
+ write_ordering){};
+
+ protected:
+ void UpdateTransactionDBOptions(size_t snapshot_cache_bits,
+ size_t commit_cache_bits) {
+ txn_db_options.wp_snapshot_cache_bits = snapshot_cache_bits;
+ txn_db_options.wp_commit_cache_bits = commit_cache_bits;
+ }
+ void UpdateTransactionDBOptions(size_t snapshot_cache_bits) {
+ txn_db_options.wp_snapshot_cache_bits = snapshot_cache_bits;
+ }
+ // If expect_update is set, check if it actually updated old_commit_map_. If
+ // it did not and yet suggested not to check the next snapshot, do the
+ // opposite to check if it was not a bad suggestion.
+ void MaybeUpdateOldCommitMapTestWithNext(uint64_t prepare, uint64_t commit,
+ uint64_t snapshot,
+ uint64_t next_snapshot,
+ bool expect_update) {
+ WritePreparedTxnDB* wp_db = dynamic_cast<WritePreparedTxnDB*>(db);
+ // reset old_commit_map_empty_ so that its value indicate whether
+ // old_commit_map_ was updated
+ wp_db->old_commit_map_empty_ = true;
+ bool check_next = wp_db->MaybeUpdateOldCommitMap(prepare, commit, snapshot,
+ snapshot < next_snapshot);
+ if (expect_update == wp_db->old_commit_map_empty_) {
+ printf("prepare: %" PRIu64 " commit: %" PRIu64 " snapshot: %" PRIu64
+ " next: %" PRIu64 "\n",
+ prepare, commit, snapshot, next_snapshot);
+ }
+ EXPECT_EQ(!expect_update, wp_db->old_commit_map_empty_);
+ if (!check_next && wp_db->old_commit_map_empty_) {
+ // do the opposite to make sure it was not a bad suggestion
+ const bool dont_care_bool = true;
+ wp_db->MaybeUpdateOldCommitMap(prepare, commit, next_snapshot,
+ dont_care_bool);
+ if (!wp_db->old_commit_map_empty_) {
+ printf("prepare: %" PRIu64 " commit: %" PRIu64 " snapshot: %" PRIu64
+ " next: %" PRIu64 "\n",
+ prepare, commit, snapshot, next_snapshot);
+ }
+ EXPECT_TRUE(wp_db->old_commit_map_empty_);
+ }
+ }
+
+ // Test that a CheckAgainstSnapshots thread reading old_snapshots will not
+ // miss a snapshot because of a concurrent update by UpdateSnapshots that is
+ // writing new_snapshots. Both threads are broken at two points. The sync
+ // points to enforce them are specified by a1, a2, b1, and b2. CommitEntry
+ // entry is expected to be vital for one of the snapshots that is common
+ // between the old and new list of snapshots.
+ void SnapshotConcurrentAccessTestInternal(
+ WritePreparedTxnDB* wp_db,
+ const std::vector<SequenceNumber>& old_snapshots,
+ const std::vector<SequenceNumber>& new_snapshots, CommitEntry& entry,
+ SequenceNumber& version, size_t a1, size_t a2, size_t b1, size_t b2) {
+ // First reset the snapshot list
+ const std::vector<SequenceNumber> empty_snapshots;
+ wp_db->old_commit_map_empty_ = true;
+ wp_db->UpdateSnapshots(empty_snapshots, ++version);
+ // Then initialize it with the old_snapshots
+ wp_db->UpdateSnapshots(old_snapshots, ++version);
+
+ // Starting from the first thread, cut each thread at two points
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency({
+ {"WritePreparedTxnDB::CheckAgainstSnapshots:p:" + std::to_string(a1),
+ "WritePreparedTxnDB::UpdateSnapshots:s:start"},
+ {"WritePreparedTxnDB::UpdateSnapshots:p:" + std::to_string(b1),
+ "WritePreparedTxnDB::CheckAgainstSnapshots:s:" + std::to_string(a1)},
+ {"WritePreparedTxnDB::CheckAgainstSnapshots:p:" + std::to_string(a2),
+ "WritePreparedTxnDB::UpdateSnapshots:s:" + std::to_string(b1)},
+ {"WritePreparedTxnDB::UpdateSnapshots:p:" + std::to_string(b2),
+ "WritePreparedTxnDB::CheckAgainstSnapshots:s:" + std::to_string(a2)},
+ {"WritePreparedTxnDB::CheckAgainstSnapshots:p:end",
+ "WritePreparedTxnDB::UpdateSnapshots:s:" + std::to_string(b2)},
+ });
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+ {
+ ASSERT_TRUE(wp_db->old_commit_map_empty_);
+ ROCKSDB_NAMESPACE::port::Thread t1(
+ [&]() { wp_db->UpdateSnapshots(new_snapshots, version); });
+ ROCKSDB_NAMESPACE::port::Thread t2(
+ [&]() { wp_db->CheckAgainstSnapshots(entry); });
+ t1.join();
+ t2.join();
+ ASSERT_FALSE(wp_db->old_commit_map_empty_);
+ }
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+
+ wp_db->old_commit_map_empty_ = true;
+ wp_db->UpdateSnapshots(empty_snapshots, ++version);
+ wp_db->UpdateSnapshots(old_snapshots, ++version);
+ // Starting from the second thread, cut each thread at two points
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency({
+ {"WritePreparedTxnDB::UpdateSnapshots:p:" + std::to_string(a1),
+ "WritePreparedTxnDB::CheckAgainstSnapshots:s:start"},
+ {"WritePreparedTxnDB::CheckAgainstSnapshots:p:" + std::to_string(b1),
+ "WritePreparedTxnDB::UpdateSnapshots:s:" + std::to_string(a1)},
+ {"WritePreparedTxnDB::UpdateSnapshots:p:" + std::to_string(a2),
+ "WritePreparedTxnDB::CheckAgainstSnapshots:s:" + std::to_string(b1)},
+ {"WritePreparedTxnDB::CheckAgainstSnapshots:p:" + std::to_string(b2),
+ "WritePreparedTxnDB::UpdateSnapshots:s:" + std::to_string(a2)},
+ {"WritePreparedTxnDB::UpdateSnapshots:p:end",
+ "WritePreparedTxnDB::CheckAgainstSnapshots:s:" + std::to_string(b2)},
+ });
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+ {
+ ASSERT_TRUE(wp_db->old_commit_map_empty_);
+ ROCKSDB_NAMESPACE::port::Thread t1(
+ [&]() { wp_db->UpdateSnapshots(new_snapshots, version); });
+ ROCKSDB_NAMESPACE::port::Thread t2(
+ [&]() { wp_db->CheckAgainstSnapshots(entry); });
+ t1.join();
+ t2.join();
+ ASSERT_FALSE(wp_db->old_commit_map_empty_);
+ }
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+ }
+
+ // Verify value of keys.
+ void VerifyKeys(const std::unordered_map<std::string, std::string>& data,
+ const Snapshot* snapshot = nullptr) {
+ std::string value;
+ ReadOptions read_options;
+ read_options.snapshot = snapshot;
+ for (auto& kv : data) {
+ auto s = db->Get(read_options, kv.first, &value);
+ ASSERT_TRUE(s.ok() || s.IsNotFound());
+ if (s.ok()) {
+ if (kv.second != value) {
+ printf("key = %s\n", kv.first.c_str());
+ }
+ ASSERT_EQ(kv.second, value);
+ } else {
+ ASSERT_EQ(kv.second, "NOT_FOUND");
+ }
+
+ // Try with MultiGet API too
+ std::vector<std::string> values;
+ auto s_vec = db->MultiGet(read_options, {db->DefaultColumnFamily()},
+ {kv.first}, &values);
+ ASSERT_EQ(1, values.size());
+ ASSERT_EQ(1, s_vec.size());
+ s = s_vec[0];
+ ASSERT_TRUE(s.ok() || s.IsNotFound());
+ if (s.ok()) {
+ ASSERT_TRUE(kv.second == values[0]);
+ } else {
+ ASSERT_EQ(kv.second, "NOT_FOUND");
+ }
+ }
+ }
+
+ // Verify all versions of keys.
+ void VerifyInternalKeys(const std::vector<KeyVersion>& expected_versions) {
+ std::vector<KeyVersion> versions;
+ const size_t kMaxKeys = 100000;
+ ASSERT_OK(GetAllKeyVersions(db, expected_versions.front().user_key,
+ expected_versions.back().user_key, kMaxKeys,
+ &versions));
+ ASSERT_EQ(expected_versions.size(), versions.size());
+ for (size_t i = 0; i < versions.size(); i++) {
+ ASSERT_EQ(expected_versions[i].user_key, versions[i].user_key);
+ ASSERT_EQ(expected_versions[i].sequence, versions[i].sequence);
+ ASSERT_EQ(expected_versions[i].type, versions[i].type);
+ if (versions[i].type != kTypeDeletion &&
+ versions[i].type != kTypeSingleDeletion) {
+ ASSERT_EQ(expected_versions[i].value, versions[i].value);
+ }
+ // Range delete not supported.
+ assert(expected_versions[i].type != kTypeRangeDeletion);
+ }
+ }
+};
+
+class WritePreparedTransactionTest
+ : public WritePreparedTransactionTestBase,
+ virtual public ::testing::WithParamInterface<
+ std::tuple<bool, bool, TxnDBWritePolicy, WriteOrdering>> {
+ public:
+ WritePreparedTransactionTest()
+ : WritePreparedTransactionTestBase(
+ std::get<0>(GetParam()), std::get<1>(GetParam()),
+ std::get<2>(GetParam()), std::get<3>(GetParam())){};
+};
+
+#ifndef ROCKSDB_VALGRIND_RUN
+class SnapshotConcurrentAccessTest
+ : public WritePreparedTransactionTestBase,
+ virtual public ::testing::WithParamInterface<std::tuple<
+ bool, bool, TxnDBWritePolicy, WriteOrdering, size_t, size_t>> {
+ public:
+ SnapshotConcurrentAccessTest()
+ : WritePreparedTransactionTestBase(
+ std::get<0>(GetParam()), std::get<1>(GetParam()),
+ std::get<2>(GetParam()), std::get<3>(GetParam())),
+ split_id_(std::get<4>(GetParam())),
+ split_cnt_(std::get<5>(GetParam())){};
+
+ protected:
+ // A test is split into split_cnt_ tests, each identified with split_id_ where
+ // 0 <= split_id_ < split_cnt_
+ size_t split_id_;
+ size_t split_cnt_;
+};
+#endif // ROCKSDB_VALGRIND_RUN
+
+class SeqAdvanceConcurrentTest
+ : public WritePreparedTransactionTestBase,
+ virtual public ::testing::WithParamInterface<std::tuple<
+ bool, bool, TxnDBWritePolicy, WriteOrdering, size_t, size_t>> {
+ public:
+ SeqAdvanceConcurrentTest()
+ : WritePreparedTransactionTestBase(
+ std::get<0>(GetParam()), std::get<1>(GetParam()),
+ std::get<2>(GetParam()), std::get<3>(GetParam())),
+ split_id_(std::get<4>(GetParam())),
+ split_cnt_(std::get<5>(GetParam())){};
+
+ protected:
+ // A test is split into split_cnt_ tests, each identified with split_id_ where
+ // 0 <= split_id_ < split_cnt_
+ size_t split_id_;
+ size_t split_cnt_;
+};
+
+INSTANTIATE_TEST_CASE_P(
+ WritePreparedTransaction, WritePreparedTransactionTest,
+ ::testing::Values(
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite)));
+
+#ifndef ROCKSDB_VALGRIND_RUN
+INSTANTIATE_TEST_CASE_P(
+ TwoWriteQueues, SnapshotConcurrentAccessTest,
+ ::testing::Values(
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 0, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 1, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 2, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 3, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 4, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 5, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 6, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 7, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 8, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 9, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 10, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 11, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 12, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 13, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 14, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 15, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 16, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 17, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 18, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 19, 20),
+
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 0, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 1, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 2, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 3, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 4, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 5, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 6, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 7, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 8, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 9, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 10, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 11, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 12, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 13, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 14, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 15, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 16, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 17, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 18, 20),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 19, 20)));
+
+INSTANTIATE_TEST_CASE_P(
+ OneWriteQueue, SnapshotConcurrentAccessTest,
+ ::testing::Values(
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 0, 20),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 1, 20),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 2, 20),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 3, 20),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 4, 20),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 5, 20),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 6, 20),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 7, 20),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 8, 20),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 9, 20),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 10, 20),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 11, 20),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 12, 20),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 13, 20),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 14, 20),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 15, 20),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 16, 20),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 17, 20),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 18, 20),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 19, 20)));
+
+INSTANTIATE_TEST_CASE_P(
+ TwoWriteQueues, SeqAdvanceConcurrentTest,
+ ::testing::Values(
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 0, 10),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 1, 10),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 2, 10),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 3, 10),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 4, 10),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 5, 10),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 6, 10),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 7, 10),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 8, 10),
+ std::make_tuple(false, true, WRITE_PREPARED, kOrderedWrite, 9, 10),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 0, 10),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 1, 10),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 2, 10),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 3, 10),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 4, 10),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 5, 10),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 6, 10),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 7, 10),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 8, 10),
+ std::make_tuple(false, true, WRITE_PREPARED, kUnorderedWrite, 9, 10)));
+
+INSTANTIATE_TEST_CASE_P(
+ OneWriteQueue, SeqAdvanceConcurrentTest,
+ ::testing::Values(
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 0, 10),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 1, 10),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 2, 10),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 3, 10),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 4, 10),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 5, 10),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 6, 10),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 7, 10),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 8, 10),
+ std::make_tuple(false, false, WRITE_PREPARED, kOrderedWrite, 9, 10)));
+#endif // ROCKSDB_VALGRIND_RUN
+
+TEST_P(WritePreparedTransactionTest, CommitMap) {
+ WritePreparedTxnDB* wp_db = dynamic_cast<WritePreparedTxnDB*>(db);
+ assert(wp_db);
+ assert(wp_db->db_impl_);
+ size_t size = wp_db->COMMIT_CACHE_SIZE;
+ CommitEntry c = {5, 12}, e;
+ bool evicted = wp_db->AddCommitEntry(c.prep_seq % size, c, &e);
+ ASSERT_FALSE(evicted);
+
+ // Should be able to read the same value
+ CommitEntry64b dont_care;
+ bool found = wp_db->GetCommitEntry(c.prep_seq % size, &dont_care, &e);
+ ASSERT_TRUE(found);
+ ASSERT_EQ(c, e);
+ // Should be able to distinguish between overlapping entries
+ found = wp_db->GetCommitEntry((c.prep_seq + size) % size, &dont_care, &e);
+ ASSERT_TRUE(found);
+ ASSERT_NE(c.prep_seq + size, e.prep_seq);
+ // Should be able to detect non-existent entry
+ found = wp_db->GetCommitEntry((c.prep_seq + 1) % size, &dont_care, &e);
+ ASSERT_FALSE(found);
+
+ // Reject an invalid exchange
+ CommitEntry e2 = {c.prep_seq + size, c.commit_seq + size};
+ CommitEntry64b e2_64b(e2, wp_db->FORMAT);
+ bool exchanged = wp_db->ExchangeCommitEntry(e2.prep_seq % size, e2_64b, e);
+ ASSERT_FALSE(exchanged);
+ // check whether it did actually reject that
+ found = wp_db->GetCommitEntry(e2.prep_seq % size, &dont_care, &e);
+ ASSERT_TRUE(found);
+ ASSERT_EQ(c, e);
+
+ // Accept a valid exchange
+ CommitEntry64b c_64b(c, wp_db->FORMAT);
+ CommitEntry e3 = {c.prep_seq + size, c.commit_seq + size + 1};
+ exchanged = wp_db->ExchangeCommitEntry(c.prep_seq % size, c_64b, e3);
+ ASSERT_TRUE(exchanged);
+ // check whether it did actually accepted that
+ found = wp_db->GetCommitEntry(c.prep_seq % size, &dont_care, &e);
+ ASSERT_TRUE(found);
+ ASSERT_EQ(e3, e);
+
+ // Rewrite an entry
+ CommitEntry e4 = {e3.prep_seq + size, e3.commit_seq + size + 1};
+ evicted = wp_db->AddCommitEntry(e4.prep_seq % size, e4, &e);
+ ASSERT_TRUE(evicted);
+ ASSERT_EQ(e3, e);
+ found = wp_db->GetCommitEntry(e4.prep_seq % size, &dont_care, &e);
+ ASSERT_TRUE(found);
+ ASSERT_EQ(e4, e);
+}
+
+TEST_P(WritePreparedTransactionTest, MaybeUpdateOldCommitMap) {
+ // If prepare <= snapshot < commit we should keep the entry around since its
+ // nonexistence could be interpreted as committed in the snapshot while it is
+ // not true. We keep such entries around by adding them to the
+ // old_commit_map_.
+ uint64_t p /*prepare*/, c /*commit*/, s /*snapshot*/, ns /*next_snapshot*/;
+ p = 10l, c = 15l, s = 20l, ns = 21l;
+ MaybeUpdateOldCommitMapTestWithNext(p, c, s, ns, false);
+ // If we do not expect the old commit map to be updated, try also with a next
+ // snapshot that is expected to update the old commit map. This would test
+ // that MaybeUpdateOldCommitMap would not prevent us from checking the next
+ // snapshot that must be checked.
+ p = 10l, c = 15l, s = 20l, ns = 11l;
+ MaybeUpdateOldCommitMapTestWithNext(p, c, s, ns, false);
+
+ p = 10l, c = 20l, s = 20l, ns = 19l;
+ MaybeUpdateOldCommitMapTestWithNext(p, c, s, ns, false);
+ p = 10l, c = 20l, s = 20l, ns = 21l;
+ MaybeUpdateOldCommitMapTestWithNext(p, c, s, ns, false);
+
+ p = 20l, c = 20l, s = 20l, ns = 21l;
+ MaybeUpdateOldCommitMapTestWithNext(p, c, s, ns, false);
+ p = 20l, c = 20l, s = 20l, ns = 19l;
+ MaybeUpdateOldCommitMapTestWithNext(p, c, s, ns, false);
+
+ p = 10l, c = 25l, s = 20l, ns = 21l;
+ MaybeUpdateOldCommitMapTestWithNext(p, c, s, ns, true);
+
+ p = 20l, c = 25l, s = 20l, ns = 21l;
+ MaybeUpdateOldCommitMapTestWithNext(p, c, s, ns, true);
+
+ p = 21l, c = 25l, s = 20l, ns = 22l;
+ MaybeUpdateOldCommitMapTestWithNext(p, c, s, ns, false);
+ p = 21l, c = 25l, s = 20l, ns = 19l;
+ MaybeUpdateOldCommitMapTestWithNext(p, c, s, ns, false);
+}
+
+// Trigger the condition where some old memtables are skipped when doing
+// TransactionUtil::CheckKey(), and make sure the result is still correct.
+TEST_P(WritePreparedTransactionTest, CheckKeySkipOldMemtable) {
+ const int kAttemptHistoryMemtable = 0;
+ const int kAttemptImmMemTable = 1;
+ for (int attempt = kAttemptHistoryMemtable; attempt <= kAttemptImmMemTable;
+ attempt++) {
+ options.max_write_buffer_number_to_maintain = 3;
+ ReOpen();
+
+ WriteOptions write_options;
+ ReadOptions read_options;
+ TransactionOptions txn_options;
+ txn_options.set_snapshot = true;
+ string value;
+ Status s;
+
+ ASSERT_OK(db->Put(write_options, Slice("foo"), Slice("bar")));
+ ASSERT_OK(db->Put(write_options, Slice("foo2"), Slice("bar")));
+
+ Transaction* txn = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn != nullptr);
+ ASSERT_OK(txn->SetName("txn"));
+
+ Transaction* txn2 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn2 != nullptr);
+ ASSERT_OK(txn2->SetName("txn2"));
+
+ // This transaction is created to cause potential conflict.
+ Transaction* txn_x = db->BeginTransaction(write_options);
+ ASSERT_OK(txn_x->SetName("txn_x"));
+ ASSERT_OK(txn_x->Put(Slice("foo"), Slice("bar3")));
+ ASSERT_OK(txn_x->Prepare());
+
+ // Create snapshots after the prepare, but there should still
+ // be a conflict when trying to read "foo".
+
+ if (attempt == kAttemptImmMemTable) {
+ // For the second attempt, hold flush from beginning. The memtable
+ // will be switched to immutable after calling TEST_SwitchMemtable()
+ // while CheckKey() is called.
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency(
+ {{"WritePreparedTransactionTest.CheckKeySkipOldMemtable",
+ "FlushJob::Start"}});
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+ }
+
+ // force a memtable flush. The memtable should still be kept
+ FlushOptions flush_ops;
+ if (attempt == kAttemptHistoryMemtable) {
+ ASSERT_OK(db->Flush(flush_ops));
+ } else {
+ assert(attempt == kAttemptImmMemTable);
+ DBImpl* db_impl = static_cast<DBImpl*>(db->GetRootDB());
+ db_impl->TEST_SwitchMemtable();
+ }
+ uint64_t num_imm_mems;
+ ASSERT_TRUE(db->GetIntProperty(DB::Properties::kNumImmutableMemTable,
+ &num_imm_mems));
+ if (attempt == kAttemptHistoryMemtable) {
+ ASSERT_EQ(0, num_imm_mems);
+ } else {
+ assert(attempt == kAttemptImmMemTable);
+ ASSERT_EQ(1, num_imm_mems);
+ }
+
+ // Put something in active memtable
+ ASSERT_OK(db->Put(write_options, Slice("foo3"), Slice("bar")));
+
+ // Create txn3 after flushing, but this transaction also needs to
+ // check all memtables because of they contains uncommitted data.
+ Transaction* txn3 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn3 != nullptr);
+ ASSERT_OK(txn3->SetName("txn3"));
+
+ // Commit the pending write
+ ASSERT_OK(txn_x->Commit());
+
+ // Commit txn, txn2 and tx3. txn and tx3 will conflict but txn2 will
+ // pass. In all cases, both memtables are queried.
+ SetPerfLevel(PerfLevel::kEnableCount);
+ get_perf_context()->Reset();
+ ASSERT_TRUE(txn3->GetForUpdate(read_options, "foo", &value).IsBusy());
+ // We should have checked two memtables, active and either immutable
+ // or history memtable, depending on the test case.
+ ASSERT_EQ(2, get_perf_context()->get_from_memtable_count);
+
+ get_perf_context()->Reset();
+ ASSERT_TRUE(txn->GetForUpdate(read_options, "foo", &value).IsBusy());
+ // We should have checked two memtables, active and either immutable
+ // or history memtable, depending on the test case.
+ ASSERT_EQ(2, get_perf_context()->get_from_memtable_count);
+
+ get_perf_context()->Reset();
+ ASSERT_OK(txn2->GetForUpdate(read_options, "foo2", &value));
+ ASSERT_EQ(value, "bar");
+ // We should have checked two memtables, and since there is no
+ // conflict, another Get() will be made and fetch the data from
+ // DB. If it is in immutable memtable, two extra memtable reads
+ // will be issued. If it is not (in history), only one will
+ // be made, which is to the active memtable.
+ if (attempt == kAttemptHistoryMemtable) {
+ ASSERT_EQ(3, get_perf_context()->get_from_memtable_count);
+ } else {
+ assert(attempt == kAttemptImmMemTable);
+ ASSERT_EQ(4, get_perf_context()->get_from_memtable_count);
+ }
+
+ Transaction* txn4 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_TRUE(txn4 != nullptr);
+ ASSERT_OK(txn4->SetName("txn4"));
+ get_perf_context()->Reset();
+ ASSERT_OK(txn4->GetForUpdate(read_options, "foo", &value));
+ if (attempt == kAttemptHistoryMemtable) {
+ // Active memtable will be checked in snapshot validation and when
+ // getting the value.
+ ASSERT_EQ(2, get_perf_context()->get_from_memtable_count);
+ } else {
+ // Only active memtable will be checked in snapshot validation but
+ // both of active and immutable snapshot will be queried when
+ // getting the value.
+ assert(attempt == kAttemptImmMemTable);
+ ASSERT_EQ(3, get_perf_context()->get_from_memtable_count);
+ }
+
+ ASSERT_OK(txn2->Commit());
+ ASSERT_OK(txn4->Commit());
+
+ TEST_SYNC_POINT("WritePreparedTransactionTest.CheckKeySkipOldMemtable");
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+
+ SetPerfLevel(PerfLevel::kDisable);
+
+ delete txn;
+ delete txn2;
+ delete txn3;
+ delete txn4;
+ delete txn_x;
+ }
+}
+
+// Reproduce the bug with two snapshots with the same seuqence number and test
+// that the release of the first snapshot will not affect the reads by the other
+// snapshot
+TEST_P(WritePreparedTransactionTest, DoubleSnapshot) {
+ TransactionOptions txn_options;
+ Status s;
+
+ // Insert initial value
+ ASSERT_OK(db->Put(WriteOptions(), "key", "value1"));
+
+ WritePreparedTxnDB* wp_db = dynamic_cast<WritePreparedTxnDB*>(db);
+ Transaction* txn =
+ wp_db->BeginTransaction(WriteOptions(), txn_options, nullptr);
+ ASSERT_OK(txn->SetName("txn"));
+ ASSERT_OK(txn->Put("key", "value2"));
+ ASSERT_OK(txn->Prepare());
+ // Three snapshots with the same seq number
+ const Snapshot* snapshot0 = wp_db->GetSnapshot();
+ const Snapshot* snapshot1 = wp_db->GetSnapshot();
+ const Snapshot* snapshot2 = wp_db->GetSnapshot();
+ ASSERT_OK(txn->Commit());
+ SequenceNumber cache_size = wp_db->COMMIT_CACHE_SIZE;
+ SequenceNumber overlap_seq = txn->GetId() + cache_size;
+ delete txn;
+
+ // 4th snapshot with a larger seq
+ const Snapshot* snapshot3 = wp_db->GetSnapshot();
+ // Cause an eviction to advance max evicted seq number
+ // This also fetches the 4 snapshots from db since their seq is lower than the
+ // new max
+ wp_db->AddCommitted(overlap_seq, overlap_seq);
+
+ ReadOptions ropt;
+ // It should see the value before commit
+ ropt.snapshot = snapshot2;
+ PinnableSlice pinnable_val;
+ s = wp_db->Get(ropt, wp_db->DefaultColumnFamily(), "key", &pinnable_val);
+ ASSERT_OK(s);
+ ASSERT_TRUE(pinnable_val == "value1");
+ pinnable_val.Reset();
+
+ wp_db->ReleaseSnapshot(snapshot1);
+
+ // It should still see the value before commit
+ s = wp_db->Get(ropt, wp_db->DefaultColumnFamily(), "key", &pinnable_val);
+ ASSERT_OK(s);
+ ASSERT_TRUE(pinnable_val == "value1");
+ pinnable_val.Reset();
+
+ // Cause an eviction to advance max evicted seq number and trigger updating
+ // the snapshot list
+ overlap_seq += cache_size;
+ wp_db->AddCommitted(overlap_seq, overlap_seq);
+
+ // It should still see the value before commit
+ s = wp_db->Get(ropt, wp_db->DefaultColumnFamily(), "key", &pinnable_val);
+ ASSERT_OK(s);
+ ASSERT_TRUE(pinnable_val == "value1");
+ pinnable_val.Reset();
+
+ wp_db->ReleaseSnapshot(snapshot0);
+ wp_db->ReleaseSnapshot(snapshot2);
+ wp_db->ReleaseSnapshot(snapshot3);
+}
+
+size_t UniqueCnt(std::vector<SequenceNumber> vec) {
+ std::set<SequenceNumber> aset;
+ for (auto i : vec) {
+ aset.insert(i);
+ }
+ return aset.size();
+}
+// Test that the entries in old_commit_map_ get garbage collected properly
+TEST_P(WritePreparedTransactionTest, OldCommitMapGC) {
+ const size_t snapshot_cache_bits = 0;
+ const size_t commit_cache_bits = 0;
+ DBImpl* mock_db = new DBImpl(options, dbname);
+ UpdateTransactionDBOptions(snapshot_cache_bits, commit_cache_bits);
+ std::unique_ptr<WritePreparedTxnDBMock> wp_db(
+ new WritePreparedTxnDBMock(mock_db, txn_db_options));
+
+ SequenceNumber seq = 0;
+ // Take the first snapshot that overlaps with two txn
+ auto prep_seq = ++seq;
+ wp_db->AddPrepared(prep_seq);
+ auto prep_seq2 = ++seq;
+ wp_db->AddPrepared(prep_seq2);
+ auto snap_seq1 = seq;
+ wp_db->TakeSnapshot(snap_seq1);
+ auto commit_seq = ++seq;
+ wp_db->AddCommitted(prep_seq, commit_seq);
+ wp_db->RemovePrepared(prep_seq);
+ auto commit_seq2 = ++seq;
+ wp_db->AddCommitted(prep_seq2, commit_seq2);
+ wp_db->RemovePrepared(prep_seq2);
+ // Take the 2nd and 3rd snapshot that overlap with the same txn
+ prep_seq = ++seq;
+ wp_db->AddPrepared(prep_seq);
+ auto snap_seq2 = seq;
+ wp_db->TakeSnapshot(snap_seq2);
+ seq++;
+ auto snap_seq3 = seq;
+ wp_db->TakeSnapshot(snap_seq3);
+ seq++;
+ commit_seq = ++seq;
+ wp_db->AddCommitted(prep_seq, commit_seq);
+ wp_db->RemovePrepared(prep_seq);
+ // Make sure max_evicted_seq_ will be larger than 2nd snapshot by evicting the
+ // only item in the commit_cache_ via another commit.
+ prep_seq = ++seq;
+ wp_db->AddPrepared(prep_seq);
+ commit_seq = ++seq;
+ wp_db->AddCommitted(prep_seq, commit_seq);
+ wp_db->RemovePrepared(prep_seq);
+
+ // Verify that the evicted commit entries for all snapshots are in the
+ // old_commit_map_
+ {
+ ASSERT_FALSE(wp_db->old_commit_map_empty_.load());
+ ReadLock rl(&wp_db->old_commit_map_mutex_);
+ ASSERT_EQ(3, wp_db->old_commit_map_.size());
+ ASSERT_EQ(2, UniqueCnt(wp_db->old_commit_map_[snap_seq1]));
+ ASSERT_EQ(1, UniqueCnt(wp_db->old_commit_map_[snap_seq2]));
+ ASSERT_EQ(1, UniqueCnt(wp_db->old_commit_map_[snap_seq3]));
+ }
+
+ // Verify that the 2nd snapshot is cleaned up after the release
+ wp_db->ReleaseSnapshotInternal(snap_seq2);
+ {
+ ASSERT_FALSE(wp_db->old_commit_map_empty_.load());
+ ReadLock rl(&wp_db->old_commit_map_mutex_);
+ ASSERT_EQ(2, wp_db->old_commit_map_.size());
+ ASSERT_EQ(2, UniqueCnt(wp_db->old_commit_map_[snap_seq1]));
+ ASSERT_EQ(1, UniqueCnt(wp_db->old_commit_map_[snap_seq3]));
+ }
+
+ // Verify that the 1st snapshot is cleaned up after the release
+ wp_db->ReleaseSnapshotInternal(snap_seq1);
+ {
+ ASSERT_FALSE(wp_db->old_commit_map_empty_.load());
+ ReadLock rl(&wp_db->old_commit_map_mutex_);
+ ASSERT_EQ(1, wp_db->old_commit_map_.size());
+ ASSERT_EQ(1, UniqueCnt(wp_db->old_commit_map_[snap_seq3]));
+ }
+
+ // Verify that the 3rd snapshot is cleaned up after the release
+ wp_db->ReleaseSnapshotInternal(snap_seq3);
+ {
+ ASSERT_TRUE(wp_db->old_commit_map_empty_.load());
+ ReadLock rl(&wp_db->old_commit_map_mutex_);
+ ASSERT_EQ(0, wp_db->old_commit_map_.size());
+ }
+}
+
+TEST_P(WritePreparedTransactionTest, CheckAgainstSnapshots) {
+ std::vector<SequenceNumber> snapshots = {100l, 200l, 300l, 400l, 500l,
+ 600l, 700l, 800l, 900l};
+ const size_t snapshot_cache_bits = 2;
+ const uint64_t cache_size = 1ul << snapshot_cache_bits;
+ // Safety check to express the intended size in the test. Can be adjusted if
+ // the snapshots lists changed.
+ assert((1ul << snapshot_cache_bits) * 2 + 1 == snapshots.size());
+ DBImpl* mock_db = new DBImpl(options, dbname);
+ UpdateTransactionDBOptions(snapshot_cache_bits);
+ std::unique_ptr<WritePreparedTxnDBMock> wp_db(
+ new WritePreparedTxnDBMock(mock_db, txn_db_options));
+ SequenceNumber version = 1000l;
+ ASSERT_EQ(0, wp_db->snapshots_total_);
+ wp_db->UpdateSnapshots(snapshots, version);
+ ASSERT_EQ(snapshots.size(), wp_db->snapshots_total_);
+ // seq numbers are chosen so that we have two of them between each two
+ // snapshots. If the diff of two consecutive seq is more than 5, there is a
+ // snapshot between them.
+ std::vector<SequenceNumber> seqs = {50l, 55l, 150l, 155l, 250l, 255l, 350l,
+ 355l, 450l, 455l, 550l, 555l, 650l, 655l,
+ 750l, 755l, 850l, 855l, 950l, 955l};
+ assert(seqs.size() > 1);
+ for (size_t i = 0; i < seqs.size() - 1; i++) {
+ wp_db->old_commit_map_empty_ = true; // reset
+ CommitEntry commit_entry = {seqs[i], seqs[i + 1]};
+ wp_db->CheckAgainstSnapshots(commit_entry);
+ // Expect update if there is snapshot in between the prepare and commit
+ bool expect_update = commit_entry.commit_seq - commit_entry.prep_seq > 5 &&
+ commit_entry.commit_seq >= snapshots.front() &&
+ commit_entry.prep_seq <= snapshots.back();
+ ASSERT_EQ(expect_update, !wp_db->old_commit_map_empty_);
+ }
+
+ // Test that search will include multiple snapshot from snapshot cache
+ {
+ // exclude first and last item in the cache
+ CommitEntry commit_entry = {snapshots.front() + 1,
+ snapshots[cache_size - 1] - 1};
+ wp_db->old_commit_map_empty_ = true; // reset
+ wp_db->old_commit_map_.clear();
+ wp_db->CheckAgainstSnapshots(commit_entry);
+ ASSERT_EQ(wp_db->old_commit_map_.size(), cache_size - 2);
+ }
+
+ // Test that search will include multiple snapshot from old snapshots
+ {
+ // include two in the middle
+ CommitEntry commit_entry = {snapshots[cache_size] + 1,
+ snapshots[cache_size + 2] + 1};
+ wp_db->old_commit_map_empty_ = true; // reset
+ wp_db->old_commit_map_.clear();
+ wp_db->CheckAgainstSnapshots(commit_entry);
+ ASSERT_EQ(wp_db->old_commit_map_.size(), 2);
+ }
+
+ // Test that search will include both snapshot cache and old snapshots
+ // Case 1: includes all in snapshot cache
+ {
+ CommitEntry commit_entry = {snapshots.front() - 1, snapshots.back() + 1};
+ wp_db->old_commit_map_empty_ = true; // reset
+ wp_db->old_commit_map_.clear();
+ wp_db->CheckAgainstSnapshots(commit_entry);
+ ASSERT_EQ(wp_db->old_commit_map_.size(), snapshots.size());
+ }
+
+ // Case 2: includes all snapshot caches except the smallest
+ {
+ CommitEntry commit_entry = {snapshots.front() + 1, snapshots.back() + 1};
+ wp_db->old_commit_map_empty_ = true; // reset
+ wp_db->old_commit_map_.clear();
+ wp_db->CheckAgainstSnapshots(commit_entry);
+ ASSERT_EQ(wp_db->old_commit_map_.size(), snapshots.size() - 1);
+ }
+
+ // Case 3: includes only the largest of snapshot cache
+ {
+ CommitEntry commit_entry = {snapshots[cache_size - 1] - 1,
+ snapshots.back() + 1};
+ wp_db->old_commit_map_empty_ = true; // reset
+ wp_db->old_commit_map_.clear();
+ wp_db->CheckAgainstSnapshots(commit_entry);
+ ASSERT_EQ(wp_db->old_commit_map_.size(), snapshots.size() - cache_size + 1);
+ }
+}
+
+// This test is too slow for travis
+#ifndef TRAVIS
+#ifndef ROCKSDB_VALGRIND_RUN
+// Test that CheckAgainstSnapshots will not miss a live snapshot if it is run in
+// parallel with UpdateSnapshots.
+TEST_P(SnapshotConcurrentAccessTest, SnapshotConcurrentAccess) {
+ // We have a sync point in the method under test after checking each snapshot.
+ // If you increase the max number of snapshots in this test, more sync points
+ // in the methods must also be added.
+ const std::vector<SequenceNumber> snapshots = {10l, 20l, 30l, 40l, 50l,
+ 60l, 70l, 80l, 90l, 100l};
+ const size_t snapshot_cache_bits = 2;
+ // Safety check to express the intended size in the test. Can be adjusted if
+ // the snapshots lists changed.
+ assert((1ul << snapshot_cache_bits) * 2 + 2 == snapshots.size());
+ SequenceNumber version = 1000l;
+ // Choose the cache size so that the new snapshot list could replace all the
+ // existing items in the cache and also have some overflow.
+ DBImpl* mock_db = new DBImpl(options, dbname);
+ UpdateTransactionDBOptions(snapshot_cache_bits);
+ std::unique_ptr<WritePreparedTxnDBMock> wp_db(
+ new WritePreparedTxnDBMock(mock_db, txn_db_options));
+ const size_t extra = 2;
+ size_t loop_id = 0;
+ // Add up to extra items that do not fit into the cache
+ for (size_t old_size = 1; old_size <= wp_db->SNAPSHOT_CACHE_SIZE + extra;
+ old_size++) {
+ const std::vector<SequenceNumber> old_snapshots(
+ snapshots.begin(), snapshots.begin() + old_size);
+
+ // Each member of old snapshot might or might not appear in the new list. We
+ // create a common_snapshots for each combination.
+ size_t new_comb_cnt = size_t(1) << old_size;
+ for (size_t new_comb = 0; new_comb < new_comb_cnt; new_comb++, loop_id++) {
+ if (loop_id % split_cnt_ != split_id_) continue;
+ printf("."); // To signal progress
+ fflush(stdout);
+ std::vector<SequenceNumber> common_snapshots;
+ for (size_t i = 0; i < old_snapshots.size(); i++) {
+ if (IsInCombination(i, new_comb)) {
+ common_snapshots.push_back(old_snapshots[i]);
+ }
+ }
+ // And add some new snapshots to the common list
+ for (size_t added_snapshots = 0;
+ added_snapshots <= snapshots.size() - old_snapshots.size();
+ added_snapshots++) {
+ std::vector<SequenceNumber> new_snapshots = common_snapshots;
+ for (size_t i = 0; i < added_snapshots; i++) {
+ new_snapshots.push_back(snapshots[old_snapshots.size() + i]);
+ }
+ for (auto it = common_snapshots.begin(); it != common_snapshots.end();
+ ++it) {
+ auto snapshot = *it;
+ // Create a commit entry that is around the snapshot and thus should
+ // be not be discarded
+ CommitEntry entry = {static_cast<uint64_t>(snapshot - 1),
+ snapshot + 1};
+ // The critical part is when iterating the snapshot cache. Afterwards,
+ // we are operating under the lock
+ size_t a_range =
+ std::min(old_snapshots.size(), wp_db->SNAPSHOT_CACHE_SIZE) + 1;
+ size_t b_range =
+ std::min(new_snapshots.size(), wp_db->SNAPSHOT_CACHE_SIZE) + 1;
+ // Break each thread at two points
+ for (size_t a1 = 1; a1 <= a_range; a1++) {
+ for (size_t a2 = a1 + 1; a2 <= a_range; a2++) {
+ for (size_t b1 = 1; b1 <= b_range; b1++) {
+ for (size_t b2 = b1 + 1; b2 <= b_range; b2++) {
+ SnapshotConcurrentAccessTestInternal(
+ wp_db.get(), old_snapshots, new_snapshots, entry, version,
+ a1, a2, b1, b2);
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ printf("\n");
+}
+#endif // ROCKSDB_VALGRIND_RUN
+#endif // TRAVIS
+
+// This test clarifies the contract of AdvanceMaxEvictedSeq method
+TEST_P(WritePreparedTransactionTest, AdvanceMaxEvictedSeqBasic) {
+ DBImpl* mock_db = new DBImpl(options, dbname);
+ std::unique_ptr<WritePreparedTxnDBMock> wp_db(
+ new WritePreparedTxnDBMock(mock_db, txn_db_options));
+
+ // 1. Set the initial values for max, prepared, and snapshots
+ SequenceNumber zero_max = 0l;
+ // Set the initial list of prepared txns
+ const std::vector<SequenceNumber> initial_prepared = {10, 30, 50, 100,
+ 150, 200, 250};
+ for (auto p : initial_prepared) {
+ wp_db->AddPrepared(p);
+ }
+ // This updates the max value and also set old prepared
+ SequenceNumber init_max = 100;
+ wp_db->AdvanceMaxEvictedSeq(zero_max, init_max);
+ const std::vector<SequenceNumber> initial_snapshots = {20, 40};
+ wp_db->SetDBSnapshots(initial_snapshots);
+ // This will update the internal cache of snapshots from the DB
+ wp_db->UpdateSnapshots(initial_snapshots, init_max);
+
+ // 2. Invoke AdvanceMaxEvictedSeq
+ const std::vector<SequenceNumber> latest_snapshots = {20, 110, 220, 300};
+ wp_db->SetDBSnapshots(latest_snapshots);
+ SequenceNumber new_max = 200;
+ wp_db->AdvanceMaxEvictedSeq(init_max, new_max);
+
+ // 3. Verify that the state matches with AdvanceMaxEvictedSeq contract
+ // a. max should be updated to new_max
+ ASSERT_EQ(wp_db->max_evicted_seq_, new_max);
+ // b. delayed prepared should contain every txn <= max and prepared should
+ // only contain txns > max
+ auto it = initial_prepared.begin();
+ for (; it != initial_prepared.end() && *it <= new_max; ++it) {
+ ASSERT_EQ(1, wp_db->delayed_prepared_.erase(*it));
+ }
+ ASSERT_TRUE(wp_db->delayed_prepared_.empty());
+ for (; it != initial_prepared.end() && !wp_db->prepared_txns_.empty();
+ ++it, wp_db->prepared_txns_.pop()) {
+ ASSERT_EQ(*it, wp_db->prepared_txns_.top());
+ }
+ ASSERT_TRUE(it == initial_prepared.end());
+ ASSERT_TRUE(wp_db->prepared_txns_.empty());
+ // c. snapshots should contain everything below new_max
+ auto sit = latest_snapshots.begin();
+ for (size_t i = 0; sit != latest_snapshots.end() && *sit <= new_max &&
+ i < wp_db->snapshots_total_;
+ sit++, i++) {
+ ASSERT_TRUE(i < wp_db->snapshots_total_);
+ // This test is in small scale and the list of snapshots are assumed to be
+ // within the cache size limit. This is just a safety check to double check
+ // that assumption.
+ ASSERT_TRUE(i < wp_db->SNAPSHOT_CACHE_SIZE);
+ ASSERT_EQ(*sit, wp_db->snapshot_cache_[i]);
+ }
+}
+
+// A new snapshot should always be always larger than max_evicted_seq_
+// Otherwise the snapshot does not go through AdvanceMaxEvictedSeq
+TEST_P(WritePreparedTransactionTest, NewSnapshotLargerThanMax) {
+ WriteOptions woptions;
+ TransactionOptions txn_options;
+ WritePreparedTxnDB* wp_db = dynamic_cast<WritePreparedTxnDB*>(db);
+ Transaction* txn0 = db->BeginTransaction(woptions, txn_options);
+ ASSERT_OK(txn0->Put(Slice("key"), Slice("value")));
+ ASSERT_OK(txn0->Commit());
+ const SequenceNumber seq = txn0->GetId(); // is also prepare seq
+ delete txn0;
+ std::vector<Transaction*> txns;
+ // Inc seq without committing anything
+ for (int i = 0; i < 10; i++) {
+ Transaction* txn = db->BeginTransaction(woptions, txn_options);
+ ASSERT_OK(txn->SetName("xid" + std::to_string(i)));
+ ASSERT_OK(txn->Put(Slice("key" + std::to_string(i)), Slice("value")));
+ ASSERT_OK(txn->Prepare());
+ txns.push_back(txn);
+ }
+
+ // The new commit is seq + 10
+ ASSERT_OK(db->Put(woptions, "key", "value"));
+ auto snap = wp_db->GetSnapshot();
+ const SequenceNumber last_seq = snap->GetSequenceNumber();
+ wp_db->ReleaseSnapshot(snap);
+ ASSERT_LT(seq, last_seq);
+ // Otherwise our test is not effective
+ ASSERT_LT(last_seq - seq, wp_db->INC_STEP_FOR_MAX_EVICTED);
+
+ // Evict seq out of commit cache
+ const SequenceNumber overwrite_seq = seq + wp_db->COMMIT_CACHE_SIZE;
+ // Check that the next write could make max go beyond last
+ auto last_max = wp_db->max_evicted_seq_.load();
+ wp_db->AddCommitted(overwrite_seq, overwrite_seq);
+ // Check that eviction has advanced the max
+ ASSERT_LT(last_max, wp_db->max_evicted_seq_.load());
+ // Check that the new max has not advanced the last seq
+ ASSERT_LT(wp_db->max_evicted_seq_.load(), last_seq);
+ for (auto txn : txns) {
+ txn->Rollback();
+ delete txn;
+ }
+}
+
+// A new snapshot should always be always larger than max_evicted_seq_
+// In very rare cases max could be below last published seq. Test that
+// taking snapshot will wait for max to catch up.
+TEST_P(WritePreparedTransactionTest, MaxCatchupWithNewSnapshot) {
+ const size_t snapshot_cache_bits = 7; // same as default
+ const size_t commit_cache_bits = 0; // only 1 entry => frequent eviction
+ UpdateTransactionDBOptions(snapshot_cache_bits, commit_cache_bits);
+ ReOpen();
+ WriteOptions woptions;
+ WritePreparedTxnDB* wp_db = dynamic_cast<WritePreparedTxnDB*>(db);
+
+ const int writes = 50;
+ const int batch_cnt = 4;
+ ROCKSDB_NAMESPACE::port::Thread t1([&]() {
+ for (int i = 0; i < writes; i++) {
+ WriteBatch batch;
+ // For duplicate keys cause 4 commit entries, each evicting an entry that
+ // is not published yet, thus causing max evicted seq go higher than last
+ // published.
+ for (int b = 0; b < batch_cnt; b++) {
+ batch.Put("foo", "foo");
+ }
+ db->Write(woptions, &batch);
+ }
+ });
+
+ ROCKSDB_NAMESPACE::port::Thread t2([&]() {
+ while (wp_db->max_evicted_seq_ == 0) { // wait for insert thread
+ std::this_thread::yield();
+ }
+ for (int i = 0; i < 10; i++) {
+ SequenceNumber max_lower_bound = wp_db->max_evicted_seq_;
+ auto snap = db->GetSnapshot();
+ if (snap->GetSequenceNumber() != 0) {
+ // Value of max_evicted_seq_ when snapshot was taken in unknown. We thus
+ // compare with the lower bound instead as an approximation.
+ ASSERT_LT(max_lower_bound, snap->GetSequenceNumber());
+ } // seq 0 is ok to be less than max since nothing is visible to it
+ db->ReleaseSnapshot(snap);
+ }
+ });
+
+ t1.join();
+ t2.join();
+
+ // Make sure that the test has worked and seq number has advanced as we
+ // thought
+ auto snap = db->GetSnapshot();
+ ASSERT_GT(snap->GetSequenceNumber(), batch_cnt * writes - 1);
+ db->ReleaseSnapshot(snap);
+}
+
+// Test that reads without snapshots would not hit an undefined state
+TEST_P(WritePreparedTransactionTest, MaxCatchupWithUnbackedSnapshot) {
+ const size_t snapshot_cache_bits = 7; // same as default
+ const size_t commit_cache_bits = 0; // only 1 entry => frequent eviction
+ UpdateTransactionDBOptions(snapshot_cache_bits, commit_cache_bits);
+ ReOpen();
+ WriteOptions woptions;
+ WritePreparedTxnDB* wp_db = dynamic_cast<WritePreparedTxnDB*>(db);
+
+ const int writes = 50;
+ ROCKSDB_NAMESPACE::port::Thread t1([&]() {
+ for (int i = 0; i < writes; i++) {
+ WriteBatch batch;
+ batch.Put("key", "foo");
+ db->Write(woptions, &batch);
+ }
+ });
+
+ ROCKSDB_NAMESPACE::port::Thread t2([&]() {
+ while (wp_db->max_evicted_seq_ == 0) { // wait for insert thread
+ std::this_thread::yield();
+ }
+ ReadOptions ropt;
+ PinnableSlice pinnable_val;
+ TransactionOptions txn_options;
+ for (int i = 0; i < 10; i++) {
+ auto s = db->Get(ropt, db->DefaultColumnFamily(), "key", &pinnable_val);
+ ASSERT_TRUE(s.ok() || s.IsTryAgain());
+ pinnable_val.Reset();
+ Transaction* txn = db->BeginTransaction(woptions, txn_options);
+ s = txn->Get(ropt, db->DefaultColumnFamily(), "key", &pinnable_val);
+ ASSERT_TRUE(s.ok() || s.IsTryAgain());
+ pinnable_val.Reset();
+ std::vector<std::string> values;
+ auto s_vec =
+ txn->MultiGet(ropt, {db->DefaultColumnFamily()}, {"key"}, &values);
+ ASSERT_EQ(1, values.size());
+ ASSERT_EQ(1, s_vec.size());
+ s = s_vec[0];
+ ASSERT_TRUE(s.ok() || s.IsTryAgain());
+ Slice key("key");
+ txn->MultiGet(ropt, db->DefaultColumnFamily(), 1, &key, &pinnable_val, &s,
+ true);
+ ASSERT_TRUE(s.ok() || s.IsTryAgain());
+ delete txn;
+ }
+ });
+
+ t1.join();
+ t2.join();
+
+ // Make sure that the test has worked and seq number has advanced as we
+ // thought
+ auto snap = db->GetSnapshot();
+ ASSERT_GT(snap->GetSequenceNumber(), writes - 1);
+ db->ReleaseSnapshot(snap);
+}
+
+// Check that old_commit_map_ cleanup works correctly if the snapshot equals
+// max_evicted_seq_.
+TEST_P(WritePreparedTransactionTest, CleanupSnapshotEqualToMax) {
+ const size_t snapshot_cache_bits = 7; // same as default
+ const size_t commit_cache_bits = 0; // only 1 entry => frequent eviction
+ UpdateTransactionDBOptions(snapshot_cache_bits, commit_cache_bits);
+ ReOpen();
+ WriteOptions woptions;
+ WritePreparedTxnDB* wp_db = dynamic_cast<WritePreparedTxnDB*>(db);
+ // Insert something to increase seq
+ ASSERT_OK(db->Put(woptions, "key", "value"));
+ auto snap = db->GetSnapshot();
+ auto snap_seq = snap->GetSequenceNumber();
+ // Another insert should trigger eviction + load snapshot from db
+ ASSERT_OK(db->Put(woptions, "key", "value"));
+ // This is the scenario that we check agaisnt
+ ASSERT_EQ(snap_seq, wp_db->max_evicted_seq_);
+ // old_commit_map_ now has some data that needs gc
+ ASSERT_EQ(1, wp_db->snapshots_total_);
+ ASSERT_EQ(1, wp_db->old_commit_map_.size());
+
+ db->ReleaseSnapshot(snap);
+
+ // Another insert should trigger eviction + load snapshot from db
+ ASSERT_OK(db->Put(woptions, "key", "value"));
+
+ // the snapshot and related metadata must be properly garbage collected
+ ASSERT_EQ(0, wp_db->snapshots_total_);
+ ASSERT_TRUE(wp_db->snapshots_all_.empty());
+ ASSERT_EQ(0, wp_db->old_commit_map_.size());
+}
+
+TEST_P(WritePreparedTransactionTest, AdvanceSeqByOne) {
+ auto snap = db->GetSnapshot();
+ auto seq1 = snap->GetSequenceNumber();
+ db->ReleaseSnapshot(snap);
+
+ WritePreparedTxnDB* wp_db = dynamic_cast<WritePreparedTxnDB*>(db);
+ wp_db->AdvanceSeqByOne();
+
+ snap = db->GetSnapshot();
+ auto seq2 = snap->GetSequenceNumber();
+ db->ReleaseSnapshot(snap);
+
+ ASSERT_LT(seq1, seq2);
+}
+
+// Test that the txn Initilize calls the overridden functions
+TEST_P(WritePreparedTransactionTest, TxnInitialize) {
+ TransactionOptions txn_options;
+ WriteOptions write_options;
+ ASSERT_OK(db->Put(write_options, "key", "value"));
+ Transaction* txn0 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(txn0->SetName("xid"));
+ ASSERT_OK(txn0->Put(Slice("key"), Slice("value1")));
+ ASSERT_OK(txn0->Prepare());
+
+ // SetSnapshot is overridden to update min_uncommitted_
+ txn_options.set_snapshot = true;
+ Transaction* txn1 = db->BeginTransaction(write_options, txn_options);
+ auto snap = txn1->GetSnapshot();
+ auto snap_impl = reinterpret_cast<const SnapshotImpl*>(snap);
+ // If ::Initialize calls the overriden SetSnapshot, min_uncommitted_ must be
+ // udpated
+ ASSERT_GT(snap_impl->min_uncommitted_, kMinUnCommittedSeq);
+
+ txn0->Rollback();
+ txn1->Rollback();
+ delete txn0;
+ delete txn1;
+}
+
+// This tests that transactions with duplicate keys perform correctly after max
+// is advancing their prepared sequence numbers. This will not be the case if
+// for example the txn does not add the prepared seq for the second sub-batch to
+// the PreparedHeap structure.
+TEST_P(WritePreparedTransactionTest, AdvanceMaxEvictedSeqWithDuplicates) {
+ const size_t snapshot_cache_bits = 7; // same as default
+ const size_t commit_cache_bits = 1; // disable commit cache
+ UpdateTransactionDBOptions(snapshot_cache_bits, commit_cache_bits);
+ ReOpen();
+
+ ReadOptions ropt;
+ PinnableSlice pinnable_val;
+ WriteOptions write_options;
+ TransactionOptions txn_options;
+ Transaction* txn0 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(txn0->SetName("xid"));
+ ASSERT_OK(txn0->Put(Slice("key"), Slice("value1")));
+ ASSERT_OK(txn0->Put(Slice("key"), Slice("value2")));
+ ASSERT_OK(txn0->Prepare());
+
+ ASSERT_OK(db->Put(write_options, "key2", "value"));
+ // Will cause max advance due to disabled commit cache
+ ASSERT_OK(db->Put(write_options, "key3", "value"));
+
+ auto s = db->Get(ropt, db->DefaultColumnFamily(), "key", &pinnable_val);
+ ASSERT_TRUE(s.IsNotFound());
+ delete txn0;
+
+ WritePreparedTxnDB* wp_db = dynamic_cast<WritePreparedTxnDB*>(db);
+ wp_db->db_impl_->FlushWAL(true);
+ wp_db->TEST_Crash();
+ ReOpenNoDelete();
+ assert(db != nullptr);
+ s = db->Get(ropt, db->DefaultColumnFamily(), "key", &pinnable_val);
+ ASSERT_TRUE(s.IsNotFound());
+
+ txn0 = db->GetTransactionByName("xid");
+ ASSERT_OK(txn0->Rollback());
+ delete txn0;
+}
+
+#ifndef ROCKSDB_VALGRIND_RUN
+// Stress SmallestUnCommittedSeq, which reads from both prepared_txns_ and
+// delayed_prepared_, when is run concurrently with advancing max_evicted_seq,
+// which moves prepared txns from prepared_txns_ to delayed_prepared_.
+TEST_P(WritePreparedTransactionTest, SmallestUnCommittedSeq) {
+ const size_t snapshot_cache_bits = 7; // same as default
+ const size_t commit_cache_bits = 1; // disable commit cache
+ UpdateTransactionDBOptions(snapshot_cache_bits, commit_cache_bits);
+ ReOpen();
+ WritePreparedTxnDB* wp_db = dynamic_cast<WritePreparedTxnDB*>(db);
+ ReadOptions ropt;
+ PinnableSlice pinnable_val;
+ WriteOptions write_options;
+ TransactionOptions txn_options;
+ std::vector<Transaction*> txns, committed_txns;
+
+ const int cnt = 100;
+ for (int i = 0; i < cnt; i++) {
+ Transaction* txn = db->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(txn->SetName("xid" + ToString(i)));
+ auto key = "key1" + ToString(i);
+ auto value = "value1" + ToString(i);
+ ASSERT_OK(txn->Put(Slice(key), Slice(value)));
+ ASSERT_OK(txn->Prepare());
+ txns.push_back(txn);
+ }
+
+ port::Mutex mutex;
+ Random rnd(1103);
+ ROCKSDB_NAMESPACE::port::Thread commit_thread([&]() {
+ for (int i = 0; i < cnt; i++) {
+ uint32_t index = rnd.Uniform(cnt - i);
+ Transaction* txn;
+ {
+ MutexLock l(&mutex);
+ txn = txns[index];
+ txns.erase(txns.begin() + index);
+ }
+ // Since commit cache is practically disabled, commit results in immediate
+ // advance in max_evicted_seq_ and subsequently moving some prepared txns
+ // to delayed_prepared_.
+ txn->Commit();
+ committed_txns.push_back(txn);
+ }
+ });
+ ROCKSDB_NAMESPACE::port::Thread read_thread([&]() {
+ while (1) {
+ MutexLock l(&mutex);
+ if (txns.empty()) {
+ break;
+ }
+ auto min_uncommitted = wp_db->SmallestUnCommittedSeq();
+ ASSERT_LE(min_uncommitted, (*txns.begin())->GetId());
+ }
+ });
+
+ commit_thread.join();
+ read_thread.join();
+ for (auto txn : committed_txns) {
+ delete txn;
+ }
+}
+#endif // ROCKSDB_VALGRIND_RUN
+
+TEST_P(SeqAdvanceConcurrentTest, SeqAdvanceConcurrent) {
+ // Given the sequential run of txns, with this timeout we should never see a
+ // deadlock nor a timeout unless we have a key conflict, which should be
+ // almost infeasible.
+ txn_db_options.transaction_lock_timeout = 1000;
+ txn_db_options.default_lock_timeout = 1000;
+ ReOpen();
+ FlushOptions fopt;
+
+ // Number of different txn types we use in this test
+ const size_t type_cnt = 5;
+ // The size of the first write group
+ // TODO(myabandeh): This should be increase for pre-release tests
+ const size_t first_group_size = 2;
+ // Total number of txns we run in each test
+ // TODO(myabandeh): This should be increase for pre-release tests
+ const size_t txn_cnt = first_group_size + 1;
+
+ size_t base[txn_cnt + 1] = {
+ 1,
+ };
+ for (size_t bi = 1; bi <= txn_cnt; bi++) {
+ base[bi] = base[bi - 1] * type_cnt;
+ }
+ const size_t max_n = static_cast<size_t>(std::pow(type_cnt, txn_cnt));
+ printf("Number of cases being tested is %" ROCKSDB_PRIszt "\n", max_n);
+ for (size_t n = 0; n < max_n; n++, ReOpen()) {
+ if (n % split_cnt_ != split_id_) continue;
+ if (n % 1000 == 0) {
+ printf("Tested %" ROCKSDB_PRIszt " cases so far\n", n);
+ }
+ DBImpl* db_impl = reinterpret_cast<DBImpl*>(db->GetRootDB());
+ auto seq = db_impl->TEST_GetLastVisibleSequence();
+ with_empty_commits = 0;
+ exp_seq = seq;
+ // This is increased before writing the batch for commit
+ commit_writes = 0;
+ // This is increased before txn starts linking if it expects to do a commit
+ // eventually
+ expected_commits = 0;
+ std::vector<port::Thread> threads;
+
+ linked = 0;
+ std::atomic<bool> batch_formed(false);
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "WriteThread::EnterAsBatchGroupLeader:End",
+ [&](void* /*arg*/) { batch_formed = true; });
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
+ "WriteThread::JoinBatchGroup:Wait", [&](void* /*arg*/) {
+ linked++;
+ if (linked == 1) {
+ // Wait until the others are linked too.
+ while (linked < first_group_size) {
+ }
+ } else if (linked == 1 + first_group_size) {
+ // Make the 2nd batch of the rest of writes plus any followup
+ // commits from the first batch
+ while (linked < txn_cnt + commit_writes) {
+ }
+ }
+ // Then we will have one or more batches consisting of follow-up
+ // commits from the 2nd batch. There is a bit of non-determinism here
+ // but it should be tolerable.
+ });
+
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+ for (size_t bi = 0; bi < txn_cnt; bi++) {
+ // get the bi-th digit in number system based on type_cnt
+ size_t d = (n % base[bi + 1]) / base[bi];
+ switch (d) {
+ case 0:
+ threads.emplace_back(txn_t0, bi);
+ break;
+ case 1:
+ threads.emplace_back(txn_t1, bi);
+ break;
+ case 2:
+ threads.emplace_back(txn_t2, bi);
+ break;
+ case 3:
+ threads.emplace_back(txn_t3, bi);
+ break;
+ case 4:
+ threads.emplace_back(txn_t3, bi);
+ break;
+ default:
+ assert(false);
+ }
+ // wait to be linked
+ while (linked.load() <= bi) {
+ }
+ // after a queue of size first_group_size
+ if (bi + 1 == first_group_size) {
+ while (!batch_formed) {
+ }
+ // to make it more deterministic, wait until the commits are linked
+ while (linked.load() <= bi + expected_commits) {
+ }
+ }
+ }
+ for (auto& t : threads) {
+ t.join();
+ }
+ if (options.two_write_queues) {
+ // In this case none of the above scheduling tricks to deterministically
+ // form merged batches works because the writes go to separate queues.
+ // This would result in different write groups in each run of the test. We
+ // still keep the test since although non-deterministic and hard to debug,
+ // it is still useful to have.
+ // TODO(myabandeh): Add a deterministic unit test for two_write_queues
+ }
+
+ // Check if memtable inserts advanced seq number as expected
+ seq = db_impl->TEST_GetLastVisibleSequence();
+ ASSERT_EQ(exp_seq, seq);
+
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->ClearAllCallBacks();
+
+ // Check if recovery preserves the last sequence number
+ db_impl->FlushWAL(true);
+ ReOpenNoDelete();
+ assert(db != nullptr);
+ db_impl = reinterpret_cast<DBImpl*>(db->GetRootDB());
+ seq = db_impl->TEST_GetLastVisibleSequence();
+ ASSERT_LE(exp_seq, seq + with_empty_commits);
+
+ // Check if flush preserves the last sequence number
+ db_impl->Flush(fopt);
+ seq = db_impl->GetLatestSequenceNumber();
+ ASSERT_LE(exp_seq, seq + with_empty_commits);
+
+ // Check if recovery after flush preserves the last sequence number
+ db_impl->FlushWAL(true);
+ ReOpenNoDelete();
+ assert(db != nullptr);
+ db_impl = reinterpret_cast<DBImpl*>(db->GetRootDB());
+ seq = db_impl->GetLatestSequenceNumber();
+ ASSERT_LE(exp_seq, seq + with_empty_commits);
+ }
+}
+
+// Run a couple of different txns among them some uncommitted. Restart the db at
+// a couple points to check whether the list of uncommitted txns are recovered
+// properly.
+TEST_P(WritePreparedTransactionTest, BasicRecovery) {
+ options.disable_auto_compactions = true;
+ ReOpen();
+ WritePreparedTxnDB* wp_db = dynamic_cast<WritePreparedTxnDB*>(db);
+
+ txn_t0(0);
+
+ TransactionOptions txn_options;
+ WriteOptions write_options;
+ size_t index = 1000;
+ Transaction* txn0 = db->BeginTransaction(write_options, txn_options);
+ auto istr0 = std::to_string(index);
+ auto s = txn0->SetName("xid" + istr0);
+ ASSERT_OK(s);
+ s = txn0->Put(Slice("foo0" + istr0), Slice("bar0" + istr0));
+ ASSERT_OK(s);
+ s = txn0->Prepare();
+ auto prep_seq_0 = txn0->GetId();
+
+ txn_t1(0);
+
+ index++;
+ Transaction* txn1 = db->BeginTransaction(write_options, txn_options);
+ auto istr1 = std::to_string(index);
+ s = txn1->SetName("xid" + istr1);
+ ASSERT_OK(s);
+ s = txn1->Put(Slice("foo1" + istr1), Slice("bar"));
+ ASSERT_OK(s);
+ s = txn1->Prepare();
+ auto prep_seq_1 = txn1->GetId();
+
+ txn_t2(0);
+
+ ReadOptions ropt;
+ PinnableSlice pinnable_val;
+ // Check the value is not committed before restart
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo0" + istr0, &pinnable_val);
+ ASSERT_TRUE(s.IsNotFound());
+ pinnable_val.Reset();
+
+ delete txn0;
+ delete txn1;
+ wp_db->db_impl_->FlushWAL(true);
+ wp_db->TEST_Crash();
+ ReOpenNoDelete();
+ assert(db != nullptr);
+ wp_db = dynamic_cast<WritePreparedTxnDB*>(db);
+ // After recovery, all the uncommitted txns (0 and 1) should be inserted into
+ // delayed_prepared_
+ ASSERT_TRUE(wp_db->prepared_txns_.empty());
+ ASSERT_FALSE(wp_db->delayed_prepared_empty_);
+ ASSERT_LE(prep_seq_0, wp_db->max_evicted_seq_);
+ ASSERT_LE(prep_seq_1, wp_db->max_evicted_seq_);
+ {
+ ReadLock rl(&wp_db->prepared_mutex_);
+ ASSERT_EQ(2, wp_db->delayed_prepared_.size());
+ ASSERT_TRUE(wp_db->delayed_prepared_.find(prep_seq_0) !=
+ wp_db->delayed_prepared_.end());
+ ASSERT_TRUE(wp_db->delayed_prepared_.find(prep_seq_1) !=
+ wp_db->delayed_prepared_.end());
+ }
+
+ // Check the value is still not committed after restart
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo0" + istr0, &pinnable_val);
+ ASSERT_TRUE(s.IsNotFound());
+ pinnable_val.Reset();
+
+ txn_t3(0);
+
+ // Test that a recovered txns will be properly marked committed for the next
+ // recovery
+ txn1 = db->GetTransactionByName("xid" + istr1);
+ ASSERT_NE(txn1, nullptr);
+ txn1->Commit();
+ delete txn1;
+
+ index++;
+ Transaction* txn2 = db->BeginTransaction(write_options, txn_options);
+ auto istr2 = std::to_string(index);
+ s = txn2->SetName("xid" + istr2);
+ ASSERT_OK(s);
+ s = txn2->Put(Slice("foo2" + istr2), Slice("bar"));
+ ASSERT_OK(s);
+ s = txn2->Prepare();
+ auto prep_seq_2 = txn2->GetId();
+
+ delete txn2;
+ wp_db->db_impl_->FlushWAL(true);
+ wp_db->TEST_Crash();
+ ReOpenNoDelete();
+ assert(db != nullptr);
+ wp_db = dynamic_cast<WritePreparedTxnDB*>(db);
+ ASSERT_TRUE(wp_db->prepared_txns_.empty());
+ ASSERT_FALSE(wp_db->delayed_prepared_empty_);
+
+ // 0 and 2 are prepared and 1 is committed
+ {
+ ReadLock rl(&wp_db->prepared_mutex_);
+ ASSERT_EQ(2, wp_db->delayed_prepared_.size());
+ const auto& end = wp_db->delayed_prepared_.end();
+ ASSERT_NE(wp_db->delayed_prepared_.find(prep_seq_0), end);
+ ASSERT_EQ(wp_db->delayed_prepared_.find(prep_seq_1), end);
+ ASSERT_NE(wp_db->delayed_prepared_.find(prep_seq_2), end);
+ }
+ ASSERT_LE(prep_seq_0, wp_db->max_evicted_seq_);
+ ASSERT_LE(prep_seq_2, wp_db->max_evicted_seq_);
+
+ // Commit all the remaining txns
+ txn0 = db->GetTransactionByName("xid" + istr0);
+ ASSERT_NE(txn0, nullptr);
+ txn0->Commit();
+ txn2 = db->GetTransactionByName("xid" + istr2);
+ ASSERT_NE(txn2, nullptr);
+ txn2->Commit();
+
+ // Check the value is committed after commit
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo0" + istr0, &pinnable_val);
+ ASSERT_TRUE(s.ok());
+ ASSERT_TRUE(pinnable_val == ("bar0" + istr0));
+ pinnable_val.Reset();
+
+ delete txn0;
+ delete txn2;
+ wp_db->db_impl_->FlushWAL(true);
+ ReOpenNoDelete();
+ assert(db != nullptr);
+ wp_db = dynamic_cast<WritePreparedTxnDB*>(db);
+ ASSERT_TRUE(wp_db->prepared_txns_.empty());
+ ASSERT_TRUE(wp_db->delayed_prepared_empty_);
+
+ // Check the value is still committed after recovery
+ s = db->Get(ropt, db->DefaultColumnFamily(), "foo0" + istr0, &pinnable_val);
+ ASSERT_TRUE(s.ok());
+ ASSERT_TRUE(pinnable_val == ("bar0" + istr0));
+ pinnable_val.Reset();
+}
+
+// After recovery the commit map is empty while the max is set. The code would
+// go through a different path which requires a separate test. Test that the
+// committed data before the restart is visible to all snapshots.
+TEST_P(WritePreparedTransactionTest, IsInSnapshotEmptyMap) {
+ for (bool end_with_prepare : {false, true}) {
+ ReOpen();
+ WriteOptions woptions;
+ ASSERT_OK(db->Put(woptions, "key", "value"));
+ ASSERT_OK(db->Put(woptions, "key", "value"));
+ ASSERT_OK(db->Put(woptions, "key", "value"));
+ SequenceNumber prepare_seq = kMaxSequenceNumber;
+ if (end_with_prepare) {
+ TransactionOptions txn_options;
+ Transaction* txn = db->BeginTransaction(woptions, txn_options);
+ ASSERT_OK(txn->SetName("xid0"));
+ ASSERT_OK(txn->Prepare());
+ prepare_seq = txn->GetId();
+ delete txn;
+ }
+ dynamic_cast<WritePreparedTxnDB*>(db)->TEST_Crash();
+ auto db_impl = reinterpret_cast<DBImpl*>(db->GetRootDB());
+ db_impl->FlushWAL(true);
+ ReOpenNoDelete();
+ WritePreparedTxnDB* wp_db = dynamic_cast<WritePreparedTxnDB*>(db);
+ assert(wp_db != nullptr);
+ ASSERT_GT(wp_db->max_evicted_seq_, 0); // max after recovery
+ // Take a snapshot right after recovery
+ const Snapshot* snap = db->GetSnapshot();
+ auto snap_seq = snap->GetSequenceNumber();
+ ASSERT_GT(snap_seq, 0);
+
+ for (SequenceNumber seq = 0;
+ seq <= wp_db->max_evicted_seq_ && seq != prepare_seq; seq++) {
+ ASSERT_TRUE(wp_db->IsInSnapshot(seq, snap_seq));
+ }
+ if (end_with_prepare) {
+ ASSERT_FALSE(wp_db->IsInSnapshot(prepare_seq, snap_seq));
+ }
+ // trivial check
+ ASSERT_FALSE(wp_db->IsInSnapshot(snap_seq + 1, snap_seq));
+
+ db->ReleaseSnapshot(snap);
+
+ ASSERT_OK(db->Put(woptions, "key", "value"));
+ // Take a snapshot after some writes
+ snap = db->GetSnapshot();
+ snap_seq = snap->GetSequenceNumber();
+ for (SequenceNumber seq = 0;
+ seq <= wp_db->max_evicted_seq_ && seq != prepare_seq; seq++) {
+ ASSERT_TRUE(wp_db->IsInSnapshot(seq, snap_seq));
+ }
+ if (end_with_prepare) {
+ ASSERT_FALSE(wp_db->IsInSnapshot(prepare_seq, snap_seq));
+ }
+ // trivial check
+ ASSERT_FALSE(wp_db->IsInSnapshot(snap_seq + 1, snap_seq));
+
+ db->ReleaseSnapshot(snap);
+ }
+}
+
+// Shows the contract of IsInSnapshot when called on invalid/released snapshots
+TEST_P(WritePreparedTransactionTest, IsInSnapshotReleased) {
+ WritePreparedTxnDB* wp_db = dynamic_cast<WritePreparedTxnDB*>(db);
+ WriteOptions woptions;
+ ASSERT_OK(db->Put(woptions, "key", "value"));
+ // snap seq = 1
+ const Snapshot* snap1 = db->GetSnapshot();
+ ASSERT_OK(db->Put(woptions, "key", "value"));
+ ASSERT_OK(db->Put(woptions, "key", "value"));
+ // snap seq = 3
+ const Snapshot* snap2 = db->GetSnapshot();
+ const SequenceNumber seq = 1;
+ // Evict seq out of commit cache
+ size_t overwrite_seq = wp_db->COMMIT_CACHE_SIZE + seq;
+ wp_db->AddCommitted(overwrite_seq, overwrite_seq);
+ SequenceNumber snap_seq;
+ uint64_t min_uncommitted = kMinUnCommittedSeq;
+ bool released;
+
+ released = false;
+ snap_seq = snap1->GetSequenceNumber();
+ ASSERT_LE(seq, snap_seq);
+ // Valid snapshot lower than max
+ ASSERT_LE(snap_seq, wp_db->max_evicted_seq_);
+ ASSERT_TRUE(wp_db->IsInSnapshot(seq, snap_seq, min_uncommitted, &released));
+ ASSERT_FALSE(released);
+
+ released = false;
+ snap_seq = snap1->GetSequenceNumber();
+ // Invaid snapshot lower than max
+ ASSERT_LE(snap_seq + 1, wp_db->max_evicted_seq_);
+ ASSERT_TRUE(
+ wp_db->IsInSnapshot(seq, snap_seq + 1, min_uncommitted, &released));
+ ASSERT_TRUE(released);
+
+ db->ReleaseSnapshot(snap1);
+
+ released = false;
+ // Released snapshot lower than max
+ ASSERT_TRUE(wp_db->IsInSnapshot(seq, snap_seq, min_uncommitted, &released));
+ // The release does not take affect until the next max advance
+ ASSERT_FALSE(released);
+
+ released = false;
+ // Invaid snapshot lower than max
+ ASSERT_TRUE(
+ wp_db->IsInSnapshot(seq, snap_seq + 1, min_uncommitted, &released));
+ ASSERT_TRUE(released);
+
+ // This make the snapshot release to reflect in txn db structures
+ wp_db->AdvanceMaxEvictedSeq(wp_db->max_evicted_seq_,
+ wp_db->max_evicted_seq_ + 1);
+
+ released = false;
+ // Released snapshot lower than max
+ ASSERT_TRUE(wp_db->IsInSnapshot(seq, snap_seq, min_uncommitted, &released));
+ ASSERT_TRUE(released);
+
+ released = false;
+ // Invaid snapshot lower than max
+ ASSERT_TRUE(
+ wp_db->IsInSnapshot(seq, snap_seq + 1, min_uncommitted, &released));
+ ASSERT_TRUE(released);
+
+ snap_seq = snap2->GetSequenceNumber();
+
+ released = false;
+ // Unreleased snapshot lower than max
+ ASSERT_TRUE(wp_db->IsInSnapshot(seq, snap_seq, min_uncommitted, &released));
+ ASSERT_FALSE(released);
+
+ db->ReleaseSnapshot(snap2);
+}
+
+// Test WritePreparedTxnDB's IsInSnapshot against different ordering of
+// snapshot, max_committed_seq_, prepared, and commit entries.
+TEST_P(WritePreparedTransactionTest, IsInSnapshot) {
+ WriteOptions wo;
+ // Use small commit cache to trigger lots of eviction and fast advance of
+ // max_evicted_seq_
+ const size_t commit_cache_bits = 3;
+ // Same for snapshot cache size
+ const size_t snapshot_cache_bits = 2;
+
+ // Take some preliminary snapshots first. This is to stress the data structure
+ // that holds the old snapshots as it will be designed to be efficient when
+ // only a few snapshots are below the max_evicted_seq_.
+ for (int max_snapshots = 1; max_snapshots < 20; max_snapshots++) {
+ // Leave some gap between the preliminary snapshots and the final snapshot
+ // that we check. This should test for also different overlapping scenarios
+ // between the last snapshot and the commits.
+ for (int max_gap = 1; max_gap < 10; max_gap++) {
+ // Since we do not actually write to db, we mock the seq as it would be
+ // increased by the db. The only exception is that we need db seq to
+ // advance for our snapshots. for which we apply a dummy put each time we
+ // increase our mock of seq.
+ uint64_t seq = 0;
+ // At each step we prepare a txn and then we commit it in the next txn.
+ // This emulates the consecutive transactions that write to the same key
+ uint64_t cur_txn = 0;
+ // Number of snapshots taken so far
+ int num_snapshots = 0;
+ // Number of gaps applied so far
+ int gap_cnt = 0;
+ // The final snapshot that we will inspect
+ uint64_t snapshot = 0;
+ bool found_committed = false;
+ // To stress the data structure that maintain prepared txns, at each cycle
+ // we add a new prepare txn. These do not mean to be committed for
+ // snapshot inspection.
+ std::set<uint64_t> prepared;
+ // We keep the list of txns committed before we take the last snapshot.
+ // These should be the only seq numbers that will be found in the snapshot
+ std::set<uint64_t> committed_before;
+ // The set of commit seq numbers to be excluded from IsInSnapshot queries
+ std::set<uint64_t> commit_seqs;
+ DBImpl* mock_db = new DBImpl(options, dbname);
+ UpdateTransactionDBOptions(snapshot_cache_bits, commit_cache_bits);
+ std::unique_ptr<WritePreparedTxnDBMock> wp_db(
+ new WritePreparedTxnDBMock(mock_db, txn_db_options));
+ // We continue until max advances a bit beyond the snapshot.
+ while (!snapshot || wp_db->max_evicted_seq_ < snapshot + 100) {
+ // do prepare for a transaction
+ seq++;
+ wp_db->AddPrepared(seq);
+ prepared.insert(seq);
+
+ // If cur_txn is not started, do prepare for it.
+ if (!cur_txn) {
+ seq++;
+ cur_txn = seq;
+ wp_db->AddPrepared(cur_txn);
+ } else { // else commit it
+ seq++;
+ wp_db->AddCommitted(cur_txn, seq);
+ wp_db->RemovePrepared(cur_txn);
+ commit_seqs.insert(seq);
+ if (!snapshot) {
+ committed_before.insert(cur_txn);
+ }
+ cur_txn = 0;
+ }
+
+ if (num_snapshots < max_snapshots - 1) {
+ // Take preliminary snapshots
+ wp_db->TakeSnapshot(seq);
+ num_snapshots++;
+ } else if (gap_cnt < max_gap) {
+ // Wait for some gap before taking the final snapshot
+ gap_cnt++;
+ } else if (!snapshot) {
+ // Take the final snapshot if it is not already taken
+ snapshot = seq;
+ wp_db->TakeSnapshot(snapshot);
+ num_snapshots++;
+ }
+
+ // If the snapshot is taken, verify seq numbers visible to it. We redo
+ // it at each cycle to test that the system is still sound when
+ // max_evicted_seq_ advances.
+ if (snapshot) {
+ for (uint64_t s = 1;
+ s <= seq && commit_seqs.find(s) == commit_seqs.end(); s++) {
+ bool was_committed =
+ (committed_before.find(s) != committed_before.end());
+ bool is_in_snapshot = wp_db->IsInSnapshot(s, snapshot);
+ if (was_committed != is_in_snapshot) {
+ printf("max_snapshots %d max_gap %d seq %" PRIu64 " max %" PRIu64
+ " snapshot %" PRIu64
+ " gap_cnt %d num_snapshots %d s %" PRIu64 "\n",
+ max_snapshots, max_gap, seq,
+ wp_db->max_evicted_seq_.load(), snapshot, gap_cnt,
+ num_snapshots, s);
+ }
+ ASSERT_EQ(was_committed, is_in_snapshot);
+ found_committed = found_committed || is_in_snapshot;
+ }
+ }
+ }
+ // Safety check to make sure the test actually ran
+ ASSERT_TRUE(found_committed);
+ // As an extra check, check if prepared set will be properly empty after
+ // they are committed.
+ if (cur_txn) {
+ wp_db->AddCommitted(cur_txn, seq);
+ wp_db->RemovePrepared(cur_txn);
+ }
+ for (auto p : prepared) {
+ wp_db->AddCommitted(p, seq);
+ wp_db->RemovePrepared(p);
+ }
+ ASSERT_TRUE(wp_db->delayed_prepared_.empty());
+ ASSERT_TRUE(wp_db->prepared_txns_.empty());
+ }
+ }
+}
+
+void ASSERT_SAME(ReadOptions roptions, TransactionDB* db, Status exp_s,
+ PinnableSlice& exp_v, Slice key) {
+ Status s;
+ PinnableSlice v;
+ s = db->Get(roptions, db->DefaultColumnFamily(), key, &v);
+ ASSERT_TRUE(exp_s == s);
+ ASSERT_TRUE(s.ok() || s.IsNotFound());
+ if (s.ok()) {
+ ASSERT_TRUE(exp_v == v);
+ }
+
+ // Try with MultiGet API too
+ std::vector<std::string> values;
+ auto s_vec =
+ db->MultiGet(roptions, {db->DefaultColumnFamily()}, {key}, &values);
+ ASSERT_EQ(1, values.size());
+ ASSERT_EQ(1, s_vec.size());
+ s = s_vec[0];
+ ASSERT_TRUE(exp_s == s);
+ ASSERT_TRUE(s.ok() || s.IsNotFound());
+ if (s.ok()) {
+ ASSERT_TRUE(exp_v == values[0]);
+ }
+}
+
+void ASSERT_SAME(TransactionDB* db, Status exp_s, PinnableSlice& exp_v,
+ Slice key) {
+ ASSERT_SAME(ReadOptions(), db, exp_s, exp_v, key);
+}
+
+TEST_P(WritePreparedTransactionTest, Rollback) {
+ ReadOptions roptions;
+ WriteOptions woptions;
+ TransactionOptions txn_options;
+ const size_t num_keys = 4;
+ const size_t num_values = 5;
+ for (size_t ikey = 1; ikey <= num_keys; ikey++) {
+ for (size_t ivalue = 0; ivalue < num_values; ivalue++) {
+ for (bool crash : {false, true}) {
+ ReOpen();
+ WritePreparedTxnDB* wp_db = dynamic_cast<WritePreparedTxnDB*>(db);
+ std::string key_str = "key" + ToString(ikey);
+ switch (ivalue) {
+ case 0:
+ break;
+ case 1:
+ ASSERT_OK(db->Put(woptions, key_str, "initvalue1"));
+ break;
+ case 2:
+ ASSERT_OK(db->Merge(woptions, key_str, "initvalue2"));
+ break;
+ case 3:
+ ASSERT_OK(db->Delete(woptions, key_str));
+ break;
+ case 4:
+ ASSERT_OK(db->SingleDelete(woptions, key_str));
+ break;
+ default:
+ assert(0);
+ }
+
+ PinnableSlice v1;
+ auto s1 =
+ db->Get(roptions, db->DefaultColumnFamily(), Slice("key1"), &v1);
+ PinnableSlice v2;
+ auto s2 =
+ db->Get(roptions, db->DefaultColumnFamily(), Slice("key2"), &v2);
+ PinnableSlice v3;
+ auto s3 =
+ db->Get(roptions, db->DefaultColumnFamily(), Slice("key3"), &v3);
+ PinnableSlice v4;
+ auto s4 =
+ db->Get(roptions, db->DefaultColumnFamily(), Slice("key4"), &v4);
+ Transaction* txn = db->BeginTransaction(woptions, txn_options);
+ auto s = txn->SetName("xid0");
+ ASSERT_OK(s);
+ s = txn->Put(Slice("key1"), Slice("value1"));
+ ASSERT_OK(s);
+ s = txn->Merge(Slice("key2"), Slice("value2"));
+ ASSERT_OK(s);
+ s = txn->Delete(Slice("key3"));
+ ASSERT_OK(s);
+ s = txn->SingleDelete(Slice("key4"));
+ ASSERT_OK(s);
+ s = txn->Prepare();
+ ASSERT_OK(s);
+
+ {
+ ReadLock rl(&wp_db->prepared_mutex_);
+ ASSERT_FALSE(wp_db->prepared_txns_.empty());
+ ASSERT_EQ(txn->GetId(), wp_db->prepared_txns_.top());
+ }
+
+ ASSERT_SAME(db, s1, v1, "key1");
+ ASSERT_SAME(db, s2, v2, "key2");
+ ASSERT_SAME(db, s3, v3, "key3");
+ ASSERT_SAME(db, s4, v4, "key4");
+
+ if (crash) {
+ delete txn;
+ auto db_impl = reinterpret_cast<DBImpl*>(db->GetRootDB());
+ db_impl->FlushWAL(true);
+ dynamic_cast<WritePreparedTxnDB*>(db)->TEST_Crash();
+ ReOpenNoDelete();
+ assert(db != nullptr);
+ wp_db = dynamic_cast<WritePreparedTxnDB*>(db);
+ txn = db->GetTransactionByName("xid0");
+ ASSERT_FALSE(wp_db->delayed_prepared_empty_);
+ ReadLock rl(&wp_db->prepared_mutex_);
+ ASSERT_TRUE(wp_db->prepared_txns_.empty());
+ ASSERT_FALSE(wp_db->delayed_prepared_.empty());
+ ASSERT_TRUE(wp_db->delayed_prepared_.find(txn->GetId()) !=
+ wp_db->delayed_prepared_.end());
+ }
+
+ ASSERT_SAME(db, s1, v1, "key1");
+ ASSERT_SAME(db, s2, v2, "key2");
+ ASSERT_SAME(db, s3, v3, "key3");
+ ASSERT_SAME(db, s4, v4, "key4");
+
+ s = txn->Rollback();
+ ASSERT_OK(s);
+
+ {
+ ASSERT_TRUE(wp_db->delayed_prepared_empty_);
+ ReadLock rl(&wp_db->prepared_mutex_);
+ ASSERT_TRUE(wp_db->prepared_txns_.empty());
+ ASSERT_TRUE(wp_db->delayed_prepared_.empty());
+ }
+
+ ASSERT_SAME(db, s1, v1, "key1");
+ ASSERT_SAME(db, s2, v2, "key2");
+ ASSERT_SAME(db, s3, v3, "key3");
+ ASSERT_SAME(db, s4, v4, "key4");
+ delete txn;
+ }
+ }
+ }
+}
+
+TEST_P(WritePreparedTransactionTest, DisableGCDuringRecovery) {
+ // Use large buffer to avoid memtable flush after 1024 insertions
+ options.write_buffer_size = 1024 * 1024;
+ ReOpen();
+ std::vector<KeyVersion> versions;
+ uint64_t seq = 0;
+ for (uint64_t i = 1; i <= 1024; i++) {
+ std::string v = "bar" + ToString(i);
+ ASSERT_OK(db->Put(WriteOptions(), "foo", v));
+ VerifyKeys({{"foo", v}});
+ seq++; // one for the key/value
+ KeyVersion kv = {"foo", v, seq, kTypeValue};
+ if (options.two_write_queues) {
+ seq++; // one for the commit
+ }
+ versions.emplace_back(kv);
+ }
+ std::reverse(std::begin(versions), std::end(versions));
+ VerifyInternalKeys(versions);
+ DBImpl* db_impl = reinterpret_cast<DBImpl*>(db->GetRootDB());
+ db_impl->FlushWAL(true);
+ // Use small buffer to ensure memtable flush during recovery
+ options.write_buffer_size = 1024;
+ ReOpenNoDelete();
+ VerifyInternalKeys(versions);
+}
+
+TEST_P(WritePreparedTransactionTest, SequenceNumberZero) {
+ ASSERT_OK(db->Put(WriteOptions(), "foo", "bar"));
+ VerifyKeys({{"foo", "bar"}});
+ const Snapshot* snapshot = db->GetSnapshot();
+ ASSERT_OK(db->Flush(FlushOptions()));
+ // Dummy keys to avoid compaction trivially move files and get around actual
+ // compaction logic.
+ ASSERT_OK(db->Put(WriteOptions(), "a", "dummy"));
+ ASSERT_OK(db->Put(WriteOptions(), "z", "dummy"));
+ ASSERT_OK(db->CompactRange(CompactRangeOptions(), nullptr, nullptr));
+ // Compaction will output keys with sequence number 0, if it is visible to
+ // earliest snapshot. Make sure IsInSnapshot() report sequence number 0 is
+ // visible to any snapshot.
+ VerifyKeys({{"foo", "bar"}});
+ VerifyKeys({{"foo", "bar"}}, snapshot);
+ VerifyInternalKeys({{"foo", "bar", 0, kTypeValue}});
+ db->ReleaseSnapshot(snapshot);
+}
+
+// Compaction should not remove a key if it is not committed, and should
+// proceed with older versions of the key as-if the new version doesn't exist.
+TEST_P(WritePreparedTransactionTest, CompactionShouldKeepUncommittedKeys) {
+ options.disable_auto_compactions = true;
+ ReOpen();
+ DBImpl* db_impl = reinterpret_cast<DBImpl*>(db->GetRootDB());
+ // Snapshots to avoid keys get evicted.
+ std::vector<const Snapshot*> snapshots;
+ // Keep track of expected sequence number.
+ SequenceNumber expected_seq = 0;
+
+ auto add_key = [&](std::function<Status()> func) {
+ ASSERT_OK(func());
+ expected_seq++;
+ if (options.two_write_queues) {
+ expected_seq++; // 1 for commit
+ }
+ ASSERT_EQ(expected_seq, db_impl->TEST_GetLastVisibleSequence());
+ snapshots.push_back(db->GetSnapshot());
+ };
+
+ // Each key here represent a standalone test case.
+ add_key([&]() { return db->Put(WriteOptions(), "key1", "value1_1"); });
+ add_key([&]() { return db->Put(WriteOptions(), "key2", "value2_1"); });
+ add_key([&]() { return db->Put(WriteOptions(), "key3", "value3_1"); });
+ add_key([&]() { return db->Put(WriteOptions(), "key4", "value4_1"); });
+ add_key([&]() { return db->Merge(WriteOptions(), "key5", "value5_1"); });
+ add_key([&]() { return db->Merge(WriteOptions(), "key5", "value5_2"); });
+ add_key([&]() { return db->Put(WriteOptions(), "key6", "value6_1"); });
+ add_key([&]() { return db->Put(WriteOptions(), "key7", "value7_1"); });
+ ASSERT_OK(db->Flush(FlushOptions()));
+ add_key([&]() { return db->Delete(WriteOptions(), "key6"); });
+ add_key([&]() { return db->SingleDelete(WriteOptions(), "key7"); });
+
+ auto* transaction = db->BeginTransaction(WriteOptions());
+ ASSERT_OK(transaction->SetName("txn"));
+ ASSERT_OK(transaction->Put("key1", "value1_2"));
+ ASSERT_OK(transaction->Delete("key2"));
+ ASSERT_OK(transaction->SingleDelete("key3"));
+ ASSERT_OK(transaction->Merge("key4", "value4_2"));
+ ASSERT_OK(transaction->Merge("key5", "value5_3"));
+ ASSERT_OK(transaction->Put("key6", "value6_2"));
+ ASSERT_OK(transaction->Put("key7", "value7_2"));
+ // Prepare but not commit.
+ ASSERT_OK(transaction->Prepare());
+ ASSERT_EQ(++expected_seq, db->GetLatestSequenceNumber());
+ ASSERT_OK(db->Flush(FlushOptions()));
+ for (auto* s : snapshots) {
+ db->ReleaseSnapshot(s);
+ }
+ // Dummy keys to avoid compaction trivially move files and get around actual
+ // compaction logic.
+ ASSERT_OK(db->Put(WriteOptions(), "a", "dummy"));
+ ASSERT_OK(db->Put(WriteOptions(), "z", "dummy"));
+ ASSERT_OK(db->CompactRange(CompactRangeOptions(), nullptr, nullptr));
+ VerifyKeys({
+ {"key1", "value1_1"},
+ {"key2", "value2_1"},
+ {"key3", "value3_1"},
+ {"key4", "value4_1"},
+ {"key5", "value5_1,value5_2"},
+ {"key6", "NOT_FOUND"},
+ {"key7", "NOT_FOUND"},
+ });
+ VerifyInternalKeys({
+ {"key1", "value1_2", expected_seq, kTypeValue},
+ {"key1", "value1_1", 0, kTypeValue},
+ {"key2", "", expected_seq, kTypeDeletion},
+ {"key2", "value2_1", 0, kTypeValue},
+ {"key3", "", expected_seq, kTypeSingleDeletion},
+ {"key3", "value3_1", 0, kTypeValue},
+ {"key4", "value4_2", expected_seq, kTypeMerge},
+ {"key4", "value4_1", 0, kTypeValue},
+ {"key5", "value5_3", expected_seq, kTypeMerge},
+ {"key5", "value5_1,value5_2", 0, kTypeValue},
+ {"key6", "value6_2", expected_seq, kTypeValue},
+ {"key7", "value7_2", expected_seq, kTypeValue},
+ });
+ ASSERT_OK(transaction->Commit());
+ VerifyKeys({
+ {"key1", "value1_2"},
+ {"key2", "NOT_FOUND"},
+ {"key3", "NOT_FOUND"},
+ {"key4", "value4_1,value4_2"},
+ {"key5", "value5_1,value5_2,value5_3"},
+ {"key6", "value6_2"},
+ {"key7", "value7_2"},
+ });
+ delete transaction;
+}
+
+// Compaction should keep keys visible to a snapshot based on commit sequence,
+// not just prepare sequence.
+TEST_P(WritePreparedTransactionTest, CompactionShouldKeepSnapshotVisibleKeys) {
+ options.disable_auto_compactions = true;
+ ReOpen();
+ // Keep track of expected sequence number.
+ SequenceNumber expected_seq = 0;
+ auto* txn1 = db->BeginTransaction(WriteOptions());
+ ASSERT_OK(txn1->SetName("txn1"));
+ ASSERT_OK(txn1->Put("key1", "value1_1"));
+ ASSERT_OK(txn1->Prepare());
+ ASSERT_EQ(++expected_seq, db->GetLatestSequenceNumber());
+ ASSERT_OK(txn1->Commit());
+ DBImpl* db_impl = reinterpret_cast<DBImpl*>(db->GetRootDB());
+ ASSERT_EQ(++expected_seq, db_impl->TEST_GetLastVisibleSequence());
+ delete txn1;
+ // Take a snapshots to avoid keys get evicted before compaction.
+ const Snapshot* snapshot1 = db->GetSnapshot();
+ auto* txn2 = db->BeginTransaction(WriteOptions());
+ ASSERT_OK(txn2->SetName("txn2"));
+ ASSERT_OK(txn2->Put("key2", "value2_1"));
+ ASSERT_OK(txn2->Prepare());
+ ASSERT_EQ(++expected_seq, db->GetLatestSequenceNumber());
+ // txn1 commit before snapshot2 and it is visible to snapshot2.
+ // txn2 commit after snapshot2 and it is not visible.
+ const Snapshot* snapshot2 = db->GetSnapshot();
+ ASSERT_OK(txn2->Commit());
+ ASSERT_EQ(++expected_seq, db_impl->TEST_GetLastVisibleSequence());
+ delete txn2;
+ // Take a snapshots to avoid keys get evicted before compaction.
+ const Snapshot* snapshot3 = db->GetSnapshot();
+ ASSERT_OK(db->Put(WriteOptions(), "key1", "value1_2"));
+ expected_seq++; // 1 for write
+ SequenceNumber seq1 = expected_seq;
+ if (options.two_write_queues) {
+ expected_seq++; // 1 for commit
+ }
+ ASSERT_EQ(expected_seq, db_impl->TEST_GetLastVisibleSequence());
+ ASSERT_OK(db->Put(WriteOptions(), "key2", "value2_2"));
+ expected_seq++; // 1 for write
+ SequenceNumber seq2 = expected_seq;
+ if (options.two_write_queues) {
+ expected_seq++; // 1 for commit
+ }
+ ASSERT_EQ(expected_seq, db_impl->TEST_GetLastVisibleSequence());
+ ASSERT_OK(db->Flush(FlushOptions()));
+ db->ReleaseSnapshot(snapshot1);
+ db->ReleaseSnapshot(snapshot3);
+ // Dummy keys to avoid compaction trivially move files and get around actual
+ // compaction logic.
+ ASSERT_OK(db->Put(WriteOptions(), "a", "dummy"));
+ ASSERT_OK(db->Put(WriteOptions(), "z", "dummy"));
+ ASSERT_OK(db->CompactRange(CompactRangeOptions(), nullptr, nullptr));
+ VerifyKeys({{"key1", "value1_2"}, {"key2", "value2_2"}});
+ VerifyKeys({{"key1", "value1_1"}, {"key2", "NOT_FOUND"}}, snapshot2);
+ VerifyInternalKeys({
+ {"key1", "value1_2", seq1, kTypeValue},
+ // "value1_1" is visible to snapshot2. Also keys at bottom level visible
+ // to earliest snapshot will output with seq = 0.
+ {"key1", "value1_1", 0, kTypeValue},
+ {"key2", "value2_2", seq2, kTypeValue},
+ });
+ db->ReleaseSnapshot(snapshot2);
+}
+
+TEST_P(WritePreparedTransactionTest, SmallestUncommittedOptimization) {
+ const size_t snapshot_cache_bits = 7; // same as default
+ const size_t commit_cache_bits = 0; // disable commit cache
+ for (bool has_recent_prepare : {true, false}) {
+ UpdateTransactionDBOptions(snapshot_cache_bits, commit_cache_bits);
+ ReOpen();
+
+ ASSERT_OK(db->Put(WriteOptions(), "key1", "value1"));
+ auto* transaction =
+ db->BeginTransaction(WriteOptions(), TransactionOptions(), nullptr);
+ ASSERT_OK(transaction->SetName("txn"));
+ ASSERT_OK(transaction->Delete("key1"));
+ ASSERT_OK(transaction->Prepare());
+ // snapshot1 should get min_uncommitted from prepared_txns_ heap.
+ auto snapshot1 = db->GetSnapshot();
+ ASSERT_EQ(transaction->GetId(),
+ ((SnapshotImpl*)snapshot1)->min_uncommitted_);
+ // Add a commit to advance max_evicted_seq and move the prepared transaction
+ // into delayed_prepared_ set.
+ ASSERT_OK(db->Put(WriteOptions(), "key2", "value2"));
+ Transaction* txn2 = nullptr;
+ if (has_recent_prepare) {
+ txn2 =
+ db->BeginTransaction(WriteOptions(), TransactionOptions(), nullptr);
+ ASSERT_OK(txn2->SetName("txn2"));
+ ASSERT_OK(txn2->Put("key3", "value3"));
+ ASSERT_OK(txn2->Prepare());
+ }
+ // snapshot2 should get min_uncommitted from delayed_prepared_ set.
+ auto snapshot2 = db->GetSnapshot();
+ ASSERT_EQ(transaction->GetId(),
+ ((SnapshotImpl*)snapshot1)->min_uncommitted_);
+ ASSERT_OK(transaction->Commit());
+ delete transaction;
+ if (has_recent_prepare) {
+ ASSERT_OK(txn2->Commit());
+ delete txn2;
+ }
+ VerifyKeys({{"key1", "NOT_FOUND"}});
+ VerifyKeys({{"key1", "value1"}}, snapshot1);
+ VerifyKeys({{"key1", "value1"}}, snapshot2);
+ db->ReleaseSnapshot(snapshot1);
+ db->ReleaseSnapshot(snapshot2);
+ }
+}
+
+// Insert two values, v1 and v2, for a key. Between prepare and commit of v2
+// take two snapshots, s1 and s2. Release s1 during compaction.
+// Test to make sure compaction doesn't get confused and think s1 can see both
+// values, and thus compact out the older value by mistake.
+TEST_P(WritePreparedTransactionTest, ReleaseSnapshotDuringCompaction) {
+ const size_t snapshot_cache_bits = 7; // same as default
+ const size_t commit_cache_bits = 0; // minimum commit cache
+ UpdateTransactionDBOptions(snapshot_cache_bits, commit_cache_bits);
+ ReOpen();
+
+ ASSERT_OK(db->Put(WriteOptions(), "key1", "value1_1"));
+ auto* transaction =
+ db->BeginTransaction(WriteOptions(), TransactionOptions(), nullptr);
+ ASSERT_OK(transaction->SetName("txn"));
+ ASSERT_OK(transaction->Put("key1", "value1_2"));
+ ASSERT_OK(transaction->Prepare());
+ auto snapshot1 = db->GetSnapshot();
+ // Increment sequence number.
+ ASSERT_OK(db->Put(WriteOptions(), "key2", "value2"));
+ auto snapshot2 = db->GetSnapshot();
+ ASSERT_OK(transaction->Commit());
+ delete transaction;
+ VerifyKeys({{"key1", "value1_2"}});
+ VerifyKeys({{"key1", "value1_1"}}, snapshot1);
+ VerifyKeys({{"key1", "value1_1"}}, snapshot2);
+ // Add a flush to avoid compaction to fallback to trivial move.
+
+ auto callback = [&](void*) {
+ // Release snapshot1 after CompactionIterator init.
+ // CompactionIterator need to figure out the earliest snapshot
+ // that can see key1:value1_2 is kMaxSequenceNumber, not
+ // snapshot1 or snapshot2.
+ db->ReleaseSnapshot(snapshot1);
+ // Add some keys to advance max_evicted_seq.
+ ASSERT_OK(db->Put(WriteOptions(), "key3", "value3"));
+ ASSERT_OK(db->Put(WriteOptions(), "key4", "value4"));
+ };
+ SyncPoint::GetInstance()->SetCallBack("CompactionIterator:AfterInit",
+ callback);
+ SyncPoint::GetInstance()->EnableProcessing();
+
+ ASSERT_OK(db->Flush(FlushOptions()));
+ VerifyKeys({{"key1", "value1_2"}});
+ VerifyKeys({{"key1", "value1_1"}}, snapshot2);
+ db->ReleaseSnapshot(snapshot2);
+ SyncPoint::GetInstance()->ClearAllCallBacks();
+}
+
+// Insert two values, v1 and v2, for a key. Take two snapshots, s1 and s2,
+// after committing v2. Release s1 during compaction, right after compaction
+// processes v2 and before processes v1. Test to make sure compaction doesn't
+// get confused and believe v1 and v2 are visible to different snapshot
+// (v1 by s2, v2 by s1) and refuse to compact out v1.
+TEST_P(WritePreparedTransactionTest, ReleaseSnapshotDuringCompaction2) {
+ const size_t snapshot_cache_bits = 7; // same as default
+ const size_t commit_cache_bits = 0; // minimum commit cache
+ UpdateTransactionDBOptions(snapshot_cache_bits, commit_cache_bits);
+ ReOpen();
+
+ ASSERT_OK(db->Put(WriteOptions(), "key1", "value1"));
+ ASSERT_OK(db->Put(WriteOptions(), "key1", "value2"));
+ SequenceNumber v2_seq = db->GetLatestSequenceNumber();
+ auto* s1 = db->GetSnapshot();
+ // Advance sequence number.
+ ASSERT_OK(db->Put(WriteOptions(), "key2", "dummy"));
+ auto* s2 = db->GetSnapshot();
+
+ int count_value = 0;
+ auto callback = [&](void* arg) {
+ auto* ikey = reinterpret_cast<ParsedInternalKey*>(arg);
+ if (ikey->user_key == "key1") {
+ count_value++;
+ if (count_value == 2) {
+ // Processing v1.
+ db->ReleaseSnapshot(s1);
+ // Add some keys to advance max_evicted_seq and update
+ // old_commit_map.
+ ASSERT_OK(db->Put(WriteOptions(), "key3", "dummy"));
+ ASSERT_OK(db->Put(WriteOptions(), "key4", "dummy"));
+ }
+ }
+ };
+ SyncPoint::GetInstance()->SetCallBack("CompactionIterator:ProcessKV",
+ callback);
+ SyncPoint::GetInstance()->EnableProcessing();
+
+ ASSERT_OK(db->Flush(FlushOptions()));
+ // value1 should be compact out.
+ VerifyInternalKeys({{"key1", "value2", v2_seq, kTypeValue}});
+
+ // cleanup
+ db->ReleaseSnapshot(s2);
+ SyncPoint::GetInstance()->ClearAllCallBacks();
+}
+
+// Insert two values, v1 and v2, for a key. Insert another dummy key
+// so to evict the commit cache for v2, while v1 is still in commit cache.
+// Take two snapshots, s1 and s2. Release s1 during compaction.
+// Since commit cache for v2 is evicted, and old_commit_map don't have
+// s1 (it is released),
+// TODO(myabandeh): how can we be sure that the v2's commit info is evicted
+// (and not v1's)? Instead of putting a dummy, we can directly call
+// AddCommitted(v2_seq + cache_size, ...) to evict v2's entry from commit cache.
+TEST_P(WritePreparedTransactionTest, ReleaseSnapshotDuringCompaction3) {
+ const size_t snapshot_cache_bits = 7; // same as default
+ const size_t commit_cache_bits = 1; // commit cache size = 2
+ UpdateTransactionDBOptions(snapshot_cache_bits, commit_cache_bits);
+ ReOpen();
+
+ // Add a dummy key to evict v2 commit cache, but keep v1 commit cache.
+ // It also advance max_evicted_seq and can trigger old_commit_map cleanup.
+ auto add_dummy = [&]() {
+ auto* txn_dummy =
+ db->BeginTransaction(WriteOptions(), TransactionOptions(), nullptr);
+ ASSERT_OK(txn_dummy->SetName("txn_dummy"));
+ ASSERT_OK(txn_dummy->Put("dummy", "dummy"));
+ ASSERT_OK(txn_dummy->Prepare());
+ ASSERT_OK(txn_dummy->Commit());
+ delete txn_dummy;
+ };
+
+ ASSERT_OK(db->Put(WriteOptions(), "key1", "value1"));
+ auto* txn =
+ db->BeginTransaction(WriteOptions(), TransactionOptions(), nullptr);
+ ASSERT_OK(txn->SetName("txn"));
+ ASSERT_OK(txn->Put("key1", "value2"));
+ ASSERT_OK(txn->Prepare());
+ // TODO(myabandeh): replace it with GetId()?
+ auto v2_seq = db->GetLatestSequenceNumber();
+ ASSERT_OK(txn->Commit());
+ delete txn;
+ auto* s1 = db->GetSnapshot();
+ // Dummy key to advance sequence number.
+ add_dummy();
+ auto* s2 = db->GetSnapshot();
+
+ auto callback = [&](void*) {
+ db->ReleaseSnapshot(s1);
+ // Add some dummy entries to trigger s1 being cleanup from old_commit_map.
+ add_dummy();
+ add_dummy();
+ };
+ SyncPoint::GetInstance()->SetCallBack("CompactionIterator:AfterInit",
+ callback);
+ SyncPoint::GetInstance()->EnableProcessing();
+
+ ASSERT_OK(db->Flush(FlushOptions()));
+ // value1 should be compact out.
+ VerifyInternalKeys({{"key1", "value2", v2_seq, kTypeValue}});
+
+ db->ReleaseSnapshot(s2);
+ SyncPoint::GetInstance()->ClearAllCallBacks();
+}
+
+TEST_P(WritePreparedTransactionTest, ReleaseEarliestSnapshotDuringCompaction) {
+ const size_t snapshot_cache_bits = 7; // same as default
+ const size_t commit_cache_bits = 0; // minimum commit cache
+ UpdateTransactionDBOptions(snapshot_cache_bits, commit_cache_bits);
+ ReOpen();
+
+ ASSERT_OK(db->Put(WriteOptions(), "key1", "value1"));
+ auto* transaction =
+ db->BeginTransaction(WriteOptions(), TransactionOptions(), nullptr);
+ ASSERT_OK(transaction->SetName("txn"));
+ ASSERT_OK(transaction->Delete("key1"));
+ ASSERT_OK(transaction->Prepare());
+ SequenceNumber del_seq = db->GetLatestSequenceNumber();
+ auto snapshot1 = db->GetSnapshot();
+ // Increment sequence number.
+ ASSERT_OK(db->Put(WriteOptions(), "key2", "value2"));
+ auto snapshot2 = db->GetSnapshot();
+ ASSERT_OK(transaction->Commit());
+ delete transaction;
+ VerifyKeys({{"key1", "NOT_FOUND"}});
+ VerifyKeys({{"key1", "value1"}}, snapshot1);
+ VerifyKeys({{"key1", "value1"}}, snapshot2);
+ ASSERT_OK(db->Flush(FlushOptions()));
+
+ auto callback = [&](void* compaction) {
+ // Release snapshot1 after CompactionIterator init.
+ // CompactionIterator need to double check and find out snapshot2 is now
+ // the earliest existing snapshot.
+ if (compaction != nullptr) {
+ db->ReleaseSnapshot(snapshot1);
+ // Add some keys to advance max_evicted_seq.
+ ASSERT_OK(db->Put(WriteOptions(), "key3", "value3"));
+ ASSERT_OK(db->Put(WriteOptions(), "key4", "value4"));
+ }
+ };
+ SyncPoint::GetInstance()->SetCallBack("CompactionIterator:AfterInit",
+ callback);
+ SyncPoint::GetInstance()->EnableProcessing();
+
+ // Dummy keys to avoid compaction trivially move files and get around actual
+ // compaction logic.
+ ASSERT_OK(db->Put(WriteOptions(), "a", "dummy"));
+ ASSERT_OK(db->Put(WriteOptions(), "z", "dummy"));
+ ASSERT_OK(db->CompactRange(CompactRangeOptions(), nullptr, nullptr));
+ // Only verify for key1. Both the put and delete for the key should be kept.
+ // Since the delete tombstone is not visible to snapshot2, we need to keep
+ // at least one version of the key, for write-conflict check.
+ VerifyInternalKeys({{"key1", "", del_seq, kTypeDeletion},
+ {"key1", "value1", 0, kTypeValue}});
+ db->ReleaseSnapshot(snapshot2);
+ SyncPoint::GetInstance()->ClearAllCallBacks();
+}
+
+// A more complex test to verify compaction/flush should keep keys visible
+// to snapshots.
+TEST_P(WritePreparedTransactionTest,
+ CompactionKeepSnapshotVisibleKeysRandomized) {
+ constexpr size_t kNumTransactions = 10;
+ constexpr size_t kNumIterations = 1000;
+
+ std::vector<Transaction*> transactions(kNumTransactions, nullptr);
+ std::vector<size_t> versions(kNumTransactions, 0);
+ std::unordered_map<std::string, std::string> current_data;
+ std::vector<const Snapshot*> snapshots;
+ std::vector<std::unordered_map<std::string, std::string>> snapshot_data;
+
+ Random rnd(1103);
+ options.disable_auto_compactions = true;
+ ReOpen();
+
+ for (size_t i = 0; i < kNumTransactions; i++) {
+ std::string key = "key" + ToString(i);
+ std::string value = "value0";
+ ASSERT_OK(db->Put(WriteOptions(), key, value));
+ current_data[key] = value;
+ }
+ VerifyKeys(current_data);
+
+ for (size_t iter = 0; iter < kNumIterations; iter++) {
+ auto r = rnd.Next() % (kNumTransactions + 1);
+ if (r < kNumTransactions) {
+ std::string key = "key" + ToString(r);
+ if (transactions[r] == nullptr) {
+ std::string value = "value" + ToString(versions[r] + 1);
+ auto* txn = db->BeginTransaction(WriteOptions());
+ ASSERT_OK(txn->SetName("txn" + ToString(r)));
+ ASSERT_OK(txn->Put(key, value));
+ ASSERT_OK(txn->Prepare());
+ transactions[r] = txn;
+ } else {
+ std::string value = "value" + ToString(++versions[r]);
+ ASSERT_OK(transactions[r]->Commit());
+ delete transactions[r];
+ transactions[r] = nullptr;
+ current_data[key] = value;
+ }
+ } else {
+ auto* snapshot = db->GetSnapshot();
+ VerifyKeys(current_data, snapshot);
+ snapshots.push_back(snapshot);
+ snapshot_data.push_back(current_data);
+ }
+ VerifyKeys(current_data);
+ }
+ // Take a last snapshot to test compaction with uncommitted prepared
+ // transaction.
+ snapshots.push_back(db->GetSnapshot());
+ snapshot_data.push_back(current_data);
+
+ assert(snapshots.size() == snapshot_data.size());
+ for (size_t i = 0; i < snapshots.size(); i++) {
+ VerifyKeys(snapshot_data[i], snapshots[i]);
+ }
+ ASSERT_OK(db->Flush(FlushOptions()));
+ for (size_t i = 0; i < snapshots.size(); i++) {
+ VerifyKeys(snapshot_data[i], snapshots[i]);
+ }
+ // Dummy keys to avoid compaction trivially move files and get around actual
+ // compaction logic.
+ ASSERT_OK(db->Put(WriteOptions(), "a", "dummy"));
+ ASSERT_OK(db->Put(WriteOptions(), "z", "dummy"));
+ ASSERT_OK(db->CompactRange(CompactRangeOptions(), nullptr, nullptr));
+ for (size_t i = 0; i < snapshots.size(); i++) {
+ VerifyKeys(snapshot_data[i], snapshots[i]);
+ }
+ // cleanup
+ for (size_t i = 0; i < kNumTransactions; i++) {
+ if (transactions[i] == nullptr) {
+ continue;
+ }
+ ASSERT_OK(transactions[i]->Commit());
+ delete transactions[i];
+ }
+ for (size_t i = 0; i < snapshots.size(); i++) {
+ db->ReleaseSnapshot(snapshots[i]);
+ }
+}
+
+// Compaction should not apply the optimization to output key with sequence
+// number equal to 0 if the key is not visible to earliest snapshot, based on
+// commit sequence number.
+TEST_P(WritePreparedTransactionTest,
+ CompactionShouldKeepSequenceForUncommittedKeys) {
+ options.disable_auto_compactions = true;
+ ReOpen();
+ // Keep track of expected sequence number.
+ SequenceNumber expected_seq = 0;
+ auto* transaction = db->BeginTransaction(WriteOptions());
+ ASSERT_OK(transaction->SetName("txn"));
+ ASSERT_OK(transaction->Put("key1", "value1"));
+ ASSERT_OK(transaction->Prepare());
+ ASSERT_EQ(++expected_seq, db->GetLatestSequenceNumber());
+ SequenceNumber seq1 = expected_seq;
+ ASSERT_OK(db->Put(WriteOptions(), "key2", "value2"));
+ DBImpl* db_impl = reinterpret_cast<DBImpl*>(db->GetRootDB());
+ expected_seq++; // one for data
+ if (options.two_write_queues) {
+ expected_seq++; // one for commit
+ }
+ ASSERT_EQ(expected_seq, db_impl->TEST_GetLastVisibleSequence());
+ ASSERT_OK(db->Flush(FlushOptions()));
+ // Dummy keys to avoid compaction trivially move files and get around actual
+ // compaction logic.
+ ASSERT_OK(db->Put(WriteOptions(), "a", "dummy"));
+ ASSERT_OK(db->Put(WriteOptions(), "z", "dummy"));
+ ASSERT_OK(db->CompactRange(CompactRangeOptions(), nullptr, nullptr));
+ VerifyKeys({
+ {"key1", "NOT_FOUND"},
+ {"key2", "value2"},
+ });
+ VerifyInternalKeys({
+ // "key1" has not been committed. It keeps its sequence number.
+ {"key1", "value1", seq1, kTypeValue},
+ // "key2" is committed and output with seq = 0.
+ {"key2", "value2", 0, kTypeValue},
+ });
+ ASSERT_OK(transaction->Commit());
+ VerifyKeys({
+ {"key1", "value1"},
+ {"key2", "value2"},
+ });
+ delete transaction;
+}
+
+TEST_P(WritePreparedTransactionTest, CommitAndSnapshotDuringCompaction) {
+ options.disable_auto_compactions = true;
+ ReOpen();
+
+ const Snapshot* snapshot = nullptr;
+ ASSERT_OK(db->Put(WriteOptions(), "key1", "value1"));
+ auto* txn = db->BeginTransaction(WriteOptions());
+ ASSERT_OK(txn->SetName("txn"));
+ ASSERT_OK(txn->Put("key1", "value2"));
+ ASSERT_OK(txn->Prepare());
+
+ auto callback = [&](void*) {
+ // Snapshot is taken after compaction start. It should be taken into
+ // consideration for whether to compact out value1.
+ snapshot = db->GetSnapshot();
+ ASSERT_OK(txn->Commit());
+ delete txn;
+ };
+ SyncPoint::GetInstance()->SetCallBack("CompactionIterator:AfterInit",
+ callback);
+ SyncPoint::GetInstance()->EnableProcessing();
+ ASSERT_OK(db->Flush(FlushOptions()));
+ ASSERT_NE(nullptr, snapshot);
+ VerifyKeys({{"key1", "value2"}});
+ VerifyKeys({{"key1", "value1"}}, snapshot);
+ db->ReleaseSnapshot(snapshot);
+}
+
+TEST_P(WritePreparedTransactionTest, Iterate) {
+ auto verify_state = [](Iterator* iter, const std::string& key,
+ const std::string& value) {
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_OK(iter->status());
+ ASSERT_EQ(key, iter->key().ToString());
+ ASSERT_EQ(value, iter->value().ToString());
+ };
+
+ auto verify_iter = [&](const std::string& expected_val) {
+ // Get iterator from a concurrent transaction and make sure it has the
+ // same view as an iterator from the DB.
+ auto* txn = db->BeginTransaction(WriteOptions());
+
+ for (int i = 0; i < 2; i++) {
+ Iterator* iter = (i == 0)
+ ? db->NewIterator(ReadOptions())
+ : txn->GetIterator(ReadOptions());
+ // Seek
+ iter->Seek("foo");
+ verify_state(iter, "foo", expected_val);
+ // Next
+ iter->Seek("a");
+ verify_state(iter, "a", "va");
+ iter->Next();
+ verify_state(iter, "foo", expected_val);
+ // SeekForPrev
+ iter->SeekForPrev("y");
+ verify_state(iter, "foo", expected_val);
+ // Prev
+ iter->SeekForPrev("z");
+ verify_state(iter, "z", "vz");
+ iter->Prev();
+ verify_state(iter, "foo", expected_val);
+ delete iter;
+ }
+ delete txn;
+ };
+
+ ASSERT_OK(db->Put(WriteOptions(), "foo", "v1"));
+ auto* transaction = db->BeginTransaction(WriteOptions());
+ ASSERT_OK(transaction->SetName("txn"));
+ ASSERT_OK(transaction->Put("foo", "v2"));
+ ASSERT_OK(transaction->Prepare());
+ VerifyKeys({{"foo", "v1"}});
+ // dummy keys
+ ASSERT_OK(db->Put(WriteOptions(), "a", "va"));
+ ASSERT_OK(db->Put(WriteOptions(), "z", "vz"));
+ verify_iter("v1");
+ ASSERT_OK(transaction->Commit());
+ VerifyKeys({{"foo", "v2"}});
+ verify_iter("v2");
+ delete transaction;
+}
+
+TEST_P(WritePreparedTransactionTest, IteratorRefreshNotSupported) {
+ Iterator* iter = db->NewIterator(ReadOptions());
+ ASSERT_TRUE(iter->Refresh().IsNotSupported());
+ delete iter;
+}
+
+// Committing an delayed prepared has two non-atomic steps: update commit cache,
+// remove seq from delayed_prepared_. The read in IsInSnapshot also involves two
+// non-atomic steps of checking these two data structures. This test breaks each
+// in the middle to ensure correctness in spite of non-atomic execution.
+// Note: This test is limitted to the case where snapshot is larger than the
+// max_evicted_seq_.
+TEST_P(WritePreparedTransactionTest, NonAtomicCommitOfDelayedPrepared) {
+ const size_t snapshot_cache_bits = 7; // same as default
+ const size_t commit_cache_bits = 3; // 8 entries
+ for (auto split_read : {true, false}) {
+ std::vector<bool> split_options = {false};
+ if (split_read) {
+ // Also test for break before mutex
+ split_options.push_back(true);
+ }
+ for (auto split_before_mutex : split_options) {
+ UpdateTransactionDBOptions(snapshot_cache_bits, commit_cache_bits);
+ ReOpen();
+ WritePreparedTxnDB* wp_db = dynamic_cast<WritePreparedTxnDB*>(db);
+ DBImpl* db_impl = reinterpret_cast<DBImpl*>(db->GetRootDB());
+ // Fill up the commit cache
+ std::string init_value("value1");
+ for (int i = 0; i < 10; i++) {
+ db->Put(WriteOptions(), Slice("key1"), Slice(init_value));
+ }
+ // Prepare a transaction but do not commit it
+ Transaction* txn =
+ db->BeginTransaction(WriteOptions(), TransactionOptions());
+ ASSERT_OK(txn->SetName("xid"));
+ ASSERT_OK(txn->Put(Slice("key1"), Slice("value2")));
+ ASSERT_OK(txn->Prepare());
+ // Commit a bunch of entries to advance max evicted seq and make the
+ // prepared a delayed prepared
+ for (int i = 0; i < 10; i++) {
+ db->Put(WriteOptions(), Slice("key3"), Slice("value3"));
+ }
+ // The snapshot should not see the delayed prepared entry
+ auto snap = db->GetSnapshot();
+
+ if (split_read) {
+ if (split_before_mutex) {
+ // split before acquiring prepare_mutex_
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency(
+ {{"WritePreparedTxnDB::IsInSnapshot:prepared_mutex_:pause",
+ "AtomicCommitOfDelayedPrepared:Commit:before"},
+ {"AtomicCommitOfDelayedPrepared:Commit:after",
+ "WritePreparedTxnDB::IsInSnapshot:prepared_mutex_:resume"}});
+ } else {
+ // split right after reading from the commit cache
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency(
+ {{"WritePreparedTxnDB::IsInSnapshot:GetCommitEntry:pause",
+ "AtomicCommitOfDelayedPrepared:Commit:before"},
+ {"AtomicCommitOfDelayedPrepared:Commit:after",
+ "WritePreparedTxnDB::IsInSnapshot:GetCommitEntry:resume"}});
+ }
+ } else { // split commit
+ // split right before removing from delayed_prepared_
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency(
+ {{"WritePreparedTxnDB::RemovePrepared:pause",
+ "AtomicCommitOfDelayedPrepared:Read:before"},
+ {"AtomicCommitOfDelayedPrepared:Read:after",
+ "WritePreparedTxnDB::RemovePrepared:resume"}});
+ }
+ SyncPoint::GetInstance()->EnableProcessing();
+
+ ROCKSDB_NAMESPACE::port::Thread commit_thread([&]() {
+ TEST_SYNC_POINT("AtomicCommitOfDelayedPrepared:Commit:before");
+ ASSERT_OK(txn->Commit());
+ if (split_before_mutex) {
+ // Do bunch of inserts to evict the commit entry from the cache. This
+ // would prevent the 2nd look into commit cache under prepare_mutex_
+ // to see the commit entry.
+ auto seq = db_impl->TEST_GetLastVisibleSequence();
+ size_t tries = 0;
+ while (wp_db->max_evicted_seq_ < seq && tries < 50) {
+ db->Put(WriteOptions(), Slice("key3"), Slice("value3"));
+ tries++;
+ };
+ ASSERT_LT(tries, 50);
+ }
+ TEST_SYNC_POINT("AtomicCommitOfDelayedPrepared:Commit:after");
+ delete txn;
+ });
+
+ ROCKSDB_NAMESPACE::port::Thread read_thread([&]() {
+ TEST_SYNC_POINT("AtomicCommitOfDelayedPrepared:Read:before");
+ ReadOptions roptions;
+ roptions.snapshot = snap;
+ PinnableSlice value;
+ auto s = db->Get(roptions, db->DefaultColumnFamily(), "key1", &value);
+ ASSERT_OK(s);
+ // It should not see the commit of delayed prepared
+ ASSERT_TRUE(value == init_value);
+ TEST_SYNC_POINT("AtomicCommitOfDelayedPrepared:Read:after");
+ db->ReleaseSnapshot(snap);
+ });
+
+ read_thread.join();
+ commit_thread.join();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->ClearAllCallBacks();
+ } // for split_before_mutex
+ } // for split_read
+}
+
+// When max evicted seq advances a prepared seq, it involves two updates: i)
+// adding prepared seq to delayed_prepared_, ii) updating max_evicted_seq_.
+// ::IsInSnapshot also reads these two values in a non-atomic way. This test
+// ensures correctness if the update occurs after ::IsInSnapshot reads
+// delayed_prepared_empty_ and before it reads max_evicted_seq_.
+// Note: this test focuses on read snapshot larger than max_evicted_seq_.
+TEST_P(WritePreparedTransactionTest, NonAtomicUpdateOfDelayedPrepared) {
+ const size_t snapshot_cache_bits = 7; // same as default
+ const size_t commit_cache_bits = 3; // 8 entries
+ UpdateTransactionDBOptions(snapshot_cache_bits, commit_cache_bits);
+ ReOpen();
+ WritePreparedTxnDB* wp_db = dynamic_cast<WritePreparedTxnDB*>(db);
+ // Fill up the commit cache
+ std::string init_value("value1");
+ for (int i = 0; i < 10; i++) {
+ db->Put(WriteOptions(), Slice("key1"), Slice(init_value));
+ }
+ // Prepare a transaction but do not commit it
+ Transaction* txn = db->BeginTransaction(WriteOptions(), TransactionOptions());
+ ASSERT_OK(txn->SetName("xid"));
+ ASSERT_OK(txn->Put(Slice("key1"), Slice("value2")));
+ ASSERT_OK(txn->Prepare());
+ // Create a gap between prepare seq and snapshot seq
+ db->Put(WriteOptions(), Slice("key3"), Slice("value3"));
+ db->Put(WriteOptions(), Slice("key3"), Slice("value3"));
+ // The snapshot should not see the delayed prepared entry
+ auto snap = db->GetSnapshot();
+ ASSERT_LT(txn->GetId(), snap->GetSequenceNumber());
+
+ // split right after reading delayed_prepared_empty_
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency(
+ {{"WritePreparedTxnDB::IsInSnapshot:delayed_prepared_empty_:pause",
+ "AtomicUpdateOfDelayedPrepared:before"},
+ {"AtomicUpdateOfDelayedPrepared:after",
+ "WritePreparedTxnDB::IsInSnapshot:delayed_prepared_empty_:resume"}});
+ SyncPoint::GetInstance()->EnableProcessing();
+
+ ROCKSDB_NAMESPACE::port::Thread commit_thread([&]() {
+ TEST_SYNC_POINT("AtomicUpdateOfDelayedPrepared:before");
+ // Commit a bunch of entries to advance max evicted seq and make the
+ // prepared a delayed prepared
+ size_t tries = 0;
+ while (wp_db->max_evicted_seq_ < txn->GetId() && tries < 50) {
+ db->Put(WriteOptions(), Slice("key3"), Slice("value3"));
+ tries++;
+ };
+ ASSERT_LT(tries, 50);
+ // This is the case on which the test focuses
+ ASSERT_LT(wp_db->max_evicted_seq_, snap->GetSequenceNumber());
+ TEST_SYNC_POINT("AtomicUpdateOfDelayedPrepared:after");
+ });
+
+ ROCKSDB_NAMESPACE::port::Thread read_thread([&]() {
+ ReadOptions roptions;
+ roptions.snapshot = snap;
+ PinnableSlice value;
+ auto s = db->Get(roptions, db->DefaultColumnFamily(), "key1", &value);
+ ASSERT_OK(s);
+ // It should not see the uncommitted value of delayed prepared
+ ASSERT_TRUE(value == init_value);
+ db->ReleaseSnapshot(snap);
+ });
+
+ read_thread.join();
+ commit_thread.join();
+ ASSERT_OK(txn->Commit());
+ delete txn;
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->ClearAllCallBacks();
+}
+
+// Eviction from commit cache and update of max evicted seq are two non-atomic
+// steps. Similarly the read of max_evicted_seq_ in ::IsInSnapshot and reading
+// from commit cache are two non-atomic steps. This tests if the update occurs
+// after reading max_evicted_seq_ and before reading the commit cache.
+// Note: the test focuses on snapshot larger than max_evicted_seq_
+TEST_P(WritePreparedTransactionTest, NonAtomicUpdateOfMaxEvictedSeq) {
+ const size_t snapshot_cache_bits = 7; // same as default
+ const size_t commit_cache_bits = 3; // 8 entries
+ UpdateTransactionDBOptions(snapshot_cache_bits, commit_cache_bits);
+ ReOpen();
+ WritePreparedTxnDB* wp_db = dynamic_cast<WritePreparedTxnDB*>(db);
+ // Fill up the commit cache
+ std::string init_value("value1");
+ std::string last_value("value_final");
+ for (int i = 0; i < 10; i++) {
+ db->Put(WriteOptions(), Slice("key1"), Slice(init_value));
+ }
+ // Do an uncommitted write to prevent min_uncommitted optimization
+ Transaction* txn1 =
+ db->BeginTransaction(WriteOptions(), TransactionOptions());
+ ASSERT_OK(txn1->SetName("xid1"));
+ ASSERT_OK(txn1->Put(Slice("key0"), last_value));
+ ASSERT_OK(txn1->Prepare());
+ // Do a write with prepare to get the prepare seq
+ Transaction* txn = db->BeginTransaction(WriteOptions(), TransactionOptions());
+ ASSERT_OK(txn->SetName("xid"));
+ ASSERT_OK(txn->Put(Slice("key1"), last_value));
+ ASSERT_OK(txn->Prepare());
+ ASSERT_OK(txn->Commit());
+ // Create a gap between commit entry and snapshot seq
+ db->Put(WriteOptions(), Slice("key3"), Slice("value3"));
+ db->Put(WriteOptions(), Slice("key3"), Slice("value3"));
+ // The snapshot should see the last commit
+ auto snap = db->GetSnapshot();
+ ASSERT_LE(txn->GetId(), snap->GetSequenceNumber());
+
+ // split right after reading max_evicted_seq_
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency(
+ {{"WritePreparedTxnDB::IsInSnapshot:max_evicted_seq_:pause",
+ "NonAtomicUpdateOfMaxEvictedSeq:before"},
+ {"NonAtomicUpdateOfMaxEvictedSeq:after",
+ "WritePreparedTxnDB::IsInSnapshot:max_evicted_seq_:resume"}});
+ SyncPoint::GetInstance()->EnableProcessing();
+
+ ROCKSDB_NAMESPACE::port::Thread commit_thread([&]() {
+ TEST_SYNC_POINT("NonAtomicUpdateOfMaxEvictedSeq:before");
+ // Commit a bunch of entries to advance max evicted seq beyond txn->GetId()
+ size_t tries = 0;
+ while (wp_db->max_evicted_seq_ < txn->GetId() && tries < 50) {
+ db->Put(WriteOptions(), Slice("key3"), Slice("value3"));
+ tries++;
+ };
+ ASSERT_LT(tries, 50);
+ // This is the case on which the test focuses
+ ASSERT_LT(wp_db->max_evicted_seq_, snap->GetSequenceNumber());
+ TEST_SYNC_POINT("NonAtomicUpdateOfMaxEvictedSeq:after");
+ });
+
+ ROCKSDB_NAMESPACE::port::Thread read_thread([&]() {
+ ReadOptions roptions;
+ roptions.snapshot = snap;
+ PinnableSlice value;
+ auto s = db->Get(roptions, db->DefaultColumnFamily(), "key1", &value);
+ ASSERT_OK(s);
+ // It should see the committed value of the evicted entry
+ ASSERT_TRUE(value == last_value);
+ db->ReleaseSnapshot(snap);
+ });
+
+ read_thread.join();
+ commit_thread.join();
+ delete txn;
+ txn1->Commit();
+ delete txn1;
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->ClearAllCallBacks();
+}
+
+// Test when we add a prepared seq when the max_evicted_seq_ already goes beyond
+// that. The test focuses on a race condition between AddPrepared and
+// AdvanceMaxEvictedSeq functions.
+TEST_P(WritePreparedTransactionTest, AddPreparedBeforeMax) {
+ if (!options.two_write_queues) {
+ // This test is only for two write queues
+ return;
+ }
+ const size_t snapshot_cache_bits = 7; // same as default
+ // 1 entry to advance max after the 2nd commit
+ const size_t commit_cache_bits = 0;
+ UpdateTransactionDBOptions(snapshot_cache_bits, commit_cache_bits);
+ ReOpen();
+ WritePreparedTxnDB* wp_db = dynamic_cast<WritePreparedTxnDB*>(db);
+ std::string some_value("value_some");
+ std::string uncommitted_value("value_uncommitted");
+ // Prepare two uncommitted transactions
+ Transaction* txn1 =
+ db->BeginTransaction(WriteOptions(), TransactionOptions());
+ ASSERT_OK(txn1->SetName("xid1"));
+ ASSERT_OK(txn1->Put(Slice("key1"), some_value));
+ ASSERT_OK(txn1->Prepare());
+ Transaction* txn2 =
+ db->BeginTransaction(WriteOptions(), TransactionOptions());
+ ASSERT_OK(txn2->SetName("xid2"));
+ ASSERT_OK(txn2->Put(Slice("key2"), some_value));
+ ASSERT_OK(txn2->Prepare());
+ // Start the txn here so the other thread could get its id
+ Transaction* txn = db->BeginTransaction(WriteOptions(), TransactionOptions());
+ ASSERT_OK(txn->SetName("xid"));
+ ASSERT_OK(txn->Put(Slice("key0"), uncommitted_value));
+ port::Mutex txn_mutex_;
+
+ // t1) Insert prepared entry, t2) commit other entries to advance max
+ // evicted sec and finish checking the existing prepared entries, t1)
+ // AddPrepared, t2) update max_evicted_seq_
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency({
+ {"AddPreparedCallback::AddPrepared::begin:pause",
+ "AddPreparedBeforeMax::read_thread:start"},
+ {"AdvanceMaxEvictedSeq::update_max:pause",
+ "AddPreparedCallback::AddPrepared::begin:resume"},
+ {"AddPreparedCallback::AddPrepared::end",
+ "AdvanceMaxEvictedSeq::update_max:resume"},
+ });
+ SyncPoint::GetInstance()->EnableProcessing();
+
+ ROCKSDB_NAMESPACE::port::Thread write_thread([&]() {
+ txn_mutex_.Lock();
+ ASSERT_OK(txn->Prepare());
+ txn_mutex_.Unlock();
+ });
+
+ ROCKSDB_NAMESPACE::port::Thread read_thread([&]() {
+ TEST_SYNC_POINT("AddPreparedBeforeMax::read_thread:start");
+ // Publish seq number with a commit
+ ASSERT_OK(txn1->Commit());
+ // Since the commit cache size is one the 2nd commit evict the 1st one and
+ // invokes AdcanceMaxEvictedSeq
+ ASSERT_OK(txn2->Commit());
+
+ ReadOptions roptions;
+ PinnableSlice value;
+ // The snapshot should not see the uncommitted value from write_thread
+ auto snap = db->GetSnapshot();
+ ASSERT_LT(wp_db->max_evicted_seq_, snap->GetSequenceNumber());
+ // This is the scenario that we test for
+ txn_mutex_.Lock();
+ ASSERT_GT(wp_db->max_evicted_seq_, txn->GetId());
+ txn_mutex_.Unlock();
+ roptions.snapshot = snap;
+ auto s = db->Get(roptions, db->DefaultColumnFamily(), "key0", &value);
+ ASSERT_TRUE(s.IsNotFound());
+ db->ReleaseSnapshot(snap);
+ });
+
+ read_thread.join();
+ write_thread.join();
+ delete txn1;
+ delete txn2;
+ ASSERT_OK(txn->Commit());
+ delete txn;
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->ClearAllCallBacks();
+}
+
+// When an old prepared entry gets committed, there is a gap between the time
+// that it is published and when it is cleaned up from old_prepared_. This test
+// stresses such cases.
+TEST_P(WritePreparedTransactionTest, CommitOfDelayedPrepared) {
+ const size_t snapshot_cache_bits = 7; // same as default
+ for (const size_t commit_cache_bits : {0, 2, 3}) {
+ for (const size_t sub_batch_cnt : {1, 2, 3}) {
+ UpdateTransactionDBOptions(snapshot_cache_bits, commit_cache_bits);
+ ReOpen();
+ std::atomic<const Snapshot*> snap = {nullptr};
+ std::atomic<SequenceNumber> exp_prepare = {0};
+ ROCKSDB_NAMESPACE::port::Thread callback_thread;
+ // Value is synchronized via snap
+ PinnableSlice value;
+ // Take a snapshot after publish and before RemovePrepared:Start
+ auto snap_callback = [&]() {
+ ASSERT_EQ(nullptr, snap.load());
+ snap.store(db->GetSnapshot());
+ ReadOptions roptions;
+ roptions.snapshot = snap.load();
+ auto s = db->Get(roptions, db->DefaultColumnFamily(), "key2", &value);
+ ASSERT_OK(s);
+ };
+ auto callback = [&](void* param) {
+ SequenceNumber prep_seq = *((SequenceNumber*)param);
+ if (prep_seq == exp_prepare.load()) { // only for write_thread
+ // We need to spawn a thread to avoid deadlock since getting a
+ // snpashot might end up calling AdvanceSeqByOne which needs joining
+ // the write queue.
+ callback_thread = ROCKSDB_NAMESPACE::port::Thread(snap_callback);
+ TEST_SYNC_POINT("callback:end");
+ }
+ };
+ // Wait for the first snapshot be taken in GetSnapshotInternal. Although
+ // it might be updated before GetSnapshotInternal finishes but this should
+ // cover most of the cases.
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency({
+ {"WritePreparedTxnDB::GetSnapshotInternal:first", "callback:end"},
+ });
+ SyncPoint::GetInstance()->SetCallBack("RemovePrepared:Start", callback);
+ SyncPoint::GetInstance()->EnableProcessing();
+ // Thread to cause frequent evictions
+ ROCKSDB_NAMESPACE::port::Thread eviction_thread([&]() {
+ // Too many txns might cause commit_seq - prepare_seq in another thread
+ // to go beyond DELTA_UPPERBOUND
+ for (int i = 0; i < 25 * (1 << commit_cache_bits); i++) {
+ db->Put(WriteOptions(), Slice("key1"), Slice("value1"));
+ }
+ });
+ ROCKSDB_NAMESPACE::port::Thread write_thread([&]() {
+ for (int i = 0; i < 25 * (1 << commit_cache_bits); i++) {
+ Transaction* txn =
+ db->BeginTransaction(WriteOptions(), TransactionOptions());
+ ASSERT_OK(txn->SetName("xid"));
+ std::string val_str = "value" + ToString(i);
+ for (size_t b = 0; b < sub_batch_cnt; b++) {
+ ASSERT_OK(txn->Put(Slice("key2"), val_str));
+ }
+ ASSERT_OK(txn->Prepare());
+ // Let an eviction to kick in
+ std::this_thread::yield();
+
+ exp_prepare.store(txn->GetId());
+ ASSERT_OK(txn->Commit());
+ delete txn;
+ // Wait for the snapshot taking that is triggered by
+ // RemovePrepared:Start callback
+ callback_thread.join();
+
+ // Read with the snapshot taken before delayed_prepared_ cleanup
+ ReadOptions roptions;
+ roptions.snapshot = snap.load();
+ ASSERT_NE(nullptr, roptions.snapshot);
+ PinnableSlice value2;
+ auto s =
+ db->Get(roptions, db->DefaultColumnFamily(), "key2", &value2);
+ ASSERT_OK(s);
+ // It should see its own write
+ ASSERT_TRUE(val_str == value2);
+ // The value read by snapshot should not change
+ ASSERT_STREQ(value2.ToString().c_str(), value.ToString().c_str());
+
+ db->ReleaseSnapshot(roptions.snapshot);
+ snap.store(nullptr);
+ }
+ });
+ write_thread.join();
+ eviction_thread.join();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->ClearAllCallBacks();
+ }
+ }
+}
+
+// Test that updating the commit map will not affect the existing snapshots
+TEST_P(WritePreparedTransactionTest, AtomicCommit) {
+ for (bool skip_prepare : {true, false}) {
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency({
+ {"WritePreparedTxnDB::AddCommitted:start",
+ "AtomicCommit::GetSnapshot:start"},
+ {"AtomicCommit::Get:end",
+ "WritePreparedTxnDB::AddCommitted:start:pause"},
+ {"WritePreparedTxnDB::AddCommitted:end", "AtomicCommit::Get2:start"},
+ {"AtomicCommit::Get2:end",
+ "WritePreparedTxnDB::AddCommitted:end:pause:"},
+ });
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();
+ ROCKSDB_NAMESPACE::port::Thread write_thread([&]() {
+ if (skip_prepare) {
+ db->Put(WriteOptions(), Slice("key"), Slice("value"));
+ } else {
+ Transaction* txn =
+ db->BeginTransaction(WriteOptions(), TransactionOptions());
+ ASSERT_OK(txn->SetName("xid"));
+ ASSERT_OK(txn->Put(Slice("key"), Slice("value")));
+ ASSERT_OK(txn->Prepare());
+ ASSERT_OK(txn->Commit());
+ delete txn;
+ }
+ });
+ ROCKSDB_NAMESPACE::port::Thread read_thread([&]() {
+ ReadOptions roptions;
+ TEST_SYNC_POINT("AtomicCommit::GetSnapshot:start");
+ roptions.snapshot = db->GetSnapshot();
+ PinnableSlice val;
+ auto s = db->Get(roptions, db->DefaultColumnFamily(), "key", &val);
+ TEST_SYNC_POINT("AtomicCommit::Get:end");
+ TEST_SYNC_POINT("AtomicCommit::Get2:start");
+ ASSERT_SAME(roptions, db, s, val, "key");
+ TEST_SYNC_POINT("AtomicCommit::Get2:end");
+ db->ReleaseSnapshot(roptions.snapshot);
+ });
+ read_thread.join();
+ write_thread.join();
+ ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();
+ }
+}
+
+// Test that we can change write policy from WriteCommitted to WritePrepared
+// after a clean shutdown (which would empty the WAL)
+TEST_P(WritePreparedTransactionTest, WP_WC_DBBackwardCompatibility) {
+ bool empty_wal = true;
+ CrossCompatibilityTest(WRITE_COMMITTED, WRITE_PREPARED, empty_wal);
+}
+
+// Test that we fail fast if WAL is not emptied between changing the write
+// policy from WriteCommitted to WritePrepared
+TEST_P(WritePreparedTransactionTest, WP_WC_WALBackwardIncompatibility) {
+ bool empty_wal = true;
+ CrossCompatibilityTest(WRITE_COMMITTED, WRITE_PREPARED, !empty_wal);
+}
+
+// Test that we can change write policy from WritePrepare back to WriteCommitted
+// after a clean shutdown (which would empty the WAL)
+TEST_P(WritePreparedTransactionTest, WC_WP_ForwardCompatibility) {
+ bool empty_wal = true;
+ CrossCompatibilityTest(WRITE_PREPARED, WRITE_COMMITTED, empty_wal);
+}
+
+// Test that we fail fast if WAL is not emptied between changing the write
+// policy from WriteCommitted to WritePrepared
+TEST_P(WritePreparedTransactionTest, WC_WP_WALForwardIncompatibility) {
+ bool empty_wal = true;
+ CrossCompatibilityTest(WRITE_PREPARED, WRITE_COMMITTED, !empty_wal);
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
+
+#else
+#include <stdio.h>
+
+int main(int /*argc*/, char** /*argv*/) {
+ fprintf(stderr,
+ "SKIPPED as Transactions are not supported in ROCKSDB_LITE\n");
+ return 0;
+}
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/write_prepared_txn.cc b/src/rocksdb/utilities/transactions/write_prepared_txn.cc
new file mode 100644
index 000000000..216d83555
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/write_prepared_txn.cc
@@ -0,0 +1,473 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "utilities/transactions/write_prepared_txn.h"
+
+#include <cinttypes>
+#include <map>
+#include <set>
+
+#include "db/column_family.h"
+#include "db/db_impl/db_impl.h"
+#include "rocksdb/db.h"
+#include "rocksdb/status.h"
+#include "rocksdb/utilities/transaction_db.h"
+#include "util/cast_util.h"
+#include "utilities/transactions/pessimistic_transaction.h"
+#include "utilities/transactions/write_prepared_txn_db.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+struct WriteOptions;
+
+WritePreparedTxn::WritePreparedTxn(WritePreparedTxnDB* txn_db,
+ const WriteOptions& write_options,
+ const TransactionOptions& txn_options)
+ : PessimisticTransaction(txn_db, write_options, txn_options, false),
+ wpt_db_(txn_db) {
+ // Call Initialize outside PessimisticTransaction constructor otherwise it
+ // would skip overridden functions in WritePreparedTxn since they are not
+ // defined yet in the constructor of PessimisticTransaction
+ Initialize(txn_options);
+}
+
+void WritePreparedTxn::Initialize(const TransactionOptions& txn_options) {
+ PessimisticTransaction::Initialize(txn_options);
+ prepare_batch_cnt_ = 0;
+}
+
+void WritePreparedTxn::MultiGet(const ReadOptions& options,
+ ColumnFamilyHandle* column_family,
+ const size_t num_keys, const Slice* keys,
+ PinnableSlice* values, Status* statuses,
+ const bool sorted_input) {
+ SequenceNumber min_uncommitted, snap_seq;
+ const SnapshotBackup backed_by_snapshot =
+ wpt_db_->AssignMinMaxSeqs(options.snapshot, &min_uncommitted, &snap_seq);
+ WritePreparedTxnReadCallback callback(wpt_db_, snap_seq, min_uncommitted,
+ backed_by_snapshot);
+ write_batch_.MultiGetFromBatchAndDB(db_, options, column_family, num_keys,
+ keys, values, statuses, sorted_input,
+ &callback);
+ if (UNLIKELY(!callback.valid() ||
+ !wpt_db_->ValidateSnapshot(snap_seq, backed_by_snapshot))) {
+ wpt_db_->WPRecordTick(TXN_GET_TRY_AGAIN);
+ for (size_t i = 0; i < num_keys; i++) {
+ statuses[i] = Status::TryAgain();
+ }
+ }
+}
+
+Status WritePreparedTxn::Get(const ReadOptions& options,
+ ColumnFamilyHandle* column_family,
+ const Slice& key, PinnableSlice* pinnable_val) {
+ SequenceNumber min_uncommitted, snap_seq;
+ const SnapshotBackup backed_by_snapshot =
+ wpt_db_->AssignMinMaxSeqs(options.snapshot, &min_uncommitted, &snap_seq);
+ WritePreparedTxnReadCallback callback(wpt_db_, snap_seq, min_uncommitted,
+ backed_by_snapshot);
+ auto res = write_batch_.GetFromBatchAndDB(db_, options, column_family, key,
+ pinnable_val, &callback);
+ if (LIKELY(callback.valid() &&
+ wpt_db_->ValidateSnapshot(callback.max_visible_seq(),
+ backed_by_snapshot))) {
+ return res;
+ } else {
+ wpt_db_->WPRecordTick(TXN_GET_TRY_AGAIN);
+ return Status::TryAgain();
+ }
+}
+
+Iterator* WritePreparedTxn::GetIterator(const ReadOptions& options) {
+ // Make sure to get iterator from WritePrepareTxnDB, not the root db.
+ Iterator* db_iter = wpt_db_->NewIterator(options);
+ assert(db_iter);
+
+ return write_batch_.NewIteratorWithBase(db_iter);
+}
+
+Iterator* WritePreparedTxn::GetIterator(const ReadOptions& options,
+ ColumnFamilyHandle* column_family) {
+ // Make sure to get iterator from WritePrepareTxnDB, not the root db.
+ Iterator* db_iter = wpt_db_->NewIterator(options, column_family);
+ assert(db_iter);
+
+ return write_batch_.NewIteratorWithBase(column_family, db_iter);
+}
+
+Status WritePreparedTxn::PrepareInternal() {
+ WriteOptions write_options = write_options_;
+ write_options.disableWAL = false;
+ const bool WRITE_AFTER_COMMIT = true;
+ const bool kFirstPrepareBatch = true;
+ WriteBatchInternal::MarkEndPrepare(GetWriteBatch()->GetWriteBatch(), name_,
+ !WRITE_AFTER_COMMIT);
+ // For each duplicate key we account for a new sub-batch
+ prepare_batch_cnt_ = GetWriteBatch()->SubBatchCnt();
+ // Having AddPrepared in the PreReleaseCallback allows in-order addition of
+ // prepared entries to PreparedHeap and hence enables an optimization. Refer to
+ // SmallestUnCommittedSeq for more details.
+ AddPreparedCallback add_prepared_callback(
+ wpt_db_, db_impl_, prepare_batch_cnt_,
+ db_impl_->immutable_db_options().two_write_queues, kFirstPrepareBatch);
+ const bool DISABLE_MEMTABLE = true;
+ uint64_t seq_used = kMaxSequenceNumber;
+ Status s = db_impl_->WriteImpl(
+ write_options, GetWriteBatch()->GetWriteBatch(),
+ /*callback*/ nullptr, &log_number_, /*log ref*/ 0, !DISABLE_MEMTABLE,
+ &seq_used, prepare_batch_cnt_, &add_prepared_callback);
+ assert(!s.ok() || seq_used != kMaxSequenceNumber);
+ auto prepare_seq = seq_used;
+ SetId(prepare_seq);
+ return s;
+}
+
+Status WritePreparedTxn::CommitWithoutPrepareInternal() {
+ // For each duplicate key we account for a new sub-batch
+ const size_t batch_cnt = GetWriteBatch()->SubBatchCnt();
+ return CommitBatchInternal(GetWriteBatch()->GetWriteBatch(), batch_cnt);
+}
+
+Status WritePreparedTxn::CommitBatchInternal(WriteBatch* batch,
+ size_t batch_cnt) {
+ return wpt_db_->WriteInternal(write_options_, batch, batch_cnt, this);
+}
+
+Status WritePreparedTxn::CommitInternal() {
+ ROCKS_LOG_DETAILS(db_impl_->immutable_db_options().info_log,
+ "CommitInternal prepare_seq: %" PRIu64, GetID());
+ // We take the commit-time batch and append the Commit marker.
+ // The Memtable will ignore the Commit marker in non-recovery mode
+ WriteBatch* working_batch = GetCommitTimeWriteBatch();
+ const bool empty = working_batch->Count() == 0;
+ WriteBatchInternal::MarkCommit(working_batch, name_);
+
+ const bool for_recovery = use_only_the_last_commit_time_batch_for_recovery_;
+ if (!empty && for_recovery) {
+ // When not writing to memtable, we can still cache the latest write batch.
+ // The cached batch will be written to memtable in WriteRecoverableState
+ // during FlushMemTable
+ WriteBatchInternal::SetAsLastestPersistentState(working_batch);
+ }
+
+ auto prepare_seq = GetId();
+ const bool includes_data = !empty && !for_recovery;
+ assert(prepare_batch_cnt_);
+ size_t commit_batch_cnt = 0;
+ if (UNLIKELY(includes_data)) {
+ ROCKS_LOG_WARN(db_impl_->immutable_db_options().info_log,
+ "Duplicate key overhead");
+ SubBatchCounter counter(*wpt_db_->GetCFComparatorMap());
+ auto s = working_batch->Iterate(&counter);
+ assert(s.ok());
+ commit_batch_cnt = counter.BatchCount();
+ }
+ const bool disable_memtable = !includes_data;
+ const bool do_one_write =
+ !db_impl_->immutable_db_options().two_write_queues || disable_memtable;
+ WritePreparedCommitEntryPreReleaseCallback update_commit_map(
+ wpt_db_, db_impl_, prepare_seq, prepare_batch_cnt_, commit_batch_cnt);
+ // This is to call AddPrepared on CommitTimeWriteBatch
+ const bool kFirstPrepareBatch = true;
+ AddPreparedCallback add_prepared_callback(
+ wpt_db_, db_impl_, commit_batch_cnt,
+ db_impl_->immutable_db_options().two_write_queues, !kFirstPrepareBatch);
+ PreReleaseCallback* pre_release_callback;
+ if (do_one_write) {
+ pre_release_callback = &update_commit_map;
+ } else {
+ pre_release_callback = &add_prepared_callback;
+ }
+ uint64_t seq_used = kMaxSequenceNumber;
+ // Since the prepared batch is directly written to memtable, there is already
+ // a connection between the memtable and its WAL, so there is no need to
+ // redundantly reference the log that contains the prepared data.
+ const uint64_t zero_log_number = 0ull;
+ size_t batch_cnt = UNLIKELY(commit_batch_cnt) ? commit_batch_cnt : 1;
+ auto s = db_impl_->WriteImpl(write_options_, working_batch, nullptr, nullptr,
+ zero_log_number, disable_memtable, &seq_used,
+ batch_cnt, pre_release_callback);
+ assert(!s.ok() || seq_used != kMaxSequenceNumber);
+ const SequenceNumber commit_batch_seq = seq_used;
+ if (LIKELY(do_one_write || !s.ok())) {
+ if (UNLIKELY(!db_impl_->immutable_db_options().two_write_queues &&
+ s.ok())) {
+ // Note: RemovePrepared should be called after WriteImpl that publishsed
+ // the seq. Otherwise SmallestUnCommittedSeq optimization breaks.
+ wpt_db_->RemovePrepared(prepare_seq, prepare_batch_cnt_);
+ } // else RemovePrepared is called from within PreReleaseCallback
+ if (UNLIKELY(!do_one_write)) {
+ assert(!s.ok());
+ // Cleanup the prepared entry we added with add_prepared_callback
+ wpt_db_->RemovePrepared(commit_batch_seq, commit_batch_cnt);
+ }
+ return s;
+ } // else do the 2nd write to publish seq
+ // Note: the 2nd write comes with a performance penality. So if we have too
+ // many of commits accompanied with ComitTimeWriteBatch and yet we cannot
+ // enable use_only_the_last_commit_time_batch_for_recovery_ optimization,
+ // two_write_queues should be disabled to avoid many additional writes here.
+ const size_t kZeroData = 0;
+ // Update commit map only from the 2nd queue
+ WritePreparedCommitEntryPreReleaseCallback update_commit_map_with_aux_batch(
+ wpt_db_, db_impl_, prepare_seq, prepare_batch_cnt_, kZeroData,
+ commit_batch_seq, commit_batch_cnt);
+ WriteBatch empty_batch;
+ empty_batch.PutLogData(Slice());
+ // In the absence of Prepare markers, use Noop as a batch separator
+ WriteBatchInternal::InsertNoop(&empty_batch);
+ const bool DISABLE_MEMTABLE = true;
+ const size_t ONE_BATCH = 1;
+ const uint64_t NO_REF_LOG = 0;
+ s = db_impl_->WriteImpl(write_options_, &empty_batch, nullptr, nullptr,
+ NO_REF_LOG, DISABLE_MEMTABLE, &seq_used, ONE_BATCH,
+ &update_commit_map_with_aux_batch);
+ assert(!s.ok() || seq_used != kMaxSequenceNumber);
+ if (UNLIKELY(!db_impl_->immutable_db_options().two_write_queues)) {
+ if (s.ok()) {
+ // Note: RemovePrepared should be called after WriteImpl that publishsed
+ // the seq. Otherwise SmallestUnCommittedSeq optimization breaks.
+ wpt_db_->RemovePrepared(prepare_seq, prepare_batch_cnt_);
+ }
+ wpt_db_->RemovePrepared(commit_batch_seq, commit_batch_cnt);
+ } // else RemovePrepared is called from within PreReleaseCallback
+ return s;
+}
+
+Status WritePreparedTxn::RollbackInternal() {
+ ROCKS_LOG_WARN(db_impl_->immutable_db_options().info_log,
+ "RollbackInternal prepare_seq: %" PRIu64, GetId());
+ WriteBatch rollback_batch;
+ assert(GetId() != kMaxSequenceNumber);
+ assert(GetId() > 0);
+ auto cf_map_shared_ptr = wpt_db_->GetCFHandleMap();
+ auto cf_comp_map_shared_ptr = wpt_db_->GetCFComparatorMap();
+ auto read_at_seq = kMaxSequenceNumber;
+ ReadOptions roptions;
+ // to prevent callback's seq to be overrriden inside DBImpk::Get
+ roptions.snapshot = wpt_db_->GetMaxSnapshot();
+ struct RollbackWriteBatchBuilder : public WriteBatch::Handler {
+ DBImpl* db_;
+ WritePreparedTxnReadCallback callback;
+ WriteBatch* rollback_batch_;
+ std::map<uint32_t, const Comparator*>& comparators_;
+ std::map<uint32_t, ColumnFamilyHandle*>& handles_;
+ using CFKeys = std::set<Slice, SetComparator>;
+ std::map<uint32_t, CFKeys> keys_;
+ bool rollback_merge_operands_;
+ ReadOptions roptions_;
+ RollbackWriteBatchBuilder(
+ DBImpl* db, WritePreparedTxnDB* wpt_db, SequenceNumber snap_seq,
+ WriteBatch* dst_batch,
+ std::map<uint32_t, const Comparator*>& comparators,
+ std::map<uint32_t, ColumnFamilyHandle*>& handles,
+ bool rollback_merge_operands, ReadOptions _roptions)
+ : db_(db),
+ callback(wpt_db, snap_seq), // disable min_uncommitted optimization
+ rollback_batch_(dst_batch),
+ comparators_(comparators),
+ handles_(handles),
+ rollback_merge_operands_(rollback_merge_operands),
+ roptions_(_roptions) {}
+
+ Status Rollback(uint32_t cf, const Slice& key) {
+ Status s;
+ CFKeys& cf_keys = keys_[cf];
+ if (cf_keys.size() == 0) { // just inserted
+ auto cmp = comparators_[cf];
+ keys_[cf] = CFKeys(SetComparator(cmp));
+ }
+ auto it = cf_keys.insert(key);
+ if (it.second ==
+ false) { // second is false if a element already existed.
+ return s;
+ }
+
+ PinnableSlice pinnable_val;
+ bool not_used;
+ auto cf_handle = handles_[cf];
+ DBImpl::GetImplOptions get_impl_options;
+ get_impl_options.column_family = cf_handle;
+ get_impl_options.value = &pinnable_val;
+ get_impl_options.value_found = &not_used;
+ get_impl_options.callback = &callback;
+ s = db_->GetImpl(roptions_, key, get_impl_options);
+ assert(s.ok() || s.IsNotFound());
+ if (s.ok()) {
+ s = rollback_batch_->Put(cf_handle, key, pinnable_val);
+ assert(s.ok());
+ } else if (s.IsNotFound()) {
+ // There has been no readable value before txn. By adding a delete we
+ // make sure that there will be none afterwards either.
+ s = rollback_batch_->Delete(cf_handle, key);
+ assert(s.ok());
+ } else {
+ // Unexpected status. Return it to the user.
+ }
+ return s;
+ }
+
+ Status PutCF(uint32_t cf, const Slice& key, const Slice& /*val*/) override {
+ return Rollback(cf, key);
+ }
+
+ Status DeleteCF(uint32_t cf, const Slice& key) override {
+ return Rollback(cf, key);
+ }
+
+ Status SingleDeleteCF(uint32_t cf, const Slice& key) override {
+ return Rollback(cf, key);
+ }
+
+ Status MergeCF(uint32_t cf, const Slice& key,
+ const Slice& /*val*/) override {
+ if (rollback_merge_operands_) {
+ return Rollback(cf, key);
+ } else {
+ return Status::OK();
+ }
+ }
+
+ Status MarkNoop(bool) override { return Status::OK(); }
+ Status MarkBeginPrepare(bool) override { return Status::OK(); }
+ Status MarkEndPrepare(const Slice&) override { return Status::OK(); }
+ Status MarkCommit(const Slice&) override { return Status::OK(); }
+ Status MarkRollback(const Slice&) override {
+ return Status::InvalidArgument();
+ }
+
+ protected:
+ bool WriteAfterCommit() const override { return false; }
+ } rollback_handler(db_impl_, wpt_db_, read_at_seq, &rollback_batch,
+ *cf_comp_map_shared_ptr.get(), *cf_map_shared_ptr.get(),
+ wpt_db_->txn_db_options_.rollback_merge_operands,
+ roptions);
+ auto s = GetWriteBatch()->GetWriteBatch()->Iterate(&rollback_handler);
+ assert(s.ok());
+ if (!s.ok()) {
+ return s;
+ }
+ // The Rollback marker will be used as a batch separator
+ WriteBatchInternal::MarkRollback(&rollback_batch, name_);
+ bool do_one_write = !db_impl_->immutable_db_options().two_write_queues;
+ const bool DISABLE_MEMTABLE = true;
+ const uint64_t NO_REF_LOG = 0;
+ uint64_t seq_used = kMaxSequenceNumber;
+ const size_t ONE_BATCH = 1;
+ const bool kFirstPrepareBatch = true;
+ // We commit the rolled back prepared batches. Although this is
+ // counter-intuitive, i) it is safe to do so, since the prepared batches are
+ // already canceled out by the rollback batch, ii) adding the commit entry to
+ // CommitCache will allow us to benefit from the existing mechanism in
+ // CommitCache that keeps an entry evicted due to max advance and yet overlaps
+ // with a live snapshot around so that the live snapshot properly skips the
+ // entry even if its prepare seq is lower than max_evicted_seq_.
+ AddPreparedCallback add_prepared_callback(
+ wpt_db_, db_impl_, ONE_BATCH,
+ db_impl_->immutable_db_options().two_write_queues, !kFirstPrepareBatch);
+ WritePreparedCommitEntryPreReleaseCallback update_commit_map(
+ wpt_db_, db_impl_, GetId(), prepare_batch_cnt_, ONE_BATCH);
+ PreReleaseCallback* pre_release_callback;
+ if (do_one_write) {
+ pre_release_callback = &update_commit_map;
+ } else {
+ pre_release_callback = &add_prepared_callback;
+ }
+ // Note: the rollback batch does not need AddPrepared since it is written to
+ // DB in one shot. min_uncommitted still works since it requires capturing
+ // data that is written to DB but not yet committed, while
+ // the rollback batch commits with PreReleaseCallback.
+ s = db_impl_->WriteImpl(write_options_, &rollback_batch, nullptr, nullptr,
+ NO_REF_LOG, !DISABLE_MEMTABLE, &seq_used, ONE_BATCH,
+ pre_release_callback);
+ assert(!s.ok() || seq_used != kMaxSequenceNumber);
+ if (!s.ok()) {
+ return s;
+ }
+ if (do_one_write) {
+ assert(!db_impl_->immutable_db_options().two_write_queues);
+ wpt_db_->RemovePrepared(GetId(), prepare_batch_cnt_);
+ return s;
+ } // else do the 2nd write for commit
+ uint64_t rollback_seq = seq_used;
+ ROCKS_LOG_DETAILS(db_impl_->immutable_db_options().info_log,
+ "RollbackInternal 2nd write rollback_seq: %" PRIu64,
+ rollback_seq);
+ // Commit the batch by writing an empty batch to the queue that will release
+ // the commit sequence number to readers.
+ WritePreparedRollbackPreReleaseCallback update_commit_map_with_prepare(
+ wpt_db_, db_impl_, GetId(), rollback_seq, prepare_batch_cnt_);
+ WriteBatch empty_batch;
+ empty_batch.PutLogData(Slice());
+ // In the absence of Prepare markers, use Noop as a batch separator
+ WriteBatchInternal::InsertNoop(&empty_batch);
+ s = db_impl_->WriteImpl(write_options_, &empty_batch, nullptr, nullptr,
+ NO_REF_LOG, DISABLE_MEMTABLE, &seq_used, ONE_BATCH,
+ &update_commit_map_with_prepare);
+ assert(!s.ok() || seq_used != kMaxSequenceNumber);
+ ROCKS_LOG_DETAILS(db_impl_->immutable_db_options().info_log,
+ "RollbackInternal (status=%s) commit: %" PRIu64,
+ s.ToString().c_str(), GetId());
+ // TODO(lth): For WriteUnPrepared that rollback is called frequently,
+ // RemovePrepared could be moved to the callback to reduce lock contention.
+ if (s.ok()) {
+ wpt_db_->RemovePrepared(GetId(), prepare_batch_cnt_);
+ }
+ // Note: RemovePrepared for prepared batch is called from within
+ // PreReleaseCallback
+ wpt_db_->RemovePrepared(rollback_seq, ONE_BATCH);
+
+ return s;
+}
+
+Status WritePreparedTxn::ValidateSnapshot(ColumnFamilyHandle* column_family,
+ const Slice& key,
+ SequenceNumber* tracked_at_seq) {
+ assert(snapshot_);
+
+ SequenceNumber min_uncommitted =
+ static_cast_with_check<const SnapshotImpl, const Snapshot>(
+ snapshot_.get())
+ ->min_uncommitted_;
+ SequenceNumber snap_seq = snapshot_->GetSequenceNumber();
+ // tracked_at_seq is either max or the last snapshot with which this key was
+ // trackeed so there is no need to apply the IsInSnapshot to this comparison
+ // here as tracked_at_seq is not a prepare seq.
+ if (*tracked_at_seq <= snap_seq) {
+ // If the key has been previous validated at a sequence number earlier
+ // than the curent snapshot's sequence number, we already know it has not
+ // been modified.
+ return Status::OK();
+ }
+
+ *tracked_at_seq = snap_seq;
+
+ ColumnFamilyHandle* cfh =
+ column_family ? column_family : db_impl_->DefaultColumnFamily();
+
+ WritePreparedTxnReadCallback snap_checker(wpt_db_, snap_seq, min_uncommitted,
+ kBackedByDBSnapshot);
+ return TransactionUtil::CheckKeyForConflicts(db_impl_, cfh, key.ToString(),
+ snap_seq, false /* cache_only */,
+ &snap_checker, min_uncommitted);
+}
+
+void WritePreparedTxn::SetSnapshot() {
+ const bool kForWWConflictCheck = true;
+ SnapshotImpl* snapshot = wpt_db_->GetSnapshotInternal(kForWWConflictCheck);
+ SetSnapshotInternal(snapshot);
+}
+
+Status WritePreparedTxn::RebuildFromWriteBatch(WriteBatch* src_batch) {
+ auto ret = PessimisticTransaction::RebuildFromWriteBatch(src_batch);
+ prepare_batch_cnt_ = GetWriteBatch()->SubBatchCnt();
+ return ret;
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/write_prepared_txn.h b/src/rocksdb/utilities/transactions/write_prepared_txn.h
new file mode 100644
index 000000000..30d9bdb99
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/write_prepared_txn.h
@@ -0,0 +1,119 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#ifndef ROCKSDB_LITE
+
+#include <algorithm>
+#include <atomic>
+#include <mutex>
+#include <stack>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "db/write_callback.h"
+#include "rocksdb/db.h"
+#include "rocksdb/slice.h"
+#include "rocksdb/snapshot.h"
+#include "rocksdb/status.h"
+#include "rocksdb/types.h"
+#include "rocksdb/utilities/transaction.h"
+#include "rocksdb/utilities/transaction_db.h"
+#include "rocksdb/utilities/write_batch_with_index.h"
+#include "util/autovector.h"
+#include "utilities/transactions/pessimistic_transaction.h"
+#include "utilities/transactions/pessimistic_transaction_db.h"
+#include "utilities/transactions/transaction_base.h"
+#include "utilities/transactions/transaction_util.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class WritePreparedTxnDB;
+
+// This impl could write to DB also uncommitted data and then later tell apart
+// committed data from uncommitted data. Uncommitted data could be after the
+// Prepare phase in 2PC (WritePreparedTxn) or before that
+// (WriteUnpreparedTxnImpl).
+class WritePreparedTxn : public PessimisticTransaction {
+ public:
+ WritePreparedTxn(WritePreparedTxnDB* db, const WriteOptions& write_options,
+ const TransactionOptions& txn_options);
+ // No copying allowed
+ WritePreparedTxn(const WritePreparedTxn&) = delete;
+ void operator=(const WritePreparedTxn&) = delete;
+
+ virtual ~WritePreparedTxn() {}
+
+ // To make WAL commit markers visible, the snapshot will be based on the last
+ // seq in the WAL that is also published, LastPublishedSequence, as opposed to
+ // the last seq in the memtable.
+ using Transaction::Get;
+ virtual Status Get(const ReadOptions& options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ PinnableSlice* value) override;
+
+ using Transaction::MultiGet;
+ virtual void MultiGet(const ReadOptions& options,
+ ColumnFamilyHandle* column_family,
+ const size_t num_keys, const Slice* keys,
+ PinnableSlice* values, Status* statuses,
+ const bool sorted_input = false) override;
+
+ // Note: The behavior is undefined in presence of interleaved writes to the
+ // same transaction.
+ // To make WAL commit markers visible, the snapshot will be
+ // based on the last seq in the WAL that is also published,
+ // LastPublishedSequence, as opposed to the last seq in the memtable.
+ using Transaction::GetIterator;
+ virtual Iterator* GetIterator(const ReadOptions& options) override;
+ virtual Iterator* GetIterator(const ReadOptions& options,
+ ColumnFamilyHandle* column_family) override;
+
+ virtual void SetSnapshot() override;
+
+ protected:
+ void Initialize(const TransactionOptions& txn_options) override;
+ // Override the protected SetId to make it visible to the friend class
+ // WritePreparedTxnDB
+ inline void SetId(uint64_t id) override { Transaction::SetId(id); }
+
+ private:
+ friend class WritePreparedTransactionTest_BasicRecoveryTest_Test;
+ friend class WritePreparedTxnDB;
+ friend class WriteUnpreparedTxnDB;
+ friend class WriteUnpreparedTxn;
+
+ Status PrepareInternal() override;
+
+ Status CommitWithoutPrepareInternal() override;
+
+ Status CommitBatchInternal(WriteBatch* batch, size_t batch_cnt) override;
+
+ // Since the data is already written to memtables at the Prepare phase, the
+ // commit entails writing only a commit marker in the WAL. The sequence number
+ // of the commit marker is then the commit timestamp of the transaction. To
+ // make WAL commit markers visible, the snapshot will be based on the last seq
+ // in the WAL that is also published, LastPublishedSequence, as opposed to the
+ // last seq in the memtable.
+ Status CommitInternal() override;
+
+ Status RollbackInternal() override;
+
+ virtual Status ValidateSnapshot(ColumnFamilyHandle* column_family,
+ const Slice& key,
+ SequenceNumber* tracked_at_seq) override;
+
+ virtual Status RebuildFromWriteBatch(WriteBatch* src_batch) override;
+
+ WritePreparedTxnDB* wpt_db_;
+ // Number of sub-batches in prepare
+ size_t prepare_batch_cnt_ = 0;
+};
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/write_prepared_txn_db.cc b/src/rocksdb/utilities/transactions/write_prepared_txn_db.cc
new file mode 100644
index 000000000..051fae554
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/write_prepared_txn_db.cc
@@ -0,0 +1,998 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "utilities/transactions/write_prepared_txn_db.h"
+
+#include <algorithm>
+#include <cinttypes>
+#include <string>
+#include <unordered_set>
+#include <vector>
+
+#include "db/arena_wrapped_db_iter.h"
+#include "db/db_impl/db_impl.h"
+#include "rocksdb/db.h"
+#include "rocksdb/options.h"
+#include "rocksdb/utilities/transaction_db.h"
+#include "test_util/sync_point.h"
+#include "util/cast_util.h"
+#include "util/mutexlock.h"
+#include "util/string_util.h"
+#include "utilities/transactions/pessimistic_transaction.h"
+#include "utilities/transactions/transaction_db_mutex_impl.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+Status WritePreparedTxnDB::Initialize(
+ const std::vector<size_t>& compaction_enabled_cf_indices,
+ const std::vector<ColumnFamilyHandle*>& handles) {
+ auto dbimpl = static_cast_with_check<DBImpl, DB>(GetRootDB());
+ assert(dbimpl != nullptr);
+ auto rtxns = dbimpl->recovered_transactions();
+ std::map<SequenceNumber, SequenceNumber> ordered_seq_cnt;
+ for (auto rtxn : rtxns) {
+ // There should only one batch for WritePrepared policy.
+ assert(rtxn.second->batches_.size() == 1);
+ const auto& seq = rtxn.second->batches_.begin()->first;
+ const auto& batch_info = rtxn.second->batches_.begin()->second;
+ auto cnt = batch_info.batch_cnt_ ? batch_info.batch_cnt_ : 1;
+ ordered_seq_cnt[seq] = cnt;
+ }
+ // AddPrepared must be called in order
+ for (auto seq_cnt : ordered_seq_cnt) {
+ auto seq = seq_cnt.first;
+ auto cnt = seq_cnt.second;
+ for (size_t i = 0; i < cnt; i++) {
+ AddPrepared(seq + i);
+ }
+ }
+ SequenceNumber prev_max = max_evicted_seq_;
+ SequenceNumber last_seq = db_impl_->GetLatestSequenceNumber();
+ AdvanceMaxEvictedSeq(prev_max, last_seq);
+ // Create a gap between max and the next snapshot. This simplifies the logic
+ // in IsInSnapshot by not having to consider the special case of max ==
+ // snapshot after recovery. This is tested in IsInSnapshotEmptyMapTest.
+ if (last_seq) {
+ db_impl_->versions_->SetLastAllocatedSequence(last_seq + 1);
+ db_impl_->versions_->SetLastSequence(last_seq + 1);
+ db_impl_->versions_->SetLastPublishedSequence(last_seq + 1);
+ }
+
+ db_impl_->SetSnapshotChecker(new WritePreparedSnapshotChecker(this));
+ // A callback to commit a single sub-batch
+ class CommitSubBatchPreReleaseCallback : public PreReleaseCallback {
+ public:
+ explicit CommitSubBatchPreReleaseCallback(WritePreparedTxnDB* db)
+ : db_(db) {}
+ Status Callback(SequenceNumber commit_seq,
+ bool is_mem_disabled __attribute__((__unused__)), uint64_t,
+ size_t /*index*/, size_t /*total*/) override {
+ assert(!is_mem_disabled);
+ db_->AddCommitted(commit_seq, commit_seq);
+ return Status::OK();
+ }
+
+ private:
+ WritePreparedTxnDB* db_;
+ };
+ db_impl_->SetRecoverableStatePreReleaseCallback(
+ new CommitSubBatchPreReleaseCallback(this));
+
+ auto s = PessimisticTransactionDB::Initialize(compaction_enabled_cf_indices,
+ handles);
+ return s;
+}
+
+Status WritePreparedTxnDB::VerifyCFOptions(
+ const ColumnFamilyOptions& cf_options) {
+ Status s = PessimisticTransactionDB::VerifyCFOptions(cf_options);
+ if (!s.ok()) {
+ return s;
+ }
+ if (!cf_options.memtable_factory->CanHandleDuplicatedKey()) {
+ return Status::InvalidArgument(
+ "memtable_factory->CanHandleDuplicatedKey() cannot be false with "
+ "WritePrpeared transactions");
+ }
+ return Status::OK();
+}
+
+Transaction* WritePreparedTxnDB::BeginTransaction(
+ const WriteOptions& write_options, const TransactionOptions& txn_options,
+ Transaction* old_txn) {
+ if (old_txn != nullptr) {
+ ReinitializeTransaction(old_txn, write_options, txn_options);
+ return old_txn;
+ } else {
+ return new WritePreparedTxn(this, write_options, txn_options);
+ }
+}
+
+Status WritePreparedTxnDB::Write(const WriteOptions& opts,
+ WriteBatch* updates) {
+ if (txn_db_options_.skip_concurrency_control) {
+ // Skip locking the rows
+ const size_t UNKNOWN_BATCH_CNT = 0;
+ WritePreparedTxn* NO_TXN = nullptr;
+ return WriteInternal(opts, updates, UNKNOWN_BATCH_CNT, NO_TXN);
+ } else {
+ return PessimisticTransactionDB::WriteWithConcurrencyControl(opts, updates);
+ }
+}
+
+Status WritePreparedTxnDB::Write(
+ const WriteOptions& opts,
+ const TransactionDBWriteOptimizations& optimizations, WriteBatch* updates) {
+ if (optimizations.skip_concurrency_control) {
+ // Skip locking the rows
+ const size_t UNKNOWN_BATCH_CNT = 0;
+ const size_t ONE_BATCH_CNT = 1;
+ const size_t batch_cnt = optimizations.skip_duplicate_key_check
+ ? ONE_BATCH_CNT
+ : UNKNOWN_BATCH_CNT;
+ WritePreparedTxn* NO_TXN = nullptr;
+ return WriteInternal(opts, updates, batch_cnt, NO_TXN);
+ } else {
+ // TODO(myabandeh): Make use of skip_duplicate_key_check hint
+ // Fall back to unoptimized version
+ return PessimisticTransactionDB::WriteWithConcurrencyControl(opts, updates);
+ }
+}
+
+Status WritePreparedTxnDB::WriteInternal(const WriteOptions& write_options_orig,
+ WriteBatch* batch, size_t batch_cnt,
+ WritePreparedTxn* txn) {
+ ROCKS_LOG_DETAILS(db_impl_->immutable_db_options().info_log,
+ "CommitBatchInternal");
+ if (batch->Count() == 0) {
+ // Otherwise our 1 seq per batch logic will break since there is no seq
+ // increased for this batch.
+ return Status::OK();
+ }
+ if (batch_cnt == 0) { // not provided, then compute it
+ // TODO(myabandeh): add an option to allow user skipping this cost
+ SubBatchCounter counter(*GetCFComparatorMap());
+ auto s = batch->Iterate(&counter);
+ assert(s.ok());
+ batch_cnt = counter.BatchCount();
+ WPRecordTick(TXN_DUPLICATE_KEY_OVERHEAD);
+ ROCKS_LOG_DETAILS(info_log_, "Duplicate key overhead: %" PRIu64 " batches",
+ static_cast<uint64_t>(batch_cnt));
+ }
+ assert(batch_cnt);
+
+ bool do_one_write = !db_impl_->immutable_db_options().two_write_queues;
+ WriteOptions write_options(write_options_orig);
+ // In the absence of Prepare markers, use Noop as a batch separator
+ WriteBatchInternal::InsertNoop(batch);
+ const bool DISABLE_MEMTABLE = true;
+ const uint64_t no_log_ref = 0;
+ uint64_t seq_used = kMaxSequenceNumber;
+ const size_t ZERO_PREPARES = 0;
+ const bool kSeperatePrepareCommitBatches = true;
+ // Since this is not 2pc, there is no need for AddPrepared but having it in
+ // the PreReleaseCallback enables an optimization. Refer to
+ // SmallestUnCommittedSeq for more details.
+ AddPreparedCallback add_prepared_callback(
+ this, db_impl_, batch_cnt,
+ db_impl_->immutable_db_options().two_write_queues,
+ !kSeperatePrepareCommitBatches);
+ WritePreparedCommitEntryPreReleaseCallback update_commit_map(
+ this, db_impl_, kMaxSequenceNumber, ZERO_PREPARES, batch_cnt);
+ PreReleaseCallback* pre_release_callback;
+ if (do_one_write) {
+ pre_release_callback = &update_commit_map;
+ } else {
+ pre_release_callback = &add_prepared_callback;
+ }
+ auto s = db_impl_->WriteImpl(write_options, batch, nullptr, nullptr,
+ no_log_ref, !DISABLE_MEMTABLE, &seq_used,
+ batch_cnt, pre_release_callback);
+ assert(!s.ok() || seq_used != kMaxSequenceNumber);
+ uint64_t prepare_seq = seq_used;
+ if (txn != nullptr) {
+ txn->SetId(prepare_seq);
+ }
+ if (!s.ok()) {
+ return s;
+ }
+ if (do_one_write) {
+ return s;
+ } // else do the 2nd write for commit
+ ROCKS_LOG_DETAILS(db_impl_->immutable_db_options().info_log,
+ "CommitBatchInternal 2nd write prepare_seq: %" PRIu64,
+ prepare_seq);
+ // Commit the batch by writing an empty batch to the 2nd queue that will
+ // release the commit sequence number to readers.
+ const size_t ZERO_COMMITS = 0;
+ WritePreparedCommitEntryPreReleaseCallback update_commit_map_with_prepare(
+ this, db_impl_, prepare_seq, batch_cnt, ZERO_COMMITS);
+ WriteBatch empty_batch;
+ write_options.disableWAL = true;
+ write_options.sync = false;
+ const size_t ONE_BATCH = 1; // Just to inc the seq
+ s = db_impl_->WriteImpl(write_options, &empty_batch, nullptr, nullptr,
+ no_log_ref, DISABLE_MEMTABLE, &seq_used, ONE_BATCH,
+ &update_commit_map_with_prepare);
+ assert(!s.ok() || seq_used != kMaxSequenceNumber);
+ // Note: RemovePrepared is called from within PreReleaseCallback
+ return s;
+}
+
+Status WritePreparedTxnDB::Get(const ReadOptions& options,
+ ColumnFamilyHandle* column_family,
+ const Slice& key, PinnableSlice* value) {
+ SequenceNumber min_uncommitted, snap_seq;
+ const SnapshotBackup backed_by_snapshot =
+ AssignMinMaxSeqs(options.snapshot, &min_uncommitted, &snap_seq);
+ WritePreparedTxnReadCallback callback(this, snap_seq, min_uncommitted,
+ backed_by_snapshot);
+ bool* dont_care = nullptr;
+ DBImpl::GetImplOptions get_impl_options;
+ get_impl_options.column_family = column_family;
+ get_impl_options.value = value;
+ get_impl_options.value_found = dont_care;
+ get_impl_options.callback = &callback;
+ auto res = db_impl_->GetImpl(options, key, get_impl_options);
+ if (LIKELY(callback.valid() && ValidateSnapshot(callback.max_visible_seq(),
+ backed_by_snapshot))) {
+ return res;
+ } else {
+ WPRecordTick(TXN_GET_TRY_AGAIN);
+ return Status::TryAgain();
+ }
+}
+
+void WritePreparedTxnDB::UpdateCFComparatorMap(
+ const std::vector<ColumnFamilyHandle*>& handles) {
+ auto cf_map = new std::map<uint32_t, const Comparator*>();
+ auto handle_map = new std::map<uint32_t, ColumnFamilyHandle*>();
+ for (auto h : handles) {
+ auto id = h->GetID();
+ const Comparator* comparator = h->GetComparator();
+ (*cf_map)[id] = comparator;
+ if (id != 0) {
+ (*handle_map)[id] = h;
+ } else {
+ // The pointer to the default cf handle in the handles will be deleted.
+ // Use the pointer maintained by the db instead.
+ (*handle_map)[id] = DefaultColumnFamily();
+ }
+ }
+ cf_map_.reset(cf_map);
+ handle_map_.reset(handle_map);
+}
+
+void WritePreparedTxnDB::UpdateCFComparatorMap(ColumnFamilyHandle* h) {
+ auto old_cf_map_ptr = cf_map_.get();
+ assert(old_cf_map_ptr);
+ auto cf_map = new std::map<uint32_t, const Comparator*>(*old_cf_map_ptr);
+ auto old_handle_map_ptr = handle_map_.get();
+ assert(old_handle_map_ptr);
+ auto handle_map =
+ new std::map<uint32_t, ColumnFamilyHandle*>(*old_handle_map_ptr);
+ auto id = h->GetID();
+ const Comparator* comparator = h->GetComparator();
+ (*cf_map)[id] = comparator;
+ (*handle_map)[id] = h;
+ cf_map_.reset(cf_map);
+ handle_map_.reset(handle_map);
+}
+
+
+std::vector<Status> WritePreparedTxnDB::MultiGet(
+ const ReadOptions& options,
+ const std::vector<ColumnFamilyHandle*>& column_family,
+ const std::vector<Slice>& keys, std::vector<std::string>* values) {
+ assert(values);
+ size_t num_keys = keys.size();
+ values->resize(num_keys);
+
+ std::vector<Status> stat_list(num_keys);
+ for (size_t i = 0; i < num_keys; ++i) {
+ std::string* value = values ? &(*values)[i] : nullptr;
+ stat_list[i] = this->Get(options, column_family[i], keys[i], value);
+ }
+ return stat_list;
+}
+
+// Struct to hold ownership of snapshot and read callback for iterator cleanup.
+struct WritePreparedTxnDB::IteratorState {
+ IteratorState(WritePreparedTxnDB* txn_db, SequenceNumber sequence,
+ std::shared_ptr<ManagedSnapshot> s,
+ SequenceNumber min_uncommitted)
+ : callback(txn_db, sequence, min_uncommitted, kBackedByDBSnapshot),
+ snapshot(s) {}
+
+ WritePreparedTxnReadCallback callback;
+ std::shared_ptr<ManagedSnapshot> snapshot;
+};
+
+namespace {
+static void CleanupWritePreparedTxnDBIterator(void* arg1, void* /*arg2*/) {
+ delete reinterpret_cast<WritePreparedTxnDB::IteratorState*>(arg1);
+}
+} // anonymous namespace
+
+Iterator* WritePreparedTxnDB::NewIterator(const ReadOptions& options,
+ ColumnFamilyHandle* column_family) {
+ constexpr bool ALLOW_BLOB = true;
+ constexpr bool ALLOW_REFRESH = true;
+ std::shared_ptr<ManagedSnapshot> own_snapshot = nullptr;
+ SequenceNumber snapshot_seq = kMaxSequenceNumber;
+ SequenceNumber min_uncommitted = 0;
+ if (options.snapshot != nullptr) {
+ snapshot_seq = options.snapshot->GetSequenceNumber();
+ min_uncommitted =
+ static_cast_with_check<const SnapshotImpl, const Snapshot>(
+ options.snapshot)
+ ->min_uncommitted_;
+ } else {
+ auto* snapshot = GetSnapshot();
+ // We take a snapshot to make sure that the related data in the commit map
+ // are not deleted.
+ snapshot_seq = snapshot->GetSequenceNumber();
+ min_uncommitted =
+ static_cast_with_check<const SnapshotImpl, const Snapshot>(snapshot)
+ ->min_uncommitted_;
+ own_snapshot = std::make_shared<ManagedSnapshot>(db_impl_, snapshot);
+ }
+ assert(snapshot_seq != kMaxSequenceNumber);
+ auto* cfd = reinterpret_cast<ColumnFamilyHandleImpl*>(column_family)->cfd();
+ auto* state =
+ new IteratorState(this, snapshot_seq, own_snapshot, min_uncommitted);
+ auto* db_iter =
+ db_impl_->NewIteratorImpl(options, cfd, snapshot_seq, &state->callback,
+ !ALLOW_BLOB, !ALLOW_REFRESH);
+ db_iter->RegisterCleanup(CleanupWritePreparedTxnDBIterator, state, nullptr);
+ return db_iter;
+}
+
+Status WritePreparedTxnDB::NewIterators(
+ const ReadOptions& options,
+ const std::vector<ColumnFamilyHandle*>& column_families,
+ std::vector<Iterator*>* iterators) {
+ constexpr bool ALLOW_BLOB = true;
+ constexpr bool ALLOW_REFRESH = true;
+ std::shared_ptr<ManagedSnapshot> own_snapshot = nullptr;
+ SequenceNumber snapshot_seq = kMaxSequenceNumber;
+ SequenceNumber min_uncommitted = 0;
+ if (options.snapshot != nullptr) {
+ snapshot_seq = options.snapshot->GetSequenceNumber();
+ min_uncommitted = static_cast_with_check<const SnapshotImpl, const Snapshot>(
+ options.snapshot)
+ ->min_uncommitted_;
+ } else {
+ auto* snapshot = GetSnapshot();
+ // We take a snapshot to make sure that the related data in the commit map
+ // are not deleted.
+ snapshot_seq = snapshot->GetSequenceNumber();
+ own_snapshot = std::make_shared<ManagedSnapshot>(db_impl_, snapshot);
+ min_uncommitted =
+ static_cast_with_check<const SnapshotImpl, const Snapshot>(snapshot)
+ ->min_uncommitted_;
+ }
+ iterators->clear();
+ iterators->reserve(column_families.size());
+ for (auto* column_family : column_families) {
+ auto* cfd = reinterpret_cast<ColumnFamilyHandleImpl*>(column_family)->cfd();
+ auto* state =
+ new IteratorState(this, snapshot_seq, own_snapshot, min_uncommitted);
+ auto* db_iter =
+ db_impl_->NewIteratorImpl(options, cfd, snapshot_seq, &state->callback,
+ !ALLOW_BLOB, !ALLOW_REFRESH);
+ db_iter->RegisterCleanup(CleanupWritePreparedTxnDBIterator, state, nullptr);
+ iterators->push_back(db_iter);
+ }
+ return Status::OK();
+}
+
+void WritePreparedTxnDB::Init(const TransactionDBOptions& /* unused */) {
+ // Adcance max_evicted_seq_ no more than 100 times before the cache wraps
+ // around.
+ INC_STEP_FOR_MAX_EVICTED =
+ std::max(COMMIT_CACHE_SIZE / 100, static_cast<size_t>(1));
+ snapshot_cache_ = std::unique_ptr<std::atomic<SequenceNumber>[]>(
+ new std::atomic<SequenceNumber>[SNAPSHOT_CACHE_SIZE] {});
+ commit_cache_ = std::unique_ptr<std::atomic<CommitEntry64b>[]>(
+ new std::atomic<CommitEntry64b>[COMMIT_CACHE_SIZE] {});
+ dummy_max_snapshot_.number_ = kMaxSequenceNumber;
+}
+
+void WritePreparedTxnDB::CheckPreparedAgainstMax(SequenceNumber new_max,
+ bool locked) {
+ // When max_evicted_seq_ advances, move older entries from prepared_txns_
+ // to delayed_prepared_. This guarantees that if a seq is lower than max,
+ // then it is not in prepared_txns_ and save an expensive, synchronized
+ // lookup from a shared set. delayed_prepared_ is expected to be empty in
+ // normal cases.
+ ROCKS_LOG_DETAILS(
+ info_log_,
+ "CheckPreparedAgainstMax prepared_txns_.empty() %d top: %" PRIu64,
+ prepared_txns_.empty(),
+ prepared_txns_.empty() ? 0 : prepared_txns_.top());
+ const SequenceNumber prepared_top = prepared_txns_.top();
+ const bool empty = prepared_top == kMaxSequenceNumber;
+ // Preliminary check to avoid the synchronization cost
+ if (!empty && prepared_top <= new_max) {
+ if (locked) {
+ // Needed to avoid double locking in pop().
+ prepared_txns_.push_pop_mutex()->Unlock();
+ }
+ WriteLock wl(&prepared_mutex_);
+ // Need to fetch fresh values of ::top after mutex is acquired
+ while (!prepared_txns_.empty() && prepared_txns_.top() <= new_max) {
+ auto to_be_popped = prepared_txns_.top();
+ delayed_prepared_.insert(to_be_popped);
+ ROCKS_LOG_WARN(info_log_,
+ "prepared_mutex_ overhead %" PRIu64 " (prep=%" PRIu64
+ " new_max=%" PRIu64,
+ static_cast<uint64_t>(delayed_prepared_.size()),
+ to_be_popped, new_max);
+ delayed_prepared_empty_.store(false, std::memory_order_release);
+ // Update prepared_txns_ after updating delayed_prepared_empty_ otherwise
+ // there will be a point in time that the entry is neither in
+ // prepared_txns_ nor in delayed_prepared_, which will not be checked if
+ // delayed_prepared_empty_ is false.
+ prepared_txns_.pop();
+ }
+ if (locked) {
+ prepared_txns_.push_pop_mutex()->Lock();
+ }
+ }
+}
+
+void WritePreparedTxnDB::AddPrepared(uint64_t seq, bool locked) {
+ ROCKS_LOG_DETAILS(info_log_, "Txn %" PRIu64 " Preparing with max %" PRIu64,
+ seq, max_evicted_seq_.load());
+ TEST_SYNC_POINT("AddPrepared::begin:pause");
+ TEST_SYNC_POINT("AddPrepared::begin:resume");
+ if (!locked) {
+ prepared_txns_.push_pop_mutex()->Lock();
+ }
+ prepared_txns_.push_pop_mutex()->AssertHeld();
+ prepared_txns_.push(seq);
+ auto new_max = future_max_evicted_seq_.load();
+ if (UNLIKELY(seq <= new_max)) {
+ // This should not happen in normal case
+ ROCKS_LOG_ERROR(
+ info_log_,
+ "Added prepare_seq is not larger than max_evicted_seq_: %" PRIu64
+ " <= %" PRIu64,
+ seq, new_max);
+ CheckPreparedAgainstMax(new_max, true /*locked*/);
+ }
+ if (!locked) {
+ prepared_txns_.push_pop_mutex()->Unlock();
+ }
+ TEST_SYNC_POINT("AddPrepared::end");
+}
+
+void WritePreparedTxnDB::AddCommitted(uint64_t prepare_seq, uint64_t commit_seq,
+ uint8_t loop_cnt) {
+ ROCKS_LOG_DETAILS(info_log_, "Txn %" PRIu64 " Committing with %" PRIu64,
+ prepare_seq, commit_seq);
+ TEST_SYNC_POINT("WritePreparedTxnDB::AddCommitted:start");
+ TEST_SYNC_POINT("WritePreparedTxnDB::AddCommitted:start:pause");
+ auto indexed_seq = prepare_seq % COMMIT_CACHE_SIZE;
+ CommitEntry64b evicted_64b;
+ CommitEntry evicted;
+ bool to_be_evicted = GetCommitEntry(indexed_seq, &evicted_64b, &evicted);
+ if (LIKELY(to_be_evicted)) {
+ assert(evicted.prep_seq != prepare_seq);
+ auto prev_max = max_evicted_seq_.load(std::memory_order_acquire);
+ ROCKS_LOG_DETAILS(info_log_,
+ "Evicting %" PRIu64 ",%" PRIu64 " with max %" PRIu64,
+ evicted.prep_seq, evicted.commit_seq, prev_max);
+ if (prev_max < evicted.commit_seq) {
+ auto last = db_impl_->GetLastPublishedSequence(); // could be 0
+ SequenceNumber max_evicted_seq;
+ if (LIKELY(evicted.commit_seq < last)) {
+ assert(last > 0);
+ // Inc max in larger steps to avoid frequent updates
+ max_evicted_seq =
+ std::min(evicted.commit_seq + INC_STEP_FOR_MAX_EVICTED, last - 1);
+ } else {
+ // legit when a commit entry in a write batch overwrite the previous one
+ max_evicted_seq = evicted.commit_seq;
+ }
+ ROCKS_LOG_DETAILS(info_log_,
+ "%lu Evicting %" PRIu64 ",%" PRIu64 " with max %" PRIu64
+ " => %lu",
+ prepare_seq, evicted.prep_seq, evicted.commit_seq,
+ prev_max, max_evicted_seq);
+ AdvanceMaxEvictedSeq(prev_max, max_evicted_seq);
+ }
+ // After each eviction from commit cache, check if the commit entry should
+ // be kept around because it overlaps with a live snapshot.
+ CheckAgainstSnapshots(evicted);
+ if (UNLIKELY(!delayed_prepared_empty_.load(std::memory_order_acquire))) {
+ WriteLock wl(&prepared_mutex_);
+ for (auto dp : delayed_prepared_) {
+ if (dp == evicted.prep_seq) {
+ // This is a rare case that txn is committed but prepared_txns_ is not
+ // cleaned up yet. Refer to delayed_prepared_commits_ definition for
+ // why it should be kept updated.
+ delayed_prepared_commits_[evicted.prep_seq] = evicted.commit_seq;
+ ROCKS_LOG_DEBUG(info_log_,
+ "delayed_prepared_commits_[%" PRIu64 "]=%" PRIu64,
+ evicted.prep_seq, evicted.commit_seq);
+ break;
+ }
+ }
+ }
+ }
+ bool succ =
+ ExchangeCommitEntry(indexed_seq, evicted_64b, {prepare_seq, commit_seq});
+ if (UNLIKELY(!succ)) {
+ ROCKS_LOG_ERROR(info_log_,
+ "ExchangeCommitEntry failed on [%" PRIu64 "] %" PRIu64
+ ",%" PRIu64 " retrying...",
+ indexed_seq, prepare_seq, commit_seq);
+ // A very rare event, in which the commit entry is updated before we do.
+ // Here we apply a very simple solution of retrying.
+ if (loop_cnt > 100) {
+ throw std::runtime_error("Infinite loop in AddCommitted!");
+ }
+ AddCommitted(prepare_seq, commit_seq, ++loop_cnt);
+ return;
+ }
+ TEST_SYNC_POINT("WritePreparedTxnDB::AddCommitted:end");
+ TEST_SYNC_POINT("WritePreparedTxnDB::AddCommitted:end:pause");
+}
+
+void WritePreparedTxnDB::RemovePrepared(const uint64_t prepare_seq,
+ const size_t batch_cnt) {
+ TEST_SYNC_POINT_CALLBACK(
+ "RemovePrepared:Start",
+ const_cast<void*>(reinterpret_cast<const void*>(&prepare_seq)));
+ TEST_SYNC_POINT("WritePreparedTxnDB::RemovePrepared:pause");
+ TEST_SYNC_POINT("WritePreparedTxnDB::RemovePrepared:resume");
+ ROCKS_LOG_DETAILS(info_log_,
+ "RemovePrepared %" PRIu64 " cnt: %" ROCKSDB_PRIszt,
+ prepare_seq, batch_cnt);
+ WriteLock wl(&prepared_mutex_);
+ for (size_t i = 0; i < batch_cnt; i++) {
+ prepared_txns_.erase(prepare_seq + i);
+ bool was_empty = delayed_prepared_.empty();
+ if (!was_empty) {
+ delayed_prepared_.erase(prepare_seq + i);
+ auto it = delayed_prepared_commits_.find(prepare_seq + i);
+ if (it != delayed_prepared_commits_.end()) {
+ ROCKS_LOG_DETAILS(info_log_, "delayed_prepared_commits_.erase %" PRIu64,
+ prepare_seq + i);
+ delayed_prepared_commits_.erase(it);
+ }
+ bool is_empty = delayed_prepared_.empty();
+ if (was_empty != is_empty) {
+ delayed_prepared_empty_.store(is_empty, std::memory_order_release);
+ }
+ }
+ }
+}
+
+bool WritePreparedTxnDB::GetCommitEntry(const uint64_t indexed_seq,
+ CommitEntry64b* entry_64b,
+ CommitEntry* entry) const {
+ *entry_64b = commit_cache_[static_cast<size_t>(indexed_seq)].load(std::memory_order_acquire);
+ bool valid = entry_64b->Parse(indexed_seq, entry, FORMAT);
+ return valid;
+}
+
+bool WritePreparedTxnDB::AddCommitEntry(const uint64_t indexed_seq,
+ const CommitEntry& new_entry,
+ CommitEntry* evicted_entry) {
+ CommitEntry64b new_entry_64b(new_entry, FORMAT);
+ CommitEntry64b evicted_entry_64b = commit_cache_[static_cast<size_t>(indexed_seq)].exchange(
+ new_entry_64b, std::memory_order_acq_rel);
+ bool valid = evicted_entry_64b.Parse(indexed_seq, evicted_entry, FORMAT);
+ return valid;
+}
+
+bool WritePreparedTxnDB::ExchangeCommitEntry(const uint64_t indexed_seq,
+ CommitEntry64b& expected_entry_64b,
+ const CommitEntry& new_entry) {
+ auto& atomic_entry = commit_cache_[static_cast<size_t>(indexed_seq)];
+ CommitEntry64b new_entry_64b(new_entry, FORMAT);
+ bool succ = atomic_entry.compare_exchange_strong(
+ expected_entry_64b, new_entry_64b, std::memory_order_acq_rel,
+ std::memory_order_acquire);
+ return succ;
+}
+
+void WritePreparedTxnDB::AdvanceMaxEvictedSeq(const SequenceNumber& prev_max,
+ const SequenceNumber& new_max) {
+ ROCKS_LOG_DETAILS(info_log_,
+ "AdvanceMaxEvictedSeq overhead %" PRIu64 " => %" PRIu64,
+ prev_max, new_max);
+ // Declare the intention before getting snapshot from the DB. This helps a
+ // concurrent GetSnapshot to wait to catch up with future_max_evicted_seq_ if
+ // it has not already. Otherwise the new snapshot is when we ask DB for
+ // snapshots smaller than future max.
+ auto updated_future_max = prev_max;
+ while (updated_future_max < new_max &&
+ !future_max_evicted_seq_.compare_exchange_weak(
+ updated_future_max, new_max, std::memory_order_acq_rel,
+ std::memory_order_relaxed)) {
+ };
+
+ CheckPreparedAgainstMax(new_max, false /*locked*/);
+
+ // With each change to max_evicted_seq_ fetch the live snapshots behind it.
+ // We use max as the version of snapshots to identify how fresh are the
+ // snapshot list. This works because the snapshots are between 0 and
+ // max, so the larger the max, the more complete they are.
+ SequenceNumber new_snapshots_version = new_max;
+ std::vector<SequenceNumber> snapshots;
+ bool update_snapshots = false;
+ if (new_snapshots_version > snapshots_version_) {
+ // This is to avoid updating the snapshots_ if it already updated
+ // with a more recent vesion by a concrrent thread
+ update_snapshots = true;
+ // We only care about snapshots lower then max
+ snapshots = GetSnapshotListFromDB(new_max);
+ }
+ if (update_snapshots) {
+ UpdateSnapshots(snapshots, new_snapshots_version);
+ if (!snapshots.empty()) {
+ WriteLock wl(&old_commit_map_mutex_);
+ for (auto snap : snapshots) {
+ // This allows IsInSnapshot to tell apart the reads from in valid
+ // snapshots from the reads from committed values in valid snapshots.
+ old_commit_map_[snap];
+ }
+ old_commit_map_empty_.store(false, std::memory_order_release);
+ }
+ }
+ auto updated_prev_max = prev_max;
+ TEST_SYNC_POINT("AdvanceMaxEvictedSeq::update_max:pause");
+ TEST_SYNC_POINT("AdvanceMaxEvictedSeq::update_max:resume");
+ while (updated_prev_max < new_max &&
+ !max_evicted_seq_.compare_exchange_weak(updated_prev_max, new_max,
+ std::memory_order_acq_rel,
+ std::memory_order_relaxed)) {
+ };
+}
+
+const Snapshot* WritePreparedTxnDB::GetSnapshot() {
+ const bool kForWWConflictCheck = true;
+ return GetSnapshotInternal(!kForWWConflictCheck);
+}
+
+SnapshotImpl* WritePreparedTxnDB::GetSnapshotInternal(
+ bool for_ww_conflict_check) {
+ // Note: for this optimization setting the last sequence number and obtaining
+ // the smallest uncommitted seq should be done atomically. However to avoid
+ // the mutex overhead, we call SmallestUnCommittedSeq BEFORE taking the
+ // snapshot. Since we always updated the list of unprepared seq (via
+ // AddPrepared) AFTER the last sequence is updated, this guarantees that the
+ // smallest uncommitted seq that we pair with the snapshot is smaller or equal
+ // the value that would be obtained otherwise atomically. That is ok since
+ // this optimization works as long as min_uncommitted is less than or equal
+ // than the smallest uncommitted seq when the snapshot was taken.
+ auto min_uncommitted = WritePreparedTxnDB::SmallestUnCommittedSeq();
+ SnapshotImpl* snap_impl = db_impl_->GetSnapshotImpl(for_ww_conflict_check);
+ TEST_SYNC_POINT("WritePreparedTxnDB::GetSnapshotInternal:first");
+ assert(snap_impl);
+ SequenceNumber snap_seq = snap_impl->GetSequenceNumber();
+ // Note: Check against future_max_evicted_seq_ (in contrast with
+ // max_evicted_seq_) in case there is a concurrent AdvanceMaxEvictedSeq.
+ if (UNLIKELY(snap_seq != 0 && snap_seq <= future_max_evicted_seq_)) {
+ // There is a very rare case in which the commit entry evicts another commit
+ // entry that is not published yet thus advancing max evicted seq beyond the
+ // last published seq. This case is not likely in real-world setup so we
+ // handle it with a few retries.
+ size_t retry = 0;
+ SequenceNumber max;
+ while ((max = future_max_evicted_seq_.load()) != 0 &&
+ snap_impl->GetSequenceNumber() <= max && retry < 100) {
+ ROCKS_LOG_WARN(info_log_,
+ "GetSnapshot snap: %" PRIu64 " max: %" PRIu64
+ " retry %" ROCKSDB_PRIszt,
+ snap_impl->GetSequenceNumber(), max, retry);
+ ReleaseSnapshot(snap_impl);
+ // Wait for last visible seq to catch up with max, and also go beyond it
+ // by one.
+ AdvanceSeqByOne();
+ snap_impl = db_impl_->GetSnapshotImpl(for_ww_conflict_check);
+ assert(snap_impl);
+ retry++;
+ }
+ assert(snap_impl->GetSequenceNumber() > max);
+ if (snap_impl->GetSequenceNumber() <= max) {
+ throw std::runtime_error(
+ "Snapshot seq " + ToString(snap_impl->GetSequenceNumber()) +
+ " after " + ToString(retry) +
+ " retries is still less than futre_max_evicted_seq_" + ToString(max));
+ }
+ }
+ EnhanceSnapshot(snap_impl, min_uncommitted);
+ ROCKS_LOG_DETAILS(
+ db_impl_->immutable_db_options().info_log,
+ "GetSnapshot %" PRIu64 " ww:%" PRIi32 " min_uncommitted: %" PRIu64,
+ snap_impl->GetSequenceNumber(), for_ww_conflict_check, min_uncommitted);
+ TEST_SYNC_POINT("WritePreparedTxnDB::GetSnapshotInternal:end");
+ return snap_impl;
+}
+
+void WritePreparedTxnDB::AdvanceSeqByOne() {
+ // Inserting an empty value will i) let the max evicted entry to be
+ // published, i.e., max == last_published, increase the last published to
+ // be one beyond max, i.e., max < last_published.
+ WriteOptions woptions;
+ TransactionOptions txn_options;
+ Transaction* txn0 = BeginTransaction(woptions, txn_options, nullptr);
+ std::hash<std::thread::id> hasher;
+ char name[64];
+ snprintf(name, 64, "txn%" ROCKSDB_PRIszt, hasher(std::this_thread::get_id()));
+ assert(strlen(name) < 64 - 1);
+ Status s = txn0->SetName(name);
+ assert(s.ok());
+ if (s.ok()) {
+ // Without prepare it would simply skip the commit
+ s = txn0->Prepare();
+ }
+ assert(s.ok());
+ if (s.ok()) {
+ s = txn0->Commit();
+ }
+ assert(s.ok());
+ delete txn0;
+}
+
+const std::vector<SequenceNumber> WritePreparedTxnDB::GetSnapshotListFromDB(
+ SequenceNumber max) {
+ ROCKS_LOG_DETAILS(info_log_, "GetSnapshotListFromDB with max %" PRIu64, max);
+ InstrumentedMutexLock dblock(db_impl_->mutex());
+ db_impl_->mutex()->AssertHeld();
+ return db_impl_->snapshots().GetAll(nullptr, max);
+}
+
+void WritePreparedTxnDB::ReleaseSnapshotInternal(
+ const SequenceNumber snap_seq) {
+ // TODO(myabandeh): relax should enough since the synchronizatin is already
+ // done by snapshots_mutex_ under which this function is called.
+ if (snap_seq <= max_evicted_seq_.load(std::memory_order_acquire)) {
+ // Then this is a rare case that transaction did not finish before max
+ // advances. It is expected for a few read-only backup snapshots. For such
+ // snapshots we might have kept around a couple of entries in the
+ // old_commit_map_. Check and do garbage collection if that is the case.
+ bool need_gc = false;
+ {
+ WPRecordTick(TXN_OLD_COMMIT_MAP_MUTEX_OVERHEAD);
+ ROCKS_LOG_WARN(info_log_, "old_commit_map_mutex_ overhead for %" PRIu64,
+ snap_seq);
+ ReadLock rl(&old_commit_map_mutex_);
+ auto prep_set_entry = old_commit_map_.find(snap_seq);
+ need_gc = prep_set_entry != old_commit_map_.end();
+ }
+ if (need_gc) {
+ WPRecordTick(TXN_OLD_COMMIT_MAP_MUTEX_OVERHEAD);
+ ROCKS_LOG_WARN(info_log_, "old_commit_map_mutex_ overhead for %" PRIu64,
+ snap_seq);
+ WriteLock wl(&old_commit_map_mutex_);
+ old_commit_map_.erase(snap_seq);
+ old_commit_map_empty_.store(old_commit_map_.empty(),
+ std::memory_order_release);
+ }
+ }
+}
+
+void WritePreparedTxnDB::CleanupReleasedSnapshots(
+ const std::vector<SequenceNumber>& new_snapshots,
+ const std::vector<SequenceNumber>& old_snapshots) {
+ auto newi = new_snapshots.begin();
+ auto oldi = old_snapshots.begin();
+ for (; newi != new_snapshots.end() && oldi != old_snapshots.end();) {
+ assert(*newi >= *oldi); // cannot have new snapshots with lower seq
+ if (*newi == *oldi) { // still not released
+ auto value = *newi;
+ while (newi != new_snapshots.end() && *newi == value) {
+ newi++;
+ }
+ while (oldi != old_snapshots.end() && *oldi == value) {
+ oldi++;
+ }
+ } else {
+ assert(*newi > *oldi); // *oldi is released
+ ReleaseSnapshotInternal(*oldi);
+ oldi++;
+ }
+ }
+ // Everything remained in old_snapshots is released and must be cleaned up
+ for (; oldi != old_snapshots.end(); oldi++) {
+ ReleaseSnapshotInternal(*oldi);
+ }
+}
+
+void WritePreparedTxnDB::UpdateSnapshots(
+ const std::vector<SequenceNumber>& snapshots,
+ const SequenceNumber& version) {
+ ROCKS_LOG_DETAILS(info_log_, "UpdateSnapshots with version %" PRIu64,
+ version);
+ TEST_SYNC_POINT("WritePreparedTxnDB::UpdateSnapshots:p:start");
+ TEST_SYNC_POINT("WritePreparedTxnDB::UpdateSnapshots:s:start");
+#ifndef NDEBUG
+ size_t sync_i = 0;
+#endif
+ ROCKS_LOG_DETAILS(info_log_, "snapshots_mutex_ overhead");
+ WriteLock wl(&snapshots_mutex_);
+ snapshots_version_ = version;
+ // We update the list concurrently with the readers.
+ // Both new and old lists are sorted and the new list is subset of the
+ // previous list plus some new items. Thus if a snapshot repeats in
+ // both new and old lists, it will appear upper in the new list. So if
+ // we simply insert the new snapshots in order, if an overwritten item
+ // is still valid in the new list is either written to the same place in
+ // the array or it is written in a higher palce before it gets
+ // overwritten by another item. This guarantess a reader that reads the
+ // list bottom-up will eventaully see a snapshot that repeats in the
+ // update, either before it gets overwritten by the writer or
+ // afterwards.
+ size_t i = 0;
+ auto it = snapshots.begin();
+ for (; it != snapshots.end() && i < SNAPSHOT_CACHE_SIZE; ++it, ++i) {
+ snapshot_cache_[i].store(*it, std::memory_order_release);
+ TEST_IDX_SYNC_POINT("WritePreparedTxnDB::UpdateSnapshots:p:", ++sync_i);
+ TEST_IDX_SYNC_POINT("WritePreparedTxnDB::UpdateSnapshots:s:", sync_i);
+ }
+#ifndef NDEBUG
+ // Release the remaining sync points since they are useless given that the
+ // reader would also use lock to access snapshots
+ for (++sync_i; sync_i <= 10; ++sync_i) {
+ TEST_IDX_SYNC_POINT("WritePreparedTxnDB::UpdateSnapshots:p:", sync_i);
+ TEST_IDX_SYNC_POINT("WritePreparedTxnDB::UpdateSnapshots:s:", sync_i);
+ }
+#endif
+ snapshots_.clear();
+ for (; it != snapshots.end(); ++it) {
+ // Insert them to a vector that is less efficient to access
+ // concurrently
+ snapshots_.push_back(*it);
+ }
+ // Update the size at the end. Otherwise a parallel reader might read
+ // items that are not set yet.
+ snapshots_total_.store(snapshots.size(), std::memory_order_release);
+
+ // Note: this must be done after the snapshots data structures are updated
+ // with the new list of snapshots.
+ CleanupReleasedSnapshots(snapshots, snapshots_all_);
+ snapshots_all_ = snapshots;
+
+ TEST_SYNC_POINT("WritePreparedTxnDB::UpdateSnapshots:p:end");
+ TEST_SYNC_POINT("WritePreparedTxnDB::UpdateSnapshots:s:end");
+}
+
+void WritePreparedTxnDB::CheckAgainstSnapshots(const CommitEntry& evicted) {
+ TEST_SYNC_POINT("WritePreparedTxnDB::CheckAgainstSnapshots:p:start");
+ TEST_SYNC_POINT("WritePreparedTxnDB::CheckAgainstSnapshots:s:start");
+#ifndef NDEBUG
+ size_t sync_i = 0;
+#endif
+ // First check the snapshot cache that is efficient for concurrent access
+ auto cnt = snapshots_total_.load(std::memory_order_acquire);
+ // The list might get updated concurrently as we are reading from it. The
+ // reader should be able to read all the snapshots that are still valid
+ // after the update. Since the survived snapshots are written in a higher
+ // place before gets overwritten the reader that reads bottom-up will
+ // eventully see it.
+ const bool next_is_larger = true;
+ // We will set to true if the border line snapshot suggests that.
+ bool search_larger_list = false;
+ size_t ip1 = std::min(cnt, SNAPSHOT_CACHE_SIZE);
+ for (; 0 < ip1; ip1--) {
+ SequenceNumber snapshot_seq =
+ snapshot_cache_[ip1 - 1].load(std::memory_order_acquire);
+ TEST_IDX_SYNC_POINT("WritePreparedTxnDB::CheckAgainstSnapshots:p:",
+ ++sync_i);
+ TEST_IDX_SYNC_POINT("WritePreparedTxnDB::CheckAgainstSnapshots:s:", sync_i);
+ if (ip1 == SNAPSHOT_CACHE_SIZE) { // border line snapshot
+ // snapshot_seq < commit_seq => larger_snapshot_seq <= commit_seq
+ // then later also continue the search to larger snapshots
+ search_larger_list = snapshot_seq < evicted.commit_seq;
+ }
+ if (!MaybeUpdateOldCommitMap(evicted.prep_seq, evicted.commit_seq,
+ snapshot_seq, !next_is_larger)) {
+ break;
+ }
+ }
+#ifndef NDEBUG
+ // Release the remaining sync points before accquiring the lock
+ for (++sync_i; sync_i <= 10; ++sync_i) {
+ TEST_IDX_SYNC_POINT("WritePreparedTxnDB::CheckAgainstSnapshots:p:", sync_i);
+ TEST_IDX_SYNC_POINT("WritePreparedTxnDB::CheckAgainstSnapshots:s:", sync_i);
+ }
+#endif
+ TEST_SYNC_POINT("WritePreparedTxnDB::CheckAgainstSnapshots:p:end");
+ TEST_SYNC_POINT("WritePreparedTxnDB::CheckAgainstSnapshots:s:end");
+ if (UNLIKELY(SNAPSHOT_CACHE_SIZE < cnt && search_larger_list)) {
+ // Then access the less efficient list of snapshots_
+ WPRecordTick(TXN_SNAPSHOT_MUTEX_OVERHEAD);
+ ROCKS_LOG_WARN(info_log_,
+ "snapshots_mutex_ overhead for <%" PRIu64 ",%" PRIu64
+ "> with %" ROCKSDB_PRIszt " snapshots",
+ evicted.prep_seq, evicted.commit_seq, cnt);
+ ReadLock rl(&snapshots_mutex_);
+ // Items could have moved from the snapshots_ to snapshot_cache_ before
+ // accquiring the lock. To make sure that we do not miss a valid snapshot,
+ // read snapshot_cache_ again while holding the lock.
+ for (size_t i = 0; i < SNAPSHOT_CACHE_SIZE; i++) {
+ SequenceNumber snapshot_seq =
+ snapshot_cache_[i].load(std::memory_order_acquire);
+ if (!MaybeUpdateOldCommitMap(evicted.prep_seq, evicted.commit_seq,
+ snapshot_seq, next_is_larger)) {
+ break;
+ }
+ }
+ for (auto snapshot_seq_2 : snapshots_) {
+ if (!MaybeUpdateOldCommitMap(evicted.prep_seq, evicted.commit_seq,
+ snapshot_seq_2, next_is_larger)) {
+ break;
+ }
+ }
+ }
+}
+
+bool WritePreparedTxnDB::MaybeUpdateOldCommitMap(
+ const uint64_t& prep_seq, const uint64_t& commit_seq,
+ const uint64_t& snapshot_seq, const bool next_is_larger = true) {
+ // If we do not store an entry in old_commit_map_ we assume it is committed in
+ // all snapshots. If commit_seq <= snapshot_seq, it is considered already in
+ // the snapshot so we need not to keep the entry around for this snapshot.
+ if (commit_seq <= snapshot_seq) {
+ // continue the search if the next snapshot could be smaller than commit_seq
+ return !next_is_larger;
+ }
+ // then snapshot_seq < commit_seq
+ if (prep_seq <= snapshot_seq) { // overlapping range
+ WPRecordTick(TXN_OLD_COMMIT_MAP_MUTEX_OVERHEAD);
+ ROCKS_LOG_WARN(info_log_,
+ "old_commit_map_mutex_ overhead for %" PRIu64
+ " commit entry: <%" PRIu64 ",%" PRIu64 ">",
+ snapshot_seq, prep_seq, commit_seq);
+ WriteLock wl(&old_commit_map_mutex_);
+ old_commit_map_empty_.store(false, std::memory_order_release);
+ auto& vec = old_commit_map_[snapshot_seq];
+ vec.insert(std::upper_bound(vec.begin(), vec.end(), prep_seq), prep_seq);
+ // We need to store it once for each overlapping snapshot. Returning true to
+ // continue the search if there is more overlapping snapshot.
+ return true;
+ }
+ // continue the search if the next snapshot could be larger than prep_seq
+ return next_is_larger;
+}
+
+WritePreparedTxnDB::~WritePreparedTxnDB() {
+ // At this point there could be running compaction/flush holding a
+ // SnapshotChecker, which holds a pointer back to WritePreparedTxnDB.
+ // Make sure those jobs finished before destructing WritePreparedTxnDB.
+ if (!db_impl_->shutting_down_) {
+ db_impl_->CancelAllBackgroundWork(true /*wait*/);
+ }
+}
+
+void SubBatchCounter::InitWithComp(const uint32_t cf) {
+ auto cmp = comparators_[cf];
+ keys_[cf] = CFKeys(SetComparator(cmp));
+}
+
+void SubBatchCounter::AddKey(const uint32_t cf, const Slice& key) {
+ CFKeys& cf_keys = keys_[cf];
+ if (cf_keys.size() == 0) { // just inserted
+ InitWithComp(cf);
+ }
+ auto it = cf_keys.insert(key);
+ if (it.second == false) { // second is false if a element already existed.
+ batches_++;
+ keys_.clear();
+ InitWithComp(cf);
+ keys_[cf].insert(key);
+ }
+}
+
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/write_prepared_txn_db.h b/src/rocksdb/utilities/transactions/write_prepared_txn_db.h
new file mode 100644
index 000000000..964b72689
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/write_prepared_txn_db.h
@@ -0,0 +1,1111 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+#ifndef ROCKSDB_LITE
+
+#include <cinttypes>
+#include <mutex>
+#include <queue>
+#include <set>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "db/db_iter.h"
+#include "db/pre_release_callback.h"
+#include "db/read_callback.h"
+#include "db/snapshot_checker.h"
+#include "rocksdb/db.h"
+#include "rocksdb/options.h"
+#include "rocksdb/utilities/transaction_db.h"
+#include "util/cast_util.h"
+#include "util/set_comparator.h"
+#include "util/string_util.h"
+#include "utilities/transactions/pessimistic_transaction.h"
+#include "utilities/transactions/pessimistic_transaction_db.h"
+#include "utilities/transactions/transaction_lock_mgr.h"
+#include "utilities/transactions/write_prepared_txn.h"
+
+namespace ROCKSDB_NAMESPACE {
+enum SnapshotBackup : bool { kUnbackedByDBSnapshot, kBackedByDBSnapshot };
+
+// A PessimisticTransactionDB that writes data to DB after prepare phase of 2PC.
+// In this way some data in the DB might not be committed. The DB provides
+// mechanisms to tell such data apart from committed data.
+class WritePreparedTxnDB : public PessimisticTransactionDB {
+ public:
+ explicit WritePreparedTxnDB(DB* db,
+ const TransactionDBOptions& txn_db_options)
+ : PessimisticTransactionDB(db, txn_db_options),
+ SNAPSHOT_CACHE_BITS(txn_db_options.wp_snapshot_cache_bits),
+ SNAPSHOT_CACHE_SIZE(static_cast<size_t>(1ull << SNAPSHOT_CACHE_BITS)),
+ COMMIT_CACHE_BITS(txn_db_options.wp_commit_cache_bits),
+ COMMIT_CACHE_SIZE(static_cast<size_t>(1ull << COMMIT_CACHE_BITS)),
+ FORMAT(COMMIT_CACHE_BITS) {
+ Init(txn_db_options);
+ }
+
+ explicit WritePreparedTxnDB(StackableDB* db,
+ const TransactionDBOptions& txn_db_options)
+ : PessimisticTransactionDB(db, txn_db_options),
+ SNAPSHOT_CACHE_BITS(txn_db_options.wp_snapshot_cache_bits),
+ SNAPSHOT_CACHE_SIZE(static_cast<size_t>(1ull << SNAPSHOT_CACHE_BITS)),
+ COMMIT_CACHE_BITS(txn_db_options.wp_commit_cache_bits),
+ COMMIT_CACHE_SIZE(static_cast<size_t>(1ull << COMMIT_CACHE_BITS)),
+ FORMAT(COMMIT_CACHE_BITS) {
+ Init(txn_db_options);
+ }
+
+ virtual ~WritePreparedTxnDB();
+
+ virtual Status Initialize(
+ const std::vector<size_t>& compaction_enabled_cf_indices,
+ const std::vector<ColumnFamilyHandle*>& handles) override;
+
+ Transaction* BeginTransaction(const WriteOptions& write_options,
+ const TransactionOptions& txn_options,
+ Transaction* old_txn) override;
+
+ using TransactionDB::Write;
+ Status Write(const WriteOptions& opts, WriteBatch* updates) override;
+
+ // Optimized version of ::Write that receives more optimization request such
+ // as skip_concurrency_control.
+ using PessimisticTransactionDB::Write;
+ Status Write(const WriteOptions& opts, const TransactionDBWriteOptimizations&,
+ WriteBatch* updates) override;
+
+ // Write the batch to the underlying DB and mark it as committed. Could be
+ // used by both directly from TxnDB or through a transaction.
+ Status WriteInternal(const WriteOptions& write_options, WriteBatch* batch,
+ size_t batch_cnt, WritePreparedTxn* txn);
+
+ using DB::Get;
+ virtual Status Get(const ReadOptions& options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ PinnableSlice* value) override;
+
+ using DB::MultiGet;
+ virtual std::vector<Status> MultiGet(
+ const ReadOptions& options,
+ const std::vector<ColumnFamilyHandle*>& column_family,
+ const std::vector<Slice>& keys,
+ std::vector<std::string>* values) override;
+
+ using DB::NewIterator;
+ virtual Iterator* NewIterator(const ReadOptions& options,
+ ColumnFamilyHandle* column_family) override;
+
+ using DB::NewIterators;
+ virtual Status NewIterators(
+ const ReadOptions& options,
+ const std::vector<ColumnFamilyHandle*>& column_families,
+ std::vector<Iterator*>* iterators) override;
+
+ // Check whether the transaction that wrote the value with sequence number seq
+ // is visible to the snapshot with sequence number snapshot_seq.
+ // Returns true if commit_seq <= snapshot_seq
+ // If the snapshot_seq is already released and snapshot_seq <= max, sets
+ // *snap_released to true and returns true as well.
+ inline bool IsInSnapshot(uint64_t prep_seq, uint64_t snapshot_seq,
+ uint64_t min_uncommitted = kMinUnCommittedSeq,
+ bool* snap_released = nullptr) const {
+ ROCKS_LOG_DETAILS(info_log_,
+ "IsInSnapshot %" PRIu64 " in %" PRIu64
+ " min_uncommitted %" PRIu64,
+ prep_seq, snapshot_seq, min_uncommitted);
+ assert(min_uncommitted >= kMinUnCommittedSeq);
+ // Caller is responsible to initialize snap_released.
+ assert(snap_released == nullptr || *snap_released == false);
+ // Here we try to infer the return value without looking into prepare list.
+ // This would help avoiding synchronization over a shared map.
+ // TODO(myabandeh): optimize this. This sequence of checks must be correct
+ // but not necessary efficient
+ if (prep_seq == 0) {
+ // Compaction will output keys to bottom-level with sequence number 0 if
+ // it is visible to the earliest snapshot.
+ ROCKS_LOG_DETAILS(
+ info_log_, "IsInSnapshot %" PRIu64 " in %" PRIu64 " returns %" PRId32,
+ prep_seq, snapshot_seq, 1);
+ return true;
+ }
+ if (snapshot_seq < prep_seq) {
+ // snapshot_seq < prep_seq <= commit_seq => snapshot_seq < commit_seq
+ ROCKS_LOG_DETAILS(
+ info_log_, "IsInSnapshot %" PRIu64 " in %" PRIu64 " returns %" PRId32,
+ prep_seq, snapshot_seq, 0);
+ return false;
+ }
+ if (prep_seq < min_uncommitted) {
+ ROCKS_LOG_DETAILS(info_log_,
+ "IsInSnapshot %" PRIu64 " in %" PRIu64
+ " returns %" PRId32
+ " because of min_uncommitted %" PRIu64,
+ prep_seq, snapshot_seq, 1, min_uncommitted);
+ return true;
+ }
+ // Commit of delayed prepared has two non-atomic steps: add to commit cache,
+ // remove from delayed prepared. Our reads from these two is also
+ // non-atomic. By looking into commit cache first thus we might not find the
+ // prep_seq neither in commit cache not in delayed_prepared_. To fix that i)
+ // we check if there was any delayed prepared BEFORE looking into commit
+ // cache, ii) if there was, we complete the search steps to be these: i)
+ // commit cache, ii) delayed prepared, commit cache again. In this way if
+ // the first query to commit cache missed the commit, the 2nd will catch it.
+ bool was_empty;
+ SequenceNumber max_evicted_seq_lb, max_evicted_seq_ub;
+ CommitEntry64b dont_care;
+ auto indexed_seq = prep_seq % COMMIT_CACHE_SIZE;
+ size_t repeats = 0;
+ do {
+ repeats++;
+ assert(repeats < 100);
+ if (UNLIKELY(repeats >= 100)) {
+ throw std::runtime_error(
+ "The read was intrupted 100 times by update to max_evicted_seq_. "
+ "This is unexpected in all setups");
+ }
+ max_evicted_seq_lb = max_evicted_seq_.load(std::memory_order_acquire);
+ TEST_SYNC_POINT(
+ "WritePreparedTxnDB::IsInSnapshot:max_evicted_seq_:pause");
+ TEST_SYNC_POINT(
+ "WritePreparedTxnDB::IsInSnapshot:max_evicted_seq_:resume");
+ was_empty = delayed_prepared_empty_.load(std::memory_order_acquire);
+ TEST_SYNC_POINT(
+ "WritePreparedTxnDB::IsInSnapshot:delayed_prepared_empty_:pause");
+ TEST_SYNC_POINT(
+ "WritePreparedTxnDB::IsInSnapshot:delayed_prepared_empty_:resume");
+ CommitEntry cached;
+ bool exist = GetCommitEntry(indexed_seq, &dont_care, &cached);
+ TEST_SYNC_POINT("WritePreparedTxnDB::IsInSnapshot:GetCommitEntry:pause");
+ TEST_SYNC_POINT("WritePreparedTxnDB::IsInSnapshot:GetCommitEntry:resume");
+ if (exist && prep_seq == cached.prep_seq) {
+ // It is committed and also not evicted from commit cache
+ ROCKS_LOG_DETAILS(
+ info_log_,
+ "IsInSnapshot %" PRIu64 " in %" PRIu64 " returns %" PRId32,
+ prep_seq, snapshot_seq, cached.commit_seq <= snapshot_seq);
+ return cached.commit_seq <= snapshot_seq;
+ }
+ // else it could be committed but not inserted in the map which could
+ // happen after recovery, or it could be committed and evicted by another
+ // commit, or never committed.
+
+ // At this point we dont know if it was committed or it is still prepared
+ max_evicted_seq_ub = max_evicted_seq_.load(std::memory_order_acquire);
+ if (UNLIKELY(max_evicted_seq_lb != max_evicted_seq_ub)) {
+ continue;
+ }
+ // Note: max_evicted_seq_ when we did GetCommitEntry <= max_evicted_seq_ub
+ if (max_evicted_seq_ub < prep_seq) {
+ // Not evicted from cache and also not present, so must be still
+ // prepared
+ ROCKS_LOG_DETAILS(info_log_,
+ "IsInSnapshot %" PRIu64 " in %" PRIu64
+ " returns %" PRId32,
+ prep_seq, snapshot_seq, 0);
+ return false;
+ }
+ TEST_SYNC_POINT("WritePreparedTxnDB::IsInSnapshot:prepared_mutex_:pause");
+ TEST_SYNC_POINT(
+ "WritePreparedTxnDB::IsInSnapshot:prepared_mutex_:resume");
+ if (!was_empty) {
+ // We should not normally reach here
+ WPRecordTick(TXN_PREPARE_MUTEX_OVERHEAD);
+ ReadLock rl(&prepared_mutex_);
+ ROCKS_LOG_WARN(
+ info_log_, "prepared_mutex_ overhead %" PRIu64 " for %" PRIu64,
+ static_cast<uint64_t>(delayed_prepared_.size()), prep_seq);
+ if (delayed_prepared_.find(prep_seq) != delayed_prepared_.end()) {
+ // This is the order: 1) delayed_prepared_commits_ update, 2) publish
+ // 3) delayed_prepared_ clean up. So check if it is the case of a late
+ // clenaup.
+ auto it = delayed_prepared_commits_.find(prep_seq);
+ if (it == delayed_prepared_commits_.end()) {
+ // Then it is not committed yet
+ ROCKS_LOG_DETAILS(info_log_,
+ "IsInSnapshot %" PRIu64 " in %" PRIu64
+ " returns %" PRId32,
+ prep_seq, snapshot_seq, 0);
+ return false;
+ } else {
+ ROCKS_LOG_DETAILS(info_log_,
+ "IsInSnapshot %" PRIu64 " in %" PRIu64
+ " commit: %" PRIu64 " returns %" PRId32,
+ prep_seq, snapshot_seq, it->second,
+ snapshot_seq <= it->second);
+ return it->second <= snapshot_seq;
+ }
+ } else {
+ // 2nd query to commit cache. Refer to was_empty comment above.
+ exist = GetCommitEntry(indexed_seq, &dont_care, &cached);
+ if (exist && prep_seq == cached.prep_seq) {
+ ROCKS_LOG_DETAILS(
+ info_log_,
+ "IsInSnapshot %" PRIu64 " in %" PRIu64 " returns %" PRId32,
+ prep_seq, snapshot_seq, cached.commit_seq <= snapshot_seq);
+ return cached.commit_seq <= snapshot_seq;
+ }
+ max_evicted_seq_ub = max_evicted_seq_.load(std::memory_order_acquire);
+ }
+ }
+ } while (UNLIKELY(max_evicted_seq_lb != max_evicted_seq_ub));
+ // When advancing max_evicted_seq_, we move older entires from prepared to
+ // delayed_prepared_. Also we move evicted entries from commit cache to
+ // old_commit_map_ if it overlaps with any snapshot. Since prep_seq <=
+ // max_evicted_seq_, we have three cases: i) in delayed_prepared_, ii) in
+ // old_commit_map_, iii) committed with no conflict with any snapshot. Case
+ // (i) delayed_prepared_ is checked above
+ if (max_evicted_seq_ub < snapshot_seq) { // then (ii) cannot be the case
+ // only (iii) is the case: committed
+ // commit_seq <= max_evicted_seq_ < snapshot_seq => commit_seq <
+ // snapshot_seq
+ ROCKS_LOG_DETAILS(
+ info_log_, "IsInSnapshot %" PRIu64 " in %" PRIu64 " returns %" PRId32,
+ prep_seq, snapshot_seq, 1);
+ return true;
+ }
+ // else (ii) might be the case: check the commit data saved for this
+ // snapshot. If there was no overlapping commit entry, then it is committed
+ // with a commit_seq lower than any live snapshot, including snapshot_seq.
+ if (old_commit_map_empty_.load(std::memory_order_acquire)) {
+ ROCKS_LOG_DETAILS(info_log_,
+ "IsInSnapshot %" PRIu64 " in %" PRIu64
+ " returns %" PRId32 " released=1",
+ prep_seq, snapshot_seq, 0);
+ assert(snap_released);
+ // This snapshot is not valid anymore. We cannot tell if prep_seq is
+ // committed before or after the snapshot. Return true but also set
+ // snap_released to true.
+ *snap_released = true;
+ return true;
+ }
+ {
+ // We should not normally reach here unless sapshot_seq is old. This is a
+ // rare case and it is ok to pay the cost of mutex ReadLock for such old,
+ // reading transactions.
+ WPRecordTick(TXN_OLD_COMMIT_MAP_MUTEX_OVERHEAD);
+ ReadLock rl(&old_commit_map_mutex_);
+ auto prep_set_entry = old_commit_map_.find(snapshot_seq);
+ bool found = prep_set_entry != old_commit_map_.end();
+ if (found) {
+ auto& vec = prep_set_entry->second;
+ found = std::binary_search(vec.begin(), vec.end(), prep_seq);
+ } else {
+ // coming from compaction
+ ROCKS_LOG_DETAILS(info_log_,
+ "IsInSnapshot %" PRIu64 " in %" PRIu64
+ " returns %" PRId32 " released=1",
+ prep_seq, snapshot_seq, 0);
+ // This snapshot is not valid anymore. We cannot tell if prep_seq is
+ // committed before or after the snapshot. Return true but also set
+ // snap_released to true.
+ assert(snap_released);
+ *snap_released = true;
+ return true;
+ }
+
+ if (!found) {
+ ROCKS_LOG_DETAILS(info_log_,
+ "IsInSnapshot %" PRIu64 " in %" PRIu64
+ " returns %" PRId32,
+ prep_seq, snapshot_seq, 1);
+ return true;
+ }
+ }
+ // (ii) it the case: it is committed but after the snapshot_seq
+ ROCKS_LOG_DETAILS(
+ info_log_, "IsInSnapshot %" PRIu64 " in %" PRIu64 " returns %" PRId32,
+ prep_seq, snapshot_seq, 0);
+ return false;
+ }
+
+ // Add the transaction with prepare sequence seq to the prepared list.
+ // Note: must be called serially with increasing seq on each call.
+ // locked is true if prepared_mutex_ is already locked.
+ void AddPrepared(uint64_t seq, bool locked = false);
+ // Check if any of the prepared txns are less than new max_evicted_seq_. Must
+ // be called with prepared_mutex_ write locked.
+ void CheckPreparedAgainstMax(SequenceNumber new_max, bool locked);
+ // Remove the transaction with prepare sequence seq from the prepared list
+ void RemovePrepared(const uint64_t seq, const size_t batch_cnt = 1);
+ // Add the transaction with prepare sequence prepare_seq and commit sequence
+ // commit_seq to the commit map. loop_cnt is to detect infinite loops.
+ // Note: must be called serially.
+ void AddCommitted(uint64_t prepare_seq, uint64_t commit_seq,
+ uint8_t loop_cnt = 0);
+
+ struct CommitEntry {
+ uint64_t prep_seq;
+ uint64_t commit_seq;
+ CommitEntry() : prep_seq(0), commit_seq(0) {}
+ CommitEntry(uint64_t ps, uint64_t cs) : prep_seq(ps), commit_seq(cs) {}
+ bool operator==(const CommitEntry& rhs) const {
+ return prep_seq == rhs.prep_seq && commit_seq == rhs.commit_seq;
+ }
+ };
+
+ struct CommitEntry64bFormat {
+ explicit CommitEntry64bFormat(size_t index_bits)
+ : INDEX_BITS(index_bits),
+ PREP_BITS(static_cast<size_t>(64 - PAD_BITS - INDEX_BITS)),
+ COMMIT_BITS(static_cast<size_t>(64 - PREP_BITS)),
+ COMMIT_FILTER(static_cast<uint64_t>((1ull << COMMIT_BITS) - 1)),
+ DELTA_UPPERBOUND(static_cast<uint64_t>((1ull << COMMIT_BITS))) {}
+ // Number of higher bits of a sequence number that is not used. They are
+ // used to encode the value type, ...
+ const size_t PAD_BITS = static_cast<size_t>(8);
+ // Number of lower bits from prepare seq that can be skipped as they are
+ // implied by the index of the entry in the array
+ const size_t INDEX_BITS;
+ // Number of bits we use to encode the prepare seq
+ const size_t PREP_BITS;
+ // Number of bits we use to encode the commit seq.
+ const size_t COMMIT_BITS;
+ // Filter to encode/decode commit seq
+ const uint64_t COMMIT_FILTER;
+ // The value of commit_seq - prepare_seq + 1 must be less than this bound
+ const uint64_t DELTA_UPPERBOUND;
+ };
+
+ // Prepare Seq (64 bits) = PAD ... PAD PREP PREP ... PREP INDEX INDEX ...
+ // INDEX Delta Seq (64 bits) = 0 0 0 0 0 0 0 0 0 0 0 0 DELTA DELTA ...
+ // DELTA DELTA Encoded Value = PREP PREP .... PREP PREP DELTA DELTA
+ // ... DELTA DELTA PAD: first bits of a seq that is reserved for tagging and
+ // hence ignored PREP/INDEX: the used bits in a prepare seq number INDEX: the
+ // bits that do not have to be encoded (will be provided externally) DELTA:
+ // prep seq - commit seq + 1 Number of DELTA bits should be equal to number of
+ // index bits + PADs
+ struct CommitEntry64b {
+ constexpr CommitEntry64b() noexcept : rep_(0) {}
+
+ CommitEntry64b(const CommitEntry& entry, const CommitEntry64bFormat& format)
+ : CommitEntry64b(entry.prep_seq, entry.commit_seq, format) {}
+
+ CommitEntry64b(const uint64_t ps, const uint64_t cs,
+ const CommitEntry64bFormat& format) {
+ assert(ps < static_cast<uint64_t>(
+ (1ull << (format.PREP_BITS + format.INDEX_BITS))));
+ assert(ps <= cs);
+ uint64_t delta = cs - ps + 1; // make initialized delta always >= 1
+ // zero is reserved for uninitialized entries
+ assert(0 < delta);
+ assert(delta < format.DELTA_UPPERBOUND);
+ if (delta >= format.DELTA_UPPERBOUND) {
+ throw std::runtime_error(
+ "commit_seq >> prepare_seq. The allowed distance is " +
+ ToString(format.DELTA_UPPERBOUND) + " commit_seq is " +
+ ToString(cs) + " prepare_seq is " + ToString(ps));
+ }
+ rep_ = (ps << format.PAD_BITS) & ~format.COMMIT_FILTER;
+ rep_ = rep_ | delta;
+ }
+
+ // Return false if the entry is empty
+ bool Parse(const uint64_t indexed_seq, CommitEntry* entry,
+ const CommitEntry64bFormat& format) {
+ uint64_t delta = rep_ & format.COMMIT_FILTER;
+ // zero is reserved for uninitialized entries
+ assert(delta < static_cast<uint64_t>((1ull << format.COMMIT_BITS)));
+ if (delta == 0) {
+ return false; // initialized entry would have non-zero delta
+ }
+
+ assert(indexed_seq < static_cast<uint64_t>((1ull << format.INDEX_BITS)));
+ uint64_t prep_up = rep_ & ~format.COMMIT_FILTER;
+ prep_up >>= format.PAD_BITS;
+ const uint64_t& prep_low = indexed_seq;
+ entry->prep_seq = prep_up | prep_low;
+
+ entry->commit_seq = entry->prep_seq + delta - 1;
+ return true;
+ }
+
+ private:
+ uint64_t rep_;
+ };
+
+ // Struct to hold ownership of snapshot and read callback for cleanup.
+ struct IteratorState;
+
+ std::shared_ptr<std::map<uint32_t, const Comparator*>> GetCFComparatorMap() {
+ return cf_map_;
+ }
+ std::shared_ptr<std::map<uint32_t, ColumnFamilyHandle*>> GetCFHandleMap() {
+ return handle_map_;
+ }
+ void UpdateCFComparatorMap(
+ const std::vector<ColumnFamilyHandle*>& handles) override;
+ void UpdateCFComparatorMap(ColumnFamilyHandle* handle) override;
+
+ virtual const Snapshot* GetSnapshot() override;
+ SnapshotImpl* GetSnapshotInternal(bool for_ww_conflict_check);
+
+ protected:
+ virtual Status VerifyCFOptions(
+ const ColumnFamilyOptions& cf_options) override;
+ // Assign the min and max sequence numbers for reading from the db. A seq >
+ // max is not valid, and a seq < min is valid, and a min <= seq < max requires
+ // further checking. Normally max is defined by the snapshot and min is by
+ // minimum uncommitted seq.
+ inline SnapshotBackup AssignMinMaxSeqs(const Snapshot* snapshot,
+ SequenceNumber* min,
+ SequenceNumber* max);
+ // Validate is a snapshot sequence number is still valid based on the latest
+ // db status. backed_by_snapshot specifies if the number is baked by an actual
+ // snapshot object. order specified the memory order with which we load the
+ // atomic variables: relax is enough for the default since we care about last
+ // value seen by same thread.
+ inline bool ValidateSnapshot(
+ const SequenceNumber snap_seq, const SnapshotBackup backed_by_snapshot,
+ std::memory_order order = std::memory_order_relaxed);
+ // Get a dummy snapshot that refers to kMaxSequenceNumber
+ Snapshot* GetMaxSnapshot() { return &dummy_max_snapshot_; }
+
+ private:
+ friend class AddPreparedCallback;
+ friend class PreparedHeap_BasicsTest_Test;
+ friend class PreparedHeap_Concurrent_Test;
+ friend class PreparedHeap_EmptyAtTheEnd_Test;
+ friend class SnapshotConcurrentAccessTest_SnapshotConcurrentAccess_Test;
+ friend class WritePreparedCommitEntryPreReleaseCallback;
+ friend class WritePreparedTransactionTestBase;
+ friend class WritePreparedTxn;
+ friend class WritePreparedTxnDBMock;
+ friend class WritePreparedTransactionTest_AddPreparedBeforeMax_Test;
+ friend class WritePreparedTransactionTest_AdvanceMaxEvictedSeqBasic_Test;
+ friend class
+ WritePreparedTransactionTest_AdvanceMaxEvictedSeqWithDuplicates_Test;
+ friend class WritePreparedTransactionTest_AdvanceSeqByOne_Test;
+ friend class WritePreparedTransactionTest_BasicRecovery_Test;
+ friend class WritePreparedTransactionTest_CheckAgainstSnapshots_Test;
+ friend class WritePreparedTransactionTest_CleanupSnapshotEqualToMax_Test;
+ friend class WritePreparedTransactionTest_ConflictDetectionAfterRecovery_Test;
+ friend class WritePreparedTransactionTest_CommitMap_Test;
+ friend class WritePreparedTransactionTest_DoubleSnapshot_Test;
+ friend class WritePreparedTransactionTest_IsInSnapshotEmptyMap_Test;
+ friend class WritePreparedTransactionTest_IsInSnapshotReleased_Test;
+ friend class WritePreparedTransactionTest_IsInSnapshot_Test;
+ friend class WritePreparedTransactionTest_NewSnapshotLargerThanMax_Test;
+ friend class WritePreparedTransactionTest_MaxCatchupWithNewSnapshot_Test;
+ friend class WritePreparedTransactionTest_MaxCatchupWithUnbackedSnapshot_Test;
+ friend class
+ WritePreparedTransactionTest_NonAtomicCommitOfDelayedPrepared_Test;
+ friend class
+ WritePreparedTransactionTest_NonAtomicUpdateOfDelayedPrepared_Test;
+ friend class WritePreparedTransactionTest_NonAtomicUpdateOfMaxEvictedSeq_Test;
+ friend class WritePreparedTransactionTest_OldCommitMapGC_Test;
+ friend class WritePreparedTransactionTest_Rollback_Test;
+ friend class WritePreparedTransactionTest_SmallestUnCommittedSeq_Test;
+ friend class WriteUnpreparedTxn;
+ friend class WriteUnpreparedTxnDB;
+ friend class WriteUnpreparedTransactionTest_RecoveryTest_Test;
+
+ void Init(const TransactionDBOptions& /* unused */);
+
+ void WPRecordTick(uint32_t ticker_type) const {
+ RecordTick(db_impl_->immutable_db_options_.statistics.get(), ticker_type);
+ }
+
+ // A heap with the amortized O(1) complexity for erase. It uses one extra heap
+ // to keep track of erased entries that are not yet on top of the main heap.
+ class PreparedHeap {
+ // The mutex is required for push and pop from PreparedHeap. ::erase will
+ // use external synchronization via prepared_mutex_.
+ port::Mutex push_pop_mutex_;
+ std::deque<uint64_t> heap_;
+ std::priority_queue<uint64_t, std::vector<uint64_t>, std::greater<uint64_t>>
+ erased_heap_;
+ std::atomic<uint64_t> heap_top_ = {kMaxSequenceNumber};
+ // True when testing crash recovery
+ bool TEST_CRASH_ = false;
+ friend class WritePreparedTxnDB;
+
+ public:
+ ~PreparedHeap() {
+ if (!TEST_CRASH_) {
+ assert(heap_.empty());
+ assert(erased_heap_.empty());
+ }
+ }
+ port::Mutex* push_pop_mutex() { return &push_pop_mutex_; }
+
+ inline bool empty() { return top() == kMaxSequenceNumber; }
+ // Returns kMaxSequenceNumber if empty() and the smallest otherwise.
+ inline uint64_t top() { return heap_top_.load(std::memory_order_acquire); }
+ inline void push(uint64_t v) {
+ push_pop_mutex_.AssertHeld();
+ if (heap_.empty()) {
+ heap_top_.store(v, std::memory_order_release);
+ } else {
+ assert(heap_top_.load() < v);
+ }
+ heap_.push_back(v);
+ }
+ void pop(bool locked = false) {
+ if (!locked) {
+ push_pop_mutex()->Lock();
+ }
+ push_pop_mutex_.AssertHeld();
+ heap_.pop_front();
+ while (!heap_.empty() && !erased_heap_.empty() &&
+ // heap_.top() > erased_heap_.top() could happen if we have erased
+ // a non-existent entry. Ideally the user should not do that but we
+ // should be resilient against it.
+ heap_.front() >= erased_heap_.top()) {
+ if (heap_.front() == erased_heap_.top()) {
+ heap_.pop_front();
+ }
+ uint64_t erased __attribute__((__unused__));
+ erased = erased_heap_.top();
+ erased_heap_.pop();
+ // No duplicate prepare sequence numbers
+ assert(erased_heap_.empty() || erased_heap_.top() != erased);
+ }
+ while (heap_.empty() && !erased_heap_.empty()) {
+ erased_heap_.pop();
+ }
+ heap_top_.store(!heap_.empty() ? heap_.front() : kMaxSequenceNumber,
+ std::memory_order_release);
+ if (!locked) {
+ push_pop_mutex()->Unlock();
+ }
+ }
+ // Concurrrent calls needs external synchronization. It is safe to be called
+ // concurrent to push and pop though.
+ void erase(uint64_t seq) {
+ if (!empty()) {
+ auto top_seq = top();
+ if (seq < top_seq) {
+ // Already popped, ignore it.
+ } else if (top_seq == seq) {
+ pop();
+#ifndef NDEBUG
+ MutexLock ml(push_pop_mutex());
+ assert(heap_.empty() || heap_.front() != seq);
+#endif
+ } else { // top() > seq
+ // Down the heap, remember to pop it later
+ erased_heap_.push(seq);
+ }
+ }
+ }
+ };
+
+ void TEST_Crash() override { prepared_txns_.TEST_CRASH_ = true; }
+
+ // Get the commit entry with index indexed_seq from the commit table. It
+ // returns true if such entry exists.
+ bool GetCommitEntry(const uint64_t indexed_seq, CommitEntry64b* entry_64b,
+ CommitEntry* entry) const;
+
+ // Rewrite the entry with the index indexed_seq in the commit table with the
+ // commit entry <prep_seq, commit_seq>. If the rewrite results into eviction,
+ // sets the evicted_entry and returns true.
+ bool AddCommitEntry(const uint64_t indexed_seq, const CommitEntry& new_entry,
+ CommitEntry* evicted_entry);
+
+ // Rewrite the entry with the index indexed_seq in the commit table with the
+ // commit entry new_entry only if the existing entry matches the
+ // expected_entry. Returns false otherwise.
+ bool ExchangeCommitEntry(const uint64_t indexed_seq,
+ CommitEntry64b& expected_entry,
+ const CommitEntry& new_entry);
+
+ // Increase max_evicted_seq_ from the previous value prev_max to the new
+ // value. This also involves taking care of prepared txns that are not
+ // committed before new_max, as well as updating the list of live snapshots at
+ // the time of updating the max. Thread-safety: this function can be called
+ // concurrently. The concurrent invocations of this function is equivalent to
+ // a serial invocation in which the last invocation is the one with the
+ // largest new_max value.
+ void AdvanceMaxEvictedSeq(const SequenceNumber& prev_max,
+ const SequenceNumber& new_max);
+
+ inline SequenceNumber SmallestUnCommittedSeq() {
+ // Note: We have two lists to look into, but for performance reasons they
+ // are not read atomically. Since CheckPreparedAgainstMax copies the entry
+ // to delayed_prepared_ before removing it from prepared_txns_, to ensure
+ // that a prepared entry will not go unmissed, we look into them in opposite
+ // order: first read prepared_txns_ and then delayed_prepared_.
+
+ // This must be called before calling ::top. This is because the concurrent
+ // thread would call ::RemovePrepared before updating
+ // GetLatestSequenceNumber(). Reading then in opposite order here guarantees
+ // that the ::top that we read would be lower the ::top if we had otherwise
+ // update/read them atomically.
+ auto next_prepare = db_impl_->GetLatestSequenceNumber() + 1;
+ auto min_prepare = prepared_txns_.top();
+ // Since we update the prepare_heap always from the main write queue via
+ // PreReleaseCallback, the prepared_txns_.top() indicates the smallest
+ // prepared data in 2pc transactions. For non-2pc transactions that are
+ // written in two steps, we also update prepared_txns_ at the first step
+ // (via the same mechanism) so that their uncommitted data is reflected in
+ // SmallestUnCommittedSeq.
+ if (!delayed_prepared_empty_.load()) {
+ ReadLock rl(&prepared_mutex_);
+ if (!delayed_prepared_.empty()) {
+ return *delayed_prepared_.begin();
+ }
+ }
+ bool empty = min_prepare == kMaxSequenceNumber;
+ if (empty) {
+ // Since GetLatestSequenceNumber is updated
+ // after prepared_txns_ are, the value of GetLatestSequenceNumber would
+ // reflect any uncommitted data that is not added to prepared_txns_ yet.
+ // Otherwise, if there is no concurrent txn, this value simply reflects
+ // that latest value in the memtable.
+ return next_prepare;
+ } else {
+ return std::min(min_prepare, next_prepare);
+ }
+ }
+
+ // Enhance the snapshot object by recording in it the smallest uncommitted seq
+ inline void EnhanceSnapshot(SnapshotImpl* snapshot,
+ SequenceNumber min_uncommitted) {
+ assert(snapshot);
+ assert(min_uncommitted <= snapshot->number_ + 1);
+ snapshot->min_uncommitted_ = min_uncommitted;
+ }
+
+ virtual const std::vector<SequenceNumber> GetSnapshotListFromDB(
+ SequenceNumber max);
+
+ // Will be called by the public ReleaseSnapshot method. Does the maintenance
+ // internal to WritePreparedTxnDB
+ void ReleaseSnapshotInternal(const SequenceNumber snap_seq);
+
+ // Update the list of snapshots corresponding to the soon-to-be-updated
+ // max_evicted_seq_. Thread-safety: this function can be called concurrently.
+ // The concurrent invocations of this function is equivalent to a serial
+ // invocation in which the last invocation is the one with the largest
+ // version value.
+ void UpdateSnapshots(const std::vector<SequenceNumber>& snapshots,
+ const SequenceNumber& version);
+ // Check the new list of new snapshots against the old one to see if any of
+ // the snapshots are released and to do the cleanup for the released snapshot.
+ void CleanupReleasedSnapshots(
+ const std::vector<SequenceNumber>& new_snapshots,
+ const std::vector<SequenceNumber>& old_snapshots);
+
+ // Check an evicted entry against live snapshots to see if it should be kept
+ // around or it can be safely discarded (and hence assume committed for all
+ // snapshots). Thread-safety: this function can be called concurrently. If it
+ // is called concurrently with multiple UpdateSnapshots, the result is the
+ // same as checking the intersection of the snapshot list before updates with
+ // the snapshot list of all the concurrent updates.
+ void CheckAgainstSnapshots(const CommitEntry& evicted);
+
+ // Add a new entry to old_commit_map_ if prep_seq <= snapshot_seq <
+ // commit_seq. Return false if checking the next snapshot(s) is not needed.
+ // This is the case if none of the next snapshots could satisfy the condition.
+ // next_is_larger: the next snapshot will be a larger value
+ bool MaybeUpdateOldCommitMap(const uint64_t& prep_seq,
+ const uint64_t& commit_seq,
+ const uint64_t& snapshot_seq,
+ const bool next_is_larger);
+
+ // A trick to increase the last visible sequence number by one and also wait
+ // for the in-flight commits to be visible.
+ void AdvanceSeqByOne();
+
+ // The list of live snapshots at the last time that max_evicted_seq_ advanced.
+ // The list stored into two data structures: in snapshot_cache_ that is
+ // efficient for concurrent reads, and in snapshots_ if the data does not fit
+ // into snapshot_cache_. The total number of snapshots in the two lists
+ std::atomic<size_t> snapshots_total_ = {};
+ // The list sorted in ascending order. Thread-safety for writes is provided
+ // with snapshots_mutex_ and concurrent reads are safe due to std::atomic for
+ // each entry. In x86_64 architecture such reads are compiled to simple read
+ // instructions.
+ const size_t SNAPSHOT_CACHE_BITS;
+ const size_t SNAPSHOT_CACHE_SIZE;
+ std::unique_ptr<std::atomic<SequenceNumber>[]> snapshot_cache_;
+ // 2nd list for storing snapshots. The list sorted in ascending order.
+ // Thread-safety is provided with snapshots_mutex_.
+ std::vector<SequenceNumber> snapshots_;
+ // The list of all snapshots: snapshots_ + snapshot_cache_. This list although
+ // redundant but simplifies CleanupOldSnapshots implementation.
+ // Thread-safety is provided with snapshots_mutex_.
+ std::vector<SequenceNumber> snapshots_all_;
+ // The version of the latest list of snapshots. This can be used to avoid
+ // rewriting a list that is concurrently updated with a more recent version.
+ SequenceNumber snapshots_version_ = 0;
+
+ // A heap of prepared transactions. Thread-safety is provided with
+ // prepared_mutex_.
+ PreparedHeap prepared_txns_;
+ const size_t COMMIT_CACHE_BITS;
+ const size_t COMMIT_CACHE_SIZE;
+ const CommitEntry64bFormat FORMAT;
+ // commit_cache_ must be initialized to zero to tell apart an empty index from
+ // a filled one. Thread-safety is provided with commit_cache_mutex_.
+ std::unique_ptr<std::atomic<CommitEntry64b>[]> commit_cache_;
+ // The largest evicted *commit* sequence number from the commit_cache_. If a
+ // seq is smaller than max_evicted_seq_ is might or might not be present in
+ // commit_cache_. So commit_cache_ must first be checked before consulting
+ // with max_evicted_seq_.
+ std::atomic<uint64_t> max_evicted_seq_ = {};
+ // Order: 1) update future_max_evicted_seq_ = new_max, 2)
+ // GetSnapshotListFromDB(new_max), max_evicted_seq_ = new_max. Since
+ // GetSnapshotInternal guarantess that the snapshot seq is larger than
+ // future_max_evicted_seq_, this guarantes that if a snapshot is not larger
+ // than max has already being looked at via a GetSnapshotListFromDB(new_max).
+ std::atomic<uint64_t> future_max_evicted_seq_ = {};
+ // Advance max_evicted_seq_ by this value each time it needs an update. The
+ // larger the value, the less frequent advances we would have. We do not want
+ // it to be too large either as it would cause stalls by doing too much
+ // maintenance work under the lock.
+ size_t INC_STEP_FOR_MAX_EVICTED = 1;
+ // A map from old snapshots (expected to be used by a few read-only txns) to
+ // prepared sequence number of the evicted entries from commit_cache_ that
+ // overlaps with such snapshot. These are the prepared sequence numbers that
+ // the snapshot, to which they are mapped, cannot assume to be committed just
+ // because it is no longer in the commit_cache_. The vector must be sorted
+ // after each update.
+ // Thread-safety is provided with old_commit_map_mutex_.
+ std::map<SequenceNumber, std::vector<SequenceNumber>> old_commit_map_;
+ // A set of long-running prepared transactions that are not finished by the
+ // time max_evicted_seq_ advances their sequence number. This is expected to
+ // be empty normally. Thread-safety is provided with prepared_mutex_.
+ std::set<uint64_t> delayed_prepared_;
+ // Commit of a delayed prepared: 1) update commit cache, 2) update
+ // delayed_prepared_commits_, 3) publish seq, 3) clean up delayed_prepared_.
+ // delayed_prepared_commits_ will help us tell apart the unprepared txns from
+ // the ones that are committed but not cleaned up yet.
+ std::unordered_map<SequenceNumber, SequenceNumber> delayed_prepared_commits_;
+ // Update when delayed_prepared_.empty() changes. Expected to be true
+ // normally.
+ std::atomic<bool> delayed_prepared_empty_ = {true};
+ // Update when old_commit_map_.empty() changes. Expected to be true normally.
+ std::atomic<bool> old_commit_map_empty_ = {true};
+ mutable port::RWMutex prepared_mutex_;
+ mutable port::RWMutex old_commit_map_mutex_;
+ mutable port::RWMutex commit_cache_mutex_;
+ mutable port::RWMutex snapshots_mutex_;
+ // A cache of the cf comparators
+ // Thread safety: since it is a const it is safe to read it concurrently
+ std::shared_ptr<std::map<uint32_t, const Comparator*>> cf_map_;
+ // A cache of the cf handles
+ // Thread safety: since the handle is read-only object it is a const it is
+ // safe to read it concurrently
+ std::shared_ptr<std::map<uint32_t, ColumnFamilyHandle*>> handle_map_;
+ // A dummy snapshot object that refers to kMaxSequenceNumber
+ SnapshotImpl dummy_max_snapshot_;
+};
+
+class WritePreparedTxnReadCallback : public ReadCallback {
+ public:
+ WritePreparedTxnReadCallback(WritePreparedTxnDB* db, SequenceNumber snapshot)
+ : ReadCallback(snapshot),
+ db_(db),
+ backed_by_snapshot_(kBackedByDBSnapshot) {}
+ WritePreparedTxnReadCallback(WritePreparedTxnDB* db, SequenceNumber snapshot,
+ SequenceNumber min_uncommitted,
+ SnapshotBackup backed_by_snapshot)
+ : ReadCallback(snapshot, min_uncommitted),
+ db_(db),
+ backed_by_snapshot_(backed_by_snapshot) {
+ (void)backed_by_snapshot_; // to silence unused private field warning
+ }
+
+ virtual ~WritePreparedTxnReadCallback() {
+ // If it is not backed by snapshot, the caller must check validity
+ assert(valid_checked_ || backed_by_snapshot_ == kBackedByDBSnapshot);
+ }
+
+ // Will be called to see if the seq number visible; if not it moves on to
+ // the next seq number.
+ inline virtual bool IsVisibleFullCheck(SequenceNumber seq) override {
+ auto snapshot = max_visible_seq_;
+ bool snap_released = false;
+ auto ret =
+ db_->IsInSnapshot(seq, snapshot, min_uncommitted_, &snap_released);
+ assert(!snap_released || backed_by_snapshot_ == kUnbackedByDBSnapshot);
+ snap_released_ |= snap_released;
+ return ret;
+ }
+
+ inline bool valid() {
+ valid_checked_ = true;
+ return snap_released_ == false;
+ }
+
+ // TODO(myabandeh): override Refresh when Iterator::Refresh is supported
+ private:
+ WritePreparedTxnDB* db_;
+ // Whether max_visible_seq_ is backed by a snapshot
+ const SnapshotBackup backed_by_snapshot_;
+ bool snap_released_ = false;
+ // Safety check to ensure that the caller has checked invalid statuses
+ bool valid_checked_ = false;
+};
+
+class AddPreparedCallback : public PreReleaseCallback {
+ public:
+ AddPreparedCallback(WritePreparedTxnDB* db, DBImpl* db_impl,
+ size_t sub_batch_cnt, bool two_write_queues,
+ bool first_prepare_batch)
+ : db_(db),
+ db_impl_(db_impl),
+ sub_batch_cnt_(sub_batch_cnt),
+ two_write_queues_(two_write_queues),
+ first_prepare_batch_(first_prepare_batch) {
+ (void)two_write_queues_; // to silence unused private field warning
+ }
+ virtual Status Callback(SequenceNumber prepare_seq,
+ bool is_mem_disabled __attribute__((__unused__)),
+ uint64_t log_number, size_t index,
+ size_t total) override {
+ assert(index < total);
+ // To reduce the cost of lock acquisition competing with the concurrent
+ // prepare requests, lock on the first callback and unlock on the last.
+ const bool do_lock = !two_write_queues_ || index == 0;
+ const bool do_unlock = !two_write_queues_ || index + 1 == total;
+ // Always Prepare from the main queue
+ assert(!two_write_queues_ || !is_mem_disabled); // implies the 1st queue
+ TEST_SYNC_POINT("AddPreparedCallback::AddPrepared::begin:pause");
+ TEST_SYNC_POINT("AddPreparedCallback::AddPrepared::begin:resume");
+ if (do_lock) {
+ db_->prepared_txns_.push_pop_mutex()->Lock();
+ }
+ const bool kLocked = true;
+ for (size_t i = 0; i < sub_batch_cnt_; i++) {
+ db_->AddPrepared(prepare_seq + i, kLocked);
+ }
+ if (do_unlock) {
+ db_->prepared_txns_.push_pop_mutex()->Unlock();
+ }
+ TEST_SYNC_POINT("AddPreparedCallback::AddPrepared::end");
+ if (first_prepare_batch_) {
+ assert(log_number != 0);
+ db_impl_->logs_with_prep_tracker()->MarkLogAsContainingPrepSection(
+ log_number);
+ }
+ return Status::OK();
+ }
+
+ private:
+ WritePreparedTxnDB* db_;
+ DBImpl* db_impl_;
+ size_t sub_batch_cnt_;
+ bool two_write_queues_;
+ // It is 2PC and this is the first prepare batch. Always the case in 2PC
+ // unless it is WriteUnPrepared.
+ bool first_prepare_batch_;
+};
+
+class WritePreparedCommitEntryPreReleaseCallback : public PreReleaseCallback {
+ public:
+ // includes_data indicates that the commit also writes non-empty
+ // CommitTimeWriteBatch to memtable, which needs to be committed separately.
+ WritePreparedCommitEntryPreReleaseCallback(
+ WritePreparedTxnDB* db, DBImpl* db_impl, SequenceNumber prep_seq,
+ size_t prep_batch_cnt, size_t data_batch_cnt = 0,
+ SequenceNumber aux_seq = kMaxSequenceNumber, size_t aux_batch_cnt = 0)
+ : db_(db),
+ db_impl_(db_impl),
+ prep_seq_(prep_seq),
+ prep_batch_cnt_(prep_batch_cnt),
+ data_batch_cnt_(data_batch_cnt),
+ includes_data_(data_batch_cnt_ > 0),
+ aux_seq_(aux_seq),
+ aux_batch_cnt_(aux_batch_cnt),
+ includes_aux_batch_(aux_batch_cnt > 0) {
+ assert((prep_batch_cnt_ > 0) != (prep_seq == kMaxSequenceNumber)); // xor
+ assert(prep_batch_cnt_ > 0 || data_batch_cnt_ > 0);
+ assert((aux_batch_cnt_ > 0) != (aux_seq == kMaxSequenceNumber)); // xor
+ }
+
+ virtual Status Callback(SequenceNumber commit_seq,
+ bool is_mem_disabled __attribute__((__unused__)),
+ uint64_t, size_t /*index*/,
+ size_t /*total*/) override {
+ // Always commit from the 2nd queue
+ assert(!db_impl_->immutable_db_options().two_write_queues ||
+ is_mem_disabled);
+ assert(includes_data_ || prep_seq_ != kMaxSequenceNumber);
+ // Data batch is what accompanied with the commit marker and affects the
+ // last seq in the commit batch.
+ const uint64_t last_commit_seq = LIKELY(data_batch_cnt_ <= 1)
+ ? commit_seq
+ : commit_seq + data_batch_cnt_ - 1;
+ if (prep_seq_ != kMaxSequenceNumber) {
+ for (size_t i = 0; i < prep_batch_cnt_; i++) {
+ db_->AddCommitted(prep_seq_ + i, last_commit_seq);
+ }
+ } // else there was no prepare phase
+ if (includes_aux_batch_) {
+ for (size_t i = 0; i < aux_batch_cnt_; i++) {
+ db_->AddCommitted(aux_seq_ + i, last_commit_seq);
+ }
+ }
+ if (includes_data_) {
+ assert(data_batch_cnt_);
+ // Commit the data that is accompanied with the commit request
+ for (size_t i = 0; i < data_batch_cnt_; i++) {
+ // For commit seq of each batch use the commit seq of the last batch.
+ // This would make debugging easier by having all the batches having
+ // the same sequence number.
+ db_->AddCommitted(commit_seq + i, last_commit_seq);
+ }
+ }
+ if (db_impl_->immutable_db_options().two_write_queues) {
+ assert(is_mem_disabled); // implies the 2nd queue
+ // Publish the sequence number. We can do that here assuming the callback
+ // is invoked only from one write queue, which would guarantee that the
+ // publish sequence numbers will be in order, i.e., once a seq is
+ // published all the seq prior to that are also publishable.
+ db_impl_->SetLastPublishedSequence(last_commit_seq);
+ // Note RemovePrepared should be called after publishing the seq.
+ // Otherwise SmallestUnCommittedSeq optimization breaks.
+ if (prep_seq_ != kMaxSequenceNumber) {
+ db_->RemovePrepared(prep_seq_, prep_batch_cnt_);
+ } // else there was no prepare phase
+ if (includes_aux_batch_) {
+ db_->RemovePrepared(aux_seq_, aux_batch_cnt_);
+ }
+ }
+ // else SequenceNumber that is updated as part of the write already does the
+ // publishing
+ return Status::OK();
+ }
+
+ private:
+ WritePreparedTxnDB* db_;
+ DBImpl* db_impl_;
+ // kMaxSequenceNumber if there was no prepare phase
+ SequenceNumber prep_seq_;
+ size_t prep_batch_cnt_;
+ size_t data_batch_cnt_;
+ // Data here is the batch that is written with the commit marker, either
+ // because it is commit without prepare or commit has a CommitTimeWriteBatch.
+ bool includes_data_;
+ // Auxiliary batch (if there is any) is a batch that is written before, but
+ // gets the same commit seq as prepare batch or data batch. This is used in
+ // two write queues where the CommitTimeWriteBatch becomes the aux batch and
+ // we do a separate write to actually commit everything.
+ SequenceNumber aux_seq_;
+ size_t aux_batch_cnt_;
+ bool includes_aux_batch_;
+};
+
+// For two_write_queues commit both the aborted batch and the cleanup batch and
+// then published the seq
+class WritePreparedRollbackPreReleaseCallback : public PreReleaseCallback {
+ public:
+ WritePreparedRollbackPreReleaseCallback(WritePreparedTxnDB* db,
+ DBImpl* db_impl,
+ SequenceNumber prep_seq,
+ SequenceNumber rollback_seq,
+ size_t prep_batch_cnt)
+ : db_(db),
+ db_impl_(db_impl),
+ prep_seq_(prep_seq),
+ rollback_seq_(rollback_seq),
+ prep_batch_cnt_(prep_batch_cnt) {
+ assert(prep_seq != kMaxSequenceNumber);
+ assert(rollback_seq != kMaxSequenceNumber);
+ assert(prep_batch_cnt_ > 0);
+ }
+
+ Status Callback(SequenceNumber commit_seq, bool is_mem_disabled, uint64_t,
+ size_t /*index*/, size_t /*total*/) override {
+ // Always commit from the 2nd queue
+ assert(is_mem_disabled); // implies the 2nd queue
+ assert(db_impl_->immutable_db_options().two_write_queues);
+#ifdef NDEBUG
+ (void)is_mem_disabled;
+#endif
+ const uint64_t last_commit_seq = commit_seq;
+ db_->AddCommitted(rollback_seq_, last_commit_seq);
+ for (size_t i = 0; i < prep_batch_cnt_; i++) {
+ db_->AddCommitted(prep_seq_ + i, last_commit_seq);
+ }
+ db_impl_->SetLastPublishedSequence(last_commit_seq);
+ return Status::OK();
+ }
+
+ private:
+ WritePreparedTxnDB* db_;
+ DBImpl* db_impl_;
+ SequenceNumber prep_seq_;
+ SequenceNumber rollback_seq_;
+ size_t prep_batch_cnt_;
+};
+
+// Count the number of sub-batches inside a batch. A sub-batch does not have
+// duplicate keys.
+struct SubBatchCounter : public WriteBatch::Handler {
+ explicit SubBatchCounter(std::map<uint32_t, const Comparator*>& comparators)
+ : comparators_(comparators), batches_(1) {}
+ std::map<uint32_t, const Comparator*>& comparators_;
+ using CFKeys = std::set<Slice, SetComparator>;
+ std::map<uint32_t, CFKeys> keys_;
+ size_t batches_;
+ size_t BatchCount() { return batches_; }
+ void AddKey(const uint32_t cf, const Slice& key);
+ void InitWithComp(const uint32_t cf);
+ Status MarkNoop(bool) override { return Status::OK(); }
+ Status MarkEndPrepare(const Slice&) override { return Status::OK(); }
+ Status MarkCommit(const Slice&) override { return Status::OK(); }
+ Status PutCF(uint32_t cf, const Slice& key, const Slice&) override {
+ AddKey(cf, key);
+ return Status::OK();
+ }
+ Status DeleteCF(uint32_t cf, const Slice& key) override {
+ AddKey(cf, key);
+ return Status::OK();
+ }
+ Status SingleDeleteCF(uint32_t cf, const Slice& key) override {
+ AddKey(cf, key);
+ return Status::OK();
+ }
+ Status MergeCF(uint32_t cf, const Slice& key, const Slice&) override {
+ AddKey(cf, key);
+ return Status::OK();
+ }
+ Status MarkBeginPrepare(bool) override { return Status::OK(); }
+ Status MarkRollback(const Slice&) override { return Status::OK(); }
+ bool WriteAfterCommit() const override { return false; }
+};
+
+SnapshotBackup WritePreparedTxnDB::AssignMinMaxSeqs(const Snapshot* snapshot,
+ SequenceNumber* min,
+ SequenceNumber* max) {
+ if (snapshot != nullptr) {
+ *min = static_cast_with_check<const SnapshotImpl, const Snapshot>(snapshot)
+ ->min_uncommitted_;
+ *max = static_cast_with_check<const SnapshotImpl, const Snapshot>(snapshot)
+ ->number_;
+ return kBackedByDBSnapshot;
+ } else {
+ *min = SmallestUnCommittedSeq();
+ *max = 0; // to be assigned later after sv is referenced.
+ return kUnbackedByDBSnapshot;
+ }
+}
+
+bool WritePreparedTxnDB::ValidateSnapshot(
+ const SequenceNumber snap_seq, const SnapshotBackup backed_by_snapshot,
+ std::memory_order order) {
+ if (backed_by_snapshot == kBackedByDBSnapshot) {
+ return true;
+ } else {
+ SequenceNumber max = max_evicted_seq_.load(order);
+ // Validate that max has not advanced the snapshot seq that is not backed
+ // by a real snapshot. This is a very rare case that should not happen in
+ // real workloads.
+ if (UNLIKELY(snap_seq <= max && snap_seq != 0)) {
+ return false;
+ }
+ }
+ return true;
+}
+
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/write_unprepared_transaction_test.cc b/src/rocksdb/utilities/transactions/write_unprepared_transaction_test.cc
new file mode 100644
index 000000000..8b1613b2e
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/write_unprepared_transaction_test.cc
@@ -0,0 +1,727 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "utilities/transactions/transaction_test.h"
+#include "utilities/transactions/write_unprepared_txn.h"
+#include "utilities/transactions/write_unprepared_txn_db.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class WriteUnpreparedTransactionTestBase : public TransactionTestBase {
+ public:
+ WriteUnpreparedTransactionTestBase(bool use_stackable_db,
+ bool two_write_queue,
+ TxnDBWritePolicy write_policy)
+ : TransactionTestBase(use_stackable_db, two_write_queue, write_policy,
+ kOrderedWrite) {}
+};
+
+class WriteUnpreparedTransactionTest
+ : public WriteUnpreparedTransactionTestBase,
+ virtual public ::testing::WithParamInterface<
+ std::tuple<bool, bool, TxnDBWritePolicy>> {
+ public:
+ WriteUnpreparedTransactionTest()
+ : WriteUnpreparedTransactionTestBase(std::get<0>(GetParam()),
+ std::get<1>(GetParam()),
+ std::get<2>(GetParam())){}
+};
+
+INSTANTIATE_TEST_CASE_P(
+ WriteUnpreparedTransactionTest, WriteUnpreparedTransactionTest,
+ ::testing::Values(std::make_tuple(false, false, WRITE_UNPREPARED),
+ std::make_tuple(false, true, WRITE_UNPREPARED)));
+
+enum StressAction { NO_SNAPSHOT, RO_SNAPSHOT, REFRESH_SNAPSHOT };
+class WriteUnpreparedStressTest : public WriteUnpreparedTransactionTestBase,
+ virtual public ::testing::WithParamInterface<
+ std::tuple<bool, StressAction>> {
+ public:
+ WriteUnpreparedStressTest()
+ : WriteUnpreparedTransactionTestBase(false, std::get<0>(GetParam()),
+ WRITE_UNPREPARED),
+ action_(std::get<1>(GetParam())) {}
+ StressAction action_;
+};
+
+INSTANTIATE_TEST_CASE_P(
+ WriteUnpreparedStressTest, WriteUnpreparedStressTest,
+ ::testing::Values(std::make_tuple(false, NO_SNAPSHOT),
+ std::make_tuple(false, RO_SNAPSHOT),
+ std::make_tuple(false, REFRESH_SNAPSHOT),
+ std::make_tuple(true, NO_SNAPSHOT),
+ std::make_tuple(true, RO_SNAPSHOT),
+ std::make_tuple(true, REFRESH_SNAPSHOT)));
+
+TEST_P(WriteUnpreparedTransactionTest, ReadYourOwnWrite) {
+ // The following tests checks whether reading your own write for
+ // a transaction works for write unprepared, when there are uncommitted
+ // values written into DB.
+ auto verify_state = [](Iterator* iter, const std::string& key,
+ const std::string& value) {
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_OK(iter->status());
+ ASSERT_EQ(key, iter->key().ToString());
+ ASSERT_EQ(value, iter->value().ToString());
+ };
+
+ // Test always reseeking vs never reseeking.
+ for (uint64_t max_skip : {0, std::numeric_limits<int>::max()}) {
+ options.max_sequential_skip_in_iterations = max_skip;
+ options.disable_auto_compactions = true;
+ ReOpen();
+
+ TransactionOptions txn_options;
+ WriteOptions woptions;
+ ReadOptions roptions;
+
+ ASSERT_OK(db->Put(woptions, "a", ""));
+ ASSERT_OK(db->Put(woptions, "b", ""));
+
+ Transaction* txn = db->BeginTransaction(woptions, txn_options);
+ WriteUnpreparedTxn* wup_txn = dynamic_cast<WriteUnpreparedTxn*>(txn);
+ txn->SetSnapshot();
+
+ for (int i = 0; i < 5; i++) {
+ std::string stored_value = "v" + ToString(i);
+ ASSERT_OK(txn->Put("a", stored_value));
+ ASSERT_OK(txn->Put("b", stored_value));
+ wup_txn->FlushWriteBatchToDB(false);
+
+ // Test Get()
+ std::string value;
+ ASSERT_OK(txn->Get(roptions, "a", &value));
+ ASSERT_EQ(value, stored_value);
+ ASSERT_OK(txn->Get(roptions, "b", &value));
+ ASSERT_EQ(value, stored_value);
+
+ // Test Next()
+ auto iter = txn->GetIterator(roptions);
+ iter->Seek("a");
+ verify_state(iter, "a", stored_value);
+
+ iter->Next();
+ verify_state(iter, "b", stored_value);
+
+ iter->SeekToFirst();
+ verify_state(iter, "a", stored_value);
+
+ iter->Next();
+ verify_state(iter, "b", stored_value);
+
+ delete iter;
+
+ // Test Prev()
+ iter = txn->GetIterator(roptions);
+ iter->SeekForPrev("b");
+ verify_state(iter, "b", stored_value);
+
+ iter->Prev();
+ verify_state(iter, "a", stored_value);
+
+ iter->SeekToLast();
+ verify_state(iter, "b", stored_value);
+
+ iter->Prev();
+ verify_state(iter, "a", stored_value);
+
+ delete iter;
+ }
+
+ delete txn;
+ }
+}
+
+#ifndef ROCKSDB_VALGRIND_RUN
+TEST_P(WriteUnpreparedStressTest, ReadYourOwnWriteStress) {
+ // This is a stress test where different threads are writing random keys, and
+ // then before committing or aborting the transaction, it validates to see
+ // that it can read the keys it wrote, and the keys it did not write respect
+ // the snapshot. To avoid row lock contention (and simply stressing the
+ // locking system), each thread is mostly only writing to its own set of keys.
+ const uint32_t kNumIter = 1000;
+ const uint32_t kNumThreads = 10;
+ const uint32_t kNumKeys = 5;
+
+ std::default_random_engine rand(static_cast<uint32_t>(
+ std::hash<std::thread::id>()(std::this_thread::get_id())));
+
+ // Test with
+ // 1. no snapshots set
+ // 2. snapshot set on ReadOptions
+ // 3. snapshot set, and refreshing after every write.
+ StressAction a = action_;
+ WriteOptions write_options;
+ txn_db_options.transaction_lock_timeout = -1;
+ options.disable_auto_compactions = true;
+ ReOpen();
+
+ std::vector<std::string> keys;
+ for (uint32_t k = 0; k < kNumKeys * kNumThreads; k++) {
+ keys.push_back("k" + ToString(k));
+ }
+ std::shuffle(keys.begin(), keys.end(), rand);
+
+ // This counter will act as a "sequence number" to help us validate
+ // visibility logic with snapshots. If we had direct access to the seqno of
+ // snapshots and key/values, then we should directly compare those instead.
+ std::atomic<int64_t> counter(0);
+
+ std::function<void(uint32_t)> stress_thread = [&](int id) {
+ size_t tid = std::hash<std::thread::id>()(std::this_thread::get_id());
+ Random64 rnd(static_cast<uint32_t>(tid));
+
+ Transaction* txn;
+ TransactionOptions txn_options;
+ // batch_size of 1 causes writes to DB for every marker.
+ txn_options.write_batch_flush_threshold = 1;
+ ReadOptions read_options;
+
+ for (uint32_t i = 0; i < kNumIter; i++) {
+ std::set<std::string> owned_keys(&keys[id * kNumKeys],
+ &keys[(id + 1) * kNumKeys]);
+ // Add unowned keys to make the workload more interesting, but this
+ // increases row lock contention, so just do it sometimes.
+ if (rnd.OneIn(2)) {
+ owned_keys.insert(keys[rnd.Uniform(kNumKeys * kNumThreads)]);
+ }
+
+ txn = db->BeginTransaction(write_options, txn_options);
+ txn->SetName(ToString(id));
+ txn->SetSnapshot();
+ if (a >= RO_SNAPSHOT) {
+ read_options.snapshot = txn->GetSnapshot();
+ ASSERT_TRUE(read_options.snapshot != nullptr);
+ }
+
+ uint64_t buf[2];
+ buf[0] = id;
+
+ // When scanning through the database, make sure that all unprepared
+ // keys have value >= snapshot and all other keys have value < snapshot.
+ int64_t snapshot_num = counter.fetch_add(1);
+
+ Status s;
+ for (const auto& key : owned_keys) {
+ buf[1] = counter.fetch_add(1);
+ s = txn->Put(key, Slice((const char*)buf, sizeof(buf)));
+ if (!s.ok()) {
+ break;
+ }
+ if (a == REFRESH_SNAPSHOT) {
+ txn->SetSnapshot();
+ read_options.snapshot = txn->GetSnapshot();
+ snapshot_num = counter.fetch_add(1);
+ }
+ }
+
+ // Failure is possible due to snapshot validation. In this case,
+ // rollback and move onto next iteration.
+ if (!s.ok()) {
+ ASSERT_TRUE(s.IsBusy());
+ ASSERT_OK(txn->Rollback());
+ delete txn;
+ continue;
+ }
+
+ auto verify_key = [&owned_keys, &a, &id, &snapshot_num](
+ const std::string& key, const std::string& value) {
+ if (owned_keys.count(key) > 0) {
+ ASSERT_EQ(value.size(), 16);
+
+ // Since this key is part of owned_keys, then this key must be
+ // unprepared by this transaction identified by 'id'
+ ASSERT_EQ(((int64_t*)value.c_str())[0], id);
+ if (a == REFRESH_SNAPSHOT) {
+ // If refresh snapshot is true, then the snapshot is refreshed
+ // after every Put(), meaning that the current snapshot in
+ // snapshot_num must be greater than the "seqno" of any keys
+ // written by the current transaction.
+ ASSERT_LT(((int64_t*)value.c_str())[1], snapshot_num);
+ } else {
+ // If refresh snapshot is not on, then the snapshot was taken at
+ // the beginning of the transaction, meaning all writes must come
+ // after snapshot_num
+ ASSERT_GT(((int64_t*)value.c_str())[1], snapshot_num);
+ }
+ } else if (a >= RO_SNAPSHOT) {
+ // If this is not an unprepared key, just assert that the key
+ // "seqno" is smaller than the snapshot seqno.
+ ASSERT_EQ(value.size(), 16);
+ ASSERT_LT(((int64_t*)value.c_str())[1], snapshot_num);
+ }
+ };
+
+ // Validate Get()/Next()/Prev(). Do only one of them to save time, and
+ // reduce lock contention.
+ switch (rnd.Uniform(3)) {
+ case 0: // Validate Get()
+ {
+ for (const auto& key : keys) {
+ std::string value;
+ s = txn->Get(read_options, Slice(key), &value);
+ if (!s.ok()) {
+ ASSERT_TRUE(s.IsNotFound());
+ ASSERT_EQ(owned_keys.count(key), 0);
+ } else {
+ verify_key(key, value);
+ }
+ }
+ break;
+ }
+ case 1: // Validate Next()
+ {
+ Iterator* iter = txn->GetIterator(read_options);
+ for (iter->SeekToFirst(); iter->Valid(); iter->Next()) {
+ verify_key(iter->key().ToString(), iter->value().ToString());
+ }
+ delete iter;
+ break;
+ }
+ case 2: // Validate Prev()
+ {
+ Iterator* iter = txn->GetIterator(read_options);
+ for (iter->SeekToLast(); iter->Valid(); iter->Prev()) {
+ verify_key(iter->key().ToString(), iter->value().ToString());
+ }
+ delete iter;
+ break;
+ }
+ default:
+ ASSERT_TRUE(false);
+ }
+
+ if (rnd.OneIn(2)) {
+ ASSERT_OK(txn->Commit());
+ } else {
+ ASSERT_OK(txn->Rollback());
+ }
+ delete txn;
+ }
+ };
+
+ std::vector<port::Thread> threads;
+ for (uint32_t i = 0; i < kNumThreads; i++) {
+ threads.emplace_back(stress_thread, i);
+ }
+
+ for (auto& t : threads) {
+ t.join();
+ }
+}
+#endif // ROCKSDB_VALGRIND_RUN
+
+// This tests how write unprepared behaves during recovery when the DB crashes
+// after a transaction has either been unprepared or prepared, and tests if
+// the changes are correctly applied for prepared transactions if we decide to
+// rollback/commit.
+TEST_P(WriteUnpreparedTransactionTest, RecoveryTest) {
+ WriteOptions write_options;
+ write_options.disableWAL = false;
+ TransactionOptions txn_options;
+ std::vector<Transaction*> prepared_trans;
+ WriteUnpreparedTxnDB* wup_db;
+ options.disable_auto_compactions = true;
+
+ enum Action { UNPREPARED, ROLLBACK, COMMIT };
+
+ // batch_size of 1 causes writes to DB for every marker.
+ for (size_t batch_size : {1, 1000000}) {
+ txn_options.write_batch_flush_threshold = batch_size;
+ for (bool empty : {true, false}) {
+ for (Action a : {UNPREPARED, ROLLBACK, COMMIT}) {
+ for (int num_batches = 1; num_batches < 10; num_batches++) {
+ // Reset database.
+ prepared_trans.clear();
+ ReOpen();
+ wup_db = dynamic_cast<WriteUnpreparedTxnDB*>(db);
+ if (!empty) {
+ for (int i = 0; i < num_batches; i++) {
+ ASSERT_OK(db->Put(WriteOptions(), "k" + ToString(i),
+ "before value" + ToString(i)));
+ }
+ }
+
+ // Write num_batches unprepared batches.
+ Transaction* txn = db->BeginTransaction(write_options, txn_options);
+ WriteUnpreparedTxn* wup_txn = dynamic_cast<WriteUnpreparedTxn*>(txn);
+ txn->SetName("xid");
+ for (int i = 0; i < num_batches; i++) {
+ ASSERT_OK(txn->Put("k" + ToString(i), "value" + ToString(i)));
+ if (txn_options.write_batch_flush_threshold == 1) {
+ // WriteUnprepared will check write_batch_flush_threshold and
+ // possibly flush before appending to the write batch. No flush
+ // will happen at the first write because the batch is still
+ // empty, so after k puts, there should be k-1 flushed batches.
+ ASSERT_EQ(wup_txn->GetUnpreparedSequenceNumbers().size(), i);
+ } else {
+ ASSERT_EQ(wup_txn->GetUnpreparedSequenceNumbers().size(), 0);
+ }
+ }
+ if (a == UNPREPARED) {
+ // This is done to prevent the destructor from rolling back the
+ // transaction for us, since we want to pretend we crashed and
+ // test that recovery does the rollback.
+ wup_txn->unprep_seqs_.clear();
+ } else {
+ txn->Prepare();
+ }
+ delete txn;
+
+ // Crash and run recovery code paths.
+ wup_db->db_impl_->FlushWAL(true);
+ wup_db->TEST_Crash();
+ ReOpenNoDelete();
+ assert(db != nullptr);
+
+ db->GetAllPreparedTransactions(&prepared_trans);
+ ASSERT_EQ(prepared_trans.size(), a == UNPREPARED ? 0 : 1);
+ if (a == ROLLBACK) {
+ ASSERT_OK(prepared_trans[0]->Rollback());
+ delete prepared_trans[0];
+ } else if (a == COMMIT) {
+ ASSERT_OK(prepared_trans[0]->Commit());
+ delete prepared_trans[0];
+ }
+
+ Iterator* iter = db->NewIterator(ReadOptions());
+ iter->SeekToFirst();
+ // Check that DB has before values.
+ if (!empty || a == COMMIT) {
+ for (int i = 0; i < num_batches; i++) {
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ(iter->key().ToString(), "k" + ToString(i));
+ if (a == COMMIT) {
+ ASSERT_EQ(iter->value().ToString(), "value" + ToString(i));
+ } else {
+ ASSERT_EQ(iter->value().ToString(),
+ "before value" + ToString(i));
+ }
+ iter->Next();
+ }
+ }
+ ASSERT_FALSE(iter->Valid());
+ delete iter;
+ }
+ }
+ }
+ }
+}
+
+// Basic test to see that unprepared batch gets written to DB when batch size
+// is exceeded. It also does some basic checks to see if commit/rollback works
+// as expected for write unprepared.
+TEST_P(WriteUnpreparedTransactionTest, UnpreparedBatch) {
+ WriteOptions write_options;
+ TransactionOptions txn_options;
+ const int kNumKeys = 10;
+
+ // batch_size of 1 causes writes to DB for every marker.
+ for (size_t batch_size : {1, 1000000}) {
+ txn_options.write_batch_flush_threshold = batch_size;
+ for (bool prepare : {false, true}) {
+ for (bool commit : {false, true}) {
+ ReOpen();
+ Transaction* txn = db->BeginTransaction(write_options, txn_options);
+ WriteUnpreparedTxn* wup_txn = dynamic_cast<WriteUnpreparedTxn*>(txn);
+ txn->SetName("xid");
+
+ for (int i = 0; i < kNumKeys; i++) {
+ txn->Put("k" + ToString(i), "v" + ToString(i));
+ if (txn_options.write_batch_flush_threshold == 1) {
+ // WriteUnprepared will check write_batch_flush_threshold and
+ // possibly flush before appending to the write batch. No flush will
+ // happen at the first write because the batch is still empty, so
+ // after k puts, there should be k-1 flushed batches.
+ ASSERT_EQ(wup_txn->GetUnpreparedSequenceNumbers().size(), i);
+ } else {
+ ASSERT_EQ(wup_txn->GetUnpreparedSequenceNumbers().size(), 0);
+ }
+ }
+
+ if (prepare) {
+ ASSERT_OK(txn->Prepare());
+ }
+
+ Iterator* iter = db->NewIterator(ReadOptions());
+ iter->SeekToFirst();
+ assert(!iter->Valid());
+ ASSERT_FALSE(iter->Valid());
+ delete iter;
+
+ if (commit) {
+ ASSERT_OK(txn->Commit());
+ } else {
+ ASSERT_OK(txn->Rollback());
+ }
+ delete txn;
+
+ iter = db->NewIterator(ReadOptions());
+ iter->SeekToFirst();
+
+ for (int i = 0; i < (commit ? kNumKeys : 0); i++) {
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ(iter->key().ToString(), "k" + ToString(i));
+ ASSERT_EQ(iter->value().ToString(), "v" + ToString(i));
+ iter->Next();
+ }
+ ASSERT_FALSE(iter->Valid());
+ delete iter;
+ }
+ }
+ }
+}
+
+// Test whether logs containing unprepared/prepared batches are kept even
+// after memtable finishes flushing, and whether they are removed when
+// transaction commits/aborts.
+//
+// TODO(lth): Merge with TransactionTest/TwoPhaseLogRollingTest tests.
+TEST_P(WriteUnpreparedTransactionTest, MarkLogWithPrepSection) {
+ WriteOptions write_options;
+ TransactionOptions txn_options;
+ // batch_size of 1 causes writes to DB for every marker.
+ txn_options.write_batch_flush_threshold = 1;
+ const int kNumKeys = 10;
+
+ WriteOptions wopts;
+ wopts.sync = true;
+
+ for (bool prepare : {false, true}) {
+ for (bool commit : {false, true}) {
+ ReOpen();
+ auto wup_db = dynamic_cast<WriteUnpreparedTxnDB*>(db);
+ auto db_impl = wup_db->db_impl_;
+
+ Transaction* txn1 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(txn1->SetName("xid1"));
+
+ Transaction* txn2 = db->BeginTransaction(write_options, txn_options);
+ ASSERT_OK(txn2->SetName("xid2"));
+
+ // Spread this transaction across multiple log files.
+ for (int i = 0; i < kNumKeys; i++) {
+ ASSERT_OK(txn1->Put("k1" + ToString(i), "v" + ToString(i)));
+ if (i >= kNumKeys / 2) {
+ ASSERT_OK(txn2->Put("k2" + ToString(i), "v" + ToString(i)));
+ }
+
+ if (i > 0) {
+ db_impl->TEST_SwitchWAL();
+ }
+ }
+
+ ASSERT_GT(txn1->GetLogNumber(), 0);
+ ASSERT_GT(txn2->GetLogNumber(), 0);
+
+ ASSERT_EQ(db_impl->TEST_FindMinLogContainingOutstandingPrep(),
+ txn1->GetLogNumber());
+ ASSERT_GT(db_impl->TEST_LogfileNumber(), txn1->GetLogNumber());
+
+ if (prepare) {
+ ASSERT_OK(txn1->Prepare());
+ ASSERT_OK(txn2->Prepare());
+ }
+
+ ASSERT_GE(db_impl->TEST_LogfileNumber(), txn1->GetLogNumber());
+ ASSERT_GE(db_impl->TEST_LogfileNumber(), txn2->GetLogNumber());
+
+ ASSERT_EQ(db_impl->TEST_FindMinLogContainingOutstandingPrep(),
+ txn1->GetLogNumber());
+ if (commit) {
+ ASSERT_OK(txn1->Commit());
+ } else {
+ ASSERT_OK(txn1->Rollback());
+ }
+
+ ASSERT_EQ(db_impl->TEST_FindMinLogContainingOutstandingPrep(),
+ txn2->GetLogNumber());
+
+ if (commit) {
+ ASSERT_OK(txn2->Commit());
+ } else {
+ ASSERT_OK(txn2->Rollback());
+ }
+
+ ASSERT_EQ(db_impl->TEST_FindMinLogContainingOutstandingPrep(), 0);
+
+ delete txn1;
+ delete txn2;
+ }
+ }
+}
+
+TEST_P(WriteUnpreparedTransactionTest, NoSnapshotWrite) {
+ WriteOptions woptions;
+ TransactionOptions txn_options;
+ txn_options.write_batch_flush_threshold = 1;
+
+ Transaction* txn = db->BeginTransaction(woptions, txn_options);
+
+ // Do some writes with no snapshot
+ ASSERT_OK(txn->Put("a", "a"));
+ ASSERT_OK(txn->Put("b", "b"));
+ ASSERT_OK(txn->Put("c", "c"));
+
+ // Test that it is still possible to create iterators after writes with no
+ // snapshot, if iterator snapshot is fresh enough.
+ ReadOptions roptions;
+ auto iter = txn->GetIterator(roptions);
+ int keys = 0;
+ for (iter->SeekToLast(); iter->Valid(); iter->Prev(), keys++) {
+ ASSERT_OK(iter->status());
+ ASSERT_EQ(iter->key().ToString(), iter->value().ToString());
+ }
+ ASSERT_EQ(keys, 3);
+
+ delete iter;
+ delete txn;
+}
+
+// Test whether write to a transaction while iterating is supported.
+TEST_P(WriteUnpreparedTransactionTest, IterateAndWrite) {
+ WriteOptions woptions;
+ TransactionOptions txn_options;
+ txn_options.write_batch_flush_threshold = 1;
+
+ enum Action { DO_DELETE, DO_UPDATE };
+
+ for (Action a : {DO_DELETE, DO_UPDATE}) {
+ for (int i = 0; i < 100; i++) {
+ ASSERT_OK(db->Put(woptions, ToString(i), ToString(i)));
+ }
+
+ Transaction* txn = db->BeginTransaction(woptions, txn_options);
+ // write_batch_ now contains 1 key.
+ ASSERT_OK(txn->Put("9", "a"));
+
+ ReadOptions roptions;
+ auto iter = txn->GetIterator(roptions);
+ for (iter->SeekToFirst(); iter->Valid(); iter->Next()) {
+ ASSERT_OK(iter->status());
+ if (iter->key() == "9") {
+ ASSERT_EQ(iter->value().ToString(), "a");
+ } else {
+ ASSERT_EQ(iter->key().ToString(), iter->value().ToString());
+ }
+
+ if (a == DO_DELETE) {
+ ASSERT_OK(txn->Delete(iter->key()));
+ } else {
+ ASSERT_OK(txn->Put(iter->key(), "b"));
+ }
+ }
+
+ delete iter;
+ ASSERT_OK(txn->Commit());
+
+ iter = db->NewIterator(roptions);
+ if (a == DO_DELETE) {
+ // Check that db is empty.
+ iter->SeekToFirst();
+ ASSERT_FALSE(iter->Valid());
+ } else {
+ int keys = 0;
+ // Check that all values are updated to b.
+ for (iter->SeekToFirst(); iter->Valid(); iter->Next(), keys++) {
+ ASSERT_OK(iter->status());
+ ASSERT_EQ(iter->value().ToString(), "b");
+ }
+ ASSERT_EQ(keys, 100);
+ }
+
+ delete iter;
+ delete txn;
+ }
+}
+
+TEST_P(WriteUnpreparedTransactionTest, SavePoint) {
+ WriteOptions woptions;
+ TransactionOptions txn_options;
+ txn_options.write_batch_flush_threshold = 1;
+
+ Transaction* txn = db->BeginTransaction(woptions, txn_options);
+ txn->SetSavePoint();
+ ASSERT_OK(txn->Put("a", "a"));
+ ASSERT_OK(txn->Put("b", "b"));
+ ASSERT_OK(txn->Commit());
+
+ ReadOptions roptions;
+ std::string value;
+ ASSERT_OK(txn->Get(roptions, "a", &value));
+ ASSERT_EQ(value, "a");
+ ASSERT_OK(txn->Get(roptions, "b", &value));
+ ASSERT_EQ(value, "b");
+ delete txn;
+}
+
+TEST_P(WriteUnpreparedTransactionTest, UntrackedKeys) {
+ WriteOptions woptions;
+ TransactionOptions txn_options;
+ txn_options.write_batch_flush_threshold = 1;
+
+ Transaction* txn = db->BeginTransaction(woptions, txn_options);
+ auto wb = txn->GetWriteBatch()->GetWriteBatch();
+ ASSERT_OK(txn->Put("a", "a"));
+ ASSERT_OK(wb->Put("a_untrack", "a_untrack"));
+ txn->SetSavePoint();
+ ASSERT_OK(txn->Put("b", "b"));
+ ASSERT_OK(txn->Put("b_untrack", "b_untrack"));
+
+ ReadOptions roptions;
+ std::string value;
+ ASSERT_OK(txn->Get(roptions, "a", &value));
+ ASSERT_EQ(value, "a");
+ ASSERT_OK(txn->Get(roptions, "a_untrack", &value));
+ ASSERT_EQ(value, "a_untrack");
+ ASSERT_OK(txn->Get(roptions, "b", &value));
+ ASSERT_EQ(value, "b");
+ ASSERT_OK(txn->Get(roptions, "b_untrack", &value));
+ ASSERT_EQ(value, "b_untrack");
+
+ // b and b_untrack should be rolled back.
+ ASSERT_OK(txn->RollbackToSavePoint());
+ ASSERT_OK(txn->Get(roptions, "a", &value));
+ ASSERT_EQ(value, "a");
+ ASSERT_OK(txn->Get(roptions, "a_untrack", &value));
+ ASSERT_EQ(value, "a_untrack");
+ auto s = txn->Get(roptions, "b", &value);
+ ASSERT_TRUE(s.IsNotFound());
+ s = txn->Get(roptions, "b_untrack", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ // Everything should be rolled back.
+ ASSERT_OK(txn->Rollback());
+ s = txn->Get(roptions, "a", &value);
+ ASSERT_TRUE(s.IsNotFound());
+ s = txn->Get(roptions, "a_untrack", &value);
+ ASSERT_TRUE(s.IsNotFound());
+ s = txn->Get(roptions, "b", &value);
+ ASSERT_TRUE(s.IsNotFound());
+ s = txn->Get(roptions, "b_untrack", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ delete txn;
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
+
+#else
+#include <stdio.h>
+
+int main(int /*argc*/, char** /*argv*/) {
+ fprintf(stderr,
+ "SKIPPED as Transactions are not supported in ROCKSDB_LITE\n");
+ return 0;
+}
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/write_unprepared_txn.cc b/src/rocksdb/utilities/transactions/write_unprepared_txn.cc
new file mode 100644
index 000000000..01ec298cf
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/write_unprepared_txn.cc
@@ -0,0 +1,999 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "utilities/transactions/write_unprepared_txn.h"
+#include "db/db_impl/db_impl.h"
+#include "util/cast_util.h"
+#include "utilities/transactions/write_unprepared_txn_db.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+bool WriteUnpreparedTxnReadCallback::IsVisibleFullCheck(SequenceNumber seq) {
+ // Since unprep_seqs maps prep_seq => prepare_batch_cnt, to check if seq is
+ // in unprep_seqs, we have to check if seq is equal to prep_seq or any of
+ // the prepare_batch_cnt seq nums after it.
+ //
+ // TODO(lth): Can be optimized with std::lower_bound if unprep_seqs is
+ // large.
+ for (const auto& it : unprep_seqs_) {
+ if (it.first <= seq && seq < it.first + it.second) {
+ return true;
+ }
+ }
+
+ bool snap_released = false;
+ auto ret =
+ db_->IsInSnapshot(seq, wup_snapshot_, min_uncommitted_, &snap_released);
+ assert(!snap_released || backed_by_snapshot_ == kUnbackedByDBSnapshot);
+ snap_released_ |= snap_released;
+ return ret;
+}
+
+WriteUnpreparedTxn::WriteUnpreparedTxn(WriteUnpreparedTxnDB* txn_db,
+ const WriteOptions& write_options,
+ const TransactionOptions& txn_options)
+ : WritePreparedTxn(txn_db, write_options, txn_options),
+ wupt_db_(txn_db),
+ last_log_number_(0),
+ recovered_txn_(false),
+ largest_validated_seq_(0) {
+ if (txn_options.write_batch_flush_threshold < 0) {
+ write_batch_flush_threshold_ =
+ txn_db_impl_->GetTxnDBOptions().default_write_batch_flush_threshold;
+ } else {
+ write_batch_flush_threshold_ = txn_options.write_batch_flush_threshold;
+ }
+}
+
+WriteUnpreparedTxn::~WriteUnpreparedTxn() {
+ if (!unprep_seqs_.empty()) {
+ assert(log_number_ > 0);
+ assert(GetId() > 0);
+ assert(!name_.empty());
+
+ // We should rollback regardless of GetState, but some unit tests that
+ // test crash recovery run the destructor assuming that rollback does not
+ // happen, so that rollback during recovery can be exercised.
+ if (GetState() == STARTED || GetState() == LOCKS_STOLEN) {
+ auto s = RollbackInternal();
+ assert(s.ok());
+ if (!s.ok()) {
+ ROCKS_LOG_FATAL(
+ wupt_db_->info_log_,
+ "Rollback of WriteUnprepared transaction failed in destructor: %s",
+ s.ToString().c_str());
+ }
+ dbimpl_->logs_with_prep_tracker()->MarkLogAsHavingPrepSectionFlushed(
+ log_number_);
+ }
+ }
+
+ // Call tracked_keys_.clear() so that ~PessimisticTransaction does not
+ // try to unlock keys for recovered transactions.
+ if (recovered_txn_) {
+ tracked_keys_.clear();
+ }
+}
+
+void WriteUnpreparedTxn::Initialize(const TransactionOptions& txn_options) {
+ PessimisticTransaction::Initialize(txn_options);
+ if (txn_options.write_batch_flush_threshold < 0) {
+ write_batch_flush_threshold_ =
+ txn_db_impl_->GetTxnDBOptions().default_write_batch_flush_threshold;
+ } else {
+ write_batch_flush_threshold_ = txn_options.write_batch_flush_threshold;
+ }
+
+ unprep_seqs_.clear();
+ flushed_save_points_.reset(nullptr);
+ unflushed_save_points_.reset(nullptr);
+ recovered_txn_ = false;
+ largest_validated_seq_ = 0;
+ assert(active_iterators_.empty());
+ active_iterators_.clear();
+ untracked_keys_.clear();
+}
+
+Status WriteUnpreparedTxn::HandleWrite(std::function<Status()> do_write) {
+ Status s;
+ if (active_iterators_.empty()) {
+ s = MaybeFlushWriteBatchToDB();
+ if (!s.ok()) {
+ return s;
+ }
+ }
+ s = do_write();
+ if (s.ok()) {
+ if (snapshot_) {
+ largest_validated_seq_ =
+ std::max(largest_validated_seq_, snapshot_->GetSequenceNumber());
+ } else {
+ // TODO(lth): We should use the same number as tracked_at_seq in TryLock,
+ // because what is actually being tracked is the sequence number at which
+ // this key was locked at.
+ largest_validated_seq_ = db_impl_->GetLastPublishedSequence();
+ }
+ }
+ return s;
+}
+
+Status WriteUnpreparedTxn::Put(ColumnFamilyHandle* column_family,
+ const Slice& key, const Slice& value,
+ const bool assume_tracked) {
+ return HandleWrite([&]() {
+ return TransactionBaseImpl::Put(column_family, key, value, assume_tracked);
+ });
+}
+
+Status WriteUnpreparedTxn::Put(ColumnFamilyHandle* column_family,
+ const SliceParts& key, const SliceParts& value,
+ const bool assume_tracked) {
+ return HandleWrite([&]() {
+ return TransactionBaseImpl::Put(column_family, key, value, assume_tracked);
+ });
+}
+
+Status WriteUnpreparedTxn::Merge(ColumnFamilyHandle* column_family,
+ const Slice& key, const Slice& value,
+ const bool assume_tracked) {
+ return HandleWrite([&]() {
+ return TransactionBaseImpl::Merge(column_family, key, value,
+ assume_tracked);
+ });
+}
+
+Status WriteUnpreparedTxn::Delete(ColumnFamilyHandle* column_family,
+ const Slice& key, const bool assume_tracked) {
+ return HandleWrite([&]() {
+ return TransactionBaseImpl::Delete(column_family, key, assume_tracked);
+ });
+}
+
+Status WriteUnpreparedTxn::Delete(ColumnFamilyHandle* column_family,
+ const SliceParts& key,
+ const bool assume_tracked) {
+ return HandleWrite([&]() {
+ return TransactionBaseImpl::Delete(column_family, key, assume_tracked);
+ });
+}
+
+Status WriteUnpreparedTxn::SingleDelete(ColumnFamilyHandle* column_family,
+ const Slice& key,
+ const bool assume_tracked) {
+ return HandleWrite([&]() {
+ return TransactionBaseImpl::SingleDelete(column_family, key,
+ assume_tracked);
+ });
+}
+
+Status WriteUnpreparedTxn::SingleDelete(ColumnFamilyHandle* column_family,
+ const SliceParts& key,
+ const bool assume_tracked) {
+ return HandleWrite([&]() {
+ return TransactionBaseImpl::SingleDelete(column_family, key,
+ assume_tracked);
+ });
+}
+
+// WriteUnpreparedTxn::RebuildFromWriteBatch is only called on recovery. For
+// WriteUnprepared, the write batches have already been written into the
+// database during WAL replay, so all we have to do is just to "retrack" the key
+// so that rollbacks are possible.
+//
+// Calling TryLock instead of TrackKey is also possible, but as an optimization,
+// recovered transactions do not hold locks on their keys. This follows the
+// implementation in PessimisticTransactionDB::Initialize where we set
+// skip_concurrency_control to true.
+Status WriteUnpreparedTxn::RebuildFromWriteBatch(WriteBatch* wb) {
+ struct TrackKeyHandler : public WriteBatch::Handler {
+ WriteUnpreparedTxn* txn_;
+ bool rollback_merge_operands_;
+
+ TrackKeyHandler(WriteUnpreparedTxn* txn, bool rollback_merge_operands)
+ : txn_(txn), rollback_merge_operands_(rollback_merge_operands) {}
+
+ Status PutCF(uint32_t cf, const Slice& key, const Slice&) override {
+ txn_->TrackKey(cf, key.ToString(), kMaxSequenceNumber,
+ false /* read_only */, true /* exclusive */);
+ return Status::OK();
+ }
+
+ Status DeleteCF(uint32_t cf, const Slice& key) override {
+ txn_->TrackKey(cf, key.ToString(), kMaxSequenceNumber,
+ false /* read_only */, true /* exclusive */);
+ return Status::OK();
+ }
+
+ Status SingleDeleteCF(uint32_t cf, const Slice& key) override {
+ txn_->TrackKey(cf, key.ToString(), kMaxSequenceNumber,
+ false /* read_only */, true /* exclusive */);
+ return Status::OK();
+ }
+
+ Status MergeCF(uint32_t cf, const Slice& key, const Slice&) override {
+ if (rollback_merge_operands_) {
+ txn_->TrackKey(cf, key.ToString(), kMaxSequenceNumber,
+ false /* read_only */, true /* exclusive */);
+ }
+ return Status::OK();
+ }
+
+ // Recovered batches do not contain 2PC markers.
+ Status MarkBeginPrepare(bool) override { return Status::InvalidArgument(); }
+
+ Status MarkEndPrepare(const Slice&) override {
+ return Status::InvalidArgument();
+ }
+
+ Status MarkNoop(bool) override { return Status::InvalidArgument(); }
+
+ Status MarkCommit(const Slice&) override {
+ return Status::InvalidArgument();
+ }
+
+ Status MarkRollback(const Slice&) override {
+ return Status::InvalidArgument();
+ }
+ };
+
+ TrackKeyHandler handler(this,
+ wupt_db_->txn_db_options_.rollback_merge_operands);
+ return wb->Iterate(&handler);
+}
+
+Status WriteUnpreparedTxn::MaybeFlushWriteBatchToDB() {
+ const bool kPrepared = true;
+ Status s;
+ if (write_batch_flush_threshold_ > 0 &&
+ write_batch_.GetWriteBatch()->Count() > 0 &&
+ write_batch_.GetDataSize() >
+ static_cast<size_t>(write_batch_flush_threshold_)) {
+ assert(GetState() != PREPARED);
+ s = FlushWriteBatchToDB(!kPrepared);
+ }
+ return s;
+}
+
+Status WriteUnpreparedTxn::FlushWriteBatchToDB(bool prepared) {
+ // If the current write batch contains savepoints, then some special handling
+ // is required so that RollbackToSavepoint can work.
+ //
+ // RollbackToSavepoint is not supported after Prepare() is called, so only do
+ // this for unprepared batches.
+ if (!prepared && unflushed_save_points_ != nullptr &&
+ !unflushed_save_points_->empty()) {
+ return FlushWriteBatchWithSavePointToDB();
+ }
+
+ return FlushWriteBatchToDBInternal(prepared);
+}
+
+Status WriteUnpreparedTxn::FlushWriteBatchToDBInternal(bool prepared) {
+ if (name_.empty()) {
+ assert(!prepared);
+#ifndef NDEBUG
+ static std::atomic_ullong autogen_id{0};
+ // To avoid changing all tests to call SetName, just autogenerate one.
+ if (wupt_db_->txn_db_options_.autogenerate_name) {
+ SetName(std::string("autoxid") + ToString(autogen_id.fetch_add(1)));
+ } else
+#endif
+ {
+ return Status::InvalidArgument("Cannot write to DB without SetName.");
+ }
+ }
+
+ struct UntrackedKeyHandler : public WriteBatch::Handler {
+ WriteUnpreparedTxn* txn_;
+ bool rollback_merge_operands_;
+
+ UntrackedKeyHandler(WriteUnpreparedTxn* txn, bool rollback_merge_operands)
+ : txn_(txn), rollback_merge_operands_(rollback_merge_operands) {}
+
+ Status AddUntrackedKey(uint32_t cf, const Slice& key) {
+ auto str = key.ToString();
+ if (txn_->tracked_keys_[cf].count(str) == 0) {
+ txn_->untracked_keys_[cf].push_back(str);
+ }
+ return Status::OK();
+ }
+
+ Status PutCF(uint32_t cf, const Slice& key, const Slice&) override {
+ return AddUntrackedKey(cf, key);
+ }
+
+ Status DeleteCF(uint32_t cf, const Slice& key) override {
+ return AddUntrackedKey(cf, key);
+ }
+
+ Status SingleDeleteCF(uint32_t cf, const Slice& key) override {
+ return AddUntrackedKey(cf, key);
+ }
+
+ Status MergeCF(uint32_t cf, const Slice& key, const Slice&) override {
+ if (rollback_merge_operands_) {
+ return AddUntrackedKey(cf, key);
+ }
+ return Status::OK();
+ }
+
+ // The only expected 2PC marker is the initial Noop marker.
+ Status MarkNoop(bool empty_batch) override {
+ return empty_batch ? Status::OK() : Status::InvalidArgument();
+ }
+
+ Status MarkBeginPrepare(bool) override { return Status::InvalidArgument(); }
+
+ Status MarkEndPrepare(const Slice&) override {
+ return Status::InvalidArgument();
+ }
+
+ Status MarkCommit(const Slice&) override {
+ return Status::InvalidArgument();
+ }
+
+ Status MarkRollback(const Slice&) override {
+ return Status::InvalidArgument();
+ }
+ };
+
+ UntrackedKeyHandler handler(
+ this, wupt_db_->txn_db_options_.rollback_merge_operands);
+ auto s = GetWriteBatch()->GetWriteBatch()->Iterate(&handler);
+ assert(s.ok());
+
+ // TODO(lth): Reduce duplicate code with WritePrepared prepare logic.
+ WriteOptions write_options = write_options_;
+ write_options.disableWAL = false;
+ const bool WRITE_AFTER_COMMIT = true;
+ const bool first_prepare_batch = log_number_ == 0;
+ // MarkEndPrepare will change Noop marker to the appropriate marker.
+ WriteBatchInternal::MarkEndPrepare(GetWriteBatch()->GetWriteBatch(), name_,
+ !WRITE_AFTER_COMMIT, !prepared);
+ // For each duplicate key we account for a new sub-batch
+ prepare_batch_cnt_ = GetWriteBatch()->SubBatchCnt();
+ // AddPrepared better to be called in the pre-release callback otherwise there
+ // is a non-zero chance of max advancing prepare_seq and readers assume the
+ // data as committed.
+ // Also having it in the PreReleaseCallback allows in-order addition of
+ // prepared entries to PreparedHeap and hence enables an optimization. Refer
+ // to SmallestUnCommittedSeq for more details.
+ AddPreparedCallback add_prepared_callback(
+ wpt_db_, db_impl_, prepare_batch_cnt_,
+ db_impl_->immutable_db_options().two_write_queues, first_prepare_batch);
+ const bool DISABLE_MEMTABLE = true;
+ uint64_t seq_used = kMaxSequenceNumber;
+ // log_number_ should refer to the oldest log containing uncommitted data
+ // from the current transaction. This means that if log_number_ is set,
+ // WriteImpl should not overwrite that value, so set log_used to nullptr if
+ // log_number_ is already set.
+ s = db_impl_->WriteImpl(write_options, GetWriteBatch()->GetWriteBatch(),
+ /*callback*/ nullptr, &last_log_number_,
+ /*log ref*/ 0, !DISABLE_MEMTABLE, &seq_used,
+ prepare_batch_cnt_, &add_prepared_callback);
+ if (log_number_ == 0) {
+ log_number_ = last_log_number_;
+ }
+ assert(!s.ok() || seq_used != kMaxSequenceNumber);
+ auto prepare_seq = seq_used;
+
+ // Only call SetId if it hasn't been set yet.
+ if (GetId() == 0) {
+ SetId(prepare_seq);
+ }
+ // unprep_seqs_ will also contain prepared seqnos since they are treated in
+ // the same way in the prepare/commit callbacks. See the comment on the
+ // definition of unprep_seqs_.
+ unprep_seqs_[prepare_seq] = prepare_batch_cnt_;
+
+ // Reset transaction state.
+ if (!prepared) {
+ prepare_batch_cnt_ = 0;
+ const bool kClear = true;
+ TransactionBaseImpl::InitWriteBatch(kClear);
+ }
+
+ return s;
+}
+
+Status WriteUnpreparedTxn::FlushWriteBatchWithSavePointToDB() {
+ assert(unflushed_save_points_ != nullptr &&
+ unflushed_save_points_->size() > 0);
+ assert(save_points_ != nullptr && save_points_->size() > 0);
+ assert(save_points_->size() >= unflushed_save_points_->size());
+
+ // Handler class for creating an unprepared batch from a savepoint.
+ struct SavePointBatchHandler : public WriteBatch::Handler {
+ WriteBatchWithIndex* wb_;
+ const std::map<uint32_t, ColumnFamilyHandle*>& handles_;
+
+ SavePointBatchHandler(
+ WriteBatchWithIndex* wb,
+ const std::map<uint32_t, ColumnFamilyHandle*>& handles)
+ : wb_(wb), handles_(handles) {}
+
+ Status PutCF(uint32_t cf, const Slice& key, const Slice& value) override {
+ return wb_->Put(handles_.at(cf), key, value);
+ }
+
+ Status DeleteCF(uint32_t cf, const Slice& key) override {
+ return wb_->Delete(handles_.at(cf), key);
+ }
+
+ Status SingleDeleteCF(uint32_t cf, const Slice& key) override {
+ return wb_->SingleDelete(handles_.at(cf), key);
+ }
+
+ Status MergeCF(uint32_t cf, const Slice& key, const Slice& value) override {
+ return wb_->Merge(handles_.at(cf), key, value);
+ }
+
+ // The only expected 2PC marker is the initial Noop marker.
+ Status MarkNoop(bool empty_batch) override {
+ return empty_batch ? Status::OK() : Status::InvalidArgument();
+ }
+
+ Status MarkBeginPrepare(bool) override { return Status::InvalidArgument(); }
+
+ Status MarkEndPrepare(const Slice&) override {
+ return Status::InvalidArgument();
+ }
+
+ Status MarkCommit(const Slice&) override {
+ return Status::InvalidArgument();
+ }
+
+ Status MarkRollback(const Slice&) override {
+ return Status::InvalidArgument();
+ }
+ };
+
+ // The comparator of the default cf is passed in, similar to the
+ // initialization of TransactionBaseImpl::write_batch_. This comparator is
+ // only used if the write batch encounters an invalid cf id, and falls back to
+ // this comparator.
+ WriteBatchWithIndex wb(wpt_db_->DefaultColumnFamily()->GetComparator(), 0,
+ true, 0);
+ // Swap with write_batch_ so that wb contains the complete write batch. The
+ // actual write batch that will be flushed to DB will be built in
+ // write_batch_, and will be read by FlushWriteBatchToDBInternal.
+ std::swap(wb, write_batch_);
+ TransactionBaseImpl::InitWriteBatch();
+
+ size_t prev_boundary = WriteBatchInternal::kHeader;
+ const bool kPrepared = true;
+ for (size_t i = 0; i < unflushed_save_points_->size() + 1; i++) {
+ bool trailing_batch = i == unflushed_save_points_->size();
+ SavePointBatchHandler sp_handler(&write_batch_,
+ *wupt_db_->GetCFHandleMap().get());
+ size_t curr_boundary = trailing_batch ? wb.GetWriteBatch()->GetDataSize()
+ : (*unflushed_save_points_)[i];
+
+ // Construct the partial write batch up to the savepoint.
+ //
+ // Theoretically, a memcpy between the write batches should be sufficient
+ // since the rewriting into the batch should produce the exact same byte
+ // representation. Rebuilding the WriteBatchWithIndex index is still
+ // necessary though, and would imply doing two passes over the batch though.
+ Status s = WriteBatchInternal::Iterate(wb.GetWriteBatch(), &sp_handler,
+ prev_boundary, curr_boundary);
+ if (!s.ok()) {
+ return s;
+ }
+
+ if (write_batch_.GetWriteBatch()->Count() > 0) {
+ // Flush the write batch.
+ s = FlushWriteBatchToDBInternal(!kPrepared);
+ if (!s.ok()) {
+ return s;
+ }
+ }
+
+ if (!trailing_batch) {
+ if (flushed_save_points_ == nullptr) {
+ flushed_save_points_.reset(
+ new autovector<WriteUnpreparedTxn::SavePoint>());
+ }
+ flushed_save_points_->emplace_back(
+ unprep_seqs_, new ManagedSnapshot(db_impl_, wupt_db_->GetSnapshot()));
+ }
+
+ prev_boundary = curr_boundary;
+ const bool kClear = true;
+ TransactionBaseImpl::InitWriteBatch(kClear);
+ }
+
+ unflushed_save_points_->clear();
+ return Status::OK();
+}
+
+Status WriteUnpreparedTxn::PrepareInternal() {
+ const bool kPrepared = true;
+ return FlushWriteBatchToDB(kPrepared);
+}
+
+Status WriteUnpreparedTxn::CommitWithoutPrepareInternal() {
+ if (unprep_seqs_.empty()) {
+ assert(log_number_ == 0);
+ assert(GetId() == 0);
+ return WritePreparedTxn::CommitWithoutPrepareInternal();
+ }
+
+ // TODO(lth): We should optimize commit without prepare to not perform
+ // a prepare under the hood.
+ auto s = PrepareInternal();
+ if (!s.ok()) {
+ return s;
+ }
+ return CommitInternal();
+}
+
+Status WriteUnpreparedTxn::CommitInternal() {
+ // TODO(lth): Reduce duplicate code with WritePrepared commit logic.
+
+ // We take the commit-time batch and append the Commit marker. The Memtable
+ // will ignore the Commit marker in non-recovery mode
+ WriteBatch* working_batch = GetCommitTimeWriteBatch();
+ const bool empty = working_batch->Count() == 0;
+ WriteBatchInternal::MarkCommit(working_batch, name_);
+
+ const bool for_recovery = use_only_the_last_commit_time_batch_for_recovery_;
+ if (!empty && for_recovery) {
+ // When not writing to memtable, we can still cache the latest write batch.
+ // The cached batch will be written to memtable in WriteRecoverableState
+ // during FlushMemTable
+ WriteBatchInternal::SetAsLastestPersistentState(working_batch);
+ }
+
+ const bool includes_data = !empty && !for_recovery;
+ size_t commit_batch_cnt = 0;
+ if (UNLIKELY(includes_data)) {
+ ROCKS_LOG_WARN(db_impl_->immutable_db_options().info_log,
+ "Duplicate key overhead");
+ SubBatchCounter counter(*wpt_db_->GetCFComparatorMap());
+ auto s = working_batch->Iterate(&counter);
+ assert(s.ok());
+ commit_batch_cnt = counter.BatchCount();
+ }
+ const bool disable_memtable = !includes_data;
+ const bool do_one_write =
+ !db_impl_->immutable_db_options().two_write_queues || disable_memtable;
+
+ WriteUnpreparedCommitEntryPreReleaseCallback update_commit_map(
+ wpt_db_, db_impl_, unprep_seqs_, commit_batch_cnt);
+ const bool kFirstPrepareBatch = true;
+ AddPreparedCallback add_prepared_callback(
+ wpt_db_, db_impl_, commit_batch_cnt,
+ db_impl_->immutable_db_options().two_write_queues, !kFirstPrepareBatch);
+ PreReleaseCallback* pre_release_callback;
+ if (do_one_write) {
+ pre_release_callback = &update_commit_map;
+ } else {
+ pre_release_callback = &add_prepared_callback;
+ }
+ uint64_t seq_used = kMaxSequenceNumber;
+ // Since the prepared batch is directly written to memtable, there is
+ // already a connection between the memtable and its WAL, so there is no
+ // need to redundantly reference the log that contains the prepared data.
+ const uint64_t zero_log_number = 0ull;
+ size_t batch_cnt = UNLIKELY(commit_batch_cnt) ? commit_batch_cnt : 1;
+ auto s = db_impl_->WriteImpl(write_options_, working_batch, nullptr, nullptr,
+ zero_log_number, disable_memtable, &seq_used,
+ batch_cnt, pre_release_callback);
+ assert(!s.ok() || seq_used != kMaxSequenceNumber);
+ const SequenceNumber commit_batch_seq = seq_used;
+ if (LIKELY(do_one_write || !s.ok())) {
+ if (LIKELY(s.ok())) {
+ // Note RemovePrepared should be called after WriteImpl that publishsed
+ // the seq. Otherwise SmallestUnCommittedSeq optimization breaks.
+ for (const auto& seq : unprep_seqs_) {
+ wpt_db_->RemovePrepared(seq.first, seq.second);
+ }
+ }
+ if (UNLIKELY(!do_one_write)) {
+ wpt_db_->RemovePrepared(commit_batch_seq, commit_batch_cnt);
+ }
+ unprep_seqs_.clear();
+ flushed_save_points_.reset(nullptr);
+ unflushed_save_points_.reset(nullptr);
+ return s;
+ } // else do the 2nd write to publish seq
+
+ // Populate unprep_seqs_ with commit_batch_seq, since we treat data in the
+ // commit write batch as just another "unprepared" batch. This will also
+ // update the unprep_seqs_ in the update_commit_map callback.
+ unprep_seqs_[commit_batch_seq] = commit_batch_cnt;
+
+ // Note: the 2nd write comes with a performance penality. So if we have too
+ // many of commits accompanied with ComitTimeWriteBatch and yet we cannot
+ // enable use_only_the_last_commit_time_batch_for_recovery_ optimization,
+ // two_write_queues should be disabled to avoid many additional writes here.
+
+ // Update commit map only from the 2nd queue
+ WriteBatch empty_batch;
+ empty_batch.PutLogData(Slice());
+ // In the absence of Prepare markers, use Noop as a batch separator
+ WriteBatchInternal::InsertNoop(&empty_batch);
+ const bool DISABLE_MEMTABLE = true;
+ const size_t ONE_BATCH = 1;
+ const uint64_t NO_REF_LOG = 0;
+ s = db_impl_->WriteImpl(write_options_, &empty_batch, nullptr, nullptr,
+ NO_REF_LOG, DISABLE_MEMTABLE, &seq_used, ONE_BATCH,
+ &update_commit_map);
+ assert(!s.ok() || seq_used != kMaxSequenceNumber);
+ // Note RemovePrepared should be called after WriteImpl that publishsed the
+ // seq. Otherwise SmallestUnCommittedSeq optimization breaks.
+ for (const auto& seq : unprep_seqs_) {
+ wpt_db_->RemovePrepared(seq.first, seq.second);
+ }
+ unprep_seqs_.clear();
+ flushed_save_points_.reset(nullptr);
+ unflushed_save_points_.reset(nullptr);
+ return s;
+}
+
+Status WriteUnpreparedTxn::WriteRollbackKeys(
+ const TransactionKeyMap& tracked_keys, WriteBatchWithIndex* rollback_batch,
+ ReadCallback* callback, const ReadOptions& roptions) {
+ const auto& cf_map = *wupt_db_->GetCFHandleMap();
+ auto WriteRollbackKey = [&](const std::string& key, uint32_t cfid) {
+ const auto& cf_handle = cf_map.at(cfid);
+ PinnableSlice pinnable_val;
+ bool not_used;
+ DBImpl::GetImplOptions get_impl_options;
+ get_impl_options.column_family = cf_handle;
+ get_impl_options.value = &pinnable_val;
+ get_impl_options.value_found = &not_used;
+ get_impl_options.callback = callback;
+ auto s = db_impl_->GetImpl(roptions, key, get_impl_options);
+
+ if (s.ok()) {
+ s = rollback_batch->Put(cf_handle, key, pinnable_val);
+ assert(s.ok());
+ } else if (s.IsNotFound()) {
+ s = rollback_batch->Delete(cf_handle, key);
+ assert(s.ok());
+ } else {
+ return s;
+ }
+
+ return Status::OK();
+ };
+
+ for (const auto& cfkey : tracked_keys) {
+ const auto cfid = cfkey.first;
+ const auto& keys = cfkey.second;
+ for (const auto& pair : keys) {
+ auto s = WriteRollbackKey(pair.first, cfid);
+ if (!s.ok()) {
+ return s;
+ }
+ }
+ }
+
+ for (const auto& cfkey : untracked_keys_) {
+ const auto cfid = cfkey.first;
+ const auto& keys = cfkey.second;
+ for (const auto& key : keys) {
+ auto s = WriteRollbackKey(key, cfid);
+ if (!s.ok()) {
+ return s;
+ }
+ }
+ }
+
+ return Status::OK();
+}
+
+Status WriteUnpreparedTxn::RollbackInternal() {
+ // TODO(lth): Reduce duplicate code with WritePrepared rollback logic.
+ WriteBatchWithIndex rollback_batch(
+ wpt_db_->DefaultColumnFamily()->GetComparator(), 0, true, 0);
+ assert(GetId() != kMaxSequenceNumber);
+ assert(GetId() > 0);
+ Status s;
+ auto read_at_seq = kMaxSequenceNumber;
+ ReadOptions roptions;
+ // to prevent callback's seq to be overrriden inside DBImpk::Get
+ roptions.snapshot = wpt_db_->GetMaxSnapshot();
+ // Note that we do not use WriteUnpreparedTxnReadCallback because we do not
+ // need to read our own writes when reading prior versions of the key for
+ // rollback.
+ WritePreparedTxnReadCallback callback(wpt_db_, read_at_seq);
+ WriteRollbackKeys(GetTrackedKeys(), &rollback_batch, &callback, roptions);
+
+ // The Rollback marker will be used as a batch separator
+ WriteBatchInternal::MarkRollback(rollback_batch.GetWriteBatch(), name_);
+ bool do_one_write = !db_impl_->immutable_db_options().two_write_queues;
+ const bool DISABLE_MEMTABLE = true;
+ const uint64_t NO_REF_LOG = 0;
+ uint64_t seq_used = kMaxSequenceNumber;
+ // TODO(lth): We write rollback batch all in a single batch here, but this
+ // should be subdivded into multiple batches as well. In phase 2, when key
+ // sets are read from WAL, this will happen naturally.
+ const size_t ONE_BATCH = 1;
+ // We commit the rolled back prepared batches. ALthough this is
+ // counter-intuitive, i) it is safe to do so, since the prepared batches are
+ // already canceled out by the rollback batch, ii) adding the commit entry to
+ // CommitCache will allow us to benefit from the existing mechanism in
+ // CommitCache that keeps an entry evicted due to max advance and yet overlaps
+ // with a live snapshot around so that the live snapshot properly skips the
+ // entry even if its prepare seq is lower than max_evicted_seq_.
+ WriteUnpreparedCommitEntryPreReleaseCallback update_commit_map(
+ wpt_db_, db_impl_, unprep_seqs_, ONE_BATCH);
+ // Note: the rollback batch does not need AddPrepared since it is written to
+ // DB in one shot. min_uncommitted still works since it requires capturing
+ // data that is written to DB but not yet committed, while the roolback
+ // batch commits with PreReleaseCallback.
+ s = db_impl_->WriteImpl(write_options_, rollback_batch.GetWriteBatch(),
+ nullptr, nullptr, NO_REF_LOG, !DISABLE_MEMTABLE,
+ &seq_used, rollback_batch.SubBatchCnt(),
+ do_one_write ? &update_commit_map : nullptr);
+ assert(!s.ok() || seq_used != kMaxSequenceNumber);
+ if (!s.ok()) {
+ return s;
+ }
+ if (do_one_write) {
+ for (const auto& seq : unprep_seqs_) {
+ wpt_db_->RemovePrepared(seq.first, seq.second);
+ }
+ unprep_seqs_.clear();
+ flushed_save_points_.reset(nullptr);
+ unflushed_save_points_.reset(nullptr);
+ return s;
+ } // else do the 2nd write for commit
+ uint64_t& prepare_seq = seq_used;
+ ROCKS_LOG_DETAILS(db_impl_->immutable_db_options().info_log,
+ "RollbackInternal 2nd write prepare_seq: %" PRIu64,
+ prepare_seq);
+ // Commit the batch by writing an empty batch to the queue that will release
+ // the commit sequence number to readers.
+ WriteUnpreparedRollbackPreReleaseCallback update_commit_map_with_prepare(
+ wpt_db_, db_impl_, unprep_seqs_, prepare_seq);
+ WriteBatch empty_batch;
+ empty_batch.PutLogData(Slice());
+ // In the absence of Prepare markers, use Noop as a batch separator
+ WriteBatchInternal::InsertNoop(&empty_batch);
+ s = db_impl_->WriteImpl(write_options_, &empty_batch, nullptr, nullptr,
+ NO_REF_LOG, DISABLE_MEMTABLE, &seq_used, ONE_BATCH,
+ &update_commit_map_with_prepare);
+ assert(!s.ok() || seq_used != kMaxSequenceNumber);
+ // Mark the txn as rolled back
+ if (s.ok()) {
+ for (const auto& seq : unprep_seqs_) {
+ wpt_db_->RemovePrepared(seq.first, seq.second);
+ }
+ }
+
+ unprep_seqs_.clear();
+ flushed_save_points_.reset(nullptr);
+ unflushed_save_points_.reset(nullptr);
+ return s;
+}
+
+void WriteUnpreparedTxn::Clear() {
+ if (!recovered_txn_) {
+ txn_db_impl_->UnLock(this, &GetTrackedKeys());
+ }
+ unprep_seqs_.clear();
+ flushed_save_points_.reset(nullptr);
+ unflushed_save_points_.reset(nullptr);
+ recovered_txn_ = false;
+ largest_validated_seq_ = 0;
+ assert(active_iterators_.empty());
+ active_iterators_.clear();
+ untracked_keys_.clear();
+ TransactionBaseImpl::Clear();
+}
+
+void WriteUnpreparedTxn::SetSavePoint() {
+ assert((unflushed_save_points_ ? unflushed_save_points_->size() : 0) +
+ (flushed_save_points_ ? flushed_save_points_->size() : 0) ==
+ (save_points_ ? save_points_->size() : 0));
+ PessimisticTransaction::SetSavePoint();
+ if (unflushed_save_points_ == nullptr) {
+ unflushed_save_points_.reset(new autovector<size_t>());
+ }
+ unflushed_save_points_->push_back(write_batch_.GetDataSize());
+}
+
+Status WriteUnpreparedTxn::RollbackToSavePoint() {
+ assert((unflushed_save_points_ ? unflushed_save_points_->size() : 0) +
+ (flushed_save_points_ ? flushed_save_points_->size() : 0) ==
+ (save_points_ ? save_points_->size() : 0));
+ if (unflushed_save_points_ != nullptr && unflushed_save_points_->size() > 0) {
+ Status s = PessimisticTransaction::RollbackToSavePoint();
+ assert(!s.IsNotFound());
+ unflushed_save_points_->pop_back();
+ return s;
+ }
+
+ if (flushed_save_points_ != nullptr && !flushed_save_points_->empty()) {
+ return RollbackToSavePointInternal();
+ }
+
+ return Status::NotFound();
+}
+
+Status WriteUnpreparedTxn::RollbackToSavePointInternal() {
+ Status s;
+
+ const bool kClear = true;
+ TransactionBaseImpl::InitWriteBatch(kClear);
+
+ assert(flushed_save_points_->size() > 0);
+ WriteUnpreparedTxn::SavePoint& top = flushed_save_points_->back();
+
+ assert(save_points_ != nullptr && save_points_->size() > 0);
+ const TransactionKeyMap& tracked_keys = save_points_->top().new_keys_;
+
+ ReadOptions roptions;
+ roptions.snapshot = top.snapshot_->snapshot();
+ SequenceNumber min_uncommitted =
+ static_cast_with_check<const SnapshotImpl, const Snapshot>(
+ roptions.snapshot)
+ ->min_uncommitted_;
+ SequenceNumber snap_seq = roptions.snapshot->GetSequenceNumber();
+ WriteUnpreparedTxnReadCallback callback(wupt_db_, snap_seq, min_uncommitted,
+ top.unprep_seqs_,
+ kBackedByDBSnapshot);
+ WriteRollbackKeys(tracked_keys, &write_batch_, &callback, roptions);
+
+ const bool kPrepared = true;
+ s = FlushWriteBatchToDBInternal(!kPrepared);
+ assert(s.ok());
+ if (!s.ok()) {
+ return s;
+ }
+
+ // PessimisticTransaction::RollbackToSavePoint will call also call
+ // RollbackToSavepoint on write_batch_. However, write_batch_ is empty and has
+ // no savepoints because this savepoint has already been flushed. Work around
+ // this by setting a fake savepoint.
+ write_batch_.SetSavePoint();
+ s = PessimisticTransaction::RollbackToSavePoint();
+ assert(s.ok());
+ if (!s.ok()) {
+ return s;
+ }
+
+ flushed_save_points_->pop_back();
+ return s;
+}
+
+Status WriteUnpreparedTxn::PopSavePoint() {
+ assert((unflushed_save_points_ ? unflushed_save_points_->size() : 0) +
+ (flushed_save_points_ ? flushed_save_points_->size() : 0) ==
+ (save_points_ ? save_points_->size() : 0));
+ if (unflushed_save_points_ != nullptr && unflushed_save_points_->size() > 0) {
+ Status s = PessimisticTransaction::PopSavePoint();
+ assert(!s.IsNotFound());
+ unflushed_save_points_->pop_back();
+ return s;
+ }
+
+ if (flushed_save_points_ != nullptr && !flushed_save_points_->empty()) {
+ // PessimisticTransaction::PopSavePoint will call also call PopSavePoint on
+ // write_batch_. However, write_batch_ is empty and has no savepoints
+ // because this savepoint has already been flushed. Work around this by
+ // setting a fake savepoint.
+ write_batch_.SetSavePoint();
+ Status s = PessimisticTransaction::PopSavePoint();
+ assert(!s.IsNotFound());
+ flushed_save_points_->pop_back();
+ return s;
+ }
+
+ return Status::NotFound();
+}
+
+void WriteUnpreparedTxn::MultiGet(const ReadOptions& options,
+ ColumnFamilyHandle* column_family,
+ const size_t num_keys, const Slice* keys,
+ PinnableSlice* values, Status* statuses,
+ const bool sorted_input) {
+ SequenceNumber min_uncommitted, snap_seq;
+ const SnapshotBackup backed_by_snapshot =
+ wupt_db_->AssignMinMaxSeqs(options.snapshot, &min_uncommitted, &snap_seq);
+ WriteUnpreparedTxnReadCallback callback(wupt_db_, snap_seq, min_uncommitted,
+ unprep_seqs_, backed_by_snapshot);
+ write_batch_.MultiGetFromBatchAndDB(db_, options, column_family, num_keys,
+ keys, values, statuses, sorted_input,
+ &callback);
+ if (UNLIKELY(!callback.valid() ||
+ !wupt_db_->ValidateSnapshot(snap_seq, backed_by_snapshot))) {
+ wupt_db_->WPRecordTick(TXN_GET_TRY_AGAIN);
+ for (size_t i = 0; i < num_keys; i++) {
+ statuses[i] = Status::TryAgain();
+ }
+ }
+}
+
+Status WriteUnpreparedTxn::Get(const ReadOptions& options,
+ ColumnFamilyHandle* column_family,
+ const Slice& key, PinnableSlice* value) {
+ SequenceNumber min_uncommitted, snap_seq;
+ const SnapshotBackup backed_by_snapshot =
+ wupt_db_->AssignMinMaxSeqs(options.snapshot, &min_uncommitted, &snap_seq);
+ WriteUnpreparedTxnReadCallback callback(wupt_db_, snap_seq, min_uncommitted,
+ unprep_seqs_, backed_by_snapshot);
+ auto res = write_batch_.GetFromBatchAndDB(db_, options, column_family, key,
+ value, &callback);
+ if (LIKELY(callback.valid() &&
+ wupt_db_->ValidateSnapshot(snap_seq, backed_by_snapshot))) {
+ return res;
+ } else {
+ wupt_db_->WPRecordTick(TXN_GET_TRY_AGAIN);
+ return Status::TryAgain();
+ }
+}
+
+namespace {
+static void CleanupWriteUnpreparedWBWIIterator(void* arg1, void* arg2) {
+ auto txn = reinterpret_cast<WriteUnpreparedTxn*>(arg1);
+ auto iter = reinterpret_cast<Iterator*>(arg2);
+ txn->RemoveActiveIterator(iter);
+}
+} // anonymous namespace
+
+Iterator* WriteUnpreparedTxn::GetIterator(const ReadOptions& options) {
+ return GetIterator(options, wupt_db_->DefaultColumnFamily());
+}
+
+Iterator* WriteUnpreparedTxn::GetIterator(const ReadOptions& options,
+ ColumnFamilyHandle* column_family) {
+ // Make sure to get iterator from WriteUnprepareTxnDB, not the root db.
+ Iterator* db_iter = wupt_db_->NewIterator(options, column_family, this);
+ assert(db_iter);
+
+ auto iter = write_batch_.NewIteratorWithBase(column_family, db_iter);
+ active_iterators_.push_back(iter);
+ iter->RegisterCleanup(CleanupWriteUnpreparedWBWIIterator, this, iter);
+ return iter;
+}
+
+Status WriteUnpreparedTxn::ValidateSnapshot(ColumnFamilyHandle* column_family,
+ const Slice& key,
+ SequenceNumber* tracked_at_seq) {
+ // TODO(lth): Reduce duplicate code with WritePrepared ValidateSnapshot logic.
+ assert(snapshot_);
+
+ SequenceNumber min_uncommitted =
+ static_cast_with_check<const SnapshotImpl, const Snapshot>(
+ snapshot_.get())
+ ->min_uncommitted_;
+ SequenceNumber snap_seq = snapshot_->GetSequenceNumber();
+ // tracked_at_seq is either max or the last snapshot with which this key was
+ // trackeed so there is no need to apply the IsInSnapshot to this comparison
+ // here as tracked_at_seq is not a prepare seq.
+ if (*tracked_at_seq <= snap_seq) {
+ // If the key has been previous validated at a sequence number earlier
+ // than the curent snapshot's sequence number, we already know it has not
+ // been modified.
+ return Status::OK();
+ }
+
+ *tracked_at_seq = snap_seq;
+
+ ColumnFamilyHandle* cfh =
+ column_family ? column_family : db_impl_->DefaultColumnFamily();
+
+ WriteUnpreparedTxnReadCallback snap_checker(
+ wupt_db_, snap_seq, min_uncommitted, unprep_seqs_, kBackedByDBSnapshot);
+ return TransactionUtil::CheckKeyForConflicts(db_impl_, cfh, key.ToString(),
+ snap_seq, false /* cache_only */,
+ &snap_checker, min_uncommitted);
+}
+
+const std::map<SequenceNumber, size_t>&
+WriteUnpreparedTxn::GetUnpreparedSequenceNumbers() {
+ return unprep_seqs_;
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/write_unprepared_txn.h b/src/rocksdb/utilities/transactions/write_unprepared_txn.h
new file mode 100644
index 000000000..30c8f4c55
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/write_unprepared_txn.h
@@ -0,0 +1,341 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+
+#ifndef ROCKSDB_LITE
+
+#include <set>
+
+#include "utilities/transactions/write_prepared_txn.h"
+#include "utilities/transactions/write_unprepared_txn_db.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class WriteUnpreparedTxnDB;
+class WriteUnpreparedTxn;
+
+// WriteUnprepared transactions needs to be able to read their own uncommitted
+// writes, and supporting this requires some careful consideration. Because
+// writes in the current transaction may be flushed to DB already, we cannot
+// rely on the contents of WriteBatchWithIndex to determine whether a key should
+// be visible or not, so we have to remember to check the DB for any uncommitted
+// keys that should be visible to us. First, we will need to change the seek to
+// snapshot logic, to seek to max_visible_seq = max(snap_seq, max_unprep_seq).
+// Any key greater than max_visible_seq should not be visible because they
+// cannot be unprepared by the current transaction and they are not in its
+// snapshot.
+//
+// When we seek to max_visible_seq, one of these cases will happen:
+// 1. We hit a unprepared key from the current transaction.
+// 2. We hit a unprepared key from the another transaction.
+// 3. We hit a committed key with snap_seq < seq < max_unprep_seq.
+// 4. We hit a committed key with seq <= snap_seq.
+//
+// IsVisibleFullCheck handles all cases correctly.
+//
+// Other notes:
+// Note that max_visible_seq is only calculated once at iterator construction
+// time, meaning if the same transaction is adding more unprep seqs through
+// writes during iteration, these newer writes may not be visible. This is not a
+// problem for MySQL though because it avoids modifying the index as it is
+// scanning through it to avoid the Halloween Problem. Instead, it scans the
+// index once up front, and modifies based on a temporary copy.
+//
+// In DBIter, there is a "reseek" optimization if the iterator skips over too
+// many keys. However, this assumes that the reseek seeks exactly to the
+// required key. In write unprepared, even after seeking directly to
+// max_visible_seq, some iteration may be required before hitting a visible key,
+// and special precautions must be taken to avoid performing another reseek,
+// leading to an infinite loop.
+//
+class WriteUnpreparedTxnReadCallback : public ReadCallback {
+ public:
+ WriteUnpreparedTxnReadCallback(
+ WritePreparedTxnDB* db, SequenceNumber snapshot,
+ SequenceNumber min_uncommitted,
+ const std::map<SequenceNumber, size_t>& unprep_seqs,
+ SnapshotBackup backed_by_snapshot)
+ // Pass our last uncommitted seq as the snapshot to the parent class to
+ // ensure that the parent will not prematurely filter out own writes. We
+ // will do the exact comparison against snapshots in IsVisibleFullCheck
+ // override.
+ : ReadCallback(CalcMaxVisibleSeq(unprep_seqs, snapshot), min_uncommitted),
+ db_(db),
+ unprep_seqs_(unprep_seqs),
+ wup_snapshot_(snapshot),
+ backed_by_snapshot_(backed_by_snapshot) {
+ (void)backed_by_snapshot_; // to silence unused private field warning
+ }
+
+ virtual ~WriteUnpreparedTxnReadCallback() {
+ // If it is not backed by snapshot, the caller must check validity
+ assert(valid_checked_ || backed_by_snapshot_ == kBackedByDBSnapshot);
+ }
+
+ virtual bool IsVisibleFullCheck(SequenceNumber seq) override;
+
+ inline bool valid() {
+ valid_checked_ = true;
+ return snap_released_ == false;
+ }
+
+ void Refresh(SequenceNumber seq) override {
+ max_visible_seq_ = std::max(max_visible_seq_, seq);
+ wup_snapshot_ = seq;
+ }
+
+ static SequenceNumber CalcMaxVisibleSeq(
+ const std::map<SequenceNumber, size_t>& unprep_seqs,
+ SequenceNumber snapshot_seq) {
+ SequenceNumber max_unprepared = 0;
+ if (unprep_seqs.size()) {
+ max_unprepared =
+ unprep_seqs.rbegin()->first + unprep_seqs.rbegin()->second - 1;
+ }
+ return std::max(max_unprepared, snapshot_seq);
+ }
+
+ private:
+ WritePreparedTxnDB* db_;
+ const std::map<SequenceNumber, size_t>& unprep_seqs_;
+ SequenceNumber wup_snapshot_;
+ // Whether max_visible_seq_ is backed by a snapshot
+ const SnapshotBackup backed_by_snapshot_;
+ bool snap_released_ = false;
+ // Safety check to ensure that the caller has checked invalid statuses
+ bool valid_checked_ = false;
+};
+
+class WriteUnpreparedTxn : public WritePreparedTxn {
+ public:
+ WriteUnpreparedTxn(WriteUnpreparedTxnDB* db,
+ const WriteOptions& write_options,
+ const TransactionOptions& txn_options);
+
+ virtual ~WriteUnpreparedTxn();
+
+ using TransactionBaseImpl::Put;
+ virtual Status Put(ColumnFamilyHandle* column_family, const Slice& key,
+ const Slice& value,
+ const bool assume_tracked = false) override;
+ virtual Status Put(ColumnFamilyHandle* column_family, const SliceParts& key,
+ const SliceParts& value,
+ const bool assume_tracked = false) override;
+
+ using TransactionBaseImpl::Merge;
+ virtual Status Merge(ColumnFamilyHandle* column_family, const Slice& key,
+ const Slice& value,
+ const bool assume_tracked = false) override;
+
+ using TransactionBaseImpl::Delete;
+ virtual Status Delete(ColumnFamilyHandle* column_family, const Slice& key,
+ const bool assume_tracked = false) override;
+ virtual Status Delete(ColumnFamilyHandle* column_family,
+ const SliceParts& key,
+ const bool assume_tracked = false) override;
+
+ using TransactionBaseImpl::SingleDelete;
+ virtual Status SingleDelete(ColumnFamilyHandle* column_family,
+ const Slice& key,
+ const bool assume_tracked = false) override;
+ virtual Status SingleDelete(ColumnFamilyHandle* column_family,
+ const SliceParts& key,
+ const bool assume_tracked = false) override;
+
+ // In WriteUnprepared, untracked writes will break snapshot validation logic.
+ // Snapshot validation will only check the largest sequence number of a key to
+ // see if it was committed or not. However, an untracked unprepared write will
+ // hide smaller committed sequence numbers.
+ //
+ // TODO(lth): Investigate whether it is worth having snapshot validation
+ // validate all values larger than snap_seq. Otherwise, we should return
+ // Status::NotSupported for untracked writes.
+
+ virtual Status RebuildFromWriteBatch(WriteBatch*) override;
+
+ virtual uint64_t GetLastLogNumber() const override {
+ return last_log_number_;
+ }
+
+ void RemoveActiveIterator(Iterator* iter) {
+ active_iterators_.erase(
+ std::remove(active_iterators_.begin(), active_iterators_.end(), iter),
+ active_iterators_.end());
+ }
+
+ protected:
+ void Initialize(const TransactionOptions& txn_options) override;
+
+ Status PrepareInternal() override;
+
+ Status CommitWithoutPrepareInternal() override;
+ Status CommitInternal() override;
+
+ Status RollbackInternal() override;
+
+ void Clear() override;
+
+ void SetSavePoint() override;
+ Status RollbackToSavePoint() override;
+ Status PopSavePoint() override;
+
+ // Get and GetIterator needs to be overridden so that a ReadCallback to
+ // handle read-your-own-write is used.
+ using Transaction::Get;
+ virtual Status Get(const ReadOptions& options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ PinnableSlice* value) override;
+
+ using Transaction::MultiGet;
+ virtual void MultiGet(const ReadOptions& options,
+ ColumnFamilyHandle* column_family,
+ const size_t num_keys, const Slice* keys,
+ PinnableSlice* values, Status* statuses,
+ const bool sorted_input = false) override;
+
+ using Transaction::GetIterator;
+ virtual Iterator* GetIterator(const ReadOptions& options) override;
+ virtual Iterator* GetIterator(const ReadOptions& options,
+ ColumnFamilyHandle* column_family) override;
+
+ virtual Status ValidateSnapshot(ColumnFamilyHandle* column_family,
+ const Slice& key,
+ SequenceNumber* tracked_at_seq) override;
+
+ private:
+ friend class WriteUnpreparedTransactionTest_ReadYourOwnWrite_Test;
+ friend class WriteUnpreparedTransactionTest_RecoveryTest_Test;
+ friend class WriteUnpreparedTransactionTest_UnpreparedBatch_Test;
+ friend class WriteUnpreparedTxnDB;
+
+ const std::map<SequenceNumber, size_t>& GetUnpreparedSequenceNumbers();
+ Status WriteRollbackKeys(const TransactionKeyMap& tracked_keys,
+ WriteBatchWithIndex* rollback_batch,
+ ReadCallback* callback, const ReadOptions& roptions);
+
+ Status MaybeFlushWriteBatchToDB();
+ Status FlushWriteBatchToDB(bool prepared);
+ Status FlushWriteBatchToDBInternal(bool prepared);
+ Status FlushWriteBatchWithSavePointToDB();
+ Status RollbackToSavePointInternal();
+ Status HandleWrite(std::function<Status()> do_write);
+
+ // For write unprepared, we check on every writebatch append to see if
+ // write_batch_flush_threshold_ has been exceeded, and then call
+ // FlushWriteBatchToDB if so. This logic is encapsulated in
+ // MaybeFlushWriteBatchToDB.
+ int64_t write_batch_flush_threshold_;
+ WriteUnpreparedTxnDB* wupt_db_;
+
+ // Ordered list of unprep_seq sequence numbers that we have already written
+ // to DB.
+ //
+ // This maps unprep_seq => prepare_batch_cnt for each unprepared batch
+ // written by this transaction.
+ //
+ // Note that this contains both prepared and unprepared batches, since they
+ // are treated similarily in prepare heap/commit map, so it simplifies the
+ // commit callbacks.
+ std::map<SequenceNumber, size_t> unprep_seqs_;
+
+ uint64_t last_log_number_;
+
+ // Recovered transactions have tracked_keys_ populated, but are not actually
+ // locked for efficiency reasons. For recovered transactions, skip unlocking
+ // keys when transaction ends.
+ bool recovered_txn_;
+
+ // Track the largest sequence number at which we performed snapshot
+ // validation. If snapshot validation was skipped because no snapshot was set,
+ // then this is set to GetLastPublishedSequence. This value is useful because
+ // it means that for keys that have unprepared seqnos, we can guarantee that
+ // no committed keys by other transactions can exist between
+ // largest_validated_seq_ and max_unprep_seq. See
+ // WriteUnpreparedTxnDB::NewIterator for an explanation for why this is
+ // necessary for iterator Prev().
+ //
+ // Currently this value only increases during the lifetime of a transaction,
+ // but in some cases, we should be able to restore the previously largest
+ // value when calling RollbackToSavepoint.
+ SequenceNumber largest_validated_seq_;
+
+ using KeySet = std::unordered_map<uint32_t, std::vector<std::string>>;
+ struct SavePoint {
+ // Record of unprep_seqs_ at this savepoint. The set of unprep_seq is
+ // used during RollbackToSavepoint to determine visibility when restoring
+ // old values.
+ //
+ // TODO(lth): Since all unprep_seqs_ sets further down the stack must be
+ // subsets, this can potentially be deduplicated by just storing set
+ // difference. Investigate if this is worth it.
+ std::map<SequenceNumber, size_t> unprep_seqs_;
+
+ // This snapshot will be used to read keys at this savepoint if we call
+ // RollbackToSavePoint.
+ std::unique_ptr<ManagedSnapshot> snapshot_;
+
+ SavePoint(const std::map<SequenceNumber, size_t>& seqs,
+ ManagedSnapshot* snapshot)
+ : unprep_seqs_(seqs), snapshot_(snapshot){};
+ };
+
+ // We have 3 data structures holding savepoint information:
+ // 1. TransactionBaseImpl::save_points_
+ // 2. WriteUnpreparedTxn::flushed_save_points_
+ // 3. WriteUnpreparecTxn::unflushed_save_points_
+ //
+ // TransactionBaseImpl::save_points_ holds information about all write
+ // batches, including the current in-memory write_batch_, or unprepared
+ // batches that have been written out. Its responsibility is just to track
+ // which keys have been modified in every savepoint.
+ //
+ // WriteUnpreparedTxn::flushed_save_points_ holds information about savepoints
+ // set on unprepared batches that have already flushed. It holds the snapshot
+ // and unprep_seqs at that savepoint, so that the rollback process can
+ // determine which keys were visible at that point in time.
+ //
+ // WriteUnpreparecTxn::unflushed_save_points_ holds information about
+ // savepoints on the current in-memory write_batch_. It simply records the
+ // size of the write batch at every savepoint.
+ //
+ // TODO(lth): Remove the redundancy between save_point_boundaries_ and
+ // write_batch_.save_points_.
+ //
+ // Based on this information, here are some invariants:
+ // size(unflushed_save_points_) = size(write_batch_.save_points_)
+ // size(flushed_save_points_) + size(unflushed_save_points_)
+ // = size(save_points_)
+ //
+ std::unique_ptr<autovector<WriteUnpreparedTxn::SavePoint>>
+ flushed_save_points_;
+ std::unique_ptr<autovector<size_t>> unflushed_save_points_;
+
+ // It is currently unsafe to flush a write batch if there are active iterators
+ // created from this transaction. This is because we use WriteBatchWithIndex
+ // to do merging reads from the DB and the write batch. If we flush the write
+ // batch, it is possible that the delta iterator on the iterator will point to
+ // invalid memory.
+ std::vector<Iterator*> active_iterators_;
+
+ // Untracked keys that we have to rollback.
+ //
+ // TODO(lth): Currently we we do not record untracked keys per-savepoint.
+ // This means that when rolling back to savepoints, we have to check all
+ // keys in the current transaction for rollback. Note that this is only
+ // inefficient, but still correct because we take a snapshot at every
+ // savepoint, and we will use that snapshot to construct the rollback batch.
+ // The rollback batch will then contain a reissue of the same marker.
+ //
+ // A more optimal solution would be to only check keys changed since the
+ // last savepoint. Also, it may make sense to merge this into tracked_keys_
+ // and differentiate between tracked but not locked keys to avoid having two
+ // very similar data structures.
+ KeySet untracked_keys_;
+};
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/write_unprepared_txn_db.cc b/src/rocksdb/utilities/transactions/write_unprepared_txn_db.cc
new file mode 100644
index 000000000..ca365d044
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/write_unprepared_txn_db.cc
@@ -0,0 +1,468 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "utilities/transactions/write_unprepared_txn_db.h"
+#include "db/arena_wrapped_db_iter.h"
+#include "rocksdb/utilities/transaction_db.h"
+#include "util/cast_util.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// Instead of reconstructing a Transaction object, and calling rollback on it,
+// we can be more efficient with RollbackRecoveredTransaction by skipping
+// unnecessary steps (eg. updating CommitMap, reconstructing keyset)
+Status WriteUnpreparedTxnDB::RollbackRecoveredTransaction(
+ const DBImpl::RecoveredTransaction* rtxn) {
+ // TODO(lth): Reduce duplicate code with WritePrepared rollback logic.
+ assert(rtxn->unprepared_);
+ auto cf_map_shared_ptr = WritePreparedTxnDB::GetCFHandleMap();
+ auto cf_comp_map_shared_ptr = WritePreparedTxnDB::GetCFComparatorMap();
+ // In theory we could write with disableWAL = true during recovery, and
+ // assume that if we crash again during recovery, we can just replay from
+ // the very beginning. Unfortunately, the XIDs from the application may not
+ // necessarily be unique across restarts, potentially leading to situations
+ // like this:
+ //
+ // BEGIN_PREPARE(unprepared) Put(a) END_PREPARE(xid = 1)
+ // -- crash and recover with Put(a) rolled back as it was not prepared
+ // BEGIN_PREPARE(prepared) Put(b) END_PREPARE(xid = 1)
+ // COMMIT(xid = 1)
+ // -- crash and recover with both a, b
+ //
+ // We could just write the rollback marker, but then we would have to extend
+ // MemTableInserter during recovery to actually do writes into the DB
+ // instead of just dropping the in-memory write batch.
+ //
+ WriteOptions w_options;
+
+ class InvalidSnapshotReadCallback : public ReadCallback {
+ public:
+ InvalidSnapshotReadCallback(SequenceNumber snapshot)
+ : ReadCallback(snapshot) {}
+
+ inline bool IsVisibleFullCheck(SequenceNumber) override {
+ // The seq provided as snapshot is the seq right before we have locked and
+ // wrote to it, so whatever is there, it is committed.
+ return true;
+ }
+
+ // Ignore the refresh request since we are confident that our snapshot seq
+ // is not going to be affected by concurrent compactions (not enabled yet.)
+ void Refresh(SequenceNumber) override {}
+ };
+
+ // Iterate starting with largest sequence number.
+ for (auto it = rtxn->batches_.rbegin(); it != rtxn->batches_.rend(); ++it) {
+ auto last_visible_txn = it->first - 1;
+ const auto& batch = it->second.batch_;
+ WriteBatch rollback_batch;
+
+ struct RollbackWriteBatchBuilder : public WriteBatch::Handler {
+ DBImpl* db_;
+ ReadOptions roptions;
+ InvalidSnapshotReadCallback callback;
+ WriteBatch* rollback_batch_;
+ std::map<uint32_t, const Comparator*>& comparators_;
+ std::map<uint32_t, ColumnFamilyHandle*>& handles_;
+ using CFKeys = std::set<Slice, SetComparator>;
+ std::map<uint32_t, CFKeys> keys_;
+ bool rollback_merge_operands_;
+ RollbackWriteBatchBuilder(
+ DBImpl* db, SequenceNumber snap_seq, WriteBatch* dst_batch,
+ std::map<uint32_t, const Comparator*>& comparators,
+ std::map<uint32_t, ColumnFamilyHandle*>& handles,
+ bool rollback_merge_operands)
+ : db_(db),
+ callback(snap_seq),
+ // disable min_uncommitted optimization
+ rollback_batch_(dst_batch),
+ comparators_(comparators),
+ handles_(handles),
+ rollback_merge_operands_(rollback_merge_operands) {}
+
+ Status Rollback(uint32_t cf, const Slice& key) {
+ Status s;
+ CFKeys& cf_keys = keys_[cf];
+ if (cf_keys.size() == 0) { // just inserted
+ auto cmp = comparators_[cf];
+ keys_[cf] = CFKeys(SetComparator(cmp));
+ }
+ auto res = cf_keys.insert(key);
+ if (res.second ==
+ false) { // second is false if a element already existed.
+ return s;
+ }
+
+ PinnableSlice pinnable_val;
+ bool not_used;
+ auto cf_handle = handles_[cf];
+ DBImpl::GetImplOptions get_impl_options;
+ get_impl_options.column_family = cf_handle;
+ get_impl_options.value = &pinnable_val;
+ get_impl_options.value_found = &not_used;
+ get_impl_options.callback = &callback;
+ s = db_->GetImpl(roptions, key, get_impl_options);
+ assert(s.ok() || s.IsNotFound());
+ if (s.ok()) {
+ s = rollback_batch_->Put(cf_handle, key, pinnable_val);
+ assert(s.ok());
+ } else if (s.IsNotFound()) {
+ // There has been no readable value before txn. By adding a delete we
+ // make sure that there will be none afterwards either.
+ s = rollback_batch_->Delete(cf_handle, key);
+ assert(s.ok());
+ } else {
+ // Unexpected status. Return it to the user.
+ }
+ return s;
+ }
+
+ Status PutCF(uint32_t cf, const Slice& key,
+ const Slice& /*val*/) override {
+ return Rollback(cf, key);
+ }
+
+ Status DeleteCF(uint32_t cf, const Slice& key) override {
+ return Rollback(cf, key);
+ }
+
+ Status SingleDeleteCF(uint32_t cf, const Slice& key) override {
+ return Rollback(cf, key);
+ }
+
+ Status MergeCF(uint32_t cf, const Slice& key,
+ const Slice& /*val*/) override {
+ if (rollback_merge_operands_) {
+ return Rollback(cf, key);
+ } else {
+ return Status::OK();
+ }
+ }
+
+ // Recovered batches do not contain 2PC markers.
+ Status MarkNoop(bool) override { return Status::InvalidArgument(); }
+ Status MarkBeginPrepare(bool) override {
+ return Status::InvalidArgument();
+ }
+ Status MarkEndPrepare(const Slice&) override {
+ return Status::InvalidArgument();
+ }
+ Status MarkCommit(const Slice&) override {
+ return Status::InvalidArgument();
+ }
+ Status MarkRollback(const Slice&) override {
+ return Status::InvalidArgument();
+ }
+ } rollback_handler(db_impl_, last_visible_txn, &rollback_batch,
+ *cf_comp_map_shared_ptr.get(), *cf_map_shared_ptr.get(),
+ txn_db_options_.rollback_merge_operands);
+
+ auto s = batch->Iterate(&rollback_handler);
+ if (!s.ok()) {
+ return s;
+ }
+
+ // The Rollback marker will be used as a batch separator
+ WriteBatchInternal::MarkRollback(&rollback_batch, rtxn->name_);
+
+ const uint64_t kNoLogRef = 0;
+ const bool kDisableMemtable = true;
+ const size_t kOneBatch = 1;
+ uint64_t seq_used = kMaxSequenceNumber;
+ s = db_impl_->WriteImpl(w_options, &rollback_batch, nullptr, nullptr,
+ kNoLogRef, !kDisableMemtable, &seq_used, kOneBatch);
+ if (!s.ok()) {
+ return s;
+ }
+
+ // If two_write_queues, we must manually release the sequence number to
+ // readers.
+ if (db_impl_->immutable_db_options().two_write_queues) {
+ db_impl_->SetLastPublishedSequence(seq_used);
+ }
+ }
+
+ return Status::OK();
+}
+
+Status WriteUnpreparedTxnDB::Initialize(
+ const std::vector<size_t>& compaction_enabled_cf_indices,
+ const std::vector<ColumnFamilyHandle*>& handles) {
+ // TODO(lth): Reduce code duplication in this function.
+ auto dbimpl = static_cast_with_check<DBImpl, DB>(GetRootDB());
+ assert(dbimpl != nullptr);
+
+ db_impl_->SetSnapshotChecker(new WritePreparedSnapshotChecker(this));
+ // A callback to commit a single sub-batch
+ class CommitSubBatchPreReleaseCallback : public PreReleaseCallback {
+ public:
+ explicit CommitSubBatchPreReleaseCallback(WritePreparedTxnDB* db)
+ : db_(db) {}
+ Status Callback(SequenceNumber commit_seq,
+ bool is_mem_disabled __attribute__((__unused__)), uint64_t,
+ size_t /*index*/, size_t /*total*/) override {
+ assert(!is_mem_disabled);
+ db_->AddCommitted(commit_seq, commit_seq);
+ return Status::OK();
+ }
+
+ private:
+ WritePreparedTxnDB* db_;
+ };
+ db_impl_->SetRecoverableStatePreReleaseCallback(
+ new CommitSubBatchPreReleaseCallback(this));
+
+ // PessimisticTransactionDB::Initialize
+ for (auto cf_ptr : handles) {
+ AddColumnFamily(cf_ptr);
+ }
+ // Verify cf options
+ for (auto handle : handles) {
+ ColumnFamilyDescriptor cfd;
+ Status s = handle->GetDescriptor(&cfd);
+ if (!s.ok()) {
+ return s;
+ }
+ s = VerifyCFOptions(cfd.options);
+ if (!s.ok()) {
+ return s;
+ }
+ }
+
+ // Re-enable compaction for the column families that initially had
+ // compaction enabled.
+ std::vector<ColumnFamilyHandle*> compaction_enabled_cf_handles;
+ compaction_enabled_cf_handles.reserve(compaction_enabled_cf_indices.size());
+ for (auto index : compaction_enabled_cf_indices) {
+ compaction_enabled_cf_handles.push_back(handles[index]);
+ }
+
+ // create 'real' transactions from recovered shell transactions
+ auto rtxns = dbimpl->recovered_transactions();
+ std::map<SequenceNumber, SequenceNumber> ordered_seq_cnt;
+ for (auto rtxn : rtxns) {
+ auto recovered_trx = rtxn.second;
+ assert(recovered_trx);
+ assert(recovered_trx->batches_.size() >= 1);
+ assert(recovered_trx->name_.length());
+
+ // We can only rollback transactions after AdvanceMaxEvictedSeq is called,
+ // but AddPrepared must occur before AdvanceMaxEvictedSeq, which is why
+ // two iterations is required.
+ if (recovered_trx->unprepared_) {
+ continue;
+ }
+
+ WriteOptions w_options;
+ w_options.sync = true;
+ TransactionOptions t_options;
+
+ auto first_log_number = recovered_trx->batches_.begin()->second.log_number_;
+ auto first_seq = recovered_trx->batches_.begin()->first;
+ auto last_prepare_batch_cnt =
+ recovered_trx->batches_.begin()->second.batch_cnt_;
+
+ Transaction* real_trx = BeginTransaction(w_options, t_options, nullptr);
+ assert(real_trx);
+ auto wupt =
+ static_cast_with_check<WriteUnpreparedTxn, Transaction>(real_trx);
+ wupt->recovered_txn_ = true;
+
+ real_trx->SetLogNumber(first_log_number);
+ real_trx->SetId(first_seq);
+ Status s = real_trx->SetName(recovered_trx->name_);
+ if (!s.ok()) {
+ return s;
+ }
+ wupt->prepare_batch_cnt_ = last_prepare_batch_cnt;
+
+ for (auto batch : recovered_trx->batches_) {
+ const auto& seq = batch.first;
+ const auto& batch_info = batch.second;
+ auto cnt = batch_info.batch_cnt_ ? batch_info.batch_cnt_ : 1;
+ assert(batch_info.log_number_);
+
+ ordered_seq_cnt[seq] = cnt;
+ assert(wupt->unprep_seqs_.count(seq) == 0);
+ wupt->unprep_seqs_[seq] = cnt;
+
+ s = wupt->RebuildFromWriteBatch(batch_info.batch_);
+ assert(s.ok());
+ if (!s.ok()) {
+ return s;
+ }
+ }
+
+ const bool kClear = true;
+ wupt->InitWriteBatch(kClear);
+
+ real_trx->SetState(Transaction::PREPARED);
+ if (!s.ok()) {
+ return s;
+ }
+ }
+ // AddPrepared must be called in order
+ for (auto seq_cnt : ordered_seq_cnt) {
+ auto seq = seq_cnt.first;
+ auto cnt = seq_cnt.second;
+ for (size_t i = 0; i < cnt; i++) {
+ AddPrepared(seq + i);
+ }
+ }
+
+ SequenceNumber prev_max = max_evicted_seq_;
+ SequenceNumber last_seq = db_impl_->GetLatestSequenceNumber();
+ AdvanceMaxEvictedSeq(prev_max, last_seq);
+ // Create a gap between max and the next snapshot. This simplifies the logic
+ // in IsInSnapshot by not having to consider the special case of max ==
+ // snapshot after recovery. This is tested in IsInSnapshotEmptyMapTest.
+ if (last_seq) {
+ db_impl_->versions_->SetLastAllocatedSequence(last_seq + 1);
+ db_impl_->versions_->SetLastSequence(last_seq + 1);
+ db_impl_->versions_->SetLastPublishedSequence(last_seq + 1);
+ }
+
+ Status s;
+ // Rollback unprepared transactions.
+ for (auto rtxn : rtxns) {
+ auto recovered_trx = rtxn.second;
+ if (recovered_trx->unprepared_) {
+ s = RollbackRecoveredTransaction(recovered_trx);
+ if (!s.ok()) {
+ return s;
+ }
+ continue;
+ }
+ }
+
+ if (s.ok()) {
+ dbimpl->DeleteAllRecoveredTransactions();
+
+ // Compaction should start only after max_evicted_seq_ is set AND recovered
+ // transactions are either added to PrepareHeap or rolled back.
+ s = EnableAutoCompaction(compaction_enabled_cf_handles);
+ }
+
+ return s;
+}
+
+Transaction* WriteUnpreparedTxnDB::BeginTransaction(
+ const WriteOptions& write_options, const TransactionOptions& txn_options,
+ Transaction* old_txn) {
+ if (old_txn != nullptr) {
+ ReinitializeTransaction(old_txn, write_options, txn_options);
+ return old_txn;
+ } else {
+ return new WriteUnpreparedTxn(this, write_options, txn_options);
+ }
+}
+
+// Struct to hold ownership of snapshot and read callback for iterator cleanup.
+struct WriteUnpreparedTxnDB::IteratorState {
+ IteratorState(WritePreparedTxnDB* txn_db, SequenceNumber sequence,
+ std::shared_ptr<ManagedSnapshot> s,
+ SequenceNumber min_uncommitted, WriteUnpreparedTxn* txn)
+ : callback(txn_db, sequence, min_uncommitted, txn->unprep_seqs_,
+ kBackedByDBSnapshot),
+ snapshot(s) {}
+ SequenceNumber MaxVisibleSeq() { return callback.max_visible_seq(); }
+
+ WriteUnpreparedTxnReadCallback callback;
+ std::shared_ptr<ManagedSnapshot> snapshot;
+};
+
+namespace {
+static void CleanupWriteUnpreparedTxnDBIterator(void* arg1, void* /*arg2*/) {
+ delete reinterpret_cast<WriteUnpreparedTxnDB::IteratorState*>(arg1);
+}
+} // anonymous namespace
+
+Iterator* WriteUnpreparedTxnDB::NewIterator(const ReadOptions& options,
+ ColumnFamilyHandle* column_family,
+ WriteUnpreparedTxn* txn) {
+ // TODO(lth): Refactor so that this logic is shared with WritePrepared.
+ constexpr bool ALLOW_BLOB = true;
+ constexpr bool ALLOW_REFRESH = true;
+ std::shared_ptr<ManagedSnapshot> own_snapshot = nullptr;
+ SequenceNumber snapshot_seq = kMaxSequenceNumber;
+ SequenceNumber min_uncommitted = 0;
+
+ // Currently, the Prev() iterator logic does not work well without snapshot
+ // validation. The logic simply iterates through values of a key in
+ // ascending seqno order, stopping at the first non-visible value and
+ // returning the last visible value.
+ //
+ // For example, if snapshot sequence is 3, and we have the following keys:
+ // foo: v1 1
+ // foo: v2 2
+ // foo: v3 3
+ // foo: v4 4
+ // foo: v5 5
+ //
+ // Then 1, 2, 3 will be visible, but 4 will be non-visible, so we return v3,
+ // which is the last visible value.
+ //
+ // For unprepared transactions, if we have snap_seq = 3, but the current
+ // transaction has unprep_seq 5, then returning the first non-visible value
+ // would be incorrect, as we should return v5, and not v3. The problem is that
+ // there are committed values at snapshot_seq < commit_seq < unprep_seq.
+ //
+ // Snapshot validation can prevent this problem by ensuring that no committed
+ // values exist at snapshot_seq < commit_seq, and thus any value with a
+ // sequence number greater than snapshot_seq must be unprepared values. For
+ // example, if the transaction had a snapshot at 3, then snapshot validation
+ // would be performed during the Put(v5) call. It would find v4, and the Put
+ // would fail with snapshot validation failure.
+ //
+ // TODO(lth): Improve Prev() logic to continue iterating until
+ // max_visible_seq, and then return the last visible value, so that this
+ // restriction can be lifted.
+ const Snapshot* snapshot = nullptr;
+ if (options.snapshot == nullptr) {
+ snapshot = GetSnapshot();
+ own_snapshot = std::make_shared<ManagedSnapshot>(db_impl_, snapshot);
+ } else {
+ snapshot = options.snapshot;
+ }
+
+ snapshot_seq = snapshot->GetSequenceNumber();
+ assert(snapshot_seq != kMaxSequenceNumber);
+ // Iteration is safe as long as largest_validated_seq <= snapshot_seq. We are
+ // guaranteed that for keys that were modified by this transaction (and thus
+ // might have unprepared values), no committed values exist at
+ // largest_validated_seq < commit_seq (or the contrapositive: any committed
+ // value must exist at commit_seq <= largest_validated_seq). This implies
+ // that commit_seq <= largest_validated_seq <= snapshot_seq or commit_seq <=
+ // snapshot_seq. As explained above, the problem with Prev() only happens when
+ // snapshot_seq < commit_seq.
+ //
+ // For keys that were not modified by this transaction, largest_validated_seq_
+ // is meaningless, and Prev() should just work with the existing visibility
+ // logic.
+ if (txn->largest_validated_seq_ > snapshot->GetSequenceNumber() &&
+ !txn->unprep_seqs_.empty()) {
+ ROCKS_LOG_ERROR(info_log_,
+ "WriteUnprepared iterator creation failed since the "
+ "transaction has performed unvalidated writes");
+ return nullptr;
+ }
+ min_uncommitted =
+ static_cast_with_check<const SnapshotImpl, const Snapshot>(snapshot)
+ ->min_uncommitted_;
+
+ auto* cfd = reinterpret_cast<ColumnFamilyHandleImpl*>(column_family)->cfd();
+ auto* state =
+ new IteratorState(this, snapshot_seq, own_snapshot, min_uncommitted, txn);
+ auto* db_iter =
+ db_impl_->NewIteratorImpl(options, cfd, state->MaxVisibleSeq(),
+ &state->callback, !ALLOW_BLOB, !ALLOW_REFRESH);
+ db_iter->RegisterCleanup(CleanupWriteUnpreparedTxnDBIterator, state, nullptr);
+ return db_iter;
+}
+
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/transactions/write_unprepared_txn_db.h b/src/rocksdb/utilities/transactions/write_unprepared_txn_db.h
new file mode 100644
index 000000000..ad8e40f94
--- /dev/null
+++ b/src/rocksdb/utilities/transactions/write_unprepared_txn_db.h
@@ -0,0 +1,148 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+#ifndef ROCKSDB_LITE
+
+#include "utilities/transactions/write_prepared_txn_db.h"
+#include "utilities/transactions/write_unprepared_txn.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class WriteUnpreparedTxn;
+
+class WriteUnpreparedTxnDB : public WritePreparedTxnDB {
+ public:
+ using WritePreparedTxnDB::WritePreparedTxnDB;
+
+ Status Initialize(const std::vector<size_t>& compaction_enabled_cf_indices,
+ const std::vector<ColumnFamilyHandle*>& handles) override;
+
+ Transaction* BeginTransaction(const WriteOptions& write_options,
+ const TransactionOptions& txn_options,
+ Transaction* old_txn) override;
+
+ // Struct to hold ownership of snapshot and read callback for cleanup.
+ struct IteratorState;
+
+ using WritePreparedTxnDB::NewIterator;
+ Iterator* NewIterator(const ReadOptions& options,
+ ColumnFamilyHandle* column_family,
+ WriteUnpreparedTxn* txn);
+
+ private:
+ Status RollbackRecoveredTransaction(const DBImpl::RecoveredTransaction* rtxn);
+};
+
+class WriteUnpreparedCommitEntryPreReleaseCallback : public PreReleaseCallback {
+ // TODO(lth): Reduce code duplication with
+ // WritePreparedCommitEntryPreReleaseCallback
+ public:
+ // includes_data indicates that the commit also writes non-empty
+ // CommitTimeWriteBatch to memtable, which needs to be committed separately.
+ WriteUnpreparedCommitEntryPreReleaseCallback(
+ WritePreparedTxnDB* db, DBImpl* db_impl,
+ const std::map<SequenceNumber, size_t>& unprep_seqs,
+ size_t data_batch_cnt = 0, bool publish_seq = true)
+ : db_(db),
+ db_impl_(db_impl),
+ unprep_seqs_(unprep_seqs),
+ data_batch_cnt_(data_batch_cnt),
+ includes_data_(data_batch_cnt_ > 0),
+ publish_seq_(publish_seq) {
+ assert(unprep_seqs.size() > 0);
+ }
+
+ virtual Status Callback(SequenceNumber commit_seq,
+ bool is_mem_disabled __attribute__((__unused__)),
+ uint64_t, size_t /*index*/,
+ size_t /*total*/) override {
+ const uint64_t last_commit_seq = LIKELY(data_batch_cnt_ <= 1)
+ ? commit_seq
+ : commit_seq + data_batch_cnt_ - 1;
+ // Recall that unprep_seqs maps (un)prepared_seq => prepare_batch_cnt.
+ for (const auto& s : unprep_seqs_) {
+ for (size_t i = 0; i < s.second; i++) {
+ db_->AddCommitted(s.first + i, last_commit_seq);
+ }
+ }
+
+ if (includes_data_) {
+ assert(data_batch_cnt_);
+ // Commit the data that is accompanied with the commit request
+ for (size_t i = 0; i < data_batch_cnt_; i++) {
+ // For commit seq of each batch use the commit seq of the last batch.
+ // This would make debugging easier by having all the batches having
+ // the same sequence number.
+ db_->AddCommitted(commit_seq + i, last_commit_seq);
+ }
+ }
+ if (db_impl_->immutable_db_options().two_write_queues && publish_seq_) {
+ assert(is_mem_disabled); // implies the 2nd queue
+ // Publish the sequence number. We can do that here assuming the callback
+ // is invoked only from one write queue, which would guarantee that the
+ // publish sequence numbers will be in order, i.e., once a seq is
+ // published all the seq prior to that are also publishable.
+ db_impl_->SetLastPublishedSequence(last_commit_seq);
+ }
+ // else SequenceNumber that is updated as part of the write already does the
+ // publishing
+ return Status::OK();
+ }
+
+ private:
+ WritePreparedTxnDB* db_;
+ DBImpl* db_impl_;
+ const std::map<SequenceNumber, size_t>& unprep_seqs_;
+ size_t data_batch_cnt_;
+ // Either because it is commit without prepare or it has a
+ // CommitTimeWriteBatch
+ bool includes_data_;
+ // Should the callback also publishes the commit seq number
+ bool publish_seq_;
+};
+
+class WriteUnpreparedRollbackPreReleaseCallback : public PreReleaseCallback {
+ // TODO(lth): Reduce code duplication with
+ // WritePreparedCommitEntryPreReleaseCallback
+ public:
+ WriteUnpreparedRollbackPreReleaseCallback(
+ WritePreparedTxnDB* db, DBImpl* db_impl,
+ const std::map<SequenceNumber, size_t>& unprep_seqs,
+ SequenceNumber rollback_seq)
+ : db_(db),
+ db_impl_(db_impl),
+ unprep_seqs_(unprep_seqs),
+ rollback_seq_(rollback_seq) {
+ assert(unprep_seqs.size() > 0);
+ assert(db_impl_->immutable_db_options().two_write_queues);
+ }
+
+ virtual Status Callback(SequenceNumber commit_seq,
+ bool is_mem_disabled __attribute__((__unused__)),
+ uint64_t, size_t /*index*/,
+ size_t /*total*/) override {
+ assert(is_mem_disabled); // implies the 2nd queue
+ const uint64_t last_commit_seq = commit_seq;
+ db_->AddCommitted(rollback_seq_, last_commit_seq);
+ // Recall that unprep_seqs maps (un)prepared_seq => prepare_batch_cnt.
+ for (const auto& s : unprep_seqs_) {
+ for (size_t i = 0; i < s.second; i++) {
+ db_->AddCommitted(s.first + i, last_commit_seq);
+ }
+ }
+ db_impl_->SetLastPublishedSequence(last_commit_seq);
+ return Status::OK();
+ }
+
+ private:
+ WritePreparedTxnDB* db_;
+ DBImpl* db_impl_;
+ const std::map<SequenceNumber, size_t>& unprep_seqs_;
+ SequenceNumber rollback_seq_;
+};
+
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/ttl/db_ttl_impl.cc b/src/rocksdb/utilities/ttl/db_ttl_impl.cc
new file mode 100644
index 000000000..9ebaa247f
--- /dev/null
+++ b/src/rocksdb/utilities/ttl/db_ttl_impl.cc
@@ -0,0 +1,335 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+#ifndef ROCKSDB_LITE
+
+#include "utilities/ttl/db_ttl_impl.h"
+
+#include "db/write_batch_internal.h"
+#include "file/filename.h"
+#include "rocksdb/convenience.h"
+#include "rocksdb/env.h"
+#include "rocksdb/iterator.h"
+#include "rocksdb/utilities/db_ttl.h"
+#include "util/coding.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+void DBWithTTLImpl::SanitizeOptions(int32_t ttl, ColumnFamilyOptions* options,
+ Env* env) {
+ if (options->compaction_filter) {
+ options->compaction_filter =
+ new TtlCompactionFilter(ttl, env, options->compaction_filter);
+ } else {
+ options->compaction_filter_factory =
+ std::shared_ptr<CompactionFilterFactory>(new TtlCompactionFilterFactory(
+ ttl, env, options->compaction_filter_factory));
+ }
+
+ if (options->merge_operator) {
+ options->merge_operator.reset(
+ new TtlMergeOperator(options->merge_operator, env));
+ }
+}
+
+// Open the db inside DBWithTTLImpl because options needs pointer to its ttl
+DBWithTTLImpl::DBWithTTLImpl(DB* db) : DBWithTTL(db), closed_(false) {}
+
+DBWithTTLImpl::~DBWithTTLImpl() {
+ if (!closed_) {
+ Close();
+ }
+}
+
+Status DBWithTTLImpl::Close() {
+ Status ret = Status::OK();
+ if (!closed_) {
+ Options default_options = GetOptions();
+ // Need to stop background compaction before getting rid of the filter
+ CancelAllBackgroundWork(db_, /* wait = */ true);
+ ret = db_->Close();
+ delete default_options.compaction_filter;
+ closed_ = true;
+ }
+ return ret;
+}
+
+Status UtilityDB::OpenTtlDB(const Options& options, const std::string& dbname,
+ StackableDB** dbptr, int32_t ttl, bool read_only) {
+ DBWithTTL* db;
+ Status s = DBWithTTL::Open(options, dbname, &db, ttl, read_only);
+ if (s.ok()) {
+ *dbptr = db;
+ } else {
+ *dbptr = nullptr;
+ }
+ return s;
+}
+
+Status DBWithTTL::Open(const Options& options, const std::string& dbname,
+ DBWithTTL** dbptr, int32_t ttl, bool read_only) {
+
+ DBOptions db_options(options);
+ ColumnFamilyOptions cf_options(options);
+ std::vector<ColumnFamilyDescriptor> column_families;
+ column_families.push_back(
+ ColumnFamilyDescriptor(kDefaultColumnFamilyName, cf_options));
+ std::vector<ColumnFamilyHandle*> handles;
+ Status s = DBWithTTL::Open(db_options, dbname, column_families, &handles,
+ dbptr, {ttl}, read_only);
+ if (s.ok()) {
+ assert(handles.size() == 1);
+ // i can delete the handle since DBImpl is always holding a reference to
+ // default column family
+ delete handles[0];
+ }
+ return s;
+}
+
+Status DBWithTTL::Open(
+ const DBOptions& db_options, const std::string& dbname,
+ const std::vector<ColumnFamilyDescriptor>& column_families,
+ std::vector<ColumnFamilyHandle*>* handles, DBWithTTL** dbptr,
+ std::vector<int32_t> ttls, bool read_only) {
+
+ if (ttls.size() != column_families.size()) {
+ return Status::InvalidArgument(
+ "ttls size has to be the same as number of column families");
+ }
+
+ std::vector<ColumnFamilyDescriptor> column_families_sanitized =
+ column_families;
+ for (size_t i = 0; i < column_families_sanitized.size(); ++i) {
+ DBWithTTLImpl::SanitizeOptions(
+ ttls[i], &column_families_sanitized[i].options,
+ db_options.env == nullptr ? Env::Default() : db_options.env);
+ }
+ DB* db;
+
+ Status st;
+ if (read_only) {
+ st = DB::OpenForReadOnly(db_options, dbname, column_families_sanitized,
+ handles, &db);
+ } else {
+ st = DB::Open(db_options, dbname, column_families_sanitized, handles, &db);
+ }
+ if (st.ok()) {
+ *dbptr = new DBWithTTLImpl(db);
+ } else {
+ *dbptr = nullptr;
+ }
+ return st;
+}
+
+Status DBWithTTLImpl::CreateColumnFamilyWithTtl(
+ const ColumnFamilyOptions& options, const std::string& column_family_name,
+ ColumnFamilyHandle** handle, int ttl) {
+ ColumnFamilyOptions sanitized_options = options;
+ DBWithTTLImpl::SanitizeOptions(ttl, &sanitized_options, GetEnv());
+
+ return DBWithTTL::CreateColumnFamily(sanitized_options, column_family_name,
+ handle);
+}
+
+Status DBWithTTLImpl::CreateColumnFamily(const ColumnFamilyOptions& options,
+ const std::string& column_family_name,
+ ColumnFamilyHandle** handle) {
+ return CreateColumnFamilyWithTtl(options, column_family_name, handle, 0);
+}
+
+// Appends the current timestamp to the string.
+// Returns false if could not get the current_time, true if append succeeds
+Status DBWithTTLImpl::AppendTS(const Slice& val, std::string* val_with_ts,
+ Env* env) {
+ val_with_ts->reserve(kTSLength + val.size());
+ char ts_string[kTSLength];
+ int64_t curtime;
+ Status st = env->GetCurrentTime(&curtime);
+ if (!st.ok()) {
+ return st;
+ }
+ EncodeFixed32(ts_string, (int32_t)curtime);
+ val_with_ts->append(val.data(), val.size());
+ val_with_ts->append(ts_string, kTSLength);
+ return st;
+}
+
+// Returns corruption if the length of the string is lesser than timestamp, or
+// timestamp refers to a time lesser than ttl-feature release time
+Status DBWithTTLImpl::SanityCheckTimestamp(const Slice& str) {
+ if (str.size() < kTSLength) {
+ return Status::Corruption("Error: value's length less than timestamp's\n");
+ }
+ // Checks that TS is not lesser than kMinTimestamp
+ // Gaurds against corruption & normal database opened incorrectly in ttl mode
+ int32_t timestamp_value = DecodeFixed32(str.data() + str.size() - kTSLength);
+ if (timestamp_value < kMinTimestamp) {
+ return Status::Corruption("Error: Timestamp < ttl feature release time!\n");
+ }
+ return Status::OK();
+}
+
+// Checks if the string is stale or not according to TTl provided
+bool DBWithTTLImpl::IsStale(const Slice& value, int32_t ttl, Env* env) {
+ if (ttl <= 0) { // Data is fresh if TTL is non-positive
+ return false;
+ }
+ int64_t curtime;
+ if (!env->GetCurrentTime(&curtime).ok()) {
+ return false; // Treat the data as fresh if could not get current time
+ }
+ int32_t timestamp_value =
+ DecodeFixed32(value.data() + value.size() - kTSLength);
+ return (timestamp_value + ttl) < curtime;
+}
+
+// Strips the TS from the end of the slice
+Status DBWithTTLImpl::StripTS(PinnableSlice* pinnable_val) {
+ Status st;
+ if (pinnable_val->size() < kTSLength) {
+ return Status::Corruption("Bad timestamp in key-value");
+ }
+ // Erasing characters which hold the TS
+ pinnable_val->remove_suffix(kTSLength);
+ return st;
+}
+
+// Strips the TS from the end of the string
+Status DBWithTTLImpl::StripTS(std::string* str) {
+ Status st;
+ if (str->length() < kTSLength) {
+ return Status::Corruption("Bad timestamp in key-value");
+ }
+ // Erasing characters which hold the TS
+ str->erase(str->length() - kTSLength, kTSLength);
+ return st;
+}
+
+Status DBWithTTLImpl::Put(const WriteOptions& options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ const Slice& val) {
+ WriteBatch batch;
+ batch.Put(column_family, key, val);
+ return Write(options, &batch);
+}
+
+Status DBWithTTLImpl::Get(const ReadOptions& options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ PinnableSlice* value) {
+ Status st = db_->Get(options, column_family, key, value);
+ if (!st.ok()) {
+ return st;
+ }
+ st = SanityCheckTimestamp(*value);
+ if (!st.ok()) {
+ return st;
+ }
+ return StripTS(value);
+}
+
+std::vector<Status> DBWithTTLImpl::MultiGet(
+ const ReadOptions& options,
+ const std::vector<ColumnFamilyHandle*>& column_family,
+ const std::vector<Slice>& keys, std::vector<std::string>* values) {
+ auto statuses = db_->MultiGet(options, column_family, keys, values);
+ for (size_t i = 0; i < keys.size(); ++i) {
+ if (!statuses[i].ok()) {
+ continue;
+ }
+ statuses[i] = SanityCheckTimestamp((*values)[i]);
+ if (!statuses[i].ok()) {
+ continue;
+ }
+ statuses[i] = StripTS(&(*values)[i]);
+ }
+ return statuses;
+}
+
+bool DBWithTTLImpl::KeyMayExist(const ReadOptions& options,
+ ColumnFamilyHandle* column_family,
+ const Slice& key, std::string* value,
+ bool* value_found) {
+ bool ret = db_->KeyMayExist(options, column_family, key, value, value_found);
+ if (ret && value != nullptr && value_found != nullptr && *value_found) {
+ if (!SanityCheckTimestamp(*value).ok() || !StripTS(value).ok()) {
+ return false;
+ }
+ }
+ return ret;
+}
+
+Status DBWithTTLImpl::Merge(const WriteOptions& options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ const Slice& value) {
+ WriteBatch batch;
+ batch.Merge(column_family, key, value);
+ return Write(options, &batch);
+}
+
+Status DBWithTTLImpl::Write(const WriteOptions& opts, WriteBatch* updates) {
+ class Handler : public WriteBatch::Handler {
+ public:
+ explicit Handler(Env* env) : env_(env) {}
+ WriteBatch updates_ttl;
+ Status batch_rewrite_status;
+ Status PutCF(uint32_t column_family_id, const Slice& key,
+ const Slice& value) override {
+ std::string value_with_ts;
+ Status st = AppendTS(value, &value_with_ts, env_);
+ if (!st.ok()) {
+ batch_rewrite_status = st;
+ } else {
+ WriteBatchInternal::Put(&updates_ttl, column_family_id, key,
+ value_with_ts);
+ }
+ return Status::OK();
+ }
+ Status MergeCF(uint32_t column_family_id, const Slice& key,
+ const Slice& value) override {
+ std::string value_with_ts;
+ Status st = AppendTS(value, &value_with_ts, env_);
+ if (!st.ok()) {
+ batch_rewrite_status = st;
+ } else {
+ WriteBatchInternal::Merge(&updates_ttl, column_family_id, key,
+ value_with_ts);
+ }
+ return Status::OK();
+ }
+ Status DeleteCF(uint32_t column_family_id, const Slice& key) override {
+ WriteBatchInternal::Delete(&updates_ttl, column_family_id, key);
+ return Status::OK();
+ }
+ void LogData(const Slice& blob) override { updates_ttl.PutLogData(blob); }
+
+ private:
+ Env* env_;
+ };
+ Handler handler(GetEnv());
+ updates->Iterate(&handler);
+ if (!handler.batch_rewrite_status.ok()) {
+ return handler.batch_rewrite_status;
+ } else {
+ return db_->Write(opts, &(handler.updates_ttl));
+ }
+}
+
+Iterator* DBWithTTLImpl::NewIterator(const ReadOptions& opts,
+ ColumnFamilyHandle* column_family) {
+ return new TtlIterator(db_->NewIterator(opts, column_family));
+}
+
+void DBWithTTLImpl::SetTtl(ColumnFamilyHandle *h, int32_t ttl) {
+ std::shared_ptr<TtlCompactionFilterFactory> filter;
+ Options opts;
+ opts = GetOptions(h);
+ filter = std::static_pointer_cast<TtlCompactionFilterFactory>(
+ opts.compaction_filter_factory);
+ if (!filter)
+ return;
+ filter->SetTtl(ttl);
+}
+
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/ttl/db_ttl_impl.h b/src/rocksdb/utilities/ttl/db_ttl_impl.h
new file mode 100644
index 000000000..ab6063a47
--- /dev/null
+++ b/src/rocksdb/utilities/ttl/db_ttl_impl.h
@@ -0,0 +1,361 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#pragma once
+
+#ifndef ROCKSDB_LITE
+#include <deque>
+#include <string>
+#include <vector>
+
+#include "db/db_impl/db_impl.h"
+#include "rocksdb/compaction_filter.h"
+#include "rocksdb/db.h"
+#include "rocksdb/env.h"
+#include "rocksdb/merge_operator.h"
+#include "rocksdb/utilities/db_ttl.h"
+#include "rocksdb/utilities/utility_db.h"
+
+#ifdef _WIN32
+// Windows API macro interference
+#undef GetCurrentTime
+#endif
+
+namespace ROCKSDB_NAMESPACE {
+
+class DBWithTTLImpl : public DBWithTTL {
+ public:
+ static void SanitizeOptions(int32_t ttl, ColumnFamilyOptions* options,
+ Env* env);
+
+ explicit DBWithTTLImpl(DB* db);
+
+ virtual ~DBWithTTLImpl();
+
+ virtual Status Close() override;
+
+ Status CreateColumnFamilyWithTtl(const ColumnFamilyOptions& options,
+ const std::string& column_family_name,
+ ColumnFamilyHandle** handle,
+ int ttl) override;
+
+ Status CreateColumnFamily(const ColumnFamilyOptions& options,
+ const std::string& column_family_name,
+ ColumnFamilyHandle** handle) override;
+
+ using StackableDB::Put;
+ virtual Status Put(const WriteOptions& options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ const Slice& val) override;
+
+ using StackableDB::Get;
+ virtual Status Get(const ReadOptions& options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ PinnableSlice* value) override;
+
+ using StackableDB::MultiGet;
+ virtual std::vector<Status> MultiGet(
+ const ReadOptions& options,
+ const std::vector<ColumnFamilyHandle*>& column_family,
+ const std::vector<Slice>& keys,
+ std::vector<std::string>* values) override;
+
+ using StackableDB::KeyMayExist;
+ virtual bool KeyMayExist(const ReadOptions& options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ std::string* value,
+ bool* value_found = nullptr) override;
+
+ using StackableDB::Merge;
+ virtual Status Merge(const WriteOptions& options,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ const Slice& value) override;
+
+ virtual Status Write(const WriteOptions& opts, WriteBatch* updates) override;
+
+ using StackableDB::NewIterator;
+ virtual Iterator* NewIterator(const ReadOptions& opts,
+ ColumnFamilyHandle* column_family) override;
+
+ virtual DB* GetBaseDB() override { return db_; }
+
+ static bool IsStale(const Slice& value, int32_t ttl, Env* env);
+
+ static Status AppendTS(const Slice& val, std::string* val_with_ts, Env* env);
+
+ static Status SanityCheckTimestamp(const Slice& str);
+
+ static Status StripTS(std::string* str);
+
+ static Status StripTS(PinnableSlice* str);
+
+ static const uint32_t kTSLength = sizeof(int32_t); // size of timestamp
+
+ static const int32_t kMinTimestamp = 1368146402; // 05/09/2013:5:40PM GMT-8
+
+ static const int32_t kMaxTimestamp = 2147483647; // 01/18/2038:7:14PM GMT-8
+
+ void SetTtl(int32_t ttl) override { SetTtl(DefaultColumnFamily(), ttl); }
+
+ void SetTtl(ColumnFamilyHandle *h, int32_t ttl) override;
+
+ private:
+ // remember whether the Close completes or not
+ bool closed_;
+};
+
+class TtlIterator : public Iterator {
+
+ public:
+ explicit TtlIterator(Iterator* iter) : iter_(iter) { assert(iter_); }
+
+ ~TtlIterator() { delete iter_; }
+
+ bool Valid() const override { return iter_->Valid(); }
+
+ void SeekToFirst() override { iter_->SeekToFirst(); }
+
+ void SeekToLast() override { iter_->SeekToLast(); }
+
+ void Seek(const Slice& target) override { iter_->Seek(target); }
+
+ void SeekForPrev(const Slice& target) override { iter_->SeekForPrev(target); }
+
+ void Next() override { iter_->Next(); }
+
+ void Prev() override { iter_->Prev(); }
+
+ Slice key() const override { return iter_->key(); }
+
+ int32_t timestamp() const {
+ return DecodeFixed32(iter_->value().data() + iter_->value().size() -
+ DBWithTTLImpl::kTSLength);
+ }
+
+ Slice value() const override {
+ // TODO: handle timestamp corruption like in general iterator semantics
+ assert(DBWithTTLImpl::SanityCheckTimestamp(iter_->value()).ok());
+ Slice trimmed_value = iter_->value();
+ trimmed_value.size_ -= DBWithTTLImpl::kTSLength;
+ return trimmed_value;
+ }
+
+ Status status() const override { return iter_->status(); }
+
+ private:
+ Iterator* iter_;
+};
+
+class TtlCompactionFilter : public CompactionFilter {
+ public:
+ TtlCompactionFilter(
+ int32_t ttl, Env* env, const CompactionFilter* user_comp_filter,
+ std::unique_ptr<const CompactionFilter> user_comp_filter_from_factory =
+ nullptr)
+ : ttl_(ttl),
+ env_(env),
+ user_comp_filter_(user_comp_filter),
+ user_comp_filter_from_factory_(
+ std::move(user_comp_filter_from_factory)) {
+ // Unlike the merge operator, compaction filter is necessary for TTL, hence
+ // this would be called even if user doesn't specify any compaction-filter
+ if (!user_comp_filter_) {
+ user_comp_filter_ = user_comp_filter_from_factory_.get();
+ }
+ }
+
+ virtual bool Filter(int level, const Slice& key, const Slice& old_val,
+ std::string* new_val, bool* value_changed) const
+ override {
+ if (DBWithTTLImpl::IsStale(old_val, ttl_, env_)) {
+ return true;
+ }
+ if (user_comp_filter_ == nullptr) {
+ return false;
+ }
+ assert(old_val.size() >= DBWithTTLImpl::kTSLength);
+ Slice old_val_without_ts(old_val.data(),
+ old_val.size() - DBWithTTLImpl::kTSLength);
+ if (user_comp_filter_->Filter(level, key, old_val_without_ts, new_val,
+ value_changed)) {
+ return true;
+ }
+ if (*value_changed) {
+ new_val->append(
+ old_val.data() + old_val.size() - DBWithTTLImpl::kTSLength,
+ DBWithTTLImpl::kTSLength);
+ }
+ return false;
+ }
+
+ virtual const char* Name() const override { return "Delete By TTL"; }
+
+ private:
+ int32_t ttl_;
+ Env* env_;
+ const CompactionFilter* user_comp_filter_;
+ std::unique_ptr<const CompactionFilter> user_comp_filter_from_factory_;
+};
+
+class TtlCompactionFilterFactory : public CompactionFilterFactory {
+ public:
+ TtlCompactionFilterFactory(
+ int32_t ttl, Env* env,
+ std::shared_ptr<CompactionFilterFactory> comp_filter_factory)
+ : ttl_(ttl), env_(env), user_comp_filter_factory_(comp_filter_factory) {}
+
+ virtual std::unique_ptr<CompactionFilter> CreateCompactionFilter(
+ const CompactionFilter::Context& context) override {
+ std::unique_ptr<const CompactionFilter> user_comp_filter_from_factory =
+ nullptr;
+ if (user_comp_filter_factory_) {
+ user_comp_filter_from_factory =
+ user_comp_filter_factory_->CreateCompactionFilter(context);
+ }
+
+ return std::unique_ptr<TtlCompactionFilter>(new TtlCompactionFilter(
+ ttl_, env_, nullptr, std::move(user_comp_filter_from_factory)));
+ }
+
+ void SetTtl(int32_t ttl) {
+ ttl_ = ttl;
+ }
+
+ virtual const char* Name() const override {
+ return "TtlCompactionFilterFactory";
+ }
+
+ private:
+ int32_t ttl_;
+ Env* env_;
+ std::shared_ptr<CompactionFilterFactory> user_comp_filter_factory_;
+};
+
+class TtlMergeOperator : public MergeOperator {
+
+ public:
+ explicit TtlMergeOperator(const std::shared_ptr<MergeOperator>& merge_op,
+ Env* env)
+ : user_merge_op_(merge_op), env_(env) {
+ assert(merge_op);
+ assert(env);
+ }
+
+ virtual bool FullMergeV2(const MergeOperationInput& merge_in,
+ MergeOperationOutput* merge_out) const override {
+ const uint32_t ts_len = DBWithTTLImpl::kTSLength;
+ if (merge_in.existing_value && merge_in.existing_value->size() < ts_len) {
+ ROCKS_LOG_ERROR(merge_in.logger,
+ "Error: Could not remove timestamp from existing value.");
+ return false;
+ }
+
+ // Extract time-stamp from each operand to be passed to user_merge_op_
+ std::vector<Slice> operands_without_ts;
+ for (const auto& operand : merge_in.operand_list) {
+ if (operand.size() < ts_len) {
+ ROCKS_LOG_ERROR(
+ merge_in.logger,
+ "Error: Could not remove timestamp from operand value.");
+ return false;
+ }
+ operands_without_ts.push_back(operand);
+ operands_without_ts.back().remove_suffix(ts_len);
+ }
+
+ // Apply the user merge operator (store result in *new_value)
+ bool good = true;
+ MergeOperationOutput user_merge_out(merge_out->new_value,
+ merge_out->existing_operand);
+ if (merge_in.existing_value) {
+ Slice existing_value_without_ts(merge_in.existing_value->data(),
+ merge_in.existing_value->size() - ts_len);
+ good = user_merge_op_->FullMergeV2(
+ MergeOperationInput(merge_in.key, &existing_value_without_ts,
+ operands_without_ts, merge_in.logger),
+ &user_merge_out);
+ } else {
+ good = user_merge_op_->FullMergeV2(
+ MergeOperationInput(merge_in.key, nullptr, operands_without_ts,
+ merge_in.logger),
+ &user_merge_out);
+ }
+
+ // Return false if the user merge operator returned false
+ if (!good) {
+ return false;
+ }
+
+ if (merge_out->existing_operand.data()) {
+ merge_out->new_value.assign(merge_out->existing_operand.data(),
+ merge_out->existing_operand.size());
+ merge_out->existing_operand = Slice(nullptr, 0);
+ }
+
+ // Augment the *new_value with the ttl time-stamp
+ int64_t curtime;
+ if (!env_->GetCurrentTime(&curtime).ok()) {
+ ROCKS_LOG_ERROR(
+ merge_in.logger,
+ "Error: Could not get current time to be attached internally "
+ "to the new value.");
+ return false;
+ } else {
+ char ts_string[ts_len];
+ EncodeFixed32(ts_string, (int32_t)curtime);
+ merge_out->new_value.append(ts_string, ts_len);
+ return true;
+ }
+ }
+
+ virtual bool PartialMergeMulti(const Slice& key,
+ const std::deque<Slice>& operand_list,
+ std::string* new_value, Logger* logger) const
+ override {
+ const uint32_t ts_len = DBWithTTLImpl::kTSLength;
+ std::deque<Slice> operands_without_ts;
+
+ for (const auto& operand : operand_list) {
+ if (operand.size() < ts_len) {
+ ROCKS_LOG_ERROR(logger,
+ "Error: Could not remove timestamp from value.");
+ return false;
+ }
+
+ operands_without_ts.push_back(
+ Slice(operand.data(), operand.size() - ts_len));
+ }
+
+ // Apply the user partial-merge operator (store result in *new_value)
+ assert(new_value);
+ if (!user_merge_op_->PartialMergeMulti(key, operands_without_ts, new_value,
+ logger)) {
+ return false;
+ }
+
+ // Augment the *new_value with the ttl time-stamp
+ int64_t curtime;
+ if (!env_->GetCurrentTime(&curtime).ok()) {
+ ROCKS_LOG_ERROR(
+ logger,
+ "Error: Could not get current time to be attached internally "
+ "to the new value.");
+ return false;
+ } else {
+ char ts_string[ts_len];
+ EncodeFixed32(ts_string, (int32_t)curtime);
+ new_value->append(ts_string, ts_len);
+ return true;
+ }
+ }
+
+ virtual const char* Name() const override { return "Merge By TTL"; }
+
+ private:
+ std::shared_ptr<MergeOperator> user_merge_op_;
+ Env* env_;
+};
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/ttl/ttl_test.cc b/src/rocksdb/utilities/ttl/ttl_test.cc
new file mode 100644
index 000000000..3960bc625
--- /dev/null
+++ b/src/rocksdb/utilities/ttl/ttl_test.cc
@@ -0,0 +1,693 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#ifndef ROCKSDB_LITE
+
+#include <map>
+#include <memory>
+#include "rocksdb/compaction_filter.h"
+#include "rocksdb/utilities/db_ttl.h"
+#include "test_util/testharness.h"
+#include "util/string_util.h"
+#ifndef OS_WIN
+#include <unistd.h>
+#endif
+
+namespace ROCKSDB_NAMESPACE {
+
+namespace {
+
+typedef std::map<std::string, std::string> KVMap;
+
+enum BatchOperation { OP_PUT = 0, OP_DELETE = 1 };
+}
+
+class SpecialTimeEnv : public EnvWrapper {
+ public:
+ explicit SpecialTimeEnv(Env* base) : EnvWrapper(base) {
+ base->GetCurrentTime(&current_time_);
+ }
+
+ void Sleep(int64_t sleep_time) { current_time_ += sleep_time; }
+ Status GetCurrentTime(int64_t* current_time) override {
+ *current_time = current_time_;
+ return Status::OK();
+ }
+
+ private:
+ int64_t current_time_ = 0;
+};
+
+class TtlTest : public testing::Test {
+ public:
+ TtlTest() {
+ env_.reset(new SpecialTimeEnv(Env::Default()));
+ dbname_ = test::PerThreadDBPath("db_ttl");
+ options_.create_if_missing = true;
+ options_.env = env_.get();
+ // ensure that compaction is kicked in to always strip timestamp from kvs
+ options_.max_compaction_bytes = 1;
+ // compaction should take place always from level0 for determinism
+ db_ttl_ = nullptr;
+ DestroyDB(dbname_, Options());
+ }
+
+ ~TtlTest() override {
+ CloseTtl();
+ DestroyDB(dbname_, Options());
+ }
+
+ // Open database with TTL support when TTL not provided with db_ttl_ pointer
+ void OpenTtl() {
+ ASSERT_TRUE(db_ttl_ ==
+ nullptr); // db should be closed before opening again
+ ASSERT_OK(DBWithTTL::Open(options_, dbname_, &db_ttl_));
+ }
+
+ // Open database with TTL support when TTL provided with db_ttl_ pointer
+ void OpenTtl(int32_t ttl) {
+ ASSERT_TRUE(db_ttl_ == nullptr);
+ ASSERT_OK(DBWithTTL::Open(options_, dbname_, &db_ttl_, ttl));
+ }
+
+ // Open with TestFilter compaction filter
+ void OpenTtlWithTestCompaction(int32_t ttl) {
+ options_.compaction_filter_factory =
+ std::shared_ptr<CompactionFilterFactory>(
+ new TestFilterFactory(kSampleSize_, kNewValue_));
+ OpenTtl(ttl);
+ }
+
+ // Open database with TTL support in read_only mode
+ void OpenReadOnlyTtl(int32_t ttl) {
+ ASSERT_TRUE(db_ttl_ == nullptr);
+ ASSERT_OK(DBWithTTL::Open(options_, dbname_, &db_ttl_, ttl, true));
+ }
+
+ // Call db_ttl_->Close() before delete db_ttl_
+ void CloseTtl() { CloseTtlHelper(true); }
+
+ // No db_ttl_->Close() before delete db_ttl_
+ void CloseTtlNoDBClose() { CloseTtlHelper(false); }
+
+ void CloseTtlHelper(bool close_db) {
+ if (db_ttl_ != nullptr) {
+ if (close_db) {
+ db_ttl_->Close();
+ }
+ delete db_ttl_;
+ db_ttl_ = nullptr;
+ }
+ }
+
+ // Populates and returns a kv-map
+ void MakeKVMap(int64_t num_entries) {
+ kvmap_.clear();
+ int digits = 1;
+ for (int64_t dummy = num_entries; dummy /= 10; ++digits) {
+ }
+ int digits_in_i = 1;
+ for (int64_t i = 0; i < num_entries; i++) {
+ std::string key = "key";
+ std::string value = "value";
+ if (i % 10 == 0) {
+ digits_in_i++;
+ }
+ for(int j = digits_in_i; j < digits; j++) {
+ key.append("0");
+ value.append("0");
+ }
+ AppendNumberTo(&key, i);
+ AppendNumberTo(&value, i);
+ kvmap_[key] = value;
+ }
+ ASSERT_EQ(static_cast<int64_t>(kvmap_.size()),
+ num_entries); // check all insertions done
+ }
+
+ // Makes a write-batch with key-vals from kvmap_ and 'Write''s it
+ void MakePutWriteBatch(const BatchOperation* batch_ops, int64_t num_ops) {
+ ASSERT_LE(num_ops, static_cast<int64_t>(kvmap_.size()));
+ static WriteOptions wopts;
+ static FlushOptions flush_opts;
+ WriteBatch batch;
+ kv_it_ = kvmap_.begin();
+ for (int64_t i = 0; i < num_ops && kv_it_ != kvmap_.end(); i++, ++kv_it_) {
+ switch (batch_ops[i]) {
+ case OP_PUT:
+ batch.Put(kv_it_->first, kv_it_->second);
+ break;
+ case OP_DELETE:
+ batch.Delete(kv_it_->first);
+ break;
+ default:
+ FAIL();
+ }
+ }
+ db_ttl_->Write(wopts, &batch);
+ db_ttl_->Flush(flush_opts);
+ }
+
+ // Puts num_entries starting from start_pos_map from kvmap_ into the database
+ void PutValues(int64_t start_pos_map, int64_t num_entries, bool flush = true,
+ ColumnFamilyHandle* cf = nullptr) {
+ ASSERT_TRUE(db_ttl_);
+ ASSERT_LE(start_pos_map + num_entries, static_cast<int64_t>(kvmap_.size()));
+ static WriteOptions wopts;
+ static FlushOptions flush_opts;
+ kv_it_ = kvmap_.begin();
+ advance(kv_it_, start_pos_map);
+ for (int64_t i = 0; kv_it_ != kvmap_.end() && i < num_entries;
+ i++, ++kv_it_) {
+ ASSERT_OK(cf == nullptr
+ ? db_ttl_->Put(wopts, kv_it_->first, kv_it_->second)
+ : db_ttl_->Put(wopts, cf, kv_it_->first, kv_it_->second));
+ }
+ // Put a mock kv at the end because CompactionFilter doesn't delete last key
+ ASSERT_OK(cf == nullptr ? db_ttl_->Put(wopts, "keymock", "valuemock")
+ : db_ttl_->Put(wopts, cf, "keymock", "valuemock"));
+ if (flush) {
+ if (cf == nullptr) {
+ db_ttl_->Flush(flush_opts);
+ } else {
+ db_ttl_->Flush(flush_opts, cf);
+ }
+ }
+ }
+
+ // Runs a manual compaction
+ void ManualCompact(ColumnFamilyHandle* cf = nullptr) {
+ if (cf == nullptr) {
+ db_ttl_->CompactRange(CompactRangeOptions(), nullptr, nullptr);
+ } else {
+ db_ttl_->CompactRange(CompactRangeOptions(), cf, nullptr, nullptr);
+ }
+ }
+
+ // checks the whole kvmap_ to return correct values using KeyMayExist
+ void SimpleKeyMayExistCheck() {
+ static ReadOptions ropts;
+ bool value_found;
+ std::string val;
+ for(auto &kv : kvmap_) {
+ bool ret = db_ttl_->KeyMayExist(ropts, kv.first, &val, &value_found);
+ if (ret == false || value_found == false) {
+ fprintf(stderr, "KeyMayExist could not find key=%s in the database but"
+ " should have\n", kv.first.c_str());
+ FAIL();
+ } else if (val.compare(kv.second) != 0) {
+ fprintf(stderr, " value for key=%s present in database is %s but"
+ " should be %s\n", kv.first.c_str(), val.c_str(),
+ kv.second.c_str());
+ FAIL();
+ }
+ }
+ }
+
+ // checks the whole kvmap_ to return correct values using MultiGet
+ void SimpleMultiGetTest() {
+ static ReadOptions ropts;
+ std::vector<Slice> keys;
+ std::vector<std::string> values;
+
+ for (auto& kv : kvmap_) {
+ keys.emplace_back(kv.first);
+ }
+
+ auto statuses = db_ttl_->MultiGet(ropts, keys, &values);
+ size_t i = 0;
+ for (auto& kv : kvmap_) {
+ ASSERT_OK(statuses[i]);
+ ASSERT_EQ(values[i], kv.second);
+ ++i;
+ }
+ }
+
+ // Sleeps for slp_tim then runs a manual compaction
+ // Checks span starting from st_pos from kvmap_ in the db and
+ // Gets should return true if check is true and false otherwise
+ // Also checks that value that we got is the same as inserted; and =kNewValue
+ // if test_compaction_change is true
+ void SleepCompactCheck(int slp_tim, int64_t st_pos, int64_t span,
+ bool check = true, bool test_compaction_change = false,
+ ColumnFamilyHandle* cf = nullptr) {
+ ASSERT_TRUE(db_ttl_);
+
+ env_->Sleep(slp_tim);
+ ManualCompact(cf);
+ static ReadOptions ropts;
+ kv_it_ = kvmap_.begin();
+ advance(kv_it_, st_pos);
+ std::string v;
+ for (int64_t i = 0; kv_it_ != kvmap_.end() && i < span; i++, ++kv_it_) {
+ Status s = (cf == nullptr) ? db_ttl_->Get(ropts, kv_it_->first, &v)
+ : db_ttl_->Get(ropts, cf, kv_it_->first, &v);
+ if (s.ok() != check) {
+ fprintf(stderr, "key=%s ", kv_it_->first.c_str());
+ if (!s.ok()) {
+ fprintf(stderr, "is absent from db but was expected to be present\n");
+ } else {
+ fprintf(stderr, "is present in db but was expected to be absent\n");
+ }
+ FAIL();
+ } else if (s.ok()) {
+ if (test_compaction_change && v.compare(kNewValue_) != 0) {
+ fprintf(stderr, " value for key=%s present in database is %s but "
+ " should be %s\n", kv_it_->first.c_str(), v.c_str(),
+ kNewValue_.c_str());
+ FAIL();
+ } else if (!test_compaction_change && v.compare(kv_it_->second) !=0) {
+ fprintf(stderr, " value for key=%s present in database is %s but "
+ " should be %s\n", kv_it_->first.c_str(), v.c_str(),
+ kv_it_->second.c_str());
+ FAIL();
+ }
+ }
+ }
+ }
+
+ // Similar as SleepCompactCheck but uses TtlIterator to read from db
+ void SleepCompactCheckIter(int slp, int st_pos, int64_t span,
+ bool check = true) {
+ ASSERT_TRUE(db_ttl_);
+ env_->Sleep(slp);
+ ManualCompact();
+ static ReadOptions ropts;
+ Iterator *dbiter = db_ttl_->NewIterator(ropts);
+ kv_it_ = kvmap_.begin();
+ advance(kv_it_, st_pos);
+
+ dbiter->Seek(kv_it_->first);
+ if (!check) {
+ if (dbiter->Valid()) {
+ ASSERT_NE(dbiter->value().compare(kv_it_->second), 0);
+ }
+ } else { // dbiter should have found out kvmap_[st_pos]
+ for (int64_t i = st_pos; kv_it_ != kvmap_.end() && i < st_pos + span;
+ i++, ++kv_it_) {
+ ASSERT_TRUE(dbiter->Valid());
+ ASSERT_EQ(dbiter->value().compare(kv_it_->second), 0);
+ dbiter->Next();
+ }
+ }
+ delete dbiter;
+ }
+
+ // Set ttl on open db
+ void SetTtl(int32_t ttl, ColumnFamilyHandle* cf = nullptr) {
+ ASSERT_TRUE(db_ttl_);
+ cf == nullptr ? db_ttl_->SetTtl(ttl) : db_ttl_->SetTtl(cf, ttl);
+ }
+
+ class TestFilter : public CompactionFilter {
+ public:
+ TestFilter(const int64_t kSampleSize, const std::string& kNewValue)
+ : kSampleSize_(kSampleSize),
+ kNewValue_(kNewValue) {
+ }
+
+ // Works on keys of the form "key<number>"
+ // Drops key if number at the end of key is in [0, kSampleSize_/3),
+ // Keeps key if it is in [kSampleSize_/3, 2*kSampleSize_/3),
+ // Change value if it is in [2*kSampleSize_/3, kSampleSize_)
+ // Eg. kSampleSize_=6. Drop:key0-1...Keep:key2-3...Change:key4-5...
+ bool Filter(int /*level*/, const Slice& key, const Slice& /*value*/,
+ std::string* new_value, bool* value_changed) const override {
+ assert(new_value != nullptr);
+
+ std::string search_str = "0123456789";
+ std::string key_string = key.ToString();
+ size_t pos = key_string.find_first_of(search_str);
+ int num_key_end;
+ if (pos != std::string::npos) {
+ auto key_substr = key_string.substr(pos, key.size() - pos);
+#ifndef CYGWIN
+ num_key_end = std::stoi(key_substr);
+#else
+ num_key_end = std::strtol(key_substr.c_str(), 0, 10);
+#endif
+
+ } else {
+ return false; // Keep keys not matching the format "key<NUMBER>"
+ }
+
+ int64_t partition = kSampleSize_ / 3;
+ if (num_key_end < partition) {
+ return true;
+ } else if (num_key_end < partition * 2) {
+ return false;
+ } else {
+ *new_value = kNewValue_;
+ *value_changed = true;
+ return false;
+ }
+ }
+
+ const char* Name() const override { return "TestFilter"; }
+
+ private:
+ const int64_t kSampleSize_;
+ const std::string kNewValue_;
+ };
+
+ class TestFilterFactory : public CompactionFilterFactory {
+ public:
+ TestFilterFactory(const int64_t kSampleSize, const std::string& kNewValue)
+ : kSampleSize_(kSampleSize),
+ kNewValue_(kNewValue) {
+ }
+
+ std::unique_ptr<CompactionFilter> CreateCompactionFilter(
+ const CompactionFilter::Context& /*context*/) override {
+ return std::unique_ptr<CompactionFilter>(
+ new TestFilter(kSampleSize_, kNewValue_));
+ }
+
+ const char* Name() const override { return "TestFilterFactory"; }
+
+ private:
+ const int64_t kSampleSize_;
+ const std::string kNewValue_;
+ };
+
+
+ // Choose carefully so that Put, Gets & Compaction complete in 1 second buffer
+ static const int64_t kSampleSize_ = 100;
+ std::string dbname_;
+ DBWithTTL* db_ttl_;
+ std::unique_ptr<SpecialTimeEnv> env_;
+
+ private:
+ Options options_;
+ KVMap kvmap_;
+ KVMap::iterator kv_it_;
+ const std::string kNewValue_ = "new_value";
+ std::unique_ptr<CompactionFilter> test_comp_filter_;
+}; // class TtlTest
+
+// If TTL is non positive or not provided, the behaviour is TTL = infinity
+// This test opens the db 3 times with such default behavior and inserts a
+// bunch of kvs each time. All kvs should accumulate in the db till the end
+// Partitions the sample-size provided into 3 sets over boundary1 and boundary2
+TEST_F(TtlTest, NoEffect) {
+ MakeKVMap(kSampleSize_);
+ int64_t boundary1 = kSampleSize_ / 3;
+ int64_t boundary2 = 2 * boundary1;
+
+ OpenTtl();
+ PutValues(0, boundary1); //T=0: Set1 never deleted
+ SleepCompactCheck(1, 0, boundary1); //T=1: Set1 still there
+ CloseTtl();
+
+ OpenTtl(0);
+ PutValues(boundary1, boundary2 - boundary1); //T=1: Set2 never deleted
+ SleepCompactCheck(1, 0, boundary2); //T=2: Sets1 & 2 still there
+ CloseTtl();
+
+ OpenTtl(-1);
+ PutValues(boundary2, kSampleSize_ - boundary2); //T=3: Set3 never deleted
+ SleepCompactCheck(1, 0, kSampleSize_, true); //T=4: Sets 1,2,3 still there
+ CloseTtl();
+}
+
+// Rerun the NoEffect test with a different version of CloseTtl
+// function, where db is directly deleted without close.
+TEST_F(TtlTest, DestructWithoutClose) {
+ MakeKVMap(kSampleSize_);
+ int64_t boundary1 = kSampleSize_ / 3;
+ int64_t boundary2 = 2 * boundary1;
+
+ OpenTtl();
+ PutValues(0, boundary1); // T=0: Set1 never deleted
+ SleepCompactCheck(1, 0, boundary1); // T=1: Set1 still there
+ CloseTtlNoDBClose();
+
+ OpenTtl(0);
+ PutValues(boundary1, boundary2 - boundary1); // T=1: Set2 never deleted
+ SleepCompactCheck(1, 0, boundary2); // T=2: Sets1 & 2 still there
+ CloseTtlNoDBClose();
+
+ OpenTtl(-1);
+ PutValues(boundary2, kSampleSize_ - boundary2); // T=3: Set3 never deleted
+ SleepCompactCheck(1, 0, kSampleSize_, true); // T=4: Sets 1,2,3 still there
+ CloseTtlNoDBClose();
+}
+
+// Puts a set of values and checks its presence using Get during ttl
+TEST_F(TtlTest, PresentDuringTTL) {
+ MakeKVMap(kSampleSize_);
+
+ OpenTtl(2); // T=0:Open the db with ttl = 2
+ PutValues(0, kSampleSize_); // T=0:Insert Set1. Delete at t=2
+ SleepCompactCheck(1, 0, kSampleSize_, true); // T=1:Set1 should still be there
+ CloseTtl();
+}
+
+// Puts a set of values and checks its absence using Get after ttl
+TEST_F(TtlTest, AbsentAfterTTL) {
+ MakeKVMap(kSampleSize_);
+
+ OpenTtl(1); // T=0:Open the db with ttl = 2
+ PutValues(0, kSampleSize_); // T=0:Insert Set1. Delete at t=2
+ SleepCompactCheck(2, 0, kSampleSize_, false); // T=2:Set1 should not be there
+ CloseTtl();
+}
+
+// Resets the timestamp of a set of kvs by updating them and checks that they
+// are not deleted according to the old timestamp
+TEST_F(TtlTest, ResetTimestamp) {
+ MakeKVMap(kSampleSize_);
+
+ OpenTtl(3);
+ PutValues(0, kSampleSize_); // T=0: Insert Set1. Delete at t=3
+ env_->Sleep(2); // T=2
+ PutValues(0, kSampleSize_); // T=2: Insert Set1. Delete at t=5
+ SleepCompactCheck(2, 0, kSampleSize_); // T=4: Set1 should still be there
+ CloseTtl();
+}
+
+// Similar to PresentDuringTTL but uses Iterator
+TEST_F(TtlTest, IterPresentDuringTTL) {
+ MakeKVMap(kSampleSize_);
+
+ OpenTtl(2);
+ PutValues(0, kSampleSize_); // T=0: Insert. Delete at t=2
+ SleepCompactCheckIter(1, 0, kSampleSize_); // T=1: Set should be there
+ CloseTtl();
+}
+
+// Similar to AbsentAfterTTL but uses Iterator
+TEST_F(TtlTest, IterAbsentAfterTTL) {
+ MakeKVMap(kSampleSize_);
+
+ OpenTtl(1);
+ PutValues(0, kSampleSize_); // T=0: Insert. Delete at t=1
+ SleepCompactCheckIter(2, 0, kSampleSize_, false); // T=2: Should not be there
+ CloseTtl();
+}
+
+// Checks presence while opening the same db more than once with the same ttl
+// Note: The second open will open the same db
+TEST_F(TtlTest, MultiOpenSamePresent) {
+ MakeKVMap(kSampleSize_);
+
+ OpenTtl(2);
+ PutValues(0, kSampleSize_); // T=0: Insert. Delete at t=2
+ CloseTtl();
+
+ OpenTtl(2); // T=0. Delete at t=2
+ SleepCompactCheck(1, 0, kSampleSize_); // T=1: Set should be there
+ CloseTtl();
+}
+
+// Checks absence while opening the same db more than once with the same ttl
+// Note: The second open will open the same db
+TEST_F(TtlTest, MultiOpenSameAbsent) {
+ MakeKVMap(kSampleSize_);
+
+ OpenTtl(1);
+ PutValues(0, kSampleSize_); // T=0: Insert. Delete at t=1
+ CloseTtl();
+
+ OpenTtl(1); // T=0.Delete at t=1
+ SleepCompactCheck(2, 0, kSampleSize_, false); // T=2: Set should not be there
+ CloseTtl();
+}
+
+// Checks presence while opening the same db more than once with bigger ttl
+TEST_F(TtlTest, MultiOpenDifferent) {
+ MakeKVMap(kSampleSize_);
+
+ OpenTtl(1);
+ PutValues(0, kSampleSize_); // T=0: Insert. Delete at t=1
+ CloseTtl();
+
+ OpenTtl(3); // T=0: Set deleted at t=3
+ SleepCompactCheck(2, 0, kSampleSize_); // T=2: Set should be there
+ CloseTtl();
+}
+
+// Checks presence during ttl in read_only mode
+TEST_F(TtlTest, ReadOnlyPresentForever) {
+ MakeKVMap(kSampleSize_);
+
+ OpenTtl(1); // T=0:Open the db normally
+ PutValues(0, kSampleSize_); // T=0:Insert Set1. Delete at t=1
+ CloseTtl();
+
+ OpenReadOnlyTtl(1);
+ SleepCompactCheck(2, 0, kSampleSize_); // T=2:Set1 should still be there
+ CloseTtl();
+}
+
+// Checks whether WriteBatch works well with TTL
+// Puts all kvs in kvmap_ in a batch and writes first, then deletes first half
+TEST_F(TtlTest, WriteBatchTest) {
+ MakeKVMap(kSampleSize_);
+ BatchOperation batch_ops[kSampleSize_];
+ for (int i = 0; i < kSampleSize_; i++) {
+ batch_ops[i] = OP_PUT;
+ }
+
+ OpenTtl(2);
+ MakePutWriteBatch(batch_ops, kSampleSize_);
+ for (int i = 0; i < kSampleSize_ / 2; i++) {
+ batch_ops[i] = OP_DELETE;
+ }
+ MakePutWriteBatch(batch_ops, kSampleSize_ / 2);
+ SleepCompactCheck(0, 0, kSampleSize_ / 2, false);
+ SleepCompactCheck(0, kSampleSize_ / 2, kSampleSize_ - kSampleSize_ / 2);
+ CloseTtl();
+}
+
+// Checks user's compaction filter for correctness with TTL logic
+TEST_F(TtlTest, CompactionFilter) {
+ MakeKVMap(kSampleSize_);
+
+ OpenTtlWithTestCompaction(1);
+ PutValues(0, kSampleSize_); // T=0:Insert Set1. Delete at t=1
+ // T=2: TTL logic takes precedence over TestFilter:-Set1 should not be there
+ SleepCompactCheck(2, 0, kSampleSize_, false);
+ CloseTtl();
+
+ OpenTtlWithTestCompaction(3);
+ PutValues(0, kSampleSize_); // T=0:Insert Set1.
+ int64_t partition = kSampleSize_ / 3;
+ SleepCompactCheck(1, 0, partition, false); // Part dropped
+ SleepCompactCheck(0, partition, partition); // Part kept
+ SleepCompactCheck(0, 2 * partition, partition, true, true); // Part changed
+ CloseTtl();
+}
+
+// Insert some key-values which KeyMayExist should be able to get and check that
+// values returned are fine
+TEST_F(TtlTest, KeyMayExist) {
+ MakeKVMap(kSampleSize_);
+
+ OpenTtl();
+ PutValues(0, kSampleSize_, false);
+
+ SimpleKeyMayExistCheck();
+
+ CloseTtl();
+}
+
+TEST_F(TtlTest, MultiGetTest) {
+ MakeKVMap(kSampleSize_);
+
+ OpenTtl();
+ PutValues(0, kSampleSize_, false);
+
+ SimpleMultiGetTest();
+
+ CloseTtl();
+}
+
+TEST_F(TtlTest, ColumnFamiliesTest) {
+ DB* db;
+ Options options;
+ options.create_if_missing = true;
+ options.env = env_.get();
+
+ DB::Open(options, dbname_, &db);
+ ColumnFamilyHandle* handle;
+ ASSERT_OK(db->CreateColumnFamily(ColumnFamilyOptions(options),
+ "ttl_column_family", &handle));
+
+ delete handle;
+ delete db;
+
+ std::vector<ColumnFamilyDescriptor> column_families;
+ column_families.push_back(ColumnFamilyDescriptor(
+ kDefaultColumnFamilyName, ColumnFamilyOptions(options)));
+ column_families.push_back(ColumnFamilyDescriptor(
+ "ttl_column_family", ColumnFamilyOptions(options)));
+
+ std::vector<ColumnFamilyHandle*> handles;
+
+ ASSERT_OK(DBWithTTL::Open(DBOptions(options), dbname_, column_families,
+ &handles, &db_ttl_, {3, 5}, false));
+ ASSERT_EQ(handles.size(), 2U);
+ ColumnFamilyHandle* new_handle;
+ ASSERT_OK(db_ttl_->CreateColumnFamilyWithTtl(options, "ttl_column_family_2",
+ &new_handle, 2));
+ handles.push_back(new_handle);
+
+ MakeKVMap(kSampleSize_);
+ PutValues(0, kSampleSize_, false, handles[0]);
+ PutValues(0, kSampleSize_, false, handles[1]);
+ PutValues(0, kSampleSize_, false, handles[2]);
+
+ // everything should be there after 1 second
+ SleepCompactCheck(1, 0, kSampleSize_, true, false, handles[0]);
+ SleepCompactCheck(0, 0, kSampleSize_, true, false, handles[1]);
+ SleepCompactCheck(0, 0, kSampleSize_, true, false, handles[2]);
+
+ // only column family 1 should be alive after 4 seconds
+ SleepCompactCheck(3, 0, kSampleSize_, false, false, handles[0]);
+ SleepCompactCheck(0, 0, kSampleSize_, true, false, handles[1]);
+ SleepCompactCheck(0, 0, kSampleSize_, false, false, handles[2]);
+
+ // nothing should be there after 6 seconds
+ SleepCompactCheck(2, 0, kSampleSize_, false, false, handles[0]);
+ SleepCompactCheck(0, 0, kSampleSize_, false, false, handles[1]);
+ SleepCompactCheck(0, 0, kSampleSize_, false, false, handles[2]);
+
+ for (auto h : handles) {
+ delete h;
+ }
+ delete db_ttl_;
+ db_ttl_ = nullptr;
+}
+
+// Puts a set of values and checks its absence using Get after ttl
+TEST_F(TtlTest, ChangeTtlOnOpenDb) {
+ MakeKVMap(kSampleSize_);
+
+ OpenTtl(1); // T=0:Open the db with ttl = 2
+ SetTtl(3);
+ // @lint-ignore TXT2 T25377293 Grandfathered in
+ PutValues(0, kSampleSize_); // T=0:Insert Set1. Delete at t=2
+ SleepCompactCheck(2, 0, kSampleSize_, true); // T=2:Set1 should be there
+ CloseTtl();
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+// A black-box test for the ttl wrapper around rocksdb
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
+
+#else
+#include <stdio.h>
+
+int main(int /*argc*/, char** /*argv*/) {
+ fprintf(stderr, "SKIPPED as DBWithTTL is not supported in ROCKSDB_LITE\n");
+ return 0;
+}
+
+#endif // !ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/util_merge_operators_test.cc b/src/rocksdb/utilities/util_merge_operators_test.cc
new file mode 100644
index 000000000..3b043ea2f
--- /dev/null
+++ b/src/rocksdb/utilities/util_merge_operators_test.cc
@@ -0,0 +1,99 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "test_util/testharness.h"
+#include "test_util/testutil.h"
+#include "utilities/merge_operators.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class UtilMergeOperatorTest : public testing::Test {
+ public:
+ UtilMergeOperatorTest() {}
+
+ std::string FullMergeV2(std::string existing_value,
+ std::vector<std::string> operands,
+ std::string key = "") {
+ std::string result;
+ Slice result_operand(nullptr, 0);
+
+ Slice existing_value_slice(existing_value);
+ std::vector<Slice> operands_slice(operands.begin(), operands.end());
+
+ const MergeOperator::MergeOperationInput merge_in(
+ key, &existing_value_slice, operands_slice, nullptr);
+ MergeOperator::MergeOperationOutput merge_out(result, result_operand);
+ merge_operator_->FullMergeV2(merge_in, &merge_out);
+
+ if (result_operand.data()) {
+ result.assign(result_operand.data(), result_operand.size());
+ }
+ return result;
+ }
+
+ std::string FullMergeV2(std::vector<std::string> operands,
+ std::string key = "") {
+ std::string result;
+ Slice result_operand(nullptr, 0);
+
+ std::vector<Slice> operands_slice(operands.begin(), operands.end());
+
+ const MergeOperator::MergeOperationInput merge_in(key, nullptr,
+ operands_slice, nullptr);
+ MergeOperator::MergeOperationOutput merge_out(result, result_operand);
+ merge_operator_->FullMergeV2(merge_in, &merge_out);
+
+ if (result_operand.data()) {
+ result.assign(result_operand.data(), result_operand.size());
+ }
+ return result;
+ }
+
+ std::string PartialMerge(std::string left, std::string right,
+ std::string key = "") {
+ std::string result;
+
+ merge_operator_->PartialMerge(key, left, right, &result, nullptr);
+ return result;
+ }
+
+ std::string PartialMergeMulti(std::deque<std::string> operands,
+ std::string key = "") {
+ std::string result;
+ std::deque<Slice> operands_slice(operands.begin(), operands.end());
+
+ merge_operator_->PartialMergeMulti(key, operands_slice, &result, nullptr);
+ return result;
+ }
+
+ protected:
+ std::shared_ptr<MergeOperator> merge_operator_;
+};
+
+TEST_F(UtilMergeOperatorTest, MaxMergeOperator) {
+ merge_operator_ = MergeOperators::CreateMaxOperator();
+
+ EXPECT_EQ("B", FullMergeV2("B", {"A"}));
+ EXPECT_EQ("B", FullMergeV2("A", {"B"}));
+ EXPECT_EQ("", FullMergeV2({"", "", ""}));
+ EXPECT_EQ("A", FullMergeV2({"A"}));
+ EXPECT_EQ("ABC", FullMergeV2({"ABC"}));
+ EXPECT_EQ("Z", FullMergeV2({"ABC", "Z", "C", "AXX"}));
+ EXPECT_EQ("ZZZ", FullMergeV2({"ABC", "CC", "Z", "ZZZ"}));
+ EXPECT_EQ("a", FullMergeV2("a", {"ABC", "CC", "Z", "ZZZ"}));
+
+ EXPECT_EQ("z", PartialMergeMulti({"a", "z", "efqfqwgwew", "aaz", "hhhhh"}));
+
+ EXPECT_EQ("b", PartialMerge("a", "b"));
+ EXPECT_EQ("z", PartialMerge("z", "azzz"));
+ EXPECT_EQ("a", PartialMerge("a", ""));
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/rocksdb/utilities/write_batch_with_index/write_batch_with_index.cc b/src/rocksdb/utilities/write_batch_with_index/write_batch_with_index.cc
new file mode 100644
index 000000000..2df6bcaf3
--- /dev/null
+++ b/src/rocksdb/utilities/write_batch_with_index/write_batch_with_index.cc
@@ -0,0 +1,1065 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "rocksdb/utilities/write_batch_with_index.h"
+
+#include <memory>
+
+#include "db/column_family.h"
+#include "db/db_impl/db_impl.h"
+#include "db/merge_context.h"
+#include "db/merge_helper.h"
+#include "memory/arena.h"
+#include "memtable/skiplist.h"
+#include "options/db_options.h"
+#include "rocksdb/comparator.h"
+#include "rocksdb/iterator.h"
+#include "util/cast_util.h"
+#include "util/string_util.h"
+#include "utilities/write_batch_with_index/write_batch_with_index_internal.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// when direction == forward
+// * current_at_base_ <=> base_iterator > delta_iterator
+// when direction == backwards
+// * current_at_base_ <=> base_iterator < delta_iterator
+// always:
+// * equal_keys_ <=> base_iterator == delta_iterator
+class BaseDeltaIterator : public Iterator {
+ public:
+ BaseDeltaIterator(Iterator* base_iterator, WBWIIterator* delta_iterator,
+ const Comparator* comparator,
+ const ReadOptions* read_options = nullptr)
+ : forward_(true),
+ current_at_base_(true),
+ equal_keys_(false),
+ status_(Status::OK()),
+ base_iterator_(base_iterator),
+ delta_iterator_(delta_iterator),
+ comparator_(comparator),
+ iterate_upper_bound_(read_options ? read_options->iterate_upper_bound
+ : nullptr) {}
+
+ ~BaseDeltaIterator() override {}
+
+ bool Valid() const override {
+ return current_at_base_ ? BaseValid() : DeltaValid();
+ }
+
+ void SeekToFirst() override {
+ forward_ = true;
+ base_iterator_->SeekToFirst();
+ delta_iterator_->SeekToFirst();
+ UpdateCurrent();
+ }
+
+ void SeekToLast() override {
+ forward_ = false;
+ base_iterator_->SeekToLast();
+ delta_iterator_->SeekToLast();
+ UpdateCurrent();
+ }
+
+ void Seek(const Slice& k) override {
+ forward_ = true;
+ base_iterator_->Seek(k);
+ delta_iterator_->Seek(k);
+ UpdateCurrent();
+ }
+
+ void SeekForPrev(const Slice& k) override {
+ forward_ = false;
+ base_iterator_->SeekForPrev(k);
+ delta_iterator_->SeekForPrev(k);
+ UpdateCurrent();
+ }
+
+ void Next() override {
+ if (!Valid()) {
+ status_ = Status::NotSupported("Next() on invalid iterator");
+ return;
+ }
+
+ if (!forward_) {
+ // Need to change direction
+ // if our direction was backward and we're not equal, we have two states:
+ // * both iterators are valid: we're already in a good state (current
+ // shows to smaller)
+ // * only one iterator is valid: we need to advance that iterator
+ forward_ = true;
+ equal_keys_ = false;
+ if (!BaseValid()) {
+ assert(DeltaValid());
+ base_iterator_->SeekToFirst();
+ } else if (!DeltaValid()) {
+ delta_iterator_->SeekToFirst();
+ } else if (current_at_base_) {
+ // Change delta from larger than base to smaller
+ AdvanceDelta();
+ } else {
+ // Change base from larger than delta to smaller
+ AdvanceBase();
+ }
+ if (DeltaValid() && BaseValid()) {
+ if (comparator_->Equal(delta_iterator_->Entry().key,
+ base_iterator_->key())) {
+ equal_keys_ = true;
+ }
+ }
+ }
+ Advance();
+ }
+
+ void Prev() override {
+ if (!Valid()) {
+ status_ = Status::NotSupported("Prev() on invalid iterator");
+ return;
+ }
+
+ if (forward_) {
+ // Need to change direction
+ // if our direction was backward and we're not equal, we have two states:
+ // * both iterators are valid: we're already in a good state (current
+ // shows to smaller)
+ // * only one iterator is valid: we need to advance that iterator
+ forward_ = false;
+ equal_keys_ = false;
+ if (!BaseValid()) {
+ assert(DeltaValid());
+ base_iterator_->SeekToLast();
+ } else if (!DeltaValid()) {
+ delta_iterator_->SeekToLast();
+ } else if (current_at_base_) {
+ // Change delta from less advanced than base to more advanced
+ AdvanceDelta();
+ } else {
+ // Change base from less advanced than delta to more advanced
+ AdvanceBase();
+ }
+ if (DeltaValid() && BaseValid()) {
+ if (comparator_->Equal(delta_iterator_->Entry().key,
+ base_iterator_->key())) {
+ equal_keys_ = true;
+ }
+ }
+ }
+
+ Advance();
+ }
+
+ Slice key() const override {
+ return current_at_base_ ? base_iterator_->key()
+ : delta_iterator_->Entry().key;
+ }
+
+ Slice value() const override {
+ return current_at_base_ ? base_iterator_->value()
+ : delta_iterator_->Entry().value;
+ }
+
+ Status status() const override {
+ if (!status_.ok()) {
+ return status_;
+ }
+ if (!base_iterator_->status().ok()) {
+ return base_iterator_->status();
+ }
+ return delta_iterator_->status();
+ }
+
+ private:
+ void AssertInvariants() {
+#ifndef NDEBUG
+ bool not_ok = false;
+ if (!base_iterator_->status().ok()) {
+ assert(!base_iterator_->Valid());
+ not_ok = true;
+ }
+ if (!delta_iterator_->status().ok()) {
+ assert(!delta_iterator_->Valid());
+ not_ok = true;
+ }
+ if (not_ok) {
+ assert(!Valid());
+ assert(!status().ok());
+ return;
+ }
+
+ if (!Valid()) {
+ return;
+ }
+ if (!BaseValid()) {
+ assert(!current_at_base_ && delta_iterator_->Valid());
+ return;
+ }
+ if (!DeltaValid()) {
+ assert(current_at_base_ && base_iterator_->Valid());
+ return;
+ }
+ // we don't support those yet
+ assert(delta_iterator_->Entry().type != kMergeRecord &&
+ delta_iterator_->Entry().type != kLogDataRecord);
+ int compare = comparator_->Compare(delta_iterator_->Entry().key,
+ base_iterator_->key());
+ if (forward_) {
+ // current_at_base -> compare < 0
+ assert(!current_at_base_ || compare < 0);
+ // !current_at_base -> compare <= 0
+ assert(current_at_base_ && compare >= 0);
+ } else {
+ // current_at_base -> compare > 0
+ assert(!current_at_base_ || compare > 0);
+ // !current_at_base -> compare <= 0
+ assert(current_at_base_ && compare <= 0);
+ }
+ // equal_keys_ <=> compare == 0
+ assert((equal_keys_ || compare != 0) && (!equal_keys_ || compare == 0));
+#endif
+ }
+
+ void Advance() {
+ if (equal_keys_) {
+ assert(BaseValid() && DeltaValid());
+ AdvanceBase();
+ AdvanceDelta();
+ } else {
+ if (current_at_base_) {
+ assert(BaseValid());
+ AdvanceBase();
+ } else {
+ assert(DeltaValid());
+ AdvanceDelta();
+ }
+ }
+ UpdateCurrent();
+ }
+
+ void AdvanceDelta() {
+ if (forward_) {
+ delta_iterator_->Next();
+ } else {
+ delta_iterator_->Prev();
+ }
+ }
+ void AdvanceBase() {
+ if (forward_) {
+ base_iterator_->Next();
+ } else {
+ base_iterator_->Prev();
+ }
+ }
+ bool BaseValid() const { return base_iterator_->Valid(); }
+ bool DeltaValid() const { return delta_iterator_->Valid(); }
+ void UpdateCurrent() {
+// Suppress false positive clang analyzer warnings.
+#ifndef __clang_analyzer__
+ status_ = Status::OK();
+ while (true) {
+ WriteEntry delta_entry;
+ if (DeltaValid()) {
+ assert(delta_iterator_->status().ok());
+ delta_entry = delta_iterator_->Entry();
+ } else if (!delta_iterator_->status().ok()) {
+ // Expose the error status and stop.
+ current_at_base_ = false;
+ return;
+ }
+ equal_keys_ = false;
+ if (!BaseValid()) {
+ if (!base_iterator_->status().ok()) {
+ // Expose the error status and stop.
+ current_at_base_ = true;
+ return;
+ }
+
+ // Base has finished.
+ if (!DeltaValid()) {
+ // Finished
+ return;
+ }
+ if (iterate_upper_bound_) {
+ if (comparator_->Compare(delta_entry.key, *iterate_upper_bound_) >=
+ 0) {
+ // out of upper bound -> finished.
+ return;
+ }
+ }
+ if (delta_entry.type == kDeleteRecord ||
+ delta_entry.type == kSingleDeleteRecord) {
+ AdvanceDelta();
+ } else {
+ current_at_base_ = false;
+ return;
+ }
+ } else if (!DeltaValid()) {
+ // Delta has finished.
+ current_at_base_ = true;
+ return;
+ } else {
+ int compare =
+ (forward_ ? 1 : -1) *
+ comparator_->Compare(delta_entry.key, base_iterator_->key());
+ if (compare <= 0) { // delta bigger or equal
+ if (compare == 0) {
+ equal_keys_ = true;
+ }
+ if (delta_entry.type != kDeleteRecord &&
+ delta_entry.type != kSingleDeleteRecord) {
+ current_at_base_ = false;
+ return;
+ }
+ // Delta is less advanced and is delete.
+ AdvanceDelta();
+ if (equal_keys_) {
+ AdvanceBase();
+ }
+ } else {
+ current_at_base_ = true;
+ return;
+ }
+ }
+ }
+
+ AssertInvariants();
+#endif // __clang_analyzer__
+ }
+
+ bool forward_;
+ bool current_at_base_;
+ bool equal_keys_;
+ Status status_;
+ std::unique_ptr<Iterator> base_iterator_;
+ std::unique_ptr<WBWIIterator> delta_iterator_;
+ const Comparator* comparator_; // not owned
+ const Slice* iterate_upper_bound_;
+};
+
+typedef SkipList<WriteBatchIndexEntry*, const WriteBatchEntryComparator&>
+ WriteBatchEntrySkipList;
+
+class WBWIIteratorImpl : public WBWIIterator {
+ public:
+ WBWIIteratorImpl(uint32_t column_family_id,
+ WriteBatchEntrySkipList* skip_list,
+ const ReadableWriteBatch* write_batch)
+ : column_family_id_(column_family_id),
+ skip_list_iter_(skip_list),
+ write_batch_(write_batch) {}
+
+ ~WBWIIteratorImpl() override {}
+
+ bool Valid() const override {
+ if (!skip_list_iter_.Valid()) {
+ return false;
+ }
+ const WriteBatchIndexEntry* iter_entry = skip_list_iter_.key();
+ return (iter_entry != nullptr &&
+ iter_entry->column_family == column_family_id_);
+ }
+
+ void SeekToFirst() override {
+ WriteBatchIndexEntry search_entry(
+ nullptr /* search_key */, column_family_id_,
+ true /* is_forward_direction */, true /* is_seek_to_first */);
+ skip_list_iter_.Seek(&search_entry);
+ }
+
+ void SeekToLast() override {
+ WriteBatchIndexEntry search_entry(
+ nullptr /* search_key */, column_family_id_ + 1,
+ true /* is_forward_direction */, true /* is_seek_to_first */);
+ skip_list_iter_.Seek(&search_entry);
+ if (!skip_list_iter_.Valid()) {
+ skip_list_iter_.SeekToLast();
+ } else {
+ skip_list_iter_.Prev();
+ }
+ }
+
+ void Seek(const Slice& key) override {
+ WriteBatchIndexEntry search_entry(&key, column_family_id_,
+ true /* is_forward_direction */,
+ false /* is_seek_to_first */);
+ skip_list_iter_.Seek(&search_entry);
+ }
+
+ void SeekForPrev(const Slice& key) override {
+ WriteBatchIndexEntry search_entry(&key, column_family_id_,
+ false /* is_forward_direction */,
+ false /* is_seek_to_first */);
+ skip_list_iter_.SeekForPrev(&search_entry);
+ }
+
+ void Next() override { skip_list_iter_.Next(); }
+
+ void Prev() override { skip_list_iter_.Prev(); }
+
+ WriteEntry Entry() const override {
+ WriteEntry ret;
+ Slice blob, xid;
+ const WriteBatchIndexEntry* iter_entry = skip_list_iter_.key();
+ // this is guaranteed with Valid()
+ assert(iter_entry != nullptr &&
+ iter_entry->column_family == column_family_id_);
+ auto s = write_batch_->GetEntryFromDataOffset(
+ iter_entry->offset, &ret.type, &ret.key, &ret.value, &blob, &xid);
+ assert(s.ok());
+ assert(ret.type == kPutRecord || ret.type == kDeleteRecord ||
+ ret.type == kSingleDeleteRecord || ret.type == kDeleteRangeRecord ||
+ ret.type == kMergeRecord);
+ return ret;
+ }
+
+ Status status() const override {
+ // this is in-memory data structure, so the only way status can be non-ok is
+ // through memory corruption
+ return Status::OK();
+ }
+
+ const WriteBatchIndexEntry* GetRawEntry() const {
+ return skip_list_iter_.key();
+ }
+
+ private:
+ uint32_t column_family_id_;
+ WriteBatchEntrySkipList::Iterator skip_list_iter_;
+ const ReadableWriteBatch* write_batch_;
+};
+
+struct WriteBatchWithIndex::Rep {
+ explicit Rep(const Comparator* index_comparator, size_t reserved_bytes = 0,
+ size_t max_bytes = 0, bool _overwrite_key = false)
+ : write_batch(reserved_bytes, max_bytes),
+ comparator(index_comparator, &write_batch),
+ skip_list(comparator, &arena),
+ overwrite_key(_overwrite_key),
+ last_entry_offset(0),
+ last_sub_batch_offset(0),
+ sub_batch_cnt(1) {}
+ ReadableWriteBatch write_batch;
+ WriteBatchEntryComparator comparator;
+ Arena arena;
+ WriteBatchEntrySkipList skip_list;
+ bool overwrite_key;
+ size_t last_entry_offset;
+ // The starting offset of the last sub-batch. A sub-batch starts right before
+ // inserting a key that is a duplicate of a key in the last sub-batch. Zero,
+ // the default, means that no duplicate key is detected so far.
+ size_t last_sub_batch_offset;
+ // Total number of sub-batches in the write batch. Default is 1.
+ size_t sub_batch_cnt;
+
+ // Remember current offset of internal write batch, which is used as
+ // the starting offset of the next record.
+ void SetLastEntryOffset() { last_entry_offset = write_batch.GetDataSize(); }
+
+ // In overwrite mode, find the existing entry for the same key and update it
+ // to point to the current entry.
+ // Return true if the key is found and updated.
+ bool UpdateExistingEntry(ColumnFamilyHandle* column_family, const Slice& key);
+ bool UpdateExistingEntryWithCfId(uint32_t column_family_id, const Slice& key);
+
+ // Add the recent entry to the update.
+ // In overwrite mode, if key already exists in the index, update it.
+ void AddOrUpdateIndex(ColumnFamilyHandle* column_family, const Slice& key);
+ void AddOrUpdateIndex(const Slice& key);
+
+ // Allocate an index entry pointing to the last entry in the write batch and
+ // put it to skip list.
+ void AddNewEntry(uint32_t column_family_id);
+
+ // Clear all updates buffered in this batch.
+ void Clear();
+ void ClearIndex();
+
+ // Rebuild index by reading all records from the batch.
+ // Returns non-ok status on corruption.
+ Status ReBuildIndex();
+};
+
+bool WriteBatchWithIndex::Rep::UpdateExistingEntry(
+ ColumnFamilyHandle* column_family, const Slice& key) {
+ uint32_t cf_id = GetColumnFamilyID(column_family);
+ return UpdateExistingEntryWithCfId(cf_id, key);
+}
+
+bool WriteBatchWithIndex::Rep::UpdateExistingEntryWithCfId(
+ uint32_t column_family_id, const Slice& key) {
+ if (!overwrite_key) {
+ return false;
+ }
+
+ WBWIIteratorImpl iter(column_family_id, &skip_list, &write_batch);
+ iter.Seek(key);
+ if (!iter.Valid()) {
+ return false;
+ }
+ if (comparator.CompareKey(column_family_id, key, iter.Entry().key) != 0) {
+ return false;
+ }
+ WriteBatchIndexEntry* non_const_entry =
+ const_cast<WriteBatchIndexEntry*>(iter.GetRawEntry());
+ if (LIKELY(last_sub_batch_offset <= non_const_entry->offset)) {
+ last_sub_batch_offset = last_entry_offset;
+ sub_batch_cnt++;
+ }
+ non_const_entry->offset = last_entry_offset;
+ return true;
+}
+
+void WriteBatchWithIndex::Rep::AddOrUpdateIndex(
+ ColumnFamilyHandle* column_family, const Slice& key) {
+ if (!UpdateExistingEntry(column_family, key)) {
+ uint32_t cf_id = GetColumnFamilyID(column_family);
+ const auto* cf_cmp = GetColumnFamilyUserComparator(column_family);
+ if (cf_cmp != nullptr) {
+ comparator.SetComparatorForCF(cf_id, cf_cmp);
+ }
+ AddNewEntry(cf_id);
+ }
+}
+
+void WriteBatchWithIndex::Rep::AddOrUpdateIndex(const Slice& key) {
+ if (!UpdateExistingEntryWithCfId(0, key)) {
+ AddNewEntry(0);
+ }
+}
+
+void WriteBatchWithIndex::Rep::AddNewEntry(uint32_t column_family_id) {
+ const std::string& wb_data = write_batch.Data();
+ Slice entry_ptr = Slice(wb_data.data() + last_entry_offset,
+ wb_data.size() - last_entry_offset);
+ // Extract key
+ Slice key;
+ bool success __attribute__((__unused__));
+ success =
+ ReadKeyFromWriteBatchEntry(&entry_ptr, &key, column_family_id != 0);
+ assert(success);
+
+ auto* mem = arena.Allocate(sizeof(WriteBatchIndexEntry));
+ auto* index_entry =
+ new (mem) WriteBatchIndexEntry(last_entry_offset, column_family_id,
+ key.data() - wb_data.data(), key.size());
+ skip_list.Insert(index_entry);
+}
+
+void WriteBatchWithIndex::Rep::Clear() {
+ write_batch.Clear();
+ ClearIndex();
+}
+
+void WriteBatchWithIndex::Rep::ClearIndex() {
+ skip_list.~WriteBatchEntrySkipList();
+ arena.~Arena();
+ new (&arena) Arena();
+ new (&skip_list) WriteBatchEntrySkipList(comparator, &arena);
+ last_entry_offset = 0;
+ last_sub_batch_offset = 0;
+ sub_batch_cnt = 1;
+}
+
+Status WriteBatchWithIndex::Rep::ReBuildIndex() {
+ Status s;
+
+ ClearIndex();
+
+ if (write_batch.Count() == 0) {
+ // Nothing to re-index
+ return s;
+ }
+
+ size_t offset = WriteBatchInternal::GetFirstOffset(&write_batch);
+
+ Slice input(write_batch.Data());
+ input.remove_prefix(offset);
+
+ // Loop through all entries in Rep and add each one to the index
+ uint32_t found = 0;
+ while (s.ok() && !input.empty()) {
+ Slice key, value, blob, xid;
+ uint32_t column_family_id = 0; // default
+ char tag = 0;
+
+ // set offset of current entry for call to AddNewEntry()
+ last_entry_offset = input.data() - write_batch.Data().data();
+
+ s = ReadRecordFromWriteBatch(&input, &tag, &column_family_id, &key,
+ &value, &blob, &xid);
+ if (!s.ok()) {
+ break;
+ }
+
+ switch (tag) {
+ case kTypeColumnFamilyValue:
+ case kTypeValue:
+ case kTypeColumnFamilyDeletion:
+ case kTypeDeletion:
+ case kTypeColumnFamilySingleDeletion:
+ case kTypeSingleDeletion:
+ case kTypeColumnFamilyMerge:
+ case kTypeMerge:
+ found++;
+ if (!UpdateExistingEntryWithCfId(column_family_id, key)) {
+ AddNewEntry(column_family_id);
+ }
+ break;
+ case kTypeLogData:
+ case kTypeBeginPrepareXID:
+ case kTypeBeginPersistedPrepareXID:
+ case kTypeBeginUnprepareXID:
+ case kTypeEndPrepareXID:
+ case kTypeCommitXID:
+ case kTypeRollbackXID:
+ case kTypeNoop:
+ break;
+ default:
+ return Status::Corruption("unknown WriteBatch tag in ReBuildIndex",
+ ToString(static_cast<unsigned int>(tag)));
+ }
+ }
+
+ if (s.ok() && found != write_batch.Count()) {
+ s = Status::Corruption("WriteBatch has wrong count");
+ }
+
+ return s;
+}
+
+WriteBatchWithIndex::WriteBatchWithIndex(
+ const Comparator* default_index_comparator, size_t reserved_bytes,
+ bool overwrite_key, size_t max_bytes)
+ : rep(new Rep(default_index_comparator, reserved_bytes, max_bytes,
+ overwrite_key)) {}
+
+WriteBatchWithIndex::~WriteBatchWithIndex() {}
+
+WriteBatchWithIndex::WriteBatchWithIndex(WriteBatchWithIndex&&) = default;
+
+WriteBatchWithIndex& WriteBatchWithIndex::operator=(WriteBatchWithIndex&&) =
+ default;
+
+WriteBatch* WriteBatchWithIndex::GetWriteBatch() { return &rep->write_batch; }
+
+size_t WriteBatchWithIndex::SubBatchCnt() { return rep->sub_batch_cnt; }
+
+WBWIIterator* WriteBatchWithIndex::NewIterator() {
+ return new WBWIIteratorImpl(0, &(rep->skip_list), &rep->write_batch);
+}
+
+WBWIIterator* WriteBatchWithIndex::NewIterator(
+ ColumnFamilyHandle* column_family) {
+ return new WBWIIteratorImpl(GetColumnFamilyID(column_family),
+ &(rep->skip_list), &rep->write_batch);
+}
+
+Iterator* WriteBatchWithIndex::NewIteratorWithBase(
+ ColumnFamilyHandle* column_family, Iterator* base_iterator,
+ const ReadOptions* read_options) {
+ if (rep->overwrite_key == false) {
+ assert(false);
+ return nullptr;
+ }
+ return new BaseDeltaIterator(base_iterator, NewIterator(column_family),
+ GetColumnFamilyUserComparator(column_family),
+ read_options);
+}
+
+Iterator* WriteBatchWithIndex::NewIteratorWithBase(Iterator* base_iterator) {
+ if (rep->overwrite_key == false) {
+ assert(false);
+ return nullptr;
+ }
+ // default column family's comparator
+ return new BaseDeltaIterator(base_iterator, NewIterator(),
+ rep->comparator.default_comparator());
+}
+
+Status WriteBatchWithIndex::Put(ColumnFamilyHandle* column_family,
+ const Slice& key, const Slice& value) {
+ rep->SetLastEntryOffset();
+ auto s = rep->write_batch.Put(column_family, key, value);
+ if (s.ok()) {
+ rep->AddOrUpdateIndex(column_family, key);
+ }
+ return s;
+}
+
+Status WriteBatchWithIndex::Put(const Slice& key, const Slice& value) {
+ rep->SetLastEntryOffset();
+ auto s = rep->write_batch.Put(key, value);
+ if (s.ok()) {
+ rep->AddOrUpdateIndex(key);
+ }
+ return s;
+}
+
+Status WriteBatchWithIndex::Delete(ColumnFamilyHandle* column_family,
+ const Slice& key) {
+ rep->SetLastEntryOffset();
+ auto s = rep->write_batch.Delete(column_family, key);
+ if (s.ok()) {
+ rep->AddOrUpdateIndex(column_family, key);
+ }
+ return s;
+}
+
+Status WriteBatchWithIndex::Delete(const Slice& key) {
+ rep->SetLastEntryOffset();
+ auto s = rep->write_batch.Delete(key);
+ if (s.ok()) {
+ rep->AddOrUpdateIndex(key);
+ }
+ return s;
+}
+
+Status WriteBatchWithIndex::SingleDelete(ColumnFamilyHandle* column_family,
+ const Slice& key) {
+ rep->SetLastEntryOffset();
+ auto s = rep->write_batch.SingleDelete(column_family, key);
+ if (s.ok()) {
+ rep->AddOrUpdateIndex(column_family, key);
+ }
+ return s;
+}
+
+Status WriteBatchWithIndex::SingleDelete(const Slice& key) {
+ rep->SetLastEntryOffset();
+ auto s = rep->write_batch.SingleDelete(key);
+ if (s.ok()) {
+ rep->AddOrUpdateIndex(key);
+ }
+ return s;
+}
+
+Status WriteBatchWithIndex::Merge(ColumnFamilyHandle* column_family,
+ const Slice& key, const Slice& value) {
+ rep->SetLastEntryOffset();
+ auto s = rep->write_batch.Merge(column_family, key, value);
+ if (s.ok()) {
+ rep->AddOrUpdateIndex(column_family, key);
+ }
+ return s;
+}
+
+Status WriteBatchWithIndex::Merge(const Slice& key, const Slice& value) {
+ rep->SetLastEntryOffset();
+ auto s = rep->write_batch.Merge(key, value);
+ if (s.ok()) {
+ rep->AddOrUpdateIndex(key);
+ }
+ return s;
+}
+
+Status WriteBatchWithIndex::PutLogData(const Slice& blob) {
+ return rep->write_batch.PutLogData(blob);
+}
+
+void WriteBatchWithIndex::Clear() { rep->Clear(); }
+
+Status WriteBatchWithIndex::GetFromBatch(ColumnFamilyHandle* column_family,
+ const DBOptions& options,
+ const Slice& key, std::string* value) {
+ Status s;
+ MergeContext merge_context;
+ const ImmutableDBOptions immuable_db_options(options);
+
+ WriteBatchWithIndexInternal::Result result =
+ WriteBatchWithIndexInternal::GetFromBatch(
+ immuable_db_options, this, column_family, key, &merge_context,
+ &rep->comparator, value, rep->overwrite_key, &s);
+
+ switch (result) {
+ case WriteBatchWithIndexInternal::Result::kFound:
+ case WriteBatchWithIndexInternal::Result::kError:
+ // use returned status
+ break;
+ case WriteBatchWithIndexInternal::Result::kDeleted:
+ case WriteBatchWithIndexInternal::Result::kNotFound:
+ s = Status::NotFound();
+ break;
+ case WriteBatchWithIndexInternal::Result::kMergeInProgress:
+ s = Status::MergeInProgress();
+ break;
+ default:
+ assert(false);
+ }
+
+ return s;
+}
+
+Status WriteBatchWithIndex::GetFromBatchAndDB(DB* db,
+ const ReadOptions& read_options,
+ const Slice& key,
+ std::string* value) {
+ assert(value != nullptr);
+ PinnableSlice pinnable_val(value);
+ assert(!pinnable_val.IsPinned());
+ auto s = GetFromBatchAndDB(db, read_options, db->DefaultColumnFamily(), key,
+ &pinnable_val);
+ if (s.ok() && pinnable_val.IsPinned()) {
+ value->assign(pinnable_val.data(), pinnable_val.size());
+ } // else value is already assigned
+ return s;
+}
+
+Status WriteBatchWithIndex::GetFromBatchAndDB(DB* db,
+ const ReadOptions& read_options,
+ const Slice& key,
+ PinnableSlice* pinnable_val) {
+ return GetFromBatchAndDB(db, read_options, db->DefaultColumnFamily(), key,
+ pinnable_val);
+}
+
+Status WriteBatchWithIndex::GetFromBatchAndDB(DB* db,
+ const ReadOptions& read_options,
+ ColumnFamilyHandle* column_family,
+ const Slice& key,
+ std::string* value) {
+ assert(value != nullptr);
+ PinnableSlice pinnable_val(value);
+ assert(!pinnable_val.IsPinned());
+ auto s =
+ GetFromBatchAndDB(db, read_options, column_family, key, &pinnable_val);
+ if (s.ok() && pinnable_val.IsPinned()) {
+ value->assign(pinnable_val.data(), pinnable_val.size());
+ } // else value is already assigned
+ return s;
+}
+
+Status WriteBatchWithIndex::GetFromBatchAndDB(DB* db,
+ const ReadOptions& read_options,
+ ColumnFamilyHandle* column_family,
+ const Slice& key,
+ PinnableSlice* pinnable_val) {
+ return GetFromBatchAndDB(db, read_options, column_family, key, pinnable_val,
+ nullptr);
+}
+
+Status WriteBatchWithIndex::GetFromBatchAndDB(
+ DB* db, const ReadOptions& read_options, ColumnFamilyHandle* column_family,
+ const Slice& key, PinnableSlice* pinnable_val, ReadCallback* callback) {
+ Status s;
+ MergeContext merge_context;
+ const ImmutableDBOptions& immuable_db_options =
+ static_cast_with_check<DBImpl, DB>(db->GetRootDB())
+ ->immutable_db_options();
+
+ // Since the lifetime of the WriteBatch is the same as that of the transaction
+ // we cannot pin it as otherwise the returned value will not be available
+ // after the transaction finishes.
+ std::string& batch_value = *pinnable_val->GetSelf();
+ WriteBatchWithIndexInternal::Result result =
+ WriteBatchWithIndexInternal::GetFromBatch(
+ immuable_db_options, this, column_family, key, &merge_context,
+ &rep->comparator, &batch_value, rep->overwrite_key, &s);
+
+ if (result == WriteBatchWithIndexInternal::Result::kFound) {
+ pinnable_val->PinSelf();
+ return s;
+ }
+ if (result == WriteBatchWithIndexInternal::Result::kDeleted) {
+ return Status::NotFound();
+ }
+ if (result == WriteBatchWithIndexInternal::Result::kError) {
+ return s;
+ }
+ if (result == WriteBatchWithIndexInternal::Result::kMergeInProgress &&
+ rep->overwrite_key == true) {
+ // Since we've overwritten keys, we do not know what other operations are
+ // in this batch for this key, so we cannot do a Merge to compute the
+ // result. Instead, we will simply return MergeInProgress.
+ return Status::MergeInProgress();
+ }
+
+ assert(result == WriteBatchWithIndexInternal::Result::kMergeInProgress ||
+ result == WriteBatchWithIndexInternal::Result::kNotFound);
+
+ // Did not find key in batch OR could not resolve Merges. Try DB.
+ if (!callback) {
+ s = db->Get(read_options, column_family, key, pinnable_val);
+ } else {
+ DBImpl::GetImplOptions get_impl_options;
+ get_impl_options.column_family = column_family;
+ get_impl_options.value = pinnable_val;
+ get_impl_options.callback = callback;
+ s = static_cast_with_check<DBImpl, DB>(db->GetRootDB())
+ ->GetImpl(read_options, key, get_impl_options);
+ }
+
+ if (s.ok() || s.IsNotFound()) { // DB Get Succeeded
+ if (result == WriteBatchWithIndexInternal::Result::kMergeInProgress) {
+ // Merge result from DB with merges in Batch
+ auto cfh = reinterpret_cast<ColumnFamilyHandleImpl*>(column_family);
+ const MergeOperator* merge_operator =
+ cfh->cfd()->ioptions()->merge_operator;
+ Statistics* statistics = immuable_db_options.statistics.get();
+ Env* env = immuable_db_options.env;
+ Logger* logger = immuable_db_options.info_log.get();
+
+ Slice* merge_data;
+ if (s.ok()) {
+ merge_data = pinnable_val;
+ } else { // Key not present in db (s.IsNotFound())
+ merge_data = nullptr;
+ }
+
+ if (merge_operator) {
+ std::string merge_result;
+ s = MergeHelper::TimedFullMerge(merge_operator, key, merge_data,
+ merge_context.GetOperands(),
+ &merge_result, logger, statistics, env);
+ pinnable_val->Reset();
+ *pinnable_val->GetSelf() = std::move(merge_result);
+ pinnable_val->PinSelf();
+ } else {
+ s = Status::InvalidArgument("Options::merge_operator must be set");
+ }
+ }
+ }
+
+ return s;
+}
+
+void WriteBatchWithIndex::MultiGetFromBatchAndDB(
+ DB* db, const ReadOptions& read_options, ColumnFamilyHandle* column_family,
+ const size_t num_keys, const Slice* keys, PinnableSlice* values,
+ Status* statuses, bool sorted_input) {
+ MultiGetFromBatchAndDB(db, read_options, column_family, num_keys, keys,
+ values, statuses, sorted_input, nullptr);
+}
+
+void WriteBatchWithIndex::MultiGetFromBatchAndDB(
+ DB* db, const ReadOptions& read_options, ColumnFamilyHandle* column_family,
+ const size_t num_keys, const Slice* keys, PinnableSlice* values,
+ Status* statuses, bool sorted_input, ReadCallback* callback) {
+ const ImmutableDBOptions& immuable_db_options =
+ static_cast_with_check<DBImpl, DB>(db->GetRootDB())
+ ->immutable_db_options();
+
+ autovector<KeyContext, MultiGetContext::MAX_BATCH_SIZE> key_context;
+ autovector<KeyContext*, MultiGetContext::MAX_BATCH_SIZE> sorted_keys;
+ // To hold merges from the write batch
+ autovector<std::pair<WriteBatchWithIndexInternal::Result, MergeContext>,
+ MultiGetContext::MAX_BATCH_SIZE>
+ merges;
+ // Since the lifetime of the WriteBatch is the same as that of the transaction
+ // we cannot pin it as otherwise the returned value will not be available
+ // after the transaction finishes.
+ for (size_t i = 0; i < num_keys; ++i) {
+ MergeContext merge_context;
+ PinnableSlice* pinnable_val = &values[i];
+ std::string& batch_value = *pinnable_val->GetSelf();
+ Status* s = &statuses[i];
+ WriteBatchWithIndexInternal::Result result =
+ WriteBatchWithIndexInternal::GetFromBatch(
+ immuable_db_options, this, column_family, keys[i], &merge_context,
+ &rep->comparator, &batch_value, rep->overwrite_key, s);
+
+ if (result == WriteBatchWithIndexInternal::Result::kFound) {
+ pinnable_val->PinSelf();
+ continue;
+ }
+ if (result == WriteBatchWithIndexInternal::Result::kDeleted) {
+ *s = Status::NotFound();
+ continue;
+ }
+ if (result == WriteBatchWithIndexInternal::Result::kError) {
+ continue;
+ }
+ if (result == WriteBatchWithIndexInternal::Result::kMergeInProgress &&
+ rep->overwrite_key == true) {
+ // Since we've overwritten keys, we do not know what other operations are
+ // in this batch for this key, so we cannot do a Merge to compute the
+ // result. Instead, we will simply return MergeInProgress.
+ *s = Status::MergeInProgress();
+ continue;
+ }
+
+ assert(result == WriteBatchWithIndexInternal::Result::kMergeInProgress ||
+ result == WriteBatchWithIndexInternal::Result::kNotFound);
+ key_context.emplace_back(column_family, keys[i], &values[i], &statuses[i]);
+ merges.emplace_back(result, std::move(merge_context));
+ }
+
+ for (KeyContext& key : key_context) {
+ sorted_keys.emplace_back(&key);
+ }
+
+ // Did not find key in batch OR could not resolve Merges. Try DB.
+ static_cast_with_check<DBImpl, DB>(db->GetRootDB())
+ ->PrepareMultiGetKeys(key_context.size(), sorted_input, &sorted_keys);
+ static_cast_with_check<DBImpl, DB>(db->GetRootDB())
+ ->MultiGetWithCallback(read_options, column_family, callback,
+ &sorted_keys);
+
+ ColumnFamilyHandleImpl* cfh =
+ reinterpret_cast<ColumnFamilyHandleImpl*>(column_family);
+ const MergeOperator* merge_operator = cfh->cfd()->ioptions()->merge_operator;
+ for (auto iter = key_context.begin(); iter != key_context.end(); ++iter) {
+ KeyContext& key = *iter;
+ if (key.s->ok() || key.s->IsNotFound()) { // DB Get Succeeded
+ size_t index = iter - key_context.begin();
+ std::pair<WriteBatchWithIndexInternal::Result, MergeContext>&
+ merge_result = merges[index];
+ if (merge_result.first ==
+ WriteBatchWithIndexInternal::Result::kMergeInProgress) {
+ // Merge result from DB with merges in Batch
+ Statistics* statistics = immuable_db_options.statistics.get();
+ Env* env = immuable_db_options.env;
+ Logger* logger = immuable_db_options.info_log.get();
+
+ Slice* merge_data;
+ if (key.s->ok()) {
+ merge_data = iter->value;
+ } else { // Key not present in db (s.IsNotFound())
+ merge_data = nullptr;
+ }
+
+ if (merge_operator) {
+ *key.s = MergeHelper::TimedFullMerge(
+ merge_operator, *key.key, merge_data,
+ merge_result.second.GetOperands(), key.value->GetSelf(), logger,
+ statistics, env);
+ key.value->PinSelf();
+ } else {
+ *key.s =
+ Status::InvalidArgument("Options::merge_operator must be set");
+ }
+ }
+ }
+ }
+}
+
+void WriteBatchWithIndex::SetSavePoint() { rep->write_batch.SetSavePoint(); }
+
+Status WriteBatchWithIndex::RollbackToSavePoint() {
+ Status s = rep->write_batch.RollbackToSavePoint();
+
+ if (s.ok()) {
+ rep->sub_batch_cnt = 1;
+ rep->last_sub_batch_offset = 0;
+ s = rep->ReBuildIndex();
+ }
+
+ return s;
+}
+
+Status WriteBatchWithIndex::PopSavePoint() {
+ return rep->write_batch.PopSavePoint();
+}
+
+void WriteBatchWithIndex::SetMaxBytes(size_t max_bytes) {
+ rep->write_batch.SetMaxBytes(max_bytes);
+}
+
+size_t WriteBatchWithIndex::GetDataSize() const {
+ return rep->write_batch.GetDataSize();
+}
+
+} // namespace ROCKSDB_NAMESPACE
+#endif // !ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/write_batch_with_index/write_batch_with_index_internal.cc b/src/rocksdb/utilities/write_batch_with_index/write_batch_with_index_internal.cc
new file mode 100644
index 000000000..8c1222f21
--- /dev/null
+++ b/src/rocksdb/utilities/write_batch_with_index/write_batch_with_index_internal.cc
@@ -0,0 +1,288 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#ifndef ROCKSDB_LITE
+
+#include "utilities/write_batch_with_index/write_batch_with_index_internal.h"
+
+#include "db/column_family.h"
+#include "db/merge_context.h"
+#include "db/merge_helper.h"
+#include "rocksdb/comparator.h"
+#include "rocksdb/db.h"
+#include "rocksdb/utilities/write_batch_with_index.h"
+#include "util/coding.h"
+#include "util/string_util.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class Env;
+class Logger;
+class Statistics;
+
+Status ReadableWriteBatch::GetEntryFromDataOffset(size_t data_offset,
+ WriteType* type, Slice* Key,
+ Slice* value, Slice* blob,
+ Slice* xid) const {
+ if (type == nullptr || Key == nullptr || value == nullptr ||
+ blob == nullptr || xid == nullptr) {
+ return Status::InvalidArgument("Output parameters cannot be null");
+ }
+
+ if (data_offset == GetDataSize()) {
+ // reached end of batch.
+ return Status::NotFound();
+ }
+
+ if (data_offset > GetDataSize()) {
+ return Status::InvalidArgument("data offset exceed write batch size");
+ }
+ Slice input = Slice(rep_.data() + data_offset, rep_.size() - data_offset);
+ char tag;
+ uint32_t column_family;
+ Status s = ReadRecordFromWriteBatch(&input, &tag, &column_family, Key, value,
+ blob, xid);
+
+ switch (tag) {
+ case kTypeColumnFamilyValue:
+ case kTypeValue:
+ *type = kPutRecord;
+ break;
+ case kTypeColumnFamilyDeletion:
+ case kTypeDeletion:
+ *type = kDeleteRecord;
+ break;
+ case kTypeColumnFamilySingleDeletion:
+ case kTypeSingleDeletion:
+ *type = kSingleDeleteRecord;
+ break;
+ case kTypeColumnFamilyRangeDeletion:
+ case kTypeRangeDeletion:
+ *type = kDeleteRangeRecord;
+ break;
+ case kTypeColumnFamilyMerge:
+ case kTypeMerge:
+ *type = kMergeRecord;
+ break;
+ case kTypeLogData:
+ *type = kLogDataRecord;
+ break;
+ case kTypeNoop:
+ case kTypeBeginPrepareXID:
+ case kTypeBeginPersistedPrepareXID:
+ case kTypeBeginUnprepareXID:
+ case kTypeEndPrepareXID:
+ case kTypeCommitXID:
+ case kTypeRollbackXID:
+ *type = kXIDRecord;
+ break;
+ default:
+ return Status::Corruption("unknown WriteBatch tag ",
+ ToString(static_cast<unsigned int>(tag)));
+ }
+ return Status::OK();
+}
+
+// If both of `entry1` and `entry2` point to real entry in write batch, we
+// compare the entries as following:
+// 1. first compare the column family, the one with larger CF will be larger;
+// 2. Inside the same CF, we first decode the entry to find the key of the entry
+// and the entry with larger key will be larger;
+// 3. If two entries are of the same CF and offset, the one with larger offset
+// will be larger.
+// Some times either `entry1` or `entry2` is dummy entry, which is actually
+// a search key. In this case, in step 2, we don't go ahead and decode the
+// entry but use the value in WriteBatchIndexEntry::search_key.
+// One special case is WriteBatchIndexEntry::key_size is kFlagMinInCf.
+// This indicate that we are going to seek to the first of the column family.
+// Once we see this, this entry will be smaller than all the real entries of
+// the column family.
+int WriteBatchEntryComparator::operator()(
+ const WriteBatchIndexEntry* entry1,
+ const WriteBatchIndexEntry* entry2) const {
+ if (entry1->column_family > entry2->column_family) {
+ return 1;
+ } else if (entry1->column_family < entry2->column_family) {
+ return -1;
+ }
+
+ // Deal with special case of seeking to the beginning of a column family
+ if (entry1->is_min_in_cf()) {
+ return -1;
+ } else if (entry2->is_min_in_cf()) {
+ return 1;
+ }
+
+ Slice key1, key2;
+ if (entry1->search_key == nullptr) {
+ key1 = Slice(write_batch_->Data().data() + entry1->key_offset,
+ entry1->key_size);
+ } else {
+ key1 = *(entry1->search_key);
+ }
+ if (entry2->search_key == nullptr) {
+ key2 = Slice(write_batch_->Data().data() + entry2->key_offset,
+ entry2->key_size);
+ } else {
+ key2 = *(entry2->search_key);
+ }
+
+ int cmp = CompareKey(entry1->column_family, key1, key2);
+ if (cmp != 0) {
+ return cmp;
+ } else if (entry1->offset > entry2->offset) {
+ return 1;
+ } else if (entry1->offset < entry2->offset) {
+ return -1;
+ }
+ return 0;
+}
+
+int WriteBatchEntryComparator::CompareKey(uint32_t column_family,
+ const Slice& key1,
+ const Slice& key2) const {
+ if (column_family < cf_comparators_.size() &&
+ cf_comparators_[column_family] != nullptr) {
+ return cf_comparators_[column_family]->Compare(key1, key2);
+ } else {
+ return default_comparator_->Compare(key1, key2);
+ }
+}
+
+WriteBatchWithIndexInternal::Result WriteBatchWithIndexInternal::GetFromBatch(
+ const ImmutableDBOptions& immuable_db_options, WriteBatchWithIndex* batch,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ MergeContext* merge_context, WriteBatchEntryComparator* cmp,
+ std::string* value, bool overwrite_key, Status* s) {
+ uint32_t cf_id = GetColumnFamilyID(column_family);
+ *s = Status::OK();
+ WriteBatchWithIndexInternal::Result result =
+ WriteBatchWithIndexInternal::Result::kNotFound;
+
+ std::unique_ptr<WBWIIterator> iter =
+ std::unique_ptr<WBWIIterator>(batch->NewIterator(column_family));
+
+ // We want to iterate in the reverse order that the writes were added to the
+ // batch. Since we don't have a reverse iterator, we must seek past the end.
+ // TODO(agiardullo): consider adding support for reverse iteration
+ iter->Seek(key);
+ while (iter->Valid()) {
+ const WriteEntry entry = iter->Entry();
+ if (cmp->CompareKey(cf_id, entry.key, key) != 0) {
+ break;
+ }
+
+ iter->Next();
+ }
+
+ if (!(*s).ok()) {
+ return WriteBatchWithIndexInternal::Result::kError;
+ }
+
+ if (!iter->Valid()) {
+ // Read past end of results. Reposition on last result.
+ iter->SeekToLast();
+ } else {
+ iter->Prev();
+ }
+
+ Slice entry_value;
+ while (iter->Valid()) {
+ const WriteEntry entry = iter->Entry();
+ if (cmp->CompareKey(cf_id, entry.key, key) != 0) {
+ // Unexpected error or we've reached a different next key
+ break;
+ }
+
+ switch (entry.type) {
+ case kPutRecord: {
+ result = WriteBatchWithIndexInternal::Result::kFound;
+ entry_value = entry.value;
+ break;
+ }
+ case kMergeRecord: {
+ result = WriteBatchWithIndexInternal::Result::kMergeInProgress;
+ merge_context->PushOperand(entry.value);
+ break;
+ }
+ case kDeleteRecord:
+ case kSingleDeleteRecord: {
+ result = WriteBatchWithIndexInternal::Result::kDeleted;
+ break;
+ }
+ case kLogDataRecord:
+ case kXIDRecord: {
+ // ignore
+ break;
+ }
+ default: {
+ result = WriteBatchWithIndexInternal::Result::kError;
+ (*s) = Status::Corruption("Unexpected entry in WriteBatchWithIndex:",
+ ToString(entry.type));
+ break;
+ }
+ }
+ if (result == WriteBatchWithIndexInternal::Result::kFound ||
+ result == WriteBatchWithIndexInternal::Result::kDeleted ||
+ result == WriteBatchWithIndexInternal::Result::kError) {
+ // We can stop iterating once we find a PUT or DELETE
+ break;
+ }
+ if (result == WriteBatchWithIndexInternal::Result::kMergeInProgress &&
+ overwrite_key == true) {
+ // Since we've overwritten keys, we do not know what other operations are
+ // in this batch for this key, so we cannot do a Merge to compute the
+ // result. Instead, we will simply return MergeInProgress.
+ break;
+ }
+
+ iter->Prev();
+ }
+
+ if ((*s).ok()) {
+ if (result == WriteBatchWithIndexInternal::Result::kFound ||
+ result == WriteBatchWithIndexInternal::Result::kDeleted) {
+ // Found a Put or Delete. Merge if necessary.
+ if (merge_context->GetNumOperands() > 0) {
+ const MergeOperator* merge_operator;
+
+ if (column_family != nullptr) {
+ auto cfh = reinterpret_cast<ColumnFamilyHandleImpl*>(column_family);
+ merge_operator = cfh->cfd()->ioptions()->merge_operator;
+ } else {
+ *s = Status::InvalidArgument("Must provide a column_family");
+ result = WriteBatchWithIndexInternal::Result::kError;
+ return result;
+ }
+ Statistics* statistics = immuable_db_options.statistics.get();
+ Env* env = immuable_db_options.env;
+ Logger* logger = immuable_db_options.info_log.get();
+
+ if (merge_operator) {
+ *s = MergeHelper::TimedFullMerge(merge_operator, key, &entry_value,
+ merge_context->GetOperands(), value,
+ logger, statistics, env);
+ } else {
+ *s = Status::InvalidArgument("Options::merge_operator must be set");
+ }
+ if ((*s).ok()) {
+ result = WriteBatchWithIndexInternal::Result::kFound;
+ } else {
+ result = WriteBatchWithIndexInternal::Result::kError;
+ }
+ } else { // nothing to merge
+ if (result == WriteBatchWithIndexInternal::Result::kFound) { // PUT
+ value->assign(entry_value.data(), entry_value.size());
+ }
+ }
+ }
+ }
+
+ return result;
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+#endif // !ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/write_batch_with_index/write_batch_with_index_internal.h b/src/rocksdb/utilities/write_batch_with_index/write_batch_with_index_internal.h
new file mode 100644
index 000000000..6a859e072
--- /dev/null
+++ b/src/rocksdb/utilities/write_batch_with_index/write_batch_with_index_internal.h
@@ -0,0 +1,145 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+#pragma once
+
+#ifndef ROCKSDB_LITE
+
+#include <limits>
+#include <string>
+#include <vector>
+
+#include "options/db_options.h"
+#include "port/port.h"
+#include "rocksdb/comparator.h"
+#include "rocksdb/iterator.h"
+#include "rocksdb/slice.h"
+#include "rocksdb/status.h"
+#include "rocksdb/utilities/write_batch_with_index.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class MergeContext;
+struct Options;
+
+// Key used by skip list, as the binary searchable index of WriteBatchWithIndex.
+struct WriteBatchIndexEntry {
+ WriteBatchIndexEntry(size_t o, uint32_t c, size_t ko, size_t ksz)
+ : offset(o),
+ column_family(c),
+ key_offset(ko),
+ key_size(ksz),
+ search_key(nullptr) {}
+ // Create a dummy entry as the search key. This index entry won't be backed
+ // by an entry from the write batch, but a pointer to the search key. Or a
+ // special flag of offset can indicate we are seek to first.
+ // @_search_key: the search key
+ // @_column_family: column family
+ // @is_forward_direction: true for Seek(). False for SeekForPrev()
+ // @is_seek_to_first: true if we seek to the beginning of the column family
+ // _search_key should be null in this case.
+ WriteBatchIndexEntry(const Slice* _search_key, uint32_t _column_family,
+ bool is_forward_direction, bool is_seek_to_first)
+ // For SeekForPrev(), we need to make the dummy entry larger than any
+ // entry who has the same search key. Otherwise, we'll miss those entries.
+ : offset(is_forward_direction ? 0 : port::kMaxSizet),
+ column_family(_column_family),
+ key_offset(0),
+ key_size(is_seek_to_first ? kFlagMinInCf : 0),
+ search_key(_search_key) {
+ assert(_search_key != nullptr || is_seek_to_first);
+ }
+
+ // If this flag appears in the key_size, it indicates a
+ // key that is smaller than any other entry for the same column family.
+ static const size_t kFlagMinInCf = port::kMaxSizet;
+
+ bool is_min_in_cf() const {
+ assert(key_size != kFlagMinInCf ||
+ (key_offset == 0 && search_key == nullptr));
+ return key_size == kFlagMinInCf;
+ }
+
+ // offset of an entry in write batch's string buffer. If this is a dummy
+ // lookup key, in which case search_key != nullptr, offset is set to either
+ // 0 or max, only for comparison purpose. Because when entries have the same
+ // key, the entry with larger offset is larger, offset = 0 will make a seek
+ // key small or equal than all the entries with the seek key, so that Seek()
+ // will find all the entries of the same key. Similarly, offset = MAX will
+ // make the entry just larger than all entries with the search key so
+ // SeekForPrev() will see all the keys with the same key.
+ size_t offset;
+ uint32_t column_family; // c1olumn family of the entry.
+ size_t key_offset; // offset of the key in write batch's string buffer.
+ size_t key_size; // size of the key. kFlagMinInCf indicates
+ // that this is a dummy look up entry for
+ // SeekToFirst() to the beginning of the column
+ // family. We use the flag here to save a boolean
+ // in the struct.
+
+ const Slice* search_key; // if not null, instead of reading keys from
+ // write batch, use it to compare. This is used
+ // for lookup key.
+};
+
+class ReadableWriteBatch : public WriteBatch {
+ public:
+ explicit ReadableWriteBatch(size_t reserved_bytes = 0, size_t max_bytes = 0)
+ : WriteBatch(reserved_bytes, max_bytes) {}
+ // Retrieve some information from a write entry in the write batch, given
+ // the start offset of the write entry.
+ Status GetEntryFromDataOffset(size_t data_offset, WriteType* type, Slice* Key,
+ Slice* value, Slice* blob, Slice* xid) const;
+};
+
+class WriteBatchEntryComparator {
+ public:
+ WriteBatchEntryComparator(const Comparator* _default_comparator,
+ const ReadableWriteBatch* write_batch)
+ : default_comparator_(_default_comparator), write_batch_(write_batch) {}
+ // Compare a and b. Return a negative value if a is less than b, 0 if they
+ // are equal, and a positive value if a is greater than b
+ int operator()(const WriteBatchIndexEntry* entry1,
+ const WriteBatchIndexEntry* entry2) const;
+
+ int CompareKey(uint32_t column_family, const Slice& key1,
+ const Slice& key2) const;
+
+ void SetComparatorForCF(uint32_t column_family_id,
+ const Comparator* comparator) {
+ if (column_family_id >= cf_comparators_.size()) {
+ cf_comparators_.resize(column_family_id + 1, nullptr);
+ }
+ cf_comparators_[column_family_id] = comparator;
+ }
+
+ const Comparator* default_comparator() { return default_comparator_; }
+
+ private:
+ const Comparator* default_comparator_;
+ std::vector<const Comparator*> cf_comparators_;
+ const ReadableWriteBatch* write_batch_;
+};
+
+class WriteBatchWithIndexInternal {
+ public:
+ enum Result { kFound, kDeleted, kNotFound, kMergeInProgress, kError };
+
+ // If batch contains a value for key, store it in *value and return kFound.
+ // If batch contains a deletion for key, return Deleted.
+ // If batch contains Merge operations as the most recent entry for a key,
+ // and the merge process does not stop (not reaching a value or delete),
+ // prepend the current merge operands to *operands,
+ // and return kMergeInProgress
+ // If batch does not contain this key, return kNotFound
+ // Else, return kError on error with error Status stored in *s.
+ static WriteBatchWithIndexInternal::Result GetFromBatch(
+ const ImmutableDBOptions& ioptions, WriteBatchWithIndex* batch,
+ ColumnFamilyHandle* column_family, const Slice& key,
+ MergeContext* merge_context, WriteBatchEntryComparator* cmp,
+ std::string* value, bool overwrite_key, Status* s);
+};
+
+} // namespace ROCKSDB_NAMESPACE
+#endif // !ROCKSDB_LITE
diff --git a/src/rocksdb/utilities/write_batch_with_index/write_batch_with_index_test.cc b/src/rocksdb/utilities/write_batch_with_index/write_batch_with_index_test.cc
new file mode 100644
index 000000000..ac4ab7af4
--- /dev/null
+++ b/src/rocksdb/utilities/write_batch_with_index/write_batch_with_index_test.cc
@@ -0,0 +1,1846 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#ifndef ROCKSDB_LITE
+
+#include "rocksdb/utilities/write_batch_with_index.h"
+#include <map>
+#include <memory>
+#include "db/column_family.h"
+#include "port/stack_trace.h"
+#include "test_util/testharness.h"
+#include "util/random.h"
+#include "util/string_util.h"
+#include "utilities/merge_operators.h"
+#include "utilities/merge_operators/string_append/stringappend.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+namespace {
+class ColumnFamilyHandleImplDummy : public ColumnFamilyHandleImpl {
+ public:
+ explicit ColumnFamilyHandleImplDummy(int id, const Comparator* comparator)
+ : ColumnFamilyHandleImpl(nullptr, nullptr, nullptr),
+ id_(id),
+ comparator_(comparator) {}
+ uint32_t GetID() const override { return id_; }
+ const Comparator* GetComparator() const override { return comparator_; }
+
+ private:
+ uint32_t id_;
+ const Comparator* comparator_;
+};
+
+struct Entry {
+ std::string key;
+ std::string value;
+ WriteType type;
+};
+
+struct TestHandler : public WriteBatch::Handler {
+ std::map<uint32_t, std::vector<Entry>> seen;
+ Status PutCF(uint32_t column_family_id, const Slice& key,
+ const Slice& value) override {
+ Entry e;
+ e.key = key.ToString();
+ e.value = value.ToString();
+ e.type = kPutRecord;
+ seen[column_family_id].push_back(e);
+ return Status::OK();
+ }
+ Status MergeCF(uint32_t column_family_id, const Slice& key,
+ const Slice& value) override {
+ Entry e;
+ e.key = key.ToString();
+ e.value = value.ToString();
+ e.type = kMergeRecord;
+ seen[column_family_id].push_back(e);
+ return Status::OK();
+ }
+ void LogData(const Slice& /*blob*/) override {}
+ Status DeleteCF(uint32_t column_family_id, const Slice& key) override {
+ Entry e;
+ e.key = key.ToString();
+ e.value = "";
+ e.type = kDeleteRecord;
+ seen[column_family_id].push_back(e);
+ return Status::OK();
+ }
+};
+} // namespace anonymous
+
+class WriteBatchWithIndexTest : public testing::Test {};
+
+void TestValueAsSecondaryIndexHelper(std::vector<Entry> entries,
+ WriteBatchWithIndex* batch) {
+ // In this test, we insert <key, value> to column family `data`, and
+ // <value, key> to column family `index`. Then iterator them in order
+ // and seek them by key.
+
+ // Sort entries by key
+ std::map<std::string, std::vector<Entry*>> data_map;
+ // Sort entries by value
+ std::map<std::string, std::vector<Entry*>> index_map;
+ for (auto& e : entries) {
+ data_map[e.key].push_back(&e);
+ index_map[e.value].push_back(&e);
+ }
+
+ ColumnFamilyHandleImplDummy data(6, BytewiseComparator());
+ ColumnFamilyHandleImplDummy index(8, BytewiseComparator());
+ for (auto& e : entries) {
+ if (e.type == kPutRecord) {
+ batch->Put(&data, e.key, e.value);
+ batch->Put(&index, e.value, e.key);
+ } else if (e.type == kMergeRecord) {
+ batch->Merge(&data, e.key, e.value);
+ batch->Put(&index, e.value, e.key);
+ } else {
+ assert(e.type == kDeleteRecord);
+ std::unique_ptr<WBWIIterator> iter(batch->NewIterator(&data));
+ iter->Seek(e.key);
+ ASSERT_OK(iter->status());
+ auto write_entry = iter->Entry();
+ ASSERT_EQ(e.key, write_entry.key.ToString());
+ ASSERT_EQ(e.value, write_entry.value.ToString());
+ batch->Delete(&data, e.key);
+ batch->Put(&index, e.value, "");
+ }
+ }
+
+ // Iterator all keys
+ {
+ std::unique_ptr<WBWIIterator> iter(batch->NewIterator(&data));
+ for (int seek_to_first : {0, 1}) {
+ if (seek_to_first) {
+ iter->SeekToFirst();
+ } else {
+ iter->Seek("");
+ }
+ for (auto pair : data_map) {
+ for (auto v : pair.second) {
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ auto write_entry = iter->Entry();
+ ASSERT_EQ(pair.first, write_entry.key.ToString());
+ ASSERT_EQ(v->type, write_entry.type);
+ if (write_entry.type != kDeleteRecord) {
+ ASSERT_EQ(v->value, write_entry.value.ToString());
+ }
+ iter->Next();
+ }
+ }
+ ASSERT_TRUE(!iter->Valid());
+ }
+ iter->SeekToLast();
+ for (auto pair = data_map.rbegin(); pair != data_map.rend(); ++pair) {
+ for (auto v = pair->second.rbegin(); v != pair->second.rend(); v++) {
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ auto write_entry = iter->Entry();
+ ASSERT_EQ(pair->first, write_entry.key.ToString());
+ ASSERT_EQ((*v)->type, write_entry.type);
+ if (write_entry.type != kDeleteRecord) {
+ ASSERT_EQ((*v)->value, write_entry.value.ToString());
+ }
+ iter->Prev();
+ }
+ }
+ ASSERT_TRUE(!iter->Valid());
+ }
+
+ // Iterator all indexes
+ {
+ std::unique_ptr<WBWIIterator> iter(batch->NewIterator(&index));
+ for (int seek_to_first : {0, 1}) {
+ if (seek_to_first) {
+ iter->SeekToFirst();
+ } else {
+ iter->Seek("");
+ }
+ for (auto pair : index_map) {
+ for (auto v : pair.second) {
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ auto write_entry = iter->Entry();
+ ASSERT_EQ(pair.first, write_entry.key.ToString());
+ if (v->type != kDeleteRecord) {
+ ASSERT_EQ(v->key, write_entry.value.ToString());
+ ASSERT_EQ(v->value, write_entry.key.ToString());
+ }
+ iter->Next();
+ }
+ }
+ ASSERT_TRUE(!iter->Valid());
+ }
+
+ iter->SeekToLast();
+ for (auto pair = index_map.rbegin(); pair != index_map.rend(); ++pair) {
+ for (auto v = pair->second.rbegin(); v != pair->second.rend(); v++) {
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ auto write_entry = iter->Entry();
+ ASSERT_EQ(pair->first, write_entry.key.ToString());
+ if ((*v)->type != kDeleteRecord) {
+ ASSERT_EQ((*v)->key, write_entry.value.ToString());
+ ASSERT_EQ((*v)->value, write_entry.key.ToString());
+ }
+ iter->Prev();
+ }
+ }
+ ASSERT_TRUE(!iter->Valid());
+ }
+
+ // Seek to every key
+ {
+ std::unique_ptr<WBWIIterator> iter(batch->NewIterator(&data));
+
+ // Seek the keys one by one in reverse order
+ for (auto pair = data_map.rbegin(); pair != data_map.rend(); ++pair) {
+ iter->Seek(pair->first);
+ ASSERT_OK(iter->status());
+ for (auto v : pair->second) {
+ ASSERT_TRUE(iter->Valid());
+ auto write_entry = iter->Entry();
+ ASSERT_EQ(pair->first, write_entry.key.ToString());
+ ASSERT_EQ(v->type, write_entry.type);
+ if (write_entry.type != kDeleteRecord) {
+ ASSERT_EQ(v->value, write_entry.value.ToString());
+ }
+ iter->Next();
+ ASSERT_OK(iter->status());
+ }
+ }
+ }
+
+ // Seek to every index
+ {
+ std::unique_ptr<WBWIIterator> iter(batch->NewIterator(&index));
+
+ // Seek the keys one by one in reverse order
+ for (auto pair = index_map.rbegin(); pair != index_map.rend(); ++pair) {
+ iter->Seek(pair->first);
+ ASSERT_OK(iter->status());
+ for (auto v : pair->second) {
+ ASSERT_TRUE(iter->Valid());
+ auto write_entry = iter->Entry();
+ ASSERT_EQ(pair->first, write_entry.key.ToString());
+ ASSERT_EQ(v->value, write_entry.key.ToString());
+ if (v->type != kDeleteRecord) {
+ ASSERT_EQ(v->key, write_entry.value.ToString());
+ }
+ iter->Next();
+ ASSERT_OK(iter->status());
+ }
+ }
+ }
+
+ // Verify WriteBatch can be iterated
+ TestHandler handler;
+ batch->GetWriteBatch()->Iterate(&handler);
+
+ // Verify data column family
+ {
+ ASSERT_EQ(entries.size(), handler.seen[data.GetID()].size());
+ size_t i = 0;
+ for (auto e : handler.seen[data.GetID()]) {
+ auto write_entry = entries[i++];
+ ASSERT_EQ(e.type, write_entry.type);
+ ASSERT_EQ(e.key, write_entry.key);
+ if (e.type != kDeleteRecord) {
+ ASSERT_EQ(e.value, write_entry.value);
+ }
+ }
+ }
+
+ // Verify index column family
+ {
+ ASSERT_EQ(entries.size(), handler.seen[index.GetID()].size());
+ size_t i = 0;
+ for (auto e : handler.seen[index.GetID()]) {
+ auto write_entry = entries[i++];
+ ASSERT_EQ(e.key, write_entry.value);
+ if (write_entry.type != kDeleteRecord) {
+ ASSERT_EQ(e.value, write_entry.key);
+ }
+ }
+ }
+}
+
+TEST_F(WriteBatchWithIndexTest, TestValueAsSecondaryIndex) {
+ Entry entries[] = {
+ {"aaa", "0005", kPutRecord},
+ {"b", "0002", kPutRecord},
+ {"cdd", "0002", kMergeRecord},
+ {"aab", "00001", kPutRecord},
+ {"cc", "00005", kPutRecord},
+ {"cdd", "0002", kPutRecord},
+ {"aab", "0003", kPutRecord},
+ {"cc", "00005", kDeleteRecord},
+ };
+ std::vector<Entry> entries_list(entries, entries + 8);
+
+ WriteBatchWithIndex batch(nullptr, 20);
+
+ TestValueAsSecondaryIndexHelper(entries_list, &batch);
+
+ // Clear batch and re-run test with new values
+ batch.Clear();
+
+ Entry new_entries[] = {
+ {"aaa", "0005", kPutRecord},
+ {"e", "0002", kPutRecord},
+ {"add", "0002", kMergeRecord},
+ {"aab", "00001", kPutRecord},
+ {"zz", "00005", kPutRecord},
+ {"add", "0002", kPutRecord},
+ {"aab", "0003", kPutRecord},
+ {"zz", "00005", kDeleteRecord},
+ };
+
+ entries_list = std::vector<Entry>(new_entries, new_entries + 8);
+
+ TestValueAsSecondaryIndexHelper(entries_list, &batch);
+}
+
+TEST_F(WriteBatchWithIndexTest, TestComparatorForCF) {
+ ColumnFamilyHandleImplDummy cf1(6, nullptr);
+ ColumnFamilyHandleImplDummy reverse_cf(66, ReverseBytewiseComparator());
+ ColumnFamilyHandleImplDummy cf2(88, BytewiseComparator());
+ WriteBatchWithIndex batch(BytewiseComparator(), 20);
+
+ batch.Put(&cf1, "ddd", "");
+ batch.Put(&cf2, "aaa", "");
+ batch.Put(&cf2, "eee", "");
+ batch.Put(&cf1, "ccc", "");
+ batch.Put(&reverse_cf, "a11", "");
+ batch.Put(&cf1, "bbb", "");
+
+ Slice key_slices[] = {"a", "3", "3"};
+ Slice value_slice = "";
+ batch.Put(&reverse_cf, SliceParts(key_slices, 3),
+ SliceParts(&value_slice, 1));
+ batch.Put(&reverse_cf, "a22", "");
+
+ {
+ std::unique_ptr<WBWIIterator> iter(batch.NewIterator(&cf1));
+ iter->Seek("");
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("bbb", iter->Entry().key.ToString());
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("ccc", iter->Entry().key.ToString());
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("ddd", iter->Entry().key.ToString());
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(!iter->Valid());
+ }
+
+ {
+ std::unique_ptr<WBWIIterator> iter(batch.NewIterator(&cf2));
+ iter->Seek("");
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("aaa", iter->Entry().key.ToString());
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("eee", iter->Entry().key.ToString());
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(!iter->Valid());
+ }
+
+ {
+ std::unique_ptr<WBWIIterator> iter(batch.NewIterator(&reverse_cf));
+ iter->Seek("");
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(!iter->Valid());
+
+ iter->Seek("z");
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("a33", iter->Entry().key.ToString());
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("a22", iter->Entry().key.ToString());
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("a11", iter->Entry().key.ToString());
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(!iter->Valid());
+
+ iter->Seek("a22");
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("a22", iter->Entry().key.ToString());
+
+ iter->Seek("a13");
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("a11", iter->Entry().key.ToString());
+ }
+}
+
+TEST_F(WriteBatchWithIndexTest, TestOverwriteKey) {
+ ColumnFamilyHandleImplDummy cf1(6, nullptr);
+ ColumnFamilyHandleImplDummy reverse_cf(66, ReverseBytewiseComparator());
+ ColumnFamilyHandleImplDummy cf2(88, BytewiseComparator());
+ WriteBatchWithIndex batch(BytewiseComparator(), 20, true);
+
+ batch.Put(&cf1, "ddd", "");
+ batch.Merge(&cf1, "ddd", "");
+ batch.Delete(&cf1, "ddd");
+ batch.Put(&cf2, "aaa", "");
+ batch.Delete(&cf2, "aaa");
+ batch.Put(&cf2, "aaa", "aaa");
+ batch.Put(&cf2, "eee", "eee");
+ batch.Put(&cf1, "ccc", "");
+ batch.Put(&reverse_cf, "a11", "");
+ batch.Delete(&cf1, "ccc");
+ batch.Put(&reverse_cf, "a33", "a33");
+ batch.Put(&reverse_cf, "a11", "a11");
+ Slice slices[] = {"a", "3", "3"};
+ batch.Delete(&reverse_cf, SliceParts(slices, 3));
+
+ {
+ std::unique_ptr<WBWIIterator> iter(batch.NewIterator(&cf1));
+ iter->Seek("");
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("ccc", iter->Entry().key.ToString());
+ ASSERT_TRUE(iter->Entry().type == WriteType::kDeleteRecord);
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("ddd", iter->Entry().key.ToString());
+ ASSERT_TRUE(iter->Entry().type == WriteType::kDeleteRecord);
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(!iter->Valid());
+ }
+
+ {
+ std::unique_ptr<WBWIIterator> iter(batch.NewIterator(&cf2));
+ iter->SeekToLast();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("eee", iter->Entry().key.ToString());
+ ASSERT_EQ("eee", iter->Entry().value.ToString());
+ iter->Prev();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("aaa", iter->Entry().key.ToString());
+ ASSERT_EQ("aaa", iter->Entry().value.ToString());
+ iter->Prev();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(!iter->Valid());
+
+ iter->SeekToFirst();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("aaa", iter->Entry().key.ToString());
+ ASSERT_EQ("aaa", iter->Entry().value.ToString());
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("eee", iter->Entry().key.ToString());
+ ASSERT_EQ("eee", iter->Entry().value.ToString());
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(!iter->Valid());
+ }
+
+ {
+ std::unique_ptr<WBWIIterator> iter(batch.NewIterator(&reverse_cf));
+ iter->Seek("");
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(!iter->Valid());
+
+ iter->Seek("z");
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("a33", iter->Entry().key.ToString());
+ ASSERT_TRUE(iter->Entry().type == WriteType::kDeleteRecord);
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("a11", iter->Entry().key.ToString());
+ ASSERT_EQ("a11", iter->Entry().value.ToString());
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(!iter->Valid());
+
+ iter->SeekToLast();
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("a11", iter->Entry().key.ToString());
+ ASSERT_EQ("a11", iter->Entry().value.ToString());
+ iter->Prev();
+
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ("a33", iter->Entry().key.ToString());
+ ASSERT_TRUE(iter->Entry().type == WriteType::kDeleteRecord);
+ iter->Prev();
+ ASSERT_TRUE(!iter->Valid());
+ }
+}
+
+namespace {
+typedef std::map<std::string, std::string> KVMap;
+
+class KVIter : public Iterator {
+ public:
+ explicit KVIter(const KVMap* map) : map_(map), iter_(map_->end()) {}
+ bool Valid() const override { return iter_ != map_->end(); }
+ void SeekToFirst() override { iter_ = map_->begin(); }
+ void SeekToLast() override {
+ if (map_->empty()) {
+ iter_ = map_->end();
+ } else {
+ iter_ = map_->find(map_->rbegin()->first);
+ }
+ }
+ void Seek(const Slice& k) override {
+ iter_ = map_->lower_bound(k.ToString());
+ }
+ void SeekForPrev(const Slice& k) override {
+ iter_ = map_->upper_bound(k.ToString());
+ Prev();
+ }
+ void Next() override { ++iter_; }
+ void Prev() override {
+ if (iter_ == map_->begin()) {
+ iter_ = map_->end();
+ return;
+ }
+ --iter_;
+ }
+
+ Slice key() const override { return iter_->first; }
+ Slice value() const override { return iter_->second; }
+ Status status() const override { return Status::OK(); }
+
+ private:
+ const KVMap* const map_;
+ KVMap::const_iterator iter_;
+};
+
+void AssertIter(Iterator* iter, const std::string& key,
+ const std::string& value) {
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ(key, iter->key().ToString());
+ ASSERT_EQ(value, iter->value().ToString());
+}
+
+void AssertItersEqual(Iterator* iter1, Iterator* iter2) {
+ ASSERT_EQ(iter1->Valid(), iter2->Valid());
+ if (iter1->Valid()) {
+ ASSERT_EQ(iter1->key().ToString(), iter2->key().ToString());
+ ASSERT_EQ(iter1->value().ToString(), iter2->value().ToString());
+ }
+}
+} // namespace
+
+TEST_F(WriteBatchWithIndexTest, TestRandomIteraratorWithBase) {
+ std::vector<std::string> source_strings = {"a", "b", "c", "d", "e",
+ "f", "g", "h", "i", "j"};
+ for (int rand_seed = 301; rand_seed < 366; rand_seed++) {
+ Random rnd(rand_seed);
+
+ ColumnFamilyHandleImplDummy cf1(6, BytewiseComparator());
+ ColumnFamilyHandleImplDummy cf2(2, BytewiseComparator());
+ ColumnFamilyHandleImplDummy cf3(8, BytewiseComparator());
+
+ WriteBatchWithIndex batch(BytewiseComparator(), 20, true);
+
+ if (rand_seed % 2 == 0) {
+ batch.Put(&cf2, "zoo", "bar");
+ }
+ if (rand_seed % 4 == 1) {
+ batch.Put(&cf3, "zoo", "bar");
+ }
+
+ KVMap map;
+ KVMap merged_map;
+ for (auto key : source_strings) {
+ std::string value = key + key;
+ int type = rnd.Uniform(6);
+ switch (type) {
+ case 0:
+ // only base has it
+ map[key] = value;
+ merged_map[key] = value;
+ break;
+ case 1:
+ // only delta has it
+ batch.Put(&cf1, key, value);
+ map[key] = value;
+ merged_map[key] = value;
+ break;
+ case 2:
+ // both has it. Delta should win
+ batch.Put(&cf1, key, value);
+ map[key] = "wrong_value";
+ merged_map[key] = value;
+ break;
+ case 3:
+ // both has it. Delta is delete
+ batch.Delete(&cf1, key);
+ map[key] = "wrong_value";
+ break;
+ case 4:
+ // only delta has it. Delta is delete
+ batch.Delete(&cf1, key);
+ map[key] = "wrong_value";
+ break;
+ default:
+ // Neither iterator has it.
+ break;
+ }
+ }
+
+ std::unique_ptr<Iterator> iter(
+ batch.NewIteratorWithBase(&cf1, new KVIter(&map)));
+ std::unique_ptr<Iterator> result_iter(new KVIter(&merged_map));
+
+ bool is_valid = false;
+ for (int i = 0; i < 128; i++) {
+ // Random walk and make sure iter and result_iter returns the
+ // same key and value
+ int type = rnd.Uniform(6);
+ ASSERT_OK(iter->status());
+ switch (type) {
+ case 0:
+ // Seek to First
+ iter->SeekToFirst();
+ result_iter->SeekToFirst();
+ break;
+ case 1:
+ // Seek to last
+ iter->SeekToLast();
+ result_iter->SeekToLast();
+ break;
+ case 2: {
+ // Seek to random key
+ auto key_idx = rnd.Uniform(static_cast<int>(source_strings.size()));
+ auto key = source_strings[key_idx];
+ iter->Seek(key);
+ result_iter->Seek(key);
+ break;
+ }
+ case 3: {
+ // SeekForPrev to random key
+ auto key_idx = rnd.Uniform(static_cast<int>(source_strings.size()));
+ auto key = source_strings[key_idx];
+ iter->SeekForPrev(key);
+ result_iter->SeekForPrev(key);
+ break;
+ }
+ case 4:
+ // Next
+ if (is_valid) {
+ iter->Next();
+ result_iter->Next();
+ } else {
+ continue;
+ }
+ break;
+ default:
+ assert(type == 5);
+ // Prev
+ if (is_valid) {
+ iter->Prev();
+ result_iter->Prev();
+ } else {
+ continue;
+ }
+ break;
+ }
+ AssertItersEqual(iter.get(), result_iter.get());
+ is_valid = iter->Valid();
+ }
+ }
+}
+
+TEST_F(WriteBatchWithIndexTest, TestIteraratorWithBase) {
+ ColumnFamilyHandleImplDummy cf1(6, BytewiseComparator());
+ ColumnFamilyHandleImplDummy cf2(2, BytewiseComparator());
+ WriteBatchWithIndex batch(BytewiseComparator(), 20, true);
+
+ {
+ KVMap map;
+ map["a"] = "aa";
+ map["c"] = "cc";
+ map["e"] = "ee";
+ std::unique_ptr<Iterator> iter(
+ batch.NewIteratorWithBase(&cf1, new KVIter(&map)));
+
+ iter->SeekToFirst();
+ AssertIter(iter.get(), "a", "aa");
+ iter->Next();
+ AssertIter(iter.get(), "c", "cc");
+ iter->Next();
+ AssertIter(iter.get(), "e", "ee");
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(!iter->Valid());
+
+ iter->SeekToLast();
+ AssertIter(iter.get(), "e", "ee");
+ iter->Prev();
+ AssertIter(iter.get(), "c", "cc");
+ iter->Prev();
+ AssertIter(iter.get(), "a", "aa");
+ iter->Prev();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(!iter->Valid());
+
+ iter->Seek("b");
+ AssertIter(iter.get(), "c", "cc");
+
+ iter->Prev();
+ AssertIter(iter.get(), "a", "aa");
+
+ iter->Seek("a");
+ AssertIter(iter.get(), "a", "aa");
+ }
+
+ // Test the case that there is one element in the write batch
+ batch.Put(&cf2, "zoo", "bar");
+ batch.Put(&cf1, "a", "aa");
+ {
+ KVMap empty_map;
+ std::unique_ptr<Iterator> iter(
+ batch.NewIteratorWithBase(&cf1, new KVIter(&empty_map)));
+
+ iter->SeekToFirst();
+ AssertIter(iter.get(), "a", "aa");
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(!iter->Valid());
+ }
+
+ batch.Delete(&cf1, "b");
+ batch.Put(&cf1, "c", "cc");
+ batch.Put(&cf1, "d", "dd");
+ batch.Delete(&cf1, "e");
+
+ {
+ KVMap map;
+ map["b"] = "";
+ map["cc"] = "cccc";
+ map["f"] = "ff";
+ std::unique_ptr<Iterator> iter(
+ batch.NewIteratorWithBase(&cf1, new KVIter(&map)));
+
+ iter->SeekToFirst();
+ AssertIter(iter.get(), "a", "aa");
+ iter->Next();
+ AssertIter(iter.get(), "c", "cc");
+ iter->Next();
+ AssertIter(iter.get(), "cc", "cccc");
+ iter->Next();
+ AssertIter(iter.get(), "d", "dd");
+ iter->Next();
+ AssertIter(iter.get(), "f", "ff");
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(!iter->Valid());
+
+ iter->SeekToLast();
+ AssertIter(iter.get(), "f", "ff");
+ iter->Prev();
+ AssertIter(iter.get(), "d", "dd");
+ iter->Prev();
+ AssertIter(iter.get(), "cc", "cccc");
+ iter->Prev();
+ AssertIter(iter.get(), "c", "cc");
+ iter->Next();
+ AssertIter(iter.get(), "cc", "cccc");
+ iter->Prev();
+ AssertIter(iter.get(), "c", "cc");
+ iter->Prev();
+ AssertIter(iter.get(), "a", "aa");
+ iter->Prev();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(!iter->Valid());
+
+ iter->Seek("c");
+ AssertIter(iter.get(), "c", "cc");
+
+ iter->Seek("cb");
+ AssertIter(iter.get(), "cc", "cccc");
+
+ iter->Seek("cc");
+ AssertIter(iter.get(), "cc", "cccc");
+ iter->Next();
+ AssertIter(iter.get(), "d", "dd");
+
+ iter->Seek("e");
+ AssertIter(iter.get(), "f", "ff");
+
+ iter->Prev();
+ AssertIter(iter.get(), "d", "dd");
+
+ iter->Next();
+ AssertIter(iter.get(), "f", "ff");
+ }
+
+ {
+ KVMap empty_map;
+ std::unique_ptr<Iterator> iter(
+ batch.NewIteratorWithBase(&cf1, new KVIter(&empty_map)));
+
+ iter->SeekToFirst();
+ AssertIter(iter.get(), "a", "aa");
+ iter->Next();
+ AssertIter(iter.get(), "c", "cc");
+ iter->Next();
+ AssertIter(iter.get(), "d", "dd");
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(!iter->Valid());
+
+ iter->SeekToLast();
+ AssertIter(iter.get(), "d", "dd");
+ iter->Prev();
+ AssertIter(iter.get(), "c", "cc");
+ iter->Prev();
+ AssertIter(iter.get(), "a", "aa");
+
+ iter->Prev();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(!iter->Valid());
+
+ iter->Seek("aa");
+ AssertIter(iter.get(), "c", "cc");
+ iter->Next();
+ AssertIter(iter.get(), "d", "dd");
+
+ iter->Seek("ca");
+ AssertIter(iter.get(), "d", "dd");
+
+ iter->Prev();
+ AssertIter(iter.get(), "c", "cc");
+ }
+}
+
+TEST_F(WriteBatchWithIndexTest, TestIteraratorWithBaseReverseCmp) {
+ ColumnFamilyHandleImplDummy cf1(6, ReverseBytewiseComparator());
+ ColumnFamilyHandleImplDummy cf2(2, ReverseBytewiseComparator());
+ WriteBatchWithIndex batch(BytewiseComparator(), 20, true);
+
+ // Test the case that there is one element in the write batch
+ batch.Put(&cf2, "zoo", "bar");
+ batch.Put(&cf1, "a", "aa");
+ {
+ KVMap empty_map;
+ std::unique_ptr<Iterator> iter(
+ batch.NewIteratorWithBase(&cf1, new KVIter(&empty_map)));
+
+ iter->SeekToFirst();
+ AssertIter(iter.get(), "a", "aa");
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(!iter->Valid());
+ }
+
+ batch.Put(&cf1, "c", "cc");
+ {
+ KVMap map;
+ std::unique_ptr<Iterator> iter(
+ batch.NewIteratorWithBase(&cf1, new KVIter(&map)));
+
+ iter->SeekToFirst();
+ AssertIter(iter.get(), "c", "cc");
+ iter->Next();
+ AssertIter(iter.get(), "a", "aa");
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(!iter->Valid());
+
+ iter->SeekToLast();
+ AssertIter(iter.get(), "a", "aa");
+ iter->Prev();
+ AssertIter(iter.get(), "c", "cc");
+ iter->Prev();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(!iter->Valid());
+
+ iter->Seek("b");
+ AssertIter(iter.get(), "a", "aa");
+
+ iter->Prev();
+ AssertIter(iter.get(), "c", "cc");
+
+ iter->Seek("a");
+ AssertIter(iter.get(), "a", "aa");
+ }
+
+ // default column family
+ batch.Put("a", "b");
+ {
+ KVMap map;
+ map["b"] = "";
+ std::unique_ptr<Iterator> iter(batch.NewIteratorWithBase(new KVIter(&map)));
+
+ iter->SeekToFirst();
+ AssertIter(iter.get(), "a", "b");
+ iter->Next();
+ AssertIter(iter.get(), "b", "");
+ iter->Next();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(!iter->Valid());
+
+ iter->SeekToLast();
+ AssertIter(iter.get(), "b", "");
+ iter->Prev();
+ AssertIter(iter.get(), "a", "b");
+ iter->Prev();
+ ASSERT_OK(iter->status());
+ ASSERT_TRUE(!iter->Valid());
+
+ iter->Seek("b");
+ AssertIter(iter.get(), "b", "");
+
+ iter->Prev();
+ AssertIter(iter.get(), "a", "b");
+
+ iter->Seek("0");
+ AssertIter(iter.get(), "a", "b");
+ }
+}
+
+TEST_F(WriteBatchWithIndexTest, TestGetFromBatch) {
+ Options options;
+ WriteBatchWithIndex batch;
+ Status s;
+ std::string value;
+
+ s = batch.GetFromBatch(options, "b", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ batch.Put("a", "a");
+ batch.Put("b", "b");
+ batch.Put("c", "c");
+ batch.Put("a", "z");
+ batch.Delete("c");
+ batch.Delete("d");
+ batch.Delete("e");
+ batch.Put("e", "e");
+
+ s = batch.GetFromBatch(options, "b", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("b", value);
+
+ s = batch.GetFromBatch(options, "a", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("z", value);
+
+ s = batch.GetFromBatch(options, "c", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = batch.GetFromBatch(options, "d", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = batch.GetFromBatch(options, "x", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = batch.GetFromBatch(options, "e", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("e", value);
+
+ batch.Merge("z", "z");
+
+ s = batch.GetFromBatch(options, "z", &value);
+ ASSERT_NOK(s); // No merge operator specified.
+
+ s = batch.GetFromBatch(options, "b", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("b", value);
+}
+
+TEST_F(WriteBatchWithIndexTest, TestGetFromBatchMerge) {
+ DB* db;
+ Options options;
+ options.merge_operator = MergeOperators::CreateFromStringId("stringappend");
+ options.create_if_missing = true;
+
+ std::string dbname = test::PerThreadDBPath("write_batch_with_index_test");
+
+ DestroyDB(dbname, options);
+ Status s = DB::Open(options, dbname, &db);
+ ASSERT_OK(s);
+
+ ColumnFamilyHandle* column_family = db->DefaultColumnFamily();
+ WriteBatchWithIndex batch;
+ std::string value;
+
+ s = batch.GetFromBatch(options, "x", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ batch.Put("x", "X");
+ std::string expected = "X";
+
+ for (int i = 0; i < 5; i++) {
+ batch.Merge("x", ToString(i));
+ expected = expected + "," + ToString(i);
+
+ if (i % 2 == 0) {
+ batch.Put("y", ToString(i / 2));
+ }
+
+ batch.Merge("z", "z");
+
+ s = batch.GetFromBatch(column_family, options, "x", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(expected, value);
+
+ s = batch.GetFromBatch(column_family, options, "y", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ(ToString(i / 2), value);
+
+ s = batch.GetFromBatch(column_family, options, "z", &value);
+ ASSERT_TRUE(s.IsMergeInProgress());
+ }
+
+ delete db;
+ DestroyDB(dbname, options);
+}
+
+TEST_F(WriteBatchWithIndexTest, TestGetFromBatchMerge2) {
+ DB* db;
+ Options options;
+ options.merge_operator = MergeOperators::CreateFromStringId("stringappend");
+ options.create_if_missing = true;
+
+ std::string dbname = test::PerThreadDBPath("write_batch_with_index_test");
+
+ DestroyDB(dbname, options);
+ Status s = DB::Open(options, dbname, &db);
+ ASSERT_OK(s);
+
+ ColumnFamilyHandle* column_family = db->DefaultColumnFamily();
+
+ // Test batch with overwrite_key=true
+ WriteBatchWithIndex batch(BytewiseComparator(), 0, true);
+ std::string value;
+
+ s = batch.GetFromBatch(column_family, options, "X", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ batch.Put(column_family, "X", "x");
+ s = batch.GetFromBatch(column_family, options, "X", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("x", value);
+
+ batch.Put(column_family, "X", "x2");
+ s = batch.GetFromBatch(column_family, options, "X", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("x2", value);
+
+ batch.Merge(column_family, "X", "aaa");
+ s = batch.GetFromBatch(column_family, options, "X", &value);
+ ASSERT_TRUE(s.IsMergeInProgress());
+
+ batch.Merge(column_family, "X", "bbb");
+ s = batch.GetFromBatch(column_family, options, "X", &value);
+ ASSERT_TRUE(s.IsMergeInProgress());
+
+ batch.Put(column_family, "X", "x3");
+ s = batch.GetFromBatch(column_family, options, "X", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("x3", value);
+
+ batch.Merge(column_family, "X", "ccc");
+ s = batch.GetFromBatch(column_family, options, "X", &value);
+ ASSERT_TRUE(s.IsMergeInProgress());
+
+ batch.Delete(column_family, "X");
+ s = batch.GetFromBatch(column_family, options, "X", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ batch.Merge(column_family, "X", "ddd");
+ s = batch.GetFromBatch(column_family, options, "X", &value);
+ ASSERT_TRUE(s.IsMergeInProgress());
+
+ delete db;
+ DestroyDB(dbname, options);
+}
+
+TEST_F(WriteBatchWithIndexTest, TestGetFromBatchAndDB) {
+ DB* db;
+ Options options;
+ options.create_if_missing = true;
+ std::string dbname = test::PerThreadDBPath("write_batch_with_index_test");
+
+ DestroyDB(dbname, options);
+ Status s = DB::Open(options, dbname, &db);
+ ASSERT_OK(s);
+
+ WriteBatchWithIndex batch;
+ ReadOptions read_options;
+ WriteOptions write_options;
+ std::string value;
+
+ s = db->Put(write_options, "a", "a");
+ ASSERT_OK(s);
+
+ s = db->Put(write_options, "b", "b");
+ ASSERT_OK(s);
+
+ s = db->Put(write_options, "c", "c");
+ ASSERT_OK(s);
+
+ batch.Put("a", "batch.a");
+ batch.Delete("b");
+
+ s = batch.GetFromBatchAndDB(db, read_options, "a", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("batch.a", value);
+
+ s = batch.GetFromBatchAndDB(db, read_options, "b", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = batch.GetFromBatchAndDB(db, read_options, "c", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("c", value);
+
+ s = batch.GetFromBatchAndDB(db, read_options, "x", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ db->Delete(write_options, "x");
+
+ s = batch.GetFromBatchAndDB(db, read_options, "x", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ delete db;
+ DestroyDB(dbname, options);
+}
+
+TEST_F(WriteBatchWithIndexTest, TestGetFromBatchAndDBMerge) {
+ DB* db;
+ Options options;
+
+ options.create_if_missing = true;
+ std::string dbname = test::PerThreadDBPath("write_batch_with_index_test");
+
+ options.merge_operator = MergeOperators::CreateFromStringId("stringappend");
+
+ DestroyDB(dbname, options);
+ Status s = DB::Open(options, dbname, &db);
+ assert(s.ok());
+
+ WriteBatchWithIndex batch;
+ ReadOptions read_options;
+ WriteOptions write_options;
+ std::string value;
+
+ s = db->Put(write_options, "a", "a0");
+ ASSERT_OK(s);
+
+ s = db->Put(write_options, "b", "b0");
+ ASSERT_OK(s);
+
+ s = db->Merge(write_options, "b", "b1");
+ ASSERT_OK(s);
+
+ s = db->Merge(write_options, "c", "c0");
+ ASSERT_OK(s);
+
+ s = db->Merge(write_options, "d", "d0");
+ ASSERT_OK(s);
+
+ batch.Merge("a", "a1");
+ batch.Merge("a", "a2");
+ batch.Merge("b", "b2");
+ batch.Merge("d", "d1");
+ batch.Merge("e", "e0");
+
+ s = batch.GetFromBatchAndDB(db, read_options, "a", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("a0,a1,a2", value);
+
+ s = batch.GetFromBatchAndDB(db, read_options, "b", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("b0,b1,b2", value);
+
+ s = batch.GetFromBatchAndDB(db, read_options, "c", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("c0", value);
+
+ s = batch.GetFromBatchAndDB(db, read_options, "d", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("d0,d1", value);
+
+ s = batch.GetFromBatchAndDB(db, read_options, "e", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("e0", value);
+
+ s = db->Delete(write_options, "x");
+ ASSERT_OK(s);
+
+ s = batch.GetFromBatchAndDB(db, read_options, "x", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ const Snapshot* snapshot = db->GetSnapshot();
+ ReadOptions snapshot_read_options;
+ snapshot_read_options.snapshot = snapshot;
+
+ s = db->Delete(write_options, "a");
+ ASSERT_OK(s);
+
+ s = batch.GetFromBatchAndDB(db, read_options, "a", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("a1,a2", value);
+
+ s = batch.GetFromBatchAndDB(db, snapshot_read_options, "a", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("a0,a1,a2", value);
+
+ batch.Delete("a");
+
+ s = batch.GetFromBatchAndDB(db, read_options, "a", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = batch.GetFromBatchAndDB(db, snapshot_read_options, "a", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ s = db->Merge(write_options, "c", "c1");
+ ASSERT_OK(s);
+
+ s = batch.GetFromBatchAndDB(db, read_options, "c", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("c0,c1", value);
+
+ s = batch.GetFromBatchAndDB(db, snapshot_read_options, "c", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("c0", value);
+
+ s = db->Put(write_options, "e", "e1");
+ ASSERT_OK(s);
+
+ s = batch.GetFromBatchAndDB(db, read_options, "e", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("e1,e0", value);
+
+ s = batch.GetFromBatchAndDB(db, snapshot_read_options, "e", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("e0", value);
+
+ s = db->Delete(write_options, "e");
+ ASSERT_OK(s);
+
+ s = batch.GetFromBatchAndDB(db, read_options, "e", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("e0", value);
+
+ s = batch.GetFromBatchAndDB(db, snapshot_read_options, "e", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("e0", value);
+
+ db->ReleaseSnapshot(snapshot);
+ delete db;
+ DestroyDB(dbname, options);
+}
+
+TEST_F(WriteBatchWithIndexTest, TestGetFromBatchAndDBMerge2) {
+ DB* db;
+ Options options;
+
+ options.create_if_missing = true;
+ std::string dbname = test::PerThreadDBPath("write_batch_with_index_test");
+
+ options.merge_operator = MergeOperators::CreateFromStringId("stringappend");
+
+ DestroyDB(dbname, options);
+ Status s = DB::Open(options, dbname, &db);
+ assert(s.ok());
+
+ // Test batch with overwrite_key=true
+ WriteBatchWithIndex batch(BytewiseComparator(), 0, true);
+
+ ReadOptions read_options;
+ WriteOptions write_options;
+ std::string value;
+
+ s = batch.GetFromBatchAndDB(db, read_options, "A", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ batch.Merge("A", "xxx");
+
+ s = batch.GetFromBatchAndDB(db, read_options, "A", &value);
+ ASSERT_TRUE(s.IsMergeInProgress());
+
+ batch.Merge("A", "yyy");
+
+ s = batch.GetFromBatchAndDB(db, read_options, "A", &value);
+ ASSERT_TRUE(s.IsMergeInProgress());
+
+ s = db->Put(write_options, "A", "a0");
+ ASSERT_OK(s);
+
+ s = batch.GetFromBatchAndDB(db, read_options, "A", &value);
+ ASSERT_TRUE(s.IsMergeInProgress());
+
+ batch.Delete("A");
+
+ s = batch.GetFromBatchAndDB(db, read_options, "A", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ delete db;
+ DestroyDB(dbname, options);
+}
+
+TEST_F(WriteBatchWithIndexTest, TestGetFromBatchAndDBMerge3) {
+ DB* db;
+ Options options;
+
+ options.create_if_missing = true;
+ std::string dbname = test::PerThreadDBPath("write_batch_with_index_test");
+
+ options.merge_operator = MergeOperators::CreateFromStringId("stringappend");
+
+ DestroyDB(dbname, options);
+ Status s = DB::Open(options, dbname, &db);
+ assert(s.ok());
+
+ ReadOptions read_options;
+ WriteOptions write_options;
+ FlushOptions flush_options;
+ std::string value;
+
+ WriteBatchWithIndex batch;
+
+ ASSERT_OK(db->Put(write_options, "A", "1"));
+ ASSERT_OK(db->Flush(flush_options, db->DefaultColumnFamily()));
+ ASSERT_OK(batch.Merge("A", "2"));
+
+ ASSERT_OK(batch.GetFromBatchAndDB(db, read_options, "A", &value));
+ ASSERT_EQ(value, "1,2");
+
+ delete db;
+ DestroyDB(dbname, options);
+}
+
+void AssertKey(std::string key, WBWIIterator* iter) {
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ(key, iter->Entry().key.ToString());
+}
+
+void AssertValue(std::string value, WBWIIterator* iter) {
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ(value, iter->Entry().value.ToString());
+}
+
+// Tests that we can write to the WBWI while we iterate (from a single thread).
+// iteration should see the newest writes
+TEST_F(WriteBatchWithIndexTest, MutateWhileIteratingCorrectnessTest) {
+ WriteBatchWithIndex batch(BytewiseComparator(), 0, true);
+ for (char c = 'a'; c <= 'z'; ++c) {
+ batch.Put(std::string(1, c), std::string(1, c));
+ }
+
+ std::unique_ptr<WBWIIterator> iter(batch.NewIterator());
+ iter->Seek("k");
+ AssertKey("k", iter.get());
+ iter->Next();
+ AssertKey("l", iter.get());
+ batch.Put("ab", "cc");
+ iter->Next();
+ AssertKey("m", iter.get());
+ batch.Put("mm", "kk");
+ iter->Next();
+ AssertKey("mm", iter.get());
+ AssertValue("kk", iter.get());
+ batch.Delete("mm");
+
+ iter->Next();
+ AssertKey("n", iter.get());
+ iter->Prev();
+ AssertKey("mm", iter.get());
+ ASSERT_EQ(kDeleteRecord, iter->Entry().type);
+
+ iter->Seek("ab");
+ AssertKey("ab", iter.get());
+ batch.Delete("x");
+ iter->Seek("x");
+ AssertKey("x", iter.get());
+ ASSERT_EQ(kDeleteRecord, iter->Entry().type);
+ iter->Prev();
+ AssertKey("w", iter.get());
+}
+
+void AssertIterKey(std::string key, Iterator* iter) {
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ(key, iter->key().ToString());
+}
+
+void AssertIterValue(std::string value, Iterator* iter) {
+ ASSERT_TRUE(iter->Valid());
+ ASSERT_EQ(value, iter->value().ToString());
+}
+
+// same thing as above, but testing IteratorWithBase
+TEST_F(WriteBatchWithIndexTest, MutateWhileIteratingBaseCorrectnessTest) {
+ WriteBatchWithIndex batch(BytewiseComparator(), 0, true);
+ for (char c = 'a'; c <= 'z'; ++c) {
+ batch.Put(std::string(1, c), std::string(1, c));
+ }
+
+ KVMap map;
+ map["aa"] = "aa";
+ map["cc"] = "cc";
+ map["ee"] = "ee";
+ map["em"] = "me";
+
+ std::unique_ptr<Iterator> iter(
+ batch.NewIteratorWithBase(new KVIter(&map)));
+ iter->Seek("k");
+ AssertIterKey("k", iter.get());
+ iter->Next();
+ AssertIterKey("l", iter.get());
+ batch.Put("ab", "cc");
+ iter->Next();
+ AssertIterKey("m", iter.get());
+ batch.Put("mm", "kk");
+ iter->Next();
+ AssertIterKey("mm", iter.get());
+ AssertIterValue("kk", iter.get());
+ batch.Delete("mm");
+ iter->Next();
+ AssertIterKey("n", iter.get());
+ iter->Prev();
+ // "mm" is deleted, so we're back at "m"
+ AssertIterKey("m", iter.get());
+
+ iter->Seek("ab");
+ AssertIterKey("ab", iter.get());
+ iter->Prev();
+ AssertIterKey("aa", iter.get());
+ iter->Prev();
+ AssertIterKey("a", iter.get());
+ batch.Delete("aa");
+ iter->Next();
+ AssertIterKey("ab", iter.get());
+ iter->Prev();
+ AssertIterKey("a", iter.get());
+
+ batch.Delete("x");
+ iter->Seek("x");
+ AssertIterKey("y", iter.get());
+ iter->Next();
+ AssertIterKey("z", iter.get());
+ iter->Prev();
+ iter->Prev();
+ AssertIterKey("w", iter.get());
+
+ batch.Delete("e");
+ iter->Seek("e");
+ AssertIterKey("ee", iter.get());
+ AssertIterValue("ee", iter.get());
+ batch.Put("ee", "xx");
+ // still the same value
+ AssertIterValue("ee", iter.get());
+ iter->Next();
+ AssertIterKey("em", iter.get());
+ iter->Prev();
+ // new value
+ AssertIterValue("xx", iter.get());
+}
+
+// stress testing mutations with IteratorWithBase
+TEST_F(WriteBatchWithIndexTest, MutateWhileIteratingBaseStressTest) {
+ WriteBatchWithIndex batch(BytewiseComparator(), 0, true);
+ for (char c = 'a'; c <= 'z'; ++c) {
+ batch.Put(std::string(1, c), std::string(1, c));
+ }
+
+ KVMap map;
+ for (char c = 'a'; c <= 'z'; ++c) {
+ map[std::string(2, c)] = std::string(2, c);
+ }
+
+ std::unique_ptr<Iterator> iter(
+ batch.NewIteratorWithBase(new KVIter(&map)));
+
+ Random rnd(301);
+ for (int i = 0; i < 1000000; ++i) {
+ int random = rnd.Uniform(8);
+ char c = static_cast<char>(rnd.Uniform(26) + 'a');
+ switch (random) {
+ case 0:
+ batch.Put(std::string(1, c), "xxx");
+ break;
+ case 1:
+ batch.Put(std::string(2, c), "xxx");
+ break;
+ case 2:
+ batch.Delete(std::string(1, c));
+ break;
+ case 3:
+ batch.Delete(std::string(2, c));
+ break;
+ case 4:
+ iter->Seek(std::string(1, c));
+ break;
+ case 5:
+ iter->Seek(std::string(2, c));
+ break;
+ case 6:
+ if (iter->Valid()) {
+ iter->Next();
+ }
+ break;
+ case 7:
+ if (iter->Valid()) {
+ iter->Prev();
+ }
+ break;
+ default:
+ assert(false);
+ }
+ }
+}
+
+static std::string PrintContents(WriteBatchWithIndex* batch,
+ ColumnFamilyHandle* column_family) {
+ std::string result;
+
+ WBWIIterator* iter;
+ if (column_family == nullptr) {
+ iter = batch->NewIterator();
+ } else {
+ iter = batch->NewIterator(column_family);
+ }
+
+ iter->SeekToFirst();
+ while (iter->Valid()) {
+ WriteEntry e = iter->Entry();
+
+ if (e.type == kPutRecord) {
+ result.append("PUT(");
+ result.append(e.key.ToString());
+ result.append("):");
+ result.append(e.value.ToString());
+ } else if (e.type == kMergeRecord) {
+ result.append("MERGE(");
+ result.append(e.key.ToString());
+ result.append("):");
+ result.append(e.value.ToString());
+ } else if (e.type == kSingleDeleteRecord) {
+ result.append("SINGLE-DEL(");
+ result.append(e.key.ToString());
+ result.append(")");
+ } else {
+ assert(e.type == kDeleteRecord);
+ result.append("DEL(");
+ result.append(e.key.ToString());
+ result.append(")");
+ }
+
+ result.append(",");
+ iter->Next();
+ }
+
+ delete iter;
+ return result;
+}
+
+static std::string PrintContents(WriteBatchWithIndex* batch, KVMap* base_map,
+ ColumnFamilyHandle* column_family) {
+ std::string result;
+
+ Iterator* iter;
+ if (column_family == nullptr) {
+ iter = batch->NewIteratorWithBase(new KVIter(base_map));
+ } else {
+ iter = batch->NewIteratorWithBase(column_family, new KVIter(base_map));
+ }
+
+ iter->SeekToFirst();
+ while (iter->Valid()) {
+ assert(iter->status().ok());
+
+ Slice key = iter->key();
+ Slice value = iter->value();
+
+ result.append(key.ToString());
+ result.append(":");
+ result.append(value.ToString());
+ result.append(",");
+
+ iter->Next();
+ }
+
+ delete iter;
+ return result;
+}
+
+TEST_F(WriteBatchWithIndexTest, SavePointTest) {
+ WriteBatchWithIndex batch;
+ ColumnFamilyHandleImplDummy cf1(1, BytewiseComparator());
+ Status s;
+
+ batch.Put("A", "a");
+ batch.Put("B", "b");
+ batch.Put("A", "aa");
+ batch.Put(&cf1, "A", "a1");
+ batch.Delete(&cf1, "B");
+ batch.Put(&cf1, "C", "c1");
+ batch.Put(&cf1, "E", "e1");
+
+ batch.SetSavePoint(); // 1
+
+ batch.Put("C", "cc");
+ batch.Put("B", "bb");
+ batch.Delete("A");
+ batch.Put(&cf1, "B", "b1");
+ batch.Delete(&cf1, "A");
+ batch.SingleDelete(&cf1, "E");
+ batch.SetSavePoint(); // 2
+
+ batch.Put("A", "aaa");
+ batch.Put("A", "xxx");
+ batch.Delete("B");
+ batch.Put(&cf1, "B", "b2");
+ batch.Delete(&cf1, "C");
+ batch.SetSavePoint(); // 3
+ batch.SetSavePoint(); // 4
+ batch.SingleDelete("D");
+ batch.Delete(&cf1, "D");
+ batch.Delete(&cf1, "E");
+
+ ASSERT_EQ(
+ "PUT(A):a,PUT(A):aa,DEL(A),PUT(A):aaa,PUT(A):xxx,PUT(B):b,PUT(B):bb,DEL("
+ "B)"
+ ",PUT(C):cc,SINGLE-DEL(D),",
+ PrintContents(&batch, nullptr));
+
+ ASSERT_EQ(
+ "PUT(A):a1,DEL(A),DEL(B),PUT(B):b1,PUT(B):b2,PUT(C):c1,DEL(C),"
+ "DEL(D),PUT(E):e1,SINGLE-DEL(E),DEL(E),",
+ PrintContents(&batch, &cf1));
+
+ ASSERT_OK(batch.RollbackToSavePoint()); // rollback to 4
+ ASSERT_EQ(
+ "PUT(A):a,PUT(A):aa,DEL(A),PUT(A):aaa,PUT(A):xxx,PUT(B):b,PUT(B):bb,DEL("
+ "B)"
+ ",PUT(C):cc,",
+ PrintContents(&batch, nullptr));
+
+ ASSERT_EQ(
+ "PUT(A):a1,DEL(A),DEL(B),PUT(B):b1,PUT(B):b2,PUT(C):c1,DEL(C),"
+ "PUT(E):e1,SINGLE-DEL(E),",
+ PrintContents(&batch, &cf1));
+
+ ASSERT_OK(batch.RollbackToSavePoint()); // rollback to 3
+ ASSERT_EQ(
+ "PUT(A):a,PUT(A):aa,DEL(A),PUT(A):aaa,PUT(A):xxx,PUT(B):b,PUT(B):bb,DEL("
+ "B)"
+ ",PUT(C):cc,",
+ PrintContents(&batch, nullptr));
+
+ ASSERT_EQ(
+ "PUT(A):a1,DEL(A),DEL(B),PUT(B):b1,PUT(B):b2,PUT(C):c1,DEL(C),"
+ "PUT(E):e1,SINGLE-DEL(E),",
+ PrintContents(&batch, &cf1));
+
+ ASSERT_OK(batch.RollbackToSavePoint()); // rollback to 2
+ ASSERT_EQ("PUT(A):a,PUT(A):aa,DEL(A),PUT(B):b,PUT(B):bb,PUT(C):cc,",
+ PrintContents(&batch, nullptr));
+
+ ASSERT_EQ(
+ "PUT(A):a1,DEL(A),DEL(B),PUT(B):b1,PUT(C):c1,"
+ "PUT(E):e1,SINGLE-DEL(E),",
+ PrintContents(&batch, &cf1));
+
+ batch.SetSavePoint(); // 5
+ batch.Put("X", "x");
+
+ ASSERT_EQ("PUT(A):a,PUT(A):aa,DEL(A),PUT(B):b,PUT(B):bb,PUT(C):cc,PUT(X):x,",
+ PrintContents(&batch, nullptr));
+
+ ASSERT_OK(batch.RollbackToSavePoint()); // rollback to 5
+ ASSERT_EQ("PUT(A):a,PUT(A):aa,DEL(A),PUT(B):b,PUT(B):bb,PUT(C):cc,",
+ PrintContents(&batch, nullptr));
+
+ ASSERT_EQ(
+ "PUT(A):a1,DEL(A),DEL(B),PUT(B):b1,PUT(C):c1,"
+ "PUT(E):e1,SINGLE-DEL(E),",
+ PrintContents(&batch, &cf1));
+
+ ASSERT_OK(batch.RollbackToSavePoint()); // rollback to 1
+ ASSERT_EQ("PUT(A):a,PUT(A):aa,PUT(B):b,", PrintContents(&batch, nullptr));
+
+ ASSERT_EQ("PUT(A):a1,DEL(B),PUT(C):c1,PUT(E):e1,",
+ PrintContents(&batch, &cf1));
+
+ s = batch.RollbackToSavePoint(); // no savepoint found
+ ASSERT_TRUE(s.IsNotFound());
+ ASSERT_EQ("PUT(A):a,PUT(A):aa,PUT(B):b,", PrintContents(&batch, nullptr));
+
+ ASSERT_EQ("PUT(A):a1,DEL(B),PUT(C):c1,PUT(E):e1,",
+ PrintContents(&batch, &cf1));
+
+ batch.SetSavePoint(); // 6
+
+ batch.Clear();
+ ASSERT_EQ("", PrintContents(&batch, nullptr));
+ ASSERT_EQ("", PrintContents(&batch, &cf1));
+
+ s = batch.RollbackToSavePoint(); // rollback to 6
+ ASSERT_TRUE(s.IsNotFound());
+}
+
+TEST_F(WriteBatchWithIndexTest, SingleDeleteTest) {
+ WriteBatchWithIndex batch;
+ Status s;
+ std::string value;
+ DBOptions db_options;
+
+ batch.SingleDelete("A");
+
+ s = batch.GetFromBatch(db_options, "A", &value);
+ ASSERT_TRUE(s.IsNotFound());
+ s = batch.GetFromBatch(db_options, "B", &value);
+ ASSERT_TRUE(s.IsNotFound());
+ value = PrintContents(&batch, nullptr);
+ ASSERT_EQ("SINGLE-DEL(A),", value);
+
+ batch.Clear();
+ batch.Put("A", "a");
+ batch.Put("A", "a2");
+ batch.Put("B", "b");
+ batch.SingleDelete("A");
+
+ s = batch.GetFromBatch(db_options, "A", &value);
+ ASSERT_TRUE(s.IsNotFound());
+ s = batch.GetFromBatch(db_options, "B", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("b", value);
+
+ value = PrintContents(&batch, nullptr);
+ ASSERT_EQ("PUT(A):a,PUT(A):a2,SINGLE-DEL(A),PUT(B):b,", value);
+
+ batch.Put("C", "c");
+ batch.Put("A", "a3");
+ batch.Delete("B");
+ batch.SingleDelete("B");
+ batch.SingleDelete("C");
+
+ s = batch.GetFromBatch(db_options, "A", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("a3", value);
+ s = batch.GetFromBatch(db_options, "B", &value);
+ ASSERT_TRUE(s.IsNotFound());
+ s = batch.GetFromBatch(db_options, "C", &value);
+ ASSERT_TRUE(s.IsNotFound());
+ s = batch.GetFromBatch(db_options, "D", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ value = PrintContents(&batch, nullptr);
+ ASSERT_EQ(
+ "PUT(A):a,PUT(A):a2,SINGLE-DEL(A),PUT(A):a3,PUT(B):b,DEL(B),SINGLE-DEL(B)"
+ ",PUT(C):c,SINGLE-DEL(C),",
+ value);
+
+ batch.Put("B", "b4");
+ batch.Put("C", "c4");
+ batch.Put("D", "d4");
+ batch.SingleDelete("D");
+ batch.SingleDelete("D");
+ batch.Delete("A");
+
+ s = batch.GetFromBatch(db_options, "A", &value);
+ ASSERT_TRUE(s.IsNotFound());
+ s = batch.GetFromBatch(db_options, "B", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("b4", value);
+ s = batch.GetFromBatch(db_options, "C", &value);
+ ASSERT_OK(s);
+ ASSERT_EQ("c4", value);
+ s = batch.GetFromBatch(db_options, "D", &value);
+ ASSERT_TRUE(s.IsNotFound());
+
+ value = PrintContents(&batch, nullptr);
+ ASSERT_EQ(
+ "PUT(A):a,PUT(A):a2,SINGLE-DEL(A),PUT(A):a3,DEL(A),PUT(B):b,DEL(B),"
+ "SINGLE-DEL(B),PUT(B):b4,PUT(C):c,SINGLE-DEL(C),PUT(C):c4,PUT(D):d4,"
+ "SINGLE-DEL(D),SINGLE-DEL(D),",
+ value);
+}
+
+TEST_F(WriteBatchWithIndexTest, SingleDeleteDeltaIterTest) {
+ Status s;
+ std::string value;
+ DBOptions db_options;
+ WriteBatchWithIndex batch(BytewiseComparator(), 20, true /* overwrite_key */);
+ batch.Put("A", "a");
+ batch.Put("A", "a2");
+ batch.Put("B", "b");
+ batch.SingleDelete("A");
+ batch.Delete("B");
+
+ KVMap map;
+ value = PrintContents(&batch, &map, nullptr);
+ ASSERT_EQ("", value);
+
+ map["A"] = "aa";
+ map["C"] = "cc";
+ map["D"] = "dd";
+
+ batch.SingleDelete("B");
+ batch.SingleDelete("C");
+ batch.SingleDelete("Z");
+
+ value = PrintContents(&batch, &map, nullptr);
+ ASSERT_EQ("D:dd,", value);
+
+ batch.Put("A", "a3");
+ batch.Put("B", "b3");
+ batch.SingleDelete("A");
+ batch.SingleDelete("A");
+ batch.SingleDelete("D");
+ batch.SingleDelete("D");
+ batch.Delete("D");
+
+ map["E"] = "ee";
+
+ value = PrintContents(&batch, &map, nullptr);
+ ASSERT_EQ("B:b3,E:ee,", value);
+}
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
+
+#else
+#include <stdio.h>
+
+int main() {
+ fprintf(stderr, "SKIPPED\n");
+ return 0;
+}
+
+#endif // !ROCKSDB_LITE