diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index d4eb516534f19..02c929ca063cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -886,8 +886,15 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] left: LogicalPlan, right: LogicalPlan, conditionOpt: Option[Expression]): Set[Expression] = { - val baseConstraints = left.constraints.union(right.constraints) - .union(conditionOpt.map(splitConjunctivePredicates).getOrElse(Nil).toSet) + val conjunctivePredicates = conditionOpt.map(splitConjunctivePredicates).getOrElse(Nil).toSet + val inferMorePredicates = conjunctivePredicates.flatMap { + case or @ Or(_, _) => + commonPredicatesInOr(or) + .filter(c => canEvaluate(c, left) || canEvaluate(c, right)) + case c => Seq(c) + } + + val baseConstraints = left.constraints.union(right.constraints).union(inferMorePredicates) baseConstraints.union(inferAdditionalConstraints(baseConstraints)) } @@ -903,6 +910,50 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] Filter(newPredicates.reduce(And), plan) } } + + /** + * for example, A join B, condition is + * ( + * ( + * x = y && (a IN (HZ,BJ) && (b >= 0)) && (b <= 20) + * ) + * || + * ( + * x && y && (a IN (SH,SZ) && (b >= 15)) && (b <= 30) + * ) + * ) + * ===> infer two predicates which can be push down + * 1) a IN (VA,TX,IA) || a IN (VA,TX,IA) + * 2) ((b >= 0) && (b <= 20)) || ((b >= 15) && (b <= 30)) + * then, 1) can be pushed to table A and 2) can be pushed to table B + */ + private def commonPredicatesInOr(e: Expression): Seq[Expression] = { + e match { + case Or(l, r) => + val left = commonPredicatesInOr(l) + val right = commonPredicatesInOr(r) + val commonPredicates = left.filter(p => right.exists(_.references == p.references)) + if (commonPredicates.nonEmpty) { + commonPredicates.map { e => + (Seq(e) ++ right.filter(_.references == e.references)).reduce(Or) + } + } else { + Seq() + } + case And(l, r) => + val p = splitConjunctivePredicates(l) ++ splitConjunctivePredicates(r) + processAndPredicates(p) + case _ => Seq(e) + } + } + + private def processAndPredicates(es: Seq[Expression]): Seq[Expression] = { + if (es.size > 1) { + val (first, rest) = es.splitAt(1) + val (hold, other) = rest.partition(_.references == first.head.references) + Seq((first ++ hold).reduce(And)) ++ processAndPredicates(other) + } else es + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index a40ba2dc38b70..6ee822209838c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -40,6 +40,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest { } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) private def testConstraintsAfterJoin( x: LogicalPlan, @@ -263,4 +264,51 @@ class InferFiltersFromConstraintsSuite extends PlanTest { val y = testRelation.subquery('y) testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y, RightOuter) } + + test("infer filters from inner join disjunctive condition") { + val t1 = testRelation.subquery('t1) + val t2 = testRelation2.subquery('t2) + + val originalQuery = t1 + .join(t2, Inner, Some( + ('a === 'd && 'b === 1 && 'e === 3) || ('a === 'd && 'b === 5 && 'e === 7))) + .analyze + val correctAnswer = t1 + .where(('b === 1 || 'b === 5) && IsNotNull('a)) + .join(t2.where(IsNotNull('d) && ('e === 3 || 'e === 7)), Inner, + Some(('a === 'd) && (('b === 1 && 'e === 3) || ('b === 5 && 'e === 7)))) + .analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + + test("infer filters from outer join disjunctive condition") { + val t1 = testRelation.subquery('t1) + val t2 = testRelation2.subquery('t2) + + // left outer join + val originalQuery = t1 + .join(t2, LeftOuter, Some( + ('a === 'd && 'b === 1 && 'e === 3) || ('a === 'd && 'b === 5 && 'e === 7))) + .analyze + val correctAnswer = t1 + .join(t2.where(IsNotNull('d) && ('e === 3 || 'e === 7)), LeftOuter, + Some(('a === 'd) && (('b === 1 && 'e === 3) || ('b === 5 && 'e === 7)))) + .analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + + // right outer join + val originalQuery1 = t1 + .join(t2, RightOuter, Some( + ('a === 'd && 'b === 1 && 'e === 3) || ('a === 'd && 'b === 5 && 'e === 7))) + .analyze + val correctAnswer1 = t1 + .where(('b === 1 || 'b === 5) && IsNotNull('a)) + .join(t2, RightOuter, + Some(('a === 'd) && (('b === 1 && 'e === 3) || ('b === 5 && 'e === 7)))) + .analyze + val optimized1 = Optimize.execute(originalQuery1) + comparePlans(optimized1, correctAnswer1) + } }