Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.analysis

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.util.collection.OpenHashSet
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -61,6 +63,7 @@ class Analyzer(
ResolveGenerate ::
ImplicitGenerate ::
ResolveFunctions ::
ExtractWindowExpressions ::
GlobalAggregates ::
UnresolvedHavingClauseAttributes ::
TrimGroupingAliases ::
Expand Down Expand Up @@ -529,6 +532,203 @@ class Analyzer(
makeGeneratorOutput(p.generator, p.generatorOutput), p.child)
}
}

/**
* Extracts [[WindowExpression]]s from the projectList of a [[Project]] operator and
* aggregateExpressions of an [[Aggregate]] operator and creates individual [[Window]]
* operators for every distinct [[WindowSpecDefinition]].
*
* This rule handles three cases:
* - A [[Project]] having [[WindowExpression]]s in its projectList;
* - An [[Aggregate]] having [[WindowExpression]]s in its aggregateExpressions.
* - An [[Filter]]->[[Aggregate]] pattern representing GROUP BY with a HAVING
* clause and the [[Aggregate]] has [[WindowExpression]]s in its aggregateExpressions.
* Note: If there is a GROUP BY clause in the query, aggregations and corresponding
* filters (expressions in the HAVING clause) should be evaluated before any
* [[WindowExpression]]. If a query has SELECT DISTINCT, the DISTINCT part should be
* evaluated after all [[WindowExpression]]s.
*
* For every case, the transformation works as follows:
* 1. For a list of [[Expression]]s (a projectList or an aggregateExpressions), partitions
* it two lists of [[Expression]]s, one for all [[WindowExpression]]s and another for
* all regular expressions.
* 2. For all [[WindowExpression]]s, groups them based on their [[WindowSpecDefinition]]s.
* 3. For every distinct [[WindowSpecDefinition]], creates a [[Window]] operator and inserts
* it into the plan tree.
*/
object ExtractWindowExpressions extends Rule[LogicalPlan] {
def hasWindowFunction(projectList: Seq[NamedExpression]): Boolean =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

private def, below the same

projectList.exists(hasWindowFunction)

def hasWindowFunction(expr: NamedExpression): Boolean = {
expr.find {
case window: WindowExpression => true
case _ => false
}.isDefined
}

/**
* From a Seq of [[NamedExpression]]s, extract window expressions and
* other regular expressions.
*/
def extract(
expressions: Seq[NamedExpression]): (Seq[NamedExpression], Seq[NamedExpression]) = {
// First, we simple partition the input expressions to two part, one having
// WindowExpressions and another one without WindowExpressions.
val (windowExpressions, regularExpressions) = expressions.partition(hasWindowFunction)

// Then, we need to extract those regular expressions used in the WindowExpression.
// For example, when we have col1 - Sum(col2 + col3) OVER (PARTITION BY col4 ORDER BY col5),
// we need to make sure that col1 to col5 are all projected from the child of the Window
// operator.
val extractedExprBuffer = new ArrayBuffer[NamedExpression]()
def extractExpr(expr: Expression): Expression = expr match {
case ne: NamedExpression =>
// If a named expression is not in regularExpressions, add extract it and replace it
// with an AttributeReference.
val missingExpr =
AttributeSet(Seq(expr)) -- (regularExpressions ++ extractedExprBuffer)
if (missingExpr.nonEmpty) {
extractedExprBuffer += ne
}
ne.toAttribute
case e: Expression if e.foldable =>
e // No need to create an attribute reference if it will be evaluated as a Literal.
case e: Expression =>
// For other expressions, we extract it and replace it with an AttributeReference (with
// an interal column name, e.g. "_w0").
val withName = Alias(e, s"_w${extractedExprBuffer.length}")()
extractedExprBuffer += withName
withName.toAttribute
}

// Now, we extract expressions from windowExpressions by using extractExpr.
val newWindowExpressions = windowExpressions.map {
_.transform {
// Extracts children expressions of a WindowFunction (input parameters of
// a WindowFunction).
case wf : WindowFunction =>
val newChildren = wf.children.map(extractExpr(_))
wf.withNewChildren(newChildren)

// Extracts expressions from the partition spec and order spec.
case wsc @ WindowSpecDefinition(partitionSpec, orderSpec, _) =>
val newPartitionSpec = partitionSpec.map(extractExpr(_))
val newOrderSpec = orderSpec.map { so =>
val newChild = extractExpr(so.child)
so.copy(child = newChild)
}
wsc.copy(partitionSpec = newPartitionSpec, orderSpec = newOrderSpec)

// Extracts AggregateExpression. For example, for SUM(x) - Sum(y) OVER (...),
// we need to extract SUM(x).
case agg: AggregateExpression =>
val withName = Alias(agg, s"_w${extractedExprBuffer.length}")()
extractedExprBuffer += withName
withName.toAttribute
}.asInstanceOf[NamedExpression]
}

(newWindowExpressions, regularExpressions ++ extractedExprBuffer)
}

/**
* Adds operators for Window Expressions. Every Window operator handles a single Window Spec.
*/
def addWindow(windowExpressions: Seq[NamedExpression], child: LogicalPlan): LogicalPlan = {
// First, we group window expressions based on their Window Spec.
val groupedWindowExpression = windowExpressions.groupBy { expr =>
val windowExpression = expr.find {
case window: WindowExpression => true
case other => false
}.map(_.asInstanceOf[WindowExpression].windowSpec)
windowExpression.getOrElse(
failAnalysis(s"$windowExpressions does not have any WindowExpression."))
}.toSeq

// For every Window Spec, we add a Window operator and set currentChild as the child of it.
var currentChild = child
var i = 0
while (i < groupedWindowExpression.size) {
val (windowSpec, windowExpressions) = groupedWindowExpression(i)
// Set currentChild to the newly created Window operator.
currentChild = Window(currentChild.output, windowExpressions, windowSpec, currentChild)

// Move to next WindowExpression.
i += 1
}

// We return the top operator.
currentChild
}

// We have to use transformDown at here to make sure the rule of
// "Aggregate with Having clause" will be triggered.
def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
// Lookup WindowSpecDefinitions. This rule works with unresolved children.
case WithWindowDefinition(windowDefinitions, child) =>
child.transform {
case plan => plan.transformExpressions {
case UnresolvedWindowExpression(c, WindowSpecReference(windowName)) =>
val errorMessage =
s"Window specification $windowName is not defined in the WINDOW clause."
val windowSpecDefinition =
windowDefinitions
.get(windowName)
.getOrElse(failAnalysis(errorMessage))
WindowExpression(c, windowSpecDefinition)
}
}

// Aggregate with Having clause. This rule works with an unresolved Aggregate because
// a resolved Aggregate will not have Window Functions.
case f @ Filter(condition, a @ Aggregate(groupingExprs, aggregateExprs, child))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about cube/rollup, we probably should consider them here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will turn them to Aggregate in the analysis.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right

if child.resolved &&
hasWindowFunction(aggregateExprs) &&
!a.expressions.exists(!_.resolved) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: a.expressions.forall(_.resolved) more readable

val (windowExpressions, aggregateExpressions) = extract(aggregateExprs)
// Create an Aggregate operator to evaluate aggregation functions.
val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child)
// Add a Filter operator for conditions in the Having clause.
val withFilter = Filter(condition, withAggregate)
val withWindow = addWindow(windowExpressions, withFilter)

