diff options
Diffstat (limited to 'src/arrow/r/R/dplyr-collect.R')
-rw-r--r-- | src/arrow/r/R/dplyr-collect.R | 121 |
1 files changed, 121 insertions, 0 deletions
diff --git a/src/arrow/r/R/dplyr-collect.R b/src/arrow/r/R/dplyr-collect.R new file mode 100644 index 000000000..13e68f3f4 --- /dev/null +++ b/src/arrow/r/R/dplyr-collect.R @@ -0,0 +1,121 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +# The following S3 methods are registered on load if dplyr is present + +collect.arrow_dplyr_query <- function(x, as_data_frame = TRUE, ...) { + # head and tail are not ExecNodes, at best we can handle them via sink node + # so if there are any steps done after head/tail, we need to + # evaluate the query up to then and then do a new query for the rest + if (is_collapsed(x) && has_head_tail(x$.data)) { + x$.data <- as_adq(dplyr::compute(x$.data))$.data + } + + # See query-engine.R for ExecPlan/Nodes + tab <- do_exec_plan(x) + if (as_data_frame) { + df <- as.data.frame(tab) + tab$invalidate() + restore_dplyr_features(df, x) + } else { + restore_dplyr_features(tab, x) + } +} +collect.ArrowTabular <- function(x, as_data_frame = TRUE, ...) { + if (as_data_frame) { + as.data.frame(x, ...) + } else { + x + } +} +collect.Dataset <- function(x, ...) dplyr::collect(as_adq(x), ...) + +compute.arrow_dplyr_query <- function(x, ...) dplyr::collect(x, as_data_frame = FALSE) +compute.ArrowTabular <- function(x, ...) x +compute.Dataset <- compute.arrow_dplyr_query + +pull.arrow_dplyr_query <- function(.data, var = -1) { + .data <- as_adq(.data) + var <- vars_pull(names(.data), !!enquo(var)) + .data$selected_columns <- set_names(.data$selected_columns[var], var) + dplyr::collect(.data)[[1]] +} +pull.Dataset <- pull.ArrowTabular <- pull.arrow_dplyr_query + +restore_dplyr_features <- function(df, query) { + # An arrow_dplyr_query holds some attributes that Arrow doesn't know about + # After calling collect(), make sure these features are carried over + + if (length(query$group_by_vars) > 0) { + # Preserve groupings, if present + if (is.data.frame(df)) { + df <- dplyr::grouped_df( + df, + dplyr::group_vars(query), + drop = dplyr::group_by_drop_default(query) + ) + } else { + # This is a Table, via compute() or collect(as_data_frame = FALSE) + df <- as_adq(df) + df$group_by_vars <- query$group_by_vars + df$drop_empty_groups <- query$drop_empty_groups + } + } + df +} + +collapse.arrow_dplyr_query <- function(x, ...) { + # Figure out what schema will result from the query + x$schema <- implicit_schema(x) + # Nest inside a new arrow_dplyr_query (and keep groups) + restore_dplyr_features(arrow_dplyr_query(x), x) +} +collapse.Dataset <- collapse.ArrowTabular <- function(x, ...) { + arrow_dplyr_query(x) +} + +implicit_schema <- function(.data) { + .data <- ensure_group_vars(.data) + old_schm <- .data$.data$schema + + if (is.null(.data$aggregations)) { + new_fields <- map(.data$selected_columns, ~ .$type(old_schm)) + if (!is.null(.data$join) && !(.data$join$type %in% JoinType[1:4])) { + # Add cols from right side, except for semi/anti joins + right_cols <- .data$join$right_data$selected_columns + new_fields <- c(new_fields, map( + right_cols[setdiff(names(right_cols), .data$join$by)], + ~ .$type(.data$join$right_data$.data$schema) + )) + } + } else { + new_fields <- map(summarize_projection(.data), ~ .$type(old_schm)) + # * Put group_by_vars first (this can't be done by summarize, + # they have to be last per the aggregate node signature, + # and they get projected to this order after aggregation) + # * Infer the output types from the aggregations + group_fields <- new_fields[.data$group_by_vars] + hash <- length(.data$group_by_vars) > 0 + agg_fields <- imap( + new_fields[setdiff(names(new_fields), .data$group_by_vars)], + ~ output_type(.data$aggregations[[.y]][["fun"]], .x, hash) + ) + new_fields <- c(group_fields, agg_fields) + } + schema(!!!new_fields) +} |