diff options
Diffstat (limited to 'src/rocksdb/memtable')
-rw-r--r-- | src/rocksdb/memtable/alloc_tracker.cc | 63 | ||||
-rw-r--r-- | src/rocksdb/memtable/hash_linklist_rep.cc | 926 | ||||
-rw-r--r-- | src/rocksdb/memtable/hash_skiplist_rep.cc | 393 | ||||
-rw-r--r-- | src/rocksdb/memtable/inlineskiplist.h | 1051 | ||||
-rw-r--r-- | src/rocksdb/memtable/inlineskiplist_test.cc | 664 | ||||
-rw-r--r-- | src/rocksdb/memtable/memtablerep_bench.cc | 689 | ||||
-rw-r--r-- | src/rocksdb/memtable/skiplist.h | 498 | ||||
-rw-r--r-- | src/rocksdb/memtable/skiplist_test.cc | 387 | ||||
-rw-r--r-- | src/rocksdb/memtable/skiplistrep.cc | 370 | ||||
-rw-r--r-- | src/rocksdb/memtable/stl_wrappers.h | 33 | ||||
-rw-r--r-- | src/rocksdb/memtable/vectorrep.cc | 309 | ||||
-rw-r--r-- | src/rocksdb/memtable/write_buffer_manager.cc | 202 | ||||
-rw-r--r-- | src/rocksdb/memtable/write_buffer_manager_test.cc | 305 |
13 files changed, 5890 insertions, 0 deletions
diff --git a/src/rocksdb/memtable/alloc_tracker.cc b/src/rocksdb/memtable/alloc_tracker.cc new file mode 100644 index 000000000..4c6d35431 --- /dev/null +++ b/src/rocksdb/memtable/alloc_tracker.cc @@ -0,0 +1,63 @@ +// Copyright (c) 2011-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). +// +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#include <assert.h> + +#include "memory/allocator.h" +#include "memory/arena.h" +#include "rocksdb/write_buffer_manager.h" + +namespace ROCKSDB_NAMESPACE { + +AllocTracker::AllocTracker(WriteBufferManager* write_buffer_manager) + : write_buffer_manager_(write_buffer_manager), + bytes_allocated_(0), + done_allocating_(false), + freed_(false) {} + +AllocTracker::~AllocTracker() { FreeMem(); } + +void AllocTracker::Allocate(size_t bytes) { + assert(write_buffer_manager_ != nullptr); + if (write_buffer_manager_->enabled() || + write_buffer_manager_->cost_to_cache()) { + bytes_allocated_.fetch_add(bytes, std::memory_order_relaxed); + write_buffer_manager_->ReserveMem(bytes); + } +} + +void AllocTracker::DoneAllocating() { + if (write_buffer_manager_ != nullptr && !done_allocating_) { + if (write_buffer_manager_->enabled() || + write_buffer_manager_->cost_to_cache()) { + write_buffer_manager_->ScheduleFreeMem( + bytes_allocated_.load(std::memory_order_relaxed)); + } else { + assert(bytes_allocated_.load(std::memory_order_relaxed) == 0); + } + done_allocating_ = true; + } +} + +void AllocTracker::FreeMem() { + if (!done_allocating_) { + DoneAllocating(); + } + if (write_buffer_manager_ != nullptr && !freed_) { + if (write_buffer_manager_->enabled() || + write_buffer_manager_->cost_to_cache()) { + write_buffer_manager_->FreeMem( + bytes_allocated_.load(std::memory_order_relaxed)); + } else { + assert(bytes_allocated_.load(std::memory_order_relaxed) == 0); + } + freed_ = true; + } +} +} // namespace ROCKSDB_NAMESPACE diff --git a/src/rocksdb/memtable/hash_linklist_rep.cc b/src/rocksdb/memtable/hash_linklist_rep.cc new file mode 100644 index 000000000..a71768304 --- /dev/null +++ b/src/rocksdb/memtable/hash_linklist_rep.cc @@ -0,0 +1,926 @@ +// Copyright (c) 2011-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). +// + +#ifndef ROCKSDB_LITE + +#include <algorithm> +#include <atomic> + +#include "db/memtable.h" +#include "memory/arena.h" +#include "memtable/skiplist.h" +#include "monitoring/histogram.h" +#include "port/port.h" +#include "rocksdb/memtablerep.h" +#include "rocksdb/slice.h" +#include "rocksdb/slice_transform.h" +#include "rocksdb/utilities/options_type.h" +#include "util/hash.h" + +namespace ROCKSDB_NAMESPACE { +namespace { + +using Key = const char*; +using MemtableSkipList = SkipList<Key, const MemTableRep::KeyComparator&>; +using Pointer = std::atomic<void*>; + +// A data structure used as the header of a link list of a hash bucket. +struct BucketHeader { + Pointer next; + std::atomic<uint32_t> num_entries; + + explicit BucketHeader(void* n, uint32_t count) + : next(n), num_entries(count) {} + + bool IsSkipListBucket() { + return next.load(std::memory_order_relaxed) == this; + } + + uint32_t GetNumEntries() const { + return num_entries.load(std::memory_order_relaxed); + } + + // REQUIRES: called from single-threaded Insert() + void IncNumEntries() { + // Only one thread can do write at one time. No need to do atomic + // incremental. Update it with relaxed load and store. + num_entries.store(GetNumEntries() + 1, std::memory_order_relaxed); + } +}; + +// A data structure used as the header of a skip list of a hash bucket. +struct SkipListBucketHeader { + BucketHeader Counting_header; + MemtableSkipList skip_list; + + explicit SkipListBucketHeader(const MemTableRep::KeyComparator& cmp, + Allocator* allocator, uint32_t count) + : Counting_header(this, // Pointing to itself to indicate header type. + count), + skip_list(cmp, allocator) {} +}; + +struct Node { + // Accessors/mutators for links. Wrapped in methods so we can + // add the appropriate barriers as necessary. + Node* Next() { + // Use an 'acquire load' so that we observe a fully initialized + // version of the returned Node. + return next_.load(std::memory_order_acquire); + } + void SetNext(Node* x) { + // Use a 'release store' so that anybody who reads through this + // pointer observes a fully initialized version of the inserted node. + next_.store(x, std::memory_order_release); + } + // No-barrier variants that can be safely used in a few locations. + Node* NoBarrier_Next() { return next_.load(std::memory_order_relaxed); } + + void NoBarrier_SetNext(Node* x) { next_.store(x, std::memory_order_relaxed); } + + // Needed for placement new below which is fine + Node() {} + + private: + std::atomic<Node*> next_; + + // Prohibit copying due to the below + Node(const Node&) = delete; + Node& operator=(const Node&) = delete; + + public: + char key[1]; +}; + +// Memory structure of the mem table: +// It is a hash table, each bucket points to one entry, a linked list or a +// skip list. In order to track total number of records in a bucket to determine +// whether should switch to skip list, a header is added just to indicate +// number of entries in the bucket. +// +// +// +-----> NULL Case 1. Empty bucket +// | +// | +// | +---> +-------+ +// | | | Next +--> NULL +// | | +-------+ +// +-----+ | | | | Case 2. One Entry in bucket. +// | +-+ | | Data | next pointer points to +// +-----+ | | | NULL. All other cases +// | | | | | next pointer is not NULL. +// +-----+ | +-------+ +// | +---+ +// +-----+ +-> +-------+ +> +-------+ +-> +-------+ +// | | | | Next +--+ | Next +--+ | Next +-->NULL +// +-----+ | +-------+ +-------+ +-------+ +// | +-----+ | Count | | | | | +// +-----+ +-------+ | Data | | Data | +// | | | | | | +// +-----+ Case 3. | | | | +// | | A header +-------+ +-------+ +// +-----+ points to +// | | a linked list. Count indicates total number +// +-----+ of rows in this bucket. +// | | +// +-----+ +-> +-------+ <--+ +// | | | | Next +----+ +// +-----+ | +-------+ Case 4. A header points to a skip +// | +----+ | Count | list and next pointer points to +// +-----+ +-------+ itself, to distinguish case 3 or 4. +// | | | | Count still is kept to indicates total +// +-----+ | Skip +--> of entries in the bucket for debugging +// | | | List | Data purpose. +// | | | +--> +// +-----+ | | +// | | +-------+ +// +-----+ +// +// We don't have data race when changing cases because: +// (1) When changing from case 2->3, we create a new bucket header, put the +// single node there first without changing the original node, and do a +// release store when changing the bucket pointer. In that case, a reader +// who sees a stale value of the bucket pointer will read this node, while +// a reader sees the correct value because of the release store. +// (2) When changing case 3->4, a new header is created with skip list points +// to the data, before doing an acquire store to change the bucket pointer. +// The old header and nodes are never changed, so any reader sees any +// of those existing pointers will guarantee to be able to iterate to the +// end of the linked list. +// (3) Header's next pointer in case 3 might change, but they are never equal +// to itself, so no matter a reader sees any stale or newer value, it will +// be able to correctly distinguish case 3 and 4. +// +// The reason that we use case 2 is we want to make the format to be efficient +// when the utilization of buckets is relatively low. If we use case 3 for +// single entry bucket, we will need to waste 12 bytes for every entry, +// which can be significant decrease of memory utilization. +class HashLinkListRep : public MemTableRep { + public: + HashLinkListRep(const MemTableRep::KeyComparator& compare, + Allocator* allocator, const SliceTransform* transform, + size_t bucket_size, uint32_t threshold_use_skiplist, + size_t huge_page_tlb_size, Logger* logger, + int bucket_entries_logging_threshold, + bool if_log_bucket_dist_when_flash); + + KeyHandle Allocate(const size_t len, char** buf) override; + + void Insert(KeyHandle handle) override; + + bool Contains(const char* key) const override; + + size_t ApproximateMemoryUsage() override; + + void Get(const LookupKey& k, void* callback_args, + bool (*callback_func)(void* arg, const char* entry)) override; + + ~HashLinkListRep() override; + + MemTableRep::Iterator* GetIterator(Arena* arena = nullptr) override; + + MemTableRep::Iterator* GetDynamicPrefixIterator( + Arena* arena = nullptr) override; + + private: + friend class DynamicIterator; + + size_t bucket_size_; + + // Maps slices (which are transformed user keys) to buckets of keys sharing + // the same transform. + Pointer* buckets_; + + const uint32_t threshold_use_skiplist_; + + // The user-supplied transform whose domain is the user keys. + const SliceTransform* transform_; + + const MemTableRep::KeyComparator& compare_; + + Logger* logger_; + int bucket_entries_logging_threshold_; + bool if_log_bucket_dist_when_flash_; + + bool LinkListContains(Node* head, const Slice& key) const; + + bool IsEmptyBucket(Pointer& bucket_pointer) const { + return bucket_pointer.load(std::memory_order_acquire) == nullptr; + } + + // Precondition: GetLinkListFirstNode() must have been called first and return + // null so that it must be a skip list bucket + SkipListBucketHeader* GetSkipListBucketHeader(Pointer& bucket_pointer) const; + + // Returning nullptr indicates it is a skip list bucket. + Node* GetLinkListFirstNode(Pointer& bucket_pointer) const; + + Slice GetPrefix(const Slice& internal_key) const { + return transform_->Transform(ExtractUserKey(internal_key)); + } + + size_t GetHash(const Slice& slice) const { + return GetSliceRangedNPHash(slice, bucket_size_); + } + + Pointer& GetBucket(size_t i) const { return buckets_[i]; } + + Pointer& GetBucket(const Slice& slice) const { + return GetBucket(GetHash(slice)); + } + + bool Equal(const Slice& a, const Key& b) const { + return (compare_(b, a) == 0); + } + + bool Equal(const Key& a, const Key& b) const { return (compare_(a, b) == 0); } + + bool KeyIsAfterNode(const Slice& internal_key, const Node* n) const { + // nullptr n is considered infinite + return (n != nullptr) && (compare_(n->key, internal_key) < 0); + } + + bool KeyIsAfterNode(const Key& key, const Node* n) const { + // nullptr n is considered infinite + return (n != nullptr) && (compare_(n->key, key) < 0); + } + + bool KeyIsAfterOrAtNode(const Slice& internal_key, const Node* n) const { + // nullptr n is considered infinite + return (n != nullptr) && (compare_(n->key, internal_key) <= 0); + } + + bool KeyIsAfterOrAtNode(const Key& key, const Node* n) const { + // nullptr n is considered infinite + return (n != nullptr) && (compare_(n->key, key) <= 0); + } + + Node* FindGreaterOrEqualInBucket(Node* head, const Slice& key) const; + Node* FindLessOrEqualInBucket(Node* head, const Slice& key) const; + + class FullListIterator : public MemTableRep::Iterator { + public: + explicit FullListIterator(MemtableSkipList* list, Allocator* allocator) + : iter_(list), full_list_(list), allocator_(allocator) {} + + ~FullListIterator() override {} + + // Returns true iff the iterator is positioned at a valid node. + bool Valid() const override { return iter_.Valid(); } + + // Returns the key at the current position. + // REQUIRES: Valid() + const char* key() const override { + assert(Valid()); + return iter_.key(); + } + + // Advances to the next position. + // REQUIRES: Valid() + void Next() override { + assert(Valid()); + iter_.Next(); + } + + // Advances to the previous position. + // REQUIRES: Valid() + void Prev() override { + assert(Valid()); + iter_.Prev(); + } + + // Advance to the first entry with a key >= target + void Seek(const Slice& internal_key, const char* memtable_key) override { + const char* encoded_key = (memtable_key != nullptr) + ? memtable_key + : EncodeKey(&tmp_, internal_key); + iter_.Seek(encoded_key); + } + + // Retreat to the last entry with a key <= target + void SeekForPrev(const Slice& internal_key, + const char* memtable_key) override { + const char* encoded_key = (memtable_key != nullptr) + ? memtable_key + : EncodeKey(&tmp_, internal_key); + iter_.SeekForPrev(encoded_key); + } + + // Position at the first entry in collection. + // Final state of iterator is Valid() iff collection is not empty. + void SeekToFirst() override { iter_.SeekToFirst(); } + + // Position at the last entry in collection. + // Final state of iterator is Valid() iff collection is not empty. + void SeekToLast() override { iter_.SeekToLast(); } + + private: + MemtableSkipList::Iterator iter_; + // To destruct with the iterator. + std::unique_ptr<MemtableSkipList> full_list_; + std::unique_ptr<Allocator> allocator_; + std::string tmp_; // For passing to EncodeKey + }; + + class LinkListIterator : public MemTableRep::Iterator { + public: + explicit LinkListIterator(const HashLinkListRep* const hash_link_list_rep, + Node* head) + : hash_link_list_rep_(hash_link_list_rep), + head_(head), + node_(nullptr) {} + + ~LinkListIterator() override {} + + // Returns true iff the iterator is positioned at a valid node. + bool Valid() const override { return node_ != nullptr; } + + // Returns the key at the current position. + // REQUIRES: Valid() + const char* key() const override { + assert(Valid()); + return node_->key; + } + + // Advances to the next position. + // REQUIRES: Valid() + void Next() override { + assert(Valid()); + node_ = node_->Next(); + } + + // Advances to the previous position. + // REQUIRES: Valid() + void Prev() override { + // Prefix iterator does not support total order. + // We simply set the iterator to invalid state + Reset(nullptr); + } + + // Advance to the first entry with a key >= target + void Seek(const Slice& internal_key, + const char* /*memtable_key*/) override { + node_ = + hash_link_list_rep_->FindGreaterOrEqualInBucket(head_, internal_key); + } + + // Retreat to the last entry with a key <= target + void SeekForPrev(const Slice& /*internal_key*/, + const char* /*memtable_key*/) override { + // Since we do not support Prev() + // We simply do not support SeekForPrev + Reset(nullptr); + } + + // Position at the first entry in collection. + // Final state of iterator is Valid() iff collection is not empty. + void SeekToFirst() override { + // Prefix iterator does not support total order. + // We simply set the iterator to invalid state + Reset(nullptr); + } + + // Position at the last entry in collection. + // Final state of iterator is Valid() iff collection is not empty. + void SeekToLast() override { + // Prefix iterator does not support total order. + // We simply set the iterator to invalid state + Reset(nullptr); + } + + protected: + void Reset(Node* head) { + head_ = head; + node_ = nullptr; + } + + private: + friend class HashLinkListRep; + const HashLinkListRep* const hash_link_list_rep_; + Node* head_; + Node* node_; + + virtual void SeekToHead() { node_ = head_; } + }; + + class DynamicIterator : public HashLinkListRep::LinkListIterator { + public: + explicit DynamicIterator(HashLinkListRep& memtable_rep) + : HashLinkListRep::LinkListIterator(&memtable_rep, nullptr), + memtable_rep_(memtable_rep) {} + + // Advance to the first entry with a key >= target + void Seek(const Slice& k, const char* memtable_key) override { + auto transformed = memtable_rep_.GetPrefix(k); + Pointer& bucket = memtable_rep_.GetBucket(transformed); + + if (memtable_rep_.IsEmptyBucket(bucket)) { + skip_list_iter_.reset(); + Reset(nullptr); + } else { + Node* first_linked_list_node = + memtable_rep_.GetLinkListFirstNode(bucket); + if (first_linked_list_node != nullptr) { + // The bucket is organized as a linked list + skip_list_iter_.reset(); + Reset(first_linked_list_node); + HashLinkListRep::LinkListIterator::Seek(k, memtable_key); + + } else { + SkipListBucketHeader* skip_list_header = + memtable_rep_.GetSkipListBucketHeader(bucket); + assert(skip_list_header != nullptr); + // The bucket is organized as a skip list + if (!skip_list_iter_) { + skip_list_iter_.reset( + new MemtableSkipList::Iterator(&skip_list_header->skip_list)); + } else { + skip_list_iter_->SetList(&skip_list_header->skip_list); + } + if (memtable_key != nullptr) { + skip_list_iter_->Seek(memtable_key); + } else { + IterKey encoded_key; + encoded_key.EncodeLengthPrefixedKey(k); + skip_list_iter_->Seek(encoded_key.GetUserKey().data()); + } + } + } + } + + bool Valid() const override { + if (skip_list_iter_) { + return skip_list_iter_->Valid(); + } + return HashLinkListRep::LinkListIterator::Valid(); + } + + const char* key() const override { + if (skip_list_iter_) { + return skip_list_iter_->key(); + } + return HashLinkListRep::LinkListIterator::key(); + } + + void Next() override { + if (skip_list_iter_) { + skip_list_iter_->Next(); + } else { + HashLinkListRep::LinkListIterator::Next(); + } + } + + private: + // the underlying memtable + const HashLinkListRep& memtable_rep_; + std::unique_ptr<MemtableSkipList::Iterator> skip_list_iter_; + }; + + class EmptyIterator : public MemTableRep::Iterator { + // This is used when there wasn't a bucket. It is cheaper than + // instantiating an empty bucket over which to iterate. + public: + EmptyIterator() {} + bool Valid() const override { return false; } + const char* key() const override { + assert(false); + return nullptr; + } + void Next() override {} + void Prev() override {} + void Seek(const Slice& /*user_key*/, + const char* /*memtable_key*/) override {} + void SeekForPrev(const Slice& /*user_key*/, + const char* /*memtable_key*/) override {} + void SeekToFirst() override {} + void SeekToLast() override {} + + private: + }; +}; + +HashLinkListRep::HashLinkListRep( + const MemTableRep::KeyComparator& compare, Allocator* allocator, + const SliceTransform* transform, size_t bucket_size, + uint32_t threshold_use_skiplist, size_t huge_page_tlb_size, Logger* logger, + int bucket_entries_logging_threshold, bool if_log_bucket_dist_when_flash) + : MemTableRep(allocator), + bucket_size_(bucket_size), + // Threshold to use skip list doesn't make sense if less than 3, so we + // force it to be minimum of 3 to simplify implementation. + threshold_use_skiplist_(std::max(threshold_use_skiplist, 3U)), + transform_(transform), + compare_(compare), + logger_(logger), + bucket_entries_logging_threshold_(bucket_entries_logging_threshold), + if_log_bucket_dist_when_flash_(if_log_bucket_dist_when_flash) { + char* mem = allocator_->AllocateAligned(sizeof(Pointer) * bucket_size, + huge_page_tlb_size, logger); + + buckets_ = new (mem) Pointer[bucket_size]; + + for (size_t i = 0; i < bucket_size_; ++i) { + buckets_[i].store(nullptr, std::memory_order_relaxed); + } +} + +HashLinkListRep::~HashLinkListRep() {} + +KeyHandle HashLinkListRep::Allocate(const size_t len, char** buf) { + char* mem = allocator_->AllocateAligned(sizeof(Node) + len); + Node* x = new (mem) Node(); + *buf = x->key; + return static_cast<void*>(x); +} + +SkipListBucketHeader* HashLinkListRep::GetSkipListBucketHeader( + Pointer& bucket_pointer) const { + Pointer* first_next_pointer = + static_cast<Pointer*>(bucket_pointer.load(std::memory_order_acquire)); + assert(first_next_pointer != nullptr); + assert(first_next_pointer->load(std::memory_order_relaxed) != nullptr); + + // Counting header + BucketHeader* header = reinterpret_cast<BucketHeader*>(first_next_pointer); + assert(header->IsSkipListBucket()); + assert(header->GetNumEntries() > threshold_use_skiplist_); + auto* skip_list_bucket_header = + reinterpret_cast<SkipListBucketHeader*>(header); + assert(skip_list_bucket_header->Counting_header.next.load( + std::memory_order_relaxed) == header); + return skip_list_bucket_header; +} + +Node* HashLinkListRep::GetLinkListFirstNode(Pointer& bucket_pointer) const { + Pointer* first_next_pointer = + static_cast<Pointer*>(bucket_pointer.load(std::memory_order_acquire)); + assert(first_next_pointer != nullptr); + if (first_next_pointer->load(std::memory_order_relaxed) == nullptr) { + // Single entry bucket + return reinterpret_cast<Node*>(first_next_pointer); + } + + // It is possible that after we fetch first_next_pointer it is modified + // and the next is not null anymore. In this case, the bucket should have been + // modified to a counting header, so we should reload the first_next_pointer + // to make sure we see the update. + first_next_pointer = + static_cast<Pointer*>(bucket_pointer.load(std::memory_order_acquire)); + // Counting header + BucketHeader* header = reinterpret_cast<BucketHeader*>(first_next_pointer); + if (!header->IsSkipListBucket()) { + assert(header->GetNumEntries() <= threshold_use_skiplist_); + return reinterpret_cast<Node*>( + header->next.load(std::memory_order_acquire)); + } + assert(header->GetNumEntries() > threshold_use_skiplist_); + return nullptr; +} + +void HashLinkListRep::Insert(KeyHandle handle) { + Node* x = static_cast<Node*>(handle); + assert(!Contains(x->key)); + Slice internal_key = GetLengthPrefixedSlice(x->key); + auto transformed = GetPrefix(internal_key); + auto& bucket = buckets_[GetHash(transformed)]; + Pointer* first_next_pointer = + static_cast<Pointer*>(bucket.load(std::memory_order_relaxed)); + + if (first_next_pointer == nullptr) { + // Case 1. empty bucket + // NoBarrier_SetNext() suffices since we will add a barrier when + // we publish a pointer to "x" in prev[i]. + x->NoBarrier_SetNext(nullptr); + bucket.store(x, std::memory_order_release); + return; + } + + BucketHeader* header = nullptr; + if (first_next_pointer->load(std::memory_order_relaxed) == nullptr) { + // Case 2. only one entry in the bucket + // Need to convert to a Counting bucket and turn to case 4. + Node* first = reinterpret_cast<Node*>(first_next_pointer); + // Need to add a bucket header. + // We have to first convert it to a bucket with header before inserting + // the new node. Otherwise, we might need to change next pointer of first. + // In that case, a reader might sees the next pointer is NULL and wrongly + // think the node is a bucket header. + auto* mem = allocator_->AllocateAligned(sizeof(BucketHeader)); + header = new (mem) BucketHeader(first, 1); + bucket.store(header, std::memory_order_release); + } else { + header = reinterpret_cast<BucketHeader*>(first_next_pointer); + if (header->IsSkipListBucket()) { + // Case 4. Bucket is already a skip list + assert(header->GetNumEntries() > threshold_use_skiplist_); + auto* skip_list_bucket_header = + reinterpret_cast<SkipListBucketHeader*>(header); + // Only one thread can execute Insert() at one time. No need to do atomic + // incremental. + skip_list_bucket_header->Counting_header.IncNumEntries(); + skip_list_bucket_header->skip_list.Insert(x->key); + return; + } + } + + if (bucket_entries_logging_threshold_ > 0 && + header->GetNumEntries() == + static_cast<uint32_t>(bucket_entries_logging_threshold_)) { + Info(logger_, + "HashLinkedList bucket %" ROCKSDB_PRIszt + " has more than %d " + "entries. Key to insert: %s", + GetHash(transformed), header->GetNumEntries(), + GetLengthPrefixedSlice(x->key).ToString(true).c_str()); + } + + if (header->GetNumEntries() == threshold_use_skiplist_) { + // Case 3. number of entries reaches the threshold so need to convert to + // skip list. + LinkListIterator bucket_iter( + this, reinterpret_cast<Node*>( + first_next_pointer->load(std::memory_order_relaxed))); + auto mem = allocator_->AllocateAligned(sizeof(SkipListBucketHeader)); + SkipListBucketHeader* new_skip_list_header = new (mem) + SkipListBucketHeader(compare_, allocator_, header->GetNumEntries() + 1); + auto& skip_list = new_skip_list_header->skip_list; + + // Add all current entries to the skip list + for (bucket_iter.SeekToHead(); bucket_iter.Valid(); bucket_iter.Next()) { + skip_list.Insert(bucket_iter.key()); + } + + // insert the new entry + skip_list.Insert(x->key); + // Set the bucket + bucket.store(new_skip_list_header, std::memory_order_release); + } else { + // Case 5. Need to insert to the sorted linked list without changing the + // header. + Node* first = + reinterpret_cast<Node*>(header->next.load(std::memory_order_relaxed)); + assert(first != nullptr); + // Advance counter unless the bucket needs to be advanced to skip list. + // In that case, we need to make sure the previous count never exceeds + // threshold_use_skiplist_ to avoid readers to cast to wrong format. + header->IncNumEntries(); + + Node* cur = first; + Node* prev = nullptr; + while (true) { + if (cur == nullptr) { + break; + } + Node* next = cur->Next(); + // Make sure the lists are sorted. + // If x points to head_ or next points nullptr, it is trivially satisfied. + assert((cur == first) || (next == nullptr) || + KeyIsAfterNode(next->key, cur)); + if (KeyIsAfterNode(internal_key, cur)) { + // Keep searching in this list + prev = cur; + cur = next; + } else { + break; + } + } + + // Our data structure does not allow duplicate insertion + assert(cur == nullptr || !Equal(x->key, cur->key)); + + // NoBarrier_SetNext() suffices since we will add a barrier when + // we publish a pointer to "x" in prev[i]. + x->NoBarrier_SetNext(cur); + + if (prev) { + prev->SetNext(x); + } else { + header->next.store(static_cast<void*>(x), std::memory_order_release); + } + } +} + +bool HashLinkListRep::Contains(const char* key) const { + Slice internal_key = GetLengthPrefixedSlice(key); + + auto transformed = GetPrefix(internal_key); + Pointer& bucket = GetBucket(transformed); + if (IsEmptyBucket(bucket)) { + return false; + } + + Node* linked_list_node = GetLinkListFirstNode(bucket); + if (linked_list_node != nullptr) { + return LinkListContains(linked_list_node, internal_key); + } + + SkipListBucketHeader* skip_list_header = GetSkipListBucketHeader(bucket); + if (skip_list_header != nullptr) { + return skip_list_header->skip_list.Contains(key); + } + return false; +} + +size_t HashLinkListRep::ApproximateMemoryUsage() { + // Memory is always allocated from the allocator. + return 0; +} + +void HashLinkListRep::Get(const LookupKey& k, void* callback_args, + bool (*callback_func)(void* arg, const char* entry)) { + auto transformed = transform_->Transform(k.user_key()); + Pointer& bucket = GetBucket(transformed); + + if (IsEmptyBucket(bucket)) { + return; + } + + auto* link_list_head = GetLinkListFirstNode(bucket); + if (link_list_head != nullptr) { + LinkListIterator iter(this, link_list_head); + for (iter.Seek(k.internal_key(), nullptr); + iter.Valid() && callback_func(callback_args, iter.key()); + iter.Next()) { + } + } else { + auto* skip_list_header = GetSkipListBucketHeader(bucket); + if (skip_list_header != nullptr) { + // Is a skip list + MemtableSkipList::Iterator iter(&skip_list_header->skip_list); + for (iter.Seek(k.memtable_key().data()); + iter.Valid() && callback_func(callback_args, iter.key()); + iter.Next()) { + } + } + } +} + +MemTableRep::Iterator* HashLinkListRep::GetIterator(Arena* alloc_arena) { + // allocate a new arena of similar size to the one currently in use + Arena* new_arena = new Arena(allocator_->BlockSize()); + auto list = new MemtableSkipList(compare_, new_arena); + HistogramImpl keys_per_bucket_hist; + + for (size_t i = 0; i < bucket_size_; ++i) { + int count = 0; + Pointer& bucket = GetBucket(i); + if (!IsEmptyBucket(bucket)) { + auto* link_list_head = GetLinkListFirstNode(bucket); + if (link_list_head != nullptr) { + LinkListIterator itr(this, link_list_head); + for (itr.SeekToHead(); itr.Valid(); itr.Next()) { + list->Insert(itr.key()); + count++; + } + } else { + auto* skip_list_header = GetSkipListBucketHeader(bucket); + assert(skip_list_header != nullptr); + // Is a skip list + MemtableSkipList::Iterator itr(&skip_list_header->skip_list); + for (itr.SeekToFirst(); itr.Valid(); itr.Next()) { + list->Insert(itr.key()); + count++; + } + } + } + if (if_log_bucket_dist_when_flash_) { + keys_per_bucket_hist.Add(count); + } + } + if (if_log_bucket_dist_when_flash_ && logger_ != nullptr) { + Info(logger_, "hashLinkedList Entry distribution among buckets: %s", + keys_per_bucket_hist.ToString().c_str()); + } + + if (alloc_arena == nullptr) { + return new FullListIterator(list, new_arena); + } else { + auto mem = alloc_arena->AllocateAligned(sizeof(FullListIterator)); + return new (mem) FullListIterator(list, new_arena); + } +} + +MemTableRep::Iterator* HashLinkListRep::GetDynamicPrefixIterator( + Arena* alloc_arena) { + if (alloc_arena == nullptr) { + return new DynamicIterator(*this); + } else { + auto mem = alloc_arena->AllocateAligned(sizeof(DynamicIterator)); + return new (mem) DynamicIterator(*this); + } +} + +bool HashLinkListRep::LinkListContains(Node* head, + const Slice& user_key) const { + Node* x = FindGreaterOrEqualInBucket(head, user_key); + return (x != nullptr && Equal(user_key, x->key)); +} + +Node* HashLinkListRep::FindGreaterOrEqualInBucket(Node* head, + const Slice& key) const { + Node* x = head; + while (true) { + if (x == nullptr) { + return x; + } + Node* next = x->Next(); + // Make sure the lists are sorted. + // If x points to head_ or next points nullptr, it is trivially satisfied. + assert((x == head) || (next == nullptr) || KeyIsAfterNode(next->key, x)); + if (KeyIsAfterNode(key, x)) { + // Keep searching in this list + x = next; + } else { + break; + } + } + return x; +} + +struct HashLinkListRepOptions { + static const char* kName() { return "HashLinkListRepFactoryOptions"; } + size_t bucket_count; + uint32_t threshold_use_skiplist; + size_t huge_page_tlb_size; + int bucket_entries_logging_threshold; + bool if_log_bucket_dist_when_flash; +}; + +static std::unordered_map<std::string, OptionTypeInfo> hash_linklist_info = { + {"bucket_count", + {offsetof(struct HashLinkListRepOptions, bucket_count), OptionType::kSizeT, + OptionVerificationType::kNormal, OptionTypeFlags::kNone}}, + {"threshold", + {offsetof(struct HashLinkListRepOptions, threshold_use_skiplist), + OptionType::kUInt32T, OptionVerificationType::kNormal, + OptionTypeFlags::kNone}}, + {"huge_page_size", + {offsetof(struct HashLinkListRepOptions, huge_page_tlb_size), + OptionType::kSizeT, OptionVerificationType::kNormal, + OptionTypeFlags::kNone}}, + {"logging_threshold", + {offsetof(struct HashLinkListRepOptions, bucket_entries_logging_threshold), + OptionType::kInt, OptionVerificationType::kNormal, + OptionTypeFlags::kNone}}, + {"log_when_flash", + {offsetof(struct HashLinkListRepOptions, if_log_bucket_dist_when_flash), + OptionType::kBoolean, OptionVerificationType::kNormal, + OptionTypeFlags::kNone}}, +}; + +class HashLinkListRepFactory : public MemTableRepFactory { + public: + explicit HashLinkListRepFactory(size_t bucket_count, + uint32_t threshold_use_skiplist, + size_t huge_page_tlb_size, + int bucket_entries_logging_threshold, + bool if_log_bucket_dist_when_flash) { + options_.bucket_count = bucket_count; + options_.threshold_use_skiplist = threshold_use_skiplist; + options_.huge_page_tlb_size = huge_page_tlb_size; + options_.bucket_entries_logging_threshold = + bucket_entries_logging_threshold; + options_.if_log_bucket_dist_when_flash = if_log_bucket_dist_when_flash; + RegisterOptions(&options_, &hash_linklist_info); + } + + using MemTableRepFactory::CreateMemTableRep; + virtual MemTableRep* CreateMemTableRep( + const MemTableRep::KeyComparator& compare, Allocator* allocator, + const SliceTransform* transform, Logger* logger) override; + + static const char* kClassName() { return "HashLinkListRepFactory"; } + static const char* kNickName() { return "hash_linkedlist"; } + virtual const char* Name() const override { return kClassName(); } + virtual const char* NickName() const override { return kNickName(); } + + private: + HashLinkListRepOptions options_; +}; + +} // namespace + +MemTableRep* HashLinkListRepFactory::CreateMemTableRep( + const MemTableRep::KeyComparator& compare, Allocator* allocator, + const SliceTransform* transform, Logger* logger) { + return new HashLinkListRep( + compare, allocator, transform, options_.bucket_count, + options_.threshold_use_skiplist, options_.huge_page_tlb_size, logger, + options_.bucket_entries_logging_threshold, + options_.if_log_bucket_dist_when_flash); +} + +MemTableRepFactory* NewHashLinkListRepFactory( + size_t bucket_count, size_t huge_page_tlb_size, + int bucket_entries_logging_threshold, bool if_log_bucket_dist_when_flash, + uint32_t threshold_use_skiplist) { + return new HashLinkListRepFactory( + bucket_count, threshold_use_skiplist, huge_page_tlb_size, + bucket_entries_logging_threshold, if_log_bucket_dist_when_flash); +} + +} // namespace ROCKSDB_NAMESPACE +#endif // ROCKSDB_LITE diff --git a/src/rocksdb/memtable/hash_skiplist_rep.cc b/src/rocksdb/memtable/hash_skiplist_rep.cc new file mode 100644 index 000000000..9d093829b --- /dev/null +++ b/src/rocksdb/memtable/hash_skiplist_rep.cc @@ -0,0 +1,393 @@ +// Copyright (c) 2011-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). +// + +#ifndef ROCKSDB_LITE +#include <atomic> + +#include "db/memtable.h" +#include "memory/arena.h" +#include "memtable/skiplist.h" +#include "port/port.h" +#include "rocksdb/memtablerep.h" +#include "rocksdb/slice.h" +#include "rocksdb/slice_transform.h" +#include "rocksdb/utilities/options_type.h" +#include "util/murmurhash.h" + +namespace ROCKSDB_NAMESPACE { +namespace { + +class HashSkipListRep : public MemTableRep { + public: + HashSkipListRep(const MemTableRep::KeyComparator& compare, + Allocator* allocator, const SliceTransform* transform, + size_t bucket_size, int32_t skiplist_height, + int32_t skiplist_branching_factor); + + void Insert(KeyHandle handle) override; + + bool Contains(const char* key) const override; + + size_t ApproximateMemoryUsage() override; + + void Get(const LookupKey& k, void* callback_args, + bool (*callback_func)(void* arg, const char* entry)) override; + + ~HashSkipListRep() override; + + MemTableRep::Iterator* GetIterator(Arena* arena = nullptr) override; + + MemTableRep::Iterator* GetDynamicPrefixIterator( + Arena* arena = nullptr) override; + + private: + friend class DynamicIterator; + using Bucket = SkipList<const char*, const MemTableRep::KeyComparator&>; + + size_t bucket_size_; + + const int32_t skiplist_height_; + const int32_t skiplist_branching_factor_; + + // Maps slices (which are transformed user keys) to buckets of keys sharing + // the same transform. + std::atomic<Bucket*>* buckets_; + + // The user-supplied transform whose domain is the user keys. + const SliceTransform* transform_; + + const MemTableRep::KeyComparator& compare_; + // immutable after construction + Allocator* const allocator_; + + inline size_t GetHash(const Slice& slice) const { + return MurmurHash(slice.data(), static_cast<int>(slice.size()), 0) % + bucket_size_; + } + inline Bucket* GetBucket(size_t i) const { + return buckets_[i].load(std::memory_order_acquire); + } + inline Bucket* GetBucket(const Slice& slice) const { + return GetBucket(GetHash(slice)); + } + // Get a bucket from buckets_. If the bucket hasn't been initialized yet, + // initialize it before returning. + Bucket* GetInitializedBucket(const Slice& transformed); + + class Iterator : public MemTableRep::Iterator { + public: + explicit Iterator(Bucket* list, bool own_list = true, + Arena* arena = nullptr) + : list_(list), iter_(list), own_list_(own_list), arena_(arena) {} + + ~Iterator() override { + // if we own the list, we should also delete it + if (own_list_) { + assert(list_ != nullptr); + delete list_; + } + } + + // Returns true iff the iterator is positioned at a valid node. + bool Valid() const override { return list_ != nullptr && iter_.Valid(); } + + // Returns the key at the current position. + // REQUIRES: Valid() + const char* key() const override { + assert(Valid()); + return iter_.key(); + } + + // Advances to the next position. + // REQUIRES: Valid() + void Next() override { + assert(Valid()); + iter_.Next(); + } + + // Advances to the previous position. + // REQUIRES: Valid() + void Prev() override { + assert(Valid()); + iter_.Prev(); + } + + // Advance to the first entry with a key >= target + void Seek(const Slice& internal_key, const char* memtable_key) override { + if (list_ != nullptr) { + const char* encoded_key = (memtable_key != nullptr) + ? memtable_key + : EncodeKey(&tmp_, internal_key); + iter_.Seek(encoded_key); + } + } + + // Retreat to the last entry with a key <= target + void SeekForPrev(const Slice& /*internal_key*/, + const char* /*memtable_key*/) override { + // not supported + assert(false); + } + + // Position at the first entry in collection. + // Final state of iterator is Valid() iff collection is not empty. + void SeekToFirst() override { + if (list_ != nullptr) { + iter_.SeekToFirst(); + } + } + + // Position at the last entry in collection. + // Final state of iterator is Valid() iff collection is not empty. + void SeekToLast() override { + if (list_ != nullptr) { + iter_.SeekToLast(); + } + } + + protected: + void Reset(Bucket* list) { + if (own_list_) { + assert(list_ != nullptr); + delete list_; + } + list_ = list; + iter_.SetList(list); + own_list_ = false; + } + + private: + // if list_ is nullptr, we should NEVER call any methods on iter_ + // if list_ is nullptr, this Iterator is not Valid() + Bucket* list_; + Bucket::Iterator iter_; + // here we track if we own list_. If we own it, we are also + // responsible for it's cleaning. This is a poor man's std::shared_ptr + bool own_list_; + std::unique_ptr<Arena> arena_; + std::string tmp_; // For passing to EncodeKey + }; + + class DynamicIterator : public HashSkipListRep::Iterator { + public: + explicit DynamicIterator(const HashSkipListRep& memtable_rep) + : HashSkipListRep::Iterator(nullptr, false), + memtable_rep_(memtable_rep) {} + + // Advance to the first entry with a key >= target + void Seek(const Slice& k, const char* memtable_key) override { + auto transformed = memtable_rep_.transform_->Transform(ExtractUserKey(k)); + Reset(memtable_rep_.GetBucket(transformed)); + HashSkipListRep::Iterator::Seek(k, memtable_key); + } + + // Position at the first entry in collection. + // Final state of iterator is Valid() iff collection is not empty. + void SeekToFirst() override { + // Prefix iterator does not support total order. + // We simply set the iterator to invalid state + Reset(nullptr); + } + + // Position at the last entry in collection. + // Final state of iterator is Valid() iff collection is not empty. + void SeekToLast() override { + // Prefix iterator does not support total order. + // We simply set the iterator to invalid state + Reset(nullptr); + } + + private: + // the underlying memtable + const HashSkipListRep& memtable_rep_; + }; + + class EmptyIterator : public MemTableRep::Iterator { + // This is used when there wasn't a bucket. It is cheaper than + // instantiating an empty bucket over which to iterate. + public: + EmptyIterator() {} + bool Valid() const override { return false; } + const char* key() const override { + assert(false); + return nullptr; + } + void Next() override {} + void Prev() override {} + void Seek(const Slice& /*internal_key*/, + const char* /*memtable_key*/) override {} + void SeekForPrev(const Slice& /*internal_key*/, + const char* /*memtable_key*/) override {} + void SeekToFirst() override {} + void SeekToLast() override {} + + private: + }; +}; + +HashSkipListRep::HashSkipListRep(const MemTableRep::KeyComparator& compare, + Allocator* allocator, + const SliceTransform* transform, + size_t bucket_size, int32_t skiplist_height, + int32_t skiplist_branching_factor) + : MemTableRep(allocator), + bucket_size_(bucket_size), + skiplist_height_(skiplist_height), + skiplist_branching_factor_(skiplist_branching_factor), + transform_(transform), + compare_(compare), + allocator_(allocator) { + auto mem = + allocator->AllocateAligned(sizeof(std::atomic<void*>) * bucket_size); + buckets_ = new (mem) std::atomic<Bucket*>[bucket_size]; + + for (size_t i = 0; i < bucket_size_; ++i) { + buckets_[i].store(nullptr, std::memory_order_relaxed); + } +} + +HashSkipListRep::~HashSkipListRep() {} + +HashSkipListRep::Bucket* HashSkipListRep::GetInitializedBucket( + const Slice& transformed) { + size_t hash = GetHash(transformed); + auto bucket = GetBucket(hash); + if (bucket == nullptr) { + auto addr = allocator_->AllocateAligned(sizeof(Bucket)); + bucket = new (addr) Bucket(compare_, allocator_, skiplist_height_, + skiplist_branching_factor_); + buckets_[hash].store(bucket, std::memory_order_release); + } + return bucket; +} + +void HashSkipListRep::Insert(KeyHandle handle) { + auto* key = static_cast<char*>(handle); + assert(!Contains(key)); + auto transformed = transform_->Transform(UserKey(key)); + auto bucket = GetInitializedBucket(transformed); + bucket->Insert(key); +} + +bool HashSkipListRep::Contains(const char* key) const { + auto transformed = transform_->Transform(UserKey(key)); + auto bucket = GetBucket(transformed); + if (bucket == nullptr) { + return false; + } + return bucket->Contains(key); +} + +size_t HashSkipListRep::ApproximateMemoryUsage() { return 0; } + +void HashSkipListRep::Get(const LookupKey& k, void* callback_args, + bool (*callback_func)(void* arg, const char* entry)) { + auto transformed = transform_->Transform(k.user_key()); + auto bucket = GetBucket(transformed); + if (bucket != nullptr) { + Bucket::Iterator iter(bucket); + for (iter.Seek(k.memtable_key().data()); + iter.Valid() && callback_func(callback_args, iter.key()); + iter.Next()) { + } + } +} + +MemTableRep::Iterator* HashSkipListRep::GetIterator(Arena* arena) { + // allocate a new arena of similar size to the one currently in use + Arena* new_arena = new Arena(allocator_->BlockSize()); + auto list = new Bucket(compare_, new_arena); + for (size_t i = 0; i < bucket_size_; ++i) { + auto bucket = GetBucket(i); + if (bucket != nullptr) { + Bucket::Iterator itr(bucket); + for (itr.SeekToFirst(); itr.Valid(); itr.Next()) { + list->Insert(itr.key()); + } + } + } + if (arena == nullptr) { + return new Iterator(list, true, new_arena); + } else { + auto mem = arena->AllocateAligned(sizeof(Iterator)); + return new (mem) Iterator(list, true, new_arena); + } +} + +MemTableRep::Iterator* HashSkipListRep::GetDynamicPrefixIterator(Arena* arena) { + if (arena == nullptr) { + return new DynamicIterator(*this); + } else { + auto mem = arena->AllocateAligned(sizeof(DynamicIterator)); + return new (mem) DynamicIterator(*this); + } +} + +struct HashSkipListRepOptions { + static const char* kName() { return "HashSkipListRepFactoryOptions"; } + size_t bucket_count; + int32_t skiplist_height; + int32_t skiplist_branching_factor; +}; + +static std::unordered_map<std::string, OptionTypeInfo> hash_skiplist_info = { + {"bucket_count", + {offsetof(struct HashSkipListRepOptions, bucket_count), OptionType::kSizeT, + OptionVerificationType::kNormal, OptionTypeFlags::kNone}}, + {"skiplist_height", + {offsetof(struct HashSkipListRepOptions, skiplist_height), + OptionType::kInt32T, OptionVerificationType::kNormal, + OptionTypeFlags::kNone}}, + {"branching_factor", + {offsetof(struct HashSkipListRepOptions, skiplist_branching_factor), + OptionType::kInt32T, OptionVerificationType::kNormal, + OptionTypeFlags::kNone}}, +}; + +class HashSkipListRepFactory : public MemTableRepFactory { + public: + explicit HashSkipListRepFactory(size_t bucket_count, int32_t skiplist_height, + int32_t skiplist_branching_factor) { + options_.bucket_count = bucket_count; + options_.skiplist_height = skiplist_height; + options_.skiplist_branching_factor = skiplist_branching_factor; + RegisterOptions(&options_, &hash_skiplist_info); + } + + using MemTableRepFactory::CreateMemTableRep; + virtual MemTableRep* CreateMemTableRep( + const MemTableRep::KeyComparator& compare, Allocator* allocator, + const SliceTransform* transform, Logger* logger) override; + + static const char* kClassName() { return "HashSkipListRepFactory"; } + static const char* kNickName() { return "prefix_hash"; } + + virtual const char* Name() const override { return kClassName(); } + virtual const char* NickName() const override { return kNickName(); } + + private: + HashSkipListRepOptions options_; +}; + +} // namespace + +MemTableRep* HashSkipListRepFactory::CreateMemTableRep( + const MemTableRep::KeyComparator& compare, Allocator* allocator, + const SliceTransform* transform, Logger* /*logger*/) { + return new HashSkipListRep(compare, allocator, transform, + options_.bucket_count, options_.skiplist_height, + options_.skiplist_branching_factor); +} + +MemTableRepFactory* NewHashSkipListRepFactory( + size_t bucket_count, int32_t skiplist_height, + int32_t skiplist_branching_factor) { + return new HashSkipListRepFactory(bucket_count, skiplist_height, + skiplist_branching_factor); +} + +} // namespace ROCKSDB_NAMESPACE +#endif // ROCKSDB_LITE diff --git a/src/rocksdb/memtable/inlineskiplist.h b/src/rocksdb/memtable/inlineskiplist.h new file mode 100644 index 000000000..abb3c3ddb --- /dev/null +++ b/src/rocksdb/memtable/inlineskiplist.h @@ -0,0 +1,1051 @@ +// Copyright (c) 2011-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). +// +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. Use of +// this source code is governed by a BSD-style license that can be found +// in the LICENSE file. See the AUTHORS file for names of contributors. +// +// InlineSkipList is derived from SkipList (skiplist.h), but it optimizes +// the memory layout by requiring that the key storage be allocated through +// the skip list instance. For the common case of SkipList<const char*, +// Cmp> this saves 1 pointer per skip list node and gives better cache +// locality, at the expense of wasted padding from using AllocateAligned +// instead of Allocate for the keys. The unused padding will be from +// 0 to sizeof(void*)-1 bytes, and the space savings are sizeof(void*) +// bytes, so despite the padding the space used is always less than +// SkipList<const char*, ..>. +// +// Thread safety ------------- +// +// Writes via Insert require external synchronization, most likely a mutex. +// InsertConcurrently can be safely called concurrently with reads and +// with other concurrent inserts. Reads require a guarantee that the +// InlineSkipList will not be destroyed while the read is in progress. +// Apart from that, reads progress without any internal locking or +// synchronization. +// +// Invariants: +// +// (1) Allocated nodes are never deleted until the InlineSkipList is +// destroyed. This is trivially guaranteed by the code since we never +// delete any skip list nodes. +// +// (2) The contents of a Node except for the next/prev pointers are +// immutable after the Node has been linked into the InlineSkipList. +// Only Insert() modifies the list, and it is careful to initialize a +// node and use release-stores to publish the nodes in one or more lists. +// +// ... prev vs. next pointer ordering ... +// + +#pragma once +#include <assert.h> +#include <stdlib.h> + +#include <algorithm> +#include <atomic> +#include <type_traits> + +#include "memory/allocator.h" +#include "port/likely.h" +#include "port/port.h" +#include "rocksdb/slice.h" +#include "util/coding.h" +#include "util/random.h" + +namespace ROCKSDB_NAMESPACE { + +template <class Comparator> +class InlineSkipList { + private: + struct Node; + struct Splice; + + public: + using DecodedKey = + typename std::remove_reference<Comparator>::type::DecodedType; + + static const uint16_t kMaxPossibleHeight = 32; + + // Create a new InlineSkipList object that will use "cmp" for comparing + // keys, and will allocate memory using "*allocator". Objects allocated + // in the allocator must remain allocated for the lifetime of the + // skiplist object. + explicit InlineSkipList(Comparator cmp, Allocator* allocator, + int32_t max_height = 12, + int32_t branching_factor = 4); + // No copying allowed + InlineSkipList(const InlineSkipList&) = delete; + InlineSkipList& operator=(const InlineSkipList&) = delete; + + // Allocates a key and a skip-list node, returning a pointer to the key + // portion of the node. This method is thread-safe if the allocator + // is thread-safe. + char* AllocateKey(size_t key_size); + + // Allocate a splice using allocator. + Splice* AllocateSplice(); + + // Allocate a splice on heap. + Splice* AllocateSpliceOnHeap(); + + // Inserts a key allocated by AllocateKey, after the actual key value + // has been filled in. + // + // REQUIRES: nothing that compares equal to key is currently in the list. + // REQUIRES: no concurrent calls to any of inserts. + bool Insert(const char* key); + + // Inserts a key allocated by AllocateKey with a hint of last insert + // position in the skip-list. If hint points to nullptr, a new hint will be + // populated, which can be used in subsequent calls. + // + // It can be used to optimize the workload where there are multiple groups + // of keys, and each key is likely to insert to a location close to the last + // inserted key in the same group. One example is sequential inserts. + // + // REQUIRES: nothing that compares equal to key is currently in the list. + // REQUIRES: no concurrent calls to any of inserts. + bool InsertWithHint(const char* key, void** hint); + + // Like InsertConcurrently, but with a hint + // + // REQUIRES: nothing that compares equal to key is currently in the list. + // REQUIRES: no concurrent calls that use same hint + bool InsertWithHintConcurrently(const char* key, void** hint); + + // Like Insert, but external synchronization is not required. + bool InsertConcurrently(const char* key); + + // Inserts a node into the skip list. key must have been allocated by + // AllocateKey and then filled in by the caller. If UseCAS is true, + // then external synchronization is not required, otherwise this method + // may not be called concurrently with any other insertions. + // + // Regardless of whether UseCAS is true, the splice must be owned + // exclusively by the current thread. If allow_partial_splice_fix is + // true, then the cost of insertion is amortized O(log D), where D is + // the distance from the splice to the inserted key (measured as the + // number of intervening nodes). Note that this bound is very good for + // sequential insertions! If allow_partial_splice_fix is false then + // the existing splice will be ignored unless the current key is being + // inserted immediately after the splice. allow_partial_splice_fix == + // false has worse running time for the non-sequential case O(log N), + // but a better constant factor. + template <bool UseCAS> + bool Insert(const char* key, Splice* splice, bool allow_partial_splice_fix); + + // Returns true iff an entry that compares equal to key is in the list. + bool Contains(const char* key) const; + + // Return estimated number of entries smaller than `key`. + uint64_t EstimateCount(const char* key) const; + + // Validate correctness of the skip-list. + void TEST_Validate() const; + + // Iteration over the contents of a skip list + class Iterator { + public: + // Initialize an iterator over the specified list. + // The returned iterator is not valid. + explicit Iterator(const InlineSkipList* list); + + // Change the underlying skiplist used for this iterator + // This enables us not changing the iterator without deallocating + // an old one and then allocating a new one + void SetList(const InlineSkipList* list); + + // Returns true iff the iterator is positioned at a valid node. + bool Valid() const; + + // Returns the key at the current position. + // REQUIRES: Valid() + const char* key() const; + + // Advances to the next position. + // REQUIRES: Valid() + void Next(); + + // Advances to the previous position. + // REQUIRES: Valid() + void Prev(); + + // Advance to the first entry with a key >= target + void Seek(const char* target); + + // Retreat to the last entry with a key <= target + void SeekForPrev(const char* target); + + // Advance to a random entry in the list. + void RandomSeek(); + + // Position at the first entry in list. + // Final state of iterator is Valid() iff list is not empty. + void SeekToFirst(); + + // Position at the last entry in list. + // Final state of iterator is Valid() iff list is not empty. + void SeekToLast(); + + private: + const InlineSkipList* list_; + Node* node_; + // Intentionally copyable + }; + + private: + const uint16_t kMaxHeight_; + const uint16_t kBranching_; + const uint32_t kScaledInverseBranching_; + + Allocator* const allocator_; // Allocator used for allocations of nodes + // Immutable after construction + Comparator const compare_; + Node* const head_; + + // Modified only by Insert(). Read racily by readers, but stale + // values are ok. + std::atomic<int> max_height_; // Height of the entire list + + // seq_splice_ is a Splice used for insertions in the non-concurrent + // case. It caches the prev and next found during the most recent + // non-concurrent insertion. + Splice* seq_splice_; + + inline int GetMaxHeight() const { + return max_height_.load(std::memory_order_relaxed); + } + + int RandomHeight(); + + Node* AllocateNode(size_t key_size, int height); + + bool Equal(const char* a, const char* b) const { + return (compare_(a, b) == 0); + } + + bool LessThan(const char* a, const char* b) const { + return (compare_(a, b) < 0); + } + + // Return true if key is greater than the data stored in "n". Null n + // is considered infinite. n should not be head_. + bool KeyIsAfterNode(const char* key, Node* n) const; + bool KeyIsAfterNode(const DecodedKey& key, Node* n) const; + + // Returns the earliest node with a key >= key. + // Return nullptr if there is no such node. + Node* FindGreaterOrEqual(const char* key) const; + + // Return the latest node with a key < key. + // Return head_ if there is no such node. + // Fills prev[level] with pointer to previous node at "level" for every + // level in [0..max_height_-1], if prev is non-null. + Node* FindLessThan(const char* key, Node** prev = nullptr) const; + + // Return the latest node with a key < key on bottom_level. Start searching + // from root node on the level below top_level. + // Fills prev[level] with pointer to previous node at "level" for every + // level in [bottom_level..top_level-1], if prev is non-null. + Node* FindLessThan(const char* key, Node** prev, Node* root, int top_level, + int bottom_level) const; + + // Return the last node in the list. + // Return head_ if list is empty. + Node* FindLast() const; + + // Returns a random entry. + Node* FindRandomEntry() const; + + // Traverses a single level of the list, setting *out_prev to the last + // node before the key and *out_next to the first node after. Assumes + // that the key is not present in the skip list. On entry, before should + // point to a node that is before the key, and after should point to + // a node that is after the key. after should be nullptr if a good after + // node isn't conveniently available. + template <bool prefetch_before> + void FindSpliceForLevel(const DecodedKey& key, Node* before, Node* after, + int level, Node** out_prev, Node** out_next); + + // Recomputes Splice levels from highest_level (inclusive) down to + // lowest_level (inclusive). + void RecomputeSpliceLevels(const DecodedKey& key, Splice* splice, + int recompute_level); +}; + +// Implementation details follow + +template <class Comparator> +struct InlineSkipList<Comparator>::Splice { + // The invariant of a Splice is that prev_[i+1].key <= prev_[i].key < + // next_[i].key <= next_[i+1].key for all i. That means that if a + // key is bracketed by prev_[i] and next_[i] then it is bracketed by + // all higher levels. It is _not_ required that prev_[i]->Next(i) == + // next_[i] (it probably did at some point in the past, but intervening + // or concurrent operations might have inserted nodes in between). + int height_ = 0; + Node** prev_; + Node** next_; +}; + +// The Node data type is more of a pointer into custom-managed memory than +// a traditional C++ struct. The key is stored in the bytes immediately +// after the struct, and the next_ pointers for nodes with height > 1 are +// stored immediately _before_ the struct. This avoids the need to include +// any pointer or sizing data, which reduces per-node memory overheads. +template <class Comparator> +struct InlineSkipList<Comparator>::Node { + // Stores the height of the node in the memory location normally used for + // next_[0]. This is used for passing data from AllocateKey to Insert. + void StashHeight(const int height) { + assert(sizeof(int) <= sizeof(next_[0])); + memcpy(static_cast<void*>(&next_[0]), &height, sizeof(int)); + } + + // Retrieves the value passed to StashHeight. Undefined after a call + // to SetNext or NoBarrier_SetNext. + int UnstashHeight() const { + int rv; + memcpy(&rv, &next_[0], sizeof(int)); + return rv; + } + + const char* Key() const { return reinterpret_cast<const char*>(&next_[1]); } + + // Accessors/mutators for links. Wrapped in methods so we can add + // the appropriate barriers as necessary, and perform the necessary + // addressing trickery for storing links below the Node in memory. + Node* Next(int n) { + assert(n >= 0); + // Use an 'acquire load' so that we observe a fully initialized + // version of the returned Node. + return ((&next_[0] - n)->load(std::memory_order_acquire)); + } + + void SetNext(int n, Node* x) { + assert(n >= 0); + // Use a 'release store' so that anybody who reads through this + // pointer observes a fully initialized version of the inserted node. + (&next_[0] - n)->store(x, std::memory_order_release); + } + + bool CASNext(int n, Node* expected, Node* x) { + assert(n >= 0); + return (&next_[0] - n)->compare_exchange_strong(expected, x); + } + + // No-barrier variants that can be safely used in a few locations. + Node* NoBarrier_Next(int n) { + assert(n >= 0); + return (&next_[0] - n)->load(std::memory_order_relaxed); + } + + void NoBarrier_SetNext(int n, Node* x) { + assert(n >= 0); + (&next_[0] - n)->store(x, std::memory_order_relaxed); + } + + // Insert node after prev on specific level. + void InsertAfter(Node* prev, int level) { + // NoBarrier_SetNext() suffices since we will add a barrier when + // we publish a pointer to "this" in prev. + NoBarrier_SetNext(level, prev->NoBarrier_Next(level)); + prev->SetNext(level, this); + } + + private: + // next_[0] is the lowest level link (level 0). Higher levels are + // stored _earlier_, so level 1 is at next_[-1]. + std::atomic<Node*> next_[1]; +}; + +template <class Comparator> +inline InlineSkipList<Comparator>::Iterator::Iterator( + const InlineSkipList* list) { + SetList(list); +} + +template <class Comparator> +inline void InlineSkipList<Comparator>::Iterator::SetList( + const InlineSkipList* list) { + list_ = list; + node_ = nullptr; +} + +template <class Comparator> +inline bool InlineSkipList<Comparator>::Iterator::Valid() const { + return node_ != nullptr; +} + +template <class Comparator> +inline const char* InlineSkipList<Comparator>::Iterator::key() const { + assert(Valid()); + return node_->Key(); +} + +template <class Comparator> +inline void InlineSkipList<Comparator>::Iterator::Next() { + assert(Valid()); + node_ = node_->Next(0); +} + +template <class Comparator> +inline void InlineSkipList<Comparator>::Iterator::Prev() { + // Instead of using explicit "prev" links, we just search for the + // last node that falls before key. + assert(Valid()); + node_ = list_->FindLessThan(node_->Key()); + if (node_ == list_->head_) { + node_ = nullptr; + } +} + +template <class Comparator> +inline void InlineSkipList<Comparator>::Iterator::Seek(const char* target) { + node_ = list_->FindGreaterOrEqual(target); +} + +template <class Comparator> +inline void InlineSkipList<Comparator>::Iterator::SeekForPrev( + const char* target) { + Seek(target); + if (!Valid()) { + SeekToLast(); + } + while (Valid() && list_->LessThan(target, key())) { + Prev(); + } +} + +template <class Comparator> +inline void InlineSkipList<Comparator>::Iterator::RandomSeek() { + node_ = list_->FindRandomEntry(); +} + +template <class Comparator> +inline void InlineSkipList<Comparator>::Iterator::SeekToFirst() { + node_ = list_->head_->Next(0); +} + +template <class Comparator> +inline void InlineSkipList<Comparator>::Iterator::SeekToLast() { + node_ = list_->FindLast(); + if (node_ == list_->head_) { + node_ = nullptr; + } +} + +template <class Comparator> +int InlineSkipList<Comparator>::RandomHeight() { + auto rnd = Random::GetTLSInstance(); + + // Increase height with probability 1 in kBranching + int height = 1; + while (height < kMaxHeight_ && height < kMaxPossibleHeight && + rnd->Next() < kScaledInverseBranching_) { + height++; + } + assert(height > 0); + assert(height <= kMaxHeight_); + assert(height <= kMaxPossibleHeight); + return height; +} + +template <class Comparator> +bool InlineSkipList<Comparator>::KeyIsAfterNode(const char* key, + Node* n) const { + // nullptr n is considered infinite + assert(n != head_); + return (n != nullptr) && (compare_(n->Key(), key) < 0); +} + +template <class Comparator> +bool InlineSkipList<Comparator>::KeyIsAfterNode(const DecodedKey& key, + Node* n) const { + // nullptr n is considered infinite + assert(n != head_); + return (n != nullptr) && (compare_(n->Key(), key) < 0); +} + +template <class Comparator> +typename InlineSkipList<Comparator>::Node* +InlineSkipList<Comparator>::FindGreaterOrEqual(const char* key) const { + // Note: It looks like we could reduce duplication by implementing + // this function as FindLessThan(key)->Next(0), but we wouldn't be able + // to exit early on equality and the result wouldn't even be correct. + // A concurrent insert might occur after FindLessThan(key) but before + // we get a chance to call Next(0). + Node* x = head_; + int level = GetMaxHeight() - 1; + Node* last_bigger = nullptr; + const DecodedKey key_decoded = compare_.decode_key(key); + while (true) { + Node* next = x->Next(level); + if (next != nullptr) { + PREFETCH(next->Next(level), 0, 1); + } + // Make sure the lists are sorted + assert(x == head_ || next == nullptr || KeyIsAfterNode(next->Key(), x)); + // Make sure we haven't overshot during our search + assert(x == head_ || KeyIsAfterNode(key_decoded, x)); + int cmp = (next == nullptr || next == last_bigger) + ? 1 + : compare_(next->Key(), key_decoded); + if (cmp == 0 || (cmp > 0 && level == 0)) { + return next; + } else if (cmp < 0) { + // Keep searching in this list + x = next; + } else { + // Switch to next list, reuse compare_() result + last_bigger = next; + level--; + } + } +} + +template <class Comparator> +typename InlineSkipList<Comparator>::Node* +InlineSkipList<Comparator>::FindLessThan(const char* key, Node** prev) const { + return FindLessThan(key, prev, head_, GetMaxHeight(), 0); +} + +template <class Comparator> +typename InlineSkipList<Comparator>::Node* +InlineSkipList<Comparator>::FindLessThan(const char* key, Node** prev, + Node* root, int top_level, + int bottom_level) const { + assert(top_level > bottom_level); + int level = top_level - 1; + Node* x = root; + // KeyIsAfter(key, last_not_after) is definitely false + Node* last_not_after = nullptr; + const DecodedKey key_decoded = compare_.decode_key(key); + while (true) { + assert(x != nullptr); + Node* next = x->Next(level); + if (next != nullptr) { + PREFETCH(next->Next(level), 0, 1); + } + assert(x == head_ || next == nullptr || KeyIsAfterNode(next->Key(), x)); + assert(x == head_ || KeyIsAfterNode(key_decoded, x)); + if (next != last_not_after && KeyIsAfterNode(key_decoded, next)) { + // Keep searching in this list + assert(next != nullptr); + x = next; + } else { + if (prev != nullptr) { + prev[level] = x; + } + if (level == bottom_level) { + return x; + } else { + // Switch to next list, reuse KeyIsAfterNode() result + last_not_after = next; + level--; + } + } + } +} + +template <class Comparator> +typename InlineSkipList<Comparator>::Node* +InlineSkipList<Comparator>::FindLast() const { + Node* x = head_; + int level = GetMaxHeight() - 1; + while (true) { + Node* next = x->Next(level); + if (next == nullptr) { + if (level == 0) { + return x; + } else { + // Switch to next list + level--; + } + } else { + x = next; + } + } +} + +template <class Comparator> +typename InlineSkipList<Comparator>::Node* +InlineSkipList<Comparator>::FindRandomEntry() const { + // TODO(bjlemaire): consider adding PREFETCH calls. + Node *x = head_, *scan_node = nullptr, *limit_node = nullptr; + + // We start at the max level. + // FOr each level, we look at all the nodes at the level, and + // we randomly pick one of them. Then decrement the level + // and reiterate the process. + // eg: assume GetMaxHeight()=5, and there are #100 elements (nodes). + // level 4 nodes: lvl_nodes={#1, #15, #67, #84}. Randomly pick #15. + // We will consider all the nodes between #15 (inclusive) and #67 + // (exclusive). #67 is called 'limit_node' here. + // level 3 nodes: lvl_nodes={#15, #21, #45, #51}. Randomly choose + // #51. #67 remains 'limit_node'. + // [...] + // level 0 nodes: lvl_nodes={#56,#57,#58,#59}. Randomly pick $57. + // Return Node #57. + std::vector<Node*> lvl_nodes; + Random* rnd = Random::GetTLSInstance(); + int level = GetMaxHeight() - 1; + + while (level >= 0) { + lvl_nodes.clear(); + scan_node = x; + while (scan_node != limit_node) { + lvl_nodes.push_back(scan_node); + scan_node = scan_node->Next(level); + } + uint32_t rnd_idx = rnd->Next() % lvl_nodes.size(); + x = lvl_nodes[rnd_idx]; + if (rnd_idx + 1 < lvl_nodes.size()) { + limit_node = lvl_nodes[rnd_idx + 1]; + } + level--; + } + // There is a special case where x could still be the head_ + // (note that the head_ contains no key). + return x == head_ && head_ != nullptr ? head_->Next(0) : x; +} + +template <class Comparator> +uint64_t InlineSkipList<Comparator>::EstimateCount(const char* key) const { + uint64_t count = 0; + + Node* x = head_; + int level = GetMaxHeight() - 1; + const DecodedKey key_decoded = compare_.decode_key(key); + while (true) { + assert(x == head_ || compare_(x->Key(), key_decoded) < 0); + Node* next = x->Next(level); + if (next != nullptr) { + PREFETCH(next->Next(level), 0, 1); + } + if (next == nullptr || compare_(next->Key(), key_decoded) >= 0) { + if (level == 0) { + return count; + } else { + // Switch to next list + count *= kBranching_; + level--; + } + } else { + x = next; + count++; + } + } +} + +template <class Comparator> +InlineSkipList<Comparator>::InlineSkipList(const Comparator cmp, + Allocator* allocator, + int32_t max_height, + int32_t branching_factor) + : kMaxHeight_(static_cast<uint16_t>(max_height)), + kBranching_(static_cast<uint16_t>(branching_factor)), + kScaledInverseBranching_((Random::kMaxNext + 1) / kBranching_), + allocator_(allocator), + compare_(cmp), + head_(AllocateNode(0, max_height)), + max_height_(1), + seq_splice_(AllocateSplice()) { + assert(max_height > 0 && kMaxHeight_ == static_cast<uint32_t>(max_height)); + assert(branching_factor > 1 && + kBranching_ == static_cast<uint32_t>(branching_factor)); + assert(kScaledInverseBranching_ > 0); + + for (int i = 0; i < kMaxHeight_; ++i) { + head_->SetNext(i, nullptr); + } +} + +template <class Comparator> +char* InlineSkipList<Comparator>::AllocateKey(size_t key_size) { + return const_cast<char*>(AllocateNode(key_size, RandomHeight())->Key()); +} + +template <class Comparator> +typename InlineSkipList<Comparator>::Node* +InlineSkipList<Comparator>::AllocateNode(size_t key_size, int height) { + auto prefix = sizeof(std::atomic<Node*>) * (height - 1); + + // prefix is space for the height - 1 pointers that we store before + // the Node instance (next_[-(height - 1) .. -1]). Node starts at + // raw + prefix, and holds the bottom-mode (level 0) skip list pointer + // next_[0]. key_size is the bytes for the key, which comes just after + // the Node. + char* raw = allocator_->AllocateAligned(prefix + sizeof(Node) + key_size); + Node* x = reinterpret_cast<Node*>(raw + prefix); + + // Once we've linked the node into the skip list we don't actually need + // to know its height, because we can implicitly use the fact that we + // traversed into a node at level h to known that h is a valid level + // for that node. We need to convey the height to the Insert step, + // however, so that it can perform the proper links. Since we're not + // using the pointers at the moment, StashHeight temporarily borrow + // storage from next_[0] for that purpose. + x->StashHeight(height); + return x; +} + +template <class Comparator> +typename InlineSkipList<Comparator>::Splice* +InlineSkipList<Comparator>::AllocateSplice() { + // size of prev_ and next_ + size_t array_size = sizeof(Node*) * (kMaxHeight_ + 1); + char* raw = allocator_->AllocateAligned(sizeof(Splice) + array_size * 2); + Splice* splice = reinterpret_cast<Splice*>(raw); + splice->height_ = 0; + splice->prev_ = reinterpret_cast<Node**>(raw + sizeof(Splice)); + splice->next_ = reinterpret_cast<Node**>(raw + sizeof(Splice) + array_size); + return splice; +} + +template <class Comparator> +typename InlineSkipList<Comparator>::Splice* +InlineSkipList<Comparator>::AllocateSpliceOnHeap() { + size_t array_size = sizeof(Node*) * (kMaxHeight_ + 1); + char* raw = new char[sizeof(Splice) + array_size * 2]; + Splice* splice = reinterpret_cast<Splice*>(raw); + splice->height_ = 0; + splice->prev_ = reinterpret_cast<Node**>(raw + sizeof(Splice)); + splice->next_ = reinterpret_cast<Node**>(raw + sizeof(Splice) + array_size); + return splice; +} + +template <class Comparator> +bool InlineSkipList<Comparator>::Insert(const char* key) { + return Insert<false>(key, seq_splice_, false); +} + +template <class Comparator> +bool InlineSkipList<Comparator>::InsertConcurrently(const char* key) { + Node* prev[kMaxPossibleHeight]; + Node* next[kMaxPossibleHeight]; + Splice splice; + splice.prev_ = prev; + splice.next_ = next; + return Insert<true>(key, &splice, false); +} + +template <class Comparator> +bool InlineSkipList<Comparator>::InsertWithHint(const char* key, void** hint) { + assert(hint != nullptr); + Splice* splice = reinterpret_cast<Splice*>(*hint); + if (splice == nullptr) { + splice = AllocateSplice(); + *hint = reinterpret_cast<void*>(splice); + } + return Insert<false>(key, splice, true); +} + +template <class Comparator> +bool InlineSkipList<Comparator>::InsertWithHintConcurrently(const char* key, + void** hint) { + assert(hint != nullptr); + Splice* splice = reinterpret_cast<Splice*>(*hint); + if (splice == nullptr) { + splice = AllocateSpliceOnHeap(); + *hint = reinterpret_cast<void*>(splice); + } + return Insert<true>(key, splice, true); +} + +template <class Comparator> +template <bool prefetch_before> +void InlineSkipList<Comparator>::FindSpliceForLevel(const DecodedKey& key, + Node* before, Node* after, + int level, Node** out_prev, + Node** out_next) { + while (true) { + Node* next = before->Next(level); + if (next != nullptr) { + PREFETCH(next->Next(level), 0, 1); + } + if (prefetch_before == true) { + if (next != nullptr && level > 0) { + PREFETCH(next->Next(level - 1), 0, 1); + } + } + assert(before == head_ || next == nullptr || + KeyIsAfterNode(next->Key(), before)); + assert(before == head_ || KeyIsAfterNode(key, before)); + if (next == after || !KeyIsAfterNode(key, next)) { + // found it + *out_prev = before; + *out_next = next; + return; + } + before = next; + } +} + +template <class Comparator> +void InlineSkipList<Comparator>::RecomputeSpliceLevels(const DecodedKey& key, + Splice* splice, + int recompute_level) { + assert(recompute_level > 0); + assert(recompute_level <= splice->height_); + for (int i = recompute_level - 1; i >= 0; --i) { + FindSpliceForLevel<true>(key, splice->prev_[i + 1], splice->next_[i + 1], i, + &splice->prev_[i], &splice->next_[i]); + } +} + +template <class Comparator> +template <bool UseCAS> +bool InlineSkipList<Comparator>::Insert(const char* key, Splice* splice, + bool allow_partial_splice_fix) { + Node* x = reinterpret_cast<Node*>(const_cast<char*>(key)) - 1; + const DecodedKey key_decoded = compare_.decode_key(key); + int height = x->UnstashHeight(); + assert(height >= 1 && height <= kMaxHeight_); + + int max_height = max_height_.load(std::memory_order_relaxed); + while (height > max_height) { + if (max_height_.compare_exchange_weak(max_height, height)) { + // successfully updated it + max_height = height; + break; + } + // else retry, possibly exiting the loop because somebody else + // increased it + } + assert(max_height <= kMaxPossibleHeight); + + int recompute_height = 0; + if (splice->height_ < max_height) { + // Either splice has never been used or max_height has grown since + // last use. We could potentially fix it in the latter case, but + // that is tricky. + splice->prev_[max_height] = head_; + splice->next_[max_height] = nullptr; + splice->height_ = max_height; + recompute_height = max_height; + } else { + // Splice is a valid proper-height splice that brackets some + // key, but does it bracket this one? We need to validate it and + // recompute a portion of the splice (levels 0..recompute_height-1) + // that is a superset of all levels that don't bracket the new key. + // Several choices are reasonable, because we have to balance the work + // saved against the extra comparisons required to validate the Splice. + // + // One strategy is just to recompute all of orig_splice_height if the + // bottom level isn't bracketing. This pessimistically assumes that + // we will either get a perfect Splice hit (increasing sequential + // inserts) or have no locality. + // + // Another strategy is to walk up the Splice's levels until we find + // a level that brackets the key. This strategy lets the Splice + // hint help for other cases: it turns insertion from O(log N) into + // O(log D), where D is the number of nodes in between the key that + // produced the Splice and the current insert (insertion is aided + // whether the new key is before or after the splice). If you have + // a way of using a prefix of the key to map directly to the closest + // Splice out of O(sqrt(N)) Splices and we make it so that splices + // can also be used as hints during read, then we end up with Oshman's + // and Shavit's SkipTrie, which has O(log log N) lookup and insertion + // (compare to O(log N) for skip list). + // + // We control the pessimistic strategy with allow_partial_splice_fix. + // A good strategy is probably to be pessimistic for seq_splice_, + // optimistic if the caller actually went to the work of providing + // a Splice. + while (recompute_height < max_height) { + if (splice->prev_[recompute_height]->Next(recompute_height) != + splice->next_[recompute_height]) { + // splice isn't tight at this level, there must have been some inserts + // to this + // location that didn't update the splice. We might only be a little + // stale, but if + // the splice is very stale it would be O(N) to fix it. We haven't used + // up any of + // our budget of comparisons, so always move up even if we are + // pessimistic about + // our chances of success. + ++recompute_height; + } else if (splice->prev_[recompute_height] != head_ && + !KeyIsAfterNode(key_decoded, + splice->prev_[recompute_height])) { + // key is from before splice + if (allow_partial_splice_fix) { + // skip all levels with the same node without more comparisons + Node* bad = splice->prev_[recompute_height]; + while (splice->prev_[recompute_height] == bad) { + ++recompute_height; + } + } else { + // we're pessimistic, recompute everything + recompute_height = max_height; + } + } else if (KeyIsAfterNode(key_decoded, splice->next_[recompute_height])) { + // key is from after splice + if (allow_partial_splice_fix) { + Node* bad = splice->next_[recompute_height]; + while (splice->next_[recompute_height] == bad) { + ++recompute_height; + } + } else { + recompute_height = max_height; + } + } else { + // this level brackets the key, we won! + break; + } + } + } + assert(recompute_height <= max_height); + if (recompute_height > 0) { + RecomputeSpliceLevels(key_decoded, splice, recompute_height); + } + + bool splice_is_valid = true; + if (UseCAS) { + for (int i = 0; i < height; ++i) { + while (true) { + // Checking for duplicate keys on the level 0 is sufficient + if (UNLIKELY(i == 0 && splice->next_[i] != nullptr && + compare_(x->Key(), splice->next_[i]->Key()) >= 0)) { + // duplicate key + return false; + } + if (UNLIKELY(i == 0 && splice->prev_[i] != head_ && + compare_(splice->prev_[i]->Key(), x->Key()) >= 0)) { + // duplicate key + return false; + } + assert(splice->next_[i] == nullptr || + compare_(x->Key(), splice->next_[i]->Key()) < 0); + assert(splice->prev_[i] == head_ || + compare_(splice->prev_[i]->Key(), x->Key()) < 0); + x->NoBarrier_SetNext(i, splice->next_[i]); + if (splice->prev_[i]->CASNext(i, splice->next_[i], x)) { + // success + break; + } + // CAS failed, we need to recompute prev and next. It is unlikely + // to be helpful to try to use a different level as we redo the + // search, because it should be unlikely that lots of nodes have + // been inserted between prev[i] and next[i]. No point in using + // next[i] as the after hint, because we know it is stale. + FindSpliceForLevel<false>(key_decoded, splice->prev_[i], nullptr, i, + &splice->prev_[i], &splice->next_[i]); + + // Since we've narrowed the bracket for level i, we might have + // violated the Splice constraint between i and i-1. Make sure + // we recompute the whole thing next time. + if (i > 0) { + splice_is_valid = false; + } + } + } + } else { + for (int i = 0; i < height; ++i) { + if (i >= recompute_height && + splice->prev_[i]->Next(i) != splice->next_[i]) { + FindSpliceForLevel<false>(key_decoded, splice->prev_[i], nullptr, i, + &splice->prev_[i], &splice->next_[i]); + } + // Checking for duplicate keys on the level 0 is sufficient + if (UNLIKELY(i == 0 && splice->next_[i] != nullptr && + compare_(x->Key(), splice->next_[i]->Key()) >= 0)) { + // duplicate key + return false; + } + if (UNLIKELY(i == 0 && splice->prev_[i] != head_ && + compare_(splice->prev_[i]->Key(), x->Key()) >= 0)) { + // duplicate key + return false; + } + assert(splice->next_[i] == nullptr || + compare_(x->Key(), splice->next_[i]->Key()) < 0); + assert(splice->prev_[i] == head_ || + compare_(splice->prev_[i]->Key(), x->Key()) < 0); + assert(splice->prev_[i]->Next(i) == splice->next_[i]); + x->NoBarrier_SetNext(i, splice->next_[i]); + splice->prev_[i]->SetNext(i, x); + } + } + if (splice_is_valid) { + for (int i = 0; i < height; ++i) { + splice->prev_[i] = x; + } + assert(splice->prev_[splice->height_] == head_); + assert(splice->next_[splice->height_] == nullptr); + for (int i = 0; i < splice->height_; ++i) { + assert(splice->next_[i] == nullptr || + compare_(key, splice->next_[i]->Key()) < 0); + assert(splice->prev_[i] == head_ || + compare_(splice->prev_[i]->Key(), key) <= 0); + assert(splice->prev_[i + 1] == splice->prev_[i] || + splice->prev_[i + 1] == head_ || + compare_(splice->prev_[i + 1]->Key(), splice->prev_[i]->Key()) < + 0); + assert(splice->next_[i + 1] == splice->next_[i] || + splice->next_[i + 1] == nullptr || + compare_(splice->next_[i]->Key(), splice->next_[i + 1]->Key()) < + 0); + } + } else { + splice->height_ = 0; + } + return true; +} + +template <class Comparator> +bool InlineSkipList<Comparator>::Contains(const char* key) const { + Node* x = FindGreaterOrEqual(key); + if (x != nullptr && Equal(key, x->Key())) { + return true; + } else { + return false; + } +} + +template <class Comparator> +void InlineSkipList<Comparator>::TEST_Validate() const { + // Interate over all levels at the same time, and verify nodes appear in + // the right order, and nodes appear in upper level also appear in lower + // levels. + Node* nodes[kMaxPossibleHeight]; + int max_height = GetMaxHeight(); + assert(max_height > 0); + for (int i = 0; i < max_height; i++) { + nodes[i] = head_; + } + while (nodes[0] != nullptr) { + Node* l0_next = nodes[0]->Next(0); + if (l0_next == nullptr) { + break; + } + assert(nodes[0] == head_ || compare_(nodes[0]->Key(), l0_next->Key()) < 0); + nodes[0] = l0_next; + + int i = 1; + while (i < max_height) { + Node* next = nodes[i]->Next(i); + if (next == nullptr) { + break; + } + auto cmp = compare_(nodes[0]->Key(), next->Key()); + assert(cmp <= 0); + if (cmp == 0) { + assert(next == nodes[0]); + nodes[i] = next; + } else { + break; + } + i++; + } + } + for (int i = 1; i < max_height; i++) { + assert(nodes[i] != nullptr && nodes[i]->Next(i) == nullptr); + } +} + +} // namespace ROCKSDB_NAMESPACE diff --git a/src/rocksdb/memtable/inlineskiplist_test.cc b/src/rocksdb/memtable/inlineskiplist_test.cc new file mode 100644 index 000000000..f85644064 --- /dev/null +++ b/src/rocksdb/memtable/inlineskiplist_test.cc @@ -0,0 +1,664 @@ +// Copyright (c) 2011-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). +// +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#include "memtable/inlineskiplist.h" + +#include <set> +#include <unordered_set> + +#include "memory/concurrent_arena.h" +#include "rocksdb/env.h" +#include "test_util/testharness.h" +#include "util/hash.h" +#include "util/random.h" + +namespace ROCKSDB_NAMESPACE { + +// Our test skip list stores 8-byte unsigned integers +using Key = uint64_t; + +static const char* Encode(const uint64_t* key) { + return reinterpret_cast<const char*>(key); +} + +static Key Decode(const char* key) { + Key rv; + memcpy(&rv, key, sizeof(Key)); + return rv; +} + +struct TestComparator { + using DecodedType = Key; + + static DecodedType decode_key(const char* b) { return Decode(b); } + + int operator()(const char* a, const char* b) const { + if (Decode(a) < Decode(b)) { + return -1; + } else if (Decode(a) > Decode(b)) { + return +1; + } else { + return 0; + } + } + + int operator()(const char* a, const DecodedType b) const { + if (Decode(a) < b) { + return -1; + } else if (Decode(a) > b) { + return +1; + } else { + return 0; + } + } +}; + +using TestInlineSkipList = InlineSkipList<TestComparator>; + +class InlineSkipTest : public testing::Test { + public: + void Insert(TestInlineSkipList* list, Key key) { + char* buf = list->AllocateKey(sizeof(Key)); + memcpy(buf, &key, sizeof(Key)); + list->Insert(buf); + keys_.insert(key); + } + + bool InsertWithHint(TestInlineSkipList* list, Key key, void** hint) { + char* buf = list->AllocateKey(sizeof(Key)); + memcpy(buf, &key, sizeof(Key)); + bool res = list->InsertWithHint(buf, hint); + keys_.insert(key); + return res; + } + + void Validate(TestInlineSkipList* list) { + // Check keys exist. + for (Key key : keys_) { + ASSERT_TRUE(list->Contains(Encode(&key))); + } + // Iterate over the list, make sure keys appears in order and no extra + // keys exist. + TestInlineSkipList::Iterator iter(list); + ASSERT_FALSE(iter.Valid()); + Key zero = 0; + iter.Seek(Encode(&zero)); + for (Key key : keys_) { + ASSERT_TRUE(iter.Valid()); + ASSERT_EQ(key, Decode(iter.key())); + iter.Next(); + } + ASSERT_FALSE(iter.Valid()); + // Validate the list is well-formed. + list->TEST_Validate(); + } + + private: + std::set<Key> keys_; +}; + +TEST_F(InlineSkipTest, Empty) { + Arena arena; + TestComparator cmp; + InlineSkipList<TestComparator> list(cmp, &arena); + Key key = 10; + ASSERT_TRUE(!list.Contains(Encode(&key))); + + InlineSkipList<TestComparator>::Iterator iter(&list); + ASSERT_TRUE(!iter.Valid()); + iter.SeekToFirst(); + ASSERT_TRUE(!iter.Valid()); + key = 100; + iter.Seek(Encode(&key)); + ASSERT_TRUE(!iter.Valid()); + iter.SeekForPrev(Encode(&key)); + ASSERT_TRUE(!iter.Valid()); + iter.SeekToLast(); + ASSERT_TRUE(!iter.Valid()); +} + +TEST_F(InlineSkipTest, InsertAndLookup) { + const int N = 2000; + const int R = 5000; + Random rnd(1000); + std::set<Key> keys; + ConcurrentArena arena; + TestComparator cmp; + InlineSkipList<TestComparator> list(cmp, &arena); + for (int i = 0; i < N; i++) { + Key key = rnd.Next() % R; + if (keys.insert(key).second) { + char* buf = list.AllocateKey(sizeof(Key)); + memcpy(buf, &key, sizeof(Key)); + list.Insert(buf); + } + } + + for (Key i = 0; i < R; i++) { + if (list.Contains(Encode(&i))) { + ASSERT_EQ(keys.count(i), 1U); + } else { + ASSERT_EQ(keys.count(i), 0U); + } + } + + // Simple iterator tests + { + InlineSkipList<TestComparator>::Iterator iter(&list); + ASSERT_TRUE(!iter.Valid()); + + uint64_t zero = 0; + iter.Seek(Encode(&zero)); + ASSERT_TRUE(iter.Valid()); + ASSERT_EQ(*(keys.begin()), Decode(iter.key())); + + uint64_t max_key = R - 1; + iter.SeekForPrev(Encode(&max_key)); + ASSERT_TRUE(iter.Valid()); + ASSERT_EQ(*(keys.rbegin()), Decode(iter.key())); + + iter.SeekToFirst(); + ASSERT_TRUE(iter.Valid()); + ASSERT_EQ(*(keys.begin()), Decode(iter.key())); + + iter.SeekToLast(); + ASSERT_TRUE(iter.Valid()); + ASSERT_EQ(*(keys.rbegin()), Decode(iter.key())); + } + + // Forward iteration test + for (Key i = 0; i < R; i++) { + InlineSkipList<TestComparator>::Iterator iter(&list); + iter.Seek(Encode(&i)); + + // Compare against model iterator + std::set<Key>::iterator model_iter = keys.lower_bound(i); + for (int j = 0; j < 3; j++) { + if (model_iter == keys.end()) { + ASSERT_TRUE(!iter.Valid()); + break; + } else { + ASSERT_TRUE(iter.Valid()); + ASSERT_EQ(*model_iter, Decode(iter.key())); + ++model_iter; + iter.Next(); + } + } + } + + // Backward iteration test + for (Key i = 0; i < R; i++) { + InlineSkipList<TestComparator>::Iterator iter(&list); + iter.SeekForPrev(Encode(&i)); + + // Compare against model iterator + std::set<Key>::iterator model_iter = keys.upper_bound(i); + for (int j = 0; j < 3; j++) { + if (model_iter == keys.begin()) { + ASSERT_TRUE(!iter.Valid()); + break; + } else { + ASSERT_TRUE(iter.Valid()); + ASSERT_EQ(*--model_iter, Decode(iter.key())); + iter.Prev(); + } + } + } +} + +TEST_F(InlineSkipTest, InsertWithHint_Sequential) { + const int N = 100000; + Arena arena; + TestComparator cmp; + TestInlineSkipList list(cmp, &arena); + void* hint = nullptr; + for (int i = 0; i < N; i++) { + Key key = i; + InsertWithHint(&list, key, &hint); + } + Validate(&list); +} + +TEST_F(InlineSkipTest, InsertWithHint_MultipleHints) { + const int N = 100000; + const int S = 100; + Random rnd(534); + Arena arena; + TestComparator cmp; + TestInlineSkipList list(cmp, &arena); + void* hints[S]; + Key last_key[S]; + for (int i = 0; i < S; i++) { + hints[i] = nullptr; + last_key[i] = 0; + } + for (int i = 0; i < N; i++) { + Key s = rnd.Uniform(S); + Key key = (s << 32) + (++last_key[s]); + InsertWithHint(&list, key, &hints[s]); + } + Validate(&list); +} + +TEST_F(InlineSkipTest, InsertWithHint_MultipleHintsRandom) { + const int N = 100000; + const int S = 100; + Random rnd(534); + Arena arena; + TestComparator cmp; + TestInlineSkipList list(cmp, &arena); + void* hints[S]; + for (int i = 0; i < S; i++) { + hints[i] = nullptr; + } + for (int i = 0; i < N; i++) { + Key s = rnd.Uniform(S); + Key key = (s << 32) + rnd.Next(); + InsertWithHint(&list, key, &hints[s]); + } + Validate(&list); +} + +TEST_F(InlineSkipTest, InsertWithHint_CompatibleWithInsertWithoutHint) { + const int N = 100000; + const int S1 = 100; + const int S2 = 100; + Random rnd(534); + Arena arena; + TestComparator cmp; + TestInlineSkipList list(cmp, &arena); + std::unordered_set<Key> used; + Key with_hint[S1]; + Key without_hint[S2]; + void* hints[S1]; + for (int i = 0; i < S1; i++) { + hints[i] = nullptr; + while (true) { + Key s = rnd.Next(); + if (used.insert(s).second) { + with_hint[i] = s; + break; + } + } + } + for (int i = 0; i < S2; i++) { + while (true) { + Key s = rnd.Next(); + if (used.insert(s).second) { + without_hint[i] = s; + break; + } + } + } + for (int i = 0; i < N; i++) { + Key s = rnd.Uniform(S1 + S2); + if (s < S1) { + Key key = (with_hint[s] << 32) + rnd.Next(); + InsertWithHint(&list, key, &hints[s]); + } else { + Key key = (without_hint[s - S1] << 32) + rnd.Next(); + Insert(&list, key); + } + } + Validate(&list); +} + +#if !defined(ROCKSDB_VALGRIND_RUN) || defined(ROCKSDB_FULL_VALGRIND_RUN) +// We want to make sure that with a single writer and multiple +// concurrent readers (with no synchronization other than when a +// reader's iterator is created), the reader always observes all the +// data that was present in the skip list when the iterator was +// constructor. Because insertions are happening concurrently, we may +// also observe new values that were inserted since the iterator was +// constructed, but we should never miss any values that were present +// at iterator construction time. +// +// We generate multi-part keys: +// <key,gen,hash> +// where: +// key is in range [0..K-1] +// gen is a generation number for key +// hash is hash(key,gen) +// +// The insertion code picks a random key, sets gen to be 1 + the last +// generation number inserted for that key, and sets hash to Hash(key,gen). +// +// At the beginning of a read, we snapshot the last inserted +// generation number for each key. We then iterate, including random +// calls to Next() and Seek(). For every key we encounter, we +// check that it is either expected given the initial snapshot or has +// been concurrently added since the iterator started. +class ConcurrentTest { + public: + static const uint32_t K = 8; + + private: + static uint64_t key(Key key) { return (key >> 40); } + static uint64_t gen(Key key) { return (key >> 8) & 0xffffffffu; } + static uint64_t hash(Key key) { return key & 0xff; } + + static uint64_t HashNumbers(uint64_t k, uint64_t g) { + uint64_t data[2] = {k, g}; + return Hash(reinterpret_cast<char*>(data), sizeof(data), 0); + } + + static Key MakeKey(uint64_t k, uint64_t g) { + assert(sizeof(Key) == sizeof(uint64_t)); + assert(k <= K); // We sometimes pass K to seek to the end of the skiplist + assert(g <= 0xffffffffu); + return ((k << 40) | (g << 8) | (HashNumbers(k, g) & 0xff)); + } + + static bool IsValidKey(Key k) { + return hash(k) == (HashNumbers(key(k), gen(k)) & 0xff); + } + + static Key RandomTarget(Random* rnd) { + switch (rnd->Next() % 10) { + case 0: + // Seek to beginning + return MakeKey(0, 0); + case 1: + // Seek to end + return MakeKey(K, 0); + default: + // Seek to middle + return MakeKey(rnd->Next() % K, 0); + } + } + + // Per-key generation + struct State { + std::atomic<int> generation[K]; + void Set(int k, int v) { + generation[k].store(v, std::memory_order_release); + } + int Get(int k) { return generation[k].load(std::memory_order_acquire); } + + State() { + for (unsigned int k = 0; k < K; k++) { + Set(k, 0); + } + } + }; + + // Current state of the test + State current_; + + ConcurrentArena arena_; + + // InlineSkipList is not protected by mu_. We just use a single writer + // thread to modify it. + InlineSkipList<TestComparator> list_; + + public: + ConcurrentTest() : list_(TestComparator(), &arena_) {} + + // REQUIRES: No concurrent calls to WriteStep or ConcurrentWriteStep + void WriteStep(Random* rnd) { + const uint32_t k = rnd->Next() % K; + const int g = current_.Get(k) + 1; + const Key new_key = MakeKey(k, g); + char* buf = list_.AllocateKey(sizeof(Key)); + memcpy(buf, &new_key, sizeof(Key)); + list_.Insert(buf); + current_.Set(k, g); + } + + // REQUIRES: No concurrent calls for the same k + void ConcurrentWriteStep(uint32_t k, bool use_hint = false) { + const int g = current_.Get(k) + 1; + const Key new_key = MakeKey(k, g); + char* buf = list_.AllocateKey(sizeof(Key)); + memcpy(buf, &new_key, sizeof(Key)); + if (use_hint) { + void* hint = nullptr; + list_.InsertWithHintConcurrently(buf, &hint); + delete[] reinterpret_cast<char*>(hint); + } else { + list_.InsertConcurrently(buf); + } + ASSERT_EQ(g, current_.Get(k) + 1); + current_.Set(k, g); + } + + void ReadStep(Random* rnd) { + // Remember the initial committed state of the skiplist. + State initial_state; + for (unsigned int k = 0; k < K; k++) { + initial_state.Set(k, current_.Get(k)); + } + + Key pos = RandomTarget(rnd); + InlineSkipList<TestComparator>::Iterator iter(&list_); + iter.Seek(Encode(&pos)); + while (true) { + Key current; + if (!iter.Valid()) { + current = MakeKey(K, 0); + } else { + current = Decode(iter.key()); + ASSERT_TRUE(IsValidKey(current)) << current; + } + ASSERT_LE(pos, current) << "should not go backwards"; + + // Verify that everything in [pos,current) was not present in + // initial_state. + while (pos < current) { + ASSERT_LT(key(pos), K) << pos; + + // Note that generation 0 is never inserted, so it is ok if + // <*,0,*> is missing. + ASSERT_TRUE((gen(pos) == 0U) || + (gen(pos) > static_cast<uint64_t>(initial_state.Get( + static_cast<int>(key(pos)))))) + << "key: " << key(pos) << "; gen: " << gen(pos) + << "; initgen: " << initial_state.Get(static_cast<int>(key(pos))); + + // Advance to next key in the valid key space + if (key(pos) < key(current)) { + pos = MakeKey(key(pos) + 1, 0); + } else { + pos = MakeKey(key(pos), gen(pos) + 1); + } + } + + if (!iter.Valid()) { + break; + } + + if (rnd->Next() % 2) { + iter.Next(); + pos = MakeKey(key(pos), gen(pos) + 1); + } else { + Key new_target = RandomTarget(rnd); + if (new_target > pos) { + pos = new_target; + iter.Seek(Encode(&new_target)); + } + } + } + } +}; +const uint32_t ConcurrentTest::K; + +// Simple test that does single-threaded testing of the ConcurrentTest +// scaffolding. +TEST_F(InlineSkipTest, ConcurrentReadWithoutThreads) { + ConcurrentTest test; + Random rnd(test::RandomSeed()); + for (int i = 0; i < 10000; i++) { + test.ReadStep(&rnd); + test.WriteStep(&rnd); + } +} + +TEST_F(InlineSkipTest, ConcurrentInsertWithoutThreads) { + ConcurrentTest test; + Random rnd(test::RandomSeed()); + for (int i = 0; i < 10000; i++) { + test.ReadStep(&rnd); + uint32_t base = rnd.Next(); + for (int j = 0; j < 4; ++j) { + test.ConcurrentWriteStep((base + j) % ConcurrentTest::K); + } + } +} + +class TestState { + public: + ConcurrentTest t_; + bool use_hint_; + int seed_; + std::atomic<bool> quit_flag_; + std::atomic<uint32_t> next_writer_; + + enum ReaderState { STARTING, RUNNING, DONE }; + + explicit TestState(int s) + : seed_(s), + quit_flag_(false), + state_(STARTING), + pending_writers_(0), + state_cv_(&mu_) {} + + void Wait(ReaderState s) { + mu_.Lock(); + while (state_ != s) { + state_cv_.Wait(); + } + mu_.Unlock(); + } + + void Change(ReaderState s) { + mu_.Lock(); + state_ = s; + state_cv_.Signal(); + mu_.Unlock(); + } + + void AdjustPendingWriters(int delta) { + mu_.Lock(); + pending_writers_ += delta; + if (pending_writers_ == 0) { + state_cv_.Signal(); + } + mu_.Unlock(); + } + + void WaitForPendingWriters() { + mu_.Lock(); + while (pending_writers_ != 0) { + state_cv_.Wait(); + } + mu_.Unlock(); + } + + private: + port::Mutex mu_; + ReaderState state_; + int pending_writers_; + port::CondVar state_cv_; +}; + +static void ConcurrentReader(void* arg) { + TestState* state = reinterpret_cast<TestState*>(arg); + Random rnd(state->seed_); + int64_t reads = 0; + state->Change(TestState::RUNNING); + while (!state->quit_flag_.load(std::memory_order_acquire)) { + state->t_.ReadStep(&rnd); + ++reads; + } + state->Change(TestState::DONE); +} + +static void ConcurrentWriter(void* arg) { + TestState* state = reinterpret_cast<TestState*>(arg); + uint32_t k = state->next_writer_++ % ConcurrentTest::K; + state->t_.ConcurrentWriteStep(k, state->use_hint_); + state->AdjustPendingWriters(-1); +} + +static void RunConcurrentRead(int run) { + const int seed = test::RandomSeed() + (run * 100); + Random rnd(seed); + const int N = 1000; + const int kSize = 1000; + for (int i = 0; i < N; i++) { + if ((i % 100) == 0) { + fprintf(stderr, "Run %d of %d\n", i, N); + } + TestState state(seed + 1); + Env::Default()->SetBackgroundThreads(1); + Env::Default()->Schedule(ConcurrentReader, &state); + state.Wait(TestState::RUNNING); + for (int k = 0; k < kSize; ++k) { + state.t_.WriteStep(&rnd); + } + state.quit_flag_.store(true, std::memory_order_release); + state.Wait(TestState::DONE); + } +} + +static void RunConcurrentInsert(int run, bool use_hint = false, + int write_parallelism = 4) { + Env::Default()->SetBackgroundThreads(1 + write_parallelism, + Env::Priority::LOW); + const int seed = test::RandomSeed() + (run * 100); + Random rnd(seed); + const int N = 1000; + const int kSize = 1000; + for (int i = 0; i < N; i++) { + if ((i % 100) == 0) { + fprintf(stderr, "Run %d of %d\n", i, N); + } + TestState state(seed + 1); + state.use_hint_ = use_hint; + Env::Default()->Schedule(ConcurrentReader, &state); + state.Wait(TestState::RUNNING); + for (int k = 0; k < kSize; k += write_parallelism) { + state.next_writer_ = rnd.Next(); + state.AdjustPendingWriters(write_parallelism); + for (int p = 0; p < write_parallelism; ++p) { + Env::Default()->Schedule(ConcurrentWriter, &state); + } + state.WaitForPendingWriters(); + } + state.quit_flag_.store(true, std::memory_order_release); + state.Wait(TestState::DONE); + } +} + +TEST_F(InlineSkipTest, ConcurrentRead1) { RunConcurrentRead(1); } +TEST_F(InlineSkipTest, ConcurrentRead2) { RunConcurrentRead(2); } +TEST_F(InlineSkipTest, ConcurrentRead3) { RunConcurrentRead(3); } +TEST_F(InlineSkipTest, ConcurrentRead4) { RunConcurrentRead(4); } +TEST_F(InlineSkipTest, ConcurrentRead5) { RunConcurrentRead(5); } +TEST_F(InlineSkipTest, ConcurrentInsert1) { RunConcurrentInsert(1); } +TEST_F(InlineSkipTest, ConcurrentInsert2) { RunConcurrentInsert(2); } +TEST_F(InlineSkipTest, ConcurrentInsert3) { RunConcurrentInsert(3); } +TEST_F(InlineSkipTest, ConcurrentInsertWithHint1) { + RunConcurrentInsert(1, true); +} +TEST_F(InlineSkipTest, ConcurrentInsertWithHint2) { + RunConcurrentInsert(2, true); +} +TEST_F(InlineSkipTest, ConcurrentInsertWithHint3) { + RunConcurrentInsert(3, true); +} + +#endif // !defined(ROCKSDB_VALGRIND_RUN) || defined(ROCKSDB_FULL_VALGRIND_RUN) +} // namespace ROCKSDB_NAMESPACE + +int main(int argc, char** argv) { + ROCKSDB_NAMESPACE::port::InstallStackTraceHandler(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/src/rocksdb/memtable/memtablerep_bench.cc b/src/rocksdb/memtable/memtablerep_bench.cc new file mode 100644 index 000000000..a915abed7 --- /dev/null +++ b/src/rocksdb/memtable/memtablerep_bench.cc @@ -0,0 +1,689 @@ +// Copyright (c) 2011-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). +// +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#ifndef GFLAGS +#include <cstdio> +int main() { + fprintf(stderr, "Please install gflags to run rocksdb tools\n"); + return 1; +} +#else + +#include <atomic> +#include <iostream> +#include <memory> +#include <thread> +#include <type_traits> +#include <vector> + +#include "db/dbformat.h" +#include "db/memtable.h" +#include "memory/arena.h" +#include "port/port.h" +#include "port/stack_trace.h" +#include "rocksdb/comparator.h" +#include "rocksdb/convenience.h" +#include "rocksdb/memtablerep.h" +#include "rocksdb/options.h" +#include "rocksdb/slice_transform.h" +#include "rocksdb/system_clock.h" +#include "rocksdb/write_buffer_manager.h" +#include "test_util/testutil.h" +#include "util/gflags_compat.h" +#include "util/mutexlock.h" +#include "util/stop_watch.h" + +using GFLAGS_NAMESPACE::ParseCommandLineFlags; +using GFLAGS_NAMESPACE::RegisterFlagValidator; +using GFLAGS_NAMESPACE::SetUsageMessage; + +DEFINE_string(benchmarks, "fillrandom", + "Comma-separated list of benchmarks to run. Options:\n" + "\tfillrandom -- write N random values\n" + "\tfillseq -- write N values in sequential order\n" + "\treadrandom -- read N values in random order\n" + "\treadseq -- scan the DB\n" + "\treadwrite -- 1 thread writes while N - 1 threads " + "do random\n" + "\t reads\n" + "\tseqreadwrite -- 1 thread writes while N - 1 threads " + "do scans\n"); + +DEFINE_string(memtablerep, "skiplist", + "Which implementation of memtablerep to use. See " + "include/memtablerep.h for\n" + " more details. Options:\n" + "\tskiplist -- backed by a skiplist\n" + "\tvector -- backed by an std::vector\n" + "\thashskiplist -- backed by a hash skip list\n" + "\thashlinklist -- backed by a hash linked list\n" + "\tcuckoo -- backed by a cuckoo hash table"); + +DEFINE_int64(bucket_count, 1000000, + "bucket_count parameter to pass into NewHashSkiplistRepFactory or " + "NewHashLinkListRepFactory"); + +DEFINE_int32( + hashskiplist_height, 4, + "skiplist_height parameter to pass into NewHashSkiplistRepFactory"); + +DEFINE_int32( + hashskiplist_branching_factor, 4, + "branching_factor parameter to pass into NewHashSkiplistRepFactory"); + +DEFINE_int32( + huge_page_tlb_size, 0, + "huge_page_tlb_size parameter to pass into NewHashLinkListRepFactory"); + +DEFINE_int32(bucket_entries_logging_threshold, 4096, + "bucket_entries_logging_threshold parameter to pass into " + "NewHashLinkListRepFactory"); + +DEFINE_bool(if_log_bucket_dist_when_flash, true, + "if_log_bucket_dist_when_flash parameter to pass into " + "NewHashLinkListRepFactory"); + +DEFINE_int32( + threshold_use_skiplist, 256, + "threshold_use_skiplist parameter to pass into NewHashLinkListRepFactory"); + +DEFINE_int64(write_buffer_size, 256, + "write_buffer_size parameter to pass into WriteBufferManager"); + +DEFINE_int32( + num_threads, 1, + "Number of concurrent threads to run. If the benchmark includes writes,\n" + "then at most one thread will be a writer"); + +DEFINE_int32(num_operations, 1000000, + "Number of operations to do for write and random read benchmarks"); + +DEFINE_int32(num_scans, 10, + "Number of times for each thread to scan the memtablerep for " + "sequential read " + "benchmarks"); + +DEFINE_int32(item_size, 100, "Number of bytes each item should be"); + +DEFINE_int32(prefix_length, 8, + "Prefix length to pass into NewFixedPrefixTransform"); + +/* VectorRep settings */ +DEFINE_int64(vectorrep_count, 0, + "Number of entries to reserve on VectorRep initialization"); + +DEFINE_int64(seed, 0, + "Seed base for random number generators. " + "When 0 it is deterministic."); + +namespace ROCKSDB_NAMESPACE { + +namespace { +struct CallbackVerifyArgs { + bool found; + LookupKey* key; + MemTableRep* table; + InternalKeyComparator* comparator; +}; +} // namespace + +// Helper for quickly generating random data. +class RandomGenerator { + private: + std::string data_; + unsigned int pos_; + + public: + RandomGenerator() { + Random rnd(301); + auto size = (unsigned)std::max(1048576, FLAGS_item_size); + data_ = rnd.RandomString(size); + pos_ = 0; + } + + Slice Generate(unsigned int len) { + assert(len <= data_.size()); + if (pos_ + len > data_.size()) { + pos_ = 0; + } + pos_ += len; + return Slice(data_.data() + pos_ - len, len); + } +}; + +enum WriteMode { SEQUENTIAL, RANDOM, UNIQUE_RANDOM }; + +class KeyGenerator { + public: + KeyGenerator(Random64* rand, WriteMode mode, uint64_t num) + : rand_(rand), mode_(mode), num_(num), next_(0) { + if (mode_ == UNIQUE_RANDOM) { + // NOTE: if memory consumption of this approach becomes a concern, + // we can either break it into pieces and only random shuffle a section + // each time. Alternatively, use a bit map implementation + // (https://reviews.facebook.net/differential/diff/54627/) + values_.resize(num_); + for (uint64_t i = 0; i < num_; ++i) { + values_[i] = i; + } + RandomShuffle(values_.begin(), values_.end(), + static_cast<uint32_t>(FLAGS_seed)); + } + } + + uint64_t Next() { + switch (mode_) { + case SEQUENTIAL: + return next_++; + case RANDOM: + return rand_->Next() % num_; + case UNIQUE_RANDOM: + return values_[next_++]; + } + assert(false); + return std::numeric_limits<uint64_t>::max(); + } + + private: + Random64* rand_; + WriteMode mode_; + const uint64_t num_; + uint64_t next_; + std::vector<uint64_t> values_; +}; + +class BenchmarkThread { + public: + explicit BenchmarkThread(MemTableRep* table, KeyGenerator* key_gen, + uint64_t* bytes_written, uint64_t* bytes_read, + uint64_t* sequence, uint64_t num_ops, + uint64_t* read_hits) + : table_(table), + key_gen_(key_gen), + bytes_written_(bytes_written), + bytes_read_(bytes_read), + sequence_(sequence), + num_ops_(num_ops), + read_hits_(read_hits) {} + + virtual void operator()() = 0; + virtual ~BenchmarkThread() {} + + protected: + MemTableRep* table_; + KeyGenerator* key_gen_; + uint64_t* bytes_written_; + uint64_t* bytes_read_; + uint64_t* sequence_; + uint64_t num_ops_; + uint64_t* read_hits_; + RandomGenerator generator_; +}; + +class FillBenchmarkThread : public BenchmarkThread { + public: + FillBenchmarkThread(MemTableRep* table, KeyGenerator* key_gen, + uint64_t* bytes_written, uint64_t* bytes_read, + uint64_t* sequence, uint64_t num_ops, uint64_t* read_hits) + : BenchmarkThread(table, key_gen, bytes_written, bytes_read, sequence, + num_ops, read_hits) {} + + void FillOne() { + char* buf = nullptr; + auto internal_key_size = 16; + auto encoded_len = + FLAGS_item_size + VarintLength(internal_key_size) + internal_key_size; + KeyHandle handle = table_->Allocate(encoded_len, &buf); + assert(buf != nullptr); + char* p = EncodeVarint32(buf, internal_key_size); + auto key = key_gen_->Next(); + EncodeFixed64(p, key); + p += 8; + EncodeFixed64(p, ++(*sequence_)); + p += 8; + Slice bytes = generator_.Generate(FLAGS_item_size); + memcpy(p, bytes.data(), FLAGS_item_size); + p += FLAGS_item_size; + assert(p == buf + encoded_len); + table_->Insert(handle); + *bytes_written_ += encoded_len; + } + + void operator()() override { + for (unsigned int i = 0; i < num_ops_; ++i) { + FillOne(); + } + } +}; + +class ConcurrentFillBenchmarkThread : public FillBenchmarkThread { + public: + ConcurrentFillBenchmarkThread(MemTableRep* table, KeyGenerator* key_gen, + uint64_t* bytes_written, uint64_t* bytes_read, + uint64_t* sequence, uint64_t num_ops, + uint64_t* read_hits, + std::atomic_int* threads_done) + : FillBenchmarkThread(table, key_gen, bytes_written, bytes_read, sequence, + num_ops, read_hits) { + threads_done_ = threads_done; + } + + void operator()() override { + // # of read threads will be total threads - write threads (always 1). Loop + // while all reads complete. + while ((*threads_done_).load() < (FLAGS_num_threads - 1)) { + FillOne(); + } + } + + private: + std::atomic_int* threads_done_; +}; + +class ReadBenchmarkThread : public BenchmarkThread { + public: + ReadBenchmarkThread(MemTableRep* table, KeyGenerator* key_gen, + uint64_t* bytes_written, uint64_t* bytes_read, + uint64_t* sequence, uint64_t num_ops, uint64_t* read_hits) + : BenchmarkThread(table, key_gen, bytes_written, bytes_read, sequence, + num_ops, read_hits) {} + + static bool callback(void* arg, const char* entry) { + CallbackVerifyArgs* callback_args = static_cast<CallbackVerifyArgs*>(arg); + assert(callback_args != nullptr); + uint32_t key_length; + const char* key_ptr = GetVarint32Ptr(entry, entry + 5, &key_length); + if ((callback_args->comparator) + ->user_comparator() + ->Equal(Slice(key_ptr, key_length - 8), + callback_args->key->user_key())) { + callback_args->found = true; + } + return false; + } + + void ReadOne() { + std::string user_key; + auto key = key_gen_->Next(); + PutFixed64(&user_key, key); + LookupKey lookup_key(user_key, *sequence_); + InternalKeyComparator internal_key_comp(BytewiseComparator()); + CallbackVerifyArgs verify_args; + verify_args.found = false; + verify_args.key = &lookup_key; + verify_args.table = table_; + verify_args.comparator = &internal_key_comp; + table_->Get(lookup_key, &verify_args, callback); + if (verify_args.found) { + *bytes_read_ += VarintLength(16) + 16 + FLAGS_item_size; + ++*read_hits_; + } + } + void operator()() override { + for (unsigned int i = 0; i < num_ops_; ++i) { + ReadOne(); + } + } +}; + +class SeqReadBenchmarkThread : public BenchmarkThread { + public: + SeqReadBenchmarkThread(MemTableRep* table, KeyGenerator* key_gen, + uint64_t* bytes_written, uint64_t* bytes_read, + uint64_t* sequence, uint64_t num_ops, + uint64_t* read_hits) + : BenchmarkThread(table, key_gen, bytes_written, bytes_read, sequence, + num_ops, read_hits) {} + + void ReadOneSeq() { + std::unique_ptr<MemTableRep::Iterator> iter(table_->GetIterator()); + for (iter->SeekToFirst(); iter->Valid(); iter->Next()) { + // pretend to read the value + *bytes_read_ += VarintLength(16) + 16 + FLAGS_item_size; + } + ++*read_hits_; + } + + void operator()() override { + for (unsigned int i = 0; i < num_ops_; ++i) { + { ReadOneSeq(); } + } + } +}; + +class ConcurrentReadBenchmarkThread : public ReadBenchmarkThread { + public: + ConcurrentReadBenchmarkThread(MemTableRep* table, KeyGenerator* key_gen, + uint64_t* bytes_written, uint64_t* bytes_read, + uint64_t* sequence, uint64_t num_ops, + uint64_t* read_hits, + std::atomic_int* threads_done) + : ReadBenchmarkThread(table, key_gen, bytes_written, bytes_read, sequence, + num_ops, read_hits) { + threads_done_ = threads_done; + } + + void operator()() override { + for (unsigned int i = 0; i < num_ops_; ++i) { + ReadOne(); + } + ++*threads_done_; + } + + private: + std::atomic_int* threads_done_; +}; + +class SeqConcurrentReadBenchmarkThread : public SeqReadBenchmarkThread { + public: + SeqConcurrentReadBenchmarkThread(MemTableRep* table, KeyGenerator* key_gen, + uint64_t* bytes_written, + uint64_t* bytes_read, uint64_t* sequence, + uint64_t num_ops, uint64_t* read_hits, + std::atomic_int* threads_done) + : SeqReadBenchmarkThread(table, key_gen, bytes_written, bytes_read, + sequence, num_ops, read_hits) { + threads_done_ = threads_done; + } + + void operator()() override { + for (unsigned int i = 0; i < num_ops_; ++i) { + ReadOneSeq(); + } + ++*threads_done_; + } + + private: + std::atomic_int* threads_done_; +}; + +class Benchmark { + public: + explicit Benchmark(MemTableRep* table, KeyGenerator* key_gen, + uint64_t* sequence, uint32_t num_threads) + : table_(table), + key_gen_(key_gen), + sequence_(sequence), + num_threads_(num_threads) {} + + virtual ~Benchmark() {} + virtual void Run() { + std::cout << "Number of threads: " << num_threads_ << std::endl; + std::vector<port::Thread> threads; + uint64_t bytes_written = 0; + uint64_t bytes_read = 0; + uint64_t read_hits = 0; + StopWatchNano timer(SystemClock::Default().get(), true); + RunThreads(&threads, &bytes_written, &bytes_read, true, &read_hits); + auto elapsed_time = static_cast<double>(timer.ElapsedNanos() / 1000); + std::cout << "Elapsed time: " << static_cast<int>(elapsed_time) << " us" + << std::endl; + + if (bytes_written > 0) { + auto MiB_written = static_cast<double>(bytes_written) / (1 << 20); + auto write_throughput = MiB_written / (elapsed_time / 1000000); + std::cout << "Total bytes written: " << MiB_written << " MiB" + << std::endl; + std::cout << "Write throughput: " << write_throughput << " MiB/s" + << std::endl; + auto us_per_op = elapsed_time / num_write_ops_per_thread_; + std::cout << "write us/op: " << us_per_op << std::endl; + } + if (bytes_read > 0) { + auto MiB_read = static_cast<double>(bytes_read) / (1 << 20); + auto read_throughput = MiB_read / (elapsed_time / 1000000); + std::cout << "Total bytes read: " << MiB_read << " MiB" << std::endl; + std::cout << "Read throughput: " << read_throughput << " MiB/s" + << std::endl; + auto us_per_op = elapsed_time / num_read_ops_per_thread_; + std::cout << "read us/op: " << us_per_op << std::endl; + } + } + + virtual void RunThreads(std::vector<port::Thread>* threads, + uint64_t* bytes_written, uint64_t* bytes_read, + bool write, uint64_t* read_hits) = 0; + + protected: + MemTableRep* table_; + KeyGenerator* key_gen_; + uint64_t* sequence_; + uint64_t num_write_ops_per_thread_ = 0; + uint64_t num_read_ops_per_thread_ = 0; + const uint32_t num_threads_; +}; + +class FillBenchmark : public Benchmark { + public: + explicit FillBenchmark(MemTableRep* table, KeyGenerator* key_gen, + uint64_t* sequence) + : Benchmark(table, key_gen, sequence, 1) { + num_write_ops_per_thread_ = FLAGS_num_operations; + } + + void RunThreads(std::vector<port::Thread>* /*threads*/, + uint64_t* bytes_written, uint64_t* bytes_read, bool /*write*/, + uint64_t* read_hits) override { + FillBenchmarkThread(table_, key_gen_, bytes_written, bytes_read, sequence_, + num_write_ops_per_thread_, read_hits)(); + } +}; + +class ReadBenchmark : public Benchmark { + public: + explicit ReadBenchmark(MemTableRep* table, KeyGenerator* key_gen, + uint64_t* sequence) + : Benchmark(table, key_gen, sequence, FLAGS_num_threads) { + num_read_ops_per_thread_ = FLAGS_num_operations / FLAGS_num_threads; + } + + void RunThreads(std::vector<port::Thread>* threads, uint64_t* bytes_written, + uint64_t* bytes_read, bool /*write*/, + uint64_t* read_hits) override { + for (int i = 0; i < FLAGS_num_threads; ++i) { + threads->emplace_back( + ReadBenchmarkThread(table_, key_gen_, bytes_written, bytes_read, + sequence_, num_read_ops_per_thread_, read_hits)); + } + for (auto& thread : *threads) { + thread.join(); + } + std::cout << "read hit%: " + << (static_cast<double>(*read_hits) / FLAGS_num_operations) * 100 + << std::endl; + } +}; + +class SeqReadBenchmark : public Benchmark { + public: + explicit SeqReadBenchmark(MemTableRep* table, uint64_t* sequence) + : Benchmark(table, nullptr, sequence, FLAGS_num_threads) { + num_read_ops_per_thread_ = FLAGS_num_scans; + } + + void RunThreads(std::vector<port::Thread>* threads, uint64_t* bytes_written, + uint64_t* bytes_read, bool /*write*/, + uint64_t* read_hits) override { + for (int i = 0; i < FLAGS_num_threads; ++i) { + threads->emplace_back(SeqReadBenchmarkThread( + table_, key_gen_, bytes_written, bytes_read, sequence_, + num_read_ops_per_thread_, read_hits)); + } + for (auto& thread : *threads) { + thread.join(); + } + } +}; + +template <class ReadThreadType> +class ReadWriteBenchmark : public Benchmark { + public: + explicit ReadWriteBenchmark(MemTableRep* table, KeyGenerator* key_gen, + uint64_t* sequence) + : Benchmark(table, key_gen, sequence, FLAGS_num_threads) { + num_read_ops_per_thread_ = + FLAGS_num_threads <= 1 + ? 0 + : (FLAGS_num_operations / (FLAGS_num_threads - 1)); + num_write_ops_per_thread_ = FLAGS_num_operations; + } + + void RunThreads(std::vector<port::Thread>* threads, uint64_t* bytes_written, + uint64_t* bytes_read, bool /*write*/, + uint64_t* read_hits) override { + std::atomic_int threads_done; + threads_done.store(0); + threads->emplace_back(ConcurrentFillBenchmarkThread( + table_, key_gen_, bytes_written, bytes_read, sequence_, + num_write_ops_per_thread_, read_hits, &threads_done)); + for (int i = 1; i < FLAGS_num_threads; ++i) { + threads->emplace_back( + ReadThreadType(table_, key_gen_, bytes_written, bytes_read, sequence_, + num_read_ops_per_thread_, read_hits, &threads_done)); + } + for (auto& thread : *threads) { + thread.join(); + } + } +}; + +} // namespace ROCKSDB_NAMESPACE + +void PrintWarnings() { +#if defined(__GNUC__) && !defined(__OPTIMIZE__) + fprintf(stdout, + "WARNING: Optimization is disabled: benchmarks unnecessarily slow\n"); +#endif +#ifndef NDEBUG + fprintf(stdout, + "WARNING: Assertions are enabled; benchmarks unnecessarily slow\n"); +#endif +} + +int main(int argc, char** argv) { + ROCKSDB_NAMESPACE::port::InstallStackTraceHandler(); + SetUsageMessage(std::string("\nUSAGE:\n") + std::string(argv[0]) + + " [OPTIONS]..."); + ParseCommandLineFlags(&argc, &argv, true); + + PrintWarnings(); + + ROCKSDB_NAMESPACE::Options options; + + std::unique_ptr<ROCKSDB_NAMESPACE::MemTableRepFactory> factory; + if (FLAGS_memtablerep == "skiplist") { + factory.reset(new ROCKSDB_NAMESPACE::SkipListFactory); +#ifndef ROCKSDB_LITE + } else if (FLAGS_memtablerep == "vector") { + factory.reset(new ROCKSDB_NAMESPACE::VectorRepFactory); + } else if (FLAGS_memtablerep == "hashskiplist" || + FLAGS_memtablerep == "prefix_hash") { + factory.reset(ROCKSDB_NAMESPACE::NewHashSkipListRepFactory( + FLAGS_bucket_count, FLAGS_hashskiplist_height, + FLAGS_hashskiplist_branching_factor)); + options.prefix_extractor.reset( + ROCKSDB_NAMESPACE::NewFixedPrefixTransform(FLAGS_prefix_length)); + } else if (FLAGS_memtablerep == "hashlinklist" || + FLAGS_memtablerep == "hash_linkedlist") { + factory.reset(ROCKSDB_NAMESPACE::NewHashLinkListRepFactory( + FLAGS_bucket_count, FLAGS_huge_page_tlb_size, + FLAGS_bucket_entries_logging_threshold, + FLAGS_if_log_bucket_dist_when_flash, FLAGS_threshold_use_skiplist)); + options.prefix_extractor.reset( + ROCKSDB_NAMESPACE::NewFixedPrefixTransform(FLAGS_prefix_length)); +#endif // ROCKSDB_LITE + } else { + ROCKSDB_NAMESPACE::ConfigOptions config_options; + config_options.ignore_unsupported_options = false; + + ROCKSDB_NAMESPACE::Status s = + ROCKSDB_NAMESPACE::MemTableRepFactory::CreateFromString( + config_options, FLAGS_memtablerep, &factory); + if (!s.ok()) { + fprintf(stdout, "Unknown memtablerep: %s\n", s.ToString().c_str()); + exit(1); + } + } + + ROCKSDB_NAMESPACE::InternalKeyComparator internal_key_comp( + ROCKSDB_NAMESPACE::BytewiseComparator()); + ROCKSDB_NAMESPACE::MemTable::KeyComparator key_comp(internal_key_comp); + ROCKSDB_NAMESPACE::Arena arena; + ROCKSDB_NAMESPACE::WriteBufferManager wb(FLAGS_write_buffer_size); + uint64_t sequence; + auto createMemtableRep = [&] { + sequence = 0; + return factory->CreateMemTableRep(key_comp, &arena, + options.prefix_extractor.get(), + options.info_log.get()); + }; + std::unique_ptr<ROCKSDB_NAMESPACE::MemTableRep> memtablerep; + ROCKSDB_NAMESPACE::Random64 rng(FLAGS_seed); + const char* benchmarks = FLAGS_benchmarks.c_str(); + while (benchmarks != nullptr) { + std::unique_ptr<ROCKSDB_NAMESPACE::KeyGenerator> key_gen; + const char* sep = strchr(benchmarks, ','); + ROCKSDB_NAMESPACE::Slice name; + if (sep == nullptr) { + name = benchmarks; + benchmarks = nullptr; + } else { + name = ROCKSDB_NAMESPACE::Slice(benchmarks, sep - benchmarks); + benchmarks = sep + 1; + } + std::unique_ptr<ROCKSDB_NAMESPACE::Benchmark> benchmark; + if (name == ROCKSDB_NAMESPACE::Slice("fillseq")) { + memtablerep.reset(createMemtableRep()); + key_gen.reset(new ROCKSDB_NAMESPACE::KeyGenerator( + &rng, ROCKSDB_NAMESPACE::SEQUENTIAL, FLAGS_num_operations)); + benchmark.reset(new ROCKSDB_NAMESPACE::FillBenchmark( + memtablerep.get(), key_gen.get(), &sequence)); + } else if (name == ROCKSDB_NAMESPACE::Slice("fillrandom")) { + memtablerep.reset(createMemtableRep()); + key_gen.reset(new ROCKSDB_NAMESPACE::KeyGenerator( + &rng, ROCKSDB_NAMESPACE::UNIQUE_RANDOM, FLAGS_num_operations)); + benchmark.reset(new ROCKSDB_NAMESPACE::FillBenchmark( + memtablerep.get(), key_gen.get(), &sequence)); + } else if (name == ROCKSDB_NAMESPACE::Slice("readrandom")) { + key_gen.reset(new ROCKSDB_NAMESPACE::KeyGenerator( + &rng, ROCKSDB_NAMESPACE::RANDOM, FLAGS_num_operations)); + benchmark.reset(new ROCKSDB_NAMESPACE::ReadBenchmark( + memtablerep.get(), key_gen.get(), &sequence)); + } else if (name == ROCKSDB_NAMESPACE::Slice("readseq")) { + key_gen.reset(new ROCKSDB_NAMESPACE::KeyGenerator( + &rng, ROCKSDB_NAMESPACE::SEQUENTIAL, FLAGS_num_operations)); + benchmark.reset(new ROCKSDB_NAMESPACE::SeqReadBenchmark(memtablerep.get(), + &sequence)); + } else if (name == ROCKSDB_NAMESPACE::Slice("readwrite")) { + memtablerep.reset(createMemtableRep()); + key_gen.reset(new ROCKSDB_NAMESPACE::KeyGenerator( + &rng, ROCKSDB_NAMESPACE::RANDOM, FLAGS_num_operations)); + benchmark.reset(new ROCKSDB_NAMESPACE::ReadWriteBenchmark< + ROCKSDB_NAMESPACE::ConcurrentReadBenchmarkThread>( + memtablerep.get(), key_gen.get(), &sequence)); + } else if (name == ROCKSDB_NAMESPACE::Slice("seqreadwrite")) { + memtablerep.reset(createMemtableRep()); + key_gen.reset(new ROCKSDB_NAMESPACE::KeyGenerator( + &rng, ROCKSDB_NAMESPACE::RANDOM, FLAGS_num_operations)); + benchmark.reset(new ROCKSDB_NAMESPACE::ReadWriteBenchmark< + ROCKSDB_NAMESPACE::SeqConcurrentReadBenchmarkThread>( + memtablerep.get(), key_gen.get(), &sequence)); + } else { + std::cout << "WARNING: skipping unknown benchmark '" << name.ToString() + << std::endl; + continue; + } + std::cout << "Running " << name.ToString() << std::endl; + benchmark->Run(); + } + + return 0; +} + +#endif // GFLAGS diff --git a/src/rocksdb/memtable/skiplist.h b/src/rocksdb/memtable/skiplist.h new file mode 100644 index 000000000..e3cecd30c --- /dev/null +++ b/src/rocksdb/memtable/skiplist.h @@ -0,0 +1,498 @@ +// Copyright (c) 2011-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). +// +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. +// +// Thread safety +// ------------- +// +// Writes require external synchronization, most likely a mutex. +// Reads require a guarantee that the SkipList will not be destroyed +// while the read is in progress. Apart from that, reads progress +// without any internal locking or synchronization. +// +// Invariants: +// +// (1) Allocated nodes are never deleted until the SkipList is +// destroyed. This is trivially guaranteed by the code since we +// never delete any skip list nodes. +// +// (2) The contents of a Node except for the next/prev pointers are +// immutable after the Node has been linked into the SkipList. +// Only Insert() modifies the list, and it is careful to initialize +// a node and use release-stores to publish the nodes in one or +// more lists. +// +// ... prev vs. next pointer ordering ... +// + +#pragma once +#include <assert.h> +#include <stdlib.h> + +#include <atomic> + +#include "memory/allocator.h" +#include "port/port.h" +#include "util/random.h" + +namespace ROCKSDB_NAMESPACE { + +template <typename Key, class Comparator> +class SkipList { + private: + struct Node; + + public: + // Create a new SkipList object that will use "cmp" for comparing keys, + // and will allocate memory using "*allocator". Objects allocated in the + // allocator must remain allocated for the lifetime of the skiplist object. + explicit SkipList(Comparator cmp, Allocator* allocator, + int32_t max_height = 12, int32_t branching_factor = 4); + // No copying allowed + SkipList(const SkipList&) = delete; + void operator=(const SkipList&) = delete; + + // Insert key into the list. + // REQUIRES: nothing that compares equal to key is currently in the list. + void Insert(const Key& key); + + // Returns true iff an entry that compares equal to key is in the list. + bool Contains(const Key& key) const; + + // Return estimated number of entries smaller than `key`. + uint64_t EstimateCount(const Key& key) const; + + // Iteration over the contents of a skip list + class Iterator { + public: + // Initialize an iterator over the specified list. + // The returned iterator is not valid. + explicit Iterator(const SkipList* list); + + // Change the underlying skiplist used for this iterator + // This enables us not changing the iterator without deallocating + // an old one and then allocating a new one + void SetList(const SkipList* list); + + // Returns true iff the iterator is positioned at a valid node. + bool Valid() const; + + // Returns the key at the current position. + // REQUIRES: Valid() + const Key& key() const; + + // Advances to the next position. + // REQUIRES: Valid() + void Next(); + + // Advances to the previous position. + // REQUIRES: Valid() + void Prev(); + + // Advance to the first entry with a key >= target + void Seek(const Key& target); + + // Retreat to the last entry with a key <= target + void SeekForPrev(const Key& target); + + // Position at the first entry in list. + // Final state of iterator is Valid() iff list is not empty. + void SeekToFirst(); + + // Position at the last entry in list. + // Final state of iterator is Valid() iff list is not empty. + void SeekToLast(); + + private: + const SkipList* list_; + Node* node_; + // Intentionally copyable + }; + + private: + const uint16_t kMaxHeight_; + const uint16_t kBranching_; + const uint32_t kScaledInverseBranching_; + + // Immutable after construction + Comparator const compare_; + Allocator* const allocator_; // Allocator used for allocations of nodes + + Node* const head_; + + // Modified only by Insert(). Read racily by readers, but stale + // values are ok. + std::atomic<int> max_height_; // Height of the entire list + + // Used for optimizing sequential insert patterns. Tricky. prev_[i] for + // i up to max_height_ is the predecessor of prev_[0] and prev_height_ + // is the height of prev_[0]. prev_[0] can only be equal to head before + // insertion, in which case max_height_ and prev_height_ are 1. + Node** prev_; + int32_t prev_height_; + + inline int GetMaxHeight() const { + return max_height_.load(std::memory_order_relaxed); + } + + Node* NewNode(const Key& key, int height); + int RandomHeight(); + bool Equal(const Key& a, const Key& b) const { return (compare_(a, b) == 0); } + bool LessThan(const Key& a, const Key& b) const { + return (compare_(a, b) < 0); + } + + // Return true if key is greater than the data stored in "n" + bool KeyIsAfterNode(const Key& key, Node* n) const; + + // Returns the earliest node with a key >= key. + // Return nullptr if there is no such node. + Node* FindGreaterOrEqual(const Key& key) const; + + // Return the latest node with a key < key. + // Return head_ if there is no such node. + // Fills prev[level] with pointer to previous node at "level" for every + // level in [0..max_height_-1], if prev is non-null. + Node* FindLessThan(const Key& key, Node** prev = nullptr) const; + + // Return the last node in the list. + // Return head_ if list is empty. + Node* FindLast() const; +}; + +// Implementation details follow +template <typename Key, class Comparator> +struct SkipList<Key, Comparator>::Node { + explicit Node(const Key& k) : key(k) {} + + Key const key; + + // Accessors/mutators for links. Wrapped in methods so we can + // add the appropriate barriers as necessary. + Node* Next(int n) { + assert(n >= 0); + // Use an 'acquire load' so that we observe a fully initialized + // version of the returned Node. + return (next_[n].load(std::memory_order_acquire)); + } + void SetNext(int n, Node* x) { + assert(n >= 0); + // Use a 'release store' so that anybody who reads through this + // pointer observes a fully initialized version of the inserted node. + next_[n].store(x, std::memory_order_release); + } + + // No-barrier variants that can be safely used in a few locations. + Node* NoBarrier_Next(int n) { + assert(n >= 0); + return next_[n].load(std::memory_order_relaxed); + } + void NoBarrier_SetNext(int n, Node* x) { + assert(n >= 0); + next_[n].store(x, std::memory_order_relaxed); + } + + private: + // Array of length equal to the node height. next_[0] is lowest level link. + std::atomic<Node*> next_[1]; +}; + +template <typename Key, class Comparator> +typename SkipList<Key, Comparator>::Node* SkipList<Key, Comparator>::NewNode( + const Key& key, int height) { + char* mem = allocator_->AllocateAligned( + sizeof(Node) + sizeof(std::atomic<Node*>) * (height - 1)); + return new (mem) Node(key); +} + +template <typename Key, class Comparator> +inline SkipList<Key, Comparator>::Iterator::Iterator(const SkipList* list) { + SetList(list); +} + +template <typename Key, class Comparator> +inline void SkipList<Key, Comparator>::Iterator::SetList(const SkipList* list) { + list_ = list; + node_ = nullptr; +} + +template <typename Key, class Comparator> +inline bool SkipList<Key, Comparator>::Iterator::Valid() const { + return node_ != nullptr; +} + +template <typename Key, class Comparator> +inline const Key& SkipList<Key, Comparator>::Iterator::key() const { + assert(Valid()); + return node_->key; +} + +template <typename Key, class Comparator> +inline void SkipList<Key, Comparator>::Iterator::Next() { + assert(Valid()); + node_ = node_->Next(0); +} + +template <typename Key, class Comparator> +inline void SkipList<Key, Comparator>::Iterator::Prev() { + // Instead of using explicit "prev" links, we just search for the + // last node that falls before key. + assert(Valid()); + node_ = list_->FindLessThan(node_->key); + if (node_ == list_->head_) { + node_ = nullptr; + } +} + +template <typename Key, class Comparator> +inline void SkipList<Key, Comparator>::Iterator::Seek(const Key& target) { + node_ = list_->FindGreaterOrEqual(target); +} + +template <typename Key, class Comparator> +inline void SkipList<Key, Comparator>::Iterator::SeekForPrev( + const Key& target) { + Seek(target); + if (!Valid()) { + SeekToLast(); + } + while (Valid() && list_->LessThan(target, key())) { + Prev(); + } +} + +template <typename Key, class Comparator> +inline void SkipList<Key, Comparator>::Iterator::SeekToFirst() { + node_ = list_->head_->Next(0); +} + +template <typename Key, class Comparator> +inline void SkipList<Key, Comparator>::Iterator::SeekToLast() { + node_ = list_->FindLast(); + if (node_ == list_->head_) { + node_ = nullptr; + } +} + +template <typename Key, class Comparator> +int SkipList<Key, Comparator>::RandomHeight() { + auto rnd = Random::GetTLSInstance(); + + // Increase height with probability 1 in kBranching + int height = 1; + while (height < kMaxHeight_ && rnd->Next() < kScaledInverseBranching_) { + height++; + } + assert(height > 0); + assert(height <= kMaxHeight_); + return height; +} + +template <typename Key, class Comparator> +bool SkipList<Key, Comparator>::KeyIsAfterNode(const Key& key, Node* n) const { + // nullptr n is considered infinite + return (n != nullptr) && (compare_(n->key, key) < 0); +} + +template <typename Key, class Comparator> +typename SkipList<Key, Comparator>::Node* +SkipList<Key, Comparator>::FindGreaterOrEqual(const Key& key) const { + // Note: It looks like we could reduce duplication by implementing + // this function as FindLessThan(key)->Next(0), but we wouldn't be able + // to exit early on equality and the result wouldn't even be correct. + // A concurrent insert might occur after FindLessThan(key) but before + // we get a chance to call Next(0). + Node* x = head_; + int level = GetMaxHeight() - 1; + Node* last_bigger = nullptr; + while (true) { + assert(x != nullptr); + Node* next = x->Next(level); + // Make sure the lists are sorted + assert(x == head_ || next == nullptr || KeyIsAfterNode(next->key, x)); + // Make sure we haven't overshot during our search + assert(x == head_ || KeyIsAfterNode(key, x)); + int cmp = + (next == nullptr || next == last_bigger) ? 1 : compare_(next->key, key); + if (cmp == 0 || (cmp > 0 && level == 0)) { + return next; + } else if (cmp < 0) { + // Keep searching in this list + x = next; + } else { + // Switch to next list, reuse compare_() result + last_bigger = next; + level--; + } + } +} + +template <typename Key, class Comparator> +typename SkipList<Key, Comparator>::Node* +SkipList<Key, Comparator>::FindLessThan(const Key& key, Node** prev) const { + Node* x = head_; + int level = GetMaxHeight() - 1; + // KeyIsAfter(key, last_not_after) is definitely false + Node* last_not_after = nullptr; + while (true) { + assert(x != nullptr); + Node* next = x->Next(level); + assert(x == head_ || next == nullptr || KeyIsAfterNode(next->key, x)); + assert(x == head_ || KeyIsAfterNode(key, x)); + if (next != last_not_after && KeyIsAfterNode(key, next)) { + // Keep searching in this list + x = next; + } else { + if (prev != nullptr) { + prev[level] = x; + } + if (level == 0) { + return x; + } else { + // Switch to next list, reuse KeyIUsAfterNode() result + last_not_after = next; + level--; + } + } + } +} + +template <typename Key, class Comparator> +typename SkipList<Key, Comparator>::Node* SkipList<Key, Comparator>::FindLast() + const { + Node* x = head_; + int level = GetMaxHeight() - 1; + while (true) { + Node* next = x->Next(level); + if (next == nullptr) { + if (level == 0) { + return x; + } else { + // Switch to next list + level--; + } + } else { + x = next; + } + } +} + +template <typename Key, class Comparator> +uint64_t SkipList<Key, Comparator>::EstimateCount(const Key& key) const { + uint64_t count = 0; + + Node* x = head_; + int level = GetMaxHeight() - 1; + while (true) { + assert(x == head_ || compare_(x->key, key) < 0); + Node* next = x->Next(level); + if (next == nullptr || compare_(next->key, key) >= 0) { + if (level == 0) { + return count; + } else { + // Switch to next list + count *= kBranching_; + level--; + } + } else { + x = next; + count++; + } + } +} + +template <typename Key, class Comparator> +SkipList<Key, Comparator>::SkipList(const Comparator cmp, Allocator* allocator, + int32_t max_height, + int32_t branching_factor) + : kMaxHeight_(static_cast<uint16_t>(max_height)), + kBranching_(static_cast<uint16_t>(branching_factor)), + kScaledInverseBranching_((Random::kMaxNext + 1) / kBranching_), + compare_(cmp), + allocator_(allocator), + head_(NewNode(0 /* any key will do */, max_height)), + max_height_(1), + prev_height_(1) { + assert(max_height > 0 && kMaxHeight_ == static_cast<uint32_t>(max_height)); + assert(branching_factor > 0 && + kBranching_ == static_cast<uint32_t>(branching_factor)); + assert(kScaledInverseBranching_ > 0); + // Allocate the prev_ Node* array, directly from the passed-in allocator. + // prev_ does not need to be freed, as its life cycle is tied up with + // the allocator as a whole. + prev_ = reinterpret_cast<Node**>( + allocator_->AllocateAligned(sizeof(Node*) * kMaxHeight_)); + for (int i = 0; i < kMaxHeight_; i++) { + head_->SetNext(i, nullptr); + prev_[i] = head_; + } +} + +template <typename Key, class Comparator> +void SkipList<Key, Comparator>::Insert(const Key& key) { + // fast path for sequential insertion + if (!KeyIsAfterNode(key, prev_[0]->NoBarrier_Next(0)) && + (prev_[0] == head_ || KeyIsAfterNode(key, prev_[0]))) { + assert(prev_[0] != head_ || (prev_height_ == 1 && GetMaxHeight() == 1)); + + // Outside of this method prev_[1..max_height_] is the predecessor + // of prev_[0], and prev_height_ refers to prev_[0]. Inside Insert + // prev_[0..max_height - 1] is the predecessor of key. Switch from + // the external state to the internal + for (int i = 1; i < prev_height_; i++) { + prev_[i] = prev_[0]; + } + } else { + // TODO(opt): we could use a NoBarrier predecessor search as an + // optimization for architectures where memory_order_acquire needs + // a synchronization instruction. Doesn't matter on x86 + FindLessThan(key, prev_); + } + + // Our data structure does not allow duplicate insertion + assert(prev_[0]->Next(0) == nullptr || !Equal(key, prev_[0]->Next(0)->key)); + + int height = RandomHeight(); + if (height > GetMaxHeight()) { + for (int i = GetMaxHeight(); i < height; i++) { + prev_[i] = head_; + } + // fprintf(stderr, "Change height from %d to %d\n", max_height_, height); + + // It is ok to mutate max_height_ without any synchronization + // with concurrent readers. A concurrent reader that observes + // the new value of max_height_ will see either the old value of + // new level pointers from head_ (nullptr), or a new value set in + // the loop below. In the former case the reader will + // immediately drop to the next level since nullptr sorts after all + // keys. In the latter case the reader will use the new node. + max_height_.store(height, std::memory_order_relaxed); + } + + Node* x = NewNode(key, height); + for (int i = 0; i < height; i++) { + // NoBarrier_SetNext() suffices since we will add a barrier when + // we publish a pointer to "x" in prev[i]. + x->NoBarrier_SetNext(i, prev_[i]->NoBarrier_Next(i)); + prev_[i]->SetNext(i, x); + } + prev_[0] = x; + prev_height_ = height; +} + +template <typename Key, class Comparator> +bool SkipList<Key, Comparator>::Contains(const Key& key) const { + Node* x = FindGreaterOrEqual(key); + if (x != nullptr && Equal(key, x->key)) { + return true; + } else { + return false; + } +} + +} // namespace ROCKSDB_NAMESPACE diff --git a/src/rocksdb/memtable/skiplist_test.cc b/src/rocksdb/memtable/skiplist_test.cc new file mode 100644 index 000000000..a07088511 --- /dev/null +++ b/src/rocksdb/memtable/skiplist_test.cc @@ -0,0 +1,387 @@ +// Copyright (c) 2011-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). +// +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#include "memtable/skiplist.h" + +#include <set> + +#include "memory/arena.h" +#include "rocksdb/env.h" +#include "test_util/testharness.h" +#include "util/hash.h" +#include "util/random.h" + +namespace ROCKSDB_NAMESPACE { + +using Key = uint64_t; + +struct TestComparator { + int operator()(const Key& a, const Key& b) const { + if (a < b) { + return -1; + } else if (a > b) { + return +1; + } else { + return 0; + } + } +}; + +class SkipTest : public testing::Test {}; + +TEST_F(SkipTest, Empty) { + Arena arena; + TestComparator cmp; + SkipList<Key, TestComparator> list(cmp, &arena); + ASSERT_TRUE(!list.Contains(10)); + + SkipList<Key, TestComparator>::Iterator iter(&list); + ASSERT_TRUE(!iter.Valid()); + iter.SeekToFirst(); + ASSERT_TRUE(!iter.Valid()); + iter.Seek(100); + ASSERT_TRUE(!iter.Valid()); + iter.SeekForPrev(100); + ASSERT_TRUE(!iter.Valid()); + iter.SeekToLast(); + ASSERT_TRUE(!iter.Valid()); +} + +TEST_F(SkipTest, InsertAndLookup) { + const int N = 2000; + const int R = 5000; + Random rnd(1000); + std::set<Key> keys; + Arena arena; + TestComparator cmp; + SkipList<Key, TestComparator> list(cmp, &arena); + for (int i = 0; i < N; i++) { + Key key = rnd.Next() % R; + if (keys.insert(key).second) { + list.Insert(key); + } + } + + for (int i = 0; i < R; i++) { + if (list.Contains(i)) { + ASSERT_EQ(keys.count(i), 1U); + } else { + ASSERT_EQ(keys.count(i), 0U); + } + } + + // Simple iterator tests + { + SkipList<Key, TestComparator>::Iterator iter(&list); + ASSERT_TRUE(!iter.Valid()); + + iter.Seek(0); + ASSERT_TRUE(iter.Valid()); + ASSERT_EQ(*(keys.begin()), iter.key()); + + iter.SeekForPrev(R - 1); + ASSERT_TRUE(iter.Valid()); + ASSERT_EQ(*(keys.rbegin()), iter.key()); + + iter.SeekToFirst(); + ASSERT_TRUE(iter.Valid()); + ASSERT_EQ(*(keys.begin()), iter.key()); + + iter.SeekToLast(); + ASSERT_TRUE(iter.Valid()); + ASSERT_EQ(*(keys.rbegin()), iter.key()); + } + + // Forward iteration test + for (int i = 0; i < R; i++) { + SkipList<Key, TestComparator>::Iterator iter(&list); + iter.Seek(i); + + // Compare against model iterator + std::set<Key>::iterator model_iter = keys.lower_bound(i); + for (int j = 0; j < 3; j++) { + if (model_iter == keys.end()) { + ASSERT_TRUE(!iter.Valid()); + break; + } else { + ASSERT_TRUE(iter.Valid()); + ASSERT_EQ(*model_iter, iter.key()); + ++model_iter; + iter.Next(); + } + } + } + + // Backward iteration test + for (int i = 0; i < R; i++) { + SkipList<Key, TestComparator>::Iterator iter(&list); + iter.SeekForPrev(i); + + // Compare against model iterator + std::set<Key>::iterator model_iter = keys.upper_bound(i); + for (int j = 0; j < 3; j++) { + if (model_iter == keys.begin()) { + ASSERT_TRUE(!iter.Valid()); + break; + } else { + ASSERT_TRUE(iter.Valid()); + ASSERT_EQ(*--model_iter, iter.key()); + iter.Prev(); + } + } + } +} + +// We want to make sure that with a single writer and multiple +// concurrent readers (with no synchronization other than when a +// reader's iterator is created), the reader always observes all the +// data that was present in the skip list when the iterator was +// constructor. Because insertions are happening concurrently, we may +// also observe new values that were inserted since the iterator was +// constructed, but we should never miss any values that were present +// at iterator construction time. +// +// We generate multi-part keys: +// <key,gen,hash> +// where: +// key is in range [0..K-1] +// gen is a generation number for key +// hash is hash(key,gen) +// +// The insertion code picks a random key, sets gen to be 1 + the last +// generation number inserted for that key, and sets hash to Hash(key,gen). +// +// At the beginning of a read, we snapshot the last inserted +// generation number for each key. We then iterate, including random +// calls to Next() and Seek(). For every key we encounter, we +// check that it is either expected given the initial snapshot or has +// been concurrently added since the iterator started. +class ConcurrentTest { + private: + static const uint32_t K = 4; + + static uint64_t key(Key key) { return (key >> 40); } + static uint64_t gen(Key key) { return (key >> 8) & 0xffffffffu; } + static uint64_t hash(Key key) { return key & 0xff; } + + static uint64_t HashNumbers(uint64_t k, uint64_t g) { + uint64_t data[2] = {k, g}; + return Hash(reinterpret_cast<char*>(data), sizeof(data), 0); + } + + static Key MakeKey(uint64_t k, uint64_t g) { + assert(sizeof(Key) == sizeof(uint64_t)); + assert(k <= K); // We sometimes pass K to seek to the end of the skiplist + assert(g <= 0xffffffffu); + return ((k << 40) | (g << 8) | (HashNumbers(k, g) & 0xff)); + } + + static bool IsValidKey(Key k) { + return hash(k) == (HashNumbers(key(k), gen(k)) & 0xff); + } + + static Key RandomTarget(Random* rnd) { + switch (rnd->Next() % 10) { + case 0: + // Seek to beginning + return MakeKey(0, 0); + case 1: + // Seek to end + return MakeKey(K, 0); + default: + // Seek to middle + return MakeKey(rnd->Next() % K, 0); + } + } + + // Per-key generation + struct State { + std::atomic<int> generation[K]; + void Set(int k, int v) { + generation[k].store(v, std::memory_order_release); + } + int Get(int k) { return generation[k].load(std::memory_order_acquire); } + + State() { + for (unsigned int k = 0; k < K; k++) { + Set(k, 0); + } + } + }; + + // Current state of the test + State current_; + + Arena arena_; + + // SkipList is not protected by mu_. We just use a single writer + // thread to modify it. + SkipList<Key, TestComparator> list_; + + public: + ConcurrentTest() : list_(TestComparator(), &arena_) {} + + // REQUIRES: External synchronization + void WriteStep(Random* rnd) { + const uint32_t k = rnd->Next() % K; + const int g = current_.Get(k) + 1; + const Key new_key = MakeKey(k, g); + list_.Insert(new_key); + current_.Set(k, g); + } + + void ReadStep(Random* rnd) { + // Remember the initial committed state of the skiplist. + State initial_state; + for (unsigned int k = 0; k < K; k++) { + initial_state.Set(k, current_.Get(k)); + } + + Key pos = RandomTarget(rnd); + SkipList<Key, TestComparator>::Iterator iter(&list_); + iter.Seek(pos); + while (true) { + Key current; + if (!iter.Valid()) { + current = MakeKey(K, 0); + } else { + current = iter.key(); + ASSERT_TRUE(IsValidKey(current)) << current; + } + ASSERT_LE(pos, current) << "should not go backwards"; + + // Verify that everything in [pos,current) was not present in + // initial_state. + while (pos < current) { + ASSERT_LT(key(pos), K) << pos; + + // Note that generation 0 is never inserted, so it is ok if + // <*,0,*> is missing. + ASSERT_TRUE((gen(pos) == 0U) || + (gen(pos) > static_cast<uint64_t>(initial_state.Get( + static_cast<int>(key(pos)))))) + << "key: " << key(pos) << "; gen: " << gen(pos) + << "; initgen: " << initial_state.Get(static_cast<int>(key(pos))); + + // Advance to next key in the valid key space + if (key(pos) < key(current)) { + pos = MakeKey(key(pos) + 1, 0); + } else { + pos = MakeKey(key(pos), gen(pos) + 1); + } + } + + if (!iter.Valid()) { + break; + } + + if (rnd->Next() % 2) { + iter.Next(); + pos = MakeKey(key(pos), gen(pos) + 1); + } else { + Key new_target = RandomTarget(rnd); + if (new_target > pos) { + pos = new_target; + iter.Seek(new_target); + } + } + } + } +}; +const uint32_t ConcurrentTest::K; + +// Simple test that does single-threaded testing of the ConcurrentTest +// scaffolding. +TEST_F(SkipTest, ConcurrentWithoutThreads) { + ConcurrentTest test; + Random rnd(test::RandomSeed()); + for (int i = 0; i < 10000; i++) { + test.ReadStep(&rnd); + test.WriteStep(&rnd); + } +} + +class TestState { + public: + ConcurrentTest t_; + int seed_; + std::atomic<bool> quit_flag_; + + enum ReaderState { STARTING, RUNNING, DONE }; + + explicit TestState(int s) + : seed_(s), quit_flag_(false), state_(STARTING), state_cv_(&mu_) {} + + void Wait(ReaderState s) { + mu_.Lock(); + while (state_ != s) { + state_cv_.Wait(); + } + mu_.Unlock(); + } + + void Change(ReaderState s) { + mu_.Lock(); + state_ = s; + state_cv_.Signal(); + mu_.Unlock(); + } + + private: + port::Mutex mu_; + ReaderState state_; + port::CondVar state_cv_; +}; + +static void ConcurrentReader(void* arg) { + TestState* state = reinterpret_cast<TestState*>(arg); + Random rnd(state->seed_); + int64_t reads = 0; + state->Change(TestState::RUNNING); + while (!state->quit_flag_.load(std::memory_order_acquire)) { + state->t_.ReadStep(&rnd); + ++reads; + } + state->Change(TestState::DONE); +} + +static void RunConcurrent(int run) { + const int seed = test::RandomSeed() + (run * 100); + Random rnd(seed); + const int N = 1000; + const int kSize = 1000; + for (int i = 0; i < N; i++) { + if ((i % 100) == 0) { + fprintf(stderr, "Run %d of %d\n", i, N); + } + TestState state(seed + 1); + Env::Default()->SetBackgroundThreads(1); + Env::Default()->Schedule(ConcurrentReader, &state); + state.Wait(TestState::RUNNING); + for (int k = 0; k < kSize; k++) { + state.t_.WriteStep(&rnd); + } + state.quit_flag_.store(true, std::memory_order_release); + state.Wait(TestState::DONE); + } +} + +TEST_F(SkipTest, Concurrent1) { RunConcurrent(1); } +TEST_F(SkipTest, Concurrent2) { RunConcurrent(2); } +TEST_F(SkipTest, Concurrent3) { RunConcurrent(3); } +TEST_F(SkipTest, Concurrent4) { RunConcurrent(4); } +TEST_F(SkipTest, Concurrent5) { RunConcurrent(5); } + +} // namespace ROCKSDB_NAMESPACE + +int main(int argc, char** argv) { + ROCKSDB_NAMESPACE::port::InstallStackTraceHandler(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/src/rocksdb/memtable/skiplistrep.cc b/src/rocksdb/memtable/skiplistrep.cc new file mode 100644 index 000000000..40f13a2c1 --- /dev/null +++ b/src/rocksdb/memtable/skiplistrep.cc @@ -0,0 +1,370 @@ +// Copyright (c) 2011-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). +// +#include <random> + +#include "db/memtable.h" +#include "memory/arena.h" +#include "memtable/inlineskiplist.h" +#include "rocksdb/memtablerep.h" +#include "rocksdb/utilities/options_type.h" +#include "util/string_util.h" + +namespace ROCKSDB_NAMESPACE { +namespace { +class SkipListRep : public MemTableRep { + InlineSkipList<const MemTableRep::KeyComparator&> skip_list_; + const MemTableRep::KeyComparator& cmp_; + const SliceTransform* transform_; + const size_t lookahead_; + + friend class LookaheadIterator; + + public: + explicit SkipListRep(const MemTableRep::KeyComparator& compare, + Allocator* allocator, const SliceTransform* transform, + const size_t lookahead) + : MemTableRep(allocator), + skip_list_(compare, allocator), + cmp_(compare), + transform_(transform), + lookahead_(lookahead) {} + + KeyHandle Allocate(const size_t len, char** buf) override { + *buf = skip_list_.AllocateKey(len); + return static_cast<KeyHandle>(*buf); + } + + // Insert key into the list. + // REQUIRES: nothing that compares equal to key is currently in the list. + void Insert(KeyHandle handle) override { + skip_list_.Insert(static_cast<char*>(handle)); + } + + bool InsertKey(KeyHandle handle) override { + return skip_list_.Insert(static_cast<char*>(handle)); + } + + void InsertWithHint(KeyHandle handle, void** hint) override { + skip_list_.InsertWithHint(static_cast<char*>(handle), hint); + } + + bool InsertKeyWithHint(KeyHandle handle, void** hint) override { + return skip_list_.InsertWithHint(static_cast<char*>(handle), hint); + } + + void InsertWithHintConcurrently(KeyHandle handle, void** hint) override { + skip_list_.InsertWithHintConcurrently(static_cast<char*>(handle), hint); + } + + bool InsertKeyWithHintConcurrently(KeyHandle handle, void** hint) override { + return skip_list_.InsertWithHintConcurrently(static_cast<char*>(handle), + hint); + } + + void InsertConcurrently(KeyHandle handle) override { + skip_list_.InsertConcurrently(static_cast<char*>(handle)); + } + + bool InsertKeyConcurrently(KeyHandle handle) override { + return skip_list_.InsertConcurrently(static_cast<char*>(handle)); + } + + // Returns true iff an entry that compares equal to key is in the list. + bool Contains(const char* key) const override { + return skip_list_.Contains(key); + } + + size_t ApproximateMemoryUsage() override { + // All memory is allocated through allocator; nothing to report here + return 0; + } + + void Get(const LookupKey& k, void* callback_args, + bool (*callback_func)(void* arg, const char* entry)) override { + SkipListRep::Iterator iter(&skip_list_); + Slice dummy_slice; + for (iter.Seek(dummy_slice, k.memtable_key().data()); + iter.Valid() && callback_func(callback_args, iter.key()); + iter.Next()) { + } + } + + uint64_t ApproximateNumEntries(const Slice& start_ikey, + const Slice& end_ikey) override { + std::string tmp; + uint64_t start_count = + skip_list_.EstimateCount(EncodeKey(&tmp, start_ikey)); + uint64_t end_count = skip_list_.EstimateCount(EncodeKey(&tmp, end_ikey)); + return (end_count >= start_count) ? (end_count - start_count) : 0; + } + + void UniqueRandomSample(const uint64_t num_entries, + const uint64_t target_sample_size, + std::unordered_set<const char*>* entries) override { + entries->clear(); + // Avoid divide-by-0. + assert(target_sample_size > 0); + assert(num_entries > 0); + // NOTE: the size of entries is not enforced to be exactly + // target_sample_size at the end of this function, it might be slightly + // greater or smaller. + SkipListRep::Iterator iter(&skip_list_); + // There are two methods to create the subset of samples (size m) + // from the table containing N elements: + // 1-Iterate linearly through the N memtable entries. For each entry i, + // add it to the sample set with a probability + // (target_sample_size - entries.size() ) / (N-i). + // + // 2-Pick m random elements without repetition. + // We pick Option 2 when m<sqrt(N) and + // Option 1 when m > sqrt(N). + if (target_sample_size > + static_cast<uint64_t>(std::sqrt(1.0 * num_entries))) { + Random* rnd = Random::GetTLSInstance(); + iter.SeekToFirst(); + uint64_t counter = 0, num_samples_left = target_sample_size; + for (; iter.Valid() && (num_samples_left > 0); iter.Next(), counter++) { + // Add entry to sample set with probability + // num_samples_left/(num_entries - counter). + if (rnd->Next() % (num_entries - counter) < num_samples_left) { + entries->insert(iter.key()); + num_samples_left--; + } + } + } else { + // Option 2: pick m random elements with no duplicates. + // If Option 2 is picked, then target_sample_size<sqrt(N) + // Using a set spares the need to check for duplicates. + for (uint64_t i = 0; i < target_sample_size; i++) { + // We give it 5 attempts to find a non-duplicate + // With 5 attempts, the chances of returning `entries` set + // of size target_sample_size is: + // PROD_{i=1}^{target_sample_size-1} [1-(i/N)^5] + // which is monotonically increasing with N in the worse case + // of target_sample_size=sqrt(N), and is always >99.9% for N>4. + // At worst, for the final pick , when m=sqrt(N) there is + // a probability of p= 1/sqrt(N) chances to find a duplicate. + for (uint64_t j = 0; j < 5; j++) { + iter.RandomSeek(); + // unordered_set::insert returns pair<iterator, bool>. + // The second element is true if an insert successfully happened. + // If element is already in the set, this bool will be false, and + // true otherwise. + if ((entries->insert(iter.key())).second) { + break; + } + } + } + } + } + + ~SkipListRep() override {} + + // Iteration over the contents of a skip list + class Iterator : public MemTableRep::Iterator { + InlineSkipList<const MemTableRep::KeyComparator&>::Iterator iter_; + + public: + // Initialize an iterator over the specified list. + // The returned iterator is not valid. + explicit Iterator( + const InlineSkipList<const MemTableRep::KeyComparator&>* list) + : iter_(list) {} + + ~Iterator() override {} + + // Returns true iff the iterator is positioned at a valid node. + bool Valid() const override { return iter_.Valid(); } + + // Returns the key at the current position. + // REQUIRES: Valid() + const char* key() const override { return iter_.key(); } + + // Advances to the next position. + // REQUIRES: Valid() + void Next() override { iter_.Next(); } + + // Advances to the previous position. + // REQUIRES: Valid() + void Prev() override { iter_.Prev(); } + + // Advance to the first entry with a key >= target + void Seek(const Slice& user_key, const char* memtable_key) override { + if (memtable_key != nullptr) { + iter_.Seek(memtable_key); + } else { + iter_.Seek(EncodeKey(&tmp_, user_key)); + } + } + + // Retreat to the last entry with a key <= target + void SeekForPrev(const Slice& user_key, const char* memtable_key) override { + if (memtable_key != nullptr) { + iter_.SeekForPrev(memtable_key); + } else { + iter_.SeekForPrev(EncodeKey(&tmp_, user_key)); + } + } + + void RandomSeek() override { iter_.RandomSeek(); } + + // Position at the first entry in list. + // Final state of iterator is Valid() iff list is not empty. + void SeekToFirst() override { iter_.SeekToFirst(); } + + // Position at the last entry in list. + // Final state of iterator is Valid() iff list is not empty. + void SeekToLast() override { iter_.SeekToLast(); } + + protected: + std::string tmp_; // For passing to EncodeKey + }; + + // Iterator over the contents of a skip list which also keeps track of the + // previously visited node. In Seek(), it examines a few nodes after it + // first, falling back to O(log n) search from the head of the list only if + // the target key hasn't been found. + class LookaheadIterator : public MemTableRep::Iterator { + public: + explicit LookaheadIterator(const SkipListRep& rep) + : rep_(rep), iter_(&rep_.skip_list_), prev_(iter_) {} + + ~LookaheadIterator() override {} + + bool Valid() const override { return iter_.Valid(); } + + const char* key() const override { + assert(Valid()); + return iter_.key(); + } + + void Next() override { + assert(Valid()); + + bool advance_prev = true; + if (prev_.Valid()) { + auto k1 = rep_.UserKey(prev_.key()); + auto k2 = rep_.UserKey(iter_.key()); + + if (k1.compare(k2) == 0) { + // same user key, don't move prev_ + advance_prev = false; + } else if (rep_.transform_) { + // only advance prev_ if it has the same prefix as iter_ + auto t1 = rep_.transform_->Transform(k1); + auto t2 = rep_.transform_->Transform(k2); + advance_prev = t1.compare(t2) == 0; + } + } + + if (advance_prev) { + prev_ = iter_; + } + iter_.Next(); + } + + void Prev() override { + assert(Valid()); + iter_.Prev(); + prev_ = iter_; + } + + void Seek(const Slice& internal_key, const char* memtable_key) override { + const char* encoded_key = (memtable_key != nullptr) + ? memtable_key + : EncodeKey(&tmp_, internal_key); + + if (prev_.Valid() && rep_.cmp_(encoded_key, prev_.key()) >= 0) { + // prev_.key() is smaller or equal to our target key; do a quick + // linear search (at most lookahead_ steps) starting from prev_ + iter_ = prev_; + + size_t cur = 0; + while (cur++ <= rep_.lookahead_ && iter_.Valid()) { + if (rep_.cmp_(encoded_key, iter_.key()) <= 0) { + return; + } + Next(); + } + } + + iter_.Seek(encoded_key); + prev_ = iter_; + } + + void SeekForPrev(const Slice& internal_key, + const char* memtable_key) override { + const char* encoded_key = (memtable_key != nullptr) + ? memtable_key + : EncodeKey(&tmp_, internal_key); + iter_.SeekForPrev(encoded_key); + prev_ = iter_; + } + + void SeekToFirst() override { + iter_.SeekToFirst(); + prev_ = iter_; + } + + void SeekToLast() override { + iter_.SeekToLast(); + prev_ = iter_; + } + + protected: + std::string tmp_; // For passing to EncodeKey + + private: + const SkipListRep& rep_; + InlineSkipList<const MemTableRep::KeyComparator&>::Iterator iter_; + InlineSkipList<const MemTableRep::KeyComparator&>::Iterator prev_; + }; + + MemTableRep::Iterator* GetIterator(Arena* arena = nullptr) override { + if (lookahead_ > 0) { + void* mem = + arena ? arena->AllocateAligned(sizeof(SkipListRep::LookaheadIterator)) + : + operator new(sizeof(SkipListRep::LookaheadIterator)); + return new (mem) SkipListRep::LookaheadIterator(*this); + } else { + void* mem = arena ? arena->AllocateAligned(sizeof(SkipListRep::Iterator)) + : + operator new(sizeof(SkipListRep::Iterator)); + return new (mem) SkipListRep::Iterator(&skip_list_); + } + } +}; +} // namespace + +static std::unordered_map<std::string, OptionTypeInfo> skiplist_factory_info = { +#ifndef ROCKSDB_LITE + {"lookahead", + {0, OptionType::kSizeT, OptionVerificationType::kNormal, + OptionTypeFlags::kDontSerialize /*Since it is part of the ID*/}}, +#endif +}; + +SkipListFactory::SkipListFactory(size_t lookahead) : lookahead_(lookahead) { + RegisterOptions("SkipListFactoryOptions", &lookahead_, + &skiplist_factory_info); +} + +std::string SkipListFactory::GetId() const { + std::string id = Name(); + if (lookahead_ > 0) { + id.append(":").append(std::to_string(lookahead_)); + } + return id; +} + +MemTableRep* SkipListFactory::CreateMemTableRep( + const MemTableRep::KeyComparator& compare, Allocator* allocator, + const SliceTransform* transform, Logger* /*logger*/) { + return new SkipListRep(compare, allocator, transform, lookahead_); +} + +} // namespace ROCKSDB_NAMESPACE diff --git a/src/rocksdb/memtable/stl_wrappers.h b/src/rocksdb/memtable/stl_wrappers.h new file mode 100644 index 000000000..783a8088d --- /dev/null +++ b/src/rocksdb/memtable/stl_wrappers.h @@ -0,0 +1,33 @@ +// Copyright (c) 2011-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). +#pragma once + +#include <map> +#include <string> + +#include "rocksdb/comparator.h" +#include "rocksdb/memtablerep.h" +#include "rocksdb/slice.h" +#include "util/coding.h" + +namespace ROCKSDB_NAMESPACE { +namespace stl_wrappers { + +class Base { + protected: + const MemTableRep::KeyComparator& compare_; + explicit Base(const MemTableRep::KeyComparator& compare) + : compare_(compare) {} +}; + +struct Compare : private Base { + explicit Compare(const MemTableRep::KeyComparator& compare) : Base(compare) {} + inline bool operator()(const char* a, const char* b) const { + return compare_(a, b) < 0; + } +}; + +} // namespace stl_wrappers +} // namespace ROCKSDB_NAMESPACE diff --git a/src/rocksdb/memtable/vectorrep.cc b/src/rocksdb/memtable/vectorrep.cc new file mode 100644 index 000000000..293163349 --- /dev/null +++ b/src/rocksdb/memtable/vectorrep.cc @@ -0,0 +1,309 @@ +// Copyright (c) 2011-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). +// +#ifndef ROCKSDB_LITE +#include <algorithm> +#include <memory> +#include <set> +#include <type_traits> +#include <unordered_set> + +#include "db/memtable.h" +#include "memory/arena.h" +#include "memtable/stl_wrappers.h" +#include "port/port.h" +#include "rocksdb/memtablerep.h" +#include "rocksdb/utilities/options_type.h" +#include "util/mutexlock.h" + +namespace ROCKSDB_NAMESPACE { +namespace { + +class VectorRep : public MemTableRep { + public: + VectorRep(const KeyComparator& compare, Allocator* allocator, size_t count); + + // Insert key into the collection. (The caller will pack key and value into a + // single buffer and pass that in as the parameter to Insert) + // REQUIRES: nothing that compares equal to key is currently in the + // collection. + void Insert(KeyHandle handle) override; + + // Returns true iff an entry that compares equal to key is in the collection. + bool Contains(const char* key) const override; + + void MarkReadOnly() override; + + size_t ApproximateMemoryUsage() override; + + void Get(const LookupKey& k, void* callback_args, + bool (*callback_func)(void* arg, const char* entry)) override; + + ~VectorRep() override {} + + class Iterator : public MemTableRep::Iterator { + class VectorRep* vrep_; + std::shared_ptr<std::vector<const char*>> bucket_; + std::vector<const char*>::const_iterator mutable cit_; + const KeyComparator& compare_; + std::string tmp_; // For passing to EncodeKey + bool mutable sorted_; + void DoSort() const; + + public: + explicit Iterator(class VectorRep* vrep, + std::shared_ptr<std::vector<const char*>> bucket, + const KeyComparator& compare); + + // Initialize an iterator over the specified collection. + // The returned iterator is not valid. + // explicit Iterator(const MemTableRep* collection); + ~Iterator() override{}; + + // Returns true iff the iterator is positioned at a valid node. + bool Valid() const override; + + // Returns the key at the current position. + // REQUIRES: Valid() + const char* key() const override; + + // Advances to the next position. + // REQUIRES: Valid() + void Next() override; + + // Advances to the previous position. + // REQUIRES: Valid() + void Prev() override; + + // Advance to the first entry with a key >= target + void Seek(const Slice& user_key, const char* memtable_key) override; + + // Advance to the first entry with a key <= target + void SeekForPrev(const Slice& user_key, const char* memtable_key) override; + + // Position at the first entry in collection. + // Final state of iterator is Valid() iff collection is not empty. + void SeekToFirst() override; + + // Position at the last entry in collection. + // Final state of iterator is Valid() iff collection is not empty. + void SeekToLast() override; + }; + + // Return an iterator over the keys in this representation. + MemTableRep::Iterator* GetIterator(Arena* arena) override; + + private: + friend class Iterator; + using Bucket = std::vector<const char*>; + std::shared_ptr<Bucket> bucket_; + mutable port::RWMutex rwlock_; + bool immutable_; + bool sorted_; + const KeyComparator& compare_; +}; + +void VectorRep::Insert(KeyHandle handle) { + auto* key = static_cast<char*>(handle); + WriteLock l(&rwlock_); + assert(!immutable_); + bucket_->push_back(key); +} + +// Returns true iff an entry that compares equal to key is in the collection. +bool VectorRep::Contains(const char* key) const { + ReadLock l(&rwlock_); + return std::find(bucket_->begin(), bucket_->end(), key) != bucket_->end(); +} + +void VectorRep::MarkReadOnly() { + WriteLock l(&rwlock_); + immutable_ = true; +} + +size_t VectorRep::ApproximateMemoryUsage() { + return sizeof(bucket_) + sizeof(*bucket_) + + bucket_->size() * + sizeof( + std::remove_reference<decltype(*bucket_)>::type::value_type); +} + +VectorRep::VectorRep(const KeyComparator& compare, Allocator* allocator, + size_t count) + : MemTableRep(allocator), + bucket_(new Bucket()), + immutable_(false), + sorted_(false), + compare_(compare) { + bucket_.get()->reserve(count); +} + +VectorRep::Iterator::Iterator(class VectorRep* vrep, + std::shared_ptr<std::vector<const char*>> bucket, + const KeyComparator& compare) + : vrep_(vrep), + bucket_(bucket), + cit_(bucket_->end()), + compare_(compare), + sorted_(false) {} + +void VectorRep::Iterator::DoSort() const { + // vrep is non-null means that we are working on an immutable memtable + if (!sorted_ && vrep_ != nullptr) { + WriteLock l(&vrep_->rwlock_); + if (!vrep_->sorted_) { + std::sort(bucket_->begin(), bucket_->end(), + stl_wrappers::Compare(compare_)); + cit_ = bucket_->begin(); + vrep_->sorted_ = true; + } + sorted_ = true; + } + if (!sorted_) { + std::sort(bucket_->begin(), bucket_->end(), + stl_wrappers::Compare(compare_)); + cit_ = bucket_->begin(); + sorted_ = true; + } + assert(sorted_); + assert(vrep_ == nullptr || vrep_->sorted_); +} + +// Returns true iff the iterator is positioned at a valid node. +bool VectorRep::Iterator::Valid() const { + DoSort(); + return cit_ != bucket_->end(); +} + +// Returns the key at the current position. +// REQUIRES: Valid() +const char* VectorRep::Iterator::key() const { + assert(sorted_); + return *cit_; +} + +// Advances to the next position. +// REQUIRES: Valid() +void VectorRep::Iterator::Next() { + assert(sorted_); + if (cit_ == bucket_->end()) { + return; + } + ++cit_; +} + +// Advances to the previous position. +// REQUIRES: Valid() +void VectorRep::Iterator::Prev() { + assert(sorted_); + if (cit_ == bucket_->begin()) { + // If you try to go back from the first element, the iterator should be + // invalidated. So we set it to past-the-end. This means that you can + // treat the container circularly. + cit_ = bucket_->end(); + } else { + --cit_; + } +} + +// Advance to the first entry with a key >= target +void VectorRep::Iterator::Seek(const Slice& user_key, + const char* memtable_key) { + DoSort(); + // Do binary search to find first value not less than the target + const char* encoded_key = + (memtable_key != nullptr) ? memtable_key : EncodeKey(&tmp_, user_key); + cit_ = std::equal_range(bucket_->begin(), bucket_->end(), encoded_key, + [this](const char* a, const char* b) { + return compare_(a, b) < 0; + }) + .first; +} + +// Advance to the first entry with a key <= target +void VectorRep::Iterator::SeekForPrev(const Slice& /*user_key*/, + const char* /*memtable_key*/) { + assert(false); +} + +// Position at the first entry in collection. +// Final state of iterator is Valid() iff collection is not empty. +void VectorRep::Iterator::SeekToFirst() { + DoSort(); + cit_ = bucket_->begin(); +} + +// Position at the last entry in collection. +// Final state of iterator is Valid() iff collection is not empty. +void VectorRep::Iterator::SeekToLast() { + DoSort(); + cit_ = bucket_->end(); + if (bucket_->size() != 0) { + --cit_; + } +} + +void VectorRep::Get(const LookupKey& k, void* callback_args, + bool (*callback_func)(void* arg, const char* entry)) { + rwlock_.ReadLock(); + VectorRep* vector_rep; + std::shared_ptr<Bucket> bucket; + if (immutable_) { + vector_rep = this; + } else { + vector_rep = nullptr; + bucket.reset(new Bucket(*bucket_)); // make a copy + } + VectorRep::Iterator iter(vector_rep, immutable_ ? bucket_ : bucket, compare_); + rwlock_.ReadUnlock(); + + for (iter.Seek(k.user_key(), k.memtable_key().data()); + iter.Valid() && callback_func(callback_args, iter.key()); iter.Next()) { + } +} + +MemTableRep::Iterator* VectorRep::GetIterator(Arena* arena) { + char* mem = nullptr; + if (arena != nullptr) { + mem = arena->AllocateAligned(sizeof(Iterator)); + } + ReadLock l(&rwlock_); + // Do not sort here. The sorting would be done the first time + // a Seek is performed on the iterator. + if (immutable_) { + if (arena == nullptr) { + return new Iterator(this, bucket_, compare_); + } else { + return new (mem) Iterator(this, bucket_, compare_); + } + } else { + std::shared_ptr<Bucket> tmp; + tmp.reset(new Bucket(*bucket_)); // make a copy + if (arena == nullptr) { + return new Iterator(nullptr, tmp, compare_); + } else { + return new (mem) Iterator(nullptr, tmp, compare_); + } + } +} +} // namespace + +static std::unordered_map<std::string, OptionTypeInfo> vector_rep_table_info = { + {"count", + {0, OptionType::kSizeT, OptionVerificationType::kNormal, + OptionTypeFlags::kNone}}, +}; + +VectorRepFactory::VectorRepFactory(size_t count) : count_(count) { + RegisterOptions("VectorRepFactoryOptions", &count_, &vector_rep_table_info); +} + +MemTableRep* VectorRepFactory::CreateMemTableRep( + const MemTableRep::KeyComparator& compare, Allocator* allocator, + const SliceTransform*, Logger* /*logger*/) { + return new VectorRep(compare, allocator, count_); +} +} // namespace ROCKSDB_NAMESPACE +#endif // ROCKSDB_LITE diff --git a/src/rocksdb/memtable/write_buffer_manager.cc b/src/rocksdb/memtable/write_buffer_manager.cc new file mode 100644 index 000000000..8db9816be --- /dev/null +++ b/src/rocksdb/memtable/write_buffer_manager.cc @@ -0,0 +1,202 @@ +// Copyright (c) 2011-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). +// +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#include "rocksdb/write_buffer_manager.h" + +#include <memory> + +#include "cache/cache_entry_roles.h" +#include "cache/cache_reservation_manager.h" +#include "db/db_impl/db_impl.h" +#include "rocksdb/status.h" +#include "util/coding.h" + +namespace ROCKSDB_NAMESPACE { +WriteBufferManager::WriteBufferManager(size_t _buffer_size, + std::shared_ptr<Cache> cache, + bool allow_stall) + : buffer_size_(_buffer_size), + mutable_limit_(buffer_size_ * 7 / 8), + memory_used_(0), + memory_active_(0), + cache_res_mgr_(nullptr), + allow_stall_(allow_stall), + stall_active_(false) { +#ifndef ROCKSDB_LITE + if (cache) { + // Memtable's memory usage tends to fluctuate frequently + // therefore we set delayed_decrease = true to save some dummy entry + // insertion on memory increase right after memory decrease + cache_res_mgr_ = std::make_shared< + CacheReservationManagerImpl<CacheEntryRole::kWriteBuffer>>( + cache, true /* delayed_decrease */); + } +#else + (void)cache; +#endif // ROCKSDB_LITE +} + +WriteBufferManager::~WriteBufferManager() { +#ifndef NDEBUG + std::unique_lock<std::mutex> lock(mu_); + assert(queue_.empty()); +#endif +} + +std::size_t WriteBufferManager::dummy_entries_in_cache_usage() const { + if (cache_res_mgr_ != nullptr) { + return cache_res_mgr_->GetTotalReservedCacheSize(); + } else { + return 0; + } +} + +void WriteBufferManager::ReserveMem(size_t mem) { + if (cache_res_mgr_ != nullptr) { + ReserveMemWithCache(mem); + } else if (enabled()) { + memory_used_.fetch_add(mem, std::memory_order_relaxed); + } + if (enabled()) { + memory_active_.fetch_add(mem, std::memory_order_relaxed); + } +} + +// Should only be called from write thread +void WriteBufferManager::ReserveMemWithCache(size_t mem) { +#ifndef ROCKSDB_LITE + assert(cache_res_mgr_ != nullptr); + // Use a mutex to protect various data structures. Can be optimized to a + // lock-free solution if it ends up with a performance bottleneck. + std::lock_guard<std::mutex> lock(cache_res_mgr_mu_); + + size_t new_mem_used = memory_used_.load(std::memory_order_relaxed) + mem; + memory_used_.store(new_mem_used, std::memory_order_relaxed); + Status s = cache_res_mgr_->UpdateCacheReservation(new_mem_used); + + // We absorb the error since WriteBufferManager is not able to handle + // this failure properly. Ideallly we should prevent this allocation + // from happening if this cache charging fails. + // [TODO] We'll need to improve it in the future and figure out what to do on + // error + s.PermitUncheckedError(); +#else + (void)mem; +#endif // ROCKSDB_LITE +} + +void WriteBufferManager::ScheduleFreeMem(size_t mem) { + if (enabled()) { + memory_active_.fetch_sub(mem, std::memory_order_relaxed); + } +} + +void WriteBufferManager::FreeMem(size_t mem) { + if (cache_res_mgr_ != nullptr) { + FreeMemWithCache(mem); + } else if (enabled()) { + memory_used_.fetch_sub(mem, std::memory_order_relaxed); + } + // Check if stall is active and can be ended. + MaybeEndWriteStall(); +} + +void WriteBufferManager::FreeMemWithCache(size_t mem) { +#ifndef ROCKSDB_LITE + assert(cache_res_mgr_ != nullptr); + // Use a mutex to protect various data structures. Can be optimized to a + // lock-free solution if it ends up with a performance bottleneck. + std::lock_guard<std::mutex> lock(cache_res_mgr_mu_); + size_t new_mem_used = memory_used_.load(std::memory_order_relaxed) - mem; + memory_used_.store(new_mem_used, std::memory_order_relaxed); + Status s = cache_res_mgr_->UpdateCacheReservation(new_mem_used); + + // We absorb the error since WriteBufferManager is not able to handle + // this failure properly. + // [TODO] We'll need to improve it in the future and figure out what to do on + // error + s.PermitUncheckedError(); +#else + (void)mem; +#endif // ROCKSDB_LITE +} + +void WriteBufferManager::BeginWriteStall(StallInterface* wbm_stall) { + assert(wbm_stall != nullptr); + assert(allow_stall_); + + // Allocate outside of the lock. + std::list<StallInterface*> new_node = {wbm_stall}; + + { + std::unique_lock<std::mutex> lock(mu_); + // Verify if the stall conditions are stil active. + if (ShouldStall()) { + stall_active_.store(true, std::memory_order_relaxed); + queue_.splice(queue_.end(), std::move(new_node)); + } + } + + // If the node was not consumed, the stall has ended already and we can signal + // the caller. + if (!new_node.empty()) { + new_node.front()->Signal(); + } +} + +// Called when memory is freed in FreeMem or the buffer size has changed. +void WriteBufferManager::MaybeEndWriteStall() { + // Cannot early-exit on !enabled() because SetBufferSize(0) needs to unblock + // the writers. + if (!allow_stall_) { + return; + } + + if (IsStallThresholdExceeded()) { + return; // Stall conditions have not resolved. + } + + // Perform all deallocations outside of the lock. + std::list<StallInterface*> cleanup; + + std::unique_lock<std::mutex> lock(mu_); + if (!stall_active_.load(std::memory_order_relaxed)) { + return; // Nothing to do. + } + + // Unblock new writers. + stall_active_.store(false, std::memory_order_relaxed); + + // Unblock the writers in the queue. + for (StallInterface* wbm_stall : queue_) { + wbm_stall->Signal(); + } + cleanup = std::move(queue_); +} + +void WriteBufferManager::RemoveDBFromQueue(StallInterface* wbm_stall) { + assert(wbm_stall != nullptr); + + // Deallocate the removed nodes outside of the lock. + std::list<StallInterface*> cleanup; + + if (enabled() && allow_stall_) { + std::unique_lock<std::mutex> lock(mu_); + for (auto it = queue_.begin(); it != queue_.end();) { + auto next = std::next(it); + if (*it == wbm_stall) { + cleanup.splice(cleanup.end(), queue_, std::move(it)); + } + it = next; + } + } + wbm_stall->Signal(); +} + +} // namespace ROCKSDB_NAMESPACE diff --git a/src/rocksdb/memtable/write_buffer_manager_test.cc b/src/rocksdb/memtable/write_buffer_manager_test.cc new file mode 100644 index 000000000..1cc4c2cc5 --- /dev/null +++ b/src/rocksdb/memtable/write_buffer_manager_test.cc @@ -0,0 +1,305 @@ +// Copyright (c) 2011-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). +// +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#include "rocksdb/write_buffer_manager.h" + +#include "test_util/testharness.h" + +namespace ROCKSDB_NAMESPACE { +class WriteBufferManagerTest : public testing::Test {}; + +#ifndef ROCKSDB_LITE +const size_t kSizeDummyEntry = 256 * 1024; + +TEST_F(WriteBufferManagerTest, ShouldFlush) { + // A write buffer manager of size 10MB + std::unique_ptr<WriteBufferManager> wbf( + new WriteBufferManager(10 * 1024 * 1024)); + + wbf->ReserveMem(8 * 1024 * 1024); + ASSERT_FALSE(wbf->ShouldFlush()); + // 90% of the hard limit will hit the condition + wbf->ReserveMem(1 * 1024 * 1024); + ASSERT_TRUE(wbf->ShouldFlush()); + // Scheduling for freeing will release the condition + wbf->ScheduleFreeMem(1 * 1024 * 1024); + ASSERT_FALSE(wbf->ShouldFlush()); + + wbf->ReserveMem(2 * 1024 * 1024); + ASSERT_TRUE(wbf->ShouldFlush()); + + wbf->ScheduleFreeMem(4 * 1024 * 1024); + // 11MB total, 6MB mutable. hard limit still hit + ASSERT_TRUE(wbf->ShouldFlush()); + + wbf->ScheduleFreeMem(2 * 1024 * 1024); + // 11MB total, 4MB mutable. hard limit stills but won't flush because more + // than half data is already being flushed. + ASSERT_FALSE(wbf->ShouldFlush()); + + wbf->ReserveMem(4 * 1024 * 1024); + // 15 MB total, 8MB mutable. + ASSERT_TRUE(wbf->ShouldFlush()); + + wbf->FreeMem(7 * 1024 * 1024); + // 8MB total, 8MB mutable. + ASSERT_FALSE(wbf->ShouldFlush()); + + // change size: 8M limit, 7M mutable limit + wbf->SetBufferSize(8 * 1024 * 1024); + // 8MB total, 8MB mutable. + ASSERT_TRUE(wbf->ShouldFlush()); + + wbf->ScheduleFreeMem(2 * 1024 * 1024); + // 8MB total, 6MB mutable. + ASSERT_TRUE(wbf->ShouldFlush()); + + wbf->FreeMem(2 * 1024 * 1024); + // 6MB total, 6MB mutable. + ASSERT_FALSE(wbf->ShouldFlush()); + + wbf->ReserveMem(1 * 1024 * 1024); + // 7MB total, 7MB mutable. + ASSERT_FALSE(wbf->ShouldFlush()); + + wbf->ReserveMem(1 * 1024 * 1024); + // 8MB total, 8MB mutable. + ASSERT_TRUE(wbf->ShouldFlush()); + + wbf->ScheduleFreeMem(1 * 1024 * 1024); + wbf->FreeMem(1 * 1024 * 1024); + // 7MB total, 7MB mutable. + ASSERT_FALSE(wbf->ShouldFlush()); +} + +class ChargeWriteBufferTest : public testing::Test {}; + +TEST_F(ChargeWriteBufferTest, Basic) { + constexpr std::size_t kMetaDataChargeOverhead = 10000; + + LRUCacheOptions co; + // 1GB cache + co.capacity = 1024 * 1024 * 1024; + co.num_shard_bits = 4; + co.metadata_charge_policy = kDontChargeCacheMetadata; + std::shared_ptr<Cache> cache = NewLRUCache(co); + // A write buffer manager of size 50MB + std::unique_ptr<WriteBufferManager> wbf( + new WriteBufferManager(50 * 1024 * 1024, cache)); + + // Allocate 333KB will allocate 512KB, memory_used_ = 333KB + wbf->ReserveMem(333 * 1024); + // 2 dummy entries are added for size 333 KB + ASSERT_EQ(wbf->dummy_entries_in_cache_usage(), 2 * kSizeDummyEntry); + ASSERT_GE(cache->GetPinnedUsage(), 2 * 256 * 1024); + ASSERT_LT(cache->GetPinnedUsage(), 2 * 256 * 1024 + kMetaDataChargeOverhead); + + // Allocate another 512KB, memory_used_ = 845KB + wbf->ReserveMem(512 * 1024); + // 2 more dummy entries are added for size 512 KB + // since ceil((memory_used_ - dummy_entries_in_cache_usage) % kSizeDummyEntry) + // = 2 + ASSERT_EQ(wbf->dummy_entries_in_cache_usage(), 4 * kSizeDummyEntry); + ASSERT_GE(cache->GetPinnedUsage(), 4 * 256 * 1024); + ASSERT_LT(cache->GetPinnedUsage(), 4 * 256 * 1024 + kMetaDataChargeOverhead); + + // Allocate another 10MB, memory_used_ = 11085KB + wbf->ReserveMem(10 * 1024 * 1024); + // 40 more entries are added for size 10 * 1024 * 1024 KB + ASSERT_EQ(wbf->dummy_entries_in_cache_usage(), 44 * kSizeDummyEntry); + ASSERT_GE(cache->GetPinnedUsage(), 44 * 256 * 1024); + ASSERT_LT(cache->GetPinnedUsage(), 44 * 256 * 1024 + kMetaDataChargeOverhead); + + // Free 1MB, memory_used_ = 10061KB + // It will not cause any change in cache cost + // since memory_used_ > dummy_entries_in_cache_usage * (3/4) + wbf->FreeMem(1 * 1024 * 1024); + ASSERT_EQ(wbf->dummy_entries_in_cache_usage(), 44 * kSizeDummyEntry); + ASSERT_GE(cache->GetPinnedUsage(), 44 * 256 * 1024); + ASSERT_LT(cache->GetPinnedUsage(), 44 * 256 * 1024 + kMetaDataChargeOverhead); + ASSERT_FALSE(wbf->ShouldFlush()); + + // Allocate another 41MB, memory_used_ = 52045KB + wbf->ReserveMem(41 * 1024 * 1024); + ASSERT_EQ(wbf->dummy_entries_in_cache_usage(), 204 * kSizeDummyEntry); + ASSERT_GE(cache->GetPinnedUsage(), 204 * 256 * 1024); + ASSERT_LT(cache->GetPinnedUsage(), + 204 * 256 * 1024 + kMetaDataChargeOverhead); + ASSERT_TRUE(wbf->ShouldFlush()); + + ASSERT_TRUE(wbf->ShouldFlush()); + + // Schedule free 20MB, memory_used_ = 52045KB + // It will not cause any change in memory_used and cache cost + wbf->ScheduleFreeMem(20 * 1024 * 1024); + ASSERT_EQ(wbf->dummy_entries_in_cache_usage(), 204 * kSizeDummyEntry); + ASSERT_GE(cache->GetPinnedUsage(), 204 * 256 * 1024); + ASSERT_LT(cache->GetPinnedUsage(), + 204 * 256 * 1024 + kMetaDataChargeOverhead); + // Still need flush as the hard limit hits + ASSERT_TRUE(wbf->ShouldFlush()); + + // Free 20MB, memory_used_ = 31565KB + // It will releae 80 dummy entries from cache since + // since memory_used_ < dummy_entries_in_cache_usage * (3/4) + // and floor((dummy_entries_in_cache_usage - memory_used_) % kSizeDummyEntry) + // = 80 + wbf->FreeMem(20 * 1024 * 1024); + ASSERT_EQ(wbf->dummy_entries_in_cache_usage(), 124 * kSizeDummyEntry); + ASSERT_GE(cache->GetPinnedUsage(), 124 * 256 * 1024); + ASSERT_LT(cache->GetPinnedUsage(), + 124 * 256 * 1024 + kMetaDataChargeOverhead); + + ASSERT_FALSE(wbf->ShouldFlush()); + + // Free 16KB, memory_used_ = 31549KB + // It will not release any dummy entry since memory_used_ >= + // dummy_entries_in_cache_usage * (3/4) + wbf->FreeMem(16 * 1024); + ASSERT_EQ(wbf->dummy_entries_in_cache_usage(), 124 * kSizeDummyEntry); + ASSERT_GE(cache->GetPinnedUsage(), 124 * 256 * 1024); + ASSERT_LT(cache->GetPinnedUsage(), + 124 * 256 * 1024 + kMetaDataChargeOverhead); + + // Free 20MB, memory_used_ = 11069KB + // It will releae 80 dummy entries from cache + // since memory_used_ < dummy_entries_in_cache_usage * (3/4) + // and floor((dummy_entries_in_cache_usage - memory_used_) % kSizeDummyEntry) + // = 80 + wbf->FreeMem(20 * 1024 * 1024); + ASSERT_EQ(wbf->dummy_entries_in_cache_usage(), 44 * kSizeDummyEntry); + ASSERT_GE(cache->GetPinnedUsage(), 44 * 256 * 1024); + ASSERT_LT(cache->GetPinnedUsage(), 44 * 256 * 1024 + kMetaDataChargeOverhead); + + // Free 1MB, memory_used_ = 10045KB + // It will not cause any change in cache cost + // since memory_used_ > dummy_entries_in_cache_usage * (3/4) + wbf->FreeMem(1 * 1024 * 1024); + ASSERT_EQ(wbf->dummy_entries_in_cache_usage(), 44 * kSizeDummyEntry); + ASSERT_GE(cache->GetPinnedUsage(), 44 * 256 * 1024); + ASSERT_LT(cache->GetPinnedUsage(), 44 * 256 * 1024 + kMetaDataChargeOverhead); + + // Reserve 512KB, memory_used_ = 10557KB + // It will not casue any change in cache cost + // since memory_used_ > dummy_entries_in_cache_usage * (3/4) + // which reflects the benefit of saving dummy entry insertion on memory + // reservation after delay decrease + wbf->ReserveMem(512 * 1024); + ASSERT_EQ(wbf->dummy_entries_in_cache_usage(), 44 * kSizeDummyEntry); + ASSERT_GE(cache->GetPinnedUsage(), 44 * 256 * 1024); + ASSERT_LT(cache->GetPinnedUsage(), 44 * 256 * 1024 + kMetaDataChargeOverhead); + + // Destroy write buffer manger should free everything + wbf.reset(); + ASSERT_EQ(cache->GetPinnedUsage(), 0); +} + +TEST_F(ChargeWriteBufferTest, BasicWithNoBufferSizeLimit) { + constexpr std::size_t kMetaDataChargeOverhead = 10000; + // 1GB cache + std::shared_ptr<Cache> cache = NewLRUCache(1024 * 1024 * 1024, 4); + // A write buffer manager of size 256MB + std::unique_ptr<WriteBufferManager> wbf(new WriteBufferManager(0, cache)); + + // Allocate 10MB, memory_used_ = 10240KB + // It will allocate 40 dummy entries + wbf->ReserveMem(10 * 1024 * 1024); + ASSERT_EQ(wbf->dummy_entries_in_cache_usage(), 40 * kSizeDummyEntry); + ASSERT_GE(cache->GetPinnedUsage(), 40 * 256 * 1024); + ASSERT_LT(cache->GetPinnedUsage(), 40 * 256 * 1024 + kMetaDataChargeOverhead); + + ASSERT_FALSE(wbf->ShouldFlush()); + + // Free 9MB, memory_used_ = 1024KB + // It will free 36 dummy entries + wbf->FreeMem(9 * 1024 * 1024); + ASSERT_EQ(wbf->dummy_entries_in_cache_usage(), 4 * kSizeDummyEntry); + ASSERT_GE(cache->GetPinnedUsage(), 4 * 256 * 1024); + ASSERT_LT(cache->GetPinnedUsage(), 4 * 256 * 1024 + kMetaDataChargeOverhead); + + // Free 160KB gradually, memory_used_ = 864KB + // It will not cause any change + // since memory_used_ > dummy_entries_in_cache_usage * 3/4 + for (int i = 0; i < 40; i++) { + wbf->FreeMem(4 * 1024); + } + ASSERT_EQ(wbf->dummy_entries_in_cache_usage(), 4 * kSizeDummyEntry); + ASSERT_GE(cache->GetPinnedUsage(), 4 * 256 * 1024); + ASSERT_LT(cache->GetPinnedUsage(), 4 * 256 * 1024 + kMetaDataChargeOverhead); +} + +TEST_F(ChargeWriteBufferTest, BasicWithCacheFull) { + constexpr std::size_t kMetaDataChargeOverhead = 20000; + + // 12MB cache size with strict capacity + LRUCacheOptions lo; + lo.capacity = 12 * 1024 * 1024; + lo.num_shard_bits = 0; + lo.strict_capacity_limit = true; + std::shared_ptr<Cache> cache = NewLRUCache(lo); + std::unique_ptr<WriteBufferManager> wbf(new WriteBufferManager(0, cache)); + + // Allocate 10MB, memory_used_ = 10240KB + wbf->ReserveMem(10 * 1024 * 1024); + ASSERT_EQ(wbf->dummy_entries_in_cache_usage(), 40 * kSizeDummyEntry); + ASSERT_GE(cache->GetPinnedUsage(), 40 * kSizeDummyEntry); + ASSERT_LT(cache->GetPinnedUsage(), + 40 * kSizeDummyEntry + kMetaDataChargeOverhead); + + // Allocate 10MB, memory_used_ = 20480KB + // Some dummy entry insertion will fail due to full cache + wbf->ReserveMem(10 * 1024 * 1024); + ASSERT_GE(cache->GetPinnedUsage(), 40 * kSizeDummyEntry); + ASSERT_LE(cache->GetPinnedUsage(), 12 * 1024 * 1024); + ASSERT_LT(wbf->dummy_entries_in_cache_usage(), 80 * kSizeDummyEntry); + + // Free 15MB after encoutering cache full, memory_used_ = 5120KB + wbf->FreeMem(15 * 1024 * 1024); + ASSERT_EQ(wbf->dummy_entries_in_cache_usage(), 20 * kSizeDummyEntry); + ASSERT_GE(cache->GetPinnedUsage(), 20 * kSizeDummyEntry); + ASSERT_LT(cache->GetPinnedUsage(), + 20 * kSizeDummyEntry + kMetaDataChargeOverhead); + + // Reserve 15MB, creating cache full again, memory_used_ = 20480KB + wbf->ReserveMem(15 * 1024 * 1024); + ASSERT_LE(cache->GetPinnedUsage(), 12 * 1024 * 1024); + ASSERT_LT(wbf->dummy_entries_in_cache_usage(), 80 * kSizeDummyEntry); + + // Increase capacity so next insert will fully succeed + cache->SetCapacity(40 * 1024 * 1024); + + // Allocate 10MB, memory_used_ = 30720KB + wbf->ReserveMem(10 * 1024 * 1024); + ASSERT_EQ(wbf->dummy_entries_in_cache_usage(), 120 * kSizeDummyEntry); + ASSERT_GE(cache->GetPinnedUsage(), 120 * kSizeDummyEntry); + ASSERT_LT(cache->GetPinnedUsage(), + 120 * kSizeDummyEntry + kMetaDataChargeOverhead); + + // Gradually release 20 MB + // It ended up sequentially releasing 32, 24, 18 dummy entries when + // memory_used_ decreases to 22528KB, 16384KB, 11776KB. + // In total, it releases 74 dummy entries + for (int i = 0; i < 40; i++) { + wbf->FreeMem(512 * 1024); + } + + ASSERT_EQ(wbf->dummy_entries_in_cache_usage(), 46 * kSizeDummyEntry); + ASSERT_GE(cache->GetPinnedUsage(), 46 * kSizeDummyEntry); + ASSERT_LT(cache->GetPinnedUsage(), + 46 * kSizeDummyEntry + kMetaDataChargeOverhead); +} + +#endif // ROCKSDB_LITE +} // namespace ROCKSDB_NAMESPACE + +int main(int argc, char** argv) { + ROCKSDB_NAMESPACE::port::InstallStackTraceHandler(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} |