From 1ade4424f93aa88fe68510a0e134d6b71a632f44 Mon Sep 17 00:00:00 2001 From: Yuchen Liu Date: Tue, 2 Jul 2024 14:05:31 -0700 Subject: [PATCH 01/18] Squashed commit of the following: commit 261c671072f7fa13f46f3db5478bf7141a3249dc Author: Yuchen Liu Date: Tue Jul 2 13:57:57 2024 -0700 solve conflict commit 39d0b1714731d7553988d5bf1e93906aac6b7fe1 Merge: 9af25f17f0e c2d59b09559 Author: Yuchen Liu <170372783+eason-yuchen-liu@users.noreply.github.com> Date: Tue Jul 2 13:45:12 2024 -0700 rebase to master commit c2d59b09559f0aa89bc3339c5b13aca81cd160b7 Merge: 9cf8b252a4f 9af25f17f0e Author: Yuchen Liu <170372783+eason-yuchen-liu@users.noreply.github.com> Date: Tue Jul 2 13:44:50 2024 -0700 Merge branch 'skipSnapshotAtBatch' into state-cdc commit 9af25f17f0ed4ab9158add6fff0f69980bae0c2b Merge: 8fa9ef584e9 fea930ab8d6 Author: Yuchen Liu <170372783+eason-yuchen-liu@users.noreply.github.com> Date: Tue Jul 2 13:23:25 2024 -0700 Merge branch 'apache:master' into skipSnapshotAtBatch commit fea930ab8d6e524dd0125cd76596d38659a58050 Author: Anish Shrigondekar Date: Wed Jul 3 05:21:50 2024 +0900 [SPARK-48770][SS] Change to read operator metadata once on driver to check if we can find info for numColsPrefixKey used for session window agg queries ### What changes were proposed in this pull request? Change to read operator metadata once on driver to check if we can find info for numColsPrefixKey used for session window agg queries ### Why are the changes needed? Avoid reading the operator metadata file multiple times on the executors ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing unit tests ``` ===== POSSIBLE THREAD LEAK IN SUITE o.a.s.sql.execution.datasources.v2.state.RocksDBStateDataSourceReadSuite, threads: ForkJoinPool.commonPool-worker-6 (daemon=true), ForkJoinPool.commonPool-worker-4 (daemon=true), Idle Worker Monitor for python3 (daemon=true), ForkJoinPool.commonPool-worker-7 (daemon=true), ForkJoinPool.commonPool-worker-5 (daemon=true), ForkJoinPool.commonPool-worker-3 (daemon=true), rpc-boss-3-1 (daemon=true), ForkJoinPool.commonPool-worker-8 (daemon=true), shuffle-boss-6-1 (daemon=tru... [info] Run completed in 1 minute, 39 seconds. [info] Total number of tests run: 14 [info] Suites: completed 1, aborted 0 [info] Tests: succeeded 14, failed 0, canceled 0, ignored 0, pending 0 [info] All tests passed. ``` ### Was this patch authored or co-authored using generative AI tooling? No Closes #47167 from anishshri-db/task/SPARK-48770. Authored-by: Anish Shrigondekar Signed-off-by: Jungtaek Lim commit 8fa9ef584e9dc96e5eba8e8218b5f6b50b81e85b Merge: 9dbe2959ee1 ee0d30686c4 Author: Yuchen Liu <170372783+eason-yuchen-liu@users.noreply.github.com> Date: Tue Jul 2 13:21:01 2024 -0700 Merge branch 'apache:master' into skipSnapshotAtBatch commit 9cf8b252a4fbbd56e372c30642bf60c58c353359 Author: Yuchen Liu Date: Tue Jul 2 10:53:53 2024 -0700 add input error tests commit 7354408fa7a6b32419972566ce7701fdc47866c8 Merge: 6d6d511e5d5 9dbe2959ee1 Author: Yuchen Liu Date: Tue Jul 2 10:17:34 2024 -0700 Merge branch 'skipSnapshotAtBatch' into state-cdc commit 9dbe2959ee1c046b0502d579a090983ce7242ebc Author: Yuchen Liu Date: Mon Jul 1 21:54:33 2024 -0700 minor commit 6d6d511e5d555e77002978857f1c6ff9b14812c9 Author: Yuchen Liu Date: Mon Jul 1 15:53:04 2024 -0700 move StateStoreChangeDataReader to other files and delete it commit 104ba9c4dd01aca924c5ed9e83fa0f4bb3d854a5 Author: Yuchen Liu Date: Mon Jul 1 15:36:08 2024 -0700 rename PUT to update commit 12298b24cf7ac930d3fd5ba92e4a0da4f267974a Author: Yuchen Liu Date: Mon Jul 1 13:09:02 2024 -0700 minor commit 75839acc59c7f09ad5207bded78c0c19edbc9388 Author: Yuchen Liu Date: Mon Jul 1 13:03:59 2024 -0700 name all cdc to changeData commit ace711ce85da7ed14deebdc3ad457fc742c74c2e Author: Yuchen Liu Date: Mon Jul 1 12:49:07 2024 -0700 check validity of input to options commit 3834cc9a60141f925650aa4ebd700079aae074f6 Author: Yuchen Liu Date: Fri Jun 28 17:51:16 2024 -0700 solve format issue commit 337785dac6e06a9b75751ab370df989a135d6968 Author: Yuchen Liu Date: Fri Jun 28 17:07:18 2024 -0700 address comments from Anish commit 15a8316c991c0afcb1ffbf73271c2fef1146d2f4 Author: Yuchen Liu Date: Fri Jun 28 16:46:57 2024 -0700 refactor StateStoreChangeDataReader commit b1eb8c4c34a92504397589e0163a359d58398092 Author: Yuchen Liu Date: Fri Jun 28 15:03:09 2024 -0700 add integration tests to the new features commit 7c6cdadb21c4d029b1c00d8920811fd39bfbd108 Author: Yuchen Liu Date: Thu Jun 27 16:35:46 2024 -0700 unify the two traits commit cd6a39bf7b611d1e81dd413255c319521f05ef9f Merge: 271b98e71c5 d14070806ff Author: Yuchen Liu Date: Thu Jun 27 16:22:45 2024 -0700 Merge branch 'skipSnapshotAtBatch' into state-cdc commit d14070806ff7769c3fa362033e63931ab2fc58ea Author: Yuchen Liu Date: Thu Jun 27 15:17:06 2024 -0700 provide the script to regenerate golden files commit 4deb63eb972cfe37a0086abd2fdaf0f145fecf1e Author: Yuchen Liu Date: Thu Jun 27 14:22:00 2024 -0700 throw the exception commit 6f1425d1ae577c620c6d5ff3e3baa360e000fff7 Author: Yuchen Liu Date: Thu Jun 27 12:09:54 2024 -0700 reflect more comments from Jungtaek commit 42d952f2954f4cce19c85d192cdd25f1fbeb4a45 Author: Yuchen Liu Date: Thu Jun 27 11:11:33 2024 -0700 rename SupportsFineGrainedReplayFromSnapshot to SupportsFineGrainedReplay commit e15213eb64cb2be57adcbfe7c0cf12b404a58b14 Author: Yuchen Liu Date: Thu Jun 27 11:05:50 2024 -0700 rename to startVersion to snapshotVersion to make its function clear commit 271b98e71c5eef687243dc34c6f003f0858a14af Author: Yuchen Liu Date: Wed Jun 26 15:46:33 2024 -0700 make sure StateStoreChangeData is used everywhere commit ff5bff2e18c97f48047edb81ca38f7c883c7541d Merge: 6922595654f 40b6dc61c0f Author: Yuchen Liu Date: Wed Jun 26 15:22:19 2024 -0700 Merge branch 'skipSnapshotAtBatch' into state-cdc commit 40b6dc61c0fce4c8b4c6099d3983d0157690ddf6 Author: Yuchen Liu Date: Wed Jun 26 10:59:17 2024 -0700 move error to StateStoreErrors commit 23639f46d5686dc91b6941cb2bdd33616c79da5a Author: Yuchen Liu Date: Wed Jun 26 10:44:22 2024 -0700 create new error for SupportsFineGrainedReplayFromSnapshot commit 97ee3efb39f1a5a7c9405adefed14c5b59c66345 Author: Yuchen Liu Date: Wed Jun 26 10:25:57 2024 -0700 some naming and formatting comments from Anish and Jungtaek commit 1a23abbef8983ac06c5511d8ea0946b9e8db0f65 Author: Yuchen Liu Date: Tue Jun 25 14:56:07 2024 -0700 refactor the code to isolate from current state stores used by streaming queries commit 876256e10a9191a3170b0747260ecde87aa16e49 Author: Yuchen Liu Date: Tue Jun 25 12:29:40 2024 -0700 reflect comments from Jungtaek commit ef9b095e1c8f75d3462356dde5cf790a2583f608 Author: Yuchen Liu Date: Tue Jun 25 12:08:34 2024 -0700 create integration test against golden files commit 6922595654fa960e9dbf3648cbc6df79a57a53a7 Author: Yuchen Liu Date: Mon Jun 24 13:44:19 2024 -0700 stage commit 3ece6f2ea1b693272458ca5d83b72357940b5678 Author: Yuchen Liu Date: Fri Jun 21 21:22:50 2024 -0700 resort error-conditions commit be30817178691b70d5e657d6ca4a777fbe062381 Author: Yuchen Liu Date: Fri Jun 21 17:30:12 2024 -0700 Reflect more comments from Anish commit cf84d50a5563269440791d5617b2bab568240a6d Author: Yuchen Liu Date: Fri Jun 21 14:02:58 2024 -0700 support hdfs state store provider commit 752cdc7b8630836dd96d645c8424c27b79de3e4e Author: Yuchen Liu Date: Thu Jun 20 17:51:33 2024 -0700 separate CDCPartitionReader from StatePartitionReader commit bd870552e3e746273bf9f86fbfcd87037bec03ab Merge: 2184396eba4 2eb66468acb Author: Yuchen Liu Date: Thu Jun 20 17:29:31 2024 -0700 Merge branch 'skipSnapshotAtBatch' into state-cdc commit 2eb66468acbf73f3b6e4980aaaa6380659415ae2 Author: Yuchen Liu Date: Thu Jun 20 17:10:45 2024 -0700 also update the name of StateTable commit 2184396eba486f7416035218ee33353ec7ca4b05 Author: Yuchen Liu Date: Thu Jun 20 17:03:18 2024 -0700 hdfs initial implementation commit 3f266c193240243101e5d7cbaf5384d6b904a37f Author: Yuchen Liu Date: Mon Jun 17 09:46:07 2024 -0700 style commit fe9cea16e4836573106bc378d5294058bf39880c Author: Yuchen Liu Date: Fri Jun 14 12:50:21 2024 -0700 address more comments from Anish commit 1870b354a4adee9cc47ac318f1e9459dd37f3abb Merge: 4d4cd70b617 9eb6c76e86b Author: Yuchen Liu Date: Thu Jun 13 14:25:23 2024 -0700 Merge branch 'skipSnapshotAtBatch' of https://github.com/eason-yuchen-liu/spark into skipSnapshotAtBatch commit 4d4cd70b6173d98cf2e781f9f565e2e4e8e13952 Author: Yuchen Liu Date: Thu Jun 13 14:24:55 2024 -0700 log StateSourceOptions optionally commit 9eb6c76e86b39fa27a46e63ab409fcdb7f7b369b Merge: 20e1b9c605e 08e741b92b8 Author: Yuchen Liu <170372783+eason-yuchen-liu@users.noreply.github.com> Date: Thu Jun 13 14:18:50 2024 -0700 Merge branch 'master' into skipSnapshotAtBatch commit 20e1b9c605e8e3360fb1a7e3ebdbf1e1291a04d2 Author: Yuchen Liu Date: Thu Jun 13 14:16:14 2024 -0700 address comments from Anish & Wei commit 4825215a1e80028f7e2deccf502ecabae15e8766 Author: Yuchen Liu Date: Thu Jun 13 11:45:55 2024 -0700 address reviews by Wei partially commit 5229152751c9469d49f9abd7b2b42d01bb2079e0 Author: Yuchen Liu Date: Wed Jun 12 11:29:46 2024 -0700 support reading join states commit 61dea356fe8b77a691f2e5f26babf6837144215e Author: Yuchen Liu Date: Tue Jun 11 13:16:56 2024 -0700 minor commit 1656580dcff44c66e1874e7c411def4d41c390d6 Author: Yuchen Liu Date: Tue Jun 11 12:07:06 2024 -0700 improve doc commit 4ebd078e6ebcf41b15a5efdd1a2a1a10c2ea00ab Author: Yuchen Liu Date: Tue Jun 11 11:48:30 2024 -0700 move partition error commit dfa712e3b3ecf0c313a8e8dd7574bc8f69073fad Author: Yuchen Liu Date: Tue Jun 11 11:42:09 2024 -0700 clean up and format commit aa337c193628a14c21d7f67c905cc7ceab101de3 Author: Yuchen Liu Date: Tue Jun 11 10:22:59 2024 -0700 add new test on partition not found error commit 292ec5d34a7f40a7d12c420cc3e085113b00a086 Author: Yuchen Liu Date: Mon Jun 10 16:54:38 2024 -0700 delete useless test files commit 1a3d20aa36c3d97fe48eaa55c13e623305252fbd Author: Yuchen Liu Date: Mon Jun 10 16:52:22 2024 -0700 make sure test is stable commit eddb3c7c8479a5f236a93ca8202d8ab39886d200 Merge: 9d902d76327 5a2f374a208 Author: Yuchen Liu <170372783+eason-yuchen-liu@users.noreply.github.com> Date: Mon Jun 10 11:43:03 2024 -0700 Merge branch 'apache:master' into skipSnapshotAtBatch commit 9d902d763279f037035a03a8a58e167cbd445be7 Author: Yuchen Liu Date: Mon Jun 10 11:13:02 2024 -0700 test directly on the method instead of end to end commit 07267b599370637bff0b745198af2e1d34e9eee9 Author: Yuchen Liu Date: Fri Jun 7 16:43:45 2024 -0700 allow rocksdb to reconstruct state from a specific checkpoint commit 24751730288c9aa83c3d5777c46cc0423024d899 Author: Yuchen Liu Date: Thu Jun 6 10:32:56 2024 -0700 add test cases for two options in HDFS state store commit 7dad0c19755d6f2ea8b3c4baabf8dbd7b462fb67 Merge: 6db0e3dd38b 8a0927c07a1 Author: Yuchen Liu Date: Tue Jun 4 15:30:20 2024 -0700 Merge branch 'skipSnapshotAtBatch' of https://github.com/eason-yuchen-liu/spark into skipSnapshotAtBatch commit 6db0e3dd38b089efe691fe4f4c33880ea2995580 Author: Yuchen Liu Date: Tue Jun 4 15:28:49 2024 -0700 initial implementation --- .../v2/state/StateDataSource.scala | 92 ++++- .../v2/state/StatePartitionReader.scala | 75 +++- .../v2/state/StateScanBuilder.scala | 16 +- .../datasources/v2/state/StateTable.scala | 30 +- .../execution/streaming/HDFSMetadataLog.scala | 2 + .../state/HDFSBackedStateStoreProvider.scala | 43 +++ .../state/RocksDBStateStoreProvider.scala | 48 +++ .../streaming/state/StateStore.scala | 15 +- .../streaming/state/StateStoreChangelog.scala | 90 +++++ .../state/SymmetricHashJoinStateManager.scala | 19 +- .../StateDataSourceChangeDataReadSuite.scala | 324 ++++++++++++++++++ .../v2/state/StateDataSourceTestBase.scala | 2 +- 12 files changed, 715 insertions(+), 41 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index e2724cb59754d..2e9f300eae7c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -30,13 +30,15 @@ import org.apache.spark.sql.connector.catalog.{Table, TableProvider} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues.JoinSideValues +import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataPartitionReader import org.apache.spark.sql.execution.streaming.{CommitLog, OffsetSeqLog, OffsetSeqMetadata} import org.apache.spark.sql.execution.streaming.StreamingCheckpointConstants.{DIR_NAME_COMMITS, DIR_NAME_OFFSETS, DIR_NAME_STATE} import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide} import org.apache.spark.sql.execution.streaming.state.{StateSchemaCompatibilityChecker, StateStore, StateStoreConf, StateStoreId, StateStoreProviderId} import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.SerializableConfiguration /** * An implementation of [[TableProvider]] with [[DataSourceRegister]] for State Store data source. @@ -46,6 +48,8 @@ class StateDataSource extends TableProvider with DataSourceRegister { private lazy val hadoopConf: Configuration = session.sessionState.newHadoopConf() + private lazy val serializedHadoopConf = new SerializableConfiguration(hadoopConf) + override def shortName(): String = "statestore" override def getTable( @@ -54,7 +58,17 @@ class StateDataSource extends TableProvider with DataSourceRegister { properties: util.Map[String, String]): Table = { val sourceOptions = StateSourceOptions.apply(session, hadoopConf, properties) val stateConf = buildStateStoreConf(sourceOptions.resolvedCpLocation, sourceOptions.batchId) - new StateTable(session, schema, sourceOptions, stateConf) + // Read the operator metadata once to see if we can find the information for prefix scan + // encoder used in session window aggregation queries. + val allStateStoreMetadata = new StateMetadataPartitionReader( + sourceOptions.stateCheckpointLocation.getParent.toString, serializedHadoopConf) + .stateMetadata.toArray + val stateStoreMetadata = allStateStoreMetadata.filter { entry => + entry.operatorId == sourceOptions.operatorId && + entry.stateStoreName == sourceOptions.storeName + } + + new StateTable(session, schema, sourceOptions, stateConf, stateStoreMetadata) } override def inferSchema(options: CaseInsensitiveStringMap): StructType = { @@ -80,10 +94,21 @@ class StateDataSource extends TableProvider with DataSourceRegister { manager.readSchemaFile() } - new StructType() - .add("key", keySchema) - .add("value", valueSchema) - .add("partition_id", IntegerType) + if (sourceOptions.readChangeFeed) { + new StructType() + .add("key", keySchema) + .add("value", valueSchema) + .add("change_type", StringType) + .add("batch_id", LongType) + .add("partition_id", IntegerType) + } else { + new StructType() + .add("key", keySchema) + .add("value", valueSchema) + .add("partition_id", IntegerType) + } + + } catch { case NonFatal(e) => throw StateDataSourceErrors.failedToReadStateSchema(sourceOptions, e) @@ -118,7 +143,10 @@ case class StateSourceOptions( storeName: String, joinSide: JoinSideValues, snapshotStartBatchId: Option[Long], - snapshotPartitionId: Option[Int]) { + snapshotPartitionId: Option[Int], + readChangeFeed: Boolean, + changeStartBatchId: Option[Long], + changeEndBatchId: Option[Long]) { def stateCheckpointLocation: Path = new Path(resolvedCpLocation, DIR_NAME_STATE) override def toString: String = { @@ -137,6 +165,9 @@ object StateSourceOptions extends DataSourceOptions { val JOIN_SIDE = newOption("joinSide") val SNAPSHOT_START_BATCH_ID = newOption("snapshotStartBatchId") val SNAPSHOT_PARTITION_ID = newOption("snapshotPartitionId") + val READ_CHANGE_FEED = newOption("readChangeFeed") + val CHANGE_START_BATCH_ID = newOption("changeStartBatchId") + val CHANGE_END_BATCH_ID = newOption("changeEndBatchId") object JoinSideValues extends Enumeration { type JoinSideValues = Value @@ -217,9 +248,45 @@ object StateSourceOptions extends DataSourceOptions { throw StateDataSourceErrors.requiredOptionUnspecified(SNAPSHOT_PARTITION_ID) } + val readChangeFeed = Option(options.get(READ_CHANGE_FEED)).exists(_.toBoolean) + + val changeStartBatchId = Option(options.get(CHANGE_START_BATCH_ID)).map(_.toLong) + var changeEndBatchId = Option(options.get(CHANGE_END_BATCH_ID)).map(_.toLong) + + if (readChangeFeed) { + if (joinSide != JoinSideValues.none) { + throw StateDataSourceErrors.conflictOptions(Seq(JOIN_SIDE, READ_CHANGE_FEED)) + } + if (changeStartBatchId.isEmpty) { + throw StateDataSourceErrors.requiredOptionUnspecified(CHANGE_START_BATCH_ID) + } + changeEndBatchId = Option(changeEndBatchId.getOrElse(batchId)) + + // changeStartBatchId and changeEndBatchId must all be defined at this point + if (changeStartBatchId.get < 0) { + throw StateDataSourceErrors.invalidOptionValueIsNegative(CHANGE_START_BATCH_ID) + } + if (changeEndBatchId.get < changeStartBatchId.get) { + throw StateDataSourceErrors.invalidOptionValue(CHANGE_END_BATCH_ID, + s"$CHANGE_END_BATCH_ID cannot be smaller than $CHANGE_START_BATCH_ID. " + + s"Please check the input to $CHANGE_END_BATCH_ID, or if you are using its default " + + s"value, make sure that $CHANGE_START_BATCH_ID is less than ${changeEndBatchId.get}.") + } + } else { + if (changeStartBatchId.isDefined) { + throw StateDataSourceErrors.invalidOptionValue(CHANGE_START_BATCH_ID, + s"Only specify this option when $READ_CHANGE_FEED is set to true.") + } + if (changeEndBatchId.isDefined) { + throw StateDataSourceErrors.invalidOptionValue(CHANGE_END_BATCH_ID, + s"Only specify this option when $READ_CHANGE_FEED is set to true.") + } + } + StateSourceOptions( resolvedCpLocation, batchId, operatorId, storeName, - joinSide, snapshotStartBatchId, snapshotPartitionId) + joinSide, snapshotStartBatchId, snapshotPartitionId, + readChangeFeed, changeStartBatchId, changeEndBatchId) } private def resolvedCheckpointLocation( @@ -238,4 +305,13 @@ object StateSourceOptions extends DataSourceOptions { case None => throw StateDataSourceErrors.committedBatchUnavailable(checkpointLocation) } } + + private def getFirstCommittedBatch(session: SparkSession, checkpointLocation: String): Long = { + val commitLog = new CommitLog(session, + new Path(checkpointLocation, DIR_NAME_COMMITS).toString) + commitLog.getEarliestBatchId() match { + case Some(firstId) => firstId + case None => throw StateDataSourceErrors.committedBatchUnavailable(checkpointLocation) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala index f09a2763031e0..663d06f53b176 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala @@ -20,10 +20,12 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} -import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataPartitionReader +import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataTableEntry import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil import org.apache.spark.sql.execution.streaming.state._ +import org.apache.spark.sql.execution.streaming.state.RecordType.{getRecordTypeAsString, RecordType} import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.SerializableConfiguration /** @@ -33,11 +35,18 @@ import org.apache.spark.util.SerializableConfiguration class StatePartitionReaderFactory( storeConf: StateStoreConf, hadoopConf: SerializableConfiguration, - schema: StructType) extends PartitionReaderFactory { + schema: StructType, + stateStoreMetadata: Array[StateMetadataTableEntry]) extends PartitionReaderFactory { override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { - new StatePartitionReader(storeConf, hadoopConf, - partition.asInstanceOf[StateStoreInputPartition], schema) + val stateStoreInputPartition = partition.asInstanceOf[StateStoreInputPartition] + if (stateStoreInputPartition.sourceOptions.readChangeFeed) { + new StateStoreChangeDataPartitionReader(storeConf, hadoopConf, + partition.asInstanceOf[StateStoreInputPartition], schema, stateStoreMetadata) + } else { + new StatePartitionReader(storeConf, hadoopConf, + partition.asInstanceOf[StateStoreInputPartition], schema, stateStoreMetadata) + } } } @@ -49,22 +58,17 @@ class StatePartitionReader( storeConf: StateStoreConf, hadoopConf: SerializableConfiguration, partition: StateStoreInputPartition, - schema: StructType) extends PartitionReader[InternalRow] with Logging { + schema: StructType, + stateStoreMetadata: Array[StateMetadataTableEntry]) + extends PartitionReader[InternalRow] with Logging { private val keySchema = SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType] private val valueSchema = SchemaUtil.getSchemaAsDataType(schema, "value").asInstanceOf[StructType] - private lazy val provider: StateStoreProvider = { + protected lazy val provider: StateStoreProvider = { val stateStoreId = StateStoreId(partition.sourceOptions.stateCheckpointLocation.toString, partition.sourceOptions.operatorId, partition.partition, partition.sourceOptions.storeName) val stateStoreProviderId = StateStoreProviderId(stateStoreId, partition.queryId) - val allStateStoreMetadata = new StateMetadataPartitionReader( - partition.sourceOptions.stateCheckpointLocation.getParent.toString, hadoopConf) - .stateMetadata.toArray - val stateStoreMetadata = allStateStoreMetadata.filter { entry => - entry.operatorId == partition.sourceOptions.operatorId && - entry.stateStoreName == partition.sourceOptions.storeName - } val numColsPrefixKey = if (stateStoreMetadata.isEmpty) { logWarning("Metadata for state store not found, possible cause is this checkpoint " + "is created by older version of spark. If the query has session window aggregation, " + @@ -108,11 +112,11 @@ class StatePartitionReader( } } - private lazy val iter: Iterator[InternalRow] = { + protected lazy val iter: Iterator[InternalRow] = { store.iterator().map(pair => unifyStateRowPair((pair.key, pair.value))) } - private var current: InternalRow = _ + protected var current: InternalRow = _ override def next(): Boolean = { if (iter.hasNext) { @@ -140,3 +144,44 @@ class StatePartitionReader( row } } + +class StateStoreChangeDataPartitionReader( + storeConf: StateStoreConf, + hadoopConf: SerializableConfiguration, + partition: StateStoreInputPartition, + schema: StructType, + stateStoreMetadata: Array[StateMetadataTableEntry]) + extends StatePartitionReader(storeConf, hadoopConf, partition, schema, stateStoreMetadata) { + + private lazy val changeDataReader: StateStoreChangeDataReader = { + if (!provider.isInstanceOf[SupportsFineGrainedReplay]) { + throw StateStoreErrors.stateStoreProviderDoesNotSupportFineGrainedReplay( + provider.getClass.toString) + } + provider.asInstanceOf[SupportsFineGrainedReplay] + .getStateStoreChangeDataReader( + partition.sourceOptions.changeStartBatchId.get + 1, + partition.sourceOptions.changeEndBatchId.get + 1) + } + + override protected lazy val iter: Iterator[InternalRow] = { + changeDataReader.iterator.map(unifyStateChangeDataRow) + } + + override def close(): Unit = { + current = null + changeDataReader.close() + provider.close() + } + + private def unifyStateChangeDataRow(row: (RecordType, UnsafeRow, UnsafeRow, Long)): + InternalRow = { + val result = new GenericInternalRow(5) + result.update(0, row._2) + result.update(1, row._3) + result.update(2, UTF8String.fromString(getRecordTypeAsString(row._1))) + result.update(3, row._4) + result.update(4, partition.partition) + result + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala index ffcbcd0872e10..821a36977fed1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.fs.{Path, PathFilter} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory, Scan, ScanBuilder} import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues +import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataTableEntry import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide} import org.apache.spark.sql.execution.streaming.state.{StateStoreConf, StateStoreErrors} import org.apache.spark.sql.types.StructType @@ -35,8 +36,10 @@ class StateScanBuilder( session: SparkSession, schema: StructType, sourceOptions: StateSourceOptions, - stateStoreConf: StateStoreConf) extends ScanBuilder { - override def build(): Scan = new StateScan(session, schema, sourceOptions, stateStoreConf) + stateStoreConf: StateStoreConf, + stateStoreMetadata: Array[StateMetadataTableEntry]) extends ScanBuilder { + override def build(): Scan = new StateScan(session, schema, sourceOptions, stateStoreConf, + stateStoreMetadata) } /** An implementation of [[InputPartition]] for State Store data source. */ @@ -50,7 +53,8 @@ class StateScan( session: SparkSession, schema: StructType, sourceOptions: StateSourceOptions, - stateStoreConf: StateStoreConf) extends Scan with Batch { + stateStoreConf: StateStoreConf, + stateStoreMetadata: Array[StateMetadataTableEntry]) extends Scan with Batch { // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it private val hadoopConfBroadcast = session.sparkContext.broadcast( @@ -62,7 +66,8 @@ class StateScan( val fs = stateCheckpointPartitionsLocation.getFileSystem(hadoopConfBroadcast.value.value) val partitions = fs.listStatus(stateCheckpointPartitionsLocation, new PathFilter() { override def accept(path: Path): Boolean = { - fs.isDirectory(path) && Try(path.getName.toInt).isSuccess && path.getName.toInt >= 0 + fs.getFileStatus(path).isDirectory && + Try(path.getName.toInt).isSuccess && path.getName.toInt >= 0 } }) @@ -116,7 +121,8 @@ class StateScan( hadoopConfBroadcast.value, userFacingSchema, stateSchema) case JoinSideValues.none => - new StatePartitionReaderFactory(stateStoreConf, hadoopConfBroadcast.value, schema) + new StatePartitionReaderFactory(stateStoreConf, hadoopConfBroadcast.value, schema, + stateStoreMetadata) } override def toBatch: Batch = this diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala index dbd39f519e500..0750d7549ec30 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala @@ -24,9 +24,10 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connector.catalog.{MetadataColumn, SupportsMetadataColumns, SupportsRead, Table, TableCapability} import org.apache.spark.sql.connector.read.ScanBuilder import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues +import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataTableEntry import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil import org.apache.spark.sql.execution.streaming.state.StateStoreConf -import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.ArrayImplicits._ @@ -35,7 +36,8 @@ class StateTable( session: SparkSession, override val schema: StructType, sourceOptions: StateSourceOptions, - stateConf: StateStoreConf) + stateConf: StateStoreConf, + stateStoreMetadata: Array[StateMetadataTableEntry]) extends Table with SupportsRead with SupportsMetadataColumns { import StateTable._ @@ -69,11 +71,14 @@ class StateTable( override def capabilities(): util.Set[TableCapability] = CAPABILITY override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = - new StateScanBuilder(session, schema, sourceOptions, stateConf) + new StateScanBuilder(session, schema, sourceOptions, stateConf, stateStoreMetadata) override def properties(): util.Map[String, String] = Map.empty[String, String].asJava private def isValidSchema(schema: StructType): Boolean = { + if (sourceOptions.readChangeFeed) { + return isValidChangeDataSchema(schema) + } if (schema.fieldNames.toImmutableArraySeq != Seq("key", "value", "partition_id")) { false } else if (!SchemaUtil.getSchemaAsDataType(schema, "key").isInstanceOf[StructType]) { @@ -87,6 +92,25 @@ class StateTable( } } + private def isValidChangeDataSchema(schema: StructType): Boolean = { + if (schema.fieldNames.toImmutableArraySeq != + Seq("key", "value", "change_type", "batch_id", "partition_id")) { + false + } else if (!SchemaUtil.getSchemaAsDataType(schema, "key").isInstanceOf[StructType]) { + false + } else if (!SchemaUtil.getSchemaAsDataType(schema, "value").isInstanceOf[StructType]) { + false + } else if (!SchemaUtil.getSchemaAsDataType(schema, "change_type").isInstanceOf[StringType]) { + false + } else if (!SchemaUtil.getSchemaAsDataType(schema, "batch_id").isInstanceOf[LongType]) { + false + } else if (!SchemaUtil.getSchemaAsDataType(schema, "partition_id").isInstanceOf[IntegerType]) { + false + } else { + true + } + } + override def metadataColumns(): Array[MetadataColumn] = Array.empty } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index 251cc16acdf43..2ae838581f6f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -264,6 +264,8 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: /** Return the latest batch id without reading the file. */ def getLatestBatchId(): Option[Long] = listBatches.sorted.lastOption + def getEarliestBatchId(): Option[Long] = listBatches.sorted.headOption + override def getLatest(): Option[(Long, T)] = { listBatches.sorted.lastOption.map { batchId => logInfo(log"Getting latest batch ${MDC(BATCH_ID, batchId)}") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index c4a41ceb4caf4..869c4f3c2e5e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -978,4 +978,47 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with result } + + override def getStateStoreChangeDataReader(startVersion: Long, endVersion: Long): + StateStoreChangeDataReader = { + new HDFSBackedStateStoreChangeDataReader(fm, baseDir, startVersion, endVersion, + CompressionCodec.createCodec(sparkConf, storeConf.compressionCodec), + keySchema, valueSchema) + } +} + +/** [[StateStoreChangeDataReader]] implementation for [[HDFSBackedStateStoreProvider]] */ +class HDFSBackedStateStoreChangeDataReader( + fm: CheckpointFileManager, + stateLocation: Path, + startVersion: Long, + endVersion: Long, + compressionCodec: CompressionCodec, + keySchema: StructType, + valueSchema: StructType) + extends StateStoreChangeDataReader( + fm, stateLocation, startVersion, endVersion, compressionCodec) { + + override protected var changelogSuffix: String = "delta" + + override def getNext(): (RecordType.Value, UnsafeRow, UnsafeRow, Long) = { + val reader = currentChangelogReader() + if (reader == null) { + return null + } + val (recordType, keyArray, valueArray, _) = reader.next() + val keyRow = new UnsafeRow(keySchema.fields.length) + keyRow.pointTo(keyArray, keyArray.length) + if (valueArray == null) { + (recordType, keyRow, null, currentChangelogVersion - 1) + } else { + val valueRow = new UnsafeRow(valueSchema.fields.length) + // If valueSize in existing file is not multiple of 8, floor it to multiple of 8. + // This is a workaround for the following: + // Prior to Spark 2.3 mistakenly append 4 bytes to the value row in + // `RowBasedKeyValueBatch`, which gets persisted into the checkpoint data + valueRow.pointTo(valueArray, (valueArray.length / 8) * 8) + (recordType, keyRow, valueRow, currentChangelogVersion - 1) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index a555f9a40044a..3d2606f3ed0e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -18,16 +18,20 @@ package org.apache.spark.sql.execution.streaming.state import java.io._ +import java.util.concurrent.ConcurrentHashMap import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys._ +import org.apache.spark.io.CompressionCodec import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.execution.streaming.CheckpointFileManager import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -392,6 +396,19 @@ private[sql] class RocksDBStateStoreProvider case e: Throwable => throw QueryExecutionErrors.cannotLoadStore(e) } } + + override def getStateStoreChangeDataReader(startVersion: Long, endVersion: Long): + StateStoreChangeDataReader = { + val statePath = stateStoreId.storeCheckpointLocation() + val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) + new RocksDBStateStoreChangeDataReader( + CheckpointFileManager.create(statePath, hadoopConf), + statePath, + startVersion, + endVersion, + CompressionCodec.createCodec(sparkConf, storeConf.compressionCodec), + keyValueEncoderMap) + } } object RocksDBStateStoreProvider { @@ -487,3 +504,34 @@ object RocksDBStateStoreProvider { CUSTOM_METRIC_PINNED_BLOCKS_MEM_USAGE, CUSTOM_METRIC_NUM_EXTERNAL_COL_FAMILIES, CUSTOM_METRIC_NUM_INTERNAL_COL_FAMILIES) } + +/** [[StateStoreChangeDataReader]] implementation for [[RocksDBStateStoreProvider]] */ +class RocksDBStateStoreChangeDataReader( + fm: CheckpointFileManager, + stateLocation: Path, + startVersion: Long, + endVersion: Long, + compressionCodec: CompressionCodec, + keyValueEncoderMap: + ConcurrentHashMap[String, (RocksDBKeyStateEncoder, RocksDBValueStateEncoder)]) + extends StateStoreChangeDataReader( + fm, stateLocation, startVersion, endVersion, compressionCodec) { + + override protected var changelogSuffix: String = "changelog" + + override def getNext(): (RecordType.Value, UnsafeRow, UnsafeRow, Long) = { + val reader = currentChangelogReader() + if (reader == null) { + return null + } + val (recordType, keyArray, valueArray, columnFamily) = reader.next() + val (rocksDBKeyStateEncoder, rocksDBValueStateEncoder) = keyValueEncoderMap.get(columnFamily) + val keyRow = rocksDBKeyStateEncoder.decodeKey(keyArray) + if (valueArray == null) { + (recordType, keyRow, null, currentChangelogVersion - 1) + } else { + val valueRow = rocksDBValueStateEncoder.decodeValue(valueArray) + (recordType, keyRow, valueRow, currentChangelogVersion - 1) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 76fd36bd726a6..507facacb9cda 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -439,9 +439,9 @@ object StateStoreProvider { } /** - * This is an optional trait to be implemented by [[StateStoreProvider]]s that can read fine - * grained state data which is replayed from a specific snapshot version. It is used by the - * snapshotStartBatchId option in state data source. + * This is an optional trait to be implemented by [[StateStoreProvider]]s that can read the change + * of state store over batches. This is used by State Data Source with additional options like + * snapshotStartBatchId or readChangeFeed. */ trait SupportsFineGrainedReplay { @@ -469,6 +469,15 @@ trait SupportsFineGrainedReplay { def replayReadStateFromSnapshot(snapshotVersion: Long, endVersion: Long): ReadStateStore = { new WrappedReadStateStore(replayStateFromSnapshot(snapshotVersion, endVersion)) } + + /** + * + * @param startVersion + * @param endVersion + * @return + */ + def getStateStoreChangeDataReader(startVersion: Long, endVersion: Long): + StateStoreChangeDataReader } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala index b1860be41ac44..62548ef0ec568 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala @@ -28,6 +28,7 @@ import org.apache.hadoop.fs.{FSError, Path} import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys._ import org.apache.spark.io.CompressionCodec +import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.streaming.CheckpointFileManager import org.apache.spark.sql.execution.streaming.CheckpointFileManager.CancellableFSDataOutputStream @@ -55,6 +56,15 @@ object RecordType extends Enumeration { } } + def getRecordTypeAsString(recordType: RecordType): String = { + recordType match { + case PUT_RECORD => "update" + case DELETE_RECORD => "delete" + case _ => throw StateStoreErrors.unsupportedOperationException( + "getRecordTypeAsString", recordType.toString) + } + } + // Generate record type from byte representation def getRecordTypeFromByte(byte: Byte): RecordType = { byte match { @@ -390,3 +400,83 @@ class StateStoreChangelogReaderV2( } } } + +/** + * Base class representing a iterator that iterates over a range of changelog files in a state + * store. In each iteration, it will return a tuple of (changeType: [[RecordType]], + * nested key: [[UnsafeRow]], nested value: [[UnsafeRow]], batchId: [[Long]]) + * + * @param fm checkpoint file manager used to manage streaming query checkpoint + * @param stateLocation location of the state store + * @param startVersion start version of the changelog file to read + * @param endVersion end version of the changelog file to read + * @param compressionCodec de-compression method using for reading changelog file + */ +abstract class StateStoreChangeDataReader( + fm: CheckpointFileManager, + stateLocation: Path, + startVersion: Long, + endVersion: Long, + compressionCodec: CompressionCodec) + extends NextIterator[(RecordType.Value, UnsafeRow, UnsafeRow, Long)] with Logging { + + assert(startVersion >= 1) + assert(endVersion >= startVersion) + + /** + * Iterator that iterates over the changelog files in the state store. + */ + private class ChangeLogFileIterator extends Iterator[Path] { + + private var currentVersion = StateStoreChangeDataReader.this.startVersion - 1 + + /** returns the version of the changelog returned by the latest [[next]] function call */ + def getVersion: Long = currentVersion + + override def hasNext: Boolean = currentVersion < StateStoreChangeDataReader.this.endVersion + + override def next(): Path = { + currentVersion += 1 + getChangelogPath(currentVersion) + } + + private def getChangelogPath(version: Long): Path = + new Path( + StateStoreChangeDataReader.this.stateLocation, + s"$version.${StateStoreChangeDataReader.this.changelogSuffix}") + } + + /** file format of the changelog files */ + protected var changelogSuffix: String + private lazy val fileIterator = new ChangeLogFileIterator + private var changelogReader: StateStoreChangelogReader = null + + /** + * Get a changelog reader that has at least one record left to read. If there is no readers left, + * return null. + */ + protected def currentChangelogReader(): StateStoreChangelogReader = { + while (changelogReader == null || !changelogReader.hasNext) { + if (changelogReader != null) { + changelogReader.close() + } + if (!fileIterator.hasNext) { + finished = true + return null + } + // Todo: Does not support StateStoreChangelogReaderV2 + changelogReader = + new StateStoreChangelogReaderV1(fm, fileIterator.next(), compressionCodec) + } + changelogReader + } + + /** get the version of the current changelog reader */ + protected def currentChangelogVersion: Long = fileIterator.getVersion + + override def close(): Unit = { + if (changelogReader != null) { + changelogReader.close() + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala index 4de3170f5db33..3675325780269 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala @@ -444,8 +444,8 @@ class SymmetricHashJoinStateManager( private val keySchema = StructType( joinKeys.zipWithIndex.map { case (k, i) => StructField(s"field$i", k.dataType, k.nullable) }) private val keyAttributes = toAttributes(keySchema) - private val keyToNumValues = new KeyToNumValuesStore() - private val keyWithIndexToValue = new KeyWithIndexToValueStore(stateFormatVersion) + private lazy val keyToNumValues = new KeyToNumValuesStore() + private lazy val keyWithIndexToValue = new KeyWithIndexToValueStore(stateFormatVersion) // Clean up any state store resources if necessary at the end of the task Option(TaskContext.get()).foreach { _.addTaskCompletionListener[Unit] { _ => abortIfNeeded() } } @@ -476,6 +476,16 @@ class SymmetricHashJoinStateManager( def metrics: StateStoreMetrics = stateStore.metrics + private def initializeStateStoreProvider(keySchema: StructType, valueSchema: StructType): + Unit = { + val storeProviderId = StateStoreProviderId( + stateInfo.get, partitionId, getStateStoreName(joinSide, stateStoreType)) + stateStoreProvider = StateStoreProvider.createAndInit( + storeProviderId, keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema), + useColumnFamilies = false, storeConf, hadoopConf, + useMultipleValuesPerKey = false) + } + /** Get the StateStore with the given schema */ protected def getStateStore(keySchema: StructType, valueSchema: StructType): StateStore = { val storeProviderId = StateStoreProviderId( @@ -488,10 +498,7 @@ class SymmetricHashJoinStateManager( stateInfo.get.storeVersion, useColumnFamilies = false, storeConf, hadoopConf) } else { // This class will manage the state store provider by itself. - stateStoreProvider = StateStoreProvider.createAndInit( - storeProviderId, keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema), - useColumnFamilies = false, storeConf, hadoopConf, - useMultipleValuesPerKey = false) + initializeStateStoreProvider(keySchema, valueSchema) if (snapshotStartVersion.isDefined) { if (!stateStoreProvider.isInstanceOf[SupportsFineGrainedReplay]) { throw StateStoreErrors.stateStoreProviderDoesNotSupportFineGrainedReplay( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala new file mode 100644 index 0000000000000..da8b59ba23458 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala @@ -0,0 +1,324 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.state + +import org.apache.hadoop.conf.Configuration +import org.scalatest.Assertions + +import org.apache.spark.sql.Row +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.execution.streaming.state._ +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.StructType + +class HDFSBackedStateDataSourceChangeDataReaderSuite extends StateDataSourceChangeDataReaderSuite { + override protected def newStateStoreProvider(): HDFSBackedStateStoreProvider = + new HDFSBackedStateStoreProvider +} + +class RocksDBWithChangelogCheckpointStateDataSourceChangeDataReaderSuite extends + StateDataSourceChangeDataReaderSuite { + override protected def newStateStoreProvider(): RocksDBStateStoreProvider = + new RocksDBStateStoreProvider + + override def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set("spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled", + "true") + } +} + +abstract class StateDataSourceChangeDataReaderSuite extends StateDataSourceTestBase + with Assertions { + + import testImplicits._ + import StateStoreTestsHelper._ + + protected val keySchema: StructType = StateStoreTestsHelper.keySchema + protected val valueSchema: StructType = StateStoreTestsHelper.valueSchema + + protected def newStateStoreProvider(): StateStoreProvider + + override def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set(SQLConf.STREAMING_NO_DATA_MICRO_BATCHES_ENABLED, false) + spark.conf.set(SQLConf.STATE_STORE_PROVIDER_CLASS.key, newStateStoreProvider().getClass.getName) + } + + /** + * Calls the overridable [[newStateStoreProvider]] to create the state store provider instance. + * Initialize it with the configuration set by child classes. + * + * @param checkpointDir path to store state information + * @return instance of class extending [[StateStoreProvider]] + */ + private def getNewStateStoreProvider(checkpointDir: String): StateStoreProvider = { + val provider = newStateStoreProvider() + provider.init( + StateStoreId(checkpointDir, 0, 0), + keySchema, + valueSchema, + NoPrefixKeyStateEncoderSpec(keySchema), + useColumnFamilies = false, + StateStoreConf(spark.sessionState.conf), + new Configuration) + provider + } + + test("ERROR: specify changeStartBatchId in normal mode") { + withTempDir { tempDir => + val exc = intercept[StateDataSourceInvalidOptionValue] { + spark.read.format("statestore") + .option(StateSourceOptions.BATCH_ID, 0) + .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0) + .option(StateSourceOptions.CHANGE_END_BATCH_ID, 2) + .load(tempDir.getAbsolutePath) + } + assert(exc.getErrorClass === "STDS_INVALID_OPTION_VALUE.WITH_MESSAGE") + } + } + + test("ERROR: changeStartBatchId is set to negative") { + withTempDir { tempDir => + val exc = intercept[StateDataSourceInvalidOptionValueIsNegative] { + spark.read.format("statestore") + .option(StateSourceOptions.BATCH_ID, 0) + .option(StateSourceOptions.READ_CHANGE_FEED, value = true) + .option(StateSourceOptions.CHANGE_START_BATCH_ID, -1) + .option(StateSourceOptions.CHANGE_END_BATCH_ID, 0) + .load(tempDir.getAbsolutePath) + } + assert(exc.getErrorClass === "STDS_INVALID_OPTION_VALUE.IS_NEGATIVE") + } + } + + test("ERROR: changeEndBatchId is set to less than changeStartBatchId") { + withTempDir { tempDir => + val exc = intercept[StateDataSourceInvalidOptionValue] { + spark.read.format("statestore") + .option(StateSourceOptions.BATCH_ID, 0) + .option(StateSourceOptions.READ_CHANGE_FEED, value = true) + .option(StateSourceOptions.CHANGE_START_BATCH_ID, 1) + .option(StateSourceOptions.CHANGE_END_BATCH_ID, 0) + .load(tempDir.getAbsolutePath) + } + assert(exc.getErrorClass === "STDS_INVALID_OPTION_VALUE.WITH_MESSAGE") + } + } + + test("ERROR: joinSide option is used together with readChangeFeed") { + withTempDir { tempDir => + val exc = intercept[StateDataSourceConflictOptions] { + spark.read.format("statestore") + .option(StateSourceOptions.BATCH_ID, 0) + .option(StateSourceOptions.READ_CHANGE_FEED, value = true) + .option(StateSourceOptions.JOIN_SIDE, "left") + .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0) + .option(StateSourceOptions.CHANGE_END_BATCH_ID, 0) + .load(tempDir.getAbsolutePath) + } + assert(exc.getErrorClass === "STDS_CONFLICT_OPTIONS") + } + } + + test("getChangeDataReader of state store provider") { + def withNewStateStore(provider: StateStoreProvider, version: Int)(f: StateStore => Unit): + Unit = { + val stateStore = provider.getStore(version) + f(stateStore) + stateStore.commit() + } + + withTempDir { tempDir => + val provider = getNewStateStoreProvider(tempDir.getAbsolutePath) + withNewStateStore(provider, 0) { stateStore => + put(stateStore, "a", 1, 1) } + withNewStateStore(provider, 1) { stateStore => + put(stateStore, "b", 2, 2) } + withNewStateStore(provider, 2) { stateStore => + stateStore.remove(dataToKeyRow("a", 1)) } + withNewStateStore(provider, 3) { stateStore => + stateStore.remove(dataToKeyRow("b", 2)) } + + val reader = + provider.asInstanceOf[SupportsFineGrainedReplay].getStateStoreChangeDataReader(1, 4) + + assert(reader.next() === (RecordType.PUT_RECORD, dataToKeyRow("a", 1), dataToValueRow(1), 0L)) + assert(reader.next() === (RecordType.PUT_RECORD, dataToKeyRow("b", 2), dataToValueRow(2), 1L)) + assert(reader.next() === + (RecordType.DELETE_RECORD, dataToKeyRow("a", 1), null, 2L)) + assert(reader.next() === + (RecordType.DELETE_RECORD, dataToKeyRow("b", 2), null, 3L)) + } + } + + test("read limit state change feed") { + withTempDir { tempDir => + val inputData = MemoryStream[Int] + val df = inputData.toDF().limit(10) + testStream(df)( + StartStream(checkpointLocation = tempDir.getAbsolutePath), + AddData(inputData, 1, 2, 3, 4), + ProcessAllAvailable(), + AddData(inputData, 5, 6, 7, 8), + ProcessAllAvailable(), + AddData(inputData, 9, 10, 11, 12), + ProcessAllAvailable() + ) + + val stateDf = spark.read.format("statestore") + .option(StateSourceOptions.READ_CHANGE_FEED, value = true) + .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0) + .option(StateSourceOptions.CHANGE_END_BATCH_ID, 2) + .load(tempDir.getAbsolutePath) + + val expectedDf = Seq( + Row(Row(null), Row(4), "update", 0L, 0), + Row(Row(null), Row(8), "update", 1L, 0), + Row(Row(null), Row(10), "update", 2L, 0) + ) + + checkAnswer(stateDf, expectedDf) + } + } + + test("read aggregate state change feed") { + withTempDir { tempDir => + val inputData = MemoryStream[Int] + val df = inputData.toDF().groupBy("value").count() + testStream(df, OutputMode.Update)( + StartStream(checkpointLocation = tempDir.getAbsolutePath), + AddData(inputData, 1, 2, 3, 4), + ProcessAllAvailable(), + AddData(inputData, 2, 3, 4, 5), + ProcessAllAvailable(), + AddData(inputData, 3, 4, 5, 6), + ProcessAllAvailable() + ) + + val stateDf = spark.read.format("statestore") + .option(StateSourceOptions.READ_CHANGE_FEED, value = true) + .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0) + .option(StateSourceOptions.CHANGE_END_BATCH_ID, 2) + .load(tempDir.getAbsolutePath) + + val expectedDf = Seq( + Row(Row(3), Row(1), "update", 0L, 1), + Row(Row(3), Row(2), "update", 1L, 1), + Row(Row(5), Row(1), "update", 1L, 1), + Row(Row(3), Row(3), "update", 2L, 1), + Row(Row(5), Row(2), "update", 2L, 1), + Row(Row(4), Row(1), "update", 0L, 2), + Row(Row(4), Row(2), "update", 1L, 2), + Row(Row(4), Row(3), "update", 2L, 2), + Row(Row(1), Row(1), "update", 0L, 3), + Row(Row(2), Row(1), "update", 0L, 4), + Row(Row(2), Row(2), "update", 1L, 4), + Row(Row(6), Row(1), "update", 2L, 4) + ) + + checkAnswer(stateDf, expectedDf) + } + } + + test("read deduplication state change feed") { + withTempDir { tempDir => + val inputData = MemoryStream[Int] + val df = inputData.toDF().dropDuplicates("value") + testStream(df, OutputMode.Update)( + StartStream(checkpointLocation = tempDir.getAbsolutePath), + AddData(inputData, 1, 2, 3, 4), + ProcessAllAvailable(), + AddData(inputData, 2, 3, 4, 5), + ProcessAllAvailable(), + AddData(inputData, 3, 4, 5, 6), + ProcessAllAvailable() + ) + + val stateDf = spark.read.format("statestore") + .option(StateSourceOptions.READ_CHANGE_FEED, value = true) + .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0) + .option(StateSourceOptions.CHANGE_END_BATCH_ID, 2) + .load(tempDir.getAbsolutePath) + + val expectedDf = Seq( + Row(Row(1), Row(null), "update", 0L, 3), + Row(Row(2), Row(null), "update", 0L, 4), + Row(Row(3), Row(null), "update", 0L, 1), + Row(Row(4), Row(null), "update", 0L, 2), + Row(Row(5), Row(null), "update", 1L, 1), + Row(Row(6), Row(null), "update", 2L, 4) + ) + + checkAnswer(stateDf, expectedDf) + } + } + + test("read stream-stream join state change feed") { + withTempDir { tempDir => + val inputData = MemoryStream[(Int, Long)] + val leftDf = + inputData.toDF().select(col("_1").as("leftKey"), col("_2").as("leftValue")) + val rightDf = + inputData.toDF().select((col("_1") * 2).as("rightKey"), col("_2").as("rightValue")) + val df = leftDf.join(rightDf).where("leftKey == rightKey") + + testStream(df)( + StartStream(checkpointLocation = tempDir.getAbsolutePath), + AddData(inputData, (1, 1L), (2, 2L)), + ProcessAllAvailable(), + AddData(inputData, (3, 3L), (4, 4L)), + ProcessAllAvailable() + ) + + val keyWithIndexToValueDf = spark.read.format("statestore") + .option(StateSourceOptions.STORE_NAME, "left-keyWithIndexToValue") + .option(StateSourceOptions.READ_CHANGE_FEED, value = true) + .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0) + .option(StateSourceOptions.CHANGE_END_BATCH_ID, 1) + .load(tempDir.getAbsolutePath) + + val keyWithIndexToValueExpectedDf = Seq( + Row(Row(3, 0L), Row(3, 3L, false), "update", 1L, 1), + Row(Row(4, 0L), Row(4, 4L, true), "update", 1L, 2), + Row(Row(1, 0L), Row(1, 1L, false), "update", 0L, 3), + Row(Row(2, 0L), Row(2, 2L, false), "update", 0L, 4), + Row(Row(2, 0L), Row(2, 2L, true), "update", 0L, 4) + ) + + checkAnswer(keyWithIndexToValueDf, keyWithIndexToValueExpectedDf) + + val keyToNumValuesDf = spark.read.format("statestore") + .option(StateSourceOptions.STORE_NAME, "left-keyToNumValues") + .option(StateSourceOptions.READ_CHANGE_FEED, value = true) + .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0) + .option(StateSourceOptions.CHANGE_END_BATCH_ID, 1) + .load(tempDir.getAbsolutePath) + + val keyToNumValuesDfExpectedDf = Seq( + Row(Row(3), Row(1L), "update", 1L, 1), + Row(Row(4), Row(1L), "update", 1L, 2), + Row(Row(1), Row(1L), "update", 0L, 3), + Row(Row(2), Row(1L), "update", 0L, 4) + ) + + checkAnswer(keyToNumValuesDf, keyToNumValuesDfExpectedDf) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala index f5392cc823f78..705d9f125964f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala @@ -383,7 +383,7 @@ trait StateDataSourceTestBase extends StreamTest with StateStoreMetricsTest { } } - private def getStreamStreamJoinQuery(inputStream: MemoryStream[(Int, Long)]): DataFrame = { + protected def getStreamStreamJoinQuery(inputStream: MemoryStream[(Int, Long)]): DataFrame = { val df = inputStream.toDS() .select(col("_1").as("value"), timestamp_seconds($"_2").as("timestamp")) From 98bf8ec30fe38842c90385bf73d7eb81aecd578c Mon Sep 17 00:00:00 2001 From: Yuchen Liu Date: Tue, 2 Jul 2024 14:19:54 -0700 Subject: [PATCH 02/18] revert unnecessary changes --- .../v2/state/StateDataSource.scala | 9 --------- .../execution/streaming/HDFSMetadataLog.scala | 2 -- .../state/SymmetricHashJoinStateManager.scala | 19 ++++++------------- 3 files changed, 6 insertions(+), 24 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index 2e9f300eae7c2..1c0d2de400dfb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -305,13 +305,4 @@ object StateSourceOptions extends DataSourceOptions { case None => throw StateDataSourceErrors.committedBatchUnavailable(checkpointLocation) } } - - private def getFirstCommittedBatch(session: SparkSession, checkpointLocation: String): Long = { - val commitLog = new CommitLog(session, - new Path(checkpointLocation, DIR_NAME_COMMITS).toString) - commitLog.getEarliestBatchId() match { - case Some(firstId) => firstId - case None => throw StateDataSourceErrors.committedBatchUnavailable(checkpointLocation) - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index 2ae838581f6f7..251cc16acdf43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -264,8 +264,6 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: /** Return the latest batch id without reading the file. */ def getLatestBatchId(): Option[Long] = listBatches.sorted.lastOption - def getEarliestBatchId(): Option[Long] = listBatches.sorted.headOption - override def getLatest(): Option[(Long, T)] = { listBatches.sorted.lastOption.map { batchId => logInfo(log"Getting latest batch ${MDC(BATCH_ID, batchId)}") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala index 3675325780269..4de3170f5db33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala @@ -444,8 +444,8 @@ class SymmetricHashJoinStateManager( private val keySchema = StructType( joinKeys.zipWithIndex.map { case (k, i) => StructField(s"field$i", k.dataType, k.nullable) }) private val keyAttributes = toAttributes(keySchema) - private lazy val keyToNumValues = new KeyToNumValuesStore() - private lazy val keyWithIndexToValue = new KeyWithIndexToValueStore(stateFormatVersion) + private val keyToNumValues = new KeyToNumValuesStore() + private val keyWithIndexToValue = new KeyWithIndexToValueStore(stateFormatVersion) // Clean up any state store resources if necessary at the end of the task Option(TaskContext.get()).foreach { _.addTaskCompletionListener[Unit] { _ => abortIfNeeded() } } @@ -476,16 +476,6 @@ class SymmetricHashJoinStateManager( def metrics: StateStoreMetrics = stateStore.metrics - private def initializeStateStoreProvider(keySchema: StructType, valueSchema: StructType): - Unit = { - val storeProviderId = StateStoreProviderId( - stateInfo.get, partitionId, getStateStoreName(joinSide, stateStoreType)) - stateStoreProvider = StateStoreProvider.createAndInit( - storeProviderId, keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema), - useColumnFamilies = false, storeConf, hadoopConf, - useMultipleValuesPerKey = false) - } - /** Get the StateStore with the given schema */ protected def getStateStore(keySchema: StructType, valueSchema: StructType): StateStore = { val storeProviderId = StateStoreProviderId( @@ -498,7 +488,10 @@ class SymmetricHashJoinStateManager( stateInfo.get.storeVersion, useColumnFamilies = false, storeConf, hadoopConf) } else { // This class will manage the state store provider by itself. - initializeStateStoreProvider(keySchema, valueSchema) + stateStoreProvider = StateStoreProvider.createAndInit( + storeProviderId, keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema), + useColumnFamilies = false, storeConf, hadoopConf, + useMultipleValuesPerKey = false) if (snapshotStartVersion.isDefined) { if (!stateStoreProvider.isInstanceOf[SupportsFineGrainedReplay]) { throw StateStoreErrors.stateStoreProviderDoesNotSupportFineGrainedReplay( From db45c6f5f55b9ce8c4fb2f6fd0ca31ca20bca475 Mon Sep 17 00:00:00 2001 From: Yuchen Liu Date: Tue, 2 Jul 2024 14:49:40 -0700 Subject: [PATCH 03/18] Add comments --- common/utils/src/main/resources/error/error-conditions.json | 2 +- .../execution/datasources/v2/state/StatePartitionReader.scala | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 78269d29fa2cd..ef788125dd77b 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -3812,7 +3812,7 @@ "STATE_STORE_PROVIDER_DOES_NOT_SUPPORT_FINE_GRAINED_STATE_REPLAY" : { "message" : [ "The given State Store Provider does not extend org.apache.spark.sql.execution.streaming.state.SupportsFineGrainedReplay.", - "Therefore, it does not support option snapshotStartBatchId in state data source." + "Therefore, it does not support option snapshotStartBatchId or readChangeFeed in state data source." ], "sqlState" : "42K06" }, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala index 663d06f53b176..7fa52844ebbab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala @@ -145,6 +145,10 @@ class StatePartitionReader( } } +/** + * An implementation of [[PartitionReader]] for the readChangeFeed mode of State Data Source. + * It reads the change of state over batches of a particular partition. + */ class StateStoreChangeDataPartitionReader( storeConf: StateStoreConf, hadoopConf: SerializableConfiguration, From 24c03514a2d0fce06a5ad004a305d8b06eb97755 Mon Sep 17 00:00:00 2001 From: Yuchen Liu Date: Tue, 2 Jul 2024 14:51:56 -0700 Subject: [PATCH 04/18] minor --- .../v2/state/StateDataSourceChangeDataReadSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala index da8b59ba23458..822e866cbdf55 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala @@ -14,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.sql.execution.datasources.v2.state import org.apache.hadoop.conf.Configuration From d4a4b808182db86113e02cdc86c920df144169a9 Mon Sep 17 00:00:00 2001 From: Yuchen Liu Date: Mon, 8 Jul 2024 14:27:34 -0700 Subject: [PATCH 05/18] group options & make options in changeFeed mode isolate from some options --- .../v2/state/StateDataSource.scala | 113 +++++++++++------- .../v2/state/StatePartitionReader.scala | 10 +- .../v2/state/StateScanBuilder.scala | 30 +++-- .../datasources/v2/state/StateTable.scala | 14 ++- ...StreamStreamJoinStatePartitionReader.scala | 3 +- .../StateDataSourceChangeDataReadSuite.scala | 4 - 6 files changed, 111 insertions(+), 63 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index 1c0d2de400dfb..cbea4f6d17165 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -136,24 +136,38 @@ class StateDataSource extends TableProvider with DataSourceRegister { override def supportsExternalMetadata(): Boolean = false } +case class FromSnapshotOptions( + snapshotStartBatchId: Long, + snapshotPartitionId: Int) + +case class ReadChangeFeedOptions( + changeStartBatchId: Long, + changeEndBatchId: Long +) + case class StateSourceOptions( resolvedCpLocation: String, batchId: Long, operatorId: Int, storeName: String, joinSide: JoinSideValues, - snapshotStartBatchId: Option[Long], - snapshotPartitionId: Option[Int], readChangeFeed: Boolean, - changeStartBatchId: Option[Long], - changeEndBatchId: Option[Long]) { + fromSnapshotOptions: Option[FromSnapshotOptions], + readChangeFeedOptions: Option[ReadChangeFeedOptions]) { def stateCheckpointLocation: Path = new Path(resolvedCpLocation, DIR_NAME_STATE) override def toString: String = { - s"StateSourceOptions(checkpointLocation=$resolvedCpLocation, batchId=$batchId, " + - s"operatorId=$operatorId, storeName=$storeName, joinSide=$joinSide, " + - s"snapshotStartBatchId=${snapshotStartBatchId.getOrElse("None")}, " + - s"snapshotPartitionId=${snapshotPartitionId.getOrElse("None")})" + var desc = s"StateSourceOptions(checkpointLocation=$resolvedCpLocation, batchId=$batchId, " + + s"operatorId=$operatorId, storeName=$storeName, joinSide=$joinSide" + if (fromSnapshotOptions.isDefined) { + desc += s", snapshotStartBatchId=${fromSnapshotOptions.get.snapshotStartBatchId}" + desc += s", snapshotPartitionId=${fromSnapshotOptions.get.snapshotPartitionId}" + } + if (readChangeFeedOptions.isDefined) { + desc += s", changeStartBatchId=${readChangeFeedOptions.get.changeStartBatchId}" + desc += s", changeEndBatchId=${readChangeFeedOptions.get.changeEndBatchId}" + } + desc + ")" } } @@ -189,16 +203,6 @@ object StateSourceOptions extends DataSourceOptions { throw StateDataSourceErrors.requiredOptionUnspecified(PATH) }.get - val resolvedCpLocation = resolvedCheckpointLocation(hadoopConf, checkpointLocation) - - val batchId = Option(options.get(BATCH_ID)).map(_.toLong).orElse { - Some(getLastCommittedBatch(sparkSession, resolvedCpLocation)) - }.get - - if (batchId < 0) { - throw StateDataSourceErrors.invalidOptionValueIsNegative(BATCH_ID) - } - val operatorId = Option(options.get(OPERATOR_ID)).map(_.toInt) .orElse(Some(0)).get @@ -227,40 +231,40 @@ object StateSourceOptions extends DataSourceOptions { throw StateDataSourceErrors.conflictOptions(Seq(JOIN_SIDE, STORE_NAME)) } - val snapshotStartBatchId = Option(options.get(SNAPSHOT_START_BATCH_ID)).map(_.toLong) - if (snapshotStartBatchId.exists(_ < 0)) { - throw StateDataSourceErrors.invalidOptionValueIsNegative(SNAPSHOT_START_BATCH_ID) - } else if (snapshotStartBatchId.exists(_ > batchId)) { - throw StateDataSourceErrors.invalidOptionValue( - SNAPSHOT_START_BATCH_ID, s"value should be less than or equal to $batchId") - } + val resolvedCpLocation = resolvedCheckpointLocation(hadoopConf, checkpointLocation) - val snapshotPartitionId = Option(options.get(SNAPSHOT_PARTITION_ID)).map(_.toInt) - if (snapshotPartitionId.exists(_ < 0)) { - throw StateDataSourceErrors.invalidOptionValueIsNegative(SNAPSHOT_PARTITION_ID) - } + var batchId = Option(options.get(BATCH_ID)).map(_.toLong) - // both snapshotPartitionId and snapshotStartBatchId are required at the same time, because - // each partition may have different checkpoint status - if (snapshotPartitionId.isDefined && snapshotStartBatchId.isEmpty) { - throw StateDataSourceErrors.requiredOptionUnspecified(SNAPSHOT_START_BATCH_ID) - } else if (snapshotPartitionId.isEmpty && snapshotStartBatchId.isDefined) { - throw StateDataSourceErrors.requiredOptionUnspecified(SNAPSHOT_PARTITION_ID) - } + val snapshotStartBatchId = Option(options.get(SNAPSHOT_START_BATCH_ID)).map(_.toLong) + val snapshotPartitionId = Option(options.get(SNAPSHOT_PARTITION_ID)).map(_.toInt) val readChangeFeed = Option(options.get(READ_CHANGE_FEED)).exists(_.toBoolean) val changeStartBatchId = Option(options.get(CHANGE_START_BATCH_ID)).map(_.toLong) var changeEndBatchId = Option(options.get(CHANGE_END_BATCH_ID)).map(_.toLong) + var fromSnapshotOptions: Option[FromSnapshotOptions] = None + var readChangeFeedOptions: Option[ReadChangeFeedOptions] = None + if (readChangeFeed) { if (joinSide != JoinSideValues.none) { throw StateDataSourceErrors.conflictOptions(Seq(JOIN_SIDE, READ_CHANGE_FEED)) } + if (batchId.isDefined) { + throw StateDataSourceErrors.conflictOptions(Seq(BATCH_ID, READ_CHANGE_FEED)) + } + if (snapshotStartBatchId.isDefined) { + throw StateDataSourceErrors.conflictOptions(Seq(SNAPSHOT_START_BATCH_ID, READ_CHANGE_FEED)) + } + if (snapshotPartitionId.isDefined) { + throw StateDataSourceErrors.conflictOptions(Seq(SNAPSHOT_PARTITION_ID, READ_CHANGE_FEED)) + } + if (changeStartBatchId.isEmpty) { throw StateDataSourceErrors.requiredOptionUnspecified(CHANGE_START_BATCH_ID) } - changeEndBatchId = Option(changeEndBatchId.getOrElse(batchId)) + changeEndBatchId = Option( + changeEndBatchId.getOrElse(getLastCommittedBatch(sparkSession, resolvedCpLocation))) // changeStartBatchId and changeEndBatchId must all be defined at this point if (changeStartBatchId.get < 0) { @@ -272,6 +276,11 @@ object StateSourceOptions extends DataSourceOptions { s"Please check the input to $CHANGE_END_BATCH_ID, or if you are using its default " + s"value, make sure that $CHANGE_START_BATCH_ID is less than ${changeEndBatchId.get}.") } + + batchId = Option(changeEndBatchId.get) + + readChangeFeedOptions = Option( + ReadChangeFeedOptions(changeStartBatchId.get, changeEndBatchId.get)) } else { if (changeStartBatchId.isDefined) { throw StateDataSourceErrors.invalidOptionValue(CHANGE_START_BATCH_ID, @@ -281,12 +290,36 @@ object StateSourceOptions extends DataSourceOptions { throw StateDataSourceErrors.invalidOptionValue(CHANGE_END_BATCH_ID, s"Only specify this option when $READ_CHANGE_FEED is set to true.") } + + batchId = Option(batchId.getOrElse(getLastCommittedBatch(sparkSession, resolvedCpLocation))) + + if (batchId.get < 0) { + throw StateDataSourceErrors.invalidOptionValueIsNegative(BATCH_ID) + } + if (snapshotStartBatchId.exists(_ < 0)) { + throw StateDataSourceErrors.invalidOptionValueIsNegative(SNAPSHOT_START_BATCH_ID) + } else if (snapshotStartBatchId.exists(_ > batchId.get)) { + throw StateDataSourceErrors.invalidOptionValue( + SNAPSHOT_START_BATCH_ID, s"value should be less than or equal to $batchId") + } + if (snapshotPartitionId.exists(_ < 0)) { + throw StateDataSourceErrors.invalidOptionValueIsNegative(SNAPSHOT_PARTITION_ID) + } + // both snapshotPartitionId and snapshotStartBatchId are required at the same time, because + // each partition may have different checkpoint status + if (snapshotPartitionId.isDefined && snapshotStartBatchId.isEmpty) { + throw StateDataSourceErrors.requiredOptionUnspecified(SNAPSHOT_START_BATCH_ID) + } else if (snapshotPartitionId.isEmpty && snapshotStartBatchId.isDefined) { + throw StateDataSourceErrors.requiredOptionUnspecified(SNAPSHOT_PARTITION_ID) + } + + fromSnapshotOptions = Option( + FromSnapshotOptions(snapshotStartBatchId.get, snapshotPartitionId.get)) } StateSourceOptions( - resolvedCpLocation, batchId, operatorId, storeName, - joinSide, snapshotStartBatchId, snapshotPartitionId, - readChangeFeed, changeStartBatchId, changeEndBatchId) + resolvedCpLocation, batchId.get, operatorId, storeName, joinSide, + readChangeFeed, fromSnapshotOptions, readChangeFeedOptions) } private def resolvedCheckpointLocation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala index 7fa52844ebbab..befcb8cdbc1eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala @@ -97,17 +97,17 @@ class StatePartitionReader( } private lazy val store: ReadStateStore = { - partition.sourceOptions.snapshotStartBatchId match { + partition.sourceOptions.fromSnapshotOptions match { case None => provider.getReadStore(partition.sourceOptions.batchId + 1) - case Some(snapshotStartBatchId) => + case Some(fromSnapshotOptions) => if (!provider.isInstanceOf[SupportsFineGrainedReplay]) { throw StateStoreErrors.stateStoreProviderDoesNotSupportFineGrainedReplay( provider.getClass.toString) } provider.asInstanceOf[SupportsFineGrainedReplay] .replayReadStateFromSnapshot( - snapshotStartBatchId + 1, + fromSnapshotOptions.snapshotStartBatchId + 1, partition.sourceOptions.batchId + 1) } } @@ -164,8 +164,8 @@ class StateStoreChangeDataPartitionReader( } provider.asInstanceOf[SupportsFineGrainedReplay] .getStateStoreChangeDataReader( - partition.sourceOptions.changeStartBatchId.get + 1, - partition.sourceOptions.changeEndBatchId.get + 1) + partition.sourceOptions.readChangeFeedOptions.get.changeStartBatchId + 1, + partition.sourceOptions.readChangeFeedOptions.get.changeEndBatchId + 1) } override protected lazy val iter: Iterator[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala index 821a36977fed1..01f966ae948ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala @@ -86,17 +86,18 @@ class StateScan( assert((tail - head + 1) == partitionNums.length, s"No continuous partitions in state: ${partitionNums.mkString("Array(", ", ", ")")}") - sourceOptions.snapshotPartitionId match { + sourceOptions.fromSnapshotOptions match { case None => partitionNums.map { pn => new StateStoreInputPartition(pn, queryId, sourceOptions) }.toArray - case Some(snapshotPartitionId) => - if (partitionNums.contains(snapshotPartitionId)) { - Array(new StateStoreInputPartition(snapshotPartitionId, queryId, sourceOptions)) + case Some(fromSnapshotOptions) => + if (partitionNums.contains(fromSnapshotOptions.snapshotPartitionId)) { + Array(new StateStoreInputPartition( + fromSnapshotOptions.snapshotPartitionId, queryId, sourceOptions)) } else { throw StateStoreErrors.stateStoreSnapshotPartitionNotFound( - snapshotPartitionId, sourceOptions.operatorId, + fromSnapshotOptions.snapshotPartitionId, sourceOptions.operatorId, sourceOptions.stateCheckpointLocation.toString) } } @@ -128,16 +129,27 @@ class StateScan( override def toBatch: Batch = this override def description(): String = { - val desc = s"StateScan " + + var desc = s"StateScan " + s"[stateCkptLocation=${sourceOptions.stateCheckpointLocation}]" + s"[batchId=${sourceOptions.batchId}][operatorId=${sourceOptions.operatorId}]" + s"[storeName=${sourceOptions.storeName}]" if (sourceOptions.joinSide != JoinSideValues.none) { - desc + s"[joinSide=${sourceOptions.joinSide}]" - } else { - desc + desc += s"[joinSide=${sourceOptions.joinSide}]" + } + sourceOptions.fromSnapshotOptions match { + case Some(fromSnapshotOptions) => + desc += s"[snapshotStartBatchId=${fromSnapshotOptions.snapshotStartBatchId}]" + desc += s"[snapshotPartitionId=${fromSnapshotOptions.snapshotPartitionId}]" + case _ => + } + sourceOptions.readChangeFeedOptions match { + case Some(fromSnapshotOptions) => + desc += s"[changeStartBatchId=${fromSnapshotOptions.changeStartBatchId}" + desc += s"[changeEndBatchId=${fromSnapshotOptions.changeEndBatchId}" + case _ => } + desc } private def stateCheckpointPartitionsLocation: Path = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala index 0750d7549ec30..6d2646c7ba5ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala @@ -59,11 +59,17 @@ class StateTable( if (sourceOptions.joinSide != JoinSideValues.none) { desc += s"[joinSide=${sourceOptions.joinSide}]" } - if (sourceOptions.snapshotStartBatchId.isDefined) { - desc += s"[snapshotStartBatchId=${sourceOptions.snapshotStartBatchId}]" + sourceOptions.fromSnapshotOptions match { + case Some(fromSnapshotOptions) => + desc += s"[snapshotStartBatchId=${fromSnapshotOptions.snapshotStartBatchId}]" + desc += s"[snapshotPartitionId=${fromSnapshotOptions.snapshotPartitionId}]" + case _ => } - if (sourceOptions.snapshotPartitionId.isDefined) { - desc += s"[snapshotPartitionId=${sourceOptions.snapshotPartitionId}]" + sourceOptions.readChangeFeedOptions match { + case Some(fromSnapshotOptions) => + desc += s"[changeStartBatchId=${fromSnapshotOptions.changeStartBatchId}" + desc += s"[changeEndBatchId=${fromSnapshotOptions.changeEndBatchId}" + case _ => } desc } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala index 91f42db46dfb0..673ec3414c237 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala @@ -117,7 +117,8 @@ class StreamStreamJoinStatePartitionReader( formatVersion, skippedNullValueCount = None, useStateStoreCoordinator = false, - snapshotStartVersion = partition.sourceOptions.snapshotStartBatchId.map(_ + 1) + snapshotStartVersion = + partition.sourceOptions.fromSnapshotOptions.map(_.snapshotStartBatchId + 1) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala index 822e866cbdf55..2c8d93264d457 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala @@ -86,7 +86,6 @@ abstract class StateDataSourceChangeDataReaderSuite extends StateDataSourceTestB withTempDir { tempDir => val exc = intercept[StateDataSourceInvalidOptionValue] { spark.read.format("statestore") - .option(StateSourceOptions.BATCH_ID, 0) .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0) .option(StateSourceOptions.CHANGE_END_BATCH_ID, 2) .load(tempDir.getAbsolutePath) @@ -99,7 +98,6 @@ abstract class StateDataSourceChangeDataReaderSuite extends StateDataSourceTestB withTempDir { tempDir => val exc = intercept[StateDataSourceInvalidOptionValueIsNegative] { spark.read.format("statestore") - .option(StateSourceOptions.BATCH_ID, 0) .option(StateSourceOptions.READ_CHANGE_FEED, value = true) .option(StateSourceOptions.CHANGE_START_BATCH_ID, -1) .option(StateSourceOptions.CHANGE_END_BATCH_ID, 0) @@ -113,7 +111,6 @@ abstract class StateDataSourceChangeDataReaderSuite extends StateDataSourceTestB withTempDir { tempDir => val exc = intercept[StateDataSourceInvalidOptionValue] { spark.read.format("statestore") - .option(StateSourceOptions.BATCH_ID, 0) .option(StateSourceOptions.READ_CHANGE_FEED, value = true) .option(StateSourceOptions.CHANGE_START_BATCH_ID, 1) .option(StateSourceOptions.CHANGE_END_BATCH_ID, 0) @@ -127,7 +124,6 @@ abstract class StateDataSourceChangeDataReaderSuite extends StateDataSourceTestB withTempDir { tempDir => val exc = intercept[StateDataSourceConflictOptions] { spark.read.format("statestore") - .option(StateSourceOptions.BATCH_ID, 0) .option(StateSourceOptions.READ_CHANGE_FEED, value = true) .option(StateSourceOptions.JOIN_SIDE, "left") .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0) From 42552acd2188f389226ec2fba21810d615e82127 Mon Sep 17 00:00:00 2001 From: Yuchen Liu Date: Mon, 8 Jul 2024 14:53:11 -0700 Subject: [PATCH 06/18] reorder the columns in the result --- .../v2/state/StateDataSource.scala | 4 +- .../v2/state/StatePartitionReader.scala | 8 +-- .../datasources/v2/state/StateTable.scala | 10 ++-- .../StateDataSourceChangeDataReadSuite.scala | 60 +++++++++---------- 4 files changed, 41 insertions(+), 41 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index cbea4f6d17165..361e3e619c436 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -96,10 +96,10 @@ class StateDataSource extends TableProvider with DataSourceRegister { if (sourceOptions.readChangeFeed) { new StructType() + .add("batch_id", LongType) + .add("change_type", StringType) .add("key", keySchema) .add("value", valueSchema) - .add("change_type", StringType) - .add("batch_id", LongType) .add("partition_id", IntegerType) } else { new StructType() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala index befcb8cdbc1eb..4c32454cd49da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala @@ -181,10 +181,10 @@ class StateStoreChangeDataPartitionReader( private def unifyStateChangeDataRow(row: (RecordType, UnsafeRow, UnsafeRow, Long)): InternalRow = { val result = new GenericInternalRow(5) - result.update(0, row._2) - result.update(1, row._3) - result.update(2, UTF8String.fromString(getRecordTypeAsString(row._1))) - result.update(3, row._4) + result.update(0, row._4) + result.update(1, UTF8String.fromString(getRecordTypeAsString(row._1))) + result.update(2, row._2) + result.update(3, row._3) result.update(4, partition.partition) result } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala index 6d2646c7ba5ed..efe4396061261 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala @@ -100,15 +100,15 @@ class StateTable( private def isValidChangeDataSchema(schema: StructType): Boolean = { if (schema.fieldNames.toImmutableArraySeq != - Seq("key", "value", "change_type", "batch_id", "partition_id")) { + Seq("batch_id", "change_type", "key", "value", "partition_id")) { false - } else if (!SchemaUtil.getSchemaAsDataType(schema, "key").isInstanceOf[StructType]) { - false - } else if (!SchemaUtil.getSchemaAsDataType(schema, "value").isInstanceOf[StructType]) { + } else if (!SchemaUtil.getSchemaAsDataType(schema, "batch_id").isInstanceOf[LongType]) { false } else if (!SchemaUtil.getSchemaAsDataType(schema, "change_type").isInstanceOf[StringType]) { false - } else if (!SchemaUtil.getSchemaAsDataType(schema, "batch_id").isInstanceOf[LongType]) { + } else if (!SchemaUtil.getSchemaAsDataType(schema, "key").isInstanceOf[StructType]) { + false + } else if (!SchemaUtil.getSchemaAsDataType(schema, "value").isInstanceOf[StructType]) { false } else if (!SchemaUtil.getSchemaAsDataType(schema, "partition_id").isInstanceOf[IntegerType]) { false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala index 2c8d93264d457..b219598ea4624 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala @@ -186,9 +186,9 @@ abstract class StateDataSourceChangeDataReaderSuite extends StateDataSourceTestB .load(tempDir.getAbsolutePath) val expectedDf = Seq( - Row(Row(null), Row(4), "update", 0L, 0), - Row(Row(null), Row(8), "update", 1L, 0), - Row(Row(null), Row(10), "update", 2L, 0) + Row(0L, "update", Row(null), Row(4), 0), + Row(1L, "update", Row(null), Row(8), 0), + Row(2L, "update", Row(null), Row(10), 0) ) checkAnswer(stateDf, expectedDf) @@ -216,18 +216,18 @@ abstract class StateDataSourceChangeDataReaderSuite extends StateDataSourceTestB .load(tempDir.getAbsolutePath) val expectedDf = Seq( - Row(Row(3), Row(1), "update", 0L, 1), - Row(Row(3), Row(2), "update", 1L, 1), - Row(Row(5), Row(1), "update", 1L, 1), - Row(Row(3), Row(3), "update", 2L, 1), - Row(Row(5), Row(2), "update", 2L, 1), - Row(Row(4), Row(1), "update", 0L, 2), - Row(Row(4), Row(2), "update", 1L, 2), - Row(Row(4), Row(3), "update", 2L, 2), - Row(Row(1), Row(1), "update", 0L, 3), - Row(Row(2), Row(1), "update", 0L, 4), - Row(Row(2), Row(2), "update", 1L, 4), - Row(Row(6), Row(1), "update", 2L, 4) + Row(0L, "update", Row(3), Row(1), 1), + Row(1L, "update", Row(3), Row(2), 1), + Row(1L, "update", Row(5), Row(1), 1), + Row(2L, "update", Row(3), Row(3), 1), + Row(2L, "update", Row(5), Row(2), 1), + Row(0L, "update", Row(4), Row(1), 2), + Row(1L, "update", Row(4), Row(2), 2), + Row(2L, "update", Row(4), Row(3), 2), + Row(0L, "update", Row(1), Row(1), 3), + Row(0L, "update", Row(2), Row(1), 4), + Row(1L, "update", Row(2), Row(2), 4), + Row(2L, "update", Row(6), Row(1), 4) ) checkAnswer(stateDf, expectedDf) @@ -255,12 +255,12 @@ abstract class StateDataSourceChangeDataReaderSuite extends StateDataSourceTestB .load(tempDir.getAbsolutePath) val expectedDf = Seq( - Row(Row(1), Row(null), "update", 0L, 3), - Row(Row(2), Row(null), "update", 0L, 4), - Row(Row(3), Row(null), "update", 0L, 1), - Row(Row(4), Row(null), "update", 0L, 2), - Row(Row(5), Row(null), "update", 1L, 1), - Row(Row(6), Row(null), "update", 2L, 4) + Row(0L, "update", Row(1), Row(null), 3), + Row(0L, "update", Row(2), Row(null), 4), + Row(0L, "update", Row(3), Row(null), 1), + Row(0L, "update", Row(4), Row(null), 2), + Row(1L, "update", Row(5), Row(null), 1), + Row(2L, "update", Row(6), Row(null), 4) ) checkAnswer(stateDf, expectedDf) @@ -292,11 +292,11 @@ abstract class StateDataSourceChangeDataReaderSuite extends StateDataSourceTestB .load(tempDir.getAbsolutePath) val keyWithIndexToValueExpectedDf = Seq( - Row(Row(3, 0L), Row(3, 3L, false), "update", 1L, 1), - Row(Row(4, 0L), Row(4, 4L, true), "update", 1L, 2), - Row(Row(1, 0L), Row(1, 1L, false), "update", 0L, 3), - Row(Row(2, 0L), Row(2, 2L, false), "update", 0L, 4), - Row(Row(2, 0L), Row(2, 2L, true), "update", 0L, 4) + Row(1L, "update", Row(3, 0L), Row(3, 3L, false), 1), + Row(1L, "update", Row(4, 0L), Row(4, 4L, true), 2), + Row(0L, "update", Row(1, 0L), Row(1, 1L, false), 3), + Row(0L, "update", Row(2, 0L), Row(2, 2L, false), 4), + Row(0L, "update", Row(2, 0L), Row(2, 2L, true), 4) ) checkAnswer(keyWithIndexToValueDf, keyWithIndexToValueExpectedDf) @@ -309,10 +309,10 @@ abstract class StateDataSourceChangeDataReaderSuite extends StateDataSourceTestB .load(tempDir.getAbsolutePath) val keyToNumValuesDfExpectedDf = Seq( - Row(Row(3), Row(1L), "update", 1L, 1), - Row(Row(4), Row(1L), "update", 1L, 2), - Row(Row(1), Row(1L), "update", 0L, 3), - Row(Row(2), Row(1L), "update", 0L, 4) + Row(1L, "update", Row(3), Row(1L), 1), + Row(1L, "update", Row(4), Row(1L), 2), + Row(0L, "update", Row(1), Row(1L), 3), + Row(0L, "update", Row(2), Row(1L), 4) ) checkAnswer(keyToNumValuesDf, keyToNumValuesDfExpectedDf) From 24db837a3ec8b332615b0cdfaa5bba2b9ce39649 Mon Sep 17 00:00:00 2001 From: Yuchen Liu Date: Mon, 8 Jul 2024 15:34:00 -0700 Subject: [PATCH 07/18] address comments from Jungtaek --- .../datasources/v2/state/StateDataSource.scala | 1 - .../v2/state/StatePartitionReader.scala | 4 ++-- .../datasources/v2/state/StateTable.scala | 17 +++++------------ .../execution/streaming/state/StateStore.scala | 6 ++++-- .../streaming/state/StateStoreChangelog.scala | 10 +++++----- 5 files changed, 16 insertions(+), 22 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index 361e3e619c436..27bbf96266191 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -108,7 +108,6 @@ class StateDataSource extends TableProvider with DataSourceRegister { .add("partition_id", IntegerType) } - } catch { case NonFatal(e) => throw StateDataSourceErrors.failedToReadStateSchema(sourceOptions, e) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala index 4c32454cd49da..45990323b72ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala @@ -42,10 +42,10 @@ class StatePartitionReaderFactory( val stateStoreInputPartition = partition.asInstanceOf[StateStoreInputPartition] if (stateStoreInputPartition.sourceOptions.readChangeFeed) { new StateStoreChangeDataPartitionReader(storeConf, hadoopConf, - partition.asInstanceOf[StateStoreInputPartition], schema, stateStoreMetadata) + stateStoreInputPartition, schema, stateStoreMetadata) } else { new StatePartitionReader(storeConf, hadoopConf, - partition.asInstanceOf[StateStoreInputPartition], schema, stateStoreMetadata) + stateStoreInputPartition, schema, stateStoreMetadata) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala index efe4396061261..03d73d1c64fce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala @@ -82,17 +82,16 @@ class StateTable( override def properties(): util.Map[String, String] = Map.empty[String, String].asJava private def isValidSchema(schema: StructType): Boolean = { - if (sourceOptions.readChangeFeed) { - return isValidChangeDataSchema(schema) - } - if (schema.fieldNames.toImmutableArraySeq != Seq("key", "value", "partition_id")) { - false - } else if (!SchemaUtil.getSchemaAsDataType(schema, "key").isInstanceOf[StructType]) { + if (!SchemaUtil.getSchemaAsDataType(schema, "key").isInstanceOf[StructType]) { false } else if (!SchemaUtil.getSchemaAsDataType(schema, "value").isInstanceOf[StructType]) { false } else if (!SchemaUtil.getSchemaAsDataType(schema, "partition_id").isInstanceOf[IntegerType]) { false + } else if (sourceOptions.readChangeFeed) { + isValidChangeDataSchema(schema) + } else if (schema.fieldNames.toImmutableArraySeq != Seq("key", "value", "partition_id")) { + false } else { true } @@ -106,12 +105,6 @@ class StateTable( false } else if (!SchemaUtil.getSchemaAsDataType(schema, "change_type").isInstanceOf[StringType]) { false - } else if (!SchemaUtil.getSchemaAsDataType(schema, "key").isInstanceOf[StructType]) { - false - } else if (!SchemaUtil.getSchemaAsDataType(schema, "value").isInstanceOf[StructType]) { - false - } else if (!SchemaUtil.getSchemaAsDataType(schema, "partition_id").isInstanceOf[IntegerType]) { - false } else { true } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 507facacb9cda..f184dc61ce7ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -471,9 +471,11 @@ trait SupportsFineGrainedReplay { } /** + * Return a [[StateStoreChangeDataReader]] that reads the changelogs entries from startVersion to + * endVersion. * - * @param startVersion - * @param endVersion + * @param startVersion starting changelog version + * @param endVersion ending changelog version * @return */ def getStateStoreChangeDataReader(startVersion: Long, endVersion: Long): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala index 62548ef0ec568..7e63a0ba3e0e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala @@ -413,11 +413,11 @@ class StateStoreChangelogReaderV2( * @param compressionCodec de-compression method using for reading changelog file */ abstract class StateStoreChangeDataReader( - fm: CheckpointFileManager, - stateLocation: Path, - startVersion: Long, - endVersion: Long, - compressionCodec: CompressionCodec) + fm: CheckpointFileManager, + stateLocation: Path, + startVersion: Long, + endVersion: Long, + compressionCodec: CompressionCodec) extends NextIterator[(RecordType.Value, UnsafeRow, UnsafeRow, Long)] with Logging { assert(startVersion >= 1) From adde991fb0ec51ee1b487bb9dadafd24985a329c Mon Sep 17 00:00:00 2001 From: Yuchen Liu Date: Mon, 8 Jul 2024 15:37:59 -0700 Subject: [PATCH 08/18] minor --- .../v2/state/StateDataSourceChangeDataReadSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala index b219598ea4624..2858d356d4c9a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala @@ -165,7 +165,7 @@ abstract class StateDataSourceChangeDataReaderSuite extends StateDataSourceTestB } } - test("read limit state change feed") { + test("read global streaming limit state change feed") { withTempDir { tempDir => val inputData = MemoryStream[Int] val df = inputData.toDF().limit(10) @@ -195,7 +195,7 @@ abstract class StateDataSourceChangeDataReaderSuite extends StateDataSourceTestB } } - test("read aggregate state change feed") { + test("read streaming aggregate state change feed") { withTempDir { tempDir => val inputData = MemoryStream[Int] val df = inputData.toDF().groupBy("value").count() @@ -234,7 +234,7 @@ abstract class StateDataSourceChangeDataReaderSuite extends StateDataSourceTestB } } - test("read deduplication state change feed") { + test("read streaming deduplication state change feed") { withTempDir { tempDir => val inputData = MemoryStream[Int] val df = inputData.toDF().dropDuplicates("value") From d3ca86cafd4181b425e3b5984fd1ac51f374013a Mon Sep 17 00:00:00 2001 From: Yuchen Liu Date: Mon, 8 Jul 2024 16:38:35 -0700 Subject: [PATCH 09/18] refactor StatePartitionReader for both modes --- .../v2/state/StatePartitionReader.scala | 67 ++++++++++++------- 1 file changed, 42 insertions(+), 25 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala index 45990323b72ef..94f016e9a9800 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala @@ -54,14 +54,13 @@ class StatePartitionReaderFactory( * An implementation of [[PartitionReader]] for State data source. This is used to support * general read from a state store instance, rather than specific to the operator. */ -class StatePartitionReader( +abstract class StatePartitionReaderBase( storeConf: StateStoreConf, hadoopConf: SerializableConfiguration, partition: StateStoreInputPartition, schema: StructType, stateStoreMetadata: Array[StateMetadataTableEntry]) extends PartitionReader[InternalRow] with Logging { - private val keySchema = SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType] private val valueSchema = SchemaUtil.getSchemaAsDataType(schema, "value").asInstanceOf[StructType] @@ -96,6 +95,40 @@ class StatePartitionReader( useMultipleValuesPerKey = false) } + protected val iter: Iterator[InternalRow] + + private var current: InternalRow = _ + + override def next(): Boolean = { + if (iter.hasNext) { + current = iter.next() + true + } else { + current = null + false + } + } + + override def get(): InternalRow = current + + override def close(): Unit = { + current = null + provider.close() + } +} + +/** + * An implementation of [[StatePartitionReaderBase]] for the normal mode of State Data + * Source. It reads the the state at a particular batchId. + */ +class StatePartitionReader( + storeConf: StateStoreConf, + hadoopConf: SerializableConfiguration, + partition: StateStoreInputPartition, + schema: StructType, + stateStoreMetadata: Array[StateMetadataTableEntry]) + extends StatePartitionReaderBase(storeConf, hadoopConf, partition, schema, stateStoreMetadata) { + private lazy val store: ReadStateStore = { partition.sourceOptions.fromSnapshotOptions match { case None => provider.getReadStore(partition.sourceOptions.batchId + 1) @@ -112,28 +145,13 @@ class StatePartitionReader( } } - protected lazy val iter: Iterator[InternalRow] = { + override lazy val iter: Iterator[InternalRow] = { store.iterator().map(pair => unifyStateRowPair((pair.key, pair.value))) } - protected var current: InternalRow = _ - - override def next(): Boolean = { - if (iter.hasNext) { - current = iter.next() - true - } else { - current = null - false - } - } - - override def get(): InternalRow = current - override def close(): Unit = { - current = null store.abort() - provider.close() + super.close() } private def unifyStateRowPair(pair: (UnsafeRow, UnsafeRow)): InternalRow = { @@ -146,8 +164,8 @@ class StatePartitionReader( } /** - * An implementation of [[PartitionReader]] for the readChangeFeed mode of State Data Source. - * It reads the change of state over batches of a particular partition. + * An implementation of [[StatePartitionReaderBase]] for the readChangeFeed mode of State Data + * Source. It reads the change of state over batches of a particular partition. */ class StateStoreChangeDataPartitionReader( storeConf: StateStoreConf, @@ -155,7 +173,7 @@ class StateStoreChangeDataPartitionReader( partition: StateStoreInputPartition, schema: StructType, stateStoreMetadata: Array[StateMetadataTableEntry]) - extends StatePartitionReader(storeConf, hadoopConf, partition, schema, stateStoreMetadata) { + extends StatePartitionReaderBase(storeConf, hadoopConf, partition, schema, stateStoreMetadata) { private lazy val changeDataReader: StateStoreChangeDataReader = { if (!provider.isInstanceOf[SupportsFineGrainedReplay]) { @@ -168,14 +186,13 @@ class StateStoreChangeDataPartitionReader( partition.sourceOptions.readChangeFeedOptions.get.changeEndBatchId + 1) } - override protected lazy val iter: Iterator[InternalRow] = { + override lazy val iter: Iterator[InternalRow] = { changeDataReader.iterator.map(unifyStateChangeDataRow) } override def close(): Unit = { - current = null changeDataReader.close() - provider.close() + super.close() } private def unifyStateChangeDataRow(row: (RecordType, UnsafeRow, UnsafeRow, Long)): From 5199c56b62b8eae8131c14472566d836c289f9b1 Mon Sep 17 00:00:00 2001 From: Yuchen Liu Date: Mon, 8 Jul 2024 16:51:08 -0700 Subject: [PATCH 10/18] minor --- .../execution/datasources/v2/state/StateDataSource.scala | 8 ++++---- .../sql/execution/datasources/v2/state/StateTable.scala | 8 +++++--- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index 27bbf96266191..f920dd3d5414b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -136,12 +136,12 @@ class StateDataSource extends TableProvider with DataSourceRegister { } case class FromSnapshotOptions( - snapshotStartBatchId: Long, - snapshotPartitionId: Int) + snapshotStartBatchId: Long, + snapshotPartitionId: Int) case class ReadChangeFeedOptions( - changeStartBatchId: Long, - changeEndBatchId: Long + changeStartBatchId: Long, + changeEndBatchId: Long ) case class StateSourceOptions( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala index 03d73d1c64fce..a69010a9ea96c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala @@ -90,10 +90,12 @@ class StateTable( false } else if (sourceOptions.readChangeFeed) { isValidChangeDataSchema(schema) - } else if (schema.fieldNames.toImmutableArraySeq != Seq("key", "value", "partition_id")) { - false } else { - true + if (schema.fieldNames.toImmutableArraySeq != Seq("key", "value", "partition_id")) { + false + } else { + true + } } } From ce751331242adbcf95f7035eaaa21ba28cf2acd9 Mon Sep 17 00:00:00 2001 From: Yuchen Liu Date: Tue, 9 Jul 2024 09:50:38 -0700 Subject: [PATCH 11/18] Use NextIterator as the interface rather than StateStoreChangeDataReader & change the existing StateStoreChangelogReader to correctly implement the interface --- .../datasources/v2/state/StatePartitionReader.scala | 7 ++++--- .../spark/sql/execution/streaming/state/RocksDB.scala | 2 +- .../spark/sql/execution/streaming/state/StateStore.scala | 4 ++-- .../execution/streaming/state/StateStoreChangelog.scala | 7 ++++--- 4 files changed, 11 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala index 94f016e9a9800..6201cf1157ab3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.execution.streaming.state.RecordType.{getRecordTypeAsString, RecordType} import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.util.SerializableConfiguration +import org.apache.spark.util.{NextIterator, SerializableConfiguration} /** * An implementation of [[PartitionReaderFactory]] for State data source. This is used to support @@ -175,7 +175,8 @@ class StateStoreChangeDataPartitionReader( stateStoreMetadata: Array[StateMetadataTableEntry]) extends StatePartitionReaderBase(storeConf, hadoopConf, partition, schema, stateStoreMetadata) { - private lazy val changeDataReader: StateStoreChangeDataReader = { + private lazy val changeDataReader: + NextIterator[(RecordType.Value, UnsafeRow, UnsafeRow, Long)] = { if (!provider.isInstanceOf[SupportsFineGrainedReplay]) { throw StateStoreErrors.stateStoreProviderDoesNotSupportFineGrainedReplay( provider.getClass.toString) @@ -191,7 +192,7 @@ class StateStoreChangeDataPartitionReader( } override def close(): Unit = { - changeDataReader.close() + changeDataReader.closeIfNeeded() super.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 28ad197ffb4af..aa54d3d3e9bac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -335,7 +335,7 @@ class RocksDB( } } } finally { - if (changelogReader != null) changelogReader.close() + if (changelogReader != null) changelogReader.closeIfNeeded() } } loadedVersion = endVersion diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index f184dc61ce7ae..559390a77519f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo import org.apache.spark.sql.types.StructType -import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.util.{NextIterator, ThreadUtils, Utils} /** * Base trait for a versioned key-value store which provides read operations. Each instance of a @@ -479,7 +479,7 @@ trait SupportsFineGrainedReplay { * @return */ def getStateStoreChangeDataReader(startVersion: Long, endVersion: Long): - StateStoreChangeDataReader + NextIterator[(RecordType.Value, UnsafeRow, UnsafeRow, Long)] } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala index 7e63a0ba3e0e7..b05ecbb9c012b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala @@ -299,7 +299,7 @@ abstract class StateStoreChangelogReader( } protected val input: DataInputStream = decompressStream(sourceStream) - def close(): Unit = { if (input != null) input.close() } + override protected def close(): Unit = { if (input != null) input.close() } override def getNext(): (RecordType.Value, Array[Byte], Array[Byte], String) } @@ -458,7 +458,8 @@ abstract class StateStoreChangeDataReader( protected def currentChangelogReader(): StateStoreChangelogReader = { while (changelogReader == null || !changelogReader.hasNext) { if (changelogReader != null) { - changelogReader.close() + changelogReader.closeIfNeeded() + changelogReader = null } if (!fileIterator.hasNext) { finished = true @@ -476,7 +477,7 @@ abstract class StateStoreChangeDataReader( override def close(): Unit = { if (changelogReader != null) { - changelogReader.close() + changelogReader.closeIfNeeded() } } } From 84dcf155346073e8a24666150be7eac91c27a091 Mon Sep 17 00:00:00 2001 From: Yuchen Liu Date: Tue, 9 Jul 2024 09:57:27 -0700 Subject: [PATCH 12/18] more doc --- .../spark/sql/execution/streaming/state/StateStore.scala | 9 +++++++-- .../execution/streaming/state/StateStoreChangelog.scala | 2 +- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 559390a77519f..e5291b3de3248 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -471,12 +471,17 @@ trait SupportsFineGrainedReplay { } /** - * Return a [[StateStoreChangeDataReader]] that reads the changelogs entries from startVersion to + * Return an iterator that reads all the entries of changelogs from startVersion to * endVersion. + * Each record is represented by a tuple of (recordType: [[RecordType.Value]], key: [[UnsafeRow]], + * value: [[UnsafeRow]], batchId: [[Long]]) + * A put record is returned as a tuple(recordType, key, value, batchId) + * A delete record is return as a tuple(recordType, key, null, batchId) * * @param startVersion starting changelog version * @param endVersion ending changelog version - * @return + * @return tuple(recordType: [[RecordType.Value]], nested key: [[UnsafeRow]], + * nested value: [[UnsafeRow]], batchId: [[Long]]) */ def getStateStoreChangeDataReader(startVersion: Long, endVersion: Long): NextIterator[(RecordType.Value, UnsafeRow, UnsafeRow, Long)] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala index b05ecbb9c012b..c2ed0d725528b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala @@ -403,7 +403,7 @@ class StateStoreChangelogReaderV2( /** * Base class representing a iterator that iterates over a range of changelog files in a state - * store. In each iteration, it will return a tuple of (changeType: [[RecordType]], + * store. In each iteration, it will return a ByteArrayPair of (changeType: [[RecordType]], * nested key: [[UnsafeRow]], nested value: [[UnsafeRow]], batchId: [[Long]]) * * @param fm checkpoint file manager used to manage streaming query checkpoint From c797d0b58f149a0b9f2786dc083b08dbf5968825 Mon Sep 17 00:00:00 2001 From: Yuchen Liu Date: Tue, 9 Jul 2024 10:34:22 -0700 Subject: [PATCH 13/18] use Jungtaek's advice in checking schema validity --- .../datasources/v2/state/StateTable.scala | 37 ++++++++----------- 1 file changed, 15 insertions(+), 22 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala index a69010a9ea96c..2fc85cd8aa968 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala @@ -82,33 +82,26 @@ class StateTable( override def properties(): util.Map[String, String] = Map.empty[String, String].asJava private def isValidSchema(schema: StructType): Boolean = { - if (!SchemaUtil.getSchemaAsDataType(schema, "key").isInstanceOf[StructType]) { - false - } else if (!SchemaUtil.getSchemaAsDataType(schema, "value").isInstanceOf[StructType]) { - false - } else if (!SchemaUtil.getSchemaAsDataType(schema, "partition_id").isInstanceOf[IntegerType]) { - false - } else if (sourceOptions.readChangeFeed) { - isValidChangeDataSchema(schema) - } else { - if (schema.fieldNames.toImmutableArraySeq != Seq("key", "value", "partition_id")) { - false + val expectedFieldNames = + if (sourceOptions.readChangeFeed) { + Seq("batch_id", "change_type", "key", "value", "partition_id") } else { - true + Seq("key", "value", "partition_id") } - } - } + val expectedTypes = Map( + "batch_id" -> classOf[LongType], + "change_type" -> classOf[StringType], + "key" -> classOf[StructType], + "value" -> classOf[StructType], + "partition_id" -> classOf[IntegerType]) - private def isValidChangeDataSchema(schema: StructType): Boolean = { - if (schema.fieldNames.toImmutableArraySeq != - Seq("batch_id", "change_type", "key", "value", "partition_id")) { - false - } else if (!SchemaUtil.getSchemaAsDataType(schema, "batch_id").isInstanceOf[LongType]) { - false - } else if (!SchemaUtil.getSchemaAsDataType(schema, "change_type").isInstanceOf[StringType]) { + if (schema.fieldNames.toImmutableArraySeq != expectedFieldNames) { false } else { - true + schema.fieldNames.forall { fieldName => + expectedTypes(fieldName).isAssignableFrom( + SchemaUtil.getSchemaAsDataType(schema, fieldName).getClass) + } } } From e5674cfa7d08b10fcccd4372c7b2d5700b46ccba Mon Sep 17 00:00:00 2001 From: Yuchen Liu Date: Tue, 9 Jul 2024 10:59:51 -0700 Subject: [PATCH 14/18] solve column family merge conflict --- .../streaming/state/HDFSBackedStateStoreProvider.scala | 2 +- .../streaming/state/RocksDBStateStoreProvider.scala | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 869c4f3c2e5e2..2ec36166f9f22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -1006,7 +1006,7 @@ class HDFSBackedStateStoreChangeDataReader( if (reader == null) { return null } - val (recordType, keyArray, valueArray, _) = reader.next() + val (recordType, keyArray, valueArray) = reader.next() val keyRow = new UnsafeRow(keySchema.fields.length) keyRow.pointTo(keyArray, keyArray.length) if (valueArray == null) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 28c90e164f987..a5a8d27116ce0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -476,6 +476,7 @@ private[sql] class RocksDBStateStoreProvider endVersion, CompressionCodec.createCodec(sparkConf, storeConf.compressionCodec), keyValueEncoderMap) + } /** * Class for column family related utility functions. @@ -706,8 +707,10 @@ class RocksDBStateStoreChangeDataReader( if (reader == null) { return null } - val (recordType, keyArray, valueArray, columnFamily) = reader.next() - val (rocksDBKeyStateEncoder, rocksDBValueStateEncoder) = keyValueEncoderMap.get(columnFamily) + val (recordType, keyArray, valueArray) = reader.next() + // Todo: does not support multiple virtual column families + val (rocksDBKeyStateEncoder, rocksDBValueStateEncoder) = + keyValueEncoderMap.get(StateStore.DEFAULT_COL_FAMILY_NAME) val keyRow = rocksDBKeyStateEncoder.decodeKey(keyArray) if (valueArray == null) { (recordType, keyRow, null, currentChangelogVersion - 1) From c012e1ad8a14d05a013f2f8a10cc5b1723e28514 Mon Sep 17 00:00:00 2001 From: Yuchen Liu Date: Tue, 9 Jul 2024 13:48:04 -0700 Subject: [PATCH 15/18] pass tests --- .../execution/datasources/v2/state/StateDataSource.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index f920dd3d5414b..975e0e1394c72 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -312,8 +312,10 @@ object StateSourceOptions extends DataSourceOptions { throw StateDataSourceErrors.requiredOptionUnspecified(SNAPSHOT_PARTITION_ID) } - fromSnapshotOptions = Option( - FromSnapshotOptions(snapshotStartBatchId.get, snapshotPartitionId.get)) + if (snapshotStartBatchId.isDefined && snapshotPartitionId.isDefined) { + fromSnapshotOptions = Option( + FromSnapshotOptions(snapshotStartBatchId.get, snapshotPartitionId.get)) + } } StateSourceOptions( From ff0cd43488ce1d6b8aaf44dc42fc6d2545071c23 Mon Sep 17 00:00:00 2001 From: Yuchen Liu Date: Tue, 9 Jul 2024 15:53:25 -0700 Subject: [PATCH 16/18] continue --- .../datasources/v2/state/StateDataSource.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index 975e0e1394c72..e2c5499fe439d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -262,7 +262,7 @@ object StateSourceOptions extends DataSourceOptions { if (changeStartBatchId.isEmpty) { throw StateDataSourceErrors.requiredOptionUnspecified(CHANGE_START_BATCH_ID) } - changeEndBatchId = Option( + changeEndBatchId = Some( changeEndBatchId.getOrElse(getLastCommittedBatch(sparkSession, resolvedCpLocation))) // changeStartBatchId and changeEndBatchId must all be defined at this point @@ -276,7 +276,7 @@ object StateSourceOptions extends DataSourceOptions { s"value, make sure that $CHANGE_START_BATCH_ID is less than ${changeEndBatchId.get}.") } - batchId = Option(changeEndBatchId.get) + batchId = Some(changeEndBatchId.get) readChangeFeedOptions = Option( ReadChangeFeedOptions(changeStartBatchId.get, changeEndBatchId.get)) @@ -290,7 +290,7 @@ object StateSourceOptions extends DataSourceOptions { s"Only specify this option when $READ_CHANGE_FEED is set to true.") } - batchId = Option(batchId.getOrElse(getLastCommittedBatch(sparkSession, resolvedCpLocation))) + batchId = Some(batchId.getOrElse(getLastCommittedBatch(sparkSession, resolvedCpLocation))) if (batchId.get < 0) { throw StateDataSourceErrors.invalidOptionValueIsNegative(BATCH_ID) @@ -299,7 +299,7 @@ object StateSourceOptions extends DataSourceOptions { throw StateDataSourceErrors.invalidOptionValueIsNegative(SNAPSHOT_START_BATCH_ID) } else if (snapshotStartBatchId.exists(_ > batchId.get)) { throw StateDataSourceErrors.invalidOptionValue( - SNAPSHOT_START_BATCH_ID, s"value should be less than or equal to $batchId") + SNAPSHOT_START_BATCH_ID, s"value should be less than or equal to ${batchId.get}") } if (snapshotPartitionId.exists(_ < 0)) { throw StateDataSourceErrors.invalidOptionValueIsNegative(SNAPSHOT_PARTITION_ID) @@ -313,7 +313,7 @@ object StateSourceOptions extends DataSourceOptions { } if (snapshotStartBatchId.isDefined && snapshotPartitionId.isDefined) { - fromSnapshotOptions = Option( + fromSnapshotOptions = Some( FromSnapshotOptions(snapshotStartBatchId.get, snapshotPartitionId.get)) } } From 2ad7590805649e4816029d1bdca3dd2412116e2b Mon Sep 17 00:00:00 2001 From: Yuchen Liu Date: Tue, 9 Jul 2024 20:16:32 -0700 Subject: [PATCH 17/18] make the doc consistent --- .../sql/execution/streaming/state/StateStore.scala | 2 +- .../streaming/state/StateStoreChangelog.scala | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index e5291b3de3248..0dc5414b7398a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -480,7 +480,7 @@ trait SupportsFineGrainedReplay { * * @param startVersion starting changelog version * @param endVersion ending changelog version - * @return tuple(recordType: [[RecordType.Value]], nested key: [[UnsafeRow]], + * @return iterator that gives tuple(recordType: [[RecordType.Value]], nested key: [[UnsafeRow]], * nested value: [[UnsafeRow]], batchId: [[Long]]) */ def getStateStoreChangeDataReader(startVersion: Long, endVersion: Long): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala index 1634f8985fd6b..728d35e929417 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala @@ -277,10 +277,10 @@ abstract class StateStoreChangelogReader( /** * Read an iterator of change record from the changelog file. - * A record is represented by ByteArrayPair(recordType: RecordType.Value, + * A record is represented by tuple(recordType: RecordType.Value, * key: Array[Byte], value: Array[Byte], colFamilyName: String) - * A put record is returned as a ByteArrayPair(recordType, key, value, colFamilyName) - * A delete record is return as a ByteArrayPair(recordType, key, null, colFamilyName) + * A put record is returned as a tuple(recordType, key, value) + * A delete record is return as a tuple(recordType, key, null) */ class StateStoreChangelogReaderV1( fm: CheckpointFileManager, @@ -317,10 +317,10 @@ class StateStoreChangelogReaderV1( /** * Read an iterator of change record from the changelog file. - * A record is represented by ByteArrayPair(recordType: RecordType.Value, + * A record is represented by tuple(recordType: RecordType.Value, * key: Array[Byte], value: Array[Byte], colFamilyName: String) - * A put record is returned as a ByteArrayPair(recordType, key, value, colFamilyName) - * A delete record is return as a ByteArrayPair(recordType, key, null, colFamilyName) + * A put record is returned as a tuple(recordType, key, value) + * A delete record is return as a tuple(recordType, key, null) */ class StateStoreChangelogReaderV2( fm: CheckpointFileManager, @@ -368,7 +368,7 @@ class StateStoreChangelogReaderV2( /** * Base class representing a iterator that iterates over a range of changelog files in a state - * store. In each iteration, it will return a ByteArrayPair of (changeType: [[RecordType]], + * store. In each iteration, it will return a tuple of (changeType: [[RecordType]], * nested key: [[UnsafeRow]], nested value: [[UnsafeRow]], batchId: [[Long]]) * * @param fm checkpoint file manager used to manage streaming query checkpoint From 43420f67c9102503857e29086d32d07be08615d0 Mon Sep 17 00:00:00 2001 From: Yuchen Liu Date: Tue, 9 Jul 2024 20:19:07 -0700 Subject: [PATCH 18/18] continue --- .../sql/execution/streaming/state/StateStoreChangelog.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala index 728d35e929417..d189daa6e841b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala @@ -278,7 +278,7 @@ abstract class StateStoreChangelogReader( /** * Read an iterator of change record from the changelog file. * A record is represented by tuple(recordType: RecordType.Value, - * key: Array[Byte], value: Array[Byte], colFamilyName: String) + * key: Array[Byte], value: Array[Byte]) * A put record is returned as a tuple(recordType, key, value) * A delete record is return as a tuple(recordType, key, null) */ @@ -318,7 +318,7 @@ class StateStoreChangelogReaderV1( /** * Read an iterator of change record from the changelog file. * A record is represented by tuple(recordType: RecordType.Value, - * key: Array[Byte], value: Array[Byte], colFamilyName: String) + * key: Array[Byte], value: Array[Byte]) * A put record is returned as a tuple(recordType, key, value) * A delete record is return as a tuple(recordType, key, null) */