diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 1f85b07d8385d..24731c4b8577c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -137,10 +137,11 @@ case class AggregateExpression( @transient override lazy val references: AttributeSet = { - mode match { - case Partial | Complete => aggregateFunction.references ++ filterAttributes + val aggAttributes = mode match { + case Partial | Complete => aggregateFunction.references case PartialMerge | Final => AttributeSet(aggregateFunction.aggBufferAttributes) } + aggAttributes ++ filterAttributes } override def toString: String = { @@ -149,10 +150,20 @@ case class AggregateExpression( case PartialMerge => "merge_" case Final | Complete => "" } - prefix + aggregateFunction.toAggString(isDistinct) + val aggFuncStr = prefix + aggregateFunction.toAggString(isDistinct) + filter match { + case Some(predicate) => s"$aggFuncStr FILTER (WHERE $predicate)" + case _ => aggFuncStr + } } - override def sql: String = aggregateFunction.sql(isDistinct) + override def sql: String = { + val aggFuncStr = aggregateFunction.sql(isDistinct) + filter match { + case Some(predicate) => s"$aggFuncStr FILTER (WHERE ${predicate.sql})" + case _ => aggFuncStr + } + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index e729fa278e9f3..56a287d4d0279 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.optimizer.NormalizeFloatingNumbers import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming.{StateStoreRestoreExec, StateStoreSaveExec} @@ -27,6 +26,22 @@ import org.apache.spark.sql.execution.streaming.{StateStoreRestoreExec, StateSto * Utility functions used by the query planner to convert our plan to new aggregation code path. */ object AggUtils { + + private def mayRemoveAggFilters(exprs: Seq[AggregateExpression]): Seq[AggregateExpression] = { + exprs.map { ae => + if (ae.filter.isDefined) { + ae.mode match { + // Aggregate filters are applicable only in partial/complete modes; + // this method filters out them, otherwise. + case Partial | Complete => ae + case _ => ae.copy(filter = None) + } + } else { + ae + } + } + } + private def createAggregate( requiredChildDistributionExpressions: Option[Seq[Expression]] = None, groupingExpressions: Seq[NamedExpression] = Nil, @@ -41,7 +56,7 @@ object AggUtils { HashAggregateExec( requiredChildDistributionExpressions = requiredChildDistributionExpressions, groupingExpressions = groupingExpressions, - aggregateExpressions = aggregateExpressions, + aggregateExpressions = mayRemoveAggFilters(aggregateExpressions), aggregateAttributes = aggregateAttributes, initialInputBufferOffset = initialInputBufferOffset, resultExpressions = resultExpressions, @@ -54,7 +69,7 @@ object AggUtils { ObjectHashAggregateExec( requiredChildDistributionExpressions = requiredChildDistributionExpressions, groupingExpressions = groupingExpressions, - aggregateExpressions = aggregateExpressions, + aggregateExpressions = mayRemoveAggFilters(aggregateExpressions), aggregateAttributes = aggregateAttributes, initialInputBufferOffset = initialInputBufferOffset, resultExpressions = resultExpressions, @@ -63,7 +78,7 @@ object AggUtils { SortAggregateExec( requiredChildDistributionExpressions = requiredChildDistributionExpressions, groupingExpressions = groupingExpressions, - aggregateExpressions = aggregateExpressions, + aggregateExpressions = mayRemoveAggFilters(aggregateExpressions), aggregateAttributes = aggregateAttributes, initialInputBufferOffset = initialInputBufferOffset, resultExpressions = resultExpressions, diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out index 5d266c980a49a..fbb66878f891f 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out @@ -51,13 +51,13 @@ SELECT a, COUNT(b) FILTER (WHERE a >= 2) FROM testData struct<> -- !query 3 output org.apache.spark.sql.AnalysisException -grouping expressions sequence is empty, and 'testdata.`a`' is not an aggregate function. Wrap '(count(testdata.`b`) AS `count(b)`)' in windowing function(s) or wrap 'testdata.`a`' in first() (or first_value) if you don't care which value you get.; +grouping expressions sequence is empty, and 'testdata.`a`' is not an aggregate function. Wrap '(count(testdata.`b`) FILTER (WHERE (testdata.`a` >= 2)) AS `count(b) FILTER (WHERE (a >= 2))`)' in windowing function(s) or wrap 'testdata.`a`' in first() (or first_value) if you don't care which value you get.; -- !query 4 SELECT COUNT(a) FILTER (WHERE a = 1), COUNT(b) FILTER (WHERE a > 1) FROM testData -- !query 4 schema -struct +struct 1)):bigint> -- !query 4 output 2 4 @@ -65,7 +65,7 @@ struct -- !query 5 SELECT COUNT(id) FILTER (WHERE hiredate = date "2001-01-01") FROM emp -- !query 5 schema -struct +struct -- !query 5 output 2 @@ -73,7 +73,7 @@ struct -- !query 6 SELECT COUNT(id) FILTER (WHERE hiredate = to_date('2001-01-01 00:00:00')) FROM emp -- !query 6 schema -struct +struct -- !query 6 output 2 @@ -81,7 +81,7 @@ struct -- !query 7 SELECT COUNT(id) FILTER (WHERE hiredate = to_timestamp("2001-01-01 00:00:00")) FROM emp -- !query 7 schema -struct +struct -- !query 7 output 2 @@ -89,7 +89,7 @@ struct -- !query 8 SELECT COUNT(id) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd") = "2001-01-01") FROM emp -- !query 8 schema -struct +struct -- !query 8 output 2 @@ -97,7 +97,7 @@ struct -- !query 9 SELECT a, COUNT(b) FILTER (WHERE a >= 2) FROM testData GROUP BY a -- !query 9 schema -struct +struct= 2)):bigint> -- !query 9 output 1 0 2 2 @@ -117,7 +117,7 @@ expression 'testdata.`a`' is neither present in the group by, nor is it an aggre -- !query 11 SELECT COUNT(a) FILTER (WHERE a >= 0), COUNT(b) FILTER (WHERE a >= 3) FROM testData GROUP BY a -- !query 11 schema -struct +struct= 0)):bigint,count(b) FILTER (WHERE (a >= 3)):bigint> -- !query 11 output 0 0 2 0 @@ -128,7 +128,7 @@ struct -- !query 12 SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > date "2003-01-01") FROM emp GROUP BY dept_id -- !query 12 schema -struct +struct DATE '2003-01-01')):double> -- !query 12 output 10 200.0 100 400.0 @@ -141,7 +141,7 @@ NULL NULL -- !query 13 SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > to_date("2003-01-01")) FROM emp GROUP BY dept_id -- !query 13 schema -struct +struct to_date('2003-01-01'))):double> -- !query 13 output 10 200.0 100 400.0 @@ -154,7 +154,7 @@ NULL NULL -- !query 14 SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > to_timestamp("2003-01-01 00:00:00")) FROM emp GROUP BY dept_id -- !query 14 schema -struct +struct to_timestamp('2003-01-01 00:00:00'))):double> -- !query 14 output 10 200.0 100 400.0 @@ -167,7 +167,7 @@ NULL NULL -- !query 15 SELECT dept_id, SUM(salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd") > "2003-01-01") FROM emp GROUP BY dept_id -- !query 15 schema -struct +struct 2003-01-01)):double> -- !query 15 output 10 200.0 100 400.0 @@ -180,7 +180,7 @@ NULL NULL -- !query 16 SELECT 'foo', COUNT(a) FILTER (WHERE b <= 2) FROM testData GROUP BY 1 -- !query 16 schema -struct +struct -- !query 16 output foo 6 @@ -188,7 +188,7 @@ foo 6 -- !query 17 SELECT 'foo', SUM(salary) FILTER (WHERE hiredate >= date "2003-01-01") FROM emp GROUP BY 1 -- !query 17 schema -struct +struct= DATE '2003-01-01')):double> -- !query 17 output foo 1350.0 @@ -196,7 +196,7 @@ foo 1350.0 -- !query 18 SELECT 'foo', SUM(salary) FILTER (WHERE hiredate >= to_date("2003-01-01")) FROM emp GROUP BY 1 -- !query 18 schema -struct +struct= to_date('2003-01-01'))):double> -- !query 18 output foo 1350.0 @@ -204,7 +204,7 @@ foo 1350.0 -- !query 19 SELECT 'foo', SUM(salary) FILTER (WHERE hiredate >= to_timestamp("2003-01-01")) FROM emp GROUP BY 1 -- !query 19 schema -struct +struct= to_timestamp('2003-01-01'))):double> -- !query 19 output foo 1350.0 @@ -212,7 +212,7 @@ foo 1350.0 -- !query 20 select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary), sum(salary) filter (where id > 200) from emp group by dept_id -- !query 20 schema -struct +struct 200)):double> -- !query 20 output 10 2 2 400.0 NULL 100 2 2 800.0 800.0 @@ -225,7 +225,7 @@ NULL 1 1 400.0 400.0 -- !query 21 select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary), sum(salary) filter (where id + dept_id > 500) from emp group by dept_id -- !query 21 schema -struct +struct 500)):double> -- !query 21 output 10 2 2 400.0 NULL 100 2 2 800.0 800.0 @@ -238,7 +238,7 @@ NULL 1 1 400.0 NULL -- !query 22 select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id > 200) from emp group by dept_id -- !query 22 schema -struct +struct 200)):double> -- !query 22 output 10 2 2 400.0 NULL 100 2 2 NULL 800.0 @@ -251,7 +251,7 @@ NULL 1 1 NULL 400.0 -- !query 23 select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id + dept_id > 500) from emp group by dept_id -- !query 23 schema -struct +struct 500)):double> -- !query 23 output 10 2 2 400.0 NULL 100 2 2 NULL 800.0 @@ -264,7 +264,7 @@ NULL 1 1 NULL NULL -- !query 24 SELECT 'foo', APPROX_COUNT_DISTINCT(a) FILTER (WHERE b >= 0) FROM testData WHERE a = 0 GROUP BY 1 -- !query 24 schema -struct +struct= 0)):bigint> -- !query 24 output @@ -272,7 +272,7 @@ struct -- !query 25 SELECT 'foo', MAX(STRUCT(a)) FILTER (WHERE b >= 1) FROM testData WHERE a = 0 GROUP BY 1 -- !query 25 schema -struct> +struct= 1)):struct> -- !query 25 output @@ -280,7 +280,7 @@ struct> -- !query 26 SELECT a + b, COUNT(b) FILTER (WHERE b >= 2) FROM testData GROUP BY a + b -- !query 26 schema -struct<(a + b):int,count(b):bigint> +struct<(a + b):int,count(b) FILTER (WHERE (b >= 2)):bigint> -- !query 26 output 2 0 3 1 @@ -301,7 +301,7 @@ expression 'testdata.`a`' is neither present in the group by, nor is it an aggre -- !query 28 SELECT a + 1 + 1, COUNT(b) FILTER (WHERE b > 0) FROM testData GROUP BY a + 1 -- !query 28 schema -struct<((a + 1) + 1):int,count(b):bigint> +struct<((a + 1) + 1):int,count(b) FILTER (WHERE (b > 0)):bigint> -- !query 28 output 3 2 4 2 @@ -312,7 +312,7 @@ NULL 1 -- !query 29 SELECT a AS k, COUNT(b) FILTER (WHERE b > 0) FROM testData GROUP BY k -- !query 29 schema -struct +struct 0)):bigint> -- !query 29 output 1 2 2 2 @@ -327,7 +327,7 @@ SELECT emp.dept_id, FROM emp GROUP BY dept_id -- !query 30 schema -struct +struct scalarsubquery())):double> -- !query 30 output 10 133.33333333333334 NULL 100 400.0 400.0 @@ -344,7 +344,7 @@ SELECT emp.dept_id, FROM emp GROUP BY dept_id -- !query 31 schema -struct +struct -- !query 31 output 10 133.33333333333334 133.33333333333334 100 400.0 NULL @@ -366,7 +366,7 @@ GROUP BY dept_id struct<> -- !query 32 output org.apache.spark.sql.AnalysisException -IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, avg(salary#x) AS avg(salary)#x, avg(salary#x) AS avg(salary)#x] +IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, avg(salary#x) AS avg(salary)#x, avg(salary#x) FILTER (WHERE exists#x [dept_id#x]) AS avg(salary) FILTER (WHERE exists(dept_id))#x] : +- Project [state#x] : +- Filter (dept_id#x = outer(dept_id#x)) : +- SubqueryAlias `dept` @@ -392,7 +392,7 @@ GROUP BY dept_id struct<> -- !query 33 output org.apache.spark.sql.AnalysisException -IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, sum(salary#x) AS sum(salary)#x, sum(salary#x) AS sum(salary)#x] +IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, sum(salary#x) AS sum(salary)#x, sum(salary#x) FILTER (WHERE NOT exists#x [dept_id#x]) AS sum(salary) FILTER (WHERE (NOT exists(dept_id)))#x] : +- Project [state#x] : +- Filter (dept_id#x = outer(dept_id#x)) : +- SubqueryAlias `dept` @@ -417,7 +417,7 @@ GROUP BY dept_id struct<> -- !query 34 output org.apache.spark.sql.AnalysisException -IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, avg(salary#x) AS avg(salary)#x, avg(salary#x) AS avg(salary)#x] +IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, avg(salary#x) AS avg(salary)#x, avg(salary#x) FILTER (WHERE dept_id#x IN (list#x [])) AS avg(salary) FILTER (WHERE (dept_id IN (listquery())))#x] : +- Distinct : +- Project [dept_id#x] : +- SubqueryAlias `dept` @@ -442,7 +442,7 @@ GROUP BY dept_id struct<> -- !query 35 output org.apache.spark.sql.AnalysisException -IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, sum(salary#x) AS sum(salary)#x, sum(salary#x) AS sum(salary)#x] +IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, sum(salary#x) AS sum(salary)#x, sum(salary#x) FILTER (WHERE NOT dept_id#x IN (list#x [])) AS sum(salary) FILTER (WHERE (NOT (dept_id IN (listquery()))))#x] : +- Distinct : +- Project [dept_id#x] : +- SubqueryAlias `dept` diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part3.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part3.sql.out index 9678b2e8966bc..d2ab138efcdae 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part3.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part3.sql.out @@ -14,7 +14,7 @@ It is not allowed to use an aggregate function in the argument of another aggreg -- !query 1 select min(unique1) filter (where unique1 > 100) from tenk1 -- !query 1 schema -struct +struct 100)):int> -- !query 1 output 101 @@ -22,7 +22,7 @@ struct -- !query 2 select sum(1/ten) filter (where ten > 0) from tenk1 -- !query 2 schema -struct +struct 0)):double> -- !query 2 output 2828.9682539682954 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index e472ceac7c1a6..59e34e23198de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -27,6 +27,7 @@ import scala.collection.parallel.immutable.ParVector import org.apache.spark.{AccumulatorSuite, SparkException} import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.expressions.GenericRow +import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, Partial} import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.HiveResult.hiveResultString @@ -2843,16 +2844,18 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark val query = s"SELECT ${funcToResult._1} FILTER (WHERE b > 1) FROM testData2" val df = sql(query) val physical = df.queryExecution.sparkPlan - val aggregateExpressions = physical.collectFirst { + val aggregateExpressions = physical.collect { case agg: HashAggregateExec => agg.aggregateExpressions case agg: ObjectHashAggregateExec => agg.aggregateExpressions + }.flatten + aggregateExpressions.foreach { expr => + if (expr.mode == Complete || expr.mode == Partial) { + assert(expr.filter.isDefined) + } else { + assert(expr.filter.isEmpty) + } } - assert(aggregateExpressions.isDefined) - assert(aggregateExpressions.get.size == 1) - aggregateExpressions.get.foreach { expr => - assert(expr.filter.isDefined) - } - checkAnswer(df, Row(funcToResult._2) :: Nil) + checkAnswer(df, Row(funcToResult._2)) } } @@ -2860,15 +2863,17 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "false") { val df = sql("SELECT PERCENTILE(a, 1) FILTER (WHERE b > 1) FROM testData2") val physical = df.queryExecution.sparkPlan - val aggregateExpressions = physical.collectFirst { + val aggregateExpressions = physical.collect { case agg: SortAggregateExec => agg.aggregateExpressions + }.flatten + aggregateExpressions.foreach { expr => + if (expr.mode == Complete || expr.mode == Partial) { + assert(expr.filter.isDefined) + } else { + assert(expr.filter.isEmpty) + } } - assert(aggregateExpressions.isDefined) - assert(aggregateExpressions.get.size == 1) - aggregateExpressions.get.foreach { expr => - assert(expr.filter.isDefined) - } - checkAnswer(df, Row(3) :: Nil) + checkAnswer(df, Row(3)) } }