Skip to content

Commit 9011bc5

Browse files
committed
Merge remote-tracking branch 'origin/master' into docker-jdbc-tests
2 parents 6db2c1c + c4e19b3 commit 9011bc5

File tree

202 files changed

+8659
-1985
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

202 files changed

+8659
-1985
lines changed

R/pkg/R/DataFrame.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1944,9 +1944,9 @@ setMethod("describe",
19441944
#' @rdname summary
19451945
#' @name summary
19461946
setMethod("summary",
1947-
signature(x = "DataFrame"),
1948-
function(x) {
1949-
describe(x)
1947+
signature(object = "DataFrame"),
1948+
function(object, ...) {
1949+
describe(object)
19501950
})
19511951

19521952

R/pkg/R/generics.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,7 @@ setGeneric("summarize", function(x,...) { standardGeneric("summarize") })
561561

562562
#' @rdname summary
563563
#' @export
564-
setGeneric("summary", function(x, ...) { standardGeneric("summary") })
564+
setGeneric("summary", function(object, ...) { standardGeneric("summary") })
565565

566566
# @rdname tojson
567567
# @export

R/pkg/R/mllib.R

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -89,14 +89,28 @@ setMethod("predict", signature(object = "PipelineModel"),
8989
#' model <- glm(y ~ x, trainingData)
9090
#' summary(model)
9191
#'}
92-
setMethod("summary", signature(x = "PipelineModel"),
93-
function(x, ...) {
92+
setMethod("summary", signature(object = "PipelineModel"),
93+
function(object, ...) {
94+
modelName <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
95+
"getModelName", object@model)
9496
features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
95-
"getModelFeatures", x@model)
97+
"getModelFeatures", object@model)
9698
coefficients <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
97-
"getModelCoefficients", x@model)
98-
coefficients <- as.matrix(unlist(coefficients))
99-
colnames(coefficients) <- c("Estimate")
100-
rownames(coefficients) <- unlist(features)
101-
return(list(coefficients = coefficients))
99+
"getModelCoefficients", object@model)
100+
if (modelName == "LinearRegressionModel") {
101+
devianceResiduals <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
102+
"getModelDevianceResiduals", object@model)
103+
devianceResiduals <- matrix(devianceResiduals, nrow = 1)
104+
colnames(devianceResiduals) <- c("Min", "Max")
105+
rownames(devianceResiduals) <- rep("", times = 1)
106+
coefficients <- matrix(coefficients, ncol = 4)
107+
colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)")
108+
rownames(coefficients) <- unlist(features)
109+
return(list(DevianceResiduals = devianceResiduals, Coefficients = coefficients))
110+
} else {
111+
coefficients <- as.matrix(unlist(coefficients))
112+
colnames(coefficients) <- c("Estimate")
113+
rownames(coefficients) <- unlist(features)
114+
return(list(coefficients = coefficients))
115+
}
102116
})

R/pkg/inst/tests/test_mllib.R

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,23 @@ test_that("feature interaction vs native glm", {
7171

7272
test_that("summary coefficients match with native glm", {
7373
training <- createDataFrame(sqlContext, iris)
74-
stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training, solver = "l-bfgs"))
75-
coefs <- as.vector(stats$coefficients)
74+
stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training, solver = "normal"))
75+
coefs <- unlist(stats$Coefficients)
76+
devianceResiduals <- unlist(stats$DevianceResiduals)
77+
7678
rCoefs <- as.vector(coef(glm(Sepal.Width ~ Sepal.Length + Species, data = iris)))
77-
expect_true(all(abs(rCoefs - coefs) < 1e-6))
79+
rStdError <- c(0.23536, 0.04630, 0.07207, 0.09331)
80+
rTValue <- c(7.123, 7.557, -13.644, -10.798)
81+
rPValue <- c(0.0, 0.0, 0.0, 0.0)
82+
rDevianceResiduals <- c(-0.95096, 0.72918)
83+
84+
expect_true(all(abs(rCoefs - coefs[1:4]) < 1e-6))
85+
expect_true(all(abs(rStdError - coefs[5:8]) < 1e-5))
86+
expect_true(all(abs(rTValue - coefs[9:12]) < 1e-3))
87+
expect_true(all(abs(rPValue - coefs[13:16]) < 1e-6))
88+
expect_true(all(abs(rDevianceResiduals - devianceResiduals) < 1e-5))
7889
expect_true(all(
79-
as.character(stats$features) ==
90+
rownames(stats$Coefficients) ==
8091
c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica")))
8192
})
8293

