Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions R/pkg/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -201,13 +201,16 @@ exportMethods("%<=>%",
"approxCountDistinct",
"approxQuantile",
"array_contains",
"array_distinct",
"array_join",
"array_max",
"array_min",
"array_position",
"array_remove",
"array_repeat",
"array_sort",
"arrays_overlap",
"arrays_zip",
"asc",
"ascii",
"asin",
Expand Down Expand Up @@ -306,6 +309,7 @@ exportMethods("%<=>%",
"lpad",
"ltrim",
"map_entries",
"map_from_arrays",
"map_keys",
"map_values",
"max",
Expand Down
70 changes: 66 additions & 4 deletions R/pkg/R/functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,12 @@ NULL
#' \itemize{
#' \item \code{array_contains}: a value to be checked if contained in the column.
#' \item \code{array_position}: a value to locate in the given array.
#' \item \code{array_remove}: a value to remove in the given array.
#' }
#' @param ... additional argument(s). In \code{to_json} and \code{from_json}, this contains
#' additional named properties to control how it is converted, accepts the same
#' options as the JSON data source.
#' options as the JSON data source. In \code{arrays_zip}, this contains additional
#' Columns of arrays to be merged.
#' @name column_collection_functions
#' @rdname column_collection_functions
#' @family collection functions
Expand All @@ -207,9 +209,9 @@ NULL
#' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars))
#' tmp <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp))
#' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1)))
#' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1)))
#' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1), array_distinct(tmp$v1)))
#' head(select(tmp, array_position(tmp$v1, 21), array_repeat(df$mpg, 3), array_sort(tmp$v1)))
#' head(select(tmp, flatten(tmp$v1), reverse(tmp$v1)))
#' head(select(tmp, flatten(tmp$v1), reverse(tmp$v1), array_remove(tmp$v1, 21)))
#' tmp2 <- mutate(tmp, v2 = explode(tmp$v1))
#' head(tmp2)
#' head(select(tmp, posexplode(tmp$v1)))
Expand All @@ -221,6 +223,7 @@ NULL
#' head(select(tmp3, element_at(tmp3$v3, "Valiant")))
#' tmp4 <- mutate(df, v4 = create_array(df$mpg, df$cyl), v5 = create_array(df$cyl, df$hp))
#' head(select(tmp4, concat(tmp4$v4, tmp4$v5), arrays_overlap(tmp4$v4, tmp4$v5)))
#' head(select(tmp4, arrays_zip(tmp4$v4, tmp4$v5), map_from_arrays(tmp4$v4, tmp4$v5)))
#' head(select(tmp, concat(df$mpg, df$cyl, df$hp)))
#' tmp5 <- mutate(df, v6 = create_array(df$model, df$model))
#' head(select(tmp5, array_join(tmp5$v6, "#"), array_join(tmp5$v6, "#", "NULL")))}
Expand Down Expand Up @@ -1978,7 +1981,7 @@ setMethod("levenshtein", signature(y = "Column"),
})

#' @details
#' \code{months_between}: Returns number of months between dates \code{y} and \code{x}.
#' \code{months_between}: Returns number of months between dates \code{y} and \code{x}.
#' If \code{y} is later than \code{x}, then the result is positive. If \code{y} and \code{x}
#' are on the same day of month, or both are the last day of month, time of day will be ignored.
#' Otherwise, the difference is calculated based on 31 days per month, and rounded to 8 digits.
Expand Down Expand Up @@ -3008,6 +3011,19 @@ setMethod("array_contains",
column(jc)
})

#' @details
#' \code{array_distinct}: Removes duplicate values from the array.
#'
#' @rdname column_collection_functions
#' @aliases array_distinct array_distinct,Column-method
#' @note array_distinct since 2.4.0
setMethod("array_distinct",
signature(x = "Column"),
function(x) {
jc <- callJStatic("org.apache.spark.sql.functions", "array_distinct", x@jc)
column(jc)
})

#' @details
#' \code{array_join}: Concatenates the elements of column using the delimiter.
#' Null values are replaced with nullReplacement if set, otherwise they are ignored.
Expand Down Expand Up @@ -3071,6 +3087,19 @@ setMethod("array_position",
column(jc)
})

#' @details
#' \code{array_remove}: Removes all elements that equal to element from the given array.
#'
#' @rdname column_collection_functions
#' @aliases array_remove array_remove,Column-method
#' @note array_remove since 2.4.0
setMethod("array_remove",
signature(x = "Column", value = "ANY"),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your review. I will add in doc.

function(x, value) {
jc <- callJStatic("org.apache.spark.sql.functions", "array_remove", x@jc, value)
column(jc)
})

#' @details
#' \code{array_repeat}: Creates an array containing \code{x} repeated the number of times
#' given by \code{count}.
Expand Down Expand Up @@ -3120,6 +3149,24 @@ setMethod("arrays_overlap",
column(jc)
})

