Skip to content

Commit 2df6ca8

Browse files
Davies Liudavies
authored andcommitted
[SPARK-15327] [SQL] fix split expression in whole stage codegen
## What changes were proposed in this pull request? Right now, we will split the code for expressions into multiple functions when it exceed 64k, which requires that the the expressions are using Row object, but this is not true for whole-state codegen, it will fail to compile after splitted. This PR will not split the code in whole-stage codegen. ## How was this patch tested? Added regression tests. Author: Davies Liu <[email protected]> Closes #13235 from davies/fix_nested_codegen.
1 parent 594484c commit 2df6ca8

File tree

4 files changed

+31
-0
lines changed

4 files changed

+31
-0
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,10 @@ class CodegenContext {
560560
* @param expressions the codes to evaluate expressions.
561561
*/
562562
def splitExpressions(row: String, expressions: Seq[String]): String = {
563+
if (row == null) {
564+
// Cannot split these expressions because they are not created from a row object.
565+
return expressions.mkString("\n")
566+
}
563567
val blocks = new ArrayBuffer[String]()
564568
val blockBuilder = new StringBuilder()
565569
for (code <- expressions) {

sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ trait CodegenSupport extends SparkPlan {
130130
}
131131
val evaluateInputs = evaluateVariables(outputVars)
132132
// generate the code to create a UnsafeRow
133+
ctx.INPUT_ROW = row
133134
ctx.currentVars = outputVars
134135
val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
135136
val code = s"""

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,8 @@ case class TungstenAggregate(
599599

600600
// create grouping key
601601
ctx.currentVars = input
602+
// make sure that the generated code will not be splitted as multiple functions
603+
ctx.INPUT_ROW = null
602604
val unsafeRowKeyCode = GenerateUnsafeProjection.createCode(
603605
ctx, groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output)))
604606
val vectorizedRowKeys = ctx.generateExpressions(

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2483,6 +2483,30 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
24832483
}
24842484
}
24852485

2486+
test("SPARK-15327: fail to compile generated code with complex data structure") {
2487+
withTempDir{ dir =>
2488+
val json =
2489+
"""
2490+
|{"h": {"b": {"c": [{"e": "adfgd"}], "a": [{"e": "testing", "count": 3}],
2491+
|"b": [{"e": "test", "count": 1}]}}, "d": {"b": {"c": [{"e": "adfgd"}],
2492+
|"a": [{"e": "testing", "count": 3}], "b": [{"e": "test", "count": 1}]}},
2493+
|"c": {"b": {"c": [{"e": "adfgd"}], "a": [{"count": 3}],
2494+
|"b": [{"e": "test", "count": 1}]}}, "a": {"b": {"c": [{"e": "adfgd"}],
2495+
|"a": [{"count": 3}], "b": [{"e": "test", "count": 1}]}},
2496+
|"e": {"b": {"c": [{"e": "adfgd"}], "a": [{"e": "testing", "count": 3}],
2497+
|"b": [{"e": "test", "count": 1}]}}, "g": {"b": {"c": [{"e": "adfgd"}],
2498+
|"a": [{"e": "testing", "count": 3}], "b": [{"e": "test", "count": 1}]}},
2499+
|"f": {"b": {"c": [{"e": "adfgd"}], "a": [{"e": "testing", "count": 3}],
2500+
|"b": [{"e": "test", "count": 1}]}}, "b": {"b": {"c": [{"e": "adfgd"}],
2501+
|"a": [{"count": 3}], "b": [{"e": "test", "count": 1}]}}}'
2502+
|
2503+
""".stripMargin
2504+
val rdd = sparkContext.parallelize(Array(json))
2505+
spark.read.json(rdd).write.mode("overwrite").parquet(dir.toString)
2506+
spark.read.parquet(dir.toString).collect()
2507+
}
2508+
}
2509+
24862510
test("SPARK-14986: Outer lateral view with empty generate expression") {
24872511
checkAnswer(
24882512
sql("select nil from (select 1 as x ) x lateral view outer explode(array()) n as nil"),

0 commit comments

Comments
 (0)