summaryrefslogtreecommitdiffstats
path: root/src/arrow/r/tests/testthat/test-dplyr-mutate.R
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--src/arrow/r/tests/testthat/test-dplyr-mutate.R522
1 files changed, 522 insertions, 0 deletions
diff --git a/src/arrow/r/tests/testthat/test-dplyr-mutate.R b/src/arrow/r/tests/testthat/test-dplyr-mutate.R
new file mode 100644
index 000000000..886ec9e42
--- /dev/null
+++ b/src/arrow/r/tests/testthat/test-dplyr-mutate.R
@@ -0,0 +1,522 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+skip_if_not_available("dataset")
+
+library(dplyr, warn.conflicts = FALSE)
+library(stringr)
+
+tbl <- example_data
+# Add some better string data
+tbl$verses <- verses[[1]]
+# c(" a ", " b ", " c ", ...) increasing padding
+# nchar = 3 5 7 9 11 13 15 17 19 21
+tbl$padded_strings <- stringr::str_pad(letters[1:10], width = 2 * (1:10) + 1, side = "both")
+
+test_that("mutate() is lazy", {
+ expect_s3_class(
+ tbl %>% record_batch() %>% mutate(int = int + 6L),
+ "arrow_dplyr_query"
+ )
+})
+
+test_that("basic mutate", {
+ compare_dplyr_binding(
+ .input %>%
+ select(int, chr) %>%
+ filter(int > 5) %>%
+ mutate(int = int + 6L) %>%
+ collect(),
+ tbl
+ )
+})
+
+test_that("mutate() with NULL inputs", {
+ compare_dplyr_binding(
+ .input %>%
+ mutate(int = NULL) %>%
+ collect(),
+ tbl
+ )
+})
+
+test_that("empty mutate()", {
+ compare_dplyr_binding(
+ .input %>%
+ mutate() %>%
+ collect(),
+ tbl
+ )
+})
+
+test_that("transmute", {
+ compare_dplyr_binding(
+ .input %>%
+ select(int, chr) %>%
+ filter(int > 5) %>%
+ transmute(int = int + 6L) %>%
+ collect(),
+ tbl
+ )
+})
+
+test_that("transmute() with NULL inputs", {
+ compare_dplyr_binding(
+ .input %>%
+ transmute(int = NULL) %>%
+ collect(),
+ tbl
+ )
+})
+
+test_that("empty transmute()", {
+ compare_dplyr_binding(
+ .input %>%
+ transmute() %>%
+ collect(),
+ tbl
+ )
+})
+
+test_that("transmute() with unsupported arguments", {
+ expect_error(
+ tbl %>%
+ Table$create() %>%
+ transmute(int = int + 42L, .keep = "all"),
+ "`transmute()` does not support the `.keep` argument",
+ fixed = TRUE
+ )
+ expect_error(
+ tbl %>%
+ Table$create() %>%
+ transmute(int = int + 42L, .before = lgl),
+ "`transmute()` does not support the `.before` argument",
+ fixed = TRUE
+ )
+ expect_error(
+ tbl %>%
+ Table$create() %>%
+ transmute(int = int + 42L, .after = chr),
+ "`transmute()` does not support the `.after` argument",
+ fixed = TRUE
+ )
+})
+
+test_that("transmute() defuses dots arguments (ARROW-13262)", {
+ expect_warning(
+ tbl %>%
+ Table$create() %>%
+ transmute(stringr::str_c(chr, chr)) %>%
+ collect(),
+ "Expression stringr::str_c(chr, chr) not supported in Arrow; pulling data into R",
+ fixed = TRUE
+ )
+})
+
+test_that("mutate and refer to previous mutants", {
+ compare_dplyr_binding(
+ .input %>%
+ select(int, verses) %>%
+ mutate(
+ line_lengths = nchar(verses),
+ longer = line_lengths * 10
+ ) %>%
+ filter(line_lengths > 15) %>%
+ collect(),
+ tbl
+ )
+})
+
+test_that("nchar() arguments", {
+ compare_dplyr_binding(
+ .input %>%
+ select(int, verses) %>%
+ mutate(
+ line_lengths = nchar(verses, type = "bytes"),
+ longer = line_lengths * 10
+ ) %>%
+ filter(line_lengths > 15) %>%
+ collect(),
+ tbl
+ )
+ # This tests the whole abandon_ship() machinery
+ compare_dplyr_binding(
+ .input %>%
+ select(int, verses) %>%
+ mutate(
+ line_lengths = nchar(verses, type = "bytes", allowNA = TRUE),
+ longer = line_lengths * 10
+ ) %>%
+ filter(line_lengths > 15) %>%
+ collect(),
+ tbl,
+ warning = paste0(
+ "In nchar\\(verses, type = \"bytes\", allowNA = TRUE\\), ",
+ "allowNA = TRUE not supported by Arrow; pulling data into R"
+ )
+ )
+})
+
+test_that("mutate with .data pronoun", {
+ compare_dplyr_binding(
+ .input %>%
+ select(int, verses) %>%
+ mutate(
+ line_lengths = str_length(verses),
+ longer = .data$line_lengths * 10
+ ) %>%
+ filter(line_lengths > 15) %>%
+ collect(),
+ tbl
+ )
+})
+
+test_that("mutate with unnamed expressions", {
+ compare_dplyr_binding(
+ .input %>%
+ select(int, padded_strings) %>%
+ mutate(
+ int, # bare column name
+ nchar(padded_strings) # expression
+ ) %>%
+ filter(int > 5) %>%
+ collect(),
+ tbl
+ )
+})
+
+test_that("mutate with reassigning same name", {
+ compare_dplyr_binding(
+ .input %>%
+ transmute(
+ new = lgl,
+ new = chr
+ ) %>%
+ collect(),
+ tbl
+ )
+})
+
+test_that("mutate with single value for recycling", {
+ compare_dplyr_binding(
+ .input %>%
+ select(int, padded_strings) %>%
+ mutate(
+ dr_bronner = 1 # ALL ONE!
+ ) %>%
+ collect(),
+ tbl
+ )
+})
+
+test_that("dplyr::mutate's examples", {
+ # Newly created variables are available immediately
+ compare_dplyr_binding(
+ .input %>%
+ select(name, mass) %>%
+ mutate(
+ mass2 = mass * 2,
+ mass2_squared = mass2 * mass2
+ ) %>%
+ collect(),
+ starwars # this is a test tibble that ships with dplyr
+ )
+
+ # As well as adding new variables, you can use mutate() to
+ # remove variables and modify existing variables.
+ compare_dplyr_binding(
+ .input %>%
+ select(name, height, mass, homeworld) %>%
+ mutate(
+ mass = NULL,
+ height = height * 0.0328084 # convert to feet
+ ) %>%
+ collect(),
+ starwars
+ )
+
+ # Examples we don't support should succeed
+ # but warn that they're pulling data into R to do so
+
+ # across and autosplicing: ARROW-11699
+ compare_dplyr_binding(
+ .input %>%
+ select(name, homeworld, species) %>%
+ mutate(across(!name, as.factor)) %>%
+ collect(),
+ starwars,
+ warning = "Expression across.*not supported in Arrow"
+ )
+
+ # group_by then mutate
+ compare_dplyr_binding(
+ .input %>%
+ select(name, mass, homeworld) %>%
+ group_by(homeworld) %>%
+ mutate(rank = min_rank(desc(mass))) %>%
+ collect(),
+ starwars,
+ warning = TRUE
+ )
+
+ # `.before` and `.after` experimental args: ARROW-11701
+ df <- tibble(x = 1, y = 2)
+ compare_dplyr_binding(
+ .input %>% mutate(z = x + y) %>% collect(),
+ df
+ )
+ #> # A tibble: 1 x 3
+ #> x y z
+ #> <dbl> <dbl> <dbl>
+ #> 1 1 2 3
+
+ compare_dplyr_binding(
+ .input %>% mutate(z = x + y, .before = 1) %>% collect(),
+ df
+ )
+ #> # A tibble: 1 x 3
+ #> z x y
+ #> <dbl> <dbl> <dbl>
+ #> 1 3 1 2
+ compare_dplyr_binding(
+ .input %>% mutate(z = x + y, .after = x) %>% collect(),
+ df
+ )
+ #> # A tibble: 1 x 3
+ #> x z y
+ #> <dbl> <dbl> <dbl>
+ #> 1 1 3 2
+
+ # By default, mutate() keeps all columns from the input data.
+ # Experimental: You can override with `.keep`
+ df <- tibble(x = 1, y = 2, a = "a", b = "b")
+ compare_dplyr_binding(
+ .input %>% mutate(z = x + y, .keep = "all") %>% collect(), # the default
+ df
+ )
+ #> # A tibble: 1 x 5
+ #> x y a b z
+ #> <dbl> <dbl> <chr> <chr> <dbl>
+ #> 1 1 2 a b 3
+ compare_dplyr_binding(
+ .input %>% mutate(z = x + y, .keep = "used") %>% collect(),
+ df
+ )
+ #> # A tibble: 1 x 3
+ #> x y z
+ #> <dbl> <dbl> <dbl>
+ #> 1 1 2 3
+ compare_dplyr_binding(
+ .input %>% mutate(z = x + y, .keep = "unused") %>% collect(),
+ df
+ )
+ #> # A tibble: 1 x 3
+ #> a b z
+ #> <chr> <chr> <dbl>
+ #> 1 a b 3
+ compare_dplyr_binding(
+ .input %>% mutate(z = x + y, .keep = "none") %>% collect(), # same as transmute()
+ df
+ )
+ #> # A tibble: 1 x 1
+ #> z
+ #> <dbl>
+ #> 1 3
+
+ # Grouping ----------------------------------------
+ # The mutate operation may yield different results on grouped
+ # tibbles because the expressions are computed within groups.
+ # The following normalises `mass` by the global average:
+ # TODO: ARROW-13926
+ compare_dplyr_binding(
+ .input %>%
+ select(name, mass, species) %>%
+ mutate(mass_norm = mass / mean(mass, na.rm = TRUE)) %>%
+ collect(),
+ starwars,
+ warning = "window function"
+ )
+})
+
+test_that("Can mutate after group_by as long as there are no aggregations", {
+ compare_dplyr_binding(
+ .input %>%
+ select(int, chr) %>%
+ group_by(chr) %>%
+ mutate(int = int + 6L) %>%
+ collect(),
+ tbl
+ )
+ compare_dplyr_binding(
+ .input %>%
+ select(mean = int, chr) %>%
+ # rename `int` to `mean` and use `mean` in `mutate()` to test that
+ # `all_funs()` does not incorrectly identify it as an aggregate function
+ group_by(chr) %>%
+ mutate(mean = mean + 6L) %>%
+ collect(),
+ tbl
+ )
+ expect_warning(
+ tbl %>%
+ Table$create() %>%
+ select(int, chr) %>%
+ group_by(chr) %>%
+ mutate(avg_int = mean(int)) %>%
+ collect(),
+ "window functions not currently supported in Arrow; pulling data into R",
+ fixed = TRUE
+ )
+ expect_warning(
+ tbl %>%
+ Table$create() %>%
+ select(mean = int, chr) %>%
+ # rename `int` to `mean` and use `mean(mean)` in `mutate()` to test that
+ # `all_funs()` detects `mean()` despite the collision with a column name
+ group_by(chr) %>%
+ mutate(avg_int = mean(mean)) %>%
+ collect(),
+ "window functions not currently supported in Arrow; pulling data into R",
+ fixed = TRUE
+ )
+})
+
+test_that("handle bad expressions", {
+ # TODO: search for functions other than mean() (see above test)
+ # that need to be forced to fail because they error ambiguously
+
+ with_language("fr", {
+ # expect_warning(., NA) because the usual behavior when it hits a filter
+ # that it can't evaluate is to raise a warning, collect() to R, and retry
+ # the filter. But we want this to error the first time because it's
+ # a user error, not solvable by retrying in R
+ expect_warning(
+ expect_error(
+ Table$create(tbl) %>% mutate(newvar = NOTAVAR + 2),
+ "objet 'NOTAVAR' introuvable"
+ ),
+ NA
+ )
+ })
+})
+
+test_that("Can't just add a vector column with mutate()", {
+ expect_warning(
+ expect_equal(
+ Table$create(tbl) %>%
+ select(int) %>%
+ mutate(again = 1:10),
+ tibble::tibble(int = tbl$int, again = 1:10)
+ ),
+ "In again = 1:10, only values of size one are recycled; pulling data into R"
+ )
+})
+
+test_that("print a mutated table", {
+ expect_output(
+ Table$create(tbl) %>%
+ select(int) %>%
+ mutate(twice = int * 2) %>%
+ print(),
+ "InMemoryDataset (query)
+int: int32
+twice: double (multiply_checked(int, 2))
+
+See $.data for the source Arrow object",
+ fixed = TRUE
+ )
+})
+
+test_that("mutate and write_dataset", {
+ skip_if_not_available("dataset")
+ # See related test in test-dataset.R
+
+ first_date <- lubridate::ymd_hms("2015-04-29 03:12:39")
+ df1 <- tibble(
+ int = 1:10,
+ dbl = as.numeric(1:10),
+ lgl = rep(c(TRUE, FALSE, NA, TRUE, FALSE), 2),
+ chr = letters[1:10],
+ fct = factor(LETTERS[1:10]),
+ ts = first_date + lubridate::days(1:10)
+ )
+
+ second_date <- lubridate::ymd_hms("2017-03-09 07:01:02")
+ df2 <- tibble(
+ int = 101:110,
+ dbl = c(as.numeric(51:59), NaN),
+ lgl = rep(c(TRUE, FALSE, NA, TRUE, FALSE), 2),
+ chr = letters[10:1],
+ fct = factor(LETTERS[10:1]),
+ ts = second_date + lubridate::days(10:1)
+ )
+
+ dst_dir <- tempfile()
+ stacked <- record_batch(rbind(df1, df2))
+ stacked %>%
+ mutate(twice = int * 2) %>%
+ group_by(int) %>%
+ write_dataset(dst_dir, format = "feather")
+ expect_true(dir.exists(dst_dir))
+ expect_identical(dir(dst_dir), sort(paste("int", c(1:10, 101:110), sep = "=")))
+
+ new_ds <- open_dataset(dst_dir, format = "feather")
+
+ expect_equal(
+ new_ds %>%
+ select(string = chr, integer = int, twice) %>%
+ filter(integer > 6 & integer < 11) %>%
+ collect() %>%
+ summarize(mean = mean(integer)),
+ df1 %>%
+ select(string = chr, integer = int) %>%
+ mutate(twice = integer * 2) %>%
+ filter(integer > 6) %>%
+ summarize(mean = mean(integer))
+ )
+})
+
+test_that("mutate and pmin/pmax", {
+ df <- tibble(
+ city = c("Chillan", "Valdivia", "Osorno"),
+ val1 = c(200, 300, NA),
+ val2 = c(100, NA, NA),
+ val3 = c(0, NA, NA)
+ )
+
+ compare_dplyr_binding(
+ .input %>%
+ mutate(
+ max_val_1 = pmax(val1, val2, val3),
+ max_val_2 = pmax(val1, val2, val3, na.rm = TRUE),
+ min_val_1 = pmin(val1, val2, val3),
+ min_val_2 = pmin(val1, val2, val3, na.rm = TRUE)
+ ) %>%
+ collect(),
+ df
+ )
+
+ compare_dplyr_binding(
+ .input %>%
+ mutate(
+ max_val_1 = pmax(val1 - 100, 200, val1 * 100, na.rm = TRUE),
+ min_val_1 = pmin(val1 - 100, 100, val1 * 100, na.rm = TRUE),
+ ) %>%
+ collect(),
+ df
+ )
+})