Skip to content

Commit 32f511d

Browse files
committed
Add Spearman correlation support for DataFrames.
1 parent 8aa5aea commit 32f511d

File tree

3 files changed

+113
-16
lines changed

3 files changed

+113
-16
lines changed

sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,17 +39,36 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
3939

4040
/*
4141
* Calculates the correlation of two columns of a DataFrame. Currently only supports the Pearson
42-
* Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in
43-
* MLlib's Statistics.
42+
* Correlation and Spearman Correlation Coefficient.
4443
*
4544
* @param col1 the name of the column
4645
* @param col2 the name of the column to calculate the correlation against
47-
* @return The Pearson Correlation Coefficient as a Double.
46+
* @param params the parameters for calculating the correlation
47+
* @return The Correlation Coefficient as a Double.
48+
*/
49+
def corr(col1: String, col2: String, method: String, params: Map[String, Any] = Map()): Double = {
50+
require(method == "pearson" || method == "spearman",
51+
"Currently only the calculation of the Pearson Correlation and Spearman correlation " +
52+
"coefficient are supported.")
53+
method match {
54+
case "pearson" =>
55+
StatFunctions.pearsonCorrelation(df, Seq(col1, col2))
56+
case "spearman" =>
57+
val tie = params.getOrElse("tie", true).asInstanceOf[Boolean]
58+
StatFunctions.spearmanCorrelation(df, Seq(col1, col2), tie)
59+
}
60+
}
61+
62+
/**
63+
* Calculates the correlation of two columns of a DataFrame. Currently only supports the Pearson
64+
* Correlation and Spearman Correlation Coefficient.
65+
*
66+
* @param col1 the name of the column
67+
* @param col2 the name of the column to calculate the correlation against
68+
* @return The Correlation Coefficient as a Double.
4869
*/
4970
def corr(col1: String, col2: String, method: String): Double = {
50-
require(method == "pearson", "Currently only the calculation of the Pearson Correlation " +
51-
"coefficient is supported.")
52-
StatFunctions.pearsonCorrelation(df, Seq(col1, col2))
71+
corr(col1, col2, method, Map())
5372
}
5473

5574
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala

Lines changed: 54 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,48 @@ import org.apache.spark.sql.types._
2626

