diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 38f5c02910f79..3dfca3f0561cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -300,6 +300,10 @@ object FunctionRegistry { expression[CollectList]("collect_list"), expression[CollectSet]("collect_set"), expression[CountMinSketchAgg]("count_min_sketch"), + expression[EveryAgg]("every"), + expression[AnyAgg]("any"), + expression[SomeAgg]("some"), + // string functions expression[Ascii]("ascii"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala index 71099eba0fc75..88ccba31b4a4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala @@ -23,10 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ -@ExpressionDescription( - usage = "_FUNC_(expr) - Returns the maximum value of `expr`.") -case class Max(child: Expression) extends DeclarativeAggregate { - +abstract class MaxBase(child: Expression) extends DeclarativeAggregate { override def children: Seq[Expression] = child :: Nil override def nullable: Boolean = true @@ -57,3 +54,31 @@ case class Max(child: Expression) extends DeclarativeAggregate { override lazy val evaluateExpression: AttributeReference = max } + +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns the maximum value of `expr`.") +case class Max(child: Expression) extends MaxBase(child) + +abstract class AnyAggBase(child: Expression) extends MaxBase(child) with ImplicitCastInputTypes { + override def inputTypes: Seq[AbstractDataType] = Seq(BooleanType) + override def checkInputDataTypes(): TypeCheckResult = { + child.dataType match { + case dt if dt != BooleanType => + TypeCheckResult.TypeCheckFailure(s"Input to function '$prettyName' should have been " + + s"${BooleanType.simpleString}, but it's [${child.dataType.catalogString}].") + case _ => TypeCheckResult.TypeCheckSuccess + } + } +} + +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns true if at least one value of `expr` is true.") +case class AnyAgg(child: Expression) extends AnyAggBase (child) { + override def nodeName: String = "Any" +} + +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns true if at least one value of `expr` is true.") +case class SomeAgg(child: Expression) extends AnyAggBase(child) { + override def nodeName: String = "Some" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala index 8c4ba93231cbe..9f86ce92f6c26 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala @@ -23,10 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ -@ExpressionDescription( - usage = "_FUNC_(expr) - Returns the minimum value of `expr`.") -case class Min(child: Expression) extends DeclarativeAggregate { - +abstract class MinBase(child: Expression) extends DeclarativeAggregate { override def children: Seq[Expression] = child :: Nil override def nullable: Boolean = true @@ -57,3 +54,23 @@ case class Min(child: Expression) extends DeclarativeAggregate { override lazy val evaluateExpression: AttributeReference = min } + +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns the minimum value of `expr`.") +case class Min(child: Expression) extends MinBase(child) + +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns true if all values of `expr` are true.") +case class EveryAgg(child: Expression) extends MinBase(child) with ImplicitCastInputTypes { + override def nodeName: String = "Every" + override def inputTypes: Seq[AbstractDataType] = Seq(BooleanType) + override def checkInputDataTypes(): TypeCheckResult = { + child.dataType match { + case dt if dt != BooleanType => + TypeCheckResult.TypeCheckFailure(s"Input to function '$prettyName' should have been " + + s"${BooleanType.simpleString}, but it's [${child.dataType.catalogString}].") + case _ => TypeCheckResult.TypeCheckSuccess + } + } +} + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 8eec14842c7e7..b8643d0243f59 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -144,6 +144,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertSuccess(Sum('stringField)) assertSuccess(Average('stringField)) assertSuccess(Min('arrayField)) + assertSuccess(new EveryAgg('booleanField)) + assertSuccess(new AnyAgg('booleanField)) assertError(Min('mapField), "min does not support ordering on type") assertError(Max('mapField), "max does not support ordering on type") diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index 433db71527437..6bb3454839eca 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -80,3 +80,49 @@ SELECT 1 FROM range(10) HAVING true; SELECT 1 FROM range(10) HAVING MAX(id) > 0; SELECT id FROM range(10) HAVING id > 0; + +-- Test data +CREATE OR REPLACE TEMPORARY VIEW test_agg AS SELECT * FROM VALUES + (1, true), (1, false), + (2, true), + (3, false), (3, null), + (4, null), (4, null), + (5, null), (5, true), (5, false) AS test_agg(k, v); + +-- empty table +SELECT every(v), some(v), any(v) FROM test_agg WHERE 1 = 0; + +-- all null values +SELECT every(v), some(v), any(v) FROM test_agg WHERE k = 4; + +-- aggregates are null Filtering +SELECT every(v), some(v), any(v) FROM test_agg WHERE k = 5; + +-- group by +SELECT k, every(v), some(v), any(v) FROM test_agg GROUP BY k; + +-- having +SELECT k, every(v) FROM test_agg GROUP BY k HAVING every(v) = false; +SELECT k, every(v) FROM test_agg GROUP BY k HAVING every(v) IS NULL; + +-- input type checking Int +SELECT every(1); + +-- input type checking Short +SELECT some(1S); + +-- input type checking Long +SELECT any(1L); + +-- input type checking String +SELECT every("true"); + +-- every/some/any aggregates are not supported as windows expression. +SELECT k, v, every(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg; +SELECT k, v, some(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg; +SELECT k, v, any(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg; + +-- simple explain of queries having every/some/any agregates. Analyzed +-- plans should not have reference to replace expression. +EXPLAIN EXTENDED SELECT k, every(v), some(v), any(v) FROM test_agg GROUP BY k; + diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index f9d1ee8a6bcdb..72bf56062eb0f 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 30 +-- Number of queries: 45 -- !query 0 @@ -275,3 +275,185 @@ struct<> -- !query 29 output org.apache.spark.sql.AnalysisException grouping expressions sequence is empty, and '`id`' is not an aggregate function. Wrap '()' in windowing function(s) or wrap '`id`' in first() (or first_value) if you don't care which value you get.; + + +-- !query 28 +CREATE OR REPLACE TEMPORARY VIEW test_agg AS SELECT * FROM VALUES + (1, true), (1, false), + (2, true), + (3, false), (3, null), + (4, null), (4, null), + (5, null), (5, true), (5, false) AS test_agg(k, v) +-- !query 28 schema +struct<> +-- !query 28 output + + + +-- !query 29 +SELECT every(v), some(v), any(v) FROM test_agg WHERE 1 = 0 +-- !query 29 schema +struct +-- !query 29 output +NULL NULL NULL + + +-- !query 30 +SELECT every(v), some(v), any(v) FROM test_agg WHERE k = 4 +-- !query 30 schema +struct +-- !query 30 output +NULL NULL NULL + + +-- !query 31 +SELECT every(v), some(v), any(v) FROM test_agg WHERE k = 5 +-- !query 31 schema +struct +-- !query 31 output +false true true + + +-- !query 32 +SELECT k, every(v), some(v), any(v) FROM test_agg GROUP BY k +-- !query 32 schema +struct +-- !query 32 output +1 false true true +2 true true true +3 false false false +4 NULL NULL NULL +5 false true true + + +-- !query 33 +SELECT k, every(v) FROM test_agg GROUP BY k HAVING every(v) = false +-- !query 33 schema +struct +-- !query 33 output +1 false +3 false +5 false + + +-- !query 34 +SELECT k, every(v) FROM test_agg GROUP BY k HAVING every(v) IS NULL +-- !query 34 schema +struct +-- !query 34 output +4 NULL + + +-- !query 35 +SELECT every(1) +-- !query 35 schema +struct<> +-- !query 35 output +org.apache.spark.sql.AnalysisException +cannot resolve 'every(1)' due to data type mismatch: Input to function 'every' should have been boolean, but it's [int].; line 1 pos 7 + + +-- !query 36 +SELECT some(1S) +-- !query 36 schema +struct<> +-- !query 36 output +org.apache.spark.sql.AnalysisException +cannot resolve 'some(1S)' due to data type mismatch: Input to function 'some' should have been boolean, but it's [smallint].; line 1 pos 7 + + +-- !query 37 +SELECT any(1L) +-- !query 37 schema +struct<> +-- !query 37 output +org.apache.spark.sql.AnalysisException +cannot resolve 'any(1L)' due to data type mismatch: Input to function 'any' should have been boolean, but it's [bigint].; line 1 pos 7 + + +-- !query 38 +SELECT every("true") +-- !query 38 schema +struct<> +-- !query 38 output +org.apache.spark.sql.AnalysisException +cannot resolve 'every('true')' due to data type mismatch: Input to function 'every' should have been boolean, but it's [string].; line 1 pos 7 + + +-- !query 39 +SELECT k, v, every(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg +-- !query 39 schema +struct +-- !query 39 output +1 false false +1 true false +2 true true +3 NULL NULL +3 false false +4 NULL NULL +4 NULL NULL +5 NULL NULL +5 false false +5 true false + + +-- !query 40 +SELECT k, v, some(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg +-- !query 40 schema +struct +-- !query 40 output +1 false false +1 true true +2 true true +3 NULL NULL +3 false false +4 NULL NULL +4 NULL NULL +5 NULL NULL +5 false false +5 true true + + +-- !query 41 +SELECT k, v, any(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg +-- !query 41 schema +struct +-- !query 41 output +1 false false +1 true true +2 true true +3 NULL NULL +3 false false +4 NULL NULL +4 NULL NULL +5 NULL NULL +5 false false +5 true true + + +-- !query 42 +EXPLAIN EXTENDED SELECT k, every(v), some(v), any(v) FROM test_agg GROUP BY k +-- !query 42 schema +struct +-- !query 42 output +== Parsed Logical Plan == +'Aggregate ['k], ['k, unresolvedalias('every('v), None), unresolvedalias('some('v), None), unresolvedalias('any('v), None)] ++- 'UnresolvedRelation `test_agg` + +== Analyzed Logical Plan == +k: int, every(v): boolean, some(v): boolean, any(v): boolean +Aggregate [k#x], [k#x, every(v#x) AS every(v)#x, some(v#x) AS some(v)#x, any(v#x) AS any(v)#x] ++- SubqueryAlias `test_agg` + +- Project [k#x, v#x] + +- SubqueryAlias `test_agg` + +- LocalRelation [k#x, v#x] + +== Optimized Logical Plan == +Aggregate [k#x], [k#x, every(v#x) AS every(v)#x, some(v#x) AS some(v)#x, any(v#x) AS any(v)#x] ++- LocalRelation [k#x, v#x] + +== Physical Plan == +*HashAggregate(keys=[k#x], functions=[every(v#x), some(v#x), any(v#x)], output=[k#x, every(v)#x, some(v)#x, any(v)#x]) ++- Exchange hashpartitioning(k#x, 200) + +- *HashAggregate(keys=[k#x], functions=[partial_every(v#x), partial_some(v#x), partial_any(v#x)], output=[k#x, min#x, max#x, max#x]) + +- LocalTableScan [k#x, v#x] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index d0106c44b7db2..b5ab5cf737e25 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -21,6 +21,7 @@ import scala.util.Random import org.scalatest.Matchers.the +import org.apache.spark.sql.catalyst.expressions.aggregate.{AnyAgg, EveryAgg} import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec @@ -727,4 +728,67 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { "grouping expressions: [current_date(None)], value: [key: int, value: string], " + "type: GroupBy]")) } + + def getEveryAggColumn(columnName: String): Column = { + Column(new EveryAgg(Column(columnName).expr).toAggregateExpression(false)) + } + + def getAnyAggColumn(columnName: String): Column = { + Column(new AnyAgg(Column(columnName).expr).toAggregateExpression(false)) + } + + test("every") { + val df = Seq((1, true), (1, true), (1, false), (2, true), (2, true), (3, false), (3, false)) + .toDF("a", "b") + + checkAnswer( + df.groupBy("a").agg(getEveryAggColumn("b")), + Seq(Row(1, false), Row(2, true), Row(3, false))) + } + + test("every null values") { + val df = Seq[(java.lang.Integer, java.lang.Boolean)]( + (1, true), (1, false), + (2, true), + (3, false), (3, null), + (4, null), (4, null)) + .toDF("a", "b") + checkAnswer( + df.groupBy("a").agg(getEveryAggColumn("b")), + Seq(Row(1, false), Row(2, true), Row(3, false), Row(4, null))) + } + + test("every empty table") { + val df = Seq.empty[(Int, Boolean)].toDF("a", "b") + checkAnswer( + df.agg(getEveryAggColumn("b")), + Seq(Row(null))) + } + + test("any") { + val df = Seq((1, true), (1, true), (1, false), (2, true), (2, true), (3, false), (3, false)) + .toDF("a", "b") + checkAnswer( + df.groupBy("a").agg(getAnyAggColumn("b")), + Seq(Row(1, true), Row(2, true), Row(3, false))) + } + + test("any empty table") { + val df = Seq.empty[(Int, Boolean)].toDF("a", "b") + checkAnswer( + df.agg(getAnyAggColumn("b")), + Seq(Row(null))) + } + + test("any null values") { + val df = Seq[(java.lang.Integer, java.lang.Boolean)]( + (1, true), (1, false), + (2, true), + (3, true), (3, false), (3, null), + (4, null), (4, null)) + .toDF("a", "b") + checkAnswer( + df.groupBy("a").agg(getAnyAggColumn("b")), + Seq(Row(1, true), Row(2, true), Row(3, true), Row(4, null))) + } }