@@ -22,11 +22,11 @@ import java.util.concurrent.TimeoutException
2222
2323import scala .concurrent .duration .FiniteDuration
2424import scala .concurrent .duration ._
25- import scala .concurrent .{Await , Future }
25+ import scala .concurrent .{Awaitable , Await , Future }
2626import scala .language .postfixOps
2727
2828import org .apache .spark .{SecurityManager , SparkConf }
29- import org .apache .spark .util .{RpcUtils , Utils }
29+ import org .apache .spark .util .{ThreadUtils , RpcUtils , Utils }
3030
3131
3232/**
@@ -187,6 +187,13 @@ private[spark] object RpcAddress {
187187}
188188
189189
190+ /**
191+ * An exception thrown if RpcTimeout modifies a [[TimeoutException ]].
192+ */
193+ private [rpc] class RpcTimeoutException (message : String )
194+ extends TimeoutException (message)
195+
196+
190197/**
191198 * Associates a timeout with a description so that a when a TimeoutException occurs, additional
192199 * context about the timeout can be amended to the exception message.
@@ -202,17 +209,44 @@ private[spark] class RpcTimeout(timeout: FiniteDuration, description: String) {
202209 def message : String = description
203210
204211 /** Amends the standard message of TimeoutException to include the description */
205- def amend (te : TimeoutException ): TimeoutException = {
206- new TimeoutException (te.getMessage() + " " + description)
212+ def createRpcTimeoutException (te : TimeoutException ): RpcTimeoutException = {
213+ new RpcTimeoutException (te.getMessage() + " " + description)
214+ }
215+
216+ /**
217+ * Add a callback to the given Future so that if it completes as failed with a TimeoutException
218+ * then the timeout description is added to the message
219+ */
220+ def addMessageIfTimeout [T ](future : Future [T ]): Future [T ] = {
221+ future.recover {
222+ // Add a warning message if Future is passed to addMessageIfTimeoutTest more than once
223+ case rte : RpcTimeoutException => throw new RpcTimeoutException (rte.getMessage() +
224+ " (Future has multiple calls to RpcTimeout.addMessageIfTimeoutTest)" )
225+ // Any other TimeoutException get converted to a RpcTimeoutException with modified message
226+ case te : TimeoutException => throw createRpcTimeoutException(te)
227+ }(ThreadUtils .sameThread)
228+ }
229+
230+ /** Applies the duration to create future before calling addMessageIfTimeout*/
231+ def addMessageIfTimeout [T ](f : FiniteDuration => Future [T ]): Future [T ] = {
232+ addMessageIfTimeout(f(duration))
207233 }
208234
209- /** Wait on a future result to catch and amend a TimeoutException */
210- def awaitResult [T ](future : Future [T ]): T = {
235+ /**
236+ * Waits for a completed result to catch and amend a TimeoutException message
237+ * @param awaitable the `Awaitable` to be awaited
238+ * @throws RpcTimeoutException if after waiting for the specified time `awaitable`
239+ * is still not ready
240+ */
241+ def awaitResult [T ](awaitable : Awaitable [T ]): T = {
211242 try {
212- Await .result(future , duration)
243+ Await .result(awaitable , duration)
213244 }
214245 catch {
215- case te : TimeoutException => throw amend(te)
246+ // The exception has already been converted to a RpcTimeoutException so just raise it
247+ case rte : RpcTimeoutException => throw rte
248+ // Any other TimeoutException get converted to a RpcTimeoutException with modified message
249+ case te : TimeoutException => throw createRpcTimeoutException(te)
216250 }
217251 }
218252}
0 commit comments