diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java
index 172b394689ca9..a839de0a91448 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java
@@ -17,9 +17,12 @@
package org.apache.spark.util.sketch;
+import java.io.BufferedInputStream;
+import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
+import java.nio.ByteBuffer;
/**
* A Bloom filter is a space-efficient probabilistic data structure that offers an approximate
@@ -51,7 +54,22 @@ public enum Version {
*
The words/longs (numWords * 64 bit)
*
*/
- V1(1);
+ V1(1),
+
+ /**
+ * {@code BloomFilter} binary format version 2.
+ * Fixes the int32 truncation issue with V1 indexes, but by changing the bit pattern,
+ * it will become incompatible with V1 serializations.
+ * All values written in big-endian order:
+ *
+ * - Version number, always 2 (32 bit)
+ * - Number of hash functions (32 bit)
+ * - Integer seed to initialize hash functions (32 bit)
+ * - Total number of words of the underlying bit array (32 bit)
+ * - The words/longs (numWords * 64 bit)
+ *
+ */
+ V2(2);
private final int versionNumber;
@@ -175,14 +193,26 @@ public long cardinality() {
* the stream.
*/
public static BloomFilter readFrom(InputStream in) throws IOException {
- return BloomFilterImpl.readFrom(in);
+ // peek into the InputStream so we can determine the version
+ BufferedInputStream bin = new BufferedInputStream(in);
+ bin.mark(4);
+ int version = ByteBuffer.wrap(bin.readNBytes(4)).getInt();
+ bin.reset();
+
+ return switch (version) {
+ case 1 -> BloomFilterImpl.readFrom(bin);
+ case 2 -> BloomFilterImplV2.readFrom(bin);
+ default -> throw new IllegalArgumentException("Unknown BloomFilter version: " + version);
+ };
}
/**
* Reads in a {@link BloomFilter} from a byte array.
*/
public static BloomFilter readFrom(byte[] bytes) throws IOException {
- return BloomFilterImpl.readFrom(bytes);
+ try (ByteArrayInputStream bis = new ByteArrayInputStream(bytes)) {
+ return readFrom(bis);
+ }
}
/**
@@ -256,6 +286,19 @@ public static BloomFilter create(long expectedNumItems, double fpp) {
* pick an optimal {@code numHashFunctions} which can minimize {@code fpp} for the bloom filter.
*/
public static BloomFilter create(long expectedNumItems, long numBits) {
+ return create(Version.V2, expectedNumItems, numBits, BloomFilterImplV2.DEFAULT_SEED);
+ }
+
+ public static BloomFilter create(long expectedNumItems, long numBits, int seed) {
+ return create(Version.V2, expectedNumItems, numBits, seed);
+ }
+
+ public static BloomFilter create(
+ Version version,
+ long expectedNumItems,
+ long numBits,
+ int seed
+ ) {
if (expectedNumItems <= 0) {
throw new IllegalArgumentException("Expected insertions must be positive");
}
@@ -264,6 +307,11 @@ public static BloomFilter create(long expectedNumItems, long numBits) {
throw new IllegalArgumentException("Number of bits must be positive");
}
- return new BloomFilterImpl(optimalNumOfHashFunctions(expectedNumItems, numBits), numBits);
+ int numHashFunctions = optimalNumOfHashFunctions(expectedNumItems, numBits);
+
+ return switch (version) {
+ case V1 -> new BloomFilterImpl(numHashFunctions, numBits);
+ case V2 -> new BloomFilterImplV2(numHashFunctions, numBits, seed);
+ };
}
}
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterBase.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterBase.java
new file mode 100644
index 0000000000000..b5b321cba0407
--- /dev/null
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterBase.java
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.sketch;
+
+import java.util.Objects;
+
+abstract class BloomFilterBase extends BloomFilter {
+
+ public static final int DEFAULT_SEED = 0;
+
+ protected int seed;
+ protected int numHashFunctions;
+ protected BitArray bits;
+
+ protected BloomFilterBase(int numHashFunctions, long numBits) {
+ this(numHashFunctions, numBits, DEFAULT_SEED);
+ }
+
+ protected BloomFilterBase(int numHashFunctions, long numBits, int seed) {
+ this(new BitArray(numBits), numHashFunctions, seed);
+ }
+
+ protected BloomFilterBase(BitArray bits, int numHashFunctions, int seed) {
+ this.bits = bits;
+ this.numHashFunctions = numHashFunctions;
+ this.seed = seed;
+ }
+
+ protected BloomFilterBase() {}
+
+ @Override
+ public boolean equals(Object other) {
+ if (other == this) {
+ return true;
+ }
+
+ if (!(other instanceof BloomFilterBase that)) {
+ return false;
+ }
+
+ return
+ this.getClass() == that.getClass()
+ && this.numHashFunctions == that.numHashFunctions
+ && this.seed == that.seed
+ // TODO: this.bits can be null temporarily, during deserialization,
+ // should we worry about this?
+ && this.bits.equals(that.bits);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(numHashFunctions, seed, bits);
+ }
+
+ @Override
+ public double expectedFpp() {
+ return Math.pow((double) bits.cardinality() / bits.bitSize(), numHashFunctions);
+ }
+
+ @Override
+ public long bitSize() {
+ return bits.bitSize();
+ }
+
+ @Override
+ public boolean put(Object item) {
+ if (item instanceof String str) {
+ return putString(str);
+ } else if (item instanceof byte[] bytes) {
+ return putBinary(bytes);
+ } else {
+ return putLong(Utils.integralToLong(item));
+ }
+ }
+
+ protected HiLoHash hashLongToIntPair(long item, int seed) {
+ // Here we first hash the input long element into 2 int hash values, h1 and h2, then produce n
+ // hash values by `h1 + i * h2` with 1 <= i <= numHashFunctions.
+ // Note that `CountMinSketch` use a different strategy, it hash the input long element with
+ // every i to produce n hash values.
+ // TODO: the strategy of `CountMinSketch` looks more advanced, should we follow it here?
+ int h1 = Murmur3_x86_32.hashLong(item, seed);
+ int h2 = Murmur3_x86_32.hashLong(item, h1);
+ return new HiLoHash(h1, h2);
+ }
+
+ protected HiLoHash hashBytesToIntPair(byte[] item, int seed) {
+ int h1 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, seed);
+ int h2 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, h1);
+ return new HiLoHash(h1, h2);
+ }
+
+ protected abstract boolean scatterHashAndSetAllBits(HiLoHash inputHash);
+
+ protected abstract boolean scatterHashAndGetAllBits(HiLoHash inputHash);
+
+ @Override
+ public boolean putString(String item) {
+ return putBinary(Utils.getBytesFromUTF8String(item));
+ }
+
+ @Override
+ public boolean putBinary(byte[] item) {
+ HiLoHash hiLoHash = hashBytesToIntPair(item, seed);
+ return scatterHashAndSetAllBits(hiLoHash);
+ }
+
+ @Override
+ public boolean mightContainString(String item) {
+ return mightContainBinary(Utils.getBytesFromUTF8String(item));
+ }
+
+ @Override
+ public boolean mightContainBinary(byte[] item) {
+ HiLoHash hiLoHash = hashBytesToIntPair(item, seed);
+ return scatterHashAndGetAllBits(hiLoHash);
+ }
+
+ public boolean putLong(long item) {
+ HiLoHash hiLoHash = hashLongToIntPair(item, seed);
+ return scatterHashAndSetAllBits(hiLoHash);
+ }
+
+ @Override
+ public boolean mightContainLong(long item) {
+ HiLoHash hiLoHash = hashLongToIntPair(item, seed);
+ return scatterHashAndGetAllBits(hiLoHash);
+ }
+
+ @Override
+ public boolean mightContain(Object item) {
+ if (item instanceof String str) {
+ return mightContainString(str);
+ } else if (item instanceof byte[] bytes) {
+ return mightContainBinary(bytes);
+ } else {
+ return mightContainLong(Utils.integralToLong(item));
+ }
+ }
+
+ @Override
+ public boolean isCompatible(BloomFilter other) {
+ if (other == null) {
+ return false;
+ }
+
+ if (!(other instanceof BloomFilterBase that)) {
+ return false;
+ }
+
+ return
+ this.getClass() == that.getClass()
+ && this.bitSize() == that.bitSize()
+ && this.numHashFunctions == that.numHashFunctions
+ && this.seed == that.seed;
+ }
+
+ @Override
+ public BloomFilter mergeInPlace(BloomFilter other) throws IncompatibleMergeException {
+ BloomFilterBase otherImplInstance = checkCompatibilityForMerge(other);
+
+ this.bits.putAll(otherImplInstance.bits);
+ return this;
+ }
+
+ @Override
+ public BloomFilter intersectInPlace(BloomFilter other) throws IncompatibleMergeException {
+ BloomFilterBase otherImplInstance = checkCompatibilityForMerge(other);
+
+ this.bits.and(otherImplInstance.bits);
+ return this;
+ }
+
+ @Override
+ public long cardinality() {
+ return this.bits.cardinality();
+ }
+
+ protected abstract BloomFilterBase checkCompatibilityForMerge(BloomFilter other)
+ throws IncompatibleMergeException;
+
+ public record HiLoHash(int hi, int lo) {}
+
+}
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java
index 3bd04a531fe75..743fd9fb6738e 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java
@@ -19,71 +19,17 @@
import java.io.*;
-class BloomFilterImpl extends BloomFilter implements Serializable {
-
- private int numHashFunctions;
-
- private BitArray bits;
+class BloomFilterImpl extends BloomFilterBase implements Serializable {
BloomFilterImpl(int numHashFunctions, long numBits) {
- this(new BitArray(numBits), numHashFunctions);
- }
-
- private BloomFilterImpl(BitArray bits, int numHashFunctions) {
- this.bits = bits;
- this.numHashFunctions = numHashFunctions;
+ super(numHashFunctions, numBits);
}
private BloomFilterImpl() {}
- @Override
- public boolean equals(Object other) {
- if (other == this) {
- return true;
- }
-
- if (!(other instanceof BloomFilterImpl that)) {
- return false;
- }
-
- return this.numHashFunctions == that.numHashFunctions && this.bits.equals(that.bits);
- }
-
- @Override
- public int hashCode() {
- return bits.hashCode() * 31 + numHashFunctions;
- }
-
- @Override
- public double expectedFpp() {
- return Math.pow((double) bits.cardinality() / bits.bitSize(), numHashFunctions);
- }
-
- @Override
- public long bitSize() {
- return bits.bitSize();
- }
-
- @Override
- public boolean put(Object item) {
- if (item instanceof String str) {
- return putString(str);
- } else if (item instanceof byte[] bytes) {
- return putBinary(bytes);
- } else {
- return putLong(Utils.integralToLong(item));
- }
- }
-
- @Override
- public boolean putString(String item) {
- return putBinary(Utils.getBytesFromUTF8String(item));
- }
-
- @Override
- public boolean putBinary(byte[] item) {
- int h1 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, 0);
- int h2 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, h1);
+ protected boolean scatterHashAndSetAllBits(HiLoHash inputHash) {
+ int h1 = inputHash.hi();
+ int h2 = inputHash.lo();
long bitSize = bits.bitSize();
boolean bitsChanged = false;
@@ -98,15 +44,9 @@ public boolean putBinary(byte[] item) {
return bitsChanged;
}
- @Override
- public boolean mightContainString(String item) {
- return mightContainBinary(Utils.getBytesFromUTF8String(item));
- }
-
- @Override
- public boolean mightContainBinary(byte[] item) {
- int h1 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, 0);
- int h2 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, h1);
+ protected boolean scatterHashAndGetAllBits(HiLoHash inputHash) {
+ int h1 = inputHash.hi();
+ int h2 = inputHash.lo();
long bitSize = bits.bitSize();
for (int i = 1; i <= numHashFunctions; i++) {
@@ -122,94 +62,7 @@ public boolean mightContainBinary(byte[] item) {
return true;
}
- @Override
- public boolean putLong(long item) {
- // Here we first hash the input long element into 2 int hash values, h1 and h2, then produce n
- // hash values by `h1 + i * h2` with 1 <= i <= numHashFunctions.
- // Note that `CountMinSketch` use a different strategy, it hash the input long element with
- // every i to produce n hash values.
- // TODO: the strategy of `CountMinSketch` looks more advanced, should we follow it here?
- int h1 = Murmur3_x86_32.hashLong(item, 0);
- int h2 = Murmur3_x86_32.hashLong(item, h1);
-
- long bitSize = bits.bitSize();
- boolean bitsChanged = false;
- for (int i = 1; i <= numHashFunctions; i++) {
- int combinedHash = h1 + (i * h2);
- // Flip all the bits if it's negative (guaranteed positive number)
- if (combinedHash < 0) {
- combinedHash = ~combinedHash;
- }
- bitsChanged |= bits.set(combinedHash % bitSize);
- }
- return bitsChanged;
- }
-
- @Override
- public boolean mightContainLong(long item) {
- int h1 = Murmur3_x86_32.hashLong(item, 0);
- int h2 = Murmur3_x86_32.hashLong(item, h1);
-
- long bitSize = bits.bitSize();
- for (int i = 1; i <= numHashFunctions; i++) {
- int combinedHash = h1 + (i * h2);
- // Flip all the bits if it's negative (guaranteed positive number)
- if (combinedHash < 0) {
- combinedHash = ~combinedHash;
- }
- if (!bits.get(combinedHash % bitSize)) {
- return false;
- }
- }
- return true;
- }
-
- @Override
- public boolean mightContain(Object item) {
- if (item instanceof String str) {
- return mightContainString(str);
- } else if (item instanceof byte[] bytes) {
- return mightContainBinary(bytes);
- } else {
- return mightContainLong(Utils.integralToLong(item));
- }
- }
-
- @Override
- public boolean isCompatible(BloomFilter other) {
- if (other == null) {
- return false;
- }
-
- if (!(other instanceof BloomFilterImpl that)) {
- return false;
- }
-
- return this.bitSize() == that.bitSize() && this.numHashFunctions == that.numHashFunctions;
- }
-
- @Override
- public BloomFilter mergeInPlace(BloomFilter other) throws IncompatibleMergeException {
- BloomFilterImpl otherImplInstance = checkCompatibilityForMerge(other);
-
- this.bits.putAll(otherImplInstance.bits);
- return this;
- }
-
- @Override
- public BloomFilter intersectInPlace(BloomFilter other) throws IncompatibleMergeException {
- BloomFilterImpl otherImplInstance = checkCompatibilityForMerge(other);
-
- this.bits.and(otherImplInstance.bits);
- return this;
- }
-
- @Override
- public long cardinality() {
- return this.bits.cardinality();
- }
-
- private BloomFilterImpl checkCompatibilityForMerge(BloomFilter other)
+ protected BloomFilterImpl checkCompatibilityForMerge(BloomFilter other)
throws IncompatibleMergeException {
// Duplicates the logic of `isCompatible` here to provide better error message.
if (other == null) {
@@ -240,6 +93,7 @@ public void writeTo(OutputStream out) throws IOException {
dos.writeInt(Version.V1.getVersionNumber());
dos.writeInt(numHashFunctions);
+ // ignore seed
bits.writeTo(dos);
}
@@ -252,6 +106,7 @@ private void readFrom0(InputStream in) throws IOException {
}
this.numHashFunctions = dis.readInt();
+ this.seed = DEFAULT_SEED;
this.bits = BitArray.readFrom(dis);
}
@@ -261,16 +116,18 @@ public static BloomFilterImpl readFrom(InputStream in) throws IOException {
return filter;
}
- public static BloomFilterImpl readFrom(byte[] bytes) throws IOException {
- try (ByteArrayInputStream bis = new ByteArrayInputStream(bytes)) {
- return readFrom(bis);
- }
+ // no longer necessary, but can't remove without triggering MIMA violations
+ @Deprecated
+ public static BloomFilter readFrom(byte[] bytes) throws IOException {
+ return BloomFilter.readFrom(bytes);
}
+ @Serial
private void writeObject(ObjectOutputStream out) throws IOException {
writeTo(out);
}
+ @Serial
private void readObject(ObjectInputStream in) throws IOException {
readFrom0(in);
}
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImplV2.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImplV2.java
new file mode 100644
index 0000000000000..fa0a1df384865
--- /dev/null
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImplV2.java
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.sketch;
+
+import java.io.*;
+
+class BloomFilterImplV2 extends BloomFilterBase implements Serializable {
+
+ BloomFilterImplV2(int numHashFunctions, long numBits, int seed) {
+ this(new BitArray(numBits), numHashFunctions, seed);
+ }
+
+ private BloomFilterImplV2(BitArray bits, int numHashFunctions, int seed) {
+ super(bits, numHashFunctions, seed);
+ }
+
+ private BloomFilterImplV2() {}
+
+ @Override
+ public boolean equals(Object other) {
+ if (other == this) {
+ return true;
+ }
+
+ if (!(other instanceof BloomFilterImplV2 that)) {
+ return false;
+ }
+
+ return
+ this.numHashFunctions == that.numHashFunctions
+ && this.seed == that.seed
+ && this.bits.equals(that.bits);
+ }
+
+ protected boolean scatterHashAndSetAllBits(HiLoHash inputHash) {
+ int h1 = inputHash.hi();
+ int h2 = inputHash.lo();
+
+ long bitSize = bits.bitSize();
+ boolean bitsChanged = false;
+
+ // Integer.MAX_VALUE takes care of scrambling the higher four bytes of combinedHash
+ long combinedHash = (long) h1 * Integer.MAX_VALUE;
+ for (long i = 0; i < numHashFunctions; i++) {
+ combinedHash += h2;
+
+ // Flip all the bits if it's negative (guaranteed positive number)
+ long combinedIndex = combinedHash < 0 ? ~combinedHash : combinedHash;
+
+ bitsChanged |= bits.set(combinedIndex % bitSize);
+ }
+ return bitsChanged;
+ }
+
+ protected boolean scatterHashAndGetAllBits(HiLoHash inputHash) {
+ int h1 = inputHash.hi();
+ int h2 = inputHash.lo();
+
+ long bitSize = bits.bitSize();
+
+ // Integer.MAX_VALUE takes care of scrambling the higher four bytes of combinedHash
+ long combinedHash = (long) h1 * Integer.MAX_VALUE;
+ for (long i = 0; i < numHashFunctions; i++) {
+ combinedHash += h2;
+
+ // Flip all the bits if it's negative (guaranteed positive number)
+ long combinedIndex = combinedHash < 0 ? ~combinedHash : combinedHash;
+
+ if (!bits.get(combinedIndex % bitSize)) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ protected BloomFilterImplV2 checkCompatibilityForMerge(BloomFilter other)
+ throws IncompatibleMergeException {
+ // Duplicates the logic of `isCompatible` here to provide better error message.
+ if (other == null) {
+ throw new IncompatibleMergeException("Cannot merge null bloom filter");
+ }
+
+ if (!(other instanceof BloomFilterImplV2 that)) {
+ throw new IncompatibleMergeException(
+ "Cannot merge bloom filter of class " + other.getClass().getName()
+ );
+ }
+
+ if (this.bitSize() != that.bitSize()) {
+ throw new IncompatibleMergeException("Cannot merge bloom filters with different bit size");
+ }
+
+ if (this.seed != that.seed) {
+ throw new IncompatibleMergeException(
+ "Cannot merge bloom filters with different seeds"
+ );
+ }
+
+ if (this.numHashFunctions != that.numHashFunctions) {
+ throw new IncompatibleMergeException(
+ "Cannot merge bloom filters with different number of hash functions"
+ );
+ }
+ return that;
+ }
+
+ @Override
+ public void writeTo(OutputStream out) throws IOException {
+ DataOutputStream dos = new DataOutputStream(out);
+
+ dos.writeInt(Version.V2.getVersionNumber());
+ dos.writeInt(numHashFunctions);
+ dos.writeInt(seed);
+ bits.writeTo(dos);
+ }
+
+ private void readFrom0(InputStream in) throws IOException {
+ DataInputStream dis = new DataInputStream(in);
+
+ int version = dis.readInt();
+ if (version != Version.V2.getVersionNumber()) {
+ throw new IOException("Unexpected Bloom filter version number (" + version + ")");
+ }
+
+ this.numHashFunctions = dis.readInt();
+ this.seed = dis.readInt();
+ this.bits = BitArray.readFrom(dis);
+ }
+
+ public static BloomFilterImplV2 readFrom(InputStream in) throws IOException {
+ BloomFilterImplV2 filter = new BloomFilterImplV2();
+ filter.readFrom0(in);
+ return filter;
+ }
+
+ @Serial
+ private void writeObject(ObjectOutputStream out) throws IOException {
+ writeTo(out);
+ }
+
+ @Serial
+ private void readObject(ObjectInputStream in) throws IOException {
+ readFrom0(in);
+ }
+}
diff --git a/common/sketch/src/test/java/org/apache/spark/util/sketch/SparkBloomFilterSuite.java b/common/sketch/src/test/java/org/apache/spark/util/sketch/SparkBloomFilterSuite.java
new file mode 100644
index 0000000000000..ac574525d36b3
--- /dev/null
+++ b/common/sketch/src/test/java/org/apache/spark/util/sketch/SparkBloomFilterSuite.java
@@ -0,0 +1,394 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.sketch;
+
+import org.junit.jupiter.api.*;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
+
+import java.io.PrintStream;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.time.Duration;
+import java.time.Instant;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.LongAdder;
+import java.util.stream.LongStream;
+import java.util.stream.Stream;
+
+
+public class SparkBloomFilterSuite {
+
+ // the implemented fpp limit is only approximating the hard boundary,
+ // so we'll need an error threshold for the assertion
+ final double FPP_EVEN_ODD_ERROR_FACTOR = 0.10;
+ final double FPP_RANDOM_ERROR_FACTOR = 0.10;
+
+ final long ONE_GB = 1024L * 1024L * 1024L;
+ final long REQUIRED_HEAP_UPPER_BOUND_IN_BYTES = 4 * ONE_GB;
+
+ private static Instant START;
+ private static boolean strict;
+ private static boolean verbose;
+
+ private Instant start;
+ private final Map testOutMap = new ConcurrentHashMap<>();
+
+ @BeforeAll
+ public static void beforeAll() {
+ START = Instant.now();
+ String testClassName = SparkBloomFilterSuite.class.getName();
+ strict = Boolean.parseBoolean(System.getProperty(testClassName+ ".strict", "true"));
+ verbose = Boolean.parseBoolean(System.getProperty(testClassName+ ".verbose", "false"));
+ }
+
+ @AfterAll
+ public static void afterAll() {
+ Duration duration = Duration.between(START, Instant.now());
+ if (verbose) {
+ System.err.println(duration + " TOTAL");
+ }
+ }
+
+ @BeforeEach
+ public void beforeEach(
+ TestInfo testInfo
+ ) throws Exception {
+ start = Instant.now();
+
+ String testName = testInfo.getDisplayName();
+
+ String testClassName = SparkBloomFilterSuite.class.getName();
+ String logDir = System.getProperty(testClassName+ ".logDir", "./target/tmp");
+ Path logDirPath = Path.of(logDir);
+ Files.createDirectories(logDirPath);
+ Path testLogPath = Path.of(logDir,testName + ".log");
+ Files.deleteIfExists(testLogPath);
+
+ PrintStream testOut = new PrintStream(Files.newOutputStream(testLogPath));
+ testOutMap.put(testName, testOut);
+
+ testOut.println("testName: " + testName);
+ }
+
+ @AfterEach
+ public void afterEach(TestInfo testInfo) {
+ Duration duration = Duration.between(start, Instant.now());
+
+ String testName = testInfo.getDisplayName();
+ PrintStream testOut = testOutMap.get(testName);
+
+ testOut.println("duration: " + duration );
+ testOut.close();
+ }
+
+ private static Stream dataPointProvider() {
+ // temporary workaround:
+ // to reduce running time to acceptable levels, we test only one case,
+ // with the default FPP and the default seed only.
+ return Stream.of(
+ Arguments.of(1_000_000_000L, 0.03, BloomFilterImplV2.DEFAULT_SEED)
+ );
+ // preferable minimum parameter space for tests:
+ // {1_000_000L, 1_000_000_000L} for: long numItems
+ // {0.05, 0.03, 0.01, 0.001} for: double expectedFpp
+ // {BloomFilterImpl.DEFAULT_SEED, 1, 127} for: int deterministicSeed
+ }
+
+ /**
+ * This test, in N number of iterations, inserts N even numbers (2*i) int,
+ * and leaves out N odd numbers (2*i+1) from the tested BloomFilter instance.
+ *
+ * It checks the 100% accuracy of mightContain=true on all of the even items,
+ * and measures the mightContain=true (false positive) rate on the not-inserted odd numbers.
+ *
+ * @param numItems the number of items to be inserted
+ * @param expectedFpp the expected fpp rate of the tested BloomFilter instance
+ * @param deterministicSeed the deterministic seed to use to initialize
+ * the primary BloomFilter instance.
+ */
+ @ParameterizedTest(name = "testAccuracyEvenOdd.n{0}_fpp{1}_seed{2}")
+ @MethodSource("dataPointProvider")
+ public void testAccuracyEvenOdd(
+ long numItems,
+ double expectedFpp,
+ int deterministicSeed,
+ TestInfo testInfo
+ ) {
+ String testName = testInfo.getDisplayName();
+ PrintStream testOut = testOutMap.get(testName);
+
+ long optimalNumOfBits = BloomFilter.optimalNumOfBits(numItems, expectedFpp);
+ testOut.printf(
+ "optimal bitArray: %d (%d MB)\n",
+ optimalNumOfBits,
+ optimalNumOfBits / Byte.SIZE / 1024 / 1024
+ );
+ Assumptions.assumeTrue(
+ optimalNumOfBits / Byte.SIZE < REQUIRED_HEAP_UPPER_BOUND_IN_BYTES,
+ "this testcase would require allocating more than 4GB of heap mem ("
+ + optimalNumOfBits
+ + " bits)"
+ );
+
+ BloomFilter bloomFilter =
+ BloomFilter.create(
+ BloomFilter.Version.V2,
+ numItems,
+ optimalNumOfBits,
+ deterministicSeed
+ );
+
+ testOut.printf(
+ "allocated bitArray: %d (%d MB)\n",
+ bloomFilter.bitSize(),
+ bloomFilter.bitSize() / Byte.SIZE / 1024 / 1024
+ );
+
+ for (long i = 0; i < numItems; i++) {
+ if (verbose && i % 10_000_000 == 0) {
+ System.err.printf("i: %d\n", i);
+ }
+
+ bloomFilter.putLong(2 * i);
+ }
+
+ testOut.printf("bitCount: %d\nsaturation: %f\n",
+ bloomFilter.cardinality(),
+ (double) bloomFilter.cardinality() / bloomFilter.bitSize()
+ );
+
+ LongAdder mightContainEven = new LongAdder();
+ LongAdder mightContainOdd = new LongAdder();
+
+ LongStream inputStream = LongStream.range(0, numItems).parallel();
+ inputStream.forEach(
+ i -> {
+ long even = 2 * i;
+ if (bloomFilter.mightContainLong(even)) {
+ mightContainEven.increment();
+ }
+
+ long odd = 2 * i + 1;
+ if (bloomFilter.mightContainLong(odd)) {
+ mightContainOdd.increment();
+ }
+ }
+ );
+
+ Assertions.assertEquals(
+ numItems, mightContainEven.longValue(),
+ "mightContainLong must return true for all inserted numbers"
+ );
+
+ double actualFpp = mightContainOdd.doubleValue() / numItems;
+ double acceptableFpp = expectedFpp * (1 + FPP_EVEN_ODD_ERROR_FACTOR);
+
+ testOut.printf("expectedFpp: %f %%\n", 100 * expectedFpp);
+ testOut.printf("acceptableFpp: %f %%\n", 100 * acceptableFpp);
+ testOut.printf("actualFpp: %f %%\n", 100 * actualFpp);
+
+ if (!strict) {
+ Assumptions.assumeTrue(
+ actualFpp <= acceptableFpp,
+ String.format(
+ "acceptableFpp(%f %%) < actualFpp (%f %%)",
+ 100 * acceptableFpp,
+ 100 * actualFpp
+ )
+ );
+ } else {
+ Assertions.assertTrue(
+ actualFpp <= acceptableFpp,
+ String.format(
+ "acceptableFpp(%f %%) < actualFpp (%f %%)",
+ 100 * acceptableFpp,
+ 100 * actualFpp
+ )
+ );
+ }
+ }
+
+ /**
+ * This test inserts N pseudorandomly generated numbers in 2N number of iterations in two
+ * differently seeded (theoretically independent) BloomFilter instances. All the random
+ * numbers generated in an even-iteration will be inserted into both filters, all the
+ * random numbers generated in an odd-iteration will be left out from both.
+ *
+ * The test checks the 100% accuracy of 'mightContain=true' for all the items inserted
+ * in an even-loop. It counts the false positives as the number of odd-loop items for
+ * which the primary filter reports 'mightContain=true', but secondary reports
+ * 'mightContain=false'. Since we inserted the same elements into both instances,
+ * and the secondary reports non-insertion, the 'mightContain=true' from the primary
+ * can only be a false positive.
+ *
+ * @param numItems the number of items to be inserted
+ * @param expectedFpp the expected fpp rate of the tested BloomFilter instance
+ * @param deterministicSeed the deterministic seed to use to initialize
+ * the primary BloomFilter instance. (The secondary will be
+ * initialized with the constant seed of 0xCAFEBABE)
+ */
+ @ParameterizedTest(name = "testAccuracyRandom.n{0}_fpp{1}_seed{2}")
+ @MethodSource("dataPointProvider")
+ public void testAccuracyRandomDistribution(
+ long numItems,
+ double expectedFpp,
+ int deterministicSeed,
+ TestInfo testInfo
+ ) {
+ String testName = testInfo.getDisplayName();
+ PrintStream testOut = testOutMap.get(testName);
+
+ long optimalNumOfBits = BloomFilter.optimalNumOfBits(numItems, expectedFpp);
+ testOut.printf(
+ "optimal bitArray: %d (%d MB)\n",
+ optimalNumOfBits,
+ optimalNumOfBits / Byte.SIZE / 1024 / 1024
+ );
+ Assumptions.assumeTrue(
+ 2 * optimalNumOfBits / Byte.SIZE < REQUIRED_HEAP_UPPER_BOUND_IN_BYTES,
+ "this testcase would require allocating more than 4GB of heap mem (2x "
+ + optimalNumOfBits
+ + " bits)"
+ );
+
+ BloomFilter bloomFilterPrimary =
+ BloomFilter.create(
+ BloomFilter.Version.V2,
+ numItems,
+ optimalNumOfBits,
+ deterministicSeed
+ );
+
+ BloomFilter bloomFilterSecondary =
+ BloomFilter.create(
+ BloomFilter.Version.V2,
+ numItems,
+ optimalNumOfBits,
+ 0xCAFEBABE
+ );
+
+ testOut.printf(
+ "allocated bitArray: %d (%d MB)\n",
+ bloomFilterPrimary.bitSize(),
+ bloomFilterPrimary.bitSize() / Byte.SIZE / 1024 / 1024
+ );
+
+ long iterationCount = 2 * numItems;
+
+ for (long i = 0; i < iterationCount; i++) {
+ if (verbose && i % 10_000_000 == 0) {
+ System.err.printf("i: %d\n", i);
+ }
+
+ long candidate = scramble(i);
+ if (i % 2 == 0) {
+ bloomFilterPrimary.putLong(candidate);
+ bloomFilterSecondary.putLong(candidate);
+ }
+ }
+ testOut.printf("bitCount: %d\nsaturation: %f\n",
+ bloomFilterPrimary.cardinality(),
+ (double) bloomFilterPrimary.cardinality() / bloomFilterPrimary.bitSize()
+ );
+
+ LongAdder mightContainEvenIndexed = new LongAdder();
+ LongAdder mightContainOddIndexed = new LongAdder();
+ LongAdder confirmedAsNotInserted = new LongAdder();
+ LongStream inputStream = LongStream.range(0, iterationCount).parallel();
+ inputStream.forEach(
+ i -> {
+ if (verbose && i % (iterationCount / 100) == 0) {
+ System.err.printf("%s: %2d %%\n", testName, 100 * i / iterationCount);
+ }
+
+ long candidate = scramble(i);
+
+ if (i % 2 == 0) { // EVEN
+ mightContainEvenIndexed.increment();
+ } else { // ODD
+ // for fpp estimation, only consider the odd indexes
+ // (to avoid querying the secondary with elements known to be inserted)
+
+ // since here we avoided all the even indexes,
+ // most of these secondary queries will return false
+ if (!bloomFilterSecondary.mightContainLong(candidate)) {
+ // from the odd indexes, we consider only those items
+ // where the secondary confirms the non-insertion
+
+ // anything on which the primary and the secondary
+ // disagrees here is a false positive
+ if (bloomFilterPrimary.mightContainLong(candidate)) {
+ mightContainOddIndexed.increment();
+ }
+ // count the total number of considered items for a baseline
+ confirmedAsNotInserted.increment();
+ }
+ }
+ }
+ );
+
+ Assertions.assertEquals(
+ numItems, mightContainEvenIndexed.longValue(),
+ "mightContainLong must return true for all inserted numbers"
+ );
+
+ double actualFpp =
+ mightContainOddIndexed.doubleValue() / confirmedAsNotInserted.doubleValue();
+ double acceptableFpp = expectedFpp * (1 + FPP_RANDOM_ERROR_FACTOR);
+
+ testOut.printf("mightContainOddIndexed: %10d\n", mightContainOddIndexed.longValue());
+ testOut.printf("confirmedAsNotInserted: %10d\n", confirmedAsNotInserted.longValue());
+ testOut.printf("numItems: %10d\n", numItems);
+ testOut.printf("expectedFpp: %f %%\n", 100 * expectedFpp);
+ testOut.printf("acceptableFpp: %f %%\n", 100 * acceptableFpp);
+ testOut.printf("actualFpp: %f %%\n", 100 * actualFpp);
+
+ if (!strict) {
+ Assumptions.assumeTrue(
+ actualFpp <= acceptableFpp,
+ String.format(
+ "acceptableFpp(%f %%) < actualFpp (%f %%)",
+ 100 * acceptableFpp,
+ 100 * actualFpp
+ )
+ );
+ } else {
+ Assertions.assertTrue(
+ actualFpp <= acceptableFpp,
+ String.format(
+ "acceptableFpp(%f %%) < actualFpp (%f %%)",
+ 100 * acceptableFpp,
+ 100 * actualFpp
+ )
+ );
+ }
+ }
+
+ // quick scrambling logic hacked out from java.util.Random
+ // its range is only 48bits (out of the 64bits of a Long value),
+ // but it should be enough for the purposes of this test.
+ private static final long multiplier = 0x5DEECE66DL;
+ private static final long addend = 0xBL;
+ private static final long mask = (1L << 48) - 1;
+ private static long scramble(long value) {
+ return (value * multiplier + addend) & mask;
+ }
+}
diff --git a/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala b/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala
index 4d0ba66637b46..ba8f97a51aecf 100644
--- a/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala
+++ b/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala
@@ -46,7 +46,9 @@ class BloomFilterSuite extends AnyFunSuite { // scalastyle:ignore funsuite
val fpp = 0.05
val numInsertion = numItems / 10
- val allItems = Array.fill(numItems)(itemGen(r))
+ // using a Set to avoid duplicates,
+ // inserting twice as many random values as used, to compensate for lost dupes
+ val allItems = Set.fill(2 * numItems)(itemGen(r)).take(numItems)
val filter = BloomFilter.create(numInsertion, fpp)
@@ -158,5 +160,11 @@ class BloomFilterSuite extends AnyFunSuite { // scalastyle:ignore funsuite
val filter2 = BloomFilter.create(2000, 6400)
filter1.mergeInPlace(filter2)
}
+
+ intercept[IncompatibleMergeException] {
+ val filter1 = BloomFilter.create(BloomFilter.Version.V1, 1000L, 6400L, 0)
+ val filter2 = BloomFilter.create(BloomFilter.Version.V2, 1000L, 6400L, 0)
+ filter1.mergeInPlace(filter2)
+ }
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala
index af97856fd222e..fb279b1db6fc9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala
@@ -41,16 +41,19 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession {
override def beforeAll(): Unit = {
super.beforeAll()
// Register 'bloom_filter_agg' to builtin.
- spark.sessionState.functionRegistry.registerFunction(funcId_bloom_filter_agg,
+ spark.sessionState.functionRegistry.registerFunction(
+ funcId_bloom_filter_agg,
new ExpressionInfo(classOf[BloomFilterAggregate].getName, "bloom_filter_agg"),
- (children: Seq[Expression]) => children.size match {
- case 1 => new BloomFilterAggregate(children.head)
- case 2 => new BloomFilterAggregate(children.head, children(1))
- case 3 => new BloomFilterAggregate(children.head, children(1), children(2))
- })
+ (children: Seq[Expression]) =>
+ children.size match {
+ case 1 => new BloomFilterAggregate(children.head)
+ case 2 => new BloomFilterAggregate(children.head, children(1))
+ case 3 => new BloomFilterAggregate(children.head, children(1), children(2))
+ })
// Register 'might_contain' to builtin.
- spark.sessionState.functionRegistry.registerFunction(funcId_might_contain,
+ spark.sessionState.functionRegistry.registerFunction(
+ funcId_might_contain,
new ExpressionInfo(classOf[BloomFilterMightContain].getName, "might_contain"),
(children: Seq[Expression]) => BloomFilterMightContain(children.head, children(1)))
}
@@ -64,10 +67,22 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession {
test("Test bloom_filter_agg and might_contain") {
val conf = SQLConf.get
val table = "bloom_filter_test"
- for (numEstimatedItems <- Seq(Long.MinValue, -10L, 0L, 4096L, 4194304L, Long.MaxValue,
- conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS))) {
- for ((numBits, index) <- Seq(Long.MinValue, -10L, 0L, 4096L, 4194304L, Long.MaxValue,
- conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_BITS)).zipWithIndex) {
+ for (numEstimatedItems <- Seq(
+ Long.MinValue,
+ -10L,
+ 0L,
+ 4096L,
+ 4194304L,
+ Long.MaxValue,
+ conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS))) {
+ for ((numBits, index) <- Seq(
+ Long.MinValue,
+ -10L,
+ 0L,
+ 4096L,
+ 4194304L,
+ Long.MaxValue,
+ conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_BITS)).zipWithIndex) {
val sqlString = s"""
|SELECT every(might_contain(
| (SELECT bloom_filter_agg(col,
@@ -85,7 +100,8 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession {
""".stripMargin
withTempView(table) {
(Seq(Long.MinValue, 0, Long.MaxValue) ++ (1L to 10000L))
- .toDF("col").createOrReplaceTempView(table)
+ .toDF("col")
+ .createOrReplaceTempView(table)
// Validate error messages as well as answers when there's no error.
if (numEstimatedItems <= 0) {
val exception = intercept[AnalysisException] {
@@ -104,16 +120,13 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession {
"valueRange" -> "[0, positive]",
"currentValue" -> toSQLValue(numEstimatedItems, LongType),
"sqlExpr" -> (s""""bloom_filter_agg(col, CAST($numEstimatedItems AS BIGINT), """ +
- s"""CAST($numBits AS BIGINT))"""")
- ),
+ s"""CAST($numBits AS BIGINT))"""")),
context = ExpectedContext(
fragment = "bloom_filter_agg(col,\n" +
s" cast($numEstimatedItems as long),\n" +
s" cast($numBits as long))",
start = 49,
- stop = stop(index)
- )
- )
+ stop = stop(index)))
} else if (numBits <= 0) {
val exception = intercept[AnalysisException] {
spark.sql(sqlString)
@@ -132,16 +145,13 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession {
"valueRange" -> "[0, positive]",
"currentValue" -> toSQLValue(numBits, LongType),
"sqlExpr" -> (s""""bloom_filter_agg(col, CAST($numEstimatedItems AS BIGINT), """ +
- s"""CAST($numBits AS BIGINT))"""")
- ),
+ s"""CAST($numBits AS BIGINT))"""")),
context = ExpectedContext(
fragment = "bloom_filter_agg(col,\n" +
s" cast($numEstimatedItems as long),\n" +
s" cast($numBits as long))",
start = 49,
- stop = stop(index)
- )
- )
+ stop = stop(index)))
} else {
checkAnswer(spark.sql(sqlString), Row(true, false))
}
@@ -154,8 +164,7 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession {
val exception1 = intercept[AnalysisException] {
spark.sql("""
|SELECT bloom_filter_agg(a)
- |FROM values (1.2), (2.5) as t(a)"""
- .stripMargin)
+ |FROM values (1.2), (2.5) as t(a)""".stripMargin)
}
checkError(
exception = exception1,
@@ -165,20 +174,13 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession {
"sqlExpr" -> "\"bloom_filter_agg(a, 1000000, 8388608)\"",
"expectedLeft" -> "\"BINARY\"",
"expectedRight" -> "\"BIGINT\"",
- "actual" -> "\"DECIMAL(2,1)\", \"BIGINT\", \"BIGINT\""
- ),
- context = ExpectedContext(
- fragment = "bloom_filter_agg(a)",
- start = 8,
- stop = 26
- )
- )
+ "actual" -> "\"DECIMAL(2,1)\", \"BIGINT\", \"BIGINT\""),
+ context = ExpectedContext(fragment = "bloom_filter_agg(a)", start = 8, stop = 26))
val exception2 = intercept[AnalysisException] {
spark.sql("""
|SELECT bloom_filter_agg(a, 2)
- |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)"""
- .stripMargin)
+ |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)""".stripMargin)
}
checkError(
exception = exception2,
@@ -188,20 +190,13 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession {
"sqlExpr" -> "\"bloom_filter_agg(a, 2, (2 * 8))\"",
"expectedLeft" -> "\"BINARY\"",
"expectedRight" -> "\"BIGINT\"",
- "actual" -> "\"BIGINT\", \"INT\", \"BIGINT\""
- ),
- context = ExpectedContext(
- fragment = "bloom_filter_agg(a, 2)",
- start = 8,
- stop = 29
- )
- )
+ "actual" -> "\"BIGINT\", \"INT\", \"BIGINT\""),
+ context = ExpectedContext(fragment = "bloom_filter_agg(a, 2)", start = 8, stop = 29))
val exception3 = intercept[AnalysisException] {
spark.sql("""
|SELECT bloom_filter_agg(a, cast(2 as long), 5)
- |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)"""
- .stripMargin)
+ |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)""".stripMargin)
}
checkError(
exception = exception3,
@@ -211,60 +206,42 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession {
"sqlExpr" -> "\"bloom_filter_agg(a, CAST(2 AS BIGINT), 5)\"",
"expectedLeft" -> "\"BINARY\"",
"expectedRight" -> "\"BIGINT\"",
- "actual" -> "\"BIGINT\", \"BIGINT\", \"INT\""
- ),
+ "actual" -> "\"BIGINT\", \"BIGINT\", \"INT\""),
context = ExpectedContext(
fragment = "bloom_filter_agg(a, cast(2 as long), 5)",
start = 8,
- stop = 46
- )
- )
+ stop = 46))
val exception4 = intercept[AnalysisException] {
spark.sql("""
|SELECT bloom_filter_agg(a, null, 5)
- |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)"""
- .stripMargin)
+ |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)""".stripMargin)
}
checkError(
exception = exception4,
condition = "DATATYPE_MISMATCH.UNEXPECTED_NULL",
parameters = Map(
"exprName" -> "estimatedNumItems or numBits",
- "sqlExpr" -> "\"bloom_filter_agg(a, NULL, 5)\""
- ),
- context = ExpectedContext(
- fragment = "bloom_filter_agg(a, null, 5)",
- start = 8,
- stop = 35
- )
- )
+ "sqlExpr" -> "\"bloom_filter_agg(a, NULL, 5)\""),
+ context = ExpectedContext(fragment = "bloom_filter_agg(a, null, 5)", start = 8, stop = 35))
val exception5 = intercept[AnalysisException] {
spark.sql("""
|SELECT bloom_filter_agg(a, 5, null)
- |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)"""
- .stripMargin)
+ |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)""".stripMargin)
}
checkError(
exception = exception5,
condition = "DATATYPE_MISMATCH.UNEXPECTED_NULL",
parameters = Map(
"exprName" -> "estimatedNumItems or numBits",
- "sqlExpr" -> "\"bloom_filter_agg(a, 5, NULL)\""
- ),
- context = ExpectedContext(
- fragment = "bloom_filter_agg(a, 5, null)",
- start = 8,
- stop = 35
- )
- )
+ "sqlExpr" -> "\"bloom_filter_agg(a, 5, NULL)\""),
+ context = ExpectedContext(fragment = "bloom_filter_agg(a, 5, null)", start = 8, stop = 35))
}
test("Test that might_contain errors out disallowed input value types") {
val exception1 = intercept[AnalysisException] {
- spark.sql("""|SELECT might_contain(1.0, 1L)"""
- .stripMargin)
+ spark.sql("""|SELECT might_contain(1.0, 1L)""".stripMargin)
}
checkError(
exception = exception1,
@@ -274,18 +251,11 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession {
"functionName" -> "`might_contain`",
"expectedLeft" -> "\"BINARY\"",
"expectedRight" -> "\"BIGINT\"",
- "actual" -> "\"DECIMAL(2,1)\", \"BIGINT\""
- ),
- context = ExpectedContext(
- fragment = "might_contain(1.0, 1L)",
- start = 7,
- stop = 28
- )
- )
+ "actual" -> "\"DECIMAL(2,1)\", \"BIGINT\""),
+ context = ExpectedContext(fragment = "might_contain(1.0, 1L)", start = 7, stop = 28))
val exception2 = intercept[AnalysisException] {
- spark.sql("""|SELECT might_contain(NULL, 0.1)"""
- .stripMargin)
+ spark.sql("""|SELECT might_contain(NULL, 0.1)""".stripMargin)
}
checkError(
exception = exception2,
@@ -295,22 +265,15 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession {
"functionName" -> "`might_contain`",
"expectedLeft" -> "\"BINARY\"",
"expectedRight" -> "\"BIGINT\"",
- "actual" -> "\"VOID\", \"DECIMAL(1,1)\""
- ),
- context = ExpectedContext(
- fragment = "might_contain(NULL, 0.1)",
- start = 7,
- stop = 30
- )
- )
+ "actual" -> "\"VOID\", \"DECIMAL(1,1)\""),
+ context = ExpectedContext(fragment = "might_contain(NULL, 0.1)", start = 7, stop = 30))
}
test("Test that might_contain errors out non-constant Bloom filter") {
val exception1 = intercept[AnalysisException] {
spark.sql("""
|SELECT might_contain(cast(a as binary), cast(5 as long))
- |FROM values (cast(1 as string)), (cast(2 as string)) as t(a)"""
- .stripMargin)
+ |FROM values (cast(1 as string)), (cast(2 as string)) as t(a)""".stripMargin)
}
checkError(
exception = exception1,
@@ -318,20 +281,16 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession {
parameters = Map(
"sqlExpr" -> "\"might_contain(CAST(a AS BINARY), CAST(5 AS BIGINT))\"",
"functionName" -> "`might_contain`",
- "actual" -> "\"CAST(a AS BINARY)\""
- ),
+ "actual" -> "\"CAST(a AS BINARY)\""),
context = ExpectedContext(
fragment = "might_contain(cast(a as binary), cast(5 as long))",
start = 8,
- stop = 56
- )
- )
+ stop = 56))
val exception2 = intercept[AnalysisException] {
spark.sql("""
|SELECT might_contain((select cast(a as binary)), cast(5 as long))
- |FROM values (cast(1 as string)), (cast(2 as string)) as t(a)"""
- .stripMargin)
+ |FROM values (cast(1 as string)), (cast(2 as string)) as t(a)""".stripMargin)
}
checkError(
exception = exception2,
@@ -339,32 +298,38 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession {
parameters = Map(
"sqlExpr" -> "\"might_contain(scalarsubquery(a), CAST(5 AS BIGINT))\"",
"functionName" -> "`might_contain`",
- "actual" -> "\"scalarsubquery(a)\""
- ),
+ "actual" -> "\"scalarsubquery(a)\""),
context = ExpectedContext(
fragment = "might_contain((select cast(a as binary)), cast(5 as long))",
start = 8,
- stop = 65
- )
- )
+ stop = 65))
}
- test("Test that might_contain can take a constant value input") {
- checkAnswer(spark.sql(
- """SELECT might_contain(
+ test("Test that might_contain can take a constant value input (seedless version)") {
+ checkAnswer(
+ spark.sql("""SELECT might_contain(
|X'00000001000000050000000343A2EC6EA8C117E2D3CDB767296B144FC5BFBCED9737F267',
|cast(201 as long))""".stripMargin),
Row(false))
}
+ test("Test that might_contain can take a constant value input (seeded version)") {
+ checkAnswer(
+ spark.sql("""SELECT might_contain(
+ |X'0000000200000005000000000000000343A2EC6EA8C117E2D3CDB767296B144FC5BFBCED9737F267',
+ |cast(201 as long))""".stripMargin),
+ Row(false))
+ }
+
test("Test that bloom_filter_agg produces a NULL with empty input") {
- checkAnswer(spark.sql("""SELECT bloom_filter_agg(cast(id as long)) from range(1, 1)"""),
+ checkAnswer(
+ spark.sql("""SELECT bloom_filter_agg(cast(id as long)) from range(1, 1)"""),
Row(null))
}
test("Test NULL inputs for might_contain") {
- checkAnswer(spark.sql(
- s"""
+ checkAnswer(
+ spark.sql(s"""
|SELECT might_contain(null, null) both_null,
| might_contain(null, 1L) null_bf,
| might_contain((SELECT bloom_filter_agg(cast(id as long)) from range(1, 10000)),
@@ -374,9 +339,15 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession {
}
test("Test that a query with bloom_filter_agg has partial aggregates") {
- assert(spark.sql("""SELECT bloom_filter_agg(cast(id as long)) from range(1, 1000000)""")
- .queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec].inputPlan
- .collect({case agg: BaseAggregateExec => agg}).size == 2)
+ assert(
+ spark
+ .sql("""SELECT bloom_filter_agg(cast(id as long)) from range(1, 1000000)""")
+ .queryExecution
+ .executedPlan
+ .asInstanceOf[AdaptiveSparkPlanExec]
+ .inputPlan
+ .collect({ case agg: BaseAggregateExec => agg })
+ .size == 2)
}
test("Test numBitsExpression") {
@@ -385,7 +356,8 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession {
assert(agg.numBitsExpression === Literal(numBits))
}
- checkNumBits(conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS) * 100,
+ checkNumBits(
+ conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS) * 100,
conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_BITS))
checkNumBits(conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS) + 10, 29193836)
checkNumBits(conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS), 29193763)