Skip to content

Commit 4fb96e5

Browse files
jiangxb1987gatorsmile
authored andcommitted
[SPARK-25114][CORE] Fix RecordBinaryComparator when subtraction between two words is divisible by Integer.MAX_VALUE.
## What changes were proposed in this pull request? apache#22079 (comment) It is possible for two objects to be unequal and yet we consider them as equal with this code, if the long values are separated by Int.MaxValue. This PR fixes the issue. ## How was this patch tested? Add new test cases in `RecordBinaryComparatorSuite`. Closes apache#22101 from jiangxb1987/fix-rbc. Authored-by: Xingbo Jiang <[email protected]> Signed-off-by: Xiao Li <[email protected]>
1 parent f984ec7 commit 4fb96e5

File tree

2 files changed

+81
-11
lines changed

2 files changed

+81
-11
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,10 @@
2222

2323
public final class RecordBinaryComparator extends RecordComparator {
2424

25-
// TODO(jiangxb) Add test suite for this.
2625
@Override
2726
public int compare(
2827
Object leftObj, long leftOff, int leftLen, Object rightObj, long rightOff, int rightLen) {
2928
int i = 0;
30-
int res = 0;
3129

3230
// If the arrays have different length, the longer one is larger.
3331
if (leftLen != rightLen) {
@@ -40,27 +38,33 @@ public int compare(
4038
// check if stars align and we can get both offsets to be aligned
4139
if ((leftOff % 8) == (rightOff % 8)) {
4240
while ((leftOff + i) % 8 != 0 && i < leftLen) {
43-
res = (Platform.getByte(leftObj, leftOff + i) & 0xff) -
44-
(Platform.getByte(rightObj, rightOff + i) & 0xff);
45-
if (res != 0) return res;
41+
final int v1 = Platform.getByte(leftObj, leftOff + i) & 0xff;
42+
final int v2 = Platform.getByte(rightObj, rightOff + i) & 0xff;
43+
if (v1 != v2) {
44+
return v1 > v2 ? 1 : -1;
45+
}
4646
i += 1;
4747
}
4848
}
4949
// for architectures that support unaligned accesses, chew it up 8 bytes at a time
5050
if (Platform.unaligned() || (((leftOff + i) % 8 == 0) && ((rightOff + i) % 8 == 0))) {
5151
while (i <= leftLen - 8) {
52-
res = (int) ((Platform.getLong(leftObj, leftOff + i) -
53-
Platform.getLong(rightObj, rightOff + i)) % Integer.MAX_VALUE);
54-
if (res != 0) return res;
52+
final long v1 = Platform.getLong(leftObj, leftOff + i);
53+
final long v2 = Platform.getLong(rightObj, rightOff + i);
54+
if (v1 != v2) {
55+
return v1 > v2 ? 1 : -1;
56+
}
5557
i += 8;
5658
}
5759
}
5860
// this will finish off the unaligned comparisons, or do the entire aligned comparison
5961
// whichever is needed.
6062
while (i < leftLen) {
61-
res = (Platform.getByte(leftObj, leftOff + i) & 0xff) -
62-
(Platform.getByte(rightObj, rightOff + i) & 0xff);
63-
if (res != 0) return res;
63+
final int v1 = Platform.getByte(leftObj, leftOff + i) & 0xff;
64+
final int v2 = Platform.getByte(rightObj, rightOff + i) & 0xff;
65+
if (v1 != v2) {
66+
return v1 > v2 ? 1 : -1;
67+
}
6468
i += 1;
6569
}
6670

sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,4 +253,70 @@ public void testBinaryComparatorForNullColumns() throws Exception {
253253
assert(compare(0, 0) == 0);
254254
assert(compare(0, 1) > 0);
255255
}
256+
257+
@Test
258+
public void testBinaryComparatorWhenSubtractionIsDivisibleByMaxIntValue() throws Exception {
259+
int numFields = 1;
260+
261+
UnsafeRow row1 = new UnsafeRow(numFields);
262+
byte[] data1 = new byte[100];
263+
row1.pointTo(data1, computeSizeInBytes(numFields * 8));
264+
row1.setLong(0, 11);
265+
266+
UnsafeRow row2 = new UnsafeRow(numFields);
267+
byte[] data2 = new byte[100];
268+
row2.pointTo(data2, computeSizeInBytes(numFields * 8));
269+
row2.setLong(0, 11L + Integer.MAX_VALUE);
270+
271+
insertRow(row1);
272+
insertRow(row2);
273+
274+
assert(compare(0, 1) < 0);
275+
}
276+
277+
@Test
278+
public void testBinaryComparatorWhenSubtractionCanOverflowLongValue() throws Exception {
279+
int numFields = 1;
280+
281+
UnsafeRow row1 = new UnsafeRow(numFields);
282+
byte[] data1 = new byte[100];
283+
row1.pointTo(data1, computeSizeInBytes(numFields * 8));
284+
row1.setLong(0, Long.MIN_VALUE);
285+
286+
UnsafeRow row2 = new UnsafeRow(numFields);
287+
byte[] data2 = new byte[100];
288+
row2.pointTo(data2, computeSizeInBytes(numFields * 8));
289+
row2.setLong(0, 1);
290+
291+
insertRow(row1);
292+
insertRow(row2);
293+
294+
assert(compare(0, 1) < 0);
295+
}
296+
297+
@Test
298+
public void testBinaryComparatorWhenOnlyTheLastColumnDiffers() throws Exception {
299+
int numFields = 4;
300+
301+
UnsafeRow row1 = new UnsafeRow(numFields);
302+
byte[] data1 = new byte[100];
303+
row1.pointTo(data1, computeSizeInBytes(numFields * 8));
304+
row1.setInt(0, 11);
305+
row1.setDouble(1, 3.14);
306+
row1.setInt(2, -1);
307+
row1.setLong(3, 0);
308+
309+
UnsafeRow row2 = new UnsafeRow(numFields);
310+
byte[] data2 = new byte[100];
311+
row2.pointTo(data2, computeSizeInBytes(numFields * 8));
312+
row2.setInt(0, 11);
313+
row2.setDouble(1, 3.14);
314+
row2.setInt(2, -1);
315+
row2.setLong(3, 1);
316+
317+
insertRow(row1);
318+
insertRow(row2);
319+
320+
assert(compare(0, 1) < 0);
321+
}
256322
}

0 commit comments

Comments
 (0)