Skip to content

Commit 6922595

Browse files
stage
1 parent cf84d50 commit 6922595

File tree

5 files changed

+200
-17
lines changed

5 files changed

+200
-17
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,12 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
301301
new HDFSBackedReadStateStore(endVersion, newMap)
302302
}
303303

304+
override def getStateStoreCDCReader(startVersion: Long, endVersion: Long): StateStoreCDCReader = {
305+
new HDFSBackedStateStoreCDCReader(fm, baseDir, startVersion, endVersion,
306+
CompressionCodec.createCodec(sparkConf, storeConf.compressionCodec),
307+
keySchema, valueSchema)
308+
}
309+
304310
private def getLoadedMapForStore(version: Long): HDFSBackedStateStoreMap = synchronized {
305311
try {
306312
if (version < 0) {
@@ -338,12 +344,6 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
338344
}
339345
}
340346

341-
override def getStateStoreCDCReader(startVersion: Long, endVersion: Long): StateStoreCDCReader = {
342-
new HDFSBackedStateStoreCDCReader(fm, baseDir, startVersion, endVersion,
343-
CompressionCodec.createCodec(sparkConf, storeConf.compressionCodec),
344-
keySchema, valueSchema)
345-
}
346-
347347
// Run bunch of validations specific to HDFSBackedStateStoreProvider
348348
private def runValidation(
349349
useColumnFamilies: Boolean,

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,10 @@ import org.apache.hadoop.conf.Configuration
2626
import org.apache.spark.{SparkConf, SparkEnv}
2727
import org.apache.spark.internal.{Logging, MDC}
2828
import org.apache.spark.internal.LogKeys._
29+
import org.apache.spark.io.CompressionCodec
2930
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
3031
import org.apache.spark.sql.errors.QueryExecutionErrors
32+
import org.apache.spark.sql.execution.streaming.CheckpointFileManager
3133
import org.apache.spark.sql.types.StructType
3234
import org.apache.spark.util.Utils
3335

@@ -354,6 +356,19 @@ private[sql] class RocksDBStateStoreProvider
354356
}
355357
}
356358

