Skip to content

Commit 1480ae3

Browse files
committed
address review comments
1 parent ace2944 commit 1480ae3

File tree

2 files changed

+41
-36
lines changed

2 files changed

+41
-36
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import org.apache.spark.rdd.RDD
2626
import org.apache.spark.sql._
2727
import org.apache.spark.sql.catalyst.InternalRow
2828
import org.apache.spark.sql.catalyst.expressions.Attribute
29-
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeFormatter, CodegenContext, ExprCode}
29+
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodegenContext, ExprCode}
3030
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
3131
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
3232
import org.apache.spark.util.{AccumulatorV2, LongAccumulator}
@@ -68,8 +68,13 @@ package object debug {
6868
output
6969
}
7070

71-
private def codegenSubtreeSourceSeq(plan: SparkPlan):
72-
Seq[(WholeStageCodegenExec, CodeAndComment)] = {
71+
/**
72+
* Get WholeStageCodegenExec subtrees and the codegen in a query plan
73+
*
74+
* @param plan the query plan for codegen
75+
* @return Sequence of WholeStageCodegen subtrees and corresponding codegen
76+
*/
77+
def codegenStringSeq(plan: SparkPlan): Seq[(String, String)] = {
7378
val codegenSubtrees = new collection.mutable.HashSet[WholeStageCodegenExec]()
7479
plan transform {
7580
case s: WholeStageCodegenExec =>
@@ -79,31 +84,10 @@ package object debug {
7984
}
8085
codegenSubtrees.toSeq.map { subtree =>
8186
val (_, source) = subtree.doCodeGen()
82-
(subtree, source)
87+
(subtree.toString, CodeFormatter.format(source))
8388
}
8489
}
8590

86-
/**
87-
* Get WholeStageCodegenExec subtrees and the codegen in a query plan
88-
*
89-
* @param plan the query plan for codegen
90-
* @return Sequence of WholeStageCodegen subtrees and corresponding codegen
91-
*/
92-
def codegenStringSeq(plan: SparkPlan): Seq[(String, String)] = {
93-
codegenSubtreeSourceSeq(plan).map(s => (s._1.toString, CodeFormatter.format(s._2)))
94-
}
95-
96-
97-
/**
98-
* Get WholeStageCodegenExec subtrees' CodeAndComment in a query plan
99-
*
100-
* @param plan the query plan for CodeAndComment
101-
* @return Sequence of WholeStageCodegen subtrees' `CodeAndComment`
102-
*/
103-
def codegenCodeAndCommentSeq(plan: SparkPlan): Seq[CodeAndComment] = {
104-
codegenSubtreeSourceSeq(plan).map(_._2)
105-
}
106-
10791
/**
10892
* Augments [[Dataset]]s with debug methods.
10993
*/

sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@ package org.apache.spark.sql
1919

2020
import org.scalatest.BeforeAndAfterAll
2121

22-
import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator
22+
import org.apache.spark.internal.Logging
23+
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodeGenerator}
2324
import org.apache.spark.sql.catalyst.rules.RuleExecutor
2425
import org.apache.spark.sql.catalyst.util.resourceToString
25-
import org.apache.spark.sql.execution.debug
26+
import org.apache.spark.sql.execution.{SparkPlan, WholeStageCodegenExec}
2627
import org.apache.spark.sql.internal.SQLConf
2728
import org.apache.spark.sql.test.SharedSQLContext
2829
import org.apache.spark.util.Utils
@@ -31,7 +32,7 @@ import org.apache.spark.util.Utils
3132
* This test suite ensures all the TPC-DS queries can be successfully analyzed and optimized
3233
* without hitting the max iteration threshold.
3334
*/
34-
class TPCDSQuerySuite extends QueryTest with SharedSQLContext with BeforeAndAfterAll {
35+
class TPCDSQuerySuite extends QueryTest with SharedSQLContext with BeforeAndAfterAll with Logging {
3536

3637
// When Utils.isTesting is true, the RuleExecutor will issue an exception when hitting
3738
// the max iteration of analyzer/optimizer batches.
@@ -350,16 +351,38 @@ class TPCDSQuerySuite extends QueryTest with SharedSQLContext with BeforeAndAfte
350351
"q81", "q82", "q83", "q84", "q85", "q86", "q87", "q88", "q89", "q90",
351352
"q91", "q92", "q93", "q94", "q95", "q96", "q97", "q98", "q99")
352353

354+
private def checkGeneratedCode(plan: SparkPlan): Unit = {
355+
val codegenSubtrees = new collection.mutable.HashSet[WholeStageCodegenExec]()
356+
plan foreach {
357+
case s: WholeStageCodegenExec =>
358+
codegenSubtrees += s
359+
case s => s
360+
}
361+
codegenSubtrees.toSeq.map { subtree =>
362+
val code = subtree.doCodeGen()._2
363+
try {
364+
// Just check the generated code can be properly compiled
365+
CodeGenerator.compile(code)
366+
} catch {
367+
case e: Exception =>
368+
logError(s"failed to compile: $e", e)
369+
val msg =
370+
s"Subtree:\n$subtree\n" +
371+
s"Generated code:\n${CodeFormatter.format(code)}\n"
372+
logDebug(msg)
373+
throw e
374+
}
375+
}
376+
}
377+
353378
tpcdsQueries.foreach { name =>
354379
val queryString = resourceToString(s"tpcds/$name.sql",
355380
classLoader = Thread.currentThread().getContextClassLoader)
356381
test(name) {
357382
withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
358383
// check the plans can be properly generated
359-
val p = sql(queryString).queryExecution.executedPlan
360-
// check the generated code can be properly compiled
361-
val codes = debug.codegenCodeAndCommentSeq(p)
362-
codes.map(c => CodeGenerator.compile(c))
384+
val plan = sql(queryString).queryExecution.executedPlan
385+
checkGeneratedCode(plan)
363386
}
364387
}
365388
}
@@ -374,10 +397,8 @@ class TPCDSQuerySuite extends QueryTest with SharedSQLContext with BeforeAndAfte
374397
classLoader = Thread.currentThread().getContextClassLoader)
375398
test(s"modified-$name") {
376399
// check the plans can be properly generated
377-
val p = sql(queryString).queryExecution.executedPlan
378-
// check the generated code can be properly compiled
379-
val codes = debug.codegenCodeAndCommentSeq(p)
380-
codes.map(c => CodeGenerator.compile(c))
400+
val plan = sql(queryString).queryExecution.executedPlan
401+
checkGeneratedCode(plan)
381402
}
382403
}
383404
}

0 commit comments

Comments
 (0)