diff options
Diffstat (limited to 'src/rocksdb/db/write_callback_test.cc')
-rw-r--r-- | src/rocksdb/db/write_callback_test.cc | 465 |
1 files changed, 465 insertions, 0 deletions
diff --git a/src/rocksdb/db/write_callback_test.cc b/src/rocksdb/db/write_callback_test.cc new file mode 100644 index 000000000..e6ebaae08 --- /dev/null +++ b/src/rocksdb/db/write_callback_test.cc @@ -0,0 +1,465 @@ +// Copyright (c) 2011-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). + +#ifndef ROCKSDB_LITE + +#include "db/write_callback.h" + +#include <atomic> +#include <functional> +#include <string> +#include <utility> +#include <vector> + +#include "db/db_impl/db_impl.h" +#include "port/port.h" +#include "rocksdb/db.h" +#include "rocksdb/write_batch.h" +#include "test_util/sync_point.h" +#include "test_util/testharness.h" +#include "util/random.h" + +using std::string; + +namespace ROCKSDB_NAMESPACE { + +class WriteCallbackTest : public testing::Test { + public: + string dbname; + + WriteCallbackTest() { + dbname = test::PerThreadDBPath("write_callback_testdb"); + } +}; + +class WriteCallbackTestWriteCallback1 : public WriteCallback { + public: + bool was_called = false; + + Status Callback(DB* db) override { + was_called = true; + + // Make sure db is a DBImpl + DBImpl* db_impl = dynamic_cast<DBImpl*>(db); + if (db_impl == nullptr) { + return Status::InvalidArgument(""); + } + + return Status::OK(); + } + + bool AllowWriteBatching() override { return true; } +}; + +class WriteCallbackTestWriteCallback2 : public WriteCallback { + public: + Status Callback(DB* /*db*/) override { return Status::Busy(); } + bool AllowWriteBatching() override { return true; } +}; + +class MockWriteCallback : public WriteCallback { + public: + bool should_fail_ = false; + bool allow_batching_ = false; + std::atomic<bool> was_called_{false}; + + MockWriteCallback() {} + + MockWriteCallback(const MockWriteCallback& other) { + should_fail_ = other.should_fail_; + allow_batching_ = other.allow_batching_; + was_called_.store(other.was_called_.load()); + } + + Status Callback(DB* /*db*/) override { + was_called_.store(true); + if (should_fail_) { + return Status::Busy(); + } else { + return Status::OK(); + } + } + + bool AllowWriteBatching() override { return allow_batching_; } +}; + +#if !defined(ROCKSDB_VALGRIND_RUN) || defined(ROCKSDB_FULL_VALGRIND_RUN) +class WriteCallbackPTest + : public WriteCallbackTest, + public ::testing::WithParamInterface< + std::tuple<bool, bool, bool, bool, bool, bool, bool>> { + public: + WriteCallbackPTest() { + std::tie(unordered_write_, seq_per_batch_, two_queues_, allow_parallel_, + allow_batching_, enable_WAL_, enable_pipelined_write_) = + GetParam(); + } + + protected: + bool unordered_write_; + bool seq_per_batch_; + bool two_queues_; + bool allow_parallel_; + bool allow_batching_; + bool enable_WAL_; + bool enable_pipelined_write_; +}; + +TEST_P(WriteCallbackPTest, WriteWithCallbackTest) { + struct WriteOP { + WriteOP(bool should_fail = false) { callback_.should_fail_ = should_fail; } + + void Put(const string& key, const string& val) { + kvs_.push_back(std::make_pair(key, val)); + ASSERT_OK(write_batch_.Put(key, val)); + } + + void Clear() { + kvs_.clear(); + write_batch_.Clear(); + callback_.was_called_.store(false); + } + + MockWriteCallback callback_; + WriteBatch write_batch_; + std::vector<std::pair<string, string>> kvs_; + }; + + // In each scenario we'll launch multiple threads to write. + // The size of each array equals to number of threads, and + // each boolean in it denote whether callback of corresponding + // thread should succeed or fail. + std::vector<std::vector<WriteOP>> write_scenarios = { + {true}, + {false}, + {false, false}, + {true, true}, + {true, false}, + {false, true}, + {false, false, false}, + {true, true, true}, + {false, true, false}, + {true, false, true}, + {true, false, false, false, false}, + {false, false, false, false, true}, + {false, false, true, false, true}, + }; + + for (auto& write_group : write_scenarios) { + Options options; + options.create_if_missing = true; + options.unordered_write = unordered_write_; + options.allow_concurrent_memtable_write = allow_parallel_; + options.enable_pipelined_write = enable_pipelined_write_; + options.two_write_queues = two_queues_; + // Skip unsupported combinations + if (options.enable_pipelined_write && seq_per_batch_) { + continue; + } + if (options.enable_pipelined_write && options.two_write_queues) { + continue; + } + if (options.unordered_write && !options.allow_concurrent_memtable_write) { + continue; + } + if (options.unordered_write && options.enable_pipelined_write) { + continue; + } + + ReadOptions read_options; + DB* db; + DBImpl* db_impl; + + ASSERT_OK(DestroyDB(dbname, options)); + + DBOptions db_options(options); + ColumnFamilyOptions cf_options(options); + std::vector<ColumnFamilyDescriptor> column_families; + column_families.push_back( + ColumnFamilyDescriptor(kDefaultColumnFamilyName, cf_options)); + std::vector<ColumnFamilyHandle*> handles; + auto open_s = DBImpl::Open(db_options, dbname, column_families, &handles, + &db, seq_per_batch_, true /* batch_per_txn */); + ASSERT_OK(open_s); + assert(handles.size() == 1); + delete handles[0]; + + db_impl = dynamic_cast<DBImpl*>(db); + ASSERT_TRUE(db_impl); + + // Writers that have called JoinBatchGroup. + std::atomic<uint64_t> threads_joining(0); + // Writers that have linked to the queue + std::atomic<uint64_t> threads_linked(0); + // Writers that pass WriteThread::JoinBatchGroup:Wait sync-point. + std::atomic<uint64_t> threads_verified(0); + + std::atomic<uint64_t> seq(db_impl->GetLatestSequenceNumber()); + ASSERT_EQ(db_impl->GetLatestSequenceNumber(), 0); + + ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack( + "WriteThread::JoinBatchGroup:Start", [&](void*) { + uint64_t cur_threads_joining = threads_joining.fetch_add(1); + // Wait for the last joined writer to link to the queue. + // In this way the writers link to the queue one by one. + // This allows us to confidently detect the first writer + // who increases threads_linked as the leader. + while (threads_linked.load() < cur_threads_joining) { + } + }); + + // Verification once writers call JoinBatchGroup. + ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack( + "WriteThread::JoinBatchGroup:Wait", [&](void* arg) { + uint64_t cur_threads_linked = threads_linked.fetch_add(1); + bool is_leader = false; + bool is_last = false; + + // who am i + is_leader = (cur_threads_linked == 0); + is_last = (cur_threads_linked == write_group.size() - 1); + + // check my state + auto* writer = reinterpret_cast<WriteThread::Writer*>(arg); + + if (is_leader) { + ASSERT_TRUE(writer->state == + WriteThread::State::STATE_GROUP_LEADER); + } else { + ASSERT_TRUE(writer->state == WriteThread::State::STATE_INIT); + } + + // (meta test) the first WriteOP should indeed be the first + // and the last should be the last (all others can be out of + // order) + if (is_leader) { + ASSERT_TRUE(writer->callback->Callback(nullptr).ok() == + !write_group.front().callback_.should_fail_); + } else if (is_last) { + ASSERT_TRUE(writer->callback->Callback(nullptr).ok() == + !write_group.back().callback_.should_fail_); + } + + threads_verified.fetch_add(1); + // Wait here until all verification in this sync-point + // callback finish for all writers. + while (threads_verified.load() < write_group.size()) { + } + }); + + ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack( + "WriteThread::JoinBatchGroup:DoneWaiting", [&](void* arg) { + // check my state + auto* writer = reinterpret_cast<WriteThread::Writer*>(arg); + + if (!allow_batching_) { + // no batching so everyone should be a leader + ASSERT_TRUE(writer->state == + WriteThread::State::STATE_GROUP_LEADER); + } else if (!allow_parallel_) { + ASSERT_TRUE(writer->state == WriteThread::State::STATE_COMPLETED || + (enable_pipelined_write_ && + writer->state == + WriteThread::State::STATE_MEMTABLE_WRITER_LEADER)); + } + }); + + std::atomic<uint32_t> thread_num(0); + std::atomic<char> dummy_key(0); + + // Each write thread create a random write batch and write to DB + // with a write callback. + std::function<void()> write_with_callback_func = [&]() { + uint32_t i = thread_num.fetch_add(1); + Random rnd(i); + + // leaders gotta lead + while (i > 0 && threads_verified.load() < 1) { + } + + // loser has to lose + while (i == write_group.size() - 1 && + threads_verified.load() < write_group.size() - 1) { + } + + auto& write_op = write_group.at(i); + write_op.Clear(); + write_op.callback_.allow_batching_ = allow_batching_; + + // insert some keys + for (uint32_t j = 0; j < rnd.Next() % 50; j++) { + // grab unique key + char my_key = dummy_key.fetch_add(1); + + string skey(5, my_key); + string sval(10, my_key); + write_op.Put(skey, sval); + + if (!write_op.callback_.should_fail_ && !seq_per_batch_) { + seq.fetch_add(1); + } + } + if (!write_op.callback_.should_fail_ && seq_per_batch_) { + seq.fetch_add(1); + } + + WriteOptions woptions; + woptions.disableWAL = !enable_WAL_; + woptions.sync = enable_WAL_; + if (woptions.protection_bytes_per_key > 0) { + ASSERT_OK(WriteBatchInternal::UpdateProtectionInfo( + &write_op.write_batch_, woptions.protection_bytes_per_key)); + } + Status s; + if (seq_per_batch_) { + class PublishSeqCallback : public PreReleaseCallback { + public: + PublishSeqCallback(DBImpl* db_impl_in) : db_impl_(db_impl_in) {} + Status Callback(SequenceNumber last_seq, bool /*not used*/, uint64_t, + size_t /*index*/, size_t /*total*/) override { + db_impl_->SetLastPublishedSequence(last_seq); + return Status::OK(); + } + DBImpl* db_impl_; + } publish_seq_callback(db_impl); + // seq_per_batch_ requires a natural batch separator or Noop + ASSERT_OK(WriteBatchInternal::InsertNoop(&write_op.write_batch_)); + const size_t ONE_BATCH = 1; + s = db_impl->WriteImpl(woptions, &write_op.write_batch_, + &write_op.callback_, nullptr, 0, false, nullptr, + ONE_BATCH, + two_queues_ ? &publish_seq_callback : nullptr); + } else { + s = db_impl->WriteWithCallback(woptions, &write_op.write_batch_, + &write_op.callback_); + } + + if (write_op.callback_.should_fail_) { + ASSERT_TRUE(s.IsBusy()); + } else { + ASSERT_OK(s); + } + }; + + ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing(); + + // do all the writes + std::vector<port::Thread> threads; + for (uint32_t i = 0; i < write_group.size(); i++) { + threads.emplace_back(write_with_callback_func); + } + for (auto& t : threads) { + t.join(); + } + + ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing(); + + // check for keys + string value; + for (auto& w : write_group) { + ASSERT_TRUE(w.callback_.was_called_.load()); + for (auto& kvp : w.kvs_) { + if (w.callback_.should_fail_) { + ASSERT_TRUE(db->Get(read_options, kvp.first, &value).IsNotFound()); + } else { + ASSERT_OK(db->Get(read_options, kvp.first, &value)); + ASSERT_EQ(value, kvp.second); + } + } + } + + ASSERT_EQ(seq.load(), db_impl->TEST_GetLastVisibleSequence()); + + delete db; + ASSERT_OK(DestroyDB(dbname, options)); + } +} + +INSTANTIATE_TEST_CASE_P(WriteCallbackPTest, WriteCallbackPTest, + ::testing::Combine(::testing::Bool(), ::testing::Bool(), + ::testing::Bool(), ::testing::Bool(), + ::testing::Bool(), ::testing::Bool(), + ::testing::Bool())); +#endif // !defined(ROCKSDB_VALGRIND_RUN) || defined(ROCKSDB_FULL_VALGRIND_RUN) + +TEST_F(WriteCallbackTest, WriteCallBackTest) { + Options options; + WriteOptions write_options; + ReadOptions read_options; + string value; + DB* db; + DBImpl* db_impl; + + ASSERT_OK(DestroyDB(dbname, options)); + + options.create_if_missing = true; + Status s = DB::Open(options, dbname, &db); + ASSERT_OK(s); + + db_impl = dynamic_cast<DBImpl*>(db); + ASSERT_TRUE(db_impl); + + WriteBatch wb; + + ASSERT_OK(wb.Put("a", "value.a")); + ASSERT_OK(wb.Delete("x")); + + // Test a simple Write + s = db->Write(write_options, &wb); + ASSERT_OK(s); + + s = db->Get(read_options, "a", &value); + ASSERT_OK(s); + ASSERT_EQ("value.a", value); + + // Test WriteWithCallback + WriteCallbackTestWriteCallback1 callback1; + WriteBatch wb2; + + ASSERT_OK(wb2.Put("a", "value.a2")); + + s = db_impl->WriteWithCallback(write_options, &wb2, &callback1); + ASSERT_OK(s); + ASSERT_TRUE(callback1.was_called); + + s = db->Get(read_options, "a", &value); + ASSERT_OK(s); + ASSERT_EQ("value.a2", value); + + // Test WriteWithCallback for a callback that fails + WriteCallbackTestWriteCallback2 callback2; + WriteBatch wb3; + + ASSERT_OK(wb3.Put("a", "value.a3")); + + s = db_impl->WriteWithCallback(write_options, &wb3, &callback2); + ASSERT_NOK(s); + + s = db->Get(read_options, "a", &value); + ASSERT_OK(s); + ASSERT_EQ("value.a2", value); + + delete db; + ASSERT_OK(DestroyDB(dbname, options)); +} + +} // namespace ROCKSDB_NAMESPACE + +int main(int argc, char** argv) { + ROCKSDB_NAMESPACE::port::InstallStackTraceHandler(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} + +#else +#include <stdio.h> + +int main(int /*argc*/, char** /*argv*/) { + fprintf(stderr, + "SKIPPED as WriteWithCallback is not supported in ROCKSDB_LITE\n"); + return 0; +} + +#endif // !ROCKSDB_LITE |