Skip to content

Commit 01f5cca

Browse files
committed
Merge remote-tracking branch 'origin/master' into SPARK-6152
2 parents 002d1ea + 1dde39d commit 01f5cca

File tree

332 files changed

+13186
-5913
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

332 files changed

+13186
-5913
lines changed

R/pkg/DESCRIPTION

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,5 @@ Collate:
3434
'serialize.R'
3535
'sparkR.R'
3636
'stats.R'
37+
'types.R'
3738
'utils.R'

R/pkg/NAMESPACE

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,11 @@ export("setJobGroup",
2323
exportClasses("DataFrame")
2424

2525
exportMethods("arrange",
26+
"as.data.frame",
2627
"attach",
2728
"cache",
2829
"collect",
30+
"coltypes",
2931
"columns",
3032
"count",
3133
"cov",
@@ -262,6 +264,4 @@ export("structField",
262264
"structType",
263265
"structType.jobj",
264266
"structType.structField",
265-
"print.structType")
266-
267-
export("as.data.frame")
267+
"print.structType")

R/pkg/R/DataFrame.R

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1944,9 +1944,9 @@ setMethod("describe",
19441944
#' @rdname summary
19451945
#' @name summary
19461946
setMethod("summary",
1947-
signature(x = "DataFrame"),
1948-
function(x) {
1949-
describe(x)
1947+
signature(object = "DataFrame"),
1948+
function(object, ...) {
1949+
describe(object)
19501950
})
19511951

19521952

@@ -2152,3 +2152,52 @@ setMethod("with",
21522152
newEnv <- assignNewEnv(data)
21532153
eval(substitute(expr), envir = newEnv, enclos = newEnv)
21542154
})
2155+
2156+
#' Returns the column types of a DataFrame.
2157+
#'
2158+
#' @name coltypes
2159+
#' @title Get column types of a DataFrame
2160+
#' @family dataframe_funcs
2161+
#' @param x (DataFrame)
2162+
#' @return value (character) A character vector with the column types of the given DataFrame
2163+
#' @rdname coltypes
2164+
#' @examples \dontrun{
2165+
#' irisDF <- createDataFrame(sqlContext, iris)
2166+
#' coltypes(irisDF)
2167+
#' }
2168+
setMethod("coltypes",
2169+
signature(x = "DataFrame"),
2170+
function(x) {
2171+
# Get the data types of the DataFrame by invoking dtypes() function
2172+
types <- sapply(dtypes(x), function(x) {x[[2]]})
2173+
2174+
# Map Spark data types into R's data types using DATA_TYPES environment
2175+
rTypes <- sapply(types, USE.NAMES=F, FUN=function(x) {
2176+
2177+
# Check for primitive types
2178+
type <- PRIMITIVE_TYPES[[x]]
2179+
2180+
if (is.null(type)) {
2181+
# Check for complex types
2182+
for (t in names(COMPLEX_TYPES)) {
2183+
if (substring(x, 1, nchar(t)) == t) {
2184+
type <- COMPLEX_TYPES[[t]]
2185+
break
2186+
}
2187+
}
2188+
2189+
if (is.null(type)) {
2190+
stop(paste("Unsupported data type: ", x))
2191+
}
2192+
}
2193+
type
2194+
})
2195+
2196+
# Find which types don't have mapping to R
2197+
naIndices <- which(is.na(rTypes))
2198+
2199+
# Assign the original scala data types to the unmatched ones
2200+
rTypes[naIndices] <- types[naIndices]
2201+
2202+
rTypes
2203+
})

R/pkg/R/functions.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1339,7 +1339,7 @@ setMethod("pmod", signature(y = "Column"),
13391339
#' @export
13401340
setMethod("approxCountDistinct",
13411341
signature(x = "Column"),
1342-
function(x, rsd = 0.95) {
1342+
function(x, rsd = 0.05) {
13431343
jc <- callJStatic("org.apache.spark.sql.functions", "approxCountDistinct", x@jc, rsd)
13441344
column(jc)
13451345
})

R/pkg/R/generics.R

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,7 @@ setGeneric("summarize", function(x,...) { standardGeneric("summarize") })
561561

562562
#' @rdname summary
563563
#' @export
564-
setGeneric("summary", function(x, ...) { standardGeneric("summary") })
564+
setGeneric("summary", function(object, ...) { standardGeneric("summary") })
565565

566566
# @rdname tojson
567567
# @export
@@ -1047,3 +1047,7 @@ setGeneric("attach")
10471047
#' @rdname with
10481048
#' @export
10491049
setGeneric("with")
1050+
1051+
#' @rdname coltypes
1052+
#' @export
1053+
setGeneric("coltypes", function(x) { standardGeneric("coltypes") })

R/pkg/R/mllib.R

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,9 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFram
4848
function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0,
4949
standardize = TRUE, solver = "auto") {
5050
family <- match.arg(family)
51+
formula <- paste(deparse(formula), collapse="")
5152
model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
52-
"fitRModelFormula", deparse(formula), data@sdf, family, lambda,
53+
"fitRModelFormula", formula, data@sdf, family, lambda,
5354
alpha, standardize, solver)
5455
return(new("PipelineModel", model = model))
5556
})
@@ -88,14 +89,28 @@ setMethod("predict", signature(object = "PipelineModel"),
8889
#' model <- glm(y ~ x, trainingData)
8990
#' summary(model)
9091
#'}
91-
setMethod("summary", signature(x = "PipelineModel"),
92-
function(x, ...) {
92+
setMethod("summary", signature(object = "PipelineModel"),
93+
function(object, ...) {
94+
modelName <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
95+
"getModelName", object@model)
9396
features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
94-
"getModelFeatures", x@model)
97+
"getModelFeatures", object@model)
9598
coefficients <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
96-
"getModelCoefficients", x@model)
97-
coefficients <- as.matrix(unlist(coefficients))
98-
colnames(coefficients) <- c("Estimate")
99-
rownames(coefficients) <- unlist(features)
100-
return(list(coefficients = coefficients))
99+
"getModelCoefficients", object@model)
100+
if (modelName == "LinearRegressionModel") {
101+
devianceResiduals <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
102+
"getModelDevianceResiduals", object@model)
103+
devianceResiduals <- matrix(devianceResiduals, nrow = 1)
104+
colnames(devianceResiduals) <- c("Min", "Max")
105+
rownames(devianceResiduals) <- rep("", times = 1)
106+
coefficients <- matrix(coefficients, ncol = 4)
107+
colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)")
108+
rownames(coefficients) <- unlist(features)
109+
return(list(devianceResiduals = devianceResiduals, coefficients = coefficients))
110+
} else {
111+
coefficients <- as.matrix(unlist(coefficients))
112+
colnames(coefficients) <- c("Estimate")
113+
rownames(coefficients) <- unlist(features)
114+
return(list(coefficients = coefficients))
115+
}
101116
})

