Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions python/pyspark/sql/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,18 @@ def test_udf_in_join_condition(self):
with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
self.assertEqual(df.collect(), [Row(a=1, b=1)])

def test_udf_in_left_outer_join_condition(self):
# regression test for SPARK-26147
from pyspark.sql.functions import udf, col
left = self.spark.createDataFrame([Row(a=1)])
right = self.spark.createDataFrame([Row(b=1)])
f = udf(lambda a: str(a), StringType())
# The join condition can't be pushed down, as it refers to attributes from both sides.
# The Python UDF only refer to attributes from one side, so it's evaluable.
df = left.join(right, f("a") == col("b").cast("string"), how="left_outer")
with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
self.assertEqual(df.collect(), [Row(a=1, b=1)])

def test_udf_in_left_semi_join_condition(self):
# regression test for SPARK-25314
from pyspark.sql.functions import udf
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,19 +155,20 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper {
}

/**
* PythonUDF in join condition can not be evaluated, this rule will detect the PythonUDF
* and pull them out from join condition. For python udf accessing attributes from only one side,
* they are pushed down by operation push down rules. If not (e.g. user disables filter push
* down rules), we need to pull them out in this rule too.
* PythonUDF in join condition can't be evaluated if it refers to attributes from both join sides.
* See `ExtractPythonUDFs` for details. This rule will detect un-evaluable PythonUDF and pull them
* out from join condition.
*/
object PullOutPythonUDFInJoinCondition extends Rule[LogicalPlan] with PredicateHelper {
def hasPythonUDF(expression: Expression): Boolean = {
expression.collectFirst { case udf: PythonUDF => udf }.isDefined

private def hasUnevaluablePythonUDF(expr: Expression, j: Join): Boolean = {
expr.find { e =>
PythonUDF.isScalarPythonUDF(e) && !canEvaluate(e, j.left) && !canEvaluate(e, j.right)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might need a comment to explain why we only pull out the Scalar PythonUDF.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's only possible to have scalar UDF in join condition, so changing it to e.isInstanceOf[PythonUDF] is same.

}.isDefined
}

override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case j @ Join(_, _, joinType, condition)
if condition.isDefined && hasPythonUDF(condition.get) =>
case j @ Join(_, _, joinType, Some(cond)) if hasUnevaluablePythonUDF(cond, j) =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Followed by the rule changes, we need modify the suites in PullOutPythonUDFInJoinConditionSuite, the suites should also construct the dummy python udf from both side.

if (!joinType.isInstanceOf[InnerLike] && joinType != LeftSemi) {
// The current strategy only support InnerLike and LeftSemi join because for other type,
// it breaks SQL semantic if we run the join condition as a filter after join. If we pass
Expand All @@ -179,10 +180,9 @@ object PullOutPythonUDFInJoinCondition extends Rule[LogicalPlan] with PredicateH
}
// If condition expression contains python udf, it will be moved out from
// the new join conditions.
val (udf, rest) =
splitConjunctivePredicates(condition.get).partition(hasPythonUDF)
val (udf, rest) = splitConjunctivePredicates(cond).partition(hasUnevaluablePythonUDF(_, j))
val newCondition = if (rest.isEmpty) {
logWarning(s"The join condition:$condition of the join plan contains PythonUDF only," +
logWarning(s"The join condition:$cond of the join plan contains PythonUDF only," +
s" it will be moved out and the join plan will be turned to cross join.")
None
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.spark.sql.catalyst.optimizer

import org.scalatest.Matchers._

import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.dsl.expressions._
Expand All @@ -28,7 +26,7 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.internal.SQLConf._
import org.apache.spark.sql.types.BooleanType
import org.apache.spark.sql.types.{BooleanType, IntegerType}

class PullOutPythonUDFInJoinConditionSuite extends PlanTest {

Expand All @@ -40,13 +38,29 @@ class PullOutPythonUDFInJoinConditionSuite extends PlanTest {
CheckCartesianProducts) :: Nil
}

val testRelationLeft = LocalRelation('a.int, 'b.int)
val testRelationRight = LocalRelation('c.int, 'd.int)
val attrA = 'a.int
val attrB = 'b.int
val attrC = 'c.int
val attrD = 'd.int

val testRelationLeft = LocalRelation(attrA, attrB)
val testRelationRight = LocalRelation(attrC, attrD)

// This join condition refers to attributes from 2 tables, but the PythonUDF inside it only
// refer to attributes from one side.
val evaluableJoinCond = {
val pythonUDF = PythonUDF("evaluable", null,
IntegerType,
Seq(attrA),
PythonEvalType.SQL_BATCHED_UDF,
udfDeterministic = true)
pythonUDF === attrC
}

// Dummy python UDF for testing. Unable to execute.
val pythonUDF = PythonUDF("pythonUDF", null,
// This join condition is a PythonUDF which refers to attributes from 2 tables.
val unevaluableJoinCond = PythonUDF("unevaluable", null,
BooleanType,
Seq.empty,
Seq(attrA, attrC),
PythonEvalType.SQL_BATCHED_UDF,
udfDeterministic = true)

Expand All @@ -66,62 +80,76 @@ class PullOutPythonUDFInJoinConditionSuite extends PlanTest {
}
}

test("inner join condition with python udf only") {
val query = testRelationLeft.join(
test("inner join condition with python udf") {
val query1 = testRelationLeft.join(
testRelationRight,
joinType = Inner,
condition = Some(pythonUDF))
val expected = testRelationLeft.join(
condition = Some(unevaluableJoinCond))
val expected1 = testRelationLeft.join(
testRelationRight,
joinType = Inner,
condition = None).where(pythonUDF).analyze
comparePlanWithCrossJoinEnable(query, expected)
condition = None).where(unevaluableJoinCond).analyze
comparePlanWithCrossJoinEnable(query1, expected1)

// evaluable PythonUDF will not be touched
val query2 = testRelationLeft.join(
testRelationRight,
joinType = Inner,
condition = Some(evaluableJoinCond))
comparePlans(Optimize.execute(query2), query2)
}

test("left semi join condition with python udf only") {
val query = testRelationLeft.join(
test("left semi join condition with python udf") {
val query1 = testRelationLeft.join(
testRelationRight,
joinType = LeftSemi,
condition = Some(pythonUDF))
val expected = testRelationLeft.join(
condition = Some(unevaluableJoinCond))
val expected1 = testRelationLeft.join(
testRelationRight,
joinType = Inner,
condition = None).where(pythonUDF).select('a, 'b).analyze
comparePlanWithCrossJoinEnable(query, expected)
condition = None).where(unevaluableJoinCond).select('a, 'b).analyze
comparePlanWithCrossJoinEnable(query1, expected1)

// evaluable PythonUDF will not be touched
val query2 = testRelationLeft.join(
testRelationRight,
joinType = LeftSemi,
condition = Some(evaluableJoinCond))
comparePlans(Optimize.execute(query2), query2)
}

test("python udf and common condition") {
test("unevaluable python udf and common condition") {
val query = testRelationLeft.join(
testRelationRight,
joinType = Inner,
condition = Some(pythonUDF && 'a.attr === 'c.attr))
condition = Some(unevaluableJoinCond && 'a.attr === 'c.attr))
val expected = testRelationLeft.join(
testRelationRight,
joinType = Inner,
condition = Some('a.attr === 'c.attr)).where(pythonUDF).analyze
condition = Some('a.attr === 'c.attr)).where(unevaluableJoinCond).analyze
val optimized = Optimize.execute(query.analyze)
comparePlans(optimized, expected)
}

test("python udf or common condition") {
test("unevaluable python udf or common condition") {
val query = testRelationLeft.join(
testRelationRight,
joinType = Inner,
condition = Some(pythonUDF || 'a.attr === 'c.attr))
condition = Some(unevaluableJoinCond || 'a.attr === 'c.attr))
val expected = testRelationLeft.join(
testRelationRight,
joinType = Inner,
condition = None).where(pythonUDF || 'a.attr === 'c.attr).analyze
condition = None).where(unevaluableJoinCond || 'a.attr === 'c.attr).analyze
comparePlanWithCrossJoinEnable(query, expected)
}

test("pull out whole complex condition with multiple python udf") {
test("pull out whole complex condition with multiple unevaluable python udf") {
val pythonUDF1 = PythonUDF("pythonUDF1", null,
BooleanType,
Seq.empty,
Seq(attrA, attrC),
PythonEvalType.SQL_BATCHED_UDF,
udfDeterministic = true)
val condition = (pythonUDF || 'a.attr === 'c.attr) && pythonUDF1
val condition = (unevaluableJoinCond || 'a.attr === 'c.attr) && pythonUDF1

val query = testRelationLeft.join(
testRelationRight,
Expand All @@ -134,13 +162,13 @@ class PullOutPythonUDFInJoinConditionSuite extends PlanTest {
comparePlanWithCrossJoinEnable(query, expected)
}

test("partial pull out complex condition with multiple python udf") {
test("partial pull out complex condition with multiple unevaluable python udf") {
val pythonUDF1 = PythonUDF("pythonUDF1", null,
BooleanType,
Seq.empty,
Seq(attrA, attrC),
PythonEvalType.SQL_BATCHED_UDF,
udfDeterministic = true)
val condition = (pythonUDF || pythonUDF1) && 'a.attr === 'c.attr
val condition = (unevaluableJoinCond || pythonUDF1) && 'a.attr === 'c.attr

val query = testRelationLeft.join(
testRelationRight,
Expand All @@ -149,23 +177,41 @@ class PullOutPythonUDFInJoinConditionSuite extends PlanTest {
val expected = testRelationLeft.join(
testRelationRight,
joinType = Inner,
condition = Some('a.attr === 'c.attr)).where(pythonUDF || pythonUDF1).analyze
condition = Some('a.attr === 'c.attr)).where(unevaluableJoinCond || pythonUDF1).analyze
val optimized = Optimize.execute(query.analyze)
comparePlans(optimized, expected)
}

test("pull out unevaluable python udf when it's mixed with evaluable one") {
val query = testRelationLeft.join(
testRelationRight,
joinType = Inner,
condition = Some(evaluableJoinCond && unevaluableJoinCond))
val expected = testRelationLeft.join(
testRelationRight,
joinType = Inner,
condition = Some(evaluableJoinCond)).where(unevaluableJoinCond).analyze
val optimized = Optimize.execute(query.analyze)
comparePlans(optimized, expected)
}

test("throw an exception for not support join type") {
for (joinType <- unsupportedJoinTypes) {
val thrownException = the [AnalysisException] thrownBy {
val e = intercept[AnalysisException] {
val query = testRelationLeft.join(
testRelationRight,
joinType,
condition = Some(pythonUDF))
condition = Some(unevaluableJoinCond))
Optimize.execute(query.analyze)
}
assert(thrownException.message.contentEquals(
assert(e.message.contentEquals(
s"Using PythonUDF in join condition of join type $joinType is not supported."))

val query2 = testRelationLeft.join(
testRelationRight,
joinType,
condition = Some(evaluableJoinCond))
comparePlans(Optimize.execute(query2), query2)
}
}
}