Skip to content

Commit 03c6e56

Browse files
committed
Refactor getAllValidConstraints
1 parent 6dca2e5 commit 03c6e56

File tree

2 files changed

+78
-10
lines changed

2 files changed

+78
-10
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -175,20 +175,57 @@ abstract class UnaryNode extends LogicalPlan {
175175
*/
176176
protected def getAllValidConstraints(projectList: Seq[NamedExpression]): ExpressionSet = {
177177
var allConstraints = child.constraints
178-
projectList.foreach {
179-
case a @ Alias(l: Literal, _) =>
180-
allConstraints += EqualNullSafe(a.toAttribute, l)
181-
case a @ Alias(e, _) =>
182-
// For every alias in `projectList`, replace the reference in constraints by its attribute.
183-
allConstraints ++= allConstraints.map(_ transform {
184-
case expr: Expression if expr.semanticEquals(e) =>
185-
a.toAttribute
178+
179+
// For each expression collect its aliases
180+
val aliasMap = projectList.collect{
181+
case alias @ Alias(expr, _) if !expr.foldable => (expr.canonicalized, alias)
182+
}.groupBy(_._1).mapValues(_.map(_._2))
183+
val remainingExpressions = aliasMap.keySet.to[collection.mutable.Set]
184+
185+
/**
186+
* Filtering allConstraints between each iteration is necessary, because
187+
* otherwise collecting valid constraints could in the worst case have exponential
188+
* time and memory complexity. Each replaced alias could double the number of constraints,
189+
* because we would keep both the original constraint and the one with alias.
190+
*/
191+
def shouldBeKept(expr: Expression): Boolean = {
192+
expr.references.subsetOf(outputSet) ||
193+
remainingExpressions.contains(expr.canonicalized) ||
194+
(expr.children.nonEmpty && expr.children.forall(shouldBeKept))
195+
}
196+
197+
// Replace expressions with aliases
198+
for ((expr, aliases) <- aliasMap) {
199+
allConstraints ++= allConstraints.flatMap(constraint => {
200+
aliases.map(alias => {
201+
constraint transform {
202+
case e: Expression if e.semanticEquals(expr) =>
203+
alias.toAttribute
204+
}
186205
})
187-
allConstraints += EqualNullSafe(e, a.toAttribute)
206+
})
207+
208+
for { alias1 <- aliases; alias2 <- aliases } {
209+
if (!alias1.fastEquals(alias2)) {
210+
allConstraints += EqualNullSafe(alias1.toAttribute, alias2.toAttribute)
211+
}
212+
}
213+
214+
remainingExpressions.remove(expr)
215+
allConstraints = allConstraints.filter(shouldBeKept)
216+
}
217+
218+
/**
219+
We keep the child constraints and equality between original and aliased attributes,
220+
so [[ConstraintHelper.inferAdditionalConstraints]] would have the full information available.
221+
*/
222+
projectList.foreach {
223+
case alias @ Alias(expr, _) =>
224+
allConstraints += EqualNullSafe(alias.toAttribute, expr)
188225
case _ => // Don't change.
189226
}
190227

191-
allConstraints
228+
allConstraints ++ child.constraints
192229
}
193230

194231
override protected lazy val validConstraints: ExpressionSet = child.constraints

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,4 +422,35 @@ class ConstraintPropagationSuite extends SparkFunSuite with PlanTest {
422422
assert(aliasedRelation.analyze.constraints.isEmpty)
423423
}
424424
}
425+
426+
test("SPARK-33152: Avoid exponential growth of constraints") {
427+
val relation = LocalRelation('a.int, 'b.int, 'c.int)
428+
429+
val plan1 = relation
430+
.where('a.attr + 'b.attr > intToLiteral(0))
431+
.select('a as 'a1, 'b as 'b1)
432+
.analyze
433+
434+
verifyConstraints(plan1.constraints,
435+
ExpressionSet(Seq(
436+
IsNotNull(resolveColumn(plan1, "a1")),
437+
IsNotNull(resolveColumn(plan1, "b1")),
438+
resolveColumn(plan1, "a1") + resolveColumn(plan1, "b1") > 0
439+
)))
440+
441+
442+
val plan2 = relation
443+
.where('a.attr + 'b.attr > intToLiteral(0))
444+
.select('a as 'a1, 'b as 'b1, ('a.attr + 'b.attr) as 'c1)
445+
.analyze
446+
447+
verifyConstraints(plan2.constraints,
448+
ExpressionSet(Seq(
449+
IsNotNull(resolveColumn(plan2, "a1")),
450+
IsNotNull(resolveColumn(plan2, "b1")),
451+
IsNotNull(resolveColumn(plan2, "c1")),
452+
resolveColumn(plan2, "a1") + resolveColumn(plan2, "b1") > 0,
453+
resolveColumn(plan2, "c1") > 0
454+
)))
455+
}
425456
}

0 commit comments

Comments
 (0)