Skip to content

Commit a8943d1

Browse files
cloud-fanroot
authored andcommitted
[SPARK-37392][SQL] Fix the performance bug when inferring constraints for Generate
### What changes were proposed in this pull request? This is a performance regression since Spark 3.1, caused by https://issues.apache.org/jira/browse/SPARK-32295 If you run the query in the JIRA ticket ``` Seq( (1, "x", "x", "x", "x", "x", "x", "x", "x", "x", "x", "x", "x", "x", "x", "x", "x", "x", "x", "x", "x") ).toDF() .checkpoint() // or save and reload to truncate lineage .createOrReplaceTempView("sub") session.sql(""" SELECT * FROM ( SELECT EXPLODE( ARRAY( * ) ) result FROM ( SELECT _1 a, _2 b, _3 c, _4 d, _5 e, _6 f, _7 g, _8 h, _9 i, _10 j, _11 k, _12 l, _13 m, _14 n, _15 o, _16 p, _17 q, _18 r, _19 s, _20 t, _21 u FROM sub ) ) WHERE result != '' """).show() ``` You will hit OOM. The reason is that: 1. We infer additional predicates with `Generate`. In this case, it's `size(array(cast(_1#21 as string), _2#22, _3#23, ...) > 0` 2. Because of the cast, the `ConstantFolding` rule can't optimize this `size(array(...))`. 3. We end up with a plan containing this part ``` +- Project [_1#21 AS a#106, _2#22 AS b#107, _3#23 AS c#108, _4#24 AS d#109, _5#25 AS e#110, _6#26 AS f#111, _7#27 AS g#112, _8#28 AS h#113, _9#29 AS i#114, _10#30 AS j#115, _11#31 AS k#116, _12#32 AS l#117, _13#33 AS m#118, _14#34 AS n#119, _15#35 AS o#120, _16#36 AS p#121, _17#37 AS q#122, _18#38 AS r#123, _19#39 AS s#124, _20#40 AS t#125, _21#41 AS u#126] +- Filter (size(array(cast(_1#21 as string), _2#22, _3#23, _4#24, _5#25, _6#26, _7#27, _8#28, _9#29, _10#30, _11#31, _12#32, _13#33, _14#34, _15#35, _16#36, _17#37, _18#38, _19#39, _20#40, _21#41), true) > 0) +- LogicalRDD [_1#21, _2#22, _3#23, _4#24, _5#25, _6#26, _7#27, _8#28, _9#29, _10#30, _11#31, _12#32, _13#33, _14#34, _15#35, _16#36, _17#37, _18#38, _19#39, _20#40, _21#41] ``` When calculating the constraints of the `Project`, we generate around 2^20 expressions, due to this code ``` var allConstraints = child.constraints projectList.foreach { case a Alias(l: Literal, _) => allConstraints += EqualNullSafe(a.toAttribute, l) case a Alias(e, _) => // For every alias in `projectList`, replace the reference in constraints by its attribute. allConstraints ++= allConstraints.map(_ transform { case expr: Expression if expr.semanticEquals(e) => a.toAttribute }) allConstraints += EqualNullSafe(e, a.toAttribute) case _ => // Don't change. } ``` There are 3 issues here: 1. We may infer complicated predicates from `Generate` 2. `ConstanFolding` rule is too conservative. At least `Cast` has no side effect with ANSI-off. 3. When calculating constraints, we should have a upper bound to avoid generating too many expressions. This fixes the first 2 issues, and leaves the third one for the future. ### Why are the changes needed? fix a performance issue ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? new tests, and run the query in JIRA ticket locally. Closes apache#34823 from cloud-fan/perf. Authored-by: Wenchen Fan <[email protected]> Signed-off-by: Wenchen Fan <[email protected]> (cherry picked from commit 1fac7a9) Signed-off-by: Wenchen Fan <[email protected]>
1 parent afa1a16 commit a8943d1

File tree

3 files changed

+66
-72
lines changed

3 files changed

