Skip to content

Commit 8bfc3b7

Browse files
lianchengyhuai
authored andcommitted
[SPARK-17972][SQL] Add Dataset.checkpoint() to truncate large query plans
## What changes were proposed in this pull request? ### Problem Iterative ML code may easily create query plans that grow exponentially. We found that query planning time also increases exponentially even when all the sub-plan trees are cached. The following snippet illustrates the problem: ``` scala (0 until 6).foldLeft(Seq(1, 2, 3).toDS) { (plan, iteration) => println(s"== Iteration $iteration ==") val time0 = System.currentTimeMillis() val joined = plan.join(plan, "value").join(plan, "value").join(plan, "value").join(plan, "value") joined.cache() println(s"Query planning takes ${System.currentTimeMillis() - time0} ms") joined.as[Int] } // == Iteration 0 == // Query planning takes 9 ms // == Iteration 1 == // Query planning takes 26 ms // == Iteration 2 == // Query planning takes 53 ms // == Iteration 3 == // Query planning takes 163 ms // == Iteration 4 == // Query planning takes 700 ms // == Iteration 5 == // Query planning takes 3418 ms ``` This is because when building a new Dataset, the new plan is always built upon `QueryExecution.analyzed`, which doesn't leverage existing cached plans. On the other hand, usually, doing caching every a few iterations may not be the right direction for this problem since caching is too memory consuming (imaging computing connected components over a graph with 50 billion nodes). What we really need here is to truncate both the query plan (to minimize query planning time) and the lineage of the underlying RDD (to avoid stack overflow). ### Changes introduced in this PR This PR tries to fix this issue by introducing a `checkpoint()` method into `Dataset[T]`, which does exactly the things described above. The following snippet, which is essentially the same as the one above but invokes `checkpoint()` instead of `cache()`, shows the micro benchmark result of this PR: One key point is that the checkpointed Dataset should preserve the origianl partitioning and ordering information of the original Dataset, so that we can avoid unnecessary shuffling (similar to reading from a pre-bucketed table). This is done by adding `outputPartitioning` and `outputOrdering` to `LogicalRDD` and `RDDScanExec`. ### Micro benchmark ``` scala spark.sparkContext.setCheckpointDir("/tmp/cp") (0 until 100).foldLeft(Seq(1, 2, 3).toDS) { (plan, iteration) => println(s"== Iteration $iteration ==") val time0 = System.currentTimeMillis() val cp = plan.checkpoint() cp.count() System.out.println(s"Checkpointing takes ${System.currentTimeMillis() - time0} ms") val time1 = System.currentTimeMillis() val joined = cp.join(cp, "value").join(cp, "value").join(cp, "value").join(cp, "value") val result = joined.as[Int] println(s"Query planning takes ${System.currentTimeMillis() - time1} ms") result } // == Iteration 0 == // Checkpointing takes 591 ms // Query planning takes 13 ms // == Iteration 1 == // Checkpointing takes 1605 ms // Query planning takes 16 ms // == Iteration 2 == // Checkpointing takes 782 ms // Query planning takes 8 ms // == Iteration 3 == // Checkpointing takes 729 ms // Query planning takes 10 ms // == Iteration 4 == // Checkpointing takes 734 ms // Query planning takes 9 ms // == Iteration 5 == // ... // == Iteration 50 == // Checkpointing takes 571 ms // Query planning takes 7 ms // == Iteration 51 == // Checkpointing takes 548 ms // Query planning takes 7 ms // == Iteration 52 == // Checkpointing takes 596 ms // Query planning takes 8 ms // == Iteration 53 == // Checkpointing takes 568 ms // Query planning takes 7 ms // ... ``` You may see that although checkpointing is more heavy weight an operation, it always takes roughly the same amount of time to perform both checkpointing and query planning. ### Open question mengxr mentioned that it would be more convenient if we can make `Dataset.checkpoint()` eager, i.e., always performs a `RDD.count()` after calling `RDD.checkpoint()`. Not quite sure whether this is a universal requirement. Maybe we can add a `eager: Boolean` argument for `Dataset.checkpoint()` to support that. ## How was this patch tested? Unit test added in `DatasetSuite`. Author: Cheng Lian <[email protected]> Author: Yin Huai <[email protected]> Closes #15651 from liancheng/ds-checkpoint.
1 parent 26b07f1 commit 8bfc3b7

