diff options
Diffstat (limited to 'src/rocksdb/utilities/agg_merge')
-rw-r--r-- | src/rocksdb/utilities/agg_merge/agg_merge.cc | 238 | ||||
-rw-r--r-- | src/rocksdb/utilities/agg_merge/agg_merge.h | 49 | ||||
-rw-r--r-- | src/rocksdb/utilities/agg_merge/agg_merge_test.cc | 135 | ||||
-rw-r--r-- | src/rocksdb/utilities/agg_merge/test_agg_merge.cc | 104 | ||||
-rw-r--r-- | src/rocksdb/utilities/agg_merge/test_agg_merge.h | 47 |
5 files changed, 573 insertions, 0 deletions
diff --git a/src/rocksdb/utilities/agg_merge/agg_merge.cc b/src/rocksdb/utilities/agg_merge/agg_merge.cc new file mode 100644 index 000000000..a7eab1f12 --- /dev/null +++ b/src/rocksdb/utilities/agg_merge/agg_merge.cc @@ -0,0 +1,238 @@ +// Copyright (c) 2017-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). + +#include "utilities/agg_merge/agg_merge.h" + +#include <assert.h> + +#include <deque> +#include <memory> +#include <type_traits> +#include <utility> +#include <vector> + +#include "port/lang.h" +#include "port/likely.h" +#include "rocksdb/merge_operator.h" +#include "rocksdb/slice.h" +#include "rocksdb/utilities/agg_merge.h" +#include "rocksdb/utilities/options_type.h" +#include "util/coding.h" +#include "utilities/merge_operators.h" + +namespace ROCKSDB_NAMESPACE { +static std::unordered_map<std::string, std::unique_ptr<Aggregator>> func_map; +const std::string kUnnamedFuncName = ""; +const std::string kErrorFuncName = "kErrorFuncName"; + +Status AddAggregator(const std::string& function_name, + std::unique_ptr<Aggregator>&& agg) { + if (function_name == kErrorFuncName) { + return Status::InvalidArgument( + "Cannot register function name kErrorFuncName"); + } + func_map.emplace(function_name, std::move(agg)); + return Status::OK(); +} + +AggMergeOperator::AggMergeOperator() {} + +std::string EncodeAggFuncAndPayloadNoCheck(const Slice& function_name, + const Slice& value) { + std::string result; + PutLengthPrefixedSlice(&result, function_name); + result += value.ToString(); + return result; +} + +Status EncodeAggFuncAndPayload(const Slice& function_name, const Slice& payload, + std::string& output) { + if (function_name == kErrorFuncName) { + return Status::InvalidArgument("Cannot use error function name"); + } + if (function_name != kUnnamedFuncName && + func_map.find(function_name.ToString()) == func_map.end()) { + return Status::InvalidArgument("Function name not registered"); + } + output = EncodeAggFuncAndPayloadNoCheck(function_name, payload); + return Status::OK(); +} + +bool ExtractAggFuncAndValue(const Slice& op, Slice& func, Slice& value) { + value = op; + return GetLengthPrefixedSlice(&value, &func); +} + +bool ExtractList(const Slice& encoded_list, std::vector<Slice>& decoded_list) { + decoded_list.clear(); + Slice list_slice = encoded_list; + Slice item; + while (GetLengthPrefixedSlice(&list_slice, &item)) { + decoded_list.push_back(item); + } + return list_slice.empty(); +} + +class AggMergeOperator::Accumulator { + public: + bool Add(const Slice& op, bool is_partial_aggregation) { + if (ignore_operands_) { + return true; + } + Slice my_func; + Slice my_value; + bool ret = ExtractAggFuncAndValue(op, my_func, my_value); + if (!ret) { + ignore_operands_ = true; + return true; + } + + // Determine whether we need to do partial merge. + if (is_partial_aggregation && !my_func.empty()) { + auto f = func_map.find(my_func.ToString()); + if (f == func_map.end() || !f->second->DoPartialAggregate()) { + return false; + } + } + + if (!func_valid_) { + if (my_func != kUnnamedFuncName) { + func_ = my_func; + func_valid_ = true; + } + } else if (func_ != my_func) { + // User switched aggregation function. Need to aggregate the older + // one first. + + // Previous aggreagion can't be done in partial merge + if (is_partial_aggregation) { + func_valid_ = false; + ignore_operands_ = true; + return false; + } + + // We could consider stashing an iterator into the hash of aggregators + // to avoid repeated lookups when the aggregator doesn't change. + auto f = func_map.find(func_.ToString()); + if (f == func_map.end() || !f->second->Aggregate(values_, scratch_)) { + func_valid_ = false; + ignore_operands_ = true; + return true; + } + std::swap(scratch_, aggregated_); + values_.clear(); + values_.push_back(aggregated_); + func_ = my_func; + } + values_.push_back(my_value); + return true; + } + + // Return false if aggregation fails. + // One possible reason + bool GetResult(std::string& result) { + if (!func_valid_) { + return false; + } + auto f = func_map.find(func_.ToString()); + if (f == func_map.end()) { + return false; + } + if (!f->second->Aggregate(values_, scratch_)) { + return false; + } + result = EncodeAggFuncAndPayloadNoCheck(func_, scratch_); + return true; + } + + void Clear() { + func_.clear(); + values_.clear(); + aggregated_.clear(); + scratch_.clear(); + ignore_operands_ = false; + func_valid_ = false; + } + + private: + Slice func_; + std::vector<Slice> values_; + std::string aggregated_; + std::string scratch_; + bool ignore_operands_ = false; + bool func_valid_ = false; +}; + +// Creating and using a new Accumulator might invoke multiple malloc and is +// expensive if it needs to be done when processing each merge operation. +// AggMergeOperator's merge operators can be invoked concurrently by multiple +// threads so we cannot simply create one Aggregator and reuse. +// We use thread local instances instead. +AggMergeOperator::Accumulator& AggMergeOperator::GetTLSAccumulator() { + static thread_local Accumulator tls_acc; + tls_acc.Clear(); + return tls_acc; +} + +void AggMergeOperator::PackAllMergeOperands(const MergeOperationInput& merge_in, + MergeOperationOutput& merge_out) { + merge_out.new_value = ""; + PutLengthPrefixedSlice(&merge_out.new_value, kErrorFuncName); + if (merge_in.existing_value != nullptr) { + PutLengthPrefixedSlice(&merge_out.new_value, *merge_in.existing_value); + } + for (const Slice& op : merge_in.operand_list) { + PutLengthPrefixedSlice(&merge_out.new_value, op); + } +} + +bool AggMergeOperator::FullMergeV2(const MergeOperationInput& merge_in, + MergeOperationOutput* merge_out) const { + Accumulator& agg = GetTLSAccumulator(); + if (merge_in.existing_value != nullptr) { + agg.Add(*merge_in.existing_value, /*is_partial_aggregation=*/false); + } + for (const Slice& e : merge_in.operand_list) { + agg.Add(e, /*is_partial_aggregation=*/false); + } + + bool succ = agg.GetResult(merge_out->new_value); + if (!succ) { + // If aggregation can't happen, pack all merge operands. In contrast to + // merge operator, we don't want to fail the DB. If users insert wrong + // format or call unregistered an aggregation function, we still hope + // the DB can continue functioning with other keys. + PackAllMergeOperands(merge_in, *merge_out); + } + agg.Clear(); + return true; +} + +bool AggMergeOperator::PartialMergeMulti(const Slice& /*key*/, + const std::deque<Slice>& operand_list, + std::string* new_value, + Logger* /*logger*/) const { + Accumulator& agg = GetTLSAccumulator(); + bool do_aggregation = true; + for (const Slice& item : operand_list) { + do_aggregation = agg.Add(item, /*is_partial_aggregation=*/true); + if (!do_aggregation) { + break; + } + } + if (do_aggregation) { + do_aggregation = agg.GetResult(*new_value); + } + agg.Clear(); + return do_aggregation; +} + +std::shared_ptr<MergeOperator> GetAggMergeOperator() { + STATIC_AVOID_DESTRUCTION(std::shared_ptr<MergeOperator>, instance) + (std::make_shared<AggMergeOperator>()); + assert(instance); + return instance; +} +} // namespace ROCKSDB_NAMESPACE diff --git a/src/rocksdb/utilities/agg_merge/agg_merge.h b/src/rocksdb/utilities/agg_merge/agg_merge.h new file mode 100644 index 000000000..00e58de08 --- /dev/null +++ b/src/rocksdb/utilities/agg_merge/agg_merge.h @@ -0,0 +1,49 @@ +// Copyright (c) 2017-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). + +#pragma once +#include <algorithm> +#include <cstddef> +#include <memory> +#include <unordered_map> + +#include "rocksdb/merge_operator.h" +#include "rocksdb/slice.h" +#include "rocksdb/utilities/agg_merge.h" +#include "utilities/cassandra/cassandra_options.h" + +namespace ROCKSDB_NAMESPACE { +class AggMergeOperator : public MergeOperator { + public: + explicit AggMergeOperator(); + + bool FullMergeV2(const MergeOperationInput& merge_in, + MergeOperationOutput* merge_out) const override; + + bool PartialMergeMulti(const Slice& key, + const std::deque<Slice>& operand_list, + std::string* new_value, Logger* logger) const override; + + const char* Name() const override { return kClassName(); } + static const char* kClassName() { return "AggMergeOperator.v1"; } + + bool AllowSingleOperand() const override { return true; } + + bool ShouldMerge(const std::vector<Slice>&) const override { return false; } + + private: + class Accumulator; + + // Pack all merge operands into one value. This is called when aggregation + // fails. The existing values are preserved and returned so that users can + // debug the problem. + static void PackAllMergeOperands(const MergeOperationInput& merge_in, + MergeOperationOutput& merge_out); + static Accumulator& GetTLSAccumulator(); +}; + +extern std::string EncodeAggFuncAndPayloadNoCheck(const Slice& function_name, + const Slice& value); +} // namespace ROCKSDB_NAMESPACE diff --git a/src/rocksdb/utilities/agg_merge/agg_merge_test.cc b/src/rocksdb/utilities/agg_merge/agg_merge_test.cc new file mode 100644 index 000000000..a65441cd0 --- /dev/null +++ b/src/rocksdb/utilities/agg_merge/agg_merge_test.cc @@ -0,0 +1,135 @@ +// Copyright (c) 2017-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). + +#include "rocksdb/utilities/agg_merge.h" + +#include <gtest/gtest.h> + +#include <memory> + +#include "db/db_test_util.h" +#include "rocksdb/options.h" +#include "test_util/testharness.h" +#include "utilities/agg_merge/agg_merge.h" +#include "utilities/agg_merge/test_agg_merge.h" + +namespace ROCKSDB_NAMESPACE { + +class AggMergeTest : public DBTestBase { + public: + AggMergeTest() : DBTestBase("agg_merge_db_test", /*env_do_fsync=*/true) {} +}; + +TEST_F(AggMergeTest, TestUsingMergeOperator) { + ASSERT_OK(AddAggregator("sum", std::make_unique<SumAggregator>())); + ASSERT_OK(AddAggregator("last3", std::make_unique<Last3Aggregator>())); + ASSERT_OK(AddAggregator("mul", std::make_unique<MultipleAggregator>())); + + Options options = CurrentOptions(); + options.merge_operator = GetAggMergeOperator(); + Reopen(options); + std::string v = EncodeHelper::EncodeFuncAndInt("sum", 10); + ASSERT_OK(Merge("foo", v)); + v = EncodeHelper::EncodeFuncAndInt("sum", 20); + ASSERT_OK(Merge("foo", v)); + v = EncodeHelper::EncodeFuncAndInt("sum", 15); + ASSERT_OK(Merge("foo", v)); + + v = EncodeHelper::EncodeFuncAndList("last3", {"a", "b"}); + ASSERT_OK(Merge("bar", v)); + v = EncodeHelper::EncodeFuncAndList("last3", {"c", "d", "e"}); + ASSERT_OK(Merge("bar", v)); + ASSERT_OK(Flush()); + v = EncodeHelper::EncodeFuncAndList("last3", {"f"}); + ASSERT_OK(Merge("bar", v)); + + // Test Put() without aggregation type. + v = EncodeHelper::EncodeFuncAndInt(kUnnamedFuncName, 30); + ASSERT_OK(Put("foo2", v)); + v = EncodeHelper::EncodeFuncAndInt("sum", 10); + ASSERT_OK(Merge("foo2", v)); + v = EncodeHelper::EncodeFuncAndInt("sum", 20); + ASSERT_OK(Merge("foo2", v)); + + EXPECT_EQ(EncodeHelper::EncodeFuncAndInt("sum", 45), Get("foo")); + EXPECT_EQ(EncodeHelper::EncodeFuncAndList("last3", {"f", "c", "d"}), + Get("bar")); + EXPECT_EQ(EncodeHelper::EncodeFuncAndInt("sum", 60), Get("foo2")); + + // Test changing aggregation type + v = EncodeHelper::EncodeFuncAndInt("mul", 10); + ASSERT_OK(Put("bar2", v)); + v = EncodeHelper::EncodeFuncAndInt("mul", 20); + ASSERT_OK(Merge("bar2", v)); + v = EncodeHelper::EncodeFuncAndInt("sum", 30); + ASSERT_OK(Merge("bar2", v)); + v = EncodeHelper::EncodeFuncAndInt("sum", 40); + ASSERT_OK(Merge("bar2", v)); + EXPECT_EQ(EncodeHelper::EncodeFuncAndInt("sum", 10 * 20 + 30 + 40), + Get("bar2")); + + // Changing aggregation type with partial merge + v = EncodeHelper::EncodeFuncAndInt("mul", 10); + ASSERT_OK(Merge("foo3", v)); + ASSERT_OK(Flush()); + v = EncodeHelper::EncodeFuncAndInt("mul", 10); + ASSERT_OK(Merge("foo3", v)); + v = EncodeHelper::EncodeFuncAndInt("mul", 10); + ASSERT_OK(Merge("foo3", v)); + v = EncodeHelper::EncodeFuncAndInt("sum", 10); + ASSERT_OK(Merge("foo3", v)); + ASSERT_OK(Flush()); + EXPECT_EQ(EncodeHelper::EncodeFuncAndInt("sum", 10 * 10 * 10 + 10), + Get("foo3")); + + // Merge after full merge + v = EncodeHelper::EncodeFuncAndInt("sum", 1); + ASSERT_OK(Merge("foo4", v)); + v = EncodeHelper::EncodeFuncAndInt("sum", 2); + ASSERT_OK(Merge("foo4", v)); + ASSERT_OK(Flush()); + v = EncodeHelper::EncodeFuncAndInt("sum", 3); + ASSERT_OK(Merge("foo4", v)); + v = EncodeHelper::EncodeFuncAndInt("sum", 4); + ASSERT_OK(Merge("foo4", v)); + ASSERT_OK(Flush()); + ASSERT_OK(db_->CompactRange(CompactRangeOptions(), nullptr, nullptr)); + v = EncodeHelper::EncodeFuncAndInt("sum", 5); + ASSERT_OK(Merge("foo4", v)); + EXPECT_EQ(EncodeHelper::EncodeFuncAndInt("sum", 15), Get("foo4")); + + // Test unregistered function name + v = EncodeAggFuncAndPayloadNoCheck("non_existing", "1"); + ASSERT_OK(Merge("bar3", v)); + std::string v1; + v1 = EncodeAggFuncAndPayloadNoCheck("non_existing", "invalid"); + ; + ASSERT_OK(Merge("bar3", v1)); + EXPECT_EQ(EncodeAggFuncAndPayloadNoCheck(kErrorFuncName, + EncodeHelper::EncodeList({v, v1})), + Get("bar3")); + + // invalidate input + ASSERT_OK(EncodeAggFuncAndPayload("sum", "invalid", v)); + ASSERT_OK(Merge("bar4", v)); + v1 = EncodeHelper::EncodeFuncAndInt("sum", 20); + ASSERT_OK(Merge("bar4", v1)); + std::string aggregated_value = Get("bar4"); + Slice func, payload; + ASSERT_TRUE(ExtractAggFuncAndValue(aggregated_value, func, payload)); + EXPECT_EQ(kErrorFuncName, func); + std::vector<Slice> decoded_list; + ASSERT_TRUE(ExtractList(payload, decoded_list)); + ASSERT_EQ(2, decoded_list.size()); + ASSERT_EQ(v, decoded_list[0]); + ASSERT_EQ(v1, decoded_list[1]); +} +} // namespace ROCKSDB_NAMESPACE + +int main(int argc, char** argv) { + ROCKSDB_NAMESPACE::port::InstallStackTraceHandler(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/src/rocksdb/utilities/agg_merge/test_agg_merge.cc b/src/rocksdb/utilities/agg_merge/test_agg_merge.cc new file mode 100644 index 000000000..06e5b5697 --- /dev/null +++ b/src/rocksdb/utilities/agg_merge/test_agg_merge.cc @@ -0,0 +1,104 @@ +// Copyright (c) 2017-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). + +#include "test_agg_merge.h" + +#include <assert.h> + +#include <deque> +#include <vector> + +#include "util/coding.h" +#include "utilities/agg_merge/agg_merge.h" + +namespace ROCKSDB_NAMESPACE { + +std::string EncodeHelper::EncodeFuncAndInt(const Slice& function_name, + int64_t value) { + std::string encoded_value; + PutVarsignedint64(&encoded_value, value); + std::string ret; + Status s = EncodeAggFuncAndPayload(function_name, encoded_value, ret); + assert(s.ok()); + return ret; +} + +std::string EncodeHelper::EncodeInt(int64_t value) { + std::string encoded_value; + PutVarsignedint64(&encoded_value, value); + return encoded_value; +} + +std::string EncodeHelper::EncodeFuncAndList(const Slice& function_name, + const std::vector<Slice>& list) { + std::string ret; + Status s = EncodeAggFuncAndPayload(function_name, EncodeList(list), ret); + assert(s.ok()); + return ret; +} + +std::string EncodeHelper::EncodeList(const std::vector<Slice>& list) { + std::string result; + for (const Slice& entity : list) { + PutLengthPrefixedSlice(&result, entity); + } + return result; +} + +bool SumAggregator::Aggregate(const std::vector<Slice>& item_list, + std::string& result) const { + int64_t sum = 0; + for (const Slice& item : item_list) { + int64_t ivalue; + Slice v = item; + if (!GetVarsignedint64(&v, &ivalue) || !v.empty()) { + return false; + } + sum += ivalue; + } + result = EncodeHelper::EncodeInt(sum); + return true; +} + +bool MultipleAggregator::Aggregate(const std::vector<Slice>& item_list, + std::string& result) const { + int64_t mresult = 1; + for (const Slice& item : item_list) { + int64_t ivalue; + Slice v = item; + if (!GetVarsignedint64(&v, &ivalue) || !v.empty()) { + return false; + } + mresult *= ivalue; + } + result = EncodeHelper::EncodeInt(mresult); + return true; +} + +bool Last3Aggregator::Aggregate(const std::vector<Slice>& item_list, + std::string& result) const { + std::vector<Slice> last3; + last3.reserve(3); + for (auto it = item_list.rbegin(); it != item_list.rend(); ++it) { + Slice input = *it; + Slice entity; + bool ret; + while ((ret = GetLengthPrefixedSlice(&input, &entity)) == true) { + last3.push_back(entity); + if (last3.size() >= 3) { + break; + } + } + if (last3.size() >= 3) { + break; + } + if (!ret) { + continue; + } + } + result = EncodeHelper::EncodeList(last3); + return true; +} +} // namespace ROCKSDB_NAMESPACE diff --git a/src/rocksdb/utilities/agg_merge/test_agg_merge.h b/src/rocksdb/utilities/agg_merge/test_agg_merge.h new file mode 100644 index 000000000..5bdf8b9cc --- /dev/null +++ b/src/rocksdb/utilities/agg_merge/test_agg_merge.h @@ -0,0 +1,47 @@ +// Copyright (c) 2017-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). + +#pragma once +#include <algorithm> +#include <cstddef> +#include <memory> +#include <unordered_map> + +#include "rocksdb/merge_operator.h" +#include "rocksdb/slice.h" +#include "rocksdb/utilities/agg_merge.h" +#include "utilities/cassandra/cassandra_options.h" + +namespace ROCKSDB_NAMESPACE { +class SumAggregator : public Aggregator { + public: + ~SumAggregator() override {} + bool Aggregate(const std::vector<Slice>&, std::string& result) const override; + bool DoPartialAggregate() const override { return true; } +}; + +class MultipleAggregator : public Aggregator { + public: + ~MultipleAggregator() override {} + bool Aggregate(const std::vector<Slice>&, std::string& result) const override; + bool DoPartialAggregate() const override { return true; } +}; + +class Last3Aggregator : public Aggregator { + public: + ~Last3Aggregator() override {} + bool Aggregate(const std::vector<Slice>&, std::string& result) const override; +}; + +class EncodeHelper { + public: + static std::string EncodeFuncAndInt(const Slice& function_name, + int64_t value); + static std::string EncodeInt(int64_t value); + static std::string EncodeList(const std::vector<Slice>& list); + static std::string EncodeFuncAndList(const Slice& function_name, + const std::vector<Slice>& list); +}; +} // namespace ROCKSDB_NAMESPACE |