2727
private[sql] object StatFunctions extends Logging {
2828

29+
/** Calculate Spearman Correlation Coefficient for the given columns */
30+
private[sql] def spearmanCorrelation(df: DataFrame, cols: Seq[String], tie: Boolean): Double = {
31+
require(cols.length == 2, "Spearman correlation can only be calculated " +
32+
"between two columns.")
33+
val rankDF = calRanks(df, cols)
34+
if (tie) {
35+
// Calculate Spearman Correlation Coefficient with tie correction
36+
pearsonCorrelation(rankDF, Seq("avgDX", "avgDY"))
37+
} else {
38+
// Calculate Spearman Correlation Coefficient with no tie correction
39+
val n = df.count().toDouble
40+
val sumOfRankingsDiff = calSumOfRankingsDiff(rankDF)
41+
1.0 - 6.0 * sumOfRankingsDiff / (n * (n * n - 1.0))
42+
}
43+
}
44+
45+
private[sql] def calRanks(df: DataFrame, cols: Seq[String]): DataFrame = {
46+
import df.sqlContext.implicits._
47+
val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType)))
48+
val doubleDF = df.select(columns: _*).toDF(cols(0), cols(1))
49+
50+
val sortOnCol1 = doubleDF.select(cols(0)).sort(cols(0)).rdd.zipWithIndex()
51+
.map(kv => (kv._1.getDouble(0), kv._2 + 1.0)).toDF(cols(0), "dX")
52+
val avgRankings1 = sortOnCol1.groupBy(cols(0)).avg("dX").toDF(cols(0), "avgDX")
53+
val avg1 = sortOnCol1.join(avgRankings1, sortOnCol1(cols(0)) === avgRankings1(cols(0)))
54+
.select(sortOnCol1(cols(0)), avgRankings1("avgDX"))
55+
56+
val sortOnCol2 = doubleDF.select(cols(1)).sort(cols(1)).rdd.zipWithIndex()
57+
.map(kv => (kv._1.getDouble(0), kv._2 + 1.0)).toDF(cols(1), "dY")
58+
val avgRankings2 = sortOnCol2.groupBy(cols(1)).avg("dY").toDF(cols(1), "avgDY")
59+
val avg2 = sortOnCol2.join(avgRankings2, sortOnCol2(cols(1)) === avgRankings2(cols(1)))
60+
.select(sortOnCol2(cols(1)), avgRankings2("avgDY"))
61+
62+
doubleDF.join(avg1, doubleDF(cols(0)) === avg1(cols(0)))
63+
.join(avg2, doubleDF(cols(1)) === avg2(cols(1))).distinct.select("avgDX", "avgDY")
64+
}
65+
66+
private[sql] def calSumOfRankingsDiff(rankDF: DataFrame): Double = {
67+
import rankDF.sqlContext.implicits._
68+
rankDF.select(sum(($"avgDX" - $"avgDY") * ($"avgDX" - $"avgDY"))).collect()(0).getDouble(0)
69+
}
70+
2971
/** Calculate the Pearson Correlation Coefficient for the given columns */
3072
private[sql] def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = {
3173
val counts = collectStatisticalData(df, cols)
@@ -38,7 +80,7 @@ private[sql] object StatFunctions extends Logging {
3880
var yAvg = 0.0 // the mean of all examples seen so far in col2
3981
var Ck = 0.0 // the co-moment after k examples
4082
var MkX = 0.0 // sum of squares of differences from the (current) mean for col1
41-
var MkY = 0.0 // sum of squares of differences from the (current) mean for col1
83+
var MkY = 0.0 // sum of squares of differences from the (current) mean for col2
4284
var count = 0L // count of observed examples
4385
// add an example to the calculation
4486
def add(x: Double, y: Double): this.type = {
@@ -55,15 +97,17 @@ private[sql] object StatFunctions extends Logging {
5597
// merge counters from other partitions. Formula can be found at:
5698
// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
5799
def merge(other: CovarianceCounter): this.type = {
58-
val totalCount = count + other.count
59-
val deltaX = xAvg - other.xAvg
60-
val deltaY = yAvg - other.yAvg
61-
Ck += other.Ck + deltaX * deltaY * count / totalCount * other.count
62-
xAvg = (xAvg * count + other.xAvg * other.count) / totalCount
63-
yAvg = (yAvg * count + other.yAvg * other.count) / totalCount
64-
MkX += other.MkX + deltaX * deltaX * count / totalCount * other.count
65-
MkY += other.MkY + deltaY * deltaY * count / totalCount * other.count
66-
count = totalCount
100+
if (other.count > 0) {
101+
val totalCount = count + other.count
102+
val deltaX = xAvg - other.xAvg
103+
val deltaY = yAvg - other.yAvg
104+
Ck += other.Ck + deltaX * deltaY * count / totalCount * other.count
105+
xAvg = (xAvg * count + other.xAvg * other.count) / totalCount
106+
yAvg = (yAvg * count + other.yAvg * other.count) / totalCount
107+
MkX += other.MkX + deltaX * deltaX * count / totalCount * other.count
108+
MkY += other.MkY + deltaY * deltaY * count / totalCount * other.count
109+
count = totalCount
110+
}
67111
this
68112
}
69113
// return the sample covariance for the observed examples

sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,40 @@ class DataFrameStatSuite extends FunSuite {
5151
assert(math.abs(corr3 - 0.95723391394758572) < 1e-12)
5252
}
5353

54+
test("spearman correlation") {
55+
val x = Seq(5.05, 6.75, 3.21, 2.66)
56+
val y = Seq(1.65, 26.5, -5.93, 7.96)
57+
val z = Seq(1.65, 2.64, 2.64, 6.95)
58+
59+
// To calculate the Spearman Correlation in R:
60+
// > x <- c(5.05, 6.75, 3.21, 2.66)
61+
// > y <- c(1.65, 26.5, -5.93, 7.96)
62+
// > z <- c(1.65, 2.64, 2.64, 6.95)
63+
// No tie correction is needed
64+
// > cor(x, y, method="spearman")
65+
// [1] 0.4
66+
// Tie correction is needed
67+
// > cor(x, z, method="spearman")
68+
// [1] -0.6324555
69+
// > cor(y, z, method="spearman")
70+
// [1] 0.3162278
71+
72+
val df1 = x.zip(y).toDF("a", "b")
73+
val corr1 = df1.stat.corr("a", "b", "spearman")
74+
75+
val params: Map[String, Any] = Map("tie" -> true)
76+
77+
val df2 = x.zip(z).toDF("a", "c")
78+
val corr2 = df2.stat.corr("a", "c", "spearman", params)
79+
80+
val df3 = y.zip(z).toDF("b", "c")
81+
val corr3 = df3.stat.corr("b", "c", "spearman", params)
82+
83+
assert(corr1 - 0.4 < 1e-12)
84+
assert(corr2 + 0.6324555 < 1e-12)
85+
assert(corr3 - 0.3162278 < 1e-12)
86+
}
87+
5488
test("covariance") {
5589
val df = Seq.tabulate(10)(i => (i, 2.0 * i, toLetter(i))).toDF("singles", "doubles", "letters")
5690

0 commit comments

Comments
 (0)