Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.util.concurrent.atomic.AtomicBoolean

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.ThreadUtils

Expand All @@ -40,8 +41,15 @@ trait AsyncLogPurge extends Logging {

private val purgeRunning = new AtomicBoolean(false)

private val purgeOldestRunning = new AtomicBoolean(false)

protected def purge(threshold: Long): Unit

// This method is used to purge the oldest OperatorStateMetadata and StateSchema files
// which are written per run. Unlike purge(), which is called per microbatch, this is
// called at the planning stage of each query run.
protected def purgeOldest(plan: SparkPlan): Unit

protected lazy val useAsyncPurge: Boolean = sparkSession.conf.get(SQLConf.ASYNC_LOG_PURGE)

protected def purgeAsync(batchId: Long): Unit = {
Expand All @@ -62,6 +70,24 @@ trait AsyncLogPurge extends Logging {
}
}

protected def purgeOldestAsync(plan: SparkPlan): Unit = {
if (purgeOldestRunning.compareAndSet(false, true)) {
asyncPurgeExecutorService.execute(() => {
try {
purgeOldest(plan)
} catch {
case throwable: Throwable =>
logError("Encountered error while performing async log purge", throwable)
errorNotifier.markError(throwable)
} finally {
purgeOldestRunning.set(false)
}
})
} else {
log.debug("Skipped log purging since there is already one in progress.")
}
}

protected def asyncLogPurgeShutdown(): Unit = {
ThreadUtils.shutdown(asyncPurgeExecutorService)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadat
import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
import org.apache.spark.sql.execution.python.FlatMapGroupsInPandasWithStateExec
import org.apache.spark.sql.execution.streaming.sources.WriteToMicroBatchDataSourceV1
import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadataV1, OperatorStateMetadataV2, OperatorStateMetadataWriter}
import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadataV1, OperatorStateMetadataV2, OperatorStateMetadataV2FileManager, OperatorStateMetadataWriter, StateStoreId}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.util.{SerializableConfiguration, Utils}
Expand All @@ -58,7 +58,7 @@ class IncrementalExecution(
val offsetSeqMetadata: OffsetSeqMetadata,
val watermarkPropagator: WatermarkPropagator,
val isFirstBatch: Boolean)
extends QueryExecution(sparkSession, logicalPlan) with Logging {
extends QueryExecution(sparkSession, logicalPlan) with Logging with AsyncLogPurge {

// Modified planner with stateful operations.
override val planner: SparkPlanner = new SparkPlanner(
Expand All @@ -79,6 +79,38 @@ class IncrementalExecution(
StreamingTransformWithStateStrategy :: Nil
}

// Methods to enable the use of AsyncLogPurge
protected val minLogEntriesToMaintain: Int =
sparkSession.sessionState.conf.minBatchesToRetain

val errorNotifier: ErrorNotifier = new ErrorNotifier()

override protected def purge(threshold: Long): Unit = {}

override protected def purgeOldest(planWithStateOpId: SparkPlan): Unit = {
planWithStateOpId.collect {
case ssw: StateStoreWriter =>
ssw.operatorStateMetadataVersion match {
case 2 =>
val fileManager = new OperatorStateMetadataV2FileManager(
new Path(checkpointLocation, ssw.getStateInfo.operatorId.toString),
ssw.stateSchemaFilePath(Some(StateStoreId.DEFAULT_STORE_NAME)),
hadoopConf)
fileManager.keepNEntries(minLogEntriesToMaintain)
case _ =>
}
case _ =>
}
}

private def purgeMetadataFiles(planWithStateOpId: SparkPlan): Unit = {
if (useAsyncPurge) {
purgeOldestAsync(planWithStateOpId)
} else {
purgeOldest(planWithStateOpId)
}
}

private lazy val hadoopConf = sparkSession.sessionState.newHadoopConf()

private[sql] val numStateStores = offsetSeqMetadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key)
Expand Down Expand Up @@ -502,6 +534,7 @@ class IncrementalExecution(
// The rule below doesn't change the plan but can cause the side effect that
// metadata/schema is written in the checkpoint directory of stateful operator.
planWithStateOpId transform StateSchemaAndOperatorMetadataRule.rule
purgeMetadataFiles(planWithStateOpId)

simulateWatermarkPropagation(planWithStateOpId)
planWithStateOpId transform WatermarkPropagationRule.rule
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table}
import org.apache.spark.sql.connector.read.streaming.{Offset => OffsetV2, ReadLimit, SparkDataStream}
import org.apache.spark.sql.connector.write.{LogicalWriteInfoImpl, SupportsTruncate, Write}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.command.StreamingExplainCommand
import org.apache.spark.sql.execution.streaming.sources.ForeachBatchUserFuncException
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -690,6 +691,9 @@ abstract class StreamExecution(
offsetLog.purge(threshold)
commitLog.purge(threshold)
}

// This is to fulfill the interface of AsyncLogPurge
protected def purgeOldest(plan: SparkPlan): Unit = {}
}

object StreamExecution {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@ import java.nio.charset.StandardCharsets
import scala.reflect.ClassTag

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FSDataInputStream, FSDataOutputStream, Path}
import org.apache.hadoop.fs.{FSDataInputStream, FSDataOutputStream, Path, PathFilter}
import org.json4s.{Formats, NoTypeHints}
import org.json4s.jackson.Serialization

import org.apache.spark.internal.{Logging, LogKeys, MDC}
import org.apache.spark.internal.LogKeys.BATCH_ID
import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, MetadataVersionUtil}
import org.apache.spark.sql.execution.streaming.CheckpointFileManager.CancellableFSDataOutputStream
import org.apache.spark.sql.execution.streaming.state.OperatorStateMetadataUtils.{OperatorStateMetadataReader, OperatorStateMetadataWriter}
Expand Down Expand Up @@ -312,3 +313,89 @@ class OperatorStateMetadataV2Reader(
OperatorStateMetadataUtils.readMetadata(inputStream)
}
}

