summaryrefslogtreecommitdiffstats
path: root/src/rocksdb/cache/sharded_cache.h
blob: e3271cc7bd3beef0f18ee6bafc6484746c33c26d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
//  Copyright (c) 2011-present, Facebook, Inc.  All rights reserved.
//  This source code is licensed under both the GPLv2 (found in the
//  COPYING file in the root directory) and Apache 2.0 License
//  (found in the LICENSE.Apache file in the root directory).
//
// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file. See the AUTHORS file for names of contributors.

#pragma once

#include <atomic>
#include <cstdint>
#include <string>

#include "port/lang.h"
#include "port/port.h"
#include "rocksdb/cache.h"
#include "util/hash.h"
#include "util/mutexlock.h"

namespace ROCKSDB_NAMESPACE {

// Optional base class for classes implementing the CacheShard concept
class CacheShardBase {
 public:
  explicit CacheShardBase(CacheMetadataChargePolicy metadata_charge_policy)
      : metadata_charge_policy_(metadata_charge_policy) {}

  using DeleterFn = Cache::DeleterFn;

  // Expected by concept CacheShard (TODO with C++20 support)
  // Some Defaults
  std::string GetPrintableOptions() const { return ""; }
  using HashVal = uint64_t;
  using HashCref = uint64_t;
  static inline HashVal ComputeHash(const Slice& key) {
    return GetSliceNPHash64(key);
  }
  static inline uint32_t HashPieceForSharding(HashCref hash) {
    return Lower32of64(hash);
  }
  void AppendPrintableOptions(std::string& /*str*/) const {}

  // Must be provided for concept CacheShard (TODO with C++20 support)
  /*
  struct HandleImpl {  // for concept HandleImpl
    HashVal hash;
    HashCref GetHash() const;
    ...
  };
  Status Insert(const Slice& key, HashCref hash, void* value, size_t charge,
                DeleterFn deleter, HandleImpl** handle,
                Cache::Priority priority) = 0;
  Status Insert(const Slice& key, HashCref hash, void* value,
                const Cache::CacheItemHelper* helper, size_t charge,
                HandleImpl** handle, Cache::Priority priority) = 0;
  HandleImpl* Lookup(const Slice& key, HashCref hash) = 0;
  HandleImpl* Lookup(const Slice& key, HashCref hash,
                        const Cache::CacheItemHelper* helper,
                        const Cache::CreateCallback& create_cb,
                        Cache::Priority priority, bool wait,
                        Statistics* stats) = 0;
  bool Release(HandleImpl* handle, bool useful, bool erase_if_last_ref) = 0;
  bool IsReady(HandleImpl* handle) = 0;
  void Wait(HandleImpl* handle) = 0;
  bool Ref(HandleImpl* handle) = 0;
  void Erase(const Slice& key, HashCref hash) = 0;
  void SetCapacity(size_t capacity) = 0;
  void SetStrictCapacityLimit(bool strict_capacity_limit) = 0;
  size_t GetUsage() const = 0;
  size_t GetPinnedUsage() const = 0;
  size_t GetOccupancyCount() const = 0;
  size_t GetTableAddressCount() const = 0;
  // Handles iterating over roughly `average_entries_per_lock` entries, using
  // `state` to somehow record where it last ended up. Caller initially uses
  // *state == 0 and implementation sets *state = SIZE_MAX to indicate
  // completion.
  void ApplyToSomeEntries(
      const std::function<void(const Slice& key, void* value, size_t charge,
                               DeleterFn deleter)>& callback,
      size_t average_entries_per_lock, size_t* state) = 0;
  void EraseUnRefEntries() = 0;
  */

 protected:
  const CacheMetadataChargePolicy metadata_charge_policy_;
};

// Portions of ShardedCache that do not depend on the template parameter
class ShardedCacheBase : public Cache {
 public:
  ShardedCacheBase(size_t capacity, int num_shard_bits,
                   bool strict_capacity_limit,
                   std::shared_ptr<MemoryAllocator> memory_allocator);
  virtual ~ShardedCacheBase() = default;

  int GetNumShardBits() const;
  uint32_t GetNumShards() const;

  uint64_t NewId() override;

  bool HasStrictCapacityLimit() const override;
  size_t GetCapacity() const override;

