diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java index 172b394689ca9..a839de0a91448 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java @@ -17,9 +17,12 @@ package org.apache.spark.util.sketch; +import java.io.BufferedInputStream; +import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; +import java.nio.ByteBuffer; /** * A Bloom filter is a space-efficient probabilistic data structure that offers an approximate @@ -51,7 +54,22 @@ public enum Version { *
  • The words/longs (numWords * 64 bit)
  • * */ - V1(1); + V1(1), + + /** + * {@code BloomFilter} binary format version 2. + * Fixes the int32 truncation issue with V1 indexes, but by changing the bit pattern, + * it will become incompatible with V1 serializations. + * All values written in big-endian order: + * + */ + V2(2); private final int versionNumber; @@ -175,14 +193,26 @@ public long cardinality() { * the stream. */ public static BloomFilter readFrom(InputStream in) throws IOException { - return BloomFilterImpl.readFrom(in); + // peek into the InputStream so we can determine the version + BufferedInputStream bin = new BufferedInputStream(in); + bin.mark(4); + int version = ByteBuffer.wrap(bin.readNBytes(4)).getInt(); + bin.reset(); + + return switch (version) { + case 1 -> BloomFilterImpl.readFrom(bin); + case 2 -> BloomFilterImplV2.readFrom(bin); + default -> throw new IllegalArgumentException("Unknown BloomFilter version: " + version); + }; } /** * Reads in a {@link BloomFilter} from a byte array. */ public static BloomFilter readFrom(byte[] bytes) throws IOException { - return BloomFilterImpl.readFrom(bytes); + try (ByteArrayInputStream bis = new ByteArrayInputStream(bytes)) { + return readFrom(bis); + } } /** @@ -256,6 +286,19 @@ public static BloomFilter create(long expectedNumItems, double fpp) { * pick an optimal {@code numHashFunctions} which can minimize {@code fpp} for the bloom filter. */ public static BloomFilter create(long expectedNumItems, long numBits) { + return create(Version.V2, expectedNumItems, numBits, BloomFilterImplV2.DEFAULT_SEED); + } + + public static BloomFilter create(long expectedNumItems, long numBits, int seed) { + return create(Version.V2, expectedNumItems, numBits, seed); + } + + public static BloomFilter create( + Version version, + long expectedNumItems, + long numBits, + int seed + ) { if (expectedNumItems <= 0) { throw new IllegalArgumentException("Expected insertions must be positive"); } @@ -264,6 +307,11 @@ public static BloomFilter create(long expectedNumItems, long numBits) { throw new IllegalArgumentException("Number of bits must be positive"); } - return new BloomFilterImpl(optimalNumOfHashFunctions(expectedNumItems, numBits), numBits); + int numHashFunctions = optimalNumOfHashFunctions(expectedNumItems, numBits); + + return switch (version) { + case V1 -> new BloomFilterImpl(numHashFunctions, numBits); + case V2 -> new BloomFilterImplV2(numHashFunctions, numBits, seed); + }; } } diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterBase.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterBase.java new file mode 100644 index 0000000000000..b5b321cba0407 --- /dev/null +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterBase.java @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch; + +import java.util.Objects; + +abstract class BloomFilterBase extends BloomFilter { + + public static final int DEFAULT_SEED = 0; + + protected int seed; + protected int numHashFunctions; + protected BitArray bits; + + protected BloomFilterBase(int numHashFunctions, long numBits) { + this(numHashFunctions, numBits, DEFAULT_SEED); + } + + protected BloomFilterBase(int numHashFunctions, long numBits, int seed) { + this(new BitArray(numBits), numHashFunctions, seed); + } + + protected BloomFilterBase(BitArray bits, int numHashFunctions, int seed) { + this.bits = bits; + this.numHashFunctions = numHashFunctions; + this.seed = seed; + } + + protected BloomFilterBase() {} + + @Override + public boolean equals(Object other) { + if (other == this) { + return true; + } + + if (!(other instanceof BloomFilterBase that)) { + return false; + } + + return + this.getClass() == that.getClass() + && this.numHashFunctions == that.numHashFunctions + && this.seed == that.seed + // TODO: this.bits can be null temporarily, during deserialization, + // should we worry about this? + && this.bits.equals(that.bits); + } + + @Override + public int hashCode() { + return Objects.hash(numHashFunctions, seed, bits); + } + + @Override + public double expectedFpp() { + return Math.pow((double) bits.cardinality() / bits.bitSize(), numHashFunctions); + } + + @Override + public long bitSize() { + return bits.bitSize(); + } + + @Override + public boolean put(Object item) { + if (item instanceof String str) { + return putString(str); + } else if (item instanceof byte[] bytes) { + return putBinary(bytes); + } else { + return putLong(Utils.integralToLong(item)); + } + } + + protected HiLoHash hashLongToIntPair(long item, int seed) { + // Here we first hash the input long element into 2 int hash values, h1 and h2, then produce n + // hash values by `h1 + i * h2` with 1 <= i <= numHashFunctions. + // Note that `CountMinSketch` use a different strategy, it hash the input long element with + // every i to produce n hash values. + // TODO: the strategy of `CountMinSketch` looks more advanced, should we follow it here? + int h1 = Murmur3_x86_32.hashLong(item, seed); + int h2 = Murmur3_x86_32.hashLong(item, h1); + return new HiLoHash(h1, h2); + } + + protected HiLoHash hashBytesToIntPair(byte[] item, int seed) { + int h1 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, seed); + int h2 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, h1); + return new HiLoHash(h1, h2); + } + + protected abstract boolean scatterHashAndSetAllBits(HiLoHash inputHash); + + protected abstract boolean scatterHashAndGetAllBits(HiLoHash inputHash); + + @Override + public boolean putString(String item) { + return putBinary(Utils.getBytesFromUTF8String(item)); + } + + @Override + public boolean putBinary(byte[] item) { + HiLoHash hiLoHash = hashBytesToIntPair(item, seed); + return scatterHashAndSetAllBits(hiLoHash); + } + + @Override + public boolean mightContainString(String item) { + return mightContainBinary(Utils.getBytesFromUTF8String(item)); + } + + @Override + public boolean mightContainBinary(byte[] item) { + HiLoHash hiLoHash = hashBytesToIntPair(item, seed); + return scatterHashAndGetAllBits(hiLoHash); + } + + public boolean putLong(long item) { + HiLoHash hiLoHash = hashLongToIntPair(item, seed); + return scatterHashAndSetAllBits(hiLoHash); + } + + @Override + public boolean mightContainLong(long item) { + HiLoHash hiLoHash = hashLongToIntPair(item, seed); + return scatterHashAndGetAllBits(hiLoHash); + } + + @Override + public boolean mightContain(Object item) { + if (item instanceof String str) { + return mightContainString(str); + } else if (item instanceof byte[] bytes) { + return mightContainBinary(bytes); + } else { + return mightContainLong(Utils.integralToLong(item)); + } + } + + @Override + public boolean isCompatible(BloomFilter other) { + if (other == null) { + return false; + } + + if (!(other instanceof BloomFilterBase that)) { + return false; + } + + return + this.getClass() == that.getClass() + && this.bitSize() == that.bitSize() + && this.numHashFunctions == that.numHashFunctions + && this.seed == that.seed; + } + + @Override + public BloomFilter mergeInPlace(BloomFilter other) throws IncompatibleMergeException { + BloomFilterBase otherImplInstance = checkCompatibilityForMerge(other); + + this.bits.putAll(otherImplInstance.bits); + return this; + } + + @Override + public BloomFilter intersectInPlace(BloomFilter other) throws IncompatibleMergeException { + BloomFilterBase otherImplInstance = checkCompatibilityForMerge(other); + + this.bits.and(otherImplInstance.bits); + return this; + } + + @Override + public long cardinality() { + return this.bits.cardinality(); + } + + protected abstract BloomFilterBase checkCompatibilityForMerge(BloomFilter other) + throws IncompatibleMergeException; + + public record HiLoHash(int hi, int lo) {} + +} diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java index 3bd04a531fe75..743fd9fb6738e 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java @@ -19,71 +19,17 @@ import java.io.*; -class BloomFilterImpl extends BloomFilter implements Serializable { - - private int numHashFunctions; - - private BitArray bits; +class BloomFilterImpl extends BloomFilterBase implements Serializable { BloomFilterImpl(int numHashFunctions, long numBits) { - this(new BitArray(numBits), numHashFunctions); - } - - private BloomFilterImpl(BitArray bits, int numHashFunctions) { - this.bits = bits; - this.numHashFunctions = numHashFunctions; + super(numHashFunctions, numBits); } private BloomFilterImpl() {} - @Override - public boolean equals(Object other) { - if (other == this) { - return true; - } - - if (!(other instanceof BloomFilterImpl that)) { - return false; - } - - return this.numHashFunctions == that.numHashFunctions && this.bits.equals(that.bits); - } - - @Override - public int hashCode() { - return bits.hashCode() * 31 + numHashFunctions; - } - - @Override - public double expectedFpp() { - return Math.pow((double) bits.cardinality() / bits.bitSize(), numHashFunctions); - } - - @Override - public long bitSize() { - return bits.bitSize(); - } - - @Override - public boolean put(Object item) { - if (item instanceof String str) { - return putString(str); - } else if (item instanceof byte[] bytes) { - return putBinary(bytes); - } else { - return putLong(Utils.integralToLong(item)); - } - } - - @Override - public boolean putString(String item) { - return putBinary(Utils.getBytesFromUTF8String(item)); - } - - @Override - public boolean putBinary(byte[] item) { - int h1 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, 0); - int h2 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, h1); + protected boolean scatterHashAndSetAllBits(HiLoHash inputHash) { + int h1 = inputHash.hi(); + int h2 = inputHash.lo(); long bitSize = bits.bitSize(); boolean bitsChanged = false; @@ -98,15 +44,9 @@ public boolean putBinary(byte[] item) { return bitsChanged; } - @Override - public boolean mightContainString(String item) { - return mightContainBinary(Utils.getBytesFromUTF8String(item)); - } - - @Override - public boolean mightContainBinary(byte[] item) { - int h1 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, 0); - int h2 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, h1); + protected boolean scatterHashAndGetAllBits(HiLoHash inputHash) { + int h1 = inputHash.hi(); + int h2 = inputHash.lo(); long bitSize = bits.bitSize(); for (int i = 1; i <= numHashFunctions; i++) { @@ -122,94 +62,7 @@ public boolean mightContainBinary(byte[] item) { return true; } - @Override - public boolean putLong(long item) { - // Here we first hash the input long element into 2 int hash values, h1 and h2, then produce n - // hash values by `h1 + i * h2` with 1 <= i <= numHashFunctions. - // Note that `CountMinSketch` use a different strategy, it hash the input long element with - // every i to produce n hash values. - // TODO: the strategy of `CountMinSketch` looks more advanced, should we follow it here? - int h1 = Murmur3_x86_32.hashLong(item, 0); - int h2 = Murmur3_x86_32.hashLong(item, h1); - - long bitSize = bits.bitSize(); - boolean bitsChanged = false; - for (int i = 1; i <= numHashFunctions; i++) { - int combinedHash = h1 + (i * h2); - // Flip all the bits if it's negative (guaranteed positive number) - if (combinedHash < 0) { - combinedHash = ~combinedHash; - } - bitsChanged |= bits.set(combinedHash % bitSize); - } - return bitsChanged; - } - - @Override - public boolean mightContainLong(long item) { - int h1 = Murmur3_x86_32.hashLong(item, 0); - int h2 = Murmur3_x86_32.hashLong(item, h1); - - long bitSize = bits.bitSize(); - for (int i = 1; i <= numHashFunctions; i++) { - int combinedHash = h1 + (i * h2); - // Flip all the bits if it's negative (guaranteed positive number) - if (combinedHash < 0) { - combinedHash = ~combinedHash; - } - if (!bits.get(combinedHash % bitSize)) { - return false; - } - } - return true; - } - - @Override - public boolean mightContain(Object item) { - if (item instanceof String str) { - return mightContainString(str); - } else if (item instanceof byte[] bytes) { - return mightContainBinary(bytes); - } else { - return mightContainLong(Utils.integralToLong(item)); - } - } - - @Override - public boolean isCompatible(BloomFilter other) { - if (other == null) { - return false; - } - - if (!(other instanceof BloomFilterImpl that)) { - return false; - } - - return this.bitSize() == that.bitSize() && this.numHashFunctions == that.numHashFunctions; - } - - @Override - public BloomFilter mergeInPlace(BloomFilter other) throws IncompatibleMergeException { - BloomFilterImpl otherImplInstance = checkCompatibilityForMerge(other); - - this.bits.putAll(otherImplInstance.bits); - return this; - } - - @Override - public BloomFilter intersectInPlace(BloomFilter other) throws IncompatibleMergeException { - BloomFilterImpl otherImplInstance = checkCompatibilityForMerge(other); - - this.bits.and(otherImplInstance.bits); - return this; - } - - @Override - public long cardinality() { - return this.bits.cardinality(); - } - - private BloomFilterImpl checkCompatibilityForMerge(BloomFilter other) + protected BloomFilterImpl checkCompatibilityForMerge(BloomFilter other) throws IncompatibleMergeException { // Duplicates the logic of `isCompatible` here to provide better error message. if (other == null) { @@ -240,6 +93,7 @@ public void writeTo(OutputStream out) throws IOException { dos.writeInt(Version.V1.getVersionNumber()); dos.writeInt(numHashFunctions); + // ignore seed bits.writeTo(dos); } @@ -252,6 +106,7 @@ private void readFrom0(InputStream in) throws IOException { } this.numHashFunctions = dis.readInt(); + this.seed = DEFAULT_SEED; this.bits = BitArray.readFrom(dis); } @@ -261,16 +116,18 @@ public static BloomFilterImpl readFrom(InputStream in) throws IOException { return filter; } - public static BloomFilterImpl readFrom(byte[] bytes) throws IOException { - try (ByteArrayInputStream bis = new ByteArrayInputStream(bytes)) { - return readFrom(bis); - } + // no longer necessary, but can't remove without triggering MIMA violations + @Deprecated + public static BloomFilter readFrom(byte[] bytes) throws IOException { + return BloomFilter.readFrom(bytes); } + @Serial private void writeObject(ObjectOutputStream out) throws IOException { writeTo(out); } + @Serial private void readObject(ObjectInputStream in) throws IOException { readFrom0(in); } diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImplV2.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImplV2.java new file mode 100644 index 0000000000000..fa0a1df384865 --- /dev/null +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImplV2.java @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch; + +import java.io.*; + +class BloomFilterImplV2 extends BloomFilterBase implements Serializable { + + BloomFilterImplV2(int numHashFunctions, long numBits, int seed) { + this(new BitArray(numBits), numHashFunctions, seed); + } + + private BloomFilterImplV2(BitArray bits, int numHashFunctions, int seed) { + super(bits, numHashFunctions, seed); + } + + private BloomFilterImplV2() {} + + @Override + public boolean equals(Object other) { + if (other == this) { + return true; + } + + if (!(other instanceof BloomFilterImplV2 that)) { + return false; + } + + return + this.numHashFunctions == that.numHashFunctions + && this.seed == that.seed + && this.bits.equals(that.bits); + } + + protected boolean scatterHashAndSetAllBits(HiLoHash inputHash) { + int h1 = inputHash.hi(); + int h2 = inputHash.lo(); + + long bitSize = bits.bitSize(); + boolean bitsChanged = false; + + // Integer.MAX_VALUE takes care of scrambling the higher four bytes of combinedHash + long combinedHash = (long) h1 * Integer.MAX_VALUE; + for (long i = 0; i < numHashFunctions; i++) { + combinedHash += h2; + + // Flip all the bits if it's negative (guaranteed positive number) + long combinedIndex = combinedHash < 0 ? ~combinedHash : combinedHash; + + bitsChanged |= bits.set(combinedIndex % bitSize); + } + return bitsChanged; + } + + protected boolean scatterHashAndGetAllBits(HiLoHash inputHash) { + int h1 = inputHash.hi(); + int h2 = inputHash.lo(); + + long bitSize = bits.bitSize(); + + // Integer.MAX_VALUE takes care of scrambling the higher four bytes of combinedHash + long combinedHash = (long) h1 * Integer.MAX_VALUE; + for (long i = 0; i < numHashFunctions; i++) { + combinedHash += h2; + + // Flip all the bits if it's negative (guaranteed positive number) + long combinedIndex = combinedHash < 0 ? ~combinedHash : combinedHash; + + if (!bits.get(combinedIndex % bitSize)) { + return false; + } + } + return true; + } + + protected BloomFilterImplV2 checkCompatibilityForMerge(BloomFilter other) + throws IncompatibleMergeException { + // Duplicates the logic of `isCompatible` here to provide better error message. + if (other == null) { + throw new IncompatibleMergeException("Cannot merge null bloom filter"); + } + + if (!(other instanceof BloomFilterImplV2 that)) { + throw new IncompatibleMergeException( + "Cannot merge bloom filter of class " + other.getClass().getName() + ); + } + + if (this.bitSize() != that.bitSize()) { + throw new IncompatibleMergeException("Cannot merge bloom filters with different bit size"); + } + + if (this.seed != that.seed) { + throw new IncompatibleMergeException( + "Cannot merge bloom filters with different seeds" + ); + } + + if (this.numHashFunctions != that.numHashFunctions) { + throw new IncompatibleMergeException( + "Cannot merge bloom filters with different number of hash functions" + ); + } + return that; + } + + @Override + public void writeTo(OutputStream out) throws IOException { + DataOutputStream dos = new DataOutputStream(out); + + dos.writeInt(Version.V2.getVersionNumber()); + dos.writeInt(numHashFunctions); + dos.writeInt(seed); + bits.writeTo(dos); + } + + private void readFrom0(InputStream in) throws IOException { + DataInputStream dis = new DataInputStream(in); + + int version = dis.readInt(); + if (version != Version.V2.getVersionNumber()) { + throw new IOException("Unexpected Bloom filter version number (" + version + ")"); + } + + this.numHashFunctions = dis.readInt(); + this.seed = dis.readInt(); + this.bits = BitArray.readFrom(dis); + } + + public static BloomFilterImplV2 readFrom(InputStream in) throws IOException { + BloomFilterImplV2 filter = new BloomFilterImplV2(); + filter.readFrom0(in); + return filter; + } + + @Serial + private void writeObject(ObjectOutputStream out) throws IOException { + writeTo(out); + } + + @Serial + private void readObject(ObjectInputStream in) throws IOException { + readFrom0(in); + } +} diff --git a/common/sketch/src/test/java/org/apache/spark/util/sketch/SparkBloomFilterSuite.java b/common/sketch/src/test/java/org/apache/spark/util/sketch/SparkBloomFilterSuite.java new file mode 100644 index 0000000000000..ac574525d36b3 --- /dev/null +++ b/common/sketch/src/test/java/org/apache/spark/util/sketch/SparkBloomFilterSuite.java @@ -0,0 +1,394 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch; + +import org.junit.jupiter.api.*; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import java.io.PrintStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.time.Duration; +import java.time.Instant; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.LongAdder; +import java.util.stream.LongStream; +import java.util.stream.Stream; + + +public class SparkBloomFilterSuite { + + // the implemented fpp limit is only approximating the hard boundary, + // so we'll need an error threshold for the assertion + final double FPP_EVEN_ODD_ERROR_FACTOR = 0.10; + final double FPP_RANDOM_ERROR_FACTOR = 0.10; + + final long ONE_GB = 1024L * 1024L * 1024L; + final long REQUIRED_HEAP_UPPER_BOUND_IN_BYTES = 4 * ONE_GB; + + private static Instant START; + private static boolean strict; + private static boolean verbose; + + private Instant start; + private final Map testOutMap = new ConcurrentHashMap<>(); + + @BeforeAll + public static void beforeAll() { + START = Instant.now(); + String testClassName = SparkBloomFilterSuite.class.getName(); + strict = Boolean.parseBoolean(System.getProperty(testClassName+ ".strict", "true")); + verbose = Boolean.parseBoolean(System.getProperty(testClassName+ ".verbose", "false")); + } + + @AfterAll + public static void afterAll() { + Duration duration = Duration.between(START, Instant.now()); + if (verbose) { + System.err.println(duration + " TOTAL"); + } + } + + @BeforeEach + public void beforeEach( + TestInfo testInfo + ) throws Exception { + start = Instant.now(); + + String testName = testInfo.getDisplayName(); + + String testClassName = SparkBloomFilterSuite.class.getName(); + String logDir = System.getProperty(testClassName+ ".logDir", "./target/tmp"); + Path logDirPath = Path.of(logDir); + Files.createDirectories(logDirPath); + Path testLogPath = Path.of(logDir,testName + ".log"); + Files.deleteIfExists(testLogPath); + + PrintStream testOut = new PrintStream(Files.newOutputStream(testLogPath)); + testOutMap.put(testName, testOut); + + testOut.println("testName: " + testName); + } + + @AfterEach + public void afterEach(TestInfo testInfo) { + Duration duration = Duration.between(start, Instant.now()); + + String testName = testInfo.getDisplayName(); + PrintStream testOut = testOutMap.get(testName); + + testOut.println("duration: " + duration ); + testOut.close(); + } + + private static Stream dataPointProvider() { + // temporary workaround: + // to reduce running time to acceptable levels, we test only one case, + // with the default FPP and the default seed only. + return Stream.of( + Arguments.of(1_000_000_000L, 0.03, BloomFilterImplV2.DEFAULT_SEED) + ); + // preferable minimum parameter space for tests: + // {1_000_000L, 1_000_000_000L} for: long numItems + // {0.05, 0.03, 0.01, 0.001} for: double expectedFpp + // {BloomFilterImpl.DEFAULT_SEED, 1, 127} for: int deterministicSeed + } + + /** + * This test, in N number of iterations, inserts N even numbers (2*i) int, + * and leaves out N odd numbers (2*i+1) from the tested BloomFilter instance. + * + * It checks the 100% accuracy of mightContain=true on all of the even items, + * and measures the mightContain=true (false positive) rate on the not-inserted odd numbers. + * + * @param numItems the number of items to be inserted + * @param expectedFpp the expected fpp rate of the tested BloomFilter instance + * @param deterministicSeed the deterministic seed to use to initialize + * the primary BloomFilter instance. + */ + @ParameterizedTest(name = "testAccuracyEvenOdd.n{0}_fpp{1}_seed{2}") + @MethodSource("dataPointProvider") + public void testAccuracyEvenOdd( + long numItems, + double expectedFpp, + int deterministicSeed, + TestInfo testInfo + ) { + String testName = testInfo.getDisplayName(); + PrintStream testOut = testOutMap.get(testName); + + long optimalNumOfBits = BloomFilter.optimalNumOfBits(numItems, expectedFpp); + testOut.printf( + "optimal bitArray: %d (%d MB)\n", + optimalNumOfBits, + optimalNumOfBits / Byte.SIZE / 1024 / 1024 + ); + Assumptions.assumeTrue( + optimalNumOfBits / Byte.SIZE < REQUIRED_HEAP_UPPER_BOUND_IN_BYTES, + "this testcase would require allocating more than 4GB of heap mem (" + + optimalNumOfBits + + " bits)" + ); + + BloomFilter bloomFilter = + BloomFilter.create( + BloomFilter.Version.V2, + numItems, + optimalNumOfBits, + deterministicSeed + ); + + testOut.printf( + "allocated bitArray: %d (%d MB)\n", + bloomFilter.bitSize(), + bloomFilter.bitSize() / Byte.SIZE / 1024 / 1024 + ); + + for (long i = 0; i < numItems; i++) { + if (verbose && i % 10_000_000 == 0) { + System.err.printf("i: %d\n", i); + } + + bloomFilter.putLong(2 * i); + } + + testOut.printf("bitCount: %d\nsaturation: %f\n", + bloomFilter.cardinality(), + (double) bloomFilter.cardinality() / bloomFilter.bitSize() + ); + + LongAdder mightContainEven = new LongAdder(); + LongAdder mightContainOdd = new LongAdder(); + + LongStream inputStream = LongStream.range(0, numItems).parallel(); + inputStream.forEach( + i -> { + long even = 2 * i; + if (bloomFilter.mightContainLong(even)) { + mightContainEven.increment(); + } + + long odd = 2 * i + 1; + if (bloomFilter.mightContainLong(odd)) { + mightContainOdd.increment(); + } + } + ); + + Assertions.assertEquals( + numItems, mightContainEven.longValue(), + "mightContainLong must return true for all inserted numbers" + ); + + double actualFpp = mightContainOdd.doubleValue() / numItems; + double acceptableFpp = expectedFpp * (1 + FPP_EVEN_ODD_ERROR_FACTOR); + + testOut.printf("expectedFpp: %f %%\n", 100 * expectedFpp); + testOut.printf("acceptableFpp: %f %%\n", 100 * acceptableFpp); + testOut.printf("actualFpp: %f %%\n", 100 * actualFpp); + + if (!strict) { + Assumptions.assumeTrue( + actualFpp <= acceptableFpp, + String.format( + "acceptableFpp(%f %%) < actualFpp (%f %%)", + 100 * acceptableFpp, + 100 * actualFpp + ) + ); + } else { + Assertions.assertTrue( + actualFpp <= acceptableFpp, + String.format( + "acceptableFpp(%f %%) < actualFpp (%f %%)", + 100 * acceptableFpp, + 100 * actualFpp + ) + ); + } + } + + /** + * This test inserts N pseudorandomly generated numbers in 2N number of iterations in two + * differently seeded (theoretically independent) BloomFilter instances. All the random + * numbers generated in an even-iteration will be inserted into both filters, all the + * random numbers generated in an odd-iteration will be left out from both. + * + * The test checks the 100% accuracy of 'mightContain=true' for all the items inserted + * in an even-loop. It counts the false positives as the number of odd-loop items for + * which the primary filter reports 'mightContain=true', but secondary reports + * 'mightContain=false'. Since we inserted the same elements into both instances, + * and the secondary reports non-insertion, the 'mightContain=true' from the primary + * can only be a false positive. + * + * @param numItems the number of items to be inserted + * @param expectedFpp the expected fpp rate of the tested BloomFilter instance + * @param deterministicSeed the deterministic seed to use to initialize + * the primary BloomFilter instance. (The secondary will be + * initialized with the constant seed of 0xCAFEBABE) + */ + @ParameterizedTest(name = "testAccuracyRandom.n{0}_fpp{1}_seed{2}") + @MethodSource("dataPointProvider") + public void testAccuracyRandomDistribution( + long numItems, + double expectedFpp, + int deterministicSeed, + TestInfo testInfo + ) { + String testName = testInfo.getDisplayName(); + PrintStream testOut = testOutMap.get(testName); + + long optimalNumOfBits = BloomFilter.optimalNumOfBits(numItems, expectedFpp); + testOut.printf( + "optimal bitArray: %d (%d MB)\n", + optimalNumOfBits, + optimalNumOfBits / Byte.SIZE / 1024 / 1024 + ); + Assumptions.assumeTrue( + 2 * optimalNumOfBits / Byte.SIZE < REQUIRED_HEAP_UPPER_BOUND_IN_BYTES, + "this testcase would require allocating more than 4GB of heap mem (2x " + + optimalNumOfBits + + " bits)" + ); + + BloomFilter bloomFilterPrimary = + BloomFilter.create( + BloomFilter.Version.V2, + numItems, + optimalNumOfBits, + deterministicSeed + ); + + BloomFilter bloomFilterSecondary = + BloomFilter.create( + BloomFilter.Version.V2, + numItems, + optimalNumOfBits, + 0xCAFEBABE + ); + + testOut.printf( + "allocated bitArray: %d (%d MB)\n", + bloomFilterPrimary.bitSize(), + bloomFilterPrimary.bitSize() / Byte.SIZE / 1024 / 1024 + ); + + long iterationCount = 2 * numItems; + + for (long i = 0; i < iterationCount; i++) { + if (verbose && i % 10_000_000 == 0) { + System.err.printf("i: %d\n", i); + } + + long candidate = scramble(i); + if (i % 2 == 0) { + bloomFilterPrimary.putLong(candidate); + bloomFilterSecondary.putLong(candidate); + } + } + testOut.printf("bitCount: %d\nsaturation: %f\n", + bloomFilterPrimary.cardinality(), + (double) bloomFilterPrimary.cardinality() / bloomFilterPrimary.bitSize() + ); + + LongAdder mightContainEvenIndexed = new LongAdder(); + LongAdder mightContainOddIndexed = new LongAdder(); + LongAdder confirmedAsNotInserted = new LongAdder(); + LongStream inputStream = LongStream.range(0, iterationCount).parallel(); + inputStream.forEach( + i -> { + if (verbose && i % (iterationCount / 100) == 0) { + System.err.printf("%s: %2d %%\n", testName, 100 * i / iterationCount); + } + + long candidate = scramble(i); + + if (i % 2 == 0) { // EVEN + mightContainEvenIndexed.increment(); + } else { // ODD + // for fpp estimation, only consider the odd indexes + // (to avoid querying the secondary with elements known to be inserted) + + // since here we avoided all the even indexes, + // most of these secondary queries will return false + if (!bloomFilterSecondary.mightContainLong(candidate)) { + // from the odd indexes, we consider only those items + // where the secondary confirms the non-insertion + + // anything on which the primary and the secondary + // disagrees here is a false positive + if (bloomFilterPrimary.mightContainLong(candidate)) { + mightContainOddIndexed.increment(); + } + // count the total number of considered items for a baseline + confirmedAsNotInserted.increment(); + } + } + } + ); + + Assertions.assertEquals( + numItems, mightContainEvenIndexed.longValue(), + "mightContainLong must return true for all inserted numbers" + ); + + double actualFpp = + mightContainOddIndexed.doubleValue() / confirmedAsNotInserted.doubleValue(); + double acceptableFpp = expectedFpp * (1 + FPP_RANDOM_ERROR_FACTOR); + + testOut.printf("mightContainOddIndexed: %10d\n", mightContainOddIndexed.longValue()); + testOut.printf("confirmedAsNotInserted: %10d\n", confirmedAsNotInserted.longValue()); + testOut.printf("numItems: %10d\n", numItems); + testOut.printf("expectedFpp: %f %%\n", 100 * expectedFpp); + testOut.printf("acceptableFpp: %f %%\n", 100 * acceptableFpp); + testOut.printf("actualFpp: %f %%\n", 100 * actualFpp); + + if (!strict) { + Assumptions.assumeTrue( + actualFpp <= acceptableFpp, + String.format( + "acceptableFpp(%f %%) < actualFpp (%f %%)", + 100 * acceptableFpp, + 100 * actualFpp + ) + ); + } else { + Assertions.assertTrue( + actualFpp <= acceptableFpp, + String.format( + "acceptableFpp(%f %%) < actualFpp (%f %%)", + 100 * acceptableFpp, + 100 * actualFpp + ) + ); + } + } + + // quick scrambling logic hacked out from java.util.Random + // its range is only 48bits (out of the 64bits of a Long value), + // but it should be enough for the purposes of this test. + private static final long multiplier = 0x5DEECE66DL; + private static final long addend = 0xBL; + private static final long mask = (1L << 48) - 1; + private static long scramble(long value) { + return (value * multiplier + addend) & mask; + } +} diff --git a/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala b/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala index 4d0ba66637b46..ba8f97a51aecf 100644 --- a/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala +++ b/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala @@ -46,7 +46,9 @@ class BloomFilterSuite extends AnyFunSuite { // scalastyle:ignore funsuite val fpp = 0.05 val numInsertion = numItems / 10 - val allItems = Array.fill(numItems)(itemGen(r)) + // using a Set to avoid duplicates, + // inserting twice as many random values as used, to compensate for lost dupes + val allItems = Set.fill(2 * numItems)(itemGen(r)).take(numItems) val filter = BloomFilter.create(numInsertion, fpp) @@ -158,5 +160,11 @@ class BloomFilterSuite extends AnyFunSuite { // scalastyle:ignore funsuite val filter2 = BloomFilter.create(2000, 6400) filter1.mergeInPlace(filter2) } + + intercept[IncompatibleMergeException] { + val filter1 = BloomFilter.create(BloomFilter.Version.V1, 1000L, 6400L, 0) + val filter2 = BloomFilter.create(BloomFilter.Version.V2, 1000L, 6400L, 0) + filter1.mergeInPlace(filter2) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala index af97856fd222e..fb279b1db6fc9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala @@ -41,16 +41,19 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession { override def beforeAll(): Unit = { super.beforeAll() // Register 'bloom_filter_agg' to builtin. - spark.sessionState.functionRegistry.registerFunction(funcId_bloom_filter_agg, + spark.sessionState.functionRegistry.registerFunction( + funcId_bloom_filter_agg, new ExpressionInfo(classOf[BloomFilterAggregate].getName, "bloom_filter_agg"), - (children: Seq[Expression]) => children.size match { - case 1 => new BloomFilterAggregate(children.head) - case 2 => new BloomFilterAggregate(children.head, children(1)) - case 3 => new BloomFilterAggregate(children.head, children(1), children(2)) - }) + (children: Seq[Expression]) => + children.size match { + case 1 => new BloomFilterAggregate(children.head) + case 2 => new BloomFilterAggregate(children.head, children(1)) + case 3 => new BloomFilterAggregate(children.head, children(1), children(2)) + }) // Register 'might_contain' to builtin. - spark.sessionState.functionRegistry.registerFunction(funcId_might_contain, + spark.sessionState.functionRegistry.registerFunction( + funcId_might_contain, new ExpressionInfo(classOf[BloomFilterMightContain].getName, "might_contain"), (children: Seq[Expression]) => BloomFilterMightContain(children.head, children(1))) } @@ -64,10 +67,22 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession { test("Test bloom_filter_agg and might_contain") { val conf = SQLConf.get val table = "bloom_filter_test" - for (numEstimatedItems <- Seq(Long.MinValue, -10L, 0L, 4096L, 4194304L, Long.MaxValue, - conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS))) { - for ((numBits, index) <- Seq(Long.MinValue, -10L, 0L, 4096L, 4194304L, Long.MaxValue, - conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_BITS)).zipWithIndex) { + for (numEstimatedItems <- Seq( + Long.MinValue, + -10L, + 0L, + 4096L, + 4194304L, + Long.MaxValue, + conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS))) { + for ((numBits, index) <- Seq( + Long.MinValue, + -10L, + 0L, + 4096L, + 4194304L, + Long.MaxValue, + conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_BITS)).zipWithIndex) { val sqlString = s""" |SELECT every(might_contain( | (SELECT bloom_filter_agg(col, @@ -85,7 +100,8 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession { """.stripMargin withTempView(table) { (Seq(Long.MinValue, 0, Long.MaxValue) ++ (1L to 10000L)) - .toDF("col").createOrReplaceTempView(table) + .toDF("col") + .createOrReplaceTempView(table) // Validate error messages as well as answers when there's no error. if (numEstimatedItems <= 0) { val exception = intercept[AnalysisException] { @@ -104,16 +120,13 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession { "valueRange" -> "[0, positive]", "currentValue" -> toSQLValue(numEstimatedItems, LongType), "sqlExpr" -> (s""""bloom_filter_agg(col, CAST($numEstimatedItems AS BIGINT), """ + - s"""CAST($numBits AS BIGINT))"""") - ), + s"""CAST($numBits AS BIGINT))"""")), context = ExpectedContext( fragment = "bloom_filter_agg(col,\n" + s" cast($numEstimatedItems as long),\n" + s" cast($numBits as long))", start = 49, - stop = stop(index) - ) - ) + stop = stop(index))) } else if (numBits <= 0) { val exception = intercept[AnalysisException] { spark.sql(sqlString) @@ -132,16 +145,13 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession { "valueRange" -> "[0, positive]", "currentValue" -> toSQLValue(numBits, LongType), "sqlExpr" -> (s""""bloom_filter_agg(col, CAST($numEstimatedItems AS BIGINT), """ + - s"""CAST($numBits AS BIGINT))"""") - ), + s"""CAST($numBits AS BIGINT))"""")), context = ExpectedContext( fragment = "bloom_filter_agg(col,\n" + s" cast($numEstimatedItems as long),\n" + s" cast($numBits as long))", start = 49, - stop = stop(index) - ) - ) + stop = stop(index))) } else { checkAnswer(spark.sql(sqlString), Row(true, false)) } @@ -154,8 +164,7 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession { val exception1 = intercept[AnalysisException] { spark.sql(""" |SELECT bloom_filter_agg(a) - |FROM values (1.2), (2.5) as t(a)""" - .stripMargin) + |FROM values (1.2), (2.5) as t(a)""".stripMargin) } checkError( exception = exception1, @@ -165,20 +174,13 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession { "sqlExpr" -> "\"bloom_filter_agg(a, 1000000, 8388608)\"", "expectedLeft" -> "\"BINARY\"", "expectedRight" -> "\"BIGINT\"", - "actual" -> "\"DECIMAL(2,1)\", \"BIGINT\", \"BIGINT\"" - ), - context = ExpectedContext( - fragment = "bloom_filter_agg(a)", - start = 8, - stop = 26 - ) - ) + "actual" -> "\"DECIMAL(2,1)\", \"BIGINT\", \"BIGINT\""), + context = ExpectedContext(fragment = "bloom_filter_agg(a)", start = 8, stop = 26)) val exception2 = intercept[AnalysisException] { spark.sql(""" |SELECT bloom_filter_agg(a, 2) - |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)""" - .stripMargin) + |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)""".stripMargin) } checkError( exception = exception2, @@ -188,20 +190,13 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession { "sqlExpr" -> "\"bloom_filter_agg(a, 2, (2 * 8))\"", "expectedLeft" -> "\"BINARY\"", "expectedRight" -> "\"BIGINT\"", - "actual" -> "\"BIGINT\", \"INT\", \"BIGINT\"" - ), - context = ExpectedContext( - fragment = "bloom_filter_agg(a, 2)", - start = 8, - stop = 29 - ) - ) + "actual" -> "\"BIGINT\", \"INT\", \"BIGINT\""), + context = ExpectedContext(fragment = "bloom_filter_agg(a, 2)", start = 8, stop = 29)) val exception3 = intercept[AnalysisException] { spark.sql(""" |SELECT bloom_filter_agg(a, cast(2 as long), 5) - |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)""" - .stripMargin) + |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)""".stripMargin) } checkError( exception = exception3, @@ -211,60 +206,42 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession { "sqlExpr" -> "\"bloom_filter_agg(a, CAST(2 AS BIGINT), 5)\"", "expectedLeft" -> "\"BINARY\"", "expectedRight" -> "\"BIGINT\"", - "actual" -> "\"BIGINT\", \"BIGINT\", \"INT\"" - ), + "actual" -> "\"BIGINT\", \"BIGINT\", \"INT\""), context = ExpectedContext( fragment = "bloom_filter_agg(a, cast(2 as long), 5)", start = 8, - stop = 46 - ) - ) + stop = 46)) val exception4 = intercept[AnalysisException] { spark.sql(""" |SELECT bloom_filter_agg(a, null, 5) - |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)""" - .stripMargin) + |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)""".stripMargin) } checkError( exception = exception4, condition = "DATATYPE_MISMATCH.UNEXPECTED_NULL", parameters = Map( "exprName" -> "estimatedNumItems or numBits", - "sqlExpr" -> "\"bloom_filter_agg(a, NULL, 5)\"" - ), - context = ExpectedContext( - fragment = "bloom_filter_agg(a, null, 5)", - start = 8, - stop = 35 - ) - ) + "sqlExpr" -> "\"bloom_filter_agg(a, NULL, 5)\""), + context = ExpectedContext(fragment = "bloom_filter_agg(a, null, 5)", start = 8, stop = 35)) val exception5 = intercept[AnalysisException] { spark.sql(""" |SELECT bloom_filter_agg(a, 5, null) - |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)""" - .stripMargin) + |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)""".stripMargin) } checkError( exception = exception5, condition = "DATATYPE_MISMATCH.UNEXPECTED_NULL", parameters = Map( "exprName" -> "estimatedNumItems or numBits", - "sqlExpr" -> "\"bloom_filter_agg(a, 5, NULL)\"" - ), - context = ExpectedContext( - fragment = "bloom_filter_agg(a, 5, null)", - start = 8, - stop = 35 - ) - ) + "sqlExpr" -> "\"bloom_filter_agg(a, 5, NULL)\""), + context = ExpectedContext(fragment = "bloom_filter_agg(a, 5, null)", start = 8, stop = 35)) } test("Test that might_contain errors out disallowed input value types") { val exception1 = intercept[AnalysisException] { - spark.sql("""|SELECT might_contain(1.0, 1L)""" - .stripMargin) + spark.sql("""|SELECT might_contain(1.0, 1L)""".stripMargin) } checkError( exception = exception1, @@ -274,18 +251,11 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession { "functionName" -> "`might_contain`", "expectedLeft" -> "\"BINARY\"", "expectedRight" -> "\"BIGINT\"", - "actual" -> "\"DECIMAL(2,1)\", \"BIGINT\"" - ), - context = ExpectedContext( - fragment = "might_contain(1.0, 1L)", - start = 7, - stop = 28 - ) - ) + "actual" -> "\"DECIMAL(2,1)\", \"BIGINT\""), + context = ExpectedContext(fragment = "might_contain(1.0, 1L)", start = 7, stop = 28)) val exception2 = intercept[AnalysisException] { - spark.sql("""|SELECT might_contain(NULL, 0.1)""" - .stripMargin) + spark.sql("""|SELECT might_contain(NULL, 0.1)""".stripMargin) } checkError( exception = exception2, @@ -295,22 +265,15 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession { "functionName" -> "`might_contain`", "expectedLeft" -> "\"BINARY\"", "expectedRight" -> "\"BIGINT\"", - "actual" -> "\"VOID\", \"DECIMAL(1,1)\"" - ), - context = ExpectedContext( - fragment = "might_contain(NULL, 0.1)", - start = 7, - stop = 30 - ) - ) + "actual" -> "\"VOID\", \"DECIMAL(1,1)\""), + context = ExpectedContext(fragment = "might_contain(NULL, 0.1)", start = 7, stop = 30)) } test("Test that might_contain errors out non-constant Bloom filter") { val exception1 = intercept[AnalysisException] { spark.sql(""" |SELECT might_contain(cast(a as binary), cast(5 as long)) - |FROM values (cast(1 as string)), (cast(2 as string)) as t(a)""" - .stripMargin) + |FROM values (cast(1 as string)), (cast(2 as string)) as t(a)""".stripMargin) } checkError( exception = exception1, @@ -318,20 +281,16 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"might_contain(CAST(a AS BINARY), CAST(5 AS BIGINT))\"", "functionName" -> "`might_contain`", - "actual" -> "\"CAST(a AS BINARY)\"" - ), + "actual" -> "\"CAST(a AS BINARY)\""), context = ExpectedContext( fragment = "might_contain(cast(a as binary), cast(5 as long))", start = 8, - stop = 56 - ) - ) + stop = 56)) val exception2 = intercept[AnalysisException] { spark.sql(""" |SELECT might_contain((select cast(a as binary)), cast(5 as long)) - |FROM values (cast(1 as string)), (cast(2 as string)) as t(a)""" - .stripMargin) + |FROM values (cast(1 as string)), (cast(2 as string)) as t(a)""".stripMargin) } checkError( exception = exception2, @@ -339,32 +298,38 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"might_contain(scalarsubquery(a), CAST(5 AS BIGINT))\"", "functionName" -> "`might_contain`", - "actual" -> "\"scalarsubquery(a)\"" - ), + "actual" -> "\"scalarsubquery(a)\""), context = ExpectedContext( fragment = "might_contain((select cast(a as binary)), cast(5 as long))", start = 8, - stop = 65 - ) - ) + stop = 65)) } - test("Test that might_contain can take a constant value input") { - checkAnswer(spark.sql( - """SELECT might_contain( + test("Test that might_contain can take a constant value input (seedless version)") { + checkAnswer( + spark.sql("""SELECT might_contain( |X'00000001000000050000000343A2EC6EA8C117E2D3CDB767296B144FC5BFBCED9737F267', |cast(201 as long))""".stripMargin), Row(false)) } + test("Test that might_contain can take a constant value input (seeded version)") { + checkAnswer( + spark.sql("""SELECT might_contain( + |X'0000000200000005000000000000000343A2EC6EA8C117E2D3CDB767296B144FC5BFBCED9737F267', + |cast(201 as long))""".stripMargin), + Row(false)) + } + test("Test that bloom_filter_agg produces a NULL with empty input") { - checkAnswer(spark.sql("""SELECT bloom_filter_agg(cast(id as long)) from range(1, 1)"""), + checkAnswer( + spark.sql("""SELECT bloom_filter_agg(cast(id as long)) from range(1, 1)"""), Row(null)) } test("Test NULL inputs for might_contain") { - checkAnswer(spark.sql( - s""" + checkAnswer( + spark.sql(s""" |SELECT might_contain(null, null) both_null, | might_contain(null, 1L) null_bf, | might_contain((SELECT bloom_filter_agg(cast(id as long)) from range(1, 10000)), @@ -374,9 +339,15 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession { } test("Test that a query with bloom_filter_agg has partial aggregates") { - assert(spark.sql("""SELECT bloom_filter_agg(cast(id as long)) from range(1, 1000000)""") - .queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec].inputPlan - .collect({case agg: BaseAggregateExec => agg}).size == 2) + assert( + spark + .sql("""SELECT bloom_filter_agg(cast(id as long)) from range(1, 1000000)""") + .queryExecution + .executedPlan + .asInstanceOf[AdaptiveSparkPlanExec] + .inputPlan + .collect({ case agg: BaseAggregateExec => agg }) + .size == 2) } test("Test numBitsExpression") { @@ -385,7 +356,8 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession { assert(agg.numBitsExpression === Literal(numBits)) } - checkNumBits(conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS) * 100, + checkNumBits( + conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS) * 100, conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_BITS)) checkNumBits(conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS) + 10, 29193836) checkNumBits(conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS), 29193763)