class OperatorStateMetadataV2FileManager(
stateCheckpointPath: Path,
stateSchemaPath: Path,
hadoopConf: Configuration) extends Logging {

private val metadataDirPath = OperatorStateMetadataV2.metadataDirPath(stateCheckpointPath)
private lazy val fm = CheckpointFileManager.create(metadataDirPath, hadoopConf)

protected def isBatchFile(path: Path) = {
try {
path.getName.toLong
true
} catch {
case _: NumberFormatException => false
}
}

/**
* A `PathFilter` to filter only batch files
*/
protected val batchFilesFilter = new PathFilter {
override def accept(path: Path): Boolean = isBatchFile(path)
}

/** List the available batches on file system. */
protected def listBatches: Array[Long] = {
val batchIds = fm.list(metadataDirPath, batchFilesFilter)
// Batches must be files
.filter(f => f.isFile)
.map(f => pathToBatchId(f.getPath))
logInfo(log"BatchIds found from listing: ${MDC(BATCH_ID, batchIds.sorted.mkString(", "))}")

batchIds.sorted
}

private def pathToBatchId(path: Path): Long = {
path.getName.toLong
}

def keepNEntries(minLogEntriesToMaintain: Int): Unit = {
val thresholdBatchId = findThresholdBatchId(minLogEntriesToMaintain)
if (thresholdBatchId != -1) {
deleteSchemaFiles(thresholdBatchId)
deleteMetadataFiles(thresholdBatchId)
}
}

private def findThresholdBatchId(minLogEntriesToMaintain: Int): Long = {
val metadataFiles = listBatches
if (metadataFiles.length > minLogEntriesToMaintain) {
metadataFiles.sorted.take(metadataFiles.length - minLogEntriesToMaintain).last + 1
} else {
-1
}
}

private def deleteSchemaFiles(thresholdBatchId: Long): Unit = {
val schemaFiles = fm.list(stateSchemaPath).sorted.map(_.getPath)
val filesBeforeThreshold = schemaFiles.filter { path =>
val batchIdInPath = path.getName.split("_").head.toLong
batchIdInPath < thresholdBatchId
}
filesBeforeThreshold.foreach { path =>
fm.delete(path)
}
}

private def deleteMetadataFiles(thresholdBatchId: Long): Unit = {
val metadataFiles = fm.list(metadataDirPath, batchFilesFilter)
metadataFiles.foreach { batchFile =>
val batchId = pathToBatchId(batchFile.getPath)
if (batchId < thresholdBatchId) {
fm.delete(batchFile.getPath)
}
}
}

private[sql] def listSchemaFiles(): Array[Path] = {
fm.list(stateSchemaPath).sorted.map(_.getPath)
}

private[sql] def listMetadataFiles(): Array[Path] = {
fm.list(metadataDirPath, batchFilesFilter).sorted.map(_.getPath)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -983,6 +983,57 @@ class TransformWithStateSuite extends StateStoreMetricsTest
}
}
}

private def getOperatorStateMetadataFileManager(
stateCheckpointPath: Path,
stateSchemaPath: Path): OperatorStateMetadataV2FileManager = {
val hadoopConf = spark.sessionState.newHadoopConf()
new OperatorStateMetadataV2FileManager(stateCheckpointPath, stateSchemaPath, hadoopConf)
}

private def getStateSchemaPath(stateCheckpointPath: Path): Path = {
new Path(stateCheckpointPath, "default/_metadata/schema")
}

test("transformWithState - verify that metadata and schema logs are purged") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName,
SQLConf.SHUFFLE_PARTITIONS.key ->
TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString,
SQLConf.MIN_BATCHES_TO_RETAIN.key -> "1") {
withTempDir { chkptDir =>
val inputData = MemoryStream[String]
val result = inputData.toDS()
.groupByKey(x => x)
.transformWithState(new RunningCountStatefulProcessor(),
TimeMode.None(),
OutputMode.Update())

testStream(result, OutputMode.Update())(
StartStream(checkpointLocation = chkptDir.getCanonicalPath),
AddData(inputData, "a"),
CheckNewAnswer(("a", "1")),
StopStream,
StartStream(checkpointLocation = chkptDir.getCanonicalPath),
AddData(inputData, "a"),
CheckNewAnswer(("a", "2")),
StopStream,
StartStream(checkpointLocation = chkptDir.getCanonicalPath),
AddData(inputData, "a"),
CheckNewAnswer(),
StopStream
)
val stateOpIdPath = new Path(new Path(chkptDir.getCanonicalPath, "state"), "0")
val stateSchemaPath = getStateSchemaPath(stateOpIdPath)

val fm = getOperatorStateMetadataFileManager(
stateOpIdPath,
stateSchemaPath)
assert(fm.listSchemaFiles().length == 1)
assert(fm.listMetadataFiles().length == 1)
}
}
}
}

class TransformWithStateValidationSuite extends StateStoreMetricsTest {
Expand Down