Skip to content

Commit 2488007

Browse files
committed
Merge remote-tracking branch 'apache/master' into SPARK-23951
# Conflicts: # sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
2 parents 45d3ed8 + 3323b15 commit 2488007

File tree

5 files changed

+210
-248
lines changed

5 files changed

+210
-248
lines changed

mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ private[spark] class NodeIdCache(
9595
splits: Array[Array[Split]]): Unit = {
9696
if (prevNodeIdsForInstances != null) {
9797
// Unpersist the previous one if one exists.
98-
prevNodeIdsForInstances.unpersist()
98+
prevNodeIdsForInstances.unpersist(false)
9999
}
100100

101101
prevNodeIdsForInstances = nodeIdsForInstances
@@ -166,9 +166,13 @@ private[spark] class NodeIdCache(
166166
}
167167
}
168168
}
169+
if (nodeIdsForInstances != null) {
170+
// Unpersist current one if one exists.
171+
nodeIdsForInstances.unpersist(false)
172+
}
169173
if (prevNodeIdsForInstances != null) {
170174
// Unpersist the previous one if one exists.
171-
prevNodeIdsForInstances.unpersist()
175+
prevNodeIdsForInstances.unpersist(false)
172176
}
173177
}
174178
}

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java

Lines changed: 47 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
*/
1717
package org.apache.spark.sql.catalyst.expressions.codegen;
1818

19+
import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData;
20+
import org.apache.spark.sql.catalyst.expressions.UnsafeMapData;
21+
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
1922
import org.apache.spark.sql.types.Decimal;
2023
import org.apache.spark.unsafe.Platform;
2124
import org.apache.spark.unsafe.array.ByteArrayMethods;
@@ -103,42 +106,27 @@ protected final void zeroOutPaddingBytes(int numBytes) {
103106
public abstract void write(int ordinal, Decimal input, int precision, int scale);
104107

105108
public final void write(int ordinal, UTF8String input) {
106-
final int numBytes = input.numBytes();
107-
final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
108-
109-
// grow the global buffer before writing data.
110-
grow(roundedSize);
111-
112-
zeroOutPaddingBytes(numBytes);
113-
114-
// Write the bytes to the variable length portion.
115-
input.writeToMemory(getBuffer(), cursor());
116-
117-
setOffsetAndSize(ordinal, numBytes);
118-
119-
// move the cursor forward.
120-
increaseCursor(roundedSize);
109+
writeUnalignedBytes(ordinal, input.getBaseObject(), input.getBaseOffset(), input.numBytes());
121110
}
122111

123112
public final void write(int ordinal, byte[] input) {
124113
write(ordinal, input, 0, input.length);
125114
}
126115

127116
public final void write(int ordinal, byte[] input, int offset, int numBytes) {
128-
final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(input.length);
117+
writeUnalignedBytes(ordinal, input, Platform.BYTE_ARRAY_OFFSET + offset, numBytes);
118+
}
129119

130-
// grow the global buffer before writing data.
120+
private void writeUnalignedBytes(
121+
int ordinal,
122+
Object baseObject,
123+
long baseOffset,
124+
int numBytes) {
125+
final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
131126
grow(roundedSize);
132-
133127
zeroOutPaddingBytes(numBytes);
134-
135-
// Write the bytes to the variable length portion.
136-
Platform.copyMemory(
137-
input, Platform.BYTE_ARRAY_OFFSET + offset, getBuffer(), cursor(), numBytes);
138-
128+
Platform.copyMemory(baseObject, baseOffset, getBuffer(), cursor(), numBytes);
139129
setOffsetAndSize(ordinal, numBytes);
140-
141-
// move the cursor forward.
142130
increaseCursor(roundedSize);
143131
}
144132

@@ -156,6 +144,40 @@ public final void write(int ordinal, CalendarInterval input) {
156144
increaseCursor(16);
157145
}
158146

147+
public final void write(int ordinal, UnsafeRow row) {
148+
writeAlignedBytes(ordinal, row.getBaseObject(), row.getBaseOffset(), row.getSizeInBytes());
149+
}
150+
151+
public final void write(int ordinal, UnsafeMapData map) {
152+
writeAlignedBytes(ordinal, map.getBaseObject(), map.getBaseOffset(), map.getSizeInBytes());
153+
}
154+
155+
public final void write(UnsafeArrayData array) {
156+
// Unsafe arrays both can be written as a regular array field or as part of a map. This makes
157+
// updating the offset and size dependent on the code path, this is why we currently do not
158+
// provide an method for writing unsafe arrays that also updates the size and offset.
159+
int numBytes = array.getSizeInBytes();
160+
grow(numBytes);
161+
Platform.copyMemory(
162+
array.getBaseObject(),
163+
array.getBaseOffset(),
164+
getBuffer(),
165+
cursor(),
166+
numBytes);
167+
increaseCursor(numBytes);
168+
}
169+
170+
private void writeAlignedBytes(
171+
int ordinal,
172+
Object baseObject,
173+
long baseOffset,
174+
int numBytes) {
175+
grow(numBytes);
176+
Platform.copyMemory(baseObject, baseOffset, getBuffer(), cursor(), numBytes);
177+
setOffsetAndSize(ordinal, numBytes);
178+
increaseCursor(numBytes);
179+
}
180+
159181
protected final void writeBoolean(long offset, boolean value) {
160182
Platform.putBoolean(getBuffer(), offset, value);
161183
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala

Lines changed: 8 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -173,21 +173,17 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator {
173173
val rowWriter = new UnsafeRowWriter(writer, numFields)
174174
val structWriter = generateStructWriter(rowWriter, fields)
175175
(v, i) => {
176-
val previousCursor = writer.cursor()
177176
v.getStruct(i, fields.length) match {
178177
case row: UnsafeRow =>
179-
writeUnsafeData(
180-
rowWriter,
181-
row.getBaseObject,
182-
row.getBaseOffset,
183-
row.getSizeInBytes)
178+
writer.write(i, row)
184179
case row =>
180+
val previousCursor = writer.cursor()
185181
// Nested struct. We don't know where this will start because a row can be
186182
// variable length, so we need to update the offsets and zero out the bit mask.
187183
rowWriter.resetRowWriter()
188184
structWriter.apply(row)
185+
writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor)
189186
}
190-
writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor)
191187
}
192188

193189
case ArrayType(elementType, containsNull) =>
@@ -214,15 +210,12 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator {
214210
valueType,
215211
valueContainsNull)
216212
(v, i) => {
217-
val previousCursor = writer.cursor()
218213
v.getMap(i) match {
219214
case map: UnsafeMapData =>
220-
writeUnsafeData(
221-
valueArrayWriter,
222-
map.getBaseObject,
223-
map.getBaseOffset,
224-
map.getSizeInBytes)
215+
writer.write(i, map)
225216
case map =>
217+
val previousCursor = writer.cursor()
218+
226219
// preserve 8 bytes to write the key array numBytes later.
227220
valueArrayWriter.grow(8)
228221
valueArrayWriter.increaseCursor(8)
@@ -237,8 +230,8 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator {
237230

238231
// Write the values.
239232
writeArray(valueArrayWriter, valueWriter, map.valueArray())
233+
writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor)
240234
}
241-
writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor)
242235
}
243236

