diff --git a/core/src/main/scala/org/apache/spark/util/UninterruptibleThreadRunner.scala b/core/src/main/scala/org/apache/spark/util/UninterruptibleThreadRunner.scala new file mode 100644 index 000000000000..18108aa819db --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/UninterruptibleThreadRunner.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.util.concurrent.Executors + +import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.duration.Duration + +/** + * [[UninterruptibleThreadRunner]] ensures that all tasks are running in an + * [[UninterruptibleThread]]. A good example is Kafka consumer usage. + */ +private[spark] class UninterruptibleThreadRunner(threadName: String) { + private val thread = Executors.newSingleThreadExecutor((r: Runnable) => { + val t = new UninterruptibleThread(threadName) { + override def run(): Unit = { + r.run() + } + } + t.setDaemon(true) + t + }) + private val execContext = ExecutionContext.fromExecutorService(thread) + + def runUninterruptibly[T](body: => T): T = { + if (!Thread.currentThread.isInstanceOf[UninterruptibleThread]) { + val future = Future { + body + }(execContext) + ThreadUtils.awaitResult(future, Duration.Inf) + } else { + body + } + } + + def shutdown(): Unit = { + thread.shutdown() + } +} diff --git a/core/src/test/scala/org/apache/spark/util/UninterruptibleThreadRunnerSuite.scala b/core/src/test/scala/org/apache/spark/util/UninterruptibleThreadRunnerSuite.scala new file mode 100644 index 000000000000..40312beacdff --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/UninterruptibleThreadRunnerSuite.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import org.apache.spark.SparkFunSuite + +class UninterruptibleThreadRunnerSuite extends SparkFunSuite { + private var runner: UninterruptibleThreadRunner = null + + override def beforeEach(): Unit = { + runner = new UninterruptibleThreadRunner("ThreadName") + } + + override def afterEach(): Unit = { + runner.shutdown() + } + + test("runUninterruptibly should switch to UninterruptibleThread") { + assert(!Thread.currentThread().isInstanceOf[UninterruptibleThread]) + var isUninterruptibleThread = false + runner.runUninterruptibly { + isUninterruptibleThread = Thread.currentThread().isInstanceOf[UninterruptibleThread] + } + assert(isUninterruptibleThread, "The runner task must run in UninterruptibleThread") + } + + test("runUninterruptibly should not add new UninterruptibleThread") { + var isInitialUninterruptibleThread = false + var isRunnerUninterruptibleThread = false + val t = new UninterruptibleThread("test") { + override def run(): Unit = { + runUninterruptibly { + val initialThread = Thread.currentThread() + isInitialUninterruptibleThread = initialThread.isInstanceOf[UninterruptibleThread] + runner.runUninterruptibly { + val runnerThread = Thread.currentThread() + isRunnerUninterruptibleThread = runnerThread.isInstanceOf[UninterruptibleThread] + assert(runnerThread.eq(initialThread)) + } + } + } + } + t.start() + t.join() + assert(isInitialUninterruptibleThread, + "The initiator must already run in UninterruptibleThread") + assert(isRunnerUninterruptibleThread, "The runner task must run in UninterruptibleThread") + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala index 5ab786267495..6d30bd2a6d2c 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala @@ -18,12 +18,9 @@ package org.apache.spark.sql.kafka010 import java.{util => ju} -import java.util.concurrent.Executors import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer -import scala.concurrent.{ExecutionContext, Future} -import scala.concurrent.duration.Duration import scala.util.control.NonFatal import org.apache.kafka.clients.consumer.{Consumer, ConsumerConfig, KafkaConsumer, OffsetAndTimestamp} @@ -33,7 +30,7 @@ import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.util.{ThreadUtils, UninterruptibleThread} +import org.apache.spark.util.{UninterruptibleThread, UninterruptibleThreadRunner} /** * This class uses Kafka's own [[KafkaConsumer]] API to read data offsets from Kafka. @@ -51,19 +48,13 @@ private[kafka010] class KafkaOffsetReader( val driverKafkaParams: ju.Map[String, Object], readerOptions: CaseInsensitiveMap[String], driverGroupIdPrefix: String) extends Logging { + /** - * Used to ensure execute fetch operations execute in an UninterruptibleThread + * [[UninterruptibleThreadRunner]] ensures that all [[KafkaConsumer]] communication called in an + * [[UninterruptibleThread]]. In the case of streaming queries, we are already running in an + * [[UninterruptibleThread]], however for batch mode this is not the case. */ - val kafkaReaderThread = Executors.newSingleThreadExecutor((r: Runnable) => { - val t = new UninterruptibleThread("Kafka Offset Reader") { - override def run(): Unit = { - r.run() - } - } - t.setDaemon(true) - t - }) - val execContext = ExecutionContext.fromExecutorService(kafkaReaderThread) + val uninterruptibleThreadRunner = new UninterruptibleThreadRunner("Kafka Offset Reader") /** * Place [[groupId]] and [[nextId]] here so that they are initialized before any consumer is @@ -126,14 +117,14 @@ private[kafka010] class KafkaOffsetReader( * Closes the connection to Kafka, and cleans up state. */ def close(): Unit = { - if (_consumer != null) runUninterruptibly { stopConsumer() } - kafkaReaderThread.shutdown() + if (_consumer != null) uninterruptibleThreadRunner.runUninterruptibly { stopConsumer() } + uninterruptibleThreadRunner.shutdown() } /** * @return The Set of TopicPartitions for a given topic */ - def fetchTopicPartitions(): Set[TopicPartition] = runUninterruptibly { + def fetchTopicPartitions(): Set[TopicPartition] = uninterruptibleThreadRunner.runUninterruptibly { assert(Thread.currentThread().isInstanceOf[UninterruptibleThread]) // Poll to get the latest assigned partitions consumer.poll(0) @@ -531,7 +522,7 @@ private[kafka010] class KafkaOffsetReader( private def partitionsAssignedToConsumer( body: ju.Set[TopicPartition] => Map[TopicPartition, Long], fetchingEarliestOffset: Boolean = false) - : Map[TopicPartition, Long] = runUninterruptibly { + : Map[TopicPartition, Long] = uninterruptibleThreadRunner.runUninterruptibly { withRetriesWithoutInterrupt { // Poll to get the latest assigned partitions @@ -551,23 +542,6 @@ private[kafka010] class KafkaOffsetReader( } } - /** - * This method ensures that the closure is called in an [[UninterruptibleThread]]. - * This is required when communicating with the [[KafkaConsumer]]. In the case - * of streaming queries, we are already running in an [[UninterruptibleThread]], - * however for batch mode this is not the case. - */ - private def runUninterruptibly[T](body: => T): T = { - if (!Thread.currentThread.isInstanceOf[UninterruptibleThread]) { - val future = Future { - body - }(execContext) - ThreadUtils.awaitResult(future, Duration.Inf) - } else { - body - } - } - /** * Helper function that does multiple retries on a body of code that returns offsets. * Retries are needed to handle transient failures. For e.g. race conditions between getting