summaryrefslogtreecommitdiffstats
path: root/src/rocksdb/memtable
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--src/rocksdb/memtable/alloc_tracker.cc63
-rw-r--r--src/rocksdb/memtable/hash_linklist_rep.cc926
-rw-r--r--src/rocksdb/memtable/hash_skiplist_rep.cc393
-rw-r--r--src/rocksdb/memtable/inlineskiplist.h1051
-rw-r--r--src/rocksdb/memtable/inlineskiplist_test.cc664
-rw-r--r--src/rocksdb/memtable/memtablerep_bench.cc689
-rw-r--r--src/rocksdb/memtable/skiplist.h498
-rw-r--r--src/rocksdb/memtable/skiplist_test.cc387
-rw-r--r--src/rocksdb/memtable/skiplistrep.cc370
-rw-r--r--src/rocksdb/memtable/stl_wrappers.h33
-rw-r--r--src/rocksdb/memtable/vectorrep.cc309
-rw-r--r--src/rocksdb/memtable/write_buffer_manager.cc202
-rw-r--r--src/rocksdb/memtable/write_buffer_manager_test.cc305
13 files changed, 5890 insertions, 0 deletions
diff --git a/src/rocksdb/memtable/alloc_tracker.cc b/src/rocksdb/memtable/alloc_tracker.cc
new file mode 100644
index 000000000..4c6d35431
--- /dev/null
+++ b/src/rocksdb/memtable/alloc_tracker.cc
@@ -0,0 +1,63 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include <assert.h>
+
+#include "memory/allocator.h"
+#include "memory/arena.h"
+#include "rocksdb/write_buffer_manager.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+AllocTracker::AllocTracker(WriteBufferManager* write_buffer_manager)
+ : write_buffer_manager_(write_buffer_manager),
+ bytes_allocated_(0),
+ done_allocating_(false),
+ freed_(false) {}
+
+AllocTracker::~AllocTracker() { FreeMem(); }
+
+void AllocTracker::Allocate(size_t bytes) {
+ assert(write_buffer_manager_ != nullptr);
+ if (write_buffer_manager_->enabled() ||
+ write_buffer_manager_->cost_to_cache()) {
+ bytes_allocated_.fetch_add(bytes, std::memory_order_relaxed);
+ write_buffer_manager_->ReserveMem(bytes);
+ }
+}
+
+void AllocTracker::DoneAllocating() {
+ if (write_buffer_manager_ != nullptr && !done_allocating_) {
+ if (write_buffer_manager_->enabled() ||
+ write_buffer_manager_->cost_to_cache()) {
+ write_buffer_manager_->ScheduleFreeMem(
+ bytes_allocated_.load(std::memory_order_relaxed));
+ } else {
+ assert(bytes_allocated_.load(std::memory_order_relaxed) == 0);
+ }
+ done_allocating_ = true;
+ }
+}
+
+void AllocTracker::FreeMem() {
+ if (!done_allocating_) {
+ DoneAllocating();
+ }
+ if (write_buffer_manager_ != nullptr && !freed_) {
+ if (write_buffer_manager_->enabled() ||
+ write_buffer_manager_->cost_to_cache()) {
+ write_buffer_manager_->FreeMem(
+ bytes_allocated_.load(std::memory_order_relaxed));
+ } else {
+ assert(bytes_allocated_.load(std::memory_order_relaxed) == 0);
+ }
+ freed_ = true;
+ }
+}
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/memtable/hash_linklist_rep.cc b/src/rocksdb/memtable/hash_linklist_rep.cc
new file mode 100644
index 000000000..a71768304
--- /dev/null
+++ b/src/rocksdb/memtable/hash_linklist_rep.cc
@@ -0,0 +1,926 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+
+#ifndef ROCKSDB_LITE
+
+#include <algorithm>
+#include <atomic>
+
+#include "db/memtable.h"
+#include "memory/arena.h"
+#include "memtable/skiplist.h"
+#include "monitoring/histogram.h"
+#include "port/port.h"
+#include "rocksdb/memtablerep.h"
+#include "rocksdb/slice.h"
+#include "rocksdb/slice_transform.h"
+#include "rocksdb/utilities/options_type.h"
+#include "util/hash.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace {
+
+using Key = const char*;
+using MemtableSkipList = SkipList<Key, const MemTableRep::KeyComparator&>;
+using Pointer = std::atomic<void*>;
+
+// A data structure used as the header of a link list of a hash bucket.
+struct BucketHeader {
+ Pointer next;
+ std::atomic<uint32_t> num_entries;
+
+ explicit BucketHeader(void* n, uint32_t count)
+ : next(n), num_entries(count) {}
+
+ bool IsSkipListBucket() {
+ return next.load(std::memory_order_relaxed) == this;
+ }
+
+ uint32_t GetNumEntries() const {
+ return num_entries.load(std::memory_order_relaxed);
+ }
+
+ // REQUIRES: called from single-threaded Insert()
+ void IncNumEntries() {
+ // Only one thread can do write at one time. No need to do atomic
+ // incremental. Update it with relaxed load and store.
+ num_entries.store(GetNumEntries() + 1, std::memory_order_relaxed);
+ }
+};
+
+// A data structure used as the header of a skip list of a hash bucket.
+struct SkipListBucketHeader {
+ BucketHeader Counting_header;
+ MemtableSkipList skip_list;
+
+ explicit SkipListBucketHeader(const MemTableRep::KeyComparator& cmp,
+ Allocator* allocator, uint32_t count)
+ : Counting_header(this, // Pointing to itself to indicate header type.
+ count),
+ skip_list(cmp, allocator) {}
+};
+
+struct Node {
+ // Accessors/mutators for links. Wrapped in methods so we can
+ // add the appropriate barriers as necessary.
+ Node* Next() {
+ // Use an 'acquire load' so that we observe a fully initialized
+ // version of the returned Node.
+ return next_.load(std::memory_order_acquire);
+ }
+ void SetNext(Node* x) {
+ // Use a 'release store' so that anybody who reads through this
+ // pointer observes a fully initialized version of the inserted node.
+ next_.store(x, std::memory_order_release);
+ }
+ // No-barrier variants that can be safely used in a few locations.
+ Node* NoBarrier_Next() { return next_.load(std::memory_order_relaxed); }
+
+ void NoBarrier_SetNext(Node* x) { next_.store(x, std::memory_order_relaxed); }
+
+ // Needed for placement new below which is fine
+ Node() {}
+
+ private:
+ std::atomic<Node*> next_;
+
+ // Prohibit copying due to the below
+ Node(const Node&) = delete;
+ Node& operator=(const Node&) = delete;
+
+ public:
+ char key[1];
+};
+
+// Memory structure of the mem table:
+// It is a hash table, each bucket points to one entry, a linked list or a
+// skip list. In order to track total number of records in a bucket to determine
+// whether should switch to skip list, a header is added just to indicate
+// number of entries in the bucket.
+//
+//
+// +-----> NULL Case 1. Empty bucket
+// |
+// |
+// | +---> +-------+
+// | | | Next +--> NULL
+// | | +-------+
+// +-----+ | | | | Case 2. One Entry in bucket.
+// | +-+ | | Data | next pointer points to
+// +-----+ | | | NULL. All other cases
+// | | | | | next pointer is not NULL.
+// +-----+ | +-------+
+// | +---+
+// +-----+ +-> +-------+ +> +-------+ +-> +-------+
+// | | | | Next +--+ | Next +--+ | Next +-->NULL
+// +-----+ | +-------+ +-------+ +-------+
+// | +-----+ | Count | | | | |
+// +-----+ +-------+ | Data | | Data |
+// | | | | | |
+// +-----+ Case 3. | | | |
+// | | A header +-------+ +-------+
+// +-----+ points to
+// | | a linked list. Count indicates total number
+// +-----+ of rows in this bucket.
+// | |
+// +-----+ +-> +-------+ <--+
+// | | | | Next +----+
+// +-----+ | +-------+ Case 4. A header points to a skip
+// | +----+ | Count | list and next pointer points to
+// +-----+ +-------+ itself, to distinguish case 3 or 4.
+// | | | | Count still is kept to indicates total
+// +-----+ | Skip +--> of entries in the bucket for debugging
+// | | | List | Data purpose.
+// | | | +-->
+// +-----+ | |
+// | | +-------+
+// +-----+
+//
+// We don't have data race when changing cases because:
+// (1) When changing from case 2->3, we create a new bucket header, put the
+// single node there first without changing the original node, and do a
+// release store when changing the bucket pointer. In that case, a reader
+// who sees a stale value of the bucket pointer will read this node, while
+// a reader sees the correct value because of the release store.
+// (2) When changing case 3->4, a new header is created with skip list points
+// to the data, before doing an acquire store to change the bucket pointer.
+// The old header and nodes are never changed, so any reader sees any
+// of those existing pointers will guarantee to be able to iterate to the
+// end of the linked list.
+// (3) Header's next pointer in case 3 might change, but they are never equal
+// to itself, so no matter a reader sees any stale or newer value, it will
+// be able to correctly distinguish case 3 and 4.
+//
+// The reason that we use case 2 is we want to make the format to be efficient
+// when the utilization of buckets is relatively low. If we use case 3 for
+// single entry bucket, we will need to waste 12 bytes for every entry,
+// which can be significant decrease of memory utilization.
+class HashLinkListRep : public MemTableRep {
+ public:
+ HashLinkListRep(const MemTableRep::KeyComparator& compare,
+ Allocator* allocator, const SliceTransform* transform,
+ size_t bucket_size, uint32_t threshold_use_skiplist,
+ size_t huge_page_tlb_size, Logger* logger,
+ int bucket_entries_logging_threshold,
+ bool if_log_bucket_dist_when_flash);
+
+ KeyHandle Allocate(const size_t len, char** buf) override;
+
+ void Insert(KeyHandle handle) override;
+
+ bool Contains(const char* key) const override;
+
+ size_t ApproximateMemoryUsage() override;
+
+ void Get(const LookupKey& k, void* callback_args,
+ bool (*callback_func)(void* arg, const char* entry)) override;
+
+ ~HashLinkListRep() override;
+
+ MemTableRep::Iterator* GetIterator(Arena* arena = nullptr) override;
+
+ MemTableRep::Iterator* GetDynamicPrefixIterator(
+ Arena* arena = nullptr) override;
+
+ private:
+ friend class DynamicIterator;
+
+ size_t bucket_size_;
+
+ // Maps slices (which are transformed user keys) to buckets of keys sharing
+ // the same transform.
+ Pointer* buckets_;
+
+ const uint32_t threshold_use_skiplist_;
+
+ // The user-supplied transform whose domain is the user keys.
+ const SliceTransform* transform_;
+
+ const MemTableRep::KeyComparator& compare_;
+
+ Logger* logger_;
+ int bucket_entries_logging_threshold_;
+ bool if_log_bucket_dist_when_flash_;
+
+ bool LinkListContains(Node* head, const Slice& key) const;
+
+ bool IsEmptyBucket(Pointer& bucket_pointer) const {
+ return bucket_pointer.load(std::memory_order_acquire) == nullptr;
+ }
+
+ // Precondition: GetLinkListFirstNode() must have been called first and return
+ // null so that it must be a skip list bucket
+ SkipListBucketHeader* GetSkipListBucketHeader(Pointer& bucket_pointer) const;
+
+ // Returning nullptr indicates it is a skip list bucket.
+ Node* GetLinkListFirstNode(Pointer& bucket_pointer) const;
+
+ Slice GetPrefix(const Slice& internal_key) const {
+ return transform_->Transform(ExtractUserKey(internal_key));
+ }
+
+ size_t GetHash(const Slice& slice) const {
+ return GetSliceRangedNPHash(slice, bucket_size_);
+ }
+
+ Pointer& GetBucket(size_t i) const { return buckets_[i]; }
+
+ Pointer& GetBucket(const Slice& slice) const {
+ return GetBucket(GetHash(slice));
+ }
+
+ bool Equal(const Slice& a, const Key& b) const {
+ return (compare_(b, a) == 0);
+ }
+
+ bool Equal(const Key& a, const Key& b) const { return (compare_(a, b) == 0); }
+
+ bool KeyIsAfterNode(const Slice& internal_key, const Node* n) const {
+ // nullptr n is considered infinite
+ return (n != nullptr) && (compare_(n->key, internal_key) < 0);
+ }
+
+ bool KeyIsAfterNode(const Key& key, const Node* n) const {
+ // nullptr n is considered infinite
+ return (n != nullptr) && (compare_(n->key, key) < 0);
+ }
+
+ bool KeyIsAfterOrAtNode(const Slice& internal_key, const Node* n) const {
+ // nullptr n is considered infinite
+ return (n != nullptr) && (compare_(n->key, internal_key) <= 0);
+ }
+
+ bool KeyIsAfterOrAtNode(const Key& key, const Node* n) const {
+ // nullptr n is considered infinite
+ return (n != nullptr) && (compare_(n->key, key) <= 0);
+ }
+
+ Node* FindGreaterOrEqualInBucket(Node* head, const Slice& key) const;
+ Node* FindLessOrEqualInBucket(Node* head, const Slice& key) const;
+
+ class FullListIterator : public MemTableRep::Iterator {
+ public:
+ explicit FullListIterator(MemtableSkipList* list, Allocator* allocator)
+ : iter_(list), full_list_(list), allocator_(allocator) {}
+
+ ~FullListIterator() override {}
+
+ // Returns true iff the iterator is positioned at a valid node.
+ bool Valid() const override { return iter_.Valid(); }
+
+ // Returns the key at the current position.
+ // REQUIRES: Valid()
+ const char* key() const override {
+ assert(Valid());
+ return iter_.key();
+ }
+
+ // Advances to the next position.
+ // REQUIRES: Valid()
+ void Next() override {
+ assert(Valid());
+ iter_.Next();
+ }
+
+ // Advances to the previous position.
+ // REQUIRES: Valid()
+ void Prev() override {
+ assert(Valid());
+ iter_.Prev();
+ }
+
+ // Advance to the first entry with a key >= target
+ void Seek(const Slice& internal_key, const char* memtable_key) override {
+ const char* encoded_key = (memtable_key != nullptr)
+ ? memtable_key
+ : EncodeKey(&tmp_, internal_key);
+ iter_.Seek(encoded_key);
+ }
+
+ // Retreat to the last entry with a key <= target
+ void SeekForPrev(const Slice& internal_key,
+ const char* memtable_key) override {
+ const char* encoded_key = (memtable_key != nullptr)
+ ? memtable_key
+ : EncodeKey(&tmp_, internal_key);
+ iter_.SeekForPrev(encoded_key);
+ }
+
+ // Position at the first entry in collection.
+ // Final state of iterator is Valid() iff collection is not empty.
+ void SeekToFirst() override { iter_.SeekToFirst(); }
+
+ // Position at the last entry in collection.
+ // Final state of iterator is Valid() iff collection is not empty.
+ void SeekToLast() override { iter_.SeekToLast(); }
+
+ private:
+ MemtableSkipList::Iterator iter_;
+ // To destruct with the iterator.
+ std::unique_ptr<MemtableSkipList> full_list_;
+ std::unique_ptr<Allocator> allocator_;
+ std::string tmp_; // For passing to EncodeKey
+ };
+
+ class LinkListIterator : public MemTableRep::Iterator {
+ public:
+ explicit LinkListIterator(const HashLinkListRep* const hash_link_list_rep,
+ Node* head)
+ : hash_link_list_rep_(hash_link_list_rep),
+ head_(head),
+ node_(nullptr) {}
+
+ ~LinkListIterator() override {}
+
+ // Returns true iff the iterator is positioned at a valid node.
+ bool Valid() const override { return node_ != nullptr; }
+
+ // Returns the key at the current position.
+ // REQUIRES: Valid()
+ const char* key() const override {
+ assert(Valid());
+ return node_->key;
+ }
+
+ // Advances to the next position.
+ // REQUIRES: Valid()
+ void Next() override {
+ assert(Valid());
+ node_ = node_->Next();
+ }
+
+ // Advances to the previous position.
+ // REQUIRES: Valid()
+ void Prev() override {
+ // Prefix iterator does not support total order.
+ // We simply set the iterator to invalid state
+ Reset(nullptr);
+ }
+
+ // Advance to the first entry with a key >= target
+ void Seek(const Slice& internal_key,
+ const char* /*memtable_key*/) override {
+ node_ =
+ hash_link_list_rep_->FindGreaterOrEqualInBucket(head_, internal_key);
+ }
+
+ // Retreat to the last entry with a key <= target
+ void SeekForPrev(const Slice& /*internal_key*/,
+ const char* /*memtable_key*/) override {
+ // Since we do not support Prev()
+ // We simply do not support SeekForPrev
+ Reset(nullptr);
+ }
+
+ // Position at the first entry in collection.
+ // Final state of iterator is Valid() iff collection is not empty.
+ void SeekToFirst() override {
+ // Prefix iterator does not support total order.
+ // We simply set the iterator to invalid state
+ Reset(nullptr);
+ }
+
+ // Position at the last entry in collection.
+ // Final state of iterator is Valid() iff collection is not empty.
+ void SeekToLast() override {
+ // Prefix iterator does not support total order.
+ // We simply set the iterator to invalid state
+ Reset(nullptr);
+ }
+
+ protected:
+ void Reset(Node* head) {
+ head_ = head;
+ node_ = nullptr;
+ }
+
+ private:
+ friend class HashLinkListRep;
+ const HashLinkListRep* const hash_link_list_rep_;
+ Node* head_;
+ Node* node_;
+
+ virtual void SeekToHead() { node_ = head_; }
+ };
+
+ class DynamicIterator : public HashLinkListRep::LinkListIterator {
+ public:
+ explicit DynamicIterator(HashLinkListRep& memtable_rep)
+ : HashLinkListRep::LinkListIterator(&memtable_rep, nullptr),
+ memtable_rep_(memtable_rep) {}
+
+ // Advance to the first entry with a key >= target
+ void Seek(const Slice& k, const char* memtable_key) override {
+ auto transformed = memtable_rep_.GetPrefix(k);
+ Pointer& bucket = memtable_rep_.GetBucket(transformed);
+
+ if (memtable_rep_.IsEmptyBucket(bucket)) {
+ skip_list_iter_.reset();
+ Reset(nullptr);
+ } else {
+ Node* first_linked_list_node =
+ memtable_rep_.GetLinkListFirstNode(bucket);
+ if (first_linked_list_node != nullptr) {
+ // The bucket is organized as a linked list
+ skip_list_iter_.reset();
+ Reset(first_linked_list_node);
+ HashLinkListRep::LinkListIterator::Seek(k, memtable_key);
+
+ } else {
+ SkipListBucketHeader* skip_list_header =
+ memtable_rep_.GetSkipListBucketHeader(bucket);
+ assert(skip_list_header != nullptr);
+ // The bucket is organized as a skip list
+ if (!skip_list_iter_) {
+ skip_list_iter_.reset(
+ new MemtableSkipList::Iterator(&skip_list_header->skip_list));
+ } else {
+ skip_list_iter_->SetList(&skip_list_header->skip_list);
+ }
+ if (memtable_key != nullptr) {
+ skip_list_iter_->Seek(memtable_key);
+ } else {
+ IterKey encoded_key;
+ encoded_key.EncodeLengthPrefixedKey(k);
+ skip_list_iter_->Seek(encoded_key.GetUserKey().data());
+ }
+ }
+ }
+ }
+
+ bool Valid() const override {
+ if (skip_list_iter_) {
+ return skip_list_iter_->Valid();
+ }
+ return HashLinkListRep::LinkListIterator::Valid();
+ }
+
+ const char* key() const override {
+ if (skip_list_iter_) {
+ return skip_list_iter_->key();
+ }
+ return HashLinkListRep::LinkListIterator::key();
+ }
+
+ void Next() override {
+ if (skip_list_iter_) {
+ skip_list_iter_->Next();
+ } else {
+ HashLinkListRep::LinkListIterator::Next();
+ }
+ }
+
+ private:
+ // the underlying memtable
+ const HashLinkListRep& memtable_rep_;
+ std::unique_ptr<MemtableSkipList::Iterator> skip_list_iter_;
+ };
+
+ class EmptyIterator : public MemTableRep::Iterator {
+ // This is used when there wasn't a bucket. It is cheaper than
+ // instantiating an empty bucket over which to iterate.
+ public:
+ EmptyIterator() {}
+ bool Valid() const override { return false; }
+ const char* key() const override {
+ assert(false);
+ return nullptr;
+ }
+ void Next() override {}
+ void Prev() override {}
+ void Seek(const Slice& /*user_key*/,
+ const char* /*memtable_key*/) override {}
+ void SeekForPrev(const Slice& /*user_key*/,
+ const char* /*memtable_key*/) override {}
+ void SeekToFirst() override {}
+ void SeekToLast() override {}
+
+ private:
+ };
+};
+
+HashLinkListRep::HashLinkListRep(
+ const MemTableRep::KeyComparator& compare, Allocator* allocator,
+ const SliceTransform* transform, size_t bucket_size,
+ uint32_t threshold_use_skiplist, size_t huge_page_tlb_size, Logger* logger,
+ int bucket_entries_logging_threshold, bool if_log_bucket_dist_when_flash)
+ : MemTableRep(allocator),
+ bucket_size_(bucket_size),
+ // Threshold to use skip list doesn't make sense if less than 3, so we
+ // force it to be minimum of 3 to simplify implementation.
+ threshold_use_skiplist_(std::max(threshold_use_skiplist, 3U)),
+ transform_(transform),
+ compare_(compare),
+ logger_(logger),
+ bucket_entries_logging_threshold_(bucket_entries_logging_threshold),
+ if_log_bucket_dist_when_flash_(if_log_bucket_dist_when_flash) {
+ char* mem = allocator_->AllocateAligned(sizeof(Pointer) * bucket_size,
+ huge_page_tlb_size, logger);
+
+ buckets_ = new (mem) Pointer[bucket_size];
+
+ for (size_t i = 0; i < bucket_size_; ++i) {
+ buckets_[i].store(nullptr, std::memory_order_relaxed);
+ }
+}
+
+HashLinkListRep::~HashLinkListRep() {}
+
+KeyHandle HashLinkListRep::Allocate(const size_t len, char** buf) {
+ char* mem = allocator_->AllocateAligned(sizeof(Node) + len);
+ Node* x = new (mem) Node();
+ *buf = x->key;
+ return static_cast<void*>(x);
+}
+
+SkipListBucketHeader* HashLinkListRep::GetSkipListBucketHeader(
+ Pointer& bucket_pointer) const {
+ Pointer* first_next_pointer =
+ static_cast<Pointer*>(bucket_pointer.load(std::memory_order_acquire));
+ assert(first_next_pointer != nullptr);
+ assert(first_next_pointer->load(std::memory_order_relaxed) != nullptr);
+
+ // Counting header
+ BucketHeader* header = reinterpret_cast<BucketHeader*>(first_next_pointer);
+ assert(header->IsSkipListBucket());
+ assert(header->GetNumEntries() > threshold_use_skiplist_);
+ auto* skip_list_bucket_header =
+ reinterpret_cast<SkipListBucketHeader*>(header);
+ assert(skip_list_bucket_header->Counting_header.next.load(
+ std::memory_order_relaxed) == header);
+ return skip_list_bucket_header;
+}
+
+Node* HashLinkListRep::GetLinkListFirstNode(Pointer& bucket_pointer) const {
+ Pointer* first_next_pointer =
+ static_cast<Pointer*>(bucket_pointer.load(std::memory_order_acquire));
+ assert(first_next_pointer != nullptr);
+ if (first_next_pointer->load(std::memory_order_relaxed) == nullptr) {
+ // Single entry bucket
+ return reinterpret_cast<Node*>(first_next_pointer);
+ }
+
+ // It is possible that after we fetch first_next_pointer it is modified
+ // and the next is not null anymore. In this case, the bucket should have been
+ // modified to a counting header, so we should reload the first_next_pointer
+ // to make sure we see the update.
+ first_next_pointer =
+ static_cast<Pointer*>(bucket_pointer.load(std::memory_order_acquire));
+ // Counting header
+ BucketHeader* header = reinterpret_cast<BucketHeader*>(first_next_pointer);
+ if (!header->IsSkipListBucket()) {
+ assert(header->GetNumEntries() <= threshold_use_skiplist_);
+ return reinterpret_cast<Node*>(
+ header->next.load(std::memory_order_acquire));
+ }
+ assert(header->GetNumEntries() > threshold_use_skiplist_);
+ return nullptr;
+}
+
+void HashLinkListRep::Insert(KeyHandle handle) {
+ Node* x = static_cast<Node*>(handle);
+ assert(!Contains(x->key));
+ Slice internal_key = GetLengthPrefixedSlice(x->key);
+ auto transformed = GetPrefix(internal_key);
+ auto& bucket = buckets_[GetHash(transformed)];
+ Pointer* first_next_pointer =
+ static_cast<Pointer*>(bucket.load(std::memory_order_relaxed));
+
+ if (first_next_pointer == nullptr) {
+ // Case 1. empty bucket
+ // NoBarrier_SetNext() suffices since we will add a barrier when
+ // we publish a pointer to "x" in prev[i].
+ x->NoBarrier_SetNext(nullptr);
+ bucket.store(x, std::memory_order_release);
+ return;
+ }
+
+ BucketHeader* header = nullptr;
+ if (first_next_pointer->load(std::memory_order_relaxed) == nullptr) {
+ // Case 2. only one entry in the bucket
+ // Need to convert to a Counting bucket and turn to case 4.
+ Node* first = reinterpret_cast<Node*>(first_next_pointer);
+ // Need to add a bucket header.
+ // We have to first convert it to a bucket with header before inserting
+ // the new node. Otherwise, we might need to change next pointer of first.
+ // In that case, a reader might sees the next pointer is NULL and wrongly
+ // think the node is a bucket header.
+ auto* mem = allocator_->AllocateAligned(sizeof(BucketHeader));
+ header = new (mem) BucketHeader(first, 1);
+ bucket.store(header, std::memory_order_release);
+ } else {
+ header = reinterpret_cast<BucketHeader*>(first_next_pointer);
+ if (header->IsSkipListBucket()) {
+ // Case 4. Bucket is already a skip list
+ assert(header->GetNumEntries() > threshold_use_skiplist_);
+ auto* skip_list_bucket_header =
+ reinterpret_cast<SkipListBucketHeader*>(header);
+ // Only one thread can execute Insert() at one time. No need to do atomic
+ // incremental.
+ skip_list_bucket_header->Counting_header.IncNumEntries();
+ skip_list_bucket_header->skip_list.Insert(x->key);
+ return;
+ }
+ }
+
+ if (bucket_entries_logging_threshold_ > 0 &&
+ header->GetNumEntries() ==
+ static_cast<uint32_t>(bucket_entries_logging_threshold_)) {
+ Info(logger_,
+ "HashLinkedList bucket %" ROCKSDB_PRIszt
+ " has more than %d "
+ "entries. Key to insert: %s",
+ GetHash(transformed), header->GetNumEntries(),
+ GetLengthPrefixedSlice(x->key).ToString(true).c_str());
+ }
+
+ if (header->GetNumEntries() == threshold_use_skiplist_) {
+ // Case 3. number of entries reaches the threshold so need to convert to
+ // skip list.
+ LinkListIterator bucket_iter(
+ this, reinterpret_cast<Node*>(
+ first_next_pointer->load(std::memory_order_relaxed)));
+ auto mem = allocator_->AllocateAligned(sizeof(SkipListBucketHeader));
+ SkipListBucketHeader* new_skip_list_header = new (mem)
+ SkipListBucketHeader(compare_, allocator_, header->GetNumEntries() + 1);
+ auto& skip_list = new_skip_list_header->skip_list;
+
+ // Add all current entries to the skip list
+ for (bucket_iter.SeekToHead(); bucket_iter.Valid(); bucket_iter.Next()) {
+ skip_list.Insert(bucket_iter.key());
+ }
+
+ // insert the new entry
+ skip_list.Insert(x->key);
+ // Set the bucket
+ bucket.store(new_skip_list_header, std::memory_order_release);
+ } else {
+ // Case 5. Need to insert to the sorted linked list without changing the
+ // header.
+ Node* first =
+ reinterpret_cast<Node*>(header->next.load(std::memory_order_relaxed));
+ assert(first != nullptr);
+ // Advance counter unless the bucket needs to be advanced to skip list.
+ // In that case, we need to make sure the previous count never exceeds
+ // threshold_use_skiplist_ to avoid readers to cast to wrong format.
+ header->IncNumEntries();
+
+ Node* cur = first;
+ Node* prev = nullptr;
+ while (true) {
+ if (cur == nullptr) {
+ break;
+ }
+ Node* next = cur->Next();
+ // Make sure the lists are sorted.
+ // If x points to head_ or next points nullptr, it is trivially satisfied.
+ assert((cur == first) || (next == nullptr) ||
+ KeyIsAfterNode(next->key, cur));
+ if (KeyIsAfterNode(internal_key, cur)) {
+ // Keep searching in this list
+ prev = cur;
+ cur = next;
+ } else {
+ break;
+ }
+ }
+
+ // Our data structure does not allow duplicate insertion
+ assert(cur == nullptr || !Equal(x->key, cur->key));
+
+ // NoBarrier_SetNext() suffices since we will add a barrier when
+ // we publish a pointer to "x" in prev[i].
+ x->NoBarrier_SetNext(cur);
+
+ if (prev) {
+ prev->SetNext(x);
+ } else {
+ header->next.store(static_cast<void*>(x), std::memory_order_release);
+ }
+ }
+}
+
+bool HashLinkListRep::Contains(const char* key) const {
+ Slice internal_key = GetLengthPrefixedSlice(key);
+
+ auto transformed = GetPrefix(internal_key);
+ Pointer& bucket = GetBucket(transformed);
+ if (IsEmptyBucket(bucket)) {
+ return false;
+ }
+
+ Node* linked_list_node = GetLinkListFirstNode(bucket);
+ if (linked_list_node != nullptr) {
+ return LinkListContains(linked_list_node, internal_key);
+ }
+
+ SkipListBucketHeader* skip_list_header = GetSkipListBucketHeader(bucket);
+ if (skip_list_header != nullptr) {
+ return skip_list_header->skip_list.Contains(key);
+ }
+ return false;
+}
+
+size_t HashLinkListRep::ApproximateMemoryUsage() {
+ // Memory is always allocated from the allocator.
+ return 0;
+}
+
+void HashLinkListRep::Get(const LookupKey& k, void* callback_args,
+ bool (*callback_func)(void* arg, const char* entry)) {
+ auto transformed = transform_->Transform(k.user_key());
+ Pointer& bucket = GetBucket(transformed);
+
+ if (IsEmptyBucket(bucket)) {
+ return;
+ }
+
+ auto* link_list_head = GetLinkListFirstNode(bucket);
+ if (link_list_head != nullptr) {
+ LinkListIterator iter(this, link_list_head);
+ for (iter.Seek(k.internal_key(), nullptr);
+ iter.Valid() && callback_func(callback_args, iter.key());
+ iter.Next()) {
+ }
+ } else {
+ auto* skip_list_header = GetSkipListBucketHeader(bucket);
+ if (skip_list_header != nullptr) {
+ // Is a skip list
+ MemtableSkipList::Iterator iter(&skip_list_header->skip_list);
+ for (iter.Seek(k.memtable_key().data());
+ iter.Valid() && callback_func(callback_args, iter.key());
+ iter.Next()) {
+ }
+ }
+ }
+}
+
+MemTableRep::Iterator* HashLinkListRep::GetIterator(Arena* alloc_arena) {
+ // allocate a new arena of similar size to the one currently in use
+ Arena* new_arena = new Arena(allocator_->BlockSize());
+ auto list = new MemtableSkipList(compare_, new_arena);
+ HistogramImpl keys_per_bucket_hist;
+
+ for (size_t i = 0; i < bucket_size_; ++i) {
+ int count = 0;
+ Pointer& bucket = GetBucket(i);
+ if (!IsEmptyBucket(bucket)) {
+ auto* link_list_head = GetLinkListFirstNode(bucket);
+ if (link_list_head != nullptr) {
+ LinkListIterator itr(this, link_list_head);
+ for (itr.SeekToHead(); itr.Valid(); itr.Next()) {
+ list->Insert(itr.key());
+ count++;
+ }
+ } else {
+ auto* skip_list_header = GetSkipListBucketHeader(bucket);
+ assert(skip_list_header != nullptr);
+ // Is a skip list
+ MemtableSkipList::Iterator itr(&skip_list_header->skip_list);
+ for (itr.SeekToFirst(); itr.Valid(); itr.Next()) {
+ list->Insert(itr.key());
+ count++;
+ }
+ }
+ }
+ if (if_log_bucket_dist_when_flash_) {
+ keys_per_bucket_hist.Add(count);
+ }
+ }
+ if (if_log_bucket_dist_when_flash_ && logger_ != nullptr) {
+ Info(logger_, "hashLinkedList Entry distribution among buckets: %s",
+ keys_per_bucket_hist.ToString().c_str());
+ }
+
+ if (alloc_arena == nullptr) {
+ return new FullListIterator(list, new_arena);
+ } else {
+ auto mem = alloc_arena->AllocateAligned(sizeof(FullListIterator));
+ return new (mem) FullListIterator(list, new_arena);
+ }
+}
+
+MemTableRep::Iterator* HashLinkListRep::GetDynamicPrefixIterator(
+ Arena* alloc_arena) {
+ if (alloc_arena == nullptr) {
+ return new DynamicIterator(*this);
+ } else {
+ auto mem = alloc_arena->AllocateAligned(sizeof(DynamicIterator));
+ return new (mem) DynamicIterator(*this);
+ }
+}
+
+bool HashLinkListRep::LinkListContains(Node* head,
+ const Slice& user_key) const {
+ Node* x = FindGreaterOrEqualInBucket(head, user_key);
+ return (x != nullptr && Equal(user_key, x->key));
+}
+
+Node* HashLinkListRep::FindGreaterOrEqualInBucket(Node* head,
+ const Slice& key) const {
+ Node* x = head;
+ while (true) {
+ if (x == nullptr) {
+ return x;
+ }
+ Node* next = x->Next();
+ // Make sure the lists are sorted.
+ // If x points to head_ or next points nullptr, it is trivially satisfied.
+ assert((x == head) || (next == nullptr) || KeyIsAfterNode(next->key, x));
+ if (KeyIsAfterNode(key, x)) {
+ // Keep searching in this list
+ x = next;
+ } else {
+ break;
+ }
+ }
+ return x;
+}
+
+struct HashLinkListRepOptions {
+ static const char* kName() { return "HashLinkListRepFactoryOptions"; }
+ size_t bucket_count;
+ uint32_t threshold_use_skiplist;
+ size_t huge_page_tlb_size;
+ int bucket_entries_logging_threshold;
+ bool if_log_bucket_dist_when_flash;
+};
+
+static std::unordered_map<std::string, OptionTypeInfo> hash_linklist_info = {
+ {"bucket_count",
+ {offsetof(struct HashLinkListRepOptions, bucket_count), OptionType::kSizeT,
+ OptionVerificationType::kNormal, OptionTypeFlags::kNone}},
+ {"threshold",
+ {offsetof(struct HashLinkListRepOptions, threshold_use_skiplist),
+ OptionType::kUInt32T, OptionVerificationType::kNormal,
+ OptionTypeFlags::kNone}},
+ {"huge_page_size",
+ {offsetof(struct HashLinkListRepOptions, huge_page_tlb_size),
+ OptionType::kSizeT, OptionVerificationType::kNormal,
+ OptionTypeFlags::kNone}},
+ {"logging_threshold",
+ {offsetof(struct HashLinkListRepOptions, bucket_entries_logging_threshold),
+ OptionType::kInt, OptionVerificationType::kNormal,
+ OptionTypeFlags::kNone}},
+ {"log_when_flash",
+ {offsetof(struct HashLinkListRepOptions, if_log_bucket_dist_when_flash),
+ OptionType::kBoolean, OptionVerificationType::kNormal,
+ OptionTypeFlags::kNone}},
+};
+
+class HashLinkListRepFactory : public MemTableRepFactory {
+ public:
+ explicit HashLinkListRepFactory(size_t bucket_count,
+ uint32_t threshold_use_skiplist,
+ size_t huge_page_tlb_size,
+ int bucket_entries_logging_threshold,
+ bool if_log_bucket_dist_when_flash) {
+ options_.bucket_count = bucket_count;
+ options_.threshold_use_skiplist = threshold_use_skiplist;
+ options_.huge_page_tlb_size = huge_page_tlb_size;
+ options_.bucket_entries_logging_threshold =
+ bucket_entries_logging_threshold;
+ options_.if_log_bucket_dist_when_flash = if_log_bucket_dist_when_flash;
+ RegisterOptions(&options_, &hash_linklist_info);
+ }
+
+ using MemTableRepFactory::CreateMemTableRep;
+ virtual MemTableRep* CreateMemTableRep(
+ const MemTableRep::KeyComparator& compare, Allocator* allocator,
+ const SliceTransform* transform, Logger* logger) override;
+
+ static const char* kClassName() { return "HashLinkListRepFactory"; }
+ static const char* kNickName() { return "hash_linkedlist"; }
+ virtual const char* Name() const override { return kClassName(); }
+ virtual const char* NickName() const override { return kNickName(); }
+
+ private:
+ HashLinkListRepOptions options_;
+};
+
+} // namespace
+
+MemTableRep* HashLinkListRepFactory::CreateMemTableRep(
+ const MemTableRep::KeyComparator& compare, Allocator* allocator,
+ const SliceTransform* transform, Logger* logger) {
+ return new HashLinkListRep(
+ compare, allocator, transform, options_.bucket_count,
+ options_.threshold_use_skiplist, options_.huge_page_tlb_size, logger,
+ options_.bucket_entries_logging_threshold,
+ options_.if_log_bucket_dist_when_flash);
+}
+
+MemTableRepFactory* NewHashLinkListRepFactory(
+ size_t bucket_count, size_t huge_page_tlb_size,
+ int bucket_entries_logging_threshold, bool if_log_bucket_dist_when_flash,
+ uint32_t threshold_use_skiplist) {
+ return new HashLinkListRepFactory(
+ bucket_count, threshold_use_skiplist, huge_page_tlb_size,
+ bucket_entries_logging_threshold, if_log_bucket_dist_when_flash);
+}
+
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/memtable/hash_skiplist_rep.cc b/src/rocksdb/memtable/hash_skiplist_rep.cc
new file mode 100644
index 000000000..9d093829b
--- /dev/null
+++ b/src/rocksdb/memtable/hash_skiplist_rep.cc
@@ -0,0 +1,393 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+
+#ifndef ROCKSDB_LITE
+#include <atomic>
+
+#include "db/memtable.h"
+#include "memory/arena.h"
+#include "memtable/skiplist.h"
+#include "port/port.h"
+#include "rocksdb/memtablerep.h"
+#include "rocksdb/slice.h"
+#include "rocksdb/slice_transform.h"
+#include "rocksdb/utilities/options_type.h"
+#include "util/murmurhash.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace {
+
+class HashSkipListRep : public MemTableRep {
+ public:
+ HashSkipListRep(const MemTableRep::KeyComparator& compare,
+ Allocator* allocator, const SliceTransform* transform,
+ size_t bucket_size, int32_t skiplist_height,
+ int32_t skiplist_branching_factor);
+
+ void Insert(KeyHandle handle) override;
+
+ bool Contains(const char* key) const override;
+
+ size_t ApproximateMemoryUsage() override;
+
+ void Get(const LookupKey& k, void* callback_args,
+ bool (*callback_func)(void* arg, const char* entry)) override;
+
+ ~HashSkipListRep() override;
+
+ MemTableRep::Iterator* GetIterator(Arena* arena = nullptr) override;
+
+ MemTableRep::Iterator* GetDynamicPrefixIterator(
+ Arena* arena = nullptr) override;
+
+ private:
+ friend class DynamicIterator;
+ using Bucket = SkipList<const char*, const MemTableRep::KeyComparator&>;
+
+ size_t bucket_size_;
+
+ const int32_t skiplist_height_;
+ const int32_t skiplist_branching_factor_;
+
+ // Maps slices (which are transformed user keys) to buckets of keys sharing
+ // the same transform.
+ std::atomic<Bucket*>* buckets_;
+
+ // The user-supplied transform whose domain is the user keys.
+ const SliceTransform* transform_;
+
+ const MemTableRep::KeyComparator& compare_;
+ // immutable after construction
+ Allocator* const allocator_;
+
+ inline size_t GetHash(const Slice& slice) const {
+ return MurmurHash(slice.data(), static_cast<int>(slice.size()), 0) %
+ bucket_size_;
+ }
+ inline Bucket* GetBucket(size_t i) const {
+ return buckets_[i].load(std::memory_order_acquire);
+ }
+ inline Bucket* GetBucket(const Slice& slice) const {
+ return GetBucket(GetHash(slice));
+ }
+ // Get a bucket from buckets_. If the bucket hasn't been initialized yet,
+ // initialize it before returning.
+ Bucket* GetInitializedBucket(const Slice& transformed);
+
+ class Iterator : public MemTableRep::Iterator {
+ public:
+ explicit Iterator(Bucket* list, bool own_list = true,
+ Arena* arena = nullptr)
+ : list_(list), iter_(list), own_list_(own_list), arena_(arena) {}
+
+ ~Iterator() override {
+ // if we own the list, we should also delete it
+ if (own_list_) {
+ assert(list_ != nullptr);
+ delete list_;
+ }
+ }
+
+ // Returns true iff the iterator is positioned at a valid node.
+ bool Valid() const override { return list_ != nullptr && iter_.Valid(); }
+
+ // Returns the key at the current position.
+ // REQUIRES: Valid()
+ const char* key() const override {
+ assert(Valid());
+ return iter_.key();
+ }
+
+ // Advances to the next position.
+ // REQUIRES: Valid()
+ void Next() override {
+ assert(Valid());
+ iter_.Next();
+ }
+
+ // Advances to the previous position.
+ // REQUIRES: Valid()
+ void Prev() override {
+ assert(Valid());
+ iter_.Prev();
+ }
+
+ // Advance to the first entry with a key >= target
+ void Seek(const Slice& internal_key, const char* memtable_key) override {
+ if (list_ != nullptr) {
+ const char* encoded_key = (memtable_key != nullptr)
+ ? memtable_key
+ : EncodeKey(&tmp_, internal_key);
+ iter_.Seek(encoded_key);
+ }
+ }
+
+ // Retreat to the last entry with a key <= target
+ void SeekForPrev(const Slice& /*internal_key*/,
+ const char* /*memtable_key*/) override {
+ // not supported
+ assert(false);
+ }
+
+ // Position at the first entry in collection.
+ // Final state of iterator is Valid() iff collection is not empty.
+ void SeekToFirst() override {
+ if (list_ != nullptr) {
+ iter_.SeekToFirst();
+ }
+ }
+
+ // Position at the last entry in collection.
+ // Final state of iterator is Valid() iff collection is not empty.
+ void SeekToLast() override {
+ if (list_ != nullptr) {
+ iter_.SeekToLast();
+ }
+ }
+
+ protected:
+ void Reset(Bucket* list) {
+ if (own_list_) {
+ assert(list_ != nullptr);
+ delete list_;
+ }
+ list_ = list;
+ iter_.SetList(list);
+ own_list_ = false;
+ }
+
+ private:
+ // if list_ is nullptr, we should NEVER call any methods on iter_
+ // if list_ is nullptr, this Iterator is not Valid()
+ Bucket* list_;
+ Bucket::Iterator iter_;
+ // here we track if we own list_. If we own it, we are also
+ // responsible for it's cleaning. This is a poor man's std::shared_ptr
+ bool own_list_;
+ std::unique_ptr<Arena> arena_;
+ std::string tmp_; // For passing to EncodeKey
+ };
+
+ class DynamicIterator : public HashSkipListRep::Iterator {
+ public:
+ explicit DynamicIterator(const HashSkipListRep& memtable_rep)
+ : HashSkipListRep::Iterator(nullptr, false),
+ memtable_rep_(memtable_rep) {}
+
+ // Advance to the first entry with a key >= target
+ void Seek(const Slice& k, const char* memtable_key) override {
+ auto transformed = memtable_rep_.transform_->Transform(ExtractUserKey(k));
+ Reset(memtable_rep_.GetBucket(transformed));
+ HashSkipListRep::Iterator::Seek(k, memtable_key);
+ }
+
+ // Position at the first entry in collection.
+ // Final state of iterator is Valid() iff collection is not empty.
+ void SeekToFirst() override {
+ // Prefix iterator does not support total order.
+ // We simply set the iterator to invalid state
+ Reset(nullptr);
+ }
+
+ // Position at the last entry in collection.
+ // Final state of iterator is Valid() iff collection is not empty.
+ void SeekToLast() override {
+ // Prefix iterator does not support total order.
+ // We simply set the iterator to invalid state
+ Reset(nullptr);
+ }
+
+ private:
+ // the underlying memtable
+ const HashSkipListRep& memtable_rep_;
+ };
+
+ class EmptyIterator : public MemTableRep::Iterator {
+ // This is used when there wasn't a bucket. It is cheaper than
+ // instantiating an empty bucket over which to iterate.
+ public:
+ EmptyIterator() {}
+ bool Valid() const override { return false; }
+ const char* key() const override {
+ assert(false);
+ return nullptr;
+ }
+ void Next() override {}
+ void Prev() override {}
+ void Seek(const Slice& /*internal_key*/,
+ const char* /*memtable_key*/) override {}
+ void SeekForPrev(const Slice& /*internal_key*/,
+ const char* /*memtable_key*/) override {}
+ void SeekToFirst() override {}
+ void SeekToLast() override {}
+
+ private:
+ };
+};
+
+HashSkipListRep::HashSkipListRep(const MemTableRep::KeyComparator& compare,
+ Allocator* allocator,
+ const SliceTransform* transform,
+ size_t bucket_size, int32_t skiplist_height,
+ int32_t skiplist_branching_factor)
+ : MemTableRep(allocator),
+ bucket_size_(bucket_size),
+ skiplist_height_(skiplist_height),
+ skiplist_branching_factor_(skiplist_branching_factor),
+ transform_(transform),
+ compare_(compare),
+ allocator_(allocator) {
+ auto mem =
+ allocator->AllocateAligned(sizeof(std::atomic<void*>) * bucket_size);
+ buckets_ = new (mem) std::atomic<Bucket*>[bucket_size];
+
+ for (size_t i = 0; i < bucket_size_; ++i) {
+ buckets_[i].store(nullptr, std::memory_order_relaxed);
+ }
+}
+
+HashSkipListRep::~HashSkipListRep() {}
+
+HashSkipListRep::Bucket* HashSkipListRep::GetInitializedBucket(
+ const Slice& transformed) {
+ size_t hash = GetHash(transformed);
+ auto bucket = GetBucket(hash);
+ if (bucket == nullptr) {
+ auto addr = allocator_->AllocateAligned(sizeof(Bucket));
+ bucket = new (addr) Bucket(compare_, allocator_, skiplist_height_,
+ skiplist_branching_factor_);
+ buckets_[hash].store(bucket, std::memory_order_release);
+ }
+ return bucket;
+}
+
+void HashSkipListRep::Insert(KeyHandle handle) {
+ auto* key = static_cast<char*>(handle);
+ assert(!Contains(key));
+ auto transformed = transform_->Transform(UserKey(key));
+ auto bucket = GetInitializedBucket(transformed);
+ bucket->Insert(key);
+}
+
+bool HashSkipListRep::Contains(const char* key) const {
+ auto transformed = transform_->Transform(UserKey(key));
+ auto bucket = GetBucket(transformed);
+ if (bucket == nullptr) {
+ return false;
+ }
+ return bucket->Contains(key);
+}
+
+size_t HashSkipListRep::ApproximateMemoryUsage() { return 0; }
+
+void HashSkipListRep::Get(const LookupKey& k, void* callback_args,
+ bool (*callback_func)(void* arg, const char* entry)) {
+ auto transformed = transform_->Transform(k.user_key());
+ auto bucket = GetBucket(transformed);
+ if (bucket != nullptr) {
+ Bucket::Iterator iter(bucket);
+ for (iter.Seek(k.memtable_key().data());
+ iter.Valid() && callback_func(callback_args, iter.key());
+ iter.Next()) {
+ }
+ }
+}
+
+MemTableRep::Iterator* HashSkipListRep::GetIterator(Arena* arena) {
+ // allocate a new arena of similar size to the one currently in use
+ Arena* new_arena = new Arena(allocator_->BlockSize());
+ auto list = new Bucket(compare_, new_arena);
+ for (size_t i = 0; i < bucket_size_; ++i) {
+ auto bucket = GetBucket(i);
+ if (bucket != nullptr) {
+ Bucket::Iterator itr(bucket);
+ for (itr.SeekToFirst(); itr.Valid(); itr.Next()) {
+ list->Insert(itr.key());
+ }
+ }
+ }
+ if (arena == nullptr) {
+ return new Iterator(list, true, new_arena);
+ } else {
+ auto mem = arena->AllocateAligned(sizeof(Iterator));
+ return new (mem) Iterator(list, true, new_arena);
+ }
+}
+
+MemTableRep::Iterator* HashSkipListRep::GetDynamicPrefixIterator(Arena* arena) {
+ if (arena == nullptr) {
+ return new DynamicIterator(*this);
+ } else {
+ auto mem = arena->AllocateAligned(sizeof(DynamicIterator));
+ return new (mem) DynamicIterator(*this);
+ }
+}
+
+struct HashSkipListRepOptions {
+ static const char* kName() { return "HashSkipListRepFactoryOptions"; }
+ size_t bucket_count;
+ int32_t skiplist_height;
+ int32_t skiplist_branching_factor;
+};
+
+static std::unordered_map<std::string, OptionTypeInfo> hash_skiplist_info = {
+ {"bucket_count",
+ {offsetof(struct HashSkipListRepOptions, bucket_count), OptionType::kSizeT,
+ OptionVerificationType::kNormal, OptionTypeFlags::kNone}},
+ {"skiplist_height",
+ {offsetof(struct HashSkipListRepOptions, skiplist_height),
+ OptionType::kInt32T, OptionVerificationType::kNormal,
+ OptionTypeFlags::kNone}},
+ {"branching_factor",
+ {offsetof(struct HashSkipListRepOptions, skiplist_branching_factor),
+ OptionType::kInt32T, OptionVerificationType::kNormal,
+ OptionTypeFlags::kNone}},
+};
+
+class HashSkipListRepFactory : public MemTableRepFactory {
+ public:
+ explicit HashSkipListRepFactory(size_t bucket_count, int32_t skiplist_height,
+ int32_t skiplist_branching_factor) {
+ options_.bucket_count = bucket_count;
+ options_.skiplist_height = skiplist_height;
+ options_.skiplist_branching_factor = skiplist_branching_factor;
+ RegisterOptions(&options_, &hash_skiplist_info);
+ }
+
+ using MemTableRepFactory::CreateMemTableRep;
+ virtual MemTableRep* CreateMemTableRep(
+ const MemTableRep::KeyComparator& compare, Allocator* allocator,
+ const SliceTransform* transform, Logger* logger) override;
+
+ static const char* kClassName() { return "HashSkipListRepFactory"; }
+ static const char* kNickName() { return "prefix_hash"; }
+
+ virtual const char* Name() const override { return kClassName(); }
+ virtual const char* NickName() const override { return kNickName(); }
+
+ private:
+ HashSkipListRepOptions options_;
+};
+
+} // namespace
+
+MemTableRep* HashSkipListRepFactory::CreateMemTableRep(
+ const MemTableRep::KeyComparator& compare, Allocator* allocator,
+ const SliceTransform* transform, Logger* /*logger*/) {
+ return new HashSkipListRep(compare, allocator, transform,
+ options_.bucket_count, options_.skiplist_height,
+ options_.skiplist_branching_factor);
+}
+
+MemTableRepFactory* NewHashSkipListRepFactory(
+ size_t bucket_count, int32_t skiplist_height,
+ int32_t skiplist_branching_factor) {
+ return new HashSkipListRepFactory(bucket_count, skiplist_height,
+ skiplist_branching_factor);
+}
+
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/memtable/inlineskiplist.h b/src/rocksdb/memtable/inlineskiplist.h
new file mode 100644
index 000000000..abb3c3ddb
--- /dev/null
+++ b/src/rocksdb/memtable/inlineskiplist.h
@@ -0,0 +1,1051 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved. Use of
+// this source code is governed by a BSD-style license that can be found
+// in the LICENSE file. See the AUTHORS file for names of contributors.
+//
+// InlineSkipList is derived from SkipList (skiplist.h), but it optimizes
+// the memory layout by requiring that the key storage be allocated through
+// the skip list instance. For the common case of SkipList<const char*,
+// Cmp> this saves 1 pointer per skip list node and gives better cache
+// locality, at the expense of wasted padding from using AllocateAligned
+// instead of Allocate for the keys. The unused padding will be from
+// 0 to sizeof(void*)-1 bytes, and the space savings are sizeof(void*)
+// bytes, so despite the padding the space used is always less than
+// SkipList<const char*, ..>.
+//
+// Thread safety -------------
+//
+// Writes via Insert require external synchronization, most likely a mutex.
+// InsertConcurrently can be safely called concurrently with reads and
+// with other concurrent inserts. Reads require a guarantee that the
+// InlineSkipList will not be destroyed while the read is in progress.
+// Apart from that, reads progress without any internal locking or
+// synchronization.
+//
+// Invariants:
+//
+// (1) Allocated nodes are never deleted until the InlineSkipList is
+// destroyed. This is trivially guaranteed by the code since we never
+// delete any skip list nodes.
+//
+// (2) The contents of a Node except for the next/prev pointers are
+// immutable after the Node has been linked into the InlineSkipList.
+// Only Insert() modifies the list, and it is careful to initialize a
+// node and use release-stores to publish the nodes in one or more lists.
+//
+// ... prev vs. next pointer ordering ...
+//
+
+#pragma once
+#include <assert.h>
+#include <stdlib.h>
+
+#include <algorithm>
+#include <atomic>
+#include <type_traits>
+
+#include "memory/allocator.h"
+#include "port/likely.h"
+#include "port/port.h"
+#include "rocksdb/slice.h"
+#include "util/coding.h"
+#include "util/random.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+template <class Comparator>
+class InlineSkipList {
+ private:
+ struct Node;
+ struct Splice;
+
+ public:
+ using DecodedKey =
+ typename std::remove_reference<Comparator>::type::DecodedType;
+
+ static const uint16_t kMaxPossibleHeight = 32;
+
+ // Create a new InlineSkipList object that will use "cmp" for comparing
+ // keys, and will allocate memory using "*allocator". Objects allocated
+ // in the allocator must remain allocated for the lifetime of the
+ // skiplist object.
+ explicit InlineSkipList(Comparator cmp, Allocator* allocator,
+ int32_t max_height = 12,
+ int32_t branching_factor = 4);
+ // No copying allowed
+ InlineSkipList(const InlineSkipList&) = delete;
+ InlineSkipList& operator=(const InlineSkipList&) = delete;
+
+ // Allocates a key and a skip-list node, returning a pointer to the key
+ // portion of the node. This method is thread-safe if the allocator
+ // is thread-safe.
+ char* AllocateKey(size_t key_size);
+
+ // Allocate a splice using allocator.
+ Splice* AllocateSplice();
+
+ // Allocate a splice on heap.
+ Splice* AllocateSpliceOnHeap();
+
+ // Inserts a key allocated by AllocateKey, after the actual key value
+ // has been filled in.
+ //
+ // REQUIRES: nothing that compares equal to key is currently in the list.
+ // REQUIRES: no concurrent calls to any of inserts.
+ bool Insert(const char* key);
+
+ // Inserts a key allocated by AllocateKey with a hint of last insert
+ // position in the skip-list. If hint points to nullptr, a new hint will be
+ // populated, which can be used in subsequent calls.
+ //
+ // It can be used to optimize the workload where there are multiple groups
+ // of keys, and each key is likely to insert to a location close to the last
+ // inserted key in the same group. One example is sequential inserts.
+ //
+ // REQUIRES: nothing that compares equal to key is currently in the list.
+ // REQUIRES: no concurrent calls to any of inserts.
+ bool InsertWithHint(const char* key, void** hint);
+
+ // Like InsertConcurrently, but with a hint
+ //
+ // REQUIRES: nothing that compares equal to key is currently in the list.
+ // REQUIRES: no concurrent calls that use same hint
+ bool InsertWithHintConcurrently(const char* key, void** hint);
+
+ // Like Insert, but external synchronization is not required.
+ bool InsertConcurrently(const char* key);
+
+ // Inserts a node into the skip list. key must have been allocated by
+ // AllocateKey and then filled in by the caller. If UseCAS is true,
+ // then external synchronization is not required, otherwise this method
+ // may not be called concurrently with any other insertions.
+ //
+ // Regardless of whether UseCAS is true, the splice must be owned
+ // exclusively by the current thread. If allow_partial_splice_fix is
+ // true, then the cost of insertion is amortized O(log D), where D is
+ // the distance from the splice to the inserted key (measured as the
+ // number of intervening nodes). Note that this bound is very good for
+ // sequential insertions! If allow_partial_splice_fix is false then
+ // the existing splice will be ignored unless the current key is being
+ // inserted immediately after the splice. allow_partial_splice_fix ==
+ // false has worse running time for the non-sequential case O(log N),
+ // but a better constant factor.
+ template <bool UseCAS>
+ bool Insert(const char* key, Splice* splice, bool allow_partial_splice_fix);
+
+ // Returns true iff an entry that compares equal to key is in the list.
+ bool Contains(const char* key) const;
+
+ // Return estimated number of entries smaller than `key`.
+ uint64_t EstimateCount(const char* key) const;
+
+ // Validate correctness of the skip-list.
+ void TEST_Validate() const;
+
+ // Iteration over the contents of a skip list
+ class Iterator {
+ public:
+ // Initialize an iterator over the specified list.
+ // The returned iterator is not valid.
+ explicit Iterator(const InlineSkipList* list);
+
+ // Change the underlying skiplist used for this iterator
+ // This enables us not changing the iterator without deallocating
+ // an old one and then allocating a new one
+ void SetList(const InlineSkipList* list);
+
+ // Returns true iff the iterator is positioned at a valid node.
+ bool Valid() const;
+
+ // Returns the key at the current position.
+ // REQUIRES: Valid()
+ const char* key() const;
+
+ // Advances to the next position.
+ // REQUIRES: Valid()
+ void Next();
+
+ // Advances to the previous position.
+ // REQUIRES: Valid()
+ void Prev();
+
+ // Advance to the first entry with a key >= target
+ void Seek(const char* target);
+
+ // Retreat to the last entry with a key <= target
+ void SeekForPrev(const char* target);
+
+ // Advance to a random entry in the list.
+ void RandomSeek();
+
+ // Position at the first entry in list.
+ // Final state of iterator is Valid() iff list is not empty.
+ void SeekToFirst();
+
+ // Position at the last entry in list.
+ // Final state of iterator is Valid() iff list is not empty.
+ void SeekToLast();
+
+ private:
+ const InlineSkipList* list_;
+ Node* node_;
+ // Intentionally copyable
+ };
+
+ private:
+ const uint16_t kMaxHeight_;
+ const uint16_t kBranching_;
+ const uint32_t kScaledInverseBranching_;
+
+ Allocator* const allocator_; // Allocator used for allocations of nodes
+ // Immutable after construction
+ Comparator const compare_;
+ Node* const head_;
+
+ // Modified only by Insert(). Read racily by readers, but stale
+ // values are ok.
+ std::atomic<int> max_height_; // Height of the entire list
+
+ // seq_splice_ is a Splice used for insertions in the non-concurrent
+ // case. It caches the prev and next found during the most recent
+ // non-concurrent insertion.
+ Splice* seq_splice_;
+
+ inline int GetMaxHeight() const {
+ return max_height_.load(std::memory_order_relaxed);
+ }
+
+ int RandomHeight();
+
+ Node* AllocateNode(size_t key_size, int height);
+
+ bool Equal(const char* a, const char* b) const {
+ return (compare_(a, b) == 0);
+ }
+
+ bool LessThan(const char* a, const char* b) const {
+ return (compare_(a, b) < 0);
+ }
+
+ // Return true if key is greater than the data stored in "n". Null n
+ // is considered infinite. n should not be head_.
+ bool KeyIsAfterNode(const char* key, Node* n) const;
+ bool KeyIsAfterNode(const DecodedKey& key, Node* n) const;
+
+ // Returns the earliest node with a key >= key.
+ // Return nullptr if there is no such node.
+ Node* FindGreaterOrEqual(const char* key) const;
+
+ // Return the latest node with a key < key.
+ // Return head_ if there is no such node.
+ // Fills prev[level] with pointer to previous node at "level" for every
+ // level in [0..max_height_-1], if prev is non-null.
+ Node* FindLessThan(const char* key, Node** prev = nullptr) const;
+
+ // Return the latest node with a key < key on bottom_level. Start searching
+ // from root node on the level below top_level.
+ // Fills prev[level] with pointer to previous node at "level" for every
+ // level in [bottom_level..top_level-1], if prev is non-null.
+ Node* FindLessThan(const char* key, Node** prev, Node* root, int top_level,
+ int bottom_level) const;
+
+ // Return the last node in the list.
+ // Return head_ if list is empty.
+ Node* FindLast() const;
+
+ // Returns a random entry.
+ Node* FindRandomEntry() const;
+
+ // Traverses a single level of the list, setting *out_prev to the last
+ // node before the key and *out_next to the first node after. Assumes
+ // that the key is not present in the skip list. On entry, before should
+ // point to a node that is before the key, and after should point to
+ // a node that is after the key. after should be nullptr if a good after
+ // node isn't conveniently available.
+ template <bool prefetch_before>
+ void FindSpliceForLevel(const DecodedKey& key, Node* before, Node* after,
+ int level, Node** out_prev, Node** out_next);
+
+ // Recomputes Splice levels from highest_level (inclusive) down to
+ // lowest_level (inclusive).
+ void RecomputeSpliceLevels(const DecodedKey& key, Splice* splice,
+ int recompute_level);
+};
+
+// Implementation details follow
+
+template <class Comparator>
+struct InlineSkipList<Comparator>::Splice {
+ // The invariant of a Splice is that prev_[i+1].key <= prev_[i].key <
+ // next_[i].key <= next_[i+1].key for all i. That means that if a
+ // key is bracketed by prev_[i] and next_[i] then it is bracketed by
+ // all higher levels. It is _not_ required that prev_[i]->Next(i) ==
+ // next_[i] (it probably did at some point in the past, but intervening
+ // or concurrent operations might have inserted nodes in between).
+ int height_ = 0;
+ Node** prev_;
+ Node** next_;
+};
+
+// The Node data type is more of a pointer into custom-managed memory than
+// a traditional C++ struct. The key is stored in the bytes immediately
+// after the struct, and the next_ pointers for nodes with height > 1 are
+// stored immediately _before_ the struct. This avoids the need to include
+// any pointer or sizing data, which reduces per-node memory overheads.
+template <class Comparator>
+struct InlineSkipList<Comparator>::Node {
+ // Stores the height of the node in the memory location normally used for
+ // next_[0]. This is used for passing data from AllocateKey to Insert.
+ void StashHeight(const int height) {
+ assert(sizeof(int) <= sizeof(next_[0]));
+ memcpy(static_cast<void*>(&next_[0]), &height, sizeof(int));
+ }
+
+ // Retrieves the value passed to StashHeight. Undefined after a call
+ // to SetNext or NoBarrier_SetNext.
+ int UnstashHeight() const {
+ int rv;
+ memcpy(&rv, &next_[0], sizeof(int));
+ return rv;
+ }
+
+ const char* Key() const { return reinterpret_cast<const char*>(&next_[1]); }
+
+ // Accessors/mutators for links. Wrapped in methods so we can add
+ // the appropriate barriers as necessary, and perform the necessary
+ // addressing trickery for storing links below the Node in memory.
+ Node* Next(int n) {
+ assert(n >= 0);
+ // Use an 'acquire load' so that we observe a fully initialized
+ // version of the returned Node.
+ return ((&next_[0] - n)->load(std::memory_order_acquire));
+ }
+
+ void SetNext(int n, Node* x) {
+ assert(n >= 0);
+ // Use a 'release store' so that anybody who reads through this
+ // pointer observes a fully initialized version of the inserted node.
+ (&next_[0] - n)->store(x, std::memory_order_release);
+ }
+
+ bool CASNext(int n, Node* expected, Node* x) {
+ assert(n >= 0);
+ return (&next_[0] - n)->compare_exchange_strong(expected, x);
+ }
+
+ // No-barrier variants that can be safely used in a few locations.
+ Node* NoBarrier_Next(int n) {
+ assert(n >= 0);
+ return (&next_[0] - n)->load(std::memory_order_relaxed);
+ }
+
+ void NoBarrier_SetNext(int n, Node* x) {
+ assert(n >= 0);
+ (&next_[0] - n)->store(x, std::memory_order_relaxed);
+ }
+
+ // Insert node after prev on specific level.
+ void InsertAfter(Node* prev, int level) {
+ // NoBarrier_SetNext() suffices since we will add a barrier when
+ // we publish a pointer to "this" in prev.
+ NoBarrier_SetNext(level, prev->NoBarrier_Next(level));
+ prev->SetNext(level, this);
+ }
+
+ private:
+ // next_[0] is the lowest level link (level 0). Higher levels are
+ // stored _earlier_, so level 1 is at next_[-1].
+ std::atomic<Node*> next_[1];
+};
+
+template <class Comparator>
+inline InlineSkipList<Comparator>::Iterator::Iterator(
+ const InlineSkipList* list) {
+ SetList(list);
+}
+
+template <class Comparator>
+inline void InlineSkipList<Comparator>::Iterator::SetList(
+ const InlineSkipList* list) {
+ list_ = list;
+ node_ = nullptr;
+}
+
+template <class Comparator>
+inline bool InlineSkipList<Comparator>::Iterator::Valid() const {
+ return node_ != nullptr;
+}
+
+template <class Comparator>
+inline const char* InlineSkipList<Comparator>::Iterator::key() const {
+ assert(Valid());
+ return node_->Key();
+}
+
+template <class Comparator>
+inline void InlineSkipList<Comparator>::Iterator::Next() {
+ assert(Valid());
+ node_ = node_->Next(0);
+}
+
+template <class Comparator>
+inline void InlineSkipList<Comparator>::Iterator::Prev() {
+ // Instead of using explicit "prev" links, we just search for the
+ // last node that falls before key.
+ assert(Valid());
+ node_ = list_->FindLessThan(node_->Key());
+ if (node_ == list_->head_) {
+ node_ = nullptr;
+ }
+}
+
+template <class Comparator>
+inline void InlineSkipList<Comparator>::Iterator::Seek(const char* target) {
+ node_ = list_->FindGreaterOrEqual(target);
+}
+
+template <class Comparator>
+inline void InlineSkipList<Comparator>::Iterator::SeekForPrev(
+ const char* target) {
+ Seek(target);
+ if (!Valid()) {
+ SeekToLast();
+ }
+ while (Valid() && list_->LessThan(target, key())) {
+ Prev();
+ }
+}
+
+template <class Comparator>
+inline void InlineSkipList<Comparator>::Iterator::RandomSeek() {
+ node_ = list_->FindRandomEntry();
+}
+
+template <class Comparator>
+inline void InlineSkipList<Comparator>::Iterator::SeekToFirst() {
+ node_ = list_->head_->Next(0);
+}
+
+template <class Comparator>
+inline void InlineSkipList<Comparator>::Iterator::SeekToLast() {
+ node_ = list_->FindLast();
+ if (node_ == list_->head_) {
+ node_ = nullptr;
+ }
+}
+
+template <class Comparator>
+int InlineSkipList<Comparator>::RandomHeight() {
+ auto rnd = Random::GetTLSInstance();
+
+ // Increase height with probability 1 in kBranching
+ int height = 1;
+ while (height < kMaxHeight_ && height < kMaxPossibleHeight &&
+ rnd->Next() < kScaledInverseBranching_) {
+ height++;
+ }
+ assert(height > 0);
+ assert(height <= kMaxHeight_);
+ assert(height <= kMaxPossibleHeight);
+ return height;
+}
+
+template <class Comparator>
+bool InlineSkipList<Comparator>::KeyIsAfterNode(const char* key,
+ Node* n) const {
+ // nullptr n is considered infinite
+ assert(n != head_);
+ return (n != nullptr) && (compare_(n->Key(), key) < 0);
+}
+
+template <class Comparator>
+bool InlineSkipList<Comparator>::KeyIsAfterNode(const DecodedKey& key,
+ Node* n) const {
+ // nullptr n is considered infinite
+ assert(n != head_);
+ return (n != nullptr) && (compare_(n->Key(), key) < 0);
+}
+
+template <class Comparator>
+typename InlineSkipList<Comparator>::Node*
+InlineSkipList<Comparator>::FindGreaterOrEqual(const char* key) const {
+ // Note: It looks like we could reduce duplication by implementing
+ // this function as FindLessThan(key)->Next(0), but we wouldn't be able
+ // to exit early on equality and the result wouldn't even be correct.
+ // A concurrent insert might occur after FindLessThan(key) but before
+ // we get a chance to call Next(0).
+ Node* x = head_;
+ int level = GetMaxHeight() - 1;
+ Node* last_bigger = nullptr;
+ const DecodedKey key_decoded = compare_.decode_key(key);
+ while (true) {
+ Node* next = x->Next(level);
+ if (next != nullptr) {
+ PREFETCH(next->Next(level), 0, 1);
+ }
+ // Make sure the lists are sorted
+ assert(x == head_ || next == nullptr || KeyIsAfterNode(next->Key(), x));
+ // Make sure we haven't overshot during our search
+ assert(x == head_ || KeyIsAfterNode(key_decoded, x));
+ int cmp = (next == nullptr || next == last_bigger)
+ ? 1
+ : compare_(next->Key(), key_decoded);
+ if (cmp == 0 || (cmp > 0 && level == 0)) {
+ return next;
+ } else if (cmp < 0) {
+ // Keep searching in this list
+ x = next;
+ } else {
+ // Switch to next list, reuse compare_() result
+ last_bigger = next;
+ level--;
+ }
+ }
+}
+
+template <class Comparator>
+typename InlineSkipList<Comparator>::Node*
+InlineSkipList<Comparator>::FindLessThan(const char* key, Node** prev) const {
+ return FindLessThan(key, prev, head_, GetMaxHeight(), 0);
+}
+
+template <class Comparator>
+typename InlineSkipList<Comparator>::Node*
+InlineSkipList<Comparator>::FindLessThan(const char* key, Node** prev,
+ Node* root, int top_level,
+ int bottom_level) const {
+ assert(top_level > bottom_level);
+ int level = top_level - 1;
+ Node* x = root;
+ // KeyIsAfter(key, last_not_after) is definitely false
+ Node* last_not_after = nullptr;
+ const DecodedKey key_decoded = compare_.decode_key(key);
+ while (true) {
+ assert(x != nullptr);
+ Node* next = x->Next(level);
+ if (next != nullptr) {
+ PREFETCH(next->Next(level), 0, 1);
+ }
+ assert(x == head_ || next == nullptr || KeyIsAfterNode(next->Key(), x));
+ assert(x == head_ || KeyIsAfterNode(key_decoded, x));
+ if (next != last_not_after && KeyIsAfterNode(key_decoded, next)) {
+ // Keep searching in this list
+ assert(next != nullptr);
+ x = next;
+ } else {
+ if (prev != nullptr) {
+ prev[level] = x;
+ }
+ if (level == bottom_level) {
+ return x;
+ } else {
+ // Switch to next list, reuse KeyIsAfterNode() result
+ last_not_after = next;
+ level--;
+ }
+ }
+ }
+}
+
+template <class Comparator>
+typename InlineSkipList<Comparator>::Node*
+InlineSkipList<Comparator>::FindLast() const {
+ Node* x = head_;
+ int level = GetMaxHeight() - 1;
+ while (true) {
+ Node* next = x->Next(level);
+ if (next == nullptr) {
+ if (level == 0) {
+ return x;
+ } else {
+ // Switch to next list
+ level--;
+ }
+ } else {
+ x = next;
+ }
+ }
+}
+
+template <class Comparator>
+typename InlineSkipList<Comparator>::Node*
+InlineSkipList<Comparator>::FindRandomEntry() const {
+ // TODO(bjlemaire): consider adding PREFETCH calls.
+ Node *x = head_, *scan_node = nullptr, *limit_node = nullptr;
+
+ // We start at the max level.
+ // FOr each level, we look at all the nodes at the level, and
+ // we randomly pick one of them. Then decrement the level
+ // and reiterate the process.
+ // eg: assume GetMaxHeight()=5, and there are #100 elements (nodes).
+ // level 4 nodes: lvl_nodes={#1, #15, #67, #84}. Randomly pick #15.
+ // We will consider all the nodes between #15 (inclusive) and #67
+ // (exclusive). #67 is called 'limit_node' here.
+ // level 3 nodes: lvl_nodes={#15, #21, #45, #51}. Randomly choose
+ // #51. #67 remains 'limit_node'.
+ // [...]
+ // level 0 nodes: lvl_nodes={#56,#57,#58,#59}. Randomly pick $57.
+ // Return Node #57.
+ std::vector<Node*> lvl_nodes;
+ Random* rnd = Random::GetTLSInstance();
+ int level = GetMaxHeight() - 1;
+
+ while (level >= 0) {
+ lvl_nodes.clear();
+ scan_node = x;
+ while (scan_node != limit_node) {
+ lvl_nodes.push_back(scan_node);
+ scan_node = scan_node->Next(level);
+ }
+ uint32_t rnd_idx = rnd->Next() % lvl_nodes.size();
+ x = lvl_nodes[rnd_idx];
+ if (rnd_idx + 1 < lvl_nodes.size()) {
+ limit_node = lvl_nodes[rnd_idx + 1];
+ }
+ level--;
+ }
+ // There is a special case where x could still be the head_
+ // (note that the head_ contains no key).
+ return x == head_ && head_ != nullptr ? head_->Next(0) : x;
+}
+
+template <class Comparator>
+uint64_t InlineSkipList<Comparator>::EstimateCount(const char* key) const {
+ uint64_t count = 0;
+
+ Node* x = head_;
+ int level = GetMaxHeight() - 1;
+ const DecodedKey key_decoded = compare_.decode_key(key);
+ while (true) {
+ assert(x == head_ || compare_(x->Key(), key_decoded) < 0);
+ Node* next = x->Next(level);
+ if (next != nullptr) {
+ PREFETCH(next->Next(level), 0, 1);
+ }
+ if (next == nullptr || compare_(next->Key(), key_decoded) >= 0) {
+ if (level == 0) {
+ return count;
+ } else {
+ // Switch to next list
+ count *= kBranching_;
+ level--;
+ }
+ } else {
+ x = next;
+ count++;
+ }
+ }
+}
+
+template <class Comparator>
+InlineSkipList<Comparator>::InlineSkipList(const Comparator cmp,
+ Allocator* allocator,
+ int32_t max_height,
+ int32_t branching_factor)
+ : kMaxHeight_(static_cast<uint16_t>(max_height)),
+ kBranching_(static_cast<uint16_t>(branching_factor)),
+ kScaledInverseBranching_((Random::kMaxNext + 1) / kBranching_),
+ allocator_(allocator),
+ compare_(cmp),
+ head_(AllocateNode(0, max_height)),
+ max_height_(1),
+ seq_splice_(AllocateSplice()) {
+ assert(max_height > 0 && kMaxHeight_ == static_cast<uint32_t>(max_height));
+ assert(branching_factor > 1 &&
+ kBranching_ == static_cast<uint32_t>(branching_factor));
+ assert(kScaledInverseBranching_ > 0);
+
+ for (int i = 0; i < kMaxHeight_; ++i) {
+ head_->SetNext(i, nullptr);
+ }
+}
+
+template <class Comparator>
+char* InlineSkipList<Comparator>::AllocateKey(size_t key_size) {
+ return const_cast<char*>(AllocateNode(key_size, RandomHeight())->Key());
+}
+
+template <class Comparator>
+typename InlineSkipList<Comparator>::Node*
+InlineSkipList<Comparator>::AllocateNode(size_t key_size, int height) {
+ auto prefix = sizeof(std::atomic<Node*>) * (height - 1);
+
+ // prefix is space for the height - 1 pointers that we store before
+ // the Node instance (next_[-(height - 1) .. -1]). Node starts at
+ // raw + prefix, and holds the bottom-mode (level 0) skip list pointer
+ // next_[0]. key_size is the bytes for the key, which comes just after
+ // the Node.
+ char* raw = allocator_->AllocateAligned(prefix + sizeof(Node) + key_size);
+ Node* x = reinterpret_cast<Node*>(raw + prefix);
+
+ // Once we've linked the node into the skip list we don't actually need
+ // to know its height, because we can implicitly use the fact that we
+ // traversed into a node at level h to known that h is a valid level
+ // for that node. We need to convey the height to the Insert step,
+ // however, so that it can perform the proper links. Since we're not
+ // using the pointers at the moment, StashHeight temporarily borrow
+ // storage from next_[0] for that purpose.
+ x->StashHeight(height);
+ return x;
+}
+
+template <class Comparator>
+typename InlineSkipList<Comparator>::Splice*
+InlineSkipList<Comparator>::AllocateSplice() {
+ // size of prev_ and next_
+ size_t array_size = sizeof(Node*) * (kMaxHeight_ + 1);
+ char* raw = allocator_->AllocateAligned(sizeof(Splice) + array_size * 2);
+ Splice* splice = reinterpret_cast<Splice*>(raw);
+ splice->height_ = 0;
+ splice->prev_ = reinterpret_cast<Node**>(raw + sizeof(Splice));
+ splice->next_ = reinterpret_cast<Node**>(raw + sizeof(Splice) + array_size);
+ return splice;
+}
+
+template <class Comparator>
+typename InlineSkipList<Comparator>::Splice*
+InlineSkipList<Comparator>::AllocateSpliceOnHeap() {
+ size_t array_size = sizeof(Node*) * (kMaxHeight_ + 1);
+ char* raw = new char[sizeof(Splice) + array_size * 2];
+ Splice* splice = reinterpret_cast<Splice*>(raw);
+ splice->height_ = 0;
+ splice->prev_ = reinterpret_cast<Node**>(raw + sizeof(Splice));
+ splice->next_ = reinterpret_cast<Node**>(raw + sizeof(Splice) + array_size);
+ return splice;
+}
+
+template <class Comparator>
+bool InlineSkipList<Comparator>::Insert(const char* key) {
+ return Insert<false>(key, seq_splice_, false);
+}
+
+template <class Comparator>
+bool InlineSkipList<Comparator>::InsertConcurrently(const char* key) {
+ Node* prev[kMaxPossibleHeight];
+ Node* next[kMaxPossibleHeight];
+ Splice splice;
+ splice.prev_ = prev;
+ splice.next_ = next;
+ return Insert<true>(key, &splice, false);
+}
+
+template <class Comparator>
+bool InlineSkipList<Comparator>::InsertWithHint(const char* key, void** hint) {
+ assert(hint != nullptr);
+ Splice* splice = reinterpret_cast<Splice*>(*hint);
+ if (splice == nullptr) {
+ splice = AllocateSplice();
+ *hint = reinterpret_cast<void*>(splice);
+ }
+ return Insert<false>(key, splice, true);
+}
+
+template <class Comparator>
+bool InlineSkipList<Comparator>::InsertWithHintConcurrently(const char* key,
+ void** hint) {
+ assert(hint != nullptr);
+ Splice* splice = reinterpret_cast<Splice*>(*hint);
+ if (splice == nullptr) {
+ splice = AllocateSpliceOnHeap();
+ *hint = reinterpret_cast<void*>(splice);
+ }
+ return Insert<true>(key, splice, true);
+}
+
+template <class Comparator>
+template <bool prefetch_before>
+void InlineSkipList<Comparator>::FindSpliceForLevel(const DecodedKey& key,
+ Node* before, Node* after,
+ int level, Node** out_prev,
+ Node** out_next) {
+ while (true) {
+ Node* next = before->Next(level);
+ if (next != nullptr) {
+ PREFETCH(next->Next(level), 0, 1);
+ }
+ if (prefetch_before == true) {
+ if (next != nullptr && level > 0) {
+ PREFETCH(next->Next(level - 1), 0, 1);
+ }
+ }
+ assert(before == head_ || next == nullptr ||
+ KeyIsAfterNode(next->Key(), before));
+ assert(before == head_ || KeyIsAfterNode(key, before));
+ if (next == after || !KeyIsAfterNode(key, next)) {
+ // found it
+ *out_prev = before;
+ *out_next = next;
+ return;
+ }
+ before = next;
+ }
+}
+
+template <class Comparator>
+void InlineSkipList<Comparator>::RecomputeSpliceLevels(const DecodedKey& key,
+ Splice* splice,
+ int recompute_level) {
+ assert(recompute_level > 0);
+ assert(recompute_level <= splice->height_);
+ for (int i = recompute_level - 1; i >= 0; --i) {
+ FindSpliceForLevel<true>(key, splice->prev_[i + 1], splice->next_[i + 1], i,
+ &splice->prev_[i], &splice->next_[i]);
+ }
+}
+
+template <class Comparator>
+template <bool UseCAS>
+bool InlineSkipList<Comparator>::Insert(const char* key, Splice* splice,
+ bool allow_partial_splice_fix) {
+ Node* x = reinterpret_cast<Node*>(const_cast<char*>(key)) - 1;
+ const DecodedKey key_decoded = compare_.decode_key(key);
+ int height = x->UnstashHeight();
+ assert(height >= 1 && height <= kMaxHeight_);
+
+ int max_height = max_height_.load(std::memory_order_relaxed);
+ while (height > max_height) {
+ if (max_height_.compare_exchange_weak(max_height, height)) {
+ // successfully updated it
+ max_height = height;
+ break;
+ }
+ // else retry, possibly exiting the loop because somebody else
+ // increased it
+ }
+ assert(max_height <= kMaxPossibleHeight);
+
+ int recompute_height = 0;
+ if (splice->height_ < max_height) {
+ // Either splice has never been used or max_height has grown since
+ // last use. We could potentially fix it in the latter case, but
+ // that is tricky.
+ splice->prev_[max_height] = head_;
+ splice->next_[max_height] = nullptr;
+ splice->height_ = max_height;
+ recompute_height = max_height;
+ } else {
+ // Splice is a valid proper-height splice that brackets some
+ // key, but does it bracket this one? We need to validate it and
+ // recompute a portion of the splice (levels 0..recompute_height-1)
+ // that is a superset of all levels that don't bracket the new key.
+ // Several choices are reasonable, because we have to balance the work
+ // saved against the extra comparisons required to validate the Splice.
+ //
+ // One strategy is just to recompute all of orig_splice_height if the
+ // bottom level isn't bracketing. This pessimistically assumes that
+ // we will either get a perfect Splice hit (increasing sequential
+ // inserts) or have no locality.
+ //
+ // Another strategy is to walk up the Splice's levels until we find
+ // a level that brackets the key. This strategy lets the Splice
+ // hint help for other cases: it turns insertion from O(log N) into
+ // O(log D), where D is the number of nodes in between the key that
+ // produced the Splice and the current insert (insertion is aided
+ // whether the new key is before or after the splice). If you have
+ // a way of using a prefix of the key to map directly to the closest
+ // Splice out of O(sqrt(N)) Splices and we make it so that splices
+ // can also be used as hints during read, then we end up with Oshman's
+ // and Shavit's SkipTrie, which has O(log log N) lookup and insertion
+ // (compare to O(log N) for skip list).
+ //
+ // We control the pessimistic strategy with allow_partial_splice_fix.
+ // A good strategy is probably to be pessimistic for seq_splice_,
+ // optimistic if the caller actually went to the work of providing
+ // a Splice.
+ while (recompute_height < max_height) {
+ if (splice->prev_[recompute_height]->Next(recompute_height) !=
+ splice->next_[recompute_height]) {
+ // splice isn't tight at this level, there must have been some inserts
+ // to this
+ // location that didn't update the splice. We might only be a little
+ // stale, but if
+ // the splice is very stale it would be O(N) to fix it. We haven't used
+ // up any of
+ // our budget of comparisons, so always move up even if we are
+ // pessimistic about
+ // our chances of success.
+ ++recompute_height;
+ } else if (splice->prev_[recompute_height] != head_ &&
+ !KeyIsAfterNode(key_decoded,
+ splice->prev_[recompute_height])) {
+ // key is from before splice
+ if (allow_partial_splice_fix) {
+ // skip all levels with the same node without more comparisons
+ Node* bad = splice->prev_[recompute_height];
+ while (splice->prev_[recompute_height] == bad) {
+ ++recompute_height;
+ }
+ } else {
+ // we're pessimistic, recompute everything
+ recompute_height = max_height;
+ }
+ } else if (KeyIsAfterNode(key_decoded, splice->next_[recompute_height])) {
+ // key is from after splice
+ if (allow_partial_splice_fix) {
+ Node* bad = splice->next_[recompute_height];
+ while (splice->next_[recompute_height] == bad) {
+ ++recompute_height;
+ }
+ } else {
+ recompute_height = max_height;
+ }
+ } else {
+ // this level brackets the key, we won!
+ break;
+ }
+ }
+ }
+ assert(recompute_height <= max_height);
+ if (recompute_height > 0) {
+ RecomputeSpliceLevels(key_decoded, splice, recompute_height);
+ }
+
+ bool splice_is_valid = true;
+ if (UseCAS) {
+ for (int i = 0; i < height; ++i) {
+ while (true) {
+ // Checking for duplicate keys on the level 0 is sufficient
+ if (UNLIKELY(i == 0 && splice->next_[i] != nullptr &&
+ compare_(x->Key(), splice->next_[i]->Key()) >= 0)) {
+ // duplicate key
+ return false;
+ }
+ if (UNLIKELY(i == 0 && splice->prev_[i] != head_ &&
+ compare_(splice->prev_[i]->Key(), x->Key()) >= 0)) {
+ // duplicate key
+ return false;
+ }
+ assert(splice->next_[i] == nullptr ||
+ compare_(x->Key(), splice->next_[i]->Key()) < 0);
+ assert(splice->prev_[i] == head_ ||
+ compare_(splice->prev_[i]->Key(), x->Key()) < 0);
+ x->NoBarrier_SetNext(i, splice->next_[i]);
+ if (splice->prev_[i]->CASNext(i, splice->next_[i], x)) {
+ // success
+ break;
+ }
+ // CAS failed, we need to recompute prev and next. It is unlikely
+ // to be helpful to try to use a different level as we redo the
+ // search, because it should be unlikely that lots of nodes have
+ // been inserted between prev[i] and next[i]. No point in using
+ // next[i] as the after hint, because we know it is stale.
+ FindSpliceForLevel<false>(key_decoded, splice->prev_[i], nullptr, i,
+ &splice->prev_[i], &splice->next_[i]);
+
+ // Since we've narrowed the bracket for level i, we might have
+ // violated the Splice constraint between i and i-1. Make sure
+ // we recompute the whole thing next time.
+ if (i > 0) {
+ splice_is_valid = false;
+ }
+ }
+ }
+ } else {
+ for (int i = 0; i < height; ++i) {
+ if (i >= recompute_height &&
+ splice->prev_[i]->Next(i) != splice->next_[i]) {
+ FindSpliceForLevel<false>(key_decoded, splice->prev_[i], nullptr, i,
+ &splice->prev_[i], &splice->next_[i]);
+ }
+ // Checking for duplicate keys on the level 0 is sufficient
+ if (UNLIKELY(i == 0 && splice->next_[i] != nullptr &&
+ compare_(x->Key(), splice->next_[i]->Key()) >= 0)) {
+ // duplicate key
+ return false;
+ }
+ if (UNLIKELY(i == 0 && splice->prev_[i] != head_ &&
+ compare_(splice->prev_[i]->Key(), x->Key()) >= 0)) {
+ // duplicate key
+ return false;
+ }
+ assert(splice->next_[i] == nullptr ||
+ compare_(x->Key(), splice->next_[i]->Key()) < 0);
+ assert(splice->prev_[i] == head_ ||
+ compare_(splice->prev_[i]->Key(), x->Key()) < 0);
+ assert(splice->prev_[i]->Next(i) == splice->next_[i]);
+ x->NoBarrier_SetNext(i, splice->next_[i]);
+ splice->prev_[i]->SetNext(i, x);
+ }
+ }
+ if (splice_is_valid) {
+ for (int i = 0; i < height; ++i) {
+ splice->prev_[i] = x;
+ }
+ assert(splice->prev_[splice->height_] == head_);
+ assert(splice->next_[splice->height_] == nullptr);
+ for (int i = 0; i < splice->height_; ++i) {
+ assert(splice->next_[i] == nullptr ||
+ compare_(key, splice->next_[i]->Key()) < 0);
+ assert(splice->prev_[i] == head_ ||
+ compare_(splice->prev_[i]->Key(), key) <= 0);
+ assert(splice->prev_[i + 1] == splice->prev_[i] ||
+ splice->prev_[i + 1] == head_ ||
+ compare_(splice->prev_[i + 1]->Key(), splice->prev_[i]->Key()) <
+ 0);
+ assert(splice->next_[i + 1] == splice->next_[i] ||
+ splice->next_[i + 1] == nullptr ||
+ compare_(splice->next_[i]->Key(), splice->next_[i + 1]->Key()) <
+ 0);
+ }
+ } else {
+ splice->height_ = 0;
+ }
+ return true;
+}
+
+template <class Comparator>
+bool InlineSkipList<Comparator>::Contains(const char* key) const {
+ Node* x = FindGreaterOrEqual(key);
+ if (x != nullptr && Equal(key, x->Key())) {
+ return true;
+ } else {
+ return false;
+ }
+}
+
+template <class Comparator>
+void InlineSkipList<Comparator>::TEST_Validate() const {
+ // Interate over all levels at the same time, and verify nodes appear in
+ // the right order, and nodes appear in upper level also appear in lower
+ // levels.
+ Node* nodes[kMaxPossibleHeight];
+ int max_height = GetMaxHeight();
+ assert(max_height > 0);
+ for (int i = 0; i < max_height; i++) {
+ nodes[i] = head_;
+ }
+ while (nodes[0] != nullptr) {
+ Node* l0_next = nodes[0]->Next(0);
+ if (l0_next == nullptr) {
+ break;
+ }
+ assert(nodes[0] == head_ || compare_(nodes[0]->Key(), l0_next->Key()) < 0);
+ nodes[0] = l0_next;
+
+ int i = 1;
+ while (i < max_height) {
+ Node* next = nodes[i]->Next(i);
+ if (next == nullptr) {
+ break;
+ }
+ auto cmp = compare_(nodes[0]->Key(), next->Key());
+ assert(cmp <= 0);
+ if (cmp == 0) {
+ assert(next == nodes[0]);
+ nodes[i] = next;
+ } else {
+ break;
+ }
+ i++;
+ }
+ }
+ for (int i = 1; i < max_height; i++) {
+ assert(nodes[i] != nullptr && nodes[i]->Next(i) == nullptr);
+ }
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/memtable/inlineskiplist_test.cc b/src/rocksdb/memtable/inlineskiplist_test.cc
new file mode 100644
index 000000000..f85644064
--- /dev/null
+++ b/src/rocksdb/memtable/inlineskiplist_test.cc
@@ -0,0 +1,664 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "memtable/inlineskiplist.h"
+
+#include <set>
+#include <unordered_set>
+
+#include "memory/concurrent_arena.h"
+#include "rocksdb/env.h"
+#include "test_util/testharness.h"
+#include "util/hash.h"
+#include "util/random.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+// Our test skip list stores 8-byte unsigned integers
+using Key = uint64_t;
+
+static const char* Encode(const uint64_t* key) {
+ return reinterpret_cast<const char*>(key);
+}
+
+static Key Decode(const char* key) {
+ Key rv;
+ memcpy(&rv, key, sizeof(Key));
+ return rv;
+}
+
+struct TestComparator {
+ using DecodedType = Key;
+
+ static DecodedType decode_key(const char* b) { return Decode(b); }
+
+ int operator()(const char* a, const char* b) const {
+ if (Decode(a) < Decode(b)) {
+ return -1;
+ } else if (Decode(a) > Decode(b)) {
+ return +1;
+ } else {
+ return 0;
+ }
+ }
+
+ int operator()(const char* a, const DecodedType b) const {
+ if (Decode(a) < b) {
+ return -1;
+ } else if (Decode(a) > b) {
+ return +1;
+ } else {
+ return 0;
+ }
+ }
+};
+
+using TestInlineSkipList = InlineSkipList<TestComparator>;
+
+class InlineSkipTest : public testing::Test {
+ public:
+ void Insert(TestInlineSkipList* list, Key key) {
+ char* buf = list->AllocateKey(sizeof(Key));
+ memcpy(buf, &key, sizeof(Key));
+ list->Insert(buf);
+ keys_.insert(key);
+ }
+
+ bool InsertWithHint(TestInlineSkipList* list, Key key, void** hint) {
+ char* buf = list->AllocateKey(sizeof(Key));
+ memcpy(buf, &key, sizeof(Key));
+ bool res = list->InsertWithHint(buf, hint);
+ keys_.insert(key);
+ return res;
+ }
+
+ void Validate(TestInlineSkipList* list) {
+ // Check keys exist.
+ for (Key key : keys_) {
+ ASSERT_TRUE(list->Contains(Encode(&key)));
+ }
+ // Iterate over the list, make sure keys appears in order and no extra
+ // keys exist.
+ TestInlineSkipList::Iterator iter(list);
+ ASSERT_FALSE(iter.Valid());
+ Key zero = 0;
+ iter.Seek(Encode(&zero));
+ for (Key key : keys_) {
+ ASSERT_TRUE(iter.Valid());
+ ASSERT_EQ(key, Decode(iter.key()));
+ iter.Next();
+ }
+ ASSERT_FALSE(iter.Valid());
+ // Validate the list is well-formed.
+ list->TEST_Validate();
+ }
+
+ private:
+ std::set<Key> keys_;
+};
+
+TEST_F(InlineSkipTest, Empty) {
+ Arena arena;
+ TestComparator cmp;
+ InlineSkipList<TestComparator> list(cmp, &arena);
+ Key key = 10;
+ ASSERT_TRUE(!list.Contains(Encode(&key)));
+
+ InlineSkipList<TestComparator>::Iterator iter(&list);
+ ASSERT_TRUE(!iter.Valid());
+ iter.SeekToFirst();
+ ASSERT_TRUE(!iter.Valid());
+ key = 100;
+ iter.Seek(Encode(&key));
+ ASSERT_TRUE(!iter.Valid());
+ iter.SeekForPrev(Encode(&key));
+ ASSERT_TRUE(!iter.Valid());
+ iter.SeekToLast();
+ ASSERT_TRUE(!iter.Valid());
+}
+
+TEST_F(InlineSkipTest, InsertAndLookup) {
+ const int N = 2000;
+ const int R = 5000;
+ Random rnd(1000);
+ std::set<Key> keys;
+ ConcurrentArena arena;
+ TestComparator cmp;
+ InlineSkipList<TestComparator> list(cmp, &arena);
+ for (int i = 0; i < N; i++) {
+ Key key = rnd.Next() % R;
+ if (keys.insert(key).second) {
+ char* buf = list.AllocateKey(sizeof(Key));
+ memcpy(buf, &key, sizeof(Key));
+ list.Insert(buf);
+ }
+ }
+
+ for (Key i = 0; i < R; i++) {
+ if (list.Contains(Encode(&i))) {
+ ASSERT_EQ(keys.count(i), 1U);
+ } else {
+ ASSERT_EQ(keys.count(i), 0U);
+ }
+ }
+
+ // Simple iterator tests
+ {
+ InlineSkipList<TestComparator>::Iterator iter(&list);
+ ASSERT_TRUE(!iter.Valid());
+
+ uint64_t zero = 0;
+ iter.Seek(Encode(&zero));
+ ASSERT_TRUE(iter.Valid());
+ ASSERT_EQ(*(keys.begin()), Decode(iter.key()));
+
+ uint64_t max_key = R - 1;
+ iter.SeekForPrev(Encode(&max_key));
+ ASSERT_TRUE(iter.Valid());
+ ASSERT_EQ(*(keys.rbegin()), Decode(iter.key()));
+
+ iter.SeekToFirst();
+ ASSERT_TRUE(iter.Valid());
+ ASSERT_EQ(*(keys.begin()), Decode(iter.key()));
+
+ iter.SeekToLast();
+ ASSERT_TRUE(iter.Valid());
+ ASSERT_EQ(*(keys.rbegin()), Decode(iter.key()));
+ }
+
+ // Forward iteration test
+ for (Key i = 0; i < R; i++) {
+ InlineSkipList<TestComparator>::Iterator iter(&list);
+ iter.Seek(Encode(&i));
+
+ // Compare against model iterator
+ std::set<Key>::iterator model_iter = keys.lower_bound(i);
+ for (int j = 0; j < 3; j++) {
+ if (model_iter == keys.end()) {
+ ASSERT_TRUE(!iter.Valid());
+ break;
+ } else {
+ ASSERT_TRUE(iter.Valid());
+ ASSERT_EQ(*model_iter, Decode(iter.key()));
+ ++model_iter;
+ iter.Next();
+ }
+ }
+ }
+
+ // Backward iteration test
+ for (Key i = 0; i < R; i++) {
+ InlineSkipList<TestComparator>::Iterator iter(&list);
+ iter.SeekForPrev(Encode(&i));
+
+ // Compare against model iterator
+ std::set<Key>::iterator model_iter = keys.upper_bound(i);
+ for (int j = 0; j < 3; j++) {
+ if (model_iter == keys.begin()) {
+ ASSERT_TRUE(!iter.Valid());
+ break;
+ } else {
+ ASSERT_TRUE(iter.Valid());
+ ASSERT_EQ(*--model_iter, Decode(iter.key()));
+ iter.Prev();
+ }
+ }
+ }
+}
+
+TEST_F(InlineSkipTest, InsertWithHint_Sequential) {
+ const int N = 100000;
+ Arena arena;
+ TestComparator cmp;
+ TestInlineSkipList list(cmp, &arena);
+ void* hint = nullptr;
+ for (int i = 0; i < N; i++) {
+ Key key = i;
+ InsertWithHint(&list, key, &hint);
+ }
+ Validate(&list);
+}
+
+TEST_F(InlineSkipTest, InsertWithHint_MultipleHints) {
+ const int N = 100000;
+ const int S = 100;
+ Random rnd(534);
+ Arena arena;
+ TestComparator cmp;
+ TestInlineSkipList list(cmp, &arena);
+ void* hints[S];
+ Key last_key[S];
+ for (int i = 0; i < S; i++) {
+ hints[i] = nullptr;
+ last_key[i] = 0;
+ }
+ for (int i = 0; i < N; i++) {
+ Key s = rnd.Uniform(S);
+ Key key = (s << 32) + (++last_key[s]);
+ InsertWithHint(&list, key, &hints[s]);
+ }
+ Validate(&list);
+}
+
+TEST_F(InlineSkipTest, InsertWithHint_MultipleHintsRandom) {
+ const int N = 100000;
+ const int S = 100;
+ Random rnd(534);
+ Arena arena;
+ TestComparator cmp;
+ TestInlineSkipList list(cmp, &arena);
+ void* hints[S];
+ for (int i = 0; i < S; i++) {
+ hints[i] = nullptr;
+ }
+ for (int i = 0; i < N; i++) {
+ Key s = rnd.Uniform(S);
+ Key key = (s << 32) + rnd.Next();
+ InsertWithHint(&list, key, &hints[s]);
+ }
+ Validate(&list);
+}
+
+TEST_F(InlineSkipTest, InsertWithHint_CompatibleWithInsertWithoutHint) {
+ const int N = 100000;
+ const int S1 = 100;
+ const int S2 = 100;
+ Random rnd(534);
+ Arena arena;
+ TestComparator cmp;
+ TestInlineSkipList list(cmp, &arena);
+ std::unordered_set<Key> used;
+ Key with_hint[S1];
+ Key without_hint[S2];
+ void* hints[S1];
+ for (int i = 0; i < S1; i++) {
+ hints[i] = nullptr;
+ while (true) {
+ Key s = rnd.Next();
+ if (used.insert(s).second) {
+ with_hint[i] = s;
+ break;
+ }
+ }
+ }
+ for (int i = 0; i < S2; i++) {
+ while (true) {
+ Key s = rnd.Next();
+ if (used.insert(s).second) {
+ without_hint[i] = s;
+ break;
+ }
+ }
+ }
+ for (int i = 0; i < N; i++) {
+ Key s = rnd.Uniform(S1 + S2);
+ if (s < S1) {
+ Key key = (with_hint[s] << 32) + rnd.Next();
+ InsertWithHint(&list, key, &hints[s]);
+ } else {
+ Key key = (without_hint[s - S1] << 32) + rnd.Next();
+ Insert(&list, key);
+ }
+ }
+ Validate(&list);
+}
+
+#if !defined(ROCKSDB_VALGRIND_RUN) || defined(ROCKSDB_FULL_VALGRIND_RUN)
+// We want to make sure that with a single writer and multiple
+// concurrent readers (with no synchronization other than when a
+// reader's iterator is created), the reader always observes all the
+// data that was present in the skip list when the iterator was
+// constructor. Because insertions are happening concurrently, we may
+// also observe new values that were inserted since the iterator was
+// constructed, but we should never miss any values that were present
+// at iterator construction time.
+//
+// We generate multi-part keys:
+// <key,gen,hash>
+// where:
+// key is in range [0..K-1]
+// gen is a generation number for key
+// hash is hash(key,gen)
+//
+// The insertion code picks a random key, sets gen to be 1 + the last
+// generation number inserted for that key, and sets hash to Hash(key,gen).
+//
+// At the beginning of a read, we snapshot the last inserted
+// generation number for each key. We then iterate, including random
+// calls to Next() and Seek(). For every key we encounter, we
+// check that it is either expected given the initial snapshot or has
+// been concurrently added since the iterator started.
+class ConcurrentTest {
+ public:
+ static const uint32_t K = 8;
+
+ private:
+ static uint64_t key(Key key) { return (key >> 40); }
+ static uint64_t gen(Key key) { return (key >> 8) & 0xffffffffu; }
+ static uint64_t hash(Key key) { return key & 0xff; }
+
+ static uint64_t HashNumbers(uint64_t k, uint64_t g) {
+ uint64_t data[2] = {k, g};
+ return Hash(reinterpret_cast<char*>(data), sizeof(data), 0);
+ }
+
+ static Key MakeKey(uint64_t k, uint64_t g) {
+ assert(sizeof(Key) == sizeof(uint64_t));
+ assert(k <= K); // We sometimes pass K to seek to the end of the skiplist
+ assert(g <= 0xffffffffu);
+ return ((k << 40) | (g << 8) | (HashNumbers(k, g) & 0xff));
+ }
+
+ static bool IsValidKey(Key k) {
+ return hash(k) == (HashNumbers(key(k), gen(k)) & 0xff);
+ }
+
+ static Key RandomTarget(Random* rnd) {
+ switch (rnd->Next() % 10) {
+ case 0:
+ // Seek to beginning
+ return MakeKey(0, 0);
+ case 1:
+ // Seek to end
+ return MakeKey(K, 0);
+ default:
+ // Seek to middle
+ return MakeKey(rnd->Next() % K, 0);
+ }
+ }
+
+ // Per-key generation
+ struct State {
+ std::atomic<int> generation[K];
+ void Set(int k, int v) {
+ generation[k].store(v, std::memory_order_release);
+ }
+ int Get(int k) { return generation[k].load(std::memory_order_acquire); }
+
+ State() {
+ for (unsigned int k = 0; k < K; k++) {
+ Set(k, 0);
+ }
+ }
+ };
+
+ // Current state of the test
+ State current_;
+
+ ConcurrentArena arena_;
+
+ // InlineSkipList is not protected by mu_. We just use a single writer
+ // thread to modify it.
+ InlineSkipList<TestComparator> list_;
+
+ public:
+ ConcurrentTest() : list_(TestComparator(), &arena_) {}
+
+ // REQUIRES: No concurrent calls to WriteStep or ConcurrentWriteStep
+ void WriteStep(Random* rnd) {
+ const uint32_t k = rnd->Next() % K;
+ const int g = current_.Get(k) + 1;
+ const Key new_key = MakeKey(k, g);
+ char* buf = list_.AllocateKey(sizeof(Key));
+ memcpy(buf, &new_key, sizeof(Key));
+ list_.Insert(buf);
+ current_.Set(k, g);
+ }
+
+ // REQUIRES: No concurrent calls for the same k
+ void ConcurrentWriteStep(uint32_t k, bool use_hint = false) {
+ const int g = current_.Get(k) + 1;
+ const Key new_key = MakeKey(k, g);
+ char* buf = list_.AllocateKey(sizeof(Key));
+ memcpy(buf, &new_key, sizeof(Key));
+ if (use_hint) {
+ void* hint = nullptr;
+ list_.InsertWithHintConcurrently(buf, &hint);
+ delete[] reinterpret_cast<char*>(hint);
+ } else {
+ list_.InsertConcurrently(buf);
+ }
+ ASSERT_EQ(g, current_.Get(k) + 1);
+ current_.Set(k, g);
+ }
+
+ void ReadStep(Random* rnd) {
+ // Remember the initial committed state of the skiplist.
+ State initial_state;
+ for (unsigned int k = 0; k < K; k++) {
+ initial_state.Set(k, current_.Get(k));
+ }
+
+ Key pos = RandomTarget(rnd);
+ InlineSkipList<TestComparator>::Iterator iter(&list_);
+ iter.Seek(Encode(&pos));
+ while (true) {
+ Key current;
+ if (!iter.Valid()) {
+ current = MakeKey(K, 0);
+ } else {
+ current = Decode(iter.key());
+ ASSERT_TRUE(IsValidKey(current)) << current;
+ }
+ ASSERT_LE(pos, current) << "should not go backwards";
+
+ // Verify that everything in [pos,current) was not present in
+ // initial_state.
+ while (pos < current) {
+ ASSERT_LT(key(pos), K) << pos;
+
+ // Note that generation 0 is never inserted, so it is ok if
+ // <*,0,*> is missing.
+ ASSERT_TRUE((gen(pos) == 0U) ||
+ (gen(pos) > static_cast<uint64_t>(initial_state.Get(
+ static_cast<int>(key(pos))))))
+ << "key: " << key(pos) << "; gen: " << gen(pos)
+ << "; initgen: " << initial_state.Get(static_cast<int>(key(pos)));
+
+ // Advance to next key in the valid key space
+ if (key(pos) < key(current)) {
+ pos = MakeKey(key(pos) + 1, 0);
+ } else {
+ pos = MakeKey(key(pos), gen(pos) + 1);
+ }
+ }
+
+ if (!iter.Valid()) {
+ break;
+ }
+
+ if (rnd->Next() % 2) {
+ iter.Next();
+ pos = MakeKey(key(pos), gen(pos) + 1);
+ } else {
+ Key new_target = RandomTarget(rnd);
+ if (new_target > pos) {
+ pos = new_target;
+ iter.Seek(Encode(&new_target));
+ }
+ }
+ }
+ }
+};
+const uint32_t ConcurrentTest::K;
+
+// Simple test that does single-threaded testing of the ConcurrentTest
+// scaffolding.
+TEST_F(InlineSkipTest, ConcurrentReadWithoutThreads) {
+ ConcurrentTest test;
+ Random rnd(test::RandomSeed());
+ for (int i = 0; i < 10000; i++) {
+ test.ReadStep(&rnd);
+ test.WriteStep(&rnd);
+ }
+}
+
+TEST_F(InlineSkipTest, ConcurrentInsertWithoutThreads) {
+ ConcurrentTest test;
+ Random rnd(test::RandomSeed());
+ for (int i = 0; i < 10000; i++) {
+ test.ReadStep(&rnd);
+ uint32_t base = rnd.Next();
+ for (int j = 0; j < 4; ++j) {
+ test.ConcurrentWriteStep((base + j) % ConcurrentTest::K);
+ }
+ }
+}
+
+class TestState {
+ public:
+ ConcurrentTest t_;
+ bool use_hint_;
+ int seed_;
+ std::atomic<bool> quit_flag_;
+ std::atomic<uint32_t> next_writer_;
+
+ enum ReaderState { STARTING, RUNNING, DONE };
+
+ explicit TestState(int s)
+ : seed_(s),
+ quit_flag_(false),
+ state_(STARTING),
+ pending_writers_(0),
+ state_cv_(&mu_) {}
+
+ void Wait(ReaderState s) {
+ mu_.Lock();
+ while (state_ != s) {
+ state_cv_.Wait();
+ }
+ mu_.Unlock();
+ }
+
+ void Change(ReaderState s) {
+ mu_.Lock();
+ state_ = s;
+ state_cv_.Signal();
+ mu_.Unlock();
+ }
+
+ void AdjustPendingWriters(int delta) {
+ mu_.Lock();
+ pending_writers_ += delta;
+ if (pending_writers_ == 0) {
+ state_cv_.Signal();
+ }
+ mu_.Unlock();
+ }
+
+ void WaitForPendingWriters() {
+ mu_.Lock();
+ while (pending_writers_ != 0) {
+ state_cv_.Wait();
+ }
+ mu_.Unlock();
+ }
+
+ private:
+ port::Mutex mu_;
+ ReaderState state_;
+ int pending_writers_;
+ port::CondVar state_cv_;
+};
+
+static void ConcurrentReader(void* arg) {
+ TestState* state = reinterpret_cast<TestState*>(arg);
+ Random rnd(state->seed_);
+ int64_t reads = 0;
+ state->Change(TestState::RUNNING);
+ while (!state->quit_flag_.load(std::memory_order_acquire)) {
+ state->t_.ReadStep(&rnd);
+ ++reads;
+ }
+ state->Change(TestState::DONE);
+}
+
+static void ConcurrentWriter(void* arg) {
+ TestState* state = reinterpret_cast<TestState*>(arg);
+ uint32_t k = state->next_writer_++ % ConcurrentTest::K;
+ state->t_.ConcurrentWriteStep(k, state->use_hint_);
+ state->AdjustPendingWriters(-1);
+}
+
+static void RunConcurrentRead(int run) {
+ const int seed = test::RandomSeed() + (run * 100);
+ Random rnd(seed);
+ const int N = 1000;
+ const int kSize = 1000;
+ for (int i = 0; i < N; i++) {
+ if ((i % 100) == 0) {
+ fprintf(stderr, "Run %d of %d\n", i, N);
+ }
+ TestState state(seed + 1);
+ Env::Default()->SetBackgroundThreads(1);
+ Env::Default()->Schedule(ConcurrentReader, &state);
+ state.Wait(TestState::RUNNING);
+ for (int k = 0; k < kSize; ++k) {
+ state.t_.WriteStep(&rnd);
+ }
+ state.quit_flag_.store(true, std::memory_order_release);
+ state.Wait(TestState::DONE);
+ }
+}
+
+static void RunConcurrentInsert(int run, bool use_hint = false,
+ int write_parallelism = 4) {
+ Env::Default()->SetBackgroundThreads(1 + write_parallelism,
+ Env::Priority::LOW);
+ const int seed = test::RandomSeed() + (run * 100);
+ Random rnd(seed);
+ const int N = 1000;
+ const int kSize = 1000;
+ for (int i = 0; i < N; i++) {
+ if ((i % 100) == 0) {
+ fprintf(stderr, "Run %d of %d\n", i, N);
+ }
+ TestState state(seed + 1);
+ state.use_hint_ = use_hint;
+ Env::Default()->Schedule(ConcurrentReader, &state);
+ state.Wait(TestState::RUNNING);
+ for (int k = 0; k < kSize; k += write_parallelism) {
+ state.next_writer_ = rnd.Next();
+ state.AdjustPendingWriters(write_parallelism);
+ for (int p = 0; p < write_parallelism; ++p) {
+ Env::Default()->Schedule(ConcurrentWriter, &state);
+ }
+ state.WaitForPendingWriters();
+ }
+ state.quit_flag_.store(true, std::memory_order_release);
+ state.Wait(TestState::DONE);
+ }
+}
+
+TEST_F(InlineSkipTest, ConcurrentRead1) { RunConcurrentRead(1); }
+TEST_F(InlineSkipTest, ConcurrentRead2) { RunConcurrentRead(2); }
+TEST_F(InlineSkipTest, ConcurrentRead3) { RunConcurrentRead(3); }
+TEST_F(InlineSkipTest, ConcurrentRead4) { RunConcurrentRead(4); }
+TEST_F(InlineSkipTest, ConcurrentRead5) { RunConcurrentRead(5); }
+TEST_F(InlineSkipTest, ConcurrentInsert1) { RunConcurrentInsert(1); }
+TEST_F(InlineSkipTest, ConcurrentInsert2) { RunConcurrentInsert(2); }
+TEST_F(InlineSkipTest, ConcurrentInsert3) { RunConcurrentInsert(3); }
+TEST_F(InlineSkipTest, ConcurrentInsertWithHint1) {
+ RunConcurrentInsert(1, true);
+}
+TEST_F(InlineSkipTest, ConcurrentInsertWithHint2) {
+ RunConcurrentInsert(2, true);
+}
+TEST_F(InlineSkipTest, ConcurrentInsertWithHint3) {
+ RunConcurrentInsert(3, true);
+}
+
+#endif // !defined(ROCKSDB_VALGRIND_RUN) || defined(ROCKSDB_FULL_VALGRIND_RUN)
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/rocksdb/memtable/memtablerep_bench.cc b/src/rocksdb/memtable/memtablerep_bench.cc
new file mode 100644
index 000000000..a915abed7
--- /dev/null
+++ b/src/rocksdb/memtable/memtablerep_bench.cc
@@ -0,0 +1,689 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#ifndef GFLAGS
+#include <cstdio>
+int main() {
+ fprintf(stderr, "Please install gflags to run rocksdb tools\n");
+ return 1;
+}
+#else
+
+#include <atomic>
+#include <iostream>
+#include <memory>
+#include <thread>
+#include <type_traits>
+#include <vector>
+
+#include "db/dbformat.h"
+#include "db/memtable.h"
+#include "memory/arena.h"
+#include "port/port.h"
+#include "port/stack_trace.h"
+#include "rocksdb/comparator.h"
+#include "rocksdb/convenience.h"
+#include "rocksdb/memtablerep.h"
+#include "rocksdb/options.h"
+#include "rocksdb/slice_transform.h"
+#include "rocksdb/system_clock.h"
+#include "rocksdb/write_buffer_manager.h"
+#include "test_util/testutil.h"
+#include "util/gflags_compat.h"
+#include "util/mutexlock.h"
+#include "util/stop_watch.h"
+
+using GFLAGS_NAMESPACE::ParseCommandLineFlags;
+using GFLAGS_NAMESPACE::RegisterFlagValidator;
+using GFLAGS_NAMESPACE::SetUsageMessage;
+
+DEFINE_string(benchmarks, "fillrandom",
+ "Comma-separated list of benchmarks to run. Options:\n"
+ "\tfillrandom -- write N random values\n"
+ "\tfillseq -- write N values in sequential order\n"
+ "\treadrandom -- read N values in random order\n"
+ "\treadseq -- scan the DB\n"
+ "\treadwrite -- 1 thread writes while N - 1 threads "
+ "do random\n"
+ "\t reads\n"
+ "\tseqreadwrite -- 1 thread writes while N - 1 threads "
+ "do scans\n");
+
+DEFINE_string(memtablerep, "skiplist",
+ "Which implementation of memtablerep to use. See "
+ "include/memtablerep.h for\n"
+ " more details. Options:\n"
+ "\tskiplist -- backed by a skiplist\n"
+ "\tvector -- backed by an std::vector\n"
+ "\thashskiplist -- backed by a hash skip list\n"
+ "\thashlinklist -- backed by a hash linked list\n"
+ "\tcuckoo -- backed by a cuckoo hash table");
+
+DEFINE_int64(bucket_count, 1000000,
+ "bucket_count parameter to pass into NewHashSkiplistRepFactory or "
+ "NewHashLinkListRepFactory");
+
+DEFINE_int32(
+ hashskiplist_height, 4,
+ "skiplist_height parameter to pass into NewHashSkiplistRepFactory");
+
+DEFINE_int32(
+ hashskiplist_branching_factor, 4,
+ "branching_factor parameter to pass into NewHashSkiplistRepFactory");
+
+DEFINE_int32(
+ huge_page_tlb_size, 0,
+ "huge_page_tlb_size parameter to pass into NewHashLinkListRepFactory");
+
+DEFINE_int32(bucket_entries_logging_threshold, 4096,
+ "bucket_entries_logging_threshold parameter to pass into "
+ "NewHashLinkListRepFactory");
+
+DEFINE_bool(if_log_bucket_dist_when_flash, true,
+ "if_log_bucket_dist_when_flash parameter to pass into "
+ "NewHashLinkListRepFactory");
+
+DEFINE_int32(
+ threshold_use_skiplist, 256,
+ "threshold_use_skiplist parameter to pass into NewHashLinkListRepFactory");
+
+DEFINE_int64(write_buffer_size, 256,
+ "write_buffer_size parameter to pass into WriteBufferManager");
+
+DEFINE_int32(
+ num_threads, 1,
+ "Number of concurrent threads to run. If the benchmark includes writes,\n"
+ "then at most one thread will be a writer");
+
+DEFINE_int32(num_operations, 1000000,
+ "Number of operations to do for write and random read benchmarks");
+
+DEFINE_int32(num_scans, 10,
+ "Number of times for each thread to scan the memtablerep for "
+ "sequential read "
+ "benchmarks");
+
+DEFINE_int32(item_size, 100, "Number of bytes each item should be");
+
+DEFINE_int32(prefix_length, 8,
+ "Prefix length to pass into NewFixedPrefixTransform");
+
+/* VectorRep settings */
+DEFINE_int64(vectorrep_count, 0,
+ "Number of entries to reserve on VectorRep initialization");
+
+DEFINE_int64(seed, 0,
+ "Seed base for random number generators. "
+ "When 0 it is deterministic.");
+
+namespace ROCKSDB_NAMESPACE {
+
+namespace {
+struct CallbackVerifyArgs {
+ bool found;
+ LookupKey* key;
+ MemTableRep* table;
+ InternalKeyComparator* comparator;
+};
+} // namespace
+
+// Helper for quickly generating random data.
+class RandomGenerator {
+ private:
+ std::string data_;
+ unsigned int pos_;
+
+ public:
+ RandomGenerator() {
+ Random rnd(301);
+ auto size = (unsigned)std::max(1048576, FLAGS_item_size);
+ data_ = rnd.RandomString(size);
+ pos_ = 0;
+ }
+
+ Slice Generate(unsigned int len) {
+ assert(len <= data_.size());
+ if (pos_ + len > data_.size()) {
+ pos_ = 0;
+ }
+ pos_ += len;
+ return Slice(data_.data() + pos_ - len, len);
+ }
+};
+
+enum WriteMode { SEQUENTIAL, RANDOM, UNIQUE_RANDOM };
+
+class KeyGenerator {
+ public:
+ KeyGenerator(Random64* rand, WriteMode mode, uint64_t num)
+ : rand_(rand), mode_(mode), num_(num), next_(0) {
+ if (mode_ == UNIQUE_RANDOM) {
+ // NOTE: if memory consumption of this approach becomes a concern,
+ // we can either break it into pieces and only random shuffle a section
+ // each time. Alternatively, use a bit map implementation
+ // (https://reviews.facebook.net/differential/diff/54627/)
+ values_.resize(num_);
+ for (uint64_t i = 0; i < num_; ++i) {
+ values_[i] = i;
+ }
+ RandomShuffle(values_.begin(), values_.end(),
+ static_cast<uint32_t>(FLAGS_seed));
+ }
+ }
+
+ uint64_t Next() {
+ switch (mode_) {
+ case SEQUENTIAL:
+ return next_++;
+ case RANDOM:
+ return rand_->Next() % num_;
+ case UNIQUE_RANDOM:
+ return values_[next_++];
+ }
+ assert(false);
+ return std::numeric_limits<uint64_t>::max();
+ }
+
+ private:
+ Random64* rand_;
+ WriteMode mode_;
+ const uint64_t num_;
+ uint64_t next_;
+ std::vector<uint64_t> values_;
+};
+
+class BenchmarkThread {
+ public:
+ explicit BenchmarkThread(MemTableRep* table, KeyGenerator* key_gen,
+ uint64_t* bytes_written, uint64_t* bytes_read,
+ uint64_t* sequence, uint64_t num_ops,
+ uint64_t* read_hits)
+ : table_(table),
+ key_gen_(key_gen),
+ bytes_written_(bytes_written),
+ bytes_read_(bytes_read),
+ sequence_(sequence),
+ num_ops_(num_ops),
+ read_hits_(read_hits) {}
+
+ virtual void operator()() = 0;
+ virtual ~BenchmarkThread() {}
+
+ protected:
+ MemTableRep* table_;
+ KeyGenerator* key_gen_;
+ uint64_t* bytes_written_;
+ uint64_t* bytes_read_;
+ uint64_t* sequence_;
+ uint64_t num_ops_;
+ uint64_t* read_hits_;
+ RandomGenerator generator_;
+};
+
+class FillBenchmarkThread : public BenchmarkThread {
+ public:
+ FillBenchmarkThread(MemTableRep* table, KeyGenerator* key_gen,
+ uint64_t* bytes_written, uint64_t* bytes_read,
+ uint64_t* sequence, uint64_t num_ops, uint64_t* read_hits)
+ : BenchmarkThread(table, key_gen, bytes_written, bytes_read, sequence,
+ num_ops, read_hits) {}
+
+ void FillOne() {
+ char* buf = nullptr;
+ auto internal_key_size = 16;
+ auto encoded_len =
+ FLAGS_item_size + VarintLength(internal_key_size) + internal_key_size;
+ KeyHandle handle = table_->Allocate(encoded_len, &buf);
+ assert(buf != nullptr);
+ char* p = EncodeVarint32(buf, internal_key_size);
+ auto key = key_gen_->Next();
+ EncodeFixed64(p, key);
+ p += 8;
+ EncodeFixed64(p, ++(*sequence_));
+ p += 8;
+ Slice bytes = generator_.Generate(FLAGS_item_size);
+ memcpy(p, bytes.data(), FLAGS_item_size);
+ p += FLAGS_item_size;
+ assert(p == buf + encoded_len);
+ table_->Insert(handle);
+ *bytes_written_ += encoded_len;
+ }
+
+ void operator()() override {
+ for (unsigned int i = 0; i < num_ops_; ++i) {
+ FillOne();
+ }
+ }
+};
+
+class ConcurrentFillBenchmarkThread : public FillBenchmarkThread {
+ public:
+ ConcurrentFillBenchmarkThread(MemTableRep* table, KeyGenerator* key_gen,
+ uint64_t* bytes_written, uint64_t* bytes_read,
+ uint64_t* sequence, uint64_t num_ops,
+ uint64_t* read_hits,
+ std::atomic_int* threads_done)
+ : FillBenchmarkThread(table, key_gen, bytes_written, bytes_read, sequence,
+ num_ops, read_hits) {
+ threads_done_ = threads_done;
+ }
+
+ void operator()() override {
+ // # of read threads will be total threads - write threads (always 1). Loop
+ // while all reads complete.
+ while ((*threads_done_).load() < (FLAGS_num_threads - 1)) {
+ FillOne();
+ }
+ }
+
+ private:
+ std::atomic_int* threads_done_;
+};
+
+class ReadBenchmarkThread : public BenchmarkThread {
+ public:
+ ReadBenchmarkThread(MemTableRep* table, KeyGenerator* key_gen,
+ uint64_t* bytes_written, uint64_t* bytes_read,
+ uint64_t* sequence, uint64_t num_ops, uint64_t* read_hits)
+ : BenchmarkThread(table, key_gen, bytes_written, bytes_read, sequence,
+ num_ops, read_hits) {}
+
+ static bool callback(void* arg, const char* entry) {
+ CallbackVerifyArgs* callback_args = static_cast<CallbackVerifyArgs*>(arg);
+ assert(callback_args != nullptr);
+ uint32_t key_length;
+ const char* key_ptr = GetVarint32Ptr(entry, entry + 5, &key_length);
+ if ((callback_args->comparator)
+ ->user_comparator()
+ ->Equal(Slice(key_ptr, key_length - 8),
+ callback_args->key->user_key())) {
+ callback_args->found = true;
+ }
+ return false;
+ }
+
+ void ReadOne() {
+ std::string user_key;
+ auto key = key_gen_->Next();
+ PutFixed64(&user_key, key);
+ LookupKey lookup_key(user_key, *sequence_);
+ InternalKeyComparator internal_key_comp(BytewiseComparator());
+ CallbackVerifyArgs verify_args;
+ verify_args.found = false;
+ verify_args.key = &lookup_key;
+ verify_args.table = table_;
+ verify_args.comparator = &internal_key_comp;
+ table_->Get(lookup_key, &verify_args, callback);
+ if (verify_args.found) {
+ *bytes_read_ += VarintLength(16) + 16 + FLAGS_item_size;
+ ++*read_hits_;
+ }
+ }
+ void operator()() override {
+ for (unsigned int i = 0; i < num_ops_; ++i) {
+ ReadOne();
+ }
+ }
+};
+
+class SeqReadBenchmarkThread : public BenchmarkThread {
+ public:
+ SeqReadBenchmarkThread(MemTableRep* table, KeyGenerator* key_gen,
+ uint64_t* bytes_written, uint64_t* bytes_read,
+ uint64_t* sequence, uint64_t num_ops,
+ uint64_t* read_hits)
+ : BenchmarkThread(table, key_gen, bytes_written, bytes_read, sequence,
+ num_ops, read_hits) {}
+
+ void ReadOneSeq() {
+ std::unique_ptr<MemTableRep::Iterator> iter(table_->GetIterator());
+ for (iter->SeekToFirst(); iter->Valid(); iter->Next()) {
+ // pretend to read the value
+ *bytes_read_ += VarintLength(16) + 16 + FLAGS_item_size;
+ }
+ ++*read_hits_;
+ }
+
+ void operator()() override {
+ for (unsigned int i = 0; i < num_ops_; ++i) {
+ { ReadOneSeq(); }
+ }
+ }
+};
+
+class ConcurrentReadBenchmarkThread : public ReadBenchmarkThread {
+ public:
+ ConcurrentReadBenchmarkThread(MemTableRep* table, KeyGenerator* key_gen,
+ uint64_t* bytes_written, uint64_t* bytes_read,
+ uint64_t* sequence, uint64_t num_ops,
+ uint64_t* read_hits,
+ std::atomic_int* threads_done)
+ : ReadBenchmarkThread(table, key_gen, bytes_written, bytes_read, sequence,
+ num_ops, read_hits) {
+ threads_done_ = threads_done;
+ }
+
+ void operator()() override {
+ for (unsigned int i = 0; i < num_ops_; ++i) {
+ ReadOne();
+ }
+ ++*threads_done_;
+ }
+
+ private:
+ std::atomic_int* threads_done_;
+};
+
+class SeqConcurrentReadBenchmarkThread : public SeqReadBenchmarkThread {
+ public:
+ SeqConcurrentReadBenchmarkThread(MemTableRep* table, KeyGenerator* key_gen,
+ uint64_t* bytes_written,
+ uint64_t* bytes_read, uint64_t* sequence,
+ uint64_t num_ops, uint64_t* read_hits,
+ std::atomic_int* threads_done)
+ : SeqReadBenchmarkThread(table, key_gen, bytes_written, bytes_read,
+ sequence, num_ops, read_hits) {
+ threads_done_ = threads_done;
+ }
+
+ void operator()() override {
+ for (unsigned int i = 0; i < num_ops_; ++i) {
+ ReadOneSeq();
+ }
+ ++*threads_done_;
+ }
+
+ private:
+ std::atomic_int* threads_done_;
+};
+
+class Benchmark {
+ public:
+ explicit Benchmark(MemTableRep* table, KeyGenerator* key_gen,
+ uint64_t* sequence, uint32_t num_threads)
+ : table_(table),
+ key_gen_(key_gen),
+ sequence_(sequence),
+ num_threads_(num_threads) {}
+
+ virtual ~Benchmark() {}
+ virtual void Run() {
+ std::cout << "Number of threads: " << num_threads_ << std::endl;
+ std::vector<port::Thread> threads;
+ uint64_t bytes_written = 0;
+ uint64_t bytes_read = 0;
+ uint64_t read_hits = 0;
+ StopWatchNano timer(SystemClock::Default().get(), true);
+ RunThreads(&threads, &bytes_written, &bytes_read, true, &read_hits);
+ auto elapsed_time = static_cast<double>(timer.ElapsedNanos() / 1000);
+ std::cout << "Elapsed time: " << static_cast<int>(elapsed_time) << " us"
+ << std::endl;
+
+ if (bytes_written > 0) {
+ auto MiB_written = static_cast<double>(bytes_written) / (1 << 20);
+ auto write_throughput = MiB_written / (elapsed_time / 1000000);
+ std::cout << "Total bytes written: " << MiB_written << " MiB"
+ << std::endl;
+ std::cout << "Write throughput: " << write_throughput << " MiB/s"
+ << std::endl;
+ auto us_per_op = elapsed_time / num_write_ops_per_thread_;
+ std::cout << "write us/op: " << us_per_op << std::endl;
+ }
+ if (bytes_read > 0) {
+ auto MiB_read = static_cast<double>(bytes_read) / (1 << 20);
+ auto read_throughput = MiB_read / (elapsed_time / 1000000);
+ std::cout << "Total bytes read: " << MiB_read << " MiB" << std::endl;
+ std::cout << "Read throughput: " << read_throughput << " MiB/s"
+ << std::endl;
+ auto us_per_op = elapsed_time / num_read_ops_per_thread_;
+ std::cout << "read us/op: " << us_per_op << std::endl;
+ }
+ }
+
+ virtual void RunThreads(std::vector<port::Thread>* threads,
+ uint64_t* bytes_written, uint64_t* bytes_read,
+ bool write, uint64_t* read_hits) = 0;
+
+ protected:
+ MemTableRep* table_;
+ KeyGenerator* key_gen_;
+ uint64_t* sequence_;
+ uint64_t num_write_ops_per_thread_ = 0;
+ uint64_t num_read_ops_per_thread_ = 0;
+ const uint32_t num_threads_;
+};
+
+class FillBenchmark : public Benchmark {
+ public:
+ explicit FillBenchmark(MemTableRep* table, KeyGenerator* key_gen,
+ uint64_t* sequence)
+ : Benchmark(table, key_gen, sequence, 1) {
+ num_write_ops_per_thread_ = FLAGS_num_operations;
+ }
+
+ void RunThreads(std::vector<port::Thread>* /*threads*/,
+ uint64_t* bytes_written, uint64_t* bytes_read, bool /*write*/,
+ uint64_t* read_hits) override {
+ FillBenchmarkThread(table_, key_gen_, bytes_written, bytes_read, sequence_,
+ num_write_ops_per_thread_, read_hits)();
+ }
+};
+
+class ReadBenchmark : public Benchmark {
+ public:
+ explicit ReadBenchmark(MemTableRep* table, KeyGenerator* key_gen,
+ uint64_t* sequence)
+ : Benchmark(table, key_gen, sequence, FLAGS_num_threads) {
+ num_read_ops_per_thread_ = FLAGS_num_operations / FLAGS_num_threads;
+ }
+
+ void RunThreads(std::vector<port::Thread>* threads, uint64_t* bytes_written,
+ uint64_t* bytes_read, bool /*write*/,
+ uint64_t* read_hits) override {
+ for (int i = 0; i < FLAGS_num_threads; ++i) {
+ threads->emplace_back(
+ ReadBenchmarkThread(table_, key_gen_, bytes_written, bytes_read,
+ sequence_, num_read_ops_per_thread_, read_hits));
+ }
+ for (auto& thread : *threads) {
+ thread.join();
+ }
+ std::cout << "read hit%: "
+ << (static_cast<double>(*read_hits) / FLAGS_num_operations) * 100
+ << std::endl;
+ }
+};
+
+class SeqReadBenchmark : public Benchmark {
+ public:
+ explicit SeqReadBenchmark(MemTableRep* table, uint64_t* sequence)
+ : Benchmark(table, nullptr, sequence, FLAGS_num_threads) {
+ num_read_ops_per_thread_ = FLAGS_num_scans;
+ }
+
+ void RunThreads(std::vector<port::Thread>* threads, uint64_t* bytes_written,
+ uint64_t* bytes_read, bool /*write*/,
+ uint64_t* read_hits) override {
+ for (int i = 0; i < FLAGS_num_threads; ++i) {
+ threads->emplace_back(SeqReadBenchmarkThread(
+ table_, key_gen_, bytes_written, bytes_read, sequence_,
+ num_read_ops_per_thread_, read_hits));
+ }
+ for (auto& thread : *threads) {
+ thread.join();
+ }
+ }
+};
+
+template <class ReadThreadType>
+class ReadWriteBenchmark : public Benchmark {
+ public:
+ explicit ReadWriteBenchmark(MemTableRep* table, KeyGenerator* key_gen,
+ uint64_t* sequence)
+ : Benchmark(table, key_gen, sequence, FLAGS_num_threads) {
+ num_read_ops_per_thread_ =
+ FLAGS_num_threads <= 1
+ ? 0
+ : (FLAGS_num_operations / (FLAGS_num_threads - 1));
+ num_write_ops_per_thread_ = FLAGS_num_operations;
+ }
+
+ void RunThreads(std::vector<port::Thread>* threads, uint64_t* bytes_written,
+ uint64_t* bytes_read, bool /*write*/,
+ uint64_t* read_hits) override {
+ std::atomic_int threads_done;
+ threads_done.store(0);
+ threads->emplace_back(ConcurrentFillBenchmarkThread(
+ table_, key_gen_, bytes_written, bytes_read, sequence_,
+ num_write_ops_per_thread_, read_hits, &threads_done));
+ for (int i = 1; i < FLAGS_num_threads; ++i) {
+ threads->emplace_back(
+ ReadThreadType(table_, key_gen_, bytes_written, bytes_read, sequence_,
+ num_read_ops_per_thread_, read_hits, &threads_done));
+ }
+ for (auto& thread : *threads) {
+ thread.join();
+ }
+ }
+};
+
+} // namespace ROCKSDB_NAMESPACE
+
+void PrintWarnings() {
+#if defined(__GNUC__) && !defined(__OPTIMIZE__)
+ fprintf(stdout,
+ "WARNING: Optimization is disabled: benchmarks unnecessarily slow\n");
+#endif
+#ifndef NDEBUG
+ fprintf(stdout,
+ "WARNING: Assertions are enabled; benchmarks unnecessarily slow\n");
+#endif
+}
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ SetUsageMessage(std::string("\nUSAGE:\n") + std::string(argv[0]) +
+ " [OPTIONS]...");
+ ParseCommandLineFlags(&argc, &argv, true);
+
+ PrintWarnings();
+
+ ROCKSDB_NAMESPACE::Options options;
+
+ std::unique_ptr<ROCKSDB_NAMESPACE::MemTableRepFactory> factory;
+ if (FLAGS_memtablerep == "skiplist") {
+ factory.reset(new ROCKSDB_NAMESPACE::SkipListFactory);
+#ifndef ROCKSDB_LITE
+ } else if (FLAGS_memtablerep == "vector") {
+ factory.reset(new ROCKSDB_NAMESPACE::VectorRepFactory);
+ } else if (FLAGS_memtablerep == "hashskiplist" ||
+ FLAGS_memtablerep == "prefix_hash") {
+ factory.reset(ROCKSDB_NAMESPACE::NewHashSkipListRepFactory(
+ FLAGS_bucket_count, FLAGS_hashskiplist_height,
+ FLAGS_hashskiplist_branching_factor));
+ options.prefix_extractor.reset(
+ ROCKSDB_NAMESPACE::NewFixedPrefixTransform(FLAGS_prefix_length));
+ } else if (FLAGS_memtablerep == "hashlinklist" ||
+ FLAGS_memtablerep == "hash_linkedlist") {
+ factory.reset(ROCKSDB_NAMESPACE::NewHashLinkListRepFactory(
+ FLAGS_bucket_count, FLAGS_huge_page_tlb_size,
+ FLAGS_bucket_entries_logging_threshold,
+ FLAGS_if_log_bucket_dist_when_flash, FLAGS_threshold_use_skiplist));
+ options.prefix_extractor.reset(
+ ROCKSDB_NAMESPACE::NewFixedPrefixTransform(FLAGS_prefix_length));
+#endif // ROCKSDB_LITE
+ } else {
+ ROCKSDB_NAMESPACE::ConfigOptions config_options;
+ config_options.ignore_unsupported_options = false;
+
+ ROCKSDB_NAMESPACE::Status s =
+ ROCKSDB_NAMESPACE::MemTableRepFactory::CreateFromString(
+ config_options, FLAGS_memtablerep, &factory);
+ if (!s.ok()) {
+ fprintf(stdout, "Unknown memtablerep: %s\n", s.ToString().c_str());
+ exit(1);
+ }
+ }
+
+ ROCKSDB_NAMESPACE::InternalKeyComparator internal_key_comp(
+ ROCKSDB_NAMESPACE::BytewiseComparator());
+ ROCKSDB_NAMESPACE::MemTable::KeyComparator key_comp(internal_key_comp);
+ ROCKSDB_NAMESPACE::Arena arena;
+ ROCKSDB_NAMESPACE::WriteBufferManager wb(FLAGS_write_buffer_size);
+ uint64_t sequence;
+ auto createMemtableRep = [&] {
+ sequence = 0;
+ return factory->CreateMemTableRep(key_comp, &arena,
+ options.prefix_extractor.get(),
+ options.info_log.get());
+ };
+ std::unique_ptr<ROCKSDB_NAMESPACE::MemTableRep> memtablerep;
+ ROCKSDB_NAMESPACE::Random64 rng(FLAGS_seed);
+ const char* benchmarks = FLAGS_benchmarks.c_str();
+ while (benchmarks != nullptr) {
+ std::unique_ptr<ROCKSDB_NAMESPACE::KeyGenerator> key_gen;
+ const char* sep = strchr(benchmarks, ',');
+ ROCKSDB_NAMESPACE::Slice name;
+ if (sep == nullptr) {
+ name = benchmarks;
+ benchmarks = nullptr;
+ } else {
+ name = ROCKSDB_NAMESPACE::Slice(benchmarks, sep - benchmarks);
+ benchmarks = sep + 1;
+ }
+ std::unique_ptr<ROCKSDB_NAMESPACE::Benchmark> benchmark;
+ if (name == ROCKSDB_NAMESPACE::Slice("fillseq")) {
+ memtablerep.reset(createMemtableRep());
+ key_gen.reset(new ROCKSDB_NAMESPACE::KeyGenerator(
+ &rng, ROCKSDB_NAMESPACE::SEQUENTIAL, FLAGS_num_operations));
+ benchmark.reset(new ROCKSDB_NAMESPACE::FillBenchmark(
+ memtablerep.get(), key_gen.get(), &sequence));
+ } else if (name == ROCKSDB_NAMESPACE::Slice("fillrandom")) {
+ memtablerep.reset(createMemtableRep());
+ key_gen.reset(new ROCKSDB_NAMESPACE::KeyGenerator(
+ &rng, ROCKSDB_NAMESPACE::UNIQUE_RANDOM, FLAGS_num_operations));
+ benchmark.reset(new ROCKSDB_NAMESPACE::FillBenchmark(
+ memtablerep.get(), key_gen.get(), &sequence));
+ } else if (name == ROCKSDB_NAMESPACE::Slice("readrandom")) {
+ key_gen.reset(new ROCKSDB_NAMESPACE::KeyGenerator(
+ &rng, ROCKSDB_NAMESPACE::RANDOM, FLAGS_num_operations));
+ benchmark.reset(new ROCKSDB_NAMESPACE::ReadBenchmark(
+ memtablerep.get(), key_gen.get(), &sequence));
+ } else if (name == ROCKSDB_NAMESPACE::Slice("readseq")) {
+ key_gen.reset(new ROCKSDB_NAMESPACE::KeyGenerator(
+ &rng, ROCKSDB_NAMESPACE::SEQUENTIAL, FLAGS_num_operations));
+ benchmark.reset(new ROCKSDB_NAMESPACE::SeqReadBenchmark(memtablerep.get(),
+ &sequence));
+ } else if (name == ROCKSDB_NAMESPACE::Slice("readwrite")) {
+ memtablerep.reset(createMemtableRep());
+ key_gen.reset(new ROCKSDB_NAMESPACE::KeyGenerator(
+ &rng, ROCKSDB_NAMESPACE::RANDOM, FLAGS_num_operations));
+ benchmark.reset(new ROCKSDB_NAMESPACE::ReadWriteBenchmark<
+ ROCKSDB_NAMESPACE::ConcurrentReadBenchmarkThread>(
+ memtablerep.get(), key_gen.get(), &sequence));
+ } else if (name == ROCKSDB_NAMESPACE::Slice("seqreadwrite")) {
+ memtablerep.reset(createMemtableRep());
+ key_gen.reset(new ROCKSDB_NAMESPACE::KeyGenerator(
+ &rng, ROCKSDB_NAMESPACE::RANDOM, FLAGS_num_operations));
+ benchmark.reset(new ROCKSDB_NAMESPACE::ReadWriteBenchmark<
+ ROCKSDB_NAMESPACE::SeqConcurrentReadBenchmarkThread>(
+ memtablerep.get(), key_gen.get(), &sequence));
+ } else {
+ std::cout << "WARNING: skipping unknown benchmark '" << name.ToString()
+ << std::endl;
+ continue;
+ }
+ std::cout << "Running " << name.ToString() << std::endl;
+ benchmark->Run();
+ }
+
+ return 0;
+}
+
+#endif // GFLAGS
diff --git a/src/rocksdb/memtable/skiplist.h b/src/rocksdb/memtable/skiplist.h
new file mode 100644
index 000000000..e3cecd30c
--- /dev/null
+++ b/src/rocksdb/memtable/skiplist.h
@@ -0,0 +1,498 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+//
+// Thread safety
+// -------------
+//
+// Writes require external synchronization, most likely a mutex.
+// Reads require a guarantee that the SkipList will not be destroyed
+// while the read is in progress. Apart from that, reads progress
+// without any internal locking or synchronization.
+//
+// Invariants:
+//
+// (1) Allocated nodes are never deleted until the SkipList is
+// destroyed. This is trivially guaranteed by the code since we
+// never delete any skip list nodes.
+//
+// (2) The contents of a Node except for the next/prev pointers are
+// immutable after the Node has been linked into the SkipList.
+// Only Insert() modifies the list, and it is careful to initialize
+// a node and use release-stores to publish the nodes in one or
+// more lists.
+//
+// ... prev vs. next pointer ordering ...
+//
+
+#pragma once
+#include <assert.h>
+#include <stdlib.h>
+
+#include <atomic>
+
+#include "memory/allocator.h"
+#include "port/port.h"
+#include "util/random.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+template <typename Key, class Comparator>
+class SkipList {
+ private:
+ struct Node;
+
+ public:
+ // Create a new SkipList object that will use "cmp" for comparing keys,
+ // and will allocate memory using "*allocator". Objects allocated in the
+ // allocator must remain allocated for the lifetime of the skiplist object.
+ explicit SkipList(Comparator cmp, Allocator* allocator,
+ int32_t max_height = 12, int32_t branching_factor = 4);
+ // No copying allowed
+ SkipList(const SkipList&) = delete;
+ void operator=(const SkipList&) = delete;
+
+ // Insert key into the list.
+ // REQUIRES: nothing that compares equal to key is currently in the list.
+ void Insert(const Key& key);
+
+ // Returns true iff an entry that compares equal to key is in the list.
+ bool Contains(const Key& key) const;
+
+ // Return estimated number of entries smaller than `key`.
+ uint64_t EstimateCount(const Key& key) const;
+
+ // Iteration over the contents of a skip list
+ class Iterator {
+ public:
+ // Initialize an iterator over the specified list.
+ // The returned iterator is not valid.
+ explicit Iterator(const SkipList* list);
+
+ // Change the underlying skiplist used for this iterator
+ // This enables us not changing the iterator without deallocating
+ // an old one and then allocating a new one
+ void SetList(const SkipList* list);
+
+ // Returns true iff the iterator is positioned at a valid node.
+ bool Valid() const;
+
+ // Returns the key at the current position.
+ // REQUIRES: Valid()
+ const Key& key() const;
+
+ // Advances to the next position.
+ // REQUIRES: Valid()
+ void Next();
+
+ // Advances to the previous position.
+ // REQUIRES: Valid()
+ void Prev();
+
+ // Advance to the first entry with a key >= target
+ void Seek(const Key& target);
+
+ // Retreat to the last entry with a key <= target
+ void SeekForPrev(const Key& target);
+
+ // Position at the first entry in list.
+ // Final state of iterator is Valid() iff list is not empty.
+ void SeekToFirst();
+
+ // Position at the last entry in list.
+ // Final state of iterator is Valid() iff list is not empty.
+ void SeekToLast();
+
+ private:
+ const SkipList* list_;
+ Node* node_;
+ // Intentionally copyable
+ };
+
+ private:
+ const uint16_t kMaxHeight_;
+ const uint16_t kBranching_;
+ const uint32_t kScaledInverseBranching_;
+
+ // Immutable after construction
+ Comparator const compare_;
+ Allocator* const allocator_; // Allocator used for allocations of nodes
+
+ Node* const head_;
+
+ // Modified only by Insert(). Read racily by readers, but stale
+ // values are ok.
+ std::atomic<int> max_height_; // Height of the entire list
+
+ // Used for optimizing sequential insert patterns. Tricky. prev_[i] for
+ // i up to max_height_ is the predecessor of prev_[0] and prev_height_
+ // is the height of prev_[0]. prev_[0] can only be equal to head before
+ // insertion, in which case max_height_ and prev_height_ are 1.
+ Node** prev_;
+ int32_t prev_height_;
+
+ inline int GetMaxHeight() const {
+ return max_height_.load(std::memory_order_relaxed);
+ }
+
+ Node* NewNode(const Key& key, int height);
+ int RandomHeight();
+ bool Equal(const Key& a, const Key& b) const { return (compare_(a, b) == 0); }
+ bool LessThan(const Key& a, const Key& b) const {
+ return (compare_(a, b) < 0);
+ }
+
+ // Return true if key is greater than the data stored in "n"
+ bool KeyIsAfterNode(const Key& key, Node* n) const;
+
+ // Returns the earliest node with a key >= key.
+ // Return nullptr if there is no such node.
+ Node* FindGreaterOrEqual(const Key& key) const;
+
+ // Return the latest node with a key < key.
+ // Return head_ if there is no such node.
+ // Fills prev[level] with pointer to previous node at "level" for every
+ // level in [0..max_height_-1], if prev is non-null.
+ Node* FindLessThan(const Key& key, Node** prev = nullptr) const;
+
+ // Return the last node in the list.
+ // Return head_ if list is empty.
+ Node* FindLast() const;
+};
+
+// Implementation details follow
+template <typename Key, class Comparator>
+struct SkipList<Key, Comparator>::Node {
+ explicit Node(const Key& k) : key(k) {}
+
+ Key const key;
+
+ // Accessors/mutators for links. Wrapped in methods so we can
+ // add the appropriate barriers as necessary.
+ Node* Next(int n) {
+ assert(n >= 0);
+ // Use an 'acquire load' so that we observe a fully initialized
+ // version of the returned Node.
+ return (next_[n].load(std::memory_order_acquire));
+ }
+ void SetNext(int n, Node* x) {
+ assert(n >= 0);
+ // Use a 'release store' so that anybody who reads through this
+ // pointer observes a fully initialized version of the inserted node.
+ next_[n].store(x, std::memory_order_release);
+ }
+
+ // No-barrier variants that can be safely used in a few locations.
+ Node* NoBarrier_Next(int n) {
+ assert(n >= 0);
+ return next_[n].load(std::memory_order_relaxed);
+ }
+ void NoBarrier_SetNext(int n, Node* x) {
+ assert(n >= 0);
+ next_[n].store(x, std::memory_order_relaxed);
+ }
+
+ private:
+ // Array of length equal to the node height. next_[0] is lowest level link.
+ std::atomic<Node*> next_[1];
+};
+
+template <typename Key, class Comparator>
+typename SkipList<Key, Comparator>::Node* SkipList<Key, Comparator>::NewNode(
+ const Key& key, int height) {
+ char* mem = allocator_->AllocateAligned(
+ sizeof(Node) + sizeof(std::atomic<Node*>) * (height - 1));
+ return new (mem) Node(key);
+}
+
+template <typename Key, class Comparator>
+inline SkipList<Key, Comparator>::Iterator::Iterator(const SkipList* list) {
+ SetList(list);
+}
+
+template <typename Key, class Comparator>
+inline void SkipList<Key, Comparator>::Iterator::SetList(const SkipList* list) {
+ list_ = list;
+ node_ = nullptr;
+}
+
+template <typename Key, class Comparator>
+inline bool SkipList<Key, Comparator>::Iterator::Valid() const {
+ return node_ != nullptr;
+}
+
+template <typename Key, class Comparator>
+inline const Key& SkipList<Key, Comparator>::Iterator::key() const {
+ assert(Valid());
+ return node_->key;
+}
+
+template <typename Key, class Comparator>
+inline void SkipList<Key, Comparator>::Iterator::Next() {
+ assert(Valid());
+ node_ = node_->Next(0);
+}
+
+template <typename Key, class Comparator>
+inline void SkipList<Key, Comparator>::Iterator::Prev() {
+ // Instead of using explicit "prev" links, we just search for the
+ // last node that falls before key.
+ assert(Valid());
+ node_ = list_->FindLessThan(node_->key);
+ if (node_ == list_->head_) {
+ node_ = nullptr;
+ }
+}
+
+template <typename Key, class Comparator>
+inline void SkipList<Key, Comparator>::Iterator::Seek(const Key& target) {
+ node_ = list_->FindGreaterOrEqual(target);
+}
+
+template <typename Key, class Comparator>
+inline void SkipList<Key, Comparator>::Iterator::SeekForPrev(
+ const Key& target) {
+ Seek(target);
+ if (!Valid()) {
+ SeekToLast();
+ }
+ while (Valid() && list_->LessThan(target, key())) {
+ Prev();
+ }
+}
+
+template <typename Key, class Comparator>
+inline void SkipList<Key, Comparator>::Iterator::SeekToFirst() {
+ node_ = list_->head_->Next(0);
+}
+
+template <typename Key, class Comparator>
+inline void SkipList<Key, Comparator>::Iterator::SeekToLast() {
+ node_ = list_->FindLast();
+ if (node_ == list_->head_) {
+ node_ = nullptr;
+ }
+}
+
+template <typename Key, class Comparator>
+int SkipList<Key, Comparator>::RandomHeight() {
+ auto rnd = Random::GetTLSInstance();
+
+ // Increase height with probability 1 in kBranching
+ int height = 1;
+ while (height < kMaxHeight_ && rnd->Next() < kScaledInverseBranching_) {
+ height++;
+ }
+ assert(height > 0);
+ assert(height <= kMaxHeight_);
+ return height;
+}
+
+template <typename Key, class Comparator>
+bool SkipList<Key, Comparator>::KeyIsAfterNode(const Key& key, Node* n) const {
+ // nullptr n is considered infinite
+ return (n != nullptr) && (compare_(n->key, key) < 0);
+}
+
+template <typename Key, class Comparator>
+typename SkipList<Key, Comparator>::Node*
+SkipList<Key, Comparator>::FindGreaterOrEqual(const Key& key) const {
+ // Note: It looks like we could reduce duplication by implementing
+ // this function as FindLessThan(key)->Next(0), but we wouldn't be able
+ // to exit early on equality and the result wouldn't even be correct.
+ // A concurrent insert might occur after FindLessThan(key) but before
+ // we get a chance to call Next(0).
+ Node* x = head_;
+ int level = GetMaxHeight() - 1;
+ Node* last_bigger = nullptr;
+ while (true) {
+ assert(x != nullptr);
+ Node* next = x->Next(level);
+ // Make sure the lists are sorted
+ assert(x == head_ || next == nullptr || KeyIsAfterNode(next->key, x));
+ // Make sure we haven't overshot during our search
+ assert(x == head_ || KeyIsAfterNode(key, x));
+ int cmp =
+ (next == nullptr || next == last_bigger) ? 1 : compare_(next->key, key);
+ if (cmp == 0 || (cmp > 0 && level == 0)) {
+ return next;
+ } else if (cmp < 0) {
+ // Keep searching in this list
+ x = next;
+ } else {
+ // Switch to next list, reuse compare_() result
+ last_bigger = next;
+ level--;
+ }
+ }
+}
+
+template <typename Key, class Comparator>
+typename SkipList<Key, Comparator>::Node*
+SkipList<Key, Comparator>::FindLessThan(const Key& key, Node** prev) const {
+ Node* x = head_;
+ int level = GetMaxHeight() - 1;
+ // KeyIsAfter(key, last_not_after) is definitely false
+ Node* last_not_after = nullptr;
+ while (true) {
+ assert(x != nullptr);
+ Node* next = x->Next(level);
+ assert(x == head_ || next == nullptr || KeyIsAfterNode(next->key, x));
+ assert(x == head_ || KeyIsAfterNode(key, x));
+ if (next != last_not_after && KeyIsAfterNode(key, next)) {
+ // Keep searching in this list
+ x = next;
+ } else {
+ if (prev != nullptr) {
+ prev[level] = x;
+ }
+ if (level == 0) {
+ return x;
+ } else {
+ // Switch to next list, reuse KeyIUsAfterNode() result
+ last_not_after = next;
+ level--;
+ }
+ }
+ }
+}
+
+template <typename Key, class Comparator>
+typename SkipList<Key, Comparator>::Node* SkipList<Key, Comparator>::FindLast()
+ const {
+ Node* x = head_;
+ int level = GetMaxHeight() - 1;
+ while (true) {
+ Node* next = x->Next(level);
+ if (next == nullptr) {
+ if (level == 0) {
+ return x;
+ } else {
+ // Switch to next list
+ level--;
+ }
+ } else {
+ x = next;
+ }
+ }
+}
+
+template <typename Key, class Comparator>
+uint64_t SkipList<Key, Comparator>::EstimateCount(const Key& key) const {
+ uint64_t count = 0;
+
+ Node* x = head_;
+ int level = GetMaxHeight() - 1;
+ while (true) {
+ assert(x == head_ || compare_(x->key, key) < 0);
+ Node* next = x->Next(level);
+ if (next == nullptr || compare_(next->key, key) >= 0) {
+ if (level == 0) {
+ return count;
+ } else {
+ // Switch to next list
+ count *= kBranching_;
+ level--;
+ }
+ } else {
+ x = next;
+ count++;
+ }
+ }
+}
+
+template <typename Key, class Comparator>
+SkipList<Key, Comparator>::SkipList(const Comparator cmp, Allocator* allocator,
+ int32_t max_height,
+ int32_t branching_factor)
+ : kMaxHeight_(static_cast<uint16_t>(max_height)),
+ kBranching_(static_cast<uint16_t>(branching_factor)),
+ kScaledInverseBranching_((Random::kMaxNext + 1) / kBranching_),
+ compare_(cmp),
+ allocator_(allocator),
+ head_(NewNode(0 /* any key will do */, max_height)),
+ max_height_(1),
+ prev_height_(1) {
+ assert(max_height > 0 && kMaxHeight_ == static_cast<uint32_t>(max_height));
+ assert(branching_factor > 0 &&
+ kBranching_ == static_cast<uint32_t>(branching_factor));
+ assert(kScaledInverseBranching_ > 0);
+ // Allocate the prev_ Node* array, directly from the passed-in allocator.
+ // prev_ does not need to be freed, as its life cycle is tied up with
+ // the allocator as a whole.
+ prev_ = reinterpret_cast<Node**>(
+ allocator_->AllocateAligned(sizeof(Node*) * kMaxHeight_));
+ for (int i = 0; i < kMaxHeight_; i++) {
+ head_->SetNext(i, nullptr);
+ prev_[i] = head_;
+ }
+}
+
+template <typename Key, class Comparator>
+void SkipList<Key, Comparator>::Insert(const Key& key) {
+ // fast path for sequential insertion
+ if (!KeyIsAfterNode(key, prev_[0]->NoBarrier_Next(0)) &&
+ (prev_[0] == head_ || KeyIsAfterNode(key, prev_[0]))) {
+ assert(prev_[0] != head_ || (prev_height_ == 1 && GetMaxHeight() == 1));
+
+ // Outside of this method prev_[1..max_height_] is the predecessor
+ // of prev_[0], and prev_height_ refers to prev_[0]. Inside Insert
+ // prev_[0..max_height - 1] is the predecessor of key. Switch from
+ // the external state to the internal
+ for (int i = 1; i < prev_height_; i++) {
+ prev_[i] = prev_[0];
+ }
+ } else {
+ // TODO(opt): we could use a NoBarrier predecessor search as an
+ // optimization for architectures where memory_order_acquire needs
+ // a synchronization instruction. Doesn't matter on x86
+ FindLessThan(key, prev_);
+ }
+
+ // Our data structure does not allow duplicate insertion
+ assert(prev_[0]->Next(0) == nullptr || !Equal(key, prev_[0]->Next(0)->key));
+
+ int height = RandomHeight();
+ if (height > GetMaxHeight()) {
+ for (int i = GetMaxHeight(); i < height; i++) {
+ prev_[i] = head_;
+ }
+ // fprintf(stderr, "Change height from %d to %d\n", max_height_, height);
+
+ // It is ok to mutate max_height_ without any synchronization
+ // with concurrent readers. A concurrent reader that observes
+ // the new value of max_height_ will see either the old value of
+ // new level pointers from head_ (nullptr), or a new value set in
+ // the loop below. In the former case the reader will
+ // immediately drop to the next level since nullptr sorts after all
+ // keys. In the latter case the reader will use the new node.
+ max_height_.store(height, std::memory_order_relaxed);
+ }
+
+ Node* x = NewNode(key, height);
+ for (int i = 0; i < height; i++) {
+ // NoBarrier_SetNext() suffices since we will add a barrier when
+ // we publish a pointer to "x" in prev[i].
+ x->NoBarrier_SetNext(i, prev_[i]->NoBarrier_Next(i));
+ prev_[i]->SetNext(i, x);
+ }
+ prev_[0] = x;
+ prev_height_ = height;
+}
+
+template <typename Key, class Comparator>
+bool SkipList<Key, Comparator>::Contains(const Key& key) const {
+ Node* x = FindGreaterOrEqual(key);
+ if (x != nullptr && Equal(key, x->key)) {
+ return true;
+ } else {
+ return false;
+ }
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/memtable/skiplist_test.cc b/src/rocksdb/memtable/skiplist_test.cc
new file mode 100644
index 000000000..a07088511
--- /dev/null
+++ b/src/rocksdb/memtable/skiplist_test.cc
@@ -0,0 +1,387 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "memtable/skiplist.h"
+
+#include <set>
+
+#include "memory/arena.h"
+#include "rocksdb/env.h"
+#include "test_util/testharness.h"
+#include "util/hash.h"
+#include "util/random.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+using Key = uint64_t;
+
+struct TestComparator {
+ int operator()(const Key& a, const Key& b) const {
+ if (a < b) {
+ return -1;
+ } else if (a > b) {
+ return +1;
+ } else {
+ return 0;
+ }
+ }
+};
+
+class SkipTest : public testing::Test {};
+
+TEST_F(SkipTest, Empty) {
+ Arena arena;
+ TestComparator cmp;
+ SkipList<Key, TestComparator> list(cmp, &arena);
+ ASSERT_TRUE(!list.Contains(10));
+
+ SkipList<Key, TestComparator>::Iterator iter(&list);
+ ASSERT_TRUE(!iter.Valid());
+ iter.SeekToFirst();
+ ASSERT_TRUE(!iter.Valid());
+ iter.Seek(100);
+ ASSERT_TRUE(!iter.Valid());
+ iter.SeekForPrev(100);
+ ASSERT_TRUE(!iter.Valid());
+ iter.SeekToLast();
+ ASSERT_TRUE(!iter.Valid());
+}
+
+TEST_F(SkipTest, InsertAndLookup) {
+ const int N = 2000;
+ const int R = 5000;
+ Random rnd(1000);
+ std::set<Key> keys;
+ Arena arena;
+ TestComparator cmp;
+ SkipList<Key, TestComparator> list(cmp, &arena);
+ for (int i = 0; i < N; i++) {
+ Key key = rnd.Next() % R;
+ if (keys.insert(key).second) {
+ list.Insert(key);
+ }
+ }
+
+ for (int i = 0; i < R; i++) {
+ if (list.Contains(i)) {
+ ASSERT_EQ(keys.count(i), 1U);
+ } else {
+ ASSERT_EQ(keys.count(i), 0U);
+ }
+ }
+
+ // Simple iterator tests
+ {
+ SkipList<Key, TestComparator>::Iterator iter(&list);
+ ASSERT_TRUE(!iter.Valid());
+
+ iter.Seek(0);
+ ASSERT_TRUE(iter.Valid());
+ ASSERT_EQ(*(keys.begin()), iter.key());
+
+ iter.SeekForPrev(R - 1);
+ ASSERT_TRUE(iter.Valid());
+ ASSERT_EQ(*(keys.rbegin()), iter.key());
+
+ iter.SeekToFirst();
+ ASSERT_TRUE(iter.Valid());
+ ASSERT_EQ(*(keys.begin()), iter.key());
+
+ iter.SeekToLast();
+ ASSERT_TRUE(iter.Valid());
+ ASSERT_EQ(*(keys.rbegin()), iter.key());
+ }
+
+ // Forward iteration test
+ for (int i = 0; i < R; i++) {
+ SkipList<Key, TestComparator>::Iterator iter(&list);
+ iter.Seek(i);
+
+ // Compare against model iterator
+ std::set<Key>::iterator model_iter = keys.lower_bound(i);
+ for (int j = 0; j < 3; j++) {
+ if (model_iter == keys.end()) {
+ ASSERT_TRUE(!iter.Valid());
+ break;
+ } else {
+ ASSERT_TRUE(iter.Valid());
+ ASSERT_EQ(*model_iter, iter.key());
+ ++model_iter;
+ iter.Next();
+ }
+ }
+ }
+
+ // Backward iteration test
+ for (int i = 0; i < R; i++) {
+ SkipList<Key, TestComparator>::Iterator iter(&list);
+ iter.SeekForPrev(i);
+
+ // Compare against model iterator
+ std::set<Key>::iterator model_iter = keys.upper_bound(i);
+ for (int j = 0; j < 3; j++) {
+ if (model_iter == keys.begin()) {
+ ASSERT_TRUE(!iter.Valid());
+ break;
+ } else {
+ ASSERT_TRUE(iter.Valid());
+ ASSERT_EQ(*--model_iter, iter.key());
+ iter.Prev();
+ }
+ }
+ }
+}
+
+// We want to make sure that with a single writer and multiple
+// concurrent readers (with no synchronization other than when a
+// reader's iterator is created), the reader always observes all the
+// data that was present in the skip list when the iterator was
+// constructor. Because insertions are happening concurrently, we may
+// also observe new values that were inserted since the iterator was
+// constructed, but we should never miss any values that were present
+// at iterator construction time.
+//
+// We generate multi-part keys:
+// <key,gen,hash>
+// where:
+// key is in range [0..K-1]
+// gen is a generation number for key
+// hash is hash(key,gen)
+//
+// The insertion code picks a random key, sets gen to be 1 + the last
+// generation number inserted for that key, and sets hash to Hash(key,gen).
+//
+// At the beginning of a read, we snapshot the last inserted
+// generation number for each key. We then iterate, including random
+// calls to Next() and Seek(). For every key we encounter, we
+// check that it is either expected given the initial snapshot or has
+// been concurrently added since the iterator started.
+class ConcurrentTest {
+ private:
+ static const uint32_t K = 4;
+
+ static uint64_t key(Key key) { return (key >> 40); }
+ static uint64_t gen(Key key) { return (key >> 8) & 0xffffffffu; }
+ static uint64_t hash(Key key) { return key & 0xff; }
+
+ static uint64_t HashNumbers(uint64_t k, uint64_t g) {
+ uint64_t data[2] = {k, g};
+ return Hash(reinterpret_cast<char*>(data), sizeof(data), 0);
+ }
+
+ static Key MakeKey(uint64_t k, uint64_t g) {
+ assert(sizeof(Key) == sizeof(uint64_t));
+ assert(k <= K); // We sometimes pass K to seek to the end of the skiplist
+ assert(g <= 0xffffffffu);
+ return ((k << 40) | (g << 8) | (HashNumbers(k, g) & 0xff));
+ }
+
+ static bool IsValidKey(Key k) {
+ return hash(k) == (HashNumbers(key(k), gen(k)) & 0xff);
+ }
+
+ static Key RandomTarget(Random* rnd) {
+ switch (rnd->Next() % 10) {
+ case 0:
+ // Seek to beginning
+ return MakeKey(0, 0);
+ case 1:
+ // Seek to end
+ return MakeKey(K, 0);
+ default:
+ // Seek to middle
+ return MakeKey(rnd->Next() % K, 0);
+ }
+ }
+
+ // Per-key generation
+ struct State {
+ std::atomic<int> generation[K];
+ void Set(int k, int v) {
+ generation[k].store(v, std::memory_order_release);
+ }
+ int Get(int k) { return generation[k].load(std::memory_order_acquire); }
+
+ State() {
+ for (unsigned int k = 0; k < K; k++) {
+ Set(k, 0);
+ }
+ }
+ };
+
+ // Current state of the test
+ State current_;
+
+ Arena arena_;
+
+ // SkipList is not protected by mu_. We just use a single writer
+ // thread to modify it.
+ SkipList<Key, TestComparator> list_;
+
+ public:
+ ConcurrentTest() : list_(TestComparator(), &arena_) {}
+
+ // REQUIRES: External synchronization
+ void WriteStep(Random* rnd) {
+ const uint32_t k = rnd->Next() % K;
+ const int g = current_.Get(k) + 1;
+ const Key new_key = MakeKey(k, g);
+ list_.Insert(new_key);
+ current_.Set(k, g);
+ }
+
+ void ReadStep(Random* rnd) {
+ // Remember the initial committed state of the skiplist.
+ State initial_state;
+ for (unsigned int k = 0; k < K; k++) {
+ initial_state.Set(k, current_.Get(k));
+ }
+
+ Key pos = RandomTarget(rnd);
+ SkipList<Key, TestComparator>::Iterator iter(&list_);
+ iter.Seek(pos);
+ while (true) {
+ Key current;
+ if (!iter.Valid()) {
+ current = MakeKey(K, 0);
+ } else {
+ current = iter.key();
+ ASSERT_TRUE(IsValidKey(current)) << current;
+ }
+ ASSERT_LE(pos, current) << "should not go backwards";
+
+ // Verify that everything in [pos,current) was not present in
+ // initial_state.
+ while (pos < current) {
+ ASSERT_LT(key(pos), K) << pos;
+
+ // Note that generation 0 is never inserted, so it is ok if
+ // <*,0,*> is missing.
+ ASSERT_TRUE((gen(pos) == 0U) ||
+ (gen(pos) > static_cast<uint64_t>(initial_state.Get(
+ static_cast<int>(key(pos))))))
+ << "key: " << key(pos) << "; gen: " << gen(pos)
+ << "; initgen: " << initial_state.Get(static_cast<int>(key(pos)));
+
+ // Advance to next key in the valid key space
+ if (key(pos) < key(current)) {
+ pos = MakeKey(key(pos) + 1, 0);
+ } else {
+ pos = MakeKey(key(pos), gen(pos) + 1);
+ }
+ }
+
+ if (!iter.Valid()) {
+ break;
+ }
+
+ if (rnd->Next() % 2) {
+ iter.Next();
+ pos = MakeKey(key(pos), gen(pos) + 1);
+ } else {
+ Key new_target = RandomTarget(rnd);
+ if (new_target > pos) {
+ pos = new_target;
+ iter.Seek(new_target);
+ }
+ }
+ }
+ }
+};
+const uint32_t ConcurrentTest::K;
+
+// Simple test that does single-threaded testing of the ConcurrentTest
+// scaffolding.
+TEST_F(SkipTest, ConcurrentWithoutThreads) {
+ ConcurrentTest test;
+ Random rnd(test::RandomSeed());
+ for (int i = 0; i < 10000; i++) {
+ test.ReadStep(&rnd);
+ test.WriteStep(&rnd);
+ }
+}
+
+class TestState {
+ public:
+ ConcurrentTest t_;
+ int seed_;
+ std::atomic<bool> quit_flag_;
+
+ enum ReaderState { STARTING, RUNNING, DONE };
+
+ explicit TestState(int s)
+ : seed_(s), quit_flag_(false), state_(STARTING), state_cv_(&mu_) {}
+
+ void Wait(ReaderState s) {
+ mu_.Lock();
+ while (state_ != s) {
+ state_cv_.Wait();
+ }
+ mu_.Unlock();
+ }
+
+ void Change(ReaderState s) {
+ mu_.Lock();
+ state_ = s;
+ state_cv_.Signal();
+ mu_.Unlock();
+ }
+
+ private:
+ port::Mutex mu_;
+ ReaderState state_;
+ port::CondVar state_cv_;
+};
+
+static void ConcurrentReader(void* arg) {
+ TestState* state = reinterpret_cast<TestState*>(arg);
+ Random rnd(state->seed_);
+ int64_t reads = 0;
+ state->Change(TestState::RUNNING);
+ while (!state->quit_flag_.load(std::memory_order_acquire)) {
+ state->t_.ReadStep(&rnd);
+ ++reads;
+ }
+ state->Change(TestState::DONE);
+}
+
+static void RunConcurrent(int run) {
+ const int seed = test::RandomSeed() + (run * 100);
+ Random rnd(seed);
+ const int N = 1000;
+ const int kSize = 1000;
+ for (int i = 0; i < N; i++) {
+ if ((i % 100) == 0) {
+ fprintf(stderr, "Run %d of %d\n", i, N);
+ }
+ TestState state(seed + 1);
+ Env::Default()->SetBackgroundThreads(1);
+ Env::Default()->Schedule(ConcurrentReader, &state);
+ state.Wait(TestState::RUNNING);
+ for (int k = 0; k < kSize; k++) {
+ state.t_.WriteStep(&rnd);
+ }
+ state.quit_flag_.store(true, std::memory_order_release);
+ state.Wait(TestState::DONE);
+ }
+}
+
+TEST_F(SkipTest, Concurrent1) { RunConcurrent(1); }
+TEST_F(SkipTest, Concurrent2) { RunConcurrent(2); }
+TEST_F(SkipTest, Concurrent3) { RunConcurrent(3); }
+TEST_F(SkipTest, Concurrent4) { RunConcurrent(4); }
+TEST_F(SkipTest, Concurrent5) { RunConcurrent(5); }
+
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/rocksdb/memtable/skiplistrep.cc b/src/rocksdb/memtable/skiplistrep.cc
new file mode 100644
index 000000000..40f13a2c1
--- /dev/null
+++ b/src/rocksdb/memtable/skiplistrep.cc
@@ -0,0 +1,370 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+#include <random>
+
+#include "db/memtable.h"
+#include "memory/arena.h"
+#include "memtable/inlineskiplist.h"
+#include "rocksdb/memtablerep.h"
+#include "rocksdb/utilities/options_type.h"
+#include "util/string_util.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace {
+class SkipListRep : public MemTableRep {
+ InlineSkipList<const MemTableRep::KeyComparator&> skip_list_;
+ const MemTableRep::KeyComparator& cmp_;
+ const SliceTransform* transform_;
+ const size_t lookahead_;
+
+ friend class LookaheadIterator;
+
+ public:
+ explicit SkipListRep(const MemTableRep::KeyComparator& compare,
+ Allocator* allocator, const SliceTransform* transform,
+ const size_t lookahead)
+ : MemTableRep(allocator),
+ skip_list_(compare, allocator),
+ cmp_(compare),
+ transform_(transform),
+ lookahead_(lookahead) {}
+
+ KeyHandle Allocate(const size_t len, char** buf) override {
+ *buf = skip_list_.AllocateKey(len);
+ return static_cast<KeyHandle>(*buf);
+ }
+
+ // Insert key into the list.
+ // REQUIRES: nothing that compares equal to key is currently in the list.
+ void Insert(KeyHandle handle) override {
+ skip_list_.Insert(static_cast<char*>(handle));
+ }
+
+ bool InsertKey(KeyHandle handle) override {
+ return skip_list_.Insert(static_cast<char*>(handle));
+ }
+
+ void InsertWithHint(KeyHandle handle, void** hint) override {
+ skip_list_.InsertWithHint(static_cast<char*>(handle), hint);
+ }
+
+ bool InsertKeyWithHint(KeyHandle handle, void** hint) override {
+ return skip_list_.InsertWithHint(static_cast<char*>(handle), hint);
+ }
+
+ void InsertWithHintConcurrently(KeyHandle handle, void** hint) override {
+ skip_list_.InsertWithHintConcurrently(static_cast<char*>(handle), hint);
+ }
+
+ bool InsertKeyWithHintConcurrently(KeyHandle handle, void** hint) override {
+ return skip_list_.InsertWithHintConcurrently(static_cast<char*>(handle),
+ hint);
+ }
+
+ void InsertConcurrently(KeyHandle handle) override {
+ skip_list_.InsertConcurrently(static_cast<char*>(handle));
+ }
+
+ bool InsertKeyConcurrently(KeyHandle handle) override {
+ return skip_list_.InsertConcurrently(static_cast<char*>(handle));
+ }
+
+ // Returns true iff an entry that compares equal to key is in the list.
+ bool Contains(const char* key) const override {
+ return skip_list_.Contains(key);
+ }
+
+ size_t ApproximateMemoryUsage() override {
+ // All memory is allocated through allocator; nothing to report here
+ return 0;
+ }
+
+ void Get(const LookupKey& k, void* callback_args,
+ bool (*callback_func)(void* arg, const char* entry)) override {
+ SkipListRep::Iterator iter(&skip_list_);
+ Slice dummy_slice;
+ for (iter.Seek(dummy_slice, k.memtable_key().data());
+ iter.Valid() && callback_func(callback_args, iter.key());
+ iter.Next()) {
+ }
+ }
+
+ uint64_t ApproximateNumEntries(const Slice& start_ikey,
+ const Slice& end_ikey) override {
+ std::string tmp;
+ uint64_t start_count =
+ skip_list_.EstimateCount(EncodeKey(&tmp, start_ikey));
+ uint64_t end_count = skip_list_.EstimateCount(EncodeKey(&tmp, end_ikey));
+ return (end_count >= start_count) ? (end_count - start_count) : 0;
+ }
+
+ void UniqueRandomSample(const uint64_t num_entries,
+ const uint64_t target_sample_size,
+ std::unordered_set<const char*>* entries) override {
+ entries->clear();
+ // Avoid divide-by-0.
+ assert(target_sample_size > 0);
+ assert(num_entries > 0);
+ // NOTE: the size of entries is not enforced to be exactly
+ // target_sample_size at the end of this function, it might be slightly
+ // greater or smaller.
+ SkipListRep::Iterator iter(&skip_list_);
+ // There are two methods to create the subset of samples (size m)
+ // from the table containing N elements:
+ // 1-Iterate linearly through the N memtable entries. For each entry i,
+ // add it to the sample set with a probability
+ // (target_sample_size - entries.size() ) / (N-i).
+ //
+ // 2-Pick m random elements without repetition.
+ // We pick Option 2 when m<sqrt(N) and
+ // Option 1 when m > sqrt(N).
+ if (target_sample_size >
+ static_cast<uint64_t>(std::sqrt(1.0 * num_entries))) {
+ Random* rnd = Random::GetTLSInstance();
+ iter.SeekToFirst();
+ uint64_t counter = 0, num_samples_left = target_sample_size;
+ for (; iter.Valid() && (num_samples_left > 0); iter.Next(), counter++) {
+ // Add entry to sample set with probability
+ // num_samples_left/(num_entries - counter).
+ if (rnd->Next() % (num_entries - counter) < num_samples_left) {
+ entries->insert(iter.key());
+ num_samples_left--;
+ }
+ }
+ } else {
+ // Option 2: pick m random elements with no duplicates.
+ // If Option 2 is picked, then target_sample_size<sqrt(N)
+ // Using a set spares the need to check for duplicates.
+ for (uint64_t i = 0; i < target_sample_size; i++) {
+ // We give it 5 attempts to find a non-duplicate
+ // With 5 attempts, the chances of returning `entries` set
+ // of size target_sample_size is:
+ // PROD_{i=1}^{target_sample_size-1} [1-(i/N)^5]
+ // which is monotonically increasing with N in the worse case
+ // of target_sample_size=sqrt(N), and is always >99.9% for N>4.
+ // At worst, for the final pick , when m=sqrt(N) there is
+ // a probability of p= 1/sqrt(N) chances to find a duplicate.
+ for (uint64_t j = 0; j < 5; j++) {
+ iter.RandomSeek();
+ // unordered_set::insert returns pair<iterator, bool>.
+ // The second element is true if an insert successfully happened.
+ // If element is already in the set, this bool will be false, and
+ // true otherwise.
+ if ((entries->insert(iter.key())).second) {
+ break;
+ }
+ }
+ }
+ }
+ }
+
+ ~SkipListRep() override {}
+
+ // Iteration over the contents of a skip list
+ class Iterator : public MemTableRep::Iterator {
+ InlineSkipList<const MemTableRep::KeyComparator&>::Iterator iter_;
+
+ public:
+ // Initialize an iterator over the specified list.
+ // The returned iterator is not valid.
+ explicit Iterator(
+ const InlineSkipList<const MemTableRep::KeyComparator&>* list)
+ : iter_(list) {}
+
+ ~Iterator() override {}
+
+ // Returns true iff the iterator is positioned at a valid node.
+ bool Valid() const override { return iter_.Valid(); }
+
+ // Returns the key at the current position.
+ // REQUIRES: Valid()
+ const char* key() const override { return iter_.key(); }
+
+ // Advances to the next position.
+ // REQUIRES: Valid()
+ void Next() override { iter_.Next(); }
+
+ // Advances to the previous position.
+ // REQUIRES: Valid()
+ void Prev() override { iter_.Prev(); }
+
+ // Advance to the first entry with a key >= target
+ void Seek(const Slice& user_key, const char* memtable_key) override {
+ if (memtable_key != nullptr) {
+ iter_.Seek(memtable_key);
+ } else {
+ iter_.Seek(EncodeKey(&tmp_, user_key));
+ }
+ }
+
+ // Retreat to the last entry with a key <= target
+ void SeekForPrev(const Slice& user_key, const char* memtable_key) override {
+ if (memtable_key != nullptr) {
+ iter_.SeekForPrev(memtable_key);
+ } else {
+ iter_.SeekForPrev(EncodeKey(&tmp_, user_key));
+ }
+ }
+
+ void RandomSeek() override { iter_.RandomSeek(); }
+
+ // Position at the first entry in list.
+ // Final state of iterator is Valid() iff list is not empty.
+ void SeekToFirst() override { iter_.SeekToFirst(); }
+
+ // Position at the last entry in list.
+ // Final state of iterator is Valid() iff list is not empty.
+ void SeekToLast() override { iter_.SeekToLast(); }
+
+ protected:
+ std::string tmp_; // For passing to EncodeKey
+ };
+
+ // Iterator over the contents of a skip list which also keeps track of the
+ // previously visited node. In Seek(), it examines a few nodes after it
+ // first, falling back to O(log n) search from the head of the list only if
+ // the target key hasn't been found.
+ class LookaheadIterator : public MemTableRep::Iterator {
+ public:
+ explicit LookaheadIterator(const SkipListRep& rep)
+ : rep_(rep), iter_(&rep_.skip_list_), prev_(iter_) {}
+
+ ~LookaheadIterator() override {}
+
+ bool Valid() const override { return iter_.Valid(); }
+
+ const char* key() const override {
+ assert(Valid());
+ return iter_.key();
+ }
+
+ void Next() override {
+ assert(Valid());
+
+ bool advance_prev = true;
+ if (prev_.Valid()) {
+ auto k1 = rep_.UserKey(prev_.key());
+ auto k2 = rep_.UserKey(iter_.key());
+
+ if (k1.compare(k2) == 0) {
+ // same user key, don't move prev_
+ advance_prev = false;
+ } else if (rep_.transform_) {
+ // only advance prev_ if it has the same prefix as iter_
+ auto t1 = rep_.transform_->Transform(k1);
+ auto t2 = rep_.transform_->Transform(k2);
+ advance_prev = t1.compare(t2) == 0;
+ }
+ }
+
+ if (advance_prev) {
+ prev_ = iter_;
+ }
+ iter_.Next();
+ }
+
+ void Prev() override {
+ assert(Valid());
+ iter_.Prev();
+ prev_ = iter_;
+ }
+
+ void Seek(const Slice& internal_key, const char* memtable_key) override {
+ const char* encoded_key = (memtable_key != nullptr)
+ ? memtable_key
+ : EncodeKey(&tmp_, internal_key);
+
+ if (prev_.Valid() && rep_.cmp_(encoded_key, prev_.key()) >= 0) {
+ // prev_.key() is smaller or equal to our target key; do a quick
+ // linear search (at most lookahead_ steps) starting from prev_
+ iter_ = prev_;
+
+ size_t cur = 0;
+ while (cur++ <= rep_.lookahead_ && iter_.Valid()) {
+ if (rep_.cmp_(encoded_key, iter_.key()) <= 0) {
+ return;
+ }
+ Next();
+ }
+ }
+
+ iter_.Seek(encoded_key);
+ prev_ = iter_;
+ }
+
+ void SeekForPrev(const Slice& internal_key,
+ const char* memtable_key) override {
+ const char* encoded_key = (memtable_key != nullptr)
+ ? memtable_key
+ : EncodeKey(&tmp_, internal_key);
+ iter_.SeekForPrev(encoded_key);
+ prev_ = iter_;
+ }
+
+ void SeekToFirst() override {
+ iter_.SeekToFirst();
+ prev_ = iter_;
+ }
+
+ void SeekToLast() override {
+ iter_.SeekToLast();
+ prev_ = iter_;
+ }
+
+ protected:
+ std::string tmp_; // For passing to EncodeKey
+
+ private:
+ const SkipListRep& rep_;
+ InlineSkipList<const MemTableRep::KeyComparator&>::Iterator iter_;
+ InlineSkipList<const MemTableRep::KeyComparator&>::Iterator prev_;
+ };
+
+ MemTableRep::Iterator* GetIterator(Arena* arena = nullptr) override {
+ if (lookahead_ > 0) {
+ void* mem =
+ arena ? arena->AllocateAligned(sizeof(SkipListRep::LookaheadIterator))
+ :
+ operator new(sizeof(SkipListRep::LookaheadIterator));
+ return new (mem) SkipListRep::LookaheadIterator(*this);
+ } else {
+ void* mem = arena ? arena->AllocateAligned(sizeof(SkipListRep::Iterator))
+ :
+ operator new(sizeof(SkipListRep::Iterator));
+ return new (mem) SkipListRep::Iterator(&skip_list_);
+ }
+ }
+};
+} // namespace
+
+static std::unordered_map<std::string, OptionTypeInfo> skiplist_factory_info = {
+#ifndef ROCKSDB_LITE
+ {"lookahead",
+ {0, OptionType::kSizeT, OptionVerificationType::kNormal,
+ OptionTypeFlags::kDontSerialize /*Since it is part of the ID*/}},
+#endif
+};
+
+SkipListFactory::SkipListFactory(size_t lookahead) : lookahead_(lookahead) {
+ RegisterOptions("SkipListFactoryOptions", &lookahead_,
+ &skiplist_factory_info);
+}
+
+std::string SkipListFactory::GetId() const {
+ std::string id = Name();
+ if (lookahead_ > 0) {
+ id.append(":").append(std::to_string(lookahead_));
+ }
+ return id;
+}
+
+MemTableRep* SkipListFactory::CreateMemTableRep(
+ const MemTableRep::KeyComparator& compare, Allocator* allocator,
+ const SliceTransform* transform, Logger* /*logger*/) {
+ return new SkipListRep(compare, allocator, transform, lookahead_);
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/memtable/stl_wrappers.h b/src/rocksdb/memtable/stl_wrappers.h
new file mode 100644
index 000000000..783a8088d
--- /dev/null
+++ b/src/rocksdb/memtable/stl_wrappers.h
@@ -0,0 +1,33 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+#pragma once
+
+#include <map>
+#include <string>
+
+#include "rocksdb/comparator.h"
+#include "rocksdb/memtablerep.h"
+#include "rocksdb/slice.h"
+#include "util/coding.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace stl_wrappers {
+
+class Base {
+ protected:
+ const MemTableRep::KeyComparator& compare_;
+ explicit Base(const MemTableRep::KeyComparator& compare)
+ : compare_(compare) {}
+};
+
+struct Compare : private Base {
+ explicit Compare(const MemTableRep::KeyComparator& compare) : Base(compare) {}
+ inline bool operator()(const char* a, const char* b) const {
+ return compare_(a, b) < 0;
+ }
+};
+
+} // namespace stl_wrappers
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/memtable/vectorrep.cc b/src/rocksdb/memtable/vectorrep.cc
new file mode 100644
index 000000000..293163349
--- /dev/null
+++ b/src/rocksdb/memtable/vectorrep.cc
@@ -0,0 +1,309 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+#ifndef ROCKSDB_LITE
+#include <algorithm>
+#include <memory>
+#include <set>
+#include <type_traits>
+#include <unordered_set>
+
+#include "db/memtable.h"
+#include "memory/arena.h"
+#include "memtable/stl_wrappers.h"
+#include "port/port.h"
+#include "rocksdb/memtablerep.h"
+#include "rocksdb/utilities/options_type.h"
+#include "util/mutexlock.h"
+
+namespace ROCKSDB_NAMESPACE {
+namespace {
+
+class VectorRep : public MemTableRep {
+ public:
+ VectorRep(const KeyComparator& compare, Allocator* allocator, size_t count);
+
+ // Insert key into the collection. (The caller will pack key and value into a
+ // single buffer and pass that in as the parameter to Insert)
+ // REQUIRES: nothing that compares equal to key is currently in the
+ // collection.
+ void Insert(KeyHandle handle) override;
+
+ // Returns true iff an entry that compares equal to key is in the collection.
+ bool Contains(const char* key) const override;
+
+ void MarkReadOnly() override;
+
+ size_t ApproximateMemoryUsage() override;
+
+ void Get(const LookupKey& k, void* callback_args,
+ bool (*callback_func)(void* arg, const char* entry)) override;
+
+ ~VectorRep() override {}
+
+ class Iterator : public MemTableRep::Iterator {
+ class VectorRep* vrep_;
+ std::shared_ptr<std::vector<const char*>> bucket_;
+ std::vector<const char*>::const_iterator mutable cit_;
+ const KeyComparator& compare_;
+ std::string tmp_; // For passing to EncodeKey
+ bool mutable sorted_;
+ void DoSort() const;
+
+ public:
+ explicit Iterator(class VectorRep* vrep,
+ std::shared_ptr<std::vector<const char*>> bucket,
+ const KeyComparator& compare);
+
+ // Initialize an iterator over the specified collection.
+ // The returned iterator is not valid.
+ // explicit Iterator(const MemTableRep* collection);
+ ~Iterator() override{};
+
+ // Returns true iff the iterator is positioned at a valid node.
+ bool Valid() const override;
+
+ // Returns the key at the current position.
+ // REQUIRES: Valid()
+ const char* key() const override;
+
+ // Advances to the next position.
+ // REQUIRES: Valid()
+ void Next() override;
+
+ // Advances to the previous position.
+ // REQUIRES: Valid()
+ void Prev() override;
+
+ // Advance to the first entry with a key >= target
+ void Seek(const Slice& user_key, const char* memtable_key) override;
+
+ // Advance to the first entry with a key <= target
+ void SeekForPrev(const Slice& user_key, const char* memtable_key) override;
+
+ // Position at the first entry in collection.
+ // Final state of iterator is Valid() iff collection is not empty.
+ void SeekToFirst() override;
+
+ // Position at the last entry in collection.
+ // Final state of iterator is Valid() iff collection is not empty.
+ void SeekToLast() override;
+ };
+
+ // Return an iterator over the keys in this representation.
+ MemTableRep::Iterator* GetIterator(Arena* arena) override;
+
+ private:
+ friend class Iterator;
+ using Bucket = std::vector<const char*>;
+ std::shared_ptr<Bucket> bucket_;
+ mutable port::RWMutex rwlock_;
+ bool immutable_;
+ bool sorted_;
+ const KeyComparator& compare_;
+};
+
+void VectorRep::Insert(KeyHandle handle) {
+ auto* key = static_cast<char*>(handle);
+ WriteLock l(&rwlock_);
+ assert(!immutable_);
+ bucket_->push_back(key);
+}
+
+// Returns true iff an entry that compares equal to key is in the collection.
+bool VectorRep::Contains(const char* key) const {
+ ReadLock l(&rwlock_);
+ return std::find(bucket_->begin(), bucket_->end(), key) != bucket_->end();
+}
+
+void VectorRep::MarkReadOnly() {
+ WriteLock l(&rwlock_);
+ immutable_ = true;
+}
+
+size_t VectorRep::ApproximateMemoryUsage() {
+ return sizeof(bucket_) + sizeof(*bucket_) +
+ bucket_->size() *
+ sizeof(
+ std::remove_reference<decltype(*bucket_)>::type::value_type);
+}
+
+VectorRep::VectorRep(const KeyComparator& compare, Allocator* allocator,
+ size_t count)
+ : MemTableRep(allocator),
+ bucket_(new Bucket()),
+ immutable_(false),
+ sorted_(false),
+ compare_(compare) {
+ bucket_.get()->reserve(count);
+}
+
+VectorRep::Iterator::Iterator(class VectorRep* vrep,
+ std::shared_ptr<std::vector<const char*>> bucket,
+ const KeyComparator& compare)
+ : vrep_(vrep),
+ bucket_(bucket),
+ cit_(bucket_->end()),
+ compare_(compare),
+ sorted_(false) {}
+
+void VectorRep::Iterator::DoSort() const {
+ // vrep is non-null means that we are working on an immutable memtable
+ if (!sorted_ && vrep_ != nullptr) {
+ WriteLock l(&vrep_->rwlock_);
+ if (!vrep_->sorted_) {
+ std::sort(bucket_->begin(), bucket_->end(),
+ stl_wrappers::Compare(compare_));
+ cit_ = bucket_->begin();
+ vrep_->sorted_ = true;
+ }
+ sorted_ = true;
+ }
+ if (!sorted_) {
+ std::sort(bucket_->begin(), bucket_->end(),
+ stl_wrappers::Compare(compare_));
+ cit_ = bucket_->begin();
+ sorted_ = true;
+ }
+ assert(sorted_);
+ assert(vrep_ == nullptr || vrep_->sorted_);
+}
+
+// Returns true iff the iterator is positioned at a valid node.
+bool VectorRep::Iterator::Valid() const {
+ DoSort();
+ return cit_ != bucket_->end();
+}
+
+// Returns the key at the current position.
+// REQUIRES: Valid()
+const char* VectorRep::Iterator::key() const {
+ assert(sorted_);
+ return *cit_;
+}
+
+// Advances to the next position.
+// REQUIRES: Valid()
+void VectorRep::Iterator::Next() {
+ assert(sorted_);
+ if (cit_ == bucket_->end()) {
+ return;
+ }
+ ++cit_;
+}
+
+// Advances to the previous position.
+// REQUIRES: Valid()
+void VectorRep::Iterator::Prev() {
+ assert(sorted_);
+ if (cit_ == bucket_->begin()) {
+ // If you try to go back from the first element, the iterator should be
+ // invalidated. So we set it to past-the-end. This means that you can
+ // treat the container circularly.
+ cit_ = bucket_->end();
+ } else {
+ --cit_;
+ }
+}
+
+// Advance to the first entry with a key >= target
+void VectorRep::Iterator::Seek(const Slice& user_key,
+ const char* memtable_key) {
+ DoSort();
+ // Do binary search to find first value not less than the target
+ const char* encoded_key =
+ (memtable_key != nullptr) ? memtable_key : EncodeKey(&tmp_, user_key);
+ cit_ = std::equal_range(bucket_->begin(), bucket_->end(), encoded_key,
+ [this](const char* a, const char* b) {
+ return compare_(a, b) < 0;
+ })
+ .first;
+}
+
+// Advance to the first entry with a key <= target
+void VectorRep::Iterator::SeekForPrev(const Slice& /*user_key*/,
+ const char* /*memtable_key*/) {
+ assert(false);
+}
+
+// Position at the first entry in collection.
+// Final state of iterator is Valid() iff collection is not empty.
+void VectorRep::Iterator::SeekToFirst() {
+ DoSort();
+ cit_ = bucket_->begin();
+}
+
+// Position at the last entry in collection.
+// Final state of iterator is Valid() iff collection is not empty.
+void VectorRep::Iterator::SeekToLast() {
+ DoSort();
+ cit_ = bucket_->end();
+ if (bucket_->size() != 0) {
+ --cit_;
+ }
+}
+
+void VectorRep::Get(const LookupKey& k, void* callback_args,
+ bool (*callback_func)(void* arg, const char* entry)) {
+ rwlock_.ReadLock();
+ VectorRep* vector_rep;
+ std::shared_ptr<Bucket> bucket;
+ if (immutable_) {
+ vector_rep = this;
+ } else {
+ vector_rep = nullptr;
+ bucket.reset(new Bucket(*bucket_)); // make a copy
+ }
+ VectorRep::Iterator iter(vector_rep, immutable_ ? bucket_ : bucket, compare_);
+ rwlock_.ReadUnlock();
+
+ for (iter.Seek(k.user_key(), k.memtable_key().data());
+ iter.Valid() && callback_func(callback_args, iter.key()); iter.Next()) {
+ }
+}
+
+MemTableRep::Iterator* VectorRep::GetIterator(Arena* arena) {
+ char* mem = nullptr;
+ if (arena != nullptr) {
+ mem = arena->AllocateAligned(sizeof(Iterator));
+ }
+ ReadLock l(&rwlock_);
+ // Do not sort here. The sorting would be done the first time
+ // a Seek is performed on the iterator.
+ if (immutable_) {
+ if (arena == nullptr) {
+ return new Iterator(this, bucket_, compare_);
+ } else {
+ return new (mem) Iterator(this, bucket_, compare_);
+ }
+ } else {
+ std::shared_ptr<Bucket> tmp;
+ tmp.reset(new Bucket(*bucket_)); // make a copy
+ if (arena == nullptr) {
+ return new Iterator(nullptr, tmp, compare_);
+ } else {
+ return new (mem) Iterator(nullptr, tmp, compare_);
+ }
+ }
+}
+} // namespace
+
+static std::unordered_map<std::string, OptionTypeInfo> vector_rep_table_info = {
+ {"count",
+ {0, OptionType::kSizeT, OptionVerificationType::kNormal,
+ OptionTypeFlags::kNone}},
+};
+
+VectorRepFactory::VectorRepFactory(size_t count) : count_(count) {
+ RegisterOptions("VectorRepFactoryOptions", &count_, &vector_rep_table_info);
+}
+
+MemTableRep* VectorRepFactory::CreateMemTableRep(
+ const MemTableRep::KeyComparator& compare, Allocator* allocator,
+ const SliceTransform*, Logger* /*logger*/) {
+ return new VectorRep(compare, allocator, count_);
+}
+} // namespace ROCKSDB_NAMESPACE
+#endif // ROCKSDB_LITE
diff --git a/src/rocksdb/memtable/write_buffer_manager.cc b/src/rocksdb/memtable/write_buffer_manager.cc
new file mode 100644
index 000000000..8db9816be
--- /dev/null
+++ b/src/rocksdb/memtable/write_buffer_manager.cc
@@ -0,0 +1,202 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "rocksdb/write_buffer_manager.h"
+
+#include <memory>
+
+#include "cache/cache_entry_roles.h"
+#include "cache/cache_reservation_manager.h"
+#include "db/db_impl/db_impl.h"
+#include "rocksdb/status.h"
+#include "util/coding.h"
+
+namespace ROCKSDB_NAMESPACE {
+WriteBufferManager::WriteBufferManager(size_t _buffer_size,
+ std::shared_ptr<Cache> cache,
+ bool allow_stall)
+ : buffer_size_(_buffer_size),
+ mutable_limit_(buffer_size_ * 7 / 8),
+ memory_used_(0),
+ memory_active_(0),
+ cache_res_mgr_(nullptr),
+ allow_stall_(allow_stall),
+ stall_active_(false) {
+#ifndef ROCKSDB_LITE
+ if (cache) {
+ // Memtable's memory usage tends to fluctuate frequently
+ // therefore we set delayed_decrease = true to save some dummy entry
+ // insertion on memory increase right after memory decrease
+ cache_res_mgr_ = std::make_shared<
+ CacheReservationManagerImpl<CacheEntryRole::kWriteBuffer>>(
+ cache, true /* delayed_decrease */);
+ }
+#else
+ (void)cache;
+#endif // ROCKSDB_LITE
+}
+
+WriteBufferManager::~WriteBufferManager() {
+#ifndef NDEBUG
+ std::unique_lock<std::mutex> lock(mu_);
+ assert(queue_.empty());
+#endif
+}
+
+std::size_t WriteBufferManager::dummy_entries_in_cache_usage() const {
+ if (cache_res_mgr_ != nullptr) {
+ return cache_res_mgr_->GetTotalReservedCacheSize();
+ } else {
+ return 0;
+ }
+}
+
+void WriteBufferManager::ReserveMem(size_t mem) {
+ if (cache_res_mgr_ != nullptr) {
+ ReserveMemWithCache(mem);
+ } else if (enabled()) {
+ memory_used_.fetch_add(mem, std::memory_order_relaxed);
+ }
+ if (enabled()) {
+ memory_active_.fetch_add(mem, std::memory_order_relaxed);
+ }
+}
+
+// Should only be called from write thread
+void WriteBufferManager::ReserveMemWithCache(size_t mem) {
+#ifndef ROCKSDB_LITE
+ assert(cache_res_mgr_ != nullptr);
+ // Use a mutex to protect various data structures. Can be optimized to a
+ // lock-free solution if it ends up with a performance bottleneck.
+ std::lock_guard<std::mutex> lock(cache_res_mgr_mu_);
+
+ size_t new_mem_used = memory_used_.load(std::memory_order_relaxed) + mem;
+ memory_used_.store(new_mem_used, std::memory_order_relaxed);
+ Status s = cache_res_mgr_->UpdateCacheReservation(new_mem_used);
+
+ // We absorb the error since WriteBufferManager is not able to handle
+ // this failure properly. Ideallly we should prevent this allocation
+ // from happening if this cache charging fails.
+ // [TODO] We'll need to improve it in the future and figure out what to do on
+ // error
+ s.PermitUncheckedError();
+#else
+ (void)mem;
+#endif // ROCKSDB_LITE
+}
+
+void WriteBufferManager::ScheduleFreeMem(size_t mem) {
+ if (enabled()) {
+ memory_active_.fetch_sub(mem, std::memory_order_relaxed);
+ }
+}
+
+void WriteBufferManager::FreeMem(size_t mem) {
+ if (cache_res_mgr_ != nullptr) {
+ FreeMemWithCache(mem);
+ } else if (enabled()) {
+ memory_used_.fetch_sub(mem, std::memory_order_relaxed);
+ }
+ // Check if stall is active and can be ended.
+ MaybeEndWriteStall();
+}
+
+void WriteBufferManager::FreeMemWithCache(size_t mem) {
+#ifndef ROCKSDB_LITE
+ assert(cache_res_mgr_ != nullptr);
+ // Use a mutex to protect various data structures. Can be optimized to a
+ // lock-free solution if it ends up with a performance bottleneck.
+ std::lock_guard<std::mutex> lock(cache_res_mgr_mu_);
+ size_t new_mem_used = memory_used_.load(std::memory_order_relaxed) - mem;
+ memory_used_.store(new_mem_used, std::memory_order_relaxed);
+ Status s = cache_res_mgr_->UpdateCacheReservation(new_mem_used);
+
+ // We absorb the error since WriteBufferManager is not able to handle
+ // this failure properly.
+ // [TODO] We'll need to improve it in the future and figure out what to do on
+ // error
+ s.PermitUncheckedError();
+#else
+ (void)mem;
+#endif // ROCKSDB_LITE
+}
+
+void WriteBufferManager::BeginWriteStall(StallInterface* wbm_stall) {
+ assert(wbm_stall != nullptr);
+ assert(allow_stall_);
+
+ // Allocate outside of the lock.
+ std::list<StallInterface*> new_node = {wbm_stall};
+
+ {
+ std::unique_lock<std::mutex> lock(mu_);
+ // Verify if the stall conditions are stil active.
+ if (ShouldStall()) {
+ stall_active_.store(true, std::memory_order_relaxed);
+ queue_.splice(queue_.end(), std::move(new_node));
+ }
+ }
+
+ // If the node was not consumed, the stall has ended already and we can signal
+ // the caller.
+ if (!new_node.empty()) {
+ new_node.front()->Signal();
+ }
+}
+
+// Called when memory is freed in FreeMem or the buffer size has changed.
+void WriteBufferManager::MaybeEndWriteStall() {
+ // Cannot early-exit on !enabled() because SetBufferSize(0) needs to unblock
+ // the writers.
+ if (!allow_stall_) {
+ return;
+ }
+
+ if (IsStallThresholdExceeded()) {
+ return; // Stall conditions have not resolved.
+ }
+
+ // Perform all deallocations outside of the lock.
+ std::list<StallInterface*> cleanup;
+
+ std::unique_lock<std::mutex> lock(mu_);
+ if (!stall_active_.load(std::memory_order_relaxed)) {
+ return; // Nothing to do.
+ }
+
+ // Unblock new writers.
+ stall_active_.store(false, std::memory_order_relaxed);
+
+ // Unblock the writers in the queue.
+ for (StallInterface* wbm_stall : queue_) {
+ wbm_stall->Signal();
+ }
+ cleanup = std::move(queue_);
+}
+
+void WriteBufferManager::RemoveDBFromQueue(StallInterface* wbm_stall) {
+ assert(wbm_stall != nullptr);
+
+ // Deallocate the removed nodes outside of the lock.
+ std::list<StallInterface*> cleanup;
+
+ if (enabled() && allow_stall_) {
+ std::unique_lock<std::mutex> lock(mu_);
+ for (auto it = queue_.begin(); it != queue_.end();) {
+ auto next = std::next(it);
+ if (*it == wbm_stall) {
+ cleanup.splice(cleanup.end(), queue_, std::move(it));
+ }
+ it = next;
+ }
+ }
+ wbm_stall->Signal();
+}
+
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/memtable/write_buffer_manager_test.cc b/src/rocksdb/memtable/write_buffer_manager_test.cc
new file mode 100644
index 000000000..1cc4c2cc5
--- /dev/null
+++ b/src/rocksdb/memtable/write_buffer_manager_test.cc
@@ -0,0 +1,305 @@
+// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+//
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "rocksdb/write_buffer_manager.h"
+
+#include "test_util/testharness.h"
+
+namespace ROCKSDB_NAMESPACE {
+class WriteBufferManagerTest : public testing::Test {};
+
+#ifndef ROCKSDB_LITE
+const size_t kSizeDummyEntry = 256 * 1024;
+
+TEST_F(WriteBufferManagerTest, ShouldFlush) {
+ // A write buffer manager of size 10MB
+ std::unique_ptr<WriteBufferManager> wbf(
+ new WriteBufferManager(10 * 1024 * 1024));
+
+ wbf->ReserveMem(8 * 1024 * 1024);
+ ASSERT_FALSE(wbf->ShouldFlush());
+ // 90% of the hard limit will hit the condition
+ wbf->ReserveMem(1 * 1024 * 1024);
+ ASSERT_TRUE(wbf->ShouldFlush());
+ // Scheduling for freeing will release the condition
+ wbf->ScheduleFreeMem(1 * 1024 * 1024);
+ ASSERT_FALSE(wbf->ShouldFlush());
+
+ wbf->ReserveMem(2 * 1024 * 1024);
+ ASSERT_TRUE(wbf->ShouldFlush());
+
+ wbf->ScheduleFreeMem(4 * 1024 * 1024);
+ // 11MB total, 6MB mutable. hard limit still hit
+ ASSERT_TRUE(wbf->ShouldFlush());
+
+ wbf->ScheduleFreeMem(2 * 1024 * 1024);
+ // 11MB total, 4MB mutable. hard limit stills but won't flush because more
+ // than half data is already being flushed.
+ ASSERT_FALSE(wbf->ShouldFlush());
+
+ wbf->ReserveMem(4 * 1024 * 1024);
+ // 15 MB total, 8MB mutable.
+ ASSERT_TRUE(wbf->ShouldFlush());
+
+ wbf->FreeMem(7 * 1024 * 1024);
+ // 8MB total, 8MB mutable.
+ ASSERT_FALSE(wbf->ShouldFlush());
+
+ // change size: 8M limit, 7M mutable limit
+ wbf->SetBufferSize(8 * 1024 * 1024);
+ // 8MB total, 8MB mutable.
+ ASSERT_TRUE(wbf->ShouldFlush());
+
+ wbf->ScheduleFreeMem(2 * 1024 * 1024);
+ // 8MB total, 6MB mutable.
+ ASSERT_TRUE(wbf->ShouldFlush());
+
+ wbf->FreeMem(2 * 1024 * 1024);
+ // 6MB total, 6MB mutable.
+ ASSERT_FALSE(wbf->ShouldFlush());
+
+ wbf->ReserveMem(1 * 1024 * 1024);
+ // 7MB total, 7MB mutable.
+ ASSERT_FALSE(wbf->ShouldFlush());
+
+ wbf->ReserveMem(1 * 1024 * 1024);
+ // 8MB total, 8MB mutable.
+ ASSERT_TRUE(wbf->ShouldFlush());
+
+ wbf->ScheduleFreeMem(1 * 1024 * 1024);
+ wbf->FreeMem(1 * 1024 * 1024);
+ // 7MB total, 7MB mutable.
+ ASSERT_FALSE(wbf->ShouldFlush());
+}
+
+class ChargeWriteBufferTest : public testing::Test {};
+
+TEST_F(ChargeWriteBufferTest, Basic) {
+ constexpr std::size_t kMetaDataChargeOverhead = 10000;
+
+ LRUCacheOptions co;
+ // 1GB cache
+ co.capacity = 1024 * 1024 * 1024;
+ co.num_shard_bits = 4;
+ co.metadata_charge_policy = kDontChargeCacheMetadata;
+ std::shared_ptr<Cache> cache = NewLRUCache(co);
+ // A write buffer manager of size 50MB
+ std::unique_ptr<WriteBufferManager> wbf(
+ new WriteBufferManager(50 * 1024 * 1024, cache));
+
+ // Allocate 333KB will allocate 512KB, memory_used_ = 333KB
+ wbf->ReserveMem(333 * 1024);
+ // 2 dummy entries are added for size 333 KB
+ ASSERT_EQ(wbf->dummy_entries_in_cache_usage(), 2 * kSizeDummyEntry);
+ ASSERT_GE(cache->GetPinnedUsage(), 2 * 256 * 1024);
+ ASSERT_LT(cache->GetPinnedUsage(), 2 * 256 * 1024 + kMetaDataChargeOverhead);
+
+ // Allocate another 512KB, memory_used_ = 845KB
+ wbf->ReserveMem(512 * 1024);
+ // 2 more dummy entries are added for size 512 KB
+ // since ceil((memory_used_ - dummy_entries_in_cache_usage) % kSizeDummyEntry)
+ // = 2
+ ASSERT_EQ(wbf->dummy_entries_in_cache_usage(), 4 * kSizeDummyEntry);
+ ASSERT_GE(cache->GetPinnedUsage(), 4 * 256 * 1024);
+ ASSERT_LT(cache->GetPinnedUsage(), 4 * 256 * 1024 + kMetaDataChargeOverhead);
+
+ // Allocate another 10MB, memory_used_ = 11085KB
+ wbf->ReserveMem(10 * 1024 * 1024);
+ // 40 more entries are added for size 10 * 1024 * 1024 KB
+ ASSERT_EQ(wbf->dummy_entries_in_cache_usage(), 44 * kSizeDummyEntry);
+ ASSERT_GE(cache->GetPinnedUsage(), 44 * 256 * 1024);
+ ASSERT_LT(cache->GetPinnedUsage(), 44 * 256 * 1024 + kMetaDataChargeOverhead);
+
+ // Free 1MB, memory_used_ = 10061KB
+ // It will not cause any change in cache cost
+ // since memory_used_ > dummy_entries_in_cache_usage * (3/4)
+ wbf->FreeMem(1 * 1024 * 1024);
+ ASSERT_EQ(wbf->dummy_entries_in_cache_usage(), 44 * kSizeDummyEntry);
+ ASSERT_GE(cache->GetPinnedUsage(), 44 * 256 * 1024);
+ ASSERT_LT(cache->GetPinnedUsage(), 44 * 256 * 1024 + kMetaDataChargeOverhead);
+ ASSERT_FALSE(wbf->ShouldFlush());
+
+ // Allocate another 41MB, memory_used_ = 52045KB
+ wbf->ReserveMem(41 * 1024 * 1024);
+ ASSERT_EQ(wbf->dummy_entries_in_cache_usage(), 204 * kSizeDummyEntry);
+ ASSERT_GE(cache->GetPinnedUsage(), 204 * 256 * 1024);
+ ASSERT_LT(cache->GetPinnedUsage(),
+ 204 * 256 * 1024 + kMetaDataChargeOverhead);
+ ASSERT_TRUE(wbf->ShouldFlush());
+
+ ASSERT_TRUE(wbf->ShouldFlush());
+
+ // Schedule free 20MB, memory_used_ = 52045KB
+ // It will not cause any change in memory_used and cache cost
+ wbf->ScheduleFreeMem(20 * 1024 * 1024);
+ ASSERT_EQ(wbf->dummy_entries_in_cache_usage(), 204 * kSizeDummyEntry);
+ ASSERT_GE(cache->GetPinnedUsage(), 204 * 256 * 1024);
+ ASSERT_LT(cache->GetPinnedUsage(),
+ 204 * 256 * 1024 + kMetaDataChargeOverhead);
+ // Still need flush as the hard limit hits
+ ASSERT_TRUE(wbf->ShouldFlush());
+
+ // Free 20MB, memory_used_ = 31565KB
+ // It will releae 80 dummy entries from cache since
+ // since memory_used_ < dummy_entries_in_cache_usage * (3/4)
+ // and floor((dummy_entries_in_cache_usage - memory_used_) % kSizeDummyEntry)
+ // = 80
+ wbf->FreeMem(20 * 1024 * 1024);
+ ASSERT_EQ(wbf->dummy_entries_in_cache_usage(), 124 * kSizeDummyEntry);
+ ASSERT_GE(cache->GetPinnedUsage(), 124 * 256 * 1024);
+ ASSERT_LT(cache->GetPinnedUsage(),
+ 124 * 256 * 1024 + kMetaDataChargeOverhead);
+
+ ASSERT_FALSE(wbf->ShouldFlush());
+
+ // Free 16KB, memory_used_ = 31549KB
+ // It will not release any dummy entry since memory_used_ >=
+ // dummy_entries_in_cache_usage * (3/4)
+ wbf->FreeMem(16 * 1024);
+ ASSERT_EQ(wbf->dummy_entries_in_cache_usage(), 124 * kSizeDummyEntry);
+ ASSERT_GE(cache->GetPinnedUsage(), 124 * 256 * 1024);
+ ASSERT_LT(cache->GetPinnedUsage(),
+ 124 * 256 * 1024 + kMetaDataChargeOverhead);
+
+ // Free 20MB, memory_used_ = 11069KB
+ // It will releae 80 dummy entries from cache
+ // since memory_used_ < dummy_entries_in_cache_usage * (3/4)
+ // and floor((dummy_entries_in_cache_usage - memory_used_) % kSizeDummyEntry)
+ // = 80
+ wbf->FreeMem(20 * 1024 * 1024);
+ ASSERT_EQ(wbf->dummy_entries_in_cache_usage(), 44 * kSizeDummyEntry);
+ ASSERT_GE(cache->GetPinnedUsage(), 44 * 256 * 1024);
+ ASSERT_LT(cache->GetPinnedUsage(), 44 * 256 * 1024 + kMetaDataChargeOverhead);
+
+ // Free 1MB, memory_used_ = 10045KB
+ // It will not cause any change in cache cost
+ // since memory_used_ > dummy_entries_in_cache_usage * (3/4)
+ wbf->FreeMem(1 * 1024 * 1024);
+ ASSERT_EQ(wbf->dummy_entries_in_cache_usage(), 44 * kSizeDummyEntry);
+ ASSERT_GE(cache->GetPinnedUsage(), 44 * 256 * 1024);
+ ASSERT_LT(cache->GetPinnedUsage(), 44 * 256 * 1024 + kMetaDataChargeOverhead);
+
+ // Reserve 512KB, memory_used_ = 10557KB
+ // It will not casue any change in cache cost
+ // since memory_used_ > dummy_entries_in_cache_usage * (3/4)
+ // which reflects the benefit of saving dummy entry insertion on memory
+ // reservation after delay decrease
+ wbf->ReserveMem(512 * 1024);
+ ASSERT_EQ(wbf->dummy_entries_in_cache_usage(), 44 * kSizeDummyEntry);
+ ASSERT_GE(cache->GetPinnedUsage(), 44 * 256 * 1024);
+ ASSERT_LT(cache->GetPinnedUsage(), 44 * 256 * 1024 + kMetaDataChargeOverhead);
+
+ // Destroy write buffer manger should free everything
+ wbf.reset();
+ ASSERT_EQ(cache->GetPinnedUsage(), 0);
+}
+
+TEST_F(ChargeWriteBufferTest, BasicWithNoBufferSizeLimit) {
+ constexpr std::size_t kMetaDataChargeOverhead = 10000;
+ // 1GB cache
+ std::shared_ptr<Cache> cache = NewLRUCache(1024 * 1024 * 1024, 4);
+ // A write buffer manager of size 256MB
+ std::unique_ptr<WriteBufferManager> wbf(new WriteBufferManager(0, cache));
+
+ // Allocate 10MB, memory_used_ = 10240KB
+ // It will allocate 40 dummy entries
+ wbf->ReserveMem(10 * 1024 * 1024);
+ ASSERT_EQ(wbf->dummy_entries_in_cache_usage(), 40 * kSizeDummyEntry);
+ ASSERT_GE(cache->GetPinnedUsage(), 40 * 256 * 1024);
+ ASSERT_LT(cache->GetPinnedUsage(), 40 * 256 * 1024 + kMetaDataChargeOverhead);
+
+ ASSERT_FALSE(wbf->ShouldFlush());
+
+ // Free 9MB, memory_used_ = 1024KB
+ // It will free 36 dummy entries
+ wbf->FreeMem(9 * 1024 * 1024);
+ ASSERT_EQ(wbf->dummy_entries_in_cache_usage(), 4 * kSizeDummyEntry);
+ ASSERT_GE(cache->GetPinnedUsage(), 4 * 256 * 1024);
+ ASSERT_LT(cache->GetPinnedUsage(), 4 * 256 * 1024 + kMetaDataChargeOverhead);
+
+ // Free 160KB gradually, memory_used_ = 864KB
+ // It will not cause any change
+ // since memory_used_ > dummy_entries_in_cache_usage * 3/4
+ for (int i = 0; i < 40; i++) {
+ wbf->FreeMem(4 * 1024);
+ }
+ ASSERT_EQ(wbf->dummy_entries_in_cache_usage(), 4 * kSizeDummyEntry);
+ ASSERT_GE(cache->GetPinnedUsage(), 4 * 256 * 1024);
+ ASSERT_LT(cache->GetPinnedUsage(), 4 * 256 * 1024 + kMetaDataChargeOverhead);
+}
+
+TEST_F(ChargeWriteBufferTest, BasicWithCacheFull) {
+ constexpr std::size_t kMetaDataChargeOverhead = 20000;
+
+ // 12MB cache size with strict capacity
+ LRUCacheOptions lo;
+ lo.capacity = 12 * 1024 * 1024;
+ lo.num_shard_bits = 0;
+ lo.strict_capacity_limit = true;
+ std::shared_ptr<Cache> cache = NewLRUCache(lo);
+ std::unique_ptr<WriteBufferManager> wbf(new WriteBufferManager(0, cache));
+
+ // Allocate 10MB, memory_used_ = 10240KB
+ wbf->ReserveMem(10 * 1024 * 1024);
+ ASSERT_EQ(wbf->dummy_entries_in_cache_usage(), 40 * kSizeDummyEntry);
+ ASSERT_GE(cache->GetPinnedUsage(), 40 * kSizeDummyEntry);
+ ASSERT_LT(cache->GetPinnedUsage(),
+ 40 * kSizeDummyEntry + kMetaDataChargeOverhead);
+
+ // Allocate 10MB, memory_used_ = 20480KB
+ // Some dummy entry insertion will fail due to full cache
+ wbf->ReserveMem(10 * 1024 * 1024);
+ ASSERT_GE(cache->GetPinnedUsage(), 40 * kSizeDummyEntry);
+ ASSERT_LE(cache->GetPinnedUsage(), 12 * 1024 * 1024);
+ ASSERT_LT(wbf->dummy_entries_in_cache_usage(), 80 * kSizeDummyEntry);
+
+ // Free 15MB after encoutering cache full, memory_used_ = 5120KB
+ wbf->FreeMem(15 * 1024 * 1024);
+ ASSERT_EQ(wbf->dummy_entries_in_cache_usage(), 20 * kSizeDummyEntry);
+ ASSERT_GE(cache->GetPinnedUsage(), 20 * kSizeDummyEntry);
+ ASSERT_LT(cache->GetPinnedUsage(),
+ 20 * kSizeDummyEntry + kMetaDataChargeOverhead);
+
+ // Reserve 15MB, creating cache full again, memory_used_ = 20480KB
+ wbf->ReserveMem(15 * 1024 * 1024);
+ ASSERT_LE(cache->GetPinnedUsage(), 12 * 1024 * 1024);
+ ASSERT_LT(wbf->dummy_entries_in_cache_usage(), 80 * kSizeDummyEntry);
+
+ // Increase capacity so next insert will fully succeed
+ cache->SetCapacity(40 * 1024 * 1024);
+
+ // Allocate 10MB, memory_used_ = 30720KB
+ wbf->ReserveMem(10 * 1024 * 1024);
+ ASSERT_EQ(wbf->dummy_entries_in_cache_usage(), 120 * kSizeDummyEntry);
+ ASSERT_GE(cache->GetPinnedUsage(), 120 * kSizeDummyEntry);
+ ASSERT_LT(cache->GetPinnedUsage(),
+ 120 * kSizeDummyEntry + kMetaDataChargeOverhead);
+
+ // Gradually release 20 MB
+ // It ended up sequentially releasing 32, 24, 18 dummy entries when
+ // memory_used_ decreases to 22528KB, 16384KB, 11776KB.
+ // In total, it releases 74 dummy entries
+ for (int i = 0; i < 40; i++) {
+ wbf->FreeMem(512 * 1024);
+ }
+
+ ASSERT_EQ(wbf->dummy_entries_in_cache_usage(), 46 * kSizeDummyEntry);
+ ASSERT_GE(cache->GetPinnedUsage(), 46 * kSizeDummyEntry);
+ ASSERT_LT(cache->GetPinnedUsage(),
+ 46 * kSizeDummyEntry + kMetaDataChargeOverhead);
+}
+
+#endif // ROCKSDB_LITE
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}