Skip to content

Commit 14e99b4

Browse files
committed
[SPARK-44571][SQL] Eliminate the Join by combine multiple Aggregates
1 parent 4c74061 commit 14e99b4

File tree

14 files changed

+986
-22
lines changed

14 files changed

+986
-22
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
2323
import org.apache.spark.sql.catalyst.expressions.Cast.{toSQLExpr, toSQLId, toSQLType}
2424
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, JavaCode, TrueLiteral}
2525
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
26-
import org.apache.spark.sql.catalyst.trees.TreePattern.OUTER_REFERENCE
26+
import org.apache.spark.sql.catalyst.trees.TreePattern.{BLOOM_FILTER, OUTER_REFERENCE, TreePattern}
2727
import org.apache.spark.sql.types._
2828
import org.apache.spark.util.sketch.BloomFilter
2929

@@ -47,6 +47,8 @@ case class BloomFilterMightContain(
4747
override def right: Expression = valueExpression
4848
override def prettyName: String = "might_contain"
4949

50+
final override val nodePatterns: Seq[TreePattern] = Seq(BLOOM_FILTER)
51+
5052
override def checkInputDataTypes(): TypeCheckResult = {
5153
(left.dataType, right.dataType) match {
5254
case (BinaryType, NullType) | (NullType, LongType) | (NullType, NullType) |

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder
2929
import org.apache.spark.sql.catalyst.expressions.codegen._
3030
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
3131
import org.apache.spark.sql.catalyst.trees.{BinaryLike, SQLQueryContext, UnaryLike}
32-
import org.apache.spark.sql.catalyst.trees.TreePattern.{ARRAYS_ZIP, CONCAT, TreePattern}
32+
import org.apache.spark.sql.catalyst.trees.TreePattern.{ARRAY_CONTAINS, ARRAYS_OVERLAP, ARRAYS_ZIP, CONCAT, TreePattern}
3333
import org.apache.spark.sql.catalyst.types.{DataTypeUtils, PhysicalDataType, PhysicalIntegralType}
3434
import org.apache.spark.sql.catalyst.util._
3535
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
@@ -1295,6 +1295,8 @@ case class ArrayContains(left: Expression, right: Expression)
12951295
extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with Predicate
12961296
with QueryErrorsBase {
12971297

1298+
final override val nodePatterns: Seq[TreePattern] = Seq(ARRAY_CONTAINS)
1299+
12981300
@transient private lazy val ordering: Ordering[Any] =
12991301
TypeUtils.getInterpretedOrdering(right.dataType)
13001302

@@ -1518,6 +1520,8 @@ case class ArrayAppend(left: Expression, right: Expression) extends ArrayPendBas
15181520
case class ArraysOverlap(left: Expression, right: Expression)
15191521
extends BinaryArrayExpressionWithImplicitCast with NullIntolerant with Predicate {
15201522

1523+
final override val nodePatterns: Seq[TreePattern] = Seq(ARRAYS_OVERLAP)
1524+
15211525
override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match {
15221526
case TypeCheckResult.TypeCheckSuccess =>
15231527
TypeUtils.checkForOrderingExpr(elementType, prettyName)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2222
import org.apache.spark.sql.catalyst.expressions.Cast._
2323
import org.apache.spark.sql.catalyst.expressions.codegen._
2424
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
25-
import org.apache.spark.sql.catalyst.trees.TreePattern.{COALESCE, NULL_CHECK, TreePattern}
25+
import org.apache.spark.sql.catalyst.trees.TreePattern.{AT_LEAST_N_NON_NULLS, COALESCE, NULL_CHECK, TreePattern}
2626
import org.apache.spark.sql.catalyst.util.TypeUtils
2727
import org.apache.spark.sql.errors.QueryCompilationErrors
2828
import org.apache.spark.sql.types._
@@ -412,6 +412,8 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate
412412
override def nullable: Boolean = false
413413
override def foldable: Boolean = children.forall(_.foldable)
414414

415+
final override val nodePatterns: Seq[TreePattern] = Seq(AT_LEAST_N_NON_NULLS)
416+
415417
private[this] val childrenArray = children.toArray
416418

417419
override def eval(input: InternalRow): Boolean = {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,8 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress
466466
override def toString: String = s"RLIKE($left, $right)"
467467
override def sql: String = s"${prettyName.toUpperCase(Locale.ROOT)}(${left.sql}, ${right.sql})"
468468

469+
final override val nodePatterns: Seq[TreePattern] = Seq(LIKE_FAMLIY)
470+
469471
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
470472
val patternClass = classOf[Pattern].getName
471473

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
3131
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
3232
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
3333
import org.apache.spark.sql.catalyst.trees.{BinaryLike, SQLQueryContext}
34-
import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UPPER_OR_LOWER}
34+
import org.apache.spark.sql.catalyst.trees.TreePattern.{STRING_PREDICATE, TreePattern, UPPER_OR_LOWER}
3535
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils}
3636
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
3737
import org.apache.spark.sql.internal.SQLConf
@@ -498,6 +498,8 @@ abstract class StringPredicate extends BinaryExpression
498498

499499
override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
500500

501+
final override val nodePatterns: Seq[TreePattern] = Seq(STRING_PREDICATE)
502+
501503
protected override def nullSafeEval(input1: Any, input2: Any): Any =
502504
compare(input1.asInstanceOf[UTF8String], input2.asInstanceOf[UTF8String])
503505

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.optimizer
19+
20+
import scala.collection.mutable.ArrayBuffer
21+
22+
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeMap, Expression, NamedExpression, Or}
23+
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
24+
import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, LeftOuter, RightOuter}
25+
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Join, LeafNode, LogicalPlan, Project, SerializeFromObject}
26+
import org.apache.spark.sql.catalyst.rules.Rule
27+
import org.apache.spark.sql.catalyst.trees.TreePattern.{AGGREGATE, ARRAY_CONTAINS, ARRAYS_OVERLAP, AT_LEAST_N_NON_NULLS, BLOOM_FILTER, DYNAMIC_PRUNING_EXPRESSION, DYNAMIC_PRUNING_SUBQUERY, EXISTS_SUBQUERY, HIGH_ORDER_FUNCTION, IN, IN_SUBQUERY, INSET, INVOKE, JOIN, JSON_TO_STRUCT, LIKE_FAMLIY, PYTHON_UDF, REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, SCALA_UDF, STRING_PREDICATE}
28+
29+
/**
30+
* This rule eliminates the [[Join]] if all the join side are [[Aggregate]]s by combine these
31+
* [[Aggregate]]s. This rule also support the nested [[Join]], as long as all the join sides for
32+
* every [[Join]] are [[Aggregate]]s.
33+
*
34+
* Note: this rule doesn't support following cases:
35+
* 1. The [[Aggregate]]s to be merged if at least one of them does not have a predicate or
36+
* has low predicate selectivity.
37+
* 2. The upstream node of these [[Aggregate]]s to be merged exists [[Join]].
38+
*/
39+
object CombineJoinedAggregates extends Rule[LogicalPlan] with MergeScalarSubqueriesHelper {
40+
41+
private def isSupportedJoinType(joinType: JoinType): Boolean =
42+
Seq(Inner, Cross, LeftOuter, RightOuter, FullOuter).contains(joinType)
43+
44+
private def isCheapPredicate(e: Expression): Boolean = {
45+
!e.containsAnyPattern(PYTHON_UDF, SCALA_UDF, INVOKE, JSON_TO_STRUCT, LIKE_FAMLIY,
46+
REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, DYNAMIC_PRUNING_SUBQUERY, DYNAMIC_PRUNING_EXPRESSION,
47+
HIGH_ORDER_FUNCTION, IN_SUBQUERY, IN, INSET, EXISTS_SUBQUERY, STRING_PREDICATE,
48+
AT_LEAST_N_NON_NULLS, BLOOM_FILTER, ARRAY_CONTAINS, ARRAYS_OVERLAP) &&
49+
Option(e.apply(conf.maxTreeNodeNumOfPredicate)).isEmpty
50+
}
51+
52+
/**
53+
* Try to merge two `Aggregate`s by traverse down recursively.
54+
*
55+
* @return The optional tuple as follows:
56+
* 1. the merged plan
57+
* 2. the attribute mapping from the old to the merged version
58+
* 3. optional filters of both plans that need to be propagated and merged in an
59+
* ancestor `Aggregate` node if possible.
60+
*/
61+
private def mergePlan(
62+
left: LogicalPlan,
63+
right: LogicalPlan): Option[(LogicalPlan, AttributeMap[Attribute], Seq[Expression])] = {
64+
(left, right) match {
65+
case (la: Aggregate, ra: Aggregate) =>
66+
mergePlan(la.child, ra.child).map { case (newChild, outputMap, filters) =>
67+
val rightAggregateExprs = ra.aggregateExpressions.map(mapAttributes(_, outputMap))
68+
69+
val mergedAggregateExprs = if (filters.length == 2) {
70+
Seq(
71+
(la.aggregateExpressions, filters.head),
72+
(rightAggregateExprs, filters.last)
73+
).flatMap { case (aggregateExpressions, propagatedFilter) =>
74+
aggregateExpressions.map { ne =>
75+
ne.transform {
76+
case ae @ AggregateExpression(_, _, _, filterOpt, _) =>
77+
val newFilter = filterOpt.map { filter =>
78+
And(propagatedFilter, filter)
79+
}.orElse(Some(propagatedFilter))
80+
ae.copy(filter = newFilter)
81+
}.asInstanceOf[NamedExpression]
82+
}
83+
}
84+
} else {
85+
la.aggregateExpressions ++ rightAggregateExprs
86+
}
87+
88+
(Aggregate(Seq.empty, mergedAggregateExprs, newChild), AttributeMap.empty, Seq.empty)
89+
}
90+
case (lp: Project, rp: Project) =>
91+
val mergedProjectList = ArrayBuffer[NamedExpression](lp.projectList: _*)
92+
93+
mergePlan(lp.child, rp.child).map { case (newChild, outputMap, filters) =>
94+
val allFilterReferences = filters.flatMap(_.references)
95+
val newOutputMap = AttributeMap((rp.projectList ++ allFilterReferences).map { ne =>
96+
val mapped = mapAttributes(ne, outputMap)
97+
98+
val withoutAlias = mapped match {
99+
case Alias(child, _) => child
100+
case e => e
101+
}
102+
103+
val outputAttr = mergedProjectList.find {
104+
case Alias(child, _) => child semanticEquals withoutAlias
105+
case e => e semanticEquals withoutAlias
106+
}.getOrElse {
107+
mergedProjectList += mapped
108+
mapped
109+
}.toAttribute
110+
ne.toAttribute -> outputAttr
111+
})
112+
113+
(Project(mergedProjectList.toSeq, newChild), newOutputMap, filters)
114+
}
115+
case (lf: Filter, rf: Filter)
116+
if isCheapPredicate(lf.condition) && isCheapPredicate(rf.condition) =>
117+
mergePlan(lf.child, rf.child).map {
118+
case (newChild, outputMap, filters) =>
119+
val mappedRightCondition = mapAttributes(rf.condition, outputMap)
120+
val (newLeftCondition, newRightCondition) = if (filters.length == 2) {
121+
(And(lf.condition, filters.head), And(mappedRightCondition, filters.last))
122+
} else {
123+
(lf.condition, mappedRightCondition)
124+
}
125+
val newCondition = Or(newLeftCondition, newRightCondition)
126+
127+
(Filter(newCondition, newChild), outputMap, Seq(newLeftCondition, newRightCondition))
128+
}
129+
case (ll: LeafNode, rl: LeafNode) =>
130+
checkIdenticalPlans(rl, ll).map { outputMap =>
131+
(ll, outputMap, Seq.empty)
132+
}
133+
case (ls: SerializeFromObject, rs: SerializeFromObject) =>
134+
checkIdenticalPlans(rs, ls).map { outputMap =>
135+
(ls, outputMap, Seq.empty)
136+
}
137+
case _ => None
138+
}
139+
}
140+
141+
def apply(plan: LogicalPlan): LogicalPlan = {
142+
if (!conf.combineJoinedAggregatesEnabled) return plan
143+
144+
plan.transformUpWithPruning(_.containsAnyPattern(JOIN, AGGREGATE), ruleId) {
145+
case j @ Join(left: Aggregate, right: Aggregate, joinType, None, _)
146+
if isSupportedJoinType(joinType) &&
147+
left.groupingExpressions.isEmpty && right.groupingExpressions.isEmpty =>
148+
val mergedAggregate = mergePlan(left, right)
149+
mergedAggregate.map(_._1).getOrElse(j)
150+
}
151+
}
152+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ import org.apache.spark.sql.types.DataType
101101
* : +- ReusedSubquery Subquery scalar-subquery#242, [id=#125]
102102
* +- *(1) Scan OneRowRelation[]
103103
*/
104-
object MergeScalarSubqueries extends Rule[LogicalPlan] {
104+
object MergeScalarSubqueries extends Rule[LogicalPlan] with MergeScalarSubqueriesHelper {
105105
def apply(plan: LogicalPlan): LogicalPlan = {
106106
plan match {
107107
// Subquery reuse needs to be enabled for this optimization.
@@ -212,17 +212,6 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] {
212212
}
213213
}
214214

215-
// If 2 plans are identical return the attribute mapping from the new to the cached version.
216-
private def checkIdenticalPlans(
217-
newPlan: LogicalPlan,
218-
cachedPlan: LogicalPlan): Option[AttributeMap[Attribute]] = {
219-
if (newPlan.canonicalized == cachedPlan.canonicalized) {
220-
Some(AttributeMap(newPlan.output.zip(cachedPlan.output)))
221-
} else {
222-
None
223-
}
224-
}
225-
226215
// Recursively traverse down and try merging 2 plans. If merge is possible then return the merged
227216
// plan with the attribute mapping from the new to the merged version.
228217
// Please note that merging arbitrary plans can be complicated, the current version supports only
@@ -314,12 +303,6 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] {
314303
plan)
315304
}
316305

317-
private def mapAttributes[T <: Expression](expr: T, outputMap: AttributeMap[Attribute]) = {
318-
expr.transform {
319-
case a: Attribute => outputMap.getOrElse(a, a)
320-
}.asInstanceOf[T]
321-
}
322-
323306
// Applies `outputMap` attribute mapping on attributes of `newExpressions` and merges them into
324307
// `cachedExpressions`. Returns the merged expressions and the attribute mapping from the new to
325308
// the merged version that can be propagated up during merging nodes.
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.optimizer
19+
20+
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression}
21+
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
22+
23+
/**
24+
* The helper class used to merge scalar subqueries.
25+
*/
26+
trait MergeScalarSubqueriesHelper {
27+
28+
// If 2 plans are identical return the attribute mapping from the left to the right.
29+
protected def checkIdenticalPlans(
30+
left: LogicalPlan, right: LogicalPlan): Option[AttributeMap[Attribute]] = {
31+
if (left.canonicalized == right.canonicalized) {
32+
Some(AttributeMap(left.output.zip(right.output)))
33+
} else {
34+
None
35+
}
36+
}
37+
38+
protected def mapAttributes[T <: Expression](expr: T, outputMap: AttributeMap[Attribute]): T = {
39+
expr.transform {
40+
case a: Attribute => outputMap.getOrElse(a, a)
41+
}.asInstanceOf[T]
42+
}
43+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
9292
EliminateOffsets,
9393
EliminateLimits,
9494
CombineUnions,
95+
CombineJoinedAggregates,
9596
// Constant folding and strength reduction
9697
OptimizeRepartition,
9798
TransposeWindow,

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ object RuleIdCollection {
111111
"org.apache.spark.sql.catalyst.optimizer.ColumnPruning" ::
112112
"org.apache.spark.sql.catalyst.optimizer.CombineConcats" ::
113113
"org.apache.spark.sql.catalyst.optimizer.CombineFilters" ::
114+
"org.apache.spark.sql.catalyst.optimizer.CombineJoinedAggregates" ::
114115
"org.apache.spark.sql.catalyst.optimizer.CombineTypedFilters" ::
115116
"org.apache.spark.sql.catalyst.optimizer.CombineUnions" ::
116117
"org.apache.spark.sql.catalyst.optimizer.ConstantFolding" ::

0 commit comments

Comments
 (0)