Skip to content

Commit 352eda6

Browse files
committed
Only merge other CovarianceCounter when its count is greater than zero.
1 parent 51b3d41 commit 352eda6

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ private[sql] object StatFunctions extends Logging {
3838
var yAvg = 0.0 // the mean of all examples seen so far in col2
3939
var Ck = 0.0 // the co-moment after k examples
4040
var MkX = 0.0 // sum of squares of differences from the (current) mean for col1
41-
var MkY = 0.0 // sum of squares of differences from the (current) mean for col1
41+
var MkY = 0.0 // sum of squares of differences from the (current) mean for col2
4242
var count = 0L // count of observed examples
4343
// add an example to the calculation
4444
def add(x: Double, y: Double): this.type = {
@@ -55,15 +55,17 @@ private[sql] object StatFunctions extends Logging {
5555
// merge counters from other partitions. Formula can be found at:
5656
// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
5757
def merge(other: CovarianceCounter): this.type = {
58-
val totalCount = count + other.count
59-
val deltaX = xAvg - other.xAvg
60-
val deltaY = yAvg - other.yAvg
61-
Ck += other.Ck + deltaX * deltaY * count / totalCount * other.count
62-
xAvg = (xAvg * count + other.xAvg * other.count) / totalCount
63-
yAvg = (yAvg * count + other.yAvg * other.count) / totalCount
64-
MkX += other.MkX + deltaX * deltaX * count / totalCount * other.count
65-
MkY += other.MkY + deltaY * deltaY * count / totalCount * other.count
66-
count = totalCount
58+
if (other.count > 0) {
59+
val totalCount = count + other.count
60+
val deltaX = xAvg - other.xAvg
61+
val deltaY = yAvg - other.yAvg
62+
Ck += other.Ck + deltaX * deltaY * count / totalCount * other.count
63+
xAvg = (xAvg * count + other.xAvg * other.count) / totalCount
64+
yAvg = (yAvg * count + other.yAvg * other.count) / totalCount
65+
MkX += other.MkX + deltaX * deltaX * count / totalCount * other.count
66+
MkY += other.MkY + deltaY * deltaY * count / totalCount * other.count
67+
count = totalCount
68+
}
6769
this
6870
}
6971
// return the sample covariance for the observed examples

0 commit comments

Comments
 (0)