summaryrefslogtreecommitdiffstats
path: root/src/rocksdb/cache/sharded_cache.h
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--src/rocksdb/cache/sharded_cache.h322
1 files changed, 322 insertions, 0 deletions
diff --git a/src/rocksdb/cache/sharded_cache.h b/src/rocksdb/cache/sharded_cache.h
new file mode 100644
index 000000000..e3271cc7b
--- /dev/null
+++ b/src/rocksdb/cache/sharded_cache.h
@@ -0,0 +1,322 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#pragma once
+
+#include <atomic>
+#include <cstdint>
+#include <string>
+
+#include "port/lang.h"
+#include "port/port.h"
+#include "rocksdb/cache.h"
+#include "util/hash.h"
+#include "util/mutexlock.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// Optional base class for classes implementing the CacheShard concept
+class CacheShardBase {
+ public:
+ explicit CacheShardBase(CacheMetadataChargePolicy metadata_charge_policy)
+ : metadata_charge_policy_(metadata_charge_policy) {}
+
+ using DeleterFn = Cache::DeleterFn;
+
+ // Expected by concept CacheShard (TODO with C++20 support)
+ // Some Defaults
+ std::string GetPrintableOptions() const { return ""; }
+ using HashVal = uint64_t;
+ using HashCref = uint64_t;
+ static inline HashVal ComputeHash(const Slice& key) {
+ return GetSliceNPHash64(key);
+ }
+ static inline uint32_t HashPieceForSharding(HashCref hash) {
+ return Lower32of64(hash);
+ }
+ void AppendPrintableOptions(std::string& /*str*/) const {}
+
+ // Must be provided for concept CacheShard (TODO with C++20 support)
+ /*
+ struct HandleImpl { // for concept HandleImpl
+ HashVal hash;
+ HashCref GetHash() const;
+ ...
+ };
+ Status Insert(const Slice& key, HashCref hash, void* value, size_t charge,
+ DeleterFn deleter, HandleImpl** handle,
+ Cache::Priority priority) = 0;
+ Status Insert(const Slice& key, HashCref hash, void* value,
+ const Cache::CacheItemHelper* helper, size_t charge,
+ HandleImpl** handle, Cache::Priority priority) = 0;
+ HandleImpl* Lookup(const Slice& key, HashCref hash) = 0;
+ HandleImpl* Lookup(const Slice& key, HashCref hash,
+ const Cache::CacheItemHelper* helper,
+ const Cache::CreateCallback& create_cb,
+ Cache::Priority priority, bool wait,
+ Statistics* stats) = 0;
+ bool Release(HandleImpl* handle, bool useful, bool erase_if_last_ref) = 0;
+ bool IsReady(HandleImpl* handle) = 0;
+ void Wait(HandleImpl* handle) = 0;
+ bool Ref(HandleImpl* handle) = 0;
+ void Erase(const Slice& key, HashCref hash) = 0;
+ void SetCapacity(size_t capacity) = 0;
+ void SetStrictCapacityLimit(bool strict_capacity_limit) = 0;
+ size_t GetUsage() const = 0;
+ size_t GetPinnedUsage() const = 0;
+ size_t GetOccupancyCount() const = 0;
+ size_t GetTableAddressCount() const = 0;
+ // Handles iterating over roughly `average_entries_per_lock` entries, using
+ // `state` to somehow record where it last ended up. Caller initially uses
+ // *state == 0 and implementation sets *state = SIZE_MAX to indicate
+ // completion.
+ void ApplyToSomeEntries(
+ const std::function<void(const Slice& key, void* value, size_t charge,
+ DeleterFn deleter)>& callback,
+ size_t average_entries_per_lock, size_t* state) = 0;
+ void EraseUnRefEntries() = 0;
+ */
+
+ protected:
+ const CacheMetadataChargePolicy metadata_charge_policy_;
+};
+
+// Portions of ShardedCache that do not depend on the template parameter
+class ShardedCacheBase : public Cache {
+ public:
+ ShardedCacheBase(size_t capacity, int num_shard_bits,
+ bool strict_capacity_limit,
+ std::shared_ptr<MemoryAllocator> memory_allocator);
+ virtual ~ShardedCacheBase() = default;
+
+ int GetNumShardBits() const;
+ uint32_t GetNumShards() const;
+
+ uint64_t NewId() override;
+
+ bool HasStrictCapacityLimit() const override;
+ size_t GetCapacity() const override;
+
+ using Cache::GetUsage;
+ size_t GetUsage(Handle* handle) const override;
+ std::string GetPrintableOptions() const override;
+
+ protected: // fns
+ virtual void AppendPrintableOptions(std::string& str) const = 0;
+ size_t GetPerShardCapacity() const;
+ size_t ComputePerShardCapacity(size_t capacity) const;
+
+ protected: // data
+ std::atomic<uint64_t> last_id_; // For NewId
+ const uint32_t shard_mask_;
+
+ // Dynamic configuration parameters, guarded by config_mutex_
+ bool strict_capacity_limit_;
+ size_t capacity_;
+ mutable port::Mutex config_mutex_;
+};
+
+// Generic cache interface that shards cache by hash of keys. 2^num_shard_bits
+// shards will be created, with capacity split evenly to each of the shards.
+// Keys are typically sharded by the lowest num_shard_bits bits of hash value
+// so that the upper bits of the hash value can keep a stable ordering of
+// table entries even as the table grows (using more upper hash bits).
+// See CacheShardBase above for what is expected of the CacheShard parameter.
+template <class CacheShard>
+class ShardedCache : public ShardedCacheBase {
+ public:
+ using HashVal = typename CacheShard::HashVal;
+ using HashCref = typename CacheShard::HashCref;
+ using HandleImpl = typename CacheShard::HandleImpl;
+
+ ShardedCache(size_t capacity, int num_shard_bits, bool strict_capacity_limit,
+ std::shared_ptr<MemoryAllocator> allocator)
+ : ShardedCacheBase(capacity, num_shard_bits, strict_capacity_limit,
+ allocator),
+ shards_(reinterpret_cast<CacheShard*>(port::cacheline_aligned_alloc(
+ sizeof(CacheShard) * GetNumShards()))),
+ destroy_shards_in_dtor_(false) {}
+
+ virtual ~ShardedCache() {
+ if (destroy_shards_in_dtor_) {
+ ForEachShard([](CacheShard* cs) { cs->~CacheShard(); });
+ }
+ port::cacheline_aligned_free(shards_);
+ }
+
+ CacheShard& GetShard(HashCref hash) {
+ return shards_[CacheShard::HashPieceForSharding(hash) & shard_mask_];
+ }
+
+ const CacheShard& GetShard(HashCref hash) const {
+ return shards_[CacheShard::HashPieceForSharding(hash) & shard_mask_];
+ }
+
+ void SetCapacity(size_t capacity) override {
+ MutexLock l(&config_mutex_);
+ capacity_ = capacity;
+ auto per_shard = ComputePerShardCapacity(capacity);
+ ForEachShard([=](CacheShard* cs) { cs->SetCapacity(per_shard); });
+ }
+
+ void SetStrictCapacityLimit(bool s_c_l) override {
+ MutexLock l(&config_mutex_);
+ strict_capacity_limit_ = s_c_l;
+ ForEachShard(
+ [s_c_l](CacheShard* cs) { cs->SetStrictCapacityLimit(s_c_l); });
+ }
+
+ Status Insert(const Slice& key, void* value, size_t charge, DeleterFn deleter,
+ Handle** handle, Priority priority) override {
+ HashVal hash = CacheShard::ComputeHash(key);
+ auto h_out = reinterpret_cast<HandleImpl**>(handle);
+ return GetShard(hash).Insert(key, hash, value, charge, deleter, h_out,
+ priority);
+ }
+ Status Insert(const Slice& key, void* value, const CacheItemHelper* helper,
+ size_t charge, Handle** handle = nullptr,
+ Priority priority = Priority::LOW) override {
+ if (!helper) {
+ return Status::InvalidArgument();
+ }
+ HashVal hash = CacheShard::ComputeHash(key);
+ auto h_out = reinterpret_cast<HandleImpl**>(handle);
+ return GetShard(hash).Insert(key, hash, value, helper, charge, h_out,
+ priority);
+ }
+
+ Handle* Lookup(const Slice& key, Statistics* /*stats*/) override {
+ HashVal hash = CacheShard::ComputeHash(key);
+ HandleImpl* result = GetShard(hash).Lookup(key, hash);
+ return reinterpret_cast<Handle*>(result);
+ }
+ Handle* Lookup(const Slice& key, const CacheItemHelper* helper,
+ const CreateCallback& create_cb, Priority priority, bool wait,
+ Statistics* stats = nullptr) override {
+ HashVal hash = CacheShard::ComputeHash(key);
+ HandleImpl* result = GetShard(hash).Lookup(key, hash, helper, create_cb,
+ priority, wait, stats);
+ return reinterpret_cast<Handle*>(result);
+ }
+
+ void Erase(const Slice& key) override {
+ HashVal hash = CacheShard::ComputeHash(key);
+ GetShard(hash).Erase(key, hash);
+ }
+
+ bool Release(Handle* handle, bool useful,
+ bool erase_if_last_ref = false) override {
+ auto h = reinterpret_cast<HandleImpl*>(handle);
+ return GetShard(h->GetHash()).Release(h, useful, erase_if_last_ref);
+ }
+ bool IsReady(Handle* handle) override {
+ auto h = reinterpret_cast<HandleImpl*>(handle);
+ return GetShard(h->GetHash()).IsReady(h);
+ }
+ void Wait(Handle* handle) override {
+ auto h = reinterpret_cast<HandleImpl*>(handle);
+ GetShard(h->GetHash()).Wait(h);
+ }
+ bool Ref(Handle* handle) override {
+ auto h = reinterpret_cast<HandleImpl*>(handle);
+ return GetShard(h->GetHash()).Ref(h);
+ }
+ bool Release(Handle* handle, bool erase_if_last_ref = false) override {
+ return Release(handle, true /*useful*/, erase_if_last_ref);
+ }
+ using ShardedCacheBase::GetUsage;
+ size_t GetUsage() const override {
+ return SumOverShards2(&CacheShard::GetUsage);
+ }
+ size_t GetPinnedUsage() const override {
+ return SumOverShards2(&CacheShard::GetPinnedUsage);
+ }
+ size_t GetOccupancyCount() const override {
+ return SumOverShards2(&CacheShard::GetPinnedUsage);
+ }
+ size_t GetTableAddressCount() const override {
+ return SumOverShards2(&CacheShard::GetTableAddressCount);
+ }
+ void ApplyToAllEntries(
+ const std::function<void(const Slice& key, void* value, size_t charge,
+ DeleterFn deleter)>& callback,
+ const ApplyToAllEntriesOptions& opts) override {
+ uint32_t num_shards = GetNumShards();
+ // Iterate over part of each shard, rotating between shards, to
+ // minimize impact on latency of concurrent operations.
+ std::unique_ptr<size_t[]> states(new size_t[num_shards]{});
+
+ size_t aepl = opts.average_entries_per_lock;
+ aepl = std::min(aepl, size_t{1});
+
+ bool remaining_work;
+ do {
+ remaining_work = false;
+ for (uint32_t i = 0; i < num_shards; i++) {
+ if (states[i] != SIZE_MAX) {
+ shards_[i].ApplyToSomeEntries(callback, aepl, &states[i]);
+ remaining_work |= states[i] != SIZE_MAX;
+ }
+ }
+ } while (remaining_work);
+ }
+
+ virtual void EraseUnRefEntries() override {
+ ForEachShard([](CacheShard* cs) { cs->EraseUnRefEntries(); });
+ }
+
+ void DisownData() override {
+ // Leak data only if that won't generate an ASAN/valgrind warning.
+ if (!kMustFreeHeapAllocations) {
+ destroy_shards_in_dtor_ = false;
+ }
+ }
+
+ protected:
+ inline void ForEachShard(const std::function<void(CacheShard*)>& fn) {
+ uint32_t num_shards = GetNumShards();
+ for (uint32_t i = 0; i < num_shards; i++) {
+ fn(shards_ + i);
+ }
+ }
+
+ inline size_t SumOverShards(
+ const std::function<size_t(CacheShard&)>& fn) const {
+ uint32_t num_shards = GetNumShards();
+ size_t result = 0;
+ for (uint32_t i = 0; i < num_shards; i++) {
+ result += fn(shards_[i]);
+ }
+ return result;
+ }
+
+ inline size_t SumOverShards2(size_t (CacheShard::*fn)() const) const {
+ return SumOverShards([fn](CacheShard& cs) { return (cs.*fn)(); });
+ }
+
+ // Must be called exactly once by derived class constructor
+ void InitShards(const std::function<void(CacheShard*)>& placement_new) {
+ ForEachShard(placement_new);
+ destroy_shards_in_dtor_ = true;
+ }
+
+ void AppendPrintableOptions(std::string& str) const override {
+ shards_[0].AppendPrintableOptions(str);
+ }
+
+ private:
+ CacheShard* const shards_;
+ bool destroy_shards_in_dtor_;
+};
+
+// 512KB is traditional minimum shard size.
+int GetDefaultCacheShardBits(size_t capacity,
+ size_t min_shard_size = 512U * 1024U);
+
+} // namespace ROCKSDB_NAMESPACE