Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -3785,6 +3785,12 @@
],
"sqlState" : "42802"
},
"STATE_STORE_DUPLICATE_STATE_VARIABLE_DEFINED" : {
"message" : [
"State variable with name <stateName> has already been defined in the StatefulProcessor."
],
"sqlState" : "42802"
},
"STATE_STORE_HANDLE_NOT_INITIALIZED" : {
"message" : [
"The handle has not been initialized for this StatefulProcessor.",
Expand All @@ -3804,12 +3810,24 @@
],
"sqlState" : "42802"
},
"STATE_STORE_INVALID_CONFIG_AFTER_RESTART" : {
"message" : [
"<configName> <oldConfig> is not equal to <newConfig>. Please set <configName> to <oldConfig>, or restart with a new checkpoint directory."
],
"sqlState" : "42K06"
},
"STATE_STORE_INVALID_PROVIDER" : {
"message" : [
"The given State Store Provider <inputClass> does not extend org.apache.spark.sql.execution.streaming.state.StateStoreProvider."
],
"sqlState" : "42K06"
},
"STATE_STORE_INVALID_VARIABLE_TYPE_CHANGE" : {
"message" : [
"Cannot change <stateName> to <newType> between query restarts. Please set <stateName> to <oldType>, or restart with a new checkpoint directory."
],
"sqlState" : "42K06"
},
"STATE_STORE_KEY_ROW_FORMAT_VALIDATION_FAILURE" : {
"message" : [
"The streaming query failed to validate written state for key row.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,12 @@ class IncrementalExecution(
checkpointLocation, tws.getStateInfo.operatorId.toString))
val operatorStateMetadataLog = new OperatorStateMetadataLog(sparkSession,
metadataPath.toString)
// check if old metadata is present. if it is, validate with this metadata
operatorStateMetadataLog.getLatest() match {
case Some((_, oldMetadata)) =>
tws.validateMetadatas(oldMetadata, metadata)
case None =>
}
operatorStateMetadataLog.add(currentBatchId, metadata)
case _ =>
val metadataWriter = new OperatorStateMetadataWriter(new Path(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,18 +302,35 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi

private[sql] val columnFamilySchemaUtils = ColumnFamilySchemaUtilsV1

private[sql] val stateVariableUtils = TransformWithStateVariableUtils

// Because this is only happening on the driver side, there is only
// one task modifying and accessing this map at a time
private[sql] val columnFamilySchemas: mutable.Map[String, ColumnFamilySchema] =
new mutable.HashMap[String, ColumnFamilySchema]()

private[sql] val stateVariableInfos: mutable.Map[String, TransformWithStateVariableInfo] =
new mutable.HashMap[String, TransformWithStateVariableInfo]()

def getColumnFamilySchemas: Map[String, ColumnFamilySchema] = columnFamilySchemas.toMap

def getStateVariableInfos: Map[String, TransformWithStateVariableInfo] = stateVariableInfos.toMap

def checkIfDuplicateVariableDefined(stateName: String): Unit = {
if (columnFamilySchemas.contains(stateName)) {
throw StateStoreErrors.duplicateStateVariableDefined(stateName)
}
}

override def getValueState[T](stateName: String, valEncoder: Encoder[T]): ValueState[T] = {
verifyStateVarOperations("get_value_state", PRE_INIT)
val colFamilySchema = columnFamilySchemaUtils.
getValueStateSchema(stateName, keyExprEnc, valEncoder, false)
checkIfDuplicateVariableDefined(stateName)
columnFamilySchemas.put(stateName, colFamilySchema)
val stateVariableInfo = stateVariableUtils.
getValueState(stateName, ttlEnabled = false)
stateVariableInfos.put(stateName, stateVariableInfo)
null.asInstanceOf[ValueState[T]]
}

Expand All @@ -324,15 +341,23 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
verifyStateVarOperations("get_value_state", PRE_INIT)
val colFamilySchema = columnFamilySchemaUtils.
getValueStateSchema(stateName, keyExprEnc, valEncoder, true)
checkIfDuplicateVariableDefined(stateName)
columnFamilySchemas.put(stateName, colFamilySchema)
val stateVariableInfo = stateVariableUtils.
getValueState(stateName, ttlEnabled = true)
stateVariableInfos.put(stateName, stateVariableInfo)
null.asInstanceOf[ValueState[T]]
}

override def getListState[T](stateName: String, valEncoder: Encoder[T]): ListState[T] = {
verifyStateVarOperations("get_list_state", PRE_INIT)
val colFamilySchema = columnFamilySchemaUtils.
getListStateSchema(stateName, keyExprEnc, valEncoder, false)
checkIfDuplicateVariableDefined(stateName)
columnFamilySchemas.put(stateName, colFamilySchema)
val stateVariableInfo = stateVariableUtils.
getListState(stateName, ttlEnabled = false)
stateVariableInfos.put(stateName, stateVariableInfo)
null.asInstanceOf[ListState[T]]
}

Expand All @@ -343,7 +368,11 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
verifyStateVarOperations("get_list_state", PRE_INIT)
val colFamilySchema = columnFamilySchemaUtils.
getListStateSchema(stateName, keyExprEnc, valEncoder, true)
checkIfDuplicateVariableDefined(stateName)
columnFamilySchemas.put(stateName, colFamilySchema)
val stateVariableInfo = stateVariableUtils.
getListState(stateName, ttlEnabled = true)
stateVariableInfos.put(stateName, stateVariableInfo)
null.asInstanceOf[ListState[T]]
}

