@@ -23,6 +23,7 @@ import java.nio.charset.Charset
2323import java .util .{List => JList , ArrayList => JArrayList , Map => JMap , Collections }
2424
2525import scala .collection .JavaConversions ._
26+ import scala .collection .mutable
2627import scala .language .existentials
2728import scala .reflect .ClassTag
2829import scala .util .{Try , Success , Failure }
@@ -52,6 +53,7 @@ private[spark] class PythonRDD(
5253 extends RDD [Array [Byte ]](parent) {
5354
5455 val bufferSize = conf.getInt(" spark.buffer.size" , 65536 )
56+ val reuse_worker = conf.getBoolean(" spark.python.worker.reuse" , true )
5557
5658 override def getPartitions = parent.partitions
5759
@@ -63,19 +65,26 @@ private[spark] class PythonRDD(
6365 val localdir = env.blockManager.diskBlockManager.localDirs.map(
6466 f => f.getPath()).mkString(" ," )
6567 envVars += (" SPARK_LOCAL_DIRS" -> localdir) // it's also used in monitor thread
68+ if (reuse_worker) {
69+ envVars += (" SPARK_REUSE_WORKER" -> " 1" )
70+ }
6671 val worker : Socket = env.createPythonWorker(pythonExec, envVars.toMap)
6772
6873 // Start a thread to feed the process input from our parent's iterator
6974 val writerThread = new WriterThread (env, worker, split, context)
7075
76+ var complete_cleanly = false
7177 context.addTaskCompletionListener { context =>
7278 writerThread.shutdownOnTaskCompletion()
73-
74- // Cleanup the worker socket. This will also cause the Python worker to exit.
75- try {
76- worker.close()
77- } catch {
78- case e : Exception => logWarning(" Failed to close worker socket" , e)
79+ if (reuse_worker && complete_cleanly) {
80+ env.releasePythonWorker(pythonExec, envVars.toMap, worker)
81+ } else {
82+ try {
83+ worker.close()
84+ } catch {
85+ case e : Exception =>
86+ logWarning(" Failed to close worker socket" , e)
87+ }
7988 }
8089 }
8190
@@ -115,6 +124,10 @@ private[spark] class PythonRDD(
115124 val total = finishTime - startTime
116125 logInfo(" Times: total = %s, boot = %s, init = %s, finish = %s" .format(total, boot,
117126 init, finish))
127+ val memoryBytesSpilled = stream.readLong()
128+ val diskBytesSpilled = stream.readLong()
129+ context.taskMetrics.memoryBytesSpilled += memoryBytesSpilled
130+ context.taskMetrics.diskBytesSpilled += diskBytesSpilled
118131 read()
119132 case SpecialLengths .PYTHON_EXCEPTION_THROWN =>
120133 // Signals that an exception has been thrown in python
@@ -133,6 +146,7 @@ private[spark] class PythonRDD(
133146 stream.readFully(update)
134147 accumulator += Collections .singletonList(update)
135148 }
149+ complete_cleanly = true
136150 null
137151 }
138152 } catch {
@@ -195,29 +209,45 @@ private[spark] class PythonRDD(
195209 PythonRDD .writeUTF(include, dataOut)
196210 }
197211 // Broadcast variables
198- dataOut.writeInt(broadcastVars.length)
212+ val oldBids = PythonRDD .getWorkerBroadcasts(worker)
213+ val newBids = broadcastVars.map(_.id).toSet
214+ // number of different broadcasts
215+ val cnt = oldBids.diff(newBids).size + newBids.diff(oldBids).size
216+ dataOut.writeInt(cnt)
217+ for (bid <- oldBids) {
218+ if (! newBids.contains(bid)) {
219+ // remove the broadcast from worker
220+ dataOut.writeLong(- bid - 1 ) // bid >= 0
221+ oldBids.remove(bid)
222+ }
223+ }
199224 for (broadcast <- broadcastVars) {
200- dataOut.writeLong(broadcast.id)
201- dataOut.writeInt(broadcast.value.length)
202- dataOut.write(broadcast.value)
225+ if (! oldBids.contains(broadcast.id)) {
226+ // send new broadcast
227+ dataOut.writeLong(broadcast.id)
228+ dataOut.writeInt(broadcast.value.length)
229+ dataOut.write(broadcast.value)
230+ oldBids.add(broadcast.id)
231+ }
203232 }
204233 dataOut.flush()
205234 // Serialized command:
206235 dataOut.writeInt(command.length)
207236 dataOut.write(command)
208237 // Data values
209238 PythonRDD .writeIteratorToStream(parent.iterator(split, context), dataOut)
239+ dataOut.writeInt(SpecialLengths .END_OF_DATA_SECTION )
210240 dataOut.flush()
211241 } catch {
212242 case e : Exception if context.isCompleted || context.isInterrupted =>
213243 logDebug(" Exception thrown after task completion (likely due to cleanup)" , e)
244+ worker.shutdownOutput()
214245
215246 case e : Exception =>
216247 // We must avoid throwing exceptions here, because the thread uncaught exception handler
217248 // will kill the whole executor (see org.apache.spark.executor.Executor).
218249 _exception = e
219- } finally {
220- Try (worker.shutdownOutput()) // kill Python worker process
250+ worker.shutdownOutput()
221251 }
222252 }
223253 }
@@ -278,6 +308,14 @@ private object SpecialLengths {
278308private [spark] object PythonRDD extends Logging {
279309 val UTF8 = Charset .forName(" UTF-8" )
280310
311+ // remember the broadcasts sent to each worker
312+ private val workerBroadcasts = new mutable.WeakHashMap [Socket , mutable.Set [Long ]]()
313+ private def getWorkerBroadcasts (worker : Socket ) = {
314+ synchronized {
315+ workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet [Long ]())
316+ }
317+ }
318+
281319 /**
282320 * Adapter for calling SparkContext#runJob from Python.
283321 *
0 commit comments