Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,11 @@ class SQLContext(@transient val sparkContext: SparkContext)
tlSession.remove()
}

protected[sql] def setSession(session: SQLSession): Unit = {
detachSession()
tlSession.set(session)
}

protected[sql] class SQLSession {
// Note that this is a lazy val so we can override the default value in subclasses.
protected[sql] lazy val conf: SQLConf = new SQLConf
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,12 @@ private[thriftserver] class SparkSQLOperationManager(hiveContext: HiveContext)
confOverlay: JMap[String, String],
async: Boolean): ExecuteStatementOperation = synchronized {

val operation = new SparkExecuteStatementOperation(parentSession, statement, confOverlay)(
hiveContext, sessionToActivePool)
val runInBackground = async && hiveContext.hiveThriftServerAsync
val operation = new SparkExecuteStatementOperation(parentSession, statement, confOverlay,
runInBackground)(hiveContext, sessionToActivePool)
handleToOperation.put(operation.getHandle, operation)
logDebug(s"Created Operation for $statement with session=$parentSession, " +
s"runInBackground=$runInBackground")
operation
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ package org.apache.spark.sql.hive.thriftserver

import java.io.File
import java.net.URL
import java.sql.{Date, DriverManager, Statement}
import java.nio.charset.StandardCharsets
import java.sql.{Date, DriverManager, SQLException, Statement}

import scala.collection.mutable.ArrayBuffer
import scala.concurrent.duration._
import scala.concurrent.{Await, Promise}
import scala.concurrent.{Await, Promise, future}
import scala.concurrent.ExecutionContext.Implicits.global
import scala.sys.process.{Process, ProcessLogger}
import scala.util.{Random, Try}

Expand Down Expand Up @@ -337,6 +339,42 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest {
}
)
}

test("test jdbc cancel") {
withJdbcStatement { statement =>
val queries = Seq(
"DROP TABLE IF EXISTS test_map",
"CREATE TABLE test_map(key INT, value STRING)",
s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_map")

queries.foreach(statement.execute)

val largeJoin = "SELECT COUNT(*) FROM test_map " +
List.fill(10)("join test_map").mkString(" ")
val f = future { Thread.sleep(100); statement.cancel(); }
val e = intercept[SQLException] {
statement.executeQuery(largeJoin)
}
assert(e.getMessage contains "cancelled")
Await.result(f, 3.minute)

// cancel is a noop
statement.executeQuery("SET spark.sql.hive.thriftServer.async=false")
val sf = future { Thread.sleep(100); statement.cancel(); }
val smallJoin = "SELECT COUNT(*) FROM test_map " +
List.fill(4)("join test_map").mkString(" ")
val rs1 = statement.executeQuery(smallJoin)
Await.result(sf, 3.minute)
rs1.next()
assert(rs1.getInt(1) === math.pow(5, 5))
rs1.close()

val rs2 = statement.executeQuery("SELECT COUNT(*) FROM test_map")
rs2.next()
assert(rs2.getInt(1) === 5)
rs2.close()
}
}
}