// Finally, generate output columns according to the original projectList.
val finalProjectList = aggregateExprs.map (_.toAttribute)
Project(finalProjectList, withWindow)

case p: LogicalPlan if !p.childrenResolved => p

// Aggregate without Having clause.
case a @ Aggregate(groupingExprs, aggregateExprs, child)
if hasWindowFunction(aggregateExprs) &&
!a.expressions.exists(!_.resolved) =>
val (windowExpressions, aggregateExpressions) = extract(aggregateExprs)
// Create an Aggregate operator to evaluate aggregation functions.
val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child)
// Add Window operators.
val withWindow = addWindow(windowExpressions, withAggregate)

// Finally, generate output columns according to the original projectList.
val finalProjectList = aggregateExprs.map (_.toAttribute)
Project(finalProjectList, withWindow)

// We only extract Window Expressions after all expressions of the Project
// have been resolved.
case p @ Project(projectList, child)
if hasWindowFunction(projectList) && !p.expressions.exists(!_.resolved) =>
val (windowExpressions, regularExpressions) = extract(projectList)
// We add a project to get all needed expressions for window expressions from the child
// of the original Project operator.
val withProject = Project(regularExpressions, child)
// Add Window operators.
val withWindow = addWindow(windowExpressions, withProject)

// Finally, generate output columns according to the original projectList.
val finalProjectList = projectList.map (_.toAttribute)
Project(finalProjectList, withWindow)
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ trait CheckAnalysis {
failAnalysis(
s"invalid expression ${b.prettyString} " +
s"between ${b.left.simpleString} and ${b.right.simpleString}")

case w @ WindowExpression(windowFunction, windowSpec) if windowSpec.validate.nonEmpty =>
// The window spec is not valid.
val reason = windowSpec.validate.get
failAnalysis(s"Window specification $windowSpec is not valid because $reason")
}

operator match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -548,3 +548,97 @@ class JoinedRow5 extends Row {
}
}
}

/**
* JIT HACK: Replace with macros
*/
class JoinedRow6 extends Row {
private[this] var row1: Row = _
private[this] var row2: Row = _

def this(left: Row, right: Row) = {
this()
row1 = left
row2 = right
}

/** Updates this JoinedRow to used point at two new base rows. Returns itself. */
def apply(r1: Row, r2: Row): Row = {
row1 = r1
row2 = r2
this
}

/** Updates this JoinedRow by updating its left base row. Returns itself. */
def withLeft(newLeft: Row): Row = {
row1 = newLeft
this
}

/** Updates this JoinedRow by updating its right base row. Returns itself. */
def withRight(newRight: Row): Row = {
row2 = newRight
this
}

override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq

override def length: Int = row1.length + row2.length

override def apply(i: Int): Any =
if (i < row1.length) row1(i) else row2(i - row1.length)

override def isNullAt(i: Int): Boolean =
if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length)

override def getInt(i: Int): Int =
if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length)

override def getLong(i: Int): Long =
if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length)

override def getDouble(i: Int): Double =
if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length)

override def getBoolean(i: Int): Boolean =
if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length)

override def getShort(i: Int): Short =
if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length)

override def getByte(i: Int): Byte =
if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length)

override def getFloat(i: Int): Float =
if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)

override def getString(i: Int): String =
if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length)

override def getAs[T](i: Int): T =
if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)

override def copy(): Row = {
val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
var i = 0
while(i < totalSize) {
copiedValues(i) = apply(i)
i += 1
}
new GenericRow(copiedValues)
}

override def toString: String = {
// Make sure toString never throws NullPointerException.
if ((row1 eq null) && (row2 eq null)) {
"[ empty row ]"
} else if (row1 eq null) {
row2.mkString("[", ",", "]")
} else if (row2 eq null) {
row1.mkString("[", ",", "]")
} else {
mkString("[", ",", "]")
}
}
}
Loading