Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1170,24 +1170,29 @@ object TransposeWindow extends Rule[LogicalPlan] {
object InferFiltersFromGenerate extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
_.containsPattern(GENERATE)) {
// This rule does not infer filters from foldable expressions to avoid constant filters
// like 'size([1, 2, 3]) > 0'. These do not show up in child's constraints and
// then the idempotence will break.
case generate @ Generate(e, _, _, _, _, _)
if !e.deterministic || e.children.forall(_.foldable) ||
e.children.exists(_.isInstanceOf[UserDefinedExpression]) => generate

case generate @ Generate(g, _, false, _, _, _) if canInferFilters(g) =>
// Exclude child's constraints to guarantee idempotency
val inferredFilters = ExpressionSet(
Seq(
GreaterThan(Size(g.children.head), Literal(0)),
IsNotNull(g.children.head)
)
) -- generate.child.constraints

if (inferredFilters.nonEmpty) {
generate.copy(child = Filter(inferredFilters.reduce(And), generate.child))
assert(g.children.length == 1)
val input = g.children.head
// Generating extra predicates here has overheads/risks:
// - We may evaluate expensive input expressions multiple times.
// - We may infer too many constraints later.
// - The input expression may fail to be evaluated under ANSI mode. If we reorder the
// predicates and evaluate the input expression first, we may fail the query unexpectedly.
// To be safe, here we only generate extra predicates if the input is an attribute.
// Note that, foldable input is also excluded here, to avoid constant filters like
// 'size([1, 2, 3]) > 0'. These do not show up in child's constraints and then the
// idempotence will break.
if (input.isInstanceOf[Attribute]) {
Copy link
Contributor Author

@cloud-fan cloud-fan Dec 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's almost useless to generate predicate with CreateArray/CreateMap. Size(CreateArray(...)) > 0 is always true unless you create an empty array.

// Exclude child's constraints to guarantee idempotency
val inferredFilters = ExpressionSet(
Seq(GreaterThan(Size(input), Literal(0)), IsNotNull(input))
) -- generate.child.constraints

if (inferredFilters.nonEmpty) {
generate.copy(child = Filter(inferredFilters.reduce(And), generate.child))
} else {
generate
}
} else {
generate
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ object ConstantFolding extends Rule[LogicalPlan] {
private def hasNoSideEffect(e: Expression): Boolean = e match {
case _: Attribute => true
case _: Literal => true
case c: Cast if !conf.ansiEnabled => hasNoSideEffect(c.child)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either this change or the change in InferFiltersFromGenerate can fix the perf issue. But I keep both fixes to be super safe.

case _: NoThrow if e.deterministic => e.children.forall(hasNoSideEffect)
case _ => false
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,8 @@
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
Expand All @@ -36,7 +34,7 @@ class InferFiltersFromGenerateSuite extends PlanTest {
val testRelation = LocalRelation('a.array(StructType(Seq(
StructField("x", IntegerType),
StructField("y", IntegerType)
))), 'c1.string, 'c2.string)
))), 'c1.string, 'c2.string, 'c3.int)

Seq(Explode(_), PosExplode(_), Inline(_)).foreach { f =>
val generator = f('a)
Expand Down Expand Up @@ -74,63 +72,53 @@ class InferFiltersFromGenerateSuite extends PlanTest {
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, originalQuery)
}
}

// setup rules to test inferFilters with ConstantFolding to make sure
// the Filter rule added in inferFilters is removed again when doing
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need this now, because we don't infer the filters at the first place, instead of removing the inferred filters later.

// explode with CreateArray/CreateMap
object OptimizeInferAndConstantFold extends RuleExecutor[LogicalPlan] {
val batches =
Batch("AnalysisNodes", Once,
EliminateSubqueryAliases) ::
Batch("Infer Filters", Once, InferFiltersFromGenerate) ::
Batch("ConstantFolding after", FixedPoint(4),
ConstantFolding,
NullPropagation,
PruneFilters) :: Nil
val generatorWithFromJson = f(JsonToStructs(
ArrayType(new StructType().add("s", "string")),
Map.empty,
'c1))
test("SPARK-37392: Don't infer filters from " + generatorWithFromJson) {
val originalQuery = testRelation.generate(generatorWithFromJson).analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, originalQuery)
}

val returnSchema = ArrayType(StructType(Seq(
StructField("x", IntegerType),
StructField("y", StringType)
)))
val fakeUDF = ScalaUDF(
(i: Int) => Array(Row.fromSeq(Seq(1, "a")), Row.fromSeq(Seq(2, "b"))),
returnSchema, 'c3 :: Nil, Nil)
val generatorWithUDF = f(fakeUDF)
test("SPARK-36715: Don't infer filters from " + generatorWithUDF) {
val originalQuery = testRelation.generate(generatorWithUDF).analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, originalQuery)
}
}

Seq(Explode(_), PosExplode(_)).foreach { f =>
val createArrayExplode = f(CreateArray(Seq('c1)))
test("SPARK-33544: Don't infer filters from CreateArray " + createArrayExplode) {
val originalQuery = testRelation.generate(createArrayExplode).analyze
val optimized = OptimizeInferAndConstantFold.execute(originalQuery)
comparePlans(optimized, originalQuery)
}
val createMapExplode = f(CreateMap(Seq('c1, 'c2)))
test("SPARK-33544: Don't infer filters from CreateMap " + createMapExplode) {
val originalQuery = testRelation.generate(createMapExplode).analyze
val optimized = OptimizeInferAndConstantFold.execute(originalQuery)
comparePlans(optimized, originalQuery)
}
}

Seq(Inline(_)).foreach { f =>
val createArrayStructExplode = f(CreateArray(Seq(CreateStruct(Seq('c1)))))
test("SPARK-33544: Don't infer filters from CreateArray " + createArrayStructExplode) {
val originalQuery = testRelation.generate(createArrayStructExplode).analyze
val optimized = OptimizeInferAndConstantFold.execute(originalQuery)
comparePlans(optimized, originalQuery)
}
}
val createArrayExplode = f(CreateArray(Seq('c1)))
test("SPARK-33544: Don't infer filters from " + createArrayExplode) {
val originalQuery = testRelation.generate(createArrayExplode).analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, originalQuery)
}
val createMapExplode = f(CreateMap(Seq('c1, 'c2)))
test("SPARK-33544: Don't infer filters from " + createMapExplode) {
val originalQuery = testRelation.generate(createMapExplode).analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, originalQuery)
}
}

test("SPARK-36715: Don't infer filters from udf") {
Seq(Explode(_), PosExplode(_), Inline(_)).foreach { f =>
val returnSchema = ArrayType(StructType(Seq(
StructField("x", IntegerType),
StructField("y", StringType)
)))
val fakeUDF = ScalaUDF(
(i: Int) => Array(Row.fromSeq(Seq(1, "a")), Row.fromSeq(Seq(2, "b"))),
returnSchema, Literal(8) :: Nil,
Option(ExpressionEncoder[Int]().resolveAndBind()) :: Nil)
val generator = f(fakeUDF)
val originalQuery = OneRowRelation().generate(generator).analyze
val optimized = OptimizeInferAndConstantFold.execute(originalQuery)
val correctAnswer = OneRowRelation()
.generate(generator)
.analyze
comparePlans(optimized, correctAnswer)
Seq(Inline(_)).foreach { f =>
val createArrayStructExplode = f(CreateArray(Seq(CreateStruct(Seq('c1)))))
test("SPARK-33544: Don't infer filters from " + createArrayStructExplode) {
val originalQuery = testRelation.generate(createArrayStructExplode).analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, originalQuery)
}
}
}