diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 47962ebe6ef8..700207bbc149 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -70,6 +70,7 @@ class Analyzer( Batch("Resolution", fixedPoint, ResolveRelations :: ResolveReferences :: + ResolveNumericReferences :: ResolveGroupingAnalytics :: ResolvePivot :: ResolveSortReferences :: @@ -178,6 +179,46 @@ class Analyzer( } } + /** + * Replaces queries of the form "SELECT expr FROM A GROUP BY 1 ORDER BY 1" + * with a query of the form "SELECT expr FROM A GROUP BY expr ORDER BY expr" + */ + object ResolveNumericReferences extends Rule[LogicalPlan] { + + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case Aggregate(groups, aggs, child) => + val newGroups = groups.map { + case group if group.isInstanceOf[Literal] && group.dataType.isInstanceOf[IntegralType] => + aggs(group.toString.toInt - 1) match { + case u: UnresolvedAlias => + u.child match { + case UnresolvedStar(_) => // Can't replace literal with column yet + group + case _ => u.child + } + case a: Alias => a.child + case a: AttributeReference => a + } + case group => group + } + Aggregate(newGroups, aggs, child) + case Sort(ordering, global, child) => + val newOrdering = ordering.map { + case o if o.child.isInstanceOf[Literal] && o.dataType.isInstanceOf[IntegralType] => + val newExpr = child.asInstanceOf[Project].projectList(o.child.toString.toInt - 1) + match { + case u: UnresolvedAlias => + u.child + case a: Alias => + a.child + } + SortOrder(newExpr, o.direction) + case other => other + } + Sort(newOrdering, global, child) + } + } + object ResolveGroupingAnalytics extends Rule[LogicalPlan] { /* * GROUP BY a, b, c WITH ROLLUP diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 7b2c93d63d67..aeccedb5f214 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -146,6 +146,12 @@ trait CheckAnalysis { s"data type.") } + if (expr.isInstanceOf[AggregateExpression] || expr.isInstanceOf[AggregateFunction]) { + // Aggregate function in group by clause; this fails to execute + failAnalysis(s"aggregate expression ${expr.prettyString} should not " + + s"appear in grouping expression.") + } + if (!expr.deterministic) { // This is just a sanity check, our analysis rule PullOutNondeterministic should // already pull out those nondeterministic expressions and evaluate them in diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index bb82b562aaaa..863a214078f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -472,25 +472,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Seq(Row(1, 3), Row(2, 3), Row(3, 3))) } - test("literal in agg grouping expressions") { - checkAnswer( - sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"), - Seq(Row(1, 2), Row(2, 2), Row(3, 2))) - checkAnswer( - sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"), - Seq(Row(1, 2), Row(2, 2), Row(3, 2))) - - checkAnswer( - sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1"), - sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) - checkAnswer( - sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"), - sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) - checkAnswer( - sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY 1, 2"), - sql("SELECT 1, 2, sum(b) FROM testData2")) - } - test("aggregates with nulls") { checkAnswer( sql("SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a)," + @@ -2028,4 +2009,17 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(false) :: Row(true) :: Nil) } + test("SPARK-12063: Group by Columns Number") { + checkAnswer( + sql("SELECT a, SUM(b) FROM testData2 GROUP BY 1"), + Seq(Row(1, 3), Row(2, 3), Row(3, 3))) + } + + test("SPARK-12063: Order by Column Number") { + Seq(("one", 1), ("two", 2), ("three", 3), ("one", 5)).toDF("k", "v").registerTempTable("ord") + checkAnswer( + sql("SELECT v from ord order by 1 desc"), + Row(5) :: Row(3) :: Row(2) :: Row(1) :: Nil) + } + }