summaryrefslogtreecommitdiffstats
path: root/src/rocksdb/utilities/agg_merge
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--src/rocksdb/utilities/agg_merge/agg_merge.cc238
-rw-r--r--src/rocksdb/utilities/agg_merge/agg_merge.h49
-rw-r--r--src/rocksdb/utilities/agg_merge/agg_merge_test.cc135
-rw-r--r--src/rocksdb/utilities/agg_merge/test_agg_merge.cc104
-rw-r--r--src/rocksdb/utilities/agg_merge/test_agg_merge.h47
5 files changed, 573 insertions, 0 deletions
diff --git a/src/rocksdb/utilities/agg_merge/agg_merge.cc b/src/rocksdb/utilities/agg_merge/agg_merge.cc
new file mode 100644
index 000000000..a7eab1f12
--- /dev/null
+++ b/src/rocksdb/utilities/agg_merge/agg_merge.cc
@@ -0,0 +1,238 @@
+// Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "utilities/agg_merge/agg_merge.h"
+
+#include <assert.h>
+
+#include <deque>
+#include <memory>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+#include "port/lang.h"
+#include "port/likely.h"
+#include "rocksdb/merge_operator.h"
+#include "rocksdb/slice.h"
+#include "rocksdb/utilities/agg_merge.h"
+#include "rocksdb/utilities/options_type.h"
+#include "util/coding.h"
+#include "utilities/merge_operators.h"
+
+namespace ROCKSDB_NAMESPACE {
+static std::unordered_map<std::string, std::unique_ptr<Aggregator>> func_map;
+const std::string kUnnamedFuncName = "";
+const std::string kErrorFuncName = "kErrorFuncName";
+
+Status AddAggregator(const std::string& function_name,
+ std::unique_ptr<Aggregator>&& agg) {
+ if (function_name == kErrorFuncName) {
+ return Status::InvalidArgument(
+ "Cannot register function name kErrorFuncName");
+ }
+ func_map.emplace(function_name, std::move(agg));
+ return Status::OK();
+}
+
+AggMergeOperator::AggMergeOperator() {}
+
+std::string EncodeAggFuncAndPayloadNoCheck(const Slice& function_name,
+ const Slice& value) {
+ std::string result;
+ PutLengthPrefixedSlice(&result, function_name);
+ result += value.ToString();
+ return result;
+}
+
+Status EncodeAggFuncAndPayload(const Slice& function_name, const Slice& payload,
+ std::string& output) {
+ if (function_name == kErrorFuncName) {
+ return Status::InvalidArgument("Cannot use error function name");
+ }
+ if (function_name != kUnnamedFuncName &&
+ func_map.find(function_name.ToString()) == func_map.end()) {
+ return Status::InvalidArgument("Function name not registered");
+ }
+ output = EncodeAggFuncAndPayloadNoCheck(function_name, payload);
+ return Status::OK();
+}
+
+bool ExtractAggFuncAndValue(const Slice& op, Slice& func, Slice& value) {
+ value = op;
+ return GetLengthPrefixedSlice(&value, &func);
+}
+
+bool ExtractList(const Slice& encoded_list, std::vector<Slice>& decoded_list) {
+ decoded_list.clear();
+ Slice list_slice = encoded_list;
+ Slice item;
+ while (GetLengthPrefixedSlice(&list_slice, &item)) {
+ decoded_list.push_back(item);
+ }
+ return list_slice.empty();
+}
+
+class AggMergeOperator::Accumulator {
+ public:
+ bool Add(const Slice& op, bool is_partial_aggregation) {
+ if (ignore_operands_) {
+ return true;
+ }
+ Slice my_func;
+ Slice my_value;
+ bool ret = ExtractAggFuncAndValue(op, my_func, my_value);
+ if (!ret) {
+ ignore_operands_ = true;
+ return true;
+ }
+
+ // Determine whether we need to do partial merge.
+ if (is_partial_aggregation && !my_func.empty()) {
+ auto f = func_map.find(my_func.ToString());
+ if (f == func_map.end() || !f->second->DoPartialAggregate()) {
+ return false;
+ }
+ }
+
+ if (!func_valid_) {
+ if (my_func != kUnnamedFuncName) {
+ func_ = my_func;
+ func_valid_ = true;
+ }
+ } else if (func_ != my_func) {
+ // User switched aggregation function. Need to aggregate the older
+ // one first.
+
+ // Previous aggreagion can't be done in partial merge
+ if (is_partial_aggregation) {
+ func_valid_ = false;
+ ignore_operands_ = true;
+ return false;
+ }
+
+ // We could consider stashing an iterator into the hash of aggregators
+ // to avoid repeated lookups when the aggregator doesn't change.
+ auto f = func_map.find(func_.ToString());
+ if (f == func_map.end() || !f->second->Aggregate(values_, scratch_)) {
+ func_valid_ = false;
+ ignore_operands_ = true;
+ return true;
+ }
+ std::swap(scratch_, aggregated_);
+ values_.clear();
+ values_.push_back(aggregated_);
+ func_ = my_func;
+ }
+ values_.push_back(my_value);
+ return true;
+ }
+
+ // Return false if aggregation fails.
+ // One possible reason
+ bool GetResult(std::string& result) {
+ if (!func_valid_) {
+ return false;
+ }
+ auto f = func_map.find(func_.ToString());
+ if (f == func_map.end()) {
+ return false;
+ }
+ if (!f->second->Aggregate(values_, scratch_)) {
+ return false;
+ }
+ result = EncodeAggFuncAndPayloadNoCheck(func_, scratch_);
+ return true;
+ }
+
+ void Clear() {
+ func_.clear();
+ values_.clear();
+ aggregated_.clear();
+ scratch_.clear();
+ ignore_operands_ = false;
+ func_valid_ = false;
+ }
+
+ private:
+ Slice func_;
+ std::vector<Slice> values_;
+ std::string aggregated_;
+ std::string scratch_;
+ bool ignore_operands_ = false;
+ bool func_valid_ = false;
+};
+
+// Creating and using a new Accumulator might invoke multiple malloc and is
+// expensive if it needs to be done when processing each merge operation.
+// AggMergeOperator's merge operators can be invoked concurrently by multiple
+// threads so we cannot simply create one Aggregator and reuse.
+// We use thread local instances instead.
+AggMergeOperator::Accumulator& AggMergeOperator::GetTLSAccumulator() {
+ static thread_local Accumulator tls_acc;
+ tls_acc.Clear();
+ return tls_acc;
+}
+
+void AggMergeOperator::PackAllMergeOperands(const MergeOperationInput& merge_in,
+ MergeOperationOutput& merge_out) {
+ merge_out.new_value = "";
+ PutLengthPrefixedSlice(&merge_out.new_value, kErrorFuncName);
+ if (merge_in.existing_value != nullptr) {
+ PutLengthPrefixedSlice(&merge_out.new_value, *merge_in.existing_value);
+ }
+ for (const Slice& op : merge_in.operand_list) {
+ PutLengthPrefixedSlice(&merge_out.new_value, op);
+ }
+}
+
+bool AggMergeOperator::FullMergeV2(const MergeOperationInput& merge_in,
+ MergeOperationOutput* merge_out) const {
+ Accumulator& agg = GetTLSAccumulator();
+ if (merge_in.existing_value != nullptr) {
+ agg.Add(*merge_in.existing_value, /*is_partial_aggregation=*/false);
+ }
+ for (const Slice& e : merge_in.operand_list) {
+ agg.Add(e, /*is_partial_aggregation=*/false);
+ }
+
+ bool succ = agg.GetResult(merge_out->new_value);
+ if (!succ) {
+ // If aggregation can't happen, pack all merge operands. In contrast to
+ // merge operator, we don't want to fail the DB. If users insert wrong
+ // format or call unregistered an aggregation function, we still hope
+ // the DB can continue functioning with other keys.
+ PackAllMergeOperands(merge_in, *merge_out);
+ }
+ agg.Clear();
+ return true;
+}
+
+bool AggMergeOperator::PartialMergeMulti(const Slice& /*key*/,
+ const std::deque<Slice>& operand_list,
+ std::string* new_value,
+ Logger* /*logger*/) const {
+ Accumulator& agg = GetTLSAccumulator();
+ bool do_aggregation = true;
+ for (const Slice& item : operand_list) {
+ do_aggregation = agg.Add(item, /*is_partial_aggregation=*/true);
+ if (!do_aggregation) {
+ break;
+ }
+ }
+ if (do_aggregation) {
+ do_aggregation = agg.GetResult(*new_value);
+ }
+ agg.Clear();
+ return do_aggregation;
+}
+
+std::shared_ptr<MergeOperator> GetAggMergeOperator() {
+ STATIC_AVOID_DESTRUCTION(std::shared_ptr<MergeOperator>, instance)
+ (std::make_shared<AggMergeOperator>());
+ assert(instance);
+ return instance;
+}
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/agg_merge/agg_merge.h b/src/rocksdb/utilities/agg_merge/agg_merge.h
new file mode 100644
index 000000000..00e58de08
--- /dev/null
+++ b/src/rocksdb/utilities/agg_merge/agg_merge.h
@@ -0,0 +1,49 @@
+// Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+#include <algorithm>
+#include <cstddef>
+#include <memory>
+#include <unordered_map>
+
+#include "rocksdb/merge_operator.h"
+#include "rocksdb/slice.h"
+#include "rocksdb/utilities/agg_merge.h"
+#include "utilities/cassandra/cassandra_options.h"
+
+namespace ROCKSDB_NAMESPACE {
+class AggMergeOperator : public MergeOperator {
+ public:
+ explicit AggMergeOperator();
+
+ bool FullMergeV2(const MergeOperationInput& merge_in,
+ MergeOperationOutput* merge_out) const override;
+
+ bool PartialMergeMulti(const Slice& key,
+ const std::deque<Slice>& operand_list,
+ std::string* new_value, Logger* logger) const override;
+
+ const char* Name() const override { return kClassName(); }
+ static const char* kClassName() { return "AggMergeOperator.v1"; }
+
+ bool AllowSingleOperand() const override { return true; }
+
+ bool ShouldMerge(const std::vector<Slice>&) const override { return false; }
+
+ private:
+ class Accumulator;
+
+ // Pack all merge operands into one value. This is called when aggregation
+ // fails. The existing values are preserved and returned so that users can
+ // debug the problem.
+ static void PackAllMergeOperands(const MergeOperationInput& merge_in,
+ MergeOperationOutput& merge_out);
+ static Accumulator& GetTLSAccumulator();
+};
+
+extern std::string EncodeAggFuncAndPayloadNoCheck(const Slice& function_name,
+ const Slice& value);
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/agg_merge/agg_merge_test.cc b/src/rocksdb/utilities/agg_merge/agg_merge_test.cc
new file mode 100644
index 000000000..a65441cd0
--- /dev/null
+++ b/src/rocksdb/utilities/agg_merge/agg_merge_test.cc
@@ -0,0 +1,135 @@
+// Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "rocksdb/utilities/agg_merge.h"
+
+#include <gtest/gtest.h>
+
+#include <memory>
+
+#include "db/db_test_util.h"
+#include "rocksdb/options.h"
+#include "test_util/testharness.h"
+#include "utilities/agg_merge/agg_merge.h"
+#include "utilities/agg_merge/test_agg_merge.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+class AggMergeTest : public DBTestBase {
+ public:
+ AggMergeTest() : DBTestBase("agg_merge_db_test", /*env_do_fsync=*/true) {}
+};
+
+TEST_F(AggMergeTest, TestUsingMergeOperator) {
+ ASSERT_OK(AddAggregator("sum", std::make_unique<SumAggregator>()));
+ ASSERT_OK(AddAggregator("last3", std::make_unique<Last3Aggregator>()));
+ ASSERT_OK(AddAggregator("mul", std::make_unique<MultipleAggregator>()));
+
+ Options options = CurrentOptions();
+ options.merge_operator = GetAggMergeOperator();
+ Reopen(options);
+ std::string v = EncodeHelper::EncodeFuncAndInt("sum", 10);
+ ASSERT_OK(Merge("foo", v));
+ v = EncodeHelper::EncodeFuncAndInt("sum", 20);
+ ASSERT_OK(Merge("foo", v));
+ v = EncodeHelper::EncodeFuncAndInt("sum", 15);
+ ASSERT_OK(Merge("foo", v));
+
+ v = EncodeHelper::EncodeFuncAndList("last3", {"a", "b"});
+ ASSERT_OK(Merge("bar", v));
+ v = EncodeHelper::EncodeFuncAndList("last3", {"c", "d", "e"});
+ ASSERT_OK(Merge("bar", v));
+ ASSERT_OK(Flush());
+ v = EncodeHelper::EncodeFuncAndList("last3", {"f"});
+ ASSERT_OK(Merge("bar", v));
+
+ // Test Put() without aggregation type.
+ v = EncodeHelper::EncodeFuncAndInt(kUnnamedFuncName, 30);
+ ASSERT_OK(Put("foo2", v));
+ v = EncodeHelper::EncodeFuncAndInt("sum", 10);
+ ASSERT_OK(Merge("foo2", v));
+ v = EncodeHelper::EncodeFuncAndInt("sum", 20);
+ ASSERT_OK(Merge("foo2", v));
+
+ EXPECT_EQ(EncodeHelper::EncodeFuncAndInt("sum", 45), Get("foo"));
+ EXPECT_EQ(EncodeHelper::EncodeFuncAndList("last3", {"f", "c", "d"}),
+ Get("bar"));
+ EXPECT_EQ(EncodeHelper::EncodeFuncAndInt("sum", 60), Get("foo2"));
+
+ // Test changing aggregation type
+ v = EncodeHelper::EncodeFuncAndInt("mul", 10);
+ ASSERT_OK(Put("bar2", v));
+ v = EncodeHelper::EncodeFuncAndInt("mul", 20);
+ ASSERT_OK(Merge("bar2", v));
+ v = EncodeHelper::EncodeFuncAndInt("sum", 30);
+ ASSERT_OK(Merge("bar2", v));
+ v = EncodeHelper::EncodeFuncAndInt("sum", 40);
+ ASSERT_OK(Merge("bar2", v));
+ EXPECT_EQ(EncodeHelper::EncodeFuncAndInt("sum", 10 * 20 + 30 + 40),
+ Get("bar2"));
+
+ // Changing aggregation type with partial merge
+ v = EncodeHelper::EncodeFuncAndInt("mul", 10);
+ ASSERT_OK(Merge("foo3", v));
+ ASSERT_OK(Flush());
+ v = EncodeHelper::EncodeFuncAndInt("mul", 10);
+ ASSERT_OK(Merge("foo3", v));
+ v = EncodeHelper::EncodeFuncAndInt("mul", 10);
+ ASSERT_OK(Merge("foo3", v));
+ v = EncodeHelper::EncodeFuncAndInt("sum", 10);
+ ASSERT_OK(Merge("foo3", v));
+ ASSERT_OK(Flush());
+ EXPECT_EQ(EncodeHelper::EncodeFuncAndInt("sum", 10 * 10 * 10 + 10),
+ Get("foo3"));
+
+ // Merge after full merge
+ v = EncodeHelper::EncodeFuncAndInt("sum", 1);
+ ASSERT_OK(Merge("foo4", v));
+ v = EncodeHelper::EncodeFuncAndInt("sum", 2);
+ ASSERT_OK(Merge("foo4", v));
+ ASSERT_OK(Flush());
+ v = EncodeHelper::EncodeFuncAndInt("sum", 3);
+ ASSERT_OK(Merge("foo4", v));
+ v = EncodeHelper::EncodeFuncAndInt("sum", 4);
+ ASSERT_OK(Merge("foo4", v));
+ ASSERT_OK(Flush());
+ ASSERT_OK(db_->CompactRange(CompactRangeOptions(), nullptr, nullptr));
+ v = EncodeHelper::EncodeFuncAndInt("sum", 5);
+ ASSERT_OK(Merge("foo4", v));
+ EXPECT_EQ(EncodeHelper::EncodeFuncAndInt("sum", 15), Get("foo4"));
+
+ // Test unregistered function name
+ v = EncodeAggFuncAndPayloadNoCheck("non_existing", "1");
+ ASSERT_OK(Merge("bar3", v));
+ std::string v1;
+ v1 = EncodeAggFuncAndPayloadNoCheck("non_existing", "invalid");
+ ;
+ ASSERT_OK(Merge("bar3", v1));
+ EXPECT_EQ(EncodeAggFuncAndPayloadNoCheck(kErrorFuncName,
+ EncodeHelper::EncodeList({v, v1})),
+ Get("bar3"));
+
+ // invalidate input
+ ASSERT_OK(EncodeAggFuncAndPayload("sum", "invalid", v));
+ ASSERT_OK(Merge("bar4", v));
+ v1 = EncodeHelper::EncodeFuncAndInt("sum", 20);
+ ASSERT_OK(Merge("bar4", v1));
+ std::string aggregated_value = Get("bar4");
+ Slice func, payload;
+ ASSERT_TRUE(ExtractAggFuncAndValue(aggregated_value, func, payload));
+ EXPECT_EQ(kErrorFuncName, func);
+ std::vector<Slice> decoded_list;
+ ASSERT_TRUE(ExtractList(payload, decoded_list));
+ ASSERT_EQ(2, decoded_list.size());
+ ASSERT_EQ(v, decoded_list[0]);
+ ASSERT_EQ(v1, decoded_list[1]);
+}
+} // namespace ROCKSDB_NAMESPACE
+
+int main(int argc, char** argv) {
+ ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/rocksdb/utilities/agg_merge/test_agg_merge.cc b/src/rocksdb/utilities/agg_merge/test_agg_merge.cc
new file mode 100644
index 000000000..06e5b5697
--- /dev/null
+++ b/src/rocksdb/utilities/agg_merge/test_agg_merge.cc
@@ -0,0 +1,104 @@
+// Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#include "test_agg_merge.h"
+
+#include <assert.h>
+
+#include <deque>
+#include <vector>
+
+#include "util/coding.h"
+#include "utilities/agg_merge/agg_merge.h"
+
+namespace ROCKSDB_NAMESPACE {
+
+std::string EncodeHelper::EncodeFuncAndInt(const Slice& function_name,
+ int64_t value) {
+ std::string encoded_value;
+ PutVarsignedint64(&encoded_value, value);
+ std::string ret;
+ Status s = EncodeAggFuncAndPayload(function_name, encoded_value, ret);
+ assert(s.ok());
+ return ret;
+}
+
+std::string EncodeHelper::EncodeInt(int64_t value) {
+ std::string encoded_value;
+ PutVarsignedint64(&encoded_value, value);
+ return encoded_value;
+}
+
+std::string EncodeHelper::EncodeFuncAndList(const Slice& function_name,
+ const std::vector<Slice>& list) {
+ std::string ret;
+ Status s = EncodeAggFuncAndPayload(function_name, EncodeList(list), ret);
+ assert(s.ok());
+ return ret;
+}
+
+std::string EncodeHelper::EncodeList(const std::vector<Slice>& list) {
+ std::string result;
+ for (const Slice& entity : list) {
+ PutLengthPrefixedSlice(&result, entity);
+ }
+ return result;
+}
+
+bool SumAggregator::Aggregate(const std::vector<Slice>& item_list,
+ std::string& result) const {
+ int64_t sum = 0;
+ for (const Slice& item : item_list) {
+ int64_t ivalue;
+ Slice v = item;
+ if (!GetVarsignedint64(&v, &ivalue) || !v.empty()) {
+ return false;
+ }
+ sum += ivalue;
+ }
+ result = EncodeHelper::EncodeInt(sum);
+ return true;
+}
+
+bool MultipleAggregator::Aggregate(const std::vector<Slice>& item_list,
+ std::string& result) const {
+ int64_t mresult = 1;
+ for (const Slice& item : item_list) {
+ int64_t ivalue;
+ Slice v = item;
+ if (!GetVarsignedint64(&v, &ivalue) || !v.empty()) {
+ return false;
+ }
+ mresult *= ivalue;
+ }
+ result = EncodeHelper::EncodeInt(mresult);
+ return true;
+}
+
+bool Last3Aggregator::Aggregate(const std::vector<Slice>& item_list,
+ std::string& result) const {
+ std::vector<Slice> last3;
+ last3.reserve(3);
+ for (auto it = item_list.rbegin(); it != item_list.rend(); ++it) {
+ Slice input = *it;
+ Slice entity;
+ bool ret;
+ while ((ret = GetLengthPrefixedSlice(&input, &entity)) == true) {
+ last3.push_back(entity);
+ if (last3.size() >= 3) {
+ break;
+ }
+ }
+ if (last3.size() >= 3) {
+ break;
+ }
+ if (!ret) {
+ continue;
+ }
+ }
+ result = EncodeHelper::EncodeList(last3);
+ return true;
+}
+} // namespace ROCKSDB_NAMESPACE
diff --git a/src/rocksdb/utilities/agg_merge/test_agg_merge.h b/src/rocksdb/utilities/agg_merge/test_agg_merge.h
new file mode 100644
index 000000000..5bdf8b9cc
--- /dev/null
+++ b/src/rocksdb/utilities/agg_merge/test_agg_merge.h
@@ -0,0 +1,47 @@
+// Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
+// This source code is licensed under both the GPLv2 (found in the
+// COPYING file in the root directory) and Apache 2.0 License
+// (found in the LICENSE.Apache file in the root directory).
+
+#pragma once
+#include <algorithm>
+#include <cstddef>
+#include <memory>
+#include <unordered_map>
+
+#include "rocksdb/merge_operator.h"
+#include "rocksdb/slice.h"
+#include "rocksdb/utilities/agg_merge.h"
+#include "utilities/cassandra/cassandra_options.h"
+
+namespace ROCKSDB_NAMESPACE {
+class SumAggregator : public Aggregator {
+ public:
+ ~SumAggregator() override {}
+ bool Aggregate(const std::vector<Slice>&, std::string& result) const override;
+ bool DoPartialAggregate() const override { return true; }
+};
+
+class MultipleAggregator : public Aggregator {
+ public:
+ ~MultipleAggregator() override {}
+ bool Aggregate(const std::vector<Slice>&, std::string& result) const override;
+ bool DoPartialAggregate() const override { return true; }
+};
+
+class Last3Aggregator : public Aggregator {
+ public:
+ ~Last3Aggregator() override {}
+ bool Aggregate(const std::vector<Slice>&, std::string& result) const override;
+};
+
+class EncodeHelper {
+ public:
+ static std::string EncodeFuncAndInt(const Slice& function_name,
+ int64_t value);
+ static std::string EncodeInt(int64_t value);
+ static std::string EncodeList(const std::vector<Slice>& list);
+ static std::string EncodeFuncAndList(const Slice& function_name,
+ const std::vector<Slice>& list);
+};
+} // namespace ROCKSDB_NAMESPACE