  using Cache::GetUsage;
  size_t GetUsage(Handle* handle) const override;
  std::string GetPrintableOptions() const override;

 protected:  // fns
  virtual void AppendPrintableOptions(std::string& str) const = 0;
  size_t GetPerShardCapacity() const;
  size_t ComputePerShardCapacity(size_t capacity) const;

 protected:                        // data
  std::atomic<uint64_t> last_id_;  // For NewId
  const uint32_t shard_mask_;

  // Dynamic configuration parameters, guarded by config_mutex_
  bool strict_capacity_limit_;
  size_t capacity_;
  mutable port::Mutex config_mutex_;
};

// Generic cache interface that shards cache by hash of keys. 2^num_shard_bits
// shards will be created, with capacity split evenly to each of the shards.
// Keys are typically sharded by the lowest num_shard_bits bits of hash value
// so that the upper bits of the hash value can keep a stable ordering of
// table entries even as the table grows (using more upper hash bits).
// See CacheShardBase above for what is expected of the CacheShard parameter.
template <class CacheShard>
class ShardedCache : public ShardedCacheBase {
 public:
  using HashVal = typename CacheShard::HashVal;
  using HashCref = typename CacheShard::HashCref;
  using HandleImpl = typename CacheShard::HandleImpl;

  ShardedCache(size_t capacity, int num_shard_bits, bool strict_capacity_limit,
               std::shared_ptr<MemoryAllocator> allocator)
      : ShardedCacheBase(capacity, num_shard_bits, strict_capacity_limit,
                         allocator),
        shards_(reinterpret_cast<CacheShard*>(port::cacheline_aligned_alloc(
            sizeof(CacheShard) * GetNumShards()))),
        destroy_shards_in_dtor_(false) {}

  virtual ~ShardedCache() {
    if (destroy_shards_in_dtor_) {
      ForEachShard([](CacheShard* cs) { cs->~CacheShard(); });
    }
    port::cacheline_aligned_free(shards_);
  }

  CacheShard& GetShard(HashCref hash) {
    return shards_[CacheShard::HashPieceForSharding(hash) & shard_mask_];
  }

  const CacheShard& GetShard(HashCref hash) const {
    return shards_[CacheShard::HashPieceForSharding(hash) & shard_mask_];
  }

  void SetCapacity(size_t capacity) override {
    MutexLock l(&config_mutex_);
    capacity_ = capacity;
    auto per_shard = ComputePerShardCapacity(capacity);
    ForEachShard([=](CacheShard* cs) { cs->SetCapacity(per_shard); });
  }

  void SetStrictCapacityLimit(bool s_c_l) override {
    MutexLock l(&config_mutex_);
    strict_capacity_limit_ = s_c_l;
    ForEachShard(
        [s_c_l](CacheShard* cs) { cs->SetStrictCapacityLimit(s_c_l); });
  }

  Status Insert(const Slice& key, void* value, size_t charge, DeleterFn deleter,
                Handle** handle, Priority priority) override {
    HashVal hash = CacheShard::ComputeHash(key);
    auto h_out = reinterpret_cast<HandleImpl**>(handle);
    return GetShard(hash).Insert(key, hash, value, charge, deleter, h_out,
                                 priority);
  }
  Status Insert(const Slice& key, void* value, const CacheItemHelper* helper,
                size_t charge, Handle** handle = nullptr,
                Priority priority = Priority::LOW) override {
    if (!helper) {
      return Status::InvalidArgument();
    }
    HashVal hash = CacheShard::ComputeHash(key);
    auto h_out = reinterpret_cast<HandleImpl**>(handle);
    return GetShard(hash).Insert(key, hash, value, helper, charge, h_out,
                                 priority);
  }

  Handle* Lookup(const Slice& key, Statistics* /*stats*/) override {
    HashVal hash = CacheShard::ComputeHash(key);
    HandleImpl* result = GetShard(hash).Lookup(key, hash);
    return reinterpret_cast<Handle*>(result);
  }
  Handle* Lookup(const Slice& key, const CacheItemHelper* helper,
                 const CreateCallback& create_cb, Priority priority, bool wait,
                 Statistics* stats = nullptr) override {
    HashVal hash = CacheShard::ComputeHash(key);
    HandleImpl* result = GetShard(hash).Lookup(key, hash, helper, create_cb,
                                               priority, wait, stats);
    return reinterpret_cast<Handle*>(result);
  }

