Skip to content

Commit 9974341

Browse files
committed
timers
1 parent d2d7f75 commit 9974341

File tree

5 files changed

+68
-23
lines changed

5 files changed

+68
-23
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,30 @@ import org.apache.spark.sql.avro.{AvroDeserializer, AvroOptions, AvroSerializer,
2222
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
2323
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._
2424
import org.apache.spark.sql.execution.streaming.state.{AvroEncoderSpec, NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, RangeKeyScanStateEncoderSpec, StateStoreColFamilySchema}
25-
import org.apache.spark.sql.types.{NullType, StructField, StructType}
25+
import org.apache.spark.sql.types.{BinaryType, BooleanType, ByteType, DataType, DoubleType, FloatType, IntegerType, LongType, NullType, ShortType, StructField, StructType}
2626

2727
object StateStoreColumnFamilySchemaUtils {
2828

2929
def apply(initializeAvroSerde: Boolean): StateStoreColumnFamilySchemaUtils =
3030
new StateStoreColumnFamilySchemaUtils(initializeAvroSerde)
31+
32+
33+
def convertForRangeScan(schema: StructType): StructType = {
34+
StructType(schema.fields.map { field =>
35+
if (isFixedSize(field.dataType)) {
36+
// Convert numeric types to BinaryType while preserving nullability
37+
field.copy(dataType = BinaryType)
38+
} else {
39+
field
40+
}
41+
})
42+
}
43+
44+
private def isFixedSize(dataType: DataType): Boolean = dataType match {
45+
case _: ByteType | _: BooleanType | _: ShortType | _: IntegerType | _: LongType |
46+
_: FloatType | _: DoubleType => true
47+
case _ => false
48+
}
3149
}
3250

3351
class StateStoreColumnFamilySchemaUtils(initializeAvroSerde: Boolean) extends Logging {
@@ -119,7 +137,8 @@ class StateStoreColumnFamilySchemaUtils(initializeAvroSerde: Boolean) extends Lo
119137
def getTtlStateSchema(
120138
stateName: String,
121139
keyEncoder: ExpressionEncoder[Any]): StateStoreColFamilySchema = {
122-
val ttlKeySchema = getSingleKeyTTLAvroRowSchema(keyEncoder.schema)
140+
val ttlKeySchema = StateStoreColumnFamilySchemaUtils.convertForRangeScan(
141+
getSingleKeyTTLRowSchema(keyEncoder.schema))
123142
val ttlValSchema = StructType(
124143
Array(StructField("__dummy__", NullType)))
125144
StateStoreColFamilySchema(
@@ -170,4 +189,21 @@ class StateStoreColumnFamilySchemaUtils(initializeAvroSerde: Boolean) extends Lo
170189
Some(StructType(keySchema.drop(1)))
171190
))
172191
}
192+
193+
def getTimerStateSchemaForSecIndex(
194+
stateName: String,
195+
keySchema: StructType,
196+
valSchema: StructType): StateStoreColFamilySchema = {
197+
val avroKeySchema = StateStoreColumnFamilySchemaUtils.convertForRangeScan(keySchema)
198+
StateStoreColFamilySchema(
199+
stateName,
200+
keySchema,
201+
valSchema,
202+
Some(RangeKeyScanStateEncoderSpec(keySchema, Seq(0))),
203+
avroEnc = getAvroSerde(
204+
StructType(avroKeySchema.take(1)),
205+
valSchema,
206+
Some(StructType(avroKeySchema.drop(1)))
207+
))
208+
}
173209
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,6 @@ object TransformWithStateKeyValueRowSchemaUtils {
4343
.add("expirationMs", LongType)
4444
.add("groupingKey", keySchema)
4545

46-
def getSingleKeyTTLAvroRowSchema(keySchema: StructType): StructType =
47-
new StructType()
48-
.add("expirationMs", BinaryType)
49-
.add("groupingKey", keySchema)
50-
5146
def getCompositeKeyTTLRowSchema(
5247
groupingKeySchema: StructType,
5348
userKeySchema: StructType): StructType =

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,13 @@ class StatefulProcessorHandleImpl(
165165

166166
override def getQueryInfo(): QueryInfo = currQueryInfo
167167

168-
private lazy val timerState = new TimerStateImpl(store, timeMode, keyEncoder)
168+
private lazy val timerStateName = TimerStateUtils.getTimerStateVarName(
169+
timeMode.toString)
170+
private lazy val timerSecIndexColFamily = TimerStateUtils.getSecIndexColFamilyName(
171+
timeMode.toString)
172+
private lazy val timerState = new TimerStateImpl(
173+
store, timeMode, keyEncoder, schemas(timerStateName).avroEnc,
174+
schemas(timerSecIndexColFamily).avroEnc)
169175

170176
/**
171177
* Function to register a timer for the given expiryTimestampMs
@@ -355,10 +361,16 @@ class DriverStatefulProcessorHandleImpl(
355361

356362
private def addTimerColFamily(): Unit = {
357363
val stateName = TimerStateUtils.getTimerStateVarName(timeMode.toString)
364+
val secIndexColFamilyName = TimerStateUtils.getSecIndexColFamilyName(timeMode.toString)
358365
val timerEncoder = new TimerKeyEncoder(keyExprEnc)
359366
val colFamilySchema = schemaUtils.
360367
getTimerStateSchema(stateName, timerEncoder.schemaForKeyRow, timerEncoder.schemaForValueRow)
368+
val secIndexColFamilySchema = schemaUtils.
369+
getTimerStateSchemaForSecIndex(secIndexColFamilyName,
370+
timerEncoder.keySchemaForSecIndex,
371+
timerEncoder.schemaForValueRow)
361372
columnFamilySchemas.put(stateName, colFamilySchema)
373+
columnFamilySchemas.put(secIndexColFamilyName, secIndexColFamilySchema)
362374
val stateVariableInfo = TransformWithStateVariableUtils.getTimerState(stateName)
363375
stateVariableInfos.put(stateName, stateVariableInfo)
364376
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,15 @@ object TimerStateUtils {
4343
TimerStateUtils.PROC_TIMERS_STATE_NAME + TimerStateUtils.KEY_TO_TIMESTAMP_CF
4444
}
4545
}
46+
47+
def getSecIndexColFamilyName(timeMode: String): String = {
48+
assert(timeMode == TimeMode.EventTime.toString || timeMode == TimeMode.ProcessingTime.toString)
49+
if (timeMode == TimeMode.EventTime.toString) {
50+
TimerStateUtils.EVENT_TIMERS_STATE_NAME + TimerStateUtils.TIMESTAMP_TO_KEY_CF
51+
} else {
52+
TimerStateUtils.PROC_TIMERS_STATE_NAME + TimerStateUtils.TIMESTAMP_TO_KEY_CF
53+
}
54+
}
4655
}
4756

4857
/**
@@ -55,7 +64,9 @@ object TimerStateUtils {
5564
class TimerStateImpl(
5665
store: StateStore,
5766
timeMode: TimeMode,
58-
keyExprEnc: ExpressionEncoder[Any]) extends Logging {
67+
keyExprEnc: ExpressionEncoder[Any],
68+
avroEnc: Option[AvroEncoderSpec] = None,
69+
secIndexAvroEnc: Option[AvroEncoderSpec] = None) extends Logging {
5970

6071
private val EMPTY_ROW =
6172
UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null))
@@ -75,15 +86,15 @@ class TimerStateImpl(
7586
private val keyToTsCFName = timerCFName + TimerStateUtils.KEY_TO_TIMESTAMP_CF
7687
store.createColFamilyIfAbsent(keyToTsCFName, schemaForKeyRow,
7788
schemaForValueRow, PrefixKeyScanStateEncoderSpec(schemaForKeyRow, 1),
78-
useMultipleValuesPerKey = false, isInternal = true)
89+
useMultipleValuesPerKey = false, isInternal = true, avroEncoderSpec = avroEnc)
7990

8091
// We maintain a secondary index that inverts the ordering of the timestamp
8192
// and grouping key
8293
private val keySchemaForSecIndex = rowEncoder.keySchemaForSecIndex
8394
private val tsToKeyCFName = timerCFName + TimerStateUtils.TIMESTAMP_TO_KEY_CF
8495
store.createColFamilyIfAbsent(tsToKeyCFName, keySchemaForSecIndex,
8596
schemaForValueRow, RangeKeyScanStateEncoderSpec(keySchemaForSecIndex, Seq(0)),
86-
useMultipleValuesPerKey = false, isInternal = true)
97+
useMultipleValuesPerKey = false, isInternal = true, avroEncoderSpec = secIndexAvroEnc)
8798

8899
private def getGroupingKey(cfName: String): Any = {
89100
val keyOption = ImplicitGroupingKeyTracker.getImplicitKeyOption

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import org.apache.spark.sql.avro.{AvroDeserializer, AvroSerializer, SchemaConver
3131
import org.apache.spark.sql.catalyst.InternalRow
3232
import org.apache.spark.sql.catalyst.expressions.{BoundReference, JoinedRow, UnsafeProjection, UnsafeRow}
3333
import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter
34+
import org.apache.spark.sql.execution.streaming.StateStoreColumnFamilySchemaUtils
3435
import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider.{STATE_ENCODING_NUM_VERSION_BYTES, STATE_ENCODING_VERSION, VIRTUAL_COL_FAMILY_PREFIX_BYTES}
3536
import org.apache.spark.sql.types._
3637
import org.apache.spark.unsafe.Platform
@@ -394,6 +395,7 @@ class RangeKeyScanStateEncoder(
394395
extends RocksDBKeyStateEncoderBase(useColumnFamilies, virtualColFamilyId) with Logging {
395396

396397
import RocksDBStateEncoder._
398+
logError(s"### avroEnc.isDefined: ${avroEnc.isDefined}")
397399

398400
private val rangeScanKeyFieldsWithOrdinal: Seq[(StructField, Int)] = {
399401
orderingOrdinals.map { ordinal =>
@@ -460,18 +462,7 @@ class RangeKeyScanStateEncoder(
460462
UnsafeProjection.create(refs)
461463
}
462464

463-
def convertForRangeScan(schema: StructType): StructType = {
464-
StructType(schema.fields.map { field =>
465-
if (isFixedSize(field.dataType)) {
466-
// Convert numeric types to BinaryType while preserving nullability
467-
field.copy(dataType = BinaryType)
468-
} else {
469-
field
470-
}
471-
})
472-
}
473-
474-
private val rangeScanAvroSchema = convertForRangeScan(
465+
private val rangeScanAvroSchema = StateStoreColumnFamilySchemaUtils.convertForRangeScan(
475466
StructType(rangeScanKeyFieldsWithOrdinal.map(_._1).toArray))
476467

477468
private val rangeScanAvroType = SchemaConverters.toAvroType(rangeScanAvroSchema)

0 commit comments

Comments
 (0)