Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class Analyzer(
Batch("Resolution", fixedPoint,
ResolveRelations ::
ResolveReferences ::
ResolveNumericReferences ::
ResolveGroupingAnalytics ::
ResolvePivot ::
ResolveSortReferences ::
Expand Down Expand Up @@ -178,6 +179,46 @@ class Analyzer(
}
}

/**
* Replaces queries of the form "SELECT expr FROM A GROUP BY 1 ORDER BY 1"
* with a query of the form "SELECT expr FROM A GROUP BY expr ORDER BY expr"
*/
object ResolveNumericReferences extends Rule[LogicalPlan] {

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case Aggregate(groups, aggs, child) =>
val newGroups = groups.map {
case group if group.isInstanceOf[Literal] && group.dataType.isInstanceOf[IntegralType] =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

case Literal(index: Int) => is easier. It also eliminates the need for group.toString.toInt

aggs(group.toString.toInt - 1) match {
case u: UnresolvedAlias =>
u.child match {
case UnresolvedStar(_) => // Can't replace literal with column yet
group
case _ => u.child
}
case a: Alias => a.child
case a: AttributeReference => a
}
case group => group
}
Aggregate(newGroups, aggs, child)
case Sort(ordering, global, child) =>
val newOrdering = ordering.map {
case o if o.child.isInstanceOf[Literal] && o.dataType.isInstanceOf[IntegralType] =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above.

val newExpr = child.asInstanceOf[Project].projectList(o.child.toString.toInt - 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests are currently failing to due this cast. A Sort's child doesn't have to be a Project, use the LogicalPlan output attributes to get the proper attribute.

match {
case u: UnresolvedAlias =>
u.child
case a: Alias =>
a.child
}
SortOrder(newExpr, o.direction)
case other => other
}
Sort(newOrdering, global, child)
}
}

object ResolveGroupingAnalytics extends Rule[LogicalPlan] {
/*
* GROUP BY a, b, c WITH ROLLUP
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,12 @@ trait CheckAnalysis {
s"data type.")
}

if (expr.isInstanceOf[AggregateExpression] || expr.isInstanceOf[AggregateFunction]) {
// Aggregate function in group by clause; this fails to execute
failAnalysis(s"aggregate expression ${expr.prettyString} should not " +
s"appear in grouping expression.")
}

if (!expr.deterministic) {
// This is just a sanity check, our analysis rule PullOutNondeterministic should
// already pull out those nondeterministic expressions and evaluate them in
Expand Down
32 changes: 13 additions & 19 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -472,25 +472,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
Seq(Row(1, 3), Row(2, 3), Row(3, 3)))
}

test("literal in agg grouping expressions") {
checkAnswer(
sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"),
Seq(Row(1, 2), Row(2, 2), Row(3, 2)))
checkAnswer(
sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"),
Seq(Row(1, 2), Row(2, 2), Row(3, 2)))

checkAnswer(
sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1"),
sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a"))
checkAnswer(
sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"),
sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a"))
checkAnswer(
sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY 1, 2"),
sql("SELECT 1, 2, sum(b) FROM testData2"))
}

test("aggregates with nulls") {
checkAnswer(
sql("SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a)," +
Expand Down Expand Up @@ -2028,4 +2009,17 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
Row(false) :: Row(true) :: Nil)
}

test("SPARK-12063: Group by Columns Number") {
checkAnswer(
sql("SELECT a, SUM(b) FROM testData2 GROUP BY 1"),
Seq(Row(1, 3), Row(2, 3), Row(3, 3)))
}

test("SPARK-12063: Order by Column Number") {
Seq(("one", 1), ("two", 2), ("three", 3), ("one", 5)).toDF("k", "v").registerTempTable("ord")
checkAnswer(
sql("SELECT v from ord order by 1 desc"),
Row(5) :: Row(3) :: Row(2) :: Row(1) :: Nil)
}

}