Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,12 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy

override def simpleString: String = statePrefix + super.simpleString

override def treeChildren: Seq[PlanType] = {
val subqueries = expressions.flatMap(_.collect {case e: SubqueryExpression => e})
children ++ subqueries.map(e => e.plan.asInstanceOf[PlanType])
/**
* All the subqueries of current plan.
*/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is treeChildren? Its doc in TreeNode does not really show the difference between it and children.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the doc string of treeChildren.

def subqueries: Seq[PlanType] = {
expressions.flatMap(_.collect {case e: SubqueryExpression => e.plan.asInstanceOf[PlanType]})
}

override def innerChildren: Seq[PlanType] = subqueries
}
Original file line number Diff line number Diff line change
Expand Up @@ -450,9 +450,52 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {

/**
* All the nodes that will be used to generate tree string.
*
* For example:
*
* WholeStageCodegen
* +-- SortMergeJoin
* |-- InputAdapter
* | +-- Sort
* +-- InputAdapter
* +-- Sort
*
* the treeChildren of WholeStageCodegen will be Seq(Sort, Sort), it will generate a tree string
* like this:
*
* WholeStageCodegen
* : +- SortMergeJoin
* : :- INPUT
* : :- INPUT
* :- Sort
* :- Sort
*/
protected def treeChildren: Seq[BaseType] = children

/**
* All the nodes that are parts of this node.
*
* For example:
*
* WholeStageCodegen
* +- SortMergeJoin
* |-- InputAdapter
* | +-- Sort
* +-- InputAdapter
* +-- Sort
*
* the innerChildren of WholeStageCodegen will be Seq(SortMergeJoin), it will generate a tree
* string like this:
*
* WholeStageCodegen
* : +- SortMergeJoin
* : :- INPUT
* : :- INPUT
* :- Sort
* :- Sort
*/
protected def innerChildren: Seq[BaseType] = Nil

/**
* Appends the string represent of this node and its children to the given StringBuilder.
*
Expand All @@ -475,6 +518,12 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
builder.append(simpleString)
builder.append("\n")

if (innerChildren.nonEmpty) {
innerChildren.init.foreach(_.generateTreeString(
depth + 2, lastChildren :+ false :+ false, builder))
innerChildren.last.generateTreeString(depth + 2, lastChildren :+ false :+ true, builder)
}

if (treeChildren.nonEmpty) {
treeChildren.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder))
treeChildren.last.generateTreeString(depth + 1, lastChildren :+ true, builder)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,8 @@ class SparkPlanInfo(
private[sql] object SparkPlanInfo {

def fromSparkPlan(plan: SparkPlan): SparkPlanInfo = {
val children = plan match {
case WholeStageCodegen(child, _) => child :: Nil
case InputAdapter(child) => child :: Nil
case plan => plan.children
}

val children = plan.children ++ plan.subqueries
val metrics = plan.metrics.toSeq.map { case (key, metric) =>
new SQLMetricInfo(metric.name.getOrElse(key), metric.id,
Utils.getFormattedClassName(metric.param))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.spark.sql.execution

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
Expand All @@ -29,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.toCommentSafeString
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight, SortMergeJoin}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin}
import org.apache.spark.sql.execution.metric.LongSQLMetricValue

/**
Expand Down Expand Up @@ -163,16 +161,12 @@ trait CodegenSupport extends SparkPlan {
* This is the leaf node of a tree with WholeStageCodegen, is used to generate code that consumes
* an RDD iterator of InternalRow.
*/
case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
case class InputAdapter(child: SparkPlan) extends UnaryNode with CodegenSupport {

override def output: Seq[Attribute] = child.output
override def outputPartitioning: Partitioning = child.outputPartitioning
override def outputOrdering: Seq[SortOrder] = child.outputOrdering

override def doPrepare(): Unit = {
child.prepare()
}

override def doExecute(): RDD[InternalRow] = {
child.execute()
}
Expand All @@ -181,8 +175,6 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
child.doExecuteBroadcast()
}

override def supportCodegen: Boolean = false

override def upstreams(): Seq[RDD[InternalRow]] = {
child.execute() :: Nil
}
Expand Down Expand Up @@ -210,6 +202,8 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
}

override def simpleString: String = "INPUT"

override def treeChildren: Seq[SparkPlan] = Nil
}

