Skip to content

Commit 20402cd

Browse files
committed
use ZipEntry
1 parent 9d87c3f commit 20402cd

File tree

1 file changed

+30
-6
lines changed

1 file changed

+30
-6
lines changed

project/SparkBuild.scala

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ object Assembly {
354354
object PySparkAssembly {
355355
import sbtassembly.Plugin._
356356
import AssemblyKeys._
357+
import java.util.zip.{ZipOutputStream, ZipEntry}
357358

358359
lazy val settings = Seq(
359360
unmanagedJars in Compile += { BuildCommons.sparkHome / "python/lib/py4j-0.8.2.1-src.zip" },
@@ -365,12 +366,7 @@ object PySparkAssembly {
365366

366367
val zipFile = new File(BuildCommons.sparkHome , "python/lib/pyspark.zip")
367368
zipFile.delete()
368-
def entries(f: File): List[File] =
369-
f :: (if (f.isDirectory) IO.listFiles(f).toList.flatMap(entries(_)) else Nil)
370-
val sources = entries(src).map { d =>
371-
(d, d.getAbsolutePath.substring(src.getParent.length +1))
372-
}
373-
IO.zip(sources, zipFile)
369+
zipRecursive(src, zipFile)
374370

375371
val dst = new File(outDir, "pyspark")
376372
if (!dst.isDirectory()) {
@@ -380,6 +376,34 @@ object PySparkAssembly {
380376
}
381377
)
382378

379+
private def zipRecursive(source: File, destZipFile: File) = {
380+
val destOutput = new ZipOutputStream(new FileOutputStream(destZipFile))
381+
addFilesToZipStream("", source, destOutput)
382+
destOutput.flush()
383+
destOutput.close()
384+
}
385+
386+
private def addFilesToZipStream(parent: String, source: File, output: ZipOutputStream): Unit = {
387+
if (source.isDirectory()) {
388+
output.putNextEntry(new ZipEntry(parent + source.getName()))
389+
for (file <- source.listFiles()) {
390+
addFilesToZipStream(parent + source.getName() + File.separator, file, output)
391+
}
392+
} else {
393+
val in = new FileInputStream(source)
394+
output.putNextEntry(new ZipEntry(parent + source.getName()))
395+
val buf = new Array[Byte](8192)
396+
var n = 0
397+
while (n != -1) {
398+
n = in.read(buf)
399+
if (n != -1) {
400+
output.write(buf, 0, n)
401+
}
402+
}
403+
in.close()
404+
}
405+
}
406+
383407
private def copy(src: File, dst: File): Seq[File] = {
384408
src.listFiles().flatMap { f =>
385409
val child = new File(dst, f.getName())

0 commit comments

Comments
 (0)