diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncLogPurge.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncLogPurge.scala index aa393211a1c15..831aca8513bc7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncLogPurge.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncLogPurge.scala @@ -21,6 +21,7 @@ import java.util.concurrent.atomic.AtomicBoolean import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.ThreadUtils @@ -40,8 +41,15 @@ trait AsyncLogPurge extends Logging { private val purgeRunning = new AtomicBoolean(false) + private val purgeOldestRunning = new AtomicBoolean(false) + protected def purge(threshold: Long): Unit + // This method is used to purge the oldest OperatorStateMetadata and StateSchema files + // which are written per run. Unlike purge(), which is called per microbatch, this is + // called at the planning stage of each query run. + protected def purgeOldest(plan: SparkPlan): Unit + protected lazy val useAsyncPurge: Boolean = sparkSession.conf.get(SQLConf.ASYNC_LOG_PURGE) protected def purgeAsync(batchId: Long): Unit = { @@ -62,6 +70,24 @@ trait AsyncLogPurge extends Logging { } } + protected def purgeOldestAsync(plan: SparkPlan): Unit = { + if (purgeOldestRunning.compareAndSet(false, true)) { + asyncPurgeExecutorService.execute(() => { + try { + purgeOldest(plan) + } catch { + case throwable: Throwable => + logError("Encountered error while performing async log purge", throwable) + errorNotifier.markError(throwable) + } finally { + purgeOldestRunning.set(false) + } + }) + } else { + log.debug("Skipped log purging since there is already one in progress.") + } + } + protected def asyncLogPurgeShutdown(): Unit = { ThreadUtils.shutdown(asyncPurgeExecutorService) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 567fb1b98f14c..70d613747c70c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadat import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike import org.apache.spark.sql.execution.python.FlatMapGroupsInPandasWithStateExec import org.apache.spark.sql.execution.streaming.sources.WriteToMicroBatchDataSourceV1 -import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadataV1, OperatorStateMetadataV2, OperatorStateMetadataWriter} +import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadataV1, OperatorStateMetadataV2, OperatorStateMetadataV2FileManager, OperatorStateMetadataWriter, StateStoreId} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -58,7 +58,7 @@ class IncrementalExecution( val offsetSeqMetadata: OffsetSeqMetadata, val watermarkPropagator: WatermarkPropagator, val isFirstBatch: Boolean) - extends QueryExecution(sparkSession, logicalPlan) with Logging { + extends QueryExecution(sparkSession, logicalPlan) with Logging with AsyncLogPurge { // Modified planner with stateful operations. override val planner: SparkPlanner = new SparkPlanner( @@ -79,6 +79,38 @@ class IncrementalExecution( StreamingTransformWithStateStrategy :: Nil } + // Methods to enable the use of AsyncLogPurge + protected val minLogEntriesToMaintain: Int = + sparkSession.sessionState.conf.minBatchesToRetain + + val errorNotifier: ErrorNotifier = new ErrorNotifier() + + override protected def purge(threshold: Long): Unit = {} + + override protected def purgeOldest(planWithStateOpId: SparkPlan): Unit = { + planWithStateOpId.collect { + case ssw: StateStoreWriter => + ssw.operatorStateMetadataVersion match { + case 2 => + val fileManager = new OperatorStateMetadataV2FileManager( + new Path(checkpointLocation, ssw.getStateInfo.operatorId.toString), + ssw.stateSchemaFilePath(Some(StateStoreId.DEFAULT_STORE_NAME)), + hadoopConf) + fileManager.keepNEntries(minLogEntriesToMaintain) + case _ => + } + case _ => + } + } + + private def purgeMetadataFiles(planWithStateOpId: SparkPlan): Unit = { + if (useAsyncPurge) { + purgeOldestAsync(planWithStateOpId) + } else { + purgeOldest(planWithStateOpId) + } + } + private lazy val hadoopConf = sparkSession.sessionState.newHadoopConf() private[sql] val numStateStores = offsetSeqMetadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key) @@ -502,6 +534,7 @@ class IncrementalExecution( // The rule below doesn't change the plan but can cause the side effect that // metadata/schema is written in the checkpoint directory of stateful operator. planWithStateOpId transform StateSchemaAndOperatorMetadataRule.rule + purgeMetadataFiles(planWithStateOpId) simulateWatermarkPropagation(planWithStateOpId) planWithStateOpId transform WatermarkPropagationRule.rule diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 6fd58e13366e0..f612ee9b2a120 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table} import org.apache.spark.sql.connector.read.streaming.{Offset => OffsetV2, ReadLimit, SparkDataStream} import org.apache.spark.sql.connector.write.{LogicalWriteInfoImpl, SupportsTruncate, Write} +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command.StreamingExplainCommand import org.apache.spark.sql.execution.streaming.sources.ForeachBatchUserFuncException import org.apache.spark.sql.internal.SQLConf @@ -690,6 +691,9 @@ abstract class StreamExecution( offsetLog.purge(threshold) commitLog.purge(threshold) } + + // This is to fulfill the interface of AsyncLogPurge + protected def purgeOldest(plan: SparkPlan): Unit = {} } object StreamExecution { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala index 3be4f95b3af92..434a60d3aa9e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala @@ -23,11 +23,12 @@ import java.nio.charset.StandardCharsets import scala.reflect.ClassTag import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FSDataInputStream, FSDataOutputStream, Path} +import org.apache.hadoop.fs.{FSDataInputStream, FSDataOutputStream, Path, PathFilter} import org.json4s.{Formats, NoTypeHints} import org.json4s.jackson.Serialization import org.apache.spark.internal.{Logging, LogKeys, MDC} +import org.apache.spark.internal.LogKeys.BATCH_ID import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, MetadataVersionUtil} import org.apache.spark.sql.execution.streaming.CheckpointFileManager.CancellableFSDataOutputStream import org.apache.spark.sql.execution.streaming.state.OperatorStateMetadataUtils.{OperatorStateMetadataReader, OperatorStateMetadataWriter} @@ -312,3 +313,89 @@ class OperatorStateMetadataV2Reader( OperatorStateMetadataUtils.readMetadata(inputStream) } } + +class OperatorStateMetadataV2FileManager( + stateCheckpointPath: Path, + stateSchemaPath: Path, + hadoopConf: Configuration) extends Logging { + + private val metadataDirPath = OperatorStateMetadataV2.metadataDirPath(stateCheckpointPath) + private lazy val fm = CheckpointFileManager.create(metadataDirPath, hadoopConf) + + protected def isBatchFile(path: Path) = { + try { + path.getName.toLong + true + } catch { + case _: NumberFormatException => false + } + } + + /** + * A `PathFilter` to filter only batch files + */ + protected val batchFilesFilter = new PathFilter { + override def accept(path: Path): Boolean = isBatchFile(path) + } + + /** List the available batches on file system. */ + protected def listBatches: Array[Long] = { + val batchIds = fm.list(metadataDirPath, batchFilesFilter) + // Batches must be files + .filter(f => f.isFile) + .map(f => pathToBatchId(f.getPath)) + logInfo(log"BatchIds found from listing: ${MDC(BATCH_ID, batchIds.sorted.mkString(", "))}") + + batchIds.sorted + } + + private def pathToBatchId(path: Path): Long = { + path.getName.toLong + } + + def keepNEntries(minLogEntriesToMaintain: Int): Unit = { + val thresholdBatchId = findThresholdBatchId(minLogEntriesToMaintain) + if (thresholdBatchId != -1) { + deleteSchemaFiles(thresholdBatchId) + deleteMetadataFiles(thresholdBatchId) + } + } + + private def findThresholdBatchId(minLogEntriesToMaintain: Int): Long = { + val metadataFiles = listBatches + if (metadataFiles.length > minLogEntriesToMaintain) { + metadataFiles.sorted.take(metadataFiles.length - minLogEntriesToMaintain).last + 1 + } else { + -1 + } + } + + private def deleteSchemaFiles(thresholdBatchId: Long): Unit = { + val schemaFiles = fm.list(stateSchemaPath).sorted.map(_.getPath) + val filesBeforeThreshold = schemaFiles.filter { path => + val batchIdInPath = path.getName.split("_").head.toLong + batchIdInPath < thresholdBatchId + } + filesBeforeThreshold.foreach { path => + fm.delete(path) + } + } + + private def deleteMetadataFiles(thresholdBatchId: Long): Unit = { + val metadataFiles = fm.list(metadataDirPath, batchFilesFilter) + metadataFiles.foreach { batchFile => + val batchId = pathToBatchId(batchFile.getPath) + if (batchId < thresholdBatchId) { + fm.delete(batchFile.getPath) + } + } + } + + private[sql] def listSchemaFiles(): Array[Path] = { + fm.list(stateSchemaPath).sorted.map(_.getPath) + } + + private[sql] def listMetadataFiles(): Array[Path] = { + fm.list(metadataDirPath, batchFilesFilter).sorted.map(_.getPath) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index d55a16a60eac0..6a1edf533dfea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -983,6 +983,57 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } } + + private def getOperatorStateMetadataFileManager( + stateCheckpointPath: Path, + stateSchemaPath: Path): OperatorStateMetadataV2FileManager = { + val hadoopConf = spark.sessionState.newHadoopConf() + new OperatorStateMetadataV2FileManager(stateCheckpointPath, stateSchemaPath, hadoopConf) + } + + private def getStateSchemaPath(stateCheckpointPath: Path): Path = { + new Path(stateCheckpointPath, "default/_metadata/schema") + } + + test("transformWithState - verify that metadata and schema logs are purged") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString, + SQLConf.MIN_BATCHES_TO_RETAIN.key -> "1") { + withTempDir { chkptDir => + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + StopStream, + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "2")), + StopStream, + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, "a"), + CheckNewAnswer(), + StopStream + ) + val stateOpIdPath = new Path(new Path(chkptDir.getCanonicalPath, "state"), "0") + val stateSchemaPath = getStateSchemaPath(stateOpIdPath) + + val fm = getOperatorStateMetadataFileManager( + stateOpIdPath, + stateSchemaPath) + assert(fm.listSchemaFiles().length == 1) + assert(fm.listMetadataFiles().length == 1) + } + } + } } class TransformWithStateValidationSuite extends StateStoreMetricsTest {