summaryrefslogtreecommitdiffstats
path: root/src/rocksdb/memtable
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-27 18:24:20 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-27 18:24:20 +0000
commit483eb2f56657e8e7f419ab1a4fab8dce9ade8609 (patch)
treee5d88d25d870d5dedacb6bbdbe2a966086a0a5cf /src/rocksdb/memtable
parentInitial commit. (diff)
downloadceph-483eb2f56657e8e7f419ab1a4fab8dce9ade8609.tar.xz
ceph-483eb2f56657e8e7f419ab1a4fab8dce9ade8609.zip
Adding upstream version 14.2.21.upstream/14.2.21upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to '')
-rw-r--r--src/rocksdb/memtable/alloc_tracker.cc62
-rw-r--r--src/rocksdb/memtable/hash_linklist_rep.cc845
-rw-r--r--src/rocksdb/memtable/hash_linklist_rep.h49
-rw-r--r--src/rocksdb/memtable/hash_skiplist_rep.cc349
-rw-r--r--src/rocksdb/memtable/hash_skiplist_rep.h44
-rw-r--r--src/rocksdb/memtable/inlineskiplist.h965
-rw-r--r--src/rocksdb/memtable/inlineskiplist_test.cc645
-rw-r--r--src/rocksdb/memtable/memtablerep_bench.cc682
-rw-r--r--src/rocksdb/memtable/skiplist.h497
-rw-r--r--src/rocksdb/memtable/skiplist_test.cc388
-rw-r--r--src/rocksdb/memtable/skiplistrep.cc271
-rw-r--r--src/rocksdb/memtable/stl_wrappers.h33
-rw-r--r--src/rocksdb/memtable/vectorrep.cc301
-rw-r--r--src/rocksdb/memtable/write_buffer_manager.cc130
-rw-r--r--src/rocksdb/memtable/write_buffer_manager_test.cc151
15 files changed, 5412 insertions, 0 deletions
diff --git a/src/rocksdb/memtable/alloc_tracker.cc b/src/rocksdb/memtable/alloc_tracker.cc
new file mode 100644
index 00000000..a1fa4938
--- /dev/null
+++ b/src/rocksdb/memtable/alloc_tracker.cc
@@ -0,0 +1,62 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include <assert.h>
+#include "rocksdb/write_buffer_manager.h"
+#include "util/allocator.h"
+#include "util/arena.h"
+
+namespace rocksdb {
+
+AllocTracker::AllocTracker(WriteBufferManager* write_buffer_manager)
+ : write_buffer_manager_(write_buffer_manager),
+ bytes_allocated_(0),
+ done_allocating_(false),
+ freed_(false) {}
+
+AllocTracker::~AllocTracker() { FreeMem(); }
+
+void AllocTracker::Allocate(size_t bytes) {
+ assert(write_buffer_manager_ != nullptr);
+ if (write_buffer_manager_->enabled() ||
+ write_buffer_manager_->cost_to_cache()) {
+ bytes_allocated_.fetch_add(bytes, std::memory_order_relaxed);
+ write_buffer_manager_->ReserveMem(bytes);
+ }
+}
+
+void AllocTracker::DoneAllocating() {
+ if (write_buffer_manager_ != nullptr && !done_allocating_) {
+ if (write_buffer_manager_->enabled() ||
+ write_buffer_manager_->cost_to_cache()) {
+ write_buffer_manager_->ScheduleFreeMem(
+ bytes_allocated_.load(std::memory_order_relaxed));
+ } else {
+ assert(bytes_allocated_.load(std::memory_order_relaxed) == 0);
+ }
+ done_allocating_ = true;
+ }
+}
+
+void AllocTracker::FreeMem() {
+ if (!done_allocating_) {
+ DoneAllocating();
+ }
+ if (write_buffer_manager_ != nullptr && !freed_) {
+ if (write_buffer_manager_->enabled() ||
+ write_buffer_manager_->cost_to_cache()) {
+ write_buffer_manager_->FreeMem(
+ bytes_allocated_.load(std::memory_order_relaxed));
+ } else {
+ assert(bytes_allocated_.load(std::memory_order_relaxed) == 0);
+ }
+ freed_ = true;
+ }
+}
+} // namespace rocksdb
diff --git a/src/rocksdb/memtable/hash_linklist_rep.cc b/src/rocksdb/memtable/hash_linklist_rep.cc
new file mode 100644
index 00000000..878d2338
--- /dev/null
+++ b/src/rocksdb/memtable/hash_linklist_rep.cc
@@ -0,0 +1,845 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+
+#ifndef ROCKSDB_LITE
+#include "memtable/hash_linklist_rep.h"
+
+#include <algorithm>
+#include <atomic>
+#include "db/memtable.h"
+#include "memtable/skiplist.h"
+#include "monitoring/histogram.h"
+#include "port/port.h"
+#include "rocksdb/memtablerep.h"
+#include "rocksdb/slice.h"
+#include "rocksdb/slice_transform.h"
+#include "util/arena.h"
+#include "util/hash.h"
+
+namespace rocksdb {
+namespace {
+
+typedef const char* Key;
+typedef SkipList<Key, const MemTableRep::KeyComparator&> MemtableSkipList;
+typedef std::atomic<void*> Pointer;
+
+// A data structure used as the header of a link list of a hash bucket.
+struct BucketHeader {
+ Pointer next;
+ std::atomic<uint32_t> num_entries;
+
+ explicit BucketHeader(void* n, uint32_t count)
+ : next(n), num_entries(count) {}
+
+ bool IsSkipListBucket() {
+ return next.load(std::memory_order_relaxed) == this;
+ }
+
+ uint32_t GetNumEntries() const {
+ return num_entries.load(std::memory_order_relaxed);
+ }
+
+ // REQUIRES: called from single-threaded Insert()
+ void IncNumEntries() {
+ // Only one thread can do write at one time. No need to do atomic
+ // incremental. Update it with relaxed load and store.
+ num_entries.store(GetNumEntries() + 1, std::memory_order_relaxed);
+ }
+};
+
+// A data structure used as the header of a skip list of a hash bucket.
+struct SkipListBucketHeader {
+ BucketHeader Counting_header;
+ MemtableSkipList skip_list;
+
+ explicit SkipListBucketHeader(const MemTableRep::KeyComparator& cmp,
+ Allocator* allocator, uint32_t count)
+ : Counting_header(this, // Pointing to itself to indicate header type.
+ count),
+ skip_list(cmp, allocator) {}
+};
+
+struct Node {
+ // Accessors/mutators for links. Wrapped in methods so we can
+ // add the appropriate barriers as necessary.
+ Node* Next() {
+ // Use an 'acquire load' so that we observe a fully initialized
+ // version of the returned Node.
+ return next_.load(std::memory_order_acquire);
+ }
+ void SetNext(Node* x) {
+ // Use a 'release store' so that anybody who reads through this
+ // pointer observes a fully initialized version of the inserted node.
+ next_.store(x, std::memory_order_release);
+ }
+ // No-barrier variants that can be safely used in a few locations.
+ Node* NoBarrier_Next() {
+ return next_.load(std::memory_order_relaxed);
+ }
+
+ void NoBarrier_SetNext(Node* x) { next_.store(x, std::memory_order_relaxed); }
+
+ // Needed for placement new below which is fine
+ Node() {}
+
+ private:
+ std::atomic<Node*> next_;
+
+ // Prohibit copying due to the below
+ Node(const Node&) = delete;
+ Node& operator=(const Node&) = delete;
+
+ public:
+ char key[1];
+};
+
+// Memory structure of the mem table:
+// It is a hash table, each bucket points to one entry, a linked list or a
+// skip list. In order to track total number of records in a bucket to determine
+// whether should switch to skip list, a header is added just to indicate
+// number of entries in the bucket.
+//
+//
+// +-----> NULL Case 1. Empty bucket
+// |
+// |
+// | +---> +-------+
+// | | | Next +--> NULL
+// | | +-------+
+// +-----+ | | | | Case 2. One Entry in bucket.
+// | +-+ | | Data | next pointer points to
+// +-----+ | | | NULL. All other cases
+// | | | | | next pointer is not NULL.
+// +-----+ | +-------+
+// | +---+
+// +-----+ +-> +-------+ +> +-------+ +-> +-------+
+// | | | | Next +--+ | Next +--+ | Next +-->NULL
+// +-----+ | +-------+ +-------+ +-------+
+// | +-----+ | Count | | | | |
+// +-----+ +-------+ | Data | | Data |
+// | | | | | |
+// +-----+ Case 3. | | | |
+// | | A header +-------+ +-------+
+// +-----+ points to
+// | | a linked list. Count indicates total number
+// +-----+ of rows in this bucket.
+// | |
+// +-----+ +-> +-------+ <--+
+// | | | | Next +----+
+// +-----+ | +-------+ Case 4. A header points to a skip
+// | +----+ | Count | list and next pointer points to
+// +-----+ +-------+ itself, to distinguish case 3 or 4.
+// | | | | Count still is kept to indicates total
+// +-----+ | Skip +--> of entries in the bucket for debugging
+// | | | List | Data purpose.
+// | | | +-->
+// +-----+ | |
+// | | +-------+
+// +-----+
+//
+// We don't have data race when changing cases because:
+// (1) When changing from case 2->3, we create a new bucket header, put the
+// single node there first without changing the original node, and do a
+// release store when changing the bucket pointer. In that case, a reader
+// who sees a stale value of the bucket pointer will read this node, while
+// a reader sees the correct value because of the release store.
+// (2) When changing case 3->4, a new header is created with skip list points
+// to the data, before doing an acquire store to change the bucket pointer.
+// The old header and nodes are never changed, so any reader sees any
+// of those existing pointers will guarantee to be able to iterate to the
+// end of the linked list.
+// (3) Header's next pointer in case 3 might change, but they are never equal
+// to itself, so no matter a reader sees any stale or newer value, it will
+// be able to correctly distinguish case 3 and 4.
+//
+// The reason that we use case 2 is we want to make the format to be efficient
+// when the utilization of buckets is relatively low. If we use case 3 for
+// single entry bucket, we will need to waste 12 bytes for every entry,
+// which can be significant decrease of memory utilization.
+class HashLinkListRep : public MemTableRep {
+ public:
+ HashLinkListRep(const MemTableRep::KeyComparator& compare,
+ Allocator* allocator, const SliceTransform* transform,
+ size_t bucket_size, uint32_t threshold_use_skiplist,
+ size_t huge_page_tlb_size, Logger* logger,
+ int bucket_entries_logging_threshold,
+ bool if_log_bucket_dist_when_flash);
+
+ KeyHandle Allocate(const size_t len, char** buf) override;
+
+ void Insert(KeyHandle handle) override;
+
+ bool Contains(const char* key) const override;
+
+ size_t ApproximateMemoryUsage() override;
+
+ void Get(const LookupKey& k, void* callback_args,
+ bool (*callback_func)(void* arg, const char* entry)) override;
+
+ ~HashLinkListRep() override;
+
+ MemTableRep::Iterator* GetIterator(Arena* arena = nullptr) override;
+
+ MemTableRep::Iterator* GetDynamicPrefixIterator(
+ Arena* arena = nullptr) override;
+
+ private:
+ friend class DynamicIterator;
+
+ size_t bucket_size_;
+
+ // Maps slices (which are transformed user keys) to buckets of keys sharing
+ // the same transform.
+ Pointer* buckets_;
+
+ const uint32_t threshold_use_skiplist_;
+
+ // The user-supplied transform whose domain is the user keys.
+ const SliceTransform* transform_;
+
+ const MemTableRep::KeyComparator& compare_;
+
+ Logger* logger_;
+ int bucket_entries_logging_threshold_;
+ bool if_log_bucket_dist_when_flash_;
+
+ bool LinkListContains(Node* head, const Slice& key) const;
+
+ SkipListBucketHeader* GetSkipListBucketHeader(Pointer* first_next_pointer)
+ const;
+
+ Node* GetLinkListFirstNode(Pointer* first_next_pointer) const;
+
+ Slice GetPrefix(const Slice& internal_key) const {
+ return transform_->Transform(ExtractUserKey(internal_key));
+ }
+
+ size_t GetHash(const Slice& slice) const {
+ return NPHash64(slice.data(), static_cast<int>(slice.size()), 0) %
+ bucket_size_;
+ }
+
+ Pointer* GetBucket(size_t i) const {
+ return static_cast<Pointer*>(buckets_[i].load(std::memory_order_acquire));
+ }
+
+ Pointer* GetBucket(const Slice& slice) const {
+ return GetBucket(GetHash(slice));
+ }
+
+ bool Equal(const Slice& a, const Key& b) const {
+ return (compare_(b, a) == 0);
+ }
+
+ bool Equal(const Key& a, const Key& b) const { return (compare_(a, b) == 0); }
+
+ bool KeyIsAfterNode(const Slice& internal_key, const Node* n) const {
+ // nullptr n is considered infinite
+ return (n != nullptr) && (compare_(n->key, internal_key) < 0);
+ }
+
+ bool KeyIsAfterNode(const Key& key, const Node* n) const {
+ // nullptr n is considered infinite
+ return (n != nullptr) && (compare_(n->key, key) < 0);
+ }
+
+ bool KeyIsAfterOrAtNode(const Slice& internal_key, const Node* n) const {
+ // nullptr n is considered infinite
+ return (n != nullptr) && (compare_(n->key, internal_key) <= 0);
+ }
+
+ bool KeyIsAfterOrAtNode(const Key& key, const Node* n) const {
+ // nullptr n is considered infinite
+ return (n != nullptr) && (compare_(n->key, key) <= 0);
+ }
+
+ Node* FindGreaterOrEqualInBucket(Node* head, const Slice& key) const;
+ Node* FindLessOrEqualInBucket(Node* head, const Slice& key) const;
+
+ class FullListIterator : public MemTableRep::Iterator {
+ public:
+ explicit FullListIterator(MemtableSkipList* list, Allocator* allocator)
+ : iter_(list), full_list_(list), allocator_(allocator) {}
+
+ ~FullListIterator() override {}
+
+ // Returns true iff the iterator is positioned at a valid node.
+ bool Valid() const override { return iter_.Valid(); }
+
+ // Returns the key at the current position.
+ // REQUIRES: Valid()
+ const char* key() const override {
+ assert(Valid());
+ return iter_.key();
+ }
+
+ // Advances to the next position.
+ // REQUIRES: Valid()
+ void Next() override {
+ assert(Valid());
+ iter_.Next();
+ }
+
+ // Advances to the previous position.
+ // REQUIRES: Valid()
+ void Prev() override {
+ assert(Valid());
+ iter_.Prev();
+ }
+
+ // Advance to the first entry with a key >= target
+ void Seek(const Slice& internal_key, const char* memtable_key) override {
+ const char* encoded_key =
+ (memtable_key != nullptr) ?
+ memtable_key : EncodeKey(&tmp_, internal_key);
+ iter_.Seek(encoded_key);
+ }
+
+ // Retreat to the last entry with a key <= target
+ void SeekForPrev(const Slice& internal_key,
+ const char* memtable_key) override {
+ const char* encoded_key = (memtable_key != nullptr)
+ ? memtable_key
+ : EncodeKey(&tmp_, internal_key);
+ iter_.SeekForPrev(encoded_key);
+ }
+
+ // Position at the first entry in collection.
+ // Final state of iterator is Valid() iff collection is not empty.
+ void SeekToFirst() override { iter_.SeekToFirst(); }
+
+ // Position at the last entry in collection.
+ // Final state of iterator is Valid() iff collection is not empty.
+ void SeekToLast() override { iter_.SeekToLast(); }
+
+ private:
+ MemtableSkipList::Iterator iter_;
+ // To destruct with the iterator.
+ std::unique_ptr<MemtableSkipList> full_list_;
+ std::unique_ptr<Allocator> allocator_;
+ std::string tmp_; // For passing to EncodeKey
+ };
+
+ class LinkListIterator : public MemTableRep::Iterator {
+ public:
+ explicit LinkListIterator(const HashLinkListRep* const hash_link_list_rep,
+ Node* head)
+ : hash_link_list_rep_(hash_link_list_rep),
+ head_(head),
+ node_(nullptr) {}
+
+ ~LinkListIterator() override {}
+
+ // Returns true iff the iterator is positioned at a valid node.
+ bool Valid() const override { return node_ != nullptr; }
+
+ // Returns the key at the current position.
+ // REQUIRES: Valid()
+ const char* key() const override {
+ assert(Valid());
+ return node_->key;
+ }
+
+ // Advances to the next position.
+ // REQUIRES: Valid()
+ void Next() override {
+ assert(Valid());
+ node_ = node_->Next();
+ }
+
+ // Advances to the previous position.
+ // REQUIRES: Valid()
+ void Prev() override {
+ // Prefix iterator does not support total order.
+ // We simply set the iterator to invalid state
+ Reset(nullptr);
+ }
+
+ // Advance to the first entry with a key >= target
+ void Seek(const Slice& internal_key,
+ const char* /*memtable_key*/) override {
+ node_ = hash_link_list_rep_->FindGreaterOrEqualInBucket(head_,
+ internal_key);
+ }
+
+ // Retreat to the last entry with a key <= target
+ void SeekForPrev(const Slice& /*internal_key*/,
+ const char* /*memtable_key*/) override {
+ // Since we do not support Prev()
+ // We simply do not support SeekForPrev
+ Reset(nullptr);
+ }
+
+ // Position at the first entry in collection.
+ // Final state of iterator is Valid() iff collection is not empty.
+ void SeekToFirst() override {
+ // Prefix iterator does not support total order.
+ // We simply set the iterator to invalid state
+ Reset(nullptr);
+ }
+
+ // Position at the last entry in collection.
+ // Final state of iterator is Valid() iff collection is not empty.
+ void SeekToLast() override {
+ // Prefix iterator does not support total order.
+ // We simply set the iterator to invalid state
+ Reset(nullptr);
+ }
+
+ protected:
+ void Reset(Node* head) {
+ head_ = head;
+ node_ = nullptr;
+ }
+ private:
+ friend class HashLinkListRep;
+ const HashLinkListRep* const hash_link_list_rep_;
+ Node* head_;
+ Node* node_;
+
+ virtual void SeekToHead() {
+ node_ = head_;
+ }
+ };
+
+ class DynamicIterator : public HashLinkListRep::LinkListIterator {
+ public:
+ explicit DynamicIterator(HashLinkListRep& memtable_rep)
+ : HashLinkListRep::LinkListIterator(&memtable_rep, nullptr),
+ memtable_rep_(memtable_rep) {}
+
+ // Advance to the first entry with a key >= target
+ void Seek(const Slice& k, const char* memtable_key) override {
+ auto transformed = memtable_rep_.GetPrefix(k);
+ auto* bucket = memtable_rep_.GetBucket(transformed);
+
+ SkipListBucketHeader* skip_list_header =
+ memtable_rep_.GetSkipListBucketHeader(bucket);
+ if (skip_list_header != nullptr) {
+ // The bucket is organized as a skip list
+ if (!skip_list_iter_) {
+ skip_list_iter_.reset(
+ new MemtableSkipList::Iterator(&skip_list_header->skip_list));
+ } else {
+ skip_list_iter_->SetList(&skip_list_header->skip_list);
+ }
+ if (memtable_key != nullptr) {
+ skip_list_iter_->Seek(memtable_key);
+ } else {
+ IterKey encoded_key;
+ encoded_key.EncodeLengthPrefixedKey(k);
+ skip_list_iter_->Seek(encoded_key.GetUserKey().data());
+ }
+ } else {
+ // The bucket is organized as a linked list
+ skip_list_iter_.reset();
+ Reset(memtable_rep_.GetLinkListFirstNode(bucket));
+ HashLinkListRep::LinkListIterator::Seek(k, memtable_key);
+ }
+ }
+
+ bool Valid() const override {
+ if (skip_list_iter_) {
+ return skip_list_iter_->Valid();
+ }
+ return HashLinkListRep::LinkListIterator::Valid();
+ }
+
+ const char* key() const override {
+ if (skip_list_iter_) {
+ return skip_list_iter_->key();
+ }
+ return HashLinkListRep::LinkListIterator::key();
+ }
+
+ void Next() override {
+ if (skip_list_iter_) {
+ skip_list_iter_->Next();
+ } else {
+ HashLinkListRep::LinkListIterator::Next();
+ }
+ }
+
+ private:
+ // the underlying memtable
+ const HashLinkListRep& memtable_rep_;
+ std::unique_ptr<MemtableSkipList::Iterator> skip_list_iter_;
+ };
+
+ class EmptyIterator : public MemTableRep::Iterator {
+ // This is used when there wasn't a bucket. It is cheaper than
+ // instantiating an empty bucket over which to iterate.
+ public:
+ EmptyIterator() { }
+ bool Valid() const override { return false; }
+ const char* key() const override {
+ assert(false);
+ return nullptr;
+ }
+ void Next() override {}
+ void Prev() override {}
+ void Seek(const Slice& /*user_key*/,
+ const char* /*memtable_key*/) override {}
+ void SeekForPrev(const Slice& /*user_key*/,
+ const char* /*memtable_key*/) override {}
+ void SeekToFirst() override {}
+ void SeekToLast() override {}
+
+ private:
+ };
+};
+
+HashLinkListRep::HashLinkListRep(
+ const MemTableRep::KeyComparator& compare, Allocator* allocator,
+ const SliceTransform* transform, size_t bucket_size,
+ uint32_t threshold_use_skiplist, size_t huge_page_tlb_size, Logger* logger,
+ int bucket_entries_logging_threshold, bool if_log_bucket_dist_when_flash)
+ : MemTableRep(allocator),
+ bucket_size_(bucket_size),
+ // Threshold to use skip list doesn't make sense if less than 3, so we
+ // force it to be minimum of 3 to simplify implementation.
+ threshold_use_skiplist_(std::max(threshold_use_skiplist, 3U)),
+ transform_(transform),
+ compare_(compare),
+ logger_(logger),
+ bucket_entries_logging_threshold_(bucket_entries_logging_threshold),
+ if_log_bucket_dist_when_flash_(if_log_bucket_dist_when_flash) {
+ char* mem = allocator_->AllocateAligned(sizeof(Pointer) * bucket_size,
+ huge_page_tlb_size, logger);
+
+ buckets_ = new (mem) Pointer[bucket_size];
+
+ for (size_t i = 0; i < bucket_size_; ++i) {
+ buckets_[i].store(nullptr, std::memory_order_relaxed);
+ }
+}
+
+HashLinkListRep::~HashLinkListRep() {
+}
+
+KeyHandle HashLinkListRep::Allocate(const size_t len, char** buf) {
+ char* mem = allocator_->AllocateAligned(sizeof(Node) + len);
+ Node* x = new (mem) Node();
+ *buf = x->key;
+ return static_cast<void*>(x);
+}
+
+SkipListBucketHeader* HashLinkListRep::GetSkipListBucketHeader(
+ Pointer* first_next_pointer) const {
+ if (first_next_pointer == nullptr) {
+ return nullptr;
+ }
+ if (first_next_pointer->load(std::memory_order_relaxed) == nullptr) {
+ // Single entry bucket
+ return nullptr;
+ }
+ // Counting header
+ BucketHeader* header = reinterpret_cast<BucketHeader*>(first_next_pointer);
+ if (header->IsSkipListBucket()) {
+ assert(header->GetNumEntries() > threshold_use_skiplist_);
+ auto* skip_list_bucket_header =
+ reinterpret_cast<SkipListBucketHeader*>(header);
+ assert(skip_list_bucket_header->Counting_header.next.load(
+ std::memory_order_relaxed) == header);
+ return skip_list_bucket_header;
+ }
+ assert(header->GetNumEntries() <= threshold_use_skiplist_);
+ return nullptr;
+}
+
+Node* HashLinkListRep::GetLinkListFirstNode(Pointer* first_next_pointer) const {
+ if (first_next_pointer == nullptr) {
+ return nullptr;
+ }
+ if (first_next_pointer->load(std::memory_order_relaxed) == nullptr) {
+ // Single entry bucket
+ return reinterpret_cast<Node*>(first_next_pointer);
+ }
+ // Counting header
+ BucketHeader* header = reinterpret_cast<BucketHeader*>(first_next_pointer);
+ if (!header->IsSkipListBucket()) {
+ assert(header->GetNumEntries() <= threshold_use_skiplist_);
+ return reinterpret_cast<Node*>(
+ header->next.load(std::memory_order_acquire));
+ }
+ assert(header->GetNumEntries() > threshold_use_skiplist_);
+ return nullptr;
+}
+
+void HashLinkListRep::Insert(KeyHandle handle) {
+ Node* x = static_cast<Node*>(handle);
+ assert(!Contains(x->key));
+ Slice internal_key = GetLengthPrefixedSlice(x->key);
+ auto transformed = GetPrefix(internal_key);
+ auto& bucket = buckets_[GetHash(transformed)];
+ Pointer* first_next_pointer =
+ static_cast<Pointer*>(bucket.load(std::memory_order_relaxed));
+
+ if (first_next_pointer == nullptr) {
+ // Case 1. empty bucket
+ // NoBarrier_SetNext() suffices since we will add a barrier when
+ // we publish a pointer to "x" in prev[i].
+ x->NoBarrier_SetNext(nullptr);
+ bucket.store(x, std::memory_order_release);
+ return;
+ }
+
+ BucketHeader* header = nullptr;
+ if (first_next_pointer->load(std::memory_order_relaxed) == nullptr) {
+ // Case 2. only one entry in the bucket
+ // Need to convert to a Counting bucket and turn to case 4.
+ Node* first = reinterpret_cast<Node*>(first_next_pointer);
+ // Need to add a bucket header.
+ // We have to first convert it to a bucket with header before inserting
+ // the new node. Otherwise, we might need to change next pointer of first.
+ // In that case, a reader might sees the next pointer is NULL and wrongly
+ // think the node is a bucket header.
+ auto* mem = allocator_->AllocateAligned(sizeof(BucketHeader));
+ header = new (mem) BucketHeader(first, 1);
+ bucket.store(header, std::memory_order_release);
+ } else {
+ header = reinterpret_cast<BucketHeader*>(first_next_pointer);
+ if (header->IsSkipListBucket()) {
+ // Case 4. Bucket is already a skip list
+ assert(header->GetNumEntries() > threshold_use_skiplist_);
+ auto* skip_list_bucket_header =
+ reinterpret_cast<SkipListBucketHeader*>(header);
+ // Only one thread can execute Insert() at one time. No need to do atomic
+ // incremental.
+ skip_list_bucket_header->Counting_header.IncNumEntries();
+ skip_list_bucket_header->skip_list.Insert(x->key);
+ return;
+ }
+ }
+
+ if (bucket_entries_logging_threshold_ > 0 &&
+ header->GetNumEntries() ==
+ static_cast<uint32_t>(bucket_entries_logging_threshold_)) {
+ Info(logger_, "HashLinkedList bucket %" ROCKSDB_PRIszt
+ " has more than %d "
+ "entries. Key to insert: %s",
+ GetHash(transformed), header->GetNumEntries(),
+ GetLengthPrefixedSlice(x->key).ToString(true).c_str());
+ }
+
+ if (header->GetNumEntries() == threshold_use_skiplist_) {
+ // Case 3. number of entries reaches the threshold so need to convert to
+ // skip list.
+ LinkListIterator bucket_iter(
+ this, reinterpret_cast<Node*>(
+ first_next_pointer->load(std::memory_order_relaxed)));
+ auto mem = allocator_->AllocateAligned(sizeof(SkipListBucketHeader));
+ SkipListBucketHeader* new_skip_list_header = new (mem)
+ SkipListBucketHeader(compare_, allocator_, header->GetNumEntries() + 1);
+ auto& skip_list = new_skip_list_header->skip_list;
+
+ // Add all current entries to the skip list
+ for (bucket_iter.SeekToHead(); bucket_iter.Valid(); bucket_iter.Next()) {
+ skip_list.Insert(bucket_iter.key());
+ }
+
+ // insert the new entry
+ skip_list.Insert(x->key);
+ // Set the bucket
+ bucket.store(new_skip_list_header, std::memory_order_release);
+ } else {
+ // Case 5. Need to insert to the sorted linked list without changing the
+ // header.
+ Node* first =
+ reinterpret_cast<Node*>(header->next.load(std::memory_order_relaxed));
+ assert(first != nullptr);
+ // Advance counter unless the bucket needs to be advanced to skip list.
+ // In that case, we need to make sure the previous count never exceeds
+ // threshold_use_skiplist_ to avoid readers to cast to wrong format.
+ header->IncNumEntries();
+
+ Node* cur = first;
+ Node* prev = nullptr;
+ while (true) {
+ if (cur == nullptr) {
+ break;
+ }
+ Node* next = cur->Next();
+ // Make sure the lists are sorted.
+ // If x points to head_ or next points nullptr, it is trivially satisfied.
+ assert((cur == first) || (next == nullptr) ||
+ KeyIsAfterNode(next->key, cur));
+ if (KeyIsAfterNode(internal_key, cur)) {
+ // Keep searching in this list
+ prev = cur;
+ cur = next;
+ } else {
+ break;
+ }
+ }
+
+ // Our data structure does not allow duplicate insertion
+ assert(cur == nullptr || !Equal(x->key, cur->key));
+
+ // NoBarrier_SetNext() suffices since we will add a barrier when
+ // we publish a pointer to "x" in prev[i].
+ x->NoBarrier_SetNext(cur);
+
+ if (prev) {
+ prev->SetNext(x);
+ } else {
+ header->next.store(static_cast<void*>(x), std::memory_order_release);
+ }
+ }
+}
+
+bool HashLinkListRep::Contains(const char* key) const {
+ Slice internal_key = GetLengthPrefixedSlice(key);
+
+ auto transformed = GetPrefix(internal_key);
+ auto bucket = GetBucket(transformed);
+ if (bucket == nullptr) {
+ return false;
+ }
+
+ SkipListBucketHeader* skip_list_header = GetSkipListBucketHeader(bucket);
+ if (skip_list_header != nullptr) {
+ return skip_list_header->skip_list.Contains(key);
+ } else {
+ return LinkListContains(GetLinkListFirstNode(bucket), internal_key);
+ }
+}
+
+size_t HashLinkListRep::ApproximateMemoryUsage() {
+ // Memory is always allocated from the allocator.
+ return 0;
+}
+
+void HashLinkListRep::Get(const LookupKey& k, void* callback_args,
+ bool (*callback_func)(void* arg, const char* entry)) {
+ auto transformed = transform_->Transform(k.user_key());
+ auto bucket = GetBucket(transformed);
+
+ auto* skip_list_header = GetSkipListBucketHeader(bucket);
+ if (skip_list_header != nullptr) {
+ // Is a skip list
+ MemtableSkipList::Iterator iter(&skip_list_header->skip_list);
+ for (iter.Seek(k.memtable_key().data());
+ iter.Valid() && callback_func(callback_args, iter.key());
+ iter.Next()) {
+ }
+ } else {
+ auto* link_list_head = GetLinkListFirstNode(bucket);
+ if (link_list_head != nullptr) {
+ LinkListIterator iter(this, link_list_head);
+ for (iter.Seek(k.internal_key(), nullptr);
+ iter.Valid() && callback_func(callback_args, iter.key());
+ iter.Next()) {
+ }
+ }
+ }
+}
+
+MemTableRep::Iterator* HashLinkListRep::GetIterator(Arena* alloc_arena) {
+ // allocate a new arena of similar size to the one currently in use
+ Arena* new_arena = new Arena(allocator_->BlockSize());
+ auto list = new MemtableSkipList(compare_, new_arena);
+ HistogramImpl keys_per_bucket_hist;
+
+ for (size_t i = 0; i < bucket_size_; ++i) {
+ int count = 0;
+ auto* bucket = GetBucket(i);
+ if (bucket != nullptr) {
+ auto* skip_list_header = GetSkipListBucketHeader(bucket);
+ if (skip_list_header != nullptr) {
+ // Is a skip list
+ MemtableSkipList::Iterator itr(&skip_list_header->skip_list);
+ for (itr.SeekToFirst(); itr.Valid(); itr.Next()) {
+ list->Insert(itr.key());
+ count++;
+ }
+ } else {
+ auto* link_list_head = GetLinkListFirstNode(bucket);
+ if (link_list_head != nullptr) {
+ LinkListIterator itr(this, link_list_head);
+ for (itr.SeekToHead(); itr.Valid(); itr.Next()) {
+ list->Insert(itr.key());
+ count++;
+ }
+ }
+ }
+ }
+ if (if_log_bucket_dist_when_flash_) {
+ keys_per_bucket_hist.Add(count);
+ }
+ }
+ if (if_log_bucket_dist_when_flash_ && logger_ != nullptr) {
+ Info(logger_, "hashLinkedList Entry distribution among buckets: %s",
+ keys_per_bucket_hist.ToString().c_str());
+ }
+
+ if (alloc_arena == nullptr) {
+ return new FullListIterator(list, new_arena);
+ } else {
+ auto mem = alloc_arena->AllocateAligned(sizeof(FullListIterator));
+ return new (mem) FullListIterator(list, new_arena);
+ }
+}
+
+MemTableRep::Iterator* HashLinkListRep::GetDynamicPrefixIterator(
+ Arena* alloc_arena) {
+ if (alloc_arena == nullptr) {
+ return new DynamicIterator(*this);
+ } else {
+ auto mem = alloc_arena->AllocateAligned(sizeof(DynamicIterator));
+ return new (mem) DynamicIterator(*this);
+ }
+}
+
+bool HashLinkListRep::LinkListContains(Node* head,
+ const Slice& user_key) const {
+ Node* x = FindGreaterOrEqualInBucket(head, user_key);
+ return (x != nullptr && Equal(user_key, x->key));
+}
+
+Node* HashLinkListRep::FindGreaterOrEqualInBucket(Node* head,
+ const Slice& key) const {
+ Node* x = head;
+ while (true) {
+ if (x == nullptr) {
+ return x;
+ }
+ Node* next = x->Next();
+ // Make sure the lists are sorted.
+ // If x points to head_ or next points nullptr, it is trivially satisfied.
+ assert((x == head) || (next == nullptr) || KeyIsAfterNode(next->key, x));
+ if (KeyIsAfterNode(key, x)) {
+ // Keep searching in this list
+ x = next;
+ } else {
+ break;
+ }
+ }
+ return x;
+}
+
+} // anon namespace
+
+MemTableRep* HashLinkListRepFactory::CreateMemTableRep(
+ const MemTableRep::KeyComparator& compare, Allocator* allocator,
+ const SliceTransform* transform, Logger* logger) {
+ return new HashLinkListRep(compare, allocator, transform, bucket_count_,
+ threshold_use_skiplist_, huge_page_tlb_size_,
+ logger, bucket_entries_logging_threshold_,
+ if_log_bucket_dist_when_flash_);
+}
+
+MemTableRepFactory* NewHashLinkListRepFactory(
+ size_t bucket_count, size_t huge_page_tlb_size,
+ int bucket_entries_logging_threshold, bool if_log_bucket_dist_when_flash,
+ uint32_t threshold_use_skiplist) {
+ return new HashLinkListRepFactory(
+ bucket_count, threshold_use_skiplist, huge_page_tlb_size,
+ bucket_entries_logging_threshold, if_log_bucket_dist_when_flash);
+}
+
+} // namespace rocksdb
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/memtable/hash_linklist_rep.h b/src/rocksdb/memtable/hash_linklist_rep.h
new file mode 100644
index 00000000..a6da3eed
--- /dev/null
+++ b/src/rocksdb/memtable/hash_linklist_rep.h
@@ -0,0 +1,49 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#pragma once
+#ifndef ROCKSDB_LITE
+#include "rocksdb/slice_transform.h"
+#include "rocksdb/memtablerep.h"
+
+namespace rocksdb {
+
+class HashLinkListRepFactory : public MemTableRepFactory {
+ public:
+ explicit HashLinkListRepFactory(size_t bucket_count,
+ uint32_t threshold_use_skiplist,
+ size_t huge_page_tlb_size,
+ int bucket_entries_logging_threshold,
+ bool if_log_bucket_dist_when_flash)
+ : bucket_count_(bucket_count),
+ threshold_use_skiplist_(threshold_use_skiplist),
+ huge_page_tlb_size_(huge_page_tlb_size),
+ bucket_entries_logging_threshold_(bucket_entries_logging_threshold),
+ if_log_bucket_dist_when_flash_(if_log_bucket_dist_when_flash) {}
+
+ virtual ~HashLinkListRepFactory() {}
+
+ using MemTableRepFactory::CreateMemTableRep;
+ virtual MemTableRep* CreateMemTableRep(
+ const MemTableRep::KeyComparator& compare, Allocator* allocator,
+ const SliceTransform* transform, Logger* logger) override;
+
+ virtual const char* Name() const override {
+ return "HashLinkListRepFactory";
+ }
+
+ private:
+ const size_t bucket_count_;
+ const uint32_t threshold_use_skiplist_;
+ const size_t huge_page_tlb_size_;
+ int bucket_entries_logging_threshold_;
+ bool if_log_bucket_dist_when_flash_;
+};
+
+}
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/memtable/hash_skiplist_rep.cc b/src/rocksdb/memtable/hash_skiplist_rep.cc
new file mode 100644
index 00000000..d02919cd
--- /dev/null
+++ b/src/rocksdb/memtable/hash_skiplist_rep.cc
@@ -0,0 +1,349 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+
+#ifndef ROCKSDB_LITE
+#include "memtable/hash_skiplist_rep.h"
+
+#include <atomic>
+
+#include "rocksdb/memtablerep.h"
+#include "util/arena.h"
+#include "rocksdb/slice.h"
+#include "rocksdb/slice_transform.h"
+#include "port/port.h"
+#include "util/murmurhash.h"
+#include "db/memtable.h"
+#include "memtable/skiplist.h"
+
+namespace rocksdb {
+namespace {
+
+class HashSkipListRep : public MemTableRep {
+ public:
+ HashSkipListRep(const MemTableRep::KeyComparator& compare,
+ Allocator* allocator, const SliceTransform* transform,
+ size_t bucket_size, int32_t skiplist_height,
+ int32_t skiplist_branching_factor);
+
+ void Insert(KeyHandle handle) override;
+
+ bool Contains(const char* key) const override;
+
+ size_t ApproximateMemoryUsage() override;
+
+ void Get(const LookupKey& k, void* callback_args,
+ bool (*callback_func)(void* arg, const char* entry)) override;
+
+ ~HashSkipListRep() override;
+
+ MemTableRep::Iterator* GetIterator(Arena* arena = nullptr) override;
+
+ MemTableRep::Iterator* GetDynamicPrefixIterator(
+ Arena* arena = nullptr) override;
+
+ private:
+ friend class DynamicIterator;
+ typedef SkipList<const char*, const MemTableRep::KeyComparator&> Bucket;
+
+ size_t bucket_size_;
+
+ const int32_t skiplist_height_;
+ const int32_t skiplist_branching_factor_;
+
+ // Maps slices (which are transformed user keys) to buckets of keys sharing
+ // the same transform.
+ std::atomic<Bucket*>* buckets_;
+
+ // The user-supplied transform whose domain is the user keys.
+ const SliceTransform* transform_;
+
+ const MemTableRep::KeyComparator& compare_;
+ // immutable after construction
+ Allocator* const allocator_;
+
+ inline size_t GetHash(const Slice& slice) const {
+ return MurmurHash(slice.data(), static_cast<int>(slice.size()), 0) %
+ bucket_size_;
+ }
+ inline Bucket* GetBucket(size_t i) const {
+ return buckets_[i].load(std::memory_order_acquire);
+ }
+ inline Bucket* GetBucket(const Slice& slice) const {
+ return GetBucket(GetHash(slice));
+ }
+ // Get a bucket from buckets_. If the bucket hasn't been initialized yet,
+ // initialize it before returning.
+ Bucket* GetInitializedBucket(const Slice& transformed);
+
+ class Iterator : public MemTableRep::Iterator {
+ public:
+ explicit Iterator(Bucket* list, bool own_list = true,
+ Arena* arena = nullptr)
+ : list_(list), iter_(list), own_list_(own_list), arena_(arena) {}
+
+ ~Iterator() override {
+ // if we own the list, we should also delete it
+ if (own_list_) {
+ assert(list_ != nullptr);
+ delete list_;
+ }
+ }
+
+ // Returns true iff the iterator is positioned at a valid node.
+ bool Valid() const override { return list_ != nullptr && iter_.Valid(); }
+
+ // Returns the key at the current position.
+ // REQUIRES: Valid()
+ const char* key() const override {
+ assert(Valid());
+ return iter_.key();
+ }
+
+ // Advances to the next position.
+ // REQUIRES: Valid()
+ void Next() override {
+ assert(Valid());
+ iter_.Next();
+ }
+
+ // Advances to the previous position.
+ // REQUIRES: Valid()
+ void Prev() override {
+ assert(Valid());
+ iter_.Prev();
+ }
+
+ // Advance to the first entry with a key >= target
+ void Seek(const Slice& internal_key, const char* memtable_key) override {
+ if (list_ != nullptr) {
+ const char* encoded_key =
+ (memtable_key != nullptr) ?
+ memtable_key : EncodeKey(&tmp_, internal_key);
+ iter_.Seek(encoded_key);
+ }
+ }
+
+ // Retreat to the last entry with a key <= target
+ void SeekForPrev(const Slice& /*internal_key*/,
+ const char* /*memtable_key*/) override {
+ // not supported
+ assert(false);
+ }
+
+ // Position at the first entry in collection.
+ // Final state of iterator is Valid() iff collection is not empty.
+ void SeekToFirst() override {
+ if (list_ != nullptr) {
+ iter_.SeekToFirst();
+ }
+ }
+
+ // Position at the last entry in collection.
+ // Final state of iterator is Valid() iff collection is not empty.
+ void SeekToLast() override {
+ if (list_ != nullptr) {
+ iter_.SeekToLast();
+ }
+ }
+
+ protected:
+ void Reset(Bucket* list) {
+ if (own_list_) {
+ assert(list_ != nullptr);
+ delete list_;
+ }
+ list_ = list;
+ iter_.SetList(list);
+ own_list_ = false;
+ }
+ private:
+ // if list_ is nullptr, we should NEVER call any methods on iter_
+ // if list_ is nullptr, this Iterator is not Valid()
+ Bucket* list_;
+ Bucket::Iterator iter_;
+ // here we track if we own list_. If we own it, we are also
+ // responsible for it's cleaning. This is a poor man's std::shared_ptr
+ bool own_list_;
+ std::unique_ptr<Arena> arena_;
+ std::string tmp_; // For passing to EncodeKey
+ };
+
+ class DynamicIterator : public HashSkipListRep::Iterator {
+ public:
+ explicit DynamicIterator(const HashSkipListRep& memtable_rep)
+ : HashSkipListRep::Iterator(nullptr, false),
+ memtable_rep_(memtable_rep) {}
+
+ // Advance to the first entry with a key >= target
+ void Seek(const Slice& k, const char* memtable_key) override {
+ auto transformed = memtable_rep_.transform_->Transform(ExtractUserKey(k));
+ Reset(memtable_rep_.GetBucket(transformed));
+ HashSkipListRep::Iterator::Seek(k, memtable_key);
+ }
+
+ // Position at the first entry in collection.
+ // Final state of iterator is Valid() iff collection is not empty.
+ void SeekToFirst() override {
+ // Prefix iterator does not support total order.
+ // We simply set the iterator to invalid state
+ Reset(nullptr);
+ }
+
+ // Position at the last entry in collection.
+ // Final state of iterator is Valid() iff collection is not empty.
+ void SeekToLast() override {
+ // Prefix iterator does not support total order.
+ // We simply set the iterator to invalid state
+ Reset(nullptr);
+ }
+
+ private:
+ // the underlying memtable
+ const HashSkipListRep& memtable_rep_;
+ };
+
+ class EmptyIterator : public MemTableRep::Iterator {
+ // This is used when there wasn't a bucket. It is cheaper than
+ // instantiating an empty bucket over which to iterate.
+ public:
+ EmptyIterator() { }
+ bool Valid() const override { return false; }
+ const char* key() const override {
+ assert(false);
+ return nullptr;
+ }
+ void Next() override {}
+ void Prev() override {}
+ void Seek(const Slice& /*internal_key*/,
+ const char* /*memtable_key*/) override {}
+ void SeekForPrev(const Slice& /*internal_key*/,
+ const char* /*memtable_key*/) override {}
+ void SeekToFirst() override {}
+ void SeekToLast() override {}
+
+ private:
+ };
+};
+
+HashSkipListRep::HashSkipListRep(const MemTableRep::KeyComparator& compare,
+ Allocator* allocator,
+ const SliceTransform* transform,
+ size_t bucket_size, int32_t skiplist_height,
+ int32_t skiplist_branching_factor)
+ : MemTableRep(allocator),
+ bucket_size_(bucket_size),
+ skiplist_height_(skiplist_height),
+ skiplist_branching_factor_(skiplist_branching_factor),
+ transform_(transform),
+ compare_(compare),
+ allocator_(allocator) {
+ auto mem = allocator->AllocateAligned(
+ sizeof(std::atomic<void*>) * bucket_size);
+ buckets_ = new (mem) std::atomic<Bucket*>[bucket_size];
+
+ for (size_t i = 0; i < bucket_size_; ++i) {
+ buckets_[i].store(nullptr, std::memory_order_relaxed);
+ }
+}
+
+HashSkipListRep::~HashSkipListRep() {
+}
+
+HashSkipListRep::Bucket* HashSkipListRep::GetInitializedBucket(
+ const Slice& transformed) {
+ size_t hash = GetHash(transformed);
+ auto bucket = GetBucket(hash);
+ if (bucket == nullptr) {
+ auto addr = allocator_->AllocateAligned(sizeof(Bucket));
+ bucket = new (addr) Bucket(compare_, allocator_, skiplist_height_,
+ skiplist_branching_factor_);
+ buckets_[hash].store(bucket, std::memory_order_release);
+ }
+ return bucket;
+}
+
+void HashSkipListRep::Insert(KeyHandle handle) {
+ auto* key = static_cast<char*>(handle);
+ assert(!Contains(key));
+ auto transformed = transform_->Transform(UserKey(key));
+ auto bucket = GetInitializedBucket(transformed);
+ bucket->Insert(key);
+}
+
+bool HashSkipListRep::Contains(const char* key) const {
+ auto transformed = transform_->Transform(UserKey(key));
+ auto bucket = GetBucket(transformed);
+ if (bucket == nullptr) {
+ return false;
+ }
+ return bucket->Contains(key);
+}
+
+size_t HashSkipListRep::ApproximateMemoryUsage() {
+ return 0;
+}
+
+void HashSkipListRep::Get(const LookupKey& k, void* callback_args,
+ bool (*callback_func)(void* arg, const char* entry)) {
+ auto transformed = transform_->Transform(k.user_key());
+ auto bucket = GetBucket(transformed);
+ if (bucket != nullptr) {
+ Bucket::Iterator iter(bucket);
+ for (iter.Seek(k.memtable_key().data());
+ iter.Valid() && callback_func(callback_args, iter.key());
+ iter.Next()) {
+ }
+ }
+}
+
+MemTableRep::Iterator* HashSkipListRep::GetIterator(Arena* arena) {
+ // allocate a new arena of similar size to the one currently in use
+ Arena* new_arena = new Arena(allocator_->BlockSize());
+ auto list = new Bucket(compare_, new_arena);
+ for (size_t i = 0; i < bucket_size_; ++i) {
+ auto bucket = GetBucket(i);
+ if (bucket != nullptr) {
+ Bucket::Iterator itr(bucket);
+ for (itr.SeekToFirst(); itr.Valid(); itr.Next()) {
+ list->Insert(itr.key());
+ }
+ }
+ }
+ if (arena == nullptr) {
+ return new Iterator(list, true, new_arena);
+ } else {
+ auto mem = arena->AllocateAligned(sizeof(Iterator));
+ return new (mem) Iterator(list, true, new_arena);
+ }
+}
+
+MemTableRep::Iterator* HashSkipListRep::GetDynamicPrefixIterator(Arena* arena) {
+ if (arena == nullptr) {
+ return new DynamicIterator(*this);
+ } else {
+ auto mem = arena->AllocateAligned(sizeof(DynamicIterator));
+ return new (mem) DynamicIterator(*this);
+ }
+}
+
+} // anon namespace
+
+MemTableRep* HashSkipListRepFactory::CreateMemTableRep(
+ const MemTableRep::KeyComparator& compare, Allocator* allocator,
+ const SliceTransform* transform, Logger* /*logger*/) {
+ return new HashSkipListRep(compare, allocator, transform, bucket_count_,
+ skiplist_height_, skiplist_branching_factor_);
+}
+
+MemTableRepFactory* NewHashSkipListRepFactory(
+ size_t bucket_count, int32_t skiplist_height,
+ int32_t skiplist_branching_factor) {
+ return new HashSkipListRepFactory(bucket_count, skiplist_height,
+ skiplist_branching_factor);
+}
+
+} // namespace rocksdb
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/memtable/hash_skiplist_rep.h b/src/rocksdb/memtable/hash_skiplist_rep.h
new file mode 100644
index 00000000..5d1e04f3
--- /dev/null
+++ b/src/rocksdb/memtable/hash_skiplist_rep.h
@@ -0,0 +1,44 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#pragma once
+#ifndef ROCKSDB_LITE
+#include "rocksdb/slice_transform.h"
+#include "rocksdb/memtablerep.h"
+
+namespace rocksdb {
+
+class HashSkipListRepFactory : public MemTableRepFactory {
+ public:
+ explicit HashSkipListRepFactory(
+ size_t bucket_count,
+ int32_t skiplist_height,
+ int32_t skiplist_branching_factor)
+ : bucket_count_(bucket_count),
+ skiplist_height_(skiplist_height),
+ skiplist_branching_factor_(skiplist_branching_factor) { }
+
+ virtual ~HashSkipListRepFactory() {}
+
+ using MemTableRepFactory::CreateMemTableRep;
+ virtual MemTableRep* CreateMemTableRep(
+ const MemTableRep::KeyComparator& compare, Allocator* allocator,
+ const SliceTransform* transform, Logger* logger) override;
+
+ virtual const char* Name() const override {
+ return "HashSkipListRepFactory";
+ }
+
+ private:
+ const size_t bucket_count_;
+ const int32_t skiplist_height_;
+ const int32_t skiplist_branching_factor_;
+};
+
+}
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/memtable/inlineskiplist.h b/src/rocksdb/memtable/inlineskiplist.h
new file mode 100644
index 00000000..1ef8f2b6
--- /dev/null
+++ b/src/rocksdb/memtable/inlineskiplist.h
@@ -0,0 +1,965 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved. Use of
+// this source code is governed by a BSD-style license that can be found
+// in the LICENSE file. See the AUTHORS file for names of contributors.
+//
+// InlineSkipList is derived from SkipList (skiplist.h), but it optimizes
+// the memory layout by requiring that the key storage be allocated through
+// the skip list instance. For the common case of SkipList<const char*,
+// Cmp> this saves 1 pointer per skip list node and gives better cache
+// locality, at the expense of wasted padding from using AllocateAligned
+// instead of Allocate for the keys. The unused padding will be from
+// 0 to sizeof(void*)-1 bytes, and the space savings are sizeof(void*)
+// bytes, so despite the padding the space used is always less than
+// SkipList<const char*, ..>.
+//
+// Thread safety -------------
+//
+// Writes via Insert require external synchronization, most likely a mutex.
+// InsertConcurrently can be safely called concurrently with reads and
+// with other concurrent inserts. Reads require a guarantee that the
+// InlineSkipList will not be destroyed while the read is in progress.
+// Apart from that, reads progress without any internal locking or
+// synchronization.
+//
+// Invariants:
+//
+// (1) Allocated nodes are never deleted until the InlineSkipList is
+// destroyed. This is trivially guaranteed by the code since we never
+// delete any skip list nodes.
+//
+// (2) The contents of a Node except for the next/prev pointers are
+// immutable after the Node has been linked into the InlineSkipList.
+// Only Insert() modifies the list, and it is careful to initialize a
+// node and use release-stores to publish the nodes in one or more lists.
+//
+// ... prev vs. next pointer ordering ...
+//
+
+#pragma once
+#include <assert.h>
+#include <stdlib.h>
+#include <algorithm>
+#include <atomic>
+#include <type_traits>
+#include "port/likely.h"
+#include "port/port.h"
+#include "rocksdb/slice.h"
+#include "util/allocator.h"
+#include "util/coding.h"
+#include "util/random.h"
+
+namespace rocksdb {
+
+template <class Comparator>
+class InlineSkipList {
+ private:
+ struct Node;
+ struct Splice;
+
+ public:
+ using DecodedKey = \
+ typename std::remove_reference<Comparator>::type::DecodedType;
+
+ static const uint16_t kMaxPossibleHeight = 32;
+
+ // Create a new InlineSkipList object that will use "cmp" for comparing
+ // keys, and will allocate memory using "*allocator". Objects allocated
+ // in the allocator must remain allocated for the lifetime of the
+ // skiplist object.
+ explicit InlineSkipList(Comparator cmp, Allocator* allocator,
+ int32_t max_height = 12,
+ int32_t branching_factor = 4);
+
+ // Allocates a key and a skip-list node, returning a pointer to the key
+ // portion of the node. This method is thread-safe if the allocator
+ // is thread-safe.
+ char* AllocateKey(size_t key_size);
+
+ // Allocate a splice using allocator.
+ Splice* AllocateSplice();
+
+ // Inserts a key allocated by AllocateKey, after the actual key value
+ // has been filled in.
+ //
+ // REQUIRES: nothing that compares equal to key is currently in the list.
+ // REQUIRES: no concurrent calls to any of inserts.
+ bool Insert(const char* key);
+
+ // Inserts a key allocated by AllocateKey with a hint of last insert
+ // position in the skip-list. If hint points to nullptr, a new hint will be
+ // populated, which can be used in subsequent calls.
+ //
+ // It can be used to optimize the workload where there are multiple groups
+ // of keys, and each key is likely to insert to a location close to the last
+ // inserted key in the same group. One example is sequential inserts.
+ //
+ // REQUIRES: nothing that compares equal to key is currently in the list.
+ // REQUIRES: no concurrent calls to any of inserts.
+ bool InsertWithHint(const char* key, void** hint);
+
+ // Like Insert, but external synchronization is not required.
+ bool InsertConcurrently(const char* key);
+
+ // Inserts a node into the skip list. key must have been allocated by
+ // AllocateKey and then filled in by the caller. If UseCAS is true,
+ // then external synchronization is not required, otherwise this method
+ // may not be called concurrently with any other insertions.
+ //
+ // Regardless of whether UseCAS is true, the splice must be owned
+ // exclusively by the current thread. If allow_partial_splice_fix is
+ // true, then the cost of insertion is amortized O(log D), where D is
+ // the distance from the splice to the inserted key (measured as the
+ // number of intervening nodes). Note that this bound is very good for
+ // sequential insertions! If allow_partial_splice_fix is false then
+ // the existing splice will be ignored unless the current key is being
+ // inserted immediately after the splice. allow_partial_splice_fix ==
+ // false has worse running time for the non-sequential case O(log N),
+ // but a better constant factor.
+ template <bool UseCAS>
+ bool Insert(const char* key, Splice* splice, bool allow_partial_splice_fix);
+
+ // Returns true iff an entry that compares equal to key is in the list.
+ bool Contains(const char* key) const;
+
+ // Return estimated number of entries smaller than `key`.
+ uint64_t EstimateCount(const char* key) const;
+
+ // Validate correctness of the skip-list.
+ void TEST_Validate() const;
+
+ // Iteration over the contents of a skip list
+ class Iterator {
+ public:
+ // Initialize an iterator over the specified list.
+ // The returned iterator is not valid.
+ explicit Iterator(const InlineSkipList* list);
+
+ // Change the underlying skiplist used for this iterator
+ // This enables us not changing the iterator without deallocating
+ // an old one and then allocating a new one
+ void SetList(const InlineSkipList* list);
+
+ // Returns true iff the iterator is positioned at a valid node.
+ bool Valid() const;
+
+ // Returns the key at the current position.
+ // REQUIRES: Valid()
+ const char* key() const;
+
+ // Advances to the next position.
+ // REQUIRES: Valid()
+ void Next();
+
+ // Advances to the previous position.
+ // REQUIRES: Valid()
+ void Prev();
+
+ // Advance to the first entry with a key >= target
+ void Seek(const char* target);
+
+ // Retreat to the last entry with a key <= target
+ void SeekForPrev(const char* target);
+
+ // Position at the first entry in list.
+ // Final state of iterator is Valid() iff list is not empty.
+ void SeekToFirst();
+
+ // Position at the last entry in list.
+ // Final state of iterator is Valid() iff list is not empty.
+ void SeekToLast();
+
+ private:
+ const InlineSkipList* list_;
+ Node* node_;
+ // Intentionally copyable
+ };
+
+ private:
+ const uint16_t kMaxHeight_;
+ const uint16_t kBranching_;
+ const uint32_t kScaledInverseBranching_;
+
+ Allocator* const allocator_; // Allocator used for allocations of nodes
+ // Immutable after construction
+ Comparator const compare_;
+ Node* const head_;
+
+ // Modified only by Insert(). Read racily by readers, but stale
+ // values are ok.
+ std::atomic<int> max_height_; // Height of the entire list
+
+ // seq_splice_ is a Splice used for insertions in the non-concurrent
+ // case. It caches the prev and next found during the most recent
+ // non-concurrent insertion.
+ Splice* seq_splice_;
+
+ inline int GetMaxHeight() const {
+ return max_height_.load(std::memory_order_relaxed);
+ }
+
+ int RandomHeight();
+
+ Node* AllocateNode(size_t key_size, int height);
+
+ bool Equal(const char* a, const char* b) const {
+ return (compare_(a, b) == 0);
+ }
+
+ bool LessThan(const char* a, const char* b) const {
+ return (compare_(a, b) < 0);
+ }
+
+ // Return true if key is greater than the data stored in "n". Null n
+ // is considered infinite. n should not be head_.
+ bool KeyIsAfterNode(const char* key, Node* n) const;
+ bool KeyIsAfterNode(const DecodedKey& key, Node* n) const;
+
+ // Returns the earliest node with a key >= key.
+ // Return nullptr if there is no such node.
+ Node* FindGreaterOrEqual(const char* key) const;
+
+ // Return the latest node with a key < key.
+ // Return head_ if there is no such node.
+ // Fills prev[level] with pointer to previous node at "level" for every
+ // level in [0..max_height_-1], if prev is non-null.
+ Node* FindLessThan(const char* key, Node** prev = nullptr) const;
+
+ // Return the latest node with a key < key on bottom_level. Start searching
+ // from root node on the level below top_level.
+ // Fills prev[level] with pointer to previous node at "level" for every
+ // level in [bottom_level..top_level-1], if prev is non-null.
+ Node* FindLessThan(const char* key, Node** prev, Node* root, int top_level,
+ int bottom_level) const;
+
+ // Return the last node in the list.
+ // Return head_ if list is empty.
+ Node* FindLast() const;
+
+ // Traverses a single level of the list, setting *out_prev to the last
+ // node before the key and *out_next to the first node after. Assumes
+ // that the key is not present in the skip list. On entry, before should
+ // point to a node that is before the key, and after should point to
+ // a node that is after the key. after should be nullptr if a good after
+ // node isn't conveniently available.
+ template<bool prefetch_before>
+ void FindSpliceForLevel(const DecodedKey& key, Node* before, Node* after, int level,
+ Node** out_prev, Node** out_next);
+
+ // Recomputes Splice levels from highest_level (inclusive) down to
+ // lowest_level (inclusive).
+ void RecomputeSpliceLevels(const DecodedKey& key, Splice* splice,
+ int recompute_level);
+
+ // No copying allowed
+ InlineSkipList(const InlineSkipList&);
+ InlineSkipList& operator=(const InlineSkipList&);
+};
+
+// Implementation details follow
+
+template <class Comparator>
+struct InlineSkipList<Comparator>::Splice {
+ // The invariant of a Splice is that prev_[i+1].key <= prev_[i].key <
+ // next_[i].key <= next_[i+1].key for all i. That means that if a
+ // key is bracketed by prev_[i] and next_[i] then it is bracketed by
+ // all higher levels. It is _not_ required that prev_[i]->Next(i) ==
+ // next_[i] (it probably did at some point in the past, but intervening
+ // or concurrent operations might have inserted nodes in between).
+ int height_ = 0;
+ Node** prev_;
+ Node** next_;
+};
+
+// The Node data type is more of a pointer into custom-managed memory than
+// a traditional C++ struct. The key is stored in the bytes immediately
+// after the struct, and the next_ pointers for nodes with height > 1 are
+// stored immediately _before_ the struct. This avoids the need to include
+// any pointer or sizing data, which reduces per-node memory overheads.
+template <class Comparator>
+struct InlineSkipList<Comparator>::Node {
+ // Stores the height of the node in the memory location normally used for
+ // next_[0]. This is used for passing data from AllocateKey to Insert.
+ void StashHeight(const int height) {
+ assert(sizeof(int) <= sizeof(next_[0]));
+ memcpy(static_cast<void*>(&next_[0]), &height, sizeof(int));
+ }
+
+ // Retrieves the value passed to StashHeight. Undefined after a call
+ // to SetNext or NoBarrier_SetNext.
+ int UnstashHeight() const {
+ int rv;
+ memcpy(&rv, &next_[0], sizeof(int));
+ return rv;
+ }
+
+ const char* Key() const { return reinterpret_cast<const char*>(&next_[1]); }
+
+ // Accessors/mutators for links. Wrapped in methods so we can add
+ // the appropriate barriers as necessary, and perform the necessary
+ // addressing trickery for storing links below the Node in memory.
+ Node* Next(int n) {
+ assert(n >= 0);
+ // Use an 'acquire load' so that we observe a fully initialized
+ // version of the returned Node.
+ return ((&next_[0] - n)->load(std::memory_order_acquire));
+ }
+
+ void SetNext(int n, Node* x) {
+ assert(n >= 0);
+ // Use a 'release store' so that anybody who reads through this
+ // pointer observes a fully initialized version of the inserted node.
+ (&next_[0] - n)->store(x, std::memory_order_release);
+ }
+
+ bool CASNext(int n, Node* expected, Node* x) {
+ assert(n >= 0);
+ return (&next_[0] - n)->compare_exchange_strong(expected, x);
+ }
+
+ // No-barrier variants that can be safely used in a few locations.
+ Node* NoBarrier_Next(int n) {
+ assert(n >= 0);
+ return (&next_[0] - n)->load(std::memory_order_relaxed);
+ }
+
+ void NoBarrier_SetNext(int n, Node* x) {
+ assert(n >= 0);
+ (&next_[0] - n)->store(x, std::memory_order_relaxed);
+ }
+
+ // Insert node after prev on specific level.
+ void InsertAfter(Node* prev, int level) {
+ // NoBarrier_SetNext() suffices since we will add a barrier when
+ // we publish a pointer to "this" in prev.
+ NoBarrier_SetNext(level, prev->NoBarrier_Next(level));
+ prev->SetNext(level, this);
+ }
+
+ private:
+ // next_[0] is the lowest level link (level 0). Higher levels are
+ // stored _earlier_, so level 1 is at next_[-1].
+ std::atomic<Node*> next_[1];
+};
+
+template <class Comparator>
+inline InlineSkipList<Comparator>::Iterator::Iterator(
+ const InlineSkipList* list) {
+ SetList(list);
+}
+
+template <class Comparator>
+inline void InlineSkipList<Comparator>::Iterator::SetList(
+ const InlineSkipList* list) {
+ list_ = list;
+ node_ = nullptr;
+}
+
+template <class Comparator>
+inline bool InlineSkipList<Comparator>::Iterator::Valid() const {
+ return node_ != nullptr;
+}
+
+template <class Comparator>
+inline const char* InlineSkipList<Comparator>::Iterator::key() const {
+ assert(Valid());
+ return node_->Key();
+}
+
+template <class Comparator>
+inline void InlineSkipList<Comparator>::Iterator::Next() {
+ assert(Valid());
+ node_ = node_->Next(0);
+}
+
+template <class Comparator>
+inline void InlineSkipList<Comparator>::Iterator::Prev() {
+ // Instead of using explicit "prev" links, we just search for the
+ // last node that falls before key.
+ assert(Valid());
+ node_ = list_->FindLessThan(node_->Key());
+ if (node_ == list_->head_) {
+ node_ = nullptr;
+ }
+}
+
+template <class Comparator>
+inline void InlineSkipList<Comparator>::Iterator::Seek(const char* target) {
+ node_ = list_->FindGreaterOrEqual(target);
+}
+
+template <class Comparator>
+inline void InlineSkipList<Comparator>::Iterator::SeekForPrev(
+ const char* target) {
+ Seek(target);
+ if (!Valid()) {
+ SeekToLast();
+ }
+ while (Valid() && list_->LessThan(target, key())) {
+ Prev();
+ }
+}
+
+template <class Comparator>
+inline void InlineSkipList<Comparator>::Iterator::SeekToFirst() {
+ node_ = list_->head_->Next(0);
+}
+
+template <class Comparator>
+inline void InlineSkipList<Comparator>::Iterator::SeekToLast() {
+ node_ = list_->FindLast();
+ if (node_ == list_->head_) {
+ node_ = nullptr;
+ }
+}
+
+template <class Comparator>
+int InlineSkipList<Comparator>::RandomHeight() {
+ auto rnd = Random::GetTLSInstance();
+
+ // Increase height with probability 1 in kBranching
+ int height = 1;
+ while (height < kMaxHeight_ && height < kMaxPossibleHeight &&
+ rnd->Next() < kScaledInverseBranching_) {
+ height++;
+ }
+ assert(height > 0);
+ assert(height <= kMaxHeight_);
+ assert(height <= kMaxPossibleHeight);
+ return height;
+}
+
+template <class Comparator>
+bool InlineSkipList<Comparator>::KeyIsAfterNode(const char* key,
+ Node* n) const {
+ // nullptr n is considered infinite
+ assert(n != head_);
+ return (n != nullptr) && (compare_(n->Key(), key) < 0);
+}
+
+template <class Comparator>
+bool InlineSkipList<Comparator>::KeyIsAfterNode(const DecodedKey& key,
+ Node* n) const {
+ // nullptr n is considered infinite
+ assert(n != head_);
+ return (n != nullptr) && (compare_(n->Key(), key) < 0);
+}
+
+template <class Comparator>
+typename InlineSkipList<Comparator>::Node*
+InlineSkipList<Comparator>::FindGreaterOrEqual(const char* key) const {
+ // Note: It looks like we could reduce duplication by implementing
+ // this function as FindLessThan(key)->Next(0), but we wouldn't be able
+ // to exit early on equality and the result wouldn't even be correct.
+ // A concurrent insert might occur after FindLessThan(key) but before
+ // we get a chance to call Next(0).
+ Node* x = head_;
+ int level = GetMaxHeight() - 1;
+ Node* last_bigger = nullptr;
+ const DecodedKey key_decoded = compare_.decode_key(key);
+ while (true) {
+ Node* next = x->Next(level);
+ if (next != nullptr) {
+ PREFETCH(next->Next(level), 0, 1);
+ }
+ // Make sure the lists are sorted
+ assert(x == head_ || next == nullptr || KeyIsAfterNode(next->Key(), x));
+ // Make sure we haven't overshot during our search
+ assert(x == head_ || KeyIsAfterNode(key_decoded, x));
+ int cmp = (next == nullptr || next == last_bigger)
+ ? 1
+ : compare_(next->Key(), key_decoded);
+ if (cmp == 0 || (cmp > 0 && level == 0)) {
+ return next;
+ } else if (cmp < 0) {
+ // Keep searching in this list
+ x = next;
+ } else {
+ // Switch to next list, reuse compare_() result
+ last_bigger = next;
+ level--;
+ }
+ }
+}
+
+template <class Comparator>
+typename InlineSkipList<Comparator>::Node*
+InlineSkipList<Comparator>::FindLessThan(const char* key, Node** prev) const {
+ return FindLessThan(key, prev, head_, GetMaxHeight(), 0);
+}
+
+template <class Comparator>
+typename InlineSkipList<Comparator>::Node*
+InlineSkipList<Comparator>::FindLessThan(const char* key, Node** prev,
+ Node* root, int top_level,
+ int bottom_level) const {
+ assert(top_level > bottom_level);
+ int level = top_level - 1;
+ Node* x = root;
+ // KeyIsAfter(key, last_not_after) is definitely false
+ Node* last_not_after = nullptr;
+ const DecodedKey key_decoded = compare_.decode_key(key);
+ while (true) {
+ assert(x != nullptr);
+ Node* next = x->Next(level);
+ if (next != nullptr) {
+ PREFETCH(next->Next(level), 0, 1);
+ }
+ assert(x == head_ || next == nullptr || KeyIsAfterNode(next->Key(), x));
+ assert(x == head_ || KeyIsAfterNode(key_decoded, x));
+ if (next != last_not_after && KeyIsAfterNode(key_decoded, next)) {
+ // Keep searching in this list
+ assert(next != nullptr);
+ x = next;
+ } else {
+ if (prev != nullptr) {
+ prev[level] = x;
+ }
+ if (level == bottom_level) {
+ return x;
+ } else {
+ // Switch to next list, reuse KeyIsAfterNode() result
+ last_not_after = next;
+ level--;
+ }
+ }
+ }
+}
+
+template <class Comparator>
+typename InlineSkipList<Comparator>::Node*
+InlineSkipList<Comparator>::FindLast() const {
+ Node* x = head_;
+ int level = GetMaxHeight() - 1;
+ while (true) {
+ Node* next = x->Next(level);
+ if (next == nullptr) {
+ if (level == 0) {
+ return x;
+ } else {
+ // Switch to next list
+ level--;
+ }
+ } else {
+ x = next;
+ }
+ }
+}
+
+template <class Comparator>
+uint64_t InlineSkipList<Comparator>::EstimateCount(const char* key) const {
+ uint64_t count = 0;
+
+ Node* x = head_;
+ int level = GetMaxHeight() - 1;
+ const DecodedKey key_decoded = compare_.decode_key(key);
+ while (true) {
+ assert(x == head_ || compare_(x->Key(), key_decoded) < 0);
+ Node* next = x->Next(level);
+ if (next != nullptr) {
+ PREFETCH(next->Next(level), 0, 1);
+ }
+ if (next == nullptr || compare_(next->Key(), key_decoded) >= 0) {
+ if (level == 0) {
+ return count;
+ } else {
+ // Switch to next list
+ count *= kBranching_;
+ level--;
+ }
+ } else {
+ x = next;
+ count++;
+ }
+ }
+}
+
+template <class Comparator>
+InlineSkipList<Comparator>::InlineSkipList(const Comparator cmp,
+ Allocator* allocator,
+ int32_t max_height,
+ int32_t branching_factor)
+ : kMaxHeight_(static_cast<uint16_t>(max_height)),
+ kBranching_(static_cast<uint16_t>(branching_factor)),
+ kScaledInverseBranching_((Random::kMaxNext + 1) / kBranching_),
+ allocator_(allocator),
+ compare_(cmp),
+ head_(AllocateNode(0, max_height)),
+ max_height_(1),
+ seq_splice_(AllocateSplice()) {
+ assert(max_height > 0 && kMaxHeight_ == static_cast<uint32_t>(max_height));
+ assert(branching_factor > 1 &&
+ kBranching_ == static_cast<uint32_t>(branching_factor));
+ assert(kScaledInverseBranching_ > 0);
+
+ for (int i = 0; i < kMaxHeight_; ++i) {
+ head_->SetNext(i, nullptr);
+ }
+}
+
+template <class Comparator>
+char* InlineSkipList<Comparator>::AllocateKey(size_t key_size) {
+ return const_cast<char*>(AllocateNode(key_size, RandomHeight())->Key());
+}
+
+template <class Comparator>
+typename InlineSkipList<Comparator>::Node*
+InlineSkipList<Comparator>::AllocateNode(size_t key_size, int height) {
+ auto prefix = sizeof(std::atomic<Node*>) * (height - 1);
+
+ // prefix is space for the height - 1 pointers that we store before
+ // the Node instance (next_[-(height - 1) .. -1]). Node starts at
+ // raw + prefix, and holds the bottom-mode (level 0) skip list pointer
+ // next_[0]. key_size is the bytes for the key, which comes just after
+ // the Node.
+ char* raw = allocator_->AllocateAligned(prefix + sizeof(Node) + key_size);
+ Node* x = reinterpret_cast<Node*>(raw + prefix);
+
+ // Once we've linked the node into the skip list we don't actually need
+ // to know its height, because we can implicitly use the fact that we
+ // traversed into a node at level h to known that h is a valid level
+ // for that node. We need to convey the height to the Insert step,
+ // however, so that it can perform the proper links. Since we're not
+ // using the pointers at the moment, StashHeight temporarily borrow
+ // storage from next_[0] for that purpose.
+ x->StashHeight(height);
+ return x;
+}
+
+template <class Comparator>
+typename InlineSkipList<Comparator>::Splice*
+InlineSkipList<Comparator>::AllocateSplice() {
+ // size of prev_ and next_
+ size_t array_size = sizeof(Node*) * (kMaxHeight_ + 1);
+ char* raw = allocator_->AllocateAligned(sizeof(Splice) + array_size * 2);
+ Splice* splice = reinterpret_cast<Splice*>(raw);
+ splice->height_ = 0;
+ splice->prev_ = reinterpret_cast<Node**>(raw + sizeof(Splice));
+ splice->next_ = reinterpret_cast<Node**>(raw + sizeof(Splice) + array_size);
+ return splice;
+}
+
+template <class Comparator>
+bool InlineSkipList<Comparator>::Insert(const char* key) {
+ return Insert<false>(key, seq_splice_, false);
+}
+
+template <class Comparator>
+bool InlineSkipList<Comparator>::InsertConcurrently(const char* key) {
+ Node* prev[kMaxPossibleHeight];
+ Node* next[kMaxPossibleHeight];
+ Splice splice;
+ splice.prev_ = prev;
+ splice.next_ = next;
+ return Insert<true>(key, &splice, false);
+}
+
+template <class Comparator>
+bool InlineSkipList<Comparator>::InsertWithHint(const char* key, void** hint) {
+ assert(hint != nullptr);
+ Splice* splice = reinterpret_cast<Splice*>(*hint);
+ if (splice == nullptr) {
+ splice = AllocateSplice();
+ *hint = reinterpret_cast<void*>(splice);
+ }
+ return Insert<false>(key, splice, true);
+}
+
+template <class Comparator>
+template <bool prefetch_before>
+void InlineSkipList<Comparator>::FindSpliceForLevel(const DecodedKey& key,
+ Node* before, Node* after,
+ int level, Node** out_prev,
+ Node** out_next) {
+ while (true) {
+ Node* next = before->Next(level);
+ if (next != nullptr) {
+ PREFETCH(next->Next(level), 0, 1);
+ }
+ if (prefetch_before == true) {
+ if (next != nullptr && level>0) {
+ PREFETCH(next->Next(level-1), 0, 1);
+ }
+ }
+ assert(before == head_ || next == nullptr ||
+ KeyIsAfterNode(next->Key(), before));
+ assert(before == head_ || KeyIsAfterNode(key, before));
+ if (next == after || !KeyIsAfterNode(key, next)) {
+ // found it
+ *out_prev = before;
+ *out_next = next;
+ return;
+ }
+ before = next;
+ }
+}
+
+template <class Comparator>
+void InlineSkipList<Comparator>::RecomputeSpliceLevels(const DecodedKey& key,
+ Splice* splice,
+ int recompute_level) {
+ assert(recompute_level > 0);
+ assert(recompute_level <= splice->height_);
+ for (int i = recompute_level - 1; i >= 0; --i) {
+ FindSpliceForLevel<true>(key, splice->prev_[i + 1], splice->next_[i + 1], i,
+ &splice->prev_[i], &splice->next_[i]);
+ }
+}
+
+template <class Comparator>
+template <bool UseCAS>
+bool InlineSkipList<Comparator>::Insert(const char* key, Splice* splice,
+ bool allow_partial_splice_fix) {
+ Node* x = reinterpret_cast<Node*>(const_cast<char*>(key)) - 1;
+ const DecodedKey key_decoded = compare_.decode_key(key);
+ int height = x->UnstashHeight();
+ assert(height >= 1 && height <= kMaxHeight_);
+
+ int max_height = max_height_.load(std::memory_order_relaxed);
+ while (height > max_height) {
+ if (max_height_.compare_exchange_weak(max_height, height)) {
+ // successfully updated it
+ max_height = height;
+ break;
+ }
+ // else retry, possibly exiting the loop because somebody else
+ // increased it
+ }
+ assert(max_height <= kMaxPossibleHeight);
+
+ int recompute_height = 0;
+ if (splice->height_ < max_height) {
+ // Either splice has never been used or max_height has grown since
+ // last use. We could potentially fix it in the latter case, but
+ // that is tricky.
+ splice->prev_[max_height] = head_;
+ splice->next_[max_height] = nullptr;
+ splice->height_ = max_height;
+ recompute_height = max_height;
+ } else {
+ // Splice is a valid proper-height splice that brackets some
+ // key, but does it bracket this one? We need to validate it and
+ // recompute a portion of the splice (levels 0..recompute_height-1)
+ // that is a superset of all levels that don't bracket the new key.
+ // Several choices are reasonable, because we have to balance the work
+ // saved against the extra comparisons required to validate the Splice.
+ //
+ // One strategy is just to recompute all of orig_splice_height if the
+ // bottom level isn't bracketing. This pessimistically assumes that
+ // we will either get a perfect Splice hit (increasing sequential
+ // inserts) or have no locality.
+ //
+ // Another strategy is to walk up the Splice's levels until we find
+ // a level that brackets the key. This strategy lets the Splice
+ // hint help for other cases: it turns insertion from O(log N) into
+ // O(log D), where D is the number of nodes in between the key that
+ // produced the Splice and the current insert (insertion is aided
+ // whether the new key is before or after the splice). If you have
+ // a way of using a prefix of the key to map directly to the closest
+ // Splice out of O(sqrt(N)) Splices and we make it so that splices
+ // can also be used as hints during read, then we end up with Oshman's
+ // and Shavit's SkipTrie, which has O(log log N) lookup and insertion
+ // (compare to O(log N) for skip list).
+ //
+ // We control the pessimistic strategy with allow_partial_splice_fix.
+ // A good strategy is probably to be pessimistic for seq_splice_,
+ // optimistic if the caller actually went to the work of providing
+ // a Splice.
+ while (recompute_height < max_height) {
+ if (splice->prev_[recompute_height]->Next(recompute_height) !=
+ splice->next_[recompute_height]) {
+ // splice isn't tight at this level, there must have been some inserts
+ // to this
+ // location that didn't update the splice. We might only be a little
+ // stale, but if
+ // the splice is very stale it would be O(N) to fix it. We haven't used
+ // up any of
+ // our budget of comparisons, so always move up even if we are
+ // pessimistic about
+ // our chances of success.
+ ++recompute_height;
+ } else if (splice->prev_[recompute_height] != head_ &&
+ !KeyIsAfterNode(key_decoded,
+ splice->prev_[recompute_height])) {
+ // key is from before splice
+ if (allow_partial_splice_fix) {
+ // skip all levels with the same node without more comparisons
+ Node* bad = splice->prev_[recompute_height];
+ while (splice->prev_[recompute_height] == bad) {
+ ++recompute_height;
+ }
+ } else {
+ // we're pessimistic, recompute everything
+ recompute_height = max_height;
+ }
+ } else if (KeyIsAfterNode(key_decoded,
+ splice->next_[recompute_height])) {
+ // key is from after splice
+ if (allow_partial_splice_fix) {
+ Node* bad = splice->next_[recompute_height];
+ while (splice->next_[recompute_height] == bad) {
+ ++recompute_height;
+ }
+ } else {
+ recompute_height = max_height;
+ }
+ } else {
+ // this level brackets the key, we won!
+ break;
+ }
+ }
+ }
+ assert(recompute_height <= max_height);
+ if (recompute_height > 0) {
+ RecomputeSpliceLevels(key_decoded, splice, recompute_height);
+ }
+
+ bool splice_is_valid = true;
+ if (UseCAS) {
+ for (int i = 0; i < height; ++i) {
+ while (true) {
+ // Checking for duplicate keys on the level 0 is sufficient
+ if (UNLIKELY(i == 0 && splice->next_[i] != nullptr &&
+ compare_(x->Key(), splice->next_[i]->Key()) >= 0)) {
+ // duplicate key
+ return false;
+ }
+ if (UNLIKELY(i == 0 && splice->prev_[i] != head_ &&
+ compare_(splice->prev_[i]->Key(), x->Key()) >= 0)) {
+ // duplicate key
+ return false;
+ }
+ assert(splice->next_[i] == nullptr ||
+ compare_(x->Key(), splice->next_[i]->Key()) < 0);
+ assert(splice->prev_[i] == head_ ||
+ compare_(splice->prev_[i]->Key(), x->Key()) < 0);
+ x->NoBarrier_SetNext(i, splice->next_[i]);
+ if (splice->prev_[i]->CASNext(i, splice->next_[i], x)) {
+ // success
+ break;
+ }
+ // CAS failed, we need to recompute prev and next. It is unlikely
+ // to be helpful to try to use a different level as we redo the
+ // search, because it should be unlikely that lots of nodes have
+ // been inserted between prev[i] and next[i]. No point in using
+ // next[i] as the after hint, because we know it is stale.
+ FindSpliceForLevel<false>(key_decoded, splice->prev_[i], nullptr, i,
+ &splice->prev_[i], &splice->next_[i]);
+
+ // Since we've narrowed the bracket for level i, we might have
+ // violated the Splice constraint between i and i-1. Make sure
+ // we recompute the whole thing next time.
+ if (i > 0) {
+ splice_is_valid = false;
+ }
+ }
+ }
+ } else {
+ for (int i = 0; i < height; ++i) {
+ if (i >= recompute_height &&
+ splice->prev_[i]->Next(i) != splice->next_[i]) {
+ FindSpliceForLevel<false>(key_decoded, splice->prev_[i], nullptr, i,
+ &splice->prev_[i], &splice->next_[i]);
+ }
+ // Checking for duplicate keys on the level 0 is sufficient
+ if (UNLIKELY(i == 0 && splice->next_[i] != nullptr &&
+ compare_(x->Key(), splice->next_[i]->Key()) >= 0)) {
+ // duplicate key
+ return false;
+ }
+ if (UNLIKELY(i == 0 && splice->prev_[i] != head_ &&
+ compare_(splice->prev_[i]->Key(), x->Key()) >= 0)) {
+ // duplicate key
+ return false;
+ }
+ assert(splice->next_[i] == nullptr ||
+ compare_(x->Key(), splice->next_[i]->Key()) < 0);
+ assert(splice->prev_[i] == head_ ||
+ compare_(splice->prev_[i]->Key(), x->Key()) < 0);
+ assert(splice->prev_[i]->Next(i) == splice->next_[i]);
+ x->NoBarrier_SetNext(i, splice->next_[i]);
+ splice->prev_[i]->SetNext(i, x);
+ }
+ }
+ if (splice_is_valid) {
+ for (int i = 0; i < height; ++i) {
+ splice->prev_[i] = x;
+ }
+ assert(splice->prev_[splice->height_] == head_);
+ assert(splice->next_[splice->height_] == nullptr);
+ for (int i = 0; i < splice->height_; ++i) {
+ assert(splice->next_[i] == nullptr ||
+ compare_(key, splice->next_[i]->Key()) < 0);
+ assert(splice->prev_[i] == head_ ||
+ compare_(splice->prev_[i]->Key(), key) <= 0);
+ assert(splice->prev_[i + 1] == splice->prev_[i] ||
+ splice->prev_[i + 1] == head_ ||
+ compare_(splice->prev_[i + 1]->Key(), splice->prev_[i]->Key()) <
+ 0);
+ assert(splice->next_[i + 1] == splice->next_[i] ||
+ splice->next_[i + 1] == nullptr ||
+ compare_(splice->next_[i]->Key(), splice->next_[i + 1]->Key()) <
+ 0);
+ }
+ } else {
+ splice->height_ = 0;
+ }
+ return true;
+}
+
+template <class Comparator>
+bool InlineSkipList<Comparator>::Contains(const char* key) const {
+ Node* x = FindGreaterOrEqual(key);
+ if (x != nullptr && Equal(key, x->Key())) {
+ return true;
+ } else {
+ return false;
+ }
+}
+
+template <class Comparator>
+void InlineSkipList<Comparator>::TEST_Validate() const {
+ // Interate over all levels at the same time, and verify nodes appear in
+ // the right order, and nodes appear in upper level also appear in lower
+ // levels.
+ Node* nodes[kMaxPossibleHeight];
+ int max_height = GetMaxHeight();
+ assert(max_height > 0);
+ for (int i = 0; i < max_height; i++) {
+ nodes[i] = head_;
+ }
+ while (nodes[0] != nullptr) {
+ Node* l0_next = nodes[0]->Next(0);
+ if (l0_next == nullptr) {
+ break;
+ }
+ assert(nodes[0] == head_ || compare_(nodes[0]->Key(), l0_next->Key()) < 0);
+ nodes[0] = l0_next;
+
+ int i = 1;
+ while (i < max_height) {
+ Node* next = nodes[i]->Next(i);
+ if (next == nullptr) {
+ break;
+ }
+ auto cmp = compare_(nodes[0]->Key(), next->Key());
+ assert(cmp <= 0);
+ if (cmp == 0) {
+ assert(next == nodes[0]);
+ nodes[i] = next;
+ } else {
+ break;
+ }
+ i++;
+ }
+ }
+ for (int i = 1; i < max_height; i++) {
+ assert(nodes[i] != nullptr && nodes[i]->Next(i) == nullptr);
+ }
+}
+
+} // namespace rocksdb
diff --git a/src/rocksdb/memtable/inlineskiplist_test.cc b/src/rocksdb/memtable/inlineskiplist_test.cc
new file mode 100644
index 00000000..b416ef7c
--- /dev/null
+++ b/src/rocksdb/memtable/inlineskiplist_test.cc
@@ -0,0 +1,645 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "memtable/inlineskiplist.h"
+#include <set>
+#include <unordered_set>
+#include "rocksdb/env.h"
+#include "util/concurrent_arena.h"
+#include "util/hash.h"
+#include "util/random.h"
+#include "util/testharness.h"
+
+namespace rocksdb {
+
+// Our test skip list stores 8-byte unsigned integers
+typedef uint64_t Key;
+
+static const char* Encode(const uint64_t* key) {
+ return reinterpret_cast<const char*>(key);
+}
+
+static Key Decode(const char* key) {
+ Key rv;
+ memcpy(&rv, key, sizeof(Key));
+ return rv;
+}
+
+struct TestComparator {
+ typedef Key DecodedType;
+
+ static DecodedType decode_key(const char* b) {
+ return Decode(b);
+ }
+
+ int operator()(const char* a, const char* b) const {
+ if (Decode(a) < Decode(b)) {
+ return -1;
+ } else if (Decode(a) > Decode(b)) {
+ return +1;
+ } else {
+ return 0;
+ }
+ }
+
+ int operator()(const char* a, const DecodedType b) const {
+ if (Decode(a) < b) {
+ return -1;
+ } else if (Decode(a) > b) {
+ return +1;
+ } else {
+ return 0;
+ }
+ }
+};
+
+typedef InlineSkipList<TestComparator> TestInlineSkipList;
+
+class InlineSkipTest : public testing::Test {
+ public:
+ void Insert(TestInlineSkipList* list, Key key) {
+ char* buf = list->AllocateKey(sizeof(Key));
+ memcpy(buf, &key, sizeof(Key));
+ list->Insert(buf);
+ keys_.insert(key);
+ }
+
+ bool InsertWithHint(TestInlineSkipList* list, Key key, void** hint) {
+ char* buf = list->AllocateKey(sizeof(Key));
+ memcpy(buf, &key, sizeof(Key));
+ bool res = list->InsertWithHint(buf, hint);
+ keys_.insert(key);
+ return res;
+ }
+
+ void Validate(TestInlineSkipList* list) {
+ // Check keys exist.
+ for (Key key : keys_) {
+ ASSERT_TRUE(list->Contains(Encode(&key)));
+ }
+ // Iterate over the list, make sure keys appears in order and no extra
+ // keys exist.
+ TestInlineSkipList::Iterator iter(list);
+ ASSERT_FALSE(iter.Valid());
+ Key zero = 0;
+ iter.Seek(Encode(&zero));
+ for (Key key : keys_) {
+ ASSERT_TRUE(iter.Valid());
+ ASSERT_EQ(key, Decode(iter.key()));
+ iter.Next();
+ }
+ ASSERT_FALSE(iter.Valid());
+ // Validate the list is well-formed.
+ list->TEST_Validate();
+ }
+
+ private:
+ std::set<Key> keys_;
+};
+
+TEST_F(InlineSkipTest, Empty) {
+ Arena arena;
+ TestComparator cmp;
+ InlineSkipList<TestComparator> list(cmp, &arena);
+ Key key = 10;
+ ASSERT_TRUE(!list.Contains(Encode(&key)));
+
+ InlineSkipList<TestComparator>::Iterator iter(&list);
+ ASSERT_TRUE(!iter.Valid());
+ iter.SeekToFirst();
+ ASSERT_TRUE(!iter.Valid());
+ key = 100;
+ iter.Seek(Encode(&key));
+ ASSERT_TRUE(!iter.Valid());
+ iter.SeekForPrev(Encode(&key));
+ ASSERT_TRUE(!iter.Valid());
+ iter.SeekToLast();
+ ASSERT_TRUE(!iter.Valid());
+}
+
+TEST_F(InlineSkipTest, InsertAndLookup) {
+ const int N = 2000;
+ const int R = 5000;
+ Random rnd(1000);
+ std::set<Key> keys;
+ ConcurrentArena arena;
+ TestComparator cmp;
+ InlineSkipList<TestComparator> list(cmp, &arena);
+ for (int i = 0; i < N; i++) {
+ Key key = rnd.Next() % R;
+ if (keys.insert(key).second) {
+ char* buf = list.AllocateKey(sizeof(Key));
+ memcpy(buf, &key, sizeof(Key));
+ list.Insert(buf);
+ }
+ }
+
+ for (Key i = 0; i < R; i++) {
+ if (list.Contains(Encode(&i))) {
+ ASSERT_EQ(keys.count(i), 1U);
+ } else {
+ ASSERT_EQ(keys.count(i), 0U);
+ }
+ }
+
+ // Simple iterator tests
+ {
+ InlineSkipList<TestComparator>::Iterator iter(&list);
+ ASSERT_TRUE(!iter.Valid());
+
+ uint64_t zero = 0;
+ iter.Seek(Encode(&zero));
+ ASSERT_TRUE(iter.Valid());
+ ASSERT_EQ(*(keys.begin()), Decode(iter.key()));
+
+ uint64_t max_key = R - 1;
+ iter.SeekForPrev(Encode(&max_key));
+ ASSERT_TRUE(iter.Valid());
+ ASSERT_EQ(*(keys.rbegin()), Decode(iter.key()));
+
+ iter.SeekToFirst();
+ ASSERT_TRUE(iter.Valid());
+ ASSERT_EQ(*(keys.begin()), Decode(iter.key()));
+
+ iter.SeekToLast();
+ ASSERT_TRUE(iter.Valid());
+ ASSERT_EQ(*(keys.rbegin()), Decode(iter.key()));
+ }
+
+ // Forward iteration test
+ for (Key i = 0; i < R; i++) {
+ InlineSkipList<TestComparator>::Iterator iter(&list);
+ iter.Seek(Encode(&i));
+
+ // Compare against model iterator
+ std::set<Key>::iterator model_iter = keys.lower_bound(i);
+ for (int j = 0; j < 3; j++) {
+ if (model_iter == keys.end()) {
+ ASSERT_TRUE(!iter.Valid());
+ break;
+ } else {
+ ASSERT_TRUE(iter.Valid());
+ ASSERT_EQ(*model_iter, Decode(iter.key()));
+ ++model_iter;
+ iter.Next();
+ }
+ }
+ }
+
+ // Backward iteration test
+ for (Key i = 0; i < R; i++) {
+ InlineSkipList<TestComparator>::Iterator iter(&list);
+ iter.SeekForPrev(Encode(&i));
+
+ // Compare against model iterator
+ std::set<Key>::iterator model_iter = keys.upper_bound(i);
+ for (int j = 0; j < 3; j++) {
+ if (model_iter == keys.begin()) {
+ ASSERT_TRUE(!iter.Valid());
+ break;
+ } else {
+ ASSERT_TRUE(iter.Valid());
+ ASSERT_EQ(*--model_iter, Decode(iter.key()));
+ iter.Prev();
+ }
+ }
+ }
+}
+
+TEST_F(InlineSkipTest, InsertWithHint_Sequential) {
+ const int N = 100000;
+ Arena arena;
+ TestComparator cmp;
+ TestInlineSkipList list(cmp, &arena);
+ void* hint = nullptr;
+ for (int i = 0; i < N; i++) {
+ Key key = i;
+ InsertWithHint(&list, key, &hint);
+ }
+ Validate(&list);
+}
+
+TEST_F(InlineSkipTest, InsertWithHint_MultipleHints) {
+ const int N = 100000;
+ const int S = 100;
+ Random rnd(534);
+ Arena arena;
+ TestComparator cmp;
+ TestInlineSkipList list(cmp, &arena);
+ void* hints[S];
+ Key last_key[S];
+ for (int i = 0; i < S; i++) {
+ hints[i] = nullptr;
+ last_key[i] = 0;
+ }
+ for (int i = 0; i < N; i++) {
+ Key s = rnd.Uniform(S);
+ Key key = (s << 32) + (++last_key[s]);
+ InsertWithHint(&list, key, &hints[s]);
+ }
+ Validate(&list);
+}
+
+TEST_F(InlineSkipTest, InsertWithHint_MultipleHintsRandom) {
+ const int N = 100000;
+ const int S = 100;
+ Random rnd(534);
+ Arena arena;
+ TestComparator cmp;
+ TestInlineSkipList list(cmp, &arena);
+ void* hints[S];
+ for (int i = 0; i < S; i++) {
+ hints[i] = nullptr;
+ }
+ for (int i = 0; i < N; i++) {
+ Key s = rnd.Uniform(S);
+ Key key = (s << 32) + rnd.Next();
+ InsertWithHint(&list, key, &hints[s]);
+ }
+ Validate(&list);
+}
+
+TEST_F(InlineSkipTest, InsertWithHint_CompatibleWithInsertWithoutHint) {
+ const int N = 100000;
+ const int S1 = 100;
+ const int S2 = 100;
+ Random rnd(534);
+ Arena arena;
+ TestComparator cmp;
+ TestInlineSkipList list(cmp, &arena);
+ std::unordered_set<Key> used;
+ Key with_hint[S1];
+ Key without_hint[S2];
+ void* hints[S1];
+ for (int i = 0; i < S1; i++) {
+ hints[i] = nullptr;
+ while (true) {
+ Key s = rnd.Next();
+ if (used.insert(s).second) {
+ with_hint[i] = s;
+ break;
+ }
+ }
+ }
+ for (int i = 0; i < S2; i++) {
+ while (true) {
+ Key s = rnd.Next();
+ if (used.insert(s).second) {
+ without_hint[i] = s;
+ break;
+ }
+ }
+ }
+ for (int i = 0; i < N; i++) {
+ Key s = rnd.Uniform(S1 + S2);
+ if (s < S1) {
+ Key key = (with_hint[s] << 32) + rnd.Next();
+ InsertWithHint(&list, key, &hints[s]);
+ } else {
+ Key key = (without_hint[s - S1] << 32) + rnd.Next();
+ Insert(&list, key);
+ }
+ }
+ Validate(&list);
+}
+
+#ifndef ROCKSDB_VALGRIND_RUN
+// We want to make sure that with a single writer and multiple
+// concurrent readers (with no synchronization other than when a
+// reader's iterator is created), the reader always observes all the
+// data that was present in the skip list when the iterator was
+// constructor. Because insertions are happening concurrently, we may
+// also observe new values that were inserted since the iterator was
+// constructed, but we should never miss any values that were present
+// at iterator construction time.
+//
+// We generate multi-part keys:
+// <key,gen,hash>
+// where:
+// key is in range [0..K-1]
+// gen is a generation number for key
+// hash is hash(key,gen)
+//
+// The insertion code picks a random key, sets gen to be 1 + the last
+// generation number inserted for that key, and sets hash to Hash(key,gen).
+//
+// At the beginning of a read, we snapshot the last inserted
+// generation number for each key. We then iterate, including random
+// calls to Next() and Seek(). For every key we encounter, we
+// check that it is either expected given the initial snapshot or has
+// been concurrently added since the iterator started.
+class ConcurrentTest {
+ public:
+ static const uint32_t K = 8;
+
+ private:
+ static uint64_t key(Key key) { return (key >> 40); }
+ static uint64_t gen(Key key) { return (key >> 8) & 0xffffffffu; }
+ static uint64_t hash(Key key) { return key & 0xff; }
+
+ static uint64_t HashNumbers(uint64_t k, uint64_t g) {
+ uint64_t data[2] = {k, g};
+ return Hash(reinterpret_cast<char*>(data), sizeof(data), 0);
+ }
+
+ static Key MakeKey(uint64_t k, uint64_t g) {
+ assert(sizeof(Key) == sizeof(uint64_t));
+ assert(k <= K); // We sometimes pass K to seek to the end of the skiplist
+ assert(g <= 0xffffffffu);
+ return ((k << 40) | (g << 8) | (HashNumbers(k, g) & 0xff));
+ }
+
+ static bool IsValidKey(Key k) {
+ return hash(k) == (HashNumbers(key(k), gen(k)) & 0xff);
+ }
+
+ static Key RandomTarget(Random* rnd) {
+ switch (rnd->Next() % 10) {
+ case 0:
+ // Seek to beginning
+ return MakeKey(0, 0);
+ case 1:
+ // Seek to end
+ return MakeKey(K, 0);
+ default:
+ // Seek to middle
+ return MakeKey(rnd->Next() % K, 0);
+ }
+ }
+
+ // Per-key generation
+ struct State {
+ std::atomic<int> generation[K];
+ void Set(int k, int v) {
+ generation[k].store(v, std::memory_order_release);
+ }
+ int Get(int k) { return generation[k].load(std::memory_order_acquire); }
+
+ State() {
+ for (unsigned int k = 0; k < K; k++) {
+ Set(k, 0);
+ }
+ }
+ };
+
+ // Current state of the test
+ State current_;
+
+ ConcurrentArena arena_;
+
+ // InlineSkipList is not protected by mu_. We just use a single writer
+ // thread to modify it.
+ InlineSkipList<TestComparator> list_;
+
+ public:
+ ConcurrentTest() : list_(TestComparator(), &arena_) {}
+
+ // REQUIRES: No concurrent calls to WriteStep or ConcurrentWriteStep
+ void WriteStep(Random* rnd) {
+ const uint32_t k = rnd->Next() % K;
+ const int g = current_.Get(k) + 1;
+ const Key new_key = MakeKey(k, g);
+ char* buf = list_.AllocateKey(sizeof(Key));
+ memcpy(buf, &new_key, sizeof(Key));
+ list_.Insert(buf);
+ current_.Set(k, g);
+ }
+
+ // REQUIRES: No concurrent calls for the same k
+ void ConcurrentWriteStep(uint32_t k) {
+ const int g = current_.Get(k) + 1;
+ const Key new_key = MakeKey(k, g);
+ char* buf = list_.AllocateKey(sizeof(Key));
+ memcpy(buf, &new_key, sizeof(Key));
+ list_.InsertConcurrently(buf);
+ ASSERT_EQ(g, current_.Get(k) + 1);
+ current_.Set(k, g);
+ }
+
+ void ReadStep(Random* rnd) {
+ // Remember the initial committed state of the skiplist.
+ State initial_state;
+ for (unsigned int k = 0; k < K; k++) {
+ initial_state.Set(k, current_.Get(k));
+ }
+
+ Key pos = RandomTarget(rnd);
+ InlineSkipList<TestComparator>::Iterator iter(&list_);
+ iter.Seek(Encode(&pos));
+ while (true) {
+ Key current;
+ if (!iter.Valid()) {
+ current = MakeKey(K, 0);
+ } else {
+ current = Decode(iter.key());
+ ASSERT_TRUE(IsValidKey(current)) << current;
+ }
+ ASSERT_LE(pos, current) << "should not go backwards";
+
+ // Verify that everything in [pos,current) was not present in
+ // initial_state.
+ while (pos < current) {
+ ASSERT_LT(key(pos), K) << pos;
+
+ // Note that generation 0 is never inserted, so it is ok if
+ // <*,0,*> is missing.
+ ASSERT_TRUE((gen(pos) == 0U) ||
+ (gen(pos) > static_cast<uint64_t>(initial_state.Get(
+ static_cast<int>(key(pos))))))
+ << "key: " << key(pos) << "; gen: " << gen(pos)
+ << "; initgen: " << initial_state.Get(static_cast<int>(key(pos)));
+
+ // Advance to next key in the valid key space
+ if (key(pos) < key(current)) {
+ pos = MakeKey(key(pos) + 1, 0);
+ } else {
+ pos = MakeKey(key(pos), gen(pos) + 1);
+ }
+ }
+
+ if (!iter.Valid()) {
+ break;
+ }
+
+ if (rnd->Next() % 2) {
+ iter.Next();
+ pos = MakeKey(key(pos), gen(pos) + 1);
+ } else {
+ Key new_target = RandomTarget(rnd);
+ if (new_target > pos) {
+ pos = new_target;
+ iter.Seek(Encode(&new_target));
+ }
+ }
+ }
+ }
+};
+const uint32_t ConcurrentTest::K;
+
+// Simple test that does single-threaded testing of the ConcurrentTest
+// scaffolding.
+TEST_F(InlineSkipTest, ConcurrentReadWithoutThreads) {
+ ConcurrentTest test;
+ Random rnd(test::RandomSeed());
+ for (int i = 0; i < 10000; i++) {
+ test.ReadStep(&rnd);
+ test.WriteStep(&rnd);
+ }
+}
+
+TEST_F(InlineSkipTest, ConcurrentInsertWithoutThreads) {
+ ConcurrentTest test;
+ Random rnd(test::RandomSeed());
+ for (int i = 0; i < 10000; i++) {
+ test.ReadStep(&rnd);
+ uint32_t base = rnd.Next();
+ for (int j = 0; j < 4; ++j) {
+ test.ConcurrentWriteStep((base + j) % ConcurrentTest::K);
+ }
+ }
+}
+
+class TestState {
+ public:
+ ConcurrentTest t_;
+ int seed_;
+ std::atomic<bool> quit_flag_;
+ std::atomic<uint32_t> next_writer_;
+
+ enum ReaderState { STARTING, RUNNING, DONE };
+
+ explicit TestState(int s)
+ : seed_(s),
+ quit_flag_(false),
+ state_(STARTING),
+ pending_writers_(0),
+ state_cv_(&mu_) {}
+
+ void Wait(ReaderState s) {
+ mu_.Lock();
+ while (state_ != s) {
+ state_cv_.Wait();
+ }
+ mu_.Unlock();
+ }
+
+ void Change(ReaderState s) {
+ mu_.Lock();
+ state_ = s;
+ state_cv_.Signal();
+ mu_.Unlock();
+ }
+
+ void AdjustPendingWriters(int delta) {
+ mu_.Lock();
+ pending_writers_ += delta;
+ if (pending_writers_ == 0) {
+ state_cv_.Signal();
+ }
+ mu_.Unlock();
+ }
+
+ void WaitForPendingWriters() {
+ mu_.Lock();
+ while (pending_writers_ != 0) {
+ state_cv_.Wait();
+ }
+ mu_.Unlock();
+ }
+
+ private:
+ port::Mutex mu_;
+ ReaderState state_;
+ int pending_writers_;
+ port::CondVar state_cv_;
+};
+
+static void ConcurrentReader(void* arg) {
+ TestState* state = reinterpret_cast<TestState*>(arg);
+ Random rnd(state->seed_);
+ int64_t reads = 0;
+ state->Change(TestState::RUNNING);
+ while (!state->quit_flag_.load(std::memory_order_acquire)) {
+ state->t_.ReadStep(&rnd);
+ ++reads;
+ }
+ state->Change(TestState::DONE);
+}
+
+static void ConcurrentWriter(void* arg) {
+ TestState* state = reinterpret_cast<TestState*>(arg);
+ uint32_t k = state->next_writer_++ % ConcurrentTest::K;
+ state->t_.ConcurrentWriteStep(k);
+ state->AdjustPendingWriters(-1);
+}
+
+static void RunConcurrentRead(int run) {
+ const int seed = test::RandomSeed() + (run * 100);
+ Random rnd(seed);
+ const int N = 1000;
+ const int kSize = 1000;
+ for (int i = 0; i < N; i++) {
+ if ((i % 100) == 0) {
+ fprintf(stderr, "Run %d of %d\n", i, N);
+ }
+ TestState state(seed + 1);
+ Env::Default()->SetBackgroundThreads(1);
+ Env::Default()->Schedule(ConcurrentReader, &state);
+ state.Wait(TestState::RUNNING);
+ for (int k = 0; k < kSize; ++k) {
+ state.t_.WriteStep(&rnd);
+ }
+ state.quit_flag_.store(true, std::memory_order_release);
+ state.Wait(TestState::DONE);
+ }
+}
+
+static void RunConcurrentInsert(int run, int write_parallelism = 4) {
+ Env::Default()->SetBackgroundThreads(1 + write_parallelism,
+ Env::Priority::LOW);
+ const int seed = test::RandomSeed() + (run * 100);
+ Random rnd(seed);
+ const int N = 1000;
+ const int kSize = 1000;
+ for (int i = 0; i < N; i++) {
+ if ((i % 100) == 0) {
+ fprintf(stderr, "Run %d of %d\n", i, N);
+ }
+ TestState state(seed + 1);
+ Env::Default()->Schedule(ConcurrentReader, &state);
+ state.Wait(TestState::RUNNING);
+ for (int k = 0; k < kSize; k += write_parallelism) {
+ state.next_writer_ = rnd.Next();
+ state.AdjustPendingWriters(write_parallelism);
+ for (int p = 0; p < write_parallelism; ++p) {
+ Env::Default()->Schedule(ConcurrentWriter, &state);
+ }
+ state.WaitForPendingWriters();
+ }
+ state.quit_flag_.store(true, std::memory_order_release);
+ state.Wait(TestState::DONE);
+ }
+}
+
+TEST_F(InlineSkipTest, ConcurrentRead1) { RunConcurrentRead(1); }
+TEST_F(InlineSkipTest, ConcurrentRead2) { RunConcurrentRead(2); }
+TEST_F(InlineSkipTest, ConcurrentRead3) { RunConcurrentRead(3); }
+TEST_F(InlineSkipTest, ConcurrentRead4) { RunConcurrentRead(4); }
+TEST_F(InlineSkipTest, ConcurrentRead5) { RunConcurrentRead(5); }
+TEST_F(InlineSkipTest, ConcurrentInsert1) { RunConcurrentInsert(1); }
+TEST_F(InlineSkipTest, ConcurrentInsert2) { RunConcurrentInsert(2); }
+TEST_F(InlineSkipTest, ConcurrentInsert3) { RunConcurrentInsert(3); }
+
+#endif // ROCKSDB_VALGRIND_RUN
+} // namespace rocksdb
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/rocksdb/memtable/memtablerep_bench.cc b/src/rocksdb/memtable/memtablerep_bench.cc
new file mode 100644
index 00000000..51ff11a0
--- /dev/null
+++ b/src/rocksdb/memtable/memtablerep_bench.cc
@@ -0,0 +1,682 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#ifndef __STDC_FORMAT_MACROS
+#define __STDC_FORMAT_MACROS
+#endif
+
+#ifndef GFLAGS
+#include <cstdio>
+int main() {
+ fprintf(stderr, "Please install gflags to run rocksdb tools\n");
+ return 1;
+}
+#else
+
+#include <atomic>
+#include <iostream>
+#include <memory>
+#include <thread>
+#include <type_traits>
+#include <vector>
+
+#include "db/dbformat.h"
+#include "db/memtable.h"
+#include "port/port.h"
+#include "port/stack_trace.h"
+#include "rocksdb/comparator.h"
+#include "rocksdb/memtablerep.h"
+#include "rocksdb/options.h"
+#include "rocksdb/slice_transform.h"
+#include "rocksdb/write_buffer_manager.h"
+#include "util/arena.h"
+#include "util/gflags_compat.h"
+#include "util/mutexlock.h"
+#include "util/stop_watch.h"
+#include "util/testutil.h"
+
+using GFLAGS_NAMESPACE::ParseCommandLineFlags;
+using GFLAGS_NAMESPACE::RegisterFlagValidator;
+using GFLAGS_NAMESPACE::SetUsageMessage;
+
+DEFINE_string(benchmarks, "fillrandom",
+ "Comma-separated list of benchmarks to run. Options:\n"
+ "\tfillrandom -- write N random values\n"
+ "\tfillseq -- write N values in sequential order\n"
+ "\treadrandom -- read N values in random order\n"
+ "\treadseq -- scan the DB\n"
+ "\treadwrite -- 1 thread writes while N - 1 threads "
+ "do random\n"
+ "\t reads\n"
+ "\tseqreadwrite -- 1 thread writes while N - 1 threads "
+ "do scans\n");
+
+DEFINE_string(memtablerep, "skiplist",
+ "Which implementation of memtablerep to use. See "
+ "include/memtablerep.h for\n"
+ " more details. Options:\n"
+ "\tskiplist -- backed by a skiplist\n"
+ "\tvector -- backed by an std::vector\n"
+ "\thashskiplist -- backed by a hash skip list\n"
+ "\thashlinklist -- backed by a hash linked list\n"
+ "\tcuckoo -- backed by a cuckoo hash table");
+
+DEFINE_int64(bucket_count, 1000000,
+ "bucket_count parameter to pass into NewHashSkiplistRepFactory or "
+ "NewHashLinkListRepFactory");
+
+DEFINE_int32(
+ hashskiplist_height, 4,
+ "skiplist_height parameter to pass into NewHashSkiplistRepFactory");
+
+DEFINE_int32(
+ hashskiplist_branching_factor, 4,
+ "branching_factor parameter to pass into NewHashSkiplistRepFactory");
+
+DEFINE_int32(
+ huge_page_tlb_size, 0,
+ "huge_page_tlb_size parameter to pass into NewHashLinkListRepFactory");
+
+DEFINE_int32(bucket_entries_logging_threshold, 4096,
+ "bucket_entries_logging_threshold parameter to pass into "
+ "NewHashLinkListRepFactory");
+
+DEFINE_bool(if_log_bucket_dist_when_flash, true,
+ "if_log_bucket_dist_when_flash parameter to pass into "
+ "NewHashLinkListRepFactory");
+
+DEFINE_int32(
+ threshold_use_skiplist, 256,
+ "threshold_use_skiplist parameter to pass into NewHashLinkListRepFactory");
+
+DEFINE_int64(write_buffer_size, 256,
+ "write_buffer_size parameter to pass into WriteBufferManager");
+
+DEFINE_int32(
+ num_threads, 1,
+ "Number of concurrent threads to run. If the benchmark includes writes,\n"
+ "then at most one thread will be a writer");
+
+DEFINE_int32(num_operations, 1000000,
+ "Number of operations to do for write and random read benchmarks");
+
+DEFINE_int32(num_scans, 10,
+ "Number of times for each thread to scan the memtablerep for "
+ "sequential read "
+ "benchmarks");
+
+DEFINE_int32(item_size, 100, "Number of bytes each item should be");
+
+DEFINE_int32(prefix_length, 8,
+ "Prefix length to pass into NewFixedPrefixTransform");
+
+/* VectorRep settings */
+DEFINE_int64(vectorrep_count, 0,
+ "Number of entries to reserve on VectorRep initialization");
+
+DEFINE_int64(seed, 0,
+ "Seed base for random number generators. "
+ "When 0 it is deterministic.");
+
+namespace rocksdb {
+
+namespace {
+struct CallbackVerifyArgs {
+ bool found;
+ LookupKey* key;
+ MemTableRep* table;
+ InternalKeyComparator* comparator;
+};
+} // namespace
+
+// Helper for quickly generating random data.
+class RandomGenerator {
+ private:
+ std::string data_;
+ unsigned int pos_;
+
+ public:
+ RandomGenerator() {
+ Random rnd(301);
+ auto size = (unsigned)std::max(1048576, FLAGS_item_size);
+ test::RandomString(&rnd, size, &data_);
+ pos_ = 0;
+ }
+
+ Slice Generate(unsigned int len) {
+ assert(len <= data_.size());
+ if (pos_ + len > data_.size()) {
+ pos_ = 0;
+ }
+ pos_ += len;
+ return Slice(data_.data() + pos_ - len, len);
+ }
+};
+
+enum WriteMode { SEQUENTIAL, RANDOM, UNIQUE_RANDOM };
+
+class KeyGenerator {
+ public:
+ KeyGenerator(Random64* rand, WriteMode mode, uint64_t num)
+ : rand_(rand), mode_(mode), num_(num), next_(0) {
+ if (mode_ == UNIQUE_RANDOM) {
+ // NOTE: if memory consumption of this approach becomes a concern,
+ // we can either break it into pieces and only random shuffle a section
+ // each time. Alternatively, use a bit map implementation
+ // (https://reviews.facebook.net/differential/diff/54627/)
+ values_.resize(num_);
+ for (uint64_t i = 0; i < num_; ++i) {
+ values_[i] = i;
+ }
+ std::shuffle(
+ values_.begin(), values_.end(),
+ std::default_random_engine(static_cast<unsigned int>(FLAGS_seed)));
+ }
+ }
+
+ uint64_t Next() {
+ switch (mode_) {
+ case SEQUENTIAL:
+ return next_++;
+ case RANDOM:
+ return rand_->Next() % num_;
+ case UNIQUE_RANDOM:
+ return values_[next_++];
+ }
+ assert(false);
+ return std::numeric_limits<uint64_t>::max();
+ }
+
+ private:
+ Random64* rand_;
+ WriteMode mode_;
+ const uint64_t num_;
+ uint64_t next_;
+ std::vector<uint64_t> values_;
+};
+
+class BenchmarkThread {
+ public:
+ explicit BenchmarkThread(MemTableRep* table, KeyGenerator* key_gen,
+ uint64_t* bytes_written, uint64_t* bytes_read,
+ uint64_t* sequence, uint64_t num_ops,
+ uint64_t* read_hits)
+ : table_(table),
+ key_gen_(key_gen),
+ bytes_written_(bytes_written),
+ bytes_read_(bytes_read),
+ sequence_(sequence),
+ num_ops_(num_ops),
+ read_hits_(read_hits) {}
+
+ virtual void operator()() = 0;
+ virtual ~BenchmarkThread() {}
+
+ protected:
+ MemTableRep* table_;
+ KeyGenerator* key_gen_;
+ uint64_t* bytes_written_;
+ uint64_t* bytes_read_;
+ uint64_t* sequence_;
+ uint64_t num_ops_;
+ uint64_t* read_hits_;
+ RandomGenerator generator_;
+};
+
+class FillBenchmarkThread : public BenchmarkThread {
+ public:
+ FillBenchmarkThread(MemTableRep* table, KeyGenerator* key_gen,
+ uint64_t* bytes_written, uint64_t* bytes_read,
+ uint64_t* sequence, uint64_t num_ops, uint64_t* read_hits)
+ : BenchmarkThread(table, key_gen, bytes_written, bytes_read, sequence,
+ num_ops, read_hits) {}
+
+ void FillOne() {
+ char* buf = nullptr;
+ auto internal_key_size = 16;
+ auto encoded_len =
+ FLAGS_item_size + VarintLength(internal_key_size) + internal_key_size;
+ KeyHandle handle = table_->Allocate(encoded_len, &buf);
+ assert(buf != nullptr);
+ char* p = EncodeVarint32(buf, internal_key_size);
+ auto key = key_gen_->Next();
+ EncodeFixed64(p, key);
+ p += 8;
+ EncodeFixed64(p, ++(*sequence_));
+ p += 8;
+ Slice bytes = generator_.Generate(FLAGS_item_size);
+ memcpy(p, bytes.data(), FLAGS_item_size);
+ p += FLAGS_item_size;
+ assert(p == buf + encoded_len);
+ table_->Insert(handle);
+ *bytes_written_ += encoded_len;
+ }
+
+ void operator()() override {
+ for (unsigned int i = 0; i < num_ops_; ++i) {
+ FillOne();
+ }
+ }
+};
+
+class ConcurrentFillBenchmarkThread : public FillBenchmarkThread {
+ public:
+ ConcurrentFillBenchmarkThread(MemTableRep* table, KeyGenerator* key_gen,
+ uint64_t* bytes_written, uint64_t* bytes_read,
+ uint64_t* sequence, uint64_t num_ops,
+ uint64_t* read_hits,
+ std::atomic_int* threads_done)
+ : FillBenchmarkThread(table, key_gen, bytes_written, bytes_read, sequence,
+ num_ops, read_hits) {
+ threads_done_ = threads_done;
+ }
+
+ void operator()() override {
+ // # of read threads will be total threads - write threads (always 1). Loop
+ // while all reads complete.
+ while ((*threads_done_).load() < (FLAGS_num_threads - 1)) {
+ FillOne();
+ }
+ }
+
+ private:
+ std::atomic_int* threads_done_;
+};
+
+class ReadBenchmarkThread : public BenchmarkThread {
+ public:
+ ReadBenchmarkThread(MemTableRep* table, KeyGenerator* key_gen,
+ uint64_t* bytes_written, uint64_t* bytes_read,
+ uint64_t* sequence, uint64_t num_ops, uint64_t* read_hits)
+ : BenchmarkThread(table, key_gen, bytes_written, bytes_read, sequence,
+ num_ops, read_hits) {}
+
+ static bool callback(void* arg, const char* entry) {
+ CallbackVerifyArgs* callback_args = static_cast<CallbackVerifyArgs*>(arg);
+ assert(callback_args != nullptr);
+ uint32_t key_length;
+ const char* key_ptr = GetVarint32Ptr(entry, entry + 5, &key_length);
+ if ((callback_args->comparator)
+ ->user_comparator()
+ ->Equal(Slice(key_ptr, key_length - 8),
+ callback_args->key->user_key())) {
+ callback_args->found = true;
+ }
+ return false;
+ }
+
+ void ReadOne() {
+ std::string user_key;
+ auto key = key_gen_->Next();
+ PutFixed64(&user_key, key);
+ LookupKey lookup_key(user_key, *sequence_);
+ InternalKeyComparator internal_key_comp(BytewiseComparator());
+ CallbackVerifyArgs verify_args;
+ verify_args.found = false;
+ verify_args.key = &lookup_key;
+ verify_args.table = table_;
+ verify_args.comparator = &internal_key_comp;
+ table_->Get(lookup_key, &verify_args, callback);
+ if (verify_args.found) {
+ *bytes_read_ += VarintLength(16) + 16 + FLAGS_item_size;
+ ++*read_hits_;
+ }
+ }
+ void operator()() override {
+ for (unsigned int i = 0; i < num_ops_; ++i) {
+ ReadOne();
+ }
+ }
+};
+
+class SeqReadBenchmarkThread : public BenchmarkThread {
+ public:
+ SeqReadBenchmarkThread(MemTableRep* table, KeyGenerator* key_gen,
+ uint64_t* bytes_written, uint64_t* bytes_read,
+ uint64_t* sequence, uint64_t num_ops,
+ uint64_t* read_hits)
+ : BenchmarkThread(table, key_gen, bytes_written, bytes_read, sequence,
+ num_ops, read_hits) {}
+
+ void ReadOneSeq() {
+ std::unique_ptr<MemTableRep::Iterator> iter(table_->GetIterator());
+ for (iter->SeekToFirst(); iter->Valid(); iter->Next()) {
+ // pretend to read the value
+ *bytes_read_ += VarintLength(16) + 16 + FLAGS_item_size;
+ }
+ ++*read_hits_;
+ }
+
+ void operator()() override {
+ for (unsigned int i = 0; i < num_ops_; ++i) {
+ { ReadOneSeq(); }
+ }
+ }
+};
+
+class ConcurrentReadBenchmarkThread : public ReadBenchmarkThread {
+ public:
+ ConcurrentReadBenchmarkThread(MemTableRep* table, KeyGenerator* key_gen,
+ uint64_t* bytes_written, uint64_t* bytes_read,
+ uint64_t* sequence, uint64_t num_ops,
+ uint64_t* read_hits,
+ std::atomic_int* threads_done)
+ : ReadBenchmarkThread(table, key_gen, bytes_written, bytes_read, sequence,
+ num_ops, read_hits) {
+ threads_done_ = threads_done;
+ }
+
+ void operator()() override {
+ for (unsigned int i = 0; i < num_ops_; ++i) {
+ ReadOne();
+ }
+ ++*threads_done_;
+ }
+
+ private:
+ std::atomic_int* threads_done_;
+};
+
+class SeqConcurrentReadBenchmarkThread : public SeqReadBenchmarkThread {
+ public:
+ SeqConcurrentReadBenchmarkThread(MemTableRep* table, KeyGenerator* key_gen,
+ uint64_t* bytes_written,
+ uint64_t* bytes_read, uint64_t* sequence,
+ uint64_t num_ops, uint64_t* read_hits,
+ std::atomic_int* threads_done)
+ : SeqReadBenchmarkThread(table, key_gen, bytes_written, bytes_read,
+ sequence, num_ops, read_hits) {
+ threads_done_ = threads_done;
+ }
+
+ void operator()() override {
+ for (unsigned int i = 0; i < num_ops_; ++i) {
+ ReadOneSeq();
+ }
+ ++*threads_done_;
+ }
+
+ private:
+ std::atomic_int* threads_done_;
+};
+
+class Benchmark {
+ public:
+ explicit Benchmark(MemTableRep* table, KeyGenerator* key_gen,
+ uint64_t* sequence, uint32_t num_threads)
+ : table_(table),
+ key_gen_(key_gen),
+ sequence_(sequence),
+ num_threads_(num_threads) {}
+
+ virtual ~Benchmark() {}
+ virtual void Run() {
+ std::cout << "Number of threads: " << num_threads_ << std::endl;
+ std::vector<port::Thread> threads;
+ uint64_t bytes_written = 0;
+ uint64_t bytes_read = 0;
+ uint64_t read_hits = 0;
+ StopWatchNano timer(Env::Default(), true);
+ RunThreads(&threads, &bytes_written, &bytes_read, true, &read_hits);
+ auto elapsed_time = static_cast<double>(timer.ElapsedNanos() / 1000);
+ std::cout << "Elapsed time: " << static_cast<int>(elapsed_time) << " us"
+ << std::endl;
+
+ if (bytes_written > 0) {
+ auto MiB_written = static_cast<double>(bytes_written) / (1 << 20);
+ auto write_throughput = MiB_written / (elapsed_time / 1000000);
+ std::cout << "Total bytes written: " << MiB_written << " MiB"
+ << std::endl;
+ std::cout << "Write throughput: " << write_throughput << " MiB/s"
+ << std::endl;
+ auto us_per_op = elapsed_time / num_write_ops_per_thread_;
+ std::cout << "write us/op: " << us_per_op << std::endl;
+ }
+ if (bytes_read > 0) {
+ auto MiB_read = static_cast<double>(bytes_read) / (1 << 20);
+ auto read_throughput = MiB_read / (elapsed_time / 1000000);
+ std::cout << "Total bytes read: " << MiB_read << " MiB" << std::endl;
+ std::cout << "Read throughput: " << read_throughput << " MiB/s"
+ << std::endl;
+ auto us_per_op = elapsed_time / num_read_ops_per_thread_;
+ std::cout << "read us/op: " << us_per_op << std::endl;
+ }
+ }
+
+ virtual void RunThreads(std::vector<port::Thread>* threads,
+ uint64_t* bytes_written, uint64_t* bytes_read,
+ bool write, uint64_t* read_hits) = 0;
+
+ protected:
+ MemTableRep* table_;
+ KeyGenerator* key_gen_;
+ uint64_t* sequence_;
+ uint64_t num_write_ops_per_thread_;
+ uint64_t num_read_ops_per_thread_;
+ const uint32_t num_threads_;
+};
+
+class FillBenchmark : public Benchmark {
+ public:
+ explicit FillBenchmark(MemTableRep* table, KeyGenerator* key_gen,
+ uint64_t* sequence)
+ : Benchmark(table, key_gen, sequence, 1) {
+ num_write_ops_per_thread_ = FLAGS_num_operations;
+ }
+
+ void RunThreads(std::vector<port::Thread>* /*threads*/, uint64_t* bytes_written,
+ uint64_t* bytes_read, bool /*write*/,
+ uint64_t* read_hits) override {
+ FillBenchmarkThread(table_, key_gen_, bytes_written, bytes_read, sequence_,
+ num_write_ops_per_thread_, read_hits)();
+ }
+};
+
+class ReadBenchmark : public Benchmark {
+ public:
+ explicit ReadBenchmark(MemTableRep* table, KeyGenerator* key_gen,
+ uint64_t* sequence)
+ : Benchmark(table, key_gen, sequence, FLAGS_num_threads) {
+ num_read_ops_per_thread_ = FLAGS_num_operations / FLAGS_num_threads;
+ }
+
+ void RunThreads(std::vector<port::Thread>* threads, uint64_t* bytes_written,
+ uint64_t* bytes_read, bool /*write*/,
+ uint64_t* read_hits) override {
+ for (int i = 0; i < FLAGS_num_threads; ++i) {
+ threads->emplace_back(
+ ReadBenchmarkThread(table_, key_gen_, bytes_written, bytes_read,
+ sequence_, num_read_ops_per_thread_, read_hits));
+ }
+ for (auto& thread : *threads) {
+ thread.join();
+ }
+ std::cout << "read hit%: "
+ << (static_cast<double>(*read_hits) / FLAGS_num_operations) * 100
+ << std::endl;
+ }
+};
+
+class SeqReadBenchmark : public Benchmark {
+ public:
+ explicit SeqReadBenchmark(MemTableRep* table, uint64_t* sequence)
+ : Benchmark(table, nullptr, sequence, FLAGS_num_threads) {
+ num_read_ops_per_thread_ = FLAGS_num_scans;
+ }
+
+ void RunThreads(std::vector<port::Thread>* threads, uint64_t* bytes_written,
+ uint64_t* bytes_read, bool /*write*/,
+ uint64_t* read_hits) override {
+ for (int i = 0; i < FLAGS_num_threads; ++i) {
+ threads->emplace_back(SeqReadBenchmarkThread(
+ table_, key_gen_, bytes_written, bytes_read, sequence_,
+ num_read_ops_per_thread_, read_hits));
+ }
+ for (auto& thread : *threads) {
+ thread.join();
+ }
+ }
+};
+
+template <class ReadThreadType>
+class ReadWriteBenchmark : public Benchmark {
+ public:
+ explicit ReadWriteBenchmark(MemTableRep* table, KeyGenerator* key_gen,
+ uint64_t* sequence)
+ : Benchmark(table, key_gen, sequence, FLAGS_num_threads) {
+ num_read_ops_per_thread_ =
+ FLAGS_num_threads <= 1
+ ? 0
+ : (FLAGS_num_operations / (FLAGS_num_threads - 1));
+ num_write_ops_per_thread_ = FLAGS_num_operations;
+ }
+
+ void RunThreads(std::vector<port::Thread>* threads, uint64_t* bytes_written,
+ uint64_t* bytes_read, bool /*write*/,
+ uint64_t* read_hits) override {
+ std::atomic_int threads_done;
+ threads_done.store(0);
+ threads->emplace_back(ConcurrentFillBenchmarkThread(
+ table_, key_gen_, bytes_written, bytes_read, sequence_,
+ num_write_ops_per_thread_, read_hits, &threads_done));
+ for (int i = 1; i < FLAGS_num_threads; ++i) {
+ threads->emplace_back(
+ ReadThreadType(table_, key_gen_, bytes_written, bytes_read, sequence_,
+ num_read_ops_per_thread_, read_hits, &threads_done));
+ }
+ for (auto& thread : *threads) {
+ thread.join();
+ }
+ }
+};
+
+} // namespace rocksdb
+
+void PrintWarnings() {
+#if defined(__GNUC__) && !defined(__OPTIMIZE__)
+ fprintf(stdout,
+ "WARNING: Optimization is disabled: benchmarks unnecessarily slow\n");
+#endif
+#ifndef NDEBUG
+ fprintf(stdout,
+ "WARNING: Assertions are enabled; benchmarks unnecessarily slow\n");
+#endif
+}
+
+int main(int argc, char** argv) {
+ rocksdb::port::InstallStackTraceHandler();
+ SetUsageMessage(std::string("\nUSAGE:\n") + std::string(argv[0]) +
+ " [OPTIONS]...");
+ ParseCommandLineFlags(&argc, &argv, true);
+
+ PrintWarnings();
+
+ rocksdb::Options options;
+
+ std::unique_ptr<rocksdb::MemTableRepFactory> factory;
+ if (FLAGS_memtablerep == "skiplist") {
+ factory.reset(new rocksdb::SkipListFactory);
+#ifndef ROCKSDB_LITE
+ } else if (FLAGS_memtablerep == "vector") {
+ factory.reset(new rocksdb::VectorRepFactory);
+ } else if (FLAGS_memtablerep == "hashskiplist") {
+ factory.reset(rocksdb::NewHashSkipListRepFactory(
+ FLAGS_bucket_count, FLAGS_hashskiplist_height,
+ FLAGS_hashskiplist_branching_factor));
+ options.prefix_extractor.reset(
+ rocksdb::NewFixedPrefixTransform(FLAGS_prefix_length));
+ } else if (FLAGS_memtablerep == "hashlinklist") {
+ factory.reset(rocksdb::NewHashLinkListRepFactory(
+ FLAGS_bucket_count, FLAGS_huge_page_tlb_size,
+ FLAGS_bucket_entries_logging_threshold,
+ FLAGS_if_log_bucket_dist_when_flash, FLAGS_threshold_use_skiplist));
+ options.prefix_extractor.reset(
+ rocksdb::NewFixedPrefixTransform(FLAGS_prefix_length));
+#endif // ROCKSDB_LITE
+ } else {
+ fprintf(stdout, "Unknown memtablerep: %s\n", FLAGS_memtablerep.c_str());
+ exit(1);
+ }
+
+ rocksdb::InternalKeyComparator internal_key_comp(
+ rocksdb::BytewiseComparator());
+ rocksdb::MemTable::KeyComparator key_comp(internal_key_comp);
+ rocksdb::Arena arena;
+ rocksdb::WriteBufferManager wb(FLAGS_write_buffer_size);
+ uint64_t sequence;
+ auto createMemtableRep = [&] {
+ sequence = 0;
+ return factory->CreateMemTableRep(key_comp, &arena,
+ options.prefix_extractor.get(),
+ options.info_log.get());
+ };
+ std::unique_ptr<rocksdb::MemTableRep> memtablerep;
+ rocksdb::Random64 rng(FLAGS_seed);
+ const char* benchmarks = FLAGS_benchmarks.c_str();
+ while (benchmarks != nullptr) {
+ std::unique_ptr<rocksdb::KeyGenerator> key_gen;
+ const char* sep = strchr(benchmarks, ',');
+ rocksdb::Slice name;
+ if (sep == nullptr) {
+ name = benchmarks;
+ benchmarks = nullptr;
+ } else {
+ name = rocksdb::Slice(benchmarks, sep - benchmarks);
+ benchmarks = sep + 1;
+ }
+ std::unique_ptr<rocksdb::Benchmark> benchmark;
+ if (name == rocksdb::Slice("fillseq")) {
+ memtablerep.reset(createMemtableRep());
+ key_gen.reset(new rocksdb::KeyGenerator(&rng, rocksdb::SEQUENTIAL,
+ FLAGS_num_operations));
+ benchmark.reset(new rocksdb::FillBenchmark(memtablerep.get(),
+ key_gen.get(), &sequence));
+ } else if (name == rocksdb::Slice("fillrandom")) {
+ memtablerep.reset(createMemtableRep());
+ key_gen.reset(new rocksdb::KeyGenerator(&rng, rocksdb::UNIQUE_RANDOM,
+ FLAGS_num_operations));
+ benchmark.reset(new rocksdb::FillBenchmark(memtablerep.get(),
+ key_gen.get(), &sequence));
+ } else if (name == rocksdb::Slice("readrandom")) {
+ key_gen.reset(new rocksdb::KeyGenerator(&rng, rocksdb::RANDOM,
+ FLAGS_num_operations));
+ benchmark.reset(new rocksdb::ReadBenchmark(memtablerep.get(),
+ key_gen.get(), &sequence));
+ } else if (name == rocksdb::Slice("readseq")) {
+ key_gen.reset(new rocksdb::KeyGenerator(&rng, rocksdb::SEQUENTIAL,
+ FLAGS_num_operations));
+ benchmark.reset(
+ new rocksdb::SeqReadBenchmark(memtablerep.get(), &sequence));
+ } else if (name == rocksdb::Slice("readwrite")) {
+ memtablerep.reset(createMemtableRep());
+ key_gen.reset(new rocksdb::KeyGenerator(&rng, rocksdb::RANDOM,
+ FLAGS_num_operations));
+ benchmark.reset(new rocksdb::ReadWriteBenchmark<
+ rocksdb::ConcurrentReadBenchmarkThread>(memtablerep.get(),
+ key_gen.get(), &sequence));
+ } else if (name == rocksdb::Slice("seqreadwrite")) {
+ memtablerep.reset(createMemtableRep());
+ key_gen.reset(new rocksdb::KeyGenerator(&rng, rocksdb::RANDOM,
+ FLAGS_num_operations));
+ benchmark.reset(new rocksdb::ReadWriteBenchmark<
+ rocksdb::SeqConcurrentReadBenchmarkThread>(memtablerep.get(),
+ key_gen.get(), &sequence));
+ } else {
+ std::cout << "WARNING: skipping unknown benchmark '" << name.ToString()
+ << std::endl;
+ continue;
+ }
+ std::cout << "Running " << name.ToString() << std::endl;
+ benchmark->Run();
+ }
+
+ return 0;
+}
+
+#endif // GFLAGS
diff --git a/src/rocksdb/memtable/skiplist.h b/src/rocksdb/memtable/skiplist.h
new file mode 100644
index 00000000..47a89034
--- /dev/null
+++ b/src/rocksdb/memtable/skiplist.h
@@ -0,0 +1,497 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+//
+// Thread safety
+// -------------
+//
+// Writes require external synchronization, most likely a mutex.
+// Reads require a guarantee that the SkipList will not be destroyed
+// while the read is in progress. Apart from that, reads progress
+// without any internal locking or synchronization.
+//
+// Invariants:
+//
+// (1) Allocated nodes are never deleted until the SkipList is
+// destroyed. This is trivially guaranteed by the code since we
+// never delete any skip list nodes.
+//
+// (2) The contents of a Node except for the next/prev pointers are
+// immutable after the Node has been linked into the SkipList.
+// Only Insert() modifies the list, and it is careful to initialize
+// a node and use release-stores to publish the nodes in one or
+// more lists.
+//
+// ... prev vs. next pointer ordering ...
+//
+
+#pragma once
+#include <assert.h>
+#include <atomic>
+#include <stdlib.h>
+#include "port/port.h"
+#include "util/allocator.h"
+#include "util/random.h"
+
+namespace rocksdb {
+
+template<typename Key, class Comparator>
+class SkipList {
+ private:
+ struct Node;
+
+ public:
+ // Create a new SkipList object that will use "cmp" for comparing keys,
+ // and will allocate memory using "*allocator". Objects allocated in the
+ // allocator must remain allocated for the lifetime of the skiplist object.
+ explicit SkipList(Comparator cmp, Allocator* allocator,
+ int32_t max_height = 12, int32_t branching_factor = 4);
+
+ // Insert key into the list.
+ // REQUIRES: nothing that compares equal to key is currently in the list.
+ void Insert(const Key& key);
+
+ // Returns true iff an entry that compares equal to key is in the list.
+ bool Contains(const Key& key) const;
+
+ // Return estimated number of entries smaller than `key`.
+ uint64_t EstimateCount(const Key& key) const;
+
+ // Iteration over the contents of a skip list
+ class Iterator {
+ public:
+ // Initialize an iterator over the specified list.
+ // The returned iterator is not valid.
+ explicit Iterator(const SkipList* list);
+
+ // Change the underlying skiplist used for this iterator
+ // This enables us not changing the iterator without deallocating
+ // an old one and then allocating a new one
+ void SetList(const SkipList* list);
+
+ // Returns true iff the iterator is positioned at a valid node.
+ bool Valid() const;
+
+ // Returns the key at the current position.
+ // REQUIRES: Valid()
+ const Key& key() const;
+
+ // Advances to the next position.
+ // REQUIRES: Valid()
+ void Next();
+
+ // Advances to the previous position.
+ // REQUIRES: Valid()
+ void Prev();
+
+ // Advance to the first entry with a key >= target
+ void Seek(const Key& target);
+
+ // Retreat to the last entry with a key <= target
+ void SeekForPrev(const Key& target);
+
+ // Position at the first entry in list.
+ // Final state of iterator is Valid() iff list is not empty.
+ void SeekToFirst();
+
+ // Position at the last entry in list.
+ // Final state of iterator is Valid() iff list is not empty.
+ void SeekToLast();
+
+ private:
+ const SkipList* list_;
+ Node* node_;
+ // Intentionally copyable
+ };
+
+ private:
+ const uint16_t kMaxHeight_;
+ const uint16_t kBranching_;
+ const uint32_t kScaledInverseBranching_;
+
+ // Immutable after construction
+ Comparator const compare_;
+ Allocator* const allocator_; // Allocator used for allocations of nodes
+
+ Node* const head_;
+
+ // Modified only by Insert(). Read racily by readers, but stale
+ // values are ok.
+ std::atomic<int> max_height_; // Height of the entire list
+
+ // Used for optimizing sequential insert patterns. Tricky. prev_[i] for
+ // i up to max_height_ is the predecessor of prev_[0] and prev_height_
+ // is the height of prev_[0]. prev_[0] can only be equal to head before
+ // insertion, in which case max_height_ and prev_height_ are 1.
+ Node** prev_;
+ int32_t prev_height_;
+
+ inline int GetMaxHeight() const {
+ return max_height_.load(std::memory_order_relaxed);
+ }
+
+ Node* NewNode(const Key& key, int height);
+ int RandomHeight();
+ bool Equal(const Key& a, const Key& b) const { return (compare_(a, b) == 0); }
+ bool LessThan(const Key& a, const Key& b) const {
+ return (compare_(a, b) < 0);
+ }
+
+ // Return true if key is greater than the data stored in "n"
+ bool KeyIsAfterNode(const Key& key, Node* n) const;
+
+ // Returns the earliest node with a key >= key.
+ // Return nullptr if there is no such node.
+ Node* FindGreaterOrEqual(const Key& key) const;
+
+ // Return the latest node with a key < key.
+ // Return head_ if there is no such node.
+ // Fills prev[level] with pointer to previous node at "level" for every
+ // level in [0..max_height_-1], if prev is non-null.
+ Node* FindLessThan(const Key& key, Node** prev = nullptr) const;
+
+ // Return the last node in the list.
+ // Return head_ if list is empty.
+ Node* FindLast() const;
+
+ // No copying allowed
+ SkipList(const SkipList&);
+ void operator=(const SkipList&);
+};
+
+// Implementation details follow
+template<typename Key, class Comparator>
+struct SkipList<Key, Comparator>::Node {
+ explicit Node(const Key& k) : key(k) { }
+
+ Key const key;
+
+ // Accessors/mutators for links. Wrapped in methods so we can
+ // add the appropriate barriers as necessary.
+ Node* Next(int n) {
+ assert(n >= 0);
+ // Use an 'acquire load' so that we observe a fully initialized
+ // version of the returned Node.
+ return (next_[n].load(std::memory_order_acquire));
+ }
+ void SetNext(int n, Node* x) {
+ assert(n >= 0);
+ // Use a 'release store' so that anybody who reads through this
+ // pointer observes a fully initialized version of the inserted node.
+ next_[n].store(x, std::memory_order_release);
+ }
+
+ // No-barrier variants that can be safely used in a few locations.
+ Node* NoBarrier_Next(int n) {
+ assert(n >= 0);
+ return next_[n].load(std::memory_order_relaxed);
+ }
+ void NoBarrier_SetNext(int n, Node* x) {
+ assert(n >= 0);
+ next_[n].store(x, std::memory_order_relaxed);
+ }
+
+ private:
+ // Array of length equal to the node height. next_[0] is lowest level link.
+ std::atomic<Node*> next_[1];
+};
+
+template<typename Key, class Comparator>
+typename SkipList<Key, Comparator>::Node*
+SkipList<Key, Comparator>::NewNode(const Key& key, int height) {
+ char* mem = allocator_->AllocateAligned(
+ sizeof(Node) + sizeof(std::atomic<Node*>) * (height - 1));
+ return new (mem) Node(key);
+}
+
+template<typename Key, class Comparator>
+inline SkipList<Key, Comparator>::Iterator::Iterator(const SkipList* list) {
+ SetList(list);
+}
+
+template<typename Key, class Comparator>
+inline void SkipList<Key, Comparator>::Iterator::SetList(const SkipList* list) {
+ list_ = list;
+ node_ = nullptr;
+}
+
+template<typename Key, class Comparator>
+inline bool SkipList<Key, Comparator>::Iterator::Valid() const {
+ return node_ != nullptr;
+}
+
+template<typename Key, class Comparator>
+inline const Key& SkipList<Key, Comparator>::Iterator::key() const {
+ assert(Valid());
+ return node_->key;
+}
+
+template<typename Key, class Comparator>
+inline void SkipList<Key, Comparator>::Iterator::Next() {
+ assert(Valid());
+ node_ = node_->Next(0);
+}
+
+template<typename Key, class Comparator>
+inline void SkipList<Key, Comparator>::Iterator::Prev() {
+ // Instead of using explicit "prev" links, we just search for the
+ // last node that falls before key.
+ assert(Valid());
+ node_ = list_->FindLessThan(node_->key);
+ if (node_ == list_->head_) {
+ node_ = nullptr;
+ }
+}
+
+template<typename Key, class Comparator>
+inline void SkipList<Key, Comparator>::Iterator::Seek(const Key& target) {
+ node_ = list_->FindGreaterOrEqual(target);
+}
+
+template <typename Key, class Comparator>
+inline void SkipList<Key, Comparator>::Iterator::SeekForPrev(
+ const Key& target) {
+ Seek(target);
+ if (!Valid()) {
+ SeekToLast();
+ }
+ while (Valid() && list_->LessThan(target, key())) {
+ Prev();
+ }
+}
+
+template <typename Key, class Comparator>
+inline void SkipList<Key, Comparator>::Iterator::SeekToFirst() {
+ node_ = list_->head_->Next(0);
+}
+
+template<typename Key, class Comparator>
+inline void SkipList<Key, Comparator>::Iterator::SeekToLast() {
+ node_ = list_->FindLast();
+ if (node_ == list_->head_) {
+ node_ = nullptr;
+ }
+}
+
+template<typename Key, class Comparator>
+int SkipList<Key, Comparator>::RandomHeight() {
+ auto rnd = Random::GetTLSInstance();
+
+ // Increase height with probability 1 in kBranching
+ int height = 1;
+ while (height < kMaxHeight_ && rnd->Next() < kScaledInverseBranching_) {
+ height++;
+ }
+ assert(height > 0);
+ assert(height <= kMaxHeight_);
+ return height;
+}
+
+template<typename Key, class Comparator>
+bool SkipList<Key, Comparator>::KeyIsAfterNode(const Key& key, Node* n) const {
+ // nullptr n is considered infinite
+ return (n != nullptr) && (compare_(n->key, key) < 0);
+}
+
+template<typename Key, class Comparator>
+typename SkipList<Key, Comparator>::Node* SkipList<Key, Comparator>::
+ FindGreaterOrEqual(const Key& key) const {
+ // Note: It looks like we could reduce duplication by implementing
+ // this function as FindLessThan(key)->Next(0), but we wouldn't be able
+ // to exit early on equality and the result wouldn't even be correct.
+ // A concurrent insert might occur after FindLessThan(key) but before
+ // we get a chance to call Next(0).
+ Node* x = head_;
+ int level = GetMaxHeight() - 1;
+ Node* last_bigger = nullptr;
+ while (true) {
+ assert(x != nullptr);
+ Node* next = x->Next(level);
+ // Make sure the lists are sorted
+ assert(x == head_ || next == nullptr || KeyIsAfterNode(next->key, x));
+ // Make sure we haven't overshot during our search
+ assert(x == head_ || KeyIsAfterNode(key, x));
+ int cmp = (next == nullptr || next == last_bigger)
+ ? 1 : compare_(next->key, key);
+ if (cmp == 0 || (cmp > 0 && level == 0)) {
+ return next;
+ } else if (cmp < 0) {
+ // Keep searching in this list
+ x = next;
+ } else {
+ // Switch to next list, reuse compare_() result
+ last_bigger = next;
+ level--;
+ }
+ }
+}
+
+template<typename Key, class Comparator>
+typename SkipList<Key, Comparator>::Node*
+SkipList<Key, Comparator>::FindLessThan(const Key& key, Node** prev) const {
+ Node* x = head_;
+ int level = GetMaxHeight() - 1;
+ // KeyIsAfter(key, last_not_after) is definitely false
+ Node* last_not_after = nullptr;
+ while (true) {
+ assert(x != nullptr);
+ Node* next = x->Next(level);
+ assert(x == head_ || next == nullptr || KeyIsAfterNode(next->key, x));
+ assert(x == head_ || KeyIsAfterNode(key, x));
+ if (next != last_not_after && KeyIsAfterNode(key, next)) {
+ // Keep searching in this list
+ x = next;
+ } else {
+ if (prev != nullptr) {
+ prev[level] = x;
+ }
+ if (level == 0) {
+ return x;
+ } else {
+ // Switch to next list, reuse KeyIUsAfterNode() result
+ last_not_after = next;
+ level--;
+ }
+ }
+ }
+}
+
+template<typename Key, class Comparator>
+typename SkipList<Key, Comparator>::Node* SkipList<Key, Comparator>::FindLast()
+ const {
+ Node* x = head_;
+ int level = GetMaxHeight() - 1;
+ while (true) {
+ Node* next = x->Next(level);
+ if (next == nullptr) {
+ if (level == 0) {
+ return x;
+ } else {
+ // Switch to next list
+ level--;
+ }
+ } else {
+ x = next;
+ }
+ }
+}
+
+template <typename Key, class Comparator>
+uint64_t SkipList<Key, Comparator>::EstimateCount(const Key& key) const {
+ uint64_t count = 0;
+
+ Node* x = head_;
+ int level = GetMaxHeight() - 1;
+ while (true) {
+ assert(x == head_ || compare_(x->key, key) < 0);
+ Node* next = x->Next(level);
+ if (next == nullptr || compare_(next->key, key) >= 0) {
+ if (level == 0) {
+ return count;
+ } else {
+ // Switch to next list
+ count *= kBranching_;
+ level--;
+ }
+ } else {
+ x = next;
+ count++;
+ }
+ }
+}
+
+template <typename Key, class Comparator>
+SkipList<Key, Comparator>::SkipList(const Comparator cmp, Allocator* allocator,
+ int32_t max_height,
+ int32_t branching_factor)
+ : kMaxHeight_(static_cast<uint16_t>(max_height)),
+ kBranching_(static_cast<uint16_t>(branching_factor)),
+ kScaledInverseBranching_((Random::kMaxNext + 1) / kBranching_),
+ compare_(cmp),
+ allocator_(allocator),
+ head_(NewNode(0 /* any key will do */, max_height)),
+ max_height_(1),
+ prev_height_(1) {
+ assert(max_height > 0 && kMaxHeight_ == static_cast<uint32_t>(max_height));
+ assert(branching_factor > 0 &&
+ kBranching_ == static_cast<uint32_t>(branching_factor));
+ assert(kScaledInverseBranching_ > 0);
+ // Allocate the prev_ Node* array, directly from the passed-in allocator.
+ // prev_ does not need to be freed, as its life cycle is tied up with
+ // the allocator as a whole.
+ prev_ = reinterpret_cast<Node**>(
+ allocator_->AllocateAligned(sizeof(Node*) * kMaxHeight_));
+ for (int i = 0; i < kMaxHeight_; i++) {
+ head_->SetNext(i, nullptr);
+ prev_[i] = head_;
+ }
+}
+
+template<typename Key, class Comparator>
+void SkipList<Key, Comparator>::Insert(const Key& key) {
+ // fast path for sequential insertion
+ if (!KeyIsAfterNode(key, prev_[0]->NoBarrier_Next(0)) &&
+ (prev_[0] == head_ || KeyIsAfterNode(key, prev_[0]))) {
+ assert(prev_[0] != head_ || (prev_height_ == 1 && GetMaxHeight() == 1));
+
+ // Outside of this method prev_[1..max_height_] is the predecessor
+ // of prev_[0], and prev_height_ refers to prev_[0]. Inside Insert
+ // prev_[0..max_height - 1] is the predecessor of key. Switch from
+ // the external state to the internal
+ for (int i = 1; i < prev_height_; i++) {
+ prev_[i] = prev_[0];
+ }
+ } else {
+ // TODO(opt): we could use a NoBarrier predecessor search as an
+ // optimization for architectures where memory_order_acquire needs
+ // a synchronization instruction. Doesn't matter on x86
+ FindLessThan(key, prev_);
+ }
+
+ // Our data structure does not allow duplicate insertion
+ assert(prev_[0]->Next(0) == nullptr || !Equal(key, prev_[0]->Next(0)->key));
+
+ int height = RandomHeight();
+ if (height > GetMaxHeight()) {
+ for (int i = GetMaxHeight(); i < height; i++) {
+ prev_[i] = head_;
+ }
+ //fprintf(stderr, "Change height from %d to %d\n", max_height_, height);
+
+ // It is ok to mutate max_height_ without any synchronization
+ // with concurrent readers. A concurrent reader that observes
+ // the new value of max_height_ will see either the old value of
+ // new level pointers from head_ (nullptr), or a new value set in
+ // the loop below. In the former case the reader will
+ // immediately drop to the next level since nullptr sorts after all
+ // keys. In the latter case the reader will use the new node.
+ max_height_.store(height, std::memory_order_relaxed);
+ }
+
+ Node* x = NewNode(key, height);
+ for (int i = 0; i < height; i++) {
+ // NoBarrier_SetNext() suffices since we will add a barrier when
+ // we publish a pointer to "x" in prev[i].
+ x->NoBarrier_SetNext(i, prev_[i]->NoBarrier_Next(i));
+ prev_[i]->SetNext(i, x);
+ }
+ prev_[0] = x;
+ prev_height_ = height;
+}
+
+template<typename Key, class Comparator>
+bool SkipList<Key, Comparator>::Contains(const Key& key) const {
+ Node* x = FindGreaterOrEqual(key);
+ if (x != nullptr && Equal(key, x->key)) {
+ return true;
+ } else {
+ return false;
+ }
+}
+
+} // namespace rocksdb
diff --git a/src/rocksdb/memtable/skiplist_test.cc b/src/rocksdb/memtable/skiplist_test.cc
new file mode 100644
index 00000000..50c3588b
--- /dev/null
+++ b/src/rocksdb/memtable/skiplist_test.cc
@@ -0,0 +1,388 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "memtable/skiplist.h"
+#include <set>
+#include "rocksdb/env.h"
+#include "util/arena.h"
+#include "util/hash.h"
+#include "util/random.h"
+#include "util/testharness.h"
+
+namespace rocksdb {
+
+typedef uint64_t Key;
+
+struct TestComparator {
+ int operator()(const Key& a, const Key& b) const {
+ if (a < b) {
+ return -1;
+ } else if (a > b) {
+ return +1;
+ } else {
+ return 0;
+ }
+ }
+};
+
+class SkipTest : public testing::Test {};
+
+TEST_F(SkipTest, Empty) {
+ Arena arena;
+ TestComparator cmp;
+ SkipList<Key, TestComparator> list(cmp, &arena);
+ ASSERT_TRUE(!list.Contains(10));
+
+ SkipList<Key, TestComparator>::Iterator iter(&list);
+ ASSERT_TRUE(!iter.Valid());
+ iter.SeekToFirst();
+ ASSERT_TRUE(!iter.Valid());
+ iter.Seek(100);
+ ASSERT_TRUE(!iter.Valid());
+ iter.SeekForPrev(100);
+ ASSERT_TRUE(!iter.Valid());
+ iter.SeekToLast();
+ ASSERT_TRUE(!iter.Valid());
+}
+
+TEST_F(SkipTest, InsertAndLookup) {
+ const int N = 2000;
+ const int R = 5000;
+ Random rnd(1000);
+ std::set<Key> keys;
+ Arena arena;
+ TestComparator cmp;
+ SkipList<Key, TestComparator> list(cmp, &arena);
+ for (int i = 0; i < N; i++) {
+ Key key = rnd.Next() % R;
+ if (keys.insert(key).second) {
+ list.Insert(key);
+ }
+ }
+
+ for (int i = 0; i < R; i++) {
+ if (list.Contains(i)) {
+ ASSERT_EQ(keys.count(i), 1U);
+ } else {
+ ASSERT_EQ(keys.count(i), 0U);
+ }
+ }
+
+ // Simple iterator tests
+ {
+ SkipList<Key, TestComparator>::Iterator iter(&list);
+ ASSERT_TRUE(!iter.Valid());
+
+ iter.Seek(0);
+ ASSERT_TRUE(iter.Valid());
+ ASSERT_EQ(*(keys.begin()), iter.key());
+
+ iter.SeekForPrev(R - 1);
+ ASSERT_TRUE(iter.Valid());
+ ASSERT_EQ(*(keys.rbegin()), iter.key());
+
+ iter.SeekToFirst();
+ ASSERT_TRUE(iter.Valid());
+ ASSERT_EQ(*(keys.begin()), iter.key());
+
+ iter.SeekToLast();
+ ASSERT_TRUE(iter.Valid());
+ ASSERT_EQ(*(keys.rbegin()), iter.key());
+ }
+
+ // Forward iteration test
+ for (int i = 0; i < R; i++) {
+ SkipList<Key, TestComparator>::Iterator iter(&list);
+ iter.Seek(i);
+
+ // Compare against model iterator
+ std::set<Key>::iterator model_iter = keys.lower_bound(i);
+ for (int j = 0; j < 3; j++) {
+ if (model_iter == keys.end()) {
+ ASSERT_TRUE(!iter.Valid());
+ break;
+ } else {
+ ASSERT_TRUE(iter.Valid());
+ ASSERT_EQ(*model_iter, iter.key());
+ ++model_iter;
+ iter.Next();
+ }
+ }
+ }
+
+ // Backward iteration test
+ for (int i = 0; i < R; i++) {
+ SkipList<Key, TestComparator>::Iterator iter(&list);
+ iter.SeekForPrev(i);
+
+ // Compare against model iterator
+ std::set<Key>::iterator model_iter = keys.upper_bound(i);
+ for (int j = 0; j < 3; j++) {
+ if (model_iter == keys.begin()) {
+ ASSERT_TRUE(!iter.Valid());
+ break;
+ } else {
+ ASSERT_TRUE(iter.Valid());
+ ASSERT_EQ(*--model_iter, iter.key());
+ iter.Prev();
+ }
+ }
+ }
+}
+
+// We want to make sure that with a single writer and multiple
+// concurrent readers (with no synchronization other than when a
+// reader's iterator is created), the reader always observes all the
+// data that was present in the skip list when the iterator was
+// constructor. Because insertions are happening concurrently, we may
+// also observe new values that were inserted since the iterator was
+// constructed, but we should never miss any values that were present
+// at iterator construction time.
+//
+// We generate multi-part keys:
+// <key,gen,hash>
+// where:
+// key is in range [0..K-1]
+// gen is a generation number for key
+// hash is hash(key,gen)
+//
+// The insertion code picks a random key, sets gen to be 1 + the last
+// generation number inserted for that key, and sets hash to Hash(key,gen).
+//
+// At the beginning of a read, we snapshot the last inserted
+// generation number for each key. We then iterate, including random
+// calls to Next() and Seek(). For every key we encounter, we
+// check that it is either expected given the initial snapshot or has
+// been concurrently added since the iterator started.
+class ConcurrentTest {
+ private:
+ static const uint32_t K = 4;
+
+ static uint64_t key(Key key) { return (key >> 40); }
+ static uint64_t gen(Key key) { return (key >> 8) & 0xffffffffu; }
+ static uint64_t hash(Key key) { return key & 0xff; }
+
+ static uint64_t HashNumbers(uint64_t k, uint64_t g) {
+ uint64_t data[2] = { k, g };
+ return Hash(reinterpret_cast<char*>(data), sizeof(data), 0);
+ }
+
+ static Key MakeKey(uint64_t k, uint64_t g) {
+ assert(sizeof(Key) == sizeof(uint64_t));
+ assert(k <= K); // We sometimes pass K to seek to the end of the skiplist
+ assert(g <= 0xffffffffu);
+ return ((k << 40) | (g << 8) | (HashNumbers(k, g) & 0xff));
+ }
+
+ static bool IsValidKey(Key k) {
+ return hash(k) == (HashNumbers(key(k), gen(k)) & 0xff);
+ }
+
+ static Key RandomTarget(Random* rnd) {
+ switch (rnd->Next() % 10) {
+ case 0:
+ // Seek to beginning
+ return MakeKey(0, 0);
+ case 1:
+ // Seek to end
+ return MakeKey(K, 0);
+ default:
+ // Seek to middle
+ return MakeKey(rnd->Next() % K, 0);
+ }
+ }
+
+ // Per-key generation
+ struct State {
+ std::atomic<int> generation[K];
+ void Set(int k, int v) {
+ generation[k].store(v, std::memory_order_release);
+ }
+ int Get(int k) { return generation[k].load(std::memory_order_acquire); }
+
+ State() {
+ for (unsigned int k = 0; k < K; k++) {
+ Set(k, 0);
+ }
+ }
+ };
+
+ // Current state of the test
+ State current_;
+
+ Arena arena_;
+
+ // SkipList is not protected by mu_. We just use a single writer
+ // thread to modify it.
+ SkipList<Key, TestComparator> list_;
+
+ public:
+ ConcurrentTest() : list_(TestComparator(), &arena_) {}
+
+ // REQUIRES: External synchronization
+ void WriteStep(Random* rnd) {
+ const uint32_t k = rnd->Next() % K;
+ const int g = current_.Get(k) + 1;
+ const Key new_key = MakeKey(k, g);
+ list_.Insert(new_key);
+ current_.Set(k, g);
+ }
+
+ void ReadStep(Random* rnd) {
+ // Remember the initial committed state of the skiplist.
+ State initial_state;
+ for (unsigned int k = 0; k < K; k++) {
+ initial_state.Set(k, current_.Get(k));
+ }
+
+ Key pos = RandomTarget(rnd);
+ SkipList<Key, TestComparator>::Iterator iter(&list_);
+ iter.Seek(pos);
+ while (true) {
+ Key current;
+ if (!iter.Valid()) {
+ current = MakeKey(K, 0);
+ } else {
+ current = iter.key();
+ ASSERT_TRUE(IsValidKey(current)) << current;
+ }
+ ASSERT_LE(pos, current) << "should not go backwards";
+
+ // Verify that everything in [pos,current) was not present in
+ // initial_state.
+ while (pos < current) {
+ ASSERT_LT(key(pos), K) << pos;
+
+ // Note that generation 0 is never inserted, so it is ok if
+ // <*,0,*> is missing.
+ ASSERT_TRUE((gen(pos) == 0U) ||
+ (gen(pos) > static_cast<uint64_t>(initial_state.Get(
+ static_cast<int>(key(pos))))))
+ << "key: " << key(pos) << "; gen: " << gen(pos)
+ << "; initgen: " << initial_state.Get(static_cast<int>(key(pos)));
+
+ // Advance to next key in the valid key space
+ if (key(pos) < key(current)) {
+ pos = MakeKey(key(pos) + 1, 0);
+ } else {
+ pos = MakeKey(key(pos), gen(pos) + 1);
+ }
+ }
+
+ if (!iter.Valid()) {
+ break;
+ }
+
+ if (rnd->Next() % 2) {
+ iter.Next();
+ pos = MakeKey(key(pos), gen(pos) + 1);
+ } else {
+ Key new_target = RandomTarget(rnd);
+ if (new_target > pos) {
+ pos = new_target;
+ iter.Seek(new_target);
+ }
+ }
+ }
+ }
+};
+const uint32_t ConcurrentTest::K;
+
+// Simple test that does single-threaded testing of the ConcurrentTest
+// scaffolding.
+TEST_F(SkipTest, ConcurrentWithoutThreads) {
+ ConcurrentTest test;
+ Random rnd(test::RandomSeed());
+ for (int i = 0; i < 10000; i++) {
+ test.ReadStep(&rnd);
+ test.WriteStep(&rnd);
+ }
+}
+
+class TestState {
+ public:
+ ConcurrentTest t_;
+ int seed_;
+ std::atomic<bool> quit_flag_;
+
+ enum ReaderState {
+ STARTING,
+ RUNNING,
+ DONE
+ };
+
+ explicit TestState(int s)
+ : seed_(s), quit_flag_(false), state_(STARTING), state_cv_(&mu_) {}
+
+ void Wait(ReaderState s) {
+ mu_.Lock();
+ while (state_ != s) {
+ state_cv_.Wait();
+ }
+ mu_.Unlock();
+ }
+
+ void Change(ReaderState s) {
+ mu_.Lock();
+ state_ = s;
+ state_cv_.Signal();
+ mu_.Unlock();
+ }
+
+ private:
+ port::Mutex mu_;
+ ReaderState state_;
+ port::CondVar state_cv_;
+};
+
+static void ConcurrentReader(void* arg) {
+ TestState* state = reinterpret_cast<TestState*>(arg);
+ Random rnd(state->seed_);
+ int64_t reads = 0;
+ state->Change(TestState::RUNNING);
+ while (!state->quit_flag_.load(std::memory_order_acquire)) {
+ state->t_.ReadStep(&rnd);
+ ++reads;
+ }
+ state->Change(TestState::DONE);
+}
+
+static void RunConcurrent(int run) {
+ const int seed = test::RandomSeed() + (run * 100);
+ Random rnd(seed);
+ const int N = 1000;
+ const int kSize = 1000;
+ for (int i = 0; i < N; i++) {
+ if ((i % 100) == 0) {
+ fprintf(stderr, "Run %d of %d\n", i, N);
+ }
+ TestState state(seed + 1);
+ Env::Default()->SetBackgroundThreads(1);
+ Env::Default()->Schedule(ConcurrentReader, &state);
+ state.Wait(TestState::RUNNING);
+ for (int k = 0; k < kSize; k++) {
+ state.t_.WriteStep(&rnd);
+ }
+ state.quit_flag_.store(true, std::memory_order_release);
+ state.Wait(TestState::DONE);
+ }
+}
+
+TEST_F(SkipTest, Concurrent1) { RunConcurrent(1); }
+TEST_F(SkipTest, Concurrent2) { RunConcurrent(2); }
+TEST_F(SkipTest, Concurrent3) { RunConcurrent(3); }
+TEST_F(SkipTest, Concurrent4) { RunConcurrent(4); }
+TEST_F(SkipTest, Concurrent5) { RunConcurrent(5); }
+
+} // namespace rocksdb
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/rocksdb/memtable/skiplistrep.cc b/src/rocksdb/memtable/skiplistrep.cc
new file mode 100644
index 00000000..32870b12
--- /dev/null
+++ b/src/rocksdb/memtable/skiplistrep.cc
@@ -0,0 +1,271 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+#include "memtable/inlineskiplist.h"
+#include "db/memtable.h"
+#include "rocksdb/memtablerep.h"
+#include "util/arena.h"
+
+namespace rocksdb {
+namespace {
+class SkipListRep : public MemTableRep {
+ InlineSkipList<const MemTableRep::KeyComparator&> skip_list_;
+ const MemTableRep::KeyComparator& cmp_;
+ const SliceTransform* transform_;
+ const size_t lookahead_;
+
+ friend class LookaheadIterator;
+public:
+ explicit SkipListRep(const MemTableRep::KeyComparator& compare,
+ Allocator* allocator, const SliceTransform* transform,
+ const size_t lookahead)
+ : MemTableRep(allocator),
+ skip_list_(compare, allocator),
+ cmp_(compare),
+ transform_(transform),
+ lookahead_(lookahead) {}
+
+ KeyHandle Allocate(const size_t len, char** buf) override {
+ *buf = skip_list_.AllocateKey(len);
+ return static_cast<KeyHandle>(*buf);
+ }
+
+ // Insert key into the list.
+ // REQUIRES: nothing that compares equal to key is currently in the list.
+ void Insert(KeyHandle handle) override {
+ skip_list_.Insert(static_cast<char*>(handle));
+ }
+
+ bool InsertKey(KeyHandle handle) override {
+ return skip_list_.Insert(static_cast<char*>(handle));
+ }
+
+ void InsertWithHint(KeyHandle handle, void** hint) override {
+ skip_list_.InsertWithHint(static_cast<char*>(handle), hint);
+ }
+
+ bool InsertKeyWithHint(KeyHandle handle, void** hint) override {
+ return skip_list_.InsertWithHint(static_cast<char*>(handle), hint);
+ }
+
+ void InsertConcurrently(KeyHandle handle) override {
+ skip_list_.InsertConcurrently(static_cast<char*>(handle));
+ }
+
+ bool InsertKeyConcurrently(KeyHandle handle) override {
+ return skip_list_.InsertConcurrently(static_cast<char*>(handle));
+ }
+
+ // Returns true iff an entry that compares equal to key is in the list.
+ bool Contains(const char* key) const override {
+ return skip_list_.Contains(key);
+ }
+
+ size_t ApproximateMemoryUsage() override {
+ // All memory is allocated through allocator; nothing to report here
+ return 0;
+ }
+
+ void Get(const LookupKey& k, void* callback_args,
+ bool (*callback_func)(void* arg, const char* entry)) override {
+ SkipListRep::Iterator iter(&skip_list_);
+ Slice dummy_slice;
+ for (iter.Seek(dummy_slice, k.memtable_key().data());
+ iter.Valid() && callback_func(callback_args, iter.key()); iter.Next()) {
+ }
+ }
+
+ uint64_t ApproximateNumEntries(const Slice& start_ikey,
+ const Slice& end_ikey) override {
+ std::string tmp;
+ uint64_t start_count =
+ skip_list_.EstimateCount(EncodeKey(&tmp, start_ikey));
+ uint64_t end_count = skip_list_.EstimateCount(EncodeKey(&tmp, end_ikey));
+ return (end_count >= start_count) ? (end_count - start_count) : 0;
+ }
+
+ ~SkipListRep() override {}
+
+ // Iteration over the contents of a skip list
+ class Iterator : public MemTableRep::Iterator {
+ InlineSkipList<const MemTableRep::KeyComparator&>::Iterator iter_;
+
+ public:
+ // Initialize an iterator over the specified list.
+ // The returned iterator is not valid.
+ explicit Iterator(
+ const InlineSkipList<const MemTableRep::KeyComparator&>* list)
+ : iter_(list) {}
+
+ ~Iterator() override {}
+
+ // Returns true iff the iterator is positioned at a valid node.
+ bool Valid() const override { return iter_.Valid(); }
+
+ // Returns the key at the current position.
+ // REQUIRES: Valid()
+ const char* key() const override { return iter_.key(); }
+
+ // Advances to the next position.
+ // REQUIRES: Valid()
+ void Next() override { iter_.Next(); }
+
+ // Advances to the previous position.
+ // REQUIRES: Valid()
+ void Prev() override { iter_.Prev(); }
+
+ // Advance to the first entry with a key >= target
+ void Seek(const Slice& user_key, const char* memtable_key) override {
+ if (memtable_key != nullptr) {
+ iter_.Seek(memtable_key);
+ } else {
+ iter_.Seek(EncodeKey(&tmp_, user_key));
+ }
+ }
+
+ // Retreat to the last entry with a key <= target
+ void SeekForPrev(const Slice& user_key, const char* memtable_key) override {
+ if (memtable_key != nullptr) {
+ iter_.SeekForPrev(memtable_key);
+ } else {
+ iter_.SeekForPrev(EncodeKey(&tmp_, user_key));
+ }
+ }
+
+ // Position at the first entry in list.
+ // Final state of iterator is Valid() iff list is not empty.
+ void SeekToFirst() override { iter_.SeekToFirst(); }
+
+ // Position at the last entry in list.
+ // Final state of iterator is Valid() iff list is not empty.
+ void SeekToLast() override { iter_.SeekToLast(); }
+
+ protected:
+ std::string tmp_; // For passing to EncodeKey
+ };
+
+ // Iterator over the contents of a skip list which also keeps track of the
+ // previously visited node. In Seek(), it examines a few nodes after it
+ // first, falling back to O(log n) search from the head of the list only if
+ // the target key hasn't been found.
+ class LookaheadIterator : public MemTableRep::Iterator {
+ public:
+ explicit LookaheadIterator(const SkipListRep& rep) :
+ rep_(rep), iter_(&rep_.skip_list_), prev_(iter_) {}
+
+ ~LookaheadIterator() override {}
+
+ bool Valid() const override { return iter_.Valid(); }
+
+ const char* key() const override {
+ assert(Valid());
+ return iter_.key();
+ }
+
+ void Next() override {
+ assert(Valid());
+
+ bool advance_prev = true;
+ if (prev_.Valid()) {
+ auto k1 = rep_.UserKey(prev_.key());
+ auto k2 = rep_.UserKey(iter_.key());
+
+ if (k1.compare(k2) == 0) {
+ // same user key, don't move prev_
+ advance_prev = false;
+ } else if (rep_.transform_) {
+ // only advance prev_ if it has the same prefix as iter_
+ auto t1 = rep_.transform_->Transform(k1);
+ auto t2 = rep_.transform_->Transform(k2);
+ advance_prev = t1.compare(t2) == 0;
+ }
+ }
+
+ if (advance_prev) {
+ prev_ = iter_;
+ }
+ iter_.Next();
+ }
+
+ void Prev() override {
+ assert(Valid());
+ iter_.Prev();
+ prev_ = iter_;
+ }
+
+ void Seek(const Slice& internal_key, const char* memtable_key) override {
+ const char *encoded_key =
+ (memtable_key != nullptr) ?
+ memtable_key : EncodeKey(&tmp_, internal_key);
+
+ if (prev_.Valid() && rep_.cmp_(encoded_key, prev_.key()) >= 0) {
+ // prev_.key() is smaller or equal to our target key; do a quick
+ // linear search (at most lookahead_ steps) starting from prev_
+ iter_ = prev_;
+
+ size_t cur = 0;
+ while (cur++ <= rep_.lookahead_ && iter_.Valid()) {
+ if (rep_.cmp_(encoded_key, iter_.key()) <= 0) {
+ return;
+ }
+ Next();
+ }
+ }
+
+ iter_.Seek(encoded_key);
+ prev_ = iter_;
+ }
+
+ void SeekForPrev(const Slice& internal_key,
+ const char* memtable_key) override {
+ const char* encoded_key = (memtable_key != nullptr)
+ ? memtable_key
+ : EncodeKey(&tmp_, internal_key);
+ iter_.SeekForPrev(encoded_key);
+ prev_ = iter_;
+ }
+
+ void SeekToFirst() override {
+ iter_.SeekToFirst();
+ prev_ = iter_;
+ }
+
+ void SeekToLast() override {
+ iter_.SeekToLast();
+ prev_ = iter_;
+ }
+
+ protected:
+ std::string tmp_; // For passing to EncodeKey
+
+ private:
+ const SkipListRep& rep_;
+ InlineSkipList<const MemTableRep::KeyComparator&>::Iterator iter_;
+ InlineSkipList<const MemTableRep::KeyComparator&>::Iterator prev_;
+ };
+
+ MemTableRep::Iterator* GetIterator(Arena* arena = nullptr) override {
+ if (lookahead_ > 0) {
+ void *mem =
+ arena ? arena->AllocateAligned(sizeof(SkipListRep::LookaheadIterator))
+ : operator new(sizeof(SkipListRep::LookaheadIterator));
+ return new (mem) SkipListRep::LookaheadIterator(*this);
+ } else {
+ void *mem =
+ arena ? arena->AllocateAligned(sizeof(SkipListRep::Iterator))
+ : operator new(sizeof(SkipListRep::Iterator));
+ return new (mem) SkipListRep::Iterator(&skip_list_);
+ }
+ }
+};
+}
+
+MemTableRep* SkipListFactory::CreateMemTableRep(
+ const MemTableRep::KeyComparator& compare, Allocator* allocator,
+ const SliceTransform* transform, Logger* /*logger*/) {
+ return new SkipListRep(compare, allocator, transform, lookahead_);
+}
+
+} // namespace rocksdb
diff --git a/src/rocksdb/memtable/stl_wrappers.h b/src/rocksdb/memtable/stl_wrappers.h
new file mode 100644
index 00000000..0287f4f8
--- /dev/null
+++ b/src/rocksdb/memtable/stl_wrappers.h
@@ -0,0 +1,33 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+#pragma once
+
+#include <map>
+#include <string>
+
+#include "rocksdb/comparator.h"
+#include "rocksdb/memtablerep.h"
+#include "rocksdb/slice.h"
+#include "util/coding.h"
+
+namespace rocksdb {
+namespace stl_wrappers {
+
+class Base {
+ protected:
+ const MemTableRep::KeyComparator& compare_;
+ explicit Base(const MemTableRep::KeyComparator& compare)
+ : compare_(compare) {}
+};
+
+struct Compare : private Base {
+ explicit Compare(const MemTableRep::KeyComparator& compare) : Base(compare) {}
+ inline bool operator()(const char* a, const char* b) const {
+ return compare_(a, b) < 0;
+ }
+};
+
+}
+}
diff --git a/src/rocksdb/memtable/vectorrep.cc b/src/rocksdb/memtable/vectorrep.cc
new file mode 100644
index 00000000..827ab8a5
--- /dev/null
+++ b/src/rocksdb/memtable/vectorrep.cc
@@ -0,0 +1,301 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+#ifndef ROCKSDB_LITE
+#include "rocksdb/memtablerep.h"
+
+#include <unordered_set>
+#include <set>
+#include <memory>
+#include <algorithm>
+#include <type_traits>
+
+#include "util/arena.h"
+#include "db/memtable.h"
+#include "memtable/stl_wrappers.h"
+#include "port/port.h"
+#include "util/mutexlock.h"
+
+namespace rocksdb {
+namespace {
+
+using namespace stl_wrappers;
+
+class VectorRep : public MemTableRep {
+ public:
+ VectorRep(const KeyComparator& compare, Allocator* allocator, size_t count);
+
+ // Insert key into the collection. (The caller will pack key and value into a
+ // single buffer and pass that in as the parameter to Insert)
+ // REQUIRES: nothing that compares equal to key is currently in the
+ // collection.
+ void Insert(KeyHandle handle) override;
+
+ // Returns true iff an entry that compares equal to key is in the collection.
+ bool Contains(const char* key) const override;
+
+ void MarkReadOnly() override;
+
+ size_t ApproximateMemoryUsage() override;
+
+ void Get(const LookupKey& k, void* callback_args,
+ bool (*callback_func)(void* arg, const char* entry)) override;
+
+ ~VectorRep() override {}
+
+ class Iterator : public MemTableRep::Iterator {
+ class VectorRep* vrep_;
+ std::shared_ptr<std::vector<const char*>> bucket_;
+ std::vector<const char*>::const_iterator mutable cit_;
+ const KeyComparator& compare_;
+ std::string tmp_; // For passing to EncodeKey
+ bool mutable sorted_;
+ void DoSort() const;
+ public:
+ explicit Iterator(class VectorRep* vrep,
+ std::shared_ptr<std::vector<const char*>> bucket,
+ const KeyComparator& compare);
+
+ // Initialize an iterator over the specified collection.
+ // The returned iterator is not valid.
+ // explicit Iterator(const MemTableRep* collection);
+ ~Iterator() override{};
+
+ // Returns true iff the iterator is positioned at a valid node.
+ bool Valid() const override;
+
+ // Returns the key at the current position.
+ // REQUIRES: Valid()
+ const char* key() const override;
+
+ // Advances to the next position.
+ // REQUIRES: Valid()
+ void Next() override;
+
+ // Advances to the previous position.
+ // REQUIRES: Valid()
+ void Prev() override;
+
+ // Advance to the first entry with a key >= target
+ void Seek(const Slice& user_key, const char* memtable_key) override;
+
+ // Advance to the first entry with a key <= target
+ void SeekForPrev(const Slice& user_key, const char* memtable_key) override;
+
+ // Position at the first entry in collection.
+ // Final state of iterator is Valid() iff collection is not empty.
+ void SeekToFirst() override;
+
+ // Position at the last entry in collection.
+ // Final state of iterator is Valid() iff collection is not empty.
+ void SeekToLast() override;
+ };
+
+ // Return an iterator over the keys in this representation.
+ MemTableRep::Iterator* GetIterator(Arena* arena) override;
+
+ private:
+ friend class Iterator;
+ typedef std::vector<const char*> Bucket;
+ std::shared_ptr<Bucket> bucket_;
+ mutable port::RWMutex rwlock_;
+ bool immutable_;
+ bool sorted_;
+ const KeyComparator& compare_;
+};
+
+void VectorRep::Insert(KeyHandle handle) {
+ auto* key = static_cast<char*>(handle);
+ WriteLock l(&rwlock_);
+ assert(!immutable_);
+ bucket_->push_back(key);
+}
+
+// Returns true iff an entry that compares equal to key is in the collection.
+bool VectorRep::Contains(const char* key) const {
+ ReadLock l(&rwlock_);
+ return std::find(bucket_->begin(), bucket_->end(), key) != bucket_->end();
+}
+
+void VectorRep::MarkReadOnly() {
+ WriteLock l(&rwlock_);
+ immutable_ = true;
+}
+
+size_t VectorRep::ApproximateMemoryUsage() {
+ return
+ sizeof(bucket_) + sizeof(*bucket_) +
+ bucket_->size() *
+ sizeof(
+ std::remove_reference<decltype(*bucket_)>::type::value_type
+ );
+}
+
+VectorRep::VectorRep(const KeyComparator& compare, Allocator* allocator,
+ size_t count)
+ : MemTableRep(allocator),
+ bucket_(new Bucket()),
+ immutable_(false),
+ sorted_(false),
+ compare_(compare) {
+ bucket_.get()->reserve(count);
+}
+
+VectorRep::Iterator::Iterator(class VectorRep* vrep,
+ std::shared_ptr<std::vector<const char*>> bucket,
+ const KeyComparator& compare)
+: vrep_(vrep),
+ bucket_(bucket),
+ cit_(bucket_->end()),
+ compare_(compare),
+ sorted_(false) { }
+
+void VectorRep::Iterator::DoSort() const {
+ // vrep is non-null means that we are working on an immutable memtable
+ if (!sorted_ && vrep_ != nullptr) {
+ WriteLock l(&vrep_->rwlock_);
+ if (!vrep_->sorted_) {
+ std::sort(bucket_->begin(), bucket_->end(), Compare(compare_));
+ cit_ = bucket_->begin();
+ vrep_->sorted_ = true;
+ }
+ sorted_ = true;
+ }
+ if (!sorted_) {
+ std::sort(bucket_->begin(), bucket_->end(), Compare(compare_));
+ cit_ = bucket_->begin();
+ sorted_ = true;
+ }
+ assert(sorted_);
+ assert(vrep_ == nullptr || vrep_->sorted_);
+}
+
+// Returns true iff the iterator is positioned at a valid node.
+bool VectorRep::Iterator::Valid() const {
+ DoSort();
+ return cit_ != bucket_->end();
+}
+
+// Returns the key at the current position.
+// REQUIRES: Valid()
+const char* VectorRep::Iterator::key() const {
+ assert(sorted_);
+ return *cit_;
+}
+
+// Advances to the next position.
+// REQUIRES: Valid()
+void VectorRep::Iterator::Next() {
+ assert(sorted_);
+ if (cit_ == bucket_->end()) {
+ return;
+ }
+ ++cit_;
+}
+
+// Advances to the previous position.
+// REQUIRES: Valid()
+void VectorRep::Iterator::Prev() {
+ assert(sorted_);
+ if (cit_ == bucket_->begin()) {
+ // If you try to go back from the first element, the iterator should be
+ // invalidated. So we set it to past-the-end. This means that you can
+ // treat the container circularly.
+ cit_ = bucket_->end();
+ } else {
+ --cit_;
+ }
+}
+
+// Advance to the first entry with a key >= target
+void VectorRep::Iterator::Seek(const Slice& user_key,
+ const char* memtable_key) {
+ DoSort();
+ // Do binary search to find first value not less than the target
+ const char* encoded_key =
+ (memtable_key != nullptr) ? memtable_key : EncodeKey(&tmp_, user_key);
+ cit_ = std::equal_range(bucket_->begin(),
+ bucket_->end(),
+ encoded_key,
+ [this] (const char* a, const char* b) {
+ return compare_(a, b) < 0;
+ }).first;
+}
+
+// Advance to the first entry with a key <= target
+void VectorRep::Iterator::SeekForPrev(const Slice& /*user_key*/,
+ const char* /*memtable_key*/) {
+ assert(false);
+}
+
+// Position at the first entry in collection.
+// Final state of iterator is Valid() iff collection is not empty.
+void VectorRep::Iterator::SeekToFirst() {
+ DoSort();
+ cit_ = bucket_->begin();
+}
+
+// Position at the last entry in collection.
+// Final state of iterator is Valid() iff collection is not empty.
+void VectorRep::Iterator::SeekToLast() {
+ DoSort();
+ cit_ = bucket_->end();
+ if (bucket_->size() != 0) {
+ --cit_;
+ }
+}
+
+void VectorRep::Get(const LookupKey& k, void* callback_args,
+ bool (*callback_func)(void* arg, const char* entry)) {
+ rwlock_.ReadLock();
+ VectorRep* vector_rep;
+ std::shared_ptr<Bucket> bucket;
+ if (immutable_) {
+ vector_rep = this;
+ } else {
+ vector_rep = nullptr;
+ bucket.reset(new Bucket(*bucket_)); // make a copy
+ }
+ VectorRep::Iterator iter(vector_rep, immutable_ ? bucket_ : bucket, compare_);
+ rwlock_.ReadUnlock();
+
+ for (iter.Seek(k.user_key(), k.memtable_key().data());
+ iter.Valid() && callback_func(callback_args, iter.key()); iter.Next()) {
+ }
+}
+
+MemTableRep::Iterator* VectorRep::GetIterator(Arena* arena) {
+ char* mem = nullptr;
+ if (arena != nullptr) {
+ mem = arena->AllocateAligned(sizeof(Iterator));
+ }
+ ReadLock l(&rwlock_);
+ // Do not sort here. The sorting would be done the first time
+ // a Seek is performed on the iterator.
+ if (immutable_) {
+ if (arena == nullptr) {
+ return new Iterator(this, bucket_, compare_);
+ } else {
+ return new (mem) Iterator(this, bucket_, compare_);
+ }
+ } else {
+ std::shared_ptr<Bucket> tmp;
+ tmp.reset(new Bucket(*bucket_)); // make a copy
+ if (arena == nullptr) {
+ return new Iterator(nullptr, tmp, compare_);
+ } else {
+ return new (mem) Iterator(nullptr, tmp, compare_);
+ }
+ }
+}
+} // anon namespace
+
+MemTableRep* VectorRepFactory::CreateMemTableRep(
+ const MemTableRep::KeyComparator& compare, Allocator* allocator,
+ const SliceTransform*, Logger* /*logger*/) {
+ return new VectorRep(compare, allocator, count_);
+}
+} // namespace rocksdb
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/memtable/write_buffer_manager.cc b/src/rocksdb/memtable/write_buffer_manager.cc
new file mode 100644
index 00000000..7f2e664a
--- /dev/null
+++ b/src/rocksdb/memtable/write_buffer_manager.cc
@@ -0,0 +1,130 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "rocksdb/write_buffer_manager.h"
+#include <mutex>
+#include "util/coding.h"
+
+namespace rocksdb {
+#ifndef ROCKSDB_LITE
+namespace {
+const size_t kSizeDummyEntry = 1024 * 1024;
+// The key will be longer than keys for blocks in SST files so they won't
+// conflict.
+const size_t kCacheKeyPrefix = kMaxVarint64Length * 4 + 1;
+} // namespace
+
+struct WriteBufferManager::CacheRep {
+ std::shared_ptr<Cache> cache_;
+ std::mutex cache_mutex_;
+ std::atomic<size_t> cache_allocated_size_;
+ // The non-prefix part will be updated according to the ID to use.
+ char cache_key_[kCacheKeyPrefix + kMaxVarint64Length];
+ uint64_t next_cache_key_id_ = 0;
+ std::vector<Cache::Handle*> dummy_handles_;
+
+ explicit CacheRep(std::shared_ptr<Cache> cache)
+ : cache_(cache), cache_allocated_size_(0) {
+ memset(cache_key_, 0, kCacheKeyPrefix);
+ size_t pointer_size = sizeof(const void*);
+ assert(pointer_size <= kCacheKeyPrefix);
+ memcpy(cache_key_, static_cast<const void*>(this), pointer_size);
+ }
+
+ Slice GetNextCacheKey() {
+ memset(cache_key_ + kCacheKeyPrefix, 0, kMaxVarint64Length);
+ char* end =
+ EncodeVarint64(cache_key_ + kCacheKeyPrefix, next_cache_key_id_++);
+ return Slice(cache_key_, static_cast<size_t>(end - cache_key_));
+ }
+};
+#else
+struct WriteBufferManager::CacheRep {};
+#endif // ROCKSDB_LITE
+
+WriteBufferManager::WriteBufferManager(size_t _buffer_size,
+ std::shared_ptr<Cache> cache)
+ : buffer_size_(_buffer_size),
+ mutable_limit_(buffer_size_ * 7 / 8),
+ memory_used_(0),
+ memory_active_(0),
+ cache_rep_(nullptr) {
+#ifndef ROCKSDB_LITE
+ if (cache) {
+ // Construct the cache key using the pointer to this.
+ cache_rep_.reset(new CacheRep(cache));
+ }
+#else
+ (void)cache;
+#endif // ROCKSDB_LITE
+}
+
+WriteBufferManager::~WriteBufferManager() {
+#ifndef ROCKSDB_LITE
+ if (cache_rep_) {
+ for (auto* handle : cache_rep_->dummy_handles_) {
+ cache_rep_->cache_->Release(handle, true);
+ }
+ }
+#endif // ROCKSDB_LITE
+}
+
+// Should only be called from write thread
+void WriteBufferManager::ReserveMemWithCache(size_t mem) {
+#ifndef ROCKSDB_LITE
+ assert(cache_rep_ != nullptr);
+ // Use a mutex to protect various data structures. Can be optimized to a
+ // lock-free solution if it ends up with a performance bottleneck.
+ std::lock_guard<std::mutex> lock(cache_rep_->cache_mutex_);
+
+ size_t new_mem_used = memory_used_.load(std::memory_order_relaxed) + mem;
+ memory_used_.store(new_mem_used, std::memory_order_relaxed);
+ while (new_mem_used > cache_rep_->cache_allocated_size_) {
+ // Expand size by at least 1MB.
+ // Add a dummy record to the cache
+ Cache::Handle* handle;
+ cache_rep_->cache_->Insert(cache_rep_->GetNextCacheKey(), nullptr,
+ kSizeDummyEntry, nullptr, &handle);
+ cache_rep_->dummy_handles_.push_back(handle);
+ cache_rep_->cache_allocated_size_ += kSizeDummyEntry;
+ }
+#else
+ (void)mem;
+#endif // ROCKSDB_LITE
+}
+
+void WriteBufferManager::FreeMemWithCache(size_t mem) {
+#ifndef ROCKSDB_LITE
+ assert(cache_rep_ != nullptr);
+ // Use a mutex to protect various data structures. Can be optimized to a
+ // lock-free solution if it ends up with a performance bottleneck.
+ std::lock_guard<std::mutex> lock(cache_rep_->cache_mutex_);
+ size_t new_mem_used = memory_used_.load(std::memory_order_relaxed) - mem;
+ memory_used_.store(new_mem_used, std::memory_order_relaxed);
+ // Gradually shrink memory costed in the block cache if the actual
+ // usage is less than 3/4 of what we reserve from the block cache.
+ // We do this because:
+ // 1. we don't pay the cost of the block cache immediately a memtable is
+ // freed, as block cache insert is expensive;
+ // 2. eventually, if we walk away from a temporary memtable size increase,
+ // we make sure shrink the memory costed in block cache over time.
+ // In this way, we only shrink costed memory showly even there is enough
+ // margin.
+ if (new_mem_used < cache_rep_->cache_allocated_size_ / 4 * 3 &&
+ cache_rep_->cache_allocated_size_ - kSizeDummyEntry > new_mem_used) {
+ assert(!cache_rep_->dummy_handles_.empty());
+ cache_rep_->cache_->Release(cache_rep_->dummy_handles_.back(), true);
+ cache_rep_->dummy_handles_.pop_back();
+ cache_rep_->cache_allocated_size_ -= kSizeDummyEntry;
+ }
+#else
+ (void)mem;
+#endif // ROCKSDB_LITE
+}
+} // namespace rocksdb
diff --git a/src/rocksdb/memtable/write_buffer_manager_test.cc b/src/rocksdb/memtable/write_buffer_manager_test.cc
new file mode 100644
index 00000000..0fc9fd06
--- /dev/null
+++ b/src/rocksdb/memtable/write_buffer_manager_test.cc
@@ -0,0 +1,151 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "rocksdb/write_buffer_manager.h"
+#include "util/testharness.h"
+
+namespace rocksdb {
+
+class WriteBufferManagerTest : public testing::Test {};
+
+#ifndef ROCKSDB_LITE
+TEST_F(WriteBufferManagerTest, ShouldFlush) {
+ // A write buffer manager of size 10MB
+ std::unique_ptr<WriteBufferManager> wbf(
+ new WriteBufferManager(10 * 1024 * 1024));
+
+ wbf->ReserveMem(8 * 1024 * 1024);
+ ASSERT_FALSE(wbf->ShouldFlush());
+ // 90% of the hard limit will hit the condition
+ wbf->ReserveMem(1 * 1024 * 1024);
+ ASSERT_TRUE(wbf->ShouldFlush());
+ // Scheduling for freeing will release the condition
+ wbf->ScheduleFreeMem(1 * 1024 * 1024);
+ ASSERT_FALSE(wbf->ShouldFlush());
+
+ wbf->ReserveMem(2 * 1024 * 1024);
+ ASSERT_TRUE(wbf->ShouldFlush());
+
+ wbf->ScheduleFreeMem(4 * 1024 * 1024);
+ // 11MB total, 6MB mutable. hard limit still hit
+ ASSERT_TRUE(wbf->ShouldFlush());
+
+ wbf->ScheduleFreeMem(2 * 1024 * 1024);
+ // 11MB total, 4MB mutable. hard limit stills but won't flush because more
+ // than half data is already being flushed.
+ ASSERT_FALSE(wbf->ShouldFlush());
+
+ wbf->ReserveMem(4 * 1024 * 1024);
+ // 15 MB total, 8MB mutable.
+ ASSERT_TRUE(wbf->ShouldFlush());
+
+ wbf->FreeMem(7 * 1024 * 1024);
+ // 9MB total, 8MB mutable.
+ ASSERT_FALSE(wbf->ShouldFlush());
+}
+
+TEST_F(WriteBufferManagerTest, CacheCost) {
+ // 1GB cache
+ std::shared_ptr<Cache> cache = NewLRUCache(1024 * 1024 * 1024, 4);
+ // A write buffer manager of size 50MB
+ std::unique_ptr<WriteBufferManager> wbf(
+ new WriteBufferManager(50 * 1024 * 1024, cache));
+
+ // Allocate 1.5MB will allocate 2MB
+ wbf->ReserveMem(1536 * 1024);
+ ASSERT_GE(cache->GetPinnedUsage(), 2 * 1024 * 1024);
+ ASSERT_LT(cache->GetPinnedUsage(), 2 * 1024 * 1024 + 10000);
+
+ // Allocate another 2MB
+ wbf->ReserveMem(2 * 1024 * 1024);
+ ASSERT_GE(cache->GetPinnedUsage(), 4 * 1024 * 1024);
+ ASSERT_LT(cache->GetPinnedUsage(), 4 * 1024 * 1024 + 10000);
+
+ // Allocate another 20MB
+ wbf->ReserveMem(20 * 1024 * 1024);
+ ASSERT_GE(cache->GetPinnedUsage(), 24 * 1024 * 1024);
+ ASSERT_LT(cache->GetPinnedUsage(), 24 * 1024 * 1024 + 10000);
+
+ // Free 2MB will not cause any change in cache cost
+ wbf->FreeMem(2 * 1024 * 1024);
+ ASSERT_GE(cache->GetPinnedUsage(), 24 * 1024 * 1024);
+ ASSERT_LT(cache->GetPinnedUsage(), 24 * 1024 * 1024 + 10000);
+
+ ASSERT_FALSE(wbf->ShouldFlush());
+
+ // Allocate another 30MB
+ wbf->ReserveMem(30 * 1024 * 1024);
+ ASSERT_GE(cache->GetPinnedUsage(), 52 * 1024 * 1024);
+ ASSERT_LT(cache->GetPinnedUsage(), 52 * 1024 * 1024 + 10000);
+ ASSERT_TRUE(wbf->ShouldFlush());
+
+ ASSERT_TRUE(wbf->ShouldFlush());
+
+ wbf->ScheduleFreeMem(20 * 1024 * 1024);
+ ASSERT_GE(cache->GetPinnedUsage(), 52 * 1024 * 1024);
+ ASSERT_LT(cache->GetPinnedUsage(), 52 * 1024 * 1024 + 10000);
+
+ // Still need flush as the hard limit hits
+ ASSERT_TRUE(wbf->ShouldFlush());
+
+ // Free 20MB will releae 1MB from cache
+ wbf->FreeMem(20 * 1024 * 1024);
+ ASSERT_GE(cache->GetPinnedUsage(), 51 * 1024 * 1024);
+ ASSERT_LT(cache->GetPinnedUsage(), 51 * 1024 * 1024 + 10000);
+
+ ASSERT_FALSE(wbf->ShouldFlush());
+
+ // Every free will release 1MB if still not hit 3/4
+ wbf->FreeMem(16 * 1024);
+ ASSERT_GE(cache->GetPinnedUsage(), 50 * 1024 * 1024);
+ ASSERT_LT(cache->GetPinnedUsage(), 50 * 1024 * 1024 + 10000);
+
+ wbf->FreeMem(16 * 1024);
+ ASSERT_GE(cache->GetPinnedUsage(), 49 * 1024 * 1024);
+ ASSERT_LT(cache->GetPinnedUsage(), 49 * 1024 * 1024 + 10000);
+
+ // Free 2MB will not cause any change in cache cost
+ wbf->ReserveMem(2 * 1024 * 1024);
+ ASSERT_GE(cache->GetPinnedUsage(), 49 * 1024 * 1024);
+ ASSERT_LT(cache->GetPinnedUsage(), 49 * 1024 * 1024 + 10000);
+
+ wbf->FreeMem(16 * 1024);
+ ASSERT_GE(cache->GetPinnedUsage(), 48 * 1024 * 1024);
+ ASSERT_LT(cache->GetPinnedUsage(), 48 * 1024 * 1024 + 10000);
+
+ // Destory write buffer manger should free everything
+ wbf.reset();
+ ASSERT_LT(cache->GetPinnedUsage(), 1024 * 1024);
+}
+
+TEST_F(WriteBufferManagerTest, NoCapCacheCost) {
+ // 1GB cache
+ std::shared_ptr<Cache> cache = NewLRUCache(1024 * 1024 * 1024, 4);
+ // A write buffer manager of size 256MB
+ std::unique_ptr<WriteBufferManager> wbf(new WriteBufferManager(0, cache));
+ // Allocate 1.5MB will allocate 2MB
+ wbf->ReserveMem(10 * 1024 * 1024);
+ ASSERT_GE(cache->GetPinnedUsage(), 10 * 1024 * 1024);
+ ASSERT_LT(cache->GetPinnedUsage(), 10 * 1024 * 1024 + 10000);
+ ASSERT_FALSE(wbf->ShouldFlush());
+
+ wbf->FreeMem(9 * 1024 * 1024);
+ for (int i = 0; i < 10; i++) {
+ wbf->FreeMem(16 * 1024);
+ }
+ ASSERT_GE(cache->GetPinnedUsage(), 1024 * 1024);
+ ASSERT_LT(cache->GetPinnedUsage(), 1024 * 1024 + 10000);
+}
+#endif // ROCKSDB_LITE
+} // namespace rocksdb
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}