diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 1ed098c076ae5..4f9c928a57d97 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -28,7 +28,7 @@ import socket from subprocess import Popen, PIPE from tempfile import NamedTemporaryFile -from threading import Thread +from threading import Thread, Lock from collections import defaultdict from itertools import chain from functools import reduce @@ -57,6 +57,9 @@ __all__ = ["RDD"] +# Lock which will make sure that dependend broadcast variables are pickled along +# with their PythonRDD wrapped function when using multple threads(SPARK-12717). +_lock = Lock() def portable_hash(x): """ @@ -2374,14 +2377,15 @@ def _jrdd(self): else: profiler = None - command = (self.func, profiler, self._prev_jrdd_deserializer, - self._jrdd_deserializer) - pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self.ctx, command, self) - python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), - bytearray(pickled_cmd), - env, includes, self.preservesPartitioning, - self.ctx.pythonExec, self.ctx.pythonVer, - bvars, self.ctx._javaAccumulator) + with _lock: + command = (self.func, profiler, self._prev_jrdd_deserializer, + self._jrdd_deserializer) + pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self.ctx, command, self) + python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), + bytearray(pickled_cmd), + env, includes, self.preservesPartitioning, + self.ctx.pythonExec, self.ctx.pythonVer, + bvars, self.ctx._javaAccumulator) self._jrdd_val = python_rdd.asJavaRDD() if profiler: