diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 51d9e2e967990..b775e40367601 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -440,6 +440,29 @@ private[spark] object PythonRDD extends Logging { Array(port, secret) } + /** + * Create a socket server object and background thread to execute the writeFunc + * with the given OutputStream. + * + * This is the same as serveToStream, only it returns a server object that + * can be used to sync in Python. + */ + private[spark] def serveToStreamWithSync( + threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = { + + val handleFunc = (sock: Socket) => { + val out = new BufferedOutputStream(sock.getOutputStream()) + Utils.tryWithSafeFinally { + writeFunc(out) + } { + out.close() + } + } + + val server = new SocketFuncServer(authHelper, threadName, handleFunc) + Array(server.port, server.secret, server) + } + private def getMergedConf(confAsMap: java.util.HashMap[String, String], baseConf: Configuration): Configuration = { val conf = PythonHadoopUtil.mapToConf(confAsMap) @@ -957,3 +980,17 @@ private[spark] class PythonParallelizeServer(sc: SparkContext, parallelism: Int) } } +/** + * Create a socket server class and run user function on the socket in a background thread. + * This is the same as calling SocketAuthServer.setupOneConnectionServer except it creates + * a server object that can then be synced from Python. + */ +private [spark] class SocketFuncServer( + authHelper: SocketAuthHelper, + threadName: String, + func: Socket => Unit) extends PythonServer[Unit](authHelper, threadName) { + + override def handleConnection(sock: Socket): Unit = { + func(sock) + } +} diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 45b90b5bd97a3..4c4dd278dfa24 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -2175,9 +2175,13 @@ def _collectAsArrow(self): .. note:: Experimental. """ - with SCCallSiteSync(self._sc) as css: - sock_info = self._jdf.collectAsArrowToPython() - return list(_load_from_socket(sock_info, ArrowStreamSerializer())) + with SCCallSiteSync(self._sc): + from pyspark.rdd import _load_from_socket + port, auth_secret, jsocket_auth_server = self._jdf.collectAsArrowToPython() + try: + return list(_load_from_socket((port, auth_secret), ArrowStreamSerializer())) + finally: + jsocket_auth_server.getResult() # Join serving thread and raise any exceptions ########################################################################################## # Pandas compatibility diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index f65fe885ef7d7..949ce3b824011 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -80,7 +80,7 @@ _have_pyarrow = _pyarrow_requirement_message is None _test_compiled = _test_not_compiled_message is None -from pyspark import SparkContext +from pyspark import SparkContext, SparkConf from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row from pyspark.sql.types import * from pyspark.sql.types import UserDefinedType, _infer_type, _make_type_verifier @@ -4550,6 +4550,32 @@ def test_timestamp_dst(self): self.assertPandasEqual(pdf, df_from_pandas.toPandas()) +@unittest.skipIf( + not _have_pandas or not _have_pyarrow, + _pandas_requirement_message or _pyarrow_requirement_message) +class MaxResultArrowTests(unittest.TestCase): + # These tests are separate as 'spark.driver.maxResultSize' configuration + # is a static configuration to Spark context. + + @classmethod + def setUpClass(cls): + cls.spark = SparkSession(SparkContext( + 'local[4]', cls.__name__, conf=SparkConf().set("spark.driver.maxResultSize", "10k"))) + + # Explicitly enable Arrow and disable fallback. + cls.spark.conf.set("spark.sql.execution.arrow.enabled", "true") + cls.spark.conf.set("spark.sql.execution.arrow.fallback.enabled", "false") + + @classmethod + def tearDownClass(cls): + if hasattr(cls, "spark"): + cls.spark.stop() + + def test_exception_by_max_results(self): + with self.assertRaisesRegexp(Exception, "is bigger than"): + self.spark.range(0, 10000, 1, 100).toPandas() + + class EncryptionArrowTests(ArrowTests): @classmethod diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index be0617fdacb21..c90b2e857e664 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3284,7 +3284,7 @@ class Dataset[T] private[sql]( val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone withAction("collectAsArrowToPython", queryExecution) { plan => - PythonRDD.serveToStream("serve-Arrow") { out => + PythonRDD.serveToStreamWithSync("serve-Arrow") { out => val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId) val arrowBatchRdd = toArrowBatchRdd(plan) val numPartitions = arrowBatchRdd.partitions.length