Skip to content

Commit 1a3d20a

Browse files
make sure test is stable
1 parent eddb3c7 commit 1a3d20a

File tree

4 files changed

+130
-97
lines changed

4 files changed

+130
-97
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,15 +296,15 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
296296
private def getLoadedMapForStore(startVersion: Long, endVersion: Long):
297297
HDFSBackedStateStoreMap = synchronized {
298298
try {
299-
if (startVersion < 0) {
299+
if (startVersion < 1) {
300300
throw QueryExecutionErrors.unexpectedStateStoreVersion(startVersion)
301301
}
302302
if (endVersion < startVersion || endVersion < 0) {
303303
throw QueryExecutionErrors.unexpectedStateStoreVersion(endVersion)
304304
}
305305

306306
val newMap = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey)
307-
if (!(startVersion == 0 && endVersion == 0)) {
307+
if (!(endVersion == 0)) {
308308
newMap.putAll(loadMap(startVersion, endVersion))
309309
}
310310
newMap

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,10 +180,38 @@ class RocksDB(
180180
logInfo(log"Loading ${MDC(LogKeys.VERSION_NUM, version)}")
181181
try {
182182
if (loadedVersion != version) {
183+
closeDB()
183184
val latestSnapshotVersion = fileManager.getLatestSnapshotVersion(version)
184-
loadFromCheckpoint(latestSnapshotVersion, version)
185+
val metadata = fileManager.loadCheckpointFromDfs(latestSnapshotVersion, workingDir)
186+
loadedVersion = latestSnapshotVersion
187+
188+
// reset last snapshot version
189+
if (lastSnapshotVersion > latestSnapshotVersion) {
190+
// discard any newer snapshots
191+
lastSnapshotVersion = 0L
192+
latestSnapshot = None
193+
}
194+
openDB()
195+
196+
numKeysOnWritingVersion = if (!conf.trackTotalNumberOfRows) {
197+
// we don't track the total number of rows - discard the number being track
198+
-1L
199+
} else if (metadata.numKeys < 0) {
200+
// we track the total number of rows, but the snapshot doesn't have tracking number
201+
// need to count keys now
202+
countKeys()
203+
} else {
204+
metadata.numKeys
205+
}
206+
if (loadedVersion != version) replayChangelog(version)
207+
// After changelog replay the numKeysOnWritingVersion will be updated to
208+
// the correct number of keys in the loaded version.
209+
numKeysOnLoadedVersion = numKeysOnWritingVersion
210+
fileManagerMetrics = fileManager.latestLoadCheckpointMetrics
211+
}
212+
if (conf.resetStatsOnLoad) {
213+
nativeStats.reset
185214
}
186-
187215
logInfo(log"Loaded ${MDC(LogKeys.VERSION_NUM, version)}")
188216
} catch {
189217
case t: Throwable =>

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ private[sql] class RocksDBStateStoreProvider
324324

325325
override def getReadStore(startVersion: Long, endVersion: Long): StateStore = {
326326
try {
327-
if (startVersion < 0) {
327+
if (startVersion < 1) {
328328
throw QueryExecutionErrors.unexpectedStateStoreVersion(startVersion)
329329
}
330330
if (endVersion < startVersion) {

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala

Lines changed: 97 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ import java.io.{File, FileWriter}
2020

2121
import org.apache.hadoop.conf.Configuration
2222
import org.scalatest.Assertions
23-
2423
import org.apache.spark.{SparkException, SparkUnsupportedOperationException}
2524
import org.apache.spark.io.CompressionCodec
2625
import org.apache.spark.sql.{AnalysisException, DataFrame, Encoders, Row}
@@ -380,10 +379,28 @@ class HDFSBackedStateDataSourceReadSuite
380379
super.beforeAll()
381380
spark.conf.set(SQLConf.STATE_STORE_PROVIDER_CLASS.key,
382381
classOf[HDFSBackedStateStoreProvider].getName)
382+
// make sure we have a snapshot for every two delta files
383+
spark.conf.set(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key, 1)
383384
}
384385

385386
override protected def newStateStoreProvider(): HDFSBackedStateStoreProvider =
386387
new HDFSBackedStateStoreProvider
388+
389+
test("ERROR: snapshot partition not found") {
390+
testPartitionNotFound()
391+
}
392+
393+
test("provider.getReadStore(startVersion, endVersion)") {
394+
testGetReadStoreWithStart()
395+
}
396+
397+
test("option snapshotPartitionId") {
398+
testSnapshotPartitionId()
399+
}
400+
401+
test("snapshotStartBatchId and snapshotPartitionId end to end") {
402+
testSnapshotEndToEnd()
403+
}
387404
}
388405

389406
class RocksDBStateDataSourceReadSuite
@@ -408,43 +425,66 @@ class RocksDBWithChangelogCheckpointStateDataSourceReaderSuite
408425
classOf[RocksDBStateStoreProvider].getName)
409426
spark.conf.set("spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled",
410427
"true")
428+
// make sure we have a snapshot for every other checkpoint
429+
spark.conf.set(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key, 2)
411430
}
412431

413432
override protected def newStateStoreProvider(): RocksDBStateStoreProvider =
414433
new RocksDBStateStoreProvider
434+
435+
test("ERROR: snapshot partition not found") {
436+
testPartitionNotFound()
437+
}
438+
439+
test("provider.getReadStore(startVersion, endVersion)") {
440+
testGetReadStoreWithStart()
441+
}
442+
443+
test("option snapshotPartitionId") {
444+
testSnapshotPartitionId()
445+
}
415446
}
416447

417448
abstract class StateDataSourceReadSuite[storeProvider <: StateStoreProvider]
418449
extends StateDataSourceTestBase with Assertions {
419450

451+
import testImplicits._
420452
import StateStoreTestsHelper._
421453

422454
protected val keySchema: StructType = StateStoreTestsHelper.keySchema
423455
protected val valueSchema: StructType = StateStoreTestsHelper.valueSchema
424456

425457
protected def newStateStoreProvider(): storeProvider
426458

427-
protected def getNewStateStoreProvider(checkpointDir: String): storeProvider = {
428-
val minDeltasForSnapshot = 1 // overwrites the default 10
429-
val numOfVersToRetainInMemory = SQLConf.MAX_BATCHES_TO_RETAIN_IN_MEMORY.defaultValue.get
430-
val sqlConf = new SQLConf()
431-
sqlConf.setConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT, minDeltasForSnapshot)
432-
sqlConf.setConf(SQLConf.MAX_BATCHES_TO_RETAIN_IN_MEMORY, numOfVersToRetainInMemory)
433-
sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2)
434-
sqlConf.setConf(SQLConf.STATE_STORE_COMPRESSION_CODEC, SQLConf.get.stateStoreCompressionCodec)
459+
def put(store: StateStore, key1: String, key2: Int, value: Int): Unit = {
460+
store.put(dataToKeyRow(key1, key2), dataToValueRow(value))
461+
}
435462

463+
def get(store: ReadStateStore, key1: String, key2: Int): Option[Int] = {
464+
Option(store.get(dataToKeyRow(key1, key2))).map(valueRowToData)
465+
}
466+
467+
/**
468+
* Calls the overridable [[newStateStoreProvider]] to create the state store provider instance.
469+
* Initialize it with default settings except for STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.
470+
*
471+
* @param checkpointDir path to store state information
472+
* @param minDeltasForSnapshot one snapshot for minDeltasForSnapshot+1 delta files
473+
* @return
474+
*/
475+
private def getNewStateStoreProvider(checkpointDir: String): storeProvider = {
436476
val provider = newStateStoreProvider()
437477
provider.init(
438478
StateStoreId(checkpointDir, 0, 0),
439479
keySchema,
440480
valueSchema,
441481
NoPrefixKeyStateEncoderSpec(keySchema),
442482
useColumnFamilies = false,
443-
StateStoreConf(sqlConf),
483+
StateStoreConf(spark.sessionState.conf),
444484
new Configuration)
445485
provider
446486
}
447-
487+
448488
test("simple aggregation, state ver 1") {
449489
testStreamingAggregation(1)
450490
}
@@ -912,97 +952,62 @@ abstract class StateDataSourceReadSuite[storeProvider <: StateStoreProvider]
912952
}
913953
}
914954

915-
def put(store: StateStore, key1: String, key2: Int, value: Int): Unit = {
916-
store.put(dataToKeyRow(key1, key2), dataToValueRow(value))
917-
}
955+
protected def testPartitionNotFound(): Unit = {
956+
withTempDir(tempDir => {
957+
val provider = getNewStateStoreProvider(tempDir.getAbsolutePath)
958+
for (i <- 1 to 4) {
959+
val store = provider.getStore(i - 1)
960+
put(store, "a", i, i)
961+
store.commit()
962+
provider.doMaintenance() // create a snapshot every other delta file
963+
}
918964

919-
test("ERROR: snapshot partition not found") {
920-
withTempDir(tempDir1 => {
921-
val tempDir = new java.io.File("/tmp/state/test/")
922965
val exc = intercept[SparkException] {
923-
val provider = getNewStateStoreProvider(tempDir.getAbsolutePath + "/state/")
924-
// val checker = new StateSchemaCompatibilityChecker(
925-
// new StateStoreProviderId(provider.stateStoreId, UUID.randomUUID()), new Configuration())
926-
// checker.createSchemaFile(keySchema, valueSchema)
927-
for (i <- 1 to 4) {
928-
val store = provider.getStore(i - 1)
929-
put(store, "a", 0, i)
930-
store.commit()
931-
provider.doMaintenance() // do cleanup
932-
}
933-
// val stateStore = provider.getStore(0)
934-
935-
// put(stateStore, "a", 1, 1)
936-
// put(stateStore, "b", 2, 2)
937-
// println(stateStore.hasCommitted)
938-
// println(stateStore.getClass.toString)
939-
940-
// stateStore.commit()
941-
provider.close()
942-
943-
// println(stateStore.hasCommitted)
966+
provider.getReadStore(1, 2)
967+
}
968+
checkError(exc, "CANNOT_LOAD_STATE_STORE.UNCATEGORIZED")
969+
})
970+
}
944971

945-
val df = spark.read.format("statestore")
946-
.option(StateSourceOptions.SNAPSHOT_START_BATCH_ID, 0)
947-
.option(StateSourceOptions.SNAPSHOT_PARTITION_ID, 0)
948-
.option(StateSourceOptions.BATCH_ID, 0)
949-
.load(tempDir.getAbsolutePath)
972+
protected def testGetReadStoreWithStart(): Unit = {
973+
withTempDir(tempDir => {
974+
val provider = getNewStateStoreProvider(tempDir.getAbsolutePath)
975+
for (i <- 1 to 4) {
976+
val store = provider.getStore(i - 1)
977+
put(store, "a", i, i)
978+
store.commit()
979+
provider.doMaintenance()
980+
}
950981

951-
println(df.rdd.getNumPartitions)
982+
val result = provider.getReadStore(2, 3)
952983

984+
assert(get(result, "a", 1).get == 1)
985+
assert(get(result, "a", 2).get == 2)
986+
assert(get(result, "a", 3).get == 3)
987+
assert(get(result, "a", 4).isEmpty)
953988

954-
val result = provider.getReadStore(0, 1)
989+
provider.close()
990+
})
991+
}
955992

993+
protected def testSnapshotPartitionId(): Unit = {
994+
withTempDir(tempDir => {
995+
val inputData = MemoryStream[Int]
996+
val df = inputData.toDF().limit(10)
956997

957-
}
958-
assert(exc.getCause.getMessage.contains(
959-
"CANNOT_LOAD_STATE_STORE.CANNOT_READ_SNAPSHOT_FILE_NOT_EXISTS"))
960-
})
998+
testStream(df)(
999+
StartStream(checkpointLocation = tempDir.getAbsolutePath),
1000+
AddData(inputData, 1, 2, 3, 4),
1001+
CheckLastBatch(1, 2, 3, 4)
1002+
)
9611003

962-
val exc = intercept[SparkException] {
963-
val checkpointPath = this.getClass.getResource(
964-
"/structured-streaming/checkpoint-version-4.0.0-state-source/").getPath
965-
spark.read.format("statestore")
1004+
val stateDf = spark.read.format("statestore")
9661005
.option(StateSourceOptions.SNAPSHOT_START_BATCH_ID, 0)
9671006
.option(StateSourceOptions.SNAPSHOT_PARTITION_ID, 0)
968-
.load(checkpointPath).show()
969-
}
970-
assert(exc.getCause.getMessage.contains(
971-
"CANNOT_LOAD_STATE_STORE.CANNOT_READ_SNAPSHOT_FILE_NOT_EXISTS"))
972-
}
973-
974-
test("reconstruct state from specific snapshot and partition") {
975-
val checkpointPath = this.getClass.getResource(
976-
"/structured-streaming/checkpoint-version-4.0.0-state-source/").getPath
977-
val stateFromBatch11 = spark.read.format("statestore")
978-
.option(StateSourceOptions.SNAPSHOT_START_BATCH_ID, 11)
979-
.option(StateSourceOptions.SNAPSHOT_PARTITION_ID, 1)
980-
.load(checkpointPath)
981-
val stateFromBatch23 = spark.read.format("statestore")
982-
.option(StateSourceOptions.SNAPSHOT_START_BATCH_ID, 23)
983-
.option(StateSourceOptions.SNAPSHOT_PARTITION_ID, 1)
984-
.load(checkpointPath)
985-
val stateFromLatestBatch = spark.read.format("statestore").load(checkpointPath)
986-
val stateFromLatestBatchPartition1 = stateFromLatestBatch.filter(
987-
stateFromLatestBatch("partition_id") === 1)
988-
989-
checkAnswer(stateFromBatch23, stateFromLatestBatchPartition1)
990-
checkAnswer(stateFromBatch11, stateFromLatestBatchPartition1)
991-
}
992-
993-
test("use snapshotStartBatchId together with batchId") {
994-
val checkpointPath = this.getClass.getResource(
995-
"/structured-streaming/checkpoint-version-4.0.0-state-source/").getPath
996-
val stateFromBatch11 = spark.read.format("statestore")
997-
.option(StateSourceOptions.SNAPSHOT_START_BATCH_ID, 11)
998-
.option(StateSourceOptions.SNAPSHOT_PARTITION_ID, 1)
999-
.option(StateSourceOptions.BATCH_ID, 20)
1000-
.load(checkpointPath)
1001-
val stateFromLatestBatch = spark.read.format("statestore")
1002-
.option(StateSourceOptions.BATCH_ID, 20).load(checkpointPath)
1003-
val stateFromLatestBatchPartition1 = stateFromLatestBatch.filter(
1004-
stateFromLatestBatch("partition_id") === 1)
1005-
1006-
checkAnswer(stateFromBatch11, stateFromLatestBatchPartition1)
1007+
.option(StateSourceOptions.BATCH_ID, 0)
1008+
.load(tempDir.getAbsolutePath)
1009+
1010+
assert(stateDf.rdd.getNumPartitions == 1)
1011+
})
10071012
}
10081013
}

0 commit comments

Comments
 (0)