@@ -36,7 +36,7 @@ import org.apache.spark.sql.execution.streaming.StreamingCheckpointConstants.{DI
3636import org .apache .spark .sql .execution .streaming .StreamingSymmetricHashJoinHelper .{LeftSide , RightSide }
3737import org .apache .spark .sql .execution .streaming .state .{StateSchemaCompatibilityChecker , StateStore , StateStoreConf , StateStoreId , StateStoreProviderId }
3838import org .apache .spark .sql .sources .DataSourceRegister
39- import org .apache .spark .sql .types .{IntegerType , StructType }
39+ import org .apache .spark .sql .types .{IntegerType , LongType , StringType , StructType }
4040import org .apache .spark .sql .util .CaseInsensitiveStringMap
4141import org .apache .spark .util .SerializableConfiguration
4242
@@ -94,10 +94,20 @@ class StateDataSource extends TableProvider with DataSourceRegister {
9494 manager.readSchemaFile()
9595 }
9696
97- new StructType ()
98- .add(" key" , keySchema)
99- .add(" value" , valueSchema)
100- .add(" partition_id" , IntegerType )
97+ if (sourceOptions.readChangeFeed) {
98+ new StructType ()
99+ .add(" batch_id" , LongType )
100+ .add(" change_type" , StringType )
101+ .add(" key" , keySchema)
102+ .add(" value" , valueSchema)
103+ .add(" partition_id" , IntegerType )
104+ } else {
105+ new StructType ()
106+ .add(" key" , keySchema)
107+ .add(" value" , valueSchema)
108+ .add(" partition_id" , IntegerType )
109+ }
110+
101111 } catch {
102112 case NonFatal (e) =>
103113 throw StateDataSourceErrors .failedToReadStateSchema(sourceOptions, e)
@@ -125,21 +135,38 @@ class StateDataSource extends TableProvider with DataSourceRegister {
125135 override def supportsExternalMetadata (): Boolean = false
126136}
127137
138+ case class FromSnapshotOptions (
139+ snapshotStartBatchId : Long ,
140+ snapshotPartitionId : Int )
141+
142+ case class ReadChangeFeedOptions (
143+ changeStartBatchId : Long ,
144+ changeEndBatchId : Long
145+ )
146+
128147case class StateSourceOptions (
129148 resolvedCpLocation : String ,
130149 batchId : Long ,
131150 operatorId : Int ,
132151 storeName : String ,
133152 joinSide : JoinSideValues ,
134- snapshotStartBatchId : Option [Long ],
135- snapshotPartitionId : Option [Int ]) {
153+ readChangeFeed : Boolean ,
154+ fromSnapshotOptions : Option [FromSnapshotOptions ],
155+ readChangeFeedOptions : Option [ReadChangeFeedOptions ]) {
136156 def stateCheckpointLocation : Path = new Path (resolvedCpLocation, DIR_NAME_STATE )
137157
138158 override def toString : String = {
139- s " StateSourceOptions(checkpointLocation= $resolvedCpLocation, batchId= $batchId, " +
140- s " operatorId= $operatorId, storeName= $storeName, joinSide= $joinSide, " +
141- s " snapshotStartBatchId= ${snapshotStartBatchId.getOrElse(" None" )}, " +
142- s " snapshotPartitionId= ${snapshotPartitionId.getOrElse(" None" )}) "
159+ var desc = s " StateSourceOptions(checkpointLocation= $resolvedCpLocation, batchId= $batchId, " +
160+ s " operatorId= $operatorId, storeName= $storeName, joinSide= $joinSide"
161+ if (fromSnapshotOptions.isDefined) {
162+ desc += s " , snapshotStartBatchId= ${fromSnapshotOptions.get.snapshotStartBatchId}"
163+ desc += s " , snapshotPartitionId= ${fromSnapshotOptions.get.snapshotPartitionId}"
164+ }
165+ if (readChangeFeedOptions.isDefined) {
166+ desc += s " , changeStartBatchId= ${readChangeFeedOptions.get.changeStartBatchId}"
167+ desc += s " , changeEndBatchId= ${readChangeFeedOptions.get.changeEndBatchId}"
168+ }
169+ desc + " )"
143170 }
144171}
145172
@@ -151,6 +178,9 @@ object StateSourceOptions extends DataSourceOptions {
151178 val JOIN_SIDE = newOption(" joinSide" )
152179 val SNAPSHOT_START_BATCH_ID = newOption(" snapshotStartBatchId" )
153180 val SNAPSHOT_PARTITION_ID = newOption(" snapshotPartitionId" )
181+ val READ_CHANGE_FEED = newOption(" readChangeFeed" )
182+ val CHANGE_START_BATCH_ID = newOption(" changeStartBatchId" )
183+ val CHANGE_END_BATCH_ID = newOption(" changeEndBatchId" )
154184
155185 object JoinSideValues extends Enumeration {
156186 type JoinSideValues = Value
@@ -172,16 +202,6 @@ object StateSourceOptions extends DataSourceOptions {
172202 throw StateDataSourceErrors .requiredOptionUnspecified(PATH )
173203 }.get
174204
175- val resolvedCpLocation = resolvedCheckpointLocation(hadoopConf, checkpointLocation)
176-
177- val batchId = Option (options.get(BATCH_ID )).map(_.toLong).orElse {
178- Some (getLastCommittedBatch(sparkSession, resolvedCpLocation))
179- }.get
180-
181- if (batchId < 0 ) {
182- throw StateDataSourceErrors .invalidOptionValueIsNegative(BATCH_ID )
183- }
184-
185205 val operatorId = Option (options.get(OPERATOR_ID )).map(_.toInt)
186206 .orElse(Some (0 )).get
187207
@@ -210,30 +230,97 @@ object StateSourceOptions extends DataSourceOptions {
210230 throw StateDataSourceErrors .conflictOptions(Seq (JOIN_SIDE , STORE_NAME ))
211231 }
212232
213- val snapshotStartBatchId = Option (options.get(SNAPSHOT_START_BATCH_ID )).map(_.toLong)
214- if (snapshotStartBatchId.exists(_ < 0 )) {
215- throw StateDataSourceErrors .invalidOptionValueIsNegative(SNAPSHOT_START_BATCH_ID )
216- } else if (snapshotStartBatchId.exists(_ > batchId)) {
217- throw StateDataSourceErrors .invalidOptionValue(
218- SNAPSHOT_START_BATCH_ID , s " value should be less than or equal to $batchId" )
219- }
233+ val resolvedCpLocation = resolvedCheckpointLocation(hadoopConf, checkpointLocation)
234+
235+ var batchId = Option (options.get(BATCH_ID )).map(_.toLong)
220236
237+ val snapshotStartBatchId = Option (options.get(SNAPSHOT_START_BATCH_ID )).map(_.toLong)
221238 val snapshotPartitionId = Option (options.get(SNAPSHOT_PARTITION_ID )).map(_.toInt)
222- if (snapshotPartitionId.exists(_ < 0 )) {
223- throw StateDataSourceErrors .invalidOptionValueIsNegative(SNAPSHOT_PARTITION_ID )
224- }
225239
226- // both snapshotPartitionId and snapshotStartBatchId are required at the same time, because
227- // each partition may have different checkpoint status
228- if (snapshotPartitionId.isDefined && snapshotStartBatchId.isEmpty) {
229- throw StateDataSourceErrors .requiredOptionUnspecified(SNAPSHOT_START_BATCH_ID )
230- } else if (snapshotPartitionId.isEmpty && snapshotStartBatchId.isDefined) {
231- throw StateDataSourceErrors .requiredOptionUnspecified(SNAPSHOT_PARTITION_ID )
240+ val readChangeFeed = Option (options.get(READ_CHANGE_FEED )).exists(_.toBoolean)
241+
242+ val changeStartBatchId = Option (options.get(CHANGE_START_BATCH_ID )).map(_.toLong)
243+ var changeEndBatchId = Option (options.get(CHANGE_END_BATCH_ID )).map(_.toLong)
244+
245+ var fromSnapshotOptions : Option [FromSnapshotOptions ] = None
246+ var readChangeFeedOptions : Option [ReadChangeFeedOptions ] = None
247+
248+ if (readChangeFeed) {
249+ if (joinSide != JoinSideValues .none) {
250+ throw StateDataSourceErrors .conflictOptions(Seq (JOIN_SIDE , READ_CHANGE_FEED ))
251+ }
252+ if (batchId.isDefined) {
253+ throw StateDataSourceErrors .conflictOptions(Seq (BATCH_ID , READ_CHANGE_FEED ))
254+ }
255+ if (snapshotStartBatchId.isDefined) {
256+ throw StateDataSourceErrors .conflictOptions(Seq (SNAPSHOT_START_BATCH_ID , READ_CHANGE_FEED ))
257+ }
258+ if (snapshotPartitionId.isDefined) {
259+ throw StateDataSourceErrors .conflictOptions(Seq (SNAPSHOT_PARTITION_ID , READ_CHANGE_FEED ))
260+ }
261+
262+ if (changeStartBatchId.isEmpty) {
263+ throw StateDataSourceErrors .requiredOptionUnspecified(CHANGE_START_BATCH_ID )
264+ }
265+ changeEndBatchId = Some (
266+ changeEndBatchId.getOrElse(getLastCommittedBatch(sparkSession, resolvedCpLocation)))
267+
268+ // changeStartBatchId and changeEndBatchId must all be defined at this point
269+ if (changeStartBatchId.get < 0 ) {
270+ throw StateDataSourceErrors .invalidOptionValueIsNegative(CHANGE_START_BATCH_ID )
271+ }
272+ if (changeEndBatchId.get < changeStartBatchId.get) {
273+ throw StateDataSourceErrors .invalidOptionValue(CHANGE_END_BATCH_ID ,
274+ s " $CHANGE_END_BATCH_ID cannot be smaller than $CHANGE_START_BATCH_ID. " +
275+ s " Please check the input to $CHANGE_END_BATCH_ID, or if you are using its default " +
276+ s " value, make sure that $CHANGE_START_BATCH_ID is less than ${changeEndBatchId.get}. " )
277+ }
278+
279+ batchId = Some (changeEndBatchId.get)
280+
281+ readChangeFeedOptions = Option (
282+ ReadChangeFeedOptions (changeStartBatchId.get, changeEndBatchId.get))
283+ } else {
284+ if (changeStartBatchId.isDefined) {
285+ throw StateDataSourceErrors .invalidOptionValue(CHANGE_START_BATCH_ID ,
286+ s " Only specify this option when $READ_CHANGE_FEED is set to true. " )
287+ }
288+ if (changeEndBatchId.isDefined) {
289+ throw StateDataSourceErrors .invalidOptionValue(CHANGE_END_BATCH_ID ,
290+ s " Only specify this option when $READ_CHANGE_FEED is set to true. " )
291+ }
292+
293+ batchId = Some (batchId.getOrElse(getLastCommittedBatch(sparkSession, resolvedCpLocation)))
294+
295+ if (batchId.get < 0 ) {
296+ throw StateDataSourceErrors .invalidOptionValueIsNegative(BATCH_ID )
297+ }
298+ if (snapshotStartBatchId.exists(_ < 0 )) {
299+ throw StateDataSourceErrors .invalidOptionValueIsNegative(SNAPSHOT_START_BATCH_ID )
300+ } else if (snapshotStartBatchId.exists(_ > batchId.get)) {
301+ throw StateDataSourceErrors .invalidOptionValue(
302+ SNAPSHOT_START_BATCH_ID , s " value should be less than or equal to ${batchId.get}" )
303+ }
304+ if (snapshotPartitionId.exists(_ < 0 )) {
305+ throw StateDataSourceErrors .invalidOptionValueIsNegative(SNAPSHOT_PARTITION_ID )
306+ }
307+ // both snapshotPartitionId and snapshotStartBatchId are required at the same time, because
308+ // each partition may have different checkpoint status
309+ if (snapshotPartitionId.isDefined && snapshotStartBatchId.isEmpty) {
310+ throw StateDataSourceErrors .requiredOptionUnspecified(SNAPSHOT_START_BATCH_ID )
311+ } else if (snapshotPartitionId.isEmpty && snapshotStartBatchId.isDefined) {
312+ throw StateDataSourceErrors .requiredOptionUnspecified(SNAPSHOT_PARTITION_ID )
313+ }
314+
315+ if (snapshotStartBatchId.isDefined && snapshotPartitionId.isDefined) {
316+ fromSnapshotOptions = Some (
317+ FromSnapshotOptions (snapshotStartBatchId.get, snapshotPartitionId.get))
318+ }
232319 }
233320
234321 StateSourceOptions (
235- resolvedCpLocation, batchId, operatorId, storeName,
236- joinSide, snapshotStartBatchId, snapshotPartitionId )
322+ resolvedCpLocation, batchId.get , operatorId, storeName, joinSide ,
323+ readChangeFeed, fromSnapshotOptions, readChangeFeedOptions )
237324 }
238325
239326 private def resolvedCheckpointLocation (
0 commit comments