summaryrefslogtreecommitdiffstats
path: root/src/arrow/r/R/dplyr-collect.R
diff options
context:
space:
mode:
Diffstat (limited to 'src/arrow/r/R/dplyr-collect.R')
-rw-r--r--src/arrow/r/R/dplyr-collect.R121
1 files changed, 121 insertions, 0 deletions
diff --git a/src/arrow/r/R/dplyr-collect.R b/src/arrow/r/R/dplyr-collect.R
new file mode 100644
index 000000000..13e68f3f4
--- /dev/null
+++ b/src/arrow/r/R/dplyr-collect.R
@@ -0,0 +1,121 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+# The following S3 methods are registered on load if dplyr is present
+
+collect.arrow_dplyr_query <- function(x, as_data_frame = TRUE, ...) {
+ # head and tail are not ExecNodes, at best we can handle them via sink node
+ # so if there are any steps done after head/tail, we need to
+ # evaluate the query up to then and then do a new query for the rest
+ if (is_collapsed(x) && has_head_tail(x$.data)) {
+ x$.data <- as_adq(dplyr::compute(x$.data))$.data
+ }
+
+ # See query-engine.R for ExecPlan/Nodes
+ tab <- do_exec_plan(x)
+ if (as_data_frame) {
+ df <- as.data.frame(tab)
+ tab$invalidate()
+ restore_dplyr_features(df, x)
+ } else {
+ restore_dplyr_features(tab, x)
+ }
+}
+collect.ArrowTabular <- function(x, as_data_frame = TRUE, ...) {
+ if (as_data_frame) {
+ as.data.frame(x, ...)
+ } else {
+ x
+ }
+}
+collect.Dataset <- function(x, ...) dplyr::collect(as_adq(x), ...)
+
+compute.arrow_dplyr_query <- function(x, ...) dplyr::collect(x, as_data_frame = FALSE)
+compute.ArrowTabular <- function(x, ...) x
+compute.Dataset <- compute.arrow_dplyr_query
+
+pull.arrow_dplyr_query <- function(.data, var = -1) {
+ .data <- as_adq(.data)
+ var <- vars_pull(names(.data), !!enquo(var))
+ .data$selected_columns <- set_names(.data$selected_columns[var], var)
+ dplyr::collect(.data)[[1]]
+}
+pull.Dataset <- pull.ArrowTabular <- pull.arrow_dplyr_query
+
+restore_dplyr_features <- function(df, query) {
+ # An arrow_dplyr_query holds some attributes that Arrow doesn't know about
+ # After calling collect(), make sure these features are carried over
+
+ if (length(query$group_by_vars) > 0) {
+ # Preserve groupings, if present
+ if (is.data.frame(df)) {
+ df <- dplyr::grouped_df(
+ df,
+ dplyr::group_vars(query),
+ drop = dplyr::group_by_drop_default(query)
+ )
+ } else {
+ # This is a Table, via compute() or collect(as_data_frame = FALSE)
+ df <- as_adq(df)
+ df$group_by_vars <- query$group_by_vars
+ df$drop_empty_groups <- query$drop_empty_groups
+ }
+ }
+ df
+}
+
+collapse.arrow_dplyr_query <- function(x, ...) {
+ # Figure out what schema will result from the query
+ x$schema <- implicit_schema(x)
+ # Nest inside a new arrow_dplyr_query (and keep groups)
+ restore_dplyr_features(arrow_dplyr_query(x), x)
+}
+collapse.Dataset <- collapse.ArrowTabular <- function(x, ...) {
+ arrow_dplyr_query(x)
+}
+
+implicit_schema <- function(.data) {
+ .data <- ensure_group_vars(.data)
+ old_schm <- .data$.data$schema
+
+ if (is.null(.data$aggregations)) {
+ new_fields <- map(.data$selected_columns, ~ .$type(old_schm))
+ if (!is.null(.data$join) && !(.data$join$type %in% JoinType[1:4])) {
+ # Add cols from right side, except for semi/anti joins
+ right_cols <- .data$join$right_data$selected_columns
+ new_fields <- c(new_fields, map(
+ right_cols[setdiff(names(right_cols), .data$join$by)],
+ ~ .$type(.data$join$right_data$.data$schema)
+ ))
+ }
+ } else {
+ new_fields <- map(summarize_projection(.data), ~ .$type(old_schm))
+ # * Put group_by_vars first (this can't be done by summarize,
+ # they have to be last per the aggregate node signature,
+ # and they get projected to this order after aggregation)
+ # * Infer the output types from the aggregations
+ group_fields <- new_fields[.data$group_by_vars]
+ hash <- length(.data$group_by_vars) > 0
+ agg_fields <- imap(
+ new_fields[setdiff(names(new_fields), .data$group_by_vars)],
+ ~ output_type(.data$aggregations[[.y]][["fun"]], .x, hash)
+ )
+ new_fields <- c(group_fields, agg_fields)
+ }
+ schema(!!!new_fields)
+}