diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 687711a389ac..12e68882b0a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1170,24 +1170,29 @@ object TransposeWindow extends Rule[LogicalPlan] { object InferFiltersFromGenerate extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning( _.containsPattern(GENERATE)) { - // This rule does not infer filters from foldable expressions to avoid constant filters - // like 'size([1, 2, 3]) > 0'. These do not show up in child's constraints and - // then the idempotence will break. - case generate @ Generate(e, _, _, _, _, _) - if !e.deterministic || e.children.forall(_.foldable) || - e.children.exists(_.isInstanceOf[UserDefinedExpression]) => generate - case generate @ Generate(g, _, false, _, _, _) if canInferFilters(g) => - // Exclude child's constraints to guarantee idempotency - val inferredFilters = ExpressionSet( - Seq( - GreaterThan(Size(g.children.head), Literal(0)), - IsNotNull(g.children.head) - ) - ) -- generate.child.constraints - - if (inferredFilters.nonEmpty) { - generate.copy(child = Filter(inferredFilters.reduce(And), generate.child)) + assert(g.children.length == 1) + val input = g.children.head + // Generating extra predicates here has overheads/risks: + // - We may evaluate expensive input expressions multiple times. + // - We may infer too many constraints later. + // - The input expression may fail to be evaluated under ANSI mode. If we reorder the + // predicates and evaluate the input expression first, we may fail the query unexpectedly. + // To be safe, here we only generate extra predicates if the input is an attribute. + // Note that, foldable input is also excluded here, to avoid constant filters like + // 'size([1, 2, 3]) > 0'. These do not show up in child's constraints and then the + // idempotence will break. + if (input.isInstanceOf[Attribute]) { + // Exclude child's constraints to guarantee idempotency + val inferredFilters = ExpressionSet( + Seq(GreaterThan(Size(input), Literal(0)), IsNotNull(input)) + ) -- generate.child.constraints + + if (inferredFilters.nonEmpty) { + generate.copy(child = Filter(inferredFilters.reduce(And), generate.child)) + } else { + generate + } } else { generate } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 3969ac1836f1..b00293039122 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -47,6 +47,7 @@ object ConstantFolding extends Rule[LogicalPlan] { private def hasNoSideEffect(e: Expression): Boolean = e match { case _: Attribute => true case _: Literal => true + case c: Cast if !conf.ansiEnabled => hasNoSideEffect(c.child) case _: NoThrow if e.deterministic => e.children.forall(hasNoSideEffect) case _ => false } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromGenerateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromGenerateSuite.scala index 800d37eaa0d4..61ab4f027ed2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromGenerateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromGenerateSuite.scala @@ -18,10 +18,8 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -36,7 +34,7 @@ class InferFiltersFromGenerateSuite extends PlanTest { val testRelation = LocalRelation('a.array(StructType(Seq( StructField("x", IntegerType), StructField("y", IntegerType) - ))), 'c1.string, 'c2.string) + ))), 'c1.string, 'c2.string, 'c3.int) Seq(Explode(_), PosExplode(_), Inline(_)).foreach { f => val generator = f('a) @@ -74,63 +72,53 @@ class InferFiltersFromGenerateSuite extends PlanTest { val optimized = Optimize.execute(originalQuery) comparePlans(optimized, originalQuery) } - } - // setup rules to test inferFilters with ConstantFolding to make sure - // the Filter rule added in inferFilters is removed again when doing - // explode with CreateArray/CreateMap - object OptimizeInferAndConstantFold extends RuleExecutor[LogicalPlan] { - val batches = - Batch("AnalysisNodes", Once, - EliminateSubqueryAliases) :: - Batch("Infer Filters", Once, InferFiltersFromGenerate) :: - Batch("ConstantFolding after", FixedPoint(4), - ConstantFolding, - NullPropagation, - PruneFilters) :: Nil + val generatorWithFromJson = f(JsonToStructs( + ArrayType(new StructType().add("s", "string")), + Map.empty, + 'c1)) + test("SPARK-37392: Don't infer filters from " + generatorWithFromJson) { + val originalQuery = testRelation.generate(generatorWithFromJson).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, originalQuery) + } + + val returnSchema = ArrayType(StructType(Seq( + StructField("x", IntegerType), + StructField("y", StringType) + ))) + val fakeUDF = ScalaUDF( + (i: Int) => Array(Row.fromSeq(Seq(1, "a")), Row.fromSeq(Seq(2, "b"))), + returnSchema, 'c3 :: Nil, Nil) + val generatorWithUDF = f(fakeUDF) + test("SPARK-36715: Don't infer filters from " + generatorWithUDF) { + val originalQuery = testRelation.generate(generatorWithUDF).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, originalQuery) + } } Seq(Explode(_), PosExplode(_)).foreach { f => - val createArrayExplode = f(CreateArray(Seq('c1))) - test("SPARK-33544: Don't infer filters from CreateArray " + createArrayExplode) { - val originalQuery = testRelation.generate(createArrayExplode).analyze - val optimized = OptimizeInferAndConstantFold.execute(originalQuery) - comparePlans(optimized, originalQuery) - } - val createMapExplode = f(CreateMap(Seq('c1, 'c2))) - test("SPARK-33544: Don't infer filters from CreateMap " + createMapExplode) { - val originalQuery = testRelation.generate(createMapExplode).analyze - val optimized = OptimizeInferAndConstantFold.execute(originalQuery) - comparePlans(optimized, originalQuery) - } - } - - Seq(Inline(_)).foreach { f => - val createArrayStructExplode = f(CreateArray(Seq(CreateStruct(Seq('c1))))) - test("SPARK-33544: Don't infer filters from CreateArray " + createArrayStructExplode) { - val originalQuery = testRelation.generate(createArrayStructExplode).analyze - val optimized = OptimizeInferAndConstantFold.execute(originalQuery) - comparePlans(optimized, originalQuery) - } - } + val createArrayExplode = f(CreateArray(Seq('c1))) + test("SPARK-33544: Don't infer filters from " + createArrayExplode) { + val originalQuery = testRelation.generate(createArrayExplode).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, originalQuery) + } + val createMapExplode = f(CreateMap(Seq('c1, 'c2))) + test("SPARK-33544: Don't infer filters from " + createMapExplode) { + val originalQuery = testRelation.generate(createMapExplode).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, originalQuery) + } + } - test("SPARK-36715: Don't infer filters from udf") { - Seq(Explode(_), PosExplode(_), Inline(_)).foreach { f => - val returnSchema = ArrayType(StructType(Seq( - StructField("x", IntegerType), - StructField("y", StringType) - ))) - val fakeUDF = ScalaUDF( - (i: Int) => Array(Row.fromSeq(Seq(1, "a")), Row.fromSeq(Seq(2, "b"))), - returnSchema, Literal(8) :: Nil, - Option(ExpressionEncoder[Int]().resolveAndBind()) :: Nil) - val generator = f(fakeUDF) - val originalQuery = OneRowRelation().generate(generator).analyze - val optimized = OptimizeInferAndConstantFold.execute(originalQuery) - val correctAnswer = OneRowRelation() - .generate(generator) - .analyze - comparePlans(optimized, correctAnswer) + Seq(Inline(_)).foreach { f => + val createArrayStructExplode = f(CreateArray(Seq(CreateStruct(Seq('c1))))) + test("SPARK-33544: Don't infer filters from " + createArrayStructExplode) { + val originalQuery = testRelation.generate(createArrayStructExplode).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, originalQuery) } } }