Skip to content

Commit b890fdc

Browse files
gaborgsomogyidongjoon-hyun
authored andcommitted
[SPARK-32387][SS] Extract UninterruptibleThread runner logic from KafkaOffsetReader
### What changes were proposed in this pull request? `UninterruptibleThread` running functionality is baked into `KafkaOffsetReader` which can be extracted into a class. The main intention is to simplify `KafkaOffsetReader` in order to make easier to solve SPARK-32032. In this PR I've made this extraction without functionality change. ### Why are the changes needed? `UninterruptibleThread` running functionality is baked into `KafkaOffsetReader`. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing + additional unit tests. Closes #29187 from gaborgsomogyi/SPARK-32387. Authored-by: Gabor Somogyi <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent e6ef27b commit b890fdc

File tree

3 files changed

+129
-36
lines changed

3 files changed

+129
-36
lines changed
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.util
19+
20+
import java.util.concurrent.Executors
21+
22+
import scala.concurrent.{ExecutionContext, Future}
23+
import scala.concurrent.duration.Duration
24+
25+
/**
26+
* [[UninterruptibleThreadRunner]] ensures that all tasks are running in an
27+
* [[UninterruptibleThread]]. A good example is Kafka consumer usage.
28+
*/
29+
private[spark] class UninterruptibleThreadRunner(threadName: String) {
30+
private val thread = Executors.newSingleThreadExecutor((r: Runnable) => {
31+
val t = new UninterruptibleThread(threadName) {
32+
override def run(): Unit = {
33+
r.run()
34+
}
35+
}
36+
t.setDaemon(true)
37+
t
38+
})
39+
private val execContext = ExecutionContext.fromExecutorService(thread)
40+
41+
def runUninterruptibly[T](body: => T): T = {
42+
if (!Thread.currentThread.isInstanceOf[UninterruptibleThread]) {
43+
val future = Future {
44+
body
45+
}(execContext)
46+
ThreadUtils.awaitResult(future, Duration.Inf)
47+
} else {
48+
body
49+
}
50+
}
51+
52+
def shutdown(): Unit = {
53+
thread.shutdown()
54+
}
55+
}
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.util
19+
20+
import org.apache.spark.SparkFunSuite
21+
22+
class UninterruptibleThreadRunnerSuite extends SparkFunSuite {
23+
private var runner: UninterruptibleThreadRunner = null
24+
25+
override def beforeEach(): Unit = {
26+
runner = new UninterruptibleThreadRunner("ThreadName")
27+
}
28+
29+
override def afterEach(): Unit = {
30+
runner.shutdown()
31+
}
32+
33+
test("runUninterruptibly should switch to UninterruptibleThread") {
34+
assert(!Thread.currentThread().isInstanceOf[UninterruptibleThread])
35+
var isUninterruptibleThread = false
36+
runner.runUninterruptibly {
37+
isUninterruptibleThread = Thread.currentThread().isInstanceOf[UninterruptibleThread]
38+
}
39+
assert(isUninterruptibleThread, "The runner task must run in UninterruptibleThread")
40+
}
41+
42+
test("runUninterruptibly should not add new UninterruptibleThread") {
43+
var isInitialUninterruptibleThread = false
44+
var isRunnerUninterruptibleThread = false
45+
val t = new UninterruptibleThread("test") {
46+
override def run(): Unit = {
47+
runUninterruptibly {
48+
val initialThread = Thread.currentThread()
49+
isInitialUninterruptibleThread = initialThread.isInstanceOf[UninterruptibleThread]
50+
runner.runUninterruptibly {
51+
val runnerThread = Thread.currentThread()
52+
isRunnerUninterruptibleThread = runnerThread.isInstanceOf[UninterruptibleThread]
53+
assert(runnerThread.eq(initialThread))
54+
}
55+
}
56+
}
57+
}
58+
t.start()
59+
t.join()
60+
assert(isInitialUninterruptibleThread,
61+
"The initiator must already run in UninterruptibleThread")
62+
assert(isRunnerUninterruptibleThread, "The runner task must run in UninterruptibleThread")
63+
}
64+
}

external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala

Lines changed: 10 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,9 @@
1818
package org.apache.spark.sql.kafka010
1919

2020
import java.{util => ju}
21-
import java.util.concurrent.Executors
2221

2322
import scala.collection.JavaConverters._
2423
import scala.collection.mutable.ArrayBuffer
25-
import scala.concurrent.{ExecutionContext, Future}
26-
import scala.concurrent.duration.Duration
2724
import scala.util.control.NonFatal
2825

