summaryrefslogtreecommitdiffstats
path: root/src/arrow/r/src/recordbatch.cpp
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--src/arrow/r/src/recordbatch.cpp309
1 files changed, 309 insertions, 0 deletions
diff --git a/src/arrow/r/src/recordbatch.cpp b/src/arrow/r/src/recordbatch.cpp
new file mode 100644
index 000000000..81e20e9ec
--- /dev/null
+++ b/src/arrow/r/src/recordbatch.cpp
@@ -0,0 +1,309 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "./arrow_types.h"
+
+#if defined(ARROW_R_WITH_ARROW)
+#include <arrow/array/array_base.h>
+#include <arrow/io/file.h>
+#include <arrow/io/memory.h>
+#include <arrow/ipc/reader.h>
+#include <arrow/ipc/writer.h>
+#include <arrow/type.h>
+#include <arrow/util/key_value_metadata.h>
+
+// [[arrow::export]]
+int RecordBatch__num_columns(const std::shared_ptr<arrow::RecordBatch>& x) {
+ return x->num_columns();
+}
+
+// [[arrow::export]]
+int RecordBatch__num_rows(const std::shared_ptr<arrow::RecordBatch>& x) {
+ return x->num_rows();
+}
+
+// [[arrow::export]]
+std::shared_ptr<arrow::Schema> RecordBatch__schema(
+ const std::shared_ptr<arrow::RecordBatch>& x) {
+ return x->schema();
+}
+
+// [[arrow::export]]
+std::shared_ptr<arrow::RecordBatch> RecordBatch__RenameColumns(
+ const std::shared_ptr<arrow::RecordBatch>& batch,
+ const std::vector<std::string>& names) {
+ int n = batch->num_columns();
+ if (names.size() != static_cast<size_t>(n)) {
+ cpp11::stop("RecordBatch has %d columns but %d names were provided", n, names.size());
+ }
+ std::vector<std::shared_ptr<arrow::Field>> fields(n);
+ for (int i = 0; i < n; i++) {
+ fields[i] = batch->schema()->field(i)->WithName(names[i]);
+ }
+ auto schema = std::make_shared<arrow::Schema>(std::move(fields));
+ return arrow::RecordBatch::Make(schema, batch->num_rows(), batch->columns());
+}
+
+// [[arrow::export]]
+std::shared_ptr<arrow::RecordBatch> RecordBatch__ReplaceSchemaMetadata(
+ const std::shared_ptr<arrow::RecordBatch>& x, cpp11::strings metadata) {
+ auto vec_metadata = cpp11::as_cpp<std::vector<std::string>>(metadata);
+ auto names_metadata = cpp11::as_cpp<std::vector<std::string>>(metadata.names());
+ auto kv = std::shared_ptr<arrow::KeyValueMetadata>(
+ new arrow::KeyValueMetadata(names_metadata, vec_metadata));
+ return x->ReplaceSchemaMetadata(kv);
+}
+
+// [[arrow::export]]
+cpp11::list RecordBatch__columns(const std::shared_ptr<arrow::RecordBatch>& batch) {
+ auto nc = batch->num_columns();
+ arrow::ArrayVector res(nc);
+ for (int i = 0; i < nc; i++) {
+ res[i] = batch->column(i);
+ }
+ return arrow::r::to_r_list(res);
+}
+
+// [[arrow::export]]
+std::shared_ptr<arrow::Array> RecordBatch__column(
+ const std::shared_ptr<arrow::RecordBatch>& batch, R_xlen_t i) {
+ arrow::r::validate_index(i, batch->num_columns());
+ return batch->column(i);
+}
+
+// [[arrow::export]]
+std::shared_ptr<arrow::Array> RecordBatch__GetColumnByName(
+ const std::shared_ptr<arrow::RecordBatch>& batch, const std::string& name) {
+ return batch->GetColumnByName(name);
+}
+
+// [[arrow::export]]
+std::shared_ptr<arrow::RecordBatch> RecordBatch__SelectColumns(
+ const std::shared_ptr<arrow::RecordBatch>& batch, const std::vector<int>& indices) {
+ return ValueOrStop(batch->SelectColumns(indices));
+}
+
+// [[arrow::export]]
+bool RecordBatch__Equals(const std::shared_ptr<arrow::RecordBatch>& self,
+ const std::shared_ptr<arrow::RecordBatch>& other,
+ bool check_metadata) {
+ return self->Equals(*other, check_metadata);
+}
+
+// [[arrow::export]]
+std::shared_ptr<arrow::RecordBatch> RecordBatch__AddColumn(
+ const std::shared_ptr<arrow::RecordBatch>& batch, R_xlen_t i,
+ const std::shared_ptr<arrow::Field>& field,
+ const std::shared_ptr<arrow::Array>& column) {
+ return ValueOrStop(batch->AddColumn(i, field, column));
+}
+
+// [[arrow::export]]
+std::shared_ptr<arrow::RecordBatch> RecordBatch__SetColumn(
+ const std::shared_ptr<arrow::RecordBatch>& batch, R_xlen_t i,
+ const std::shared_ptr<arrow::Field>& field,
+ const std::shared_ptr<arrow::Array>& column) {
+ return ValueOrStop(batch->SetColumn(i, field, column));
+}
+
+// [[arrow::export]]
+std::shared_ptr<arrow::RecordBatch> RecordBatch__RemoveColumn(
+ const std::shared_ptr<arrow::RecordBatch>& batch, R_xlen_t i) {
+ arrow::r::validate_index(i, batch->num_columns());
+ return ValueOrStop(batch->RemoveColumn(i));
+}
+
+// [[arrow::export]]
+std::string RecordBatch__column_name(const std::shared_ptr<arrow::RecordBatch>& batch,
+ R_xlen_t i) {
+ arrow::r::validate_index(i, batch->num_columns());
+ return batch->column_name(i);
+}
+
+// [[arrow::export]]
+cpp11::writable::strings RecordBatch__names(
+ const std::shared_ptr<arrow::RecordBatch>& batch) {
+ int n = batch->num_columns();
+ cpp11::writable::strings names(n);
+ for (int i = 0; i < n; i++) {
+ names[i] = batch->column_name(i);
+ }
+ return names;
+}
+
+// [[arrow::export]]
+std::shared_ptr<arrow::RecordBatch> RecordBatch__Slice1(
+ const std::shared_ptr<arrow::RecordBatch>& self, R_xlen_t offset) {
+ arrow::r::validate_slice_offset(offset, self->num_rows());
+ return self->Slice(offset);
+}
+
+// [[arrow::export]]
+std::shared_ptr<arrow::RecordBatch> RecordBatch__Slice2(
+ const std::shared_ptr<arrow::RecordBatch>& self, R_xlen_t offset, R_xlen_t length) {
+ arrow::r::validate_slice_offset(offset, self->num_rows());
+ arrow::r::validate_slice_length(length, self->num_rows() - offset);
+ return self->Slice(offset, length);
+}
+
+// [[arrow::export]]
+cpp11::raws ipc___SerializeRecordBatch__Raw(
+ const std::shared_ptr<arrow::RecordBatch>& batch) {
+ // how many bytes do we need ?
+ int64_t size;
+ StopIfNotOk(arrow::ipc::GetRecordBatchSize(*batch, &size));
+
+ // allocate the result raw vector
+ cpp11::writable::raws out(size);
+
+ // serialize into the bytes of the raw vector
+ auto buffer = std::make_shared<arrow::r::RBuffer<cpp11::raws>>(out);
+ arrow::io::FixedSizeBufferWriter stream(buffer);
+ StopIfNotOk(arrow::ipc::SerializeRecordBatch(
+ *batch, arrow::ipc::IpcWriteOptions::Defaults(), &stream));
+ StopIfNotOk(stream.Close());
+
+ return out;
+}
+
+// [[arrow::export]]
+std::shared_ptr<arrow::RecordBatch> ipc___ReadRecordBatch__InputStream__Schema(
+ const std::shared_ptr<arrow::io::InputStream>& stream,
+ const std::shared_ptr<arrow::Schema>& schema) {
+ // TODO: promote to function arg
+ arrow::ipc::DictionaryMemo memo;
+ StopIfNotOk(memo.fields().AddSchemaFields(*schema));
+ return ValueOrStop(arrow::ipc::ReadRecordBatch(
+ schema, &memo, arrow::ipc::IpcReadOptions::Defaults(), stream.get()));
+}
+
+namespace arrow {
+namespace r {
+
+arrow::Status check_consistent_array_size(
+ const std::vector<std::shared_ptr<arrow::Array>>& arrays, int64_t* num_rows) {
+ if (arrays.size()) {
+ *num_rows = arrays[0]->length();
+
+ for (const auto& array : arrays) {
+ if (array->length() != *num_rows) {
+ return arrow::Status::Invalid("All arrays must have the same length");
+ }
+ }
+ }
+
+ return arrow::Status::OK();
+}
+
+Status count_fields(SEXP lst, int* out) {
+ int res = 0;
+ R_xlen_t n = XLENGTH(lst);
+ SEXP names = Rf_getAttrib(lst, R_NamesSymbol);
+ for (R_xlen_t i = 0; i < n; i++) {
+ if (LENGTH(STRING_ELT(names, i)) > 0) {
+ ++res;
+ } else {
+ SEXP x = VECTOR_ELT(lst, i);
+ if (Rf_inherits(x, "data.frame")) {
+ res += XLENGTH(x);
+ } else {
+ return Status::RError(
+ "only data frames are allowed as unnamed arguments to be auto spliced");
+ }
+ }
+ }
+ *out = res;
+ return Status::OK();
+}
+
+} // namespace r
+} // namespace arrow
+
+std::shared_ptr<arrow::RecordBatch> RecordBatch__from_arrays__known_schema(
+ const std::shared_ptr<arrow::Schema>& schema, SEXP lst) {
+ int num_fields;
+ StopIfNotOk(arrow::r::count_fields(lst, &num_fields));
+
+ if (schema->num_fields() != num_fields) {
+ cpp11::stop("incompatible. schema has %d fields, and %d arrays are supplied",
+ schema->num_fields(), num_fields);
+ }
+
+ // convert lst to a vector of arrow::Array
+ std::vector<std::shared_ptr<arrow::Array>> arrays(num_fields);
+
+ auto fill_array = [&arrays, &schema](int j, SEXP x, std::string name) {
+ if (schema->field(j)->name() != name) {
+ cpp11::stop("field at index %d has name '%s' != '%s'", j + 1,
+ schema->field(j)->name().c_str(), name.c_str());
+ }
+ arrays[j] = arrow::r::vec_to_arrow(x, schema->field(j)->type(), false);
+ };
+
+ arrow::r::TraverseDots(lst, num_fields, fill_array);
+
+ int64_t num_rows = 0;
+ StopIfNotOk(arrow::r::check_consistent_array_size(arrays, &num_rows));
+ return arrow::RecordBatch::Make(schema, num_rows, arrays);
+}
+
+namespace arrow {
+namespace r {
+
+arrow::Status CollectRecordBatchArrays(
+ SEXP lst, const std::shared_ptr<arrow::Schema>& schema, int num_fields, bool inferred,
+ std::vector<std::shared_ptr<arrow::Array>>& arrays) {
+ auto extract_one_array = [&arrays, &schema, inferred](int j, SEXP x, cpp11::r_string) {
+ arrays[j] = arrow::r::vec_to_arrow(x, schema->field(j)->type(), inferred);
+ };
+ arrow::r::TraverseDots(lst, num_fields, extract_one_array);
+ return arrow::Status::OK();
+}
+
+} // namespace r
+} // namespace arrow
+
+// [[arrow::export]]
+std::shared_ptr<arrow::RecordBatch> RecordBatch__from_arrays(SEXP schema_sxp, SEXP lst) {
+ bool infer_schema = !Rf_inherits(schema_sxp, "Schema");
+
+ int num_fields;
+ StopIfNotOk(arrow::r::count_fields(lst, &num_fields));
+
+ // schema + metadata
+ std::shared_ptr<arrow::Schema> schema;
+ StopIfNotOk(arrow::r::InferSchemaFromDots(lst, schema_sxp, num_fields, schema));
+ StopIfNotOk(arrow::r::AddMetadataFromDots(lst, num_fields, schema));
+
+ // RecordBatch
+ if (!infer_schema) {
+ return RecordBatch__from_arrays__known_schema(schema, lst);
+ }
+
+ // RecordBatch
+ std::vector<std::shared_ptr<arrow::Array>> arrays(num_fields);
+ StopIfNotOk(
+ arrow::r::CollectRecordBatchArrays(lst, schema, num_fields, infer_schema, arrays));
+
+ // extract number of rows, and check their consistency
+ int64_t num_rows = 0;
+ StopIfNotOk(arrow::r::check_consistent_array_size(arrays, &num_rows));
+
+ return arrow::RecordBatch::Make(schema, num_rows, arrays);
+}
+
+#endif