summaryrefslogtreecommitdiffstats
path: root/src/arrow/r/R/compute.R
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--src/arrow/r/R/compute.R309
1 files changed, 309 insertions, 0 deletions
diff --git a/src/arrow/r/R/compute.R b/src/arrow/r/R/compute.R
new file mode 100644
index 000000000..0a7d77a09
--- /dev/null
+++ b/src/arrow/r/R/compute.R
@@ -0,0 +1,309 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+#' Call an Arrow compute function
+#'
+#' This function provides a lower-level API for calling Arrow functions by their
+#' string function name. You won't use it directly for most applications.
+#' Many Arrow compute functions are mapped to R methods,
+#' and in a `dplyr` evaluation context, [all Arrow functions][list_compute_functions()]
+#' are callable with an `arrow_` prefix.
+#' @param function_name string Arrow compute function name
+#' @param ... Function arguments, which may include `Array`, `ChunkedArray`, `Scalar`,
+#' `RecordBatch`, or `Table`.
+#' @param args list arguments as an alternative to specifying in `...`
+#' @param options named list of C++ function options.
+#' @details When passing indices in `...`, `args`, or `options`, express them as
+#' 0-based integers (consistent with C++).
+#' @return An `Array`, `ChunkedArray`, `Scalar`, `RecordBatch`, or `Table`, whatever the compute function results in.
+#' @seealso [Arrow C++ documentation](https://arrow.apache.org/docs/cpp/compute.html) for
+#' the functions and their respective options.
+#' @examplesIf arrow_available()
+#' a <- Array$create(c(1L, 2L, 3L, NA, 5L))
+#' s <- Scalar$create(4L)
+#' call_function("coalesce", a, s)
+#'
+#' a <- Array$create(rnorm(10000))
+#' call_function("quantile", a, options = list(q = seq(0, 1, 0.25)))
+#' @export
+#' @include array.R
+#' @include chunked-array.R
+#' @include scalar.R
+call_function <- function(function_name, ..., args = list(...), options = empty_named_list()) {
+ assert_that(is.string(function_name))
+ assert_that(is.list(options), !is.null(names(options)))
+
+ datum_classes <- c("Array", "ChunkedArray", "RecordBatch", "Table", "Scalar")
+ valid_args <- map_lgl(args, ~ inherits(., datum_classes))
+ if (!all(valid_args)) {
+ # Lame, just pick one to report
+ first_bad <- min(which(!valid_args))
+ stop(
+ "Argument ", first_bad, " is of class ", head(class(args[[first_bad]]), 1),
+ " but it must be one of ", oxford_paste(datum_classes, "or"),
+ call. = FALSE
+ )
+ }
+
+ compute__CallFunction(function_name, args, options)
+}
+
+#' List available Arrow C++ compute functions
+#'
+#' This function lists the names of all available Arrow C++ library compute functions.
+#' These can be called by passing to [call_function()], or they can be
+#' called by name with an `arrow_` prefix inside a `dplyr` verb.
+#'
+#' The resulting list describes the capabilities of your `arrow` build.
+#' Some functions, such as string and regular expression functions,
+#' require optional build-time C++ dependencies. If your `arrow` package
+#' was not compiled with those features enabled, those functions will
+#' not appear in this list.
+#'
+#' Some functions take options that need to be passed when calling them
+#' (in a list called `options`). These options require custom handling
+#' in C++; many functions already have that handling set up but not all do.
+#' If you encounter one that needs special handling for options, please
+#' report an issue.
+#'
+#' Note that this list does *not* enumerate all of the R bindings for these functions.
+#' The package includes Arrow methods for many base R functions that can
+#' be called directly on Arrow objects, as well as some tidyverse-flavored versions
+#' available inside `dplyr` verbs.
+#'
+#' @param pattern Optional regular expression to filter the function list
+#' @param ... Additional parameters passed to `grep()`
+#' @return A character vector of available Arrow C++ function names
+#' @examplesIf arrow_available()
+#' available_funcs <- list_compute_functions()
+#' utf8_funcs <- list_compute_functions(pattern = "^UTF8", ignore.case = TRUE)
+#' @export
+list_compute_functions <- function(pattern = NULL, ...) {
+ funcs <- compute__GetFunctionNames()
+ if (!is.null(pattern)) {
+ funcs <- grep(pattern, funcs, value = TRUE, ...)
+ }
+ # TODO: Filtering of hash funcs will already happen in C++ with ARROW-13943
+ funcs <- grep(
+ "^hash_",
+ funcs,
+ value = TRUE,
+ invert = TRUE
+ )
+ funcs
+}
+
+#' @export
+sum.ArrowDatum <- function(..., na.rm = FALSE) {
+ scalar_aggregate("sum", ..., na.rm = na.rm)
+}
+
+#' @export
+mean.ArrowDatum <- function(..., na.rm = FALSE) {
+ scalar_aggregate("mean", ..., na.rm = na.rm)
+}
+
+#' @export
+min.ArrowDatum <- function(..., na.rm = FALSE) {
+ scalar_aggregate("min_max", ..., na.rm = na.rm)$GetFieldByName("min")
+}
+
+#' @export
+max.ArrowDatum <- function(..., na.rm = FALSE) {
+ scalar_aggregate("min_max", ..., na.rm = na.rm)$GetFieldByName("max")
+}
+
+scalar_aggregate <- function(FUN, ..., na.rm = FALSE, min_count = 0L) {
+ a <- collect_arrays_from_dots(list(...))
+ if (FUN == "min_max" && na.rm && a$null_count == length(a)) {
+ Array$create(data.frame(min = Inf, max = -Inf))
+ # If na.rm == TRUE and all values in array are NA, R returns
+ # Inf/-Inf, which are type double. Since Arrow is type-stable
+ # and does not do that, we handle this special case here.
+ } else {
+ call_function(FUN, a, options = list(skip_nulls = na.rm, min_count = min_count))
+ }
+}
+
+collect_arrays_from_dots <- function(dots) {
+ # Given a list that may contain both Arrays and ChunkedArrays,
+ # return a single ChunkedArray containing all of those chunks
+ # (may return a regular Array if there is only one element in dots)
+ # If there is only one element and it is a scalar, it returns the scalar
+ if (length(dots) == 1) {
+ return(dots[[1]])
+ }
+
+ assert_that(all(map_lgl(dots, is.Array)))
+ arrays <- unlist(lapply(dots, function(x) {
+ if (inherits(x, "ChunkedArray")) {
+ x$chunks
+ } else {
+ x
+ }
+ }))
+ ChunkedArray$create(!!!arrays)
+}
+
+#' @export
+quantile.ArrowDatum <- function(x,
+ probs = seq(0, 1, 0.25),
+ na.rm = FALSE,
+ type = 7,
+ interpolation = c("linear", "lower", "higher", "nearest", "midpoint"),
+ ...) {
+ if (inherits(x, "Scalar")) x <- Array$create(x)
+ assert_is(probs, c("numeric", "integer"))
+ assert_that(length(probs) > 0)
+ assert_that(all(probs >= 0 & probs <= 1))
+ if (!na.rm && x$null_count > 0) {
+ stop("Missing values not allowed if 'na.rm' is FALSE", call. = FALSE)
+ }
+ if (type != 7) {
+ stop(
+ "Argument `type` not supported in Arrow. To control the quantile ",
+ "interpolation algorithm, set argument `interpolation` to one of: ",
+ "\"linear\" (the default), \"lower\", \"higher\", \"nearest\", or ",
+ "\"midpoint\".",
+ call. = FALSE
+ )
+ }
+ interpolation <- QuantileInterpolation[[toupper(match.arg(interpolation))]]
+ out <- call_function("quantile", x, options = list(q = probs, interpolation = interpolation))
+ if (length(out) == 0) {
+ # When there are no non-missing values in the data, the Arrow quantile
+ # function returns an empty Array, but for consistency with the R quantile
+ # function, we want an Array of NA_real_ with the same length as probs
+ out <- Array$create(rep(NA_real_, length(probs)))
+ }
+ out
+}
+
+#' @export
+median.ArrowDatum <- function(x, na.rm = FALSE, ...) {
+ if (!na.rm && x$null_count > 0) {
+ Scalar$create(NA_real_)
+ } else {
+ Scalar$create(quantile(x, probs = 0.5, na.rm = TRUE, ...))
+ }
+}
+
+#' @export
+unique.ArrowDatum <- function(x, incomparables = FALSE, ...) {
+ call_function("unique", x)
+}
+
+#' @export
+any.ArrowDatum <- function(..., na.rm = FALSE) {
+ scalar_aggregate("any", ..., na.rm = na.rm)
+}
+
+#' @export
+all.ArrowDatum <- function(..., na.rm = FALSE) {
+ scalar_aggregate("all", ..., na.rm = na.rm)
+}
+
+#' `match` and `%in%` for Arrow objects
+#'
+#' `base::match()` is not a generic, so we can't just define Arrow methods for
+#' it. This function exposes the analogous functions in the Arrow C++ library.
+#'
+#' @param x `Scalar`, `Array` or `ChunkedArray`
+#' @param table `Scalar`, Array`, `ChunkedArray`, or R vector lookup table.
+#' @param ... additional arguments, ignored
+#' @return `match_arrow()` returns an `int32`-type Arrow object of the same length
+#' and type as `x` with the (0-based) indexes into `table`. `is_in()` returns a
+#' `boolean`-type Arrow object of the same length and type as `x` with values indicating
+#' per element of `x` it it is present in `table`.
+#' @examplesIf arrow_available()
+#' # note that the returned value is 0-indexed
+#' cars_tbl <- arrow_table(name = rownames(mtcars), mtcars)
+#' match_arrow(Scalar$create("Mazda RX4 Wag"), cars_tbl$name)
+#'
+#' is_in(Array$create("Mazda RX4 Wag"), cars_tbl$name)
+#'
+#' # Although there are multiple matches, you are returned the index of the first
+#' # match, as with the base R equivalent
+#' match(4, mtcars$cyl) # 1-indexed
+#' match_arrow(Scalar$create(4), cars_tbl$cyl) # 0-indexed
+#'
+#' # If `x` contains multiple values, you are returned the indices of the first
+#' # match for each value.
+#' match(c(4, 6, 8), mtcars$cyl)
+#' match_arrow(Array$create(c(4, 6, 8)), cars_tbl$cyl)
+#'
+#' # Return type matches type of `x`
+#' is_in(c(4, 6, 8), mtcars$cyl) # returns vector
+#' is_in(Scalar$create(4), mtcars$cyl) # returns Scalar
+#' is_in(Array$create(c(4, 6, 8)), cars_tbl$cyl) # returns Array
+#' is_in(ChunkedArray$create(c(4, 6), 8), cars_tbl$cyl) # returns ChunkedArray
+#' @export
+match_arrow <- function(x, table, ...) {
+ if (!inherits(x, "ArrowDatum")) {
+ x <- Array$create(x)
+ }
+
+ if (!inherits(table, c("Array", "ChunkedArray"))) {
+ table <- Array$create(table)
+ }
+ call_function("index_in_meta_binary", x, table)
+}
+
+#' @rdname match_arrow
+#' @export
+is_in <- function(x, table, ...) {
+ if (!inherits(x, "ArrowDatum")) {
+ x <- Array$create(x)
+ }
+
+ if (!inherits(table, c("Array", "DictionaryArray", "ChunkedArray"))) {
+ table <- Array$create(table)
+ }
+ call_function("is_in_meta_binary", x, table)
+}
+
+#' `table` for Arrow objects
+#'
+#' This function tabulates the values in the array and returns a table of counts.
+#' @param x `Array` or `ChunkedArray`
+#' @return A `StructArray` containing "values" (same type as `x`) and "counts"
+#' `Int64`.
+#' @examplesIf arrow_available()
+#' cyl_vals <- Array$create(mtcars$cyl)
+#' counts <- value_counts(cyl_vals)
+#' @export
+value_counts <- function(x) {
+ call_function("value_counts", x)
+}
+
+#' Cast options
+#'
+#' @param safe logical: enforce safe conversion? Default `TRUE`
+#' @param ... additional cast options, such as `allow_int_overflow`,
+#' `allow_time_truncate`, and `allow_float_truncate`, which are set to `!safe`
+#' by default
+#' @return A list
+#' @export
+#' @keywords internal
+cast_options <- function(safe = TRUE, ...) {
+ opts <- list(
+ allow_int_overflow = !safe,
+ allow_time_truncate = !safe,
+ allow_float_truncate = !safe
+ )
+ modifyList(opts, list(...))
+}