Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,36 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {

/*
* Calculates the correlation of two columns of a DataFrame. Currently only supports the Pearson
* Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in
* MLlib's Statistics.
* Correlation and Spearman Correlation Coefficient.
*
* @param col1 the name of the column
* @param col2 the name of the column to calculate the correlation against
* @return The Pearson Correlation Coefficient as a Double.
* @param params the parameters for calculating the correlation
* @return The Correlation Coefficient as a Double.
*/
def corr(col1: String, col2: String, method: String, params: Map[String, Any] = Map()): Double = {
require(method == "pearson" || method == "spearman",
"Currently only the calculation of the Pearson Correlation and Spearman correlation " +
"coefficient are supported.")
method match {
case "pearson" =>
StatFunctions.pearsonCorrelation(df, Seq(col1, col2))
case "spearman" =>
val tie = params.getOrElse("tie", true).asInstanceOf[Boolean]
StatFunctions.spearmanCorrelation(df, Seq(col1, col2), tie)
}
}

/**
* Calculates the correlation of two columns of a DataFrame. Currently only supports the Pearson
* Correlation and Spearman Correlation Coefficient.
*
* @param col1 the name of the column
* @param col2 the name of the column to calculate the correlation against
* @return The Correlation Coefficient as a Double.
*/
def corr(col1: String, col2: String, method: String): Double = {
require(method == "pearson", "Currently only the calculation of the Pearson Correlation " +
"coefficient is supported.")
StatFunctions.pearsonCorrelation(df, Seq(col1, col2))
corr(col1, col2, method, Map())
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,48 @@ import org.apache.spark.sql.types._

private[sql] object StatFunctions extends Logging {

/** Calculate Spearman Correlation Coefficient for the given columns */
private[sql] def spearmanCorrelation(df: DataFrame, cols: Seq[String], tie: Boolean): Double = {
require(cols.length == 2, "Spearman correlation can only be calculated " +
"between two columns.")
val rankDF = calRanks(df, cols)
if (tie) {
// Calculate Spearman Correlation Coefficient with tie correction
pearsonCorrelation(rankDF, Seq("avgDX", "avgDY"))
} else {
// Calculate Spearman Correlation Coefficient with no tie correction
val n = df.count().toDouble
val sumOfRankingsDiff = calSumOfRankingsDiff(rankDF)
1.0 - 6.0 * sumOfRankingsDiff / (n * (n * n - 1.0))
}
}

private[sql] def calRanks(df: DataFrame, cols: Seq[String]): DataFrame = {
import df.sqlContext.implicits._
val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType)))
val doubleDF = df.select(columns: _*).toDF(cols(0), cols(1))

val sortOnCol1 = doubleDF.select(cols(0)).sort(cols(0)).rdd.zipWithIndex()
.map(kv => (kv._1.getDouble(0), kv._2 + 1.0)).toDF(cols(0), "dX")
val avgRankings1 = sortOnCol1.groupBy(cols(0)).avg("dX").toDF(cols(0), "avgDX")
val avg1 = sortOnCol1.join(avgRankings1, sortOnCol1(cols(0)) === avgRankings1(cols(0)))
.select(sortOnCol1(cols(0)), avgRankings1("avgDX"))

val sortOnCol2 = doubleDF.select(cols(1)).sort(cols(1)).rdd.zipWithIndex()
.map(kv => (kv._1.getDouble(0), kv._2 + 1.0)).toDF(cols(1), "dY")
val avgRankings2 = sortOnCol2.groupBy(cols(1)).avg("dY").toDF(cols(1), "avgDY")
val avg2 = sortOnCol2.join(avgRankings2, sortOnCol2(cols(1)) === avgRankings2(cols(1)))
.select(sortOnCol2(cols(1)), avgRankings2("avgDY"))

doubleDF.join(avg1, doubleDF(cols(0)) === avg1(cols(0)))
.join(avg2, doubleDF(cols(1)) === avg2(cols(1))).distinct.select("avgDX", "avgDY")
}

