diff --git a/core/src/main/java/org/apache/spark/io/FileUtility.java b/core/src/main/java/org/apache/spark/io/FileUtility.java new file mode 100644 index 0000000000000..5e21c6218b2f7 --- /dev/null +++ b/core/src/main/java/org/apache/spark/io/FileUtility.java @@ -0,0 +1,112 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.io; + +import org.apache.commons.compress.archivers.tar.TarArchiveEntry; +import org.apache.commons.compress.archivers.tar.TarArchiveInputStream; +import org.apache.commons.compress.archivers.tar.TarArchiveOutputStream; +import org.apache.commons.compress.utils.IOUtils; + +import java.io.*; + +public class FileUtility { + + /** + * Extract an input tar file into an output files and directories. + * inputTarFileLoc: the input file location for the tar file + * destDirLoc: destination for the extracted files + * + * throws IllegalStateException + */ + public static final String ENCODING = "utf-8"; + + public static void extractTarFile(String inputTarFileLoc, String destDirLoc) + throws IllegalStateException { + File inputFile = new File(inputTarFileLoc); + if (!inputTarFileLoc.endsWith(".tar")) { + throw new IllegalStateException(String.format( + "Input File[%s] should end with tar extension.", inputTarFileLoc)); + } + File destDir = new File(destDirLoc); + if (destDir.exists() && !destDir.delete()) { + throw new IllegalStateException(String.format( + "Couldn't delete the existing destination directory[%s] ", destDirLoc)); + } else if (!destDir.mkdir()) { + throw new IllegalStateException(String.format( + "Couldn't create directory %s ", destDirLoc)); + } + + try (InputStream is = new FileInputStream(inputFile); + TarArchiveInputStream debInputStream = new TarArchiveInputStream(is, ENCODING)) { + TarArchiveEntry entry; + while ((entry = (TarArchiveEntry) debInputStream.getNextEntry()) != null) { + final File outputFile = new File(destDirLoc, entry.getName()); + if (entry.isDirectory()) { + if (!outputFile.exists() && !outputFile.mkdirs()) { + throw new IllegalStateException(String.format( + "Couldn't create directory %s.", outputFile.getAbsolutePath())); + } + } else { + try (OutputStream outputFileStream = new FileOutputStream(outputFile)) { + IOUtils.copy(debInputStream, outputFileStream); + } + } + } + } catch (IOException e){ + throw new IllegalStateException(String.format( + "extractTarFile failed with exception %s.", e.getMessage())); + } + } + + /** + * create a tar file for input source directory location . + * source: the source directory location + * destFileLoc: destination of the created tarball + * + * throws IllegalStateException + */ + + public static void createTarFile(String source, String destFileLoc) + throws IllegalStateException { + File f = new File(destFileLoc); + if (f.exists() && !f.delete()) { + throw new IllegalStateException(String.format( + "Couldn't delete the destination file location[%s]", destFileLoc)); + } + File folder = new File(source); + if (!folder.exists()) { + throw new IllegalStateException(String.format( + "Source folder[%s] does not exist", source)); + } + + try (FileOutputStream fos = new FileOutputStream(destFileLoc); + TarArchiveOutputStream tarOs = new TarArchiveOutputStream(fos, ENCODING)) { + File[] fileNames = folder.listFiles(); + for (File file : fileNames) { + TarArchiveEntry tar_file = new TarArchiveEntry(file.getName()); + tar_file.setSize(file.length()); + tarOs.putArchiveEntry(tar_file); + try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(file))) { + IOUtils.copy(bis, tarOs); + tarOs.closeArchiveEntry(); + } + } + tarOs.finish(); + } catch (IOException e) { + throw new IllegalStateException(String.format( + "createTarFile failed with exception %s.", e.getMessage())); + } + } + +} diff --git a/core/src/test/java/org/apache/spark/io/FileUtilitySuite.java b/core/src/test/java/org/apache/spark/io/FileUtilitySuite.java new file mode 100644 index 0000000000000..a9a21a302eca8 --- /dev/null +++ b/core/src/test/java/org/apache/spark/io/FileUtilitySuite.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.io; + +import org.apache.commons.io.FileUtils; +import org.apache.commons.lang3.RandomUtils; +import org.apache.spark.util.Utils; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.io.File; +import java.io.IOException; + +/** + * Tests functionality of {@link FileUtility} + */ +public class FileUtilitySuite { + + protected File sourceFolder; + protected File destTarLoc; + protected File destFolder; + + @Before + public void setUp() throws IOException { + String tmpDir = System.getProperty("java.io.tmpdir"); + sourceFolder = Utils.createTempDir(tmpDir, "FileUtilTest-src-" + RandomUtils.nextLong()); + destFolder = Utils.createTempDir(tmpDir, "FileUtilTest-dest-" + RandomUtils.nextLong()); + destTarLoc= File.createTempFile("dest-tar", ".tar"); + } + + @After + public void tearDown() { + destTarLoc.delete(); + } + + @Test + public void testCreationAndExtraction() throws IllegalStateException, IOException { + // Create a temp file in the source folder + Assert.assertEquals(sourceFolder.listFiles().length , 0); + File inputFile = File.createTempFile("source-file", ".tmp", sourceFolder); + // Create a byte array of size 1 KB with random bytes + byte[] randomBytes = RandomUtils.nextBytes(1 * 1024); + FileUtils.writeByteArrayToFile(inputFile, randomBytes); + + // Create the tarball + destTarLoc.delete(); + Assert.assertFalse(destTarLoc.exists()); + FileUtility.createTarFile(sourceFolder.toString(), destTarLoc.getAbsolutePath()); + Assert.assertTrue(destTarLoc.exists()); + + // Extract the tarball + Assert.assertEquals(destFolder.listFiles().length , 0); + FileUtility.extractTarFile(destTarLoc.getAbsolutePath(), destFolder.getAbsolutePath()); + + // Verify that the extraction was successful + Assert.assertTrue(destFolder.exists()); + Assert.assertEquals(destFolder.listFiles().length , 1); + Assert.assertArrayEquals(randomBytes, FileUtils.readFileToByteArray(destFolder.listFiles()[0])); + } + +} diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 2da4c9e44b29e..97b713cd1403f 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -175,6 +175,7 @@ parquet-jackson-1.10.1.jar protobuf-java-2.5.0.jar py4j-0.10.8.1.jar pyrolite-4.30.jar +rocksdbjni-6.2.2.jar scala-collection-compat_2.12-2.1.1.jar scala-compiler-2.12.10.jar scala-library-2.12.10.jar diff --git a/dev/deps/spark-deps-hadoop-3.2 b/dev/deps/spark-deps-hadoop-3.2 index 2226baeadfba1..c0ab22d32abcb 100644 --- a/dev/deps/spark-deps-hadoop-3.2 +++ b/dev/deps/spark-deps-hadoop-3.2 @@ -194,6 +194,7 @@ protobuf-java-2.5.0.jar py4j-0.10.8.1.jar pyrolite-4.30.jar re2j-1.1.jar +rocksdbjni-6.2.2.jar scala-collection-compat_2.12-2.1.1.jar scala-compiler-2.12.10.jar scala-library-2.12.10.jar diff --git a/pom.xml b/pom.xml index 2396c5168b166..99514754a1c3b 100644 --- a/pom.xml +++ b/pom.xml @@ -193,6 +193,7 @@ 1.1 2.52.0 2.22 + 6.2.2 diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 02ed6f8adaa62..2112543cc9b82 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -147,6 +147,11 @@ mockito-core test + + org.rocksdb + rocksdbjni + ${rocksdb.version} + target/scala-${scala.binary.version}/classes diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDbInstance.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDbInstance.scala new file mode 100644 index 0000000000000..2b5e6751317d0 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDbInstance.scala @@ -0,0 +1,412 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.io.BufferedWriter +import java.io.File +import java.io.FileWriter + +import scala.collection.JavaConverters._ + +import org.apache.commons.io.FileUtils +import org.rocksdb._ +import org.rocksdb.RocksDB + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils + +class RocksDbInstance(keySchema: StructType, valueSchema: StructType, version: String) + extends Logging { + + import RocksDbInstance._ + RocksDB.loadLibrary() + + protected var db: RocksDB = null + protected var dbPath: String = _ + protected val readOptions: ReadOptions = new ReadOptions() + protected val writeOptions: WriteOptions = new WriteOptions() + protected val table_options = new BlockBasedTableConfig + protected val options: Options = new Options() + + private def isOpen(): Boolean = { + db != null + } + + def open(path: String, readOnly: Boolean): Unit = { + require(db == null, "Another rocksDb instance is already active") + try { + setOptions + db = if (readOnly) { + options.setCreateIfMissing(false) + RocksDB.openReadOnly(options, path) + } else { + options.setCreateIfMissing(true) + RocksDB.open(options, path) + } + dbPath = path + } catch { + case e: Throwable => + throw new IllegalStateException( + s"Error while creating rocksDb instance ${e.getMessage}", + e) + } + } + + def get(key: UnsafeRow): UnsafeRow = { + require(isOpen(), "Open rocksDb instance before any operation") + Option(db.get(readOptions, key.getBytes)) match { + case Some(valueInBytes) => + val value = new UnsafeRow(valueSchema.fields.length) + value.pointTo(valueInBytes, valueInBytes.length) + value + case None => null + } + } + + def put(key: UnsafeRow, value: UnsafeRow): Unit = { + require(isOpen(), "Open rocksDb instance before any operation") + db.put(key.getBytes, value.getBytes) + } + + def remove(key: UnsafeRow): Unit = { + require(isOpen(), "Open rocksDb instance before any operation") + db.delete(key.getBytes) + } + + def commit(checkPointPath: Option[String] = None): Unit = { + checkPointPath.foreach(f => createCheckpoint(db, f)) + } + + def abort: Unit = { + // no-op + } + + def close(): Unit = { + logDebug("Closing the db") + try { + db.close() + } finally { + db = null + options.close() + readOptions.close() + writeOptions.close() + } + } + + def iterator(closeDbOnCompletion: Boolean): Iterator[UnsafeRowPair] = { + require(isOpen(), "Open rocksDb instance before any operation") + Option(db.getSnapshot) match { + case Some(snapshot) => + var snapshotReadOptions: ReadOptions = + new ReadOptions().setSnapshot(snapshot).setFillCache(false) + val itr = db.newIterator(snapshotReadOptions) + createUnsafeRowPairIterator(itr, snapshotReadOptions, snapshot, closeDbOnCompletion) + case None => + Iterator.empty + } + } + + protected def createUnsafeRowPairIterator( + itr: RocksIterator, + itrReadOptions: ReadOptions, + snapshot: Snapshot, + closeDbOnCompletion: Boolean): Iterator[UnsafeRowPair] = { + + itr.seekToFirst() + + new Iterator[UnsafeRowPair] { + @volatile var isClosed = false + override def hasNext: Boolean = { + if (!isClosed && itr.isValid) { + true + } else { + if (!isClosed) { + isClosed = true + itrReadOptions.close() + db.releaseSnapshot(snapshot) + if (closeDbOnCompletion) { + close() + } + itr.close() + logDebug(s"read from DB completed") + } + false + } + } + + override def next(): UnsafeRowPair = { + val keyBytes = itr.key + val key = new UnsafeRow(keySchema.fields.length) + key.pointTo(keyBytes, keyBytes.length) + val valueBytes = itr.value + val value = new UnsafeRow(valueSchema.fields.length) + value.pointTo(valueBytes, valueBytes.length) + itr.next() + new UnsafeRowPair(key, value) + } + } + } + + protected def printMemoryStats(db: RocksDB): Unit = { + require(isOpen(), "Open rocksDb instance before any operation") + val usage = MemoryUtil + .getApproximateMemoryUsageByType( + List(db).asJava, + Set(rocksDbLRUCache.asInstanceOf[Cache]).asJava) + .asScala + val numKeys = db.getProperty(db.getDefaultColumnFamily, "rocksdb.estimate-num-keys") + logDebug(s""" + | rocksdb.estimate-num-keys = $numKeys + | ApproximateMemoryUsageByType = ${usage.toString} + | """.stripMargin) + } + + protected def printStats: Unit = { + require(isOpen(), "Open rocksDb instance before any operation") + try { + val stats = db.getProperty("rocksdb.stats") + logInfo(s"Stats = $stats") + } catch { + case e: Exception => + logWarning("Exception while getting stats") + } + } + + private val dataBlockSize = RocksDbStateStoreConf.blockSizeInKB + private val memTableMemoryBudget = RocksDbStateStoreConf.memtableBudgetInMB + private val enableStats = RocksDbStateStoreConf.enableStats + + protected def setOptions(): Unit = { + + // Read options + readOptions.setFillCache(true) + + // Write options + writeOptions.setSync(false) + writeOptions.setDisableWAL(true) + + /* + Table configs + Use Partitioned Index Filters + https://github.com/facebook/rocksdb/wiki/Partitioned-Index-Filters + Use format Verion = 4 + https://rocksdb.org/blog/2019/03/08/format-version-4.html + */ + table_options + .setBlockSize(dataBlockSize * 1024) + .setFormatVersion(4) + .setDataBlockIndexType(DataBlockIndexType.kDataBlockBinaryAndHash) + .setBlockCache(rocksDbLRUCache) + .setFilterPolicy(new BloomFilter(10, false)) + .setPinTopLevelIndexAndFilter(false) // Dont pin anything in cache + .setIndexType(IndexType.kTwoLevelIndexSearch) + .setPartitionFilters(true) + + options + .setTableFormatConfig(table_options) + .optimizeLevelStyleCompaction(memTableMemoryBudget * 1024 * 1024) + .setBytesPerSync(1048576) + .setMaxOpenFiles(5000) + .setIncreaseParallelism(4) + + if (enableStats) { + options + .setStatistics(new Statistics()) + .setStatsDumpPeriodSec(30) + } + } + + protected def createCheckpoint(rocksDb: RocksDB, dir: String): Unit = { + require(isOpen(), "Open rocksDb instance before any operation") + val (result, elapsedMs) = Utils.timeTakenMs { + val c = Checkpoint.create(rocksDb) + val f: File = new File(dir) + if (f.exists()) { + FileUtils.deleteDirectory(f) + } + c.createCheckpoint(dir) + c.close() + } + logInfo(s"Creating Checkpoint at $dir took $elapsedMs ms.") + } +} + +class OptimisticTransactionDbInstance( + keySchema: StructType, + valueSchema: StructType, + version: String) + extends RocksDbInstance(keySchema: StructType, valueSchema: StructType, version: String) { + import RocksDbInstance._ + RocksDB.loadLibrary() + + private var otdb: OptimisticTransactionDB = null + private var txn: Transaction = null + + private def isOpen(): Boolean = { + otdb != null + } + + def open(path: String): Unit = { + open(path, false) + } + + override def open(path: String, readOnly: Boolean): Unit = { + require(otdb == null, "Another OptimisticTransactionDbInstance instance is already active") + require(readOnly == false, "Cannot open OptimisticTransactionDbInstance in Readonly mode") + try { + setOptions() + options.setCreateIfMissing(true) + otdb = OptimisticTransactionDB.open(options, path) + db = otdb.getBaseDB + dbPath = path + } catch { + case e: Throwable => + throw new IllegalStateException( + s"Error while creating OptimisticTransactionDb instance" + + s" ${e.getMessage}", + e) + } + } + + def startTransactions(): Unit = { + require(isOpen(), "Open OptimisticTransactionDbInstance before performing any operation") + Option(txn) match { + case None => + val optimisticTransactionOptions = new OptimisticTransactionOptions() + txn = otdb.beginTransaction(writeOptions, optimisticTransactionOptions) + txn.setSavePoint() + case Some(x) => + throw new IllegalStateException(s"Already started a transaction") + } + } + + override def put(key: UnsafeRow, value: UnsafeRow): Unit = { + require(txn != null, "Start Transaction before inserting any key") + txn.put(key.getBytes, value.getBytes) + } + + override def remove(key: UnsafeRow): Unit = { + require(txn != null, "Start Transaction before deleting any key") + txn.delete(key.getBytes) + } + + override def get(key: UnsafeRow): UnsafeRow = { + require(txn != null, "Start Transaction before fetching any key-value") + Option(txn.get(readOptions, key.getBytes)) match { + case Some(valueInBytes) => + val value = new UnsafeRow(valueSchema.fields.length) + value.pointTo(valueInBytes, valueInBytes.length) + value + case None => + null + } + } + + override def commit(checkPointPath: Option[String] = None): Unit = { + require(txn != null, "Start Transaction before fetching any key-value") + try { + txn.commit() + txn.close() + updateVersionInCommitFile() + checkPointPath.foreach(f => createCheckpoint(otdb.asInstanceOf[RocksDB], f)) + } catch { + case e: Exception => + log.error(s"Unable to commit the transactions. Error message = ${e.getMessage}") + throw e + } finally { + txn = null + } + } + + override def abort(): Unit = { + require(txn != null, "No Transaction to abort") + txn.rollbackToSavePoint() + txn.close() + txn = null + } + + override def close(): Unit = { + require(isOpen(), "No DB to close") + require(txn == null, "Transaction should be closed before closing the DB connection") + printMemoryStats(otdb.asInstanceOf[RocksDB]) + logDebug("Closing the transaction db") + try { + otdb.close() + db.close() + otdb = null + db = null + } finally { + options.close() + readOptions.close() + writeOptions.close() + } + } + + override def iterator(closeDbOnCompletion: Boolean): Iterator[UnsafeRowPair] = { + require(txn != null, "Transaction is not set") + require( + closeDbOnCompletion == false, + "Cannot close a DB without aborting/committing the transactions") + val snapshot = db.getSnapshot + val readOptions = new ReadOptions() + .setSnapshot(snapshot) + .setFillCache(false) // for range lookup, we should not fill cache + val itr: RocksIterator = txn.getIterator(readOptions) + Option(itr) match { + case Some(i) => + logDebug(s"creating iterator from a transactional DB") + createUnsafeRowPairIterator(i, readOptions, snapshot, false) + case None => + Iterator.empty + } + } + + def getApproxEntriesInDb(): Long = { + require(isOpen(), "No DB to find Database Entries") + otdb.getProperty("rocksdb.estimate-num-keys").toLong + } + + protected def updateVersionInCommitFile(): Unit = { + val file = new File(dbPath, COMMIT_FILE_NAME) + val bw = new BufferedWriter(new FileWriter(file)) + bw.write(version.toString) + bw.close() + } + +} + +object RocksDbInstance { + + RocksDB.loadLibrary() + + val COMMIT_FILE_NAME = "commit" + + lazy val rocksDbLRUCache = new LRUCache(RocksDbStateStoreConf.cacheSize * 1024 * 1024, 6, false) + + def destroyDB(path: String): Unit = { + val f: File = new File(path) + val destroyOptions: Options = new Options() + if (f.exists()) { + RocksDB.destroyDB(path, destroyOptions) + FileUtils.deleteDirectory(f) + } + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDbStateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDbStateStoreConf.scala new file mode 100644 index 0000000000000..9d55b6c786b52 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDbStateStoreConf.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.apache.spark.SparkEnv +import org.apache.spark.internal.config.ConfigBuilder + +object RocksDbStateStoreConf { + + private[spark] val ROCKSDB_STATE_STORE_DATA_BLOCK_SIZE = + ConfigBuilder("spark.sql.streaming.stateStore.rocksDb.blockSizeInKB") + .doc( + "The maximum size (in KB) of packed data in a block of a table file. " + + "When reading from a table, an entire block is loaded into memory") + .intConf + .createWithDefault(32) + + private[spark] val ROCKSDB_STATE_STORE_MEMTABLE_BUDGET = + ConfigBuilder("spark.sql.streaming.stateStore.rocksDb.memtableBudgetInMB") + .doc("The maximum size (in MB) of memory to be used to optimize level style compaction") + .intConf + .createWithDefault(1024) + + private[spark] val ROCKSDB_STATE_STORE_CACHE_SIZE = + ConfigBuilder("spark.sql.streaming.stateStore.rocksDb.cacheSizeInMB") + .doc("The maximum size (in MB) of in-memory LRU cache for RocksDB operations") + .intConf + .createWithDefault(512) + + private[spark] val ROCKSDB_STATE_STORE_ENABLE_STATS = + ConfigBuilder("spark.sql.streaming.stateStore.rocksDb.enableDbStats") + .doc("Enable statistics for rocksdb for debugging and reporting") + .booleanConf + .createWithDefault(false) + + val blockSizeInKB: Int = Option(SparkEnv.get) + .map(_.conf.get(ROCKSDB_STATE_STORE_DATA_BLOCK_SIZE)) + .getOrElse(32) + + val memtableBudgetInMB: Int = Option(SparkEnv.get) + .map(_.conf.get(ROCKSDB_STATE_STORE_MEMTABLE_BUDGET)) + .getOrElse(1024) + + val cacheSize: Int = Option(SparkEnv.get) + .map(_.conf.get(RocksDbStateStoreConf.ROCKSDB_STATE_STORE_CACHE_SIZE)) + .getOrElse(512) + + val enableStats: Boolean = Option(SparkEnv.get) + .map(_.conf.get(ROCKSDB_STATE_STORE_ENABLE_STATS)) + .getOrElse(false) + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDbStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDbStateStoreProvider.scala new file mode 100644 index 0000000000000..cdc803b860a19 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDbStateStoreProvider.scala @@ -0,0 +1,646 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state; + +import java.io._ +import java.util + +import scala.collection.JavaConverters._ +import scala.io.Source +import scala.util.control.NonFatal + +import org.apache.commons.io.FileUtils +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs._ + +import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.internal.Logging +import org.apache.spark.io.FileUtility +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.streaming.CheckpointFileManager +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils + +/* + * An implementation of [[StateStoreProvider]] and [[StateStore]] using RocksDB as the storage + * engine. In RocksDB, new writes are inserted into a memtable which is flushed into local storage + * when the memtable fills up. It improves scalability as compared to + * [[HDFSBackedStateStoreProvider]] since now the state data which was large enough to fit in the + * executor memory can be written into the combination of memtable and local storage.The data is + * backed in a HDFS-compatible file system just like [[HDFSBackedStateStoreProvider]] + * + * Fault-tolerance model: + * - Every set of updates is written to a delta file before committing. + * - The state store is responsible for managing, collapsing and cleaning up of delta files. + * - Updates are committed in the db atomically + * + * Backup Model: + * - Delta file is written in a HDFS-compatible file system on batch commit + * - RocksDB state is check-pointed into a separate folder on batch commit + * - Maintenance thread periodically takes a snapshot of the latest check-pointed version of + * rocksDB state which is written to a HDFS-compatible file system. + * + * Isolation Guarantee: + * - writes are committed in the transaction. + * - writer thread which started the transaction can read all un-committed updates + * - any other reader thread cannot read any un-committed updates + */ +private[sql] class RocksDbStateStoreProvider extends StateStoreProvider with Logging { + + /* Internal fields and methods */ + @volatile private var stateStoreId_ : StateStoreId = _ + @volatile private var keySchema: StructType = _ + @volatile private var valueSchema: StructType = _ + @volatile private var storeConf: StateStoreConf = _ + @volatile private var hadoopConf: Configuration = _ + @volatile private var numberOfVersionsToRetain: Int = _ + @volatile private var localDir: String = _ + + private lazy val baseDir: Path = stateStoreId.storeCheckpointLocation() + private lazy val fm = CheckpointFileManager.create(baseDir, hadoopConf) + private lazy val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) + + private case class StoreFile(version: Long, path: Path, isSnapshot: Boolean) + + import WALUtils._ + + /** Implementation of [[StateStore]] API which is backed by RocksDB and HDFS */ + class RocksDbStateStore(val version: Long) extends StateStore with Logging { + + /** Trait and classes representing the internal state of the store */ + trait STATE + + case object LOADED extends STATE + + case object UPDATING extends STATE + + case object COMMITTED extends STATE + + case object ABORTED extends STATE + + private val newVersion = version + 1 + @volatile private var state: STATE = LOADED + private val finalDeltaFile: Path = deltaFile(baseDir, newVersion) + private lazy val deltaFileStream = fm.createAtomic(finalDeltaFile, overwriteIfPossible = true) + private lazy val compressedStream = compressStream(deltaFileStream, sparkConf) + + override def id: StateStoreId = RocksDbStateStoreProvider.this.stateStoreId + + var rocksDbWriteInstance: OptimisticTransactionDbInstance = null + + /* + * numEntriesInDb and bytesUsedByDb are estimated value + * due to the nature of RocksDB implementation. + * see https://github.com/facebook/rocksdb/wiki/RocksDB-FAQ for more details + */ + var numEntriesInDb: Long = 0L + var bytesUsedByDb: Long = 0L + + private def initTransaction(): Unit = { + if (state == LOADED && rocksDbWriteInstance == null) { + logDebug(s"Creating Transactional DB for batch $version") + rocksDbWriteInstance = + new OptimisticTransactionDbInstance(keySchema, valueSchema, newVersion.toString) + rocksDbWriteInstance.open(rocksDbPath) + rocksDbWriteInstance.startTransactions() + state = UPDATING + } + } + + override def get(key: UnsafeRow): UnsafeRow = { + initTransaction() + rocksDbWriteInstance.get(key) + } + + override def put(key: UnsafeRow, value: UnsafeRow): Unit = { + initTransaction() + require(state == UPDATING, s"Cannot put after already committed or aborted") + val keyCopy = key.copy() + val valueCopy = value.copy() + rocksDbWriteInstance.put(keyCopy, valueCopy) + writeUpdateToDeltaFile(compressedStream, keyCopy, valueCopy) + } + + override def remove(key: UnsafeRow): Unit = { + initTransaction() + require(state == UPDATING, "Cannot remove after already committed or aborted") + rocksDbWriteInstance.remove(key) + writeRemoveToDeltaFile(compressedStream, key) + } + + override def getRange( + start: Option[UnsafeRow], + end: Option[UnsafeRow]): Iterator[UnsafeRowPair] = { + require(state == UPDATING, "Cannot getRange after already committed or aborted") + iterator() + } + + /** Commit all the updates that have been made to the store, and return the new version. */ + override def commit(): Long = { + initTransaction() + require(state == UPDATING, s"Cannot commit after already committed or aborted") + try { + synchronized { + rocksDbWriteInstance.commit(Some(getCheckpointPath(newVersion))) + finalizeDeltaFile(compressedStream) + } + state = COMMITTED + numEntriesInDb = rocksDbWriteInstance.getApproxEntriesInDb() + bytesUsedByDb = numEntriesInDb * (keySchema.defaultSize + valueSchema.defaultSize) + newVersion + } catch { + case NonFatal(e) => + throw new IllegalStateException(s"Error committing version $newVersion into $this", e) + } finally { + storeMap.remove(version) + close() + } + } + + /* + * Abort all the updates made on this store. This store will not be usable any more. + */ + override def abort(): Unit = { + // This if statement is to ensure that files are deleted only if there are changes to the + // StateStore. We have two StateStores for each task, one which is used only for reading, and + // the other used for read+write. We don't want the read-only to delete state files. + try { + if (state == UPDATING) { + state = ABORTED + synchronized { + rocksDbWriteInstance.abort() + cancelDeltaFile(compressedStream, deltaFileStream) + } + logInfo(s"Aborted version $newVersion for $this") + } else { + state = ABORTED + } + } catch { + case NonFatal(e) => + throw new IllegalStateException(s"Error aborting version $newVersion into $this", e) + } finally { + storeMap.remove(version) + close() + } + } + + def close(): Unit = { + if (rocksDbWriteInstance != null) { + rocksDbWriteInstance.close() + rocksDbWriteInstance = null + } + } + + /* + * Get an iterator of all the store data. + * This can be called only after committing all the updates made in the current thread. + */ + override def iterator(): Iterator[UnsafeRowPair] = { + state match { + case UPDATING => + logDebug("state = updating using transaction DB") + // We need to use current db to read uncommitted transactions + rocksDbWriteInstance.iterator(closeDbOnCompletion = false) + + case LOADED | ABORTED => + // use check-pointed db for previous version + logDebug(s"state = loaded/aborted using check-pointed DB with version $version") + if (version == 0) { + Iterator.empty + } else { + val path = getCheckpointPath(version) + val r: RocksDbInstance = + new RocksDbInstance(keySchema, valueSchema, version.toString) + r.open(path, readOnly = true) + r.iterator(closeDbOnCompletion = true) + } + case COMMITTED => + logDebug(s"state = committed using check-pointed DB with version $newVersion") + // use check-pointed db for current updated version + val path = getCheckpointPath(newVersion) + val r: RocksDbInstance = + new RocksDbInstance(keySchema, valueSchema, newVersion.toString) + r.open(path, readOnly = true) + r.iterator(closeDbOnCompletion = true) + + case _ => Iterator.empty + } + } + + override def metrics: StateStoreMetrics = { + val metricsFromProvider: Map[String, Long] = getMetricsForProvider() + val customMetrics = metricsFromProvider.flatMap { + case (name, value) => + // just allow searching from list cause the list is small enough + supportedCustomMetrics.find(_.name == name).map(_ -> value) + } + StateStoreMetrics(Math.max(numEntriesInDb, 0), Math.max(bytesUsedByDb, 0), customMetrics) + } + + /* + * Whether all updates have been committed + */ + override def hasCommitted: Boolean = { + state == COMMITTED + } + + override def toString(): String = { + s"RocksDbStateStore[id=(op=${id.operatorId},part=${id.partitionId}),dir=$baseDir]" + } + + } + + /* + * Initialize the provider with more contextual information from the SQL operator. + * This method will be called first after creating an instance of the StateStoreProvider by + * reflection. + * + * @param stateStoreId Id of the versioned StateStores that this provider will generate + * @param keySchema Schema of keys to be stored + * @param valueSchema Schema of value to be stored + * @param keyIndexOrdinal Optional column (represent as the ordinal of the field in keySchema) by + * which the StateStore implementation could index the data. + * @param storeConfs Configurations used by the StateStores + * @param hadoopConf Hadoop configuration that could be used by StateStore + * to save state data + */ + override def init( + stateStoreId: StateStoreId, + keySchema: StructType, + valueSchema: StructType, + keyIndexOrdinal: Option[Int], // for sorting the data by their keys + storeConfs: StateStoreConf, + hadoopConf: Configuration): Unit = { + this.stateStoreId_ = stateStoreId + this.keySchema = keySchema + this.valueSchema = valueSchema + this.storeConf = storeConfs + this.hadoopConf = hadoopConf + this.numberOfVersionsToRetain = storeConfs.maxVersionsToRetainInMemory + fm.mkdirs(baseDir) + this.localDir = storeConfs.confs + .getOrElse( + "spark.sql.streaming.stateStore.rocksDb.localDir", + Utils.createTempDir().getAbsolutePath) + } + + /* + * Return the id of the StateStores this provider will generate. + * Should be the same as the one passed in init(). + */ + override def stateStoreId: StateStoreId = stateStoreId_ + + /* + * Called when the provider instance is unloaded from the executor + */ + override def close(): Unit = { + storeMap.values.asScala.foreach(_.close) + storeMap.clear() + } + + private val storeMap = new util.HashMap[Long, RocksDbStateStore]() + + /* + * Optional custom metrics that the implementation may want to report. + * + * @note The StateStore objects created by this provider must report the same custom metrics + * (specifically, same names) through `StateStore.metrics`. + */ + // TODO + override def supportedCustomMetrics: Seq[StateStoreCustomMetric] = { + Nil + } + + override def toString(): String = { + s"RocksDbStateStoreProvider[" + + s"id = (op=${stateStoreId.operatorId},part=${stateStoreId.partitionId}),dir = $baseDir]" + } + + def getMetricsForProvider(): Map[String, Long] = synchronized { + Map.empty[String, Long] + } + + /* + * Return an instance of [[StateStore]] representing state data of the given version + */ + override def getStore(version: Long): StateStore = synchronized { + logInfo(s"get Store for version $version") + require(version >= 0, "Version cannot be less than 0") + if (storeMap.containsKey(version)) { + storeMap.get(version) + } else { + val store = createStore(version) + storeMap.put(version, store) + store + } + } + + private def createStore(version: Long): RocksDbStateStore = { + val newStore = new RocksDbStateStore(version) + if (version > 0) { + // load the data into the rocksDB + logInfo( + s"Loading state into the db for $version and partition ${stateStoreId_.partitionId}") + loadIntoRocksDB(version) + } + newStore + } + + private def loadIntoRocksDB(version: Long): Unit = { + /* + 1. Get last available/committed Rocksdb version in local folder + 2. If last committed version = version, we already have loaded rocksdb state. + 3. If last committed version = version - 1, + we have to apply delta for version in the existing rocksdb + 4. Otherwise we have to recreate a new rocksDB store by using Snapshots/Delta + */ + val (_, elapsedMs) = Utils.timeTakenMs { + var lastAvailableVersion = getLastCommittedVersion() + if (lastAvailableVersion == -1L || lastAvailableVersion <= version - 2) { + // Destroy existing DB so that we can reconstruct it using snapshot and delta files + RocksDbInstance.destroyDB(rocksDbPath) + var lastAvailableSnapShotVersion: Long = version + 1 + // load from snapshot + var found = false + while (!found && lastAvailableSnapShotVersion > 0) { + try { + lastAvailableSnapShotVersion = lastAvailableSnapShotVersion - 1 + found = loadSnapshotFile(lastAvailableSnapShotVersion) + logDebug( + s"Snapshot for version $lastAvailableSnapShotVersion " + + "and partition ${stateStoreId_.partitionId}: found = $found") + } catch { + case e: Exception => + logError(s"$e while reading snapshot file") + throw e + } + } + lastAvailableVersion = lastAvailableSnapShotVersion + } + if (lastAvailableVersion < version) { + applyDelta(version, lastAvailableVersion) + } + } + logInfo( + s"Loading state for $version and partition ${stateStoreId_.partitionId} took $elapsedMs ms.") + } + + private def getLastCommittedVersion(): Long = { + val f = new File(rocksDbPath, RocksDbInstance.COMMIT_FILE_NAME) + if (f.exists()) { + try { + val fileContents = Source.fromFile(f.getAbsolutePath).getLines.mkString + return fileContents.toLong + } catch { + case e: Exception => + logWarning("Exception while reading committed file") + } + } + return -1L + } + + private def loadSnapshotFile(version: Long): Boolean = { + val fileToRead = snapshotFile(baseDir, version) + if (version == 0 || !fm.exists(fileToRead)) { + return false + } + val versionTempPath = getTempPath(version) + val tmpLocDir: File = new File(versionTempPath) + val tmpLocFile: File = new File(s"${versionTempPath}.tar") + try { + logInfo(s"Will download $fileToRead at location ${tmpLocFile.toString()}") + if (downloadFile(fm, fileToRead, new Path(tmpLocFile.getAbsolutePath), sparkConf)) { + FileUtility.extractTarFile(tmpLocFile.getAbsolutePath, versionTempPath) + if (!tmpLocDir.list().exists(_.endsWith(".sst"))) { + logWarning("Snapshot files are corrupted") + throw new IOException( + s"Error reading snapshot file $fileToRead of $this:" + + s" No SST files found") + } + FileUtils.moveDirectory(tmpLocDir, new File(rocksDbPath)) + true + } else { + false + } + } catch { + case e: Exception => + logError(s"Exception while loading snapshot file $e") + throw e + } finally { + if (tmpLocFile.exists()) { + tmpLocFile.delete() + } + FileUtils.deleteDirectory(tmpLocDir) + } + } + + private def applyDelta(version: Long, lastAvailableVersion: Long): Unit = { + var rocksDbWriteInstance: OptimisticTransactionDbInstance = null + try { + rocksDbWriteInstance = + new OptimisticTransactionDbInstance(keySchema, valueSchema, version.toString) + rocksDbWriteInstance.open(rocksDbPath) + rocksDbWriteInstance.startTransactions() + // Load all the deltas from the version after the last available + // one up to the target version. + // The last available version is the one with a full snapshot, so it doesn't need deltas. + for (deltaVersion <- (lastAvailableVersion + 1) to version) { + val fileToRead = deltaFile(baseDir, deltaVersion) + updateFromDeltaFile( + fm, + fileToRead, + keySchema, + valueSchema, + rocksDbWriteInstance, + sparkConf) + logInfo(s"Read delta file for version $version of $this from $fileToRead") + } + rocksDbWriteInstance.commit(Some(getCheckpointPath(version))) + } catch { + case e: Exception => + logError(s"Exception while loading state ${e.getMessage}") + if (rocksDbWriteInstance != null) { + rocksDbWriteInstance.abort() + } + throw e + } finally { + if (rocksDbWriteInstance != null) { + rocksDbWriteInstance.close() + } + } + } + + /** Optional method for providers to allow for background maintenance (e.g. compactions) */ + override def doMaintenance(): Unit = { + try { + val (files: Seq[WALUtils.StoreFile], e1) = Utils.timeTakenMs(fetchFiles(fm, baseDir)) + logDebug(s"fetchFiles() took $e1 ms.") + doSnapshot(files) + cleanup(files) + cleanRocksDBCheckpoints(files) + } catch { + case NonFatal(e) => + logWarning(s"Error performing snapshot and cleaning up $this") + } + } + + private def doSnapshot(files: Seq[WALUtils.StoreFile]): Unit = { + if (files.nonEmpty) { + val lastVersion = files.last.version + val deltaFilesForLastVersion = + filesForVersion(files, lastVersion).filter(_.isSnapshot == false) + if (deltaFilesForLastVersion.size > storeConf.minDeltasForSnapshot) { + val dbPath = getCheckpointPath(lastVersion) + val snapShotFileName = s"{getTempPath(lastVersion)}.snapshot" + val f = new File(snapShotFileName) + try { + val (_, t1) = Utils.timeTakenMs { + FileUtility.createTarFile(dbPath, snapShotFileName) + val targetFile = snapshotFile(baseDir, lastVersion) + uploadFile(fm, new Path(snapShotFileName), targetFile, sparkConf) + } + logInfo(s"Creating snapshot file for ${stateStoreId_.partitionId} took $t1 ms.") + } catch { + case e: Exception => + logError(s"Exception while creating snapshot $e") + throw e + } finally { + f.delete() // delete the tarball + } + } + } + } + + /* + * Clean up old snapshots and delta files that are not needed any more. It ensures that last + * few versions of the store can be recovered from the files, so re-executed RDD operations + * can re-apply updates on the past versions of the store. + */ + private[state] def cleanup(files: Seq[WALUtils.StoreFile]): Unit = { + try { + if (files.nonEmpty) { + val earliestVersionToRetain = files.last.version - storeConf.minVersionsToRetain + if (earliestVersionToRetain > 0) { + val earliestFileToRetain = filesForVersion(files, earliestVersionToRetain).head + val filesToDelete = files.filter(_.version < earliestFileToRetain.version) + val (_, e2) = Utils.timeTakenMs { + filesToDelete.foreach { f => + fm.delete(f.path) + val file = new File(rocksDbPath, f.version.toString) + if (file.exists()) { + file.delete() + } + } + } + logDebug(s"deleting files took $e2 ms.") + logInfo( + s"Deleted files older than ${earliestFileToRetain.version} for $this: " + + filesToDelete.mkString(", ")) + } + } + } catch { + case NonFatal(e) => + logWarning(s"Error cleaning up files for $this", e) + } + } + + private def cleanRocksDBCheckpoints(files: Seq[WALUtils.StoreFile]): Unit = { + try { + val (_, e2) = Utils.timeTakenMs { + if (files.nonEmpty) { + val earliestVersionToRetain = files.last.version - storeConf.minVersionsToRetain + if (earliestVersionToRetain > 0) { + new File(getCheckpointPath(earliestVersionToRetain)).getParentFile + .listFiles(new FileFilter { + def accept(f: File): Boolean = { + try { + f.getName.toLong < earliestVersionToRetain + } catch { + case _: NumberFormatException => false + } + } + }) + .foreach(p => RocksDbInstance.destroyDB(p.getAbsolutePath)) + logInfo( + s"Deleted rocksDB checkpoints older than ${earliestVersionToRetain} for $this: ") + } + } + } + logDebug(s"deleting rocksDB checkpoints took $e2 ms.") + } catch { + case NonFatal(e) => logWarning(s"Error cleaning up files for $this", e) + } + } + + // Used only for unit tests + private[sql] def latestIterator(): Iterator[UnsafeRowPair] = synchronized { + val versionsInFiles = fetchFiles(fm, baseDir).map(_.version).toSet + if (versionsInFiles.nonEmpty) { + val maxVersion = versionsInFiles.max + if (maxVersion > 0) { + loadIntoRocksDB(maxVersion) + val r: RocksDbInstance = new RocksDbInstance(keySchema, valueSchema, maxVersion.toString) + try { + r.open(rocksDbPath, readOnly = true) + return r.iterator(false) + } catch { + case e: Exception => + logWarning(s"Exception ${e.getMessage} while getting latest Iterator") + } + } + } + Iterator.empty + } + + private[sql] def getLocalDir: String = localDir + + private[sql] lazy val rocksDbPath: String = { + getPath("db") + } + + private def getCheckpointPath(version: Long): String = { + getPath("checkpoint", Some(version.toString)) + } + + private def getTempPath(version: Long): String = { + getPath("tmp", Some(version.toString)) + } + + private def getPath(subFolderName: String, version: Option[String] = None): String = { + val checkpointRootLocationPath = new Path(stateStoreId.checkpointRootLocation) + + val dirPath = new Path( + localDir, + new Path( + new Path( + subFolderName, + checkpointRootLocationPath.getName + "_" + checkpointRootLocationPath.hashCode()), + new Path(stateStoreId_.operatorId.toString, stateStoreId_.partitionId.toString))) + + val f: File = new File(dirPath.toString) + if (!f.exists() && !f.mkdirs()) { + throw new IllegalStateException(s"Couldn't create directory ${dirPath.toString}") + } + + if (version.isEmpty) { + dirPath.toString + } else { + new Path(dirPath, version.get).toString + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/WALUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/WALUtils.scala new file mode 100644 index 0000000000000..1da7e8d8bbeed --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/WALUtils.scala @@ -0,0 +1,275 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.io._ +import java.util.Locale + +import com.google.common.io.ByteStreams +import org.apache.commons.io.IOUtils +import org.apache.hadoop.fs.{FileStatus, FSDataInputStream, FSError, Path} +import scala.collection.mutable + +import org.apache.spark.SparkConf +import org.apache.spark.io.LZ4CompressionCodec +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.streaming.CheckpointFileManager +import org.apache.spark.sql.execution.streaming.CheckpointFileManager.CancellableFSDataOutputStream +import org.apache.spark.sql.types.StructType + +object WALUtils { + + case class StoreFile(version: Long, path: Path, isSnapshot: Boolean) + private val EOF_MARKER = -1 + + /** Files needed to recover the given version of the store */ + def filesForVersion(allFiles: Seq[StoreFile], version: Long): Seq[StoreFile] = { + require(version >= 0) + require(allFiles.exists(_.version == version)) + + val latestSnapshotFileBeforeVersion = allFiles + .filter(_.isSnapshot == true) + .takeWhile(_.version <= version) + .lastOption + val deltaBatchFiles = latestSnapshotFileBeforeVersion match { + case Some(snapshotFile) => + val deltaFiles = allFiles.filter { file => + file.version > snapshotFile.version && file.version <= version + }.toList + require( + deltaFiles.size == version - snapshotFile.version, + s"Unexpected list of delta files for version $version for $this: $deltaFiles") + deltaFiles + + case None => + allFiles.takeWhile(_.version <= version) + } + latestSnapshotFileBeforeVersion.toSeq ++ deltaBatchFiles + } + + /** Fetch all the files that back the store */ + def fetchFiles(fm: CheckpointFileManager, baseDir: Path): Seq[StoreFile] = { + val files: Seq[FileStatus] = try { + fm.list(baseDir) + } catch { + case _: java.io.FileNotFoundException => + Seq.empty + } + val versionToFiles = new mutable.HashMap[Long, StoreFile] + files.foreach { status => + val path = status.getPath + val nameParts = path.getName.split("\\.") + if (nameParts.size == 2) { + val version = nameParts(0).toLong + nameParts(1).toLowerCase(Locale.ROOT) match { + case "delta" => + // ignore the file otherwise, snapshot file already exists for that batch id + if (!versionToFiles.contains(version)) { + versionToFiles.put(version, StoreFile(version, path, isSnapshot = false)) + } + case "snapshot" => + versionToFiles.put(version, StoreFile(version, path, isSnapshot = true)) + case _ => + // logWarning(s"Could not identify file $path for $this") + } + } + } + val storeFiles = versionToFiles.values.toSeq.sortBy(_.version) + storeFiles + } + + def compressStream(outputStream: DataOutputStream, sparkConf: SparkConf): DataOutputStream = { + val compressed = new LZ4CompressionCodec(sparkConf).compressedOutputStream(outputStream) + new DataOutputStream(compressed) + } + + def decompressStream(inputStream: DataInputStream, sparkConf: SparkConf): DataInputStream = { + val compressed = new LZ4CompressionCodec(sparkConf).compressedInputStream(inputStream) + new DataInputStream(compressed) + } + + def writeUpdateToDeltaFile(output: DataOutputStream, key: UnsafeRow, value: UnsafeRow): Unit = { + val keyBytes = key.getBytes() + val valueBytes = value.getBytes() + output.writeInt(keyBytes.size) + output.write(keyBytes) + output.writeInt(valueBytes.size) + output.write(valueBytes) + } + + def writeRemoveToDeltaFile(output: DataOutputStream, key: UnsafeRow): Unit = { + val keyBytes = key.getBytes() + output.writeInt(keyBytes.size) + output.write(keyBytes) + output.writeInt(EOF_MARKER) + } + + def finalizeDeltaFile(output: DataOutputStream): Unit = { + output.writeInt(EOF_MARKER) // Write this magic number to signify end of file + output.close() + } + + def updateFromDeltaFile( + fm: CheckpointFileManager, + fileToRead: Path, + keySchema: StructType, + valueSchema: StructType, + newRocksDb: OptimisticTransactionDbInstance, + sparkConf: SparkConf): Unit = { + var input: DataInputStream = null + val sourceStream = try { + fm.open(fileToRead) + } catch { + case f: FileNotFoundException => + throw new IllegalStateException( + s"Error reading delta file $fileToRead of $this: $fileToRead does not exist", + f) + } + try { + input = decompressStream(sourceStream, sparkConf) + var eof = false + + while (!eof) { + val keySize = input.readInt() + if (keySize == EOF_MARKER) { + eof = true + } else if (keySize < 0) { + newRocksDb.abort + newRocksDb.close() + throw new IOException( + s"Error reading delta file $fileToRead of $this: key size cannot be $keySize") + } else { + val keyRowBuffer = new Array[Byte](keySize) + ByteStreams.readFully(input, keyRowBuffer, 0, keySize) + + val keyRow = new UnsafeRow(keySchema.fields.length) + keyRow.pointTo(keyRowBuffer, keySize) + + val valueSize = input.readInt() + if (valueSize < 0) { + newRocksDb.remove(key = keyRow) + } else { + val valueRowBuffer = new Array[Byte](valueSize) + ByteStreams.readFully(input, valueRowBuffer, 0, valueSize) + val valueRow = new UnsafeRow(valueSchema.fields.length) + // If valueSize in existing file is not multiple of 8, floor it to multiple of 8. + // This is a workaround for the following: + // Prior to Spark 2.3 mistakenly append 4 bytes to the value row in + // `RowBasedKeyValueBatch`, which gets persisted into the checkpoint data + valueRow.pointTo(valueRowBuffer, (valueSize / 8) * 8) + newRocksDb.put(keyRow, valueRow) + } + } + } + } finally { + if (input != null) input.close() + } + } + + /* + * Try to cancel the underlying stream and safely close the compressed stream. + * + * @param compressedStream the compressed stream. + * @param rawStream the underlying stream which needs to be cancelled. + */ + def cancelDeltaFile( + compressedStream: DataOutputStream, + rawStream: CancellableFSDataOutputStream): Unit = { + try { + if (rawStream != null) rawStream.cancel() + IOUtils.closeQuietly(compressedStream) + } catch { + case e: FSError if e.getCause.isInstanceOf[IOException] => + // Closing the compressedStream causes the stream to write/flush flush data into the + // rawStream. Since the rawStream is already closed, there may be errors. + // Usually its an IOException. However, Hadoop's RawLocalFileSystem wraps + // IOException into FSError. + } + } + + def uploadFile( + fm: CheckpointFileManager, + sourceFile: Path, + targetFile: Path, + sparkConf: SparkConf): Unit = { + var output: CancellableFSDataOutputStream = null + var in: BufferedInputStream = null + try { + in = new BufferedInputStream(new FileInputStream(sourceFile.toString)) + output = fm.createAtomic(targetFile, overwriteIfPossible = true) + val buffer = new Array[Byte](1024) + var len = in.read(buffer) + while (len > 0) { + output.write(buffer, 0, len) + len = in.read(buffer) + } + output.close() + } catch { + case e: Throwable => + if (output != null) output.cancel() + throw e + } finally { + if (in != null) { + in.close() + } + } + } + + def downloadFile( + fm: CheckpointFileManager, + sourceFile: Path, + targetFile: Path, + sparkConf: SparkConf): Boolean = { + var in: FSDataInputStream = null + var output: BufferedOutputStream = null + try { + in = fm.open(sourceFile) + output = new BufferedOutputStream(new FileOutputStream(targetFile.toString)) + val buffer = new Array[Byte](1024) + var eof = false + while (!eof) { + val len = in.read(buffer) + if (len > 0) { + output.write(buffer, 0, len) + } else { + eof = true + } + } + output.close() + } catch { + case e: Throwable => + new File(targetFile.toString).delete() + throw e + } finally { + output.close() + if (in != null) { + in.close() + } + } + return true + } + + def deltaFile(baseDir: Path, version: Long): Path = { + new Path(baseDir, s"$version.delta") + } + + def snapshotFile(baseDir: Path, version: Long): Path = { + new Path(baseDir, s"$version.snapshot") + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDbStateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDbStateStoreSuite.scala new file mode 100644 index 0000000000000..435b5830c3466 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDbStateStoreSuite.scala @@ -0,0 +1,606 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.io.File +import java.util.UUID + +import org.apache.commons.io.FileUtils +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.scalatest.{BeforeAndAfter, PrivateMethodTester} +import org.scalatest.concurrent.Eventually.{eventually, timeout} +import org.scalatest.time.SpanSugar._ +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.util.Random + +import org.apache.spark.{SparkConf, SparkContext, SparkEnv} +import org.apache.spark.LocalSparkContext.withSpark +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.functions.count +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +import org.apache.spark.util.Utils + +class RocksDbStateStoreSuite + extends StateStoreSuiteBase[RocksDbStateStoreProvider] + with BeforeAndAfter + with PrivateMethodTester { + type MapType = mutable.HashMap[UnsafeRow, UnsafeRow] + type ProviderMapType = java.util.concurrent.ConcurrentHashMap[UnsafeRow, UnsafeRow] + + import StateStoreCoordinatorSuite._ + import StateStoreTestsHelper._ + + val keySchema = StructType(Seq(StructField("key", StringType, true))) + val valueSchema = StructType(Seq(StructField("value", IntegerType, true))) + + before { + StateStore.stop() + require(!StateStore.isMaintenanceRunning) + } + + after { + StateStore.stop() + require(!StateStore.isMaintenanceRunning) + } + + def updateVersionTo( + provider: StateStoreProvider, + currentVersion: Int, + targetVersion: Int): Int = { + var newCurrentVersion = currentVersion + for (i <- newCurrentVersion until targetVersion) { + newCurrentVersion = incrementVersion(provider, i) + } + require(newCurrentVersion === targetVersion) + newCurrentVersion + } + + def incrementVersion(provider: StateStoreProvider, currentVersion: Int): Int = { + val store = provider.getStore(currentVersion) + put(store, "a", currentVersion + 1) + store.commit() + currentVersion + 1 + } + + def checkLoadedVersions( + rocksDbWriteInstance: RocksDbInstance, + count: Int, + earliestKey: Long, + latestKey: Long): Unit = { + assert(rocksDbWriteInstance.iterator(false).length === count) + } + + def checkVersion( + rocksDbWriteInstance: RocksDbInstance, + version: Long, + expectedData: Map[String, Int]): Unit = { + + val originValueMap = rocksDbWriteInstance + .iterator(false) + .map { row => + rowToString(row.key) -> rowToInt(row.value) + } + .toMap[String, Int] + + assert(originValueMap === expectedData) + } + + test("get, put, remove, commit, and all data iterator") { + val provider = newStoreProvider() + + // Verify state before starting a new set of updates + assert(getLatestData(provider).isEmpty) + + val store = provider.getStore(0) + assert(!store.hasCommitted) + assert(get(store, "a") === None) + assert(store.iterator().isEmpty) + + // Verify state after updating + put(store, "a", 1) + assert(get(store, "a") === Some(1)) + + assert(store.iterator().nonEmpty) + assert(getLatestData(provider).isEmpty) + + // Make updates, commit and then verify state + put(store, "b", 2) + put(store, "aa", 3) + remove(store, _.startsWith("a")) + assert(store.commit() === 1) + + assert(store.hasCommitted) + assert(rowsToSet(store.iterator()) === Set("b" -> 2)) + assert(getLatestData(provider) === Set("b" -> 2)) + + // Trying to get newer versions should fail + intercept[Exception] { + provider.getStore(2) + } + intercept[Exception] { + getData(provider, 2) + } + + // New updates to the reloaded store with new version, and does not change old version + val reloadedProvider = newStoreProvider(store.id, provider.getLocalDir) + val reloadedStore = reloadedProvider.getStore(1) + put(reloadedStore, "c", 4) + assert(reloadedStore.commit() === 2) + assert(rowsToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4)) + assert(getLatestData(provider) === Set("b" -> 2, "c" -> 4)) + assert(getData(provider, version = 1) === Set("b" -> 2)) + } + + test("snapshotting") { + val provider = + newStoreProvider(opId = Random.nextInt, partition = 0, minDeltasForSnapshot = 5) + + var currentVersion = 0 + + currentVersion = updateVersionTo(provider, currentVersion, 2) + require(getData(provider) === Set("a" -> 2)) + provider.doMaintenance() // should not generate snapshot files + assert(getData(provider) === Set("a" -> 2)) + + for (i <- 1 to currentVersion) { + assert(fileExists(provider, i, isSnapshot = false)) // all delta files present + assert(!fileExists(provider, i, isSnapshot = true)) // no snapshot files present + } + + // After version 6, snapshotting should generate one snapshot file + currentVersion = updateVersionTo(provider, currentVersion, 6) + require(getData(provider) === Set("a" -> 6), "store not updated correctly") + provider.doMaintenance() // should generate snapshot files + + val snapshotVersion = + (0 to 6).find(version => fileExists(provider, version, isSnapshot = true)) + assert(snapshotVersion.nonEmpty, "snapshot file not generated") + deleteFilesEarlierThanVersion(provider, snapshotVersion.get) + assert( + getData(provider, snapshotVersion.get) === Set("a" -> snapshotVersion.get), + "snapshotting messed up the data of the snapshotted version") + assert( + getData(provider) === Set("a" -> 6), + "snapshotting messed up the data of the final version") + + // After version 20, snapshotting should generate newer snapshot files + currentVersion = updateVersionTo(provider, currentVersion, 20) + require(getData(provider) === Set("a" -> 20), "store not updated correctly") + provider.doMaintenance() // do snapshot + + val latestSnapshotVersion = + (0 to 20).filter(version => fileExists(provider, version, isSnapshot = true)).lastOption + assert(latestSnapshotVersion.nonEmpty, "no snapshot file found") + assert(latestSnapshotVersion.get > snapshotVersion.get, "newer snapshot not generated") + + deleteFilesEarlierThanVersion(provider, latestSnapshotVersion.get) + assert(getData(provider) === Set("a" -> 20), "snapshotting messed up the data") + } + + test("cleaning") { + val provider = + newStoreProvider(opId = Random.nextInt, partition = 0, minDeltasForSnapshot = 5) + + for (i <- 1 to 20) { + val store = provider.getStore(i - 1) + put(store, "a", i) + store.commit() + provider.doMaintenance() // do cleanup + } + require( + rowsToSet(provider.latestIterator()) === Set("a" -> 20), + "store not updated correctly") + + assert(!fileExists(provider, version = 1, isSnapshot = false)) // first file should be deleted + + // last couple of versions should be retrievable + assert(getData(provider, 20) === Set("a" -> 20)) + assert(getData(provider, 19) === Set("a" -> 19)) + } + + testQuietly("SPARK-19677: Committing a delta file atop an existing one should not fail on HDFS") { + val conf = new Configuration() + conf.set("fs.fake.impl", classOf[RenameLikeHDFSFileSystem].getName) + conf.set("fs.defaultFS", "fake:///") + + val provider = newStoreProvider(opId = Random.nextInt, partition = 0, hadoopConf = conf) + provider.getStore(0).commit() + provider.getStore(0).commit() + + // Verify we don't leak temp files + val tempFiles = FileUtils + .listFiles(new File(provider.stateStoreId.checkpointRootLocation), null, true) + .asScala + .filter(_.getName.startsWith("temp-")) + assert(tempFiles.isEmpty) + } + + test("corrupted file handling") { + val provider = + newStoreProvider(opId = Random.nextInt, partition = 0, minDeltasForSnapshot = 5) + for (i <- 1 to 6) { + val store = provider.getStore(i - 1) + put(store, "a", i) + store.commit() + provider.doMaintenance() // do cleanup + } + val snapshotVersion = (0 to 10) + .find(version => fileExists(provider, version, isSnapshot = true)) + .getOrElse(fail("snapshot file not found")) + + // Corrupt snapshot file and verify that it throws error + provider.close() + assert(getData(provider, snapshotVersion) === Set("a" -> snapshotVersion)) + RocksDbInstance.destroyDB(provider.rocksDbPath) + + corruptFile(provider, snapshotVersion, isSnapshot = true) + intercept[Exception] { + provider.close() + RocksDbInstance.destroyDB(provider.rocksDbPath) + getData(provider, snapshotVersion) + } + + // Corrupt delta file and verify that it throws error + provider.close() + RocksDbInstance.destroyDB(provider.rocksDbPath) + assert(getData(provider, snapshotVersion - 1) === Set("a" -> (snapshotVersion - 1))) + + corruptFile(provider, snapshotVersion - 1, isSnapshot = false) + intercept[Exception] { + provider.close() + RocksDbInstance.destroyDB(provider.rocksDbPath) + getData(provider, snapshotVersion - 1) + } + + // Delete delta file and verify that it throws error + deleteFilesEarlierThanVersion(provider, snapshotVersion) + intercept[Exception] { + provider.close() + RocksDbInstance.destroyDB(provider.rocksDbPath) + getData(provider, snapshotVersion - 1) + } + } + + test("StateStore.get") { + quietly { + val dir = newDir() + val storeId = StateStoreProviderId(StateStoreId(dir, 0, 0), UUID.randomUUID) + val sqlConf = new SQLConf + sqlConf.setConfString( + SQLConf.STATE_STORE_PROVIDER_CLASS.key, + "org.apache.spark.sql.execution.streaming.state.RocksDbStateStoreProvider") + val storeConf = new StateStoreConf(sqlConf) + assert( + storeConf.providerClass === + "org.apache.spark.sql.execution.streaming.state.RocksDbStateStoreProvider") + val hadoopConf = new Configuration() + + // Verify that trying to get incorrect versions throw errors + intercept[IllegalArgumentException] { + StateStore.get(storeId, keySchema, valueSchema, None, -1, storeConf, hadoopConf) + } + assert(!StateStore.isLoaded(storeId)) // version -1 should not attempt to load the store + + intercept[IllegalStateException] { + StateStore.get(storeId, keySchema, valueSchema, None, 1, storeConf, hadoopConf) + } + + // Increase version of the store and try to get again + val store0 = StateStore.get(storeId, keySchema, valueSchema, None, 0, storeConf, hadoopConf) + assert(store0.version === 0) + put(store0, "a", 1) + store0.commit() + + val store1 = StateStore.get(storeId, keySchema, valueSchema, None, 1, storeConf, hadoopConf) + assert(StateStore.isLoaded(storeId)) + assert(store1.version === 1) + assert(rowsToSet(store1.iterator()) === Set("a" -> 1)) + + // Verify that you can also load older version + val store0reloaded = + StateStore.get(storeId, keySchema, valueSchema, None, 0, storeConf, hadoopConf) + assert(store0reloaded.version === 0) + assert(rowsToSet(store0reloaded.iterator()) === Set.empty) + + // Verify that you can remove the store and still reload and use it + StateStore.unload(storeId) + assert(!StateStore.isLoaded(storeId)) + + val store1reloaded = + StateStore.get(storeId, keySchema, valueSchema, None, 1, storeConf, hadoopConf) + assert(StateStore.isLoaded(storeId)) + assert(store1reloaded.version === 1) + put(store1reloaded, "a", 2) + assert(store1reloaded.commit() === 2) + assert(rowsToSet(store1reloaded.iterator()) === Set("a" -> 2)) + } + } + + test("maintenance") { + val conf = new SparkConf() + .setMaster("local") + .setAppName("test") + // Make maintenance thread do snapshots and cleanups very fast + .set(StateStore.MAINTENANCE_INTERVAL_CONFIG, "10ms") + // Make sure that when SparkContext stops, the StateStore maintenance thread 'quickly' + // fails to talk to the StateStoreCoordinator and unloads all the StateStores + .set("spark.rpc.numRetries", "1") + val opId = 0 + val dir = newDir() + val storeProviderId = StateStoreProviderId(StateStoreId(dir, opId, 0), UUID.randomUUID) + val sqlConf = new SQLConf() + sqlConf.setConfString( + SQLConf.STATE_STORE_PROVIDER_CLASS.key, + "org.apache.spark.sql.execution.streaming.state.RocksDbStateStoreProvider") + sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2) + val storeConf = StateStoreConf(sqlConf) + val hadoopConf = new Configuration() + val provider = newStoreProvider(storeProviderId.storeId) + + var latestStoreVersion = 0 + + def generateStoreVersions() { + for (i <- 1 to 20) { + val store = StateStore.get( + storeProviderId, + keySchema, + valueSchema, + None, + latestStoreVersion, + storeConf, + hadoopConf) + put(store, "a", i) + store.commit() + latestStoreVersion += 1 + } + } + + val timeoutDuration = 60 seconds + + quietly { + withSpark(new SparkContext(conf)) { sc => + withCoordinatorRef(sc) { coordinatorRef => + require(!StateStore.isMaintenanceRunning, "StateStore is unexpectedly running") + + // Generate sufficient versions of store for snapshots + generateStoreVersions() + + eventually(timeout(timeoutDuration)) { + // Store should have been reported to the coordinator + assert( + coordinatorRef.getLocation(storeProviderId).nonEmpty, + "active instance was not reported") + + // Background maintenance should clean up and generate snapshots + assert(StateStore.isMaintenanceRunning, "Maintenance task is not running") + + // Some snapshots should have been generated + val snapshotVersions = (1 to latestStoreVersion).filter { version => + fileExists(provider, version, isSnapshot = true) + } + assert(snapshotVersions.nonEmpty, "no snapshot file found") + } + + // Generate more versions such that there is another snapshot and + // the earliest delta file will be cleaned up + generateStoreVersions() + + // Earliest delta file should get cleaned up + eventually(timeout(timeoutDuration)) { + assert(!fileExists(provider, 1, isSnapshot = false), "earliest file not deleted") + } + + // If driver decides to deactivate all stores related to a query run, + // then this instance should be unloaded + coordinatorRef.deactivateInstances(storeProviderId.queryRunId) + eventually(timeout(timeoutDuration)) { + assert(!StateStore.isLoaded(storeProviderId)) + } + + // Reload the store and verify + StateStore.get( + storeProviderId, + keySchema, + valueSchema, + indexOrdinal = None, + latestStoreVersion, + storeConf, + hadoopConf) + assert(StateStore.isLoaded(storeProviderId)) + + // If some other executor loads the store, then this instance should be unloaded + coordinatorRef.reportActiveInstance(storeProviderId, "other-host", "other-exec") + eventually(timeout(timeoutDuration)) { + assert(!StateStore.isLoaded(storeProviderId)) + } + + // Reload the store and verify + StateStore.get( + storeProviderId, + keySchema, + valueSchema, + indexOrdinal = None, + latestStoreVersion, + storeConf, + hadoopConf) + assert(StateStore.isLoaded(storeProviderId)) + } + } + + // Verify if instance is unloaded if SparkContext is stopped + eventually(timeout(timeoutDuration)) { + require(SparkEnv.get === null) + assert(!StateStore.isLoaded(storeProviderId)) + assert(!StateStore.isMaintenanceRunning) + } + } + } + + test("SPARK-21145: Restarted queries create new provider instances") { + try { + val checkpointLocation = Utils.createTempDir().getAbsoluteFile + val spark = SparkSession.builder().master("local[2]").getOrCreate() + SparkSession.setActiveSession(spark) + implicit val sqlContext = spark.sqlContext + spark.conf.set("spark.sql.shuffle.partitions", "1") + spark.conf.set( + SQLConf.STATE_STORE_PROVIDER_CLASS.key, + "org.apache.spark.sql.execution.streaming.state.RocksDbStateStoreProvider") + import spark.implicits._ + val inputData = MemoryStream[Int] + + def runQueryAndGetLoadedProviders(): Seq[StateStoreProvider] = { + val aggregated = inputData.toDF().groupBy("value").agg(count("*")) + // stateful query + val query = aggregated.writeStream + .format("memory") + .outputMode("complete") + .queryName("query") + .option("checkpointLocation", checkpointLocation.toString) + .start() + inputData.addData(1, 2, 3) + query.processAllAvailable() + require(query.lastProgress != null) // at least one batch processed after start + val loadedProvidersMethod = + PrivateMethod[mutable.HashMap[StateStoreProviderId, StateStoreProvider]]( + 'loadedProviders) + val loadedProvidersMap = StateStore invokePrivate loadedProvidersMethod() + val loadedProviders = loadedProvidersMap.synchronized { loadedProvidersMap.values.toSeq } + query.stop() + loadedProviders + } + + val loadedProvidersAfterRun1 = runQueryAndGetLoadedProviders() + require(loadedProvidersAfterRun1.length === 1) + + val loadedProvidersAfterRun2 = runQueryAndGetLoadedProviders() + assert(loadedProvidersAfterRun2.length === 2) // two providers loaded for 2 runs + + // Both providers should have the same StateStoreId, but the should be different objects + assert( + loadedProvidersAfterRun2(0).stateStoreId === loadedProvidersAfterRun2(1).stateStoreId) + assert(loadedProvidersAfterRun2(0) ne loadedProvidersAfterRun2(1)) + + } finally { + SparkSession.getActiveSession.foreach { spark => + spark.streams.active.foreach(_.stop()) + spark.stop() + } + } + } + + override def newStoreProvider(): RocksDbStateStoreProvider = { + newStoreProvider(opId = Random.nextInt(), partition = 0) + } + + override def newStoreProvider(storeId: StateStoreId): RocksDbStateStoreProvider = { + newStoreProvider( + storeId.operatorId, + storeId.partitionId, + dir = storeId.checkpointRootLocation) + } + + def newStoreProvider(storeId: StateStoreId, localDir: String): RocksDbStateStoreProvider = { + newStoreProvider( + storeId.operatorId, + storeId.partitionId, + dir = storeId.checkpointRootLocation, + localDir = localDir) + } + + override def getLatestData(storeProvider: RocksDbStateStoreProvider): Set[(String, Int)] = { + getData(storeProvider) + } + + override def getData( + provider: RocksDbStateStoreProvider, + version: Int = -1): Set[(String, Int)] = { + val reloadedProvider = newStoreProvider(provider.stateStoreId, provider.getLocalDir) + if (version < 0) { + reloadedProvider.latestIterator().map(rowsToStringInt).toSet + } else { + reloadedProvider.getStore(version).iterator().map(rowsToStringInt).toSet + } + } + + def newStoreProvider( + opId: Long, + partition: Int, + dir: String = newDir(), + localDir: String = newDir(), + minDeltasForSnapshot: Int = SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get, + numOfVersToRetainInMemory: Int = SQLConf.MAX_BATCHES_TO_RETAIN_IN_MEMORY.defaultValue.get, + hadoopConf: Configuration = new Configuration): RocksDbStateStoreProvider = { + val sqlConf = new SQLConf() + sqlConf.setConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT, minDeltasForSnapshot) + sqlConf.setConf(SQLConf.MAX_BATCHES_TO_RETAIN_IN_MEMORY, numOfVersToRetainInMemory) + sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2) + sqlConf.setConfString("spark.sql.streaming.stateStore.rocksDb.localDir", localDir) + val provider = new RocksDbStateStoreProvider + provider.init( + StateStoreId(dir, opId, partition), + keySchema, + valueSchema, + keyIndexOrdinal = None, + new StateStoreConf(sqlConf), + hadoopConf) + provider + } + + def fileExists( + provider: RocksDbStateStoreProvider, + version: Long, + isSnapshot: Boolean): Boolean = { + val method = PrivateMethod[Path]('baseDir) + val basePath = provider invokePrivate method() + val fileName = if (isSnapshot) s"$version.snapshot" else s"$version.delta" + val filePath = new File(basePath.toString, fileName) + filePath.exists + } + + def deleteFilesEarlierThanVersion(provider: RocksDbStateStoreProvider, version: Long): Unit = { + val method = PrivateMethod[Path]('baseDir) + val basePath = provider invokePrivate method() + for (version <- 0 until version.toInt) { + for (isSnapshot <- Seq(false, true)) { + val fileName = if (isSnapshot) s"$version.snapshot" else s"$version.delta" + val filePath = new File(basePath.toString, fileName) + if (filePath.exists) filePath.delete() + } + } + } + + def corruptFile( + provider: RocksDbStateStoreProvider, + version: Long, + isSnapshot: Boolean): Unit = { + val method = PrivateMethod[Path]('baseDir) + val basePath = provider invokePrivate method() + val fileName = if (isSnapshot) s"$version.snapshot" else s"$version.delta" + val filePath = new File(basePath.toString, fileName) + filePath.delete() + filePath.createNewFile() + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index a84d107f2cbc0..c110c6c0a62f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -198,6 +198,58 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] assert(loadedMaps.size() === 0) } + test("get, put, remove, commit, and all data iterator") { + val provider = newStoreProvider() + + // Verify state before starting a new set of updates + assert(getLatestData(provider).isEmpty) + + val store = provider.getStore(0) + assert(!store.hasCommitted) + assert(get(store, "a") === None) + assert(store.iterator().isEmpty) + assert(store.metrics.numKeys === 0) + + // Verify state after updating + put(store, "a", 1) + assert(get(store, "a") === Some(1)) + assert(store.metrics.numKeys === 1) + + assert(store.iterator().nonEmpty) + assert(getLatestData(provider).isEmpty) + + // Make updates, commit and then verify state + put(store, "b", 2) + put(store, "aa", 3) + assert(store.metrics.numKeys === 3) + remove(store, _.startsWith("a")) + assert(store.metrics.numKeys === 1) + assert(store.commit() === 1) + + assert(store.hasCommitted) + assert(rowsToSet(store.iterator()) === Set("b" -> 2)) + assert(getLatestData(provider) === Set("b" -> 2)) + + // Trying to get newer versions should fail + intercept[Exception] { + provider.getStore(2) + } + intercept[Exception] { + getData(provider, 2) + } + + // New updates to the reloaded store with new version, and does not change old version + val reloadedProvider = newStoreProvider(store.id) + val reloadedStore = reloadedProvider.getStore(1) + assert(reloadedStore.metrics.numKeys === 1) + put(reloadedStore, "c", 4) + assert(reloadedStore.metrics.numKeys === 2) + assert(reloadedStore.commit() === 2) + assert(rowsToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4)) + assert(getLatestData(provider) === Set("b" -> 2, "c" -> 4)) + assert(getData(provider, version = 1) === Set("b" -> 2)) + } + test("snapshotting") { val provider = newStoreProvider(opId = Random.nextInt, partition = 0, minDeltasForSnapshot = 5) @@ -817,58 +869,6 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] extends SparkFunSuite { import StateStoreTestsHelper._ - test("get, put, remove, commit, and all data iterator") { - val provider = newStoreProvider() - - // Verify state before starting a new set of updates - assert(getLatestData(provider).isEmpty) - - val store = provider.getStore(0) - assert(!store.hasCommitted) - assert(get(store, "a") === None) - assert(store.iterator().isEmpty) - assert(store.metrics.numKeys === 0) - - // Verify state after updating - put(store, "a", 1) - assert(get(store, "a") === Some(1)) - assert(store.metrics.numKeys === 1) - - assert(store.iterator().nonEmpty) - assert(getLatestData(provider).isEmpty) - - // Make updates, commit and then verify state - put(store, "b", 2) - put(store, "aa", 3) - assert(store.metrics.numKeys === 3) - remove(store, _.startsWith("a")) - assert(store.metrics.numKeys === 1) - assert(store.commit() === 1) - - assert(store.hasCommitted) - assert(rowsToSet(store.iterator()) === Set("b" -> 2)) - assert(getLatestData(provider) === Set("b" -> 2)) - - // Trying to get newer versions should fail - intercept[Exception] { - provider.getStore(2) - } - intercept[Exception] { - getData(provider, 2) - } - - // New updates to the reloaded store with new version, and does not change old version - val reloadedProvider = newStoreProvider(store.id) - val reloadedStore = reloadedProvider.getStore(1) - assert(reloadedStore.metrics.numKeys === 1) - put(reloadedStore, "c", 4) - assert(reloadedStore.metrics.numKeys === 2) - assert(reloadedStore.commit() === 2) - assert(rowsToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4)) - assert(getLatestData(provider) === Set("b" -> 2, "c" -> 4)) - assert(getData(provider, version = 1) === Set("b" -> 2)) - } - test("removing while iterating") { val provider = newStoreProvider()