Skip to content

Commit dd37529

Browse files
committed
[SPARK-24250][SQL] support accessing SQLConf inside tasks
## What changes were proposed in this pull request? Previously in apache#20136 we decided to forbid tasks to access `SQLConf`, because it doesn't work and always give you the default conf value. In apache#21190 we fixed the check and all the places that violate it. Currently the pattern of accessing configs at the executor side is: read the configs at the driver side, then access the variables holding the config values in the RDD closure, so that they will be serialized to the executor side. Something like ``` val someConf = conf.getXXX child.execute().mapPartitions { if (someConf == ...) ... ... } ``` However, this pattern is hard to apply if the config needs to be propagated via a long call stack. An example is `DataType.sameType`, and see how many changes were made in apache#21190 . When it comes to code generation, it's even worse. I tried it locally and we need to change a ton of files to propagate configs to code generators. This PR proposes to allow tasks to access `SQLConf`. The idea is, we can save all the SQL configs to job properties when an SQL execution is triggered. At executor side we rebuild the `SQLConf` from job properties. ## How was this patch tested? a new test suite Author: Wenchen Fan <[email protected]> Closes apache#21299 from cloud-fan/config.
1 parent 434d74e commit dd37529

File tree

9 files changed

+210
-36
lines changed

9 files changed

+210
-36
lines changed