R/pkg/R/schema.R

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -115,20 +115,7 @@ structField.jobj <- function(x) {
115115
}
116116

117117
checkType <- function(type) {
118-
primtiveTypes <- c("byte",
119-
"integer",
120-
"float",
121-
"double",
122-
"numeric",
123-
"character",
124-
"string",
125-
"binary",
126-
"raw",
127-
"logical",
128-
"boolean",
129-
"timestamp",
130-
"date")
131-
if (type %in% primtiveTypes) {
118+
if (!is.null(PRIMITIVE_TYPES[[type]])) {
132119
return()
133120
} else {
134121
# Check complex types

R/pkg/R/types.R

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one or more
2+
# contributor license agreements. See the NOTICE file distributed with
3+
# this work for additional information regarding copyright ownership.
4+
# The ASF licenses this file to You under the Apache License, Version 2.0
5+
# (the "License"); you may not use this file except in compliance with
6+
# the License. You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
# types.R. This file handles the data type mapping between Spark and R
17+
18+
# The primitive data types, where names(PRIMITIVE_TYPES) are Scala types whereas
19+
# values are equivalent R types. This is stored in an environment to allow for
20+
# more efficient look up (environments use hashmaps).
21+
PRIMITIVE_TYPES <- as.environment(list(
22+
"byte"="integer",
23+
"tinyint"="integer",
24+
"smallint"="integer",
25+
"integer"="integer",
26+
"bigint"="numeric",
27+
"float"="numeric",
28+
"double"="numeric",
29+
"decimal"="numeric",
30+
"string"="character",
31+
"binary"="raw",
32+
"boolean"="logical",
33+
"timestamp"="POSIXct",
34+
"date"="Date"))
35+
36+
# The complex data types. These do not have any direct mapping to R's types.
37+
COMPLEX_TYPES <- list(
38+
"map"=NA,
39+
"array"=NA,
40+
"struct"=NA)
41+
42+
# The full list of data types.
43+
DATA_TYPES <- as.environment(c(as.list(PRIMITIVE_TYPES), COMPLEX_TYPES))

R/pkg/inst/tests/test_mllib.R

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,18 @@ test_that("glm and predict", {
3333
expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double")
3434
})
3535

36+
test_that("glm should work with long formula", {
37+
training <- createDataFrame(sqlContext, iris)
38+
training$LongLongLongLongLongName <- training$Sepal_Width
39+
training$VeryLongLongLongLonLongName <- training$Sepal_Length
40+
training$AnotherLongLongLongLongName <- training$Species
41+
model <- glm(LongLongLongLongLongName ~ VeryLongLongLongLonLongName + AnotherLongLongLongLongName,
42+
data = training)
43+
vals <- collect(select(predict(model, training), "prediction"))
44+
rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris)
45+
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
46+
})
47+
3648
test_that("predictions match with native glm", {
3749
training <- createDataFrame(sqlContext, iris)
3850
model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training)
@@ -59,12 +71,18 @@ test_that("feature interaction vs native glm", {
5971

6072
test_that("summary coefficients match with native glm", {
6173
training <- createDataFrame(sqlContext, iris)
62-
stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training, solver = "l-bfgs"))
63-
coefs <- as.vector(stats$coefficients)
64-
rCoefs <- as.vector(coef(glm(Sepal.Width ~ Sepal.Length + Species, data = iris)))
65-
expect_true(all(abs(rCoefs - coefs) < 1e-6))
74+
stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training, solver = "normal"))
75+
coefs <- unlist(stats$coefficients)
76+
devianceResiduals <- unlist(stats$devianceResiduals)
77+
78+
rStats <- summary(glm(Sepal.Width ~ Sepal.Length + Species, data = iris))
79+
rCoefs <- unlist(rStats$coefficients)
80+
rDevianceResiduals <- c(-0.95096, 0.72918)
81+
82+
expect_true(all(abs(rCoefs - coefs) < 1e-5))
83+
expect_true(all(abs(rDevianceResiduals - devianceResiduals) < 1e-5))
6684
expect_true(all(
67-
as.character(stats$features) ==
85+
rownames(stats$coefficients) ==
6886
c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica")))
6987
})
7088

@@ -73,14 +91,20 @@ test_that("summary coefficients match with native glm of family 'binomial'", {
7391
training <- filter(df, df$Species != "setosa")
7492
stats <- summary(glm(Species ~ Sepal_Length + Sepal_Width, data = training,
7593
family = "binomial"))
76-
coefs <- as.vector(stats$coefficients)
94+
coefs <- as.vector(stats$coefficients[,1])
7795

7896
rTraining <- iris[iris$Species %in% c("versicolor","virginica"),]
7997
rCoefs <- as.vector(coef(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining,
8098
family = binomial(link = "logit"))))
8199

82100
expect_true(all(abs(rCoefs - coefs) < 1e-4))
83101
expect_true(all(
84-
as.character(stats$features) ==
102+
rownames(stats$coefficients) ==
85103
c("(Intercept)", "Sepal_Length", "Sepal_Width")))
86104
})
105+
106+
test_that("summary works on base GLM models", {
107+
baseModel <- stats::glm(Sepal.Width ~ Sepal.Length + Species, data = iris)
108+
baseSummary <- summary(baseModel)
109+
expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4)
110+
})

R/pkg/inst/tests/test_sparkSQL.R

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -647,11 +647,11 @@ test_that("sample on a DataFrame", {
647647
sampled <- sample(df, FALSE, 1.0)
648648
expect_equal(nrow(collect(sampled)), count(df))
649649
expect_is(sampled, "DataFrame")
650-
sampled2 <- sample(df, FALSE, 0.1)
650+
sampled2 <- sample(df, FALSE, 0.1, 0) # set seed for predictable result
651651
expect_true(count(sampled2) < 3)
652652

653653
# Also test sample_frac
654-
sampled3 <- sample_frac(df, FALSE, 0.1)
654+
sampled3 <- sample_frac(df, FALSE, 0.1, 0) # set seed for predictable result
655655
expect_true(count(sampled3) < 3)
656656
})
657657

@@ -875,9 +875,9 @@ test_that("column binary mathfunctions", {
875875
expect_equal(collect(select(df, shiftRight(df$b, 1)))[4, 1], 4)
876876
expect_equal(collect(select(df, shiftRightUnsigned(df$b, 1)))[4, 1], 4)
877877
expect_equal(class(collect(select(df, rand()))[2, 1]), "numeric")
878-
expect_equal(collect(select(df, rand(1)))[1, 1], 0.45, tolerance = 0.01)
878+
expect_equal(collect(select(df, rand(1)))[1, 1], 0.134, tolerance = 0.01)
879879
expect_equal(class(collect(select(df, randn()))[2, 1]), "numeric")
880-
expect_equal(collect(select(df, randn(1)))[1, 1], -0.0111, tolerance = 0.01)
880+
expect_equal(collect(select(df, randn(1)))[1, 1], -1.03, tolerance = 0.01)
881881
})
882882

883883
test_that("string operators", {
@@ -1458,17 +1458,18 @@ test_that("sampleBy() on a DataFrame", {
14581458
fractions <- list("0" = 0.1, "1" = 0.2)
14591459
sample <- sampleBy(df, "key", fractions, 0)
14601460
result <- collect(orderBy(count(groupBy(sample, "key")), "key"))
1461-
expect_identical(as.list(result[1, ]), list(key = "0", count = 2))
1462-
expect_identical(as.list(result[2, ]), list(key = "1", count = 10))
1461+
expect_identical(as.list(result[1, ]), list(key = "0", count = 3))
1462+
expect_identical(as.list(result[2, ]), list(key = "1", count = 7))
14631463
})
14641464

14651465
test_that("SQL error message is returned from JVM", {
14661466
retError <- tryCatch(sql(sqlContext, "select * from blah"), error = function(e) e)
14671467
expect_equal(grepl("Table not found: blah", retError), TRUE)
14681468
})
14691469

1470+
irisDF <- createDataFrame(sqlContext, iris)
1471+
14701472
test_that("Method as.data.frame as a synonym for collect()", {
1471-
irisDF <- createDataFrame(sqlContext, iris)
14721473
expect_equal(as.data.frame(irisDF), collect(irisDF))
14731474
irisDF2 <- irisDF[irisDF$Species == "setosa", ]
14741475
expect_equal(as.data.frame(irisDF2), collect(irisDF2))
@@ -1503,6 +1504,27 @@ test_that("with() on a DataFrame", {
15031504
expect_equal(nrow(sum2), 35)
15041505
})
15051506

1507+
test_that("Method coltypes() to get R's data types of a DataFrame", {
1508+
expect_equal(coltypes(irisDF), c(rep("numeric", 4), "character"))
1509+
1510+
data <- data.frame(c1=c(1,2,3),
1511+
c2=c(T,F,T),
1512+
c3=c("2015/01/01 10:00:00", "2015/01/02 10:00:00", "2015/01/03 10:00:00"))
1513+
1514+
schema <- structType(structField("c1", "byte"),
1515+
structField("c3", "boolean"),
1516+
structField("c4", "timestamp"))
1517+
1518+
# Test primitive types
1519+
DF <- createDataFrame(sqlContext, data, schema)
1520+
expect_equal(coltypes(DF), c("integer", "logical", "POSIXct"))
1521+
1522+
# Test complex types
1523+
x <- createDataFrame(sqlContext, list(list(as.environment(
1524+
list("a"="b", "c"="d", "e"="f")))))
1525+
expect_equal(coltypes(x), "map<string,string>")
1526+
})
1527+
15061528
unlink(parquetPath)
15071529
unlink(jsonPath)
15081530
unlink(jsonPathNa)

0 commit comments

Comments
 (0)