Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReader, PartitionReaderFactory, Scan, ScanBuilder}
import org.apache.spark.sql.execution.datasources.v2.state.StateDataSourceErrors
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.PATH
import org.apache.spark.sql.execution.streaming.CheckpointFileManager
import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadata, OperatorStateMetadataReader, OperatorStateMetadataV1}
import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, OperatorStateMetadataLog}
import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadata, OperatorStateMetadataReader, OperatorStateMetadataV1, OperatorStateMetadataV2}
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StringType, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
Expand All @@ -46,7 +46,8 @@ case class StateMetadataTableEntry(
numPartitions: Int,
minBatchId: Long,
maxBatchId: Long,
numColsPrefixKey: Int) {
numColsPrefixKey: Int,
operatorPropertiesJson: String) {
def toRow(): InternalRow = {
new GenericInternalRow(
Array[Any](operatorId,
Expand All @@ -55,7 +56,8 @@ case class StateMetadataTableEntry(
numPartitions,
minBatchId,
maxBatchId,
numColsPrefixKey))
numColsPrefixKey,
UTF8String.fromString(operatorPropertiesJson)))
}
}

Expand Down Expand Up @@ -110,7 +112,14 @@ class StateMetadataTable extends Table with SupportsRead with SupportsMetadataCo
override def comment: String = "Number of columns in prefix key of the state store instance"
}

override val metadataColumns: Array[MetadataColumn] = Array(NumColsPrefixKeyColumn)
private object OperatorPropertiesColumn extends MetadataColumn {
override def name: String = "_operatorProperties"
override def dataType: DataType = StringType
override def comment: String = "Json string storing operator properties"
}

override val metadataColumns: Array[MetadataColumn] =
Array(NumColsPrefixKeyColumn, OperatorPropertiesColumn)
}

case class StateMetadataInputPartition(checkpointLocation: String) extends InputPartition
Expand Down Expand Up @@ -188,28 +197,55 @@ class StateMetadataPartitionReader(
} else Array.empty
}

