@@ -19,9 +19,10 @@ package org.apache.spark.sql.execution
1919
2020import org .apache .spark .annotation .DeveloperApi
2121import org .apache .spark .shuffle .sort .SortShuffleManager
22- import org .apache .spark .sql .catalyst . expressions
22+ import org .apache .spark .sql .types . DataType
2323import org .apache .spark .{SparkEnv , HashPartitioner , RangePartitioner , SparkConf }
2424import org .apache .spark .rdd .{RDD , ShuffledRDD }
25+ import org .apache .spark .serializer .Serializer
2526import org .apache .spark .sql .{SQLContext , Row }
2627import org .apache .spark .sql .catalyst .errors .attachTree
2728import org .apache .spark .sql .catalyst .expressions .{Attribute , RowOrdering }
@@ -45,6 +46,27 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
4546 private val bypassMergeThreshold =
4647 child.sqlContext.sparkContext.conf.getInt(" spark.shuffle.sort.bypassMergeThreshold" , 200 )
4748
49+ def serializer (
50+ keySchema : Array [DataType ],
51+ valueSchema : Array [DataType ],
52+ numPartitions : Int ): Serializer = {
53+ val useSqlSerializer2 =
54+ ! (sortBasedShuffleOn && numPartitions > bypassMergeThreshold) &&
55+ child.sqlContext.conf.useSqlSerializer2 &&
56+ SparkSqlSerializer2 .support(keySchema) &&
57+ SparkSqlSerializer2 .support(valueSchema)
58+
59+ val serializer = if (useSqlSerializer2) {
60+ logInfo(" Use ShuffleSerializer" )
61+ new SparkSqlSerializer2 (keySchema, valueSchema)
62+ } else {
63+ logInfo(" Use SparkSqlSerializer" )
64+ new SparkSqlSerializer (new SparkConf (false ))
65+ }
66+
67+ serializer
68+ }
69+
4870 override def execute (): RDD [Row ] = attachTree(this , " execute" ) {
4971 newPartitioning match {
5072 case HashPartitioning (expressions, numPartitions) =>
@@ -70,7 +92,11 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
7092 }
7193 val part = new HashPartitioner (numPartitions)
7294 val shuffled = new ShuffledRDD [Row , Row , Row ](rdd, part)
73- shuffled.setSerializer(new SparkSqlSerializer (new SparkConf (false )))
95+
96+ val keySchema = expressions.map(_.dataType).toArray
97+ val valueSchema = child.output.map(_.dataType).toArray
98+ shuffled.setSerializer(serializer(keySchema, valueSchema, numPartitions))
99+
74100 shuffled.map(_._2)
75101
76102 case RangePartitioning (sortingExpressions, numPartitions) =>
@@ -88,7 +114,9 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
88114
89115 val part = new RangePartitioner (numPartitions, rdd, ascending = true )
90116 val shuffled = new ShuffledRDD [Row , Null , Null ](rdd, part)
91- shuffled.setSerializer(new SparkSqlSerializer (new SparkConf (false )))
117+
118+ val keySchema = sortingExpressions.map(_.dataType).toArray
119+ shuffled.setSerializer(serializer(keySchema, null , numPartitions))
92120
93121 shuffled.map(_._1)
94122
@@ -107,7 +135,10 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
107135 }
108136 val partitioner = new HashPartitioner (1 )
109137 val shuffled = new ShuffledRDD [Null , Row , Row ](rdd, partitioner)
110- shuffled.setSerializer(new SparkSqlSerializer (new SparkConf (false )))
138+
139+ val valueSchema = child.output.map(_.dataType).toArray
140+ shuffled.setSerializer(serializer(null , valueSchema, 1 ))
141+
111142 shuffled.map(_._2)
112143
113144 case _ => sys.error(s " Exchange not implemented for $newPartitioning" )
0 commit comments