Skip to content

Commit 162c5be

Browse files
JoshRosencloud-fan
authored andcommitted
[SPARK-22982] Remove unsafe asynchronous close() call from FileDownloadChannel
## What changes were proposed in this pull request? This patch fixes a severe asynchronous IO bug in Spark's Netty-based file transfer code. At a high-level, the problem is that an unsafe asynchronous `close()` of a pipe's source channel creates a race condition where file transfer code closes a file descriptor then attempts to read from it. If the closed file descriptor's number has been reused by an `open()` call then this invalid read may cause unrelated file operations to return incorrect results. **One manifestation of this problem is incorrect query results.** For a high-level overview of how file download works, take a look at the control flow in `NettyRpcEnv.openChannel()`: this code creates a pipe to buffer results, then submits an asynchronous stream request to a lower-level TransportClient. The callback passes received data to the sink end of the pipe. The source end of the pipe is passed back to the caller of `openChannel()`. Thus `openChannel()` returns immediately and callers interact with the returned pipe source channel. Because the underlying stream request is asynchronous, errors may occur after `openChannel()` has returned and after that method's caller has started to `read()` from the returned channel. For example, if a client requests an invalid stream from a remote server then the "stream does not exist" error may not be received from the remote server until after `openChannel()` has returned. In order to be able to propagate the "stream does not exist" error to the file-fetching application thread, this code wraps the pipe's source channel in a special `FileDownloadChannel` which adds an `setError(t: Throwable)` method, then calls this `setError()` method in the FileDownloadCallback's `onFailure` method. It is possible for `FileDownloadChannel`'s `read()` and `setError()` methods to be called concurrently from different threads: the `setError()` method is called from within the Netty RPC system's stream callback handlers, while the `read()` methods are called from higher-level application code performing remote stream reads. The problem lies in `setError()`: the existing code closed the wrapped pipe source channel. Because `read()` and `setError()` occur in different threads, this means it is possible for one thread to be calling `source.read()` while another asynchronously calls `source.close()`. Java's IO libraries do not guarantee that this will be safe and, in fact, it's possible for these operations to interleave in such a way that a lower-level `read()` system call occurs right after a `close()` call. In the best-case, this fails as a read of a closed file descriptor; in the worst-case, the file descriptor number has been re-used by an intervening `open()` operation and the read corrupts the result of an unrelated file IO operation being performed by a different thread. The solution here is to remove the `stream.close()` call in `onError()`: the thread that is performing the `read()` calls is responsible for closing the stream in a `finally` block, so there's no need to close it here. If that thread is blocked in a `read()` then it will become unblocked when the sink end of the pipe is closed in `FileDownloadCallback.onFailure()`. After making this change, we also need to refine the `read()` method to always check for a `setError()` result, even if the underlying channel `read()` call has succeeded. This patch also makes a slight cleanup to a dodgy-looking `catch e: Exception` block to use a safer `try-finally` error handling idiom. This bug was introduced in SPARK-11956 / #9941 and is present in Spark 1.6.0+. ## How was this patch tested? This fix was tested manually against a workload which non-deterministically hit this bug. Author: Josh Rosen <[email protected]> Closes #20179 from JoshRosen/SPARK-22982-fix-unsafe-async-io-in-file-download-channel. (cherry picked from commit edf0a48) Signed-off-by: Wenchen Fan <[email protected]>
1 parent 20a8c88 commit 162c5be

File tree

2 files changed

+39
-19
lines changed

2 files changed

+39
-19
lines changed

core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -332,16 +332,14 @@ private[netty] class NettyRpcEnv(
332332

333333
val pipe = Pipe.open()
334334
val source = new FileDownloadChannel(pipe.source())
335-
try {
335+
Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
336336
val client = downloadClient(parsedUri.getHost(), parsedUri.getPort())
337337
val callback = new FileDownloadCallback(pipe.sink(), source, client)
338338
client.stream(parsedUri.getPath(), callback)
339-
} catch {
340-
case e: Exception =>
341-
pipe.sink().close()
342-
source.close()
343-
throw e
344-
}
339+
})(catchBlock = {
340+
pipe.sink().close()
341+
source.close()
342+
})
345343