@@ -85,14 +96,26 @@ test_that("summary coefficients match with native glm of family 'binomial'", {
8596
training <- filter(df, df$Species != "setosa")
8697
stats <- summary(glm(Species ~ Sepal_Length + Sepal_Width, data = training,
8798
family = "binomial"))
88-
coefs <- as.vector(stats$coefficients)
99+
coefs <- as.vector(stats$Coefficients)
89100

90101
rTraining <- iris[iris$Species %in% c("versicolor","virginica"),]
91102
rCoefs <- as.vector(coef(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining,
92103
family = binomial(link = "logit"))))
104+
rStdError <- c(3.0974, 0.5169, 0.8628)
105+
rTValue <- c(-4.212, 3.680, 0.469)
106+
rPValue <- c(0.000, 0.000, 0.639)
93107

94-
expect_true(all(abs(rCoefs - coefs) < 1e-4))
108+
expect_true(all(abs(rCoefs - coefs[1:3]) < 1e-4))
109+
expect_true(all(abs(rStdError - coefs[4:6]) < 1e-4))
110+
expect_true(all(abs(rTValue - coefs[7:9]) < 1e-3))
111+
expect_true(all(abs(rPValue - coefs[10:12]) < 1e-3))
95112
expect_true(all(
96-
as.character(stats$features) ==
113+
rownames(stats$Coefficients) ==
97114
c("(Intercept)", "Sepal_Length", "Sepal_Width")))
98115
})
116+
117+
test_that("summary works on base GLM models", {
118+
baseModel <- stats::glm(Sepal.Width ~ Sepal.Length + Species, data = iris)
119+
baseSummary <- summary(baseModel)
120+
expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4)
121+
})

R/pkg/inst/tests/test_sparkSQL.R

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -647,11 +647,11 @@ test_that("sample on a DataFrame", {
647647
sampled <- sample(df, FALSE, 1.0)
648648
expect_equal(nrow(collect(sampled)), count(df))
649649
expect_is(sampled, "DataFrame")
650-
sampled2 <- sample(df, FALSE, 0.1)
650+
sampled2 <- sample(df, FALSE, 0.1, 0) # set seed for predictable result
651651
expect_true(count(sampled2) < 3)
652652

653653
# Also test sample_frac
654-
sampled3 <- sample_frac(df, FALSE, 0.1)
654+
sampled3 <- sample_frac(df, FALSE, 0.1, 0) # set seed for predictable result
655655
expect_true(count(sampled3) < 3)
656656
})
657657

@@ -875,9 +875,9 @@ test_that("column binary mathfunctions", {
875875
expect_equal(collect(select(df, shiftRight(df$b, 1)))[4, 1], 4)
876876
expect_equal(collect(select(df, shiftRightUnsigned(df$b, 1)))[4, 1], 4)
877877
expect_equal(class(collect(select(df, rand()))[2, 1]), "numeric")
878-
expect_equal(collect(select(df, rand(1)))[1, 1], 0.45, tolerance = 0.01)
878+
expect_equal(collect(select(df, rand(1)))[1, 1], 0.134, tolerance = 0.01)
879879
expect_equal(class(collect(select(df, randn()))[2, 1]), "numeric")
880-
expect_equal(collect(select(df, randn(1)))[1, 1], -0.0111, tolerance = 0.01)
880+
expect_equal(collect(select(df, randn(1)))[1, 1], -1.03, tolerance = 0.01)
881881
})
882882

883883
test_that("string operators", {
@@ -1458,8 +1458,8 @@ test_that("sampleBy() on a DataFrame", {
14581458
fractions <- list("0" = 0.1, "1" = 0.2)
14591459
sample <- sampleBy(df, "key", fractions, 0)
14601460
result <- collect(orderBy(count(groupBy(sample, "key")), "key"))
1461-
expect_identical(as.list(result[1, ]), list(key = "0", count = 2))
1462-
expect_identical(as.list(result[2, ]), list(key = "1", count = 10))
1461+
expect_identical(as.list(result[1, ]), list(key = "0", count = 3))
1462+
expect_identical(as.list(result[2, ]), list(key = "1", count = 7))
14631463
})
14641464

14651465
test_that("SQL error message is returned from JVM", {
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.api.java.function;
19+
20+
import java.io.Serializable;
21+
import java.util.Iterator;
22+
23+
/**
24+
* A function that returns zero or more output records from each grouping key and its values from 2
25+
* Datasets.
26+
*/
27+
public interface CoGroupFunction<K, V1, V2, R> extends Serializable {
28+
Iterable<R> call(K key, Iterator<V1> left, Iterator<V2> right) throws Exception;
29+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.api.java.function;
19+
20+
import java.io.Serializable;
21+
22+
/**
23+
* Base interface for a function used in Dataset's filter function.
24+
*
25+
* If the function returns true, the element is discarded in the returned Dataset.
26+
*/
27+
public interface FilterFunction<T> extends Serializable {
28+
boolean call(T value) throws Exception;
29+
}

core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,5 @@
2323
* A function that returns zero or more output records from each input record.
2424
*/
2525
public interface FlatMapFunction<T, R> extends Serializable {
26-
public Iterable<R> call(T t) throws Exception;
26+
Iterable<R> call(T t) throws Exception;
2727
}

core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,5 @@
2323
* A function that takes two inputs and returns zero or more output records.
2424
*/
2525
public interface FlatMapFunction2<T1, T2, R> extends Serializable {
26-
public Iterable<R> call(T1 t1, T2 t2) throws Exception;
26+
Iterable<R> call(T1 t1, T2 t2) throws Exception;
2727
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.api.java.function;
19+
20+
import java.io.Serializable;
21+
import java.util.Iterator;
22+
23+
/**
24+
* A function that returns zero or more output records from each grouping key and its values.
25+
*/
26+
public interface FlatMapGroupFunction<K, V, R> extends Serializable {
27+
Iterable<R> call(K key, Iterator<V> values) throws Exception;
28+
}

0 commit comments

Comments
 (0)