Skip to content

Commit 1ed28f2

Browse files
committed
refine
1 parent 5f8f4ed commit 1ed28f2

File tree

7 files changed

+64
-66
lines changed

7 files changed

+64
-66
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.streaming.{InternalOutputModes, StreamingRe
3232
import org.apache.spark.sql.execution.aggregate.AggUtils
3333
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
3434
import org.apache.spark.sql.execution.command._
35-
import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleOrigin}
35+
import org.apache.spark.sql.execution.exchange.{REPARTITION, REPARTITION_WITH_NUM, ShuffleExchangeExec}
3636
import org.apache.spark.sql.execution.python._
3737
import org.apache.spark.sql.execution.streaming._
3838
import org.apache.spark.sql.execution.streaming.sources.MemoryPlan
@@ -670,7 +670,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
670670
case logical.Repartition(numPartitions, shuffle, child) =>
671671
if (shuffle) {
672672
ShuffleExchangeExec(RoundRobinPartitioning(numPartitions),
673-
planLater(child), ShuffleOrigin.REPARTITION_WITH_NUM) :: Nil
673+
planLater(child), REPARTITION_WITH_NUM) :: Nil
674674
} else {
675675
execution.CoalesceExec(numPartitions, planLater(child)) :: Nil
676676
}
@@ -704,9 +704,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
704704
execution.RangeExec(r) :: Nil
705705
case r: logical.RepartitionByExpression =>
706706
val shuffleOrigin = if (r.optNumPartitions.isEmpty) {
707-
ShuffleOrigin.REPARTITION
707+
REPARTITION
708708
} else {
709-
ShuffleOrigin.REPARTITION_WITH_NUM
709+
REPARTITION_WITH_NUM
710710
}
711711
exchange.ShuffleExchangeExec(r.partitioning, planLater(r.child), shuffleOrigin) :: Nil
712712
case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818
package org.apache.spark.sql.execution.adaptive
1919

2020
import org.apache.spark.sql.SparkSession
21+
import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
2122
import org.apache.spark.sql.catalyst.rules.Rule
2223
import org.apache.spark.sql.execution.SparkPlan
24+
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, REPARTITION, ShuffleExchangeLike}
2325
import org.apache.spark.sql.internal.SQLConf
2426

2527
/**
@@ -47,7 +49,7 @@ case class CoalesceShufflePartitions(session: SparkSession) extends Rule[SparkPl
4749
val shuffleStages = collectShuffleStages(plan)
4850
// ShuffleExchanges introduced by repartition do not support changing the number of partitions.
4951
// We change the number of partitions in the stage only if all the ShuffleExchanges support it.
50-
if (!shuffleStages.forall(_.shuffle.canChangeNumPartitions)) {
52+
if (!shuffleStages.forall(s => supportCoalesce(s.shuffle))) {
5153
plan
5254
} else {
5355
// `ShuffleQueryStageExec#mapStats` returns None when the input RDD has 0 partitions,
@@ -82,4 +84,9 @@ case class CoalesceShufflePartitions(session: SparkSession) extends Rule[SparkPl
8284
}
8385
}
8486
}
87+
88+
private def supportCoalesce(s: ShuffleExchangeLike): Boolean = {
89+
s.outputPartitioning != SinglePartition &&
90+
(s.shuffleOrigin == ENSURE_REQUIREMENTS || s.shuffleOrigin == REPARTITION)
91+
}
8592
}

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@
1818
package org.apache.spark.sql.execution.adaptive
1919

2020
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
21+
import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
2122
import org.apache.spark.sql.catalyst.rules.Rule
2223
import org.apache.spark.sql.execution._
23-
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchangeExec}
24+
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, EnsureRequirements, ShuffleExchangeExec, ShuffleExchangeLike}
2425
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
2526
import org.apache.spark.sql.internal.SQLConf
2627

@@ -136,9 +137,13 @@ object OptimizeLocalShuffleReader extends Rule[SparkPlan] {
136137

137138
def canUseLocalShuffleReader(plan: SparkPlan): Boolean = plan match {
138139
case s: ShuffleQueryStageExec =>
139-
s.shuffle.canChangePartitioning && s.mapStats.isDefined
140+
s.mapStats.isDefined && supportLocalReader(s.shuffle)
140141
case CustomShuffleReaderExec(s: ShuffleQueryStageExec, partitionSpecs) =>
141-
s.shuffle.canChangePartitioning && s.mapStats.isDefined && partitionSpecs.nonEmpty
142+
s.mapStats.isDefined && partitionSpecs.nonEmpty && supportLocalReader(s.shuffle)
142143
case _ => false
143144
}
145+
146+
private def supportLocalReader(s: ShuffleExchangeLike): Boolean = {
147+
s.outputPartitioning != SinglePartition && s.shuffleOrigin == ENSURE_REQUIREMENTS
148+
}
144149
}

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala

Lines changed: 17 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -56,25 +56,10 @@ trait ShuffleExchangeLike extends Exchange {
5656
*/
5757
def numPartitions: Int
5858

