summaryrefslogtreecommitdiffstats
path: root/src/arrow/r/R/dplyr-mutate.R
blob: 2e5239484f633423e9f2bd7d16e16c0c709cdf80 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.


# The following S3 methods are registered on load if dplyr is present

mutate.arrow_dplyr_query <- function(.data,
                                     ...,
                                     .keep = c("all", "used", "unused", "none"),
                                     .before = NULL,
                                     .after = NULL) {
  call <- match.call()
  exprs <- ensure_named_exprs(quos(...))

  .keep <- match.arg(.keep)
  .before <- enquo(.before)
  .after <- enquo(.after)

  if (.keep %in% c("all", "unused") && length(exprs) == 0) {
    # Nothing to do
    return(.data)
  }

  .data <- as_adq(.data)

  # Restrict the cases we support for now
  has_aggregations <- any(unlist(lapply(exprs, all_funs)) %in% names(agg_funcs))
  if (has_aggregations) {
    # ARROW-13926
    # mutate() on a grouped dataset does calculations within groups
    # This doesn't matter on scalar ops (arithmetic etc.) but it does
    # for things with aggregations (e.g. subtracting the mean)
    return(abandon_ship(call, .data, "window functions not currently supported in Arrow"))
  }

  mask <- arrow_mask(.data)
  results <- list()
  for (i in seq_along(exprs)) {
    # Iterate over the indices and not the names because names may be repeated
    # (which overwrites the previous name)
    new_var <- names(exprs)[i]
    results[[new_var]] <- arrow_eval(exprs[[i]], mask)
    if (inherits(results[[new_var]], "try-error")) {
      msg <- handle_arrow_not_supported(
        results[[new_var]],
        format_expr(exprs[[i]])
      )
      return(abandon_ship(call, .data, msg))
    } else if (!inherits(results[[new_var]], "Expression") &&
      !is.null(results[[new_var]])) {
      # We need some wrapping to handle literal values
      if (length(results[[new_var]]) != 1) {
        msg <- paste0("In ", new_var, " = ", format_expr(exprs[[i]]), ", only values of size one are recycled")
        return(abandon_ship(call, .data, msg))
      }
      results[[new_var]] <- Expression$scalar(results[[new_var]])
    }
    # Put it in the data mask too
    mask[[new_var]] <- mask$.data[[new_var]] <- results[[new_var]]
  }

  old_vars <- names(.data$selected_columns)
  # Note that this is names(exprs) not names(results):
  # if results$new_var is NULL, that means we are supposed to remove it
  new_vars <- names(exprs)

  # Assign the new columns into the .data$selected_columns
  for (new_var in new_vars) {
    .data$selected_columns[[new_var]] <- results[[new_var]]
  }

  # Deduplicate new_vars and remove NULL columns from new_vars
  new_vars <- intersect(new_vars, names(.data$selected_columns))

  # Respect .before and .after
  if (!quo_is_null(.before) || !quo_is_null(.after)) {
    new <- setdiff(new_vars, old_vars)
    .data <- dplyr::relocate(.data, all_of(new), .before = !!.before, .after = !!.after)
  }

  # Respect .keep
  if (.keep == "none") {
    .data$selected_columns <- .data$selected_columns[new_vars]
  } else if (.keep != "all") {
    # "used" or "unused"
    used_vars <- unlist(lapply(exprs, all.vars), use.names = FALSE)
    if (.keep == "used") {
      .data$selected_columns[setdiff(old_vars, used_vars)] <- NULL
    } else {
      # "unused"
      .data$selected_columns[intersect(old_vars, used_vars)] <- NULL
    }
  }
  # Even if "none", we still keep group vars
  ensure_group_vars(.data)
}
mutate.Dataset <- mutate.ArrowTabular <- mutate.arrow_dplyr_query

transmute.arrow_dplyr_query <- function(.data, ...) {
  dots <- check_transmute_args(...)
  dplyr::mutate(.data, !!!dots, .keep = "none")
}
transmute.Dataset <- transmute.ArrowTabular <- transmute.arrow_dplyr_query

# This function is a copy of dplyr:::check_transmute_args at
# https://github.com/tidyverse/dplyr/blob/master/R/mutate.R
check_transmute_args <- function(..., .keep, .before, .after) {
  if (!missing(.keep)) {
    abort("`transmute()` does not support the `.keep` argument")
  }
  if (!missing(.before)) {
    abort("`transmute()` does not support the `.before` argument")
  }
  if (!missing(.after)) {
    abort("`transmute()` does not support the `.after` argument")
  }
  enquos(...)
}

ensure_named_exprs <- function(exprs) {
  # Check for unnamed expressions and fix if any
  unnamed <- !nzchar(names(exprs))
  # Deparse and take the first element in case they're long expressions
  names(exprs)[unnamed] <- map_chr(exprs[unnamed], format_expr)
  exprs
}