Skip to content

Commit 54ff1f9

Browse files
zsxwingdavies
authored andcommitted
[SPARK-13697] [PYSPARK] Fix the missing module name of TransformFunctionSerializer.loads
## What changes were proposed in this pull request? Set the function's module name to `__main__` if it's missing in `TransformFunctionSerializer.loads`. ## How was this patch tested? Manually test in the shell. Before this patch: ``` >>> from pyspark.streaming import StreamingContext >>> from pyspark.streaming.util import TransformFunction >>> ssc = StreamingContext(sc, 1) >>> func = TransformFunction(sc, lambda x: x, sc.serializer) >>> func.rdd_wrapper(lambda x: x) TransformFunction(<function <lambda> at 0x106ac8b18>) >>> bytes = bytearray(ssc._transformerSerializer.serializer.dumps((func.func, func.rdd_wrap_func, func.deserializers))) >>> func2 = ssc._transformerSerializer.loads(bytes) >>> print(func2.func.__module__) None >>> print(func2.rdd_wrap_func.__module__) None >>> ``` After this patch: ``` >>> from pyspark.streaming import StreamingContext >>> from pyspark.streaming.util import TransformFunction >>> ssc = StreamingContext(sc, 1) >>> func = TransformFunction(sc, lambda x: x, sc.serializer) >>> func.rdd_wrapper(lambda x: x) TransformFunction(<function <lambda> at 0x108bf1b90>) >>> bytes = bytearray(ssc._transformerSerializer.serializer.dumps((func.func, func.rdd_wrap_func, func.deserializers))) >>> func2 = ssc._transformerSerializer.loads(bytes) >>> print(func2.func.__module__) __main__ >>> print(func2.rdd_wrap_func.__module__) __main__ >>> ``` Author: Shixiong Zhu <[email protected]> Closes #11535 from zsxwing/loads-module. (cherry picked from commit ee913e6) Signed-off-by: Davies Liu <[email protected]>
1 parent 9d8404b commit 54ff1f9

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

python/pyspark/cloudpickle.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@ def save_function_tuple(self, func):
241241
save(f_globals)
242242
save(defaults)
243243
save(dct)
244+
save(func.__module__)
244245
write(pickle.TUPLE)
245246
write(pickle.REDUCE) # applies _fill_function on the tuple
246247

@@ -674,13 +675,14 @@ def _genpartial(func, args, kwds):
674675
return partial(func, *args, **kwds)
675676

676677

677-
def _fill_function(func, globals, defaults, dict):
678+
def _fill_function(func, globals, defaults, dict, module):
678679
""" Fills in the rest of function data into the skeleton function object
679680
that were created via _make_skel_func().
680681
"""
681682
func.__globals__.update(globals)
682683
func.__defaults__ = defaults
683684
func.__dict__ = dict
685+
func.__module__ = module
684686

685687
return func
686688

python/pyspark/tests.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,12 @@ def test_itemgetter(self):
230230
getter2 = ser.loads(ser.dumps(getter))
231231
self.assertEqual(getter(d), getter2(d))
232232

233+
def test_function_module_name(self):
234+
ser = CloudPickleSerializer()
235+
func = lambda x: x
236+
func2 = ser.loads(ser.dumps(func))
237+
self.assertEqual(func.__module__, func2.__module__)
238+
233239
def test_attrgetter(self):
234240
from operator import attrgetter
235241
ser = CloudPickleSerializer()

0 commit comments

Comments
 (0)