Skip to content
8 changes: 8 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
0.5.2
=====

- Fixed a regression: `AttributeError` when loading pickles that hold a
reference to a dynamically defined class from the `__main__` module.
([issue #131]( https://github.com/cloudpipe/cloudpickle/issues/131)).


0.5.1
=====

Expand Down
6 changes: 5 additions & 1 deletion cloudpickle/cloudpickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,12 +628,16 @@ def save_global(self, obj, name=None, pack=struct.pack):
The name of this method is somewhat misleading: all types get
dispatched here.
"""
if obj.__module__ == "__main__":
return self.save_dynamic_class(obj)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean save_dynamic_class is also used for functions? If so, could you rename that method and update its docstring?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think functions get dispatched here. The dispatch entries for save_global are:

     dispatch[type] = save_global
     dispatch[types.ClassType] = save_global

whereas function-likes are dispatched via:

dispatch[types.FunctionType] = save_function
...
dispatch[types.MethodType] = save_instancemethod
...
dispatch[types.BuiltinFunctionType] = save_builtin_function
...
dispatch[classmethod] = save_classmethod
dispatch[staticmethod] = save_classmethod

I believe the name save_global is a holdover from the base Pickler class, which calls self.save_global in some base class methods that we call into via super, so renaming this isn't straightforward.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have pushed a new commit to simplify that function even further and all tests still pass. However, I am not sure we are not breaking edge cases in third party libraries and applications.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have run the test suite of both loky and joblib against this branch all tests pass as well so I am pretty confident that the changes are fine.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you check whether save_dynamic_class really gets a function as input and, if so, change the name or docstring?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pitrou I'm pretty confident save_dynamic_class can never be called with a function as input. It accesses obj.__bases__ unconditionally, which doesn't exist on function objects.


try:
return Pickler.save_global(self, obj, name=name)
except Exception:
if obj.__module__ == "__builtin__" or obj.__module__ == "builtins":
if obj in _BUILTIN_TYPE_NAMES:
return self.save_reduce(_builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj)
return self.save_reduce(
_builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj)

typ = type(obj)
if typ is not obj and isinstance(obj, (type, types.ClassType)):
Expand Down
65 changes: 62 additions & 3 deletions tests/cloudpickle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from cloudpickle.cloudpickle import _find_module, _make_empty_cell, cell_set

from .testutils import subprocess_pickle_echo
from .testutils import assert_run_python_script


HAVE_WEAKSET = hasattr(weakref, 'WeakSet')
Expand Down Expand Up @@ -287,19 +288,21 @@ def some_method(self, x):
clone_class = pickle_depickle(SomeClass, protocol=self.protocol)
self.assertEqual(clone_class(1).one(), 1)
self.assertEqual(clone_class(5).some_method(41), 7)
clone_class = subprocess_pickle_echo(SomeClass)
clone_class = subprocess_pickle_echo(SomeClass, protocol=self.protocol)
self.assertEqual(clone_class(5).some_method(41), 7)

# pickle the class instances
self.assertEqual(pickle_depickle(SomeClass(1)).one(), 1)
self.assertEqual(pickle_depickle(SomeClass(5)).some_method(41), 7)
new_instance = subprocess_pickle_echo(SomeClass(5))
new_instance = subprocess_pickle_echo(SomeClass(5),
protocol=self.protocol)
self.assertEqual(new_instance.some_method(41), 7)

# pickle the method instances
self.assertEqual(pickle_depickle(SomeClass(1).one)(), 1)
self.assertEqual(pickle_depickle(SomeClass(5).some_method)(41), 7)
new_method = subprocess_pickle_echo(SomeClass(5).some_method)
new_method = subprocess_pickle_echo(SomeClass(5).some_method,
protocol=self.protocol)
self.assertEqual(new_method(41), 7)

def test_partial(self):
Expand Down Expand Up @@ -748,6 +751,62 @@ def test_builtin_type__new__(self):
for t in list, tuple, set, frozenset, dict, object:
self.assertTrue(pickle_depickle(t.__new__) is t.__new__)

def test_interactively_defined_function(self):
# Check that callables defined in the __main__ module of a Python
# script (or jupyter kernel) can be pickled / unpickled / executed.
code = """\
from testutils import subprocess_pickle_echo

CONSTANT = 42

class Foo(object):

def method(self, x):
return x

foo = Foo()

def f0(x):
return x ** 2

def f1():
return Foo

def f2(x):
return Foo().method(x)

def f3():
return Foo().method(CONSTANT)

def f4(x):
return foo.method(x)

cloned = subprocess_pickle_echo(lambda x: x**2, protocol={protocol})
assert cloned(3) == 9

cloned = subprocess_pickle_echo(f0, protocol={protocol})
assert cloned(3) == 9

cloned = subprocess_pickle_echo(Foo, protocol={protocol})
assert cloned().method(2) == Foo().method(2)

cloned = subprocess_pickle_echo(Foo(), protocol={protocol})
assert cloned.method(2) == Foo().method(2)

cloned = subprocess_pickle_echo(f1, protocol={protocol})
assert cloned()().method('a') == f1()().method('a')

cloned = subprocess_pickle_echo(f2, protocol={protocol})
assert cloned(2) == f2(2)

cloned = subprocess_pickle_echo(f3, protocol={protocol})
assert cloned() == f3()

cloned = subprocess_pickle_echo(f4, protocol={protocol})
assert cloned(2) == f4(2)
""".format(protocol=self.protocol)
assert_run_python_script(textwrap.dedent(code))

@pytest.mark.skipif(sys.version_info >= (3, 0),
reason="hardcoded pickle bytes for 2.7")
def test_function_pickle_compat_0_4_0(self):
Expand Down
60 changes: 51 additions & 9 deletions tests/testutils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import sys
import os
from subprocess import Popen
from subprocess import PIPE
import os.path as op
import tempfile
from subprocess import Popen, check_output, PIPE, STDOUT, CalledProcessError

from cloudpickle import dumps
from pickle import loads
Expand All @@ -16,7 +17,7 @@ class TimeoutExpired(Exception):
timeout_supported = False


def subprocess_pickle_echo(input_data):
def subprocess_pickle_echo(input_data, protocol=None):
"""Echo function with a child Python process

Pickle the input data into a buffer, send it to a subprocess via
Expand All @@ -27,10 +28,14 @@ def subprocess_pickle_echo(input_data):
[1, 'a', None]

"""
pickled_input_data = dumps(input_data)
cmd = [sys.executable, __file__]
cwd = os.getcwd()
proc = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE, cwd=cwd)
pickled_input_data = dumps(input_data, protocol=protocol)
cmd = [sys.executable, __file__] # run then pickle_echo() in __main__
cloudpickle_repo_folder = op.normpath(
op.join(op.dirname(__file__), '..'))
cwd = cloudpickle_repo_folder
pythonpath = "{src}/tests:{src}".format(src=cloudpickle_repo_folder)
env = {'PYTHONPATH': pythonpath}
proc = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE, cwd=cwd, env=env)
try:
comm_kwargs = {}
if timeout_supported:
Expand All @@ -48,7 +53,7 @@ def subprocess_pickle_echo(input_data):
raise RuntimeError(message)


def pickle_echo(stream_in=None, stream_out=None):
def pickle_echo(stream_in=None, stream_out=None, protocol=None):
"""Read a pickle from stdin and pickle it back to stdout"""
if stream_in is None:
stream_in = sys.stdin
Expand All @@ -64,9 +69,46 @@ def pickle_echo(stream_in=None, stream_out=None):
input_bytes = stream_in.read()
stream_in.close()
unpickled_content = loads(input_bytes)
stream_out.write(dumps(unpickled_content))
stream_out.write(dumps(unpickled_content, protocol=protocol))
stream_out.close()


def assert_run_python_script(source_code, timeout=5):
"""Utility to help check pickleability of objects defined in __main__

The script provided in the source code should return 0 and not print
anything on stderr or stdout.
"""
fd, source_file = tempfile.mkstemp(suffix='_src_test_cloudpickle.py')
os.close(fd)
try:
with open(source_file, 'wb') as f:
f.write(source_code.encode('utf-8'))
cmd = [sys.executable, source_file]
cloudpickle_repo_folder = op.normpath(
op.join(op.dirname(__file__), '..'))
pythonpath = "{src}/tests:{src}".format(src=cloudpickle_repo_folder)
kwargs = {
'cwd': cloudpickle_repo_folder,
'stderr': STDOUT,
'env': {'PYTHONPATH': pythonpath},
}
if timeout_supported:
kwargs['timeout'] = timeout
try:
try:
out = check_output(cmd, **kwargs)
except CalledProcessError as e:
raise RuntimeError(u"script errored with output:\n%s"
% e.output.decode('utf-8'))
if out != b"":
raise AssertionError(out.decode('utf-8'))
except TimeoutExpired as e:
raise RuntimeError(u"script timeout, output so far:\n%s"
% e.output.decode('utf-8'))
finally:
os.unlink(source_file)


if __name__ == '__main__':
pickle_echo()