#' @details
#' \code{arrays_zip}: Returns a merged array of structs in which the N-th struct contains all N-th
#' values of input arrays.
#'
#' @rdname column_collection_functions
#' @aliases arrays_zip arrays_zip,Column-method
#' @note arrays_zip since 2.4.0
setMethod("arrays_zip",
signature(x = "Column"),
function(x, ...) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jcols <- lapply(list(x, ...), function(arg) {
stopifnot(class(arg) == "Column")
arg@jc
})
jc <- callJStatic("org.apache.spark.sql.functions", "arrays_zip", jcols)
column(jc)
})

#' @details
#' \code{flatten}: Creates a single array from an array of arrays.
#' If a structure of nested arrays is deeper than two levels, only one level of nesting is removed.
Expand Down Expand Up @@ -3147,6 +3194,21 @@ setMethod("map_entries",
column(jc)
})

#' @details
#' \code{map_from_arrays}: Creates a new map column. The array in the first column is used for
#' keys. The array in the second column is used for values. All elements in the array for key
#' should not be null.
#'
#' @rdname column_collection_functions
#' @aliases map_from_arrays map_from_arrays,Column-method
#' @note map_from_arrays since 2.4.0
setMethod("map_from_arrays",
signature(x = "Column", y = "Column"),
function(x, y) {
jc <- callJStatic("org.apache.spark.sql.functions", "map_from_arrays", x@jc, y@jc)
column(jc)
})

#' @details
#' \code{map_keys}: Returns an unordered array containing the keys of the map.
#'
Expand Down
16 changes: 16 additions & 0 deletions R/pkg/R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,10 @@ setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCoun
#' @name NULL
setGeneric("array_contains", function(x, value) { standardGeneric("array_contains") })

#' @rdname column_collection_functions
#' @name NULL
setGeneric("array_distinct", function(x) { standardGeneric("array_distinct") })

#' @rdname column_collection_functions
#' @name NULL
setGeneric("array_join", function(x, delimiter, ...) { standardGeneric("array_join") })
Expand All @@ -773,6 +777,10 @@ setGeneric("array_min", function(x) { standardGeneric("array_min") })
#' @name NULL
setGeneric("array_position", function(x, value) { standardGeneric("array_position") })

#' @rdname column_collection_functions
#' @name NULL
setGeneric("array_remove", function(x, value) { standardGeneric("array_remove") })

#' @rdname column_collection_functions
#' @name NULL
setGeneric("array_repeat", function(x, count) { standardGeneric("array_repeat") })
Expand All @@ -785,6 +793,10 @@ setGeneric("array_sort", function(x) { standardGeneric("array_sort") })
#' @name NULL
setGeneric("arrays_overlap", function(x, y) { standardGeneric("arrays_overlap") })

#' @rdname column_collection_functions
#' @name NULL
setGeneric("arrays_zip", function(x, ...) { standardGeneric("arrays_zip") })

#' @rdname column_string_functions
#' @name NULL
setGeneric("ascii", function(x) { standardGeneric("ascii") })
Expand Down Expand Up @@ -1050,6 +1062,10 @@ setGeneric("ltrim", function(x, trimString) { standardGeneric("ltrim") })
#' @name NULL
setGeneric("map_entries", function(x) { standardGeneric("map_entries") })

#' @rdname column_collection_functions
#' @name NULL
setGeneric("map_from_arrays", function(x, y) { standardGeneric("map_from_arrays") })

#' @rdname column_collection_functions
#' @name NULL
setGeneric("map_keys", function(x) { standardGeneric("map_keys") })
Expand Down
21 changes: 21 additions & 0 deletions R/pkg/tests/fulltests/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -1503,6 +1503,27 @@ test_that("column functions", {
result <- collect(select(df2, reverse(df2[[1]])))[[1]]
expect_equal(result, "cba")

# Test array_distinct() and array_remove()
df <- createDataFrame(list(list(list(1L, 2L, 3L, 1L, 2L)), list(list(6L, 5L, 5L, 4L, 6L))))
result <- collect(select(df, array_distinct(df[[1]])))[[1]]
expect_equal(result, list(list(1L, 2L, 3L), list(6L, 5L, 4L)))

result <- collect(select(df, array_remove(df[[1]], 2L)))[[1]]
expect_equal(result, list(list(1L, 3L, 1L), list(6L, 5L, 5L, 4L, 6L)))

# Test arrays_zip()
df <- createDataFrame(list(list(list(1L, 2L), list(3L, 4L))), schema = c("c1", "c2"))
result <- collect(select(df, arrays_zip(df[[1]], df[[2]])))[[1]]
expected_entries <- list(listToStruct(list(c1 = 1L, c2 = 3L)),
listToStruct(list(c1 = 2L, c2 = 4L)))
expect_equal(result, list(expected_entries))

# Test map_from_arrays()
df <- createDataFrame(list(list(list("x", "y"), list(1, 2))), schema = c("k", "v"))
result <- collect(select(df, map_from_arrays(df$k, df$v)))[[1]]
expected_entries <- list(as.environment(list(x = 1, y = 2)))
expect_equal(result, expected_entries)

# Test array_repeat()
df <- createDataFrame(list(list("a", 3L), list("b", 2L)))
result <- collect(select(df, array_repeat(df[[1]], df[[2]])))[[1]]
Expand Down