diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 7f54a77c94a0f..8c3dd81d8542f 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -3785,6 +3785,12 @@ ], "sqlState" : "42802" }, + "STATE_STORE_DUPLICATE_STATE_VARIABLE_DEFINED" : { + "message" : [ + "State variable with name has already been defined in the StatefulProcessor." + ], + "sqlState" : "42802" + }, "STATE_STORE_HANDLE_NOT_INITIALIZED" : { "message" : [ "The handle has not been initialized for this StatefulProcessor.", @@ -3804,12 +3810,24 @@ ], "sqlState" : "42802" }, + "STATE_STORE_INVALID_CONFIG_AFTER_RESTART" : { + "message" : [ + " is not equal to . Please set to , or restart with a new checkpoint directory." + ], + "sqlState" : "42K06" + }, "STATE_STORE_INVALID_PROVIDER" : { "message" : [ "The given State Store Provider does not extend org.apache.spark.sql.execution.streaming.state.StateStoreProvider." ], "sqlState" : "42K06" }, + "STATE_STORE_INVALID_VARIABLE_TYPE_CHANGE" : { + "message" : [ + "Cannot change to between query restarts. Please set to , or restart with a new checkpoint directory." + ], + "sqlState" : "42K06" + }, "STATE_STORE_KEY_ROW_FORMAT_VALIDATION_FAILURE" : { "message" : [ "The streaming query failed to validate written state for key row.", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 4ce762714e864..6762b1d120d4c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -243,6 +243,12 @@ class IncrementalExecution( checkpointLocation, tws.getStateInfo.operatorId.toString)) val operatorStateMetadataLog = new OperatorStateMetadataLog(sparkSession, metadataPath.toString) + // check if old metadata is present. if it is, validate with this metadata + operatorStateMetadataLog.getLatest() match { + case Some((_, oldMetadata)) => + tws.validateMetadatas(oldMetadata, metadata) + case None => + } operatorStateMetadataLog.add(currentBatchId, metadata) case _ => val metadataWriter = new OperatorStateMetadataWriter(new Path( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateSchemaV3File.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateSchemaV3File.scala deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index 65b435b5c692c..ec2df52ad6e84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -302,18 +302,35 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi private[sql] val columnFamilySchemaUtils = ColumnFamilySchemaUtilsV1 + private[sql] val stateVariableUtils = TransformWithStateVariableUtils + // Because this is only happening on the driver side, there is only // one task modifying and accessing this map at a time private[sql] val columnFamilySchemas: mutable.Map[String, ColumnFamilySchema] = new mutable.HashMap[String, ColumnFamilySchema]() + private[sql] val stateVariableInfos: mutable.Map[String, TransformWithStateVariableInfo] = + new mutable.HashMap[String, TransformWithStateVariableInfo]() + def getColumnFamilySchemas: Map[String, ColumnFamilySchema] = columnFamilySchemas.toMap + def getStateVariableInfos: Map[String, TransformWithStateVariableInfo] = stateVariableInfos.toMap + + def checkIfDuplicateVariableDefined(stateName: String): Unit = { + if (columnFamilySchemas.contains(stateName)) { + throw StateStoreErrors.duplicateStateVariableDefined(stateName) + } + } + override def getValueState[T](stateName: String, valEncoder: Encoder[T]): ValueState[T] = { verifyStateVarOperations("get_value_state", PRE_INIT) val colFamilySchema = columnFamilySchemaUtils. getValueStateSchema(stateName, keyExprEnc, valEncoder, false) + checkIfDuplicateVariableDefined(stateName) columnFamilySchemas.put(stateName, colFamilySchema) + val stateVariableInfo = stateVariableUtils. + getValueState(stateName, ttlEnabled = false) + stateVariableInfos.put(stateName, stateVariableInfo) null.asInstanceOf[ValueState[T]] } @@ -324,7 +341,11 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi verifyStateVarOperations("get_value_state", PRE_INIT) val colFamilySchema = columnFamilySchemaUtils. getValueStateSchema(stateName, keyExprEnc, valEncoder, true) + checkIfDuplicateVariableDefined(stateName) columnFamilySchemas.put(stateName, colFamilySchema) + val stateVariableInfo = stateVariableUtils. + getValueState(stateName, ttlEnabled = true) + stateVariableInfos.put(stateName, stateVariableInfo) null.asInstanceOf[ValueState[T]] } @@ -332,7 +353,11 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi verifyStateVarOperations("get_list_state", PRE_INIT) val colFamilySchema = columnFamilySchemaUtils. getListStateSchema(stateName, keyExprEnc, valEncoder, false) + checkIfDuplicateVariableDefined(stateName) columnFamilySchemas.put(stateName, colFamilySchema) + val stateVariableInfo = stateVariableUtils. + getListState(stateName, ttlEnabled = false) + stateVariableInfos.put(stateName, stateVariableInfo) null.asInstanceOf[ListState[T]] } @@ -343,7 +368,11 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi verifyStateVarOperations("get_list_state", PRE_INIT) val colFamilySchema = columnFamilySchemaUtils. getListStateSchema(stateName, keyExprEnc, valEncoder, true) + checkIfDuplicateVariableDefined(stateName) columnFamilySchemas.put(stateName, colFamilySchema) + val stateVariableInfo = stateVariableUtils. + getListState(stateName, ttlEnabled = true) + stateVariableInfos.put(stateName, stateVariableInfo) null.asInstanceOf[ListState[T]] } @@ -354,7 +383,11 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi verifyStateVarOperations("get_map_state", PRE_INIT) val colFamilySchema = columnFamilySchemaUtils. getMapStateSchema(stateName, keyExprEnc, userKeyEnc, valEncoder, false) + checkIfDuplicateVariableDefined(stateName) columnFamilySchemas.put(stateName, colFamilySchema) + val stateVariableInfo = stateVariableUtils. + getMapState(stateName, ttlEnabled = false) + stateVariableInfos.put(stateName, stateVariableInfo) null.asInstanceOf[MapState[K, V]] } @@ -366,7 +399,11 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi verifyStateVarOperations("get_map_state", PRE_INIT) val colFamilySchema = columnFamilySchemaUtils. getMapStateSchema(stateName, keyExprEnc, valEncoder, userKeyEnc, true) + checkIfDuplicateVariableDefined(stateName) columnFamilySchemas.put(stateName, colFamilySchema) + val stateVariableInfo = stateVariableUtils. + getMapState(stateName, ttlEnabled = true) + stateVariableInfos.put(stateName, stateVariableInfo) null.asInstanceOf[MapState[K, V]] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index 0f43256e3b6c2..56a1387701a4c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -21,9 +21,10 @@ import java.util.concurrent.TimeUnit.NANOSECONDS import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.json4s.{DefaultFormats, JString} import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ -import org.json4s.JString +import org.json4s.jackson.JsonMethods import org.json4s.jackson.JsonMethods.{compact, render} import org.apache.spark.broadcast.Broadcast @@ -122,6 +123,12 @@ case class TransformWithStateExec( columnFamilySchemas } + private def getStateVariableInfos(): Map[String, TransformWithStateVariableInfo] = { + val stateVariableInfos = getDriverProcessorHandle().getStateVariableInfos + closeProcessorHandle() + stateVariableInfos + } + /** * This method is used for the driver-side stateful processor after we * have collected all the necessary schemas. @@ -431,6 +438,66 @@ case class TransformWithStateExec( } } + private def checkOperatorPropEquality[T]( + fieldName: String, + oldMetadataV2: OperatorStateMetadataV2, + newMetadataV2: OperatorStateMetadataV2): Unit = { + val oldJsonString = oldMetadataV2.operatorPropertiesJson + val newJsonString = newMetadataV2.operatorPropertiesJson + // verify that timeMode, outputMode are the same + implicit val formats: DefaultFormats.type = DefaultFormats + val oldJsonProps = JsonMethods.parse(oldJsonString).extract[Map[String, Any]] + val newJsonProps = JsonMethods.parse(newJsonString).extract[Map[String, Any]] + val oldProp = oldJsonProps(fieldName).asInstanceOf[T] + val newProp = newJsonProps(fieldName).asInstanceOf[T] + if (oldProp != newProp) { + throw StateStoreErrors.invalidConfigChangedAfterRestart( + fieldName, + oldProp.toString, + newProp.toString + ) + } + } + + private def checkStateVariableEquality(oldMetadataV2: OperatorStateMetadataV2): Unit = { + val oldJsonString = oldMetadataV2.operatorPropertiesJson + implicit val formats: DefaultFormats.type = DefaultFormats + val oldJsonProps = JsonMethods.parse(oldJsonString).extract[Map[String, Any]] + // compare state variable infos + val oldStateVariableInfos = oldJsonProps("stateVariables"). + asInstanceOf[List[Map[String, Any]]] + .map(TransformWithStateVariableInfo.fromMap) + val newStateVariableInfos = getStateVariableInfos() + oldStateVariableInfos.foreach { oldInfo => + val newInfo = newStateVariableInfos.get(oldInfo.stateName) + newInfo match { + case Some(stateVarInfo) => + if (oldInfo.stateVariableType != stateVarInfo.stateVariableType) { + throw StateStoreErrors.invalidVariableTypeChange( + stateVarInfo.stateName, + oldInfo.stateVariableType.toString, + stateVarInfo.stateVariableType.toString + ) + } + case None => + } + } + } + + def validateMetadatas( + oldMetadata: OperatorStateMetadata, + newMetadata: OperatorStateMetadata): Unit = { + (oldMetadata, newMetadata) match { + case ( + oldMetadataV2: OperatorStateMetadataV2, + newMetadataV2: OperatorStateMetadataV2) => + checkOperatorPropEquality[String]("timeMode", oldMetadataV2, newMetadataV2) + checkOperatorPropEquality[String]("outputMode", oldMetadataV2, newMetadataV2) + checkStateVariableEquality(oldMetadataV2) + case (_, _) => + } + } + /** Metadata of this stateful operator and its states stores. */ override def operatorStateMetadata( stateSchemaPaths: Array[String] = Array.empty): OperatorStateMetadata = { @@ -443,7 +510,10 @@ case class TransformWithStateExec( val operatorPropertiesJson: JValue = ("timeMode" -> JString(timeMode.toString)) ~ - ("outputMode" -> JString(outputMode.toString)) + ("outputMode" -> JString(outputMode.toString)) ~ + ("stateVariables" -> getStateVariableInfos().map { case (_, stateInfo) => + stateInfo.jsonValue + }.arr) val json = compact(render(operatorPropertiesJson)) OperatorStateMetadataV2(operatorInfo, stateStoreInfo, json) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateVariableUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateVariableUtils.scala new file mode 100644 index 0000000000000..3c8f0796c322f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateVariableUtils.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import org.json4s.DefaultFormats +import org.json4s.JsonAST._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods +import org.json4s.jackson.JsonMethods.{compact, render} + +import org.apache.spark.sql.execution.streaming.StateVariableType.StateVariableType + +// Enum of possible State Variable types +object StateVariableType extends Enumeration { + type StateVariableType = Value + val ValueState, ListState, MapState = Value +} + +case class TransformWithStateVariableInfo( + stateName: String, + stateVariableType: StateVariableType, + ttlEnabled: Boolean) { + def jsonValue: JValue = { + ("stateName" -> JString(stateName)) ~ + ("stateVariableType" -> JString(stateVariableType.toString)) ~ + ("ttlEnabled" -> JBool(ttlEnabled)) + } + + def json: String = { + compact(render(jsonValue)) + } +} + +object TransformWithStateVariableInfo { + + def fromJson(json: String): TransformWithStateVariableInfo = { + implicit val formats: DefaultFormats.type = DefaultFormats + val parsed = JsonMethods.parse(json).extract[Map[String, Any]] + fromMap(parsed) + } + + def fromMap(map: Map[String, Any]): TransformWithStateVariableInfo = { + val stateName = map("stateName").asInstanceOf[String] + val stateVariableType = StateVariableType.withName( + map("stateVariableType").asInstanceOf[String]) + val ttlEnabled = map("ttlEnabled").asInstanceOf[Boolean] + TransformWithStateVariableInfo(stateName, stateVariableType, ttlEnabled) + } +} +object TransformWithStateVariableUtils { + def getValueState(stateName: String, ttlEnabled: Boolean): TransformWithStateVariableInfo = { + TransformWithStateVariableInfo(stateName, StateVariableType.ValueState, ttlEnabled) + } + + def getListState(stateName: String, ttlEnabled: Boolean): TransformWithStateVariableInfo = { + TransformWithStateVariableInfo(stateName, StateVariableType.ListState, ttlEnabled) + } + + def getMapState(stateName: String, ttlEnabled: Boolean): TransformWithStateVariableInfo = { + TransformWithStateVariableInfo(stateName, StateVariableType.MapState, ttlEnabled) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaV3File.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaV3File.scala index e7e94ae193664..4bd49d41c14a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaV3File.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaV3File.scala @@ -26,7 +26,6 @@ import scala.io.{Source => IOSource} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.streaming.CheckpointFileManager import org.apache.spark.sql.execution.streaming.MetadataVersionUtil.validateVersion @@ -40,7 +39,7 @@ import org.apache.spark.sql.execution.streaming.MetadataVersionUtil.validateVers */ class StateSchemaV3File( hadoopConf: Configuration, - path: String) extends Logging { + path: String) { val metadataPath = new Path(path) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 053112ebaa9ec..39937b6cdd6e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -299,8 +299,8 @@ object KeyStateEncoderSpec { asInstanceOf[List[_]].map(_.asInstanceOf[BigInt].toInt) RangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals) case "PrefixKeyScanStateEncoderSpec" => - val numColsPrefixKey = m("numColsPrefixKey").asInstanceOf[BigInt].toInt - PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, numColsPrefixKey) + val numColsPrefixKey = m("numColsPrefixKey").asInstanceOf[BigInt] + PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, numColsPrefixKey.toInt) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala index 4ac813291c00b..6057832aca694 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala @@ -173,7 +173,49 @@ object StateStoreErrors { StateStoreProviderDoesNotSupportFineGrainedReplay = { new StateStoreProviderDoesNotSupportFineGrainedReplay(inputClass) } + + def invalidConfigChangedAfterRestart(configName: String, oldConfig: String, newConfig: String): + StateStoreInvalidConfigAfterRestart = { + new StateStoreInvalidConfigAfterRestart(configName, oldConfig, newConfig) + } + + def duplicateStateVariableDefined(stateName: String): + StateStoreDuplicateStateVariableDefined = { + new StateStoreDuplicateStateVariableDefined(stateName) + } + + def invalidVariableTypeChange(stateName: String, oldType: String, newType: String): + StateStoreInvalidVariableTypeChange = { + new StateStoreInvalidVariableTypeChange(stateName, oldType, newType) + } } +class StateStoreDuplicateStateVariableDefined(stateName: String) + extends SparkRuntimeException( + errorClass = "STATE_STORE_DUPLICATE_STATE_VARIABLE_DEFINED", + messageParameters = Map( + "stateName" -> stateName + ) + ) + +class StateStoreInvalidConfigAfterRestart(configName: String, oldConfig: String, newConfig: String) + extends SparkUnsupportedOperationException( + errorClass = "STATE_STORE_INVALID_CONFIG_AFTER_RESTART", + messageParameters = Map( + "configName" -> configName, + "oldConfig" -> oldConfig, + "newConfig" -> newConfig + ) + ) + +class StateStoreInvalidVariableTypeChange(stateName: String, oldType: String, newType: String) + extends SparkUnsupportedOperationException( + errorClass = "STATE_STORE_INVALID_VARIABLE_TYPE_CHANGE", + messageParameters = Map( + "stateName" -> stateName, + "oldType" -> oldType, + "newType" -> newType + ) + ) class StateStoreMultipleColumnFamiliesNotSupportedException(stateStoreProvider: String) extends SparkUnsupportedOperationException( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index fd621fc1161f1..18cbb53d3b29d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.{Dataset, Encoders, Row} import org.apache.spark.sql.catalyst.util.stringToFile import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA, KEY_ROW_SCHEMA} -import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, ColumnFamilySchema, ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, OperatorInfoV1, OperatorStateMetadataV2, POJOTestClass, PrefixKeyScanStateEncoderSpec, RocksDBStateStoreProvider, StatefulProcessorCannotPerformOperationWithInvalidHandleState, StateSchemaV3File, StateStoreMetadataV2, StateStoreMultipleColumnFamiliesNotSupportedException, StateStoreValueSchemaNotCompatible, TestClass} +import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, ColumnFamilySchema, ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, OperatorInfoV1, OperatorStateMetadataV2, POJOTestClass, PrefixKeyScanStateEncoderSpec, RocksDBStateStoreProvider, StatefulProcessorCannotPerformOperationWithInvalidHandleState, StateSchemaV3File, StateStoreInvalidConfigAfterRestart, StateStoreInvalidVariableTypeChange, StateStoreMetadataV2, StateStoreMultipleColumnFamiliesNotSupportedException, StateStoreValueSchemaNotCompatible, TestClass} import org.apache.spark.sql.functions.timestamp_seconds import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock @@ -64,6 +64,29 @@ class RunningCountStatefulProcessor extends StatefulProcessor[String, String, (S } } +// Class to test that changing between Value and List State fails +// between query runs +class RunningCountListStatefulProcessor + extends StatefulProcessor[String, String, (String, String)] + with Logging { + @transient protected var _countState: ListState[Long] = _ + + override def init( + outputMode: OutputMode, + timeMode: TimeMode): Unit = { + _countState = getHandle.getListState[Long]( + "countState", Encoders.scalaLong) + } + + override def handleInputRows( + key: String, + inputRows: Iterator[String], + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String)] = { + Iterator.empty + } +} + class RunningCountStatefulProcessorInt extends StatefulProcessor[String, String, (String, String)] with Logging { @transient protected var _countState: ValueState[Int] = _ @@ -884,7 +907,6 @@ class TransformWithStateSuite extends StateStoreMetricsTest val columnFamilySchemas = fetchColumnFamilySchemas(chkptDir.getCanonicalPath, 0) assert(columnFamilySchemas.length == 1) - val expected = ColumnFamilySchemaV1( "countState", new StructType().add("key", @@ -951,7 +973,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } test("transformWithState - verify that OperatorStateMetadataV2" + - " file is being written correctly") { + " integrates with state-metadata source") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> @@ -980,17 +1002,20 @@ class TransformWithStateSuite extends StateStoreMetricsTest Row(0, "transformWithStateExec", "default", 5, 0L, 0L), Row(0, "transformWithStateExec", "default", 5, 1L, 1L) )) + // need line to be unbroken, otherwise the test will fail. + // scalastyle:off + val expectedAnswer = """{"timeMode":"NoTime","outputMode":"Update","stateVariables":[{"stateName":"countState","stateVariableType":"ValueState","ttlEnabled":false}]}""" + // scalastyle:on checkAnswer(df.select(df.metadataColumn("_operatorProperties")), Seq( - Row("""{"timeMode":"NoTime","outputMode":"Update"}"""), - Row("""{"timeMode":"NoTime","outputMode":"Update"}""") + Row(expectedAnswer), + Row(expectedAnswer) ) ) } } } - test("transformWithState - verify that metadata logs are purged") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, @@ -1176,6 +1201,77 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } + test("test that different outputMode after query restart fails") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + withTempDir { checkpointDir => + val inputData = MemoryStream[String] + val result1 = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result1, OutputMode.Update())( + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + StopStream + ) + val result2 = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.None(), + OutputMode.Append()) + testStream(result2, OutputMode.Update())( + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(inputData, "a"), + ExpectFailure[StateStoreInvalidConfigAfterRestart] { t => + assert(t.getMessage.contains("outputMode")) + assert(t.getMessage.contains("is not equal")) + } + ) + } + } + } + + test("test that changing between different state variable types fails") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + withTempDir { checkpointDir => + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + StopStream + ) + val result2 = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountListStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + testStream(result2, OutputMode.Update())( + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(inputData, "a"), + ExpectFailure[StateStoreInvalidVariableTypeChange] { t => + assert(t.getMessage.contains("Cannot change countState")) + } + ) + } + } + } + test("transformWithState - verify StateSchemaV3 writes correct SQL schema of key/value") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName,