Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1744,10 +1744,10 @@ class Analyzer(
* it into the plan tree.
*/
object ExtractWindowExpressions extends Rule[LogicalPlan] {
private def hasWindowFunction(projectList: Seq[NamedExpression]): Boolean =
projectList.exists(hasWindowFunction)
private def hasWindowFunction(exprs: Seq[Expression]): Boolean =
exprs.exists(hasWindowFunction)

private def hasWindowFunction(expr: NamedExpression): Boolean = {
private def hasWindowFunction(expr: Expression): Boolean = {
expr.find {
case window: WindowExpression => true
case _ => false
Expand Down Expand Up @@ -1830,6 +1830,10 @@ class Analyzer(
seenWindowAggregates += newAgg
WindowExpression(newAgg, spec)

case AggregateExpression(aggFunc, _, _, _) if hasWindowFunction(aggFunc.children) =>
failAnalysis("It is not allowed to use a window function inside an aggregate " +
"function. Please use the inner window function in a sub-query.")

// Extracts AggregateExpression. For example, for SUM(x) - Sum(y) OVER (...),
// we need to extract SUM(x).
case agg: AggregateExpression if !seenWindowAggregates.contains(agg) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ package org.apache.spark.sql

import scala.util.Random

import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
import org.scalatest.Matchers.the

import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
Expand Down Expand Up @@ -687,4 +687,34 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
}
}
}

test("SPARK-21896: Window functions inside aggregate functions") {
def checkWindowError(df: => DataFrame): Unit = {
val thrownException = the [AnalysisException] thrownBy {
df.queryExecution.analyzed
}
assert(thrownException.message.contains("not allowed to use a window function"))
}

checkWindowError(testData2.select(min(avg('b).over(Window.partitionBy('a)))))
checkWindowError(testData2.agg(sum('b), max(rank().over(Window.orderBy('a)))))
checkWindowError(testData2.groupBy('a).agg(sum('b), max(rank().over(Window.orderBy('b)))))
checkWindowError(testData2.groupBy('a).agg(max(sum(sum('b)).over(Window.orderBy('a)))))
checkWindowError(
testData2.groupBy('a).agg(sum('b).as("s"), max(count("*").over())).where('s === 3))
checkAnswer(
testData2.groupBy('a).agg(max('b), sum('b).as("s"), count("*").over()).where('s === 3),
Row(1, 2, 3, 3) :: Row(2, 2, 3, 3) :: Row(3, 2, 3, 3) :: Nil)

checkWindowError(sql("SELECT MIN(AVG(b) OVER(PARTITION BY a)) FROM testData2"))
checkWindowError(sql("SELECT SUM(b), MAX(RANK() OVER(ORDER BY a)) FROM testData2"))
checkWindowError(sql("SELECT SUM(b), MAX(RANK() OVER(ORDER BY b)) FROM testData2 GROUP BY a"))
checkWindowError(sql("SELECT MAX(SUM(SUM(b)) OVER(ORDER BY a)) FROM testData2 GROUP BY a"))
checkWindowError(
sql("SELECT MAX(RANK() OVER(ORDER BY b)) FROM testData2 GROUP BY a HAVING SUM(b) = 3"))
checkAnswer(
sql("SELECT a, MAX(b), RANK() OVER(ORDER BY a) FROM testData2 GROUP BY a HAVING SUM(b) = 3"),
Row(1, 2, 1) :: Row(2, 2, 2) :: Row(3, 2, 3) :: Nil)
}

}