From 81fff20471bd2aded08380c8dd99c09fe34d2c79 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 12 May 2017 07:07:40 -0700 Subject: [PATCH 001/194] Start of work on adventures --- .../org/apache/spark/deploy/ExecutorState.scala | 2 +- .../org/apache/spark/deploy/master/Master.scala | 14 ++++++++++++++ .../org/apache/spark/deploy/worker/Worker.scala | 2 ++ .../apache/spark/scheduler/BlacklistTracker.scala | 8 ++++++++ 4 files changed, 25 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala b/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala index 69c98e28931d7..be29f3893b5eb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy private[deploy] object ExecutorState extends Enumeration { - val LAUNCHING, RUNNING, KILLED, FAILED, LOST, EXITED = Value + val LAUNCHING, RUNNING, KILLED, FAILED, LOST, EXITED, DECOMMISSIONED = Value type ExecutorState = Value diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index f10a41286c52f..bd6e6f5ec87a0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -231,6 +231,9 @@ private[deploy] class Master( logError("Leadership has been revoked -- master shutting down.") System.exit(0) + case WorkerShutdown(id, workerHost, workerPort) => + logInfo("Recording worker %d shutdown %s:%d".format(id, workerHost, workerPort)) + case RegisterWorker( id, workerHost, workerPort, workerRef, cores, memory, workerWebUiUrl, masterAddress) => logInfo("Registering worker %s:%d with %d cores, %s RAM".format( @@ -771,6 +774,17 @@ private[deploy] class Master( true } + private def decommissionWorker(worker: WorkerInfo) { + logInfo("Decommissioning worker %d on %s:%d".format(worker.id, worker.host, worker.port)) + worker.setState(WorkerState.DECOMMISSIONED) + for (exec <- worker.executors.values) { + logInfo("Telling app of decomission executors") + exec.application.driver.send(ExecutorUpdated( + exec.id, ExecutorState.DECOMMISSIONED, Some("worker decommissioned"), None, workerLost = false)) + exec.state = ExecutorState.DECOMMISSIONED + } + } + private def removeWorker(worker: WorkerInfo) { logInfo("Removing worker " + worker.id + " on " + worker.host + ":" + worker.port) worker.setState(WorkerState.DEAD) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 1198e3cb05eaa..07c4e5f36a26d 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -468,6 +468,8 @@ private[deploy] class Worker( case LaunchExecutor(masterUrl, appId, execId, appDesc, cores_, memory_) => if (masterUrl != activeMasterUrl) { logWarning("Invalid Master (" + masterUrl + ") attempted to launch executor.") + } else if (shuttingDown) { + logWarning("Asked to launch an executor while shutting down. Not launching executor.") } else { try { logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala index e130e609e4f63..6d6b7dbf08029 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala @@ -40,6 +40,8 @@ import org.apache.spark.util.{Clock, SystemClock, Utils} * stage, but still many failures over the entire application * * "flaky" executors -- they don't fail every task, but are still faulty enough to merit * blacklisting + * * shutting down executors -- executors which are shutting down and should not have new tasks + * scheduled. * * See the design doc on SPARK-8425 for a more in-depth discussion. * @@ -145,6 +147,12 @@ private[scheduler] class BlacklistTracker ( nextExpiryTime = math.min(execMinExpiry, nodeMinExpiry) } + def updateBlackListForNodeShutdown(executors: Set[String]): Unit = { + // We allow timeout on the node shutdown so if the node ends up not actually shutting down + // or is migrated in a way we don't notice we can start scheduling tasks on it. + val expiryTimeForNewBlacklists = now + BLACKLIST_TIMEOUT_MILLIS + + } def updateBlacklistForSuccessfulTaskSet( stageId: Int, From e470bac53151418d02dd5f03f243d635900376a9 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 2 Jun 2017 05:16:39 -0700 Subject: [PATCH 002/194] Mini progresss --- .../apache/spark/deploy/DeployMessage.scala | 11 +++++++ .../apache/spark/deploy/ExecutorState.scala | 6 +++- .../apache/spark/deploy/master/Master.scala | 19 ++++++++++-- .../apache/spark/deploy/worker/Worker.scala | 16 ++++++++-- .../org/apache/spark/executor/Executor.scala | 30 +++++++++++++++++-- .../spark/scheduler/BlacklistTracker.scala | 18 ++++++++--- 6 files changed, 87 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index c1a91c27eef2d..d485d843f70b9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -58,6 +58,17 @@ private[deploy] object DeployMessages { assert (port > 0) } + case class WorkerDecommission( + id: String, + host: String, + port: Int, + worker: RpcEndpointRef) + extends DeployMessage { + Utils.checkHost(host) + assert (port > 0) + } + + case class ExecutorStateChanged( appId: String, execId: Int, diff --git a/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala b/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala index be29f3893b5eb..f47c37b4f925b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala @@ -23,5 +23,9 @@ private[deploy] object ExecutorState extends Enumeration { type ExecutorState = Value - def isFinished(state: ExecutorState): Boolean = Seq(KILLED, FAILED, LOST, EXITED).contains(state) + // Decommissioned is not included as a finished state since the executor is still running but + // will soon become finished. + private val finishedStates = Seq(KILLED, FAILED, LOST, EXITED) + + def isFinished(state: ExecutorState): Boolean = finishedStates.contains(state) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index bd6e6f5ec87a0..aba8bba745af7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -231,8 +231,15 @@ private[deploy] class Master( logError("Leadership has been revoked -- master shutting down.") System.exit(0) - case WorkerShutdown(id, workerHost, workerPort) => - logInfo("Recording worker %d shutdown %s:%d".format(id, workerHost, workerPort)) + case WorkerDecommission(id, workerHost, workerPort, workerRef) => + logInfo("Recording worker %d decomissioning %s:%d".format( + id, workerHost, workerPort)) + if (state == RecoveryState.STANDBY) { + workerRef.send(MasterInStandby) + } else { + // If a worker attempts to decomission that isn't registered ignore it. + idToWorker.get(id).foreach(decommissionWorker) + } case RegisterWorker( id, workerHost, workerPort, workerRef, cores, memory, workerWebUiUrl, masterAddress) => @@ -303,6 +310,7 @@ private[deploy] class Master( // Important note: this code path is not exercised by tests, so be very careful when // changing this `if` condition. if (!normalExit + && oldState != ExecutorState.DECOMMISSIONED && appInfo.incrementRetryCount() >= MAX_EXECUTOR_RETRIES && MAX_EXECUTOR_RETRIES >= 0) { // < 0 disables this application-killing path val execs = appInfo.executors.values @@ -780,11 +788,16 @@ private[deploy] class Master( for (exec <- worker.executors.values) { logInfo("Telling app of decomission executors") exec.application.driver.send(ExecutorUpdated( - exec.id, ExecutorState.DECOMMISSIONED, Some("worker decommissioned"), None, workerLost = false)) + exec.id, ExecutorState.DECOMMISSIONED, + Some("worker decommissioned"), None, workerLost = false)) exec.state = ExecutorState.DECOMMISSIONED } } + private def recommissionWorker(worker: WorkerInfo) { + logInfo("Recommissioning worker %d on %s:%d".format(worker.id, worker.host, worker.port)) + } + private def removeWorker(worker: WorkerInfo) { logInfo("Removing worker " + worker.id + " on " + worker.host + ":" + worker.port) worker.setState(WorkerState.DEAD) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 07c4e5f36a26d..7d0a5f43b7d0b 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -71,6 +71,9 @@ private[deploy] class Worker( private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) // Send a heartbeat every (heartbeat timeout) / 4 milliseconds private val HEARTBEAT_MILLIS = conf.getLong("spark.worker.timeout", 60) * 1000 / 4 + // Decommissioning timeout + private val DECOMMISSIONING_TIMEOUT_MILLIS = + conf.getLong("spark.decommissioning.timeout", 420) * 1000 // Model retries to connect to the master, after Hadoop's model. // The first six attempts to reconnect are in shorter intervals (between 5 and 15 seconds) @@ -119,6 +122,7 @@ private[deploy] class Worker( private val workerUri = RpcEndpointAddress(rpcEnv.address, endpointName).toString private var registered = false private var connected = false + private var decommissioned = false private val workerId = generateWorkerId() private val sparkHome = if (testing) { @@ -468,8 +472,8 @@ private[deploy] class Worker( case LaunchExecutor(masterUrl, appId, execId, appDesc, cores_, memory_) => if (masterUrl != activeMasterUrl) { logWarning("Invalid Master (" + masterUrl + ") attempted to launch executor.") - } else if (shuttingDown) { - logWarning("Asked to launch an executor while shutting down. Not launching executor.") + } else if (decommissioned) { + logWarning("Asked to launch an executor while decommissioned. Not launching executor.") } else { try { logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name)) @@ -588,6 +592,9 @@ private[deploy] class Worker( case ApplicationFinished(id) => finishedApps += id maybeCleanupApplication(id) + + case DecommissionSelf => + decommissionSelf() } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { @@ -681,6 +688,11 @@ private[deploy] class Worker( } } + private def decommission(): Unit = { + decommissioned = true + + } + private[worker] def handleDriverStateChanged(driverStateChanged: DriverStateChanged): Unit = { val driverId = driverStateChanged.driverId val exception = driverStateChanged.exception diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 19e7eb086f413..bdcf80f696651 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -163,14 +163,38 @@ private[spark] class Executor( */ private var heartbeatFailures = 0 + /** + * Flag to prevent launching new tasks while decommissioned. There could be a race condition + * accessing this, but decommissioning is only intended to help not be a hard stop. + */ + private var decommissioned = false + startDriverHeartbeater() private[executor] def numRunningTasks: Int = runningTasks.size() + /** + * Mark an executor for decommissioning and avoid launching new tasks. + */ + private[spark] def decommission(): Unit = { + decommissioned = true + } + + /** + * Restore the executor to a running state. This can happen on decommissioning timeout. + */ + private[spark] def recomission(): Unit = { + decommissioned = false + } + def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = { - val tr = new TaskRunner(context, taskDescription) - runningTasks.put(taskDescription.taskId, tr) - threadPool.execute(tr) + if (!decommissioned) { + val tr = new TaskRunner(context, taskDescription) + runningTasks.put(taskDescription.taskId, tr) + threadPool.execute(tr) + } else { + log.info(s"Not launching task, executor is in decommissioned state") + } } def killTask(taskId: Long, interruptThread: Boolean, reason: String): Unit = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala index 6d6b7dbf08029..073d4b2bd0f40 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala @@ -63,6 +63,8 @@ private[scheduler] class BlacklistTracker ( private val MAX_FAILURES_PER_EXEC = conf.get(config.MAX_FAILURES_PER_EXEC) private val MAX_FAILED_EXEC_PER_NODE = conf.get(config.MAX_FAILED_EXEC_PER_NODE) val BLACKLIST_TIMEOUT_MILLIS = BlacklistTracker.getBlacklistTimeout(conf) + private val DECOMMISSIONING_TIMEOUT_MILLIS = + conf.getLong("spark.decommissioning.timeout", 420) * 1000 /** * A map from executorId to information on task failures. Tracks the time of each task failure, @@ -147,11 +149,19 @@ private[scheduler] class BlacklistTracker ( nextExpiryTime = math.min(execMinExpiry, nodeMinExpiry) } - def updateBlackListForNodeShutdown(executors: Set[String]): Unit = { - // We allow timeout on the node shutdown so if the node ends up not actually shutting down + def updateBlacklistForDecommission(node: String): Unit = { + // We allow timeout on the node decommission so if the node ends up not actually shutting down // or is migrated in a way we don't notice we can start scheduling tasks on it. - val expiryTimeForNewBlacklists = now + BLACKLIST_TIMEOUT_MILLIS - + val now = clock.getTimeMillis() + val expiryTimeForNewBlacklists = now + DECOMMISSIONING_TIMEOUT_MILLIS + logInfo(s"Blacklisting node $node because it is decommissioning") + + // Note: we do this even if the node is already blacklisted. + val blacklistedExecsOnNode = + nodeToBlacklistedExecs.getOrElseUpdate(node, HashSet[String]()) + nodeIdToBlacklistExpiryTime.put(node, expiryTimeForNewBlacklists) + listenerBus.post(SparkListenerNodeBlacklisted(now, node, blacklistedExecsOnNode.size)) + _nodeBlacklist.set(nodeIdToBlacklistExpiryTime.keySet.toSet) } def updateBlacklistForSuccessfulTaskSet( From a00c707cd707c6ca2003c4d53ee51735dda3a96e Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 2 Jun 2017 06:29:41 -0700 Subject: [PATCH 003/194] Go down the path of handling as lost but urgh lets just blacklist instead maybe --- .../org/apache/spark/deploy/DeployMessage.scala | 2 ++ .../org/apache/spark/deploy/ExecutorState.scala | 4 +--- .../org/apache/spark/deploy/master/Master.scala | 6 ++---- .../org/apache/spark/deploy/worker/Worker.scala | 7 +++---- .../org/apache/spark/executor/Executor.scala | 7 ------- .../spark/scheduler/BlacklistTracker.scala | 17 ----------------- .../spark/scheduler/ExecutorLossReason.scala | 8 ++++++++ 7 files changed, 16 insertions(+), 35 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index d485d843f70b9..f3e2bb4f0ae3b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -142,6 +142,8 @@ private[deploy] object DeployMessages { case object ReregisterWithMaster // used when a worker attempts to reconnect to a master + case object DecommissionSelf // Mark self for decommissioning. + // AppClient to Master case class RegisterApplication(appDescription: ApplicationDescription, driver: RpcEndpointRef) diff --git a/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala b/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala index f47c37b4f925b..7cd0e544d111d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala @@ -23,9 +23,7 @@ private[deploy] object ExecutorState extends Enumeration { type ExecutorState = Value - // Decommissioned is not included as a finished state since the executor is still running but - // will soon become finished. - private val finishedStates = Seq(KILLED, FAILED, LOST, EXITED) + private val finishedStates = Seq(KILLED, FAILED, LOST, EXITED, DECOMISSIONED) def isFinished(state: ExecutorState): Boolean = finishedStates.contains(state) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index aba8bba745af7..37fe68eee9d6d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -791,11 +791,9 @@ private[deploy] class Master( exec.id, ExecutorState.DECOMMISSIONED, Some("worker decommissioned"), None, workerLost = false)) exec.state = ExecutorState.DECOMMISSIONED + exec.application.removeExecutor(exec) } - } - - private def recommissionWorker(worker: WorkerInfo) { - logInfo("Recommissioning worker %d on %s:%d".format(worker.id, worker.host, worker.port)) + persistenceEngine.removeWorker(worker) } private def removeWorker(worker: WorkerInfo) { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 7d0a5f43b7d0b..40d21de3fa44c 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -71,9 +71,6 @@ private[deploy] class Worker( private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) // Send a heartbeat every (heartbeat timeout) / 4 milliseconds private val HEARTBEAT_MILLIS = conf.getLong("spark.worker.timeout", 60) * 1000 / 4 - // Decommissioning timeout - private val DECOMMISSIONING_TIMEOUT_MILLIS = - conf.getLong("spark.decommissioning.timeout", 420) * 1000 // Model retries to connect to the master, after Hadoop's model. // The first six attempts to reconnect are in shorter intervals (between 5 and 15 seconds) @@ -688,8 +685,10 @@ private[deploy] class Worker( } } - private def decommission(): Unit = { + private def decommissionSelf(): Unit = { decommissioned = true + // TODO: Send decommission notification to executors & shuffle service. + // Also send message to master program. } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index bdcf80f696651..fc6f613aad164 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -180,13 +180,6 @@ private[spark] class Executor( decommissioned = true } - /** - * Restore the executor to a running state. This can happen on decommissioning timeout. - */ - private[spark] def recomission(): Unit = { - decommissioned = false - } - def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = { if (!decommissioned) { val tr = new TaskRunner(context, taskDescription) diff --git a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala index 073d4b2bd0f40..3abc93147e65b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala @@ -63,8 +63,6 @@ private[scheduler] class BlacklistTracker ( private val MAX_FAILURES_PER_EXEC = conf.get(config.MAX_FAILURES_PER_EXEC) private val MAX_FAILED_EXEC_PER_NODE = conf.get(config.MAX_FAILED_EXEC_PER_NODE) val BLACKLIST_TIMEOUT_MILLIS = BlacklistTracker.getBlacklistTimeout(conf) - private val DECOMMISSIONING_TIMEOUT_MILLIS = - conf.getLong("spark.decommissioning.timeout", 420) * 1000 /** * A map from executorId to information on task failures. Tracks the time of each task failure, @@ -149,21 +147,6 @@ private[scheduler] class BlacklistTracker ( nextExpiryTime = math.min(execMinExpiry, nodeMinExpiry) } - def updateBlacklistForDecommission(node: String): Unit = { - // We allow timeout on the node decommission so if the node ends up not actually shutting down - // or is migrated in a way we don't notice we can start scheduling tasks on it. - val now = clock.getTimeMillis() - val expiryTimeForNewBlacklists = now + DECOMMISSIONING_TIMEOUT_MILLIS - logInfo(s"Blacklisting node $node because it is decommissioning") - - // Note: we do this even if the node is already blacklisted. - val blacklistedExecsOnNode = - nodeToBlacklistedExecs.getOrElseUpdate(node, HashSet[String]()) - nodeIdToBlacklistExpiryTime.put(node, expiryTimeForNewBlacklists) - listenerBus.post(SparkListenerNodeBlacklisted(now, node, blacklistedExecsOnNode.size)) - _nodeBlacklist.set(nodeIdToBlacklistExpiryTime.keySet.toSet) - } - def updateBlacklistForSuccessfulTaskSet( stageId: Int, stageAttemptId: Int, diff --git a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala index 46a35b6a2eaf9..18579e25da013 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala @@ -58,3 +58,11 @@ private [spark] object LossReasonPending extends ExecutorLossReason("Pending los private[spark] case class SlaveLost(_message: String = "Slave lost", workerLost: Boolean = false) extends ExecutorLossReason(_message) + +/** + * A loss reason that means the worker is marked for decommissioning. + * + * This is used by the task scheduler to remove state associated with the executor, but + * not yet fail any tasks that were running in the executor before the executor is "fully" lost. + */ +private [spark] object WorkerDecommission extends ExecutorLossReason("Worker Decommission.") From 74ade447ec94b600f5447a9269e66e47ae78fb11 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 9 Jun 2017 11:59:26 -0700 Subject: [PATCH 004/194] Plumb through executor loss to the scheduables --- .../apache/spark/deploy/DeployMessage.scala | 8 +--- .../apache/spark/deploy/ExecutorState.scala | 4 +- .../deploy/client/StandaloneAppClient.scala | 3 ++ .../client/StandaloneAppClientListener.scala | 4 +- .../apache/spark/deploy/master/Master.scala | 31 ++++++++------- .../apache/spark/deploy/worker/Worker.scala | 2 +- .../main/scala/org/apache/spark/rdd/RDD.scala | 2 + .../org/apache/spark/scheduler/Pool.scala | 4 ++ .../apache/spark/scheduler/Schedulable.scala | 1 + .../spark/scheduler/TaskScheduler.scala | 5 +++ .../spark/scheduler/TaskSchedulerImpl.scala | 5 +++ .../spark/scheduler/TaskSetManager.scala | 6 +++ .../cluster/CoarseGrainedClusterMessage.scala | 2 + .../CoarseGrainedSchedulerBackend.scala | 39 ++++++++++++++++++- .../cluster/StandaloneSchedulerBackend.scala | 6 +++ 15 files changed, 98 insertions(+), 24 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index f3e2bb4f0ae3b..55faf6f58234e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -60,14 +60,8 @@ private[deploy] object DeployMessages { case class WorkerDecommission( id: String, - host: String, - port: Int, worker: RpcEndpointRef) - extends DeployMessage { - Utils.checkHost(host) - assert (port > 0) - } - + extends DeployMessage case class ExecutorStateChanged( appId: String, diff --git a/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala b/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala index 7cd0e544d111d..0751bcf221f86 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala @@ -23,7 +23,9 @@ private[deploy] object ExecutorState extends Enumeration { type ExecutorState = Value - private val finishedStates = Seq(KILLED, FAILED, LOST, EXITED, DECOMISSIONED) + // DECOMMISSIONED isn't listed as finished since we don't want to remove the executor from + // the worker and the executor still exists - but we do want to avoid scheduling new tasks on it. + private val finishedStates = Seq(KILLED, FAILED, LOST, EXITED) def isFinished(state: ExecutorState): Boolean = finishedStates.contains(state) } diff --git a/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala index 93f58ce63799f..4d2e750541181 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala @@ -39,6 +39,7 @@ import org.apache.spark.util.{RpcUtils, ThreadUtils} * Takes a master URL, an app description, and a listener for cluster events, and calls * back the listener when various events occur. * + * * @param masterUrls Each url should look like spark://host:port. */ private[spark] class StandaloneAppClient( @@ -180,6 +181,8 @@ private[spark] class StandaloneAppClient( logInfo("Executor updated: %s is now %s%s".format(fullId, state, messageText)) if (ExecutorState.isFinished(state)) { listener.executorRemoved(fullId, message.getOrElse(""), exitStatus, workerLost) + } else if (state == ExecutorState.DECOMMISSIONED) { + listener.executorDecommissioned(fullId, message.getOrElse("")) } case MasterChanged(masterRef, masterWebUiUrl) => diff --git a/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClientListener.scala b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClientListener.scala index 64255ec92b72a..4cb329b2c13a3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClientListener.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClientListener.scala @@ -37,5 +37,7 @@ private[spark] trait StandaloneAppClientListener { fullId: String, workerId: String, hostPort: String, cores: Int, memory: Int): Unit def executorRemoved( - fullId: String, message: String, exitStatus: Option[Int], workerLost: Boolean): Unit + fullId: String, message: String, exitStatus: Option[Int], workerLost: Boolean): Unit + + def executorDecommissioned(fullId: String, message: String): Unit } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 37fe68eee9d6d..cb7c646a79247 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -231,9 +231,8 @@ private[deploy] class Master( logError("Leadership has been revoked -- master shutting down.") System.exit(0) - case WorkerDecommission(id, workerHost, workerPort, workerRef) => - logInfo("Recording worker %d decomissioning %s:%d".format( - id, workerHost, workerPort)) + case WorkerDecommission(id, workerRef) => + logInfo("Recording worker %d decommissioning".format(id)) if (state == RecoveryState.STANDBY) { workerRef.send(MasterInStandby) } else { @@ -309,6 +308,7 @@ private[deploy] class Master( // Only retry certain number of times so we don't go into an infinite loop. // Important note: this code path is not exercised by tests, so be very careful when // changing this `if` condition. + // We also don't count failures from decommissioned workers since they are "expected." if (!normalExit && oldState != ExecutorState.DECOMMISSIONED && appInfo.incrementRetryCount() >= MAX_EXECUTOR_RETRIES @@ -783,17 +783,22 @@ private[deploy] class Master( } private def decommissionWorker(worker: WorkerInfo) { - logInfo("Decommissioning worker %d on %s:%d".format(worker.id, worker.host, worker.port)) - worker.setState(WorkerState.DECOMMISSIONED) - for (exec <- worker.executors.values) { - logInfo("Telling app of decomission executors") - exec.application.driver.send(ExecutorUpdated( - exec.id, ExecutorState.DECOMMISSIONED, - Some("worker decommissioned"), None, workerLost = false)) - exec.state = ExecutorState.DECOMMISSIONED - exec.application.removeExecutor(exec) + if (worker.state != WorkerState.DECOMMISSIONED) { + logInfo("Decommissioning worker %d on %s:%d".format(worker.id, worker.host, worker.port)) + worker.setState(WorkerState.DECOMMISSIONED) + for (exec <- worker.executors.values) { + logInfo("Telling app of decomission executors") + exec.application.driver.send(ExecutorUpdated( + exec.id, ExecutorState.DECOMMISSIONED, + Some("worker decommissioned"), None, workerLost = false)) + exec.state = ExecutorState.DECOMMISSIONED + exec.application.removeExecutor(exec) + } + persistenceEngine.removeWorker(worker) + } else { + logWarning("Skipping decommissioning worker %d on %s:%d as worker is already decommissioned". + format(worker.id, worker.host, worker.port)) } - persistenceEngine.removeWorker(worker) } private def removeWorker(worker: WorkerInfo) { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 40d21de3fa44c..fb83bd9005dec 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -689,7 +689,7 @@ private[deploy] class Worker( decommissioned = true // TODO: Send decommission notification to executors & shuffle service. // Also send message to master program. - + sendToMaster(WorkerDecommission(workerId, self)) } private[worker] def handleDriverStateChanged(driverStateChanged: DriverStateChanged): Unit = { diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 2985c90119468..366f8038be4d8 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -335,6 +335,7 @@ abstract class RDD[T: ClassTag]( readCachedBlock = false computeOrReadCheckpoint(partition, context) }) match { + // Block hit. case Left(blockResult) => if (readCachedBlock) { val existingMetrics = context.taskMetrics().inputMetrics @@ -348,6 +349,7 @@ abstract class RDD[T: ClassTag]( } else { new InterruptibleIterator(context, blockResult.data.asInstanceOf[Iterator[T]]) } + // Need to compute the block. case Right(iter) => new InterruptibleIterator(context, iter.asInstanceOf[Iterator[T]]) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala index 1181371ab425a..35829af5e7057 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala @@ -88,6 +88,10 @@ private[spark] class Pool( schedulableQueue.asScala.foreach(_.executorLost(executorId, host, reason)) } + override def executorDecommission(executorId: String): Unit = { + schedulableQueue.asScala.foreach(_.executorDecommission(executorId)) + } + override def checkSpeculatableTasks(minTimeToSpeculation: Int): Boolean = { var shouldRevive = false for (schedulable <- schedulableQueue.asScala) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala b/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala index b6f88ed0a93aa..8cc239c81d11a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala @@ -43,6 +43,7 @@ private[spark] trait Schedulable { def removeSchedulable(schedulable: Schedulable): Unit def getSchedulableByName(name: String): Schedulable def executorLost(executorId: String, host: String, reason: ExecutorLossReason): Unit + def executorDecommission(executorId: String): Unit def checkSpeculatableTasks(minTimeToSpeculation: Int): Boolean def getSortedTaskSetQueue: ArrayBuffer[TaskSetManager] } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index 3de7d1f7de22b..d36dba4071c79 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -84,6 +84,11 @@ private[spark] trait TaskScheduler { */ def applicationId(): String = appId + /** + * Process a decommissioning executor. + */ + def executorDecommission(executorId: String): Unit + /** * Process a lost executor */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 629cfc7c7a8ce..66c91e270b005 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -535,6 +535,11 @@ private[spark] class TaskSchedulerImpl private[scheduler]( } } + override def executorDecommission(executorId: String): Unit = { + rootPool.executorDecommission(executorId) + backend.reviveOffers() + } + override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = { var failedExecutor: Option[String] = None diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index a41b059fa7dec..46b43f70268d2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -1006,6 +1006,12 @@ private[spark] class TaskSetManager( levels.toArray } + def executorDecommission(execId: String) { + recomputeLocality() + // Future consideration: if an executor is decommissioned it may make sense to add the current + // tasks to the spec exec queue. + } + def recomputeLocality() { val previousLocalityLevel = myLocalityLevels(currentLocalityIndex) myLocalityLevels = computeValidLocalityLevels() diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 6b49bd699a13a..0c298b5f97ef4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -85,6 +85,8 @@ private[spark] object CoarseGrainedClusterMessages { case class RemoveExecutor(executorId: String, reason: ExecutorLossReason) extends CoarseGrainedClusterMessage + case class DecomissionExecutor(executorId: String) extends CoarseGrainedClusterMessage + case class SetupDriver(driver: RpcEndpointRef) extends CoarseGrainedClusterMessage // Exchanged between the driver and the AM in Yarn client mode diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index dc82bb7704727..6a3928a8e8e34 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -101,6 +101,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Executors that have been lost, but for which we don't yet know the real exit reason. protected val executorsPendingLossReason = new HashSet[String] + // Executors which are being decommissioned + protected val executorsPendingDecommission = new HashSet[String] protected val addressToExecutorId = new HashMap[RpcAddress, String] @@ -270,7 +272,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp private def executorIsAlive(executorId: String): Boolean = synchronized { !executorsPendingToRemove.contains(executorId) && - !executorsPendingLossReason.contains(executorId) + !executorsPendingLossReason.contains(executorId) && + !executorsPendingDecommission.contains(executorId) } // Launch tasks returned by a set of resource offers @@ -331,6 +334,30 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } } + /** + * Stop making resource offers for the given executor. The executor is marked as lost with the + * loss reason as WorkerDecommission. + * + */ + private def decommissionExecutor(executorId: String): Boolean = { + val shouldDisable = CoarseGrainedSchedulerBackend.this.synchronized { + // Only bother decommissioning executors which are alive. + if (executorIsAlive(executorId)) { + executorsPendingDecommission += executorId + true + } else { + false + } + } + + if (shouldDisable) { + logInfo(s"Decommissioning executor $executorId.") + scheduler.executorDecommission(executorId) + } + + shouldDisable + } + /** * Stop making resource offers for the given executor. The executor is marked as lost with * the loss reason still pending. @@ -454,6 +481,16 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp }(ThreadUtils.sameThread) } + /** + * Called by subclasses when notified of a decommissioning worker. + */ + protected def decommissionExecutor(executorId: String): Unit = { + // Only log the failure since we don't care about the result. + driverEndpoint.ask[Boolean](DecomissionExecutor(executorId)).onFailure { case t => + logError(t.getMessage, t) + }(ThreadUtils.sameThread) + } + def sufficientResourcesRegistered(): Boolean = true override def isReady(): Boolean = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index 0529fe9eed4da..b1cde2d6f0348 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -161,6 +161,12 @@ private[spark] class StandaloneSchedulerBackend( removeExecutor(fullId.split("/")(1), reason) } + override def executorDecommissioned(fullId: String, message: String) { + logInfo("Executor %s decommissioned: %s".format(fullId, message)) + decommissionExecutor(fullId.split("/")(1)) + } + + override def sufficientResourcesRegistered(): Boolean = { totalCoreCount.get() >= totalExpectedCores * minRegisteredRatio } From a880177f9bf45a2f0644229fbff863f80d058161 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 21 Jun 2017 06:00:03 -0700 Subject: [PATCH 005/194] AppClient suite works! yay --- .../deploy/client/StandaloneAppClient.scala | 6 +++ .../apache/spark/deploy/master/Master.scala | 15 +++++-- .../apache/spark/deploy/worker/Worker.scala | 24 +++++++---- .../spark/internal/config/package.scala | 10 +++++ .../spark/scheduler/TaskSchedulerImpl.scala | 2 + .../spark/deploy/client/AppClientSuite.scala | 40 +++++++++++++++++-- .../spark/scheduler/DAGSchedulerSuite.scala | 2 + .../ExternalClusterManagerSuite.scala | 1 + 8 files changed, 86 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala index 4d2e750541181..921baf7112644 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala @@ -171,6 +171,7 @@ private[spark] class StandaloneAppClient( case ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) => val fullId = appId + "/" + id + println("Executor added on worker " + workerId) logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, hostPort, cores)) listener.executorAdded(fullId, workerId, hostPort, cores, memory) @@ -179,10 +180,15 @@ private[spark] class StandaloneAppClient( val fullId = appId + "/" + id val messageText = message.map(s => " (" + s + ")").getOrElse("") logInfo("Executor updated: %s is now %s%s".format(fullId, state, messageText)) + println("Executor updated, also yay! state is " + state) if (ExecutorState.isFinished(state)) { + println("executor is finished. good bye!") listener.executorRemoved(fullId, message.getOrElse(""), exitStatus, workerLost) } else if (state == ExecutorState.DECOMMISSIONED) { + println("Propegating decommission to my listener") listener.executorDecommissioned(fullId, message.getOrElse("")) + } else { + println("not doing anything about that eh :p") } case MasterChanged(masterRef, masterWebUiUrl) => diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index cb7c646a79247..63c558595ee32 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -232,7 +232,9 @@ private[deploy] class Master( System.exit(0) case WorkerDecommission(id, workerRef) => - logInfo("Recording worker %d decommissioning".format(id)) + println("pandas!!") + logInfo("Recording worker %s decommissioning".format(id)) + println("Decommisioning worker :) " + id) if (state == RecoveryState.STANDBY) { workerRef.send(MasterInStandby) } else { @@ -335,6 +337,7 @@ private[deploy] class Master( } case Heartbeat(workerId, worker) => + println("Heartbeat received " + workerId) idToWorker.get(workerId) match { case Some(workerInfo) => workerInfo.lastHeartbeat = System.currentTimeMillis() @@ -752,6 +755,7 @@ private[deploy] class Master( } private def registerWorker(worker: WorkerInfo): Boolean = { + println("Registering worker..." + worker) // There may be one or more refs to dead workers on this same node (w/ different ID's), // remove them. workers.filter { w => @@ -783,22 +787,27 @@ private[deploy] class Master( } private def decommissionWorker(worker: WorkerInfo) { + println("Found worker info for decomissioning worker " + worker) if (worker.state != WorkerState.DECOMMISSIONED) { - logInfo("Decommissioning worker %d on %s:%d".format(worker.id, worker.host, worker.port)) + logInfo("Decommissioning worker %s on %s:%d".format(worker.id, worker.host, worker.port)) worker.setState(WorkerState.DECOMMISSIONED) + println("Worker executors are " + worker.executors.values.toList) for (exec <- worker.executors.values) { logInfo("Telling app of decomission executors") + println("Telling the drivers") exec.application.driver.send(ExecutorUpdated( exec.id, ExecutorState.DECOMMISSIONED, Some("worker decommissioned"), None, workerLost = false)) exec.state = ExecutorState.DECOMMISSIONED exec.application.removeExecutor(exec) } + // On recovery do not add a decommissioned executor persistenceEngine.removeWorker(worker) } else { - logWarning("Skipping decommissioning worker %d on %s:%d as worker is already decommissioned". + logWarning("Skipping decommissioning worker %s on %s:%d as worker is already decommissioned". format(worker.id, worker.host, worker.port)) } + println("Finished decommissioning worker from master") } private def removeWorker(worker: WorkerInfo) { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index fb83bd9005dec..7103210704bb4 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -35,7 +35,7 @@ import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.ExternalShuffleService import org.apache.spark.deploy.master.{DriverState, Master} import org.apache.spark.deploy.worker.ui.WorkerWebUI -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} import org.apache.spark.metrics.MetricsSystem import org.apache.spark.rpc._ import org.apache.spark.util.{ThreadUtils, Utils} @@ -641,7 +641,9 @@ private[deploy] class Worker( */ private def sendToMaster(message: Any): Unit = { master match { - case Some(masterRef) => masterRef.send(message) + case Some(masterRef) => + println("Sending message to master " + message) + masterRef.send(message) case None => logWarning( s"Dropping $message because the connection to master has not yet been established") @@ -685,11 +687,19 @@ private[deploy] class Worker( } } - private def decommissionSelf(): Unit = { - decommissioned = true - // TODO: Send decommission notification to executors & shuffle service. - // Also send message to master program. - sendToMaster(WorkerDecommission(workerId, self)) + private[deploy] def decommissionSelf(): Unit = { + println("decommission self called") + if (conf.get(config.WORKER_DECOMMISSION_ENABLED)) { + println("propegating") + logDebug("Decommissioning self") + decommissioned = true + // TODO: Send decommission notification to executors & shuffle service. + // Also send message to master program. + sendToMaster(WorkerDecommission(workerId, self)) + } else { + println("skipping") + logWarning("Asked to decommission self, but decommissioning not enabled") + } } private[worker] def handleDriverStateChanged(driverStateChanged: DriverStateChanged): Unit = { diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 84ef57f2d271b..f06691231cfe8 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -151,6 +151,16 @@ package object config { .createOptional // End blacklist confs + private[spark] val WORKER_DECOMMISSION_ENABLED = + ConfigBuilder("spark.worker.decommission.enabled") + .booleanConf + .createWithDefault(false) + + private[spark] val WORKER_DECOMMISSION_EC2_POLE = + ConfigBuilder("spark.worker.decommission.ec2Pole") + .booleanConf + .createWithDefault(false) + private[spark] val UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE = ConfigBuilder("spark.files.fetchFailure.unRegisterOutputOnHost") .doc("Whether to un-register all the outputs on the host in condition that we receive " + diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 66c91e270b005..89d2b5b0ad46a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -536,6 +536,7 @@ private[spark] class TaskSchedulerImpl private[scheduler]( } override def executorDecommission(executorId: String): Unit = { + println("Decommissioning " + executorId) rootPool.executorDecommission(executorId) backend.reviveOffers() } @@ -544,6 +545,7 @@ private[spark] class TaskSchedulerImpl private[scheduler]( var failedExecutor: Option[String] = None synchronized { + println("Recording executor loss " + executorId) if (executorIdToRunningTaskIds.contains(executorId)) { val hostPort = executorIdToHost(executorId) logExecutorLoss(executorId, hostPort, reason) diff --git a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala index 936639b845789..378a795e75f54 100644 --- a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.deploy.{ApplicationDescription, Command} import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} import org.apache.spark.deploy.master.{ApplicationInfo, Master} import org.apache.spark.deploy.worker.Worker -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} import org.apache.spark.rpc.RpcEnv import org.apache.spark.util.Utils @@ -43,13 +43,13 @@ class AppClientSuite with Eventually with ScalaFutures { private val numWorkers = 2 - private val conf = new SparkConf() - private val securityManager = new SecurityManager(conf) + private var conf: SparkConf = null private var masterRpcEnv: RpcEnv = null private var workerRpcEnvs: Seq[RpcEnv] = null private var master: Master = null private var workers: Seq[Worker] = null + private var securityManager: SecurityManager = null /** * Start the local cluster. @@ -57,6 +57,8 @@ class AppClientSuite */ override def beforeAll(): Unit = { super.beforeAll() + conf = new SparkConf().set(config.WORKER_DECOMMISSION_ENABLED.key, "true") + securityManager = new SecurityManager(conf) masterRpcEnv = RpcEnv.create(Master.SYSTEM_NAME, "localhost", 0, conf, securityManager) workerRpcEnvs = (0 until numWorkers).map { i => RpcEnv.create(Worker.SYSTEM_NAME + i, "localhost", 0, conf, securityManager) @@ -110,8 +112,23 @@ class AppClientSuite assert(apps.head.getExecutorLimit === numExecutorsRequested, s"executor request failed") } + // Save the executor id before decommissioning so we can kill it + val application = getApplications().head + val executors = application.executors + val executorId: String = executors.head._2.fullId + + // Send a decommission self to all the workers + // Note: normally the worker would send this on their own. + println("Workers are " + workers) + workers.foreach(worker => worker.decommissionSelf()) + + // Decommissioning is async. + eventually(timeout(1.seconds), interval(10.millis)) { + // We only record decommissioning for the executor we've requested + assert(ci.listener.execDecommissionedList.size === 1) + } + // Send request to kill executor, verify request was made - val executorId: String = getApplications().head.executors.head._2.fullId whenReady( ci.client.killExecutors(Seq(executorId)), timeout(10.seconds), @@ -119,6 +136,15 @@ class AppClientSuite assert(acknowledged) } + // Verify that asking for executors on the decommissioned workers fails + whenReady( + ci.client.requestTotalExecutors(numExecutorsRequested), + timeout(10.seconds), + interval(10.millis)) { acknowledged => + assert(acknowledged) + } + assert(getApplications().head.executors.size === 0) + // Issue stop command for Client to disconnect from Master ci.client.stop() @@ -186,6 +212,7 @@ class AppClientSuite val deadReasonList = new ConcurrentLinkedQueue[String]() val execAddedList = new ConcurrentLinkedQueue[String]() val execRemovedList = new ConcurrentLinkedQueue[String]() + val execDecommissionedList = new ConcurrentLinkedQueue[String]() def connected(id: String): Unit = { connectedIdList.add(id) @@ -214,6 +241,11 @@ class AppClientSuite id: String, message: String, exitStatus: Option[Int], workerLost: Boolean): Unit = { execRemovedList.add(id) } + + def executorDecommissioned(id: String, message: String): Unit = { + println("Decommission executor " + id) + execDecommissionedList.add(id) + } } /** Create AppClient and supporting objects */ diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index ddd3281106745..9a9f13796abfc 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -130,6 +130,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou taskId: Long, interruptThread: Boolean, reason: String): Boolean = false override def setDAGScheduler(dagScheduler: DAGScheduler) = {} override def defaultParallelism() = 2 + override def executorDecommission(executorId: String) = {} override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} override def applicationAttemptId(): Option[String] = None } @@ -631,6 +632,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou execId: String, accumUpdates: Array[(Long, Seq[AccumulatorV2[_, _]])], blockManagerId: BlockManagerId): Boolean = true + override def executorDecommission(executorId: String): Unit = {} override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} override def applicationAttemptId(): Option[String] = None } diff --git a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala index ba56af8215cd7..831abd08f65fb 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala @@ -83,6 +83,7 @@ private class DummyTaskScheduler extends TaskScheduler { taskId: Long, interruptThread: Boolean, reason: String): Boolean = false override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {} override def defaultParallelism(): Int = 2 + override def executorDecommission(executorId: String): Unit = {} override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} override def applicationAttemptId(): Option[String] = None def executorHeartbeatReceived( From b9704038e96b0bb862b824cb9723e68633e18c06 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 21 Jun 2017 10:04:53 -0700 Subject: [PATCH 006/194] Decomissioning now works in the coarse grained scheduler, yay.... --- .../cluster/CoarseGrainedSchedulerBackend.scala | 13 +++++++++++-- .../cluster/StandaloneSchedulerBackend.scala | 1 + 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 6a3928a8e8e34..be4fd55374052 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -221,6 +221,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp removeExecutor(executorId, reason) context.reply(true) + case DecomissionExecutor(executorId) => + decommissionExecutor(executorId) + context.reply(true) + case RetrieveSparkAppConfig => val reply = SparkAppConfig(sparkProperties, SparkEnv.get.securityManager.getIOEncryptionKey()) @@ -340,12 +344,15 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp * */ private def decommissionExecutor(executorId: String): Boolean = { + println("sched backend received exec decom req for " + executorId) val shouldDisable = CoarseGrainedSchedulerBackend.this.synchronized { // Only bother decommissioning executors which are alive. if (executorIsAlive(executorId)) { + println("Adding exec to pending decom list") executorsPendingDecommission += executorId true } else { + println("Exec already decommed") false } } @@ -354,7 +361,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp logInfo(s"Decommissioning executor $executorId.") scheduler.executorDecommission(executorId) } - + println("Coarse grained scheduler decomissioning executor " + shouldDisable) shouldDisable } @@ -484,9 +491,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp /** * Called by subclasses when notified of a decommissioning worker. */ - protected def decommissionExecutor(executorId: String): Unit = { + private[spark] def decommissionExecutor(executorId: String): Unit = { + println("scheduler asked to decom exec " + executorId) // Only log the failure since we don't care about the result. driverEndpoint.ask[Boolean](DecomissionExecutor(executorId)).onFailure { case t => + println("Decommissioning executor failed? idk w/e") logError(t.getMessage, t) }(ThreadUtils.sameThread) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index b1cde2d6f0348..2b80e01fca5f1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -162,6 +162,7 @@ private[spark] class StandaloneSchedulerBackend( } override def executorDecommissioned(fullId: String, message: String) { + println("Decommission executor inside of scheduler") logInfo("Executor %s decommissioned: %s".format(fullId, message)) decommissionExecutor(fullId.split("/")(1)) } From ded6bbc8d056f9f82302450aa27dbc0d94fdbccd Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 22 Jun 2017 09:10:06 -0700 Subject: [PATCH 007/194] Remove sketchy println debugging --- .../apache/spark/deploy/client/StandaloneAppClient.scala | 6 ------ .../scala/org/apache/spark/deploy/master/Master.scala | 8 -------- .../scala/org/apache/spark/deploy/worker/Worker.scala | 4 ---- .../org/apache/spark/scheduler/TaskSchedulerImpl.scala | 2 -- .../scheduler/cluster/CoarseGrainedSchedulerBackend.scala | 6 ------ .../scheduler/cluster/StandaloneSchedulerBackend.scala | 1 - .../org/apache/spark/deploy/client/AppClientSuite.scala | 2 -- 7 files changed, 29 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala index 921baf7112644..4d2e750541181 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala @@ -171,7 +171,6 @@ private[spark] class StandaloneAppClient( case ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) => val fullId = appId + "/" + id - println("Executor added on worker " + workerId) logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, hostPort, cores)) listener.executorAdded(fullId, workerId, hostPort, cores, memory) @@ -180,15 +179,10 @@ private[spark] class StandaloneAppClient( val fullId = appId + "/" + id val messageText = message.map(s => " (" + s + ")").getOrElse("") logInfo("Executor updated: %s is now %s%s".format(fullId, state, messageText)) - println("Executor updated, also yay! state is " + state) if (ExecutorState.isFinished(state)) { - println("executor is finished. good bye!") listener.executorRemoved(fullId, message.getOrElse(""), exitStatus, workerLost) } else if (state == ExecutorState.DECOMMISSIONED) { - println("Propegating decommission to my listener") listener.executorDecommissioned(fullId, message.getOrElse("")) - } else { - println("not doing anything about that eh :p") } case MasterChanged(masterRef, masterWebUiUrl) => diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 63c558595ee32..0305f3590108a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -232,9 +232,7 @@ private[deploy] class Master( System.exit(0) case WorkerDecommission(id, workerRef) => - println("pandas!!") logInfo("Recording worker %s decommissioning".format(id)) - println("Decommisioning worker :) " + id) if (state == RecoveryState.STANDBY) { workerRef.send(MasterInStandby) } else { @@ -337,7 +335,6 @@ private[deploy] class Master( } case Heartbeat(workerId, worker) => - println("Heartbeat received " + workerId) idToWorker.get(workerId) match { case Some(workerInfo) => workerInfo.lastHeartbeat = System.currentTimeMillis() @@ -755,7 +752,6 @@ private[deploy] class Master( } private def registerWorker(worker: WorkerInfo): Boolean = { - println("Registering worker..." + worker) // There may be one or more refs to dead workers on this same node (w/ different ID's), // remove them. workers.filter { w => @@ -787,14 +783,11 @@ private[deploy] class Master( } private def decommissionWorker(worker: WorkerInfo) { - println("Found worker info for decomissioning worker " + worker) if (worker.state != WorkerState.DECOMMISSIONED) { logInfo("Decommissioning worker %s on %s:%d".format(worker.id, worker.host, worker.port)) worker.setState(WorkerState.DECOMMISSIONED) - println("Worker executors are " + worker.executors.values.toList) for (exec <- worker.executors.values) { logInfo("Telling app of decomission executors") - println("Telling the drivers") exec.application.driver.send(ExecutorUpdated( exec.id, ExecutorState.DECOMMISSIONED, Some("worker decommissioned"), None, workerLost = false)) @@ -807,7 +800,6 @@ private[deploy] class Master( logWarning("Skipping decommissioning worker %s on %s:%d as worker is already decommissioned". format(worker.id, worker.host, worker.port)) } - println("Finished decommissioning worker from master") } private def removeWorker(worker: WorkerInfo) { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 7103210704bb4..59150c9859719 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -642,7 +642,6 @@ private[deploy] class Worker( private def sendToMaster(message: Any): Unit = { master match { case Some(masterRef) => - println("Sending message to master " + message) masterRef.send(message) case None => logWarning( @@ -688,16 +687,13 @@ private[deploy] class Worker( } private[deploy] def decommissionSelf(): Unit = { - println("decommission self called") if (conf.get(config.WORKER_DECOMMISSION_ENABLED)) { - println("propegating") logDebug("Decommissioning self") decommissioned = true // TODO: Send decommission notification to executors & shuffle service. // Also send message to master program. sendToMaster(WorkerDecommission(workerId, self)) } else { - println("skipping") logWarning("Asked to decommission self, but decommissioning not enabled") } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 89d2b5b0ad46a..66c91e270b005 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -536,7 +536,6 @@ private[spark] class TaskSchedulerImpl private[scheduler]( } override def executorDecommission(executorId: String): Unit = { - println("Decommissioning " + executorId) rootPool.executorDecommission(executorId) backend.reviveOffers() } @@ -545,7 +544,6 @@ private[spark] class TaskSchedulerImpl private[scheduler]( var failedExecutor: Option[String] = None synchronized { - println("Recording executor loss " + executorId) if (executorIdToRunningTaskIds.contains(executorId)) { val hostPort = executorIdToHost(executorId) logExecutorLoss(executorId, hostPort, reason) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index be4fd55374052..891ecf461f11a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -344,15 +344,12 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp * */ private def decommissionExecutor(executorId: String): Boolean = { - println("sched backend received exec decom req for " + executorId) val shouldDisable = CoarseGrainedSchedulerBackend.this.synchronized { // Only bother decommissioning executors which are alive. if (executorIsAlive(executorId)) { - println("Adding exec to pending decom list") executorsPendingDecommission += executorId true } else { - println("Exec already decommed") false } } @@ -361,7 +358,6 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp logInfo(s"Decommissioning executor $executorId.") scheduler.executorDecommission(executorId) } - println("Coarse grained scheduler decomissioning executor " + shouldDisable) shouldDisable } @@ -492,10 +488,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp * Called by subclasses when notified of a decommissioning worker. */ private[spark] def decommissionExecutor(executorId: String): Unit = { - println("scheduler asked to decom exec " + executorId) // Only log the failure since we don't care about the result. driverEndpoint.ask[Boolean](DecomissionExecutor(executorId)).onFailure { case t => - println("Decommissioning executor failed? idk w/e") logError(t.getMessage, t) }(ThreadUtils.sameThread) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index 2b80e01fca5f1..b1cde2d6f0348 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -162,7 +162,6 @@ private[spark] class StandaloneSchedulerBackend( } override def executorDecommissioned(fullId: String, message: String) { - println("Decommission executor inside of scheduler") logInfo("Executor %s decommissioned: %s".format(fullId, message)) decommissionExecutor(fullId.split("/")(1)) } diff --git a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala index 378a795e75f54..acbc77adf67fa 100644 --- a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala @@ -119,7 +119,6 @@ class AppClientSuite // Send a decommission self to all the workers // Note: normally the worker would send this on their own. - println("Workers are " + workers) workers.foreach(worker => worker.decommissionSelf()) // Decommissioning is async. @@ -243,7 +242,6 @@ class AppClientSuite } def executorDecommissioned(id: String, message: String): Unit = { - println("Decommission executor " + id) execDecommissionedList.add(id) } } From 16c855ad9eb7b961d805bc2f459d86f3b3d31108 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 5 Jul 2017 23:13:41 -0700 Subject: [PATCH 008/194] Add a worker decommissioning suite --- .../scheduler/WorkerDecommissionSuite.scala | 76 +++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 core/src/test/scala/org/apache/spark/scheduler/WorkerDecommissionSuite.scala diff --git a/core/src/test/scala/org/apache/spark/scheduler/WorkerDecommissionSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/WorkerDecommissionSuite.scala new file mode 100644 index 0000000000000..891d749a7d77f --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/WorkerDecommissionSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import scala.concurrent.TimeoutException +import scala.concurrent.duration._ + +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkFunSuite} +import org.apache.spark.internal.config +import org.apache.spark.scheduler.cluster.StandaloneSchedulerBackend +import org.apache.spark.util.{RpcUtils, SerializableBuffer, ThreadUtils} + +class WorkerDecommissionSuite extends SparkFunSuite with LocalSparkContext { + + + override def beforeEach(): Unit = { + val conf = new SparkConf().setAppName("test").setMaster("local") + .set(config.WORKER_DECOMMISSION_ENABLED.key, "true") + + sc = new SparkContext("local-cluster[2, 1, 1024]", "test", conf) + } + + test("verify task with no decommissioning works as expected") { + val input = sc.parallelize(1 to 10) + input.count() + val sleepyRdd = input.mapPartitions{ x => + Thread.sleep(100) + x + } + assert(sleepyRdd.count() === 10) + } + + test("verify a task with all workers decommissioned succeeds") { + val input = sc.parallelize(1 to 10) + // Do a count to wait for the executors to be registered. + input.count() + val sleepyRdd = input.mapPartitions{ x => + Thread.sleep(100) + x + } + // Start the task + val asyncCount = sleepyRdd.countAsync() + Thread.sleep(10) + // Decommission all the executors, this should not halt the current task. + // The master passing message is tested with + val sched = sc.schedulerBackend.asInstanceOf[StandaloneSchedulerBackend] + val execs = sched.getExecutorIds() + println("execs are " + execs) + execs.foreach(execId => sched.decommissionExecutor(execId)) + assert(asyncCount.get() === 10) + // Try and launch task after decommissioning, this should fail + println("post decom execs are " + sched.getExecutorIds()) + val postDecommissioned = input.map(x => x) + val postDecomAsyncCount = postDecommissioned.countAsync() + val thrown = intercept[java.util.concurrent.TimeoutException]{ + val result = ThreadUtils.awaitResult(postDecomAsyncCount, 1.seconds) + } + assert(postDecomAsyncCount.isCompleted === false, + "After exec decomission new task could not launch") + } +} From c2a0ad87dc3220eb5154a6d0a117ce0260bd2695 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 22 Aug 2017 13:28:24 -0700 Subject: [PATCH 009/194] Add decommissioning script for whatever process is running locally on host to call --- .../apache/spark/deploy/worker/Worker.scala | 11 ++++- sbin/decommission-slave.sh | 44 +++++++++++++++++++ sbin/spark-daemon.sh | 15 +++++++ 3 files changed, 68 insertions(+), 2 deletions(-) create mode 100644 sbin/decommission-slave.sh diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index f84e9da9b94cf..c20e3ee092c3e 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -38,7 +38,7 @@ import org.apache.spark.deploy.worker.ui.WorkerWebUI import org.apache.spark.internal.{config, Logging} import org.apache.spark.metrics.MetricsSystem import org.apache.spark.rpc._ -import org.apache.spark.util.{SparkUncaughtExceptionHandler, ThreadUtils, Utils} +import org.apache.spark.util.{SignalUtils, SparkUncaughtExceptionHandler, ThreadUtils, Utils} private[deploy] class Worker( override val rpcEnv: RpcEnv, @@ -58,6 +58,11 @@ private[deploy] class Worker( Utils.checkHost(host) assert (port > 0) + // If worker decommissioning is enabled register a handler on SIGPWR to shutdown. + if (conf.get(config.WORKER_DECOMMISSION_ENABLED)) { + SignalUtils.register("SIGPWR")(decommissionSelf) + } + // A scheduled executor used to send messages at the specified time. private val forwordMessageScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("worker-forward-message-scheduler") @@ -688,7 +693,7 @@ private[deploy] class Worker( } } - private[deploy] def decommissionSelf(): Unit = { + private[deploy] def decommissionSelf(): Boolean = { if (conf.get(config.WORKER_DECOMMISSION_ENABLED)) { logDebug("Decommissioning self") decommissioned = true @@ -698,6 +703,8 @@ private[deploy] class Worker( } else { logWarning("Asked to decommission self, but decommissioning not enabled") } + // Return true since can be called as a signal handler + true } private[worker] def handleDriverStateChanged(driverStateChanged: DriverStateChanged): Unit = { diff --git a/sbin/decommission-slave.sh b/sbin/decommission-slave.sh new file mode 100644 index 0000000000000..0251b01199c7a --- /dev/null +++ b/sbin/decommission-slave.sh @@ -0,0 +1,44 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# A shell script to decommission all workers on a single slave +# +# Environment variables +# +# SPARK_WORKER_INSTANCES The number of worker instances that should be +# running on this slave. Default is 1. + +# Usage: decommission-slave.sh +# Decommissions all slaves on this worker machine + +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi + +. "${SPARK_HOME}/sbin/spark-config.sh" + +. "${SPARK_HOME}/bin/load-spark-env.sh" + +if [ "$SPARK_WORKER_INSTANCES" = "" ]; then + "${SPARK_HOME}/sbin"/spark-daemon.sh decommission org.apache.spark.deploy.worker.Worker 1 +else + for ((i=0; i<$SPARK_WORKER_INSTANCES; i++)); do + "${SPARK_HOME}/sbin"/spark-daemon.sh decommission org.apache.spark.deploy.worker.Worker $(( $i + 1 )) + done +fi diff --git a/sbin/spark-daemon.sh b/sbin/spark-daemon.sh index 6de67e039b48f..81f2fd40a706f 100755 --- a/sbin/spark-daemon.sh +++ b/sbin/spark-daemon.sh @@ -215,6 +215,21 @@ case $option in fi ;; + (decommission) + + if [ -f $pid ]; then + TARGET_ID="$(cat "$pid")" + if [[ $(ps -p "$TARGET_ID" -o comm=) =~ "java" ]]; then + echo "decommissioning $command" + kill -s SIGPWR "$TARGET_ID" + else + echo "no $command to decommission" + fi + else + echo "no $command to decommission" + fi + ;; + (status) if [ -f $pid ]; then From 672c3b6f79400cce867ce273199ccdcf995b6ed6 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 22 Aug 2017 14:38:51 -0700 Subject: [PATCH 010/194] Leave polling mechanism up to the cloud vendors --- .../scala/org/apache/spark/internal/config/package.scala | 5 ----- 1 file changed, 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 67e20bc985f93..7bc765d612c81 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -161,11 +161,6 @@ package object config { .booleanConf .createWithDefault(false) - private[spark] val WORKER_DECOMMISSION_EC2_POLE = - ConfigBuilder("spark.worker.decommission.ec2Pole") - .booleanConf - .createWithDefault(false) - private[spark] val UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE = ConfigBuilder("spark.files.fetchFailure.unRegisterOutputOnHost") .doc("Whether to un-register all the outputs on the host in condition that we receive " + From 9cfdb7fc36691bf0c627080de5c2008fe83ba3bd Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 22 Aug 2017 14:55:12 -0700 Subject: [PATCH 011/194] Remove legacy comment and remove some unecessary blank lines --- .../scala/org/apache/spark/scheduler/BlacklistTracker.scala | 2 -- .../spark/scheduler/cluster/StandaloneSchedulerBackend.scala | 2 -- 2 files changed, 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala index 713c67e2c066d..cd8e61d6d0208 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala @@ -40,8 +40,6 @@ import org.apache.spark.util.{Clock, SystemClock, Utils} * stage, but still many failures over the entire application * * "flaky" executors -- they don't fail every task, but are still faulty enough to merit * blacklisting - * * shutting down executors -- executors which are shutting down and should not have new tasks - * scheduled. * * See the design doc on SPARK-8425 for a more in-depth discussion. * diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index 00ad422b2bc4f..7afdd87fd5ddd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -172,13 +172,11 @@ private[spark] class StandaloneSchedulerBackend( decommissionExecutor(fullId.split("/")(1)) } - override def workerRemoved(workerId: String, host: String, message: String): Unit = { logInfo("Worker %s removed: %s".format(workerId, message)) removeWorker(workerId, host, message) } - override def sufficientResourcesRegistered(): Boolean = { totalCoreCount.get() >= totalExpectedCores * minRegisteredRatio } From 65a29c12c1740c285ff7b06f3788cd2a92ce87f1 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 22 Aug 2017 14:59:24 -0700 Subject: [PATCH 012/194] Remove manually debugging printlns (oops) --- .../org/apache/spark/scheduler/WorkerDecommissionSuite.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/WorkerDecommissionSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/WorkerDecommissionSuite.scala index 891d749a7d77f..4dc9a9dc4109c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/WorkerDecommissionSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/WorkerDecommissionSuite.scala @@ -60,11 +60,9 @@ class WorkerDecommissionSuite extends SparkFunSuite with LocalSparkContext { // The master passing message is tested with val sched = sc.schedulerBackend.asInstanceOf[StandaloneSchedulerBackend] val execs = sched.getExecutorIds() - println("execs are " + execs) execs.foreach(execId => sched.decommissionExecutor(execId)) assert(asyncCount.get() === 10) // Try and launch task after decommissioning, this should fail - println("post decom execs are " + sched.getExecutorIds()) val postDecommissioned = input.map(x => x) val postDecomAsyncCount = postDecommissioned.countAsync() val thrown = intercept[java.util.concurrent.TimeoutException]{ From 258a116964731294b939bb4fca9f6b7fbf90560d Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 27 Aug 2018 12:12:59 -0700 Subject: [PATCH 013/194] Update and add blocking for K8s --- .../org/apache/spark/deploy/worker/Worker.scala | 3 +-- .../cluster/CoarseGrainedSchedulerBackend.scala | 8 ++++---- .../k8s/features/BasicExecutorFeatureStep.scala | 12 ++++++++++++ sbin/decommission-slave.sh | 13 ++++++++++++- 4 files changed, 29 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 7514305ae627e..512d5f4203945 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -666,8 +666,7 @@ private[deploy] class Worker( */ private def sendToMaster(message: Any): Unit = { master match { - case Some(masterRef) => - masterRef.send(message) + case Some(masterRef) => masterRef.send(message) case None => logWarning( s"Dropping $message because the connection to master has not yet been established") diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index b27224d1908c7..b2fd13893054b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -165,10 +165,6 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp executorDataMap.get(executorId).foreach(_.executorEndpoint.send(StopExecutor)) removeExecutor(executorId, reason) - case DecomissionExecutor(executorId) => - decommissionExecutor(executorId) - context.reply(true) - } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { @@ -233,6 +229,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp removeWorker(workerId, host, message) context.reply(true) + case DecomissionExecutor(executorId) => + decommissionExecutor(executorId) + context.reply(true) + case RetrieveSparkAppConfig => val reply = SparkAppConfig( sparkProperties, diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala index c37f713c56de1..ea1b6fa06cecc 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala @@ -151,6 +151,18 @@ private[spark] class BasicExecutorFeatureStep( .endResources() .build() }.getOrElse(executorContainer) + val containerWithLifecycle = kubernetesConf.workerDecomissioning() match { + case true => + new ContainerBuilder(executorContainer).editOrNewLifecycle() + .withNewPreStop() + .withNewExec() + .withCommand( + List("/opt/spark/sbin/decommission-slave.sh", "--block-until-exit").asJava) + .endExec() + .endPreStop() + .build() + case false => containerWithLimitCores + } val driverPod = kubernetesConf.roleSpecificConf.driverPod val ownerReference = driverPod.map(pod => new OwnerReferenceBuilder() diff --git a/sbin/decommission-slave.sh b/sbin/decommission-slave.sh index 0251b01199c7a..53048931d69b4 100644 --- a/sbin/decommission-slave.sh +++ b/sbin/decommission-slave.sh @@ -24,7 +24,7 @@ # SPARK_WORKER_INSTANCES The number of worker instances that should be # running on this slave. Default is 1. -# Usage: decommission-slave.sh +# Usage: decommission-slave.sh [--block-until-exit] # Decommissions all slaves on this worker machine if [ -z "${SPARK_HOME}" ]; then @@ -42,3 +42,14 @@ else "${SPARK_HOME}/sbin"/spark-daemon.sh decommission org.apache.spark.deploy.worker.Worker $(( $i + 1 )) done fi + +# Check if --block-until-exit is set. +# This is done for systems which block on the decomissioning script and on exit +# shut down the entire system (e.g. K8s). +if [ "$1" == "--block-until-exit" ]; then + shift + # For now we only block on the 0th instance if there multiple instances. + instance=$1 + pid="$SPARK_PID_DIR/spark-$SPARK_IDENT_STRING-$command-$instance.pid" + wait $pid +fi From c40fac5de1e976e73ff3df29fac83ca1f4d1ce8e Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 27 Aug 2018 12:13:22 -0700 Subject: [PATCH 014/194] Add workerDecomissioning to K8s conf --- .../scala/org/apache/spark/deploy/k8s/KubernetesConf.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala index 866ba3cbaa9c3..f016cc0a9ddfe 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala @@ -91,6 +91,9 @@ private[spark] case class KubernetesConf[T <: KubernetesRoleSpecificConf]( } } + def workerDecomissioning(): Boolean = + sparkConf.get(org.apache.spark.internal.config.WORKER_DECOMMISSION_ENABLED) + def nodeSelector(): Map[String, String] = KubernetesUtils.parsePrefixedKeyValuePairs(sparkConf, KUBERNETES_NODE_SELECTOR_PREFIX) From 5877c16e20559122847ed5ea21c74214fc024c9d Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sat, 8 Sep 2018 11:04:04 -0700 Subject: [PATCH 015/194] Tidy up small things. --- .../spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 889bf6630955a..fa8f234568696 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -164,7 +164,6 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // automatically, so try to tell the executor to stop itself. See SPARK-13519. executorDataMap.get(executorId).foreach(_.executorEndpoint.send(StopExecutor)) removeExecutor(executorId, reason) - } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { @@ -360,7 +359,6 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp /** * Stop making resource offers for the given executor. The executor is marked as lost with the * loss reason as WorkerDecommission. - * */ private def decommissionExecutor(executorId: String): Boolean = { val shouldDisable = CoarseGrainedSchedulerBackend.this.synchronized { From 4e6572f8a7798298fe4787fe5913ee94c2b97359 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 19 Sep 2018 13:07:43 -0700 Subject: [PATCH 016/194] Fix missing endLifecycle --- .../spark/deploy/k8s/features/BasicExecutorFeatureStep.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala index af24ebe320244..90382bcf1bf18 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala @@ -170,6 +170,7 @@ private[spark] class BasicExecutorFeatureStep( List("/opt/spark/sbin/decommission-slave.sh", "--block-until-exit").asJava) .endExec() .endPreStop() + .endLifecycle() .build() case false => containerWithLimitCores } From 42a29abf4d4479f5195eee6324efd181f118535b Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 20 Sep 2018 22:45:29 -0700 Subject: [PATCH 017/194] Add a WIP Decom suite work --- .../integrationtest/DecommissionSuite.scala | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala new file mode 100644 index 0000000000000..f5f001a176f9a --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{Pod, SecretBuilder} +import org.apache.commons.codec.binary.Base64 +import org.apache.commons.io.output.ByteArrayOutputStream +import org.scalatest.concurrent.Eventually + +import org.apache.spark.deploy.k8s.integrationtest.KubernetesSuite._ + +private[spark] trait DecommissionSuite { k8sSuite: KubernetesSuite => + + test("Run SparkPi with env and mount secrets.", k8sTestTag) { + createTestSecret() + sparkAppConf + .set("spark.worker.decommission.enabled", "true") + .set("spark.kubernetes.container.image", pySparkDockerImage) + .set("spark.kubernetes.pyspark.pythonVersion", "2") + // We should be able to run Spark PI now + runSparkPiAndVerifyCompletion() + // Now we manually trigger decomissioning + runSparkApplicationAndVerifyCompletion( + appResource = PYSPARK_FILES, + mainClass = "", + expectedLogOnCompletion = Seq( + "Spark is working before decommissioning: True", + "Called decommissionWorker on on success: True", + "Spark decommissioned worker is not scheduled: True"), + appArgs = Array("python"), + driverPodChecker = doBasicDriverPyPodCheck, + executorPodChecker = doBasicExecutorPyPodCheck, + appLocator = appLocator, + isJVM = false, + pyFiles = Some(PYSPARK_CONTAINER_TESTS)) + // Now we expect this to fail + runSparkPiAndVerifyCompletion() + } +} + +prviate[spark] object DecommissionSuite { + val TEST_LOCAL_PYSPARK: String = "local:///opt/spark/tests/" + val PYSPARK_DECOMISSIONING: String = TEST_LOCAL_PYSPARK + "trigger_decomissioning.py" +} From 05941daa889d661fb3ce12aa11a8e831ccc901d3 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 12 Oct 2018 15:28:15 -0700 Subject: [PATCH 018/194] Attempt at making decomissioning integration test for Spark on K8s compiles, still needs more work. --- .../integrationtest/DecommissionSuite.scala | 31 +++++++------------ .../k8s/integrationtest/KubernetesSuite.scala | 12 +++++-- 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala index f5f001a176f9a..826b6eb7e1f1f 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala @@ -18,43 +18,36 @@ package org.apache.spark.deploy.k8s.integrationtest import scala.collection.JavaConverters._ -import io.fabric8.kubernetes.api.model.{Pod, SecretBuilder} -import org.apache.commons.codec.binary.Base64 -import org.apache.commons.io.output.ByteArrayOutputStream -import org.scalatest.concurrent.Eventually - import org.apache.spark.deploy.k8s.integrationtest.KubernetesSuite._ +import org.apache.spark.deploy.k8s.integrationtest.TestConfig.{getTestImageRepo, getTestImageTag} private[spark] trait DecommissionSuite { k8sSuite: KubernetesSuite => + import DecommissionSuite._ + + private val pySparkDockerImage = + s"${getTestImageRepo}/spark-py:${getTestImageTag}" + test("Run SparkPi with env and mount secrets.", k8sTestTag) { - createTestSecret() sparkAppConf .set("spark.worker.decommission.enabled", "true") .set("spark.kubernetes.container.image", pySparkDockerImage) .set("spark.kubernetes.pyspark.pythonVersion", "2") - // We should be able to run Spark PI now - runSparkPiAndVerifyCompletion() - // Now we manually trigger decomissioning + runSparkApplicationAndVerifyCompletion( - appResource = PYSPARK_FILES, + appResource = PYSPARK_DECOMISSIONING, mainClass = "", - expectedLogOnCompletion = Seq( - "Spark is working before decommissioning: True", - "Called decommissionWorker on on success: True", - "Spark decommissioned worker is not scheduled: True"), + expectedLogOnCompletion = Seq("Decommissioning worker"), appArgs = Array("python"), driverPodChecker = doBasicDriverPyPodCheck, executorPodChecker = doBasicExecutorPyPodCheck, appLocator = appLocator, isJVM = false, - pyFiles = Some(PYSPARK_CONTAINER_TESTS)) - // Now we expect this to fail - runSparkPiAndVerifyCompletion() + decomissioningTest = true) } } -prviate[spark] object DecommissionSuite { +private[spark] object DecommissionSuite { val TEST_LOCAL_PYSPARK: String = "local:///opt/spark/tests/" - val PYSPARK_DECOMISSIONING: String = TEST_LOCAL_PYSPARK + "trigger_decomissioning.py" + val PYSPARK_DECOMISSIONING: String = TEST_LOCAL_PYSPARK + "decomissioning_waiter.py" } diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index c99a907f98d0a..dd44c3563e10c 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -207,7 +207,8 @@ private[spark] class KubernetesSuite extends SparkFunSuite executorPodChecker: Pod => Unit, appLocator: String, isJVM: Boolean, - pyFiles: Option[String] = None): Unit = { + pyFiles: Option[String] = None, + decomissioningTest: Boolean = false): Unit = { val appArguments = SparkAppArguments( mainAppResource = appResource, mainClass = mainClass, @@ -242,12 +243,19 @@ private[spark] class KubernetesSuite extends SparkFunSuite action match { case Action.ADDED | Action.MODIFIED => execPods(name) = resource + // If testing decomissioning delete the node 10 seconds after + if (decomissioningTest) { + Thread.sleep(1000) + kubernetesTestComponents.kubernetesClient.pods().withName(name).delete() + } case Action.DELETED | Action.ERROR => execPods.remove(name) } } }) - Eventually.eventually(TIMEOUT, INTERVAL) { execPods.values.nonEmpty should be (true) } + // If we're testing decomissioning we delete all the executors + Eventually.eventually(TIMEOUT, INTERVAL) { + execPods.values.nonEmpty || decomissioningTest should be (true) } execWatcher.close() execPods.values.foreach(executorPodChecker(_)) Eventually.eventually(TIMEOUT, INTERVAL) { From cb61f45be45dfeeda58b0644d21baead9082d7d4 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 12 Oct 2018 15:31:24 -0700 Subject: [PATCH 019/194] Add initial decomissioning_water helper script --- .../tests/decomissioning_water.py | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 resource-managers/kubernetes/integration-tests/tests/decomissioning_water.py diff --git a/resource-managers/kubernetes/integration-tests/tests/decomissioning_water.py b/resource-managers/kubernetes/integration-tests/tests/decomissioning_water.py new file mode 100644 index 0000000000000..486a6d26b855d --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/tests/decomissioning_water.py @@ -0,0 +1,38 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +import time + +from pyspark.sql import SparkSession + + +if __name__ == "__main__": + """ + Usage: decomissioning_water + """ + spark = SparkSession \ + .builder \ + .appName("PyMemoryTest") \ + .getOrCreate() + sc = spark.SparkContext + rdd = sc.parallelize(1.to(10)) + rdd.collect() + time.sleep(15) + sys.exit(0) + From 963a289cb493d9442a832ee28921258d0b519917 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 12 Oct 2018 15:38:47 -0700 Subject: [PATCH 020/194] We don't use the JavaConverters in this test suite. --- .../spark/deploy/k8s/integrationtest/DecommissionSuite.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala index 826b6eb7e1f1f..3bb64047f25a8 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala @@ -16,8 +16,6 @@ */ package org.apache.spark.deploy.k8s.integrationtest -import scala.collection.JavaConverters._ - import org.apache.spark.deploy.k8s.integrationtest.KubernetesSuite._ import org.apache.spark.deploy.k8s.integrationtest.TestConfig.{getTestImageRepo, getTestImageTag} From 1cc14365cb988e6bea67bae30617f54e88787b5c Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 12 Oct 2018 15:41:12 -0700 Subject: [PATCH 021/194] 1.to(10) is scala code, use range since we're in Python. --- .../kubernetes/integration-tests/tests/decomissioning_water.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resource-managers/kubernetes/integration-tests/tests/decomissioning_water.py b/resource-managers/kubernetes/integration-tests/tests/decomissioning_water.py index 486a6d26b855d..b250f3af64e6c 100644 --- a/resource-managers/kubernetes/integration-tests/tests/decomissioning_water.py +++ b/resource-managers/kubernetes/integration-tests/tests/decomissioning_water.py @@ -31,7 +31,7 @@ .appName("PyMemoryTest") \ .getOrCreate() sc = spark.SparkContext - rdd = sc.parallelize(1.to(10)) + rdd = sc.parallelize(range(10)) rdd.collect() time.sleep(15) sys.exit(0) From 9036b4474162e63f4fa6042c1244dd6e24a794c9 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 12 Oct 2018 15:42:12 -0700 Subject: [PATCH 022/194] Fix style issue with blank line at end of file. --- .../kubernetes/integration-tests/tests/decomissioning_water.py | 1 - 1 file changed, 1 deletion(-) diff --git a/resource-managers/kubernetes/integration-tests/tests/decomissioning_water.py b/resource-managers/kubernetes/integration-tests/tests/decomissioning_water.py index b250f3af64e6c..75535e4737c9b 100644 --- a/resource-managers/kubernetes/integration-tests/tests/decomissioning_water.py +++ b/resource-managers/kubernetes/integration-tests/tests/decomissioning_water.py @@ -35,4 +35,3 @@ rdd.collect() time.sleep(15) sys.exit(0) - From d58f2a6adc3176490da4cecf6547ac21ae1bbd0b Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 12 Oct 2018 15:46:43 -0700 Subject: [PATCH 023/194] Remove unneeded appArgs --- .../spark/deploy/k8s/integrationtest/DecommissionSuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala index 3bb64047f25a8..7789633ca495b 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala @@ -36,7 +36,6 @@ private[spark] trait DecommissionSuite { k8sSuite: KubernetesSuite => appResource = PYSPARK_DECOMISSIONING, mainClass = "", expectedLogOnCompletion = Seq("Decommissioning worker"), - appArgs = Array("python"), driverPodChecker = doBasicDriverPyPodCheck, executorPodChecker = doBasicExecutorPyPodCheck, appLocator = appLocator, From 5ae1bd76ca0786bf0289b88a21e72777fa7aa62d Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 16 Oct 2018 12:26:56 -0700 Subject: [PATCH 024/194] Add missing sys import --- .../kubernetes/integration-tests/tests/decomissioning_water.py | 1 + 1 file changed, 1 insertion(+) diff --git a/resource-managers/kubernetes/integration-tests/tests/decomissioning_water.py b/resource-managers/kubernetes/integration-tests/tests/decomissioning_water.py index 75535e4737c9b..4352f099c1158 100644 --- a/resource-managers/kubernetes/integration-tests/tests/decomissioning_water.py +++ b/resource-managers/kubernetes/integration-tests/tests/decomissioning_water.py @@ -17,6 +17,7 @@ from __future__ import print_function +import sys import time from pyspark.sql import SparkSession From be38dab943899f7adce1859d07d5f4efc4908108 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 26 Oct 2018 09:37:39 -0700 Subject: [PATCH 025/194] Add back appArgs since despite Ilan's comments during the stream they aren't optional for this code path :p --- .../spark/deploy/k8s/integrationtest/DecommissionSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala index 7789633ca495b..ea83e211f4452 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala @@ -36,6 +36,7 @@ private[spark] trait DecommissionSuite { k8sSuite: KubernetesSuite => appResource = PYSPARK_DECOMISSIONING, mainClass = "", expectedLogOnCompletion = Seq("Decommissioning worker"), + appArgs = Array.empty[String], driverPodChecker = doBasicDriverPyPodCheck, executorPodChecker = doBasicExecutorPyPodCheck, appLocator = appLocator, From bbeceb9cba664869e690b36438d720507a241814 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 31 Oct 2018 10:09:04 -0700 Subject: [PATCH 026/194] Extend DecommissionSuite so the tests are triggered. --- .../spark/deploy/k8s/integrationtest/KubernetesSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index dd44c3563e10c..b9557d6ebe2a7 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -38,7 +38,7 @@ import org.apache.spark.internal.Logging private[spark] class KubernetesSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter with BasicTestsSuite with SecretsTestsSuite - with PythonTestsSuite with ClientModeTestsSuite + with PythonTestsSuite with ClientModeTestsSuite with DecommissionSuite with Logging with Eventually with Matchers { import KubernetesSuite._ From e5fb644a47b62b89a3a2db08a9c91f2f57a35b71 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 31 Oct 2018 10:51:45 -0700 Subject: [PATCH 027/194] Check all containers --- .../k8s/integrationtest/KubernetesSuite.scala | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index b9557d6ebe2a7..1b57893d61d25 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -22,7 +22,7 @@ import java.util.UUID import java.util.regex.Pattern import com.google.common.io.PatternFilenameFilter -import io.fabric8.kubernetes.api.model.Pod +import io.fabric8.kubernetes.api.model.{ContainerStateRunning, Pod} import io.fabric8.kubernetes.client.{KubernetesClientException, Watcher} import io.fabric8.kubernetes.client.Watcher.Action import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, Tag} @@ -243,10 +243,21 @@ private[spark] class KubernetesSuite extends SparkFunSuite action match { case Action.ADDED | Action.MODIFIED => execPods(name) = resource - // If testing decomissioning delete the node 10 seconds after + // If testing decomissioning delete the node 5 seconds after it starts running if (decomissioningTest) { - Thread.sleep(1000) - kubernetesTestComponents.kubernetesClient.pods().withName(name).delete() + // Wait for all the containers in the pod to be running + Eventually.eventually(TIMEOUT, INTERVAL) { + val containerStatuses = p.getStatus.getContainerStatuses.asScala + val runningContainers = containerStatuses.filter(_ == ContainerStateRunning) + val nonRunningContainers = containerStatuses.filter(_ != ContainerStateRunning) + + runningContainers > 0 && nonRunningContainers == 0 + } + // Sleep a small interval to ensure everything is registered. + Thread.sleep(500) + // Delete the pod to simulate cluster scale down/migration. + val pod = kubernetesTestComponents.kubernetesClient.pods().withName(name) + pod.delete() } case Action.DELETED | Action.ERROR => execPods.remove(name) From 164fa2a68515bd54af0aeea516acae128f425a4a Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 31 Oct 2018 11:01:34 -0700 Subject: [PATCH 028/194] Wait for the pod to become ready for before we kill it. --- .../deploy/k8s/integrationtest/KubernetesSuite.scala | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 1b57893d61d25..7023d9da87ef3 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -22,7 +22,7 @@ import java.util.UUID import java.util.regex.Pattern import com.google.common.io.PatternFilenameFilter -import io.fabric8.kubernetes.api.model.{ContainerStateRunning, Pod} +import io.fabric8.kubernetes.api.model.Pod import io.fabric8.kubernetes.client.{KubernetesClientException, Watcher} import io.fabric8.kubernetes.client.Watcher.Action import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, Tag} @@ -247,11 +247,9 @@ private[spark] class KubernetesSuite extends SparkFunSuite if (decomissioningTest) { // Wait for all the containers in the pod to be running Eventually.eventually(TIMEOUT, INTERVAL) { - val containerStatuses = p.getStatus.getContainerStatuses.asScala - val runningContainers = containerStatuses.filter(_ == ContainerStateRunning) - val nonRunningContainers = containerStatuses.filter(_ != ContainerStateRunning) - - runningContainers > 0 && nonRunningContainers == 0 + resource.getStatus.getConditions().asScala + .map(cond => cond.getStatus() == "True" && cond.getType() == "Ready") + .headOption.getOrElse(false) } // Sleep a small interval to ensure everything is registered. Thread.sleep(500) From ca448d13c523e4658720ed3bf9b7cfa9f03ec260 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 31 Oct 2018 15:02:50 -0700 Subject: [PATCH 029/194] import the test tag idk why it won't run --- .../spark/deploy/k8s/integrationtest/DecommissionSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala index ea83e211f4452..a747398cb5c91 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.deploy.k8s.integrationtest.TestConfig.{getTestImageRepo, private[spark] trait DecommissionSuite { k8sSuite: KubernetesSuite => import DecommissionSuite._ + import KubernetesSuite.k8sTestTag private val pySparkDockerImage = s"${getTestImageRepo}/spark-py:${getTestImageTag}" From 8d504b23f95722be9eb53aeef84ee71d44a6013e Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 7 Nov 2018 18:06:03 -0800 Subject: [PATCH 030/194] Remove import --- .../spark/deploy/k8s/integrationtest/DecommissionSuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala index a747398cb5c91..265980b5b6134 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala @@ -16,7 +16,6 @@ */ package org.apache.spark.deploy.k8s.integrationtest -import org.apache.spark.deploy.k8s.integrationtest.KubernetesSuite._ import org.apache.spark.deploy.k8s.integrationtest.TestConfig.{getTestImageRepo, getTestImageTag} private[spark] trait DecommissionSuite { k8sSuite: KubernetesSuite => From 7b0023a85e787bdbfe4a0b3ed7fe3a9e6058eaf9 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 6 Dec 2018 11:21:46 -0800 Subject: [PATCH 031/194] Maybe we don't need to explicitly set the docker image since the Python tests don't either --- .../spark/deploy/k8s/integrationtest/DecommissionSuite.scala | 4 ---- 1 file changed, 4 deletions(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala index 265980b5b6134..24858ab280ceb 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala @@ -23,13 +23,9 @@ private[spark] trait DecommissionSuite { k8sSuite: KubernetesSuite => import DecommissionSuite._ import KubernetesSuite.k8sTestTag - private val pySparkDockerImage = - s"${getTestImageRepo}/spark-py:${getTestImageTag}" - test("Run SparkPi with env and mount secrets.", k8sTestTag) { sparkAppConf .set("spark.worker.decommission.enabled", "true") - .set("spark.kubernetes.container.image", pySparkDockerImage) .set("spark.kubernetes.pyspark.pythonVersion", "2") runSparkApplicationAndVerifyCompletion( From cca9948e8f105eef2f961d99a0a458445d880ee2 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 6 Dec 2018 11:37:42 -0800 Subject: [PATCH 032/194] We don't use () anymore on the properties in kubeconf --- .../spark/deploy/k8s/features/BasicExecutorFeatureStep.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala index c80454bff9a35..bba8b6ccb7f92 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala @@ -157,7 +157,7 @@ private[spark] class BasicExecutorFeatureStep(kubernetesConf: KubernetesExecutor .endResources() .build() }.getOrElse(executorContainer) - val containerWithLifecycle = kubernetesConf.workerDecomissioning() match { + val containerWithLifecycle = kubernetesConf.workerDecomissioning match { case true => new ContainerBuilder(executorContainer).editOrNewLifecycle() .withNewPreStop() From 1bbb69b95a4d87719a9edd8dfe831d366a16c00c Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 6 Dec 2018 12:04:02 -0800 Subject: [PATCH 033/194] Remove unused imports --- .../spark/deploy/k8s/integrationtest/DecommissionSuite.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala index 24858ab280ceb..b20cbcc160258 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala @@ -16,8 +16,6 @@ */ package org.apache.spark.deploy.k8s.integrationtest -import org.apache.spark.deploy.k8s.integrationtest.TestConfig.{getTestImageRepo, getTestImageTag} - private[spark] trait DecommissionSuite { k8sSuite: KubernetesSuite => import DecommissionSuite._ From af048f5753cd99b68d2e5f8d268c52a119a2d84a Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 11 Dec 2018 16:55:11 -0800 Subject: [PATCH 034/194] Maybe the class loading issue is from two traits with the same test name. --- .../spark/deploy/k8s/integrationtest/DecommissionSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala index b20cbcc160258..7a12c4775511d 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala @@ -21,7 +21,7 @@ private[spark] trait DecommissionSuite { k8sSuite: KubernetesSuite => import DecommissionSuite._ import KubernetesSuite.k8sTestTag - test("Run SparkPi with env and mount secrets.", k8sTestTag) { + test("Test basic decommissioning", k8sTestTag) { sparkAppConf .set("spark.worker.decommission.enabled", "true") .set("spark.kubernetes.pyspark.pythonVersion", "2") From 43abf98dc1eff6bcc3d5c9ba8700a74c7922de2c Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 11 Dec 2018 17:35:52 -0800 Subject: [PATCH 035/194] Configure the image for Python. --- .../spark/deploy/k8s/integrationtest/DecommissionSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala index 7a12c4775511d..1319d9163df98 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala @@ -24,7 +24,8 @@ private[spark] trait DecommissionSuite { k8sSuite: KubernetesSuite => test("Test basic decommissioning", k8sTestTag) { sparkAppConf .set("spark.worker.decommission.enabled", "true") - .set("spark.kubernetes.pyspark.pythonVersion", "2") + .set("spark.kubernetes.pyspark.pythonVersion", "3") + .set("spark.kubernetes.container.image", pyImage) runSparkApplicationAndVerifyCompletion( appResource = PYSPARK_DECOMISSIONING, From abb060949797cec31ec02080c5720ccec9df3f5c Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 11 Dec 2018 18:08:53 -0800 Subject: [PATCH 036/194] Mispelled decommissioning in the python test file. --- .../tests/{decomissioning_water.py => decommissioning_water.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename resource-managers/kubernetes/integration-tests/tests/{decomissioning_water.py => decommissioning_water.py} (100%) diff --git a/resource-managers/kubernetes/integration-tests/tests/decomissioning_water.py b/resource-managers/kubernetes/integration-tests/tests/decommissioning_water.py similarity index 100% rename from resource-managers/kubernetes/integration-tests/tests/decomissioning_water.py rename to resource-managers/kubernetes/integration-tests/tests/decommissioning_water.py From 9d4fc238d2d4e0984d4880d42c4cf6a1f1a52f8b Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 11 Dec 2018 18:49:24 -0800 Subject: [PATCH 037/194] Change ref --- .../spark/deploy/k8s/integrationtest/DecommissionSuite.scala | 2 +- .../tests/{decommissioning_water.py => decommissioning.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename resource-managers/kubernetes/integration-tests/tests/{decommissioning_water.py => decommissioning.py} (100%) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala index 1319d9163df98..792d8bb0ec639 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala @@ -42,5 +42,5 @@ private[spark] trait DecommissionSuite { k8sSuite: KubernetesSuite => private[spark] object DecommissionSuite { val TEST_LOCAL_PYSPARK: String = "local:///opt/spark/tests/" - val PYSPARK_DECOMISSIONING: String = TEST_LOCAL_PYSPARK + "decomissioning_waiter.py" + val PYSPARK_DECOMISSIONING: String = TEST_LOCAL_PYSPARK + "decommissioning.py" } diff --git a/resource-managers/kubernetes/integration-tests/tests/decommissioning_water.py b/resource-managers/kubernetes/integration-tests/tests/decommissioning.py similarity index 100% rename from resource-managers/kubernetes/integration-tests/tests/decommissioning_water.py rename to resource-managers/kubernetes/integration-tests/tests/decommissioning.py From 7f3fd5fa424751386207ee4dba20c44d1334e600 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 25 Dec 2018 09:41:08 -0800 Subject: [PATCH 038/194] The spark context on the session object is stored as _sc --- .../kubernetes/integration-tests/tests/decommissioning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resource-managers/kubernetes/integration-tests/tests/decommissioning.py b/resource-managers/kubernetes/integration-tests/tests/decommissioning.py index 4352f099c1158..a3ba08b979624 100644 --- a/resource-managers/kubernetes/integration-tests/tests/decommissioning.py +++ b/resource-managers/kubernetes/integration-tests/tests/decommissioning.py @@ -31,7 +31,7 @@ .builder \ .appName("PyMemoryTest") \ .getOrCreate() - sc = spark.SparkContext + sc = spark._sc rdd = sc.parallelize(range(10)) rdd.collect() time.sleep(15) From 8306827f31f77a87d89d8324235c3e643abce8f1 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 11 Dec 2018 18:26:48 -0800 Subject: [PATCH 039/194] Speed up running the kubernetes integration tests locally by allowing folks to skip the tgz dist build and extraction --- .../scripts/setup-integration-test-env.sh | 12 ++++++------ .../deploy/k8s/integrationtest/KubernetesSuite.scala | 12 ++++++++++-- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh b/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh index 36e30d7b2cffb..e70101bbc08cf 100755 --- a/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh +++ b/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh @@ -58,15 +58,15 @@ while (( "$#" )); do shift done -if [[ $SPARK_TGZ == "N/A" ]]; +rm -rf $UNPACKED_SPARK_TGZ +if [[ $SPARK_TGZ == "N/A" && $IMAGE_TAG == "N/A" ]]; then - echo "Must specify a Spark tarball to build Docker images against with --spark-tgz." && exit 1; + echo "Must specify a Spark tarball to build Docker images against with --spark-tgz OR image with --image-tag." && exit 1; +else + mkdir -p $UNPACKED_SPARK_TGZ + tar -xzvf $SPARK_TGZ --strip-components=1 -C $UNPACKED_SPARK_TGZ; fi -rm -rf $UNPACKED_SPARK_TGZ -mkdir -p $UNPACKED_SPARK_TGZ -tar -xzvf $SPARK_TGZ --strip-components=1 -C $UNPACKED_SPARK_TGZ; - if [[ $IMAGE_TAG == "N/A" ]]; then IMAGE_TAG=$(uuidgen); diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 2638233d07a3b..572a27eb23e70 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -103,8 +103,16 @@ class KubernetesSuite extends SparkFunSuite System.clearProperty(key) } - val sparkDirProp = System.getProperty(CONFIG_KEY_UNPACK_DIR) - require(sparkDirProp != null, "Spark home directory must be provided in system properties.") + val possible_spark_dirs = List( + // If someone specified the tgz for the tests look at the extraction dir + System.getProperty(CONFIG_KEY_UNPACK_DIR), + // If otherwise use my working dir + 3 up + new File(Paths.get(System.getProperty("user.dir")).toFile, ("../" * 3)).getAbsolutePath() + ) + val sparkDirProp = possible_spark_dirs.filter(x => + new File(Paths.get(x).toFile, "bin/spark-submit").exists).headOption.getOrElse(null) + require(sparkDirProp != null, + s"Spark home directory must be provided in system properties tested $possible_spark_dirs") sparkHomeDir = Paths.get(sparkDirProp) require(sparkHomeDir.toFile.isDirectory, s"No directory found for spark home specified at $sparkHomeDir.") From 154c8b94128c4ff71f5e003a11f5cd1d66faa965 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 25 Dec 2018 10:23:32 -0800 Subject: [PATCH 040/194] Log exec decom for test --- .../spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 7b17c644a3fc4..9ddae2d42f98b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -241,6 +241,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp context.reply(true) case DecomissionExecutor(executorId) => + logInfo(s"Decommissioning executor ${executorId}") decommissionExecutor(executorId) context.reply(true) From 3020ef8f030d584282f342831f808ad9e1031a70 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 25 Dec 2018 10:23:45 -0800 Subject: [PATCH 041/194] Fix log msg check --- .../spark/deploy/k8s/integrationtest/DecommissionSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala index 792d8bb0ec639..897a85c97ea45 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala @@ -30,7 +30,7 @@ private[spark] trait DecommissionSuite { k8sSuite: KubernetesSuite => runSparkApplicationAndVerifyCompletion( appResource = PYSPARK_DECOMISSIONING, mainClass = "", - expectedLogOnCompletion = Seq("Decommissioning worker"), + expectedLogOnCompletion = Seq("Decommissioning executor"), appArgs = Array.empty[String], driverPodChecker = doBasicDriverPyPodCheck, executorPodChecker = doBasicExecutorPyPodCheck, From 27b4edd6e83fba62f117bc8cf0fb47dc78b35006 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 25 Dec 2018 10:24:03 -0800 Subject: [PATCH 042/194] 30 seconds why not --- .../kubernetes/integration-tests/tests/decommissioning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resource-managers/kubernetes/integration-tests/tests/decommissioning.py b/resource-managers/kubernetes/integration-tests/tests/decommissioning.py index a3ba08b979624..d025ff6e4455d 100644 --- a/resource-managers/kubernetes/integration-tests/tests/decommissioning.py +++ b/resource-managers/kubernetes/integration-tests/tests/decommissioning.py @@ -34,5 +34,5 @@ sc = spark._sc rdd = sc.parallelize(range(10)) rdd.collect() - time.sleep(15) + time.sleep(30) sys.exit(0) From 00310f93d3afeb30f7e570a6e942859292484807 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 25 Dec 2018 10:24:25 -0800 Subject: [PATCH 043/194] Some temporary printlns for debugging in Jenkins --- .../spark/deploy/k8s/integrationtest/KubernetesSuite.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 572a27eb23e70..e7204f9c09e88 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -280,6 +280,7 @@ class KubernetesSuite extends SparkFunSuite // If testing decomissioning delete the node 5 seconds after it starts running if (decomissioningTest) { // Wait for all the containers in the pod to be running + println("Waiting for pod to become OK then delete.") Eventually.eventually(TIMEOUT, INTERVAL) { resource.getStatus.getConditions().asScala .map(cond => cond.getStatus() == "True" && cond.getType() == "Ready") @@ -290,6 +291,7 @@ class KubernetesSuite extends SparkFunSuite // Delete the pod to simulate cluster scale down/migration. val pod = kubernetesTestComponents.kubernetesClient.pods().withName(name) pod.delete() + println(s"Pod: $name deleted") } case Action.DELETED | Action.ERROR => execPods.remove(name) @@ -301,6 +303,7 @@ class KubernetesSuite extends SparkFunSuite execPods.values.nonEmpty || decomissioningTest should be (true) } execWatcher.close() execPods.values.foreach(executorPodChecker(_)) + println(s"Exec pods are $execPods") Eventually.eventually(TIMEOUT, INTERVAL) { expectedLogOnCompletion.foreach { e => assert(kubernetesTestComponents.kubernetesClient From c3c0e3a013692caabea0118b6ef1e141601853d6 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 28 Dec 2018 12:25:54 -0800 Subject: [PATCH 044/194] Just run the decom suite. --- .../spark/deploy/k8s/integrationtest/KubernetesSuite.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index e7204f9c09e88..8f6251677ad05 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -39,8 +39,10 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ class KubernetesSuite extends SparkFunSuite - with BeforeAndAfterAll with BeforeAndAfter with BasicTestsSuite with SecretsTestsSuite - with PythonTestsSuite with ClientModeTestsSuite with PodTemplateSuite with DecommissionSuite + with BeforeAndAfterAll with BeforeAndAfter + //with BasicTestsSuite with SecretsTestsSuite + // with PythonTestsSuite with ClientModeTestsSuite with PodTemplateSuite + with DecommissionSuite with Logging with Eventually with Matchers { import KubernetesSuite._ From 0bf027a16d3da55f9de33b04951a9b481609b6fc Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 28 Dec 2018 12:53:51 -0800 Subject: [PATCH 045/194] Try and debug the tests some more. --- .../spark/deploy/k8s/integrationtest/KubernetesSuite.scala | 4 ++++ .../kubernetes/integration-tests/tests/decommissioning.py | 5 ++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 8f6251677ad05..b511cb6e2a115 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -257,6 +257,7 @@ class KubernetesSuite extends SparkFunSuite isJVM, pyFiles) + println("Running spark job.") val driverPod = kubernetesTestComponents.kubernetesClient .pods() .withLabel("spark-app-locator", appLocator) @@ -275,9 +276,11 @@ class KubernetesSuite extends SparkFunSuite override def onClose(cause: KubernetesClientException): Unit = logInfo("Ending watch of executors") override def eventReceived(action: Watcher.Action, resource: Pod): Unit = { + println("Event received.") val name = resource.getMetadata.getName action match { case Action.ADDED | Action.MODIFIED => + println("Add or modification event received.") execPods(name) = resource // If testing decomissioning delete the node 5 seconds after it starts running if (decomissioningTest) { @@ -296,6 +299,7 @@ class KubernetesSuite extends SparkFunSuite println(s"Pod: $name deleted") } case Action.DELETED | Action.ERROR => + println("Deleted or error event received.") execPods.remove(name) } } diff --git a/resource-managers/kubernetes/integration-tests/tests/decommissioning.py b/resource-managers/kubernetes/integration-tests/tests/decommissioning.py index d025ff6e4455d..a99e76d2ec2b4 100644 --- a/resource-managers/kubernetes/integration-tests/tests/decommissioning.py +++ b/resource-managers/kubernetes/integration-tests/tests/decommissioning.py @@ -27,6 +27,7 @@ """ Usage: decomissioning_water """ + print("Starting decom test") spark = SparkSession \ .builder \ .appName("PyMemoryTest") \ @@ -34,5 +35,7 @@ sc = spark._sc rdd = sc.parallelize(range(10)) rdd.collect() - time.sleep(30) + print("Waiting to give nodes time to finish.") + time.sleep(50) + spark.stop() sys.exit(0) From be18e5202b8097704e7480082c354fd2f9a29592 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 28 Dec 2018 12:54:33 -0800 Subject: [PATCH 046/194] Re-enable basic test suite. --- .../spark/deploy/k8s/integrationtest/KubernetesSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index b511cb6e2a115..2fe12d56d057f 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -40,7 +40,7 @@ import org.apache.spark.internal.config._ class KubernetesSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter - //with BasicTestsSuite with SecretsTestsSuite + with BasicTestsSuite //with SecretsTestsSuite // with PythonTestsSuite with ClientModeTestsSuite with PodTemplateSuite with DecommissionSuite with Logging with Eventually with Matchers { From 2914581143c236cb97b3e9c55651b71cd7ffd1b2 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 28 Dec 2018 15:54:48 -0800 Subject: [PATCH 047/194] More debugging --- .../k8s/integrationtest/KubernetesSuite.scala | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 2fe12d56d057f..e0f9fd4e91ade 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -40,7 +40,7 @@ import org.apache.spark.internal.config._ class KubernetesSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter - with BasicTestsSuite //with SecretsTestsSuite + // with BasicTestsSuite with SecretsTestsSuite // with PythonTestsSuite with ClientModeTestsSuite with PodTemplateSuite with DecommissionSuite with Logging with Eventually with Matchers { @@ -265,8 +265,11 @@ class KubernetesSuite extends SparkFunSuite .list() .getItems .get(0) + println("Doing driver pod check") driverPodChecker(driverPod) + println("Done driver pod check") val execPods = scala.collection.mutable.Map[String, Pod]() + println("Creating watched...") val execWatcher = kubernetesTestComponents.kubernetesClient .pods() .withLabel("spark-app-locator", appLocator) @@ -280,10 +283,10 @@ class KubernetesSuite extends SparkFunSuite val name = resource.getMetadata.getName action match { case Action.ADDED | Action.MODIFIED => - println("Add or modification event received.") + println(s"Add or modification event received for $name.") execPods(name) = resource // If testing decomissioning delete the node 5 seconds after it starts running - if (decomissioningTest) { + if (decomissioningTest && false) { // Wait for all the containers in the pod to be running println("Waiting for pod to become OK then delete.") Eventually.eventually(TIMEOUT, INTERVAL) { @@ -297,19 +300,22 @@ class KubernetesSuite extends SparkFunSuite val pod = kubernetesTestComponents.kubernetesClient.pods().withName(name) pod.delete() println(s"Pod: $name deleted") + } else { + println(s"Resource $name added") } case Action.DELETED | Action.ERROR => println("Deleted or error event received.") execPods.remove(name) + println("Resrouce $name removed") } } }) - // If we're testing decomissioning we delete all the executors - Eventually.eventually(TIMEOUT, INTERVAL) { - execPods.values.nonEmpty || decomissioningTest should be (true) } + // If we're testing decomissioning we delete all the executors, but we should have + // an executor at some point. + Eventually.eventually(TIMEOUT, INTERVAL) { execPods.values.nonEmpty } execWatcher.close() execPods.values.foreach(executorPodChecker(_)) - println(s"Exec pods are $execPods") + println(s"Close to the end exec pods are $execPods") Eventually.eventually(TIMEOUT, INTERVAL) { expectedLogOnCompletion.foreach { e => assert(kubernetesTestComponents.kubernetesClient @@ -319,6 +325,7 @@ class KubernetesSuite extends SparkFunSuite .contains(e), "The application did not complete.") } } + println(s"end exec pods are $execPods") } protected def doBasicDriverPodCheck(driverPod: Pod): Unit = { assert(driverPod.getMetadata.getName === driverPodName) From 953094a5882b62bdf70e31ba7426cb6756a509e5 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 28 Dec 2018 18:49:24 -0800 Subject: [PATCH 048/194] Hey did we not run the Python tests? --- .../integration-tests/dev/dev-run-integration-tests.sh | 3 ++- .../spark/deploy/k8s/integrationtest/KubernetesSuite.scala | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh index 68f284ca1d1ce..67b41300a139f 100755 --- a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh +++ b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh @@ -80,7 +80,8 @@ while (( "$#" )); do shift ;; *) - break + echo "Unexpected propert $2 $1 breaking parsing" + exit 1 ;; esac shift diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index e0f9fd4e91ade..dbc38d305a6da 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -40,7 +40,8 @@ import org.apache.spark.internal.config._ class KubernetesSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter - // with BasicTestsSuite with SecretsTestsSuite + with BasicTestsSuite // with SecretsTestsSuite + with PythonTestsSuite // with PythonTestsSuite with ClientModeTestsSuite with PodTemplateSuite with DecommissionSuite with Logging with Eventually with Matchers { From 5d173bd88aad8ee85d1c048ca814ed1c7f3353d9 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sat, 29 Dec 2018 01:19:27 -0800 Subject: [PATCH 049/194] Python tests aren't registering the executors, lets avoid that noise and just do SparkPI since we don't really need any special logic in the driver. --- .../k8s/integrationtest/DecommissionSuite.scala | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala index 897a85c97ea45..72b2fda3a062f 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala @@ -24,23 +24,11 @@ private[spark] trait DecommissionSuite { k8sSuite: KubernetesSuite => test("Test basic decommissioning", k8sTestTag) { sparkAppConf .set("spark.worker.decommission.enabled", "true") - .set("spark.kubernetes.pyspark.pythonVersion", "3") - .set("spark.kubernetes.container.image", pyImage) runSparkApplicationAndVerifyCompletion( - appResource = PYSPARK_DECOMISSIONING, - mainClass = "", + SPARK_PI_MAIN_CLASS, expectedLogOnCompletion = Seq("Decommissioning executor"), - appArgs = Array.empty[String], - driverPodChecker = doBasicDriverPyPodCheck, - executorPodChecker = doBasicExecutorPyPodCheck, - appLocator = appLocator, - isJVM = false, + appArgs = Array("100"), // Give it some time to run decomissioningTest = true) } } - -private[spark] object DecommissionSuite { - val TEST_LOCAL_PYSPARK: String = "local:///opt/spark/tests/" - val PYSPARK_DECOMISSIONING: String = TEST_LOCAL_PYSPARK + "decommissioning.py" -} From 705fd583441dd4771ddaea1ea686f277dd4eeaef Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sat, 29 Dec 2018 01:27:24 -0800 Subject: [PATCH 050/194] Fix using SparkPI for decom test. --- .../deploy/k8s/integrationtest/DecommissionSuite.scala | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala index 72b2fda3a062f..681ca13f83508 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala @@ -18,17 +18,22 @@ package org.apache.spark.deploy.k8s.integrationtest private[spark] trait DecommissionSuite { k8sSuite: KubernetesSuite => - import DecommissionSuite._ import KubernetesSuite.k8sTestTag + import KubernetesSuite.SPARK_PI_MAIN_CLASS test("Test basic decommissioning", k8sTestTag) { sparkAppConf .set("spark.worker.decommission.enabled", "true") runSparkApplicationAndVerifyCompletion( - SPARK_PI_MAIN_CLASS, + appResource = containerLocalSparkDistroExamplesJar, + mainClass = SPARK_PI_MAIN_CLASS, expectedLogOnCompletion = Seq("Decommissioning executor"), appArgs = Array("100"), // Give it some time to run + driverPodChecker = doBasicDriverPodCheck, + executorPodChecker = doBasicExecutorPodCheck, + appLocator = appLocator, + isJVM = true, decomissioningTest = true) } } From ca60dbf817325f388a7a6dee84171290cf60e34c Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sat, 29 Dec 2018 01:50:13 -0800 Subject: [PATCH 051/194] Enable all the tests... --- .../spark/deploy/k8s/integrationtest/KubernetesSuite.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index dbc38d305a6da..993775d88aee8 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -40,9 +40,8 @@ import org.apache.spark.internal.config._ class KubernetesSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter - with BasicTestsSuite // with SecretsTestsSuite - with PythonTestsSuite - // with PythonTestsSuite with ClientModeTestsSuite with PodTemplateSuite + with BasicTestsSuite with SecretsTestsSuite + with PythonTestsSuite with ClientModeTestsSuite with PodTemplateSuite with DecommissionSuite with Logging with Eventually with Matchers { From 044f8c5493efda7d5d15c966e3028509173153bf Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sat, 29 Dec 2018 01:59:32 -0800 Subject: [PATCH 052/194] more ... --- .../spark/deploy/k8s/integrationtest/KubernetesSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 993775d88aee8..cb24fa5d2836e 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -269,7 +269,7 @@ class KubernetesSuite extends SparkFunSuite driverPodChecker(driverPod) println("Done driver pod check") val execPods = scala.collection.mutable.Map[String, Pod]() - println("Creating watched...") + println("Creating watcher...") val execWatcher = kubernetesTestComponents.kubernetesClient .pods() .withLabel("spark-app-locator", appLocator) From c09867b51e27c3ef324f246e8fc605d60a3ad273 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Wed, 12 Dec 2018 10:03:50 +0800 Subject: [PATCH 053/194] [SPARK-26193][SQL][FOLLOW UP] Read metrics rename and display text changes ## What changes were proposed in this pull request? Follow up pr for #23207, include following changes: - Rename `SQLShuffleMetricsReporter` to `SQLShuffleReadMetricsReporter` to make it match with write side naming. - Display text changes for read side for naming consistent. - Rename function in `ShuffleWriteProcessor`. - Delete `private[spark]` in execution package. ## How was this patch tested? Existing tests. Closes #23286 from xuanyuanking/SPARK-26193-follow. Authored-by: Yuanjian Li Signed-off-by: Wenchen Fan --- .../spark/scheduler/ShuffleMapTask.scala | 2 +- .../spark/shuffle/ShuffleWriteProcessor.scala | 2 +- .../spark/sql/execution/ShuffledRowRDD.scala | 6 ++-- .../exchange/ShuffleExchangeExec.scala | 4 +-- .../apache/spark/sql/execution/limit.scala | 6 ++-- .../metric/SQLShuffleMetricsReporter.scala | 36 +++++++++---------- .../execution/UnsafeRowSerializerSuite.scala | 4 +-- .../execution/metric/SQLMetricsSuite.scala | 20 +++++------ 8 files changed, 40 insertions(+), 40 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 2a8d1dd995e27..35664ff515d4b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -92,7 +92,7 @@ private[spark] class ShuffleMapTask( threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime } else 0L - dep.shuffleWriterProcessor.writeProcess(rdd, dep, partitionId, context, partition) + dep.shuffleWriterProcessor.write(rdd, dep, partitionId, context, partition) } override def preferredLocations: Seq[TaskLocation] = preferredLocs diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala index f5213157a9a85..5b0c7e9f2b0b4 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala @@ -41,7 +41,7 @@ private[spark] class ShuffleWriteProcessor extends Serializable with Logging { * get from [[ShuffleManager]] and triggers rdd compute, finally return the [[MapStatus]] for * this task. */ - def writeProcess( + def write( rdd: RDD[_], dep: ShuffleDependency[_, _, _], partitionId: Int, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala index 9b05faaed0459..079ff25fcb67e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala @@ -22,7 +22,7 @@ import java.util.Arrays import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.metric.{SQLMetric, SQLShuffleMetricsReporter} +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLShuffleReadMetricsReporter} /** * The [[Partition]] used by [[ShuffledRowRDD]]. A post-shuffle partition @@ -157,9 +157,9 @@ class ShuffledRowRDD( override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { val shuffledRowPartition = split.asInstanceOf[ShuffledRowRDDPartition] val tempMetrics = context.taskMetrics().createTempShuffleReadMetrics() - // `SQLShuffleMetricsReporter` will update its own metrics for SQL exchange operator, + // `SQLShuffleReadMetricsReporter` will update its own metrics for SQL exchange operator, // as well as the `tempMetrics` for basic shuffle metrics. - val sqlMetricsReporter = new SQLShuffleMetricsReporter(tempMetrics, metrics) + val sqlMetricsReporter = new SQLShuffleReadMetricsReporter(tempMetrics, metrics) // The range of pre-shuffle partitions that we are fetching at here is // [startPreShufflePartitionIndex, endPreShufflePartitionIndex - 1]. val reader = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 0c2020572e721..da7b0c6f43fbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Uns import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleMetricsReporter, SQLShuffleWriteMetricsReporter} +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType import org.apache.spark.util.MutablePair @@ -50,7 +50,7 @@ case class ShuffleExchangeExec( private lazy val writeMetrics = SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext) private lazy val readMetrics = - SQLShuffleMetricsReporter.createShuffleReadMetrics(sparkContext) + SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext) override lazy val metrics = Map( "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size") ) ++ readMetrics ++ writeMetrics diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 1f2fdde538645..bfaf080292bce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGe import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec -import org.apache.spark.sql.execution.metric.{SQLShuffleMetricsReporter, SQLShuffleWriteMetricsReporter} +import org.apache.spark.sql.execution.metric.{SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} /** * Take the first `limit` elements and collect them to a single partition. @@ -41,7 +41,7 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode private lazy val writeMetrics = SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext) private lazy val readMetrics = - SQLShuffleMetricsReporter.createShuffleReadMetrics(sparkContext) + SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext) override lazy val metrics = readMetrics ++ writeMetrics protected override def doExecute(): RDD[InternalRow] = { val locallyLimited = child.execute().mapPartitionsInternal(_.take(limit)) @@ -165,7 +165,7 @@ case class TakeOrderedAndProjectExec( private lazy val writeMetrics = SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext) private lazy val readMetrics = - SQLShuffleMetricsReporter.createShuffleReadMetrics(sparkContext) + SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext) override lazy val metrics = readMetrics ++ writeMetrics protected override def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLShuffleMetricsReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLShuffleMetricsReporter.scala index ff7941e3b3e8d..2c0ea80495abb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLShuffleMetricsReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLShuffleMetricsReporter.scala @@ -27,23 +27,23 @@ import org.apache.spark.shuffle.ShuffleWriteMetricsReporter * @param metrics All metrics in current SparkPlan. This param should not empty and * contains all shuffle metrics defined in createShuffleReadMetrics. */ -private[spark] class SQLShuffleMetricsReporter( +class SQLShuffleReadMetricsReporter( tempMetrics: TempShuffleReadMetrics, metrics: Map[String, SQLMetric]) extends TempShuffleReadMetrics { private[this] val _remoteBlocksFetched = - metrics(SQLShuffleMetricsReporter.REMOTE_BLOCKS_FETCHED) + metrics(SQLShuffleReadMetricsReporter.REMOTE_BLOCKS_FETCHED) private[this] val _localBlocksFetched = - metrics(SQLShuffleMetricsReporter.LOCAL_BLOCKS_FETCHED) + metrics(SQLShuffleReadMetricsReporter.LOCAL_BLOCKS_FETCHED) private[this] val _remoteBytesRead = - metrics(SQLShuffleMetricsReporter.REMOTE_BYTES_READ) + metrics(SQLShuffleReadMetricsReporter.REMOTE_BYTES_READ) private[this] val _remoteBytesReadToDisk = - metrics(SQLShuffleMetricsReporter.REMOTE_BYTES_READ_TO_DISK) + metrics(SQLShuffleReadMetricsReporter.REMOTE_BYTES_READ_TO_DISK) private[this] val _localBytesRead = - metrics(SQLShuffleMetricsReporter.LOCAL_BYTES_READ) + metrics(SQLShuffleReadMetricsReporter.LOCAL_BYTES_READ) private[this] val _fetchWaitTime = - metrics(SQLShuffleMetricsReporter.FETCH_WAIT_TIME) + metrics(SQLShuffleReadMetricsReporter.FETCH_WAIT_TIME) private[this] val _recordsRead = - metrics(SQLShuffleMetricsReporter.RECORDS_READ) + metrics(SQLShuffleReadMetricsReporter.RECORDS_READ) override def incRemoteBlocksFetched(v: Long): Unit = { _remoteBlocksFetched.add(v) @@ -75,7 +75,7 @@ private[spark] class SQLShuffleMetricsReporter( } } -private[spark] object SQLShuffleMetricsReporter { +object SQLShuffleReadMetricsReporter { val REMOTE_BLOCKS_FETCHED = "remoteBlocksFetched" val LOCAL_BLOCKS_FETCHED = "localBlocksFetched" val REMOTE_BYTES_READ = "remoteBytesRead" @@ -88,8 +88,8 @@ private[spark] object SQLShuffleMetricsReporter { * Create all shuffle read relative metrics and return the Map. */ def createShuffleReadMetrics(sc: SparkContext): Map[String, SQLMetric] = Map( - REMOTE_BLOCKS_FETCHED -> SQLMetrics.createMetric(sc, "remote blocks fetched"), - LOCAL_BLOCKS_FETCHED -> SQLMetrics.createMetric(sc, "local blocks fetched"), + REMOTE_BLOCKS_FETCHED -> SQLMetrics.createMetric(sc, "remote blocks read"), + LOCAL_BLOCKS_FETCHED -> SQLMetrics.createMetric(sc, "local blocks read"), REMOTE_BYTES_READ -> SQLMetrics.createSizeMetric(sc, "remote bytes read"), REMOTE_BYTES_READ_TO_DISK -> SQLMetrics.createSizeMetric(sc, "remote bytes read to disk"), LOCAL_BYTES_READ -> SQLMetrics.createSizeMetric(sc, "local bytes read"), @@ -102,7 +102,7 @@ private[spark] object SQLShuffleMetricsReporter { * @param metricsReporter Other reporter need to be updated in this SQLShuffleWriteMetricsReporter. * @param metrics Shuffle write metrics in current SparkPlan. */ -private[spark] class SQLShuffleWriteMetricsReporter( +class SQLShuffleWriteMetricsReporter( metricsReporter: ShuffleWriteMetricsReporter, metrics: Map[String, SQLMetric]) extends ShuffleWriteMetricsReporter { private[this] val _bytesWritten = @@ -112,29 +112,29 @@ private[spark] class SQLShuffleWriteMetricsReporter( private[this] val _writeTime = metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_WRITE_TIME) - override private[spark] def incBytesWritten(v: Long): Unit = { + override def incBytesWritten(v: Long): Unit = { metricsReporter.incBytesWritten(v) _bytesWritten.add(v) } - override private[spark] def decRecordsWritten(v: Long): Unit = { + override def decRecordsWritten(v: Long): Unit = { metricsReporter.decBytesWritten(v) _recordsWritten.set(_recordsWritten.value - v) } - override private[spark] def incRecordsWritten(v: Long): Unit = { + override def incRecordsWritten(v: Long): Unit = { metricsReporter.incRecordsWritten(v) _recordsWritten.add(v) } - override private[spark] def incWriteTime(v: Long): Unit = { + override def incWriteTime(v: Long): Unit = { metricsReporter.incWriteTime(v) _writeTime.add(v) } - override private[spark] def decBytesWritten(v: Long): Unit = { + override def decBytesWritten(v: Long): Unit = { metricsReporter.decBytesWritten(v) _bytesWritten.set(_bytesWritten.value - v) } } -private[spark] object SQLShuffleWriteMetricsReporter { +object SQLShuffleWriteMetricsReporter { val SHUFFLE_BYTES_WRITTEN = "shuffleBytesWritten" val SHUFFLE_RECORDS_WRITTEN = "shuffleRecordsWritten" val SHUFFLE_WRITE_TIME = "shuffleWriteTime" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index 1ad5713ab8ae6..ca8692290edb2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{LocalSparkSession, Row, SparkSession} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} -import org.apache.spark.sql.execution.metric.SQLShuffleMetricsReporter +import org.apache.spark.sql.execution.metric.SQLShuffleReadMetricsReporter import org.apache.spark.sql.types._ import org.apache.spark.storage.ShuffleBlockId import org.apache.spark.util.collection.ExternalSorter @@ -140,7 +140,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkSession { new UnsafeRowSerializer(2)) val shuffled = new ShuffledRowRDD( dependency, - SQLShuffleMetricsReporter.createShuffleReadMetrics(spark.sparkContext)) + SQLShuffleReadMetricsReporter.createShuffleReadMetrics(spark.sparkContext)) shuffled.count() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index f6495496a58e1..47265df4831df 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -96,8 +96,8 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared "avg hash probe (min, med, max)" -> "\n(1, 1, 1)")) val shuffleExpected1 = Map( "records read" -> 2L, - "local blocks fetched" -> 2L, - "remote blocks fetched" -> 0L, + "local blocks read" -> 2L, + "remote blocks read" -> 0L, "shuffle records written" -> 2L) testSparkPlanMetrics(df, 1, Map( 2L -> (("HashAggregate", expected1(0))), @@ -114,8 +114,8 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared "avg hash probe (min, med, max)" -> "\n(1, 1, 1)")) val shuffleExpected2 = Map( "records read" -> 4L, - "local blocks fetched" -> 4L, - "remote blocks fetched" -> 0L, + "local blocks read" -> 4L, + "remote blocks read" -> 0L, "shuffle records written" -> 4L) testSparkPlanMetrics(df2, 1, Map( 2L -> (("HashAggregate", expected2(0))), @@ -175,8 +175,8 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared 1L -> (("Exchange", Map( "shuffle records written" -> 2L, "records read" -> 2L, - "local blocks fetched" -> 2L, - "remote blocks fetched" -> 0L))), + "local blocks read" -> 2L, + "remote blocks read" -> 0L))), 0L -> (("ObjectHashAggregate", Map("number of output rows" -> 1L)))) ) @@ -187,8 +187,8 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared 1L -> (("Exchange", Map( "shuffle records written" -> 4L, "records read" -> 4L, - "local blocks fetched" -> 4L, - "remote blocks fetched" -> 0L))), + "local blocks read" -> 4L, + "remote blocks read" -> 0L))), 0L -> (("ObjectHashAggregate", Map("number of output rows" -> 3L)))) ) } @@ -216,8 +216,8 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared "number of output rows" -> 4L))), 2L -> (("Exchange", Map( "records read" -> 4L, - "local blocks fetched" -> 2L, - "remote blocks fetched" -> 0L, + "local blocks read" -> 2L, + "remote blocks read" -> 0L, "shuffle records written" -> 2L)))) ) } From afe463c6584c23bb12e89315278b45ee456d6641 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 12 Dec 2018 09:03:13 -0600 Subject: [PATCH 054/194] [SPARK-19827][R][FOLLOWUP] spark.ml R API for PIC ## What changes were proposed in this pull request? Follow up style fixes to PIC in R; see #23072 ## How was this patch tested? Existing tests. Closes #23292 from srowen/SPARK-19827.2. Authored-by: Sean Owen Signed-off-by: Sean Owen --- R/pkg/R/mllib_clustering.R | 15 ++++++--------- R/pkg/R/mllib_fpm.R | 4 ++-- examples/src/main/r/ml/powerIterationClustering.R | 3 ++- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/R/pkg/R/mllib_clustering.R b/R/pkg/R/mllib_clustering.R index 7d9dcebfe70d3..9b32b71d34fef 100644 --- a/R/pkg/R/mllib_clustering.R +++ b/R/pkg/R/mllib_clustering.R @@ -621,11 +621,10 @@ setMethod("write.ml", signature(object = "LDAModel", path = "character"), #' #' A scalable graph clustering algorithm. Users can call \code{spark.assignClusters} to #' return a cluster assignment for each input vertex. -#' -# Run the PIC algorithm and returns a cluster assignment for each input vertex. +#' Run the PIC algorithm and returns a cluster assignment for each input vertex. #' @param data a SparkDataFrame. #' @param k the number of clusters to create. -#' @param initMode the initialization algorithm. +#' @param initMode the initialization algorithm; "random" or "degree" #' @param maxIter the maximum number of iterations. #' @param sourceCol the name of the input column for source vertex IDs. #' @param destinationCol the name of the input column for destination vertex IDs @@ -633,18 +632,16 @@ setMethod("write.ml", signature(object = "LDAModel", path = "character"), #' we treat all instance weights as 1.0. #' @param ... additional argument(s) passed to the method. #' @return A dataset that contains columns of vertex id and the corresponding cluster for the id. -#' The schema of it will be: -#' \code{id: Long} -#' \code{cluster: Int} +#' The schema of it will be: \code{id: integer}, \code{cluster: integer} #' @rdname spark.powerIterationClustering -#' @aliases assignClusters,PowerIterationClustering-method,SparkDataFrame-method +#' @aliases spark.assignClusters,SparkDataFrame-method #' @examples #' \dontrun{ #' df <- createDataFrame(list(list(0L, 1L, 1.0), list(0L, 2L, 1.0), #' list(1L, 2L, 1.0), list(3L, 4L, 1.0), #' list(4L, 0L, 0.1)), #' schema = c("src", "dst", "weight")) -#' clusters <- spark.assignClusters(df, initMode="degree", weightCol="weight") +#' clusters <- spark.assignClusters(df, initMode = "degree", weightCol = "weight") #' showDF(clusters) #' } #' @note spark.assignClusters(SparkDataFrame) since 3.0.0 @@ -652,7 +649,7 @@ setMethod("spark.assignClusters", signature(data = "SparkDataFrame"), function(data, k = 2L, initMode = c("random", "degree"), maxIter = 20L, sourceCol = "src", destinationCol = "dst", weightCol = NULL) { - if (!is.numeric(k) || k < 1) { + if (!is.integer(k) || k < 1) { stop("k should be a number with value >= 1.") } if (!is.integer(maxIter) || maxIter <= 0) { diff --git a/R/pkg/R/mllib_fpm.R b/R/pkg/R/mllib_fpm.R index c248e9ec9be94..0cc7a16c302dc 100644 --- a/R/pkg/R/mllib_fpm.R +++ b/R/pkg/R/mllib_fpm.R @@ -183,8 +183,8 @@ setMethod("write.ml", signature(object = "FPGrowthModel", path = "character"), #' @return A complete set of frequent sequential patterns in the input sequences of itemsets. #' The returned \code{SparkDataFrame} contains columns of sequence and corresponding #' frequency. The schema of it will be: -#' \code{sequence: ArrayType(ArrayType(T))} (T is the item type) -#' \code{freq: Long} +#' \code{sequence: ArrayType(ArrayType(T))}, \code{freq: integer} +#' where T is the item type #' @rdname spark.prefixSpan #' @aliases findFrequentSequentialPatterns,PrefixSpan,SparkDataFrame-method #' @examples diff --git a/examples/src/main/r/ml/powerIterationClustering.R b/examples/src/main/r/ml/powerIterationClustering.R index ba43037106d14..3530d88e50509 100644 --- a/examples/src/main/r/ml/powerIterationClustering.R +++ b/examples/src/main/r/ml/powerIterationClustering.R @@ -30,7 +30,8 @@ df <- createDataFrame(list(list(0L, 1L, 1.0), list(0L, 2L, 1.0), list(4L, 0L, 0.1)), schema = c("src", "dst", "weight")) # assign clusters -clusters <- spark.assignClusters(df, k=2L, maxIter=20L, initMode="degree", weightCol="weight") +clusters <- spark.assignClusters(df, k = 2L, maxIter = 20L, + initMode = "degree", weightCol = "weight") showDF(arrange(clusters, clusters$id)) # $example off$ From c76d70a4f3a85d41f8b6d4394662798ff4ebb81b Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Wed, 12 Dec 2018 10:06:41 -0600 Subject: [PATCH 055/194] [SPARK-24102][ML][MLLIB] ML Evaluators should use weight column - added weight column for regression evaluator ## What changes were proposed in this pull request? The evaluators BinaryClassificationEvaluator, RegressionEvaluator, and MulticlassClassificationEvaluator and the corresponding metrics classes BinaryClassificationMetrics, RegressionMetrics and MulticlassMetrics should use sample weight data. I've closed the PR: https://github.com/apache/spark/pull/16557 as recommended in favor of creating three pull requests, one for each of the evaluators (binary/regression/multiclass) to make it easier to review/update. The updates to the regression metrics were based on (and updated with new changes based on comments): https://issues.apache.org/jira/browse/SPARK-11520 ("RegressionMetrics should support instance weights") but the pull request was closed as the changes were never checked in. ## How was this patch tested? I added tests to the metrics class. Closes #17085 from imatiach-msft/ilmat/regression-evaluate. Authored-by: Ilya Matiach Signed-off-by: Sean Owen --- .../ml/evaluation/RegressionEvaluator.scala | 19 ++++--- .../mllib/evaluation/RegressionMetrics.scala | 30 ++++++----- .../stat/MultivariateOnlineSummarizer.scala | 25 ++++++---- .../stat/MultivariateStatisticalSummary.scala | 6 +++ .../evaluation/RegressionMetricsSuite.scala | 50 +++++++++++++++++++ project/MimaExcludes.scala | 5 +- 6 files changed, 106 insertions(+), 29 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala index 031cd0d635bf4..616569bb55e4c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} -import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} +import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol, HasWeightCol} import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.sql.{Dataset, Row} @@ -33,7 +33,8 @@ import org.apache.spark.sql.types.{DoubleType, FloatType} @Since("1.4.0") @Experimental final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String) - extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable { + extends Evaluator with HasPredictionCol with HasLabelCol + with HasWeightCol with DefaultParamsWritable { @Since("1.4.0") def this() = this(Identifiable.randomUID("regEval")) @@ -69,6 +70,10 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui @Since("1.4.0") def setLabelCol(value: String): this.type = set(labelCol, value) + /** @group setParam */ + @Since("3.0.0") + def setWeightCol(value: String): this.type = set(weightCol, value) + setDefault(metricName -> "rmse") @Since("2.0.0") @@ -77,11 +82,13 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui SchemaUtils.checkColumnTypes(schema, $(predictionCol), Seq(DoubleType, FloatType)) SchemaUtils.checkNumericType(schema, $(labelCol)) - val predictionAndLabels = dataset - .select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType)) + val predictionAndLabelsWithWeights = dataset + .select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType), + if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))) .rdd - .map { case Row(prediction: Double, label: Double) => (prediction, label) } - val metrics = new RegressionMetrics(predictionAndLabels) + .map { case Row(prediction: Double, label: Double, weight: Double) => + (prediction, label, weight) } + val metrics = new RegressionMetrics(predictionAndLabelsWithWeights) val metric = $(metricName) match { case "rmse" => metrics.rootMeanSquaredError case "mse" => metrics.meanSquaredError diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala index 020676cac5a64..525047973ad5c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala @@ -27,17 +27,18 @@ import org.apache.spark.sql.DataFrame /** * Evaluator for regression. * - * @param predictionAndObservations an RDD of (prediction, observation) pairs + * @param predAndObsWithOptWeight an RDD of either (prediction, observation, weight) + * or (prediction, observation) pairs * @param throughOrigin True if the regression is through the origin. For example, in linear * regression, it will be true without fitting intercept. */ @Since("1.2.0") class RegressionMetrics @Since("2.0.0") ( - predictionAndObservations: RDD[(Double, Double)], throughOrigin: Boolean) + predAndObsWithOptWeight: RDD[_ <: Product], throughOrigin: Boolean) extends Logging { @Since("1.2.0") - def this(predictionAndObservations: RDD[(Double, Double)]) = + def this(predictionAndObservations: RDD[_ <: Product]) = this(predictionAndObservations, false) /** @@ -52,10 +53,13 @@ class RegressionMetrics @Since("2.0.0") ( * Use MultivariateOnlineSummarizer to calculate summary statistics of observations and errors. */ private lazy val summary: MultivariateStatisticalSummary = { - val summary: MultivariateStatisticalSummary = predictionAndObservations.map { - case (prediction, observation) => Vectors.dense(observation, observation - prediction) + val summary: MultivariateStatisticalSummary = predAndObsWithOptWeight.map { + case (prediction: Double, observation: Double, weight: Double) => + (Vectors.dense(observation, observation - prediction), weight) + case (prediction: Double, observation: Double) => + (Vectors.dense(observation, observation - prediction), 1.0) }.treeAggregate(new MultivariateOnlineSummarizer())( - (summary, v) => summary.add(v), + (summary, sample) => summary.add(sample._1, sample._2), (sum1, sum2) => sum1.merge(sum2) ) summary @@ -63,11 +67,13 @@ class RegressionMetrics @Since("2.0.0") ( private lazy val SSy = math.pow(summary.normL2(0), 2) private lazy val SSerr = math.pow(summary.normL2(1), 2) - private lazy val SStot = summary.variance(0) * (summary.count - 1) + private lazy val SStot = summary.variance(0) * (summary.weightSum - 1) private lazy val SSreg = { val yMean = summary.mean(0) - predictionAndObservations.map { - case (prediction, _) => math.pow(prediction - yMean, 2) + predAndObsWithOptWeight.map { + case (prediction: Double, _: Double, weight: Double) => + math.pow(prediction - yMean, 2) * weight + case (prediction: Double, _: Double) => math.pow(prediction - yMean, 2) }.sum() } @@ -79,7 +85,7 @@ class RegressionMetrics @Since("2.0.0") ( */ @Since("1.2.0") def explainedVariance: Double = { - SSreg / summary.count + SSreg / summary.weightSum } /** @@ -88,7 +94,7 @@ class RegressionMetrics @Since("2.0.0") ( */ @Since("1.2.0") def meanAbsoluteError: Double = { - summary.normL1(1) / summary.count + summary.normL1(1) / summary.weightSum } /** @@ -97,7 +103,7 @@ class RegressionMetrics @Since("2.0.0") ( */ @Since("1.2.0") def meanSquaredError: Double = { - SSerr / summary.count + SSerr / summary.weightSum } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 0554b6d8ff5b5..6d510e1633d67 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -52,7 +52,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S private var totalCnt: Long = 0 private var totalWeightSum: Double = 0.0 private var weightSquareSum: Double = 0.0 - private var weightSum: Array[Double] = _ + private var currWeightSum: Array[Double] = _ private var nnz: Array[Long] = _ private var currMax: Array[Double] = _ private var currMin: Array[Double] = _ @@ -78,7 +78,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S currM2n = Array.ofDim[Double](n) currM2 = Array.ofDim[Double](n) currL1 = Array.ofDim[Double](n) - weightSum = Array.ofDim[Double](n) + currWeightSum = Array.ofDim[Double](n) nnz = Array.ofDim[Long](n) currMax = Array.fill[Double](n)(Double.MinValue) currMin = Array.fill[Double](n)(Double.MaxValue) @@ -91,7 +91,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S val localCurrM2n = currM2n val localCurrM2 = currM2 val localCurrL1 = currL1 - val localWeightSum = weightSum + val localWeightSum = currWeightSum val localNumNonzeros = nnz val localCurrMax = currMax val localCurrMin = currMin @@ -139,8 +139,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S weightSquareSum += other.weightSquareSum var i = 0 while (i < n) { - val thisNnz = weightSum(i) - val otherNnz = other.weightSum(i) + val thisNnz = currWeightSum(i) + val otherNnz = other.currWeightSum(i) val totalNnz = thisNnz + otherNnz val totalCnnz = nnz(i) + other.nnz(i) if (totalNnz != 0.0) { @@ -157,7 +157,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S currMax(i) = math.max(currMax(i), other.currMax(i)) currMin(i) = math.min(currMin(i), other.currMin(i)) } - weightSum(i) = totalNnz + currWeightSum(i) = totalNnz nnz(i) = totalCnnz i += 1 } @@ -170,7 +170,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S this.totalCnt = other.totalCnt this.totalWeightSum = other.totalWeightSum this.weightSquareSum = other.weightSquareSum - this.weightSum = other.weightSum.clone() + this.currWeightSum = other.currWeightSum.clone() this.nnz = other.nnz.clone() this.currMax = other.currMax.clone() this.currMin = other.currMin.clone() @@ -189,7 +189,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S val realMean = Array.ofDim[Double](n) var i = 0 while (i < n) { - realMean(i) = currMean(i) * (weightSum(i) / totalWeightSum) + realMean(i) = currMean(i) * (currWeightSum(i) / totalWeightSum) i += 1 } Vectors.dense(realMean) @@ -214,8 +214,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S val len = currM2n.length while (i < len) { // We prevent variance from negative value caused by numerical error. - realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) * - (totalWeightSum - weightSum(i)) / totalWeightSum) / denominator, 0.0) + realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * currWeightSum(i) * + (totalWeightSum - currWeightSum(i)) / totalWeightSum) / denominator, 0.0) i += 1 } } @@ -229,6 +229,11 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S @Since("1.1.0") override def count: Long = totalCnt + /** + * Sum of weights. + */ + override def weightSum: Double = totalWeightSum + /** * Number of nonzero elements in each dimension. * diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala index 39a16fb743d64..a4381032f8c0d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala @@ -44,6 +44,12 @@ trait MultivariateStatisticalSummary { @Since("1.0.0") def count: Long + /** + * Sum of weights. + */ + @Since("3.0.0") + def weightSum: Double + /** * Number of nonzero elements (including explicitly presented zero values) in each column. */ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala index f1d517383643d..23809777f7d3a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala @@ -133,4 +133,54 @@ class RegressionMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { "root mean squared error mismatch") assert(metrics.r2 ~== 1.0 absTol eps, "r2 score mismatch") } + + test("regression metrics with same (1.0) weight samples") { + val predictionAndObservationWithWeight = sc.parallelize( + Seq((2.25, 3.0, 1.0), (-0.25, -0.5, 1.0), (1.75, 2.0, 1.0), (7.75, 7.0, 1.0)), 2) + val metrics = new RegressionMetrics(predictionAndObservationWithWeight, false) + assert(metrics.explainedVariance ~== 8.79687 absTol eps, + "explained variance regression score mismatch") + assert(metrics.meanAbsoluteError ~== 0.5 absTol eps, "mean absolute error mismatch") + assert(metrics.meanSquaredError ~== 0.3125 absTol eps, "mean squared error mismatch") + assert(metrics.rootMeanSquaredError ~== 0.55901 absTol eps, + "root mean squared error mismatch") + assert(metrics.r2 ~== 0.95717 absTol eps, "r2 score mismatch") + } + + /** + * The following values are hand calculated using the formula: + * [[https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights]] + * preds = c(2.25, -0.25, 1.75, 7.75) + * obs = c(3.0, -0.5, 2.0, 7.0) + * weights = c(0.1, 0.2, 0.15, 0.05) + * count = 4 + * + * Weighted metrics can be calculated with MultivariateStatisticalSummary. + * (observations, observations - predictions) + * mean (1.7, 0.05) + * variance (7.3, 0.3) + * numNonZeros (0.5, 0.5) + * max (7.0, 0.75) + * min (-0.5, -0.75) + * normL2 (2.0, 0.32596) + * normL1 (1.05, 0.2) + * + * explainedVariance: sum(pow((preds - 1.7),2)*weight) / weightedCount = 5.2425 + * meanAbsoluteError: normL1(1) / weightedCount = 0.4 + * meanSquaredError: pow(normL2(1),2) / weightedCount = 0.2125 + * rootMeanSquaredError: sqrt(meanSquaredError) = 0.46098 + * r2: 1 - pow(normL2(1),2) / (variance(0) * (weightedCount - 1)) = 1.02910 + */ + test("regression metrics with weighted samples") { + val predictionAndObservationWithWeight = sc.parallelize( + Seq((2.25, 3.0, 0.1), (-0.25, -0.5, 0.2), (1.75, 2.0, 0.15), (7.75, 7.0, 0.05)), 2) + val metrics = new RegressionMetrics(predictionAndObservationWithWeight, false) + assert(metrics.explainedVariance ~== 5.2425 absTol eps, + "explained variance regression score mismatch") + assert(metrics.meanAbsoluteError ~== 0.4 absTol eps, "mean absolute error mismatch") + assert(metrics.meanSquaredError ~== 0.2125 absTol eps, "mean squared error mismatch") + assert(metrics.rootMeanSquaredError ~== 0.46098 absTol eps, + "root mean squared error mismatch") + assert(metrics.r2 ~== 1.02910 absTol eps, "r2 score mismatch") + } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index b3252d70a80c8..883913332ca1e 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -531,7 +531,10 @@ object MimaExcludes { ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toDenseColMajor"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toDenseMatrix"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toSparseMatrix"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.getSizeInBytes") + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.getSizeInBytes"), + + // [SPARK-18693] Added weightSum to trait MultivariateStatisticalSummary + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.stat.MultivariateStatisticalSummary.weightSum") ) ++ Seq( // [SPARK-17019] Expose on-heap and off-heap memory usage in various places ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerBlockManagerAdded.copy"), From d6396279d0c85bcb5121e958a655f895559eb821 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 12 Dec 2018 12:01:21 -0800 Subject: [PATCH 056/194] [SPARK-25877][K8S] Move all feature logic to feature classes. This change makes the driver and executor builders a lot simpler by encapsulating almost all feature logic into the respective feature classes. The only logic that remains is the creation of the initial pod, which needs to happen before anything else so is better to be left in the builder class. Most feature classes already behave fine when the config has nothing they should handle, but a few minor tweaks had to be added. Unit tests were also updated or added to account for these. The builder suites were simplified a lot and just test the remaining pod-related code in the builders themselves. Author: Marcelo Vanzin Closes #23220 from vanzin/SPARK-25877. --- .../HadoopConfExecutorFeatureStep.scala | 10 +- .../HadoopSparkUserExecutorFeatureStep.scala | 5 +- .../KerberosConfExecutorFeatureStep.scala | 26 +-- .../features/PodTemplateConfigMapStep.scala | 82 +++++--- .../submit/KubernetesClientApplication.scala | 4 +- .../k8s/submit/KubernetesDriverBuilder.scala | 99 +++------ .../cluster/k8s/ExecutorPodsAllocator.scala | 3 +- .../k8s/KubernetesClusterManager.scala | 2 +- .../k8s/KubernetesExecutorBuilder.scala | 100 +++------ .../spark/deploy/k8s/PodBuilderSuite.scala | 177 ++++++++++++++++ .../PodTemplateConfigMapStepSuite.scala | 25 ++- .../spark/deploy/k8s/submit/ClientSuite.scala | 2 +- .../submit/KubernetesDriverBuilderSuite.scala | 194 +----------------- .../k8s/submit/PodBuilderSuiteUtils.scala | 142 ------------- .../k8s/ExecutorPodsAllocatorSuite.scala | 4 +- .../k8s/KubernetesExecutorBuilderSuite.scala | 144 +------------ 16 files changed, 343 insertions(+), 676 deletions(-) create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/PodBuilderSuite.scala delete mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/PodBuilderSuiteUtils.scala diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopConfExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopConfExecutorFeatureStep.scala index bca66759d586e..da332881ae1a2 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopConfExecutorFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopConfExecutorFeatureStep.scala @@ -31,10 +31,10 @@ private[spark] class HadoopConfExecutorFeatureStep(conf: KubernetesExecutorConf) override def configurePod(pod: SparkPod): SparkPod = { val hadoopConfDirCMapName = conf.getOption(HADOOP_CONFIG_MAP_NAME) - require(hadoopConfDirCMapName.isDefined, - "Ensure that the env `HADOOP_CONF_DIR` is defined either in the client or " + - " using pre-existing ConfigMaps") - logInfo("HADOOP_CONF_DIR defined") - HadoopBootstrapUtil.bootstrapHadoopConfDir(None, None, hadoopConfDirCMapName, pod) + if (hadoopConfDirCMapName.isDefined) { + HadoopBootstrapUtil.bootstrapHadoopConfDir(None, None, hadoopConfDirCMapName, pod) + } else { + pod + } } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopSparkUserExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopSparkUserExecutorFeatureStep.scala index e342110763196..c038e75491ca5 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopSparkUserExecutorFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopSparkUserExecutorFeatureStep.scala @@ -28,7 +28,8 @@ private[spark] class HadoopSparkUserExecutorFeatureStep(conf: KubernetesExecutor extends KubernetesFeatureConfigStep { override def configurePod(pod: SparkPod): SparkPod = { - val sparkUserName = conf.get(KERBEROS_SPARK_USER_NAME) - HadoopBootstrapUtil.bootstrapSparkUserPod(sparkUserName, pod) + conf.getOption(KERBEROS_SPARK_USER_NAME).map { user => + HadoopBootstrapUtil.bootstrapSparkUserPod(user, pod) + }.getOrElse(pod) } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfExecutorFeatureStep.scala index 32bb6a5d2bcbb..907271b1cb483 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfExecutorFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfExecutorFeatureStep.scala @@ -27,18 +27,20 @@ import org.apache.spark.internal.Logging private[spark] class KerberosConfExecutorFeatureStep(conf: KubernetesExecutorConf) extends KubernetesFeatureConfigStep with Logging { - private val maybeKrb5CMap = conf.getOption(KRB5_CONFIG_MAP_NAME) - require(maybeKrb5CMap.isDefined, "HADOOP_CONF_DIR ConfigMap not found") - override def configurePod(pod: SparkPod): SparkPod = { - logInfo(s"Mounting Resources for Kerberos") - HadoopBootstrapUtil.bootstrapKerberosPod( - conf.get(KERBEROS_DT_SECRET_NAME), - conf.get(KERBEROS_DT_SECRET_KEY), - conf.get(KERBEROS_SPARK_USER_NAME), - None, - None, - maybeKrb5CMap, - pod) + val maybeKrb5CMap = conf.getOption(KRB5_CONFIG_MAP_NAME) + if (maybeKrb5CMap.isDefined) { + logInfo(s"Mounting Resources for Kerberos") + HadoopBootstrapUtil.bootstrapKerberosPod( + conf.get(KERBEROS_DT_SECRET_NAME), + conf.get(KERBEROS_DT_SECRET_KEY), + conf.get(KERBEROS_SPARK_USER_NAME), + None, + None, + maybeKrb5CMap, + pod) + } else { + pod + } } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStep.scala index 09dcf93a54f8e..7f41ca43589b6 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStep.scala @@ -28,44 +28,60 @@ import org.apache.spark.deploy.k8s.Constants._ private[spark] class PodTemplateConfigMapStep(conf: KubernetesConf) extends KubernetesFeatureConfigStep { + + private val hasTemplate = conf.contains(KUBERNETES_EXECUTOR_PODTEMPLATE_FILE) + def configurePod(pod: SparkPod): SparkPod = { - val podWithVolume = new PodBuilder(pod.pod) - .editSpec() - .addNewVolume() - .withName(POD_TEMPLATE_VOLUME) - .withNewConfigMap() - .withName(POD_TEMPLATE_CONFIGMAP) - .addNewItem() - .withKey(POD_TEMPLATE_KEY) - .withPath(EXECUTOR_POD_SPEC_TEMPLATE_FILE_NAME) - .endItem() - .endConfigMap() - .endVolume() - .endSpec() - .build() + if (hasTemplate) { + val podWithVolume = new PodBuilder(pod.pod) + .editSpec() + .addNewVolume() + .withName(POD_TEMPLATE_VOLUME) + .withNewConfigMap() + .withName(POD_TEMPLATE_CONFIGMAP) + .addNewItem() + .withKey(POD_TEMPLATE_KEY) + .withPath(EXECUTOR_POD_SPEC_TEMPLATE_FILE_NAME) + .endItem() + .endConfigMap() + .endVolume() + .endSpec() + .build() - val containerWithVolume = new ContainerBuilder(pod.container) - .addNewVolumeMount() - .withName(POD_TEMPLATE_VOLUME) - .withMountPath(EXECUTOR_POD_SPEC_TEMPLATE_MOUNTPATH) - .endVolumeMount() - .build() - SparkPod(podWithVolume, containerWithVolume) + val containerWithVolume = new ContainerBuilder(pod.container) + .addNewVolumeMount() + .withName(POD_TEMPLATE_VOLUME) + .withMountPath(EXECUTOR_POD_SPEC_TEMPLATE_MOUNTPATH) + .endVolumeMount() + .build() + SparkPod(podWithVolume, containerWithVolume) + } else { + pod + } } - override def getAdditionalPodSystemProperties(): Map[String, String] = Map[String, String]( - KUBERNETES_EXECUTOR_PODTEMPLATE_FILE.key -> - (EXECUTOR_POD_SPEC_TEMPLATE_MOUNTPATH + "/" + EXECUTOR_POD_SPEC_TEMPLATE_FILE_NAME)) + override def getAdditionalPodSystemProperties(): Map[String, String] = { + if (hasTemplate) { + Map[String, String]( + KUBERNETES_EXECUTOR_PODTEMPLATE_FILE.key -> + (EXECUTOR_POD_SPEC_TEMPLATE_MOUNTPATH + "/" + EXECUTOR_POD_SPEC_TEMPLATE_FILE_NAME)) + } else { + Map.empty + } + } override def getAdditionalKubernetesResources(): Seq[HasMetadata] = { - require(conf.get(KUBERNETES_EXECUTOR_PODTEMPLATE_FILE).isDefined) - val podTemplateFile = conf.get(KUBERNETES_EXECUTOR_PODTEMPLATE_FILE).get - val podTemplateString = Files.toString(new File(podTemplateFile), StandardCharsets.UTF_8) - Seq(new ConfigMapBuilder() - .withNewMetadata() - .withName(POD_TEMPLATE_CONFIGMAP) - .endMetadata() - .addToData(POD_TEMPLATE_KEY, podTemplateString) - .build()) + if (hasTemplate) { + val podTemplateFile = conf.get(KUBERNETES_EXECUTOR_PODTEMPLATE_FILE).get + val podTemplateString = Files.toString(new File(podTemplateFile), StandardCharsets.UTF_8) + Seq(new ConfigMapBuilder() + .withNewMetadata() + .withName(POD_TEMPLATE_CONFIGMAP) + .endMetadata() + .addToData(POD_TEMPLATE_KEY, podTemplateString) + .build()) + } else { + Nil + } } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala index 70a93c968795e..3888778bf84ca 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala @@ -104,7 +104,7 @@ private[spark] class Client( watcher: LoggingPodStatusWatcher) extends Logging { def run(): Unit = { - val resolvedDriverSpec = builder.buildFromFeatures(conf) + val resolvedDriverSpec = builder.buildFromFeatures(conf, kubernetesClient) val configMapName = s"${conf.resourceNamePrefix}-driver-conf-map" val configMap = buildConfigMap(configMapName, resolvedDriverSpec.systemProperties) // The include of the ENV_VAR for "SPARK_CONF_DIR" is to allow for the @@ -232,7 +232,7 @@ private[spark] class KubernetesClientApplication extends SparkApplication { None)) { kubernetesClient => val client = new Client( kubernetesConf, - KubernetesDriverBuilder(kubernetesClient, kubernetesConf.sparkConf), + new KubernetesDriverBuilder(), kubernetesClient, waitForAppCompletion, watcher) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala index a5ad9729aee9a..d2c0ced9fa2f4 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala @@ -20,90 +20,49 @@ import java.io.File import io.fabric8.kubernetes.client.KubernetesClient -import org.apache.spark.SparkConf import org.apache.spark.deploy.k8s._ import org.apache.spark.deploy.k8s.features._ -private[spark] class KubernetesDriverBuilder( - provideBasicStep: (KubernetesDriverConf => BasicDriverFeatureStep) = - new BasicDriverFeatureStep(_), - provideCredentialsStep: (KubernetesDriverConf => DriverKubernetesCredentialsFeatureStep) = - new DriverKubernetesCredentialsFeatureStep(_), - provideServiceStep: (KubernetesDriverConf => DriverServiceFeatureStep) = - new DriverServiceFeatureStep(_), - provideSecretsStep: (KubernetesConf => MountSecretsFeatureStep) = - new MountSecretsFeatureStep(_), - provideEnvSecretsStep: (KubernetesConf => EnvSecretsFeatureStep) = - new EnvSecretsFeatureStep(_), - provideLocalDirsStep: (KubernetesConf => LocalDirsFeatureStep) = - new LocalDirsFeatureStep(_), - provideVolumesStep: (KubernetesConf => MountVolumesFeatureStep) = - new MountVolumesFeatureStep(_), - provideDriverCommandStep: (KubernetesDriverConf => DriverCommandFeatureStep) = - new DriverCommandFeatureStep(_), - provideHadoopGlobalStep: (KubernetesDriverConf => KerberosConfDriverFeatureStep) = - new KerberosConfDriverFeatureStep(_), - providePodTemplateConfigMapStep: (KubernetesConf => PodTemplateConfigMapStep) = - new PodTemplateConfigMapStep(_), - provideInitialPod: () => SparkPod = () => SparkPod.initialPod) { +private[spark] class KubernetesDriverBuilder { - def buildFromFeatures(kubernetesConf: KubernetesDriverConf): KubernetesDriverSpec = { - val baseFeatures = Seq( - provideBasicStep(kubernetesConf), - provideCredentialsStep(kubernetesConf), - provideServiceStep(kubernetesConf), - provideLocalDirsStep(kubernetesConf)) - - val secretFeature = if (kubernetesConf.secretNamesToMountPaths.nonEmpty) { - Seq(provideSecretsStep(kubernetesConf)) - } else Nil - val envSecretFeature = if (kubernetesConf.secretEnvNamesToKeyRefs.nonEmpty) { - Seq(provideEnvSecretsStep(kubernetesConf)) - } else Nil - val volumesFeature = if (kubernetesConf.volumes.nonEmpty) { - Seq(provideVolumesStep(kubernetesConf)) - } else Nil - val podTemplateFeature = if ( - kubernetesConf.get(Config.KUBERNETES_EXECUTOR_PODTEMPLATE_FILE).isDefined) { - Seq(providePodTemplateConfigMapStep(kubernetesConf)) - } else Nil - - val driverCommandStep = provideDriverCommandStep(kubernetesConf) - - val hadoopConfigStep = Some(provideHadoopGlobalStep(kubernetesConf)) + def buildFromFeatures( + conf: KubernetesDriverConf, + client: KubernetesClient): KubernetesDriverSpec = { + val initialPod = conf.get(Config.KUBERNETES_DRIVER_PODTEMPLATE_FILE) + .map { file => + KubernetesUtils.loadPodFromTemplate( + client, + new File(file), + conf.get(Config.KUBERNETES_DRIVER_PODTEMPLATE_CONTAINER_NAME)) + } + .getOrElse(SparkPod.initialPod()) - val allFeatures: Seq[KubernetesFeatureConfigStep] = - baseFeatures ++ Seq(driverCommandStep) ++ - secretFeature ++ envSecretFeature ++ volumesFeature ++ - hadoopConfigStep ++ podTemplateFeature + val features = Seq( + new BasicDriverFeatureStep(conf), + new DriverKubernetesCredentialsFeatureStep(conf), + new DriverServiceFeatureStep(conf), + new MountSecretsFeatureStep(conf), + new EnvSecretsFeatureStep(conf), + new LocalDirsFeatureStep(conf), + new MountVolumesFeatureStep(conf), + new DriverCommandFeatureStep(conf), + new KerberosConfDriverFeatureStep(conf), + new PodTemplateConfigMapStep(conf)) - var spec = KubernetesDriverSpec( - provideInitialPod(), + val spec = KubernetesDriverSpec( + initialPod, driverKubernetesResources = Seq.empty, - kubernetesConf.sparkConf.getAll.toMap) - for (feature <- allFeatures) { + conf.sparkConf.getAll.toMap) + + features.foldLeft(spec) { case (spec, feature) => val configuredPod = feature.configurePod(spec.pod) val addedSystemProperties = feature.getAdditionalPodSystemProperties() val addedResources = feature.getAdditionalKubernetesResources() - spec = KubernetesDriverSpec( + KubernetesDriverSpec( configuredPod, spec.driverKubernetesResources ++ addedResources, spec.systemProperties ++ addedSystemProperties) } - spec } -} -private[spark] object KubernetesDriverBuilder { - def apply(kubernetesClient: KubernetesClient, conf: SparkConf): KubernetesDriverBuilder = { - conf.get(Config.KUBERNETES_DRIVER_PODTEMPLATE_FILE) - .map(new File(_)) - .map(file => new KubernetesDriverBuilder(provideInitialPod = () => - KubernetesUtils.loadPodFromTemplate( - kubernetesClient, - file, - conf.get(Config.KUBERNETES_DRIVER_PODTEMPLATE_CONTAINER_NAME)) - )) - .getOrElse(new KubernetesDriverBuilder()) - } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala index ac42554b1334b..da3edfeca9b1f 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala @@ -136,7 +136,8 @@ private[spark] class ExecutorPodsAllocator( newExecutorId.toString, applicationId, driverPod) - val executorPod = executorBuilder.buildFromFeatures(executorConf, secMgr) + val executorPod = executorBuilder.buildFromFeatures(executorConf, secMgr, + kubernetesClient) val podWithAttachedContainer = new PodBuilder(executorPod.pod) .editOrNewSpec() .addToContainers(executorPod.container) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala index b31fbb420ed6d..809bdf8ca8c27 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala @@ -95,7 +95,7 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit val executorPodsAllocator = new ExecutorPodsAllocator( sc.conf, sc.env.securityManager, - KubernetesExecutorBuilder(kubernetesClient, sc.conf), + new KubernetesExecutorBuilder(), kubernetesClient, snapshotsStore, new SystemClock()) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala index ba273cad6a8e5..0b74966fe8685 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala @@ -20,86 +20,36 @@ import java.io.File import io.fabric8.kubernetes.client.KubernetesClient -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.SecurityManager import org.apache.spark.deploy.k8s._ -import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.features._ -private[spark] class KubernetesExecutorBuilder( - provideBasicStep: (KubernetesExecutorConf, SecurityManager) => BasicExecutorFeatureStep = - new BasicExecutorFeatureStep(_, _), - provideSecretsStep: (KubernetesConf => MountSecretsFeatureStep) = - new MountSecretsFeatureStep(_), - provideEnvSecretsStep: (KubernetesConf => EnvSecretsFeatureStep) = - new EnvSecretsFeatureStep(_), - provideLocalDirsStep: (KubernetesConf => LocalDirsFeatureStep) = - new LocalDirsFeatureStep(_), - provideVolumesStep: (KubernetesConf => MountVolumesFeatureStep) = - new MountVolumesFeatureStep(_), - provideHadoopConfStep: (KubernetesExecutorConf => HadoopConfExecutorFeatureStep) = - new HadoopConfExecutorFeatureStep(_), - provideKerberosConfStep: (KubernetesExecutorConf => KerberosConfExecutorFeatureStep) = - new KerberosConfExecutorFeatureStep(_), - provideHadoopSparkUserStep: (KubernetesExecutorConf => HadoopSparkUserExecutorFeatureStep) = - new HadoopSparkUserExecutorFeatureStep(_), - provideInitialPod: () => SparkPod = () => SparkPod.initialPod()) { +private[spark] class KubernetesExecutorBuilder { def buildFromFeatures( - kubernetesConf: KubernetesExecutorConf, - secMgr: SecurityManager): SparkPod = { - val sparkConf = kubernetesConf.sparkConf - val maybeHadoopConfigMap = sparkConf.getOption(HADOOP_CONFIG_MAP_NAME) - val maybeDTSecretName = sparkConf.getOption(KERBEROS_DT_SECRET_NAME) - val maybeDTDataItem = sparkConf.getOption(KERBEROS_DT_SECRET_KEY) - - val baseFeatures = Seq(provideBasicStep(kubernetesConf, secMgr), - provideLocalDirsStep(kubernetesConf)) - val secretFeature = if (kubernetesConf.secretNamesToMountPaths.nonEmpty) { - Seq(provideSecretsStep(kubernetesConf)) - } else Nil - val secretEnvFeature = if (kubernetesConf.secretEnvNamesToKeyRefs.nonEmpty) { - Seq(provideEnvSecretsStep(kubernetesConf)) - } else Nil - val volumesFeature = if (kubernetesConf.volumes.nonEmpty) { - Seq(provideVolumesStep(kubernetesConf)) - } else Nil - - val maybeHadoopConfFeatureSteps = maybeHadoopConfigMap.map { _ => - val maybeKerberosStep = - if (maybeDTSecretName.isDefined && maybeDTDataItem.isDefined) { - provideKerberosConfStep(kubernetesConf) - } else { - provideHadoopSparkUserStep(kubernetesConf) - } - Seq(provideHadoopConfStep(kubernetesConf)) :+ - maybeKerberosStep - }.getOrElse(Seq.empty[KubernetesFeatureConfigStep]) - - val allFeatures: Seq[KubernetesFeatureConfigStep] = - baseFeatures ++ - secretFeature ++ - secretEnvFeature ++ - volumesFeature ++ - maybeHadoopConfFeatureSteps - - var executorPod = provideInitialPod() - for (feature <- allFeatures) { - executorPod = feature.configurePod(executorPod) - } - executorPod + conf: KubernetesExecutorConf, + secMgr: SecurityManager, + client: KubernetesClient): SparkPod = { + val initialPod = conf.get(Config.KUBERNETES_EXECUTOR_PODTEMPLATE_FILE) + .map { file => + KubernetesUtils.loadPodFromTemplate( + client, + new File(file), + conf.get(Config.KUBERNETES_EXECUTOR_PODTEMPLATE_CONTAINER_NAME)) + } + .getOrElse(SparkPod.initialPod()) + + val features = Seq( + new BasicExecutorFeatureStep(conf, secMgr), + new MountSecretsFeatureStep(conf), + new EnvSecretsFeatureStep(conf), + new LocalDirsFeatureStep(conf), + new MountVolumesFeatureStep(conf), + new HadoopConfExecutorFeatureStep(conf), + new KerberosConfExecutorFeatureStep(conf), + new HadoopSparkUserExecutorFeatureStep(conf)) + + features.foldLeft(initialPod) { case (pod, feature) => feature.configurePod(pod) } } -} -private[spark] object KubernetesExecutorBuilder { - def apply(kubernetesClient: KubernetesClient, conf: SparkConf): KubernetesExecutorBuilder = { - conf.get(Config.KUBERNETES_EXECUTOR_PODTEMPLATE_FILE) - .map(new File(_)) - .map(file => new KubernetesExecutorBuilder(provideInitialPod = () => - KubernetesUtils.loadPodFromTemplate( - kubernetesClient, - file, - conf.get(Config.KUBERNETES_EXECUTOR_PODTEMPLATE_CONTAINER_NAME)) - )) - .getOrElse(new KubernetesExecutorBuilder()) - } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/PodBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/PodBuilderSuite.scala new file mode 100644 index 0000000000000..7dde0c1377168 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/PodBuilderSuite.scala @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import java.io.File + +import io.fabric8.kubernetes.api.model.{Config => _, _} +import io.fabric8.kubernetes.client.KubernetesClient +import io.fabric8.kubernetes.client.dsl.{MixedOperation, PodResource} +import org.mockito.Matchers.any +import org.mockito.Mockito.{mock, never, verify, when} +import scala.collection.JavaConverters._ + +import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} +import org.apache.spark.deploy.k8s._ +import org.apache.spark.internal.config.ConfigEntry + +abstract class PodBuilderSuite extends SparkFunSuite { + + protected def templateFileConf: ConfigEntry[_] + + protected def buildPod(sparkConf: SparkConf, client: KubernetesClient): SparkPod + + private val baseConf = new SparkConf(false) + .set(Config.CONTAINER_IMAGE, "spark-executor:latest") + + test("use empty initial pod if template is not specified") { + val client = mock(classOf[KubernetesClient]) + buildPod(baseConf.clone(), client) + verify(client, never()).pods() + } + + test("load pod template if specified") { + val client = mockKubernetesClient() + val sparkConf = baseConf.clone().set(templateFileConf.key, "template-file.yaml") + val pod = buildPod(sparkConf, client) + verifyPod(pod) + } + + test("complain about misconfigured pod template") { + val client = mockKubernetesClient( + new PodBuilder() + .withNewMetadata() + .addToLabels("test-label-key", "test-label-value") + .endMetadata() + .build()) + val sparkConf = baseConf.clone().set(templateFileConf.key, "template-file.yaml") + val exception = intercept[SparkException] { + buildPod(sparkConf, client) + } + assert(exception.getMessage.contains("Could not load pod from template file.")) + } + + private def mockKubernetesClient(pod: Pod = podWithSupportedFeatures()): KubernetesClient = { + val kubernetesClient = mock(classOf[KubernetesClient]) + val pods = + mock(classOf[MixedOperation[Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]]]) + val podResource = mock(classOf[PodResource[Pod, DoneablePod]]) + when(kubernetesClient.pods()).thenReturn(pods) + when(pods.load(any(classOf[File]))).thenReturn(podResource) + when(podResource.get()).thenReturn(pod) + kubernetesClient + } + + private def verifyPod(pod: SparkPod): Unit = { + val metadata = pod.pod.getMetadata + assert(metadata.getLabels.containsKey("test-label-key")) + assert(metadata.getAnnotations.containsKey("test-annotation-key")) + assert(metadata.getNamespace === "namespace") + assert(metadata.getOwnerReferences.asScala.exists(_.getName == "owner-reference")) + val spec = pod.pod.getSpec + assert(!spec.getContainers.asScala.exists(_.getName == "executor-container")) + assert(spec.getDnsPolicy === "dns-policy") + assert(spec.getHostAliases.asScala.exists(_.getHostnames.asScala.exists(_ == "hostname"))) + assert(spec.getImagePullSecrets.asScala.exists(_.getName == "local-reference")) + assert(spec.getInitContainers.asScala.exists(_.getName == "init-container")) + assert(spec.getNodeName == "node-name") + assert(spec.getNodeSelector.get("node-selector-key") === "node-selector-value") + assert(spec.getSchedulerName === "scheduler") + assert(spec.getSecurityContext.getRunAsUser === 1000L) + assert(spec.getServiceAccount === "service-account") + assert(spec.getSubdomain === "subdomain") + assert(spec.getTolerations.asScala.exists(_.getKey == "toleration-key")) + assert(spec.getVolumes.asScala.exists(_.getName == "test-volume")) + val container = pod.container + assert(container.getName === "executor-container") + assert(container.getArgs.contains("arg")) + assert(container.getCommand.equals(List("command").asJava)) + assert(container.getEnv.asScala.exists(_.getName == "env-key")) + assert(container.getResources.getLimits.get("gpu") === + new QuantityBuilder().withAmount("1").build()) + assert(container.getSecurityContext.getRunAsNonRoot) + assert(container.getStdin) + assert(container.getTerminationMessagePath === "termination-message-path") + assert(container.getTerminationMessagePolicy === "termination-message-policy") + assert(pod.container.getVolumeMounts.asScala.exists(_.getName == "test-volume")) + } + + private def podWithSupportedFeatures(): Pod = { + new PodBuilder() + .withNewMetadata() + .addToLabels("test-label-key", "test-label-value") + .addToAnnotations("test-annotation-key", "test-annotation-value") + .withNamespace("namespace") + .addNewOwnerReference() + .withController(true) + .withName("owner-reference") + .endOwnerReference() + .endMetadata() + .withNewSpec() + .withDnsPolicy("dns-policy") + .withHostAliases(new HostAliasBuilder().withHostnames("hostname").build()) + .withImagePullSecrets( + new LocalObjectReferenceBuilder().withName("local-reference").build()) + .withInitContainers(new ContainerBuilder().withName("init-container").build()) + .withNodeName("node-name") + .withNodeSelector(Map("node-selector-key" -> "node-selector-value").asJava) + .withSchedulerName("scheduler") + .withNewSecurityContext() + .withRunAsUser(1000L) + .endSecurityContext() + .withServiceAccount("service-account") + .withSubdomain("subdomain") + .withTolerations(new TolerationBuilder() + .withKey("toleration-key") + .withOperator("Equal") + .withEffect("NoSchedule") + .build()) + .addNewVolume() + .withNewHostPath() + .withPath("/test") + .endHostPath() + .withName("test-volume") + .endVolume() + .addNewContainer() + .withArgs("arg") + .withCommand("command") + .addNewEnv() + .withName("env-key") + .withValue("env-value") + .endEnv() + .withImagePullPolicy("Always") + .withName("executor-container") + .withNewResources() + .withLimits(Map("gpu" -> new QuantityBuilder().withAmount("1").build()).asJava) + .endResources() + .withNewSecurityContext() + .withRunAsNonRoot(true) + .endSecurityContext() + .withStdin(true) + .withTerminationMessagePath("termination-message-path") + .withTerminationMessagePolicy("termination-message-policy") + .addToVolumeMounts( + new VolumeMountBuilder() + .withName("test-volume") + .withMountPath("/test") + .build()) + .endContainer() + .endSpec() + .build() + } + +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStepSuite.scala index 7295b82ca4799..5e7388dc8e672 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStepSuite.scala @@ -20,25 +20,32 @@ import java.io.{File, PrintWriter} import java.nio.file.Files import io.fabric8.kubernetes.api.model.ConfigMap -import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s._ -class PodTemplateConfigMapStepSuite extends SparkFunSuite with BeforeAndAfter { - private var kubernetesConf : KubernetesConf = _ - private var templateFile: File = _ +class PodTemplateConfigMapStepSuite extends SparkFunSuite { - before { - templateFile = Files.createTempFile("pod-template", "yml").toFile + test("Do nothing when executor template is not specified") { + val conf = KubernetesTestConf.createDriverConf() + val step = new PodTemplateConfigMapStep(conf) + + val initialPod = SparkPod.initialPod() + val configuredPod = step.configurePod(initialPod) + assert(configuredPod === initialPod) + + assert(step.getAdditionalKubernetesResources().isEmpty) + assert(step.getAdditionalPodSystemProperties().isEmpty) + } + + test("Mounts executor template volume if config specified") { + val templateFile = Files.createTempFile("pod-template", "yml").toFile templateFile.deleteOnExit() val sparkConf = new SparkConf(false) .set(Config.KUBERNETES_EXECUTOR_PODTEMPLATE_FILE, templateFile.getAbsolutePath) - kubernetesConf = KubernetesTestConf.createDriverConf(sparkConf = sparkConf) - } + val kubernetesConf = KubernetesTestConf.createDriverConf(sparkConf = sparkConf) - test("Mounts executor template volume if config specified") { val writer = new PrintWriter(templateFile) writer.write("pod-template-contents") writer.close() diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala index e9c05fef6f5db..1bb926cbca23d 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala @@ -126,7 +126,7 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { MockitoAnnotations.initMocks(this) kconf = KubernetesTestConf.createDriverConf( resourceNamePrefix = Some(KUBERNETES_RESOURCE_PREFIX)) - when(driverBuilder.buildFromFeatures(kconf)).thenReturn(BUILT_KUBERNETES_SPEC) + when(driverBuilder.buildFromFeatures(kconf, kubernetesClient)).thenReturn(BUILT_KUBERNETES_SPEC) when(kubernetesClient.pods()).thenReturn(podOperations) when(podOperations.withName(POD_NAME)).thenReturn(namedPods) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala index 7e7dc4763c2e7..6518c91a1a1fd 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala @@ -16,201 +16,21 @@ */ package org.apache.spark.deploy.k8s.submit -import io.fabric8.kubernetes.api.model.PodBuilder import io.fabric8.kubernetes.client.KubernetesClient -import org.mockito.Mockito._ -import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} +import org.apache.spark.SparkConf import org.apache.spark.deploy.k8s._ -import org.apache.spark.deploy.k8s.Config.{CONTAINER_IMAGE, KUBERNETES_DRIVER_PODTEMPLATE_FILE, KUBERNETES_EXECUTOR_PODTEMPLATE_FILE} -import org.apache.spark.deploy.k8s.features._ +import org.apache.spark.internal.config.ConfigEntry -class KubernetesDriverBuilderSuite extends SparkFunSuite { +class KubernetesDriverBuilderSuite extends PodBuilderSuite { - private val BASIC_STEP_TYPE = "basic" - private val CREDENTIALS_STEP_TYPE = "credentials" - private val SERVICE_STEP_TYPE = "service" - private val LOCAL_DIRS_STEP_TYPE = "local-dirs" - private val SECRETS_STEP_TYPE = "mount-secrets" - private val DRIVER_CMD_STEP_TYPE = "driver-command" - private val ENV_SECRETS_STEP_TYPE = "env-secrets" - private val HADOOP_GLOBAL_STEP_TYPE = "hadoop-global" - private val MOUNT_VOLUMES_STEP_TYPE = "mount-volumes" - private val TEMPLATE_VOLUME_STEP_TYPE = "template-volume" - - private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( - BASIC_STEP_TYPE, classOf[BasicDriverFeatureStep]) - - private val credentialsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( - CREDENTIALS_STEP_TYPE, classOf[DriverKubernetesCredentialsFeatureStep]) - - private val serviceStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( - SERVICE_STEP_TYPE, classOf[DriverServiceFeatureStep]) - - private val localDirsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( - LOCAL_DIRS_STEP_TYPE, classOf[LocalDirsFeatureStep]) - - private val secretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( - SECRETS_STEP_TYPE, classOf[MountSecretsFeatureStep]) - - private val driverCommandStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( - DRIVER_CMD_STEP_TYPE, classOf[DriverCommandFeatureStep]) - - private val envSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( - ENV_SECRETS_STEP_TYPE, classOf[EnvSecretsFeatureStep]) - - private val hadoopGlobalStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( - HADOOP_GLOBAL_STEP_TYPE, classOf[KerberosConfDriverFeatureStep]) - - private val mountVolumesStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( - MOUNT_VOLUMES_STEP_TYPE, classOf[MountVolumesFeatureStep]) - - private val templateVolumeStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( - TEMPLATE_VOLUME_STEP_TYPE, classOf[PodTemplateConfigMapStep] - ) - - private val builderUnderTest: KubernetesDriverBuilder = - new KubernetesDriverBuilder( - _ => basicFeatureStep, - _ => credentialsStep, - _ => serviceStep, - _ => secretsStep, - _ => envSecretsStep, - _ => localDirsStep, - _ => mountVolumesStep, - _ => driverCommandStep, - _ => hadoopGlobalStep, - _ => templateVolumeStep) - - test("Apply fundamental steps all the time.") { - val conf = KubernetesTestConf.createDriverConf() - validateStepTypesApplied( - builderUnderTest.buildFromFeatures(conf), - BASIC_STEP_TYPE, - CREDENTIALS_STEP_TYPE, - SERVICE_STEP_TYPE, - LOCAL_DIRS_STEP_TYPE, - DRIVER_CMD_STEP_TYPE, - HADOOP_GLOBAL_STEP_TYPE) + override protected def templateFileConf: ConfigEntry[_] = { + Config.KUBERNETES_DRIVER_PODTEMPLATE_FILE } - test("Apply secrets step if secrets are present.") { - val conf = KubernetesTestConf.createDriverConf( - secretEnvNamesToKeyRefs = Map("EnvName" -> "SecretName:secretKey"), - secretNamesToMountPaths = Map("secret" -> "secretMountPath")) - validateStepTypesApplied( - builderUnderTest.buildFromFeatures(conf), - BASIC_STEP_TYPE, - CREDENTIALS_STEP_TYPE, - SERVICE_STEP_TYPE, - LOCAL_DIRS_STEP_TYPE, - SECRETS_STEP_TYPE, - ENV_SECRETS_STEP_TYPE, - DRIVER_CMD_STEP_TYPE, - HADOOP_GLOBAL_STEP_TYPE) - } - - test("Apply volumes step if mounts are present.") { - val volumeSpec = KubernetesVolumeSpec( - "volume", - "/tmp", - "", - false, - KubernetesHostPathVolumeConf("/path")) - val conf = KubernetesTestConf.createDriverConf(volumes = Seq(volumeSpec)) - validateStepTypesApplied( - builderUnderTest.buildFromFeatures(conf), - BASIC_STEP_TYPE, - CREDENTIALS_STEP_TYPE, - SERVICE_STEP_TYPE, - LOCAL_DIRS_STEP_TYPE, - MOUNT_VOLUMES_STEP_TYPE, - DRIVER_CMD_STEP_TYPE, - HADOOP_GLOBAL_STEP_TYPE) - } - - test("Apply volumes step if a mount subpath is present.") { - val volumeSpec = KubernetesVolumeSpec( - "volume", - "/tmp", - "foo", - false, - KubernetesHostPathVolumeConf("/path")) - val conf = KubernetesTestConf.createDriverConf(volumes = Seq(volumeSpec)) - validateStepTypesApplied( - builderUnderTest.buildFromFeatures(conf), - BASIC_STEP_TYPE, - CREDENTIALS_STEP_TYPE, - SERVICE_STEP_TYPE, - LOCAL_DIRS_STEP_TYPE, - MOUNT_VOLUMES_STEP_TYPE, - DRIVER_CMD_STEP_TYPE, - HADOOP_GLOBAL_STEP_TYPE) - } - - test("Apply template volume step if executor template is present.") { - val sparkConf = new SparkConf(false) - .set(KUBERNETES_EXECUTOR_PODTEMPLATE_FILE, "filename") + override protected def buildPod(sparkConf: SparkConf, client: KubernetesClient): SparkPod = { val conf = KubernetesTestConf.createDriverConf(sparkConf = sparkConf) - validateStepTypesApplied( - builderUnderTest.buildFromFeatures(conf), - BASIC_STEP_TYPE, - CREDENTIALS_STEP_TYPE, - SERVICE_STEP_TYPE, - LOCAL_DIRS_STEP_TYPE, - DRIVER_CMD_STEP_TYPE, - HADOOP_GLOBAL_STEP_TYPE, - TEMPLATE_VOLUME_STEP_TYPE) - } - - private def validateStepTypesApplied(resolvedSpec: KubernetesDriverSpec, stepTypes: String*) - : Unit = { - val addedProperties = resolvedSpec.systemProperties - .filter { case (k, _) => !k.startsWith("spark.") } - .toMap - assert(addedProperties.keys.toSet === stepTypes.toSet) - stepTypes.foreach { stepType => - assert(resolvedSpec.pod.pod.getMetadata.getLabels.get(stepType) === stepType) - assert(resolvedSpec.driverKubernetesResources.containsSlice( - KubernetesFeaturesTestUtils.getSecretsForStepType(stepType))) - assert(resolvedSpec.systemProperties(stepType) === stepType) - } - } - - test("Start with empty pod if template is not specified") { - val kubernetesClient = mock(classOf[KubernetesClient]) - val driverBuilder = KubernetesDriverBuilder.apply(kubernetesClient, new SparkConf()) - verify(kubernetesClient, never()).pods() + new KubernetesDriverBuilder().buildFromFeatures(conf, client).pod } - test("Starts with template if specified") { - val kubernetesClient = PodBuilderSuiteUtils.loadingMockKubernetesClient() - val sparkConf = new SparkConf(false) - .set(CONTAINER_IMAGE, "spark-driver:latest") - .set(KUBERNETES_DRIVER_PODTEMPLATE_FILE, "template-file.yaml") - val kubernetesConf = KubernetesTestConf.createDriverConf(sparkConf = sparkConf) - val driverSpec = KubernetesDriverBuilder - .apply(kubernetesClient, sparkConf) - .buildFromFeatures(kubernetesConf) - PodBuilderSuiteUtils.verifyPodWithSupportedFeatures(driverSpec.pod) - } - - test("Throws on misconfigured pod template") { - val kubernetesClient = PodBuilderSuiteUtils.loadingMockKubernetesClient( - new PodBuilder() - .withNewMetadata() - .addToLabels("test-label-key", "test-label-value") - .endMetadata() - .build()) - val sparkConf = new SparkConf(false) - .set(CONTAINER_IMAGE, "spark-driver:latest") - .set(KUBERNETES_DRIVER_PODTEMPLATE_FILE, "template-file.yaml") - val kubernetesConf = KubernetesTestConf.createDriverConf(sparkConf = sparkConf) - val exception = intercept[SparkException] { - KubernetesDriverBuilder - .apply(kubernetesClient, sparkConf) - .buildFromFeatures(kubernetesConf) - } - assert(exception.getMessage.contains("Could not load pod from template file.")) - } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/PodBuilderSuiteUtils.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/PodBuilderSuiteUtils.scala deleted file mode 100644 index c92e9e6e3b6b3..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/PodBuilderSuiteUtils.scala +++ /dev/null @@ -1,142 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit - -import java.io.File - -import io.fabric8.kubernetes.api.model._ -import io.fabric8.kubernetes.client.KubernetesClient -import io.fabric8.kubernetes.client.dsl.{MixedOperation, PodResource} -import org.mockito.Matchers.any -import org.mockito.Mockito.{mock, when} -import org.scalatest.FlatSpec -import scala.collection.JavaConverters._ - -import org.apache.spark.deploy.k8s.SparkPod - -object PodBuilderSuiteUtils extends FlatSpec { - - def loadingMockKubernetesClient(pod: Pod = podWithSupportedFeatures()): KubernetesClient = { - val kubernetesClient = mock(classOf[KubernetesClient]) - val pods = - mock(classOf[MixedOperation[Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]]]) - val podResource = mock(classOf[PodResource[Pod, DoneablePod]]) - when(kubernetesClient.pods()).thenReturn(pods) - when(pods.load(any(classOf[File]))).thenReturn(podResource) - when(podResource.get()).thenReturn(pod) - kubernetesClient - } - - def verifyPodWithSupportedFeatures(pod: SparkPod): Unit = { - val metadata = pod.pod.getMetadata - assert(metadata.getLabels.containsKey("test-label-key")) - assert(metadata.getAnnotations.containsKey("test-annotation-key")) - assert(metadata.getNamespace === "namespace") - assert(metadata.getOwnerReferences.asScala.exists(_.getName == "owner-reference")) - val spec = pod.pod.getSpec - assert(!spec.getContainers.asScala.exists(_.getName == "executor-container")) - assert(spec.getDnsPolicy === "dns-policy") - assert(spec.getHostAliases.asScala.exists(_.getHostnames.asScala.exists(_ == "hostname"))) - assert(spec.getImagePullSecrets.asScala.exists(_.getName == "local-reference")) - assert(spec.getInitContainers.asScala.exists(_.getName == "init-container")) - assert(spec.getNodeName == "node-name") - assert(spec.getNodeSelector.get("node-selector-key") === "node-selector-value") - assert(spec.getSchedulerName === "scheduler") - assert(spec.getSecurityContext.getRunAsUser === 1000L) - assert(spec.getServiceAccount === "service-account") - assert(spec.getSubdomain === "subdomain") - assert(spec.getTolerations.asScala.exists(_.getKey == "toleration-key")) - assert(spec.getVolumes.asScala.exists(_.getName == "test-volume")) - val container = pod.container - assert(container.getName === "executor-container") - assert(container.getArgs.contains("arg")) - assert(container.getCommand.equals(List("command").asJava)) - assert(container.getEnv.asScala.exists(_.getName == "env-key")) - assert(container.getResources.getLimits.get("gpu") === - new QuantityBuilder().withAmount("1").build()) - assert(container.getSecurityContext.getRunAsNonRoot) - assert(container.getStdin) - assert(container.getTerminationMessagePath === "termination-message-path") - assert(container.getTerminationMessagePolicy === "termination-message-policy") - assert(pod.container.getVolumeMounts.asScala.exists(_.getName == "test-volume")) - - } - - - def podWithSupportedFeatures(): Pod = new PodBuilder() - .withNewMetadata() - .addToLabels("test-label-key", "test-label-value") - .addToAnnotations("test-annotation-key", "test-annotation-value") - .withNamespace("namespace") - .addNewOwnerReference() - .withController(true) - .withName("owner-reference") - .endOwnerReference() - .endMetadata() - .withNewSpec() - .withDnsPolicy("dns-policy") - .withHostAliases(new HostAliasBuilder().withHostnames("hostname").build()) - .withImagePullSecrets( - new LocalObjectReferenceBuilder().withName("local-reference").build()) - .withInitContainers(new ContainerBuilder().withName("init-container").build()) - .withNodeName("node-name") - .withNodeSelector(Map("node-selector-key" -> "node-selector-value").asJava) - .withSchedulerName("scheduler") - .withNewSecurityContext() - .withRunAsUser(1000L) - .endSecurityContext() - .withServiceAccount("service-account") - .withSubdomain("subdomain") - .withTolerations(new TolerationBuilder() - .withKey("toleration-key") - .withOperator("Equal") - .withEffect("NoSchedule") - .build()) - .addNewVolume() - .withNewHostPath() - .withPath("/test") - .endHostPath() - .withName("test-volume") - .endVolume() - .addNewContainer() - .withArgs("arg") - .withCommand("command") - .addNewEnv() - .withName("env-key") - .withValue("env-value") - .endEnv() - .withImagePullPolicy("Always") - .withName("executor-container") - .withNewResources() - .withLimits(Map("gpu" -> new QuantityBuilder().withAmount("1").build()).asJava) - .endResources() - .withNewSecurityContext() - .withRunAsNonRoot(true) - .endSecurityContext() - .withStdin(true) - .withTerminationMessagePath("termination-message-path") - .withTerminationMessagePolicy("termination-message-policy") - .addToVolumeMounts( - new VolumeMountBuilder() - .withName("test-volume") - .withMountPath("/test") - .build()) - .endContainer() - .endSpec() - .build() - -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala index d4fa31af3d5ce..278a3821a6f3d 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala @@ -80,8 +80,8 @@ class ExecutorPodsAllocatorSuite extends SparkFunSuite with BeforeAndAfter { when(kubernetesClient.pods()).thenReturn(podOperations) when(podOperations.withName(driverPodName)).thenReturn(driverPodOperations) when(driverPodOperations.get).thenReturn(driverPod) - when(executorBuilder.buildFromFeatures(any(classOf[KubernetesExecutorConf]), meq(secMgr))) - .thenAnswer(executorPodAnswer()) + when(executorBuilder.buildFromFeatures(any(classOf[KubernetesExecutorConf]), meq(secMgr), + meq(kubernetesClient))).thenAnswer(executorPodAnswer()) snapshotsStore = new DeterministicExecutorPodsSnapshotsStore() waitForExecutorPodsClock = new ManualClock(0L) podsAllocatorUnderTest = new ExecutorPodsAllocator( diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala index ef521fd801e97..bd716174a8271 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala @@ -16,147 +16,23 @@ */ package org.apache.spark.scheduler.cluster.k8s -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model.{Config => _, _} import io.fabric8.kubernetes.client.KubernetesClient -import org.mockito.Mockito.{mock, never, verify} -import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.k8s._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.features._ -import org.apache.spark.deploy.k8s.submit.PodBuilderSuiteUtils -import org.apache.spark.util.SparkConfWithEnv - -class KubernetesExecutorBuilderSuite extends SparkFunSuite { - private val BASIC_STEP_TYPE = "basic" - private val SECRETS_STEP_TYPE = "mount-secrets" - private val ENV_SECRETS_STEP_TYPE = "env-secrets" - private val LOCAL_DIRS_STEP_TYPE = "local-dirs" - private val HADOOP_CONF_STEP_TYPE = "hadoop-conf-step" - private val HADOOP_SPARK_USER_STEP_TYPE = "hadoop-spark-user" - private val KERBEROS_CONF_STEP_TYPE = "kerberos-step" - private val MOUNT_VOLUMES_STEP_TYPE = "mount-volumes" - - private val secMgr = new SecurityManager(new SparkConf(false)) - - private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( - BASIC_STEP_TYPE, classOf[BasicExecutorFeatureStep]) - private val mountSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( - SECRETS_STEP_TYPE, classOf[MountSecretsFeatureStep]) - private val envSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( - ENV_SECRETS_STEP_TYPE, classOf[EnvSecretsFeatureStep]) - private val localDirsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( - LOCAL_DIRS_STEP_TYPE, classOf[LocalDirsFeatureStep]) - private val hadoopConfStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( - HADOOP_CONF_STEP_TYPE, classOf[HadoopConfExecutorFeatureStep]) - private val hadoopSparkUser = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( - HADOOP_SPARK_USER_STEP_TYPE, classOf[HadoopSparkUserExecutorFeatureStep]) - private val kerberosConf = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( - KERBEROS_CONF_STEP_TYPE, classOf[KerberosConfExecutorFeatureStep]) - private val mountVolumesStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( - MOUNT_VOLUMES_STEP_TYPE, classOf[MountVolumesFeatureStep]) +import org.apache.spark.internal.config.ConfigEntry - private val builderUnderTest = new KubernetesExecutorBuilder( - (_, _) => basicFeatureStep, - _ => mountSecretsStep, - _ => envSecretsStep, - _ => localDirsStep, - _ => mountVolumesStep, - _ => hadoopConfStep, - _ => kerberosConf, - _ => hadoopSparkUser) - - test("Basic steps are consistently applied.") { - val conf = KubernetesTestConf.createExecutorConf() - validateStepTypesApplied( - builderUnderTest.buildFromFeatures(conf, secMgr), BASIC_STEP_TYPE, LOCAL_DIRS_STEP_TYPE) - } - - test("Apply secrets step if secrets are present.") { - val conf = KubernetesTestConf.createExecutorConf( - secretEnvNamesToKeyRefs = Map("secret-name" -> "secret-key"), - secretNamesToMountPaths = Map("secret" -> "secretMountPath")) - validateStepTypesApplied( - builderUnderTest.buildFromFeatures(conf, secMgr), - BASIC_STEP_TYPE, - LOCAL_DIRS_STEP_TYPE, - SECRETS_STEP_TYPE, - ENV_SECRETS_STEP_TYPE) - } +class KubernetesExecutorBuilderSuite extends PodBuilderSuite { - test("Apply volumes step if mounts are present.") { - val volumeSpec = KubernetesVolumeSpec( - "volume", - "/tmp", - "", - false, - KubernetesHostPathVolumeConf("/checkpoint")) - val conf = KubernetesTestConf.createExecutorConf( - volumes = Seq(volumeSpec)) - validateStepTypesApplied( - builderUnderTest.buildFromFeatures(conf, secMgr), - BASIC_STEP_TYPE, - LOCAL_DIRS_STEP_TYPE, - MOUNT_VOLUMES_STEP_TYPE) + override protected def templateFileConf: ConfigEntry[_] = { + Config.KUBERNETES_EXECUTOR_PODTEMPLATE_FILE } - test("Apply basicHadoop step if HADOOP_CONF_DIR is defined") { - // HADOOP_DELEGATION_TOKEN - val conf = KubernetesTestConf.createExecutorConf( - sparkConf = new SparkConfWithEnv(Map("HADOOP_CONF_DIR" -> "/var/hadoop-conf")) - .set(HADOOP_CONFIG_MAP_NAME, "hadoop-conf-map-name") - .set(KRB5_CONFIG_MAP_NAME, "krb5-conf-map-name")) - validateStepTypesApplied( - builderUnderTest.buildFromFeatures(conf, secMgr), - BASIC_STEP_TYPE, - LOCAL_DIRS_STEP_TYPE, - HADOOP_CONF_STEP_TYPE, - HADOOP_SPARK_USER_STEP_TYPE) + override protected def buildPod(sparkConf: SparkConf, client: KubernetesClient): SparkPod = { + sparkConf.set("spark.driver.host", "https://driver.host.com") + val conf = KubernetesTestConf.createExecutorConf(sparkConf = sparkConf) + val secMgr = new SecurityManager(sparkConf) + new KubernetesExecutorBuilder().buildFromFeatures(conf, secMgr, client) } - test("Apply kerberos step if DT secrets created") { - val conf = KubernetesTestConf.createExecutorConf( - sparkConf = new SparkConf(false) - .set(HADOOP_CONFIG_MAP_NAME, "hadoop-conf-map-name") - .set(KRB5_CONFIG_MAP_NAME, "krb5-conf-map-name") - .set(KERBEROS_SPARK_USER_NAME, "spark-user") - .set(KERBEROS_DT_SECRET_NAME, "dt-secret") - .set(KERBEROS_DT_SECRET_KEY, "dt-key" )) - validateStepTypesApplied( - builderUnderTest.buildFromFeatures(conf, secMgr), - BASIC_STEP_TYPE, - LOCAL_DIRS_STEP_TYPE, - HADOOP_CONF_STEP_TYPE, - KERBEROS_CONF_STEP_TYPE) - } - - private def validateStepTypesApplied(resolvedPod: SparkPod, stepTypes: String*): Unit = { - assert(resolvedPod.pod.getMetadata.getLabels.asScala.keys.toSet === stepTypes.toSet) - } - - test("Starts with empty executor pod if template is not specified") { - val kubernetesClient = mock(classOf[KubernetesClient]) - val executorBuilder = KubernetesExecutorBuilder.apply(kubernetesClient, new SparkConf()) - verify(kubernetesClient, never()).pods() - } - - test("Starts with executor template if specified") { - val kubernetesClient = PodBuilderSuiteUtils.loadingMockKubernetesClient() - val sparkConf = new SparkConf(false) - .set("spark.driver.host", "https://driver.host.com") - .set(Config.CONTAINER_IMAGE, "spark-executor:latest") - .set(Config.KUBERNETES_EXECUTOR_PODTEMPLATE_FILE, "template-file.yaml") - val kubernetesConf = KubernetesTestConf.createExecutorConf( - sparkConf = sparkConf, - driverPod = Some(new PodBuilder() - .withNewMetadata() - .withName("driver") - .endMetadata() - .build())) - val sparkPod = KubernetesExecutorBuilder(kubernetesClient, sparkConf) - .buildFromFeatures(kubernetesConf, secMgr) - PodBuilderSuiteUtils.verifyPodWithSupportedFeatures(sparkPod) - } } From 3ec9e3bac191277ada38f15f225c8c4e93766cff Mon Sep 17 00:00:00 2001 From: Luca Canali Date: Wed, 12 Dec 2018 16:18:22 -0800 Subject: [PATCH 057/194] [SPARK-25277][YARN] YARN applicationMaster metrics should not register static metrics ## What changes were proposed in this pull request? YARN applicationMaster metrics registration introduced in SPARK-24594 causes further registration of static metrics (Codegenerator and HiveExternalCatalog) and of JVM metrics, which I believe do not belong in this context. This looks like an unintended side effect of using the start method of [[MetricsSystem]]. A possible solution proposed here, is to introduce startNoRegisterSources to avoid these additional registrations of static sources and of JVM sources in the case of YARN applicationMaster metrics (this could be useful for other metrics that may be added in the future). ## How was this patch tested? Manually tested on a YARN cluster, Closes #22279 from LucaCanali/YarnMetricsRemoveExtraSourceRegistration. Lead-authored-by: Luca Canali Co-authored-by: LucaCanali Signed-off-by: Marcelo Vanzin --- .../scala/org/apache/spark/metrics/MetricsSystem.scala | 8 +++++--- .../org/apache/spark/deploy/yarn/ApplicationMaster.scala | 3 ++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index bb7b434e9a113..301317a79dfcf 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -94,11 +94,13 @@ private[spark] class MetricsSystem private ( metricsConfig.initialize() - def start() { + def start(registerStaticSources: Boolean = true) { require(!running, "Attempting to start a MetricsSystem that is already running") running = true - StaticSources.allSources.foreach(registerSource) - registerSources() + if (registerStaticSources) { + StaticSources.allSources.foreach(registerSource) + registerSources() + } registerSinks() sinks.foreach(_.start) } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index c1f3211bcab29..e46c4f970c4a3 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -449,7 +449,8 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends val ms = MetricsSystem.createMetricsSystem("applicationMaster", sparkConf, securityMgr) val prefix = _sparkConf.get(YARN_METRICS_NAMESPACE).getOrElse(appId) ms.registerSource(new ApplicationMasterSource(prefix, allocator)) - ms.start() + // do not register static sources in this case as per SPARK-25277 + ms.start(false) metricsSystem = Some(ms) reporterThread = launchReporterThread() } From d9bccb5d16d245102c7d97ca919ccd140252fd3d Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Wed, 12 Dec 2018 16:45:50 -0800 Subject: [PATCH 058/194] [SPARK-26322][SS] Add spark.kafka.sasl.token.mechanism to ease delegation token configuration. ## What changes were proposed in this pull request? When Kafka delegation token obtained, SCRAM `sasl.mechanism` has to be configured for authentication. This can be configured on the related source/sink which is inconvenient from user perspective. Such granularity is not required and this configuration can be implemented with one central parameter. In this PR `spark.kafka.sasl.token.mechanism` added to configure this centrally (default: `SCRAM-SHA-512`). ## How was this patch tested? Existing unit tests + on cluster. Closes #23274 from gaborgsomogyi/SPARK-26322. Authored-by: Gabor Somogyi Signed-off-by: Marcelo Vanzin --- .../apache/spark/internal/config/Kafka.scala | 9 ++ .../structured-streaming-kafka-integration.md | 144 +----------------- .../sql/kafka010/KafkaSourceProvider.scala | 15 +- 3 files changed, 21 insertions(+), 147 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/Kafka.scala b/core/src/main/scala/org/apache/spark/internal/config/Kafka.scala index 064fc93cb8ed8..e91ddd3e9741a 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/Kafka.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/Kafka.scala @@ -79,4 +79,13 @@ private[spark] object Kafka { "For further details please see kafka documentation. Only used to obtain delegation token.") .stringConf .createOptional + + val TOKEN_SASL_MECHANISM = + ConfigBuilder("spark.kafka.sasl.token.mechanism") + .doc("SASL mechanism used for client connections with delegation token. Because SCRAM " + + "login module used for authentication a compatible mechanism has to be set here. " + + "For further details please see kafka documentation (sasl.mechanism). Only used to " + + "authenticate against Kafka broker with delegation token.") + .stringConf + .createWithDefault("SCRAM-SHA-512") } diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md index 7040f8da2c614..3d64ec4cb55f7 100644 --- a/docs/structured-streaming-kafka-integration.md +++ b/docs/structured-streaming-kafka-integration.md @@ -642,9 +642,9 @@ This way the application can be configured via Spark parameters and may not need configuration (Spark can use Kafka's dynamic JAAS configuration feature). For further information about delegation tokens, see [Kafka delegation token docs](http://kafka.apache.org/documentation/#security_delegation_token). -The process is initiated by Spark's Kafka delegation token provider. When `spark.kafka.bootstrap.servers`, +The process is initiated by Spark's Kafka delegation token provider. When `spark.kafka.bootstrap.servers` is set, Spark considers the following log in options, in order of preference: -- **JAAS login configuration** +- **JAAS login configuration**, please see example below. - **Keytab file**, such as, ./bin/spark-submit \ @@ -669,144 +669,8 @@ Kafka broker configuration): After obtaining delegation token successfully, Spark distributes it across nodes and renews it accordingly. Delegation token uses `SCRAM` login module for authentication and because of that the appropriate -`sasl.mechanism` has to be configured on source/sink (it must match with Kafka broker configuration): - -
-
-{% highlight scala %} - -// Setting on Kafka Source for Streaming Queries -val df = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("kafka.sasl.mechanism", "SCRAM-SHA-512") - .option("subscribe", "topic1") - .load() -df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - -// Setting on Kafka Source for Batch Queries -val df = spark - .read - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("kafka.sasl.mechanism", "SCRAM-SHA-512") - .option("subscribe", "topic1") - .load() -df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - -// Setting on Kafka Sink for Streaming Queries -val ds = df - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .writeStream - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("kafka.sasl.mechanism", "SCRAM-SHA-512") - .option("topic", "topic1") - .start() - -// Setting on Kafka Sink for Batch Queries -val ds = df - .selectExpr("topic1", "CAST(key AS STRING)", "CAST(value AS STRING)") - .write - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("kafka.sasl.mechanism", "SCRAM-SHA-512") - .save() - -{% endhighlight %} -
-
-{% highlight java %} - -// Setting on Kafka Source for Streaming Queries -Dataset df = spark - .readStream() - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("kafka.sasl.mechanism", "SCRAM-SHA-512") - .option("subscribe", "topic1") - .load(); -df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)"); - -// Setting on Kafka Source for Batch Queries -Dataset df = spark - .read() - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("kafka.sasl.mechanism", "SCRAM-SHA-512") - .option("subscribe", "topic1") - .load(); -df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)"); - -// Setting on Kafka Sink for Streaming Queries -StreamingQuery ds = df - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .writeStream() - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("kafka.sasl.mechanism", "SCRAM-SHA-512") - .option("topic", "topic1") - .start(); - -// Setting on Kafka Sink for Batch Queries -df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .write() - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("kafka.sasl.mechanism", "SCRAM-SHA-512") - .option("topic", "topic1") - .save(); - -{% endhighlight %} -
-
-{% highlight python %} - -// Setting on Kafka Source for Streaming Queries -df = spark \ - .readStream \ - .format("kafka") \ - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ - .option("kafka.sasl.mechanism", "SCRAM-SHA-512") \ - .option("subscribe", "topic1") \ - .load() -df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - -// Setting on Kafka Source for Batch Queries -df = spark \ - .read \ - .format("kafka") \ - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ - .option("kafka.sasl.mechanism", "SCRAM-SHA-512") \ - .option("subscribe", "topic1") \ - .load() -df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - -// Setting on Kafka Sink for Streaming Queries -ds = df \ - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") \ - .writeStream \ - .format("kafka") \ - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ - .option("kafka.sasl.mechanism", "SCRAM-SHA-512") \ - .option("topic", "topic1") \ - .start() - -// Setting on Kafka Sink for Batch Queries -df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") \ - .write \ - .format("kafka") \ - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ - .option("kafka.sasl.mechanism", "SCRAM-SHA-512") \ - .option("topic", "topic1") \ - .save() - -{% endhighlight %} -
-
+`spark.kafka.sasl.token.mechanism` (default: `SCRAM-SHA-512`) has to be configured. Also, this parameter +must match with Kafka broker configuration. When delegation token is available on an executor it can be overridden with JAAS login configuration. diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 0ac330435e5c5..6a0c2088ac3d1 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -30,6 +30,7 @@ import org.apache.kafka.common.serialization.{ByteArrayDeserializer, ByteArraySe import org.apache.spark.SparkEnv import org.apache.spark.deploy.security.KafkaTokenUtil import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext} import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ @@ -501,7 +502,7 @@ private[kafka010] object KafkaSourceProvider extends Logging { // If buffer config is not set, set it to reasonable value to work around // buffer issues (see KAFKA-3135) .setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer) - .setTokenJaasConfigIfNeeded() + .setAuthenticationConfigIfNeeded() .build() def kafkaParamsForExecutors( @@ -523,7 +524,7 @@ private[kafka010] object KafkaSourceProvider extends Logging { // If buffer config is not set, set it to reasonable value to work around // buffer issues (see KAFKA-3135) .setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer) - .setTokenJaasConfigIfNeeded() + .setAuthenticationConfigIfNeeded() .build() /** @@ -556,7 +557,7 @@ private[kafka010] object KafkaSourceProvider extends Logging { this } - def setTokenJaasConfigIfNeeded(): ConfigUpdater = { + def setAuthenticationConfigIfNeeded(): ConfigUpdater = { // There are multiple possibilities to log in and applied in the following order: // - JVM global security provided -> try to log in with JVM global security configuration // which can be configured for example with 'java.security.auth.login.config'. @@ -568,11 +569,11 @@ private[kafka010] object KafkaSourceProvider extends Logging { } else if (KafkaSecurityHelper.isTokenAvailable()) { logDebug("Delegation token detected, using it for login.") val jaasParams = KafkaSecurityHelper.getTokenJaasParams(SparkEnv.get.conf) - val mechanism = kafkaParams - .getOrElse(SaslConfigs.SASL_MECHANISM, SaslConfigs.DEFAULT_SASL_MECHANISM) + set(SaslConfigs.SASL_JAAS_CONFIG, jaasParams) + val mechanism = SparkEnv.get.conf.get(Kafka.TOKEN_SASL_MECHANISM) require(mechanism.startsWith("SCRAM"), "Delegation token works only with SCRAM mechanism.") - set(SaslConfigs.SASL_JAAS_CONFIG, jaasParams) + set(SaslConfigs.SASL_MECHANISM, mechanism) } this } @@ -600,7 +601,7 @@ private[kafka010] object KafkaSourceProvider extends Logging { ConfigUpdater("executor", specifiedKafkaParams) .set(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, serClassName) .set(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, serClassName) - .setTokenJaasConfigIfNeeded() + .setAuthenticationConfigIfNeeded() .build() } From b6504149c53db1c94f795e4b093a0f87391e1dd9 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 13 Dec 2018 11:13:15 +0800 Subject: [PATCH 059/194] [SPARK-26297][SQL] improve the doc of Distribution/Partitioning ## What changes were proposed in this pull request? Some documents of `Distribution/Partitioning` are stale and misleading, this PR fixes them: 1. `Distribution` never have intra-partition requirement 2. `OrderedDistribution` does not require tuples that share the same value being colocated in the same partition. 3. `RangePartitioning` can provide a weaker guarantee for a prefix of its `ordering` expressions. ## How was this patch tested? comment-only PR. Closes #23249 from cloud-fan/doc. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- .../plans/physical/partitioning.scala | 54 ++++++++++++------- 1 file changed, 34 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index cc1a5e835d9cd..17e1cb416fc8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -22,13 +22,11 @@ import org.apache.spark.sql.types.{DataType, IntegerType} /** * Specifies how tuples that share common expressions will be distributed when a query is executed - * in parallel on many machines. Distribution can be used to refer to two distinct physical - * properties: - * - Inter-node partitioning of data: In this case the distribution describes how tuples are - * partitioned across physical machines in a cluster. Knowing this property allows some - * operators (e.g., Aggregate) to perform partition local operations instead of global ones. - * - Intra-partition ordering of data: In this case the distribution describes guarantees made - * about how tuples are distributed within a single partition. + * in parallel on many machines. + * + * Distribution here refers to inter-node partitioning of data. That is, it describes how tuples + * are partitioned across physical machines in a cluster. Knowing this property allows some + * operators (e.g., Aggregate) to perform partition local operations instead of global ones. */ sealed trait Distribution { /** @@ -70,9 +68,7 @@ case object AllTuples extends Distribution { /** * Represents data where tuples that share the same values for the `clustering` - * [[Expression Expressions]] will be co-located. Based on the context, this - * can mean such tuples are either co-located in the same partition or they will be contiguous - * within a single partition. + * [[Expression Expressions]] will be co-located in the same partition. */ case class ClusteredDistribution( clustering: Seq[Expression], @@ -118,10 +114,12 @@ case class HashClusteredDistribution( /** * Represents data where tuples have been ordered according to the `ordering` - * [[Expression Expressions]]. This is a strictly stronger guarantee than - * [[ClusteredDistribution]] as an ordering will ensure that tuples that share the - * same value for the ordering expressions are contiguous and will never be split across - * partitions. + * [[Expression Expressions]]. Its requirement is defined as the following: + * - Given any 2 adjacent partitions, all the rows of the second partition must be larger than or + * equal to any row in the first partition, according to the `ordering` expressions. + * + * In other words, this distribution requires the rows to be ordered across partitions, but not + * necessarily within a partition. */ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution { require( @@ -241,12 +239,12 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) /** * Represents a partitioning where rows are split across partitions based on some total ordering of - * the expressions specified in `ordering`. When data is partitioned in this manner the following - * two conditions are guaranteed to hold: - * - All row where the expressions in `ordering` evaluate to the same values will be in the same - * partition. - * - Each partition will have a `min` and `max` row, relative to the given ordering. All rows - * that are in between `min` and `max` in this `ordering` will reside in this partition. + * the expressions specified in `ordering`. When data is partitioned in this manner, it guarantees: + * Given any 2 adjacent partitions, all the rows of the second partition must be larger than any row + * in the first partition, according to the `ordering` expressions. + * + * This is a strictly stronger guarantee than what `OrderedDistribution(ordering)` requires, as + * there is no overlap between partitions. * * This class extends expression primarily so that transformations over expression will descend * into its child. @@ -262,6 +260,22 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) super.satisfies0(required) || { required match { case OrderedDistribution(requiredOrdering) => + // If `ordering` is a prefix of `requiredOrdering`: + // Let's say `ordering` is [a, b] and `requiredOrdering` is [a, b, c]. According to the + // RangePartitioning definition, any [a, b] in a previous partition must be smaller + // than any [a, b] in the following partition. This also means any [a, b, c] in a + // previous partition must be smaller than any [a, b, c] in the following partition. + // Thus `RangePartitioning(a, b)` satisfies `OrderedDistribution(a, b, c)`. + // + // If `requiredOrdering` is a prefix of `ordering`: + // Let's say `ordering` is [a, b, c] and `requiredOrdering` is [a, b]. According to the + // RangePartitioning definition, any [a, b, c] in a previous partition must be smaller + // than any [a, b, c] in the following partition. If there is a [a1, b1] from a previous + // partition which is larger than a [a2, b2] from the following partition, then there + // must be a [a1, b1 c1] larger than [a2, b2, c2], which violates RangePartitioning + // definition. So it's guaranteed that, any [a, b] in a previous partition must not be + // greater(i.e. smaller or equal to) than any [a, b] in the following partition. Thus + // `RangePartitioning(a, b, c)` satisfies `OrderedDistribution(a, b)`. val minSize = Seq(requiredOrdering.size, ordering.size).min requiredOrdering.take(minSize) == ordering.take(minSize) case ClusteredDistribution(requiredClustering, _) => From 73a373b72020f90bff2023dc7485da88b7ecead5 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 13 Dec 2018 12:50:15 +0800 Subject: [PATCH 060/194] [SPARK-26348][SQL][TEST] make sure expression is resolved during test ## What changes were proposed in this pull request? cleanup some tests to make sure expression is resolved during test. ## How was this patch tested? test-only PR Closes #23297 from cloud-fan/test. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- .../spark/sql/catalyst/dsl/package.scala | 2 +- .../expressions/ExpressionEvalHelper.scala | 10 ++++---- .../expressions/JsonExpressionsSuite.scala | 11 ++++----- .../catalyst/expressions/PredicateSuite.scala | 23 ++++++------------- .../expressions/StringExpressionsSuite.scala | 7 ++---- 5 files changed, 20 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 176ea823b1fcd..151481c80ee96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -136,7 +136,7 @@ package object dsl { implicit def longToLiteral(l: Long): Literal = Literal(l) implicit def floatToLiteral(f: Float): Literal = Literal(f) implicit def doubleToLiteral(d: Double): Literal = Literal(d) - implicit def stringToLiteral(s: String): Literal = Literal(s) + implicit def stringToLiteral(s: String): Literal = Literal.create(s, StringType) implicit def dateToLiteral(d: Date): Literal = Literal(d) implicit def bigDecimalToLiteral(d: BigDecimal): Literal = Literal(d.underlying()) implicit def bigDecimalToLiteral(d: java.math.BigDecimal): Literal = Literal(d) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index b4fd170467d81..1c91adab71375 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -28,7 +28,7 @@ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.analysis.{ResolveTimeZone, SimpleAnalyzer} +import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer import org.apache.spark.sql.catalyst.plans.PlanTestBase @@ -70,7 +70,9 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa private def prepareEvaluation(expression: Expression): Expression = { val serializer = new JavaSerializer(new SparkConf()).newInstance val resolver = ResolveTimeZone(new SQLConf) - resolver.resolveTimeZones(serializer.deserialize(serializer.serialize(expression))) + val expr = resolver.resolveTimeZones(expression) + assert(expr.resolved) + serializer.deserialize(serializer.serialize(expr)) } protected def checkEvaluation( @@ -296,9 +298,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa expected: Any, inputRow: InternalRow = EmptyRow): Unit = { val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation()) - // We should analyze the plan first, otherwise we possibly optimize an unresolved plan. - val analyzedPlan = SimpleAnalyzer.execute(plan) - val optimizedPlan = SimpleTestOptimizer.execute(analyzedPlan) + val optimizedPlan = SimpleTestOptimizer.execute(plan) checkEvaluationWithoutCodegen(optimizedPlan.expressions.head, expected, inputRow) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 5d60cefc13896..238e6e34b4ae5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -25,7 +25,7 @@ import org.scalatest.exceptions.TestFailedException import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.internal.SQLConf @@ -694,11 +694,10 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with val mapType2 = MapType(IntegerType, CalendarIntervalType) val schema2 = StructType(StructField("a", mapType2) :: Nil) val struct2 = Literal.create(null, schema2) - intercept[TreeNodeException[_]] { - checkEvaluation( - StructsToJson(Map.empty, struct2, gmtId), - null - ) + StructsToJson(Map.empty, struct2, gmtId).checkInputDataTypes() match { + case TypeCheckResult.TypeCheckFailure(msg) => + assert(msg.contains("Unable to convert column a of type calendarinterval to JSON")) + case _ => fail("from_json should not work on interval map value type.") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 0f63717f9daf2..3541afcd2144d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -24,6 +24,7 @@ import scala.collection.immutable.HashSet import org.apache.spark.SparkFunSuite import org.apache.spark.sql.RandomDataGenerator import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.encoders.ExamplePointUDT import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} @@ -231,22 +232,12 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { testWithRandomDataGeneration(structType, nullable) } - // Map types: not supported - for ( - keyType <- atomicTypes; - valueType <- atomicTypes; - nullable <- Seq(true, false)) { - val mapType = MapType(keyType, valueType) - val e = intercept[Exception] { - testWithRandomDataGeneration(mapType, nullable) - } - if (e.getMessage.contains("Code generation of")) { - // If the `value` expression is null, `eval` will be short-circuited. - // Codegen version evaluation will be run then. - assert(e.getMessage.contains("cannot generate equality code for un-comparable type")) - } else { - assert(e.getMessage.contains("Exception evaluating")) - } + // In doesn't support map type and will fail the analyzer. + val map = Literal.create(create_map(1 -> 1), MapType(IntegerType, IntegerType)) + In(map, Seq(map)).checkInputDataTypes() match { + case TypeCheckResult.TypeCheckFailure(msg) => + assert(msg.contains("function in does not support ordering on type map")) + case _ => fail("In should not work on map type") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index aa334e040d5fc..e95f2dff231b9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -744,16 +744,14 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("ParseUrl") { def checkParseUrl(expected: String, urlStr: String, partToExtract: String): Unit = { - checkEvaluation( - ParseUrl(Seq(Literal(urlStr), Literal(partToExtract))), expected) + checkEvaluation(ParseUrl(Seq(urlStr, partToExtract)), expected) } def checkParseUrlWithKey( expected: String, urlStr: String, partToExtract: String, key: String): Unit = { - checkEvaluation( - ParseUrl(Seq(Literal(urlStr), Literal(partToExtract), Literal(key))), expected) + checkEvaluation(ParseUrl(Seq(urlStr, partToExtract, key)), expected) } checkParseUrl("spark.apache.org", "http://spark.apache.org/path?query=1", "HOST") @@ -798,7 +796,6 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Sentences(nullString, nullString, nullString), null) checkEvaluation(Sentences(nullString, nullString), null) checkEvaluation(Sentences(nullString), null) - checkEvaluation(Sentences(Literal.create(null, NullType)), null) checkEvaluation(Sentences("", nullString, nullString), Seq.empty) checkEvaluation(Sentences("", nullString), Seq.empty) checkEvaluation(Sentences(""), Seq.empty) From 4a5acc79f76be9cbda28886e6c5745f3fbd3d2e8 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 13 Dec 2018 13:14:59 +0800 Subject: [PATCH 061/194] [SPARK-26355][PYSPARK] Add a workaround for PyArrow 0.11. ## What changes were proposed in this pull request? In PyArrow 0.11, there is a API breaking change. - [ARROW-1949](https://issues.apache.org/jira/browse/ARROW-1949) - [Python/C++] Add option to Array.from_pandas and pyarrow.array to perform unsafe casts. This causes test failures in `ScalarPandasUDFTests.test_vectorized_udf_null_(byte|short|int|long)`: ``` File "/Users/ueshin/workspace/apache-spark/spark/python/pyspark/worker.py", line 377, in main process() File "/Users/ueshin/workspace/apache-spark/spark/python/pyspark/worker.py", line 372, in process serializer.dump_stream(func(split_index, iterator), outfile) File "/Users/ueshin/workspace/apache-spark/spark/python/pyspark/serializers.py", line 317, in dump_stream batch = _create_batch(series, self._timezone) File "/Users/ueshin/workspace/apache-spark/spark/python/pyspark/serializers.py", line 286, in _create_batch arrs = [create_array(s, t) for s, t in series] File "/Users/ueshin/workspace/apache-spark/spark/python/pyspark/serializers.py", line 284, in create_array return pa.Array.from_pandas(s, mask=mask, type=t) File "pyarrow/array.pxi", line 474, in pyarrow.lib.Array.from_pandas return array(obj, mask=mask, type=type, safe=safe, from_pandas=True, File "pyarrow/array.pxi", line 169, in pyarrow.lib.array return _ndarray_to_array(values, mask, type, from_pandas, safe, File "pyarrow/array.pxi", line 69, in pyarrow.lib._ndarray_to_array check_status(NdarrayToArrow(pool, values, mask, from_pandas, File "pyarrow/error.pxi", line 81, in pyarrow.lib.check_status raise ArrowInvalid(message) ArrowInvalid: Floating point value truncated ``` We should add a workaround to support PyArrow 0.11. ## How was this patch tested? In my local environment. Closes #23305 from ueshin/issues/SPARK-26355/pyarrow_0.11. Authored-by: Takuya UESHIN Signed-off-by: Hyukjin Kwon --- python/pyspark/serializers.py | 5 ++++- .../pyspark/sql/tests/test_pandas_udf_grouped_map.py | 11 +++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index f3ebd3767a0a1..fd4695210fb7c 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -281,7 +281,10 @@ def create_array(s, t): # TODO: see ARROW-2432. Remove when the minimum PyArrow version becomes 0.10.0. return pa.Array.from_pandas(s.apply( lambda v: decimal.Decimal('NaN') if v is None else v), mask=mask, type=t) - return pa.Array.from_pandas(s, mask=mask, type=t) + elif LooseVersion(pa.__version__) < LooseVersion("0.11.0"): + # TODO: see ARROW-1949. Remove when the minimum PyArrow version becomes 0.11.0. + return pa.Array.from_pandas(s, mask=mask, type=t) + return pa.Array.from_pandas(s, mask=mask, type=t, safe=False) arrs = [create_array(s, t) for s, t in series] return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))]) diff --git a/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py b/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py index bfecc071386e9..a12c608dff9dd 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py +++ b/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py @@ -468,8 +468,15 @@ def invalid_positional_types(pdf): with QuietTest(self.sc): with self.assertRaisesRegexp(Exception, "KeyError: 'id'"): grouped_df.apply(column_name_typo).collect() - with self.assertRaisesRegexp(Exception, "No cast implemented"): - grouped_df.apply(invalid_positional_types).collect() + from distutils.version import LooseVersion + import pyarrow as pa + if LooseVersion(pa.__version__) < LooseVersion("0.11.0"): + # TODO: see ARROW-1949. Remove when the minimum PyArrow version becomes 0.11.0. + with self.assertRaisesRegexp(Exception, "No cast implemented"): + grouped_df.apply(invalid_positional_types).collect() + else: + with self.assertRaisesRegexp(Exception, "an integer is required"): + grouped_df.apply(invalid_positional_types).collect() def test_positional_assignment_conf(self): import pandas as pd From 239b8ec63091b44f521849359917d3f8c7c635dd Mon Sep 17 00:00:00 2001 From: Qi Shao Date: Thu, 13 Dec 2018 20:05:49 +0800 Subject: [PATCH 062/194] [MINOR][R] Fix indents of sparkR welcome message to be consistent with pyspark and spark-shell ## What changes were proposed in this pull request? 1. Removed empty space at the beginning of welcome message lines of sparkR to be consistent with welcome message of `pyspark` and `spark-shell` 2. Setting indent of logo message lines to 3 to be consistent with welcome message of `pyspark` and `spark-shell` Output of `pyspark`: ``` Welcome to ____ __ / __/__ ___ _____/ /__ _\ \/ _ \/ _ `/ __/ '_/ /__ / .__/\_,_/_/ /_/\_\ version 2.4.0 /_/ Using Python version 3.6.6 (default, Jun 28 2018 11:07:29) SparkSession available as 'spark'. ``` Output of `spark-shell`: ``` Spark session available as 'spark'. Welcome to ____ __ / __/__ ___ _____/ /__ _\ \/ _ \/ _ `/ __/ '_/ /___/ .__/\_,_/_/ /_/\_\ version 2.4.0 /_/ Using Scala version 2.11.12 (Java HotSpot(TM) 64-Bit Server VM, Java 1.8.0_161) Type in expressions to have them evaluated. Type :help for more information. ``` ## How was this patch tested? Before: Output of `sparkR`: ``` Welcome to ____ __ / __/__ ___ _____/ /__ _\ \/ _ \/ _ `/ __/ '_/ /___/ .__/\_,_/_/ /_/\_\ version 2.4.0 /_/ SparkSession available as 'spark'. ``` After: ``` Welcome to ____ __ / __/__ ___ _____/ /__ _\ \/ _ \/ _ `/ __/ '_/ /___/ .__/\_,_/_/ /_/\_\ version 2.4.0 /_/ SparkSession available as 'spark'. ``` Closes #23293 from AzureQ/master. Authored-by: Qi Shao Signed-off-by: Hyukjin Kwon --- R/pkg/inst/profile/shell.R | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/R/pkg/inst/profile/shell.R b/R/pkg/inst/profile/shell.R index 32eb3671b5941..e4e0d032997de 100644 --- a/R/pkg/inst/profile/shell.R +++ b/R/pkg/inst/profile/shell.R @@ -33,19 +33,19 @@ sc <- SparkR:::callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", spark) assign("sc", sc, envir = .GlobalEnv) sparkVer <- SparkR:::callJMethod(sc, "version") - cat("\n Welcome to") + cat("\nWelcome to") cat("\n") - cat(" ____ __", "\n") - cat(" / __/__ ___ _____/ /__", "\n") - cat(" _\\ \\/ _ \\/ _ `/ __/ '_/", "\n") - cat(" /___/ .__/\\_,_/_/ /_/\\_\\") + cat(" ____ __", "\n") + cat(" / __/__ ___ _____/ /__", "\n") + cat(" _\\ \\/ _ \\/ _ `/ __/ '_/", "\n") + cat(" /___/ .__/\\_,_/_/ /_/\\_\\") if (nchar(sparkVer) == 0) { cat("\n") } else { - cat(" version ", sparkVer, "\n") + cat(" version", sparkVer, "\n") } - cat(" /_/", "\n") + cat(" /_/", "\n") cat("\n") - cat("\n SparkSession available as 'spark'.\n") + cat("\nSparkSession available as 'spark'.\n") } From bda9e84a7b9ee52142e7c654aa6e491625304084 Mon Sep 17 00:00:00 2001 From: seancxmao Date: Thu, 13 Dec 2018 07:40:13 -0600 Subject: [PATCH 063/194] [MINOR][DOC] Fix comments of ConvertToLocalRelation rule ## What changes were proposed in this pull request? There are some comments issues left when `ConvertToLocalRelation` rule was added (see #22205/[SPARK-25212](https://issues.apache.org/jira/browse/SPARK-25212)). This PR fixes those comments issues. ## How was this patch tested? N/A Closes #23273 from seancxmao/ConvertToLocalRelation-doc. Authored-by: seancxmao Signed-off-by: Sean Owen --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 8d251eeab8484..f615757a837a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -131,11 +131,11 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) // since the other rules might make two separate Unions operators adjacent. Batch("Union", Once, CombineUnions) :: - // run this once earlier. this might simplify the plan and reduce cost of optimizer. - // for example, a query such as Filter(LocalRelation) would go through all the heavy + // Run this once earlier. This might simplify the plan and reduce cost of optimizer. + // For example, a query such as Filter(LocalRelation) would go through all the heavy // optimizer rules that are triggered when there is a filter - // (e.g. InferFiltersFromConstraints). if we run this batch earlier, the query becomes just - // LocalRelation and does not trigger many rules + // (e.g. InferFiltersFromConstraints). If we run this batch earlier, the query becomes just + // LocalRelation and does not trigger many rules. Batch("LocalRelation early", fixedPoint, ConvertToLocalRelation, PropagateEmptyRelation) :: @@ -1370,10 +1370,8 @@ object DecimalAggregates extends Rule[LogicalPlan] { } /** - * Converts local operations (i.e. ones that don't require data exchange) on LocalRelation to - * another LocalRelation. - * - * This is relatively simple as it currently handles only 2 single case: Project and Limit. + * Converts local operations (i.e. ones that don't require data exchange) on `LocalRelation` to + * another `LocalRelation`. */ object ConvertToLocalRelation extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { From d2a58a1354b5790c0603190e54c6fc31c8c1976c Mon Sep 17 00:00:00 2001 From: lichaoqun Date: Thu, 13 Dec 2018 07:42:17 -0600 Subject: [PATCH 064/194] [MINOR][DOC] update the condition description of BypassMergeSortShuffle MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? These three condition descriptions should be updated, follow #23228 :
  • no Ordering is specified,
  • no Aggregator is specified, and
  • the number of partitions is less than spark.shuffle.sort.bypassMergeThreshold.
  • 1、If the shuffle dependency specifies aggregation, but it only aggregates at the reduce-side, BypassMergeSortShuffle can still be used. 2、If the number of output partitions is spark.shuffle.sort.bypassMergeThreshold(eg.200), we can use BypassMergeSortShuffle. ## How was this patch tested? N/A Closes #23281 from lcqzte10192193/wid-lcq-1211. Authored-by: lichaoqun Signed-off-by: Sean Owen --- .../spark/shuffle/sort/BypassMergeSortShuffleWriter.java | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index fda33cd8293d5..997bc9e3f0435 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -58,9 +58,8 @@ * simultaneously opens separate serializers and file streams for all partitions. As a result, * {@link SortShuffleManager} only selects this write path when *
      - *
    • no Ordering is specified,
    • - *
    • no Aggregator is specified, and
    • - *
    • the number of partitions is less than + *
    • no map-side combine is specified, and
    • + *
    • the number of partitions is less than or equal to * spark.shuffle.sort.bypassMergeThreshold.
    • *
    * From 6a1cdf4d88d978de4917391e9e641db12599f3de Mon Sep 17 00:00:00 2001 From: "n.fraison" Date: Thu, 13 Dec 2018 08:34:47 -0600 Subject: [PATCH 065/194] [SPARK-26340][CORE] Ensure cores per executor is greater than cpu per task Currently this check is only performed for dynamic allocation use case in ExecutorAllocationManager. ## What changes were proposed in this pull request? Checks that cpu per task is lower than number of cores per executor otherwise throw an exception ## How was this patch tested? manual tests Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #23290 from ashangit/master. Authored-by: n.fraison Signed-off-by: Sean Owen --- core/src/main/scala/org/apache/spark/SparkConf.scala | 9 +++++++++ .../src/test/scala/org/apache/spark/SparkConfSuite.scala | 7 +++++++ 2 files changed, 16 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 21c5cbc04d813..8d135d3e083d7 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -605,6 +605,15 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria } } + if (contains("spark.executor.cores") && contains("spark.task.cpus")) { + val executorCores = getInt("spark.executor.cores", 1) + val taskCpus = getInt("spark.task.cpus", 1) + + if (executorCores < taskCpus) { + throw new SparkException("spark.executor.cores must not be less than spark.task.cpus.") + } + } + val encryptionEnabled = get(NETWORK_ENCRYPTION_ENABLED) || get(SASL_ENCRYPTION_ENABLED) require(!encryptionEnabled || get(NETWORK_AUTH_ENABLED), s"${NETWORK_AUTH_ENABLED.key} must be enabled when enabling encryption.") diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index df274d949bae3..7cb03deae1391 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -138,6 +138,13 @@ class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSyst assert(sc.appName === "My other app") } + test("creating SparkContext with cpus per tasks bigger than cores per executors") { + val conf = new SparkConf(false) + .set("spark.executor.cores", "1") + .set("spark.task.cpus", "2") + intercept[SparkException] { sc = new SparkContext(conf) } + } + test("nested property names") { // This wasn't supported by some external conf parsing libraries System.setProperty("spark.test.a", "a") From 9b127e1670bdc3616308c50c7c852478d7b28008 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 13 Dec 2018 23:03:26 +0800 Subject: [PATCH 066/194] [SPARK-26313][SQL] move `newScanBuilder` from Table to read related mix-in traits ## What changes were proposed in this pull request? As discussed in https://github.com/apache/spark/pull/23208/files#r239684490 , we should put `newScanBuilder` in read related mix-in traits like `SupportsBatchRead`, to support write-only table. In the `Append` operator, we should skip schema validation if not necessary. In the future we would introduce a capability API, so that data source can tell Spark that it doesn't want to do validation. ## How was this patch tested? existing tests. Closes #23266 from cloud-fan/ds-read. Authored-by: Wenchen Fan Signed-off-by: Hyukjin Kwon --- .../sql/sources/v2/SupportsBatchRead.java | 8 ++--- .../spark/sql/sources/v2/SupportsRead.java | 35 +++++++++++++++++++ .../apache/spark/sql/sources/v2/Table.java | 15 ++------ 3 files changed, 41 insertions(+), 17 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsRead.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchRead.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchRead.java index 0df89dbb608a4..6c5a95d2a75b7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchRead.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchRead.java @@ -24,10 +24,10 @@ /** * An empty mix-in interface for {@link Table}, to indicate this table supports batch scan. *

    - * If a {@link Table} implements this interface, its {@link Table#newScanBuilder(DataSourceOptions)} - * must return a {@link ScanBuilder} that builds {@link Scan} with {@link Scan#toBatch()} - * implemented. + * If a {@link Table} implements this interface, the + * {@link SupportsRead#newScanBuilder(DataSourceOptions)} must return a {@link ScanBuilder} that + * builds {@link Scan} with {@link Scan#toBatch()} implemented. *

    */ @Evolving -public interface SupportsBatchRead extends Table { } +public interface SupportsBatchRead extends SupportsRead { } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsRead.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsRead.java new file mode 100644 index 0000000000000..e22738d20d507 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsRead.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import org.apache.spark.sql.sources.v2.reader.Scan; +import org.apache.spark.sql.sources.v2.reader.ScanBuilder; + +/** + * An internal base interface of mix-in interfaces for readable {@link Table}. This adds + * {@link #newScanBuilder(DataSourceOptions)} that is used to create a scan for batch, micro-batch, + * or continuous processing. + */ +interface SupportsRead extends Table { + + /** + * Returns a {@link ScanBuilder} which can be used to build a {@link Scan}. Spark will call this + * method to configure each scan. + */ + ScanBuilder newScanBuilder(DataSourceOptions options); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/Table.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/Table.java index 0c65fe0f9e76a..08664859b8de2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/Table.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/Table.java @@ -18,8 +18,6 @@ package org.apache.spark.sql.sources.v2; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.sources.v2.reader.Scan; -import org.apache.spark.sql.sources.v2.reader.ScanBuilder; import org.apache.spark.sql.types.StructType; /** @@ -43,17 +41,8 @@ public interface Table { String name(); /** - * Returns the schema of this table. + * Returns the schema of this table. If the table is not readable and doesn't have a schema, an + * empty schema can be returned here. */ StructType schema(); - - /** - * Returns a {@link ScanBuilder} which can be used to build a {@link Scan} later. Spark will call - * this method for each data scanning query. - *

    - * The builder can take some query specific information to do operators pushdown, and keep these - * information in the created {@link Scan}. - *

    - */ - ScanBuilder newScanBuilder(DataSourceOptions options); } From 14b497863a2336f072ad93e0791cdfc11f3f1004 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Thu, 13 Dec 2018 09:07:33 -0800 Subject: [PATCH 067/194] [SPARK-26098][WEBUI] Show associated SQL query in Job page ## What changes were proposed in this pull request? For jobs associated to SQL queries, it would be easier to understand the context to showing the SQL query in Job detail page. Before code change, it is hard to tell what the job is about from the job page: ![image](https://user-images.githubusercontent.com/1097932/48659359-96baa180-ea8a-11e8-8419-a0a87c3f30fc.png) After code change: ![image](https://user-images.githubusercontent.com/1097932/48659390-26f8e680-ea8b-11e8-8fdd-3b58909ea364.png) After navigating to the associated SQL detail page, We can see the whole context : ![image](https://user-images.githubusercontent.com/1097932/48659463-9fac7280-ea8c-11e8-9dfe-244e849f72a5.png) **For Jobs don't have associated SQL query, the text won't be shown.** ## How was this patch tested? Manual test Closes #23068 from gengliangwang/addSQLID. Authored-by: Gengliang Wang Signed-off-by: gatorsmile --- .../apache/spark/status/AppStatusListener.scala | 7 ++++++- .../org/apache/spark/status/AppStatusStore.scala | 7 +++++++ .../scala/org/apache/spark/status/LiveEntity.scala | 5 +++-- .../scala/org/apache/spark/status/storeTypes.scala | 3 ++- .../scala/org/apache/spark/ui/jobs/JobPage.scala | 14 +++++++++++++- 5 files changed, 31 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index bd3f58b6182c0..262ff6547faa5 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -70,6 +70,8 @@ private[spark] class AppStatusListener( private val liveTasks = new HashMap[Long, LiveTask]() private val liveRDDs = new HashMap[Int, LiveRDD]() private val pools = new HashMap[String, SchedulerPool]() + + private val SQL_EXECUTION_ID_KEY = "spark.sql.execution.id" // Keep the active executor count as a separate variable to avoid having to do synchronization // around liveExecutors. @volatile private var activeExecutorCount = 0 @@ -318,6 +320,8 @@ private[spark] class AppStatusListener( val lastStageName = lastStageInfo.map(_.name).getOrElse("(Unknown Stage Name)") val jobGroup = Option(event.properties) .flatMap { p => Option(p.getProperty(SparkContext.SPARK_JOB_GROUP_ID)) } + val sqlExecutionId = Option(event.properties) + .flatMap(p => Option(p.getProperty(SQL_EXECUTION_ID_KEY)).map(_.toLong)) val job = new LiveJob( event.jobId, @@ -325,7 +329,8 @@ private[spark] class AppStatusListener( if (event.time > 0) Some(new Date(event.time)) else None, event.stageIds, jobGroup, - numTasks) + numTasks, + sqlExecutionId) liveJobs.put(event.jobId, job) liveUpdate(job, now) diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala index b35781cb36e81..312bcccb1cca1 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala @@ -56,6 +56,13 @@ private[spark] class AppStatusStore( store.read(classOf[JobDataWrapper], jobId).info } + // Returns job data and associated SQL execution ID of certain Job ID. + // If there is no related SQL execution, the SQL execution ID part will be None. + def jobWithAssociatedSql(jobId: Int): (v1.JobData, Option[Long]) = { + val data = store.read(classOf[JobDataWrapper], jobId) + (data.info, data.sqlExecutionId) + } + def executorList(activeOnly: Boolean): Seq[v1.ExecutorSummary] = { val base = store.view(classOf[ExecutorSummaryWrapper]) val filtered = if (activeOnly) { diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala index 47e45a66ecccb..7f7b83a54d794 100644 --- a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala +++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala @@ -64,7 +64,8 @@ private class LiveJob( val submissionTime: Option[Date], val stageIds: Seq[Int], jobGroup: Option[String], - numTasks: Int) extends LiveEntity { + numTasks: Int, + sqlExecutionId: Option[Long]) extends LiveEntity { var activeTasks = 0 var completedTasks = 0 @@ -108,7 +109,7 @@ private class LiveJob( skippedStages.size, failedStages, killedSummary) - new JobDataWrapper(info, skippedStages) + new JobDataWrapper(info, skippedStages, sqlExecutionId) } } diff --git a/core/src/main/scala/org/apache/spark/status/storeTypes.scala b/core/src/main/scala/org/apache/spark/status/storeTypes.scala index ef19e86f3135f..eea47b3b17098 100644 --- a/core/src/main/scala/org/apache/spark/status/storeTypes.scala +++ b/core/src/main/scala/org/apache/spark/status/storeTypes.scala @@ -68,7 +68,8 @@ private[spark] class ExecutorSummaryWrapper(val info: ExecutorSummary) { */ private[spark] class JobDataWrapper( val info: JobData, - val skippedStages: Set[Int]) { + val skippedStages: Set[Int], + val sqlExecutionId: Option[Long]) { @JsonIgnore @KVIndex private def id: Int = info.jobId diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index 55444a2c0c9ab..b58a6ca447edf 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -189,7 +189,7 @@ private[ui] class JobPage(parent: JobsTab, store: AppStatusStore) extends WebUIP require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") val jobId = parameterId.toInt - val jobData = store.asOption(store.job(jobId)).getOrElse { + val (jobData, sqlExecutionId) = store.asOption(store.jobWithAssociatedSql(jobId)).getOrElse { val content =

    No information to display for job {jobId}

    @@ -197,6 +197,7 @@ private[ui] class JobPage(parent: JobsTab, store: AppStatusStore) extends WebUIP return UIUtils.headerSparkPage( request, s"Details for Job $jobId", content, parent) } + val isComplete = jobData.status != JobExecutionStatus.RUNNING val stages = jobData.stageIds.map { stageId => // This could be empty if the listener hasn't received information about the @@ -278,6 +279,17 @@ private[ui] class JobPage(parent: JobsTab, store: AppStatusStore) extends WebUIP Status: {jobData.status} + { + if (sqlExecutionId.isDefined) { +
  • + Associated SQL Query: + {{sqlExecutionId.get}} +
  • + } + } { if (jobData.jobGroup.isDefined) {
  • From dde56e4e80f8418020d1f9a53d8462299b88fefb Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Thu, 13 Dec 2018 16:12:55 -0800 Subject: [PATCH 068/194] [SPARK-23886][SS] Update query status for ContinuousExecution ## What changes were proposed in this pull request? Added query status updates to ContinuousExecution. ## How was this patch tested? Existing unit tests + added ContinuousQueryStatusAndProgressSuite. Closes #23095 from gaborgsomogyi/SPARK-23886. Authored-by: Gabor Somogyi Signed-off-by: Marcelo Vanzin --- .../streaming/MicroBatchExecution.scala | 6 ++ .../streaming/ProgressReporter.scala | 1 - .../continuous/ContinuousExecution.scala | 6 ++ .../sql/streaming/StreamingQueryStatus.scala | 6 +- ...ontinuousQueryStatusAndProgressSuite.scala | 55 +++++++++++++++++++ 5 files changed, 71 insertions(+), 3 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueryStatusAndProgressSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 64e09edf27f58..03beefeca269b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -147,6 +147,12 @@ class MicroBatchExecution( logInfo(s"Query $prettyIdString was stopped") } + /** Begins recording statistics about query progress for a given trigger. */ + override protected def startTrigger(): Unit = { + super.startTrigger() + currentStatus = currentStatus.copy(isTriggerActive = true) + } + /** * Repeatedly attempts to run batches as data arrives. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index 6a22f0cc8431a..39ab702ee083c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -114,7 +114,6 @@ trait ProgressReporter extends Logging { logDebug("Starting Trigger Calculation") lastTriggerStartTimestamp = currentTriggerStartTimestamp currentTriggerStartTimestamp = triggerClock.getTimeMillis() - currentStatus = currentStatus.copy(isTriggerActive = true) currentTriggerStartOffsets = null currentTriggerEndOffsets = null currentDurationsMs.clear() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 4d42428fd189e..f0859aaaa3041 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -118,6 +118,8 @@ class ContinuousExecution( // For at least once, we can just ignore those reports and risk duplicates. commitLog.getLatest() match { case Some((latestEpochId, _)) => + updateStatusMessage("Starting new streaming query " + + s"and getting offsets from latest epoch $latestEpochId") val nextOffsets = offsetLog.get(latestEpochId).getOrElse { throw new IllegalStateException( s"Batch $latestEpochId was committed without end epoch offsets!") @@ -129,6 +131,7 @@ class ContinuousExecution( nextOffsets case None => // We are starting this stream for the first time. Offsets are all None. + updateStatusMessage("Starting new streaming query") logInfo(s"Starting new streaming query.") currentBatchId = 0 OffsetSeq.fill(continuousSources.map(_ => null): _*) @@ -263,6 +266,7 @@ class ContinuousExecution( epochUpdateThread.setDaemon(true) epochUpdateThread.start() + updateStatusMessage("Running") reportTimeTaken("runContinuous") { SQLExecution.withNewExecutionId( sparkSessionForQuery, lastExecution) { @@ -322,6 +326,8 @@ class ContinuousExecution( * before this is called. */ def commit(epoch: Long): Unit = { + updateStatusMessage(s"Committing epoch $epoch") + assert(continuousSources.length == 1, "only one continuous source supported currently") assert(offsetLog.get(epoch).isDefined, s"offset for epoch $epoch not reported before commit") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala index 9dc62b7aac891..6ca9aacab7247 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala @@ -28,9 +28,11 @@ import org.apache.spark.annotation.Evolving * Reports information about the instantaneous status of a streaming query. * * @param message A human readable description of what the stream is currently doing. - * @param isDataAvailable True when there is new data to be processed. + * @param isDataAvailable True when there is new data to be processed. Doesn't apply + * to ContinuousExecution where it is always false. * @param isTriggerActive True when the trigger is actively firing, false when waiting for the - * next trigger time. + * next trigger time. Doesn't apply to ContinuousExecution where it is + * always false. * * @since 2.1.0 */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueryStatusAndProgressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueryStatusAndProgressSuite.scala new file mode 100644 index 0000000000000..10bea7f090571 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueryStatusAndProgressSuite.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming.continuous + +import org.apache.spark.sql.execution.streaming.StreamExecution +import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream +import org.apache.spark.sql.streaming.Trigger + +class ContinuousQueryStatusAndProgressSuite extends ContinuousSuiteBase { + test("StreamingQueryStatus - ContinuousExecution isDataAvailable and isTriggerActive " + + "should be false") { + import testImplicits._ + + val input = ContinuousMemoryStream[Int] + + def assertStatus(stream: StreamExecution): Unit = { + assert(stream.status.isDataAvailable === false) + assert(stream.status.isTriggerActive === false) + } + + val trigger = Trigger.Continuous(100) + testStream(input.toDF(), useV2Sink = true)( + StartStream(trigger), + Execute(assertStatus), + AddData(input, 0, 1, 2), + Execute(assertStatus), + CheckAnswer(0, 1, 2), + Execute(assertStatus), + StopStream, + Execute(assertStatus), + AddData(input, 3, 4, 5), + Execute(assertStatus), + StartStream(trigger), + Execute(assertStatus), + CheckAnswer(0, 1, 2, 3, 4, 5), + Execute(assertStatus), + StopStream, + Execute(assertStatus)) + } +} From 3fac7d4140c0167018bd6d9dea9633a02beddc25 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Fri, 14 Dec 2018 10:45:24 +0800 Subject: [PATCH 069/194] [SPARK-26364][PYTHON][TESTING] Clean up imports in test_pandas_udf* ## What changes were proposed in this pull request? Clean up unconditional import statements and move them to the top. Conditional imports (pandas, numpy, pyarrow) are left as-is. ## How was this patch tested? Exising tests. Closes #23314 from icexelloss/clean-up-test-imports. Authored-by: Li Jin Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/tests/test_pandas_udf.py | 16 +--- .../sql/tests/test_pandas_udf_grouped_agg.py | 39 +--------- .../sql/tests/test_pandas_udf_grouped_map.py | 40 +++------- .../sql/tests/test_pandas_udf_scalar.py | 75 +++++-------------- .../sql/tests/test_pandas_udf_window.py | 29 +------ 5 files changed, 36 insertions(+), 163 deletions(-) diff --git a/python/pyspark/sql/tests/test_pandas_udf.py b/python/pyspark/sql/tests/test_pandas_udf.py index c4b5478a7e893..d4d9679649ee9 100644 --- a/python/pyspark/sql/tests/test_pandas_udf.py +++ b/python/pyspark/sql/tests/test_pandas_udf.py @@ -17,12 +17,16 @@ import unittest +from pyspark.sql.functions import udf, pandas_udf, PandasUDFType from pyspark.sql.types import * from pyspark.sql.utils import ParseException +from pyspark.rdd import PythonEvalType from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ pandas_requirement_message, pyarrow_requirement_message from pyspark.testing.utils import QuietTest +from py4j.protocol import Py4JJavaError + @unittest.skipIf( not have_pandas or not have_pyarrow, @@ -30,9 +34,6 @@ class PandasUDFTests(ReusedSQLTestCase): def test_pandas_udf_basic(self): - from pyspark.rdd import PythonEvalType - from pyspark.sql.functions import pandas_udf, PandasUDFType - udf = pandas_udf(lambda x: x, DoubleType()) self.assertEqual(udf.returnType, DoubleType()) self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) @@ -65,10 +66,6 @@ def test_pandas_udf_basic(self): self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) def test_pandas_udf_decorator(self): - from pyspark.rdd import PythonEvalType - from pyspark.sql.functions import pandas_udf, PandasUDFType - from pyspark.sql.types import StructType, StructField, DoubleType - @pandas_udf(DoubleType()) def foo(x): return x @@ -114,8 +111,6 @@ def foo(x): self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) def test_udf_wrong_arg(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - with QuietTest(self.sc): with self.assertRaises(ParseException): @pandas_udf('blah') @@ -151,9 +146,6 @@ def foo(k, v, w): return k def test_stopiteration_in_udf(self): - from pyspark.sql.functions import udf, pandas_udf, PandasUDFType - from py4j.protocol import Py4JJavaError - def foo(x): raise StopIteration() diff --git a/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py b/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py index 5383704434c85..18264ead2fd08 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py +++ b/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py @@ -17,6 +17,9 @@ import unittest +from pyspark.rdd import PythonEvalType +from pyspark.sql.functions import array, explode, col, lit, mean, sum, \ + udf, pandas_udf, PandasUDFType from pyspark.sql.types import * from pyspark.sql.utils import AnalysisException from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ @@ -31,7 +34,6 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase): @property def data(self): - from pyspark.sql.functions import array, explode, col, lit return self.spark.range(10).toDF('id') \ .withColumn("vs", array([lit(i * 1.0) + col('id') for i in range(20, 30)])) \ .withColumn("v", explode(col('vs'))) \ @@ -40,8 +42,6 @@ def data(self): @property def python_plus_one(self): - from pyspark.sql.functions import udf - @udf('double') def plus_one(v): assert isinstance(v, (int, float)) @@ -51,7 +51,6 @@ def plus_one(v): @property def pandas_scalar_plus_two(self): import pandas as pd - from pyspark.sql.functions import pandas_udf, PandasUDFType @pandas_udf('double', PandasUDFType.SCALAR) def plus_two(v): @@ -61,8 +60,6 @@ def plus_two(v): @property def pandas_agg_mean_udf(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - @pandas_udf('double', PandasUDFType.GROUPED_AGG) def avg(v): return v.mean() @@ -70,8 +67,6 @@ def avg(v): @property def pandas_agg_sum_udf(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - @pandas_udf('double', PandasUDFType.GROUPED_AGG) def sum(v): return v.sum() @@ -80,7 +75,6 @@ def sum(v): @property def pandas_agg_weighted_mean_udf(self): import numpy as np - from pyspark.sql.functions import pandas_udf, PandasUDFType @pandas_udf('double', PandasUDFType.GROUPED_AGG) def weighted_mean(v, w): @@ -88,8 +82,6 @@ def weighted_mean(v, w): return weighted_mean def test_manual(self): - from pyspark.sql.functions import pandas_udf, array - df = self.data sum_udf = self.pandas_agg_sum_udf mean_udf = self.pandas_agg_mean_udf @@ -118,8 +110,6 @@ def test_manual(self): self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) def test_basic(self): - from pyspark.sql.functions import col, lit, mean - df = self.data weighted_mean_udf = self.pandas_agg_weighted_mean_udf @@ -150,9 +140,6 @@ def test_basic(self): self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) def test_unsupported_types(self): - from pyspark.sql.types import DoubleType, MapType - from pyspark.sql.functions import pandas_udf, PandasUDFType - with QuietTest(self.sc): with self.assertRaisesRegexp(NotImplementedError, 'not supported'): pandas_udf( @@ -173,8 +160,6 @@ def mean_and_std_udf(v): return {v.mean(): v.std()} def test_alias(self): - from pyspark.sql.functions import mean - df = self.data mean_udf = self.pandas_agg_mean_udf @@ -187,8 +172,6 @@ def test_mixed_sql(self): """ Test mixing group aggregate pandas UDF with sql expression. """ - from pyspark.sql.functions import sum - df = self.data sum_udf = self.pandas_agg_sum_udf @@ -225,8 +208,6 @@ def test_mixed_udfs(self): """ Test mixing group aggregate pandas UDF with python UDF and scalar pandas UDF. """ - from pyspark.sql.functions import sum - df = self.data plus_one = self.python_plus_one plus_two = self.pandas_scalar_plus_two @@ -292,8 +273,6 @@ def test_multiple_udfs(self): """ Test multiple group aggregate pandas UDFs in one agg function. """ - from pyspark.sql.functions import sum, mean - df = self.data mean_udf = self.pandas_agg_mean_udf sum_udf = self.pandas_agg_sum_udf @@ -315,8 +294,6 @@ def test_multiple_udfs(self): self.assertPandasEqual(expected1, result1) def test_complex_groupby(self): - from pyspark.sql.functions import sum - df = self.data sum_udf = self.pandas_agg_sum_udf plus_one = self.python_plus_one @@ -359,8 +336,6 @@ def test_complex_groupby(self): self.assertPandasEqual(expected7.toPandas(), result7.toPandas()) def test_complex_expressions(self): - from pyspark.sql.functions import col, sum - df = self.data plus_one = self.python_plus_one plus_two = self.pandas_scalar_plus_two @@ -434,7 +409,6 @@ def test_complex_expressions(self): self.assertPandasEqual(expected3, result3) def test_retain_group_columns(self): - from pyspark.sql.functions import sum with self.sql_conf({"spark.sql.retainGroupColumns": False}): df = self.data sum_udf = self.pandas_agg_sum_udf @@ -444,8 +418,6 @@ def test_retain_group_columns(self): self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) def test_array_type(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - df = self.data array_udf = pandas_udf(lambda x: [1.0, 2.0], 'array', PandasUDFType.GROUPED_AGG) @@ -453,8 +425,6 @@ def test_array_type(self): self.assertEquals(result1.first()['v2'], [1.0, 2.0]) def test_invalid_args(self): - from pyspark.sql.functions import mean - df = self.data plus_one = self.python_plus_one mean_udf = self.pandas_agg_mean_udf @@ -478,9 +448,6 @@ def test_invalid_args(self): df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect() def test_register_vectorized_udf_basic(self): - from pyspark.sql.functions import pandas_udf - from pyspark.rdd import PythonEvalType - sum_pandas_udf = pandas_udf( lambda v: v.sum(), "integer", PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF) diff --git a/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py b/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py index a12c608dff9dd..80e70349b78d3 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py +++ b/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py @@ -18,7 +18,12 @@ import datetime import unittest +from collections import OrderedDict +from decimal import Decimal +from distutils.version import LooseVersion + from pyspark.sql import Row +from pyspark.sql.functions import array, explode, col, lit, udf, sum, pandas_udf, PandasUDFType from pyspark.sql.types import * from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ pandas_requirement_message, pyarrow_requirement_message @@ -32,16 +37,12 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase): @property def data(self): - from pyspark.sql.functions import array, explode, col, lit return self.spark.range(10).toDF('id') \ .withColumn("vs", array([lit(i) for i in range(20, 30)])) \ .withColumn("v", explode(col('vs'))).drop('vs') def test_supported_types(self): - from decimal import Decimal - from distutils.version import LooseVersion import pyarrow as pa - from pyspark.sql.functions import pandas_udf, PandasUDFType values = [ 1, 2, 3, @@ -131,8 +132,6 @@ def test_supported_types(self): self.assertPandasEqual(expected3, result3) def test_array_type_correct(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col - df = self.data.withColumn("arr", array(col("id"))).repartition(1, "id") output_schema = StructType( @@ -151,8 +150,6 @@ def test_array_type_correct(self): self.assertPandasEqual(expected, result) def test_register_grouped_map_udf(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - foo_udf = pandas_udf(lambda x: x, "id long", PandasUDFType.GROUPED_MAP) with QuietTest(self.sc): with self.assertRaisesRegexp( @@ -161,7 +158,6 @@ def test_register_grouped_map_udf(self): self.spark.catalog.registerFunction("foo_udf", foo_udf) def test_decorator(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType df = self.data @pandas_udf( @@ -176,7 +172,6 @@ def foo(pdf): self.assertPandasEqual(expected, result) def test_coerce(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType df = self.data foo = pandas_udf( @@ -191,7 +186,6 @@ def test_coerce(self): self.assertPandasEqual(expected, result) def test_complex_groupby(self): - from pyspark.sql.functions import pandas_udf, col, PandasUDFType df = self.data @pandas_udf( @@ -210,7 +204,6 @@ def normalize(pdf): self.assertPandasEqual(expected, result) def test_empty_groupby(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType df = self.data @pandas_udf( @@ -229,7 +222,6 @@ def normalize(pdf): self.assertPandasEqual(expected, result) def test_datatype_string(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType df = self.data foo_udf = pandas_udf( @@ -243,8 +235,6 @@ def test_datatype_string(self): self.assertPandasEqual(expected, result) def test_wrong_return_type(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - with QuietTest(self.sc): with self.assertRaisesRegexp( NotImplementedError, @@ -255,7 +245,6 @@ def test_wrong_return_type(self): PandasUDFType.GROUPED_MAP) def test_wrong_args(self): - from pyspark.sql.functions import udf, pandas_udf, sum, PandasUDFType df = self.data with QuietTest(self.sc): @@ -277,9 +266,7 @@ def test_wrong_args(self): pandas_udf(lambda x, y: x, DoubleType(), PandasUDFType.SCALAR)) def test_unsupported_types(self): - from distutils.version import LooseVersion import pyarrow as pa - from pyspark.sql.functions import pandas_udf, PandasUDFType common_err_msg = 'Invalid returnType.*grouped map Pandas UDF.*' unsupported_types = [ @@ -300,7 +287,6 @@ def test_unsupported_types(self): # Regression test for SPARK-23314 def test_timestamp_dst(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType # Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am dt = [datetime.datetime(2015, 11, 1, 0, 30), datetime.datetime(2015, 11, 1, 1, 30), @@ -311,12 +297,12 @@ def test_timestamp_dst(self): self.assertPandasEqual(df.toPandas(), result.toPandas()) def test_udf_with_key(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType + import numpy as np + df = self.data pdf = df.toPandas() def foo1(key, pdf): - import numpy as np assert type(key) == tuple assert type(key[0]) == np.int64 @@ -326,7 +312,6 @@ def foo1(key, pdf): v4=pdf.v * pdf.id.mean()) def foo2(key, pdf): - import numpy as np assert type(key) == tuple assert type(key[0]) == np.int64 assert type(key[1]) == np.int32 @@ -385,9 +370,7 @@ def foo3(key, pdf): self.assertPandasEqual(expected4, result4) def test_column_order(self): - from collections import OrderedDict import pandas as pd - from pyspark.sql.functions import pandas_udf, PandasUDFType # Helper function to set column names from a list def rename_pdf(pdf, names): @@ -468,7 +451,6 @@ def invalid_positional_types(pdf): with QuietTest(self.sc): with self.assertRaisesRegexp(Exception, "KeyError: 'id'"): grouped_df.apply(column_name_typo).collect() - from distutils.version import LooseVersion import pyarrow as pa if LooseVersion(pa.__version__) < LooseVersion("0.11.0"): # TODO: see ARROW-1949. Remove when the minimum PyArrow version becomes 0.11.0. @@ -480,7 +462,6 @@ def invalid_positional_types(pdf): def test_positional_assignment_conf(self): import pandas as pd - from pyspark.sql.functions import pandas_udf, PandasUDFType with self.sql_conf({ "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName": False}): @@ -496,9 +477,7 @@ def foo(_): self.assertEqual(r.b, 1) def test_self_join_with_pandas(self): - import pyspark.sql.functions as F - - @F.pandas_udf('key long, col string', F.PandasUDFType.GROUPED_MAP) + @pandas_udf('key long, col string', PandasUDFType.GROUPED_MAP) def dummy_pandas_udf(df): return df[['key', 'col']] @@ -508,12 +487,11 @@ def dummy_pandas_udf(df): # this was throwing an AnalysisException before SPARK-24208 res = df_with_pandas.alias('temp0').join(df_with_pandas.alias('temp1'), - F.col('temp0.key') == F.col('temp1.key')) + col('temp0.key') == col('temp1.key')) self.assertEquals(res.count(), 5) def test_mixed_scalar_udfs_followed_by_grouby_apply(self): import pandas as pd - from pyspark.sql.functions import udf, pandas_udf, PandasUDFType df = self.spark.range(0, 10).toDF('v1') df = df.withColumn('v2', udf(lambda x: x + 1, 'int')(df['v1'])) \ diff --git a/python/pyspark/sql/tests/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/test_pandas_udf_scalar.py index 2f585a3725988..6a6865a9fb16d 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/test_pandas_udf_scalar.py @@ -16,12 +16,20 @@ # import datetime import os +import random import shutil import sys import tempfile import time import unittest +from datetime import date, datetime +from decimal import Decimal +from distutils.version import LooseVersion + +from pyspark.rdd import PythonEvalType +from pyspark.sql import Column +from pyspark.sql.functions import array, col, expr, lit, sum, udf, pandas_udf from pyspark.sql.types import Row from pyspark.sql.types import * from pyspark.sql.utils import AnalysisException @@ -59,18 +67,16 @@ def tearDownClass(cls): @property def nondeterministic_vectorized_udf(self): - from pyspark.sql.functions import pandas_udf + import pandas as pd + import numpy as np @pandas_udf('double') def random_udf(v): - import pandas as pd - import numpy as np return pd.Series(np.random.random(len(v))) random_udf = random_udf.asNondeterministic() return random_udf def test_pandas_udf_tokenize(self): - from pyspark.sql.functions import pandas_udf tokenize = pandas_udf(lambda s: s.apply(lambda str: str.split(' ')), ArrayType(StringType())) self.assertEqual(tokenize.returnType, ArrayType(StringType())) @@ -79,7 +85,6 @@ def test_pandas_udf_tokenize(self): self.assertEqual([Row(hi=[u'hi', u'boo']), Row(hi=[u'bye', u'boo'])], result.collect()) def test_pandas_udf_nested_arrays(self): - from pyspark.sql.functions import pandas_udf tokenize = pandas_udf(lambda s: s.apply(lambda str: [str.split(' ')]), ArrayType(ArrayType(StringType()))) self.assertEqual(tokenize.returnType, ArrayType(ArrayType(StringType()))) @@ -88,7 +93,6 @@ def test_pandas_udf_nested_arrays(self): self.assertEqual([Row(hi=[[u'hi', u'boo']]), Row(hi=[[u'bye', u'boo']])], result.collect()) def test_vectorized_udf_basic(self): - from pyspark.sql.functions import pandas_udf, col, array df = self.spark.range(10).select( col('id').cast('string').alias('str'), col('id').cast('int').alias('int'), @@ -114,9 +118,6 @@ def test_vectorized_udf_basic(self): self.assertEquals(df.collect(), res.collect()) def test_register_nondeterministic_vectorized_udf_basic(self): - from pyspark.sql.functions import pandas_udf - from pyspark.rdd import PythonEvalType - import random random_pandas_udf = pandas_udf( lambda x: random.randint(6, 6) + x, IntegerType()).asNondeterministic() self.assertEqual(random_pandas_udf.deterministic, False) @@ -129,7 +130,6 @@ def test_register_nondeterministic_vectorized_udf_basic(self): self.assertEqual(row[0], 7) def test_vectorized_udf_null_boolean(self): - from pyspark.sql.functions import pandas_udf, col data = [(True,), (True,), (None,), (False,)] schema = StructType().add("bool", BooleanType()) df = self.spark.createDataFrame(data, schema) @@ -138,7 +138,6 @@ def test_vectorized_udf_null_boolean(self): self.assertEquals(df.collect(), res.collect()) def test_vectorized_udf_null_byte(self): - from pyspark.sql.functions import pandas_udf, col data = [(None,), (2,), (3,), (4,)] schema = StructType().add("byte", ByteType()) df = self.spark.createDataFrame(data, schema) @@ -147,7 +146,6 @@ def test_vectorized_udf_null_byte(self): self.assertEquals(df.collect(), res.collect()) def test_vectorized_udf_null_short(self): - from pyspark.sql.functions import pandas_udf, col data = [(None,), (2,), (3,), (4,)] schema = StructType().add("short", ShortType()) df = self.spark.createDataFrame(data, schema) @@ -156,7 +154,6 @@ def test_vectorized_udf_null_short(self): self.assertEquals(df.collect(), res.collect()) def test_vectorized_udf_null_int(self): - from pyspark.sql.functions import pandas_udf, col data = [(None,), (2,), (3,), (4,)] schema = StructType().add("int", IntegerType()) df = self.spark.createDataFrame(data, schema) @@ -165,7 +162,6 @@ def test_vectorized_udf_null_int(self): self.assertEquals(df.collect(), res.collect()) def test_vectorized_udf_null_long(self): - from pyspark.sql.functions import pandas_udf, col data = [(None,), (2,), (3,), (4,)] schema = StructType().add("long", LongType()) df = self.spark.createDataFrame(data, schema) @@ -174,7 +170,6 @@ def test_vectorized_udf_null_long(self): self.assertEquals(df.collect(), res.collect()) def test_vectorized_udf_null_float(self): - from pyspark.sql.functions import pandas_udf, col data = [(3.0,), (5.0,), (-1.0,), (None,)] schema = StructType().add("float", FloatType()) df = self.spark.createDataFrame(data, schema) @@ -183,7 +178,6 @@ def test_vectorized_udf_null_float(self): self.assertEquals(df.collect(), res.collect()) def test_vectorized_udf_null_double(self): - from pyspark.sql.functions import pandas_udf, col data = [(3.0,), (5.0,), (-1.0,), (None,)] schema = StructType().add("double", DoubleType()) df = self.spark.createDataFrame(data, schema) @@ -192,8 +186,6 @@ def test_vectorized_udf_null_double(self): self.assertEquals(df.collect(), res.collect()) def test_vectorized_udf_null_decimal(self): - from decimal import Decimal - from pyspark.sql.functions import pandas_udf, col data = [(Decimal(3.0),), (Decimal(5.0),), (Decimal(-1.0),), (None,)] schema = StructType().add("decimal", DecimalType(38, 18)) df = self.spark.createDataFrame(data, schema) @@ -202,7 +194,6 @@ def test_vectorized_udf_null_decimal(self): self.assertEquals(df.collect(), res.collect()) def test_vectorized_udf_null_string(self): - from pyspark.sql.functions import pandas_udf, col data = [("foo",), (None,), ("bar",), ("bar",)] schema = StructType().add("str", StringType()) df = self.spark.createDataFrame(data, schema) @@ -211,7 +202,6 @@ def test_vectorized_udf_null_string(self): self.assertEquals(df.collect(), res.collect()) def test_vectorized_udf_string_in_udf(self): - from pyspark.sql.functions import pandas_udf, col import pandas as pd df = self.spark.range(10) str_f = pandas_udf(lambda x: pd.Series(map(str, x)), StringType()) @@ -220,7 +210,6 @@ def test_vectorized_udf_string_in_udf(self): self.assertEquals(expected.collect(), actual.collect()) def test_vectorized_udf_datatype_string(self): - from pyspark.sql.functions import pandas_udf, col df = self.spark.range(10).select( col('id').cast('string').alias('str'), col('id').cast('int').alias('int'), @@ -244,9 +233,8 @@ def test_vectorized_udf_datatype_string(self): self.assertEquals(df.collect(), res.collect()) def test_vectorized_udf_null_binary(self): - from distutils.version import LooseVersion import pyarrow as pa - from pyspark.sql.functions import pandas_udf, col + if LooseVersion(pa.__version__) < LooseVersion("0.10.0"): with QuietTest(self.sc): with self.assertRaisesRegexp( @@ -262,7 +250,6 @@ def test_vectorized_udf_null_binary(self): self.assertEquals(df.collect(), res.collect()) def test_vectorized_udf_array_type(self): - from pyspark.sql.functions import pandas_udf, col data = [([1, 2],), ([3, 4],)] array_schema = StructType([StructField("array", ArrayType(IntegerType()))]) df = self.spark.createDataFrame(data, schema=array_schema) @@ -271,7 +258,6 @@ def test_vectorized_udf_array_type(self): self.assertEquals(df.collect(), result.collect()) def test_vectorized_udf_null_array(self): - from pyspark.sql.functions import pandas_udf, col data = [([1, 2],), (None,), (None,), ([3, 4],), (None,)] array_schema = StructType([StructField("array", ArrayType(IntegerType()))]) df = self.spark.createDataFrame(data, schema=array_schema) @@ -280,7 +266,6 @@ def test_vectorized_udf_null_array(self): self.assertEquals(df.collect(), result.collect()) def test_vectorized_udf_complex(self): - from pyspark.sql.functions import pandas_udf, col, expr df = self.spark.range(10).select( col('id').cast('int').alias('a'), col('id').cast('int').alias('b'), @@ -293,7 +278,6 @@ def test_vectorized_udf_complex(self): self.assertEquals(expected.collect(), res.collect()) def test_vectorized_udf_exception(self): - from pyspark.sql.functions import pandas_udf, col df = self.spark.range(10) raise_exception = pandas_udf(lambda x: x * (1 / 0), LongType()) with QuietTest(self.sc): @@ -301,8 +285,8 @@ def test_vectorized_udf_exception(self): df.select(raise_exception(col('id'))).collect() def test_vectorized_udf_invalid_length(self): - from pyspark.sql.functions import pandas_udf, col import pandas as pd + df = self.spark.range(10) raise_exception = pandas_udf(lambda _: pd.Series(1), LongType()) with QuietTest(self.sc): @@ -312,7 +296,6 @@ def test_vectorized_udf_invalid_length(self): df.select(raise_exception(col('id'))).collect() def test_vectorized_udf_chained(self): - from pyspark.sql.functions import pandas_udf, col df = self.spark.range(10) f = pandas_udf(lambda x: x + 1, LongType()) g = pandas_udf(lambda x: x - 1, LongType()) @@ -320,7 +303,6 @@ def test_vectorized_udf_chained(self): self.assertEquals(df.collect(), res.collect()) def test_vectorized_udf_wrong_return_type(self): - from pyspark.sql.functions import pandas_udf with QuietTest(self.sc): with self.assertRaisesRegexp( NotImplementedError, @@ -328,7 +310,6 @@ def test_vectorized_udf_wrong_return_type(self): pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType())) def test_vectorized_udf_return_scalar(self): - from pyspark.sql.functions import pandas_udf, col df = self.spark.range(10) f = pandas_udf(lambda x: 1.0, DoubleType()) with QuietTest(self.sc): @@ -336,7 +317,6 @@ def test_vectorized_udf_return_scalar(self): df.select(f(col('id'))).collect() def test_vectorized_udf_decorator(self): - from pyspark.sql.functions import pandas_udf, col df = self.spark.range(10) @pandas_udf(returnType=LongType()) @@ -346,21 +326,18 @@ def identity(x): self.assertEquals(df.collect(), res.collect()) def test_vectorized_udf_empty_partition(self): - from pyspark.sql.functions import pandas_udf, col df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2)) f = pandas_udf(lambda x: x, LongType()) res = df.select(f(col('id'))) self.assertEquals(df.collect(), res.collect()) def test_vectorized_udf_varargs(self): - from pyspark.sql.functions import pandas_udf, col df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2)) f = pandas_udf(lambda *v: v[0], LongType()) res = df.select(f(col('id'))) self.assertEquals(df.collect(), res.collect()) def test_vectorized_udf_unsupported_types(self): - from pyspark.sql.functions import pandas_udf with QuietTest(self.sc): with self.assertRaisesRegexp( NotImplementedError, @@ -368,8 +345,6 @@ def test_vectorized_udf_unsupported_types(self): pandas_udf(lambda x: x, MapType(StringType(), IntegerType())) def test_vectorized_udf_dates(self): - from pyspark.sql.functions import pandas_udf, col - from datetime import date schema = StructType().add("idx", LongType()).add("date", DateType()) data = [(0, date(1969, 1, 1),), (1, date(2012, 2, 2),), @@ -405,8 +380,6 @@ def check_data(idx, date, date_copy): self.assertIsNone(result[i][3]) # "check_data" col def test_vectorized_udf_timestamps(self): - from pyspark.sql.functions import pandas_udf, col - from datetime import datetime schema = StructType([ StructField("idx", LongType(), True), StructField("timestamp", TimestampType(), True)]) @@ -447,8 +420,8 @@ def check_data(idx, timestamp, timestamp_copy): self.assertIsNone(result[i][3]) # "check_data" col def test_vectorized_udf_return_timestamp_tz(self): - from pyspark.sql.functions import pandas_udf, col import pandas as pd + df = self.spark.range(10) @pandas_udf(returnType=TimestampType()) @@ -465,8 +438,8 @@ def gen_timestamps(id): self.assertEquals(expected, ts) def test_vectorized_udf_check_config(self): - from pyspark.sql.functions import pandas_udf, col import pandas as pd + with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 3}): df = self.spark.range(10, numPartitions=1) @@ -479,9 +452,8 @@ def check_records_per_batch(x): self.assertTrue(r <= 3) def test_vectorized_udf_timestamps_respect_session_timezone(self): - from pyspark.sql.functions import pandas_udf, col - from datetime import datetime import pandas as pd + schema = StructType([ StructField("idx", LongType(), True), StructField("timestamp", TimestampType(), True)]) @@ -519,8 +491,6 @@ def test_vectorized_udf_timestamps_respect_session_timezone(self): def test_nondeterministic_vectorized_udf(self): # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations - from pyspark.sql.functions import pandas_udf, col - @pandas_udf('double') def plus_ten(v): return v + 10 @@ -533,8 +503,6 @@ def plus_ten(v): self.assertTrue(result1['plus_ten(rand)'].equals(result1['rand'] + 10)) def test_nondeterministic_vectorized_udf_in_aggregate(self): - from pyspark.sql.functions import sum - df = self.spark.range(10) random_udf = self.nondeterministic_vectorized_udf @@ -545,8 +513,6 @@ def test_nondeterministic_vectorized_udf_in_aggregate(self): df.agg(sum(random_udf(df.id))).collect() def test_register_vectorized_udf_basic(self): - from pyspark.rdd import PythonEvalType - from pyspark.sql.functions import pandas_udf, col, expr df = self.spark.range(10).select( col('id').cast('int').alias('a'), col('id').cast('int').alias('b')) @@ -563,11 +529,10 @@ def test_register_vectorized_udf_basic(self): # Regression test for SPARK-23314 def test_timestamp_dst(self): - from pyspark.sql.functions import pandas_udf # Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am - dt = [datetime.datetime(2015, 11, 1, 0, 30), - datetime.datetime(2015, 11, 1, 1, 30), - datetime.datetime(2015, 11, 1, 2, 30)] + dt = [datetime(2015, 11, 1, 0, 30), + datetime(2015, 11, 1, 1, 30), + datetime(2015, 11, 1, 2, 30)] df = self.spark.createDataFrame(dt, 'timestamp').toDF('time') foo_udf = pandas_udf(lambda x: x, 'timestamp') result = df.withColumn('time', foo_udf(df.time)) @@ -593,7 +558,6 @@ def test_type_annotation(self): def test_mixed_udf(self): import pandas as pd - from pyspark.sql.functions import col, udf, pandas_udf df = self.spark.range(0, 1).toDF('v') @@ -696,8 +660,6 @@ def f4(x): def test_mixed_udf_and_sql(self): import pandas as pd - from pyspark.sql import Column - from pyspark.sql.functions import udf, pandas_udf df = self.spark.range(0, 1).toDF('v') @@ -758,7 +720,6 @@ def test_datasource_with_udf(self): # This needs to a separate test because Arrow dependency is optional import pandas as pd import numpy as np - from pyspark.sql.functions import pandas_udf, lit, col path = tempfile.mkdtemp() shutil.rmtree(path) diff --git a/python/pyspark/sql/tests/test_pandas_udf_window.py b/python/pyspark/sql/tests/test_pandas_udf_window.py index f0e6d2696df62..0a7a19c1c0814 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_window.py +++ b/python/pyspark/sql/tests/test_pandas_udf_window.py @@ -18,6 +18,8 @@ import unittest from pyspark.sql.utils import AnalysisException +from pyspark.sql.functions import array, explode, col, lit, mean, min, max, rank, \ + udf, pandas_udf, PandasUDFType from pyspark.sql.window import Window from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ pandas_requirement_message, pyarrow_requirement_message @@ -30,7 +32,6 @@ class WindowPandasUDFTests(ReusedSQLTestCase): @property def data(self): - from pyspark.sql.functions import array, explode, col, lit return self.spark.range(10).toDF('id') \ .withColumn("vs", array([lit(i * 1.0) + col('id') for i in range(20, 30)])) \ .withColumn("v", explode(col('vs'))) \ @@ -39,18 +40,14 @@ def data(self): @property def python_plus_one(self): - from pyspark.sql.functions import udf return udf(lambda v: v + 1, 'double') @property def pandas_scalar_time_two(self): - from pyspark.sql.functions import pandas_udf return pandas_udf(lambda v: v * 2, 'double') @property def pandas_agg_mean_udf(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - @pandas_udf('double', PandasUDFType.GROUPED_AGG) def avg(v): return v.mean() @@ -58,8 +55,6 @@ def avg(v): @property def pandas_agg_max_udf(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - @pandas_udf('double', PandasUDFType.GROUPED_AGG) def max(v): return v.max() @@ -67,8 +62,6 @@ def max(v): @property def pandas_agg_min_udf(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - @pandas_udf('double', PandasUDFType.GROUPED_AGG) def min(v): return v.min() @@ -88,8 +81,6 @@ def unpartitioned_window(self): return Window.partitionBy() def test_simple(self): - from pyspark.sql.functions import mean - df = self.data w = self.unbounded_window @@ -105,8 +96,6 @@ def test_simple(self): self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) def test_multiple_udfs(self): - from pyspark.sql.functions import max, min, mean - df = self.data w = self.unbounded_window @@ -121,8 +110,6 @@ def test_multiple_udfs(self): self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) def test_replace_existing(self): - from pyspark.sql.functions import mean - df = self.data w = self.unbounded_window @@ -132,8 +119,6 @@ def test_replace_existing(self): self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) def test_mixed_sql(self): - from pyspark.sql.functions import mean - df = self.data w = self.unbounded_window mean_udf = self.pandas_agg_mean_udf @@ -144,8 +129,6 @@ def test_mixed_sql(self): self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) def test_mixed_udf(self): - from pyspark.sql.functions import mean - df = self.data w = self.unbounded_window @@ -171,8 +154,6 @@ def test_mixed_udf(self): self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) def test_without_partitionBy(self): - from pyspark.sql.functions import mean - df = self.data w = self.unpartitioned_window mean_udf = self.pandas_agg_mean_udf @@ -187,8 +168,6 @@ def test_without_partitionBy(self): self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) def test_mixed_sql_and_udf(self): - from pyspark.sql.functions import max, min, rank, col - df = self.data w = self.unbounded_window ow = self.ordered_window @@ -221,8 +200,6 @@ def test_mixed_sql_and_udf(self): self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) def test_array_type(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - df = self.data w = self.unbounded_window @@ -231,8 +208,6 @@ def test_array_type(self): self.assertEquals(result1.first()['v2'], [1.0, 2.0]) def test_invalid_args(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - df = self.data w = self.unbounded_window ow = self.ordered_window From 611a13b645cc503865d34473a4d0e88c71631b46 Mon Sep 17 00:00:00 2001 From: jasonwayne Date: Fri, 14 Dec 2018 10:47:58 +0800 Subject: [PATCH 070/194] [SPARK-26360] remove redundant validateQuery call ## What changes were proposed in this pull request? remove a redundant `KafkaWriter.validateQuery` call in `KafkaSourceProvider ` ## How was this patch tested? Just removing duplicate codes, so I just build and run unit tests. Closes #23309 from JasonWayne/SPARK-26360. Authored-by: jasonwayne Signed-off-by: Hyukjin Kwon --- .../org/apache/spark/sql/kafka010/KafkaSourceProvider.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 6a0c2088ac3d1..4b8b5c0019b44 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -266,8 +266,6 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister // We convert the options argument from V2 -> Java map -> scala mutable -> scala immutable. val producerParams = kafkaParamsForProducer(options.asMap.asScala.toMap) - KafkaWriter.validateQuery(schema.toAttributes, producerParams, topic) - new KafkaStreamingWriteSupport(topic, producerParams, schema) } From 41b01075c2cf6bad6f2fc2e080bf66a1e8b159ff Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 14 Dec 2018 10:50:48 +0800 Subject: [PATCH 071/194] [SPARK-26337][SQL][TEST] Add benchmark for LongToUnsafeRowMap ## What changes were proposed in this pull request? Regarding the performance issue of SPARK-26155, it reports the issue on TPC-DS. I think it is better to add a benchmark for `LongToUnsafeRowMap` which is the root cause of performance regression. It can be easier to show performance difference between different metric implementations in `LongToUnsafeRowMap`. ## How was this patch tested? Manually run added benchmark. Closes #23284 from viirya/SPARK-26337. Authored-by: Liang-Chi Hsieh Signed-off-by: Wenchen Fan --- ...HashedRelationMetricsBenchmark-results.txt | 11 +++ .../HashedRelationMetricsBenchmark.scala | 84 +++++++++++++++++++ 2 files changed, 95 insertions(+) create mode 100644 sql/core/benchmarks/HashedRelationMetricsBenchmark-results.txt create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/HashedRelationMetricsBenchmark.scala diff --git a/sql/core/benchmarks/HashedRelationMetricsBenchmark-results.txt b/sql/core/benchmarks/HashedRelationMetricsBenchmark-results.txt new file mode 100644 index 0000000000000..338244ad542f4 --- /dev/null +++ b/sql/core/benchmarks/HashedRelationMetricsBenchmark-results.txt @@ -0,0 +1,11 @@ +================================================================================================ +LongToUnsafeRowMap metrics +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Mac OS X 10.13.6 +Intel(R) Core(TM) i7-7700HQ CPU @ 2.80GHz +LongToUnsafeRowMap metrics: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +LongToUnsafeRowMap 234 / 315 2.1 467.3 1.0X + + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/HashedRelationMetricsBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/HashedRelationMetricsBenchmark.scala new file mode 100644 index 0000000000000..bdf753debe62a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/HashedRelationMetricsBenchmark.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import org.apache.spark.SparkConf +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.internal.config.MEMORY_OFFHEAP_ENABLED +import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeProjection} +import org.apache.spark.sql.execution.joins.LongToUnsafeRowMap +import org.apache.spark.sql.types.LongType + +/** + * Benchmark to measure metrics performance at HashedRelation. + * To run this benchmark: + * {{{ + * 1. without sbt: bin/spark-submit --class + * 2. build/sbt "sql/test:runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " + * Results will be written to "benchmarks/HashedRelationMetricsBenchmark-results.txt". + * }}} + */ +object HashedRelationMetricsBenchmark extends SqlBasedBenchmark { + + def benchmarkLongToUnsafeRowMapMetrics(numRows: Int): Unit = { + runBenchmark("LongToUnsafeRowMap metrics") { + val benchmark = new Benchmark("LongToUnsafeRowMap metrics", numRows, output = output) + benchmark.addCase("LongToUnsafeRowMap") { iter => + val taskMemoryManager = new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) + val unsafeProj = UnsafeProjection.create(Seq(BoundReference(0, LongType, false))) + + val keys = Range.Long(0, numRows, 1) + val map = new LongToUnsafeRowMap(taskMemoryManager, 1) + keys.foreach { k => + map.append(k, unsafeProj(InternalRow(k))) + } + map.optimize() + + val threads = (0 to 100).map { _ => + val thread = new Thread { + override def run: Unit = { + val row = unsafeProj(InternalRow(0L)).copy() + keys.foreach { k => + assert(map.getValue(k, row) eq row) + assert(row.getLong(0) == k) + } + } + } + thread.start() + thread + } + threads.map(_.join()) + map.free() + } + benchmark.run() + } + } + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + benchmarkLongToUnsafeRowMapMetrics(500000) + } +} From 2ba82a7c898f1cf09a14571aac66474a391a0891 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 13 Dec 2018 20:55:12 -0800 Subject: [PATCH 072/194] [SPARK-26368][SQL] Make it clear that getOrInferFileFormatSchema doesn't create InMemoryFileIndex ## What changes were proposed in this pull request? I was looking at the code and it was a bit difficult to see the life cycle of InMemoryFileIndex passed into getOrInferFileFormatSchema, because once it is passed in, and another time it was created in getOrInferFileFormatSchema. It'd be easier to understand the life cycle if we move the creation of it out. ## How was this patch tested? This is a simple code move and should be covered by existing tests. Closes #23317 from rxin/SPARK-26368. Authored-by: Reynold Xin Signed-off-by: gatorsmile --- .../execution/datasources/DataSource.scala | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 795a6d0b6b040..fefff68c4ba8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -122,21 +122,14 @@ case class DataSource( * be any further inference in any triggers. * * @param format the file format object for this DataSource - * @param fileIndex optional [[InMemoryFileIndex]] for getting partition schema and file list + * @param getFileIndex [[InMemoryFileIndex]] for getting partition schema and file list * @return A pair of the data schema (excluding partition columns) and the schema of the partition * columns. */ private def getOrInferFileFormatSchema( format: FileFormat, - fileIndex: Option[InMemoryFileIndex] = None): (StructType, StructType) = { - // The operations below are expensive therefore try not to do them if we don't need to, e.g., - // in streaming mode, we have already inferred and registered partition columns, we will - // never have to materialize the lazy val below - lazy val tempFileIndex = fileIndex.getOrElse { - val globbedPaths = - checkAndGlobPathIfNecessary(checkEmptyGlobPath = false, checkFilesExist = false) - createInMemoryFileIndex(globbedPaths) - } + getFileIndex: () => InMemoryFileIndex): (StructType, StructType) = { + lazy val tempFileIndex = getFileIndex() val partitionSchema = if (partitionColumns.isEmpty) { // Try to infer partitioning, because no DataSource in the read path provides the partitioning @@ -236,7 +229,15 @@ case class DataSource( "you may be able to create a static DataFrame on that directory with " + "'spark.read.load(directory)' and infer schema from it.") } - val (dataSchema, partitionSchema) = getOrInferFileFormatSchema(format) + + val (dataSchema, partitionSchema) = getOrInferFileFormatSchema(format, () => { + // The operations below are expensive therefore try not to do them if we don't need to, + // e.g., in streaming mode, we have already inferred and registered partition columns, + // we will never have to materialize the lazy val below + val globbedPaths = + checkAndGlobPathIfNecessary(checkEmptyGlobPath = false, checkFilesExist = false) + createInMemoryFileIndex(globbedPaths) + }) SourceInfo( s"FileSource[$path]", StructType(dataSchema ++ partitionSchema), @@ -370,7 +371,7 @@ case class DataSource( } else { val index = createInMemoryFileIndex(globbedPaths) val (resultDataSchema, resultPartitionSchema) = - getOrInferFileFormatSchema(format, Some(index)) + getOrInferFileFormatSchema(format, () => index) (index, resultDataSchema, resultPartitionSchema) } From e3b179071aa58d08211fa6fe5a1efa43e4a756b8 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Sat, 15 Dec 2018 00:23:28 +0800 Subject: [PATCH 073/194] [SPARK-26370][SQL] Fix resolution of higher-order function for the same identifier. ## What changes were proposed in this pull request? When using a higher-order function with the same variable name as the existing columns in `Filter` or something which uses `Analyzer.resolveExpressionBottomUp` during the resolution, e.g.,: ```scala val df = Seq( (Seq(1, 9, 8, 7), 1, 2), (Seq(5, 9, 7), 2, 2), (Seq.empty, 3, 2), (null, 4, 2) ).toDF("i", "x", "d") checkAnswer(df.filter("exists(i, x -> x % d == 0)"), Seq(Row(Seq(1, 9, 8, 7), 1, 2))) checkAnswer(df.select("x").filter("exists(i, x -> x % d == 0)"), Seq(Row(1))) ``` the following exception happens: ``` java.lang.ClassCastException: org.apache.spark.sql.catalyst.expressions.BoundReference cannot be cast to org.apache.spark.sql.catalyst.expressions.NamedExpression at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:237) at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62) at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55) at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49) at scala.collection.TraversableLike.map(TraversableLike.scala:237) at scala.collection.TraversableLike.map$(TraversableLike.scala:230) at scala.collection.AbstractTraversable.map(Traversable.scala:108) at org.apache.spark.sql.catalyst.expressions.HigherOrderFunction.$anonfun$functionsForEval$1(higherOrderFunctions.scala:147) at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:237) at scala.collection.immutable.List.foreach(List.scala:392) at scala.collection.TraversableLike.map(TraversableLike.scala:237) at scala.collection.TraversableLike.map$(TraversableLike.scala:230) at scala.collection.immutable.List.map(List.scala:298) at org.apache.spark.sql.catalyst.expressions.HigherOrderFunction.functionsForEval(higherOrderFunctions.scala:145) at org.apache.spark.sql.catalyst.expressions.HigherOrderFunction.functionsForEval$(higherOrderFunctions.scala:145) at org.apache.spark.sql.catalyst.expressions.ArrayExists.functionsForEval$lzycompute(higherOrderFunctions.scala:369) at org.apache.spark.sql.catalyst.expressions.ArrayExists.functionsForEval(higherOrderFunctions.scala:369) at org.apache.spark.sql.catalyst.expressions.SimpleHigherOrderFunction.functionForEval(higherOrderFunctions.scala:176) at org.apache.spark.sql.catalyst.expressions.SimpleHigherOrderFunction.functionForEval$(higherOrderFunctions.scala:176) at org.apache.spark.sql.catalyst.expressions.ArrayExists.functionForEval(higherOrderFunctions.scala:369) at org.apache.spark.sql.catalyst.expressions.ArrayExists.nullSafeEval(higherOrderFunctions.scala:387) at org.apache.spark.sql.catalyst.expressions.SimpleHigherOrderFunction.eval(higherOrderFunctions.scala:190) at org.apache.spark.sql.catalyst.expressions.SimpleHigherOrderFunction.eval$(higherOrderFunctions.scala:185) at org.apache.spark.sql.catalyst.expressions.ArrayExists.eval(higherOrderFunctions.scala:369) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificPredicate.eval(Unknown Source) at org.apache.spark.sql.execution.FilterExec.$anonfun$doExecute$3(basicPhysicalOperators.scala:216) at org.apache.spark.sql.execution.FilterExec.$anonfun$doExecute$3$adapted(basicPhysicalOperators.scala:215) ... ``` because the `UnresolvedAttribute`s in `LambdaFunction` are unexpectedly resolved by the rule. This pr modified to use a placeholder `UnresolvedNamedLambdaVariable` to prevent unexpected resolution. ## How was this patch tested? Added a test and modified some tests. Closes #23320 from ueshin/issues/SPARK-26370/hof_resolution. Authored-by: Takuya UESHIN Signed-off-by: Wenchen Fan --- .../analysis/higherOrderFunctions.scala | 5 ++-- .../expressions/higherOrderFunctions.scala | 26 +++++++++++++++++-- .../sql/catalyst/parser/AstBuilder.scala | 7 +++-- .../ResolveLambdaVariablesSuite.scala | 10 ++++--- ...ReplaceNullWithFalseInPredicateSuite.scala | 14 +++++----- .../parser/ExpressionParserSuite.scala | 6 +++-- .../typeCoercion/native/mapZipWith.sql.out | 4 +-- .../spark/sql/DataFrameFunctionsSuite.scala | 20 ++++++++++++++ 8 files changed, 72 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala index a8a7bbd9f9cd0..1cd7f412bb678 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala @@ -150,13 +150,14 @@ case class ResolveLambdaVariables(conf: SQLConf) extends Rule[LogicalPlan] { val lambdaMap = l.arguments.map(v => canonicalizer(v.name) -> v).toMap l.mapChildren(resolve(_, parentLambdaMap ++ lambdaMap)) - case u @ UnresolvedAttribute(name +: nestedFields) => + case u @ UnresolvedNamedLambdaVariable(name +: nestedFields) => parentLambdaMap.get(canonicalizer(name)) match { case Some(lambda) => nestedFields.foldLeft(lambda: Expression) { (expr, fieldName) => ExtractValue(expr, Literal(fieldName), conf.resolver) } - case None => u + case None => + UnresolvedAttribute(u.nameParts) } case _ => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index a8639d29f964d..7141b6e996389 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -22,12 +22,34 @@ import java.util.concurrent.atomic.AtomicReference import scala.collection.mutable import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute, UnresolvedException} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.array.ByteArrayMethods +/** + * A placeholder of lambda variables to prevent unexpected resolution of [[LambdaFunction]]. + */ +case class UnresolvedNamedLambdaVariable(nameParts: Seq[String]) + extends LeafExpression with NamedExpression with Unevaluable { + + override def name: String = + nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".") + + override def exprId: ExprId = throw new UnresolvedException(this, "exprId") + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") + override def qualifier: Seq[String] = throw new UnresolvedException(this, "qualifier") + override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute") + override def newInstance(): NamedExpression = throw new UnresolvedException(this, "newInstance") + override lazy val resolved = false + + override def toString: String = s"lambda '$name" + + override def sql: String = name +} + /** * A named lambda variable. */ @@ -79,7 +101,7 @@ case class LambdaFunction( object LambdaFunction { val identity: LambdaFunction = { - val id = UnresolvedAttribute.quoted("id") + val id = UnresolvedNamedLambdaVariable(Seq("id")) LambdaFunction(id, Seq(id)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 672bffcfc0cad..8959f78b656d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1338,9 +1338,12 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging */ override def visitLambda(ctx: LambdaContext): Expression = withOrigin(ctx) { val arguments = ctx.IDENTIFIER().asScala.map { name => - UnresolvedAttribute.quoted(name.getText) + UnresolvedNamedLambdaVariable(UnresolvedAttribute.quoted(name.getText).nameParts) } - LambdaFunction(expression(ctx.expression), arguments) + val function = expression(ctx.expression).transformUp { + case a: UnresolvedAttribute => UnresolvedNamedLambdaVariable(a.nameParts) + } + LambdaFunction(function, arguments) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala index c4171c75ecd03..a5847ba7c522d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala @@ -49,19 +49,21 @@ class ResolveLambdaVariablesSuite extends PlanTest { comparePlans(Analyzer.execute(plan(e1)), plan(e2)) } + private def lv(s: Symbol) = UnresolvedNamedLambdaVariable(Seq(s.name)) + test("resolution - no op") { checkExpression(key, key) } test("resolution - simple") { - val in = ArrayTransform(values1, LambdaFunction('x.attr + 1, 'x.attr :: Nil)) + val in = ArrayTransform(values1, LambdaFunction(lv('x) + 1, lv('x) :: Nil)) val out = ArrayTransform(values1, LambdaFunction(lvInt + 1, lvInt :: Nil)) checkExpression(in, out) } test("resolution - nested") { val in = ArrayTransform(values2, LambdaFunction( - ArrayTransform('x.attr, LambdaFunction('x.attr + 1, 'x.attr :: Nil)), 'x.attr :: Nil)) + ArrayTransform(lv('x), LambdaFunction(lv('x) + 1, lv('x) :: Nil)), lv('x) :: Nil)) val out = ArrayTransform(values2, LambdaFunction( ArrayTransform(lvArray, LambdaFunction(lvInt + 1, lvInt :: Nil)), lvArray :: Nil)) checkExpression(in, out) @@ -75,14 +77,14 @@ class ResolveLambdaVariablesSuite extends PlanTest { test("fail - name collisions") { val p = plan(ArrayTransform(values1, - LambdaFunction('x.attr + 'X.attr, 'x.attr :: 'X.attr :: Nil))) + LambdaFunction(lv('x) + lv('X), lv('x) :: lv('X) :: Nil))) val msg = intercept[AnalysisException](Analyzer.execute(p)).getMessage assert(msg.contains("arguments should not have names that are semantically the same")) } test("fail - lambda arguments") { val p = plan(ArrayTransform(values1, - LambdaFunction('x.attr + 'y.attr + 'z.attr, 'x.attr :: 'y.attr :: 'z.attr :: Nil))) + LambdaFunction(lv('x) + lv('y) + lv('z), lv('x) :: lv('y) :: lv('z) :: Nil))) val msg = intercept[AnalysisException](Analyzer.execute(p)).getMessage assert(msg.contains("does not match the number of arguments expected")) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala index ee0d04da3e46c..748075bfd6a68 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, Literal, MapFilter, NamedExpression, Or} +import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, Literal, MapFilter, NamedExpression, Or, UnresolvedNamedLambdaVariable} import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} @@ -306,22 +306,24 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testProjection(originalExpr = column, expectedExpr = column) } + private def lv(s: Symbol) = UnresolvedNamedLambdaVariable(Seq(s.name)) + test("replace nulls in lambda function of ArrayFilter") { - testHigherOrderFunc('a, ArrayFilter, Seq('e)) + testHigherOrderFunc('a, ArrayFilter, Seq(lv('e))) } test("replace nulls in lambda function of ArrayExists") { - testHigherOrderFunc('a, ArrayExists, Seq('e)) + testHigherOrderFunc('a, ArrayExists, Seq(lv('e))) } test("replace nulls in lambda function of MapFilter") { - testHigherOrderFunc('m, MapFilter, Seq('k, 'v)) + testHigherOrderFunc('m, MapFilter, Seq(lv('k), lv('v))) } test("inability to replace nulls in arbitrary higher-order function") { val lambdaFunc = LambdaFunction( - function = If('e > 0, Literal(null, BooleanType), TrueLiteral), - arguments = Seq[NamedExpression]('e)) + function = If(lv('e) > 0, Literal(null, BooleanType), TrueLiteral), + arguments = Seq[NamedExpression](lv('e))) val column = ArrayTransform('a, lambdaFunc) testProjection(originalExpr = column, expectedExpr = column) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index b4df22c5b29fa..8bcc69d580d83 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -246,9 +246,11 @@ class ExpressionParserSuite extends PlanTest { intercept("foo(a x)", "extraneous input 'x'") } + private def lv(s: Symbol) = UnresolvedNamedLambdaVariable(Seq(s.name)) + test("lambda functions") { - assertEqual("x -> x + 1", LambdaFunction('x + 1, Seq('x.attr))) - assertEqual("(x, y) -> x + y", LambdaFunction('x + 'y, Seq('x.attr, 'y.attr))) + assertEqual("x -> x + 1", LambdaFunction(lv('x) + 1, Seq(lv('x)))) + assertEqual("(x, y) -> x + y", LambdaFunction(lv('x) + lv('y), Seq(lv('x), lv('y)))) } test("window function expressions") { diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out index 35740094ba53e..86a578ca013df 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out @@ -85,7 +85,7 @@ FROM various_maps struct<> -- !query 5 output org.apache.spark.sql.AnalysisException -cannot resolve 'map_zip_with(various_maps.`decimal_map1`, various_maps.`decimal_map2`, lambdafunction(named_struct(NamePlaceholder(), `k`, NamePlaceholder(), `v1`, NamePlaceholder(), `v2`), `k`, `v1`, `v2`))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,0), decimal(36,35)].; line 1 pos 7 +cannot resolve 'map_zip_with(various_maps.`decimal_map1`, various_maps.`decimal_map2`, lambdafunction(named_struct(NamePlaceholder(), k, NamePlaceholder(), v1, NamePlaceholder(), v2), k, v1, v2))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,0), decimal(36,35)].; line 1 pos 7 -- !query 6 @@ -113,7 +113,7 @@ FROM various_maps struct<> -- !query 8 output org.apache.spark.sql.AnalysisException -cannot resolve 'map_zip_with(various_maps.`decimal_map2`, various_maps.`int_map`, lambdafunction(named_struct(NamePlaceholder(), `k`, NamePlaceholder(), `v1`, NamePlaceholder(), `v2`), `k`, `v1`, `v2`))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,35), int].; line 1 pos 7 +cannot resolve 'map_zip_with(various_maps.`decimal_map2`, various_maps.`int_map`, lambdafunction(named_struct(NamePlaceholder(), k, NamePlaceholder(), v1, NamePlaceholder(), v2), k, v1, v2))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,35), int].; line 1 pos 7 -- !query 9 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index e6d1a038a5918..b7fc9570af919 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2908,6 +2908,26 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } assert(ex.getMessage.contains("Cannot use null as map key")) } + + test("SPARK-26370: Fix resolution of higher-order function for the same identifier") { + val df = Seq( + (Seq(1, 9, 8, 7), 1, 2), + (Seq(5, 9, 7), 2, 2), + (Seq.empty, 3, 2), + (null, 4, 2) + ).toDF("i", "x", "d") + + checkAnswer(df.selectExpr("x", "exists(i, x -> x % d == 0)"), + Seq( + Row(1, true), + Row(2, false), + Row(3, false), + Row(4, null))) + checkAnswer(df.filter("exists(i, x -> x % d == 0)"), + Seq(Row(Seq(1, 9, 8, 7), 1, 2))) + checkAnswer(df.select("x").filter("exists(i, x -> x % d == 0)"), + Seq(Row(1))) + } } object DataFrameFunctionsSuite { From 736df41874eacad27329f8eb2f445c36c3576606 Mon Sep 17 00:00:00 2001 From: CarolinPeng <00244106@zte.intra> Date: Fri, 14 Dec 2018 14:23:21 -0600 Subject: [PATCH 074/194] [MINOR][SQL] Some errors in the notes. ## What changes were proposed in this pull request? When using ordinals to access linked list, the time cost is O(n). ## How was this patch tested? Existing tests. Closes #23280 from CarolinePeng/update_Two. Authored-by: CarolinPeng <00244106@zte.intra> Signed-off-by: Sean Owen --- .../org/apache/spark/sql/catalyst/expressions/package.scala | 2 +- .../apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 0083ee64653e9..bf18e8bcb52df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -101,7 +101,7 @@ package object expressions { StructType(attrs.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))) } - // It's possible that `attrs` is a linked list, which can lead to bad O(n^2) loops when + // It's possible that `attrs` is a linked list, which can lead to bad O(n) loops when // accessing attributes by their ordinals. To avoid this performance penalty, convert the input // to an array. @transient private lazy val attrsArray = attrs.toArray diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index a520eba001af1..3ad2ee6923615 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -93,7 +93,7 @@ abstract class LogicalPlan /** * Optionally resolves the given strings to a [[NamedExpression]] using the input from all child * nodes of this LogicalPlan. The attribute is expressed as - * as string in the following form: `[scope].AttributeName.[nested].[fields]...`. + * string in the following form: `[scope].AttributeName.[nested].[fields]...`. */ def resolveChildren( nameParts: Seq[String], From ff2a8b8816dced0b42dece1349c1272d33f04021 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 15 Dec 2018 13:52:07 +0800 Subject: [PATCH 075/194] [SPARK-26265][CORE][FOLLOWUP] Put freePage into a finally block ## What changes were proposed in this pull request? Based on the [comment](https://github.com/apache/spark/pull/23272#discussion_r240735509), it seems to be better to put `freePage` into a `finally` block. This patch as a follow-up to do so. ## How was this patch tested? Existing tests. Closes #23294 from viirya/SPARK-26265-followup. Authored-by: Liang-Chi Hsieh Signed-off-by: Hyukjin Kwon --- .../spark/unsafe/map/BytesToBytesMap.java | 57 ++++++++++--------- 1 file changed, 30 insertions(+), 27 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index fbba002f1f80f..7df8aafb2b674 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -262,36 +262,39 @@ private void advanceToNextPage() { // reference to the page to free and free it after releasing the lock of `MapIterator`. MemoryBlock pageToFree = null; - synchronized (this) { - int nextIdx = dataPages.indexOf(currentPage) + 1; - if (destructive && currentPage != null) { - dataPages.remove(currentPage); - pageToFree = currentPage; - nextIdx --; - } - if (dataPages.size() > nextIdx) { - currentPage = dataPages.get(nextIdx); - pageBaseObject = currentPage.getBaseObject(); - offsetInPage = currentPage.getBaseOffset(); - recordsInPage = UnsafeAlignedOffset.getSize(pageBaseObject, offsetInPage); - offsetInPage += UnsafeAlignedOffset.getUaoSize(); - } else { - currentPage = null; - if (reader != null) { - handleFailedDelete(); + try { + synchronized (this) { + int nextIdx = dataPages.indexOf(currentPage) + 1; + if (destructive && currentPage != null) { + dataPages.remove(currentPage); + pageToFree = currentPage; + nextIdx--; } - try { - Closeables.close(reader, /* swallowIOException = */ false); - reader = spillWriters.getFirst().getReader(serializerManager); - recordsInPage = -1; - } catch (IOException e) { - // Scala iterator does not handle exception - Platform.throwException(e); + if (dataPages.size() > nextIdx) { + currentPage = dataPages.get(nextIdx); + pageBaseObject = currentPage.getBaseObject(); + offsetInPage = currentPage.getBaseOffset(); + recordsInPage = UnsafeAlignedOffset.getSize(pageBaseObject, offsetInPage); + offsetInPage += UnsafeAlignedOffset.getUaoSize(); + } else { + currentPage = null; + if (reader != null) { + handleFailedDelete(); + } + try { + Closeables.close(reader, /* swallowIOException = */ false); + reader = spillWriters.getFirst().getReader(serializerManager); + recordsInPage = -1; + } catch (IOException e) { + // Scala iterator does not handle exception + Platform.throwException(e); + } } } - } - if (pageToFree != null) { - freePage(pageToFree); + } finally { + if (pageToFree != null) { + freePage(pageToFree); + } } } From aa584728ff0c03fbd3c52ab289af7cb1456f726a Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Sat, 15 Dec 2018 13:55:24 +0800 Subject: [PATCH 076/194] [SPARK-26362][CORE] Remove 'spark.driver.allowMultipleContexts' to disallow multiple creation of SparkContexts ## What changes were proposed in this pull request? Multiple SparkContexts are discouraged and it has been warning for last 4 years, see SPARK-4180. It could cause arbitrary and mysterious error cases, see SPARK-2243. Honestly, I didn't even know Spark still allows it, which looks never officially supported, see SPARK-2243. I believe It should be good timing now to remove this configuration. ## How was this patch tested? Each doc was manually checked and manually tested: ``` $ ./bin/spark-shell --conf=spark.driver.allowMultipleContexts=true ... scala> new SparkContext() org.apache.spark.SparkException: Only one SparkContext should be running in this JVM (see SPARK-2243).The currently running SparkContext was created at: org.apache.spark.sql.SparkSession$Builder.getOrCreate(SparkSession.scala:939) ... org.apache.spark.SparkContext$.$anonfun$assertNoOtherContextIsRunning$2(SparkContext.scala:2435) at scala.Option.foreach(Option.scala:274) at org.apache.spark.SparkContext$.assertNoOtherContextIsRunning(SparkContext.scala:2432) at org.apache.spark.SparkContext$.markPartiallyConstructed(SparkContext.scala:2509) at org.apache.spark.SparkContext.(SparkContext.scala:80) at org.apache.spark.SparkContext.(SparkContext.scala:112) ... 49 elided ``` Closes #23311 from HyukjinKwon/SPARK-26362. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- .../scala/org/apache/spark/SparkContext.scala | 65 +++++++------------ .../spark/api/java/JavaSparkContext.scala | 4 +- .../org/apache/spark/SparkContextSuite.scala | 19 +----- .../ExternalClusterManagerSuite.scala | 3 +- docs/rdd-programming-guide.md | 2 +- project/MimaExcludes.scala | 4 ++ python/pyspark/context.py | 3 + .../execution/ExchangeCoordinatorSuite.scala | 1 - 8 files changed, 34 insertions(+), 67 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 696dafda6d1ec..09cc346db0ed2 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -64,9 +64,8 @@ import org.apache.spark.util.logging.DriverLogger * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark * cluster, and can be used to create RDDs, accumulators and broadcast variables on that cluster. * - * Only one SparkContext may be active per JVM. You must `stop()` the active SparkContext before - * creating a new one. This limitation may eventually be removed; see SPARK-2243 for more details. - * + * @note Only one `SparkContext` should be active per JVM. You must `stop()` the + * active `SparkContext` before creating a new one. * @param config a Spark Config object describing the application configuration. Any settings in * this config overrides the default configs as well as system properties. */ @@ -75,14 +74,10 @@ class SparkContext(config: SparkConf) extends Logging { // The call site where this SparkContext was constructed. private val creationSite: CallSite = Utils.getCallSite() - // If true, log warnings instead of throwing exceptions when multiple SparkContexts are active - private val allowMultipleContexts: Boolean = - config.getBoolean("spark.driver.allowMultipleContexts", false) - // In order to prevent multiple SparkContexts from being active at the same time, mark this // context as having started construction. // NOTE: this must be placed at the beginning of the SparkContext constructor. - SparkContext.markPartiallyConstructed(this, allowMultipleContexts) + SparkContext.markPartiallyConstructed(this) val startTime = System.currentTimeMillis() @@ -2392,7 +2387,7 @@ class SparkContext(config: SparkConf) extends Logging { // In order to prevent multiple SparkContexts from being active at the same time, mark this // context as having finished construction. // NOTE: this must be placed at the end of the SparkContext constructor. - SparkContext.setActiveContext(this, allowMultipleContexts) + SparkContext.setActiveContext(this) } /** @@ -2409,18 +2404,18 @@ object SparkContext extends Logging { private val SPARK_CONTEXT_CONSTRUCTOR_LOCK = new Object() /** - * The active, fully-constructed SparkContext. If no SparkContext is active, then this is `null`. + * The active, fully-constructed SparkContext. If no SparkContext is active, then this is `null`. * - * Access to this field is guarded by SPARK_CONTEXT_CONSTRUCTOR_LOCK. + * Access to this field is guarded by `SPARK_CONTEXT_CONSTRUCTOR_LOCK`. */ private val activeContext: AtomicReference[SparkContext] = new AtomicReference[SparkContext](null) /** - * Points to a partially-constructed SparkContext if some thread is in the SparkContext + * Points to a partially-constructed SparkContext if another thread is in the SparkContext * constructor, or `None` if no SparkContext is being constructed. * - * Access to this field is guarded by SPARK_CONTEXT_CONSTRUCTOR_LOCK + * Access to this field is guarded by `SPARK_CONTEXT_CONSTRUCTOR_LOCK`. */ private var contextBeingConstructed: Option[SparkContext] = None @@ -2428,24 +2423,16 @@ object SparkContext extends Logging { * Called to ensure that no other SparkContext is running in this JVM. * * Throws an exception if a running context is detected and logs a warning if another thread is - * constructing a SparkContext. This warning is necessary because the current locking scheme + * constructing a SparkContext. This warning is necessary because the current locking scheme * prevents us from reliably distinguishing between cases where another context is being * constructed and cases where another constructor threw an exception. */ - private def assertNoOtherContextIsRunning( - sc: SparkContext, - allowMultipleContexts: Boolean): Unit = { + private def assertNoOtherContextIsRunning(sc: SparkContext): Unit = { SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { Option(activeContext.get()).filter(_ ne sc).foreach { ctx => - val errMsg = "Only one SparkContext may be running in this JVM (see SPARK-2243)." + - " To ignore this error, set spark.driver.allowMultipleContexts = true. " + + val errMsg = "Only one SparkContext should be running in this JVM (see SPARK-2243)." + s"The currently running SparkContext was created at:\n${ctx.creationSite.longForm}" - val exception = new SparkException(errMsg) - if (allowMultipleContexts) { - logWarning("Multiple running SparkContexts detected in the same JVM!", exception) - } else { - throw exception - } + throw new SparkException(errMsg) } contextBeingConstructed.filter(_ ne sc).foreach { otherContext => @@ -2454,7 +2441,7 @@ object SparkContext extends Logging { val otherContextCreationSite = Option(otherContext.creationSite).map(_.longForm).getOrElse("unknown location") val warnMsg = "Another SparkContext is being constructed (or threw an exception in its" + - " constructor). This may indicate an error, since only one SparkContext may be" + + " constructor). This may indicate an error, since only one SparkContext should be" + " running in this JVM (see SPARK-2243)." + s" The other SparkContext was created at:\n$otherContextCreationSite" logWarning(warnMsg) @@ -2467,8 +2454,6 @@ object SparkContext extends Logging { * singleton object. Because we can only have one active SparkContext per JVM, * this is useful when applications may wish to share a SparkContext. * - * @note This function cannot be used to create multiple SparkContext instances - * even if multiple contexts are allowed. * @param config `SparkConfig` that will be used for initialisation of the `SparkContext` * @return current `SparkContext` (or a new one if it wasn't created before the function call) */ @@ -2477,7 +2462,7 @@ object SparkContext extends Logging { // from assertNoOtherContextIsRunning within setActiveContext SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { if (activeContext.get() == null) { - setActiveContext(new SparkContext(config), allowMultipleContexts = false) + setActiveContext(new SparkContext(config)) } else { if (config.getAll.nonEmpty) { logWarning("Using an existing SparkContext; some configuration may not take effect.") @@ -2494,14 +2479,12 @@ object SparkContext extends Logging { * * This method allows not passing a SparkConf (useful if just retrieving). * - * @note This function cannot be used to create multiple SparkContext instances - * even if multiple contexts are allowed. * @return current `SparkContext` (or a new one if wasn't created before the function call) */ def getOrCreate(): SparkContext = { SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { if (activeContext.get() == null) { - setActiveContext(new SparkContext(), allowMultipleContexts = false) + setActiveContext(new SparkContext()) } activeContext.get() } @@ -2516,16 +2499,14 @@ object SparkContext extends Logging { /** * Called at the beginning of the SparkContext constructor to ensure that no SparkContext is - * running. Throws an exception if a running context is detected and logs a warning if another - * thread is constructing a SparkContext. This warning is necessary because the current locking + * running. Throws an exception if a running context is detected and logs a warning if another + * thread is constructing a SparkContext. This warning is necessary because the current locking * scheme prevents us from reliably distinguishing between cases where another context is being * constructed and cases where another constructor threw an exception. */ - private[spark] def markPartiallyConstructed( - sc: SparkContext, - allowMultipleContexts: Boolean): Unit = { + private[spark] def markPartiallyConstructed(sc: SparkContext): Unit = { SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { - assertNoOtherContextIsRunning(sc, allowMultipleContexts) + assertNoOtherContextIsRunning(sc) contextBeingConstructed = Some(sc) } } @@ -2534,18 +2515,16 @@ object SparkContext extends Logging { * Called at the end of the SparkContext constructor to ensure that no other SparkContext has * raced with this constructor and started. */ - private[spark] def setActiveContext( - sc: SparkContext, - allowMultipleContexts: Boolean): Unit = { + private[spark] def setActiveContext(sc: SparkContext): Unit = { SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { - assertNoOtherContextIsRunning(sc, allowMultipleContexts) + assertNoOtherContextIsRunning(sc) contextBeingConstructed = None activeContext.set(sc) } } /** - * Clears the active SparkContext metadata. This is called by `SparkContext#stop()`. It's + * Clears the active SparkContext metadata. This is called by `SparkContext#stop()`. It's * also called in unit tests to prevent a flood of warnings from test suites that don't / can't * properly clean up their SparkContexts. */ diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 03f259d73e975..2f74d09b3a2bc 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -40,8 +40,8 @@ import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, NewHadoopRDD} * A Java-friendly version of [[org.apache.spark.SparkContext]] that returns * [[org.apache.spark.api.java.JavaRDD]]s and works with Java collections instead of Scala ones. * - * Only one SparkContext may be active per JVM. You must `stop()` the active SparkContext before - * creating a new one. This limitation may eventually be removed; see SPARK-2243 for more details. + * @note Only one `SparkContext` should be active per JVM. You must `stop()` the + * active `SparkContext` before creating a new one. */ class JavaSparkContext(val sc: SparkContext) extends Closeable { diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index ec4c7efb5835a..66de2f2ac86a4 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -44,7 +44,6 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu test("Only one SparkContext may be active at a time") { // Regression test for SPARK-4180 val conf = new SparkConf().setAppName("test").setMaster("local") - .set("spark.driver.allowMultipleContexts", "false") sc = new SparkContext(conf) val envBefore = SparkEnv.get // A SparkContext is already running, so we shouldn't be able to create a second one @@ -58,7 +57,7 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu } test("Can still construct a new SparkContext after failing to construct a previous one") { - val conf = new SparkConf().set("spark.driver.allowMultipleContexts", "false") + val conf = new SparkConf() // This is an invalid configuration (no app name or master URL) intercept[SparkException] { new SparkContext(conf) @@ -67,18 +66,6 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu sc = new SparkContext(conf.setMaster("local").setAppName("test")) } - test("Check for multiple SparkContexts can be disabled via undocumented debug option") { - var secondSparkContext: SparkContext = null - try { - val conf = new SparkConf().setAppName("test").setMaster("local") - .set("spark.driver.allowMultipleContexts", "true") - sc = new SparkContext(conf) - secondSparkContext = new SparkContext(conf) - } finally { - Option(secondSparkContext).foreach(_.stop()) - } - } - test("Test getOrCreate") { var sc2: SparkContext = null SparkContext.clearActiveContext() @@ -92,10 +79,6 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu assert(sc === sc2) assert(sc eq sc2) - // Try creating second context to confirm that it's still possible, if desired - sc2 = new SparkContext(new SparkConf().setAppName("test3").setMaster("local") - .set("spark.driver.allowMultipleContexts", "true")) - sc2.stop() } diff --git a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala index e8e8632fe7f8a..17c0a262056f4 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala @@ -25,8 +25,7 @@ import org.apache.spark.util.AccumulatorV2 class ExternalClusterManagerSuite extends SparkFunSuite with LocalSparkContext { test("launch of backend and scheduler") { - val conf = new SparkConf().setMaster("myclusterManager"). - setAppName("testcm").set("spark.driver.allowMultipleContexts", "true") + val conf = new SparkConf().setMaster("myclusterManager").setAppName("testcm") sc = new SparkContext(conf) // check if the scheduler components are created and initialized sc.schedulerBackend match { diff --git a/docs/rdd-programming-guide.md b/docs/rdd-programming-guide.md index 2d1ddae5780de..308a8ea653909 100644 --- a/docs/rdd-programming-guide.md +++ b/docs/rdd-programming-guide.md @@ -138,7 +138,7 @@ The first thing a Spark program must do is to create a [SparkContext](api/scala/ how to access a cluster. To create a `SparkContext` you first need to build a [SparkConf](api/scala/index.html#org.apache.spark.SparkConf) object that contains information about your application. -Only one SparkContext may be active per JVM. You must `stop()` the active SparkContext before creating a new one. +Only one SparkContext should be active per JVM. You must `stop()` the active SparkContext before creating a new one. {% highlight scala %} val conf = new SparkConf().setAppName(appName).setMaster(master) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 883913332ca1e..7bb70a29195d6 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -220,6 +220,10 @@ object MimaExcludes { // [SPARK-26139] Implement shuffle write metrics in SQL ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ShuffleDependency.this"), + // [SPARK-26362][CORE] Remove 'spark.driver.allowMultipleContexts' to disallow multiple creation of SparkContexts + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkContext.setActiveContext"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkContext.markPartiallyConstructed"), + // Data Source V2 API changes (problem: Problem) => problem match { case MissingClassProblem(cls) => diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 1180bf91baa5a..6137ed25a0dd9 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -63,6 +63,9 @@ class SparkContext(object): Main entry point for Spark functionality. A SparkContext represents the connection to a Spark cluster, and can be used to create L{RDD} and broadcast variables on that cluster. + + .. note:: Only one :class:`SparkContext` should be active per JVM. You must `stop()` + the active :class:`SparkContext` before creating a new one. """ _gateway = None diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala index 6ad025f37e440..4a439940beb74 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -263,7 +263,6 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { .setMaster("local[*]") .setAppName("test") .set("spark.ui.enabled", "false") - .set("spark.driver.allowMultipleContexts", "true") .set(SQLConf.SHUFFLE_PARTITIONS.key, "5") .set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true") .set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "-1") From 6fce1afee174a7a834b12b286bc3392e41e35244 Mon Sep 17 00:00:00 2001 From: Jing Chen He Date: Sat, 15 Dec 2018 08:41:16 -0600 Subject: [PATCH 077/194] [SPARK-26315][PYSPARK] auto cast threshold from Integer to Float in approxSimilarityJoin of BucketedRandomProjectionLSHModel ## What changes were proposed in this pull request? If the input parameter 'threshold' to the function approxSimilarityJoin is not a float, we would get an exception. The fix is to convert the 'threshold' into a float before calling the java implementation method. ## How was this patch tested? Added a new test case. Without this fix, the test will throw an exception as reported in the JIRA. With the fix, the test passes. Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #23313 from jerryjch/SPARK-26315. Authored-by: Jing Chen He Signed-off-by: Sean Owen --- python/pyspark/ml/feature.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index c9507c20918e3..08ae58246adb6 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -192,6 +192,7 @@ def approxSimilarityJoin(self, datasetA, datasetB, threshold, distCol="distCol") "datasetA" and "datasetB", and a column "distCol" is added to show the distance between each pair. """ + threshold = TypeConverters.toFloat(threshold) return self._call_java("approxSimilarityJoin", datasetA, datasetB, threshold, distCol) @@ -239,6 +240,16 @@ class BucketedRandomProjectionLSH(JavaEstimator, LSHParams, HasInputCol, HasOutp | 3| 6| 2.23606797749979| +---+---+-----------------+ ... + >>> model.approxSimilarityJoin(df, df2, 3, distCol="EuclideanDistance").select( + ... col("datasetA.id").alias("idA"), + ... col("datasetB.id").alias("idB"), + ... col("EuclideanDistance")).show() + +---+---+-----------------+ + |idA|idB|EuclideanDistance| + +---+---+-----------------+ + | 3| 6| 2.23606797749979| + +---+---+-----------------+ + ... >>> brpPath = temp_path + "/brp" >>> brp.save(brpPath) >>> brp2 = BucketedRandomProjectionLSH.load(brpPath) From ecd5aa1db8542214483b8aa51f63b6fed79fe36b Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sun, 16 Dec 2018 09:32:13 +0800 Subject: [PATCH 078/194] [SPARK-26243][SQL] Use java.time API for parsing timestamps and dates from JSON ## What changes were proposed in this pull request? In the PR, I propose to switch on **java.time API** for parsing timestamps and dates from JSON inputs with microseconds precision. The SQL config `spark.sql.legacy.timeParser.enabled` allow to switch back to previous behavior with using `java.text.SimpleDateFormat`/`FastDateFormat` for parsing/generating timestamps/dates. ## How was this patch tested? It was tested by `JsonExpressionsSuite`, `JsonFunctionsSuite` and `JsonSuite`. Closes #23196 from MaxGekk/json-time-parser. Lead-authored-by: Maxim Gekk Co-authored-by: Maxim Gekk Signed-off-by: Wenchen Fan --- docs/sql-migration-guide-upgrade.md | 2 +- .../sql/catalyst/csv/CSVInferSchema.scala | 6 +- .../sql/catalyst/csv/UnivocityGenerator.scala | 8 +- .../sql/catalyst/csv/UnivocityParser.scala | 6 +- .../spark/sql/catalyst/json/JSONOptions.scala | 10 +- .../sql/catalyst/json/JacksonGenerator.scala | 14 +- .../sql/catalyst/json/JacksonParser.scala | 35 +-- ...rmatter.scala => TimestampFormatter.scala} | 93 ++++---- .../sql/util/DateTimeFormatterSuite.scala | 103 --------- .../util/DateTimestampFormatterSuite.scala | 174 +++++++++++++++ .../datasources/json/JsonSuite.scala | 201 ++++++++++-------- .../sql/sources/HadoopFsRelationTest.scala | 105 ++++----- 12 files changed, 422 insertions(+), 335 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/{DateTimeFormatter.scala => TimestampFormatter.scala} (63%) delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateTimeFormatterSuite.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateTimestampFormatterSuite.scala diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index 8834e8991d8c3..115fc6516fb4c 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -35,7 +35,7 @@ displayTitle: Spark SQL Upgrading Guide - Spark applications which are built with Spark version 2.4 and prior, and call methods of `UserDefinedFunction`, need to be re-compiled with Spark 3.0, as they are not binary compatible with Spark 3.0. - - Since Spark 3.0, CSV datasource uses java.time API for parsing and generating CSV content. New formatting implementation supports date/timestamp patterns conformed to ISO 8601. To switch back to the implementation used in Spark 2.4 and earlier, set `spark.sql.legacy.timeParser.enabled` to `true`. + - Since Spark 3.0, CSV/JSON datasources use java.time API for parsing and generating CSV/JSON content. In Spark version 2.4 and earlier, java.text.SimpleDateFormat is used for the same purpuse with fallbacks to the parsing mechanisms of Spark 2.0 and 1.x. For example, `2018-12-08 10:39:21.123` with the pattern `yyyy-MM-dd'T'HH:mm:ss.SSS` cannot be parsed since Spark 3.0 because the timestamp does not match to the pattern but it can be parsed by earlier Spark versions due to a fallback to `Timestamp.valueOf`. To parse the same timestamp since Spark 3.0, the pattern should be `yyyy-MM-dd HH:mm:ss.SSS`. To switch back to the implementation used in Spark 2.4 and earlier, set `spark.sql.legacy.timeParser.enabled` to `true`. - In Spark version 2.4 and earlier, CSV datasource converts a malformed CSV string to a row with all `null`s in the PERMISSIVE mode. Since Spark 3.0, the returned row can contain non-`null` fields if some of CSV column values were parsed and converted to desired types successfully. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala index 345dc4d41993e..35ade136cc607 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala @@ -22,13 +22,13 @@ import scala.util.control.Exception.allCatch import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.expressions.ExprUtils -import org.apache.spark.sql.catalyst.util.DateTimeFormatter +import org.apache.spark.sql.catalyst.util.TimestampFormatter import org.apache.spark.sql.types._ class CSVInferSchema(val options: CSVOptions) extends Serializable { @transient - private lazy val timeParser = DateTimeFormatter( + private lazy val timestampParser = TimestampFormatter( options.timestampFormat, options.timeZone, options.locale) @@ -160,7 +160,7 @@ class CSVInferSchema(val options: CSVOptions) extends Serializable { private def tryParseTimestamp(field: String): DataType = { // This case infers a custom `dataFormat` is set. - if ((allCatch opt timeParser.parse(field)).isDefined) { + if ((allCatch opt timestampParser.parse(field)).isDefined) { TimestampType } else { tryParseBoolean(field) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala index af09cd6c8449b..f012d96138f37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala @@ -22,7 +22,7 @@ import java.io.Writer import com.univocity.parsers.csv.CsvWriter import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeFormatter} +import org.apache.spark.sql.catalyst.util.{DateFormatter, TimestampFormatter} import org.apache.spark.sql.types._ class UnivocityGenerator( @@ -41,18 +41,18 @@ class UnivocityGenerator( private val valueConverters: Array[ValueConverter] = schema.map(_.dataType).map(makeConverter).toArray - private val timeFormatter = DateTimeFormatter( + private val timestampFormatter = TimestampFormatter( options.timestampFormat, options.timeZone, options.locale) - private val dateFormatter = DateFormatter(options.dateFormat, options.timeZone, options.locale) + private val dateFormatter = DateFormatter(options.dateFormat, options.locale) private def makeConverter(dataType: DataType): ValueConverter = dataType match { case DateType => (row: InternalRow, ordinal: Int) => dateFormatter.format(row.getInt(ordinal)) case TimestampType => - (row: InternalRow, ordinal: Int) => timeFormatter.format(row.getLong(ordinal)) + (row: InternalRow, ordinal: Int) => timestampFormatter.format(row.getLong(ordinal)) case udt: UserDefinedType[_] => makeConverter(udt.sqlType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala index 0f375e036029c..ed089120055e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala @@ -74,11 +74,11 @@ class UnivocityParser( private val row = new GenericInternalRow(requiredSchema.length) - private val timeFormatter = DateTimeFormatter( + private val timestampFormatter = TimestampFormatter( options.timestampFormat, options.timeZone, options.locale) - private val dateFormatter = DateFormatter(options.dateFormat, options.timeZone, options.locale) + private val dateFormatter = DateFormatter(options.dateFormat, options.locale) // Retrieve the raw record string. private def getCurrentInput: UTF8String = { @@ -158,7 +158,7 @@ class UnivocityParser( } case _: TimestampType => (d: String) => - nullSafeDatum(d, name, nullable, options)(timeFormatter.parse) + nullSafeDatum(d, name, nullable, options)(timestampFormatter.parse) case _: DateType => (d: String) => nullSafeDatum(d, name, nullable, options)(dateFormatter.parse) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index e10b8a327c01a..eaff3fa7bec25 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -21,7 +21,6 @@ import java.nio.charset.{Charset, StandardCharsets} import java.util.{Locale, TimeZone} import com.fasterxml.jackson.core.{JsonFactory, JsonParser} -import org.apache.commons.lang3.time.FastDateFormat import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.util._ @@ -82,13 +81,10 @@ private[sql] class JSONOptions( val timeZone: TimeZone = DateTimeUtils.getTimeZone( parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId)) - // Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe. - val dateFormat: FastDateFormat = - FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), locale) + val dateFormat: String = parameters.getOrElse("dateFormat", "yyyy-MM-dd") - val timestampFormat: FastDateFormat = - FastDateFormat.getInstance( - parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, locale) + val timestampFormat: String = + parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX") val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala index d02a2be8ddad6..951f5190cd504 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala @@ -23,7 +23,7 @@ import com.fasterxml.jackson.core._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters -import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, MapData} +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ /** @@ -77,6 +77,12 @@ private[sql] class JacksonGenerator( private val lineSeparator: String = options.lineSeparatorInWrite + private val timestampFormatter = TimestampFormatter( + options.timestampFormat, + options.timeZone, + options.locale) + private val dateFormatter = DateFormatter(options.dateFormat, options.locale) + private def makeWriter(dataType: DataType): ValueWriter = dataType match { case NullType => (row: SpecializedGetters, ordinal: Int) => @@ -116,14 +122,12 @@ private[sql] class JacksonGenerator( case TimestampType => (row: SpecializedGetters, ordinal: Int) => - val timestampString = - options.timestampFormat.format(DateTimeUtils.toJavaTimestamp(row.getLong(ordinal))) + val timestampString = timestampFormatter.format(row.getLong(ordinal)) gen.writeString(timestampString) case DateType => (row: SpecializedGetters, ordinal: Int) => - val dateString = - options.dateFormat.format(DateTimeUtils.toJavaDate(row.getInt(ordinal))) + val dateString = dateFormatter.format(row.getInt(ordinal)) gen.writeString(dateString) case BinaryType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 7e3bd4df51bb7..3f245e1400fa1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -55,6 +55,12 @@ class JacksonParser( private val factory = new JsonFactory() options.setJacksonOptions(factory) + private val timestampFormatter = TimestampFormatter( + options.timestampFormat, + options.timeZone, + options.locale) + private val dateFormatter = DateFormatter(options.dateFormat, options.locale) + /** * Create a converter which converts the JSON documents held by the `JsonParser` * to a value according to a desired schema. This is a wrapper for the method @@ -218,17 +224,7 @@ class JacksonParser( case TimestampType => (parser: JsonParser) => parseJsonToken[java.lang.Long](parser, dataType) { case VALUE_STRING if parser.getTextLength >= 1 => - val stringValue = parser.getText - // This one will lose microseconds parts. - // See https://issues.apache.org/jira/browse/SPARK-10681. - Long.box { - Try(options.timestampFormat.parse(stringValue).getTime * 1000L) - .getOrElse { - // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards - // compatibility. - DateTimeUtils.stringToTime(stringValue).getTime * 1000L - } - } + timestampFormatter.parse(parser.getText) case VALUE_NUMBER_INT => parser.getLongValue * 1000000L @@ -237,22 +233,7 @@ class JacksonParser( case DateType => (parser: JsonParser) => parseJsonToken[java.lang.Integer](parser, dataType) { case VALUE_STRING if parser.getTextLength >= 1 => - val stringValue = parser.getText - // This one will lose microseconds parts. - // See https://issues.apache.org/jira/browse/SPARK-10681.x - Int.box { - Try(DateTimeUtils.millisToDays(options.dateFormat.parse(stringValue).getTime)) - .orElse { - // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards - // compatibility. - Try(DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(stringValue).getTime)) - } - .getOrElse { - // In Spark 1.5.0, we store the data as number of days since epoch in string. - // So, we just convert it to Int. - stringValue.toInt - } - } + dateFormatter.parse(parser.getText) } case BinaryType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala similarity index 63% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatter.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala index ad1f4131de2f6..2b8d22dde9267 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.util import java.time._ import java.time.format.DateTimeFormatterBuilder -import java.time.temporal.{ChronoField, TemporalQueries} +import java.time.temporal.{ChronoField, TemporalAccessor, TemporalQueries} import java.util.{Locale, TimeZone} import scala.util.Try @@ -28,31 +28,44 @@ import org.apache.commons.lang3.time.FastDateFormat import org.apache.spark.sql.internal.SQLConf -sealed trait DateTimeFormatter { +sealed trait TimestampFormatter { def parse(s: String): Long // returns microseconds since epoch def format(us: Long): String } -class Iso8601DateTimeFormatter( +trait FormatterUtils { + protected def zoneId: ZoneId + protected def buildFormatter( + pattern: String, + locale: Locale): java.time.format.DateTimeFormatter = { + new DateTimeFormatterBuilder() + .appendPattern(pattern) + .parseDefaulting(ChronoField.YEAR_OF_ERA, 1970) + .parseDefaulting(ChronoField.MONTH_OF_YEAR, 1) + .parseDefaulting(ChronoField.DAY_OF_MONTH, 1) + .parseDefaulting(ChronoField.HOUR_OF_DAY, 0) + .parseDefaulting(ChronoField.MINUTE_OF_HOUR, 0) + .parseDefaulting(ChronoField.SECOND_OF_MINUTE, 0) + .toFormatter(locale) + } + protected def toInstantWithZoneId(temporalAccessor: TemporalAccessor): java.time.Instant = { + val localDateTime = LocalDateTime.from(temporalAccessor) + val zonedDateTime = ZonedDateTime.of(localDateTime, zoneId) + Instant.from(zonedDateTime) + } +} + +class Iso8601TimestampFormatter( pattern: String, timeZone: TimeZone, - locale: Locale) extends DateTimeFormatter { - val formatter = new DateTimeFormatterBuilder() - .appendPattern(pattern) - .parseDefaulting(ChronoField.YEAR_OF_ERA, 1970) - .parseDefaulting(ChronoField.MONTH_OF_YEAR, 1) - .parseDefaulting(ChronoField.DAY_OF_MONTH, 1) - .parseDefaulting(ChronoField.HOUR_OF_DAY, 0) - .parseDefaulting(ChronoField.MINUTE_OF_HOUR, 0) - .parseDefaulting(ChronoField.SECOND_OF_MINUTE, 0) - .toFormatter(locale) + locale: Locale) extends TimestampFormatter with FormatterUtils { + val zoneId = timeZone.toZoneId + val formatter = buildFormatter(pattern, locale) def toInstant(s: String): Instant = { val temporalAccessor = formatter.parse(s) if (temporalAccessor.query(TemporalQueries.offset()) == null) { - val localDateTime = LocalDateTime.from(temporalAccessor) - val zonedDateTime = ZonedDateTime.of(localDateTime, timeZone.toZoneId) - Instant.from(zonedDateTime) + toInstantWithZoneId(temporalAccessor) } else { Instant.from(temporalAccessor) } @@ -75,10 +88,10 @@ class Iso8601DateTimeFormatter( } } -class LegacyDateTimeFormatter( +class LegacyTimestampFormatter( pattern: String, timeZone: TimeZone, - locale: Locale) extends DateTimeFormatter { + locale: Locale) extends TimestampFormatter { val format = FastDateFormat.getInstance(pattern, timeZone, locale) protected def toMillis(s: String): Long = format.parse(s).getTime @@ -90,21 +103,21 @@ class LegacyDateTimeFormatter( } } -class LegacyFallbackDateTimeFormatter( +class LegacyFallbackTimestampFormatter( pattern: String, timeZone: TimeZone, - locale: Locale) extends LegacyDateTimeFormatter(pattern, timeZone, locale) { + locale: Locale) extends LegacyTimestampFormatter(pattern, timeZone, locale) { override def toMillis(s: String): Long = { Try {super.toMillis(s)}.getOrElse(DateTimeUtils.stringToTime(s).getTime) } } -object DateTimeFormatter { - def apply(format: String, timeZone: TimeZone, locale: Locale): DateTimeFormatter = { +object TimestampFormatter { + def apply(format: String, timeZone: TimeZone, locale: Locale): TimestampFormatter = { if (SQLConf.get.legacyTimeParserEnabled) { - new LegacyFallbackDateTimeFormatter(format, timeZone, locale) + new LegacyFallbackTimestampFormatter(format, timeZone, locale) } else { - new Iso8601DateTimeFormatter(format, timeZone, locale) + new Iso8601TimestampFormatter(format, timeZone, locale) } } } @@ -116,13 +129,19 @@ sealed trait DateFormatter { class Iso8601DateFormatter( pattern: String, - timeZone: TimeZone, - locale: Locale) extends DateFormatter { + locale: Locale) extends DateFormatter with FormatterUtils { + + val zoneId = ZoneId.of("UTC") + + val formatter = buildFormatter(pattern, locale) - val dateTimeFormatter = new Iso8601DateTimeFormatter(pattern, timeZone, locale) + def toInstant(s: String): Instant = { + val temporalAccessor = formatter.parse(s) + toInstantWithZoneId(temporalAccessor) + } override def parse(s: String): Int = { - val seconds = dateTimeFormatter.toInstant(s).getEpochSecond + val seconds = toInstant(s).getEpochSecond val days = Math.floorDiv(seconds, DateTimeUtils.SECONDS_PER_DAY) days.toInt @@ -130,15 +149,12 @@ class Iso8601DateFormatter( override def format(days: Int): String = { val instant = Instant.ofEpochSecond(days * DateTimeUtils.SECONDS_PER_DAY) - dateTimeFormatter.formatter.withZone(timeZone.toZoneId).format(instant) + formatter.withZone(zoneId).format(instant) } } -class LegacyDateFormatter( - pattern: String, - timeZone: TimeZone, - locale: Locale) extends DateFormatter { - val format = FastDateFormat.getInstance(pattern, timeZone, locale) +class LegacyDateFormatter(pattern: String, locale: Locale) extends DateFormatter { + val format = FastDateFormat.getInstance(pattern, locale) def parse(s: String): Int = { val milliseconds = format.parse(s).getTime @@ -153,8 +169,7 @@ class LegacyDateFormatter( class LegacyFallbackDateFormatter( pattern: String, - timeZone: TimeZone, - locale: Locale) extends LegacyDateFormatter(pattern, timeZone, locale) { + locale: Locale) extends LegacyDateFormatter(pattern, locale) { override def parse(s: String): Int = { Try(super.parse(s)).orElse { // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards @@ -169,11 +184,11 @@ class LegacyFallbackDateFormatter( } object DateFormatter { - def apply(format: String, timeZone: TimeZone, locale: Locale): DateFormatter = { + def apply(format: String, locale: Locale): DateFormatter = { if (SQLConf.get.legacyTimeParserEnabled) { - new LegacyFallbackDateFormatter(format, timeZone, locale) + new LegacyFallbackDateFormatter(format, locale) } else { - new Iso8601DateFormatter(format, timeZone, locale) + new Iso8601DateFormatter(format, locale) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateTimeFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateTimeFormatterSuite.scala deleted file mode 100644 index 02d4ee0490604..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateTimeFormatterSuite.scala +++ /dev/null @@ -1,103 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.util - -import java.util.{Locale, TimeZone} - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeFormatter, DateTimeTestUtils} - -class DateTimeFormatterSuite extends SparkFunSuite { - test("parsing dates using time zones") { - val localDate = "2018-12-02" - val expectedDays = Map( - "UTC" -> 17867, - "PST" -> 17867, - "CET" -> 17866, - "Africa/Dakar" -> 17867, - "America/Los_Angeles" -> 17867, - "Antarctica/Vostok" -> 17866, - "Asia/Hong_Kong" -> 17866, - "Europe/Amsterdam" -> 17866) - DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone => - val formatter = DateFormatter("yyyy-MM-dd", TimeZone.getTimeZone(timeZone), Locale.US) - val daysSinceEpoch = formatter.parse(localDate) - assert(daysSinceEpoch === expectedDays(timeZone)) - } - } - - test("parsing timestamps using time zones") { - val localDate = "2018-12-02T10:11:12.001234" - val expectedMicros = Map( - "UTC" -> 1543745472001234L, - "PST" -> 1543774272001234L, - "CET" -> 1543741872001234L, - "Africa/Dakar" -> 1543745472001234L, - "America/Los_Angeles" -> 1543774272001234L, - "Antarctica/Vostok" -> 1543723872001234L, - "Asia/Hong_Kong" -> 1543716672001234L, - "Europe/Amsterdam" -> 1543741872001234L) - DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone => - val formatter = DateTimeFormatter( - "yyyy-MM-dd'T'HH:mm:ss.SSSSSS", - TimeZone.getTimeZone(timeZone), - Locale.US) - val microsSinceEpoch = formatter.parse(localDate) - assert(microsSinceEpoch === expectedMicros(timeZone)) - } - } - - test("format dates using time zones") { - val daysSinceEpoch = 17867 - val expectedDate = Map( - "UTC" -> "2018-12-02", - "PST" -> "2018-12-01", - "CET" -> "2018-12-02", - "Africa/Dakar" -> "2018-12-02", - "America/Los_Angeles" -> "2018-12-01", - "Antarctica/Vostok" -> "2018-12-02", - "Asia/Hong_Kong" -> "2018-12-02", - "Europe/Amsterdam" -> "2018-12-02") - DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone => - val formatter = DateFormatter("yyyy-MM-dd", TimeZone.getTimeZone(timeZone), Locale.US) - val date = formatter.format(daysSinceEpoch) - assert(date === expectedDate(timeZone)) - } - } - - test("format timestamps using time zones") { - val microsSinceEpoch = 1543745472001234L - val expectedTimestamp = Map( - "UTC" -> "2018-12-02T10:11:12.001234", - "PST" -> "2018-12-02T02:11:12.001234", - "CET" -> "2018-12-02T11:11:12.001234", - "Africa/Dakar" -> "2018-12-02T10:11:12.001234", - "America/Los_Angeles" -> "2018-12-02T02:11:12.001234", - "Antarctica/Vostok" -> "2018-12-02T16:11:12.001234", - "Asia/Hong_Kong" -> "2018-12-02T18:11:12.001234", - "Europe/Amsterdam" -> "2018-12-02T11:11:12.001234") - DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone => - val formatter = DateTimeFormatter( - "yyyy-MM-dd'T'HH:mm:ss.SSSSSS", - TimeZone.getTimeZone(timeZone), - Locale.US) - val timestamp = formatter.format(microsSinceEpoch) - assert(timestamp === expectedTimestamp(timeZone)) - } - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateTimestampFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateTimestampFormatterSuite.scala new file mode 100644 index 0000000000000..43e348c7eebf4 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateTimestampFormatterSuite.scala @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.util + +import java.util.{Locale, TimeZone} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.plans.SQLHelper +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.internal.SQLConf + +class DateTimestampFormatterSuite extends SparkFunSuite with SQLHelper { + test("parsing dates") { + DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone => + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) { + val formatter = DateFormatter("yyyy-MM-dd", Locale.US) + val daysSinceEpoch = formatter.parse("2018-12-02") + assert(daysSinceEpoch === 17867) + } + } + } + + test("parsing timestamps using time zones") { + val localDate = "2018-12-02T10:11:12.001234" + val expectedMicros = Map( + "UTC" -> 1543745472001234L, + "PST" -> 1543774272001234L, + "CET" -> 1543741872001234L, + "Africa/Dakar" -> 1543745472001234L, + "America/Los_Angeles" -> 1543774272001234L, + "Antarctica/Vostok" -> 1543723872001234L, + "Asia/Hong_Kong" -> 1543716672001234L, + "Europe/Amsterdam" -> 1543741872001234L) + DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone => + val formatter = TimestampFormatter( + "yyyy-MM-dd'T'HH:mm:ss.SSSSSS", + TimeZone.getTimeZone(timeZone), + Locale.US) + val microsSinceEpoch = formatter.parse(localDate) + assert(microsSinceEpoch === expectedMicros(timeZone)) + } + } + + test("format dates") { + DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone => + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) { + val formatter = DateFormatter("yyyy-MM-dd", Locale.US) + val date = formatter.format(17867) + assert(date === "2018-12-02") + } + } + } + + test("format timestamps using time zones") { + val microsSinceEpoch = 1543745472001234L + val expectedTimestamp = Map( + "UTC" -> "2018-12-02T10:11:12.001234", + "PST" -> "2018-12-02T02:11:12.001234", + "CET" -> "2018-12-02T11:11:12.001234", + "Africa/Dakar" -> "2018-12-02T10:11:12.001234", + "America/Los_Angeles" -> "2018-12-02T02:11:12.001234", + "Antarctica/Vostok" -> "2018-12-02T16:11:12.001234", + "Asia/Hong_Kong" -> "2018-12-02T18:11:12.001234", + "Europe/Amsterdam" -> "2018-12-02T11:11:12.001234") + DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone => + val formatter = TimestampFormatter( + "yyyy-MM-dd'T'HH:mm:ss.SSSSSS", + TimeZone.getTimeZone(timeZone), + Locale.US) + val timestamp = formatter.format(microsSinceEpoch) + assert(timestamp === expectedTimestamp(timeZone)) + } + } + + test("roundtrip timestamp -> micros -> timestamp using timezones") { + Seq( + -58710115316212000L, + -18926315945345679L, + -9463427405253013L, + -244000001L, + 0L, + 99628200102030L, + 1543749753123456L, + 2177456523456789L, + 11858049903010203L).foreach { micros => + DateTimeTestUtils.outstandingTimezones.foreach { timeZone => + val formatter = TimestampFormatter("yyyy-MM-dd'T'HH:mm:ss.SSSSSS", timeZone, Locale.US) + val timestamp = formatter.format(micros) + val parsed = formatter.parse(timestamp) + assert(micros === parsed) + } + } + } + + test("roundtrip micros -> timestamp -> micros using timezones") { + Seq( + "0109-07-20T18:38:03.788000", + "1370-04-01T10:00:54.654321", + "1670-02-11T14:09:54.746987", + "1969-12-31T23:55:55.999999", + "1970-01-01T00:00:00.000000", + "1973-02-27T02:30:00.102030", + "2018-12-02T11:22:33.123456", + "2039-01-01T01:02:03.456789", + "2345-10-07T22:45:03.010203").foreach { timestamp => + DateTimeTestUtils.outstandingTimezones.foreach { timeZone => + val formatter = TimestampFormatter("yyyy-MM-dd'T'HH:mm:ss.SSSSSS", timeZone, Locale.US) + val micros = formatter.parse(timestamp) + val formatted = formatter.format(micros) + assert(timestamp === formatted) + } + } + } + + test("roundtrip date -> days -> date") { + Seq( + "0050-01-01", + "0953-02-02", + "1423-03-08", + "1969-12-31", + "1972-08-25", + "1975-09-26", + "2018-12-12", + "2038-01-01", + "5010-11-17").foreach { date => + DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone => + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) { + val formatter = DateFormatter("yyyy-MM-dd", Locale.US) + val days = formatter.parse(date) + val formatted = formatter.format(days) + assert(date === formatted) + } + } + } + } + + test("roundtrip days -> date -> days") { + Seq( + -701265, + -371419, + -199722, + -1, + 0, + 967, + 2094, + 17877, + 24837, + 1110657).foreach { days => + DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone => + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) { + val formatter = DateFormatter("yyyy-MM-dd", Locale.US) + val date = formatter.format(days) + val parsed = formatter.parse(date) + assert(days === parsed) + } + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 3330de3584ebb..786335b42e3cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -57,14 +57,17 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } val factory = new JsonFactory() - def enforceCorrectType(value: Any, dataType: DataType): Any = { + def enforceCorrectType( + value: Any, + dataType: DataType, + options: Map[String, String] = Map.empty): Any = { val writer = new StringWriter() Utils.tryWithResource(factory.createGenerator(writer)) { generator => generator.writeObject(value) generator.flush() } - val dummyOption = new JSONOptions(Map.empty[String, String], "GMT") + val dummyOption = new JSONOptions(options, SQLConf.get.sessionLocalTimeZone) val dummySchema = StructType(Seq.empty) val parser = new JacksonParser(dummySchema, dummyOption, allowArrayAsStructs = true) @@ -96,19 +99,27 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(intNumber.toLong * 1000L)), enforceCorrectType(intNumber.toLong, TimestampType)) val strTime = "2014-09-30 12:34:56" - checkTypePromotion(DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf(strTime)), - enforceCorrectType(strTime, TimestampType)) + checkTypePromotion( + expected = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf(strTime)), + enforceCorrectType(strTime, TimestampType, + Map("timestampFormat" -> "yyyy-MM-dd HH:mm:ss"))) val strDate = "2014-10-15" checkTypePromotion( DateTimeUtils.fromJavaDate(Date.valueOf(strDate)), enforceCorrectType(strDate, DateType)) val ISO8601Time1 = "1970-01-01T01:00:01.0Z" - val ISO8601Time2 = "1970-01-01T02:00:01-01:00" checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(3601000)), - enforceCorrectType(ISO8601Time1, TimestampType)) + enforceCorrectType( + ISO8601Time1, + TimestampType, + Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss.SX"))) + val ISO8601Time2 = "1970-01-01T02:00:01-01:00" checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(10801000)), - enforceCorrectType(ISO8601Time2, TimestampType)) + enforceCorrectType( + ISO8601Time2, + TimestampType, + Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ssXXX"))) val ISO8601Date = "1970-01-01" checkTypePromotion(DateTimeUtils.millisToDays(32400000), @@ -1440,103 +1451,105 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("backward compatibility") { - // This test we make sure our JSON support can read JSON data generated by previous version - // of Spark generated through toJSON method and JSON data source. - // The data is generated by the following program. - // Here are a few notes: - // - Spark 1.5.0 cannot save timestamp data. So, we manually added timestamp field (col13) - // in the JSON object. - // - For Spark before 1.5.1, we do not generate UDTs. So, we manually added the UDT value to - // JSON objects generated by those Spark versions (col17). - // - If the type is NullType, we do not write data out. - - // Create the schema. - val struct = - StructType( - StructField("f1", FloatType, true) :: - StructField("f2", ArrayType(BooleanType), true) :: Nil) + withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> "true") { + // This test we make sure our JSON support can read JSON data generated by previous version + // of Spark generated through toJSON method and JSON data source. + // The data is generated by the following program. + // Here are a few notes: + // - Spark 1.5.0 cannot save timestamp data. So, we manually added timestamp field (col13) + // in the JSON object. + // - For Spark before 1.5.1, we do not generate UDTs. So, we manually added the UDT value to + // JSON objects generated by those Spark versions (col17). + // - If the type is NullType, we do not write data out. + + // Create the schema. + val struct = + StructType( + StructField("f1", FloatType, true) :: + StructField("f2", ArrayType(BooleanType), true) :: Nil) - val dataTypes = - Seq( - StringType, BinaryType, NullType, BooleanType, - ByteType, ShortType, IntegerType, LongType, - FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), - DateType, TimestampType, - ArrayType(IntegerType), MapType(StringType, LongType), struct, - new TestUDT.MyDenseVectorUDT()) - val fields = dataTypes.zipWithIndex.map { case (dataType, index) => - StructField(s"col$index", dataType, nullable = true) - } - val schema = StructType(fields) + val dataTypes = + Seq( + StringType, BinaryType, NullType, BooleanType, + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), + DateType, TimestampType, + ArrayType(IntegerType), MapType(StringType, LongType), struct, + new TestUDT.MyDenseVectorUDT()) + val fields = dataTypes.zipWithIndex.map { case (dataType, index) => + StructField(s"col$index", dataType, nullable = true) + } + val schema = StructType(fields) - val constantValues = - Seq( - "a string in binary".getBytes(StandardCharsets.UTF_8), - null, - true, - 1.toByte, - 2.toShort, - 3, - Long.MaxValue, - 0.25.toFloat, - 0.75, - new java.math.BigDecimal(s"1234.23456"), - new java.math.BigDecimal(s"1.23456"), - java.sql.Date.valueOf("2015-01-01"), - java.sql.Timestamp.valueOf("2015-01-01 23:50:59.123"), - Seq(2, 3, 4), - Map("a string" -> 2000L), - Row(4.75.toFloat, Seq(false, true)), - new TestUDT.MyDenseVector(Array(0.25, 2.25, 4.25))) - val data = - Row.fromSeq(Seq("Spark " + spark.sparkContext.version) ++ constantValues) :: Nil + val constantValues = + Seq( + "a string in binary".getBytes(StandardCharsets.UTF_8), + null, + true, + 1.toByte, + 2.toShort, + 3, + Long.MaxValue, + 0.25.toFloat, + 0.75, + new java.math.BigDecimal(s"1234.23456"), + new java.math.BigDecimal(s"1.23456"), + java.sql.Date.valueOf("2015-01-01"), + java.sql.Timestamp.valueOf("2015-01-01 23:50:59.123"), + Seq(2, 3, 4), + Map("a string" -> 2000L), + Row(4.75.toFloat, Seq(false, true)), + new TestUDT.MyDenseVector(Array(0.25, 2.25, 4.25))) + val data = + Row.fromSeq(Seq("Spark " + spark.sparkContext.version) ++ constantValues) :: Nil - // Data generated by previous versions. - // scalastyle:off - val existingJSONData = + // Data generated by previous versions. + // scalastyle:off + val existingJSONData = """{"col0":"Spark 1.2.2","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: - """{"col0":"Spark 1.3.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: - """{"col0":"Spark 1.3.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: - """{"col0":"Spark 1.4.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: - """{"col0":"Spark 1.4.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: - """{"col0":"Spark 1.5.0","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: - """{"col0":"Spark 1.5.0","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"16436","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: Nil - // scalastyle:on - - // Generate data for the current version. - val df = spark.createDataFrame(spark.sparkContext.parallelize(data, 1), schema) - withTempPath { path => - df.write.format("json").mode("overwrite").save(path.getCanonicalPath) + """{"col0":"Spark 1.3.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.3.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.4.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.4.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.5.0","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.5.0","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"16436","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: Nil + // scalastyle:on + + // Generate data for the current version. + val df = spark.createDataFrame(spark.sparkContext.parallelize(data, 1), schema) + withTempPath { path => + df.write.format("json").mode("overwrite").save(path.getCanonicalPath) - // df.toJSON will convert internal rows to external rows first and then generate - // JSON objects. While, df.write.format("json") will write internal rows directly. - val allJSON = + // df.toJSON will convert internal rows to external rows first and then generate + // JSON objects. While, df.write.format("json") will write internal rows directly. + val allJSON = existingJSONData ++ df.toJSON.collect() ++ sparkContext.textFile(path.getCanonicalPath).collect() - Utils.deleteRecursively(path) - sparkContext.parallelize(allJSON, 1).saveAsTextFile(path.getCanonicalPath) - - // Read data back with the schema specified. - val col0Values = - Seq( - "Spark 1.2.2", - "Spark 1.3.1", - "Spark 1.3.1", - "Spark 1.4.1", - "Spark 1.4.1", - "Spark 1.5.0", - "Spark 1.5.0", - "Spark " + spark.sparkContext.version, - "Spark " + spark.sparkContext.version) - val expectedResult = col0Values.map { v => - Row.fromSeq(Seq(v) ++ constantValues) + Utils.deleteRecursively(path) + sparkContext.parallelize(allJSON, 1).saveAsTextFile(path.getCanonicalPath) + + // Read data back with the schema specified. + val col0Values = + Seq( + "Spark 1.2.2", + "Spark 1.3.1", + "Spark 1.3.1", + "Spark 1.4.1", + "Spark 1.4.1", + "Spark 1.5.0", + "Spark 1.5.0", + "Spark " + spark.sparkContext.version, + "Spark " + spark.sparkContext.version) + val expectedResult = col0Values.map { v => + Row.fromSeq(Seq(v) ++ constantValues) + } + checkAnswer( + spark.read.format("json").schema(schema).load(path.getCanonicalPath), + expectedResult + ) } - checkAnswer( - spark.read.format("json").schema(schema).load(path.getCanonicalPath), - expectedResult - ) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala index 6075f2c8877d6..f0f62b608785d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.sources import java.io.File +import java.util.TimeZone import scala.util.Random @@ -125,56 +126,62 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes } else { Seq(false) } - for (dataType <- supportedDataTypes) { - for (parquetDictionaryEncodingEnabled <- parquetDictionaryEncodingEnabledConfs) { - val extraMessage = if (isParquetDataSource) { - s" with parquet.enable.dictionary = $parquetDictionaryEncodingEnabled" - } else { - "" - } - logInfo(s"Testing $dataType data type$extraMessage") - - val extraOptions = Map[String, String]( - "parquet.enable.dictionary" -> parquetDictionaryEncodingEnabled.toString - ) - - withTempPath { file => - val path = file.getCanonicalPath - - val dataGenerator = RandomDataGenerator.forType( - dataType = dataType, - nullable = true, - new Random(System.nanoTime()) - ).getOrElse { - fail(s"Failed to create data generator for schema $dataType") + // TODO: Support new parser too, see SPARK-26374. + withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> "true") { + for (dataType <- supportedDataTypes) { + for (parquetDictionaryEncodingEnabled <- parquetDictionaryEncodingEnabledConfs) { + val extraMessage = if (isParquetDataSource) { + s" with parquet.enable.dictionary = $parquetDictionaryEncodingEnabled" + } else { + "" + } + logInfo(s"Testing $dataType data type$extraMessage") + + val extraOptions = Map[String, String]( + "parquet.enable.dictionary" -> parquetDictionaryEncodingEnabled.toString + ) + + withTempPath { file => + val path = file.getCanonicalPath + + val seed = System.nanoTime() + withClue(s"Random data generated with the seed: ${seed}") { + val dataGenerator = RandomDataGenerator.forType( + dataType = dataType, + nullable = true, + new Random(seed) + ).getOrElse { + fail(s"Failed to create data generator for schema $dataType") + } + + // Create a DF for the schema with random data. The index field is used to sort the + // DataFrame. This is a workaround for SPARK-10591. + val schema = new StructType() + .add("index", IntegerType, nullable = false) + .add("col", dataType, nullable = true) + val rdd = + spark.sparkContext.parallelize((1 to 10).map(i => Row(i, dataGenerator()))) + val df = spark.createDataFrame(rdd, schema).orderBy("index").coalesce(1) + + df.write + .mode("overwrite") + .format(dataSourceName) + .option("dataSchema", df.schema.json) + .options(extraOptions) + .save(path) + + val loadedDF = spark + .read + .format(dataSourceName) + .option("dataSchema", df.schema.json) + .schema(df.schema) + .options(extraOptions) + .load(path) + .orderBy("index") + + checkAnswer(loadedDF, df) + } } - - // Create a DF for the schema with random data. The index field is used to sort the - // DataFrame. This is a workaround for SPARK-10591. - val schema = new StructType() - .add("index", IntegerType, nullable = false) - .add("col", dataType, nullable = true) - val rdd = - spark.sparkContext.parallelize((1 to 10).map(i => Row(i, dataGenerator()))) - val df = spark.createDataFrame(rdd, schema).orderBy("index").coalesce(1) - - df.write - .mode("overwrite") - .format(dataSourceName) - .option("dataSchema", df.schema.json) - .options(extraOptions) - .save(path) - - val loadedDF = spark - .read - .format(dataSourceName) - .option("dataSchema", df.schema.json) - .schema(df.schema) - .options(extraOptions) - .load(path) - .orderBy("index") - - checkAnswer(loadedDF, df) } } } From ffbc6b12369580ae92dca3a0891d1774e43a6eec Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sun, 16 Dec 2018 10:57:11 +0800 Subject: [PATCH 079/194] [SPARK-26078][SQL] Dedup self-join attributes on IN subqueries ## What changes were proposed in this pull request? When there is a self-join as result of a IN subquery, the join condition may be invalid, resulting in trivially true predicates and return wrong results. The PR deduplicates the subquery output in order to avoid the issue. ## How was this patch tested? added UT Closes #23057 from mgaido91/SPARK-26078. Authored-by: Marco Gaido Signed-off-by: Wenchen Fan --- .../sql/catalyst/optimizer/subquery.scala | 99 ++++++++++++------- .../org/apache/spark/sql/SubquerySuite.scala | 37 +++++++ 2 files changed, 98 insertions(+), 38 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index e9b7a8b76e683..34840c6c977a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.mutable.ArrayBuffer -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -43,31 +43,53 @@ import org.apache.spark.sql.types._ * condition. */ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { - private def dedupJoin(joinPlan: LogicalPlan): LogicalPlan = joinPlan match { + + private def buildJoin( + outerPlan: LogicalPlan, + subplan: LogicalPlan, + joinType: JoinType, + condition: Option[Expression]): Join = { + // Deduplicate conflicting attributes if any. + val dedupSubplan = dedupSubqueryOnSelfJoin(outerPlan, subplan, None, condition) + Join(outerPlan, dedupSubplan, joinType, condition) + } + + private def dedupSubqueryOnSelfJoin( + outerPlan: LogicalPlan, + subplan: LogicalPlan, + valuesOpt: Option[Seq[Expression]], + condition: Option[Expression] = None): LogicalPlan = { // SPARK-21835: It is possibly that the two sides of the join have conflicting attributes, // the produced join then becomes unresolved and break structural integrity. We should - // de-duplicate conflicting attributes. We don't use transformation here because we only - // care about the most top join converted from correlated predicate subquery. - case j @ Join(left, right, joinType @ (LeftSemi | LeftAnti | ExistenceJoin(_)), joinCond) => - val duplicates = right.outputSet.intersect(left.outputSet) - if (duplicates.nonEmpty) { - val aliasMap = AttributeMap(duplicates.map { dup => - dup -> Alias(dup, dup.toString)() - }.toSeq) - val aliasedExpressions = right.output.map { ref => - aliasMap.getOrElse(ref, ref) - } - val newRight = Project(aliasedExpressions, right) - val newJoinCond = joinCond.map { condExpr => - condExpr transform { - case a: Attribute => aliasMap.getOrElse(a, a).toAttribute + // de-duplicate conflicting attributes. + // SPARK-26078: it may also happen that the subquery has conflicting attributes with the outer + // values. In this case, the resulting join would contain trivially true conditions (eg. + // id#3 = id#3) which cannot be de-duplicated after. In this method, if there are conflicting + // attributes in the join condition, the subquery's conflicting attributes are changed using + // a projection which aliases them and resolves the problem. + val outerReferences = valuesOpt.map(values => + AttributeSet.fromAttributeSets(values.map(_.references))).getOrElse(AttributeSet.empty) + val outerRefs = outerPlan.outputSet ++ outerReferences + val duplicates = outerRefs.intersect(subplan.outputSet) + if (duplicates.nonEmpty) { + condition.foreach { e => + val conflictingAttrs = e.references.intersect(duplicates) + if (conflictingAttrs.nonEmpty) { + throw new AnalysisException("Found conflicting attributes " + + s"${conflictingAttrs.mkString(",")} in the condition joining outer plan:\n " + + s"$outerPlan\nand subplan:\n $subplan") } - } - Join(left, newRight, joinType, newJoinCond) - } else { - j } - case _ => joinPlan + val rewrites = AttributeMap(duplicates.map { dup => + dup -> Alias(dup, dup.toString)() + }.toSeq) + val aliasedExpressions = subplan.output.map { ref => + rewrites.getOrElse(ref, ref) + } + Project(aliasedExpressions, subplan) + } else { + subplan + } } def apply(plan: LogicalPlan): LogicalPlan = plan transform { @@ -85,17 +107,16 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { withSubquery.foldLeft(newFilter) { case (p, Exists(sub, conditions, _)) => val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) - // Deduplicate conflicting attributes if any. - dedupJoin(Join(outerPlan, sub, LeftSemi, joinCond)) + buildJoin(outerPlan, sub, LeftSemi, joinCond) case (p, Not(Exists(sub, conditions, _))) => val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) - // Deduplicate conflicting attributes if any. - dedupJoin(Join(outerPlan, sub, LeftAnti, joinCond)) + buildJoin(outerPlan, sub, LeftAnti, joinCond) case (p, InSubquery(values, ListQuery(sub, conditions, _, _))) => - val inConditions = values.zip(sub.output).map(EqualTo.tupled) - val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p) // Deduplicate conflicting attributes if any. - dedupJoin(Join(outerPlan, sub, LeftSemi, joinCond)) + val newSub = dedupSubqueryOnSelfJoin(p, sub, Some(values)) + val inConditions = values.zip(newSub.output).map(EqualTo.tupled) + val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p) + Join(outerPlan, newSub, LeftSemi, joinCond) case (p, Not(InSubquery(values, ListQuery(sub, conditions, _, _)))) => // This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr // Construct the condition. A NULL in one of the conditions is regarded as a positive @@ -103,7 +124,10 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // Note that will almost certainly be planned as a Broadcast Nested Loop join. // Use EXISTS if performance matters to you. - val inConditions = values.zip(sub.output).map(EqualTo.tupled) + + // Deduplicate conflicting attributes if any. + val newSub = dedupSubqueryOnSelfJoin(p, sub, Some(values)) + val inConditions = values.zip(newSub.output).map(EqualTo.tupled) val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions, p) // Expand the NOT IN expression with the NULL-aware semantic // to its full form. That is from: @@ -118,8 +142,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // will have the final conditions in the LEFT ANTI as // (A.A1 = B.B1 OR ISNULL(A.A1 = B.B1)) AND (B.B2 = A.A2) AND B.B3 > 1 val finalJoinCond = (nullAwareJoinConds ++ conditions).reduceLeft(And) - // Deduplicate conflicting attributes if any. - dedupJoin(Join(outerPlan, sub, LeftAnti, Option(finalJoinCond))) + Join(outerPlan, newSub, LeftAnti, Option(finalJoinCond)) case (p, predicate) => val (newCond, inputPlan) = rewriteExistentialExpr(Seq(predicate), p) Project(p.output, Filter(newCond.get, inputPlan)) @@ -140,16 +163,16 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { e transformUp { case Exists(sub, conditions, _) => val exists = AttributeReference("exists", BooleanType, nullable = false)() - // Deduplicate conflicting attributes if any. - newPlan = dedupJoin( - Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And))) + newPlan = + buildJoin(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And)) exists case InSubquery(values, ListQuery(sub, conditions, _, _)) => val exists = AttributeReference("exists", BooleanType, nullable = false)() - val inConditions = values.zip(sub.output).map(EqualTo.tupled) - val newConditions = (inConditions ++ conditions).reduceLeftOption(And) // Deduplicate conflicting attributes if any. - newPlan = dedupJoin(Join(newPlan, sub, ExistenceJoin(exists), newConditions)) + val newSub = dedupSubqueryOnSelfJoin(newPlan, sub, Some(values)) + val inConditions = values.zip(newSub.output).map(EqualTo.tupled) + val newConditions = (inConditions ++ conditions).reduceLeftOption(And) + newPlan = Join(newPlan, newSub, ExistenceJoin(exists), newConditions) exists } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 5088821ad7361..c95c52f1d3a9c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, Sort} import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ class SubquerySuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -1280,4 +1281,40 @@ class SubquerySuite extends QueryTest with SharedSQLContext { assert(subqueries.length == 1) } } + + test("SPARK-26078: deduplicate fake self joins for IN subqueries") { + withTempView("a", "b") { + Seq("a" -> 2, "b" -> 1).toDF("id", "num").createTempView("a") + Seq("a" -> 2, "b" -> 1).toDF("id", "num").createTempView("b") + + val df1 = spark.sql( + """ + |SELECT id,num,source FROM ( + | SELECT id, num, 'a' as source FROM a + | UNION ALL + | SELECT id, num, 'b' as source FROM b + |) AS c WHERE c.id IN (SELECT id FROM b WHERE num = 2) + """.stripMargin) + checkAnswer(df1, Seq(Row("a", 2, "a"), Row("a", 2, "b"))) + val df2 = spark.sql( + """ + |SELECT id,num,source FROM ( + | SELECT id, num, 'a' as source FROM a + | UNION ALL + | SELECT id, num, 'b' as source FROM b + |) AS c WHERE c.id NOT IN (SELECT id FROM b WHERE num = 2) + """.stripMargin) + checkAnswer(df2, Seq(Row("b", 1, "a"), Row("b", 1, "b"))) + val df3 = spark.sql( + """ + |SELECT id,num,source FROM ( + | SELECT id, num, 'a' as source FROM a + | UNION ALL + | SELECT id, num, 'b' as source FROM b + |) AS c WHERE c.id IN (SELECT id FROM b WHERE num = 2) OR + |c.id IN (SELECT id FROM b WHERE num = 3) + """.stripMargin) + checkAnswer(df3, Seq(Row("a", 2, "a"), Row("a", 2, "b"))) + } + } } From 4c2af74a0accd5a516e4c8a4180b0c570bbde49c Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Sun, 16 Dec 2018 11:02:00 +0800 Subject: [PATCH 080/194] [SPARK-26372][SQL] Don't reuse value from previous row when parsing bad CSV input field ## What changes were proposed in this pull request? CSV parsing accidentally uses the previous good value for a bad input field. See example in Jira. This PR ensures that the associated column is set to null when an input field cannot be converted. ## How was this patch tested? Added new test. Ran all SQL unit tests (testOnly org.apache.spark.sql.*). Ran pyspark tests for pyspark-sql Closes #23323 from bersprockets/csv-bad-field. Authored-by: Bruce Robbins Signed-off-by: Hyukjin Kwon --- .../sql/catalyst/csv/UnivocityParser.scala | 1 + .../resources/test-data/bad_after_good.csv | 2 ++ .../execution/datasources/csv/CSVSuite.scala | 19 +++++++++++++++++++ 3 files changed, 22 insertions(+) create mode 100644 sql/core/src/test/resources/test-data/bad_after_good.csv diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala index ed089120055e2..82a5b3c302b18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala @@ -239,6 +239,7 @@ class UnivocityParser( } catch { case NonFatal(e) => badRecordException = badRecordException.orElse(Some(e)) + row.setNullAt(i) } i += 1 } diff --git a/sql/core/src/test/resources/test-data/bad_after_good.csv b/sql/core/src/test/resources/test-data/bad_after_good.csv new file mode 100644 index 0000000000000..4621a7d23714d --- /dev/null +++ b/sql/core/src/test/resources/test-data/bad_after_good.csv @@ -0,0 +1,2 @@ +"good record",1999-08-01 +"bad record",1999-088-01 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 3b977d74053e6..d9e5d7af19671 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -63,6 +63,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te private val datesFile = "test-data/dates.csv" private val unescapedQuotesFile = "test-data/unescaped-quotes.csv" private val valueMalformedFile = "test-data/value-malformed.csv" + private val badAfterGoodFile = "test-data/bad_after_good.csv" /** Verifies data and schema. */ private def verifyCars( @@ -2012,4 +2013,22 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te assert(!files.exists(_.getName.endsWith("csv"))) } } + + test("Do not reuse last good value for bad input field") { + val schema = StructType( + StructField("col1", StringType) :: + StructField("col2", DateType) :: + Nil + ) + val rows = spark.read + .schema(schema) + .format("csv") + .load(testFile(badAfterGoodFile)) + + val expectedRows = Seq( + Row("good record", java.sql.Date.valueOf("1999-08-01")), + Row("bad record", null)) + + checkAnswer(rows, expectedRows) + } } From f2a56a68ccc8257477553b6e772c7225c25541a4 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Mon, 17 Dec 2018 08:24:51 +0800 Subject: [PATCH 081/194] [SPARK-26248][SQL] Infer date type from CSV ## What changes were proposed in this pull request? The `CSVInferSchema` class is extended to support inferring of `DateType` from CSV input. The attempt to infer `DateType` is performed after inferring `TimestampType`. ## How was this patch tested? Added new test for inferring date types from CSV . It was also tested by existing suites like `CSVInferSchemaSuite`, `CsvExpressionsSuite`, `CsvFunctionsSuite` and `CsvSuite`. Closes #23202 from MaxGekk/csv-date-inferring. Lead-authored-by: Maxim Gekk Co-authored-by: Maxim Gekk Signed-off-by: Wenchen Fan --- .../sql/catalyst/csv/CSVInferSchema.scala | 20 +++++++++++++++---- .../catalyst/csv/CSVInferSchemaSuite.scala | 18 +++++++++++++++++ 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala index 35ade136cc607..11f3740d99a72 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala @@ -22,16 +22,20 @@ import scala.util.control.Exception.allCatch import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.expressions.ExprUtils -import org.apache.spark.sql.catalyst.util.TimestampFormatter +import org.apache.spark.sql.catalyst.util.{DateFormatter, TimestampFormatter} import org.apache.spark.sql.types._ class CSVInferSchema(val options: CSVOptions) extends Serializable { @transient - private lazy val timestampParser = TimestampFormatter( + private lazy val timestampFormatter = TimestampFormatter( options.timestampFormat, options.timeZone, options.locale) + @transient + private lazy val dateFormatter = DateFormatter( + options.dateFormat, + options.locale) private val decimalParser = { ExprUtils.getDecimalParser(options.locale) @@ -104,6 +108,7 @@ class CSVInferSchema(val options: CSVOptions) extends Serializable { compatibleType(typeSoFar, tryParseDecimal(field)).getOrElse(StringType) case DoubleType => tryParseDouble(field) case TimestampType => tryParseTimestamp(field) + case DateType => tryParseDate(field) case BooleanType => tryParseBoolean(field) case StringType => StringType case other: DataType => @@ -159,9 +164,16 @@ class CSVInferSchema(val options: CSVOptions) extends Serializable { } private def tryParseTimestamp(field: String): DataType = { - // This case infers a custom `dataFormat` is set. - if ((allCatch opt timestampParser.parse(field)).isDefined) { + if ((allCatch opt timestampFormatter.parse(field)).isDefined) { TimestampType + } else { + tryParseDate(field) + } + } + + private def tryParseDate(field: String): DataType = { + if ((allCatch opt dateFormatter.parse(field)).isDefined) { + DateType } else { tryParseBoolean(field) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala index c2b525ad1a9f8..84b2e616a4426 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala @@ -187,4 +187,22 @@ class CSVInferSchemaSuite extends SparkFunSuite with SQLHelper { Seq("en-US", "ko-KR", "ru-RU", "de-DE").foreach(checkDecimalInfer(_, DecimalType(7, 0))) } + + test("inferring date type") { + var options = new CSVOptions(Map("dateFormat" -> "yyyy/MM/dd"), false, "GMT") + var inferSchema = new CSVInferSchema(options) + assert(inferSchema.inferField(NullType, "2018/12/02") == DateType) + + options = new CSVOptions(Map("dateFormat" -> "MMM yyyy"), false, "GMT") + inferSchema = new CSVInferSchema(options) + assert(inferSchema.inferField(NullType, "Dec 2018") == DateType) + + options = new CSVOptions( + Map("dateFormat" -> "yyyy-MM-dd", "timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss"), + columnPruning = false, + defaultTimeZoneId = "GMT") + inferSchema = new CSVInferSchema(options) + assert(inferSchema.inferField(NullType, "2018-12-03T11:00:00") == TimestampType) + assert(inferSchema.inferField(NullType, "2018-12-03") == DateType) + } } From 51a1cbbc2a81a607a03ecf17eb17fa17f9909099 Mon Sep 17 00:00:00 2001 From: Keiji Yoshida Date: Sun, 16 Dec 2018 17:11:58 -0800 Subject: [PATCH 082/194] [MINOR][DOCS] Fix the "not found: value Row" error on the "programmatic_schema" example ## What changes were proposed in this pull request? Print `import org.apache.spark.sql.Row` of `SparkSQLExample.scala` on the `programmatic_schema` example to fix the `not found: value Row` error on it. ``` scala> val rowRDD = peopleRDD.map(_.split(",")).map(attributes => Row(attributes(0), attributes(1).trim)) :28: error: not found: value Row val rowRDD = peopleRDD.map(_.split(",")).map(attributes => Row(attributes(0), attributes(1).trim)) ``` ## How was this patch tested? NA Closes #23326 from kjmrknsn/fix-sql-getting-started. Authored-by: Keiji Yoshida Signed-off-by: Dongjoon Hyun --- .../scala/org/apache/spark/examples/sql/SparkSQLExample.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/SparkSQLExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/SparkSQLExample.scala index 958361a6684c5..678cbc64aff1f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/SparkSQLExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/SparkSQLExample.scala @@ -16,7 +16,9 @@ */ package org.apache.spark.examples.sql +// $example on:programmatic_schema$ import org.apache.spark.sql.Row +// $example off:programmatic_schema$ // $example on:init_session$ import org.apache.spark.sql.SparkSession // $example off:init_session$ From c26df2b6456decd8092e2c0cdc91e83deae6916a Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Mon, 17 Dec 2018 11:53:14 +0800 Subject: [PATCH 083/194] Revert "[SPARK-26248][SQL] Infer date type from CSV" This reverts commit 5217f7b2263c7aaeadf60ef602776bb3777269cd. --- .../sql/catalyst/csv/CSVInferSchema.scala | 20 ++++--------------- .../catalyst/csv/CSVInferSchemaSuite.scala | 18 ----------------- 2 files changed, 4 insertions(+), 34 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala index 11f3740d99a72..35ade136cc607 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala @@ -22,20 +22,16 @@ import scala.util.control.Exception.allCatch import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.expressions.ExprUtils -import org.apache.spark.sql.catalyst.util.{DateFormatter, TimestampFormatter} +import org.apache.spark.sql.catalyst.util.TimestampFormatter import org.apache.spark.sql.types._ class CSVInferSchema(val options: CSVOptions) extends Serializable { @transient - private lazy val timestampFormatter = TimestampFormatter( + private lazy val timestampParser = TimestampFormatter( options.timestampFormat, options.timeZone, options.locale) - @transient - private lazy val dateFormatter = DateFormatter( - options.dateFormat, - options.locale) private val decimalParser = { ExprUtils.getDecimalParser(options.locale) @@ -108,7 +104,6 @@ class CSVInferSchema(val options: CSVOptions) extends Serializable { compatibleType(typeSoFar, tryParseDecimal(field)).getOrElse(StringType) case DoubleType => tryParseDouble(field) case TimestampType => tryParseTimestamp(field) - case DateType => tryParseDate(field) case BooleanType => tryParseBoolean(field) case StringType => StringType case other: DataType => @@ -164,16 +159,9 @@ class CSVInferSchema(val options: CSVOptions) extends Serializable { } private def tryParseTimestamp(field: String): DataType = { - if ((allCatch opt timestampFormatter.parse(field)).isDefined) { + // This case infers a custom `dataFormat` is set. + if ((allCatch opt timestampParser.parse(field)).isDefined) { TimestampType - } else { - tryParseDate(field) - } - } - - private def tryParseDate(field: String): DataType = { - if ((allCatch opt dateFormatter.parse(field)).isDefined) { - DateType } else { tryParseBoolean(field) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala index 84b2e616a4426..c2b525ad1a9f8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala @@ -187,22 +187,4 @@ class CSVInferSchemaSuite extends SparkFunSuite with SQLHelper { Seq("en-US", "ko-KR", "ru-RU", "de-DE").foreach(checkDecimalInfer(_, DecimalType(7, 0))) } - - test("inferring date type") { - var options = new CSVOptions(Map("dateFormat" -> "yyyy/MM/dd"), false, "GMT") - var inferSchema = new CSVInferSchema(options) - assert(inferSchema.inferField(NullType, "2018/12/02") == DateType) - - options = new CSVOptions(Map("dateFormat" -> "MMM yyyy"), false, "GMT") - inferSchema = new CSVInferSchema(options) - assert(inferSchema.inferField(NullType, "Dec 2018") == DateType) - - options = new CSVOptions( - Map("dateFormat" -> "yyyy-MM-dd", "timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss"), - columnPruning = false, - defaultTimeZoneId = "GMT") - inferSchema = new CSVInferSchema(options) - assert(inferSchema.inferField(NullType, "2018-12-03T11:00:00") == TimestampType) - assert(inferSchema.inferField(NullType, "2018-12-03") == DateType) - } } From 90c9bd5d0c1c6ab46deee2d6d0384265fad63f4f Mon Sep 17 00:00:00 2001 From: Kris Mok Date: Mon, 17 Dec 2018 13:41:20 +0800 Subject: [PATCH 084/194] [SPARK-26352][SQL] join reorder should not change the order of output attributes ## What changes were proposed in this pull request? The optimizer rule `org.apache.spark.sql.catalyst.optimizer.ReorderJoin` performs join reordering on inner joins. This was introduced from SPARK-12032 (https://github.com/apache/spark/pull/10073) in 2015-12. After it had reordered the joins, though, it didn't check whether or not the output attribute order is still the same as before. Thus, it's possible to have a mismatch between the reordered output attributes order vs the schema that a DataFrame thinks it has. The same problem exists in the CBO version of join reordering (`CostBasedJoinReorder`) too. This can be demonstrated with the example: ```scala spark.sql("create table table_a (x int, y int) using parquet") spark.sql("create table table_b (i int, j int) using parquet") spark.sql("create table table_c (a int, b int) using parquet") val df = spark.sql(""" with df1 as (select * from table_a cross join table_b) select * from df1 join table_c on a = x and b = i """) ``` here's what the DataFrame thinks: ``` scala> df.printSchema root |-- x: integer (nullable = true) |-- y: integer (nullable = true) |-- i: integer (nullable = true) |-- j: integer (nullable = true) |-- a: integer (nullable = true) |-- b: integer (nullable = true) ``` here's what the optimized plan thinks, after join reordering: ``` scala> df.queryExecution.optimizedPlan.output.foreach(a => println(s"|-- ${a.name}: ${a.dataType.typeName}")) |-- x: integer |-- y: integer |-- a: integer |-- b: integer |-- i: integer |-- j: integer ``` If we exclude the `ReorderJoin` rule (using Spark 2.4's optimizer rule exclusion feature), it's back to normal: ``` scala> spark.conf.set("spark.sql.optimizer.excludedRules", "org.apache.spark.sql.catalyst.optimizer.ReorderJoin") scala> val df = spark.sql("with df1 as (select * from table_a cross join table_b) select * from df1 join table_c on a = x and b = i") df: org.apache.spark.sql.DataFrame = [x: int, y: int ... 4 more fields] scala> df.queryExecution.optimizedPlan.output.foreach(a => println(s"|-- ${a.name}: ${a.dataType.typeName}")) |-- x: integer |-- y: integer |-- i: integer |-- j: integer |-- a: integer |-- b: integer ``` Note that this output attribute ordering problem leads to data corruption, and can manifest itself in various symptoms: * Silently corrupting data, if the reordered columns happen to either have matching types or have sufficiently-compatible types (e.g. all fixed length primitive types are considered as "sufficiently compatible" in an `UnsafeRow`), then only the resulting data is going to be wrong but it might not trigger any alarms immediately. Or * Weird Java-level exceptions like `java.lang.NegativeArraySizeException`, or even SIGSEGVs. ## How was this patch tested? Added new unit test in `JoinReorderSuite` and new end-to-end test in `JoinSuite`. Also made `JoinReorderSuite` and `StarJoinReorderSuite` assert more strongly on maintaining output attribute order. Closes #23303 from rednaxelafx/fix-join-reorder. Authored-by: Kris Mok Signed-off-by: Wenchen Fan --- .../optimizer/CostBasedJoinReorder.scala | 10 +++++ .../spark/sql/catalyst/optimizer/joins.scala | 12 +++++- .../optimizer/JoinOptimizationSuite.scala | 3 ++ .../catalyst/optimizer/JoinReorderSuite.scala | 38 +++++++++++++++++-- .../StarJoinCostBasedReorderSuite.scala | 21 +++++++++- .../optimizer/StarJoinReorderSuite.scala | 28 ++++++++++++-- .../org/apache/spark/sql/JoinSuite.scala | 14 +++++++ 7 files changed, 116 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala index 064ca68b7a628..01634a9d852c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala @@ -48,6 +48,7 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper { if projectList.forall(_.isInstanceOf[Attribute]) => reorder(p, p.output) } + // After reordering is finished, convert OrderedJoin back to Join result transformDown { case OrderedJoin(left, right, jt, cond) => Join(left, right, jt, cond) @@ -175,11 +176,20 @@ object JoinReorderDP extends PredicateHelper with Logging { assert(topOutputSet == p.outputSet) // Keep the same order of final output attributes. p.copy(projectList = output) + case finalPlan if !sameOutput(finalPlan, output) => + Project(output, finalPlan) case finalPlan => finalPlan } } + private def sameOutput(plan: LogicalPlan, expectedOutput: Seq[Attribute]): Boolean = { + val thisOutput = plan.output + thisOutput.length == expectedOutput.length && thisOutput.zip(expectedOutput).forall { + case (a1, a2) => a1.semanticEquals(a2) + } + } + /** Find all possible plans at the next level, based on existing levels. */ private def searchLevel( existingLevels: Seq[JoinPlanMap], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index 6ebb194d71c2e..0b6471289a471 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -86,9 +86,9 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { } def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case ExtractFiltersAndInnerJoins(input, conditions) + case p @ ExtractFiltersAndInnerJoins(input, conditions) if input.size > 2 && conditions.nonEmpty => - if (SQLConf.get.starSchemaDetection && !SQLConf.get.cboEnabled) { + val reordered = if (SQLConf.get.starSchemaDetection && !SQLConf.get.cboEnabled) { val starJoinPlan = StarSchemaDetection.reorderStarJoins(input, conditions) if (starJoinPlan.nonEmpty) { val rest = input.filterNot(starJoinPlan.contains(_)) @@ -99,6 +99,14 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { } else { createOrderedJoin(input, conditions) } + + if (p.sameOutput(reordered)) { + reordered + } else { + // Reordering the joins have changed the order of the columns. + // Inject a projection to make sure we restore to the expected ordering. + Project(p.output, reordered) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala index ccd9d8dd4d213..e9438b2eee550 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala @@ -102,16 +102,19 @@ class JoinOptimizationSuite extends PlanTest { x.join(y).join(z).where(("x.b".attr === "z.b".attr) && ("y.d".attr === "z.a".attr)), x.join(z, condition = Some("x.b".attr === "z.b".attr)) .join(y, condition = Some("y.d".attr === "z.a".attr)) + .select(Seq("x.a", "x.b", "x.c", "y.d", "z.a", "z.b", "z.c").map(_.attr): _*) ), ( x.join(y, Cross).join(z, Cross) .where(("x.b".attr === "z.b".attr) && ("y.d".attr === "z.a".attr)), x.join(z, Cross, Some("x.b".attr === "z.b".attr)) .join(y, Cross, Some("y.d".attr === "z.a".attr)) + .select(Seq("x.a", "x.b", "x.c", "y.d", "z.a", "z.b", "z.c").map(_.attr): _*) ), ( x.join(y, Inner).join(z, Cross).where("x.b".attr === "z.a".attr), x.join(z, Cross, Some("x.b".attr === "z.a".attr)).join(y, Inner) + .select(Seq("x.a", "x.b", "x.c", "y.d", "z.a", "z.b", "z.c").map(_.attr): _*) ) ) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala index 565b0a10154a8..c94a8b9e318f6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} -import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.{Cross, Inner, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan} import org.apache.spark.sql.internal.SQLConf.{CBO_ENABLED, JOIN_REORDER_ENABLED} @@ -124,7 +124,8 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { // the original order (t1 J t2) J t3. val bestPlan = t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) - .join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .select(outputsOf(t1, t2, t3): _*) assertEqualPlans(originalPlan, bestPlan) } @@ -139,7 +140,9 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { val bestPlan = t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) .join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .select(outputsOf(t1, t2, t3): _*) // this is redundant but we'll take it for now .join(t4) + .select(outputsOf(t1, t2, t4, t3): _*) assertEqualPlans(originalPlan, bestPlan) } @@ -202,6 +205,7 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) .join(t4.join(t3, Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))), Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2"))) + .select(outputsOf(t1, t4, t2, t3): _*) assertEqualPlans(originalPlan, bestPlan) } @@ -219,6 +223,23 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { } } + test("SPARK-26352: join reordering should not change the order of attributes") { + // This test case does not rely on CBO. + // It's similar to the test case above, but catches a reordering bug that the one above doesn't + val tab1 = LocalRelation('x.int, 'y.int) + val tab2 = LocalRelation('i.int, 'j.int) + val tab3 = LocalRelation('a.int, 'b.int) + val original = + tab1.join(tab2, Cross) + .join(tab3, Inner, Some('a === 'x && 'b === 'i)) + val expected = + tab1.join(tab3, Inner, Some('a === 'x)) + .join(tab2, Cross, Some('b === 'i)) + .select(outputsOf(tab1, tab2, tab3): _*) + + assertEqualPlans(original, expected) + } + test("reorder recursively") { // Original order: // Join @@ -266,8 +287,17 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { private def assertEqualPlans( originalPlan: LogicalPlan, groundTruthBestPlan: LogicalPlan): Unit = { - val optimized = Optimize.execute(originalPlan.analyze) + val analyzed = originalPlan.analyze + val optimized = Optimize.execute(analyzed) val expected = groundTruthBestPlan.analyze + + assert(analyzed.sameOutput(expected)) // if this fails, the expected plan itself is incorrect + assert(analyzed.sameOutput(optimized)) + compareJoinOrder(optimized, expected) } + + private def outputsOf(plans: LogicalPlan*): Seq[Attribute] = { + plans.map(_.output).reduce(_ ++ _) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala index d4d23ad69b2c2..baae934e1e4fe 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala @@ -218,6 +218,7 @@ class StarJoinCostBasedReorderSuite extends PlanTest with StatsEstimationTestBas .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk"))) .join(t2, Inner, Some(nameToAttr("f1_c2") === nameToAttr("t2_c1"))) .join(t1, Inner, Some(nameToAttr("f1_c1") === nameToAttr("t1_c1"))) + .select(outputsOf(f1, t1, t2, d1, d2): _*) assertEqualPlans(query, expected) } @@ -256,6 +257,7 @@ class StarJoinCostBasedReorderSuite extends PlanTest with StatsEstimationTestBas .join(t3.join(t2, Inner, Some(nameToAttr("t2_c2") === nameToAttr("t3_c1"))), Inner, Some(nameToAttr("d1_c2") === nameToAttr("t2_c1"))) .join(t1, Inner, Some(nameToAttr("t1_c1") === nameToAttr("f1_c1"))) + .select(outputsOf(d1, t1, t2, f1, d2, t3): _*) assertEqualPlans(query, expected) } @@ -297,6 +299,7 @@ class StarJoinCostBasedReorderSuite extends PlanTest with StatsEstimationTestBas Some(nameToAttr("t3_c1") === nameToAttr("t4_c1"))) .join(t1.join(t2, Inner, Some(nameToAttr("t1_c1") === nameToAttr("t2_c1"))), Inner, Some(nameToAttr("t1_c2") === nameToAttr("t4_c2"))) + .select(outputsOf(d1, t1, t2, t3, t4, f1, d2): _*) assertEqualPlans(query, expected) } @@ -347,6 +350,7 @@ class StarJoinCostBasedReorderSuite extends PlanTest with StatsEstimationTestBas Some(nameToAttr("d3_c2") === nameToAttr("t1_c1"))) .join(t5.join(t6, Inner, Some(nameToAttr("t5_c2") === nameToAttr("t6_c2"))), Inner, Some(nameToAttr("d2_c2") === nameToAttr("t5_c1"))) + .select(outputsOf(d1, t3, t4, f1, d2, t5, t6, d3, t1, t2): _*) assertEqualPlans(query, expected) } @@ -375,6 +379,7 @@ class StarJoinCostBasedReorderSuite extends PlanTest with StatsEstimationTestBas f1.join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk"))) .join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk"))) .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk"))) + .select(outputsOf(d1, d2, f1, d3): _*) assertEqualPlans(query, expected) } @@ -400,13 +405,27 @@ class StarJoinCostBasedReorderSuite extends PlanTest with StatsEstimationTestBas f1.join(t3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("t3_c1"))) .join(t2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("t2_c1"))) .join(t1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("t1_c1"))) + .select(outputsOf(t1, f1, t2, t3): _*) assertEqualPlans(query, expected) } private def assertEqualPlans( plan1: LogicalPlan, plan2: LogicalPlan): Unit = { - val optimized = Optimize.execute(plan1.analyze) + val analyzed = plan1.analyze + val optimized = Optimize.execute(analyzed) val expected = plan2.analyze + + assert(equivalentOutput(analyzed, expected)) // if this fails, the expected itself is incorrect + assert(equivalentOutput(analyzed, optimized)) + compareJoinOrder(optimized, expected) } + + private def outputsOf(plans: LogicalPlan*): Seq[Attribute] = { + plans.map(_.output).reduce(_ ++ _) + } + + private def equivalentOutput(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = { + normalizeExprIds(plan1).output == normalizeExprIds(plan2).output + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala index 4e0883e91e84a..9dc653b9d6c44 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala @@ -182,6 +182,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase { .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1"))) .join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1"))) .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + .select(outputsOf(d1, d2, f1, d3, s3): _*) assertEqualPlans(query, expected) } @@ -220,6 +221,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase { .join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1"))) .join(d2, Inner, Some(nameToAttr("f1_fk2") < nameToAttr("d2_pk1"))) .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + .select(outputsOf(d1, f1, d2, s3, d3): _*) assertEqualPlans(query, expected) } @@ -255,7 +257,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase { .join(d3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) .join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1"))) .join(s3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("s3_c2"))) - + .select(outputsOf(d1, f1, d2, s3, d3): _*) assertEqualPlans(query, expected) } @@ -292,6 +294,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase { .join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1"))) .join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_c2"))) .join(s3, Inner, Some(nameToAttr("d3_fk1") < nameToAttr("s3_pk1"))) + .select(outputsOf(d1, f1, d2, s3, d3): _*) assertEqualPlans(query, expected) } @@ -395,6 +398,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase { .join(d2.where(nameToAttr("d2_c2") === 2), Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1"))) .join(s3, Inner, Some(nameToAttr("f11_fk1") === nameToAttr("s3_pk1"))) + .select(outputsOf(d1, f11, f1, d2, s3): _*) assertEqualPlans(query, equivQuery) } @@ -430,6 +434,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase { .join(d2.where(nameToAttr("d2_c2") === 2), Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_c4"))) .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + .select(outputsOf(d1, d3, f1, d2, s3): _*) assertEqualPlans(query, expected) } @@ -465,6 +470,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase { .join(d2.where(nameToAttr("d2_c2") === 2), Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1"))) .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + .select(outputsOf(d1, d3, f1, d2, s3): _*) assertEqualPlans(query, expected) } @@ -499,6 +505,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase { .join(d2.where(nameToAttr("d2_c2") === 2), Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1"))) .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + .select(outputsOf(d1, d3, f1, d2, s3): _*) assertEqualPlans(query, expected) } @@ -532,6 +539,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase { .join(d3, Inner, Some(nameToAttr("f1_fk3") < nameToAttr("d3_pk1"))) .join(d2, Inner, Some(nameToAttr("f1_fk2") < nameToAttr("d2_pk1"))) .join(s3, Inner, Some(nameToAttr("d3_fk1") < nameToAttr("s3_pk1"))) + .select(outputsOf(d1, d3, f1, d2, s3): _*) assertEqualPlans(query, expected) } @@ -565,13 +573,27 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase { .join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1"))) .join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1"))) .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + .select(outputsOf(d1, d3, f1, d2, s3): _*) assertEqualPlans(query, expected) } - private def assertEqualPlans( plan1: LogicalPlan, plan2: LogicalPlan): Unit = { - val optimized = Optimize.execute(plan1.analyze) + private def assertEqualPlans(plan1: LogicalPlan, plan2: LogicalPlan): Unit = { + val analyzed = plan1.analyze + val optimized = Optimize.execute(analyzed) val expected = plan2.analyze + + assert(equivalentOutput(analyzed, expected)) // if this fails, the expected itself is incorrect + assert(equivalentOutput(analyzed, optimized)) + compareJoinOrder(optimized, expected) } + + private def outputsOf(plans: LogicalPlan*): Seq[Attribute] = { + plans.map(_.output).reduce(_ ++ _) + } + + private def equivalentOutput(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = { + normalizeExprIds(plan1).output == normalizeExprIds(plan2).output + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index aa2162c9d2cda..91445c8d96d85 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -895,4 +895,18 @@ class JoinSuite extends QueryTest with SharedSQLContext { checkAnswer(res, Row(0, 0, 0)) } } + + test("SPARK-26352: join reordering should not change the order of columns") { + withTable("tab1", "tab2", "tab3") { + spark.sql("select 1 as x, 100 as y").write.saveAsTable("tab1") + spark.sql("select 42 as i, 200 as j").write.saveAsTable("tab2") + spark.sql("select 1 as a, 42 as b").write.saveAsTable("tab3") + + val df = spark.sql(""" + with tmp as (select * from tab1 cross join tab2) + select * from tmp join tab3 on a = x and b = i + """) + checkAnswer(df, Row(1, 100, 42, 200, 1, 42)) + } + } } From 33de7df15cec6f63f5b7ff4fc8f9d44839ba60ff Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 16 Dec 2018 23:40:06 -0800 Subject: [PATCH 085/194] [SPARK-26327][SQL][FOLLOW-UP] Refactor the code and restore the metrics name ## What changes were proposed in this pull request? - The original comment about `updateDriverMetrics` is not right. - Refactor the code to ensure `selectedPartitions ` has been set before sending the driver-side metrics. - Restore the original name, which is more general and extendable. ## How was this patch tested? The existing tests. Closes #23328 from gatorsmile/followupSpark-26142. Authored-by: gatorsmile Signed-off-by: gatorsmile --- .../sql/execution/DataSourceScanExec.scala | 39 +++++++++---------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index c0fa4e777b49c..322ffffca564b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.{ArrayBuffer, HashMap} import org.apache.commons.lang3.StringUtils import org.apache.hadoop.fs.{BlockLocation, FileStatus, LocatedFileStatus, Path} @@ -167,14 +167,26 @@ case class FileSourceScanExec( partitionSchema = relation.partitionSchema, relation.sparkSession.sessionState.conf) - private var fileListingTime = 0L + val driverMetrics: HashMap[String, Long] = HashMap.empty + + /** + * Send the driver-side metrics. Before calling this function, selectedPartitions has + * been initialized. See SPARK-26327 for more details. + */ + private def sendDriverMetrics(): Unit = { + driverMetrics.foreach(e => metrics(e._1).add(e._2)) + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, + metrics.filter(e => driverMetrics.contains(e._1)).values.toSeq) + } @transient private lazy val selectedPartitions: Seq[PartitionDirectory] = { val optimizerMetadataTimeNs = relation.location.metadataOpsTimeNs.getOrElse(0L) val startTime = System.nanoTime() val ret = relation.location.listFiles(partitionFilters, dataFilters) + driverMetrics("numFiles") = ret.map(_.files.size.toLong).sum val timeTakenMs = ((System.nanoTime() - startTime) + optimizerMetadataTimeNs) / 1000 / 1000 - fileListingTime = timeTakenMs + driverMetrics("metadataTime") = timeTakenMs ret } @@ -286,8 +298,6 @@ case class FileSourceScanExec( } private lazy val inputRDD: RDD[InternalRow] = { - // Update metrics for taking effect in both code generation node and normal node. - updateDriverMetrics() val readFile: (PartitionedFile) => Iterator[InternalRow] = relation.fileFormat.buildReaderWithPartitionValues( sparkSession = relation.sparkSession, @@ -298,12 +308,14 @@ case class FileSourceScanExec( options = relation.options, hadoopConf = relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options)) - relation.bucketSpec match { + val readRDD = relation.bucketSpec match { case Some(bucketing) if relation.sparkSession.sessionState.conf.bucketingEnabled => createBucketedReadRDD(bucketing, readFile, selectedPartitions, relation) case _ => createNonBucketedReadRDD(readFile, selectedPartitions, relation) } + sendDriverMetrics() + readRDD } override def inputRDDs(): Seq[RDD[InternalRow]] = { @@ -313,7 +325,7 @@ case class FileSourceScanExec( override lazy val metrics = Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "numFiles" -> SQLMetrics.createMetric(sparkContext, "number of files"), - "fileListingTime" -> SQLMetrics.createMetric(sparkContext, "file listing time (ms)"), + "metadataTime" -> SQLMetrics.createMetric(sparkContext, "metadata time"), "scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time")) protected override def doExecute(): RDD[InternalRow] = { @@ -504,19 +516,6 @@ case class FileSourceScanExec( } } - /** - * Send the updated metrics to driver, while this function calling, selectedPartitions has - * been initialized. See SPARK-26327 for more detail. - */ - private def updateDriverMetrics() = { - metrics("numFiles").add(selectedPartitions.map(_.files.size.toLong).sum) - metrics("fileListingTime").add(fileListingTime) - - val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, - metrics("numFiles") :: metrics("fileListingTime") :: Nil) - } - override def doCanonicalize(): FileSourceScanExec = { FileSourceScanExec( relation, From a1c97b5446eb32a5b583e3f0be5123856c491b8c Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 17 Dec 2018 00:13:51 -0800 Subject: [PATCH 086/194] [SPARK-20636] Add the rule TransposeWindow to the optimization batch ## What changes were proposed in this pull request? This PR is a follow-up of the PR https://github.com/apache/spark/pull/17899. It is to add the rule TransposeWindow the optimizer batch. ## How was this patch tested? The existing tests. Closes #23222 from gatorsmile/followupSPARK-20636. Authored-by: gatorsmile Signed-off-by: gatorsmile --- .../sql/catalyst/optimizer/Optimizer.scala | 1 + .../sql/DataFrameWindowFunctionsSuite.scala | 38 +++++++++++++------ 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f615757a837a1..3eb6bca6ec976 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -73,6 +73,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) CombineLimits, CombineUnions, // Constant folding and strength reduction + TransposeWindow, NullPropagation, ConstantPropagation, FoldablePropagation, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 9a5d5a9966ab7..9277dc6859247 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql import org.scalatest.Matchers.the import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} +import org.apache.spark.sql.catalyst.optimizer.TransposeWindow +import org.apache.spark.sql.execution.exchange.Exchange import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Window} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -668,18 +670,30 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { ("S2", "P2", 300) ).toDF("sno", "pno", "qty") - val w1 = Window.partitionBy("sno") - val w2 = Window.partitionBy("sno", "pno") - - checkAnswer( - df.select($"sno", $"pno", $"qty", sum($"qty").over(w2).alias("sum_qty_2")) - .select($"sno", $"pno", $"qty", col("sum_qty_2"), sum("qty").over(w1).alias("sum_qty_1")), - Seq( - Row("S1", "P1", 100, 800, 800), - Row("S1", "P1", 700, 800, 800), - Row("S2", "P1", 200, 200, 500), - Row("S2", "P2", 300, 300, 500))) - + Seq(true, false).foreach { transposeWindowEnabled => + val excludedRules = if (transposeWindowEnabled) "" else TransposeWindow.ruleName + withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> excludedRules) { + val w1 = Window.partitionBy("sno") + val w2 = Window.partitionBy("sno", "pno") + + val select = df.select($"sno", $"pno", $"qty", sum($"qty").over(w2).alias("sum_qty_2")) + .select($"sno", $"pno", $"qty", col("sum_qty_2"), sum("qty").over(w1).alias("sum_qty_1")) + + val expectedNumExchanges = if (transposeWindowEnabled) 1 else 2 + val actualNumExchanges = select.queryExecution.executedPlan.collect { + case e: Exchange => e + }.length + assert(actualNumExchanges == expectedNumExchanges) + + checkAnswer( + select, + Seq( + Row("S1", "P1", 100, 800, 800), + Row("S1", "P1", 700, 800, 800), + Row("S2", "P1", 200, 200, 500), + Row("S2", "P2", 300, 300, 500))) + } + } } test("NaN and -0.0 in window partition keys") { From 0ed3f6aa122e08e0efc08431063cae9da2b33997 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 17 Dec 2018 21:47:38 +0800 Subject: [PATCH 087/194] [SPARK-26243][SQL][FOLLOWUP] fix code style issues in TimestampFormatter.scala ## What changes were proposed in this pull request? 1. rename `FormatterUtils` to `DateTimeFormatterHelper`, and move it to a separated file 2. move `DateFormatter` and its implementation to a separated file 3. mark some methods as private 4. add `override` to some methods ## How was this patch tested? existing tests Closes #23329 from cloud-fan/minor. Authored-by: Wenchen Fan Signed-off-by: Hyukjin Kwon --- .../sql/catalyst/util/DateFormatter.scala | 96 +++++++++++++++ .../util/DateTimeFormatterHelper.scala | 44 +++++++ .../catalyst/util/TimestampFormatter.scala | 115 ++---------------- .../spark/sql/util/DateFormatterSuite.scala | 92 ++++++++++++++ ...te.scala => TimestampFormatterSuite.scala} | 73 +---------- 5 files changed, 246 insertions(+), 174 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateFormatterSuite.scala rename sql/catalyst/src/test/scala/org/apache/spark/sql/util/{DateTimestampFormatterSuite.scala => TimestampFormatterSuite.scala} (66%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala new file mode 100644 index 0000000000000..9e8d51cc65f03 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import java.time.{Instant, ZoneId} +import java.util.Locale + +import scala.util.Try + +import org.apache.commons.lang3.time.FastDateFormat + +import org.apache.spark.sql.internal.SQLConf + +sealed trait DateFormatter { + def parse(s: String): Int // returns days since epoch + def format(days: Int): String +} + +class Iso8601DateFormatter( + pattern: String, + locale: Locale) extends DateFormatter with DateTimeFormatterHelper { + + private val formatter = buildFormatter(pattern, locale) + private val UTC = ZoneId.of("UTC") + + private def toInstant(s: String): Instant = { + val temporalAccessor = formatter.parse(s) + toInstantWithZoneId(temporalAccessor, UTC) + } + + override def parse(s: String): Int = { + val seconds = toInstant(s).getEpochSecond + val days = Math.floorDiv(seconds, DateTimeUtils.SECONDS_PER_DAY) + days.toInt + } + + override def format(days: Int): String = { + val instant = Instant.ofEpochSecond(days * DateTimeUtils.SECONDS_PER_DAY) + formatter.withZone(UTC).format(instant) + } +} + +class LegacyDateFormatter(pattern: String, locale: Locale) extends DateFormatter { + private val format = FastDateFormat.getInstance(pattern, locale) + + override def parse(s: String): Int = { + val milliseconds = format.parse(s).getTime + DateTimeUtils.millisToDays(milliseconds) + } + + override def format(days: Int): String = { + val date = DateTimeUtils.toJavaDate(days) + format.format(date) + } +} + +class LegacyFallbackDateFormatter( + pattern: String, + locale: Locale) extends LegacyDateFormatter(pattern, locale) { + override def parse(s: String): Int = { + Try(super.parse(s)).orElse { + // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards + // compatibility. + Try(DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(s).getTime)) + }.getOrElse { + // In Spark 1.5.0, we store the data as number of days since epoch in string. + // So, we just convert it to Int. + s.toInt + } + } +} + +object DateFormatter { + def apply(format: String, locale: Locale): DateFormatter = { + if (SQLConf.get.legacyTimeParserEnabled) { + new LegacyFallbackDateFormatter(format, locale) + } else { + new Iso8601DateFormatter(format, locale) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala new file mode 100644 index 0000000000000..b85101d38d9e6 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import java.time.{Instant, LocalDateTime, ZonedDateTime, ZoneId} +import java.time.format.{DateTimeFormatter, DateTimeFormatterBuilder} +import java.time.temporal.{ChronoField, TemporalAccessor} +import java.util.Locale + +trait DateTimeFormatterHelper { + + protected def buildFormatter(pattern: String, locale: Locale): DateTimeFormatter = { + new DateTimeFormatterBuilder() + .appendPattern(pattern) + .parseDefaulting(ChronoField.YEAR_OF_ERA, 1970) + .parseDefaulting(ChronoField.MONTH_OF_YEAR, 1) + .parseDefaulting(ChronoField.DAY_OF_MONTH, 1) + .parseDefaulting(ChronoField.HOUR_OF_DAY, 0) + .parseDefaulting(ChronoField.MINUTE_OF_HOUR, 0) + .parseDefaulting(ChronoField.SECOND_OF_MINUTE, 0) + .toFormatter(locale) + } + + protected def toInstantWithZoneId(temporalAccessor: TemporalAccessor, zoneId: ZoneId): Instant = { + val localDateTime = LocalDateTime.from(temporalAccessor) + val zonedDateTime = ZonedDateTime.of(localDateTime, zoneId) + Instant.from(zonedDateTime) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala index 2b8d22dde9267..eb1303303463d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.catalyst.util import java.time._ -import java.time.format.DateTimeFormatterBuilder -import java.time.temporal.{ChronoField, TemporalAccessor, TemporalQueries} +import java.time.temporal.TemporalQueries import java.util.{Locale, TimeZone} import scala.util.Try @@ -33,39 +32,16 @@ sealed trait TimestampFormatter { def format(us: Long): String } -trait FormatterUtils { - protected def zoneId: ZoneId - protected def buildFormatter( - pattern: String, - locale: Locale): java.time.format.DateTimeFormatter = { - new DateTimeFormatterBuilder() - .appendPattern(pattern) - .parseDefaulting(ChronoField.YEAR_OF_ERA, 1970) - .parseDefaulting(ChronoField.MONTH_OF_YEAR, 1) - .parseDefaulting(ChronoField.DAY_OF_MONTH, 1) - .parseDefaulting(ChronoField.HOUR_OF_DAY, 0) - .parseDefaulting(ChronoField.MINUTE_OF_HOUR, 0) - .parseDefaulting(ChronoField.SECOND_OF_MINUTE, 0) - .toFormatter(locale) - } - protected def toInstantWithZoneId(temporalAccessor: TemporalAccessor): java.time.Instant = { - val localDateTime = LocalDateTime.from(temporalAccessor) - val zonedDateTime = ZonedDateTime.of(localDateTime, zoneId) - Instant.from(zonedDateTime) - } -} - class Iso8601TimestampFormatter( pattern: String, timeZone: TimeZone, - locale: Locale) extends TimestampFormatter with FormatterUtils { - val zoneId = timeZone.toZoneId - val formatter = buildFormatter(pattern, locale) + locale: Locale) extends TimestampFormatter with DateTimeFormatterHelper { + private val formatter = buildFormatter(pattern, locale) - def toInstant(s: String): Instant = { + private def toInstant(s: String): Instant = { val temporalAccessor = formatter.parse(s) if (temporalAccessor.query(TemporalQueries.offset()) == null) { - toInstantWithZoneId(temporalAccessor) + toInstantWithZoneId(temporalAccessor, timeZone.toZoneId) } else { Instant.from(temporalAccessor) } @@ -77,9 +53,9 @@ class Iso8601TimestampFormatter( result } - def parse(s: String): Long = instantToMicros(toInstant(s)) + override def parse(s: String): Long = instantToMicros(toInstant(s)) - def format(us: Long): String = { + override def format(us: Long): String = { val secs = Math.floorDiv(us, DateTimeUtils.MICROS_PER_SECOND) val mos = Math.floorMod(us, DateTimeUtils.MICROS_PER_SECOND) val instant = Instant.ofEpochSecond(secs, mos * DateTimeUtils.NANOS_PER_MICROS) @@ -92,13 +68,13 @@ class LegacyTimestampFormatter( pattern: String, timeZone: TimeZone, locale: Locale) extends TimestampFormatter { - val format = FastDateFormat.getInstance(pattern, timeZone, locale) + private val format = FastDateFormat.getInstance(pattern, timeZone, locale) protected def toMillis(s: String): Long = format.parse(s).getTime - def parse(s: String): Long = toMillis(s) * DateTimeUtils.MICROS_PER_MILLIS + override def parse(s: String): Long = toMillis(s) * DateTimeUtils.MICROS_PER_MILLIS - def format(us: Long): String = { + override def format(us: Long): String = { format.format(DateTimeUtils.toJavaTimestamp(us)) } } @@ -121,74 +97,3 @@ object TimestampFormatter { } } } - -sealed trait DateFormatter { - def parse(s: String): Int // returns days since epoch - def format(days: Int): String -} - -class Iso8601DateFormatter( - pattern: String, - locale: Locale) extends DateFormatter with FormatterUtils { - - val zoneId = ZoneId.of("UTC") - - val formatter = buildFormatter(pattern, locale) - - def toInstant(s: String): Instant = { - val temporalAccessor = formatter.parse(s) - toInstantWithZoneId(temporalAccessor) - } - - override def parse(s: String): Int = { - val seconds = toInstant(s).getEpochSecond - val days = Math.floorDiv(seconds, DateTimeUtils.SECONDS_PER_DAY) - - days.toInt - } - - override def format(days: Int): String = { - val instant = Instant.ofEpochSecond(days * DateTimeUtils.SECONDS_PER_DAY) - formatter.withZone(zoneId).format(instant) - } -} - -class LegacyDateFormatter(pattern: String, locale: Locale) extends DateFormatter { - val format = FastDateFormat.getInstance(pattern, locale) - - def parse(s: String): Int = { - val milliseconds = format.parse(s).getTime - DateTimeUtils.millisToDays(milliseconds) - } - - def format(days: Int): String = { - val date = DateTimeUtils.toJavaDate(days) - format.format(date) - } -} - -class LegacyFallbackDateFormatter( - pattern: String, - locale: Locale) extends LegacyDateFormatter(pattern, locale) { - override def parse(s: String): Int = { - Try(super.parse(s)).orElse { - // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards - // compatibility. - Try(DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(s).getTime)) - }.getOrElse { - // In Spark 1.5.0, we store the data as number of days since epoch in string. - // So, we just convert it to Int. - s.toInt - } - } -} - -object DateFormatter { - def apply(format: String, locale: Locale): DateFormatter = { - if (SQLConf.get.legacyTimeParserEnabled) { - new LegacyFallbackDateFormatter(format, locale) - } else { - new Iso8601DateFormatter(format, locale) - } - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateFormatterSuite.scala new file mode 100644 index 0000000000000..019615b81101c --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateFormatterSuite.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.util + +import java.util.Locale + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.plans.SQLHelper +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.internal.SQLConf + +class DateFormatterSuite extends SparkFunSuite with SQLHelper { + test("parsing dates") { + DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone => + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) { + val formatter = DateFormatter("yyyy-MM-dd", Locale.US) + val daysSinceEpoch = formatter.parse("2018-12-02") + assert(daysSinceEpoch === 17867) + } + } + } + + test("format dates") { + DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone => + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) { + val formatter = DateFormatter("yyyy-MM-dd", Locale.US) + val date = formatter.format(17867) + assert(date === "2018-12-02") + } + } + } + + test("roundtrip date -> days -> date") { + Seq( + "0050-01-01", + "0953-02-02", + "1423-03-08", + "1969-12-31", + "1972-08-25", + "1975-09-26", + "2018-12-12", + "2038-01-01", + "5010-11-17").foreach { date => + DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone => + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) { + val formatter = DateFormatter("yyyy-MM-dd", Locale.US) + val days = formatter.parse(date) + val formatted = formatter.format(days) + assert(date === formatted) + } + } + } + } + + test("roundtrip days -> date -> days") { + Seq( + -701265, + -371419, + -199722, + -1, + 0, + 967, + 2094, + 17877, + 24837, + 1110657).foreach { days => + DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone => + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) { + val formatter = DateFormatter("yyyy-MM-dd", Locale.US) + val date = formatter.format(days) + val parsed = formatter.parse(date) + assert(days === parsed) + } + } + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateTimestampFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala similarity index 66% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateTimestampFormatterSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala index 43e348c7eebf4..c110ffa01f733 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateTimestampFormatterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala @@ -21,19 +21,9 @@ import java.util.{Locale, TimeZone} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.plans.SQLHelper -import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, TimestampFormatter} -class DateTimestampFormatterSuite extends SparkFunSuite with SQLHelper { - test("parsing dates") { - DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone => - withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) { - val formatter = DateFormatter("yyyy-MM-dd", Locale.US) - val daysSinceEpoch = formatter.parse("2018-12-02") - assert(daysSinceEpoch === 17867) - } - } - } +class TimestampFormatterSuite extends SparkFunSuite with SQLHelper { test("parsing timestamps using time zones") { val localDate = "2018-12-02T10:11:12.001234" @@ -56,16 +46,6 @@ class DateTimestampFormatterSuite extends SparkFunSuite with SQLHelper { } } - test("format dates") { - DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone => - withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) { - val formatter = DateFormatter("yyyy-MM-dd", Locale.US) - val date = formatter.format(17867) - assert(date === "2018-12-02") - } - } - } - test("format timestamps using time zones") { val microsSinceEpoch = 1543745472001234L val expectedTimestamp = Map( @@ -87,7 +67,7 @@ class DateTimestampFormatterSuite extends SparkFunSuite with SQLHelper { } } - test("roundtrip timestamp -> micros -> timestamp using timezones") { + test("roundtrip micros -> timestamp -> micros using timezones") { Seq( -58710115316212000L, -18926315945345679L, @@ -107,7 +87,7 @@ class DateTimestampFormatterSuite extends SparkFunSuite with SQLHelper { } } - test("roundtrip micros -> timestamp -> micros using timezones") { + test("roundtrip timestamp -> micros -> timestamp using timezones") { Seq( "0109-07-20T18:38:03.788000", "1370-04-01T10:00:54.654321", @@ -126,49 +106,4 @@ class DateTimestampFormatterSuite extends SparkFunSuite with SQLHelper { } } } - - test("roundtrip date -> days -> date") { - Seq( - "0050-01-01", - "0953-02-02", - "1423-03-08", - "1969-12-31", - "1972-08-25", - "1975-09-26", - "2018-12-12", - "2038-01-01", - "5010-11-17").foreach { date => - DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone => - withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) { - val formatter = DateFormatter("yyyy-MM-dd", Locale.US) - val days = formatter.parse(date) - val formatted = formatter.format(days) - assert(date === formatted) - } - } - } - } - - test("roundtrip days -> date -> days") { - Seq( - -701265, - -371419, - -199722, - -1, - 0, - 967, - 2094, - 17877, - 24837, - 1110657).foreach { days => - DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone => - withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) { - val formatter = DateFormatter("yyyy-MM-dd", Locale.US) - val date = formatter.format(days) - val parsed = formatter.parse(date) - assert(days === parsed) - } - } - } - } } From 62a8466cea4f1ab07b41c2b3facd6576b9946359 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Mon, 17 Dec 2018 09:28:23 -0600 Subject: [PATCH 088/194] [SPARK-20351][ML] Add trait hasTrainingSummary to replace the duplicate code ## What changes were proposed in this pull request? Add a trait HasTrainingSummary to avoid code duplicate related to training summary. Currently all the training summary use the similar pattern which can be generalized, ``` private[ml] final var trainingSummary: Option[T] = None def hasSummary: Boolean = trainingSummary.isDefined def summary: T = trainingSummary.getOrElse... private[ml] def setSummary(summary: Option[T]): ... ``` Classes with the trait need to override `setSummry`. And for Java compatibility, they will also have to override `summary` method, otherwise the java code will regard all the summary class as Object due to a known issue with Scala. ## How was this patch tested? existing Java and Scala unit tests Closes #17654 from hhbyyh/hassummary. Authored-by: Yuhao Yang Signed-off-by: Sean Owen --- .../classification/LogisticRegression.scala | 24 ++------- .../spark/ml/clustering/BisectingKMeans.scala | 25 ++------- .../spark/ml/clustering/GaussianMixture.scala | 24 ++------- .../apache/spark/ml/clustering/KMeans.scala | 23 ++------ .../GeneralizedLinearRegression.scala | 22 ++------ .../ml/regression/LinearRegression.scala | 21 ++------ .../spark/ml/util/HasTrainingSummary.scala | 52 +++++++++++++++++++ 7 files changed, 78 insertions(+), 113 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/util/HasTrainingSummary.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 27a7db0b2f5d4..f2a5c11a34867 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -934,8 +934,8 @@ class LogisticRegressionModel private[spark] ( @Since("2.1.0") val interceptVector: Vector, @Since("1.3.0") override val numClasses: Int, private val isMultinomial: Boolean) - extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel] - with LogisticRegressionParams with MLWritable { + extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel] with MLWritable + with LogisticRegressionParams with HasTrainingSummary[LogisticRegressionTrainingSummary] { require(coefficientMatrix.numRows == interceptVector.size, s"Dimension mismatch! Expected " + s"coefficientMatrix.numRows == interceptVector.size, but ${coefficientMatrix.numRows} != " + @@ -1018,20 +1018,16 @@ class LogisticRegressionModel private[spark] ( @Since("1.6.0") override val numFeatures: Int = coefficientMatrix.numCols - private var trainingSummary: Option[LogisticRegressionTrainingSummary] = None - /** * Gets summary of model on training set. An exception is thrown - * if `trainingSummary == None`. + * if `hasSummary` is false. */ @Since("1.5.0") - def summary: LogisticRegressionTrainingSummary = trainingSummary.getOrElse { - throw new SparkException("No training summary available for this LogisticRegressionModel") - } + override def summary: LogisticRegressionTrainingSummary = super.summary /** * Gets summary of model on training set. An exception is thrown - * if `trainingSummary == None` or it is a multiclass model. + * if `hasSummary` is false or it is a multiclass model. */ @Since("2.3.0") def binarySummary: BinaryLogisticRegressionTrainingSummary = summary match { @@ -1062,16 +1058,6 @@ class LogisticRegressionModel private[spark] ( (model, model.getProbabilityCol, model.getPredictionCol) } - private[classification] - def setSummary(summary: Option[LogisticRegressionTrainingSummary]): this.type = { - this.trainingSummary = summary - this - } - - /** Indicates whether a training summary exists for this model instance. */ - @Since("1.5.0") - def hasSummary: Boolean = trainingSummary.isDefined - /** * Evaluates the model on a test dataset. * diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index 1a94aefa3f563..49e9f51368131 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -87,8 +87,9 @@ private[clustering] trait BisectingKMeansParams extends Params with HasMaxIter @Since("2.0.0") class BisectingKMeansModel private[ml] ( @Since("2.0.0") override val uid: String, - private val parentModel: MLlibBisectingKMeansModel - ) extends Model[BisectingKMeansModel] with BisectingKMeansParams with MLWritable { + private val parentModel: MLlibBisectingKMeansModel) + extends Model[BisectingKMeansModel] with BisectingKMeansParams with MLWritable + with HasTrainingSummary[BisectingKMeansSummary] { @Since("2.0.0") override def copy(extra: ParamMap): BisectingKMeansModel = { @@ -143,28 +144,12 @@ class BisectingKMeansModel private[ml] ( @Since("2.0.0") override def write: MLWriter = new BisectingKMeansModel.BisectingKMeansModelWriter(this) - private var trainingSummary: Option[BisectingKMeansSummary] = None - - private[clustering] def setSummary(summary: Option[BisectingKMeansSummary]): this.type = { - this.trainingSummary = summary - this - } - - /** - * Return true if there exists summary of model. - */ - @Since("2.1.0") - def hasSummary: Boolean = trainingSummary.nonEmpty - /** * Gets summary of model on training set. An exception is - * thrown if `trainingSummary == None`. + * thrown if `hasSummary` is false. */ @Since("2.1.0") - def summary: BisectingKMeansSummary = trainingSummary.getOrElse { - throw new SparkException( - s"No training summary available for the ${this.getClass.getSimpleName}") - } + override def summary: BisectingKMeansSummary = super.summary } object BisectingKMeansModel extends MLReadable[BisectingKMeansModel] { diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 88abc1605d69f..bb10b3228b93f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -85,7 +85,8 @@ class GaussianMixtureModel private[ml] ( @Since("2.0.0") override val uid: String, @Since("2.0.0") val weights: Array[Double], @Since("2.0.0") val gaussians: Array[MultivariateGaussian]) - extends Model[GaussianMixtureModel] with GaussianMixtureParams with MLWritable { + extends Model[GaussianMixtureModel] with GaussianMixtureParams with MLWritable + with HasTrainingSummary[GaussianMixtureSummary] { /** @group setParam */ @Since("2.1.0") @@ -160,28 +161,13 @@ class GaussianMixtureModel private[ml] ( @Since("2.0.0") override def write: MLWriter = new GaussianMixtureModel.GaussianMixtureModelWriter(this) - private var trainingSummary: Option[GaussianMixtureSummary] = None - - private[clustering] def setSummary(summary: Option[GaussianMixtureSummary]): this.type = { - this.trainingSummary = summary - this - } - - /** - * Return true if there exists summary of model. - */ - @Since("2.0.0") - def hasSummary: Boolean = trainingSummary.nonEmpty - /** * Gets summary of model on training set. An exception is - * thrown if `trainingSummary == None`. + * thrown if `hasSummary` is false. */ @Since("2.0.0") - def summary: GaussianMixtureSummary = trainingSummary.getOrElse { - throw new RuntimeException( - s"No training summary available for the ${this.getClass.getSimpleName}") - } + override def summary: GaussianMixtureSummary = super.summary + } @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 2eed84d51782a..319747d4a1930 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -107,7 +107,8 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe class KMeansModel private[ml] ( @Since("1.5.0") override val uid: String, private[clustering] val parentModel: MLlibKMeansModel) - extends Model[KMeansModel] with KMeansParams with GeneralMLWritable { + extends Model[KMeansModel] with KMeansParams with GeneralMLWritable + with HasTrainingSummary[KMeansSummary] { @Since("1.5.0") override def copy(extra: ParamMap): KMeansModel = { @@ -153,28 +154,12 @@ class KMeansModel private[ml] ( @Since("1.6.0") override def write: GeneralMLWriter = new GeneralMLWriter(this) - private var trainingSummary: Option[KMeansSummary] = None - - private[clustering] def setSummary(summary: Option[KMeansSummary]): this.type = { - this.trainingSummary = summary - this - } - - /** - * Return true if there exists summary of model. - */ - @Since("2.0.0") - def hasSummary: Boolean = trainingSummary.nonEmpty - /** * Gets summary of model on training set. An exception is - * thrown if `trainingSummary == None`. + * thrown if `hasSummary` is false. */ @Since("2.0.0") - def summary: KMeansSummary = trainingSummary.getOrElse { - throw new SparkException( - s"No training summary available for the ${this.getClass.getSimpleName}") - } + override def summary: KMeansSummary = super.summary } /** Helper class for storing model data */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index abb60ea205751..885b13bf8dac3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -1001,7 +1001,8 @@ class GeneralizedLinearRegressionModel private[ml] ( @Since("2.0.0") val coefficients: Vector, @Since("2.0.0") val intercept: Double) extends RegressionModel[Vector, GeneralizedLinearRegressionModel] - with GeneralizedLinearRegressionBase with MLWritable { + with GeneralizedLinearRegressionBase with MLWritable + with HasTrainingSummary[GeneralizedLinearRegressionTrainingSummary] { /** * Sets the link prediction (linear predictor) column name. @@ -1054,29 +1055,12 @@ class GeneralizedLinearRegressionModel private[ml] ( output.toDF() } - private var trainingSummary: Option[GeneralizedLinearRegressionTrainingSummary] = None - /** * Gets R-like summary of model on training set. An exception is * thrown if there is no summary available. */ @Since("2.0.0") - def summary: GeneralizedLinearRegressionTrainingSummary = trainingSummary.getOrElse { - throw new SparkException( - "No training summary available for this GeneralizedLinearRegressionModel") - } - - /** - * Indicates if [[summary]] is available. - */ - @Since("2.0.0") - def hasSummary: Boolean = trainingSummary.nonEmpty - - private[regression] - def setSummary(summary: Option[GeneralizedLinearRegressionTrainingSummary]): this.type = { - this.trainingSummary = summary - this - } + override def summary: GeneralizedLinearRegressionTrainingSummary = super.summary /** * Evaluate the model on the given dataset, returning a summary of the results. diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index ce6c12cc368dd..197828762d160 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -647,33 +647,20 @@ class LinearRegressionModel private[ml] ( @Since("1.3.0") val intercept: Double, @Since("2.3.0") val scale: Double) extends RegressionModel[Vector, LinearRegressionModel] - with LinearRegressionParams with GeneralMLWritable { + with LinearRegressionParams with GeneralMLWritable + with HasTrainingSummary[LinearRegressionTrainingSummary] { private[ml] def this(uid: String, coefficients: Vector, intercept: Double) = this(uid, coefficients, intercept, 1.0) - private var trainingSummary: Option[LinearRegressionTrainingSummary] = None - override val numFeatures: Int = coefficients.size /** * Gets summary (e.g. residuals, mse, r-squared ) of model on training set. An exception is - * thrown if `trainingSummary == None`. + * thrown if `hasSummary` is false. */ @Since("1.5.0") - def summary: LinearRegressionTrainingSummary = trainingSummary.getOrElse { - throw new SparkException("No training summary available for this LinearRegressionModel") - } - - private[regression] - def setSummary(summary: Option[LinearRegressionTrainingSummary]): this.type = { - this.trainingSummary = summary - this - } - - /** Indicates whether a training summary exists for this model instance. */ - @Since("1.5.0") - def hasSummary: Boolean = trainingSummary.isDefined + override def summary: LinearRegressionTrainingSummary = super.summary /** * Evaluates the model on a test dataset. diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/HasTrainingSummary.scala b/mllib/src/main/scala/org/apache/spark/ml/util/HasTrainingSummary.scala new file mode 100644 index 0000000000000..edb0208144e10 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/util/HasTrainingSummary.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import org.apache.spark.SparkException +import org.apache.spark.annotation.Since + + +/** + * Trait for models that provides Training summary. + * + * @tparam T Summary instance type + */ +@Since("3.0.0") +private[ml] trait HasTrainingSummary[T] { + + private[ml] final var trainingSummary: Option[T] = None + + /** Indicates whether a training summary exists for this model instance. */ + @Since("3.0.0") + def hasSummary: Boolean = trainingSummary.isDefined + + /** + * Gets summary of model on training set. An exception is + * thrown if if `hasSummary` is false. + */ + @Since("3.0.0") + def summary: T = trainingSummary.getOrElse { + throw new SparkException( + s"No training summary available for this ${this.getClass.getSimpleName}") + } + + private[ml] def setSummary(summary: Option[T]): this.type = { + this.trainingSummary = summary + this + } +} From 97a1d0d186eec85d79e6d3ec38fde0b93d47b48d Mon Sep 17 00:00:00 2001 From: chakravarthi Date: Mon, 17 Dec 2018 09:46:50 -0800 Subject: [PATCH 089/194] [SPARK-26255][YARN] Apply user provided UI filters to SQL tab in yarn mode MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? User specified filters are not applied to SQL tab in yarn mode, as it is overridden by the yarn AmIp filter. So we need to append user provided filters (spark.ui.filters) with yarn filter. ## How was this patch tested? 【Test step】: 1) Launch spark sql with authentication filter as below: 2) spark-sql --master yarn --conf spark.ui.filters=org.apache.hadoop.security.authentication.server.AuthenticationFilter --conf spark.org.apache.hadoop.security.authentication.server.AuthenticationFilter.params="type=simple" 3) Go to Yarn application list UI link 4) Launch the application master for the Spark-SQL app ID and access all the tabs by appending tab name. 5) It will display an error for all tabs including SQL tab.(before able to access SQL tab,as Authentication filter is not applied for SQL tab) 6) Also can be verified with info logs,that Authentication filter applied to SQL tab.(before it is not applied). I have attached the behaviour below in following order.. 1) Command used 2) Before fix (logs and UI) 3) After fix (logs and UI) **1) COMMAND USED**: launching spark-sql with authentication filter. ![image](https://user-images.githubusercontent.com/45845595/49947295-e7e97400-ff16-11e8-8c9a-10659487ddee.png) **2) BEFORE FIX:** **UI result:** able to access SQL tab. ![image](https://user-images.githubusercontent.com/45845595/49948398-62b38e80-ff19-11e8-95dc-e74f9e3c2ba7.png) **logs**: authentication filter not applied to SQL tab. ![image](https://user-images.githubusercontent.com/45845595/49947343-ff286180-ff16-11e8-9de0-3f8db140bc32.png) **3) AFTER FIX:** **UI result**: Not able to access SQL tab. ![image](https://user-images.githubusercontent.com/45845595/49947360-0d767d80-ff17-11e8-9e9e-a95311949164.png) **in logs**: Both yarn filter and Authentication filter applied to SQL tab. ![image](https://user-images.githubusercontent.com/45845595/49947377-1a936c80-ff17-11e8-9f44-700eb3dc0ded.png) Closes #23312 from chakravarthiT/SPARK-26255_ui. Authored-by: chakravarthi Signed-off-by: Marcelo Vanzin --- .../apache/spark/scheduler/cluster/YarnSchedulerBackend.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 67c36aac49266..1289d4be79ea4 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -168,8 +168,10 @@ private[spark] abstract class YarnSchedulerBackend( filterName != null && filterName.nonEmpty && filterParams != null && filterParams.nonEmpty if (hasFilter) { + // SPARK-26255: Append user provided filters(spark.ui.filters) with yarn filter. + val allFilters = filterName + "," + conf.get("spark.ui.filters", "") logInfo(s"Add WebUI Filter. $filterName, $filterParams, $proxyBase") - conf.set("spark.ui.filters", filterName) + conf.set("spark.ui.filters", allFilters) filterParams.foreach { case (k, v) => conf.set(s"spark.$filterName.param.$k", v) } scheduler.sc.ui.foreach { ui => JettyUtils.addFilters(ui.getHandlers, conf) } } From 53e05acef2fd1c50181c4d3c283aad3e08a53b91 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Mon, 17 Dec 2018 10:07:35 -0800 Subject: [PATCH 090/194] [SPARK-26371][SS] Increase kafka ConfigUpdater test coverage. ## What changes were proposed in this pull request? As Kafka delegation token added logic into ConfigUpdater it would be good to test it. This PR contains the following changes: * ConfigUpdater extracted to a separate file and renamed to KafkaConfigUpdater * mockito-core dependency added to kafka-0-10-sql * Unit tests added ## How was this patch tested? Existing + new unit tests + on cluster. Closes #23321 from gaborgsomogyi/SPARK-26371. Authored-by: Gabor Somogyi Signed-off-by: Dongjoon Hyun --- external/kafka-0-10-sql/pom.xml | 5 + .../sql/kafka010/KafkaConfigUpdater.scala | 74 ++++++++++++ .../sql/kafka010/KafkaSourceProvider.scala | 52 +------- .../kafka010/KafkaConfigUpdaterSuite.scala | 113 ++++++++++++++++++ .../kafka010/KafkaDelegationTokenTest.scala | 90 ++++++++++++++ .../kafka010/KafkaSecurityHelperSuite.scala | 46 +------ 6 files changed, 287 insertions(+), 93 deletions(-) create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaConfigUpdater.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaConfigUpdaterSuite.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDelegationTokenTest.scala diff --git a/external/kafka-0-10-sql/pom.xml b/external/kafka-0-10-sql/pom.xml index de8731c4b774b..1c77906f43b17 100644 --- a/external/kafka-0-10-sql/pom.xml +++ b/external/kafka-0-10-sql/pom.xml @@ -106,6 +106,11 @@ ${jetty.version} test + + org.mockito + mockito-core + test + org.scalacheck scalacheck_${scala.binary.version} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaConfigUpdater.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaConfigUpdater.scala new file mode 100644 index 0000000000000..bc1b8019f6a63 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaConfigUpdater.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} + +import scala.collection.JavaConverters._ + +import org.apache.kafka.common.config.SaslConfigs + +import org.apache.spark.SparkEnv +import org.apache.spark.deploy.security.KafkaTokenUtil +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.Kafka + +/** + * Class to conveniently update Kafka config params, while logging the changes + */ +private[kafka010] case class KafkaConfigUpdater(module: String, kafkaParams: Map[String, String]) + extends Logging { + private val map = new ju.HashMap[String, Object](kafkaParams.asJava) + + def set(key: String, value: Object): this.type = { + map.put(key, value) + logDebug(s"$module: Set $key to $value, earlier value: ${kafkaParams.getOrElse(key, "")}") + this + } + + def setIfUnset(key: String, value: Object): this.type = { + if (!map.containsKey(key)) { + map.put(key, value) + logDebug(s"$module: Set $key to $value") + } + this + } + + def setAuthenticationConfigIfNeeded(): this.type = { + // There are multiple possibilities to log in and applied in the following order: + // - JVM global security provided -> try to log in with JVM global security configuration + // which can be configured for example with 'java.security.auth.login.config'. + // For this no additional parameter needed. + // - Token is provided -> try to log in with scram module using kafka's dynamic JAAS + // configuration. + if (KafkaTokenUtil.isGlobalJaasConfigurationProvided) { + logDebug("JVM global security configuration detected, using it for login.") + } else if (KafkaSecurityHelper.isTokenAvailable()) { + logDebug("Delegation token detected, using it for login.") + val jaasParams = KafkaSecurityHelper.getTokenJaasParams(SparkEnv.get.conf) + set(SaslConfigs.SASL_JAAS_CONFIG, jaasParams) + val mechanism = SparkEnv.get.conf.get(Kafka.TOKEN_SASL_MECHANISM) + require(mechanism.startsWith("SCRAM"), + "Delegation token works only with SCRAM mechanism.") + set(SaslConfigs.SASL_MECHANISM, mechanism) + } + this + } + + def build(): ju.Map[String, Object] = map +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 4b8b5c0019b44..5774ee7a1c945 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -24,13 +24,9 @@ import scala.collection.JavaConverters._ import org.apache.kafka.clients.consumer.ConsumerConfig import org.apache.kafka.clients.producer.ProducerConfig -import org.apache.kafka.common.config.SaslConfigs import org.apache.kafka.common.serialization.{ByteArrayDeserializer, ByteArraySerializer} -import org.apache.spark.SparkEnv -import org.apache.spark.deploy.security.KafkaTokenUtil import org.apache.spark.internal.Logging -import org.apache.spark.internal.config._ import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext} import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ @@ -483,7 +479,7 @@ private[kafka010] object KafkaSourceProvider extends Logging { } def kafkaParamsForDriver(specifiedKafkaParams: Map[String, String]): ju.Map[String, Object] = - ConfigUpdater("source", specifiedKafkaParams) + KafkaConfigUpdater("source", specifiedKafkaParams) .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName) .set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName) @@ -506,7 +502,7 @@ private[kafka010] object KafkaSourceProvider extends Logging { def kafkaParamsForExecutors( specifiedKafkaParams: Map[String, String], uniqueGroupId: String): ju.Map[String, Object] = - ConfigUpdater("executor", specifiedKafkaParams) + KafkaConfigUpdater("executor", specifiedKafkaParams) .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName) .set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName) @@ -537,48 +533,6 @@ private[kafka010] object KafkaSourceProvider extends Logging { s"${groupIdPrefix}-${UUID.randomUUID}-${metadataPath.hashCode}" } - /** Class to conveniently update Kafka config params, while logging the changes */ - private case class ConfigUpdater(module: String, kafkaParams: Map[String, String]) { - private val map = new ju.HashMap[String, Object](kafkaParams.asJava) - - def set(key: String, value: Object): this.type = { - map.put(key, value) - logDebug(s"$module: Set $key to $value, earlier value: ${kafkaParams.getOrElse(key, "")}") - this - } - - def setIfUnset(key: String, value: Object): ConfigUpdater = { - if (!map.containsKey(key)) { - map.put(key, value) - logDebug(s"$module: Set $key to $value") - } - this - } - - def setAuthenticationConfigIfNeeded(): ConfigUpdater = { - // There are multiple possibilities to log in and applied in the following order: - // - JVM global security provided -> try to log in with JVM global security configuration - // which can be configured for example with 'java.security.auth.login.config'. - // For this no additional parameter needed. - // - Token is provided -> try to log in with scram module using kafka's dynamic JAAS - // configuration. - if (KafkaTokenUtil.isGlobalJaasConfigurationProvided) { - logDebug("JVM global security configuration detected, using it for login.") - } else if (KafkaSecurityHelper.isTokenAvailable()) { - logDebug("Delegation token detected, using it for login.") - val jaasParams = KafkaSecurityHelper.getTokenJaasParams(SparkEnv.get.conf) - set(SaslConfigs.SASL_JAAS_CONFIG, jaasParams) - val mechanism = SparkEnv.get.conf.get(Kafka.TOKEN_SASL_MECHANISM) - require(mechanism.startsWith("SCRAM"), - "Delegation token works only with SCRAM mechanism.") - set(SaslConfigs.SASL_MECHANISM, mechanism) - } - this - } - - def build(): ju.Map[String, Object] = map - } - private[kafka010] def kafkaParamsForProducer( parameters: Map[String, String]): ju.Map[String, Object] = { val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } @@ -596,7 +550,7 @@ private[kafka010] object KafkaSourceProvider extends Logging { val specifiedKafkaParams = convertToSpecifiedParams(parameters) - ConfigUpdater("executor", specifiedKafkaParams) + KafkaConfigUpdater("executor", specifiedKafkaParams) .set(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, serClassName) .set(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, serClassName) .setAuthenticationConfigIfNeeded() diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaConfigUpdaterSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaConfigUpdaterSuite.scala new file mode 100644 index 0000000000000..25ccca3cb9846 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaConfigUpdaterSuite.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import org.apache.kafka.common.config.SaslConfigs + +import org.apache.spark.SparkFunSuite +import org.apache.spark.internal.config._ + +class KafkaConfigUpdaterSuite extends SparkFunSuite with KafkaDelegationTokenTest { + private val testModule = "testModule" + private val testKey = "testKey" + private val testValue = "testValue" + private val otherTestValue = "otherTestValue" + + test("set should always set value") { + val params = Map.empty[String, String] + + val updatedParams = KafkaConfigUpdater(testModule, params) + .set(testKey, testValue) + .build() + + assert(updatedParams.size() === 1) + assert(updatedParams.get(testKey) === testValue) + } + + test("setIfUnset without existing key should set value") { + val params = Map.empty[String, String] + + val updatedParams = KafkaConfigUpdater(testModule, params) + .setIfUnset(testKey, testValue) + .build() + + assert(updatedParams.size() === 1) + assert(updatedParams.get(testKey) === testValue) + } + + test("setIfUnset with existing key should not set value") { + val params = Map[String, String](testKey -> testValue) + + val updatedParams = KafkaConfigUpdater(testModule, params) + .setIfUnset(testKey, otherTestValue) + .build() + + assert(updatedParams.size() === 1) + assert(updatedParams.get(testKey) === testValue) + } + + test("setAuthenticationConfigIfNeeded with global security should not set values") { + val params = Map.empty[String, String] + setGlobalKafkaClientConfig() + + val updatedParams = KafkaConfigUpdater(testModule, params) + .setAuthenticationConfigIfNeeded() + .build() + + assert(updatedParams.size() === 0) + } + + test("setAuthenticationConfigIfNeeded with token should set values") { + val params = Map.empty[String, String] + setSparkEnv(Map.empty) + addTokenToUGI() + + val updatedParams = KafkaConfigUpdater(testModule, params) + .setAuthenticationConfigIfNeeded() + .build() + + assert(updatedParams.size() === 2) + assert(updatedParams.containsKey(SaslConfigs.SASL_JAAS_CONFIG)) + assert(updatedParams.get(SaslConfigs.SASL_MECHANISM) === + Kafka.TOKEN_SASL_MECHANISM.defaultValueString) + } + + test("setAuthenticationConfigIfNeeded with token and invalid mechanism should throw exception") { + val params = Map.empty[String, String] + setSparkEnv(Map[String, String](Kafka.TOKEN_SASL_MECHANISM.key -> "INVALID")) + addTokenToUGI() + + val e = intercept[IllegalArgumentException] { + KafkaConfigUpdater(testModule, params) + .setAuthenticationConfigIfNeeded() + .build() + } + + assert(e.getMessage.contains("Delegation token works only with SCRAM mechanism.")) + } + + test("setAuthenticationConfigIfNeeded without security should not set values") { + val params = Map.empty[String, String] + + val updatedParams = KafkaConfigUpdater(testModule, params) + .setAuthenticationConfigIfNeeded() + .build() + + assert(updatedParams.size() === 0) + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDelegationTokenTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDelegationTokenTest.scala new file mode 100644 index 0000000000000..1899c65c721bb --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDelegationTokenTest.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} +import javax.security.auth.login.{AppConfigurationEntry, Configuration} + +import org.apache.hadoop.security.{Credentials, UserGroupInformation} +import org.apache.hadoop.security.token.Token +import org.mockito.Mockito.{doReturn, mock} +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.{SparkConf, SparkEnv, SparkFunSuite} +import org.apache.spark.deploy.security.KafkaTokenUtil +import org.apache.spark.deploy.security.KafkaTokenUtil.KafkaDelegationTokenIdentifier + +/** + * This is a trait which provides functionalities for Kafka delegation token related test suites. + */ +trait KafkaDelegationTokenTest extends BeforeAndAfterEach { + self: SparkFunSuite => + + protected val tokenId = "tokenId" + ju.UUID.randomUUID().toString + protected val tokenPassword = "tokenPassword" + ju.UUID.randomUUID().toString + + private class KafkaJaasConfiguration extends Configuration { + val entry = + new AppConfigurationEntry( + "DummyModule", + AppConfigurationEntry.LoginModuleControlFlag.REQUIRED, + ju.Collections.emptyMap[String, Object]() + ) + + override def getAppConfigurationEntry(name: String): Array[AppConfigurationEntry] = { + if (name.equals("KafkaClient")) { + Array(entry) + } else { + null + } + } + } + + override def afterEach(): Unit = { + try { + Configuration.setConfiguration(null) + UserGroupInformation.setLoginUser(null) + SparkEnv.set(null) + } finally { + super.afterEach() + } + } + + protected def setGlobalKafkaClientConfig(): Unit = { + Configuration.setConfiguration(new KafkaJaasConfiguration) + } + + protected def addTokenToUGI(): Unit = { + val token = new Token[KafkaDelegationTokenIdentifier]( + tokenId.getBytes, + tokenPassword.getBytes, + KafkaTokenUtil.TOKEN_KIND, + KafkaTokenUtil.TOKEN_SERVICE + ) + val creds = new Credentials() + creds.addToken(KafkaTokenUtil.TOKEN_SERVICE, token) + UserGroupInformation.getCurrentUser.addCredentials(creds) + } + + protected def setSparkEnv(settings: Traversable[(String, String)]): Unit = { + val conf = new SparkConf().setAll(settings) + val env = mock(classOf[SparkEnv]) + doReturn(conf).when(env).conf + SparkEnv.set(env) + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSecurityHelperSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSecurityHelperSuite.scala index fd9dee390d185..d908bbfc2c5f4 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSecurityHelperSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSecurityHelperSuite.scala @@ -17,51 +17,9 @@ package org.apache.spark.sql.kafka010 -import java.util.UUID - -import org.apache.hadoop.security.{Credentials, UserGroupInformation} -import org.apache.hadoop.security.token.Token -import org.scalatest.BeforeAndAfterEach - import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.security.KafkaTokenUtil -import org.apache.spark.deploy.security.KafkaTokenUtil.KafkaDelegationTokenIdentifier - -class KafkaSecurityHelperSuite extends SparkFunSuite with BeforeAndAfterEach { - private val tokenId = "tokenId" + UUID.randomUUID().toString - private val tokenPassword = "tokenPassword" + UUID.randomUUID().toString - - private var sparkConf: SparkConf = null - - override def beforeEach(): Unit = { - super.beforeEach() - sparkConf = new SparkConf() - } - - override def afterEach(): Unit = { - try { - resetUGI - } finally { - super.afterEach() - } - } - - private def addTokenToUGI(): Unit = { - val token = new Token[KafkaDelegationTokenIdentifier]( - tokenId.getBytes, - tokenPassword.getBytes, - KafkaTokenUtil.TOKEN_KIND, - KafkaTokenUtil.TOKEN_SERVICE - ) - val creds = new Credentials() - creds.addToken(KafkaTokenUtil.TOKEN_SERVICE, token) - UserGroupInformation.getCurrentUser.addCredentials(creds) - } - - private def resetUGI: Unit = { - UserGroupInformation.setLoginUser(null) - } +class KafkaSecurityHelperSuite extends SparkFunSuite with KafkaDelegationTokenTest { test("isTokenAvailable without token should return false") { assert(!KafkaSecurityHelper.isTokenAvailable()) } @@ -75,7 +33,7 @@ class KafkaSecurityHelperSuite extends SparkFunSuite with BeforeAndAfterEach { test("getTokenJaasParams with token should return scram module") { addTokenToUGI() - val jaasParams = KafkaSecurityHelper.getTokenJaasParams(sparkConf) + val jaasParams = KafkaSecurityHelper.getTokenJaasParams(new SparkConf()) assert(jaasParams.contains("ScramLoginModule required")) assert(jaasParams.contains("tokenauth=true")) From 9bbc1f936edf9d46fdf73c7be7442e8cd4e26ba7 Mon Sep 17 00:00:00 2001 From: Vaclav Kosar Date: Mon, 17 Dec 2018 11:50:24 -0800 Subject: [PATCH 091/194] [SPARK-24933][SS] Report numOutputRows in SinkProgress ## What changes were proposed in this pull request? SinkProgress should report similar properties like SourceProgress as long as they are available for given Sink. Count of written rows is metric availble for all Sinks. Since relevant progress information is with respect to commited rows, ideal object to carry this info is WriterCommitMessage. For brevity the implementation will focus only on Sinks with API V2 and on Micro Batch mode. Implemention for Continuous mode will be provided at later date. ### Before ``` {"description":"org.apache.spark.sql.kafka010.KafkaSourceProvider3c0bd317"} ``` ### After ``` {"description":"org.apache.spark.sql.kafka010.KafkaSourceProvider3c0bd317","numOutputRows":5000} ``` ### This PR is related to: - https://issues.apache.org/jira/browse/SPARK-24647 - https://issues.apache.org/jira/browse/SPARK-21313 ## How was this patch tested? Existing and new unit tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #21919 from vackosar/feature/SPARK-24933-numOutputRows. Lead-authored-by: Vaclav Kosar Co-authored-by: Kosar, Vaclav: Functions Transformation Signed-off-by: gatorsmile --- .../spark/sql/kafka010/KafkaSinkSuite.scala | 21 +++++++++++++ .../v2/WriteToDataSourceV2Exec.scala | 30 +++++++++++++++---- .../streaming/MicroBatchExecution.scala | 11 +++++-- .../streaming/ProgressReporter.scala | 7 +++-- .../execution/streaming/StreamExecution.scala | 4 +++ .../apache/spark/sql/streaming/progress.scala | 21 +++++++++++-- ...StreamingQueryStatusAndProgressSuite.scala | 10 ++++--- 7 files changed, 88 insertions(+), 16 deletions(-) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala index d46c4139011da..07d2b8a5dc420 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -232,6 +232,27 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext with KafkaTest { } } + test("streaming - sink progress is produced") { + /* ensure sink progress is correctly produced. */ + val input = MemoryStream[String] + val topic = newTopic() + testUtils.createTopic(topic) + + val writer = createKafkaWriter( + input.toDF(), + withTopic = Some(topic), + withOutputMode = Some(OutputMode.Update()))() + + try { + input.addData("1", "2", "3") + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + assert(writer.lastProgress.sink.numOutputRows == 3L) + } finally { + writer.stop() + } + } test("streaming - write data with bad schema") { val input = MemoryStream[String] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 9a1fe1e0a328b..d7e20eed4cbc0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.util.Utils +import org.apache.spark.util.{LongAccumulator, Utils} /** * Deprecated logical plan for writing data into data source v2. This is being replaced by more @@ -47,6 +47,8 @@ case class WriteToDataSourceV2(writeSupport: BatchWriteSupport, query: LogicalPl case class WriteToDataSourceV2Exec(writeSupport: BatchWriteSupport, query: SparkPlan) extends UnaryExecNode { + var commitProgress: Option[StreamWriterCommitProgress] = None + override def child: SparkPlan = query override def output: Seq[Attribute] = Nil @@ -55,6 +57,7 @@ case class WriteToDataSourceV2Exec(writeSupport: BatchWriteSupport, query: Spark val useCommitCoordinator = writeSupport.useCommitCoordinator val rdd = query.execute() val messages = new Array[WriterCommitMessage](rdd.partitions.length) + val totalNumRowsAccumulator = new LongAccumulator() logInfo(s"Start processing data source write support: $writeSupport. " + s"The input RDD has ${messages.length} partitions.") @@ -65,15 +68,18 @@ case class WriteToDataSourceV2Exec(writeSupport: BatchWriteSupport, query: Spark (context: TaskContext, iter: Iterator[InternalRow]) => DataWritingSparkTask.run(writerFactory, context, iter, useCommitCoordinator), rdd.partitions.indices, - (index, message: WriterCommitMessage) => { - messages(index) = message - writeSupport.onDataWriterCommit(message) + (index, result: DataWritingSparkTaskResult) => { + val commitMessage = result.writerCommitMessage + messages(index) = commitMessage + totalNumRowsAccumulator.add(result.numRows) + writeSupport.onDataWriterCommit(commitMessage) } ) logInfo(s"Data source write support $writeSupport is committing.") writeSupport.commit(messages) logInfo(s"Data source write support $writeSupport committed.") + commitProgress = Some(StreamWriterCommitProgress(totalNumRowsAccumulator.value)) } catch { case cause: Throwable => logError(s"Data source write support $writeSupport is aborting.") @@ -102,7 +108,7 @@ object DataWritingSparkTask extends Logging { writerFactory: DataWriterFactory, context: TaskContext, iter: Iterator[InternalRow], - useCommitCoordinator: Boolean): WriterCommitMessage = { + useCommitCoordinator: Boolean): DataWritingSparkTaskResult = { val stageId = context.stageId() val stageAttempt = context.stageAttemptNumber() val partId = context.partitionId() @@ -110,9 +116,12 @@ object DataWritingSparkTask extends Logging { val attemptId = context.attemptNumber() val dataWriter = writerFactory.createWriter(partId, taskId) + var count = 0L // write the data and commit this writer. Utils.tryWithSafeFinallyAndFailureCallbacks(block = { while (iter.hasNext) { + // Count is here. + count += 1 dataWriter.write(iter.next()) } @@ -139,7 +148,7 @@ object DataWritingSparkTask extends Logging { logInfo(s"Committed partition $partId (task $taskId, attempt $attemptId" + s"stage $stageId.$stageAttempt)") - msg + DataWritingSparkTaskResult(count, msg) })(catchBlock = { // If there is an error, abort this writer @@ -151,3 +160,12 @@ object DataWritingSparkTask extends Logging { }) } } + +private[v2] case class DataWritingSparkTaskResult( + numRows: Long, + writerCommitMessage: WriterCommitMessage) + +/** + * Sink progress information collected after commit. + */ +private[sql] case class StreamWriterCommitProgress(numOutputRows: Long) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 03beefeca269b..8ad436a4ff57d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentBatchTimestamp, import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} +import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, StreamWriterCommitProgress, WriteToDataSourceV2, WriteToDataSourceV2Exec} import org.apache.spark.sql.execution.streaming.sources.{MicroBatchWritSupport, RateControlMicroBatchReadSupport} import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset => OffsetV2} @@ -246,6 +246,7 @@ class MicroBatchExecution( * DONE */ private def populateStartOffsets(sparkSessionToRunBatches: SparkSession): Unit = { + sinkCommitProgress = None offsetLog.getLatest() match { case Some((latestBatchId, nextOffsets)) => /* First assume that we are re-executing the latest known batch @@ -537,7 +538,8 @@ class MicroBatchExecution( val nextBatch = new Dataset(sparkSessionToRunBatch, lastExecution, RowEncoder(lastExecution.analyzed.schema)) - reportTimeTaken("addBatch") { + val batchSinkProgress: Option[StreamWriterCommitProgress] = + reportTimeTaken("addBatch") { SQLExecution.withNewExecutionId(sparkSessionToRunBatch, lastExecution) { sink match { case s: Sink => s.addBatch(currentBatchId, nextBatch) @@ -545,10 +547,15 @@ class MicroBatchExecution( // This doesn't accumulate any data - it just forces execution of the microbatch writer. nextBatch.collect() } + lastExecution.executedPlan match { + case w: WriteToDataSourceV2Exec => w.commitProgress + case _ => None + } } } withProgressLocked { + sinkCommitProgress = batchSinkProgress watermarkTracker.updateWatermark(lastExecution.executedPlan) commitLog.add(currentBatchId, CommitMetadata(watermarkTracker.currentWatermark)) committedOffsets ++= availableOffsets diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index 39ab702ee083c..d1f3f74c5e731 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.QueryExecution -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2StreamingScanExec +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2StreamingScanExec, StreamWriterCommitProgress} import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReadSupport import org.apache.spark.sql.streaming._ import org.apache.spark.sql.streaming.StreamingQueryListener.QueryProgressEvent @@ -56,6 +56,7 @@ trait ProgressReporter extends Logging { protected def logicalPlan: LogicalPlan protected def lastExecution: QueryExecution protected def newData: Map[BaseStreamingSource, LogicalPlan] + protected def sinkCommitProgress: Option[StreamWriterCommitProgress] protected def sources: Seq[BaseStreamingSource] protected def sink: BaseStreamingSink protected def offsetSeqMetadata: OffsetSeqMetadata @@ -167,7 +168,9 @@ trait ProgressReporter extends Logging { ) } - val sinkProgress = new SinkProgress(sink.toString) + val sinkProgress = SinkProgress( + sink.toString, + sinkCommitProgress.map(_.numOutputRows)) val newProgress = new StreamingQueryProgress( id = id, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 89b4f40c9c0b9..83824f40ab90b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.command.StreamingExplainCommand +import org.apache.spark.sql.execution.datasources.v2.StreamWriterCommitProgress import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming._ import org.apache.spark.util.{Clock, UninterruptibleThread, Utils} @@ -114,6 +115,9 @@ abstract class StreamExecution( @volatile var availableOffsets = new StreamProgress + @volatile + var sinkCommitProgress: Option[StreamWriterCommitProgress] = None + /** The current batchId or -1 if execution has not yet been initialized. */ protected var currentBatchId: Long = -1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala index 3cd6700efef5f..0b3945cbd1323 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala @@ -30,6 +30,7 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.annotation.Evolving +import org.apache.spark.sql.streaming.SinkProgress.DEFAULT_NUM_OUTPUT_ROWS /** * Information about updates made to stateful operators in a [[StreamingQuery]] during a trigger. @@ -207,11 +208,19 @@ class SourceProgress protected[sql]( * during a trigger. See [[StreamingQueryProgress]] for more information. * * @param description Description of the source corresponding to this status. + * @param numOutputRows Number of rows written to the sink or -1 for Continuous Mode (temporarily) + * or Sink V1 (until decommissioned). * @since 2.1.0 */ @Evolving class SinkProgress protected[sql]( - val description: String) extends Serializable { + val description: String, + val numOutputRows: Long) extends Serializable { + + /** SinkProgress without custom metrics. */ + protected[sql] def this(description: String) { + this(description, DEFAULT_NUM_OUTPUT_ROWS) + } /** The compact JSON representation of this progress. */ def json: String = compact(render(jsonValue)) @@ -222,6 +231,14 @@ class SinkProgress protected[sql]( override def toString: String = prettyJson private[sql] def jsonValue: JValue = { - ("description" -> JString(description)) + ("description" -> JString(description)) ~ + ("numOutputRows" -> JInt(numOutputRows)) } } + +private[sql] object SinkProgress { + val DEFAULT_NUM_OUTPUT_ROWS: Long = -1L + + def apply(description: String, numOutputRows: Option[Long]): SinkProgress = + new SinkProgress(description, numOutputRows.getOrElse(DEFAULT_NUM_OUTPUT_ROWS)) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala index 7bef687e7e43b..2f460b044b237 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala @@ -73,7 +73,8 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually { | "inputRowsPerSecond" : 10.0 | } ], | "sink" : { - | "description" : "sink" + | "description" : "sink", + | "numOutputRows" : -1 | } |} """.stripMargin.trim) @@ -105,7 +106,8 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually { | "numInputRows" : 678 | } ], | "sink" : { - | "description" : "sink" + | "description" : "sink", + | "numOutputRows" : -1 | } |} """.stripMargin.trim) @@ -250,7 +252,7 @@ object StreamingQueryStatusAndProgressSuite { processedRowsPerSecond = Double.PositiveInfinity // should not be present in the json ) ), - sink = new SinkProgress("sink") + sink = SinkProgress("sink", None) ) val testProgress2 = new StreamingQueryProgress( @@ -274,7 +276,7 @@ object StreamingQueryStatusAndProgressSuite { processedRowsPerSecond = Double.NegativeInfinity // should not be present in the json ) ), - sink = new SinkProgress("sink") + sink = SinkProgress("sink", None) ) val testStatus = new StreamingQueryStatus("active", true, false) From a97ca7ab365c53e9e5375c129b8c1a63746a4e18 Mon Sep 17 00:00:00 2001 From: suxingfate Date: Mon, 17 Dec 2018 13:36:57 -0800 Subject: [PATCH 092/194] [SPARK-25922][K8] Spark Driver/Executor "spark-app-selector" label mismatch ## What changes were proposed in this pull request? In K8S Cluster mode, the algorithm to generate spark-app-selector/spark.app.id of spark driver is different with spark executor. This patch makes sure spark driver and executor to use the same spark-app-selector/spark.app.id if spark.app.id is set, otherwise it will use superclass applicationId. In K8S Client mode, spark-app-selector/spark.app.id for executors will use superclass applicationId. ## How was this patch tested? Manually run." Closes #23322 from suxingfate/SPARK-25922. Lead-authored-by: suxingfate Co-authored-by: xinglwang Signed-off-by: Yinan Li --- .../KubernetesClusterSchedulerBackend.scala | 28 ++++++++++++++----- ...bernetesClusterSchedulerBackendSuite.scala | 14 +++++----- 2 files changed, 28 insertions(+), 14 deletions(-) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala index 68f6f2e46e316..03f5da2bb0bce 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala @@ -18,9 +18,10 @@ package org.apache.spark.scheduler.cluster.k8s import java.util.concurrent.ExecutorService -import io.fabric8.kubernetes.client.KubernetesClient import scala.concurrent.{ExecutionContext, Future} +import io.fabric8.kubernetes.client.KubernetesClient + import org.apache.spark.SparkContext import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ @@ -39,10 +40,10 @@ private[spark] class KubernetesClusterSchedulerBackend( lifecycleEventHandler: ExecutorPodsLifecycleManager, watchEvents: ExecutorPodsWatchSnapshotSource, pollEvents: ExecutorPodsPollingSnapshotSource) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) { + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) { - private implicit val requestExecutorContext = ExecutionContext.fromExecutorService( - requestExecutorsService) + private implicit val requestExecutorContext = + ExecutionContext.fromExecutorService(requestExecutorsService) protected override val minRegisteredRatio = if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) { @@ -60,6 +61,17 @@ private[spark] class KubernetesClusterSchedulerBackend( removeExecutor(executorId, reason) } + /** + * Get an application ID associated with the job. + * This returns the string value of spark.app.id if set, otherwise + * the locally-generated ID from the superclass. + * + * @return The application ID + */ + override def applicationId(): String = { + conf.getOption("spark.app.id").map(_.toString).getOrElse(super.applicationId) + } + override def start(): Unit = { super.start() if (!Utils.isDynamicAllocationEnabled(conf)) { @@ -88,7 +100,8 @@ private[spark] class KubernetesClusterSchedulerBackend( if (shouldDeleteExecutors) { Utils.tryLogNonFatalError { - kubernetesClient.pods() + kubernetesClient + .pods() .withLabel(SPARK_APP_ID_LABEL, applicationId()) .withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE) .delete() @@ -120,7 +133,8 @@ private[spark] class KubernetesClusterSchedulerBackend( } override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = Future[Boolean] { - kubernetesClient.pods() + kubernetesClient + .pods() .withLabel(SPARK_APP_ID_LABEL, applicationId()) .withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE) .withLabelIn(SPARK_EXECUTOR_ID_LABEL, executorIds: _*) @@ -133,7 +147,7 @@ private[spark] class KubernetesClusterSchedulerBackend( } private class KubernetesDriverEndpoint(rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) - extends DriverEndpoint(rpcEnv, sparkProperties) { + extends DriverEndpoint(rpcEnv, sparkProperties) { override def onDisconnected(rpcAddress: RpcAddress): Unit = { // Don't do anything besides disabling the executor - allow the Kubernetes API events to diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala index 75232f7b98b04..6e182bed459f8 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala @@ -37,6 +37,7 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn private val requestExecutorsService = new DeterministicScheduler() private val sparkConf = new SparkConf(false) .set("spark.executor.instances", "3") + .set("spark.app.id", TEST_SPARK_APP_ID) @Mock private var sc: SparkContext = _ @@ -87,8 +88,10 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn when(sc.env).thenReturn(env) when(env.rpcEnv).thenReturn(rpcEnv) driverEndpoint = ArgumentCaptor.forClass(classOf[RpcEndpoint]) - when(rpcEnv.setupEndpoint( - mockitoEq(CoarseGrainedSchedulerBackend.ENDPOINT_NAME), driverEndpoint.capture())) + when( + rpcEnv.setupEndpoint( + mockitoEq(CoarseGrainedSchedulerBackend.ENDPOINT_NAME), + driverEndpoint.capture())) .thenReturn(driverEndpointRef) when(kubernetesClient.pods()).thenReturn(podOperations) schedulerBackendUnderTest = new KubernetesClusterSchedulerBackend( @@ -100,9 +103,7 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn podAllocator, lifecycleEventHandler, watchEvents, - pollEvents) { - override def applicationId(): String = TEST_SPARK_APP_ID - } + pollEvents) } test("Start all components") { @@ -127,8 +128,7 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn test("Remove executor") { schedulerBackendUnderTest.start() - schedulerBackendUnderTest.doRemoveExecutor( - "1", ExecutorKilled) + schedulerBackendUnderTest.doRemoveExecutor("1", ExecutorKilled) verify(driverEndpointRef).send(RemoveExecutor("1", ExecutorKilled)) } From 3ee251f765dc573b2e09281e597b69e6bc1beec7 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 18 Dec 2018 09:15:21 +0800 Subject: [PATCH 093/194] [SPARK-24561][SQL][PYTHON] User-defined window aggregation functions with Pandas UDF (bounded window) ## What changes were proposed in this pull request? This PR implements a new feature - window aggregation Pandas UDF for bounded window. #### Doc: https://docs.google.com/document/d/14EjeY5z4-NC27-SmIP9CsMPCANeTcvxN44a7SIJtZPc/edit#heading=h.c87w44wcj3wj #### Example: ``` from pyspark.sql.functions import pandas_udf, PandasUDFType from pyspark.sql.window import Window df = spark.range(0, 10, 2).toDF('v') w1 = Window.partitionBy().orderBy('v').rangeBetween(-2, 4) w2 = Window.partitionBy().orderBy('v').rowsBetween(-2, 2) pandas_udf('double', PandasUDFType.GROUPED_AGG) def avg(v): return v.mean() df.withColumn('v_mean', avg(df['v']).over(w1)).show() # +---+------+ # | v|v_mean| # +---+------+ # | 0| 1.0| # | 2| 2.0| # | 4| 4.0| # | 6| 6.0| # | 8| 7.0| # +---+------+ df.withColumn('v_mean', avg(df['v']).over(w2)).show() # +---+------+ # | v|v_mean| # +---+------+ # | 0| 2.0| # | 2| 3.0| # | 4| 4.0| # | 6| 5.0| # | 8| 6.0| # +---+------+ ``` #### High level changes: This PR modifies the existing WindowInPandasExec physical node to deal with unbounded (growing, shrinking and sliding) windows. * `WindowInPandasExec` now share the same base class as `WindowExec` and share utility functions. See `WindowExecBase` * `WindowFunctionFrame` now has two new functions `currentLowerBound` and `currentUpperBound` - to return the lower and upper window bound for the current output row. It is also modified to allow `AggregateProcessor` == null. Null aggregator processor is used for `WindowInPandasExec` where we don't have an aggregator and only uses lower and upper bound functions from `WindowFunctionFrame` * The biggest change is in `WindowInPandasExec`, where it is modified to take `currentLowerBound` and `currentUpperBound` and write those values together with the input data to the python process for rolling window aggregation. See `WindowInPandasExec` for more details. #### Discussion In benchmarking, I found numpy variant of the rolling window UDF is much faster than the pandas version: Spark SQL window function: 20s Pandas variant: ~80s Numpy variant: 10s Numpy variant with numba: 4s Allowing numpy variant of the vectorized UDFs is something I want to discuss because of the performance improvement, but doesn't have to be in this PR. ## How was this patch tested? New tests Closes #22305 from icexelloss/SPARK-24561-bounded-window-udf. Authored-by: Li Jin Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/functions.py | 21 +- .../sql/tests/test_pandas_udf_window.py | 157 ++++++++- python/pyspark/worker.py | 57 ++- .../sql/catalyst/analysis/CheckAnalysis.scala | 5 - .../execution/python/WindowInPandasExec.scala | 329 +++++++++++++++--- .../sql/execution/window/WindowExec.scala | 189 +--------- .../sql/execution/window/WindowExecBase.scala | 230 ++++++++++++ .../window/WindowFunctionFrame.scala | 108 ++++-- 8 files changed, 792 insertions(+), 304 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f98e550e39da8..d188de39e21c7 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2982,8 +2982,7 @@ def pandas_udf(f=None, returnType=None, functionType=None): | 2| 6.0| +---+-----------+ - This example shows using grouped aggregated UDFs as window functions. Note that only - unbounded window frame is supported at the moment: + This example shows using grouped aggregated UDFs as window functions. >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> from pyspark.sql import Window @@ -2993,20 +2992,24 @@ def pandas_udf(f=None, returnType=None, functionType=None): >>> @pandas_udf("double", PandasUDFType.GROUPED_AGG) # doctest: +SKIP ... def mean_udf(v): ... return v.mean() - >>> w = Window \\ - ... .partitionBy('id') \\ - ... .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) + >>> w = (Window.partitionBy('id') + ... .orderBy('v') + ... .rowsBetween(-1, 0)) >>> df.withColumn('mean_v', mean_udf(df['v']).over(w)).show() # doctest: +SKIP +---+----+------+ | id| v|mean_v| +---+----+------+ - | 1| 1.0| 1.5| + | 1| 1.0| 1.0| | 1| 2.0| 1.5| - | 2| 3.0| 6.0| - | 2| 5.0| 6.0| - | 2|10.0| 6.0| + | 2| 3.0| 3.0| + | 2| 5.0| 4.0| + | 2|10.0| 7.5| +---+----+------+ + .. note:: For performance reasons, the input series to window functions are not copied. + Therefore, mutating the input series is not allowed and will cause incorrect results. + For the same reason, users should also not rely on the index of the input series. + .. seealso:: :meth:`pyspark.sql.GroupedData.agg` and :class:`pyspark.sql.Window` .. note:: The user-defined functions are considered deterministic by default. Due to diff --git a/python/pyspark/sql/tests/test_pandas_udf_window.py b/python/pyspark/sql/tests/test_pandas_udf_window.py index 0a7a19c1c0814..3ba98e76468b3 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_window.py +++ b/python/pyspark/sql/tests/test_pandas_udf_window.py @@ -46,6 +46,15 @@ def python_plus_one(self): def pandas_scalar_time_two(self): return pandas_udf(lambda v: v * 2, 'double') + @property + def pandas_agg_count_udf(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('long', PandasUDFType.GROUPED_AGG) + def count(v): + return len(v) + return count + @property def pandas_agg_mean_udf(self): @pandas_udf('double', PandasUDFType.GROUPED_AGG) @@ -70,7 +79,7 @@ def min(v): @property def unbounded_window(self): return Window.partitionBy('id') \ - .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing).orderBy('v') @property def ordered_window(self): @@ -80,6 +89,32 @@ def ordered_window(self): def unpartitioned_window(self): return Window.partitionBy() + @property + def sliding_row_window(self): + return Window.partitionBy('id').orderBy('v').rowsBetween(-2, 1) + + @property + def sliding_range_window(self): + return Window.partitionBy('id').orderBy('v').rangeBetween(-2, 4) + + @property + def growing_row_window(self): + return Window.partitionBy('id').orderBy('v').rowsBetween(Window.unboundedPreceding, 3) + + @property + def growing_range_window(self): + return Window.partitionBy('id').orderBy('v') \ + .rangeBetween(Window.unboundedPreceding, 4) + + @property + def shrinking_row_window(self): + return Window.partitionBy('id').orderBy('v').rowsBetween(-2, Window.unboundedFollowing) + + @property + def shrinking_range_window(self): + return Window.partitionBy('id').orderBy('v') \ + .rangeBetween(-3, Window.unboundedFollowing) + def test_simple(self): df = self.data w = self.unbounded_window @@ -100,12 +135,12 @@ def test_multiple_udfs(self): w = self.unbounded_window result1 = df.withColumn('mean_v', self.pandas_agg_mean_udf(df['v']).over(w)) \ - .withColumn('max_v', self.pandas_agg_max_udf(df['v']).over(w)) \ - .withColumn('min_w', self.pandas_agg_min_udf(df['w']).over(w)) + .withColumn('max_v', self.pandas_agg_max_udf(df['v']).over(w)) \ + .withColumn('min_w', self.pandas_agg_min_udf(df['w']).over(w)) expected1 = df.withColumn('mean_v', mean(df['v']).over(w)) \ - .withColumn('max_v', max(df['v']).over(w)) \ - .withColumn('min_w', min(df['w']).over(w)) + .withColumn('max_v', max(df['v']).over(w)) \ + .withColumn('min_w', min(df['w']).over(w)) self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) @@ -183,16 +218,16 @@ def test_mixed_sql_and_udf(self): # Test chaining sql aggregate function and udf result3 = df.withColumn('max_v', max_udf(df['v']).over(w)) \ - .withColumn('min_v', min(df['v']).over(w)) \ - .withColumn('v_diff', col('max_v') - col('min_v')) \ - .drop('max_v', 'min_v') + .withColumn('min_v', min(df['v']).over(w)) \ + .withColumn('v_diff', col('max_v') - col('min_v')) \ + .drop('max_v', 'min_v') expected3 = expected1 # Test mixing sql window function and udf result4 = df.withColumn('max_v', max_udf(df['v']).over(w)) \ - .withColumn('rank', rank().over(ow)) + .withColumn('rank', rank().over(ow)) expected4 = df.withColumn('max_v', max(df['v']).over(w)) \ - .withColumn('rank', rank().over(ow)) + .withColumn('rank', rank().over(ow)) self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) @@ -210,8 +245,6 @@ def test_array_type(self): def test_invalid_args(self): df = self.data w = self.unbounded_window - ow = self.ordered_window - mean_udf = self.pandas_agg_mean_udf with QuietTest(self.sc): with self.assertRaisesRegexp( @@ -220,11 +253,101 @@ def test_invalid_args(self): foo_udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUPED_MAP) df.withColumn('v2', foo_udf(df['v']).over(w)) - with QuietTest(self.sc): - with self.assertRaisesRegexp( - AnalysisException, - '.*Only unbounded window frame is supported.*'): - df.withColumn('mean_v', mean_udf(df['v']).over(ow)) + def test_bounded_simple(self): + from pyspark.sql.functions import mean, max, min, count + + df = self.data + w1 = self.sliding_row_window + w2 = self.shrinking_range_window + + plus_one = self.python_plus_one + count_udf = self.pandas_agg_count_udf + mean_udf = self.pandas_agg_mean_udf + max_udf = self.pandas_agg_max_udf + min_udf = self.pandas_agg_min_udf + + result1 = df.withColumn('mean_v', mean_udf(plus_one(df['v'])).over(w1)) \ + .withColumn('count_v', count_udf(df['v']).over(w2)) \ + .withColumn('max_v', max_udf(df['v']).over(w2)) \ + .withColumn('min_v', min_udf(df['v']).over(w1)) + + expected1 = df.withColumn('mean_v', mean(plus_one(df['v'])).over(w1)) \ + .withColumn('count_v', count(df['v']).over(w2)) \ + .withColumn('max_v', max(df['v']).over(w2)) \ + .withColumn('min_v', min(df['v']).over(w1)) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + def test_growing_window(self): + from pyspark.sql.functions import mean + + df = self.data + w1 = self.growing_row_window + w2 = self.growing_range_window + + mean_udf = self.pandas_agg_mean_udf + + result1 = df.withColumn('m1', mean_udf(df['v']).over(w1)) \ + .withColumn('m2', mean_udf(df['v']).over(w2)) + + expected1 = df.withColumn('m1', mean(df['v']).over(w1)) \ + .withColumn('m2', mean(df['v']).over(w2)) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + def test_sliding_window(self): + from pyspark.sql.functions import mean + + df = self.data + w1 = self.sliding_row_window + w2 = self.sliding_range_window + + mean_udf = self.pandas_agg_mean_udf + + result1 = df.withColumn('m1', mean_udf(df['v']).over(w1)) \ + .withColumn('m2', mean_udf(df['v']).over(w2)) + + expected1 = df.withColumn('m1', mean(df['v']).over(w1)) \ + .withColumn('m2', mean(df['v']).over(w2)) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + def test_shrinking_window(self): + from pyspark.sql.functions import mean + + df = self.data + w1 = self.shrinking_row_window + w2 = self.shrinking_range_window + + mean_udf = self.pandas_agg_mean_udf + + result1 = df.withColumn('m1', mean_udf(df['v']).over(w1)) \ + .withColumn('m2', mean_udf(df['v']).over(w2)) + + expected1 = df.withColumn('m1', mean(df['v']).over(w1)) \ + .withColumn('m2', mean(df['v']).over(w2)) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + def test_bounded_mixed(self): + from pyspark.sql.functions import mean, max + + df = self.data + w1 = self.sliding_row_window + w2 = self.unbounded_window + + mean_udf = self.pandas_agg_mean_udf + max_udf = self.pandas_agg_max_udf + + result1 = df.withColumn('mean_v', mean_udf(df['v']).over(w1)) \ + .withColumn('max_v', max_udf(df['v']).over(w2)) \ + .withColumn('mean_unbounded_v', mean_udf(df['v']).over(w1)) + + expected1 = df.withColumn('mean_v', mean(df['v']).over(w1)) \ + .withColumn('max_v', max(df['v']).over(w2)) \ + .withColumn('mean_unbounded_v', mean(df['v']).over(w1)) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) if __name__ == "__main__": diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 953b468e96519..bf007b0c62d8d 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -145,7 +145,18 @@ def wrapped(*series): return lambda *a: (wrapped(*a), arrow_return_type) -def wrap_window_agg_pandas_udf(f, return_type): +def wrap_window_agg_pandas_udf(f, return_type, runner_conf, udf_index): + window_bound_types_str = runner_conf.get('pandas_window_bound_types') + window_bound_type = [t.strip().lower() for t in window_bound_types_str.split(',')][udf_index] + if window_bound_type == 'bounded': + return wrap_bounded_window_agg_pandas_udf(f, return_type) + elif window_bound_type == 'unbounded': + return wrap_unbounded_window_agg_pandas_udf(f, return_type) + else: + raise RuntimeError("Invalid window bound type: {} ".format(window_bound_type)) + + +def wrap_unbounded_window_agg_pandas_udf(f, return_type): # This is similar to grouped_agg_pandas_udf, the only difference # is that window_agg_pandas_udf needs to repeat the return value # to match window length, where grouped_agg_pandas_udf just returns @@ -160,7 +171,41 @@ def wrapped(*series): return lambda *a: (wrapped(*a), arrow_return_type) -def read_single_udf(pickleSer, infile, eval_type, runner_conf): +def wrap_bounded_window_agg_pandas_udf(f, return_type): + arrow_return_type = to_arrow_type(return_type) + + def wrapped(begin_index, end_index, *series): + import pandas as pd + result = [] + + # Index operation is faster on np.ndarray, + # So we turn the index series into np array + # here for performance + begin_array = begin_index.values + end_array = end_index.values + + for i in range(len(begin_array)): + # Note: Create a slice from a series for each window is + # actually pretty expensive. However, there + # is no easy way to reduce cost here. + # Note: s.iloc[i : j] is about 30% faster than s[i: j], with + # the caveat that the created slices shares the same + # memory with s. Therefore, user are not allowed to + # change the value of input series inside the window + # function. It is rare that user needs to modify the + # input series in the window function, and therefore, + # it is be a reasonable restriction. + # Note: Calling reset_index on the slices will increase the cost + # of creating slices by about 100%. Therefore, for performance + # reasons we don't do it here. + series_slices = [s.iloc[begin_array[i]: end_array[i]] for s in series] + result.append(f(*series_slices)) + return pd.Series(result) + + return lambda *a: (wrapped(*a), arrow_return_type) + + +def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): num_arg = read_int(infile) arg_offsets = [read_int(infile) for i in range(num_arg)] row_func = None @@ -184,7 +229,7 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf): elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF: - return arg_offsets, wrap_window_agg_pandas_udf(func, return_type) + return arg_offsets, wrap_window_agg_pandas_udf(func, return_type, runner_conf, udf_index) elif eval_type == PythonEvalType.SQL_BATCHED_UDF: return arg_offsets, wrap_udf(func, return_type) else: @@ -226,7 +271,8 @@ def read_udfs(pickleSer, infile, eval_type): # See FlatMapGroupsInPandasExec for how arg_offsets are used to # distinguish between grouping attributes and data attributes - arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type, runner_conf) + arg_offsets, udf = read_single_udf( + pickleSer, infile, eval_type, runner_conf, udf_index=0) udfs['f'] = udf split_offset = arg_offsets[0] + 1 arg0 = ["a[%d]" % o for o in arg_offsets[1: split_offset]] @@ -238,7 +284,8 @@ def read_udfs(pickleSer, infile, eval_type): # In the special case of a single UDF this will return a single result rather # than a tuple of results; this is the format that the JVM side expects. for i in range(num_udfs): - arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type, runner_conf) + arg_offsets, udf = read_single_udf( + pickleSer, infile, eval_type, runner_conf, udf_index=i) udfs['f%d' % i] = udf args = ["a[%d]" % o for o in arg_offsets] call_udf.append("f%d(%s)" % (i, ", ".join(args))) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 6a91d556b2f3e..88d41e8824405 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -134,11 +134,6 @@ trait CheckAnalysis extends PredicateHelper { failAnalysis("An offset window function can only be evaluated in an ordered " + s"row-based window frame with a single offset: $w") - case _ @ WindowExpression(_: PythonUDF, - WindowSpecDefinition(_, _, frame: SpecifiedWindowFrame)) - if !frame.isUnbounded => - failAnalysis("Only unbounded window frame is supported with Pandas UDFs.") - case w @ WindowExpression(e, s) => // Only allow window functions with an aggregate expression or an offset window // function or a Pandas window UDF. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala index 82973307feef3..1ce1215bfdd62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala @@ -27,17 +27,64 @@ import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} +import org.apache.spark.sql.execution.{ExternalAppendOnlyUnsafeRowArray, SparkPlan} import org.apache.spark.sql.execution.arrow.ArrowUtils -import org.apache.spark.sql.types.{DataType, StructField, StructType} +import org.apache.spark.sql.execution.window._ +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils +/** + * This class calculates and outputs windowed aggregates over the rows in a single partition. + * + * This is similar to [[WindowExec]]. The main difference is that this node does not compute + * any window aggregation values. Instead, it computes the lower and upper bound for each window + * (i.e. window bounds) and pass the data and indices to Python worker to do the actual window + * aggregation. + * + * It currently materializes all data associated with the same partition key and passes them to + * Python worker. This is not strictly necessary for sliding windows and can be improved (by + * possibly slicing data into overlapping chunks and stitching them together). + * + * This class groups window expressions by their window boundaries so that window expressions + * with the same window boundaries can share the same window bounds. The window bounds are + * prepended to the data passed to the python worker. + * + * For example, if we have: + * avg(v) over specifiedwindowframe(RowFrame, -5, 5), + * avg(v) over specifiedwindowframe(RowFrame, UnboundedPreceding, UnboundedFollowing), + * avg(v) over specifiedwindowframe(RowFrame, -3, 3), + * max(v) over specifiedwindowframe(RowFrame, -3, 3) + * + * The python input will look like: + * (lower_bound_w1, upper_bound_w1, lower_bound_w3, upper_bound_w3, v) + * + * where w1 is specifiedwindowframe(RowFrame, -5, 5) + * w2 is specifiedwindowframe(RowFrame, UnboundedPreceding, UnboundedFollowing) + * w3 is specifiedwindowframe(RowFrame, -3, 3) + * + * Note that w2 doesn't have bound indices in the python input because it's unbounded window + * so it's bound indices will always be the same. + * + * Bounded window and Unbounded window are evaluated differently in Python worker: + * (1) Bounded window takes the window bound indices in addition to the input columns. + * Unbounded window takes only input columns. + * (2) Bounded window evaluates the udf once per input row. + * Unbounded window evaluates the udf once per window partition. + * This is controlled by Python runner conf "pandas_window_bound_types" + * + * The logic to compute window bounds is delegated to [[WindowFunctionFrame]] and shared with + * [[WindowExec]] + * + * Note this doesn't support partial aggregation and all aggregation is computed from the entire + * window. + */ case class WindowInPandasExec( windowExpression: Seq[NamedExpression], partitionSpec: Seq[Expression], orderSpec: Seq[SortOrder], - child: SparkPlan) extends UnaryExecNode { + child: SparkPlan) + extends WindowExecBase(windowExpression, partitionSpec, orderSpec, child) { override def output: Seq[Attribute] = child.output ++ windowExpression.map(_.toAttribute) @@ -60,6 +107,26 @@ case class WindowInPandasExec( override def outputPartitioning: Partitioning = child.outputPartitioning + /** + * Helper functions and data structures for window bounds + * + * It contains: + * (1) Total number of window bound indices in the python input row + * (2) Function from frame index to its lower bound column index in the python input row + * (3) Function from frame index to its upper bound column index in the python input row + * (4) Seq from frame index to its window bound type + */ + private type WindowBoundHelpers = (Int, Int => Int, Int => Int, Seq[WindowBoundType]) + + /** + * Enum for window bound types. Used only inside this class. + */ + private sealed case class WindowBoundType(value: String) + private object UnboundedWindow extends WindowBoundType("unbounded") + private object BoundedWindow extends WindowBoundType("bounded") + + private val windowBoundTypeConf = "pandas_window_bound_types" + private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = { udf.children match { case Seq(u: PythonUDF) => @@ -73,68 +140,150 @@ case class WindowInPandasExec( } /** - * Create the resulting projection. - * - * This method uses Code Generation. It can only be used on the executor side. - * - * @param expressions unbound ordered function expressions. - * @return the final resulting projection. + * See [[WindowBoundHelpers]] for details. */ - private[this] def createResultProjection(expressions: Seq[Expression]): UnsafeProjection = { - val references = expressions.zipWithIndex.map { case (e, i) => - // Results of window expressions will be on the right side of child's output - BoundReference(child.output.size + i, e.dataType, e.nullable) + private def computeWindowBoundHelpers( + factories: Seq[InternalRow => WindowFunctionFrame] + ): WindowBoundHelpers = { + val functionFrames = factories.map(_(EmptyRow)) + + val windowBoundTypes = functionFrames.map { + case _: UnboundedWindowFunctionFrame => UnboundedWindow + case _: UnboundedFollowingWindowFunctionFrame | + _: SlidingWindowFunctionFrame | + _: UnboundedPrecedingWindowFunctionFrame => BoundedWindow + // It should be impossible to get other types of window function frame here + case frame => throw new RuntimeException(s"Unexpected window function frame $frame.") } - val unboundToRefMap = expressions.zip(references).toMap - val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap)) - UnsafeProjection.create( - child.output ++ patchedWindowExpression, - child.output) + + val requiredIndices = functionFrames.map { + case _: UnboundedWindowFunctionFrame => 0 + case _ => 2 + } + + val upperBoundIndices = requiredIndices.scan(0)(_ + _).tail + + val boundIndices = requiredIndices.zip(upperBoundIndices).map { case (num, upperBoundIndex) => + if (num == 0) { + // Sentinel values for unbounded window + (-1, -1) + } else { + (upperBoundIndex - 2, upperBoundIndex - 1) + } + } + + def lowerBoundIndex(frameIndex: Int) = boundIndices(frameIndex)._1 + def upperBoundIndex(frameIndex: Int) = boundIndices(frameIndex)._2 + + (requiredIndices.sum, lowerBoundIndex, upperBoundIndex, windowBoundTypes) } protected override def doExecute(): RDD[InternalRow] = { - val inputRDD = child.execute() + // Unwrap the expressions and factories from the map. + val expressionsWithFrameIndex = + windowFrameExpressionFactoryPairs.map(_._1).zipWithIndex.flatMap { + case (buffer, frameIndex) => buffer.map(expr => (expr, frameIndex)) + } + + val expressions = expressionsWithFrameIndex.map(_._1) + val expressionIndexToFrameIndex = + expressionsWithFrameIndex.map(_._2).zipWithIndex.map(_.swap).toMap + + val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray + // Helper functions + val (numBoundIndices, lowerBoundIndex, upperBoundIndex, frameWindowBoundTypes) = + computeWindowBoundHelpers(factories) + val isBounded = { frameIndex: Int => lowerBoundIndex(frameIndex) >= 0 } + val numFrames = factories.length + + val inMemoryThreshold = conf.windowExecBufferInMemoryThreshold + val spillThreshold = conf.windowExecBufferSpillThreshold val sessionLocalTimeZone = conf.sessionLocalTimeZone - val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) // Extract window expressions and window functions - val expressions = windowExpression.flatMap(_.collect { case e: WindowExpression => e }) - - val udfExpressions = expressions.map(_.windowFunction.asInstanceOf[PythonUDF]) + val windowExpressions = expressions.flatMap(_.collect { case e: WindowExpression => e }) + val udfExpressions = windowExpressions.map(_.windowFunction.asInstanceOf[PythonUDF]) + // We shouldn't be chaining anything here. + // All chained python functions should only contain one function. val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip + require(pyFuncs.length == expressions.length) + + val udfWindowBoundTypes = pyFuncs.indices.map(i => + frameWindowBoundTypes(expressionIndexToFrameIndex(i))) + val pythonRunnerConf: Map[String, String] = (ArrowUtils.getPythonRunnerConfMap(conf) + + (windowBoundTypeConf -> udfWindowBoundTypes.map(_.value).mkString(","))) // Filter child output attributes down to only those that are UDF inputs. - // Also eliminate duplicate UDF inputs. - val allInputs = new ArrayBuffer[Expression] - val dataTypes = new ArrayBuffer[DataType] + // Also eliminate duplicate UDF inputs. This is similar to how other Python UDF node + // handles UDF inputs. + val dataInputs = new ArrayBuffer[Expression] + val dataInputTypes = new ArrayBuffer[DataType] val argOffsets = inputs.map { input => input.map { e => - if (allInputs.exists(_.semanticEquals(e))) { - allInputs.indexWhere(_.semanticEquals(e)) + if (dataInputs.exists(_.semanticEquals(e))) { + dataInputs.indexWhere(_.semanticEquals(e)) } else { - allInputs += e - dataTypes += e.dataType - allInputs.length - 1 + dataInputs += e + dataInputTypes += e.dataType + dataInputs.length - 1 } }.toArray }.toArray - // Schema of input rows to the python runner - val windowInputSchema = StructType(dataTypes.zipWithIndex.map { case (dt, i) => - StructField(s"_$i", dt) - }) + // In addition to UDF inputs, we will prepend window bounds for each UDFs. + // For bounded windows, we prepend lower bound and upper bound. For unbounded windows, + // we no not add window bounds. (strictly speaking, we only need to lower or upper bound + // if the window is bounded only on one side, this can be improved in the future) - inputRDD.mapPartitionsInternal { iter => - val context = TaskContext.get() + // Setting window bounds for each window frames. Each window frame has different bounds so + // each has its own window bound columns. + val windowBoundsInput = factories.indices.flatMap { frameIndex => + if (isBounded(frameIndex)) { + Seq( + BoundReference(lowerBoundIndex(frameIndex), IntegerType, nullable = false), + BoundReference(upperBoundIndex(frameIndex), IntegerType, nullable = false) + ) + } else { + Seq.empty + } + } - val grouped = if (partitionSpec.isEmpty) { - // Use an empty unsafe row as a place holder for the grouping key - Iterator((new UnsafeRow(), iter)) + // Setting the window bounds argOffset for each UDF. For UDFs with bounded window, argOffset + // for the UDF is (lowerBoundOffet, upperBoundOffset, inputOffset1, inputOffset2, ...) + // For UDFs with unbounded window, argOffset is (inputOffset1, inputOffset2, ...) + pyFuncs.indices.foreach { exprIndex => + val frameIndex = expressionIndexToFrameIndex(exprIndex) + if (isBounded(frameIndex)) { + argOffsets(exprIndex) = + Array(lowerBoundIndex(frameIndex), upperBoundIndex(frameIndex)) ++ + argOffsets(exprIndex).map(_ + windowBoundsInput.length) } else { - GroupedIterator(iter, partitionSpec, child.output) + argOffsets(exprIndex) = argOffsets(exprIndex).map(_ + windowBoundsInput.length) } + } + + val allInputs = windowBoundsInput ++ dataInputs + val allInputTypes = allInputs.map(_.dataType) + + // Start processing. + child.execute().mapPartitions { iter => + val context = TaskContext.get() + + // Get all relevant projections. + val resultProj = createResultProjection(expressions) + val pythonInputProj = UnsafeProjection.create( + allInputs, + windowBoundsInput.map(ref => + AttributeReference(s"i_${ref.ordinal}", ref.dataType)()) ++ child.output + ) + val pythonInputSchema = StructType( + allInputTypes.zipWithIndex.map { case (dt, i) => + StructField(s"_$i", dt) + } + ) + val grouping = UnsafeProjection.create(partitionSpec, child.output) // The queue used to buffer input rows so we can drain it to // combine input with output from Python. @@ -144,11 +293,94 @@ case class WindowInPandasExec( queue.close() } - val inputProj = UnsafeProjection.create(allInputs, child.output) - val pythonInput = grouped.map { case (_, rows) => - rows.map { row => - queue.add(row.asInstanceOf[UnsafeRow]) - inputProj(row) + val stream = iter.map { row => + queue.add(row.asInstanceOf[UnsafeRow]) + row + } + + val pythonInput = new Iterator[Iterator[UnsafeRow]] { + + // Manage the stream and the grouping. + var nextRow: UnsafeRow = null + var nextGroup: UnsafeRow = null + var nextRowAvailable: Boolean = false + private[this] def fetchNextRow() { + nextRowAvailable = stream.hasNext + if (nextRowAvailable) { + nextRow = stream.next().asInstanceOf[UnsafeRow] + nextGroup = grouping(nextRow) + } else { + nextRow = null + nextGroup = null + } + } + fetchNextRow() + + // Manage the current partition. + val buffer: ExternalAppendOnlyUnsafeRowArray = + new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold) + var bufferIterator: Iterator[UnsafeRow] = _ + + val indexRow = new SpecificInternalRow(Array.fill(numBoundIndices)(IntegerType)) + + val frames = factories.map(_(indexRow)) + + private[this] def fetchNextPartition() { + // Collect all the rows in the current partition. + // Before we start to fetch new input rows, make a copy of nextGroup. + val currentGroup = nextGroup.copy() + + // clear last partition + buffer.clear() + + while (nextRowAvailable && nextGroup == currentGroup) { + buffer.add(nextRow) + fetchNextRow() + } + + // Setup the frames. + var i = 0 + while (i < numFrames) { + frames(i).prepare(buffer) + i += 1 + } + + // Setup iteration + rowIndex = 0 + bufferIterator = buffer.generateIterator() + } + + // Iteration + var rowIndex = 0 + + override final def hasNext: Boolean = + (bufferIterator != null && bufferIterator.hasNext) || nextRowAvailable + + override final def next(): Iterator[UnsafeRow] = { + // Load the next partition if we need to. + if ((bufferIterator == null || !bufferIterator.hasNext) && nextRowAvailable) { + fetchNextPartition() + } + + val join = new JoinedRow + + bufferIterator.zipWithIndex.map { + case (current, index) => + var frameIndex = 0 + while (frameIndex < numFrames) { + frames(frameIndex).write(index, current) + // If the window is unbounded we don't need to write out window bounds. + if (isBounded(frameIndex)) { + indexRow.setInt( + lowerBoundIndex(frameIndex), frames(frameIndex).currentLowerBound()) + indexRow.setInt( + upperBoundIndex(frameIndex), frames(frameIndex).currentUpperBound()) + } + frameIndex += 1 + } + + pythonInputProj(join(indexRow, current)) + } } } @@ -156,12 +388,11 @@ case class WindowInPandasExec( pyFuncs, PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF, argOffsets, - windowInputSchema, + pythonInputSchema, sessionLocalTimeZone, pythonRunnerConf).compute(pythonInput, context.partitionId(), context) val joined = new JoinedRow - val resultProj = createResultProjection(expressions) windowFunctionResult.flatMap(_.rowIterator.asScala).map { windowOutput => val leftRow = queue.remove() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala index 729b8bdb3dae8..89f6edda2ef57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala @@ -83,7 +83,7 @@ case class WindowExec( partitionSpec: Seq[Expression], orderSpec: Seq[SortOrder], child: SparkPlan) - extends UnaryExecNode { + extends WindowExecBase(windowExpression, partitionSpec, orderSpec, child) { override def output: Seq[Attribute] = child.output ++ windowExpression.map(_.toAttribute) @@ -104,193 +104,6 @@ case class WindowExec( override def outputPartitioning: Partitioning = child.outputPartitioning - /** - * Create a bound ordering object for a given frame type and offset. A bound ordering object is - * used to determine which input row lies within the frame boundaries of an output row. - * - * This method uses Code Generation. It can only be used on the executor side. - * - * @param frame to evaluate. This can either be a Row or Range frame. - * @param bound with respect to the row. - * @param timeZone the session local timezone for time related calculations. - * @return a bound ordering object. - */ - private[this] def createBoundOrdering( - frame: FrameType, bound: Expression, timeZone: String): BoundOrdering = { - (frame, bound) match { - case (RowFrame, CurrentRow) => - RowBoundOrdering(0) - - case (RowFrame, IntegerLiteral(offset)) => - RowBoundOrdering(offset) - - case (RangeFrame, CurrentRow) => - val ordering = newOrdering(orderSpec, child.output) - RangeBoundOrdering(ordering, IdentityProjection, IdentityProjection) - - case (RangeFrame, offset: Expression) if orderSpec.size == 1 => - // Use only the first order expression when the offset is non-null. - val sortExpr = orderSpec.head - val expr = sortExpr.child - - // Create the projection which returns the current 'value'. - val current = newMutableProjection(expr :: Nil, child.output) - - // Flip the sign of the offset when processing the order is descending - val boundOffset = sortExpr.direction match { - case Descending => UnaryMinus(offset) - case Ascending => offset - } - - // Create the projection which returns the current 'value' modified by adding the offset. - val boundExpr = (expr.dataType, boundOffset.dataType) match { - case (DateType, IntegerType) => DateAdd(expr, boundOffset) - case (TimestampType, CalendarIntervalType) => - TimeAdd(expr, boundOffset, Some(timeZone)) - case (a, b) if a== b => Add(expr, boundOffset) - } - val bound = newMutableProjection(boundExpr :: Nil, child.output) - - // Construct the ordering. This is used to compare the result of current value projection - // to the result of bound value projection. This is done manually because we want to use - // Code Generation (if it is enabled). - val boundSortExprs = sortExpr.copy(BoundReference(0, expr.dataType, expr.nullable)) :: Nil - val ordering = newOrdering(boundSortExprs, Nil) - RangeBoundOrdering(ordering, current, bound) - - case (RangeFrame, _) => - sys.error("Non-Zero range offsets are not supported for windows " + - "with multiple order expressions.") - } - } - - /** - * Collection containing an entry for each window frame to process. Each entry contains a frame's - * [[WindowExpression]]s and factory function for the WindowFrameFunction. - */ - private[this] lazy val windowFrameExpressionFactoryPairs = { - type FrameKey = (String, FrameType, Expression, Expression) - type ExpressionBuffer = mutable.Buffer[Expression] - val framedFunctions = mutable.Map.empty[FrameKey, (ExpressionBuffer, ExpressionBuffer)] - - // Add a function and its function to the map for a given frame. - def collect(tpe: String, fr: SpecifiedWindowFrame, e: Expression, fn: Expression): Unit = { - val key = (tpe, fr.frameType, fr.lower, fr.upper) - val (es, fns) = framedFunctions.getOrElseUpdate( - key, (ArrayBuffer.empty[Expression], ArrayBuffer.empty[Expression])) - es += e - fns += fn - } - - // Collect all valid window functions and group them by their frame. - windowExpression.foreach { x => - x.foreach { - case e @ WindowExpression(function, spec) => - val frame = spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame] - function match { - case AggregateExpression(f, _, _, _) => collect("AGGREGATE", frame, e, f) - case f: AggregateWindowFunction => collect("AGGREGATE", frame, e, f) - case f: OffsetWindowFunction => collect("OFFSET", frame, e, f) - case f => sys.error(s"Unsupported window function: $f") - } - case _ => - } - } - - // Map the groups to a (unbound) expression and frame factory pair. - var numExpressions = 0 - val timeZone = conf.sessionLocalTimeZone - framedFunctions.toSeq.map { - case (key, (expressions, functionSeq)) => - val ordinal = numExpressions - val functions = functionSeq.toArray - - // Construct an aggregate processor if we need one. - def processor = AggregateProcessor( - functions, - ordinal, - child.output, - (expressions, schema) => - newMutableProjection(expressions, schema, subexpressionEliminationEnabled)) - - // Create the factory - val factory = key match { - // Offset Frame - case ("OFFSET", _, IntegerLiteral(offset), _) => - target: InternalRow => - new OffsetWindowFunctionFrame( - target, - ordinal, - // OFFSET frame functions are guaranteed be OffsetWindowFunctions. - functions.map(_.asInstanceOf[OffsetWindowFunction]), - child.output, - (expressions, schema) => - newMutableProjection(expressions, schema, subexpressionEliminationEnabled), - offset) - - // Entire Partition Frame. - case ("AGGREGATE", _, UnboundedPreceding, UnboundedFollowing) => - target: InternalRow => { - new UnboundedWindowFunctionFrame(target, processor) - } - - // Growing Frame. - case ("AGGREGATE", frameType, UnboundedPreceding, upper) => - target: InternalRow => { - new UnboundedPrecedingWindowFunctionFrame( - target, - processor, - createBoundOrdering(frameType, upper, timeZone)) - } - - // Shrinking Frame. - case ("AGGREGATE", frameType, lower, UnboundedFollowing) => - target: InternalRow => { - new UnboundedFollowingWindowFunctionFrame( - target, - processor, - createBoundOrdering(frameType, lower, timeZone)) - } - - // Moving Frame. - case ("AGGREGATE", frameType, lower, upper) => - target: InternalRow => { - new SlidingWindowFunctionFrame( - target, - processor, - createBoundOrdering(frameType, lower, timeZone), - createBoundOrdering(frameType, upper, timeZone)) - } - } - - // Keep track of the number of expressions. This is a side-effect in a map... - numExpressions += expressions.size - - // Create the Frame Expression - Factory pair. - (expressions, factory) - } - } - - /** - * Create the resulting projection. - * - * This method uses Code Generation. It can only be used on the executor side. - * - * @param expressions unbound ordered function expressions. - * @return the final resulting projection. - */ - private[this] def createResultProjection(expressions: Seq[Expression]): UnsafeProjection = { - val references = expressions.zipWithIndex.map{ case (e, i) => - // Results of window expressions will be on the right side of child's output - BoundReference(child.output.size + i, e.dataType, e.nullable) - } - val unboundToRefMap = expressions.zip(references).toMap - val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap)) - UnsafeProjection.create( - child.output ++ patchedWindowExpression, - child.output) - } - protected override def doExecute(): RDD[InternalRow] = { // Unwrap the expressions and factories from the map. val expressions = windowFrameExpressionFactoryPairs.flatMap(_._1) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala new file mode 100644 index 0000000000000..dcb86f48bdf32 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.window + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.types.{CalendarIntervalType, DateType, IntegerType, TimestampType} + +abstract class WindowExecBase( + windowExpression: Seq[NamedExpression], + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder], + child: SparkPlan) extends UnaryExecNode { + + /** + * Create the resulting projection. + * + * This method uses Code Generation. It can only be used on the executor side. + * + * @param expressions unbound ordered function expressions. + * @return the final resulting projection. + */ + protected def createResultProjection(expressions: Seq[Expression]): UnsafeProjection = { + val references = expressions.zipWithIndex.map { case (e, i) => + // Results of window expressions will be on the right side of child's output + BoundReference(child.output.size + i, e.dataType, e.nullable) + } + val unboundToRefMap = expressions.zip(references).toMap + val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap)) + UnsafeProjection.create( + child.output ++ patchedWindowExpression, + child.output) + } + + /** + * Create a bound ordering object for a given frame type and offset. A bound ordering object is + * used to determine which input row lies within the frame boundaries of an output row. + * + * This method uses Code Generation. It can only be used on the executor side. + * + * @param frame to evaluate. This can either be a Row or Range frame. + * @param bound with respect to the row. + * @param timeZone the session local timezone for time related calculations. + * @return a bound ordering object. + */ + private def createBoundOrdering( + frame: FrameType, bound: Expression, timeZone: String): BoundOrdering = { + (frame, bound) match { + case (RowFrame, CurrentRow) => + RowBoundOrdering(0) + + case (RowFrame, IntegerLiteral(offset)) => + RowBoundOrdering(offset) + + case (RangeFrame, CurrentRow) => + val ordering = newOrdering(orderSpec, child.output) + RangeBoundOrdering(ordering, IdentityProjection, IdentityProjection) + + case (RangeFrame, offset: Expression) if orderSpec.size == 1 => + // Use only the first order expression when the offset is non-null. + val sortExpr = orderSpec.head + val expr = sortExpr.child + + // Create the projection which returns the current 'value'. + val current = newMutableProjection(expr :: Nil, child.output) + + // Flip the sign of the offset when processing the order is descending + val boundOffset = sortExpr.direction match { + case Descending => UnaryMinus(offset) + case Ascending => offset + } + + // Create the projection which returns the current 'value' modified by adding the offset. + val boundExpr = (expr.dataType, boundOffset.dataType) match { + case (DateType, IntegerType) => DateAdd(expr, boundOffset) + case (TimestampType, CalendarIntervalType) => + TimeAdd(expr, boundOffset, Some(timeZone)) + case (a, b) if a == b => Add(expr, boundOffset) + } + val bound = newMutableProjection(boundExpr :: Nil, child.output) + + // Construct the ordering. This is used to compare the result of current value projection + // to the result of bound value projection. This is done manually because we want to use + // Code Generation (if it is enabled). + val boundSortExprs = sortExpr.copy(BoundReference(0, expr.dataType, expr.nullable)) :: Nil + val ordering = newOrdering(boundSortExprs, Nil) + RangeBoundOrdering(ordering, current, bound) + + case (RangeFrame, _) => + sys.error("Non-Zero range offsets are not supported for windows " + + "with multiple order expressions.") + } + } + + /** + * Collection containing an entry for each window frame to process. Each entry contains a frame's + * [[WindowExpression]]s and factory function for the WindowFrameFunction. + */ + protected lazy val windowFrameExpressionFactoryPairs = { + type FrameKey = (String, FrameType, Expression, Expression) + type ExpressionBuffer = mutable.Buffer[Expression] + val framedFunctions = mutable.Map.empty[FrameKey, (ExpressionBuffer, ExpressionBuffer)] + + // Add a function and its function to the map for a given frame. + def collect(tpe: String, fr: SpecifiedWindowFrame, e: Expression, fn: Expression): Unit = { + val key = (tpe, fr.frameType, fr.lower, fr.upper) + val (es, fns) = framedFunctions.getOrElseUpdate( + key, (ArrayBuffer.empty[Expression], ArrayBuffer.empty[Expression])) + es += e + fns += fn + } + + // Collect all valid window functions and group them by their frame. + windowExpression.foreach { x => + x.foreach { + case e @ WindowExpression(function, spec) => + val frame = spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame] + function match { + case AggregateExpression(f, _, _, _) => collect("AGGREGATE", frame, e, f) + case f: AggregateWindowFunction => collect("AGGREGATE", frame, e, f) + case f: OffsetWindowFunction => collect("OFFSET", frame, e, f) + case f: PythonUDF => collect("AGGREGATE", frame, e, f) + case f => sys.error(s"Unsupported window function: $f") + } + case _ => + } + } + + // Map the groups to a (unbound) expression and frame factory pair. + var numExpressions = 0 + val timeZone = conf.sessionLocalTimeZone + framedFunctions.toSeq.map { + case (key, (expressions, functionSeq)) => + val ordinal = numExpressions + val functions = functionSeq.toArray + + // Construct an aggregate processor if we need one. + // Currently we don't allow mixing of Pandas UDF and SQL aggregation functions + // in a single Window physical node. Therefore, we can assume no SQL aggregation + // functions if Pandas UDF exists. In the future, we might mix Pandas UDF and SQL + // aggregation function in a single physical node. + def processor = if (functions.exists(_.isInstanceOf[PythonUDF])) { + null + } else { + AggregateProcessor( + functions, + ordinal, + child.output, + (expressions, schema) => + newMutableProjection(expressions, schema, subexpressionEliminationEnabled)) + } + + // Create the factory + val factory = key match { + // Offset Frame + case ("OFFSET", _, IntegerLiteral(offset), _) => + target: InternalRow => + new OffsetWindowFunctionFrame( + target, + ordinal, + // OFFSET frame functions are guaranteed be OffsetWindowFunctions. + functions.map(_.asInstanceOf[OffsetWindowFunction]), + child.output, + (expressions, schema) => + newMutableProjection(expressions, schema, subexpressionEliminationEnabled), + offset) + + // Entire Partition Frame. + case ("AGGREGATE", _, UnboundedPreceding, UnboundedFollowing) => + target: InternalRow => { + new UnboundedWindowFunctionFrame(target, processor) + } + + // Growing Frame. + case ("AGGREGATE", frameType, UnboundedPreceding, upper) => + target: InternalRow => { + new UnboundedPrecedingWindowFunctionFrame( + target, + processor, + createBoundOrdering(frameType, upper, timeZone)) + } + + // Shrinking Frame. + case ("AGGREGATE", frameType, lower, UnboundedFollowing) => + target: InternalRow => { + new UnboundedFollowingWindowFunctionFrame( + target, + processor, + createBoundOrdering(frameType, lower, timeZone)) + } + + // Moving Frame. + case ("AGGREGATE", frameType, lower, upper) => + target: InternalRow => { + new SlidingWindowFunctionFrame( + target, + processor, + createBoundOrdering(frameType, lower, timeZone), + createBoundOrdering(frameType, upper, timeZone)) + } + } + + // Keep track of the number of expressions. This is a side-effect in a map... + numExpressions += expressions.size + + // Create the Frame Expression - Factory pair. + (expressions, factory) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala index 156002ef58fbe..a5601899ea2de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray * Before use a frame must be prepared by passing it all the rows in the current partition. After * preparation the update method can be called to fill the output rows. */ -private[window] abstract class WindowFunctionFrame { +abstract class WindowFunctionFrame { /** * Prepare the frame for calculating the results for a partition. * @@ -42,6 +42,20 @@ private[window] abstract class WindowFunctionFrame { * Write the current results to the target row. */ def write(index: Int, current: InternalRow): Unit + + /** + * The current lower window bound in the row array (inclusive). + * + * This should be called after the current row is updated via [[write]] + */ + def currentLowerBound(): Int + + /** + * The current row index of the upper window bound in the row array (exclusive) + * + * This should be called after the current row is updated via [[write]] + */ + def currentUpperBound(): Int } object WindowFunctionFrame { @@ -62,7 +76,7 @@ object WindowFunctionFrame { * @param newMutableProjection function used to create the projection. * @param offset by which rows get moved within a partition. */ -private[window] final class OffsetWindowFunctionFrame( +final class OffsetWindowFunctionFrame( target: InternalRow, ordinal: Int, expressions: Array[OffsetWindowFunction], @@ -137,6 +151,10 @@ private[window] final class OffsetWindowFunctionFrame( } inputIndex += 1 } + + override def currentLowerBound(): Int = throw new UnsupportedOperationException() + + override def currentUpperBound(): Int = throw new UnsupportedOperationException() } /** @@ -148,7 +166,7 @@ private[window] final class OffsetWindowFunctionFrame( * @param lbound comparator used to identify the lower bound of an output row. * @param ubound comparator used to identify the upper bound of an output row. */ -private[window] final class SlidingWindowFunctionFrame( +final class SlidingWindowFunctionFrame( target: InternalRow, processor: AggregateProcessor, lbound: BoundOrdering, @@ -170,24 +188,24 @@ private[window] final class SlidingWindowFunctionFrame( private[this] val buffer = new util.ArrayDeque[InternalRow]() /** - * Index of the first input row with a value greater than the upper bound of the current - * output row. + * Index of the first input row with a value equal to or greater than the lower bound of the + * current output row. */ - private[this] var inputHighIndex = 0 + private[this] var lowerBound = 0 /** - * Index of the first input row with a value equal to or greater than the lower bound of the - * current output row. + * Index of the first input row with a value greater than the upper bound of the current + * output row. */ - private[this] var inputLowIndex = 0 + private[this] var upperBound = 0 /** Prepare the frame for calculating a new partition. Reset all variables. */ override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = { input = rows inputIterator = input.generateIterator() nextRow = WindowFunctionFrame.getNextOrNull(inputIterator) - inputHighIndex = 0 - inputLowIndex = 0 + lowerBound = 0 + upperBound = 0 buffer.clear() } @@ -197,27 +215,27 @@ private[window] final class SlidingWindowFunctionFrame( // Drop all rows from the buffer for which the input row value is smaller than // the output row lower bound. - while (!buffer.isEmpty && lbound.compare(buffer.peek(), inputLowIndex, current, index) < 0) { + while (!buffer.isEmpty && lbound.compare(buffer.peek(), lowerBound, current, index) < 0) { buffer.remove() - inputLowIndex += 1 + lowerBound += 1 bufferUpdated = true } // Add all rows to the buffer for which the input row value is equal to or less than // the output row upper bound. - while (nextRow != null && ubound.compare(nextRow, inputHighIndex, current, index) <= 0) { - if (lbound.compare(nextRow, inputLowIndex, current, index) < 0) { - inputLowIndex += 1 + while (nextRow != null && ubound.compare(nextRow, upperBound, current, index) <= 0) { + if (lbound.compare(nextRow, lowerBound, current, index) < 0) { + lowerBound += 1 } else { buffer.add(nextRow.copy()) bufferUpdated = true } nextRow = WindowFunctionFrame.getNextOrNull(inputIterator) - inputHighIndex += 1 + upperBound += 1 } // Only recalculate and update when the buffer changes. - if (bufferUpdated) { + if (processor != null && bufferUpdated) { processor.initialize(input.length) val iter = buffer.iterator() while (iter.hasNext) { @@ -226,6 +244,10 @@ private[window] final class SlidingWindowFunctionFrame( processor.evaluate(target) } } + + override def currentLowerBound(): Int = lowerBound + + override def currentUpperBound(): Int = upperBound } /** @@ -239,27 +261,39 @@ private[window] final class SlidingWindowFunctionFrame( * @param target to write results to. * @param processor to calculate the row values with. */ -private[window] final class UnboundedWindowFunctionFrame( +final class UnboundedWindowFunctionFrame( target: InternalRow, processor: AggregateProcessor) extends WindowFunctionFrame { + val lowerBound: Int = 0 + var upperBound: Int = 0 + /** Prepare the frame for calculating a new partition. Process all rows eagerly. */ override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = { - processor.initialize(rows.length) - - val iterator = rows.generateIterator() - while (iterator.hasNext) { - processor.update(iterator.next()) + if (processor != null) { + processor.initialize(rows.length) + val iterator = rows.generateIterator() + while (iterator.hasNext) { + processor.update(iterator.next()) + } } + + upperBound = rows.length } /** Write the frame columns for the current row to the given target row. */ override def write(index: Int, current: InternalRow): Unit = { // Unfortunately we cannot assume that evaluation is deterministic. So we need to re-evaluate // for each row. - processor.evaluate(target) + if (processor != null) { + processor.evaluate(target) + } } + + override def currentLowerBound(): Int = lowerBound + + override def currentUpperBound(): Int = upperBound } /** @@ -276,7 +310,7 @@ private[window] final class UnboundedWindowFunctionFrame( * @param processor to calculate the row values with. * @param ubound comparator used to identify the upper bound of an output row. */ -private[window] final class UnboundedPrecedingWindowFunctionFrame( +final class UnboundedPrecedingWindowFunctionFrame( target: InternalRow, processor: AggregateProcessor, ubound: BoundOrdering) @@ -308,7 +342,9 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame( nextRow = inputIterator.next() } - processor.initialize(input.length) + if (processor != null) { + processor.initialize(input.length) + } } /** Write the frame columns for the current row to the given target row. */ @@ -318,17 +354,23 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame( // Add all rows to the aggregates for which the input row value is equal to or less than // the output row upper bound. while (nextRow != null && ubound.compare(nextRow, inputIndex, current, index) <= 0) { - processor.update(nextRow) + if (processor != null) { + processor.update(nextRow) + } nextRow = WindowFunctionFrame.getNextOrNull(inputIterator) inputIndex += 1 bufferUpdated = true } // Only recalculate and update when the buffer changes. - if (bufferUpdated) { + if (processor != null && bufferUpdated) { processor.evaluate(target) } } + + override def currentLowerBound(): Int = 0 + + override def currentUpperBound(): Int = inputIndex } /** @@ -347,7 +389,7 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame( * @param processor to calculate the row values with. * @param lbound comparator used to identify the lower bound of an output row. */ -private[window] final class UnboundedFollowingWindowFunctionFrame( +final class UnboundedFollowingWindowFunctionFrame( target: InternalRow, processor: AggregateProcessor, lbound: BoundOrdering) @@ -384,7 +426,7 @@ private[window] final class UnboundedFollowingWindowFunctionFrame( } // Only recalculate and update when the buffer changes. - if (bufferUpdated) { + if (processor != null && bufferUpdated) { processor.initialize(input.length) if (nextRow != null) { processor.update(nextRow) @@ -395,4 +437,8 @@ private[window] final class UnboundedFollowingWindowFunctionFrame( processor.evaluate(target) } } + + override def currentLowerBound(): Int = inputIndex + + override def currentUpperBound(): Int = input.length } From 435392e47391c0bc860c0dfb75c35d620187352d Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Tue, 18 Dec 2018 13:50:55 +0800 Subject: [PATCH 094/194] [SPARK-26246][SQL] Inferring TimestampType from JSON ## What changes were proposed in this pull request? The `JsonInferSchema` class is extended to support `TimestampType` inferring from string fields in JSON input: - If the `prefersDecimal` option is set to `true`, it tries to infer decimal type from the string field. - If decimal type inference fails or `prefersDecimal` is disabled, `JsonInferSchema` tries to infer `TimestampType`. - If timestamp type inference fails, `StringType` is returned as the inferred type. ## How was this patch tested? Added new test suite - `JsonInferSchemaSuite` to check date and timestamp types inferring from JSON using `JsonInferSchema` directly. A few tests were added `JsonSuite` to check type merging and roundtrip tests. This changes was tested by `JsonSuite`, `JsonExpressionsSuite` and `JsonFunctionsSuite` as well. Closes #23201 from MaxGekk/json-infer-time. Lead-authored-by: Maxim Gekk Co-authored-by: Maxim Gekk Signed-off-by: Hyukjin Kwon --- .../sql/catalyst/json/JsonInferSchema.scala | 22 +++- .../catalyst/json/JsonInferSchemaSuite.scala | 102 ++++++++++++++++++ .../datasources/json/JsonSuite.scala | 52 +++++++++ 3 files changed, 171 insertions(+), 5 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JsonInferSchemaSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala index 263e05de32075..d1bc00c08c1c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala @@ -28,7 +28,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.expressions.ExprUtils import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil -import org.apache.spark.sql.catalyst.util.{DropMalformedMode, FailFastMode, ParseMode, PermissiveMode} +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -37,6 +37,12 @@ private[sql] class JsonInferSchema(options: JSONOptions) extends Serializable { private val decimalParser = ExprUtils.getDecimalParser(options.locale) + @transient + private lazy val timestampFormatter = TimestampFormatter( + options.timestampFormat, + options.timeZone, + options.locale) + /** * Infer the type of a collection of json records in three stages: * 1. Infer the type of each record @@ -115,13 +121,19 @@ private[sql] class JsonInferSchema(options: JSONOptions) extends Serializable { // record fields' types have been combined. NullType - case VALUE_STRING if options.prefersDecimal => + case VALUE_STRING => + val field = parser.getText val decimalTry = allCatch opt { - val bigDecimal = decimalParser(parser.getText) + val bigDecimal = decimalParser(field) DecimalType(bigDecimal.precision, bigDecimal.scale) } - decimalTry.getOrElse(StringType) - case VALUE_STRING => StringType + if (options.prefersDecimal && decimalTry.isDefined) { + decimalTry.get + } else if ((allCatch opt timestampFormatter.parse(field)).isDefined) { + TimestampType + } else { + StringType + } case START_OBJECT => val builder = Array.newBuilder[StructField] diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JsonInferSchemaSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JsonInferSchemaSuite.scala new file mode 100644 index 0000000000000..9307f9b47b807 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JsonInferSchemaSuite.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.json + +import com.fasterxml.jackson.core.JsonFactory + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.plans.SQLHelper +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +class JsonInferSchemaSuite extends SparkFunSuite with SQLHelper { + + def checkType(options: Map[String, String], json: String, dt: DataType): Unit = { + val jsonOptions = new JSONOptions(options, "UTC", "") + val inferSchema = new JsonInferSchema(jsonOptions) + val factory = new JsonFactory() + jsonOptions.setJacksonOptions(factory) + val parser = CreateJacksonParser.string(factory, json) + parser.nextToken() + val expectedType = StructType(Seq(StructField("a", dt, true))) + + assert(inferSchema.inferField(parser) === expectedType) + } + + def checkTimestampType(pattern: String, json: String): Unit = { + checkType(Map("timestampFormat" -> pattern), json, TimestampType) + } + + test("inferring timestamp type") { + Seq(true, false).foreach { legacyParser => + withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> legacyParser.toString) { + checkTimestampType("yyyy", """{"a": "2018"}""") + checkTimestampType("yyyy=MM", """{"a": "2018=12"}""") + checkTimestampType("yyyy MM dd", """{"a": "2018 12 02"}""") + checkTimestampType( + "yyyy-MM-dd'T'HH:mm:ss.SSS", + """{"a": "2018-12-02T21:04:00.123"}""") + checkTimestampType( + "yyyy-MM-dd'T'HH:mm:ss.SSSSSSXXX", + """{"a": "2018-12-02T21:04:00.123567+01:00"}""") + } + } + } + + test("prefer decimals over timestamps") { + Seq(true, false).foreach { legacyParser => + withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> legacyParser.toString) { + checkType( + options = Map( + "prefersDecimal" -> "true", + "timestampFormat" -> "yyyyMMdd.HHmmssSSS" + ), + json = """{"a": "20181202.210400123"}""", + dt = DecimalType(17, 9) + ) + } + } + } + + test("skip decimal type inferring") { + Seq(true, false).foreach { legacyParser => + withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> legacyParser.toString) { + checkType( + options = Map( + "prefersDecimal" -> "false", + "timestampFormat" -> "yyyyMMdd.HHmmssSSS" + ), + json = """{"a": "20181202.210400123"}""", + dt = TimestampType + ) + } + } + } + + test("fallback to string type") { + Seq(true, false).foreach { legacyParser => + withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> legacyParser.toString) { + checkType( + options = Map("timestampFormat" -> "yyyy,MM,dd.HHmmssSSS"), + json = """{"a": "20181202.210400123"}""", + dt = StringType + ) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 786335b42e3cb..8f575a371c98e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.StructType.fromDDL import org.apache.spark.util.Utils class TestFileFilter extends PathFilter { @@ -2589,4 +2590,55 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { Row(null, Array(0, 1, 2), "abc", """{"a":"-","b":[0, 1, 2],"c":"abc"}""") :: Row(0.1, null, "def", """{"a":0.1,"b":{},"c":"def"}""") :: Nil) } + + test("inferring timestamp type") { + Seq(true, false).foreach { legacyParser => + withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> legacyParser.toString) { + def schemaOf(jsons: String*): StructType = spark.read.json(jsons.toDS).schema + + assert(schemaOf( + """{"a":"2018-12-17T10:11:12.123-01:00"}""", + """{"a":"2018-12-16T22:23:24.123-02:00"}""") === fromDDL("a timestamp")) + + assert(schemaOf("""{"a":"2018-12-17T10:11:12.123-01:00"}""", """{"a":1}""") + === fromDDL("a string")) + assert(schemaOf("""{"a":"2018-12-17T10:11:12.123-01:00"}""", """{"a":"123"}""") + === fromDDL("a string")) + + assert(schemaOf("""{"a":"2018-12-17T10:11:12.123-01:00"}""", """{"a":null}""") + === fromDDL("a timestamp")) + assert(schemaOf("""{"a":null}""", """{"a":"2018-12-17T10:11:12.123-01:00"}""") + === fromDDL("a timestamp")) + } + } + } + + test("roundtrip for timestamp type inferring") { + Seq(true, false).foreach { legacyParser => + withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> legacyParser.toString) { + val customSchema = new StructType().add("date", TimestampType) + withTempDir { dir => + val timestampsWithFormatPath = s"${dir.getCanonicalPath}/timestampsWithFormat.json" + val timestampsWithFormat = spark.read + .option("timestampFormat", "dd/MM/yyyy HH:mm") + .json(datesRecords) + assert(timestampsWithFormat.schema === customSchema) + + timestampsWithFormat.write + .format("json") + .option("timestampFormat", "yyyy-MM-dd HH:mm:ss") + .option(DateTimeUtils.TIMEZONE_OPTION, "UTC") + .save(timestampsWithFormatPath) + + val readBack = spark.read + .option("timestampFormat", "yyyy-MM-dd HH:mm:ss") + .option(DateTimeUtils.TIMEZONE_OPTION, "UTC") + .json(timestampsWithFormatPath) + + assert(readBack.schema === customSchema) + checkAnswer(readBack, timestampsWithFormat) + } + } + } + } } From d33bf4b12ce5fcd4222624c7347444433e3033de Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Tue, 18 Dec 2018 20:52:02 +0800 Subject: [PATCH 095/194] [SPARK-26081][SQL][FOLLOW-UP] Use foreach instead of misuse of map (for Unit) ## What changes were proposed in this pull request? This PR proposes to use foreach instead of misuse of map (for Unit). This could cause some weird errors potentially and it's not a good practice anyway. See also SPARK-16694 ## How was this patch tested? N/A Closes #23341 from HyukjinKwon/followup-SPARK-26081. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- .../spark/sql/execution/datasources/csv/CSVFileFormat.scala | 2 +- .../spark/sql/execution/datasources/json/JsonFileFormat.scala | 2 +- .../spark/sql/execution/datasources/text/TextFileFormat.scala | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index f7d8a9e1042d5..f4f139d180058 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -189,5 +189,5 @@ private[csv] class CsvOutputWriter( gen.write(row) } - override def close(): Unit = univocityGenerator.map(_.close()) + override def close(): Unit = univocityGenerator.foreach(_.close()) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 3042133ee43aa..40f55e7068010 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -190,5 +190,5 @@ private[json] class JsonOutputWriter( gen.writeLineEnding() } - override def close(): Unit = jacksonGenerator.map(_.close()) + override def close(): Unit = jacksonGenerator.foreach(_.close()) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index 01948ab25d63c..0607f7b3c0d4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -153,7 +153,7 @@ class TextOutputWriter( private var outputStream: Option[OutputStream] = None override def write(row: InternalRow): Unit = { - val os = outputStream.getOrElse{ + val os = outputStream.getOrElse { val newStream = CodecStreams.createOutputStream(context, new Path(path)) outputStream = Some(newStream) newStream @@ -167,6 +167,6 @@ class TextOutputWriter( } override def close(): Unit = { - outputStream.map(_.close()) + outputStream.foreach(_.close()) } } From 77d78b8f531a34b87d037311e190442a81e49e45 Mon Sep 17 00:00:00 2001 From: Stan Zhai Date: Tue, 18 Dec 2018 07:02:09 -0600 Subject: [PATCH 096/194] [SPARK-24680][DEPLOY] Support spark.executorEnv.JAVA_HOME in Standalone mode ## What changes were proposed in this pull request? spark.executorEnv.JAVA_HOME does not take effect when a Worker starting an Executor process in Standalone mode. This PR fixed this. ## How was this patch tested? Manual tests. Closes #21663 from stanzhai/fix-executor-env-java-home. Lead-authored-by: Stan Zhai Co-authored-by: Stan Zhai Signed-off-by: Sean Owen --- .../spark/launcher/AbstractCommandBuilder.java | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index ce24400f557cd..56edceb17bfb8 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -91,14 +91,18 @@ abstract List buildCommand(Map env) */ List buildJavaCommand(String extraClassPath) throws IOException { List cmd = new ArrayList<>(); - String envJavaHome; - if (javaHome != null) { - cmd.add(join(File.separator, javaHome, "bin", "java")); - } else if ((envJavaHome = System.getenv("JAVA_HOME")) != null) { - cmd.add(join(File.separator, envJavaHome, "bin", "java")); - } else { - cmd.add(join(File.separator, System.getProperty("java.home"), "bin", "java")); + String[] candidateJavaHomes = new String[] { + javaHome, + childEnv.get("JAVA_HOME"), + System.getenv("JAVA_HOME"), + System.getProperty("java.home") + }; + for (String javaHome : candidateJavaHomes) { + if (javaHome != null) { + cmd.add(join(File.separator, javaHome, "bin", "java")); + break; + } } // Load extra JAVA_OPTS from conf/java-opts, if it exists. From 4af79806c91217fdc5ed7fb335fdbc5341874f70 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Wed, 19 Dec 2018 00:01:53 +0800 Subject: [PATCH 097/194] [SPARK-26384][SQL] Propagate SQL configs for CSV schema inferring ## What changes were proposed in this pull request? Currently, SQL configs are not propagated to executors while schema inferring in CSV datasource. For example, changing of `spark.sql.legacy.timeParser.enabled` does not impact on inferring timestamp types. In the PR, I propose to fix the issue by wrapping schema inferring action using `SQLExecution.withSQLConfPropagated`. ## How was this patch tested? Added logging to `TimestampFormatter`: ```patch -object TimestampFormatter { +object TimestampFormatter extends Logging { def apply(format: String, timeZone: TimeZone, locale: Locale): TimestampFormatter = { if (SQLConf.get.legacyTimeParserEnabled) { + logError("LegacyFallbackTimestampFormatter is being used") new LegacyFallbackTimestampFormatter(format, timeZone, locale) } else { + logError("Iso8601TimestampFormatter is being used") new Iso8601TimestampFormatter(format, timeZone, locale) } } ``` and run the command in `spark-shell`: ```shell $ ./bin/spark-shell --conf spark.sql.legacy.timeParser.enabled=true ``` ```scala scala> Seq("2010|10|10").toDF.repartition(1).write.mode("overwrite").text("/tmp/foo") scala> spark.read.option("inferSchema", "true").option("header", "false").option("timestampFormat", "yyyy|MM|dd").csv("/tmp/foo").printSchema() 18/12/18 10:47:27 ERROR TimestampFormatter: LegacyFallbackTimestampFormatter is being used root |-- _c0: timestamp (nullable = true) ``` Closes #23345 from MaxGekk/csv-schema-infer-propagate-configs. Authored-by: Maxim Gekk Signed-off-by: Hyukjin Kwon --- .../sql/execution/datasources/csv/CSVDataSource.scala | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index b46dfb94c133e..375cec597166c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -35,6 +35,7 @@ import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVInferSchema, CSVOptions, UnivocityParser} +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType @@ -135,7 +136,9 @@ object TextInputCSVDataSource extends CSVDataSource { val parser = new CsvParser(parsedOptions.asParserSettings) linesWithoutHeader.map(parser.parseLine) } - new CSVInferSchema(parsedOptions).infer(tokenRDD, header) + SQLExecution.withSQLConfPropagated(csv.sparkSession) { + new CSVInferSchema(parsedOptions).infer(tokenRDD, header) + } case _ => // If the first line could not be read, just return the empty schema. StructType(Nil) @@ -208,7 +211,9 @@ object MultiLineCSVDataSource extends CSVDataSource { encoding = parsedOptions.charset) } val sampled = CSVUtils.sample(tokenRDD, parsedOptions) - new CSVInferSchema(parsedOptions).infer(sampled, header) + SQLExecution.withSQLConfPropagated(sparkSession) { + new CSVInferSchema(parsedOptions).infer(sampled, header) + } case None => // If the first row could not be read, just return the empty schema. StructType(Nil) From 0702b7059f3e0a2be4371e8d4d7727e5f7ff2255 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 18 Dec 2018 10:09:56 -0800 Subject: [PATCH 098/194] [SPARK-26382][CORE] prefix comparator should handle -0.0 ## What changes were proposed in this pull request? This is kind of a followup of https://github.com/apache/spark/pull/23239 The `UnsafeProject` will normalize special float/double values(NaN and -0.0), so the sorter doesn't have to handle it. However, for consistency and future-proof, this PR proposes to normalize `-0.0` in the prefix comparator, so that it's same with the normal ordering. Note that prefix comparator handles NaN as well. This is not a bug fix, but a safe guard. ## How was this patch tested? existing tests Closes #23334 from cloud-fan/sort. Authored-by: Wenchen Fan Signed-off-by: Dongjoon Hyun --- .../unsafe/sort/PrefixComparators.java | 2 ++ .../unsafe/sort/PrefixComparatorsSuite.scala | 17 +++++++++++++++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java index 0910db22af004..bef1bdadb27aa 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java @@ -69,6 +69,8 @@ public static final class DoublePrefixComparator { * details see http://stereopsis.com/radix.html. */ public static long computePrefix(double value) { + // normalize -0.0 to 0.0, as they should be equal + value = value == -0.0 ? 0.0 : value; // Java's doubleToLongBits already canonicalizes all NaN values to the smallest possible // positive NaN, so there's nothing special we need to do for NaNs. long bits = Double.doubleToLongBits(value); diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala index 73546ef1b7a60..38cb37c524594 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala @@ -125,6 +125,7 @@ class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks { val nan2Prefix = PrefixComparators.DoublePrefixComparator.computePrefix(nan2) assert(nan1Prefix === nan2Prefix) val doubleMaxPrefix = PrefixComparators.DoublePrefixComparator.computePrefix(Double.MaxValue) + // NaN is greater than the max double value. assert(PrefixComparators.DOUBLE.compare(nan1Prefix, doubleMaxPrefix) === 1) } @@ -134,22 +135,34 @@ class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks { assert(java.lang.Double.doubleToRawLongBits(negativeNan) < 0) val prefix = PrefixComparators.DoublePrefixComparator.computePrefix(negativeNan) val doubleMaxPrefix = PrefixComparators.DoublePrefixComparator.computePrefix(Double.MaxValue) + // -NaN is greater than the max double value. assert(PrefixComparators.DOUBLE.compare(prefix, doubleMaxPrefix) === 1) } test("double prefix comparator handles other special values properly") { - val nullValue = 0L + // See `SortPrefix.nullValue` for how we deal with nulls for float/double type + val smallestNullPrefix = 0L + val largestNullPrefix = -1L val nan = PrefixComparators.DoublePrefixComparator.computePrefix(Double.NaN) val posInf = PrefixComparators.DoublePrefixComparator.computePrefix(Double.PositiveInfinity) val negInf = PrefixComparators.DoublePrefixComparator.computePrefix(Double.NegativeInfinity) val minValue = PrefixComparators.DoublePrefixComparator.computePrefix(Double.MinValue) val maxValue = PrefixComparators.DoublePrefixComparator.computePrefix(Double.MaxValue) val zero = PrefixComparators.DoublePrefixComparator.computePrefix(0.0) + val minusZero = PrefixComparators.DoublePrefixComparator.computePrefix(-0.0) + + // null is greater than everything including NaN, when we need to treat it as the largest value. + assert(PrefixComparators.DOUBLE.compare(largestNullPrefix, nan) === 1) + // NaN is greater than the positive infinity. assert(PrefixComparators.DOUBLE.compare(nan, posInf) === 1) assert(PrefixComparators.DOUBLE.compare(posInf, maxValue) === 1) assert(PrefixComparators.DOUBLE.compare(maxValue, zero) === 1) assert(PrefixComparators.DOUBLE.compare(zero, minValue) === 1) assert(PrefixComparators.DOUBLE.compare(minValue, negInf) === 1) - assert(PrefixComparators.DOUBLE.compare(negInf, nullValue) === 1) + // null is smaller than everything including negative infinity, when we need to treat it as + // the smallest value. + assert(PrefixComparators.DOUBLE.compare(negInf, smallestNullPrefix) === 1) + // 0.0 should be equal to -0.0. + assert(PrefixComparators.DOUBLE.compare(zero, minusZero) === 0) } } From c446c9e0d1ae9e942ed491508f3ed8a44ddfb2a8 Mon Sep 17 00:00:00 2001 From: Jackey Lee Date: Tue, 18 Dec 2018 12:15:36 -0600 Subject: [PATCH 099/194] [SPARK-26394][CORE] Fix annotation error for Utils.timeStringAsMs ## What changes were proposed in this pull request? Change microseconds to milliseconds in annotation of Utils.timeStringAsMs. Closes #23346 from stczwd/stczwd. Authored-by: Jackey Lee Signed-off-by: Sean Owen --- core/src/main/scala/org/apache/spark/util/Utils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index b4ea1ee950217..143abd3bbea8e 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1037,7 +1037,7 @@ private[spark] object Utils extends Logging { } /** - * Convert a time parameter such as (50s, 100ms, or 250us) to microseconds for internal use. If + * Convert a time parameter such as (50s, 100ms, or 250us) to milliseconds for internal use. If * no suffix is provided, the passed number is assumed to be in ms. */ def timeStringAsMs(str: String): Long = { From 4578d12facd3e3b51006bf5dfae768eb09cc45e6 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 18 Dec 2018 13:30:09 -0800 Subject: [PATCH 100/194] [SPARK-25815][K8S] Support kerberos in client mode, keytab-based token renewal. This change hooks up the k8s backed to the updated HadoopDelegationTokenManager, so that delegation tokens are also available in client mode, and keytab-based token renewal is enabled. The change re-works the k8s feature steps related to kerberos so that the driver does all the credential management and provides all the needed information to executors - so nothing needs to be added to executor pods. This also makes cluster mode behave a lot more similarly to client mode, since no driver-related config steps are run in the latter case. The main two things that don't need to happen in executors anymore are: - adding the Hadoop config to the executor pods: this is not needed since the Spark driver will serialize the Hadoop config and send it to executors when running tasks. - mounting the kerberos config file in the executor pods: this is not needed once you remove the above. The Hadoop conf sent by the driver with the tasks is already resolved (i.e. has all the kerberos names properly defined), so executors do not need access to the kerberos realm information anymore. The change also avoids creating delegation tokens unnecessarily. This means that they'll only be created if a secret with tokens was not provided, and if a keytab is not provided. In either of those cases, the driver code will handle delegation tokens: in cluster mode by creating a secret and stashing them, in client mode by using existing mechanisms to send DTs to executors. One last feature: the change also allows defining a keytab with a "local:" URI. This is supported in client mode (although that's the same as not saying "local:"), and in k8s cluster mode. This allows the keytab to be mounted onto the image from a pre-existing secret, for example. Finally, the new code always sets SPARK_USER in the driver and executor pods. This is in line with how other resource managers behave: the submitting user reflects which user will access Hadoop services in the app. (With kerberos, that's overridden by the logged in user.) That user is unrelated to the OS user the app is running as inside the containers. Tested: - client and cluster mode with kinit - cluster mode with keytab - cluster mode with local: keytab - YARN cluster with keytab (to make sure it isn't broken) Closes #22911 from vanzin/SPARK-25815. Authored-by: Marcelo Vanzin Signed-off-by: Marcelo Vanzin --- .../org/apache/spark/deploy/SparkSubmit.scala | 29 +- .../HadoopDelegationTokenManager.scala | 8 +- .../scala/org/apache/spark/util/Utils.scala | 8 + .../apache/spark/deploy/k8s/Constants.scala | 9 +- .../spark/deploy/k8s/KubernetesConf.scala | 4 - .../apache/spark/deploy/k8s/SparkPod.scala | 25 +- .../k8s/features/BasicDriverFeatureStep.scala | 4 + .../features/BasicExecutorFeatureStep.scala | 4 + .../HadoopConfDriverFeatureStep.scala | 124 +++++++ .../HadoopConfExecutorFeatureStep.scala | 40 --- .../HadoopSparkUserExecutorFeatureStep.scala | 35 -- .../KerberosConfDriverFeatureStep.scala | 315 ++++++++++-------- .../KerberosConfExecutorFeatureStep.scala | 46 --- .../hadooputils/HadoopBootstrapUtil.scala | 283 ---------------- .../hadooputils/KerberosConfigSpec.scala | 33 -- .../k8s/submit/KubernetesDriverBuilder.scala | 1 + .../KubernetesClusterSchedulerBackend.scala | 7 +- .../k8s/KubernetesExecutorBuilder.scala | 5 +- .../BasicDriverFeatureStepSuite.scala | 3 +- .../BasicExecutorFeatureStepSuite.scala | 9 +- .../HadoopConfDriverFeatureStepSuite.scala | 71 ++++ .../KerberosConfDriverFeatureStepSuite.scala | 171 ++++++++++ .../KubernetesFeaturesTestUtils.scala | 6 + .../org/apache/spark/deploy/yarn/Client.scala | 24 +- .../spark/deploy/yarn/ClientSuite.scala | 6 +- 25 files changed, 649 insertions(+), 621 deletions(-) create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopConfDriverFeatureStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopConfExecutorFeatureStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopSparkUserExecutorFeatureStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfExecutorFeatureStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/hadooputils/HadoopBootstrapUtil.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/hadooputils/KerberosConfigSpec.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/HadoopConfDriverFeatureStepSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStepSuite.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index d4055cb6c5853..763bd0a70a035 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy import java.io._ import java.lang.reflect.{InvocationTargetException, Modifier, UndeclaredThrowableException} -import java.net.URL +import java.net.{URI, URL} import java.security.PrivilegedExceptionAction import java.text.ParseException import java.util.UUID @@ -334,19 +334,20 @@ private[spark] class SparkSubmit extends Logging { val hadoopConf = conf.getOrElse(SparkHadoopUtil.newConfiguration(sparkConf)) val targetDir = Utils.createTempDir() - // assure a keytab is available from any place in a JVM - if (clusterManager == YARN || clusterManager == LOCAL || isMesosClient || isKubernetesCluster) { - if (args.principal != null) { - if (args.keytab != null) { - require(new File(args.keytab).exists(), s"Keytab file: ${args.keytab} does not exist") - // Add keytab and principal configurations in sysProps to make them available - // for later use; e.g. in spark sql, the isolated class loader used to talk - // to HiveMetastore will use these settings. They will be set as Java system - // properties and then loaded by SparkConf - sparkConf.set(KEYTAB, args.keytab) - sparkConf.set(PRINCIPAL, args.principal) - UserGroupInformation.loginUserFromKeytab(args.principal, args.keytab) - } + // Kerberos is not supported in standalone mode, and keytab support is not yet available + // in Mesos cluster mode. + if (clusterManager != STANDALONE + && !isMesosCluster + && args.principal != null + && args.keytab != null) { + // If client mode, make sure the keytab is just a local path. + if (deployMode == CLIENT && Utils.isLocalUri(args.keytab)) { + args.keytab = new URI(args.keytab).getPath() + } + + if (!Utils.isLocalUri(args.keytab)) { + require(new File(args.keytab).exists(), s"Keytab file: ${args.keytab} does not exist") + UserGroupInformation.loginUserFromKeytab(args.principal, args.keytab) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala index 126a6ab801369..f7e3ddecee093 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.security import java.io.File +import java.net.URI import java.security.PrivilegedExceptionAction import java.util.concurrent.{ScheduledExecutorService, TimeUnit} import java.util.concurrent.atomic.AtomicReference @@ -71,11 +72,13 @@ private[spark] class HadoopDelegationTokenManager( private val providerEnabledConfig = "spark.security.credentials.%s.enabled" private val principal = sparkConf.get(PRINCIPAL).orNull - private val keytab = sparkConf.get(KEYTAB).orNull + + // The keytab can be a local: URI for cluster mode, so translate it to a regular path. If it is + // needed later on, the code will check that it exists. + private val keytab = sparkConf.get(KEYTAB).map { uri => new URI(uri).getPath() }.orNull require((principal == null) == (keytab == null), "Both principal and keytab must be defined, or neither.") - require(keytab == null || new File(keytab).isFile(), s"Cannot find keytab at $keytab.") private val delegationTokenProviders = loadProviders() logDebug("Using the following builtin delegation token providers: " + @@ -264,6 +267,7 @@ private[spark] class HadoopDelegationTokenManager( private def doLogin(): UserGroupInformation = { logInfo(s"Attempting to login to KDC using principal: $principal") + require(new File(keytab).isFile(), s"Cannot find keytab at $keytab.") val ugi = UserGroupInformation.loginUserFromKeytabAndReturnUGI(principal, keytab) logInfo("Successfully logged into KDC.") ugi diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 143abd3bbea8e..f322e92c6c8cb 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -92,6 +92,9 @@ private[spark] object Utils extends Logging { private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 @volatile private var localRootDirs: Array[String] = null + /** Scheme used for files that are locally available on worker nodes in the cluster. */ + val LOCAL_SCHEME = "local" + /** Serialize an object using Java serialization */ def serialize[T](o: T): Array[Byte] = { val bos = new ByteArrayOutputStream() @@ -2829,6 +2832,11 @@ private[spark] object Utils extends Logging { def isClientMode(conf: SparkConf): Boolean = { "client".equals(conf.get(SparkLauncher.DEPLOY_MODE, "client")) } + + /** Returns whether the URI is a "local:" URI. */ + def isLocalUri(uri: String): Boolean = { + uri.startsWith(s"$LOCAL_SCHEME:") + } } private[util] object CallerContext extends Logging { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala index 85917b88e912a..76041e7de5182 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala @@ -87,25 +87,22 @@ private[spark] object Constants { val NON_JVM_MEMORY_OVERHEAD_FACTOR = 0.4d // Hadoop Configuration - val HADOOP_FILE_VOLUME = "hadoop-properties" + val HADOOP_CONF_VOLUME = "hadoop-properties" val KRB_FILE_VOLUME = "krb5-file" val HADOOP_CONF_DIR_PATH = "/opt/hadoop/conf" val KRB_FILE_DIR_PATH = "/etc" val ENV_HADOOP_CONF_DIR = "HADOOP_CONF_DIR" val HADOOP_CONFIG_MAP_NAME = "spark.kubernetes.executor.hadoopConfigMapName" - val KRB5_CONFIG_MAP_NAME = - "spark.kubernetes.executor.krb5ConfigMapName" // Kerberos Configuration - val KERBEROS_DELEGEGATION_TOKEN_SECRET_NAME = "delegation-tokens" val KERBEROS_DT_SECRET_NAME = "spark.kubernetes.kerberos.dt-secret-name" val KERBEROS_DT_SECRET_KEY = "spark.kubernetes.kerberos.dt-secret-key" - val KERBEROS_SPARK_USER_NAME = - "spark.kubernetes.kerberos.spark-user-name" val KERBEROS_SECRET_KEY = "hadoop-tokens" + val KERBEROS_KEYTAB_VOLUME = "kerberos-keytab" + val KERBEROS_KEYTAB_MOUNT_POINT = "/mnt/secrets/kerberos-keytab" // Hadoop credentials secrets for the Spark app. val SPARK_APP_HADOOP_CREDENTIALS_BASE_DIR = "/mnt/secrets/hadoop-credentials" diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala index 0bea02324b7bc..0a2df9ed2b8e8 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala @@ -42,10 +42,6 @@ private[spark] abstract class KubernetesConf(val sparkConf: SparkConf) { def appName: String = get("spark.app.name", "spark") - def hadoopConfigMapName: String = s"$resourceNamePrefix-hadoop-config" - - def krbConfigMapName: String = s"$resourceNamePrefix-krb5-file" - def namespace: String = get(KUBERNETES_NAMESPACE) def imagePullPolicy: String = get(CONTAINER_IMAGE_PULL_POLICY) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPod.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPod.scala index 345dd117fd35f..fd1196368a7ff 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPod.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPod.scala @@ -18,7 +18,30 @@ package org.apache.spark.deploy.k8s import io.fabric8.kubernetes.api.model.{Container, ContainerBuilder, Pod, PodBuilder} -private[spark] case class SparkPod(pod: Pod, container: Container) +private[spark] case class SparkPod(pod: Pod, container: Container) { + + /** + * Convenience method to apply a series of chained transformations to a pod. + * + * Use it like: + * + * original.modify { case pod => + * // update pod and return new one + * }.modify { case pod => + * // more changes that create a new pod + * }.modify { + * case pod if someCondition => // new pod + * } + * + * This makes it cleaner to apply multiple transformations, avoiding having to create + * a bunch of awkwardly-named local variables. Since the argument is a partial function, + * it can do matching without needing to exhaust all the possibilities. If the function + * is not applied, then the original pod will be kept. + */ + def transform(fn: PartialFunction[SparkPod, SparkPod]): SparkPod = fn.lift(this).getOrElse(this) + +} + private[spark] object SparkPod { def initialPod(): SparkPod = { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala index d8cf3653d3226..8362c14fb289d 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala @@ -110,6 +110,10 @@ private[spark] class BasicDriverFeatureStep(conf: KubernetesDriverConf) .withContainerPort(driverUIPort) .withProtocol("TCP") .endPort() + .addNewEnv() + .withName(ENV_SPARK_USER) + .withValue(Utils.getCurrentUserName()) + .endEnv() .addAllToEnv(driverCustomEnvs.asJava) .addNewEnv() .withName(ENV_DRIVER_BIND_ADDRESS) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala index 1bc9c292597b5..d1bddf290f6eb 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala @@ -163,6 +163,10 @@ private[spark] class BasicExecutorFeatureStep( .addToLimits("memory", executorMemoryQuantity) .addToRequests("cpu", executorCpuQuantity) .endResources() + .addNewEnv() + .withName(ENV_SPARK_USER) + .withValue(Utils.getCurrentUserName()) + .endEnv() .addAllToEnv(executorEnv.asJava) .withPorts(requiredPorts.asJava) .addToArgs("executor") diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopConfDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopConfDriverFeatureStep.scala new file mode 100644 index 0000000000000..d602ed5481e65 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopConfDriverFeatureStep.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import java.io.File +import java.nio.charset.StandardCharsets + +import scala.collection.JavaConverters._ + +import com.google.common.io.Files +import io.fabric8.kubernetes.api.model._ + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesUtils, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ + +/** + * Mounts the Hadoop configuration - either a pre-defined config map, or a local configuration + * directory - on the driver pod. + */ +private[spark] class HadoopConfDriverFeatureStep(conf: KubernetesConf) + extends KubernetesFeatureConfigStep { + + private val confDir = Option(conf.sparkConf.getenv(ENV_HADOOP_CONF_DIR)) + private val existingConfMap = conf.get(KUBERNETES_HADOOP_CONF_CONFIG_MAP) + + KubernetesUtils.requireNandDefined( + confDir, + existingConfMap, + "Do not specify both the `HADOOP_CONF_DIR` in your ENV and the ConfigMap " + + "as the creation of an additional ConfigMap, when one is already specified is extraneous") + + private lazy val confFiles: Seq[File] = { + val dir = new File(confDir.get) + if (dir.isDirectory) { + dir.listFiles.filter(_.isFile).toSeq + } else { + Nil + } + } + + private def newConfigMapName: String = s"${conf.resourceNamePrefix}-hadoop-config" + + private def hasHadoopConf: Boolean = confDir.isDefined || existingConfMap.isDefined + + override def configurePod(original: SparkPod): SparkPod = { + original.transform { case pod if hasHadoopConf => + val confVolume = if (confDir.isDefined) { + val keyPaths = confFiles.map { file => + new KeyToPathBuilder() + .withKey(file.getName()) + .withPath(file.getName()) + .build() + } + new VolumeBuilder() + .withName(HADOOP_CONF_VOLUME) + .withNewConfigMap() + .withName(newConfigMapName) + .withItems(keyPaths.asJava) + .endConfigMap() + .build() + } else { + new VolumeBuilder() + .withName(HADOOP_CONF_VOLUME) + .withNewConfigMap() + .withName(existingConfMap.get) + .endConfigMap() + .build() + } + + val podWithConf = new PodBuilder(pod.pod) + .editSpec() + .addNewVolumeLike(confVolume) + .endVolume() + .endSpec() + .build() + + val containerWithMount = new ContainerBuilder(pod.container) + .addNewVolumeMount() + .withName(HADOOP_CONF_VOLUME) + .withMountPath(HADOOP_CONF_DIR_PATH) + .endVolumeMount() + .addNewEnv() + .withName(ENV_HADOOP_CONF_DIR) + .withValue(HADOOP_CONF_DIR_PATH) + .endEnv() + .build() + + SparkPod(podWithConf, containerWithMount) + } + } + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = { + if (confDir.isDefined) { + val fileMap = confFiles.map { file => + (file.getName(), Files.toString(file, StandardCharsets.UTF_8)) + }.toMap.asJava + + Seq(new ConfigMapBuilder() + .withNewMetadata() + .withName(newConfigMapName) + .endMetadata() + .addToData(fileMap) + .build()) + } else { + Nil + } + } + +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopConfExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopConfExecutorFeatureStep.scala deleted file mode 100644 index da332881ae1a2..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopConfExecutorFeatureStep.scala +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.features - -import org.apache.spark.deploy.k8s.{KubernetesExecutorConf, SparkPod} -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.features.hadooputils.HadoopBootstrapUtil -import org.apache.spark.internal.Logging - -/** - * This step is responsible for bootstraping the container with ConfigMaps - * containing Hadoop config files mounted as volumes and an ENV variable - * pointed to the mounted file directory. - */ -private[spark] class HadoopConfExecutorFeatureStep(conf: KubernetesExecutorConf) - extends KubernetesFeatureConfigStep with Logging { - - override def configurePod(pod: SparkPod): SparkPod = { - val hadoopConfDirCMapName = conf.getOption(HADOOP_CONFIG_MAP_NAME) - if (hadoopConfDirCMapName.isDefined) { - HadoopBootstrapUtil.bootstrapHadoopConfDir(None, None, hadoopConfDirCMapName, pod) - } else { - pod - } - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopSparkUserExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopSparkUserExecutorFeatureStep.scala deleted file mode 100644 index c038e75491ca5..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopSparkUserExecutorFeatureStep.scala +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.features - -import org.apache.spark.deploy.k8s.{KubernetesExecutorConf, SparkPod} -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.features.hadooputils.HadoopBootstrapUtil - -/** - * This step is responsible for setting ENV_SPARK_USER when HADOOP_FILES are detected - * however, this step would not be run if Kerberos is enabled, as Kerberos sets SPARK_USER - */ -private[spark] class HadoopSparkUserExecutorFeatureStep(conf: KubernetesExecutorConf) - extends KubernetesFeatureConfigStep { - - override def configurePod(pod: SparkPod): SparkPod = { - conf.getOption(KERBEROS_SPARK_USER_NAME).map { user => - HadoopBootstrapUtil.bootstrapSparkUserPod(user, pod) - }.getOrElse(pod) - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStep.scala index c6d5a866fa7bc..721d7e97b21f8 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStep.scala @@ -16,31 +16,40 @@ */ package org.apache.spark.deploy.k8s.features -import io.fabric8.kubernetes.api.model.{HasMetadata, Secret, SecretBuilder} +import java.io.File +import java.nio.charset.StandardCharsets + +import scala.collection.JavaConverters._ + +import com.google.common.io.Files +import io.fabric8.kubernetes.api.model._ import org.apache.commons.codec.binary.Base64 -import org.apache.hadoop.security.{Credentials, UserGroupInformation} +import org.apache.hadoop.security.UserGroupInformation import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.k8s.{KubernetesDriverConf, KubernetesUtils, SparkPod} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.features.hadooputils._ import org.apache.spark.deploy.security.HadoopDelegationTokenManager +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ +import org.apache.spark.util.Utils /** - * Runs the necessary Hadoop-based logic based on Kerberos configs and the presence of the - * HADOOP_CONF_DIR. This runs various bootstrap methods defined in HadoopBootstrapUtil. + * Provide kerberos / service credentials to the Spark driver. + * + * There are three use cases, in order of precedence: + * + * - keytab: if a kerberos keytab is defined, it is provided to the driver, and the driver will + * manage the kerberos login and the creation of delegation tokens. + * - existing tokens: if a secret containing delegation tokens is provided, it will be mounted + * on the driver pod, and the driver will handle distribution of those tokens to executors. + * - tgt only: if Hadoop security is enabled, the local TGT will be used to create delegation + * tokens which will be provided to the driver. The driver will handle distribution of the + * tokens to executors. */ private[spark] class KerberosConfDriverFeatureStep(kubernetesConf: KubernetesDriverConf) - extends KubernetesFeatureConfigStep { - - private val hadoopConfDir = Option(kubernetesConf.sparkConf.getenv(ENV_HADOOP_CONF_DIR)) - private val hadoopConfigMapName = kubernetesConf.get(KUBERNETES_HADOOP_CONF_CONFIG_MAP) - KubernetesUtils.requireNandDefined( - hadoopConfDir, - hadoopConfigMapName, - "Do not specify both the `HADOOP_CONF_DIR` in your ENV and the ConfigMap " + - "as the creation of an additional ConfigMap, when one is already specified is extraneous") + extends KubernetesFeatureConfigStep with Logging { private val principal = kubernetesConf.get(org.apache.spark.internal.config.PRINCIPAL) private val keytab = kubernetesConf.get(org.apache.spark.internal.config.KEYTAB) @@ -49,15 +58,6 @@ private[spark] class KerberosConfDriverFeatureStep(kubernetesConf: KubernetesDri private val krb5File = kubernetesConf.get(KUBERNETES_KERBEROS_KRB5_FILE) private val krb5CMap = kubernetesConf.get(KUBERNETES_KERBEROS_KRB5_CONFIG_MAP) private val hadoopConf = SparkHadoopUtil.get.newConfiguration(kubernetesConf.sparkConf) - private val tokenManager = new HadoopDelegationTokenManager(kubernetesConf.sparkConf, hadoopConf) - private val isKerberosEnabled = - (hadoopConfDir.isDefined && UserGroupInformation.isSecurityEnabled) || - (hadoopConfigMapName.isDefined && (krb5File.isDefined || krb5CMap.isDefined)) - require(keytab.isEmpty || isKerberosEnabled, - "You must enable Kerberos support if you are specifying a Kerberos Keytab") - - require(existingSecretName.isEmpty || isKerberosEnabled, - "You must enable Kerberos support if you are specifying a Kerberos Secret") KubernetesUtils.requireNandDefined( krb5File, @@ -79,128 +79,183 @@ private[spark] class KerberosConfDriverFeatureStep(kubernetesConf: KubernetesDri "If a secret storing a Kerberos Delegation Token is specified you must also" + " specify the item-key where the data is stored") - private val hadoopConfigurationFiles = hadoopConfDir.map { hConfDir => - HadoopBootstrapUtil.getHadoopConfFiles(hConfDir) + if (!hasKerberosConf) { + logInfo("You have not specified a krb5.conf file locally or via a ConfigMap. " + + "Make sure that you have the krb5.conf locally on the driver image.") } - private val newHadoopConfigMapName = - if (hadoopConfigMapName.isEmpty) { - Some(kubernetesConf.hadoopConfigMapName) - } else { - None - } - // Either use pre-existing secret or login to create new Secret with DT stored within - private val kerberosConfSpec: Option[KerberosConfigSpec] = (for { - secretName <- existingSecretName - secretItemKey <- existingSecretItemKey - } yield { - KerberosConfigSpec( - dtSecret = None, - dtSecretName = secretName, - dtSecretItemKey = secretItemKey, - jobUserName = UserGroupInformation.getCurrentUser.getShortUserName) - }).orElse( - if (isKerberosEnabled) { - Some(buildKerberosSpec()) + // Create delegation tokens if needed. This is a lazy val so that it's not populated + // unnecessarily. But it needs to be accessible to different methods in this class, + // since it's not clear based solely on available configuration options that delegation + // tokens are needed when other credentials are not available. + private lazy val delegationTokens: Array[Byte] = { + if (keytab.isEmpty && existingSecretName.isEmpty) { + val tokenManager = new HadoopDelegationTokenManager(kubernetesConf.sparkConf, + SparkHadoopUtil.get.newConfiguration(kubernetesConf.sparkConf)) + val creds = UserGroupInformation.getCurrentUser().getCredentials() + tokenManager.obtainDelegationTokens(creds) + // If no tokens and no secrets are stored in the credentials, make sure nothing is returned, + // to avoid creating an unnecessary secret. + if (creds.numberOfTokens() > 0 || creds.numberOfSecretKeys() > 0) { + SparkHadoopUtil.get.serialize(creds) + } else { + null + } } else { - None + null } - ) + } - override def configurePod(pod: SparkPod): SparkPod = { - if (!isKerberosEnabled) { - return pod - } + private def needKeytabUpload: Boolean = keytab.exists(!Utils.isLocalUri(_)) - val hadoopBasedSparkPod = HadoopBootstrapUtil.bootstrapHadoopConfDir( - hadoopConfDir, - newHadoopConfigMapName, - hadoopConfigMapName, - pod) - kerberosConfSpec.map { hSpec => - HadoopBootstrapUtil.bootstrapKerberosPod( - hSpec.dtSecretName, - hSpec.dtSecretItemKey, - hSpec.jobUserName, - krb5File, - Some(kubernetesConf.krbConfigMapName), - krb5CMap, - hadoopBasedSparkPod) - }.getOrElse( - HadoopBootstrapUtil.bootstrapSparkUserPod( - UserGroupInformation.getCurrentUser.getShortUserName, - hadoopBasedSparkPod)) - } + private def dtSecretName: String = s"${kubernetesConf.resourceNamePrefix}-delegation-tokens" - override def getAdditionalPodSystemProperties(): Map[String, String] = { - if (!isKerberosEnabled) { - return Map.empty - } + private def ktSecretName: String = s"${kubernetesConf.resourceNamePrefix}-kerberos-keytab" - val resolvedConfValues = kerberosConfSpec.map { hSpec => - Map(KERBEROS_DT_SECRET_NAME -> hSpec.dtSecretName, - KERBEROS_DT_SECRET_KEY -> hSpec.dtSecretItemKey, - KERBEROS_SPARK_USER_NAME -> hSpec.jobUserName, - KRB5_CONFIG_MAP_NAME -> krb5CMap.getOrElse(kubernetesConf.krbConfigMapName)) - }.getOrElse( - Map(KERBEROS_SPARK_USER_NAME -> - UserGroupInformation.getCurrentUser.getShortUserName)) - Map(HADOOP_CONFIG_MAP_NAME -> - hadoopConfigMapName.getOrElse(kubernetesConf.hadoopConfigMapName)) ++ resolvedConfValues - } + private def hasKerberosConf: Boolean = krb5CMap.isDefined | krb5File.isDefined - override def getAdditionalKubernetesResources(): Seq[HasMetadata] = { - if (!isKerberosEnabled) { - return Seq.empty - } + private def newConfigMapName: String = s"${kubernetesConf.resourceNamePrefix}-krb5-file" - val hadoopConfConfigMap = for { - hName <- newHadoopConfigMapName - hFiles <- hadoopConfigurationFiles - } yield { - HadoopBootstrapUtil.buildHadoopConfigMap(hName, hFiles) - } + override def configurePod(original: SparkPod): SparkPod = { + original.transform { case pod if hasKerberosConf => + val configMapVolume = if (krb5CMap.isDefined) { + new VolumeBuilder() + .withName(KRB_FILE_VOLUME) + .withNewConfigMap() + .withName(krb5CMap.get) + .endConfigMap() + .build() + } else { + val krb5Conf = new File(krb5File.get) + new VolumeBuilder() + .withName(KRB_FILE_VOLUME) + .withNewConfigMap() + .withName(newConfigMapName) + .withItems(new KeyToPathBuilder() + .withKey(krb5Conf.getName()) + .withPath(krb5Conf.getName()) + .build()) + .endConfigMap() + .build() + } - val krb5ConfigMap = krb5File.map { fileLocation => - HadoopBootstrapUtil.buildkrb5ConfigMap( - kubernetesConf.krbConfigMapName, - fileLocation) - } + val podWithVolume = new PodBuilder(pod.pod) + .editSpec() + .addNewVolumeLike(configMapVolume) + .endVolume() + .endSpec() + .build() + + val containerWithMount = new ContainerBuilder(pod.container) + .addNewVolumeMount() + .withName(KRB_FILE_VOLUME) + .withMountPath(KRB_FILE_DIR_PATH + "/krb5.conf") + .withSubPath("krb5.conf") + .endVolumeMount() + .build() + + SparkPod(podWithVolume, containerWithMount) + }.transform { + case pod if needKeytabUpload => + // If keytab is defined and is a submission-local file (not local: URI), then create a + // secret for it. The keytab data will be stored in this secret below. + val podWitKeytab = new PodBuilder(pod.pod) + .editOrNewSpec() + .addNewVolume() + .withName(KERBEROS_KEYTAB_VOLUME) + .withNewSecret() + .withSecretName(ktSecretName) + .endSecret() + .endVolume() + .endSpec() + .build() + + val containerWithKeytab = new ContainerBuilder(pod.container) + .addNewVolumeMount() + .withName(KERBEROS_KEYTAB_VOLUME) + .withMountPath(KERBEROS_KEYTAB_MOUNT_POINT) + .endVolumeMount() + .build() + + SparkPod(podWitKeytab, containerWithKeytab) + + case pod if existingSecretName.isDefined | delegationTokens != null => + val secretName = existingSecretName.getOrElse(dtSecretName) + val itemKey = existingSecretItemKey.getOrElse(KERBEROS_SECRET_KEY) + + val podWithTokens = new PodBuilder(pod.pod) + .editOrNewSpec() + .addNewVolume() + .withName(SPARK_APP_HADOOP_SECRET_VOLUME_NAME) + .withNewSecret() + .withSecretName(secretName) + .endSecret() + .endVolume() + .endSpec() + .build() - val kerberosDTSecret = kerberosConfSpec.flatMap(_.dtSecret) + val containerWithTokens = new ContainerBuilder(pod.container) + .addNewVolumeMount() + .withName(SPARK_APP_HADOOP_SECRET_VOLUME_NAME) + .withMountPath(SPARK_APP_HADOOP_CREDENTIALS_BASE_DIR) + .endVolumeMount() + .addNewEnv() + .withName(ENV_HADOOP_TOKEN_FILE_LOCATION) + .withValue(s"$SPARK_APP_HADOOP_CREDENTIALS_BASE_DIR/$itemKey") + .endEnv() + .build() - hadoopConfConfigMap.toSeq ++ - krb5ConfigMap.toSeq ++ - kerberosDTSecret.toSeq + SparkPod(podWithTokens, containerWithTokens) + } } - private def buildKerberosSpec(): KerberosConfigSpec = { - // The JobUserUGI will be taken fom the Local Ticket Cache or via keytab+principal - // The login happens in the SparkSubmit so login logic is not necessary to include - val jobUserUGI = UserGroupInformation.getCurrentUser - val creds = jobUserUGI.getCredentials - tokenManager.obtainDelegationTokens(creds) - val tokenData = SparkHadoopUtil.get.serialize(creds) - require(tokenData.nonEmpty, "Did not obtain any delegation tokens") - val newSecretName = - s"${kubernetesConf.resourceNamePrefix}-$KERBEROS_DELEGEGATION_TOKEN_SECRET_NAME" - val secretDT = - new SecretBuilder() - .withNewMetadata() - .withName(newSecretName) - .endMetadata() - .addToData(KERBEROS_SECRET_KEY, Base64.encodeBase64String(tokenData)) - .build() - KerberosConfigSpec( - dtSecret = Some(secretDT), - dtSecretName = newSecretName, - dtSecretItemKey = KERBEROS_SECRET_KEY, - jobUserName = jobUserUGI.getShortUserName) + override def getAdditionalPodSystemProperties(): Map[String, String] = { + // If a submission-local keytab is provided, update the Spark config so that it knows the + // path of the keytab in the driver container. + if (needKeytabUpload) { + val ktName = new File(keytab.get).getName() + Map(KEYTAB.key -> s"$KERBEROS_KEYTAB_MOUNT_POINT/$ktName") + } else { + Map.empty + } } - private case class KerberosConfigSpec( - dtSecret: Option[Secret], - dtSecretName: String, - dtSecretItemKey: String, - jobUserName: String) + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = { + Seq[HasMetadata]() ++ { + krb5File.map { path => + val file = new File(path) + new ConfigMapBuilder() + .withNewMetadata() + .withName(newConfigMapName) + .endMetadata() + .addToData( + Map(file.getName() -> Files.toString(file, StandardCharsets.UTF_8)).asJava) + .build() + } + } ++ { + // If a submission-local keytab is provided, stash it in a secret. + if (needKeytabUpload) { + val kt = new File(keytab.get) + Seq(new SecretBuilder() + .withNewMetadata() + .withName(ktSecretName) + .endMetadata() + .addToData(kt.getName(), Base64.encodeBase64String(Files.toByteArray(kt))) + .build()) + } else { + Nil + } + } ++ { + if (delegationTokens != null) { + Seq(new SecretBuilder() + .withNewMetadata() + .withName(dtSecretName) + .endMetadata() + .addToData(KERBEROS_SECRET_KEY, Base64.encodeBase64String(delegationTokens)) + .build()) + } else { + Nil + } + } + } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfExecutorFeatureStep.scala deleted file mode 100644 index 907271b1cb483..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfExecutorFeatureStep.scala +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.features - -import org.apache.spark.deploy.k8s.{KubernetesExecutorConf, SparkPod} -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.features.hadooputils.HadoopBootstrapUtil -import org.apache.spark.internal.Logging - -/** - * This step is responsible for mounting the DT secret for the executors - */ -private[spark] class KerberosConfExecutorFeatureStep(conf: KubernetesExecutorConf) - extends KubernetesFeatureConfigStep with Logging { - - override def configurePod(pod: SparkPod): SparkPod = { - val maybeKrb5CMap = conf.getOption(KRB5_CONFIG_MAP_NAME) - if (maybeKrb5CMap.isDefined) { - logInfo(s"Mounting Resources for Kerberos") - HadoopBootstrapUtil.bootstrapKerberosPod( - conf.get(KERBEROS_DT_SECRET_NAME), - conf.get(KERBEROS_DT_SECRET_KEY), - conf.get(KERBEROS_SPARK_USER_NAME), - None, - None, - maybeKrb5CMap, - pod) - } else { - pod - } - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/hadooputils/HadoopBootstrapUtil.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/hadooputils/HadoopBootstrapUtil.scala deleted file mode 100644 index 5bee766caf2be..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/hadooputils/HadoopBootstrapUtil.scala +++ /dev/null @@ -1,283 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.features.hadooputils - -import java.io.File -import java.nio.charset.StandardCharsets - -import scala.collection.JavaConverters._ - -import com.google.common.io.Files -import io.fabric8.kubernetes.api.model._ - -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.SparkPod -import org.apache.spark.internal.Logging - -private[spark] object HadoopBootstrapUtil extends Logging { - - /** - * Mounting the DT secret for both the Driver and the executors - * - * @param dtSecretName Name of the secret that stores the Delegation Token - * @param dtSecretItemKey Name of the Item Key storing the Delegation Token - * @param userName Name of the SparkUser to set SPARK_USER - * @param fileLocation Optional Location of the krb5 file - * @param newKrb5ConfName Optional location of the ConfigMap for Krb5 - * @param existingKrb5ConfName Optional name of ConfigMap for Krb5 - * @param pod Input pod to be appended to - * @return a modified SparkPod - */ - def bootstrapKerberosPod( - dtSecretName: String, - dtSecretItemKey: String, - userName: String, - fileLocation: Option[String], - newKrb5ConfName: Option[String], - existingKrb5ConfName: Option[String], - pod: SparkPod): SparkPod = { - - val preConfigMapVolume = existingKrb5ConfName.map { kconf => - new VolumeBuilder() - .withName(KRB_FILE_VOLUME) - .withNewConfigMap() - .withName(kconf) - .endConfigMap() - .build() - } - - val createConfigMapVolume = for { - fLocation <- fileLocation - krb5ConfName <- newKrb5ConfName - } yield { - val krb5File = new File(fLocation) - val fileStringPath = krb5File.toPath.getFileName.toString - new VolumeBuilder() - .withName(KRB_FILE_VOLUME) - .withNewConfigMap() - .withName(krb5ConfName) - .withItems(new KeyToPathBuilder() - .withKey(fileStringPath) - .withPath(fileStringPath) - .build()) - .endConfigMap() - .build() - } - - // Breaking up Volume creation for clarity - val configMapVolume = preConfigMapVolume.orElse(createConfigMapVolume) - if (configMapVolume.isEmpty) { - logInfo("You have not specified a krb5.conf file locally or via a ConfigMap. " + - "Make sure that you have the krb5.conf locally on the Driver and Executor images") - } - - val kerberizedPodWithDTSecret = new PodBuilder(pod.pod) - .editOrNewSpec() - .addNewVolume() - .withName(SPARK_APP_HADOOP_SECRET_VOLUME_NAME) - .withNewSecret() - .withSecretName(dtSecretName) - .endSecret() - .endVolume() - .endSpec() - .build() - - // Optionally add the krb5.conf ConfigMap - val kerberizedPod = configMapVolume.map { cmVolume => - new PodBuilder(kerberizedPodWithDTSecret) - .editSpec() - .addNewVolumeLike(cmVolume) - .endVolume() - .endSpec() - .build() - }.getOrElse(kerberizedPodWithDTSecret) - - val kerberizedContainerWithMounts = new ContainerBuilder(pod.container) - .addNewVolumeMount() - .withName(SPARK_APP_HADOOP_SECRET_VOLUME_NAME) - .withMountPath(SPARK_APP_HADOOP_CREDENTIALS_BASE_DIR) - .endVolumeMount() - .addNewEnv() - .withName(ENV_HADOOP_TOKEN_FILE_LOCATION) - .withValue(s"$SPARK_APP_HADOOP_CREDENTIALS_BASE_DIR/$dtSecretItemKey") - .endEnv() - .addNewEnv() - .withName(ENV_SPARK_USER) - .withValue(userName) - .endEnv() - .build() - - // Optionally add the krb5.conf Volume Mount - val kerberizedContainer = - if (configMapVolume.isDefined) { - new ContainerBuilder(kerberizedContainerWithMounts) - .addNewVolumeMount() - .withName(KRB_FILE_VOLUME) - .withMountPath(KRB_FILE_DIR_PATH + "/krb5.conf") - .withSubPath("krb5.conf") - .endVolumeMount() - .build() - } else { - kerberizedContainerWithMounts - } - - SparkPod(kerberizedPod, kerberizedContainer) - } - - /** - * setting ENV_SPARK_USER when HADOOP_FILES are detected - * - * @param sparkUserName Name of the SPARK_USER - * @param pod Input pod to be appended to - * @return a modified SparkPod - */ - def bootstrapSparkUserPod(sparkUserName: String, pod: SparkPod): SparkPod = { - val envModifiedContainer = new ContainerBuilder(pod.container) - .addNewEnv() - .withName(ENV_SPARK_USER) - .withValue(sparkUserName) - .endEnv() - .build() - SparkPod(pod.pod, envModifiedContainer) - } - - /** - * Grabbing files in the HADOOP_CONF_DIR - * - * @param path location of HADOOP_CONF_DIR - * @return a list of File object - */ - def getHadoopConfFiles(path: String): Seq[File] = { - val dir = new File(path) - if (dir.isDirectory) { - dir.listFiles.filter(_.isFile).toSeq - } else { - Seq.empty[File] - } - } - - /** - * Bootstraping the container with ConfigMaps that store - * Hadoop configuration files - * - * @param hadoopConfDir directory location of HADOOP_CONF_DIR env - * @param newHadoopConfigMapName name of the new configMap for HADOOP_CONF_DIR - * @param existingHadoopConfigMapName name of the pre-defined configMap for HADOOP_CONF_DIR - * @param pod Input pod to be appended to - * @return a modified SparkPod - */ - def bootstrapHadoopConfDir( - hadoopConfDir: Option[String], - newHadoopConfigMapName: Option[String], - existingHadoopConfigMapName: Option[String], - pod: SparkPod): SparkPod = { - val preConfigMapVolume = existingHadoopConfigMapName.map { hConf => - new VolumeBuilder() - .withName(HADOOP_FILE_VOLUME) - .withNewConfigMap() - .withName(hConf) - .endConfigMap() - .build() } - - val createConfigMapVolume = for { - dirLocation <- hadoopConfDir - hConfName <- newHadoopConfigMapName - } yield { - val hadoopConfigFiles = getHadoopConfFiles(dirLocation) - val keyPaths = hadoopConfigFiles.map { file => - val fileStringPath = file.toPath.getFileName.toString - new KeyToPathBuilder() - .withKey(fileStringPath) - .withPath(fileStringPath) - .build() - } - new VolumeBuilder() - .withName(HADOOP_FILE_VOLUME) - .withNewConfigMap() - .withName(hConfName) - .withItems(keyPaths.asJava) - .endConfigMap() - .build() - } - - // Breaking up Volume Creation for clarity - val configMapVolume = preConfigMapVolume.getOrElse(createConfigMapVolume.get) - - val hadoopSupportedPod = new PodBuilder(pod.pod) - .editSpec() - .addNewVolumeLike(configMapVolume) - .endVolume() - .endSpec() - .build() - - val hadoopSupportedContainer = new ContainerBuilder(pod.container) - .addNewVolumeMount() - .withName(HADOOP_FILE_VOLUME) - .withMountPath(HADOOP_CONF_DIR_PATH) - .endVolumeMount() - .addNewEnv() - .withName(ENV_HADOOP_CONF_DIR) - .withValue(HADOOP_CONF_DIR_PATH) - .endEnv() - .build() - SparkPod(hadoopSupportedPod, hadoopSupportedContainer) - } - - /** - * Builds ConfigMap given the file location of the - * krb5.conf file - * - * @param configMapName name of configMap for krb5 - * @param fileLocation location of krb5 file - * @return a ConfigMap - */ - def buildkrb5ConfigMap( - configMapName: String, - fileLocation: String): ConfigMap = { - val file = new File(fileLocation) - new ConfigMapBuilder() - .withNewMetadata() - .withName(configMapName) - .endMetadata() - .addToData(Map(file.toPath.getFileName.toString -> - Files.toString(file, StandardCharsets.UTF_8)).asJava) - .build() - } - - /** - * Builds ConfigMap given the ConfigMap name - * and a list of Hadoop Conf files - * - * @param hadoopConfigMapName name of hadoopConfigMap - * @param hadoopConfFiles list of hadoopFiles - * @return a ConfigMap - */ - def buildHadoopConfigMap( - hadoopConfigMapName: String, - hadoopConfFiles: Seq[File]): ConfigMap = { - new ConfigMapBuilder() - .withNewMetadata() - .withName(hadoopConfigMapName) - .endMetadata() - .addToData(hadoopConfFiles.map { file => - (file.toPath.getFileName.toString, - Files.toString(file, StandardCharsets.UTF_8)) - }.toMap.asJava) - .build() - } - -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/hadooputils/KerberosConfigSpec.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/hadooputils/KerberosConfigSpec.scala deleted file mode 100644 index 7f7ef216cf485..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/hadooputils/KerberosConfigSpec.scala +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.features.hadooputils - -import io.fabric8.kubernetes.api.model.Secret - -/** - * Represents a given configuration of the Kerberos Configuration logic - *

    - * - The secret containing a DT, either previously specified or built on the fly - * - The name of the secret where the DT will be stored - * - The data item-key on the secret which correlates with where the current DT data is stored - * - The Job User's username - */ -private[spark] case class KerberosConfigSpec( - dtSecret: Option[Secret], - dtSecretName: String, - dtSecretItemKey: String, - jobUserName: String) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala index d2c0ced9fa2f4..57e4060bc85b9 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala @@ -46,6 +46,7 @@ private[spark] class KubernetesDriverBuilder { new LocalDirsFeatureStep(conf), new MountVolumesFeatureStep(conf), new DriverCommandFeatureStep(conf), + new HadoopConfDriverFeatureStep(conf), new KerberosConfDriverFeatureStep(conf), new PodTemplateConfigMapStep(conf)) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala index 03f5da2bb0bce..cd298971e02a7 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala @@ -25,6 +25,7 @@ import io.fabric8.kubernetes.client.KubernetesClient import org.apache.spark.SparkContext import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.security.HadoopDelegationTokenManager import org.apache.spark.rpc.{RpcAddress, RpcEnv} import org.apache.spark.scheduler.{ExecutorLossReason, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SchedulerBackendUtils} @@ -143,7 +144,11 @@ private[spark] class KubernetesClusterSchedulerBackend( } override def createDriverEndpoint(properties: Seq[(String, String)]): DriverEndpoint = { - new KubernetesDriverEndpoint(rpcEnv, properties) + new KubernetesDriverEndpoint(sc.env.rpcEnv, properties) + } + + override protected def createTokenManager(): Option[HadoopDelegationTokenManager] = { + Some(new HadoopDelegationTokenManager(conf, sc.hadoopConfiguration)) } private class KubernetesDriverEndpoint(rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala index 0b74966fe8685..48aa2c56d4d69 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala @@ -44,10 +44,7 @@ private[spark] class KubernetesExecutorBuilder { new MountSecretsFeatureStep(conf), new EnvSecretsFeatureStep(conf), new LocalDirsFeatureStep(conf), - new MountVolumesFeatureStep(conf), - new HadoopConfExecutorFeatureStep(conf), - new KerberosConfExecutorFeatureStep(conf), - new HadoopSparkUserExecutorFeatureStep(conf)) + new MountVolumesFeatureStep(conf)) features.foldLeft(initialPod) { case (pod, feature) => feature.configurePod(pod) } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala index e4951bc1e69ed..5ceb9d6d6fcd0 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.submit._ import org.apache.spark.internal.config._ import org.apache.spark.ui.SparkUI +import org.apache.spark.util.Utils class BasicDriverFeatureStepSuite extends SparkFunSuite { @@ -73,7 +74,6 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { val foundPortNames = configuredPod.container.getPorts.asScala.toSet assert(expectedPortNames === foundPortNames) - assert(configuredPod.container.getEnv.size === 3) val envs = configuredPod.container .getEnv .asScala @@ -82,6 +82,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { DRIVER_ENVS.foreach { case (k, v) => assert(envs(v) === v) } + assert(envs(ENV_SPARK_USER) === Utils.getCurrentUserName()) assert(configuredPod.pod.getSpec().getImagePullSecrets.asScala === TEST_IMAGE_PULL_SECRET_OBJECTS) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala index 05989d9be7ad5..c2efab01e4248 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala @@ -200,7 +200,8 @@ class BasicExecutorFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { ENV_EXECUTOR_MEMORY -> "1g", ENV_APPLICATION_ID -> KubernetesTestConf.APP_ID, ENV_SPARK_CONF_DIR -> SPARK_CONF_DIR_INTERNAL, - ENV_EXECUTOR_POD_IP -> null) ++ additionalEnvVars + ENV_EXECUTOR_POD_IP -> null, + ENV_SPARK_USER -> Utils.getCurrentUserName()) val extraJavaOptsStart = additionalEnvVars.keys.count(_.startsWith(ENV_JAVA_OPT_PREFIX)) val extraJavaOpts = Utils.sparkJavaOpts(conf, SparkConf.isExecutorStartupConf) @@ -208,9 +209,11 @@ class BasicExecutorFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { s"$ENV_JAVA_OPT_PREFIX${ind + extraJavaOptsStart}" -> opt }.toMap - val mapEnvs = executorPod.container.getEnv.asScala.map { + val containerEnvs = executorPod.container.getEnv.asScala.map { x => (x.getName, x.getValue) }.toMap - assert((defaultEnvs ++ extraJavaOptsEnvs) === mapEnvs) + + val expectedEnvs = defaultEnvs ++ additionalEnvVars ++ extraJavaOptsEnvs + assert(containerEnvs === expectedEnvs) } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/HadoopConfDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/HadoopConfDriverFeatureStepSuite.scala new file mode 100644 index 0000000000000..e1c01dbdc7358 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/HadoopConfDriverFeatureStepSuite.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import java.io.File +import java.nio.charset.StandardCharsets.UTF_8 + +import scala.collection.JavaConverters._ + +import com.google.common.io.Files +import io.fabric8.kubernetes.api.model.ConfigMap + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit.JavaMainAppResource +import org.apache.spark.util.{SparkConfWithEnv, Utils} + +class HadoopConfDriverFeatureStepSuite extends SparkFunSuite { + + import KubernetesFeaturesTestUtils._ + import SecretVolumeUtils._ + + test("mount hadoop config map if defined") { + val sparkConf = new SparkConf(false) + .set(Config.KUBERNETES_HADOOP_CONF_CONFIG_MAP, "testConfigMap") + val conf = KubernetesTestConf.createDriverConf(sparkConf = sparkConf) + val step = new HadoopConfDriverFeatureStep(conf) + checkPod(step.configurePod(SparkPod.initialPod())) + assert(step.getAdditionalKubernetesResources().isEmpty) + } + + test("create hadoop config map if config dir is defined") { + val confDir = Utils.createTempDir() + val confFiles = Set("core-site.xml", "hdfs-site.xml") + + confFiles.foreach { f => + Files.write("some data", new File(confDir, f), UTF_8) + } + + val sparkConf = new SparkConfWithEnv(Map(ENV_HADOOP_CONF_DIR -> confDir.getAbsolutePath())) + val conf = KubernetesTestConf.createDriverConf(sparkConf = sparkConf) + + val step = new HadoopConfDriverFeatureStep(conf) + checkPod(step.configurePod(SparkPod.initialPod())) + + val hadoopConfMap = filter[ConfigMap](step.getAdditionalKubernetesResources()).head + assert(hadoopConfMap.getData().keySet().asScala === confFiles) + } + + private def checkPod(pod: SparkPod): Unit = { + assert(podHasVolume(pod.pod, HADOOP_CONF_VOLUME)) + assert(containerHasVolume(pod.container, HADOOP_CONF_VOLUME, HADOOP_CONF_DIR_PATH)) + assert(containerHasEnvVar(pod.container, ENV_HADOOP_CONF_DIR)) + } + +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStepSuite.scala new file mode 100644 index 0000000000000..41ca3a94ce7a7 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStepSuite.scala @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import java.io.File +import java.nio.charset.StandardCharsets.UTF_8 +import java.security.PrivilegedExceptionAction + +import scala.collection.JavaConverters._ + +import com.google.common.io.Files +import io.fabric8.kubernetes.api.model.{ConfigMap, Secret} +import org.apache.commons.codec.binary.Base64 +import org.apache.hadoop.io.Text +import org.apache.hadoop.security.{Credentials, UserGroupInformation} + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.deploy.k8s._ +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit.JavaMainAppResource +import org.apache.spark.internal.config._ +import org.apache.spark.util.Utils + +class KerberosConfDriverFeatureStepSuite extends SparkFunSuite { + + import KubernetesFeaturesTestUtils._ + import SecretVolumeUtils._ + + private val tmpDir = Utils.createTempDir() + + test("mount krb5 config map if defined") { + val configMap = "testConfigMap" + val step = createStep( + new SparkConf(false).set(KUBERNETES_KERBEROS_KRB5_CONFIG_MAP, configMap)) + + checkPodForKrbConf(step.configurePod(SparkPod.initialPod()), configMap) + assert(step.getAdditionalPodSystemProperties().isEmpty) + assert(filter[ConfigMap](step.getAdditionalKubernetesResources()).isEmpty) + } + + test("create krb5.conf config map if local config provided") { + val krbConf = File.createTempFile("krb5", ".conf", tmpDir) + Files.write("some data", krbConf, UTF_8) + + val sparkConf = new SparkConf(false) + .set(KUBERNETES_KERBEROS_KRB5_FILE, krbConf.getAbsolutePath()) + val step = createStep(sparkConf) + + val confMap = filter[ConfigMap](step.getAdditionalKubernetesResources()).head + assert(confMap.getData().keySet().asScala === Set(krbConf.getName())) + + checkPodForKrbConf(step.configurePod(SparkPod.initialPod()), confMap.getMetadata().getName()) + assert(step.getAdditionalPodSystemProperties().isEmpty) + } + + test("create keytab secret if client keytab file used") { + val keytab = File.createTempFile("keytab", ".bin", tmpDir) + Files.write("some data", keytab, UTF_8) + + val sparkConf = new SparkConf(false) + .set(KEYTAB, keytab.getAbsolutePath()) + .set(PRINCIPAL, "alice") + val step = createStep(sparkConf) + + val pod = step.configurePod(SparkPod.initialPod()) + assert(podHasVolume(pod.pod, KERBEROS_KEYTAB_VOLUME)) + assert(containerHasVolume(pod.container, KERBEROS_KEYTAB_VOLUME, KERBEROS_KEYTAB_MOUNT_POINT)) + + assert(step.getAdditionalPodSystemProperties().keys === Set(KEYTAB.key)) + + val secret = filter[Secret](step.getAdditionalKubernetesResources()).head + assert(secret.getData().keySet().asScala === Set(keytab.getName())) + } + + test("do nothing if container-local keytab used") { + val sparkConf = new SparkConf(false) + .set(KEYTAB, "local:/my.keytab") + .set(PRINCIPAL, "alice") + val step = createStep(sparkConf) + + val initial = SparkPod.initialPod() + assert(step.configurePod(initial) === initial) + assert(step.getAdditionalPodSystemProperties().isEmpty) + assert(step.getAdditionalKubernetesResources().isEmpty) + } + + test("mount delegation tokens if provided") { + val dtSecret = "tokenSecret" + val sparkConf = new SparkConf(false) + .set(KUBERNETES_KERBEROS_DT_SECRET_NAME, dtSecret) + .set(KUBERNETES_KERBEROS_DT_SECRET_ITEM_KEY, "dtokens") + val step = createStep(sparkConf) + + checkPodForTokens(step.configurePod(SparkPod.initialPod()), dtSecret) + assert(step.getAdditionalPodSystemProperties().isEmpty) + assert(step.getAdditionalKubernetesResources().isEmpty) + } + + test("create delegation tokens if needed") { + // Since HadoopDelegationTokenManager does not create any tokens without proper configs and + // services, start with a test user that already has some tokens that will just be piped + // through to the driver. + val testUser = UserGroupInformation.createUserForTesting("k8s", Array()) + testUser.doAs(new PrivilegedExceptionAction[Unit]() { + override def run(): Unit = { + val creds = testUser.getCredentials() + creds.addSecretKey(new Text("K8S_TEST_KEY"), Array[Byte](0x4, 0x2)) + testUser.addCredentials(creds) + + val tokens = SparkHadoopUtil.get.serialize(creds) + + val step = createStep(new SparkConf(false)) + + val dtSecret = filter[Secret](step.getAdditionalKubernetesResources()).head + assert(dtSecret.getData().get(KERBEROS_SECRET_KEY) === Base64.encodeBase64String(tokens)) + + checkPodForTokens(step.configurePod(SparkPod.initialPod()), + dtSecret.getMetadata().getName()) + + assert(step.getAdditionalPodSystemProperties().isEmpty) + } + }) + } + + test("do nothing if no config and no tokens") { + val step = createStep(new SparkConf(false)) + val initial = SparkPod.initialPod() + assert(step.configurePod(initial) === initial) + assert(step.getAdditionalPodSystemProperties().isEmpty) + assert(step.getAdditionalKubernetesResources().isEmpty) + } + + private def checkPodForKrbConf(pod: SparkPod, confMapName: String): Unit = { + val podVolume = pod.pod.getSpec().getVolumes().asScala.find(_.getName() == KRB_FILE_VOLUME) + assert(podVolume.isDefined) + assert(containerHasVolume(pod.container, KRB_FILE_VOLUME, KRB_FILE_DIR_PATH + "/krb5.conf")) + assert(podVolume.get.getConfigMap().getName() === confMapName) + } + + private def checkPodForTokens(pod: SparkPod, dtSecretName: String): Unit = { + val podVolume = pod.pod.getSpec().getVolumes().asScala + .find(_.getName() == SPARK_APP_HADOOP_SECRET_VOLUME_NAME) + assert(podVolume.isDefined) + assert(containerHasVolume(pod.container, SPARK_APP_HADOOP_SECRET_VOLUME_NAME, + SPARK_APP_HADOOP_CREDENTIALS_BASE_DIR)) + assert(containerHasEnvVar(pod.container, ENV_HADOOP_TOKEN_FILE_LOCATION)) + assert(podVolume.get.getSecret().getSecretName() === dtSecretName) + } + + private def createStep(conf: SparkConf): KerberosConfDriverFeatureStep = { + val kconf = KubernetesTestConf.createDriverConf(sparkConf = conf) + new KerberosConfDriverFeatureStep(kconf) + } + +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala index f90380e30e52a..076b681be2397 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala @@ -17,6 +17,7 @@ package org.apache.spark.deploy.k8s.features import scala.collection.JavaConverters._ +import scala.reflect.ClassTag import io.fabric8.kubernetes.api.model.{Container, HasMetadata, PodBuilder, SecretBuilder} import org.mockito.Matchers @@ -63,4 +64,9 @@ object KubernetesFeaturesTestUtils { def containerHasEnvVar(container: Container, envVarName: String): Boolean = { container.getEnv.asScala.exists(envVar => envVar.getName == envVarName) } + + def filter[T: ClassTag](list: Seq[HasMetadata]): Seq[T] = { + val desired = implicitly[ClassTag[T]].runtimeClass + list.filter(_.getClass() == desired).map(_.asInstanceOf[T]).toSeq + } } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 6240f7b68d2c8..184fb6a8ad13e 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -116,6 +116,8 @@ private[spark] class Client( } } + require(keytab == null || !Utils.isLocalUri(keytab), "Keytab should reference a local file.") + private val launcherBackend = new LauncherBackend() { override protected def conf: SparkConf = sparkConf @@ -472,7 +474,7 @@ private[spark] class Client( appMasterOnly: Boolean = false): (Boolean, String) = { val trimmedPath = path.trim() val localURI = Utils.resolveURI(trimmedPath) - if (localURI.getScheme != LOCAL_SCHEME) { + if (localURI.getScheme != Utils.LOCAL_SCHEME) { if (addDistributedUri(localURI)) { val localPath = getQualifiedLocalPath(localURI, hadoopConf) val linkname = targetDir.map(_ + "/").getOrElse("") + @@ -515,7 +517,7 @@ private[spark] class Client( val sparkArchive = sparkConf.get(SPARK_ARCHIVE) if (sparkArchive.isDefined) { val archive = sparkArchive.get - require(!isLocalUri(archive), s"${SPARK_ARCHIVE.key} cannot be a local URI.") + require(!Utils.isLocalUri(archive), s"${SPARK_ARCHIVE.key} cannot be a local URI.") distribute(Utils.resolveURI(archive).toString, resType = LocalResourceType.ARCHIVE, destName = Some(LOCALIZED_LIB_DIR)) @@ -525,7 +527,7 @@ private[spark] class Client( // Break the list of jars to upload, and resolve globs. val localJars = new ArrayBuffer[String]() jars.foreach { jar => - if (!isLocalUri(jar)) { + if (!Utils.isLocalUri(jar)) { val path = getQualifiedLocalPath(Utils.resolveURI(jar), hadoopConf) val pathFs = FileSystem.get(path.toUri(), hadoopConf) pathFs.globStatus(path).filter(_.isFile()).foreach { entry => @@ -814,7 +816,7 @@ private[spark] class Client( } (pySparkArchives ++ pyArchives).foreach { path => val uri = Utils.resolveURI(path) - if (uri.getScheme != LOCAL_SCHEME) { + if (uri.getScheme != Utils.LOCAL_SCHEME) { pythonPath += buildPath(Environment.PWD.$$(), new Path(uri).getName()) } else { pythonPath += uri.getPath() @@ -1183,9 +1185,6 @@ private object Client extends Logging { // Alias for the user jar val APP_JAR_NAME: String = "__app__.jar" - // URI scheme that identifies local resources - val LOCAL_SCHEME = "local" - // Staging directory for any temporary jars or files val SPARK_STAGING: String = ".sparkStaging" @@ -1307,7 +1306,7 @@ private object Client extends Logging { addClasspathEntry(buildPath(Environment.PWD.$$(), LOCALIZED_LIB_DIR, "*"), env) if (sparkConf.get(SPARK_ARCHIVE).isEmpty) { sparkConf.get(SPARK_JARS).foreach { jars => - jars.filter(isLocalUri).foreach { jar => + jars.filter(Utils.isLocalUri).foreach { jar => val uri = new URI(jar) addClasspathEntry(getClusterPath(sparkConf, uri.getPath()), env) } @@ -1340,7 +1339,7 @@ private object Client extends Logging { private def getMainJarUri(mainJar: Option[String]): Option[URI] = { mainJar.flatMap { path => val uri = Utils.resolveURI(path) - if (uri.getScheme == LOCAL_SCHEME) Some(uri) else None + if (uri.getScheme == Utils.LOCAL_SCHEME) Some(uri) else None }.orElse(Some(new URI(APP_JAR_NAME))) } @@ -1368,7 +1367,7 @@ private object Client extends Logging { uri: URI, fileName: String, env: HashMap[String, String]): Unit = { - if (uri != null && uri.getScheme == LOCAL_SCHEME) { + if (uri != null && uri.getScheme == Utils.LOCAL_SCHEME) { addClasspathEntry(getClusterPath(conf, uri.getPath), env) } else if (fileName != null) { addClasspathEntry(buildPath(Environment.PWD.$$(), fileName), env) @@ -1489,11 +1488,6 @@ private object Client extends Logging { components.mkString(Path.SEPARATOR) } - /** Returns whether the URI is a "local:" URI. */ - def isLocalUri(uri: String): Boolean = { - uri.startsWith(s"$LOCAL_SCHEME:") - } - def createAppReport(report: ApplicationReport): YarnAppReport = { val diags = report.getDiagnostics() val diagsOpt = if (diags != null && diags.nonEmpty) Some(diags) else None diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index b3286e8fd824e..a6f57fcdb2461 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -100,7 +100,7 @@ class ClientSuite extends SparkFunSuite with Matchers { val cp = env("CLASSPATH").split(":|;|") s"$SPARK,$USER,$ADDED".split(",").foreach({ entry => val uri = new URI(entry) - if (LOCAL_SCHEME.equals(uri.getScheme())) { + if (Utils.LOCAL_SCHEME.equals(uri.getScheme())) { cp should contain (uri.getPath()) } else { cp should not contain (uri.getPath()) @@ -136,7 +136,7 @@ class ClientSuite extends SparkFunSuite with Matchers { val expected = ADDED.split(",") .map(p => { val uri = new URI(p) - if (LOCAL_SCHEME == uri.getScheme()) { + if (Utils.LOCAL_SCHEME == uri.getScheme()) { p } else { Option(uri.getFragment()).getOrElse(new File(p).getName()) @@ -249,7 +249,7 @@ class ClientSuite extends SparkFunSuite with Matchers { any(classOf[MutableHashMap[URI, Path]]), anyBoolean(), any()) classpath(client) should contain (buildPath(PWD, LOCALIZED_LIB_DIR, "*")) - sparkConf.set(SPARK_ARCHIVE, LOCAL_SCHEME + ":" + archive.getPath()) + sparkConf.set(SPARK_ARCHIVE, Utils.LOCAL_SCHEME + ":" + archive.getPath()) intercept[IllegalArgumentException] { client.prepareLocalResources(new Path(temp.getAbsolutePath()), Nil) } From 93089b5b36a223f390fb45b88911714db80c6f18 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 18 Dec 2018 23:21:52 -0800 Subject: [PATCH 101/194] [SPARK-26366][SQL] ReplaceExceptWithFilter should consider NULL as False ## What changes were proposed in this pull request? In `ReplaceExceptWithFilter` we do not consider properly the case in which the condition returns NULL. Indeed, in that case, since negating NULL still returns NULL, so it is not true the assumption that negating the condition returns all the rows which didn't satisfy it, rows returning NULL may not be returned. This happens when constraints inferred by `InferFiltersFromConstraints` are not enough, as it happens with `OR` conditions. The rule had also problems with non-deterministic conditions: in such a scenario, this rule would change the probability of the output. The PR fixes these problem by: - returning False for the condition when it is Null (in this way we do return all the rows which didn't satisfy it); - avoiding any transformation when the condition is non-deterministic. ## How was this patch tested? added UTs Closes #23315 from mgaido91/SPARK-26366. Authored-by: Marco Gaido Signed-off-by: gatorsmile --- .../optimizer/ReplaceExceptWithFilter.scala | 32 ++++++++------ .../optimizer/ReplaceOperatorSuite.scala | 44 ++++++++++++++----- .../org/apache/spark/sql/DatasetSuite.scala | 11 +++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 38 ++++++++++++++++ 4 files changed, 101 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala index efd3944eba7f5..4996d24dfd298 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala @@ -36,7 +36,8 @@ import org.apache.spark.sql.catalyst.rules.Rule * Note: * Before flipping the filter condition of the right node, we should: * 1. Combine all it's [[Filter]]. - * 2. Apply InferFiltersFromConstraints rule (to take into account of NULL values in the condition). + * 2. Update the attribute references to the left node; + * 3. Add a Coalesce(condition, False) (to take into account of NULL values in the condition). */ object ReplaceExceptWithFilter extends Rule[LogicalPlan] { @@ -47,23 +48,28 @@ object ReplaceExceptWithFilter extends Rule[LogicalPlan] { plan.transform { case e @ Except(left, right, false) if isEligible(left, right) => - val newCondition = transformCondition(left, skipProject(right)) - newCondition.map { c => - Distinct(Filter(Not(c), left)) - }.getOrElse { + val filterCondition = combineFilters(skipProject(right)).asInstanceOf[Filter].condition + if (filterCondition.deterministic) { + transformCondition(left, filterCondition).map { c => + Distinct(Filter(Not(c), left)) + }.getOrElse { + e + } + } else { e } } } - private def transformCondition(left: LogicalPlan, right: LogicalPlan): Option[Expression] = { - val filterCondition = - InferFiltersFromConstraints(combineFilters(right)).asInstanceOf[Filter].condition - - val attributeNameMap: Map[String, Attribute] = left.output.map(x => (x.name, x)).toMap - - if (filterCondition.references.forall(r => attributeNameMap.contains(r.name))) { - Some(filterCondition.transform { case a: AttributeReference => attributeNameMap(a.name) }) + private def transformCondition(plan: LogicalPlan, condition: Expression): Option[Expression] = { + val attributeNameMap: Map[String, Attribute] = plan.output.map(x => (x.name, x)).toMap + if (condition.references.forall(r => attributeNameMap.contains(r.name))) { + val rewrittenCondition = condition.transform { + case a: AttributeReference => attributeNameMap(a.name) + } + // We need to consider as False when the condition is NULL, otherwise we do not return those + // rows containing NULL which are instead filtered in the Except right plan + Some(Coalesce(Seq(rewrittenCondition, Literal.FalseLiteral))) } else { None } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala index 3b1b2d588ef67..c8e15c7da763e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala @@ -20,11 +20,12 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, Not} +import org.apache.spark.sql.catalyst.expressions.{Alias, Coalesce, If, Literal, Not} import org.apache.spark.sql.catalyst.expressions.aggregate.First import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi, PlanTest} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types.BooleanType class ReplaceOperatorSuite extends PlanTest { @@ -65,8 +66,7 @@ class ReplaceOperatorSuite extends PlanTest { val correctAnswer = Aggregate(table1.output, table1.output, - Filter(Not((attributeA.isNotNull && attributeB.isNotNull) && - (attributeA >= 2 && attributeB < 1)), + Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1, Literal.FalseLiteral))), Filter(attributeB === 2, Filter(attributeA === 1, table1)))).analyze comparePlans(optimized, correctAnswer) @@ -84,8 +84,8 @@ class ReplaceOperatorSuite extends PlanTest { val correctAnswer = Aggregate(table1.output, table1.output, - Filter(Not((attributeA.isNotNull && attributeB.isNotNull) && - (attributeA >= 2 && attributeB < 1)), table1)).analyze + Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1, Literal.FalseLiteral))), + table1)).analyze comparePlans(optimized, correctAnswer) } @@ -104,8 +104,7 @@ class ReplaceOperatorSuite extends PlanTest { val correctAnswer = Aggregate(table1.output, table1.output, - Filter(Not((attributeA.isNotNull && attributeB.isNotNull) && - (attributeA >= 2 && attributeB < 1)), + Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1, Literal.FalseLiteral))), Project(Seq(attributeA, attributeB), table1))).analyze comparePlans(optimized, correctAnswer) @@ -125,8 +124,7 @@ class ReplaceOperatorSuite extends PlanTest { val correctAnswer = Aggregate(table1.output, table1.output, - Filter(Not((attributeA.isNotNull && attributeB.isNotNull) && - (attributeA >= 2 && attributeB < 1)), + Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1, Literal.FalseLiteral))), Filter(attributeB === 2, Filter(attributeA === 1, table1)))).analyze comparePlans(optimized, correctAnswer) @@ -146,8 +144,7 @@ class ReplaceOperatorSuite extends PlanTest { val correctAnswer = Aggregate(table1.output, table1.output, - Filter(Not((attributeA.isNotNull && attributeB.isNotNull) && - (attributeA === 1 && attributeB === 2)), + Filter(Not(Coalesce(Seq(attributeA === 1 && attributeB === 2, Literal.FalseLiteral))), Project(Seq(attributeA, attributeB), Filter(attributeB < 1, Filter(attributeA >= 2, table1))))).analyze @@ -229,4 +226,29 @@ class ReplaceOperatorSuite extends PlanTest { comparePlans(optimized, query) } + + test("SPARK-26366: ReplaceExceptWithFilter should handle properly NULL") { + val basePlan = LocalRelation(Seq('a.int, 'b.int)) + val otherPlan = basePlan.where('a.in(1, 2) || 'b.in()) + val except = Except(basePlan, otherPlan, false) + val result = OptimizeIn(Optimize.execute(except.analyze)) + val correctAnswer = Aggregate(basePlan.output, basePlan.output, + Filter(!Coalesce(Seq( + 'a.in(1, 2) || If('b.isNotNull, Literal.FalseLiteral, Literal(null, BooleanType)), + Literal.FalseLiteral)), + basePlan)).analyze + comparePlans(result, correctAnswer) + } + + test("SPARK-26366: ReplaceExceptWithFilter should not transform non-detrministic") { + val basePlan = LocalRelation(Seq('a.int, 'b.int)) + val otherPlan = basePlan.where('a > rand(1L)) + val except = Except(basePlan, otherPlan, false) + val result = Optimize.execute(except.analyze) + val condition = basePlan.output.zip(otherPlan.output).map { case (a1, a2) => + a1 <=> a2 }.reduce( _ && _) + val correctAnswer = Aggregate(basePlan.output, otherPlan.output, + Join(basePlan, otherPlan, LeftAnti, Option(condition))).analyze + comparePlans(result, correctAnswer) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 525c7cef39563..c90b15814a534 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1656,6 +1656,17 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkAnswer(df.groupBy(col("a")).agg(first(col("b"))), Seq(Row("0", BigDecimal.valueOf(0.1111)), Row("1", BigDecimal.valueOf(1.1111)))) } + + test("SPARK-26366: return nulls which are not filtered in except") { + val inputDF = sqlContext.createDataFrame( + sparkContext.parallelize(Seq(Row("0", "a"), Row("1", null))), + StructType(Seq( + StructField("a", StringType, nullable = true), + StructField("b", StringType, nullable = true)))) + + val exceptDF = inputDF.filter(col("a").isin("0") or col("b") > "c") + checkAnswer(inputDF.except(exceptDF), Seq(Row("1", null))) + } } case class TestDataUnion(x: Int, y: Int, z: Int) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 4cc8a45391996..37a8815350a53 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2899,6 +2899,44 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-26366: verify ReplaceExceptWithFilter") { + Seq(true, false).foreach { enabled => + withSQLConf(SQLConf.REPLACE_EXCEPT_WITH_FILTER.key -> enabled.toString) { + val df = spark.createDataFrame( + sparkContext.parallelize(Seq(Row(0, 3, 5), + Row(0, 3, null), + Row(null, 3, 5), + Row(0, null, 5), + Row(0, null, null), + Row(null, null, 5), + Row(null, 3, null), + Row(null, null, null))), + StructType(Seq(StructField("c1", IntegerType), + StructField("c2", IntegerType), + StructField("c3", IntegerType)))) + val where = "c2 >= 3 OR c1 >= 0" + val whereNullSafe = + """ + |(c2 IS NOT NULL AND c2 >= 3) + |OR (c1 IS NOT NULL AND c1 >= 0) + """.stripMargin + + val df_a = df.filter(where) + val df_b = df.filter(whereNullSafe) + checkAnswer(df.except(df_a), df.except(df_b)) + + val whereWithIn = "c2 >= 3 OR c1 in (2)" + val whereWithInNullSafe = + """ + |(c2 IS NOT NULL AND c2 >= 3) + """.stripMargin + val dfIn_a = df.filter(whereWithIn) + val dfIn_b = df.filter(whereWithInNullSafe) + checkAnswer(df.except(dfIn_a), df.except(dfIn_b)) + } + } + } } case class Foo(bar: Option[String]) From c3a7a52edda96de37501d3e41acc4a3636a79234 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 19 Dec 2018 09:41:30 -0800 Subject: [PATCH 102/194] [SPARK-26390][SQL] ColumnPruning rule should only do column pruning ## What changes were proposed in this pull request? This is a small clean up. By design catalyst rules should be orthogonal: each rule should have its own responsibility. However, the `ColumnPruning` rule does not only do column pruning, but also remove no-op project and window. This PR updates the `RemoveRedundantProject` rule to remove no-op window as well, and clean up the `ColumnPruning` rule to only do column pruning. ## How was this patch tested? existing tests Closes #23343 from cloud-fan/column-pruning. Authored-by: Wenchen Fan Signed-off-by: Dongjoon Hyun --- .../sql/catalyst/optimizer/Optimizer.scala | 23 +++++++++---------- .../optimizer/ColumnPruningSuite.scala | 7 +++--- .../optimizer/CombiningLimitsSuite.scala | 5 ++-- .../optimizer/JoinOptimizationSuite.scala | 1 + .../RemoveRedundantAliasAndProjectSuite.scala | 2 +- .../optimizer/RewriteSubquerySuite.scala | 2 +- .../optimizer/TransposeWindowSuite.scala | 2 +- 7 files changed, 21 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 3eb6bca6ec976..44d5543114902 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -93,7 +93,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) RewriteCorrelatedScalarSubquery, EliminateSerialization, RemoveRedundantAliases, - RemoveRedundantProject, + RemoveNoopOperators, SimplifyExtractValueOps, CombineConcats) ++ extendedOperatorOptimizationRules @@ -177,7 +177,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) RewritePredicateSubquery, ColumnPruning, CollapseProject, - RemoveRedundantProject) :+ + RemoveNoopOperators) :+ Batch("UpdateAttributeReferences", Once, UpdateNullabilityInAttributeReferences) } @@ -403,11 +403,15 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] { } /** - * Remove projections from the query plan that do not make any modifications. + * Remove no-op operators from the query plan that do not make any modifications. */ -object RemoveRedundantProject extends Rule[LogicalPlan] { +object RemoveNoopOperators extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case p @ Project(_, child) if p.output == child.output => child + // Eliminate no-op Projects + case p @ Project(_, child) if child.sameOutput(p) => child + + // Eliminate no-op Window + case w: Window if w.windowExpressions.isEmpty => w.child } } @@ -602,17 +606,12 @@ object ColumnPruning extends Rule[LogicalPlan] { p.copy(child = w.copy( windowExpressions = w.windowExpressions.filter(p.references.contains))) - // Eliminate no-op Window - case w: Window if w.windowExpressions.isEmpty => w.child - - // Eliminate no-op Projects - case p @ Project(_, child) if child.sameOutput(p) => child - // Can't prune the columns on LeafNode case p @ Project(_, _: LeafNode) => p // for all other logical plans that inherits the output from it's children - case p @ Project(_, child) => + // Project over project is handled by the first case, skip it here. + case p @ Project(_, child) if !child.isInstanceOf[Project] => val required = child.references ++ p.references if (!child.inputSet.subsetOf(required)) { val newChildren = child.children.map(c => prunedChild(c, required)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 8d7c9bf220bc2..57195d5fda7c5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -34,6 +34,7 @@ class ColumnPruningSuite extends PlanTest { val batches = Batch("Column pruning", FixedPoint(100), PushDownPredicate, ColumnPruning, + RemoveNoopOperators, CollapseProject) :: Nil } @@ -340,10 +341,8 @@ class ColumnPruningSuite extends PlanTest { test("Column pruning on Union") { val input1 = LocalRelation('a.int, 'b.string, 'c.double) val input2 = LocalRelation('c.int, 'd.string, 'e.double) - val query = Project('b :: Nil, - Union(input1 :: input2 :: Nil)).analyze - val expected = Project('b :: Nil, - Union(Project('b :: Nil, input1) :: Project('d :: Nil, input2) :: Nil)).analyze + val query = Project('b :: Nil, Union(input1 :: input2 :: Nil)).analyze + val expected = Union(Project('b :: Nil, input1) :: Project('d :: Nil, input2) :: Nil).analyze comparePlans(Optimize.execute(query), expected) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala index ef4b848924f06..b190dd5a7c220 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala @@ -27,8 +27,9 @@ class CombiningLimitsSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = - Batch("Filter Pushdown", FixedPoint(100), - ColumnPruning) :: + Batch("Column Pruning", FixedPoint(100), + ColumnPruning, + RemoveNoopOperators) :: Batch("Combine Limit", FixedPoint(10), CombineLimits) :: Batch("Constant Folding", FixedPoint(10), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala index e9438b2eee550..6fe5e619d03ad 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala @@ -39,6 +39,7 @@ class JoinOptimizationSuite extends PlanTest { ReorderJoin, PushPredicateThroughJoin, ColumnPruning, + RemoveNoopOperators, CollapseProject) :: Nil } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala index 1973b5abb462d..3802dbf5d6e06 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala @@ -33,7 +33,7 @@ class RemoveRedundantAliasAndProjectSuite extends PlanTest with PredicateHelper FixedPoint(50), PushProjectionThroughUnion, RemoveRedundantAliases, - RemoveRedundantProject) :: Nil + RemoveNoopOperators) :: Nil } test("all expressions in project list are aliased child output") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala index 6b3739c372c3a..f00d22e6e96a6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala @@ -34,7 +34,7 @@ class RewriteSubquerySuite extends PlanTest { RewritePredicateSubquery, ColumnPruning, CollapseProject, - RemoveRedundantProject) :: Nil + RemoveNoopOperators) :: Nil } test("Column pruning after rewriting predicate subquery") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransposeWindowSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransposeWindowSuite.scala index 58b3d1c98f3cd..4acd57832d2f6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransposeWindowSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransposeWindowSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor class TransposeWindowSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = - Batch("CollapseProject", FixedPoint(100), CollapseProject, RemoveRedundantProject) :: + Batch("CollapseProject", FixedPoint(100), CollapseProject, RemoveNoopOperators) :: Batch("FlipWindow", Once, CollapseWindow, TransposeWindow) :: Nil } From 2867beb716b3e3a1edae74c35e352cb004bd096e Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 20 Dec 2018 10:41:45 +0800 Subject: [PATCH 103/194] [SPARK-26262][SQL] Runs SQLQueryTestSuite on mixed config sets: WHOLESTAGE_CODEGEN_ENABLED and CODEGEN_FACTORY_MODE ## What changes were proposed in this pull request? For better test coverage, this pr proposed to use the 4 mixed config sets of `WHOLESTAGE_CODEGEN_ENABLED` and `CODEGEN_FACTORY_MODE` when running `SQLQueryTestSuite`: 1. WHOLESTAGE_CODEGEN_ENABLED=true, CODEGEN_FACTORY_MODE=CODEGEN_ONLY 2. WHOLESTAGE_CODEGEN_ENABLED=false, CODEGEN_FACTORY_MODE=CODEGEN_ONLY 3. WHOLESTAGE_CODEGEN_ENABLED=true, CODEGEN_FACTORY_MODE=NO_CODEGEN 4. WHOLESTAGE_CODEGEN_ENABLED=false, CODEGEN_FACTORY_MODE=NO_CODEGEN This pr also moved some existing tests into `ExplainSuite` because explain output results are different between codegen and interpreter modes. ## How was this patch tested? Existing tests. Closes #23213 from maropu/InterpreterModeTest. Authored-by: Takeshi Yamamuro Signed-off-by: Wenchen Fan --- .../resources/sql-tests/inputs/group-by.sql | 5 - .../sql-tests/inputs/inline-table.sql | 3 - .../resources/sql-tests/inputs/operators.sql | 21 -- .../inputs/sql-compatibility-functions.sql | 5 - .../sql-tests/inputs/string-functions.sql | 27 --- .../inputs/table-valued-functions.sql | 6 - .../sql-tests/results/group-by.sql.out | 30 +-- .../sql-tests/results/inline-table.sql.out | 32 +-- .../sql-tests/results/operators.sql.out | 204 +++++++----------- .../sql-compatibility-functions.sql.out | 61 ++---- .../results/string-functions.sql.out | 131 +++-------- .../results/table-valued-functions.sql.out | 41 +--- .../org/apache/spark/sql/ExplainSuite.scala | 133 +++++++++++- .../apache/spark/sql/SQLQueryTestSuite.scala | 51 ++--- 14 files changed, 281 insertions(+), 469 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index ec263ea70bd4a..7e81ff1aba37b 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -141,8 +141,3 @@ SELECT every("true"); SELECT k, v, every(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg; SELECT k, v, some(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg; SELECT k, v, any(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg; - --- simple explain of queries having every/some/any agregates. Optimized --- plan should show the rewritten aggregate expression. -EXPLAIN EXTENDED SELECT k, every(v), some(v), any(v) FROM test_agg GROUP BY k; - diff --git a/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql b/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql index 41d316444ed6b..b3ec956cd178e 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql @@ -49,6 +49,3 @@ select * from values ("one", count(1)), ("two", 2) as data(a, b); -- string to timestamp select * from values (timestamp('1991-12-06 00:00:00.0'), array(timestamp('1991-12-06 01:00:00.0'), timestamp('1991-12-06 12:00:00.0'))) as data(a, b); - --- cross-join inline tables -EXPLAIN EXTENDED SELECT * FROM VALUES ('one', 1), ('three', null) CROSS JOIN VALUES ('one', 1), ('three', null); diff --git a/sql/core/src/test/resources/sql-tests/inputs/operators.sql b/sql/core/src/test/resources/sql-tests/inputs/operators.sql index 37f9cd44da7f2..ba14789d48db6 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/operators.sql @@ -29,27 +29,6 @@ select 2 * 5; select 5 % 3; select pmod(-7, 3); --- check operator precedence. --- We follow Oracle operator precedence in the table below that lists the levels of precedence --- among SQL operators from high to low: ------------------------------------------------------------------------------------------- --- Operator Operation ------------------------------------------------------------------------------------------- --- +, - identity, negation --- *, / multiplication, division --- +, -, || addition, subtraction, concatenation --- =, !=, <, >, <=, >=, IS NULL, LIKE, BETWEEN, IN comparison --- NOT exponentiation, logical negation --- AND conjunction --- OR disjunction ------------------------------------------------------------------------------------------- -explain select 'a' || 1 + 2; -explain select 1 - 2 || 'b'; -explain select 2 * 4 + 3 || 'b'; -explain select 3 + 1 || 'a' || 4 / 2; -explain select 1 == 1 OR 'a' || 'b' == 'ab'; -explain select 'a' || 'c' == 'ac' AND 2 == 3; - -- math functions select cot(1); select cot(null); diff --git a/sql/core/src/test/resources/sql-tests/inputs/sql-compatibility-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/sql-compatibility-functions.sql index f1461032065ad..1ae49c8bfc76a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/sql-compatibility-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/sql-compatibility-functions.sql @@ -12,11 +12,6 @@ SELECT nullif(1, 2.1d), nullif(1, 1.0d); SELECT nvl(1, 2.1d), nvl(null, 2.1d); SELECT nvl2(null, 1, 2.1d), nvl2('n', 1, 2.1d); --- explain for these functions; use range to avoid constant folding -explain extended -select ifnull(id, 'x'), nullif(id, 'x'), nvl(id, 'x'), nvl2(id, 'x', 'y') -from range(2); - -- SPARK-16730 cast alias functions for Hive compatibility SELECT boolean(1), tinyint(1), smallint(1), int(1), bigint(1); SELECT float(1), double(1), decimal(1); diff --git a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql index 2effb43183d75..fbc231627e36f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql @@ -5,10 +5,6 @@ select format_string(); -- A pipe operator for string concatenation select 'a' || 'b' || 'c'; --- Check if catalyst combine nested `Concat`s -EXPLAIN EXTENDED SELECT (col1 || col2 || col3 || col4) col -FROM (SELECT id col1, id col2, id col3, id col4 FROM range(10)); - -- replace function select replace('abc', 'b', '123'); select replace('abc', 'b'); @@ -25,29 +21,6 @@ select left(null, -2), left("abcd", -2), left("abcd", 0), left("abcd", 'a'); select right("abcd", 2), right("abcd", 5), right("abcd", '2'), right("abcd", null); select right(null, -2), right("abcd", -2), right("abcd", 0), right("abcd", 'a'); --- turn off concatBinaryAsString -set spark.sql.function.concatBinaryAsString=false; - --- Check if catalyst combine nested `Concat`s if concatBinaryAsString=false -EXPLAIN SELECT ((col1 || col2) || (col3 || col4)) col -FROM ( - SELECT - string(id) col1, - string(id + 1) col2, - encode(string(id + 2), 'utf-8') col3, - encode(string(id + 3), 'utf-8') col4 - FROM range(10) -); - -EXPLAIN SELECT (col1 || (col3 || col4)) col -FROM ( - SELECT - string(id) col1, - encode(string(id + 2), 'utf-8') col3, - encode(string(id + 3), 'utf-8') col4 - FROM range(10) -); - -- split function SELECT split('aa1cc2ee3', '[1-9]+'); SELECT split('aa1cc2ee3', '[1-9]+', 2); diff --git a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql index 72cd8ca9d8722..6f14c8ca87821 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql @@ -21,9 +21,3 @@ select * from range(1, null); -- range call with a mixed-case function name select * from RaNgE(2); - --- Explain -EXPLAIN select * from RaNgE(2); - --- cross-join table valued functions -EXPLAIN EXTENDED SELECT * FROM range(3) CROSS JOIN range(3); diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 9a8d025331b67..daf47c4d0a39a 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 47 +-- Number of queries: 46 -- !query 0 @@ -459,31 +459,3 @@ struct --- !query 46 output -== Parsed Logical Plan == -'Aggregate ['k], ['k, unresolvedalias('every('v), None), unresolvedalias('some('v), None), unresolvedalias('any('v), None)] -+- 'UnresolvedRelation `test_agg` - -== Analyzed Logical Plan == -k: int, every(v): boolean, some(v): boolean, any(v): boolean -Aggregate [k#x], [k#x, every(v#x) AS every(v)#x, some(v#x) AS some(v)#x, any(v#x) AS any(v)#x] -+- SubqueryAlias `test_agg` - +- Project [k#x, v#x] - +- SubqueryAlias `test_agg` - +- LocalRelation [k#x, v#x] - -== Optimized Logical Plan == -Aggregate [k#x], [k#x, min(v#x) AS every(v)#x, max(v#x) AS some(v)#x, max(v#x) AS any(v)#x] -+- LocalRelation [k#x, v#x] - -== Physical Plan == -*HashAggregate(keys=[k#x], functions=[min(v#x), max(v#x)], output=[k#x, every(v)#x, some(v)#x, any(v)#x]) -+- Exchange hashpartitioning(k#x, 200) - +- *HashAggregate(keys=[k#x], functions=[partial_min(v#x), partial_max(v#x)], output=[k#x, min#x, max#x]) - +- LocalTableScan [k#x, v#x] diff --git a/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out b/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out index c065ce5012929..4e80f0bda5513 100644 --- a/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 18 +-- Number of queries: 17 -- !query 0 @@ -151,33 +151,3 @@ select * from values (timestamp('1991-12-06 00:00:00.0'), array(timestamp('1991- struct> -- !query 16 output 1991-12-06 00:00:00 [1991-12-06 01:00:00.0,1991-12-06 12:00:00.0] - - --- !query 17 -EXPLAIN EXTENDED SELECT * FROM VALUES ('one', 1), ('three', null) CROSS JOIN VALUES ('one', 1), ('three', null) --- !query 17 schema -struct --- !query 17 output -== Parsed Logical Plan == -'Project [*] -+- 'Join Cross - :- 'UnresolvedInlineTable [col1, col2], [List(one, 1), List(three, null)] - +- 'UnresolvedInlineTable [col1, col2], [List(one, 1), List(three, null)] - -== Analyzed Logical Plan == -col1: string, col2: int, col1: string, col2: int -Project [col1#x, col2#x, col1#x, col2#x] -+- Join Cross - :- LocalRelation [col1#x, col2#x] - +- LocalRelation [col1#x, col2#x] - -== Optimized Logical Plan == -Join Cross -:- LocalRelation [col1#x, col2#x] -+- LocalRelation [col1#x, col2#x] - -== Physical Plan == -BroadcastNestedLoopJoin BuildRight, Cross -:- LocalTableScan [col1#x, col2#x] -+- BroadcastExchange IdentityBroadcastMode - +- LocalTableScan [col1#x, col2#x] diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index 570b281353f3d..e0cbd575bc346 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 55 +-- Number of queries: 49 -- !query 0 @@ -195,260 +195,200 @@ struct -- !query 24 -explain select 'a' || 1 + 2 +select cot(1) -- !query 24 schema -struct +struct -- !query 24 output -== Physical Plan == -*Project [null AS (CAST(concat(a, CAST(1 AS STRING)) AS DOUBLE) + CAST(2 AS DOUBLE))#x] -+- *Scan OneRowRelation[] +0.6420926159343306 -- !query 25 -explain select 1 - 2 || 'b' +select cot(null) -- !query 25 schema -struct +struct -- !query 25 output -== Physical Plan == -*Project [-1b AS concat(CAST((1 - 2) AS STRING), b)#x] -+- *Scan OneRowRelation[] +NULL -- !query 26 -explain select 2 * 4 + 3 || 'b' +select cot(0) -- !query 26 schema -struct +struct -- !query 26 output -== Physical Plan == -*Project [11b AS concat(CAST(((2 * 4) + 3) AS STRING), b)#x] -+- *Scan OneRowRelation[] +Infinity -- !query 27 -explain select 3 + 1 || 'a' || 4 / 2 +select cot(-1) -- !query 27 schema -struct +struct -- !query 27 output -== Physical Plan == -*Project [4a2.0 AS concat(concat(CAST((3 + 1) AS STRING), a), CAST((CAST(4 AS DOUBLE) / CAST(2 AS DOUBLE)) AS STRING))#x] -+- *Scan OneRowRelation[] +-0.6420926159343306 -- !query 28 -explain select 1 == 1 OR 'a' || 'b' == 'ab' +select ceiling(0) -- !query 28 schema -struct +struct -- !query 28 output -== Physical Plan == -*Project [true AS ((1 = 1) OR (concat(a, b) = ab))#x] -+- *Scan OneRowRelation[] +0 -- !query 29 -explain select 'a' || 'c' == 'ac' AND 2 == 3 +select ceiling(1) -- !query 29 schema -struct +struct -- !query 29 output -== Physical Plan == -*Project [false AS ((concat(a, c) = ac) AND (2 = 3))#x] -+- *Scan OneRowRelation[] +1 -- !query 30 -select cot(1) +select ceil(1234567890123456) -- !query 30 schema -struct +struct -- !query 30 output -0.6420926159343306 +1234567890123456 -- !query 31 -select cot(null) +select ceiling(1234567890123456) -- !query 31 schema -struct +struct -- !query 31 output -NULL +1234567890123456 -- !query 32 -select cot(0) +select ceil(0.01) -- !query 32 schema -struct +struct -- !query 32 output -Infinity +1 -- !query 33 -select cot(-1) +select ceiling(-0.10) -- !query 33 schema -struct +struct -- !query 33 output --0.6420926159343306 +0 -- !query 34 -select ceiling(0) +select floor(0) -- !query 34 schema -struct +struct -- !query 34 output 0 -- !query 35 -select ceiling(1) +select floor(1) -- !query 35 schema -struct +struct -- !query 35 output 1 -- !query 36 -select ceil(1234567890123456) +select floor(1234567890123456) -- !query 36 schema -struct +struct -- !query 36 output 1234567890123456 -- !query 37 -select ceiling(1234567890123456) --- !query 37 schema -struct --- !query 37 output -1234567890123456 - - --- !query 38 -select ceil(0.01) --- !query 38 schema -struct --- !query 38 output -1 - - --- !query 39 -select ceiling(-0.10) --- !query 39 schema -struct --- !query 39 output -0 - - --- !query 40 -select floor(0) --- !query 40 schema -struct --- !query 40 output -0 - - --- !query 41 -select floor(1) --- !query 41 schema -struct --- !query 41 output -1 - - --- !query 42 -select floor(1234567890123456) --- !query 42 schema -struct --- !query 42 output -1234567890123456 - - --- !query 43 select floor(0.01) --- !query 43 schema +-- !query 37 schema struct --- !query 43 output +-- !query 37 output 0 --- !query 44 +-- !query 38 select floor(-0.10) --- !query 44 schema +-- !query 38 schema struct --- !query 44 output +-- !query 38 output -1 --- !query 45 +-- !query 39 select 1 > 0.00001 --- !query 45 schema +-- !query 39 schema struct<(CAST(1 AS BIGINT) > 0):boolean> --- !query 45 output +-- !query 39 output true --- !query 46 +-- !query 40 select mod(7, 2), mod(7, 0), mod(0, 2), mod(7, null), mod(null, 2), mod(null, null) --- !query 46 schema +-- !query 40 schema struct<(7 % 2):int,(7 % 0):int,(0 % 2):int,(7 % CAST(NULL AS INT)):int,(CAST(NULL AS INT) % 2):int,(CAST(NULL AS DOUBLE) % CAST(NULL AS DOUBLE)):double> --- !query 46 output +-- !query 40 output 1 NULL 0 NULL NULL NULL --- !query 47 +-- !query 41 select BIT_LENGTH('abc') --- !query 47 schema +-- !query 41 schema struct --- !query 47 output +-- !query 41 output 24 --- !query 48 +-- !query 42 select CHAR_LENGTH('abc') --- !query 48 schema +-- !query 42 schema struct --- !query 48 output +-- !query 42 output 3 --- !query 49 +-- !query 43 select CHARACTER_LENGTH('abc') --- !query 49 schema +-- !query 43 schema struct --- !query 49 output +-- !query 43 output 3 --- !query 50 +-- !query 44 select OCTET_LENGTH('abc') --- !query 50 schema +-- !query 44 schema struct --- !query 50 output +-- !query 44 output 3 --- !query 51 +-- !query 45 select abs(-3.13), abs('-2.19') --- !query 51 schema +-- !query 45 schema struct --- !query 51 output +-- !query 45 output 3.13 2.19 --- !query 52 +-- !query 46 select positive('-1.11'), positive(-1.11), negative('-1.11'), negative(-1.11) --- !query 52 schema +-- !query 46 schema struct<(+ CAST(-1.11 AS DOUBLE)):double,(+ -1.11):decimal(3,2),(- CAST(-1.11 AS DOUBLE)):double,(- -1.11):decimal(3,2)> --- !query 52 output +-- !query 46 output -1.11 -1.11 1.11 1.11 --- !query 53 +-- !query 47 select pmod(-7, 2), pmod(0, 2), pmod(7, 0), pmod(7, null), pmod(null, 2), pmod(null, null) --- !query 53 schema +-- !query 47 schema struct --- !query 53 output +-- !query 47 output 1 0 NULL NULL NULL NULL --- !query 54 +-- !query 48 select pmod(cast(3.13 as decimal), cast(0 as decimal)), pmod(cast(2 as smallint), cast(0 as smallint)) --- !query 54 schema +-- !query 48 schema struct --- !query 54 output +-- !query 48 output NULL NULL diff --git a/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out index e035505f15d28..69a8e958000db 100644 --- a/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 15 +-- Number of queries: 14 -- !query 0 @@ -67,74 +67,49 @@ struct -- !query 8 -explain extended -select ifnull(id, 'x'), nullif(id, 'x'), nvl(id, 'x'), nvl2(id, 'x', 'y') -from range(2) --- !query 8 schema -struct --- !query 8 output -== Parsed Logical Plan == -'Project [unresolvedalias('ifnull('id, x), None), unresolvedalias('nullif('id, x), None), unresolvedalias('nvl('id, x), None), unresolvedalias('nvl2('id, x, y), None)] -+- 'UnresolvedTableValuedFunction range, [2] - -== Analyzed Logical Plan == -ifnull(`id`, 'x'): string, nullif(`id`, 'x'): bigint, nvl(`id`, 'x'): string, nvl2(`id`, 'x', 'y'): string -Project [ifnull(id#xL, x) AS ifnull(`id`, 'x')#x, nullif(id#xL, x) AS nullif(`id`, 'x')#xL, nvl(id#xL, x) AS nvl(`id`, 'x')#x, nvl2(id#xL, x, y) AS nvl2(`id`, 'x', 'y')#x] -+- Range (0, 2, step=1, splits=None) - -== Optimized Logical Plan == -Project [coalesce(cast(id#xL as string), x) AS ifnull(`id`, 'x')#x, id#xL AS nullif(`id`, 'x')#xL, coalesce(cast(id#xL as string), x) AS nvl(`id`, 'x')#x, x AS nvl2(`id`, 'x', 'y')#x] -+- Range (0, 2, step=1, splits=None) - -== Physical Plan == -*Project [coalesce(cast(id#xL as string), x) AS ifnull(`id`, 'x')#x, id#xL AS nullif(`id`, 'x')#xL, coalesce(cast(id#xL as string), x) AS nvl(`id`, 'x')#x, x AS nvl2(`id`, 'x', 'y')#x] -+- *Range (0, 2, step=1, splits=2) - - --- !query 9 SELECT boolean(1), tinyint(1), smallint(1), int(1), bigint(1) --- !query 9 schema +-- !query 8 schema struct --- !query 9 output +-- !query 8 output true 1 1 1 1 --- !query 10 +-- !query 9 SELECT float(1), double(1), decimal(1) --- !query 10 schema +-- !query 9 schema struct --- !query 10 output +-- !query 9 output 1.0 1.0 1 --- !query 11 +-- !query 10 SELECT date("2014-04-04"), timestamp(date("2014-04-04")) --- !query 11 schema +-- !query 10 schema struct --- !query 11 output +-- !query 10 output 2014-04-04 2014-04-04 00:00:00 --- !query 12 +-- !query 11 SELECT string(1, 2) --- !query 12 schema +-- !query 11 schema struct<> --- !query 12 output +-- !query 11 output org.apache.spark.sql.AnalysisException Function string accepts only one argument; line 1 pos 7 --- !query 13 +-- !query 12 CREATE TEMPORARY VIEW tempView1 AS VALUES (1, NAMED_STRUCT('col1', 'gamma', 'col2', 'delta')) AS T(id, st) --- !query 13 schema +-- !query 12 schema struct<> --- !query 13 output +-- !query 12 output --- !query 14 +-- !query 13 SELECT nvl(st.col1, "value"), count(*) FROM from tempView1 GROUP BY nvl(st.col1, "value") --- !query 14 schema +-- !query 13 schema struct --- !query 14 output +-- !query 13 output gamma 1 diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out index e8f2e0a81455a..25d93b2063146 100644 --- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 17 +-- Number of queries: 13 -- !query 0 @@ -29,151 +29,80 @@ abc -- !query 3 -EXPLAIN EXTENDED SELECT (col1 || col2 || col3 || col4) col -FROM (SELECT id col1, id col2, id col3, id col4 FROM range(10)) --- !query 3 schema -struct --- !query 3 output -== Parsed Logical Plan == -'Project [concat(concat(concat('col1, 'col2), 'col3), 'col4) AS col#x] -+- 'SubqueryAlias `__auto_generated_subquery_name` - +- 'Project ['id AS col1#x, 'id AS col2#x, 'id AS col3#x, 'id AS col4#x] - +- 'UnresolvedTableValuedFunction range, [10] - -== Analyzed Logical Plan == -col: string -Project [concat(concat(concat(cast(col1#xL as string), cast(col2#xL as string)), cast(col3#xL as string)), cast(col4#xL as string)) AS col#x] -+- SubqueryAlias `__auto_generated_subquery_name` - +- Project [id#xL AS col1#xL, id#xL AS col2#xL, id#xL AS col3#xL, id#xL AS col4#xL] - +- Range (0, 10, step=1, splits=None) - -== Optimized Logical Plan == -Project [concat(cast(id#xL as string), cast(id#xL as string), cast(id#xL as string), cast(id#xL as string)) AS col#x] -+- Range (0, 10, step=1, splits=None) - -== Physical Plan == -*Project [concat(cast(id#xL as string), cast(id#xL as string), cast(id#xL as string), cast(id#xL as string)) AS col#x] -+- *Range (0, 10, step=1, splits=2) - - --- !query 4 select replace('abc', 'b', '123') --- !query 4 schema +-- !query 3 schema struct --- !query 4 output +-- !query 3 output a123c --- !query 5 +-- !query 4 select replace('abc', 'b') --- !query 5 schema +-- !query 4 schema struct --- !query 5 output +-- !query 4 output ac --- !query 6 +-- !query 5 select length(uuid()), (uuid() <> uuid()) --- !query 6 schema +-- !query 5 schema struct --- !query 6 output +-- !query 5 output 36 true --- !query 7 +-- !query 6 select position('bar' in 'foobarbar'), position(null, 'foobarbar'), position('aaads', null) --- !query 7 schema +-- !query 6 schema struct --- !query 7 output +-- !query 6 output 4 NULL NULL --- !query 8 +-- !query 7 select left("abcd", 2), left("abcd", 5), left("abcd", '2'), left("abcd", null) --- !query 8 schema +-- !query 7 schema struct --- !query 8 output +-- !query 7 output ab abcd ab NULL --- !query 9 +-- !query 8 select left(null, -2), left("abcd", -2), left("abcd", 0), left("abcd", 'a') --- !query 9 schema +-- !query 8 schema struct --- !query 9 output +-- !query 8 output NULL NULL --- !query 10 +-- !query 9 select right("abcd", 2), right("abcd", 5), right("abcd", '2'), right("abcd", null) --- !query 10 schema +-- !query 9 schema struct --- !query 10 output +-- !query 9 output cd abcd cd NULL --- !query 11 +-- !query 10 select right(null, -2), right("abcd", -2), right("abcd", 0), right("abcd", 'a') --- !query 11 schema +-- !query 10 schema struct --- !query 11 output +-- !query 10 output NULL NULL --- !query 12 -set spark.sql.function.concatBinaryAsString=false --- !query 12 schema -struct --- !query 12 output -spark.sql.function.concatBinaryAsString false - - --- !query 13 -EXPLAIN SELECT ((col1 || col2) || (col3 || col4)) col -FROM ( - SELECT - string(id) col1, - string(id + 1) col2, - encode(string(id + 2), 'utf-8') col3, - encode(string(id + 3), 'utf-8') col4 - FROM range(10) -) --- !query 13 schema -struct --- !query 13 output -== Physical Plan == -*Project [concat(cast(id#xL as string), cast((id#xL + 1) as string), cast(encode(cast((id#xL + 2) as string), utf-8) as string), cast(encode(cast((id#xL + 3) as string), utf-8) as string)) AS col#x] -+- *Range (0, 10, step=1, splits=2) - - --- !query 14 -EXPLAIN SELECT (col1 || (col3 || col4)) col -FROM ( - SELECT - string(id) col1, - encode(string(id + 2), 'utf-8') col3, - encode(string(id + 3), 'utf-8') col4 - FROM range(10) -) --- !query 14 schema -struct --- !query 14 output -== Physical Plan == -*Project [concat(cast(id#xL as string), cast(encode(cast((id#xL + 2) as string), utf-8) as string), cast(encode(cast((id#xL + 3) as string), utf-8) as string)) AS col#x] -+- *Range (0, 10, step=1, splits=2) - - --- !query 15 +-- !query 11 SELECT split('aa1cc2ee3', '[1-9]+') --- !query 15 schema +-- !query 11 schema struct> --- !query 15 output +-- !query 11 output ["aa","cc","ee",""] --- !query 16 +-- !query 12 SELECT split('aa1cc2ee3', '[1-9]+', 2) --- !query 16 schema +-- !query 12 schema struct> --- !query 16 output +-- !query 12 output ["aa","cc2ee3"] diff --git a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out index 94af9181225d6..fdbea0ee90720 100644 --- a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 10 +-- Number of queries: 8 -- !query 0 @@ -99,42 +99,3 @@ struct -- !query 7 output 0 1 - - --- !query 8 -EXPLAIN select * from RaNgE(2) --- !query 8 schema -struct --- !query 8 output -== Physical Plan == -*Range (0, 2, step=1, splits=2) - - --- !query 9 -EXPLAIN EXTENDED SELECT * FROM range(3) CROSS JOIN range(3) --- !query 9 schema -struct --- !query 9 output -== Parsed Logical Plan == -'Project [*] -+- 'Join Cross - :- 'UnresolvedTableValuedFunction range, [3] - +- 'UnresolvedTableValuedFunction range, [3] - -== Analyzed Logical Plan == -id: bigint, id: bigint -Project [id#xL, id#xL] -+- Join Cross - :- Range (0, 3, step=1, splits=None) - +- Range (0, 3, step=1, splits=None) - -== Optimized Logical Plan == -Join Cross -:- Range (0, 3, step=1, splits=None) -+- Range (0, 3, step=1, splits=None) - -== Physical Plan == -BroadcastNestedLoopJoin BuildRight, Cross -:- *Range (0, 3, step=1, splits=2) -+- BroadcastExchange IdentityBroadcastMode - +- *Range (0, 3, step=1, splits=2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala index 56d300e30a58e..ce475922eb5e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.StructType @@ -29,10 +30,11 @@ class ExplainSuite extends QueryTest with SharedSQLContext { private def checkKeywordsExistsInExplain(df: DataFrame, keywords: String*): Unit = { val output = new java.io.ByteArrayOutputStream() Console.withOut(output) { - df.explain(extended = false) + df.explain(extended = true) } + val normalizedOutput = output.toString.replaceAll("#\\d+", "#x") for (key <- keywords) { - assert(output.toString.contains(key)) + assert(normalizedOutput.contains(key)) } } @@ -53,6 +55,133 @@ class ExplainSuite extends QueryTest with SharedSQLContext { checkKeywordsExistsInExplain(df, keywords = "InMemoryRelation", "StorageLevel(disk, memory, deserialized, 1 replicas)") } + + test("optimized plan should show the rewritten aggregate expression") { + withTempView("test_agg") { + sql( + """ + |CREATE TEMPORARY VIEW test_agg AS SELECT * FROM VALUES + | (1, true), (1, false), + | (2, true), + | (3, false), (3, null), + | (4, null), (4, null), + | (5, null), (5, true), (5, false) AS test_agg(k, v) + """.stripMargin) + + // simple explain of queries having every/some/any aggregates. Optimized + // plan should show the rewritten aggregate expression. + val df = sql("SELECT k, every(v), some(v), any(v) FROM test_agg GROUP BY k") + checkKeywordsExistsInExplain(df, + "Aggregate [k#x], [k#x, min(v#x) AS every(v)#x, max(v#x) AS some(v)#x, " + + "max(v#x) AS any(v)#x]") + } + } + + test("explain inline tables cross-joins") { + val df = sql( + """ + |SELECT * FROM VALUES ('one', 1), ('three', null) + | CROSS JOIN VALUES ('one', 1), ('three', null) + """.stripMargin) + checkKeywordsExistsInExplain(df, + "Join Cross", + ":- LocalRelation [col1#x, col2#x]", + "+- LocalRelation [col1#x, col2#x]") + } + + test("explain table valued functions") { + checkKeywordsExistsInExplain(sql("select * from RaNgE(2)"), "Range (0, 2, step=1, splits=None)") + checkKeywordsExistsInExplain(sql("SELECT * FROM range(3) CROSS JOIN range(3)"), + "Join Cross", + ":- Range (0, 3, step=1, splits=None)", + "+- Range (0, 3, step=1, splits=None)") + } + + test("explain string functions") { + // Check if catalyst combine nested `Concat`s + val df1 = sql( + """ + |SELECT (col1 || col2 || col3 || col4) col + | FROM (SELECT id col1, id col2, id col3, id col4 FROM range(10)) + """.stripMargin) + checkKeywordsExistsInExplain(df1, + "Project [concat(cast(id#xL as string), cast(id#xL as string), cast(id#xL as string)" + + ", cast(id#xL as string)) AS col#x]") + + // Check if catalyst combine nested `Concat`s if concatBinaryAsString=false + withSQLConf(SQLConf.CONCAT_BINARY_AS_STRING.key -> "false") { + val df2 = sql( + """ + |SELECT ((col1 || col2) || (col3 || col4)) col + |FROM ( + | SELECT + | string(id) col1, + | string(id + 1) col2, + | encode(string(id + 2), 'utf-8') col3, + | encode(string(id + 3), 'utf-8') col4 + | FROM range(10) + |) + """.stripMargin) + checkKeywordsExistsInExplain(df2, + "Project [concat(cast(id#xL as string), cast((id#xL + 1) as string), " + + "cast(encode(cast((id#xL + 2) as string), utf-8) as string), " + + "cast(encode(cast((id#xL + 3) as string), utf-8) as string)) AS col#x]") + + val df3 = sql( + """ + |SELECT (col1 || (col3 || col4)) col + |FROM ( + | SELECT + | string(id) col1, + | encode(string(id + 2), 'utf-8') col3, + | encode(string(id + 3), 'utf-8') col4 + | FROM range(10) + |) + """.stripMargin) + checkKeywordsExistsInExplain(df3, + "Project [concat(cast(id#xL as string), " + + "cast(encode(cast((id#xL + 2) as string), utf-8) as string), " + + "cast(encode(cast((id#xL + 3) as string), utf-8) as string)) AS col#x]") + } + } + + test("check operator precedence") { + // We follow Oracle operator precedence in the table below that lists the levels + // of precedence among SQL operators from high to low: + // --------------------------------------------------------------------------------------- + // Operator Operation + // --------------------------------------------------------------------------------------- + // +, - identity, negation + // *, / multiplication, division + // +, -, || addition, subtraction, concatenation + // =, !=, <, >, <=, >=, IS NULL, LIKE, BETWEEN, IN comparison + // NOT exponentiation, logical negation + // AND conjunction + // OR disjunction + // --------------------------------------------------------------------------------------- + checkKeywordsExistsInExplain(sql("select 'a' || 1 + 2"), + "Project [null AS (CAST(concat(a, CAST(1 AS STRING)) AS DOUBLE) + CAST(2 AS DOUBLE))#x]") + checkKeywordsExistsInExplain(sql("select 1 - 2 || 'b'"), + "Project [-1b AS concat(CAST((1 - 2) AS STRING), b)#x]") + checkKeywordsExistsInExplain(sql("select 2 * 4 + 3 || 'b'"), + "Project [11b AS concat(CAST(((2 * 4) + 3) AS STRING), b)#x]") + checkKeywordsExistsInExplain(sql("select 3 + 1 || 'a' || 4 / 2"), + "Project [4a2.0 AS concat(concat(CAST((3 + 1) AS STRING), a), " + + "CAST((CAST(4 AS DOUBLE) / CAST(2 AS DOUBLE)) AS STRING))#x]") + checkKeywordsExistsInExplain(sql("select 1 == 1 OR 'a' || 'b' == 'ab'"), + "Project [true AS ((1 = 1) OR (concat(a, b) = ab))#x]") + checkKeywordsExistsInExplain(sql("select 'a' || 'c' == 'ac' AND 2 == 3"), + "Project [false AS ((concat(a, c) = ac) AND (2 = 3))#x]") + } + + test("explain for these functions; use range to avoid constant folding") { + val df = sql("select ifnull(id, 'x'), nullif(id, 'x'), nvl(id, 'x'), nvl2(id, 'x', 'y') " + + "from range(2)") + checkKeywordsExistsInExplain(df, + "Project [coalesce(cast(id#xL as string), x) AS ifnull(`id`, 'x')#x, " + + "id#xL AS nullif(`id`, 'x')#xL, coalesce(cast(id#xL as string), x) AS nvl(`id`, 'x')#x, " + + "x AS nvl2(`id`, 'x', 'y')#x]") + } } case class ExplainSingleData(id: Int) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index cf4585bf7ac6c..b2515226d9a14 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -137,28 +137,39 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { } } + // For better test coverage, runs the tests on mixed config sets: WHOLESTAGE_CODEGEN_ENABLED + // and CODEGEN_FACTORY_MODE. + private lazy val codegenConfigSets = Array( + ("true", "CODEGEN_ONLY"), + ("false", "CODEGEN_ONLY"), + ("false", "NO_CODEGEN") + ).map { case (wholeStageCodegenEnabled, codegenFactoryMode) => + Array(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> wholeStageCodegenEnabled, + SQLConf.CODEGEN_FACTORY_MODE.key -> codegenFactoryMode) + } + /** Run a test case. */ private def runTest(testCase: TestCase): Unit = { val input = fileToString(new File(testCase.inputFile)) val (comments, code) = input.split("\n").partition(_.startsWith("--")) - // Runs all the tests on both codegen-only and interpreter modes - val codegenConfigSets = Array(CODEGEN_ONLY, NO_CODEGEN).map { - case codegenFactoryMode => - Array(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenFactoryMode.toString) - } - val configSets = { - val configLines = comments.filter(_.startsWith("--SET")).map(_.substring(5)) - val configs = configLines.map(_.split(",").map { confAndValue => - val (conf, value) = confAndValue.span(_ != '=') - conf.trim -> value.substring(1).trim - }) - // When we are regenerating the golden files, we don't need to set any config as they - // all need to return the same result - if (regenerateGoldenFiles) { - Array.empty[Array[(String, String)]] - } else { + // List of SQL queries to run + // note: this is not a robust way to split queries using semicolon, but works for now. + val queries = code.mkString("\n").split("(?<=[^\\\\]);").map(_.trim).filter(_ != "").toSeq + + // When we are regenerating the golden files, we don't need to set any config as they + // all need to return the same result + if (regenerateGoldenFiles) { + runQueries(queries, testCase.resultFile, None) + } else { + val configSets = { + val configLines = comments.filter(_.startsWith("--SET")).map(_.substring(5)) + val configs = configLines.map(_.split(",").map { confAndValue => + val (conf, value) = confAndValue.span(_ != '=') + conf.trim -> value.substring(1).trim + }) + if (configs.nonEmpty) { codegenConfigSets.flatMap { codegenConfig => configs.map { config => @@ -169,15 +180,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { codegenConfigSets } } - } - // List of SQL queries to run - // note: this is not a robust way to split queries using semicolon, but works for now. - val queries = code.mkString("\n").split("(?<=[^\\\\]);").map(_.trim).filter(_ != "").toSeq - - if (configSets.isEmpty) { - runQueries(queries, testCase.resultFile, None) - } else { configSets.foreach { configSet => try { runQueries(queries, testCase.resultFile, Some(configSet)) From 4b723012677a31df8eee5534d98abd49db270913 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 20 Dec 2018 10:47:24 +0800 Subject: [PATCH 104/194] [SPARK-25271][SQL] Hive ctas commands should use data source if it is convertible ## What changes were proposed in this pull request? In Spark 2.3.0 and previous versions, Hive CTAS command will convert to use data source to write data into the table when the table is convertible. This behavior is controlled by the configs like HiveUtils.CONVERT_METASTORE_ORC and HiveUtils.CONVERT_METASTORE_PARQUET. In 2.3.1, we drop this optimization by mistake in the PR [SPARK-22977](https://github.com/apache/spark/pull/20521/files#r217254430). Since that Hive CTAS command only uses Hive Serde to write data. This patch adds this optimization back to Hive CTAS command. This patch adds OptimizedCreateHiveTableAsSelectCommand which uses data source to write data. ## How was this patch tested? Added test. Closes #22514 from viirya/SPARK-25271-2. Authored-by: Liang-Chi Hsieh Signed-off-by: Wenchen Fan --- .../spark/sql/execution/command/ddl.scala | 8 ++ .../datasources/DataSourceStrategy.scala | 12 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 43 +++++- .../spark/sql/hive/HiveStrategies.scala | 62 +++----- .../org/apache/spark/sql/hive/HiveUtils.scala | 8 ++ .../CreateHiveTableAsSelectCommand.scala | 134 +++++++++++++----- .../spark/sql/hive/HiveParquetSuite.scala | 14 ++ .../sql/hive/execution/SQLQuerySuite.scala | 40 ++++++ 8 files changed, 230 insertions(+), 91 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index e1faecedd20ed..096481f68275d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -820,6 +820,14 @@ object DDLUtils { table.provider.isDefined && table.provider.get.toLowerCase(Locale.ROOT) != HIVE_PROVIDER } + def readHiveTable(table: CatalogTable): HiveTableRelation = { + HiveTableRelation( + table, + // Hive table columns are always nullable. + table.dataSchema.asNullable.toAttributes, + table.partitionSchema.asNullable.toAttributes) + } + /** * Throws a standard error for actions that require partitionProvider = hive. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index b304e2da6e1cf..b5cf8c9515bfb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -244,27 +244,19 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] }) } - private def readHiveTable(table: CatalogTable): LogicalPlan = { - HiveTableRelation( - table, - // Hive table columns are always nullable. - table.dataSchema.asNullable.toAttributes, - table.partitionSchema.asNullable.toAttributes) - } - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case i @ InsertIntoTable(UnresolvedCatalogRelation(tableMeta), _, _, _, _) if DDLUtils.isDatasourceTable(tableMeta) => i.copy(table = readDataSourceTable(tableMeta)) case i @ InsertIntoTable(UnresolvedCatalogRelation(tableMeta), _, _, _, _) => - i.copy(table = readHiveTable(tableMeta)) + i.copy(table = DDLUtils.readHiveTable(tableMeta)) case UnresolvedCatalogRelation(tableMeta) if DDLUtils.isDatasourceTable(tableMeta) => readDataSourceTable(tableMeta) case UnresolvedCatalogRelation(tableMeta) => - readHiveTable(tableMeta) + DDLUtils.readHiveTable(tableMeta) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 5823548a8063c..03f4b8d83e353 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive +import java.util.Locale + import scala.util.control.NonFatal import com.google.common.util.concurrent.Striped @@ -29,6 +31,8 @@ import org.apache.spark.sql.catalyst.{QualifiedTableName, TableIdentifier} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetOptions} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.HiveCaseSensitiveInferenceMode._ import org.apache.spark.sql.types._ @@ -113,7 +117,44 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log } } - def convertToLogicalRelation( + // Return true for Apache ORC and Hive ORC-related configuration names. + // Note that Spark doesn't support configurations like `hive.merge.orcfile.stripe.level`. + private def isOrcProperty(key: String) = + key.startsWith("orc.") || key.contains(".orc.") + + private def isParquetProperty(key: String) = + key.startsWith("parquet.") || key.contains(".parquet.") + + def convert(relation: HiveTableRelation): LogicalRelation = { + val serde = relation.tableMeta.storage.serde.getOrElse("").toLowerCase(Locale.ROOT) + + // Consider table and storage properties. For properties existing in both sides, storage + // properties will supersede table properties. + if (serde.contains("parquet")) { + val options = relation.tableMeta.properties.filterKeys(isParquetProperty) ++ + relation.tableMeta.storage.properties + (ParquetOptions.MERGE_SCHEMA -> + SQLConf.get.getConf(HiveUtils.CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING).toString) + convertToLogicalRelation(relation, options, classOf[ParquetFileFormat], "parquet") + } else { + val options = relation.tableMeta.properties.filterKeys(isOrcProperty) ++ + relation.tableMeta.storage.properties + if (SQLConf.get.getConf(SQLConf.ORC_IMPLEMENTATION) == "native") { + convertToLogicalRelation( + relation, + options, + classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat], + "orc") + } else { + convertToLogicalRelation( + relation, + options, + classOf[org.apache.spark.sql.hive.orc.OrcFileFormat], + "orc") + } + } + } + + private def convertToLogicalRelation( relation: HiveTableRelation, options: Map[String, String], fileFormatClass: Class[_ <: FileFormat], diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 07ee105404311..8a5ab188a949f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -31,8 +31,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoDir, InsertIntoTab import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command.{CreateTableCommand, DDLUtils} -import org.apache.spark.sql.execution.datasources.{CreateTable, LogicalRelation} -import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetOptions} +import org.apache.spark.sql.execution.datasources.CreateTable import org.apache.spark.sql.hive.execution._ import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} @@ -181,49 +180,17 @@ case class RelationConversions( conf: SQLConf, sessionCatalog: HiveSessionCatalog) extends Rule[LogicalPlan] { private def isConvertible(relation: HiveTableRelation): Boolean = { - val serde = relation.tableMeta.storage.serde.getOrElse("").toLowerCase(Locale.ROOT) - serde.contains("parquet") && conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET) || - serde.contains("orc") && conf.getConf(HiveUtils.CONVERT_METASTORE_ORC) + isConvertible(relation.tableMeta) } - // Return true for Apache ORC and Hive ORC-related configuration names. - // Note that Spark doesn't support configurations like `hive.merge.orcfile.stripe.level`. - private def isOrcProperty(key: String) = - key.startsWith("orc.") || key.contains(".orc.") - - private def isParquetProperty(key: String) = - key.startsWith("parquet.") || key.contains(".parquet.") - - private def convert(relation: HiveTableRelation): LogicalRelation = { - val serde = relation.tableMeta.storage.serde.getOrElse("").toLowerCase(Locale.ROOT) - - // Consider table and storage properties. For properties existing in both sides, storage - // properties will supersede table properties. - if (serde.contains("parquet")) { - val options = relation.tableMeta.properties.filterKeys(isParquetProperty) ++ - relation.tableMeta.storage.properties + (ParquetOptions.MERGE_SCHEMA -> - conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING).toString) - sessionCatalog.metastoreCatalog - .convertToLogicalRelation(relation, options, classOf[ParquetFileFormat], "parquet") - } else { - val options = relation.tableMeta.properties.filterKeys(isOrcProperty) ++ - relation.tableMeta.storage.properties - if (conf.getConf(SQLConf.ORC_IMPLEMENTATION) == "native") { - sessionCatalog.metastoreCatalog.convertToLogicalRelation( - relation, - options, - classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat], - "orc") - } else { - sessionCatalog.metastoreCatalog.convertToLogicalRelation( - relation, - options, - classOf[org.apache.spark.sql.hive.orc.OrcFileFormat], - "orc") - } - } + private def isConvertible(tableMeta: CatalogTable): Boolean = { + val serde = tableMeta.storage.serde.getOrElse("").toLowerCase(Locale.ROOT) + serde.contains("parquet") && SQLConf.get.getConf(HiveUtils.CONVERT_METASTORE_PARQUET) || + serde.contains("orc") && SQLConf.get.getConf(HiveUtils.CONVERT_METASTORE_ORC) } + private val metastoreCatalog = sessionCatalog.metastoreCatalog + override def apply(plan: LogicalPlan): LogicalPlan = { plan resolveOperators { // Write path @@ -231,12 +198,21 @@ case class RelationConversions( // Inserting into partitioned table is not supported in Parquet/Orc data source (yet). if query.resolved && DDLUtils.isHiveTable(r.tableMeta) && !r.isPartitioned && isConvertible(r) => - InsertIntoTable(convert(r), partition, query, overwrite, ifPartitionNotExists) + InsertIntoTable(metastoreCatalog.convert(r), partition, + query, overwrite, ifPartitionNotExists) // Read path case relation: HiveTableRelation if DDLUtils.isHiveTable(relation.tableMeta) && isConvertible(relation) => - convert(relation) + metastoreCatalog.convert(relation) + + // CTAS + case CreateTable(tableDesc, mode, Some(query)) + if DDLUtils.isHiveTable(tableDesc) && tableDesc.partitionColumnNames.isEmpty && + isConvertible(tableDesc) && SQLConf.get.getConf(HiveUtils.CONVERT_METASTORE_CTAS) => + DDLUtils.checkDataColNames(tableDesc) + OptimizedCreateHiveTableAsSelectCommand( + tableDesc, query, query.output.map(_.name), mode) } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index 66067704195dd..b60d4c71f5941 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -110,6 +110,14 @@ private[spark] object HiveUtils extends Logging { .booleanConf .createWithDefault(true) + val CONVERT_METASTORE_CTAS = buildConf("spark.sql.hive.convertMetastoreCtas") + .doc("When set to true, Spark will try to use built-in data source writer " + + "instead of Hive serde in CTAS. This flag is effective only if " + + "`spark.sql.hive.convertMetastoreParquet` or `spark.sql.hive.convertMetastoreOrc` is " + + "enabled respectively for Parquet and ORC formats") + .booleanConf + .createWithDefault(true) + val HIVE_METASTORE_SHARED_PREFIXES = buildConf("spark.sql.hive.metastore.sharedPrefixes") .doc("A comma separated list of class prefixes that should be loaded using the classloader " + "that is shared between Spark SQL and a specific version of Hive. An example of classes " + diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala index fd1e931ee0c7a..608f21e726259 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala @@ -20,32 +20,26 @@ package org.apache.spark.sql.hive.execution import scala.util.control.NonFatal import org.apache.spark.sql.{AnalysisException, Row, SaveMode, SparkSession} -import org.apache.spark.sql.catalyst.catalog.CatalogTable -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, SessionCatalog} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.command.DataWritingCommand +import org.apache.spark.sql.execution.command.{DataWritingCommand, DDLUtils} +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, InsertIntoHadoopFsRelationCommand, LogicalRelation} +import org.apache.spark.sql.hive.HiveSessionCatalog +trait CreateHiveTableAsSelectBase extends DataWritingCommand { + val tableDesc: CatalogTable + val query: LogicalPlan + val outputColumnNames: Seq[String] + val mode: SaveMode -/** - * Create table and insert the query result into it. - * - * @param tableDesc the Table Describe, which may contain serde, storage handler etc. - * @param query the query whose result will be insert into the new relation - * @param mode SaveMode - */ -case class CreateHiveTableAsSelectCommand( - tableDesc: CatalogTable, - query: LogicalPlan, - outputColumnNames: Seq[String], - mode: SaveMode) - extends DataWritingCommand { - - private val tableIdentifier = tableDesc.identifier + protected val tableIdentifier = tableDesc.identifier override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = { val catalog = sparkSession.sessionState.catalog - if (catalog.tableExists(tableIdentifier)) { + val tableExists = catalog.tableExists(tableIdentifier) + + if (tableExists) { assert(mode != SaveMode.Overwrite, s"Expect the table $tableIdentifier has been dropped when the save mode is Overwrite") @@ -57,15 +51,8 @@ case class CreateHiveTableAsSelectCommand( return Seq.empty } - // For CTAS, there is no static partition values to insert. - val partition = tableDesc.partitionColumnNames.map(_ -> None).toMap - InsertIntoHiveTable( - tableDesc, - partition, - query, - overwrite = false, - ifPartitionNotExists = false, - outputColumnNames = outputColumnNames).run(sparkSession, child) + val command = getWritingCommand(catalog, tableDesc, tableExists = true) + command.run(sparkSession, child) } else { // TODO ideally, we should get the output data ready first and then // add the relation into catalog, just in case of failure occurs while data @@ -77,15 +64,8 @@ case class CreateHiveTableAsSelectCommand( try { // Read back the metadata of the table which was created just now. val createdTableMeta = catalog.getTableMetadata(tableDesc.identifier) - // For CTAS, there is no static partition values to insert. - val partition = createdTableMeta.partitionColumnNames.map(_ -> None).toMap - InsertIntoHiveTable( - createdTableMeta, - partition, - query, - overwrite = true, - ifPartitionNotExists = false, - outputColumnNames = outputColumnNames).run(sparkSession, child) + val command = getWritingCommand(catalog, createdTableMeta, tableExists = false) + command.run(sparkSession, child) } catch { case NonFatal(e) => // drop the created table. @@ -97,9 +77,89 @@ case class CreateHiveTableAsSelectCommand( Seq.empty[Row] } + // Returns `DataWritingCommand` which actually writes data into the table. + def getWritingCommand( + catalog: SessionCatalog, + tableDesc: CatalogTable, + tableExists: Boolean): DataWritingCommand + override def argString: String = { s"[Database:${tableDesc.database}, " + s"TableName: ${tableDesc.identifier.table}, " + s"InsertIntoHiveTable]" } } + +/** + * Create table and insert the query result into it. + * + * @param tableDesc the table description, which may contain serde, storage handler etc. + * @param query the query whose result will be insert into the new relation + * @param mode SaveMode + */ +case class CreateHiveTableAsSelectCommand( + tableDesc: CatalogTable, + query: LogicalPlan, + outputColumnNames: Seq[String], + mode: SaveMode) + extends CreateHiveTableAsSelectBase { + + override def getWritingCommand( + catalog: SessionCatalog, + tableDesc: CatalogTable, + tableExists: Boolean): DataWritingCommand = { + // For CTAS, there is no static partition values to insert. + val partition = tableDesc.partitionColumnNames.map(_ -> None).toMap + InsertIntoHiveTable( + tableDesc, + partition, + query, + overwrite = if (tableExists) false else true, + ifPartitionNotExists = false, + outputColumnNames = outputColumnNames) + } +} + +/** + * Create table and insert the query result into it. This creates Hive table but inserts + * the query result into it by using data source. + * + * @param tableDesc the table description, which may contain serde, storage handler etc. + * @param query the query whose result will be insert into the new relation + * @param mode SaveMode + */ +case class OptimizedCreateHiveTableAsSelectCommand( + tableDesc: CatalogTable, + query: LogicalPlan, + outputColumnNames: Seq[String], + mode: SaveMode) + extends CreateHiveTableAsSelectBase { + + override def getWritingCommand( + catalog: SessionCatalog, + tableDesc: CatalogTable, + tableExists: Boolean): DataWritingCommand = { + val metastoreCatalog = catalog.asInstanceOf[HiveSessionCatalog].metastoreCatalog + val hiveTable = DDLUtils.readHiveTable(tableDesc) + + val hadoopRelation = metastoreCatalog.convert(hiveTable) match { + case LogicalRelation(t: HadoopFsRelation, _, _, _) => t + case _ => throw new AnalysisException(s"$tableIdentifier should be converted to " + + "HadoopFsRelation.") + } + + InsertIntoHadoopFsRelationCommand( + hadoopRelation.location.rootPaths.head, + Map.empty, // We don't support to convert partitioned table. + false, + Seq.empty, // We don't support to convert partitioned table. + hadoopRelation.bucketSpec, + hadoopRelation.fileFormat, + hadoopRelation.options, + query, + if (tableExists) mode else SaveMode.Overwrite, + Some(tableDesc), + Some(hadoopRelation.location), + query.output.map(_.name)) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala index e5c9df05d5674..470c6a342b4dd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala @@ -92,4 +92,18 @@ class HiveParquetSuite extends QueryTest with ParquetTest with TestHiveSingleton } } } + + test("SPARK-25271: write empty map into hive parquet table") { + import testImplicits._ + + Seq(Map(1 -> "a"), Map.empty[Int, String]).toDF("m").createOrReplaceTempView("p") + withTempView("p") { + val targetTable = "targetTable" + withTable(targetTable) { + sql(s"CREATE TABLE $targetTable STORED AS PARQUET AS SELECT m FROM p") + checkAnswer(sql(s"SELECT m FROM $targetTable"), + Row(Map(1 -> "a")) :: Row(Map.empty[Int, String]) :: Nil) + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index fab2a27cdef17..6acf44606cbbe 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -2276,6 +2276,46 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } + test("SPARK-25271: Hive ctas commands should use data source if it is convertible") { + withTempView("p") { + Seq(1, 2, 3).toDF("id").createOrReplaceTempView("p") + + Seq("orc", "parquet").foreach { format => + Seq(true, false).foreach { isConverted => + withSQLConf( + HiveUtils.CONVERT_METASTORE_ORC.key -> s"$isConverted", + HiveUtils.CONVERT_METASTORE_PARQUET.key -> s"$isConverted") { + Seq(true, false).foreach { isConvertedCtas => + withSQLConf(HiveUtils.CONVERT_METASTORE_CTAS.key -> s"$isConvertedCtas") { + + val targetTable = "targetTable" + withTable(targetTable) { + val df = sql(s"CREATE TABLE $targetTable STORED AS $format AS SELECT id FROM p") + checkAnswer(sql(s"SELECT id FROM $targetTable"), + Row(1) :: Row(2) :: Row(3) :: Nil) + + val ctasDSCommand = df.queryExecution.analyzed.collect { + case _: OptimizedCreateHiveTableAsSelectCommand => true + }.headOption + val ctasCommand = df.queryExecution.analyzed.collect { + case _: CreateHiveTableAsSelectCommand => true + }.headOption + + if (isConverted && isConvertedCtas) { + assert(ctasDSCommand.nonEmpty) + assert(ctasCommand.isEmpty) + } else { + assert(ctasDSCommand.isEmpty) + assert(ctasCommand.nonEmpty) + } + } + } + } + } + } + } + } + } test("SPARK-26181 hasMinMaxStats method of ColumnStatsMap is not correct") { withSQLConf(SQLConf.CBO_ENABLED.key -> "true") { From 2aa1022bd10b5f661d2702bf5c23061261a155f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E4=BA=AE?= Date: Thu, 20 Dec 2018 13:22:12 +0800 Subject: [PATCH 105/194] [SPARK-26318][SQL] Deprecate Row.merge MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Deprecate Row.merge ## How was this patch tested? N/A Closes #23271 from KyleLi1985/master. Authored-by: 李亮 Signed-off-by: Hyukjin Kwon --- sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index e12bf9616e2de..4f5af9ac80b10 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -57,6 +57,7 @@ object Row { /** * Merge multiple rows into a single row, one after another. */ + @deprecated("This method is deprecated and will be removed in future versions.", "3.0.0") def merge(rows: Row*): Row = { // TODO: Improve the performance of this if used in performance critical part. new GenericRow(rows.flatMap(_.toSeq).toArray) From 029933c5d67c27e782175d0e1417402062176269 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 20 Dec 2018 14:17:44 +0800 Subject: [PATCH 106/194] [SPARK-26308][SQL] Avoid cast of decimals for ScalaUDF ## What changes were proposed in this pull request? Currently, when we infer the schema for scala/java decimals, we return as data type the `SYSTEM_DEFAULT` implementation, ie. the decimal type with precision 38 and scale 18. But this is not right, as we know nothing about the right precision and scale and these values can be not enough to store the data. This problem arises in particular with UDF, where we cast all the input of type `DecimalType` to a `DecimalType(38, 18)`: in case this is not enough, null is returned as input for the UDF. The PR defines a custom handling for casting to the expected data types for ScalaUDF: the decimal precision and scale is picked from the input, so no casting to different and maybe wrong percision and scale happens. ## How was this patch tested? added UTs Closes #23308 from mgaido91/SPARK-26308. Authored-by: Marco Gaido Signed-off-by: Wenchen Fan --- .../sql/catalyst/analysis/TypeCoercion.scala | 31 ++++++++++++++++++ .../sql/catalyst/expressions/ScalaUDF.scala | 2 +- .../scala/org/apache/spark/sql/UDFSuite.scala | 32 ++++++++++++++++++- 3 files changed, 63 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 133fa119b7aa6..1706b3eece6d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -879,6 +879,37 @@ object TypeCoercion { } } e.withNewChildren(children) + + case udf: ScalaUDF if udf.inputTypes.nonEmpty => + val children = udf.children.zip(udf.inputTypes).map { case (in, expected) => + implicitCast(in, udfInputToCastType(in.dataType, expected)).getOrElse(in) + } + udf.withNewChildren(children) + } + + private def udfInputToCastType(input: DataType, expectedType: DataType): DataType = { + (input, expectedType) match { + // SPARK-26308: avoid casting to an arbitrary precision and scale for decimals. Please note + // that precision and scale cannot be inferred properly for a ScalaUDF because, when it is + // created, it is not bound to any column. So here the precision and scale of the input + // column is used. + case (in: DecimalType, _: DecimalType) => in + case (ArrayType(dtIn, _), ArrayType(dtExp, nullableExp)) => + ArrayType(udfInputToCastType(dtIn, dtExp), nullableExp) + case (MapType(keyDtIn, valueDtIn, _), MapType(keyDtExp, valueDtExp, nullableExp)) => + MapType(udfInputToCastType(keyDtIn, keyDtExp), + udfInputToCastType(valueDtIn, valueDtExp), + nullableExp) + case (StructType(fieldsIn), StructType(fieldsExp)) => + val fieldTypes = + fieldsIn.map(_.dataType).zip(fieldsExp.map(_.dataType)).map { case (dtIn, dtExp) => + udfInputToCastType(dtIn, dtExp) + } + StructType(fieldsExp.zip(fieldTypes).map { case (field, newDt) => + field.copy(dataType = newDt) + }) + case (_, other) => other + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index fae90caebf96c..a23aaa3a0b3ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -52,7 +52,7 @@ case class ScalaUDF( udfName: Option[String] = None, nullable: Boolean = true, udfDeterministic: Boolean = true) - extends Expression with ImplicitCastInputTypes with NonSQLExpression with UserDefinedExpression { + extends Expression with NonSQLExpression with UserDefinedExpression { // The constructor for SPARK 2.1 and 2.2 def this( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 20dcefa7e3cad..a26d306cff6b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.math.BigDecimal + import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.execution.QueryExecution @@ -26,7 +28,7 @@ import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationComm import org.apache.spark.sql.functions.{lit, udf} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ -import org.apache.spark.sql.types.{DataTypes, DoubleType} +import org.apache.spark.sql.types._ import org.apache.spark.sql.util.QueryExecutionListener @@ -420,4 +422,32 @@ class UDFSuite extends QueryTest with SharedSQLContext { checkAnswer(df, Seq(Row("null1x"), Row(null), Row("N3null"))) } } + + test("SPARK-26308: udf with decimal") { + val df1 = spark.createDataFrame( + sparkContext.parallelize(Seq(Row(new BigDecimal("2011000000000002456556")))), + StructType(Seq(StructField("col1", DecimalType(30, 0))))) + val udf1 = org.apache.spark.sql.functions.udf((value: BigDecimal) => { + if (value == null) null else value.toBigInteger.toString + }) + checkAnswer(df1.select(udf1(df1.col("col1"))), Seq(Row("2011000000000002456556"))) + } + + test("SPARK-26308: udf with complex types of decimal") { + val df1 = spark.createDataFrame( + sparkContext.parallelize(Seq(Row(Array(new BigDecimal("2011000000000002456556"))))), + StructType(Seq(StructField("col1", ArrayType(DecimalType(30, 0)))))) + val udf1 = org.apache.spark.sql.functions.udf((arr: Seq[BigDecimal]) => { + arr.map(value => if (value == null) null else value.toBigInteger.toString) + }) + checkAnswer(df1.select(udf1($"col1")), Seq(Row(Array("2011000000000002456556")))) + + val df2 = spark.createDataFrame( + sparkContext.parallelize(Seq(Row(Map("a" -> new BigDecimal("2011000000000002456556"))))), + StructType(Seq(StructField("col1", MapType(StringType, DecimalType(30, 0)))))) + val udf2 = org.apache.spark.sql.functions.udf((map: Map[String, BigDecimal]) => { + map.mapValues(value => if (value == null) null else value.toBigInteger.toString) + }) + checkAnswer(df2.select(udf2($"col1")), Seq(Row(Map("a" -> "2011000000000002456556")))) + } } From a647251b91ac49b29e836249d7802a30752d6abd Mon Sep 17 00:00:00 2001 From: zhoukang Date: Thu, 20 Dec 2018 08:26:25 -0600 Subject: [PATCH 107/194] [SPARK-24687][CORE] Avoid job hanging when generate task binary causes fatal error ## What changes were proposed in this pull request? When NoClassDefFoundError thrown,it will cause job hang. `Exception in thread "dag-scheduler-event-loop" java.lang.NoClassDefFoundError: Lcom/xxx/data/recommend/aggregator/queue/QueueName; at java.lang.Class.getDeclaredFields0(Native Method) at java.lang.Class.privateGetDeclaredFields(Class.java:2436) at java.lang.Class.getDeclaredField(Class.java:1946) at java.io.ObjectStreamClass.getDeclaredSUID(ObjectStreamClass.java:1659) at java.io.ObjectStreamClass.access$700(ObjectStreamClass.java:72) at java.io.ObjectStreamClass$2.run(ObjectStreamClass.java:480) at java.io.ObjectStreamClass$2.run(ObjectStreamClass.java:468) at java.security.AccessController.doPrivileged(Native Method) at java.io.ObjectStreamClass.(ObjectStreamClass.java:468) at java.io.ObjectStreamClass.lookup(ObjectStreamClass.java:365) at java.io.ObjectOutputStream.writeClass(ObjectOutputStream.java:1212) at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1119) at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1547) at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1508) at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1431) at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1177) at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1547) at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1508) at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1431) at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1177) at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1547) at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1508) at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1431) at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1177) at java.io.ObjectOutputStream.writeArray(ObjectOutputStream.java:1377) at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1173) at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1547) at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1508) at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1431) at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1177) at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1547) at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1508) at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1431) at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1177) at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1547) at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1508) at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1431) at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1177) at java.io.ObjectOutputStream.writeArray(ObjectOutputStream.java:1377)` It is caused by NoClassDefFoundError will not catch up during task seriazation. `var taskBinary: Broadcast[Array[Byte]] = null try { // For ShuffleMapTask, serialize and broadcast (rdd, shuffleDep). // For ResultTask, serialize and broadcast (rdd, func). val taskBinaryBytes: Array[Byte] = stage match { case stage: ShuffleMapStage => JavaUtils.bufferToArray( closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef)) case stage: ResultStage => JavaUtils.bufferToArray(closureSerializer.serialize((stage.rdd, stage.func): AnyRef)) } taskBinary = sc.broadcast(taskBinaryBytes) } catch { // In the case of a failure during serialization, abort the stage. case e: NotSerializableException => abortStage(stage, "Task not serializable: " + e.toString, Some(e)) runningStages -= stage // Abort execution return case NonFatal(e) => abortStage(stage, s"Task serialization failed: $e\n${Utils.exceptionString(e)}", Some(e)) runningStages -= stage return }` image below shows that stage 33 blocked and never be scheduled. 2018-06-28 4 28 42 2018-06-28 4 28 49 ## How was this patch tested? UT Closes #21664 from caneGuy/zhoukang/fix-noclassdeferror. Authored-by: zhoukang Signed-off-by: Sean Owen --- .../main/scala/org/apache/spark/scheduler/DAGScheduler.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 06966e77db81e..6f4c326442e1e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1170,9 +1170,11 @@ private[spark] class DAGScheduler( // Abort execution return - case NonFatal(e) => + case e: Throwable => abortStage(stage, s"Task serialization failed: $e\n${Utils.exceptionString(e)}", Some(e)) runningStages -= stage + + // Abort execution return } From b011de3cf6713c4abbf01c48869a79f2b01fb705 Mon Sep 17 00:00:00 2001 From: Jorge Machado Date: Thu, 20 Dec 2018 08:29:51 -0600 Subject: [PATCH 108/194] [SPARK-26324][DOCS] Add Spark docs for Running in Mesos with SSL ## What changes were proposed in this pull request? Added docs for running spark jobs with Mesos on SSL Closes #23342 from jomach/master. Lead-authored-by: Jorge Machado Co-authored-by: Jorge Machado Co-authored-by: Jorge Machado Co-authored-by: Jorge Machado Signed-off-by: Sean Owen --- docs/running-on-mesos.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 968d668e2c93a..a07773c1c71e1 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -108,6 +108,19 @@ Please note that if you specify multiple ways to obtain the credentials then the An equivalent order applies for the secret. Essentially we prefer the configuration to be specified directly rather than indirectly by files, and we prefer that configuration settings are used over environment variables. +### Deploy to a Mesos running on Secure Sockets + +If you want to deploy a Spark Application into a Mesos cluster that is running in a secure mode there are some environment variables that need to be set. + +- `LIBPROCESS_SSL_ENABLED=true` enables SSL communication +- `LIBPROCESS_SSL_VERIFY_CERT=false` verifies the ssl certificate +- `LIBPROCESS_SSL_KEY_FILE=pathToKeyFile.key` path to key +- `LIBPROCESS_SSL_CERT_FILE=pathToCRTFile.crt` the certificate file to be used + +All options can be found at http://mesos.apache.org/documentation/latest/ssl/ + +Then submit happens as described in Client mode or Cluster mode below + ## Uploading Spark Package When Mesos runs a task on a Mesos slave for the first time, that slave must have a Spark binary From f4e4c1a66c6c53f3e8604d1fbd35f03d80282241 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Thu, 20 Dec 2018 10:05:56 -0800 Subject: [PATCH 109/194] [SPARK-26409][SQL][TESTS] SQLConf should be serializable in test sessions ## What changes were proposed in this pull request? `SQLConf` is supposed to be serializable. However, currently it is not serializable in `WithTestConf`. `WithTestConf` uses the method `overrideConfs` in closure, while the classes which implements it (`TestHiveSessionStateBuilder` and `TestSQLSessionStateBuilder`) are not serializable. This PR is to use a local variable to fix it. ## How was this patch tested? Add unit test. Closes #23352 from gengliangwang/serializableSQLConf. Authored-by: Gengliang Wang Signed-off-by: gatorsmile --- .../apache/spark/sql/internal/BaseSessionStateBuilder.scala | 3 ++- .../test/scala/org/apache/spark/sql/SerializationSuite.scala | 5 +++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index ac07e1f6bb4f8..319c2649592fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -309,13 +309,14 @@ private[sql] trait WithTestConf { self: BaseSessionStateBuilder => def overrideConfs: Map[String, String] override protected lazy val conf: SQLConf = { + val overrideConfigurations = overrideConfs val conf = parentState.map(_.conf.clone()).getOrElse { new SQLConf { clear() override def clear(): Unit = { super.clear() // Make sure we start with the default test configs even after clear - overrideConfs.foreach { case (key, value) => setConfString(key, value) } + overrideConfigurations.foreach { case (key, value) => setConfString(key, value) } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala index cd6b2647e0be6..1a1c956aed3d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala @@ -27,4 +27,9 @@ class SerializationSuite extends SparkFunSuite with SharedSQLContext { val spark = SparkSession.builder.getOrCreate() new JavaSerializer(new SparkConf()).newInstance().serialize(spark.sqlContext) } + + test("[SPARK-26409] SQLConf should be serializable") { + val spark = SparkSession.builder.getOrCreate() + new JavaSerializer(new SparkConf()).newInstance().serialize(spark.sessionState.conf) + } } From ec539d19b93d4c1ab0282004a8753afc01b6f8a9 Mon Sep 17 00:00:00 2001 From: Ngone51 Date: Thu, 20 Dec 2018 10:25:52 -0800 Subject: [PATCH 110/194] [SPARK-26392][YARN] Cancel pending allocate requests by taking locality preference into account ## What changes were proposed in this pull request? Right now, we cancel pending allocate requests by its sending order. I thing we can take locality preference into account when do this to perfom least impact on task locality preference. ## How was this patch tested? N.A. Closes #23344 from Ngone51/dev-cancel-pending-allocate-requests-by-taking-locality-preference-into-account. Authored-by: Ngone51 Signed-off-by: Marcelo Vanzin --- .../spark/deploy/yarn/YarnAllocator.scala | 29 ++++++++----------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index d37d0d66d8ae2..54b1ec266113f 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -294,6 +294,15 @@ private[yarn] class YarnAllocator( s"pending: $numPendingAllocate, running: ${runningExecutors.size}, " + s"executorsStarting: ${numExecutorsStarting.get}") + // Split the pending container request into three groups: locality matched list, locality + // unmatched list and non-locality list. Take the locality matched container request into + // consideration of container placement, treat as allocated containers. + // For locality unmatched and locality free container requests, cancel these container + // requests, since required locality preference has been changed, recalculating using + // container placement strategy. + val (localRequests, staleRequests, anyHostRequests) = splitPendingAllocationsByLocality( + hostToLocalTaskCounts, pendingAllocate) + if (missing > 0) { if (log.isInfoEnabled()) { var requestContainerMessage = s"Will request $missing executor container(s), each with " + @@ -306,15 +315,6 @@ private[yarn] class YarnAllocator( logInfo(requestContainerMessage) } - // Split the pending container request into three groups: locality matched list, locality - // unmatched list and non-locality list. Take the locality matched container request into - // consideration of container placement, treat as allocated containers. - // For locality unmatched and locality free container requests, cancel these container - // requests, since required locality preference has been changed, recalculating using - // container placement strategy. - val (localRequests, staleRequests, anyHostRequests) = splitPendingAllocationsByLocality( - hostToLocalTaskCounts, pendingAllocate) - // cancel "stale" requests for locations that are no longer needed staleRequests.foreach { stale => amClient.removeContainerRequest(stale) @@ -374,14 +374,9 @@ private[yarn] class YarnAllocator( val numToCancel = math.min(numPendingAllocate, -missing) logInfo(s"Canceling requests for $numToCancel executor container(s) to have a new desired " + s"total $targetNumExecutors executors.") - - val matchingRequests = amClient.getMatchingRequests(RM_REQUEST_PRIORITY, ANY_HOST, resource) - if (!matchingRequests.isEmpty) { - matchingRequests.iterator().next().asScala - .take(numToCancel).foreach(amClient.removeContainerRequest) - } else { - logWarning("Expected to find pending requests, but found none.") - } + // cancel pending allocate requests by taking locality preference into account + val cancelRequests = (staleRequests ++ anyHostRequests ++ localRequests).take(numToCancel) + cancelRequests.foreach(amClient.removeContainerRequest) } } From cc07eaef6ca9a4040aacbe432df2ffb105a16379 Mon Sep 17 00:00:00 2001 From: zhengruifeng Date: Thu, 20 Dec 2018 11:22:49 -0800 Subject: [PATCH 111/194] [SPARK-25970][ML] Add Instrumentation to PrefixSpan ## What changes were proposed in this pull request? Add Instrumentation to PrefixSpan ## How was this patch tested? existing tests Closes #22971 from zhengruifeng/log_PrefixSpan. Authored-by: zhengruifeng Signed-off-by: Xiangrui Meng --- .../src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala index 2a3413553a6af..b0006a8d4a58e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.fpm import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param._ import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.fpm.{PrefixSpan => mllibPrefixSpan} import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.col @@ -135,7 +136,10 @@ final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params * - `freq: Long` */ @Since("2.4.0") - def findFrequentSequentialPatterns(dataset: Dataset[_]): DataFrame = { + def findFrequentSequentialPatterns(dataset: Dataset[_]): DataFrame = instrumented { instr => + instr.logDataset(dataset) + instr.logParams(this, params: _*) + val sequenceColParam = $(sequenceCol) val inputType = dataset.schema(sequenceColParam).dataType require(inputType.isInstanceOf[ArrayType] && From 07d111dabded53a6a5fb1dc81d96d080f9f57144 Mon Sep 17 00:00:00 2001 From: liuxian Date: Fri, 21 Dec 2018 13:01:14 +0800 Subject: [PATCH 112/194] [MINOR][SQL] Locality does not need to be implemented ## What changes were proposed in this pull request? `HadoopFileWholeTextReader` and `HadoopFileLinesReader` will be eventually called in `FileSourceScanExec`. In fact, locality has been implemented in `FileScanRDD`, even if we implement it in `HadoopFileWholeTextReader ` and `HadoopFileLinesReader`, it would be useless. So I think these `TODO` can be removed. ## How was this patch tested? N/A Closes #23339 from 10110346/noneededtodo. Authored-by: liuxian Signed-off-by: Wenchen Fan --- .../spark/sql/execution/datasources/HadoopFileLinesReader.scala | 2 +- .../sql/execution/datasources/HadoopFileWholeTextReader.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala index 00a78f7343c59..57082b40e1132 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala @@ -51,7 +51,7 @@ class HadoopFileLinesReader( new Path(new URI(file.filePath)), file.start, file.length, - // TODO: Implement Locality + // The locality is decided by `getPreferredLocations` in `FileScanRDD`. Array.empty) val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileWholeTextReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileWholeTextReader.scala index c61a89e6e8c3f..f5724f7c5955d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileWholeTextReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileWholeTextReader.scala @@ -40,7 +40,7 @@ class HadoopFileWholeTextReader(file: PartitionedFile, conf: Configuration) Array(new Path(new URI(file.filePath))), Array(file.start), Array(file.length), - // TODO: Implement Locality + // The locality is decided by `getPreferredLocations` in `FileScanRDD`. Array.empty[String]) val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) From 75ff5d99b75283643e60e5fd48a9abca39090f8d Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Fri, 21 Dec 2018 16:09:30 +0800 Subject: [PATCH 113/194] [SPARK-26422][R] Support to disable Hive support in SparkR even for Hadoop versions unsupported by Hive fork ## What changes were proposed in this pull request? Currently, even if I explicitly disable Hive support in SparkR session as below: ```r sparkSession <- sparkR.session("local[4]", "SparkR", Sys.getenv("SPARK_HOME"), enableHiveSupport = FALSE) ``` produces when the Hadoop version is not supported by our Hive fork: ``` java.lang.reflect.InvocationTargetException ... Caused by: java.lang.IllegalArgumentException: Unrecognized Hadoop major version number: 3.1.1.3.1.0.0-78 at org.apache.hadoop.hive.shims.ShimLoader.getMajorVersion(ShimLoader.java:174) at org.apache.hadoop.hive.shims.ShimLoader.loadShims(ShimLoader.java:139) at org.apache.hadoop.hive.shims.ShimLoader.getHadoopShims(ShimLoader.java:100) at org.apache.hadoop.hive.conf.HiveConf$ConfVars.(HiveConf.java:368) ... 43 more Error in handleErrors(returnStatus, conn) : java.lang.ExceptionInInitializerError at org.apache.hadoop.hive.conf.HiveConf.(HiveConf.java:105) at java.lang.Class.forName0(Native Method) at java.lang.Class.forName(Class.java:348) at org.apache.spark.util.Utils$.classForName(Utils.scala:193) at org.apache.spark.sql.SparkSession$.hiveClassesArePresent(SparkSession.scala:1116) at org.apache.spark.sql.api.r.SQLUtils$.getOrCreateSparkSession(SQLUtils.scala:52) at org.apache.spark.sql.api.r.SQLUtils.getOrCreateSparkSession(SQLUtils.scala) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) ``` The root cause is that: ``` SparkSession.hiveClassesArePresent ``` check if the class is loadable or not to check if that's in classpath but `org.apache.hadoop.hive.conf.HiveConf` has a check for Hadoop version as static logic which is executed right away. This throws an `IllegalArgumentException` and that's not caught: https://github.com/apache/spark/blob/36edbac1c8337a4719f90e4abd58d38738b2e1fb/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala#L1113-L1121 So, currently, if users have a Hive built-in Spark with unsupported Hadoop version by our fork (namely 3+), there's no way to use SparkR even though it could work. This PR just propose to change the order of bool comparison so that we can don't execute `SparkSession.hiveClassesArePresent` when: 1. `enableHiveSupport` is explicitly disabled 2. `spark.sql.catalogImplementation` is `in-memory` so that we **only** check `SparkSession.hiveClassesArePresent` when Hive support is explicitly enabled by short circuiting. ## How was this patch tested? It's difficult to write a test since we don't run tests against Hadoop 3 yet. See https://github.com/apache/spark/pull/21588. Manually tested. Closes #23356 from HyukjinKwon/SPARK-26422. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- .../scala/org/apache/spark/sql/api/r/SQLUtils.scala | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index becb05cf72aba..e98cab8b56d13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -49,9 +49,17 @@ private[sql] object SQLUtils extends Logging { sparkConfigMap: JMap[Object, Object], enableHiveSupport: Boolean): SparkSession = { val spark = - if (SparkSession.hiveClassesArePresent && enableHiveSupport && + if (enableHiveSupport && jsc.sc.conf.get(CATALOG_IMPLEMENTATION.key, "hive").toLowerCase(Locale.ROOT) == - "hive") { + "hive" && + // Note that the order of conditions here are on purpose. + // `SparkSession.hiveClassesArePresent` checks if Hive's `HiveConf` is loadable or not; + // however, `HiveConf` itself has some static logic to check if Hadoop version is + // supported or not, which throws an `IllegalArgumentException` if unsupported. + // If this is checked first, there's no way to disable Hive support in the case above. + // So, we intentionally check if Hive classes are loadable or not only when + // Hive support is explicitly enabled by short-circuiting. See also SPARK-26422. + SparkSession.hiveClassesArePresent) { SparkSession.builder().sparkContext(withHiveExternalCatalog(jsc.sc)).getOrCreate() } else { if (enableHiveSupport) { From 65de56408f49edf9d21f40c1fe19fac3a1a37c0c Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 21 Dec 2018 10:41:25 -0800 Subject: [PATCH 114/194] [SPARK-26267][SS] Retry when detecting incorrect offsets from Kafka ## What changes were proposed in this pull request? Due to [KAFKA-7703](https://issues.apache.org/jira/browse/KAFKA-7703), Kafka may return an earliest offset when we are request a latest offset. This will cause Spark to reprocess data. As per suggestion in KAFKA-7703, we put a position call between poll and seekToEnd to block the fetch request triggered by `poll` before calling `seekToEnd`. In addition, to avoid other unknown issues, we also use the previous known offsets to audit the latest offsets returned by Kafka. If we find some incorrect offsets (a latest offset is less than an offset in `knownOffsets`), we will retry at most `maxOffsetFetchAttempts` times. ## How was this patch tested? Jenkins Closes #23324 from zsxwing/SPARK-26267. Authored-by: Shixiong Zhu Signed-off-by: Shixiong Zhu --- .../kafka010/KafkaContinuousReadSupport.scala | 4 +- .../kafka010/KafkaMicroBatchReadSupport.scala | 19 ++++- .../kafka010/KafkaOffsetRangeCalculator.scala | 2 + .../sql/kafka010/KafkaOffsetReader.scala | 80 +++++++++++++++++-- .../spark/sql/kafka010/KafkaSource.scala | 5 +- .../kafka010/KafkaMicroBatchSourceSuite.scala | 48 +++++++++++ 6 files changed, 145 insertions(+), 13 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReadSupport.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReadSupport.scala index 1753a28fba2fb..02dfb9ca2b95a 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReadSupport.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReadSupport.scala @@ -60,7 +60,7 @@ class KafkaContinuousReadSupport( override def initialOffset(): Offset = { val offsets = initialOffsets match { case EarliestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchEarliestOffsets()) - case LatestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchLatestOffsets()) + case LatestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchLatestOffsets(None)) case SpecificOffsetRangeLimit(p) => offsetReader.fetchSpecificOffsets(p, reportDataLoss) } logInfo(s"Initial offsets: $offsets") @@ -107,7 +107,7 @@ class KafkaContinuousReadSupport( override def needsReconfiguration(config: ScanConfig): Boolean = { val knownPartitions = config.asInstanceOf[KafkaContinuousScanConfig].knownPartitions - offsetReader.fetchLatestOffsets().keySet != knownPartitions + offsetReader.fetchLatestOffsets(None).keySet != knownPartitions } override def toString(): String = s"KafkaSource[$offsetReader]" diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReadSupport.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReadSupport.scala index bb4de674c3c72..b4f042e93a5da 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReadSupport.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReadSupport.scala @@ -84,7 +84,7 @@ private[kafka010] class KafkaMicroBatchReadSupport( override def latestOffset(start: Offset): Offset = { val startPartitionOffsets = start.asInstanceOf[KafkaSourceOffset].partitionToOffsets - val latestPartitionOffsets = kafkaOffsetReader.fetchLatestOffsets() + val latestPartitionOffsets = kafkaOffsetReader.fetchLatestOffsets(Some(startPartitionOffsets)) endPartitionOffsets = KafkaSourceOffset(maxOffsetsPerTrigger.map { maxOffsets => rateLimit(maxOffsets, startPartitionOffsets, latestPartitionOffsets) }.getOrElse { @@ -133,10 +133,21 @@ private[kafka010] class KafkaMicroBatchReadSupport( }.toSeq logDebug("TopicPartitions: " + topicPartitions.mkString(", ")) + val fromOffsets = startPartitionOffsets ++ newPartitionInitialOffsets + val untilOffsets = endPartitionOffsets + untilOffsets.foreach { case (tp, untilOffset) => + fromOffsets.get(tp).foreach { fromOffset => + if (untilOffset < fromOffset) { + reportDataLoss(s"Partition $tp's offset was changed from " + + s"$fromOffset to $untilOffset, some data may have been missed") + } + } + } + // Calculate offset ranges val offsetRanges = rangeCalculator.getRanges( - fromOffsets = startPartitionOffsets ++ newPartitionInitialOffsets, - untilOffsets = endPartitionOffsets, + fromOffsets = fromOffsets, + untilOffsets = untilOffsets, executorLocations = getSortedExecutorList()) // Reuse Kafka consumers only when all the offset ranges have distinct TopicPartitions, @@ -186,7 +197,7 @@ private[kafka010] class KafkaMicroBatchReadSupport( case EarliestOffsetRangeLimit => KafkaSourceOffset(kafkaOffsetReader.fetchEarliestOffsets()) case LatestOffsetRangeLimit => - KafkaSourceOffset(kafkaOffsetReader.fetchLatestOffsets()) + KafkaSourceOffset(kafkaOffsetReader.fetchLatestOffsets(None)) case SpecificOffsetRangeLimit(p) => kafkaOffsetReader.fetchSpecificOffsets(p, reportDataLoss) } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala index fb209c724afba..6008794924052 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala @@ -37,6 +37,8 @@ private[kafka010] class KafkaOffsetRangeCalculator(val minPartitions: Option[Int * the read tasks of the skewed partitions to multiple Spark tasks. * The number of Spark tasks will be *approximately* `numPartitions`. It can be less or more * depending on rounding errors or Kafka partitions that didn't receive any new data. + * + * Empty ranges (`KafkaOffsetRange.size <= 0`) will be dropped. */ def getRanges( fromOffsets: PartitionOffsetMap, diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala index 82066697cb95a..fc443d22bf5a2 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala @@ -21,6 +21,7 @@ import java.{util => ju} import java.util.concurrent.{Executors, ThreadFactory} import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration import scala.util.control.NonFatal @@ -137,6 +138,12 @@ private[kafka010] class KafkaOffsetReader( // Poll to get the latest assigned partitions consumer.poll(0) val partitions = consumer.assignment() + + // Call `position` to wait until the potential offset request triggered by `poll(0)` is + // done. This is a workaround for KAFKA-7703, which an async `seekToBeginning` triggered by + // `poll(0)` may reset offsets that should have been set by another request. + partitions.asScala.map(p => p -> consumer.position(p)).foreach(_ => {}) + consumer.pause(partitions) assert(partitions.asScala == partitionOffsets.keySet, "If startingOffsets contains specific offsets, you must specify all TopicPartitions.\n" + @@ -192,19 +199,82 @@ private[kafka010] class KafkaOffsetReader( /** * Fetch the latest offsets for the topic partitions that are indicated * in the [[ConsumerStrategy]]. + * + * Kafka may return earliest offsets when we are requesting latest offsets if `poll` is called + * right before `seekToEnd` (KAFKA-7703). As a workaround, we will call `position` right after + * `poll` to wait until the potential offset request triggered by `poll(0)` is done. + * + * In addition, to avoid other unknown issues, we also use the given `knownOffsets` to audit the + * latest offsets returned by Kafka. If we find some incorrect offsets (a latest offset is less + * than an offset in `knownOffsets`), we will retry at most `maxOffsetFetchAttempts` times. When + * a topic is recreated, the latest offsets may be less than offsets in `knownOffsets`. We cannot + * distinguish this with KAFKA-7703, so we just return whatever we get from Kafka after retrying. */ - def fetchLatestOffsets(): Map[TopicPartition, Long] = runUninterruptibly { + def fetchLatestOffsets( + knownOffsets: Option[PartitionOffsetMap]): PartitionOffsetMap = runUninterruptibly { withRetriesWithoutInterrupt { // Poll to get the latest assigned partitions consumer.poll(0) val partitions = consumer.assignment() + + // Call `position` to wait until the potential offset request triggered by `poll(0)` is + // done. This is a workaround for KAFKA-7703, which an async `seekToBeginning` triggered by + // `poll(0)` may reset offsets that should have been set by another request. + partitions.asScala.map(p => p -> consumer.position(p)).foreach(_ => {}) + consumer.pause(partitions) logDebug(s"Partitions assigned to consumer: $partitions. Seeking to the end.") - consumer.seekToEnd(partitions) - val partitionOffsets = partitions.asScala.map(p => p -> consumer.position(p)).toMap - logDebug(s"Got latest offsets for partition : $partitionOffsets") - partitionOffsets + if (knownOffsets.isEmpty) { + consumer.seekToEnd(partitions) + partitions.asScala.map(p => p -> consumer.position(p)).toMap + } else { + var partitionOffsets: PartitionOffsetMap = Map.empty + + /** + * Compare `knownOffsets` and `partitionOffsets`. Returns all partitions that have incorrect + * latest offset (offset in `knownOffsets` is great than the one in `partitionOffsets`). + */ + def findIncorrectOffsets(): Seq[(TopicPartition, Long, Long)] = { + var incorrectOffsets = ArrayBuffer[(TopicPartition, Long, Long)]() + partitionOffsets.foreach { case (tp, offset) => + knownOffsets.foreach(_.get(tp).foreach { knownOffset => + if (knownOffset > offset) { + val incorrectOffset = (tp, knownOffset, offset) + incorrectOffsets += incorrectOffset + } + }) + } + incorrectOffsets + } + + // Retry to fetch latest offsets when detecting incorrect offsets. We don't use + // `withRetriesWithoutInterrupt` to retry because: + // + // - `withRetriesWithoutInterrupt` will reset the consumer for each attempt but a fresh + // consumer has a much bigger chance to hit KAFKA-7703. + // - Avoid calling `consumer.poll(0)` which may cause KAFKA-7703. + var incorrectOffsets: Seq[(TopicPartition, Long, Long)] = Nil + var attempt = 0 + do { + consumer.seekToEnd(partitions) + partitionOffsets = partitions.asScala.map(p => p -> consumer.position(p)).toMap + attempt += 1 + + incorrectOffsets = findIncorrectOffsets() + if (incorrectOffsets.nonEmpty) { + logWarning("Found incorrect offsets in some partitions " + + s"(partition, previous offset, fetched offset): $incorrectOffsets") + if (attempt < maxOffsetFetchAttempts) { + logWarning("Retrying to fetch latest offsets because of incorrect offsets") + Thread.sleep(offsetFetchAttemptIntervalMs) + } + } + } while (incorrectOffsets.nonEmpty && attempt < maxOffsetFetchAttempts) + + logDebug(s"Got latest offsets for partition : $partitionOffsets") + partitionOffsets + } } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index 66ec7e0cd084a..d65b3cea632c4 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -130,7 +130,7 @@ private[kafka010] class KafkaSource( metadataLog.get(0).getOrElse { val offsets = startingOffsets match { case EarliestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchEarliestOffsets()) - case LatestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchLatestOffsets()) + case LatestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchLatestOffsets(None)) case SpecificOffsetRangeLimit(p) => kafkaReader.fetchSpecificOffsets(p, reportDataLoss) } metadataLog.add(0, offsets) @@ -148,7 +148,8 @@ private[kafka010] class KafkaSource( // Make sure initialPartitionOffsets is initialized initialPartitionOffsets - val latest = kafkaReader.fetchLatestOffsets() + val latest = kafkaReader.fetchLatestOffsets( + currentPartitionOffsets.orElse(Some(initialPartitionOffsets))) val offsets = maxOffsetsPerTrigger match { case None => latest diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index 5ee76990b54f4..61cbb3285a4f0 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -329,6 +329,54 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { ) } + test("subscribe topic by pattern with topic recreation between batches") { + val topicPrefix = newTopic() + val topic = topicPrefix + "-good" + val topic2 = topicPrefix + "-bad" + testUtils.createTopic(topic, partitions = 1) + testUtils.sendMessages(topic, Array("1", "3")) + testUtils.createTopic(topic2, partitions = 1) + testUtils.sendMessages(topic2, Array("2", "4")) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("kafka.default.api.timeout.ms", "3000") + .option("startingOffsets", "earliest") + .option("subscribePattern", s"$topicPrefix-.*") + + val ds = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + .map(kv => kv._2.toInt) + + testStream(ds)( + StartStream(), + AssertOnQuery { q => + q.processAllAvailable() + true + }, + CheckAnswer(1, 2, 3, 4), + // Restart the stream in this test to make the test stable. When recreating a topic when a + // consumer is alive, it may not be able to see the recreated topic even if a fresh consumer + // has seen it. + StopStream, + // Recreate `topic2` and wait until it's available + WithOffsetSync(new TopicPartition(topic2, 0), expectedOffset = 1) { () => + testUtils.deleteTopic(topic2) + testUtils.createTopic(topic2) + testUtils.sendMessages(topic2, Array("6")) + }, + StartStream(), + ExpectFailure[IllegalStateException](e => { + // The offset of `topic2` should be changed from 2 to 1 + assert(e.getMessage.contains("was changed from 2 to 1")) + }) + ) + } + test("ensure that initial offset are written with an extra byte in the beginning (SPARK-19517)") { withTempDir { metadataPath => val topic = "kafka-initial-offset-current" From def767a77a8787056c4f4ae6a4df719eb968a7fc Mon Sep 17 00:00:00 2001 From: wuyi Date: Fri, 21 Dec 2018 13:21:58 -0600 Subject: [PATCH 115/194] [SPARK-26269][YARN] Yarnallocator should have same blacklist behaviour with yarn to maxmize use of cluster resource ## What changes were proposed in this pull request? As I mentioned in jira [SPARK-26269](https://issues.apache.org/jira/browse/SPARK-26269), in order to maxmize the use of cluster resource, this pr try to make `YarnAllocator` have the same blacklist behaviour with YARN. ## How was this patch tested? Added. Closes #23223 from Ngone51/dev-YarnAllocator-should-have-same-blacklist-behaviour-with-YARN. Lead-authored-by: wuyi Co-authored-by: Ngone51 Signed-off-by: Thomas Graves --- .../spark/deploy/yarn/YarnAllocator.scala | 32 ++++++-- .../yarn/YarnAllocatorBlacklistTracker.scala | 4 +- .../YarnAllocatorBlacklistTrackerSuite.scala | 2 +- .../deploy/yarn/YarnAllocatorSuite.scala | 75 ++++++++++++++++++- 4 files changed, 101 insertions(+), 12 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 54b1ec266113f..a3feca5dfd229 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -607,13 +607,23 @@ private[yarn] class YarnAllocator( val message = "Container killed by YARN for exceeding physical memory limits. " + s"$diag Consider boosting ${EXECUTOR_MEMORY_OVERHEAD.key}." (true, message) - case _ => - // all the failures which not covered above, like: - // disk failure, kill by app master or resource manager, ... - allocatorBlacklistTracker.handleResourceAllocationFailure(hostOpt) - (true, "Container marked as failed: " + containerId + onHostStr + - ". Exit status: " + completedContainer.getExitStatus + - ". Diagnostics: " + completedContainer.getDiagnostics) + case other_exit_status => + // SPARK-26269: follow YARN's blacklisting behaviour(see https://github + // .com/apache/hadoop/blob/228156cfd1b474988bc4fedfbf7edddc87db41e3/had + // oop-yarn-project/hadoop-yarn/hadoop-yarn-common/src/main/java/org/ap + // ache/hadoop/yarn/util/Apps.java#L273 for details) + if (NOT_APP_AND_SYSTEM_FAULT_EXIT_STATUS.contains(other_exit_status)) { + (false, s"Container marked as failed: $containerId$onHostStr" + + s". Exit status: ${completedContainer.getExitStatus}" + + s". Diagnostics: ${completedContainer.getDiagnostics}.") + } else { + // completed container from a bad node + allocatorBlacklistTracker.handleResourceAllocationFailure(hostOpt) + (true, s"Container from a bad node: $containerId$onHostStr" + + s". Exit status: ${completedContainer.getExitStatus}" + + s". Diagnostics: ${completedContainer.getDiagnostics}.") + } + } if (exitCausedByApp) { @@ -739,4 +749,12 @@ private object YarnAllocator { val MEM_REGEX = "[0-9.]+ [KMG]B" val VMEM_EXCEEDED_EXIT_CODE = -103 val PMEM_EXCEEDED_EXIT_CODE = -104 + + val NOT_APP_AND_SYSTEM_FAULT_EXIT_STATUS = Set( + ContainerExitStatus.KILLED_BY_RESOURCEMANAGER, + ContainerExitStatus.KILLED_BY_APPMASTER, + ContainerExitStatus.KILLED_AFTER_APP_COMPLETION, + ContainerExitStatus.ABORTED, + ContainerExitStatus.DISKS_FAILED + ) } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTracker.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTracker.scala index ceac7cda5f8be..268976b629507 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTracker.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTracker.scala @@ -120,7 +120,9 @@ private[spark] class YarnAllocatorBlacklistTracker( if (removals.nonEmpty) { logInfo(s"removing nodes from YARN application master's blacklist: $removals") } - amClient.updateBlacklist(additions.asJava, removals.asJava) + if (additions.nonEmpty || removals.nonEmpty) { + amClient.updateBlacklist(additions.asJava, removals.asJava) + } currentBlacklistedYarnNodes = nodesToBlacklist } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTrackerSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTrackerSuite.scala index aeac68e6ed330..201910731e934 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTrackerSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTrackerSuite.scala @@ -87,7 +87,7 @@ class YarnAllocatorBlacklistTrackerSuite extends SparkFunSuite with Matchers // expired blacklisted nodes (simulating a resource request) yarnBlacklistTracker.setSchedulerBlacklistedNodes(Set("host1", "host2")) // no change is communicated to YARN regarding the blacklisting - verify(amClientMock).updateBlacklist(Collections.emptyList(), Collections.emptyList()) + verify(amClientMock, times(0)).updateBlacklist(Collections.emptyList(), Collections.emptyList()) } test("combining scheduler and allocation blacklist") { diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index b61e7df4420ef..53a538dc1de29 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.deploy.yarn +import java.util.Collections + import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration @@ -114,13 +116,29 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter clock) } - def createContainer(host: String, resource: Resource = containerResource): Container = { - val containerId = ContainerId.newContainerId(appAttemptId, containerNum) + def createContainer( + host: String, + containerNumber: Int = containerNum, + resource: Resource = containerResource): Container = { + val containerId: ContainerId = ContainerId.newContainerId(appAttemptId, containerNum) containerNum += 1 val nodeId = NodeId.newInstance(host, 1000) Container.newInstance(containerId, nodeId, "", resource, RM_REQUEST_PRIORITY, null) } + def createContainers(hosts: Seq[String], containerIds: Seq[Int]): Seq[Container] = { + hosts.zip(containerIds).map{case (host, id) => createContainer(host, id)} + } + + def createContainerStatus( + containerId: ContainerId, + exitStatus: Int, + containerState: ContainerState = ContainerState.COMPLETE, + diagnostics: String = "diagnostics"): ContainerStatus = { + ContainerStatus.newInstance(containerId, containerState, diagnostics, exitStatus) + } + + test("single container allocated") { // request a single container and receive it val handler = createAllocator(1) @@ -148,7 +166,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter Map(YARN_EXECUTOR_RESOURCE_TYPES_PREFIX + "gpu" -> "2G")) handler.updateResourceRequests() - val container = createContainer("host1", handler.resource) + val container = createContainer("host1", resource = handler.resource) handler.handleAllocatedContainers(Array(container)) // get amount of memory and vcores from resource, so effectively skipping their validation @@ -417,4 +435,55 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter clock.advance(50 * 1000L) handler.getNumExecutorsFailed should be (0) } + + test("SPARK-26269: YarnAllocator should have same blacklist behaviour with YARN") { + val rmClientSpy = spy(rmClient) + val maxExecutors = 11 + + val handler = createAllocator( + maxExecutors, + rmClientSpy, + Map( + "spark.yarn.blacklist.executor.launch.blacklisting.enabled" -> "true", + "spark.blacklist.application.maxFailedExecutorsPerNode" -> "0")) + handler.updateResourceRequests() + + val hosts = (0 until maxExecutors).map(i => s"host$i") + val ids = 0 to maxExecutors + val containers = createContainers(hosts, ids) + + val nonBlacklistedStatuses = Seq( + ContainerExitStatus.SUCCESS, + ContainerExitStatus.PREEMPTED, + ContainerExitStatus.KILLED_EXCEEDED_VMEM, + ContainerExitStatus.KILLED_EXCEEDED_PMEM, + ContainerExitStatus.KILLED_BY_RESOURCEMANAGER, + ContainerExitStatus.KILLED_BY_APPMASTER, + ContainerExitStatus.KILLED_AFTER_APP_COMPLETION, + ContainerExitStatus.ABORTED, + ContainerExitStatus.DISKS_FAILED) + + val nonBlacklistedContainerStatuses = nonBlacklistedStatuses.zipWithIndex.map { + case (exitStatus, idx) => createContainerStatus(containers(idx).getId, exitStatus) + } + + val BLACKLISTED_EXIT_CODE = 1 + val blacklistedStatuses = Seq(ContainerExitStatus.INVALID, BLACKLISTED_EXIT_CODE) + + val blacklistedContainerStatuses = blacklistedStatuses.zip(9 until maxExecutors).map { + case (exitStatus, idx) => createContainerStatus(containers(idx).getId, exitStatus) + } + + handler.handleAllocatedContainers(containers.slice(0, 9)) + handler.processCompletedContainers(nonBlacklistedContainerStatuses) + verify(rmClientSpy, never()) + .updateBlacklist(hosts.slice(0, 9).asJava, Collections.emptyList()) + + handler.handleAllocatedContainers(containers.slice(9, 11)) + handler.processCompletedContainers(blacklistedContainerStatuses) + verify(rmClientSpy) + .updateBlacklist(hosts.slice(9, 10).asJava, Collections.emptyList()) + verify(rmClientSpy) + .updateBlacklist(hosts.slice(10, 11).asJava, Collections.emptyList()) + } } From 38930f009ab40eaedbdb62843c60f5096d3be2b8 Mon Sep 17 00:00:00 2001 From: pgandhi Date: Fri, 21 Dec 2018 11:28:22 -0800 Subject: [PATCH 116/194] [SPARK-25642][YARN] Adding two new metrics to record the number of registered connections as well as the number of active connections to YARN Shuffle Service Recently, the ability to expose the metrics for YARN Shuffle Service was added as part of [SPARK-18364](https://github.com/apache/spark/pull/22485). We need to add some metrics to be able to determine the number of active connections as well as open connections to the external shuffle service to benchmark network and connection issues on large cluster environments. Added two more shuffle server metrics for Spark Yarn shuffle service: numRegisteredConnections which indicate the number of registered connections to the shuffle service and numActiveConnections which indicate the number of active connections to the shuffle service at any given point in time. If these metrics are outputted to a file, we get something like this: 1533674653489 default.shuffleService: Hostname=server1.abc.com, openBlockRequestLatencyMillis_count=729, openBlockRequestLatencyMillis_rate15=0.7110833548897356, openBlockRequestLatencyMillis_rate5=1.657808981793011, openBlockRequestLatencyMillis_rate1=2.2404486061620474, openBlockRequestLatencyMillis_rateMean=0.9242558551196706, numRegisteredConnections=35, blockTransferRateBytes_count=2635880512, blockTransferRateBytes_rate15=2578547.6094160094, blockTransferRateBytes_rate5=6048721.726302424, blockTransferRateBytes_rate1=8548922.518223226, blockTransferRateBytes_rateMean=3341878.633637769, registeredExecutorsSize=5, registerExecutorRequestLatencyMillis_count=5, registerExecutorRequestLatencyMillis_rate15=0.0027973949328659836, registerExecutorRequestLatencyMillis_rate5=0.0021278007987206426, registerExecutorRequestLatencyMillis_rate1=2.8270296777387467E-6, registerExecutorRequestLatencyMillis_rateMean=0.006339206380043053, numActiveConnections=35 Closes #22498 from pgandhi999/SPARK-18364. Authored-by: pgandhi Signed-off-by: Marcelo Vanzin --- .../spark/network/TransportContext.java | 9 ++++++- .../server/TransportChannelHandler.java | 18 +++++++++++++- .../spark/network/server/TransportServer.java | 5 ++++ .../shuffle/ExternalShuffleBlockHandler.java | 24 +++++++++++++++++-- .../network/yarn/YarnShuffleService.java | 21 +++++++++------- .../yarn/YarnShuffleServiceMetrics.java | 5 ++++ .../spark/deploy/ExternalShuffleService.scala | 2 ++ .../yarn/YarnShuffleServiceMetricsSuite.scala | 3 ++- 8 files changed, 73 insertions(+), 14 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java index 480b52652de53..1a3f3f2a6f249 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java @@ -20,6 +20,7 @@ import java.util.ArrayList; import java.util.List; +import com.codahale.metrics.Counter; import io.netty.channel.Channel; import io.netty.channel.ChannelPipeline; import io.netty.channel.EventLoopGroup; @@ -66,6 +67,8 @@ public class TransportContext { private final RpcHandler rpcHandler; private final boolean closeIdleConnections; private final boolean isClientOnly; + // Number of registered connections to the shuffle service + private Counter registeredConnections = new Counter(); /** * Force to create MessageEncoder and MessageDecoder so that we can make sure they will be created @@ -221,7 +224,7 @@ private TransportChannelHandler createChannelHandler(Channel channel, RpcHandler TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client, rpcHandler, conf.maxChunksBeingTransferred()); return new TransportChannelHandler(client, responseHandler, requestHandler, - conf.connectionTimeoutMs(), closeIdleConnections); + conf.connectionTimeoutMs(), closeIdleConnections, this); } /** @@ -234,4 +237,8 @@ private ChunkFetchRequestHandler createChunkFetchHandler(TransportChannelHandler } public TransportConf getConf() { return conf; } + + public Counter getRegisteredConnections() { + return registeredConnections; + } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java index c824a7b0d4740..ca81099c4d5cb 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -21,6 +21,7 @@ import io.netty.channel.SimpleChannelInboundHandler; import io.netty.handler.timeout.IdleState; import io.netty.handler.timeout.IdleStateEvent; +import org.apache.spark.network.TransportContext; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -57,18 +58,21 @@ public class TransportChannelHandler extends SimpleChannelInboundHandler allMetrics; // Time latency for open block request in ms private final Timer openBlockRequestLatencyMillis = new Timer(); @@ -181,14 +183,20 @@ private class ShuffleMetrics implements MetricSet { private final Timer registerExecutorRequestLatencyMillis = new Timer(); // Block transfer rate in byte per second private final Meter blockTransferRateBytes = new Meter(); + // Number of active connections to the shuffle service + private Counter activeConnections = new Counter(); + // Number of registered connections to the shuffle service + private Counter registeredConnections = new Counter(); - private ShuffleMetrics() { + public ShuffleMetrics() { allMetrics = new HashMap<>(); allMetrics.put("openBlockRequestLatencyMillis", openBlockRequestLatencyMillis); allMetrics.put("registerExecutorRequestLatencyMillis", registerExecutorRequestLatencyMillis); allMetrics.put("blockTransferRateBytes", blockTransferRateBytes); allMetrics.put("registeredExecutorsSize", (Gauge) () -> blockManager.getRegisteredExecutorsSize()); + allMetrics.put("numActiveConnections", activeConnections); + allMetrics.put("numRegisteredConnections", registeredConnections); } @Override @@ -244,4 +252,16 @@ public ManagedBuffer next() { } } + @Override + public void channelActive(TransportClient client) { + metrics.activeConnections.inc(); + super.channelActive(client); + } + + @Override + public void channelInactive(TransportClient client) { + metrics.activeConnections.dec(); + super.channelInactive(client); + } + } diff --git a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java index 72ae1a1295236..7e8d3b2bc3ba4 100644 --- a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java +++ b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java @@ -170,15 +170,6 @@ protected void serviceInit(Configuration conf) throws Exception { TransportConf transportConf = new TransportConf("shuffle", new HadoopConfigProvider(conf)); blockHandler = new ExternalShuffleBlockHandler(transportConf, registeredExecutorFile); - // register metrics on the block handler into the Node Manager's metrics system. - YarnShuffleServiceMetrics serviceMetrics = - new YarnShuffleServiceMetrics(blockHandler.getAllMetrics()); - - MetricsSystemImpl metricsSystem = (MetricsSystemImpl) DefaultMetricsSystem.instance(); - metricsSystem.register( - "sparkShuffleService", "Metrics on the Spark Shuffle Service", serviceMetrics); - logger.info("Registered metrics with Hadoop's DefaultMetricsSystem"); - // If authentication is enabled, set up the shuffle server to use a // special RPC handler that filters out unauthenticated fetch requests List bootstraps = Lists.newArrayList(); @@ -199,6 +190,18 @@ protected void serviceInit(Configuration conf) throws Exception { port = shuffleServer.getPort(); boundPort = port; String authEnabledString = authEnabled ? "enabled" : "not enabled"; + + // register metrics on the block handler into the Node Manager's metrics system. + blockHandler.getAllMetrics().getMetrics().put("numRegisteredConnections", + shuffleServer.getRegisteredConnections()); + YarnShuffleServiceMetrics serviceMetrics = + new YarnShuffleServiceMetrics(blockHandler.getAllMetrics()); + + MetricsSystemImpl metricsSystem = (MetricsSystemImpl) DefaultMetricsSystem.instance(); + metricsSystem.register( + "sparkShuffleService", "Metrics on the Spark Shuffle Service", serviceMetrics); + logger.info("Registered metrics with Hadoop's DefaultMetricsSystem"); + logger.info("Started YARN shuffle service for Spark on port {}. " + "Authentication is {}. Registered executor file is {}", port, authEnabledString, registeredExecutorFile); diff --git a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleServiceMetrics.java b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleServiceMetrics.java index 3e4d479b862b3..501237407e9b2 100644 --- a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleServiceMetrics.java +++ b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleServiceMetrics.java @@ -107,6 +107,11 @@ public static void collectMetric( throw new IllegalStateException( "Not supported class type of metric[" + name + "] for value " + gaugeValue); } + } else if (metric instanceof Counter) { + Counter c = (Counter) metric; + long counterValue = c.getCount(); + metricsRecordBuilder.addGauge(new ShuffleServiceMetricsInfo(name, "Number of " + + "connections to shuffle service " + name), counterValue); } } diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index f6b3c37f0fe72..03e3abb3ce569 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -84,6 +84,8 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana server = transportContext.createServer(port, bootstraps.asJava) shuffleServiceSource.registerMetricSet(server.getAllMetrics) + blockHandler.getAllMetrics.getMetrics.put("numRegisteredConnections", + server.getRegisteredConnections) shuffleServiceSource.registerMetricSet(blockHandler.getAllMetrics) masterMetricsSystem.registerSource(shuffleServiceSource) masterMetricsSystem.start() diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceMetricsSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceMetricsSuite.scala index 40b92282a3b8f..952fd0b70bb7b 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceMetricsSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceMetricsSuite.scala @@ -38,7 +38,8 @@ class YarnShuffleServiceMetricsSuite extends SparkFunSuite with Matchers { test("metrics named as expected") { val allMetrics = Set( "openBlockRequestLatencyMillis", "registerExecutorRequestLatencyMillis", - "blockTransferRateBytes", "registeredExecutorsSize") + "blockTransferRateBytes", "registeredExecutorsSize", "numActiveConnections", + "numRegisteredConnections") metrics.getMetrics.keySet().asScala should be (allMetrics) } From ac33584615642b4d1d56c631e768b2bf8f390f71 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 22 Dec 2018 10:16:27 +0800 Subject: [PATCH 117/194] [SPARK-26216][SQL][FOLLOWUP] use abstract class instead of trait for UserDefinedFunction ## What changes were proposed in this pull request? A followup of https://github.com/apache/spark/pull/23178 , to keep binary compability by using abstract class. ## How was this patch tested? Manual test. I created a simple app with Spark 2.4 ``` object TryUDF { def main(args: Array[String]): Unit = { val spark = SparkSession.builder().appName("test").master("local[*]").getOrCreate() import spark.implicits._ val f1 = udf((i: Int) => i + 1) println(f1.deterministic) spark.range(10).select(f1.asNonNullable().apply($"id")).show() spark.stop() } } ``` When I run it with current master, it fails with ``` java.lang.IncompatibleClassChangeError: Found interface org.apache.spark.sql.expressions.UserDefinedFunction, but class was expected ``` When I run it with this PR, it works Closes #23351 from cloud-fan/minor. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- docs/sql-migration-guide-upgrade.md | 2 -- project/MimaExcludes.scala | 28 ++++++++++++++++++- .../sql/expressions/UserDefinedFunction.scala | 2 +- 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index 115fc6516fb4c..1bd3b5ad0e1aa 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -33,8 +33,6 @@ displayTitle: Spark SQL Upgrading Guide - In Spark version 2.4 and earlier, the `SET` command works without any warnings even if the specified key is for `SparkConf` entries and it has no effect because the command does not update `SparkConf`, but the behavior might confuse users. Since 3.0, the command fails if a `SparkConf` key is used. You can disable such a check by setting `spark.sql.legacy.setCommandRejectsSparkCoreConfs` to `false`. - - Spark applications which are built with Spark version 2.4 and prior, and call methods of `UserDefinedFunction`, need to be re-compiled with Spark 3.0, as they are not binary compatible with Spark 3.0. - - Since Spark 3.0, CSV/JSON datasources use java.time API for parsing and generating CSV/JSON content. In Spark version 2.4 and earlier, java.text.SimpleDateFormat is used for the same purpuse with fallbacks to the parsing mechanisms of Spark 2.0 and 1.x. For example, `2018-12-08 10:39:21.123` with the pattern `yyyy-MM-dd'T'HH:mm:ss.SSS` cannot be parsed since Spark 3.0 because the timestamp does not match to the pattern but it can be parsed by earlier Spark versions due to a fallback to `Timestamp.valueOf`. To parse the same timestamp since Spark 3.0, the pattern should be `yyyy-MM-dd HH:mm:ss.SSS`. To switch back to the implementation used in Spark 2.4 and earlier, set `spark.sql.legacy.timeParser.enabled` to `true`. - In Spark version 2.4 and earlier, CSV datasource converts a malformed CSV string to a row with all `null`s in the PERMISSIVE mode. Since Spark 3.0, the returned row can contain non-`null` fields if some of CSV column values were parsed and converted to desired types successfully. diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 7bb70a29195d6..89fc53ce3972f 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -241,7 +241,33 @@ object MimaExcludes { // [SPARK-26216][SQL] Do not use case class as public API (UserDefinedFunction) ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.expressions.UserDefinedFunction$"), - ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.sql.expressions.UserDefinedFunction") + ProblemFilters.exclude[AbstractClassProblem]("org.apache.spark.sql.expressions.UserDefinedFunction"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.inputTypes"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.nullableTypes_="), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.dataType"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.f"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.this"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.asNonNullable"), + ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.asNonNullable"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.nullable"), + ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.nullable"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.asNondeterministic"), + ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.asNondeterministic"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.deterministic"), + ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.deterministic"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.apply"), + ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.apply"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.withName"), + ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.withName"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.productElement"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.productArity"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.copy$default$2"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.canEqual"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.copy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.copy$default$1"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.productIterator"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.productPrefix"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.copy$default$3") ) // Exclude rules for 2.4.x diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index f88e0e0f299de..901472d8e0360 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.types.DataType * @since 1.3.0 */ @Stable -sealed trait UserDefinedFunction { +sealed abstract class UserDefinedFunction { /** * Returns true when the UDF can return a nullable value. From ffd2ef65631e72787cc03559bb34d083e190bc64 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 22 Dec 2018 00:41:21 -0800 Subject: [PATCH 118/194] [SPARK-26427][BUILD] Upgrade Apache ORC to 1.5.4 ## What changes were proposed in this pull request? This PR aims to update Apache ORC dependency to the latest version 1.5.4 released at Dec. 20. ([Release Notes](https://issues.apache.org/jira/secure/ReleaseNote.jspa?projectId=12318320&version=12344187])) ``` [ORC-237] OrcFile.mergeFiles Specified block size is less than configured minimum value [ORC-409] Changes for extending MemoryManagerImpl [ORC-410] Fix a locale-dependent test in TestCsvReader [ORC-416] Avoid opening data reader when there is no stripe [ORC-417] Use dynamic Apache Maven mirror link [ORC-419] Ensure to call `close` at RecordReaderImpl constructor exception [ORC-432] openjdk 8 has a bug that prevents surefire from working [ORC-435] Ability to read stripes that are greater than 2GB [ORC-437] Make acid schema checks case insensitive [ORC-411] Update build to work with Java 10. [ORC-418] Fix broken docker build script ``` ## How was this patch tested? Build and pass Jenkins. Closes #23364 from dongjoon-hyun/SPARK-26427. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- dev/deps/spark-deps-hadoop-2.7 | 6 +++--- dev/deps/spark-deps-hadoop-3.1 | 6 +++--- pom.xml | 6 +++++- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 71423af0789c6..1af29fcaff2aa 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -155,9 +155,9 @@ objenesis-2.5.1.jar okhttp-3.8.1.jar okio-1.13.0.jar opencsv-2.3.jar -orc-core-1.5.3-nohive.jar -orc-mapreduce-1.5.3-nohive.jar -orc-shims-1.5.3.jar +orc-core-1.5.4-nohive.jar +orc-mapreduce-1.5.4-nohive.jar +orc-shims-1.5.4.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index 93eafef045330..05f180b17a588 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -172,9 +172,9 @@ okhttp-2.7.5.jar okhttp-3.8.1.jar okio-1.13.0.jar opencsv-2.3.jar -orc-core-1.5.3-nohive.jar -orc-mapreduce-1.5.3-nohive.jar -orc-shims-1.5.3.jar +orc-core-1.5.4-nohive.jar +orc-mapreduce-1.5.4-nohive.jar +orc-shims-1.5.4.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar diff --git a/pom.xml b/pom.xml index 310d7de955125..de9421419edc2 100644 --- a/pom.xml +++ b/pom.xml @@ -132,7 +132,7 @@ 2.1.0 10.12.1.1 1.10.0 - 1.5.3 + 1.5.4 nohive 1.6.0 9.4.12.v20180830 @@ -1740,6 +1740,10 @@ ${orc.classifier} ${orc.deps.scope} + + javax.xml.bind + jaxb-api + org.apache.hadoop hadoop-common From 2b503818000daf0450fb9cb6d64837bbeca0bc92 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 22 Dec 2018 00:43:59 -0800 Subject: [PATCH 119/194] [SPARK-26428][SS][TEST] Minimize deprecated `ProcessingTime` usage ## What changes were proposed in this pull request? Use of `ProcessingTime` class was deprecated in favor of `Trigger.ProcessingTime` in Spark 2.2. And, [SPARK-21464](https://issues.apache.org/jira/browse/SPARK-21464) minimized it at 2.2.1. Recently, it grows again in test suites. This PR aims to clean up newly introduced deprecation warnings for Spark 3.0. ## How was this patch tested? Pass the Jenkins with existing tests and manually check the warnings. Closes #23367 from dongjoon-hyun/SPARK-26428. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../kafka010/KafkaMicroBatchSourceSuite.scala | 16 ++++++++-------- .../sql/streaming/FileStreamSourceSuite.scala | 2 +- .../apache/spark/sql/streaming/StreamSuite.scala | 4 ++-- .../streaming/StreamingQueryListenerSuite.scala | 6 +++--- .../sql/streaming/StreamingQuerySuite.scala | 4 ++-- 5 files changed, 16 insertions(+), 16 deletions(-) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index 61cbb3285a4f0..d4eb526540053 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.kafka010.KafkaSourceProvider._ import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} +import org.apache.spark.sql.streaming.{StreamTest, Trigger} import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.SharedSQLContext @@ -236,7 +236,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { } testStream(mapped)( - StartStream(ProcessingTime(100), clock), + StartStream(Trigger.ProcessingTime(100), clock), waitUntilBatchProcessed, // 1 from smallest, 1 from middle, 8 from biggest CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107), @@ -247,7 +247,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { 11, 108, 109, 110, 111, 112, 113, 114, 115, 116 ), StopStream, - StartStream(ProcessingTime(100), clock), + StartStream(Trigger.ProcessingTime(100), clock), waitUntilBatchProcessed, // smallest now empty, 1 more from middle, 9 more from biggest CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, @@ -282,7 +282,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { val mapped = kafka.map(kv => kv._2.toInt + 1) testStream(mapped)( - StartStream(trigger = ProcessingTime(1)), + StartStream(trigger = Trigger.ProcessingTime(1)), makeSureGetOffsetCalled, AddKafkaData(Set(topic), 1, 2, 3), CheckAnswer(2, 3, 4), @@ -605,7 +605,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { } testStream(kafka)( - StartStream(ProcessingTime(100), clock), + StartStream(Trigger.ProcessingTime(100), clock), waitUntilBatchProcessed, // 5 from smaller topic, 5 from bigger one CheckLastBatch((0 to 4) ++ (100 to 104): _*), @@ -618,7 +618,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { // smaller topic empty, 5 from bigger one CheckLastBatch(110 to 114: _*), StopStream, - StartStream(ProcessingTime(100), clock), + StartStream(Trigger.ProcessingTime(100), clock), waitUntilBatchProcessed, // smallest now empty, 5 from bigger one CheckLastBatch(115 to 119: _*), @@ -727,7 +727,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { // The message values are the same as their offsets to make the test easy to follow testUtils.withTranscationalProducer { producer => testStream(mapped)( - StartStream(ProcessingTime(100), clock), + StartStream(Trigger.ProcessingTime(100), clock), waitUntilBatchProcessed, CheckAnswer(), WithOffsetSync(topicPartition, expectedOffset = 5) { () => @@ -850,7 +850,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { // The message values are the same as their offsets to make the test easy to follow testUtils.withTranscationalProducer { producer => testStream(mapped)( - StartStream(ProcessingTime(100), clock), + StartStream(Trigger.ProcessingTime(100), clock), waitUntilBatchProcessed, CheckNewAnswer(), WithOffsetSync(topicPartition, expectedOffset = 5) { () => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index d4bd9c7987f2d..de664cafed3b6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -1360,7 +1360,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest { options = srcOptions) val clock = new StreamManualClock() testStream(fileStream)( - StartStream(trigger = ProcessingTime(10), triggerClock = clock), + StartStream(trigger = Trigger.ProcessingTime(10), triggerClock = clock), AssertOnQuery { _ => // Block until the first batch finishes. eventually(timeout(streamingTimeout)) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index f55ddb5419d20..55fdcee83f114 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -312,7 +312,7 @@ class StreamSuite extends StreamTest { val inputData = MemoryStream[Int] testStream(inputData.toDS())( - StartStream(ProcessingTime("10 seconds"), new StreamManualClock), + StartStream(Trigger.ProcessingTime("10 seconds"), new StreamManualClock), /* -- batch 0 ----------------------- */ // Add some data in batch 0 @@ -353,7 +353,7 @@ class StreamSuite extends StreamTest { /* Stop then restart the Stream */ StopStream, - StartStream(ProcessingTime("10 seconds"), new StreamManualClock(60 * 1000)), + StartStream(Trigger.ProcessingTime("10 seconds"), new StreamManualClock(60 * 1000)), /* -- batch 1 no rerun ----------------- */ // batch 1 would not re-run because the latest batch id logged in commit log is 1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index fe77a1b4469c5..d00f2e3bf4d1a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -82,7 +82,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { testStream(df, OutputMode.Append)( // Start event generated when query started - StartStream(ProcessingTime(100), triggerClock = clock), + StartStream(Trigger.ProcessingTime(100), triggerClock = clock), AssertOnQuery { query => assert(listener.startEvent !== null) assert(listener.startEvent.id === query.id) @@ -124,7 +124,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { }, // Termination event generated with exception message when stopped with error - StartStream(ProcessingTime(100), triggerClock = clock), + StartStream(Trigger.ProcessingTime(100), triggerClock = clock), AssertStreamExecThreadToWaitForClock(), AddData(inputData, 0), AdvanceManualClock(100), // process bad data @@ -306,7 +306,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { } val clock = new StreamManualClock() val actions = mutable.ArrayBuffer[StreamAction]() - actions += StartStream(trigger = ProcessingTime(10), triggerClock = clock) + actions += StartStream(trigger = Trigger.ProcessingTime(10), triggerClock = clock) for (_ <- 1 to 100) { actions += AdvanceManualClock(10) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index c170641372d61..29b816486a1fe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -257,7 +257,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi var lastProgressBeforeStop: StreamingQueryProgress = null testStream(mapped, OutputMode.Complete)( - StartStream(ProcessingTime(1000), triggerClock = clock), + StartStream(Trigger.ProcessingTime(1000), triggerClock = clock), AssertStreamExecThreadIsWaitingForTime(1000), AssertOnQuery(_.status.isDataAvailable === false), AssertOnQuery(_.status.isTriggerActive === false), @@ -370,7 +370,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi AssertOnQuery(_.status.message === "Stopped"), // Test status and progress after query terminated with error - StartStream(ProcessingTime(1000), triggerClock = clock), + StartStream(Trigger.ProcessingTime(1000), triggerClock = clock), AdvanceManualClock(1000), // ensure initial trigger completes before AddData AddData(inputData, 0), AdvanceManualClock(1000), // allow another trigger From 3892c5d8bc833bd14baaf4f26bbbc73a707efa4f Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 22 Dec 2018 00:46:36 -0800 Subject: [PATCH 120/194] [SPARK-26430][BUILD][TEST-MAVEN] Upgrade Surefire plugin to 3.0.0-M2 ## What changes were proposed in this pull request? This PR aims to upgrade Maven Surefile plugin for JDK11 support. 3.0.0-M2 is [released Dec. 9th.](https://issues.apache.org/jira/projects/SUREFIRE/versions/12344396) ``` [SUREFIRE-1568] Versions 2.21 and higher doesn't work with junit-platform for Java 9 module [SUREFIRE-1605] NoClassDefFoundError (RunNotifier) with JDK 11 [SUREFIRE-1600] Surefire Project using surefire:2.12.4 is not fully able to work with JDK 10+ on internal build system. Therefore surefire-shadefire should go with Surefire:3.0.0-M2. [SUREFIRE-1593] 3.0.0-M1 produces invalid code sources on Windows [SUREFIRE-1602] Surefire fails loading class ForkedBooter when using a sub-directory pom file and a local maven repo [SUREFIRE-1606] maven-shared-utils must not be on provider's classpath [SUREFIRE-1531] Option to switch-off Java 9 modules [SUREFIRE-1590] Deploy multiple versions of Report XSD [SUREFIRE-1591] Java 1.7 feature Diamonds replaced Generics [SUREFIRE-1594] Java 1.7 feature try-catch - multiple exceptions in one catch [SUREFIRE-1595] Java 1.7 feature System.lineSeparator() [SUREFIRE-1597] ModularClasspathForkConfiguration with debug logs (args file and its path on file system) [SUREFIRE-1596] Unnecessary check JAVA_RECENT == JAVA_1_7 in unit tests [SUREFIRE-1598] Fixed typo in assertion statement in integration test Surefire855AllowFailsafeUseArtifactFileIT [SUREFIRE-1607] Roadmap on Project Site ``` ## How was this patch tested? Pass the Jenkins. Closes #23370 from dongjoon-hyun/SPARK-26430. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index de9421419edc2..321de209a56a1 100644 --- a/pom.xml +++ b/pom.xml @@ -2103,7 +2103,7 @@ org.apache.maven.plugins maven-surefire-plugin - 3.0.0-M1 + 3.0.0-M2 From c19093ac64984abf55d2f2cea013ffb09d98cb0e Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Sat, 22 Dec 2018 09:03:02 -0600 Subject: [PATCH 121/194] =?UTF-8?q?[SPARK-26285][CORE]=20accumulator=20met?= =?UTF-8?q?rics=20sources=20for=20LongAccumulator=20and=20Doub=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …leAccumulator ## What changes were proposed in this pull request? This PR implements metric sources for LongAccumulator and DoubleAccumulator, such that a user can register these accumulators easily and have their values be reported by the driver's metric namespace. ## How was this patch tested? Unit tests, and manual tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #23242 from abellina/SPARK-26285_accumulator_source. Lead-authored-by: Alessandro Bellina Co-authored-by: Alessandro Bellina Co-authored-by: Alessandro Bellina Signed-off-by: Thomas Graves --- .../metrics/source/AccumulatorSource.scala | 89 ++++++++++++++++++ .../source/AccumulatorSourceSuite.scala | 91 +++++++++++++++++++ .../examples/AccumulatorMetricsTest.scala | 77 ++++++++++++++++ 3 files changed, 257 insertions(+) create mode 100644 core/src/main/scala/org/apache/spark/metrics/source/AccumulatorSource.scala create mode 100644 core/src/test/scala/org/apache/spark/metrics/source/AccumulatorSourceSuite.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/AccumulatorMetricsTest.scala diff --git a/core/src/main/scala/org/apache/spark/metrics/source/AccumulatorSource.scala b/core/src/main/scala/org/apache/spark/metrics/source/AccumulatorSource.scala new file mode 100644 index 0000000000000..45a4d224d45fe --- /dev/null +++ b/core/src/main/scala/org/apache/spark/metrics/source/AccumulatorSource.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.metrics.source + +import com.codahale.metrics.{Gauge, MetricRegistry} + +import org.apache.spark.SparkContext +import org.apache.spark.annotation.Experimental +import org.apache.spark.util.{AccumulatorV2, DoubleAccumulator, LongAccumulator} + +/** + * AccumulatorSource is a Spark metric Source that reports the current value + * of the accumulator as a gauge. + * + * It is restricted to the LongAccumulator and the DoubleAccumulator, as those + * are the current built-in numerical accumulators with Spark, and excludes + * the CollectionAccumulator, as that is a List of values (hard to report, + * to a metrics system) + */ +private[spark] class AccumulatorSource extends Source { + private val registry = new MetricRegistry + protected def register[T](accumulators: Map[String, AccumulatorV2[_, T]]): Unit = { + accumulators.foreach { + case (name, accumulator) => + val gauge = new Gauge[T] { + override def getValue: T = accumulator.value + } + registry.register(MetricRegistry.name(name), gauge) + } + } + + override def sourceName: String = "AccumulatorSource" + override def metricRegistry: MetricRegistry = registry +} + +@Experimental +class LongAccumulatorSource extends AccumulatorSource + +@Experimental +class DoubleAccumulatorSource extends AccumulatorSource + +/** + * :: Experimental :: + * Metrics source specifically for LongAccumulators. Accumulators + * are only valid on the driver side, so these metrics are reported + * only by the driver. + * Register LongAccumulators using: + * LongAccumulatorSource.register(sc, {"name" -> longAccumulator}) + */ +@Experimental +object LongAccumulatorSource { + def register(sc: SparkContext, accumulators: Map[String, LongAccumulator]): Unit = { + val source = new LongAccumulatorSource + source.register(accumulators) + sc.env.metricsSystem.registerSource(source) + } +} + +/** + * :: Experimental :: + * Metrics source specifically for DoubleAccumulators. Accumulators + * are only valid on the driver side, so these metrics are reported + * only by the driver. + * Register DoubleAccumulators using: + * DoubleAccumulatorSource.register(sc, {"name" -> doubleAccumulator}) + */ +@Experimental +object DoubleAccumulatorSource { + def register(sc: SparkContext, accumulators: Map[String, DoubleAccumulator]): Unit = { + val source = new DoubleAccumulatorSource + source.register(accumulators) + sc.env.metricsSystem.registerSource(source) + } +} diff --git a/core/src/test/scala/org/apache/spark/metrics/source/AccumulatorSourceSuite.scala b/core/src/test/scala/org/apache/spark/metrics/source/AccumulatorSourceSuite.scala new file mode 100644 index 0000000000000..6a6c07cb068cc --- /dev/null +++ b/core/src/test/scala/org/apache/spark/metrics/source/AccumulatorSourceSuite.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.metrics.source + +import com.codahale.metrics.MetricRegistry +import org.mockito.ArgumentCaptor +import org.mockito.Mockito.{mock, never, spy, times, verify, when} + +import org.apache.spark.{SparkContext, SparkEnv, SparkFunSuite} +import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.util.{DoubleAccumulator, LongAccumulator} + +class AccumulatorSourceSuite extends SparkFunSuite { + test("that that accumulators register against the metric system's register") { + val acc1 = new LongAccumulator() + val acc2 = new LongAccumulator() + val mockContext = mock(classOf[SparkContext]) + val mockEnvironment = mock(classOf[SparkEnv]) + val mockMetricSystem = mock(classOf[MetricsSystem]) + when(mockEnvironment.metricsSystem) thenReturn (mockMetricSystem) + when(mockContext.env) thenReturn (mockEnvironment) + val accs = Map("my-accumulator-1" -> acc1, + "my-accumulator-2" -> acc2) + LongAccumulatorSource.register(mockContext, accs) + val captor = new ArgumentCaptor[AccumulatorSource]() + verify(mockMetricSystem, times(1)).registerSource(captor.capture()) + val source = captor.getValue() + val gauges = source.metricRegistry.getGauges() + assert (gauges.size == 2) + assert (gauges.firstKey == "my-accumulator-1") + assert (gauges.lastKey == "my-accumulator-2") + } + + test("the accumulators value property is checked when the gauge's value is requested") { + val acc1 = new LongAccumulator() + acc1.add(123) + val acc2 = new LongAccumulator() + acc2.add(456) + val mockContext = mock(classOf[SparkContext]) + val mockEnvironment = mock(classOf[SparkEnv]) + val mockMetricSystem = mock(classOf[MetricsSystem]) + when(mockEnvironment.metricsSystem) thenReturn (mockMetricSystem) + when(mockContext.env) thenReturn (mockEnvironment) + val accs = Map("my-accumulator-1" -> acc1, + "my-accumulator-2" -> acc2) + LongAccumulatorSource.register(mockContext, accs) + val captor = new ArgumentCaptor[AccumulatorSource]() + verify(mockMetricSystem, times(1)).registerSource(captor.capture()) + val source = captor.getValue() + val gauges = source.metricRegistry.getGauges() + assert(gauges.get("my-accumulator-1").getValue() == 123) + assert(gauges.get("my-accumulator-2").getValue() == 456) + } + + test("the double accumulators value propety is checked when the gauge's value is requested") { + val acc1 = new DoubleAccumulator() + acc1.add(123.123) + val acc2 = new DoubleAccumulator() + acc2.add(456.456) + val mockContext = mock(classOf[SparkContext]) + val mockEnvironment = mock(classOf[SparkEnv]) + val mockMetricSystem = mock(classOf[MetricsSystem]) + when(mockEnvironment.metricsSystem) thenReturn (mockMetricSystem) + when(mockContext.env) thenReturn (mockEnvironment) + val accs = Map( + "my-accumulator-1" -> acc1, + "my-accumulator-2" -> acc2) + DoubleAccumulatorSource.register(mockContext, accs) + val captor = new ArgumentCaptor[AccumulatorSource]() + verify(mockMetricSystem, times(1)).registerSource(captor.capture()) + val source = captor.getValue() + val gauges = source.metricRegistry.getGauges() + assert(gauges.get("my-accumulator-1").getValue() == 123.123) + assert(gauges.get("my-accumulator-2").getValue() == 456.456) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/AccumulatorMetricsTest.scala b/examples/src/main/scala/org/apache/spark/examples/AccumulatorMetricsTest.scala new file mode 100644 index 0000000000000..5d9a9a73f12ec --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/AccumulatorMetricsTest.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples + +import org.apache.spark.metrics.source.{DoubleAccumulatorSource, LongAccumulatorSource} +import org.apache.spark.sql.SparkSession + +/** + * Usage: AccumulatorMetricsTest [numElem] + * + * This example shows how to register accumulators against the accumulator source. + * A simple RDD is created, and during the map, the accumulators are incremented. + * + * The only argument, numElem, sets the number elements in the collection to parallize. + * + * The result is output to stdout in the driver with the values of the accumulators. + * For the long accumulator, it should equal numElem the double accumulator should be + * roughly 1.1 x numElem (within double precision.) This example also sets up a + * ConsoleSink (metrics) instance, and so registered codahale metrics (like the + * accumulator source) are reported to stdout as well. + */ +object AccumulatorMetricsTest { + def main(args: Array[String]) { + + val spark = SparkSession + .builder() + .config("spark.metrics.conf.*.sink.console.class", + "org.apache.spark.metrics.sink.ConsoleSink") + .getOrCreate() + + val sc = spark.sparkContext + + val acc = sc.longAccumulator("my-long-metric") + // register the accumulator, the metric system will report as + // [spark.metrics.namespace].[execId|driver].AccumulatorSource.my-long-metric + LongAccumulatorSource.register(sc, List(("my-long-metric" -> acc)).toMap) + + val acc2 = sc.doubleAccumulator("my-double-metric") + // register the accumulator, the metric system will report as + // [spark.metrics.namespace].[execId|driver].AccumulatorSource.my-double-metric + DoubleAccumulatorSource.register(sc, List(("my-double-metric" -> acc2)).toMap) + + val num = if (args.length > 0) args(0).toInt else 1000000 + + val startTime = System.nanoTime + + val accumulatorTest = sc.parallelize(1 to num).foreach(_ => { + acc.add(1) + acc2.add(1.1) + }) + + // Print a footer with test time and accumulator values + println("Test took %.0f milliseconds".format((System.nanoTime - startTime) / 1E6)) + println("Accumulator values:") + println("*** Long accumulator (my-long-metric): " + acc.value) + println("*** Double accumulator (my-double-metric): " + acc2.value) + + spark.stop() + } +} +// scalastyle:on println From ce1961034b3bf807e597fe3ff1e4c6b70faf125d Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Sat, 22 Dec 2018 10:32:32 -0600 Subject: [PATCH 122/194] [SPARK-25245][DOCS][SS] Explain regarding limiting modification on "spark.sql.shuffle.partitions" for structured streaming ## What changes were proposed in this pull request? This patch adds explanation of `why "spark.sql.shuffle.partitions" keeps unchanged in structured streaming`, which couple of users already wondered and some of them even thought it as a bug. This patch would help other end users to know about such behavior before they find by theirselves and being wondered. ## How was this patch tested? No need to test because this is a simple addition on guide doc with markdown editor. Closes #22238 from HeartSaVioR/SPARK-25245. Lead-authored-by: Jungtaek Lim Co-authored-by: Jungtaek Lim (HeartSaVioR) Signed-off-by: Sean Owen --- docs/structured-streaming-programming-guide.md | 10 ++++++++++ .../scala/org/apache/spark/sql/internal/SQLConf.scala | 8 ++++++-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 32d61dcdb4599..e76b53dbb4dc3 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -3113,6 +3113,16 @@ See [Input Sources](#input-sources) and [Output Sinks](#output-sinks) sections f # Additional Information +**Notes** + +- Several configurations are not modifiable after the query has run. To change them, discard the checkpoint and start a new query. These configurations include: + - `spark.sql.shuffle.partitions` + - This is due to the physical partitioning of state: state is partitioned via applying hash function to key, hence the number of partitions for state should be unchanged. + - If you want to run fewer tasks for stateful operations, `coalesce` would help with avoiding unnecessary repartitioning. + - After `coalesce`, the number of (reduced) tasks will be kept unless another shuffle happens. + - `spark.sql.streaming.stateStore.providerClass`: To read the previous state of the query properly, the class of state store provider should be unchanged. + - `spark.sql.streaming.multipleWatermarkPolicy`: Modification of this would lead inconsistent watermark value when query contains multiple watermarks, hence the policy should be unchanged. + **Further Reading** - See and run the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 86e068bf632bd..fe445e0019353 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -263,7 +263,9 @@ object SQLConf { .createWithDefault(true) val SHUFFLE_PARTITIONS = buildConf("spark.sql.shuffle.partitions") - .doc("The default number of partitions to use when shuffling data for joins or aggregations.") + .doc("The default number of partitions to use when shuffling data for joins or aggregations. " + + "Note: For structured streaming, this configuration cannot be changed between query " + + "restarts from the same checkpoint location.") .intConf .createWithDefault(200) @@ -882,7 +884,9 @@ object SQLConf { .internal() .doc( "The class used to manage state data in stateful streaming queries. This class must " + - "be a subclass of StateStoreProvider, and must have a zero-arg constructor.") + "be a subclass of StateStoreProvider, and must have a zero-arg constructor. " + + "Note: For structured streaming, this configuration cannot be changed between query " + + "restarts from the same checkpoint location.") .stringConf .createWithDefault( "org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider") From c03498e551e1d261a94136da85bc1919db814ebf Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Sat, 22 Dec 2018 10:35:14 -0800 Subject: [PATCH 123/194] [SPARK-26402][SQL] Accessing nested fields with different cases in case insensitive mode ## What changes were proposed in this pull request? GetStructField with different optional names should be semantically equal. We will use this as building block to compare the nested fields used in the plans to be optimized by catalyst optimizer. This PR also fixes a bug below that accessing nested fields with different cases in case insensitive mode will result `AnalysisException`. ``` sql("create table t (s struct) using json") sql("select s.I from t group by s.i") ``` which is currently failing ``` org.apache.spark.sql.AnalysisException: expression 'default.t.`s`' is neither present in the group by, nor is it an aggregate function ``` as cloud-fan pointed out. ## How was this patch tested? New tests are added. Closes #23353 from dbtsai/nestedEqual. Lead-authored-by: DB Tsai Co-authored-by: DB Tsai Signed-off-by: Dongjoon Hyun --- .../catalyst/expressions/Canonicalize.scala | 4 ++- .../expressions/CanonicalizeSuite.scala | 29 ++++++++++++++++++ .../BinaryComparisonSimplificationSuite.scala | 30 +++++++++++++++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 19 ++++++++++++ 4 files changed, 81 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala index fe6db8b344d3d..4d218b936b3a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala @@ -26,6 +26,7 @@ package org.apache.spark.sql.catalyst.expressions * * The following rules are applied: * - Names and nullability hints for [[org.apache.spark.sql.types.DataType]]s are stripped. + * - Names for [[GetStructField]] are stripped. * - Commutative and associative operations ([[Add]] and [[Multiply]]) have their children ordered * by `hashCode`. * - [[EqualTo]] and [[EqualNullSafe]] are reordered by `hashCode`. @@ -37,10 +38,11 @@ object Canonicalize { expressionReorder(ignoreNamesTypes(e)) } - /** Remove names and nullability from types. */ + /** Remove names and nullability from types, and names from `GetStructField`. */ private[expressions] def ignoreNamesTypes(e: Expression): Expression = e match { case a: AttributeReference => AttributeReference("none", a.dataType.asNullable)(exprId = a.exprId) + case GetStructField(child, ordinal, Some(_)) => GetStructField(child, ordinal, None) case _ => e } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala index 28e6940f3cca3..9802a6e5891b8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.plans.logical.Range +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} class CanonicalizeSuite extends SparkFunSuite { @@ -50,4 +51,32 @@ class CanonicalizeSuite extends SparkFunSuite { assert(range.where(arrays1).sameResult(range.where(arrays2))) assert(!range.where(arrays1).sameResult(range.where(arrays3))) } + + test("SPARK-26402: accessing nested fields with different cases in case insensitive mode") { + val expId = NamedExpression.newExprId + val qualifier = Seq.empty[String] + val structType = StructType( + StructField("a", StructType(StructField("b", IntegerType, false) :: Nil), false) :: Nil) + + // GetStructField with different names are semantically equal + val fieldA1 = GetStructField( + AttributeReference("data1", structType, false)(expId, qualifier), + 0, Some("a1")) + val fieldA2 = GetStructField( + AttributeReference("data2", structType, false)(expId, qualifier), + 0, Some("a2")) + assert(fieldA1.semanticEquals(fieldA2)) + + val fieldB1 = GetStructField( + GetStructField( + AttributeReference("data1", structType, false)(expId, qualifier), + 0, Some("a1")), + 0, Some("b1")) + val fieldB2 = GetStructField( + GetStructField( + AttributeReference("data2", structType, false)(expId, qualifier), + 0, Some("a2")), + 0, Some("b2")) + assert(fieldB1.semanticEquals(fieldB2)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala index a313681eeb8f0..5794691a365a9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLite import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper { @@ -92,4 +93,33 @@ class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper val correctAnswer = nonNullableRelation.analyze comparePlans(actual, correctAnswer) } + + test("SPARK-26402: accessing nested fields with different cases in case insensitive mode") { + val expId = NamedExpression.newExprId + val qualifier = Seq.empty[String] + val structType = StructType( + StructField("a", StructType(StructField("b", IntegerType, false) :: Nil), false) :: Nil) + + val fieldA1 = GetStructField( + GetStructField( + AttributeReference("data1", structType, false)(expId, qualifier), + 0, Some("a1")), + 0, Some("b1")) + val fieldA2 = GetStructField( + GetStructField( + AttributeReference("data2", structType, false)(expId, qualifier), + 0, Some("a2")), + 0, Some("b2")) + + // GetStructField with different names are semantically equal; thus, `EqualTo(fieldA1, fieldA2)` + // will be optimized to `TrueLiteral` by `SimplifyBinaryComparison`. + val originalQuery = nonNullableRelation + .where(EqualTo(fieldA1, fieldA2)) + .analyze + + val optimized = Optimize.execute(originalQuery) + val correctAnswer = nonNullableRelation.analyze + + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 37a8815350a53..656da9fa01806 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2937,6 +2937,25 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-26402: accessing nested fields with different cases in case insensitive mode") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val msg = intercept[AnalysisException] { + withTable("t") { + sql("create table t (s struct) using json") + checkAnswer(sql("select s.I from t group by s.i"), Nil) + } + }.message + assert(msg.contains("No such struct field I in i")) + } + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + withTable("t") { + sql("create table t (s struct) using json") + checkAnswer(sql("select s.I from t group by s.i"), Nil) + } + } + } } case class Foo(bar: Option[String]) From 4536d5389612adfdba55b8bc4d6e2761426397e1 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Mon, 24 Dec 2018 10:47:47 +0800 Subject: [PATCH 124/194] [SPARK-26178][SPARK-26243][SQL][FOLLOWUP] Replacing SimpleDateFormat by DateTimeFormatter in comments ## What changes were proposed in this pull request? The PRs #23150 and #23196 switched JSON and CSV datasources on new formatter for dates/timestamps which is based on `DateTimeFormatter`. In this PR, I replaced `SimpleDateFormat` by `DateTimeFormatter` to reflect the changes. Closes #23374 from MaxGekk/java-time-docs. Authored-by: Maxim Gekk Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/readwriter.py | 28 +++++++++++-------- python/pyspark/sql/streaming.py | 14 ++++++---- .../apache/spark/sql/DataFrameReader.scala | 12 ++++---- .../apache/spark/sql/DataFrameWriter.scala | 12 ++++---- .../sql/streaming/DataStreamReader.scala | 12 ++++---- 5 files changed, 42 insertions(+), 36 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 7b10512a43294..3da052391a95b 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -226,11 +226,12 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, it uses the value specified in ``spark.sql.columnNameOfCorruptRecord``. :param dateFormat: sets the string that indicates a date format. Custom date formats - follow the formats at ``java.text.SimpleDateFormat``. This + follow the formats at ``java.time.format.DateTimeFormatter``. This applies to date type. If None is set, it uses the default value, ``yyyy-MM-dd``. - :param timestampFormat: sets the string that indicates a timestamp format. Custom date - formats follow the formats at ``java.text.SimpleDateFormat``. + :param timestampFormat: sets the string that indicates a timestamp format. + Custom date formats follow the formats at + ``java.time.format.DateTimeFormatter``. This applies to timestamp type. If None is set, it uses the default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. :param multiLine: parse one record, which may span multiple lines, per file. If None is @@ -406,11 +407,12 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non :param negativeInf: sets the string representation of a negative infinity value. If None is set, it uses the default value, ``Inf``. :param dateFormat: sets the string that indicates a date format. Custom date formats - follow the formats at ``java.text.SimpleDateFormat``. This + follow the formats at ``java.time.format.DateTimeFormatter``. This applies to date type. If None is set, it uses the default value, ``yyyy-MM-dd``. - :param timestampFormat: sets the string that indicates a timestamp format. Custom date - formats follow the formats at ``java.text.SimpleDateFormat``. + :param timestampFormat: sets the string that indicates a timestamp format. + Custom date formats follow the formats at + ``java.time.format.DateTimeFormatter``. This applies to timestamp type. If None is set, it uses the default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. :param maxColumns: defines a hard limit of how many columns a record can have. If None is @@ -803,11 +805,12 @@ def json(self, path, mode=None, compression=None, dateFormat=None, timestampForm known case-insensitive shorten names (none, bzip2, gzip, lz4, snappy and deflate). :param dateFormat: sets the string that indicates a date format. Custom date formats - follow the formats at ``java.text.SimpleDateFormat``. This + follow the formats at ``java.time.format.DateTimeFormatter``. This applies to date type. If None is set, it uses the default value, ``yyyy-MM-dd``. - :param timestampFormat: sets the string that indicates a timestamp format. Custom date - formats follow the formats at ``java.text.SimpleDateFormat``. + :param timestampFormat: sets the string that indicates a timestamp format. + Custom date formats follow the formats at + ``java.time.format.DateTimeFormatter``. This applies to timestamp type. If None is set, it uses the default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. :param encoding: specifies encoding (charset) of saved json files. If None is set, @@ -904,11 +907,12 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No :param nullValue: sets the string representation of a null value. If None is set, it uses the default value, empty string. :param dateFormat: sets the string that indicates a date format. Custom date formats - follow the formats at ``java.text.SimpleDateFormat``. This + follow the formats at ``java.time.format.DateTimeFormatter``. This applies to date type. If None is set, it uses the default value, ``yyyy-MM-dd``. - :param timestampFormat: sets the string that indicates a timestamp format. Custom date - formats follow the formats at ``java.text.SimpleDateFormat``. + :param timestampFormat: sets the string that indicates a timestamp format. + Custom date formats follow the formats at + ``java.time.format.DateTimeFormatter``. This applies to timestamp type. If None is set, it uses the default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. :param ignoreLeadingWhiteSpace: a flag indicating whether or not leading whitespaces from diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index fc23b9d99c34a..b981fdc4edc77 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -456,11 +456,12 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, it uses the value specified in ``spark.sql.columnNameOfCorruptRecord``. :param dateFormat: sets the string that indicates a date format. Custom date formats - follow the formats at ``java.text.SimpleDateFormat``. This + follow the formats at ``java.time.format.DateTimeFormatter``. This applies to date type. If None is set, it uses the default value, ``yyyy-MM-dd``. - :param timestampFormat: sets the string that indicates a timestamp format. Custom date - formats follow the formats at ``java.text.SimpleDateFormat``. + :param timestampFormat: sets the string that indicates a timestamp format. + Custom date formats follow the formats at + ``java.time.format.DateTimeFormatter``. This applies to timestamp type. If None is set, it uses the default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. :param multiLine: parse one record, which may span multiple lines, per file. If None is @@ -630,11 +631,12 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non :param negativeInf: sets the string representation of a negative infinity value. If None is set, it uses the default value, ``Inf``. :param dateFormat: sets the string that indicates a date format. Custom date formats - follow the formats at ``java.text.SimpleDateFormat``. This + follow the formats at ``java.time.format.DateTimeFormatter``. This applies to date type. If None is set, it uses the default value, ``yyyy-MM-dd``. - :param timestampFormat: sets the string that indicates a timestamp format. Custom date - formats follow the formats at ``java.text.SimpleDateFormat``. + :param timestampFormat: sets the string that indicates a timestamp format. + Custom date formats follow the formats at + ``java.time.format.DateTimeFormatter``. This applies to timestamp type. If None is set, it uses the default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. :param maxColumns: defines a hard limit of how many columns a record can have. If None is diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 9751528654ffb..ce8e4c8f5b82b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -375,11 +375,11 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.

  • *
  • `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format. - * Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to - * date type.
  • + * Custom date formats follow the formats at `java.time.format.DateTimeFormatter`. + * This applies to date type. *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at - * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • + * `java.time.format.DateTimeFormatter`. This applies to timestamp type. *
  • `multiLine` (default `false`): parse one record, which may span multiple lines, * per file
  • *
  • `encoding` (by default it is not set): allows to forcibly set one of standard basic @@ -585,11 +585,11 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `negativeInf` (default `-Inf`): sets the string representation of a negative infinity * value.
  • *
  • `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format. - * Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to - * date type.
  • + * Custom date formats follow the formats at `java.time.format.DateTimeFormatter`. + * This applies to date type. *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at - * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • + * `java.time.format.DateTimeFormatter`. This applies to timestamp type. *
  • `maxColumns` (default `20480`): defines a hard limit of how many columns * a record can have.
  • *
  • `maxCharsPerColumn` (default `-1`): defines the maximum number of characters allowed diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index b9c4076994e96..981b3a8fd4ac1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -530,11 +530,11 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`, * `snappy` and `deflate`).
  • *
  • `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format. - * Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to - * date type.
  • + * Custom date formats follow the formats at `java.time.format.DateTimeFormatter`. + * This applies to date type. *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at - * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • + * `java.time.format.DateTimeFormatter`. This applies to timestamp type. *
  • `encoding` (by default it is not set): specifies encoding (charset) of saved json * files. If it is not set, the UTF-8 charset will be used.
  • *
  • `lineSep` (default `\n`): defines the line separator that should be used for writing.
  • @@ -649,11 +649,11 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`, * `snappy` and `deflate`). *
  • `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format. - * Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to - * date type.
  • + * Custom date formats follow the formats at `java.time.format.DateTimeFormatter`. + * This applies to date type. *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at - * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • + * `java.time.format.DateTimeFormatter`. This applies to timestamp type. *
  • `ignoreLeadingWhiteSpace` (default `true`): a flag indicating whether or not leading * whitespaces from values being written should be skipped.
  • *
  • `ignoreTrailingWhiteSpace` (default `true`): a flag indicating defines whether or not diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 914fa90ae7e14..98589da9552cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -286,11 +286,11 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.
  • *
  • `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format. - * Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to - * date type.
  • + * Custom date formats follow the formats at `java.time.format.DateTimeFormatter`. + * This applies to date type. *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at - * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • + * `java.time.format.DateTimeFormatter`. This applies to timestamp type. *
  • `multiLine` (default `false`): parse one record, which may span multiple lines, * per file
  • *
  • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator @@ -347,11 +347,11 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
  • `negativeInf` (default `-Inf`): sets the string representation of a negative infinity * value.
  • *
  • `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format. - * Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to - * date type.
  • + * Custom date formats follow the formats at `java.time.format.DateTimeFormatter`. + * This applies to date type. *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at - * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • + * `java.time.format.DateTimeFormatter`. This applies to timestamp type. *
  • `maxColumns` (default `20480`): defines a hard limit of how many columns * a record can have.
  • *
  • `maxCharsPerColumn` (default `-1`): defines the maximum number of characters allowed From e73d73ed8ac9075a3d3b91305112798ffcd554f6 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sun, 23 Dec 2018 21:09:44 -0800 Subject: [PATCH 125/194] [SPARK-14023][CORE][SQL] Don't reference 'field' in StructField errors for clarity in exceptions ## What changes were proposed in this pull request? Variation of https://github.com/apache/spark/pull/20500 I cheated by not referencing fields or columns at all as this exception propagates in contexts where both would be applicable. ## How was this patch tested? Existing tests Closes #23373 from srowen/SPARK-14023.2. Authored-by: Sean Owen Signed-off-by: Dongjoon Hyun --- .../org/apache/spark/sql/types/StructType.scala | 17 +++++++---------- .../spark/sql/types/StructTypeSuite.scala | 8 ++++---- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 6e8bbde7787a6..e01d7c59cac52 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -28,7 +28,6 @@ import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, LegacyTypeStringParser} import org.apache.spark.sql.catalyst.util.{quoteIdentifier, truncatedString} -import org.apache.spark.util.Utils /** * A [[StructType]] object can be constructed by @@ -57,7 +56,7 @@ import org.apache.spark.util.Utils * * // If this struct does not have a field called "d", it throws an exception. * struct("d") - * // java.lang.IllegalArgumentException: Field "d" does not exist. + * // java.lang.IllegalArgumentException: d does not exist. * // ... * * // Extract multiple StructFields. Field names are provided in a set. @@ -69,7 +68,7 @@ import org.apache.spark.util.Utils * // Any names without matching fields will throw an exception. * // For the case shown below, an exception is thrown due to "d". * struct(Set("b", "c", "d")) - * // java.lang.IllegalArgumentException: Field "d" does not exist. + * // java.lang.IllegalArgumentException: d does not exist. * // ... * }}} * @@ -272,22 +271,21 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru def apply(name: String): StructField = { nameToField.getOrElse(name, throw new IllegalArgumentException( - s"""Field "$name" does not exist. - |Available fields: ${fieldNames.mkString(", ")}""".stripMargin)) + s"$name does not exist. Available: ${fieldNames.mkString(", ")}")) } /** * Returns a [[StructType]] containing [[StructField]]s of the given names, preserving the * original order of fields. * - * @throws IllegalArgumentException if a field cannot be found for any of the given names + * @throws IllegalArgumentException if at least one given field name does not exist */ def apply(names: Set[String]): StructType = { val nonExistFields = names -- fieldNamesSet if (nonExistFields.nonEmpty) { throw new IllegalArgumentException( - s"""Nonexistent field(s): ${nonExistFields.mkString(", ")}. - |Available fields: ${fieldNames.mkString(", ")}""".stripMargin) + s"${nonExistFields.mkString(", ")} do(es) not exist. " + + s"Available: ${fieldNames.mkString(", ")}") } // Preserve the original order of fields. StructType(fields.filter(f => names.contains(f.name))) @@ -301,8 +299,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru def fieldIndex(name: String): Int = { nameToIndex.getOrElse(name, throw new IllegalArgumentException( - s"""Field "$name" does not exist. - |Available fields: ${fieldNames.mkString(", ")}""".stripMargin)) + s"$name does not exist. Available: ${fieldNames.mkString(", ")}")) } private[sql] def getFieldIndex(name: String): Option[Int] = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala index 53a78c94aa6fb..b4ce26be24de2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala @@ -22,21 +22,21 @@ import org.apache.spark.sql.types.StructType.fromDDL class StructTypeSuite extends SparkFunSuite { - val s = StructType.fromDDL("a INT, b STRING") + private val s = StructType.fromDDL("a INT, b STRING") test("lookup a single missing field should output existing fields") { val e = intercept[IllegalArgumentException](s("c")).getMessage - assert(e.contains("Available fields: a, b")) + assert(e.contains("Available: a, b")) } test("lookup a set of missing fields should output existing fields") { val e = intercept[IllegalArgumentException](s(Set("a", "c"))).getMessage - assert(e.contains("Available fields: a, b")) + assert(e.contains("Available: a, b")) } test("lookup fieldIndex for missing field should output existing fields") { val e = intercept[IllegalArgumentException](s.fieldIndex("c")).getMessage - assert(e.contains("Available fields: a, b")) + assert(e.contains("Available: a, b")) } test("SPARK-24849: toDDL - simple struct") { From 35c680e9fde4b091a7174fdf0b6aee2c63574cb4 Mon Sep 17 00:00:00 2001 From: wangyanlin01 Date: Tue, 25 Dec 2018 15:53:42 +0800 Subject: [PATCH 126/194] [SPARK-26426][SQL] fix ExpresionInfo assert error in windows operation system. ## What changes were proposed in this pull request? fix ExpresionInfo assert error in windows operation system, when running unit tests. ## How was this patch tested? unit tests Closes #23363 from yanlin-Lynn/unit-test-windows. Authored-by: wangyanlin01 Signed-off-by: Hyukjin Kwon --- .../apache/spark/sql/catalyst/expressions/ExpressionInfo.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java index ab13ac9cc5483..d5a1b77c0ec81 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java @@ -79,7 +79,7 @@ public ExpressionInfo( assert name != null; assert arguments != null; assert examples != null; - assert examples.isEmpty() || examples.startsWith("\n Examples:"); + assert examples.isEmpty() || examples.startsWith(System.lineSeparator() + " Examples:"); assert note != null; assert since != null; From 210550c1552495ae4393db223e656cca35d8f8cf Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 27 Dec 2018 11:09:50 +0800 Subject: [PATCH 127/194] [SPARK-26424][SQL] Use java.time API in date/timestamp expressions ## What changes were proposed in this pull request? In the PR, I propose to switch the `DateFormatClass`, `ToUnixTimestamp`, `FromUnixTime`, `UnixTime` on java.time API for parsing/formatting dates and timestamps. The API has been already implemented by the `Timestamp`/`DateFormatter` classes. One of benefit is those classes support parsing timestamps with microsecond precision. Old behaviour can be switched on via SQL config: `spark.sql.legacy.timeParser.enabled` (`false` by default). ## How was this patch tested? It was tested by existing test suites - `DateFunctionsSuite`, `DateExpressionsSuite`, `JsonSuite`, `CsvSuite`, `SQLQueryTestSuite` as well as PySpark tests. Closes #23358 from MaxGekk/new-time-cast. Lead-authored-by: Maxim Gekk Co-authored-by: Maxim Gekk Signed-off-by: Wenchen Fan --- R/pkg/R/functions.R | 8 +- docs/sql-migration-guide-upgrade.md | 1 + python/pyspark/sql/functions.py | 6 +- .../sql/catalyst/csv/CSVInferSchema.scala | 3 +- .../expressions/datetimeExpressions.scala | 82 +++++++++++-------- .../sql/catalyst/json/JsonInferSchema.scala | 3 +- .../sql/catalyst/util/DateFormatter.scala | 8 +- .../util/DateTimeFormatterHelper.scala | 21 +++-- .../sql/catalyst/util/DateTimeUtils.scala | 10 --- .../catalyst/util/TimestampFormatter.scala | 22 ++++- .../catalyst/csv/UnivocityParserSuite.scala | 2 +- .../spark/sql/util/DateFormatterSuite.scala | 7 ++ .../sql/util/TimestampFormatterSuite.scala | 12 +++ .../org/apache/spark/sql/functions.scala | 10 +-- .../apache/spark/sql/DateFunctionsSuite.scala | 2 +- 15 files changed, 122 insertions(+), 75 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index f568a931ae1fe..5b3cc0940d9c3 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -1723,7 +1723,7 @@ setMethod("radians", #' @details #' \code{to_date}: Converts the column into a DateType. You may optionally specify #' a format according to the rules in: -#' \url{http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html}. +#' \url{https://docs.oracle.com/javase/8/docs/api/java/time/format/DateTimeFormatter.html}. #' If the string cannot be parsed according to the specified format (or default), #' the value of the column will be null. #' By default, it follows casting rules to a DateType if the format is omitted @@ -1819,7 +1819,7 @@ setMethod("to_csv", signature(x = "Column"), #' @details #' \code{to_timestamp}: Converts the column into a TimestampType. You may optionally specify #' a format according to the rules in: -#' \url{http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html}. +#' \url{https://docs.oracle.com/javase/8/docs/api/java/time/format/DateTimeFormatter.html}. #' If the string cannot be parsed according to the specified format (or default), #' the value of the column will be null. #' By default, it follows casting rules to a TimestampType if the format is omitted @@ -2240,7 +2240,7 @@ setMethod("n", signature(x = "Column"), #' \code{date_format}: Converts a date/timestamp/string to a value of string in the format #' specified by the date format given by the second argument. A pattern could be for instance #' \code{dd.MM.yyyy} and could return a string like '18.03.1993'. All -#' pattern letters of \code{java.text.SimpleDateFormat} can be used. +#' pattern letters of \code{java.time.format.DateTimeFormatter} can be used. #' Note: Use when ever possible specialized functions like \code{year}. These benefit from a #' specialized implementation. #' @@ -2666,7 +2666,7 @@ setMethod("format_string", signature(format = "character", x = "Column"), #' \code{from_unixtime}: Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) #' to a string representing the timestamp of that moment in the current system time zone in the JVM #' in the given format. -#' See \href{http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html}{ +#' See \href{https://docs.oracle.com/javase/8/docs/api/java/time/format/DateTimeFormatter.html}{ #' Customizing Formats} for available options. #' #' @rdname column_datetime_functions diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index 1bd3b5ad0e1aa..c4d2157de8b60 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -39,6 +39,7 @@ displayTitle: Spark SQL Upgrading Guide - In Spark version 2.4 and earlier, JSON datasource and JSON functions like `from_json` convert a bad JSON record to a row with all `null`s in the PERMISSIVE mode when specified schema is `StructType`. Since Spark 3.0, the returned row can contain non-`null` fields if some of JSON column values were parsed and converted to desired types successfully. + - Since Spark 3.0, the `unix_timestamp`, `date_format`, `to_unix_timestamp`, `from_unixtime`, `to_date`, `to_timestamp` functions use java.time API for parsing and formatting dates/timestamps from/to strings by using ISO chronology (https://docs.oracle.com/javase/8/docs/api/java/time/chrono/IsoChronology.html) based on Proleptic Gregorian calendar. In Spark version 2.4 and earlier, java.text.SimpleDateFormat and java.util.GregorianCalendar (hybrid calendar that supports both the Julian and Gregorian calendar systems, see https://docs.oracle.com/javase/7/docs/api/java/util/GregorianCalendar.html) is used for the same purpuse. New implementation supports pattern formats as described here https://docs.oracle.com/javase/8/docs/api/java/time/format/DateTimeFormatter.html and performs strict checking of its input. For example, the `2015-07-22 10:00:00` timestamp cannot be parse if pattern is `yyyy-MM-dd` because the parser does not consume whole input. Another example is the `31/01/2015 00:00` input cannot be parsed by the `dd/MM/yyyy hh:mm` pattern because `hh` supposes hours in the range `1-12`. To switch back to the implementation used in Spark 2.4 and earlier, set `spark.sql.legacy.timeParser.enabled` to `true`. ## Upgrading From Spark SQL 2.3 to 2.4 - In Spark version 2.3 and earlier, the second parameter to array_contains function is implicitly promoted to the element type of first array type parameter. This type promotion can be lossy and may cause `array_contains` function to return wrong result. This problem has been addressed in 2.4 by employing a safer type promotion mechanism. This can cause some change in behavior and are illustrated in the table below. diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index d188de39e21c7..d2a771e9bb8ea 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -874,7 +874,7 @@ def date_format(date, format): format given by the second argument. A pattern could be for instance `dd.MM.yyyy` and could return a string like '18.03.1993'. All - pattern letters of the Java class `java.text.SimpleDateFormat` can be used. + pattern letters of the Java class `java.time.format.DateTimeFormatter` can be used. .. note:: Use when ever possible specialized functions like `year`. These benefit from a specialized implementation. @@ -1094,7 +1094,7 @@ def to_date(col, format=None): """Converts a :class:`Column` of :class:`pyspark.sql.types.StringType` or :class:`pyspark.sql.types.TimestampType` into :class:`pyspark.sql.types.DateType` using the optionally specified format. Specify formats according to - `SimpleDateFormats `_. + `DateTimeFormatter `_. # noqa By default, it follows casting rules to :class:`pyspark.sql.types.DateType` if the format is omitted (equivalent to ``col.cast("date")``). @@ -1119,7 +1119,7 @@ def to_timestamp(col, format=None): """Converts a :class:`Column` of :class:`pyspark.sql.types.StringType` or :class:`pyspark.sql.types.TimestampType` into :class:`pyspark.sql.types.DateType` using the optionally specified format. Specify formats according to - `SimpleDateFormats `_. + `DateTimeFormatter `_. # noqa By default, it follows casting rules to :class:`pyspark.sql.types.TimestampType` if the format is omitted (equivalent to ``col.cast("timestamp")``). diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala index 35ade136cc607..4dd41042856d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala @@ -27,8 +27,7 @@ import org.apache.spark.sql.types._ class CSVInferSchema(val options: CSVOptions) extends Serializable { - @transient - private lazy val timestampParser = TimestampFormatter( + private val timestampParser = TimestampFormatter( options.timestampFormat, options.timeZone, options.locale) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 45e17ae235a94..73af0a3c5c2ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.Timestamp -import java.text.DateFormat -import java.util.{Calendar, TimeZone} +import java.util.{Calendar, Locale, TimeZone} import scala.util.control.NonFatal @@ -28,7 +27,8 @@ import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, TimestampFormatter} +import org.apache.spark.sql.catalyst.util.DateTimeUtils._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -562,16 +562,17 @@ case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Opti copy(timeZoneId = Option(timeZoneId)) override protected def nullSafeEval(timestamp: Any, format: Any): Any = { - val df = DateTimeUtils.newDateFormat(format.toString, timeZone) - UTF8String.fromString(df.format(new java.util.Date(timestamp.asInstanceOf[Long] / 1000))) + val df = TimestampFormatter(format.toString, timeZone, Locale.US) + UTF8String.fromString(df.format(timestamp.asInstanceOf[Long])) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val tf = TimestampFormatter.getClass.getName.stripSuffix("$") val tz = ctx.addReferenceObj("timeZone", timeZone) + val locale = ctx.addReferenceObj("locale", Locale.US) defineCodeGen(ctx, ev, (timestamp, format) => { - s"""UTF8String.fromString($dtu.newDateFormat($format.toString(), $tz) - .format(new java.util.Date($timestamp / 1000)))""" + s"""UTF8String.fromString($tf.apply($format.toString(), $tz, $locale) + .format($timestamp))""" }) } @@ -612,9 +613,10 @@ case class ToUnixTimestamp( } /** - * Converts time string with given pattern. - * (see [http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html]) - * to Unix time stamp (in seconds), returns null if fail. + * Converts time string with given pattern to Unix time stamp (in seconds), returns null if fail. + * See [http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html] + * if SQL config spark.sql.legacy.timeParser.enabled is set to true otherwise + * [https://docs.oracle.com/javase/8/docs/api/java/time/format/DateTimeFormatter.html]. * Note that hive Language Manual says it returns 0 if fail, but in fact it returns null. * If the second parameter is missing, use "yyyy-MM-dd HH:mm:ss". * If no parameters provided, the first parameter will be current_timestamp. @@ -663,9 +665,9 @@ abstract class UnixTime override def nullable: Boolean = true private lazy val constFormat: UTF8String = right.eval().asInstanceOf[UTF8String] - private lazy val formatter: DateFormat = + private lazy val formatter: TimestampFormatter = try { - DateTimeUtils.newDateFormat(constFormat.toString, timeZone) + TimestampFormatter(constFormat.toString, timeZone, Locale.US) } catch { case NonFatal(_) => null } @@ -677,16 +679,16 @@ abstract class UnixTime } else { left.dataType match { case DateType => - DateTimeUtils.daysToMillis(t.asInstanceOf[Int], timeZone) / 1000L + DateTimeUtils.daysToMillis(t.asInstanceOf[Int], timeZone) / MILLIS_PER_SECOND case TimestampType => - t.asInstanceOf[Long] / 1000000L + t.asInstanceOf[Long] / MICROS_PER_SECOND case StringType if right.foldable => if (constFormat == null || formatter == null) { null } else { try { formatter.parse( - t.asInstanceOf[UTF8String].toString).getTime / 1000L + t.asInstanceOf[UTF8String].toString) / MICROS_PER_SECOND } catch { case NonFatal(_) => null } @@ -698,8 +700,8 @@ abstract class UnixTime } else { val formatString = f.asInstanceOf[UTF8String].toString try { - DateTimeUtils.newDateFormat(formatString, timeZone).parse( - t.asInstanceOf[UTF8String].toString).getTime / 1000L + TimestampFormatter(formatString, timeZone, Locale.US).parse( + t.asInstanceOf[UTF8String].toString) / MICROS_PER_SECOND } catch { case NonFatal(_) => null } @@ -712,7 +714,7 @@ abstract class UnixTime val javaType = CodeGenerator.javaType(dataType) left.dataType match { case StringType if right.foldable => - val df = classOf[DateFormat].getName + val df = classOf[TimestampFormatter].getName if (formatter == null) { ExprCode.forNullValue(dataType) } else { @@ -724,24 +726,35 @@ abstract class UnixTime $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { try { - ${ev.value} = $formatterName.parse(${eval1.value}.toString()).getTime() / 1000L; + ${ev.value} = $formatterName.parse(${eval1.value}.toString()) / 1000000L; + } catch (java.lang.IllegalArgumentException e) { + ${ev.isNull} = true; } catch (java.text.ParseException e) { ${ev.isNull} = true; + } catch (java.time.format.DateTimeParseException e) { + ${ev.isNull} = true; + } catch (java.time.DateTimeException e) { + ${ev.isNull} = true; } }""") } case StringType => val tz = ctx.addReferenceObj("timeZone", timeZone) - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val locale = ctx.addReferenceObj("locale", Locale.US) + val dtu = TimestampFormatter.getClass.getName.stripSuffix("$") nullSafeCodeGen(ctx, ev, (string, format) => { s""" try { - ${ev.value} = $dtu.newDateFormat($format.toString(), $tz) - .parse($string.toString()).getTime() / 1000L; + ${ev.value} = $dtu.apply($format.toString(), $tz, $locale) + .parse($string.toString()) / 1000000L; } catch (java.lang.IllegalArgumentException e) { ${ev.isNull} = true; } catch (java.text.ParseException e) { ${ev.isNull} = true; + } catch (java.time.format.DateTimeParseException e) { + ${ev.isNull} = true; + } catch (java.time.DateTimeException e) { + ${ev.isNull} = true; } """ }) @@ -806,9 +819,9 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ copy(timeZoneId = Option(timeZoneId)) private lazy val constFormat: UTF8String = right.eval().asInstanceOf[UTF8String] - private lazy val formatter: DateFormat = + private lazy val formatter: TimestampFormatter = try { - DateTimeUtils.newDateFormat(constFormat.toString, timeZone) + TimestampFormatter(constFormat.toString, timeZone, Locale.US) } catch { case NonFatal(_) => null } @@ -823,8 +836,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ null } else { try { - UTF8String.fromString(formatter.format( - new java.util.Date(time.asInstanceOf[Long] * 1000L))) + UTF8String.fromString(formatter.format(time.asInstanceOf[Long] * MICROS_PER_SECOND)) } catch { case NonFatal(_) => null } @@ -835,8 +847,8 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ null } else { try { - UTF8String.fromString(DateTimeUtils.newDateFormat(f.toString, timeZone) - .format(new java.util.Date(time.asInstanceOf[Long] * 1000L))) + UTF8String.fromString(TimestampFormatter(f.toString, timeZone, Locale.US) + .format(time.asInstanceOf[Long] * MICROS_PER_SECOND)) } catch { case NonFatal(_) => null } @@ -846,7 +858,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val df = classOf[DateFormat].getName + val df = classOf[TimestampFormatter].getName if (format.foldable) { if (formatter == null) { ExprCode.forNullValue(StringType) @@ -859,8 +871,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { try { - ${ev.value} = UTF8String.fromString($formatterName.format( - new java.util.Date(${t.value} * 1000L))); + ${ev.value} = UTF8String.fromString($formatterName.format(${t.value} * 1000000L)); } catch (java.lang.IllegalArgumentException e) { ${ev.isNull} = true; } @@ -868,12 +879,13 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ } } else { val tz = ctx.addReferenceObj("timeZone", timeZone) - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val locale = ctx.addReferenceObj("locale", Locale.US) + val tf = TimestampFormatter.getClass.getName.stripSuffix("$") nullSafeCodeGen(ctx, ev, (seconds, f) => { s""" try { - ${ev.value} = UTF8String.fromString($dtu.newDateFormat($f.toString(), $tz).format( - new java.util.Date($seconds * 1000L))); + ${ev.value} = UTF8String.fromString($tf.apply($f.toString(), $tz, $locale). + format($seconds * 1000000L)); } catch (java.lang.IllegalArgumentException e) { ${ev.isNull} = true; }""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala index d1bc00c08c1c6..3203e626ea400 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala @@ -37,8 +37,7 @@ private[sql] class JsonInferSchema(options: JSONOptions) extends Serializable { private val decimalParser = ExprUtils.getDecimalParser(options.locale) - @transient - private lazy val timestampFormatter = TimestampFormatter( + private val timestampFormatter = TimestampFormatter( options.timestampFormat, options.timeZone, options.locale) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala index 9e8d51cc65f03..b4c99674fc1cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala @@ -26,7 +26,7 @@ import org.apache.commons.lang3.time.FastDateFormat import org.apache.spark.sql.internal.SQLConf -sealed trait DateFormatter { +sealed trait DateFormatter extends Serializable { def parse(s: String): Int // returns days since epoch def format(days: Int): String } @@ -35,7 +35,8 @@ class Iso8601DateFormatter( pattern: String, locale: Locale) extends DateFormatter with DateTimeFormatterHelper { - private val formatter = buildFormatter(pattern, locale) + @transient + private lazy val formatter = buildFormatter(pattern, locale) private val UTC = ZoneId.of("UTC") private def toInstant(s: String): Instant = { @@ -56,7 +57,8 @@ class Iso8601DateFormatter( } class LegacyDateFormatter(pattern: String, locale: Locale) extends DateFormatter { - private val format = FastDateFormat.getInstance(pattern, locale) + @transient + private lazy val format = FastDateFormat.getInstance(pattern, locale) override def parse(s: String): Int = { val milliseconds = format.parse(s).getTime diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala index b85101d38d9e6..91cc57e0bb019 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala @@ -17,27 +17,36 @@ package org.apache.spark.sql.catalyst.util -import java.time.{Instant, LocalDateTime, ZonedDateTime, ZoneId} -import java.time.format.{DateTimeFormatter, DateTimeFormatterBuilder} -import java.time.temporal.{ChronoField, TemporalAccessor} +import java.time._ +import java.time.chrono.IsoChronology +import java.time.format.{DateTimeFormatter, DateTimeFormatterBuilder, ResolverStyle} +import java.time.temporal.{ChronoField, TemporalAccessor, TemporalQueries} import java.util.Locale trait DateTimeFormatterHelper { protected def buildFormatter(pattern: String, locale: Locale): DateTimeFormatter = { new DateTimeFormatterBuilder() + .parseCaseInsensitive() .appendPattern(pattern) - .parseDefaulting(ChronoField.YEAR_OF_ERA, 1970) + .parseDefaulting(ChronoField.ERA, 1) .parseDefaulting(ChronoField.MONTH_OF_YEAR, 1) .parseDefaulting(ChronoField.DAY_OF_MONTH, 1) - .parseDefaulting(ChronoField.HOUR_OF_DAY, 0) .parseDefaulting(ChronoField.MINUTE_OF_HOUR, 0) .parseDefaulting(ChronoField.SECOND_OF_MINUTE, 0) .toFormatter(locale) + .withChronology(IsoChronology.INSTANCE) + .withResolverStyle(ResolverStyle.STRICT) } protected def toInstantWithZoneId(temporalAccessor: TemporalAccessor, zoneId: ZoneId): Instant = { - val localDateTime = LocalDateTime.from(temporalAccessor) + val localTime = if (temporalAccessor.query(TemporalQueries.localTime) == null) { + LocalTime.ofNanoOfDay(0) + } else { + LocalTime.from(temporalAccessor) + } + val localDate = LocalDate.from(temporalAccessor) + val localDateTime = LocalDateTime.of(localDate, localTime) val zonedDateTime = ZonedDateTime.of(localDateTime, zoneId) Instant.from(zonedDateTime) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index c6dfdbf2505ba..3e5e1fbc2b368 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -111,16 +111,6 @@ object DateTimeUtils { computedTimeZones.computeIfAbsent(timeZoneId, computeTimeZone) } - def newDateFormat(formatString: String, timeZone: TimeZone): DateFormat = { - val sdf = new SimpleDateFormat(formatString, Locale.US) - sdf.setTimeZone(timeZone) - // Enable strict parsing, if the input date/format is invalid, it will throw an exception. - // e.g. to parse invalid date '2016-13-12', or '2016-01-12' with invalid format 'yyyy-aa-dd', - // an exception will be throwed. - sdf.setLenient(false) - sdf - } - // we should use the exact day as Int, for example, (year, month, day) -> day def millisToDays(millisUtc: Long): SQLDate = { millisToDays(millisUtc, defaultTimeZone()) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala index eb1303303463d..b67b2d7cc3c51 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.catalyst.util +import java.text.ParseException import java.time._ +import java.time.format.DateTimeParseException import java.time.temporal.TemporalQueries import java.util.{Locale, TimeZone} @@ -27,7 +29,19 @@ import org.apache.commons.lang3.time.FastDateFormat import org.apache.spark.sql.internal.SQLConf -sealed trait TimestampFormatter { +sealed trait TimestampFormatter extends Serializable { + /** + * Parses a timestamp in a string and converts it to microseconds. + * + * @param s - string with timestamp to parse + * @return microseconds since epoch. + * @throws ParseException can be thrown by legacy parser + * @throws DateTimeParseException can be thrown by new parser + * @throws DateTimeException unable to obtain local date or time + */ + @throws(classOf[ParseException]) + @throws(classOf[DateTimeParseException]) + @throws(classOf[DateTimeException]) def parse(s: String): Long // returns microseconds since epoch def format(us: Long): String } @@ -36,7 +50,8 @@ class Iso8601TimestampFormatter( pattern: String, timeZone: TimeZone, locale: Locale) extends TimestampFormatter with DateTimeFormatterHelper { - private val formatter = buildFormatter(pattern, locale) + @transient + private lazy val formatter = buildFormatter(pattern, locale) private def toInstant(s: String): Instant = { val temporalAccessor = formatter.parse(s) @@ -68,7 +83,8 @@ class LegacyTimestampFormatter( pattern: String, timeZone: TimeZone, locale: Locale) extends TimestampFormatter { - private val format = FastDateFormat.getInstance(pattern, timeZone, locale) + @transient + private lazy val format = FastDateFormat.getInstance(pattern, timeZone, locale) protected def toMillis(s: String): Long = format.parse(s).getTime diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala index 2d0b0d3033a9c..4ae61bc61255c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala @@ -112,7 +112,7 @@ class UnivocityParserSuite extends SparkFunSuite with SQLHelper { assert(parser.makeConverter("_1", BooleanType).apply("true") == true) var timestampsOptions = - new CSVOptions(Map("timestampFormat" -> "dd/MM/yyyy hh:mm"), false, "GMT") + new CSVOptions(Map("timestampFormat" -> "dd/MM/yyyy HH:mm"), false, "GMT") parser = new UnivocityParser(StructType(Seq.empty), timestampsOptions) val customTimestamp = "31/01/2015 00:00" var format = FastDateFormat.getInstance( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateFormatterSuite.scala index 019615b81101c..2dc55e0e1f633 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateFormatterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateFormatterSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.util +import java.time.LocalDate import java.util.Locale import org.apache.spark.SparkFunSuite @@ -89,4 +90,10 @@ class DateFormatterSuite extends SparkFunSuite with SQLHelper { } } } + + test("parsing date without explicit day") { + val formatter = DateFormatter("yyyy MMM", Locale.US) + val daysSinceEpoch = formatter.parse("2018 Dec") + assert(daysSinceEpoch === LocalDate.of(2018, 12, 1).toEpochDay) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala index c110ffa01f733..edccbb2a7f5db 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.util +import java.time.{LocalDateTime, ZoneOffset} import java.util.{Locale, TimeZone} +import java.util.concurrent.TimeUnit import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.plans.SQLHelper @@ -106,4 +108,14 @@ class TimestampFormatterSuite extends SparkFunSuite with SQLHelper { } } } + + test(" case insensitive parsing of am and pm") { + val formatter = TimestampFormatter( + "yyyy MMM dd hh:mm:ss a", + TimeZone.getTimeZone("UTC"), + Locale.US) + val micros = formatter.parse("2009 Mar 20 11:30:01 am") + assert(micros === TimeUnit.SECONDS.toMicros( + LocalDateTime.of(2009, 3, 20, 11, 30, 1).toEpochSecond(ZoneOffset.UTC))) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 33186f778d868..645452553e6a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2578,7 +2578,7 @@ object functions { * Converts a date/timestamp/string to a value of string in the format specified by the date * format given by the second argument. * - * See [[java.text.SimpleDateFormat]] for valid date and time format patterns + * See [[java.time.format.DateTimeFormatter]] for valid date and time format patterns * * @param dateExpr A date, timestamp or string. If a string, the data must be in a format that * can be cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` @@ -2811,7 +2811,7 @@ object functions { * representing the timestamp of that moment in the current system time zone in the given * format. * - * See [[java.text.SimpleDateFormat]] for valid date and time format patterns + * See [[java.time.format.DateTimeFormatter]] for valid date and time format patterns * * @param ut A number of a type that is castable to a long, such as string or integer. Can be * negative for timestamps before the unix epoch @@ -2855,7 +2855,7 @@ object functions { /** * Converts time string with given pattern to Unix timestamp (in seconds). * - * See [[java.text.SimpleDateFormat]] for valid date and time format patterns + * See [[java.time.format.DateTimeFormatter]] for valid date and time format patterns * * @param s A date, timestamp or string. If a string, the data must be in a format that can be * cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` @@ -2883,7 +2883,7 @@ object functions { /** * Converts time string with the given pattern to timestamp. * - * See [[java.text.SimpleDateFormat]] for valid date and time format patterns + * See [[java.time.format.DateTimeFormatter]] for valid date and time format patterns * * @param s A date, timestamp or string. If a string, the data must be in a format that can be * cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` @@ -2908,7 +2908,7 @@ object functions { /** * Converts the column into a `DateType` with a specified format * - * See [[java.text.SimpleDateFormat]] for valid date and time format patterns + * See [[java.time.format.DateTimeFormatter]] for valid date and time format patterns * * @param e A date, timestamp or string. If a string, the data must be in a format that can be * cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index c4ec7150c4075..62bb72dd6ea25 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -405,7 +405,7 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { Row(Date.valueOf("2014-12-31")))) checkAnswer( df.select(to_date(col("s"), "yyyy-MM-dd")), - Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31")), Row(null))) + Seq(Row(null), Row(Date.valueOf("2014-12-31")), Row(null))) // now switch format checkAnswer( From 46913ceb97b1dd599e6c8d025543de3a684f4bbe Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 27 Dec 2018 16:03:14 +0800 Subject: [PATCH 128/194] [SPARK-26435][SQL] Support creating partitioned table using Hive CTAS by specifying partition column names ## What changes were proposed in this pull request? Spark SQL doesn't support creating partitioned table using Hive CTAS in SQL syntax. However it is supported by using DataFrameWriter API. ```scala val df = Seq(("a", 1)).toDF("part", "id") df.write.format("hive").partitionBy("part").saveAsTable("t") ``` Hive begins to support this syntax in newer version: https://issues.apache.org/jira/browse/HIVE-20241: ``` CREATE TABLE t PARTITIONED BY (part) AS SELECT 1 as id, "a" as part ``` This patch adds this support to SQL syntax. ## How was this patch tested? Added tests. Closes #23376 from viirya/hive-ctas-partitioned-table. Authored-by: Liang-Chi Hsieh Signed-off-by: Wenchen Fan --- .../spark/sql/catalyst/parser/SqlBase.g4 | 3 +- .../spark/sql/execution/SparkSqlParser.scala | 33 ++++++++----- .../sql/hive/execution/HiveDDLSuite.scala | 48 ++++++++++++++++++- .../sql/hive/execution/SQLQuerySuite.scala | 4 +- 4 files changed, 71 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 5e732edb17baa..b39681d886c5c 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -88,7 +88,8 @@ statement (AS? query)? #createTable | createTableHeader ('(' columns=colTypeList ')')? ((COMMENT comment=STRING) | - (PARTITIONED BY '(' partitionColumns=colTypeList ')') | + (PARTITIONED BY '(' partitionColumns=colTypeList ')' | + PARTITIONED BY partitionColumnNames=identifierList) | bucketSpec | skewSpec | rowFormat | diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 364efea52830e..8deb55b00a9d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -1196,33 +1196,40 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { selectQuery match { case Some(q) => - // Hive does not allow to use a CTAS statement to create a partitioned table. - if (tableDesc.partitionColumnNames.nonEmpty) { - val errorMessage = "A Create Table As Select (CTAS) statement is not allowed to " + - "create a partitioned table using Hive's file formats. " + - "Please use the syntax of \"CREATE TABLE tableName USING dataSource " + - "OPTIONS (...) PARTITIONED BY ...\" to create a partitioned table through a " + - "CTAS statement." - operationNotAllowed(errorMessage, ctx) - } - // Don't allow explicit specification of schema for CTAS. - if (schema.nonEmpty) { + if (dataCols.nonEmpty) { operationNotAllowed( "Schema may not be specified in a Create Table As Select (CTAS) statement", ctx) } + // When creating partitioned table with CTAS statement, we can't specify data type for the + // partition columns. + if (partitionCols.nonEmpty) { + val errorMessage = "Create Partitioned Table As Select cannot specify data type for " + + "the partition columns of the target table." + operationNotAllowed(errorMessage, ctx) + } + + // Hive CTAS supports dynamic partition by specifying partition column names. + val partitionColumnNames = + Option(ctx.partitionColumnNames) + .map(visitIdentifierList(_).toArray) + .getOrElse(Array.empty[String]) + + val tableDescWithPartitionColNames = + tableDesc.copy(partitionColumnNames = partitionColumnNames) + val hasStorageProperties = (ctx.createFileFormat.size != 0) || (ctx.rowFormat.size != 0) if (conf.convertCTAS && !hasStorageProperties) { // At here, both rowStorage.serdeProperties and fileStorage.serdeProperties // are empty Maps. - val newTableDesc = tableDesc.copy( + val newTableDesc = tableDescWithPartitionColNames.copy( storage = CatalogStorageFormat.empty.copy(locationUri = locUri), provider = Some(conf.defaultDataSourceName)) CreateTable(newTableDesc, mode, Some(q)) } else { - CreateTable(tableDesc, mode, Some(q)) + CreateTable(tableDescWithPartitionColNames, mode, Some(q)) } case None => CreateTable(tableDesc, mode, None) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index fd38944a5dd2e..6abdc4054cb0c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.hive.execution import java.io.File import java.net.URI -import java.util.Date import scala.language.existentials @@ -33,6 +32,7 @@ import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.execution.command.{DDLSuite, DDLUtils} import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.HiveExternalCatalog @@ -2370,4 +2370,50 @@ class HiveDDLSuite )) } } + + test("Hive CTAS can't create partitioned table by specifying schema") { + val err1 = intercept[ParseException] { + spark.sql( + s""" + |CREATE TABLE t (a int) + |PARTITIONED BY (b string) + |STORED AS parquet + |AS SELECT 1 as a, "a" as b + """.stripMargin) + }.getMessage + assert(err1.contains("Schema may not be specified in a Create Table As Select " + + "(CTAS) statement")) + + val err2 = intercept[ParseException] { + spark.sql( + s""" + |CREATE TABLE t + |PARTITIONED BY (b string) + |STORED AS parquet + |AS SELECT 1 as a, "a" as b + """.stripMargin) + }.getMessage + assert(err2.contains("Create Partitioned Table As Select cannot specify data type for " + + "the partition columns of the target table")) + } + + test("Hive CTAS with dynamic partition") { + Seq("orc", "parquet").foreach { format => + withTable("t") { + withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") { + spark.sql( + s""" + |CREATE TABLE t + |PARTITIONED BY (b) + |STORED AS $format + |AS SELECT 1 as a, "a" as b + """.stripMargin) + checkAnswer(spark.table("t"), Row(1, "a")) + + assert(spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + .partitionColumnNames === Seq("b")) + } + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 6acf44606cbbe..70efad103d13e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -692,8 +692,8 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { |AS SELECT key, value FROM mytable1 """.stripMargin) }.getMessage - assert(e.contains("A Create Table As Select (CTAS) statement is not allowed to " + - "create a partitioned table using Hive's file formats")) + assert(e.contains("Create Partitioned Table As Select cannot specify data type for " + + "the partition columns of the target table")) } } } From a72b963df91e72ba316452fa5dff180e46cd4885 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 27 Dec 2018 11:13:16 +0100 Subject: [PATCH 129/194] [SPARK-26191][SQL] Control truncation of Spark plans via maxFields parameter ## What changes were proposed in this pull request? In the PR, I propose to add `maxFields` parameter to all functions involved in creation of textual representation of spark plans such as `simpleString` and `verboseString`. New parameter restricts number of fields converted to truncated strings. Any elements beyond the limit will be dropped and replaced by a `"... N more fields"` placeholder. The threshold is bumped up to `Int.MaxValue` for `toFile()`. ## How was this patch tested? Added a test to `QueryExecutionSuite` which checks `maxFields` impacts on number of truncated fields in `LocalRelation`. Closes #23159 from MaxGekk/to-file-max-fields. Lead-authored-by: Maxim Gekk Co-authored-by: Maxim Gekk Signed-off-by: Herman van Hovell --- .../sql/catalyst/analysis/Analyzer.scala | 4 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 8 ++- .../sql/catalyst/analysis/TypeCoercion.scala | 4 +- .../catalyst/encoders/ExpressionEncoder.scala | 8 ++- .../sql/catalyst/expressions/Expression.scala | 6 +- .../expressions/codegen/javaCode.scala | 2 +- .../sql/catalyst/expressions/generators.scala | 3 +- .../expressions/higherOrderFunctions.scala | 4 +- .../spark/sql/catalyst/expressions/misc.scala | 5 +- .../expressions/namedExpressions.scala | 4 +- .../spark/sql/catalyst/plans/QueryPlan.scala | 4 +- .../catalyst/plans/logical/LogicalPlan.scala | 4 +- .../plans/logical/basicLogicalOperators.scala | 8 +-- .../spark/sql/catalyst/trees/TreeNode.scala | 56 +++++++++++-------- .../spark/sql/catalyst/util/package.scala | 10 ++-- .../apache/spark/sql/types/StructType.scala | 6 +- .../aggregate/PercentileSuite.scala | 2 +- .../org/apache/spark/sql/util/UtilSuite.scala | 2 +- .../sql/execution/DataSourceScanExec.scala | 12 ++-- .../spark/sql/execution/ExistingRDD.scala | 6 +- .../spark/sql/execution/QueryExecution.scala | 21 ++++--- .../spark/sql/execution/SparkPlanInfo.scala | 6 +- .../sql/execution/WholeStageCodegenExec.scala | 28 ++++++++-- .../aggregate/HashAggregateExec.scala | 12 ++-- .../aggregate/ObjectHashAggregateExec.scala | 12 ++-- .../aggregate/SortAggregateExec.scala | 12 ++-- .../execution/basicPhysicalOperators.scala | 4 +- .../execution/columnar/InMemoryRelation.scala | 4 +- .../datasources/LogicalRelation.scala | 4 +- .../SaveIntoDataSourceCommand.scala | 2 +- .../spark/sql/execution/datasources/ddl.scala | 2 +- .../datasources/v2/DataSourceV2Relation.scala | 8 ++- .../datasources/v2/DataSourceV2ScanExec.scala | 4 +- .../v2/DataSourceV2StreamingScanExec.scala | 2 +- .../v2/DataSourceV2StringFormat.scala | 6 +- .../spark/sql/execution/debug/package.scala | 3 +- .../apache/spark/sql/execution/limit.scala | 6 +- .../streaming/MicroBatchExecution.scala | 6 +- .../continuous/ContinuousExecution.scala | 7 ++- .../sql/execution/streaming/memory.scala | 5 +- .../apache/spark/sql/execution/subquery.scala | 2 +- .../sql/execution/QueryExecutionSuite.scala | 13 +++++ .../CreateHiveTableAsSelectCommand.scala | 2 +- 43 files changed, 203 insertions(+), 126 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 777053168a056..198645d875c47 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -979,7 +979,7 @@ class Analyzer( a.mapExpressions(resolveExpressionTopDown(_, appendColumns)) case q: LogicalPlan => - logTrace(s"Attempting to resolve ${q.simpleString}") + logTrace(s"Attempting to resolve ${q.simpleString(SQLConf.get.maxToStringFields)}") q.mapExpressions(resolveExpressionTopDown(_, q)) } @@ -1777,7 +1777,7 @@ class Analyzer( case p if p.expressions.exists(hasGenerator) => throw new AnalysisException("Generators are not supported outside the SELECT clause, but " + - "got: " + p.simpleString) + "got: " + p.simpleString(SQLConf.get.maxToStringFields)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 88d41e8824405..c28a97839fe49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** @@ -303,7 +304,7 @@ trait CheckAnalysis extends PredicateHelper { val missingAttributes = o.missingInput.mkString(",") val input = o.inputSet.mkString(",") val msgForMissingAttributes = s"Resolved attribute(s) $missingAttributes missing " + - s"from $input in operator ${operator.simpleString}." + s"from $input in operator ${operator.simpleString(SQLConf.get.maxToStringFields)}." val resolver = plan.conf.resolver val attrsWithSameName = o.missingInput.filter { missing => @@ -368,7 +369,7 @@ trait CheckAnalysis extends PredicateHelper { s"""nondeterministic expressions are only allowed in |Project, Filter, Aggregate or Window, found: | ${o.expressions.map(_.sql).mkString(",")} - |in operator ${operator.simpleString} + |in operator ${operator.simpleString(SQLConf.get.maxToStringFields)} """.stripMargin) case _: UnresolvedHint => @@ -380,7 +381,8 @@ trait CheckAnalysis extends PredicateHelper { } extendedCheckRules.foreach(_(plan)) plan.foreachUp { - case o if !o.resolved => failAnalysis(s"unresolved operator ${o.simpleString}") + case o if !o.resolved => + failAnalysis(s"unresolved operator ${o.simpleString(SQLConf.get.maxToStringFields)}") case _ => } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 1706b3eece6d7..b19aa50ba2156 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -1069,8 +1069,8 @@ trait TypeCoercionRule extends Rule[LogicalPlan] with Logging { // Leave the same if the dataTypes match. case Some(newType) if a.dataType == newType.dataType => a case Some(newType) => - logDebug( - s"Promoting $a from ${a.dataType} to ${newType.dataType} in ${q.simpleString}") + logDebug(s"Promoting $a from ${a.dataType} to ${newType.dataType} in " + + s" ${q.simpleString(SQLConf.get.maxToStringFields)}") newType } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index fbf0bd68b9584..da5c1fd0feb01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, InitializeJavaBean, Invoke, NewInstance} import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, DeserializeToObject, LocalRelation} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ObjectType, StringType, StructField, StructType} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -323,8 +324,8 @@ case class ExpressionEncoder[T]( extractProjection(inputRow) } catch { case e: Exception => - throw new RuntimeException( - s"Error while encoding: $e\n${serializer.map(_.simpleString).mkString("\n")}", e) + throw new RuntimeException(s"Error while encoding: $e\n" + + s"${serializer.map(_.simpleString(SQLConf.get.maxToStringFields)).mkString("\n")}", e) } /** @@ -336,7 +337,8 @@ case class ExpressionEncoder[T]( constructProjection(row).get(0, ObjectType(clsTag.runtimeClass)).asInstanceOf[T] } catch { case e: Exception => - throw new RuntimeException(s"Error while decoding: $e\n${deserializer.simpleString}", e) + throw new RuntimeException(s"Error while decoding: $e\n" + + s"${deserializer.simpleString(SQLConf.get.maxToStringFields)}", e) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index c89c2272be752..d5d119543da77 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -259,12 +259,12 @@ abstract class Expression extends TreeNode[Expression] { // Marks this as final, Expression.verboseString should never be called, and thus shouldn't be // overridden by concrete classes. - final override def verboseString: String = simpleString + final override def verboseString(maxFields: Int): String = simpleString(maxFields) - override def simpleString: String = toString + override def simpleString(maxFields: Int): String = toString override def toString: String = prettyName + truncatedString( - flatArguments.toSeq, "(", ", ", ")") + flatArguments.toSeq, "(", ", ", ")", SQLConf.get.maxToStringFields) /** * Returns SQL representation of this expression. For expressions extending [[NonSQLExpression]], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index 17d4a0dc4e884..17fff64a1b7df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -197,7 +197,7 @@ trait Block extends TreeNode[Block] with JavaCode { case _ => code"$this\n$other" } - override def verboseString: String = toString + override def verboseString(maxFields: Int): String = toString } object Block { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 9c74fdf6c9a14..6b6da1c8b4142 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** @@ -101,7 +102,7 @@ case class UserDefinedGenerator( inputRow = new InterpretedProjection(children) convertToScala = { val inputSchema = StructType(children.map { e => - StructField(e.simpleString, e.dataType, nullable = true) + StructField(e.simpleString(SQLConf.get.maxToStringFields), e.dataType, nullable = true) }) CatalystTypeConverters.createToScalaConverter(inputSchema) }.asInstanceOf[InternalRow => Row] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 7141b6e996389..e6cc11d1ad280 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -76,7 +76,9 @@ case class NamedLambdaVariable( override def toString: String = s"lambda $name#${exprId.id}$typeSuffix" - override def simpleString: String = s"lambda $name#${exprId.id}: ${dataType.simpleString}" + override def simpleString(maxFields: Int): String = { + s"lambda $name#${exprId.id}: ${dataType.simpleString(maxFields)}" + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 0cdeda9b10516..1f1decc45a3f6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.RandomUUIDGenerator +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -40,7 +41,7 @@ case class PrintToStderr(child: Expression) extends UnaryExpression { input } - private val outputPrefix = s"Result of ${child.simpleString} is " + private val outputPrefix = s"Result of ${child.simpleString(SQLConf.get.maxToStringFields)} is " override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val outputPrefixField = ctx.addReferenceObj("outputPrefix", outputPrefix) @@ -72,7 +73,7 @@ case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCa override def prettyName: String = "assert_true" - private val errMsg = s"'${child.simpleString}' is not true!" + private val errMsg = s"'${child.simpleString(SQLConf.get.maxToStringFields)}' is not true!" override def eval(input: InternalRow) : Any = { val v = child.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 02b48f9e30f2d..131459bf27bc8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -330,7 +330,9 @@ case class AttributeReference( // Since the expression id is not in the first constructor it is missing from the default // tree string. - override def simpleString: String = s"$name#${exprId.id}: ${dataType.simpleString}" + override def simpleString(maxFields: Int): String = { + s"$name#${exprId.id}: ${dataType.simpleString(maxFields)}" + } override def sql: String = { val qualifierPrefix = if (qualifier.nonEmpty) qualifier.mkString(".") + "." else "" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index ca0cea6ba7de3..125181fb213f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -172,9 +172,9 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT */ protected def statePrefix = if (missingInput.nonEmpty && children.nonEmpty) "!" else "" - override def simpleString: String = statePrefix + super.simpleString + override def simpleString(maxFields: Int): String = statePrefix + super.simpleString(maxFields) - override def verboseString: String = simpleString + override def verboseString(maxFields: Int): String = simpleString(maxFields) /** * All the subqueries of current plan. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 3ad2ee6923615..51e0f4b4c84dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -36,8 +36,8 @@ abstract class LogicalPlan /** Returns true if this subtree has data from a streaming data source. */ def isStreaming: Boolean = children.exists(_.isStreaming == true) - override def verboseStringWithSuffix: String = { - super.verboseString + statsCache.map(", " + _.toString).getOrElse("") + override def verboseStringWithSuffix(maxFields: Int): String = { + super.verboseString(maxFields) + statsCache.map(", " + _.toString).getOrElse("") } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index a26ec4eed8648..d8b3a4af4f7bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -468,7 +468,7 @@ case class View( override def newInstance(): LogicalPlan = copy(output = output.map(_.newInstance())) - override def simpleString: String = { + override def simpleString(maxFields: Int): String = { s"View (${desc.identifier}, ${output.mkString("[", ",", "]")})" } } @@ -484,8 +484,8 @@ case class View( case class With(child: LogicalPlan, cteRelations: Seq[(String, SubqueryAlias)]) extends UnaryNode { override def output: Seq[Attribute] = child.output - override def simpleString: String = { - val cteAliases = truncatedString(cteRelations.map(_._1), "[", ", ", "]") + override def simpleString(maxFields: Int): String = { + val cteAliases = truncatedString(cteRelations.map(_._1), "[", ", ", "]", maxFields) s"CTE $cteAliases" } @@ -557,7 +557,7 @@ case class Range( override def newInstance(): Range = copy(output = output.map(_.newInstance())) - override def simpleString: String = { + override def simpleString(maxFields: Int): String = { s"Range ($start, $end, step=$step, splits=$numSlices)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 2e9f9f53e94ac..21e59bbd283e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning} import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel @@ -433,17 +434,17 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { private lazy val allChildren: Set[TreeNode[_]] = (children ++ innerChildren).toSet[TreeNode[_]] /** Returns a string representing the arguments to this node, minus any children */ - def argString: String = stringArgs.flatMap { + def argString(maxFields: Int): String = stringArgs.flatMap { case tn: TreeNode[_] if allChildren.contains(tn) => Nil case Some(tn: TreeNode[_]) if allChildren.contains(tn) => Nil - case Some(tn: TreeNode[_]) => tn.simpleString :: Nil - case tn: TreeNode[_] => tn.simpleString :: Nil + case Some(tn: TreeNode[_]) => tn.simpleString(maxFields) :: Nil + case tn: TreeNode[_] => tn.simpleString(maxFields) :: Nil case seq: Seq[Any] if seq.toSet.subsetOf(allChildren.asInstanceOf[Set[Any]]) => Nil case iter: Iterable[_] if iter.isEmpty => Nil - case seq: Seq[_] => truncatedString(seq, "[", ", ", "]") :: Nil - case set: Set[_] => truncatedString(set.toSeq, "{", ", ", "}") :: Nil + case seq: Seq[_] => truncatedString(seq, "[", ", ", "]", maxFields) :: Nil + case set: Set[_] => truncatedString(set.toSeq, "{", ", ", "}", maxFields) :: Nil case array: Array[_] if array.isEmpty => Nil - case array: Array[_] => truncatedString(array, "[", ", ", "]") :: Nil + case array: Array[_] => truncatedString(array, "[", ", ", "]", maxFields) :: Nil case null => Nil case None => Nil case Some(null) => Nil @@ -456,24 +457,33 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case other => other :: Nil }.mkString(", ") - /** ONE line description of this node. */ - def simpleString: String = s"$nodeName $argString".trim + /** + * ONE line description of this node. + * @param maxFields Maximum number of fields that will be converted to strings. + * Any elements beyond the limit will be dropped. + */ + def simpleString(maxFields: Int): String = { + s"$nodeName ${argString(maxFields)}".trim + } /** ONE line description of this node with more information */ - def verboseString: String + def verboseString(maxFields: Int): String /** ONE line description of this node with some suffix information */ - def verboseStringWithSuffix: String = verboseString + def verboseStringWithSuffix(maxFields: Int): String = verboseString(maxFields) override def toString: String = treeString /** Returns a string representation of the nodes in this tree */ def treeString: String = treeString(verbose = true) - def treeString(verbose: Boolean, addSuffix: Boolean = false): String = { + def treeString( + verbose: Boolean, + addSuffix: Boolean = false, + maxFields: Int = SQLConf.get.maxToStringFields): String = { val writer = new StringBuilderWriter() try { - treeString(writer, verbose, addSuffix) + treeString(writer, verbose, addSuffix, maxFields) writer.toString } finally { writer.close() @@ -483,8 +493,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { def treeString( writer: Writer, verbose: Boolean, - addSuffix: Boolean): Unit = { - generateTreeString(0, Nil, writer, verbose, "", addSuffix) + addSuffix: Boolean, + maxFields: Int): Unit = { + generateTreeString(0, Nil, writer, verbose, "", addSuffix, maxFields) } /** @@ -550,7 +561,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { writer: Writer, verbose: Boolean, prefix: String = "", - addSuffix: Boolean = false): Unit = { + addSuffix: Boolean = false, + maxFields: Int): Unit = { if (depth > 0) { lastChildren.init.foreach { isLast => @@ -560,9 +572,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } val str = if (verbose) { - if (addSuffix) verboseStringWithSuffix else verboseString + if (addSuffix) verboseStringWithSuffix(maxFields) else verboseString(maxFields) } else { - simpleString + simpleString(maxFields) } writer.write(prefix) writer.write(str) @@ -571,17 +583,17 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { if (innerChildren.nonEmpty) { innerChildren.init.foreach(_.generateTreeString( depth + 2, lastChildren :+ children.isEmpty :+ false, writer, verbose, - addSuffix = addSuffix)) + addSuffix = addSuffix, maxFields = maxFields)) innerChildren.last.generateTreeString( depth + 2, lastChildren :+ children.isEmpty :+ true, writer, verbose, - addSuffix = addSuffix) + addSuffix = addSuffix, maxFields = maxFields) } if (children.nonEmpty) { children.init.foreach(_.generateTreeString( - depth + 1, lastChildren :+ false, writer, verbose, prefix, addSuffix)) + depth + 1, lastChildren :+ false, writer, verbose, prefix, addSuffix, maxFields)) children.last.generateTreeString( - depth + 1, lastChildren :+ true, writer, verbose, prefix, addSuffix) + depth + 1, lastChildren :+ true, writer, verbose, prefix, addSuffix, maxFields) } } @@ -664,7 +676,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { t.forall(_.isInstanceOf[Partitioning]) || t.forall(_.isInstanceOf[DataType]) => JArray(t.map(parseToJson).toList) case t: Seq[_] if t.length > 0 && t.head.isInstanceOf[String] => - JString(truncatedString(t, "[", ", ", "]")) + JString(truncatedString(t, "[", ", ", "]", SQLConf.get.maxToStringFields)) case t: Seq[_] => JNull case m: Map[_, _] => JNull // if it's a scala object, we can simply keep the full class path. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index 277584b20dcd2..7f5860e12cfd5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -184,14 +184,14 @@ package object util extends Logging { start: String, sep: String, end: String, - maxNumFields: Int = SQLConf.get.maxToStringFields): String = { - if (seq.length > maxNumFields) { + maxFields: Int): String = { + if (seq.length > maxFields) { if (truncationWarningPrinted.compareAndSet(false, true)) { logWarning( "Truncated the string representation of a plan since it was too large. This " + s"behavior can be adjusted by setting '${SQLConf.MAX_TO_STRING_FIELDS.key}'.") } - val numFields = math.max(0, maxNumFields - 1) + val numFields = math.max(0, maxFields - 1) seq.take(numFields).mkString( start, sep, sep + "... " + (seq.length - numFields) + " more fields" + end) } else { @@ -200,7 +200,9 @@ package object util extends Logging { } /** Shorthand for calling truncatedString() without start or end strings. */ - def truncatedString[T](seq: Seq[T], sep: String): String = truncatedString(seq, "", sep, "") + def truncatedString[T](seq: Seq[T], sep: String, maxFields: Int): String = { + truncatedString(seq, "", sep, "", maxFields) + } /* FIX ME implicit class debugLogging(a: Any) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index e01d7c59cac52..d563276a5711d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -28,6 +28,7 @@ import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, LegacyTypeStringParser} import org.apache.spark.sql.catalyst.util.{quoteIdentifier, truncatedString} +import org.apache.spark.sql.internal.SQLConf /** * A [[StructType]] object can be constructed by @@ -343,7 +344,10 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru override def simpleString: String = { val fieldTypes = fields.view.map(field => s"${field.name}:${field.dataType.simpleString}") - truncatedString(fieldTypes, "struct<", ",", ">") + truncatedString( + fieldTypes, + "struct<", ",", ">", + SQLConf.get.maxToStringFields) } override def catalogString: String = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala index 63c7b42978025..0e0c8e167a0a7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala @@ -215,7 +215,7 @@ class PercentileSuite extends SparkFunSuite { val percentile2 = new Percentile(child, percentage) assertEqual(percentile2.checkInputDataTypes(), TypeCheckFailure(s"Percentage(s) must be between 0.0 and 1.0, " + - s"but got ${percentage.simpleString}")) + s"but got ${percentage.simpleString(100)}")) } val nonFoldablePercentage = Seq(NonFoldableLiteral(0.5), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/UtilSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/UtilSuite.scala index 9c162026942f6..d95de71e897a2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/UtilSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/UtilSuite.scala @@ -26,6 +26,6 @@ class UtilSuite extends SparkFunSuite { assert(truncatedString(Seq(1, 2), "[", ", ", "]", 2) == "[1, 2]") assert(truncatedString(Seq(1, 2, 3), "[", ", ", "]", 2) == "[1, ... 2 more fields]") assert(truncatedString(Seq(1, 2, 3), "[", ", ", "]", -5) == "[, ... 3 more fields]") - assert(truncatedString(Seq(1, 2, 3), ", ") == "1, 2, 3") + assert(truncatedString(Seq(1, 2, 3), ", ", 10) == "1, 2, 3") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 322ffffca564b..1d7dd73706c48 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -52,19 +52,19 @@ trait DataSourceScanExec extends LeafExecNode with CodegenSupport { // Metadata that describes more details of this scan. protected def metadata: Map[String, String] - override def simpleString: String = { + override def simpleString(maxFields: Int): String = { val metadataEntries = metadata.toSeq.sorted.map { case (key, value) => key + ": " + StringUtils.abbreviate(redact(value), 100) } - val metadataStr = truncatedString(metadataEntries, " ", ", ", "") - s"$nodeNamePrefix$nodeName${truncatedString(output, "[", ",", "]")}$metadataStr" + val metadataStr = truncatedString(metadataEntries, " ", ", ", "", maxFields) + s"$nodeNamePrefix$nodeName${truncatedString(output, "[", ",", "]", maxFields)}$metadataStr" } - override def verboseString: String = redact(super.verboseString) + override def verboseString(maxFields: Int): String = redact(super.verboseString(maxFields)) - override def treeString(verbose: Boolean, addSuffix: Boolean): String = { - redact(super.treeString(verbose, addSuffix)) + override def treeString(verbose: Boolean, addSuffix: Boolean, maxFields: Int): String = { + redact(super.treeString(verbose, addSuffix, maxFields)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 49fb288fdea6a..981ecae80a724 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -79,7 +79,7 @@ case class ExternalRDDScanExec[T]( } } - override def simpleString: String = { + override def simpleString(maxFields: Int): String = { s"$nodeName${output.mkString("[", ",", "]")}" } } @@ -156,8 +156,8 @@ case class RDDScanExec( } } - override def simpleString: String = { - s"$nodeName${truncatedString(output, "[", ",", "]")}" + override def simpleString(maxFields: Int): String = { + s"$nodeName${truncatedString(output, "[", ",", "]", maxFields)}" } // Input can be InternalRow, has to be turned into UnsafeRows. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index eef5a3f899f55..9b8d2e830867d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExecutedCommandExec, ShowTablesCommand} import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{BinaryType, DateType, DecimalType, TimestampType, _} import org.apache.spark.util.Utils @@ -208,27 +209,27 @@ class QueryExecution( } } - private def writePlans(writer: Writer): Unit = { + private def writePlans(writer: Writer, maxFields: Int): Unit = { val (verbose, addSuffix) = (true, false) writer.write("== Parsed Logical Plan ==\n") - writeOrError(writer)(logical.treeString(_, verbose, addSuffix)) + writeOrError(writer)(logical.treeString(_, verbose, addSuffix, maxFields)) writer.write("\n== Analyzed Logical Plan ==\n") val analyzedOutput = stringOrError(truncatedString( - analyzed.output.map(o => s"${o.name}: ${o.dataType.simpleString}"), ", ")) + analyzed.output.map(o => s"${o.name}: ${o.dataType.simpleString}"), ", ", maxFields)) writer.write(analyzedOutput) writer.write("\n") - writeOrError(writer)(analyzed.treeString(_, verbose, addSuffix)) + writeOrError(writer)(analyzed.treeString(_, verbose, addSuffix, maxFields)) writer.write("\n== Optimized Logical Plan ==\n") - writeOrError(writer)(optimizedPlan.treeString(_, verbose, addSuffix)) + writeOrError(writer)(optimizedPlan.treeString(_, verbose, addSuffix, maxFields)) writer.write("\n== Physical Plan ==\n") - writeOrError(writer)(executedPlan.treeString(_, verbose, addSuffix)) + writeOrError(writer)(executedPlan.treeString(_, verbose, addSuffix, maxFields)) } override def toString: String = withRedaction { val writer = new StringBuilderWriter() try { - writePlans(writer) + writePlans(writer, SQLConf.get.maxToStringFields) writer.toString } finally { writer.close() @@ -280,14 +281,16 @@ class QueryExecution( /** * Dumps debug information about query execution into the specified file. + * + * @param maxFields maximim number of fields converted to string representation. */ - def toFile(path: String): Unit = { + def toFile(path: String, maxFields: Int = Int.MaxValue): Unit = { val filePath = new Path(path) val fs = filePath.getFileSystem(sparkSession.sessionState.newHadoopConf()) val writer = new BufferedWriter(new OutputStreamWriter(fs.create(filePath))) try { - writePlans(writer) + writePlans(writer, maxFields) writer.write("\n== Whole Stage Codegen ==\n") org.apache.spark.sql.execution.debug.writeCodegen(writer, executedPlan) } finally { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala index 59ffd16381116..f554ff0aa775f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.metric.SQLMetricInfo +import org.apache.spark.sql.internal.SQLConf /** * :: DeveloperApi :: @@ -62,7 +63,10 @@ private[execution] object SparkPlanInfo { case fileScan: FileSourceScanExec => fileScan.metadata case _ => Map[String, String]() } - new SparkPlanInfo(plan.nodeName, plan.simpleString, children.map(fromSparkPlan), + new SparkPlanInfo( + plan.nodeName, + plan.simpleString(SQLConf.get.maxToStringFields), + children.map(fromSparkPlan), metadata, metrics) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index fbda0d87a175f..f4927dedabe56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -87,7 +87,7 @@ trait CodegenSupport extends SparkPlan { this.parent = parent ctx.freshNamePrefix = variablePrefix s""" - |${ctx.registerComment(s"PRODUCE: ${this.simpleString}")} + |${ctx.registerComment(s"PRODUCE: ${this.simpleString(SQLConf.get.maxToStringFields)}")} |${doProduce(ctx)} """.stripMargin } @@ -188,7 +188,7 @@ trait CodegenSupport extends SparkPlan { parent.doConsume(ctx, inputVars, rowVar) } s""" - |${ctx.registerComment(s"CONSUME: ${parent.simpleString}")} + |${ctx.registerComment(s"CONSUME: ${parent.simpleString(SQLConf.get.maxToStringFields)}")} |$evaluated |$consumeFunc """.stripMargin @@ -496,8 +496,16 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with InputRDDCod writer: Writer, verbose: Boolean, prefix: String = "", - addSuffix: Boolean = false): Unit = { - child.generateTreeString(depth, lastChildren, writer, verbose, prefix = "", addSuffix = false) + addSuffix: Boolean = false, + maxFields: Int): Unit = { + child.generateTreeString( + depth, + lastChildren, + writer, + verbose, + prefix = "", + addSuffix = false, + maxFields) } override def needCopyResult: Boolean = false @@ -772,8 +780,16 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int) writer: Writer, verbose: Boolean, prefix: String = "", - addSuffix: Boolean = false): Unit = { - child.generateTreeString(depth, lastChildren, writer, verbose, s"*($codegenStageId) ", false) + addSuffix: Boolean = false, + maxFields: Int): Unit = { + child.generateTreeString( + depth, + lastChildren, + writer, + verbose, + s"*($codegenStageId) ", + false, + maxFields) } override def needStopCheck: Boolean = true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 4827f838fc514..2355d305c38e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -922,18 +922,18 @@ case class HashAggregateExec( """ } - override def verboseString: String = toString(verbose = true) + override def verboseString(maxFields: Int): String = toString(verbose = true, maxFields) - override def simpleString: String = toString(verbose = false) + override def simpleString(maxFields: Int): String = toString(verbose = false, maxFields) - private def toString(verbose: Boolean): String = { + private def toString(verbose: Boolean, maxFields: Int): String = { val allAggregateExpressions = aggregateExpressions testFallbackStartsAt match { case None => - val keyString = truncatedString(groupingExpressions, "[", ", ", "]") - val functionString = truncatedString(allAggregateExpressions, "[", ", ", "]") - val outputString = truncatedString(output, "[", ", ", "]") + val keyString = truncatedString(groupingExpressions, "[", ", ", "]", maxFields) + val functionString = truncatedString(allAggregateExpressions, "[", ", ", "]", maxFields) + val outputString = truncatedString(output, "[", ", ", "]", maxFields) if (verbose) { s"HashAggregate(keys=$keyString, functions=$functionString, output=$outputString)" } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index 7145bb03028d9..bd52c6321647a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -137,15 +137,15 @@ case class ObjectHashAggregateExec( } } - override def verboseString: String = toString(verbose = true) + override def verboseString(maxFields: Int): String = toString(verbose = true, maxFields) - override def simpleString: String = toString(verbose = false) + override def simpleString(maxFields: Int): String = toString(verbose = false, maxFields) - private def toString(verbose: Boolean): String = { + private def toString(verbose: Boolean, maxFields: Int): String = { val allAggregateExpressions = aggregateExpressions - val keyString = truncatedString(groupingExpressions, "[", ", ", "]") - val functionString = truncatedString(allAggregateExpressions, "[", ", ", "]") - val outputString = truncatedString(output, "[", ", ", "]") + val keyString = truncatedString(groupingExpressions, "[", ", ", "]", maxFields) + val functionString = truncatedString(allAggregateExpressions, "[", ", ", "]", maxFields) + val outputString = truncatedString(output, "[", ", ", "]", maxFields) if (verbose) { s"ObjectHashAggregate(keys=$keyString, functions=$functionString, output=$outputString)" } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index d732b905dcdd5..7ab6ecc08a7bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -107,16 +107,16 @@ case class SortAggregateExec( } } - override def simpleString: String = toString(verbose = false) + override def simpleString(maxFields: Int): String = toString(verbose = false, maxFields) - override def verboseString: String = toString(verbose = true) + override def verboseString(maxFields: Int): String = toString(verbose = true, maxFields) - private def toString(verbose: Boolean): String = { + private def toString(verbose: Boolean, maxFields: Int): String = { val allAggregateExpressions = aggregateExpressions - val keyString = truncatedString(groupingExpressions, "[", ", ", "]") - val functionString = truncatedString(allAggregateExpressions, "[", ", ", "]") - val outputString = truncatedString(output, "[", ", ", "]") + val keyString = truncatedString(groupingExpressions, "[", ", ", "]", maxFields) + val functionString = truncatedString(allAggregateExpressions, "[", ", ", "]", maxFields) + val outputString = truncatedString(output, "[", ", ", "]", maxFields) if (verbose) { s"SortAggregate(key=$keyString, functions=$functionString, output=$outputString)" } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 09effe087e195..2570b36b3166d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -586,7 +586,9 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) } } - override def simpleString: String = s"Range ($start, $end, step=$step, splits=$numSlices)" + override def simpleString(maxFields: Int): String = { + s"Range ($start, $end, step=$step, splits=$numSlices)" + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 73eb65f84489c..4109d9994dd8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -209,6 +209,6 @@ case class InMemoryRelation( override protected def otherCopyArgs: Seq[AnyRef] = Seq(statsOfPlanToCache) - override def simpleString: String = - s"InMemoryRelation [${truncatedString(output, ", ")}], ${cacheBuilder.storageLevel}" + override def simpleString(maxFields: Int): String = + s"InMemoryRelation [${truncatedString(output, ", ", maxFields)}], ${cacheBuilder.storageLevel}" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index 1023572d19e2e..db3604fe92cc0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -63,7 +63,9 @@ case class LogicalRelation( case _ => // Do nothing. } - override def simpleString: String = s"Relation[${truncatedString(output, ",")}] $relation" + override def simpleString(maxFields: Int): String = { + s"Relation[${truncatedString(output, ",", maxFields)}] $relation" + } } object LogicalRelation { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala index 00b1b5dedb593..f29e7869fb27c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala @@ -48,7 +48,7 @@ case class SaveIntoDataSourceCommand( Seq.empty[Row] } - override def simpleString: String = { + override def simpleString(maxFields: Int): String = { val redacted = SQLConf.get.redactOptions(options) s"SaveIntoDataSourceCommand ${dataSource}, ${redacted}, ${mode}" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index fdc5e85f3c2ea..042320edea4f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -68,7 +68,7 @@ case class CreateTempViewUsing( s"Temporary view '$tableIdent' should not have specified a database") } - override def argString: String = { + override def argString(maxFields: Int): String = { s"[tableIdent:$tableIdent " + userSpecifiedSchema.map(_ + " ").getOrElse("") + s"replace:$replace " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 0a6b0afe6cfe5..7bf2b8bff3732 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -52,8 +52,8 @@ case class DataSourceV2Relation( override def name: String = table.name() - override def simpleString: String = { - s"RelationV2${truncatedString(output, "[", ", ", "]")} $name" + override def simpleString(maxFields: Int): String = { + s"RelationV2${truncatedString(output, "[", ", ", "]", maxFields)} $name" } def newWriteSupport(): BatchWriteSupport = source.createWriteSupport(options, schema) @@ -96,7 +96,9 @@ case class StreamingDataSourceV2Relation( override def isStreaming: Boolean = true - override def simpleString: String = "Streaming RelationV2 " + metadataString + override def simpleString(maxFields: Int): String = { + "Streaming RelationV2 " + metadataString(maxFields) + } override def pushedFilters: Seq[Expression] = Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 725bcc3af3ca5..53e4e77c65e26 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -35,8 +35,8 @@ case class DataSourceV2ScanExec( @transient batch: Batch) extends LeafExecNode with ColumnarBatchScan { - override def simpleString: String = { - s"ScanV2${truncatedString(output, "[", ", ", "]")} $scanDesc" + override def simpleString(maxFields: Int): String = { + s"ScanV2${truncatedString(output, "[", ", ", "]", maxFields)} $scanDesc" } // TODO: unify the equal/hashCode implementation for all data source v2 query plans. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StreamingScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StreamingScanExec.scala index c872940909964..be75fe4f596dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StreamingScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StreamingScanExec.scala @@ -42,7 +42,7 @@ case class DataSourceV2StreamingScanExec( @transient scanConfig: ScanConfig) extends LeafExecNode with DataSourceV2StringFormat with ColumnarBatchScan { - override def simpleString: String = "ScanV2 " + metadataString + override def simpleString(maxFields: Int): String = "ScanV2 " + metadataString(maxFields) // TODO: unify the equal/hashCode implementation for all data source v2 query plans. override def equals(other: Any): Boolean = other match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala index e829f621b4ea3..f11703c8a2773 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala @@ -59,7 +59,7 @@ trait DataSourceV2StringFormat { case _ => Utils.getSimpleName(source.getClass) } - def metadataString: String = { + def metadataString(maxFields: Int): String = { val entries = scala.collection.mutable.ArrayBuffer.empty[(String, String)] if (pushedFilters.nonEmpty) { @@ -73,12 +73,12 @@ trait DataSourceV2StringFormat { }.mkString("[", ",", "]") } - val outputStr = truncatedString(output, "[", ", ", "]") + val outputStr = truncatedString(output, "[", ", ", "]", maxFields) val entriesStr = if (entries.nonEmpty) { truncatedString(entries.map { case (key, value) => key + ": " + StringUtils.abbreviate(value, 100) - }, " (", ", ", ")") + }, " (", ", ", ")", maxFields) } else { "" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 3511cefa7c292..ae8197f617a28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, Codegen import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.execution.streaming.{StreamExecution, StreamingQueryWrapper} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.StreamingQuery import org.apache.spark.util.{AccumulatorV2, LongAccumulator} @@ -216,7 +217,7 @@ package object debug { val columnStats: Array[ColumnMetrics] = Array.fill(child.output.size)(new ColumnMetrics()) def dumpStats(): Unit = { - debugPrint(s"== ${child.simpleString} ==") + debugPrint(s"== ${child.simpleString(SQLConf.get.maxToStringFields)} ==") debugPrint(s"Tuples output: ${tupleCount.value}") child.output.zip(columnStats).foreach { case (attr, metric) => // This is called on driver. All accumulator updates have a fixed value. So it's safe to use diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index bfaf080292bce..56973af8fd648 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -198,9 +198,9 @@ case class TakeOrderedAndProjectExec( override def outputPartitioning: Partitioning = SinglePartition - override def simpleString: String = { - val orderByString = truncatedString(sortOrder, "[", ",", "]") - val outputString = truncatedString(output, "[", ",", "]") + override def simpleString(maxFields: Int): String = { + val orderByString = truncatedString(sortOrder, "[", ",", "]", maxFields) + val outputString = truncatedString(output, "[", ",", "]", maxFields) s"TakeOrderedAndProject(limit=$limit, orderBy=$orderByString, output=$outputString)" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 8ad436a4ff57d..38ecb0dd12daa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, StreamWriterCommitProgress, WriteToDataSourceV2, WriteToDataSourceV2Exec} import org.apache.spark.sql.execution.streaming.sources.{MicroBatchWritSupport, RateControlMicroBatchReadSupport} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset => OffsetV2} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} @@ -482,9 +483,10 @@ class MicroBatchExecution( val newBatchesPlan = logicalPlan transform { case StreamingExecutionRelation(source, output) => newData.get(source).map { dataPlan => + val maxFields = SQLConf.get.maxToStringFields assert(output.size == dataPlan.output.size, - s"Invalid batch: ${truncatedString(output, ",")} != " + - s"${truncatedString(dataPlan.output, ",")}") + s"Invalid batch: ${truncatedString(output, ",", maxFields)} != " + + s"${truncatedString(dataPlan.output, ",", maxFields)}") val aliases = output.zip(dataPlan.output).map { case (to, from) => Alias(from, to.name)(exprId = to.exprId, explicitMetadata = Some(from.metadata)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index f0859aaaa3041..89033b70f1431 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2StreamingScanExec, StreamingDataSourceV2Relation} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2 import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, StreamingWriteSupportProvider} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, PartitionOffset} @@ -166,10 +167,10 @@ class ContinuousExecution( val readSupport = continuousSources(insertedSourceId) insertedSourceId += 1 val newOutput = readSupport.fullSchema().toAttributes - + val maxFields = SQLConf.get.maxToStringFields assert(output.size == newOutput.size, - s"Invalid reader: ${truncatedString(output, ",")} != " + - s"${truncatedString(newOutput, ",")}") + s"Invalid reader: ${truncatedString(output, ",", maxFields)} != " + + s"${truncatedString(newOutput, ",", maxFields)}") replacements ++= output.zip(newOutput) val loggedOffset = offsets.offsets(0) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index daee089f3871d..13b75ae4a4339 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Stati import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset => OffsetV2} import org.apache.spark.sql.streaming.OutputMode @@ -117,7 +118,9 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) } } - override def toString: String = s"MemoryStream[${truncatedString(output, ",")}]" + override def toString: String = { + s"MemoryStream[${truncatedString(output, ",", SQLConf.get.maxToStringFields)}]" + } override def deserializeOffset(json: String): OffsetV2 = LongOffset(json.toLong) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 310ebcdf67686..e180d2228c3b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -51,7 +51,7 @@ case class ScalarSubquery( override def dataType: DataType = plan.schema.fields.head.dataType override def children: Seq[Expression] = Nil override def nullable: Boolean = true - override def toString: String = plan.simpleString + override def toString: String = plan.simpleString(SQLConf.get.maxToStringFields) override def withNewPlan(query: SubqueryExec): ScalarSubquery = copy(plan = query) override def semanticEquals(other: Expression): Boolean = other match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala index 0c47a2040f171..3cc97c995702a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala @@ -106,6 +106,19 @@ class QueryExecutionSuite extends SharedSQLContext { } } + test("check maximum fields restriction") { + withTempDir { dir => + val path = dir.getCanonicalPath + "/plans.txt" + val ds = spark.createDataset(Seq(QueryExecutionTestRecord( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26))) + ds.queryExecution.debug.toFile(path) + val localRelations = Source.fromFile(path).getLines().filter(_.contains("LocalRelation")) + + assert(!localRelations.exists(_.contains("more fields"))) + } + } + test("toString() exception/error handling") { spark.experimental.extraStrategies = Seq( new SparkStrategy { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala index 608f21e726259..7249eacfbf9a6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala @@ -83,7 +83,7 @@ trait CreateHiveTableAsSelectBase extends DataWritingCommand { tableDesc: CatalogTable, tableExists: Boolean): DataWritingCommand - override def argString: String = { + override def argString(maxFields: Int): String = { s"[Database:${tableDesc.database}, " + s"TableName: ${tableDesc.identifier.table}, " + s"InsertIntoHiveTable]" From b3032c985fec90e97c9c53851948fb47109520df Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Thu, 27 Dec 2018 22:26:37 +0800 Subject: [PATCH 130/194] [SPARK-25892][SQL] Change AttributeReference.withMetadata's return type to AttributeReference ## What changes were proposed in this pull request? Currently the `AttributeReference.withMetadata` method have return type `Attribute`, the rest of with methods in the `AttributeReference` return type are `AttributeReference`, as the [spark-25892](https://issues.apache.org/jira/browse/SPARK-25892?jql=project%20%3D%20SPARK%20AND%20component%20in%20(ML%2C%20PySpark%2C%20SQL)) mentioned. This PR will change `AttributeReference.withMetadata` method's return type from `Attribute` to `AttributeReference`. ## How was this patch tested? Run all `sql/test,` `catalyst/test` and `org.apache.spark.sql.execution.streaming.*` Closes #22918 from kevinyu98/spark-25892. Authored-by: Kevin Yu Signed-off-by: Hyukjin Kwon --- .../spark/sql/catalyst/expressions/namedExpressions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 131459bf27bc8..7ebb171f34ba2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -311,7 +311,7 @@ case class AttributeReference( } } - override def withMetadata(newMetadata: Metadata): Attribute = { + override def withMetadata(newMetadata: Metadata): AttributeReference = { AttributeReference(name, dataType, nullable, newMetadata)(exprId, qualifier) } From e6d6eafcb6ea840b3db7856074f5ce45e136f50d Mon Sep 17 00:00:00 2001 From: deepyaman Date: Fri, 28 Dec 2018 00:02:41 +0800 Subject: [PATCH 131/194] [SPARK-26451][SQL] Change lead/lag argument name from count to offset ## What changes were proposed in this pull request? Change aligns argument name with that in Scala version and documentation. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #23357 from deepyaman/patch-1. Authored-by: deepyaman Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/functions.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index d2a771e9bb8ea..3c33e2bed92d9 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -798,7 +798,7 @@ def factorial(col): # --------------- Window functions ------------------------ @since(1.4) -def lag(col, count=1, default=None): +def lag(col, offset=1, default=None): """ Window function: returns the value that is `offset` rows before the current row, and `defaultValue` if there is less than `offset` rows before the current row. For example, @@ -807,15 +807,15 @@ def lag(col, count=1, default=None): This is equivalent to the LAG function in SQL. :param col: name of column or expression - :param count: number of row to extend + :param offset: number of row to extend :param default: default value """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.lag(_to_java_column(col), count, default)) + return Column(sc._jvm.functions.lag(_to_java_column(col), offset, default)) @since(1.4) -def lead(col, count=1, default=None): +def lead(col, offset=1, default=None): """ Window function: returns the value that is `offset` rows after the current row, and `defaultValue` if there is less than `offset` rows after the current row. For example, @@ -824,11 +824,11 @@ def lead(col, count=1, default=None): This is equivalent to the LEAD function in SQL. :param col: name of column or expression - :param count: number of row to extend + :param offset: number of row to extend :param default: default value """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.lead(_to_java_column(col), count, default)) + return Column(sc._jvm.functions.lead(_to_java_column(col), offset, default)) @since(1.4) From 11f1c8d4d6f961affb4f5ba5d272ce7ee560889d Mon Sep 17 00:00:00 2001 From: wuqingxin Date: Fri, 28 Dec 2018 00:15:57 -0800 Subject: [PATCH 132/194] [SPARK-26446][CORE] Add cachedExecutorIdleTimeout docs at ExecutorAllocationManager ## What changes were proposed in this pull request? Add docs to describe how remove policy act while considering the property `spark.dynamicAllocation.cachedExecutorIdleTimeout` in ExecutorAllocationManager ## How was this patch tested? comment-only PR. Closes #23386 from TopGunViper/SPARK-26446. Authored-by: wuqingxin Signed-off-by: Dongjoon Hyun --- .../org/apache/spark/ExecutorAllocationManager.scala | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index c3e5b96a55884..3f0b71bbe17f1 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -57,7 +57,8 @@ import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} * a long time to ramp up under heavy workloads. * * The remove policy is simpler: If an executor has been idle for K seconds, meaning it has not - * been scheduled to run any tasks, then it is removed. + * been scheduled to run any tasks, then it is removed. Note that an executor caching any data + * blocks will be removed if it has been idle for more than L seconds. * * There is no retry logic in either case because we make the assumption that the cluster manager * will eventually fulfill all requests it receives asynchronously. @@ -81,7 +82,12 @@ import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} * This is used only after the initial backlog timeout is exceeded * * spark.dynamicAllocation.executorIdleTimeout (K) - - * If an executor has been idle for this duration, remove it + * If an executor without caching any data blocks has been idle for this duration, remove it + * + * spark.dynamicAllocation.cachedExecutorIdleTimeout (L) - + * If an executor with caching data blocks has been idle for more than this duration, + * the executor will be removed + * */ private[spark] class ExecutorAllocationManager( client: ExecutorAllocationClient, From fcbec315624bd73e34577f442c3a16d4a1d27115 Mon Sep 17 00:00:00 2001 From: seancxmao Date: Fri, 28 Dec 2018 07:40:59 -0600 Subject: [PATCH 133/194] [SPARK-26444][WEBUI] Stage color doesn't change with it's status ## What changes were proposed in this pull request? On job page, in event timeline section, stage color doesn't change according to its status. Below are some screenshots. ACTIVE: active COMPLETE: complete FAILED: failed This PR lets stage color change with it's status. The main idea is to make css style class name match the corresponding stage status. ## How was this patch tested? Manually tested locally. ``` // active/complete stage sc.parallelize(1 to 3, 3).map { n => Thread.sleep(10* 1000); n }.count // failed stage sc.parallelize(1 to 3, 3).map { n => Thread.sleep(10* 1000); throw new Exception() }.count ``` Note we need to clear browser cache to let new `timeline-view.css` take effect. Below are screenshots after this PR. ACTIVE: active-after COMPLETE: complete-after FAILED: failed-after Closes #23385 from seancxmao/timeline-stage-color. Authored-by: seancxmao Signed-off-by: Sean Owen --- .../org/apache/spark/ui/static/timeline-view.css | 8 ++++---- .../src/main/scala/org/apache/spark/ui/jobs/JobPage.scala | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.css b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.css index 3bf3e8bfa1f31..10bceae2fbdda 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.css +++ b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.css @@ -98,12 +98,12 @@ rect.getting-result-time-proportion { cursor: pointer; } -.vis-timeline .vis-item.stage.succeeded { +.vis-timeline .vis-item.stage.complete { background-color: #A0DFFF; border-color: #3EC0FF; } -.vis-timeline .vis-item.stage.succeeded.vis-selected { +.vis-timeline .vis-item.stage.complete.vis-selected { background-color: #A0DFFF; border-color: #3EC0FF; z-index: auto; @@ -130,12 +130,12 @@ rect.getting-result-time-proportion { stroke: #FF4D6D; } -.vis-timeline .vis-item.stage.running { +.vis-timeline .vis-item.stage.active { background-color: #A2FCC0; border-color: #36F572; } -.vis-timeline .vis-item.stage.running.vis-selected { +.vis-timeline .vis-item.stage.active.vis-selected { background-color: #A2FCC0; border-color: #36F572; z-index: auto; diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index b58a6ca447edf..cd82439223b07 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -62,7 +62,7 @@ private[ui] class JobPage(parent: JobsTab, store: AppStatusStore) extends WebUIP val stageId = stage.stageId val attemptId = stage.attemptId val name = stage.name - val status = stage.status.toString + val status = stage.status.toString.toLowerCase(Locale.ROOT) val submissionTime = stage.submissionTime.get.getTime() val completionTime = stage.completionTime.map(_.getTime()) .getOrElse(System.currentTimeMillis()) From 1bb70d9a133bd59bf134fe30ca85b3d3c3eb5657 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 28 Dec 2018 11:29:06 -0800 Subject: [PATCH 134/194] [SPARK-26424][SQL][FOLLOWUP] Fix DateFormatClass/UnixTime codegen ## What changes were proposed in this pull request? This PR fixes the codegen bug introduced by #23358 . - https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-maven-hadoop-2.7-ubuntu-scala-2.11/158/ ``` Line 44, Column 93: A method named "apply" is not declared in any enclosing class nor any supertype, nor through a static import ``` ## How was this patch tested? Manual. `DateExpressionsSuite` should be passed with Scala-2.11. Closes #23394 from dongjoon-hyun/SPARK-26424. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../sql/catalyst/expressions/datetimeExpressions.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 73af0a3c5c2ee..8fc0112c02577 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -571,7 +571,7 @@ case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Opti val tz = ctx.addReferenceObj("timeZone", timeZone) val locale = ctx.addReferenceObj("locale", Locale.US) defineCodeGen(ctx, ev, (timestamp, format) => { - s"""UTF8String.fromString($tf.apply($format.toString(), $tz, $locale) + s"""UTF8String.fromString($tf$$.MODULE$$.apply($format.toString(), $tz, $locale) .format($timestamp))""" }) } @@ -741,11 +741,11 @@ abstract class UnixTime case StringType => val tz = ctx.addReferenceObj("timeZone", timeZone) val locale = ctx.addReferenceObj("locale", Locale.US) - val dtu = TimestampFormatter.getClass.getName.stripSuffix("$") + val tf = TimestampFormatter.getClass.getName.stripSuffix("$") nullSafeCodeGen(ctx, ev, (string, format) => { s""" try { - ${ev.value} = $dtu.apply($format.toString(), $tz, $locale) + ${ev.value} = $tf$$.MODULE$$.apply($format.toString(), $tz, $locale) .parse($string.toString()) / 1000000L; } catch (java.lang.IllegalArgumentException e) { ${ev.isNull} = true; From bb07fe9b76b12b9ac64816a4e0f99d892a6188c8 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sat, 29 Dec 2018 02:03:52 -0800 Subject: [PATCH 135/194] Maybe we have a race condition with the watcher? idk why we aren't getting any pods. --- .../k8s/integrationtest/KubernetesSuite.scala | 40 ++++++++++--------- 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index cb24fa5d2836e..5b7fcfadd3475 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -249,25 +249,7 @@ class KubernetesSuite extends SparkFunSuite mainAppResource = appResource, mainClass = mainClass, appArgs = appArgs) - SparkAppLauncher.launch( - appArguments, - sparkAppConf, - TIMEOUT.value.toSeconds.toInt, - sparkHomeDir, - isJVM, - pyFiles) - println("Running spark job.") - val driverPod = kubernetesTestComponents.kubernetesClient - .pods() - .withLabel("spark-app-locator", appLocator) - .withLabel("spark-role", "driver") - .list() - .getItems - .get(0) - println("Doing driver pod check") - driverPodChecker(driverPod) - println("Done driver pod check") val execPods = scala.collection.mutable.Map[String, Pod]() println("Creating watcher...") val execWatcher = kubernetesTestComponents.kubernetesClient @@ -286,6 +268,7 @@ class KubernetesSuite extends SparkFunSuite println(s"Add or modification event received for $name.") execPods(name) = resource // If testing decomissioning delete the node 5 seconds after it starts running + // Open question: could we put this in the checker if (decomissioningTest && false) { // Wait for all the containers in the pod to be running println("Waiting for pod to become OK then delete.") @@ -310,9 +293,30 @@ class KubernetesSuite extends SparkFunSuite } } }) + + println("Running spark job.") + SparkAppLauncher.launch( + appArguments, + sparkAppConf, + TIMEOUT.value.toSeconds.toInt, + sparkHomeDir, + isJVM, + pyFiles) + + val driverPod = kubernetesTestComponents.kubernetesClient + .pods() + .withLabel("spark-app-locator", appLocator) + .withLabel("spark-role", "driver") + .list() + .getItems + .get(0) + println("Doing driver pod check") + driverPodChecker(driverPod) + println("Done driver pod check") // If we're testing decomissioning we delete all the executors, but we should have // an executor at some point. Eventually.eventually(TIMEOUT, INTERVAL) { execPods.values.nonEmpty } + println(s"Closing watcher with execPods $execPods nonEmpty: ${execPods.values.nonEmpty}") execWatcher.close() execPods.values.foreach(executorPodChecker(_)) println(s"Close to the end exec pods are $execPods") From ea77b2332ba9038ce88ea0cd9b54797b796d745b Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sat, 29 Dec 2018 10:43:23 -0800 Subject: [PATCH 136/194] wtf is going on with this eventually block --- .../spark/deploy/k8s/integrationtest/KubernetesSuite.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 5b7fcfadd3475..5832135df860f 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -315,7 +315,10 @@ class KubernetesSuite extends SparkFunSuite println("Done driver pod check") // If we're testing decomissioning we delete all the executors, but we should have // an executor at some point. - Eventually.eventually(TIMEOUT, INTERVAL) { execPods.values.nonEmpty } + Eventually.eventually(TIMEOUT, INTERVAL) { + println(s"This iteration is ${execPods.values.nonEmpty} with ${execPods}") + execPods.values.nonEmpty + } println(s"Closing watcher with execPods $execPods nonEmpty: ${execPods.values.nonEmpty}") execWatcher.close() execPods.values.foreach(executorPodChecker(_)) From c12071312e9e039e6a5a8fb0e2d1e26ed9c347c1 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 3 Jan 2019 11:13:43 -0800 Subject: [PATCH 137/194] Rewrite the eventually's to use should be which I had accidently removed. --- .../spark/deploy/k8s/integrationtest/KubernetesSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 5832135df860f..a174ebc2042d9 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -275,7 +275,7 @@ class KubernetesSuite extends SparkFunSuite Eventually.eventually(TIMEOUT, INTERVAL) { resource.getStatus.getConditions().asScala .map(cond => cond.getStatus() == "True" && cond.getType() == "Ready") - .headOption.getOrElse(false) + .headOption.getOrElse(false) shouldBe (true) } // Sleep a small interval to ensure everything is registered. Thread.sleep(500) @@ -317,7 +317,7 @@ class KubernetesSuite extends SparkFunSuite // an executor at some point. Eventually.eventually(TIMEOUT, INTERVAL) { println(s"This iteration is ${execPods.values.nonEmpty} with ${execPods}") - execPods.values.nonEmpty + execPods.values.nonEmpty should be (true) } println(s"Closing watcher with execPods $execPods nonEmpty: ${execPods.values.nonEmpty}") execWatcher.close() From fecd0cf71d2794f1151cf5fea4d485b4aa49653b Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 3 Jan 2019 11:44:53 -0800 Subject: [PATCH 138/194] Change how we handle decom tests to actually decom workers and check that workers are decomed eventually. --- .../deploy/k8s/integrationtest/KubernetesSuite.scala | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index a174ebc2042d9..e5a684221786c 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -269,7 +269,7 @@ class KubernetesSuite extends SparkFunSuite execPods(name) = resource // If testing decomissioning delete the node 5 seconds after it starts running // Open question: could we put this in the checker - if (decomissioningTest && false) { + if (decomissioningTest) { // Wait for all the containers in the pod to be running println("Waiting for pod to become OK then delete.") Eventually.eventually(TIMEOUT, INTERVAL) { @@ -319,6 +319,12 @@ class KubernetesSuite extends SparkFunSuite println(s"This iteration is ${execPods.values.nonEmpty} with ${execPods}") execPods.values.nonEmpty should be (true) } + if (decomissioningTest) { + Eventually.eventually(TIMEOUT, INTERVAL) { + println(s"Decom: This iteration is ${execPods.values.nonEmpty} with ${execPods}") + execPods.values.nonEmpty should be (false) + } + } println(s"Closing watcher with execPods $execPods nonEmpty: ${execPods.values.nonEmpty}") execWatcher.close() execPods.values.foreach(executorPodChecker(_)) From b668feb4c70cd29a3b7d16c91a1648ab65825bc4 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 3 Jan 2019 12:16:20 -0800 Subject: [PATCH 139/194] Revert "Fix using SparkPI for decom test." This reverts commit 705fd583441dd4771ddaea1ea686f277dd4eeaef. --- .../deploy/k8s/integrationtest/DecommissionSuite.scala | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala index 681ca13f83508..72b2fda3a062f 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala @@ -18,22 +18,17 @@ package org.apache.spark.deploy.k8s.integrationtest private[spark] trait DecommissionSuite { k8sSuite: KubernetesSuite => + import DecommissionSuite._ import KubernetesSuite.k8sTestTag - import KubernetesSuite.SPARK_PI_MAIN_CLASS test("Test basic decommissioning", k8sTestTag) { sparkAppConf .set("spark.worker.decommission.enabled", "true") runSparkApplicationAndVerifyCompletion( - appResource = containerLocalSparkDistroExamplesJar, - mainClass = SPARK_PI_MAIN_CLASS, + SPARK_PI_MAIN_CLASS, expectedLogOnCompletion = Seq("Decommissioning executor"), appArgs = Array("100"), // Give it some time to run - driverPodChecker = doBasicDriverPodCheck, - executorPodChecker = doBasicExecutorPodCheck, - appLocator = appLocator, - isJVM = true, decomissioningTest = true) } } From 2905c8a8ff51f5932525ce81cfefb41745f5456a Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 3 Jan 2019 12:16:26 -0800 Subject: [PATCH 140/194] Revert "Python tests aren't registering the executors, lets avoid that noise and just do SparkPI since we don't really need any special logic in the driver." Turns out I was just checking executor registration incorrectly. This reverts commit 5d173bd88aad8ee85d1c048ca814ed1c7f3353d9. --- .../k8s/integrationtest/DecommissionSuite.scala | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala index 72b2fda3a062f..897a85c97ea45 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala @@ -24,11 +24,23 @@ private[spark] trait DecommissionSuite { k8sSuite: KubernetesSuite => test("Test basic decommissioning", k8sTestTag) { sparkAppConf .set("spark.worker.decommission.enabled", "true") + .set("spark.kubernetes.pyspark.pythonVersion", "3") + .set("spark.kubernetes.container.image", pyImage) runSparkApplicationAndVerifyCompletion( - SPARK_PI_MAIN_CLASS, + appResource = PYSPARK_DECOMISSIONING, + mainClass = "", expectedLogOnCompletion = Seq("Decommissioning executor"), - appArgs = Array("100"), // Give it some time to run + appArgs = Array.empty[String], + driverPodChecker = doBasicDriverPyPodCheck, + executorPodChecker = doBasicExecutorPyPodCheck, + appLocator = appLocator, + isJVM = false, decomissioningTest = true) } } + +private[spark] object DecommissionSuite { + val TEST_LOCAL_PYSPARK: String = "local:///opt/spark/tests/" + val PYSPARK_DECOMISSIONING: String = TEST_LOCAL_PYSPARK + "decommissioning.py" +} From b77eb9e2411d135fdbf81b21934276d3785a21fe Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 3 Jan 2019 12:20:20 -0800 Subject: [PATCH 141/194] Match the wait logic (TODO refactor) --- .../deploy/k8s/integrationtest/KubernetesSuite.scala | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index e5a684221786c..a6e5b71a0bd3e 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -278,7 +278,7 @@ class KubernetesSuite extends SparkFunSuite .headOption.getOrElse(false) shouldBe (true) } // Sleep a small interval to ensure everything is registered. - Thread.sleep(500) + Thread.sleep(100) // Delete the pod to simulate cluster scale down/migration. val pod = kubernetesTestComponents.kubernetesClient.pods().withName(name) pod.delete() @@ -319,7 +319,17 @@ class KubernetesSuite extends SparkFunSuite println(s"This iteration is ${execPods.values.nonEmpty} with ${execPods}") execPods.values.nonEmpty should be (true) } + // If decomissioning we need to wait and check the executors were removed if (decomissioningTest) { + // Wait for the executors to become ready + Eventually.eventually(TIMEOUT, INTERVAL) { + resource.getStatus.getConditions().asScala + .map(cond => cond.getStatus() == "True" && cond.getType() == "Ready") + .headOption.getOrElse(false) shouldBe (true) + } + // Sleep a small interval to allow execution + Thread.sleep(100) + // Wait for the executors to be removed Eventually.eventually(TIMEOUT, INTERVAL) { println(s"Decom: This iteration is ${execPods.values.nonEmpty} with ${execPods}") execPods.values.nonEmpty should be (false) From 9f069a4a0a2b9894c887fc66cc08bb60299f8288 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 4 Jan 2019 10:35:50 -0800 Subject: [PATCH 142/194] Debug pods not becoming ready. --- .../deploy/k8s/integrationtest/KubernetesSuite.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index a6e5b71a0bd3e..0d404b9268ce1 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -323,9 +323,13 @@ class KubernetesSuite extends SparkFunSuite if (decomissioningTest) { // Wait for the executors to become ready Eventually.eventually(TIMEOUT, INTERVAL) { - resource.getStatus.getConditions().asScala + val resourceConditions = execPods.flatMap{ + case (podName, resource) => resource.getStatus.getConditions().asScala} + val podsReady = (resourceConditions .map(cond => cond.getStatus() == "True" && cond.getType() == "Ready") - .headOption.getOrElse(false) shouldBe (true) + .headOption.getOrElse(false)) + assert(podsReady, + "The pods did not become ready the resource conditions are ${resourceConditions}") } // Sleep a small interval to allow execution Thread.sleep(100) From 789bbdd3f01c42810458fda34e082f9401cb04f1 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sat, 5 Jan 2019 22:36:33 -0800 Subject: [PATCH 143/194] special format string. --- .../spark/deploy/k8s/integrationtest/KubernetesSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 0d404b9268ce1..d919e13f0e0e1 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -329,7 +329,7 @@ class KubernetesSuite extends SparkFunSuite .map(cond => cond.getStatus() == "True" && cond.getType() == "Ready") .headOption.getOrElse(false)) assert(podsReady, - "The pods did not become ready the resource conditions are ${resourceConditions}") + s"The pods did not become ready the resource conditions are ${resourceConditions}") } // Sleep a small interval to allow execution Thread.sleep(100) From bf834e21f6de580526487d688e620385dc1e588c Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sun, 6 Jan 2019 18:21:04 -0800 Subject: [PATCH 144/194] If all the pods are done we don't need the pods to be ready. --- .../deploy/k8s/integrationtest/KubernetesSuite.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index d919e13f0e0e1..4911b7bcf5cf0 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -323,13 +323,14 @@ class KubernetesSuite extends SparkFunSuite if (decomissioningTest) { // Wait for the executors to become ready Eventually.eventually(TIMEOUT, INTERVAL) { - val resourceConditions = execPods.flatMap{ - case (podName, resource) => resource.getStatus.getConditions().asScala} + val resourceConditions = execPods.values.flatMap{ + resource => resource.getStatus.getConditions().asScala} val podsReady = (resourceConditions .map(cond => cond.getStatus() == "True" && cond.getType() == "Ready") .headOption.getOrElse(false)) - assert(podsReady, - s"The pods did not become ready the resource conditions are ${resourceConditions}") + val podsEmpty = execPods.values.isEmpty + assert(podsReady || podsEmpty, + s"The pods (${execPods.values}) did not become ready the resource conditions are ${resourceConditions}") } // Sleep a small interval to allow execution Thread.sleep(100) From 1ed5c3d8043539671652f0f38f9eece4e65ec2e3 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sun, 6 Jan 2019 19:05:04 -0800 Subject: [PATCH 145/194] Sleep 100 before waiting. --- .../spark/deploy/k8s/integrationtest/KubernetesSuite.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 4911b7bcf5cf0..7d1505340d387 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -321,6 +321,8 @@ class KubernetesSuite extends SparkFunSuite } // If decomissioning we need to wait and check the executors were removed if (decomissioningTest) { + // Sleep a small interval to ensure everything is registered. + Thread.sleep(100) // Wait for the executors to become ready Eventually.eventually(TIMEOUT, INTERVAL) { val resourceConditions = execPods.values.flatMap{ From 7acc255e1835b4060271878e7d65f469d425da69 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 7 Jan 2019 14:20:34 -0800 Subject: [PATCH 146/194] Give the pod 5 minutes to become ready. --- .../spark/deploy/k8s/integrationtest/KubernetesSuite.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 7d1505340d387..9e14b6df27911 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -272,7 +272,7 @@ class KubernetesSuite extends SparkFunSuite if (decomissioningTest) { // Wait for all the containers in the pod to be running println("Waiting for pod to become OK then delete.") - Eventually.eventually(TIMEOUT, INTERVAL) { + Eventually.eventually(POD_RUNNING_TIMEOUT, INTERVAL) { resource.getStatus.getConditions().asScala .map(cond => cond.getStatus() == "True" && cond.getType() == "Ready") .headOption.getOrElse(false) shouldBe (true) @@ -315,7 +315,7 @@ class KubernetesSuite extends SparkFunSuite println("Done driver pod check") // If we're testing decomissioning we delete all the executors, but we should have // an executor at some point. - Eventually.eventually(TIMEOUT, INTERVAL) { + Eventually.eventually(POD_RUNNING_TIMEOUT, INTERVAL) { println(s"This iteration is ${execPods.values.nonEmpty} with ${execPods}") execPods.values.nonEmpty should be (true) } @@ -449,5 +449,6 @@ private[spark] object KubernetesSuite { val SPARK_REMOTE_MAIN_CLASS: String = "org.apache.spark.examples.SparkRemoteFileTest" val SPARK_DRIVER_MAIN_CLASS: String = "org.apache.spark.examples.DriverSubmissionTest" val TIMEOUT = PatienceConfiguration.Timeout(Span(2, Minutes)) + val POD_RUNNING_TIMEOUT = PatienceConfiguration.Timeout(Span(5, Minutes)) val INTERVAL = PatienceConfiguration.Interval(Span(2, Seconds)) } From b2c58b6ce4427e56f1cab9071cb901d44fd418cd Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 7 Jan 2019 15:10:12 -0800 Subject: [PATCH 147/194] I think we might have had a race condition where the top thread deleted the pods before the bottom thread could find them in a ready state. --- .../deploy/k8s/integrationtest/KubernetesSuite.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 9e14b6df27911..660d911b58a5f 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -277,8 +277,8 @@ class KubernetesSuite extends SparkFunSuite .map(cond => cond.getStatus() == "True" && cond.getType() == "Ready") .headOption.getOrElse(false) shouldBe (true) } - // Sleep a small interval to ensure everything is registered. - Thread.sleep(100) + // Sleep a small interval to allow execution & downstream pod ready check to also catch up + Thread.sleep(2000) // Delete the pod to simulate cluster scale down/migration. val pod = kubernetesTestComponents.kubernetesClient.pods().withName(name) pod.delete() @@ -335,7 +335,7 @@ class KubernetesSuite extends SparkFunSuite s"The pods (${execPods.values}) did not become ready the resource conditions are ${resourceConditions}") } // Sleep a small interval to allow execution - Thread.sleep(100) + Thread.sleep(3000) // Wait for the executors to be removed Eventually.eventually(TIMEOUT, INTERVAL) { println(s"Decom: This iteration is ${execPods.values.nonEmpty} with ${execPods}") @@ -450,5 +450,5 @@ private[spark] object KubernetesSuite { val SPARK_DRIVER_MAIN_CLASS: String = "org.apache.spark.examples.DriverSubmissionTest" val TIMEOUT = PatienceConfiguration.Timeout(Span(2, Minutes)) val POD_RUNNING_TIMEOUT = PatienceConfiguration.Timeout(Span(5, Minutes)) - val INTERVAL = PatienceConfiguration.Interval(Span(2, Seconds)) + val INTERVAL = PatienceConfiguration.Interval(Span(1, Seconds)) } From 92f4289a1dd971dadb34e227d532bd43014fa055 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 7 Jan 2019 19:12:01 -0800 Subject: [PATCH 148/194] Use POD_RUNNING_TIMEOUT --- .../spark/deploy/k8s/integrationtest/KubernetesSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 660d911b58a5f..82359e9c5b323 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -324,7 +324,7 @@ class KubernetesSuite extends SparkFunSuite // Sleep a small interval to ensure everything is registered. Thread.sleep(100) // Wait for the executors to become ready - Eventually.eventually(TIMEOUT, INTERVAL) { + Eventually.eventually(POD_RUNNING_TIMEOUT, INTERVAL) { val resourceConditions = execPods.values.flatMap{ resource => resource.getStatus.getConditions().asScala} val podsReady = (resourceConditions From 0b33e1a792cff22e2eda9febd0c891c35325962a Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sun, 10 Feb 2019 23:39:29 -0800 Subject: [PATCH 149/194] Fix appclient suite --- .../spark/deploy/client/AppClientSuite.scala | 39 +++++++++---------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala index 8f12cdd65cfb6..0a44248e8214b 100644 --- a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala @@ -131,26 +131,25 @@ class AppClientSuite // Send request to kill executor, verify request was made whenReady( - whenReady( - ci.client.killExecutors(Seq(executorId)), - timeout(10.seconds), - interval(10.millis)) { acknowledged => - assert(acknowledged) - } - - // Verify that asking for executors on the decommissioned workers fails - whenReady( - ci.client.requestTotalExecutors(numExecutorsRequested), - timeout(10.seconds), - interval(10.millis)) { acknowledged => - assert(acknowledged) - } - assert(getApplications().head.executors.size === 0) - - // Issue stop command for Client to disconnect from Master - ci.client.stop() - - // Verify Client is marked dead and unregistered from Master + ci.client.killExecutors(Seq(executorId)), + timeout(10.seconds), + interval(10.millis)) { acknowledged => + assert(acknowledged) + } + + // Verify that asking for executors on the decommissioned workers fails + whenReady( + ci.client.requestTotalExecutors(numExecutorsRequested), + timeout(10.seconds), + interval(10.millis)) { acknowledged => + assert(acknowledged) + } + assert(getApplications().head.executors.size === 0) + + // Issue stop command for Client to disconnect from Master + ci.client.stop() + + // Verify Client is marked dead and unregistered from Master eventually(timeout(10.seconds), interval(10.millis)) { val apps = getApplications() assert(ci.listener.deadReasonList.size === 1, "client should have been marked dead") From 2dc806e127c79b263938464d15c5edb0ba64a07f Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 14 Feb 2019 11:27:07 -0800 Subject: [PATCH 150/194] Try and debug the waiting on killing the exec pod logic. --- .../deploy/k8s/integrationtest/KubernetesSuite.scala | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 82359e9c5b323..325e8db19cf2d 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -273,11 +273,16 @@ class KubernetesSuite extends SparkFunSuite // Wait for all the containers in the pod to be running println("Waiting for pod to become OK then delete.") Eventually.eventually(POD_RUNNING_TIMEOUT, INTERVAL) { - resource.getStatus.getConditions().asScala + val resourceStatus = resource.getStatus + val conditions = resourceStatus.getConditions().asScala + val result = conditions .map(cond => cond.getStatus() == "True" && cond.getType() == "Ready") .headOption.getOrElse(false) shouldBe (true) + println(s"Waiting on ${resource} status ${resourceStatus} with conditions ${conditions} result: ${result}") + result } // Sleep a small interval to allow execution & downstream pod ready check to also catch up + println("Sleeping before killing pod.") Thread.sleep(2000) // Delete the pod to simulate cluster scale down/migration. val pod = kubernetesTestComponents.kubernetesClient.pods().withName(name) @@ -316,7 +321,7 @@ class KubernetesSuite extends SparkFunSuite // If we're testing decomissioning we delete all the executors, but we should have // an executor at some point. Eventually.eventually(POD_RUNNING_TIMEOUT, INTERVAL) { - println(s"This iteration is ${execPods.values.nonEmpty} with ${execPods}") + println(s"Driver podcheck iteration non empty: ${execPods.values.nonEmpty} with ${execPods}") execPods.values.nonEmpty should be (true) } // If decomissioning we need to wait and check the executors were removed @@ -338,7 +343,7 @@ class KubernetesSuite extends SparkFunSuite Thread.sleep(3000) // Wait for the executors to be removed Eventually.eventually(TIMEOUT, INTERVAL) { - println(s"Decom: This iteration is ${execPods.values.nonEmpty} with ${execPods}") + println(s"decom iteration pods non-empty ${execPods.values.nonEmpty} with ${execPods}") execPods.values.nonEmpty should be (false) } } From 1963bf36a2b8af1f5be61f84115241d4f8d63027 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 14 Feb 2019 11:45:08 -0800 Subject: [PATCH 151/194] Maybe we were not actually hitting the k8s end point which is why the status was never changing. --- .../deploy/k8s/integrationtest/KubernetesSuite.scala | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 325e8db19cf2d..1486392883e42 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -273,7 +273,13 @@ class KubernetesSuite extends SparkFunSuite // Wait for all the containers in the pod to be running println("Waiting for pod to become OK then delete.") Eventually.eventually(POD_RUNNING_TIMEOUT, INTERVAL) { - val resourceStatus = resource.getStatus + val execPod = kubernetesTestComponents.kubernetesClient + .pods() + .withLabel("name", name) + .list() + .getItems + .get(0) + val resourceStatus = execPod.getStatus val conditions = resourceStatus.getConditions().asScala val result = conditions .map(cond => cond.getStatus() == "True" && cond.getType() == "Ready") From 5771b055a744f2dfd5cc4a6dfe8e64504e23edea Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 14 Feb 2019 11:52:45 -0800 Subject: [PATCH 152/194] re-order tests to fail faster. --- .../spark/deploy/k8s/integrationtest/KubernetesSuite.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 1486392883e42..11f70c3c3b30e 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -40,9 +40,10 @@ import org.apache.spark.internal.config._ class KubernetesSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter - with BasicTestsSuite with SecretsTestsSuite - with PythonTestsSuite with ClientModeTestsSuite with PodTemplateSuite + with BasicTestsSuite with DecommissionSuite + with SecretsTestsSuite + with PythonTestsSuite with ClientModeTestsSuite with PodTemplateSuite with Logging with Eventually with Matchers { import KubernetesSuite._ From a39cd8515bc5ecdabcce5b781e96ff7bd5427401 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 14 Feb 2019 12:09:56 -0800 Subject: [PATCH 153/194] Refactor the pod ready status check to be shared in the two places we need it. --- .../k8s/integrationtest/KubernetesSuite.scala | 41 ++++++++++--------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 11f70c3c3b30e..42b7f38c7f46e 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -252,6 +252,22 @@ class KubernetesSuite extends SparkFunSuite appArgs = appArgs) val execPods = scala.collection.mutable.Map[String, Pod]() + def checkPodReady(name: String) = { + val execPod = kubernetesTestComponents.kubernetesClient + .pods() + .withLabel("name", name) + .list() + .getItems + .get(0) + val resourceStatus = execPod.getStatus + val conditions = resourceStatus.getConditions().asScala + val result = conditions + .map(cond => cond.getStatus() == "True" && cond.getType() == "Ready") + .headOption.getOrElse(false) + println(s"Pod name ${name} with entry ${execPod} has status" + + s"${resourceStatus} with conditions ${conditions} result: ${result}") + result + } println("Creating watcher...") val execWatcher = kubernetesTestComponents.kubernetesClient .pods() @@ -274,23 +290,12 @@ class KubernetesSuite extends SparkFunSuite // Wait for all the containers in the pod to be running println("Waiting for pod to become OK then delete.") Eventually.eventually(POD_RUNNING_TIMEOUT, INTERVAL) { - val execPod = kubernetesTestComponents.kubernetesClient - .pods() - .withLabel("name", name) - .list() - .getItems - .get(0) - val resourceStatus = execPod.getStatus - val conditions = resourceStatus.getConditions().asScala - val result = conditions - .map(cond => cond.getStatus() == "True" && cond.getType() == "Ready") - .headOption.getOrElse(false) shouldBe (true) - println(s"Waiting on ${resource} status ${resourceStatus} with conditions ${conditions} result: ${result}") - result + val result = checkPodReady(name) + result shouldBe (true) } // Sleep a small interval to allow execution & downstream pod ready check to also catch up println("Sleeping before killing pod.") - Thread.sleep(2000) + Thread.sleep(100) // Delete the pod to simulate cluster scale down/migration. val pod = kubernetesTestComponents.kubernetesClient.pods().withName(name) pod.delete() @@ -337,14 +342,10 @@ class KubernetesSuite extends SparkFunSuite Thread.sleep(100) // Wait for the executors to become ready Eventually.eventually(POD_RUNNING_TIMEOUT, INTERVAL) { - val resourceConditions = execPods.values.flatMap{ - resource => resource.getStatus.getConditions().asScala} - val podsReady = (resourceConditions - .map(cond => cond.getStatus() == "True" && cond.getType() == "Ready") - .headOption.getOrElse(false)) + val podsReady = ! execPods.keys.filter(checkPodReady).isEmpty val podsEmpty = execPods.values.isEmpty assert(podsReady || podsEmpty, - s"The pods (${execPods.values}) did not become ready the resource conditions are ${resourceConditions}") + s"None of the pods in ${execPods} became ready") } // Sleep a small interval to allow execution Thread.sleep(3000) From 3a16ee89dabd7a79afe477d5a95be8bd884486eb Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 14 Feb 2019 12:48:57 -0800 Subject: [PATCH 154/194] More debugging and use shouldBe rather than a direct assert. --- .../deploy/k8s/integrationtest/KubernetesSuite.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 42b7f38c7f46e..24e25eec73a16 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -253,14 +253,18 @@ class KubernetesSuite extends SparkFunSuite val execPods = scala.collection.mutable.Map[String, Pod]() def checkPodReady(name: String) = { + println(s"!!! doing ready check on pod $name") val execPod = kubernetesTestComponents.kubernetesClient .pods() .withLabel("name", name) .list() .getItems .get(0) + println(s"!!! god pod $execPod for $name") val resourceStatus = execPod.getStatus + println(s"!!! status $resourceStatus for $name") val conditions = resourceStatus.getConditions().asScala + println(s"!!! conditions $conditions for $name") val result = conditions .map(cond => cond.getStatus() == "True" && cond.getType() == "Ready") .headOption.getOrElse(false) @@ -344,8 +348,8 @@ class KubernetesSuite extends SparkFunSuite Eventually.eventually(POD_RUNNING_TIMEOUT, INTERVAL) { val podsReady = ! execPods.keys.filter(checkPodReady).isEmpty val podsEmpty = execPods.values.isEmpty - assert(podsReady || podsEmpty, - s"None of the pods in ${execPods} became ready") + val podsReadyOrDead = podsReady || podsEmpty + podsReadyOrDead shouldBe (true) } // Sleep a small interval to allow execution Thread.sleep(3000) From 804444160db6c4aff2d66225931245dae2119faa Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 14 Feb 2019 16:26:14 -0800 Subject: [PATCH 155/194] Name isn't a label, just get the pod by name directly. --- .../spark/deploy/k8s/integrationtest/KubernetesSuite.scala | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 24e25eec73a16..04a9b5e24ba9b 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -256,10 +256,8 @@ class KubernetesSuite extends SparkFunSuite println(s"!!! doing ready check on pod $name") val execPod = kubernetesTestComponents.kubernetesClient .pods() - .withLabel("name", name) - .list() - .getItems - .get(0) + .withName(name) + .get() println(s"!!! god pod $execPod for $name") val resourceStatus = execPod.getStatus println(s"!!! status $resourceStatus for $name") From 25dc90775a50cc462cd5f325c3b3eada5def1808 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 15 Feb 2019 12:28:57 -0800 Subject: [PATCH 156/194] Get namespace as well since we are not finding the pod whuich is odd. --- .../spark/deploy/k8s/integrationtest/KubernetesSuite.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 04a9b5e24ba9b..2d3cce41de2a7 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -252,10 +252,11 @@ class KubernetesSuite extends SparkFunSuite appArgs = appArgs) val execPods = scala.collection.mutable.Map[String, Pod]() - def checkPodReady(name: String) = { + def checkPodReady(namespace: String, name: String) = { println(s"!!! doing ready check on pod $name") val execPod = kubernetesTestComponents.kubernetesClient .pods() + .inNamespace(namespace) .withName(name) .get() println(s"!!! god pod $execPod for $name") @@ -282,6 +283,7 @@ class KubernetesSuite extends SparkFunSuite override def eventReceived(action: Watcher.Action, resource: Pod): Unit = { println("Event received.") val name = resource.getMetadata.getName + val nameSpace = pod.getMetadata().getNamespace() action match { case Action.ADDED | Action.MODIFIED => println(s"Add or modification event received for $name.") From 46b5725f763e1858704c408b7a55f49f717790b0 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 15 Feb 2019 12:41:20 -0800 Subject: [PATCH 157/194] Fix namespace for pod exec check --- .../deploy/k8s/integrationtest/KubernetesSuite.scala | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 2d3cce41de2a7..8f8b04c21a330 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -283,7 +283,7 @@ class KubernetesSuite extends SparkFunSuite override def eventReceived(action: Watcher.Action, resource: Pod): Unit = { println("Event received.") val name = resource.getMetadata.getName - val nameSpace = pod.getMetadata().getNamespace() + val namespace = resource.getMetadata().getNamespace() action match { case Action.ADDED | Action.MODIFIED => println(s"Add or modification event received for $name.") @@ -294,7 +294,7 @@ class KubernetesSuite extends SparkFunSuite // Wait for all the containers in the pod to be running println("Waiting for pod to become OK then delete.") Eventually.eventually(POD_RUNNING_TIMEOUT, INTERVAL) { - val result = checkPodReady(name) + val result = checkPodReady(namespace, name) result shouldBe (true) } // Sleep a small interval to allow execution & downstream pod ready check to also catch up @@ -346,7 +346,10 @@ class KubernetesSuite extends SparkFunSuite Thread.sleep(100) // Wait for the executors to become ready Eventually.eventually(POD_RUNNING_TIMEOUT, INTERVAL) { - val podsReady = ! execPods.keys.filter(checkPodReady).isEmpty + val podsReady = ! execPods.map{ + case (name, resource) => + (name, resource.getMetadata().getNamespace()) + }.filter(case (name, namespace) => checkPodReady(namespace, name)).isEmpty val podsEmpty = execPods.values.isEmpty val podsReadyOrDead = podsReady || podsEmpty podsReadyOrDead shouldBe (true) From 4154eef68162f0051fbb17bb0c3e47f01cffbf95 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 15 Feb 2019 16:55:04 -0800 Subject: [PATCH 158/194] Fix pod check --- .../deploy/k8s/integrationtest/KubernetesSuite.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 8f8b04c21a330..1d86622f643ad 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -346,12 +346,14 @@ class KubernetesSuite extends SparkFunSuite Thread.sleep(100) // Wait for the executors to become ready Eventually.eventually(POD_RUNNING_TIMEOUT, INTERVAL) { - val podsReady = ! execPods.map{ + val anyReadyPods = ! execPods.map{ case (name, resource) => (name, resource.getMetadata().getNamespace()) - }.filter(case (name, namespace) => checkPodReady(namespace, name)).isEmpty + }.filter{ + case (name, namespace) => checkPodReady(namespace, name) + }.isEmpty val podsEmpty = execPods.values.isEmpty - val podsReadyOrDead = podsReady || podsEmpty + val podsReadyOrDead = anyReadyPods || podsEmpty podsReadyOrDead shouldBe (true) } // Sleep a small interval to allow execution From 8a2f5a7fd54c7f64e419d21fb1b0e6e76659b9c2 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 7 Mar 2019 11:49:25 -0800 Subject: [PATCH 159/194] For now skip scala style println checks in KubernetesSuite while we're figuiring out whats going on --- .../spark/deploy/k8s/integrationtest/KubernetesSuite.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 1d86622f643ad..6e807ca695a0a 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -14,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +// scalastyle:off println package org.apache.spark.deploy.k8s.integrationtest import java.io.File @@ -474,3 +475,4 @@ private[spark] object KubernetesSuite { val POD_RUNNING_TIMEOUT = PatienceConfiguration.Timeout(Span(5, Minutes)) val INTERVAL = PatienceConfiguration.Interval(Span(1, Seconds)) } +// scalastyle:on println From 7d4f2646b8edf88de9ce390917a9e8d58486486f Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 7 Mar 2019 11:58:53 -0800 Subject: [PATCH 160/194] Fix long line comment in KubernetesSuite --- .../spark/deploy/k8s/integrationtest/KubernetesSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 6e807ca695a0a..9af1d5b51afc0 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -298,7 +298,7 @@ class KubernetesSuite extends SparkFunSuite val result = checkPodReady(namespace, name) result shouldBe (true) } - // Sleep a small interval to allow execution & downstream pod ready check to also catch up + // Sleep a small interval to allow execution of job println("Sleeping before killing pod.") Thread.sleep(100) // Delete the pod to simulate cluster scale down/migration. From 367a6664b02fdd3fd658b2f61fde6c5f07db8ba1 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 7 Mar 2019 12:03:58 -0800 Subject: [PATCH 161/194] Print out when we are running decom suite. --- .../deploy/k8s/integrationtest/DecommissionSuite.scala | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala index 897a85c97ea45..595594f24be75 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala @@ -22,6 +22,9 @@ private[spark] trait DecommissionSuite { k8sSuite: KubernetesSuite => import KubernetesSuite.k8sTestTag test("Test basic decommissioning", k8sTestTag) { + // scalastyle:off println + println("***TESTING decommissioning***") + // scalastyle:on println sparkAppConf .set("spark.worker.decommission.enabled", "true") .set("spark.kubernetes.pyspark.pythonVersion", "3") @@ -37,6 +40,9 @@ private[spark] trait DecommissionSuite { k8sSuite: KubernetesSuite => appLocator = appLocator, isJVM = false, decomissioningTest = true) + // scalastyle:off println + println("***END TESTING decommissioning***") + // scalastyle:on println } } From e54dbaada6ac3f88a14eb7e689ac08c5f11f0995 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 7 Mar 2019 12:08:19 -0800 Subject: [PATCH 162/194] Print out the namespace as well. --- .../spark/deploy/k8s/integrationtest/KubernetesSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 9af1d5b51afc0..61248b2582bde 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -254,13 +254,13 @@ class KubernetesSuite extends SparkFunSuite val execPods = scala.collection.mutable.Map[String, Pod]() def checkPodReady(namespace: String, name: String) = { - println(s"!!! doing ready check on pod $name") + println(s"!!! doing ready check on pod $name in $namespace") val execPod = kubernetesTestComponents.kubernetesClient .pods() .inNamespace(namespace) .withName(name) .get() - println(s"!!! god pod $execPod for $name") + println(s"!!! got pod $execPod for $name in namespace $namespace") val resourceStatus = execPod.getStatus println(s"!!! status $resourceStatus for $name") val conditions = resourceStatus.getConditions().asScala From 06bbbed6b4352d395eb045055d97700ae798a84a Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 7 Mar 2019 12:15:31 -0800 Subject: [PATCH 163/194] Debugging is easier when I see what failed --- .../integration-tests/dev/dev-run-integration-tests.sh | 1 + .../integration-tests/scripts/setup-integration-test-env.sh | 2 ++ 2 files changed, 3 insertions(+) diff --git a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh index 67b41300a139f..ed33b6522620c 100755 --- a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh +++ b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh @@ -16,6 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +set -ex set -xo errexit TEST_ROOT_DIR=$(git rev-parse --show-toplevel) diff --git a/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh b/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh index e70101bbc08cf..deea4d3e46e94 100755 --- a/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh +++ b/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh @@ -16,6 +16,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # +set -e +set -x TEST_ROOT_DIR=$(git rev-parse --show-toplevel) UNPACKED_SPARK_TGZ="$TEST_ROOT_DIR/target/spark-dist-unpacked" IMAGE_TAG_OUTPUT_FILE="$TEST_ROOT_DIR/target/image-tag.txt" From e8081cbdb443bd3393da3cab1754dc4d4a455015 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 7 Mar 2019 12:21:57 -0800 Subject: [PATCH 164/194] Without specifiying a spark release tarball the setup env script will fail --- resource-managers/kubernetes/integration-tests/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resource-managers/kubernetes/integration-tests/README.md b/resource-managers/kubernetes/integration-tests/README.md index 73fc0581d64f5..60924a9c9021a 100644 --- a/resource-managers/kubernetes/integration-tests/README.md +++ b/resource-managers/kubernetes/integration-tests/README.md @@ -11,7 +11,7 @@ is subject to change. Note that currently the integration tests only run with Ja The simplest way to run the integration tests is to install and run Minikube, then run the following from this directory: - dev/dev-run-integration-tests.sh + dev/dev-run-integration-tests.sh --spark-tgz [spark_release_build] The minimum tested version of Minikube is 0.23.0. The kube-dns addon must be enabled. Minikube should run with a minimum of 4 CPUs and 6G of memory: From 908b204ebc52a3ce2c07d70d44ca3b248363d93b Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 7 Mar 2019 12:36:49 -0800 Subject: [PATCH 165/194] I don't know why the ready check isn't doing what I expected, lets break it down a little. --- .../deploy/k8s/integrationtest/KubernetesSuite.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 61248b2582bde..41792bf6717ba 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -265,11 +265,12 @@ class KubernetesSuite extends SparkFunSuite println(s"!!! status $resourceStatus for $name") val conditions = resourceStatus.getConditions().asScala println(s"!!! conditions $conditions for $name") - val result = conditions - .map(cond => cond.getStatus() == "True" && cond.getType() == "Ready") + val readyConditions = conditions.filter(cond.getType() == "Ready") + println(s"!!! ready conditions $conditions for $name") + val result = readyConditions + .map(cond => cond.getStatus() == "True") .headOption.getOrElse(false) - println(s"Pod name ${name} with entry ${execPod} has status" + - s"${resourceStatus} with conditions ${conditions} result: ${result}") + println(s"Pod name ${name} ready check resulted in ${result}") result } println("Creating watcher...") From 158838c469163abb05134ce777b45648d043394e Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 7 Mar 2019 12:48:04 -0800 Subject: [PATCH 166/194] Fix filter --- .../spark/deploy/k8s/integrationtest/KubernetesSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 41792bf6717ba..ad133d2d4d07a 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -265,7 +265,7 @@ class KubernetesSuite extends SparkFunSuite println(s"!!! status $resourceStatus for $name") val conditions = resourceStatus.getConditions().asScala println(s"!!! conditions $conditions for $name") - val readyConditions = conditions.filter(cond.getType() == "Ready") + val readyConditions = conditions.filter{cond => cond.getType() == "Ready"} println(s"!!! ready conditions $conditions for $name") val result = readyConditions .map(cond => cond.getStatus() == "True") From a6ad1ff6fa187c005d7697fcd253725ad2b10aad Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 7 Mar 2019 14:20:05 -0800 Subject: [PATCH 167/194] Remove unrelated subquery suite change --- sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 523c554ec10c1..53eb7ea7a0a05 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, Sort} import org.apache.spark.sql.execution.{ExecSubqueryExpression, FileSourceScanExec, WholeStageCodegenExec} import org.apache.spark.sql.execution.datasources.FileScanRDD import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types._ class SubquerySuite extends QueryTest with SharedSQLContext { import testImplicits._ From d85c2299850fdfef8ad369e14b6696a016c8c805 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 7 Mar 2019 14:20:58 -0800 Subject: [PATCH 168/194] Revert "Speed up running the kubernetes integration tests locally by allowing folks to skip the tgz dist build and extraction" This reverts commit 8306827f31f77a87d89d8324235c3e643abce8f1. --- .../scripts/setup-integration-test-env.sh | 12 ++++++------ .../deploy/k8s/integrationtest/KubernetesSuite.scala | 12 ++---------- 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh b/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh index deea4d3e46e94..a74299b738866 100755 --- a/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh +++ b/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh @@ -60,15 +60,15 @@ while (( "$#" )); do shift done -rm -rf $UNPACKED_SPARK_TGZ -if [[ $SPARK_TGZ == "N/A" && $IMAGE_TAG == "N/A" ]]; +if [[ $SPARK_TGZ == "N/A" ]]; then - echo "Must specify a Spark tarball to build Docker images against with --spark-tgz OR image with --image-tag." && exit 1; -else - mkdir -p $UNPACKED_SPARK_TGZ - tar -xzvf $SPARK_TGZ --strip-components=1 -C $UNPACKED_SPARK_TGZ; + echo "Must specify a Spark tarball to build Docker images against with --spark-tgz." && exit 1; fi +rm -rf $UNPACKED_SPARK_TGZ +mkdir -p $UNPACKED_SPARK_TGZ +tar -xzvf $SPARK_TGZ --strip-components=1 -C $UNPACKED_SPARK_TGZ; + if [[ $IMAGE_TAG == "N/A" ]]; then IMAGE_TAG=$(uuidgen); diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index ad133d2d4d07a..822cc179e0100 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -107,16 +107,8 @@ class KubernetesSuite extends SparkFunSuite System.clearProperty(key) } - val possible_spark_dirs = List( - // If someone specified the tgz for the tests look at the extraction dir - System.getProperty(CONFIG_KEY_UNPACK_DIR), - // If otherwise use my working dir + 3 up - new File(Paths.get(System.getProperty("user.dir")).toFile, ("../" * 3)).getAbsolutePath() - ) - val sparkDirProp = possible_spark_dirs.filter(x => - new File(Paths.get(x).toFile, "bin/spark-submit").exists).headOption.getOrElse(null) - require(sparkDirProp != null, - s"Spark home directory must be provided in system properties tested $possible_spark_dirs") + val sparkDirProp = System.getProperty(CONFIG_KEY_UNPACK_DIR) + require(sparkDirProp != null, "Spark home directory must be provided in system properties.") sparkHomeDir = Paths.get(sparkDirProp) require(sparkHomeDir.toFile.isDirectory, s"No directory found for spark home specified at $sparkHomeDir.") From e77009266c9b846f95434ec1c2157466fba3608b Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 7 Mar 2019 14:23:48 -0800 Subject: [PATCH 169/194] Take out set -e because I _think_ in integration env this fails with R issues but previously wasn't triggering any errors. This is a little skethy. --- .../integration-tests/scripts/setup-integration-test-env.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh b/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh index a74299b738866..178c875cf0a97 100755 --- a/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh +++ b/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh @@ -16,7 +16,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -set -e set -x TEST_ROOT_DIR=$(git rev-parse --show-toplevel) UNPACKED_SPARK_TGZ="$TEST_ROOT_DIR/target/spark-dist-unpacked" From 0cad83ac229165c9ee2ebe0f8cb84cb4b1e322aa Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 7 Mar 2019 14:34:54 -0800 Subject: [PATCH 170/194] Temporary commit to support running tests locally, should be part of https://github.com/apache/spark/pull/23380 --- .../scripts/setup-integration-test-env.sh | 33 ++++++++++++------- .../k8s/integrationtest/KubernetesSuite.scala | 12 +++++-- 2 files changed, 31 insertions(+), 14 deletions(-) diff --git a/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh b/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh index 178c875cf0a97..ab0123c2c5f08 100755 --- a/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh +++ b/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh @@ -59,50 +59,59 @@ while (( "$#" )); do shift done -if [[ $SPARK_TGZ == "N/A" ]]; +rm -rf "$UNPACKED_SPARK_TGZ" +if [[ $SPARK_TGZ == "N/A" && $IMAGE_TAG == "N/A" ]]; then - echo "Must specify a Spark tarball to build Docker images against with --spark-tgz." && exit 1; + # If there is no spark image tag to test with and no src dir, build from current + SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" + SPARK_INPUT_DIR="$(cd "$SCRIPT_DIR/"../../../../ >/dev/null 2>&1 && pwd )" + DOCKER_FILE_BASE_PATH="$SPARK_INPUT_DIR/resource-managers/kubernetes/docker/src/main/dockerfiles/spark" +elif [[ $IMAGE_TAG == "N/A" ]]; +then + # If there is a test src tarball and no image tag we will want to build from that + mkdir -p $UNPACKED_SPARK_TGZ + tar -xzvf $SPARK_TGZ --strip-components=1 -C $UNPACKED_SPARK_TGZ; + SPARK_INPUT_DIR="$UNPACKED_SPARK_TGZ" + DOCKER_FILE_BASE_PATH="$SPARK_INPUT_DIR/kubernetes/dockerfiles/spark" fi -rm -rf $UNPACKED_SPARK_TGZ -mkdir -p $UNPACKED_SPARK_TGZ -tar -xzvf $SPARK_TGZ --strip-components=1 -C $UNPACKED_SPARK_TGZ; +# If there is a specific Spark image skip building and extraction/copy if [[ $IMAGE_TAG == "N/A" ]]; then IMAGE_TAG=$(uuidgen); - cd $UNPACKED_SPARK_TGZ + cd $SPARK_INPUT_DIR # Build PySpark image - LANGUAGE_BINDING_BUILD_ARGS="-p $UNPACKED_SPARK_TGZ/kubernetes/dockerfiles/spark/bindings/python/Dockerfile" + LANGUAGE_BINDING_BUILD_ARGS="-p $DOCKER_FILE_BASE_PATH/bindings/python/Dockerfile" # Build SparkR image - LANGUAGE_BINDING_BUILD_ARGS="$LANGUAGE_BINDING_BUILD_ARGS -R $UNPACKED_SPARK_TGZ/kubernetes/dockerfiles/spark/bindings/R/Dockerfile" + LANGUAGE_BINDING_BUILD_ARGS="$LANGUAGE_BINDING_BUILD_ARGS -R $DOCKER_FILE_BASE_PATH/bindings/R/Dockerfile" case $DEPLOY_MODE in cloud) # Build images - $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG $LANGUAGE_BINDING_BUILD_ARGS build + $SPARK_INPUT_DIR/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG $LANGUAGE_BINDING_BUILD_ARGS build # Push images appropriately if [[ $IMAGE_REPO == gcr.io* ]] ; then gcloud docker -- push $IMAGE_REPO/spark:$IMAGE_TAG else - $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG push + $SPARK_INPUT_DIR/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG push fi ;; docker-for-desktop) # Only need to build as this will place it in our local Docker repo which is all # we need for Docker for Desktop to work so no need to also push - $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG $LANGUAGE_BINDING_BUILD_ARGS build + $SPARK_INPUT_DIR/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG $LANGUAGE_BINDING_BUILD_ARGS build ;; minikube) # Only need to build and if we do this with the -m option for minikube we will # build the images directly using the minikube Docker daemon so no need to push - $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -m -r $IMAGE_REPO -t $IMAGE_TAG $LANGUAGE_BINDING_BUILD_ARGS build + $SPARK_INPUT_DIR/bin/docker-image-tool.sh -m -r $IMAGE_REPO -t $IMAGE_TAG $LANGUAGE_BINDING_BUILD_ARGS build ;; *) echo "Unrecognized deploy mode $DEPLOY_MODE" && exit 1 diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 822cc179e0100..25720c58e55c2 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -107,8 +107,16 @@ class KubernetesSuite extends SparkFunSuite System.clearProperty(key) } - val sparkDirProp = System.getProperty(CONFIG_KEY_UNPACK_DIR) - require(sparkDirProp != null, "Spark home directory must be provided in system properties.") + val possible_spark_dirs = List( + // If someone specified the tgz for the tests look at the extraction dir + System.getProperty(CONFIG_KEY_UNPACK_DIR), + // Try the spark test home + sys.props("spark.test.home") + ) + val sparkDirProp = possible_spark_dirs.filter(x => + new File(Paths.get(x).toFile, "bin/spark-submit").exists).headOption.getOrElse(null) + require(sparkDirProp != null, + s"Spark home directory must be provided in system properties tested $possible_spark_dirs") sparkHomeDir = Paths.get(sparkDirProp) require(sparkHomeDir.toFile.isDirectory, s"No directory found for spark home specified at $sparkHomeDir.") From ad3474ddf218a8d6830503ec51df81c78f6050a1 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 7 Mar 2019 15:33:58 -0800 Subject: [PATCH 171/194] Tests are running locally, pod is created and deleted but we don't get decomissioning state. Took out wait for all executors to be removed because Spark just schedules another executor once we delete the first one. --- .../k8s/integrationtest/KubernetesSuite.scala | 31 +++++++++---------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 25720c58e55c2..3491d4ffc096c 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -41,10 +41,10 @@ import org.apache.spark.internal.config._ class KubernetesSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter - with BasicTestsSuite +// with BasicTestsSuite with DecommissionSuite - with SecretsTestsSuite - with PythonTestsSuite with ClientModeTestsSuite with PodTemplateSuite +// with SecretsTestsSuite +// with PythonTestsSuite with ClientModeTestsSuite with PodTemplateSuite with Logging with Eventually with Matchers { import KubernetesSuite._ @@ -265,8 +265,10 @@ class KubernetesSuite extends SparkFunSuite println(s"!!! status $resourceStatus for $name") val conditions = resourceStatus.getConditions().asScala println(s"!!! conditions $conditions for $name") + val conditionTypes = conditions.map(_.getType()) + println(s"!!! condition types $conditionTypes") val readyConditions = conditions.filter{cond => cond.getType() == "Ready"} - println(s"!!! ready conditions $conditions for $name") + println(s"!!! ready conditions $readyConditions for $name") val result = readyConditions .map(cond => cond.getStatus() == "True") .headOption.getOrElse(false) @@ -287,12 +289,14 @@ class KubernetesSuite extends SparkFunSuite val name = resource.getMetadata.getName val namespace = resource.getMetadata().getNamespace() action match { - case Action.ADDED | Action.MODIFIED => - println(s"Add or modification event received for $name.") + case Action.MODIFIED => execPods(name) = resource - // If testing decomissioning delete the node 5 seconds after it starts running - // Open question: could we put this in the checker - if (decomissioningTest) { + case Action.ADDED => + println(s"Add event received for $name.") + execPods(name) = resource + // If testing decomissioning delete the first with a delay after it starts + // running. + if (decomissioningTest && execPods.size == 1) { // Wait for all the containers in the pod to be running println("Waiting for pod to become OK then delete.") Eventually.eventually(POD_RUNNING_TIMEOUT, INTERVAL) { @@ -301,7 +305,7 @@ class KubernetesSuite extends SparkFunSuite } // Sleep a small interval to allow execution of job println("Sleeping before killing pod.") - Thread.sleep(100) + Thread.sleep(30000) // Delete the pod to simulate cluster scale down/migration. val pod = kubernetesTestComponents.kubernetesClient.pods().withName(name) pod.delete() @@ -358,13 +362,6 @@ class KubernetesSuite extends SparkFunSuite val podsReadyOrDead = anyReadyPods || podsEmpty podsReadyOrDead shouldBe (true) } - // Sleep a small interval to allow execution - Thread.sleep(3000) - // Wait for the executors to be removed - Eventually.eventually(TIMEOUT, INTERVAL) { - println(s"decom iteration pods non-empty ${execPods.values.nonEmpty} with ${execPods}") - execPods.values.nonEmpty should be (false) - } } println(s"Closing watcher with execPods $execPods nonEmpty: ${execPods.values.nonEmpty}") execWatcher.close() From 23decb9fd95717528d14b1e3fb1fc4fef73adde1 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 7 Mar 2019 15:43:03 -0800 Subject: [PATCH 172/194] Move the config variable for decom into worker, add a bit more logging since we don't appear to be triggering the decom script. --- .../main/scala/org/apache/spark/deploy/worker/Worker.scala | 5 ++++- .../main/scala/org/apache/spark/internal/config/Worker.scala | 5 +++++ .../scala/org/apache/spark/internal/config/package.scala | 5 ----- .../scala/org/apache/spark/deploy/k8s/KubernetesConf.scala | 2 +- .../spark/deploy/k8s/features/BasicExecutorFeatureStep.scala | 5 ++++- 5 files changed, 14 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index f1f8344a2c7f9..2388270984f47 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -64,8 +64,11 @@ private[deploy] class Worker( assert (port > 0) // If worker decommissioning is enabled register a handler on SIGPWR to shutdown. - if (conf.get(config.WORKER_DECOMMISSION_ENABLED)) { + if (conf.get(Worker.WORKER_DECOMMISSION_ENABLED)) { + logInfo("Registering SIGPWR handler.") SignalUtils.register("SIGPWR")(decommissionSelf) + } else { + logInfo("Worker decomissioning not enabled, skipping SIGPWR") } // A scheduled executor used to send messages at the specified time. diff --git a/core/src/main/scala/org/apache/spark/internal/config/Worker.scala b/core/src/main/scala/org/apache/spark/internal/config/Worker.scala index 47f7167d2c9cb..f29d7d4aa467e 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/Worker.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/Worker.scala @@ -60,4 +60,9 @@ private[spark] object Worker { ConfigBuilder("spark.worker.ui.compressedLogFileLengthCacheSize") .intConf .createWithDefault(100) + + private[spark] val WORKER_DECOMMISSION_ENABLED = + ConfigBuilder("spark.worker.decommission.enabled") + .booleanConf + .createWithDefault(false) } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 84ad882d81e11..d6a359db66f48 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -489,11 +489,6 @@ package object config { .createWithDefault(false) // End blacklist confs - private[spark] val WORKER_DECOMMISSION_ENABLED = - ConfigBuilder("spark.worker.decommission.enabled") - .booleanConf - .createWithDefault(false) - private[spark] val UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE = ConfigBuilder("spark.files.fetchFailure.unRegisterOutputOnHost") .doc("Whether to un-register all the outputs on the host in condition that we receive " + diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala index 12f5c17f847d0..b040451988e5c 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala @@ -55,7 +55,7 @@ private[spark] abstract class KubernetesConf(val sparkConf: SparkConf) { } def workerDecomissioning: Boolean = - sparkConf.get(org.apache.spark.internal.config.WORKER_DECOMMISSION_ENABLED) + sparkConf.get(org.apache.spark.internal.config.Worker.WORKER_DECOMMISSION_ENABLED) def nodeSelector: Map[String, String] = KubernetesUtils.parsePrefixedKeyValuePairs(sparkConf, KUBERNETES_NODE_SELECTOR_PREFIX) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala index 02c863e5fd616..4902bd8271ae3 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala @@ -189,6 +189,7 @@ private[spark] class BasicExecutorFeatureStep( }.getOrElse(executorContainer) val containerWithLifecycle = kubernetesConf.workerDecomissioning match { case true => + logger.info("Adding decommission script to lifecycle") new ContainerBuilder(executorContainer).editOrNewLifecycle() .withNewPreStop() .withNewExec() @@ -198,7 +199,9 @@ private[spark] class BasicExecutorFeatureStep( .endPreStop() .endLifecycle() .build() - case false => containerWithLimitCores + case false => + logger.info("Decommissioning not enabled, skipping shutdown script") + containerWithLimitCores } val ownerReference = kubernetesConf.driverPod.map { pod => new OwnerReferenceBuilder() From 78d02d3700a4e5cf7e5605af69eeed24c243f128 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 7 Mar 2019 17:04:19 -0800 Subject: [PATCH 173/194] Fix instances of decomi , register SIGPWR in CoarseGrainedExecutorBackend so we don't depend on Worker (only in standalone), use Worker.WORKER_DECOMMISSION_ENABLED.key, instead of string value in the k8s integration test --- .../apache/spark/deploy/master/Master.scala | 4 ++-- .../apache/spark/deploy/worker/Worker.scala | 6 ++--- .../CoarseGrainedExecutorBackend.scala | 23 ++++++++++++++++++- .../cluster/CoarseGrainedClusterMessage.scala | 2 +- .../CoarseGrainedSchedulerBackend.scala | 4 ++-- .../spark/deploy/client/AppClientSuite.scala | 2 +- .../scheduler/WorkerDecommissionSuite.scala | 4 ++-- .../integrationtest/DecommissionSuite.scala | 7 +++++- 8 files changed, 39 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index b622445ddad12..439628b19b838 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -250,7 +250,7 @@ private[deploy] class Master( if (state == RecoveryState.STANDBY) { workerRef.send(MasterInStandby) } else { - // If a worker attempts to decomission that isn't registered ignore it. + // If a worker attempts to decommission that isn't registered ignore it. idToWorker.get(id).foreach(decommissionWorker) } @@ -808,7 +808,7 @@ private[deploy] class Master( logInfo("Decommissioning worker %s on %s:%d".format(worker.id, worker.host, worker.port)) worker.setState(WorkerState.DECOMMISSIONED) for (exec <- worker.executors.values) { - logInfo("Telling app of decomission executors") + logInfo("Telling app of decommission executors") exec.application.driver.send(ExecutorUpdated( exec.id, ExecutorState.DECOMMISSIONED, Some("worker decommissioned"), None, workerLost = false)) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 2388270984f47..40901a9e46659 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -64,11 +64,11 @@ private[deploy] class Worker( assert (port > 0) // If worker decommissioning is enabled register a handler on SIGPWR to shutdown. - if (conf.get(Worker.WORKER_DECOMMISSION_ENABLED)) { + if (conf.get(WORKER_DECOMMISSION_ENABLED)) { logInfo("Registering SIGPWR handler.") SignalUtils.register("SIGPWR")(decommissionSelf) } else { - logInfo("Worker decomissioning not enabled, skipping SIGPWR") + logInfo("Worker decommissioning not enabled, skipping SIGPWR") } // A scheduled executor used to send messages at the specified time. @@ -712,7 +712,7 @@ private[deploy] class Worker( } private[deploy] def decommissionSelf(): Boolean = { - if (conf.get(config.WORKER_DECOMMISSION_ENABLED)) { + if (conf.get(WORKER_DECOMMISSION_ENABLED)) { logDebug("Decommissioning self") decommissioned = true // TODO: Send decommission notification to executors & shuffle service. diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 645f58716de63..af0526c26569c 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -28,6 +28,7 @@ import scala.util.control.NonFatal import org.apache.spark._ import org.apache.spark.TaskState.TaskState +import org.apache.spark.internal.config import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.worker.WorkerWatcher import org.apache.spark.internal.Logging @@ -35,7 +36,7 @@ import org.apache.spark.rpc._ import org.apache.spark.scheduler.{ExecutorLossReason, TaskDescription} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.serializer.SerializerInstance -import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.util.{SignalUtils, ThreadUtils, Utils} private[spark] class CoarseGrainedExecutorBackend( override val rpcEnv: RpcEnv, @@ -49,6 +50,7 @@ private[spark] class CoarseGrainedExecutorBackend( private[this] val stopping = new AtomicBoolean(false) var executor: Executor = null + @volatile private var decommissioned = false @volatile var driver: Option[RpcEndpointRef] = None // If this CoarseGrainedExecutorBackend is changed to support multiple threads, then this may need @@ -56,6 +58,9 @@ private[spark] class CoarseGrainedExecutorBackend( private[this] val ser: SerializerInstance = env.closureSerializer.newInstance() override def onStart() { + logInfo("Registering SIGPWR handler.") + SignalUtils.register("SIGPWR")(decommissionSelf) + logInfo("Connecting to driver: " + driverUrl) rpcEnv.asyncSetupEndpointRefByURI(driverUrl).flatMap { ref => // This is a very fast action so we can use "ThreadUtils.sameThread" @@ -99,6 +104,8 @@ private[spark] class CoarseGrainedExecutorBackend( case LaunchTask(data) => if (executor == null) { exitExecutor(1, "Received LaunchTask command but executor was null") + } else if (decommissioned) { + logWarning("Asked to launch a task while decommissioned. Not launching.") } else { val taskDesc = TaskDescription.decode(data.value) logInfo("Got assigned task " + taskDesc.taskId) @@ -177,6 +184,20 @@ private[spark] class CoarseGrainedExecutorBackend( System.exit(code) } + + private def decommissionSelf(): Boolean = { + logDebug("Decommissioning self") + decommissioned = true + // Tell master we are are decommissioned so it stops trying to schedule us + if (driver.nonEmpty) { + driver.get.send(DecommissionExecutor(executorId)) + } + if (executor != null) { + executor.decommission() + } + // Return true since we are handling a signal + true + } } private[spark] object CoarseGrainedExecutorBackend extends Logging { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 14d1243b1d3c0..9471b434e36f4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -90,7 +90,7 @@ private[spark] object CoarseGrainedClusterMessages { case class RemoveExecutor(executorId: String, reason: ExecutorLossReason) extends CoarseGrainedClusterMessage - case class DecomissionExecutor(executorId: String) extends CoarseGrainedClusterMessage + case class DecommissionExecutor(executorId: String) extends CoarseGrainedClusterMessage case class RemoveWorker(workerId: String, host: String, message: String) extends CoarseGrainedClusterMessage diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 14f6e763febd0..8dda9df51c2f3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -249,7 +249,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp removeWorker(workerId, host, message) context.reply(true) - case DecomissionExecutor(executorId) => + case DecommissionExecutor(executorId) => logInfo(s"Decommissioning executor ${executorId}") decommissionExecutor(executorId) context.reply(true) @@ -530,7 +530,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp */ private[spark] def decommissionExecutor(executorId: String): Unit = { // Only log the failure since we don't care about the result. - driverEndpoint.ask[Boolean](DecomissionExecutor(executorId)).onFailure { case t => + driverEndpoint.ask[Boolean](DecommissionExecutor(executorId)).onFailure { case t => logError(t.getMessage, t) }(ThreadUtils.sameThread) } diff --git a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala index 0a44248e8214b..94150b7e6ab61 100644 --- a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala @@ -58,7 +58,7 @@ class AppClientSuite */ override def beforeAll(): Unit = { super.beforeAll() - conf = new SparkConf().set(config.WORKER_DECOMMISSION_ENABLED.key, "true") + conf = new SparkConf().set(config.Worker.WORKER_DECOMMISSION_ENABLED.key, "true") securityManager = new SecurityManager(conf) masterRpcEnv = RpcEnv.create(Master.SYSTEM_NAME, "localhost", 0, conf, securityManager) workerRpcEnvs = (0 until numWorkers).map { i => diff --git a/core/src/test/scala/org/apache/spark/scheduler/WorkerDecommissionSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/WorkerDecommissionSuite.scala index 4dc9a9dc4109c..483c449fa9e14 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/WorkerDecommissionSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/WorkerDecommissionSuite.scala @@ -30,7 +30,7 @@ class WorkerDecommissionSuite extends SparkFunSuite with LocalSparkContext { override def beforeEach(): Unit = { val conf = new SparkConf().setAppName("test").setMaster("local") - .set(config.WORKER_DECOMMISSION_ENABLED.key, "true") + .set(config.Worker.WORKER_DECOMMISSION_ENABLED.key, "true") sc = new SparkContext("local-cluster[2, 1, 1024]", "test", conf) } @@ -69,6 +69,6 @@ class WorkerDecommissionSuite extends SparkFunSuite with LocalSparkContext { val result = ThreadUtils.awaitResult(postDecomAsyncCount, 1.seconds) } assert(postDecomAsyncCount.isCompleted === false, - "After exec decomission new task could not launch") + "After exec decommission new task could not launch") } } diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala index 595594f24be75..64d19737227e9 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.deploy.k8s.integrationtest +import import org.apache.spark.internal.config.Worker + private[spark] trait DecommissionSuite { k8sSuite: KubernetesSuite => import DecommissionSuite._ @@ -26,10 +28,13 @@ private[spark] trait DecommissionSuite { k8sSuite: KubernetesSuite => println("***TESTING decommissioning***") // scalastyle:on println sparkAppConf - .set("spark.worker.decommission.enabled", "true") + .set(Worker.WORKER_DECOMMISSION_ENABLED.key, "true") .set("spark.kubernetes.pyspark.pythonVersion", "3") .set("spark.kubernetes.container.image", pyImage) + // scalastyle:off println + println("***Running app***") + // scalastyle:on println runSparkApplicationAndVerifyCompletion( appResource = PYSPARK_DECOMISSIONING, mainClass = "", From fa6db3278801d9640f8228edfef723e8c8699117 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 7 Mar 2019 17:12:21 -0800 Subject: [PATCH 174/194] Fixed compile errors from last --- .../deploy/k8s/features/BasicExecutorFeatureStep.scala | 7 ++++--- .../deploy/k8s/integrationtest/DecommissionSuite.scala | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala index 4902bd8271ae3..3cc60ed57e135 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala @@ -26,6 +26,7 @@ import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.internal.config._ import org.apache.spark.internal.config.Python._ +import org.apache.spark.internal.Logging import org.apache.spark.rpc.RpcEndpointAddress import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.util.Utils @@ -33,7 +34,7 @@ import org.apache.spark.util.Utils private[spark] class BasicExecutorFeatureStep( kubernetesConf: KubernetesExecutorConf, secMgr: SecurityManager) - extends KubernetesFeatureConfigStep { + extends KubernetesFeatureConfigStep with Logging { // Consider moving some of these fields to KubernetesConf or KubernetesExecutorSpecificConf private val executorContainerImage = kubernetesConf @@ -189,7 +190,7 @@ private[spark] class BasicExecutorFeatureStep( }.getOrElse(executorContainer) val containerWithLifecycle = kubernetesConf.workerDecomissioning match { case true => - logger.info("Adding decommission script to lifecycle") + logInfo("Adding decommission script to lifecycle") new ContainerBuilder(executorContainer).editOrNewLifecycle() .withNewPreStop() .withNewExec() @@ -200,7 +201,7 @@ private[spark] class BasicExecutorFeatureStep( .endLifecycle() .build() case false => - logger.info("Decommissioning not enabled, skipping shutdown script") + logInfo("Decommissioning not enabled, skipping shutdown script") containerWithLimitCores } val ownerReference = kubernetesConf.driverPod.map { pod => diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala index 64d19737227e9..8735004d797de 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.deploy.k8s.integrationtest -import import org.apache.spark.internal.config.Worker +import org.apache.spark.internal.config.Worker private[spark] trait DecommissionSuite { k8sSuite: KubernetesSuite => From 6d31986e2727d89bd71e66639393f90c02d1cd65 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 7 Mar 2019 17:16:43 -0800 Subject: [PATCH 175/194] s/SIGPWR/PWR/ in the Scala code. --- .../scala/org/apache/spark/deploy/worker/Worker.scala | 8 ++++---- .../spark/executor/CoarseGrainedExecutorBackend.scala | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 40901a9e46659..b81d9979ed836 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -63,12 +63,12 @@ private[deploy] class Worker( Utils.checkHost(host) assert (port > 0) - // If worker decommissioning is enabled register a handler on SIGPWR to shutdown. + // If worker decommissioning is enabled register a handler on PWR to shutdown. if (conf.get(WORKER_DECOMMISSION_ENABLED)) { - logInfo("Registering SIGPWR handler.") - SignalUtils.register("SIGPWR")(decommissionSelf) + logInfo("Registering PWR handler.") + SignalUtils.register("PWR")(decommissionSelf) } else { - logInfo("Worker decommissioning not enabled, skipping SIGPWR") + logInfo("Worker decommissioning not enabled, skipping PWR") } // A scheduled executor used to send messages at the specified time. diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index af0526c26569c..3fd2dad37eac4 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -58,8 +58,8 @@ private[spark] class CoarseGrainedExecutorBackend( private[this] val ser: SerializerInstance = env.closureSerializer.newInstance() override def onStart() { - logInfo("Registering SIGPWR handler.") - SignalUtils.register("SIGPWR")(decommissionSelf) + logInfo("Registering PWR handler.") + SignalUtils.register("PWR")(decommissionSelf) logInfo("Connecting to driver: " + driverUrl) rpcEnv.asyncSetupEndpointRefByURI(driverUrl).flatMap { ref => From ba755de416c42d2985c9209aea8311ee20b517b4 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 7 Mar 2019 17:32:21 -0800 Subject: [PATCH 176/194] Add lifecycle change --- .../spark/deploy/k8s/features/BasicExecutorFeatureStep.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala index 3cc60ed57e135..93c33a65b7525 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala @@ -228,6 +228,6 @@ private[spark] class BasicExecutorFeatureStep( .endSpec() .build() - SparkPod(executorPod, containerWithLimitCores) + SparkPod(executorPod, containerWithLifeCycle) } } From 777da86541aab9a9faa599a36711f4e69cf6c716 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 7 Mar 2019 17:32:32 -0800 Subject: [PATCH 177/194] Print out when we're getting ready to stop Spark and increase sleep --- .../kubernetes/integration-tests/tests/decommissioning.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/resource-managers/kubernetes/integration-tests/tests/decommissioning.py b/resource-managers/kubernetes/integration-tests/tests/decommissioning.py index a99e76d2ec2b4..c7d1fc64409c8 100644 --- a/resource-managers/kubernetes/integration-tests/tests/decommissioning.py +++ b/resource-managers/kubernetes/integration-tests/tests/decommissioning.py @@ -36,6 +36,7 @@ rdd = sc.parallelize(range(10)) rdd.collect() print("Waiting to give nodes time to finish.") - time.sleep(50) + time.sleep(120) + print("Stopping spark") spark.stop() sys.exit(0) From 209bf18a42e7c254bf516ef0054ac8f54eba8041 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 13 Mar 2019 16:55:52 -0700 Subject: [PATCH 178/194] Try and debug whats going on with our container lifecycle. Currently it looks like it is not getting written out --- .../features/BasicExecutorFeatureStep.scala | 18 +++++++++++++----- .../src/main/dockerfiles/spark/Dockerfile | 1 + .../docker/src/main/dockerfiles/spark/decom.sh | 8 ++++++++ sbin/decommission-slave.sh | 2 ++ 4 files changed, 24 insertions(+), 5 deletions(-) create mode 100644 resource-managers/kubernetes/docker/src/main/dockerfiles/spark/decom.sh diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala index 93c33a65b7525..11a42751e2995 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala @@ -189,20 +189,26 @@ private[spark] class BasicExecutorFeatureStep( .build() }.getOrElse(executorContainer) val containerWithLifecycle = kubernetesConf.workerDecomissioning match { + case false => + logInfo("Decommissioning not enabled, skipping shutdown script") + containerWithLimitCores case true => logInfo("Adding decommission script to lifecycle") - new ContainerBuilder(executorContainer).editOrNewLifecycle() + new ContainerBuilder(containerWithLimitCores).withNewLifecycle() + .withNewPostStart() + .withNewExec() + .withCommand( + List("/bin/sh", "-c", "exit 1").asJava) + .endExec() + .endPostStart() .withNewPreStop() .withNewExec() .withCommand( - List("/opt/spark/sbin/decommission-slave.sh", "--block-until-exit").asJava) + List("/opt/decom.sh").asJava) .endExec() .endPreStop() .endLifecycle() .build() - case false => - logInfo("Decommissioning not enabled, skipping shutdown script") - containerWithLimitCores } val ownerReference = kubernetesConf.driverPod.map { pod => new OwnerReferenceBuilder() @@ -228,6 +234,8 @@ private[spark] class BasicExecutorFeatureStep( .endSpec() .build() + logInfo(s"Built container $containerWithLifeCycle") + SparkPod(executorPod, containerWithLifeCycle) } } diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile index 1d8ac3ce89c78..7aa255f336eb5 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile @@ -43,6 +43,7 @@ COPY jars /opt/spark/jars COPY bin /opt/spark/bin COPY sbin /opt/spark/sbin COPY kubernetes/dockerfiles/spark/entrypoint.sh /opt/ +COPY kubernetes/dockerfiles/spark/decom.sh /opt/ COPY examples /opt/spark/examples COPY kubernetes/tests /opt/spark/tests COPY data /opt/spark/data diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/decom.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/decom.sh new file mode 100644 index 0000000000000..e071d40238f10 --- /dev/null +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/decom.sh @@ -0,0 +1,8 @@ +#!/usr/bin/env bash + +set -ex +echo "hi" +echo "Decom adventures" > /dev/termination-log || echo "logging is hard" +WORKER_PID=$(ps axf | grep java |grep org.apache.spark | grep -v grep | awk '{print "kill -9 " $1}') +kill -s SIGPWR ${WORKER_PID} +waitpid ${WORKER_PID} diff --git a/sbin/decommission-slave.sh b/sbin/decommission-slave.sh index 53048931d69b4..4bbf257ff1d3a 100644 --- a/sbin/decommission-slave.sh +++ b/sbin/decommission-slave.sh @@ -27,6 +27,8 @@ # Usage: decommission-slave.sh [--block-until-exit] # Decommissions all slaves on this worker machine +set -ex + if [ -z "${SPARK_HOME}" ]; then export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" fi From 0faf47c7f61a80832c3575a8f39476f1e6cabae7 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 13 Mar 2019 17:28:09 -0700 Subject: [PATCH 179/194] Try and check that lifecycle is being added with more logging and a life cycle which should cause failure --- .../k8s/features/BasicExecutorFeatureStep.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala index 11a42751e2995..db49d884c8b5d 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala @@ -24,9 +24,9 @@ import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.k8s._ import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.internal.config.Python._ -import org.apache.spark.internal.Logging import org.apache.spark.rpc.RpcEndpointAddress import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.util.Utils @@ -197,15 +197,15 @@ private[spark] class BasicExecutorFeatureStep( new ContainerBuilder(containerWithLimitCores).withNewLifecycle() .withNewPostStart() .withNewExec() - .withCommand( - List("/bin/sh", "-c", "exit 1").asJava) + .addToCommand("/bin/sh") + .addtoCommand("-c") + .addtoCommand("exit 1") .endExec() .endPostStart() .withNewPreStop() .withNewExec() - .withCommand( - List("/opt/decom.sh").asJava) - .endExec() + .addToCommand("/opt/decom.sh") + .endExec() .endPreStop() .endLifecycle() .build() From 40b75f5e0510ffd1bf0ee64efffe7bc27afd8468 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 13 Mar 2019 17:28:32 -0700 Subject: [PATCH 180/194] Print out our scala version and set it for downstream. --- .../integration-tests/dev/dev-run-integration-tests.sh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh index ed33b6522620c..c1fc81b11d24b 100755 --- a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh +++ b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh @@ -32,11 +32,13 @@ INCLUDE_TAGS="k8s" EXCLUDE_TAGS= MVN="$TEST_ROOT_DIR/build/mvn" -SCALA_VERSION=$("$MVN" help:evaluate -Dexpression=scala.binary.version 2>/dev/null\ +export SCALA_VERSION=$("$MVN" help:evaluate -Dexpression=scala.binary.version 2>/dev/null\ | grep -v "INFO"\ | grep -v "WARNING"\ | tail -n 1) +echo $SCALA_VERSION + # Parse arguments while (( "$#" )); do case $1 in From 0ee3ae0fe6ce7a1e678731f5681cc530641ade48 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 13 Mar 2019 17:29:48 -0700 Subject: [PATCH 181/194] Disable R build because it's not working and also set SPARK_SCALA_VERSION for downstream building with docker image and scala 2.12 --- .../scripts/setup-integration-test-env.sh | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh b/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh index ab0123c2c5f08..85724ca978fee 100755 --- a/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh +++ b/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh @@ -18,6 +18,7 @@ # set -x TEST_ROOT_DIR=$(git rev-parse --show-toplevel) +MVN="$TEST_ROOT_DIR/build/mvn" UNPACKED_SPARK_TGZ="$TEST_ROOT_DIR/target/spark-dist-unpacked" IMAGE_TAG_OUTPUT_FILE="$TEST_ROOT_DIR/target/image-tag.txt" DEPLOY_MODE="minikube" @@ -60,6 +61,15 @@ while (( "$#" )); do done rm -rf "$UNPACKED_SPARK_TGZ" + +MVN_SCALA_VERSION=$("$MVN" help:evaluate -Dexpression=scala.binary.version 2>/dev/null\ + | grep -v "INFO"\ + | grep -v "WARNING"\ + | tail -n 1) + +export SCALA_VERSION=${SCALA_VERSION:=$MVN_SCALA_VERSION} +export SPARK_SCALA_VERSION=${SPARK_SCALA_VERSION:=$SCALA_VERSION} + if [[ $SPARK_TGZ == "N/A" && $IMAGE_TAG == "N/A" ]]; then # If there is no spark image tag to test with and no src dir, build from current @@ -86,7 +96,7 @@ then LANGUAGE_BINDING_BUILD_ARGS="-p $DOCKER_FILE_BASE_PATH/bindings/python/Dockerfile" # Build SparkR image - LANGUAGE_BINDING_BUILD_ARGS="$LANGUAGE_BINDING_BUILD_ARGS -R $DOCKER_FILE_BASE_PATH/bindings/R/Dockerfile" + #LANGUAGE_BINDING_BUILD_ARGS="$LANGUAGE_BINDING_BUILD_ARGS -R $DOCKER_FILE_BASE_PATH/bindings/R/Dockerfile" case $DEPLOY_MODE in cloud) From cd0817545af59f2f866a3e0c4c1fe55ca1409a71 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 13 Mar 2019 17:30:12 -0700 Subject: [PATCH 182/194] Add scala version as a param but I don't think we need this --- external/docker/spark-test/base/Dockerfile | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/external/docker/spark-test/base/Dockerfile b/external/docker/spark-test/base/Dockerfile index c70cd71367679..c730bdcc78cd0 100644 --- a/external/docker/spark-test/base/Dockerfile +++ b/external/docker/spark-test/base/Dockerfile @@ -25,7 +25,8 @@ RUN apt-get update && \ apt-get install -y less openjdk-8-jre-headless iproute2 vim-tiny sudo openssh-server && \ rm -rf /var/lib/apt/lists/* -ENV SCALA_VERSION 2.11.8 +ARG scala_version_buildtime=2.11.8 +ENV SCALA_VERSION $scala_version_buildtime ENV CDH_VERSION cdh4 ENV SCALA_HOME /opt/scala-$SCALA_VERSION ENV SPARK_HOME /opt/spark From 387035b49c3d08e59247a2d8d92836fb318b2d53 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 13 Mar 2019 17:30:19 -0700 Subject: [PATCH 183/194] Revert "Add scala version as a param but I don't think we need this" This reverts commit cd0817545af59f2f866a3e0c4c1fe55ca1409a71. --- external/docker/spark-test/base/Dockerfile | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/external/docker/spark-test/base/Dockerfile b/external/docker/spark-test/base/Dockerfile index c730bdcc78cd0..c70cd71367679 100644 --- a/external/docker/spark-test/base/Dockerfile +++ b/external/docker/spark-test/base/Dockerfile @@ -25,8 +25,7 @@ RUN apt-get update && \ apt-get install -y less openjdk-8-jre-headless iproute2 vim-tiny sudo openssh-server && \ rm -rf /var/lib/apt/lists/* -ARG scala_version_buildtime=2.11.8 -ENV SCALA_VERSION $scala_version_buildtime +ENV SCALA_VERSION 2.11.8 ENV CDH_VERSION cdh4 ENV SCALA_HOME /opt/scala-$SCALA_VERSION ENV SPARK_HOME /opt/spark From f11d9f5c02f31046c757cc7be331d340466cb941 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 13 Mar 2019 17:42:09 -0700 Subject: [PATCH 184/194] Lifecycle now running ok take out the bad lifecycle stage --- .../k8s/features/BasicExecutorFeatureStep.scala | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala index db49d884c8b5d..61f27dcd6914e 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala @@ -195,17 +195,10 @@ private[spark] class BasicExecutorFeatureStep( case true => logInfo("Adding decommission script to lifecycle") new ContainerBuilder(containerWithLimitCores).withNewLifecycle() - .withNewPostStart() - .withNewExec() - .addToCommand("/bin/sh") - .addtoCommand("-c") - .addtoCommand("exit 1") - .endExec() - .endPostStart() .withNewPreStop() .withNewExec() .addToCommand("/opt/decom.sh") - .endExec() + .endExec() .endPreStop() .endLifecycle() .build() @@ -234,8 +227,8 @@ private[spark] class BasicExecutorFeatureStep( .endSpec() .build() - logInfo(s"Built container $containerWithLifeCycle") + logInfo(s"Built container $containerWithLifecycle") - SparkPod(executorPod, containerWithLifeCycle) + SparkPod(executorPod, containerWithLifecycle) } } From 727f76ed29172b9484cfcd06859bb9606ee2e02f Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 13 Mar 2019 17:42:43 -0700 Subject: [PATCH 185/194] decom script should be executable --- .../kubernetes/docker/src/main/dockerfiles/spark/decom.sh | 0 1 file changed, 0 insertions(+), 0 deletions(-) mode change 100644 => 100755 resource-managers/kubernetes/docker/src/main/dockerfiles/spark/decom.sh diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/decom.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/decom.sh old mode 100644 new mode 100755 From b37892d200f5e4a04ba319ae42f009e8de9998c7 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 13 Mar 2019 18:08:51 -0700 Subject: [PATCH 186/194] Decom script is in the pod yaml, not seeing SIGTERM anymore in the log of executor. Attempt to use killall incase the grep approach wasn't working and add even more logging --- .../CoarseGrainedExecutorBackend.scala | 26 ++++++++++++------- .../org/apache/spark/util/SignalUtils.scala | 3 ++- .../src/main/dockerfiles/spark/decom.sh | 6 +++-- 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 3fd2dad37eac4..90e256e356ca0 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -186,17 +186,23 @@ private[spark] class CoarseGrainedExecutorBackend( } private def decommissionSelf(): Boolean = { - logDebug("Decommissioning self") - decommissioned = true - // Tell master we are are decommissioned so it stops trying to schedule us - if (driver.nonEmpty) { - driver.get.send(DecommissionExecutor(executorId)) - } - if (executor != null) { - executor.decommission() + logError("Decommissioning self") + try { + decommissioned = true + // Tell master we are are decommissioned so it stops trying to schedule us + if (driver.nonEmpty) { + driver.get.send(DecommissionExecutor(executorId)) + } + if (executor != null) { + executor.decommission() + } + // Return true since we are handling a signal + true + } catch { + case e: Exception => + logError(s"Error ${e} during attempt to decommission self") + false } - // Return true since we are handling a signal - true } } diff --git a/core/src/main/scala/org/apache/spark/util/SignalUtils.scala b/core/src/main/scala/org/apache/spark/util/SignalUtils.scala index 5a24965170cef..1d507185802ce 100644 --- a/core/src/main/scala/org/apache/spark/util/SignalUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/SignalUtils.scala @@ -60,10 +60,11 @@ private[spark] object SignalUtils extends Logging { if (SystemUtils.IS_OS_UNIX) { try { val handler = handlers.getOrElseUpdate(signal, { - logInfo("Registered signal handler for " + signal) + logInfo("Registering signal handler for " + signal) new ActionHandler(new Signal(signal)) }) handler.register(action) + logInfo("Registered signal handler for " + signal) } catch { case ex: Exception => logWarning(s"Failed to register signal handler for " + signal, ex) } diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/decom.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/decom.sh index e071d40238f10..a609c8506d5c9 100755 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/decom.sh +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/decom.sh @@ -1,8 +1,10 @@ #!/usr/bin/env bash -set -ex +set -x echo "hi" echo "Decom adventures" > /dev/termination-log || echo "logging is hard" -WORKER_PID=$(ps axf | grep java |grep org.apache.spark | grep -v grep | awk '{print "kill -9 " $1}') +WORKER_PID=$(ps axf | grep java |grep org.apache.spark.executor.CoarseGrainedExecutorBackend | grep -v grep | awk '{print "kill -9 " $1}') kill -s SIGPWR ${WORKER_PID} +killall -s SIGPWR java waitpid ${WORKER_PID} +sleep 30 From c1917b5f6be42d04f1f79618cff30b0906b4f7b1 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 13 Mar 2019 18:19:03 -0700 Subject: [PATCH 187/194] Try and log our exit process --- .../docker/src/main/dockerfiles/spark/decom.sh | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/decom.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/decom.sh index a609c8506d5c9..d59e9aa252358 100755 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/decom.sh +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/decom.sh @@ -1,10 +1,15 @@ #!/usr/bin/env bash set -x +LOG=/dev/termination-log echo "hi" -echo "Decom adventures" > /dev/termination-log || echo "logging is hard" +echo "Starting decom adventures" > ${LOG} || echo "logging is hard" +date | tee -a ${LOG} WORKER_PID=$(ps axf | grep java |grep org.apache.spark.executor.CoarseGrainedExecutorBackend | grep -v grep | awk '{print "kill -9 " $1}') -kill -s SIGPWR ${WORKER_PID} -killall -s SIGPWR java -waitpid ${WORKER_PID} -sleep 30 +echo "Using worker pid $WORKER_PID" | tee -a ${LOG} +kill -s SIGPWR ${WORKER_PID} | tee -a ${LOG} +killall -s SIGPWR java | tee -a ${LOG} +waitpid ${WORKER_PID} | tee -a ${LOG} +sleep 30 | tee -a ${LOG} +echo "Done" | tee -a ${LOG} +date | tee -a ${LOG} From fb559fef73cec741a68bc36e6ede9eb2399953b9 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 13 Mar 2019 18:26:05 -0700 Subject: [PATCH 188/194] Fix printing the worker pid --- .../kubernetes/docker/src/main/dockerfiles/spark/decom.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/decom.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/decom.sh index d59e9aa252358..42c39367639ae 100755 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/decom.sh +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/decom.sh @@ -1,11 +1,11 @@ #!/usr/bin/env bash set -x -LOG=/dev/termination-log +export LOG=/dev/termination-log echo "hi" echo "Starting decom adventures" > ${LOG} || echo "logging is hard" date | tee -a ${LOG} -WORKER_PID=$(ps axf | grep java |grep org.apache.spark.executor.CoarseGrainedExecutorBackend | grep -v grep | awk '{print "kill -9 " $1}') +WORKER_PID=$(ps axf | grep java |grep org.apache.spark.executor.CoarseGrainedExecutorBackend | grep -v grep | awk '{print $1}') echo "Using worker pid $WORKER_PID" | tee -a ${LOG} kill -s SIGPWR ${WORKER_PID} | tee -a ${LOG} killall -s SIGPWR java | tee -a ${LOG} From 0982c11ed8985898e15b4a5017b397dd629a42bb Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 13 Mar 2019 18:27:00 -0700 Subject: [PATCH 189/194] Wait that awk statement wasn't doing anything for me --- .../kubernetes/docker/src/main/dockerfiles/spark/decom.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/decom.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/decom.sh index 42c39367639ae..0d9e79260ef58 100755 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/decom.sh +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/decom.sh @@ -5,7 +5,7 @@ export LOG=/dev/termination-log echo "hi" echo "Starting decom adventures" > ${LOG} || echo "logging is hard" date | tee -a ${LOG} -WORKER_PID=$(ps axf | grep java |grep org.apache.spark.executor.CoarseGrainedExecutorBackend | grep -v grep | awk '{print $1}') +WORKER_PID=$(ps axf | grep java |grep org.apache.spark.executor.CoarseGrainedExecutorBackend | grep -v grep) echo "Using worker pid $WORKER_PID" | tee -a ${LOG} kill -s SIGPWR ${WORKER_PID} | tee -a ${LOG} killall -s SIGPWR java | tee -a ${LOG} From 09a01cf49969c943b85303bd70b90fda14672ca8 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 6 May 2019 18:36:19 -0700 Subject: [PATCH 190/194] Fix minor style issues after merge --- .../apache/spark/executor/CoarseGrainedExecutorBackend.scala | 2 +- .../spark/deploy/k8s/integrationtest/KubernetesSuite.scala | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 90e256e356ca0..d509cc6ca4b15 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -28,10 +28,10 @@ import scala.util.control.NonFatal import org.apache.spark._ import org.apache.spark.TaskState.TaskState -import org.apache.spark.internal.config import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.worker.WorkerWatcher import org.apache.spark.internal.Logging +import org.apache.spark.internal.config import org.apache.spark.rpc._ import org.apache.spark.scheduler.{ExecutorLossReason, TaskDescription} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 6cf4f6db9f3be..739183491869d 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -256,6 +256,7 @@ class KubernetesSuite extends SparkFunSuite } } + // scalastyle:off argcount protected def runSparkApplicationAndVerifyCompletion( appResource: String, mainClass: String, @@ -268,6 +269,7 @@ class KubernetesSuite extends SparkFunSuite pyFiles: Option[String] = None, interval: Option[PatienceConfiguration.Interval] = None, decomissioningTest: Boolean = false): Unit = { + // scalastyle:on argcount val appArguments = SparkAppArguments( mainAppResource = appResource, mainClass = mainClass, @@ -312,7 +314,7 @@ class KubernetesSuite extends SparkFunSuite action match { case Action.MODIFIED => execPods(name) = resource - case Action.ADDED => + case Action.ADDED => println(s"Add event received for $name.") execPods(name) = resource // If testing decomissioning delete the first with a delay after it starts From 9a5000d494ee49da54f1a97a9738e017696811b7 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 6 May 2019 18:45:30 -0700 Subject: [PATCH 191/194] Add license header to decom script --- .../docker/src/main/dockerfiles/spark/decom.sh | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/decom.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/decom.sh index 0d9e79260ef58..8e3ccad1f9af8 100755 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/decom.sh +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/decom.sh @@ -1,5 +1,23 @@ #!/usr/bin/env bash +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + set -x export LOG=/dev/termination-log echo "hi" From e271a1d311fd40adeb818a2d7f5fa0115f3d947d Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 6 May 2019 18:58:01 -0700 Subject: [PATCH 192/194] waitpid is the syscall wait is the shell command --- .../kubernetes/docker/src/main/dockerfiles/spark/decom.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/decom.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/decom.sh index 8e3ccad1f9af8..35aa3a1f269a7 100755 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/decom.sh +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/decom.sh @@ -27,7 +27,7 @@ WORKER_PID=$(ps axf | grep java |grep org.apache.spark.executor.CoarseGrainedExe echo "Using worker pid $WORKER_PID" | tee -a ${LOG} kill -s SIGPWR ${WORKER_PID} | tee -a ${LOG} killall -s SIGPWR java | tee -a ${LOG} -waitpid ${WORKER_PID} | tee -a ${LOG} +wait ${WORKER_PID} | tee -a ${LOG} sleep 30 | tee -a ${LOG} echo "Done" | tee -a ${LOG} date | tee -a ${LOG} From 74007922073f8ed7fe9a43bf186aee6a12072687 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 6 May 2019 19:08:07 -0700 Subject: [PATCH 193/194] Start cleaning up the decom script, todo fix the PID extraction --- .../docker/src/main/dockerfiles/spark/decom.sh | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/decom.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/decom.sh index 35aa3a1f269a7..0149c0e92cdbb 100755 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/decom.sh +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/decom.sh @@ -18,16 +18,22 @@ # -set -x +set -ex export LOG=/dev/termination-log echo "hi" echo "Starting decom adventures" > ${LOG} || echo "logging is hard" date | tee -a ${LOG} +# TODO(holden): Fix this PID extraction WORKER_PID=$(ps axf | grep java |grep org.apache.spark.executor.CoarseGrainedExecutorBackend | grep -v grep) echo "Using worker pid $WORKER_PID" | tee -a ${LOG} kill -s SIGPWR ${WORKER_PID} | tee -a ${LOG} -killall -s SIGPWR java | tee -a ${LOG} -wait ${WORKER_PID} | tee -a ${LOG} -sleep 30 | tee -a ${LOG} +echo "Waiting for worker pid to exit" +date +timeout 60 tail --pid=${WORKER_PID} -f /dev/null | tee -a ${LOG} +date +sleep 60 | tee -a ${LOG} +date echo "Done" | tee -a ${LOG} date | tee -a ${LOG} +echo "Term log was:" +cat $LOG From 55fa260da101ca35e07ea36c33553c0bf0cd3479 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 6 May 2019 19:08:24 -0700 Subject: [PATCH 194/194] Print out the termination log at the end as well --- .../docker/src/main/dockerfiles/spark/entrypoint.sh | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh index 2097fb8865de9..c1208be984d67 100755 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh @@ -95,3 +95,9 @@ esac # Execute the container CMD under tini for better hygiene exec /sbin/tini -s -- "${CMD[@]}" + +# Print out the termination log as we exit +sleep 1 +echo "Finished. Termination log:" +cat /dev/termination-log +sleep 1