1+ /*
2+ * Licensed to the Apache Software Foundation (ASF) under one or more
3+ * contributor license agreements. See the NOTICE file distributed with
4+ * this work for additional information regarding copyright ownership.
5+ * The ASF licenses this file to You under the Apache License, Version 2.0
6+ * (the "License"); you may not use this file except in compliance with
7+ * the License. You may obtain a copy of the License at
8+ *
9+ * http://www.apache.org/licenses/LICENSE-2.0
10+ *
11+ * Unless required by applicable law or agreed to in writing, software
12+ * distributed under the License is distributed on an "AS IS" BASIS,
13+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+ * See the License for the specific language governing permissions and
15+ * limitations under the License.
16+ */
17+
18+ package org .apache .spark .sql .catalyst .optimizer
19+
20+ import org .apache .spark .sql .catalyst .analysis .EliminateAnalysisOperators
21+ import org .apache .spark .sql .catalyst .expressions .{Literal , Expression }
22+ import org .apache .spark .sql .catalyst .plans .logical ._
23+ import org .apache .spark .sql .catalyst .plans .PlanTest
24+ import org .apache .spark .sql .catalyst .rules ._
25+ import org .apache .spark .sql .catalyst .dsl .plans ._
26+ import org .apache .spark .sql .catalyst .dsl .expressions ._
27+
28+ class ConditionSimplificationSuite extends PlanTest {
29+
30+ object Optimize extends RuleExecutor [LogicalPlan ] {
31+ val batches =
32+ Batch (" AnalysisNodes" , Once ,
33+ EliminateAnalysisOperators ) ::
34+ Batch (" Constant Folding" , FixedPoint (10 ),
35+ NullPropagation ,
36+ ConstantFolding ,
37+ ConditionSimplification ,
38+ BooleanSimplification ,
39+ SimplifyFilters ) :: Nil
40+ }
41+
42+ val testRelation = LocalRelation (' a .int, ' b .int, ' c .int, ' d .string)
43+
44+ def checkCondition (originCondition : Expression , optimizedCondition : Expression ): Unit = {
45+ val originQuery = testRelation.where(originCondition).analyze
46+ val optimized = Optimize (originQuery)
47+ val expected = testRelation.where(optimizedCondition).analyze
48+ comparePlans(optimized, expected)
49+ }
50+
51+ def checkCondition (originCondition : Expression ): Unit = {
52+ val originQuery = testRelation.where(originCondition).analyze
53+ val optimized = Optimize (originQuery)
54+ val expected = testRelation
55+ comparePlans(optimized, expected)
56+ }
57+
58+ test(" literal in front of attribute" ) {
59+ checkCondition(Literal (1 ) < ' a || Literal (2 ) < ' a , ' a > 1 )
60+ }
61+
62+ test(" combine the same condition" ) {
63+ checkCondition(' a < 1 || ' a < 1 , ' a < 1 )
64+ checkCondition(' a < 1 || ' a < 1 || ' a < 1 || ' a < 1 , ' a < 1 )
65+ checkCondition(' a > 2 && ' a > 2 , ' a > 2 )
66+ checkCondition(' a > 2 && ' a > 2 && ' a > 2 && ' a > 2 , ' a > 2 )
67+ checkCondition((' a < 1 && ' a < 2 ) || (' a < 1 && ' a < 2 ), ' a < 1 )
68+ }
69+
70+ test(" combine literal binary comparison" ) {
71+ checkCondition(' a === 1 && ' a < 1 )
72+ checkCondition(' a === 1 || ' a < 1 , ' a <= 1 )
73+
74+ checkCondition(' a === 1 && ' a === 2 )
75+ checkCondition(' a === 1 || ' a === 2 , ' a === 1 || ' a === 2 )
76+
77+ checkCondition(' a <= 1 && ' a > 1 )
78+ checkCondition(' a <= 1 || ' a > 1 )
79+
80+ checkCondition(' a < 1 && ' a >= 1 )
81+ checkCondition(' a < 1 || ' a >= 1 )
82+
83+ checkCondition(' a > 3 && ' a > 2 , ' a > 3 )
84+ checkCondition(' a > 3 || ' a > 2 , ' a > 2 )
85+
86+ checkCondition(' a >= 1 && ' a <= 1 , ' a === 1 )
87+
88+ }
89+
90+ test(" different data type comparison" ) {
91+ checkCondition(' a > " abc" )
92+ checkCondition(' a > " a" && ' a < " b" )
93+
94+ checkCondition(' a > " a" || ' a < " b" )
95+
96+ checkCondition(' a > " 9" || ' a < " 0" , ' a > 9.0 || ' a < 0.0 )
97+ checkCondition(' d > 9 && ' d < 1 , ' d > 9.0 && ' d < 1.0 )
98+
99+ checkCondition(' a > " 9" || ' a < " 0" , ' a > 9.0 || ' a < 0.0 )
100+ }
101+
102+ test(" combine predicate : 2 same combine" ) {
103+ checkCondition(' a < 1 || ' b > 2 || ' a >= 1 )
104+ checkCondition(' a < 1 && ' b > 2 && ' a >= 1 )
105+
106+ checkCondition(' a < 2 || ' b > 3 || ' b > 2 , ' a < 2 || ' b > 2 )
107+ checkCondition(' a < 2 && ' b > 3 && ' b > 2 , ' a < 2 && ' b > 3 )
108+
109+ checkCondition(' a < 2 || (' b > 3 || ' b > 2 ), ' b > 2 || ' a < 2 )
110+ checkCondition(' a < 2 && (' b > 3 && ' b > 2 ), ' b > 3 && ' a < 2 )
111+
112+ checkCondition(' a < 2 || ' a === 3 || ' a > 5 , ' a < 2 || ' a === 3 || ' a > 5 )
113+ }
114+
115+ test(" combine predicate : 2 difference combine" ) {
116+ checkCondition((' a < 2 || ' a > 3 ) && ' a > 4 , ' a > 4 )
117+ checkCondition((' a < 2 || ' b > 3 ) && ' a < 2 , ' a < 2 )
118+
119+ checkCondition(' a < 2 || (' a >= 2 && ' b > 1 ), ' b > 1 || ' a < 2 )
120+ checkCondition(' a < 2 || (' a === 2 && ' b > 1 ), ' a < 2 || (' a === 2 && ' b > 1 ))
121+
122+ checkCondition(' a > 3 || (' a > 2 && ' a < 4 ), ' a > 2 )
123+ }
124+
125+ test(" multi left, single right" ) {
126+ checkCondition((' a < 2 || ' a > 3 || ' b > 5 ) && ' a < 2 , ' a < 2 )
127+ }
128+
129+ test(" multi left, multi right" ) {
130+ checkCondition((' a < 2 || ' b > 3 ) && (' a < 2 || ' c > 5 ), ' a < 2 || (' b > 3 && ' c > 5 ))
131+
132+ var input : Expression = (' a === ' b || ' b > 3 ) && (' a === ' b || ' a > 3 ) && (' a === ' b || ' a < 5 )
133+ var expected : Expression = ' a === ' b || (' b > 3 && ' a > 3 && ' a < 5 )
134+ checkCondition(input, expected)
135+
136+ input = (' a === ' b || ' b > 3 ) && (' a === ' b || ' a > 3 ) && (' a === ' b || ' a > 1 )
137+ expected = ' a === ' b || (' b > 3 && ' a > 3 )
138+ checkCondition(input, expected)
139+
140+ input = (' a === ' b && ' b > 3 && ' c > 2 ) ||
141+ (' a === ' b && ' c < 1 && ' a === 5 ) ||
142+ (' a === ' b && ' b < 5 && ' a > 1 )
143+
144+ expected = (' a === ' b ) &&
145+ ((((' b > 3 ) && (' c > 2 )) ||
146+ ((' c < 1 ) && (' a === 5 ))) ||
147+ ((' b < 5 ) && (' a > 1 )))
148+ checkCondition(input, expected)
149+
150+ input = (' a < 2 || ' b > 5 || ' a < 2 || ' b > 1 ) && (' a < 2 || ' b > 1 )
151+ expected = ' a < 2 || ' b > 1
152+ checkCondition(input, expected)
153+
154+ input = (' a === ' b || ' b > 5 ) && (' a === ' b || ' c > 3 ) && (' a === ' b || ' b > 1 )
155+ expected = (' a === ' b ) || (' c > 3 && ' b > 5 )
156+ checkCondition(input, expected)
157+ }
158+ }
0 commit comments