@@ -38,7 +38,7 @@ private[sql] object StatFunctions extends Logging {
3838 var yAvg = 0.0 // the mean of all examples seen so far in col2
3939 var Ck = 0.0 // the co-moment after k examples
4040 var MkX = 0.0 // sum of squares of differences from the (current) mean for col1
41- var MkY = 0.0 // sum of squares of differences from the (current) mean for col1
41+ var MkY = 0.0 // sum of squares of differences from the (current) mean for col2
4242 var count = 0L // count of observed examples
4343 // add an example to the calculation
4444 def add (x : Double , y : Double ): this .type = {
@@ -55,15 +55,17 @@ private[sql] object StatFunctions extends Logging {
5555 // merge counters from other partitions. Formula can be found at:
5656 // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
5757 def merge (other : CovarianceCounter ): this .type = {
58- val totalCount = count + other.count
59- val deltaX = xAvg - other.xAvg
60- val deltaY = yAvg - other.yAvg
61- Ck += other.Ck + deltaX * deltaY * count / totalCount * other.count
62- xAvg = (xAvg * count + other.xAvg * other.count) / totalCount
63- yAvg = (yAvg * count + other.yAvg * other.count) / totalCount
64- MkX += other.MkX + deltaX * deltaX * count / totalCount * other.count
65- MkY += other.MkY + deltaY * deltaY * count / totalCount * other.count
66- count = totalCount
58+ if (other.count > 0 ) {
59+ val totalCount = count + other.count
60+ val deltaX = xAvg - other.xAvg
61+ val deltaY = yAvg - other.yAvg
62+ Ck += other.Ck + deltaX * deltaY * count / totalCount * other.count
63+ xAvg = (xAvg * count + other.xAvg * other.count) / totalCount
64+ yAvg = (yAvg * count + other.yAvg * other.count) / totalCount
65+ MkX += other.MkX + deltaX * deltaX * count / totalCount * other.count
66+ MkY += other.MkY + deltaY * deltaY * count / totalCount * other.count
67+ count = totalCount
68+ }
6769 this
6870 }
6971 // return the sample covariance for the observed examples
0 commit comments