Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1354,7 +1354,8 @@ private class BufferReleasingInputStream(
}
}

override def available(): Int = delegate.available()
override def available(): Int =
tryOrFetchFailedException(delegate.available())

override def mark(readlimit: Int): Unit = delegate.mark(readlimit)

Expand All @@ -1369,12 +1370,13 @@ private class BufferReleasingInputStream(
override def read(b: Array[Byte], off: Int, len: Int): Int =
tryOrFetchFailedException(delegate.read(b, off, len))

override def reset(): Unit = delegate.reset()
override def reset(): Unit = tryOrFetchFailedException(delegate.reset())

/**
* Execute a block of code that returns a value, close this stream quietly and re-throwing
* IOException as FetchFailedException when detectCorruption is true. This method is only
* used by the `read` and `skip` methods inside `BufferReleasingInputStream` currently.
* used by the `available`, `read` and `skip` methods inside `BufferReleasingInputStream`
* currently.
*/
private def tryOrFetchFailedException[T](block: => T): T = {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
blocksByAddress: Map[BlockManagerId, Seq[(BlockId, Long, Int)]],
taskContext: Option[TaskContext] = None,
streamWrapperLimitSize: Option[Long] = None,
corruptAtAvailableReset: Boolean = false,
blockManager: Option[BlockManager] = None,
maxBytesInFlight: Long = Long.MaxValue,
maxReqsInFlight: Int = Int.MaxValue,
Expand All @@ -201,7 +202,14 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
blockManager.getOrElse(createMockBlockManager()),
mapOutputTracker,
blocksByAddress.iterator,
(_, in) => streamWrapperLimitSize.map(new LimitedInputStream(in, _)).getOrElse(in),
(_, in) => {
val limited = streamWrapperLimitSize.map(new LimitedInputStream(in, _)).getOrElse(in)
if (corruptAtAvailableReset) {
new CorruptAvailableResetStream(limited)
} else {
limited
}
},
maxBytesInFlight,
maxReqsInFlight,
maxBlocksInFlightPerAddress,
Expand Down Expand Up @@ -712,6 +720,16 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
corruptBuffer
}

private class CorruptAvailableResetStream(in: InputStream) extends InputStream {
override def read(): Int = in.read()

override def read(dest: Array[Byte], off: Int, len: Int): Int = in.read(dest, off, len)

override def available(): Int = throw new IOException("corrupt at available")

override def reset(): Unit = throw new IOException("corrupt at reset")
}

private class CorruptStream(corruptAt: Long = 0L) extends InputStream {
var pos = 0
var closed = false
Expand Down Expand Up @@ -1879,4 +1897,48 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
blockManager = Some(blockManager), streamWrapperLimitSize = Some(100))
verifyLocalBlocksFromFallback(iterator)
}

test("SPARK-45678: retry corrupt blocks on available() and reset()") {
val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
val blocks = Map[BlockId, ManagedBuffer](
ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer()
)

// Semaphore to coordinate event sequence in two different threads.
val sem = new Semaphore(0)

answerFetchBlocks { invocation =>
val listener = invocation.getArgument[BlockFetchingListener](4)
Future {
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 0, 0).toString, createMockManagedBuffer())
sem.release()
}
}

val iterator = createShuffleBlockIteratorWithDefaults(
Map(remoteBmId -> toBlockList(blocks.keys, 1L, 0)),
streamWrapperLimitSize = Some(100),
detectCorruptUseExtraMemory = false, // Don't use `ChunkedByteBufferInputStream`.
corruptAtAvailableReset = true,
checksumEnabled = false
)

sem.acquire()

val (id1, stream) = iterator.next()
assert(id1 === ShuffleBlockId(0, 0, 0))

val err1 = intercept[FetchFailedException] {
stream.available()
}

assert(err1.getMessage.contains("corrupt at available"))

val err2 = intercept[FetchFailedException] {
stream.reset()
}

assert(err2.getMessage.contains("corrupt at reset"))
}
}