Skip to content

Commit 2776aca

Browse files
committed
First passing test for ExternalSorter.
1 parent 5e100b2 commit 2776aca

File tree

8 files changed

+209
-123
lines changed

8 files changed

+209
-123
lines changed

core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSortSpillMerger.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ public final class UnsafeExternalSortSpillMerger {
3030
public static abstract class MergeableIterator {
3131
public abstract boolean hasNext();
3232

33-
public abstract void advanceRecord();
33+
public abstract void loadNextRecord();
3434

3535
public abstract long getPrefix();
3636

@@ -68,6 +68,9 @@ public int compare(MergeableIterator left, MergeableIterator right) {
6868
}
6969

7070
public void addSpill(MergeableIterator spillReader) {
71+
if (spillReader.hasNext()) {
72+
spillReader.loadNextRecord();
73+
}
7174
priorityQueue.add(spillReader);
7275
}
7376

@@ -79,17 +82,18 @@ public Iterator<RecordAddressAndKeyPrefix> getSortedIterator() {
7982

8083
@Override
8184
public boolean hasNext() {
82-
return spillReader.hasNext() || !priorityQueue.isEmpty();
85+
return !priorityQueue.isEmpty() || (spillReader != null && spillReader.hasNext());
8386
}
8487

8588
@Override
8689
public RecordAddressAndKeyPrefix next() {
8790
if (spillReader != null) {
8891
if (spillReader.hasNext()) {
92+
spillReader.loadNextRecord();
8993
priorityQueue.add(spillReader);
9094
}
9195
}
92-
spillReader = priorityQueue.poll();
96+
spillReader = priorityQueue.remove();
9397
record.baseObject = spillReader.getBaseObject();
9498
record.baseOffset = spillReader.getBaseOffset();
9599
record.keyPrefix = spillReader.getPrefix();

core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSorter.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
public final class UnsafeExternalSorter {
3939

4040
private static final int PAGE_SIZE = 1024 * 1024; // TODO: tune this
41-
private static final int SER_BUFFER_SIZE = 1024 * 1024; // TODO: tune this
4241

4342
private final PrefixComparator prefixComparator;
4443
private final RecordComparator recordComparator;
@@ -92,6 +91,7 @@ private void openSorter() {
9291
public void spill() throws IOException {
9392
final UnsafeSorterSpillWriter spillWriter =
9493
new UnsafeSorterSpillWriter(blockManager, fileBufferSize, writeMetrics);
94+
spillWriters.add(spillWriter);
9595
final Iterator<RecordPointerAndKeyPrefix> sortedRecords = sorter.getSortedIterator();
9696
while (sortedRecords.hasNext()) {
9797
final RecordPointerAndKeyPrefix recordPointer = sortedRecords.next();
@@ -110,8 +110,11 @@ private void freeMemory() {
110110
final Iterator<MemoryBlock> iter = allocatedPages.iterator();
111111
while (iter.hasNext()) {
112112
memoryManager.freePage(iter.next());
113+
shuffleMemoryManager.release(PAGE_SIZE);
113114
iter.remove();
114115
}
116+
currentPage = null;
117+
currentPagePosition = -1;
115118
}
116119

117120
private void ensureSpaceInDataPage(int requiredSpace) throws Exception {

core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,8 @@ public void remove() {
169169

170170
public UnsafeExternalSortSpillMerger.MergeableIterator getMergeableIterator() {
171171
sorter.sort(sortBuffer, 0, sortBufferInsertPosition / 2, sortComparator);
172-
return new UnsafeExternalSortSpillMerger.MergeableIterator() {
172+
UnsafeExternalSortSpillMerger.MergeableIterator iter =
173+
new UnsafeExternalSortSpillMerger.MergeableIterator() {
173174

174175
private int position = 0;
175176
private Object baseObject;
@@ -182,12 +183,12 @@ public boolean hasNext() {
182183
}
183184

184185
@Override
185-
public void advanceRecord() {
186+
public void loadNextRecord() {
186187
final long recordPointer = sortBuffer[position];
187-
baseObject = memoryManager.getPage(recordPointer);
188-
baseOffset = memoryManager.getOffsetInPage(recordPointer);
189188
keyPrefix = sortBuffer[position + 1];
190189
position += 2;
190+
baseObject = memoryManager.getPage(recordPointer);
191+
baseOffset = memoryManager.getOffsetInPage(recordPointer);
191192
}
192193

193194
@Override
@@ -205,5 +206,6 @@ public long getBaseOffset() {
205206
return baseOffset;
206207
}
207208
};
209+
return iter;
208210
}
209211
}

core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillReader.java

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,9 @@
1818
package org.apache.spark.unsafe.sort;
1919

2020
import com.google.common.io.ByteStreams;
21-
import org.apache.spark.executor.ShuffleWriteMetrics;
22-
import org.apache.spark.serializer.JavaSerializerInstance;
23-
import org.apache.spark.serializer.SerializerInstance;
2421
import org.apache.spark.storage.BlockId;
2522
import org.apache.spark.storage.BlockManager;
26-
import org.apache.spark.storage.BlockObjectWriter;
27-
import org.apache.spark.storage.TempLocalBlockId;
2823
import org.apache.spark.unsafe.PlatformDependent;
29-
import scala.Tuple2;
3024

3125
import java.io.*;
3226

@@ -39,18 +33,19 @@ public final class UnsafeSorterSpillReader extends UnsafeExternalSortSpillMerger
3933
private long keyPrefix;
4034
private final byte[] arr = new byte[1024 * 1024]; // TODO: tune this (maybe grow dynamically)?
4135
private final Object baseObject = arr;
36+
private int nextRecordLength;
4237
private final long baseOffset = PlatformDependent.BYTE_ARRAY_OFFSET;
4338

4439
public UnsafeSorterSpillReader(
4540
BlockManager blockManager,
4641
File file,
4742
BlockId blockId) throws IOException {
4843
this.file = file;
44+
assert (file.length() > 0);
4945
final BufferedInputStream bs = new BufferedInputStream(new FileInputStream(file));
5046
this.in = blockManager.wrapForCompression(blockId, bs);
5147
this.din = new DataInputStream(this.in);
52-
assert (file.length() > 0);
53-
advanceRecord();
48+
nextRecordLength = din.readInt();
5449
}
5550

5651
@Override
@@ -59,21 +54,19 @@ public boolean hasNext() {
5954
}
6055

6156
@Override
62-
public void advanceRecord() {
57+
public void loadNextRecord() {
6358
try {
64-
final int recordLength = din.readInt();
65-
if (recordLength == UnsafeSorterSpillWriter.EOF_MARKER) {
59+
keyPrefix = din.readLong();
60+
ByteStreams.readFully(in, arr, 0, nextRecordLength);
61+
nextRecordLength = din.readInt();
62+
if (nextRecordLength == UnsafeSorterSpillWriter.EOF_MARKER) {
6663
in.close();
6764
in = null;
68-
return;
65+
din = null;
6966
}
70-
keyPrefix = din.readLong();
71-
ByteStreams.readFully(in, arr, 0, recordLength);
72-
7367
} catch (Exception e) {
7468
PlatformDependent.throwException(e);
7569
}
76-
throw new IllegalStateException();
7770
}
7871

7972
@Override

core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillWriter.java

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,20 @@
1818
package org.apache.spark.unsafe.sort;
1919

2020
import org.apache.spark.executor.ShuffleWriteMetrics;
21+
import org.apache.spark.serializer.DeserializationStream;
2122
import org.apache.spark.serializer.JavaSerializerInstance;
23+
import org.apache.spark.serializer.SerializationStream;
2224
import org.apache.spark.serializer.SerializerInstance;
2325
import org.apache.spark.storage.BlockId;
2426
import org.apache.spark.storage.BlockManager;
2527
import org.apache.spark.storage.BlockObjectWriter;
2628
import org.apache.spark.storage.TempLocalBlockId;
2729
import org.apache.spark.unsafe.PlatformDependent;
2830
import scala.Tuple2;
31+
import scala.reflect.ClassTag;
2932

30-
import java.io.DataOutputStream;
31-
import java.io.File;
32-
import java.io.IOException;
33+
import java.io.*;
34+
import java.nio.ByteBuffer;
3335

3436
public final class UnsafeSorterSpillWriter {
3537

@@ -51,7 +53,47 @@ public UnsafeSorterSpillWriter(
5153
this.file = spilledFileInfo._2();
5254
this.blockId = spilledFileInfo._1();
5355
// Dummy serializer:
54-
final SerializerInstance ser = new JavaSerializerInstance(0, false, null);
56+
final SerializerInstance ser = new SerializerInstance() {
57+
@Override
58+
public SerializationStream serializeStream(OutputStream s) {
59+
return new SerializationStream() {
60+
@Override
61+
public void flush() {
62+
63+
}
64+
65+
@Override
66+
public <T> SerializationStream writeObject(T t, ClassTag<T> ev1) {
67+
return null;
68+
}
69+
70+
@Override
71+
public void close() {
72+
73+
}
74+
};
75+
}
76+
77+
@Override
78+
public <T> ByteBuffer serialize(T t, ClassTag<T> ev1) {
79+
return null;
80+
}
81+
82+
@Override
83+
public DeserializationStream deserializeStream(InputStream s) {
84+
return null;
85+
}
86+
87+
@Override
88+
public <T> T deserialize(ByteBuffer bytes, ClassLoader loader, ClassTag<T> ev1) {
89+
return null;
90+
}
91+
92+
@Override
93+
public <T> T deserialize(ByteBuffer bytes, ClassTag<T> ev1) {
94+
return null;
95+
}
96+
};
5597
writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, writeMetrics);
5698
dos = new DataOutputStream(writer);
5799
}
@@ -61,14 +103,14 @@ public void write(
61103
long baseOffset,
62104
int recordLength,
63105
long keyPrefix) throws IOException {
106+
dos.writeInt(recordLength);
107+
dos.writeLong(keyPrefix);
64108
PlatformDependent.copyMemory(
65109
baseObject,
66110
baseOffset + 4,
67111
arr,
68112
PlatformDependent.BYTE_ARRAY_OFFSET,
69113
recordLength);
70-
dos.writeInt(recordLength);
71-
dos.writeLong(keyPrefix);
72114
writer.write(arr, 0, recordLength);
73115
// TODO: add a test that detects whether we leave this call out:
74116
writer.recordWritten();

core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,13 @@ private[spark] class DiskBlockObjectWriter(
223223
}
224224
}
225225

226-
override def write(b: Int): Unit = throw new UnsupportedOperationException()
226+
override def write(b: Int): Unit = {
227+
if (!initialized) {
228+
open()
229+
}
230+
231+
bs.write(b)
232+
}
227233

228234
override def write(kvBytes: Array[Byte], offs: Int, len: Int): Unit = {
229235
if (!initialized) {

0 commit comments

Comments
 (0)