diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-27 18:24:20 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-27 18:24:20 +0000 |
commit | 483eb2f56657e8e7f419ab1a4fab8dce9ade8609 (patch) | |
tree | e5d88d25d870d5dedacb6bbdbe2a966086a0a5cf /src/rocksdb/memtable | |
parent | Initial commit. (diff) | |
download | ceph-upstream.tar.xz ceph-upstream.zip |
Adding upstream version 14.2.21.upstream/14.2.21upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'src/rocksdb/memtable')
-rw-r--r-- | src/rocksdb/memtable/alloc_tracker.cc | 62 | ||||
-rw-r--r-- | src/rocksdb/memtable/hash_linklist_rep.cc | 845 | ||||
-rw-r--r-- | src/rocksdb/memtable/hash_linklist_rep.h | 49 | ||||
-rw-r--r-- | src/rocksdb/memtable/hash_skiplist_rep.cc | 349 | ||||
-rw-r--r-- | src/rocksdb/memtable/hash_skiplist_rep.h | 44 | ||||
-rw-r--r-- | src/rocksdb/memtable/inlineskiplist.h | 965 | ||||
-rw-r--r-- | src/rocksdb/memtable/inlineskiplist_test.cc | 645 | ||||
-rw-r--r-- | src/rocksdb/memtable/memtablerep_bench.cc | 682 | ||||
-rw-r--r-- | src/rocksdb/memtable/skiplist.h | 497 | ||||
-rw-r--r-- | src/rocksdb/memtable/skiplist_test.cc | 388 | ||||
-rw-r--r-- | src/rocksdb/memtable/skiplistrep.cc | 271 | ||||
-rw-r--r-- | src/rocksdb/memtable/stl_wrappers.h | 33 | ||||
-rw-r--r-- | src/rocksdb/memtable/vectorrep.cc | 301 | ||||
-rw-r--r-- | src/rocksdb/memtable/write_buffer_manager.cc | 130 | ||||
-rw-r--r-- | src/rocksdb/memtable/write_buffer_manager_test.cc | 151 |
15 files changed, 5412 insertions, 0 deletions
diff --git a/src/rocksdb/memtable/alloc_tracker.cc b/src/rocksdb/memtable/alloc_tracker.cc new file mode 100644 index 00000000..a1fa4938 --- /dev/null +++ b/src/rocksdb/memtable/alloc_tracker.cc @@ -0,0 +1,62 @@ +// Copyright (c) 2011-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). +// +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#include <assert.h> +#include "rocksdb/write_buffer_manager.h" +#include "util/allocator.h" +#include "util/arena.h" + +namespace rocksdb { + +AllocTracker::AllocTracker(WriteBufferManager* write_buffer_manager) + : write_buffer_manager_(write_buffer_manager), + bytes_allocated_(0), + done_allocating_(false), + freed_(false) {} + +AllocTracker::~AllocTracker() { FreeMem(); } + +void AllocTracker::Allocate(size_t bytes) { + assert(write_buffer_manager_ != nullptr); + if (write_buffer_manager_->enabled() || + write_buffer_manager_->cost_to_cache()) { + bytes_allocated_.fetch_add(bytes, std::memory_order_relaxed); + write_buffer_manager_->ReserveMem(bytes); + } +} + +void AllocTracker::DoneAllocating() { + if (write_buffer_manager_ != nullptr && !done_allocating_) { + if (write_buffer_manager_->enabled() || + write_buffer_manager_->cost_to_cache()) { + write_buffer_manager_->ScheduleFreeMem( + bytes_allocated_.load(std::memory_order_relaxed)); + } else { + assert(bytes_allocated_.load(std::memory_order_relaxed) == 0); + } + done_allocating_ = true; + } +} + +void AllocTracker::FreeMem() { + if (!done_allocating_) { + DoneAllocating(); + } + if (write_buffer_manager_ != nullptr && !freed_) { + if (write_buffer_manager_->enabled() || + write_buffer_manager_->cost_to_cache()) { + write_buffer_manager_->FreeMem( + bytes_allocated_.load(std::memory_order_relaxed)); + } else { + assert(bytes_allocated_.load(std::memory_order_relaxed) == 0); + } + freed_ = true; + } +} +} // namespace rocksdb diff --git a/src/rocksdb/memtable/hash_linklist_rep.cc b/src/rocksdb/memtable/hash_linklist_rep.cc new file mode 100644 index 00000000..878d2338 --- /dev/null +++ b/src/rocksdb/memtable/hash_linklist_rep.cc @@ -0,0 +1,845 @@ +// Copyright (c) 2011-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). +// + +#ifndef ROCKSDB_LITE +#include "memtable/hash_linklist_rep.h" + +#include <algorithm> +#include <atomic> +#include "db/memtable.h" +#include "memtable/skiplist.h" +#include "monitoring/histogram.h" +#include "port/port.h" +#include "rocksdb/memtablerep.h" +#include "rocksdb/slice.h" +#include "rocksdb/slice_transform.h" +#include "util/arena.h" +#include "util/hash.h" + +namespace rocksdb { +namespace { + +typedef const char* Key; +typedef SkipList<Key, const MemTableRep::KeyComparator&> MemtableSkipList; +typedef std::atomic<void*> Pointer; + +// A data structure used as the header of a link list of a hash bucket. +struct BucketHeader { + Pointer next; + std::atomic<uint32_t> num_entries; + + explicit BucketHeader(void* n, uint32_t count) + : next(n), num_entries(count) {} + + bool IsSkipListBucket() { + return next.load(std::memory_order_relaxed) == this; + } + + uint32_t GetNumEntries() const { + return num_entries.load(std::memory_order_relaxed); + } + + // REQUIRES: called from single-threaded Insert() + void IncNumEntries() { + // Only one thread can do write at one time. No need to do atomic + // incremental. Update it with relaxed load and store. + num_entries.store(GetNumEntries() + 1, std::memory_order_relaxed); + } +}; + +// A data structure used as the header of a skip list of a hash bucket. +struct SkipListBucketHeader { + BucketHeader Counting_header; + MemtableSkipList skip_list; + + explicit SkipListBucketHeader(const MemTableRep::KeyComparator& cmp, + Allocator* allocator, uint32_t count) + : Counting_header(this, // Pointing to itself to indicate header type. + count), + skip_list(cmp, allocator) {} +}; + +struct Node { + // Accessors/mutators for links. Wrapped in methods so we can + // add the appropriate barriers as necessary. + Node* Next() { + // Use an 'acquire load' so that we observe a fully initialized + // version of the returned Node. + return next_.load(std::memory_order_acquire); + } + void SetNext(Node* x) { + // Use a 'release store' so that anybody who reads through this + // pointer observes a fully initialized version of the inserted node. + next_.store(x, std::memory_order_release); + } + // No-barrier variants that can be safely used in a few locations. + Node* NoBarrier_Next() { + return next_.load(std::memory_order_relaxed); + } + + void NoBarrier_SetNext(Node* x) { next_.store(x, std::memory_order_relaxed); } + + // Needed for placement new below which is fine + Node() {} + + private: + std::atomic<Node*> next_; + + // Prohibit copying due to the below + Node(const Node&) = delete; + Node& operator=(const Node&) = delete; + + public: + char key[1]; +}; + +// Memory structure of the mem table: +// It is a hash table, each bucket points to one entry, a linked list or a +// skip list. In order to track total number of records in a bucket to determine +// whether should switch to skip list, a header is added just to indicate +// number of entries in the bucket. +// +// +// +-----> NULL Case 1. Empty bucket +// | +// | +// | +---> +-------+ +// | | | Next +--> NULL +// | | +-------+ +// +-----+ | | | | Case 2. One Entry in bucket. +// | +-+ | | Data | next pointer points to +// +-----+ | | | NULL. All other cases +// | | | | | next pointer is not NULL. +// +-----+ | +-------+ +// | +---+ +// +-----+ +-> +-------+ +> +-------+ +-> +-------+ +// | | | | Next +--+ | Next +--+ | Next +-->NULL +// +-----+ | +-------+ +-------+ +-------+ +// | +-----+ | Count | | | | | +// +-----+ +-------+ | Data | | Data | +// | | | | | | +// +-----+ Case 3. | | | | +// | | A header +-------+ +-------+ +// +-----+ points to +// | | a linked list. Count indicates total number +// +-----+ of rows in this bucket. +// | | +// +-----+ +-> +-------+ <--+ +// | | | | Next +----+ +// +-----+ | +-------+ Case 4. A header points to a skip +// | +----+ | Count | list and next pointer points to +// +-----+ +-------+ itself, to distinguish case 3 or 4. +// | | | | Count still is kept to indicates total +// +-----+ | Skip +--> of entries in the bucket for debugging +// | | | List | Data purpose. +// | | | +--> +// +-----+ | | +// | | +-------+ +// +-----+ +// +// We don't have data race when changing cases because: +// (1) When changing from case 2->3, we create a new bucket header, put the +// single node there first without changing the original node, and do a +// release store when changing the bucket pointer. In that case, a reader +// who sees a stale value of the bucket pointer will read this node, while +// a reader sees the correct value because of the release store. +// (2) When changing case 3->4, a new header is created with skip list points +// to the data, before doing an acquire store to change the bucket pointer. +// The old header and nodes are never changed, so any reader sees any +// of those existing pointers will guarantee to be able to iterate to the +// end of the linked list. +// (3) Header's next pointer in case 3 might change, but they are never equal +// to itself, so no matter a reader sees any stale or newer value, it will +// be able to correctly distinguish case 3 and 4. +// +// The reason that we use case 2 is we want to make the format to be efficient +// when the utilization of buckets is relatively low. If we use case 3 for +// single entry bucket, we will need to waste 12 bytes for every entry, +// which can be significant decrease of memory utilization. +class HashLinkListRep : public MemTableRep { + public: + HashLinkListRep(const MemTableRep::KeyComparator& compare, + Allocator* allocator, const SliceTransform* transform, + size_t bucket_size, uint32_t threshold_use_skiplist, + size_t huge_page_tlb_size, Logger* logger, + int bucket_entries_logging_threshold, + bool if_log_bucket_dist_when_flash); + + KeyHandle Allocate(const size_t len, char** buf) override; + + void Insert(KeyHandle handle) override; + + bool Contains(const char* key) const override; + + size_t ApproximateMemoryUsage() override; + + void Get(const LookupKey& k, void* callback_args, + bool (*callback_func)(void* arg, const char* entry)) override; + + ~HashLinkListRep() override; + + MemTableRep::Iterator* GetIterator(Arena* arena = nullptr) override; + + MemTableRep::Iterator* GetDynamicPrefixIterator( + Arena* arena = nullptr) override; + + private: + friend class DynamicIterator; + + size_t bucket_size_; + + // Maps slices (which are transformed user keys) to buckets of keys sharing + // the same transform. + Pointer* buckets_; + + const uint32_t threshold_use_skiplist_; + + // The user-supplied transform whose domain is the user keys. + const SliceTransform* transform_; + + const MemTableRep::KeyComparator& compare_; + + Logger* logger_; + int bucket_entries_logging_threshold_; + bool if_log_bucket_dist_when_flash_; + + bool LinkListContains(Node* head, const Slice& key) const; + + SkipListBucketHeader* GetSkipListBucketHeader(Pointer* first_next_pointer) + const; + + Node* GetLinkListFirstNode(Pointer* first_next_pointer) const; + + Slice GetPrefix(const Slice& internal_key) const { + return transform_->Transform(ExtractUserKey(internal_key)); + } + + size_t GetHash(const Slice& slice) const { + return NPHash64(slice.data(), static_cast<int>(slice.size()), 0) % + bucket_size_; + } + + Pointer* GetBucket(size_t i) const { + return static_cast<Pointer*>(buckets_[i].load(std::memory_order_acquire)); + } + + Pointer* GetBucket(const Slice& slice) const { + return GetBucket(GetHash(slice)); + } + + bool Equal(const Slice& a, const Key& b) const { + return (compare_(b, a) == 0); + } + + bool Equal(const Key& a, const Key& b) const { return (compare_(a, b) == 0); } + + bool KeyIsAfterNode(const Slice& internal_key, const Node* n) const { + // nullptr n is considered infinite + return (n != nullptr) && (compare_(n->key, internal_key) < 0); + } + + bool KeyIsAfterNode(const Key& key, const Node* n) const { + // nullptr n is considered infinite + return (n != nullptr) && (compare_(n->key, key) < 0); + } + + bool KeyIsAfterOrAtNode(const Slice& internal_key, const Node* n) const { + // nullptr n is considered infinite + return (n != nullptr) && (compare_(n->key, internal_key) <= 0); + } + + bool KeyIsAfterOrAtNode(const Key& key, const Node* n) const { + // nullptr n is considered infinite + return (n != nullptr) && (compare_(n->key, key) <= 0); + } + + Node* FindGreaterOrEqualInBucket(Node* head, const Slice& key) const; + Node* FindLessOrEqualInBucket(Node* head, const Slice& key) const; + + class FullListIterator : public MemTableRep::Iterator { + public: + explicit FullListIterator(MemtableSkipList* list, Allocator* allocator) + : iter_(list), full_list_(list), allocator_(allocator) {} + + ~FullListIterator() override {} + + // Returns true iff the iterator is positioned at a valid node. + bool Valid() const override { return iter_.Valid(); } + + // Returns the key at the current position. + // REQUIRES: Valid() + const char* key() const override { + assert(Valid()); + return iter_.key(); + } + + // Advances to the next position. + // REQUIRES: Valid() + void Next() override { + assert(Valid()); + iter_.Next(); + } + + // Advances to the previous position. + // REQUIRES: Valid() + void Prev() override { + assert(Valid()); + iter_.Prev(); + } + + // Advance to the first entry with a key >= target + void Seek(const Slice& internal_key, const char* memtable_key) override { + const char* encoded_key = + (memtable_key != nullptr) ? + memtable_key : EncodeKey(&tmp_, internal_key); + iter_.Seek(encoded_key); + } + + // Retreat to the last entry with a key <= target + void SeekForPrev(const Slice& internal_key, + const char* memtable_key) override { + const char* encoded_key = (memtable_key != nullptr) + ? memtable_key + : EncodeKey(&tmp_, internal_key); + iter_.SeekForPrev(encoded_key); + } + + // Position at the first entry in collection. + // Final state of iterator is Valid() iff collection is not empty. + void SeekToFirst() override { iter_.SeekToFirst(); } + + // Position at the last entry in collection. + // Final state of iterator is Valid() iff collection is not empty. + void SeekToLast() override { iter_.SeekToLast(); } + + private: + MemtableSkipList::Iterator iter_; + // To destruct with the iterator. + std::unique_ptr<MemtableSkipList> full_list_; + std::unique_ptr<Allocator> allocator_; + std::string tmp_; // For passing to EncodeKey + }; + + class LinkListIterator : public MemTableRep::Iterator { + public: + explicit LinkListIterator(const HashLinkListRep* const hash_link_list_rep, + Node* head) + : hash_link_list_rep_(hash_link_list_rep), + head_(head), + node_(nullptr) {} + + ~LinkListIterator() override {} + + // Returns true iff the iterator is positioned at a valid node. + bool Valid() const override { return node_ != nullptr; } + + // Returns the key at the current position. + // REQUIRES: Valid() + const char* key() const override { + assert(Valid()); + return node_->key; + } + + // Advances to the next position. + // REQUIRES: Valid() + void Next() override { + assert(Valid()); + node_ = node_->Next(); + } + + // Advances to the previous position. + // REQUIRES: Valid() + void Prev() override { + // Prefix iterator does not support total order. + // We simply set the iterator to invalid state + Reset(nullptr); + } + + // Advance to the first entry with a key >= target + void Seek(const Slice& internal_key, + const char* /*memtable_key*/) override { + node_ = hash_link_list_rep_->FindGreaterOrEqualInBucket(head_, + internal_key); + } + + // Retreat to the last entry with a key <= target + void SeekForPrev(const Slice& /*internal_key*/, + const char* /*memtable_key*/) override { + // Since we do not support Prev() + // We simply do not support SeekForPrev + Reset(nullptr); + } + + // Position at the first entry in collection. + // Final state of iterator is Valid() iff collection is not empty. + void SeekToFirst() override { + // Prefix iterator does not support total order. + // We simply set the iterator to invalid state + Reset(nullptr); + } + + // Position at the last entry in collection. + // Final state of iterator is Valid() iff collection is not empty. + void SeekToLast() override { + // Prefix iterator does not support total order. + // We simply set the iterator to invalid state + Reset(nullptr); + } + + protected: + void Reset(Node* head) { + head_ = head; + node_ = nullptr; + } + private: + friend class HashLinkListRep; + const HashLinkListRep* const hash_link_list_rep_; + Node* head_; + Node* node_; + + virtual void SeekToHead() { + node_ = head_; + } + }; + + class DynamicIterator : public HashLinkListRep::LinkListIterator { + public: + explicit DynamicIterator(HashLinkListRep& memtable_rep) + : HashLinkListRep::LinkListIterator(&memtable_rep, nullptr), + memtable_rep_(memtable_rep) {} + + // Advance to the first entry with a key >= target + void Seek(const Slice& k, const char* memtable_key) override { + auto transformed = memtable_rep_.GetPrefix(k); + auto* bucket = memtable_rep_.GetBucket(transformed); + + SkipListBucketHeader* skip_list_header = + memtable_rep_.GetSkipListBucketHeader(bucket); + if (skip_list_header != nullptr) { + // The bucket is organized as a skip list + if (!skip_list_iter_) { + skip_list_iter_.reset( + new MemtableSkipList::Iterator(&skip_list_header->skip_list)); + } else { + skip_list_iter_->SetList(&skip_list_header->skip_list); + } + if (memtable_key != nullptr) { + skip_list_iter_->Seek(memtable_key); + } else { + IterKey encoded_key; + encoded_key.EncodeLengthPrefixedKey(k); + skip_list_iter_->Seek(encoded_key.GetUserKey().data()); + } + } else { + // The bucket is organized as a linked list + skip_list_iter_.reset(); + Reset(memtable_rep_.GetLinkListFirstNode(bucket)); + HashLinkListRep::LinkListIterator::Seek(k, memtable_key); + } + } + + bool Valid() const override { + if (skip_list_iter_) { + return skip_list_iter_->Valid(); + } + return HashLinkListRep::LinkListIterator::Valid(); + } + + const char* key() const override { + if (skip_list_iter_) { + return skip_list_iter_->key(); + } + return HashLinkListRep::LinkListIterator::key(); + } + + void Next() override { + if (skip_list_iter_) { + skip_list_iter_->Next(); + } else { + HashLinkListRep::LinkListIterator::Next(); + } + } + + private: + // the underlying memtable + const HashLinkListRep& memtable_rep_; + std::unique_ptr<MemtableSkipList::Iterator> skip_list_iter_; + }; + + class EmptyIterator : public MemTableRep::Iterator { + // This is used when there wasn't a bucket. It is cheaper than + // instantiating an empty bucket over which to iterate. + public: + EmptyIterator() { } + bool Valid() const override { return false; } + const char* key() const override { + assert(false); + return nullptr; + } + void Next() override {} + void Prev() override {} + void Seek(const Slice& /*user_key*/, + const char* /*memtable_key*/) override {} + void SeekForPrev(const Slice& /*user_key*/, + const char* /*memtable_key*/) override {} + void SeekToFirst() override {} + void SeekToLast() override {} + + private: + }; +}; + +HashLinkListRep::HashLinkListRep( + const MemTableRep::KeyComparator& compare, Allocator* allocator, + const SliceTransform* transform, size_t bucket_size, + uint32_t threshold_use_skiplist, size_t huge_page_tlb_size, Logger* logger, + int bucket_entries_logging_threshold, bool if_log_bucket_dist_when_flash) + : MemTableRep(allocator), + bucket_size_(bucket_size), + // Threshold to use skip list doesn't make sense if less than 3, so we + // force it to be minimum of 3 to simplify implementation. + threshold_use_skiplist_(std::max(threshold_use_skiplist, 3U)), + transform_(transform), + compare_(compare), + logger_(logger), + bucket_entries_logging_threshold_(bucket_entries_logging_threshold), + if_log_bucket_dist_when_flash_(if_log_bucket_dist_when_flash) { + char* mem = allocator_->AllocateAligned(sizeof(Pointer) * bucket_size, + huge_page_tlb_size, logger); + + buckets_ = new (mem) Pointer[bucket_size]; + + for (size_t i = 0; i < bucket_size_; ++i) { + buckets_[i].store(nullptr, std::memory_order_relaxed); + } +} + +HashLinkListRep::~HashLinkListRep() { +} + +KeyHandle HashLinkListRep::Allocate(const size_t len, char** buf) { + char* mem = allocator_->AllocateAligned(sizeof(Node) + len); + Node* x = new (mem) Node(); + *buf = x->key; + return static_cast<void*>(x); +} + +SkipListBucketHeader* HashLinkListRep::GetSkipListBucketHeader( + Pointer* first_next_pointer) const { + if (first_next_pointer == nullptr) { + return nullptr; + } + if (first_next_pointer->load(std::memory_order_relaxed) == nullptr) { + // Single entry bucket + return nullptr; + } + // Counting header + BucketHeader* header = reinterpret_cast<BucketHeader*>(first_next_pointer); + if (header->IsSkipListBucket()) { + assert(header->GetNumEntries() > threshold_use_skiplist_); + auto* skip_list_bucket_header = + reinterpret_cast<SkipListBucketHeader*>(header); + assert(skip_list_bucket_header->Counting_header.next.load( + std::memory_order_relaxed) == header); + return skip_list_bucket_header; + } + assert(header->GetNumEntries() <= threshold_use_skiplist_); + return nullptr; +} + +Node* HashLinkListRep::GetLinkListFirstNode(Pointer* first_next_pointer) const { + if (first_next_pointer == nullptr) { + return nullptr; + } + if (first_next_pointer->load(std::memory_order_relaxed) == nullptr) { + // Single entry bucket + return reinterpret_cast<Node*>(first_next_pointer); + } + // Counting header + BucketHeader* header = reinterpret_cast<BucketHeader*>(first_next_pointer); + if (!header->IsSkipListBucket()) { + assert(header->GetNumEntries() <= threshold_use_skiplist_); + return reinterpret_cast<Node*>( + header->next.load(std::memory_order_acquire)); + } + assert(header->GetNumEntries() > threshold_use_skiplist_); + return nullptr; +} + +void HashLinkListRep::Insert(KeyHandle handle) { + Node* x = static_cast<Node*>(handle); + assert(!Contains(x->key)); + Slice internal_key = GetLengthPrefixedSlice(x->key); + auto transformed = GetPrefix(internal_key); + auto& bucket = buckets_[GetHash(transformed)]; + Pointer* first_next_pointer = + static_cast<Pointer*>(bucket.load(std::memory_order_relaxed)); + + if (first_next_pointer == nullptr) { + // Case 1. empty bucket + // NoBarrier_SetNext() suffices since we will add a barrier when + // we publish a pointer to "x" in prev[i]. + x->NoBarrier_SetNext(nullptr); + bucket.store(x, std::memory_order_release); + return; + } + + BucketHeader* header = nullptr; + if (first_next_pointer->load(std::memory_order_relaxed) == nullptr) { + // Case 2. only one entry in the bucket + // Need to convert to a Counting bucket and turn to case 4. + Node* first = reinterpret_cast<Node*>(first_next_pointer); + // Need to add a bucket header. + // We have to first convert it to a bucket with header before inserting + // the new node. Otherwise, we might need to change next pointer of first. + // In that case, a reader might sees the next pointer is NULL and wrongly + // think the node is a bucket header. + auto* mem = allocator_->AllocateAligned(sizeof(BucketHeader)); + header = new (mem) BucketHeader(first, 1); + bucket.store(header, std::memory_order_release); + } else { + header = reinterpret_cast<BucketHeader*>(first_next_pointer); + if (header->IsSkipListBucket()) { + // Case 4. Bucket is already a skip list + assert(header->GetNumEntries() > threshold_use_skiplist_); + auto* skip_list_bucket_header = + reinterpret_cast<SkipListBucketHeader*>(header); + // Only one thread can execute Insert() at one time. No need to do atomic + // incremental. + skip_list_bucket_header->Counting_header.IncNumEntries(); + skip_list_bucket_header->skip_list.Insert(x->key); + return; + } + } + + if (bucket_entries_logging_threshold_ > 0 && + header->GetNumEntries() == + static_cast<uint32_t>(bucket_entries_logging_threshold_)) { + Info(logger_, "HashLinkedList bucket %" ROCKSDB_PRIszt + " has more than %d " + "entries. Key to insert: %s", + GetHash(transformed), header->GetNumEntries(), + GetLengthPrefixedSlice(x->key).ToString(true).c_str()); + } + + if (header->GetNumEntries() == threshold_use_skiplist_) { + // Case 3. number of entries reaches the threshold so need to convert to + // skip list. + LinkListIterator bucket_iter( + this, reinterpret_cast<Node*>( + first_next_pointer->load(std::memory_order_relaxed))); + auto mem = allocator_->AllocateAligned(sizeof(SkipListBucketHeader)); + SkipListBucketHeader* new_skip_list_header = new (mem) + SkipListBucketHeader(compare_, allocator_, header->GetNumEntries() + 1); + auto& skip_list = new_skip_list_header->skip_list; + + // Add all current entries to the skip list + for (bucket_iter.SeekToHead(); bucket_iter.Valid(); bucket_iter.Next()) { + skip_list.Insert(bucket_iter.key()); + } + + // insert the new entry + skip_list.Insert(x->key); + // Set the bucket + bucket.store(new_skip_list_header, std::memory_order_release); + } else { + // Case 5. Need to insert to the sorted linked list without changing the + // header. + Node* first = + reinterpret_cast<Node*>(header->next.load(std::memory_order_relaxed)); + assert(first != nullptr); + // Advance counter unless the bucket needs to be advanced to skip list. + // In that case, we need to make sure the previous count never exceeds + // threshold_use_skiplist_ to avoid readers to cast to wrong format. + header->IncNumEntries(); + + Node* cur = first; + Node* prev = nullptr; + while (true) { + if (cur == nullptr) { + break; + } + Node* next = cur->Next(); + // Make sure the lists are sorted. + // If x points to head_ or next points nullptr, it is trivially satisfied. + assert((cur == first) || (next == nullptr) || + KeyIsAfterNode(next->key, cur)); + if (KeyIsAfterNode(internal_key, cur)) { + // Keep searching in this list + prev = cur; + cur = next; + } else { + break; + } + } + + // Our data structure does not allow duplicate insertion + assert(cur == nullptr || !Equal(x->key, cur->key)); + + // NoBarrier_SetNext() suffices since we will add a barrier when + // we publish a pointer to "x" in prev[i]. + x->NoBarrier_SetNext(cur); + + if (prev) { + prev->SetNext(x); + } else { + header->next.store(static_cast<void*>(x), std::memory_order_release); + } + } +} + +bool HashLinkListRep::Contains(const char* key) const { + Slice internal_key = GetLengthPrefixedSlice(key); + + auto transformed = GetPrefix(internal_key); + auto bucket = GetBucket(transformed); + if (bucket == nullptr) { + return false; + } + + SkipListBucketHeader* skip_list_header = GetSkipListBucketHeader(bucket); + if (skip_list_header != nullptr) { + return skip_list_header->skip_list.Contains(key); + } else { + return LinkListContains(GetLinkListFirstNode(bucket), internal_key); + } +} + +size_t HashLinkListRep::ApproximateMemoryUsage() { + // Memory is always allocated from the allocator. + return 0; +} + +void HashLinkListRep::Get(const LookupKey& k, void* callback_args, + bool (*callback_func)(void* arg, const char* entry)) { + auto transformed = transform_->Transform(k.user_key()); + auto bucket = GetBucket(transformed); + + auto* skip_list_header = GetSkipListBucketHeader(bucket); + if (skip_list_header != nullptr) { + // Is a skip list + MemtableSkipList::Iterator iter(&skip_list_header->skip_list); + for (iter.Seek(k.memtable_key().data()); + iter.Valid() && callback_func(callback_args, iter.key()); + iter.Next()) { + } + } else { + auto* link_list_head = GetLinkListFirstNode(bucket); + if (link_list_head != nullptr) { + LinkListIterator iter(this, link_list_head); + for (iter.Seek(k.internal_key(), nullptr); + iter.Valid() && callback_func(callback_args, iter.key()); + iter.Next()) { + } + } + } +} + +MemTableRep::Iterator* HashLinkListRep::GetIterator(Arena* alloc_arena) { + // allocate a new arena of similar size to the one currently in use + Arena* new_arena = new Arena(allocator_->BlockSize()); + auto list = new MemtableSkipList(compare_, new_arena); + HistogramImpl keys_per_bucket_hist; + + for (size_t i = 0; i < bucket_size_; ++i) { + int count = 0; + auto* bucket = GetBucket(i); + if (bucket != nullptr) { + auto* skip_list_header = GetSkipListBucketHeader(bucket); + if (skip_list_header != nullptr) { + // Is a skip list + MemtableSkipList::Iterator itr(&skip_list_header->skip_list); + for (itr.SeekToFirst(); itr.Valid(); itr.Next()) { + list->Insert(itr.key()); + count++; + } + } else { + auto* link_list_head = GetLinkListFirstNode(bucket); + if (link_list_head != nullptr) { + LinkListIterator itr(this, link_list_head); + for (itr.SeekToHead(); itr.Valid(); itr.Next()) { + list->Insert(itr.key()); + count++; + } + } + } + } + if (if_log_bucket_dist_when_flash_) { + keys_per_bucket_hist.Add(count); + } + } + if (if_log_bucket_dist_when_flash_ && logger_ != nullptr) { + Info(logger_, "hashLinkedList Entry distribution among buckets: %s", + keys_per_bucket_hist.ToString().c_str()); + } + + if (alloc_arena == nullptr) { + return new FullListIterator(list, new_arena); + } else { + auto mem = alloc_arena->AllocateAligned(sizeof(FullListIterator)); + return new (mem) FullListIterator(list, new_arena); + } +} + +MemTableRep::Iterator* HashLinkListRep::GetDynamicPrefixIterator( + Arena* alloc_arena) { + if (alloc_arena == nullptr) { + return new DynamicIterator(*this); + } else { + auto mem = alloc_arena->AllocateAligned(sizeof(DynamicIterator)); + return new (mem) DynamicIterator(*this); + } +} + +bool HashLinkListRep::LinkListContains(Node* head, + const Slice& user_key) const { + Node* x = FindGreaterOrEqualInBucket(head, user_key); + return (x != nullptr && Equal(user_key, x->key)); +} + +Node* HashLinkListRep::FindGreaterOrEqualInBucket(Node* head, + const Slice& key) const { + Node* x = head; + while (true) { + if (x == nullptr) { + return x; + } + Node* next = x->Next(); + // Make sure the lists are sorted. + // If x points to head_ or next points nullptr, it is trivially satisfied. + assert((x == head) || (next == nullptr) || KeyIsAfterNode(next->key, x)); + if (KeyIsAfterNode(key, x)) { + // Keep searching in this list + x = next; + } else { + break; + } + } + return x; +} + +} // anon namespace + +MemTableRep* HashLinkListRepFactory::CreateMemTableRep( + const MemTableRep::KeyComparator& compare, Allocator* allocator, + const SliceTransform* transform, Logger* logger) { + return new HashLinkListRep(compare, allocator, transform, bucket_count_, + threshold_use_skiplist_, huge_page_tlb_size_, + logger, bucket_entries_logging_threshold_, + if_log_bucket_dist_when_flash_); +} + +MemTableRepFactory* NewHashLinkListRepFactory( + size_t bucket_count, size_t huge_page_tlb_size, + int bucket_entries_logging_threshold, bool if_log_bucket_dist_when_flash, + uint32_t threshold_use_skiplist) { + return new HashLinkListRepFactory( + bucket_count, threshold_use_skiplist, huge_page_tlb_size, + bucket_entries_logging_threshold, if_log_bucket_dist_when_flash); +} + +} // namespace rocksdb +#endif // ROCKSDB_LITE diff --git a/src/rocksdb/memtable/hash_linklist_rep.h b/src/rocksdb/memtable/hash_linklist_rep.h new file mode 100644 index 00000000..a6da3eed --- /dev/null +++ b/src/rocksdb/memtable/hash_linklist_rep.h @@ -0,0 +1,49 @@ +// Copyright (c) 2011-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#pragma once +#ifndef ROCKSDB_LITE +#include "rocksdb/slice_transform.h" +#include "rocksdb/memtablerep.h" + +namespace rocksdb { + +class HashLinkListRepFactory : public MemTableRepFactory { + public: + explicit HashLinkListRepFactory(size_t bucket_count, + uint32_t threshold_use_skiplist, + size_t huge_page_tlb_size, + int bucket_entries_logging_threshold, + bool if_log_bucket_dist_when_flash) + : bucket_count_(bucket_count), + threshold_use_skiplist_(threshold_use_skiplist), + huge_page_tlb_size_(huge_page_tlb_size), + bucket_entries_logging_threshold_(bucket_entries_logging_threshold), + if_log_bucket_dist_when_flash_(if_log_bucket_dist_when_flash) {} + + virtual ~HashLinkListRepFactory() {} + + using MemTableRepFactory::CreateMemTableRep; + virtual MemTableRep* CreateMemTableRep( + const MemTableRep::KeyComparator& compare, Allocator* allocator, + const SliceTransform* transform, Logger* logger) override; + + virtual const char* Name() const override { + return "HashLinkListRepFactory"; + } + + private: + const size_t bucket_count_; + const uint32_t threshold_use_skiplist_; + const size_t huge_page_tlb_size_; + int bucket_entries_logging_threshold_; + bool if_log_bucket_dist_when_flash_; +}; + +} +#endif // ROCKSDB_LITE diff --git a/src/rocksdb/memtable/hash_skiplist_rep.cc b/src/rocksdb/memtable/hash_skiplist_rep.cc new file mode 100644 index 00000000..d02919cd --- /dev/null +++ b/src/rocksdb/memtable/hash_skiplist_rep.cc @@ -0,0 +1,349 @@ +// Copyright (c) 2011-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). +// + +#ifndef ROCKSDB_LITE +#include "memtable/hash_skiplist_rep.h" + +#include <atomic> + +#include "rocksdb/memtablerep.h" +#include "util/arena.h" +#include "rocksdb/slice.h" +#include "rocksdb/slice_transform.h" +#include "port/port.h" +#include "util/murmurhash.h" +#include "db/memtable.h" +#include "memtable/skiplist.h" + +namespace rocksdb { +namespace { + +class HashSkipListRep : public MemTableRep { + public: + HashSkipListRep(const MemTableRep::KeyComparator& compare, + Allocator* allocator, const SliceTransform* transform, + size_t bucket_size, int32_t skiplist_height, + int32_t skiplist_branching_factor); + + void Insert(KeyHandle handle) override; + + bool Contains(const char* key) const override; + + size_t ApproximateMemoryUsage() override; + + void Get(const LookupKey& k, void* callback_args, + bool (*callback_func)(void* arg, const char* entry)) override; + + ~HashSkipListRep() override; + + MemTableRep::Iterator* GetIterator(Arena* arena = nullptr) override; + + MemTableRep::Iterator* GetDynamicPrefixIterator( + Arena* arena = nullptr) override; + + private: + friend class DynamicIterator; + typedef SkipList<const char*, const MemTableRep::KeyComparator&> Bucket; + + size_t bucket_size_; + + const int32_t skiplist_height_; + const int32_t skiplist_branching_factor_; + + // Maps slices (which are transformed user keys) to buckets of keys sharing + // the same transform. + std::atomic<Bucket*>* buckets_; + + // The user-supplied transform whose domain is the user keys. + const SliceTransform* transform_; + + const MemTableRep::KeyComparator& compare_; + // immutable after construction + Allocator* const allocator_; + + inline size_t GetHash(const Slice& slice) const { + return MurmurHash(slice.data(), static_cast<int>(slice.size()), 0) % + bucket_size_; + } + inline Bucket* GetBucket(size_t i) const { + return buckets_[i].load(std::memory_order_acquire); + } + inline Bucket* GetBucket(const Slice& slice) const { + return GetBucket(GetHash(slice)); + } + // Get a bucket from buckets_. If the bucket hasn't been initialized yet, + // initialize it before returning. + Bucket* GetInitializedBucket(const Slice& transformed); + + class Iterator : public MemTableRep::Iterator { + public: + explicit Iterator(Bucket* list, bool own_list = true, + Arena* arena = nullptr) + : list_(list), iter_(list), own_list_(own_list), arena_(arena) {} + + ~Iterator() override { + // if we own the list, we should also delete it + if (own_list_) { + assert(list_ != nullptr); + delete list_; + } + } + + // Returns true iff the iterator is positioned at a valid node. + bool Valid() const override { return list_ != nullptr && iter_.Valid(); } + + // Returns the key at the current position. + // REQUIRES: Valid() + const char* key() const override { + assert(Valid()); + return iter_.key(); + } + + // Advances to the next position. + // REQUIRES: Valid() + void Next() override { + assert(Valid()); + iter_.Next(); + } + + // Advances to the previous position. + // REQUIRES: Valid() + void Prev() override { + assert(Valid()); + iter_.Prev(); + } + + // Advance to the first entry with a key >= target + void Seek(const Slice& internal_key, const char* memtable_key) override { + if (list_ != nullptr) { + const char* encoded_key = + (memtable_key != nullptr) ? + memtable_key : EncodeKey(&tmp_, internal_key); + iter_.Seek(encoded_key); + } + } + + // Retreat to the last entry with a key <= target + void SeekForPrev(const Slice& /*internal_key*/, + const char* /*memtable_key*/) override { + // not supported + assert(false); + } + + // Position at the first entry in collection. + // Final state of iterator is Valid() iff collection is not empty. + void SeekToFirst() override { + if (list_ != nullptr) { + iter_.SeekToFirst(); + } + } + + // Position at the last entry in collection. + // Final state of iterator is Valid() iff collection is not empty. + void SeekToLast() override { + if (list_ != nullptr) { + iter_.SeekToLast(); + } + } + + protected: + void Reset(Bucket* list) { + if (own_list_) { + assert(list_ != nullptr); + delete list_; + } + list_ = list; + iter_.SetList(list); + own_list_ = false; + } + private: + // if list_ is nullptr, we should NEVER call any methods on iter_ + // if list_ is nullptr, this Iterator is not Valid() + Bucket* list_; + Bucket::Iterator iter_; + // here we track if we own list_. If we own it, we are also + // responsible for it's cleaning. This is a poor man's std::shared_ptr + bool own_list_; + std::unique_ptr<Arena> arena_; + std::string tmp_; // For passing to EncodeKey + }; + + class DynamicIterator : public HashSkipListRep::Iterator { + public: + explicit DynamicIterator(const HashSkipListRep& memtable_rep) + : HashSkipListRep::Iterator(nullptr, false), + memtable_rep_(memtable_rep) {} + + // Advance to the first entry with a key >= target + void Seek(const Slice& k, const char* memtable_key) override { + auto transformed = memtable_rep_.transform_->Transform(ExtractUserKey(k)); + Reset(memtable_rep_.GetBucket(transformed)); + HashSkipListRep::Iterator::Seek(k, memtable_key); + } + + // Position at the first entry in collection. + // Final state of iterator is Valid() iff collection is not empty. + void SeekToFirst() override { + // Prefix iterator does not support total order. + // We simply set the iterator to invalid state + Reset(nullptr); + } + + // Position at the last entry in collection. + // Final state of iterator is Valid() iff collection is not empty. + void SeekToLast() override { + // Prefix iterator does not support total order. + // We simply set the iterator to invalid state + Reset(nullptr); + } + + private: + // the underlying memtable + const HashSkipListRep& memtable_rep_; + }; + + class EmptyIterator : public MemTableRep::Iterator { + // This is used when there wasn't a bucket. It is cheaper than + // instantiating an empty bucket over which to iterate. + public: + EmptyIterator() { } + bool Valid() const override { return false; } + const char* key() const override { + assert(false); + return nullptr; + } + void Next() override {} + void Prev() override {} + void Seek(const Slice& /*internal_key*/, + const char* /*memtable_key*/) override {} + void SeekForPrev(const Slice& /*internal_key*/, + const char* /*memtable_key*/) override {} + void SeekToFirst() override {} + void SeekToLast() override {} + + private: + }; +}; + +HashSkipListRep::HashSkipListRep(const MemTableRep::KeyComparator& compare, + Allocator* allocator, + const SliceTransform* transform, + size_t bucket_size, int32_t skiplist_height, + int32_t skiplist_branching_factor) + : MemTableRep(allocator), + bucket_size_(bucket_size), + skiplist_height_(skiplist_height), + skiplist_branching_factor_(skiplist_branching_factor), + transform_(transform), + compare_(compare), + allocator_(allocator) { + auto mem = allocator->AllocateAligned( + sizeof(std::atomic<void*>) * bucket_size); + buckets_ = new (mem) std::atomic<Bucket*>[bucket_size]; + + for (size_t i = 0; i < bucket_size_; ++i) { + buckets_[i].store(nullptr, std::memory_order_relaxed); + } +} + +HashSkipListRep::~HashSkipListRep() { +} + +HashSkipListRep::Bucket* HashSkipListRep::GetInitializedBucket( + const Slice& transformed) { + size_t hash = GetHash(transformed); + auto bucket = GetBucket(hash); + if (bucket == nullptr) { + auto addr = allocator_->AllocateAligned(sizeof(Bucket)); + bucket = new (addr) Bucket(compare_, allocator_, skiplist_height_, + skiplist_branching_factor_); + buckets_[hash].store(bucket, std::memory_order_release); + } + return bucket; +} + +void HashSkipListRep::Insert(KeyHandle handle) { + auto* key = static_cast<char*>(handle); + assert(!Contains(key)); + auto transformed = transform_->Transform(UserKey(key)); + auto bucket = GetInitializedBucket(transformed); + bucket->Insert(key); +} + +bool HashSkipListRep::Contains(const char* key) const { + auto transformed = transform_->Transform(UserKey(key)); + auto bucket = GetBucket(transformed); + if (bucket == nullptr) { + return false; + } + return bucket->Contains(key); +} + +size_t HashSkipListRep::ApproximateMemoryUsage() { + return 0; +} + +void HashSkipListRep::Get(const LookupKey& k, void* callback_args, + bool (*callback_func)(void* arg, const char* entry)) { + auto transformed = transform_->Transform(k.user_key()); + auto bucket = GetBucket(transformed); + if (bucket != nullptr) { + Bucket::Iterator iter(bucket); + for (iter.Seek(k.memtable_key().data()); + iter.Valid() && callback_func(callback_args, iter.key()); + iter.Next()) { + } + } +} + +MemTableRep::Iterator* HashSkipListRep::GetIterator(Arena* arena) { + // allocate a new arena of similar size to the one currently in use + Arena* new_arena = new Arena(allocator_->BlockSize()); + auto list = new Bucket(compare_, new_arena); + for (size_t i = 0; i < bucket_size_; ++i) { + auto bucket = GetBucket(i); + if (bucket != nullptr) { + Bucket::Iterator itr(bucket); + for (itr.SeekToFirst(); itr.Valid(); itr.Next()) { + list->Insert(itr.key()); + } + } + } + if (arena == nullptr) { + return new Iterator(list, true, new_arena); + } else { + auto mem = arena->AllocateAligned(sizeof(Iterator)); + return new (mem) Iterator(list, true, new_arena); + } +} + +MemTableRep::Iterator* HashSkipListRep::GetDynamicPrefixIterator(Arena* arena) { + if (arena == nullptr) { + return new DynamicIterator(*this); + } else { + auto mem = arena->AllocateAligned(sizeof(DynamicIterator)); + return new (mem) DynamicIterator(*this); + } +} + +} // anon namespace + +MemTableRep* HashSkipListRepFactory::CreateMemTableRep( + const MemTableRep::KeyComparator& compare, Allocator* allocator, + const SliceTransform* transform, Logger* /*logger*/) { + return new HashSkipListRep(compare, allocator, transform, bucket_count_, + skiplist_height_, skiplist_branching_factor_); +} + +MemTableRepFactory* NewHashSkipListRepFactory( + size_t bucket_count, int32_t skiplist_height, + int32_t skiplist_branching_factor) { + return new HashSkipListRepFactory(bucket_count, skiplist_height, + skiplist_branching_factor); +} + +} // namespace rocksdb +#endif // ROCKSDB_LITE diff --git a/src/rocksdb/memtable/hash_skiplist_rep.h b/src/rocksdb/memtable/hash_skiplist_rep.h new file mode 100644 index 00000000..5d1e04f3 --- /dev/null +++ b/src/rocksdb/memtable/hash_skiplist_rep.h @@ -0,0 +1,44 @@ +// Copyright (c) 2011-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#pragma once +#ifndef ROCKSDB_LITE +#include "rocksdb/slice_transform.h" +#include "rocksdb/memtablerep.h" + +namespace rocksdb { + +class HashSkipListRepFactory : public MemTableRepFactory { + public: + explicit HashSkipListRepFactory( + size_t bucket_count, + int32_t skiplist_height, + int32_t skiplist_branching_factor) + : bucket_count_(bucket_count), + skiplist_height_(skiplist_height), + skiplist_branching_factor_(skiplist_branching_factor) { } + + virtual ~HashSkipListRepFactory() {} + + using MemTableRepFactory::CreateMemTableRep; + virtual MemTableRep* CreateMemTableRep( + const MemTableRep::KeyComparator& compare, Allocator* allocator, + const SliceTransform* transform, Logger* logger) override; + + virtual const char* Name() const override { + return "HashSkipListRepFactory"; + } + + private: + const size_t bucket_count_; + const int32_t skiplist_height_; + const int32_t skiplist_branching_factor_; +}; + +} +#endif // ROCKSDB_LITE diff --git a/src/rocksdb/memtable/inlineskiplist.h b/src/rocksdb/memtable/inlineskiplist.h new file mode 100644 index 00000000..1ef8f2b6 --- /dev/null +++ b/src/rocksdb/memtable/inlineskiplist.h @@ -0,0 +1,965 @@ +// Copyright (c) 2011-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). +// +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. Use of +// this source code is governed by a BSD-style license that can be found +// in the LICENSE file. See the AUTHORS file for names of contributors. +// +// InlineSkipList is derived from SkipList (skiplist.h), but it optimizes +// the memory layout by requiring that the key storage be allocated through +// the skip list instance. For the common case of SkipList<const char*, +// Cmp> this saves 1 pointer per skip list node and gives better cache +// locality, at the expense of wasted padding from using AllocateAligned +// instead of Allocate for the keys. The unused padding will be from +// 0 to sizeof(void*)-1 bytes, and the space savings are sizeof(void*) +// bytes, so despite the padding the space used is always less than +// SkipList<const char*, ..>. +// +// Thread safety ------------- +// +// Writes via Insert require external synchronization, most likely a mutex. +// InsertConcurrently can be safely called concurrently with reads and +// with other concurrent inserts. Reads require a guarantee that the +// InlineSkipList will not be destroyed while the read is in progress. +// Apart from that, reads progress without any internal locking or +// synchronization. +// +// Invariants: +// +// (1) Allocated nodes are never deleted until the InlineSkipList is +// destroyed. This is trivially guaranteed by the code since we never +// delete any skip list nodes. +// +// (2) The contents of a Node except for the next/prev pointers are +// immutable after the Node has been linked into the InlineSkipList. +// Only Insert() modifies the list, and it is careful to initialize a +// node and use release-stores to publish the nodes in one or more lists. +// +// ... prev vs. next pointer ordering ... +// + +#pragma once +#include <assert.h> +#include <stdlib.h> +#include <algorithm> +#include <atomic> +#include <type_traits> +#include "port/likely.h" +#include "port/port.h" +#include "rocksdb/slice.h" +#include "util/allocator.h" +#include "util/coding.h" +#include "util/random.h" + +namespace rocksdb { + +template <class Comparator> +class InlineSkipList { + private: + struct Node; + struct Splice; + + public: + using DecodedKey = \ + typename std::remove_reference<Comparator>::type::DecodedType; + + static const uint16_t kMaxPossibleHeight = 32; + + // Create a new InlineSkipList object that will use "cmp" for comparing + // keys, and will allocate memory using "*allocator". Objects allocated + // in the allocator must remain allocated for the lifetime of the + // skiplist object. + explicit InlineSkipList(Comparator cmp, Allocator* allocator, + int32_t max_height = 12, + int32_t branching_factor = 4); + + // Allocates a key and a skip-list node, returning a pointer to the key + // portion of the node. This method is thread-safe if the allocator + // is thread-safe. + char* AllocateKey(size_t key_size); + + // Allocate a splice using allocator. + Splice* AllocateSplice(); + + // Inserts a key allocated by AllocateKey, after the actual key value + // has been filled in. + // + // REQUIRES: nothing that compares equal to key is currently in the list. + // REQUIRES: no concurrent calls to any of inserts. + bool Insert(const char* key); + + // Inserts a key allocated by AllocateKey with a hint of last insert + // position in the skip-list. If hint points to nullptr, a new hint will be + // populated, which can be used in subsequent calls. + // + // It can be used to optimize the workload where there are multiple groups + // of keys, and each key is likely to insert to a location close to the last + // inserted key in the same group. One example is sequential inserts. + // + // REQUIRES: nothing that compares equal to key is currently in the list. + // REQUIRES: no concurrent calls to any of inserts. + bool InsertWithHint(const char* key, void** hint); + + // Like Insert, but external synchronization is not required. + bool InsertConcurrently(const char* key); + + // Inserts a node into the skip list. key must have been allocated by + // AllocateKey and then filled in by the caller. If UseCAS is true, + // then external synchronization is not required, otherwise this method + // may not be called concurrently with any other insertions. + // + // Regardless of whether UseCAS is true, the splice must be owned + // exclusively by the current thread. If allow_partial_splice_fix is + // true, then the cost of insertion is amortized O(log D), where D is + // the distance from the splice to the inserted key (measured as the + // number of intervening nodes). Note that this bound is very good for + // sequential insertions! If allow_partial_splice_fix is false then + // the existing splice will be ignored unless the current key is being + // inserted immediately after the splice. allow_partial_splice_fix == + // false has worse running time for the non-sequential case O(log N), + // but a better constant factor. + template <bool UseCAS> + bool Insert(const char* key, Splice* splice, bool allow_partial_splice_fix); + + // Returns true iff an entry that compares equal to key is in the list. + bool Contains(const char* key) const; + + // Return estimated number of entries smaller than `key`. + uint64_t EstimateCount(const char* key) const; + + // Validate correctness of the skip-list. + void TEST_Validate() const; + + // Iteration over the contents of a skip list + class Iterator { + public: + // Initialize an iterator over the specified list. + // The returned iterator is not valid. + explicit Iterator(const InlineSkipList* list); + + // Change the underlying skiplist used for this iterator + // This enables us not changing the iterator without deallocating + // an old one and then allocating a new one + void SetList(const InlineSkipList* list); + + // Returns true iff the iterator is positioned at a valid node. + bool Valid() const; + + // Returns the key at the current position. + // REQUIRES: Valid() + const char* key() const; + + // Advances to the next position. + // REQUIRES: Valid() + void Next(); + + // Advances to the previous position. + // REQUIRES: Valid() + void Prev(); + + // Advance to the first entry with a key >= target + void Seek(const char* target); + + // Retreat to the last entry with a key <= target + void SeekForPrev(const char* target); + + // Position at the first entry in list. + // Final state of iterator is Valid() iff list is not empty. + void SeekToFirst(); + + // Position at the last entry in list. + // Final state of iterator is Valid() iff list is not empty. + void SeekToLast(); + + private: + const InlineSkipList* list_; + Node* node_; + // Intentionally copyable + }; + + private: + const uint16_t kMaxHeight_; + const uint16_t kBranching_; + const uint32_t kScaledInverseBranching_; + + Allocator* const allocator_; // Allocator used for allocations of nodes + // Immutable after construction + Comparator const compare_; + Node* const head_; + + // Modified only by Insert(). Read racily by readers, but stale + // values are ok. + std::atomic<int> max_height_; // Height of the entire list + + // seq_splice_ is a Splice used for insertions in the non-concurrent + // case. It caches the prev and next found during the most recent + // non-concurrent insertion. + Splice* seq_splice_; + + inline int GetMaxHeight() const { + return max_height_.load(std::memory_order_relaxed); + } + + int RandomHeight(); + + Node* AllocateNode(size_t key_size, int height); + + bool Equal(const char* a, const char* b) const { + return (compare_(a, b) == 0); + } + + bool LessThan(const char* a, const char* b) const { + return (compare_(a, b) < 0); + } + + // Return true if key is greater than the data stored in "n". Null n + // is considered infinite. n should not be head_. + bool KeyIsAfterNode(const char* key, Node* n) const; + bool KeyIsAfterNode(const DecodedKey& key, Node* n) const; + + // Returns the earliest node with a key >= key. + // Return nullptr if there is no such node. + Node* FindGreaterOrEqual(const char* key) const; + + // Return the latest node with a key < key. + // Return head_ if there is no such node. + // Fills prev[level] with pointer to previous node at "level" for every + // level in [0..max_height_-1], if prev is non-null. + Node* FindLessThan(const char* key, Node** prev = nullptr) const; + + // Return the latest node with a key < key on bottom_level. Start searching + // from root node on the level below top_level. + // Fills prev[level] with pointer to previous node at "level" for every + // level in [bottom_level..top_level-1], if prev is non-null. + Node* FindLessThan(const char* key, Node** prev, Node* root, int top_level, + int bottom_level) const; + + // Return the last node in the list. + // Return head_ if list is empty. + Node* FindLast() const; + + // Traverses a single level of the list, setting *out_prev to the last + // node before the key and *out_next to the first node after. Assumes + // that the key is not present in the skip list. On entry, before should + // point to a node that is before the key, and after should point to + // a node that is after the key. after should be nullptr if a good after + // node isn't conveniently available. + template<bool prefetch_before> + void FindSpliceForLevel(const DecodedKey& key, Node* before, Node* after, int level, + Node** out_prev, Node** out_next); + + // Recomputes Splice levels from highest_level (inclusive) down to + // lowest_level (inclusive). + void RecomputeSpliceLevels(const DecodedKey& key, Splice* splice, + int recompute_level); + + // No copying allowed + InlineSkipList(const InlineSkipList&); + InlineSkipList& operator=(const InlineSkipList&); +}; + +// Implementation details follow + +template <class Comparator> +struct InlineSkipList<Comparator>::Splice { + // The invariant of a Splice is that prev_[i+1].key <= prev_[i].key < + // next_[i].key <= next_[i+1].key for all i. That means that if a + // key is bracketed by prev_[i] and next_[i] then it is bracketed by + // all higher levels. It is _not_ required that prev_[i]->Next(i) == + // next_[i] (it probably did at some point in the past, but intervening + // or concurrent operations might have inserted nodes in between). + int height_ = 0; + Node** prev_; + Node** next_; +}; + +// The Node data type is more of a pointer into custom-managed memory than +// a traditional C++ struct. The key is stored in the bytes immediately +// after the struct, and the next_ pointers for nodes with height > 1 are +// stored immediately _before_ the struct. This avoids the need to include +// any pointer or sizing data, which reduces per-node memory overheads. +template <class Comparator> +struct InlineSkipList<Comparator>::Node { + // Stores the height of the node in the memory location normally used for + // next_[0]. This is used for passing data from AllocateKey to Insert. + void StashHeight(const int height) { + assert(sizeof(int) <= sizeof(next_[0])); + memcpy(static_cast<void*>(&next_[0]), &height, sizeof(int)); + } + + // Retrieves the value passed to StashHeight. Undefined after a call + // to SetNext or NoBarrier_SetNext. + int UnstashHeight() const { + int rv; + memcpy(&rv, &next_[0], sizeof(int)); + return rv; + } + + const char* Key() const { return reinterpret_cast<const char*>(&next_[1]); } + + // Accessors/mutators for links. Wrapped in methods so we can add + // the appropriate barriers as necessary, and perform the necessary + // addressing trickery for storing links below the Node in memory. + Node* Next(int n) { + assert(n >= 0); + // Use an 'acquire load' so that we observe a fully initialized + // version of the returned Node. + return ((&next_[0] - n)->load(std::memory_order_acquire)); + } + + void SetNext(int n, Node* x) { + assert(n >= 0); + // Use a 'release store' so that anybody who reads through this + // pointer observes a fully initialized version of the inserted node. + (&next_[0] - n)->store(x, std::memory_order_release); + } + + bool CASNext(int n, Node* expected, Node* x) { + assert(n >= 0); + return (&next_[0] - n)->compare_exchange_strong(expected, x); + } + + // No-barrier variants that can be safely used in a few locations. + Node* NoBarrier_Next(int n) { + assert(n >= 0); + return (&next_[0] - n)->load(std::memory_order_relaxed); + } + + void NoBarrier_SetNext(int n, Node* x) { + assert(n >= 0); + (&next_[0] - n)->store(x, std::memory_order_relaxed); + } + + // Insert node after prev on specific level. + void InsertAfter(Node* prev, int level) { + // NoBarrier_SetNext() suffices since we will add a barrier when + // we publish a pointer to "this" in prev. + NoBarrier_SetNext(level, prev->NoBarrier_Next(level)); + prev->SetNext(level, this); + } + + private: + // next_[0] is the lowest level link (level 0). Higher levels are + // stored _earlier_, so level 1 is at next_[-1]. + std::atomic<Node*> next_[1]; +}; + +template <class Comparator> +inline InlineSkipList<Comparator>::Iterator::Iterator( + const InlineSkipList* list) { + SetList(list); +} + +template <class Comparator> +inline void InlineSkipList<Comparator>::Iterator::SetList( + const InlineSkipList* list) { + list_ = list; + node_ = nullptr; +} + +template <class Comparator> +inline bool InlineSkipList<Comparator>::Iterator::Valid() const { + return node_ != nullptr; +} + +template <class Comparator> +inline const char* InlineSkipList<Comparator>::Iterator::key() const { + assert(Valid()); + return node_->Key(); +} + +template <class Comparator> +inline void InlineSkipList<Comparator>::Iterator::Next() { + assert(Valid()); + node_ = node_->Next(0); +} + +template <class Comparator> +inline void InlineSkipList<Comparator>::Iterator::Prev() { + // Instead of using explicit "prev" links, we just search for the + // last node that falls before key. + assert(Valid()); + node_ = list_->FindLessThan(node_->Key()); + if (node_ == list_->head_) { + node_ = nullptr; + } +} + +template <class Comparator> +inline void InlineSkipList<Comparator>::Iterator::Seek(const char* target) { + node_ = list_->FindGreaterOrEqual(target); +} + +template <class Comparator> +inline void InlineSkipList<Comparator>::Iterator::SeekForPrev( + const char* target) { + Seek(target); + if (!Valid()) { + SeekToLast(); + } + while (Valid() && list_->LessThan(target, key())) { + Prev(); + } +} + +template <class Comparator> +inline void InlineSkipList<Comparator>::Iterator::SeekToFirst() { + node_ = list_->head_->Next(0); +} + +template <class Comparator> +inline void InlineSkipList<Comparator>::Iterator::SeekToLast() { + node_ = list_->FindLast(); + if (node_ == list_->head_) { + node_ = nullptr; + } +} + +template <class Comparator> +int InlineSkipList<Comparator>::RandomHeight() { + auto rnd = Random::GetTLSInstance(); + + // Increase height with probability 1 in kBranching + int height = 1; + while (height < kMaxHeight_ && height < kMaxPossibleHeight && + rnd->Next() < kScaledInverseBranching_) { + height++; + } + assert(height > 0); + assert(height <= kMaxHeight_); + assert(height <= kMaxPossibleHeight); + return height; +} + +template <class Comparator> +bool InlineSkipList<Comparator>::KeyIsAfterNode(const char* key, + Node* n) const { + // nullptr n is considered infinite + assert(n != head_); + return (n != nullptr) && (compare_(n->Key(), key) < 0); +} + +template <class Comparator> +bool InlineSkipList<Comparator>::KeyIsAfterNode(const DecodedKey& key, + Node* n) const { + // nullptr n is considered infinite + assert(n != head_); + return (n != nullptr) && (compare_(n->Key(), key) < 0); +} + +template <class Comparator> +typename InlineSkipList<Comparator>::Node* +InlineSkipList<Comparator>::FindGreaterOrEqual(const char* key) const { + // Note: It looks like we could reduce duplication by implementing + // this function as FindLessThan(key)->Next(0), but we wouldn't be able + // to exit early on equality and the result wouldn't even be correct. + // A concurrent insert might occur after FindLessThan(key) but before + // we get a chance to call Next(0). + Node* x = head_; + int level = GetMaxHeight() - 1; + Node* last_bigger = nullptr; + const DecodedKey key_decoded = compare_.decode_key(key); + while (true) { + Node* next = x->Next(level); + if (next != nullptr) { + PREFETCH(next->Next(level), 0, 1); + } + // Make sure the lists are sorted + assert(x == head_ || next == nullptr || KeyIsAfterNode(next->Key(), x)); + // Make sure we haven't overshot during our search + assert(x == head_ || KeyIsAfterNode(key_decoded, x)); + int cmp = (next == nullptr || next == last_bigger) + ? 1 + : compare_(next->Key(), key_decoded); + if (cmp == 0 || (cmp > 0 && level == 0)) { + return next; + } else if (cmp < 0) { + // Keep searching in this list + x = next; + } else { + // Switch to next list, reuse compare_() result + last_bigger = next; + level--; + } + } +} + +template <class Comparator> +typename InlineSkipList<Comparator>::Node* +InlineSkipList<Comparator>::FindLessThan(const char* key, Node** prev) const { + return FindLessThan(key, prev, head_, GetMaxHeight(), 0); +} + +template <class Comparator> +typename InlineSkipList<Comparator>::Node* +InlineSkipList<Comparator>::FindLessThan(const char* key, Node** prev, + Node* root, int top_level, + int bottom_level) const { + assert(top_level > bottom_level); + int level = top_level - 1; + Node* x = root; + // KeyIsAfter(key, last_not_after) is definitely false + Node* last_not_after = nullptr; + const DecodedKey key_decoded = compare_.decode_key(key); + while (true) { + assert(x != nullptr); + Node* next = x->Next(level); + if (next != nullptr) { + PREFETCH(next->Next(level), 0, 1); + } + assert(x == head_ || next == nullptr || KeyIsAfterNode(next->Key(), x)); + assert(x == head_ || KeyIsAfterNode(key_decoded, x)); + if (next != last_not_after && KeyIsAfterNode(key_decoded, next)) { + // Keep searching in this list + assert(next != nullptr); + x = next; + } else { + if (prev != nullptr) { + prev[level] = x; + } + if (level == bottom_level) { + return x; + } else { + // Switch to next list, reuse KeyIsAfterNode() result + last_not_after = next; + level--; + } + } + } +} + +template <class Comparator> +typename InlineSkipList<Comparator>::Node* +InlineSkipList<Comparator>::FindLast() const { + Node* x = head_; + int level = GetMaxHeight() - 1; + while (true) { + Node* next = x->Next(level); + if (next == nullptr) { + if (level == 0) { + return x; + } else { + // Switch to next list + level--; + } + } else { + x = next; + } + } +} + +template <class Comparator> +uint64_t InlineSkipList<Comparator>::EstimateCount(const char* key) const { + uint64_t count = 0; + + Node* x = head_; + int level = GetMaxHeight() - 1; + const DecodedKey key_decoded = compare_.decode_key(key); + while (true) { + assert(x == head_ || compare_(x->Key(), key_decoded) < 0); + Node* next = x->Next(level); + if (next != nullptr) { + PREFETCH(next->Next(level), 0, 1); + } + if (next == nullptr || compare_(next->Key(), key_decoded) >= 0) { + if (level == 0) { + return count; + } else { + // Switch to next list + count *= kBranching_; + level--; + } + } else { + x = next; + count++; + } + } +} + +template <class Comparator> +InlineSkipList<Comparator>::InlineSkipList(const Comparator cmp, + Allocator* allocator, + int32_t max_height, + int32_t branching_factor) + : kMaxHeight_(static_cast<uint16_t>(max_height)), + kBranching_(static_cast<uint16_t>(branching_factor)), + kScaledInverseBranching_((Random::kMaxNext + 1) / kBranching_), + allocator_(allocator), + compare_(cmp), + head_(AllocateNode(0, max_height)), + max_height_(1), + seq_splice_(AllocateSplice()) { + assert(max_height > 0 && kMaxHeight_ == static_cast<uint32_t>(max_height)); + assert(branching_factor > 1 && + kBranching_ == static_cast<uint32_t>(branching_factor)); + assert(kScaledInverseBranching_ > 0); + + for (int i = 0; i < kMaxHeight_; ++i) { + head_->SetNext(i, nullptr); + } +} + +template <class Comparator> +char* InlineSkipList<Comparator>::AllocateKey(size_t key_size) { + return const_cast<char*>(AllocateNode(key_size, RandomHeight())->Key()); +} + +template <class Comparator> +typename InlineSkipList<Comparator>::Node* +InlineSkipList<Comparator>::AllocateNode(size_t key_size, int height) { + auto prefix = sizeof(std::atomic<Node*>) * (height - 1); + + // prefix is space for the height - 1 pointers that we store before + // the Node instance (next_[-(height - 1) .. -1]). Node starts at + // raw + prefix, and holds the bottom-mode (level 0) skip list pointer + // next_[0]. key_size is the bytes for the key, which comes just after + // the Node. + char* raw = allocator_->AllocateAligned(prefix + sizeof(Node) + key_size); + Node* x = reinterpret_cast<Node*>(raw + prefix); + + // Once we've linked the node into the skip list we don't actually need + // to know its height, because we can implicitly use the fact that we + // traversed into a node at level h to known that h is a valid level + // for that node. We need to convey the height to the Insert step, + // however, so that it can perform the proper links. Since we're not + // using the pointers at the moment, StashHeight temporarily borrow + // storage from next_[0] for that purpose. + x->StashHeight(height); + return x; +} + +template <class Comparator> +typename InlineSkipList<Comparator>::Splice* +InlineSkipList<Comparator>::AllocateSplice() { + // size of prev_ and next_ + size_t array_size = sizeof(Node*) * (kMaxHeight_ + 1); + char* raw = allocator_->AllocateAligned(sizeof(Splice) + array_size * 2); + Splice* splice = reinterpret_cast<Splice*>(raw); + splice->height_ = 0; + splice->prev_ = reinterpret_cast<Node**>(raw + sizeof(Splice)); + splice->next_ = reinterpret_cast<Node**>(raw + sizeof(Splice) + array_size); + return splice; +} + +template <class Comparator> +bool InlineSkipList<Comparator>::Insert(const char* key) { + return Insert<false>(key, seq_splice_, false); +} + +template <class Comparator> +bool InlineSkipList<Comparator>::InsertConcurrently(const char* key) { + Node* prev[kMaxPossibleHeight]; + Node* next[kMaxPossibleHeight]; + Splice splice; + splice.prev_ = prev; + splice.next_ = next; + return Insert<true>(key, &splice, false); +} + +template <class Comparator> +bool InlineSkipList<Comparator>::InsertWithHint(const char* key, void** hint) { + assert(hint != nullptr); + Splice* splice = reinterpret_cast<Splice*>(*hint); + if (splice == nullptr) { + splice = AllocateSplice(); + *hint = reinterpret_cast<void*>(splice); + } + return Insert<false>(key, splice, true); +} + +template <class Comparator> +template <bool prefetch_before> +void InlineSkipList<Comparator>::FindSpliceForLevel(const DecodedKey& key, + Node* before, Node* after, + int level, Node** out_prev, + Node** out_next) { + while (true) { + Node* next = before->Next(level); + if (next != nullptr) { + PREFETCH(next->Next(level), 0, 1); + } + if (prefetch_before == true) { + if (next != nullptr && level>0) { + PREFETCH(next->Next(level-1), 0, 1); + } + } + assert(before == head_ || next == nullptr || + KeyIsAfterNode(next->Key(), before)); + assert(before == head_ || KeyIsAfterNode(key, before)); + if (next == after || !KeyIsAfterNode(key, next)) { + // found it + *out_prev = before; + *out_next = next; + return; + } + before = next; + } +} + +template <class Comparator> +void InlineSkipList<Comparator>::RecomputeSpliceLevels(const DecodedKey& key, + Splice* splice, + int recompute_level) { + assert(recompute_level > 0); + assert(recompute_level <= splice->height_); + for (int i = recompute_level - 1; i >= 0; --i) { + FindSpliceForLevel<true>(key, splice->prev_[i + 1], splice->next_[i + 1], i, + &splice->prev_[i], &splice->next_[i]); + } +} + +template <class Comparator> +template <bool UseCAS> +bool InlineSkipList<Comparator>::Insert(const char* key, Splice* splice, + bool allow_partial_splice_fix) { + Node* x = reinterpret_cast<Node*>(const_cast<char*>(key)) - 1; + const DecodedKey key_decoded = compare_.decode_key(key); + int height = x->UnstashHeight(); + assert(height >= 1 && height <= kMaxHeight_); + + int max_height = max_height_.load(std::memory_order_relaxed); + while (height > max_height) { + if (max_height_.compare_exchange_weak(max_height, height)) { + // successfully updated it + max_height = height; + break; + } + // else retry, possibly exiting the loop because somebody else + // increased it + } + assert(max_height <= kMaxPossibleHeight); + + int recompute_height = 0; + if (splice->height_ < max_height) { + // Either splice has never been used or max_height has grown since + // last use. We could potentially fix it in the latter case, but + // that is tricky. + splice->prev_[max_height] = head_; + splice->next_[max_height] = nullptr; + splice->height_ = max_height; + recompute_height = max_height; + } else { + // Splice is a valid proper-height splice that brackets some + // key, but does it bracket this one? We need to validate it and + // recompute a portion of the splice (levels 0..recompute_height-1) + // that is a superset of all levels that don't bracket the new key. + // Several choices are reasonable, because we have to balance the work + // saved against the extra comparisons required to validate the Splice. + // + // One strategy is just to recompute all of orig_splice_height if the + // bottom level isn't bracketing. This pessimistically assumes that + // we will either get a perfect Splice hit (increasing sequential + // inserts) or have no locality. + // + // Another strategy is to walk up the Splice's levels until we find + // a level that brackets the key. This strategy lets the Splice + // hint help for other cases: it turns insertion from O(log N) into + // O(log D), where D is the number of nodes in between the key that + // produced the Splice and the current insert (insertion is aided + // whether the new key is before or after the splice). If you have + // a way of using a prefix of the key to map directly to the closest + // Splice out of O(sqrt(N)) Splices and we make it so that splices + // can also be used as hints during read, then we end up with Oshman's + // and Shavit's SkipTrie, which has O(log log N) lookup and insertion + // (compare to O(log N) for skip list). + // + // We control the pessimistic strategy with allow_partial_splice_fix. + // A good strategy is probably to be pessimistic for seq_splice_, + // optimistic if the caller actually went to the work of providing + // a Splice. + while (recompute_height < max_height) { + if (splice->prev_[recompute_height]->Next(recompute_height) != + splice->next_[recompute_height]) { + // splice isn't tight at this level, there must have been some inserts + // to this + // location that didn't update the splice. We might only be a little + // stale, but if + // the splice is very stale it would be O(N) to fix it. We haven't used + // up any of + // our budget of comparisons, so always move up even if we are + // pessimistic about + // our chances of success. + ++recompute_height; + } else if (splice->prev_[recompute_height] != head_ && + !KeyIsAfterNode(key_decoded, + splice->prev_[recompute_height])) { + // key is from before splice + if (allow_partial_splice_fix) { + // skip all levels with the same node without more comparisons + Node* bad = splice->prev_[recompute_height]; + while (splice->prev_[recompute_height] == bad) { + ++recompute_height; + } + } else { + // we're pessimistic, recompute everything + recompute_height = max_height; + } + } else if (KeyIsAfterNode(key_decoded, + splice->next_[recompute_height])) { + // key is from after splice + if (allow_partial_splice_fix) { + Node* bad = splice->next_[recompute_height]; + while (splice->next_[recompute_height] == bad) { + ++recompute_height; + } + } else { + recompute_height = max_height; + } + } else { + // this level brackets the key, we won! + break; + } + } + } + assert(recompute_height <= max_height); + if (recompute_height > 0) { + RecomputeSpliceLevels(key_decoded, splice, recompute_height); + } + + bool splice_is_valid = true; + if (UseCAS) { + for (int i = 0; i < height; ++i) { + while (true) { + // Checking for duplicate keys on the level 0 is sufficient + if (UNLIKELY(i == 0 && splice->next_[i] != nullptr && + compare_(x->Key(), splice->next_[i]->Key()) >= 0)) { + // duplicate key + return false; + } + if (UNLIKELY(i == 0 && splice->prev_[i] != head_ && + compare_(splice->prev_[i]->Key(), x->Key()) >= 0)) { + // duplicate key + return false; + } + assert(splice->next_[i] == nullptr || + compare_(x->Key(), splice->next_[i]->Key()) < 0); + assert(splice->prev_[i] == head_ || + compare_(splice->prev_[i]->Key(), x->Key()) < 0); + x->NoBarrier_SetNext(i, splice->next_[i]); + if (splice->prev_[i]->CASNext(i, splice->next_[i], x)) { + // success + break; + } + // CAS failed, we need to recompute prev and next. It is unlikely + // to be helpful to try to use a different level as we redo the + // search, because it should be unlikely that lots of nodes have + // been inserted between prev[i] and next[i]. No point in using + // next[i] as the after hint, because we know it is stale. + FindSpliceForLevel<false>(key_decoded, splice->prev_[i], nullptr, i, + &splice->prev_[i], &splice->next_[i]); + + // Since we've narrowed the bracket for level i, we might have + // violated the Splice constraint between i and i-1. Make sure + // we recompute the whole thing next time. + if (i > 0) { + splice_is_valid = false; + } + } + } + } else { + for (int i = 0; i < height; ++i) { + if (i >= recompute_height && + splice->prev_[i]->Next(i) != splice->next_[i]) { + FindSpliceForLevel<false>(key_decoded, splice->prev_[i], nullptr, i, + &splice->prev_[i], &splice->next_[i]); + } + // Checking for duplicate keys on the level 0 is sufficient + if (UNLIKELY(i == 0 && splice->next_[i] != nullptr && + compare_(x->Key(), splice->next_[i]->Key()) >= 0)) { + // duplicate key + return false; + } + if (UNLIKELY(i == 0 && splice->prev_[i] != head_ && + compare_(splice->prev_[i]->Key(), x->Key()) >= 0)) { + // duplicate key + return false; + } + assert(splice->next_[i] == nullptr || + compare_(x->Key(), splice->next_[i]->Key()) < 0); + assert(splice->prev_[i] == head_ || + compare_(splice->prev_[i]->Key(), x->Key()) < 0); + assert(splice->prev_[i]->Next(i) == splice->next_[i]); + x->NoBarrier_SetNext(i, splice->next_[i]); + splice->prev_[i]->SetNext(i, x); + } + } + if (splice_is_valid) { + for (int i = 0; i < height; ++i) { + splice->prev_[i] = x; + } + assert(splice->prev_[splice->height_] == head_); + assert(splice->next_[splice->height_] == nullptr); + for (int i = 0; i < splice->height_; ++i) { + assert(splice->next_[i] == nullptr || + compare_(key, splice->next_[i]->Key()) < 0); + assert(splice->prev_[i] == head_ || + compare_(splice->prev_[i]->Key(), key) <= 0); + assert(splice->prev_[i + 1] == splice->prev_[i] || + splice->prev_[i + 1] == head_ || + compare_(splice->prev_[i + 1]->Key(), splice->prev_[i]->Key()) < + 0); + assert(splice->next_[i + 1] == splice->next_[i] || + splice->next_[i + 1] == nullptr || + compare_(splice->next_[i]->Key(), splice->next_[i + 1]->Key()) < + 0); + } + } else { + splice->height_ = 0; + } + return true; +} + +template <class Comparator> +bool InlineSkipList<Comparator>::Contains(const char* key) const { + Node* x = FindGreaterOrEqual(key); + if (x != nullptr && Equal(key, x->Key())) { + return true; + } else { + return false; + } +} + +template <class Comparator> +void InlineSkipList<Comparator>::TEST_Validate() const { + // Interate over all levels at the same time, and verify nodes appear in + // the right order, and nodes appear in upper level also appear in lower + // levels. + Node* nodes[kMaxPossibleHeight]; + int max_height = GetMaxHeight(); + assert(max_height > 0); + for (int i = 0; i < max_height; i++) { + nodes[i] = head_; + } + while (nodes[0] != nullptr) { + Node* l0_next = nodes[0]->Next(0); + if (l0_next == nullptr) { + break; + } + assert(nodes[0] == head_ || compare_(nodes[0]->Key(), l0_next->Key()) < 0); + nodes[0] = l0_next; + + int i = 1; + while (i < max_height) { + Node* next = nodes[i]->Next(i); + if (next == nullptr) { + break; + } + auto cmp = compare_(nodes[0]->Key(), next->Key()); + assert(cmp <= 0); + if (cmp == 0) { + assert(next == nodes[0]); + nodes[i] = next; + } else { + break; + } + i++; + } + } + for (int i = 1; i < max_height; i++) { + assert(nodes[i] != nullptr && nodes[i]->Next(i) == nullptr); + } +} + +} // namespace rocksdb diff --git a/src/rocksdb/memtable/inlineskiplist_test.cc b/src/rocksdb/memtable/inlineskiplist_test.cc new file mode 100644 index 00000000..b416ef7c --- /dev/null +++ b/src/rocksdb/memtable/inlineskiplist_test.cc @@ -0,0 +1,645 @@ +// Copyright (c) 2011-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). +// +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#include "memtable/inlineskiplist.h" +#include <set> +#include <unordered_set> +#include "rocksdb/env.h" +#include "util/concurrent_arena.h" +#include "util/hash.h" +#include "util/random.h" +#include "util/testharness.h" + +namespace rocksdb { + +// Our test skip list stores 8-byte unsigned integers +typedef uint64_t Key; + +static const char* Encode(const uint64_t* key) { + return reinterpret_cast<const char*>(key); +} + +static Key Decode(const char* key) { + Key rv; + memcpy(&rv, key, sizeof(Key)); + return rv; +} + +struct TestComparator { + typedef Key DecodedType; + + static DecodedType decode_key(const char* b) { + return Decode(b); + } + + int operator()(const char* a, const char* b) const { + if (Decode(a) < Decode(b)) { + return -1; + } else if (Decode(a) > Decode(b)) { + return +1; + } else { + return 0; + } + } + + int operator()(const char* a, const DecodedType b) const { + if (Decode(a) < b) { + return -1; + } else if (Decode(a) > b) { + return +1; + } else { + return 0; + } + } +}; + +typedef InlineSkipList<TestComparator> TestInlineSkipList; + +class InlineSkipTest : public testing::Test { + public: + void Insert(TestInlineSkipList* list, Key key) { + char* buf = list->AllocateKey(sizeof(Key)); + memcpy(buf, &key, sizeof(Key)); + list->Insert(buf); + keys_.insert(key); + } + + bool InsertWithHint(TestInlineSkipList* list, Key key, void** hint) { + char* buf = list->AllocateKey(sizeof(Key)); + memcpy(buf, &key, sizeof(Key)); + bool res = list->InsertWithHint(buf, hint); + keys_.insert(key); + return res; + } + + void Validate(TestInlineSkipList* list) { + // Check keys exist. + for (Key key : keys_) { + ASSERT_TRUE(list->Contains(Encode(&key))); + } + // Iterate over the list, make sure keys appears in order and no extra + // keys exist. + TestInlineSkipList::Iterator iter(list); + ASSERT_FALSE(iter.Valid()); + Key zero = 0; + iter.Seek(Encode(&zero)); + for (Key key : keys_) { + ASSERT_TRUE(iter.Valid()); + ASSERT_EQ(key, Decode(iter.key())); + iter.Next(); + } + ASSERT_FALSE(iter.Valid()); + // Validate the list is well-formed. + list->TEST_Validate(); + } + + private: + std::set<Key> keys_; +}; + +TEST_F(InlineSkipTest, Empty) { + Arena arena; + TestComparator cmp; + InlineSkipList<TestComparator> list(cmp, &arena); + Key key = 10; + ASSERT_TRUE(!list.Contains(Encode(&key))); + + InlineSkipList<TestComparator>::Iterator iter(&list); + ASSERT_TRUE(!iter.Valid()); + iter.SeekToFirst(); + ASSERT_TRUE(!iter.Valid()); + key = 100; + iter.Seek(Encode(&key)); + ASSERT_TRUE(!iter.Valid()); + iter.SeekForPrev(Encode(&key)); + ASSERT_TRUE(!iter.Valid()); + iter.SeekToLast(); + ASSERT_TRUE(!iter.Valid()); +} + +TEST_F(InlineSkipTest, InsertAndLookup) { + const int N = 2000; + const int R = 5000; + Random rnd(1000); + std::set<Key> keys; + ConcurrentArena arena; + TestComparator cmp; + InlineSkipList<TestComparator> list(cmp, &arena); + for (int i = 0; i < N; i++) { + Key key = rnd.Next() % R; + if (keys.insert(key).second) { + char* buf = list.AllocateKey(sizeof(Key)); + memcpy(buf, &key, sizeof(Key)); + list.Insert(buf); + } + } + + for (Key i = 0; i < R; i++) { + if (list.Contains(Encode(&i))) { + ASSERT_EQ(keys.count(i), 1U); + } else { + ASSERT_EQ(keys.count(i), 0U); + } + } + + // Simple iterator tests + { + InlineSkipList<TestComparator>::Iterator iter(&list); + ASSERT_TRUE(!iter.Valid()); + + uint64_t zero = 0; + iter.Seek(Encode(&zero)); + ASSERT_TRUE(iter.Valid()); + ASSERT_EQ(*(keys.begin()), Decode(iter.key())); + + uint64_t max_key = R - 1; + iter.SeekForPrev(Encode(&max_key)); + ASSERT_TRUE(iter.Valid()); + ASSERT_EQ(*(keys.rbegin()), Decode(iter.key())); + + iter.SeekToFirst(); + ASSERT_TRUE(iter.Valid()); + ASSERT_EQ(*(keys.begin()), Decode(iter.key())); + + iter.SeekToLast(); + ASSERT_TRUE(iter.Valid()); + ASSERT_EQ(*(keys.rbegin()), Decode(iter.key())); + } + + // Forward iteration test + for (Key i = 0; i < R; i++) { + InlineSkipList<TestComparator>::Iterator iter(&list); + iter.Seek(Encode(&i)); + + // Compare against model iterator + std::set<Key>::iterator model_iter = keys.lower_bound(i); + for (int j = 0; j < 3; j++) { + if (model_iter == keys.end()) { + ASSERT_TRUE(!iter.Valid()); + break; + } else { + ASSERT_TRUE(iter.Valid()); + ASSERT_EQ(*model_iter, Decode(iter.key())); + ++model_iter; + iter.Next(); + } + } + } + + // Backward iteration test + for (Key i = 0; i < R; i++) { + InlineSkipList<TestComparator>::Iterator iter(&list); + iter.SeekForPrev(Encode(&i)); + + // Compare against model iterator + std::set<Key>::iterator model_iter = keys.upper_bound(i); + for (int j = 0; j < 3; j++) { + if (model_iter == keys.begin()) { + ASSERT_TRUE(!iter.Valid()); + break; + } else { + ASSERT_TRUE(iter.Valid()); + ASSERT_EQ(*--model_iter, Decode(iter.key())); + iter.Prev(); + } + } + } +} + +TEST_F(InlineSkipTest, InsertWithHint_Sequential) { + const int N = 100000; + Arena arena; + TestComparator cmp; + TestInlineSkipList list(cmp, &arena); + void* hint = nullptr; + for (int i = 0; i < N; i++) { + Key key = i; + InsertWithHint(&list, key, &hint); + } + Validate(&list); +} + +TEST_F(InlineSkipTest, InsertWithHint_MultipleHints) { + const int N = 100000; + const int S = 100; + Random rnd(534); + Arena arena; + TestComparator cmp; + TestInlineSkipList list(cmp, &arena); + void* hints[S]; + Key last_key[S]; + for (int i = 0; i < S; i++) { + hints[i] = nullptr; + last_key[i] = 0; + } + for (int i = 0; i < N; i++) { + Key s = rnd.Uniform(S); + Key key = (s << 32) + (++last_key[s]); + InsertWithHint(&list, key, &hints[s]); + } + Validate(&list); +} + +TEST_F(InlineSkipTest, InsertWithHint_MultipleHintsRandom) { + const int N = 100000; + const int S = 100; + Random rnd(534); + Arena arena; + TestComparator cmp; + TestInlineSkipList list(cmp, &arena); + void* hints[S]; + for (int i = 0; i < S; i++) { + hints[i] = nullptr; + } + for (int i = 0; i < N; i++) { + Key s = rnd.Uniform(S); + Key key = (s << 32) + rnd.Next(); + InsertWithHint(&list, key, &hints[s]); + } + Validate(&list); +} + +TEST_F(InlineSkipTest, InsertWithHint_CompatibleWithInsertWithoutHint) { + const int N = 100000; + const int S1 = 100; + const int S2 = 100; + Random rnd(534); + Arena arena; + TestComparator cmp; + TestInlineSkipList list(cmp, &arena); + std::unordered_set<Key> used; + Key with_hint[S1]; + Key without_hint[S2]; + void* hints[S1]; + for (int i = 0; i < S1; i++) { + hints[i] = nullptr; + while (true) { + Key s = rnd.Next(); + if (used.insert(s).second) { + with_hint[i] = s; + break; + } + } + } + for (int i = 0; i < S2; i++) { + while (true) { + Key s = rnd.Next(); + if (used.insert(s).second) { + without_hint[i] = s; + break; + } + } + } + for (int i = 0; i < N; i++) { + Key s = rnd.Uniform(S1 + S2); + if (s < S1) { + Key key = (with_hint[s] << 32) + rnd.Next(); + InsertWithHint(&list, key, &hints[s]); + } else { + Key key = (without_hint[s - S1] << 32) + rnd.Next(); + Insert(&list, key); + } + } + Validate(&list); +} + +#ifndef ROCKSDB_VALGRIND_RUN +// We want to make sure that with a single writer and multiple +// concurrent readers (with no synchronization other than when a +// reader's iterator is created), the reader always observes all the +// data that was present in the skip list when the iterator was +// constructor. Because insertions are happening concurrently, we may +// also observe new values that were inserted since the iterator was +// constructed, but we should never miss any values that were present +// at iterator construction time. +// +// We generate multi-part keys: +// <key,gen,hash> +// where: +// key is in range [0..K-1] +// gen is a generation number for key +// hash is hash(key,gen) +// +// The insertion code picks a random key, sets gen to be 1 + the last +// generation number inserted for that key, and sets hash to Hash(key,gen). +// +// At the beginning of a read, we snapshot the last inserted +// generation number for each key. We then iterate, including random +// calls to Next() and Seek(). For every key we encounter, we +// check that it is either expected given the initial snapshot or has +// been concurrently added since the iterator started. +class ConcurrentTest { + public: + static const uint32_t K = 8; + + private: + static uint64_t key(Key key) { return (key >> 40); } + static uint64_t gen(Key key) { return (key >> 8) & 0xffffffffu; } + static uint64_t hash(Key key) { return key & 0xff; } + + static uint64_t HashNumbers(uint64_t k, uint64_t g) { + uint64_t data[2] = {k, g}; + return Hash(reinterpret_cast<char*>(data), sizeof(data), 0); + } + + static Key MakeKey(uint64_t k, uint64_t g) { + assert(sizeof(Key) == sizeof(uint64_t)); + assert(k <= K); // We sometimes pass K to seek to the end of the skiplist + assert(g <= 0xffffffffu); + return ((k << 40) | (g << 8) | (HashNumbers(k, g) & 0xff)); + } + + static bool IsValidKey(Key k) { + return hash(k) == (HashNumbers(key(k), gen(k)) & 0xff); + } + + static Key RandomTarget(Random* rnd) { + switch (rnd->Next() % 10) { + case 0: + // Seek to beginning + return MakeKey(0, 0); + case 1: + // Seek to end + return MakeKey(K, 0); + default: + // Seek to middle + return MakeKey(rnd->Next() % K, 0); + } + } + + // Per-key generation + struct State { + std::atomic<int> generation[K]; + void Set(int k, int v) { + generation[k].store(v, std::memory_order_release); + } + int Get(int k) { return generation[k].load(std::memory_order_acquire); } + + State() { + for (unsigned int k = 0; k < K; k++) { + Set(k, 0); + } + } + }; + + // Current state of the test + State current_; + + ConcurrentArena arena_; + + // InlineSkipList is not protected by mu_. We just use a single writer + // thread to modify it. + InlineSkipList<TestComparator> list_; + + public: + ConcurrentTest() : list_(TestComparator(), &arena_) {} + + // REQUIRES: No concurrent calls to WriteStep or ConcurrentWriteStep + void WriteStep(Random* rnd) { + const uint32_t k = rnd->Next() % K; + const int g = current_.Get(k) + 1; + const Key new_key = MakeKey(k, g); + char* buf = list_.AllocateKey(sizeof(Key)); + memcpy(buf, &new_key, sizeof(Key)); + list_.Insert(buf); + current_.Set(k, g); + } + + // REQUIRES: No concurrent calls for the same k + void ConcurrentWriteStep(uint32_t k) { + const int g = current_.Get(k) + 1; + const Key new_key = MakeKey(k, g); + char* buf = list_.AllocateKey(sizeof(Key)); + memcpy(buf, &new_key, sizeof(Key)); + list_.InsertConcurrently(buf); + ASSERT_EQ(g, current_.Get(k) + 1); + current_.Set(k, g); + } + + void ReadStep(Random* rnd) { + // Remember the initial committed state of the skiplist. + State initial_state; + for (unsigned int k = 0; k < K; k++) { + initial_state.Set(k, current_.Get(k)); + } + + Key pos = RandomTarget(rnd); + InlineSkipList<TestComparator>::Iterator iter(&list_); + iter.Seek(Encode(&pos)); + while (true) { + Key current; + if (!iter.Valid()) { + current = MakeKey(K, 0); + } else { + current = Decode(iter.key()); + ASSERT_TRUE(IsValidKey(current)) << current; + } + ASSERT_LE(pos, current) << "should not go backwards"; + + // Verify that everything in [pos,current) was not present in + // initial_state. + while (pos < current) { + ASSERT_LT(key(pos), K) << pos; + + // Note that generation 0 is never inserted, so it is ok if + // <*,0,*> is missing. + ASSERT_TRUE((gen(pos) == 0U) || + (gen(pos) > static_cast<uint64_t>(initial_state.Get( + static_cast<int>(key(pos)))))) + << "key: " << key(pos) << "; gen: " << gen(pos) + << "; initgen: " << initial_state.Get(static_cast<int>(key(pos))); + + // Advance to next key in the valid key space + if (key(pos) < key(current)) { + pos = MakeKey(key(pos) + 1, 0); + } else { + pos = MakeKey(key(pos), gen(pos) + 1); + } + } + + if (!iter.Valid()) { + break; + } + + if (rnd->Next() % 2) { + iter.Next(); + pos = MakeKey(key(pos), gen(pos) + 1); + } else { + Key new_target = RandomTarget(rnd); + if (new_target > pos) { + pos = new_target; + iter.Seek(Encode(&new_target)); + } + } + } + } +}; +const uint32_t ConcurrentTest::K; + +// Simple test that does single-threaded testing of the ConcurrentTest +// scaffolding. +TEST_F(InlineSkipTest, ConcurrentReadWithoutThreads) { + ConcurrentTest test; + Random rnd(test::RandomSeed()); + for (int i = 0; i < 10000; i++) { + test.ReadStep(&rnd); + test.WriteStep(&rnd); + } +} + +TEST_F(InlineSkipTest, ConcurrentInsertWithoutThreads) { + ConcurrentTest test; + Random rnd(test::RandomSeed()); + for (int i = 0; i < 10000; i++) { + test.ReadStep(&rnd); + uint32_t base = rnd.Next(); + for (int j = 0; j < 4; ++j) { + test.ConcurrentWriteStep((base + j) % ConcurrentTest::K); + } + } +} + +class TestState { + public: + ConcurrentTest t_; + int seed_; + std::atomic<bool> quit_flag_; + std::atomic<uint32_t> next_writer_; + + enum ReaderState { STARTING, RUNNING, DONE }; + + explicit TestState(int s) + : seed_(s), + quit_flag_(false), + state_(STARTING), + pending_writers_(0), + state_cv_(&mu_) {} + + void Wait(ReaderState s) { + mu_.Lock(); + while (state_ != s) { + state_cv_.Wait(); + } + mu_.Unlock(); + } + + void Change(ReaderState s) { + mu_.Lock(); + state_ = s; + state_cv_.Signal(); + mu_.Unlock(); + } + + void AdjustPendingWriters(int delta) { + mu_.Lock(); + pending_writers_ += delta; + if (pending_writers_ == 0) { + state_cv_.Signal(); + } + mu_.Unlock(); + } + + void WaitForPendingWriters() { + mu_.Lock(); + while (pending_writers_ != 0) { + state_cv_.Wait(); + } + mu_.Unlock(); + } + + private: + port::Mutex mu_; + ReaderState state_; + int pending_writers_; + port::CondVar state_cv_; +}; + +static void ConcurrentReader(void* arg) { + TestState* state = reinterpret_cast<TestState*>(arg); + Random rnd(state->seed_); + int64_t reads = 0; + state->Change(TestState::RUNNING); + while (!state->quit_flag_.load(std::memory_order_acquire)) { + state->t_.ReadStep(&rnd); + ++reads; + } + state->Change(TestState::DONE); +} + +static void ConcurrentWriter(void* arg) { + TestState* state = reinterpret_cast<TestState*>(arg); + uint32_t k = state->next_writer_++ % ConcurrentTest::K; + state->t_.ConcurrentWriteStep(k); + state->AdjustPendingWriters(-1); +} + +static void RunConcurrentRead(int run) { + const int seed = test::RandomSeed() + (run * 100); + Random rnd(seed); + const int N = 1000; + const int kSize = 1000; + for (int i = 0; i < N; i++) { + if ((i % 100) == 0) { + fprintf(stderr, "Run %d of %d\n", i, N); + } + TestState state(seed + 1); + Env::Default()->SetBackgroundThreads(1); + Env::Default()->Schedule(ConcurrentReader, &state); + state.Wait(TestState::RUNNING); + for (int k = 0; k < kSize; ++k) { + state.t_.WriteStep(&rnd); + } + state.quit_flag_.store(true, std::memory_order_release); + state.Wait(TestState::DONE); + } +} + +static void RunConcurrentInsert(int run, int write_parallelism = 4) { + Env::Default()->SetBackgroundThreads(1 + write_parallelism, + Env::Priority::LOW); + const int seed = test::RandomSeed() + (run * 100); + Random rnd(seed); + const int N = 1000; + const int kSize = 1000; + for (int i = 0; i < N; i++) { + if ((i % 100) == 0) { + fprintf(stderr, "Run %d of %d\n", i, N); + } + TestState state(seed + 1); + Env::Default()->Schedule(ConcurrentReader, &state); + state.Wait(TestState::RUNNING); + for (int k = 0; k < kSize; k += write_parallelism) { + state.next_writer_ = rnd.Next(); + state.AdjustPendingWriters(write_parallelism); + for (int p = 0; p < write_parallelism; ++p) { + Env::Default()->Schedule(ConcurrentWriter, &state); + } + state.WaitForPendingWriters(); + } + state.quit_flag_.store(true, std::memory_order_release); + state.Wait(TestState::DONE); + } +} + +TEST_F(InlineSkipTest, ConcurrentRead1) { RunConcurrentRead(1); } +TEST_F(InlineSkipTest, ConcurrentRead2) { RunConcurrentRead(2); } +TEST_F(InlineSkipTest, ConcurrentRead3) { RunConcurrentRead(3); } +TEST_F(InlineSkipTest, ConcurrentRead4) { RunConcurrentRead(4); } +TEST_F(InlineSkipTest, ConcurrentRead5) { RunConcurrentRead(5); } +TEST_F(InlineSkipTest, ConcurrentInsert1) { RunConcurrentInsert(1); } +TEST_F(InlineSkipTest, ConcurrentInsert2) { RunConcurrentInsert(2); } +TEST_F(InlineSkipTest, ConcurrentInsert3) { RunConcurrentInsert(3); } + +#endif // ROCKSDB_VALGRIND_RUN +} // namespace rocksdb + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/src/rocksdb/memtable/memtablerep_bench.cc b/src/rocksdb/memtable/memtablerep_bench.cc new file mode 100644 index 00000000..51ff11a0 --- /dev/null +++ b/src/rocksdb/memtable/memtablerep_bench.cc @@ -0,0 +1,682 @@ +// Copyright (c) 2011-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). +// +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#ifndef __STDC_FORMAT_MACROS +#define __STDC_FORMAT_MACROS +#endif + +#ifndef GFLAGS +#include <cstdio> +int main() { + fprintf(stderr, "Please install gflags to run rocksdb tools\n"); + return 1; +} +#else + +#include <atomic> +#include <iostream> +#include <memory> +#include <thread> +#include <type_traits> +#include <vector> + +#include "db/dbformat.h" +#include "db/memtable.h" +#include "port/port.h" +#include "port/stack_trace.h" +#include "rocksdb/comparator.h" +#include "rocksdb/memtablerep.h" +#include "rocksdb/options.h" +#include "rocksdb/slice_transform.h" +#include "rocksdb/write_buffer_manager.h" +#include "util/arena.h" +#include "util/gflags_compat.h" +#include "util/mutexlock.h" +#include "util/stop_watch.h" +#include "util/testutil.h" + +using GFLAGS_NAMESPACE::ParseCommandLineFlags; +using GFLAGS_NAMESPACE::RegisterFlagValidator; +using GFLAGS_NAMESPACE::SetUsageMessage; + +DEFINE_string(benchmarks, "fillrandom", + "Comma-separated list of benchmarks to run. Options:\n" + "\tfillrandom -- write N random values\n" + "\tfillseq -- write N values in sequential order\n" + "\treadrandom -- read N values in random order\n" + "\treadseq -- scan the DB\n" + "\treadwrite -- 1 thread writes while N - 1 threads " + "do random\n" + "\t reads\n" + "\tseqreadwrite -- 1 thread writes while N - 1 threads " + "do scans\n"); + +DEFINE_string(memtablerep, "skiplist", + "Which implementation of memtablerep to use. See " + "include/memtablerep.h for\n" + " more details. Options:\n" + "\tskiplist -- backed by a skiplist\n" + "\tvector -- backed by an std::vector\n" + "\thashskiplist -- backed by a hash skip list\n" + "\thashlinklist -- backed by a hash linked list\n" + "\tcuckoo -- backed by a cuckoo hash table"); + +DEFINE_int64(bucket_count, 1000000, + "bucket_count parameter to pass into NewHashSkiplistRepFactory or " + "NewHashLinkListRepFactory"); + +DEFINE_int32( + hashskiplist_height, 4, + "skiplist_height parameter to pass into NewHashSkiplistRepFactory"); + +DEFINE_int32( + hashskiplist_branching_factor, 4, + "branching_factor parameter to pass into NewHashSkiplistRepFactory"); + +DEFINE_int32( + huge_page_tlb_size, 0, + "huge_page_tlb_size parameter to pass into NewHashLinkListRepFactory"); + +DEFINE_int32(bucket_entries_logging_threshold, 4096, + "bucket_entries_logging_threshold parameter to pass into " + "NewHashLinkListRepFactory"); + +DEFINE_bool(if_log_bucket_dist_when_flash, true, + "if_log_bucket_dist_when_flash parameter to pass into " + "NewHashLinkListRepFactory"); + +DEFINE_int32( + threshold_use_skiplist, 256, + "threshold_use_skiplist parameter to pass into NewHashLinkListRepFactory"); + +DEFINE_int64(write_buffer_size, 256, + "write_buffer_size parameter to pass into WriteBufferManager"); + +DEFINE_int32( + num_threads, 1, + "Number of concurrent threads to run. If the benchmark includes writes,\n" + "then at most one thread will be a writer"); + +DEFINE_int32(num_operations, 1000000, + "Number of operations to do for write and random read benchmarks"); + +DEFINE_int32(num_scans, 10, + "Number of times for each thread to scan the memtablerep for " + "sequential read " + "benchmarks"); + +DEFINE_int32(item_size, 100, "Number of bytes each item should be"); + +DEFINE_int32(prefix_length, 8, + "Prefix length to pass into NewFixedPrefixTransform"); + +/* VectorRep settings */ +DEFINE_int64(vectorrep_count, 0, + "Number of entries to reserve on VectorRep initialization"); + +DEFINE_int64(seed, 0, + "Seed base for random number generators. " + "When 0 it is deterministic."); + +namespace rocksdb { + +namespace { +struct CallbackVerifyArgs { + bool found; + LookupKey* key; + MemTableRep* table; + InternalKeyComparator* comparator; +}; +} // namespace + +// Helper for quickly generating random data. +class RandomGenerator { + private: + std::string data_; + unsigned int pos_; + + public: + RandomGenerator() { + Random rnd(301); + auto size = (unsigned)std::max(1048576, FLAGS_item_size); + test::RandomString(&rnd, size, &data_); + pos_ = 0; + } + + Slice Generate(unsigned int len) { + assert(len <= data_.size()); + if (pos_ + len > data_.size()) { + pos_ = 0; + } + pos_ += len; + return Slice(data_.data() + pos_ - len, len); + } +}; + +enum WriteMode { SEQUENTIAL, RANDOM, UNIQUE_RANDOM }; + +class KeyGenerator { + public: + KeyGenerator(Random64* rand, WriteMode mode, uint64_t num) + : rand_(rand), mode_(mode), num_(num), next_(0) { + if (mode_ == UNIQUE_RANDOM) { + // NOTE: if memory consumption of this approach becomes a concern, + // we can either break it into pieces and only random shuffle a section + // each time. Alternatively, use a bit map implementation + // (https://reviews.facebook.net/differential/diff/54627/) + values_.resize(num_); + for (uint64_t i = 0; i < num_; ++i) { + values_[i] = i; + } + std::shuffle( + values_.begin(), values_.end(), + std::default_random_engine(static_cast<unsigned int>(FLAGS_seed))); + } + } + + uint64_t Next() { + switch (mode_) { + case SEQUENTIAL: + return next_++; + case RANDOM: + return rand_->Next() % num_; + case UNIQUE_RANDOM: + return values_[next_++]; + } + assert(false); + return std::numeric_limits<uint64_t>::max(); + } + + private: + Random64* rand_; + WriteMode mode_; + const uint64_t num_; + uint64_t next_; + std::vector<uint64_t> values_; +}; + +class BenchmarkThread { + public: + explicit BenchmarkThread(MemTableRep* table, KeyGenerator* key_gen, + uint64_t* bytes_written, uint64_t* bytes_read, + uint64_t* sequence, uint64_t num_ops, + uint64_t* read_hits) + : table_(table), + key_gen_(key_gen), + bytes_written_(bytes_written), + bytes_read_(bytes_read), + sequence_(sequence), + num_ops_(num_ops), + read_hits_(read_hits) {} + + virtual void operator()() = 0; + virtual ~BenchmarkThread() {} + + protected: + MemTableRep* table_; + KeyGenerator* key_gen_; + uint64_t* bytes_written_; + uint64_t* bytes_read_; + uint64_t* sequence_; + uint64_t num_ops_; + uint64_t* read_hits_; + RandomGenerator generator_; +}; + +class FillBenchmarkThread : public BenchmarkThread { + public: + FillBenchmarkThread(MemTableRep* table, KeyGenerator* key_gen, + uint64_t* bytes_written, uint64_t* bytes_read, + uint64_t* sequence, uint64_t num_ops, uint64_t* read_hits) + : BenchmarkThread(table, key_gen, bytes_written, bytes_read, sequence, + num_ops, read_hits) {} + + void FillOne() { + char* buf = nullptr; + auto internal_key_size = 16; + auto encoded_len = + FLAGS_item_size + VarintLength(internal_key_size) + internal_key_size; + KeyHandle handle = table_->Allocate(encoded_len, &buf); + assert(buf != nullptr); + char* p = EncodeVarint32(buf, internal_key_size); + auto key = key_gen_->Next(); + EncodeFixed64(p, key); + p += 8; + EncodeFixed64(p, ++(*sequence_)); + p += 8; + Slice bytes = generator_.Generate(FLAGS_item_size); + memcpy(p, bytes.data(), FLAGS_item_size); + p += FLAGS_item_size; + assert(p == buf + encoded_len); + table_->Insert(handle); + *bytes_written_ += encoded_len; + } + + void operator()() override { + for (unsigned int i = 0; i < num_ops_; ++i) { + FillOne(); + } + } +}; + +class ConcurrentFillBenchmarkThread : public FillBenchmarkThread { + public: + ConcurrentFillBenchmarkThread(MemTableRep* table, KeyGenerator* key_gen, + uint64_t* bytes_written, uint64_t* bytes_read, + uint64_t* sequence, uint64_t num_ops, + uint64_t* read_hits, + std::atomic_int* threads_done) + : FillBenchmarkThread(table, key_gen, bytes_written, bytes_read, sequence, + num_ops, read_hits) { + threads_done_ = threads_done; + } + + void operator()() override { + // # of read threads will be total threads - write threads (always 1). Loop + // while all reads complete. + while ((*threads_done_).load() < (FLAGS_num_threads - 1)) { + FillOne(); + } + } + + private: + std::atomic_int* threads_done_; +}; + +class ReadBenchmarkThread : public BenchmarkThread { + public: + ReadBenchmarkThread(MemTableRep* table, KeyGenerator* key_gen, + uint64_t* bytes_written, uint64_t* bytes_read, + uint64_t* sequence, uint64_t num_ops, uint64_t* read_hits) + : BenchmarkThread(table, key_gen, bytes_written, bytes_read, sequence, + num_ops, read_hits) {} + + static bool callback(void* arg, const char* entry) { + CallbackVerifyArgs* callback_args = static_cast<CallbackVerifyArgs*>(arg); + assert(callback_args != nullptr); + uint32_t key_length; + const char* key_ptr = GetVarint32Ptr(entry, entry + 5, &key_length); + if ((callback_args->comparator) + ->user_comparator() + ->Equal(Slice(key_ptr, key_length - 8), + callback_args->key->user_key())) { + callback_args->found = true; + } + return false; + } + + void ReadOne() { + std::string user_key; + auto key = key_gen_->Next(); + PutFixed64(&user_key, key); + LookupKey lookup_key(user_key, *sequence_); + InternalKeyComparator internal_key_comp(BytewiseComparator()); + CallbackVerifyArgs verify_args; + verify_args.found = false; + verify_args.key = &lookup_key; + verify_args.table = table_; + verify_args.comparator = &internal_key_comp; + table_->Get(lookup_key, &verify_args, callback); + if (verify_args.found) { + *bytes_read_ += VarintLength(16) + 16 + FLAGS_item_size; + ++*read_hits_; + } + } + void operator()() override { + for (unsigned int i = 0; i < num_ops_; ++i) { + ReadOne(); + } + } +}; + +class SeqReadBenchmarkThread : public BenchmarkThread { + public: + SeqReadBenchmarkThread(MemTableRep* table, KeyGenerator* key_gen, + uint64_t* bytes_written, uint64_t* bytes_read, + uint64_t* sequence, uint64_t num_ops, + uint64_t* read_hits) + : BenchmarkThread(table, key_gen, bytes_written, bytes_read, sequence, + num_ops, read_hits) {} + + void ReadOneSeq() { + std::unique_ptr<MemTableRep::Iterator> iter(table_->GetIterator()); + for (iter->SeekToFirst(); iter->Valid(); iter->Next()) { + // pretend to read the value + *bytes_read_ += VarintLength(16) + 16 + FLAGS_item_size; + } + ++*read_hits_; + } + + void operator()() override { + for (unsigned int i = 0; i < num_ops_; ++i) { + { ReadOneSeq(); } + } + } +}; + +class ConcurrentReadBenchmarkThread : public ReadBenchmarkThread { + public: + ConcurrentReadBenchmarkThread(MemTableRep* table, KeyGenerator* key_gen, + uint64_t* bytes_written, uint64_t* bytes_read, + uint64_t* sequence, uint64_t num_ops, + uint64_t* read_hits, + std::atomic_int* threads_done) + : ReadBenchmarkThread(table, key_gen, bytes_written, bytes_read, sequence, + num_ops, read_hits) { + threads_done_ = threads_done; + } + + void operator()() override { + for (unsigned int i = 0; i < num_ops_; ++i) { + ReadOne(); + } + ++*threads_done_; + } + + private: + std::atomic_int* threads_done_; +}; + +class SeqConcurrentReadBenchmarkThread : public SeqReadBenchmarkThread { + public: + SeqConcurrentReadBenchmarkThread(MemTableRep* table, KeyGenerator* key_gen, + uint64_t* bytes_written, + uint64_t* bytes_read, uint64_t* sequence, + uint64_t num_ops, uint64_t* read_hits, + std::atomic_int* threads_done) + : SeqReadBenchmarkThread(table, key_gen, bytes_written, bytes_read, + sequence, num_ops, read_hits) { + threads_done_ = threads_done; + } + + void operator()() override { + for (unsigned int i = 0; i < num_ops_; ++i) { + ReadOneSeq(); + } + ++*threads_done_; + } + + private: + std::atomic_int* threads_done_; +}; + +class Benchmark { + public: + explicit Benchmark(MemTableRep* table, KeyGenerator* key_gen, + uint64_t* sequence, uint32_t num_threads) + : table_(table), + key_gen_(key_gen), + sequence_(sequence), + num_threads_(num_threads) {} + + virtual ~Benchmark() {} + virtual void Run() { + std::cout << "Number of threads: " << num_threads_ << std::endl; + std::vector<port::Thread> threads; + uint64_t bytes_written = 0; + uint64_t bytes_read = 0; + uint64_t read_hits = 0; + StopWatchNano timer(Env::Default(), true); + RunThreads(&threads, &bytes_written, &bytes_read, true, &read_hits); + auto elapsed_time = static_cast<double>(timer.ElapsedNanos() / 1000); + std::cout << "Elapsed time: " << static_cast<int>(elapsed_time) << " us" + << std::endl; + + if (bytes_written > 0) { + auto MiB_written = static_cast<double>(bytes_written) / (1 << 20); + auto write_throughput = MiB_written / (elapsed_time / 1000000); + std::cout << "Total bytes written: " << MiB_written << " MiB" + << std::endl; + std::cout << "Write throughput: " << write_throughput << " MiB/s" + << std::endl; + auto us_per_op = elapsed_time / num_write_ops_per_thread_; + std::cout << "write us/op: " << us_per_op << std::endl; + } + if (bytes_read > 0) { + auto MiB_read = static_cast<double>(bytes_read) / (1 << 20); + auto read_throughput = MiB_read / (elapsed_time / 1000000); + std::cout << "Total bytes read: " << MiB_read << " MiB" << std::endl; + std::cout << "Read throughput: " << read_throughput << " MiB/s" + << std::endl; + auto us_per_op = elapsed_time / num_read_ops_per_thread_; + std::cout << "read us/op: " << us_per_op << std::endl; + } + } + + virtual void RunThreads(std::vector<port::Thread>* threads, + uint64_t* bytes_written, uint64_t* bytes_read, + bool write, uint64_t* read_hits) = 0; + + protected: + MemTableRep* table_; + KeyGenerator* key_gen_; + uint64_t* sequence_; + uint64_t num_write_ops_per_thread_; + uint64_t num_read_ops_per_thread_; + const uint32_t num_threads_; +}; + +class FillBenchmark : public Benchmark { + public: + explicit FillBenchmark(MemTableRep* table, KeyGenerator* key_gen, + uint64_t* sequence) + : Benchmark(table, key_gen, sequence, 1) { + num_write_ops_per_thread_ = FLAGS_num_operations; + } + + void RunThreads(std::vector<port::Thread>* /*threads*/, uint64_t* bytes_written, + uint64_t* bytes_read, bool /*write*/, + uint64_t* read_hits) override { + FillBenchmarkThread(table_, key_gen_, bytes_written, bytes_read, sequence_, + num_write_ops_per_thread_, read_hits)(); + } +}; + +class ReadBenchmark : public Benchmark { + public: + explicit ReadBenchmark(MemTableRep* table, KeyGenerator* key_gen, + uint64_t* sequence) + : Benchmark(table, key_gen, sequence, FLAGS_num_threads) { + num_read_ops_per_thread_ = FLAGS_num_operations / FLAGS_num_threads; + } + + void RunThreads(std::vector<port::Thread>* threads, uint64_t* bytes_written, + uint64_t* bytes_read, bool /*write*/, + uint64_t* read_hits) override { + for (int i = 0; i < FLAGS_num_threads; ++i) { + threads->emplace_back( + ReadBenchmarkThread(table_, key_gen_, bytes_written, bytes_read, + sequence_, num_read_ops_per_thread_, read_hits)); + } + for (auto& thread : *threads) { + thread.join(); + } + std::cout << "read hit%: " + << (static_cast<double>(*read_hits) / FLAGS_num_operations) * 100 + << std::endl; + } +}; + +class SeqReadBenchmark : public Benchmark { + public: + explicit SeqReadBenchmark(MemTableRep* table, uint64_t* sequence) + : Benchmark(table, nullptr, sequence, FLAGS_num_threads) { + num_read_ops_per_thread_ = FLAGS_num_scans; + } + + void RunThreads(std::vector<port::Thread>* threads, uint64_t* bytes_written, + uint64_t* bytes_read, bool /*write*/, + uint64_t* read_hits) override { + for (int i = 0; i < FLAGS_num_threads; ++i) { + threads->emplace_back(SeqReadBenchmarkThread( + table_, key_gen_, bytes_written, bytes_read, sequence_, + num_read_ops_per_thread_, read_hits)); + } + for (auto& thread : *threads) { + thread.join(); + } + } +}; + +template <class ReadThreadType> +class ReadWriteBenchmark : public Benchmark { + public: + explicit ReadWriteBenchmark(MemTableRep* table, KeyGenerator* key_gen, + uint64_t* sequence) + : Benchmark(table, key_gen, sequence, FLAGS_num_threads) { + num_read_ops_per_thread_ = + FLAGS_num_threads <= 1 + ? 0 + : (FLAGS_num_operations / (FLAGS_num_threads - 1)); + num_write_ops_per_thread_ = FLAGS_num_operations; + } + + void RunThreads(std::vector<port::Thread>* threads, uint64_t* bytes_written, + uint64_t* bytes_read, bool /*write*/, + uint64_t* read_hits) override { + std::atomic_int threads_done; + threads_done.store(0); + threads->emplace_back(ConcurrentFillBenchmarkThread( + table_, key_gen_, bytes_written, bytes_read, sequence_, + num_write_ops_per_thread_, read_hits, &threads_done)); + for (int i = 1; i < FLAGS_num_threads; ++i) { + threads->emplace_back( + ReadThreadType(table_, key_gen_, bytes_written, bytes_read, sequence_, + num_read_ops_per_thread_, read_hits, &threads_done)); + } + for (auto& thread : *threads) { + thread.join(); + } + } +}; + +} // namespace rocksdb + +void PrintWarnings() { +#if defined(__GNUC__) && !defined(__OPTIMIZE__) + fprintf(stdout, + "WARNING: Optimization is disabled: benchmarks unnecessarily slow\n"); +#endif +#ifndef NDEBUG + fprintf(stdout, + "WARNING: Assertions are enabled; benchmarks unnecessarily slow\n"); +#endif +} + +int main(int argc, char** argv) { + rocksdb::port::InstallStackTraceHandler(); + SetUsageMessage(std::string("\nUSAGE:\n") + std::string(argv[0]) + + " [OPTIONS]..."); + ParseCommandLineFlags(&argc, &argv, true); + + PrintWarnings(); + + rocksdb::Options options; + + std::unique_ptr<rocksdb::MemTableRepFactory> factory; + if (FLAGS_memtablerep == "skiplist") { + factory.reset(new rocksdb::SkipListFactory); +#ifndef ROCKSDB_LITE + } else if (FLAGS_memtablerep == "vector") { + factory.reset(new rocksdb::VectorRepFactory); + } else if (FLAGS_memtablerep == "hashskiplist") { + factory.reset(rocksdb::NewHashSkipListRepFactory( + FLAGS_bucket_count, FLAGS_hashskiplist_height, + FLAGS_hashskiplist_branching_factor)); + options.prefix_extractor.reset( + rocksdb::NewFixedPrefixTransform(FLAGS_prefix_length)); + } else if (FLAGS_memtablerep == "hashlinklist") { + factory.reset(rocksdb::NewHashLinkListRepFactory( + FLAGS_bucket_count, FLAGS_huge_page_tlb_size, + FLAGS_bucket_entries_logging_threshold, + FLAGS_if_log_bucket_dist_when_flash, FLAGS_threshold_use_skiplist)); + options.prefix_extractor.reset( + rocksdb::NewFixedPrefixTransform(FLAGS_prefix_length)); +#endif // ROCKSDB_LITE + } else { + fprintf(stdout, "Unknown memtablerep: %s\n", FLAGS_memtablerep.c_str()); + exit(1); + } + + rocksdb::InternalKeyComparator internal_key_comp( + rocksdb::BytewiseComparator()); + rocksdb::MemTable::KeyComparator key_comp(internal_key_comp); + rocksdb::Arena arena; + rocksdb::WriteBufferManager wb(FLAGS_write_buffer_size); + uint64_t sequence; + auto createMemtableRep = [&] { + sequence = 0; + return factory->CreateMemTableRep(key_comp, &arena, + options.prefix_extractor.get(), + options.info_log.get()); + }; + std::unique_ptr<rocksdb::MemTableRep> memtablerep; + rocksdb::Random64 rng(FLAGS_seed); + const char* benchmarks = FLAGS_benchmarks.c_str(); + while (benchmarks != nullptr) { + std::unique_ptr<rocksdb::KeyGenerator> key_gen; + const char* sep = strchr(benchmarks, ','); + rocksdb::Slice name; + if (sep == nullptr) { + name = benchmarks; + benchmarks = nullptr; + } else { + name = rocksdb::Slice(benchmarks, sep - benchmarks); + benchmarks = sep + 1; + } + std::unique_ptr<rocksdb::Benchmark> benchmark; + if (name == rocksdb::Slice("fillseq")) { + memtablerep.reset(createMemtableRep()); + key_gen.reset(new rocksdb::KeyGenerator(&rng, rocksdb::SEQUENTIAL, + FLAGS_num_operations)); + benchmark.reset(new rocksdb::FillBenchmark(memtablerep.get(), + key_gen.get(), &sequence)); + } else if (name == rocksdb::Slice("fillrandom")) { + memtablerep.reset(createMemtableRep()); + key_gen.reset(new rocksdb::KeyGenerator(&rng, rocksdb::UNIQUE_RANDOM, + FLAGS_num_operations)); + benchmark.reset(new rocksdb::FillBenchmark(memtablerep.get(), + key_gen.get(), &sequence)); + } else if (name == rocksdb::Slice("readrandom")) { + key_gen.reset(new rocksdb::KeyGenerator(&rng, rocksdb::RANDOM, + FLAGS_num_operations)); + benchmark.reset(new rocksdb::ReadBenchmark(memtablerep.get(), + key_gen.get(), &sequence)); + } else if (name == rocksdb::Slice("readseq")) { + key_gen.reset(new rocksdb::KeyGenerator(&rng, rocksdb::SEQUENTIAL, + FLAGS_num_operations)); + benchmark.reset( + new rocksdb::SeqReadBenchmark(memtablerep.get(), &sequence)); + } else if (name == rocksdb::Slice("readwrite")) { + memtablerep.reset(createMemtableRep()); + key_gen.reset(new rocksdb::KeyGenerator(&rng, rocksdb::RANDOM, + FLAGS_num_operations)); + benchmark.reset(new rocksdb::ReadWriteBenchmark< + rocksdb::ConcurrentReadBenchmarkThread>(memtablerep.get(), + key_gen.get(), &sequence)); + } else if (name == rocksdb::Slice("seqreadwrite")) { + memtablerep.reset(createMemtableRep()); + key_gen.reset(new rocksdb::KeyGenerator(&rng, rocksdb::RANDOM, + FLAGS_num_operations)); + benchmark.reset(new rocksdb::ReadWriteBenchmark< + rocksdb::SeqConcurrentReadBenchmarkThread>(memtablerep.get(), + key_gen.get(), &sequence)); + } else { + std::cout << "WARNING: skipping unknown benchmark '" << name.ToString() + << std::endl; + continue; + } + std::cout << "Running " << name.ToString() << std::endl; + benchmark->Run(); + } + + return 0; +} + +#endif // GFLAGS diff --git a/src/rocksdb/memtable/skiplist.h b/src/rocksdb/memtable/skiplist.h new file mode 100644 index 00000000..47a89034 --- /dev/null +++ b/src/rocksdb/memtable/skiplist.h @@ -0,0 +1,497 @@ +// Copyright (c) 2011-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). +// +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. +// +// Thread safety +// ------------- +// +// Writes require external synchronization, most likely a mutex. +// Reads require a guarantee that the SkipList will not be destroyed +// while the read is in progress. Apart from that, reads progress +// without any internal locking or synchronization. +// +// Invariants: +// +// (1) Allocated nodes are never deleted until the SkipList is +// destroyed. This is trivially guaranteed by the code since we +// never delete any skip list nodes. +// +// (2) The contents of a Node except for the next/prev pointers are +// immutable after the Node has been linked into the SkipList. +// Only Insert() modifies the list, and it is careful to initialize +// a node and use release-stores to publish the nodes in one or +// more lists. +// +// ... prev vs. next pointer ordering ... +// + +#pragma once +#include <assert.h> +#include <atomic> +#include <stdlib.h> +#include "port/port.h" +#include "util/allocator.h" +#include "util/random.h" + +namespace rocksdb { + +template<typename Key, class Comparator> +class SkipList { + private: + struct Node; + + public: + // Create a new SkipList object that will use "cmp" for comparing keys, + // and will allocate memory using "*allocator". Objects allocated in the + // allocator must remain allocated for the lifetime of the skiplist object. + explicit SkipList(Comparator cmp, Allocator* allocator, + int32_t max_height = 12, int32_t branching_factor = 4); + + // Insert key into the list. + // REQUIRES: nothing that compares equal to key is currently in the list. + void Insert(const Key& key); + + // Returns true iff an entry that compares equal to key is in the list. + bool Contains(const Key& key) const; + + // Return estimated number of entries smaller than `key`. + uint64_t EstimateCount(const Key& key) const; + + // Iteration over the contents of a skip list + class Iterator { + public: + // Initialize an iterator over the specified list. + // The returned iterator is not valid. + explicit Iterator(const SkipList* list); + + // Change the underlying skiplist used for this iterator + // This enables us not changing the iterator without deallocating + // an old one and then allocating a new one + void SetList(const SkipList* list); + + // Returns true iff the iterator is positioned at a valid node. + bool Valid() const; + + // Returns the key at the current position. + // REQUIRES: Valid() + const Key& key() const; + + // Advances to the next position. + // REQUIRES: Valid() + void Next(); + + // Advances to the previous position. + // REQUIRES: Valid() + void Prev(); + + // Advance to the first entry with a key >= target + void Seek(const Key& target); + + // Retreat to the last entry with a key <= target + void SeekForPrev(const Key& target); + + // Position at the first entry in list. + // Final state of iterator is Valid() iff list is not empty. + void SeekToFirst(); + + // Position at the last entry in list. + // Final state of iterator is Valid() iff list is not empty. + void SeekToLast(); + + private: + const SkipList* list_; + Node* node_; + // Intentionally copyable + }; + + private: + const uint16_t kMaxHeight_; + const uint16_t kBranching_; + const uint32_t kScaledInverseBranching_; + + // Immutable after construction + Comparator const compare_; + Allocator* const allocator_; // Allocator used for allocations of nodes + + Node* const head_; + + // Modified only by Insert(). Read racily by readers, but stale + // values are ok. + std::atomic<int> max_height_; // Height of the entire list + + // Used for optimizing sequential insert patterns. Tricky. prev_[i] for + // i up to max_height_ is the predecessor of prev_[0] and prev_height_ + // is the height of prev_[0]. prev_[0] can only be equal to head before + // insertion, in which case max_height_ and prev_height_ are 1. + Node** prev_; + int32_t prev_height_; + + inline int GetMaxHeight() const { + return max_height_.load(std::memory_order_relaxed); + } + + Node* NewNode(const Key& key, int height); + int RandomHeight(); + bool Equal(const Key& a, const Key& b) const { return (compare_(a, b) == 0); } + bool LessThan(const Key& a, const Key& b) const { + return (compare_(a, b) < 0); + } + + // Return true if key is greater than the data stored in "n" + bool KeyIsAfterNode(const Key& key, Node* n) const; + + // Returns the earliest node with a key >= key. + // Return nullptr if there is no such node. + Node* FindGreaterOrEqual(const Key& key) const; + + // Return the latest node with a key < key. + // Return head_ if there is no such node. + // Fills prev[level] with pointer to previous node at "level" for every + // level in [0..max_height_-1], if prev is non-null. + Node* FindLessThan(const Key& key, Node** prev = nullptr) const; + + // Return the last node in the list. + // Return head_ if list is empty. + Node* FindLast() const; + + // No copying allowed + SkipList(const SkipList&); + void operator=(const SkipList&); +}; + +// Implementation details follow +template<typename Key, class Comparator> +struct SkipList<Key, Comparator>::Node { + explicit Node(const Key& k) : key(k) { } + + Key const key; + + // Accessors/mutators for links. Wrapped in methods so we can + // add the appropriate barriers as necessary. + Node* Next(int n) { + assert(n >= 0); + // Use an 'acquire load' so that we observe a fully initialized + // version of the returned Node. + return (next_[n].load(std::memory_order_acquire)); + } + void SetNext(int n, Node* x) { + assert(n >= 0); + // Use a 'release store' so that anybody who reads through this + // pointer observes a fully initialized version of the inserted node. + next_[n].store(x, std::memory_order_release); + } + + // No-barrier variants that can be safely used in a few locations. + Node* NoBarrier_Next(int n) { + assert(n >= 0); + return next_[n].load(std::memory_order_relaxed); + } + void NoBarrier_SetNext(int n, Node* x) { + assert(n >= 0); + next_[n].store(x, std::memory_order_relaxed); + } + + private: + // Array of length equal to the node height. next_[0] is lowest level link. + std::atomic<Node*> next_[1]; +}; + +template<typename Key, class Comparator> +typename SkipList<Key, Comparator>::Node* +SkipList<Key, Comparator>::NewNode(const Key& key, int height) { + char* mem = allocator_->AllocateAligned( + sizeof(Node) + sizeof(std::atomic<Node*>) * (height - 1)); + return new (mem) Node(key); +} + +template<typename Key, class Comparator> +inline SkipList<Key, Comparator>::Iterator::Iterator(const SkipList* list) { + SetList(list); +} + +template<typename Key, class Comparator> +inline void SkipList<Key, Comparator>::Iterator::SetList(const SkipList* list) { + list_ = list; + node_ = nullptr; +} + +template<typename Key, class Comparator> +inline bool SkipList<Key, Comparator>::Iterator::Valid() const { + return node_ != nullptr; +} + +template<typename Key, class Comparator> +inline const Key& SkipList<Key, Comparator>::Iterator::key() const { + assert(Valid()); + return node_->key; +} + +template<typename Key, class Comparator> +inline void SkipList<Key, Comparator>::Iterator::Next() { + assert(Valid()); + node_ = node_->Next(0); +} + +template<typename Key, class Comparator> +inline void SkipList<Key, Comparator>::Iterator::Prev() { + // Instead of using explicit "prev" links, we just search for the + // last node that falls before key. + assert(Valid()); + node_ = list_->FindLessThan(node_->key); + if (node_ == list_->head_) { + node_ = nullptr; + } +} + +template<typename Key, class Comparator> +inline void SkipList<Key, Comparator>::Iterator::Seek(const Key& target) { + node_ = list_->FindGreaterOrEqual(target); +} + +template <typename Key, class Comparator> +inline void SkipList<Key, Comparator>::Iterator::SeekForPrev( + const Key& target) { + Seek(target); + if (!Valid()) { + SeekToLast(); + } + while (Valid() && list_->LessThan(target, key())) { + Prev(); + } +} + +template <typename Key, class Comparator> +inline void SkipList<Key, Comparator>::Iterator::SeekToFirst() { + node_ = list_->head_->Next(0); +} + +template<typename Key, class Comparator> +inline void SkipList<Key, Comparator>::Iterator::SeekToLast() { + node_ = list_->FindLast(); + if (node_ == list_->head_) { + node_ = nullptr; + } +} + +template<typename Key, class Comparator> +int SkipList<Key, Comparator>::RandomHeight() { + auto rnd = Random::GetTLSInstance(); + + // Increase height with probability 1 in kBranching + int height = 1; + while (height < kMaxHeight_ && rnd->Next() < kScaledInverseBranching_) { + height++; + } + assert(height > 0); + assert(height <= kMaxHeight_); + return height; +} + +template<typename Key, class Comparator> +bool SkipList<Key, Comparator>::KeyIsAfterNode(const Key& key, Node* n) const { + // nullptr n is considered infinite + return (n != nullptr) && (compare_(n->key, key) < 0); +} + +template<typename Key, class Comparator> +typename SkipList<Key, Comparator>::Node* SkipList<Key, Comparator>:: + FindGreaterOrEqual(const Key& key) const { + // Note: It looks like we could reduce duplication by implementing + // this function as FindLessThan(key)->Next(0), but we wouldn't be able + // to exit early on equality and the result wouldn't even be correct. + // A concurrent insert might occur after FindLessThan(key) but before + // we get a chance to call Next(0). + Node* x = head_; + int level = GetMaxHeight() - 1; + Node* last_bigger = nullptr; + while (true) { + assert(x != nullptr); + Node* next = x->Next(level); + // Make sure the lists are sorted + assert(x == head_ || next == nullptr || KeyIsAfterNode(next->key, x)); + // Make sure we haven't overshot during our search + assert(x == head_ || KeyIsAfterNode(key, x)); + int cmp = (next == nullptr || next == last_bigger) + ? 1 : compare_(next->key, key); + if (cmp == 0 || (cmp > 0 && level == 0)) { + return next; + } else if (cmp < 0) { + // Keep searching in this list + x = next; + } else { + // Switch to next list, reuse compare_() result + last_bigger = next; + level--; + } + } +} + +template<typename Key, class Comparator> +typename SkipList<Key, Comparator>::Node* +SkipList<Key, Comparator>::FindLessThan(const Key& key, Node** prev) const { + Node* x = head_; + int level = GetMaxHeight() - 1; + // KeyIsAfter(key, last_not_after) is definitely false + Node* last_not_after = nullptr; + while (true) { + assert(x != nullptr); + Node* next = x->Next(level); + assert(x == head_ || next == nullptr || KeyIsAfterNode(next->key, x)); + assert(x == head_ || KeyIsAfterNode(key, x)); + if (next != last_not_after && KeyIsAfterNode(key, next)) { + // Keep searching in this list + x = next; + } else { + if (prev != nullptr) { + prev[level] = x; + } + if (level == 0) { + return x; + } else { + // Switch to next list, reuse KeyIUsAfterNode() result + last_not_after = next; + level--; + } + } + } +} + +template<typename Key, class Comparator> +typename SkipList<Key, Comparator>::Node* SkipList<Key, Comparator>::FindLast() + const { + Node* x = head_; + int level = GetMaxHeight() - 1; + while (true) { + Node* next = x->Next(level); + if (next == nullptr) { + if (level == 0) { + return x; + } else { + // Switch to next list + level--; + } + } else { + x = next; + } + } +} + +template <typename Key, class Comparator> +uint64_t SkipList<Key, Comparator>::EstimateCount(const Key& key) const { + uint64_t count = 0; + + Node* x = head_; + int level = GetMaxHeight() - 1; + while (true) { + assert(x == head_ || compare_(x->key, key) < 0); + Node* next = x->Next(level); + if (next == nullptr || compare_(next->key, key) >= 0) { + if (level == 0) { + return count; + } else { + // Switch to next list + count *= kBranching_; + level--; + } + } else { + x = next; + count++; + } + } +} + +template <typename Key, class Comparator> +SkipList<Key, Comparator>::SkipList(const Comparator cmp, Allocator* allocator, + int32_t max_height, + int32_t branching_factor) + : kMaxHeight_(static_cast<uint16_t>(max_height)), + kBranching_(static_cast<uint16_t>(branching_factor)), + kScaledInverseBranching_((Random::kMaxNext + 1) / kBranching_), + compare_(cmp), + allocator_(allocator), + head_(NewNode(0 /* any key will do */, max_height)), + max_height_(1), + prev_height_(1) { + assert(max_height > 0 && kMaxHeight_ == static_cast<uint32_t>(max_height)); + assert(branching_factor > 0 && + kBranching_ == static_cast<uint32_t>(branching_factor)); + assert(kScaledInverseBranching_ > 0); + // Allocate the prev_ Node* array, directly from the passed-in allocator. + // prev_ does not need to be freed, as its life cycle is tied up with + // the allocator as a whole. + prev_ = reinterpret_cast<Node**>( + allocator_->AllocateAligned(sizeof(Node*) * kMaxHeight_)); + for (int i = 0; i < kMaxHeight_; i++) { + head_->SetNext(i, nullptr); + prev_[i] = head_; + } +} + +template<typename Key, class Comparator> +void SkipList<Key, Comparator>::Insert(const Key& key) { + // fast path for sequential insertion + if (!KeyIsAfterNode(key, prev_[0]->NoBarrier_Next(0)) && + (prev_[0] == head_ || KeyIsAfterNode(key, prev_[0]))) { + assert(prev_[0] != head_ || (prev_height_ == 1 && GetMaxHeight() == 1)); + + // Outside of this method prev_[1..max_height_] is the predecessor + // of prev_[0], and prev_height_ refers to prev_[0]. Inside Insert + // prev_[0..max_height - 1] is the predecessor of key. Switch from + // the external state to the internal + for (int i = 1; i < prev_height_; i++) { + prev_[i] = prev_[0]; + } + } else { + // TODO(opt): we could use a NoBarrier predecessor search as an + // optimization for architectures where memory_order_acquire needs + // a synchronization instruction. Doesn't matter on x86 + FindLessThan(key, prev_); + } + + // Our data structure does not allow duplicate insertion + assert(prev_[0]->Next(0) == nullptr || !Equal(key, prev_[0]->Next(0)->key)); + + int height = RandomHeight(); + if (height > GetMaxHeight()) { + for (int i = GetMaxHeight(); i < height; i++) { + prev_[i] = head_; + } + //fprintf(stderr, "Change height from %d to %d\n", max_height_, height); + + // It is ok to mutate max_height_ without any synchronization + // with concurrent readers. A concurrent reader that observes + // the new value of max_height_ will see either the old value of + // new level pointers from head_ (nullptr), or a new value set in + // the loop below. In the former case the reader will + // immediately drop to the next level since nullptr sorts after all + // keys. In the latter case the reader will use the new node. + max_height_.store(height, std::memory_order_relaxed); + } + + Node* x = NewNode(key, height); + for (int i = 0; i < height; i++) { + // NoBarrier_SetNext() suffices since we will add a barrier when + // we publish a pointer to "x" in prev[i]. + x->NoBarrier_SetNext(i, prev_[i]->NoBarrier_Next(i)); + prev_[i]->SetNext(i, x); + } + prev_[0] = x; + prev_height_ = height; +} + +template<typename Key, class Comparator> +bool SkipList<Key, Comparator>::Contains(const Key& key) const { + Node* x = FindGreaterOrEqual(key); + if (x != nullptr && Equal(key, x->key)) { + return true; + } else { + return false; + } +} + +} // namespace rocksdb diff --git a/src/rocksdb/memtable/skiplist_test.cc b/src/rocksdb/memtable/skiplist_test.cc new file mode 100644 index 00000000..50c3588b --- /dev/null +++ b/src/rocksdb/memtable/skiplist_test.cc @@ -0,0 +1,388 @@ +// Copyright (c) 2011-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). +// +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#include "memtable/skiplist.h" +#include <set> +#include "rocksdb/env.h" +#include "util/arena.h" +#include "util/hash.h" +#include "util/random.h" +#include "util/testharness.h" + +namespace rocksdb { + +typedef uint64_t Key; + +struct TestComparator { + int operator()(const Key& a, const Key& b) const { + if (a < b) { + return -1; + } else if (a > b) { + return +1; + } else { + return 0; + } + } +}; + +class SkipTest : public testing::Test {}; + +TEST_F(SkipTest, Empty) { + Arena arena; + TestComparator cmp; + SkipList<Key, TestComparator> list(cmp, &arena); + ASSERT_TRUE(!list.Contains(10)); + + SkipList<Key, TestComparator>::Iterator iter(&list); + ASSERT_TRUE(!iter.Valid()); + iter.SeekToFirst(); + ASSERT_TRUE(!iter.Valid()); + iter.Seek(100); + ASSERT_TRUE(!iter.Valid()); + iter.SeekForPrev(100); + ASSERT_TRUE(!iter.Valid()); + iter.SeekToLast(); + ASSERT_TRUE(!iter.Valid()); +} + +TEST_F(SkipTest, InsertAndLookup) { + const int N = 2000; + const int R = 5000; + Random rnd(1000); + std::set<Key> keys; + Arena arena; + TestComparator cmp; + SkipList<Key, TestComparator> list(cmp, &arena); + for (int i = 0; i < N; i++) { + Key key = rnd.Next() % R; + if (keys.insert(key).second) { + list.Insert(key); + } + } + + for (int i = 0; i < R; i++) { + if (list.Contains(i)) { + ASSERT_EQ(keys.count(i), 1U); + } else { + ASSERT_EQ(keys.count(i), 0U); + } + } + + // Simple iterator tests + { + SkipList<Key, TestComparator>::Iterator iter(&list); + ASSERT_TRUE(!iter.Valid()); + + iter.Seek(0); + ASSERT_TRUE(iter.Valid()); + ASSERT_EQ(*(keys.begin()), iter.key()); + + iter.SeekForPrev(R - 1); + ASSERT_TRUE(iter.Valid()); + ASSERT_EQ(*(keys.rbegin()), iter.key()); + + iter.SeekToFirst(); + ASSERT_TRUE(iter.Valid()); + ASSERT_EQ(*(keys.begin()), iter.key()); + + iter.SeekToLast(); + ASSERT_TRUE(iter.Valid()); + ASSERT_EQ(*(keys.rbegin()), iter.key()); + } + + // Forward iteration test + for (int i = 0; i < R; i++) { + SkipList<Key, TestComparator>::Iterator iter(&list); + iter.Seek(i); + + // Compare against model iterator + std::set<Key>::iterator model_iter = keys.lower_bound(i); + for (int j = 0; j < 3; j++) { + if (model_iter == keys.end()) { + ASSERT_TRUE(!iter.Valid()); + break; + } else { + ASSERT_TRUE(iter.Valid()); + ASSERT_EQ(*model_iter, iter.key()); + ++model_iter; + iter.Next(); + } + } + } + + // Backward iteration test + for (int i = 0; i < R; i++) { + SkipList<Key, TestComparator>::Iterator iter(&list); + iter.SeekForPrev(i); + + // Compare against model iterator + std::set<Key>::iterator model_iter = keys.upper_bound(i); + for (int j = 0; j < 3; j++) { + if (model_iter == keys.begin()) { + ASSERT_TRUE(!iter.Valid()); + break; + } else { + ASSERT_TRUE(iter.Valid()); + ASSERT_EQ(*--model_iter, iter.key()); + iter.Prev(); + } + } + } +} + +// We want to make sure that with a single writer and multiple +// concurrent readers (with no synchronization other than when a +// reader's iterator is created), the reader always observes all the +// data that was present in the skip list when the iterator was +// constructor. Because insertions are happening concurrently, we may +// also observe new values that were inserted since the iterator was +// constructed, but we should never miss any values that were present +// at iterator construction time. +// +// We generate multi-part keys: +// <key,gen,hash> +// where: +// key is in range [0..K-1] +// gen is a generation number for key +// hash is hash(key,gen) +// +// The insertion code picks a random key, sets gen to be 1 + the last +// generation number inserted for that key, and sets hash to Hash(key,gen). +// +// At the beginning of a read, we snapshot the last inserted +// generation number for each key. We then iterate, including random +// calls to Next() and Seek(). For every key we encounter, we +// check that it is either expected given the initial snapshot or has +// been concurrently added since the iterator started. +class ConcurrentTest { + private: + static const uint32_t K = 4; + + static uint64_t key(Key key) { return (key >> 40); } + static uint64_t gen(Key key) { return (key >> 8) & 0xffffffffu; } + static uint64_t hash(Key key) { return key & 0xff; } + + static uint64_t HashNumbers(uint64_t k, uint64_t g) { + uint64_t data[2] = { k, g }; + return Hash(reinterpret_cast<char*>(data), sizeof(data), 0); + } + + static Key MakeKey(uint64_t k, uint64_t g) { + assert(sizeof(Key) == sizeof(uint64_t)); + assert(k <= K); // We sometimes pass K to seek to the end of the skiplist + assert(g <= 0xffffffffu); + return ((k << 40) | (g << 8) | (HashNumbers(k, g) & 0xff)); + } + + static bool IsValidKey(Key k) { + return hash(k) == (HashNumbers(key(k), gen(k)) & 0xff); + } + + static Key RandomTarget(Random* rnd) { + switch (rnd->Next() % 10) { + case 0: + // Seek to beginning + return MakeKey(0, 0); + case 1: + // Seek to end + return MakeKey(K, 0); + default: + // Seek to middle + return MakeKey(rnd->Next() % K, 0); + } + } + + // Per-key generation + struct State { + std::atomic<int> generation[K]; + void Set(int k, int v) { + generation[k].store(v, std::memory_order_release); + } + int Get(int k) { return generation[k].load(std::memory_order_acquire); } + + State() { + for (unsigned int k = 0; k < K; k++) { + Set(k, 0); + } + } + }; + + // Current state of the test + State current_; + + Arena arena_; + + // SkipList is not protected by mu_. We just use a single writer + // thread to modify it. + SkipList<Key, TestComparator> list_; + + public: + ConcurrentTest() : list_(TestComparator(), &arena_) {} + + // REQUIRES: External synchronization + void WriteStep(Random* rnd) { + const uint32_t k = rnd->Next() % K; + const int g = current_.Get(k) + 1; + const Key new_key = MakeKey(k, g); + list_.Insert(new_key); + current_.Set(k, g); + } + + void ReadStep(Random* rnd) { + // Remember the initial committed state of the skiplist. + State initial_state; + for (unsigned int k = 0; k < K; k++) { + initial_state.Set(k, current_.Get(k)); + } + + Key pos = RandomTarget(rnd); + SkipList<Key, TestComparator>::Iterator iter(&list_); + iter.Seek(pos); + while (true) { + Key current; + if (!iter.Valid()) { + current = MakeKey(K, 0); + } else { + current = iter.key(); + ASSERT_TRUE(IsValidKey(current)) << current; + } + ASSERT_LE(pos, current) << "should not go backwards"; + + // Verify that everything in [pos,current) was not present in + // initial_state. + while (pos < current) { + ASSERT_LT(key(pos), K) << pos; + + // Note that generation 0 is never inserted, so it is ok if + // <*,0,*> is missing. + ASSERT_TRUE((gen(pos) == 0U) || + (gen(pos) > static_cast<uint64_t>(initial_state.Get( + static_cast<int>(key(pos)))))) + << "key: " << key(pos) << "; gen: " << gen(pos) + << "; initgen: " << initial_state.Get(static_cast<int>(key(pos))); + + // Advance to next key in the valid key space + if (key(pos) < key(current)) { + pos = MakeKey(key(pos) + 1, 0); + } else { + pos = MakeKey(key(pos), gen(pos) + 1); + } + } + + if (!iter.Valid()) { + break; + } + + if (rnd->Next() % 2) { + iter.Next(); + pos = MakeKey(key(pos), gen(pos) + 1); + } else { + Key new_target = RandomTarget(rnd); + if (new_target > pos) { + pos = new_target; + iter.Seek(new_target); + } + } + } + } +}; +const uint32_t ConcurrentTest::K; + +// Simple test that does single-threaded testing of the ConcurrentTest +// scaffolding. +TEST_F(SkipTest, ConcurrentWithoutThreads) { + ConcurrentTest test; + Random rnd(test::RandomSeed()); + for (int i = 0; i < 10000; i++) { + test.ReadStep(&rnd); + test.WriteStep(&rnd); + } +} + +class TestState { + public: + ConcurrentTest t_; + int seed_; + std::atomic<bool> quit_flag_; + + enum ReaderState { + STARTING, + RUNNING, + DONE + }; + + explicit TestState(int s) + : seed_(s), quit_flag_(false), state_(STARTING), state_cv_(&mu_) {} + + void Wait(ReaderState s) { + mu_.Lock(); + while (state_ != s) { + state_cv_.Wait(); + } + mu_.Unlock(); + } + + void Change(ReaderState s) { + mu_.Lock(); + state_ = s; + state_cv_.Signal(); + mu_.Unlock(); + } + + private: + port::Mutex mu_; + ReaderState state_; + port::CondVar state_cv_; +}; + +static void ConcurrentReader(void* arg) { + TestState* state = reinterpret_cast<TestState*>(arg); + Random rnd(state->seed_); + int64_t reads = 0; + state->Change(TestState::RUNNING); + while (!state->quit_flag_.load(std::memory_order_acquire)) { + state->t_.ReadStep(&rnd); + ++reads; + } + state->Change(TestState::DONE); +} + +static void RunConcurrent(int run) { + const int seed = test::RandomSeed() + (run * 100); + Random rnd(seed); + const int N = 1000; + const int kSize = 1000; + for (int i = 0; i < N; i++) { + if ((i % 100) == 0) { + fprintf(stderr, "Run %d of %d\n", i, N); + } + TestState state(seed + 1); + Env::Default()->SetBackgroundThreads(1); + Env::Default()->Schedule(ConcurrentReader, &state); + state.Wait(TestState::RUNNING); + for (int k = 0; k < kSize; k++) { + state.t_.WriteStep(&rnd); + } + state.quit_flag_.store(true, std::memory_order_release); + state.Wait(TestState::DONE); + } +} + +TEST_F(SkipTest, Concurrent1) { RunConcurrent(1); } +TEST_F(SkipTest, Concurrent2) { RunConcurrent(2); } +TEST_F(SkipTest, Concurrent3) { RunConcurrent(3); } +TEST_F(SkipTest, Concurrent4) { RunConcurrent(4); } +TEST_F(SkipTest, Concurrent5) { RunConcurrent(5); } + +} // namespace rocksdb + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/src/rocksdb/memtable/skiplistrep.cc b/src/rocksdb/memtable/skiplistrep.cc new file mode 100644 index 00000000..32870b12 --- /dev/null +++ b/src/rocksdb/memtable/skiplistrep.cc @@ -0,0 +1,271 @@ +// Copyright (c) 2011-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). +// +#include "memtable/inlineskiplist.h" +#include "db/memtable.h" +#include "rocksdb/memtablerep.h" +#include "util/arena.h" + +namespace rocksdb { +namespace { +class SkipListRep : public MemTableRep { + InlineSkipList<const MemTableRep::KeyComparator&> skip_list_; + const MemTableRep::KeyComparator& cmp_; + const SliceTransform* transform_; + const size_t lookahead_; + + friend class LookaheadIterator; +public: + explicit SkipListRep(const MemTableRep::KeyComparator& compare, + Allocator* allocator, const SliceTransform* transform, + const size_t lookahead) + : MemTableRep(allocator), + skip_list_(compare, allocator), + cmp_(compare), + transform_(transform), + lookahead_(lookahead) {} + + KeyHandle Allocate(const size_t len, char** buf) override { + *buf = skip_list_.AllocateKey(len); + return static_cast<KeyHandle>(*buf); + } + + // Insert key into the list. + // REQUIRES: nothing that compares equal to key is currently in the list. + void Insert(KeyHandle handle) override { + skip_list_.Insert(static_cast<char*>(handle)); + } + + bool InsertKey(KeyHandle handle) override { + return skip_list_.Insert(static_cast<char*>(handle)); + } + + void InsertWithHint(KeyHandle handle, void** hint) override { + skip_list_.InsertWithHint(static_cast<char*>(handle), hint); + } + + bool InsertKeyWithHint(KeyHandle handle, void** hint) override { + return skip_list_.InsertWithHint(static_cast<char*>(handle), hint); + } + + void InsertConcurrently(KeyHandle handle) override { + skip_list_.InsertConcurrently(static_cast<char*>(handle)); + } + + bool InsertKeyConcurrently(KeyHandle handle) override { + return skip_list_.InsertConcurrently(static_cast<char*>(handle)); + } + + // Returns true iff an entry that compares equal to key is in the list. + bool Contains(const char* key) const override { + return skip_list_.Contains(key); + } + + size_t ApproximateMemoryUsage() override { + // All memory is allocated through allocator; nothing to report here + return 0; + } + + void Get(const LookupKey& k, void* callback_args, + bool (*callback_func)(void* arg, const char* entry)) override { + SkipListRep::Iterator iter(&skip_list_); + Slice dummy_slice; + for (iter.Seek(dummy_slice, k.memtable_key().data()); + iter.Valid() && callback_func(callback_args, iter.key()); iter.Next()) { + } + } + + uint64_t ApproximateNumEntries(const Slice& start_ikey, + const Slice& end_ikey) override { + std::string tmp; + uint64_t start_count = + skip_list_.EstimateCount(EncodeKey(&tmp, start_ikey)); + uint64_t end_count = skip_list_.EstimateCount(EncodeKey(&tmp, end_ikey)); + return (end_count >= start_count) ? (end_count - start_count) : 0; + } + + ~SkipListRep() override {} + + // Iteration over the contents of a skip list + class Iterator : public MemTableRep::Iterator { + InlineSkipList<const MemTableRep::KeyComparator&>::Iterator iter_; + + public: + // Initialize an iterator over the specified list. + // The returned iterator is not valid. + explicit Iterator( + const InlineSkipList<const MemTableRep::KeyComparator&>* list) + : iter_(list) {} + + ~Iterator() override {} + + // Returns true iff the iterator is positioned at a valid node. + bool Valid() const override { return iter_.Valid(); } + + // Returns the key at the current position. + // REQUIRES: Valid() + const char* key() const override { return iter_.key(); } + + // Advances to the next position. + // REQUIRES: Valid() + void Next() override { iter_.Next(); } + + // Advances to the previous position. + // REQUIRES: Valid() + void Prev() override { iter_.Prev(); } + + // Advance to the first entry with a key >= target + void Seek(const Slice& user_key, const char* memtable_key) override { + if (memtable_key != nullptr) { + iter_.Seek(memtable_key); + } else { + iter_.Seek(EncodeKey(&tmp_, user_key)); + } + } + + // Retreat to the last entry with a key <= target + void SeekForPrev(const Slice& user_key, const char* memtable_key) override { + if (memtable_key != nullptr) { + iter_.SeekForPrev(memtable_key); + } else { + iter_.SeekForPrev(EncodeKey(&tmp_, user_key)); + } + } + + // Position at the first entry in list. + // Final state of iterator is Valid() iff list is not empty. + void SeekToFirst() override { iter_.SeekToFirst(); } + + // Position at the last entry in list. + // Final state of iterator is Valid() iff list is not empty. + void SeekToLast() override { iter_.SeekToLast(); } + + protected: + std::string tmp_; // For passing to EncodeKey + }; + + // Iterator over the contents of a skip list which also keeps track of the + // previously visited node. In Seek(), it examines a few nodes after it + // first, falling back to O(log n) search from the head of the list only if + // the target key hasn't been found. + class LookaheadIterator : public MemTableRep::Iterator { + public: + explicit LookaheadIterator(const SkipListRep& rep) : + rep_(rep), iter_(&rep_.skip_list_), prev_(iter_) {} + + ~LookaheadIterator() override {} + + bool Valid() const override { return iter_.Valid(); } + + const char* key() const override { + assert(Valid()); + return iter_.key(); + } + + void Next() override { + assert(Valid()); + + bool advance_prev = true; + if (prev_.Valid()) { + auto k1 = rep_.UserKey(prev_.key()); + auto k2 = rep_.UserKey(iter_.key()); + + if (k1.compare(k2) == 0) { + // same user key, don't move prev_ + advance_prev = false; + } else if (rep_.transform_) { + // only advance prev_ if it has the same prefix as iter_ + auto t1 = rep_.transform_->Transform(k1); + auto t2 = rep_.transform_->Transform(k2); + advance_prev = t1.compare(t2) == 0; + } + } + + if (advance_prev) { + prev_ = iter_; + } + iter_.Next(); + } + + void Prev() override { + assert(Valid()); + iter_.Prev(); + prev_ = iter_; + } + + void Seek(const Slice& internal_key, const char* memtable_key) override { + const char *encoded_key = + (memtable_key != nullptr) ? + memtable_key : EncodeKey(&tmp_, internal_key); + + if (prev_.Valid() && rep_.cmp_(encoded_key, prev_.key()) >= 0) { + // prev_.key() is smaller or equal to our target key; do a quick + // linear search (at most lookahead_ steps) starting from prev_ + iter_ = prev_; + + size_t cur = 0; + while (cur++ <= rep_.lookahead_ && iter_.Valid()) { + if (rep_.cmp_(encoded_key, iter_.key()) <= 0) { + return; + } + Next(); + } + } + + iter_.Seek(encoded_key); + prev_ = iter_; + } + + void SeekForPrev(const Slice& internal_key, + const char* memtable_key) override { + const char* encoded_key = (memtable_key != nullptr) + ? memtable_key + : EncodeKey(&tmp_, internal_key); + iter_.SeekForPrev(encoded_key); + prev_ = iter_; + } + + void SeekToFirst() override { + iter_.SeekToFirst(); + prev_ = iter_; + } + + void SeekToLast() override { + iter_.SeekToLast(); + prev_ = iter_; + } + + protected: + std::string tmp_; // For passing to EncodeKey + + private: + const SkipListRep& rep_; + InlineSkipList<const MemTableRep::KeyComparator&>::Iterator iter_; + InlineSkipList<const MemTableRep::KeyComparator&>::Iterator prev_; + }; + + MemTableRep::Iterator* GetIterator(Arena* arena = nullptr) override { + if (lookahead_ > 0) { + void *mem = + arena ? arena->AllocateAligned(sizeof(SkipListRep::LookaheadIterator)) + : operator new(sizeof(SkipListRep::LookaheadIterator)); + return new (mem) SkipListRep::LookaheadIterator(*this); + } else { + void *mem = + arena ? arena->AllocateAligned(sizeof(SkipListRep::Iterator)) + : operator new(sizeof(SkipListRep::Iterator)); + return new (mem) SkipListRep::Iterator(&skip_list_); + } + } +}; +} + +MemTableRep* SkipListFactory::CreateMemTableRep( + const MemTableRep::KeyComparator& compare, Allocator* allocator, + const SliceTransform* transform, Logger* /*logger*/) { + return new SkipListRep(compare, allocator, transform, lookahead_); +} + +} // namespace rocksdb diff --git a/src/rocksdb/memtable/stl_wrappers.h b/src/rocksdb/memtable/stl_wrappers.h new file mode 100644 index 00000000..0287f4f8 --- /dev/null +++ b/src/rocksdb/memtable/stl_wrappers.h @@ -0,0 +1,33 @@ +// Copyright (c) 2011-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). +#pragma once + +#include <map> +#include <string> + +#include "rocksdb/comparator.h" +#include "rocksdb/memtablerep.h" +#include "rocksdb/slice.h" +#include "util/coding.h" + +namespace rocksdb { +namespace stl_wrappers { + +class Base { + protected: + const MemTableRep::KeyComparator& compare_; + explicit Base(const MemTableRep::KeyComparator& compare) + : compare_(compare) {} +}; + +struct Compare : private Base { + explicit Compare(const MemTableRep::KeyComparator& compare) : Base(compare) {} + inline bool operator()(const char* a, const char* b) const { + return compare_(a, b) < 0; + } +}; + +} +} diff --git a/src/rocksdb/memtable/vectorrep.cc b/src/rocksdb/memtable/vectorrep.cc new file mode 100644 index 00000000..827ab8a5 --- /dev/null +++ b/src/rocksdb/memtable/vectorrep.cc @@ -0,0 +1,301 @@ +// Copyright (c) 2011-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). +// +#ifndef ROCKSDB_LITE +#include "rocksdb/memtablerep.h" + +#include <unordered_set> +#include <set> +#include <memory> +#include <algorithm> +#include <type_traits> + +#include "util/arena.h" +#include "db/memtable.h" +#include "memtable/stl_wrappers.h" +#include "port/port.h" +#include "util/mutexlock.h" + +namespace rocksdb { +namespace { + +using namespace stl_wrappers; + +class VectorRep : public MemTableRep { + public: + VectorRep(const KeyComparator& compare, Allocator* allocator, size_t count); + + // Insert key into the collection. (The caller will pack key and value into a + // single buffer and pass that in as the parameter to Insert) + // REQUIRES: nothing that compares equal to key is currently in the + // collection. + void Insert(KeyHandle handle) override; + + // Returns true iff an entry that compares equal to key is in the collection. + bool Contains(const char* key) const override; + + void MarkReadOnly() override; + + size_t ApproximateMemoryUsage() override; + + void Get(const LookupKey& k, void* callback_args, + bool (*callback_func)(void* arg, const char* entry)) override; + + ~VectorRep() override {} + + class Iterator : public MemTableRep::Iterator { + class VectorRep* vrep_; + std::shared_ptr<std::vector<const char*>> bucket_; + std::vector<const char*>::const_iterator mutable cit_; + const KeyComparator& compare_; + std::string tmp_; // For passing to EncodeKey + bool mutable sorted_; + void DoSort() const; + public: + explicit Iterator(class VectorRep* vrep, + std::shared_ptr<std::vector<const char*>> bucket, + const KeyComparator& compare); + + // Initialize an iterator over the specified collection. + // The returned iterator is not valid. + // explicit Iterator(const MemTableRep* collection); + ~Iterator() override{}; + + // Returns true iff the iterator is positioned at a valid node. + bool Valid() const override; + + // Returns the key at the current position. + // REQUIRES: Valid() + const char* key() const override; + + // Advances to the next position. + // REQUIRES: Valid() + void Next() override; + + // Advances to the previous position. + // REQUIRES: Valid() + void Prev() override; + + // Advance to the first entry with a key >= target + void Seek(const Slice& user_key, const char* memtable_key) override; + + // Advance to the first entry with a key <= target + void SeekForPrev(const Slice& user_key, const char* memtable_key) override; + + // Position at the first entry in collection. + // Final state of iterator is Valid() iff collection is not empty. + void SeekToFirst() override; + + // Position at the last entry in collection. + // Final state of iterator is Valid() iff collection is not empty. + void SeekToLast() override; + }; + + // Return an iterator over the keys in this representation. + MemTableRep::Iterator* GetIterator(Arena* arena) override; + + private: + friend class Iterator; + typedef std::vector<const char*> Bucket; + std::shared_ptr<Bucket> bucket_; + mutable port::RWMutex rwlock_; + bool immutable_; + bool sorted_; + const KeyComparator& compare_; +}; + +void VectorRep::Insert(KeyHandle handle) { + auto* key = static_cast<char*>(handle); + WriteLock l(&rwlock_); + assert(!immutable_); + bucket_->push_back(key); +} + +// Returns true iff an entry that compares equal to key is in the collection. +bool VectorRep::Contains(const char* key) const { + ReadLock l(&rwlock_); + return std::find(bucket_->begin(), bucket_->end(), key) != bucket_->end(); +} + +void VectorRep::MarkReadOnly() { + WriteLock l(&rwlock_); + immutable_ = true; +} + +size_t VectorRep::ApproximateMemoryUsage() { + return + sizeof(bucket_) + sizeof(*bucket_) + + bucket_->size() * + sizeof( + std::remove_reference<decltype(*bucket_)>::type::value_type + ); +} + +VectorRep::VectorRep(const KeyComparator& compare, Allocator* allocator, + size_t count) + : MemTableRep(allocator), + bucket_(new Bucket()), + immutable_(false), + sorted_(false), + compare_(compare) { + bucket_.get()->reserve(count); +} + +VectorRep::Iterator::Iterator(class VectorRep* vrep, + std::shared_ptr<std::vector<const char*>> bucket, + const KeyComparator& compare) +: vrep_(vrep), + bucket_(bucket), + cit_(bucket_->end()), + compare_(compare), + sorted_(false) { } + +void VectorRep::Iterator::DoSort() const { + // vrep is non-null means that we are working on an immutable memtable + if (!sorted_ && vrep_ != nullptr) { + WriteLock l(&vrep_->rwlock_); + if (!vrep_->sorted_) { + std::sort(bucket_->begin(), bucket_->end(), Compare(compare_)); + cit_ = bucket_->begin(); + vrep_->sorted_ = true; + } + sorted_ = true; + } + if (!sorted_) { + std::sort(bucket_->begin(), bucket_->end(), Compare(compare_)); + cit_ = bucket_->begin(); + sorted_ = true; + } + assert(sorted_); + assert(vrep_ == nullptr || vrep_->sorted_); +} + +// Returns true iff the iterator is positioned at a valid node. +bool VectorRep::Iterator::Valid() const { + DoSort(); + return cit_ != bucket_->end(); +} + +// Returns the key at the current position. +// REQUIRES: Valid() +const char* VectorRep::Iterator::key() const { + assert(sorted_); + return *cit_; +} + +// Advances to the next position. +// REQUIRES: Valid() +void VectorRep::Iterator::Next() { + assert(sorted_); + if (cit_ == bucket_->end()) { + return; + } + ++cit_; +} + +// Advances to the previous position. +// REQUIRES: Valid() +void VectorRep::Iterator::Prev() { + assert(sorted_); + if (cit_ == bucket_->begin()) { + // If you try to go back from the first element, the iterator should be + // invalidated. So we set it to past-the-end. This means that you can + // treat the container circularly. + cit_ = bucket_->end(); + } else { + --cit_; + } +} + +// Advance to the first entry with a key >= target +void VectorRep::Iterator::Seek(const Slice& user_key, + const char* memtable_key) { + DoSort(); + // Do binary search to find first value not less than the target + const char* encoded_key = + (memtable_key != nullptr) ? memtable_key : EncodeKey(&tmp_, user_key); + cit_ = std::equal_range(bucket_->begin(), + bucket_->end(), + encoded_key, + [this] (const char* a, const char* b) { + return compare_(a, b) < 0; + }).first; +} + +// Advance to the first entry with a key <= target +void VectorRep::Iterator::SeekForPrev(const Slice& /*user_key*/, + const char* /*memtable_key*/) { + assert(false); +} + +// Position at the first entry in collection. +// Final state of iterator is Valid() iff collection is not empty. +void VectorRep::Iterator::SeekToFirst() { + DoSort(); + cit_ = bucket_->begin(); +} + +// Position at the last entry in collection. +// Final state of iterator is Valid() iff collection is not empty. +void VectorRep::Iterator::SeekToLast() { + DoSort(); + cit_ = bucket_->end(); + if (bucket_->size() != 0) { + --cit_; + } +} + +void VectorRep::Get(const LookupKey& k, void* callback_args, + bool (*callback_func)(void* arg, const char* entry)) { + rwlock_.ReadLock(); + VectorRep* vector_rep; + std::shared_ptr<Bucket> bucket; + if (immutable_) { + vector_rep = this; + } else { + vector_rep = nullptr; + bucket.reset(new Bucket(*bucket_)); // make a copy + } + VectorRep::Iterator iter(vector_rep, immutable_ ? bucket_ : bucket, compare_); + rwlock_.ReadUnlock(); + + for (iter.Seek(k.user_key(), k.memtable_key().data()); + iter.Valid() && callback_func(callback_args, iter.key()); iter.Next()) { + } +} + +MemTableRep::Iterator* VectorRep::GetIterator(Arena* arena) { + char* mem = nullptr; + if (arena != nullptr) { + mem = arena->AllocateAligned(sizeof(Iterator)); + } + ReadLock l(&rwlock_); + // Do not sort here. The sorting would be done the first time + // a Seek is performed on the iterator. + if (immutable_) { + if (arena == nullptr) { + return new Iterator(this, bucket_, compare_); + } else { + return new (mem) Iterator(this, bucket_, compare_); + } + } else { + std::shared_ptr<Bucket> tmp; + tmp.reset(new Bucket(*bucket_)); // make a copy + if (arena == nullptr) { + return new Iterator(nullptr, tmp, compare_); + } else { + return new (mem) Iterator(nullptr, tmp, compare_); + } + } +} +} // anon namespace + +MemTableRep* VectorRepFactory::CreateMemTableRep( + const MemTableRep::KeyComparator& compare, Allocator* allocator, + const SliceTransform*, Logger* /*logger*/) { + return new VectorRep(compare, allocator, count_); +} +} // namespace rocksdb +#endif // ROCKSDB_LITE diff --git a/src/rocksdb/memtable/write_buffer_manager.cc b/src/rocksdb/memtable/write_buffer_manager.cc new file mode 100644 index 00000000..7f2e664a --- /dev/null +++ b/src/rocksdb/memtable/write_buffer_manager.cc @@ -0,0 +1,130 @@ +// Copyright (c) 2011-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). +// +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#include "rocksdb/write_buffer_manager.h" +#include <mutex> +#include "util/coding.h" + +namespace rocksdb { +#ifndef ROCKSDB_LITE +namespace { +const size_t kSizeDummyEntry = 1024 * 1024; +// The key will be longer than keys for blocks in SST files so they won't +// conflict. +const size_t kCacheKeyPrefix = kMaxVarint64Length * 4 + 1; +} // namespace + +struct WriteBufferManager::CacheRep { + std::shared_ptr<Cache> cache_; + std::mutex cache_mutex_; + std::atomic<size_t> cache_allocated_size_; + // The non-prefix part will be updated according to the ID to use. + char cache_key_[kCacheKeyPrefix + kMaxVarint64Length]; + uint64_t next_cache_key_id_ = 0; + std::vector<Cache::Handle*> dummy_handles_; + + explicit CacheRep(std::shared_ptr<Cache> cache) + : cache_(cache), cache_allocated_size_(0) { + memset(cache_key_, 0, kCacheKeyPrefix); + size_t pointer_size = sizeof(const void*); + assert(pointer_size <= kCacheKeyPrefix); + memcpy(cache_key_, static_cast<const void*>(this), pointer_size); + } + + Slice GetNextCacheKey() { + memset(cache_key_ + kCacheKeyPrefix, 0, kMaxVarint64Length); + char* end = + EncodeVarint64(cache_key_ + kCacheKeyPrefix, next_cache_key_id_++); + return Slice(cache_key_, static_cast<size_t>(end - cache_key_)); + } +}; +#else +struct WriteBufferManager::CacheRep {}; +#endif // ROCKSDB_LITE + +WriteBufferManager::WriteBufferManager(size_t _buffer_size, + std::shared_ptr<Cache> cache) + : buffer_size_(_buffer_size), + mutable_limit_(buffer_size_ * 7 / 8), + memory_used_(0), + memory_active_(0), + cache_rep_(nullptr) { +#ifndef ROCKSDB_LITE + if (cache) { + // Construct the cache key using the pointer to this. + cache_rep_.reset(new CacheRep(cache)); + } +#else + (void)cache; +#endif // ROCKSDB_LITE +} + +WriteBufferManager::~WriteBufferManager() { +#ifndef ROCKSDB_LITE + if (cache_rep_) { + for (auto* handle : cache_rep_->dummy_handles_) { + cache_rep_->cache_->Release(handle, true); + } + } +#endif // ROCKSDB_LITE +} + +// Should only be called from write thread +void WriteBufferManager::ReserveMemWithCache(size_t mem) { +#ifndef ROCKSDB_LITE + assert(cache_rep_ != nullptr); + // Use a mutex to protect various data structures. Can be optimized to a + // lock-free solution if it ends up with a performance bottleneck. + std::lock_guard<std::mutex> lock(cache_rep_->cache_mutex_); + + size_t new_mem_used = memory_used_.load(std::memory_order_relaxed) + mem; + memory_used_.store(new_mem_used, std::memory_order_relaxed); + while (new_mem_used > cache_rep_->cache_allocated_size_) { + // Expand size by at least 1MB. + // Add a dummy record to the cache + Cache::Handle* handle; + cache_rep_->cache_->Insert(cache_rep_->GetNextCacheKey(), nullptr, + kSizeDummyEntry, nullptr, &handle); + cache_rep_->dummy_handles_.push_back(handle); + cache_rep_->cache_allocated_size_ += kSizeDummyEntry; + } +#else + (void)mem; +#endif // ROCKSDB_LITE +} + +void WriteBufferManager::FreeMemWithCache(size_t mem) { +#ifndef ROCKSDB_LITE + assert(cache_rep_ != nullptr); + // Use a mutex to protect various data structures. Can be optimized to a + // lock-free solution if it ends up with a performance bottleneck. + std::lock_guard<std::mutex> lock(cache_rep_->cache_mutex_); + size_t new_mem_used = memory_used_.load(std::memory_order_relaxed) - mem; + memory_used_.store(new_mem_used, std::memory_order_relaxed); + // Gradually shrink memory costed in the block cache if the actual + // usage is less than 3/4 of what we reserve from the block cache. + // We do this because: + // 1. we don't pay the cost of the block cache immediately a memtable is + // freed, as block cache insert is expensive; + // 2. eventually, if we walk away from a temporary memtable size increase, + // we make sure shrink the memory costed in block cache over time. + // In this way, we only shrink costed memory showly even there is enough + // margin. + if (new_mem_used < cache_rep_->cache_allocated_size_ / 4 * 3 && + cache_rep_->cache_allocated_size_ - kSizeDummyEntry > new_mem_used) { + assert(!cache_rep_->dummy_handles_.empty()); + cache_rep_->cache_->Release(cache_rep_->dummy_handles_.back(), true); + cache_rep_->dummy_handles_.pop_back(); + cache_rep_->cache_allocated_size_ -= kSizeDummyEntry; + } +#else + (void)mem; +#endif // ROCKSDB_LITE +} +} // namespace rocksdb diff --git a/src/rocksdb/memtable/write_buffer_manager_test.cc b/src/rocksdb/memtable/write_buffer_manager_test.cc new file mode 100644 index 00000000..0fc9fd06 --- /dev/null +++ b/src/rocksdb/memtable/write_buffer_manager_test.cc @@ -0,0 +1,151 @@ +// Copyright (c) 2011-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). +// +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#include "rocksdb/write_buffer_manager.h" +#include "util/testharness.h" + +namespace rocksdb { + +class WriteBufferManagerTest : public testing::Test {}; + +#ifndef ROCKSDB_LITE +TEST_F(WriteBufferManagerTest, ShouldFlush) { + // A write buffer manager of size 10MB + std::unique_ptr<WriteBufferManager> wbf( + new WriteBufferManager(10 * 1024 * 1024)); + + wbf->ReserveMem(8 * 1024 * 1024); + ASSERT_FALSE(wbf->ShouldFlush()); + // 90% of the hard limit will hit the condition + wbf->ReserveMem(1 * 1024 * 1024); + ASSERT_TRUE(wbf->ShouldFlush()); + // Scheduling for freeing will release the condition + wbf->ScheduleFreeMem(1 * 1024 * 1024); + ASSERT_FALSE(wbf->ShouldFlush()); + + wbf->ReserveMem(2 * 1024 * 1024); + ASSERT_TRUE(wbf->ShouldFlush()); + + wbf->ScheduleFreeMem(4 * 1024 * 1024); + // 11MB total, 6MB mutable. hard limit still hit + ASSERT_TRUE(wbf->ShouldFlush()); + + wbf->ScheduleFreeMem(2 * 1024 * 1024); + // 11MB total, 4MB mutable. hard limit stills but won't flush because more + // than half data is already being flushed. + ASSERT_FALSE(wbf->ShouldFlush()); + + wbf->ReserveMem(4 * 1024 * 1024); + // 15 MB total, 8MB mutable. + ASSERT_TRUE(wbf->ShouldFlush()); + + wbf->FreeMem(7 * 1024 * 1024); + // 9MB total, 8MB mutable. + ASSERT_FALSE(wbf->ShouldFlush()); +} + +TEST_F(WriteBufferManagerTest, CacheCost) { + // 1GB cache + std::shared_ptr<Cache> cache = NewLRUCache(1024 * 1024 * 1024, 4); + // A write buffer manager of size 50MB + std::unique_ptr<WriteBufferManager> wbf( + new WriteBufferManager(50 * 1024 * 1024, cache)); + + // Allocate 1.5MB will allocate 2MB + wbf->ReserveMem(1536 * 1024); + ASSERT_GE(cache->GetPinnedUsage(), 2 * 1024 * 1024); + ASSERT_LT(cache->GetPinnedUsage(), 2 * 1024 * 1024 + 10000); + + // Allocate another 2MB + wbf->ReserveMem(2 * 1024 * 1024); + ASSERT_GE(cache->GetPinnedUsage(), 4 * 1024 * 1024); + ASSERT_LT(cache->GetPinnedUsage(), 4 * 1024 * 1024 + 10000); + + // Allocate another 20MB + wbf->ReserveMem(20 * 1024 * 1024); + ASSERT_GE(cache->GetPinnedUsage(), 24 * 1024 * 1024); + ASSERT_LT(cache->GetPinnedUsage(), 24 * 1024 * 1024 + 10000); + + // Free 2MB will not cause any change in cache cost + wbf->FreeMem(2 * 1024 * 1024); + ASSERT_GE(cache->GetPinnedUsage(), 24 * 1024 * 1024); + ASSERT_LT(cache->GetPinnedUsage(), 24 * 1024 * 1024 + 10000); + + ASSERT_FALSE(wbf->ShouldFlush()); + + // Allocate another 30MB + wbf->ReserveMem(30 * 1024 * 1024); + ASSERT_GE(cache->GetPinnedUsage(), 52 * 1024 * 1024); + ASSERT_LT(cache->GetPinnedUsage(), 52 * 1024 * 1024 + 10000); + ASSERT_TRUE(wbf->ShouldFlush()); + + ASSERT_TRUE(wbf->ShouldFlush()); + + wbf->ScheduleFreeMem(20 * 1024 * 1024); + ASSERT_GE(cache->GetPinnedUsage(), 52 * 1024 * 1024); + ASSERT_LT(cache->GetPinnedUsage(), 52 * 1024 * 1024 + 10000); + + // Still need flush as the hard limit hits + ASSERT_TRUE(wbf->ShouldFlush()); + + // Free 20MB will releae 1MB from cache + wbf->FreeMem(20 * 1024 * 1024); + ASSERT_GE(cache->GetPinnedUsage(), 51 * 1024 * 1024); + ASSERT_LT(cache->GetPinnedUsage(), 51 * 1024 * 1024 + 10000); + + ASSERT_FALSE(wbf->ShouldFlush()); + + // Every free will release 1MB if still not hit 3/4 + wbf->FreeMem(16 * 1024); + ASSERT_GE(cache->GetPinnedUsage(), 50 * 1024 * 1024); + ASSERT_LT(cache->GetPinnedUsage(), 50 * 1024 * 1024 + 10000); + + wbf->FreeMem(16 * 1024); + ASSERT_GE(cache->GetPinnedUsage(), 49 * 1024 * 1024); + ASSERT_LT(cache->GetPinnedUsage(), 49 * 1024 * 1024 + 10000); + + // Free 2MB will not cause any change in cache cost + wbf->ReserveMem(2 * 1024 * 1024); + ASSERT_GE(cache->GetPinnedUsage(), 49 * 1024 * 1024); + ASSERT_LT(cache->GetPinnedUsage(), 49 * 1024 * 1024 + 10000); + + wbf->FreeMem(16 * 1024); + ASSERT_GE(cache->GetPinnedUsage(), 48 * 1024 * 1024); + ASSERT_LT(cache->GetPinnedUsage(), 48 * 1024 * 1024 + 10000); + + // Destory write buffer manger should free everything + wbf.reset(); + ASSERT_LT(cache->GetPinnedUsage(), 1024 * 1024); +} + +TEST_F(WriteBufferManagerTest, NoCapCacheCost) { + // 1GB cache + std::shared_ptr<Cache> cache = NewLRUCache(1024 * 1024 * 1024, 4); + // A write buffer manager of size 256MB + std::unique_ptr<WriteBufferManager> wbf(new WriteBufferManager(0, cache)); + // Allocate 1.5MB will allocate 2MB + wbf->ReserveMem(10 * 1024 * 1024); + ASSERT_GE(cache->GetPinnedUsage(), 10 * 1024 * 1024); + ASSERT_LT(cache->GetPinnedUsage(), 10 * 1024 * 1024 + 10000); + ASSERT_FALSE(wbf->ShouldFlush()); + + wbf->FreeMem(9 * 1024 * 1024); + for (int i = 0; i < 10; i++) { + wbf->FreeMem(16 * 1024); + } + ASSERT_GE(cache->GetPinnedUsage(), 1024 * 1024); + ASSERT_LT(cache->GetPinnedUsage(), 1024 * 1024 + 10000); +} +#endif // ROCKSDB_LITE +} // namespace rocksdb + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} |