class HiveThriftHttpServerSuite extends HiveThriftJdbcTest {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.hive.thriftserver

import java.security.PrivilegedExceptionAction
import java.sql.{Date, Timestamp}
import java.util.concurrent.Executors
import java.util.{ArrayList => JArrayList, List => JList, Map => JMap, UUID}
Expand All @@ -29,8 +30,15 @@ import org.apache.spark.sql.hive.thriftserver.server.SparkSQLOperationManager

import scala.collection.JavaConversions._
import scala.collection.mutable.{ArrayBuffer, Map => SMap}
import java.util.concurrent.RejectedExecutionException
import scala.util.control.NonFatal

import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.metastore.api.FieldSchema
import org.apache.hadoop.hive.ql.metadata.Hive
import org.apache.hadoop.hive.ql.metadata.HiveException
import org.apache.hadoop.hive.ql.session.SessionState
import org.apache.hadoop.hive.shims.ShimLoader
import org.apache.hadoop.security.UserGroupInformation
import org.apache.hive.service.cli._
import org.apache.hive.service.cli.operation.ExecuteStatementOperation
Expand Down Expand Up @@ -73,19 +81,21 @@ private[hive] class SparkExecuteStatementOperation(
parentSession: HiveSession,
statement: String,
confOverlay: JMap[String, String],
runInBackground: Boolean = true)(
hiveContext: HiveContext,
sessionToActivePool: SMap[SessionHandle, String])
// NOTE: `runInBackground` is set to `false` intentionally to disable asynchronous execution
extends ExecuteStatementOperation(parentSession, statement, confOverlay, false) with Logging {
runInBackground: Boolean = true)
(hiveContext: HiveContext, sessionToActivePool: SMap[SessionHandle, String])
extends ExecuteStatementOperation(parentSession, statement, confOverlay, runInBackground)
with Logging {

private var result: DataFrame = _
private var iter: Iterator[SparkRow] = _
private var dataTypes: Array[DataType] = _
private var statementId: String = _

def close(): Unit = {
// RDDs will be cleaned automatically upon garbage collection.
logDebug("CLOSING")
hiveContext.sparkContext.clearJobGroup()
logDebug(s"CLOSING $statementId")
cleanup(OperationState.CLOSED)
}

def addNonNullColumnValue(from: SparkRow, to: ArrayBuffer[Any], ordinal: Int) {
Expand Down Expand Up @@ -149,20 +159,84 @@ private[hive] class SparkExecuteStatementOperation(
}

def getResultSetSchema: TableSchema = {
logInfo(s"Result Schema: ${result.queryExecution.analyzed.output}")
if (result.queryExecution.analyzed.output.size == 0) {
if (result == null || result.queryExecution.analyzed.output.size == 0) {
new TableSchema(new FieldSchema("Result", "string", "") :: Nil)
} else {
logInfo(s"Result Schema: ${result.queryExecution.analyzed.output}")
val schema = result.queryExecution.analyzed.output.map { attr =>
new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "")
}
new TableSchema(schema)
}
}

def run(): Unit = {
val statementId = UUID.randomUUID().toString
logInfo(s"Running query '$statement'")
override def run(): Unit = {
setState(OperationState.PENDING)
setHasResultSet(true) // avoid no resultset for async run

if (!runInBackground) {
runInternal()
} else {
val parentSessionState = SessionState.get()
val hiveConf = getConfigForOperation()
val sparkServiceUGI = ShimLoader.getHadoopShims.getUGIForConf(hiveConf)
val sessionHive = getCurrentHive()
val currentSqlSession = hiveContext.currentSession

// Runnable impl to call runInternal asynchronously,
// from a different thread
val backgroundOperation = new Runnable() {

override def run(): Unit = {
val doAsAction = new PrivilegedExceptionAction[Object]() {
override def run(): Object = {

// User information is part of the metastore client member in Hive
hiveContext.setSession(currentSqlSession)
Hive.set(sessionHive)
SessionState.setCurrentSessionState(parentSessionState)
try {
runInternal()
} catch {
case e: HiveSQLException =>
setOperationException(e)
log.error("Error running hive query: ", e)
}
return null
}
}

try {
ShimLoader.getHadoopShims().doAs(sparkServiceUGI, doAsAction)
} catch {
case e: Exception =>
setOperationException(new HiveSQLException(e))
logError("Error running hive query as user : " +
sparkServiceUGI.getShortUserName(), e)
}
}
}
try {
// This submit blocks if no background threads are available to run this operation
val backgroundHandle =
getParentSession().getSessionManager().submitBackgroundOperation(backgroundOperation)
setBackgroundHandle(backgroundHandle)
} catch {
case rejected: RejectedExecutionException =>
setState(OperationState.ERROR)
throw new HiveSQLException("The background threadpool cannot accept" +
" new task for execution, please retry the operation", rejected)
case NonFatal(e) =>
logError(s"Error executing query in background", e)
setState(OperationState.ERROR)
throw e
}
}
}

private def runInternal(): Unit = {
statementId = UUID.randomUUID().toString
logInfo(s"Running query '$statement' with $statementId")
setState(OperationState.RUNNING)
HiveThriftServer2.listener.onStatementStart(
statementId,
Expand Down Expand Up @@ -194,20 +268,84 @@ private[hive] class SparkExecuteStatementOperation(
}
}
dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray
setHasResultSet(true)
} catch {
case e: HiveSQLException =>
if (getStatus().getState() == OperationState.CANCELED) {
return
} else {
setState(OperationState.ERROR);
throw e
}
// Actually do need to catch Throwable as some failures don't inherit from Exception and
// HiveServer will silently swallow them.
case e: Throwable =>
val currentState = getStatus().getState()
logError(s"Error executing query, currentState $currentState, ", e)
setState(OperationState.ERROR)
HiveThriftServer2.listener.onStatementError(
statementId, e.getMessage, e.getStackTraceString)
logError("Error executing query:", e)
throw new HiveSQLException(e.toString)
}
setState(OperationState.FINISHED)
HiveThriftServer2.listener.onStatementFinish(statementId)
}

override def cancel(): Unit = {
logInfo(s"Cancel '$statement' with $statementId")
if (statementId != null) {
hiveContext.sparkContext.cancelJobGroup(statementId)
}
cleanup(OperationState.CANCELED)
}

private def cleanup(state: OperationState) {
setState(state)
if (runInBackground) {
val backgroundHandle = getBackgroundHandle()
if (backgroundHandle != null) {
backgroundHandle.cancel(true)
}
}
}

/**
* If there are query specific settings to overlay, then create a copy of config
* There are two cases we need to clone the session config that's being passed to hive driver
* 1. Async query -
* If the client changes a config setting, that shouldn't reflect in the execution
* already underway
* 2. confOverlay -
* The query specific settings should only be applied to the query config and not session
* @return new configuration
* @throws HiveSQLException
*/
private def getConfigForOperation(): HiveConf = {
var sqlOperationConf = getParentSession().getHiveConf()
if (!getConfOverlay().isEmpty() || runInBackground) {
// clone the partent session config for this query
sqlOperationConf = new HiveConf(sqlOperationConf)

// apply overlay query specific settings, if any
getConfOverlay().foreach { case (k, v) =>
try {
sqlOperationConf.verifyAndSet(k, v)
} catch {
case e: IllegalArgumentException =>
throw new HiveSQLException("Error applying statement specific settings", e)
}
}
}
return sqlOperationConf
}

private def getCurrentHive(): Hive = {
try {
return Hive.get()
} catch {
case e: HiveException =>
throw new HiveSQLException("Failed to get current Hive object", e);
}
}
}

private[hive] class SparkSQLSessionManager(hiveContext: HiveContext)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,12 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
getConf("spark.sql.hive.metastore.barrierPrefixes", "")
.split(",").filterNot(_ == "")

/*
* hive thrift server use background spark sql thread pool to execute sql queries
*/
protected[hive] def hiveThriftServerAsync: Boolean =
getConf("spark.sql.hive.thriftServer.async", "true").toBoolean

@transient
protected[sql] lazy val substitutor = new VariableSubstitution()

Expand Down