summaryrefslogtreecommitdiffstats
path: root/src/arrow/r/R/query-engine.R
diff options
context:
space:
mode:
Diffstat (limited to 'src/arrow/r/R/query-engine.R')
-rw-r--r--src/arrow/r/R/query-engine.R298
1 files changed, 298 insertions, 0 deletions
diff --git a/src/arrow/r/R/query-engine.R b/src/arrow/r/R/query-engine.R
new file mode 100644
index 000000000..234aaf569
--- /dev/null
+++ b/src/arrow/r/R/query-engine.R
@@ -0,0 +1,298 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+do_exec_plan <- function(.data) {
+ plan <- ExecPlan$create()
+ final_node <- plan$Build(.data)
+ tab <- plan$Run(final_node)
+
+ # TODO (ARROW-14289): make the head/tail methods return RBR not Table
+ if (inherits(tab, "RecordBatchReader")) {
+ tab <- tab$read_table()
+ }
+
+ # If arrange() created $temp_columns, make sure to omit them from the result
+ # We can't currently handle this in the ExecPlan itself because sorting
+ # happens in the end (SinkNode) so nothing comes after it.
+ if (length(final_node$sort$temp_columns) > 0) {
+ tab <- tab[, setdiff(names(tab), final_node$sort$temp_columns), drop = FALSE]
+ }
+
+ if (ncol(tab)) {
+ # Apply any column metadata from the original schema, where appropriate
+ original_schema <- source_data(.data)$schema
+ # TODO: do we care about other (non-R) metadata preservation?
+ # How would we know if it were meaningful?
+ r_meta <- original_schema$r_metadata
+ if (!is.null(r_meta)) {
+ # Filter r_metadata$columns on columns with name _and_ type match
+ new_schema <- tab$schema
+ common_names <- intersect(names(r_meta$columns), names(tab))
+ keep <- common_names[
+ map_lgl(common_names, ~ original_schema[[.]] == new_schema[[.]])
+ ]
+ r_meta$columns <- r_meta$columns[keep]
+ if (has_aggregation(.data)) {
+ # dplyr drops top-level attributes if you do summarize
+ r_meta$attributes <- NULL
+ }
+ tab$r_metadata <- r_meta
+ }
+ }
+
+ tab
+}
+
+ExecPlan <- R6Class("ExecPlan",
+ inherit = ArrowObject,
+ public = list(
+ Scan = function(dataset) {
+ # Handle arrow_dplyr_query
+ if (inherits(dataset, "arrow_dplyr_query")) {
+ if (inherits(dataset$.data, "RecordBatchReader")) {
+ return(ExecNode_ReadFromRecordBatchReader(self, dataset$.data))
+ }
+
+ filter <- dataset$filtered_rows
+ if (isTRUE(filter)) {
+ filter <- Expression$scalar(TRUE)
+ }
+ # Use FieldsInExpression to find all from dataset$selected_columns
+ colnames <- unique(unlist(map(
+ dataset$selected_columns,
+ field_names_in_expression
+ )))
+ dataset <- dataset$.data
+ assert_is(dataset, "Dataset")
+ } else {
+ if (inherits(dataset, "ArrowTabular")) {
+ dataset <- InMemoryDataset$create(dataset)
+ }
+ assert_is(dataset, "Dataset")
+ # Set some defaults
+ filter <- Expression$scalar(TRUE)
+ colnames <- names(dataset)
+ }
+ # ScanNode needs the filter to do predicate pushdown and skip partitions,
+ # and it needs to know which fields to materialize (and which are unnecessary)
+ ExecNode_Scan(self, dataset, filter, colnames %||% character(0))
+ },
+ Build = function(.data) {
+ # This method takes an arrow_dplyr_query and chains together the
+ # ExecNodes that they produce. It does not evaluate them--that is Run().
+ group_vars <- dplyr::group_vars(.data)
+ grouped <- length(group_vars) > 0
+
+ # Collect the target names first because we have to add back the group vars
+ target_names <- names(.data)
+ .data <- ensure_group_vars(.data)
+ .data <- ensure_arrange_vars(.data) # this sets .data$temp_columns
+
+ if (inherits(.data$.data, "arrow_dplyr_query")) {
+ # We have a nested query. Recurse.
+ node <- self$Build(.data$.data)
+ } else {
+ node <- self$Scan(.data)
+ }
+
+ # ARROW-13498: Even though Scan takes the filter, apparently we have to do it again
+ if (inherits(.data$filtered_rows, "Expression")) {
+ node <- node$Filter(.data$filtered_rows)
+ }
+
+ if (!is.null(.data$aggregations)) {
+ # Project to include just the data required for each aggregation,
+ # plus group_by_vars (last)
+ # TODO: validate that none of names(aggregations) are the same as names(group_by_vars)
+ # dplyr does not error on this but the result it gives isn't great
+ node <- node$Project(summarize_projection(.data))
+
+ if (grouped) {
+ # We need to prefix all of the aggregation function names with "hash_"
+ .data$aggregations <- lapply(.data$aggregations, function(x) {
+ x[["fun"]] <- paste0("hash_", x[["fun"]])
+ x
+ })
+ }
+
+ node <- node$Aggregate(
+ options = map(.data$aggregations, ~ .[c("fun", "options")]),
+ target_names = names(.data$aggregations),
+ out_field_names = names(.data$aggregations),
+ key_names = group_vars
+ )
+
+ if (grouped) {
+ # The result will have result columns first then the grouping cols.
+ # dplyr orders group cols first, so adapt the result to meet that expectation.
+ node <- node$Project(
+ make_field_refs(c(group_vars, names(.data$aggregations)))
+ )
+ if (getOption("arrow.summarise.sort", FALSE)) {
+ # Add sorting instructions for the rows too to match dplyr
+ # (see below about why sorting isn't itself a Node)
+ node$sort <- list(
+ names = group_vars,
+ orders = rep(0L, length(group_vars))
+ )
+ }
+ }
+ } else {
+ # If any columns are derived, reordered, or renamed we need to Project
+ # If there are aggregations, the projection was already handled above
+ # We have to project at least once to eliminate some junk columns
+ # that the ExecPlan adds:
+ # __fragment_index, __batch_index, __last_in_fragment
+ # Presumably extraneous repeated projection of the same thing
+ # (as when we've done collapse() and not projected after) is cheap/no-op
+ projection <- c(.data$selected_columns, .data$temp_columns)
+ node <- node$Project(projection)
+
+ if (!is.null(.data$join)) {
+ node <- node$Join(
+ type = .data$join$type,
+ right_node = self$Build(.data$join$right_data),
+ by = .data$join$by,
+ left_output = names(.data),
+ right_output = setdiff(names(.data$join$right_data), .data$join$by)
+ )
+ }
+ }
+
+ # Apply sorting: this is currently not an ExecNode itself, it is a
+ # sink node option.
+ # TODO: handle some cases:
+ # (1) arrange > summarize > arrange
+ # (2) ARROW-13779: arrange then operation where order matters (e.g. cumsum)
+ if (length(.data$arrange_vars)) {
+ node$sort <- list(
+ names = names(.data$arrange_vars),
+ orders = .data$arrange_desc,
+ temp_columns = names(.data$temp_columns)
+ )
+ }
+
+ # This is only safe because we are going to evaluate queries that end
+ # with head/tail first, then evaluate any subsequent query as a new query
+ if (!is.null(.data$head)) {
+ node$head <- .data$head
+ }
+ if (!is.null(.data$tail)) {
+ node$tail <- .data$tail
+ }
+
+ node
+ },
+ Run = function(node) {
+ assert_is(node, "ExecNode")
+
+ # Sorting and head/tail (if sorted) are handled in the SinkNode,
+ # created in ExecPlan_run
+ sorting <- node$sort %||% list()
+ select_k <- node$head %||% -1L
+ has_sorting <- length(sorting) > 0
+ if (has_sorting) {
+ if (!is.null(node$tail)) {
+ # Reverse the sort order and take the top K, then after we'll reverse
+ # the resulting rows so that it is ordered as expected
+ sorting$orders <- !sorting$orders
+ select_k <- node$tail
+ }
+ sorting$orders <- as.integer(sorting$orders)
+ }
+
+ out <- ExecPlan_run(self, node, sorting, select_k)
+
+ if (!has_sorting) {
+ # Since ExecPlans don't scan in deterministic order, head/tail are both
+ # essentially taking a random slice from somewhere in the dataset.
+ # And since the head() implementation is way more efficient than tail(),
+ # just use it to take the random slice
+ slice_size <- node$head %||% node$tail
+ if (!is.null(slice_size)) {
+ # TODO (ARROW-14289): make the head methods return RBR not Table
+ out <- head(out, slice_size)
+ }
+ # Can we now tell `self$Stop()` to StopProducing? We already have
+ # everything we need for the head (but it seems to segfault: ARROW-14329)
+ } else if (!is.null(node$tail)) {
+ # Reverse the row order to get back what we expect
+ # TODO: don't return Table, return RecordBatchReader
+ out <- out$read_table()
+ out <- out[rev(seq_len(nrow(out))), , drop = FALSE]
+ }
+
+ out
+ },
+ Stop = function() ExecPlan_StopProducing(self)
+ )
+)
+ExecPlan$create <- function(use_threads = option_use_threads()) {
+ ExecPlan_create(use_threads)
+}
+
+ExecNode <- R6Class("ExecNode",
+ inherit = ArrowObject,
+ public = list(
+ # `sort` is a slight hack to be able to keep around arrange() params,
+ # which don't currently yield their own ExecNode but rather are consumed
+ # in the SinkNode (in ExecPlan$run())
+ sort = NULL,
+ # Similar hacks for head and tail
+ head = NULL,
+ tail = NULL,
+ preserve_sort = function(new_node) {
+ new_node$sort <- self$sort
+ new_node$head <- self$head
+ new_node$tail <- self$tail
+ new_node
+ },
+ Project = function(cols) {
+ if (length(cols)) {
+ assert_is_list_of(cols, "Expression")
+ self$preserve_sort(ExecNode_Project(self, cols, names(cols)))
+ } else {
+ self$preserve_sort(ExecNode_Project(self, character(0), character(0)))
+ }
+ },
+ Filter = function(expr) {
+ assert_is(expr, "Expression")
+ self$preserve_sort(ExecNode_Filter(self, expr))
+ },
+ Aggregate = function(options, target_names, out_field_names, key_names) {
+ self$preserve_sort(
+ ExecNode_Aggregate(self, options, target_names, out_field_names, key_names)
+ )
+ },
+ Join = function(type, right_node, by, left_output, right_output) {
+ self$preserve_sort(
+ ExecNode_Join(
+ self,
+ type,
+ right_node,
+ left_keys = names(by),
+ right_keys = by,
+ left_output = left_output,
+ right_output = right_output
+ )
+ )
+ }
+ ),
+ active = list(
+ schema = function() ExecNode_output_schema(self)
+ )
+)