private[sql] def allOperatorStateMetadata: Array[OperatorStateMetadata] = {
private[sql] def allOperatorStateMetadata: Array[(OperatorStateMetadata, Long)] = {
val stateDir = new Path(checkpointLocation, "state")
val opIds = fileManager
.list(stateDir, pathNameCanBeParsedAsLongFilter).map(f => pathToLong(f.getPath)).sorted
opIds.map { opId =>
new OperatorStateMetadataReader(new Path(stateDir, opId.toString), hadoopConf).read()
opIds.flatMap { opId =>
val operatorIdPath = new Path(stateDir, opId.toString)
// check if OperatorStateMetadataV2 path exists, if it does, read it
// otherwise, fall back to OperatorStateMetadataV1
val operatorStateMetadataV2Path = OperatorStateMetadataV2.metadataFilePath(operatorIdPath)
if (fileManager.exists(operatorStateMetadataV2Path)) {
val operatorStateMetadataLog = new OperatorStateMetadataLog(
hadoopConf, operatorStateMetadataV2Path.toString)
operatorStateMetadataLog.listBatchesOnDisk.flatMap { batchId =>
operatorStateMetadataLog.get(batchId).map((_, batchId))
}
} else {
Array((new OperatorStateMetadataReader(operatorIdPath, hadoopConf).read(), -1L))
}
}
}

private[state] lazy val stateMetadata: Iterator[StateMetadataTableEntry] = {
allOperatorStateMetadata.flatMap { operatorStateMetadata =>
require(operatorStateMetadata.version == 1)
val operatorStateMetadataV1 = operatorStateMetadata.asInstanceOf[OperatorStateMetadataV1]
operatorStateMetadataV1.stateStoreInfo.map { stateStoreMetadata =>
StateMetadataTableEntry(operatorStateMetadataV1.operatorInfo.operatorId,
operatorStateMetadataV1.operatorInfo.operatorName,
stateStoreMetadata.storeName,
stateStoreMetadata.numPartitions,
if (batchIds.nonEmpty) batchIds.head else -1,
if (batchIds.nonEmpty) batchIds.last else -1,
stateStoreMetadata.numColsPrefixKey
)
allOperatorStateMetadata.flatMap { case (operatorStateMetadata, batchId) =>
require(operatorStateMetadata.version == 1 || operatorStateMetadata.version == 2)
operatorStateMetadata match {
case v1: OperatorStateMetadataV1 =>
v1.stateStoreInfo.map { stateStoreMetadata =>
StateMetadataTableEntry(v1.operatorInfo.operatorId,
v1.operatorInfo.operatorName,
stateStoreMetadata.storeName,
stateStoreMetadata.numPartitions,
if (batchIds.nonEmpty) batchIds.head else -1,
if (batchIds.nonEmpty) batchIds.last else -1,
stateStoreMetadata.numColsPrefixKey,
""
)
}
case v2: OperatorStateMetadataV2 =>
v2.stateStoreInfo.map { stateStoreMetadata =>
StateMetadataTableEntry(v2.operatorInfo.operatorId,
v2.operatorInfo.operatorName,
stateStoreMetadata.storeName,
stateStoreMetadata.numPartitions,
batchId,
batchId,
stateStoreMetadata.numColsPrefixKey,
v2.operatorPropertiesJson
)
}
}
}
}.iterator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,68 +17,70 @@
package org.apache.spark.sql.execution.streaming

import org.apache.spark.sql.Encoder
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema._
import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchema, ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec}

trait ColumnFamilySchemaUtils {
def getValueStateSchema[T](
stateName: String,
keyEncoder: ExpressionEncoder[Any],
valEncoder: Encoder[T],
hasTtl: Boolean): ColumnFamilySchema


def getListStateSchema[T](
stateName: String,
keyEncoder: ExpressionEncoder[Any],
valEncoder: Encoder[T],
hasTtl: Boolean): ColumnFamilySchema


def getMapStateSchema[K, V](
stateName: String,
keyEncoder: ExpressionEncoder[Any],
userKeyEnc: Encoder[K],
valEncoder: Encoder[V],
hasTtl: Boolean): ColumnFamilySchema
}

object ColumnFamilySchemaUtilsV1 extends ColumnFamilySchemaUtils {

def getValueStateSchema[T](
stateName: String,
keyEncoder: ExpressionEncoder[Any],
valEncoder: Encoder[T],
hasTtl: Boolean): ColumnFamilySchemaV1 = {
new ColumnFamilySchemaV1(
ColumnFamilySchemaV1(
stateName,
KEY_ROW_SCHEMA,
if (hasTtl) {
VALUE_ROW_SCHEMA_WITH_TTL
} else {
VALUE_ROW_SCHEMA
},
getKeySchema(keyEncoder.schema),
getValueSchemaWithTTL(valEncoder.schema, hasTtl),
NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA))
}

def getListStateSchema[T](
stateName: String,
keyEncoder: ExpressionEncoder[Any],
valEncoder: Encoder[T],
hasTtl: Boolean): ColumnFamilySchemaV1 = {
new ColumnFamilySchemaV1(
ColumnFamilySchemaV1(
stateName,
KEY_ROW_SCHEMA,
if (hasTtl) {
VALUE_ROW_SCHEMA_WITH_TTL
} else {
VALUE_ROW_SCHEMA
},
getKeySchema(keyEncoder.schema),
getValueSchemaWithTTL(valEncoder.schema, hasTtl),
NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA))
}

def getMapStateSchema[K, V](
stateName: String,
keyEncoder: ExpressionEncoder[Any],
userKeyEnc: Encoder[K],
valEncoder: Encoder[V],
hasTtl: Boolean): ColumnFamilySchemaV1 = {
new ColumnFamilySchemaV1(
val compositeKeySchema = getCompositeKeySchema(keyEncoder.schema, userKeyEnc.schema)
ColumnFamilySchemaV1(
stateName,
COMPOSITE_KEY_ROW_SCHEMA,
if (hasTtl) {
VALUE_ROW_SCHEMA_WITH_TTL
} else {
VALUE_ROW_SCHEMA
},
compositeKeySchema,
getValueSchemaWithTTL(valEncoder.schema, hasTtl),
PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1),
Some(userKeyEnc.schema))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import scala.jdk.CollectionConverters._
import scala.reflect.ClassTag

import org.apache.commons.io.IOUtils
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs._
import org.json4s.{Formats, NoTypeHints}
import org.json4s.jackson.Serialization
Expand All @@ -47,10 +48,25 @@ import org.apache.spark.util.ArrayImplicits._
*
* Note: [[HDFSMetadataLog]] doesn't support S3-like file systems as they don't guarantee listing
* files in a directory always shows the latest files.
* @param hadoopConf Hadoop configuration that is used to read / write metadata files.
* @param path Path to the directory that will be used for writing metadata.
* @param metadataCacheEnabled Whether to cache the batches' metadata in memory.
*/
class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: String)
class HDFSMetadataLog[T <: AnyRef : ClassTag](
hadoopConf: Configuration,
path: String,
val metadataCacheEnabled: Boolean = false)
extends MetadataLog[T] with Logging {

def this(sparkSession: SparkSession, path: String) = {
this(
sparkSession.sessionState.newHadoopConf(),
path,
metadataCacheEnabled = sparkSession.sessionState.conf.getConf(
SQLConf.STREAMING_METADATA_CACHE_ENABLED)
)
}

private implicit val formats: Formats = Serialization.formats(NoTypeHints)

/** Needed to serialize type T into JSON when using Jackson */
Expand All @@ -64,15 +80,12 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path:
val metadataPath = new Path(path)

protected val fileManager =
CheckpointFileManager.create(metadataPath, sparkSession.sessionState.newHadoopConf())
CheckpointFileManager.create(metadataPath, hadoopConf)

if (!fileManager.exists(metadataPath)) {
fileManager.mkdirs(metadataPath)
}

protected val metadataCacheEnabled: Boolean
= sparkSession.sessionState.conf.getConf(SQLConf.STREAMING_METADATA_CACHE_ENABLED)

/**
* Cache the latest two batches. [[StreamExecution]] usually just accesses the latest two batches
* when committing offsets, this cache will save some file system operations.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadat
import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
import org.apache.spark.sql.execution.python.FlatMapGroupsInPandasWithStateExec
import org.apache.spark.sql.execution.streaming.sources.WriteToMicroBatchDataSourceV1
import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadataV1, OperatorStateMetadataWriter}
import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadataV1, OperatorStateMetadataV2, OperatorStateMetadataWriter}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.util.{SerializableConfiguration, Utils}
Expand Down Expand Up @@ -85,6 +85,10 @@ class IncrementalExecution(
.map(SQLConf.SHUFFLE_PARTITIONS.valueConverter)
.getOrElse(sparkSession.sessionState.conf.numShufflePartitions)

/**
* This value dictates which schema format version the state schema should be written in
* for all operators other than TransformWithState.
*/
private val STATE_SCHEMA_DEFAULT_VERSION: Int = 2

/**
Expand Down Expand Up @@ -206,11 +210,21 @@ class IncrementalExecution(
// write out the state schema paths to the metadata file
statefulOp match {
case stateStoreWriter: StateStoreWriter =>
val metadata = stateStoreWriter.operatorStateMetadata()
// TODO: Populate metadata with stateSchemaPaths if metadata version is v2
val metadataWriter = new OperatorStateMetadataWriter(new Path(
checkpointLocation, stateStoreWriter.getStateInfo.operatorId.toString), hadoopConf)
metadataWriter.write(metadata)
val metadata = stateStoreWriter.operatorStateMetadata(stateSchemaPaths)
stateStoreWriter match {
case tws: TransformWithStateExec =>
val metadataPath = OperatorStateMetadataV2.metadataFilePath(new Path(
checkpointLocation, tws.getStateInfo.operatorId.toString))
val operatorStateMetadataLog = new OperatorStateMetadataLog(sparkSession,
metadataPath.toString)
operatorStateMetadataLog.add(currentBatchId, metadata)
case _ =>
val metadataWriter = new OperatorStateMetadataWriter(new Path(
checkpointLocation,
stateStoreWriter.getStateInfo.operatorId.toString),
hadoopConf)
metadataWriter.write(metadata)
}
case _ =>
}
statefulOp
Expand Down Expand Up @@ -452,11 +466,11 @@ class IncrementalExecution(
new Path(checkpointLocation).getParent.toString,
new SerializableConfiguration(hadoopConf))
val opMetadataList = reader.allOperatorStateMetadata
ret = opMetadataList.map { operatorMetadata =>
val metadataInfoV1 = operatorMetadata
.asInstanceOf[OperatorStateMetadataV1]
.operatorInfo
metadataInfoV1.operatorId -> metadataInfoV1.operatorName
ret = opMetadataList.map {
case (OperatorStateMetadataV1(operatorInfo, _), _) =>
operatorInfo.operatorId -> operatorInfo.operatorName
case (OperatorStateMetadataV2(operatorInfo, _, _), _) =>
operatorInfo.operatorId -> operatorInfo.operatorName
}.toMap
} catch {
case e: Exception =>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.streaming

import java.io.{BufferedReader, InputStream, InputStreamReader, OutputStream}
import java.nio.charset.StandardCharsets
import java.nio.charset.StandardCharsets._

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.FSDataOutputStream

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadata, OperatorStateMetadataV1, OperatorStateMetadataV2}
import org.apache.spark.sql.internal.SQLConf


class OperatorStateMetadataLog(
hadoopConf: Configuration,
path: String,
metadataCacheEnabled: Boolean = false)
extends HDFSMetadataLog[OperatorStateMetadata](hadoopConf, path, metadataCacheEnabled) {

def this(sparkSession: SparkSession, path: String) = {
this(
sparkSession.sessionState.newHadoopConf(),
path,
metadataCacheEnabled = sparkSession.sessionState.conf.getConf(
SQLConf.STREAMING_METADATA_CACHE_ENABLED)
)
}

override protected def serialize(metadata: OperatorStateMetadata, out: OutputStream): Unit = {
val fsDataOutputStream = out.asInstanceOf[FSDataOutputStream]
fsDataOutputStream.write(s"v${metadata.version}\n".getBytes(StandardCharsets.UTF_8))
metadata.version match {
case 1 =>
OperatorStateMetadataV1.serialize(fsDataOutputStream, metadata)
case 2 =>
OperatorStateMetadataV2.serialize(fsDataOutputStream, metadata)
}
}

override protected def deserialize(in: InputStream): OperatorStateMetadata = {
// called inside a try-finally where the underlying stream is closed in the caller
// create buffered reader from input stream
val bufferedReader = new BufferedReader(new InputStreamReader(in, UTF_8))
// read first line for version number, in the format "v{version}"
val version = bufferedReader.readLine()
version match {
case "v1" => OperatorStateMetadataV1.deserialize(bufferedReader)
case "v2" => OperatorStateMetadataV2.deserialize(bufferedReader)
}
}
}
Loading