@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution
1919
2020import org .apache .spark .annotation .DeveloperApi
2121import org .apache .spark .shuffle .sort .SortShuffleManager
22- import org .apache .spark .{SparkEnv , HashPartitioner , RangePartitioner }
22+ import org .apache .spark .{HashPartitioner , Partitioner , RangePartitioner , SparkEnv }
2323import org .apache .spark .rdd .{RDD , ShuffledRDD }
2424import org .apache .spark .serializer .Serializer
2525import org .apache .spark .sql .{SQLContext , Row }
@@ -81,21 +81,25 @@ case class Exchange(
8181 *
8282 * See SPARK-2967, SPARK-4479, and SPARK-7375 for more discussion of this issue.
8383 *
84- * @param numPartitions the number of output partitions produced by the shuffle
84+ * @param partitioner the partitioner for the shuffle
8585 * @param serializer the serializer that will be used to write rows
8686 * @return true if rows should be copied before being shuffled, false otherwise
8787 */
8888 private def needToCopyObjectsBeforeShuffle (
89- numPartitions : Int ,
89+ partitioner : Partitioner ,
9090 serializer : Serializer ): Boolean = {
91+ // Note: even though we only use the partitioner's `numPartitions` field, we require it to be
92+ // passed instead of directly passing the number of partitions in order to guard against
93+ // corner-cases where a partitioner constructed with `numPartitions` partitions may output
94+ // fewer partitions (like RangeParittioner, for example).
9195 if (newOrdering.nonEmpty) {
9296 // If a new ordering is required, then records will be sorted with Spark's `ExternalSorter`,
9397 // which requires a defensive copy.
9498 true
9599 } else if (sortBasedShuffleOn) {
96100 // Spark's sort-based shuffle also uses `ExternalSorter` to buffer records in memory.
97101 // However, there are two special cases where we can avoid the copy, described below:
98- if (numPartitions <= bypassMergeThreshold) {
102+ if (partitioner. numPartitions <= bypassMergeThreshold) {
99103 // If the number of output partitions is sufficiently small, then Spark will fall back to
100104 // the old hash-based shuffle write path which doesn't buffer deserialized records.
101105 // Note that we'll have to remove this case if we fix SPARK-6026 and remove this bypass.
@@ -177,8 +181,9 @@ case class Exchange(
177181 val keySchema = expressions.map(_.dataType).toArray
178182 val valueSchema = child.output.map(_.dataType).toArray
179183 val serializer = getSerializer(keySchema, valueSchema, numPartitions)
184+ val part = new HashPartitioner (numPartitions)
180185
181- val rdd = if (needToCopyObjectsBeforeShuffle(numPartitions , serializer)) {
186+ val rdd = if (needToCopyObjectsBeforeShuffle(part , serializer)) {
182187 child.execute().mapPartitions { iter =>
183188 val hashExpressions = newMutableProjection(expressions, child.output)()
184189 iter.map(r => (hashExpressions(r).copy(), r.copy()))
@@ -190,55 +195,59 @@ case class Exchange(
190195 iter.map(r => mutablePair.update(hashExpressions(r), r))
191196 }
192197 }
193- val part = new HashPartitioner (numPartitions)
194- val shuffled =
195- if (newOrdering.nonEmpty) {
196- new ShuffledRDD [Row , Row , Row ](rdd, part).setKeyOrdering(keyOrdering)
197- } else {
198- new ShuffledRDD [Row , Row , Row ](rdd, part)
199- }
198+ val shuffled = new ShuffledRDD [Row , Row , Row ](rdd, part)
199+ if (newOrdering.nonEmpty) {
200+ shuffled.setKeyOrdering(keyOrdering)
201+ }
200202 shuffled.setSerializer(serializer)
201203 shuffled.map(_._2)
202204
203205 case RangePartitioning (sortingExpressions, numPartitions) =>
204206 val keySchema = child.output.map(_.dataType).toArray
205207 val serializer = getSerializer(keySchema, null , numPartitions)
206208
207- val rdd = if (needToCopyObjectsBeforeShuffle(numPartitions, serializer)) {
208- child.execute().mapPartitions { iter => iter.map(row => (row.copy(), null ))}
209+ val childRdd = child.execute()
210+ val part : Partitioner = {
211+ // Internally, RangePartitioner runs a job on the RDD that samples keys to compute
212+ // partition bounds. To get accurate samples, we need to copy the mutable keys.
213+ val rddForSampling = childRdd.mapPartitions { iter =>
214+ val mutablePair = new MutablePair [Row , Null ]()
215+ iter.map(row => mutablePair.update(row.copy(), null ))
216+ }
217+ // TODO: RangePartitioner should take an Ordering.
218+ implicit val ordering = new RowOrdering (sortingExpressions, child.output)
219+ new RangePartitioner (numPartitions, rddForSampling, ascending = true )
220+ }
221+
222+ val rdd = if (needToCopyObjectsBeforeShuffle(part, serializer)) {
223+ childRdd.mapPartitions { iter => iter.map(row => (row.copy(), null ))}
209224 } else {
210- child.execute() .mapPartitions { iter =>
211- val mutablePair = new MutablePair [Row , Null ](null , null )
225+ childRdd .mapPartitions { iter =>
226+ val mutablePair = new MutablePair [Row , Null ]()
212227 iter.map(row => mutablePair.update(row, null ))
213228 }
214229 }
215230
216- // TODO: RangePartitioner should take an Ordering.
217- implicit val ordering = new RowOrdering (sortingExpressions, child.output)
218-
219- val part = new RangePartitioner (numPartitions, rdd, ascending = true )
220- val shuffled =
221- if (newOrdering.nonEmpty) {
222- new ShuffledRDD [Row , Null , Null ](rdd, part).setKeyOrdering(keyOrdering)
223- } else {
224- new ShuffledRDD [Row , Null , Null ](rdd, part)
225- }
231+ val shuffled = new ShuffledRDD [Row , Null , Null ](rdd, part)
232+ if (newOrdering.nonEmpty) {
233+ shuffled.setKeyOrdering(keyOrdering)
234+ }
226235 shuffled.setSerializer(serializer)
227236 shuffled.map(_._1)
228237
229238 case SinglePartition =>
230239 val valueSchema = child.output.map(_.dataType).toArray
231240 val serializer = getSerializer(null , valueSchema, 1 )
241+ val partitioner = new HashPartitioner (1 )
232242
233- val rdd = if (needToCopyObjectsBeforeShuffle(numPartitions = 1 , serializer)) {
243+ val rdd = if (needToCopyObjectsBeforeShuffle(partitioner , serializer)) {
234244 child.execute().mapPartitions { iter => iter.map(r => (null , r.copy())) }
235245 } else {
236246 child.execute().mapPartitions { iter =>
237247 val mutablePair = new MutablePair [Null , Row ]()
238248 iter.map(r => mutablePair.update(null , r))
239249 }
240250 }
241- val partitioner = new HashPartitioner (1 )
242251 val shuffled = new ShuffledRDD [Null , Row , Row ](rdd, partitioner)
243252 shuffled.setSerializer(serializer)
244253 shuffled.map(_._2)
0 commit comments