/**
Expand Down Expand Up @@ -243,30 +237,23 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
* doCodeGen() will create a CodeGenContext, which will hold a list of variables for input,
* used to generated code for BoundReference.
*/
case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
extends SparkPlan with CodegenSupport {

override def supportCodegen: Boolean = false

override def output: Seq[Attribute] = plan.output
override def outputPartitioning: Partitioning = plan.outputPartitioning
override def outputOrdering: Seq[SortOrder] = plan.outputOrdering
case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSupport {

override def doPrepare(): Unit = {
plan.prepare()
}
override def output: Seq[Attribute] = child.output
override def outputPartitioning: Partitioning = child.outputPartitioning
override def outputOrdering: Seq[SortOrder] = child.outputOrdering

override def doExecute(): RDD[InternalRow] = {
val ctx = new CodegenContext
val code = plan.produce(ctx, this)
val code = child.asInstanceOf[CodegenSupport].produce(ctx, this)
val references = ctx.references.toArray
val source = s"""
public Object generate(Object[] references) {
return new GeneratedIterator(references);
}

/** Codegened pipeline for:
* ${toCommentSafeString(plan.treeString.trim)}
* ${toCommentSafeString(child.treeString.trim)}
*/
class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {

Expand Down Expand Up @@ -294,7 +281,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
// println(s"${CodeFormatter.format(cleanedSource)}")
CodeGenerator.compile(cleanedSource)

val rdds = plan.upstreams()
val rdds = child.asInstanceOf[CodegenSupport].upstreams()
assert(rdds.size <= 2, "Up to two upstream RDDs can be supported")
if (rdds.length == 1) {
rdds.head.mapPartitions { iter =>
Expand Down Expand Up @@ -361,34 +348,17 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
}
}

private[sql] override def resetMetrics(): Unit = {
plan.foreach(_.resetMetrics())
override def innerChildren: Seq[SparkPlan] = {
child :: Nil
}

override def generateTreeString(
depth: Int,
lastChildren: Seq[Boolean],
builder: StringBuilder): StringBuilder = {
if (depth > 0) {
lastChildren.init.foreach { isLast =>
val prefixFragment = if (isLast) " " else ": "
builder.append(prefixFragment)
}

val branch = if (lastChildren.last) "+- " else ":- "
builder.append(branch)
}

builder.append(simpleString)
builder.append("\n")

plan.generateTreeString(depth + 2, lastChildren :+ false :+ true, builder)
if (children.nonEmpty) {
children.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder))
children.last.generateTreeString(depth + 1, lastChildren :+ true, builder)
}
private def collectInputs(plan: SparkPlan): Seq[SparkPlan] = plan match {
case InputAdapter(c) => c :: Nil
case other => other.children.flatMap(collectInputs)
}

builder
override def treeChildren: Seq[SparkPlan] = {
collectInputs(child)
}

override def simpleString: String = "WholeStageCodegen"
Expand Down Expand Up @@ -416,27 +386,34 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru
case _ => false
}

/**
* Inserts a InputAdapter on top of those that do not support codegen.
*/
private def insertInputAdapter(plan: SparkPlan): SparkPlan = plan match {
case j @ SortMergeJoin(_, _, _, left, right) =>
// The children of SortMergeJoin should do codegen separately.
j.copy(left = InputAdapter(insertWholeStageCodegen(left)),
right = InputAdapter(insertWholeStageCodegen(right)))
case p if !supportCodegen(p) =>
// collapse them recursively
InputAdapter(insertWholeStageCodegen(p))
case p =>
p.withNewChildren(p.children.map(insertInputAdapter))
}

/**
* Inserts a WholeStageCodegen on top of those that support codegen.
*/
private def insertWholeStageCodegen(plan: SparkPlan): SparkPlan = plan match {
case plan: CodegenSupport if supportCodegen(plan) =>
WholeStageCodegen(insertInputAdapter(plan))
case other =>
other.withNewChildren(other.children.map(insertWholeStageCodegen))
}

def apply(plan: SparkPlan): SparkPlan = {
if (sqlContext.conf.wholeStageEnabled) {
plan.transform {
case plan: CodegenSupport if supportCodegen(plan) =>
var inputs = ArrayBuffer[SparkPlan]()
val combined = plan.transform {
// The build side can't be compiled together
case b @ BroadcastHashJoin(_, _, _, BuildLeft, _, left, right) =>
b.copy(left = apply(left))
case b @ BroadcastHashJoin(_, _, _, BuildRight, _, left, right) =>
b.copy(right = apply(right))
case j @ SortMergeJoin(_, _, _, left, right) =>
// The children of SortMergeJoin should do codegen separately.
j.copy(left = apply(left), right = apply(right))
case p if !supportCodegen(p) =>
val input = apply(p) // collapse them recursively
inputs += input
InputAdapter(input)
}.asInstanceOf[CodegenSupport]
WholeStageCodegen(combined, inputs)
}
insertWholeStageCodegen(plan)
} else {
plan
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
import org.apache.spark.sql.internal.SQLConf

Expand Down Expand Up @@ -68,7 +69,7 @@ package object debug {
}
}

private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode {
private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode with CodegenSupport {
def output: Seq[Attribute] = child.output

implicit object SetAccumulatorParam extends AccumulatorParam[HashSet[String]] {
Expand All @@ -86,10 +87,11 @@ package object debug {
/**
* A collection of metrics for each column of output.
* @param elementTypes the actual runtime types for the output. Useful when there are bugs
* causing the wrong data to be projected.
* causing the wrong data to be projected.
*/
case class ColumnMetrics(
elementTypes: Accumulator[HashSet[String]] = sparkContext.accumulator(HashSet.empty))
elementTypes: Accumulator[HashSet[String]] = sparkContext.accumulator(HashSet.empty))

val tupleCount: Accumulator[Int] = sparkContext.accumulator[Int](0)

val numColumns: Int = child.output.size
Expand All @@ -98,7 +100,7 @@ package object debug {
def dumpStats(): Unit = {
logDebug(s"== ${child.simpleString} ==")
logDebug(s"Tuples output: ${tupleCount.value}")
child.output.zip(columnStats).foreach { case(attr, metric) =>
child.output.zip(columnStats).foreach { case (attr, metric) =>
val actualDataTypes = metric.elementTypes.value.mkString("{", ",", "}")
logDebug(s" ${attr.name} ${attr.dataType}: $actualDataTypes")
}
Expand All @@ -108,6 +110,7 @@ package object debug {
child.execute().mapPartitions { iter =>
new Iterator[InternalRow] {
def hasNext: Boolean = iter.hasNext

def next(): InternalRow = {
val currentRow = iter.next()
tupleCount += 1
Expand All @@ -124,5 +127,17 @@ package object debug {
}
}
}

override def upstreams(): Seq[RDD[InternalRow]] = {
child.asInstanceOf[CodegenSupport].upstreams()
}

override def doProduce(ctx: CodegenContext): String = {
child.asInstanceOf[CodegenSupport].produce(ctx, this)
}

override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
consume(ctx, input)
}
}
}
Loading