diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 182c383cec2c1..bc843827acaf8 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -163,8 +163,63 @@ private[spark] object PythonRDD extends Logging { serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}") } + /** + * A helper function to create a local RDD iterator and serve it via socket. Partitions are + * are collected as separate jobs, by order of index. Partition data is first requested by a + * non-zero integer to start a collection job. The response is prefaced by an integer with 1 + * meaning partition data will be served, 0 meaning the local iterator has been consumed, + * and -1 meaining an error occurred during collection. This function is used by + * pyspark.rdd._local_iterator_from_socket(). + * + * @return 2-tuple (as a Java array) with the port number of a local socket which serves the + * data collected from these jobs, and the secret for authentication. + */ def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = { - serveIterator(rdd.toLocalIterator, s"serve toLocalIterator") + val (port, secret) = SocketAuthServer.setupOneConnectionServer( + authHelper, "serve toLocalIterator") { s => + val out = new DataOutputStream(s.getOutputStream) + val in = new DataInputStream(s.getInputStream) + Utils.tryWithSafeFinally { + + // Collects a partition on each iteration + val collectPartitionIter = rdd.partitions.indices.iterator.map { i => + rdd.sparkContext.runJob(rdd, (iter: Iterator[Any]) => iter.toArray, Seq(i)).head + } + + // Read request for data and send next partition if nonzero + var complete = false + while (!complete && in.readInt() != 0) { + if (collectPartitionIter.hasNext) { + try { + // Attempt to collect the next partition + val partitionArray = collectPartitionIter.next() + + // Send response there is a partition to read + out.writeInt(1) + + // Write the next object and signal end of data for this iteration + writeIteratorToStream(partitionArray.toIterator, out) + out.writeInt(SpecialLengths.END_OF_DATA_SECTION) + out.flush() + } catch { + case e: SparkException => + // Send response that an error occurred followed by error message + out.writeInt(-1) + writeUTF(e.getMessage, out) + complete = true + } + } else { + // Send response there are no more partitions to read and close + out.writeInt(0) + complete = true + } + } + } { + out.close() + in.close() + } + } + Array(port, secret) } def readRDDFromFile( diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index ef382f37ea212..f0682e71a1780 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -39,9 +39,9 @@ from itertools import imap as map, ifilter as filter from pyspark.java_gateway import local_connect_and_auth -from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ - BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ - PickleSerializer, pack_long, AutoBatchedSerializer +from pyspark.serializers import AutoBatchedSerializer, BatchedSerializer, NoOpSerializer, \ + CartesianDeserializer, CloudPickleSerializer, PairDeserializer, PickleSerializer, \ + UTF8Deserializer, pack_long, read_int, write_int from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_full_outer_join, python_cogroup from pyspark.statcounter import StatCounter @@ -138,15 +138,69 @@ def _parse_memory(s): return int(float(s[:-1]) * units[s[-1].lower()]) -def _load_from_socket(sock_info, serializer): +def _create_local_socket(sock_info): (sockfile, sock) = local_connect_and_auth(*sock_info) - # The RDD materialization time is unpredicable, if we set a timeout for socket reading + # The RDD materialization time is unpredictable, if we set a timeout for socket reading # operation, it will very possibly fail. See SPARK-18281. sock.settimeout(None) + return sockfile + + +def _load_from_socket(sock_info, serializer): + sockfile = _create_local_socket(sock_info) # The socket will be automatically closed when garbage-collected. return serializer.load_stream(sockfile) +def _local_iterator_from_socket(sock_info, serializer): + + class PyLocalIterable(object): + """ Create a synchronous local iterable over a socket """ + + def __init__(self, _sock_info, _serializer): + self._sockfile = _create_local_socket(_sock_info) + self._serializer = _serializer + self._read_iter = iter([]) # Initialize as empty iterator + self._read_status = 1 + + def __iter__(self): + while self._read_status == 1: + # Request next partition data from Java + write_int(1, self._sockfile) + self._sockfile.flush() + + # If response is 1 then there is a partition to read, if 0 then fully consumed + self._read_status = read_int(self._sockfile) + if self._read_status == 1: + + # Load the partition data as a stream and read each item + self._read_iter = self._serializer.load_stream(self._sockfile) + for item in self._read_iter: + yield item + + # An error occurred, read error message and raise it + elif self._read_status == -1: + error_msg = UTF8Deserializer().loads(self._sockfile) + raise RuntimeError("An error occurred while reading the next element from " + "toLocalIterator: {}".format(error_msg)) + + def __del__(self): + # If local iterator is not fully consumed, + if self._read_status == 1: + try: + # Finish consuming partition data stream + for _ in self._read_iter: + pass + # Tell Java to stop sending data and close connection + write_int(0, self._sockfile) + self._sockfile.flush() + except Exception: + # Ignore any errors, socket is automatically closed when garbage-collected + pass + + return iter(PyLocalIterable(sock_info, serializer)) + + def ignore_unicode_prefix(f): """ Ignore the 'u' prefix of string in doc tests, to make it works @@ -2386,7 +2440,7 @@ def toLocalIterator(self): """ with SCCallSiteSync(self.context) as css: sock_info = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd()) - return _load_from_socket(sock_info, self._jrdd_deserializer) + return _local_iterator_from_socket(sock_info, self._jrdd_deserializer) def barrier(self): """ diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index f8aeb62a27fb9..9f4a082d8ed5c 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -28,7 +28,7 @@ import warnings from pyspark import copy_func, since, _NoValue -from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix +from pyspark.rdd import RDD, _load_from_socket, _local_iterator_from_socket, ignore_unicode_prefix from pyspark.serializers import ArrowCollectSerializer, BatchedSerializer, PickleSerializer, \ UTF8Deserializer from pyspark.storagelevel import StorageLevel @@ -528,7 +528,7 @@ def toLocalIterator(self): """ with SCCallSiteSync(self._sc) as css: sock_info = self._jdf.toPythonIterator() - return _load_from_socket(sock_info, BatchedSerializer(PickleSerializer())) + return _local_iterator_from_socket(sock_info, BatchedSerializer(PickleSerializer())) @ignore_unicode_prefix @since(1.3) diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index eb34bbb779253..446ef0a10e43c 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -677,6 +677,34 @@ def test_repr_behaviors(self): self.assertEquals(None, df._repr_html_()) self.assertEquals(expected, df.__repr__()) + def test_to_local_iterator(self): + df = self.spark.range(8, numPartitions=4) + expected = df.collect() + it = df.toLocalIterator() + self.assertEqual(expected, list(it)) + + # Test DataFrame with empty partition + df = self.spark.range(3, numPartitions=4) + it = df.toLocalIterator() + expected = df.collect() + self.assertEqual(expected, list(it)) + + def test_to_local_iterator_not_fully_consumed(self): + # SPARK-23961: toLocalIterator throws exception when not fully consumed + # Create a DataFrame large enough so that write to socket will eventually block + df = self.spark.range(1 << 20, numPartitions=2) + it = df.toLocalIterator() + self.assertEqual(df.take(1)[0], next(it)) + with QuietTest(self.sc): + it = None # remove iterator from scope, socket is closed when cleaned up + # Make sure normal df operations still work + result = [] + for i, row in enumerate(df.toLocalIterator()): + result.append(row) + if i == 7: + break + self.assertEqual(df.take(8), result) + class QueryExecutionListenerTests(unittest.TestCase, SQLTestUtils): # These tests are separate because it uses 'spark.sql.queryExecutionListeners' which is diff --git a/python/pyspark/tests/test_rdd.py b/python/pyspark/tests/test_rdd.py index e789cbe90d27c..448fcd36b5189 100644 --- a/python/pyspark/tests/test_rdd.py +++ b/python/pyspark/tests/test_rdd.py @@ -60,15 +60,12 @@ def test_sum(self): self.assertEqual(6, self.sc.parallelize([1, 2, 3]).sum()) def test_to_localiterator(self): - from time import sleep rdd = self.sc.parallelize([1, 2, 3]) it = rdd.toLocalIterator() - sleep(5) self.assertEqual([1, 2, 3], sorted(it)) rdd2 = rdd.repartition(1000) it2 = rdd2.toLocalIterator() - sleep(5) self.assertEqual([1, 2, 3], sorted(it2)) def test_save_as_textfile_with_unicode(self): @@ -736,6 +733,34 @@ def test_overwritten_global_func(self): global_func = lambda: "Yeah" self.assertEqual(self.sc.parallelize([1]).map(lambda _: global_func()).first(), "Yeah") + def test_to_local_iterator_failure(self): + # SPARK-27548 toLocalIterator task failure not propagated to Python driver + + def fail(_): + raise RuntimeError("local iterator error") + + rdd = self.sc.range(10).map(fail) + + with self.assertRaisesRegexp(Exception, "local iterator error"): + for _ in rdd.toLocalIterator(): + pass + + def test_to_local_iterator_collects_single_partition(self): + # Test that partitions are not computed until requested by iteration + + def fail_last(x): + if x == 9: + raise RuntimeError("This should not be hit") + return x + + rdd = self.sc.range(12, numSlices=4).map(fail_last) + it = rdd.toLocalIterator() + + # Only consume first 4 elements from partitions 1 and 2, this should not collect the last + # partition which would trigger the error + for i in range(4): + self.assertEqual(i, next(it)) + if __name__ == "__main__": import unittest