Skip to content

Commit 90527f5

Browse files
viiryamengxr
authored andcommitted
[SPARK-7390] [SQL] Only merge other CovarianceCounter when its count is greater than zero
JIRA: https://issues.apache.org/jira/browse/SPARK-7390 Also fix a minor typo. Author: Liang-Chi Hsieh <[email protected]> Closes #5931 from viirya/fix_covariancecounter and squashes the following commits: 352eda6 [Liang-Chi Hsieh] Only merge other CovarianceCounter when its count is greater than zero.
1 parent 5467c34 commit 90527f5

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ private[sql] object StatFunctions extends Logging {
3838
var yAvg = 0.0 // the mean of all examples seen so far in col2
3939
var Ck = 0.0 // the co-moment after k examples
4040
var MkX = 0.0 // sum of squares of differences from the (current) mean for col1
41-
var MkY = 0.0 // sum of squares of differences from the (current) mean for col1
41+
var MkY = 0.0 // sum of squares of differences from the (current) mean for col2
4242
var count = 0L // count of observed examples
4343
// add an example to the calculation
4444
def add(x: Double, y: Double): this.type = {
@@ -55,15 +55,17 @@ private[sql] object StatFunctions extends Logging {
5555
// merge counters from other partitions. Formula can be found at:
5656
// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
5757
def merge(other: CovarianceCounter): this.type = {
58-
val totalCount = count + other.count
59-
val deltaX = xAvg - other.xAvg
60-
val deltaY = yAvg - other.yAvg
61-
Ck += other.Ck + deltaX * deltaY * count / totalCount * other.count
62-
xAvg = (xAvg * count + other.xAvg * other.count) / totalCount
63-
yAvg = (yAvg * count + other.yAvg * other.count) / totalCount
64-
MkX += other.MkX + deltaX * deltaX * count / totalCount * other.count
65-
MkY += other.MkY + deltaY * deltaY * count / totalCount * other.count
66-
count = totalCount
58+
if (other.count > 0) {
59+
val totalCount = count + other.count
60+
val deltaX = xAvg - other.xAvg
61+
val deltaY = yAvg - other.yAvg
62+
Ck += other.Ck + deltaX * deltaY * count / totalCount * other.count
63+
xAvg = (xAvg * count + other.xAvg * other.count) / totalCount
64+
yAvg = (yAvg * count + other.yAvg * other.count) / totalCount
65+
MkX += other.MkX + deltaX * deltaX * count / totalCount * other.count
66+
MkY += other.MkY + deltaY * deltaY * count / totalCount * other.count
67+
count = totalCount
68+
}
6769
this
6870
}
6971
// return the sample covariance for the observed examples

0 commit comments

Comments
 (0)