244237
case udt: UserDefinedType[_] =>
@@ -318,11 +311,7 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator {
318311
elementWriter: (SpecializedGetters, Int) => Unit,
319312
array: ArrayData): Unit = array match {
320313
case unsafe: UnsafeArrayData =>
321-
writeUnsafeData(
322-
arrayWriter,
323-
unsafe.getBaseObject,
324-
unsafe.getBaseOffset,
325-
unsafe.getSizeInBytes)
314+
arrayWriter.write(unsafe)
326315
case _ =>
327316
val numElements = array.numElements()
328317
arrayWriter.initialize(numElements)
@@ -332,23 +321,4 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator {
332321
i += 1
333322
}
334323
}
335-
336-
/**
337-
* Write an opaque block of data to the buffer. This is used to copy
338-
* [[UnsafeRow]], [[UnsafeArrayData]] and [[UnsafeMapData]] objects.
339-
*/
340-
private def writeUnsafeData(
341-
writer: UnsafeWriter,
342-
baseObject: AnyRef,
343-
baseOffset: Long,
344-
sizeInBytes: Int) : Unit = {
345-
writer.grow(sizeInBytes)
346-
Platform.copyMemory(
347-
baseObject,
348-
baseOffset,
349-
writer.getBuffer,
350-
writer.cursor,
351-
sizeInBytes)
352-
writer.increaseCursor(sizeInBytes)
353-
}
354324
}

0 commit comments

Comments
 (0)