diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala index ab020aaf6fa4f..5b33c110154d6 100644 --- a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala +++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala @@ -52,6 +52,18 @@ private[spark] abstract class StreamFileInputFormat[T] val totalBytes = files.filterNot(_.isDirectory).map(_.getLen + openCostInBytes).sum val bytesPerCore = totalBytes / defaultParallelism val maxSplitSize = Math.min(defaultMaxSplitBytes, Math.max(openCostInBytes, bytesPerCore)) + + // For small files we need to ensure the min split size per node & rack <= maxSplitSize + val jobConfig = context.getConfiguration + val minSplitSizePerNode = jobConfig.getLong(CombineFileInputFormat.SPLIT_MINSIZE_PERNODE, 0L) + val minSplitSizePerRack = jobConfig.getLong(CombineFileInputFormat.SPLIT_MINSIZE_PERRACK, 0L) + + if (maxSplitSize < minSplitSizePerNode) { + super.setMinSplitSizeNode(maxSplitSize) + } + if (maxSplitSize < minSplitSizePerRack) { + super.setMinSplitSizeRack(maxSplitSize) + } super.setMaxSplitSize(maxSplitSize) } diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 81b18c71f30ee..34efcdf4bc886 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -320,6 +320,19 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { } } + test("minimum split size per node and per rack should be less than or equal to maxSplitSize") { + sc = new SparkContext("local", "test") + val testOutput = Array[Byte](1, 2, 3, 4, 5) + val outFile = writeBinaryData(testOutput, 1) + sc.hadoopConfiguration.setLong( + "mapreduce.input.fileinputformat.split.minsize.per.node", 5123456) + sc.hadoopConfiguration.setLong( + "mapreduce.input.fileinputformat.split.minsize.per.rack", 5123456) + + val (_, data) = sc.binaryFiles(outFile.getAbsolutePath).collect().head + assert(data.toArray === testOutput) + } + test("fixed record length binary file as byte array") { sc = new SparkContext("local", "test") val testOutput = Array[Byte](1, 2, 3, 4, 5, 6)