59-
def shuffleOrigin: ShuffleOrigin.Value
60-
61-
/**
62-
* Returns whether the shuffle partition number can be changed.
63-
*/
64-
final def canChangeNumPartitions: Boolean = {
65-
// If users specify the num partitions via APIs like `repartition(5, col)`, we shouldn't change
66-
// it. For `SinglePartition`, it requires exactly one partition and we can't change it either.
67-
shuffleOrigin != ShuffleOrigin.REPARTITION_WITH_NUM && outputPartitioning != SinglePartition
68-
}
69-
7059
/**
71-
* Returns whether the shuffle output data partitioning can be changed.
60+
* The origin of this shuffle operator.
7261
*/
73-
final def canChangePartitioning: Boolean = {
74-
// If users specify the partitioning via APIs like `repartition(col)`, we shouldn't change it.
75-
// For `SinglePartition`, itself is a special partitioning and we can't change it either.
76-
shuffleOrigin == ShuffleOrigin.ENSURE_REQUIREMENTS && outputPartitioning != SinglePartition
77-
}
62+
def shuffleOrigin: ShuffleOrigin
7863

7964
/**
8065
* The asynchronous job that materializes the shuffle.
@@ -93,27 +78,28 @@ trait ShuffleExchangeLike extends Exchange {
9378
}
9479

9580
// Describes where the shuffle operator comes from.
96-
object ShuffleOrigin extends Enumeration {
97-
type ShuffleOrigin = Value
98-
// Indicates that the shuffle operator was added by the internal `EnsureRequirements` rule. It
99-
// means that the shuffle operator is used to ensure internal data partitioning requirements and
100-
// Spark is free to optimize it as long as the requirements are still ensured.
101-
val ENSURE_REQUIREMENTS = Value
102-
// Indicates that the shuffle operator was added by the user-specified repartition operator. Spark
103-
// can still optimize it via changing shuffle partition number, as data partitioning won't change.
104-
val REPARTITION = Value
105-
// Indicates that the shuffle operator was added by the user-specified repartition operator with
106-
// a certain partition number. Spark can't optimize it.
107-
val REPARTITION_WITH_NUM = Value
108-
}
81+
sealed trait ShuffleOrigin
82+
83+
// Indicates that the shuffle operator was added by the internal `EnsureRequirements` rule. It
84+
// means that the shuffle operator is used to ensure internal data partitioning requirements and
85+
// Spark is free to optimize it as long as the requirements are still ensured.
86+
case object ENSURE_REQUIREMENTS extends ShuffleOrigin
87+
88+
// Indicates that the shuffle operator was added by the user-specified repartition operator. Spark
89+
// can still optimize it via changing shuffle partition number, as data partitioning won't change.
90+
case object REPARTITION extends ShuffleOrigin
91+
92+
// Indicates that the shuffle operator was added by the user-specified repartition operator with
93+
// a certain partition number. Spark can't optimize it.
94+
case object REPARTITION_WITH_NUM extends ShuffleOrigin
10995

11096
/**
11197
* Performs a shuffle that will result in the desired partitioning.
11298
*/
11399
case class ShuffleExchangeExec(
114100
override val outputPartitioning: Partitioning,
115101
child: SparkPlan,
116-
shuffleOrigin: ShuffleOrigin.Value = ShuffleOrigin.ENSURE_REQUIREMENTS)
102+
shuffleOrigin: ShuffleOrigin = ENSURE_REQUIREMENTS)
117103
extends ShuffleExchangeLike {
118104

119105
private lazy val writeMetrics =

sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,10 @@ Aggregate [sum(distinct cast(val#x as bigint)) AS sum(DISTINCT val)#xL]
6767
== Physical Plan ==
6868
AdaptiveSparkPlan isFinalPlan=false
6969
+- HashAggregate(keys=[], functions=[sum(distinct cast(val#x as bigint)#xL)], output=[sum(DISTINCT val)#xL])
70-
+- Exchange SinglePartition, UNSPECIFIED, [id=#x]
70+
+- Exchange SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
7171
+- HashAggregate(keys=[], functions=[partial_sum(distinct cast(val#x as bigint)#xL)], output=[sum#xL])
7272
+- HashAggregate(keys=[cast(val#x as bigint)#xL], functions=[], output=[cast(val#x as bigint)#xL])
73-
+- Exchange hashpartitioning(cast(val#x as bigint)#xL, 4), UNSPECIFIED, [id=#x]
73+
+- Exchange hashpartitioning(cast(val#x as bigint)#xL, 4), ENSURE_REQUIREMENTS, [id=#x]
7474
+- HashAggregate(keys=[cast(val#x as bigint) AS cast(val#x as bigint)#xL], functions=[], output=[cast(val#x as bigint)#xL])
7575
+- FileScan parquet default.explain_temp1[val#x] Batched: true, DataFilters: [], Format: Parquet, Location [not included in comparison]/{warehouse_dir}/explain_temp1], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<val:int>
7676

@@ -116,7 +116,7 @@ Results [2]: [key#x, max#x]
116116

117117
(4) Exchange
118118
Input [2]: [key#x, max#x]
119-
Arguments: hashpartitioning(key#x, 4), UNSPECIFIED, [id=#x]
119+
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]
120120

121121
(5) HashAggregate
122122
Input [2]: [key#x, max#x]
@@ -127,7 +127,7 @@ Results [2]: [key#x, max(val#x)#x AS max(val)#x]
127127

128128
(6) Exchange
129129
Input [2]: [key#x, max(val)#x]
130-
Arguments: rangepartitioning(key#x ASC NULLS FIRST, 4), UNSPECIFIED, [id=#x]
130+
Arguments: rangepartitioning(key#x ASC NULLS FIRST, 4), ENSURE_REQUIREMENTS, [id=#x]
131131

132132
(7) Sort
133133
Input [2]: [key#x, max(val)#x]
@@ -179,7 +179,7 @@ Results [2]: [key#x, max#x]
179179

180180
(4) Exchange
181181
Input [2]: [key#x, max#x]
182-
Arguments: hashpartitioning(key#x, 4), UNSPECIFIED, [id=#x]
182+
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]
183183

184184
(5) HashAggregate
185185
Input [2]: [key#x, max#x]
@@ -254,7 +254,7 @@ Results [2]: [key#x, val#x]
254254

255255
(7) Exchange
256256
Input [2]: [key#x, val#x]
257-
Arguments: hashpartitioning(key#x, val#x, 4), UNSPECIFIED, [id=#x]
257+
Arguments: hashpartitioning(key#x, val#x, 4), ENSURE_REQUIREMENTS, [id=#x]
258258

259259
(8) HashAggregate
260260
Input [2]: [key#x, val#x]
@@ -576,7 +576,7 @@ Results [2]: [key#x, max#x]
576576

577577
(4) Exchange
578578
Input [2]: [key#x, max#x]
579-
Arguments: hashpartitioning(key#x, 4), UNSPECIFIED, [id=#x]
579+
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]
580580

581581
(5) HashAggregate
582582
Input [2]: [key#x, max#x]
@@ -605,7 +605,7 @@ Results [2]: [key#x, max#x]
605605

606606
(9) Exchange
607607
Input [2]: [key#x, max#x]
608-
Arguments: hashpartitioning(key#x, 4), UNSPECIFIED, [id=#x]
608+
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]
609609

610610
(10) HashAggregate
611611
Input [2]: [key#x, max#x]
@@ -687,7 +687,7 @@ Results [3]: [count#xL, sum#xL, count#xL]
687687

688688
(3) Exchange
689689
Input [3]: [count#xL, sum#xL, count#xL]
690-
Arguments: SinglePartition, UNSPECIFIED, [id=#x]
690+
Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
691691

692692
(4) HashAggregate
693693
Input [3]: [count#xL, sum#xL, count#xL]
@@ -732,7 +732,7 @@ Results [2]: [key#x, buf#x]
732732

733733
(3) Exchange
734734
Input [2]: [key#x, buf#x]
735-
Arguments: hashpartitioning(key#x, 4), UNSPECIFIED, [id=#x]
735+
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]
736736

737737
(4) ObjectHashAggregate
738738
Input [2]: [key#x, buf#x]
@@ -783,7 +783,7 @@ Results [2]: [key#x, min#x]
783783

784784
(4) Exchange
785785
Input [2]: [key#x, min#x]
786-
Arguments: hashpartitioning(key#x, 4), UNSPECIFIED, [id=#x]
786+
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]
787787

788788
(5) Sort
789789
Input [2]: [key#x, min#x]

sql/core/src/test/resources/sql-tests/results/explain.sql.out

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,10 @@ Aggregate [sum(distinct cast(val#x as bigint)) AS sum(DISTINCT val)#xL]
6666

6767
== Physical Plan ==
6868
*HashAggregate(keys=[], functions=[sum(distinct cast(val#x as bigint)#xL)], output=[sum(DISTINCT val)#xL])
69-
+- Exchange SinglePartition, UNSPECIFIED, [id=#x]
69+
+- Exchange SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
7070
+- *HashAggregate(keys=[], functions=[partial_sum(distinct cast(val#x as bigint)#xL)], output=[sum#xL])
7171
+- *HashAggregate(keys=[cast(val#x as bigint)#xL], functions=[], output=[cast(val#x as bigint)#xL])
72-
+- Exchange hashpartitioning(cast(val#x as bigint)#xL, 4), UNSPECIFIED, [id=#x]
72+
+- Exchange hashpartitioning(cast(val#x as bigint)#xL, 4), ENSURE_REQUIREMENTS, [id=#x]
7373
+- *HashAggregate(keys=[cast(val#x as bigint) AS cast(val#x as bigint)#xL], functions=[], output=[cast(val#x as bigint)#xL])
7474
+- *ColumnarToRow
7575
+- FileScan parquet default.explain_temp1[val#x] Batched: true, DataFilters: [], Format: Parquet, Location [not included in comparison]/{warehouse_dir}/explain_temp1], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<val:int>
@@ -119,7 +119,7 @@ Results [2]: [key#x, max#x]
119119

120120
(5) Exchange
121121
Input [2]: [key#x, max#x]
122-
Arguments: hashpartitioning(key#x, 4), UNSPECIFIED, [id=#x]
122+
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]
123123

124124
(6) HashAggregate [codegen id : 2]
125125
Input [2]: [key#x, max#x]
@@ -130,7 +130,7 @@ Results [2]: [key#x, max(val#x)#x AS max(val)#x]
130130

131131
(7) Exchange
132132
Input [2]: [key#x, max(val)#x]
133-
Arguments: rangepartitioning(key#x ASC NULLS FIRST, 4), UNSPECIFIED, [id=#x]
133+
Arguments: rangepartitioning(key#x ASC NULLS FIRST, 4), ENSURE_REQUIREMENTS, [id=#x]
134134

135135
(8) Sort [codegen id : 3]
136136
Input [2]: [key#x, max(val)#x]
@@ -181,7 +181,7 @@ Results [2]: [key#x, max#x]
181181

182182
(5) Exchange
183183
Input [2]: [key#x, max#x]
184-
Arguments: hashpartitioning(key#x, 4), UNSPECIFIED, [id=#x]
184+
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]
185185

186186
(6) HashAggregate [codegen id : 2]
187187
Input [2]: [key#x, max#x]
@@ -259,7 +259,7 @@ Results [2]: [key#x, val#x]
259259

260260
(9) Exchange
261261
Input [2]: [key#x, val#x]
262-
Arguments: hashpartitioning(key#x, val#x, 4), UNSPECIFIED, [id=#x]
262+
Arguments: hashpartitioning(key#x, val#x, 4), ENSURE_REQUIREMENTS, [id=#x]
263263

264264
(10) HashAggregate [codegen id : 4]
265265
Input [2]: [key#x, val#x]
@@ -452,7 +452,7 @@ Results [1]: [max#x]
452452

453453
(9) Exchange
454454
Input [1]: [max#x]
455-
Arguments: SinglePartition, UNSPECIFIED, [id=#x]
455+
Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
456456

457457
(10) HashAggregate [codegen id : 2]
458458
Input [1]: [max#x]
@@ -498,7 +498,7 @@ Results [1]: [max#x]
498498

499499
(16) Exchange
500500
Input [1]: [max#x]
501-
Arguments: SinglePartition, UNSPECIFIED, [id=#x]
501+
Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
502502

503503
(17) HashAggregate [codegen id : 2]
504504
Input [1]: [max#x]
@@ -580,7 +580,7 @@ Results [1]: [max#x]
580580

581581
(9) Exchange
582582
Input [1]: [max#x]
583-
Arguments: SinglePartition, UNSPECIFIED, [id=#x]
583+
Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
584584

585585
(10) HashAggregate [codegen id : 2]
586586
Input [1]: [max#x]
@@ -626,7 +626,7 @@ Results [2]: [sum#x, count#xL]
626626

627627
(16) Exchange
628628
Input [2]: [sum#x, count#xL]
629-
Arguments: SinglePartition, UNSPECIFIED, [id=#x]
629+
Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
630630

631631
(17) HashAggregate [codegen id : 2]
632632
Input [2]: [sum#x, count#xL]
@@ -690,7 +690,7 @@ Results [2]: [sum#x, count#xL]
690690

691691
(7) Exchange
692692
Input [2]: [sum#x, count#xL]
693-
Arguments: SinglePartition, UNSPECIFIED, [id=#x]
693+
Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
694694

695695
(8) HashAggregate [codegen id : 2]
696696
Input [2]: [sum#x, count#xL]
@@ -810,7 +810,7 @@ Results [2]: [key#x, max#x]
810810

811811
(5) Exchange
812812
Input [2]: [key#x, max#x]
813-
Arguments: hashpartitioning(key#x, 4), UNSPECIFIED, [id=#x]
813+
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]
814814

815815
(6) HashAggregate [codegen id : 4]
816816
Input [2]: [key#x, max#x]
@@ -901,7 +901,7 @@ Results [3]: [count#xL, sum#xL, count#xL]
901901

902902
(4) Exchange
903903
Input [3]: [count#xL, sum#xL, count#xL]
904-
Arguments: SinglePartition, UNSPECIFIED, [id=#x]
904+
Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
905905

906906
(5) HashAggregate [codegen id : 2]
907907
Input [3]: [count#xL, sum#xL, count#xL]
@@ -945,7 +945,7 @@ Results [2]: [key#x, buf#x]
945945

946946
(4) Exchange
947947
Input [2]: [key#x, buf#x]
948-
Arguments: hashpartitioning(key#x, 4), UNSPECIFIED, [id=#x]
948+
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]
949949

950950
(5) ObjectHashAggregate
951951
Input [2]: [key#x, buf#x]
@@ -995,7 +995,7 @@ Results [2]: [key#x, min#x]
995995

996996
(5) Exchange
997997
Input [2]: [key#x, min#x]
998-
Arguments: hashpartitioning(key#x, 4), UNSPECIFIED, [id=#x]
998+
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]
999999

10001000
(6) Sort [codegen id : 2]
10011001
Input [2]: [key#x, min#x]

sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -766,7 +766,7 @@ case class PreRuleReplaceAddWithBrokenVersion() extends Rule[SparkPlan] {
766766
case class MyShuffleExchangeExec(delegate: ShuffleExchangeExec) extends ShuffleExchangeLike {
767767
override def numMappers: Int = delegate.numMappers
768768
override def numPartitions: Int = delegate.numPartitions
769-
override def shuffleOrigin: ShuffleOrigin.Value = {
769+
override def shuffleOrigin: ShuffleOrigin = {
770770
delegate.shuffleOrigin
771771
}
772772
override def mapOutputStatisticsFuture: Future[MapOutputStatistics] =

0 commit comments

Comments
 (0)