diff --git a/core/src/main/java/org/apache/spark/io/FileUtility.java b/core/src/main/java/org/apache/spark/io/FileUtility.java
new file mode 100644
index 0000000000000..5e21c6218b2f7
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/io/FileUtility.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.io;
+
+import org.apache.commons.compress.archivers.tar.TarArchiveEntry;
+import org.apache.commons.compress.archivers.tar.TarArchiveInputStream;
+import org.apache.commons.compress.archivers.tar.TarArchiveOutputStream;
+import org.apache.commons.compress.utils.IOUtils;
+
+import java.io.*;
+
+public class FileUtility {
+
+ /**
+ * Extract an input tar file into an output files and directories.
+ * inputTarFileLoc: the input file location for the tar file
+ * destDirLoc: destination for the extracted files
+ *
+ * throws IllegalStateException
+ */
+ public static final String ENCODING = "utf-8";
+
+ public static void extractTarFile(String inputTarFileLoc, String destDirLoc)
+ throws IllegalStateException {
+ File inputFile = new File(inputTarFileLoc);
+ if (!inputTarFileLoc.endsWith(".tar")) {
+ throw new IllegalStateException(String.format(
+ "Input File[%s] should end with tar extension.", inputTarFileLoc));
+ }
+ File destDir = new File(destDirLoc);
+ if (destDir.exists() && !destDir.delete()) {
+ throw new IllegalStateException(String.format(
+ "Couldn't delete the existing destination directory[%s] ", destDirLoc));
+ } else if (!destDir.mkdir()) {
+ throw new IllegalStateException(String.format(
+ "Couldn't create directory %s ", destDirLoc));
+ }
+
+ try (InputStream is = new FileInputStream(inputFile);
+ TarArchiveInputStream debInputStream = new TarArchiveInputStream(is, ENCODING)) {
+ TarArchiveEntry entry;
+ while ((entry = (TarArchiveEntry) debInputStream.getNextEntry()) != null) {
+ final File outputFile = new File(destDirLoc, entry.getName());
+ if (entry.isDirectory()) {
+ if (!outputFile.exists() && !outputFile.mkdirs()) {
+ throw new IllegalStateException(String.format(
+ "Couldn't create directory %s.", outputFile.getAbsolutePath()));
+ }
+ } else {
+ try (OutputStream outputFileStream = new FileOutputStream(outputFile)) {
+ IOUtils.copy(debInputStream, outputFileStream);
+ }
+ }
+ }
+ } catch (IOException e){
+ throw new IllegalStateException(String.format(
+ "extractTarFile failed with exception %s.", e.getMessage()));
+ }
+ }
+
+ /**
+ * create a tar file for input source directory location .
+ * source: the source directory location
+ * destFileLoc: destination of the created tarball
+ *
+ * throws IllegalStateException
+ */
+
+ public static void createTarFile(String source, String destFileLoc)
+ throws IllegalStateException {
+ File f = new File(destFileLoc);
+ if (f.exists() && !f.delete()) {
+ throw new IllegalStateException(String.format(
+ "Couldn't delete the destination file location[%s]", destFileLoc));
+ }
+ File folder = new File(source);
+ if (!folder.exists()) {
+ throw new IllegalStateException(String.format(
+ "Source folder[%s] does not exist", source));
+ }
+
+ try (FileOutputStream fos = new FileOutputStream(destFileLoc);
+ TarArchiveOutputStream tarOs = new TarArchiveOutputStream(fos, ENCODING)) {
+ File[] fileNames = folder.listFiles();
+ for (File file : fileNames) {
+ TarArchiveEntry tar_file = new TarArchiveEntry(file.getName());
+ tar_file.setSize(file.length());
+ tarOs.putArchiveEntry(tar_file);
+ try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(file))) {
+ IOUtils.copy(bis, tarOs);
+ tarOs.closeArchiveEntry();
+ }
+ }
+ tarOs.finish();
+ } catch (IOException e) {
+ throw new IllegalStateException(String.format(
+ "createTarFile failed with exception %s.", e.getMessage()));
+ }
+ }
+
+}
diff --git a/core/src/test/java/org/apache/spark/io/FileUtilitySuite.java b/core/src/test/java/org/apache/spark/io/FileUtilitySuite.java
new file mode 100644
index 0000000000000..a9a21a302eca8
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/io/FileUtilitySuite.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.io;
+
+import org.apache.commons.io.FileUtils;
+import org.apache.commons.lang3.RandomUtils;
+import org.apache.spark.util.Utils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.io.File;
+import java.io.IOException;
+
+/**
+ * Tests functionality of {@link FileUtility}
+ */
+public class FileUtilitySuite {
+
+ protected File sourceFolder;
+ protected File destTarLoc;
+ protected File destFolder;
+
+ @Before
+ public void setUp() throws IOException {
+ String tmpDir = System.getProperty("java.io.tmpdir");
+ sourceFolder = Utils.createTempDir(tmpDir, "FileUtilTest-src-" + RandomUtils.nextLong());
+ destFolder = Utils.createTempDir(tmpDir, "FileUtilTest-dest-" + RandomUtils.nextLong());
+ destTarLoc= File.createTempFile("dest-tar", ".tar");
+ }
+
+ @After
+ public void tearDown() {
+ destTarLoc.delete();
+ }
+
+ @Test
+ public void testCreationAndExtraction() throws IllegalStateException, IOException {
+ // Create a temp file in the source folder
+ Assert.assertEquals(sourceFolder.listFiles().length , 0);
+ File inputFile = File.createTempFile("source-file", ".tmp", sourceFolder);
+ // Create a byte array of size 1 KB with random bytes
+ byte[] randomBytes = RandomUtils.nextBytes(1 * 1024);
+ FileUtils.writeByteArrayToFile(inputFile, randomBytes);
+
+ // Create the tarball
+ destTarLoc.delete();
+ Assert.assertFalse(destTarLoc.exists());
+ FileUtility.createTarFile(sourceFolder.toString(), destTarLoc.getAbsolutePath());
+ Assert.assertTrue(destTarLoc.exists());
+
+ // Extract the tarball
+ Assert.assertEquals(destFolder.listFiles().length , 0);
+ FileUtility.extractTarFile(destTarLoc.getAbsolutePath(), destFolder.getAbsolutePath());
+
+ // Verify that the extraction was successful
+ Assert.assertTrue(destFolder.exists());
+ Assert.assertEquals(destFolder.listFiles().length , 1);
+ Assert.assertArrayEquals(randomBytes, FileUtils.readFileToByteArray(destFolder.listFiles()[0]));
+ }
+
+}
diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7
index 2da4c9e44b29e..97b713cd1403f 100644
--- a/dev/deps/spark-deps-hadoop-2.7
+++ b/dev/deps/spark-deps-hadoop-2.7
@@ -175,6 +175,7 @@ parquet-jackson-1.10.1.jar
protobuf-java-2.5.0.jar
py4j-0.10.8.1.jar
pyrolite-4.30.jar
+rocksdbjni-6.2.2.jar
scala-collection-compat_2.12-2.1.1.jar
scala-compiler-2.12.10.jar
scala-library-2.12.10.jar
diff --git a/dev/deps/spark-deps-hadoop-3.2 b/dev/deps/spark-deps-hadoop-3.2
index 2226baeadfba1..c0ab22d32abcb 100644
--- a/dev/deps/spark-deps-hadoop-3.2
+++ b/dev/deps/spark-deps-hadoop-3.2
@@ -194,6 +194,7 @@ protobuf-java-2.5.0.jar
py4j-0.10.8.1.jar
pyrolite-4.30.jar
re2j-1.1.jar
+rocksdbjni-6.2.2.jar
scala-collection-compat_2.12-2.1.1.jar
scala-compiler-2.12.10.jar
scala-library-2.12.10.jar
diff --git a/pom.xml b/pom.xml
index 2396c5168b166..99514754a1c3b 100644
--- a/pom.xml
+++ b/pom.xml
@@ -193,6 +193,7 @@
1.1
2.52.0
2.22
+ 6.2.2
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index 02ed6f8adaa62..2112543cc9b82 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -147,6 +147,11 @@
mockito-core
test
+
+ org.rocksdb
+ rocksdbjni
+ ${rocksdb.version}
+
target/scala-${scala.binary.version}/classes
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDbInstance.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDbInstance.scala
new file mode 100644
index 0000000000000..2b5e6751317d0
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDbInstance.scala
@@ -0,0 +1,412 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import java.io.BufferedWriter
+import java.io.File
+import java.io.FileWriter
+
+import scala.collection.JavaConverters._
+
+import org.apache.commons.io.FileUtils
+import org.rocksdb._
+import org.rocksdb.RocksDB
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.Utils
+
+class RocksDbInstance(keySchema: StructType, valueSchema: StructType, version: String)
+ extends Logging {
+
+ import RocksDbInstance._
+ RocksDB.loadLibrary()
+
+ protected var db: RocksDB = null
+ protected var dbPath: String = _
+ protected val readOptions: ReadOptions = new ReadOptions()
+ protected val writeOptions: WriteOptions = new WriteOptions()
+ protected val table_options = new BlockBasedTableConfig
+ protected val options: Options = new Options()
+
+ private def isOpen(): Boolean = {
+ db != null
+ }
+
+ def open(path: String, readOnly: Boolean): Unit = {
+ require(db == null, "Another rocksDb instance is already active")
+ try {
+ setOptions
+ db = if (readOnly) {
+ options.setCreateIfMissing(false)
+ RocksDB.openReadOnly(options, path)
+ } else {
+ options.setCreateIfMissing(true)
+ RocksDB.open(options, path)
+ }
+ dbPath = path
+ } catch {
+ case e: Throwable =>
+ throw new IllegalStateException(
+ s"Error while creating rocksDb instance ${e.getMessage}",
+ e)
+ }
+ }
+
+ def get(key: UnsafeRow): UnsafeRow = {
+ require(isOpen(), "Open rocksDb instance before any operation")
+ Option(db.get(readOptions, key.getBytes)) match {
+ case Some(valueInBytes) =>
+ val value = new UnsafeRow(valueSchema.fields.length)
+ value.pointTo(valueInBytes, valueInBytes.length)
+ value
+ case None => null
+ }
+ }
+
+ def put(key: UnsafeRow, value: UnsafeRow): Unit = {
+ require(isOpen(), "Open rocksDb instance before any operation")
+ db.put(key.getBytes, value.getBytes)
+ }
+
+ def remove(key: UnsafeRow): Unit = {
+ require(isOpen(), "Open rocksDb instance before any operation")
+ db.delete(key.getBytes)
+ }
+
+ def commit(checkPointPath: Option[String] = None): Unit = {
+ checkPointPath.foreach(f => createCheckpoint(db, f))
+ }
+
+ def abort: Unit = {
+ // no-op
+ }
+
+ def close(): Unit = {
+ logDebug("Closing the db")
+ try {
+ db.close()
+ } finally {
+ db = null
+ options.close()
+ readOptions.close()
+ writeOptions.close()
+ }
+ }
+
+ def iterator(closeDbOnCompletion: Boolean): Iterator[UnsafeRowPair] = {
+ require(isOpen(), "Open rocksDb instance before any operation")
+ Option(db.getSnapshot) match {
+ case Some(snapshot) =>
+ var snapshotReadOptions: ReadOptions =
+ new ReadOptions().setSnapshot(snapshot).setFillCache(false)
+ val itr = db.newIterator(snapshotReadOptions)
+ createUnsafeRowPairIterator(itr, snapshotReadOptions, snapshot, closeDbOnCompletion)
+ case None =>
+ Iterator.empty
+ }
+ }
+
+ protected def createUnsafeRowPairIterator(
+ itr: RocksIterator,
+ itrReadOptions: ReadOptions,
+ snapshot: Snapshot,
+ closeDbOnCompletion: Boolean): Iterator[UnsafeRowPair] = {
+
+ itr.seekToFirst()
+
+ new Iterator[UnsafeRowPair] {
+ @volatile var isClosed = false
+ override def hasNext: Boolean = {
+ if (!isClosed && itr.isValid) {
+ true
+ } else {
+ if (!isClosed) {
+ isClosed = true
+ itrReadOptions.close()
+ db.releaseSnapshot(snapshot)
+ if (closeDbOnCompletion) {
+ close()
+ }
+ itr.close()
+ logDebug(s"read from DB completed")
+ }
+ false
+ }
+ }
+
+ override def next(): UnsafeRowPair = {
+ val keyBytes = itr.key
+ val key = new UnsafeRow(keySchema.fields.length)
+ key.pointTo(keyBytes, keyBytes.length)
+ val valueBytes = itr.value
+ val value = new UnsafeRow(valueSchema.fields.length)
+ value.pointTo(valueBytes, valueBytes.length)
+ itr.next()
+ new UnsafeRowPair(key, value)
+ }
+ }
+ }
+
+ protected def printMemoryStats(db: RocksDB): Unit = {
+ require(isOpen(), "Open rocksDb instance before any operation")
+ val usage = MemoryUtil
+ .getApproximateMemoryUsageByType(
+ List(db).asJava,
+ Set(rocksDbLRUCache.asInstanceOf[Cache]).asJava)
+ .asScala
+ val numKeys = db.getProperty(db.getDefaultColumnFamily, "rocksdb.estimate-num-keys")
+ logDebug(s"""
+ | rocksdb.estimate-num-keys = $numKeys
+ | ApproximateMemoryUsageByType = ${usage.toString}
+ | """.stripMargin)
+ }
+
+ protected def printStats: Unit = {
+ require(isOpen(), "Open rocksDb instance before any operation")
+ try {
+ val stats = db.getProperty("rocksdb.stats")
+ logInfo(s"Stats = $stats")
+ } catch {
+ case e: Exception =>
+ logWarning("Exception while getting stats")
+ }
+ }
+
+ private val dataBlockSize = RocksDbStateStoreConf.blockSizeInKB
+ private val memTableMemoryBudget = RocksDbStateStoreConf.memtableBudgetInMB
+ private val enableStats = RocksDbStateStoreConf.enableStats
+
+ protected def setOptions(): Unit = {
+
+ // Read options
+ readOptions.setFillCache(true)
+
+ // Write options
+ writeOptions.setSync(false)
+ writeOptions.setDisableWAL(true)
+
+ /*
+ Table configs
+ Use Partitioned Index Filters
+ https://github.com/facebook/rocksdb/wiki/Partitioned-Index-Filters
+ Use format Verion = 4
+ https://rocksdb.org/blog/2019/03/08/format-version-4.html
+ */
+ table_options
+ .setBlockSize(dataBlockSize * 1024)
+ .setFormatVersion(4)
+ .setDataBlockIndexType(DataBlockIndexType.kDataBlockBinaryAndHash)
+ .setBlockCache(rocksDbLRUCache)
+ .setFilterPolicy(new BloomFilter(10, false))
+ .setPinTopLevelIndexAndFilter(false) // Dont pin anything in cache
+ .setIndexType(IndexType.kTwoLevelIndexSearch)
+ .setPartitionFilters(true)
+
+ options
+ .setTableFormatConfig(table_options)
+ .optimizeLevelStyleCompaction(memTableMemoryBudget * 1024 * 1024)
+ .setBytesPerSync(1048576)
+ .setMaxOpenFiles(5000)
+ .setIncreaseParallelism(4)
+
+ if (enableStats) {
+ options
+ .setStatistics(new Statistics())
+ .setStatsDumpPeriodSec(30)
+ }
+ }
+
+ protected def createCheckpoint(rocksDb: RocksDB, dir: String): Unit = {
+ require(isOpen(), "Open rocksDb instance before any operation")
+ val (result, elapsedMs) = Utils.timeTakenMs {
+ val c = Checkpoint.create(rocksDb)
+ val f: File = new File(dir)
+ if (f.exists()) {
+ FileUtils.deleteDirectory(f)
+ }
+ c.createCheckpoint(dir)
+ c.close()
+ }
+ logInfo(s"Creating Checkpoint at $dir took $elapsedMs ms.")
+ }
+}
+
+class OptimisticTransactionDbInstance(
+ keySchema: StructType,
+ valueSchema: StructType,
+ version: String)
+ extends RocksDbInstance(keySchema: StructType, valueSchema: StructType, version: String) {
+ import RocksDbInstance._
+ RocksDB.loadLibrary()
+
+ private var otdb: OptimisticTransactionDB = null
+ private var txn: Transaction = null
+
+ private def isOpen(): Boolean = {
+ otdb != null
+ }
+
+ def open(path: String): Unit = {
+ open(path, false)
+ }
+
+ override def open(path: String, readOnly: Boolean): Unit = {
+ require(otdb == null, "Another OptimisticTransactionDbInstance instance is already active")
+ require(readOnly == false, "Cannot open OptimisticTransactionDbInstance in Readonly mode")
+ try {
+ setOptions()
+ options.setCreateIfMissing(true)
+ otdb = OptimisticTransactionDB.open(options, path)
+ db = otdb.getBaseDB
+ dbPath = path
+ } catch {
+ case e: Throwable =>
+ throw new IllegalStateException(
+ s"Error while creating OptimisticTransactionDb instance" +
+ s" ${e.getMessage}",
+ e)
+ }
+ }
+
+ def startTransactions(): Unit = {
+ require(isOpen(), "Open OptimisticTransactionDbInstance before performing any operation")
+ Option(txn) match {
+ case None =>
+ val optimisticTransactionOptions = new OptimisticTransactionOptions()
+ txn = otdb.beginTransaction(writeOptions, optimisticTransactionOptions)
+ txn.setSavePoint()
+ case Some(x) =>
+ throw new IllegalStateException(s"Already started a transaction")
+ }
+ }
+
+ override def put(key: UnsafeRow, value: UnsafeRow): Unit = {
+ require(txn != null, "Start Transaction before inserting any key")
+ txn.put(key.getBytes, value.getBytes)
+ }
+
+ override def remove(key: UnsafeRow): Unit = {
+ require(txn != null, "Start Transaction before deleting any key")
+ txn.delete(key.getBytes)
+ }
+
+ override def get(key: UnsafeRow): UnsafeRow = {
+ require(txn != null, "Start Transaction before fetching any key-value")
+ Option(txn.get(readOptions, key.getBytes)) match {
+ case Some(valueInBytes) =>
+ val value = new UnsafeRow(valueSchema.fields.length)
+ value.pointTo(valueInBytes, valueInBytes.length)
+ value
+ case None =>
+ null
+ }
+ }
+
+ override def commit(checkPointPath: Option[String] = None): Unit = {
+ require(txn != null, "Start Transaction before fetching any key-value")
+ try {
+ txn.commit()
+ txn.close()
+ updateVersionInCommitFile()
+ checkPointPath.foreach(f => createCheckpoint(otdb.asInstanceOf[RocksDB], f))
+ } catch {
+ case e: Exception =>
+ log.error(s"Unable to commit the transactions. Error message = ${e.getMessage}")
+ throw e
+ } finally {
+ txn = null
+ }
+ }
+
+ override def abort(): Unit = {
+ require(txn != null, "No Transaction to abort")
+ txn.rollbackToSavePoint()
+ txn.close()
+ txn = null
+ }
+
+ override def close(): Unit = {
+ require(isOpen(), "No DB to close")
+ require(txn == null, "Transaction should be closed before closing the DB connection")
+ printMemoryStats(otdb.asInstanceOf[RocksDB])
+ logDebug("Closing the transaction db")
+ try {
+ otdb.close()
+ db.close()
+ otdb = null
+ db = null
+ } finally {
+ options.close()
+ readOptions.close()
+ writeOptions.close()
+ }
+ }
+
+ override def iterator(closeDbOnCompletion: Boolean): Iterator[UnsafeRowPair] = {
+ require(txn != null, "Transaction is not set")
+ require(
+ closeDbOnCompletion == false,
+ "Cannot close a DB without aborting/committing the transactions")
+ val snapshot = db.getSnapshot
+ val readOptions = new ReadOptions()
+ .setSnapshot(snapshot)
+ .setFillCache(false) // for range lookup, we should not fill cache
+ val itr: RocksIterator = txn.getIterator(readOptions)
+ Option(itr) match {
+ case Some(i) =>
+ logDebug(s"creating iterator from a transactional DB")
+ createUnsafeRowPairIterator(i, readOptions, snapshot, false)
+ case None =>
+ Iterator.empty
+ }
+ }
+
+ def getApproxEntriesInDb(): Long = {
+ require(isOpen(), "No DB to find Database Entries")
+ otdb.getProperty("rocksdb.estimate-num-keys").toLong
+ }
+
+ protected def updateVersionInCommitFile(): Unit = {
+ val file = new File(dbPath, COMMIT_FILE_NAME)
+ val bw = new BufferedWriter(new FileWriter(file))
+ bw.write(version.toString)
+ bw.close()
+ }
+
+}
+
+object RocksDbInstance {
+
+ RocksDB.loadLibrary()
+
+ val COMMIT_FILE_NAME = "commit"
+
+ lazy val rocksDbLRUCache = new LRUCache(RocksDbStateStoreConf.cacheSize * 1024 * 1024, 6, false)
+
+ def destroyDB(path: String): Unit = {
+ val f: File = new File(path)
+ val destroyOptions: Options = new Options()
+ if (f.exists()) {
+ RocksDB.destroyDB(path, destroyOptions)
+ FileUtils.deleteDirectory(f)
+ }
+ }
+
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDbStateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDbStateStoreConf.scala
new file mode 100644
index 0000000000000..9d55b6c786b52
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDbStateStoreConf.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.internal.config.ConfigBuilder
+
+object RocksDbStateStoreConf {
+
+ private[spark] val ROCKSDB_STATE_STORE_DATA_BLOCK_SIZE =
+ ConfigBuilder("spark.sql.streaming.stateStore.rocksDb.blockSizeInKB")
+ .doc(
+ "The maximum size (in KB) of packed data in a block of a table file. " +
+ "When reading from a table, an entire block is loaded into memory")
+ .intConf
+ .createWithDefault(32)
+
+ private[spark] val ROCKSDB_STATE_STORE_MEMTABLE_BUDGET =
+ ConfigBuilder("spark.sql.streaming.stateStore.rocksDb.memtableBudgetInMB")
+ .doc("The maximum size (in MB) of memory to be used to optimize level style compaction")
+ .intConf
+ .createWithDefault(1024)
+
+ private[spark] val ROCKSDB_STATE_STORE_CACHE_SIZE =
+ ConfigBuilder("spark.sql.streaming.stateStore.rocksDb.cacheSizeInMB")
+ .doc("The maximum size (in MB) of in-memory LRU cache for RocksDB operations")
+ .intConf
+ .createWithDefault(512)
+
+ private[spark] val ROCKSDB_STATE_STORE_ENABLE_STATS =
+ ConfigBuilder("spark.sql.streaming.stateStore.rocksDb.enableDbStats")
+ .doc("Enable statistics for rocksdb for debugging and reporting")
+ .booleanConf
+ .createWithDefault(false)
+
+ val blockSizeInKB: Int = Option(SparkEnv.get)
+ .map(_.conf.get(ROCKSDB_STATE_STORE_DATA_BLOCK_SIZE))
+ .getOrElse(32)
+
+ val memtableBudgetInMB: Int = Option(SparkEnv.get)
+ .map(_.conf.get(ROCKSDB_STATE_STORE_MEMTABLE_BUDGET))
+ .getOrElse(1024)
+
+ val cacheSize: Int = Option(SparkEnv.get)
+ .map(_.conf.get(RocksDbStateStoreConf.ROCKSDB_STATE_STORE_CACHE_SIZE))
+ .getOrElse(512)
+
+ val enableStats: Boolean = Option(SparkEnv.get)
+ .map(_.conf.get(ROCKSDB_STATE_STORE_ENABLE_STATS))
+ .getOrElse(false)
+
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDbStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDbStateStoreProvider.scala
new file mode 100644
index 0000000000000..cdc803b860a19
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDbStateStoreProvider.scala
@@ -0,0 +1,646 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state;
+
+import java.io._
+import java.util
+
+import scala.collection.JavaConverters._
+import scala.io.Source
+import scala.util.control.NonFatal
+
+import org.apache.commons.io.FileUtils
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs._
+
+import org.apache.spark.{SparkConf, SparkEnv}
+import org.apache.spark.internal.Logging
+import org.apache.spark.io.FileUtility
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.streaming.CheckpointFileManager
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.Utils
+
+/*
+ * An implementation of [[StateStoreProvider]] and [[StateStore]] using RocksDB as the storage
+ * engine. In RocksDB, new writes are inserted into a memtable which is flushed into local storage
+ * when the memtable fills up. It improves scalability as compared to
+ * [[HDFSBackedStateStoreProvider]] since now the state data which was large enough to fit in the
+ * executor memory can be written into the combination of memtable and local storage.The data is
+ * backed in a HDFS-compatible file system just like [[HDFSBackedStateStoreProvider]]
+ *
+ * Fault-tolerance model:
+ * - Every set of updates is written to a delta file before committing.
+ * - The state store is responsible for managing, collapsing and cleaning up of delta files.
+ * - Updates are committed in the db atomically
+ *
+ * Backup Model:
+ * - Delta file is written in a HDFS-compatible file system on batch commit
+ * - RocksDB state is check-pointed into a separate folder on batch commit
+ * - Maintenance thread periodically takes a snapshot of the latest check-pointed version of
+ * rocksDB state which is written to a HDFS-compatible file system.
+ *
+ * Isolation Guarantee:
+ * - writes are committed in the transaction.
+ * - writer thread which started the transaction can read all un-committed updates
+ * - any other reader thread cannot read any un-committed updates
+ */
+private[sql] class RocksDbStateStoreProvider extends StateStoreProvider with Logging {
+
+ /* Internal fields and methods */
+ @volatile private var stateStoreId_ : StateStoreId = _
+ @volatile private var keySchema: StructType = _
+ @volatile private var valueSchema: StructType = _
+ @volatile private var storeConf: StateStoreConf = _
+ @volatile private var hadoopConf: Configuration = _
+ @volatile private var numberOfVersionsToRetain: Int = _
+ @volatile private var localDir: String = _
+
+ private lazy val baseDir: Path = stateStoreId.storeCheckpointLocation()
+ private lazy val fm = CheckpointFileManager.create(baseDir, hadoopConf)
+ private lazy val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf)
+
+ private case class StoreFile(version: Long, path: Path, isSnapshot: Boolean)
+
+ import WALUtils._
+
+ /** Implementation of [[StateStore]] API which is backed by RocksDB and HDFS */
+ class RocksDbStateStore(val version: Long) extends StateStore with Logging {
+
+ /** Trait and classes representing the internal state of the store */
+ trait STATE
+
+ case object LOADED extends STATE
+
+ case object UPDATING extends STATE
+
+ case object COMMITTED extends STATE
+
+ case object ABORTED extends STATE
+
+ private val newVersion = version + 1
+ @volatile private var state: STATE = LOADED
+ private val finalDeltaFile: Path = deltaFile(baseDir, newVersion)
+ private lazy val deltaFileStream = fm.createAtomic(finalDeltaFile, overwriteIfPossible = true)
+ private lazy val compressedStream = compressStream(deltaFileStream, sparkConf)
+
+ override def id: StateStoreId = RocksDbStateStoreProvider.this.stateStoreId
+
+ var rocksDbWriteInstance: OptimisticTransactionDbInstance = null
+
+ /*
+ * numEntriesInDb and bytesUsedByDb are estimated value
+ * due to the nature of RocksDB implementation.
+ * see https://github.com/facebook/rocksdb/wiki/RocksDB-FAQ for more details
+ */
+ var numEntriesInDb: Long = 0L
+ var bytesUsedByDb: Long = 0L
+
+ private def initTransaction(): Unit = {
+ if (state == LOADED && rocksDbWriteInstance == null) {
+ logDebug(s"Creating Transactional DB for batch $version")
+ rocksDbWriteInstance =
+ new OptimisticTransactionDbInstance(keySchema, valueSchema, newVersion.toString)
+ rocksDbWriteInstance.open(rocksDbPath)
+ rocksDbWriteInstance.startTransactions()
+ state = UPDATING
+ }
+ }
+
+ override def get(key: UnsafeRow): UnsafeRow = {
+ initTransaction()
+ rocksDbWriteInstance.get(key)
+ }
+
+ override def put(key: UnsafeRow, value: UnsafeRow): Unit = {
+ initTransaction()
+ require(state == UPDATING, s"Cannot put after already committed or aborted")
+ val keyCopy = key.copy()
+ val valueCopy = value.copy()
+ rocksDbWriteInstance.put(keyCopy, valueCopy)
+ writeUpdateToDeltaFile(compressedStream, keyCopy, valueCopy)
+ }
+
+ override def remove(key: UnsafeRow): Unit = {
+ initTransaction()
+ require(state == UPDATING, "Cannot remove after already committed or aborted")
+ rocksDbWriteInstance.remove(key)
+ writeRemoveToDeltaFile(compressedStream, key)
+ }
+
+ override def getRange(
+ start: Option[UnsafeRow],
+ end: Option[UnsafeRow]): Iterator[UnsafeRowPair] = {
+ require(state == UPDATING, "Cannot getRange after already committed or aborted")
+ iterator()
+ }
+
+ /** Commit all the updates that have been made to the store, and return the new version. */
+ override def commit(): Long = {
+ initTransaction()
+ require(state == UPDATING, s"Cannot commit after already committed or aborted")
+ try {
+ synchronized {
+ rocksDbWriteInstance.commit(Some(getCheckpointPath(newVersion)))
+ finalizeDeltaFile(compressedStream)
+ }
+ state = COMMITTED
+ numEntriesInDb = rocksDbWriteInstance.getApproxEntriesInDb()
+ bytesUsedByDb = numEntriesInDb * (keySchema.defaultSize + valueSchema.defaultSize)
+ newVersion
+ } catch {
+ case NonFatal(e) =>
+ throw new IllegalStateException(s"Error committing version $newVersion into $this", e)
+ } finally {
+ storeMap.remove(version)
+ close()
+ }
+ }
+
+ /*
+ * Abort all the updates made on this store. This store will not be usable any more.
+ */
+ override def abort(): Unit = {
+ // This if statement is to ensure that files are deleted only if there are changes to the
+ // StateStore. We have two StateStores for each task, one which is used only for reading, and
+ // the other used for read+write. We don't want the read-only to delete state files.
+ try {
+ if (state == UPDATING) {
+ state = ABORTED
+ synchronized {
+ rocksDbWriteInstance.abort()
+ cancelDeltaFile(compressedStream, deltaFileStream)
+ }
+ logInfo(s"Aborted version $newVersion for $this")
+ } else {
+ state = ABORTED
+ }
+ } catch {
+ case NonFatal(e) =>
+ throw new IllegalStateException(s"Error aborting version $newVersion into $this", e)
+ } finally {
+ storeMap.remove(version)
+ close()
+ }
+ }
+
+ def close(): Unit = {
+ if (rocksDbWriteInstance != null) {
+ rocksDbWriteInstance.close()
+ rocksDbWriteInstance = null
+ }
+ }
+
+ /*
+ * Get an iterator of all the store data.
+ * This can be called only after committing all the updates made in the current thread.
+ */
+ override def iterator(): Iterator[UnsafeRowPair] = {
+ state match {
+ case UPDATING =>
+ logDebug("state = updating using transaction DB")
+ // We need to use current db to read uncommitted transactions
+ rocksDbWriteInstance.iterator(closeDbOnCompletion = false)
+
+ case LOADED | ABORTED =>
+ // use check-pointed db for previous version
+ logDebug(s"state = loaded/aborted using check-pointed DB with version $version")
+ if (version == 0) {
+ Iterator.empty
+ } else {
+ val path = getCheckpointPath(version)
+ val r: RocksDbInstance =
+ new RocksDbInstance(keySchema, valueSchema, version.toString)
+ r.open(path, readOnly = true)
+ r.iterator(closeDbOnCompletion = true)
+ }
+ case COMMITTED =>
+ logDebug(s"state = committed using check-pointed DB with version $newVersion")
+ // use check-pointed db for current updated version
+ val path = getCheckpointPath(newVersion)
+ val r: RocksDbInstance =
+ new RocksDbInstance(keySchema, valueSchema, newVersion.toString)
+ r.open(path, readOnly = true)
+ r.iterator(closeDbOnCompletion = true)
+
+ case _ => Iterator.empty
+ }
+ }
+
+ override def metrics: StateStoreMetrics = {
+ val metricsFromProvider: Map[String, Long] = getMetricsForProvider()
+ val customMetrics = metricsFromProvider.flatMap {
+ case (name, value) =>
+ // just allow searching from list cause the list is small enough
+ supportedCustomMetrics.find(_.name == name).map(_ -> value)
+ }
+ StateStoreMetrics(Math.max(numEntriesInDb, 0), Math.max(bytesUsedByDb, 0), customMetrics)
+ }
+
+ /*
+ * Whether all updates have been committed
+ */
+ override def hasCommitted: Boolean = {
+ state == COMMITTED
+ }
+
+ override def toString(): String = {
+ s"RocksDbStateStore[id=(op=${id.operatorId},part=${id.partitionId}),dir=$baseDir]"
+ }
+
+ }
+
+ /*
+ * Initialize the provider with more contextual information from the SQL operator.
+ * This method will be called first after creating an instance of the StateStoreProvider by
+ * reflection.
+ *
+ * @param stateStoreId Id of the versioned StateStores that this provider will generate
+ * @param keySchema Schema of keys to be stored
+ * @param valueSchema Schema of value to be stored
+ * @param keyIndexOrdinal Optional column (represent as the ordinal of the field in keySchema) by
+ * which the StateStore implementation could index the data.
+ * @param storeConfs Configurations used by the StateStores
+ * @param hadoopConf Hadoop configuration that could be used by StateStore
+ * to save state data
+ */
+ override def init(
+ stateStoreId: StateStoreId,
+ keySchema: StructType,
+ valueSchema: StructType,
+ keyIndexOrdinal: Option[Int], // for sorting the data by their keys
+ storeConfs: StateStoreConf,
+ hadoopConf: Configuration): Unit = {
+ this.stateStoreId_ = stateStoreId
+ this.keySchema = keySchema
+ this.valueSchema = valueSchema
+ this.storeConf = storeConfs
+ this.hadoopConf = hadoopConf
+ this.numberOfVersionsToRetain = storeConfs.maxVersionsToRetainInMemory
+ fm.mkdirs(baseDir)
+ this.localDir = storeConfs.confs
+ .getOrElse(
+ "spark.sql.streaming.stateStore.rocksDb.localDir",
+ Utils.createTempDir().getAbsolutePath)
+ }
+
+ /*
+ * Return the id of the StateStores this provider will generate.
+ * Should be the same as the one passed in init().
+ */
+ override def stateStoreId: StateStoreId = stateStoreId_
+
+ /*
+ * Called when the provider instance is unloaded from the executor
+ */
+ override def close(): Unit = {
+ storeMap.values.asScala.foreach(_.close)
+ storeMap.clear()
+ }
+
+ private val storeMap = new util.HashMap[Long, RocksDbStateStore]()
+
+ /*
+ * Optional custom metrics that the implementation may want to report.
+ *
+ * @note The StateStore objects created by this provider must report the same custom metrics
+ * (specifically, same names) through `StateStore.metrics`.
+ */
+ // TODO
+ override def supportedCustomMetrics: Seq[StateStoreCustomMetric] = {
+ Nil
+ }
+
+ override def toString(): String = {
+ s"RocksDbStateStoreProvider[" +
+ s"id = (op=${stateStoreId.operatorId},part=${stateStoreId.partitionId}),dir = $baseDir]"
+ }
+
+ def getMetricsForProvider(): Map[String, Long] = synchronized {
+ Map.empty[String, Long]
+ }
+
+ /*
+ * Return an instance of [[StateStore]] representing state data of the given version
+ */
+ override def getStore(version: Long): StateStore = synchronized {
+ logInfo(s"get Store for version $version")
+ require(version >= 0, "Version cannot be less than 0")
+ if (storeMap.containsKey(version)) {
+ storeMap.get(version)
+ } else {
+ val store = createStore(version)
+ storeMap.put(version, store)
+ store
+ }
+ }
+
+ private def createStore(version: Long): RocksDbStateStore = {
+ val newStore = new RocksDbStateStore(version)
+ if (version > 0) {
+ // load the data into the rocksDB
+ logInfo(
+ s"Loading state into the db for $version and partition ${stateStoreId_.partitionId}")
+ loadIntoRocksDB(version)
+ }
+ newStore
+ }
+
+ private def loadIntoRocksDB(version: Long): Unit = {
+ /*
+ 1. Get last available/committed Rocksdb version in local folder
+ 2. If last committed version = version, we already have loaded rocksdb state.
+ 3. If last committed version = version - 1,
+ we have to apply delta for version in the existing rocksdb
+ 4. Otherwise we have to recreate a new rocksDB store by using Snapshots/Delta
+ */
+ val (_, elapsedMs) = Utils.timeTakenMs {
+ var lastAvailableVersion = getLastCommittedVersion()
+ if (lastAvailableVersion == -1L || lastAvailableVersion <= version - 2) {
+ // Destroy existing DB so that we can reconstruct it using snapshot and delta files
+ RocksDbInstance.destroyDB(rocksDbPath)
+ var lastAvailableSnapShotVersion: Long = version + 1
+ // load from snapshot
+ var found = false
+ while (!found && lastAvailableSnapShotVersion > 0) {
+ try {
+ lastAvailableSnapShotVersion = lastAvailableSnapShotVersion - 1
+ found = loadSnapshotFile(lastAvailableSnapShotVersion)
+ logDebug(
+ s"Snapshot for version $lastAvailableSnapShotVersion " +
+ "and partition ${stateStoreId_.partitionId}: found = $found")
+ } catch {
+ case e: Exception =>
+ logError(s"$e while reading snapshot file")
+ throw e
+ }
+ }
+ lastAvailableVersion = lastAvailableSnapShotVersion
+ }
+ if (lastAvailableVersion < version) {
+ applyDelta(version, lastAvailableVersion)
+ }
+ }
+ logInfo(
+ s"Loading state for $version and partition ${stateStoreId_.partitionId} took $elapsedMs ms.")
+ }
+
+ private def getLastCommittedVersion(): Long = {
+ val f = new File(rocksDbPath, RocksDbInstance.COMMIT_FILE_NAME)
+ if (f.exists()) {
+ try {
+ val fileContents = Source.fromFile(f.getAbsolutePath).getLines.mkString
+ return fileContents.toLong
+ } catch {
+ case e: Exception =>
+ logWarning("Exception while reading committed file")
+ }
+ }
+ return -1L
+ }
+
+ private def loadSnapshotFile(version: Long): Boolean = {
+ val fileToRead = snapshotFile(baseDir, version)
+ if (version == 0 || !fm.exists(fileToRead)) {
+ return false
+ }
+ val versionTempPath = getTempPath(version)
+ val tmpLocDir: File = new File(versionTempPath)
+ val tmpLocFile: File = new File(s"${versionTempPath}.tar")
+ try {
+ logInfo(s"Will download $fileToRead at location ${tmpLocFile.toString()}")
+ if (downloadFile(fm, fileToRead, new Path(tmpLocFile.getAbsolutePath), sparkConf)) {
+ FileUtility.extractTarFile(tmpLocFile.getAbsolutePath, versionTempPath)
+ if (!tmpLocDir.list().exists(_.endsWith(".sst"))) {
+ logWarning("Snapshot files are corrupted")
+ throw new IOException(
+ s"Error reading snapshot file $fileToRead of $this:" +
+ s" No SST files found")
+ }
+ FileUtils.moveDirectory(tmpLocDir, new File(rocksDbPath))
+ true
+ } else {
+ false
+ }
+ } catch {
+ case e: Exception =>
+ logError(s"Exception while loading snapshot file $e")
+ throw e
+ } finally {
+ if (tmpLocFile.exists()) {
+ tmpLocFile.delete()
+ }
+ FileUtils.deleteDirectory(tmpLocDir)
+ }
+ }
+
+ private def applyDelta(version: Long, lastAvailableVersion: Long): Unit = {
+ var rocksDbWriteInstance: OptimisticTransactionDbInstance = null
+ try {
+ rocksDbWriteInstance =
+ new OptimisticTransactionDbInstance(keySchema, valueSchema, version.toString)
+ rocksDbWriteInstance.open(rocksDbPath)
+ rocksDbWriteInstance.startTransactions()
+ // Load all the deltas from the version after the last available
+ // one up to the target version.
+ // The last available version is the one with a full snapshot, so it doesn't need deltas.
+ for (deltaVersion <- (lastAvailableVersion + 1) to version) {
+ val fileToRead = deltaFile(baseDir, deltaVersion)
+ updateFromDeltaFile(
+ fm,
+ fileToRead,
+ keySchema,
+ valueSchema,
+ rocksDbWriteInstance,
+ sparkConf)
+ logInfo(s"Read delta file for version $version of $this from $fileToRead")
+ }
+ rocksDbWriteInstance.commit(Some(getCheckpointPath(version)))
+ } catch {
+ case e: Exception =>
+ logError(s"Exception while loading state ${e.getMessage}")
+ if (rocksDbWriteInstance != null) {
+ rocksDbWriteInstance.abort()
+ }
+ throw e
+ } finally {
+ if (rocksDbWriteInstance != null) {
+ rocksDbWriteInstance.close()
+ }
+ }
+ }
+
+ /** Optional method for providers to allow for background maintenance (e.g. compactions) */
+ override def doMaintenance(): Unit = {
+ try {
+ val (files: Seq[WALUtils.StoreFile], e1) = Utils.timeTakenMs(fetchFiles(fm, baseDir))
+ logDebug(s"fetchFiles() took $e1 ms.")
+ doSnapshot(files)
+ cleanup(files)
+ cleanRocksDBCheckpoints(files)
+ } catch {
+ case NonFatal(e) =>
+ logWarning(s"Error performing snapshot and cleaning up $this")
+ }
+ }
+
+ private def doSnapshot(files: Seq[WALUtils.StoreFile]): Unit = {
+ if (files.nonEmpty) {
+ val lastVersion = files.last.version
+ val deltaFilesForLastVersion =
+ filesForVersion(files, lastVersion).filter(_.isSnapshot == false)
+ if (deltaFilesForLastVersion.size > storeConf.minDeltasForSnapshot) {
+ val dbPath = getCheckpointPath(lastVersion)
+ val snapShotFileName = s"{getTempPath(lastVersion)}.snapshot"
+ val f = new File(snapShotFileName)
+ try {
+ val (_, t1) = Utils.timeTakenMs {
+ FileUtility.createTarFile(dbPath, snapShotFileName)
+ val targetFile = snapshotFile(baseDir, lastVersion)
+ uploadFile(fm, new Path(snapShotFileName), targetFile, sparkConf)
+ }
+ logInfo(s"Creating snapshot file for ${stateStoreId_.partitionId} took $t1 ms.")
+ } catch {
+ case e: Exception =>
+ logError(s"Exception while creating snapshot $e")
+ throw e
+ } finally {
+ f.delete() // delete the tarball
+ }
+ }
+ }
+ }
+
+ /*
+ * Clean up old snapshots and delta files that are not needed any more. It ensures that last
+ * few versions of the store can be recovered from the files, so re-executed RDD operations
+ * can re-apply updates on the past versions of the store.
+ */
+ private[state] def cleanup(files: Seq[WALUtils.StoreFile]): Unit = {
+ try {
+ if (files.nonEmpty) {
+ val earliestVersionToRetain = files.last.version - storeConf.minVersionsToRetain
+ if (earliestVersionToRetain > 0) {
+ val earliestFileToRetain = filesForVersion(files, earliestVersionToRetain).head
+ val filesToDelete = files.filter(_.version < earliestFileToRetain.version)
+ val (_, e2) = Utils.timeTakenMs {
+ filesToDelete.foreach { f =>
+ fm.delete(f.path)
+ val file = new File(rocksDbPath, f.version.toString)
+ if (file.exists()) {
+ file.delete()
+ }
+ }
+ }
+ logDebug(s"deleting files took $e2 ms.")
+ logInfo(
+ s"Deleted files older than ${earliestFileToRetain.version} for $this: " +
+ filesToDelete.mkString(", "))
+ }
+ }
+ } catch {
+ case NonFatal(e) =>
+ logWarning(s"Error cleaning up files for $this", e)
+ }
+ }
+
+ private def cleanRocksDBCheckpoints(files: Seq[WALUtils.StoreFile]): Unit = {
+ try {
+ val (_, e2) = Utils.timeTakenMs {
+ if (files.nonEmpty) {
+ val earliestVersionToRetain = files.last.version - storeConf.minVersionsToRetain
+ if (earliestVersionToRetain > 0) {
+ new File(getCheckpointPath(earliestVersionToRetain)).getParentFile
+ .listFiles(new FileFilter {
+ def accept(f: File): Boolean = {
+ try {
+ f.getName.toLong < earliestVersionToRetain
+ } catch {
+ case _: NumberFormatException => false
+ }
+ }
+ })
+ .foreach(p => RocksDbInstance.destroyDB(p.getAbsolutePath))
+ logInfo(
+ s"Deleted rocksDB checkpoints older than ${earliestVersionToRetain} for $this: ")
+ }
+ }
+ }
+ logDebug(s"deleting rocksDB checkpoints took $e2 ms.")
+ } catch {
+ case NonFatal(e) => logWarning(s"Error cleaning up files for $this", e)
+ }
+ }
+
+ // Used only for unit tests
+ private[sql] def latestIterator(): Iterator[UnsafeRowPair] = synchronized {
+ val versionsInFiles = fetchFiles(fm, baseDir).map(_.version).toSet
+ if (versionsInFiles.nonEmpty) {
+ val maxVersion = versionsInFiles.max
+ if (maxVersion > 0) {
+ loadIntoRocksDB(maxVersion)
+ val r: RocksDbInstance = new RocksDbInstance(keySchema, valueSchema, maxVersion.toString)
+ try {
+ r.open(rocksDbPath, readOnly = true)
+ return r.iterator(false)
+ } catch {
+ case e: Exception =>
+ logWarning(s"Exception ${e.getMessage} while getting latest Iterator")
+ }
+ }
+ }
+ Iterator.empty
+ }
+
+ private[sql] def getLocalDir: String = localDir
+
+ private[sql] lazy val rocksDbPath: String = {
+ getPath("db")
+ }
+
+ private def getCheckpointPath(version: Long): String = {
+ getPath("checkpoint", Some(version.toString))
+ }
+
+ private def getTempPath(version: Long): String = {
+ getPath("tmp", Some(version.toString))
+ }
+
+ private def getPath(subFolderName: String, version: Option[String] = None): String = {
+ val checkpointRootLocationPath = new Path(stateStoreId.checkpointRootLocation)
+
+ val dirPath = new Path(
+ localDir,
+ new Path(
+ new Path(
+ subFolderName,
+ checkpointRootLocationPath.getName + "_" + checkpointRootLocationPath.hashCode()),
+ new Path(stateStoreId_.operatorId.toString, stateStoreId_.partitionId.toString)))
+
+ val f: File = new File(dirPath.toString)
+ if (!f.exists() && !f.mkdirs()) {
+ throw new IllegalStateException(s"Couldn't create directory ${dirPath.toString}")
+ }
+
+ if (version.isEmpty) {
+ dirPath.toString
+ } else {
+ new Path(dirPath, version.get).toString
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/WALUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/WALUtils.scala
new file mode 100644
index 0000000000000..1da7e8d8bbeed
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/WALUtils.scala
@@ -0,0 +1,275 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import java.io._
+import java.util.Locale
+
+import com.google.common.io.ByteStreams
+import org.apache.commons.io.IOUtils
+import org.apache.hadoop.fs.{FileStatus, FSDataInputStream, FSError, Path}
+import scala.collection.mutable
+
+import org.apache.spark.SparkConf
+import org.apache.spark.io.LZ4CompressionCodec
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.streaming.CheckpointFileManager
+import org.apache.spark.sql.execution.streaming.CheckpointFileManager.CancellableFSDataOutputStream
+import org.apache.spark.sql.types.StructType
+
+object WALUtils {
+
+ case class StoreFile(version: Long, path: Path, isSnapshot: Boolean)
+ private val EOF_MARKER = -1
+
+ /** Files needed to recover the given version of the store */
+ def filesForVersion(allFiles: Seq[StoreFile], version: Long): Seq[StoreFile] = {
+ require(version >= 0)
+ require(allFiles.exists(_.version == version))
+
+ val latestSnapshotFileBeforeVersion = allFiles
+ .filter(_.isSnapshot == true)
+ .takeWhile(_.version <= version)
+ .lastOption
+ val deltaBatchFiles = latestSnapshotFileBeforeVersion match {
+ case Some(snapshotFile) =>
+ val deltaFiles = allFiles.filter { file =>
+ file.version > snapshotFile.version && file.version <= version
+ }.toList
+ require(
+ deltaFiles.size == version - snapshotFile.version,
+ s"Unexpected list of delta files for version $version for $this: $deltaFiles")
+ deltaFiles
+
+ case None =>
+ allFiles.takeWhile(_.version <= version)
+ }
+ latestSnapshotFileBeforeVersion.toSeq ++ deltaBatchFiles
+ }
+
+ /** Fetch all the files that back the store */
+ def fetchFiles(fm: CheckpointFileManager, baseDir: Path): Seq[StoreFile] = {
+ val files: Seq[FileStatus] = try {
+ fm.list(baseDir)
+ } catch {
+ case _: java.io.FileNotFoundException =>
+ Seq.empty
+ }
+ val versionToFiles = new mutable.HashMap[Long, StoreFile]
+ files.foreach { status =>
+ val path = status.getPath
+ val nameParts = path.getName.split("\\.")
+ if (nameParts.size == 2) {
+ val version = nameParts(0).toLong
+ nameParts(1).toLowerCase(Locale.ROOT) match {
+ case "delta" =>
+ // ignore the file otherwise, snapshot file already exists for that batch id
+ if (!versionToFiles.contains(version)) {
+ versionToFiles.put(version, StoreFile(version, path, isSnapshot = false))
+ }
+ case "snapshot" =>
+ versionToFiles.put(version, StoreFile(version, path, isSnapshot = true))
+ case _ =>
+ // logWarning(s"Could not identify file $path for $this")
+ }
+ }
+ }
+ val storeFiles = versionToFiles.values.toSeq.sortBy(_.version)
+ storeFiles
+ }
+
+ def compressStream(outputStream: DataOutputStream, sparkConf: SparkConf): DataOutputStream = {
+ val compressed = new LZ4CompressionCodec(sparkConf).compressedOutputStream(outputStream)
+ new DataOutputStream(compressed)
+ }
+
+ def decompressStream(inputStream: DataInputStream, sparkConf: SparkConf): DataInputStream = {
+ val compressed = new LZ4CompressionCodec(sparkConf).compressedInputStream(inputStream)
+ new DataInputStream(compressed)
+ }
+
+ def writeUpdateToDeltaFile(output: DataOutputStream, key: UnsafeRow, value: UnsafeRow): Unit = {
+ val keyBytes = key.getBytes()
+ val valueBytes = value.getBytes()
+ output.writeInt(keyBytes.size)
+ output.write(keyBytes)
+ output.writeInt(valueBytes.size)
+ output.write(valueBytes)
+ }
+
+ def writeRemoveToDeltaFile(output: DataOutputStream, key: UnsafeRow): Unit = {
+ val keyBytes = key.getBytes()
+ output.writeInt(keyBytes.size)
+ output.write(keyBytes)
+ output.writeInt(EOF_MARKER)
+ }
+
+ def finalizeDeltaFile(output: DataOutputStream): Unit = {
+ output.writeInt(EOF_MARKER) // Write this magic number to signify end of file
+ output.close()
+ }
+
+ def updateFromDeltaFile(
+ fm: CheckpointFileManager,
+ fileToRead: Path,
+ keySchema: StructType,
+ valueSchema: StructType,
+ newRocksDb: OptimisticTransactionDbInstance,
+ sparkConf: SparkConf): Unit = {
+ var input: DataInputStream = null
+ val sourceStream = try {
+ fm.open(fileToRead)
+ } catch {
+ case f: FileNotFoundException =>
+ throw new IllegalStateException(
+ s"Error reading delta file $fileToRead of $this: $fileToRead does not exist",
+ f)
+ }
+ try {
+ input = decompressStream(sourceStream, sparkConf)
+ var eof = false
+
+ while (!eof) {
+ val keySize = input.readInt()
+ if (keySize == EOF_MARKER) {
+ eof = true
+ } else if (keySize < 0) {
+ newRocksDb.abort
+ newRocksDb.close()
+ throw new IOException(
+ s"Error reading delta file $fileToRead of $this: key size cannot be $keySize")
+ } else {
+ val keyRowBuffer = new Array[Byte](keySize)
+ ByteStreams.readFully(input, keyRowBuffer, 0, keySize)
+
+ val keyRow = new UnsafeRow(keySchema.fields.length)
+ keyRow.pointTo(keyRowBuffer, keySize)
+
+ val valueSize = input.readInt()
+ if (valueSize < 0) {
+ newRocksDb.remove(key = keyRow)
+ } else {
+ val valueRowBuffer = new Array[Byte](valueSize)
+ ByteStreams.readFully(input, valueRowBuffer, 0, valueSize)
+ val valueRow = new UnsafeRow(valueSchema.fields.length)
+ // If valueSize in existing file is not multiple of 8, floor it to multiple of 8.
+ // This is a workaround for the following:
+ // Prior to Spark 2.3 mistakenly append 4 bytes to the value row in
+ // `RowBasedKeyValueBatch`, which gets persisted into the checkpoint data
+ valueRow.pointTo(valueRowBuffer, (valueSize / 8) * 8)
+ newRocksDb.put(keyRow, valueRow)
+ }
+ }
+ }
+ } finally {
+ if (input != null) input.close()
+ }
+ }
+
+ /*
+ * Try to cancel the underlying stream and safely close the compressed stream.
+ *
+ * @param compressedStream the compressed stream.
+ * @param rawStream the underlying stream which needs to be cancelled.
+ */
+ def cancelDeltaFile(
+ compressedStream: DataOutputStream,
+ rawStream: CancellableFSDataOutputStream): Unit = {
+ try {
+ if (rawStream != null) rawStream.cancel()
+ IOUtils.closeQuietly(compressedStream)
+ } catch {
+ case e: FSError if e.getCause.isInstanceOf[IOException] =>
+ // Closing the compressedStream causes the stream to write/flush flush data into the
+ // rawStream. Since the rawStream is already closed, there may be errors.
+ // Usually its an IOException. However, Hadoop's RawLocalFileSystem wraps
+ // IOException into FSError.
+ }
+ }
+
+ def uploadFile(
+ fm: CheckpointFileManager,
+ sourceFile: Path,
+ targetFile: Path,
+ sparkConf: SparkConf): Unit = {
+ var output: CancellableFSDataOutputStream = null
+ var in: BufferedInputStream = null
+ try {
+ in = new BufferedInputStream(new FileInputStream(sourceFile.toString))
+ output = fm.createAtomic(targetFile, overwriteIfPossible = true)
+ val buffer = new Array[Byte](1024)
+ var len = in.read(buffer)
+ while (len > 0) {
+ output.write(buffer, 0, len)
+ len = in.read(buffer)
+ }
+ output.close()
+ } catch {
+ case e: Throwable =>
+ if (output != null) output.cancel()
+ throw e
+ } finally {
+ if (in != null) {
+ in.close()
+ }
+ }
+ }
+
+ def downloadFile(
+ fm: CheckpointFileManager,
+ sourceFile: Path,
+ targetFile: Path,
+ sparkConf: SparkConf): Boolean = {
+ var in: FSDataInputStream = null
+ var output: BufferedOutputStream = null
+ try {
+ in = fm.open(sourceFile)
+ output = new BufferedOutputStream(new FileOutputStream(targetFile.toString))
+ val buffer = new Array[Byte](1024)
+ var eof = false
+ while (!eof) {
+ val len = in.read(buffer)
+ if (len > 0) {
+ output.write(buffer, 0, len)
+ } else {
+ eof = true
+ }
+ }
+ output.close()
+ } catch {
+ case e: Throwable =>
+ new File(targetFile.toString).delete()
+ throw e
+ } finally {
+ output.close()
+ if (in != null) {
+ in.close()
+ }
+ }
+ return true
+ }
+
+ def deltaFile(baseDir: Path, version: Long): Path = {
+ new Path(baseDir, s"$version.delta")
+ }
+
+ def snapshotFile(baseDir: Path, version: Long): Path = {
+ new Path(baseDir, s"$version.snapshot")
+ }
+
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDbStateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDbStateStoreSuite.scala
new file mode 100644
index 0000000000000..435b5830c3466
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDbStateStoreSuite.scala
@@ -0,0 +1,606 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import java.io.File
+import java.util.UUID
+
+import org.apache.commons.io.FileUtils
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+import org.scalatest.{BeforeAndAfter, PrivateMethodTester}
+import org.scalatest.concurrent.Eventually.{eventually, timeout}
+import org.scalatest.time.SpanSugar._
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+import scala.util.Random
+
+import org.apache.spark.{SparkConf, SparkContext, SparkEnv}
+import org.apache.spark.LocalSparkContext.withSpark
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.catalyst.util.quietly
+import org.apache.spark.sql.execution.streaming.MemoryStream
+import org.apache.spark.sql.functions.count
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
+import org.apache.spark.util.Utils
+
+class RocksDbStateStoreSuite
+ extends StateStoreSuiteBase[RocksDbStateStoreProvider]
+ with BeforeAndAfter
+ with PrivateMethodTester {
+ type MapType = mutable.HashMap[UnsafeRow, UnsafeRow]
+ type ProviderMapType = java.util.concurrent.ConcurrentHashMap[UnsafeRow, UnsafeRow]
+
+ import StateStoreCoordinatorSuite._
+ import StateStoreTestsHelper._
+
+ val keySchema = StructType(Seq(StructField("key", StringType, true)))
+ val valueSchema = StructType(Seq(StructField("value", IntegerType, true)))
+
+ before {
+ StateStore.stop()
+ require(!StateStore.isMaintenanceRunning)
+ }
+
+ after {
+ StateStore.stop()
+ require(!StateStore.isMaintenanceRunning)
+ }
+
+ def updateVersionTo(
+ provider: StateStoreProvider,
+ currentVersion: Int,
+ targetVersion: Int): Int = {
+ var newCurrentVersion = currentVersion
+ for (i <- newCurrentVersion until targetVersion) {
+ newCurrentVersion = incrementVersion(provider, i)
+ }
+ require(newCurrentVersion === targetVersion)
+ newCurrentVersion
+ }
+
+ def incrementVersion(provider: StateStoreProvider, currentVersion: Int): Int = {
+ val store = provider.getStore(currentVersion)
+ put(store, "a", currentVersion + 1)
+ store.commit()
+ currentVersion + 1
+ }
+
+ def checkLoadedVersions(
+ rocksDbWriteInstance: RocksDbInstance,
+ count: Int,
+ earliestKey: Long,
+ latestKey: Long): Unit = {
+ assert(rocksDbWriteInstance.iterator(false).length === count)
+ }
+
+ def checkVersion(
+ rocksDbWriteInstance: RocksDbInstance,
+ version: Long,
+ expectedData: Map[String, Int]): Unit = {
+
+ val originValueMap = rocksDbWriteInstance
+ .iterator(false)
+ .map { row =>
+ rowToString(row.key) -> rowToInt(row.value)
+ }
+ .toMap[String, Int]
+
+ assert(originValueMap === expectedData)
+ }
+
+ test("get, put, remove, commit, and all data iterator") {
+ val provider = newStoreProvider()
+
+ // Verify state before starting a new set of updates
+ assert(getLatestData(provider).isEmpty)
+
+ val store = provider.getStore(0)
+ assert(!store.hasCommitted)
+ assert(get(store, "a") === None)
+ assert(store.iterator().isEmpty)
+
+ // Verify state after updating
+ put(store, "a", 1)
+ assert(get(store, "a") === Some(1))
+
+ assert(store.iterator().nonEmpty)
+ assert(getLatestData(provider).isEmpty)
+
+ // Make updates, commit and then verify state
+ put(store, "b", 2)
+ put(store, "aa", 3)
+ remove(store, _.startsWith("a"))
+ assert(store.commit() === 1)
+
+ assert(store.hasCommitted)
+ assert(rowsToSet(store.iterator()) === Set("b" -> 2))
+ assert(getLatestData(provider) === Set("b" -> 2))
+
+ // Trying to get newer versions should fail
+ intercept[Exception] {
+ provider.getStore(2)
+ }
+ intercept[Exception] {
+ getData(provider, 2)
+ }
+
+ // New updates to the reloaded store with new version, and does not change old version
+ val reloadedProvider = newStoreProvider(store.id, provider.getLocalDir)
+ val reloadedStore = reloadedProvider.getStore(1)
+ put(reloadedStore, "c", 4)
+ assert(reloadedStore.commit() === 2)
+ assert(rowsToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4))
+ assert(getLatestData(provider) === Set("b" -> 2, "c" -> 4))
+ assert(getData(provider, version = 1) === Set("b" -> 2))
+ }
+
+ test("snapshotting") {
+ val provider =
+ newStoreProvider(opId = Random.nextInt, partition = 0, minDeltasForSnapshot = 5)
+
+ var currentVersion = 0
+
+ currentVersion = updateVersionTo(provider, currentVersion, 2)
+ require(getData(provider) === Set("a" -> 2))
+ provider.doMaintenance() // should not generate snapshot files
+ assert(getData(provider) === Set("a" -> 2))
+
+ for (i <- 1 to currentVersion) {
+ assert(fileExists(provider, i, isSnapshot = false)) // all delta files present
+ assert(!fileExists(provider, i, isSnapshot = true)) // no snapshot files present
+ }
+
+ // After version 6, snapshotting should generate one snapshot file
+ currentVersion = updateVersionTo(provider, currentVersion, 6)
+ require(getData(provider) === Set("a" -> 6), "store not updated correctly")
+ provider.doMaintenance() // should generate snapshot files
+
+ val snapshotVersion =
+ (0 to 6).find(version => fileExists(provider, version, isSnapshot = true))
+ assert(snapshotVersion.nonEmpty, "snapshot file not generated")
+ deleteFilesEarlierThanVersion(provider, snapshotVersion.get)
+ assert(
+ getData(provider, snapshotVersion.get) === Set("a" -> snapshotVersion.get),
+ "snapshotting messed up the data of the snapshotted version")
+ assert(
+ getData(provider) === Set("a" -> 6),
+ "snapshotting messed up the data of the final version")
+
+ // After version 20, snapshotting should generate newer snapshot files
+ currentVersion = updateVersionTo(provider, currentVersion, 20)
+ require(getData(provider) === Set("a" -> 20), "store not updated correctly")
+ provider.doMaintenance() // do snapshot
+
+ val latestSnapshotVersion =
+ (0 to 20).filter(version => fileExists(provider, version, isSnapshot = true)).lastOption
+ assert(latestSnapshotVersion.nonEmpty, "no snapshot file found")
+ assert(latestSnapshotVersion.get > snapshotVersion.get, "newer snapshot not generated")
+
+ deleteFilesEarlierThanVersion(provider, latestSnapshotVersion.get)
+ assert(getData(provider) === Set("a" -> 20), "snapshotting messed up the data")
+ }
+
+ test("cleaning") {
+ val provider =
+ newStoreProvider(opId = Random.nextInt, partition = 0, minDeltasForSnapshot = 5)
+
+ for (i <- 1 to 20) {
+ val store = provider.getStore(i - 1)
+ put(store, "a", i)
+ store.commit()
+ provider.doMaintenance() // do cleanup
+ }
+ require(
+ rowsToSet(provider.latestIterator()) === Set("a" -> 20),
+ "store not updated correctly")
+
+ assert(!fileExists(provider, version = 1, isSnapshot = false)) // first file should be deleted
+
+ // last couple of versions should be retrievable
+ assert(getData(provider, 20) === Set("a" -> 20))
+ assert(getData(provider, 19) === Set("a" -> 19))
+ }
+
+ testQuietly("SPARK-19677: Committing a delta file atop an existing one should not fail on HDFS") {
+ val conf = new Configuration()
+ conf.set("fs.fake.impl", classOf[RenameLikeHDFSFileSystem].getName)
+ conf.set("fs.defaultFS", "fake:///")
+
+ val provider = newStoreProvider(opId = Random.nextInt, partition = 0, hadoopConf = conf)
+ provider.getStore(0).commit()
+ provider.getStore(0).commit()
+
+ // Verify we don't leak temp files
+ val tempFiles = FileUtils
+ .listFiles(new File(provider.stateStoreId.checkpointRootLocation), null, true)
+ .asScala
+ .filter(_.getName.startsWith("temp-"))
+ assert(tempFiles.isEmpty)
+ }
+
+ test("corrupted file handling") {
+ val provider =
+ newStoreProvider(opId = Random.nextInt, partition = 0, minDeltasForSnapshot = 5)
+ for (i <- 1 to 6) {
+ val store = provider.getStore(i - 1)
+ put(store, "a", i)
+ store.commit()
+ provider.doMaintenance() // do cleanup
+ }
+ val snapshotVersion = (0 to 10)
+ .find(version => fileExists(provider, version, isSnapshot = true))
+ .getOrElse(fail("snapshot file not found"))
+
+ // Corrupt snapshot file and verify that it throws error
+ provider.close()
+ assert(getData(provider, snapshotVersion) === Set("a" -> snapshotVersion))
+ RocksDbInstance.destroyDB(provider.rocksDbPath)
+
+ corruptFile(provider, snapshotVersion, isSnapshot = true)
+ intercept[Exception] {
+ provider.close()
+ RocksDbInstance.destroyDB(provider.rocksDbPath)
+ getData(provider, snapshotVersion)
+ }
+
+ // Corrupt delta file and verify that it throws error
+ provider.close()
+ RocksDbInstance.destroyDB(provider.rocksDbPath)
+ assert(getData(provider, snapshotVersion - 1) === Set("a" -> (snapshotVersion - 1)))
+
+ corruptFile(provider, snapshotVersion - 1, isSnapshot = false)
+ intercept[Exception] {
+ provider.close()
+ RocksDbInstance.destroyDB(provider.rocksDbPath)
+ getData(provider, snapshotVersion - 1)
+ }
+
+ // Delete delta file and verify that it throws error
+ deleteFilesEarlierThanVersion(provider, snapshotVersion)
+ intercept[Exception] {
+ provider.close()
+ RocksDbInstance.destroyDB(provider.rocksDbPath)
+ getData(provider, snapshotVersion - 1)
+ }
+ }
+
+ test("StateStore.get") {
+ quietly {
+ val dir = newDir()
+ val storeId = StateStoreProviderId(StateStoreId(dir, 0, 0), UUID.randomUUID)
+ val sqlConf = new SQLConf
+ sqlConf.setConfString(
+ SQLConf.STATE_STORE_PROVIDER_CLASS.key,
+ "org.apache.spark.sql.execution.streaming.state.RocksDbStateStoreProvider")
+ val storeConf = new StateStoreConf(sqlConf)
+ assert(
+ storeConf.providerClass ===
+ "org.apache.spark.sql.execution.streaming.state.RocksDbStateStoreProvider")
+ val hadoopConf = new Configuration()
+
+ // Verify that trying to get incorrect versions throw errors
+ intercept[IllegalArgumentException] {
+ StateStore.get(storeId, keySchema, valueSchema, None, -1, storeConf, hadoopConf)
+ }
+ assert(!StateStore.isLoaded(storeId)) // version -1 should not attempt to load the store
+
+ intercept[IllegalStateException] {
+ StateStore.get(storeId, keySchema, valueSchema, None, 1, storeConf, hadoopConf)
+ }
+
+ // Increase version of the store and try to get again
+ val store0 = StateStore.get(storeId, keySchema, valueSchema, None, 0, storeConf, hadoopConf)
+ assert(store0.version === 0)
+ put(store0, "a", 1)
+ store0.commit()
+
+ val store1 = StateStore.get(storeId, keySchema, valueSchema, None, 1, storeConf, hadoopConf)
+ assert(StateStore.isLoaded(storeId))
+ assert(store1.version === 1)
+ assert(rowsToSet(store1.iterator()) === Set("a" -> 1))
+
+ // Verify that you can also load older version
+ val store0reloaded =
+ StateStore.get(storeId, keySchema, valueSchema, None, 0, storeConf, hadoopConf)
+ assert(store0reloaded.version === 0)
+ assert(rowsToSet(store0reloaded.iterator()) === Set.empty)
+
+ // Verify that you can remove the store and still reload and use it
+ StateStore.unload(storeId)
+ assert(!StateStore.isLoaded(storeId))
+
+ val store1reloaded =
+ StateStore.get(storeId, keySchema, valueSchema, None, 1, storeConf, hadoopConf)
+ assert(StateStore.isLoaded(storeId))
+ assert(store1reloaded.version === 1)
+ put(store1reloaded, "a", 2)
+ assert(store1reloaded.commit() === 2)
+ assert(rowsToSet(store1reloaded.iterator()) === Set("a" -> 2))
+ }
+ }
+
+ test("maintenance") {
+ val conf = new SparkConf()
+ .setMaster("local")
+ .setAppName("test")
+ // Make maintenance thread do snapshots and cleanups very fast
+ .set(StateStore.MAINTENANCE_INTERVAL_CONFIG, "10ms")
+ // Make sure that when SparkContext stops, the StateStore maintenance thread 'quickly'
+ // fails to talk to the StateStoreCoordinator and unloads all the StateStores
+ .set("spark.rpc.numRetries", "1")
+ val opId = 0
+ val dir = newDir()
+ val storeProviderId = StateStoreProviderId(StateStoreId(dir, opId, 0), UUID.randomUUID)
+ val sqlConf = new SQLConf()
+ sqlConf.setConfString(
+ SQLConf.STATE_STORE_PROVIDER_CLASS.key,
+ "org.apache.spark.sql.execution.streaming.state.RocksDbStateStoreProvider")
+ sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2)
+ val storeConf = StateStoreConf(sqlConf)
+ val hadoopConf = new Configuration()
+ val provider = newStoreProvider(storeProviderId.storeId)
+
+ var latestStoreVersion = 0
+
+ def generateStoreVersions() {
+ for (i <- 1 to 20) {
+ val store = StateStore.get(
+ storeProviderId,
+ keySchema,
+ valueSchema,
+ None,
+ latestStoreVersion,
+ storeConf,
+ hadoopConf)
+ put(store, "a", i)
+ store.commit()
+ latestStoreVersion += 1
+ }
+ }
+
+ val timeoutDuration = 60 seconds
+
+ quietly {
+ withSpark(new SparkContext(conf)) { sc =>
+ withCoordinatorRef(sc) { coordinatorRef =>
+ require(!StateStore.isMaintenanceRunning, "StateStore is unexpectedly running")
+
+ // Generate sufficient versions of store for snapshots
+ generateStoreVersions()
+
+ eventually(timeout(timeoutDuration)) {
+ // Store should have been reported to the coordinator
+ assert(
+ coordinatorRef.getLocation(storeProviderId).nonEmpty,
+ "active instance was not reported")
+
+ // Background maintenance should clean up and generate snapshots
+ assert(StateStore.isMaintenanceRunning, "Maintenance task is not running")
+
+ // Some snapshots should have been generated
+ val snapshotVersions = (1 to latestStoreVersion).filter { version =>
+ fileExists(provider, version, isSnapshot = true)
+ }
+ assert(snapshotVersions.nonEmpty, "no snapshot file found")
+ }
+
+ // Generate more versions such that there is another snapshot and
+ // the earliest delta file will be cleaned up
+ generateStoreVersions()
+
+ // Earliest delta file should get cleaned up
+ eventually(timeout(timeoutDuration)) {
+ assert(!fileExists(provider, 1, isSnapshot = false), "earliest file not deleted")
+ }
+
+ // If driver decides to deactivate all stores related to a query run,
+ // then this instance should be unloaded
+ coordinatorRef.deactivateInstances(storeProviderId.queryRunId)
+ eventually(timeout(timeoutDuration)) {
+ assert(!StateStore.isLoaded(storeProviderId))
+ }
+
+ // Reload the store and verify
+ StateStore.get(
+ storeProviderId,
+ keySchema,
+ valueSchema,
+ indexOrdinal = None,
+ latestStoreVersion,
+ storeConf,
+ hadoopConf)
+ assert(StateStore.isLoaded(storeProviderId))
+
+ // If some other executor loads the store, then this instance should be unloaded
+ coordinatorRef.reportActiveInstance(storeProviderId, "other-host", "other-exec")
+ eventually(timeout(timeoutDuration)) {
+ assert(!StateStore.isLoaded(storeProviderId))
+ }
+
+ // Reload the store and verify
+ StateStore.get(
+ storeProviderId,
+ keySchema,
+ valueSchema,
+ indexOrdinal = None,
+ latestStoreVersion,
+ storeConf,
+ hadoopConf)
+ assert(StateStore.isLoaded(storeProviderId))
+ }
+ }
+
+ // Verify if instance is unloaded if SparkContext is stopped
+ eventually(timeout(timeoutDuration)) {
+ require(SparkEnv.get === null)
+ assert(!StateStore.isLoaded(storeProviderId))
+ assert(!StateStore.isMaintenanceRunning)
+ }
+ }
+ }
+
+ test("SPARK-21145: Restarted queries create new provider instances") {
+ try {
+ val checkpointLocation = Utils.createTempDir().getAbsoluteFile
+ val spark = SparkSession.builder().master("local[2]").getOrCreate()
+ SparkSession.setActiveSession(spark)
+ implicit val sqlContext = spark.sqlContext
+ spark.conf.set("spark.sql.shuffle.partitions", "1")
+ spark.conf.set(
+ SQLConf.STATE_STORE_PROVIDER_CLASS.key,
+ "org.apache.spark.sql.execution.streaming.state.RocksDbStateStoreProvider")
+ import spark.implicits._
+ val inputData = MemoryStream[Int]
+
+ def runQueryAndGetLoadedProviders(): Seq[StateStoreProvider] = {
+ val aggregated = inputData.toDF().groupBy("value").agg(count("*"))
+ // stateful query
+ val query = aggregated.writeStream
+ .format("memory")
+ .outputMode("complete")
+ .queryName("query")
+ .option("checkpointLocation", checkpointLocation.toString)
+ .start()
+ inputData.addData(1, 2, 3)
+ query.processAllAvailable()
+ require(query.lastProgress != null) // at least one batch processed after start
+ val loadedProvidersMethod =
+ PrivateMethod[mutable.HashMap[StateStoreProviderId, StateStoreProvider]](
+ 'loadedProviders)
+ val loadedProvidersMap = StateStore invokePrivate loadedProvidersMethod()
+ val loadedProviders = loadedProvidersMap.synchronized { loadedProvidersMap.values.toSeq }
+ query.stop()
+ loadedProviders
+ }
+
+ val loadedProvidersAfterRun1 = runQueryAndGetLoadedProviders()
+ require(loadedProvidersAfterRun1.length === 1)
+
+ val loadedProvidersAfterRun2 = runQueryAndGetLoadedProviders()
+ assert(loadedProvidersAfterRun2.length === 2) // two providers loaded for 2 runs
+
+ // Both providers should have the same StateStoreId, but the should be different objects
+ assert(
+ loadedProvidersAfterRun2(0).stateStoreId === loadedProvidersAfterRun2(1).stateStoreId)
+ assert(loadedProvidersAfterRun2(0) ne loadedProvidersAfterRun2(1))
+
+ } finally {
+ SparkSession.getActiveSession.foreach { spark =>
+ spark.streams.active.foreach(_.stop())
+ spark.stop()
+ }
+ }
+ }
+
+ override def newStoreProvider(): RocksDbStateStoreProvider = {
+ newStoreProvider(opId = Random.nextInt(), partition = 0)
+ }
+
+ override def newStoreProvider(storeId: StateStoreId): RocksDbStateStoreProvider = {
+ newStoreProvider(
+ storeId.operatorId,
+ storeId.partitionId,
+ dir = storeId.checkpointRootLocation)
+ }
+
+ def newStoreProvider(storeId: StateStoreId, localDir: String): RocksDbStateStoreProvider = {
+ newStoreProvider(
+ storeId.operatorId,
+ storeId.partitionId,
+ dir = storeId.checkpointRootLocation,
+ localDir = localDir)
+ }
+
+ override def getLatestData(storeProvider: RocksDbStateStoreProvider): Set[(String, Int)] = {
+ getData(storeProvider)
+ }
+
+ override def getData(
+ provider: RocksDbStateStoreProvider,
+ version: Int = -1): Set[(String, Int)] = {
+ val reloadedProvider = newStoreProvider(provider.stateStoreId, provider.getLocalDir)
+ if (version < 0) {
+ reloadedProvider.latestIterator().map(rowsToStringInt).toSet
+ } else {
+ reloadedProvider.getStore(version).iterator().map(rowsToStringInt).toSet
+ }
+ }
+
+ def newStoreProvider(
+ opId: Long,
+ partition: Int,
+ dir: String = newDir(),
+ localDir: String = newDir(),
+ minDeltasForSnapshot: Int = SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get,
+ numOfVersToRetainInMemory: Int = SQLConf.MAX_BATCHES_TO_RETAIN_IN_MEMORY.defaultValue.get,
+ hadoopConf: Configuration = new Configuration): RocksDbStateStoreProvider = {
+ val sqlConf = new SQLConf()
+ sqlConf.setConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT, minDeltasForSnapshot)
+ sqlConf.setConf(SQLConf.MAX_BATCHES_TO_RETAIN_IN_MEMORY, numOfVersToRetainInMemory)
+ sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2)
+ sqlConf.setConfString("spark.sql.streaming.stateStore.rocksDb.localDir", localDir)
+ val provider = new RocksDbStateStoreProvider
+ provider.init(
+ StateStoreId(dir, opId, partition),
+ keySchema,
+ valueSchema,
+ keyIndexOrdinal = None,
+ new StateStoreConf(sqlConf),
+ hadoopConf)
+ provider
+ }
+
+ def fileExists(
+ provider: RocksDbStateStoreProvider,
+ version: Long,
+ isSnapshot: Boolean): Boolean = {
+ val method = PrivateMethod[Path]('baseDir)
+ val basePath = provider invokePrivate method()
+ val fileName = if (isSnapshot) s"$version.snapshot" else s"$version.delta"
+ val filePath = new File(basePath.toString, fileName)
+ filePath.exists
+ }
+
+ def deleteFilesEarlierThanVersion(provider: RocksDbStateStoreProvider, version: Long): Unit = {
+ val method = PrivateMethod[Path]('baseDir)
+ val basePath = provider invokePrivate method()
+ for (version <- 0 until version.toInt) {
+ for (isSnapshot <- Seq(false, true)) {
+ val fileName = if (isSnapshot) s"$version.snapshot" else s"$version.delta"
+ val filePath = new File(basePath.toString, fileName)
+ if (filePath.exists) filePath.delete()
+ }
+ }
+ }
+
+ def corruptFile(
+ provider: RocksDbStateStoreProvider,
+ version: Long,
+ isSnapshot: Boolean): Unit = {
+ val method = PrivateMethod[Path]('baseDir)
+ val basePath = provider invokePrivate method()
+ val fileName = if (isSnapshot) s"$version.snapshot" else s"$version.delta"
+ val filePath = new File(basePath.toString, fileName)
+ filePath.delete()
+ filePath.createNewFile()
+ }
+
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
index a84d107f2cbc0..c110c6c0a62f2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
@@ -198,6 +198,58 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
assert(loadedMaps.size() === 0)
}
+ test("get, put, remove, commit, and all data iterator") {
+ val provider = newStoreProvider()
+
+ // Verify state before starting a new set of updates
+ assert(getLatestData(provider).isEmpty)
+
+ val store = provider.getStore(0)
+ assert(!store.hasCommitted)
+ assert(get(store, "a") === None)
+ assert(store.iterator().isEmpty)
+ assert(store.metrics.numKeys === 0)
+
+ // Verify state after updating
+ put(store, "a", 1)
+ assert(get(store, "a") === Some(1))
+ assert(store.metrics.numKeys === 1)
+
+ assert(store.iterator().nonEmpty)
+ assert(getLatestData(provider).isEmpty)
+
+ // Make updates, commit and then verify state
+ put(store, "b", 2)
+ put(store, "aa", 3)
+ assert(store.metrics.numKeys === 3)
+ remove(store, _.startsWith("a"))
+ assert(store.metrics.numKeys === 1)
+ assert(store.commit() === 1)
+
+ assert(store.hasCommitted)
+ assert(rowsToSet(store.iterator()) === Set("b" -> 2))
+ assert(getLatestData(provider) === Set("b" -> 2))
+
+ // Trying to get newer versions should fail
+ intercept[Exception] {
+ provider.getStore(2)
+ }
+ intercept[Exception] {
+ getData(provider, 2)
+ }
+
+ // New updates to the reloaded store with new version, and does not change old version
+ val reloadedProvider = newStoreProvider(store.id)
+ val reloadedStore = reloadedProvider.getStore(1)
+ assert(reloadedStore.metrics.numKeys === 1)
+ put(reloadedStore, "c", 4)
+ assert(reloadedStore.metrics.numKeys === 2)
+ assert(reloadedStore.commit() === 2)
+ assert(rowsToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4))
+ assert(getLatestData(provider) === Set("b" -> 2, "c" -> 4))
+ assert(getData(provider, version = 1) === Set("b" -> 2))
+ }
+
test("snapshotting") {
val provider = newStoreProvider(opId = Random.nextInt, partition = 0, minDeltasForSnapshot = 5)
@@ -817,58 +869,6 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
extends SparkFunSuite {
import StateStoreTestsHelper._
- test("get, put, remove, commit, and all data iterator") {
- val provider = newStoreProvider()
-
- // Verify state before starting a new set of updates
- assert(getLatestData(provider).isEmpty)
-
- val store = provider.getStore(0)
- assert(!store.hasCommitted)
- assert(get(store, "a") === None)
- assert(store.iterator().isEmpty)
- assert(store.metrics.numKeys === 0)
-
- // Verify state after updating
- put(store, "a", 1)
- assert(get(store, "a") === Some(1))
- assert(store.metrics.numKeys === 1)
-
- assert(store.iterator().nonEmpty)
- assert(getLatestData(provider).isEmpty)
-
- // Make updates, commit and then verify state
- put(store, "b", 2)
- put(store, "aa", 3)
- assert(store.metrics.numKeys === 3)
- remove(store, _.startsWith("a"))
- assert(store.metrics.numKeys === 1)
- assert(store.commit() === 1)
-
- assert(store.hasCommitted)
- assert(rowsToSet(store.iterator()) === Set("b" -> 2))
- assert(getLatestData(provider) === Set("b" -> 2))
-
- // Trying to get newer versions should fail
- intercept[Exception] {
- provider.getStore(2)
- }
- intercept[Exception] {
- getData(provider, 2)
- }
-
- // New updates to the reloaded store with new version, and does not change old version
- val reloadedProvider = newStoreProvider(store.id)
- val reloadedStore = reloadedProvider.getStore(1)
- assert(reloadedStore.metrics.numKeys === 1)
- put(reloadedStore, "c", 4)
- assert(reloadedStore.metrics.numKeys === 2)
- assert(reloadedStore.commit() === 2)
- assert(rowsToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4))
- assert(getLatestData(provider) === Set("b" -> 2, "c" -> 4))
- assert(getData(provider, version = 1) === Set("b" -> 2))
- }
-
test("removing while iterating") {
val provider = newStoreProvider()