Expand All @@ -354,7 +383,11 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
verifyStateVarOperations("get_map_state", PRE_INIT)
val colFamilySchema = columnFamilySchemaUtils.
getMapStateSchema(stateName, keyExprEnc, userKeyEnc, valEncoder, false)
checkIfDuplicateVariableDefined(stateName)
columnFamilySchemas.put(stateName, colFamilySchema)
val stateVariableInfo = stateVariableUtils.
getMapState(stateName, ttlEnabled = false)
stateVariableInfos.put(stateName, stateVariableInfo)
null.asInstanceOf[MapState[K, V]]
}

Expand All @@ -366,7 +399,11 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
verifyStateVarOperations("get_map_state", PRE_INIT)
val colFamilySchema = columnFamilySchemaUtils.
getMapStateSchema(stateName, keyExprEnc, valEncoder, userKeyEnc, true)
checkIfDuplicateVariableDefined(stateName)
columnFamilySchemas.put(stateName, colFamilySchema)
val stateVariableInfo = stateVariableUtils.
getMapState(stateName, ttlEnabled = true)
stateVariableInfos.put(stateName, stateVariableInfo)
null.asInstanceOf[MapState[K, V]]
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ import java.util.concurrent.TimeUnit.NANOSECONDS

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.json4s.{DefaultFormats, JString}
import org.json4s.JsonAST.JValue
import org.json4s.JsonDSL._
import org.json4s.JString
import org.json4s.jackson.JsonMethods
import org.json4s.jackson.JsonMethods.{compact, render}

import org.apache.spark.broadcast.Broadcast
Expand Down Expand Up @@ -122,6 +123,12 @@ case class TransformWithStateExec(
columnFamilySchemas
}

private def getStateVariableInfos(): Map[String, TransformWithStateVariableInfo] = {
val stateVariableInfos = getDriverProcessorHandle().getStateVariableInfos
closeProcessorHandle()
stateVariableInfos
}

