summaryrefslogtreecommitdiffstats
path: root/src/arrow/r/R/dplyr-mutate.R
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--src/arrow/r/R/dplyr-mutate.R140
1 files changed, 140 insertions, 0 deletions
diff --git a/src/arrow/r/R/dplyr-mutate.R b/src/arrow/r/R/dplyr-mutate.R
new file mode 100644
index 000000000..2e5239484
--- /dev/null
+++ b/src/arrow/r/R/dplyr-mutate.R
@@ -0,0 +1,140 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+# The following S3 methods are registered on load if dplyr is present
+
+mutate.arrow_dplyr_query <- function(.data,
+ ...,
+ .keep = c("all", "used", "unused", "none"),
+ .before = NULL,
+ .after = NULL) {
+ call <- match.call()
+ exprs <- ensure_named_exprs(quos(...))
+
+ .keep <- match.arg(.keep)
+ .before <- enquo(.before)
+ .after <- enquo(.after)
+
+ if (.keep %in% c("all", "unused") && length(exprs) == 0) {
+ # Nothing to do
+ return(.data)
+ }
+
+ .data <- as_adq(.data)
+
+ # Restrict the cases we support for now
+ has_aggregations <- any(unlist(lapply(exprs, all_funs)) %in% names(agg_funcs))
+ if (has_aggregations) {
+ # ARROW-13926
+ # mutate() on a grouped dataset does calculations within groups
+ # This doesn't matter on scalar ops (arithmetic etc.) but it does
+ # for things with aggregations (e.g. subtracting the mean)
+ return(abandon_ship(call, .data, "window functions not currently supported in Arrow"))
+ }
+
+ mask <- arrow_mask(.data)
+ results <- list()
+ for (i in seq_along(exprs)) {
+ # Iterate over the indices and not the names because names may be repeated
+ # (which overwrites the previous name)
+ new_var <- names(exprs)[i]
+ results[[new_var]] <- arrow_eval(exprs[[i]], mask)
+ if (inherits(results[[new_var]], "try-error")) {
+ msg <- handle_arrow_not_supported(
+ results[[new_var]],
+ format_expr(exprs[[i]])
+ )
+ return(abandon_ship(call, .data, msg))
+ } else if (!inherits(results[[new_var]], "Expression") &&
+ !is.null(results[[new_var]])) {
+ # We need some wrapping to handle literal values
+ if (length(results[[new_var]]) != 1) {
+ msg <- paste0("In ", new_var, " = ", format_expr(exprs[[i]]), ", only values of size one are recycled")
+ return(abandon_ship(call, .data, msg))
+ }
+ results[[new_var]] <- Expression$scalar(results[[new_var]])
+ }
+ # Put it in the data mask too
+ mask[[new_var]] <- mask$.data[[new_var]] <- results[[new_var]]
+ }
+
+ old_vars <- names(.data$selected_columns)
+ # Note that this is names(exprs) not names(results):
+ # if results$new_var is NULL, that means we are supposed to remove it
+ new_vars <- names(exprs)
+
+ # Assign the new columns into the .data$selected_columns
+ for (new_var in new_vars) {
+ .data$selected_columns[[new_var]] <- results[[new_var]]
+ }
+
+ # Deduplicate new_vars and remove NULL columns from new_vars
+ new_vars <- intersect(new_vars, names(.data$selected_columns))
+
+ # Respect .before and .after
+ if (!quo_is_null(.before) || !quo_is_null(.after)) {
+ new <- setdiff(new_vars, old_vars)
+ .data <- dplyr::relocate(.data, all_of(new), .before = !!.before, .after = !!.after)
+ }
+
+ # Respect .keep
+ if (.keep == "none") {
+ .data$selected_columns <- .data$selected_columns[new_vars]
+ } else if (.keep != "all") {
+ # "used" or "unused"
+ used_vars <- unlist(lapply(exprs, all.vars), use.names = FALSE)
+ if (.keep == "used") {
+ .data$selected_columns[setdiff(old_vars, used_vars)] <- NULL
+ } else {
+ # "unused"
+ .data$selected_columns[intersect(old_vars, used_vars)] <- NULL
+ }
+ }
+ # Even if "none", we still keep group vars
+ ensure_group_vars(.data)
+}
+mutate.Dataset <- mutate.ArrowTabular <- mutate.arrow_dplyr_query
+
+transmute.arrow_dplyr_query <- function(.data, ...) {
+ dots <- check_transmute_args(...)
+ dplyr::mutate(.data, !!!dots, .keep = "none")
+}
+transmute.Dataset <- transmute.ArrowTabular <- transmute.arrow_dplyr_query
+
+# This function is a copy of dplyr:::check_transmute_args at
+# https://github.com/tidyverse/dplyr/blob/master/R/mutate.R
+check_transmute_args <- function(..., .keep, .before, .after) {
+ if (!missing(.keep)) {
+ abort("`transmute()` does not support the `.keep` argument")
+ }
+ if (!missing(.before)) {
+ abort("`transmute()` does not support the `.before` argument")
+ }
+ if (!missing(.after)) {
+ abort("`transmute()` does not support the `.after` argument")
+ }
+ enquos(...)
+}
+
+ensure_named_exprs <- function(exprs) {
+ # Check for unnamed expressions and fix if any
+ unnamed <- !nzchar(names(exprs))
+ # Deparse and take the first element in case they're long expressions
+ names(exprs)[unnamed] <- map_chr(exprs[unnamed], format_expr)
+ exprs
+}