@@ -155,19 +155,20 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper {
155155}
156156
157157/**
158- * PythonUDF in join condition can not be evaluated, this rule will detect the PythonUDF
159- * and pull them out from join condition. For python udf accessing attributes from only one side,
160- * they are pushed down by operation push down rules. If not (e.g. user disables filter push
161- * down rules), we need to pull them out in this rule too.
158+ * PythonUDF in join condition can't be evaluated if it refers to attributes from both join sides.
159+ * See `ExtractPythonUDFs` for details. This rule will detect un-evaluable PythonUDF and pull them
160+ * out from join condition.
162161 */
163162object PullOutPythonUDFInJoinCondition extends Rule [LogicalPlan ] with PredicateHelper {
164- def hasPythonUDF (expression : Expression ): Boolean = {
165- expression.collectFirst { case udf : PythonUDF => udf }.isDefined
163+
164+ private def hasUnevaluablePythonUDF (expr : Expression , j : Join ): Boolean = {
165+ expr.find { e =>
166+ PythonUDF .isScalarPythonUDF(e) && ! canEvaluate(e, j.left) && ! canEvaluate(e, j.right)
167+ }.isDefined
166168 }
167169
168170 override def apply (plan : LogicalPlan ): LogicalPlan = plan transformUp {
169- case j @ Join (_, _, joinType, condition)
170- if condition.isDefined && hasPythonUDF(condition.get) =>
171+ case j @ Join (_, _, joinType, Some (cond)) if hasUnevaluablePythonUDF(cond, j) =>
171172 if (! joinType.isInstanceOf [InnerLike ] && joinType != LeftSemi ) {
172173 // The current strategy only support InnerLike and LeftSemi join because for other type,
173174 // it breaks SQL semantic if we run the join condition as a filter after join. If we pass
@@ -179,10 +180,9 @@ object PullOutPythonUDFInJoinCondition extends Rule[LogicalPlan] with PredicateH
179180 }
180181 // If condition expression contains python udf, it will be moved out from
181182 // the new join conditions.
182- val (udf, rest) =
183- splitConjunctivePredicates(condition.get).partition(hasPythonUDF)
183+ val (udf, rest) = splitConjunctivePredicates(cond).partition(hasUnevaluablePythonUDF(_, j))
184184 val newCondition = if (rest.isEmpty) {
185- logWarning(s " The join condition: $condition of the join plan contains PythonUDF only, " +
185+ logWarning(s " The join condition: $cond of the join plan contains PythonUDF only, " +
186186 s " it will be moved out and the join plan will be turned to cross join. " )
187187 None
188188 } else {
0 commit comments