@@ -26,6 +26,48 @@ import org.apache.spark.sql.types._
2626
2727private [sql] object StatFunctions extends Logging {
2828
29+ /** Calculate Spearman Correlation Coefficient for the given columns */
30+ private [sql] def spearmanCorrelation (df : DataFrame , cols : Seq [String ], tie : Boolean ): Double = {
31+ require(cols.length == 2 , " Spearman correlation can only be calculated " +
32+ " between two columns." )
33+ val rankDF = calRanks(df, cols)
34+ if (tie) {
35+ // Calculate Spearman Correlation Coefficient with tie correction
36+ pearsonCorrelation(rankDF, Seq (" avgDX" , " avgDY" ))
37+ } else {
38+ // Calculate Spearman Correlation Coefficient with no tie correction
39+ val n = df.count().toDouble
40+ val sumOfRankingsDiff = calSumOfRankingsDiff(rankDF)
41+ 1.0 - 6.0 * sumOfRankingsDiff / (n * (n * n - 1.0 ))
42+ }
43+ }
44+
45+ private [sql] def calRanks (df : DataFrame , cols : Seq [String ]): DataFrame = {
46+ import df .sqlContext .implicits ._
47+ val columns = cols.map(n => Column (Cast (Column (n).expr, DoubleType )))
48+ val doubleDF = df.select(columns : _* ).toDF(cols(0 ), cols(1 ))
49+
50+ val sortOnCol1 = doubleDF.select(cols(0 )).sort(cols(0 )).rdd.zipWithIndex()
51+ .map(kv => (kv._1.getDouble(0 ), kv._2 + 1.0 )).toDF(cols(0 ), " dX" )
52+ val avgRankings1 = sortOnCol1.groupBy(cols(0 )).avg(" dX" ).toDF(cols(0 ), " avgDX" )
53+ val avg1 = sortOnCol1.join(avgRankings1, sortOnCol1(cols(0 )) === avgRankings1(cols(0 )))
54+ .select(sortOnCol1(cols(0 )), avgRankings1(" avgDX" ))
55+
56+ val sortOnCol2 = doubleDF.select(cols(1 )).sort(cols(1 )).rdd.zipWithIndex()
57+ .map(kv => (kv._1.getDouble(0 ), kv._2 + 1.0 )).toDF(cols(1 ), " dY" )
58+ val avgRankings2 = sortOnCol2.groupBy(cols(1 )).avg(" dY" ).toDF(cols(1 ), " avgDY" )
59+ val avg2 = sortOnCol2.join(avgRankings2, sortOnCol2(cols(1 )) === avgRankings2(cols(1 )))
60+ .select(sortOnCol2(cols(1 )), avgRankings2(" avgDY" ))
61+
62+ doubleDF.join(avg1, doubleDF(cols(0 )) === avg1(cols(0 )))
63+ .join(avg2, doubleDF(cols(1 )) === avg2(cols(1 ))).distinct.select(" avgDX" , " avgDY" )
64+ }
65+
66+ private [sql] def calSumOfRankingsDiff (rankDF : DataFrame ): Double = {
67+ import rankDF .sqlContext .implicits ._
68+ rankDF.select(sum(($" avgDX" - $" avgDY" ) * ($" avgDX" - $" avgDY" ))).collect()(0 ).getDouble(0 )
69+ }
70+
2971 /** Calculate the Pearson Correlation Coefficient for the given columns */
3072 private [sql] def pearsonCorrelation (df : DataFrame , cols : Seq [String ]): Double = {
3173 val counts = collectStatisticalData(df, cols)
@@ -38,7 +80,7 @@ private[sql] object StatFunctions extends Logging {
3880 var yAvg = 0.0 // the mean of all examples seen so far in col2
3981 var Ck = 0.0 // the co-moment after k examples
4082 var MkX = 0.0 // sum of squares of differences from the (current) mean for col1
41- var MkY = 0.0 // sum of squares of differences from the (current) mean for col1
83+ var MkY = 0.0 // sum of squares of differences from the (current) mean for col2
4284 var count = 0L // count of observed examples
4385 // add an example to the calculation
4486 def add (x : Double , y : Double ): this .type = {
@@ -55,15 +97,17 @@ private[sql] object StatFunctions extends Logging {
5597 // merge counters from other partitions. Formula can be found at:
5698 // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
5799 def merge (other : CovarianceCounter ): this .type = {
58- val totalCount = count + other.count
59- val deltaX = xAvg - other.xAvg
60- val deltaY = yAvg - other.yAvg
61- Ck += other.Ck + deltaX * deltaY * count / totalCount * other.count
62- xAvg = (xAvg * count + other.xAvg * other.count) / totalCount
63- yAvg = (yAvg * count + other.yAvg * other.count) / totalCount
64- MkX += other.MkX + deltaX * deltaX * count / totalCount * other.count
65- MkY += other.MkY + deltaY * deltaY * count / totalCount * other.count
66- count = totalCount
100+ if (other.count > 0 ) {
101+ val totalCount = count + other.count
102+ val deltaX = xAvg - other.xAvg
103+ val deltaY = yAvg - other.yAvg
104+ Ck += other.Ck + deltaX * deltaY * count / totalCount * other.count
105+ xAvg = (xAvg * count + other.xAvg * other.count) / totalCount
106+ yAvg = (yAvg * count + other.yAvg * other.count) / totalCount
107+ MkX += other.MkX + deltaX * deltaX * count / totalCount * other.count
108+ MkY += other.MkY + deltaY * deltaY * count / totalCount * other.count
109+ count = totalCount
110+ }
67111 this
68112 }
69113 // return the sample covariance for the observed examples
0 commit comments