core/src/main/scala/org/apache/spark/TaskContextImpl.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,4 +178,6 @@ private[spark] class TaskContextImpl(
178178

179179
private[spark] def fetchFailed: Option[FetchFailedException] = _fetchFailedException
180180

181+
// TODO: shall we publish it and define it in `TaskContext`?
182+
private[spark] def getLocalProperties(): Properties = localProperties
181183
}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.internal
19+
20+
import java.util.{Map => JMap}
21+
22+
import org.apache.spark.{TaskContext, TaskContextImpl}
23+
import org.apache.spark.internal.config.{ConfigEntry, ConfigProvider, ConfigReader}
24+
25+
/**
26+
* A readonly SQLConf that will be created by tasks running at the executor side. It reads the
27+
* configs from the local properties which are propagated from driver to executors.
28+
*/
29+
class ReadOnlySQLConf(context: TaskContext) extends SQLConf {
30+
31+
@transient override val settings: JMap[String, String] = {
32+
context.asInstanceOf[TaskContextImpl].getLocalProperties().asInstanceOf[JMap[String, String]]
33+
}
34+
35+
@transient override protected val reader: ConfigReader = {
36+
new ConfigReader(new TaskContextConfigProvider(context))
37+
}
38+
39+
override protected def setConfWithCheck(key: String, value: String): Unit = {
40+
throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.")
41+
}
42+
43+
override def unsetConf(key: String): Unit = {
44+
throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.")
45+
}
46+
47+
override def unsetConf(entry: ConfigEntry[_]): Unit = {
48+
throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.")
49+
}
50+
51+
override def clear(): Unit = {
52+
throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.")
53+
}
54+
55+
override def clone(): SQLConf = {
56+
throw new UnsupportedOperationException("Cannot clone/copy ReadOnlySQLConf.")
57+
}
58+
59+
override def copy(entries: (ConfigEntry[_], Any)*): SQLConf = {
60+
throw new UnsupportedOperationException("Cannot clone/copy ReadOnlySQLConf.")
61+
}
62+
}
63+
64+
class TaskContextConfigProvider(context: TaskContext) extends ConfigProvider {
65+
override def get(key: String): Option[String] = Option(context.getLocalProperty(key))
66+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,12 @@ import scala.util.matching.Regex
2727

2828
import org.apache.hadoop.fs.Path
2929

30-
import org.apache.spark.{SparkContext, SparkEnv}
30+
import org.apache.spark.TaskContext
3131
import org.apache.spark.internal.Logging
3232
import org.apache.spark.internal.config._
3333
import org.apache.spark.network.util.ByteUnit
3434
import org.apache.spark.sql.catalyst.analysis.Resolver
3535
import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator
36-
import org.apache.spark.util.Utils
3736

3837
////////////////////////////////////////////////////////////////////////////////////////////////////
3938
// This file defines the configuration options for Spark SQL.
@@ -107,7 +106,13 @@ object SQLConf {
107106
* run tests in parallel. At the time this feature was implemented, this was a no-op since we
108107
* run unit tests (that does not involve SparkSession) in serial order.
109108
*/
110-
def get: SQLConf = confGetter.get()()
109+
def get: SQLConf = {
110+
if (TaskContext.get != null) {
111+
new ReadOnlySQLConf(TaskContext.get())
112+
} else {
113+
confGetter.get()()
114+
}
115+
}
111116

112117
val OPTIMIZER_MAX_ITERATIONS = buildConf("spark.sql.optimizer.maxIterations")
113118
.internal()
@@ -1292,17 +1297,11 @@ object SQLConf {
12921297
class SQLConf extends Serializable with Logging {
12931298
import SQLConf._
12941299

1295-
if (Utils.isTesting && SparkEnv.get != null) {
1296-
// assert that we're only accessing it on the driver.
1297-
assert(SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER,
1298-
"SQLConf should only be created and accessed on the driver.")
1299-
}
1300-
13011300
/** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */
13021301
@transient protected[spark] val settings = java.util.Collections.synchronizedMap(
13031302
new java.util.HashMap[String, String]())
13041303

1305-
@transient private val reader = new ConfigReader(settings)
1304+
@transient protected val reader = new ConfigReader(settings)
13061305

13071306
/** ************************ Spark SQL Params/Hints ******************* */
13081307

@@ -1765,7 +1764,7 @@ class SQLConf extends Serializable with Logging {
17651764
settings.containsKey(key)
17661765
}
17671766

1768-
private def setConfWithCheck(key: String, value: String): Unit = {
1767+
protected def setConfWithCheck(key: String, value: String): Unit = {
17691768
settings.put(key, value)
17701769
}
17711770

sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import scala.collection.JavaConverters._
2424
import scala.reflect.runtime.universe.TypeTag
2525
import scala.util.control.NonFatal
2626

27-
import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext}
27+
import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext, TaskContext}
2828
import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability}
2929
import org.apache.spark.api.java.JavaRDD
3030
import org.apache.spark.internal.Logging
@@ -898,6 +898,7 @@ object SparkSession extends Logging {
898898
* @since 2.0.0
899899
*/
900900
def getOrCreate(): SparkSession = synchronized {
901+
assertOnDriver()
901902
// Get the session from current thread's active session.
902903
var session = activeThreadSession.get()
903904
if ((session ne null) && !session.sparkContext.isStopped) {
@@ -1022,14 +1023,20 @@ object SparkSession extends Logging {
10221023
*
10231024
* @since 2.2.0
10241025
*/
1025-
def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get)
1026+
def getActiveSession: Option[SparkSession] = {
1027+
assertOnDriver()
1028+
Option(activeThreadSession.get)
1029+
}
10261030

10271031
/**
10281032
* Returns the default SparkSession that is returned by the builder.
10291033
*
10301034
* @since 2.2.0
10311035
*/
1032-
def getDefaultSession: Option[SparkSession] = Option(defaultSession.get)
1036+
def getDefaultSession: Option[SparkSession] = {
1037+
assertOnDriver()
1038+
Option(defaultSession.get)
1039+
}
10331040

10341041
/**
10351042
* Returns the currently active SparkSession, otherwise the default one. If there is no default
@@ -1062,6 +1069,14 @@ object SparkSession extends Logging {
10621069
}
10631070
}
10641071

1072+
private def assertOnDriver(): Unit = {
1073+
if (Utils.isTesting && TaskContext.get != null) {
1074+
// we're accessing it during task execution, fail.
1075+
throw new IllegalStateException(
1076+
"SparkSession should only be created and accessed on the driver.")
1077+
}
1078+
}
1079+
10651080
/**
10661081
* Helper method to create an instance of `SessionState` based on `className` from conf.
10671082
* The result is either `SessionState` or a Hive based `SessionState`.

sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -68,16 +68,18 @@ object SQLExecution {
6868
// sparkContext.getCallSite() would first try to pick up any call site that was previously
6969
// set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on
7070
// streaming queries would give us call site like "run at <unknown>:0"
71-
val callSite = sparkSession.sparkContext.getCallSite()
71+
val callSite = sc.getCallSite()
7272

73-
sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart(
74-
executionId, callSite.shortForm, callSite.longForm, queryExecution.toString,
75-
SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis()))
76-
try {
77-
body
78-
} finally {
79-
sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd(
80-
executionId, System.currentTimeMillis()))
73+
withSQLConfPropagated(sparkSession) {
74+
sc.listenerBus.post(SparkListenerSQLExecutionStart(
75+
executionId, callSite.shortForm, callSite.longForm, queryExecution.toString,
76+
SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis()))
77+
try {
78+
body
79+
} finally {
80+
sc.listenerBus.post(SparkListenerSQLExecutionEnd(
81+
executionId, System.currentTimeMillis()))
82+
}
8183
}
8284
} finally {
8385
executionIdToQueryExecution.remove(executionId)
@@ -90,13 +92,37 @@ object SQLExecution {
9092
* thread from the original one, this method can be used to connect the Spark jobs in this action
9193
* with the known executionId, e.g., `BroadcastExchangeExec.relationFuture`.
9294
*/
93-
def withExecutionId[T](sc: SparkContext, executionId: String)(body: => T): T = {
95+
def withExecutionId[T](sparkSession: SparkSession, executionId: String)(body: => T): T = {
96+
val sc = sparkSession.sparkContext
9497
val oldExecutionId = sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
98+
withSQLConfPropagated(sparkSession) {
99+
try {
100+
sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId)
101+
body
102+
} finally {
103+
sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId)
104+
}
105+
}
106+
}
107+
108+
def withSQLConfPropagated[T](sparkSession: SparkSession)(body: => T): T = {
109+
val sc = sparkSession.sparkContext
110+
// Set all the specified SQL configs to local properties, so that they can be available at
111+
// the executor side.
112+
val allConfigs = sparkSession.sessionState.conf.getAllConfs
113+
val originalLocalProps = allConfigs.collect {
114+
case (key, value) if key.startsWith("spark") =>
115+
val originalValue = sc.getLocalProperty(key)
116+
sc.setLocalProperty(key, value)
117+
(key, originalValue)
118+
}
119+
95120
try {
96-
sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId)
97121
body
98122
} finally {
99-
sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId)
123+
for ((key, value) <- originalLocalProps) {
124+
sc.setLocalProperty(key, value)
125+
}
100126
}
101127
}
102128
}

sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,7 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode {
629629
Future {
630630
// This will run in another thread. Set the execution id so that we can connect these jobs
631631
// with the correct execution.
632-
SQLExecution.withExecutionId(sparkContext, executionId) {
632+
SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) {
633633
val beforeCollect = System.nanoTime()
634634
// Note that we use .executeCollect() because we don't want to convert data to Scala types
635635
val rows: Array[InternalRow] = child.executeCollect()

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import org.apache.spark.rdd.{BinaryFileRDD, RDD}
3434
import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
3535
import org.apache.spark.sql.catalyst.InternalRow
3636
import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions}
37+
import org.apache.spark.sql.execution.SQLExecution
3738
import org.apache.spark.sql.execution.datasources._
3839
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
3940
import org.apache.spark.sql.types.StructType
@@ -104,22 +105,19 @@ object TextInputJsonDataSource extends JsonDataSource {
104105
CreateJacksonParser.internalRow(enc, _: JsonFactory, _: InternalRow)
105106
}.getOrElse(CreateJacksonParser.internalRow(_: JsonFactory, _: InternalRow))
106107

107-
JsonInferSchema.infer(rdd, parsedOptions, rowParser)
108+
SQLExecution.withSQLConfPropagated(json.sparkSession) {
109+
JsonInferSchema.infer(rdd, parsedOptions, rowParser)
110+
}
108111
}
109112

110113
private def createBaseDataset(
111114
sparkSession: SparkSession,
112115
inputPaths: Seq[FileStatus],
113116
parsedOptions: JSONOptions): Dataset[String] = {
114-
val paths = inputPaths.map(_.getPath.toString)
115-
val textOptions = Map.empty[String, String] ++
116-
parsedOptions.encoding.map("encoding" -> _) ++
117-
parsedOptions.lineSeparator.map("lineSep" -> _)
118-
119117
sparkSession.baseRelationToDataFrame(
120118
DataSource.apply(
121119
sparkSession,
122-
paths = paths,
120+
paths = inputPaths.map(_.getPath.toString),
123121
className = classOf[TextFileFormat].getName,
124122
options = parsedOptions.parameters
125123
).resolveRelation(checkFilesExist = false))
@@ -165,7 +163,9 @@ object MultiLineJsonDataSource extends JsonDataSource {
165163
.map(enc => createParser(enc, _: JsonFactory, _: PortableDataStream))
166164
.getOrElse(createParser(_: JsonFactory, _: PortableDataStream))
167165

168-
JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser)
166+
SQLExecution.withSQLConfPropagated(sparkSession) {
167+
JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser)
168+
}
169169
}
170170

171171
private def createBaseRdd(

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ case class BroadcastExchangeExec(
6969
Future {
7070
// This will run in another thread. Set the execution id so that we can connect these jobs
7171
// with the correct execution.
72-
SQLExecution.withExecutionId(sparkContext, executionId) {
72+
SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) {
7373
try {
7474
val beforeCollect = System.nanoTime()
7575
// Use executeCollect/executeCollectIterator to avoid conversion to Scala types
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.internal
19+
20+
import org.apache.spark.SparkFunSuite
21+
import org.apache.spark.sql.SparkSession
22+
import org.apache.spark.sql.test.SQLTestUtils
23+
24+
class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils {
25+
import testImplicits._
26+
27+
protected var spark: SparkSession = null
28+
29+
// Create a new [[SparkSession]] running in local-cluster mode.
30+
override def beforeAll(): Unit = {
31+
super.beforeAll()
32+
spark = SparkSession.builder()
33+
.master("local-cluster[2,1,1024]")
34+
.appName("testing")
35+
.getOrCreate()
36+
}
37+
38+
override def afterAll(): Unit = {
39+
spark.stop()
40+
spark = null
41+
}
42+
43+
test("ReadonlySQLConf is correctly created at the executor side") {
44+
SQLConf.get.setConfString("spark.sql.x", "a")
45+
try {
46+
val checks = spark.range(10).mapPartitions { it =>
47+
val conf = SQLConf.get
48+
Iterator(conf.isInstanceOf[ReadOnlySQLConf] && conf.getConfString("spark.sql.x") == "a")
49+
}.collect()
50+
assert(checks.forall(_ == true))
51+
} finally {
52+
SQLConf.get.unsetConf("spark.sql.x")
53+
}
54+
}
55+
56+
test("case-sensitive config should work for json schema inference") {
57+
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
58+
withTempPath { path =>
59+
val pathString = path.getCanonicalPath
60+
spark.range(10).select('id.as("ID")).write.json(pathString)
61+
spark.range(10).write.mode("append").json(pathString)
62+
assert(spark.read.json(pathString).columns.toSet == Set("id", "ID"))
63+
}
64+
}
65+
}
66+
}

0 commit comments

Comments
 (0)