Skip to content

Commit cb195cf

Browse files
committed
only pull out unevaluable python udf from join condition
1 parent 6a064ba commit cb195cf

File tree

2 files changed

+23
-11
lines changed
  • python/pyspark/sql/tests
  • sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer

2 files changed

+23
-11
lines changed

python/pyspark/sql/tests/test_udf.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,18 @@ def test_udf_in_join_condition(self):
209209
with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
210210
self.assertEqual(df.collect(), [Row(a=1, b=1)])
211211

212+
def test_udf_in_left_outer_join_condition(self):
213+
# regression test for SPARK-26147
214+
from pyspark.sql.functions import udf, col
215+
left = self.spark.createDataFrame([Row(a=1)])
216+
right = self.spark.createDataFrame([Row(b=1)])
217+
f = udf(lambda a: str(a), StringType())
218+
# The join condition can't be pushed down, as it refers to attributes from both sides.
219+
# The Python UDF only refer to attributes from one side, so it's evaluable.
220+
df = left.join(right, f("a") == col("b").cast("string"), how = "left_outer")
221+
with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
222+
self.assertEqual(df.collect(), [Row(a=1, b=1)])
223+
212224
def test_udf_in_left_semi_join_condition(self):
213225
# regression test for SPARK-25314
214226
from pyspark.sql.functions import udf

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -155,19 +155,20 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper {
155155
}
156156

157157
/**
158-
* PythonUDF in join condition can not be evaluated, this rule will detect the PythonUDF
159-
* and pull them out from join condition. For python udf accessing attributes from only one side,
160-
* they are pushed down by operation push down rules. If not (e.g. user disables filter push
161-
* down rules), we need to pull them out in this rule too.
158+
* PythonUDF in join condition can't be evaluated if it refers to attributes from both join sides.
159+
* See `ExtractPythonUDFs` for details. This rule will detect un-evaluable PythonUDF and pull them
160+
* out from join condition.
162161
*/
163162
object PullOutPythonUDFInJoinCondition extends Rule[LogicalPlan] with PredicateHelper {
164-
def hasPythonUDF(expression: Expression): Boolean = {
165-
expression.collectFirst { case udf: PythonUDF => udf }.isDefined
163+
164+
private def hasUnevaluablePythonUDF(expr: Expression, j: Join): Boolean = {
165+
expr.find { e =>
166+
PythonUDF.isScalarPythonUDF(e) && !canEvaluate(e, j.left) && !canEvaluate(e, j.right)
167+
}.isDefined
166168
}
167169

168170
override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
169-
case j @ Join(_, _, joinType, condition)
170-
if condition.isDefined && hasPythonUDF(condition.get) =>
171+
case j @ Join(_, _, joinType, Some(cond)) if hasUnevaluablePythonUDF(cond, j) =>
171172
if (!joinType.isInstanceOf[InnerLike] && joinType != LeftSemi) {
172173
// The current strategy only support InnerLike and LeftSemi join because for other type,
173174
// it breaks SQL semantic if we run the join condition as a filter after join. If we pass
@@ -179,10 +180,9 @@ object PullOutPythonUDFInJoinCondition extends Rule[LogicalPlan] with PredicateH
179180
}
180181
// If condition expression contains python udf, it will be moved out from
181182
// the new join conditions.
182-
val (udf, rest) =
183-
splitConjunctivePredicates(condition.get).partition(hasPythonUDF)
183+
val (udf, rest) = splitConjunctivePredicates(cond).partition(hasUnevaluablePythonUDF(_, j))
184184
val newCondition = if (rest.isEmpty) {
185-
logWarning(s"The join condition:$condition of the join plan contains PythonUDF only," +
185+
logWarning(s"The join condition:$cond of the join plan contains PythonUDF only," +
186186
s" it will be moved out and the join plan will be turned to cross join.")
187187
None
188188
} else {

0 commit comments

Comments
 (0)