Skip to content

Commit aefd508

Browse files
committed
use broadcast automatically for large closure
1 parent febafef commit aefd508

File tree

4 files changed

+18
-2
lines changed

4 files changed

+18
-2
lines changed

python/pyspark/rdd.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2061,8 +2061,12 @@ def _jrdd(self):
20612061
self._jrdd_deserializer = NoOpSerializer()
20622062
command = (self.func, self._prev_jrdd_deserializer,
20632063
self._jrdd_deserializer)
2064+
# the serialized command will be compressed by broadcast
20642065
ser = CloudPickleSerializer()
20652066
pickled_command = ser.dumps(command)
2067+
if pickled_command > (1 << 20): # 1M
2068+
broadcast = self.ctx.broadcast(pickled_command)
2069+
pickled_command = ser.dumps(broadcast)
20662070
broadcast_vars = ListConverter().convert(
20672071
[x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
20682072
self.ctx._gateway._gateway_client)

python/pyspark/sql.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from array import array
2828
from operator import itemgetter
2929

30-
from pyspark.rdd import RDD, PipelinedRDD
30+
from pyspark.rdd import RDD
3131
from pyspark.serializers import BatchedSerializer, PickleSerializer, CloudPickleSerializer
3232
from pyspark.storagelevel import StorageLevel
3333

@@ -974,7 +974,11 @@ def registerFunction(self, name, f, returnType=StringType()):
974974
command = (func,
975975
BatchedSerializer(PickleSerializer(), 1024),
976976
BatchedSerializer(PickleSerializer(), 1024))
977-
pickled_command = CloudPickleSerializer().dumps(command)
977+
ser = CloudPickleSerializer()
978+
pickled_command = ser.dumps(command)
979+
if pickled_command > (1 << 20): # 1M
980+
broadcast = self._sc.broadcast(pickled_command)
981+
pickled_command = ser.dumps(broadcast)
978982
broadcast_vars = ListConverter().convert(
979983
[x._jbroadcast for x in self._sc._pickled_broadcast_vars],
980984
self._sc._gateway._gateway_client)

python/pyspark/tests.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,12 @@ def test_large_broadcast(self):
434434
m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
435435
self.assertEquals(N, m)
436436

437+
def test_large_closure(self):
438+
N = 1000000
439+
data = [float(i) for i in xrange(N)]
440+
m = self.sc.parallelize(range(1), 1).map(lambda x: len(data)).sum()
441+
self.assertEquals(N, m)
442+
437443
def test_zip_with_different_serializers(self):
438444
a = self.sc.parallelize(range(5))
439445
b = self.sc.parallelize(range(100, 105))

python/pyspark/worker.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ def main(infile, outfile):
8181

8282
_accumulatorRegistry.clear()
8383
command = pickleSer._read_with_length(infile)
84+
if isinstance(command, Broadcast):
85+
command = pickleSer.loads(command.value)
8486
(func, deserializer, serializer) = command
8587
init_time = time.time()
8688
iterator = deserializer.load_stream(infile)

0 commit comments

Comments
 (0)