2926
import org.apache.kafka.clients.consumer.{Consumer, ConsumerConfig, KafkaConsumer, OffsetAndTimestamp}
@@ -33,7 +30,7 @@ import org.apache.spark.SparkEnv
3330
import org.apache.spark.internal.Logging
3431
import org.apache.spark.scheduler.ExecutorCacheTaskLocation
3532
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
36-
import org.apache.spark.util.{ThreadUtils, UninterruptibleThread}
33+
import org.apache.spark.util.{UninterruptibleThread, UninterruptibleThreadRunner}
3734

3835
/**
3936
* This class uses Kafka's own [[KafkaConsumer]] API to read data offsets from Kafka.
@@ -51,19 +48,13 @@ private[kafka010] class KafkaOffsetReader(
5148
val driverKafkaParams: ju.Map[String, Object],
5249
readerOptions: CaseInsensitiveMap[String],
5350
driverGroupIdPrefix: String) extends Logging {
51+
5452
/**
55-
* Used to ensure execute fetch operations execute in an UninterruptibleThread
53+
* [[UninterruptibleThreadRunner]] ensures that all [[KafkaConsumer]] communication called in an
54+
* [[UninterruptibleThread]]. In the case of streaming queries, we are already running in an
55+
* [[UninterruptibleThread]], however for batch mode this is not the case.
5656
*/
57-
val kafkaReaderThread = Executors.newSingleThreadExecutor((r: Runnable) => {
58-
val t = new UninterruptibleThread("Kafka Offset Reader") {
59-
override def run(): Unit = {
60-
r.run()
61-
}
62-
}
63-
t.setDaemon(true)
64-
t
65-
})
66-
val execContext = ExecutionContext.fromExecutorService(kafkaReaderThread)
57+
val uninterruptibleThreadRunner = new UninterruptibleThreadRunner("Kafka Offset Reader")
6758

6859
/**
6960
* Place [[groupId]] and [[nextId]] here so that they are initialized before any consumer is
@@ -126,14 +117,14 @@ private[kafka010] class KafkaOffsetReader(
126117
* Closes the connection to Kafka, and cleans up state.
127118
*/
128119
def close(): Unit = {
129-
if (_consumer != null) runUninterruptibly { stopConsumer() }
130-
kafkaReaderThread.shutdown()
120+
if (_consumer != null) uninterruptibleThreadRunner.runUninterruptibly { stopConsumer() }
121+
uninterruptibleThreadRunner.shutdown()
131122
}
132123

133124
/**
134125
* @return The Set of TopicPartitions for a given topic
135126
*/
136-
def fetchTopicPartitions(): Set[TopicPartition] = runUninterruptibly {
127+
def fetchTopicPartitions(): Set[TopicPartition] = uninterruptibleThreadRunner.runUninterruptibly {
137128
assert(Thread.currentThread().isInstanceOf[UninterruptibleThread])
138129
// Poll to get the latest assigned partitions
139130
consumer.poll(0)
@@ -531,7 +522,7 @@ private[kafka010] class KafkaOffsetReader(
531522
private def partitionsAssignedToConsumer(
532523
body: ju.Set[TopicPartition] => Map[TopicPartition, Long],
533524
fetchingEarliestOffset: Boolean = false)
534-
: Map[TopicPartition, Long] = runUninterruptibly {
525+
: Map[TopicPartition, Long] = uninterruptibleThreadRunner.runUninterruptibly {
535526

536527
withRetriesWithoutInterrupt {
537528
// Poll to get the latest assigned partitions
@@ -551,23 +542,6 @@ private[kafka010] class KafkaOffsetReader(
551542
}
552543
}
553544

554-
/**
555-
* This method ensures that the closure is called in an [[UninterruptibleThread]].
556-
* This is required when communicating with the [[KafkaConsumer]]. In the case
557-
* of streaming queries, we are already running in an [[UninterruptibleThread]],
558-
* however for batch mode this is not the case.
559-
*/
560-
private def runUninterruptibly[T](body: => T): T = {
561-
if (!Thread.currentThread.isInstanceOf[UninterruptibleThread]) {
562-
val future = Future {
563-
body
564-
}(execContext)
565-
ThreadUtils.awaitResult(future, Duration.Inf)
566-
} else {
567-
body
568-
}
569-
}
570-
571545
/**
572546
* Helper function that does multiple retries on a body of code that returns offsets.
573547
* Retries are needed to handle transient failures. For e.g. race conditions between getting

0 commit comments

Comments
 (0)