/**
* This method is used for the driver-side stateful processor after we
* have collected all the necessary schemas.
Expand Down Expand Up @@ -431,6 +438,66 @@ case class TransformWithStateExec(
}
}

private def checkOperatorPropEquality[T](
fieldName: String,
oldMetadataV2: OperatorStateMetadataV2,
newMetadataV2: OperatorStateMetadataV2): Unit = {
val oldJsonString = oldMetadataV2.operatorPropertiesJson
val newJsonString = newMetadataV2.operatorPropertiesJson
// verify that timeMode, outputMode are the same
implicit val formats: DefaultFormats.type = DefaultFormats
val oldJsonProps = JsonMethods.parse(oldJsonString).extract[Map[String, Any]]
val newJsonProps = JsonMethods.parse(newJsonString).extract[Map[String, Any]]
val oldProp = oldJsonProps(fieldName).asInstanceOf[T]
val newProp = newJsonProps(fieldName).asInstanceOf[T]
if (oldProp != newProp) {
throw StateStoreErrors.invalidConfigChangedAfterRestart(
fieldName,
oldProp.toString,
newProp.toString
)
}
}

private def checkStateVariableEquality(oldMetadataV2: OperatorStateMetadataV2): Unit = {
val oldJsonString = oldMetadataV2.operatorPropertiesJson
implicit val formats: DefaultFormats.type = DefaultFormats
val oldJsonProps = JsonMethods.parse(oldJsonString).extract[Map[String, Any]]
// compare state variable infos
val oldStateVariableInfos = oldJsonProps("stateVariables").
asInstanceOf[List[Map[String, Any]]]
.map(TransformWithStateVariableInfo.fromMap)
val newStateVariableInfos = getStateVariableInfos()
oldStateVariableInfos.foreach { oldInfo =>
val newInfo = newStateVariableInfos.get(oldInfo.stateName)
newInfo match {
case Some(stateVarInfo) =>
if (oldInfo.stateVariableType != stateVarInfo.stateVariableType) {
throw StateStoreErrors.invalidVariableTypeChange(
stateVarInfo.stateName,
oldInfo.stateVariableType.toString,
stateVarInfo.stateVariableType.toString
)
}
case None =>
}
}
}

def validateMetadatas(
oldMetadata: OperatorStateMetadata,
newMetadata: OperatorStateMetadata): Unit = {
(oldMetadata, newMetadata) match {
case (
oldMetadataV2: OperatorStateMetadataV2,
newMetadataV2: OperatorStateMetadataV2) =>
checkOperatorPropEquality[String]("timeMode", oldMetadataV2, newMetadataV2)
checkOperatorPropEquality[String]("outputMode", oldMetadataV2, newMetadataV2)
checkStateVariableEquality(oldMetadataV2)
case (_, _) =>
}
}

/** Metadata of this stateful operator and its states stores. */
override def operatorStateMetadata(
stateSchemaPaths: Array[String] = Array.empty): OperatorStateMetadata = {
Expand All @@ -443,7 +510,10 @@ case class TransformWithStateExec(

val operatorPropertiesJson: JValue =
("timeMode" -> JString(timeMode.toString)) ~
("outputMode" -> JString(outputMode.toString))
("outputMode" -> JString(outputMode.toString)) ~
("stateVariables" -> getStateVariableInfos().map { case (_, stateInfo) =>
stateInfo.jsonValue
}.arr)

val json = compact(render(operatorPropertiesJson))
OperatorStateMetadataV2(operatorInfo, stateStoreInfo, json)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.streaming

import org.json4s.DefaultFormats
import org.json4s.JsonAST._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods
import org.json4s.jackson.JsonMethods.{compact, render}

import org.apache.spark.sql.execution.streaming.StateVariableType.StateVariableType

// Enum of possible State Variable types
object StateVariableType extends Enumeration {
type StateVariableType = Value
val ValueState, ListState, MapState = Value
}

case class TransformWithStateVariableInfo(
stateName: String,
stateVariableType: StateVariableType,
ttlEnabled: Boolean) {
def jsonValue: JValue = {
("stateName" -> JString(stateName)) ~
("stateVariableType" -> JString(stateVariableType.toString)) ~
("ttlEnabled" -> JBool(ttlEnabled))
}

def json: String = {
compact(render(jsonValue))
}
}

object TransformWithStateVariableInfo {

def fromJson(json: String): TransformWithStateVariableInfo = {
implicit val formats: DefaultFormats.type = DefaultFormats
val parsed = JsonMethods.parse(json).extract[Map[String, Any]]
fromMap(parsed)
}

def fromMap(map: Map[String, Any]): TransformWithStateVariableInfo = {
val stateName = map("stateName").asInstanceOf[String]
val stateVariableType = StateVariableType.withName(
map("stateVariableType").asInstanceOf[String])
val ttlEnabled = map("ttlEnabled").asInstanceOf[Boolean]
TransformWithStateVariableInfo(stateName, stateVariableType, ttlEnabled)
}
}
object TransformWithStateVariableUtils {
def getValueState(stateName: String, ttlEnabled: Boolean): TransformWithStateVariableInfo = {
TransformWithStateVariableInfo(stateName, StateVariableType.ValueState, ttlEnabled)
}

def getListState(stateName: String, ttlEnabled: Boolean): TransformWithStateVariableInfo = {
TransformWithStateVariableInfo(stateName, StateVariableType.ListState, ttlEnabled)
}

def getMapState(stateName: String, ttlEnabled: Boolean): TransformWithStateVariableInfo = {
TransformWithStateVariableInfo(stateName, StateVariableType.MapState, ttlEnabled)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import scala.io.{Source => IOSource}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path

import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.streaming.CheckpointFileManager
import org.apache.spark.sql.execution.streaming.MetadataVersionUtil.validateVersion

Expand All @@ -40,7 +39,7 @@ import org.apache.spark.sql.execution.streaming.MetadataVersionUtil.validateVers
*/
class StateSchemaV3File(
hadoopConf: Configuration,
path: String) extends Logging {
path: String) {

val metadataPath = new Path(path)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,8 @@ object KeyStateEncoderSpec {
asInstanceOf[List[_]].map(_.asInstanceOf[BigInt].toInt)
RangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals)
case "PrefixKeyScanStateEncoderSpec" =>
val numColsPrefixKey = m("numColsPrefixKey").asInstanceOf[BigInt].toInt
PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, numColsPrefixKey)
val numColsPrefixKey = m("numColsPrefixKey").asInstanceOf[BigInt]
PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, numColsPrefixKey.toInt)
}
}
}
Expand Down
Loading