1717
1818package org .apache .spark .sql .catalyst .expressions
1919
20+ import java .util .Arrays
21+
2022import org .scalatest .{FunSuite , Matchers }
2123
22- import org .apache .spark .sql .types .{ StringType , DataType , LongType , IntegerType }
24+ import org .apache .spark .sql .types ._
2325import org .apache .spark .unsafe .PlatformDependent
2426import org .apache .spark .unsafe .array .ByteArrayMethods
2527
2628class UnsafeRowConverterSuite extends FunSuite with Matchers {
2729
2830 test(" basic conversion with only primitive types" ) {
2931 val fieldTypes : Array [DataType ] = Array (LongType , LongType , IntegerType )
32+ val converter = new UnsafeRowConverter (fieldTypes)
33+
3034 val row = new SpecificMutableRow (fieldTypes)
3135 row.setLong(0 , 0 )
3236 row.setLong(1 , 1 )
3337 row.setInt(2 , 2 )
34- val converter = new UnsafeRowConverter (fieldTypes)
38+
3539 val sizeRequired : Int = converter.getSizeRequirement(row)
3640 sizeRequired should be (8 + (3 * 8 ))
3741 val buffer : Array [Long ] = new Array [Long ](sizeRequired / 8 )
3842 val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent .LONG_ARRAY_OFFSET )
3943 numBytesWritten should be (sizeRequired)
44+
4045 val unsafeRow = new UnsafeRow ()
4146 unsafeRow.pointTo(buffer, PlatformDependent .LONG_ARRAY_OFFSET , fieldTypes.length, null )
4247 unsafeRow.getLong(0 ) should be (0 )
@@ -46,22 +51,83 @@ class UnsafeRowConverterSuite extends FunSuite with Matchers {
4651
4752 test(" basic conversion with primitive and string types" ) {
4853 val fieldTypes : Array [DataType ] = Array (LongType , StringType , StringType )
54+ val converter = new UnsafeRowConverter (fieldTypes)
55+
4956 val row = new SpecificMutableRow (fieldTypes)
5057 row.setLong(0 , 0 )
5158 row.setString(1 , " Hello" )
5259 row.setString(2 , " World" )
53- val converter = new UnsafeRowConverter (fieldTypes)
60+
5461 val sizeRequired : Int = converter.getSizeRequirement(row)
5562 sizeRequired should be (8 + (8 * 3 ) +
5663 ByteArrayMethods .roundNumberOfBytesToNearestWord(" Hello" .getBytes.length + 8 ) +
5764 ByteArrayMethods .roundNumberOfBytesToNearestWord(" World" .getBytes.length + 8 ))
5865 val buffer : Array [Long ] = new Array [Long ](sizeRequired / 8 )
5966 val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent .LONG_ARRAY_OFFSET )
6067 numBytesWritten should be (sizeRequired)
68+
6169 val unsafeRow = new UnsafeRow ()
6270 unsafeRow.pointTo(buffer, PlatformDependent .LONG_ARRAY_OFFSET , fieldTypes.length, null )
6371 unsafeRow.getLong(0 ) should be (0 )
6472 unsafeRow.getString(1 ) should be (" Hello" )
6573 unsafeRow.getString(2 ) should be (" World" )
6674 }
75+
76+ test(" null handling" ) {
77+ val fieldTypes : Array [DataType ] = Array (IntegerType , LongType , FloatType , DoubleType )
78+ val converter = new UnsafeRowConverter (fieldTypes)
79+
80+ val rowWithAllNullColumns : Row = {
81+ val r = new SpecificMutableRow (fieldTypes)
82+ for (i <- 0 to 3 ) {
83+ r.setNullAt(i)
84+ }
85+ r
86+ }
87+
88+ val sizeRequired : Int = converter.getSizeRequirement(rowWithAllNullColumns)
89+ val createdFromNullBuffer : Array [Long ] = new Array [Long ](sizeRequired / 8 )
90+ val numBytesWritten = converter.writeRow(
91+ rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent .LONG_ARRAY_OFFSET )
92+ numBytesWritten should be (sizeRequired)
93+
94+ val createdFromNull = new UnsafeRow ()
95+ createdFromNull.pointTo(
96+ createdFromNullBuffer, PlatformDependent .LONG_ARRAY_OFFSET , fieldTypes.length, null )
97+ for (i <- 0 to 3 ) {
98+ assert(createdFromNull.isNullAt(i))
99+ }
100+ createdFromNull.getInt(0 ) should be (0 )
101+ createdFromNull.getLong(1 ) should be (0 )
102+ assert(java.lang.Float .isNaN(createdFromNull.getFloat(2 )))
103+ assert(java.lang.Double .isNaN(createdFromNull.getFloat(3 )))
104+
105+ // If we have an UnsafeRow with columns that are initially non-null and we null out those
106+ // columns, then the serialized row representation should be identical to what we would get by
107+ // creating an entirely null row via the converter
108+ val rowWithNoNullColumns : Row = {
109+ val r = new SpecificMutableRow (fieldTypes)
110+ r.setInt(0 , 100 )
111+ r.setLong(1 , 200 )
112+ r.setFloat(2 , 300 )
113+ r.setDouble(3 , 400 )
114+ r
115+ }
116+ val setToNullAfterCreationBuffer : Array [Long ] = new Array [Long ](sizeRequired / 8 )
117+ converter.writeRow(
118+ rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent .LONG_ARRAY_OFFSET )
119+ val setToNullAfterCreation = new UnsafeRow ()
120+ setToNullAfterCreation.pointTo(
121+ setToNullAfterCreationBuffer, PlatformDependent .LONG_ARRAY_OFFSET , fieldTypes.length, null )
122+ setToNullAfterCreation.getInt(0 ) should be (rowWithNoNullColumns.getInt(0 ))
123+ setToNullAfterCreation.getLong(1 ) should be (rowWithNoNullColumns.getLong(1 ))
124+ setToNullAfterCreation.getFloat(2 ) should be (rowWithNoNullColumns.getFloat(2 ))
125+ setToNullAfterCreation.getDouble(3 ) should be (rowWithNoNullColumns.getDouble(3 ))
126+
127+ for (i <- 0 to 3 ) {
128+ setToNullAfterCreation.setNullAt(i)
129+ }
130+ assert(Arrays .equals(createdFromNullBuffer, setToNullAfterCreationBuffer))
131+ }
132+
67133}
0 commit comments