346344
source
347345
}
@@ -370,24 +368,33 @@ private[netty] class NettyRpcEnv(
370368
fileDownloadFactory.createClient(host, port)
371369
}
372370

373-
private class FileDownloadChannel(source: ReadableByteChannel) extends ReadableByteChannel {
371+
private class FileDownloadChannel(source: Pipe.SourceChannel) extends ReadableByteChannel {
374372

375373
@volatile private var error: Throwable = _
376374

377375
def setError(e: Throwable): Unit = {
376+
// This setError callback is invoked by internal RPC threads in order to propagate remote
377+
// exceptions to application-level threads which are reading from this channel. When an
378+
// RPC error occurs, the RPC system will call setError() and then will close the
379+
// Pipe.SinkChannel corresponding to the other end of the `source` pipe. Closing of the pipe
380+
// sink will cause `source.read()` operations to return EOF, unblocking the application-level
381+
// reading thread. Thus there is no need to actually call `source.close()` here in the
382+
// onError() callback and, in fact, calling it here would be dangerous because the close()
383+
// would be asynchronous with respect to the read() call and could trigger race-conditions
384+
// that lead to data corruption. See the PR for SPARK-22982 for more details on this topic.
378385
error = e
379-
source.close()
380386
}
381387

382388
override def read(dst: ByteBuffer): Int = {
383389
Try(source.read(dst)) match {
390+
// See the documentation above in setError(): if an RPC error has occurred then setError()
391+
// will be called to propagate the RPC error and then `source`'s corresponding
392+
// Pipe.SinkChannel will be closed, unblocking this read. In that case, we want to propagate
393+
// the remote RPC exception (and not any exceptions triggered by the pipe close, such as
394+
// ChannelClosedException), hence this `error != null` check:
395+
case _ if error != null => throw error
384396
case Success(bytesRead) => bytesRead
385-
case Failure(readErr) =>
386-
if (error != null) {
387-
throw error
388-
} else {
389-
throw readErr
390-
}
397+
case Failure(readErr) => throw readErr
391398
}
392399
}
393400

core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
package org.apache.spark.shuffle
1919

2020
import java.io._
21-
22-
import com.google.common.io.ByteStreams
21+
import java.nio.channels.Channels
22+
import java.nio.file.Files
2323

2424
import org.apache.spark.{SparkConf, SparkEnv}
2525
import org.apache.spark.internal.Logging
@@ -196,11 +196,24 @@ private[spark] class IndexShuffleBlockResolver(
196196
// find out the consolidated file, then the offset within that from our index
197197
val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId)
198198

199-
val in = new DataInputStream(new FileInputStream(indexFile))
199+
// SPARK-22982: if this FileInputStream's position is seeked forward by another piece of code
200+
// which is incorrectly using our file descriptor then this code will fetch the wrong offsets
201+
// (which may cause a reducer to be sent a different reducer's data). The explicit position
202+
// checks added here were a useful debugging aid during SPARK-22982 and may help prevent this
203+
// class of issue from re-occurring in the future which is why they are left here even though
204+
// SPARK-22982 is fixed.
205+
val channel = Files.newByteChannel(indexFile.toPath)
206+
channel.position(blockId.reduceId * 8)
207+
val in = new DataInputStream(Channels.newInputStream(channel))
200208
try {
201-
ByteStreams.skipFully(in, blockId.reduceId * 8)
202209
val offset = in.readLong()
203210
val nextOffset = in.readLong()
211+
val actualPosition = channel.position()
212+
val expectedPosition = blockId.reduceId * 8 + 16
213+
if (actualPosition != expectedPosition) {
214+
throw new Exception(s"SPARK-22982: Incorrect channel position after index file reads: " +
215+
s"expected $expectedPosition but actual position was $actualPosition.")
216+
}
204217
new FileSegmentManagedBuffer(
205218
transportConf,
206219
getDataFile(blockId.shuffleId, blockId.mapId),

0 commit comments

Comments
 (0)