From e6918187568dbd01842d8d1d2c808ce16a894239 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sun, 21 Apr 2024 13:54:28 +0200 Subject: Adding upstream version 18.2.2. Signed-off-by: Daniel Baumann --- src/rocksdb/util/thread_local_test.cc | 582 ++++++++++++++++++++++++++++++++++ 1 file changed, 582 insertions(+) create mode 100644 src/rocksdb/util/thread_local_test.cc (limited to 'src/rocksdb/util/thread_local_test.cc') diff --git a/src/rocksdb/util/thread_local_test.cc b/src/rocksdb/util/thread_local_test.cc new file mode 100644 index 000000000..25ef5c0ee --- /dev/null +++ b/src/rocksdb/util/thread_local_test.cc @@ -0,0 +1,582 @@ +// Copyright (c) 2011-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). + +#include "util/thread_local.h" + +#include +#include +#include + +#include "port/port.h" +#include "rocksdb/env.h" +#include "test_util/sync_point.h" +#include "test_util/testharness.h" +#include "test_util/testutil.h" +#include "util/autovector.h" + +namespace ROCKSDB_NAMESPACE { + +class ThreadLocalTest : public testing::Test { + public: + ThreadLocalTest() : env_(Env::Default()) {} + + Env* env_; +}; + +namespace { + +struct Params { + Params(port::Mutex* m, port::CondVar* c, int* u, int n, + UnrefHandler handler = nullptr) + : mu(m), + cv(c), + unref(u), + total(n), + started(0), + completed(0), + doWrite(false), + tls1(handler), + tls2(nullptr) {} + + port::Mutex* mu; + port::CondVar* cv; + int* unref; + int total; + int started; + int completed; + bool doWrite; + ThreadLocalPtr tls1; + ThreadLocalPtr* tls2; +}; + +class IDChecker : public ThreadLocalPtr { + public: + static uint32_t PeekId() { return TEST_PeekId(); } +}; + +} // anonymous namespace + +// Suppress false positive clang analyzer warnings. +#ifndef __clang_analyzer__ +TEST_F(ThreadLocalTest, UniqueIdTest) { + port::Mutex mu; + port::CondVar cv(&mu); + + uint32_t base_id = IDChecker::PeekId(); + // New ThreadLocal instance bumps id by 1 + { + // Id used 0 + Params p1(&mu, &cv, nullptr, 1u); + ASSERT_EQ(IDChecker::PeekId(), base_id + 1u); + // Id used 1 + Params p2(&mu, &cv, nullptr, 1u); + ASSERT_EQ(IDChecker::PeekId(), base_id + 2u); + // Id used 2 + Params p3(&mu, &cv, nullptr, 1u); + ASSERT_EQ(IDChecker::PeekId(), base_id + 3u); + // Id used 3 + Params p4(&mu, &cv, nullptr, 1u); + ASSERT_EQ(IDChecker::PeekId(), base_id + 4u); + } + // id 3, 2, 1, 0 are in the free queue in order + ASSERT_EQ(IDChecker::PeekId(), base_id + 0u); + + // pick up 0 + Params p1(&mu, &cv, nullptr, 1u); + ASSERT_EQ(IDChecker::PeekId(), base_id + 1u); + // pick up 1 + Params* p2 = new Params(&mu, &cv, nullptr, 1u); + ASSERT_EQ(IDChecker::PeekId(), base_id + 2u); + // pick up 2 + Params p3(&mu, &cv, nullptr, 1u); + ASSERT_EQ(IDChecker::PeekId(), base_id + 3u); + // return up 1 + delete p2; + ASSERT_EQ(IDChecker::PeekId(), base_id + 1u); + // Now we have 3, 1 in queue + // pick up 1 + Params p4(&mu, &cv, nullptr, 1u); + ASSERT_EQ(IDChecker::PeekId(), base_id + 3u); + // pick up 3 + Params p5(&mu, &cv, nullptr, 1u); + // next new id + ASSERT_EQ(IDChecker::PeekId(), base_id + 4u); + // After exit, id sequence in queue: + // 3, 1, 2, 0 +} +#endif // __clang_analyzer__ + +TEST_F(ThreadLocalTest, SequentialReadWriteTest) { + // global id list carries over 3, 1, 2, 0 + uint32_t base_id = IDChecker::PeekId(); + + port::Mutex mu; + port::CondVar cv(&mu); + Params p(&mu, &cv, nullptr, 1); + ThreadLocalPtr tls2; + p.tls2 = &tls2; + + ASSERT_GT(IDChecker::PeekId(), base_id); + base_id = IDChecker::PeekId(); + + auto func = [](Params* ptr) { + Params& params = *ptr; + ASSERT_TRUE(params.tls1.Get() == nullptr); + params.tls1.Reset(reinterpret_cast(1)); + ASSERT_TRUE(params.tls1.Get() == reinterpret_cast(1)); + params.tls1.Reset(reinterpret_cast(2)); + ASSERT_TRUE(params.tls1.Get() == reinterpret_cast(2)); + + ASSERT_TRUE(params.tls2->Get() == nullptr); + params.tls2->Reset(reinterpret_cast(1)); + ASSERT_TRUE(params.tls2->Get() == reinterpret_cast(1)); + params.tls2->Reset(reinterpret_cast(2)); + ASSERT_TRUE(params.tls2->Get() == reinterpret_cast(2)); + + params.mu->Lock(); + ++(params.completed); + params.cv->SignalAll(); + params.mu->Unlock(); + }; + + for (int iter = 0; iter < 1024; ++iter) { + ASSERT_EQ(IDChecker::PeekId(), base_id); + // Another new thread, read/write should not see value from previous thread + env_->StartThreadTyped(func, &p); + + mu.Lock(); + while (p.completed != iter + 1) { + cv.Wait(); + } + mu.Unlock(); + ASSERT_EQ(IDChecker::PeekId(), base_id); + } +} + +TEST_F(ThreadLocalTest, ConcurrentReadWriteTest) { + // global id list carries over 3, 1, 2, 0 + uint32_t base_id = IDChecker::PeekId(); + + ThreadLocalPtr tls2; + port::Mutex mu1; + port::CondVar cv1(&mu1); + Params p1(&mu1, &cv1, nullptr, 16); + p1.tls2 = &tls2; + + port::Mutex mu2; + port::CondVar cv2(&mu2); + Params p2(&mu2, &cv2, nullptr, 16); + p2.doWrite = true; + p2.tls2 = &tls2; + + auto func = [](void* ptr) { + auto& p = *static_cast(ptr); + + p.mu->Lock(); + // Size_T switches size along with the ptr size + // we want to cast to. + size_t own = ++(p.started); + p.cv->SignalAll(); + while (p.started != p.total) { + p.cv->Wait(); + } + p.mu->Unlock(); + + // Let write threads write a different value from the read threads + if (p.doWrite) { + own += 8192; + } + + ASSERT_TRUE(p.tls1.Get() == nullptr); + ASSERT_TRUE(p.tls2->Get() == nullptr); + + auto* env = Env::Default(); + auto start = env->NowMicros(); + + p.tls1.Reset(reinterpret_cast(own)); + p.tls2->Reset(reinterpret_cast(own + 1)); + // Loop for 1 second + while (env->NowMicros() - start < 1000 * 1000) { + for (int iter = 0; iter < 100000; ++iter) { + ASSERT_TRUE(p.tls1.Get() == reinterpret_cast(own)); + ASSERT_TRUE(p.tls2->Get() == reinterpret_cast(own + 1)); + if (p.doWrite) { + p.tls1.Reset(reinterpret_cast(own)); + p.tls2->Reset(reinterpret_cast(own + 1)); + } + } + } + + p.mu->Lock(); + ++(p.completed); + p.cv->SignalAll(); + p.mu->Unlock(); + }; + + // Initiate 2 instnaces: one keeps writing and one keeps reading. + // The read instance should not see data from the write instance. + // Each thread local copy of the value are also different from each + // other. + for (int th = 0; th < p1.total; ++th) { + env_->StartThreadTyped(func, &p1); + } + for (int th = 0; th < p2.total; ++th) { + env_->StartThreadTyped(func, &p2); + } + + mu1.Lock(); + while (p1.completed != p1.total) { + cv1.Wait(); + } + mu1.Unlock(); + + mu2.Lock(); + while (p2.completed != p2.total) { + cv2.Wait(); + } + mu2.Unlock(); + + ASSERT_EQ(IDChecker::PeekId(), base_id + 3u); +} + +TEST_F(ThreadLocalTest, Unref) { + auto unref = [](void* ptr) { + auto& p = *static_cast(ptr); + p.mu->Lock(); + ++(*p.unref); + p.mu->Unlock(); + }; + + // Case 0: no unref triggered if ThreadLocalPtr is never accessed + auto func0 = [](Params* ptr) { + auto& p = *ptr; + p.mu->Lock(); + ++(p.started); + p.cv->SignalAll(); + while (p.started != p.total) { + p.cv->Wait(); + } + p.mu->Unlock(); + }; + + for (int th = 1; th <= 128; th += th) { + port::Mutex mu; + port::CondVar cv(&mu); + int unref_count = 0; + Params p(&mu, &cv, &unref_count, th, unref); + + for (int i = 0; i < p.total; ++i) { + env_->StartThreadTyped(func0, &p); + } + env_->WaitForJoin(); + ASSERT_EQ(unref_count, 0); + } + + // Case 1: unref triggered by thread exit + auto func1 = [](Params* ptr) { + auto& p = *ptr; + + p.mu->Lock(); + ++(p.started); + p.cv->SignalAll(); + while (p.started != p.total) { + p.cv->Wait(); + } + p.mu->Unlock(); + + ASSERT_TRUE(p.tls1.Get() == nullptr); + ASSERT_TRUE(p.tls2->Get() == nullptr); + + p.tls1.Reset(ptr); + p.tls2->Reset(ptr); + + p.tls1.Reset(ptr); + p.tls2->Reset(ptr); + }; + + for (int th = 1; th <= 128; th += th) { + port::Mutex mu; + port::CondVar cv(&mu); + int unref_count = 0; + ThreadLocalPtr tls2(unref); + Params p(&mu, &cv, &unref_count, th, unref); + p.tls2 = &tls2; + + for (int i = 0; i < p.total; ++i) { + env_->StartThreadTyped(func1, &p); + } + + env_->WaitForJoin(); + + // N threads x 2 ThreadLocal instance cleanup on thread exit + ASSERT_EQ(unref_count, 2 * p.total); + } + + // Case 2: unref triggered by ThreadLocal instance destruction + auto func2 = [](Params* ptr) { + auto& p = *ptr; + + p.mu->Lock(); + ++(p.started); + p.cv->SignalAll(); + while (p.started != p.total) { + p.cv->Wait(); + } + p.mu->Unlock(); + + ASSERT_TRUE(p.tls1.Get() == nullptr); + ASSERT_TRUE(p.tls2->Get() == nullptr); + + p.tls1.Reset(ptr); + p.tls2->Reset(ptr); + + p.tls1.Reset(ptr); + p.tls2->Reset(ptr); + + p.mu->Lock(); + ++(p.completed); + p.cv->SignalAll(); + + // Waiting for instruction to exit thread + while (p.completed != 0) { + p.cv->Wait(); + } + p.mu->Unlock(); + }; + + for (int th = 1; th <= 128; th += th) { + port::Mutex mu; + port::CondVar cv(&mu); + int unref_count = 0; + Params p(&mu, &cv, &unref_count, th, unref); + p.tls2 = new ThreadLocalPtr(unref); + + for (int i = 0; i < p.total; ++i) { + env_->StartThreadTyped(func2, &p); + } + + // Wait for all threads to finish using Params + mu.Lock(); + while (p.completed != p.total) { + cv.Wait(); + } + mu.Unlock(); + + // Now destroy one ThreadLocal instance + delete p.tls2; + p.tls2 = nullptr; + // instance destroy for N threads + ASSERT_EQ(unref_count, p.total); + + // Signal to exit + mu.Lock(); + p.completed = 0; + cv.SignalAll(); + mu.Unlock(); + env_->WaitForJoin(); + // additional N threads exit unref for the left instance + ASSERT_EQ(unref_count, 2 * p.total); + } +} + +TEST_F(ThreadLocalTest, Swap) { + ThreadLocalPtr tls; + tls.Reset(reinterpret_cast(1)); + ASSERT_EQ(reinterpret_cast(tls.Swap(nullptr)), 1); + ASSERT_TRUE(tls.Swap(reinterpret_cast(2)) == nullptr); + ASSERT_EQ(reinterpret_cast(tls.Get()), 2); + ASSERT_EQ(reinterpret_cast(tls.Swap(reinterpret_cast(3))), 2); +} + +TEST_F(ThreadLocalTest, Scrape) { + auto unref = [](void* ptr) { + auto& p = *static_cast(ptr); + p.mu->Lock(); + ++(*p.unref); + p.mu->Unlock(); + }; + + auto func = [](void* ptr) { + auto& p = *static_cast(ptr); + + ASSERT_TRUE(p.tls1.Get() == nullptr); + ASSERT_TRUE(p.tls2->Get() == nullptr); + + p.tls1.Reset(ptr); + p.tls2->Reset(ptr); + + p.tls1.Reset(ptr); + p.tls2->Reset(ptr); + + p.mu->Lock(); + ++(p.completed); + p.cv->SignalAll(); + + // Waiting for instruction to exit thread + while (p.completed != 0) { + p.cv->Wait(); + } + p.mu->Unlock(); + }; + + for (int th = 1; th <= 128; th += th) { + port::Mutex mu; + port::CondVar cv(&mu); + int unref_count = 0; + Params p(&mu, &cv, &unref_count, th, unref); + p.tls2 = new ThreadLocalPtr(unref); + + for (int i = 0; i < p.total; ++i) { + env_->StartThreadTyped(func, &p); + } + + // Wait for all threads to finish using Params + mu.Lock(); + while (p.completed != p.total) { + cv.Wait(); + } + mu.Unlock(); + + ASSERT_EQ(unref_count, 0); + + // Scrape all thread local data. No unref at thread + // exit or ThreadLocalPtr destruction + autovector ptrs; + p.tls1.Scrape(&ptrs, nullptr); + p.tls2->Scrape(&ptrs, nullptr); + delete p.tls2; + // Signal to exit + mu.Lock(); + p.completed = 0; + cv.SignalAll(); + mu.Unlock(); + env_->WaitForJoin(); + + ASSERT_EQ(unref_count, 0); + } +} + +TEST_F(ThreadLocalTest, Fold) { + auto unref = [](void* ptr) { + delete static_cast*>(ptr); + }; + static const int kNumThreads = 16; + static const int kItersPerThread = 10; + port::Mutex mu; + port::CondVar cv(&mu); + Params params(&mu, &cv, nullptr, kNumThreads, unref); + auto func = [](void* ptr) { + auto& p = *static_cast(ptr); + ASSERT_TRUE(p.tls1.Get() == nullptr); + p.tls1.Reset(new std::atomic(0)); + + for (int i = 0; i < kItersPerThread; ++i) { + static_cast*>(p.tls1.Get())->fetch_add(1); + } + + p.mu->Lock(); + ++(p.completed); + p.cv->SignalAll(); + + // Waiting for instruction to exit thread + while (p.completed != 0) { + p.cv->Wait(); + } + p.mu->Unlock(); + }; + + for (int th = 0; th < params.total; ++th) { + env_->StartThread(func, ¶ms); + } + + // Wait for all threads to finish using Params + mu.Lock(); + while (params.completed != params.total) { + cv.Wait(); + } + mu.Unlock(); + + // Verify Fold() behavior + int64_t sum = 0; + params.tls1.Fold( + [](void* ptr, void* res) { + auto sum_ptr = static_cast(res); + *sum_ptr += static_cast*>(ptr)->load(); + }, + &sum); + ASSERT_EQ(sum, kNumThreads * kItersPerThread); + + // Signal to exit + mu.Lock(); + params.completed = 0; + cv.SignalAll(); + mu.Unlock(); + env_->WaitForJoin(); +} + +TEST_F(ThreadLocalTest, CompareAndSwap) { + ThreadLocalPtr tls; + ASSERT_TRUE(tls.Swap(reinterpret_cast(1)) == nullptr); + void* expected = reinterpret_cast(1); + // Swap in 2 + ASSERT_TRUE(tls.CompareAndSwap(reinterpret_cast(2), expected)); + expected = reinterpret_cast(100); + // Fail Swap, still 2 + ASSERT_TRUE(!tls.CompareAndSwap(reinterpret_cast(2), expected)); + ASSERT_EQ(expected, reinterpret_cast(2)); + // Swap in 3 + expected = reinterpret_cast(2); + ASSERT_TRUE(tls.CompareAndSwap(reinterpret_cast(3), expected)); + ASSERT_EQ(tls.Get(), reinterpret_cast(3)); +} + +namespace { + +void* AccessThreadLocal(void* /*arg*/) { + TEST_SYNC_POINT("AccessThreadLocal:Start"); + ThreadLocalPtr tlp; + tlp.Reset(new std::string("hello RocksDB")); + TEST_SYNC_POINT("AccessThreadLocal:End"); + return nullptr; +} + +} // namespace + +// The following test is disabled as it requires manual steps to run it +// correctly. +// +// Currently we have no way to acess SyncPoint w/o ASAN error when the +// child thread dies after the main thread dies. So if you manually enable +// this test and only see an ASAN error on SyncPoint, it means you pass the +// test. +TEST_F(ThreadLocalTest, DISABLED_MainThreadDiesFirst) { + ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency( + {{"AccessThreadLocal:Start", "MainThreadDiesFirst:End"}, + {"PosixEnv::~PosixEnv():End", "AccessThreadLocal:End"}}); + + // Triggers the initialization of singletons. + Env::Default(); + +#ifndef ROCKSDB_LITE + try { +#endif // ROCKSDB_LITE + ROCKSDB_NAMESPACE::port::Thread th(&AccessThreadLocal, nullptr); + th.detach(); + TEST_SYNC_POINT("MainThreadDiesFirst:End"); +#ifndef ROCKSDB_LITE + } catch (const std::system_error& ex) { + std::cerr << "Start thread: " << ex.code() << std::endl; + FAIL(); + } +#endif // ROCKSDB_LITE +} + +} // namespace ROCKSDB_NAMESPACE + +int main(int argc, char** argv) { + ROCKSDB_NAMESPACE::port::InstallStackTraceHandler(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} -- cgit v1.2.3