From 899ad8d35fadf174cd086bca4a079f09396f68e1 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 5 Mar 2019 20:58:25 -0800 Subject: [PATCH 01/12] toLocalIterator is working with request at each iteration, not the most efficient --- .../apache/spark/api/python/PythonRDD.scala | 56 +++++++++++-------- .../spark/api/python/PythonRunner.scala | 2 +- python/pyspark/rdd.py | 45 +++++++++++++-- python/pyspark/sql/dataframe.py | 6 +- .../execution/python/PythonUDFRunner.scala | 2 +- 5 files changed, 81 insertions(+), 30 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 182c383cec2c1..1161283a7a5da 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -164,7 +164,22 @@ private[spark] object PythonRDD extends Logging { } def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = { - serveIterator(rdd.toLocalIterator, s"serve toLocalIterator") + val (port, secret) = PythonServer.setupOneConnectionServer( + authHelper, "serve toLocalIterator") { s => + val out = new DataOutputStream(s.getOutputStream) + val in = new DataInputStream(s.getInputStream) + Utils.tryWithSafeFinally { + val iter = rdd.toLocalIterator + while (iter.hasNext && in.readInt() != 0) { + writeObjectToStream(iter.next(), out) + out.flush() + } + } { + out.close() + in.close() + } + } + Array(port, secret) } def readRDDFromFile( @@ -185,26 +200,23 @@ private[spark] object PythonRDD extends Logging { new PythonBroadcast(path) } - def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) { - - def write(obj: Any): Unit = obj match { - case null => - dataOut.writeInt(SpecialLengths.NULL) - case arr: Array[Byte] => - dataOut.writeInt(arr.length) - dataOut.write(arr) - case str: String => - writeUTF(str, dataOut) - case stream: PortableDataStream => - write(stream.toArray()) - case (key, value) => - write(key) - write(value) - case other => - throw new SparkException("Unexpected element type " + other.getClass) - } - - iter.foreach(write) + def writeObjectToStream(obj: Any, dataOut: DataOutputStream): Unit = obj match { + case iter: Iterator[Any] => + iter.foreach(writeObjectToStream(_, dataOut)) + case null => + dataOut.writeInt(SpecialLengths.NULL) + case arr: Array[Byte] => + dataOut.writeInt(arr.length) + dataOut.write(arr) + case str: String => + writeUTF(str, dataOut) + case stream: PortableDataStream => + writeObjectToStream(stream.toArray(), dataOut) + case (key, value) => + writeObjectToStream(key, dataOut) + writeObjectToStream(value, dataOut) + case other => + throw new SparkException("Unexpected element type " + other.getClass) } /** @@ -393,7 +405,7 @@ private[spark] object PythonRDD extends Logging { */ def serveIterator(items: Iterator[_], threadName: String): Array[Any] = { serveToStream(threadName) { out => - writeIteratorToStream(items, new DataOutputStream(out)) + writeObjectToStream(items, new DataOutputStream(out)) } } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index b7f14e062b437..ee9492e10214a 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -554,7 +554,7 @@ private[spark] class PythonRunner(funcs: Seq[ChainedPythonFunctions]) } protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { - PythonRDD.writeIteratorToStream(inputIterator, dataOut) + PythonRDD.writeObjectToStream(inputIterator, dataOut) dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) } } diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index ef382f37ea212..e268446231394 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -41,7 +41,7 @@ from pyspark.java_gateway import local_connect_and_auth from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ - PickleSerializer, pack_long, AutoBatchedSerializer + PickleSerializer, pack_long, AutoBatchedSerializer, write_int from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_full_outer_join, python_cogroup from pyspark.statcounter import StatCounter @@ -138,15 +138,52 @@ def _parse_memory(s): return int(float(s[:-1]) * units[s[-1].lower()]) -def _load_from_socket(sock_info, serializer): +def _create_local_socket(sock_info): (sockfile, sock) = local_connect_and_auth(*sock_info) - # The RDD materialization time is unpredicable, if we set a timeout for socket reading + # The RDD materialization time is unpredictable, if we set a timeout for socket reading # operation, it will very possibly fail. See SPARK-18281. sock.settimeout(None) + return sockfile, sock + + +def _load_from_socket(sock_info, serializer): + (sockfile, _) = _create_local_socket(sock_info) # The socket will be automatically closed when garbage-collected. return serializer.load_stream(sockfile) +class _PyLocalIterator(object): + + def __init__(self, sock_info, serializer): + (self.sockfile, self.sock) = _create_local_socket(sock_info) + self.read_iter = serializer.load_stream(self.sockfile) + + def __iter__(self): + return self + + def __next__(self): + # Request more data from Java, then read from stream + write_int(1, self.sockfile) + self.sockfile.flush() + return next(self.read_iter) + + def __del__(self): + try: + # Tell Java to stop sending data + write_int(0, self.sockfile) + finally: + try: + # Attempt to close the socket + self.sockfile.close() + self.sock.close() + except Exception: + pass + + def next(self): + """For python2 compatibility.""" + return self.__next__() + + def ignore_unicode_prefix(f): """ Ignore the 'u' prefix of string in doc tests, to make it works @@ -2386,7 +2423,7 @@ def toLocalIterator(self): """ with SCCallSiteSync(self.context) as css: sock_info = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd()) - return _load_from_socket(sock_info, self._jrdd_deserializer) + return _PyLocalIterator(sock_info, self._jrdd_deserializer) def barrier(self): """ diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index f8aeb62a27fb9..9a07e09c94caa 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -24,11 +24,12 @@ from functools import reduce else: from itertools import imap as map +from itertools import chain import warnings from pyspark import copy_func, since, _NoValue -from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix +from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix, _PyLocalIterator from pyspark.serializers import ArrowCollectSerializer, BatchedSerializer, PickleSerializer, \ UTF8Deserializer from pyspark.storagelevel import StorageLevel @@ -528,7 +529,8 @@ def toLocalIterator(self): """ with SCCallSiteSync(self._sc) as css: sock_info = self._jdf.toPythonIterator() - return _load_from_socket(sock_info, BatchedSerializer(PickleSerializer())) + batch_iter = _PyLocalIterator(sock_info, PickleSerializer()) + return chain.from_iterable(batch_iter) @ignore_unicode_prefix @since(1.3) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala index 752d271c4cc35..3b2d7cedb3d68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala @@ -47,7 +47,7 @@ class PythonUDFRunner( } protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { - PythonRDD.writeIteratorToStream(inputIterator, dataOut) + PythonRDD.writeObjectToStream(inputIterator, dataOut) dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) } } From 8c309c5a293524d2873e45d14d04638b06979c83 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 7 Mar 2019 11:13:17 -0800 Subject: [PATCH 02/12] fix in protocol to check for next --- .../apache/spark/api/python/PythonRDD.scala | 18 +++++++++++++++--- python/pyspark/rdd.py | 12 ++++++++++-- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 1161283a7a5da..0470549775651 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -170,9 +170,21 @@ private[spark] object PythonRDD extends Logging { val in = new DataInputStream(s.getInputStream) Utils.tryWithSafeFinally { val iter = rdd.toLocalIterator - while (iter.hasNext && in.readInt() != 0) { - writeObjectToStream(iter.next(), out) - out.flush() + var stop = false + while (!stop) { + stop = !iter.hasNext + if (!stop) { + out.writeInt(1) + out.flush() + stop = in.readInt() == 0 + if (!stop) { + writeObjectToStream(iter.next(), out) + out.flush() + } + } else { + out.writeInt(0) + out.flush() + } } } { out.close() diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index e268446231394..cc16549985381 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -41,7 +41,7 @@ from pyspark.java_gateway import local_connect_and_auth from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ - PickleSerializer, pack_long, AutoBatchedSerializer, write_int + PickleSerializer, pack_long, AutoBatchedSerializer, read_int, write_int from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_full_outer_join, python_cogroup from pyspark.statcounter import StatCounter @@ -162,6 +162,9 @@ def __iter__(self): return self def __next__(self): + has_next = read_int(self.sockfile) + if has_next == 0: + raise StopIteration # Request more data from Java, then read from stream write_int(1, self.sockfile) self.sockfile.flush() @@ -170,7 +173,12 @@ def __next__(self): def __del__(self): try: # Tell Java to stop sending data - write_int(0, self.sockfile) + has_next = read_int(self.sockfile) + if has_next != 0: + write_int(0, self.sockfile) + self.sockfile.flush() + except Exception: + pass finally: try: # Attempt to close the socket From 866d585f9293623c54850e3c3fd55393dbbb97af Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 11 Mar 2019 15:59:39 -0700 Subject: [PATCH 03/12] fixed issue with RDD toLocalIterator, able to use any serializer now --- .../apache/spark/api/python/PythonRDD.scala | 23 ++++------- python/pyspark/rdd.py | 41 ++++++++----------- python/pyspark/sql/dataframe.py | 6 +-- 3 files changed, 28 insertions(+), 42 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 0470549775651..17532152a2eab 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -170,21 +170,14 @@ private[spark] object PythonRDD extends Logging { val in = new DataInputStream(s.getInputStream) Utils.tryWithSafeFinally { val iter = rdd.toLocalIterator - var stop = false - while (!stop) { - stop = !iter.hasNext - if (!stop) { - out.writeInt(1) - out.flush() - stop = in.readInt() == 0 - if (!stop) { - writeObjectToStream(iter.next(), out) - out.flush() - } - } else { - out.writeInt(0) - out.flush() - } + + // Send data while request to stop is nonzero and iter has next + while (in.readInt() != 0 && iter.hasNext) { + + // Write the next object and signal end of data for this iteration + writeObjectToStream(iter.next(), out) + out.writeInt(SpecialLengths.END_OF_DATA_SECTION) + out.flush() } } { out.close() diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index cc16549985381..7cabbc2546df1 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -41,7 +41,7 @@ from pyspark.java_gateway import local_connect_and_auth from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ - PickleSerializer, pack_long, AutoBatchedSerializer, read_int, write_int + PickleSerializer, pack_long, AutoBatchedSerializer, write_int from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_full_outer_join, python_cogroup from pyspark.statcounter import StatCounter @@ -152,45 +152,40 @@ def _load_from_socket(sock_info, serializer): return serializer.load_stream(sockfile) -class _PyLocalIterator(object): +class _PyLocalIterable(object): def __init__(self, sock_info, serializer): (self.sockfile, self.sock) = _create_local_socket(sock_info) - self.read_iter = serializer.load_stream(self.sockfile) + self.serializer = serializer def __iter__(self): - return self + while True: + try: + # Request data from Java, if no more then connection is closed + write_int(1, self.sockfile) + self.sockfile.flush() - def __next__(self): - has_next = read_int(self.sockfile) - if has_next == 0: - raise StopIteration - # Request more data from Java, then read from stream - write_int(1, self.sockfile) - self.sockfile.flush() - return next(self.read_iter) + # Read one item from Java, if using BatchedSerializer batches have multiple items + for item in self.serializer.load_stream(self.sockfile): + yield item + except Exception: # TODO: more specific error, ConnectionError / socket.error + break def __del__(self): try: - # Tell Java to stop sending data - has_next = read_int(self.sockfile) - if has_next != 0: - write_int(0, self.sockfile) - self.sockfile.flush() + # Tell Java to stop sending data and close connection + write_int(0, self.sockfile) + self.sockfile.flush() except Exception: pass finally: try: - # Attempt to close the socket + # Attempt to close the socket, ignore any errors self.sockfile.close() self.sock.close() except Exception: pass - def next(self): - """For python2 compatibility.""" - return self.__next__() - def ignore_unicode_prefix(f): """ @@ -2431,7 +2426,7 @@ def toLocalIterator(self): """ with SCCallSiteSync(self.context) as css: sock_info = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd()) - return _PyLocalIterator(sock_info, self._jrdd_deserializer) + return iter(_PyLocalIterable(sock_info, self._jrdd_deserializer)) def barrier(self): """ diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 9a07e09c94caa..056cbcafa5838 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -24,12 +24,11 @@ from functools import reduce else: from itertools import imap as map -from itertools import chain import warnings from pyspark import copy_func, since, _NoValue -from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix, _PyLocalIterator +from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix, _PyLocalIterable from pyspark.serializers import ArrowCollectSerializer, BatchedSerializer, PickleSerializer, \ UTF8Deserializer from pyspark.storagelevel import StorageLevel @@ -529,8 +528,7 @@ def toLocalIterator(self): """ with SCCallSiteSync(self._sc) as css: sock_info = self._jdf.toPythonIterator() - batch_iter = _PyLocalIterator(sock_info, PickleSerializer()) - return chain.from_iterable(batch_iter) + return iter(_PyLocalIterable(sock_info, BatchedSerializer(PickleSerializer()))) @ignore_unicode_prefix @since(1.3) From 9ad3a77866052a07bd8739b15cfc52b12f881aa7 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 11 Mar 2019 23:38:48 -0700 Subject: [PATCH 04/12] change to read one partiton at a time --- .../apache/spark/api/python/PythonRDD.scala | 53 +++++++++++-------- .../spark/api/python/PythonRunner.scala | 2 +- python/pyspark/rdd.py | 21 ++++---- .../execution/python/PythonUDFRunner.scala | 2 +- 4 files changed, 43 insertions(+), 35 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 17532152a2eab..b9af9681f86e7 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -164,18 +164,24 @@ private[spark] object PythonRDD extends Logging { } def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = { + val (port, secret) = PythonServer.setupOneConnectionServer( authHelper, "serve toLocalIterator") { s => val out = new DataOutputStream(s.getOutputStream) val in = new DataInputStream(s.getInputStream) Utils.tryWithSafeFinally { - val iter = rdd.toLocalIterator - // Send data while request to stop is nonzero and iter has next - while (in.readInt() != 0 && iter.hasNext) { + // Collects a partition on each iteration + val collectPartitionIter = rdd.partitions.indices.iterator.map { i => + rdd.sparkContext.runJob(rdd, (iter: Iterator[Any]) => iter.toArray, Seq(i)).head + } + + // Read request for data and send next partition if nonzero + while (in.readInt() != 0 && collectPartitionIter.hasNext) { // Write the next object and signal end of data for this iteration - writeObjectToStream(iter.next(), out) + val partitionArray = collectPartitionIter.next() + writeIteratorToStream(partitionArray.toIterator, out) out.writeInt(SpecialLengths.END_OF_DATA_SECTION) out.flush() } @@ -205,23 +211,26 @@ private[spark] object PythonRDD extends Logging { new PythonBroadcast(path) } - def writeObjectToStream(obj: Any, dataOut: DataOutputStream): Unit = obj match { - case iter: Iterator[Any] => - iter.foreach(writeObjectToStream(_, dataOut)) - case null => - dataOut.writeInt(SpecialLengths.NULL) - case arr: Array[Byte] => - dataOut.writeInt(arr.length) - dataOut.write(arr) - case str: String => - writeUTF(str, dataOut) - case stream: PortableDataStream => - writeObjectToStream(stream.toArray(), dataOut) - case (key, value) => - writeObjectToStream(key, dataOut) - writeObjectToStream(value, dataOut) - case other => - throw new SparkException("Unexpected element type " + other.getClass) + def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) { + + def write(obj: Any): Unit = obj match { + case null => + dataOut.writeInt(SpecialLengths.NULL) + case arr: Array[Byte] => + dataOut.writeInt(arr.length) + dataOut.write(arr) + case str: String => + writeUTF(str, dataOut) + case stream: PortableDataStream => + write(stream.toArray()) + case (key, value) => + write(key) + write(value) + case other => + throw new SparkException("Unexpected element type " + other.getClass) + } + + iter.foreach(write) } /** @@ -410,7 +419,7 @@ private[spark] object PythonRDD extends Logging { */ def serveIterator(items: Iterator[_], threadName: String): Array[Any] = { serveToStream(threadName) { out => - writeObjectToStream(items, new DataOutputStream(out)) + writeIteratorToStream(items, new DataOutputStream(out)) } } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index ee9492e10214a..b7f14e062b437 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -554,7 +554,7 @@ private[spark] class PythonRunner(funcs: Seq[ChainedPythonFunctions]) } protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { - PythonRDD.writeObjectToStream(inputIterator, dataOut) + PythonRDD.writeIteratorToStream(inputIterator, dataOut) dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) } } diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 7cabbc2546df1..b5091f2650fd7 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -153,38 +153,37 @@ def _load_from_socket(sock_info, serializer): class _PyLocalIterable(object): + """ Create a synchronous local iterable over a socket """ def __init__(self, sock_info, serializer): (self.sockfile, self.sock) = _create_local_socket(sock_info) self.serializer = serializer + self.read_iter = iter([]) # Initialize as empty iterator def __iter__(self): while True: try: - # Request data from Java, if no more then connection is closed + # Request next partition data from Java, if no more then connection is closed write_int(1, self.sockfile) self.sockfile.flush() - # Read one item from Java, if using BatchedSerializer batches have multiple items - for item in self.serializer.load_stream(self.sockfile): + # Load the partition data as a stream and read each item + self.read_iter = self.serializer.load_stream(self.sockfile) + for item in self.read_iter: yield item except Exception: # TODO: more specific error, ConnectionError / socket.error break def __del__(self): try: + # Finish consuming partition data stream + for _ in self.read_iter: + pass # Tell Java to stop sending data and close connection write_int(0, self.sockfile) self.sockfile.flush() except Exception: - pass - finally: - try: - # Attempt to close the socket, ignore any errors - self.sockfile.close() - self.sock.close() - except Exception: - pass + pass # Ignore any errors, socket will be automatically closed when garbage-collected def ignore_unicode_prefix(f): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala index 3b2d7cedb3d68..752d271c4cc35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala @@ -47,7 +47,7 @@ class PythonUDFRunner( } protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { - PythonRDD.writeObjectToStream(inputIterator, dataOut) + PythonRDD.writeIteratorToStream(inputIterator, dataOut) dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) } } From 3415ff12fc2d978ea49d620fc8f8ea4d5f4a596b Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 7 Mar 2019 08:49:06 -0800 Subject: [PATCH 05/12] added tests for DataFrame toLocalIterator --- .../apache/spark/api/python/PythonRDD.scala | 1 - python/pyspark/sql/tests/test_dataframe.py | 28 +++++++++++++++++++ python/pyspark/tests/test_rdd.py | 3 -- 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index b9af9681f86e7..f0ca193ad2556 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -164,7 +164,6 @@ private[spark] object PythonRDD extends Logging { } def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = { - val (port, secret) = PythonServer.setupOneConnectionServer( authHelper, "serve toLocalIterator") { s => val out = new DataOutputStream(s.getOutputStream) diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index eb34bbb779253..446ef0a10e43c 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -677,6 +677,34 @@ def test_repr_behaviors(self): self.assertEquals(None, df._repr_html_()) self.assertEquals(expected, df.__repr__()) + def test_to_local_iterator(self): + df = self.spark.range(8, numPartitions=4) + expected = df.collect() + it = df.toLocalIterator() + self.assertEqual(expected, list(it)) + + # Test DataFrame with empty partition + df = self.spark.range(3, numPartitions=4) + it = df.toLocalIterator() + expected = df.collect() + self.assertEqual(expected, list(it)) + + def test_to_local_iterator_not_fully_consumed(self): + # SPARK-23961: toLocalIterator throws exception when not fully consumed + # Create a DataFrame large enough so that write to socket will eventually block + df = self.spark.range(1 << 20, numPartitions=2) + it = df.toLocalIterator() + self.assertEqual(df.take(1)[0], next(it)) + with QuietTest(self.sc): + it = None # remove iterator from scope, socket is closed when cleaned up + # Make sure normal df operations still work + result = [] + for i, row in enumerate(df.toLocalIterator()): + result.append(row) + if i == 7: + break + self.assertEqual(df.take(8), result) + class QueryExecutionListenerTests(unittest.TestCase, SQLTestUtils): # These tests are separate because it uses 'spark.sql.queryExecutionListeners' which is diff --git a/python/pyspark/tests/test_rdd.py b/python/pyspark/tests/test_rdd.py index e789cbe90d27c..190c25db1425c 100644 --- a/python/pyspark/tests/test_rdd.py +++ b/python/pyspark/tests/test_rdd.py @@ -60,15 +60,12 @@ def test_sum(self): self.assertEqual(6, self.sc.parallelize([1, 2, 3]).sum()) def test_to_localiterator(self): - from time import sleep rdd = self.sc.parallelize([1, 2, 3]) it = rdd.toLocalIterator() - sleep(5) self.assertEqual([1, 2, 3], sorted(it)) rdd2 = rdd.repartition(1000) it2 = rdd2.toLocalIterator() - sleep(5) self.assertEqual([1, 2, 3], sorted(it2)) def test_save_as_textfile_with_unicode(self): From 57d251c473c1fce48a8f0b95042dc69f6550ee0c Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 12 Mar 2019 11:33:58 -0700 Subject: [PATCH 06/12] rebased to fix SocketAuthServer name --- core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index f0ca193ad2556..f43bf59a01305 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -164,7 +164,7 @@ private[spark] object PythonRDD extends Logging { } def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = { - val (port, secret) = PythonServer.setupOneConnectionServer( + val (port, secret) = SocketAuthServer.setupOneConnectionServer( authHelper, "serve toLocalIterator") { s => val out = new DataOutputStream(s.getOutputStream) val in = new DataInputStream(s.getInputStream) From 600a906151fe7fdf6040143f3de069d61ae1b1c9 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Fri, 29 Mar 2019 11:31:30 -0700 Subject: [PATCH 07/12] remove reference to sock, only need sockfile --- python/pyspark/rdd.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index b5091f2650fd7..89a0d74cbff1a 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -143,11 +143,11 @@ def _create_local_socket(sock_info): # The RDD materialization time is unpredictable, if we set a timeout for socket reading # operation, it will very possibly fail. See SPARK-18281. sock.settimeout(None) - return sockfile, sock + return sockfile def _load_from_socket(sock_info, serializer): - (sockfile, _) = _create_local_socket(sock_info) + sockfile = _create_local_socket(sock_info) # The socket will be automatically closed when garbage-collected. return serializer.load_stream(sockfile) @@ -156,7 +156,7 @@ class _PyLocalIterable(object): """ Create a synchronous local iterable over a socket """ def __init__(self, sock_info, serializer): - (self.sockfile, self.sock) = _create_local_socket(sock_info) + self.sockfile = _create_local_socket(sock_info) self.serializer = serializer self.read_iter = iter([]) # Initialize as empty iterator From 0a796d74b253c2e6b9037fd782f09cfb7d05c1ef Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Fri, 29 Mar 2019 13:32:56 -0700 Subject: [PATCH 08/12] Use response to indicate no more partitions to read, add method _local_iterator_from_socket to get a local iterator --- .../apache/spark/api/python/PythonRDD.scala | 24 +++++--- python/pyspark/rdd.py | 61 ++++++++++--------- python/pyspark/sql/dataframe.py | 4 +- 3 files changed, 52 insertions(+), 37 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index f43bf59a01305..8ba1f6086b9c3 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -176,13 +176,23 @@ private[spark] object PythonRDD extends Logging { } // Read request for data and send next partition if nonzero - while (in.readInt() != 0 && collectPartitionIter.hasNext) { - - // Write the next object and signal end of data for this iteration - val partitionArray = collectPartitionIter.next() - writeIteratorToStream(partitionArray.toIterator, out) - out.writeInt(SpecialLengths.END_OF_DATA_SECTION) - out.flush() + var complete = false + while (in.readInt() != 0 && !complete) { + if (collectPartitionIter.hasNext) { + + // Send response there is a partition to read + out.writeInt(1) + + // Write the next object and signal end of data for this iteration + val partitionArray = collectPartitionIter.next() + writeIteratorToStream(partitionArray.toIterator, out) + out.writeInt(SpecialLengths.END_OF_DATA_SECTION) + out.flush() + } else { + // Send response there are no more partitions to read and close + out.writeInt(0) + complete = true + } } } { out.close() diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 89a0d74cbff1a..db9d8a4d81e37 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -41,7 +41,7 @@ from pyspark.java_gateway import local_connect_and_auth from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ - PickleSerializer, pack_long, AutoBatchedSerializer, write_int + PickleSerializer, pack_long, AutoBatchedSerializer, read_int, write_int from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_full_outer_join, python_cogroup from pyspark.statcounter import StatCounter @@ -152,38 +152,43 @@ def _load_from_socket(sock_info, serializer): return serializer.load_stream(sockfile) -class _PyLocalIterable(object): - """ Create a synchronous local iterable over a socket """ +def _local_iterator_from_socket(sock_info, serializer): - def __init__(self, sock_info, serializer): - self.sockfile = _create_local_socket(sock_info) - self.serializer = serializer - self.read_iter = iter([]) # Initialize as empty iterator + class PyLocalIterable(object): + """ Create a synchronous local iterable over a socket """ - def __iter__(self): - while True: - try: - # Request next partition data from Java, if no more then connection is closed - write_int(1, self.sockfile) - self.sockfile.flush() + def __init__(self, _sock_info, _serializer): + self._sockfile = _create_local_socket(_sock_info) + self._serializer = _serializer + self._read_iter = iter([]) # Initialize as empty iterator + + def __iter__(self): + while True: + # Request next partition data from Java + write_int(1, self._sockfile) + self._sockfile.flush() + + # If nonzero response, then there is a partition to read + if read_int(self._sockfile) == 0: + break # Load the partition data as a stream and read each item - self.read_iter = self.serializer.load_stream(self.sockfile) - for item in self.read_iter: + self._read_iter = self._serializer.load_stream(self._sockfile) + for item in self._read_iter: yield item - except Exception: # TODO: more specific error, ConnectionError / socket.error - break - def __del__(self): - try: - # Finish consuming partition data stream - for _ in self.read_iter: - pass - # Tell Java to stop sending data and close connection - write_int(0, self.sockfile) - self.sockfile.flush() - except Exception: - pass # Ignore any errors, socket will be automatically closed when garbage-collected + def __del__(self): + try: + # Finish consuming partition data stream + for _ in self._read_iter: + pass + # Tell Java to stop sending data and close connection + write_int(0, self._sockfile) + self._sockfile.flush() + except Exception: + pass # Ignore any errors, socket is automatically closed when garbage-collected + + return iter(PyLocalIterable(sock_info, serializer)) def ignore_unicode_prefix(f): @@ -2425,7 +2430,7 @@ def toLocalIterator(self): """ with SCCallSiteSync(self.context) as css: sock_info = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd()) - return iter(_PyLocalIterable(sock_info, self._jrdd_deserializer)) + return _local_iterator_from_socket(sock_info, self._jrdd_deserializer) def barrier(self): """ diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 056cbcafa5838..9f4a082d8ed5c 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -28,7 +28,7 @@ import warnings from pyspark import copy_func, since, _NoValue -from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix, _PyLocalIterable +from pyspark.rdd import RDD, _load_from_socket, _local_iterator_from_socket, ignore_unicode_prefix from pyspark.serializers import ArrowCollectSerializer, BatchedSerializer, PickleSerializer, \ UTF8Deserializer from pyspark.storagelevel import StorageLevel @@ -528,7 +528,7 @@ def toLocalIterator(self): """ with SCCallSiteSync(self._sc) as css: sock_info = self._jdf.toPythonIterator() - return iter(_PyLocalIterable(sock_info, BatchedSerializer(PickleSerializer()))) + return _local_iterator_from_socket(sock_info, BatchedSerializer(PickleSerializer())) @ignore_unicode_prefix @since(1.3) From 7847a14e169470ed94b2b1d6e12553cffbe293db Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 23 Apr 2019 11:43:08 -0700 Subject: [PATCH 09/12] Add test to verify only one partition is collected at a time --- python/pyspark/tests/test_rdd.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/python/pyspark/tests/test_rdd.py b/python/pyspark/tests/test_rdd.py index 190c25db1425c..67add1984a122 100644 --- a/python/pyspark/tests/test_rdd.py +++ b/python/pyspark/tests/test_rdd.py @@ -733,6 +733,20 @@ def test_overwritten_global_func(self): global_func = lambda: "Yeah" self.assertEqual(self.sc.parallelize([1]).map(lambda _: global_func()).first(), "Yeah") + def test_to_local_iterator_collects_single_partition(self): + + def fail_last(x): + if x == 9: + raise RuntimeError("This should not be hit") + return x + + rdd = self.sc.parallelize(range(10), numSlices=2).map(fail_last) + it = rdd.toLocalIterator() + + # Only consume 3 elements in first partition, this should not trigger an error + for i in range(3): + self.assertEqual(i, next(it)) + if __name__ == "__main__": import unittest From a1f811a4d3c27522ef39e4cfdb6dc26cf09f3346 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 29 Apr 2019 17:29:01 -0700 Subject: [PATCH 10/12] added failure test --- python/pyspark/tests/test_rdd.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/python/pyspark/tests/test_rdd.py b/python/pyspark/tests/test_rdd.py index 67add1984a122..de62fe16814f8 100644 --- a/python/pyspark/tests/test_rdd.py +++ b/python/pyspark/tests/test_rdd.py @@ -733,14 +733,27 @@ def test_overwritten_global_func(self): global_func = lambda: "Yeah" self.assertEqual(self.sc.parallelize([1]).map(lambda _: global_func()).first(), "Yeah") + def test_to_local_iterator_failure(self): + # SPARK-27548 toLocalIterator tasks failure not propagated to driver + + def fail(_): + raise RuntimeError("local iterator error") + + rdd = self.sc.range(10).map(fail) + + with self.assertRaisesRegexp(Exception, "local iterator error"): + for _ in rdd.toLocalIterator(): + pass + def test_to_local_iterator_collects_single_partition(self): + # Test that partitions are not computed until requested by iteration def fail_last(x): if x == 9: raise RuntimeError("This should not be hit") return x - rdd = self.sc.parallelize(range(10), numSlices=2).map(fail_last) + rdd = self.sc.range(10, numSlices=4).map(fail_last) it = rdd.toLocalIterator() # Only consume 3 elements in first partition, this should not trigger an error From 29b8ab67ef36d78529ae7fa095039c97efc6f27e Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 29 Apr 2019 17:36:00 -0700 Subject: [PATCH 11/12] Added handling of task failures and send error msg to python --- .../apache/spark/api/python/PythonRDD.scala | 40 ++++++++++++---- python/pyspark/rdd.py | 48 +++++++++++-------- 2 files changed, 59 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 8ba1f6086b9c3..bc843827acaf8 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -163,6 +163,17 @@ private[spark] object PythonRDD extends Logging { serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}") } + /** + * A helper function to create a local RDD iterator and serve it via socket. Partitions are + * are collected as separate jobs, by order of index. Partition data is first requested by a + * non-zero integer to start a collection job. The response is prefaced by an integer with 1 + * meaning partition data will be served, 0 meaning the local iterator has been consumed, + * and -1 meaining an error occurred during collection. This function is used by + * pyspark.rdd._local_iterator_from_socket(). + * + * @return 2-tuple (as a Java array) with the port number of a local socket which serves the + * data collected from these jobs, and the secret for authentication. + */ def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = { val (port, secret) = SocketAuthServer.setupOneConnectionServer( authHelper, "serve toLocalIterator") { s => @@ -177,17 +188,26 @@ private[spark] object PythonRDD extends Logging { // Read request for data and send next partition if nonzero var complete = false - while (in.readInt() != 0 && !complete) { + while (!complete && in.readInt() != 0) { if (collectPartitionIter.hasNext) { - - // Send response there is a partition to read - out.writeInt(1) - - // Write the next object and signal end of data for this iteration - val partitionArray = collectPartitionIter.next() - writeIteratorToStream(partitionArray.toIterator, out) - out.writeInt(SpecialLengths.END_OF_DATA_SECTION) - out.flush() + try { + // Attempt to collect the next partition + val partitionArray = collectPartitionIter.next() + + // Send response there is a partition to read + out.writeInt(1) + + // Write the next object and signal end of data for this iteration + writeIteratorToStream(partitionArray.toIterator, out) + out.writeInt(SpecialLengths.END_OF_DATA_SECTION) + out.flush() + } catch { + case e: SparkException => + // Send response that an error occurred followed by error message + out.writeInt(-1) + writeUTF(e.getMessage, out) + complete = true + } } else { // Send response there are no more partitions to read and close out.writeInt(0) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index db9d8a4d81e37..9dc5b26bccb67 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -39,9 +39,9 @@ from itertools import imap as map, ifilter as filter from pyspark.java_gateway import local_connect_and_auth -from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ - BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ - PickleSerializer, pack_long, AutoBatchedSerializer, read_int, write_int +from pyspark.serializers import AutoBatchedSerializer, BatchedSerializer, NoOpSerializer, \ + CartesianDeserializer, CloudPickleSerializer, PairDeserializer, PickleSerializer, \ + UTF8Deserializer, pack_long, read_int, write_int from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_full_outer_join, python_cogroup from pyspark.statcounter import StatCounter @@ -161,32 +161,42 @@ def __init__(self, _sock_info, _serializer): self._sockfile = _create_local_socket(_sock_info) self._serializer = _serializer self._read_iter = iter([]) # Initialize as empty iterator + self._read_status = 1 def __iter__(self): - while True: + while self._read_status == 1: # Request next partition data from Java write_int(1, self._sockfile) self._sockfile.flush() - # If nonzero response, then there is a partition to read - if read_int(self._sockfile) == 0: - break + # If response is 1, then there is a partition to read + self._read_status = read_int(self._sockfile) + if self._read_status == 1: - # Load the partition data as a stream and read each item - self._read_iter = self._serializer.load_stream(self._sockfile) - for item in self._read_iter: - yield item + # Load the partition data as a stream and read each item + self._read_iter = self._serializer.load_stream(self._sockfile) + for item in self._read_iter: + yield item + + # An error occurred, read error message and raise it + elif self._read_status == -1: + error_msg = UTF8Deserializer().loads(self._sockfile) + raise RuntimeError("An error occurred while reading the next element from " + "toLocalIterator: {}".format(error_msg)) def __del__(self): - try: - # Finish consuming partition data stream - for _ in self._read_iter: + # If local iterator is not fully consumed, + if self._read_status == 1: + try: + # Finish consuming partition data stream + for _ in self._read_iter: + pass + # Tell Java to stop sending data and close connection + write_int(0, self._sockfile) + self._sockfile.flush() + except Exception: + # Ignore any errors, socket is automatically closed when garbage-collected pass - # Tell Java to stop sending data and close connection - write_int(0, self._sockfile) - self._sockfile.flush() - except Exception: - pass # Ignore any errors, socket is automatically closed when garbage-collected return iter(PyLocalIterable(sock_info, serializer)) From 4f842dcb812dc5b199a4b150fea21d8015eedea0 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 6 May 2019 15:17:07 -0700 Subject: [PATCH 12/12] cleaned up some comments --- python/pyspark/rdd.py | 2 +- python/pyspark/tests/test_rdd.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 9dc5b26bccb67..f0682e71a1780 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -169,7 +169,7 @@ def __iter__(self): write_int(1, self._sockfile) self._sockfile.flush() - # If response is 1, then there is a partition to read + # If response is 1 then there is a partition to read, if 0 then fully consumed self._read_status = read_int(self._sockfile) if self._read_status == 1: diff --git a/python/pyspark/tests/test_rdd.py b/python/pyspark/tests/test_rdd.py index de62fe16814f8..448fcd36b5189 100644 --- a/python/pyspark/tests/test_rdd.py +++ b/python/pyspark/tests/test_rdd.py @@ -734,7 +734,7 @@ def test_overwritten_global_func(self): self.assertEqual(self.sc.parallelize([1]).map(lambda _: global_func()).first(), "Yeah") def test_to_local_iterator_failure(self): - # SPARK-27548 toLocalIterator tasks failure not propagated to driver + # SPARK-27548 toLocalIterator task failure not propagated to Python driver def fail(_): raise RuntimeError("local iterator error") @@ -753,11 +753,12 @@ def fail_last(x): raise RuntimeError("This should not be hit") return x - rdd = self.sc.range(10, numSlices=4).map(fail_last) + rdd = self.sc.range(12, numSlices=4).map(fail_last) it = rdd.toLocalIterator() - # Only consume 3 elements in first partition, this should not trigger an error - for i in range(3): + # Only consume first 4 elements from partitions 1 and 2, this should not collect the last + # partition which would trigger the error + for i in range(4): self.assertEqual(i, next(it))