Skip to content

Commit 905eaa1

Browse files
committed
improve structure and styling.
1 parent 909d2cd commit 905eaa1

File tree

3 files changed

+61
-86
lines changed

3 files changed

+61
-86
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala

Lines changed: 42 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -68,36 +68,34 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
6868
case _ => Seq.empty[Attribute]
6969
}
7070

71+
// Collect aliases from expressions, so we may avoid producing recursive constraints.
72+
private lazy val aliasMap = AttributeMap(
73+
(expressions ++ children.flatMap(_.expressions)).collect {
74+
case a: Alias => (a.toAttribute, a.child)
75+
})
76+
7177
/**
7278
* Infers an additional set of constraints from a given set of equality constraints.
7379
* For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an
74-
* additional constraint of the form `b = 5`
80+
* additional constraint of the form `b = 5`.
81+
*
82+
* [SPARK-17733] We explicitly prevent producing recursive constraints of the form `a = f(a, b)`
83+
* as they are often useless and can lead to a non-converging set of constraints.
7584
*/
7685
private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = {
77-
// Collect alias from expressions to avoid producing non-converging set of constraints
78-
// for recursive functions.
79-
//
80-
// Don't apply transform on constraints if the replacement will cause an recursive deduction,
81-
// when that happens a non-converging set of constraints will be created and finally throw
82-
// OOM Exception.
83-
// For more details, refer to https://issues.apache.org/jira/browse/SPARK-17733
84-
val aliasMap = AttributeMap((expressions ++ children.flatMap(_.expressions)).collect {
85-
case a: Alias => (a.toAttribute, a.child)
86-
})
87-
88-
val equalExprSets = generateEqualExpressionSets(constraints, aliasMap)
86+
val constraintClasses = generateEquivalentConstraintClasses(constraints)
8987

9088
var inferredConstraints = Set.empty[Expression]
9189
constraints.foreach {
9290
case eq @ EqualTo(l: Attribute, r: Attribute) =>
9391
val candidateConstraints = constraints - eq
9492
inferredConstraints ++= candidateConstraints.map(_ transform {
9593
case a: Attribute if a.semanticEquals(l) &&
96-
!isRecursiveDeduction(r, aliasMap, equalExprSets) => r
94+
!isRecursiveDeduction(r, constraintClasses) => r
9795
})
9896
inferredConstraints ++= candidateConstraints.map(_ transform {
9997
case a: Attribute if a.semanticEquals(r) &&
100-
!isRecursiveDeduction(l, aliasMap, equalExprSets) => l
98+
!isRecursiveDeduction(l, constraintClasses) => l
10199
})
102100
case _ => // No inference
103101
}
@@ -110,58 +108,60 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
110108
* expression sets: (Set(a, b, c), Set(e, f)). This will be used to search all expressions equal
111109
* to an selected attribute.
112110
*/
113-
private def generateEqualExpressionSets(
114-
constraints: Set[Expression],
115-
aliasMap: AttributeMap[Expression]): Seq[Set[Expression]] = {
116-
var equalExprSets = Seq.empty[Set[Expression]]
111+
private def generateEquivalentConstraintClasses(
112+
constraints: Set[Expression]): Seq[Set[Expression]] = {
113+
var constraintClasses = Seq.empty[Set[Expression]]
117114
constraints.foreach {
118115
case eq @ EqualTo(l: Attribute, r: Attribute) =>
119116
// Transform [[Alias]] to its child.
120117
val left = aliasMap.getOrElse(l, l)
121118
val right = aliasMap.getOrElse(r, r)
122-
// Get the expression set for equivalence class of expressions.
123-
val leftEqualSet = getEqualExprSet(left, equalExprSets).getOrElse(Set.empty[Expression])
124-
val rightEqualSet = getEqualExprSet(right, equalExprSets).getOrElse(Set.empty[Expression])
125-
if (!leftEqualSet.isEmpty && !rightEqualSet.isEmpty) {
119+
// Get the expression set for an equivalence constraint class.
120+
val leftConstraintClass = getConstraintClass(left, constraintClasses)
121+
val rightConstraintClass = getConstraintClass(right, constraintClasses)
122+
if (leftConstraintClass.nonEmpty && rightConstraintClass.nonEmpty) {
126123
// Combine the two sets.
127-
equalExprSets = equalExprSets.diff(leftEqualSet :: rightEqualSet :: Nil) :+
128-
(leftEqualSet ++ rightEqualSet)
129-
} else if (!leftEqualSet.isEmpty) { // && rightEqualSet.isEmpty
124+
constraintClasses = constraintClasses
125+
.diff(leftConstraintClass :: rightConstraintClass :: Nil) :+
126+
(leftConstraintClass ++ rightConstraintClass)
127+
} else if (leftConstraintClass.nonEmpty) { // && rightConstraintClass.isEmpty
130128
// Update equivalence class of `left` expression.
131-
equalExprSets = equalExprSets.diff(leftEqualSet :: Nil) :+ (leftEqualSet + right)
132-
} else if (!rightEqualSet.isEmpty) { // && leftEqualSet.isEmpty
129+
constraintClasses = constraintClasses
130+
.diff(leftConstraintClass :: Nil) :+ (leftConstraintClass + right)
131+
} else if (rightConstraintClass.nonEmpty) { // && leftConstraintClass.isEmpty
133132
// Update equivalence class of `right` expression.
134-
equalExprSets = equalExprSets.diff(rightEqualSet :: Nil) :+ (rightEqualSet + left)
135-
} else { // leftEqualSet.isEmpty && rightEqualSet.isEmpty
136-
// Create new equivalence class since both expression don't present in any classes.
137-
equalExprSets = equalExprSets :+ Set(left, right)
133+
constraintClasses = constraintClasses
134+
.diff(rightConstraintClass :: Nil) :+ (rightConstraintClass + left)
135+
} else { // leftConstraintClass.isEmpty && rightConstraintClass.isEmpty
136+
// Create new equivalence constraint class since neither expression presents
137+
// in any classes.
138+
constraintClasses = constraintClasses :+ Set(left, right)
138139
}
139140
case _ => // Skip
140141
}
141142

142-
equalExprSets
143+
constraintClasses
143144
}
144145

145146
/*
146147
* Get all expressions equivalent to the selected expression.
147148
*/
148-
private def getEqualExprSet(
149+
private def getConstraintClass(
149150
expr: Expression,
150-
equalExprSets: Seq[Set[Expression]]): Option[Set[Expression]] =
151-
equalExprSets.filter(_.contains(expr)).headOption
151+
constraintClasses: Seq[Set[Expression]]): Set[Expression] =
152+
constraintClasses.find(_.contains(expr)).getOrElse(Set.empty[Expression])
152153

153154
/*
154-
* Check whether replace by an [[Attribute]] will cause an recursive deduction. Generally it
155-
* has an form like: `a -> f(a, b)`, where `a` and `b` are expressions and `f` is an function.
155+
* Check whether replace by an [[Attribute]] will cause a recursive deduction. Generally it
156+
* has the form like: `a -> f(a, b)`, where `a` and `b` are expressions and `f` is a function.
156157
* Here we first get all expressions equal to `attr` and then check whether at least one of them
157-
* is child of the referenced expression.
158+
* is a child of the referenced expression.
158159
*/
159160
private def isRecursiveDeduction(
160161
attr: Attribute,
161-
aliasMap: AttributeMap[Expression],
162-
equalExprSets: Seq[Set[Expression]]): Boolean = {
162+
constraintClasses: Seq[Set[Expression]]): Boolean = {
163163
val expr = aliasMap.getOrElse(attr, attr)
164-
getEqualExprSet(expr, equalExprSets).getOrElse(Set.empty[Expression]).exists { e =>
164+
getConstraintClass(expr, constraintClasses).exists { e =>
165165
expr.children.exists(_.semanticEquals(e))
166166
}
167167
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -131,14 +131,14 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
131131
val originalQuery = t1.select('a, Coalesce(Seq('a, 'b)).as('int_col)).as("t")
132132
.join(t2, Inner, Some("t.a".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr))
133133
.analyze
134-
val currectAnswer = t1.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'b)))
135-
&&'a === Coalesce(Seq('a, 'b)))
134+
val correctAnswer = t1
135+
.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'b))) && 'a === Coalesce(Seq('a, 'b)))
136136
.select('a, Coalesce(Seq('a, 'b)).as('int_col)).as("t")
137137
.join(t2.where(IsNotNull('a)), Inner,
138138
Some("t.a".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr))
139139
.analyze
140140
val optimized = Optimize.execute(originalQuery)
141-
comparePlans(optimized, currectAnswer)
141+
comparePlans(optimized, correctAnswer)
142142
}
143143

144144
test("inner join with alias: alias contains single attributes") {
@@ -148,14 +148,14 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
148148
val originalQuery = t1.select('a, 'b.as('d)).as("t")
149149
.join(t2, Inner, Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr))
150150
.analyze
151-
val currectAnswer = t1.where(IsNotNull('a) && IsNotNull('b)
152-
&& 'a <=> 'a && 'b <=> 'b &&'a === 'b)
151+
val correctAnswer = t1
152+
.where(IsNotNull('a) && IsNotNull('b) && 'a <=> 'a && 'b <=> 'b &&'a === 'b)
153153
.select('a, 'b.as('d)).as("t")
154154
.join(t2.where(IsNotNull('a) && 'a <=> 'a), Inner,
155155
Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr))
156156
.analyze
157157
val optimized = Optimize.execute(originalQuery)
158-
comparePlans(optimized, currectAnswer)
158+
comparePlans(optimized, correctAnswer)
159159
}
160160

161161
test("inner join with alias: don't generate constraints for recursive functions") {
@@ -168,18 +168,20 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
168168
&& "t.d".attr === "t2.a".attr
169169
&& "t.int_col".attr === "t2.a".attr))
170170
.analyze
171-
val correctAnswer = t1.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a)))
172-
&& 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a)) && 'a <=> 'a
173-
&& Coalesce(Seq('a, 'a)) <=> 'b && Coalesce(Seq('a, 'a)) <=> Coalesce(Seq('a, 'a))
174-
&& 'a === 'b && IsNotNull(Coalesce(Seq('a, 'b))) && 'a === Coalesce(Seq('a, 'b))
175-
&& Coalesce(Seq('a, 'b)) <=> Coalesce(Seq('b, 'b)) && Coalesce(Seq('a, 'b)) === 'b
176-
&& IsNotNull('b) && IsNotNull(Coalesce(Seq('b, 'b)))
177-
&& 'b === Coalesce(Seq('b, 'b)) && 'b <=> Coalesce(Seq('b, 'b))
178-
&& Coalesce(Seq('b, 'b)) <=> Coalesce(Seq('b, 'b)) && 'b <=> 'b)
171+
val correctAnswer = t1
172+
.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a)))
173+
&& 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a)) && 'a <=> 'a
174+
&& Coalesce(Seq('a, 'a)) <=> 'b && Coalesce(Seq('a, 'a)) <=> Coalesce(Seq('a, 'a))
175+
&& 'a === 'b && IsNotNull(Coalesce(Seq('a, 'b))) && 'a === Coalesce(Seq('a, 'b))
176+
&& Coalesce(Seq('a, 'b)) <=> Coalesce(Seq('b, 'b)) && Coalesce(Seq('a, 'b)) === 'b
177+
&& IsNotNull('b) && IsNotNull(Coalesce(Seq('b, 'b)))
178+
&& 'b === Coalesce(Seq('b, 'b)) && 'b <=> Coalesce(Seq('b, 'b))
179+
&& Coalesce(Seq('b, 'b)) <=> Coalesce(Seq('b, 'b)) && 'b <=> 'b)
179180
.select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)).as("t")
180-
.join(t2.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a)))
181-
&& 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a)) && 'a <=> 'a
182-
&& Coalesce(Seq('a, 'a)) <=> Coalesce(Seq('a, 'a))), Inner,
181+
.join(t2
182+
.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a)))
183+
&& 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a)) && 'a <=> 'a
184+
&& Coalesce(Seq('a, 'a)) <=> Coalesce(Seq('a, 'a))), Inner,
183185
Some("t.a".attr === "t2.a".attr
184186
&& "t.d".attr === "t2.a".attr
185187
&& "t.int_col".attr === "t2.a".attr

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2679,31 +2679,4 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
26792679
}
26802680
}
26812681
}
2682-
2683-
test("SPARK-17733 InferFiltersFromConstraints rule never terminates for query") {
2684-
withTempView("tmpv") {
2685-
spark.range(10).toDF("a").createTempView("tmpv")
2686-
2687-
// Just ensure the following query will successfully execute complete.
2688-
val query =
2689-
"""
2690-
|SELECT
2691-
| *
2692-
|FROM (
2693-
| SELECT
2694-
| COALESCE(t1.a, t2.a) AS int_col,
2695-
| t1.a,
2696-
| t2.a AS b
2697-
| FROM tmpv t1
2698-
| CROSS JOIN tmpv t2
2699-
|) t1
2700-
|INNER JOIN tmpv t2
2701-
|ON (((t2.a) = (t1.a)) AND ((t2.a) = (t1.int_col))) AND ((t2.a) = (t1.b))
2702-
""".stripMargin
2703-
2704-
eventually(timeout(60 seconds)) {
2705-
assert(sql(query).count() > 0)
2706-
}
2707-
}
2708-
}
27092682
}

0 commit comments

Comments
 (0)