diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index e9b1aa81895f5..f5f77b03c2b1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.streaming.{InternalOutputModes, StreamingRe import org.apache.spark.sql.execution.aggregate.AggUtils import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.execution.command._ -import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.execution.exchange.{REPARTITION, REPARTITION_WITH_NUM, ShuffleExchangeExec} import org.apache.spark.sql.execution.python._ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.MemoryPlan @@ -670,7 +670,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Repartition(numPartitions, shuffle, child) => if (shuffle) { ShuffleExchangeExec(RoundRobinPartitioning(numPartitions), - planLater(child), noUserSpecifiedNumPartition = false) :: Nil + planLater(child), REPARTITION_WITH_NUM) :: Nil } else { execution.CoalesceExec(numPartitions, planLater(child)) :: Nil } @@ -703,10 +703,12 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case r: logical.Range => execution.RangeExec(r) :: Nil case r: logical.RepartitionByExpression => - exchange.ShuffleExchangeExec( - r.partitioning, - planLater(r.child), - noUserSpecifiedNumPartition = r.optNumPartitions.isEmpty) :: Nil + val shuffleOrigin = if (r.optNumPartitions.isEmpty) { + REPARTITION + } else { + REPARTITION_WITH_NUM + } + exchange.ShuffleExchangeExec(r.partitioning, planLater(r.child), shuffleOrigin) :: Nil case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil case r: LogicalRDD => RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala index 89ff528d7a188..0cf3ab0cca49a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala @@ -18,8 +18,10 @@ package org.apache.spark.sql.execution.adaptive import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.plans.physical.SinglePartition import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, REPARTITION, ShuffleExchangeLike} import org.apache.spark.sql.internal.SQLConf /** @@ -47,7 +49,7 @@ case class CoalesceShufflePartitions(session: SparkSession) extends Rule[SparkPl val shuffleStages = collectShuffleStages(plan) // ShuffleExchanges introduced by repartition do not support changing the number of partitions. // We change the number of partitions in the stage only if all the ShuffleExchanges support it. - if (!shuffleStages.forall(_.shuffle.canChangeNumPartitions)) { + if (!shuffleStages.forall(s => supportCoalesce(s.shuffle))) { plan } else { // `ShuffleQueryStageExec#mapStats` returns None when the input RDD has 0 partitions, @@ -82,4 +84,9 @@ case class CoalesceShufflePartitions(session: SparkSession) extends Rule[SparkPl } } } + + private def supportCoalesce(s: ShuffleExchangeLike): Boolean = { + s.outputPartitioning != SinglePartition && + (s.shuffleOrigin == ENSURE_REQUIREMENTS || s.shuffleOrigin == REPARTITION) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala index 8db2827beaf43..8f57947cb6396 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala @@ -18,9 +18,10 @@ package org.apache.spark.sql.execution.adaptive import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} +import org.apache.spark.sql.catalyst.plans.physical.SinglePartition import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchangeExec} +import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, EnsureRequirements, ShuffleExchangeExec, ShuffleExchangeLike} import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.internal.SQLConf @@ -136,9 +137,13 @@ object OptimizeLocalShuffleReader extends Rule[SparkPlan] { def canUseLocalShuffleReader(plan: SparkPlan): Boolean = plan match { case s: ShuffleQueryStageExec => - s.shuffle.canChangeNumPartitions && s.mapStats.isDefined + s.mapStats.isDefined && supportLocalReader(s.shuffle) case CustomShuffleReaderExec(s: ShuffleQueryStageExec, partitionSpecs) => - s.shuffle.canChangeNumPartitions && s.mapStats.isDefined && partitionSpecs.nonEmpty + s.mapStats.isDefined && partitionSpecs.nonEmpty && supportLocalReader(s.shuffle) case _ => false } + + private def supportLocalReader(s: ShuffleExchangeLike): Boolean = { + s.outputPartitioning != SinglePartition && s.shuffleOrigin == ENSURE_REQUIREMENTS + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 6af4b098bee2f..affa92de693af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -57,9 +57,9 @@ trait ShuffleExchangeLike extends Exchange { def numPartitions: Int /** - * Returns whether the shuffle partition number can be changed. + * The origin of this shuffle operator. */ - def canChangeNumPartitions: Boolean + def shuffleOrigin: ShuffleOrigin /** * The asynchronous job that materializes the shuffle. @@ -77,18 +77,30 @@ trait ShuffleExchangeLike extends Exchange { def runtimeStatistics: Statistics } +// Describes where the shuffle operator comes from. +sealed trait ShuffleOrigin + +// Indicates that the shuffle operator was added by the internal `EnsureRequirements` rule. It +// means that the shuffle operator is used to ensure internal data partitioning requirements and +// Spark is free to optimize it as long as the requirements are still ensured. +case object ENSURE_REQUIREMENTS extends ShuffleOrigin + +// Indicates that the shuffle operator was added by the user-specified repartition operator. Spark +// can still optimize it via changing shuffle partition number, as data partitioning won't change. +case object REPARTITION extends ShuffleOrigin + +// Indicates that the shuffle operator was added by the user-specified repartition operator with +// a certain partition number. Spark can't optimize it. +case object REPARTITION_WITH_NUM extends ShuffleOrigin + /** * Performs a shuffle that will result in the desired partitioning. */ case class ShuffleExchangeExec( override val outputPartitioning: Partitioning, child: SparkPlan, - noUserSpecifiedNumPartition: Boolean = true) extends ShuffleExchangeLike { - - // If users specify the num partitions via APIs like `repartition`, we shouldn't change it. - // For `SinglePartition`, it requires exactly one partition and we can't change it either. - override def canChangeNumPartitions: Boolean = - noUserSpecifiedNumPartition && outputPartitioning != SinglePartition + shuffleOrigin: ShuffleOrigin = ENSURE_REQUIREMENTS) + extends ShuffleExchangeLike { private lazy val writeMetrics = SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext) diff --git a/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out b/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out index 567e0eabe1805..578b0a807fc52 100644 --- a/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 23 +-- Number of queries: 24 -- !query @@ -67,10 +67,10 @@ Aggregate [sum(distinct cast(val#x as bigint)) AS sum(DISTINCT val)#xL] == Physical Plan == AdaptiveSparkPlan isFinalPlan=false +- HashAggregate(keys=[], functions=[sum(distinct cast(val#x as bigint)#xL)], output=[sum(DISTINCT val)#xL]) - +- Exchange SinglePartition, true, [id=#x] + +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [id=#x] +- HashAggregate(keys=[], functions=[partial_sum(distinct cast(val#x as bigint)#xL)], output=[sum#xL]) +- HashAggregate(keys=[cast(val#x as bigint)#xL], functions=[], output=[cast(val#x as bigint)#xL]) - +- Exchange hashpartitioning(cast(val#x as bigint)#xL, 4), true, [id=#x] + +- Exchange hashpartitioning(cast(val#x as bigint)#xL, 4), ENSURE_REQUIREMENTS, [id=#x] +- HashAggregate(keys=[cast(val#x as bigint) AS cast(val#x as bigint)#xL], functions=[], output=[cast(val#x as bigint)#xL]) +- FileScan parquet default.explain_temp1[val#x] Batched: true, DataFilters: [], Format: Parquet, Location [not included in comparison]/{warehouse_dir}/explain_temp1], PartitionFilters: [], PushedFilters: [], ReadSchema: struct @@ -116,7 +116,7 @@ Results [2]: [key#x, max#x] (4) Exchange Input [2]: [key#x, max#x] -Arguments: hashpartitioning(key#x, 4), true, [id=#x] +Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x] (5) HashAggregate Input [2]: [key#x, max#x] @@ -127,7 +127,7 @@ Results [2]: [key#x, max(val#x)#x AS max(val)#x] (6) Exchange Input [2]: [key#x, max(val)#x] -Arguments: rangepartitioning(key#x ASC NULLS FIRST, 4), true, [id=#x] +Arguments: rangepartitioning(key#x ASC NULLS FIRST, 4), ENSURE_REQUIREMENTS, [id=#x] (7) Sort Input [2]: [key#x, max(val)#x] @@ -179,7 +179,7 @@ Results [2]: [key#x, max#x] (4) Exchange Input [2]: [key#x, max#x] -Arguments: hashpartitioning(key#x, 4), true, [id=#x] +Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x] (5) HashAggregate Input [2]: [key#x, max#x] @@ -254,7 +254,7 @@ Results [2]: [key#x, val#x] (7) Exchange Input [2]: [key#x, val#x] -Arguments: hashpartitioning(key#x, val#x, 4), true, [id=#x] +Arguments: hashpartitioning(key#x, val#x, 4), ENSURE_REQUIREMENTS, [id=#x] (8) HashAggregate Input [2]: [key#x, val#x] @@ -576,7 +576,7 @@ Results [2]: [key#x, max#x] (4) Exchange Input [2]: [key#x, max#x] -Arguments: hashpartitioning(key#x, 4), true, [id=#x] +Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x] (5) HashAggregate Input [2]: [key#x, max#x] @@ -605,7 +605,7 @@ Results [2]: [key#x, max#x] (9) Exchange Input [2]: [key#x, max#x] -Arguments: hashpartitioning(key#x, 4), true, [id=#x] +Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x] (10) HashAggregate Input [2]: [key#x, max#x] @@ -687,7 +687,7 @@ Results [3]: [count#xL, sum#xL, count#xL] (3) Exchange Input [3]: [count#xL, sum#xL, count#xL] -Arguments: SinglePartition, true, [id=#x] +Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x] (4) HashAggregate Input [3]: [count#xL, sum#xL, count#xL] @@ -732,7 +732,7 @@ Results [2]: [key#x, buf#x] (3) Exchange Input [2]: [key#x, buf#x] -Arguments: hashpartitioning(key#x, 4), true, [id=#x] +Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x] (4) ObjectHashAggregate Input [2]: [key#x, buf#x] @@ -783,7 +783,7 @@ Results [2]: [key#x, min#x] (4) Exchange Input [2]: [key#x, min#x] -Arguments: hashpartitioning(key#x, 4), true, [id=#x] +Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x] (5) Sort Input [2]: [key#x, min#x] diff --git a/sql/core/src/test/resources/sql-tests/results/explain.sql.out b/sql/core/src/test/resources/sql-tests/results/explain.sql.out index fcd69549f2c6e..886b98e538d28 100644 --- a/sql/core/src/test/resources/sql-tests/results/explain.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/explain.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 23 +-- Number of queries: 24 -- !query @@ -66,10 +66,10 @@ Aggregate [sum(distinct cast(val#x as bigint)) AS sum(DISTINCT val)#xL] == Physical Plan == *HashAggregate(keys=[], functions=[sum(distinct cast(val#x as bigint)#xL)], output=[sum(DISTINCT val)#xL]) -+- Exchange SinglePartition, true, [id=#x] ++- Exchange SinglePartition, ENSURE_REQUIREMENTS, [id=#x] +- *HashAggregate(keys=[], functions=[partial_sum(distinct cast(val#x as bigint)#xL)], output=[sum#xL]) +- *HashAggregate(keys=[cast(val#x as bigint)#xL], functions=[], output=[cast(val#x as bigint)#xL]) - +- Exchange hashpartitioning(cast(val#x as bigint)#xL, 4), true, [id=#x] + +- Exchange hashpartitioning(cast(val#x as bigint)#xL, 4), ENSURE_REQUIREMENTS, [id=#x] +- *HashAggregate(keys=[cast(val#x as bigint) AS cast(val#x as bigint)#xL], functions=[], output=[cast(val#x as bigint)#xL]) +- *ColumnarToRow +- FileScan parquet default.explain_temp1[val#x] Batched: true, DataFilters: [], Format: Parquet, Location [not included in comparison]/{warehouse_dir}/explain_temp1], PartitionFilters: [], PushedFilters: [], ReadSchema: struct @@ -119,7 +119,7 @@ Results [2]: [key#x, max#x] (5) Exchange Input [2]: [key#x, max#x] -Arguments: hashpartitioning(key#x, 4), true, [id=#x] +Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x] (6) HashAggregate [codegen id : 2] Input [2]: [key#x, max#x] @@ -130,7 +130,7 @@ Results [2]: [key#x, max(val#x)#x AS max(val)#x] (7) Exchange Input [2]: [key#x, max(val)#x] -Arguments: rangepartitioning(key#x ASC NULLS FIRST, 4), true, [id=#x] +Arguments: rangepartitioning(key#x ASC NULLS FIRST, 4), ENSURE_REQUIREMENTS, [id=#x] (8) Sort [codegen id : 3] Input [2]: [key#x, max(val)#x] @@ -181,7 +181,7 @@ Results [2]: [key#x, max#x] (5) Exchange Input [2]: [key#x, max#x] -Arguments: hashpartitioning(key#x, 4), true, [id=#x] +Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x] (6) HashAggregate [codegen id : 2] Input [2]: [key#x, max#x] @@ -259,7 +259,7 @@ Results [2]: [key#x, val#x] (9) Exchange Input [2]: [key#x, val#x] -Arguments: hashpartitioning(key#x, val#x, 4), true, [id=#x] +Arguments: hashpartitioning(key#x, val#x, 4), ENSURE_REQUIREMENTS, [id=#x] (10) HashAggregate [codegen id : 4] Input [2]: [key#x, val#x] @@ -452,7 +452,7 @@ Results [1]: [max#x] (9) Exchange Input [1]: [max#x] -Arguments: SinglePartition, true, [id=#x] +Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x] (10) HashAggregate [codegen id : 2] Input [1]: [max#x] @@ -498,7 +498,7 @@ Results [1]: [max#x] (16) Exchange Input [1]: [max#x] -Arguments: SinglePartition, true, [id=#x] +Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x] (17) HashAggregate [codegen id : 2] Input [1]: [max#x] @@ -580,7 +580,7 @@ Results [1]: [max#x] (9) Exchange Input [1]: [max#x] -Arguments: SinglePartition, true, [id=#x] +Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x] (10) HashAggregate [codegen id : 2] Input [1]: [max#x] @@ -626,7 +626,7 @@ Results [2]: [sum#x, count#xL] (16) Exchange Input [2]: [sum#x, count#xL] -Arguments: SinglePartition, true, [id=#x] +Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x] (17) HashAggregate [codegen id : 2] Input [2]: [sum#x, count#xL] @@ -690,7 +690,7 @@ Results [2]: [sum#x, count#xL] (7) Exchange Input [2]: [sum#x, count#xL] -Arguments: SinglePartition, true, [id=#x] +Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x] (8) HashAggregate [codegen id : 2] Input [2]: [sum#x, count#xL] @@ -810,7 +810,7 @@ Results [2]: [key#x, max#x] (5) Exchange Input [2]: [key#x, max#x] -Arguments: hashpartitioning(key#x, 4), true, [id=#x] +Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x] (6) HashAggregate [codegen id : 4] Input [2]: [key#x, max#x] @@ -901,7 +901,7 @@ Results [3]: [count#xL, sum#xL, count#xL] (4) Exchange Input [3]: [count#xL, sum#xL, count#xL] -Arguments: SinglePartition, true, [id=#x] +Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x] (5) HashAggregate [codegen id : 2] Input [3]: [count#xL, sum#xL, count#xL] @@ -945,7 +945,7 @@ Results [2]: [key#x, buf#x] (4) Exchange Input [2]: [key#x, buf#x] -Arguments: hashpartitioning(key#x, 4), true, [id=#x] +Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x] (5) ObjectHashAggregate Input [2]: [key#x, buf#x] @@ -995,7 +995,7 @@ Results [2]: [key#x, min#x] (5) Exchange Input [2]: [key#x, min#x] -Arguments: hashpartitioning(key#x, 4), true, [id=#x] +Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x] (6) Sort [codegen id : 2] Input [2]: [key#x, min#x] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index 951b72a863483..12abd31b99e93 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec} -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike, ShuffleExchangeExec, ShuffleExchangeLike} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin} import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.COLUMN_BATCH_SIZE @@ -766,7 +766,9 @@ case class PreRuleReplaceAddWithBrokenVersion() extends Rule[SparkPlan] { case class MyShuffleExchangeExec(delegate: ShuffleExchangeExec) extends ShuffleExchangeLike { override def numMappers: Int = delegate.numMappers override def numPartitions: Int = delegate.numPartitions - override def canChangeNumPartitions: Boolean = delegate.canChangeNumPartitions + override def shuffleOrigin: ShuffleOrigin = { + delegate.shuffleOrigin + } override def mapOutputStatisticsFuture: Future[MapOutputStatistics] = delegate.mapOutputStatisticsFuture override def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[_] = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 38a323b1c057e..758965954b374 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -1307,4 +1307,14 @@ class AdaptiveQueryExecSuite spark.listenerManager.unregister(listener) } } + + test("SPARK-33494: Do not use local shuffle reader for repartition") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + val df = spark.table("testData").repartition('key) + df.collect() + // local shuffle reader breaks partitioning and shouldn't be used for repartition operation + // which is specified by users. + checkNumLocalShuffleReaders(df.queryExecution.executedPlan, numShufflesWithoutLocalReader = 1) + } + } }