  void Erase(const Slice& key) override {
    HashVal hash = CacheShard::ComputeHash(key);
    GetShard(hash).Erase(key, hash);
  }

  bool Release(Handle* handle, bool useful,
               bool erase_if_last_ref = false) override {
    auto h = reinterpret_cast<HandleImpl*>(handle);
    return GetShard(h->GetHash()).Release(h, useful, erase_if_last_ref);
  }
  bool IsReady(Handle* handle) override {
    auto h = reinterpret_cast<HandleImpl*>(handle);
    return GetShard(h->GetHash()).IsReady(h);
  }
  void Wait(Handle* handle) override {
    auto h = reinterpret_cast<HandleImpl*>(handle);
    GetShard(h->GetHash()).Wait(h);
  }
  bool Ref(Handle* handle) override {
    auto h = reinterpret_cast<HandleImpl*>(handle);
    return GetShard(h->GetHash()).Ref(h);
  }
  bool Release(Handle* handle, bool erase_if_last_ref = false) override {
    return Release(handle, true /*useful*/, erase_if_last_ref);
  }
  using ShardedCacheBase::GetUsage;
  size_t GetUsage() const override {
    return SumOverShards2(&CacheShard::GetUsage);
  }
  size_t GetPinnedUsage() const override {
    return SumOverShards2(&CacheShard::GetPinnedUsage);
  }
  size_t GetOccupancyCount() const override {
    return SumOverShards2(&CacheShard::GetPinnedUsage);
  }
  size_t GetTableAddressCount() const override {
    return SumOverShards2(&CacheShard::GetTableAddressCount);
  }
  void ApplyToAllEntries(
      const std::function<void(const Slice& key, void* value, size_t charge,
                               DeleterFn deleter)>& callback,
      const ApplyToAllEntriesOptions& opts) override {
    uint32_t num_shards = GetNumShards();
    // Iterate over part of each shard, rotating between shards, to
    // minimize impact on latency of concurrent operations.
    std::unique_ptr<size_t[]> states(new size_t[num_shards]{});

    size_t aepl = opts.average_entries_per_lock;
    aepl = std::min(aepl, size_t{1});

    bool remaining_work;
    do {
      remaining_work = false;
      for (uint32_t i = 0; i < num_shards; i++) {
        if (states[i] != SIZE_MAX) {
          shards_[i].ApplyToSomeEntries(callback, aepl, &states[i]);
          remaining_work |= states[i] != SIZE_MAX;
        }
      }
    } while (remaining_work);
  }

  virtual void EraseUnRefEntries() override {
    ForEachShard([](CacheShard* cs) { cs->EraseUnRefEntries(); });
  }

  void DisownData() override {
    // Leak data only if that won't generate an ASAN/valgrind warning.
    if (!kMustFreeHeapAllocations) {
      destroy_shards_in_dtor_ = false;
    }
  }

 protected:
  inline void ForEachShard(const std::function<void(CacheShard*)>& fn) {
    uint32_t num_shards = GetNumShards();
    for (uint32_t i = 0; i < num_shards; i++) {
      fn(shards_ + i);
    }
  }

  inline size_t SumOverShards(
      const std::function<size_t(CacheShard&)>& fn) const {
    uint32_t num_shards = GetNumShards();
    size_t result = 0;
    for (uint32_t i = 0; i < num_shards; i++) {
      result += fn(shards_[i]);
    }
    return result;
  }

  inline size_t SumOverShards2(size_t (CacheShard::*fn)() const) const {
    return SumOverShards([fn](CacheShard& cs) { return (cs.*fn)(); });
  }

  // Must be called exactly once by derived class constructor
  void InitShards(const std::function<void(CacheShard*)>& placement_new) {
    ForEachShard(placement_new);
    destroy_shards_in_dtor_ = true;
  }

  void AppendPrintableOptions(std::string& str) const override {
    shards_[0].AppendPrintableOptions(str);
  }

 private:
  CacheShard* const shards_;
  bool destroy_shards_in_dtor_;
};

// 512KB is traditional minimum shard size.
int GetDefaultCacheShardBits(size_t capacity,
                             size_t min_shard_size = 512U * 1024U);

}  // namespace ROCKSDB_NAMESPACE