Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -878,6 +878,9 @@ private[spark] abstract class PythonServer[T](

private[spark] object PythonServer {

// visible for testing
private[spark] var timeout = 15000

/**
* Create a socket server and run user function on the socket in a background thread.
*
Expand All @@ -896,7 +899,7 @@ private[spark] object PythonServer {
(func: Socket => Unit): (Int, String) = {
val serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1)))
// Close the socket if no connection in 15 seconds
serverSocket.setSoTimeout(15000)
serverSocket.setSoTimeout(timeout)

new Thread(threadName) {
setDaemon(true)
Expand Down
7 changes: 3 additions & 4 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3257,12 +3257,11 @@ class Dataset[T] private[sql](

private[sql] def collectToPython(): Array[Any] = {
EvaluatePython.registerPicklers()
withAction("collectToPython", queryExecution) { plan =>
val iter = withAction("collectToPython", queryExecution) { plan =>
val toJava: (Any) => Any = EvaluatePython.toJava(_, schema)
val iter: Iterator[Array[Byte]] = new SerDeUtil.AutoBatchedPickler(
plan.executeCollect().iterator.map(toJava))
PythonRDD.serveIterator(iter, "serve-DataFrame")
new SerDeUtil.AutoBatchedPickler(plan.executeCollect().iterator.map(toJava))
}
PythonRDD.serveIterator(iter, "serve-DataFrame")
}

private[sql] def getRowsToPython(
Expand Down
41 changes: 39 additions & 2 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,25 @@
package org.apache.spark.sql

import java.io.{Externalizable, ObjectInput, ObjectOutput}
import java.net.{InetAddress, Socket}
import java.sql.{Date, Timestamp}

import org.apache.spark.SparkException
import scala.io.Source

import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.api.python.PythonServer
import org.apache.spark.security.SocketAuthHelper
import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder}
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi}
import org.apache.spark.sql.catalyst.util.sideBySide
import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec}
import org.apache.spark.sql.execution.{LogicalRDD, QueryExecution, RDDScanExec}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.QueryExecutionListener

case class TestDataPoint(x: Int, y: Double, s: String, t: TestDataPoint2)
case class TestDataPoint2(x: Int, s: String)
Expand Down Expand Up @@ -1586,6 +1592,37 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
}
}
}

test("SPARK-34726: Fix collectToPython timeouts") {
// Lower `PythonServer.setupOneConnectionServer` timeout for this test
val oldTimeout = PythonServer.timeout
PythonServer.timeout = 1000

val listener = new QueryExecutionListener {
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}

override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
// Wait longer than `PythonServer.setupOneConnectionServer` timeout
Thread.sleep(PythonServer.timeout + 1000)
}
}
try {
spark.listenerManager.register(listener)

val Array(port: Int, secretToPython: String) = spark.range(5).toDF().collectToPython()

// Mimic Python side
val socket = new Socket(InetAddress.getByAddress(Array(127, 0, 0, 1)), port)
val authHelper = new SocketAuthHelper(new SparkConf()) {
override val secret: String = secretToPython
}
authHelper.authToServer(socket)
Source.fromInputStream(socket.getInputStream)
} finally {
spark.listenerManager.unregister(listener)
PythonServer.timeout = oldTimeout
}
}
}

case class TestDataUnion(x: Int, y: Int, z: Int)
Expand Down