Skip to content

Commit 6ffdaa1

Browse files
committed
Null handling improvements in UnsafeRow.
1 parent 31eaabc commit 6ffdaa1

File tree

3 files changed

+95
-17
lines changed

3 files changed

+95
-17
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,10 @@ private void assertIndexIsValid(int index) {
145145
public void setNullAt(int i) {
146146
assertIndexIsValid(i);
147147
BitSetMethods.set(baseObject, baseOffset, i);
148+
// To preserve row equality, zero out the value when setting the column to null.
149+
// Since this row does does not currently support updates to variable-length values, we don't
150+
// have to worry about zeroing out that data.
151+
PlatformDependent.UNSAFE.putLong(baseObject, getFieldOffset(i), 0);
148152
}
149153

150154
private void setNotNullAt(int i) {
@@ -288,13 +292,21 @@ public long getLong(int i) {
288292
@Override
289293
public float getFloat(int i) {
290294
assertIndexIsValid(i);
291-
return PlatformDependent.UNSAFE.getFloat(baseObject, getFieldOffset(i));
295+
if (isNullAt(i)) {
296+
return Float.NaN;
297+
} else {
298+
return PlatformDependent.UNSAFE.getFloat(baseObject, getFieldOffset(i));
299+
}
292300
}
293301

294302
@Override
295303
public double getDouble(int i) {
296304
assertIndexIsValid(i);
297-
return PlatformDependent.UNSAFE.getDouble(baseObject, getFieldOffset(i));
305+
if (isNullAt(i)) {
306+
return Float.NaN;
307+
} else {
308+
return PlatformDependent.UNSAFE.getDouble(baseObject, getFieldOffset(i));
309+
}
298310
}
299311

300312
public UTF8String getUTF8String(int i) {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) {
7474
while (fieldNumber < writers.length) {
7575
if (row.isNullAt(fieldNumber)) {
7676
unsafeRow.setNullAt(fieldNumber)
77-
// TODO: type-specific null value writing?
7877
} else {
7978
appendCursor += writers(fieldNumber).write(
8079
row(fieldNumber),
@@ -122,11 +121,6 @@ private abstract class UnsafeColumnWriter[T] {
122121
}
123122

124123
private object UnsafeColumnWriter {
125-
private object IntUnsafeColumnWriter extends IntUnsafeColumnWriter
126-
private object LongUnsafeColumnWriter extends LongUnsafeColumnWriter
127-
private object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter
128-
private object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter
129-
private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter
130124

131125
def forType(dataType: DataType): UnsafeColumnWriter[_] = {
132126
dataType match {
@@ -143,6 +137,12 @@ private object UnsafeColumnWriter {
143137

144138
// ------------------------------------------------------------------------------------------------
145139

140+
private object IntUnsafeColumnWriter extends IntUnsafeColumnWriter
141+
private object LongUnsafeColumnWriter extends LongUnsafeColumnWriter
142+
private object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter
143+
private object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter
144+
private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter
145+
146146
private abstract class PrimitiveUnsafeColumnWriter[T] extends UnsafeColumnWriter[T] {
147147
def getSize(value: T): Int = 0
148148
}
@@ -205,12 +205,12 @@ private class StringUnsafeColumnWriter private() extends UnsafeColumnWriter[UTF8
205205
}
206206

207207
override def write(
208-
value: UTF8String,
209-
columnNumber: Int,
210-
row: UnsafeRow,
211-
baseObject: Object,
212-
baseOffset: Long,
213-
appendCursor: Int): Int = {
208+
value: UTF8String,
209+
columnNumber: Int,
210+
row: UnsafeRow,
211+
baseObject: Object,
212+
baseOffset: Long,
213+
appendCursor: Int): Int = {
214214
val numBytes = value.getBytes.length
215215
PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + appendCursor, numBytes)
216216
PlatformDependent.copyMemory(

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,26 +17,31 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20+
import java.util.Arrays
21+
2022
import org.scalatest.{FunSuite, Matchers}
2123

22-
import org.apache.spark.sql.types.{StringType, DataType, LongType, IntegerType}
24+
import org.apache.spark.sql.types._
2325
import org.apache.spark.unsafe.PlatformDependent
2426
import org.apache.spark.unsafe.array.ByteArrayMethods
2527

2628
class UnsafeRowConverterSuite extends FunSuite with Matchers {
2729

2830
test("basic conversion with only primitive types") {
2931
val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType)
32+
val converter = new UnsafeRowConverter(fieldTypes)
33+
3034
val row = new SpecificMutableRow(fieldTypes)
3135
row.setLong(0, 0)
3236
row.setLong(1, 1)
3337
row.setInt(2, 2)
34-
val converter = new UnsafeRowConverter(fieldTypes)
38+
3539
val sizeRequired: Int = converter.getSizeRequirement(row)
3640
sizeRequired should be (8 + (3 * 8))
3741
val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
3842
val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET)
3943
numBytesWritten should be (sizeRequired)
44+
4045
val unsafeRow = new UnsafeRow()
4146
unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
4247
unsafeRow.getLong(0) should be (0)
@@ -46,22 +51,83 @@ class UnsafeRowConverterSuite extends FunSuite with Matchers {
4651

4752
test("basic conversion with primitive and string types") {
4853
val fieldTypes: Array[DataType] = Array(LongType, StringType, StringType)
54+
val converter = new UnsafeRowConverter(fieldTypes)
55+
4956
val row = new SpecificMutableRow(fieldTypes)
5057
row.setLong(0, 0)
5158
row.setString(1, "Hello")
5259
row.setString(2, "World")
53-
val converter = new UnsafeRowConverter(fieldTypes)
60+
5461
val sizeRequired: Int = converter.getSizeRequirement(row)
5562
sizeRequired should be (8 + (8 * 3) +
5663
ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length + 8) +
5764
ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length + 8))
5865
val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
5966
val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET)
6067
numBytesWritten should be (sizeRequired)
68+
6169
val unsafeRow = new UnsafeRow()
6270
unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
6371
unsafeRow.getLong(0) should be (0)
6472
unsafeRow.getString(1) should be ("Hello")
6573
unsafeRow.getString(2) should be ("World")
6674
}
75+
76+
test("null handling") {
77+
val fieldTypes: Array[DataType] = Array(IntegerType, LongType, FloatType, DoubleType)
78+
val converter = new UnsafeRowConverter(fieldTypes)
79+
80+
val rowWithAllNullColumns: Row = {
81+
val r = new SpecificMutableRow(fieldTypes)
82+
for (i <- 0 to 3) {
83+
r.setNullAt(i)
84+
}
85+
r
86+
}
87+
88+
val sizeRequired: Int = converter.getSizeRequirement(rowWithAllNullColumns)
89+
val createdFromNullBuffer: Array[Long] = new Array[Long](sizeRequired / 8)
90+
val numBytesWritten = converter.writeRow(
91+
rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET)
92+
numBytesWritten should be (sizeRequired)
93+
94+
val createdFromNull = new UnsafeRow()
95+
createdFromNull.pointTo(
96+
createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
97+
for (i <- 0 to 3) {
98+
assert(createdFromNull.isNullAt(i))
99+
}
100+
createdFromNull.getInt(0) should be (0)
101+
createdFromNull.getLong(1) should be (0)
102+
assert(java.lang.Float.isNaN(createdFromNull.getFloat(2)))
103+
assert(java.lang.Double.isNaN(createdFromNull.getFloat(3)))
104+
105+
// If we have an UnsafeRow with columns that are initially non-null and we null out those
106+
// columns, then the serialized row representation should be identical to what we would get by
107+
// creating an entirely null row via the converter
108+
val rowWithNoNullColumns: Row = {
109+
val r = new SpecificMutableRow(fieldTypes)
110+
r.setInt(0, 100)
111+
r.setLong(1, 200)
112+
r.setFloat(2, 300)
113+
r.setDouble(3, 400)
114+
r
115+
}
116+
val setToNullAfterCreationBuffer: Array[Long] = new Array[Long](sizeRequired / 8)
117+
converter.writeRow(
118+
rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET)
119+
val setToNullAfterCreation = new UnsafeRow()
120+
setToNullAfterCreation.pointTo(
121+
setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
122+
setToNullAfterCreation.getInt(0) should be (rowWithNoNullColumns.getInt(0))
123+
setToNullAfterCreation.getLong(1) should be (rowWithNoNullColumns.getLong(1))
124+
setToNullAfterCreation.getFloat(2) should be (rowWithNoNullColumns.getFloat(2))
125+
setToNullAfterCreation.getDouble(3) should be (rowWithNoNullColumns.getDouble(3))
126+
127+
for (i <- 0 to 3) {
128+
setToNullAfterCreation.setNullAt(i)
129+
}
130+
assert(Arrays.equals(createdFromNullBuffer, setToNullAfterCreationBuffer))
131+
}
132+
67133
}

0 commit comments

Comments
 (0)