diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-21 11:54:28 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-21 11:54:28 +0000 |
commit | e6918187568dbd01842d8d1d2c808ce16a894239 (patch) | |
tree | 64f88b554b444a49f656b6c656111a145cbbaa28 /src/arrow/r/R/arrow-datum.R | |
parent | Initial commit. (diff) | |
download | ceph-upstream/18.2.2.tar.xz ceph-upstream/18.2.2.zip |
Adding upstream version 18.2.2.upstream/18.2.2
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to '')
-rw-r--r-- | src/arrow/r/R/arrow-datum.R | 266 |
1 files changed, 266 insertions, 0 deletions
diff --git a/src/arrow/r/R/arrow-datum.R b/src/arrow/r/R/arrow-datum.R new file mode 100644 index 000000000..557321f68 --- /dev/null +++ b/src/arrow/r/R/arrow-datum.R @@ -0,0 +1,266 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +#' @include arrow-package.R + +# Base class for Array, ChunkedArray, and Scalar, for S3 method dispatch only. +# Does not exist in C++ class hierarchy +ArrowDatum <- R6Class("ArrowDatum", + inherit = ArrowObject, + public = list( + cast = function(target_type, safe = TRUE, ...) { + opts <- cast_options(safe, ...) + opts$to_type <- as_type(target_type) + call_function("cast", self, options = opts) + } + ) +) + +#' @export +length.ArrowDatum <- function(x) x$length() + +#' @export +is.finite.ArrowDatum <- function(x) { + is_fin <- call_function("is_finite", x) + # for compatibility with base::is.finite(), return FALSE for NA_real_ + is_fin & !is.na(is_fin) +} + +#' @export +is.infinite.ArrowDatum <- function(x) { + is_inf <- call_function("is_inf", x) + # for compatibility with base::is.infinite(), return FALSE for NA_real_ + is_inf & !is.na(is_inf) +} + +#' @export +is.na.ArrowDatum <- function(x) { + call_function("is_null", x, options = list(nan_is_null = TRUE)) +} + +#' @export +is.nan.ArrowDatum <- function(x) { + if (x$type_id() %in% TYPES_WITH_NAN) { + # TODO: if an option is added to the is_nan kernel to treat NA as NaN, + # use that to simplify the code here (ARROW-13366) + call_function("is_nan", x) & call_function("is_valid", x) + } else { + Scalar$create(FALSE)$as_array(length(x)) + } +} + +#' @export +as.vector.ArrowDatum <- function(x, mode) { + x$as_vector() +} + +#' @export +Ops.ArrowDatum <- function(e1, e2) { + if (.Generic == "!") { + eval_array_expression(.Generic, e1) + } else if (.Generic %in% names(.array_function_map)) { + eval_array_expression(.Generic, e1, e2) + } else { + stop(paste0("Unsupported operation on `", class(e1)[1L], "` : "), .Generic, call. = FALSE) + } +} + +# Wrapper around call_function that: +# (1) maps R function names to Arrow C++ compute ("/" --> "divide_checked") +# (2) wraps R input args as Array or Scalar +eval_array_expression <- function(FUN, + ..., + args = list(...), + options = empty_named_list()) { + if (FUN == "-" && length(args) == 1L) { + if (inherits(args[[1]], "ArrowObject")) { + return(eval_array_expression("negate_checked", args[[1]])) + } else { + return(-args[[1]]) + } + } + args <- lapply(args, .wrap_arrow, FUN) + + # In Arrow, "divide" is one function, which does integer division on + # integer inputs and floating-point division on floats + if (FUN == "/") { + # TODO: omg so many ways it's wrong to assume these types + args <- map(args, ~ .$cast(float64())) + } else if (FUN == "%/%") { + # In R, integer division works like floor(float division) + out <- eval_array_expression("/", args = args, options = options) + return(out$cast(int32(), allow_float_truncate = TRUE)) + } else if (FUN == "%%") { + # We can't simply do {e1 - e2 * ( e1 %/% e2 )} since Ops.Array evaluates + # eagerly, but we can build that up + quotient <- eval_array_expression("%/%", args = args) + base <- eval_array_expression("*", quotient, args[[2]]) + # this cast is to ensure that the result of this and e1 are the same + # (autocasting only applies to scalars) + base <- base$cast(args[[1]]$type) + return(eval_array_expression("-", args[[1]], base)) + } + + call_function( + .array_function_map[[FUN]] %||% FUN, + args = args, + options = options + ) +} + +.wrap_arrow <- function(arg, fun) { + if (!inherits(arg, "ArrowObject")) { + # TODO: Array$create if lengths are equal? + if (fun == "%in%") { + arg <- Array$create(arg) + } else { + arg <- Scalar$create(arg) + } + } + arg +} + +#' @export +na.omit.ArrowDatum <- function(object, ...) { + object$Filter(!is.na(object)) +} + +#' @export +na.exclude.ArrowDatum <- na.omit.ArrowDatum + +#' @export +na.fail.ArrowDatum <- function(object, ...) { + if (object$null_count > 0) { + stop("missing values in object", call. = FALSE) + } + object +} + +filter_rows <- function(x, i, keep_na = TRUE, ...) { + # General purpose function for [ row subsetting with R semantics + # Based on the input for `i`, calls x$Filter, x$Slice, or x$Take + nrows <- x$num_rows %||% x$length() # Depends on whether Array or Table-like + if (is.logical(i)) { + if (isTRUE(i)) { + # Shortcut without doing any work + x + } else { + i <- rep_len(i, nrows) # For R recycling behavior; consider vctrs::vec_recycle() + x$Filter(i, keep_na) + } + } else if (is.numeric(i)) { + if (all(i < 0)) { + # in R, negative i means "everything but i" + i <- setdiff(seq_len(nrows), -1 * i) + } + if (is.sliceable(i)) { + x$Slice(i[1] - 1, length(i)) + } else if (all(i > 0)) { + x$Take(i - 1) + } else { + stop("Cannot mix positive and negative indices", call. = FALSE) + } + } else if (is.Array(i, INTEGER_TYPES)) { + # NOTE: this doesn't do the - 1 offset + x$Take(i) + } else if (is.Array(i, "bool")) { + x$Filter(i, keep_na) + } else { + # Unsupported cases + if (is.Array(i)) { + stop("Cannot extract rows with an Array of type ", i$type$ToString(), call. = FALSE) + } + stop("Cannot extract rows with an object of class ", class(i), call. = FALSE) + } +} + +#' @export +`[.ArrowDatum` <- filter_rows + +#' @importFrom utils head +#' @export +head.ArrowDatum <- function(x, n = 6L, ...) { + assert_is(n, c("numeric", "integer")) + assert_that(length(n) == 1) + len <- NROW(x) + if (n < 0) { + # head(x, negative) means all but the last n rows + n <- max(len + n, 0) + } else { + n <- min(len, n) + } + if (n == len) { + return(x) + } + x$Slice(0, n) +} + +#' @importFrom utils tail +#' @export +tail.ArrowDatum <- function(x, n = 6L, ...) { + assert_is(n, c("numeric", "integer")) + assert_that(length(n) == 1) + len <- NROW(x) + if (n < 0) { + # tail(x, negative) means all but the first n rows + n <- min(-n, len) + } else { + n <- max(len - n, 0) + } + if (n == 0) { + return(x) + } + x$Slice(n) +} + +is.sliceable <- function(i) { + # Determine whether `i` can be expressed as a $Slice() command + is.numeric(i) && + length(i) > 0 && + all(i > 0) && + i[1] <= i[length(i)] && + identical(as.integer(i), i[1]:i[length(i)]) +} + +#' @export +as.double.ArrowDatum <- function(x, ...) as.double(as.vector(x), ...) + +#' @export +as.integer.ArrowDatum <- function(x, ...) as.integer(as.vector(x), ...) + +#' @export +as.character.ArrowDatum <- function(x, ...) as.character(as.vector(x), ...) + +#' @export +sort.ArrowDatum <- function(x, decreasing = FALSE, na.last = NA, ...) { + # Arrow always sorts nulls at the end of the array. This corresponds to + # sort(na.last = TRUE). For the other two cases (na.last = NA and + # na.last = FALSE) we need to use workarounds. + # TODO: Implement this more cleanly after ARROW-12063 + if (is.na(na.last)) { + # Filter out NAs before sorting + x <- x$Filter(!is.na(x)) + x$Take(x$SortIndices(descending = decreasing)) + } else if (na.last) { + x$Take(x$SortIndices(descending = decreasing)) + } else { + # Create a new array that encodes missing values as 1 and non-missing values + # as 0. Sort descending by that array first to get the NAs at the beginning + tbl <- Table$create(x = x, `is_na` = as.integer(is.na(x))) + tbl$x$Take(tbl$SortIndices(names = c("is_na", "x"), descending = c(TRUE, decreasing))) + } +} |