Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,11 @@ case class AggregateExpression(

@transient
override lazy val references: AttributeSet = {
mode match {
case Partial | Complete => aggregateFunction.references ++ filterAttributes
val aggAttributes = mode match {
case Partial | Complete => aggregateFunction.references
case PartialMerge | Final => AttributeSet(aggregateFunction.aggBufferAttributes)
}
aggAttributes ++ filterAttributes
}

override def toString: String = {
Expand All @@ -149,10 +150,20 @@ case class AggregateExpression(
case PartialMerge => "merge_"
case Final | Complete => ""
}
prefix + aggregateFunction.toAggString(isDistinct)
val aggFuncStr = prefix + aggregateFunction.toAggString(isDistinct)
filter match {
case Some(predicate) => s"$aggFuncStr FILTER (WHERE $predicate)"
case _ => aggFuncStr
}
}

override def sql: String = aggregateFunction.sql(isDistinct)
override def sql: String = {
val aggFuncStr = aggregateFunction.sql(isDistinct)
filter match {
case Some(predicate) => s"$aggFuncStr FILTER (WHERE ${predicate.sql})"
case _ => aggFuncStr
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,29 @@ package org.apache.spark.sql.execution.aggregate

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.optimizer.NormalizeFloatingNumbers
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.streaming.{StateStoreRestoreExec, StateStoreSaveExec}

/**
* Utility functions used by the query planner to convert our plan to new aggregation code path.
*/
object AggUtils {

private def mayRemoveAggFilters(exprs: Seq[AggregateExpression]): Seq[AggregateExpression] = {
exprs.map { ae =>
if (ae.filter.isDefined) {
ae.mode match {
// Aggregate filters are applicable only in partial/complete modes;
// this method filters out them, otherwise.
case Partial | Complete => ae
case _ => ae.copy(filter = None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also simplify AggregateExpression.references now

}
} else {
ae
}
}
}

private def createAggregate(
requiredChildDistributionExpressions: Option[Seq[Expression]] = None,
groupingExpressions: Seq[NamedExpression] = Nil,
Expand All @@ -41,7 +56,7 @@ object AggUtils {
HashAggregateExec(
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
groupingExpressions = groupingExpressions,
aggregateExpressions = aggregateExpressions,
aggregateExpressions = mayRemoveAggFilters(aggregateExpressions),
aggregateAttributes = aggregateAttributes,
initialInputBufferOffset = initialInputBufferOffset,
resultExpressions = resultExpressions,
Expand All @@ -54,7 +69,7 @@ object AggUtils {
ObjectHashAggregateExec(
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
groupingExpressions = groupingExpressions,
aggregateExpressions = aggregateExpressions,
aggregateExpressions = mayRemoveAggFilters(aggregateExpressions),
aggregateAttributes = aggregateAttributes,
initialInputBufferOffset = initialInputBufferOffset,
resultExpressions = resultExpressions,
Expand All @@ -63,7 +78,7 @@ object AggUtils {
SortAggregateExec(
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
groupingExpressions = groupingExpressions,
aggregateExpressions = aggregateExpressions,
aggregateExpressions = mayRemoveAggFilters(aggregateExpressions),
aggregateAttributes = aggregateAttributes,
initialInputBufferOffset = initialInputBufferOffset,
resultExpressions = resultExpressions,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,53 +51,53 @@ SELECT a, COUNT(b) FILTER (WHERE a >= 2) FROM testData
struct<>
-- !query 3 output
org.apache.spark.sql.AnalysisException
grouping expressions sequence is empty, and 'testdata.`a`' is not an aggregate function. Wrap '(count(testdata.`b`) AS `count(b)`)' in windowing function(s) or wrap 'testdata.`a`' in first() (or first_value) if you don't care which value you get.;
grouping expressions sequence is empty, and 'testdata.`a`' is not an aggregate function. Wrap '(count(testdata.`b`) FILTER (WHERE (testdata.`a` >= 2)) AS `count(b) FILTER (WHERE (a >= 2))`)' in windowing function(s) or wrap 'testdata.`a`' in first() (or first_value) if you don't care which value you get.;


-- !query 4
SELECT COUNT(a) FILTER (WHERE a = 1), COUNT(b) FILTER (WHERE a > 1) FROM testData
-- !query 4 schema
struct<count(a):bigint,count(b):bigint>
struct<count(a) FILTER (WHERE (a = 1)):bigint,count(b) FILTER (WHERE (a > 1)):bigint>
-- !query 4 output
2 4


-- !query 5
SELECT COUNT(id) FILTER (WHERE hiredate = date "2001-01-01") FROM emp
-- !query 5 schema
struct<count(id):bigint>
struct<count(id) FILTER (WHERE (hiredate = DATE '2001-01-01')):bigint>
-- !query 5 output
2


-- !query 6
SELECT COUNT(id) FILTER (WHERE hiredate = to_date('2001-01-01 00:00:00')) FROM emp
-- !query 6 schema
struct<count(id):bigint>
struct<count(id) FILTER (WHERE (hiredate = to_date('2001-01-01 00:00:00'))):bigint>
-- !query 6 output
2


-- !query 7
SELECT COUNT(id) FILTER (WHERE hiredate = to_timestamp("2001-01-01 00:00:00")) FROM emp
-- !query 7 schema
struct<count(id):bigint>
struct<count(id) FILTER (WHERE (CAST(hiredate AS TIMESTAMP) = to_timestamp('2001-01-01 00:00:00'))):bigint>
-- !query 7 output
2


-- !query 8
SELECT COUNT(id) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd") = "2001-01-01") FROM emp
-- !query 8 schema
struct<count(id):bigint>
struct<count(id) FILTER (WHERE (date_format(CAST(hiredate AS TIMESTAMP), yyyy-MM-dd) = 2001-01-01)):bigint>
-- !query 8 output
2


-- !query 9
SELECT a, COUNT(b) FILTER (WHERE a >= 2) FROM testData GROUP BY a
-- !query 9 schema
struct<a:int,count(b):bigint>
struct<a:int,count(b) FILTER (WHERE (a >= 2)):bigint>
-- !query 9 output
1 0
2 2
Expand All @@ -117,7 +117,7 @@ expression 'testdata.`a`' is neither present in the group by, nor is it an aggre
-- !query 11
SELECT COUNT(a) FILTER (WHERE a >= 0), COUNT(b) FILTER (WHERE a >= 3) FROM testData GROUP BY a
-- !query 11 schema
struct<count(a):bigint,count(b):bigint>
struct<count(a) FILTER (WHERE (a >= 0)):bigint,count(b) FILTER (WHERE (a >= 3)):bigint>
-- !query 11 output
0 0
2 0
Expand All @@ -128,7 +128,7 @@ struct<count(a):bigint,count(b):bigint>
-- !query 12
SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > date "2003-01-01") FROM emp GROUP BY dept_id
-- !query 12 schema
struct<dept_id:int,sum(salary):double>
struct<dept_id:int,sum(salary) FILTER (WHERE (hiredate > DATE '2003-01-01')):double>
-- !query 12 output
10 200.0
100 400.0
Expand All @@ -141,7 +141,7 @@ NULL NULL
-- !query 13
SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > to_date("2003-01-01")) FROM emp GROUP BY dept_id
-- !query 13 schema
struct<dept_id:int,sum(salary):double>
struct<dept_id:int,sum(salary) FILTER (WHERE (hiredate > to_date('2003-01-01'))):double>
-- !query 13 output
10 200.0
100 400.0
Expand All @@ -154,7 +154,7 @@ NULL NULL
-- !query 14
SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > to_timestamp("2003-01-01 00:00:00")) FROM emp GROUP BY dept_id
-- !query 14 schema
struct<dept_id:int,sum(salary):double>
struct<dept_id:int,sum(salary) FILTER (WHERE (CAST(hiredate AS TIMESTAMP) > to_timestamp('2003-01-01 00:00:00'))):double>
-- !query 14 output
10 200.0
100 400.0
Expand All @@ -167,7 +167,7 @@ NULL NULL
-- !query 15
SELECT dept_id, SUM(salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd") > "2003-01-01") FROM emp GROUP BY dept_id
-- !query 15 schema
struct<dept_id:int,sum(salary):double>
struct<dept_id:int,sum(salary) FILTER (WHERE (date_format(CAST(hiredate AS TIMESTAMP), yyyy-MM-dd) > 2003-01-01)):double>
-- !query 15 output
10 200.0
100 400.0
Expand All @@ -180,39 +180,39 @@ NULL NULL
-- !query 16
SELECT 'foo', COUNT(a) FILTER (WHERE b <= 2) FROM testData GROUP BY 1
-- !query 16 schema
struct<foo:string,count(a):bigint>
struct<foo:string,count(a) FILTER (WHERE (b <= 2)):bigint>
-- !query 16 output
foo 6


-- !query 17
SELECT 'foo', SUM(salary) FILTER (WHERE hiredate >= date "2003-01-01") FROM emp GROUP BY 1
-- !query 17 schema
struct<foo:string,sum(salary):double>
struct<foo:string,sum(salary) FILTER (WHERE (hiredate >= DATE '2003-01-01')):double>
-- !query 17 output
foo 1350.0


-- !query 18
SELECT 'foo', SUM(salary) FILTER (WHERE hiredate >= to_date("2003-01-01")) FROM emp GROUP BY 1
-- !query 18 schema
struct<foo:string,sum(salary):double>
struct<foo:string,sum(salary) FILTER (WHERE (hiredate >= to_date('2003-01-01'))):double>
-- !query 18 output
foo 1350.0


-- !query 19
SELECT 'foo', SUM(salary) FILTER (WHERE hiredate >= to_timestamp("2003-01-01")) FROM emp GROUP BY 1
-- !query 19 schema
struct<foo:string,sum(salary):double>
struct<foo:string,sum(salary) FILTER (WHERE (CAST(hiredate AS TIMESTAMP) >= to_timestamp('2003-01-01'))):double>
-- !query 19 output
foo 1350.0


-- !query 20
select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary), sum(salary) filter (where id > 200) from emp group by dept_id
-- !query 20 schema
struct<dept_id:int,count(DISTINCT emp_name):bigint,count(DISTINCT hiredate):bigint,sum(salary):double,sum(salary):double>
struct<dept_id:int,count(DISTINCT emp_name):bigint,count(DISTINCT hiredate):bigint,sum(salary):double,sum(salary) FILTER (WHERE (id > 200)):double>
-- !query 20 output
10 2 2 400.0 NULL
100 2 2 800.0 800.0
Expand All @@ -225,7 +225,7 @@ NULL 1 1 400.0 400.0
-- !query 21
select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary), sum(salary) filter (where id + dept_id > 500) from emp group by dept_id
-- !query 21 schema
struct<dept_id:int,count(DISTINCT emp_name):bigint,count(DISTINCT hiredate):bigint,sum(salary):double,sum(salary):double>
struct<dept_id:int,count(DISTINCT emp_name):bigint,count(DISTINCT hiredate):bigint,sum(salary):double,sum(salary) FILTER (WHERE ((id + dept_id) > 500)):double>
-- !query 21 output
10 2 2 400.0 NULL
100 2 2 800.0 800.0
Expand All @@ -238,7 +238,7 @@ NULL 1 1 400.0 NULL
-- !query 22
select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id > 200) from emp group by dept_id
-- !query 22 schema
struct<dept_id:int,count(DISTINCT emp_name):bigint,count(DISTINCT hiredate):bigint,sum(salary):double,sum(salary):double>
struct<dept_id:int,count(DISTINCT emp_name):bigint,count(DISTINCT hiredate):bigint,sum(salary) FILTER (WHERE (salary < 400.0)):double,sum(salary) FILTER (WHERE (id > 200)):double>
-- !query 22 output
10 2 2 400.0 NULL
100 2 2 NULL 800.0
Expand All @@ -251,7 +251,7 @@ NULL 1 1 NULL 400.0
-- !query 23
select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id + dept_id > 500) from emp group by dept_id
-- !query 23 schema
struct<dept_id:int,count(DISTINCT emp_name):bigint,count(DISTINCT hiredate):bigint,sum(salary):double,sum(salary):double>
struct<dept_id:int,count(DISTINCT emp_name):bigint,count(DISTINCT hiredate):bigint,sum(salary) FILTER (WHERE (salary < 400.0)):double,sum(salary) FILTER (WHERE ((id + dept_id) > 500)):double>
-- !query 23 output
10 2 2 400.0 NULL
100 2 2 NULL 800.0
Expand All @@ -264,23 +264,23 @@ NULL 1 1 NULL NULL
-- !query 24
SELECT 'foo', APPROX_COUNT_DISTINCT(a) FILTER (WHERE b >= 0) FROM testData WHERE a = 0 GROUP BY 1
-- !query 24 schema
struct<foo:string,approx_count_distinct(a):bigint>
struct<foo:string,approx_count_distinct(a) FILTER (WHERE (b >= 0)):bigint>
-- !query 24 output



-- !query 25
SELECT 'foo', MAX(STRUCT(a)) FILTER (WHERE b >= 1) FROM testData WHERE a = 0 GROUP BY 1
-- !query 25 schema
struct<foo:string,max(named_struct(a, a)):struct<a:int>>
struct<foo:string,max(named_struct(a, a)) FILTER (WHERE (b >= 1)):struct<a:int>>
-- !query 25 output



-- !query 26
SELECT a + b, COUNT(b) FILTER (WHERE b >= 2) FROM testData GROUP BY a + b
-- !query 26 schema
struct<(a + b):int,count(b):bigint>
struct<(a + b):int,count(b) FILTER (WHERE (b >= 2)):bigint>
-- !query 26 output
2 0
3 1
Expand All @@ -301,7 +301,7 @@ expression 'testdata.`a`' is neither present in the group by, nor is it an aggre
-- !query 28
SELECT a + 1 + 1, COUNT(b) FILTER (WHERE b > 0) FROM testData GROUP BY a + 1
-- !query 28 schema
struct<((a + 1) + 1):int,count(b):bigint>
struct<((a + 1) + 1):int,count(b) FILTER (WHERE (b > 0)):bigint>
-- !query 28 output
3 2
4 2
Expand All @@ -312,7 +312,7 @@ NULL 1
-- !query 29
SELECT a AS k, COUNT(b) FILTER (WHERE b > 0) FROM testData GROUP BY k
-- !query 29 schema
struct<k:int,count(b):bigint>
struct<k:int,count(b) FILTER (WHERE (b > 0)):bigint>
-- !query 29 output
1 2
2 2
Expand All @@ -327,7 +327,7 @@ SELECT emp.dept_id,
FROM emp
GROUP BY dept_id
-- !query 30 schema
struct<dept_id:int,avg(salary):double,avg(salary):double>
struct<dept_id:int,avg(salary):double,avg(salary) FILTER (WHERE (id > scalarsubquery())):double>
-- !query 30 output
10 133.33333333333334 NULL
100 400.0 400.0
Expand All @@ -344,7 +344,7 @@ SELECT emp.dept_id,
FROM emp
GROUP BY dept_id
-- !query 31 schema
struct<dept_id:int,avg(salary):double,avg(salary):double>
struct<dept_id:int,avg(salary):double,avg(salary) FILTER (WHERE (dept_id = scalarsubquery())):double>
-- !query 31 output
10 133.33333333333334 133.33333333333334
100 400.0 NULL
Expand All @@ -366,7 +366,7 @@ GROUP BY dept_id
struct<>
-- !query 32 output
org.apache.spark.sql.AnalysisException
IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, avg(salary#x) AS avg(salary)#x, avg(salary#x) AS avg(salary)#x]
IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, avg(salary#x) AS avg(salary)#x, avg(salary#x) FILTER (WHERE exists#x [dept_id#x]) AS avg(salary) FILTER (WHERE exists(dept_id))#x]
: +- Project [state#x]
: +- Filter (dept_id#x = outer(dept_id#x))
: +- SubqueryAlias `dept`
Expand All @@ -392,7 +392,7 @@ GROUP BY dept_id
struct<>
-- !query 33 output
org.apache.spark.sql.AnalysisException
IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, sum(salary#x) AS sum(salary)#x, sum(salary#x) AS sum(salary)#x]
IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, sum(salary#x) AS sum(salary)#x, sum(salary#x) FILTER (WHERE NOT exists#x [dept_id#x]) AS sum(salary) FILTER (WHERE (NOT exists(dept_id)))#x]
: +- Project [state#x]
: +- Filter (dept_id#x = outer(dept_id#x))
: +- SubqueryAlias `dept`
Expand All @@ -417,7 +417,7 @@ GROUP BY dept_id
struct<>
-- !query 34 output
org.apache.spark.sql.AnalysisException
IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, avg(salary#x) AS avg(salary)#x, avg(salary#x) AS avg(salary)#x]
IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, avg(salary#x) AS avg(salary)#x, avg(salary#x) FILTER (WHERE dept_id#x IN (list#x [])) AS avg(salary) FILTER (WHERE (dept_id IN (listquery())))#x]
: +- Distinct
: +- Project [dept_id#x]
: +- SubqueryAlias `dept`
Expand All @@ -442,7 +442,7 @@ GROUP BY dept_id
struct<>
-- !query 35 output
org.apache.spark.sql.AnalysisException
IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, sum(salary#x) AS sum(salary)#x, sum(salary#x) AS sum(salary)#x]
IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, sum(salary#x) AS sum(salary)#x, sum(salary#x) FILTER (WHERE NOT dept_id#x IN (list#x [])) AS sum(salary) FILTER (WHERE (NOT (dept_id IN (listquery()))))#x]
: +- Distinct
: +- Project [dept_id#x]
: +- SubqueryAlias `dept`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ It is not allowed to use an aggregate function in the argument of another aggreg
-- !query 1
select min(unique1) filter (where unique1 > 100) from tenk1
-- !query 1 schema
struct<min(unique1):int>
struct<min(unique1) FILTER (WHERE (unique1 > 100)):int>
-- !query 1 output
101


-- !query 2
select sum(1/ten) filter (where ten > 0) from tenk1
-- !query 2 schema
struct<sum((CAST(1 AS DOUBLE) / CAST(ten AS DOUBLE))):double>
struct<sum((CAST(1 AS DOUBLE) / CAST(ten AS DOUBLE))) FILTER (WHERE (ten > 0)):double>
-- !query 2 output
2828.9682539682954

Expand Down
Loading