Skip to content

Commit b91873d

Browse files
cloud-fanhvanhovell
authored andcommitted
[SPARK-20409][SQL] fail early if aggregate function in GROUP BY
## What changes were proposed in this pull request? It's illegal to have aggregate function in GROUP BY, and we should fail at analysis phase, if this happens. ## How was this patch tested? new regression test Author: Wenchen Fan <[email protected]> Closes #17704 from cloud-fan/minor.
1 parent c6f62c5 commit b91873d

File tree

4 files changed

+19
-13
lines changed

4 files changed

+19
-13
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -966,7 +966,7 @@ class Analyzer(
966966
case p if !p.childrenResolved => p
967967
// Replace the index with the related attribute for ORDER BY,
968968
// which is a 1-base position of the projection list.
969-
case s @ Sort(orders, global, child)
969+
case Sort(orders, global, child)
970970
if orders.exists(_.child.isInstanceOf[UnresolvedOrdinal]) =>
971971
val newOrders = orders map {
972972
case s @ SortOrder(UnresolvedOrdinal(index), direction, nullOrdering, _) =>
@@ -983,17 +983,11 @@ class Analyzer(
983983

984984
// Replace the index with the corresponding expression in aggregateExpressions. The index is
985985
// a 1-base position of aggregateExpressions, which is output columns (select expression)
986-
case a @ Aggregate(groups, aggs, child) if aggs.forall(_.resolved) &&
986+
case Aggregate(groups, aggs, child) if aggs.forall(_.resolved) &&
987987
groups.exists(_.isInstanceOf[UnresolvedOrdinal]) =>
988988
val newGroups = groups.map {
989-
case ordinal @ UnresolvedOrdinal(index) if index > 0 && index <= aggs.size =>
990-
aggs(index - 1) match {
991-
case e if ResolveAggregateFunctions.containsAggregate(e) =>
992-
ordinal.failAnalysis(
993-
s"GROUP BY position $index is an aggregate function, and " +
994-
"aggregate functions are not allowed in GROUP BY")
995-
case o => o
996-
}
989+
case u @ UnresolvedOrdinal(index) if index > 0 && index <= aggs.size =>
990+
aggs(index - 1)
997991
case ordinal @ UnresolvedOrdinal(index) =>
998992
ordinal.failAnalysis(
999993
s"GROUP BY position $index is not in select list " +

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,11 @@ trait CheckAnalysis extends PredicateHelper {
254254
}
255255

256256
def checkValidGroupingExprs(expr: Expression): Unit = {
257+
if (expr.find(_.isInstanceOf[AggregateExpression]).isDefined) {
258+
failAnalysis(
259+
"aggregate functions are not allowed in GROUP BY, but found " + expr.sql)
260+
}
261+
257262
// Check if the data type of expr is orderable.
258263
if (!RowOrdering.isOrderable(expr.dataType)) {
259264
failAnalysis(
@@ -271,8 +276,8 @@ trait CheckAnalysis extends PredicateHelper {
271276
}
272277
}
273278

274-
aggregateExprs.foreach(checkValidAggregateExpression)
275279
groupingExprs.foreach(checkValidGroupingExprs)
280+
aggregateExprs.foreach(checkValidAggregateExpression)
276281

277282
case Sort(orders, _, _) =>
278283
orders.foreach { order =>

sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ select a, b, sum(b) from data group by 3
122122
struct<>
123123
-- !query 11 output
124124
org.apache.spark.sql.AnalysisException
125-
GROUP BY position 3 is an aggregate function, and aggregate functions are not allowed in GROUP BY; line 1 pos 39
125+
aggregate functions are not allowed in GROUP BY, but found sum(CAST(data.`b` AS BIGINT));
126126

127127

128128
-- !query 12
@@ -131,7 +131,7 @@ select a, b, sum(b) + 2 from data group by 3
131131
struct<>
132132
-- !query 12 output
133133
org.apache.spark.sql.AnalysisException
134-
GROUP BY position 3 is an aggregate function, and aggregate functions are not allowed in GROUP BY; line 1 pos 43
134+
aggregate functions are not allowed in GROUP BY, but found (sum(CAST(data.`b` AS BIGINT)) + CAST(2 AS BIGINT));
135135

136136

137137
-- !query 13

sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,4 +538,11 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
538538
Seq(Row(3, 0, 0.0, 1, 5.0), Row(2, 1, 4.0, 0, 0.0))
539539
)
540540
}
541+
542+
test("aggregate function in GROUP BY") {
543+
val e = intercept[AnalysisException] {
544+
testData.groupBy(sum($"key")).count()
545+
}
546+
assert(e.message.contains("aggregate functions are not allowed in GROUP BY"))
547+
}
541548
}

0 commit comments

Comments
 (0)