@@ -23,8 +23,9 @@ import scala.collection.mutable.ArrayBuffer
2323import org .apache .spark .sql .AnalysisException
2424import org .apache .spark .sql .catalyst .expressions .{Expression , PlanExpression }
2525import org .apache .spark .sql .catalyst .plans .QueryPlan
26+ import org .apache .spark .sql .execution .adaptive .{AdaptiveSparkPlanExec , AdaptiveSparkPlanHelper , QueryStageExec }
2627
27- object ExplainUtils {
28+ object ExplainUtils extends AdaptiveSparkPlanHelper {
2829 /**
2930 * Given a input physical plan, performs the following tasks.
3031 * 1. Computes the operator id for current operator and records it in the operaror
@@ -144,15 +145,26 @@ object ExplainUtils {
144145 case p : WholeStageCodegenExec =>
145146 case p : InputAdapter =>
146147 case other : QueryPlan [_] =>
147- if (! other.getTagValue(QueryPlan .OP_ID_TAG ).isDefined) {
148+
149+ def setOpId (): Unit = if (other.getTagValue(QueryPlan .OP_ID_TAG ).isEmpty) {
148150 currentOperationID += 1
149151 other.setTagValue(QueryPlan .OP_ID_TAG , currentOperationID)
150152 operatorIDs += ((currentOperationID, other))
151153 }
152- other.innerChildren.foreach { plan =>
153- currentOperationID = generateOperatorIDs(plan,
154- currentOperationID,
155- operatorIDs)
154+
155+ other match {
156+ case p : AdaptiveSparkPlanExec =>
157+ currentOperationID =
158+ generateOperatorIDs(p.executedPlan, currentOperationID, operatorIDs)
159+ setOpId()
160+ case p : QueryStageExec =>
161+ currentOperationID = generateOperatorIDs(p.plan, currentOperationID, operatorIDs)
162+ setOpId()
163+ case _ =>
164+ setOpId()
165+ other.innerChildren.foldLeft(currentOperationID) {
166+ (curId, plan) => generateOperatorIDs(plan, curId, operatorIDs)
167+ }
156168 }
157169 }
158170 currentOperationID
@@ -163,21 +175,25 @@ object ExplainUtils {
163175 * whole stage code gen id in the plan via setting a tag.
164176 */
165177 private def generateWholeStageCodegenIds (plan : QueryPlan [_]): Unit = {
178+ var currentCodegenId = - 1
179+
180+ def setCodegenId (p : QueryPlan [_], children : Seq [QueryPlan [_]]): Unit = {
181+ if (currentCodegenId != - 1 ) {
182+ p.setTagValue(QueryPlan .CODEGEN_ID_TAG , currentCodegenId)
183+ }
184+ children.foreach(generateWholeStageCodegenIds)
185+ }
186+
166187 // Skip the subqueries as they are not printed as part of main query block.
167188 if (plan.isInstanceOf [BaseSubqueryExec ]) {
168189 return
169190 }
170- var currentCodegenId = - 1
171191 plan.foreach {
172192 case p : WholeStageCodegenExec => currentCodegenId = p.codegenStageId
173193 case _ : InputAdapter => currentCodegenId = - 1
174- case other : QueryPlan [_] =>
175- if (currentCodegenId != - 1 ) {
176- other.setTagValue(QueryPlan .CODEGEN_ID_TAG , currentCodegenId)
177- }
178- other.innerChildren.foreach { plan =>
179- generateWholeStageCodegenIds(plan)
180- }
194+ case p : AdaptiveSparkPlanExec => setCodegenId(p, Seq (p.executedPlan))
195+ case p : QueryStageExec => setCodegenId(p, Seq (p.plan))
196+ case other : QueryPlan [_] => setCodegenId(other, other.innerChildren)
181197 }
182198 }
183199
@@ -232,13 +248,16 @@ object ExplainUtils {
232248 }
233249
234250 def removeTags (plan : QueryPlan [_]): Unit = {
251+ def remove (p : QueryPlan [_], children : Seq [QueryPlan [_]]): Unit = {
252+ p.unsetTagValue(QueryPlan .OP_ID_TAG )
253+ p.unsetTagValue(QueryPlan .CODEGEN_ID_TAG )
254+ children.foreach(removeTags)
255+ }
256+
235257 plan foreach {
236- case plan : QueryPlan [_] =>
237- plan.unsetTagValue(QueryPlan .OP_ID_TAG )
238- plan.unsetTagValue(QueryPlan .CODEGEN_ID_TAG )
239- plan.innerChildren.foreach { p =>
240- removeTags(p)
241- }
258+ case p : AdaptiveSparkPlanExec => remove(p, Seq (p.executedPlan))
259+ case p : QueryStageExec => remove(p, Seq (p.plan))
260+ case plan : QueryPlan [_] => remove(plan, plan.innerChildren)
242261 }
243262 }
244263}
0 commit comments