Skip to content

Commit 9b23be2

Browse files
schintaptgravescs
authored andcommitted
[SPARK-26201] Fix python broadcast with encryption
## What changes were proposed in this pull request? Python with rpc and disk encryption enabled along with a python broadcast variable and just read the value back on the driver side the job failed with: Traceback (most recent call last): File "broadcast.py", line 37, in <module> words_new.value File "/pyspark.zip/pyspark/broadcast.py", line 137, in value File "pyspark.zip/pyspark/broadcast.py", line 122, in load_from_path File "pyspark.zip/pyspark/broadcast.py", line 128, in load EOFError: Ran out of input To reproduce use configs: --conf spark.network.crypto.enabled=true --conf spark.io.encryption.enabled=true Code: words_new = sc.broadcast(["scala", "java", "hadoop", "spark", "akka"]) words_new.value print(words_new.value) ## How was this patch tested? words_new = sc.broadcast([“scala”, “java”, “hadoop”, “spark”, “akka”]) textFile = sc.textFile(“README.md”) wordCounts = textFile.flatMap(lambda line: line.split()).map(lambda word: (word + words_new.value[1], 1)).reduceByKey(lambda a, b: a+b) count = wordCounts.count() print(count) words_new.value print(words_new.value) Closes #23166 from redsanket/SPARK-26201. Authored-by: schintap <[email protected]> Signed-off-by: Thomas Graves <[email protected]>
1 parent c3f27b2 commit 9b23be2

File tree

3 files changed

+56
-9
lines changed

3 files changed

+56
-9
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,7 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial
660660
with Logging {
661661

662662
private var encryptionServer: PythonServer[Unit] = null
663+
private var decryptionServer: PythonServer[Unit] = null
663664

664665
/**
665666
* Read data from disks, then copy it to `out`
@@ -708,16 +709,36 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial
708709
override def handleConnection(sock: Socket): Unit = {
709710
val env = SparkEnv.get
710711
val in = sock.getInputStream()
711-
val dir = new File(Utils.getLocalDir(env.conf))
712-
val file = File.createTempFile("broadcast", "", dir)
713-
path = file.getAbsolutePath
714-
val out = env.serializerManager.wrapForEncryption(new FileOutputStream(path))
712+
val abspath = new File(path).getAbsolutePath
713+
val out = env.serializerManager.wrapForEncryption(new FileOutputStream(abspath))
715714
DechunkedInputStream.dechunkAndCopyToOutput(in, out)
716715
}
717716
}
718717
Array(encryptionServer.port, encryptionServer.secret)
719718
}
720719

720+
def setupDecryptionServer(): Array[Any] = {
721+
decryptionServer = new PythonServer[Unit]("broadcast-decrypt-server-for-driver") {
722+
override def handleConnection(sock: Socket): Unit = {
723+
val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream()))
724+
Utils.tryWithSafeFinally {
725+
val in = SparkEnv.get.serializerManager.wrapForEncryption(new FileInputStream(path))
726+
Utils.tryWithSafeFinally {
727+
Utils.copyStream(in, out, false)
728+
} {
729+
in.close()
730+
}
731+
out.flush()
732+
} {
733+
JavaUtils.closeQuietly(out)
734+
}
735+
}
736+
}
737+
Array(decryptionServer.port, decryptionServer.secret)
738+
}
739+
740+
def waitTillBroadcastDataSent(): Unit = decryptionServer.getResult()
741+
721742
def waitTillDataReceived(): Unit = encryptionServer.getResult()
722743
}
723744
// scalastyle:on no.finalize

python/pyspark/broadcast.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,24 +77,27 @@ def __init__(self, sc=None, value=None, pickle_registry=None, path=None,
7777
# we're on the driver. We want the pickled data to end up in a file (maybe encrypted)
7878
f = NamedTemporaryFile(delete=False, dir=sc._temp_dir)
7979
self._path = f.name
80-
python_broadcast = sc._jvm.PythonRDD.setupBroadcast(self._path)
80+
self._sc = sc
81+
self._python_broadcast = sc._jvm.PythonRDD.setupBroadcast(self._path)
8182
if sc._encryption_enabled:
8283
# with encryption, we ask the jvm to do the encryption for us, we send it data
8384
# over a socket
84-
port, auth_secret = python_broadcast.setupEncryptionServer()
85+
port, auth_secret = self._python_broadcast.setupEncryptionServer()
8586
(encryption_sock_file, _) = local_connect_and_auth(port, auth_secret)
8687
broadcast_out = ChunkedStream(encryption_sock_file, 8192)
8788
else:
8889
# no encryption, we can just write pickled data directly to the file from python
8990
broadcast_out = f
9091
self.dump(value, broadcast_out)
9192
if sc._encryption_enabled:
92-
python_broadcast.waitTillDataReceived()
93-
self._jbroadcast = sc._jsc.broadcast(python_broadcast)
93+
self._python_broadcast.waitTillDataReceived()
94+
self._jbroadcast = sc._jsc.broadcast(self._python_broadcast)
9495
self._pickle_registry = pickle_registry
9596
else:
9697
# we're on an executor
9798
self._jbroadcast = None
99+
self._sc = None
100+
self._python_broadcast = None
98101
if sock_file is not None:
99102
# the jvm is doing decryption for us. Read the value
100103
# immediately from the sock_file
@@ -134,7 +137,15 @@ def value(self):
134137
""" Return the broadcasted value
135138
"""
136139
if not hasattr(self, "_value") and self._path is not None:
137-
self._value = self.load_from_path(self._path)
140+
# we only need to decrypt it here when encryption is enabled and
141+
# if its on the driver, since executor decryption is handled already
142+
if self._sc is not None and self._sc._encryption_enabled:
143+
port, auth_secret = self._python_broadcast.setupDecryptionServer()
144+
(decrypted_sock_file, _) = local_connect_and_auth(port, auth_secret)
145+
self._python_broadcast.waitTillBroadcastDataSent()
146+
return self.load(decrypted_sock_file)
147+
else:
148+
self._value = self.load_from_path(self._path)
138149
return self._value
139150

140151
def unpersist(self, blocking=False):

python/pyspark/tests/test_broadcast.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,21 @@ def test_broadcast_with_encryption(self):
6767
def test_broadcast_no_encryption(self):
6868
self._test_multiple_broadcasts()
6969

70+
def _test_broadcast_on_driver(self, *extra_confs):
71+
conf = SparkConf()
72+
for key, value in extra_confs:
73+
conf.set(key, value)
74+
conf.setMaster("local-cluster[2,1,1024]")
75+
self.sc = SparkContext(conf=conf)
76+
bs = self.sc.broadcast(value=5)
77+
self.assertEqual(5, bs.value)
78+
79+
def test_broadcast_value_driver_no_encryption(self):
80+
self._test_broadcast_on_driver()
81+
82+
def test_broadcast_value_driver_encryption(self):
83+
self._test_broadcast_on_driver(("spark.io.encryption.enabled", "true"))
84+
7085

7186
class BroadcastFrameProtocolTest(unittest.TestCase):
7287

0 commit comments

Comments
 (0)