@@ -27,7 +27,7 @@ import scala.reflect.ClassTag
2727import org .apache .spark .serializer ._
2828import org .apache .spark .Logging
2929import org .apache .spark .sql .Row
30- import org .apache .spark .sql .catalyst .expressions .SpecificMutableRow
30+ import org .apache .spark .sql .catalyst .expressions .{ SpecificMutableRow , MutableRow , GenericMutableRow }
3131import org .apache .spark .sql .types ._
3232
3333/**
@@ -49,9 +49,9 @@ private[sql] class Serializer2SerializationStream(
4949 out : OutputStream )
5050 extends SerializationStream with Logging {
5151
52- val rowOut = new DataOutputStream (out)
53- val writeKeyFunc = SparkSqlSerializer2 .createSerializationFunction(keySchema, rowOut)
54- val writeValueFunc = SparkSqlSerializer2 .createSerializationFunction(valueSchema, rowOut)
52+ private val rowOut = new DataOutputStream (new BufferedOutputStream ( out) )
53+ private val writeKeyFunc = SparkSqlSerializer2 .createSerializationFunction(keySchema, rowOut)
54+ private val writeValueFunc = SparkSqlSerializer2 .createSerializationFunction(valueSchema, rowOut)
5555
5656 override def writeObject [T : ClassTag ](t : T ): SerializationStream = {
5757 val kv = t.asInstanceOf [Product2 [Row , Row ]]
@@ -86,41 +86,55 @@ private[sql] class Serializer2SerializationStream(
8686private [sql] class Serializer2DeserializationStream (
8787 keySchema : Array [DataType ],
8888 valueSchema : Array [DataType ],
89+ hasKeyOrdering : Boolean ,
8990 in : InputStream )
9091 extends DeserializationStream with Logging {
9192
92- val rowIn = new DataInputStream (new BufferedInputStream (in))
93+ private val rowIn = new DataInputStream (new BufferedInputStream (in))
94+
95+ private def rowGenerator (schema : Array [DataType ]): () => (MutableRow ) = {
96+ if (schema == null ) {
97+ () => null
98+ } else {
99+ if (hasKeyOrdering) {
100+ // We have key ordering specified in a ShuffledRDD, it is not safe to reuse a mutable row.
101+ () => new GenericMutableRow (schema.length)
102+ } else {
103+ // It is safe to reuse the mutable row.
104+ val mutableRow = new SpecificMutableRow (schema)
105+ () => mutableRow
106+ }
107+ }
108+ }
93109
94- val key = if (keySchema != null ) new SpecificMutableRow (keySchema) else null
95- val value = if (valueSchema != null ) new SpecificMutableRow (valueSchema) else null
96- val readKeyFunc = SparkSqlSerializer2 .createDeserializationFunction(keySchema, rowIn, key)
97- val readValueFunc = SparkSqlSerializer2 .createDeserializationFunction(valueSchema, rowIn, value)
110+ // Functions used to return rows for key and value.
111+ private val getKey = rowGenerator(keySchema)
112+ private val getValue = rowGenerator(valueSchema)
113+ // Functions used to read a serialized row from the InputStream and deserialize it.
114+ private val readKeyFunc = SparkSqlSerializer2 .createDeserializationFunction(keySchema, rowIn)
115+ private val readValueFunc = SparkSqlSerializer2 .createDeserializationFunction(valueSchema, rowIn)
98116
99117 override def readObject [T : ClassTag ](): T = {
100- readKeyFunc()
101- readValueFunc()
102-
103- (key, value).asInstanceOf [T ]
118+ (readKeyFunc(getKey()), readValueFunc(getValue())).asInstanceOf [T ]
104119 }
105120
106121 override def readKey [T : ClassTag ](): T = {
107- readKeyFunc()
108- key.asInstanceOf [T ]
122+ readKeyFunc(getKey()).asInstanceOf [T ]
109123 }
110124
111125 override def readValue [T : ClassTag ](): T = {
112- readValueFunc()
113- value.asInstanceOf [T ]
126+ readValueFunc(getValue()).asInstanceOf [T ]
114127 }
115128
116129 override def close (): Unit = {
117130 rowIn.close()
118131 }
119132}
120133
121- private [sql] class ShuffleSerializerInstance (
134+ private [sql] class SparkSqlSerializer2Instance (
122135 keySchema : Array [DataType ],
123- valueSchema : Array [DataType ])
136+ valueSchema : Array [DataType ],
137+ hasKeyOrdering : Boolean )
124138 extends SerializerInstance {
125139
126140 def serialize [T : ClassTag ](t : T ): ByteBuffer =
@@ -137,7 +151,7 @@ private[sql] class ShuffleSerializerInstance(
137151 }
138152
139153 def deserializeStream (s : InputStream ): DeserializationStream = {
140- new Serializer2DeserializationStream (keySchema, valueSchema, s)
154+ new Serializer2DeserializationStream (keySchema, valueSchema, hasKeyOrdering, s)
141155 }
142156}
143157
@@ -148,12 +162,16 @@ private[sql] class ShuffleSerializerInstance(
148162 * The schema of keys is represented by `keySchema` and that of values is represented by
149163 * `valueSchema`.
150164 */
151- private [sql] class SparkSqlSerializer2 (keySchema : Array [DataType ], valueSchema : Array [DataType ])
165+ private [sql] class SparkSqlSerializer2 (
166+ keySchema : Array [DataType ],
167+ valueSchema : Array [DataType ],
168+ hasKeyOrdering : Boolean )
152169 extends Serializer
153170 with Logging
154171 with Serializable {
155172
156- def newInstance (): SerializerInstance = new ShuffleSerializerInstance (keySchema, valueSchema)
173+ def newInstance (): SerializerInstance =
174+ new SparkSqlSerializer2Instance (keySchema, valueSchema, hasKeyOrdering)
157175
158176 override def supportsRelocationOfSerializedObjects : Boolean = {
159177 // SparkSqlSerializer2 is stateless and writes no stream headers
@@ -323,11 +341,11 @@ private[sql] object SparkSqlSerializer2 {
323341 */
324342 def createDeserializationFunction (
325343 schema : Array [DataType ],
326- in : DataInputStream ,
327- mutableRow : SpecificMutableRow ) : () => Unit = {
328- ( ) => {
329- // If the schema is null, the returned function does nothing when it get called.
330- if (schema != null ) {
344+ in : DataInputStream ) : ( MutableRow ) => Row = {
345+ if (schema == null ) {
346+ ( mutableRow : MutableRow ) => null
347+ } else {
348+ ( mutableRow : MutableRow ) => {
331349 var i = 0
332350 while (i < schema.length) {
333351 schema(i) match {
@@ -440,6 +458,8 @@ private[sql] object SparkSqlSerializer2 {
440458 }
441459 i += 1
442460 }
461+
462+ mutableRow
443463 }
444464 }
445465 }
0 commit comments