359+
override def getStateStoreCDCReader(startVersion: Long, endVersion: Long): StateStoreCDCReader = {
360+
val statePath = stateStoreId.storeCheckpointLocation()
361+
val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf)
362+
new RocksDBStateStoreCDCReader(
363+
CheckpointFileManager.create(statePath, hadoopConf),
364+
statePath,
365+
startVersion,
366+
endVersion,
367+
CompressionCodec.createCodec(sparkConf, storeConf.compressionCodec),
368+
keySchema,
369+
valueSchema)
370+
}
371+
357372
override def doMaintenance(): Unit = {
358373
try {
359374
rocksDB.doMaintenance()

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCDCReader.scala

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,10 @@ import org.apache.spark.util.NextIterator
3434
*/
3535
abstract class StateStoreCDCReader(
3636
fm: CheckpointFileManager,
37-
// fileToRead: Path,
3837
stateLocation: Path,
3938
startVersion: Long,
4039
endVersion: Long,
41-
compressionCodec: CompressionCodec,
42-
keySchema: StructType,
43-
valueSchema: StructType)
40+
compressionCodec: CompressionCodec)
4441
extends NextIterator[(RecordType.Value, UnsafeRow, UnsafeRow, Long)] with Logging {
4542

4643
class ChangeLogFileIterator(
@@ -89,7 +86,7 @@ class HDFSBackedStateStoreCDCReader(
8986
valueSchema: StructType
9087
)
9188
extends StateStoreCDCReader(
92-
fm, stateLocation, startVersion, endVersion, compressionCodec, keySchema, valueSchema) {
89+
fm, stateLocation, startVersion, endVersion, compressionCodec) {
9390
override protected var changelogSuffix: String = "delta"
9491

9592
private var currentChangelogReader: StateStoreChangelogReader = null
@@ -125,3 +122,51 @@ class HDFSBackedStateStoreCDCReader(
125122
}
126123
}
127124
}
125+
126+
class RocksDBStateStoreCDCReader(
127+
fm: CheckpointFileManager,
128+
stateLocation: Path,
129+
startVersion: Long,
130+
endVersion: Long,
131+
compressionCodec: CompressionCodec,
132+
keySchema: StructType,
133+
valueSchema: StructType
134+
)
135+
extends StateStoreCDCReader(
136+
fm, stateLocation, startVersion, endVersion, compressionCodec) {
137+
override protected var changelogSuffix: String = "changelog"
138+
139+
private var currentChangelogReader: StateStoreChangelogReader = null
140+
141+
override def getNext(): (RecordType.Value, UnsafeRow, UnsafeRow, Long) = {
142+
while (currentChangelogReader == null || !currentChangelogReader.hasNext) {
143+
if (currentChangelogReader != null) {
144+
currentChangelogReader.close()
145+
currentChangelogReader = null
146+
}
147+
if (!fileIterator.hasNext) {
148+
finished = true
149+
return null
150+
}
151+
currentChangelogReader =
152+
new StateStoreChangelogReaderV1(fm, fileIterator.next(), compressionCodec)
153+
}
154+
155+
val readResult = currentChangelogReader.next()
156+
val keyRow = new UnsafeRow(keySchema.fields.length)
157+
keyRow.pointTo(readResult._2, readResult._2.length)
158+
val valueRow = new UnsafeRow(valueSchema.fields.length)
159+
// If valueSize in existing file is not multiple of 8, floor it to multiple of 8.
160+
// This is a workaround for the following:
161+
// Prior to Spark 2.3 mistakenly append 4 bytes to the value row in
162+
// `RowBasedKeyValueBatch`, which gets persisted into the checkpoint data
163+
valueRow.pointTo(readResult._3, (readResult._3.length / 8) * 8)
164+
(readResult._1, keyRow, valueRow, fileIterator.getVersion - 1)
165+
}
166+
167+
override def close(): Unit = {
168+
if (currentChangelogReader != null) {
169+
currentChangelogReader.close()
170+
}
171+
}
172+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -444,8 +444,8 @@ class SymmetricHashJoinStateManager(
444444
private val keySchema = StructType(
445445
joinKeys.zipWithIndex.map { case (k, i) => StructField(s"field$i", k.dataType, k.nullable) })
446446
private val keyAttributes = toAttributes(keySchema)
447-
private val keyToNumValues = new KeyToNumValuesStore()
448-
private val keyWithIndexToValue = new KeyWithIndexToValueStore(stateFormatVersion)
447+
private lazy val keyToNumValues = new KeyToNumValuesStore()
448+
private lazy val keyWithIndexToValue = new KeyWithIndexToValueStore(stateFormatVersion)
449449

450450
// Clean up any state store resources if necessary at the end of the task
451451
Option(TaskContext.get()).foreach { _.addTaskCompletionListener[Unit] { _ => abortIfNeeded() } }
@@ -476,6 +476,16 @@ class SymmetricHashJoinStateManager(
476476

477477
def metrics: StateStoreMetrics = stateStore.metrics
478478

479+
private def initializeStateStoreProvider(keySchema: StructType, valueSchema: StructType):
480+
Unit = {
481+
val storeProviderId = StateStoreProviderId(
482+
stateInfo.get, partitionId, getStateStoreName(joinSide, stateStoreType))
483+
stateStoreProvider = StateStoreProvider.createAndInit(
484+
storeProviderId, keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema),
485+
useColumnFamilies = false, storeConf, hadoopConf,
486+
useMultipleValuesPerKey = false)
487+
}
488+
479489
/** Get the StateStore with the given schema */
480490
protected def getStateStore(keySchema: StructType, valueSchema: StructType): StateStore = {
481491
val storeProviderId = StateStoreProviderId(
@@ -488,10 +498,7 @@ class SymmetricHashJoinStateManager(
488498
stateInfo.get.storeVersion, useColumnFamilies = false, storeConf, hadoopConf)
489499
} else {
490500
// This class will manage the state store provider by itself.
491-
stateStoreProvider = StateStoreProvider.createAndInit(
492-
storeProviderId, keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema),
493-
useColumnFamilies = false, storeConf, hadoopConf,
494-
useMultipleValuesPerKey = false)
501+
initializeStateStoreProvider(keySchema, valueSchema)
495502
if (snapshotStartVersion.isDefined) {
496503
stateStoreProvider.getStore(snapshotStartVersion.get, stateInfo.get.storeVersion)
497504
} else {
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.sql.execution.datasources.v2.state
18+
19+
import java.io.File
20+
21+
import org.scalatest.Assertions
22+
23+
import org.apache.spark.sql.execution.streaming.MemoryStream
24+
import org.apache.spark.sql.execution.streaming.state._
25+
import org.apache.spark.sql.internal.SQLConf
26+
27+
28+
29+
class HDFSBackedStateDataSourceReadCDCSuite extends StateDataSourceCDCReadSuite {
30+
override protected def newStateStoreProvider(): HDFSBackedStateStoreProvider =
31+
new HDFSBackedStateStoreProvider
32+
33+
override def beforeAll(): Unit = {
34+
super.beforeAll()
35+
spark.conf.set(SQLConf.STATE_STORE_PROVIDER_CLASS.key,
36+
newStateStoreProvider().getClass.getName)
37+
// make sure we have a snapshot for every two delta files
38+
// HDFS maintenance task will not count the latest delta file, which has the same version
39+
// as the snapshot version
40+
spark.conf.set(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key, 1)
41+
}
42+
}
43+
44+
class RocksDBStateDataSourceCDCReadSuite extends StateDataSourceCDCReadSuite {
45+
override protected def newStateStoreProvider(): RocksDBStateStoreProvider =
46+
new RocksDBStateStoreProvider
47+
48+
override def beforeAll(): Unit = {
49+
super.beforeAll()
50+
spark.conf.set(SQLConf.STATE_STORE_PROVIDER_CLASS.key,
51+
newStateStoreProvider().getClass.getName)
52+
spark.conf.set("spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled",
53+
"false")
54+
}
55+
}
56+
57+
class RocksDBWithChangelogCheckpointStateDataSourceCDCReaderSuite extends
58+
StateDataSourceCDCReadSuite {
59+
override protected def newStateStoreProvider(): RocksDBStateStoreProvider =
60+
new RocksDBStateStoreProvider
61+
62+
override def beforeAll(): Unit = {
63+
super.beforeAll()
64+
spark.conf.set(SQLConf.STATE_STORE_PROVIDER_CLASS.key,
65+
newStateStoreProvider().getClass.getName)
66+
spark.conf.set("spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled",
67+
"true")
68+
// make sure we have a snapshot for every other checkpoint
69+
// RocksDB maintenance task will count the latest checkpoint, so we need to set it to 2
70+
spark.conf.set(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key, 2)
71+
}
72+
}
73+
74+
abstract class StateDataSourceCDCReadSuite extends StateDataSourceTestBase with Assertions {
75+
protected def newStateStoreProvider(): StateStoreProvider
76+
77+
test("cdc read limit state") {
78+
withTempDir(tempDir => {
79+
val tempDir2 = new File("/tmp/state/rand")
80+
import testImplicits._
81+
spark.conf.set(SQLConf.STREAMING_MAINTENANCE_INTERVAL.key, 500)
82+
val inputData = MemoryStream[Int]
83+
val df = inputData.toDF().limit(10)
84+
testStream(df)(
85+
StartStream(checkpointLocation = tempDir2.getAbsolutePath),
86+
AddData(inputData, 1, 2, 3, 4),
87+
CheckLastBatch(1, 2, 3, 4),
88+
AddData(inputData, 5, 6, 7, 8),
89+
CheckLastBatch(5, 6, 7, 8),
90+
AddData(inputData, 9, 10, 11, 12),
91+
CheckLastBatch(9, 10)
92+
)
93+
94+
val stateDf = spark.read.format("statestore")
95+
.option(StateSourceOptions.MODE_TYPE, "cdc")
96+
.option(StateSourceOptions.CDC_START_BATCH_ID, 0)
97+
.option(StateSourceOptions.CDC_END_BATCH_ID, 2)
98+
.load(tempDir2.getAbsolutePath)
99+
stateDf.show()
100+
101+
val expectedDf = spark.createDataFrame()
102+
})
103+
}
104+
105+
test("cdc read aggregate state") {
106+
107+
}
108+
109+
test("cdc read deduplication state") {
110+
111+
}
112+
113+
test("cdc read stream-stream join state") {
114+
115+
}
116+
}

0 commit comments

Comments
 (0)