diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index a74b288cb22ce..69b8aeef99f04 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -230,8 +230,12 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy override def simpleString: String = statePrefix + super.simpleString - override def treeChildren: Seq[PlanType] = { - val subqueries = expressions.flatMap(_.collect {case e: SubqueryExpression => e}) - children ++ subqueries.map(e => e.plan.asInstanceOf[PlanType]) + /** + * All the subqueries of current plan. + */ + def subqueries: Seq[PlanType] = { + expressions.flatMap(_.collect {case e: SubqueryExpression => e.plan.asInstanceOf[PlanType]}) } + + override def innerChildren: Seq[PlanType] = subqueries } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index e46ce1cee7c6c..c463d2b6246f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -450,9 +450,52 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { /** * All the nodes that will be used to generate tree string. + * + * For example: + * + * WholeStageCodegen + * +-- SortMergeJoin + * |-- InputAdapter + * | +-- Sort + * +-- InputAdapter + * +-- Sort + * + * the treeChildren of WholeStageCodegen will be Seq(Sort, Sort), it will generate a tree string + * like this: + * + * WholeStageCodegen + * : +- SortMergeJoin + * : :- INPUT + * : :- INPUT + * :- Sort + * :- Sort */ protected def treeChildren: Seq[BaseType] = children + /** + * All the nodes that are parts of this node. + * + * For example: + * + * WholeStageCodegen + * +- SortMergeJoin + * |-- InputAdapter + * | +-- Sort + * +-- InputAdapter + * +-- Sort + * + * the innerChildren of WholeStageCodegen will be Seq(SortMergeJoin), it will generate a tree + * string like this: + * + * WholeStageCodegen + * : +- SortMergeJoin + * : :- INPUT + * : :- INPUT + * :- Sort + * :- Sort + */ + protected def innerChildren: Seq[BaseType] = Nil + /** * Appends the string represent of this node and its children to the given StringBuilder. * @@ -475,6 +518,12 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { builder.append(simpleString) builder.append("\n") + if (innerChildren.nonEmpty) { + innerChildren.init.foreach(_.generateTreeString( + depth + 2, lastChildren :+ false :+ false, builder)) + innerChildren.last.generateTreeString(depth + 2, lastChildren :+ false :+ true, builder) + } + if (treeChildren.nonEmpty) { treeChildren.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder)) treeChildren.last.generateTreeString(depth + 1, lastChildren :+ true, builder) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala index 4dd9928244197..9019e5dfd66c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala @@ -36,11 +36,8 @@ class SparkPlanInfo( private[sql] object SparkPlanInfo { def fromSparkPlan(plan: SparkPlan): SparkPlanInfo = { - val children = plan match { - case WholeStageCodegen(child, _) => child :: Nil - case InputAdapter(child) => child :: Nil - case plan => plan.children - } + + val children = plan.children ++ plan.subqueries val metrics = plan.metrics.toSeq.map { case (key, metric) => new SQLMetricInfo(metric.name.getOrElse(key), metric.id, Utils.getFormattedClassName(metric.param)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index cb68ca6ada366..6d231bf74a0e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution -import scala.collection.mutable.ArrayBuffer - import org.apache.spark.broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext @@ -29,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.toCommentSafeString import org.apache.spark.sql.execution.aggregate.TungstenAggregate -import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight, SortMergeJoin} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin} import org.apache.spark.sql.execution.metric.LongSQLMetricValue /** @@ -163,16 +161,12 @@ trait CodegenSupport extends SparkPlan { * This is the leaf node of a tree with WholeStageCodegen, is used to generate code that consumes * an RDD iterator of InternalRow. */ -case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { +case class InputAdapter(child: SparkPlan) extends UnaryNode with CodegenSupport { override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = child.outputPartitioning override def outputOrdering: Seq[SortOrder] = child.outputOrdering - override def doPrepare(): Unit = { - child.prepare() - } - override def doExecute(): RDD[InternalRow] = { child.execute() } @@ -181,8 +175,6 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { child.doExecuteBroadcast() } - override def supportCodegen: Boolean = false - override def upstreams(): Seq[RDD[InternalRow]] = { child.execute() :: Nil } @@ -210,6 +202,8 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { } override def simpleString: String = "INPUT" + + override def treeChildren: Seq[SparkPlan] = Nil } /** @@ -243,22 +237,15 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { * doCodeGen() will create a CodeGenContext, which will hold a list of variables for input, * used to generated code for BoundReference. */ -case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) - extends SparkPlan with CodegenSupport { - - override def supportCodegen: Boolean = false - - override def output: Seq[Attribute] = plan.output - override def outputPartitioning: Partitioning = plan.outputPartitioning - override def outputOrdering: Seq[SortOrder] = plan.outputOrdering +case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSupport { - override def doPrepare(): Unit = { - plan.prepare() - } + override def output: Seq[Attribute] = child.output + override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputOrdering: Seq[SortOrder] = child.outputOrdering override def doExecute(): RDD[InternalRow] = { val ctx = new CodegenContext - val code = plan.produce(ctx, this) + val code = child.asInstanceOf[CodegenSupport].produce(ctx, this) val references = ctx.references.toArray val source = s""" public Object generate(Object[] references) { @@ -266,7 +253,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) } /** Codegened pipeline for: - * ${toCommentSafeString(plan.treeString.trim)} + * ${toCommentSafeString(child.treeString.trim)} */ class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { @@ -294,7 +281,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) // println(s"${CodeFormatter.format(cleanedSource)}") CodeGenerator.compile(cleanedSource) - val rdds = plan.upstreams() + val rdds = child.asInstanceOf[CodegenSupport].upstreams() assert(rdds.size <= 2, "Up to two upstream RDDs can be supported") if (rdds.length == 1) { rdds.head.mapPartitions { iter => @@ -361,34 +348,17 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) } } - private[sql] override def resetMetrics(): Unit = { - plan.foreach(_.resetMetrics()) + override def innerChildren: Seq[SparkPlan] = { + child :: Nil } - override def generateTreeString( - depth: Int, - lastChildren: Seq[Boolean], - builder: StringBuilder): StringBuilder = { - if (depth > 0) { - lastChildren.init.foreach { isLast => - val prefixFragment = if (isLast) " " else ": " - builder.append(prefixFragment) - } - - val branch = if (lastChildren.last) "+- " else ":- " - builder.append(branch) - } - - builder.append(simpleString) - builder.append("\n") - - plan.generateTreeString(depth + 2, lastChildren :+ false :+ true, builder) - if (children.nonEmpty) { - children.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder)) - children.last.generateTreeString(depth + 1, lastChildren :+ true, builder) - } + private def collectInputs(plan: SparkPlan): Seq[SparkPlan] = plan match { + case InputAdapter(c) => c :: Nil + case other => other.children.flatMap(collectInputs) + } - builder + override def treeChildren: Seq[SparkPlan] = { + collectInputs(child) } override def simpleString: String = "WholeStageCodegen" @@ -416,27 +386,34 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru case _ => false } + /** + * Inserts a InputAdapter on top of those that do not support codegen. + */ + private def insertInputAdapter(plan: SparkPlan): SparkPlan = plan match { + case j @ SortMergeJoin(_, _, _, left, right) => + // The children of SortMergeJoin should do codegen separately. + j.copy(left = InputAdapter(insertWholeStageCodegen(left)), + right = InputAdapter(insertWholeStageCodegen(right))) + case p if !supportCodegen(p) => + // collapse them recursively + InputAdapter(insertWholeStageCodegen(p)) + case p => + p.withNewChildren(p.children.map(insertInputAdapter)) + } + + /** + * Inserts a WholeStageCodegen on top of those that support codegen. + */ + private def insertWholeStageCodegen(plan: SparkPlan): SparkPlan = plan match { + case plan: CodegenSupport if supportCodegen(plan) => + WholeStageCodegen(insertInputAdapter(plan)) + case other => + other.withNewChildren(other.children.map(insertWholeStageCodegen)) + } + def apply(plan: SparkPlan): SparkPlan = { if (sqlContext.conf.wholeStageEnabled) { - plan.transform { - case plan: CodegenSupport if supportCodegen(plan) => - var inputs = ArrayBuffer[SparkPlan]() - val combined = plan.transform { - // The build side can't be compiled together - case b @ BroadcastHashJoin(_, _, _, BuildLeft, _, left, right) => - b.copy(left = apply(left)) - case b @ BroadcastHashJoin(_, _, _, BuildRight, _, left, right) => - b.copy(right = apply(right)) - case j @ SortMergeJoin(_, _, _, left, right) => - // The children of SortMergeJoin should do codegen separately. - j.copy(left = apply(left), right = apply(right)) - case p if !supportCodegen(p) => - val input = apply(p) // collapse them recursively - inputs += input - InputAdapter(input) - }.asInstanceOf[CodegenSupport] - WholeStageCodegen(combined, inputs) - } + insertWholeStageCodegen(plan) } else { plan } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 95d033bc57548..fed88b8c0a117 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -24,6 +24,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.internal.SQLConf @@ -68,7 +69,7 @@ package object debug { } } - private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode { + private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode with CodegenSupport { def output: Seq[Attribute] = child.output implicit object SetAccumulatorParam extends AccumulatorParam[HashSet[String]] { @@ -86,10 +87,11 @@ package object debug { /** * A collection of metrics for each column of output. * @param elementTypes the actual runtime types for the output. Useful when there are bugs - * causing the wrong data to be projected. + * causing the wrong data to be projected. */ case class ColumnMetrics( - elementTypes: Accumulator[HashSet[String]] = sparkContext.accumulator(HashSet.empty)) + elementTypes: Accumulator[HashSet[String]] = sparkContext.accumulator(HashSet.empty)) + val tupleCount: Accumulator[Int] = sparkContext.accumulator[Int](0) val numColumns: Int = child.output.size @@ -98,7 +100,7 @@ package object debug { def dumpStats(): Unit = { logDebug(s"== ${child.simpleString} ==") logDebug(s"Tuples output: ${tupleCount.value}") - child.output.zip(columnStats).foreach { case(attr, metric) => + child.output.zip(columnStats).foreach { case (attr, metric) => val actualDataTypes = metric.elementTypes.value.mkString("{", ",", "}") logDebug(s" ${attr.name} ${attr.dataType}: $actualDataTypes") } @@ -108,6 +110,7 @@ package object debug { child.execute().mapPartitions { iter => new Iterator[InternalRow] { def hasNext: Boolean = iter.hasNext + def next(): InternalRow = { val currentRow = iter.next() tupleCount += 1 @@ -124,5 +127,17 @@ package object debug { } } } + + override def upstreams(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].upstreams() + } + + override def doProduce(ctx: CodegenContext): String = { + child.asInstanceOf[CodegenSupport].produce(ctx, this) + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + consume(ctx, input) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index 4eb248569b281..12e586ada5976 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -21,7 +21,7 @@ import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable -import org.apache.spark.sql.execution.{InputAdapter, SparkPlanInfo, WholeStageCodegen} +import org.apache.spark.sql.execution.SparkPlanInfo import org.apache.spark.sql.execution.metric.SQLMetrics /** @@ -73,36 +73,40 @@ private[sql] object SparkPlanGraph { edges: mutable.ArrayBuffer[SparkPlanGraphEdge], parent: SparkPlanGraphNode, subgraph: SparkPlanGraphCluster): Unit = { - if (planInfo.nodeName == classOf[WholeStageCodegen].getSimpleName) { - val cluster = new SparkPlanGraphCluster( - nodeIdGenerator.getAndIncrement(), - planInfo.nodeName, - planInfo.simpleString, - mutable.ArrayBuffer[SparkPlanGraphNode]()) - nodes += cluster - buildSparkPlanGraphNode( - planInfo.children.head, nodeIdGenerator, nodes, edges, parent, cluster) - } else if (planInfo.nodeName == classOf[InputAdapter].getSimpleName) { - buildSparkPlanGraphNode(planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null) - } else { - val metrics = planInfo.metrics.map { metric => - SQLPlanMetric(metric.name, metric.accumulatorId, - SQLMetrics.getMetricParam(metric.metricParam)) - } - val node = new SparkPlanGraphNode( - nodeIdGenerator.getAndIncrement(), planInfo.nodeName, - planInfo.simpleString, planInfo.metadata, metrics) - if (subgraph == null) { - nodes += node - } else { - subgraph.nodes += node - } - - if (parent != null) { - edges += SparkPlanGraphEdge(node.id, parent.id) - } - planInfo.children.foreach( - buildSparkPlanGraphNode(_, nodeIdGenerator, nodes, edges, node, subgraph)) + planInfo.nodeName match { + case "WholeStageCodegen" => + val cluster = new SparkPlanGraphCluster( + nodeIdGenerator.getAndIncrement(), + planInfo.nodeName, + planInfo.simpleString, + mutable.ArrayBuffer[SparkPlanGraphNode]()) + nodes += cluster + buildSparkPlanGraphNode( + planInfo.children.head, nodeIdGenerator, nodes, edges, parent, cluster) + case "InputAdapter" => + buildSparkPlanGraphNode(planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null) + case "Subquery" if subgraph != null => + // Subquery should not be included in WholeStageCodegen + buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, parent, null) + case _ => + val metrics = planInfo.metrics.map { metric => + SQLPlanMetric(metric.name, metric.accumulatorId, + SQLMetrics.getMetricParam(metric.metricParam)) + } + val node = new SparkPlanGraphNode( + nodeIdGenerator.getAndIncrement(), planInfo.nodeName, + planInfo.simpleString, planInfo.metadata, metrics) + if (subgraph == null) { + nodes += node + } else { + subgraph.nodes += node + } + + if (parent != null) { + edges += SparkPlanGraphEdge(node.id, parent.id) + } + planInfo.children.foreach( + buildSparkPlanGraphNode(_, nodeIdGenerator, nodes, edges, node, subgraph)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index de371d85d9fd7..e00c762c67054 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -31,14 +31,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { val df = sqlContext.range(10).filter("id = 1").selectExpr("id + 1") val plan = df.queryExecution.executedPlan assert(plan.find(_.isInstanceOf[WholeStageCodegen]).isDefined) - - checkThatPlansAgree( - sqlContext.range(100), - (p: SparkPlan) => - WholeStageCodegen(Filter('a == 1, InputAdapter(p)), Seq()), - (p: SparkPlan) => Filter('a == 1, p), - sortAnswers = false - ) + assert(df.collect() === Array(Row(2))) } test("Aggregate should be included in WholeStageCodegen") { @@ -46,7 +39,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { val plan = df.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegen] && - p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined) + p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[TungstenAggregate]).isDefined) assert(df.collect() === Array(Row(9, 4.5))) } @@ -55,7 +48,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { val plan = df.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegen] && - p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined) + p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[TungstenAggregate]).isDefined) assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1))) } @@ -66,7 +59,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { val df = sqlContext.range(10).join(broadcast(smallDF), col("k") === col("id")) assert(df.queryExecution.executedPlan.find(p => p.isInstanceOf[WholeStageCodegen] && - p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[BroadcastHashJoin]).isDefined) + p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[BroadcastHashJoin]).isDefined) assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2"))) } @@ -75,7 +68,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { val plan = df.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegen] && - p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[Sort]).isDefined) + p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[Sort]).isDefined) assert(df.collect() === Array(Row(1), Row(2), Row(3))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index f8a9a95c873ad..6a1f35105ace9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -189,8 +189,8 @@ class JDBCSuite extends SparkFunSuite // the plan only has PhysicalRDD to scan JDBCRelation. assert(parentPlan.isInstanceOf[org.apache.spark.sql.execution.WholeStageCodegen]) val node = parentPlan.asInstanceOf[org.apache.spark.sql.execution.WholeStageCodegen] - assert(node.plan.isInstanceOf[org.apache.spark.sql.execution.PhysicalRDD]) - assert(node.plan.asInstanceOf[PhysicalRDD].nodeName.contains("JDBCRelation")) + assert(node.child.isInstanceOf[org.apache.spark.sql.execution.PhysicalRDD]) + assert(node.child.asInstanceOf[PhysicalRDD].nodeName.contains("JDBCRelation")) df } assert(checkPushdown(sql("SELECT * FROM foobar WHERE THEID < 1")).collect().size == 0) @@ -227,7 +227,7 @@ class JDBCSuite extends SparkFunSuite // cannot compile given predicates. assert(parentPlan.isInstanceOf[org.apache.spark.sql.execution.WholeStageCodegen]) val node = parentPlan.asInstanceOf[org.apache.spark.sql.execution.WholeStageCodegen] - assert(node.plan.isInstanceOf[org.apache.spark.sql.execution.Filter]) + assert(node.child.isInstanceOf[org.apache.spark.sql.execution.Filter]) df } assert(checkNotPushdown(sql("SELECT * FROM foobar WHERE (THEID + 1) < 2")).collect().size == 0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index 15a95623d1e5c..e7d2b5ad96821 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -93,7 +93,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { val metric = qe.executedPlan match { - case w: WholeStageCodegen => w.plan.longMetric("numOutputRows") + case w: WholeStageCodegen => w.child.longMetric("numOutputRows") case other => other.longMetric("numOutputRows") } metrics += metric.value.value