diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index 1355003358b9f..4c4ec000d0930 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -62,11 +62,17 @@ trait ConstraintHelper { */ def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = { var inferredConstraints = Set.empty[Expression] - constraints.foreach { + // IsNotNull should be constructed by `constructIsNotNullConstraints`. + val predicates = constraints.filterNot(_.isInstanceOf[IsNotNull]) + predicates.foreach { case eq @ EqualTo(l: Attribute, r: Attribute) => - val candidateConstraints = constraints - eq + val candidateConstraints = predicates - eq inferredConstraints ++= replaceConstraints(candidateConstraints, l, r) inferredConstraints ++= replaceConstraints(candidateConstraints, r, l) + case eq @ EqualTo(l @ Cast(_: Attribute, _, _), r: Attribute) => + inferredConstraints ++= replaceConstraints(predicates - eq, r, l) + case eq @ EqualTo(l: Attribute, r @ Cast(_: Attribute, _, _)) => + inferredConstraints ++= replaceConstraints(predicates - eq, l, r) case _ => // No inference } inferredConstraints -- constraints @@ -75,7 +81,7 @@ trait ConstraintHelper { private def replaceConstraints( constraints: Set[Expression], source: Expression, - destination: Attribute): Set[Expression] = constraints.map(_ transform { + destination: Expression): Set[Expression] = constraints.map(_ transform { case e: Expression if e.semanticEquals(source) => destination }) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 974bc781d36ab..79bd573f1d84a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{IntegerType, LongType} class InferFiltersFromConstraintsSuite extends PlanTest { @@ -46,8 +47,8 @@ class InferFiltersFromConstraintsSuite extends PlanTest { y: LogicalPlan, expectedLeft: LogicalPlan, expectedRight: LogicalPlan, - joinType: JoinType) = { - val condition = Some("x.a".attr === "y.a".attr) + joinType: JoinType, + condition: Option[Expression] = Some("x.a".attr === "y.a".attr)) = { val originalQuery = x.join(y, joinType, condition).analyze val correctAnswer = expectedLeft.join(expectedRight, joinType, condition).analyze val optimized = Optimize.execute(originalQuery) @@ -263,4 +264,56 @@ class InferFiltersFromConstraintsSuite extends PlanTest { val y = testRelation.subquery('y) testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y, RightOuter) } + + test("Constraints should be inferred from cast equality constraint(filter higher data type)") { + val testRelation1 = LocalRelation('a.int) + val testRelation2 = LocalRelation('b.long) + val originalLeft = testRelation1.subquery('left) + val originalRight = testRelation2.where('b === 1L).subquery('right) + + val left = testRelation1.where(IsNotNull('a) && 'a.cast(LongType) === 1L).subquery('left) + val right = testRelation2.where(IsNotNull('b) && 'b === 1L).subquery('right) + + Seq(Some("left.a".attr.cast(LongType) === "right.b".attr), + Some("right.b".attr === "left.a".attr.cast(LongType))).foreach { condition => + testConstraintsAfterJoin(originalLeft, originalRight, left, right, Inner, condition) + } + + Seq(Some("left.a".attr === "right.b".attr.cast(IntegerType)), + Some("right.b".attr.cast(IntegerType) === "left.a".attr)).foreach { condition => + testConstraintsAfterJoin( + originalLeft, + originalRight, + testRelation1.where(IsNotNull('a)).subquery('left), + right, + Inner, + condition) + } + } + + test("Constraints shouldn't be inferred from cast equality constraint(filter lower data type)") { + val testRelation1 = LocalRelation('a.int) + val testRelation2 = LocalRelation('b.long) + val originalLeft = testRelation1.where('a === 1).subquery('left) + val originalRight = testRelation2.subquery('right) + + val left = testRelation1.where(IsNotNull('a) && 'a === 1).subquery('left) + val right = testRelation2.where(IsNotNull('b)).subquery('right) + + Seq(Some("left.a".attr.cast(LongType) === "right.b".attr), + Some("right.b".attr === "left.a".attr.cast(LongType))).foreach { condition => + testConstraintsAfterJoin(originalLeft, originalRight, left, right, Inner, condition) + } + + Seq(Some("left.a".attr === "right.b".attr.cast(IntegerType)), + Some("right.b".attr.cast(IntegerType) === "left.a".attr)).foreach { condition => + testConstraintsAfterJoin( + originalLeft, + originalRight, + left, + testRelation2.where(IsNotNull('b) && 'b.attr.cast(IntegerType) === 1).subquery('right), + Inner, + condition) + } + } }