private[sql] def calSumOfRankingsDiff(rankDF: DataFrame): Double = {
import rankDF.sqlContext.implicits._
rankDF.select(sum(($"avgDX" - $"avgDY") * ($"avgDX" - $"avgDY"))).collect()(0).getDouble(0)
}

/** Calculate the Pearson Correlation Coefficient for the given columns */
private[sql] def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = {
val counts = collectStatisticalData(df, cols)
Expand All @@ -38,7 +80,7 @@ private[sql] object StatFunctions extends Logging {
var yAvg = 0.0 // the mean of all examples seen so far in col2
var Ck = 0.0 // the co-moment after k examples
var MkX = 0.0 // sum of squares of differences from the (current) mean for col1
var MkY = 0.0 // sum of squares of differences from the (current) mean for col1
var MkY = 0.0 // sum of squares of differences from the (current) mean for col2
var count = 0L // count of observed examples
// add an example to the calculation
def add(x: Double, y: Double): this.type = {
Expand All @@ -55,15 +97,17 @@ private[sql] object StatFunctions extends Logging {
// merge counters from other partitions. Formula can be found at:
// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
def merge(other: CovarianceCounter): this.type = {
val totalCount = count + other.count
val deltaX = xAvg - other.xAvg
val deltaY = yAvg - other.yAvg
Ck += other.Ck + deltaX * deltaY * count / totalCount * other.count
xAvg = (xAvg * count + other.xAvg * other.count) / totalCount
yAvg = (yAvg * count + other.yAvg * other.count) / totalCount
MkX += other.MkX + deltaX * deltaX * count / totalCount * other.count
MkY += other.MkY + deltaY * deltaY * count / totalCount * other.count
count = totalCount
if (other.count > 0) {
val totalCount = count + other.count
val deltaX = xAvg - other.xAvg
val deltaY = yAvg - other.yAvg
Ck += other.Ck + deltaX * deltaY * count / totalCount * other.count
xAvg = (xAvg * count + other.xAvg * other.count) / totalCount
yAvg = (yAvg * count + other.yAvg * other.count) / totalCount
MkX += other.MkX + deltaX * deltaX * count / totalCount * other.count
MkY += other.MkY + deltaY * deltaY * count / totalCount * other.count
count = totalCount
}
this
}
// return the sample covariance for the observed examples
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,45 @@ class DataFrameStatSuite extends FunSuite {
assert(math.abs(corr3 - 0.95723391394758572) < 1e-12)
}

test("spearman correlation") {
val x = Seq(5.05, 6.75, 3.21, 2.66)
val y = Seq(1.65, 26.5, -5.93, 7.96)
val z = Seq(1.65, 2.64, 2.64, 6.95)

// To calculate the Spearman Correlation in R:
// > x <- c(5.05, 6.75, 3.21, 2.66)
// > y <- c(1.65, 26.5, -5.93, 7.96)
// > z <- c(1.65, 2.64, 2.64, 6.95)
// No tie correction is needed
// > cor(x, y, method="spearman")
// [1] 0.4
// Tie correction is needed
// > cor(x, z, method="spearman")
// [1] -0.6324555
// > cor(y, z, method="spearman")
// [1] 0.3162278

// No tie-correction
val params1: Map[String, Any] = Map("tie" -> false)

val df1 = x.zip(y).toDF("a", "b")
val corr1 = df1.stat.corr("a", "b", "spearman", params1)

// With tie-correction
val params2: Map[String, Any] = Map("tie" -> true)

val df2 = x.zip(z).toDF("a", "c")
val corr2 = df2.stat.corr("a", "c", "spearman", params2)

// Default is with tie-correction
val df3 = y.zip(z).toDF("b", "c")
val corr3 = df3.stat.corr("b", "c", "spearman")

assert(corr1 - 0.4 < 1e-12)
assert(corr2 + 0.6324555 < 1e-12)
assert(corr3 - 0.3162278 < 1e-12)
}

test("covariance") {
val df = Seq.tabulate(10)(i => (i, 2.0 * i, toLetter(i))).toDF("singles", "doubles", "letters")

Expand Down