File tree

4 files changed

+157
-12
lines changed

4 files changed

+157
-12
lines changed

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,14 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
4040
import org.apache.spark.sql.catalyst.optimizer.CombineUnions
4141
import org.apache.spark.sql.catalyst.plans._
4242
import org.apache.spark.sql.catalyst.plans.logical._
43+
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection}
4344
import org.apache.spark.sql.catalyst.util.usePrettyExpression
4445
import org.apache.spark.sql.execution.{FileRelation, LogicalRDD, QueryExecution, SQLExecution}
4546
import org.apache.spark.sql.execution.command.{CreateViewCommand, ExplainCommand, GlobalTempView, LocalTempView}
46-
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
47+
import org.apache.spark.sql.execution.datasources.LogicalRelation
4748
import org.apache.spark.sql.execution.datasources.json.JacksonGenerator
4849
import org.apache.spark.sql.execution.python.EvaluatePython
49-
import org.apache.spark.sql.streaming.{DataStreamWriter, StreamingQuery}
50+
import org.apache.spark.sql.streaming.DataStreamWriter
5051
import org.apache.spark.sql.types._
5152
import org.apache.spark.storage.StorageLevel
5253
import org.apache.spark.util.Utils
@@ -482,6 +483,58 @@ class Dataset[T] private[sql](
482483
@InterfaceStability.Evolving
483484
def isStreaming: Boolean = logicalPlan.isStreaming
484485

486+
/**
487+
* Returns a checkpointed version of this Dataset.
488+
*
489+
* @group basic
490+
* @since 2.1.0
491+
*/
492+
@Experimental
493+
@InterfaceStability.Evolving
494+
def checkpoint(): Dataset[T] = checkpoint(eager = true)
495+
496+
/**
497+
* Returns a checkpointed version of this Dataset.
498+
*
499+
* @param eager When true, materializes the underlying checkpointed RDD eagerly.
500+
*
501+
* @group basic
502+
* @since 2.1.0
503+
*/
504+
@Experimental
505+
@InterfaceStability.Evolving
506+
def checkpoint(eager: Boolean): Dataset[T] = {
507+
val internalRdd = queryExecution.toRdd.map(_.copy())
508+
internalRdd.checkpoint()
509+
510+
if (eager) {
511+
internalRdd.count()
512+
}
513+
514+
val physicalPlan = queryExecution.executedPlan
515+
516+
// Takes the first leaf partitioning whenever we see a `PartitioningCollection`. Otherwise the
517+
// size of `PartitioningCollection` may grow exponentially for queries involving deep inner
518+
// joins.
519+
def firstLeafPartitioning(partitioning: Partitioning): Partitioning = {
520+
partitioning match {
521+
case p: PartitioningCollection => firstLeafPartitioning(p.partitionings.head)
522+
case p => p
523+
}
524+
}
525+
526+
val outputPartitioning = firstLeafPartitioning(physicalPlan.outputPartitioning)
527+
528+
Dataset.ofRows(
529+
sparkSession,
530+
LogicalRDD(
531+
logicalPlan.output,
532+
internalRdd,
533+
outputPartitioning,
534+
physicalPlan.outputOrdering
535+
)(sparkSession)).as[T]
536+
}
537+
485538
/**
486539
* Displays the Dataset in a tabular form. Strings more than 20 characters will be truncated,
487540
* and all cells will be aligned right. For example:

sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
2323
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
2424
import org.apache.spark.sql.catalyst.expressions._
2525
import org.apache.spark.sql.catalyst.plans.logical._
26-
import org.apache.spark.sql.execution.datasources._
26+
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
2727
import org.apache.spark.sql.execution.metric.SQLMetrics
2828
import org.apache.spark.sql.types.DataType
2929
import org.apache.spark.util.Utils
@@ -130,17 +130,40 @@ case class ExternalRDDScanExec[T](
130130
/** Logical plan node for scanning data from an RDD of InternalRow. */
131131
case class LogicalRDD(
132132
output: Seq[Attribute],
133-
rdd: RDD[InternalRow])(session: SparkSession)
133+
rdd: RDD[InternalRow],
134+
outputPartitioning: Partitioning = UnknownPartitioning(0),
135+
outputOrdering: Seq[SortOrder] = Nil)(session: SparkSession)
134136
extends LeafNode with MultiInstanceRelation {
135137

136138
override protected final def otherCopyArgs: Seq[AnyRef] = session :: Nil
137139

138-
override def newInstance(): LogicalRDD.this.type =
139-
LogicalRDD(output.map(_.newInstance()), rdd)(session).asInstanceOf[this.type]
140+
override def newInstance(): LogicalRDD.this.type = {
141+
val rewrite = output.zip(output.map(_.newInstance())).toMap
142+
143+
val rewrittenPartitioning = outputPartitioning match {
144+
case p: Expression =>
145+
p.transform {
146+
case e: Attribute => rewrite.getOrElse(e, e)
147+
}.asInstanceOf[Partitioning]
148+
149+
case p => p
150+
}
151+
152+
val rewrittenOrdering = outputOrdering.map(_.transform {
153+
case e: Attribute => rewrite.getOrElse(e, e)
154+
}.asInstanceOf[SortOrder])
155+
156+
LogicalRDD(
157+
output.map(rewrite),
158+
rdd,
159+
rewrittenPartitioning,
160+
rewrittenOrdering
161+
)(session).asInstanceOf[this.type]
162+
}
140163

141164
override def sameResult(plan: LogicalPlan): Boolean = {
142165
plan.canonicalized match {
143-
case LogicalRDD(_, otherRDD) => rdd.id == otherRDD.id
166+
case LogicalRDD(_, otherRDD, _, _) => rdd.id == otherRDD.id
144167
case _ => false
145168
}
146169
}
@@ -158,7 +181,9 @@ case class LogicalRDD(
158181
case class RDDScanExec(
159182
output: Seq[Attribute],
160183
rdd: RDD[InternalRow],
161-
override val nodeName: String) extends LeafExecNode {
184+
override val nodeName: String,
185+
override val outputPartitioning: Partitioning = UnknownPartitioning(0),
186+
override val outputOrdering: Seq[SortOrder] = Nil) extends LeafExecNode {
162187

163188
override lazy val metrics = Map(
164189
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@ import org.apache.spark.sql.execution.datasources._
3232
import org.apache.spark.sql.execution.exchange.ShuffleExchange
3333
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight}
3434
import org.apache.spark.sql.execution.streaming.{MemoryPlan, StreamingExecutionRelation, StreamingRelation, StreamingRelationExec}
35-
import org.apache.spark.sql.internal.SQLConf
36-
import org.apache.spark.sql.streaming.StreamingQuery
3735

3836
/**
3937
* Converts a logical plan into zero or more SparkPlans. This API is exposed for experimenting
@@ -402,13 +400,14 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
402400
generator, join = join, outer = outer, g.output, planLater(child)) :: Nil
403401
case logical.OneRowRelation =>
404402
execution.RDDScanExec(Nil, singleRowRdd, "OneRowRelation") :: Nil
405-
case r : logical.Range =>
403+
case r: logical.Range =>
406404
execution.RangeExec(r) :: Nil
407405
case logical.RepartitionByExpression(expressions, child, nPartitions) =>
408406
exchange.ShuffleExchange(HashPartitioning(
409407
expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil
410408
case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil
411-
case LogicalRDD(output, rdd) => RDDScanExec(output, rdd, "ExistingRDD") :: Nil
409+
case r: LogicalRDD =>
410+
RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil
412411
case BroadcastHint(child) => planLater(child) :: Nil
413412
case _ => Nil
414413
}

sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,11 @@ import java.sql.{Date, Timestamp}
2222

2323
import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder}
2424
import org.apache.spark.sql.catalyst.util.sideBySide
25+
import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SortExec}
26+
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchange}
2527
import org.apache.spark.sql.execution.streaming.MemoryStream
2628
import org.apache.spark.sql.functions._
29+
import org.apache.spark.sql.internal.SQLConf
2730
import org.apache.spark.sql.test.SharedSQLContext
2831
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
2932

@@ -919,6 +922,71 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
919922
df.withColumn("b", expr("0")).as[ClassData]
920923
.groupByKey(_.a).flatMapGroups { case (x, iter) => List[Int]() })
921924
}
925+
926+
Seq(true, false).foreach { eager =>
927+
def testCheckpointing(testName: String)(f: => Unit): Unit = {
928+
test(s"Dataset.checkpoint() - $testName (eager = $eager)") {
929+
withTempDir { dir =>
930+
val originalCheckpointDir = spark.sparkContext.checkpointDir
931+
932+
try {
933+
spark.sparkContext.setCheckpointDir(dir.getCanonicalPath)
934+
f
935+
} finally {
936+
// Since the original checkpointDir can be None, we need
937+
// to set the variable directly.
938+
spark.sparkContext.checkpointDir = originalCheckpointDir
939+
}
940+
}
941+
}
942+
}
943+
944+
testCheckpointing("basic") {
945+
val ds = spark.range(10).repartition('id % 2).filter('id > 5).orderBy('id.desc)
946+
val cp = ds.checkpoint(eager)
947+
948+
val logicalRDD = cp.logicalPlan match {
949+
case plan: LogicalRDD => plan
950+
case _ =>
951+
val treeString = cp.logicalPlan.treeString(verbose = true)
952+
fail(s"Expecting a LogicalRDD, but got\n$treeString")
953+
}
954+
955+
val dsPhysicalPlan = ds.queryExecution.executedPlan
956+
val cpPhysicalPlan = cp.queryExecution.executedPlan
957+
958+
assertResult(dsPhysicalPlan.outputPartitioning) { logicalRDD.outputPartitioning }
959+
assertResult(dsPhysicalPlan.outputOrdering) { logicalRDD.outputOrdering }
960+
961+
assertResult(dsPhysicalPlan.outputPartitioning) { cpPhysicalPlan.outputPartitioning }
962+
assertResult(dsPhysicalPlan.outputOrdering) { cpPhysicalPlan.outputOrdering }
963+
964+
// For a lazy checkpoint() call, the first check also materializes the checkpoint.
965+
checkDataset(cp, (9L to 6L by -1L).map(java.lang.Long.valueOf): _*)
966+
967+
// Reads back from checkpointed data and check again.
968+
checkDataset(cp, (9L to 6L by -1L).map(java.lang.Long.valueOf): _*)
969+
}
970+
971+
testCheckpointing("should preserve partitioning information") {
972+
val ds = spark.range(10).repartition('id % 2)
973+
val cp = ds.checkpoint(eager)
974+
975+
val agg = cp.groupBy('id % 2).agg(count('id))
976+
977+
agg.queryExecution.executedPlan.collectFirst {
978+
case ShuffleExchange(_, _: RDDScanExec, _) =>
979+
case BroadcastExchangeExec(_, _: RDDScanExec) =>
980+
}.foreach { _ =>
981+
fail(
982+
"No Exchange should be inserted above RDDScanExec since the checkpointed Dataset " +
983+
"preserves partitioning information:\n\n" + agg.queryExecution
984+
)
985+
}
986+
987+
checkAnswer(agg, ds.groupBy('id % 2).agg(count('id)))
988+
}
989+
}
922990
}
923991

924992
case class Generic[T](id: T, value: Double)

0 commit comments

Comments
 (0)