From 5094db08121a4a1cefc29d736e7e3e7580b866d2 Mon Sep 17 00:00:00 2001 From: JihongMa Date: Thu, 29 Oct 2015 23:46:03 -0700 Subject: [PATCH 1/8] SPARK-11420: stddev via Imperative Aggregate --- R/pkg/inst/tests/test_sparkSQL.R | 2 +- python/pyspark/sql/dataframe.py | 36 +-- .../expressions/aggregate/functions.scala | 210 ++++++------------ .../org/apache/spark/sql/GroupedData.scala | 4 +- .../org/apache/spark/sql/functions.scala | 4 +- .../spark/sql/DataFrameAggregateSuite.scala | 6 +- .../org/apache/spark/sql/DataFrameSuite.scala | 2 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 19 +- 8 files changed, 100 insertions(+), 183 deletions(-) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index e1d4499925fe..f18d68277d90 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -1271,7 +1271,7 @@ test_that("describe() and summarize() on a DataFrame", { stats <- describe(df, "age") expect_equal(collect(stats)[1, "summary"], "count") expect_equal(collect(stats)[2, "age"], "24.5") - expect_equal(collect(stats)[3, "age"], "7.7781745930520225") + expect_equal(collect(stats)[3, "age"], "5.5") stats <- describe(df) expect_equal(collect(stats)[4, "name"], "Andy") expect_equal(collect(stats)[5, "age"], "30") diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 3baff8147753..746ad7345a27 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -661,25 +661,25 @@ def describe(self, *cols): guarantee about the backward compatibility of the schema of the resulting DataFrame. >>> df.describe().show() - +-------+------------------+ - |summary| age| - +-------+------------------+ - | count| 2| - | mean| 3.5| - | stddev|2.1213203435596424| - | min| 2| - | max| 5| - +-------+------------------+ + +-------+---+ + |summary|age| + +-------+---+ + | count| 2| + | mean|3.5| + | stddev|1.5| + | min| 2| + | max| 5| + +-------+---+ >>> df.describe(['age', 'name']).show() - +-------+------------------+-----+ - |summary| age| name| - +-------+------------------+-----+ - | count| 2| 2| - | mean| 3.5| null| - | stddev|2.1213203435596424| null| - | min| 2|Alice| - | max| 5| Bob| - +-------+------------------+-----+ + +-------+---+-----+ + |summary|age| name| + +-------+---+-----+ + | count| 2| 2| + | mean|3.5| null| + | stddev|1.5| null| + | min| 2|Alice| + | max| 5| Bob| + +-------+---+-----+ """ if len(cols) == 1 and isinstance(cols[0], list): cols = cols[0] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index 281404f285a9..8169e0f45ae4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -327,149 +327,6 @@ case class Min(child: Expression) extends DeclarativeAggregate { override val evaluateExpression = min } -// Compute the sample standard deviation of a column -case class Stddev(child: Expression) extends StddevAgg(child) { - - override def isSample: Boolean = true - override def prettyName: String = "stddev" -} - -// Compute the population standard deviation of a column -case class StddevPop(child: Expression) extends StddevAgg(child) { - - override def isSample: Boolean = false - override def prettyName: String = "stddev_pop" -} - -// Compute the sample standard deviation of a column -case class StddevSamp(child: Expression) extends StddevAgg(child) { - - override def isSample: Boolean = true - override def prettyName: String = "stddev_samp" -} - -// Compute standard deviation based on online algorithm specified here: -// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance -abstract class StddevAgg(child: Expression) extends DeclarativeAggregate { - - override def children: Seq[Expression] = child :: Nil - - override def nullable: Boolean = true - - def isSample: Boolean - - // Return data type. - override def dataType: DataType = resultType - - // Expected input data type. - // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the - // new version at planning time (after analysis phase). For now, NullType is added at here - // to make it resolved when we have cases like `select stddev(null)`. - // We can use our analyzer to cast NullType to the default data type of the NumericType once - // we remove the old aggregate functions. Then, we will not need NullType at here. - override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) - - private val resultType = DoubleType - - private val preCount = AttributeReference("preCount", resultType)() - private val currentCount = AttributeReference("currentCount", resultType)() - private val preAvg = AttributeReference("preAvg", resultType)() - private val currentAvg = AttributeReference("currentAvg", resultType)() - private val currentMk = AttributeReference("currentMk", resultType)() - - override val aggBufferAttributes = preCount :: currentCount :: preAvg :: - currentAvg :: currentMk :: Nil - - override val initialValues = Seq( - /* preCount = */ Cast(Literal(0), resultType), - /* currentCount = */ Cast(Literal(0), resultType), - /* preAvg = */ Cast(Literal(0), resultType), - /* currentAvg = */ Cast(Literal(0), resultType), - /* currentMk = */ Cast(Literal(0), resultType) - ) - - override val updateExpressions = { - - // update average - // avg = avg + (value - avg)/count - def avgAdd: Expression = { - currentAvg + ((Cast(child, resultType) - currentAvg) / currentCount) - } - - // update sum of square of difference from mean - // Mk = Mk + (value - preAvg) * (value - updatedAvg) - def mkAdd: Expression = { - val delta1 = Cast(child, resultType) - preAvg - val delta2 = Cast(child, resultType) - currentAvg - currentMk + (delta1 * delta2) - } - - Seq( - /* preCount = */ If(IsNull(child), preCount, currentCount), - /* currentCount = */ If(IsNull(child), currentCount, - Add(currentCount, Cast(Literal(1), resultType))), - /* preAvg = */ If(IsNull(child), preAvg, currentAvg), - /* currentAvg = */ If(IsNull(child), currentAvg, avgAdd), - /* currentMk = */ If(IsNull(child), currentMk, mkAdd) - ) - } - - override val mergeExpressions = { - - // count merge - def countMerge: Expression = { - currentCount.left + currentCount.right - } - - // average merge - def avgMerge: Expression = { - ((currentAvg.left * preCount) + (currentAvg.right * currentCount.right)) / - (preCount + currentCount.right) - } - - // update sum of square differences - def mkMerge: Expression = { - val avgDelta = currentAvg.right - preAvg - val mkDelta = (avgDelta * avgDelta) * (preCount * currentCount.right) / - (preCount + currentCount.right) - - currentMk.left + currentMk.right + mkDelta - } - - Seq( - /* preCount = */ If(IsNull(currentCount.left), - Cast(Literal(0), resultType), currentCount.left), - /* currentCount = */ If(IsNull(currentCount.left), currentCount.right, - If(IsNull(currentCount.right), currentCount.left, countMerge)), - /* preAvg = */ If(IsNull(currentAvg.left), Cast(Literal(0), resultType), currentAvg.left), - /* currentAvg = */ If(IsNull(currentAvg.left), currentAvg.right, - If(IsNull(currentAvg.right), currentAvg.left, avgMerge)), - /* currentMk = */ If(IsNull(currentMk.left), currentMk.right, - If(IsNull(currentMk.right), currentMk.left, mkMerge)) - ) - } - - override val evaluateExpression = { - // when currentCount == 0, return null - // when currentCount == 1, return 0 - // when currentCount >1 - // stddev_samp = sqrt (currentMk/(currentCount -1)) - // stddev_pop = sqrt (currentMk/currentCount) - val varCol = { - if (isSample) { - currentMk / Cast((currentCount - Cast(Literal(1), resultType)), resultType) - } - else { - currentMk / currentCount - } - } - - If(EqualTo(currentCount, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), - If(EqualTo(currentCount, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), - Cast(Sqrt(varCol), resultType))) - } -} - case class Sum(child: Expression) extends DeclarativeAggregate { override def children: Seq[Expression] = child :: Nil @@ -1139,6 +996,73 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w } } +case class Stddev(child: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def prettyName: String = "stddev" + + override protected val momentOrder = 2 + + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + require(moments.length == momentOrder + 1, + s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") + + if (n == 0.0) Double.NaN else math.sqrt(moments(2) / n) + } +} + + +case class StddevPop(child: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def prettyName: String = "stddev_pop" + + override protected val momentOrder = 2 + + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + require(moments.length == momentOrder + 1, + s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") + + if (n == 0.0) Double.NaN else math.sqrt(moments(2) / n) + } +} + +case class StddevSamp(child: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def prettyName: String = "stddev_samp" + + override protected val momentOrder = 2 + + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + require(moments.length == momentOrder + 1, + s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") + + if (n == 0.0) Double.NaN else math.sqrt(moments(2) / (n - 1.0)) + } +} + case class Variance(child: Expression, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index dc96384a4d28..bdcefd3136d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -316,7 +316,7 @@ class GroupedData protected[sql]( } /** - * Compute the sample standard deviation for each numeric columns for each group. + * Compute the population standard deviation for each numeric columns for each group. * The resulting [[DataFrame]] will also contain the grouping columns. * When specified columns are given, only compute the stddev for them. * @@ -364,7 +364,7 @@ class GroupedData protected[sql]( } /** - * Compute the sample variance for each numeric columns for each group. + * Compute the population variance for each numeric columns for each group. * The resulting [[DataFrame]] will also contain the grouping columns. * When specified columns are given, only compute the variance for them. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index c1737b1ef663..8d2280c3dcca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -327,7 +327,7 @@ object functions { def skewness(columnName: String): Column = skewness(Column(columnName)) /** - * Aggregate function: returns the unbiased sample standard deviation of + * Aggregate function: returns the population standard deviation of * the expression in a group. * * @group agg_funcs @@ -336,7 +336,7 @@ object functions { def stddev(e: Column): Column = Stddev(e.expr) /** - * Aggregate function: returns the unbiased sample standard deviation of + * Aggregate function: returns the population standard deviation of * the expression in a group. * * @group agg_funcs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 9b23977c765d..8b3b676d8651 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -176,7 +176,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } test("stddev") { - val testData2ADev = math.sqrt(4/5.0) + val testData2ADev = math.sqrt(4 / 6.0) checkAnswer( testData2.agg(stddev('a)), @@ -184,11 +184,11 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { checkAnswer( testData2.agg(stddev_pop('a)), - Row(math.sqrt(4/6.0))) + Row(testData2ADev)) checkAnswer( testData2.agg(stddev_samp('a)), - Row(testData2ADev)) + Row(math.sqrt(4 / 5.0)) } test("zero stddev") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index c9d6e19d2ce9..1ca1e73a21ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -443,7 +443,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val describeResult = Seq( Row("count", "4", "4"), Row("mean", "33.0", "178.0"), - Row("stddev", "19.148542155126762", "11.547005383792516"), + Row("stddev", "16.583123951777", "10.0"), Row("min", "16", "164"), Row("max", "60", "192")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5a616fac0bc2..6325f18ef5a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -329,13 +329,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { testCodeGen( "SELECT min(key) FROM testData3x", Row(1) :: Nil) - // STDDEV - testCodeGen( - "SELECT a, stddev(b), stddev_pop(b) FROM testData2 GROUP BY a", - (1 to 3).map(i => Row(i, math.sqrt(0.5), math.sqrt(0.25)))) - testCodeGen( - "SELECT stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2", - Row(math.sqrt(1.5 / 5), math.sqrt(1.5 / 6), math.sqrt(1.5 / 5)) :: Nil) // Some combinations. testCodeGen( """ @@ -356,8 +349,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(100, 1, 50.5, 300, 100) :: Nil) // Aggregate with Code generation handling all null values testCodeGen( - "SELECT sum('a'), avg('a'), stddev('a'), count(null) FROM testData", - Row(null, null, null, 0) :: Nil) + "SELECT sum('a'), avg('a'), count(null) FROM testData", + Row(null, null, 0) :: Nil) } finally { sqlContext.dropTempTable("testData3x") sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue) @@ -525,7 +518,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAnswer( sql("SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a)," + "AVG(a), VARIANCE(a), STDDEV(a), SUM(a), COUNT(a) FROM nullInts"), - Row(0, -1.5, 1, 3, 2, 2.0 / 3.0, 1, 6, 3) + Row(0, -1.5, 1, 3, 2, 2.0 / 3.0, math.sqrt(2.0 / 3.0), 6, 3) ) } @@ -718,7 +711,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("stddev") { checkAnswer( sql("SELECT STDDEV(a) FROM testData2"), - Row(math.sqrt(4.0 / 5.0)) + Row(math.sqrt(4.0 / 6.0)) ) } @@ -732,7 +725,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("stddev_samp") { checkAnswer( sql("SELECT STDDEV_SAMP(a) FROM testData2"), - Row(math.sqrt(4/5.0)) + Row(math.sqrt(4 / 5.0)) ) } @@ -774,7 +767,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("stddev agg") { checkAnswer( sql("SELECT a, stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2 GROUP BY a"), - (1 to 3).map(i => Row(i, math.sqrt(1.0 / 2.0), math.sqrt(1.0 / 4.0), math.sqrt(1.0 / 2.0)))) + (1 to 3).map(i => Row(i, math.sqrt(1.0 / 4.0), math.sqrt(1.0 / 4.0), math.sqrt(1.0 / 2.0)))) } test("variance agg") { From 0113626243184292d2d0d1d4bdaa7a2a621fcc53 Mon Sep 17 00:00:00 2001 From: JihongMa Date: Fri, 30 Oct 2015 10:11:12 -0700 Subject: [PATCH 2/8] handle null --- .../expressions/aggregate/functions.scala | 6 ++++-- .../spark/sql/DataFrameAggregateSuite.scala | 18 +++++++++--------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index 8169e0f45ae4..ee3ac7116e56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -992,7 +992,9 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w moments(4) = buffer.getDouble(fourthMomentOffset) } - getStatistic(n, mean, moments) + if (n == 0.0) null + else if (n == 1.0) 0.0 + else getStatistic(n, mean, moments) } } @@ -1059,7 +1061,7 @@ case class StddevSamp(child: Expression, require(moments.length == momentOrder + 1, s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") - if (n == 0.0) Double.NaN else math.sqrt(moments(2) / (n - 1.0)) + if (n == 0.0 || n == 1.0) Double.NaN else math.sqrt(moments(2) / (n - 1.0)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 8b3b676d8651..4719f1357ea1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -188,7 +188,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { checkAnswer( testData2.agg(stddev_samp('a)), - Row(math.sqrt(4 / 5.0)) + Row(math.sqrt(4 / 5.0))) } test("zero stddev") { @@ -255,7 +255,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { checkAnswer( emptyTableData.agg(var_samp('a)), - Row(Double.NaN)) + Row(0.0)) checkAnswer( emptyTableData.agg(var_pop('a)), @@ -263,11 +263,11 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { checkAnswer( emptyTableData.agg(skewness('a)), - Row(Double.NaN)) + Row(0.0)) checkAnswer( emptyTableData.agg(kurtosis('a)), - Row(Double.NaN)) + Row(0.0)) } test("null moments") { @@ -276,22 +276,22 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { checkAnswer( emptyTableData.agg(variance('a)), - Row(Double.NaN)) + Row(null)) checkAnswer( emptyTableData.agg(var_samp('a)), - Row(Double.NaN)) + Row(null)) checkAnswer( emptyTableData.agg(var_pop('a)), - Row(Double.NaN)) + Row(null)) checkAnswer( emptyTableData.agg(skewness('a)), - Row(Double.NaN)) + Row(null)) checkAnswer( emptyTableData.agg(kurtosis('a)), - Row(Double.NaN)) + Row(null)) } } From 4ca8b19cb6af7a4f203b594644053c79e571339b Mon Sep 17 00:00:00 2001 From: JihongMa Date: Tue, 3 Nov 2015 14:23:26 -0800 Subject: [PATCH 3/8] minor fix --- .../src/main/scala/org/apache/spark/sql/functions.scala | 8 -------- 1 file changed, 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index f6d1ac73279d..1297b093f5fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -328,14 +328,6 @@ object functions { */ def skewness(e: Column): Column = Skewness(e.expr) - /** - * Aggregate function: returns the skewness of the values in a group. - * - * @group agg_funcs - * @since 1.6.0 - */ - def skewness(columnName: String): Column = skewness(Column(columnName)) - /** * Aggregate function: returns the population standard deviation of * the expression in a group. From 402971cf655f69218277a64162223655efdec898 Mon Sep 17 00:00:00 2001 From: JihongMa Date: Wed, 4 Nov 2015 11:57:09 -0800 Subject: [PATCH 4/8] style fix --- .../spark/sql/catalyst/expressions/aggregate/functions.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index ba9362a625d4..7269184bec05 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -1263,7 +1263,7 @@ case class Skewness(child: Expression, val m2 = moments(2) val m3 = moments(3) if (n == 0.0 || m2 == 0.0) { - null + null } else { math.sqrt(n) * m3 / math.sqrt(m2 * m2 * m2) } @@ -1291,7 +1291,7 @@ case class Kurtosis(child: Expression, val m2 = moments(2) val m4 = moments(4) if (n == 0.0 || m2 == 0.0) { - null + null } else { n * m4 / (m2 * m2) - 3.0 } From b69d1e68d52739df83a9ebe797c83f4b8c5ef0dd Mon Sep 17 00:00:00 2001 From: JihongMa Date: Fri, 6 Nov 2015 12:56:51 -0800 Subject: [PATCH 5/8] address comment --- .../sql/catalyst/expressions/aggregate/Kurtosis.scala | 7 ++++--- .../sql/catalyst/expressions/aggregate/Skewness.scala | 7 ++++--- .../spark/sql/catalyst/expressions/aggregate/Stddev.scala | 4 +++- .../sql/catalyst/expressions/aggregate/Variance.scala | 4 +++- .../org/apache/spark/sql/DataFrameAggregateSuite.scala | 4 ++-- 5 files changed, 16 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala index 65dec152bc36..0c91cf7a76c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala @@ -40,9 +40,10 @@ case class Kurtosis(child: Expression, s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") val m2 = moments(2) val m4 = moments(4) - if (n == 0.0 || m2 == 0.0) { - null - } else { + + if (n == 0.0) null + else if (m2 == 0.0) Double.NaN + else { n * m4 / (m2 * m2) - 3.0 } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala index 3396d04b47ab..d4340c6e15be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala @@ -39,9 +39,10 @@ case class Skewness(child: Expression, s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") val m2 = moments(2) val m3 = moments(3) - if (n == 0.0 || m2 == 0.0) { - null - } else { + + if (n == 0.0) null + else if (m2 == 0.0) Double.NaN + else { math.sqrt(n) * m3 / math.sqrt(m2 * m2 * m2) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala index dd759bde8131..d4cdfdd614e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala @@ -59,6 +59,8 @@ case class StddevSamp(child: Expression, require(moments.length == momentOrder + 1, s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") - if (n == 0.0 || n == 1.0) null else math.sqrt(moments(2) / (n - 1.0)) + if (n == 0.0) null + else if (n == 1.0) Double.NaN + else math.sqrt(moments(2) / (n - 1.0)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala index 4545023f1138..f1636bc9a65a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala @@ -38,7 +38,9 @@ case class VarianceSamp(child: Expression, require(moments.length == momentOrder + 1, s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - if (n == 0.0 || n == 1.0) null else moments(2) / (n - 1.0) + if (n == 0.0) null + else if (n== 1.0) Double.NaN + else moments(2) / (n - 1.0) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index b85e3c669462..fd82b70c0714 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -220,7 +220,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { val input = Seq((1, 2)).toDF("a", "b") checkAnswer( input.agg(variance('a), var_samp('a), var_pop('a), skewness('a), kurtosis('a)), - Row(null, null, 0.0, null, null)) + Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN)) checkAnswer( input.agg( @@ -229,7 +229,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { expr("var_pop(a)"), expr("skewness(a)"), expr("kurtosis(a)")), - Row(null, null, 0.0, null, null)) + Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN)) } test("null moments") { From dc0558ba6b658ca8b64202cea47ae1dad6294120 Mon Sep 17 00:00:00 2001 From: JihongMa Date: Wed, 11 Nov 2015 14:13:02 -0800 Subject: [PATCH 6/8] fix tests --- python/pyspark/sql/dataframe.py | 2 +- .../spark/sql/catalyst/analysis/HiveTypeCoercion.scala | 6 ++++-- .../catalyst/expressions/aggregate/CentralMomentAgg.scala | 2 +- .../test/scala/org/apache/spark/sql/DataFrameSuite.scala | 2 +- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 0dd75ba7ca82..ad6ad0235a90 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -761,7 +761,7 @@ def describe(self, *cols): +-------+------------------+-----+ | count| 2| 2| | mean| 3.5| null| - | stddev|2.1213203435596424| null| + | stddev|2.1213203435596424| NaN| | min| 2|Alice| | max| 5| Bob| +-------+------------------+-----+ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index bf2bff0243fa..92188ee54fd2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -297,8 +297,10 @@ object HiveTypeCoercion { case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) case Average(e @ StringType()) => Average(Cast(e, DoubleType)) - case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType)) - case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType)) + case StddevPop(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => + StddevPop(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) + case StddevSamp(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => + StddevSamp(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) case VariancePop(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => VariancePop(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) case VarianceSamp(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index f3fceb64fffa..de5872ab11eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -206,7 +206,7 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w * @param centralMoments Length `momentOrder + 1` array of central moments (un-normalized) * needed to compute the aggregate stat. */ - def getStatistic(n: Double, mean: Double, centralMoments: Array[Double]): Double + def getStatistic(n: Double, mean: Double, centralMoments: Array[Double]): Double override final def eval(buffer: InternalRow): Any = { val n = buffer.getDouble(nOffset) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 703ad097861b..f4da56b819ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -459,7 +459,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val emptyDescribeResult = Seq( Row("count", "0", "0"), Row("mean", null, null), - Row("stddev", Double.NaN, Double.NaN), + Row("stddev", "NaN", "NaN"), Row("min", null, null), Row("max", null, null)) From ca407bc5edda37b28596c0859399f07a058f1c8a Mon Sep 17 00:00:00 2001 From: JihongMa Date: Wed, 11 Nov 2015 19:22:46 -0800 Subject: [PATCH 7/8] fix R test --- R/pkg/inst/tests/test_sparkSQL.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 9e453a1e7c2f..64fc2fdd5f74 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -1007,7 +1007,7 @@ test_that("group by, agg functions", { df3 <- agg(gd, age = "stddev") expect_is(df3, "DataFrame") df3_local <- collect(df3) - expect_equal(0, df3_local[df3_local$name == "Andy",][1, 2]) + expect_equal(NaN, df3_local[df3_local$name == "Andy",][1, 2]) df4 <- agg(gd, sumAge = sum(df$age)) expect_is(df4, "DataFrame") @@ -1038,7 +1038,7 @@ test_that("group by, agg functions", { df7 <- agg(gd2, value = "stddev") df7_local <- collect(df7) expect_true(abs(df7_local[df7_local$name == "ID1",][1, 2] - 6.928203) < 1e-6) - expect_equal(0, df7_local[df7_local$name == "ID2",][1, 2]) + expect_equal(NaN, df7_local[df7_local$name == "ID2",][1, 2]) mockLines3 <- c("{\"name\":\"Andy\", \"age\":30}", "{\"name\":\"Andy\", \"age\":30}", From 7a239ecad765f7a570741bd086302a35f7897742 Mon Sep 17 00:00:00 2001 From: JihongMa Date: Wed, 11 Nov 2015 19:30:51 -0800 Subject: [PATCH 8/8] fix test_sparkSQL.R --- R/pkg/inst/tests/test_sparkSQL.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 64fc2fdd5f74..af024e6183a3 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -1007,7 +1007,7 @@ test_that("group by, agg functions", { df3 <- agg(gd, age = "stddev") expect_is(df3, "DataFrame") df3_local <- collect(df3) - expect_equal(NaN, df3_local[df3_local$name == "Andy",][1, 2]) + expect_true(is.nan(df3_local[df3_local$name == "Andy",][1, 2])) df4 <- agg(gd, sumAge = sum(df$age)) expect_is(df4, "DataFrame") @@ -1038,7 +1038,7 @@ test_that("group by, agg functions", { df7 <- agg(gd2, value = "stddev") df7_local <- collect(df7) expect_true(abs(df7_local[df7_local$name == "ID1",][1, 2] - 6.928203) < 1e-6) - expect_equal(NaN, df7_local[df7_local$name == "ID2",][1, 2]) + expect_true(is.nan(df7_local[df7_local$name == "ID2",][1, 2])) mockLines3 <- c("{\"name\":\"Andy\", \"age\":30}", "{\"name\":\"Andy\", \"age\":30}",