From 88c6d0b4bc9eaea2a00a26149eba75a201bb0458 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Tue, 21 Feb 2023 08:11:33 +0800 Subject: [PATCH] SPARK-42500: ConstantPropagation support more case --- .../sql/catalyst/optimizer/expressions.scala | 16 ++++++--- .../optimizer/ConstantPropagationSuite.scala | 33 +++++++++++++++++-- 2 files changed, 42 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 1d756a2dcb744..0ab6106c14c8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -200,14 +200,20 @@ object ConstantPropagation extends Rule[LogicalPlan] { private def replaceConstants(condition: Expression, equalityPredicates: EqualityPredicates) : Expression = { - val constantsMap = AttributeMap(equalityPredicates.map(_._1)) - val predicates = equalityPredicates.map(_._2).toSet - def replaceConstants0(expression: Expression) = expression transform { + val allConstantsMap = AttributeMap(equalityPredicates.map(_._1)) + val allPredicates = equalityPredicates.map(_._2).toSet + def replaceConstants0( + expression: Expression, constantsMap: AttributeMap[Literal]) = expression transform { case a: AttributeReference => constantsMap.getOrElse(a, a) } condition transform { - case e @ EqualTo(_, _) if !predicates.contains(e) => replaceConstants0(e) - case e @ EqualNullSafe(_, _) if !predicates.contains(e) => replaceConstants0(e) + case b: BinaryComparison => + if (!allPredicates.contains(b)) { + replaceConstants0(b, allConstantsMap) + } else { + val excludedEqualityPredicates = equalityPredicates.filterNot(_._2.semanticEquals(b)) + replaceConstants0(b, AttributeMap(excludedEqualityPredicates.map(_._1))) + } } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala index f5f1455f94611..f4787a55d89ac 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.Literal.FalseLiteral import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -159,8 +160,9 @@ class ConstantPropagationSuite extends PlanTest { columnA === Literal(1) && columnA === Literal(2) && columnB === Add(columnA, Literal(3))) val correctAnswer = testRelation - .select(columnA) - .where(columnA === Literal(1) && columnA === Literal(2) && columnB === Literal(5)).analyze + .select(columnA, columnB) + .where(FalseLiteral) + .select(columnA).analyze comparePlans(Optimize.execute(query.analyze), correctAnswer) } @@ -186,4 +188,31 @@ class ConstantPropagationSuite extends PlanTest { .analyze comparePlans(Optimize.execute(query2), correctAnswer2) } + + test("SPARK-42500: ConstantPropagation supports more cases") { + comparePlans( + Optimize.execute(testRelation.where(columnA === 1 && columnB > columnA + 2).analyze), + testRelation.where(columnA === 1 && columnB > 3).analyze) + + comparePlans( + Optimize.execute(testRelation.where(columnA === 1 && columnA === 2).analyze), + testRelation.where(FalseLiteral).analyze) + + comparePlans( + Optimize.execute(testRelation.where(columnA === 1 && columnA === columnA + 2).analyze), + testRelation.where(FalseLiteral).analyze) + + comparePlans( + Optimize.execute( + testRelation.where((columnA === 1 || columnB === 2) && columnB === 1).analyze), + testRelation.where(columnA === 1 && columnB === 1).analyze) + + comparePlans( + Optimize.execute(testRelation.where(columnA === 1 && columnA === 1).analyze), + testRelation.where(columnA === 1).analyze) + + comparePlans( + Optimize.execute(testRelation.where(Not(columnA === 1 && columnA === columnA + 2)).analyze), + testRelation.where(Not(columnA === 1) || Not(columnA === columnA + 2)).analyze) + } }