Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -383,8 +383,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.MapGroups(f, key, value, grouping, data, objAttr, child) =>
execution.MapGroupsExec(f, key, value, grouping, data, objAttr, planLater(child)) :: Nil
case logical.FlatMapGroupsWithState(
f, key, value, grouping, data, output, _, _, _, _, child) =>
execution.MapGroupsExec(f, key, value, grouping, data, output, planLater(child)) :: Nil
f, key, value, grouping, data, output, _, _, _, timeout, child) =>
execution.MapGroupsExec(
f, key, value, grouping, data, output, timeout, planLater(child)) :: Nil
case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) =>
execution.CoGroupExec(
f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState
import org.apache.spark.sql.execution.streaming.GroupStateImpl
import org.apache.spark.sql.streaming.GroupStateTimeout
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -361,8 +362,11 @@ object MapGroupsExec {
groupingAttributes: Seq[Attribute],
dataAttributes: Seq[Attribute],
outputObjAttr: Attribute,
timeoutConf: GroupStateTimeout,
child: SparkPlan): MapGroupsExec = {
val f = (key: Any, values: Iterator[Any]) => func(key, values, new GroupStateImpl[Any](None))
val f = (key: Any, values: Iterator[Any]) => {
func(key, values, GroupStateImpl.createForBatch(timeoutConf))
}
new MapGroupsExec(f, keyDeserializer, valueDeserializer,
groupingAttributes, dataAttributes, outputObjAttr, child)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ case class FlatMapGroupsWithStateExec(
val keyObj = getKeyObj(keyRow) // convert key to objects
val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects
val stateObjOption = getStateObj(prevStateRowOption)
val keyedState = new GroupStateImpl(
val keyedState = GroupStateImpl.createForStreaming(
stateObjOption,
batchTimestampMs.getOrElse(NO_TIMESTAMP),
eventTimeWatermark.getOrElse(NO_TIMESTAMP),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,13 @@ import org.apache.spark.unsafe.types.CalendarInterval
* @param hasTimedOut Whether the key for which this state wrapped is being created is
* getting timed out or not.
*/
private[sql] class GroupStateImpl[S](
private[sql] class GroupStateImpl[S] private(
optionalValue: Option[S],
batchProcessingTimeMs: Long,
eventTimeWatermarkMs: Long,
timeoutConf: GroupStateTimeout,
override val hasTimedOut: Boolean) extends GroupState[S] {

// Constructor to create dummy state when using mapGroupsWithState in a batch query
def this(optionalValue: Option[S]) = this(
optionalValue,
batchProcessingTimeMs = NO_TIMESTAMP,
eventTimeWatermarkMs = NO_TIMESTAMP,
timeoutConf = GroupStateTimeout.NoTimeout,
hasTimedOut = false)
private var value: S = optionalValue.getOrElse(null.asInstanceOf[S])
private var defined: Boolean = optionalValue.isDefined
private var updated: Boolean = false // whether value has been updated (but not removed)
Expand Down Expand Up @@ -102,12 +95,7 @@ private[sql] class GroupStateImpl[S](
if (durationMs <= 0) {
throw new IllegalArgumentException("Timeout duration must be positive")
}
if (batchProcessingTimeMs != NO_TIMESTAMP) {
timeoutTimestamp = durationMs + batchProcessingTimeMs
} else {
// This is being called in a batch query, hence no processing timestamp.
// Just ignore any attempts to set timeout.
}
timeoutTimestamp = durationMs + batchProcessingTimeMs
}

override def setTimeoutDuration(duration: String): Unit = {
Expand All @@ -128,12 +116,7 @@ private[sql] class GroupStateImpl[S](
s"Timeout timestamp ($timestampMs) cannot be earlier than the " +
s"current watermark ($eventTimeWatermarkMs)")
}
if (batchProcessingTimeMs != NO_TIMESTAMP) {
timeoutTimestamp = timestampMs
} else {
// This is being called in a batch query, hence no processing timestamp.
// Just ignore any attempts to set timeout.
}
timeoutTimestamp = timestampMs
}

@throws[IllegalArgumentException]("if 'additionalDuration' is invalid")
Expand Down Expand Up @@ -213,4 +196,23 @@ private[sql] class GroupStateImpl[S](
private[sql] object GroupStateImpl {
// Value used represent the lack of valid timestamp as a long
val NO_TIMESTAMP = -1L

def createForStreaming[S](
optionalValue: Option[S],
batchProcessingTimeMs: Long,
eventTimeWatermarkMs: Long,
timeoutConf: GroupStateTimeout,
hasTimedOut: Boolean): GroupStateImpl[S] = {
new GroupStateImpl[S](
optionalValue, batchProcessingTimeMs, eventTimeWatermarkMs, timeoutConf, hasTimedOut)
}

def createForBatch(timeoutConf: GroupStateTimeout): GroupStateImpl[Any] = {
new GroupStateImpl[Any](
optionalValue = None,
batchProcessingTimeMs = NO_TIMESTAMP,
eventTimeWatermarkMs = NO_TIMESTAMP,
timeoutConf,
hasTimedOut = false)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,15 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
assert(state.hasRemoved === shouldBeRemoved)
}

// === Tests for state in streaming queries ===
// Updating empty state
state = new GroupStateImpl[String](None)
state = GroupStateImpl.createForStreaming(None, 1, 1, NoTimeout, hasTimedOut = false)
testState(None)
state.update("")
testState(Some(""), shouldBeUpdated = true)

// Updating exiting state
state = new GroupStateImpl[String](Some("2"))
state = GroupStateImpl.createForStreaming(Some("2"), 1, 1, NoTimeout, hasTimedOut = false)
testState(Some("2"))
state.update("3")
testState(Some("3"), shouldBeUpdated = true)
Expand All @@ -99,48 +100,73 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
}

test("GroupState - setTimeout**** with NoTimeout") {
for (initState <- Seq(None, Some(5))) {
// for different initial state
implicit val state = new GroupStateImpl(initState, 1000, 1000, NoTimeout, hasTimedOut = false)
testTimeoutDurationNotAllowed[UnsupportedOperationException](state)
testTimeoutTimestampNotAllowed[UnsupportedOperationException](state)
for (initValue <- Seq(None, Some(5))) {
val states = Seq(
GroupStateImpl.createForStreaming(initValue, 1000, 1000, NoTimeout, hasTimedOut = false),
GroupStateImpl.createForBatch(NoTimeout)
)
for (state <- states) {
// for streaming queries
testTimeoutDurationNotAllowed[UnsupportedOperationException](state)
testTimeoutTimestampNotAllowed[UnsupportedOperationException](state)

// for batch queries
testTimeoutDurationNotAllowed[UnsupportedOperationException](state)
testTimeoutTimestampNotAllowed[UnsupportedOperationException](state)
}
}
}

test("GroupState - setTimeout**** with ProcessingTimeTimeout") {
implicit var state: GroupStateImpl[Int] = null

state = new GroupStateImpl[Int](None, 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false)
// for streaming queries
var state: GroupStateImpl[Int] = GroupStateImpl.createForStreaming(
None, 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false)
assert(state.getTimeoutTimestamp === NO_TIMESTAMP)
state.setTimeoutDuration(500)
assert(state.getTimeoutTimestamp === 1500) // can be set without initializing state
assert(state.getTimeoutTimestamp === 1500) // can be set without initializing state
testTimeoutTimestampNotAllowed[UnsupportedOperationException](state)

state.update(5)
assert(state.getTimeoutTimestamp === 1500) // does not change
assert(state.getTimeoutTimestamp === 1500) // does not change
state.setTimeoutDuration(1000)
assert(state.getTimeoutTimestamp === 2000)
state.setTimeoutDuration("2 second")
assert(state.getTimeoutTimestamp === 3000)
testTimeoutTimestampNotAllowed[UnsupportedOperationException](state)

state.remove()
assert(state.getTimeoutTimestamp === 3000) // does not change
state.setTimeoutDuration(500) // can still be set
assert(state.getTimeoutTimestamp === 3000) // does not change
state.setTimeoutDuration(500) // can still be set
assert(state.getTimeoutTimestamp === 1500)
testTimeoutTimestampNotAllowed[UnsupportedOperationException](state)

// for batch queries
state = GroupStateImpl.createForBatch(ProcessingTimeTimeout).asInstanceOf[GroupStateImpl[Int]]
assert(state.getTimeoutTimestamp === NO_TIMESTAMP)
state.setTimeoutDuration(500)
testTimeoutTimestampNotAllowed[UnsupportedOperationException](state)

state.update(5)
state.setTimeoutDuration(1000)
state.setTimeoutDuration("2 second")
testTimeoutTimestampNotAllowed[UnsupportedOperationException](state)

state.remove()
state.setTimeoutDuration(500)
testTimeoutTimestampNotAllowed[UnsupportedOperationException](state)
}

test("GroupState - setTimeout**** with EventTimeTimeout") {
implicit val state = new GroupStateImpl[Int](
None, 1000, 1000, EventTimeTimeout, hasTimedOut = false)
var state: GroupStateImpl[Int] = GroupStateImpl.createForStreaming(
None, 1000, 1000, EventTimeTimeout, false)

assert(state.getTimeoutTimestamp === NO_TIMESTAMP)
testTimeoutDurationNotAllowed[UnsupportedOperationException](state)
state.setTimeoutTimestamp(5000)
assert(state.getTimeoutTimestamp === 5000) // can be set without initializing state
assert(state.getTimeoutTimestamp === 5000) // can be set without initializing state

state.update(5)
assert(state.getTimeoutTimestamp === 5000) // does not change
assert(state.getTimeoutTimestamp === 5000) // does not change
state.setTimeoutTimestamp(10000)
assert(state.getTimeoutTimestamp === 10000)
state.setTimeoutTimestamp(new Date(20000))
Expand All @@ -150,7 +176,22 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
state.remove()
assert(state.getTimeoutTimestamp === 20000)
state.setTimeoutTimestamp(5000)
assert(state.getTimeoutTimestamp === 5000) // can be set after removing state
assert(state.getTimeoutTimestamp === 5000) // can be set after removing state
testTimeoutDurationNotAllowed[UnsupportedOperationException](state)

// for batch queries
state = GroupStateImpl.createForBatch(EventTimeTimeout).asInstanceOf[GroupStateImpl[Int]]
assert(state.getTimeoutTimestamp === NO_TIMESTAMP)
testTimeoutDurationNotAllowed[UnsupportedOperationException](state)
state.setTimeoutTimestamp(5000)

state.update(5)
state.setTimeoutTimestamp(10000)
state.setTimeoutTimestamp(new Date(20000))
testTimeoutDurationNotAllowed[UnsupportedOperationException](state)

state.remove()
state.setTimeoutTimestamp(5000)
testTimeoutDurationNotAllowed[UnsupportedOperationException](state)
}

Expand All @@ -165,7 +206,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
assert(state.getTimeoutTimestamp === NO_TIMESTAMP)
}

state = new GroupStateImpl(Some(5), 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false)
state = GroupStateImpl.createForStreaming(
Some(5), 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false)
testIllegalTimeout {
state.setTimeoutDuration(-1000)
}
Expand All @@ -182,7 +224,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
state.setTimeoutDuration("1 month -1 day")
}

state = new GroupStateImpl(Some(5), 1000, 1000, EventTimeTimeout, hasTimedOut = false)
state = GroupStateImpl.createForStreaming(
Some(5), 1000, 1000, EventTimeTimeout, hasTimedOut = false)
testIllegalTimeout {
state.setTimeoutTimestamp(-10000)
}
Expand Down Expand Up @@ -211,23 +254,32 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf

test("GroupState - hasTimedOut") {
for (timeoutConf <- Seq(NoTimeout, ProcessingTimeTimeout, EventTimeTimeout)) {
// for streaming queries
for (initState <- Seq(None, Some(5))) {
val state1 = new GroupStateImpl(initState, 1000, 1000, timeoutConf, hasTimedOut = false)
val state1 = GroupStateImpl.createForStreaming(
initState, 1000, 1000, timeoutConf, hasTimedOut = false)
assert(state1.hasTimedOut === false)
val state2 = new GroupStateImpl(initState, 1000, 1000, timeoutConf, hasTimedOut = true)

val state2 = GroupStateImpl.createForStreaming(
initState, 1000, 1000, timeoutConf, hasTimedOut = true)
assert(state2.hasTimedOut === true)
}

// for batch queries
assert(GroupStateImpl.createForBatch(timeoutConf).hasTimedOut === false)
}
}

test("GroupState - primitive type") {
var intState = new GroupStateImpl[Int](None)
var intState = GroupStateImpl.createForStreaming[Int](
None, 1000, 1000, NoTimeout, hasTimedOut = false)
intercept[NoSuchElementException] {
intState.get
}
assert(intState.getOption === None)

intState = new GroupStateImpl[Int](Some(10))
intState = GroupStateImpl.createForStreaming[Int](
Some(10), 1000, 1000, NoTimeout, hasTimedOut = false)
assert(intState.get == 10)
intState.update(0)
assert(intState.get == 0)
Expand All @@ -243,7 +295,6 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
val beforeTimeoutThreshold = 999
val afterTimeoutThreshold = 1001


// Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout = NoTimeout
for (priorState <- Seq(None, Some(0))) {
val priorStateStr = if (priorState.nonEmpty) "prior state set" else "no prior state"
Expand Down Expand Up @@ -748,15 +799,21 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
}

test("mapGroupsWithState - batch") {
val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => {
// Test the following
// - no initial state
// - timeouts operations work, does not throw any error [SPARK-20792]
// - works with primitive state type
val stateFunc = (key: String, values: Iterator[String], state: GroupState[Int]) => {
if (state.exists) throw new IllegalArgumentException("state.exists should be false")
state.setTimeoutTimestamp(0, "1 hour")
state.update(10)
(key, values.size)
}

checkAnswer(
spark.createDataset(Seq("a", "a", "b"))
.groupByKey(x => x)
.mapGroupsWithState(stateFunc)
.mapGroupsWithState(EventTimeTimeout)(stateFunc)
.toDF,
spark.createDataset(Seq(("a", 2), ("b", 1))).toDF)
}
Expand Down