Skip to content

Commit 3af423c

Browse files
committed
[SPARK-6986] [SQL] Use Serializer2 in more cases.
With 0a2b15c, the serialization stream and deserialization stream has enough information to determine it is handling a key-value pari, a key, or a value. It is safe to use `SparkSqlSerializer2` in more cases. Author: Yin Huai <[email protected]> Closes #5849 from yhuai/serializer2MoreCases and squashes the following commits: 53a5eaa [Yin Huai] Josh's comments. 487f540 [Yin Huai] Use BufferedOutputStream. 8385f95 [Yin Huai] Always create a new row at the deserialization side to work with sort merge join. c7e2129 [Yin Huai] Update tests. 4513d13 [Yin Huai] Use Serializer2 in more places.
1 parent 92f8f80 commit 3af423c

File tree

3 files changed

+69
-58
lines changed

3 files changed

+69
-58
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -84,18 +84,8 @@ case class Exchange(
8484
def serializer(
8585
keySchema: Array[DataType],
8686
valueSchema: Array[DataType],
87+
hasKeyOrdering: Boolean,
8788
numPartitions: Int): Serializer = {
88-
// In ExternalSorter's spillToMergeableFile function, key-value pairs are written out
89-
// through write(key) and then write(value) instead of write((key, value)). Because
90-
// SparkSqlSerializer2 assumes that objects passed in are Product2, we cannot safely use
91-
// it when spillToMergeableFile in ExternalSorter will be used.
92-
// So, we will not use SparkSqlSerializer2 when
93-
// - Sort-based shuffle is enabled and the number of reducers (numPartitions) is greater
94-
// then the bypassMergeThreshold; or
95-
// - newOrdering is defined.
96-
val cannotUseSqlSerializer2 =
97-
(sortBasedShuffleOn && numPartitions > bypassMergeThreshold) || newOrdering.nonEmpty
98-
9989
// It is true when there is no field that needs to be write out.
10090
// For now, we will not use SparkSqlSerializer2 when noField is true.
10191
val noField =
@@ -104,14 +94,13 @@ case class Exchange(
10494

10595
val useSqlSerializer2 =
10696
child.sqlContext.conf.useSqlSerializer2 && // SparkSqlSerializer2 is enabled.
107-
!cannotUseSqlSerializer2 && // Safe to use Serializer2.
10897
SparkSqlSerializer2.support(keySchema) && // The schema of key is supported.
10998
SparkSqlSerializer2.support(valueSchema) && // The schema of value is supported.
11099
!noField
111100

112101
val serializer = if (useSqlSerializer2) {
113102
logInfo("Using SparkSqlSerializer2.")
114-
new SparkSqlSerializer2(keySchema, valueSchema)
103+
new SparkSqlSerializer2(keySchema, valueSchema, hasKeyOrdering)
115104
} else {
116105
logInfo("Using SparkSqlSerializer.")
117106
new SparkSqlSerializer(sparkConf)
@@ -154,7 +143,8 @@ case class Exchange(
154143
}
155144
val keySchema = expressions.map(_.dataType).toArray
156145
val valueSchema = child.output.map(_.dataType).toArray
157-
shuffled.setSerializer(serializer(keySchema, valueSchema, numPartitions))
146+
shuffled.setSerializer(
147+
serializer(keySchema, valueSchema, newOrdering.nonEmpty, numPartitions))
158148

159149
shuffled.map(_._2)
160150

@@ -179,7 +169,8 @@ case class Exchange(
179169
new ShuffledRDD[Row, Null, Null](rdd, part)
180170
}
181171
val keySchema = child.output.map(_.dataType).toArray
182-
shuffled.setSerializer(serializer(keySchema, null, numPartitions))
172+
shuffled.setSerializer(
173+
serializer(keySchema, null, newOrdering.nonEmpty, numPartitions))
183174

184175
shuffled.map(_._1)
185176

@@ -199,7 +190,7 @@ case class Exchange(
199190
val partitioner = new HashPartitioner(1)
200191
val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner)
201192
val valueSchema = child.output.map(_.dataType).toArray
202-
shuffled.setSerializer(serializer(null, valueSchema, 1))
193+
shuffled.setSerializer(serializer(null, valueSchema, false, 1))
203194
shuffled.map(_._2)
204195

205196
case _ => sys.error(s"Exchange not implemented for $newPartitioning")

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala

Lines changed: 47 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import scala.reflect.ClassTag
2727
import org.apache.spark.serializer._
2828
import org.apache.spark.Logging
2929
import org.apache.spark.sql.Row
30-
import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow
30+
import org.apache.spark.sql.catalyst.expressions.{SpecificMutableRow, MutableRow, GenericMutableRow}
3131
import org.apache.spark.sql.types._
3232

3333
/**
@@ -49,9 +49,9 @@ private[sql] class Serializer2SerializationStream(
4949
out: OutputStream)
5050
extends SerializationStream with Logging {
5151

52-
val rowOut = new DataOutputStream(out)
53-
val writeKeyFunc = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut)
54-
val writeValueFunc = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut)
52+
private val rowOut = new DataOutputStream(new BufferedOutputStream(out))
53+
private val writeKeyFunc = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut)
54+
private val writeValueFunc = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut)
5555

5656
override def writeObject[T: ClassTag](t: T): SerializationStream = {
5757
val kv = t.asInstanceOf[Product2[Row, Row]]
@@ -86,41 +86,55 @@ private[sql] class Serializer2SerializationStream(
8686
private[sql] class Serializer2DeserializationStream(
8787
keySchema: Array[DataType],
8888
valueSchema: Array[DataType],
89+
hasKeyOrdering: Boolean,
8990
in: InputStream)
9091
extends DeserializationStream with Logging {
9192

92-
val rowIn = new DataInputStream(new BufferedInputStream(in))
93+
private val rowIn = new DataInputStream(new BufferedInputStream(in))
94+
95+
private def rowGenerator(schema: Array[DataType]): () => (MutableRow) = {
96+
if (schema == null) {
97+
() => null
98+
} else {
99+
if (hasKeyOrdering) {
100+
// We have key ordering specified in a ShuffledRDD, it is not safe to reuse a mutable row.
101+
() => new GenericMutableRow(schema.length)
102+
} else {
103+
// It is safe to reuse the mutable row.
104+
val mutableRow = new SpecificMutableRow(schema)
105+
() => mutableRow
106+
}
107+
}
108+
}
93109

94-
val key = if (keySchema != null) new SpecificMutableRow(keySchema) else null
95-
val value = if (valueSchema != null) new SpecificMutableRow(valueSchema) else null
96-
val readKeyFunc = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn, key)
97-
val readValueFunc = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn, value)
110+
// Functions used to return rows for key and value.
111+
private val getKey = rowGenerator(keySchema)
112+
private val getValue = rowGenerator(valueSchema)
113+
// Functions used to read a serialized row from the InputStream and deserialize it.
114+
private val readKeyFunc = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn)
115+
private val readValueFunc = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn)
98116

99117
override def readObject[T: ClassTag](): T = {
100-
readKeyFunc()
101-
readValueFunc()
102-
103-
(key, value).asInstanceOf[T]
118+
(readKeyFunc(getKey()), readValueFunc(getValue())).asInstanceOf[T]
104119
}
105120

106121
override def readKey[T: ClassTag](): T = {
107-
readKeyFunc()
108-
key.asInstanceOf[T]
122+
readKeyFunc(getKey()).asInstanceOf[T]
109123
}
110124

111125
override def readValue[T: ClassTag](): T = {
112-
readValueFunc()
113-
value.asInstanceOf[T]
126+
readValueFunc(getValue()).asInstanceOf[T]
114127
}
115128

116129
override def close(): Unit = {
117130
rowIn.close()
118131
}
119132
}
120133

121-
private[sql] class ShuffleSerializerInstance(
134+
private[sql] class SparkSqlSerializer2Instance(
122135
keySchema: Array[DataType],
123-
valueSchema: Array[DataType])
136+
valueSchema: Array[DataType],
137+
hasKeyOrdering: Boolean)
124138
extends SerializerInstance {
125139

126140
def serialize[T: ClassTag](t: T): ByteBuffer =
@@ -137,7 +151,7 @@ private[sql] class ShuffleSerializerInstance(
137151
}
138152

139153
def deserializeStream(s: InputStream): DeserializationStream = {
140-
new Serializer2DeserializationStream(keySchema, valueSchema, s)
154+
new Serializer2DeserializationStream(keySchema, valueSchema, hasKeyOrdering, s)
141155
}
142156
}
143157

@@ -148,12 +162,16 @@ private[sql] class ShuffleSerializerInstance(
148162
* The schema of keys is represented by `keySchema` and that of values is represented by
149163
* `valueSchema`.
150164
*/
151-
private[sql] class SparkSqlSerializer2(keySchema: Array[DataType], valueSchema: Array[DataType])
165+
private[sql] class SparkSqlSerializer2(
166+
keySchema: Array[DataType],
167+
valueSchema: Array[DataType],
168+
hasKeyOrdering: Boolean)
152169
extends Serializer
153170
with Logging
154171
with Serializable{
155172

156-
def newInstance(): SerializerInstance = new ShuffleSerializerInstance(keySchema, valueSchema)
173+
def newInstance(): SerializerInstance =
174+
new SparkSqlSerializer2Instance(keySchema, valueSchema, hasKeyOrdering)
157175

158176
override def supportsRelocationOfSerializedObjects: Boolean = {
159177
// SparkSqlSerializer2 is stateless and writes no stream headers
@@ -323,11 +341,11 @@ private[sql] object SparkSqlSerializer2 {
323341
*/
324342
def createDeserializationFunction(
325343
schema: Array[DataType],
326-
in: DataInputStream,
327-
mutableRow: SpecificMutableRow): () => Unit = {
328-
() => {
329-
// If the schema is null, the returned function does nothing when it get called.
330-
if (schema != null) {
344+
in: DataInputStream): (MutableRow) => Row = {
345+
if (schema == null) {
346+
(mutableRow: MutableRow) => null
347+
} else {
348+
(mutableRow: MutableRow) => {
331349
var i = 0
332350
while (i < schema.length) {
333351
schema(i) match {
@@ -440,6 +458,8 @@ private[sql] object SparkSqlSerializer2 {
440458
}
441459
i += 1
442460
}
461+
462+
mutableRow
443463
}
444464
}
445465
}

sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,15 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll
148148
table("shuffle").collect())
149149
}
150150

151+
test("key schema is null") {
152+
val aggregations = allColumns.split(",").map(c => s"COUNT($c)").mkString(",")
153+
val df = sql(s"SELECT $aggregations FROM shuffle")
154+
checkSerializer(df.queryExecution.executedPlan, serializerClass)
155+
checkAnswer(
156+
df,
157+
Row(1000, 1000, 0, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000))
158+
}
159+
151160
test("value schema is null") {
152161
val df = sql(s"SELECT col0 FROM shuffle ORDER BY col0")
153162
checkSerializer(df.queryExecution.executedPlan, serializerClass)
@@ -167,29 +176,20 @@ class SparkSqlSerializer2SortShuffleSuite extends SparkSqlSerializer2Suite {
167176
override def beforeAll(): Unit = {
168177
super.beforeAll()
169178
// Sort merge will not be triggered.
170-
sql("set spark.sql.shuffle.partitions = 200")
171-
}
172-
173-
test("key schema is null") {
174-
val aggregations = allColumns.split(",").map(c => s"COUNT($c)").mkString(",")
175-
val df = sql(s"SELECT $aggregations FROM shuffle")
176-
checkSerializer(df.queryExecution.executedPlan, serializerClass)
177-
checkAnswer(
178-
df,
179-
Row(1000, 1000, 0, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000))
179+
val bypassMergeThreshold =
180+
sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
181+
sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold-1}")
180182
}
181183
}
182184

183185
/** For now, we will use SparkSqlSerializer for sort based shuffle with sort merge. */
184186
class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite {
185187

186-
// We are expecting SparkSqlSerializer.
187-
override val serializerClass: Class[Serializer] =
188-
classOf[SparkSqlSerializer].asInstanceOf[Class[Serializer]]
189-
190188
override def beforeAll(): Unit = {
191189
super.beforeAll()
192190
// To trigger the sort merge.
193-
sql("set spark.sql.shuffle.partitions = 201")
191+
val bypassMergeThreshold =
192+
sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
193+
sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold + 1}")
194194
}
195195
}

0 commit comments

Comments
 (0)