diff options
Diffstat (limited to 'src/arrow/cpp/src/arrow/compute/kernels/aggregate_tdigest.cc')
-rw-r--r-- | src/arrow/cpp/src/arrow/compute/kernels/aggregate_tdigest.cc | 235 |
1 files changed, 235 insertions, 0 deletions
diff --git a/src/arrow/cpp/src/arrow/compute/kernels/aggregate_tdigest.cc b/src/arrow/cpp/src/arrow/compute/kernels/aggregate_tdigest.cc new file mode 100644 index 000000000..0fddf38f5 --- /dev/null +++ b/src/arrow/cpp/src/arrow/compute/kernels/aggregate_tdigest.cc @@ -0,0 +1,235 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/api_aggregate.h" +#include "arrow/compute/kernels/aggregate_internal.h" +#include "arrow/compute/kernels/common.h" +#include "arrow/util/bit_run_reader.h" +#include "arrow/util/tdigest.h" + +namespace arrow { +namespace compute { +namespace internal { + +namespace { + +using arrow::internal::TDigest; +using arrow::internal::VisitSetBitRunsVoid; + +template <typename ArrowType> +struct TDigestImpl : public ScalarAggregator { + using ThisType = TDigestImpl<ArrowType>; + using ArrayType = typename TypeTraits<ArrowType>::ArrayType; + using CType = typename ArrowType::c_type; + + explicit TDigestImpl(const TDigestOptions& options) + : options{options}, + tdigest{options.delta, options.buffer_size}, + count{0}, + all_valid{true} {} + + Status Consume(KernelContext*, const ExecBatch& batch) override { + if (!this->all_valid) return Status::OK(); + if (!options.skip_nulls && batch[0].null_count() > 0) { + this->all_valid = false; + return Status::OK(); + } + if (batch[0].is_array()) { + const ArrayData& data = *batch[0].array(); + const CType* values = data.GetValues<CType>(1); + + if (data.length > data.GetNullCount()) { + this->count += data.length - data.GetNullCount(); + VisitSetBitRunsVoid(data.buffers[0], data.offset, data.length, + [&](int64_t pos, int64_t len) { + for (int64_t i = 0; i < len; ++i) { + this->tdigest.NanAdd(values[pos + i]); + } + }); + } + } else { + const CType value = UnboxScalar<ArrowType>::Unbox(*batch[0].scalar()); + if (batch[0].scalar()->is_valid) { + this->count += 1; + for (int64_t i = 0; i < batch.length; i++) { + this->tdigest.NanAdd(value); + } + } + } + return Status::OK(); + } + + Status MergeFrom(KernelContext*, KernelState&& src) override { + const auto& other = checked_cast<const ThisType&>(src); + if (!this->all_valid || !other.all_valid) { + this->all_valid = false; + return Status::OK(); + } + this->tdigest.Merge(other.tdigest); + this->count += other.count; + return Status::OK(); + } + + Status Finalize(KernelContext* ctx, Datum* out) override { + const int64_t out_length = options.q.size(); + auto out_data = ArrayData::Make(float64(), out_length, 0); + out_data->buffers.resize(2, nullptr); + ARROW_ASSIGN_OR_RAISE(out_data->buffers[1], + ctx->Allocate(out_length * sizeof(double))); + double* out_buffer = out_data->template GetMutableValues<double>(1); + + if (this->tdigest.is_empty() || !this->all_valid || this->count < options.min_count) { + ARROW_ASSIGN_OR_RAISE(out_data->buffers[0], ctx->AllocateBitmap(out_length)); + std::memset(out_data->buffers[0]->mutable_data(), 0x00, + out_data->buffers[0]->size()); + std::fill(out_buffer, out_buffer + out_length, 0.0); + out_data->null_count = out_length; + } else { + for (int64_t i = 0; i < out_length; ++i) { + out_buffer[i] = this->tdigest.Quantile(this->options.q[i]); + } + } + *out = Datum(std::move(out_data)); + return Status::OK(); + } + + const TDigestOptions options; + TDigest tdigest; + int64_t count; + bool all_valid; +}; + +struct TDigestInitState { + std::unique_ptr<KernelState> state; + KernelContext* ctx; + const DataType& in_type; + const TDigestOptions& options; + + TDigestInitState(KernelContext* ctx, const DataType& in_type, + const TDigestOptions& options) + : ctx(ctx), in_type(in_type), options(options) {} + + Status Visit(const DataType&) { + return Status::NotImplemented("No tdigest implemented"); + } + + Status Visit(const HalfFloatType&) { + return Status::NotImplemented("No tdigest implemented"); + } + + template <typename Type> + enable_if_t<is_number_type<Type>::value, Status> Visit(const Type&) { + state.reset(new TDigestImpl<Type>(options)); + return Status::OK(); + } + + Result<std::unique_ptr<KernelState>> Create() { + RETURN_NOT_OK(VisitTypeInline(in_type, this)); + return std::move(state); + } +}; + +Result<std::unique_ptr<KernelState>> TDigestInit(KernelContext* ctx, + const KernelInitArgs& args) { + TDigestInitState visitor(ctx, *args.inputs[0].type, + static_cast<const TDigestOptions&>(*args.options)); + return visitor.Create(); +} + +void AddTDigestKernels(KernelInit init, + const std::vector<std::shared_ptr<DataType>>& types, + ScalarAggregateFunction* func) { + for (const auto& ty : types) { + auto sig = KernelSignature::Make({InputType(ty)}, float64()); + AddAggKernel(std::move(sig), init, func); + } +} + +const FunctionDoc tdigest_doc{ + "Approximate quantiles of a numeric array with T-Digest algorithm", + ("By default, 0.5 quantile (median) is returned.\n" + "Nulls and NaNs are ignored.\n" + "An array of nulls is returned if there is no valid data point."), + {"array"}, + "TDigestOptions"}; + +const FunctionDoc approximate_median_doc{ + "Approximate median of a numeric array with T-Digest algorithm", + ("Nulls and NaNs are ignored.\n" + "A null scalar is returned if there is no valid data point."), + {"array"}, + "ScalarAggregateOptions"}; + +std::shared_ptr<ScalarAggregateFunction> AddTDigestAggKernels() { + static auto default_tdigest_options = TDigestOptions::Defaults(); + auto func = std::make_shared<ScalarAggregateFunction>( + "tdigest", Arity::Unary(), &tdigest_doc, &default_tdigest_options); + AddTDigestKernels(TDigestInit, NumericTypes(), func.get()); + return func; +} + +std::shared_ptr<ScalarAggregateFunction> AddApproximateMedianAggKernels( + const ScalarAggregateFunction* tdigest_func) { + static ScalarAggregateOptions default_scalar_aggregate_options; + + auto median = std::make_shared<ScalarAggregateFunction>( + "approximate_median", Arity::Unary(), &approximate_median_doc, + &default_scalar_aggregate_options); + + auto sig = + KernelSignature::Make({InputType(ValueDescr::ANY)}, ValueDescr::Scalar(float64())); + + auto init = [tdigest_func]( + KernelContext* ctx, + const KernelInitArgs& args) -> Result<std::unique_ptr<KernelState>> { + std::vector<ValueDescr> inputs = args.inputs; + ARROW_ASSIGN_OR_RAISE(auto kernel, tdigest_func->DispatchBest(&inputs)); + const auto& scalar_options = + checked_cast<const ScalarAggregateOptions&>(*args.options); + TDigestOptions options; + // Default q = 0.5 + options.min_count = scalar_options.min_count; + options.skip_nulls = scalar_options.skip_nulls; + KernelInitArgs new_args{kernel, inputs, &options}; + return kernel->init(ctx, new_args); + }; + + auto finalize = [](KernelContext* ctx, Datum* out) -> Status { + Datum temp; + RETURN_NOT_OK(checked_cast<ScalarAggregator*>(ctx->state())->Finalize(ctx, &temp)); + const auto arr = temp.make_array(); + DCHECK_EQ(arr->length(), 1); + return arr->GetScalar(0).Value(out); + }; + + AddAggKernel(std::move(sig), std::move(init), std::move(finalize), median.get()); + return median; +} + +} // namespace + +void RegisterScalarAggregateTDigest(FunctionRegistry* registry) { + auto tdigest = AddTDigestAggKernels(); + DCHECK_OK(registry->AddFunction(tdigest)); + + auto approx_median = AddApproximateMedianAggKernels(tdigest.get()); + DCHECK_OK(registry->AddFunction(approx_median)); +} + +} // namespace internal +} // namespace compute +} // namespace arrow |