From b4b425eb3a7a5da56ee4faefa8e2babab146f727 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Fri, 12 Mar 2021 09:55:48 +0100 Subject: [PATCH 1/4] [SPARK-34726][SQL] Fix collectToPython timeouts --- sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 6e4577591dab3..bfa49a8d81a41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3257,12 +3257,11 @@ class Dataset[T] private[sql]( private[sql] def collectToPython(): Array[Any] = { EvaluatePython.registerPicklers() - withAction("collectToPython", queryExecution) { plan => + val iter = withAction("collectToPython", queryExecution) { plan => val toJava: (Any) => Any = EvaluatePython.toJava(_, schema) - val iter: Iterator[Array[Byte]] = new SerDeUtil.AutoBatchedPickler( - plan.executeCollect().iterator.map(toJava)) - PythonRDD.serveIterator(iter, "serve-DataFrame") + new SerDeUtil.AutoBatchedPickler(plan.executeCollect().iterator.map(toJava)) } + PythonRDD.serveIterator(iter, "serve-DataFrame") } private[sql] def getRowsToPython( From 6b18cc752379387dd13c009490b54be823cf2437 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Fri, 12 Mar 2021 15:48:21 +0100 Subject: [PATCH 2/4] add UT Change-Id: If8017ce93dae2ead782567afea1eeba5438d73ba --- .../org/apache/spark/sql/DatasetSuite.scala | 33 +++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 08ebf8b10fefc..64eeca64d827c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -18,19 +18,24 @@ package org.apache.spark.sql import java.io.{Externalizable, ObjectInput, ObjectOutput} +import java.net.{InetAddress, Socket} import java.sql.{Date, Timestamp} -import org.apache.spark.SparkException +import scala.io.Source + +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.security.SocketAuthHelper import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder} import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} import org.apache.spark.sql.catalyst.util.sideBySide -import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec} +import org.apache.spark.sql.execution.{LogicalRDD, QueryExecution, RDDScanExec} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.QueryExecutionListener case class TestDataPoint(x: Int, y: Double, s: String, t: TestDataPoint2) case class TestDataPoint2(x: Int, s: String) @@ -1586,6 +1591,30 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-34726: Fix collectToPython timeouts") { + val listener = new QueryExecutionListener { + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {} + + override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { + // Longer than 15s in `PythonServer.setupOneConnectionServer` + Thread.sleep(20 * 1000) + } + } + spark.listenerManager.register(listener) + + val Array(port: Int, secretToPython: String) = spark.range(5).toDF().collectToPython() + + // Mimic Python side + val socket = new Socket(InetAddress.getByAddress(Array(127, 0, 0, 1)), port) + val authHelper = new SocketAuthHelper(new SparkConf()) { + override val secret: String = secretToPython + } + authHelper.authToServer(socket) + Source.fromInputStream(socket.getInputStream) + + spark.listenerManager.unregister(listener) + } } case class TestDataUnion(x: Int, y: Int, z: Int) From aae1992aebe464978184d10a152b7a66acaab5ce Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Fri, 12 Mar 2021 16:04:28 +0100 Subject: [PATCH 3/4] speed up UT Change-Id: I61cd8a8ea5eadbc72b38a3ecbdf1cd41556d5de6 --- .../scala/org/apache/spark/api/python/PythonRDD.scala | 4 +++- .../test/scala/org/apache/spark/sql/DatasetSuite.scala | 10 ++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index b775e40367601..378b1c864a8e4 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -878,6 +878,8 @@ private[spark] abstract class PythonServer[T]( private[spark] object PythonServer { + private[spark] var timeout = 15000 + /** * Create a socket server and run user function on the socket in a background thread. * @@ -896,7 +898,7 @@ private[spark] object PythonServer { (func: Socket => Unit): (Int, String) = { val serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1))) // Close the socket if no connection in 15 seconds - serverSocket.setSoTimeout(15000) + serverSocket.setSoTimeout(timeout) new Thread(threadName) { setDaemon(true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 64eeca64d827c..477c3dc4a40b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -24,6 +24,7 @@ import java.sql.{Date, Timestamp} import scala.io.Source import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.api.python.PythonServer import org.apache.spark.security.SocketAuthHelper import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder} import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} @@ -1593,12 +1594,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } test("SPARK-34726: Fix collectToPython timeouts") { + // Lower `PythonServer.setupOneConnectionServer` timeout for this test + val oldTimeout = PythonServer.timeout + PythonServer.timeout = 1000 + val listener = new QueryExecutionListener { override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {} override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { - // Longer than 15s in `PythonServer.setupOneConnectionServer` - Thread.sleep(20 * 1000) + // Wait longer than `PythonServer.setupOneConnectionServer` timeout + Thread.sleep(PythonServer.timeout + 1000) } } spark.listenerManager.register(listener) @@ -1614,6 +1619,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { Source.fromInputStream(socket.getInputStream) spark.listenerManager.unregister(listener) + PythonServer.timeout = oldTimeout } } From 8f6b811a28d1813d72151fe9063d5794b99fa481 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Fri, 12 Mar 2021 16:12:56 +0100 Subject: [PATCH 4/4] address comments Change-Id: I3d8077ba9e65c0367e401b3fdd8590fd8a83cf73 --- .../apache/spark/api/python/PythonRDD.scala | 1 + .../org/apache/spark/sql/DatasetSuite.scala | 24 ++++++++++--------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 378b1c864a8e4..a280e2225e816 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -878,6 +878,7 @@ private[spark] abstract class PythonServer[T]( private[spark] object PythonServer { + // visible for testing private[spark] var timeout = 15000 /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 477c3dc4a40b2..aa33f79e21919 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1606,20 +1606,22 @@ class DatasetSuite extends QueryTest with SharedSQLContext { Thread.sleep(PythonServer.timeout + 1000) } } - spark.listenerManager.register(listener) + try { + spark.listenerManager.register(listener) - val Array(port: Int, secretToPython: String) = spark.range(5).toDF().collectToPython() + val Array(port: Int, secretToPython: String) = spark.range(5).toDF().collectToPython() - // Mimic Python side - val socket = new Socket(InetAddress.getByAddress(Array(127, 0, 0, 1)), port) - val authHelper = new SocketAuthHelper(new SparkConf()) { - override val secret: String = secretToPython + // Mimic Python side + val socket = new Socket(InetAddress.getByAddress(Array(127, 0, 0, 1)), port) + val authHelper = new SocketAuthHelper(new SparkConf()) { + override val secret: String = secretToPython + } + authHelper.authToServer(socket) + Source.fromInputStream(socket.getInputStream) + } finally { + spark.listenerManager.unregister(listener) + PythonServer.timeout = oldTimeout } - authHelper.authToServer(socket) - Source.fromInputStream(socket.getInputStream) - - spark.listenerManager.unregister(listener) - PythonServer.timeout = oldTimeout } }