|
17 | 17 |
|
18 | 18 | package org.apache.spark.sql.catalyst.analysis |
19 | 19 |
|
| 20 | +import scala.collection.mutable.ArrayBuffer |
| 21 | + |
20 | 22 | import org.apache.spark.util.collection.OpenHashSet |
21 | 23 | import org.apache.spark.sql.AnalysisException |
22 | 24 | import org.apache.spark.sql.catalyst.expressions._ |
@@ -61,6 +63,7 @@ class Analyzer( |
61 | 63 | ResolveGenerate :: |
62 | 64 | ImplicitGenerate :: |
63 | 65 | ResolveFunctions :: |
| 66 | + ExtractWindowExpressions :: |
64 | 67 | GlobalAggregates :: |
65 | 68 | UnresolvedHavingClauseAttributes :: |
66 | 69 | TrimGroupingAliases :: |
@@ -529,6 +532,203 @@ class Analyzer( |
529 | 532 | makeGeneratorOutput(p.generator, p.generatorOutput), p.child) |
530 | 533 | } |
531 | 534 | } |
| 535 | + |
| 536 | + /** |
| 537 | + * Extracts [[WindowExpression]]s from the projectList of a [[Project]] operator and |
| 538 | + * aggregateExpressions of an [[Aggregate]] operator and creates individual [[Window]] |
| 539 | + * operators for every distinct [[WindowSpecDefinition]]. |
| 540 | + * |
| 541 | + * This rule handles three cases: |
| 542 | + * - A [[Project]] having [[WindowExpression]]s in its projectList; |
| 543 | + * - An [[Aggregate]] having [[WindowExpression]]s in its aggregateExpressions. |
| 544 | + * - An [[Filter]]->[[Aggregate]] pattern representing GROUP BY with a HAVING |
| 545 | + * clause and the [[Aggregate]] has [[WindowExpression]]s in its aggregateExpressions. |
| 546 | + * Note: If there is a GROUP BY clause in the query, aggregations and corresponding |
| 547 | + * filters (expressions in the HAVING clause) should be evaluated before any |
| 548 | + * [[WindowExpression]]. If a query has SELECT DISTINCT, the DISTINCT part should be |
| 549 | + * evaluated after all [[WindowExpression]]s. |
| 550 | + * |
| 551 | + * For every case, the transformation works as follows: |
| 552 | + * 1. For a list of [[Expression]]s (a projectList or an aggregateExpressions), partitions |
| 553 | + * it two lists of [[Expression]]s, one for all [[WindowExpression]]s and another for |
| 554 | + * all regular expressions. |
| 555 | + * 2. For all [[WindowExpression]]s, groups them based on their [[WindowSpecDefinition]]s. |
| 556 | + * 3. For every distinct [[WindowSpecDefinition]], creates a [[Window]] operator and inserts |
| 557 | + * it into the plan tree. |
| 558 | + */ |
| 559 | + object ExtractWindowExpressions extends Rule[LogicalPlan] { |
| 560 | + def hasWindowFunction(projectList: Seq[NamedExpression]): Boolean = |
| 561 | + projectList.exists(hasWindowFunction) |
| 562 | + |
| 563 | + def hasWindowFunction(expr: NamedExpression): Boolean = { |
| 564 | + expr.find { |
| 565 | + case window: WindowExpression => true |
| 566 | + case _ => false |
| 567 | + }.isDefined |
| 568 | + } |
| 569 | + |
| 570 | + /** |
| 571 | + * From a Seq of [[NamedExpression]]s, extract window expressions and |
| 572 | + * other regular expressions. |
| 573 | + */ |
| 574 | + def extract( |
| 575 | + expressions: Seq[NamedExpression]): (Seq[NamedExpression], Seq[NamedExpression]) = { |
| 576 | + // First, we simple partition the input expressions to two part, one having |
| 577 | + // WindowExpressions and another one without WindowExpressions. |
| 578 | + val (windowExpressions, regularExpressions) = expressions.partition(hasWindowFunction) |
| 579 | + |
| 580 | + // Then, we need to extract those regular expressions used in the WindowExpression. |
| 581 | + // For example, when we have col1 - Sum(col2 + col3) OVER (PARTITION BY col4 ORDER BY col5), |
| 582 | + // we need to make sure that col1 to col5 are all projected from the child of the Window |
| 583 | + // operator. |
| 584 | + val extractedExprBuffer = new ArrayBuffer[NamedExpression]() |
| 585 | + def extractExpr(expr: Expression): Expression = expr match { |
| 586 | + case ne: NamedExpression => |
| 587 | + // If a named expression is not in regularExpressions, add extract it and replace it |
| 588 | + // with an AttributeReference. |
| 589 | + val missingExpr = |
| 590 | + AttributeSet(Seq(expr)) -- (regularExpressions ++ extractedExprBuffer) |
| 591 | + if (missingExpr.nonEmpty) { |
| 592 | + extractedExprBuffer += ne |
| 593 | + } |
| 594 | + ne.toAttribute |
| 595 | + case e: Expression if e.foldable => |
| 596 | + e // No need to create an attribute reference if it will be evaluated as a Literal. |
| 597 | + case e: Expression => |
| 598 | + // For other expressions, we extract it and replace it with an AttributeReference (with |
| 599 | + // an interal column name, e.g. "_w0"). |
| 600 | + val withName = Alias(e, s"_w${extractedExprBuffer.length}")() |
| 601 | + extractedExprBuffer += withName |
| 602 | + withName.toAttribute |
| 603 | + } |
| 604 | + |
| 605 | + // Now, we extract expressions from windowExpressions by using extractExpr. |
| 606 | + val newWindowExpressions = windowExpressions.map { |
| 607 | + _.transform { |
| 608 | + // Extracts children expressions of a WindowFunction (input parameters of |
| 609 | + // a WindowFunction). |
| 610 | + case wf : WindowFunction => |
| 611 | + val newChildren = wf.children.map(extractExpr(_)) |
| 612 | + wf.withNewChildren(newChildren) |
| 613 | + |
| 614 | + // Extracts expressions from the partition spec and order spec. |
| 615 | + case wsc @ WindowSpecDefinition(partitionSpec, orderSpec, _) => |
| 616 | + val newPartitionSpec = partitionSpec.map(extractExpr(_)) |
| 617 | + val newOrderSpec = orderSpec.map { so => |
| 618 | + val newChild = extractExpr(so.child) |
| 619 | + so.copy(child = newChild) |
| 620 | + } |
| 621 | + wsc.copy(partitionSpec = newPartitionSpec, orderSpec = newOrderSpec) |
| 622 | + |
| 623 | + // Extracts AggregateExpression. For example, for SUM(x) - Sum(y) OVER (...), |
| 624 | + // we need to extract SUM(x). |
| 625 | + case agg: AggregateExpression => |
| 626 | + val withName = Alias(agg, s"_w${extractedExprBuffer.length}")() |
| 627 | + extractedExprBuffer += withName |
| 628 | + withName.toAttribute |
| 629 | + }.asInstanceOf[NamedExpression] |
| 630 | + } |
| 631 | + |
| 632 | + (newWindowExpressions, regularExpressions ++ extractedExprBuffer) |
| 633 | + } |
| 634 | + |
| 635 | + /** |
| 636 | + * Adds operators for Window Expressions. Every Window operator handles a single Window Spec. |
| 637 | + */ |
| 638 | + def addWindow(windowExpressions: Seq[NamedExpression], child: LogicalPlan): LogicalPlan = { |
| 639 | + // First, we group window expressions based on their Window Spec. |
| 640 | + val groupedWindowExpression = windowExpressions.groupBy { expr => |
| 641 | + val windowExpression = expr.find { |
| 642 | + case window: WindowExpression => true |
| 643 | + case other => false |
| 644 | + }.map(_.asInstanceOf[WindowExpression].windowSpec) |
| 645 | + windowExpression.getOrElse( |
| 646 | + failAnalysis(s"$windowExpressions does not have any WindowExpression.")) |
| 647 | + }.toSeq |
| 648 | + |
| 649 | + // For every Window Spec, we add a Window operator and set currentChild as the child of it. |
| 650 | + var currentChild = child |
| 651 | + var i = 0 |
| 652 | + while (i < groupedWindowExpression.size) { |
| 653 | + val (windowSpec, windowExpressions) = groupedWindowExpression(i) |
| 654 | + // Set currentChild to the newly created Window operator. |
| 655 | + currentChild = Window(currentChild.output, windowExpressions, windowSpec, currentChild) |
| 656 | + |
| 657 | + // Move to next WindowExpression. |
| 658 | + i += 1 |
| 659 | + } |
| 660 | + |
| 661 | + // We return the top operator. |
| 662 | + currentChild |
| 663 | + } |
| 664 | + |
| 665 | + // We have to use transformDown at here to make sure the rule of |
| 666 | + // "Aggregate with Having clause" will be triggered. |
| 667 | + def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { |
| 668 | + // Lookup WindowSpecDefinitions. This rule works with unresolved children. |
| 669 | + case WithWindowDefinition(windowDefinitions, child) => |
| 670 | + child.transform { |
| 671 | + case plan => plan.transformExpressions { |
| 672 | + case UnresolvedWindowExpression(c, WindowSpecReference(windowName)) => |
| 673 | + val errorMessage = |
| 674 | + s"Window specification $windowName is not defined in the WINDOW clause." |
| 675 | + val windowSpecDefinition = |
| 676 | + windowDefinitions |
| 677 | + .get(windowName) |
| 678 | + .getOrElse(failAnalysis(errorMessage)) |
| 679 | + WindowExpression(c, windowSpecDefinition) |
| 680 | + } |
| 681 | + } |
| 682 | + |
| 683 | + // Aggregate with Having clause. This rule works with an unresolved Aggregate because |
| 684 | + // a resolved Aggregate will not have Window Functions. |
| 685 | + case f @ Filter(condition, a @ Aggregate(groupingExprs, aggregateExprs, child)) |
| 686 | + if child.resolved && |
| 687 | + hasWindowFunction(aggregateExprs) && |
| 688 | + !a.expressions.exists(!_.resolved) => |
| 689 | + val (windowExpressions, aggregateExpressions) = extract(aggregateExprs) |
| 690 | + // Create an Aggregate operator to evaluate aggregation functions. |
| 691 | + val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child) |
| 692 | + // Add a Filter operator for conditions in the Having clause. |
| 693 | + val withFilter = Filter(condition, withAggregate) |
| 694 | + val withWindow = addWindow(windowExpressions, withFilter) |
| 695 | + |
| 696 | + // Finally, generate output columns according to the original projectList. |
| 697 | + val finalProjectList = aggregateExprs.map (_.toAttribute) |
| 698 | + Project(finalProjectList, withWindow) |
| 699 | + |
| 700 | + case p: LogicalPlan if !p.childrenResolved => p |
| 701 | + |
| 702 | + // Aggregate without Having clause. |
| 703 | + case a @ Aggregate(groupingExprs, aggregateExprs, child) |
| 704 | + if hasWindowFunction(aggregateExprs) && |
| 705 | + !a.expressions.exists(!_.resolved) => |
| 706 | + val (windowExpressions, aggregateExpressions) = extract(aggregateExprs) |
| 707 | + // Create an Aggregate operator to evaluate aggregation functions. |
| 708 | + val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child) |
| 709 | + // Add Window operators. |
| 710 | + val withWindow = addWindow(windowExpressions, withAggregate) |
| 711 | + |
| 712 | + // Finally, generate output columns according to the original projectList. |
| 713 | + val finalProjectList = aggregateExprs.map (_.toAttribute) |
| 714 | + Project(finalProjectList, withWindow) |
| 715 | + |
| 716 | + // We only extract Window Expressions after all expressions of the Project |
| 717 | + // have been resolved. |
| 718 | + case p @ Project(projectList, child) |
| 719 | + if hasWindowFunction(projectList) && !p.expressions.exists(!_.resolved) => |
| 720 | + val (windowExpressions, regularExpressions) = extract(projectList) |
| 721 | + // We add a project to get all needed expressions for window expressions from the child |
| 722 | + // of the original Project operator. |
| 723 | + val withProject = Project(regularExpressions, child) |
| 724 | + // Add Window operators. |
| 725 | + val withWindow = addWindow(windowExpressions, withProject) |
| 726 | + |
| 727 | + // Finally, generate output columns according to the original projectList. |
| 728 | + val finalProjectList = projectList.map (_.toAttribute) |
| 729 | + Project(finalProjectList, withWindow) |
| 730 | + } |
| 731 | + } |
532 | 732 | } |
533 | 733 |
|
534 | 734 | /** |
|
0 commit comments