2525from  pyspark .storagelevel  import  StorageLevel 
2626from  pyspark .resultiterable  import  ResultIterable 
2727from  pyspark .streaming .util  import  rddToFileName , RDDFunction 
28+ from  pyspark .rdd  import  portable_hash , _parse_memory 
2829from  pyspark .traceback_utils  import  SCCallSiteSync 
2930
3031from  py4j .java_collections  import  ListConverter , MapConverter 
@@ -40,6 +41,7 @@ def __init__(self, jdstream, ssc, jrdd_deserializer):
4041        self ._jrdd_deserializer  =  jrdd_deserializer 
4142        self .is_cached  =  False 
4243        self .is_checkpointed  =  False 
44+         self ._partitionFunc  =  None 
4345
4446    def  context (self ):
4547        """ 
@@ -161,32 +163,71 @@ def _mergeCombiners(iterator):
161163
162164        return  shuffled .mapPartitions (_mergeCombiners )
163165
164-     def  partitionBy (self , numPartitions , partitionFunc = None ):
166+     def  partitionBy (self , numPartitions , partitionFunc = portable_hash ):
165167        """ 
166168        Return a copy of the DStream partitioned using the specified partitioner. 
167169        """ 
168170        if  numPartitions  is  None :
169171            numPartitions  =  self .ctx ._defaultReducePartitions ()
170172
171-         if  partitionFunc  is  None :
172-             partitionFunc  =  lambda  x : 0  if  x  is  None  else  hash (x )
173- 
174173        # Transferring O(n) objects to Java is too expensive.  Instead, we'll 
175174        # form the hash buckets in Python, transferring O(numPartitions) objects 
176175        # to Java.  Each object is a (splitNumber, [objects]) pair. 
176+ 
177177        outputSerializer  =  self .ctx ._unbatched_serializer 
178+ # 
179+ #        def add_shuffle_key(split, iterator): 
180+ #            buckets = defaultdict(list) 
181+ # 
182+ #            for (k, v) in iterator: 
183+ #                buckets[partitionFunc(k) % numPartitions].append((k, v)) 
184+ #            for (split, items) in buckets.iteritems(): 
185+ #                yield pack_long(split) 
186+ #                yield outputSerializer.dumps(items) 
187+ #        keyed = PipelinedDStream(self, add_shuffle_key) 
188+ 
189+         limit  =  (_parse_memory (self .ctx ._conf .get (
190+             "spark.python.worker.memory" , "512m" )) /  2 )
178191
179192        def  add_shuffle_key (split , iterator ):
193+ 
180194            buckets  =  defaultdict (list )
195+             c , batch  =  0 , min (10  *  numPartitions , 1000 )
181196
182-             for  ( k , v )  in  iterator :
197+             for  k , v  in  iterator :
183198                buckets [partitionFunc (k ) %  numPartitions ].append ((k , v ))
184-             for  (split , items ) in  buckets .iteritems ():
199+                 c  +=  1 
200+ 
201+                 # check used memory and avg size of chunk of objects 
202+                 if  (c  %  1000  ==  0  and  get_used_memory () >  limit 
203+                         or  c  >  batch ):
204+                     n , size  =  len (buckets ), 0 
205+                     for  split  in  buckets .keys ():
206+                         yield  pack_long (split )
207+                         d  =  outputSerializer .dumps (buckets [split ])
208+                         del  buckets [split ]
209+                         yield  d 
210+                         size  +=  len (d )
211+ 
212+                     avg  =  (size  /  n ) >>  20 
213+                     # let 1M < avg < 10M 
214+                     if  avg  <  1 :
215+                         batch  *=  1.5 
216+                     elif  avg  >  10 :
217+                         batch  =  max (batch  /  1.5 , 1 )
218+                     c  =  0 
219+ 
220+             for  split , items  in  buckets .iteritems ():
185221                yield  pack_long (split )
186222                yield  outputSerializer .dumps (items )
187-         keyed  =  PipelinedDStream (self , add_shuffle_key )
223+ 
224+         keyed  =  self ._mapPartitionsWithIndex (add_shuffle_key )
225+ 
226+ 
227+ 
228+ 
188229        keyed ._bypass_serializer  =  True 
189-         with  SCCallSiteSync (self .context ) as  css :
230+         with  SCCallSiteSync (self .ctx ) as  css :
190231            partitioner  =  self .ctx ._jvm .PythonPartitioner (numPartitions ,
191232                                                          id (partitionFunc ))
192233            jdstream  =  self .ctx ._jvm .PythonPairwiseDStream (keyed ._jdstream .dstream (),
@@ -428,6 +469,10 @@ def get_output(rdd, time):
428469
429470
430471class  PipelinedDStream (DStream ):
472+     """ 
473+     Since PipelinedDStream is same to PipelindRDD, if PipliedRDD is changed, 
474+     this code should be changed in the same way. 
475+     """ 
431476    def  __init__ (self , prev , func , preservesPartitioning = False ):
432477        if  not  isinstance (prev , PipelinedDStream ) or  not  prev ._is_pipelinable ():
433478            # This transformation is the first in its stage: 
@@ -453,19 +498,22 @@ def pipeline_func(split, iterator):
453498        self ._jdstream_val  =  None 
454499        self ._jrdd_deserializer  =  self .ctx .serializer 
455500        self ._bypass_serializer  =  False 
501+         self ._partitionFunc  =  prev ._partitionFunc  if  self .preservesPartitioning  else  None 
456502
457503    @property  
458504    def  _jdstream (self ):
459505        if  self ._jdstream_val :
460506            return  self ._jdstream_val 
461507        if  self ._bypass_serializer :
462-             serializer  =  NoOpSerializer ()
463-         else :
464-             serializer  =  self .ctx .serializer 
465- 
466-         command  =  (self .func , self ._prev_jrdd_deserializer , serializer )
467-         ser  =  CompressedSerializer (CloudPickleSerializer ())
508+             self .jrdd_deserializer  =  NoOpSerializer ()
509+         command  =  (self .func , self ._prev_jrdd_deserializer ,
510+                    self ._jrdd_deserializer )
511+         # the serialized command will be compressed by broadcast 
512+         ser  =  CloudPickleSerializer ()
468513        pickled_command  =  ser .dumps (command )
514+         if  pickled_command  >  (1  <<  20 ):  # 1M 
515+             broadcast  =  self .ctx .broadcast (pickled_command )
516+             pickled_command  =  ser .dumps (broadcast )
469517        broadcast_vars  =  ListConverter ().convert (
470518            [x ._jbroadcast  for  x  in  self .ctx ._pickled_broadcast_vars ],
471519            self .ctx ._gateway ._gateway_client )
0 commit comments