1717
1818package org .apache .spark .sql .catalyst .optimizer
1919
20- import org .scalatest .Matchers ._
21-
2220import org .apache .spark .api .python .PythonEvalType
2321import org .apache .spark .sql .AnalysisException
2422import org .apache .spark .sql .catalyst .dsl .expressions ._
@@ -28,7 +26,7 @@ import org.apache.spark.sql.catalyst.plans._
2826import org .apache .spark .sql .catalyst .plans .logical .{LocalRelation , LogicalPlan }
2927import org .apache .spark .sql .catalyst .rules .RuleExecutor
3028import org .apache .spark .sql .internal .SQLConf ._
31- import org .apache .spark .sql .types .BooleanType
29+ import org .apache .spark .sql .types .{ BooleanType , IntegerType }
3230
3331class PullOutPythonUDFInJoinConditionSuite extends PlanTest {
3432
@@ -40,13 +38,29 @@ class PullOutPythonUDFInJoinConditionSuite extends PlanTest {
4038 CheckCartesianProducts ) :: Nil
4139 }
4240
43- val testRelationLeft = LocalRelation (' a .int, ' b .int)
44- val testRelationRight = LocalRelation (' c .int, ' d .int)
41+ val attrA = ' a .int
42+ val attrB = ' b .int
43+ val attrC = ' c .int
44+ val attrD = ' d .int
45+
46+ val testRelationLeft = LocalRelation (attrA, attrB)
47+ val testRelationRight = LocalRelation (attrC, attrD)
48+
49+ // This join condition refers to attributes from 2 tables, but the PythonUDF inside it only
50+ // refer to attributes from one side.
51+ val evaluableJoinCond = {
52+ val pythonUDF = PythonUDF (" evaluable" , null ,
53+ IntegerType ,
54+ Seq (attrA),
55+ PythonEvalType .SQL_BATCHED_UDF ,
56+ udfDeterministic = true )
57+ pythonUDF === attrC
58+ }
4559
46- // Dummy python UDF for testing. Unable to execute .
47- val pythonUDF = PythonUDF (" pythonUDF " , null ,
60+ // This join condition is a PythonUDF which refers to attributes from 2 tables .
61+ val unevaluableJoinCond = PythonUDF (" unevaluable " , null ,
4862 BooleanType ,
49- Seq .empty ,
63+ Seq (attrA, attrC) ,
5064 PythonEvalType .SQL_BATCHED_UDF ,
5165 udfDeterministic = true )
5266
@@ -66,62 +80,76 @@ class PullOutPythonUDFInJoinConditionSuite extends PlanTest {
6680 }
6781 }
6882
69- test(" inner join condition with python udf only " ) {
70- val query = testRelationLeft.join(
83+ test(" inner join condition with python udf" ) {
84+ val query1 = testRelationLeft.join(
7185 testRelationRight,
7286 joinType = Inner ,
73- condition = Some (pythonUDF ))
74- val expected = testRelationLeft.join(
87+ condition = Some (unevaluableJoinCond ))
88+ val expected1 = testRelationLeft.join(
7589 testRelationRight,
7690 joinType = Inner ,
77- condition = None ).where(pythonUDF).analyze
78- comparePlanWithCrossJoinEnable(query, expected)
91+ condition = None ).where(unevaluableJoinCond).analyze
92+ comparePlanWithCrossJoinEnable(query1, expected1)
93+
94+ // evaluable PythonUDF will not be touched
95+ val query2 = testRelationLeft.join(
96+ testRelationRight,
97+ joinType = Inner ,
98+ condition = Some (evaluableJoinCond))
99+ comparePlans(Optimize .execute(query2), query2)
79100 }
80101
81- test(" left semi join condition with python udf only " ) {
82- val query = testRelationLeft.join(
102+ test(" left semi join condition with python udf" ) {
103+ val query1 = testRelationLeft.join(
83104 testRelationRight,
84105 joinType = LeftSemi ,
85- condition = Some (pythonUDF ))
86- val expected = testRelationLeft.join(
106+ condition = Some (unevaluableJoinCond ))
107+ val expected1 = testRelationLeft.join(
87108 testRelationRight,
88109 joinType = Inner ,
89- condition = None ).where(pythonUDF).select(' a , ' b ).analyze
90- comparePlanWithCrossJoinEnable(query, expected)
110+ condition = None ).where(unevaluableJoinCond).select(' a , ' b ).analyze
111+ comparePlanWithCrossJoinEnable(query1, expected1)
112+
113+ // evaluable PythonUDF will not be touched
114+ val query2 = testRelationLeft.join(
115+ testRelationRight,
116+ joinType = LeftSemi ,
117+ condition = Some (evaluableJoinCond))
118+ comparePlans(Optimize .execute(query2), query2)
91119 }
92120
93- test(" python udf and common condition" ) {
121+ test(" unevaluable python udf and common condition" ) {
94122 val query = testRelationLeft.join(
95123 testRelationRight,
96124 joinType = Inner ,
97- condition = Some (pythonUDF && ' a .attr === ' c .attr))
125+ condition = Some (unevaluableJoinCond && ' a .attr === ' c .attr))
98126 val expected = testRelationLeft.join(
99127 testRelationRight,
100128 joinType = Inner ,
101- condition = Some (' a .attr === ' c .attr)).where(pythonUDF ).analyze
129+ condition = Some (' a .attr === ' c .attr)).where(unevaluableJoinCond ).analyze
102130 val optimized = Optimize .execute(query.analyze)
103131 comparePlans(optimized, expected)
104132 }
105133
106- test(" python udf or common condition" ) {
134+ test(" unevaluable python udf or common condition" ) {
107135 val query = testRelationLeft.join(
108136 testRelationRight,
109137 joinType = Inner ,
110- condition = Some (pythonUDF || ' a .attr === ' c .attr))
138+ condition = Some (unevaluableJoinCond || ' a .attr === ' c .attr))
111139 val expected = testRelationLeft.join(
112140 testRelationRight,
113141 joinType = Inner ,
114- condition = None ).where(pythonUDF || ' a .attr === ' c .attr).analyze
142+ condition = None ).where(unevaluableJoinCond || ' a .attr === ' c .attr).analyze
115143 comparePlanWithCrossJoinEnable(query, expected)
116144 }
117145
118- test(" pull out whole complex condition with multiple python udf" ) {
146+ test(" pull out whole complex condition with multiple unevaluable python udf" ) {
119147 val pythonUDF1 = PythonUDF (" pythonUDF1" , null ,
120148 BooleanType ,
121- Seq .empty ,
149+ Seq (attrA, attrC) ,
122150 PythonEvalType .SQL_BATCHED_UDF ,
123151 udfDeterministic = true )
124- val condition = (pythonUDF || ' a .attr === ' c .attr) && pythonUDF1
152+ val condition = (unevaluableJoinCond || ' a .attr === ' c .attr) && pythonUDF1
125153
126154 val query = testRelationLeft.join(
127155 testRelationRight,
@@ -134,13 +162,13 @@ class PullOutPythonUDFInJoinConditionSuite extends PlanTest {
134162 comparePlanWithCrossJoinEnable(query, expected)
135163 }
136164
137- test(" partial pull out complex condition with multiple python udf" ) {
165+ test(" partial pull out complex condition with multiple unevaluable python udf" ) {
138166 val pythonUDF1 = PythonUDF (" pythonUDF1" , null ,
139167 BooleanType ,
140- Seq .empty ,
168+ Seq (attrA, attrC) ,
141169 PythonEvalType .SQL_BATCHED_UDF ,
142170 udfDeterministic = true )
143- val condition = (pythonUDF || pythonUDF1) && ' a .attr === ' c .attr
171+ val condition = (unevaluableJoinCond || pythonUDF1) && ' a .attr === ' c .attr
144172
145173 val query = testRelationLeft.join(
146174 testRelationRight,
@@ -149,23 +177,41 @@ class PullOutPythonUDFInJoinConditionSuite extends PlanTest {
149177 val expected = testRelationLeft.join(
150178 testRelationRight,
151179 joinType = Inner ,
152- condition = Some (' a .attr === ' c .attr)).where(pythonUDF || pythonUDF1).analyze
180+ condition = Some (' a .attr === ' c .attr)).where(unevaluableJoinCond || pythonUDF1).analyze
181+ val optimized = Optimize .execute(query.analyze)
182+ comparePlans(optimized, expected)
183+ }
184+
185+ test(" pull out unevaluable python udf when it's mixed with evaluable one" ) {
186+ val query = testRelationLeft.join(
187+ testRelationRight,
188+ joinType = Inner ,
189+ condition = Some (evaluableJoinCond && unevaluableJoinCond))
190+ val expected = testRelationLeft.join(
191+ testRelationRight,
192+ joinType = Inner ,
193+ condition = Some (evaluableJoinCond)).where(unevaluableJoinCond).analyze
153194 val optimized = Optimize .execute(query.analyze)
154195 comparePlans(optimized, expected)
155196 }
156197
157198 test(" throw an exception for not support join type" ) {
158199 for (joinType <- unsupportedJoinTypes) {
159- val thrownException = the [AnalysisException ] thrownBy {
200+ val e = intercept [AnalysisException ] {
160201 val query = testRelationLeft.join(
161202 testRelationRight,
162203 joinType,
163- condition = Some (pythonUDF ))
204+ condition = Some (unevaluableJoinCond ))
164205 Optimize .execute(query.analyze)
165206 }
166- assert(thrownException .message.contentEquals(
207+ assert(e .message.contentEquals(
167208 s " Using PythonUDF in join condition of join type $joinType is not supported. " ))
209+
210+ val query2 = testRelationLeft.join(
211+ testRelationRight,
212+ joinType,
213+ condition = Some (evaluableJoinCond))
214+ comparePlans(Optimize .execute(query2), query2)
168215 }
169216 }
170217}
171-
0 commit comments