+66
-72
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,24 +1103,29 @@ object TransposeWindow extends Rule[LogicalPlan] {
11031103
object InferFiltersFromGenerate extends Rule[LogicalPlan] {
11041104
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
11051105
_.containsPattern(GENERATE)) {
1106-
// This rule does not infer filters from foldable expressions to avoid constant filters
1107-
// like 'size([1, 2, 3]) > 0'. These do not show up in child's constraints and
1108-
// then the idempotence will break.
1109-
case generate @ Generate(e, _, _, _, _, _)
1110-
if !e.deterministic || e.children.forall(_.foldable) ||
1111-
e.children.exists(_.isInstanceOf[UserDefinedExpression]) => generate
1112-
11131106
case generate @ Generate(g, _, false, _, _, _) if canInferFilters(g) =>
1114-
// Exclude child's constraints to guarantee idempotency
1115-
val inferredFilters = ExpressionSet(
1116-
Seq(
1117-
GreaterThan(Size(g.children.head), Literal(0)),
1118-
IsNotNull(g.children.head)
1119-
)
1120-
) -- generate.child.constraints
1121-
1122-
if (inferredFilters.nonEmpty) {
1123-
generate.copy(child = Filter(inferredFilters.reduce(And), generate.child))
1107+
assert(g.children.length == 1)
1108+
val input = g.children.head
1109+
// Generating extra predicates here has overheads/risks:
1110+
// - We may evaluate expensive input expressions multiple times.
1111+
// - We may infer too many constraints later.
1112+
// - The input expression may fail to be evaluated under ANSI mode. If we reorder the
1113+
// predicates and evaluate the input expression first, we may fail the query unexpectedly.
1114+
// To be safe, here we only generate extra predicates if the input is an attribute.
1115+
// Note that, foldable input is also excluded here, to avoid constant filters like
1116+
// 'size([1, 2, 3]) > 0'. These do not show up in child's constraints and then the
1117+
// idempotence will break.
1118+
if (input.isInstanceOf[Attribute]) {
1119+
// Exclude child's constraints to guarantee idempotency
1120+
val inferredFilters = ExpressionSet(
1121+
Seq(GreaterThan(Size(input), Literal(0)), IsNotNull(input))
1122+
) -- generate.child.constraints
1123+
1124+
if (inferredFilters.nonEmpty) {
1125+
generate.copy(child = Filter(inferredFilters.reduce(And), generate.child))
1126+
} else {
1127+
generate
1128+
}
11241129
} else {
11251130
generate
11261131
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ object ConstantFolding extends Rule[LogicalPlan] {
4747
private def hasNoSideEffect(e: Expression): Boolean = e match {
4848
case _: Attribute => true
4949
case _: Literal => true
50+
case c: Cast if !conf.ansiEnabled => hasNoSideEffect(c.child)
5051
case _: NoThrow if e.deterministic => e.children.forall(hasNoSideEffect)
5152
case _ => false
5253
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromGenerateSuite.scala

Lines changed: 43 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,8 @@
1818
package org.apache.spark.sql.catalyst.optimizer
1919

2020
import org.apache.spark.sql.Row
21-
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
2221
import org.apache.spark.sql.catalyst.dsl.expressions._
2322
import org.apache.spark.sql.catalyst.dsl.plans._
24-
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
2523
import org.apache.spark.sql.catalyst.expressions._
2624
import org.apache.spark.sql.catalyst.plans._
2725
import org.apache.spark.sql.catalyst.plans.logical._
@@ -36,7 +34,7 @@ class InferFiltersFromGenerateSuite extends PlanTest {
3634
val testRelation = LocalRelation('a.array(StructType(Seq(
3735
StructField("x", IntegerType),
3836
StructField("y", IntegerType)
39-
))), 'c1.string, 'c2.string)
37+
))), 'c1.string, 'c2.string, 'c3.int)
4038

4139
Seq(Explode(_), PosExplode(_), Inline(_)).foreach { f =>
4240
val generator = f('a)
@@ -74,63 +72,53 @@ class InferFiltersFromGenerateSuite extends PlanTest {
7472
val optimized = Optimize.execute(originalQuery)
7573
comparePlans(optimized, originalQuery)
7674
}
77-
}
7875

79-
// setup rules to test inferFilters with ConstantFolding to make sure
80-
// the Filter rule added in inferFilters is removed again when doing
81-
// explode with CreateArray/CreateMap
82-
object OptimizeInferAndConstantFold extends RuleExecutor[LogicalPlan] {
83-
val batches =
84-
Batch("AnalysisNodes", Once,
85-
EliminateSubqueryAliases) ::
86-
Batch("Infer Filters", Once, InferFiltersFromGenerate) ::
87-
Batch("ConstantFolding after", FixedPoint(4),
88-
ConstantFolding,
89-
NullPropagation,
90-
PruneFilters) :: Nil
76+
val generatorWithFromJson = f(JsonToStructs(
77+
ArrayType(new StructType().add("s", "string")),
78+
Map.empty,
79+
'c1))
80+
test("SPARK-37392: Don't infer filters from " + generatorWithFromJson) {
81+
val originalQuery = testRelation.generate(generatorWithFromJson).analyze
82+
val optimized = Optimize.execute(originalQuery)
83+
comparePlans(optimized, originalQuery)
84+
}
85+
86+
val returnSchema = ArrayType(StructType(Seq(
87+
StructField("x", IntegerType),
88+
StructField("y", StringType)
89+
)))
90+
val fakeUDF = ScalaUDF(
91+
(i: Int) => Array(Row.fromSeq(Seq(1, "a")), Row.fromSeq(Seq(2, "b"))),
92+
returnSchema, 'c3 :: Nil, Nil)
93+
val generatorWithUDF = f(fakeUDF)
94+
test("SPARK-36715: Don't infer filters from " + generatorWithUDF) {
95+
val originalQuery = testRelation.generate(generatorWithUDF).analyze
96+
val optimized = Optimize.execute(originalQuery)
97+
comparePlans(optimized, originalQuery)
98+
}
9199
}
92100

93101
Seq(Explode(_), PosExplode(_)).foreach { f =>
94-
val createArrayExplode = f(CreateArray(Seq('c1)))
95-
test("SPARK-33544: Don't infer filters from CreateArray " + createArrayExplode) {
96-
val originalQuery = testRelation.generate(createArrayExplode).analyze
97-
val optimized = OptimizeInferAndConstantFold.execute(originalQuery)
98-
comparePlans(optimized, originalQuery)
99-
}
100-
val createMapExplode = f(CreateMap(Seq('c1, 'c2)))
101-
test("SPARK-33544: Don't infer filters from CreateMap " + createMapExplode) {
102-
val originalQuery = testRelation.generate(createMapExplode).analyze
103-
val optimized = OptimizeInferAndConstantFold.execute(originalQuery)
104-
comparePlans(optimized, originalQuery)
105-
}
106-
}
107-
108-
Seq(Inline(_)).foreach { f =>
109-
val createArrayStructExplode = f(CreateArray(Seq(CreateStruct(Seq('c1)))))
110-
test("SPARK-33544: Don't infer filters from CreateArray " + createArrayStructExplode) {
111-
val originalQuery = testRelation.generate(createArrayStructExplode).analyze
112-
val optimized = OptimizeInferAndConstantFold.execute(originalQuery)
113-
comparePlans(optimized, originalQuery)
114-
}
115-
}
102+
val createArrayExplode = f(CreateArray(Seq('c1)))
103+
test("SPARK-33544: Don't infer filters from " + createArrayExplode) {
104+
val originalQuery = testRelation.generate(createArrayExplode).analyze
105+
val optimized = Optimize.execute(originalQuery)
106+
comparePlans(optimized, originalQuery)
107+
}
108+
val createMapExplode = f(CreateMap(Seq('c1, 'c2)))
109+
test("SPARK-33544: Don't infer filters from " + createMapExplode) {
110+
val originalQuery = testRelation.generate(createMapExplode).analyze
111+
val optimized = Optimize.execute(originalQuery)
112+
comparePlans(optimized, originalQuery)
113+
}
114+
}
116115

117-
test("SPARK-36715: Don't infer filters from udf") {
118-
Seq(Explode(_), PosExplode(_), Inline(_)).foreach { f =>
119-
val returnSchema = ArrayType(StructType(Seq(
120-
StructField("x", IntegerType),
121-
StructField("y", StringType)
122-
)))
123-
val fakeUDF = ScalaUDF(
124-
(i: Int) => Array(Row.fromSeq(Seq(1, "a")), Row.fromSeq(Seq(2, "b"))),
125-
returnSchema, Literal(8) :: Nil,
126-
Option(ExpressionEncoder[Int]().resolveAndBind()) :: Nil)
127-
val generator = f(fakeUDF)
128-
val originalQuery = OneRowRelation().generate(generator).analyze
129-
val optimized = OptimizeInferAndConstantFold.execute(originalQuery)
130-
val correctAnswer = OneRowRelation()
131-
.generate(generator)
132-
.analyze
133-
comparePlans(optimized, correctAnswer)
116+
Seq(Inline(_)).foreach { f =>
117+
val createArrayStructExplode = f(CreateArray(Seq(CreateStruct(Seq('c1)))))
118+
test("SPARK-33544: Don't infer filters from " + createArrayStructExplode) {
119+
val originalQuery = testRelation.generate(createArrayStructExplode).analyze
120+
val optimized = Optimize.execute(originalQuery)
121+
comparePlans(optimized, originalQuery)
134122
}
135123
}
136124
}

0 commit comments

Comments
 (0)