@@ -23,6 +23,7 @@ import java.nio.charset.Charset
2323import  java .util .{List  =>  JList , ArrayList  =>  JArrayList , Map  =>  JMap , Collections }
2424
2525import  scala .collection .JavaConversions ._ 
26+ import  scala .collection .mutable 
2627import  scala .language .existentials 
2728import  scala .reflect .ClassTag 
2829import  scala .util .{Try , Success , Failure }
@@ -52,6 +53,7 @@ private[spark] class PythonRDD(
5253  extends  RDD [Array [Byte ]](parent) {
5354
5455  val  bufferSize  =  conf.getInt(" spark.buffer.size"  , 65536 )
56+   val  reuse_worker  =  conf.getBoolean(" spark.python.worker.reuse"  , true )
5557
5658  override  def  getPartitions  =  parent.partitions
5759
@@ -63,19 +65,26 @@ private[spark] class PythonRDD(
6365    val  localdir  =  env.blockManager.diskBlockManager.localDirs.map(
6466      f =>  f.getPath()).mkString(" ,"  )
6567    envVars +=  (" SPARK_LOCAL_DIRS"   ->  localdir) //  it's also used in monitor thread
68+     if  (reuse_worker) {
69+       envVars +=  (" SPARK_REUSE_WORKER"   ->  " 1"  )
70+     }
6671    val  worker :  Socket  =  env.createPythonWorker(pythonExec, envVars.toMap)
6772
6873    //  Start a thread to feed the process input from our parent's iterator
6974    val  writerThread  =  new  WriterThread (env, worker, split, context)
7075
76+     var  complete_cleanly  =  false 
7177    context.addTaskCompletionListener { context => 
7278      writerThread.shutdownOnTaskCompletion()
73- 
74-       //  Cleanup the worker socket. This will also cause the Python worker to exit.
75-       try  {
76-         worker.close()
77-       } catch  {
78-         case  e : Exception  =>  logWarning(" Failed to close worker socket"  , e)
79+       if  (reuse_worker &&  complete_cleanly) {
80+         env.releasePythonWorker(pythonExec, envVars.toMap, worker)
81+       } else  {
82+         try  {
83+           worker.close()
84+         } catch  {
85+           case  e : Exception  => 
86+             logWarning(" Failed to close worker socket"  , e)
87+         }
7988      }
8089    }
8190
@@ -115,6 +124,10 @@ private[spark] class PythonRDD(
115124              val  total  =  finishTime -  startTime
116125              logInfo(" Times: total = %s, boot = %s, init = %s, finish = %s"  .format(total, boot,
117126                init, finish))
127+               val  memoryBytesSpilled  =  stream.readLong()
128+               val  diskBytesSpilled  =  stream.readLong()
129+               context.taskMetrics.memoryBytesSpilled +=  memoryBytesSpilled
130+               context.taskMetrics.diskBytesSpilled +=  diskBytesSpilled
118131              read()
119132            case  SpecialLengths .PYTHON_EXCEPTION_THROWN  => 
120133              //  Signals that an exception has been thrown in python
@@ -133,6 +146,7 @@ private[spark] class PythonRDD(
133146                stream.readFully(update)
134147                accumulator +=  Collections .singletonList(update)
135148              }
149+                complete_cleanly =  true 
136150              null 
137151          }
138152        } catch  {
@@ -195,29 +209,45 @@ private[spark] class PythonRDD(
195209          PythonRDD .writeUTF(include, dataOut)
196210        }
197211        //  Broadcast variables
198-         dataOut.writeInt(broadcastVars.length)
212+         val  oldBids  =  PythonRDD .getWorkerBroadcasts(worker)
213+         val  newBids  =  broadcastVars.map(_.id).toSet
214+         //  number of different broadcasts
215+         val  cnt  =  oldBids.diff(newBids).size +  newBids.diff(oldBids).size
216+         dataOut.writeInt(cnt)
217+         for  (bid <-  oldBids) {
218+           if  (! newBids.contains(bid)) {
219+             //  remove the broadcast from worker
220+             dataOut.writeLong(-  bid -  1 )  //  bid >= 0
221+             oldBids.remove(bid)
222+           }
223+         }
199224        for  (broadcast <-  broadcastVars) {
200-           dataOut.writeLong(broadcast.id)
201-           dataOut.writeInt(broadcast.value.length)
202-           dataOut.write(broadcast.value)
225+           if  (! oldBids.contains(broadcast.id)) {
226+             //  send new broadcast
227+             dataOut.writeLong(broadcast.id)
228+             dataOut.writeInt(broadcast.value.length)
229+             dataOut.write(broadcast.value)
230+             oldBids.add(broadcast.id)
231+           }
203232        }
204233        dataOut.flush()
205234        //  Serialized command:
206235        dataOut.writeInt(command.length)
207236        dataOut.write(command)
208237        //  Data values
209238        PythonRDD .writeIteratorToStream(parent.iterator(split, context), dataOut)
239+         dataOut.writeInt(SpecialLengths .END_OF_DATA_SECTION )
210240        dataOut.flush()
211241      } catch  {
212242        case  e : Exception  if  context.isCompleted ||  context.isInterrupted => 
213243          logDebug(" Exception thrown after task completion (likely due to cleanup)"  , e)
244+           worker.shutdownOutput()
214245
215246        case  e : Exception  => 
216247          //  We must avoid throwing exceptions here, because the thread uncaught exception handler
217248          //  will kill the whole executor (see org.apache.spark.executor.Executor).
218249          _exception =  e
219-       } finally  {
220-         Try (worker.shutdownOutput()) //  kill Python worker process
250+           worker.shutdownOutput()
221251      }
222252    }
223253  }
@@ -278,6 +308,14 @@ private object SpecialLengths {
278308private [spark] object  PythonRDD  extends  Logging  {
279309  val  UTF8  =  Charset .forName(" UTF-8"  )
280310
311+   //  remember the broadcasts sent to each worker
312+   private  val  workerBroadcasts  =  new  mutable.WeakHashMap [Socket , mutable.Set [Long ]]()
313+   private  def  getWorkerBroadcasts (worker : Socket ) =  {
314+     synchronized  {
315+       workerBroadcasts.getOrElseUpdate(worker, new  mutable.HashSet [Long ]())
316+     }
317+   }
318+ 
281319  /**  
282320   * Adapter for calling SparkContext#runJob from Python. 
283321   * 
0 commit comments