Skip to content

Commit b615476

Browse files
committed
Scaffolding for sort-based shuffle
1 parent e3d85b7 commit b615476

File tree

7 files changed

+164
-3
lines changed

7 files changed

+164
-3
lines changed

core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import org.apache.spark.shuffle._
2424
* A ShuffleManager using hashing, that creates one output file per reduce partition on each
2525
* mapper (possibly reusing these across waves of tasks).
2626
*/
27-
class HashShuffleManager(conf: SparkConf) extends ShuffleManager {
27+
private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager {
2828
/* Register a shuffle with the manager and obtain a handle for it to pass to tasks. */
2929
override def registerShuffle[K, V, C](
3030
shuffleId: Int,

core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import org.apache.spark.{InterruptibleIterator, TaskContext}
2121
import org.apache.spark.serializer.Serializer
2222
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
2323

24-
class HashShuffleReader[K, C](
24+
private[spark] class HashShuffleReader[K, C](
2525
handle: BaseShuffleHandle[K, _, C],
2626
startPartition: Int,
2727
endPartition: Int,

core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import org.apache.spark.serializer.Serializer
2424
import org.apache.spark.executor.ShuffleWriteMetrics
2525
import org.apache.spark.scheduler.MapStatus
2626

27-
class HashShuffleWriter[K, V](
27+
private[spark] class HashShuffleWriter[K, V](
2828
handle: BaseShuffleHandle[K, V, _],
2929
mapId: Int,
3030
context: TaskContext)
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.shuffle.sort
19+
20+
import org.apache.spark.shuffle._
21+
import org.apache.spark.{TaskContext, ShuffleDependency}
22+
import org.apache.spark.shuffle.hash.HashShuffleReader
23+
24+
private[spark] class SortShuffleManager extends ShuffleManager {
25+
/**
26+
* Register a shuffle with the manager and obtain a handle for it to pass to tasks.
27+
*/
28+
override def registerShuffle[K, V, C](
29+
shuffleId: Int,
30+
numMaps: Int,
31+
dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
32+
new BaseShuffleHandle(shuffleId, numMaps, dependency)
33+
}
34+
35+
/**
36+
* Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive).
37+
* Called on executors by reduce tasks.
38+
*/
39+
override def getReader[K, C](
40+
handle: ShuffleHandle,
41+
startPartition: Int,
42+
endPartition: Int,
43+
context: TaskContext): ShuffleReader[K, C] = {
44+
// We currently use the same block store shuffle fetcher as the hash-based shuffle.
45+
new HashShuffleReader(
46+
handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)
47+
}
48+
49+
/** Get a writer for a given partition. Called on executors by map tasks. */
50+
override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext)
51+
: ShuffleWriter[K, V] = {
52+
new SortShuffleWriter(handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context)
53+
}
54+
55+
/** Remove a shuffle's metadata from the ShuffleManager. */
56+
override def unregisterShuffle(shuffleId: Int): Unit = {}
57+
58+
/** Shut down this ShuffleManager. */
59+
override def stop(): Unit = {}
60+
}
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.shuffle.sort
19+
20+
import org.apache.spark.shuffle.{ShuffleWriter, BaseShuffleHandle}
21+
import org.apache.spark.{SparkEnv, Logging, TaskContext}
22+
import org.apache.spark.scheduler.MapStatus
23+
import org.apache.spark.serializer.Serializer
24+
25+
private[spark] class SortShuffleWriter[K, V](
26+
handle: BaseShuffleHandle[K, V, _],
27+
mapId: Int,
28+
context: TaskContext)
29+
extends ShuffleWriter[K, V] with Logging {
30+
31+
private val dep = handle.dependency
32+
private val numOutputPartitions = dep.partitioner.numPartitions
33+
private val metrics = context.taskMetrics
34+
35+
private val blockManager = SparkEnv.get.blockManager
36+
private val diskBlockManager = blockManager.diskBlockManager
37+
private val ser = Serializer.getSerializer(dep.serializer.getOrElse(null))
38+
39+
/** Write a bunch of records to this task's output */
40+
override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
41+
val iter = if (dep.aggregator.isDefined) {
42+
if (dep.mapSideCombine) {
43+
// TODO: This does an external merge-sort if the data is highly combinable, and then we
44+
// do another one later to sort them by output partition. We can improve this by doing
45+
// the merging as part of the SortedFileWriter.
46+
dep.aggregator.get.combineValuesByKey(records, context)
47+
} else {
48+
records
49+
}
50+
} else if (dep.aggregator.isEmpty && dep.mapSideCombine) {
51+
throw new IllegalStateException("Aggregator is empty for map-side combine")
52+
} else {
53+
records
54+
}
55+
56+
???
57+
}
58+
59+
/** Close this writer, passing along whether the map completed */
60+
override def stop(success: Boolean): Option[MapStatus] = ???
61+
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.shuffle.sort
19+
20+
import org.apache.spark.{Logging, Partitioner}
21+
import org.apache.spark.storage.DiskBlockManager
22+
23+
/**
24+
*
25+
*/
26+
private[spark] class SortedFileWriter[K, V](
27+
partitioner: Partitioner,
28+
diskBlockManager: DiskBlockManager) extends Logging {
29+
30+
val numPartitions = partitioner.numPartitions
31+
32+
def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
33+
34+
}
35+
36+
def stop() {
37+
38+
}
39+
}

project/SparkBuild.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,7 @@ object Unidoc {
292292
.map(_.filterNot(_.getCanonicalPath.contains("akka")))
293293
.map(_.filterNot(_.getCanonicalPath.contains("deploy")))
294294
.map(_.filterNot(_.getCanonicalPath.contains("network")))
295+
.map(_.filterNot(_.getCanonicalPath.contains("shuffle")))
295296
.map(_.filterNot(_.getCanonicalPath.contains("executor")))
296297
.map(_.filterNot(_.getCanonicalPath.contains("python")))
297298
.map(_.filterNot(_.getCanonicalPath.contains("collection")))

0 commit comments

Comments
 (0)