summaryrefslogtreecommitdiffstats
path: root/src/arrow/r/R/dplyr-summarize.R
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--src/arrow/r/R/dplyr-summarize.R289
1 files changed, 289 insertions, 0 deletions
diff --git a/src/arrow/r/R/dplyr-summarize.R b/src/arrow/r/R/dplyr-summarize.R
new file mode 100644
index 000000000..a6b7a3592
--- /dev/null
+++ b/src/arrow/r/R/dplyr-summarize.R
@@ -0,0 +1,289 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+# The following S3 methods are registered on load if dplyr is present
+
+summarise.arrow_dplyr_query <- function(.data, ...) {
+ call <- match.call()
+ .data <- as_adq(.data)
+ exprs <- quos(...)
+ # Only retain the columns we need to do our aggregations
+ vars_to_keep <- unique(c(
+ unlist(lapply(exprs, all.vars)), # vars referenced in summarise
+ dplyr::group_vars(.data) # vars needed for grouping
+ ))
+ # If exprs rely on the results of previous exprs
+ # (total = sum(x), mean = total / n())
+ # then not all vars will correspond to columns in the data,
+ # so don't try to select() them (use intersect() to exclude them)
+ # Note that this select() isn't useful for the Arrow summarize implementation
+ # because it will effectively project to keep what it needs anyway,
+ # but the data.frame fallback version does benefit from select here
+ .data <- dplyr::select(.data, intersect(vars_to_keep, names(.data)))
+
+ # Try stuff, if successful return()
+ out <- try(do_arrow_summarize(.data, ...), silent = TRUE)
+ if (inherits(out, "try-error")) {
+ return(abandon_ship(call, .data, format(out)))
+ } else {
+ return(out)
+ }
+}
+summarise.Dataset <- summarise.ArrowTabular <- summarise.arrow_dplyr_query
+
+# This is the Arrow summarize implementation
+do_arrow_summarize <- function(.data, ..., .groups = NULL) {
+ exprs <- ensure_named_exprs(quos(...))
+
+ # Create a stateful environment for recording our evaluated expressions
+ # It's more complex than other places because a single summarize() expr
+ # may result in multiple query nodes (Aggregate, Project),
+ # and we have to walk through the expressions to disentangle them.
+ ctx <- env(
+ mask = arrow_mask(.data, aggregation = TRUE),
+ aggregations = empty_named_list(),
+ post_mutate = empty_named_list()
+ )
+ for (i in seq_along(exprs)) {
+ # Iterate over the indices and not the names because names may be repeated
+ # (which overwrites the previous name)
+ summarize_eval(
+ names(exprs)[i],
+ exprs[[i]],
+ ctx,
+ length(.data$group_by_vars) > 0
+ )
+ }
+
+ # Apply the results to the .data object.
+ # First, the aggregations
+ .data$aggregations <- ctx$aggregations
+ # Then collapse the query so that the resulting query object can have
+ # additional operations applied to it
+ out <- collapse.arrow_dplyr_query(.data)
+ # The expressions may have been translated into
+ # "first, aggregate, then transform the result further"
+ # nolint start
+ # For example,
+ # summarize(mean = sum(x) / n())
+ # is effectively implemented as
+ # summarize(..temp0 = sum(x), ..temp1 = n()) %>%
+ # mutate(mean = ..temp0 / ..temp1) %>%
+ # select(-starts_with("..temp"))
+ # If this is the case, there will be expressions in post_mutate
+ # nolint end
+ if (length(ctx$post_mutate)) {
+ # Append post_mutate, and make sure order is correct
+ # according to input exprs (also dropping ..temp columns)
+ out$selected_columns <- c(
+ out$selected_columns,
+ ctx$post_mutate
+ )[c(.data$group_by_vars, names(exprs))]
+ }
+
+ # If the object has .drop = FALSE and any group vars are dictionaries,
+ # we can't (currently) preserve the empty rows that dplyr does,
+ # so give a warning about that.
+ if (!dplyr::group_by_drop_default(.data)) {
+ group_by_exprs <- .data$selected_columns[.data$group_by_vars]
+ if (any(map_lgl(group_by_exprs, ~ inherits(.$type(), "DictionaryType")))) {
+ warning(
+ ".drop = FALSE currently not supported in Arrow aggregation",
+ call. = FALSE
+ )
+ }
+ }
+
+ # Handle .groups argument
+ if (length(.data$group_by_vars)) {
+ if (is.null(.groups)) {
+ # dplyr docs say:
+ # When ‘.groups’ is not specified, it is chosen based on the
+ # number of rows of the results:
+ # • If all the results have 1 row, you get "drop_last".
+ # • If the number of rows varies, you get "keep".
+ #
+ # But we don't support anything that returns multiple rows now
+ .groups <- "drop_last"
+ } else {
+ assert_that(is.string(.groups))
+ }
+ if (.groups == "drop_last") {
+ out$group_by_vars <- head(.data$group_by_vars, -1)
+ } else if (.groups == "keep") {
+ out$group_by_vars <- .data$group_by_vars
+ } else if (.groups == "rowwise") {
+ stop(arrow_not_supported('.groups = "rowwise"'))
+ } else if (.groups == "drop") {
+ # collapse() preserves groups so remove them
+ out <- dplyr::ungroup(out)
+ } else {
+ stop(paste("Invalid .groups argument:", .groups))
+ }
+ # TODO: shouldn't we be doing something with `drop_empty_groups` in summarize? (ARROW-14044)
+ out$drop_empty_groups <- .data$drop_empty_groups
+ }
+ out
+}
+
+arrow_eval_or_stop <- function(expr, mask) {
+ # TODO: change arrow_eval error handling behavior?
+ out <- arrow_eval(expr, mask)
+ if (inherits(out, "try-error")) {
+ msg <- handle_arrow_not_supported(out, format_expr(expr))
+ stop(msg, call. = FALSE)
+ }
+ out
+}
+
+summarize_projection <- function(.data) {
+ c(
+ map(.data$aggregations, ~ .$data),
+ .data$selected_columns[.data$group_by_vars]
+ )
+}
+
+format_aggregation <- function(x) {
+ paste0(x$fun, "(", x$data$ToString(), ")")
+}
+
+# This function handles each summarize expression and turns it into the
+# appropriate combination of (1) aggregations (possibly temporary) and
+# (2) post-aggregation transformations (mutate)
+# The function returns nothing: it assigns into the `ctx` environment
+summarize_eval <- function(name, quosure, ctx, hash, recurse = FALSE) {
+ expr <- quo_get_expr(quosure)
+ ctx$quo_env <- quo_get_env(quosure)
+
+ funs_in_expr <- all_funs(expr)
+ if (length(funs_in_expr) == 0) {
+ # If it is a scalar or field ref, no special handling required
+ ctx$aggregations[[name]] <- arrow_eval_or_stop(quosure, ctx$mask)
+ return()
+ }
+
+ # For the quantile() binding in the hash aggregation case, we need to mutate
+ # the list output from the Arrow hash_tdigest kernel to flatten it into a
+ # column of type float64. We do that by modifying the unevaluated expression
+ # to replace quantile(...) with arrow_list_element(quantile(...), 0L)
+ if (hash && "quantile" %in% funs_in_expr) {
+ expr <- wrap_hash_quantile(expr)
+ funs_in_expr <- all_funs(expr)
+ }
+
+ # Start inspecting the expr to see what aggregations it involves
+ agg_funs <- names(agg_funcs)
+ outer_agg <- funs_in_expr[1] %in% agg_funs
+ inner_agg <- funs_in_expr[-1] %in% agg_funs
+
+ # First, pull out any aggregations wrapped in other function calls
+ if (any(inner_agg)) {
+ expr <- extract_aggregations(expr, ctx)
+ }
+
+ # By this point, there are no more aggregation functions in expr
+ # except for possibly the outer function call:
+ # they've all been pulled out to ctx$aggregations, and in their place in expr
+ # there are variable names, which will correspond to field refs in the
+ # query object after aggregation and collapse().
+ # So if we want to know if there are any aggregations inside expr,
+ # we have to look for them by their new var names
+ inner_agg_exprs <- all_vars(expr) %in% names(ctx$aggregations)
+
+ if (outer_agg) {
+ # This is something like agg(fun(x, y)
+ # It just works by normal arrow_eval, unless there's a mix of aggs and
+ # columns in the original data like agg(fun(x, agg(x)))
+ # (but that will have been caught in extract_aggregations())
+ ctx$aggregations[[name]] <- arrow_eval_or_stop(
+ as_quosure(expr, ctx$quo_env),
+ ctx$mask
+ )
+ return()
+ } else if (all(inner_agg_exprs)) {
+ # Something like: fun(agg(x), agg(y))
+ # So based on the aggregations that have been extracted, mutate after
+ mutate_mask <- arrow_mask(
+ list(selected_columns = make_field_refs(names(ctx$aggregations)))
+ )
+ ctx$post_mutate[[name]] <- arrow_eval_or_stop(
+ as_quosure(expr, ctx$quo_env),
+ mutate_mask
+ )
+ return()
+ }
+
+ # Backstop for any other odd cases, like fun(x, y) (i.e. no aggregation),
+ # or aggregation functions that aren't supported in Arrow (not in agg_funcs)
+ stop(
+ handle_arrow_not_supported(quo_get_expr(quosure), format_expr(quosure)),
+ call. = FALSE
+ )
+}
+
+# This function recurses through expr, pulls out any aggregation expressions,
+# and inserts a variable name (field ref) in place of the aggregation
+extract_aggregations <- function(expr, ctx) {
+ # Keep the input in case we need to raise an error message with it
+ original_expr <- expr
+ funs <- all_funs(expr)
+ if (length(funs) == 0) {
+ return(expr)
+ } else if (length(funs) > 1) {
+ # Recurse more
+ expr[-1] <- lapply(expr[-1], extract_aggregations, ctx)
+ }
+ if (funs[1] %in% names(agg_funcs)) {
+ inner_agg_exprs <- all_vars(expr) %in% names(ctx$aggregations)
+ if (any(inner_agg_exprs) & !all(inner_agg_exprs)) {
+ # We can't aggregate over a combination of dataset columns and other
+ # aggregations (e.g. sum(x - mean(x)))
+ # TODO: support in ARROW-13926
+ # TODO: Add "because" arg to explain _why_ it's not supported?
+ # TODO: this message could also say "not supported in summarize()"
+ # since some of these expressions may be legal elsewhere
+ stop(
+ handle_arrow_not_supported(original_expr, format_expr(original_expr)),
+ call. = FALSE
+ )
+ }
+
+ # We have an aggregation expression with no other aggregations inside it,
+ # so arrow_eval the expression on the data and give it a ..temp name prefix,
+ # then insert that name (symbol) back into the expression so that we can
+ # mutate() on the result of the aggregation and reference this field.
+ tmpname <- paste0("..temp", length(ctx$aggregations))
+ ctx$aggregations[[tmpname]] <- arrow_eval_or_stop(as_quosure(expr, ctx$quo_env), ctx$mask)
+ expr <- as.symbol(tmpname)
+ }
+ expr
+}
+
+# This function recurses through expr and wraps each call to quantile() with a
+# call to arrow_list_element()
+wrap_hash_quantile <- function(expr) {
+ if (length(expr) == 1) {
+ return(expr)
+ } else {
+ if (is.call(expr) && expr[[1]] == quote(quantile)) {
+ return(str2lang(paste0("arrow_list_element(", deparse1(expr), ", 0L)")))
+ } else {
+ return(as.call(lapply(expr, wrap_hash_quantile)))
+ }
+ }
+}