Skip to content

Commit e99a26c

Browse files
committed
refactory And/Or optimization to make it more readable and clean
1 parent ac82785 commit e99a26c

File tree

5 files changed

+288
-1
lines changed

5 files changed

+288
-1
lines changed

pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -824,6 +824,11 @@
824824
<artifactId>jackson-mapper-asl</artifactId>
825825
<version>1.8.8</version>
826826
</dependency>
827+
<dependency>
828+
<groupId>org.spire-math</groupId>
829+
<artifactId>spire_2.10</artifactId>
830+
<version>0.9.0</version>
831+
</dependency>
827832
</dependencies>
828833
</dependencyManagement>
829834

sql/catalyst/pom.xml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,10 @@
4444
<groupId>org.scala-lang</groupId>
4545
<artifactId>scala-reflect</artifactId>
4646
</dependency>
47-
47+
<dependency>
48+
<groupId>org.spire-math</groupId>
49+
<artifactId>spire_2.10</artifactId>
50+
</dependency>
4851
<dependency>
4952
<groupId>org.apache.spark</groupId>
5053
<artifactId>spark-core_${scala.binary.version}</artifactId>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,14 @@ trait PredicateHelper {
4848
}
4949
}
5050

51+
protected def splitDisjunctivePredicates(condition: Expression): Seq[Expression] = {
52+
condition match {
53+
case Or(cond1, cond2) =>
54+
splitDisjunctivePredicates(cond1) ++ splitDisjunctivePredicates(cond2)
55+
case other => other :: Nil
56+
}
57+
}
58+
5159
/**
5260
* Returns true if `expr` can be evaluated using only the output of `plan`. This method
5361
* can be used to determine when is is acceptable to move expression evaluation within a query

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
package org.apache.spark.sql.catalyst.optimizer
1919

2020
import scala.collection.immutable.HashSet
21+
import spire.implicits._
22+
import spire.math._
23+
2124
import org.apache.spark.sql.catalyst.expressions._
2225
import org.apache.spark.sql.catalyst.plans.Inner
2326
import org.apache.spark.sql.catalyst.plans.FullOuter
@@ -293,6 +296,116 @@ object OptimizeIn extends Rule[LogicalPlan] {
293296
}
294297
}
295298

299+
/**
300+
* Simplifies Conditions(And, Or) expressions when the conditions can by optimized.
301+
*/
302+
object ConditionSimplification extends Rule[LogicalPlan] with PredicateHelper {
303+
304+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
305+
case q: LogicalPlan => q transformExpressionsDown {
306+
// a && a => a
307+
case And(left, right) if left.fastEquals(right) =>
308+
left
309+
310+
// a || a => a
311+
case Or(left, right) if left.fastEquals(right) =>
312+
left
313+
314+
// a < 2 && a > 2 => false, a > 3 && a > 5 => a > 5
315+
case and @ And(
316+
e1 @ NumericLiteralBinaryComparison(n1, i1),
317+
e2 @ NumericLiteralBinaryComparison(n2, i2)) if n1 == n2 =>
318+
if (!i1.intersects(i2)) Literal(false)
319+
else if (i1.isSubsetOf(i2)) e1
320+
else if (i1.isSupersetOf(i2)) e2
321+
else and
322+
323+
// a < 2 || a >= 2 => true, a > 3 || a > 5 => a > 3
324+
case or @ Or(
325+
e1 @ NumericLiteralBinaryComparison(n1, i1),
326+
e2 @ NumericLiteralBinaryComparison(n2, i2)) if n1 == n2 =>
327+
if (i1.intersects(i2)) Literal(true)
328+
else if (i1.isSubsetOf(i2)) e2
329+
else if (i1.isSupersetOf(i2)) e1
330+
else or
331+
332+
// (a < 3 && b > 5) || a > 2 => b > 5 || a > 2
333+
case Or(left1 @ And(left2, right2), right1) =>
334+
And(Or(left2, right1), Or(right2, right1))
335+
336+
// (a < 3 || b > 5) || a > 2 => true, (b > 5 || a < 3) || a > 2 => true
337+
case Or( Or(
338+
e1 @ NumericLiteralBinaryComparison(n1, i1), e2 @ NumericLiteralBinaryComparison(n2, i2)),
339+
right @ NumericLiteralBinaryComparison(n3, i3)) =>
340+
if (n3 fastEquals n1) {
341+
Or(Or(e1, right), e2)
342+
} else {
343+
Or(Or(e2, right), e1)
344+
}
345+
346+
// (b > 5 && a < 2) && a > 3 => false, (a < 2 && b > 5) && a > 3 => false
347+
case And(And(
348+
e1 @ NumericLiteralBinaryComparison(n1, i1), e2 @ NumericLiteralBinaryComparison(n2, i2)),
349+
right @ NumericLiteralBinaryComparison(n3, i3)) =>
350+
if (n3 fastEquals n1) {
351+
And(And(e1, right), e2)
352+
} else {
353+
And(And(e2, right), e1)
354+
}
355+
356+
// (a < 2 || b > 5) && a > 3 => b > 5 && a > 3
357+
case And(left1@Or(left2, right2), right1) =>
358+
Or(And(left2, right1), And(right2, right1))
359+
360+
// (a && b && c && ...) || (a && b && d && ...) || (a && b && e && ...) ... =>
361+
// a && b && ((c && ...) || (d && ...) || (e && ...) || ...)
362+
case or @ Or(left, right) =>
363+
val lhsSet = splitConjunctivePredicates(left).toSet
364+
val rhsSet = splitConjunctivePredicates(right).toSet
365+
val common = lhsSet.intersect(rhsSet)
366+
(lhsSet.diff(common).reduceOption(And) ++ rhsSet.diff(common).reduceOption(And))
367+
.reduceOption(Or)
368+
.map(_ :: common.toList)
369+
.getOrElse(common.toList)
370+
.reduce(And)
371+
372+
// (a || b || c || ...) && (a || b || d || ...) && (a || b || e || ...) ... =>
373+
// (a || b) || ((c || ...) && (f || ...) && (e || ...) && ...)
374+
case and @ And(left, right) =>
375+
val lhsSet = splitDisjunctivePredicates(left).toSet
376+
val rhsSet = splitDisjunctivePredicates(right).toSet
377+
val common = lhsSet.intersect(rhsSet)
378+
(lhsSet.diff(common).reduceOption(Or) ++ rhsSet.diff(common).reduceOption(Or))
379+
.reduceOption(And)
380+
.map(_ :: common.toList)
381+
.getOrElse(common.toList)
382+
.reduce(Or)
383+
}
384+
}
385+
386+
private implicit class NumericLiteral(e: Literal) {
387+
def toDouble = Cast(e, DoubleType).eval().asInstanceOf[Double]
388+
}
389+
390+
object NumericLiteralBinaryComparison {
391+
def unapply(e: Expression): Option[(NamedExpression, Interval[Double])] = e match {
392+
case LessThan(n: NamedExpression, l @ Literal(_, _: NumericType)) => Some((n, Interval.below(l.toDouble)))
393+
case LessThan(l @ Literal(_, _: NumericType), n: NamedExpression) => Some((n, Interval.atOrAbove(l.toDouble)))
394+
395+
case GreaterThan(n: NamedExpression, l @ Literal(_, _: NumericType)) => Some((n, Interval.above(l.toDouble)))
396+
case GreaterThan(l @ Literal(_, dt: NumericType), n: NamedExpression) => Some((n, Interval.atOrBelow(l.toDouble)))
397+
398+
case LessThanOrEqual(n: NamedExpression, l @ Literal(_, _: NumericType)) => Some((n, Interval.atOrBelow(l.toDouble)))
399+
case LessThanOrEqual(l @ Literal(_, _: NumericType), n: NamedExpression) => Some((n, Interval.above(l.toDouble)))
400+
401+
case GreaterThanOrEqual(n: NamedExpression, l @ Literal(_, _: NumericType)) => Some((n, Interval.atOrAbove(l.toDouble)))
402+
case GreaterThanOrEqual(l @ Literal(_, _: NumericType), n: NamedExpression) => Some((n, Interval.below(l.toDouble)))
403+
404+
case EqualTo(n: NamedExpression, l @ Literal(_, _: NumericType)) => Some((n, Interval.point(l.toDouble)))
405+
}
406+
}
407+
}
408+
296409
/**
297410
* Simplifies boolean expressions where the answer can be determined without evaluating both sides.
298411
* Note that this rule can eliminate expressions that might otherwise have been evaluated and thus
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.optimizer
19+
20+
import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators
21+
import org.apache.spark.sql.catalyst.expressions.{Literal, Expression}
22+
import org.apache.spark.sql.catalyst.plans.logical._
23+
import org.apache.spark.sql.catalyst.plans.PlanTest
24+
import org.apache.spark.sql.catalyst.rules._
25+
import org.apache.spark.sql.catalyst.dsl.plans._
26+
import org.apache.spark.sql.catalyst.dsl.expressions._
27+
28+
class ConditionSimplificationSuite extends PlanTest {
29+
30+
object Optimize extends RuleExecutor[LogicalPlan] {
31+
val batches =
32+
Batch("AnalysisNodes", Once,
33+
EliminateAnalysisOperators) ::
34+
Batch("Constant Folding", FixedPoint(10),
35+
NullPropagation,
36+
ConstantFolding,
37+
ConditionSimplification,
38+
BooleanSimplification,
39+
SimplifyFilters) :: Nil
40+
}
41+
42+
val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.string)
43+
44+
def checkCondition(originCondition: Expression, optimizedCondition: Expression): Unit = {
45+
val originQuery = testRelation.where(originCondition).analyze
46+
val optimized = Optimize(originQuery)
47+
val expected = testRelation.where(optimizedCondition).analyze
48+
comparePlans(optimized, expected)
49+
}
50+
51+
def checkCondition(originCondition: Expression): Unit = {
52+
val originQuery = testRelation.where(originCondition).analyze
53+
val optimized = Optimize(originQuery)
54+
val expected = testRelation
55+
comparePlans(optimized, expected)
56+
}
57+
58+
test("literal in front of attribute") {
59+
checkCondition(Literal(1) < 'a || Literal(2) < 'a, 'a > 1)
60+
}
61+
62+
test("combine the same condition") {
63+
checkCondition('a < 1 || 'a < 1, 'a < 1)
64+
checkCondition('a < 1 || 'a < 1 || 'a < 1 || 'a < 1, 'a < 1)
65+
checkCondition('a > 2 && 'a > 2, 'a > 2)
66+
checkCondition('a > 2 && 'a > 2 && 'a > 2 && 'a > 2, 'a > 2)
67+
checkCondition(('a < 1 && 'a < 2) || ('a < 1 && 'a < 2), 'a < 1)
68+
}
69+
70+
test("combine literal binary comparison") {
71+
checkCondition('a === 1 && 'a < 1)
72+
checkCondition('a === 1 || 'a < 1, 'a <= 1)
73+
74+
checkCondition('a === 1 && 'a === 2)
75+
checkCondition('a === 1 || 'a === 2, 'a === 1 || 'a === 2)
76+
77+
checkCondition('a <= 1 && 'a > 1)
78+
checkCondition('a <= 1 || 'a > 1)
79+
80+
checkCondition('a < 1 && 'a >= 1)
81+
checkCondition('a < 1 || 'a >= 1)
82+
83+
checkCondition('a > 3 && 'a > 2, 'a > 3)
84+
checkCondition('a > 3 || 'a > 2, 'a > 2)
85+
86+
checkCondition('a >= 1 && 'a <= 1, 'a === 1)
87+
88+
}
89+
90+
test("different data type comparison") {
91+
checkCondition('a > "abc")
92+
checkCondition('a > "a" && 'a < "b")
93+
94+
checkCondition('a > "a" || 'a < "b")
95+
96+
checkCondition('a > "9" || 'a < "0", 'a > 9.0 || 'a < 0.0)
97+
checkCondition('d > 9 && 'd < 1, 'd > 9.0 && 'd < 1.0 )
98+
99+
checkCondition('a > "9" || 'a < "0", 'a > 9.0 || 'a < 0.0)
100+
}
101+
102+
test("combine predicate : 2 same combine") {
103+
checkCondition('a < 1 || 'b > 2 || 'a >= 1)
104+
checkCondition('a < 1 && 'b > 2 && 'a >= 1)
105+
106+
checkCondition('a < 2 || 'b > 3 || 'b > 2, 'a < 2 || 'b > 2)
107+
checkCondition('a < 2 && 'b > 3 && 'b > 2, 'a < 2 && 'b > 3)
108+
109+
checkCondition('a < 2 || ('b > 3 || 'b > 2), 'b > 2 || 'a < 2)
110+
checkCondition('a < 2 && ('b > 3 && 'b > 2), 'b > 3 && 'a < 2)
111+
112+
checkCondition('a < 2 || 'a === 3 || 'a > 5, 'a < 2 || 'a === 3 || 'a > 5)
113+
}
114+
115+
test("combine predicate : 2 difference combine") {
116+
checkCondition(('a < 2 || 'a > 3) && 'a > 4, 'a > 4)
117+
checkCondition(('a < 2 || 'b > 3) && 'a < 2, 'a < 2)
118+
119+
checkCondition('a < 2 || ('a >= 2 && 'b > 1), 'b > 1 || 'a < 2)
120+
checkCondition('a < 2 || ('a === 2 && 'b > 1), 'a < 2 || ('a === 2 && 'b > 1))
121+
122+
checkCondition('a > 3 || ('a > 2 && 'a < 4), 'a > 2)
123+
}
124+
125+
test("multi left, single right") {
126+
checkCondition(('a < 2 || 'a > 3 || 'b > 5) && 'a < 2, 'a < 2)
127+
}
128+
129+
test("multi left, multi right") {
130+
checkCondition(('a < 2 || 'b > 3) && ('a < 2 || 'c > 5), 'a < 2 || ('b > 3 && 'c > 5))
131+
132+
var input: Expression = ('a === 'b || 'b > 3) && ('a === 'b || 'a > 3) && ('a === 'b || 'a < 5)
133+
var expected: Expression = 'a === 'b || ('b > 3 && 'a > 3 && 'a < 5)
134+
checkCondition(input, expected)
135+
136+
input = ('a === 'b || 'b > 3) && ('a === 'b || 'a > 3) && ('a === 'b || 'a > 1)
137+
expected = 'a === 'b || ('b > 3 && 'a > 3)
138+
checkCondition(input, expected)
139+
140+
input = ('a === 'b && 'b > 3 && 'c > 2) ||
141+
('a === 'b && 'c < 1 && 'a === 5) ||
142+
('a === 'b && 'b < 5 && 'a > 1)
143+
144+
expected = ('a === 'b) &&
145+
(((('b > 3) && ('c > 2)) ||
146+
(('c < 1) && ('a === 5))) ||
147+
(('b < 5) && ('a > 1)))
148+
checkCondition(input, expected)
149+
150+
input = ('a < 2 || 'b > 5 || 'a < 2 || 'b > 1) && ('a < 2 || 'b > 1)
151+
expected = 'a < 2 || 'b > 1
152+
checkCondition(input, expected)
153+
154+
input = ('a === 'b || 'b > 5) && ('a === 'b || 'c > 3) && ('a === 'b || 'b > 1)
155+
expected = ('a === 'b) || ('c > 3 && 'b > 5)
156+
checkCondition(input, expected)
157+
}
158+
}

0 commit comments

Comments
 (0)