|  | 
|  | 1 | +/* | 
|  | 2 | + * Licensed to the Apache Software Foundation (ASF) under one or more | 
|  | 3 | + * contributor license agreements.  See the NOTICE file distributed with | 
|  | 4 | + * this work for additional information regarding copyright ownership. | 
|  | 5 | + * The ASF licenses this file to You under the Apache License, Version 2.0 | 
|  | 6 | + * (the "License"); you may not use this file except in compliance with | 
|  | 7 | + * the License.  You may obtain a copy of the License at | 
|  | 8 | + * | 
|  | 9 | + *    http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 10 | + * | 
|  | 11 | + * Unless required by applicable law or agreed to in writing, software | 
|  | 12 | + * distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 13 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 14 | + * See the License for the specific language governing permissions and | 
|  | 15 | + * limitations under the License. | 
|  | 16 | + */ | 
|  | 17 | + | 
|  | 18 | +package org.apache.spark.sql.catalyst.optimizer | 
|  | 19 | + | 
|  | 20 | +import scala.collection.mutable.ArrayBuffer | 
|  | 21 | + | 
|  | 22 | +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeMap, Expression, NamedExpression, Or} | 
|  | 23 | +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression | 
|  | 24 | +import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, LeftOuter, RightOuter} | 
|  | 25 | +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Join, LeafNode, LogicalPlan, Project, SerializeFromObject} | 
|  | 26 | +import org.apache.spark.sql.catalyst.rules.Rule | 
|  | 27 | +import org.apache.spark.sql.catalyst.trees.TreePattern.{AGGREGATE, ARRAY_CONTAINS, ARRAYS_OVERLAP, AT_LEAST_N_NON_NULLS, BLOOM_FILTER, DYNAMIC_PRUNING_EXPRESSION, DYNAMIC_PRUNING_SUBQUERY, EXISTS_SUBQUERY, HIGH_ORDER_FUNCTION, IN, IN_SUBQUERY, INSET, INVOKE, JOIN, JSON_TO_STRUCT, LIKE_FAMLIY, PYTHON_UDF, REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, SCALA_UDF, STRING_PREDICATE} | 
|  | 28 | + | 
|  | 29 | +/** | 
|  | 30 | + * This rule eliminates the [[Join]] if all the join side are [[Aggregate]]s by combine these | 
|  | 31 | + * [[Aggregate]]s. This rule also support the nested [[Join]], as long as all the join sides for | 
|  | 32 | + * every [[Join]] are [[Aggregate]]s. | 
|  | 33 | + * | 
|  | 34 | + * Note: this rule doesn't support following cases: | 
|  | 35 | + * 1. The [[Aggregate]]s to be merged if at least one of them does not have a predicate or | 
|  | 36 | + *    has low predicate selectivity. | 
|  | 37 | + * 2. The upstream node of these [[Aggregate]]s to be merged exists [[Join]]. | 
|  | 38 | + */ | 
|  | 39 | +object CombineJoinedAggregates extends Rule[LogicalPlan] with MergeScalarSubqueriesHelper { | 
|  | 40 | + | 
|  | 41 | +  private def isSupportedJoinType(joinType: JoinType): Boolean = | 
|  | 42 | +    Seq(Inner, Cross, LeftOuter, RightOuter, FullOuter).contains(joinType) | 
|  | 43 | + | 
|  | 44 | +  private def isCheapPredicate(e: Expression): Boolean = { | 
|  | 45 | +    !e.containsAnyPattern(PYTHON_UDF, SCALA_UDF, INVOKE, JSON_TO_STRUCT, LIKE_FAMLIY, | 
|  | 46 | +      REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, DYNAMIC_PRUNING_SUBQUERY, DYNAMIC_PRUNING_EXPRESSION, | 
|  | 47 | +      HIGH_ORDER_FUNCTION, IN_SUBQUERY, IN, INSET, EXISTS_SUBQUERY, STRING_PREDICATE, | 
|  | 48 | +      AT_LEAST_N_NON_NULLS, BLOOM_FILTER, ARRAY_CONTAINS, ARRAYS_OVERLAP) && | 
|  | 49 | +      Option(e.apply(conf.maxTreeNodeNumOfPredicate)).isEmpty | 
|  | 50 | +  } | 
|  | 51 | + | 
|  | 52 | +  /** | 
|  | 53 | +   * Try to merge two `Aggregate`s by traverse down recursively. | 
|  | 54 | +   * | 
|  | 55 | +   * @return The optional tuple as follows: | 
|  | 56 | +   *         1. the merged plan | 
|  | 57 | +   *         2. the attribute mapping from the old to the merged version | 
|  | 58 | +   *         3. optional filters of both plans that need to be propagated and merged in an | 
|  | 59 | +   *         ancestor `Aggregate` node if possible. | 
|  | 60 | +   */ | 
|  | 61 | +  private def mergePlan( | 
|  | 62 | +      left: LogicalPlan, | 
|  | 63 | +      right: LogicalPlan): Option[(LogicalPlan, AttributeMap[Attribute], Seq[Expression])] = { | 
|  | 64 | +    (left, right) match { | 
|  | 65 | +      case (la: Aggregate, ra: Aggregate) => | 
|  | 66 | +        mergePlan(la.child, ra.child).map { case (newChild, outputMap, filters) => | 
|  | 67 | +          val rightAggregateExprs = ra.aggregateExpressions.map(mapAttributes(_, outputMap)) | 
|  | 68 | + | 
|  | 69 | +          val mergedAggregateExprs = if (filters.length == 2) { | 
|  | 70 | +            Seq( | 
|  | 71 | +              (la.aggregateExpressions, filters.head), | 
|  | 72 | +              (rightAggregateExprs, filters.last) | 
|  | 73 | +            ).flatMap { case (aggregateExpressions, propagatedFilter) => | 
|  | 74 | +              aggregateExpressions.map { ne => | 
|  | 75 | +                ne.transform { | 
|  | 76 | +                  case ae @ AggregateExpression(_, _, _, filterOpt, _) => | 
|  | 77 | +                    val newFilter = filterOpt.map { filter => | 
|  | 78 | +                      And(propagatedFilter, filter) | 
|  | 79 | +                    }.orElse(Some(propagatedFilter)) | 
|  | 80 | +                    ae.copy(filter = newFilter) | 
|  | 81 | +                }.asInstanceOf[NamedExpression] | 
|  | 82 | +              } | 
|  | 83 | +            } | 
|  | 84 | +          } else { | 
|  | 85 | +            la.aggregateExpressions ++ rightAggregateExprs | 
|  | 86 | +          } | 
|  | 87 | + | 
|  | 88 | +          (Aggregate(Seq.empty, mergedAggregateExprs, newChild), AttributeMap.empty, Seq.empty) | 
|  | 89 | +        } | 
|  | 90 | +      case (lp: Project, rp: Project) => | 
|  | 91 | +        val mergedProjectList = ArrayBuffer[NamedExpression](lp.projectList: _*) | 
|  | 92 | + | 
|  | 93 | +        mergePlan(lp.child, rp.child).map { case (newChild, outputMap, filters) => | 
|  | 94 | +          val allFilterReferences = filters.flatMap(_.references) | 
|  | 95 | +          val newOutputMap = AttributeMap((rp.projectList ++ allFilterReferences).map { ne => | 
|  | 96 | +            val mapped = mapAttributes(ne, outputMap) | 
|  | 97 | + | 
|  | 98 | +            val withoutAlias = mapped match { | 
|  | 99 | +              case Alias(child, _) => child | 
|  | 100 | +              case e => e | 
|  | 101 | +            } | 
|  | 102 | + | 
|  | 103 | +            val outputAttr = mergedProjectList.find { | 
|  | 104 | +              case Alias(child, _) => child semanticEquals withoutAlias | 
|  | 105 | +              case e => e semanticEquals withoutAlias | 
|  | 106 | +            }.getOrElse { | 
|  | 107 | +              mergedProjectList += mapped | 
|  | 108 | +              mapped | 
|  | 109 | +            }.toAttribute | 
|  | 110 | +            ne.toAttribute -> outputAttr | 
|  | 111 | +          }) | 
|  | 112 | + | 
|  | 113 | +          (Project(mergedProjectList.toSeq, newChild), newOutputMap, filters) | 
|  | 114 | +        } | 
|  | 115 | +      case (lf: Filter, rf: Filter) | 
|  | 116 | +        if isCheapPredicate(lf.condition) && isCheapPredicate(rf.condition) => | 
|  | 117 | +        mergePlan(lf.child, rf.child).map { | 
|  | 118 | +          case (newChild, outputMap, filters) => | 
|  | 119 | +            val mappedRightCondition = mapAttributes(rf.condition, outputMap) | 
|  | 120 | +            val (newLeftCondition, newRightCondition) = if (filters.length == 2) { | 
|  | 121 | +              (And(lf.condition, filters.head), And(mappedRightCondition, filters.last)) | 
|  | 122 | +            } else { | 
|  | 123 | +              (lf.condition, mappedRightCondition) | 
|  | 124 | +            } | 
|  | 125 | +          val newCondition = Or(newLeftCondition, newRightCondition) | 
|  | 126 | + | 
|  | 127 | +          (Filter(newCondition, newChild), outputMap, Seq(newLeftCondition, newRightCondition)) | 
|  | 128 | +        } | 
|  | 129 | +      case (ll: LeafNode, rl: LeafNode) => | 
|  | 130 | +        checkIdenticalPlans(rl, ll).map { outputMap => | 
|  | 131 | +          (ll, outputMap, Seq.empty) | 
|  | 132 | +        } | 
|  | 133 | +      case (ls: SerializeFromObject, rs: SerializeFromObject) => | 
|  | 134 | +        checkIdenticalPlans(rs, ls).map { outputMap => | 
|  | 135 | +          (ls, outputMap, Seq.empty) | 
|  | 136 | +        } | 
|  | 137 | +      case _ => None | 
|  | 138 | +    } | 
|  | 139 | +  } | 
|  | 140 | + | 
|  | 141 | +  def apply(plan: LogicalPlan): LogicalPlan = { | 
|  | 142 | +    if (!conf.combineJoinedAggregatesEnabled) return plan | 
|  | 143 | + | 
|  | 144 | +    plan.transformUpWithPruning(_.containsAnyPattern(JOIN, AGGREGATE), ruleId) { | 
|  | 145 | +      case j @ Join(left: Aggregate, right: Aggregate, joinType, None, _) | 
|  | 146 | +        if isSupportedJoinType(joinType) && | 
|  | 147 | +          left.groupingExpressions.isEmpty && right.groupingExpressions.isEmpty => | 
|  | 148 | +        val mergedAggregate = mergePlan(left, right) | 
|  | 149 | +        mergedAggregate.map(_._1).getOrElse(j) | 
|  | 150 | +    } | 
|  | 151 | +  } | 
|  | 152 | +} | 
0 commit comments