Skip to content

Commit 8385f95

Browse files
committed
Always create a new row at the deserialization side to work with sort merge join.
1 parent c7e2129 commit 8385f95

File tree

1 file changed

+14
-17
lines changed

1 file changed

+14
-17
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import scala.reflect.ClassTag
2727
import org.apache.spark.serializer._
2828
import org.apache.spark.Logging
2929
import org.apache.spark.sql.Row
30-
import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow
30+
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
3131
import org.apache.spark.sql.types._
3232

3333
/**
@@ -91,34 +91,27 @@ private[sql] class Serializer2DeserializationStream(
9191

9292
val rowIn = new DataInputStream(new BufferedInputStream(in))
9393

94-
val key = if (keySchema != null) new SpecificMutableRow(keySchema) else null
95-
val value = if (valueSchema != null) new SpecificMutableRow(valueSchema) else null
96-
val readKeyFunc = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn, key)
97-
val readValueFunc = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn, value)
94+
val readKeyFunc = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn)
95+
val readValueFunc = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn)
9896

9997
override def readObject[T: ClassTag](): T = {
100-
readKeyFunc()
101-
readValueFunc()
102-
103-
(key, value).asInstanceOf[T]
98+
(readKeyFunc(), readValueFunc()).asInstanceOf[T]
10499
}
105100

106101
override def readKey[T: ClassTag](): T = {
107-
readKeyFunc()
108-
key.asInstanceOf[T]
102+
readKeyFunc().asInstanceOf[T]
109103
}
110104

111105
override def readValue[T: ClassTag](): T = {
112-
readValueFunc()
113-
value.asInstanceOf[T]
106+
readValueFunc().asInstanceOf[T]
114107
}
115108

116109
override def close(): Unit = {
117110
rowIn.close()
118111
}
119112
}
120113

121-
private[sql] class ShuffleSerializerInstance(
114+
private[sql] class SparkSqlSerializer2Instance(
122115
keySchema: Array[DataType],
123116
valueSchema: Array[DataType])
124117
extends SerializerInstance {
@@ -153,7 +146,7 @@ private[sql] class SparkSqlSerializer2(keySchema: Array[DataType], valueSchema:
153146
with Logging
154147
with Serializable{
155148

156-
def newInstance(): SerializerInstance = new ShuffleSerializerInstance(keySchema, valueSchema)
149+
def newInstance(): SerializerInstance = new SparkSqlSerializer2Instance(keySchema, valueSchema)
157150

158151
override def supportsRelocationOfSerializedObjects: Boolean = {
159152
// SparkSqlSerializer2 is stateless and writes no stream headers
@@ -323,12 +316,12 @@ private[sql] object SparkSqlSerializer2 {
323316
*/
324317
def createDeserializationFunction(
325318
schema: Array[DataType],
326-
in: DataInputStream,
327-
mutableRow: SpecificMutableRow): () => Unit = {
319+
in: DataInputStream): () => Row = {
328320
() => {
329321
// If the schema is null, the returned function does nothing when it get called.
330322
if (schema != null) {
331323
var i = 0
324+
val mutableRow = new GenericMutableRow(schema.length)
332325
while (i < schema.length) {
333326
schema(i) match {
334327
// When we read values from the underlying stream, we also first read the null byte
@@ -440,6 +433,10 @@ private[sql] object SparkSqlSerializer2 {
440433
}
441434
i += 1
442435
}
436+
437+
mutableRow
438+
} else {
439+
null
443440
}
444441
}
445442
}

0 commit comments

Comments
 (0)