// Copyright (c) 2011-present, Facebook, Inc. All rights reserved. // This source code is licensed under both the GPLv2 (found in the // COPYING file in the root directory) and Apache 2.0 License // (found in the LICENSE.Apache file in the root directory). #include #include #include #include "port/port.h" #include "rocksdb/env.h" #include "test_util/sync_point.h" #include "test_util/testharness.h" #include "test_util/testutil.h" #include "util/autovector.h" #include "util/thread_local.h" namespace ROCKSDB_NAMESPACE { class ThreadLocalTest : public testing::Test { public: ThreadLocalTest() : env_(Env::Default()) {} Env* env_; }; namespace { struct Params { Params(port::Mutex* m, port::CondVar* c, int* u, int n, UnrefHandler handler = nullptr) : mu(m), cv(c), unref(u), total(n), started(0), completed(0), doWrite(false), tls1(handler), tls2(nullptr) {} port::Mutex* mu; port::CondVar* cv; int* unref; int total; int started; int completed; bool doWrite; ThreadLocalPtr tls1; ThreadLocalPtr* tls2; }; class IDChecker : public ThreadLocalPtr { public: static uint32_t PeekId() { return TEST_PeekId(); } }; } // anonymous namespace // Suppress false positive clang analyzer warnings. #ifndef __clang_analyzer__ TEST_F(ThreadLocalTest, UniqueIdTest) { port::Mutex mu; port::CondVar cv(&mu); uint32_t base_id = IDChecker::PeekId(); // New ThreadLocal instance bumps id by 1 { // Id used 0 Params p1(&mu, &cv, nullptr, 1u); ASSERT_EQ(IDChecker::PeekId(), base_id + 1u); // Id used 1 Params p2(&mu, &cv, nullptr, 1u); ASSERT_EQ(IDChecker::PeekId(), base_id + 2u); // Id used 2 Params p3(&mu, &cv, nullptr, 1u); ASSERT_EQ(IDChecker::PeekId(), base_id + 3u); // Id used 3 Params p4(&mu, &cv, nullptr, 1u); ASSERT_EQ(IDChecker::PeekId(), base_id + 4u); } // id 3, 2, 1, 0 are in the free queue in order ASSERT_EQ(IDChecker::PeekId(), base_id + 0u); // pick up 0 Params p1(&mu, &cv, nullptr, 1u); ASSERT_EQ(IDChecker::PeekId(), base_id + 1u); // pick up 1 Params* p2 = new Params(&mu, &cv, nullptr, 1u); ASSERT_EQ(IDChecker::PeekId(), base_id + 2u); // pick up 2 Params p3(&mu, &cv, nullptr, 1u); ASSERT_EQ(IDChecker::PeekId(), base_id + 3u); // return up 1 delete p2; ASSERT_EQ(IDChecker::PeekId(), base_id + 1u); // Now we have 3, 1 in queue // pick up 1 Params p4(&mu, &cv, nullptr, 1u); ASSERT_EQ(IDChecker::PeekId(), base_id + 3u); // pick up 3 Params p5(&mu, &cv, nullptr, 1u); // next new id ASSERT_EQ(IDChecker::PeekId(), base_id + 4u); // After exit, id sequence in queue: // 3, 1, 2, 0 } #endif // __clang_analyzer__ TEST_F(ThreadLocalTest, SequentialReadWriteTest) { // global id list carries over 3, 1, 2, 0 uint32_t base_id = IDChecker::PeekId(); port::Mutex mu; port::CondVar cv(&mu); Params p(&mu, &cv, nullptr, 1); ThreadLocalPtr tls2; p.tls2 = &tls2; auto func = [](void* ptr) { auto& params = *static_cast(ptr); ASSERT_TRUE(params.tls1.Get() == nullptr); params.tls1.Reset(reinterpret_cast(1)); ASSERT_TRUE(params.tls1.Get() == reinterpret_cast(1)); params.tls1.Reset(reinterpret_cast(2)); ASSERT_TRUE(params.tls1.Get() == reinterpret_cast(2)); ASSERT_TRUE(params.tls2->Get() == nullptr); params.tls2->Reset(reinterpret_cast(1)); ASSERT_TRUE(params.tls2->Get() == reinterpret_cast(1)); params.tls2->Reset(reinterpret_cast(2)); ASSERT_TRUE(params.tls2->Get() == reinterpret_cast(2)); params.mu->Lock(); ++(params.completed); params.cv->SignalAll(); params.mu->Unlock(); }; for (int iter = 0; iter < 1024; ++iter) { ASSERT_EQ(IDChecker::PeekId(), base_id + 1u); // Another new thread, read/write should not see value from previous thread env_->StartThread(func, static_cast(&p)); mu.Lock(); while (p.completed != iter + 1) { cv.Wait(); } mu.Unlock(); ASSERT_EQ(IDChecker::PeekId(), base_id + 1u); } } TEST_F(ThreadLocalTest, ConcurrentReadWriteTest) { // global id list carries over 3, 1, 2, 0 uint32_t base_id = IDChecker::PeekId(); ThreadLocalPtr tls2; port::Mutex mu1; port::CondVar cv1(&mu1); Params p1(&mu1, &cv1, nullptr, 16); p1.tls2 = &tls2; port::Mutex mu2; port::CondVar cv2(&mu2); Params p2(&mu2, &cv2, nullptr, 16); p2.doWrite = true; p2.tls2 = &tls2; auto func = [](void* ptr) { auto& p = *static_cast(ptr); p.mu->Lock(); // Size_T switches size along with the ptr size // we want to cast to. size_t own = ++(p.started); p.cv->SignalAll(); while (p.started != p.total) { p.cv->Wait(); } p.mu->Unlock(); // Let write threads write a different value from the read threads if (p.doWrite) { own += 8192; } ASSERT_TRUE(p.tls1.Get() == nullptr); ASSERT_TRUE(p.tls2->Get() == nullptr); auto* env = Env::Default(); auto start = env->NowMicros(); p.tls1.Reset(reinterpret_cast(own)); p.tls2->Reset(reinterpret_cast(own + 1)); // Loop for 1 second while (env->NowMicros() - start < 1000 * 1000) { for (int iter = 0; iter < 100000; ++iter) { ASSERT_TRUE(p.tls1.Get() == reinterpret_cast(own)); ASSERT_TRUE(p.tls2->Get() == reinterpret_cast(own + 1)); if (p.doWrite) { p.tls1.Reset(reinterpret_cast(own)); p.tls2->Reset(reinterpret_cast(own + 1)); } } } p.mu->Lock(); ++(p.completed); p.cv->SignalAll(); p.mu->Unlock(); }; // Initiate 2 instnaces: one keeps writing and one keeps reading. // The read instance should not see data from the write instance. // Each thread local copy of the value are also different from each // other. for (int th = 0; th < p1.total; ++th) { env_->StartThread(func, static_cast(&p1)); } for (int th = 0; th < p2.total; ++th) { env_->StartThread(func, static_cast(&p2)); } mu1.Lock(); while (p1.completed != p1.total) { cv1.Wait(); } mu1.Unlock(); mu2.Lock(); while (p2.completed != p2.total) { cv2.Wait(); } mu2.Unlock(); ASSERT_EQ(IDChecker::PeekId(), base_id + 3u); } TEST_F(ThreadLocalTest, Unref) { auto unref = [](void* ptr) { auto& p = *static_cast(ptr); p.mu->Lock(); ++(*p.unref); p.mu->Unlock(); }; // Case 0: no unref triggered if ThreadLocalPtr is never accessed auto func0 = [](void* ptr) { auto& p = *static_cast(ptr); p.mu->Lock(); ++(p.started); p.cv->SignalAll(); while (p.started != p.total) { p.cv->Wait(); } p.mu->Unlock(); }; for (int th = 1; th <= 128; th += th) { port::Mutex mu; port::CondVar cv(&mu); int unref_count = 0; Params p(&mu, &cv, &unref_count, th, unref); for (int i = 0; i < p.total; ++i) { env_->StartThread(func0, static_cast(&p)); } env_->WaitForJoin(); ASSERT_EQ(unref_count, 0); } // Case 1: unref triggered by thread exit auto func1 = [](void* ptr) { auto& p = *static_cast(ptr); p.mu->Lock(); ++(p.started); p.cv->SignalAll(); while (p.started != p.total) { p.cv->Wait(); } p.mu->Unlock(); ASSERT_TRUE(p.tls1.Get() == nullptr); ASSERT_TRUE(p.tls2->Get() == nullptr); p.tls1.Reset(ptr); p.tls2->Reset(ptr); p.tls1.Reset(ptr); p.tls2->Reset(ptr); }; for (int th = 1; th <= 128; th += th) { port::Mutex mu; port::CondVar cv(&mu); int unref_count = 0; ThreadLocalPtr tls2(unref); Params p(&mu, &cv, &unref_count, th, unref); p.tls2 = &tls2; for (int i = 0; i < p.total; ++i) { env_->StartThread(func1, static_cast(&p)); } env_->WaitForJoin(); // N threads x 2 ThreadLocal instance cleanup on thread exit ASSERT_EQ(unref_count, 2 * p.total); } // Case 2: unref triggered by ThreadLocal instance destruction auto func2 = [](void* ptr) { auto& p = *static_cast(ptr); p.mu->Lock(); ++(p.started); p.cv->SignalAll(); while (p.started != p.total) { p.cv->Wait(); } p.mu->Unlock(); ASSERT_TRUE(p.tls1.Get() == nullptr); ASSERT_TRUE(p.tls2->Get() == nullptr); p.tls1.Reset(ptr); p.tls2->Reset(ptr); p.tls1.Reset(ptr); p.tls2->Reset(ptr); p.mu->Lock(); ++(p.completed); p.cv->SignalAll(); // Waiting for instruction to exit thread while (p.completed != 0) { p.cv->Wait(); } p.mu->Unlock(); }; for (int th = 1; th <= 128; th += th) { port::Mutex mu; port::CondVar cv(&mu); int unref_count = 0; Params p(&mu, &cv, &unref_count, th, unref); p.tls2 = new ThreadLocalPtr(unref); for (int i = 0; i < p.total; ++i) { env_->StartThread(func2, static_cast(&p)); } // Wait for all threads to finish using Params mu.Lock(); while (p.completed != p.total) { cv.Wait(); } mu.Unlock(); // Now destroy one ThreadLocal instance delete p.tls2; p.tls2 = nullptr; // instance destroy for N threads ASSERT_EQ(unref_count, p.total); // Signal to exit mu.Lock(); p.completed = 0; cv.SignalAll(); mu.Unlock(); env_->WaitForJoin(); // additional N threads exit unref for the left instance ASSERT_EQ(unref_count, 2 * p.total); } } TEST_F(ThreadLocalTest, Swap) { ThreadLocalPtr tls; tls.Reset(reinterpret_cast(1)); ASSERT_EQ(reinterpret_cast(tls.Swap(nullptr)), 1); ASSERT_TRUE(tls.Swap(reinterpret_cast(2)) == nullptr); ASSERT_EQ(reinterpret_cast(tls.Get()), 2); ASSERT_EQ(reinterpret_cast(tls.Swap(reinterpret_cast(3))), 2); } TEST_F(ThreadLocalTest, Scrape) { auto unref = [](void* ptr) { auto& p = *static_cast(ptr); p.mu->Lock(); ++(*p.unref); p.mu->Unlock(); }; auto func = [](void* ptr) { auto& p = *static_cast(ptr); ASSERT_TRUE(p.tls1.Get() == nullptr); ASSERT_TRUE(p.tls2->Get() == nullptr); p.tls1.Reset(ptr); p.tls2->Reset(ptr); p.tls1.Reset(ptr); p.tls2->Reset(ptr); p.mu->Lock(); ++(p.completed); p.cv->SignalAll(); // Waiting for instruction to exit thread while (p.completed != 0) { p.cv->Wait(); } p.mu->Unlock(); }; for (int th = 1; th <= 128; th += th) { port::Mutex mu; port::CondVar cv(&mu); int unref_count = 0; Params p(&mu, &cv, &unref_count, th, unref); p.tls2 = new ThreadLocalPtr(unref); for (int i = 0; i < p.total; ++i) { env_->StartThread(func, static_cast(&p)); } // Wait for all threads to finish using Params mu.Lock(); while (p.completed != p.total) { cv.Wait(); } mu.Unlock(); ASSERT_EQ(unref_count, 0); // Scrape all thread local data. No unref at thread // exit or ThreadLocalPtr destruction autovector ptrs; p.tls1.Scrape(&ptrs, nullptr); p.tls2->Scrape(&ptrs, nullptr); delete p.tls2; // Signal to exit mu.Lock(); p.completed = 0; cv.SignalAll(); mu.Unlock(); env_->WaitForJoin(); ASSERT_EQ(unref_count, 0); } } TEST_F(ThreadLocalTest, Fold) { auto unref = [](void* ptr) { delete static_cast*>(ptr); }; static const int kNumThreads = 16; static const int kItersPerThread = 10; port::Mutex mu; port::CondVar cv(&mu); Params params(&mu, &cv, nullptr, kNumThreads, unref); auto func = [](void* ptr) { auto& p = *static_cast(ptr); ASSERT_TRUE(p.tls1.Get() == nullptr); p.tls1.Reset(new std::atomic(0)); for (int i = 0; i < kItersPerThread; ++i) { static_cast*>(p.tls1.Get())->fetch_add(1); } p.mu->Lock(); ++(p.completed); p.cv->SignalAll(); // Waiting for instruction to exit thread while (p.completed != 0) { p.cv->Wait(); } p.mu->Unlock(); }; for (int th = 0; th < params.total; ++th) { env_->StartThread(func, static_cast(¶ms)); } // Wait for all threads to finish using Params mu.Lock(); while (params.completed != params.total) { cv.Wait(); } mu.Unlock(); // Verify Fold() behavior int64_t sum = 0; params.tls1.Fold( [](void* ptr, void* res) { auto sum_ptr = static_cast(res); *sum_ptr += static_cast*>(ptr)->load(); }, &sum); ASSERT_EQ(sum, kNumThreads * kItersPerThread); // Signal to exit mu.Lock(); params.completed = 0; cv.SignalAll(); mu.Unlock(); env_->WaitForJoin(); } TEST_F(ThreadLocalTest, CompareAndSwap) { ThreadLocalPtr tls; ASSERT_TRUE(tls.Swap(reinterpret_cast(1)) == nullptr); void* expected = reinterpret_cast(1); // Swap in 2 ASSERT_TRUE(tls.CompareAndSwap(reinterpret_cast(2), expected)); expected = reinterpret_cast(100); // Fail Swap, still 2 ASSERT_TRUE(!tls.CompareAndSwap(reinterpret_cast(2), expected)); ASSERT_EQ(expected, reinterpret_cast(2)); // Swap in 3 expected = reinterpret_cast(2); ASSERT_TRUE(tls.CompareAndSwap(reinterpret_cast(3), expected)); ASSERT_EQ(tls.Get(), reinterpret_cast(3)); } namespace { void* AccessThreadLocal(void* /*arg*/) { TEST_SYNC_POINT("AccessThreadLocal:Start"); ThreadLocalPtr tlp; tlp.Reset(new std::string("hello RocksDB")); TEST_SYNC_POINT("AccessThreadLocal:End"); return nullptr; } } // namespace // The following test is disabled as it requires manual steps to run it // correctly. // // Currently we have no way to acess SyncPoint w/o ASAN error when the // child thread dies after the main thread dies. So if you manually enable // this test and only see an ASAN error on SyncPoint, it means you pass the // test. TEST_F(ThreadLocalTest, DISABLED_MainThreadDiesFirst) { ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency( {{"AccessThreadLocal:Start", "MainThreadDiesFirst:End"}, {"PosixEnv::~PosixEnv():End", "AccessThreadLocal:End"}}); // Triggers the initialization of singletons. Env::Default(); #ifndef ROCKSDB_LITE try { #endif // ROCKSDB_LITE ROCKSDB_NAMESPACE::port::Thread th(&AccessThreadLocal, nullptr); th.detach(); TEST_SYNC_POINT("MainThreadDiesFirst:End"); #ifndef ROCKSDB_LITE } catch (const std::system_error& ex) { std::cerr << "Start thread: " << ex.code() << std::endl; FAIL(); } #endif // ROCKSDB_LITE } } // namespace ROCKSDB_NAMESPACE int main(int argc, char** argv) { ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); }