Skip to content

Commit af55a08

Browse files
committed
Only infer foldable constraints
1 parent 248e3cc commit af55a08

File tree

2 files changed

+24
-39
lines changed

2 files changed

+24
-39
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala

Lines changed: 22 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -107,49 +107,28 @@ trait ConstraintHelper {
107107
}
108108

109109
val greaterThans = binaryComparisons.map {
110+
case EqualTo(l, r) if l.foldable => EqualTo(r, l)
110111
case LessThan(l, r) => GreaterThan(r, l)
111112
case LessThanOrEqual(l, r) => GreaterThanOrEqual(r, l)
112113
case other => other
113114
}
114115

115116
val lessThans = binaryComparisons.map {
117+
case EqualTo(l, r) if l.foldable => EqualTo(r, l)
116118
case GreaterThan(l, r) => LessThan(r, l)
117119
case GreaterThanOrEqual(l, r) => LessThanOrEqual(r, l)
118120
case other => other
119121
}
120122

121123
var inferredConstraints = Set.empty[Expression]
122-
123-
greaterThans.foreach {
124-
case gt @ GreaterThan(l: Attribute, r: Attribute) =>
125-
inferredConstraints ++= replaceInequalityConstraints(greaterThans, r, l, gt)
126-
case gt @ GreaterThanOrEqual(l: Attribute, r: Attribute) =>
127-
inferredConstraints ++= replaceInequalityConstraints(greaterThans, r, l, gt)
128-
case gt @ GreaterThan(l @ Cast(_: Attribute, _, _), r: Attribute) =>
129-
inferredConstraints ++= replaceInequalityConstraints(greaterThans, r, l, gt)
130-
case gt @ GreaterThanOrEqual(l @ Cast(_: Attribute, _, _), r: Attribute) =>
131-
inferredConstraints ++= replaceInequalityConstraints(greaterThans, r, l, gt)
132-
case gt @ GreaterThan(l: Attribute, r @ Cast(_: Attribute, _, _)) =>
133-
inferredConstraints ++= replaceInequalityConstraints(greaterThans, r, l, gt)
134-
case gt @ GreaterThanOrEqual(l: Attribute, r @ Cast(_: Attribute, _, _)) =>
135-
inferredConstraints ++= replaceInequalityConstraints(greaterThans, r, l, gt)
136-
case _ => // No inference
137-
}
138-
139-
lessThans.foreach {
140-
case lt @ LessThan(l: Attribute, r: Attribute) =>
141-
inferredConstraints ++= replaceInequalityConstraints(lessThans, r, l, lt)
142-
case lt @ LessThanOrEqual(l: Attribute, r: Attribute) =>
143-
inferredConstraints ++= replaceInequalityConstraints(lessThans, r, l, lt)
144-
case lt @ LessThan(l @ Cast(_: Attribute, _, _), r: Attribute) =>
145-
inferredConstraints ++= replaceInequalityConstraints(lessThans, r, l, lt)
146-
case lt @ LessThanOrEqual(l @ Cast(_: Attribute, _, _), r: Attribute) =>
147-
inferredConstraints ++= replaceInequalityConstraints(lessThans, r, l, lt)
148-
case lt @ LessThan(l: Attribute, r @ Cast(_: Attribute, _, _)) =>
149-
inferredConstraints ++= replaceInequalityConstraints(lessThans, r, l, lt)
150-
case lt @ LessThanOrEqual(l: Attribute, r @ Cast(_: Attribute, _, _)) =>
151-
inferredConstraints ++= replaceInequalityConstraints(lessThans, r, l, lt)
152-
case _ => // No inference
124+
Seq(greaterThans, lessThans).foreach { comparisons =>
125+
comparisons.foreach {
126+
case b @ BinaryComparison(l: Attribute, r: Expression) if r.foldable =>
127+
inferredConstraints ++= replaceInequalityConstraints(comparisons, l, r, b)
128+
case b @ BinaryComparison(l @ Cast(_: Attribute, _, _), r: Expression) if r.foldable =>
129+
inferredConstraints ++= replaceInequalityConstraints(comparisons, l, r, b)
130+
case _ => // No inference
131+
}
153132
}
154133
(inferredConstraints -- constraints -- greaterThans -- lessThans)
155134
.filterNot(i => constraints.exists(_.semanticEquals(i)))
@@ -167,11 +146,18 @@ trait ConstraintHelper {
167146
source: Expression,
168147
destination: Expression,
169148
op: BinaryComparison): Set[Expression] = (constraints - op).map {
170-
case EqualTo(l, r) if l.semanticEquals(source) => op.makeCopy(Array(destination, r))
171-
case EqualTo(l, r) if r.semanticEquals(source) => op.makeCopy(Array(destination, l))
172-
case gt @ GreaterThan(l, r) if l.semanticEquals(source) => gt.makeCopy(Array(destination, r))
173-
case lt @ LessThan(l, r) if l.semanticEquals(source) => lt.makeCopy(Array(destination, r))
174-
case BinaryComparison(l, r) if l.semanticEquals(source) => op.makeCopy(Array(destination, r))
149+
case gt @ GreaterThan(l, r) if r.semanticEquals(source) =>
150+
gt.copy(l, destination)
151+
case GreaterThanOrEqual(l, r) if r.semanticEquals(source) && op.isInstanceOf[GreaterThan] =>
152+
op.makeCopy(Array(l, destination))
153+
case gt @ GreaterThanOrEqual(l, r) if r.semanticEquals(source) =>
154+
gt.copy(l, destination)
155+
case lt @ LessThan(l, r) if r.semanticEquals(source) =>
156+
lt.copy(l, destination)
157+
case LessThanOrEqual(l, r) if r.semanticEquals(source) && op.isInstanceOf[LessThan] =>
158+
op.makeCopy(Array(l, destination))
159+
case lt @ LessThanOrEqual(l, r) if r.semanticEquals(source) =>
160+
lt.copy(l, destination)
175161
case other => other
176162
}
177163

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
402402
test("Constraints inferred from inequality attributes: case2") {
403403
val original = testRelation.where('a < 'b && 'b < 'c && 'c < 5)
404404
val optimized = testRelation.where(IsNotNull('a) && IsNotNull('b) && IsNotNull('c)
405-
&& 'a < 'b && 'b < 'c && 'c > 'a && 'a < 5 && 'b < 5 && 'c < 5)
405+
&& 'a < 'b && 'b < 'c && 'a < 5 && 'b < 5 && 'c < 5)
406406
comparePlans(Optimize.execute(original.analyze), optimized.analyze)
407407
}
408408

@@ -413,8 +413,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
413413
val optimizedLeft = testRelation.where(IsNotNull('a) && IsNotNull('b)
414414
&& 'b >= 3 && 'b <= 13).as("x")
415415
val optimizedRight = testRelation.where(IsNotNull('a) && IsNotNull('b) && IsNotNull('c)
416-
&& 'b < 'c && 'c > 3 && 'b <= 13).as("y")
417-
416+
&& 'c > 3 && 'b <= 13).as("y")
418417
val condition = Some("x.a".attr === "y.a".attr
419418
&& "x.b".attr >= "y.b".attr && "x.b".attr < "y.c".attr)
420419
val original = left.join(right, Inner, condition)

0 commit comments

Comments
 (0)