diff --git a/CHANGES.md b/CHANGES.md index 5d4a5a096..dfd571686 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,11 @@ +0.5.2 +===== + +- Fixed a regression: `AttributeError` when loading pickles that hold a + reference to a dynamically defined class from the `__main__` module. + ([issue #131]( https://github.com/cloudpipe/cloudpickle/issues/131)). + + 0.5.1 ===== diff --git a/cloudpickle/cloudpickle.py b/cloudpickle/cloudpickle.py index 598834ee6..a185a4e4b 100644 --- a/cloudpickle/cloudpickle.py +++ b/cloudpickle/cloudpickle.py @@ -628,12 +628,16 @@ def save_global(self, obj, name=None, pack=struct.pack): The name of this method is somewhat misleading: all types get dispatched here. """ + if obj.__module__ == "__main__": + return self.save_dynamic_class(obj) + try: return Pickler.save_global(self, obj, name=name) except Exception: if obj.__module__ == "__builtin__" or obj.__module__ == "builtins": if obj in _BUILTIN_TYPE_NAMES: - return self.save_reduce(_builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj) + return self.save_reduce( + _builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj) typ = type(obj) if typ is not obj and isinstance(obj, (type, types.ClassType)): diff --git a/tests/cloudpickle_test.py b/tests/cloudpickle_test.py index efe0ebff0..af208df66 100644 --- a/tests/cloudpickle_test.py +++ b/tests/cloudpickle_test.py @@ -44,6 +44,7 @@ from cloudpickle.cloudpickle import _find_module, _make_empty_cell, cell_set from .testutils import subprocess_pickle_echo +from .testutils import assert_run_python_script HAVE_WEAKSET = hasattr(weakref, 'WeakSet') @@ -287,19 +288,21 @@ def some_method(self, x): clone_class = pickle_depickle(SomeClass, protocol=self.protocol) self.assertEqual(clone_class(1).one(), 1) self.assertEqual(clone_class(5).some_method(41), 7) - clone_class = subprocess_pickle_echo(SomeClass) + clone_class = subprocess_pickle_echo(SomeClass, protocol=self.protocol) self.assertEqual(clone_class(5).some_method(41), 7) # pickle the class instances self.assertEqual(pickle_depickle(SomeClass(1)).one(), 1) self.assertEqual(pickle_depickle(SomeClass(5)).some_method(41), 7) - new_instance = subprocess_pickle_echo(SomeClass(5)) + new_instance = subprocess_pickle_echo(SomeClass(5), + protocol=self.protocol) self.assertEqual(new_instance.some_method(41), 7) # pickle the method instances self.assertEqual(pickle_depickle(SomeClass(1).one)(), 1) self.assertEqual(pickle_depickle(SomeClass(5).some_method)(41), 7) - new_method = subprocess_pickle_echo(SomeClass(5).some_method) + new_method = subprocess_pickle_echo(SomeClass(5).some_method, + protocol=self.protocol) self.assertEqual(new_method(41), 7) def test_partial(self): @@ -748,6 +751,62 @@ def test_builtin_type__new__(self): for t in list, tuple, set, frozenset, dict, object: self.assertTrue(pickle_depickle(t.__new__) is t.__new__) + def test_interactively_defined_function(self): + # Check that callables defined in the __main__ module of a Python + # script (or jupyter kernel) can be pickled / unpickled / executed. + code = """\ + from testutils import subprocess_pickle_echo + + CONSTANT = 42 + + class Foo(object): + + def method(self, x): + return x + + foo = Foo() + + def f0(x): + return x ** 2 + + def f1(): + return Foo + + def f2(x): + return Foo().method(x) + + def f3(): + return Foo().method(CONSTANT) + + def f4(x): + return foo.method(x) + + cloned = subprocess_pickle_echo(lambda x: x**2, protocol={protocol}) + assert cloned(3) == 9 + + cloned = subprocess_pickle_echo(f0, protocol={protocol}) + assert cloned(3) == 9 + + cloned = subprocess_pickle_echo(Foo, protocol={protocol}) + assert cloned().method(2) == Foo().method(2) + + cloned = subprocess_pickle_echo(Foo(), protocol={protocol}) + assert cloned.method(2) == Foo().method(2) + + cloned = subprocess_pickle_echo(f1, protocol={protocol}) + assert cloned()().method('a') == f1()().method('a') + + cloned = subprocess_pickle_echo(f2, protocol={protocol}) + assert cloned(2) == f2(2) + + cloned = subprocess_pickle_echo(f3, protocol={protocol}) + assert cloned() == f3() + + cloned = subprocess_pickle_echo(f4, protocol={protocol}) + assert cloned(2) == f4(2) + """.format(protocol=self.protocol) + assert_run_python_script(textwrap.dedent(code)) + @pytest.mark.skipif(sys.version_info >= (3, 0), reason="hardcoded pickle bytes for 2.7") def test_function_pickle_compat_0_4_0(self): diff --git a/tests/testutils.py b/tests/testutils.py index 110e2f78d..a8187baf3 100644 --- a/tests/testutils.py +++ b/tests/testutils.py @@ -1,7 +1,8 @@ import sys import os -from subprocess import Popen -from subprocess import PIPE +import os.path as op +import tempfile +from subprocess import Popen, check_output, PIPE, STDOUT, CalledProcessError from cloudpickle import dumps from pickle import loads @@ -16,7 +17,7 @@ class TimeoutExpired(Exception): timeout_supported = False -def subprocess_pickle_echo(input_data): +def subprocess_pickle_echo(input_data, protocol=None): """Echo function with a child Python process Pickle the input data into a buffer, send it to a subprocess via @@ -27,10 +28,14 @@ def subprocess_pickle_echo(input_data): [1, 'a', None] """ - pickled_input_data = dumps(input_data) - cmd = [sys.executable, __file__] - cwd = os.getcwd() - proc = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE, cwd=cwd) + pickled_input_data = dumps(input_data, protocol=protocol) + cmd = [sys.executable, __file__] # run then pickle_echo() in __main__ + cloudpickle_repo_folder = op.normpath( + op.join(op.dirname(__file__), '..')) + cwd = cloudpickle_repo_folder + pythonpath = "{src}/tests:{src}".format(src=cloudpickle_repo_folder) + env = {'PYTHONPATH': pythonpath} + proc = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE, cwd=cwd, env=env) try: comm_kwargs = {} if timeout_supported: @@ -48,7 +53,7 @@ def subprocess_pickle_echo(input_data): raise RuntimeError(message) -def pickle_echo(stream_in=None, stream_out=None): +def pickle_echo(stream_in=None, stream_out=None, protocol=None): """Read a pickle from stdin and pickle it back to stdout""" if stream_in is None: stream_in = sys.stdin @@ -64,9 +69,46 @@ def pickle_echo(stream_in=None, stream_out=None): input_bytes = stream_in.read() stream_in.close() unpickled_content = loads(input_bytes) - stream_out.write(dumps(unpickled_content)) + stream_out.write(dumps(unpickled_content, protocol=protocol)) stream_out.close() +def assert_run_python_script(source_code, timeout=5): + """Utility to help check pickleability of objects defined in __main__ + + The script provided in the source code should return 0 and not print + anything on stderr or stdout. + """ + fd, source_file = tempfile.mkstemp(suffix='_src_test_cloudpickle.py') + os.close(fd) + try: + with open(source_file, 'wb') as f: + f.write(source_code.encode('utf-8')) + cmd = [sys.executable, source_file] + cloudpickle_repo_folder = op.normpath( + op.join(op.dirname(__file__), '..')) + pythonpath = "{src}/tests:{src}".format(src=cloudpickle_repo_folder) + kwargs = { + 'cwd': cloudpickle_repo_folder, + 'stderr': STDOUT, + 'env': {'PYTHONPATH': pythonpath}, + } + if timeout_supported: + kwargs['timeout'] = timeout + try: + try: + out = check_output(cmd, **kwargs) + except CalledProcessError as e: + raise RuntimeError(u"script errored with output:\n%s" + % e.output.decode('utf-8')) + if out != b"": + raise AssertionError(out.decode('utf-8')) + except TimeoutExpired as e: + raise RuntimeError(u"script timeout, output so far:\n%s" + % e.output.decode('utf-8')) + finally: + os.unlink(source_file) + + if __name__ == '__main__': pickle_echo()